Skip to main content

edgefirst_decoder/
decoder.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::HashSet;
5
6use ndarray::{s, Array3, ArrayView, ArrayViewD, Dimension};
7use ndarray_stats::QuantileExt;
8use num_traits::{AsPrimitive, Float};
9use serde::{Deserialize, Serialize};
10
11use crate::{
12    configs::{DecoderType, DimName, ModelType, QuantTuple},
13    dequantize_ndarray,
14    modelpack::{
15        decode_modelpack_det, decode_modelpack_float, decode_modelpack_split_float,
16        ModelPackDetectionConfig,
17    },
18    yolo::{
19        decode_yolo_det, decode_yolo_det_float, decode_yolo_segdet_float, decode_yolo_segdet_quant,
20        decode_yolo_split_det_float, decode_yolo_split_det_quant, decode_yolo_split_segdet_float,
21        impl_yolo_split_segdet_quant_get_boxes, impl_yolo_split_segdet_quant_process_masks,
22    },
23    DecoderError, DecoderVersion, DetectBox, Quantization, Segmentation, XYWH,
24};
25
26/// Used to represent the outputs in the model configuration.
27/// # Examples
28/// ```rust
29/// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, ConfigOutputs};
30/// # fn main() -> DecoderResult<()> {
31/// let config_json = include_str!("../../../testdata/modelpack_split.json");
32/// let config: ConfigOutputs = serde_json::from_str(config_json)?;
33/// let decoder = DecoderBuilder::new().with_config(config).build()?;
34///
35/// # Ok(())
36/// # }
37#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
38pub struct ConfigOutputs {
39    #[serde(default)]
40    pub outputs: Vec<ConfigOutput>,
41    /// NMS mode from config file. When present, overrides the builder's NMS
42    /// setting.
43    /// - `Some(Nms::ClassAgnostic)` — class-agnostic NMS: suppress overlapping
44    ///   boxes regardless of class
45    /// - `Some(Nms::ClassAware)` — class-aware NMS: only suppress boxes with
46    ///   the same class
47    /// - `None` — use builder default or skip NMS (user handles it externally)
48    #[serde(default, skip_serializing_if = "Option::is_none")]
49    pub nms: Option<configs::Nms>,
50    /// Decoder version for Ultralytics models. Determines the decoding
51    /// strategy.
52    /// - `Some(Yolo26)` — end-to-end model with embedded NMS
53    /// - `Some(Yolov5/Yolov8/Yolo11)` — traditional models requiring external
54    ///   NMS
55    /// - `None` — infer from other settings (legacy behavior)
56    #[serde(default, skip_serializing_if = "Option::is_none")]
57    pub decoder_version: Option<configs::DecoderVersion>,
58}
59
60#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
61#[serde(tag = "type")]
62pub enum ConfigOutput {
63    #[serde(rename = "detection")]
64    Detection(configs::Detection),
65    #[serde(rename = "masks")]
66    Mask(configs::Mask),
67    #[serde(rename = "segmentation")]
68    Segmentation(configs::Segmentation),
69    #[serde(rename = "protos")]
70    Protos(configs::Protos),
71    #[serde(rename = "scores")]
72    Scores(configs::Scores),
73    #[serde(rename = "boxes")]
74    Boxes(configs::Boxes),
75    #[serde(rename = "mask_coefficients")]
76    MaskCoefficients(configs::MaskCoefficients),
77}
78
79#[derive(Debug, PartialEq, Clone)]
80pub enum ConfigOutputRef<'a> {
81    Detection(&'a configs::Detection),
82    Mask(&'a configs::Mask),
83    Segmentation(&'a configs::Segmentation),
84    Protos(&'a configs::Protos),
85    Scores(&'a configs::Scores),
86    Boxes(&'a configs::Boxes),
87    MaskCoefficients(&'a configs::MaskCoefficients),
88}
89
90impl<'a> ConfigOutputRef<'a> {
91    fn decoder(&self) -> configs::DecoderType {
92        match self {
93            ConfigOutputRef::Detection(v) => v.decoder,
94            ConfigOutputRef::Mask(v) => v.decoder,
95            ConfigOutputRef::Segmentation(v) => v.decoder,
96            ConfigOutputRef::Protos(v) => v.decoder,
97            ConfigOutputRef::Scores(v) => v.decoder,
98            ConfigOutputRef::Boxes(v) => v.decoder,
99            ConfigOutputRef::MaskCoefficients(v) => v.decoder,
100        }
101    }
102
103    fn dshape(&self) -> &[(DimName, usize)] {
104        match self {
105            ConfigOutputRef::Detection(v) => &v.dshape,
106            ConfigOutputRef::Mask(v) => &v.dshape,
107            ConfigOutputRef::Segmentation(v) => &v.dshape,
108            ConfigOutputRef::Protos(v) => &v.dshape,
109            ConfigOutputRef::Scores(v) => &v.dshape,
110            ConfigOutputRef::Boxes(v) => &v.dshape,
111            ConfigOutputRef::MaskCoefficients(v) => &v.dshape,
112        }
113    }
114}
115
116impl<'a> From<&'a configs::Detection> for ConfigOutputRef<'a> {
117    /// Converts from references of config structs to ConfigOutputRef
118    /// # Examples
119    /// ```rust
120    /// # use edgefirst_decoder::{configs, ConfigOutputRef};
121    /// let detection_config = configs::Detection {
122    ///     anchors: None,
123    ///     decoder: configs::DecoderType::Ultralytics,
124    ///     quantization: None,
125    ///     shape: vec![1, 84, 8400],
126    ///     dshape: Vec::new(),
127    ///     normalized: Some(true),
128    /// };
129    /// let output: ConfigOutputRef = (&detection_config).into();
130    /// ```
131    fn from(v: &'a configs::Detection) -> ConfigOutputRef<'a> {
132        ConfigOutputRef::Detection(v)
133    }
134}
135
136impl<'a> From<&'a configs::Mask> for ConfigOutputRef<'a> {
137    /// Converts from references of config structs to ConfigOutputRef
138    /// # Examples
139    /// ```rust
140    /// # use edgefirst_decoder::{configs, ConfigOutputRef};
141    /// let mask = configs::Mask {
142    ///     decoder: configs::DecoderType::ModelPack,
143    ///     quantization: None,
144    ///     shape: vec![1, 160, 160, 1],
145    ///     dshape: Vec::new(),
146    /// };
147    /// let output: ConfigOutputRef = (&mask).into();
148    /// ```
149    fn from(v: &'a configs::Mask) -> ConfigOutputRef<'a> {
150        ConfigOutputRef::Mask(v)
151    }
152}
153
154impl<'a> From<&'a configs::Segmentation> for ConfigOutputRef<'a> {
155    /// Converts from references of config structs to ConfigOutputRef
156    /// # Examples
157    /// ```rust
158    /// # use edgefirst_decoder::{configs, ConfigOutputRef};
159    /// let seg = configs::Segmentation {
160    ///     decoder: configs::DecoderType::ModelPack,
161    ///     quantization: None,
162    ///     shape: vec![1, 160, 160, 3],
163    ///     dshape: Vec::new(),
164    /// };
165    /// let output: ConfigOutputRef = (&seg).into();
166    /// ```
167    fn from(v: &'a configs::Segmentation) -> ConfigOutputRef<'a> {
168        ConfigOutputRef::Segmentation(v)
169    }
170}
171
172impl<'a> From<&'a configs::Protos> for ConfigOutputRef<'a> {
173    /// Converts from references of config structs to ConfigOutputRef
174    /// # Examples
175    /// ```rust
176    /// # use edgefirst_decoder::{configs, ConfigOutputRef};
177    /// let protos = configs::Protos {
178    ///     decoder: configs::DecoderType::Ultralytics,
179    ///     quantization: None,
180    ///     shape: vec![1, 160, 160, 32],
181    ///     dshape: Vec::new(),
182    /// };
183    /// let output: ConfigOutputRef = (&protos).into();
184    /// ```
185    fn from(v: &'a configs::Protos) -> ConfigOutputRef<'a> {
186        ConfigOutputRef::Protos(v)
187    }
188}
189
190impl<'a> From<&'a configs::Scores> for ConfigOutputRef<'a> {
191    /// Converts from references of config structs to ConfigOutputRef
192    /// # Examples
193    /// ```rust
194    /// # use edgefirst_decoder::{configs, ConfigOutputRef};
195    /// let scores = configs::Scores {
196    ///     decoder: configs::DecoderType::Ultralytics,
197    ///     quantization: None,
198    ///     shape: vec![1, 40, 8400],
199    ///     dshape: Vec::new(),
200    /// };
201    /// let output: ConfigOutputRef = (&scores).into();
202    /// ```
203    fn from(v: &'a configs::Scores) -> ConfigOutputRef<'a> {
204        ConfigOutputRef::Scores(v)
205    }
206}
207
208impl<'a> From<&'a configs::Boxes> for ConfigOutputRef<'a> {
209    /// Converts from references of config structs to ConfigOutputRef
210    /// # Examples
211    /// ```rust
212    /// # use edgefirst_decoder::{configs, ConfigOutputRef};
213    /// let boxes = configs::Boxes {
214    ///     decoder: configs::DecoderType::Ultralytics,
215    ///     quantization: None,
216    ///     shape: vec![1, 4, 8400],
217    ///     dshape: Vec::new(),
218    ///     normalized: Some(true),
219    /// };
220    /// let output: ConfigOutputRef = (&boxes).into();
221    /// ```
222    fn from(v: &'a configs::Boxes) -> ConfigOutputRef<'a> {
223        ConfigOutputRef::Boxes(v)
224    }
225}
226
227impl<'a> From<&'a configs::MaskCoefficients> for ConfigOutputRef<'a> {
228    /// Converts from references of config structs to ConfigOutputRef
229    /// # Examples
230    /// ```rust
231    /// # use edgefirst_decoder::{configs, ConfigOutputRef};
232    /// let mask_coefficients = configs::MaskCoefficients {
233    ///     decoder: configs::DecoderType::Ultralytics,
234    ///     quantization: None,
235    ///     shape: vec![1, 32, 8400],
236    ///     dshape: Vec::new(),
237    /// };
238    /// let output: ConfigOutputRef = (&mask_coefficients).into();
239    /// ```
240    fn from(v: &'a configs::MaskCoefficients) -> ConfigOutputRef<'a> {
241        ConfigOutputRef::MaskCoefficients(v)
242    }
243}
244
245impl ConfigOutput {
246    /// Returns the shape of the output.
247    ///
248    /// # Examples
249    /// ```rust
250    /// # use edgefirst_decoder::{configs, ConfigOutput};
251    /// let detection_config = configs::Detection {
252    ///     anchors: None,
253    ///     decoder: configs::DecoderType::Ultralytics,
254    ///     quantization: None,
255    ///     shape: vec![1, 84, 8400],
256    ///     dshape: Vec::new(),
257    ///     normalized: Some(true),
258    /// };
259    /// let output = ConfigOutput::Detection(detection_config);
260    /// assert_eq!(output.shape(), &[1, 84, 8400]);
261    /// ```
262    pub fn shape(&self) -> &[usize] {
263        match self {
264            ConfigOutput::Detection(detection) => &detection.shape,
265            ConfigOutput::Mask(mask) => &mask.shape,
266            ConfigOutput::Segmentation(segmentation) => &segmentation.shape,
267            ConfigOutput::Scores(scores) => &scores.shape,
268            ConfigOutput::Boxes(boxes) => &boxes.shape,
269            ConfigOutput::Protos(protos) => &protos.shape,
270            ConfigOutput::MaskCoefficients(mask_coefficients) => &mask_coefficients.shape,
271        }
272    }
273
274    /// Returns the decoder type of the output.
275    ///    
276    /// # Examples
277    /// ```rust
278    /// # use edgefirst_decoder::{configs, ConfigOutput};
279    /// let detection_config = configs::Detection {
280    ///     anchors: None,
281    ///     decoder: configs::DecoderType::Ultralytics,
282    ///     quantization: None,
283    ///     shape: vec![1, 84, 8400],
284    ///     dshape: Vec::new(),
285    ///     normalized: Some(true),
286    /// };
287    /// let output = ConfigOutput::Detection(detection_config);
288    /// assert_eq!(output.decoder(), &configs::DecoderType::Ultralytics);
289    /// ```
290    pub fn decoder(&self) -> &configs::DecoderType {
291        match self {
292            ConfigOutput::Detection(detection) => &detection.decoder,
293            ConfigOutput::Mask(mask) => &mask.decoder,
294            ConfigOutput::Segmentation(segmentation) => &segmentation.decoder,
295            ConfigOutput::Scores(scores) => &scores.decoder,
296            ConfigOutput::Boxes(boxes) => &boxes.decoder,
297            ConfigOutput::Protos(protos) => &protos.decoder,
298            ConfigOutput::MaskCoefficients(mask_coefficients) => &mask_coefficients.decoder,
299        }
300    }
301
302    /// Returns the quantization of the output.
303    ///
304    /// # Examples
305    /// ```rust
306    /// # use edgefirst_decoder::{configs, ConfigOutput};
307    /// let detection_config = configs::Detection {
308    ///   anchors: None,
309    ///   decoder: configs::DecoderType::Ultralytics,
310    ///   quantization: Some(configs::QuantTuple(0.012345, 26)),
311    ///   shape: vec![1, 84, 8400],
312    ///   dshape: Vec::new(),
313    ///   normalized: Some(true),
314    /// };
315    /// let output = ConfigOutput::Detection(detection_config);
316    /// assert_eq!(output.quantization(),
317    /// Some(configs::QuantTuple(0.012345,26))); ```
318    pub fn quantization(&self) -> Option<QuantTuple> {
319        match self {
320            ConfigOutput::Detection(detection) => detection.quantization,
321            ConfigOutput::Mask(mask) => mask.quantization,
322            ConfigOutput::Segmentation(segmentation) => segmentation.quantization,
323            ConfigOutput::Scores(scores) => scores.quantization,
324            ConfigOutput::Boxes(boxes) => boxes.quantization,
325            ConfigOutput::Protos(protos) => protos.quantization,
326            ConfigOutput::MaskCoefficients(mask_coefficients) => mask_coefficients.quantization,
327        }
328    }
329}
330
331pub mod configs {
332    use std::fmt::Display;
333
334    use serde::{Deserialize, Serialize};
335
336    #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy)]
337    pub struct QuantTuple(pub f32, pub i32);
338    impl From<QuantTuple> for (f32, i32) {
339        fn from(value: QuantTuple) -> Self {
340            (value.0, value.1)
341        }
342    }
343
344    impl From<(f32, i32)> for QuantTuple {
345        fn from(value: (f32, i32)) -> Self {
346            QuantTuple(value.0, value.1)
347        }
348    }
349
350    #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
351    pub struct Segmentation {
352        pub decoder: DecoderType,
353        pub quantization: Option<QuantTuple>,
354        pub shape: Vec<usize>,
355        // #[serde(default)]
356        // pub channels_first: bool,
357        #[serde(default)]
358        pub dshape: Vec<(DimName, usize)>,
359    }
360
361    #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
362    pub struct Protos {
363        pub decoder: DecoderType,
364        pub quantization: Option<QuantTuple>,
365        pub shape: Vec<usize>,
366        // #[serde(default)]
367        // pub channels_first: bool,
368        #[serde(default)]
369        pub dshape: Vec<(DimName, usize)>,
370    }
371
372    #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
373    pub struct MaskCoefficients {
374        pub decoder: DecoderType,
375        pub quantization: Option<QuantTuple>,
376        pub shape: Vec<usize>,
377        // #[serde(default)]
378        // pub channels_first: bool,
379        #[serde(default)]
380        pub dshape: Vec<(DimName, usize)>,
381    }
382
383    #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
384    pub struct Mask {
385        pub decoder: DecoderType,
386        pub quantization: Option<QuantTuple>,
387        pub shape: Vec<usize>,
388        // #[serde(default)]
389        // pub channels_first: bool,
390        #[serde(default)]
391        pub dshape: Vec<(DimName, usize)>,
392    }
393
394    #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
395    pub struct Detection {
396        pub anchors: Option<Vec<[f32; 2]>>,
397        pub decoder: DecoderType,
398        pub quantization: Option<QuantTuple>,
399        pub shape: Vec<usize>,
400        // #[serde(default)]
401        // pub channels_first: bool,
402        #[serde(default)]
403        pub dshape: Vec<(DimName, usize)>,
404        /// Whether box coordinates are normalized to [0,1] range.
405        /// - `Some(true)`: Coordinates in [0,1] range relative to model input
406        /// - `Some(false)`: Pixel coordinates relative to model input
407        ///   (letterboxed)
408        /// - `None`: Unknown, caller must infer (e.g., check if any coordinate
409        ///   > 1.0)
410        #[serde(default)]
411        pub normalized: Option<bool>,
412    }
413
414    #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
415    pub struct Scores {
416        pub decoder: DecoderType,
417        pub quantization: Option<QuantTuple>,
418        pub shape: Vec<usize>,
419        // #[serde(default)]
420        // pub channels_first: bool,
421        #[serde(default)]
422        pub dshape: Vec<(DimName, usize)>,
423    }
424
425    #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
426    pub struct Boxes {
427        pub decoder: DecoderType,
428        pub quantization: Option<QuantTuple>,
429        pub shape: Vec<usize>,
430        // #[serde(default)]
431        // pub channels_first: bool,
432        #[serde(default)]
433        pub dshape: Vec<(DimName, usize)>,
434        /// Whether box coordinates are normalized to [0,1] range.
435        /// - `Some(true)`: Coordinates in [0,1] range relative to model input
436        /// - `Some(false)`: Pixel coordinates relative to model input
437        ///   (letterboxed)
438        /// - `None`: Unknown, caller must infer (e.g., check if any coordinate
439        ///   > 1.0)
440        #[serde(default)]
441        pub normalized: Option<bool>,
442    }
443
444    #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq)]
445    pub enum DimName {
446        #[serde(rename = "batch")]
447        Batch,
448        #[serde(rename = "height")]
449        Height,
450        #[serde(rename = "width")]
451        Width,
452        #[serde(rename = "num_classes")]
453        NumClasses,
454        #[serde(rename = "num_features")]
455        NumFeatures,
456        #[serde(rename = "num_boxes")]
457        NumBoxes,
458        #[serde(rename = "num_protos")]
459        NumProtos,
460        #[serde(rename = "num_anchors_x_features")]
461        NumAnchorsXFeatures,
462        #[serde(rename = "padding")]
463        Padding,
464        #[serde(rename = "box_coords")]
465        BoxCoords,
466    }
467
468    impl Display for DimName {
469        /// Formats the DimName for display
470        /// # Examples
471        /// ```rust
472        /// # use edgefirst_decoder::configs::DimName;
473        /// let dim = DimName::Height;
474        /// assert_eq!(format!("{}", dim), "height");
475        /// # let s = format!("{} {} {} {} {} {} {} {} {} {}", DimName::Batch, DimName::Height, DimName::Width, DimName::NumClasses, DimName::NumFeatures, DimName::NumBoxes, DimName::NumProtos, DimName::NumAnchorsXFeatures, DimName::Padding, DimName::BoxCoords);
476        /// # assert_eq!(s, "batch height width num_classes num_features num_boxes num_protos num_anchors_x_features padding box_coords");
477        /// ```
478        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
479            match self {
480                DimName::Batch => write!(f, "batch"),
481                DimName::Height => write!(f, "height"),
482                DimName::Width => write!(f, "width"),
483                DimName::NumClasses => write!(f, "num_classes"),
484                DimName::NumFeatures => write!(f, "num_features"),
485                DimName::NumBoxes => write!(f, "num_boxes"),
486                DimName::NumProtos => write!(f, "num_protos"),
487                DimName::NumAnchorsXFeatures => write!(f, "num_anchors_x_features"),
488                DimName::Padding => write!(f, "padding"),
489                DimName::BoxCoords => write!(f, "box_coords"),
490            }
491        }
492    }
493
494    #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq)]
495    pub enum DecoderType {
496        #[serde(rename = "modelpack")]
497        ModelPack,
498        #[serde(rename = "ultralytics")]
499        Ultralytics,
500    }
501
502    /// Decoder version for Ultralytics models.
503    ///
504    /// Specifies the YOLO architecture version, which determines the decoding
505    /// strategy:
506    /// - `Yolov5`, `Yolov8`, `Yolo11`: Traditional models requiring external
507    ///   NMS
508    /// - `Yolo26`: End-to-end models with NMS embedded in the model
509    ///   architecture
510    ///
511    /// When `decoder_version` is set to `Yolo26`, the decoder uses end-to-end
512    /// model types regardless of the `nms` setting.
513    #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq)]
514    #[serde(rename_all = "lowercase")]
515    pub enum DecoderVersion {
516        /// YOLOv5 - anchor-free DFL decoder, requires external NMS
517        #[serde(rename = "yolov5")]
518        Yolov5,
519        /// YOLOv8 - anchor-free DFL decoder, requires external NMS
520        #[serde(rename = "yolov8")]
521        Yolov8,
522        /// YOLO11 - anchor-free DFL decoder, requires external NMS
523        #[serde(rename = "yolo11")]
524        Yolo11,
525        /// YOLO26 - end-to-end model with embedded NMS (one-to-one matching
526        /// heads)
527        #[serde(rename = "yolo26")]
528        Yolo26,
529    }
530
531    impl DecoderVersion {
532        /// Returns true if this version uses end-to-end inference (embedded
533        /// NMS).
534        pub fn is_end_to_end(&self) -> bool {
535            matches!(self, DecoderVersion::Yolo26)
536        }
537    }
538
539    /// NMS (Non-Maximum Suppression) mode for filtering overlapping detections.
540    ///
541    /// This enum is used with `Option<Nms>`:
542    /// - `Some(Nms::ClassAgnostic)` — class-agnostic NMS (default): suppress
543    ///   overlapping boxes regardless of class label
544    /// - `Some(Nms::ClassAware)` — class-aware NMS: only suppress boxes that
545    ///   share the same class label AND overlap above the IoU threshold
546    /// - `None` — bypass NMS entirely (for end-to-end models with embedded NMS)
547    #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq, Default)]
548    #[serde(rename_all = "snake_case")]
549    pub enum Nms {
550        /// Suppress overlapping boxes regardless of class label (default HAL
551        /// behavior)
552        #[default]
553        ClassAgnostic,
554        /// Only suppress boxes with the same class label that overlap
555        ClassAware,
556    }
557
558    #[derive(Debug, Clone, PartialEq)]
559    pub enum ModelType {
560        ModelPackSegDet {
561            boxes: Boxes,
562            scores: Scores,
563            segmentation: Segmentation,
564        },
565        ModelPackSegDetSplit {
566            detection: Vec<Detection>,
567            segmentation: Segmentation,
568        },
569        ModelPackDet {
570            boxes: Boxes,
571            scores: Scores,
572        },
573        ModelPackDetSplit {
574            detection: Vec<Detection>,
575        },
576        ModelPackSeg {
577            segmentation: Segmentation,
578        },
579        YoloDet {
580            boxes: Detection,
581        },
582        YoloSegDet {
583            boxes: Detection,
584            protos: Protos,
585        },
586        YoloSplitDet {
587            boxes: Boxes,
588            scores: Scores,
589        },
590        YoloSplitSegDet {
591            boxes: Boxes,
592            scores: Scores,
593            mask_coeff: MaskCoefficients,
594            protos: Protos,
595        },
596        /// End-to-end YOLO detection (post-NMS output from model)
597        /// Input shape: (1, N, 6+) where columns are [x1, y1, x2, y2, conf,
598        /// class, ...]
599        YoloEndToEndDet {
600            boxes: Detection,
601        },
602        /// End-to-end YOLO detection + segmentation (post-NMS output from
603        /// model) Input shape: (1, N, 6 + num_protos) where columns are
604        /// [x1, y1, x2, y2, conf, class, mask_coeff_0, ..., mask_coeff_31]
605        YoloEndToEndSegDet {
606            boxes: Detection,
607            protos: Protos,
608        },
609    }
610
611    #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
612    #[serde(rename_all = "lowercase")]
613    pub enum DataType {
614        Raw = 0,
615        Int8 = 1,
616        UInt8 = 2,
617        Int16 = 3,
618        UInt16 = 4,
619        Float16 = 5,
620        Int32 = 6,
621        UInt32 = 7,
622        Float32 = 8,
623        Int64 = 9,
624        UInt64 = 10,
625        Float64 = 11,
626        String = 12,
627    }
628}
629
630#[derive(Debug, Clone, PartialEq)]
631pub struct DecoderBuilder {
632    config_src: Option<ConfigSource>,
633    iou_threshold: f32,
634    score_threshold: f32,
635    /// NMS mode: Some(mode) applies NMS, None bypasses NMS (for end-to-end
636    /// models)
637    nms: Option<configs::Nms>,
638}
639
640#[derive(Debug, Clone, PartialEq)]
641enum ConfigSource {
642    Yaml(String),
643    Json(String),
644    Config(ConfigOutputs),
645}
646
647impl Default for DecoderBuilder {
648    /// Creates a default DecoderBuilder with no configuration and 0.5 score
649    /// threshold and 0.5 OU threshold.
650    ///
651    /// A valid confguration must be provided before building the Decoder.
652    ///
653    /// # Examples
654    /// ```rust
655    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
656    /// # fn main() -> DecoderResult<()> {
657    /// #  let config_yaml = include_str!("../../../testdata/modelpack_split.yaml").to_string();
658    /// let decoder = DecoderBuilder::default()
659    ///     .with_config_yaml_str(config_yaml)
660    ///     .build()?;
661    /// assert_eq!(decoder.score_threshold, 0.5);
662    /// assert_eq!(decoder.iou_threshold, 0.5);
663    ///
664    /// # Ok(())
665    /// # }
666    /// ```
667    fn default() -> Self {
668        Self {
669            config_src: None,
670            iou_threshold: 0.5,
671            score_threshold: 0.5,
672            nms: Some(configs::Nms::ClassAgnostic),
673        }
674    }
675}
676
677impl DecoderBuilder {
678    /// Creates a default DecoderBuilder with no configuration and 0.5 score
679    /// threshold and 0.5 OU threshold.
680    ///
681    /// A valid confguration must be provided before building the Decoder.
682    ///
683    /// # Examples
684    /// ```rust
685    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
686    /// # fn main() -> DecoderResult<()> {
687    /// #  let config_yaml = include_str!("../../../testdata/modelpack_split.yaml").to_string();
688    /// let decoder = DecoderBuilder::new()
689    ///     .with_config_yaml_str(config_yaml)
690    ///     .build()?;
691    /// assert_eq!(decoder.score_threshold, 0.5);
692    /// assert_eq!(decoder.iou_threshold, 0.5);
693    ///
694    /// # Ok(())
695    /// # }
696    /// ```
697    pub fn new() -> Self {
698        Self::default()
699    }
700
701    /// Loads a model configuration in YAML format. Does not check if the string
702    /// is a correct configuration file. Use `DecoderBuilder.build()` to
703    /// deserialize the YAML and parse the model configuration.
704    ///
705    /// # Examples
706    /// ```rust
707    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
708    /// # fn main() -> DecoderResult<()> {
709    /// let config_yaml = include_str!("../../../testdata/modelpack_split.yaml").to_string();
710    /// let decoder = DecoderBuilder::new()
711    ///     .with_config_yaml_str(config_yaml)
712    ///     .build()?;
713    ///
714    /// # Ok(())
715    /// # }
716    /// ```
717    pub fn with_config_yaml_str(mut self, yaml_str: String) -> Self {
718        self.config_src.replace(ConfigSource::Yaml(yaml_str));
719        self
720    }
721
722    /// Loads a model configuration in JSON format. Does not check if the string
723    /// is a correct configuration file. Use `DecoderBuilder.build()` to
724    /// deserialize the JSON and parse the model configuration.
725    ///
726    /// # Examples
727    /// ```rust
728    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
729    /// # fn main() -> DecoderResult<()> {
730    /// let config_json = include_str!("../../../testdata/modelpack_split.json").to_string();
731    /// let decoder = DecoderBuilder::new()
732    ///     .with_config_json_str(config_json)
733    ///     .build()?;
734    ///
735    /// # Ok(())
736    /// # }
737    /// ```
738    pub fn with_config_json_str(mut self, json_str: String) -> Self {
739        self.config_src.replace(ConfigSource::Json(json_str));
740        self
741    }
742
743    /// Loads a model configuration. Does not check if the configuration is
744    /// correct. Intended to be used when the user needs control over the
745    /// deserialize of the configuration information. Use
746    /// `DecoderBuilder.build()` to parse the model configuration.
747    ///
748    /// # Examples
749    /// ```rust
750    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
751    /// # fn main() -> DecoderResult<()> {
752    /// let config_json = include_str!("../../../testdata/modelpack_split.json");
753    /// let config = serde_json::from_str(config_json)?;
754    /// let decoder = DecoderBuilder::new().with_config(config).build()?;
755    ///
756    /// # Ok(())
757    /// # }
758    /// ```
759    pub fn with_config(mut self, config: ConfigOutputs) -> Self {
760        self.config_src.replace(ConfigSource::Config(config));
761        self
762    }
763
764    /// Loads a YOLO detection model configuration.  Use
765    /// `DecoderBuilder.build()` to parse the model configuration.
766    ///
767    /// # Examples
768    /// ```rust
769    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
770    /// # fn main() -> DecoderResult<()> {
771    /// let decoder = DecoderBuilder::new()
772    ///     .with_config_yolo_det(
773    ///         configs::Detection {
774    ///             anchors: None,
775    ///             decoder: configs::DecoderType::Ultralytics,
776    ///             quantization: Some(configs::QuantTuple(0.012345, 26)),
777    ///             shape: vec![1, 84, 8400],
778    ///             dshape: Vec::new(),
779    ///             normalized: Some(true),
780    ///         },
781    ///         None,
782    ///     )
783    ///     .build()?;
784    ///
785    /// # Ok(())
786    /// # }
787    /// ```
788    pub fn with_config_yolo_det(
789        mut self,
790        boxes: configs::Detection,
791        version: Option<DecoderVersion>,
792    ) -> Self {
793        let config = ConfigOutputs {
794            outputs: vec![ConfigOutput::Detection(boxes)],
795            decoder_version: version,
796            ..Default::default()
797        };
798        self.config_src.replace(ConfigSource::Config(config));
799        self
800    }
801
802    /// Loads a YOLO split detection model configuration.  Use
803    /// `DecoderBuilder.build()` to parse the model configuration.
804    ///
805    /// # Examples
806    /// ```rust
807    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
808    /// # fn main() -> DecoderResult<()> {
809    /// let boxes_config = configs::Boxes {
810    ///     decoder: configs::DecoderType::Ultralytics,
811    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
812    ///     shape: vec![1, 4, 8400],
813    ///     dshape: Vec::new(),
814    ///     normalized: Some(true),
815    /// };
816    /// let scores_config = configs::Scores {
817    ///     decoder: configs::DecoderType::Ultralytics,
818    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
819    ///     shape: vec![1, 80, 8400],
820    ///     dshape: Vec::new(),
821    /// };
822    /// let decoder = DecoderBuilder::new()
823    ///     .with_config_yolo_split_det(boxes_config, scores_config)
824    ///     .build()?;
825    /// # Ok(())
826    /// # }
827    /// ```
828    pub fn with_config_yolo_split_det(
829        mut self,
830        boxes: configs::Boxes,
831        scores: configs::Scores,
832    ) -> Self {
833        let config = ConfigOutputs {
834            outputs: vec![ConfigOutput::Boxes(boxes), ConfigOutput::Scores(scores)],
835            ..Default::default()
836        };
837        self.config_src.replace(ConfigSource::Config(config));
838        self
839    }
840
841    /// Loads a YOLO segmentation model configuration.  Use
842    /// `DecoderBuilder.build()` to parse the model configuration.
843    ///
844    /// # Examples
845    /// ```rust
846    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
847    /// # fn main() -> DecoderResult<()> {
848    /// let seg_config = configs::Detection {
849    ///     decoder: configs::DecoderType::Ultralytics,
850    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
851    ///     shape: vec![1, 116, 8400],
852    ///     anchors: None,
853    ///     dshape: Vec::new(),
854    ///     normalized: Some(true),
855    /// };
856    /// let protos_config = configs::Protos {
857    ///     decoder: configs::DecoderType::Ultralytics,
858    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
859    ///     shape: vec![1, 160, 160, 32],
860    ///     dshape: Vec::new(),
861    /// };
862    /// let decoder = DecoderBuilder::new()
863    ///     .with_config_yolo_segdet(
864    ///         seg_config,
865    ///         protos_config,
866    ///         Some(configs::DecoderVersion::Yolov8),
867    ///     )
868    ///     .build()?;
869    /// # Ok(())
870    /// # }
871    /// ```
872    pub fn with_config_yolo_segdet(
873        mut self,
874        boxes: configs::Detection,
875        protos: configs::Protos,
876        version: Option<DecoderVersion>,
877    ) -> Self {
878        let config = ConfigOutputs {
879            outputs: vec![ConfigOutput::Detection(boxes), ConfigOutput::Protos(protos)],
880            decoder_version: version,
881            ..Default::default()
882        };
883        self.config_src.replace(ConfigSource::Config(config));
884        self
885    }
886
887    /// Loads a YOLO split segmentation model configuration.  Use
888    /// `DecoderBuilder.build()` to parse the model configuration.
889    ///
890    /// # Examples
891    /// ```rust
892    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
893    /// # fn main() -> DecoderResult<()> {
894    /// let boxes_config = configs::Boxes {
895    ///     decoder: configs::DecoderType::Ultralytics,
896    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
897    ///     shape: vec![1, 4, 8400],
898    ///     dshape: Vec::new(),
899    ///     normalized: Some(true),
900    /// };
901    /// let scores_config = configs::Scores {
902    ///     decoder: configs::DecoderType::Ultralytics,
903    ///     quantization: Some(configs::QuantTuple(0.012345, 14)),
904    ///     shape: vec![1, 80, 8400],
905    ///     dshape: Vec::new(),
906    /// };
907    /// let mask_config = configs::MaskCoefficients {
908    ///     decoder: configs::DecoderType::Ultralytics,
909    ///     quantization: Some(configs::QuantTuple(0.0064123, 125)),
910    ///     shape: vec![1, 32, 8400],
911    ///     dshape: Vec::new(),
912    /// };
913    /// let protos_config = configs::Protos {
914    ///     decoder: configs::DecoderType::Ultralytics,
915    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
916    ///     shape: vec![1, 160, 160, 32],
917    ///     dshape: Vec::new(),
918    /// };
919    /// let decoder = DecoderBuilder::new()
920    ///     .with_config_yolo_split_segdet(boxes_config, scores_config, mask_config, protos_config)
921    ///     .build()?;
922    /// # Ok(())
923    /// # }
924    /// ```
925    pub fn with_config_yolo_split_segdet(
926        mut self,
927        boxes: configs::Boxes,
928        scores: configs::Scores,
929        mask_coefficients: configs::MaskCoefficients,
930        protos: configs::Protos,
931    ) -> Self {
932        let config = ConfigOutputs {
933            outputs: vec![
934                ConfigOutput::Boxes(boxes),
935                ConfigOutput::Scores(scores),
936                ConfigOutput::MaskCoefficients(mask_coefficients),
937                ConfigOutput::Protos(protos),
938            ],
939            ..Default::default()
940        };
941        self.config_src.replace(ConfigSource::Config(config));
942        self
943    }
944
945    /// Loads a ModelPack detection model configuration.  Use
946    /// `DecoderBuilder.build()` to parse the model configuration.
947    ///
948    /// # Examples
949    /// ```rust
950    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
951    /// # fn main() -> DecoderResult<()> {
952    /// let boxes_config = configs::Boxes {
953    ///     decoder: configs::DecoderType::ModelPack,
954    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
955    ///     shape: vec![1, 8400, 1, 4],
956    ///     dshape: Vec::new(),
957    ///     normalized: Some(true),
958    /// };
959    /// let scores_config = configs::Scores {
960    ///     decoder: configs::DecoderType::ModelPack,
961    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
962    ///     shape: vec![1, 8400, 3],
963    ///     dshape: Vec::new(),
964    /// };
965    /// let decoder = DecoderBuilder::new()
966    ///     .with_config_modelpack_det(boxes_config, scores_config)
967    ///     .build()?;
968    /// # Ok(())
969    /// # }
970    /// ```
971    pub fn with_config_modelpack_det(
972        mut self,
973        boxes: configs::Boxes,
974        scores: configs::Scores,
975    ) -> Self {
976        let config = ConfigOutputs {
977            outputs: vec![ConfigOutput::Boxes(boxes), ConfigOutput::Scores(scores)],
978            ..Default::default()
979        };
980        self.config_src.replace(ConfigSource::Config(config));
981        self
982    }
983
984    /// Loads a ModelPack split detection model configuration. Use
985    /// `DecoderBuilder.build()` to parse the model configuration.
986    ///
987    /// # Examples
988    /// ```rust
989    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
990    /// # fn main() -> DecoderResult<()> {
991    /// let config0 = configs::Detection {
992    ///     anchors: Some(vec![
993    ///         [0.13750000298023224, 0.2074074000120163],
994    ///         [0.2541666626930237, 0.21481481194496155],
995    ///         [0.23125000298023224, 0.35185185074806213],
996    ///     ]),
997    ///     decoder: configs::DecoderType::ModelPack,
998    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
999    ///     shape: vec![1, 17, 30, 18],
1000    ///     dshape: Vec::new(),
1001    ///     normalized: Some(true),
1002    /// };
1003    /// let config1 = configs::Detection {
1004    ///     anchors: Some(vec![
1005    ///         [0.36666667461395264, 0.31481480598449707],
1006    ///         [0.38749998807907104, 0.4740740656852722],
1007    ///         [0.5333333611488342, 0.644444465637207],
1008    ///     ]),
1009    ///     decoder: configs::DecoderType::ModelPack,
1010    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
1011    ///     shape: vec![1, 9, 15, 18],
1012    ///     dshape: Vec::new(),
1013    ///     normalized: Some(true),
1014    /// };
1015    ///
1016    /// let decoder = DecoderBuilder::new()
1017    ///     .with_config_modelpack_det_split(vec![config0, config1])
1018    ///     .build()?;
1019    /// # Ok(())
1020    /// # }
1021    /// ```
1022    pub fn with_config_modelpack_det_split(mut self, boxes: Vec<configs::Detection>) -> Self {
1023        let outputs = boxes.into_iter().map(ConfigOutput::Detection).collect();
1024        let config = ConfigOutputs {
1025            outputs,
1026            ..Default::default()
1027        };
1028        self.config_src.replace(ConfigSource::Config(config));
1029        self
1030    }
1031
1032    /// Loads a ModelPack segmentation detection model configuration. Use
1033    /// `DecoderBuilder.build()` to parse the model configuration.
1034    ///
1035    /// # Examples
1036    /// ```rust
1037    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
1038    /// # fn main() -> DecoderResult<()> {
1039    /// let boxes_config = configs::Boxes {
1040    ///     decoder: configs::DecoderType::ModelPack,
1041    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
1042    ///     shape: vec![1, 8400, 1, 4],
1043    ///     dshape: Vec::new(),
1044    ///     normalized: Some(true),
1045    /// };
1046    /// let scores_config = configs::Scores {
1047    ///     decoder: configs::DecoderType::ModelPack,
1048    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
1049    ///     shape: vec![1, 8400, 2],
1050    ///     dshape: Vec::new(),
1051    /// };
1052    /// let seg_config = configs::Segmentation {
1053    ///     decoder: configs::DecoderType::ModelPack,
1054    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
1055    ///     shape: vec![1, 640, 640, 3],
1056    ///     dshape: Vec::new(),
1057    /// };
1058    /// let decoder = DecoderBuilder::new()
1059    ///     .with_config_modelpack_segdet(boxes_config, scores_config, seg_config)
1060    ///     .build()?;
1061    /// # Ok(())
1062    /// # }
1063    /// ```
1064    pub fn with_config_modelpack_segdet(
1065        mut self,
1066        boxes: configs::Boxes,
1067        scores: configs::Scores,
1068        segmentation: configs::Segmentation,
1069    ) -> Self {
1070        let config = ConfigOutputs {
1071            outputs: vec![
1072                ConfigOutput::Boxes(boxes),
1073                ConfigOutput::Scores(scores),
1074                ConfigOutput::Segmentation(segmentation),
1075            ],
1076            ..Default::default()
1077        };
1078        self.config_src.replace(ConfigSource::Config(config));
1079        self
1080    }
1081
1082    /// Loads a ModelPack segmentation split detection model configuration. Use
1083    /// `DecoderBuilder.build()` to parse the model configuration.
1084    ///
1085    /// # Examples
1086    /// ```rust
1087    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
1088    /// # fn main() -> DecoderResult<()> {
1089    /// let config0 = configs::Detection {
1090    ///     anchors: Some(vec![
1091    ///         [0.36666667461395264, 0.31481480598449707],
1092    ///         [0.38749998807907104, 0.4740740656852722],
1093    ///         [0.5333333611488342, 0.644444465637207],
1094    ///     ]),
1095    ///     decoder: configs::DecoderType::ModelPack,
1096    ///     quantization: Some(configs::QuantTuple(0.08547406643629074, 174)),
1097    ///     shape: vec![1, 9, 15, 18],
1098    ///     dshape: Vec::new(),
1099    ///     normalized: Some(true),
1100    /// };
1101    /// let config1 = configs::Detection {
1102    ///     anchors: Some(vec![
1103    ///         [0.13750000298023224, 0.2074074000120163],
1104    ///         [0.2541666626930237, 0.21481481194496155],
1105    ///         [0.23125000298023224, 0.35185185074806213],
1106    ///     ]),
1107    ///     decoder: configs::DecoderType::ModelPack,
1108    ///     quantization: Some(configs::QuantTuple(0.09929127991199493, 183)),
1109    ///     shape: vec![1, 17, 30, 18],
1110    ///     dshape: Vec::new(),
1111    ///     normalized: Some(true),
1112    /// };
1113    /// let seg_config = configs::Segmentation {
1114    ///     decoder: configs::DecoderType::ModelPack,
1115    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
1116    ///     shape: vec![1, 640, 640, 2],
1117    ///     dshape: Vec::new(),
1118    /// };
1119    /// let decoder = DecoderBuilder::new()
1120    ///     .with_config_modelpack_segdet_split(vec![config0, config1], seg_config)
1121    ///     .build()?;
1122    /// # Ok(())
1123    /// # }
1124    /// ```
1125    pub fn with_config_modelpack_segdet_split(
1126        mut self,
1127        boxes: Vec<configs::Detection>,
1128        segmentation: configs::Segmentation,
1129    ) -> Self {
1130        let mut outputs = boxes
1131            .into_iter()
1132            .map(ConfigOutput::Detection)
1133            .collect::<Vec<_>>();
1134        outputs.push(ConfigOutput::Segmentation(segmentation));
1135        let config = ConfigOutputs {
1136            outputs,
1137            ..Default::default()
1138        };
1139        self.config_src.replace(ConfigSource::Config(config));
1140        self
1141    }
1142
1143    /// Loads a ModelPack segmentation model configuration. Use
1144    /// `DecoderBuilder.build()` to parse the model configuration.
1145    ///
1146    /// # Examples
1147    /// ```rust
1148    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
1149    /// # fn main() -> DecoderResult<()> {
1150    /// let seg_config = configs::Segmentation {
1151    ///     decoder: configs::DecoderType::ModelPack,
1152    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
1153    ///     shape: vec![1, 640, 640, 3],
1154    ///     dshape: Vec::new(),
1155    /// };
1156    /// let decoder = DecoderBuilder::new()
1157    ///     .with_config_modelpack_seg(seg_config)
1158    ///     .build()?;
1159    /// # Ok(())
1160    /// # }
1161    /// ```
1162    pub fn with_config_modelpack_seg(mut self, segmentation: configs::Segmentation) -> Self {
1163        let config = ConfigOutputs {
1164            outputs: vec![ConfigOutput::Segmentation(segmentation)],
1165            ..Default::default()
1166        };
1167        self.config_src.replace(ConfigSource::Config(config));
1168        self
1169    }
1170
1171    /// Sets the scores threshold of the decoder
1172    ///
1173    /// # Examples
1174    /// ```rust
1175    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
1176    /// # fn main() -> DecoderResult<()> {
1177    /// # let config_json = include_str!("../../../testdata/modelpack_split.json").to_string();
1178    /// let decoder = DecoderBuilder::new()
1179    ///     .with_config_json_str(config_json)
1180    ///     .with_score_threshold(0.654)
1181    ///     .build()?;
1182    /// assert_eq!(decoder.score_threshold, 0.654);
1183    /// # Ok(())
1184    /// # }
1185    /// ```
1186    pub fn with_score_threshold(mut self, score_threshold: f32) -> Self {
1187        self.score_threshold = score_threshold;
1188        self
1189    }
1190
1191    /// Sets the IOU threshold of the decoder. Has no effect when NMS is set to
1192    /// `None`
1193    ///
1194    /// # Examples
1195    /// ```rust
1196    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
1197    /// # fn main() -> DecoderResult<()> {
1198    /// # let config_json = include_str!("../../../testdata/modelpack_split.json").to_string();
1199    /// let decoder = DecoderBuilder::new()
1200    ///     .with_config_json_str(config_json)
1201    ///     .with_iou_threshold(0.654)
1202    ///     .build()?;
1203    /// assert_eq!(decoder.iou_threshold, 0.654);
1204    /// # Ok(())
1205    /// # }
1206    /// ```
1207    pub fn with_iou_threshold(mut self, iou_threshold: f32) -> Self {
1208        self.iou_threshold = iou_threshold;
1209        self
1210    }
1211
1212    /// Sets the NMS mode for the decoder.
1213    ///
1214    /// - `Some(Nms::ClassAgnostic)` — class-agnostic NMS (default): suppress
1215    ///   overlapping boxes regardless of class label
1216    /// - `Some(Nms::ClassAware)` — class-aware NMS: only suppress boxes that
1217    ///   share the same class label AND overlap above the IoU threshold
1218    /// - `None` — bypass NMS entirely (for end-to-end models with embedded NMS)
1219    ///
1220    /// # Examples
1221    /// ```rust
1222    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, configs::Nms};
1223    /// # fn main() -> DecoderResult<()> {
1224    /// # let config_json = include_str!("../../../testdata/modelpack_split.json").to_string();
1225    /// let decoder = DecoderBuilder::new()
1226    ///     .with_config_json_str(config_json)
1227    ///     .with_nms(Some(Nms::ClassAware))
1228    ///     .build()?;
1229    /// assert_eq!(decoder.nms, Some(Nms::ClassAware));
1230    /// # Ok(())
1231    /// # }
1232    /// ```
1233    pub fn with_nms(mut self, nms: Option<configs::Nms>) -> Self {
1234        self.nms = nms;
1235        self
1236    }
1237
1238    /// Builds the decoder with the given settings. If the config is a JSON or
1239    /// YAML string, this will deserialize the JSON or YAML and then parse the
1240    /// configuration information.
1241    ///
1242    /// # Examples
1243    /// ```rust
1244    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
1245    /// # fn main() -> DecoderResult<()> {
1246    /// # let config_json = include_str!("../../../testdata/modelpack_split.json").to_string();
1247    /// let decoder = DecoderBuilder::new()
1248    ///     .with_config_json_str(config_json)
1249    ///     .with_score_threshold(0.654)
1250    ///     .build()?;
1251    /// # Ok(())
1252    /// # }
1253    /// ```
1254    pub fn build(self) -> Result<Decoder, DecoderError> {
1255        let config = match self.config_src {
1256            Some(ConfigSource::Json(s)) => serde_json::from_str(&s)?,
1257            Some(ConfigSource::Yaml(s)) => serde_yaml::from_str(&s)?,
1258            Some(ConfigSource::Config(c)) => c,
1259            None => return Err(DecoderError::NoConfig),
1260        };
1261
1262        // Extract normalized flag from config outputs
1263        let normalized = Self::get_normalized(&config.outputs);
1264
1265        // Use NMS from config if present, otherwise use builder's NMS setting
1266        let nms = config.nms.or(self.nms);
1267        let model_type = Self::get_model_type(config)?;
1268
1269        Ok(Decoder {
1270            model_type,
1271            iou_threshold: self.iou_threshold,
1272            score_threshold: self.score_threshold,
1273            nms,
1274            normalized,
1275        })
1276    }
1277
1278    /// Extracts the normalized flag from config outputs.
1279    /// - `Some(true)`: Boxes are in normalized [0,1] coordinates
1280    /// - `Some(false)`: Boxes are in pixel coordinates
1281    /// - `None`: Unknown (not specified in config), caller must infer
1282    fn get_normalized(outputs: &[ConfigOutput]) -> Option<bool> {
1283        for output in outputs {
1284            match output {
1285                ConfigOutput::Detection(det) => return det.normalized,
1286                ConfigOutput::Boxes(boxes) => return boxes.normalized,
1287                _ => {}
1288            }
1289        }
1290        None // not specified
1291    }
1292
1293    fn get_model_type(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
1294        // yolo or modelpack
1295        let mut yolo = false;
1296        let mut modelpack = false;
1297        for c in &configs.outputs {
1298            match c.decoder() {
1299                DecoderType::ModelPack => modelpack = true,
1300                DecoderType::Ultralytics => yolo = true,
1301            }
1302        }
1303        match (modelpack, yolo) {
1304            (true, true) => Err(DecoderError::InvalidConfig(
1305                "Both ModelPack and Yolo outputs found in config".to_string(),
1306            )),
1307            (true, false) => Self::get_model_type_modelpack(configs),
1308            (false, true) => Self::get_model_type_yolo(configs),
1309            (false, false) => Err(DecoderError::InvalidConfig(
1310                "No outputs found in config".to_string(),
1311            )),
1312        }
1313    }
1314
1315    fn get_model_type_yolo(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
1316        let mut boxes = None;
1317        let mut protos = None;
1318        let mut split_boxes = None;
1319        let mut split_scores = None;
1320        let mut split_mask_coeff = None;
1321        for c in configs.outputs {
1322            match c {
1323                ConfigOutput::Detection(detection) => boxes = Some(detection),
1324                ConfigOutput::Segmentation(_) => {
1325                    return Err(DecoderError::InvalidConfig(
1326                        "Invalid Segmentation output with Yolo decoder".to_string(),
1327                    ));
1328                }
1329                ConfigOutput::Protos(protos_) => protos = Some(protos_),
1330                ConfigOutput::Mask(_) => {
1331                    return Err(DecoderError::InvalidConfig(
1332                        "Invalid Mask output with Yolo decoder".to_string(),
1333                    ));
1334                }
1335                ConfigOutput::Scores(scores) => split_scores = Some(scores),
1336                ConfigOutput::Boxes(boxes) => split_boxes = Some(boxes),
1337                ConfigOutput::MaskCoefficients(mask_coeff) => split_mask_coeff = Some(mask_coeff),
1338            }
1339        }
1340
1341        // Use end-to-end model types when:
1342        // 1. decoder_version is explicitly set to Yolo26 (definitive), OR
1343        //    decoder_version is not set but the dshapes are (batch, num_boxes,
1344        //    num_features)
1345        let is_end_to_end_dshape = boxes.as_ref().is_some_and(|b| {
1346            let dims = b.dshape.iter().map(|(d, _)| *d).collect::<Vec<_>>();
1347            dims == vec![DimName::Batch, DimName::NumBoxes, DimName::NumFeatures]
1348        });
1349
1350        let is_end_to_end = configs
1351            .decoder_version
1352            .map(|v| v.is_end_to_end())
1353            .unwrap_or(is_end_to_end_dshape);
1354
1355        if is_end_to_end {
1356            if let Some(boxes) = boxes {
1357                if let Some(protos) = protos {
1358                    Self::verify_yolo_seg_det_26(&boxes, &protos)?;
1359                    return Ok(ModelType::YoloEndToEndSegDet { boxes, protos });
1360                } else {
1361                    Self::verify_yolo_det_26(&boxes)?;
1362                    return Ok(ModelType::YoloEndToEndDet { boxes });
1363                }
1364            } else {
1365                return Err(DecoderError::InvalidConfig(
1366                    "Invalid Yolo end-to-end model outputs".to_string(),
1367                ));
1368            }
1369        }
1370
1371        if let Some(boxes) = boxes {
1372            if let Some(protos) = protos {
1373                Self::verify_yolo_seg_det(&boxes, &protos)?;
1374                Ok(ModelType::YoloSegDet { boxes, protos })
1375            } else {
1376                Self::verify_yolo_det(&boxes)?;
1377                Ok(ModelType::YoloDet { boxes })
1378            }
1379        } else if let (Some(boxes), Some(scores)) = (split_boxes, split_scores) {
1380            if let (Some(mask_coeff), Some(protos)) = (split_mask_coeff, protos) {
1381                Self::verify_yolo_split_segdet(&boxes, &scores, &mask_coeff, &protos)?;
1382                Ok(ModelType::YoloSplitSegDet {
1383                    boxes,
1384                    scores,
1385                    mask_coeff,
1386                    protos,
1387                })
1388            } else {
1389                Self::verify_yolo_split_det(&boxes, &scores)?;
1390                Ok(ModelType::YoloSplitDet { boxes, scores })
1391            }
1392        } else {
1393            Err(DecoderError::InvalidConfig(
1394                "Invalid Yolo model outputs".to_string(),
1395            ))
1396        }
1397    }
1398
1399    fn verify_yolo_det(detect: &configs::Detection) -> Result<(), DecoderError> {
1400        if detect.shape.len() != 3 {
1401            return Err(DecoderError::InvalidConfig(format!(
1402                "Invalid Yolo Detection shape {:?}",
1403                detect.shape
1404            )));
1405        }
1406
1407        Self::verify_dshapes(
1408            &detect.dshape,
1409            &detect.shape,
1410            "Detection",
1411            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1412        )?;
1413        if !detect.dshape.is_empty() {
1414            Self::get_class_count(&detect.dshape, None, None)?;
1415        } else {
1416            Self::get_class_count_no_dshape(detect.into(), None)?;
1417        }
1418
1419        Ok(())
1420    }
1421
1422    fn verify_yolo_det_26(detect: &configs::Detection) -> Result<(), DecoderError> {
1423        if detect.shape.len() != 3 {
1424            return Err(DecoderError::InvalidConfig(format!(
1425                "Invalid Yolo Detection shape {:?}",
1426                detect.shape
1427            )));
1428        }
1429
1430        Self::verify_dshapes(
1431            &detect.dshape,
1432            &detect.shape,
1433            "Detection",
1434            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1435        )?;
1436
1437        if !detect.shape.contains(&6) {
1438            return Err(DecoderError::InvalidConfig(
1439                "Yolo26 Detection must have 6 features".to_string(),
1440            ));
1441        }
1442
1443        Ok(())
1444    }
1445
1446    fn verify_yolo_seg_det(
1447        detection: &configs::Detection,
1448        protos: &configs::Protos,
1449    ) -> Result<(), DecoderError> {
1450        if detection.shape.len() != 3 {
1451            return Err(DecoderError::InvalidConfig(format!(
1452                "Invalid Yolo Detection shape {:?}",
1453                detection.shape
1454            )));
1455        }
1456        if protos.shape.len() != 4 {
1457            return Err(DecoderError::InvalidConfig(format!(
1458                "Invalid Yolo Protos shape {:?}",
1459                protos.shape
1460            )));
1461        }
1462
1463        Self::verify_dshapes(
1464            &detection.dshape,
1465            &detection.shape,
1466            "Detection",
1467            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1468        )?;
1469        Self::verify_dshapes(
1470            &protos.dshape,
1471            &protos.shape,
1472            "Protos",
1473            &[
1474                DimName::Batch,
1475                DimName::Height,
1476                DimName::Width,
1477                DimName::NumProtos,
1478            ],
1479        )?;
1480
1481        let protos_count = Self::get_protos_count(&protos.dshape).unwrap_or(protos.shape[3]);
1482        log::debug!("Protos count: {}", protos_count);
1483        log::debug!("Detection dshape: {:?}", detection.dshape);
1484        let classes = if !detection.dshape.is_empty() {
1485            Self::get_class_count(&detection.dshape, Some(protos_count), None)?
1486        } else {
1487            Self::get_class_count_no_dshape(detection.into(), Some(protos_count))?
1488        };
1489
1490        if classes == 0 {
1491            return Err(DecoderError::InvalidConfig(
1492                "Yolo Segmentation Detection has zero classes".to_string(),
1493            ));
1494        }
1495
1496        Ok(())
1497    }
1498
1499    fn verify_yolo_seg_det_26(
1500        detection: &configs::Detection,
1501        protos: &configs::Protos,
1502    ) -> Result<(), DecoderError> {
1503        if detection.shape.len() != 3 {
1504            return Err(DecoderError::InvalidConfig(format!(
1505                "Invalid Yolo Detection shape {:?}",
1506                detection.shape
1507            )));
1508        }
1509        if protos.shape.len() != 4 {
1510            return Err(DecoderError::InvalidConfig(format!(
1511                "Invalid Yolo Protos shape {:?}",
1512                protos.shape
1513            )));
1514        }
1515
1516        Self::verify_dshapes(
1517            &detection.dshape,
1518            &detection.shape,
1519            "Detection",
1520            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1521        )?;
1522        Self::verify_dshapes(
1523            &protos.dshape,
1524            &protos.shape,
1525            "Protos",
1526            &[
1527                DimName::Batch,
1528                DimName::Height,
1529                DimName::Width,
1530                DimName::NumProtos,
1531            ],
1532        )?;
1533
1534        let protos_count = Self::get_protos_count(&protos.dshape).unwrap_or(protos.shape[3]);
1535        log::debug!("Protos count: {}", protos_count);
1536        log::debug!("Detection dshape: {:?}", detection.dshape);
1537
1538        if !detection.shape.contains(&(6 + protos_count)) {
1539            return Err(DecoderError::InvalidConfig(format!(
1540                "Yolo26 Segmentation Detection must have num_features be 6 + num_protos = {}",
1541                6 + protos_count
1542            )));
1543        }
1544
1545        Ok(())
1546    }
1547
1548    fn verify_yolo_split_det(
1549        boxes: &configs::Boxes,
1550        scores: &configs::Scores,
1551    ) -> Result<(), DecoderError> {
1552        if boxes.shape.len() != 3 {
1553            return Err(DecoderError::InvalidConfig(format!(
1554                "Invalid Yolo Split Boxes shape {:?}",
1555                boxes.shape
1556            )));
1557        }
1558        if scores.shape.len() != 3 {
1559            return Err(DecoderError::InvalidConfig(format!(
1560                "Invalid Yolo Split Scores shape {:?}",
1561                scores.shape
1562            )));
1563        }
1564
1565        Self::verify_dshapes(
1566            &boxes.dshape,
1567            &boxes.shape,
1568            "Boxes",
1569            &[DimName::Batch, DimName::BoxCoords, DimName::NumBoxes],
1570        )?;
1571        Self::verify_dshapes(
1572            &scores.dshape,
1573            &scores.shape,
1574            "Scores",
1575            &[DimName::Batch, DimName::NumClasses, DimName::NumBoxes],
1576        )?;
1577
1578        let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[2]);
1579        let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[2]);
1580
1581        if boxes_num != scores_num {
1582            return Err(DecoderError::InvalidConfig(format!(
1583                "Yolo Split Detection Boxes num {} incompatible with Scores num {}",
1584                boxes_num, scores_num
1585            )));
1586        }
1587
1588        Ok(())
1589    }
1590
1591    fn verify_yolo_split_segdet(
1592        boxes: &configs::Boxes,
1593        scores: &configs::Scores,
1594        mask_coeff: &configs::MaskCoefficients,
1595        protos: &configs::Protos,
1596    ) -> Result<(), DecoderError> {
1597        if boxes.shape.len() != 3 {
1598            return Err(DecoderError::InvalidConfig(format!(
1599                "Invalid Yolo Split Boxes shape {:?}",
1600                boxes.shape
1601            )));
1602        }
1603        if scores.shape.len() != 3 {
1604            return Err(DecoderError::InvalidConfig(format!(
1605                "Invalid Yolo Split Scores shape {:?}",
1606                scores.shape
1607            )));
1608        }
1609
1610        if mask_coeff.shape.len() != 3 {
1611            return Err(DecoderError::InvalidConfig(format!(
1612                "Invalid Yolo Split Mask Coefficients shape {:?}",
1613                mask_coeff.shape
1614            )));
1615        }
1616
1617        if protos.shape.len() != 4 {
1618            return Err(DecoderError::InvalidConfig(format!(
1619                "Invalid Yolo Protos shape {:?}",
1620                mask_coeff.shape
1621            )));
1622        }
1623
1624        Self::verify_dshapes(
1625            &boxes.dshape,
1626            &boxes.shape,
1627            "Boxes",
1628            &[DimName::Batch, DimName::BoxCoords, DimName::NumBoxes],
1629        )?;
1630        Self::verify_dshapes(
1631            &scores.dshape,
1632            &scores.shape,
1633            "Scores",
1634            &[DimName::Batch, DimName::NumClasses, DimName::NumBoxes],
1635        )?;
1636        Self::verify_dshapes(
1637            &mask_coeff.dshape,
1638            &mask_coeff.shape,
1639            "Mask Coefficients",
1640            &[DimName::Batch, DimName::NumProtos, DimName::NumBoxes],
1641        )?;
1642        Self::verify_dshapes(
1643            &protos.dshape,
1644            &protos.shape,
1645            "Protos",
1646            &[
1647                DimName::Batch,
1648                DimName::Height,
1649                DimName::Width,
1650                DimName::NumProtos,
1651            ],
1652        )?;
1653
1654        let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[2]);
1655        let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[2]);
1656        let mask_num = Self::get_box_count(&mask_coeff.dshape).unwrap_or(mask_coeff.shape[2]);
1657
1658        let mask_channels = if !mask_coeff.dshape.is_empty() {
1659            Self::get_protos_count(&mask_coeff.dshape).ok_or_else(|| {
1660                DecoderError::InvalidConfig("Could not find num_protos in config".to_string())
1661            })?
1662        } else {
1663            mask_coeff.shape[1]
1664        };
1665        let proto_channels = if !protos.dshape.is_empty() {
1666            Self::get_protos_count(&protos.dshape).ok_or_else(|| {
1667                DecoderError::InvalidConfig("Could not find num_protos in config".to_string())
1668            })?
1669        } else {
1670            protos.shape[3]
1671        };
1672
1673        if boxes_num != scores_num {
1674            return Err(DecoderError::InvalidConfig(format!(
1675                "Yolo Split Detection Boxes num {} incompatible with Scores num {}",
1676                boxes_num, scores_num
1677            )));
1678        }
1679
1680        if boxes_num != mask_num {
1681            return Err(DecoderError::InvalidConfig(format!(
1682                "Yolo Split Detection Boxes num {} incompatible with Mask Coefficients num {}",
1683                boxes_num, mask_num
1684            )));
1685        }
1686
1687        if proto_channels != mask_channels {
1688            return Err(DecoderError::InvalidConfig(format!(
1689                "Yolo Protos channels {} incompatible with Mask Coefficients channels {}",
1690                proto_channels, mask_channels
1691            )));
1692        }
1693
1694        Ok(())
1695    }
1696
1697    fn get_model_type_modelpack(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
1698        let mut split_decoders = Vec::new();
1699        let mut segment_ = None;
1700        let mut scores_ = None;
1701        let mut boxes_ = None;
1702        for c in configs.outputs {
1703            match c {
1704                ConfigOutput::Detection(detection) => split_decoders.push(detection),
1705                ConfigOutput::Segmentation(segmentation) => segment_ = Some(segmentation),
1706                ConfigOutput::Mask(_) => {}
1707                ConfigOutput::Protos(_) => {
1708                    return Err(DecoderError::InvalidConfig(
1709                        "ModelPack should not have protos".to_string(),
1710                    ));
1711                }
1712                ConfigOutput::Scores(scores) => scores_ = Some(scores),
1713                ConfigOutput::Boxes(boxes) => boxes_ = Some(boxes),
1714                ConfigOutput::MaskCoefficients(_) => {
1715                    return Err(DecoderError::InvalidConfig(
1716                        "ModelPack should not have mask coefficients".to_string(),
1717                    ));
1718                }
1719            }
1720        }
1721
1722        if let Some(segmentation) = segment_ {
1723            if !split_decoders.is_empty() {
1724                let classes = Self::verify_modelpack_split_det(&split_decoders)?;
1725                Self::verify_modelpack_seg(&segmentation, Some(classes))?;
1726                Ok(ModelType::ModelPackSegDetSplit {
1727                    detection: split_decoders,
1728                    segmentation,
1729                })
1730            } else if let (Some(scores), Some(boxes)) = (scores_, boxes_) {
1731                let classes = Self::verify_modelpack_det(&boxes, &scores)?;
1732                Self::verify_modelpack_seg(&segmentation, Some(classes))?;
1733                Ok(ModelType::ModelPackSegDet {
1734                    boxes,
1735                    scores,
1736                    segmentation,
1737                })
1738            } else {
1739                Self::verify_modelpack_seg(&segmentation, None)?;
1740                Ok(ModelType::ModelPackSeg { segmentation })
1741            }
1742        } else if !split_decoders.is_empty() {
1743            Self::verify_modelpack_split_det(&split_decoders)?;
1744            Ok(ModelType::ModelPackDetSplit {
1745                detection: split_decoders,
1746            })
1747        } else if let (Some(scores), Some(boxes)) = (scores_, boxes_) {
1748            Self::verify_modelpack_det(&boxes, &scores)?;
1749            Ok(ModelType::ModelPackDet { boxes, scores })
1750        } else {
1751            Err(DecoderError::InvalidConfig(
1752                "Invalid ModelPack model outputs".to_string(),
1753            ))
1754        }
1755    }
1756
1757    fn verify_modelpack_det(
1758        boxes: &configs::Boxes,
1759        scores: &configs::Scores,
1760    ) -> Result<usize, DecoderError> {
1761        if boxes.shape.len() != 4 {
1762            return Err(DecoderError::InvalidConfig(format!(
1763                "Invalid ModelPack Boxes shape {:?}",
1764                boxes.shape
1765            )));
1766        }
1767        if scores.shape.len() != 3 {
1768            return Err(DecoderError::InvalidConfig(format!(
1769                "Invalid ModelPack Scores shape {:?}",
1770                scores.shape
1771            )));
1772        }
1773
1774        Self::verify_dshapes(
1775            &boxes.dshape,
1776            &boxes.shape,
1777            "Boxes",
1778            &[
1779                DimName::Batch,
1780                DimName::NumBoxes,
1781                DimName::Padding,
1782                DimName::BoxCoords,
1783            ],
1784        )?;
1785        Self::verify_dshapes(
1786            &scores.dshape,
1787            &scores.shape,
1788            "Scores",
1789            &[DimName::Batch, DimName::NumBoxes, DimName::NumClasses],
1790        )?;
1791
1792        let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[1]);
1793        let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[1]);
1794
1795        if boxes_num != scores_num {
1796            return Err(DecoderError::InvalidConfig(format!(
1797                "ModelPack Detection Boxes num {} incompatible with Scores num {}",
1798                boxes_num, scores_num
1799            )));
1800        }
1801
1802        let num_classes = if !scores.dshape.is_empty() {
1803            Self::get_class_count(&scores.dshape, None, None)?
1804        } else {
1805            Self::get_class_count_no_dshape(scores.into(), None)?
1806        };
1807
1808        Ok(num_classes)
1809    }
1810
1811    fn verify_modelpack_split_det(boxes: &[configs::Detection]) -> Result<usize, DecoderError> {
1812        let mut num_classes = None;
1813        for b in boxes {
1814            let Some(num_anchors) = b.anchors.as_ref().map(|a| a.len()) else {
1815                return Err(DecoderError::InvalidConfig(
1816                    "ModelPack Split Detection missing anchors".to_string(),
1817                ));
1818            };
1819
1820            if num_anchors == 0 {
1821                return Err(DecoderError::InvalidConfig(
1822                    "ModelPack Split Detection has zero anchors".to_string(),
1823                ));
1824            }
1825
1826            if b.shape.len() != 4 {
1827                return Err(DecoderError::InvalidConfig(format!(
1828                    "Invalid ModelPack Split Detection shape {:?}",
1829                    b.shape
1830                )));
1831            }
1832
1833            Self::verify_dshapes(
1834                &b.dshape,
1835                &b.shape,
1836                "Split Detection",
1837                &[
1838                    DimName::Batch,
1839                    DimName::Height,
1840                    DimName::Width,
1841                    DimName::NumAnchorsXFeatures,
1842                ],
1843            )?;
1844            let classes = if !b.dshape.is_empty() {
1845                Self::get_class_count(&b.dshape, None, Some(num_anchors))?
1846            } else {
1847                Self::get_class_count_no_dshape(b.into(), None)?
1848            };
1849
1850            match num_classes {
1851                Some(n) => {
1852                    if n != classes {
1853                        return Err(DecoderError::InvalidConfig(format!(
1854                            "ModelPack Split Detection inconsistent number of classes: previous {}, current {}",
1855                            n, classes
1856                        )));
1857                    }
1858                }
1859                None => {
1860                    num_classes = Some(classes);
1861                }
1862            }
1863        }
1864
1865        Ok(num_classes.unwrap_or(0))
1866    }
1867
1868    fn verify_modelpack_seg(
1869        segmentation: &configs::Segmentation,
1870        classes: Option<usize>,
1871    ) -> Result<(), DecoderError> {
1872        if segmentation.shape.len() != 4 {
1873            return Err(DecoderError::InvalidConfig(format!(
1874                "Invalid ModelPack Segmentation shape {:?}",
1875                segmentation.shape
1876            )));
1877        }
1878        Self::verify_dshapes(
1879            &segmentation.dshape,
1880            &segmentation.shape,
1881            "Segmentation",
1882            &[
1883                DimName::Batch,
1884                DimName::Height,
1885                DimName::Width,
1886                DimName::NumClasses,
1887            ],
1888        )?;
1889
1890        if let Some(classes) = classes {
1891            let seg_classes = if !segmentation.dshape.is_empty() {
1892                Self::get_class_count(&segmentation.dshape, None, None)?
1893            } else {
1894                Self::get_class_count_no_dshape(segmentation.into(), None)?
1895            };
1896
1897            if seg_classes != classes + 1 {
1898                return Err(DecoderError::InvalidConfig(format!(
1899                    "ModelPack Segmentation channels {} incompatible with number of classes {}",
1900                    seg_classes, classes
1901                )));
1902            }
1903        }
1904        Ok(())
1905    }
1906
1907    // verifies that dshapes match the given shape
1908    fn verify_dshapes(
1909        dshape: &[(DimName, usize)],
1910        shape: &[usize],
1911        name: &str,
1912        dims: &[DimName],
1913    ) -> Result<(), DecoderError> {
1914        for s in shape {
1915            if *s == 0 {
1916                return Err(DecoderError::InvalidConfig(format!(
1917                    "{} shape has zero dimension",
1918                    name
1919                )));
1920            }
1921        }
1922
1923        if shape.len() != dims.len() {
1924            return Err(DecoderError::InvalidConfig(format!(
1925                "{} shape length {} does not match expected dims length {}",
1926                name,
1927                shape.len(),
1928                dims.len()
1929            )));
1930        }
1931
1932        if dshape.is_empty() {
1933            return Ok(());
1934        }
1935        // check the dshape lengths match the shape lengths
1936        if dshape.len() != shape.len() {
1937            return Err(DecoderError::InvalidConfig(format!(
1938                "{} dshape length does not match shape length",
1939                name
1940            )));
1941        }
1942
1943        // check the dshape values match the shape values
1944        for ((dim_name, dim_size), shape_size) in dshape.iter().zip(shape) {
1945            if dim_size != shape_size {
1946                return Err(DecoderError::InvalidConfig(format!(
1947                    "{} dshape dimension {} size {} does not match shape size {}",
1948                    name, dim_name, dim_size, shape_size
1949                )));
1950            }
1951            if *dim_name == DimName::Padding && *dim_size != 1 {
1952                return Err(DecoderError::InvalidConfig(
1953                    "Padding dimension size must be 1".to_string(),
1954                ));
1955            }
1956
1957            if *dim_name == DimName::BoxCoords && *dim_size != 4 {
1958                return Err(DecoderError::InvalidConfig(
1959                    "BoxCoords dimension size must be 4".to_string(),
1960                ));
1961            }
1962        }
1963
1964        let dims_present = HashSet::<DimName>::from_iter(dshape.iter().map(|(name, _)| *name));
1965        for dim in dims {
1966            if !dims_present.contains(dim) {
1967                return Err(DecoderError::InvalidConfig(format!(
1968                    "{} dshape missing required dimension {:?}",
1969                    name, dim
1970                )));
1971            }
1972        }
1973
1974        Ok(())
1975    }
1976
1977    fn get_box_count(dshape: &[(DimName, usize)]) -> Option<usize> {
1978        for (dim_name, dim_size) in dshape {
1979            if *dim_name == DimName::NumBoxes {
1980                return Some(*dim_size);
1981            }
1982        }
1983        None
1984    }
1985
1986    fn get_class_count_no_dshape(
1987        config: ConfigOutputRef,
1988        protos: Option<usize>,
1989    ) -> Result<usize, DecoderError> {
1990        match config {
1991            ConfigOutputRef::Detection(detection) => match detection.decoder {
1992                DecoderType::Ultralytics => {
1993                    if detection.shape[1] <= 4 + protos.unwrap_or(0) {
1994                        return Err(DecoderError::InvalidConfig(format!(
1995                            "Invalid shape: Yolo num_features {} must be greater than {}",
1996                            detection.shape[1],
1997                            4 + protos.unwrap_or(0),
1998                        )));
1999                    }
2000                    Ok(detection.shape[1] - 4 - protos.unwrap_or(0))
2001                }
2002                DecoderType::ModelPack => {
2003                    let Some(num_anchors) = detection.anchors.as_ref().map(|a| a.len()) else {
2004                        return Err(DecoderError::Internal(
2005                            "ModelPack Detection missing anchors".to_string(),
2006                        ));
2007                    };
2008                    let anchors_x_features = detection.shape[3];
2009                    if anchors_x_features <= num_anchors * 5 {
2010                        return Err(DecoderError::InvalidConfig(format!(
2011                            "Invalid ModelPack Split Detection shape: anchors_x_features {} not greater than number of anchors * 5 = {}",
2012                            anchors_x_features,
2013                            num_anchors * 5,
2014                        )));
2015                    }
2016
2017                    if !anchors_x_features.is_multiple_of(num_anchors) {
2018                        return Err(DecoderError::InvalidConfig(format!(
2019                            "Invalid ModelPack Split Detection shape: anchors_x_features {} not a multiple of number of anchors {}",
2020                            anchors_x_features, num_anchors
2021                        )));
2022                    }
2023                    Ok(anchors_x_features / num_anchors - 5)
2024                }
2025            },
2026
2027            ConfigOutputRef::Scores(scores) => match scores.decoder {
2028                DecoderType::Ultralytics => Ok(scores.shape[1]),
2029                DecoderType::ModelPack => Ok(scores.shape[2]),
2030            },
2031            ConfigOutputRef::Segmentation(seg) => Ok(seg.shape[3]),
2032            _ => Err(DecoderError::Internal(
2033                "Attempted to get class count from unsupported config output".to_owned(),
2034            )),
2035        }
2036    }
2037
2038    // get the class count from dshape or calculate from num_features
2039    fn get_class_count(
2040        dshape: &[(DimName, usize)],
2041        protos: Option<usize>,
2042        anchors: Option<usize>,
2043    ) -> Result<usize, DecoderError> {
2044        if dshape.is_empty() {
2045            return Ok(0);
2046        }
2047        // if it has num_classes in dshape, return it
2048        for (dim_name, dim_size) in dshape {
2049            if *dim_name == DimName::NumClasses {
2050                return Ok(*dim_size);
2051            }
2052        }
2053
2054        // number of classes can be calculated from num_features - 4 for yolo.  If the
2055        // model has protos, we also subtract the number of protos.
2056        for (dim_name, dim_size) in dshape {
2057            if *dim_name == DimName::NumFeatures {
2058                let protos = protos.unwrap_or(0);
2059                if protos + 4 >= *dim_size {
2060                    return Err(DecoderError::InvalidConfig(format!(
2061                        "Invalid shape: Yolo num_features {} must be greater than {}",
2062                        *dim_size,
2063                        protos + 4,
2064                    )));
2065                }
2066                return Ok(*dim_size - 4 - protos);
2067            }
2068        }
2069
2070        // number of classes can be calculated from number of anchors for modelpack
2071        // split detection
2072        if let Some(num_anchors) = anchors {
2073            for (dim_name, dim_size) in dshape {
2074                if *dim_name == DimName::NumAnchorsXFeatures {
2075                    let anchors_x_features = *dim_size;
2076                    if anchors_x_features <= num_anchors * 5 {
2077                        return Err(DecoderError::InvalidConfig(format!(
2078                            "Invalid ModelPack Split Detection shape: anchors_x_features {} not greater than number of anchors * 5 = {}",
2079                            anchors_x_features,
2080                            num_anchors * 5,
2081                        )));
2082                    }
2083
2084                    if !anchors_x_features.is_multiple_of(num_anchors) {
2085                        return Err(DecoderError::InvalidConfig(format!(
2086                            "Invalid ModelPack Split Detection shape: anchors_x_features {} not a multiple of number of anchors {}",
2087                            anchors_x_features, num_anchors
2088                        )));
2089                    }
2090                    return Ok((anchors_x_features / num_anchors) - 5);
2091                }
2092            }
2093        }
2094        Err(DecoderError::InvalidConfig(
2095            "Cannot determine number of classes from dshape".to_owned(),
2096        ))
2097    }
2098
2099    fn get_protos_count(dshape: &[(DimName, usize)]) -> Option<usize> {
2100        for (dim_name, dim_size) in dshape {
2101            if *dim_name == DimName::NumProtos {
2102                return Some(*dim_size);
2103            }
2104        }
2105        None
2106    }
2107}
2108
2109#[derive(Debug, Clone, PartialEq)]
2110pub struct Decoder {
2111    model_type: ModelType,
2112    pub iou_threshold: f32,
2113    pub score_threshold: f32,
2114    /// NMS mode: Some(mode) applies NMS, None bypasses NMS (for end-to-end
2115    /// models)
2116    pub nms: Option<configs::Nms>,
2117    /// Whether decoded boxes are in normalized [0,1] coordinates.
2118    /// - `Some(true)`: Coordinates in [0,1] range
2119    /// - `Some(false)`: Pixel coordinates
2120    /// - `None`: Unknown, caller must infer (e.g., check if any coordinate >
2121    ///   1.0)
2122    normalized: Option<bool>,
2123}
2124
2125#[derive(Debug)]
2126pub enum ArrayViewDQuantized<'a> {
2127    UInt8(ArrayViewD<'a, u8>),
2128    Int8(ArrayViewD<'a, i8>),
2129    UInt16(ArrayViewD<'a, u16>),
2130    Int16(ArrayViewD<'a, i16>),
2131    UInt32(ArrayViewD<'a, u32>),
2132    Int32(ArrayViewD<'a, i32>),
2133}
2134
2135impl<'a, D> From<ArrayView<'a, u8, D>> for ArrayViewDQuantized<'a>
2136where
2137    D: Dimension,
2138{
2139    fn from(arr: ArrayView<'a, u8, D>) -> Self {
2140        Self::UInt8(arr.into_dyn())
2141    }
2142}
2143
2144impl<'a, D> From<ArrayView<'a, i8, D>> for ArrayViewDQuantized<'a>
2145where
2146    D: Dimension,
2147{
2148    fn from(arr: ArrayView<'a, i8, D>) -> Self {
2149        Self::Int8(arr.into_dyn())
2150    }
2151}
2152
2153impl<'a, D> From<ArrayView<'a, u16, D>> for ArrayViewDQuantized<'a>
2154where
2155    D: Dimension,
2156{
2157    fn from(arr: ArrayView<'a, u16, D>) -> Self {
2158        Self::UInt16(arr.into_dyn())
2159    }
2160}
2161
2162impl<'a, D> From<ArrayView<'a, i16, D>> for ArrayViewDQuantized<'a>
2163where
2164    D: Dimension,
2165{
2166    fn from(arr: ArrayView<'a, i16, D>) -> Self {
2167        Self::Int16(arr.into_dyn())
2168    }
2169}
2170
2171impl<'a, D> From<ArrayView<'a, u32, D>> for ArrayViewDQuantized<'a>
2172where
2173    D: Dimension,
2174{
2175    fn from(arr: ArrayView<'a, u32, D>) -> Self {
2176        Self::UInt32(arr.into_dyn())
2177    }
2178}
2179
2180impl<'a, D> From<ArrayView<'a, i32, D>> for ArrayViewDQuantized<'a>
2181where
2182    D: Dimension,
2183{
2184    fn from(arr: ArrayView<'a, i32, D>) -> Self {
2185        Self::Int32(arr.into_dyn())
2186    }
2187}
2188
2189impl<'a> ArrayViewDQuantized<'a> {
2190    /// Returns the shape of the underlying array.
2191    ///
2192    /// # Examples
2193    /// ```rust
2194    /// # use edgefirst_decoder::ArrayViewDQuantized;
2195    /// # use ndarray::Array2;
2196    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
2197    /// let arr = Array2::from_shape_vec((2, 3), vec![1u8, 2, 3, 4, 5, 6])?;
2198    /// let view = ArrayViewDQuantized::from(arr.view().into_dyn());
2199    /// assert_eq!(view.shape(), &[2, 3]);
2200    /// # Ok(())
2201    /// # }
2202    /// ```
2203    pub fn shape(&self) -> &[usize] {
2204        match self {
2205            ArrayViewDQuantized::UInt8(a) => a.shape(),
2206            ArrayViewDQuantized::Int8(a) => a.shape(),
2207            ArrayViewDQuantized::UInt16(a) => a.shape(),
2208            ArrayViewDQuantized::Int16(a) => a.shape(),
2209            ArrayViewDQuantized::UInt32(a) => a.shape(),
2210            ArrayViewDQuantized::Int32(a) => a.shape(),
2211        }
2212    }
2213}
2214
2215macro_rules! with_quantized {
2216    ($x:expr, $var:ident, $body:expr) => {
2217        match $x {
2218            ArrayViewDQuantized::UInt8(x) => {
2219                let $var = x;
2220                $body
2221            }
2222            ArrayViewDQuantized::Int8(x) => {
2223                let $var = x;
2224                $body
2225            }
2226            ArrayViewDQuantized::UInt16(x) => {
2227                let $var = x;
2228                $body
2229            }
2230            ArrayViewDQuantized::Int16(x) => {
2231                let $var = x;
2232                $body
2233            }
2234            ArrayViewDQuantized::UInt32(x) => {
2235                let $var = x;
2236                $body
2237            }
2238            ArrayViewDQuantized::Int32(x) => {
2239                let $var = x;
2240                $body
2241            }
2242        }
2243    };
2244}
2245
2246impl Decoder {
2247    /// This function returns the parsed model type of the decoder.
2248    ///
2249    /// # Examples
2250    ///
2251    /// ```rust
2252    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, configs::ModelType};
2253    /// # fn main() -> DecoderResult<()> {
2254    /// #    let config_yaml = include_str!("../../../testdata/modelpack_split.yaml").to_string();
2255    ///     let decoder = DecoderBuilder::default()
2256    ///         .with_config_yaml_str(config_yaml)
2257    ///         .build()?;
2258    ///     assert!(matches!(
2259    ///         decoder.model_type(),
2260    ///         ModelType::ModelPackDetSplit { .. }
2261    ///     ));
2262    /// #    Ok(())
2263    /// # }
2264    /// ```
2265    pub fn model_type(&self) -> &ModelType {
2266        &self.model_type
2267    }
2268
2269    /// Returns the box coordinate format if known from the model config.
2270    ///
2271    /// - `Some(true)`: Boxes are in normalized [0,1] coordinates
2272    /// - `Some(false)`: Boxes are in pixel coordinates relative to model input
2273    /// - `None`: Unknown, caller must infer (e.g., check if any coordinate >
2274    ///   1.0)
2275    ///
2276    /// This is determined by the model config's `normalized` field, not the NMS
2277    /// mode. When coordinates are in pixels or unknown, the caller may need
2278    /// to normalize using the model input dimensions.
2279    ///
2280    /// # Examples
2281    ///
2282    /// ```rust
2283    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
2284    /// # fn main() -> DecoderResult<()> {
2285    /// #    let config_yaml = include_str!("../../../testdata/modelpack_split.yaml").to_string();
2286    ///     let decoder = DecoderBuilder::default()
2287    ///         .with_config_yaml_str(config_yaml)
2288    ///         .build()?;
2289    ///     // Config doesn't specify normalized, so it's None
2290    ///     assert!(decoder.normalized_boxes().is_none());
2291    /// #    Ok(())
2292    /// # }
2293    /// ```
2294    pub fn normalized_boxes(&self) -> Option<bool> {
2295        self.normalized
2296    }
2297
2298    /// This function decodes quantized model outputs into detection boxes and
2299    /// segmentation masks. The quantized outputs can be of u8, i8, u16, i16,
2300    /// u32, or i32 types. Up to `output_boxes.capacity()` boxes and masks
2301    /// will be decoded. The function clears the provided output vectors
2302    /// before populating them with the decoded results.
2303    ///
2304    /// This function returns a `DecoderError` if the the provided outputs don't
2305    /// match the configuration provided by the user when building the decoder.
2306    ///
2307    /// # Examples
2308    ///
2309    /// ```rust
2310    /// # use edgefirst_decoder::{BoundingBox, DecoderBuilder, DetectBox, DecoderResult};
2311    /// # use ndarray::Array4;
2312    /// # fn main() -> DecoderResult<()> {
2313    /// #    let detect0 = include_bytes!("../../../testdata/modelpack_split_9x15x18.bin");
2314    /// #    let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec())?;
2315    /// #
2316    /// #    let detect1 = include_bytes!("../../../testdata/modelpack_split_17x30x18.bin");
2317    /// #    let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec())?;
2318    /// #    let model_output = vec![
2319    /// #        detect1.view().into_dyn().into(),
2320    /// #        detect0.view().into_dyn().into(),
2321    /// #    ];
2322    /// let decoder = DecoderBuilder::default()
2323    ///     .with_config_yaml_str(include_str!("../../../testdata/modelpack_split.yaml").to_string())
2324    ///     .with_score_threshold(0.45)
2325    ///     .with_iou_threshold(0.45)
2326    ///     .build()?;
2327    ///
2328    /// let mut output_boxes: Vec<_> = Vec::with_capacity(10);
2329    /// let mut output_masks: Vec<_> = Vec::with_capacity(10);
2330    /// decoder.decode_quantized(&model_output, &mut output_boxes, &mut output_masks)?;
2331    /// assert!(output_boxes[0].equal_within_delta(
2332    ///     &DetectBox {
2333    ///         bbox: BoundingBox {
2334    ///             xmin: 0.43171933,
2335    ///             ymin: 0.68243736,
2336    ///             xmax: 0.5626645,
2337    ///             ymax: 0.808863,
2338    ///         },
2339    ///         score: 0.99240804,
2340    ///         label: 0
2341    ///     },
2342    ///     1e-6
2343    /// ));
2344    /// #    Ok(())
2345    /// # }
2346    /// ```
2347    pub fn decode_quantized(
2348        &self,
2349        outputs: &[ArrayViewDQuantized],
2350        output_boxes: &mut Vec<DetectBox>,
2351        output_masks: &mut Vec<Segmentation>,
2352    ) -> Result<(), DecoderError> {
2353        output_boxes.clear();
2354        output_masks.clear();
2355        match &self.model_type {
2356            ModelType::ModelPackSegDet {
2357                boxes,
2358                scores,
2359                segmentation,
2360            } => {
2361                self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)?;
2362                self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
2363            }
2364            ModelType::ModelPackSegDetSplit {
2365                detection,
2366                segmentation,
2367            } => {
2368                self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)?;
2369                self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
2370            }
2371            ModelType::ModelPackDet { boxes, scores } => {
2372                self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)
2373            }
2374            ModelType::ModelPackDetSplit { detection } => {
2375                self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)
2376            }
2377            ModelType::ModelPackSeg { segmentation } => {
2378                self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
2379            }
2380            ModelType::YoloDet { boxes } => {
2381                self.decode_yolo_det_quantized(outputs, boxes, output_boxes)
2382            }
2383            ModelType::YoloSegDet { boxes, protos } => self.decode_yolo_segdet_quantized(
2384                outputs,
2385                boxes,
2386                protos,
2387                output_boxes,
2388                output_masks,
2389            ),
2390            ModelType::YoloSplitDet { boxes, scores } => {
2391                self.decode_yolo_split_det_quantized(outputs, boxes, scores, output_boxes)
2392            }
2393            ModelType::YoloSplitSegDet {
2394                boxes,
2395                scores,
2396                mask_coeff,
2397                protos,
2398            } => self.decode_yolo_split_segdet_quantized(
2399                outputs,
2400                boxes,
2401                scores,
2402                mask_coeff,
2403                protos,
2404                output_boxes,
2405                output_masks,
2406            ),
2407            ModelType::YoloEndToEndDet { .. } | ModelType::YoloEndToEndSegDet { .. } => {
2408                Err(DecoderError::InvalidConfig(
2409                    "End-to-end models require float decode, not quantized".to_string(),
2410                ))
2411            }
2412        }
2413    }
2414
2415    /// This function decodes floating point model outputs into detection boxes
2416    /// and segmentation masks. Up to `output_boxes.capacity()` boxes and
2417    /// masks will be decoded. The function clears the provided output
2418    /// vectors before populating them with the decoded results.
2419    ///
2420    /// This function returns an `Error` if the the provided outputs don't
2421    /// match the configuration provided by the user when building the decoder.
2422    ///
2423    /// Any quantization information in the configuration will be ignored.
2424    ///
2425    /// # Examples
2426    ///
2427    /// ```rust
2428    /// # use edgefirst_decoder::{BoundingBox, DecoderBuilder, DetectBox, DecoderResult, configs, configs::{DecoderType, DecoderVersion}, dequantize_cpu, Quantization};
2429    /// # use ndarray::Array3;
2430    /// # fn main() -> DecoderResult<()> {
2431    /// #   let out = include_bytes!("../../../testdata/yolov8s_80_classes.bin");
2432    /// #   let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
2433    /// #   let mut out_dequant = vec![0.0_f64; 84 * 8400];
2434    /// #   let quant = Quantization::new(0.0040811873, -123);
2435    /// #   dequantize_cpu(out, quant, &mut out_dequant);
2436    /// #   let model_output_f64 = Array3::from_shape_vec((1, 84, 8400), out_dequant)?.into_dyn();
2437    ///    let decoder = DecoderBuilder::default()
2438    ///     .with_config_yolo_det(configs::Detection {
2439    ///         decoder: DecoderType::Ultralytics,
2440    ///         quantization: None,
2441    ///         shape: vec![1, 84, 8400],
2442    ///         anchors: None,
2443    ///         dshape: Vec::new(),
2444    ///         normalized: Some(true),
2445    ///     },
2446    ///     Some(DecoderVersion::Yolo11))
2447    ///     .with_score_threshold(0.25)
2448    ///     .with_iou_threshold(0.7)
2449    ///     .build()?;
2450    ///
2451    /// let mut output_boxes: Vec<_> = Vec::with_capacity(10);
2452    /// let mut output_masks: Vec<_> = Vec::with_capacity(10);
2453    /// let model_output_f64 = vec![model_output_f64.view().into()];
2454    /// decoder.decode_float(&model_output_f64, &mut output_boxes, &mut output_masks)?;    
2455    /// assert!(output_boxes[0].equal_within_delta(
2456    ///        &DetectBox {
2457    ///            bbox: BoundingBox {
2458    ///                xmin: 0.5285137,
2459    ///                ymin: 0.05305544,
2460    ///                xmax: 0.87541467,
2461    ///                ymax: 0.9998909,
2462    ///            },
2463    ///            score: 0.5591227,
2464    ///            label: 0
2465    ///        },
2466    ///        1e-6
2467    ///    ));
2468    ///
2469    /// #    Ok(())
2470    /// # }
2471    pub fn decode_float<T>(
2472        &self,
2473        outputs: &[ArrayViewD<T>],
2474        output_boxes: &mut Vec<DetectBox>,
2475        output_masks: &mut Vec<Segmentation>,
2476    ) -> Result<(), DecoderError>
2477    where
2478        T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
2479        f32: AsPrimitive<T>,
2480    {
2481        output_boxes.clear();
2482        output_masks.clear();
2483        match &self.model_type {
2484            ModelType::ModelPackSegDet {
2485                boxes,
2486                scores,
2487                segmentation,
2488            } => {
2489                self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
2490                self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
2491            }
2492            ModelType::ModelPackSegDetSplit {
2493                detection,
2494                segmentation,
2495            } => {
2496                self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
2497                self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
2498            }
2499            ModelType::ModelPackDet { boxes, scores } => {
2500                self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
2501            }
2502            ModelType::ModelPackDetSplit { detection } => {
2503                self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
2504            }
2505            ModelType::ModelPackSeg { segmentation } => {
2506                self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
2507            }
2508            ModelType::YoloDet { boxes } => {
2509                self.decode_yolo_det_float(outputs, boxes, output_boxes)?;
2510            }
2511            ModelType::YoloSegDet { boxes, protos } => {
2512                self.decode_yolo_segdet_float(outputs, boxes, protos, output_boxes, output_masks)?;
2513            }
2514            ModelType::YoloSplitDet { boxes, scores } => {
2515                self.decode_yolo_split_det_float(outputs, boxes, scores, output_boxes)?;
2516            }
2517            ModelType::YoloSplitSegDet {
2518                boxes,
2519                scores,
2520                mask_coeff,
2521                protos,
2522            } => {
2523                self.decode_yolo_split_segdet_float(
2524                    outputs,
2525                    boxes,
2526                    scores,
2527                    mask_coeff,
2528                    protos,
2529                    output_boxes,
2530                    output_masks,
2531                )?;
2532            }
2533            ModelType::YoloEndToEndDet { boxes } => {
2534                self.decode_yolo_end_to_end_det_float(outputs, boxes, output_boxes)?;
2535            }
2536            ModelType::YoloEndToEndSegDet { boxes, protos } => {
2537                self.decode_yolo_end_to_end_segdet_float(
2538                    outputs,
2539                    boxes,
2540                    protos,
2541                    output_boxes,
2542                    output_masks,
2543                )?;
2544            }
2545        }
2546        Ok(())
2547    }
2548
2549    fn decode_modelpack_det_quantized(
2550        &self,
2551        outputs: &[ArrayViewDQuantized],
2552        boxes: &configs::Boxes,
2553        scores: &configs::Scores,
2554        output_boxes: &mut Vec<DetectBox>,
2555    ) -> Result<(), DecoderError> {
2556        let (boxes_tensor, ind) =
2557            Self::find_outputs_with_shape_quantized(&boxes.shape, outputs, &[])?;
2558        let (scores_tensor, _) =
2559            Self::find_outputs_with_shape_quantized(&scores.shape, outputs, &[ind])?;
2560        let quant_boxes = boxes
2561            .quantization
2562            .map(Quantization::from)
2563            .unwrap_or_default();
2564        let quant_scores = scores
2565            .quantization
2566            .map(Quantization::from)
2567            .unwrap_or_default();
2568
2569        with_quantized!(boxes_tensor, b, {
2570            with_quantized!(scores_tensor, s, {
2571                let boxes_tensor = Self::swap_axes_if_needed(b, boxes.into());
2572                let boxes_tensor = boxes_tensor.slice(s![0, .., 0, ..]);
2573
2574                let scores_tensor = Self::swap_axes_if_needed(s, scores.into());
2575                let scores_tensor = scores_tensor.slice(s![0, .., ..]);
2576                decode_modelpack_det(
2577                    (boxes_tensor, quant_boxes),
2578                    (scores_tensor, quant_scores),
2579                    self.score_threshold,
2580                    self.iou_threshold,
2581                    output_boxes,
2582                );
2583            });
2584        });
2585
2586        Ok(())
2587    }
2588
2589    fn decode_modelpack_seg_quantized(
2590        &self,
2591        outputs: &[ArrayViewDQuantized],
2592        segmentation: &configs::Segmentation,
2593        output_masks: &mut Vec<Segmentation>,
2594    ) -> Result<(), DecoderError> {
2595        let (seg, _) = Self::find_outputs_with_shape_quantized(&segmentation.shape, outputs, &[])?;
2596
2597        macro_rules! modelpack_seg {
2598            ($seg:expr, $body:expr) => {{
2599                let seg = Self::swap_axes_if_needed($seg, segmentation.into());
2600                let seg = seg.slice(s![0, .., .., ..]);
2601                seg.mapv($body)
2602            }};
2603        }
2604        use ArrayViewDQuantized::*;
2605        let seg = match seg {
2606            UInt8(s) => {
2607                modelpack_seg!(s, |x| x)
2608            }
2609            Int8(s) => {
2610                modelpack_seg!(s, |x| (x as i16 + 128) as u8)
2611            }
2612            UInt16(s) => {
2613                modelpack_seg!(s, |x| (x >> 8) as u8)
2614            }
2615            Int16(s) => {
2616                modelpack_seg!(s, |x| ((x as i32 + 32768) >> 8) as u8)
2617            }
2618            UInt32(s) => {
2619                modelpack_seg!(s, |x| (x >> 24) as u8)
2620            }
2621            Int32(s) => {
2622                modelpack_seg!(s, |x| ((x as i64 + 2147483648) >> 24) as u8)
2623            }
2624        };
2625
2626        output_masks.push(Segmentation {
2627            xmin: 0.0,
2628            ymin: 0.0,
2629            xmax: 1.0,
2630            ymax: 1.0,
2631            segmentation: seg,
2632        });
2633        Ok(())
2634    }
2635
2636    fn decode_modelpack_det_split_quantized(
2637        &self,
2638        outputs: &[ArrayViewDQuantized],
2639        detection: &[configs::Detection],
2640        output_boxes: &mut Vec<DetectBox>,
2641    ) -> Result<(), DecoderError> {
2642        let new_detection = detection
2643            .iter()
2644            .map(|x| match &x.anchors {
2645                None => Err(DecoderError::InvalidConfig(
2646                    "ModelPack Split Detection missing anchors".to_string(),
2647                )),
2648                Some(a) => Ok(ModelPackDetectionConfig {
2649                    anchors: a.clone(),
2650                    quantization: None,
2651                }),
2652            })
2653            .collect::<Result<Vec<_>, _>>()?;
2654        let new_outputs = Self::match_outputs_to_detect_quantized(detection, outputs)?;
2655
2656        macro_rules! dequant_output {
2657            ($det_tensor:expr, $detection:expr) => {{
2658                let det_tensor = Self::swap_axes_if_needed($det_tensor, $detection.into());
2659                let det_tensor = det_tensor.slice(s![0, .., .., ..]);
2660                if let Some(q) = $detection.quantization {
2661                    dequantize_ndarray(det_tensor, q.into())
2662                } else {
2663                    det_tensor.map(|x| *x as f32)
2664                }
2665            }};
2666        }
2667
2668        let new_outputs = new_outputs
2669            .iter()
2670            .zip(detection)
2671            .map(|(det_tensor, detection)| {
2672                with_quantized!(det_tensor, d, dequant_output!(d, detection))
2673            })
2674            .collect::<Vec<_>>();
2675
2676        let new_outputs_view = new_outputs
2677            .iter()
2678            .map(|d: &Array3<f32>| d.view())
2679            .collect::<Vec<_>>();
2680        decode_modelpack_split_float(
2681            &new_outputs_view,
2682            &new_detection,
2683            self.score_threshold,
2684            self.iou_threshold,
2685            output_boxes,
2686        );
2687        Ok(())
2688    }
2689
2690    fn decode_yolo_det_quantized(
2691        &self,
2692        outputs: &[ArrayViewDQuantized],
2693        boxes: &configs::Detection,
2694        output_boxes: &mut Vec<DetectBox>,
2695    ) -> Result<(), DecoderError> {
2696        let (boxes_tensor, _) =
2697            Self::find_outputs_with_shape_quantized(&boxes.shape, outputs, &[])?;
2698        let quant_boxes = boxes
2699            .quantization
2700            .map(Quantization::from)
2701            .unwrap_or_default();
2702
2703        with_quantized!(boxes_tensor, b, {
2704            let boxes_tensor = Self::swap_axes_if_needed(b, boxes.into());
2705            let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
2706            decode_yolo_det(
2707                (boxes_tensor, quant_boxes),
2708                self.score_threshold,
2709                self.iou_threshold,
2710                self.nms,
2711                output_boxes,
2712            );
2713        });
2714
2715        Ok(())
2716    }
2717
2718    fn decode_yolo_segdet_quantized(
2719        &self,
2720        outputs: &[ArrayViewDQuantized],
2721        boxes: &configs::Detection,
2722        protos: &configs::Protos,
2723        output_boxes: &mut Vec<DetectBox>,
2724        output_masks: &mut Vec<Segmentation>,
2725    ) -> Result<(), DecoderError> {
2726        let (boxes_tensor, ind) =
2727            Self::find_outputs_with_shape_quantized(&boxes.shape, outputs, &[])?;
2728        let (protos_tensor, _) =
2729            Self::find_outputs_with_shape_quantized(&protos.shape, outputs, &[ind])?;
2730
2731        let quant_boxes = boxes
2732            .quantization
2733            .map(Quantization::from)
2734            .unwrap_or_default();
2735        let quant_protos = protos
2736            .quantization
2737            .map(Quantization::from)
2738            .unwrap_or_default();
2739
2740        with_quantized!(boxes_tensor, b, {
2741            with_quantized!(protos_tensor, p, {
2742                let box_tensor = Self::swap_axes_if_needed(b, boxes.into());
2743                let box_tensor = box_tensor.slice(s![0, .., ..]);
2744
2745                let protos_tensor = Self::swap_axes_if_needed(p, protos.into());
2746                let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
2747                decode_yolo_segdet_quant(
2748                    (box_tensor, quant_boxes),
2749                    (protos_tensor, quant_protos),
2750                    self.score_threshold,
2751                    self.iou_threshold,
2752                    self.nms,
2753                    output_boxes,
2754                    output_masks,
2755                );
2756            });
2757        });
2758
2759        Ok(())
2760    }
2761
2762    fn decode_yolo_split_det_quantized(
2763        &self,
2764        outputs: &[ArrayViewDQuantized],
2765        boxes: &configs::Boxes,
2766        scores: &configs::Scores,
2767        output_boxes: &mut Vec<DetectBox>,
2768    ) -> Result<(), DecoderError> {
2769        let (boxes_tensor, ind) =
2770            Self::find_outputs_with_shape_quantized(&boxes.shape, outputs, &[])?;
2771        let (scores_tensor, _) =
2772            Self::find_outputs_with_shape_quantized(&scores.shape, outputs, &[ind])?;
2773        let quant_boxes = boxes
2774            .quantization
2775            .map(Quantization::from)
2776            .unwrap_or_default();
2777        let quant_scores = scores
2778            .quantization
2779            .map(Quantization::from)
2780            .unwrap_or_default();
2781
2782        with_quantized!(boxes_tensor, b, {
2783            with_quantized!(scores_tensor, s, {
2784                let boxes_tensor = Self::swap_axes_if_needed(b, boxes.into());
2785                let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
2786
2787                let scores_tensor = Self::swap_axes_if_needed(s, scores.into());
2788                let scores_tensor = scores_tensor.slice(s![0, .., ..]);
2789                decode_yolo_split_det_quant(
2790                    (boxes_tensor, quant_boxes),
2791                    (scores_tensor, quant_scores),
2792                    self.score_threshold,
2793                    self.iou_threshold,
2794                    self.nms,
2795                    output_boxes,
2796                );
2797            });
2798        });
2799
2800        Ok(())
2801    }
2802
2803    #[allow(clippy::too_many_arguments)]
2804    fn decode_yolo_split_segdet_quantized(
2805        &self,
2806        outputs: &[ArrayViewDQuantized],
2807        boxes: &configs::Boxes,
2808        scores: &configs::Scores,
2809        mask_coeff: &configs::MaskCoefficients,
2810        protos: &configs::Protos,
2811        output_boxes: &mut Vec<DetectBox>,
2812        output_masks: &mut Vec<Segmentation>,
2813    ) -> Result<(), DecoderError> {
2814        let quant_boxes = boxes
2815            .quantization
2816            .map(Quantization::from)
2817            .unwrap_or_default();
2818        let quant_scores = scores
2819            .quantization
2820            .map(Quantization::from)
2821            .unwrap_or_default();
2822        let quant_masks = mask_coeff
2823            .quantization
2824            .map(Quantization::from)
2825            .unwrap_or_default();
2826        let quant_protos = protos
2827            .quantization
2828            .map(Quantization::from)
2829            .unwrap_or_default();
2830
2831        let mut skip = vec![];
2832
2833        let (boxes_tensor, ind) =
2834            Self::find_outputs_with_shape_quantized(&boxes.shape, outputs, &skip)?;
2835        skip.push(ind);
2836
2837        let (scores_tensor, ind) =
2838            Self::find_outputs_with_shape_quantized(&scores.shape, outputs, &skip)?;
2839        skip.push(ind);
2840
2841        let (mask_tensor, ind) =
2842            Self::find_outputs_with_shape_quantized(&mask_coeff.shape, outputs, &skip)?;
2843        skip.push(ind);
2844
2845        let (protos_tensor, _) =
2846            Self::find_outputs_with_shape_quantized(&protos.shape, outputs, &skip)?;
2847
2848        let boxes = with_quantized!(boxes_tensor, b, {
2849            with_quantized!(scores_tensor, s, {
2850                let boxes_tensor = Self::swap_axes_if_needed(b, boxes.into());
2851                let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
2852
2853                let scores_tensor = Self::swap_axes_if_needed(s, scores.into());
2854                let scores_tensor = scores_tensor.slice(s![0, .., ..]);
2855                impl_yolo_split_segdet_quant_get_boxes::<XYWH, _, _>(
2856                    (boxes_tensor, quant_boxes),
2857                    (scores_tensor, quant_scores),
2858                    self.score_threshold,
2859                    self.iou_threshold,
2860                    self.nms,
2861                    output_boxes.capacity(),
2862                )
2863            })
2864        });
2865
2866        with_quantized!(mask_tensor, m, {
2867            with_quantized!(protos_tensor, p, {
2868                let mask_tensor = Self::swap_axes_if_needed(m, mask_coeff.into());
2869                let mask_tensor = mask_tensor.slice(s![0, .., ..]);
2870
2871                let protos_tensor = Self::swap_axes_if_needed(p, protos.into());
2872                let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
2873                impl_yolo_split_segdet_quant_process_masks::<_, _>(
2874                    boxes,
2875                    (mask_tensor, quant_masks),
2876                    (protos_tensor, quant_protos),
2877                    output_boxes,
2878                    output_masks,
2879                )
2880            })
2881        });
2882
2883        Ok(())
2884    }
2885
2886    fn decode_modelpack_det_split_float<D>(
2887        &self,
2888        outputs: &[ArrayViewD<D>],
2889        detection: &[configs::Detection],
2890        output_boxes: &mut Vec<DetectBox>,
2891    ) -> Result<(), DecoderError>
2892    where
2893        D: AsPrimitive<f32>,
2894    {
2895        let new_detection = detection
2896            .iter()
2897            .map(|x| match &x.anchors {
2898                None => Err(DecoderError::InvalidConfig(
2899                    "ModelPack Split Detection missing anchors".to_string(),
2900                )),
2901                Some(a) => Ok(ModelPackDetectionConfig {
2902                    anchors: a.clone(),
2903                    quantization: None,
2904                }),
2905            })
2906            .collect::<Result<Vec<_>, _>>()?;
2907
2908        let new_outputs = Self::match_outputs_to_detect(detection, outputs)?;
2909        let new_outputs = new_outputs
2910            .into_iter()
2911            .map(|x| x.slice(s![0, .., .., ..]))
2912            .collect::<Vec<_>>();
2913
2914        decode_modelpack_split_float(
2915            &new_outputs,
2916            &new_detection,
2917            self.score_threshold,
2918            self.iou_threshold,
2919            output_boxes,
2920        );
2921        Ok(())
2922    }
2923
2924    fn decode_modelpack_seg_float<T>(
2925        &self,
2926        outputs: &[ArrayViewD<T>],
2927        segmentation: &configs::Segmentation,
2928        output_masks: &mut Vec<Segmentation>,
2929    ) -> Result<(), DecoderError>
2930    where
2931        T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
2932        f32: AsPrimitive<T>,
2933    {
2934        let (seg, _) = Self::find_outputs_with_shape(&segmentation.shape, outputs, &[])?;
2935
2936        let seg = Self::swap_axes_if_needed(seg, segmentation.into());
2937        let seg = seg.slice(s![0, .., .., ..]);
2938        let u8_max = 255.0_f32.as_();
2939        let max = *seg.max().unwrap_or(&u8_max);
2940        let min = *seg.min().unwrap_or(&0.0_f32.as_());
2941        let seg = seg.mapv(|x| ((x - min) / (max - min) * u8_max).as_());
2942        output_masks.push(Segmentation {
2943            xmin: 0.0,
2944            ymin: 0.0,
2945            xmax: 1.0,
2946            ymax: 1.0,
2947            segmentation: seg,
2948        });
2949        Ok(())
2950    }
2951
2952    fn decode_modelpack_det_float<T>(
2953        &self,
2954        outputs: &[ArrayViewD<T>],
2955        boxes: &configs::Boxes,
2956        scores: &configs::Scores,
2957        output_boxes: &mut Vec<DetectBox>,
2958    ) -> Result<(), DecoderError>
2959    where
2960        T: Float + AsPrimitive<f32> + Send + Sync + 'static,
2961        f32: AsPrimitive<T>,
2962    {
2963        let (boxes_tensor, ind) = Self::find_outputs_with_shape(&boxes.shape, outputs, &[])?;
2964
2965        let boxes_tensor = Self::swap_axes_if_needed(boxes_tensor, boxes.into());
2966        let boxes_tensor = boxes_tensor.slice(s![0, .., 0, ..]);
2967
2968        let (scores_tensor, _) = Self::find_outputs_with_shape(&scores.shape, outputs, &[ind])?;
2969        let scores_tensor = Self::swap_axes_if_needed(scores_tensor, scores.into());
2970        let scores_tensor = scores_tensor.slice(s![0, .., ..]);
2971
2972        decode_modelpack_float(
2973            boxes_tensor,
2974            scores_tensor,
2975            self.score_threshold,
2976            self.iou_threshold,
2977            output_boxes,
2978        );
2979        Ok(())
2980    }
2981
2982    fn decode_yolo_det_float<T>(
2983        &self,
2984        outputs: &[ArrayViewD<T>],
2985        boxes: &configs::Detection,
2986        output_boxes: &mut Vec<DetectBox>,
2987    ) -> Result<(), DecoderError>
2988    where
2989        T: Float + AsPrimitive<f32> + Send + Sync + 'static,
2990        f32: AsPrimitive<T>,
2991    {
2992        let (boxes_tensor, _) = Self::find_outputs_with_shape(&boxes.shape, outputs, &[])?;
2993
2994        let boxes_tensor = Self::swap_axes_if_needed(boxes_tensor, boxes.into());
2995        let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
2996        decode_yolo_det_float(
2997            boxes_tensor,
2998            self.score_threshold,
2999            self.iou_threshold,
3000            self.nms,
3001            output_boxes,
3002        );
3003        Ok(())
3004    }
3005
3006    fn decode_yolo_segdet_float<T>(
3007        &self,
3008        outputs: &[ArrayViewD<T>],
3009        boxes: &configs::Detection,
3010        protos: &configs::Protos,
3011        output_boxes: &mut Vec<DetectBox>,
3012        output_masks: &mut Vec<Segmentation>,
3013    ) -> Result<(), DecoderError>
3014    where
3015        T: Float + AsPrimitive<f32> + Send + Sync + 'static,
3016        f32: AsPrimitive<T>,
3017    {
3018        let (boxes_tensor, ind) = Self::find_outputs_with_shape(&boxes.shape, outputs, &[])?;
3019
3020        let boxes_tensor = Self::swap_axes_if_needed(boxes_tensor, boxes.into());
3021        let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
3022
3023        let (protos_tensor, _) = Self::find_outputs_with_shape(&protos.shape, outputs, &[ind])?;
3024
3025        let protos_tensor = Self::swap_axes_if_needed(protos_tensor, protos.into());
3026        let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
3027        decode_yolo_segdet_float(
3028            boxes_tensor,
3029            protos_tensor,
3030            self.score_threshold,
3031            self.iou_threshold,
3032            self.nms,
3033            output_boxes,
3034            output_masks,
3035        );
3036        Ok(())
3037    }
3038
3039    fn decode_yolo_split_det_float<T>(
3040        &self,
3041        outputs: &[ArrayViewD<T>],
3042        boxes: &configs::Boxes,
3043        scores: &configs::Scores,
3044        output_boxes: &mut Vec<DetectBox>,
3045    ) -> Result<(), DecoderError>
3046    where
3047        T: Float + AsPrimitive<f32> + Send + Sync + 'static,
3048        f32: AsPrimitive<T>,
3049    {
3050        let (boxes_tensor, ind) = Self::find_outputs_with_shape(&boxes.shape, outputs, &[])?;
3051        let boxes_tensor = Self::swap_axes_if_needed(boxes_tensor, boxes.into());
3052        let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
3053
3054        let (scores_tensor, _) = Self::find_outputs_with_shape(&scores.shape, outputs, &[ind])?;
3055
3056        let scores_tensor = Self::swap_axes_if_needed(scores_tensor, scores.into());
3057        let scores_tensor = scores_tensor.slice(s![0, .., ..]);
3058
3059        decode_yolo_split_det_float(
3060            boxes_tensor,
3061            scores_tensor,
3062            self.score_threshold,
3063            self.iou_threshold,
3064            self.nms,
3065            output_boxes,
3066        );
3067        Ok(())
3068    }
3069
3070    #[allow(clippy::too_many_arguments)]
3071    fn decode_yolo_split_segdet_float<T>(
3072        &self,
3073        outputs: &[ArrayViewD<T>],
3074        boxes: &configs::Boxes,
3075        scores: &configs::Scores,
3076        mask_coeff: &configs::MaskCoefficients,
3077        protos: &configs::Protos,
3078        output_boxes: &mut Vec<DetectBox>,
3079        output_masks: &mut Vec<Segmentation>,
3080    ) -> Result<(), DecoderError>
3081    where
3082        T: Float + AsPrimitive<f32> + Send + Sync + 'static,
3083        f32: AsPrimitive<T>,
3084    {
3085        let mut skip = vec![];
3086        let (boxes_tensor, ind) = Self::find_outputs_with_shape(&boxes.shape, outputs, &skip)?;
3087
3088        let boxes_tensor = Self::swap_axes_if_needed(boxes_tensor, boxes.into());
3089        let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
3090        skip.push(ind);
3091
3092        let (scores_tensor, ind) = Self::find_outputs_with_shape(&scores.shape, outputs, &skip)?;
3093
3094        let scores_tensor = Self::swap_axes_if_needed(scores_tensor, scores.into());
3095        let scores_tensor = scores_tensor.slice(s![0, .., ..]);
3096        skip.push(ind);
3097
3098        let (mask_tensor, ind) = Self::find_outputs_with_shape(&mask_coeff.shape, outputs, &skip)?;
3099        let mask_tensor = Self::swap_axes_if_needed(mask_tensor, mask_coeff.into());
3100        let mask_tensor = mask_tensor.slice(s![0, .., ..]);
3101        skip.push(ind);
3102
3103        let (protos_tensor, _) = Self::find_outputs_with_shape(&protos.shape, outputs, &skip)?;
3104        let protos_tensor = Self::swap_axes_if_needed(protos_tensor, protos.into());
3105        let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
3106        decode_yolo_split_segdet_float(
3107            boxes_tensor,
3108            scores_tensor,
3109            mask_tensor,
3110            protos_tensor,
3111            self.score_threshold,
3112            self.iou_threshold,
3113            self.nms,
3114            output_boxes,
3115            output_masks,
3116        );
3117        Ok(())
3118    }
3119
3120    /// Decodes end-to-end YOLO detection outputs (post-NMS from model).
3121    ///
3122    /// Input shape: (1, N, 6+) where columns are [x1, y1, x2, y2, conf, class,
3123    /// ...] Boxes are output directly from model (may be normalized or
3124    /// pixel coords depending on config).
3125    fn decode_yolo_end_to_end_det_float<T>(
3126        &self,
3127        outputs: &[ArrayViewD<T>],
3128        boxes_config: &configs::Detection,
3129        output_boxes: &mut Vec<DetectBox>,
3130    ) -> Result<(), DecoderError>
3131    where
3132        T: Float + AsPrimitive<f32> + Send + Sync + 'static,
3133        f32: AsPrimitive<T>,
3134    {
3135        let (det_tensor, _) = Self::find_outputs_with_shape(&boxes_config.shape, outputs, &[])?;
3136        let det_tensor = Self::swap_axes_if_needed(det_tensor, boxes_config.into());
3137        let det_tensor = det_tensor.slice(s![0, .., ..]);
3138
3139        crate::yolo::decode_yolo_end_to_end_det_float(
3140            det_tensor,
3141            self.score_threshold,
3142            output_boxes,
3143        )?;
3144        Ok(())
3145    }
3146
3147    /// Decodes end-to-end YOLO detection + segmentation outputs (post-NMS from
3148    /// model).
3149    ///
3150    /// Input shapes:
3151    /// - detection: (1, N, 6 + num_protos) where columns are [x1, y1, x2, y2,
3152    ///   conf, class, mask_coeff_0, ..., mask_coeff_31]
3153    /// - protos: (1, proto_height, proto_width, num_protos)
3154    fn decode_yolo_end_to_end_segdet_float<T>(
3155        &self,
3156        outputs: &[ArrayViewD<T>],
3157        boxes_config: &configs::Detection,
3158        protos_config: &configs::Protos,
3159        output_boxes: &mut Vec<DetectBox>,
3160        output_masks: &mut Vec<Segmentation>,
3161    ) -> Result<(), DecoderError>
3162    where
3163        T: Float + AsPrimitive<f32> + Send + Sync + 'static,
3164        f32: AsPrimitive<T>,
3165    {
3166        if outputs.len() < 2 {
3167            return Err(DecoderError::InvalidShape(
3168                "End-to-end segdet requires detection and protos outputs".to_string(),
3169            ));
3170        }
3171
3172        let (det_tensor, det_ind) =
3173            Self::find_outputs_with_shape(&boxes_config.shape, outputs, &[])?;
3174        let det_tensor = Self::swap_axes_if_needed(det_tensor, boxes_config.into());
3175        let det_tensor = det_tensor.slice(s![0, .., ..]);
3176
3177        let (protos_tensor, _) =
3178            Self::find_outputs_with_shape(&protos_config.shape, outputs, &[det_ind])?;
3179        let protos_tensor = Self::swap_axes_if_needed(protos_tensor, protos_config.into());
3180        let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
3181
3182        crate::yolo::decode_yolo_end_to_end_segdet_float(
3183            det_tensor,
3184            protos_tensor,
3185            self.score_threshold,
3186            output_boxes,
3187            output_masks,
3188        )?;
3189        Ok(())
3190    }
3191
3192    fn match_outputs_to_detect<'a, 'b, T>(
3193        configs: &[configs::Detection],
3194        outputs: &'a [ArrayViewD<'b, T>],
3195    ) -> Result<Vec<&'a ArrayViewD<'b, T>>, DecoderError> {
3196        let mut new_output_order = Vec::new();
3197        for c in configs {
3198            let mut found = false;
3199            for o in outputs {
3200                if o.shape() == c.shape {
3201                    new_output_order.push(o);
3202                    found = true;
3203                    break;
3204                }
3205            }
3206            if !found {
3207                return Err(DecoderError::InvalidShape(format!(
3208                    "Did not find output with shape {:?}",
3209                    c.shape
3210                )));
3211            }
3212        }
3213        Ok(new_output_order)
3214    }
3215
3216    fn find_outputs_with_shape<'a, 'b, T>(
3217        shape: &[usize],
3218        outputs: &'a [ArrayViewD<'b, T>],
3219        skip: &[usize],
3220    ) -> Result<(&'a ArrayViewD<'b, T>, usize), DecoderError> {
3221        for (ind, o) in outputs.iter().enumerate() {
3222            if skip.contains(&ind) {
3223                continue;
3224            }
3225            if o.shape() == shape {
3226                return Ok((o, ind));
3227            }
3228        }
3229        Err(DecoderError::InvalidShape(format!(
3230            "Did not find output with shape {:?}",
3231            shape
3232        )))
3233    }
3234
3235    fn find_outputs_with_shape_quantized<'a, 'b>(
3236        shape: &[usize],
3237        outputs: &'a [ArrayViewDQuantized<'b>],
3238        skip: &[usize],
3239    ) -> Result<(&'a ArrayViewDQuantized<'b>, usize), DecoderError> {
3240        for (ind, o) in outputs.iter().enumerate() {
3241            if skip.contains(&ind) {
3242                continue;
3243            }
3244            if o.shape() == shape {
3245                return Ok((o, ind));
3246            }
3247        }
3248        Err(DecoderError::InvalidShape(format!(
3249            "Did not find output with shape {:?}",
3250            shape
3251        )))
3252    }
3253
3254    /// This is split detection, need to swap axes to batch, height, width,
3255    /// num_anchors_x_features,
3256    fn modelpack_det_order(x: DimName) -> usize {
3257        match x {
3258            DimName::Batch => 0,
3259            DimName::NumBoxes => 1,
3260            DimName::Padding => 2,
3261            DimName::BoxCoords => 3,
3262            _ => 1000, // this should be unreachable
3263        }
3264    }
3265
3266    // This is Ultralytics detection, need to swap axes to batch, num_features,
3267    // height, width
3268    fn yolo_det_order(x: DimName) -> usize {
3269        match x {
3270            DimName::Batch => 0,
3271            DimName::NumFeatures => 1,
3272            DimName::NumBoxes => 2,
3273            _ => 1000, // this should be unreachable
3274        }
3275    }
3276
3277    // This is modelpack boxes, need to swap axes to batch, num_boxes, padding,
3278    // box_coords
3279    fn modelpack_boxes_order(x: DimName) -> usize {
3280        match x {
3281            DimName::Batch => 0,
3282            DimName::NumBoxes => 1,
3283            DimName::Padding => 2,
3284            DimName::BoxCoords => 3,
3285            _ => 1000, // this should be unreachable
3286        }
3287    }
3288
3289    /// This is Ultralytics boxes, need to swap axes to batch, box_coords,
3290    /// num_boxes
3291    fn yolo_boxes_order(x: DimName) -> usize {
3292        match x {
3293            DimName::Batch => 0,
3294            DimName::BoxCoords => 1,
3295            DimName::NumBoxes => 2,
3296            _ => 1000, // this should be unreachable
3297        }
3298    }
3299
3300    /// This is modelpack scores, need to swap axes to batch, num_boxes,
3301    /// num_classes
3302    fn modelpack_scores_order(x: DimName) -> usize {
3303        match x {
3304            DimName::Batch => 0,
3305            DimName::NumBoxes => 1,
3306            DimName::NumClasses => 2,
3307            _ => 1000, // this should be unreachable
3308        }
3309    }
3310
3311    fn yolo_scores_order(x: DimName) -> usize {
3312        match x {
3313            DimName::Batch => 0,
3314            DimName::NumClasses => 1,
3315            DimName::NumBoxes => 2,
3316            _ => 1000, // this should be unreachable
3317        }
3318    }
3319
3320    /// This is modelpack segmentation, need to swap axes to batch, height,
3321    /// width, num_classes
3322    fn modelpack_segmentation_order(x: DimName) -> usize {
3323        match x {
3324            DimName::Batch => 0,
3325            DimName::Height => 1,
3326            DimName::Width => 2,
3327            DimName::NumClasses => 3,
3328            _ => 1000, // this should be unreachable
3329        }
3330    }
3331
3332    /// This is modelpack masks, need to swap axes to batch, height,
3333    /// width
3334    fn modelpack_mask_order(x: DimName) -> usize {
3335        match x {
3336            DimName::Batch => 0,
3337            DimName::Height => 1,
3338            DimName::Width => 2,
3339            _ => 1000, // this should be unreachable
3340        }
3341    }
3342
3343    /// This is yolo protos, need to swap axes to batch, height, width,
3344    /// num_protos
3345    fn yolo_protos_order(x: DimName) -> usize {
3346        match x {
3347            DimName::Batch => 0,
3348            DimName::Height => 1,
3349            DimName::Width => 2,
3350            DimName::NumProtos => 3,
3351            _ => 1000, // this should be unreachable
3352        }
3353    }
3354
3355    /// This is yolo mask coefficients, need to swap axes to batch, num_protos,
3356    /// num_boxes
3357    fn yolo_maskcoefficients_order(x: DimName) -> usize {
3358        match x {
3359            DimName::Batch => 0,
3360            DimName::NumProtos => 1,
3361            DimName::NumBoxes => 2,
3362            _ => 1000, // this should be unreachable
3363        }
3364    }
3365
3366    fn get_order_fn(config: ConfigOutputRef) -> fn(DimName) -> usize {
3367        let decoder_type = config.decoder();
3368        match (config, decoder_type) {
3369            (ConfigOutputRef::Detection(_), DecoderType::ModelPack) => Self::modelpack_det_order,
3370            (ConfigOutputRef::Detection(_), DecoderType::Ultralytics) => Self::yolo_det_order,
3371            (ConfigOutputRef::Boxes(_), DecoderType::ModelPack) => Self::modelpack_boxes_order,
3372            (ConfigOutputRef::Boxes(_), DecoderType::Ultralytics) => Self::yolo_boxes_order,
3373            (ConfigOutputRef::Scores(_), DecoderType::ModelPack) => Self::modelpack_scores_order,
3374            (ConfigOutputRef::Scores(_), DecoderType::Ultralytics) => Self::yolo_scores_order,
3375            (ConfigOutputRef::Segmentation(_), _) => Self::modelpack_segmentation_order,
3376            (ConfigOutputRef::Mask(_), _) => Self::modelpack_mask_order,
3377            (ConfigOutputRef::Protos(_), _) => Self::yolo_protos_order,
3378            (ConfigOutputRef::MaskCoefficients(_), _) => Self::yolo_maskcoefficients_order,
3379        }
3380    }
3381
3382    fn swap_axes_if_needed<'a, T, D: Dimension>(
3383        array: &ArrayView<'a, T, D>,
3384        config: ConfigOutputRef,
3385    ) -> ArrayView<'a, T, D> {
3386        let mut array = array.clone();
3387        if config.dshape().is_empty() {
3388            return array;
3389        }
3390        let order_fn: fn(DimName) -> usize = Self::get_order_fn(config.clone());
3391        let mut current_order: Vec<usize> = config
3392            .dshape()
3393            .iter()
3394            .map(|x| order_fn(x.0))
3395            .collect::<Vec<_>>();
3396
3397        assert_eq!(array.shape().len(), current_order.len());
3398        // do simple bubble sort as swap_axes is inexpensive and the
3399        // number of dimensions is small
3400        for i in 0..current_order.len() {
3401            let mut swapped = false;
3402            for j in 0..current_order.len() - 1 - i {
3403                if current_order[j] > current_order[j + 1] {
3404                    array.swap_axes(j, j + 1);
3405                    current_order.swap(j, j + 1);
3406                    swapped = true;
3407                }
3408            }
3409            if !swapped {
3410                break;
3411            }
3412        }
3413        array
3414    }
3415
3416    fn match_outputs_to_detect_quantized<'a, 'b>(
3417        configs: &[configs::Detection],
3418        outputs: &'a [ArrayViewDQuantized<'b>],
3419    ) -> Result<Vec<&'a ArrayViewDQuantized<'b>>, DecoderError> {
3420        let mut new_output_order = Vec::new();
3421        for c in configs {
3422            let mut found = false;
3423            for o in outputs {
3424                if o.shape() == c.shape {
3425                    new_output_order.push(o);
3426                    found = true;
3427                    break;
3428                }
3429            }
3430            if !found {
3431                return Err(DecoderError::InvalidShape(format!(
3432                    "Did not find output with shape {:?}",
3433                    c.shape
3434                )));
3435            }
3436        }
3437        Ok(new_output_order)
3438    }
3439}
3440
3441#[cfg(test)]
3442#[cfg_attr(coverage_nightly, coverage(off))]
3443mod decoder_builder_tests {
3444    use super::*;
3445
3446    #[test]
3447    fn test_decoder_builder_no_config() {
3448        use crate::DecoderBuilder;
3449        let result = DecoderBuilder::default().build();
3450        assert!(matches!(result, Err(DecoderError::NoConfig)));
3451    }
3452
3453    #[test]
3454    fn test_decoder_builder_empty_config() {
3455        use crate::DecoderBuilder;
3456        let result = DecoderBuilder::default()
3457            .with_config(ConfigOutputs {
3458                outputs: vec![],
3459                ..Default::default()
3460            })
3461            .build();
3462        assert!(
3463            matches!(result, Err(DecoderError::InvalidConfig(s)) if s == "No outputs found in config")
3464        );
3465    }
3466
3467    #[test]
3468    fn test_malformed_config_yaml() {
3469        let malformed_yaml = "
3470        model_type: yolov8_det
3471        outputs:
3472          - shape: [1, 84, 8400]
3473        "
3474        .to_owned();
3475        let result = DecoderBuilder::new()
3476            .with_config_yaml_str(malformed_yaml)
3477            .build();
3478        assert!(matches!(result, Err(DecoderError::Yaml(_))));
3479    }
3480
3481    #[test]
3482    fn test_malformed_config_json() {
3483        let malformed_yaml = "
3484        {
3485            \"model_type\": \"yolov8_det\",
3486            \"outputs\": [
3487                {
3488                    \"shape\": [1, 84, 8400]
3489                }
3490            ]
3491        }"
3492        .to_owned();
3493        let result = DecoderBuilder::new()
3494            .with_config_json_str(malformed_yaml)
3495            .build();
3496        assert!(matches!(result, Err(DecoderError::Json(_))));
3497    }
3498
3499    #[test]
3500    fn test_modelpack_and_yolo_config_error() {
3501        let result = DecoderBuilder::new()
3502            .with_config_modelpack_det(
3503                configs::Boxes {
3504                    decoder: configs::DecoderType::Ultralytics,
3505                    shape: vec![1, 4, 8400],
3506                    quantization: None,
3507                    dshape: vec![
3508                        (DimName::Batch, 1),
3509                        (DimName::BoxCoords, 4),
3510                        (DimName::NumBoxes, 8400),
3511                    ],
3512                    normalized: Some(true),
3513                },
3514                configs::Scores {
3515                    decoder: configs::DecoderType::ModelPack,
3516                    shape: vec![1, 80, 8400],
3517                    quantization: None,
3518                    dshape: vec![
3519                        (DimName::Batch, 1),
3520                        (DimName::NumClasses, 80),
3521                        (DimName::NumBoxes, 8400),
3522                    ],
3523                },
3524            )
3525            .build();
3526
3527        assert!(matches!(
3528            result, Err(DecoderError::InvalidConfig(s)) if s == "Both ModelPack and Yolo outputs found in config"
3529        ));
3530    }
3531
3532    #[test]
3533    fn test_yolo_invalid_seg_shape() {
3534        let result = DecoderBuilder::new()
3535            .with_config_yolo_segdet(
3536                configs::Detection {
3537                    decoder: configs::DecoderType::Ultralytics,
3538                    shape: vec![1, 85, 8400, 1], // Invalid shape
3539                    quantization: None,
3540                    anchors: None,
3541                    dshape: vec![
3542                        (DimName::Batch, 1),
3543                        (DimName::NumFeatures, 85),
3544                        (DimName::NumBoxes, 8400),
3545                        (DimName::Batch, 1),
3546                    ],
3547                    normalized: Some(true),
3548                },
3549                configs::Protos {
3550                    decoder: configs::DecoderType::Ultralytics,
3551                    shape: vec![1, 32, 160, 160],
3552                    quantization: None,
3553                    dshape: vec![
3554                        (DimName::Batch, 1),
3555                        (DimName::NumProtos, 32),
3556                        (DimName::Height, 160),
3557                        (DimName::Width, 160),
3558                    ],
3559                },
3560                Some(DecoderVersion::Yolo11),
3561            )
3562            .build();
3563
3564        assert!(matches!(
3565            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Detection shape")
3566        ));
3567    }
3568
3569    #[test]
3570    fn test_yolo_invalid_mask() {
3571        let result = DecoderBuilder::new()
3572            .with_config(ConfigOutputs {
3573                outputs: vec![ConfigOutput::Mask(configs::Mask {
3574                    shape: vec![1, 160, 160, 1],
3575                    decoder: configs::DecoderType::Ultralytics,
3576                    quantization: None,
3577                    dshape: vec![
3578                        (DimName::Batch, 1),
3579                        (DimName::Height, 160),
3580                        (DimName::Width, 160),
3581                        (DimName::NumFeatures, 1),
3582                    ],
3583                })],
3584                ..Default::default()
3585            })
3586            .build();
3587
3588        assert!(matches!(
3589            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Mask output with Yolo decoder")
3590        ));
3591    }
3592
3593    #[test]
3594    fn test_yolo_invalid_outputs() {
3595        let result = DecoderBuilder::new()
3596            .with_config(ConfigOutputs {
3597                outputs: vec![ConfigOutput::Segmentation(configs::Segmentation {
3598                    shape: vec![1, 84, 8400],
3599                    decoder: configs::DecoderType::Ultralytics,
3600                    quantization: None,
3601                    dshape: vec![
3602                        (DimName::Batch, 1),
3603                        (DimName::NumFeatures, 84),
3604                        (DimName::NumBoxes, 8400),
3605                    ],
3606                })],
3607                ..Default::default()
3608            })
3609            .build();
3610
3611        assert!(
3612            matches!(result, Err(DecoderError::InvalidConfig(s)) if s == "Invalid Segmentation output with Yolo decoder")
3613        );
3614    }
3615
3616    #[test]
3617    fn test_yolo_invalid_det() {
3618        let result = DecoderBuilder::new()
3619            .with_config_yolo_det(
3620                configs::Detection {
3621                    anchors: None,
3622                    decoder: DecoderType::Ultralytics,
3623                    quantization: None,
3624                    shape: vec![1, 84, 8400, 1], // Invalid shape
3625                    dshape: vec![
3626                        (DimName::Batch, 1),
3627                        (DimName::NumFeatures, 84),
3628                        (DimName::NumBoxes, 8400),
3629                        (DimName::Batch, 1),
3630                    ],
3631                    normalized: Some(true),
3632                },
3633                Some(DecoderVersion::Yolo11),
3634            )
3635            .build();
3636
3637        assert!(matches!(
3638            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Detection shape")));
3639
3640        let result = DecoderBuilder::new()
3641            .with_config_yolo_det(
3642                configs::Detection {
3643                    anchors: None,
3644                    decoder: DecoderType::Ultralytics,
3645                    quantization: None,
3646                    shape: vec![1, 8400, 3], // Invalid shape
3647                    dshape: vec![
3648                        (DimName::Batch, 1),
3649                        (DimName::NumBoxes, 8400),
3650                        (DimName::NumFeatures, 3),
3651                    ],
3652                    normalized: Some(true),
3653                },
3654                Some(DecoderVersion::Yolo11),
3655            )
3656            .build();
3657
3658        assert!(
3659            matches!(
3660            &result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid shape: Yolo num_features 3 must be greater than 4")),
3661            "{}",
3662            result.unwrap_err()
3663        );
3664
3665        let result = DecoderBuilder::new()
3666            .with_config_yolo_det(
3667                configs::Detection {
3668                    anchors: None,
3669                    decoder: DecoderType::Ultralytics,
3670                    quantization: None,
3671                    shape: vec![1, 3, 8400], // Invalid shape
3672                    dshape: Vec::new(),
3673                    normalized: Some(true),
3674                },
3675                Some(DecoderVersion::Yolo11),
3676            )
3677            .build();
3678
3679        assert!(matches!(
3680            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid shape: Yolo num_features 3 must be greater than 4")));
3681    }
3682
3683    #[test]
3684    fn test_yolo_invalid_segdet() {
3685        let result = DecoderBuilder::new()
3686            .with_config_yolo_segdet(
3687                configs::Detection {
3688                    decoder: configs::DecoderType::Ultralytics,
3689                    shape: vec![1, 85, 8400, 1], // Invalid shape
3690                    quantization: None,
3691                    anchors: None,
3692                    dshape: vec![
3693                        (DimName::Batch, 1),
3694                        (DimName::NumFeatures, 85),
3695                        (DimName::NumBoxes, 8400),
3696                        (DimName::Batch, 1),
3697                    ],
3698                    normalized: Some(true),
3699                },
3700                configs::Protos {
3701                    decoder: configs::DecoderType::Ultralytics,
3702                    shape: vec![1, 32, 160, 160],
3703                    quantization: None,
3704                    dshape: vec![
3705                        (DimName::Batch, 1),
3706                        (DimName::NumProtos, 32),
3707                        (DimName::Height, 160),
3708                        (DimName::Width, 160),
3709                    ],
3710                },
3711                Some(DecoderVersion::Yolo11),
3712            )
3713            .build();
3714
3715        assert!(matches!(
3716            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Detection shape")));
3717
3718        let result = DecoderBuilder::new()
3719            .with_config_yolo_segdet(
3720                configs::Detection {
3721                    decoder: configs::DecoderType::Ultralytics,
3722                    shape: vec![1, 85, 8400],
3723                    quantization: None,
3724                    anchors: None,
3725                    dshape: vec![
3726                        (DimName::Batch, 1),
3727                        (DimName::NumFeatures, 85),
3728                        (DimName::NumBoxes, 8400),
3729                    ],
3730                    normalized: Some(true),
3731                },
3732                configs::Protos {
3733                    decoder: configs::DecoderType::Ultralytics,
3734                    shape: vec![1, 32, 160, 160, 1], // Invalid shape
3735                    dshape: vec![
3736                        (DimName::Batch, 1),
3737                        (DimName::NumProtos, 32),
3738                        (DimName::Height, 160),
3739                        (DimName::Width, 160),
3740                        (DimName::Batch, 1),
3741                    ],
3742                    quantization: None,
3743                },
3744                Some(DecoderVersion::Yolo11),
3745            )
3746            .build();
3747
3748        assert!(matches!(
3749            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Protos shape")));
3750
3751        let result = DecoderBuilder::new()
3752            .with_config_yolo_segdet(
3753                configs::Detection {
3754                    decoder: configs::DecoderType::Ultralytics,
3755                    shape: vec![1, 8400, 36], // too few classes
3756                    quantization: None,
3757                    anchors: None,
3758                    dshape: vec![
3759                        (DimName::Batch, 1),
3760                        (DimName::NumBoxes, 8400),
3761                        (DimName::NumFeatures, 36),
3762                    ],
3763                    normalized: Some(true),
3764                },
3765                configs::Protos {
3766                    decoder: configs::DecoderType::Ultralytics,
3767                    shape: vec![1, 32, 160, 160],
3768                    quantization: None,
3769                    dshape: vec![
3770                        (DimName::Batch, 1),
3771                        (DimName::NumProtos, 32),
3772                        (DimName::Height, 160),
3773                        (DimName::Width, 160),
3774                    ],
3775                },
3776                Some(DecoderVersion::Yolo11),
3777            )
3778            .build();
3779        println!("{:?}", result);
3780        assert!(matches!(
3781            result, Err(DecoderError::InvalidConfig(s)) if s == "Invalid shape: Yolo num_features 36 must be greater than 36"));
3782    }
3783
3784    #[test]
3785    fn test_yolo_invalid_split_det() {
3786        let result = DecoderBuilder::new()
3787            .with_config_yolo_split_det(
3788                configs::Boxes {
3789                    decoder: configs::DecoderType::Ultralytics,
3790                    shape: vec![1, 4, 8400, 1], // Invalid shape
3791                    quantization: None,
3792                    dshape: vec![
3793                        (DimName::Batch, 1),
3794                        (DimName::BoxCoords, 4),
3795                        (DimName::NumBoxes, 8400),
3796                        (DimName::Batch, 1),
3797                    ],
3798                    normalized: Some(true),
3799                },
3800                configs::Scores {
3801                    decoder: configs::DecoderType::Ultralytics,
3802                    shape: vec![1, 80, 8400],
3803                    quantization: None,
3804                    dshape: vec![
3805                        (DimName::Batch, 1),
3806                        (DimName::NumClasses, 80),
3807                        (DimName::NumBoxes, 8400),
3808                    ],
3809                },
3810            )
3811            .build();
3812
3813        assert!(matches!(
3814            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Split Boxes shape")));
3815
3816        let result = DecoderBuilder::new()
3817            .with_config_yolo_split_det(
3818                configs::Boxes {
3819                    decoder: configs::DecoderType::Ultralytics,
3820                    shape: vec![1, 4, 8400],
3821                    quantization: None,
3822                    dshape: vec![
3823                        (DimName::Batch, 1),
3824                        (DimName::BoxCoords, 4),
3825                        (DimName::NumBoxes, 8400),
3826                    ],
3827                    normalized: Some(true),
3828                },
3829                configs::Scores {
3830                    decoder: configs::DecoderType::Ultralytics,
3831                    shape: vec![1, 80, 8400, 1], // Invalid shape
3832                    quantization: None,
3833                    dshape: vec![
3834                        (DimName::Batch, 1),
3835                        (DimName::NumClasses, 80),
3836                        (DimName::NumBoxes, 8400),
3837                        (DimName::Batch, 1),
3838                    ],
3839                },
3840            )
3841            .build();
3842
3843        assert!(matches!(
3844            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Split Scores shape")));
3845
3846        let result = DecoderBuilder::new()
3847            .with_config_yolo_split_det(
3848                configs::Boxes {
3849                    decoder: configs::DecoderType::Ultralytics,
3850                    shape: vec![1, 8400, 4],
3851                    quantization: None,
3852                    dshape: vec![
3853                        (DimName::Batch, 1),
3854                        (DimName::NumBoxes, 8400),
3855                        (DimName::BoxCoords, 4),
3856                    ],
3857                    normalized: Some(true),
3858                },
3859                configs::Scores {
3860                    decoder: configs::DecoderType::Ultralytics,
3861                    shape: vec![1, 8400 + 1, 80], // Invalid number of boxes
3862                    quantization: None,
3863                    dshape: vec![
3864                        (DimName::Batch, 1),
3865                        (DimName::NumBoxes, 8401),
3866                        (DimName::NumClasses, 80),
3867                    ],
3868                },
3869            )
3870            .build();
3871
3872        assert!(matches!(
3873            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Yolo Split Detection Boxes num 8400 incompatible with Scores num 8401")));
3874
3875        let result = DecoderBuilder::new()
3876            .with_config_yolo_split_det(
3877                configs::Boxes {
3878                    decoder: configs::DecoderType::Ultralytics,
3879                    shape: vec![1, 5, 8400], // Invalid boxes dimensions
3880                    quantization: None,
3881                    dshape: vec![
3882                        (DimName::Batch, 1),
3883                        (DimName::BoxCoords, 5),
3884                        (DimName::NumBoxes, 8400),
3885                    ],
3886                    normalized: Some(true),
3887                },
3888                configs::Scores {
3889                    decoder: configs::DecoderType::Ultralytics,
3890                    shape: vec![1, 80, 8400],
3891                    quantization: None,
3892                    dshape: vec![
3893                        (DimName::Batch, 1),
3894                        (DimName::NumClasses, 80),
3895                        (DimName::NumBoxes, 8400),
3896                    ],
3897                },
3898            )
3899            .build();
3900        assert!(matches!(
3901            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("BoxCoords dimension size must be 4")));
3902    }
3903
3904    #[test]
3905    fn test_yolo_invalid_split_segdet() {
3906        let result = DecoderBuilder::new()
3907            .with_config_yolo_split_segdet(
3908                configs::Boxes {
3909                    decoder: configs::DecoderType::Ultralytics,
3910                    shape: vec![1, 8400, 4, 1],
3911                    quantization: None,
3912                    dshape: vec![
3913                        (DimName::Batch, 1),
3914                        (DimName::NumBoxes, 8400),
3915                        (DimName::BoxCoords, 4),
3916                        (DimName::Batch, 1),
3917                    ],
3918                    normalized: Some(true),
3919                },
3920                configs::Scores {
3921                    decoder: configs::DecoderType::Ultralytics,
3922                    shape: vec![1, 8400, 80],
3923
3924                    quantization: None,
3925                    dshape: vec![
3926                        (DimName::Batch, 1),
3927                        (DimName::NumBoxes, 8400),
3928                        (DimName::NumClasses, 80),
3929                    ],
3930                },
3931                configs::MaskCoefficients {
3932                    decoder: configs::DecoderType::Ultralytics,
3933                    shape: vec![1, 8400, 32],
3934                    quantization: None,
3935                    dshape: vec![
3936                        (DimName::Batch, 1),
3937                        (DimName::NumBoxes, 8400),
3938                        (DimName::NumProtos, 32),
3939                    ],
3940                },
3941                configs::Protos {
3942                    decoder: configs::DecoderType::Ultralytics,
3943                    shape: vec![1, 32, 160, 160],
3944                    quantization: None,
3945                    dshape: vec![
3946                        (DimName::Batch, 1),
3947                        (DimName::NumProtos, 32),
3948                        (DimName::Height, 160),
3949                        (DimName::Width, 160),
3950                    ],
3951                },
3952            )
3953            .build();
3954
3955        assert!(matches!(
3956            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Split Boxes shape")));
3957
3958        let result = DecoderBuilder::new()
3959            .with_config_yolo_split_segdet(
3960                configs::Boxes {
3961                    decoder: configs::DecoderType::Ultralytics,
3962                    shape: vec![1, 8400, 4],
3963                    quantization: None,
3964                    dshape: vec![
3965                        (DimName::Batch, 1),
3966                        (DimName::NumBoxes, 8400),
3967                        (DimName::BoxCoords, 4),
3968                    ],
3969                    normalized: Some(true),
3970                },
3971                configs::Scores {
3972                    decoder: configs::DecoderType::Ultralytics,
3973                    shape: vec![1, 8400, 80, 1],
3974                    quantization: None,
3975                    dshape: vec![
3976                        (DimName::Batch, 1),
3977                        (DimName::NumBoxes, 8400),
3978                        (DimName::NumClasses, 80),
3979                        (DimName::Batch, 1),
3980                    ],
3981                },
3982                configs::MaskCoefficients {
3983                    decoder: configs::DecoderType::Ultralytics,
3984                    shape: vec![1, 8400, 32],
3985                    quantization: None,
3986                    dshape: vec![
3987                        (DimName::Batch, 1),
3988                        (DimName::NumBoxes, 8400),
3989                        (DimName::NumProtos, 32),
3990                    ],
3991                },
3992                configs::Protos {
3993                    decoder: configs::DecoderType::Ultralytics,
3994                    shape: vec![1, 32, 160, 160],
3995                    quantization: None,
3996                    dshape: vec![
3997                        (DimName::Batch, 1),
3998                        (DimName::NumProtos, 32),
3999                        (DimName::Height, 160),
4000                        (DimName::Width, 160),
4001                    ],
4002                },
4003            )
4004            .build();
4005
4006        assert!(matches!(
4007            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Split Scores shape")));
4008
4009        let result = DecoderBuilder::new()
4010            .with_config_yolo_split_segdet(
4011                configs::Boxes {
4012                    decoder: configs::DecoderType::Ultralytics,
4013                    shape: vec![1, 8400, 4],
4014                    quantization: None,
4015                    dshape: vec![
4016                        (DimName::Batch, 1),
4017                        (DimName::NumBoxes, 8400),
4018                        (DimName::BoxCoords, 4),
4019                    ],
4020                    normalized: Some(true),
4021                },
4022                configs::Scores {
4023                    decoder: configs::DecoderType::Ultralytics,
4024                    shape: vec![1, 8400, 80],
4025                    quantization: None,
4026                    dshape: vec![
4027                        (DimName::Batch, 1),
4028                        (DimName::NumBoxes, 8400),
4029                        (DimName::NumClasses, 80),
4030                    ],
4031                },
4032                configs::MaskCoefficients {
4033                    decoder: configs::DecoderType::Ultralytics,
4034                    shape: vec![1, 8400, 32, 1],
4035                    quantization: None,
4036                    dshape: vec![
4037                        (DimName::Batch, 1),
4038                        (DimName::NumBoxes, 8400),
4039                        (DimName::NumProtos, 32),
4040                        (DimName::Batch, 1),
4041                    ],
4042                },
4043                configs::Protos {
4044                    decoder: configs::DecoderType::Ultralytics,
4045                    shape: vec![1, 32, 160, 160],
4046                    quantization: None,
4047                    dshape: vec![
4048                        (DimName::Batch, 1),
4049                        (DimName::NumProtos, 32),
4050                        (DimName::Height, 160),
4051                        (DimName::Width, 160),
4052                    ],
4053                },
4054            )
4055            .build();
4056
4057        assert!(matches!(
4058            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Split Mask Coefficients shape")));
4059
4060        let result = DecoderBuilder::new()
4061            .with_config_yolo_split_segdet(
4062                configs::Boxes {
4063                    decoder: configs::DecoderType::Ultralytics,
4064                    shape: vec![1, 8400, 4],
4065                    quantization: None,
4066                    dshape: vec![
4067                        (DimName::Batch, 1),
4068                        (DimName::NumBoxes, 8400),
4069                        (DimName::BoxCoords, 4),
4070                    ],
4071                    normalized: Some(true),
4072                },
4073                configs::Scores {
4074                    decoder: configs::DecoderType::Ultralytics,
4075                    shape: vec![1, 8400, 80],
4076                    quantization: None,
4077                    dshape: vec![
4078                        (DimName::Batch, 1),
4079                        (DimName::NumBoxes, 8400),
4080                        (DimName::NumClasses, 80),
4081                    ],
4082                },
4083                configs::MaskCoefficients {
4084                    decoder: configs::DecoderType::Ultralytics,
4085                    shape: vec![1, 8400, 32],
4086                    quantization: None,
4087                    dshape: vec![
4088                        (DimName::Batch, 1),
4089                        (DimName::NumBoxes, 8400),
4090                        (DimName::NumProtos, 32),
4091                    ],
4092                },
4093                configs::Protos {
4094                    decoder: configs::DecoderType::Ultralytics,
4095                    shape: vec![1, 32, 160, 160, 1],
4096                    quantization: None,
4097                    dshape: vec![
4098                        (DimName::Batch, 1),
4099                        (DimName::NumProtos, 32),
4100                        (DimName::Height, 160),
4101                        (DimName::Width, 160),
4102                        (DimName::Batch, 1),
4103                    ],
4104                },
4105            )
4106            .build();
4107
4108        assert!(matches!(
4109            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Protos shape")));
4110
4111        let result = DecoderBuilder::new()
4112            .with_config_yolo_split_segdet(
4113                configs::Boxes {
4114                    decoder: configs::DecoderType::Ultralytics,
4115                    shape: vec![1, 8400, 4],
4116                    quantization: None,
4117                    dshape: vec![
4118                        (DimName::Batch, 1),
4119                        (DimName::NumBoxes, 8400),
4120                        (DimName::BoxCoords, 4),
4121                    ],
4122                    normalized: Some(true),
4123                },
4124                configs::Scores {
4125                    decoder: configs::DecoderType::Ultralytics,
4126                    shape: vec![1, 8401, 80],
4127                    quantization: None,
4128                    dshape: vec![
4129                        (DimName::Batch, 1),
4130                        (DimName::NumBoxes, 8401),
4131                        (DimName::NumClasses, 80),
4132                    ],
4133                },
4134                configs::MaskCoefficients {
4135                    decoder: configs::DecoderType::Ultralytics,
4136                    shape: vec![1, 8400, 32],
4137                    quantization: None,
4138                    dshape: vec![
4139                        (DimName::Batch, 1),
4140                        (DimName::NumBoxes, 8400),
4141                        (DimName::NumProtos, 32),
4142                    ],
4143                },
4144                configs::Protos {
4145                    decoder: configs::DecoderType::Ultralytics,
4146                    shape: vec![1, 32, 160, 160],
4147                    quantization: None,
4148                    dshape: vec![
4149                        (DimName::Batch, 1),
4150                        (DimName::NumProtos, 32),
4151                        (DimName::Height, 160),
4152                        (DimName::Width, 160),
4153                    ],
4154                },
4155            )
4156            .build();
4157
4158        assert!(matches!(
4159            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Yolo Split Detection Boxes num 8400 incompatible with Scores num 8401")));
4160
4161        let result = DecoderBuilder::new()
4162            .with_config_yolo_split_segdet(
4163                configs::Boxes {
4164                    decoder: configs::DecoderType::Ultralytics,
4165                    shape: vec![1, 8400, 4],
4166                    quantization: None,
4167                    dshape: vec![
4168                        (DimName::Batch, 1),
4169                        (DimName::NumBoxes, 8400),
4170                        (DimName::BoxCoords, 4),
4171                    ],
4172                    normalized: Some(true),
4173                },
4174                configs::Scores {
4175                    decoder: configs::DecoderType::Ultralytics,
4176                    shape: vec![1, 8400, 80],
4177                    quantization: None,
4178                    dshape: vec![
4179                        (DimName::Batch, 1),
4180                        (DimName::NumBoxes, 8400),
4181                        (DimName::NumClasses, 80),
4182                    ],
4183                },
4184                configs::MaskCoefficients {
4185                    decoder: configs::DecoderType::Ultralytics,
4186                    shape: vec![1, 8401, 32],
4187
4188                    quantization: None,
4189                    dshape: vec![
4190                        (DimName::Batch, 1),
4191                        (DimName::NumBoxes, 8401),
4192                        (DimName::NumProtos, 32),
4193                    ],
4194                },
4195                configs::Protos {
4196                    decoder: configs::DecoderType::Ultralytics,
4197                    shape: vec![1, 32, 160, 160],
4198                    quantization: None,
4199                    dshape: vec![
4200                        (DimName::Batch, 1),
4201                        (DimName::NumProtos, 32),
4202                        (DimName::Height, 160),
4203                        (DimName::Width, 160),
4204                    ],
4205                },
4206            )
4207            .build();
4208
4209        assert!(matches!(
4210            result, Err(DecoderError::InvalidConfig(ref s)) if s.starts_with("Yolo Split Detection Boxes num 8400 incompatible with Mask Coefficients num 8401")));
4211        let result = DecoderBuilder::new()
4212            .with_config_yolo_split_segdet(
4213                configs::Boxes {
4214                    decoder: configs::DecoderType::Ultralytics,
4215                    shape: vec![1, 8400, 4],
4216                    quantization: None,
4217                    dshape: vec![
4218                        (DimName::Batch, 1),
4219                        (DimName::NumBoxes, 8400),
4220                        (DimName::BoxCoords, 4),
4221                    ],
4222                    normalized: Some(true),
4223                },
4224                configs::Scores {
4225                    decoder: configs::DecoderType::Ultralytics,
4226                    shape: vec![1, 8400, 80],
4227                    quantization: None,
4228                    dshape: vec![
4229                        (DimName::Batch, 1),
4230                        (DimName::NumBoxes, 8400),
4231                        (DimName::NumClasses, 80),
4232                    ],
4233                },
4234                configs::MaskCoefficients {
4235                    decoder: configs::DecoderType::Ultralytics,
4236                    shape: vec![1, 8400, 32],
4237                    quantization: None,
4238                    dshape: vec![
4239                        (DimName::Batch, 1),
4240                        (DimName::NumBoxes, 8400),
4241                        (DimName::NumProtos, 32),
4242                    ],
4243                },
4244                configs::Protos {
4245                    decoder: configs::DecoderType::Ultralytics,
4246                    shape: vec![1, 31, 160, 160],
4247                    quantization: None,
4248                    dshape: vec![
4249                        (DimName::Batch, 1),
4250                        (DimName::NumProtos, 31),
4251                        (DimName::Height, 160),
4252                        (DimName::Width, 160),
4253                    ],
4254                },
4255            )
4256            .build();
4257        println!("{:?}", result);
4258        assert!(matches!(
4259            result, Err(DecoderError::InvalidConfig(ref s)) if s.starts_with( "Yolo Protos channels 31 incompatible with Mask Coefficients channels 32")));
4260    }
4261
4262    #[test]
4263    fn test_modelpack_invalid_config() {
4264        let result = DecoderBuilder::new()
4265            .with_config(ConfigOutputs {
4266                outputs: vec![
4267                    ConfigOutput::Boxes(configs::Boxes {
4268                        decoder: configs::DecoderType::ModelPack,
4269                        shape: vec![1, 8400, 1, 4],
4270                        quantization: None,
4271                        dshape: vec![
4272                            (DimName::Batch, 1),
4273                            (DimName::NumBoxes, 8400),
4274                            (DimName::Padding, 1),
4275                            (DimName::BoxCoords, 4),
4276                        ],
4277                        normalized: Some(true),
4278                    }),
4279                    ConfigOutput::Scores(configs::Scores {
4280                        decoder: configs::DecoderType::ModelPack,
4281                        shape: vec![1, 8400, 3],
4282                        quantization: None,
4283                        dshape: vec![
4284                            (DimName::Batch, 1),
4285                            (DimName::NumBoxes, 8400),
4286                            (DimName::NumClasses, 3),
4287                        ],
4288                    }),
4289                    ConfigOutput::Protos(configs::Protos {
4290                        decoder: configs::DecoderType::ModelPack,
4291                        shape: vec![1, 8400, 3],
4292                        quantization: None,
4293                        dshape: vec![
4294                            (DimName::Batch, 1),
4295                            (DimName::NumBoxes, 8400),
4296                            (DimName::NumFeatures, 3),
4297                        ],
4298                    }),
4299                ],
4300                ..Default::default()
4301            })
4302            .build();
4303
4304        assert!(matches!(
4305            result, Err(DecoderError::InvalidConfig(s)) if s == "ModelPack should not have protos"));
4306
4307        let result = DecoderBuilder::new()
4308            .with_config(ConfigOutputs {
4309                outputs: vec![
4310                    ConfigOutput::Boxes(configs::Boxes {
4311                        decoder: configs::DecoderType::ModelPack,
4312                        shape: vec![1, 8400, 1, 4],
4313                        quantization: None,
4314                        dshape: vec![
4315                            (DimName::Batch, 1),
4316                            (DimName::NumBoxes, 8400),
4317                            (DimName::Padding, 1),
4318                            (DimName::BoxCoords, 4),
4319                        ],
4320                        normalized: Some(true),
4321                    }),
4322                    ConfigOutput::Scores(configs::Scores {
4323                        decoder: configs::DecoderType::ModelPack,
4324                        shape: vec![1, 8400, 3],
4325                        quantization: None,
4326                        dshape: vec![
4327                            (DimName::Batch, 1),
4328                            (DimName::NumBoxes, 8400),
4329                            (DimName::NumClasses, 3),
4330                        ],
4331                    }),
4332                    ConfigOutput::MaskCoefficients(configs::MaskCoefficients {
4333                        decoder: configs::DecoderType::ModelPack,
4334                        shape: vec![1, 8400, 3],
4335                        quantization: None,
4336                        dshape: vec![
4337                            (DimName::Batch, 1),
4338                            (DimName::NumBoxes, 8400),
4339                            (DimName::NumProtos, 3),
4340                        ],
4341                    }),
4342                ],
4343                ..Default::default()
4344            })
4345            .build();
4346
4347        assert!(matches!(
4348            result, Err(DecoderError::InvalidConfig(s)) if s == "ModelPack should not have mask coefficients"));
4349
4350        let result = DecoderBuilder::new()
4351            .with_config(ConfigOutputs {
4352                outputs: vec![ConfigOutput::Boxes(configs::Boxes {
4353                    decoder: configs::DecoderType::ModelPack,
4354                    shape: vec![1, 8400, 1, 4],
4355                    quantization: None,
4356                    dshape: vec![
4357                        (DimName::Batch, 1),
4358                        (DimName::NumBoxes, 8400),
4359                        (DimName::Padding, 1),
4360                        (DimName::BoxCoords, 4),
4361                    ],
4362                    normalized: Some(true),
4363                })],
4364                ..Default::default()
4365            })
4366            .build();
4367
4368        assert!(matches!(
4369            result, Err(DecoderError::InvalidConfig(s)) if s == "Invalid ModelPack model outputs"));
4370    }
4371
4372    #[test]
4373    fn test_modelpack_invalid_det() {
4374        let result = DecoderBuilder::new()
4375            .with_config_modelpack_det(
4376                configs::Boxes {
4377                    decoder: DecoderType::ModelPack,
4378                    quantization: None,
4379                    shape: vec![1, 4, 8400],
4380                    dshape: vec![
4381                        (DimName::Batch, 1),
4382                        (DimName::BoxCoords, 4),
4383                        (DimName::NumBoxes, 8400),
4384                    ],
4385                    normalized: Some(true),
4386                },
4387                configs::Scores {
4388                    decoder: DecoderType::ModelPack,
4389                    quantization: None,
4390                    shape: vec![1, 80, 8400],
4391                    dshape: vec![
4392                        (DimName::Batch, 1),
4393                        (DimName::NumClasses, 80),
4394                        (DimName::NumBoxes, 8400),
4395                    ],
4396                },
4397            )
4398            .build();
4399
4400        assert!(matches!(
4401            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid ModelPack Boxes shape")));
4402
4403        let result = DecoderBuilder::new()
4404            .with_config_modelpack_det(
4405                configs::Boxes {
4406                    decoder: DecoderType::ModelPack,
4407                    quantization: None,
4408                    shape: vec![1, 4, 1, 8400],
4409                    dshape: vec![
4410                        (DimName::Batch, 1),
4411                        (DimName::BoxCoords, 4),
4412                        (DimName::Padding, 1),
4413                        (DimName::NumBoxes, 8400),
4414                    ],
4415                    normalized: Some(true),
4416                },
4417                configs::Scores {
4418                    decoder: DecoderType::ModelPack,
4419                    quantization: None,
4420                    shape: vec![1, 80, 8400, 1],
4421                    dshape: vec![
4422                        (DimName::Batch, 1),
4423                        (DimName::NumClasses, 80),
4424                        (DimName::NumBoxes, 8400),
4425                        (DimName::Padding, 1),
4426                    ],
4427                },
4428            )
4429            .build();
4430
4431        assert!(matches!(
4432            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid ModelPack Scores shape")));
4433
4434        let result = DecoderBuilder::new()
4435            .with_config_modelpack_det(
4436                configs::Boxes {
4437                    decoder: DecoderType::ModelPack,
4438                    quantization: None,
4439                    shape: vec![1, 4, 2, 8400],
4440                    dshape: vec![
4441                        (DimName::Batch, 1),
4442                        (DimName::BoxCoords, 4),
4443                        (DimName::Padding, 2),
4444                        (DimName::NumBoxes, 8400),
4445                    ],
4446                    normalized: Some(true),
4447                },
4448                configs::Scores {
4449                    decoder: DecoderType::ModelPack,
4450                    quantization: None,
4451                    shape: vec![1, 80, 8400],
4452                    dshape: vec![
4453                        (DimName::Batch, 1),
4454                        (DimName::NumClasses, 80),
4455                        (DimName::NumBoxes, 8400),
4456                    ],
4457                },
4458            )
4459            .build();
4460        assert!(matches!(
4461            result, Err(DecoderError::InvalidConfig(s)) if s == "Padding dimension size must be 1"));
4462
4463        let result = DecoderBuilder::new()
4464            .with_config_modelpack_det(
4465                configs::Boxes {
4466                    decoder: DecoderType::ModelPack,
4467                    quantization: None,
4468                    shape: vec![1, 5, 1, 8400],
4469                    dshape: vec![
4470                        (DimName::Batch, 1),
4471                        (DimName::BoxCoords, 5),
4472                        (DimName::Padding, 1),
4473                        (DimName::NumBoxes, 8400),
4474                    ],
4475                    normalized: Some(true),
4476                },
4477                configs::Scores {
4478                    decoder: DecoderType::ModelPack,
4479                    quantization: None,
4480                    shape: vec![1, 80, 8400],
4481                    dshape: vec![
4482                        (DimName::Batch, 1),
4483                        (DimName::NumClasses, 80),
4484                        (DimName::NumBoxes, 8400),
4485                    ],
4486                },
4487            )
4488            .build();
4489
4490        assert!(matches!(
4491            result, Err(DecoderError::InvalidConfig(s)) if s == "BoxCoords dimension size must be 4"));
4492
4493        let result = DecoderBuilder::new()
4494            .with_config_modelpack_det(
4495                configs::Boxes {
4496                    decoder: DecoderType::ModelPack,
4497                    quantization: None,
4498                    shape: vec![1, 4, 1, 8400],
4499                    dshape: vec![
4500                        (DimName::Batch, 1),
4501                        (DimName::BoxCoords, 4),
4502                        (DimName::Padding, 1),
4503                        (DimName::NumBoxes, 8400),
4504                    ],
4505                    normalized: Some(true),
4506                },
4507                configs::Scores {
4508                    decoder: DecoderType::ModelPack,
4509                    quantization: None,
4510                    shape: vec![1, 80, 8401],
4511                    dshape: vec![
4512                        (DimName::Batch, 1),
4513                        (DimName::NumClasses, 80),
4514                        (DimName::NumBoxes, 8401),
4515                    ],
4516                },
4517            )
4518            .build();
4519
4520        assert!(matches!(
4521            result, Err(DecoderError::InvalidConfig(s)) if s == "ModelPack Detection Boxes num 8400 incompatible with Scores num 8401"));
4522    }
4523
4524    #[test]
4525    fn test_modelpack_invalid_det_split() {
4526        let result = DecoderBuilder::default()
4527            .with_config_modelpack_det_split(vec![
4528                configs::Detection {
4529                    decoder: DecoderType::ModelPack,
4530                    shape: vec![1, 17, 30, 18],
4531                    anchors: None,
4532                    quantization: None,
4533                    dshape: vec![
4534                        (DimName::Batch, 1),
4535                        (DimName::Height, 17),
4536                        (DimName::Width, 30),
4537                        (DimName::NumAnchorsXFeatures, 18),
4538                    ],
4539                    normalized: Some(true),
4540                },
4541                configs::Detection {
4542                    decoder: DecoderType::ModelPack,
4543                    shape: vec![1, 9, 15, 18],
4544                    anchors: None,
4545                    quantization: None,
4546                    dshape: vec![
4547                        (DimName::Batch, 1),
4548                        (DimName::Height, 9),
4549                        (DimName::Width, 15),
4550                        (DimName::NumAnchorsXFeatures, 18),
4551                    ],
4552                    normalized: Some(true),
4553                },
4554            ])
4555            .build();
4556
4557        assert!(matches!(
4558            result, Err(DecoderError::InvalidConfig(s)) if s == "ModelPack Split Detection missing anchors"));
4559
4560        let result = DecoderBuilder::default()
4561            .with_config_modelpack_det_split(vec![configs::Detection {
4562                decoder: DecoderType::ModelPack,
4563                shape: vec![1, 17, 30, 18],
4564                anchors: None,
4565                quantization: None,
4566                dshape: Vec::new(),
4567                normalized: Some(true),
4568            }])
4569            .build();
4570
4571        assert!(matches!(
4572            result, Err(DecoderError::InvalidConfig(s)) if s == "ModelPack Split Detection missing anchors"));
4573
4574        let result = DecoderBuilder::default()
4575            .with_config_modelpack_det_split(vec![configs::Detection {
4576                decoder: DecoderType::ModelPack,
4577                shape: vec![1, 17, 30, 18],
4578                anchors: Some(vec![]),
4579                quantization: None,
4580                dshape: vec![
4581                    (DimName::Batch, 1),
4582                    (DimName::Height, 17),
4583                    (DimName::Width, 30),
4584                    (DimName::NumAnchorsXFeatures, 18),
4585                ],
4586                normalized: Some(true),
4587            }])
4588            .build();
4589
4590        assert!(matches!(
4591            result, Err(DecoderError::InvalidConfig(s)) if s == "ModelPack Split Detection has zero anchors"));
4592
4593        let result = DecoderBuilder::default()
4594            .with_config_modelpack_det_split(vec![configs::Detection {
4595                decoder: DecoderType::ModelPack,
4596                shape: vec![1, 17, 30, 18, 1],
4597                anchors: Some(vec![
4598                    [0.3666666, 0.3148148],
4599                    [0.3874999, 0.474074],
4600                    [0.5333333, 0.644444],
4601                ]),
4602                quantization: None,
4603                dshape: vec![
4604                    (DimName::Batch, 1),
4605                    (DimName::Height, 17),
4606                    (DimName::Width, 30),
4607                    (DimName::NumAnchorsXFeatures, 18),
4608                    (DimName::Padding, 1),
4609                ],
4610                normalized: Some(true),
4611            }])
4612            .build();
4613
4614        assert!(matches!(
4615            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid ModelPack Split Detection shape")));
4616
4617        let result = DecoderBuilder::default()
4618            .with_config_modelpack_det_split(vec![configs::Detection {
4619                decoder: DecoderType::ModelPack,
4620                shape: vec![1, 15, 17, 30],
4621                anchors: Some(vec![
4622                    [0.3666666, 0.3148148],
4623                    [0.3874999, 0.474074],
4624                    [0.5333333, 0.644444],
4625                ]),
4626                quantization: None,
4627                dshape: vec![
4628                    (DimName::Batch, 1),
4629                    (DimName::NumAnchorsXFeatures, 15),
4630                    (DimName::Height, 17),
4631                    (DimName::Width, 30),
4632                ],
4633                normalized: Some(true),
4634            }])
4635            .build();
4636
4637        assert!(matches!(
4638            result, Err(DecoderError::InvalidConfig(s)) if s.contains("not greater than number of anchors * 5 =")));
4639
4640        let result = DecoderBuilder::default()
4641            .with_config_modelpack_det_split(vec![configs::Detection {
4642                decoder: DecoderType::ModelPack,
4643                shape: vec![1, 17, 30, 15],
4644                anchors: Some(vec![
4645                    [0.3666666, 0.3148148],
4646                    [0.3874999, 0.474074],
4647                    [0.5333333, 0.644444],
4648                ]),
4649                quantization: None,
4650                dshape: Vec::new(),
4651                normalized: Some(true),
4652            }])
4653            .build();
4654
4655        assert!(matches!(
4656            result, Err(DecoderError::InvalidConfig(s)) if s.contains("not greater than number of anchors * 5 =")));
4657
4658        let result = DecoderBuilder::default()
4659            .with_config_modelpack_det_split(vec![configs::Detection {
4660                decoder: DecoderType::ModelPack,
4661                shape: vec![1, 16, 17, 30],
4662                anchors: Some(vec![
4663                    [0.3666666, 0.3148148],
4664                    [0.3874999, 0.474074],
4665                    [0.5333333, 0.644444],
4666                ]),
4667                quantization: None,
4668                dshape: vec![
4669                    (DimName::Batch, 1),
4670                    (DimName::NumAnchorsXFeatures, 16),
4671                    (DimName::Height, 17),
4672                    (DimName::Width, 30),
4673                ],
4674                normalized: Some(true),
4675            }])
4676            .build();
4677
4678        assert!(matches!(
4679            result, Err(DecoderError::InvalidConfig(s)) if s.contains("not a multiple of number of anchors")));
4680
4681        let result = DecoderBuilder::default()
4682            .with_config_modelpack_det_split(vec![configs::Detection {
4683                decoder: DecoderType::ModelPack,
4684                shape: vec![1, 17, 30, 16],
4685                anchors: Some(vec![
4686                    [0.3666666, 0.3148148],
4687                    [0.3874999, 0.474074],
4688                    [0.5333333, 0.644444],
4689                ]),
4690                quantization: None,
4691                dshape: Vec::new(),
4692                normalized: Some(true),
4693            }])
4694            .build();
4695
4696        assert!(matches!(
4697            result, Err(DecoderError::InvalidConfig(s)) if s.contains("not a multiple of number of anchors")));
4698
4699        let result = DecoderBuilder::default()
4700            .with_config_modelpack_det_split(vec![configs::Detection {
4701                decoder: DecoderType::ModelPack,
4702                shape: vec![1, 18, 17, 30],
4703                anchors: Some(vec![
4704                    [0.3666666, 0.3148148],
4705                    [0.3874999, 0.474074],
4706                    [0.5333333, 0.644444],
4707                ]),
4708                quantization: None,
4709                dshape: vec![
4710                    (DimName::Batch, 1),
4711                    (DimName::NumProtos, 18),
4712                    (DimName::Height, 17),
4713                    (DimName::Width, 30),
4714                ],
4715                normalized: Some(true),
4716            }])
4717            .build();
4718        assert!(matches!(
4719            result, Err(DecoderError::InvalidConfig(s)) if s.contains("Split Detection dshape missing required dimension NumAnchorsXFeature")));
4720
4721        let result = DecoderBuilder::default()
4722            .with_config_modelpack_det_split(vec![
4723                configs::Detection {
4724                    decoder: DecoderType::ModelPack,
4725                    shape: vec![1, 17, 30, 18],
4726                    anchors: Some(vec![
4727                        [0.3666666, 0.3148148],
4728                        [0.3874999, 0.474074],
4729                        [0.5333333, 0.644444],
4730                    ]),
4731                    quantization: None,
4732                    dshape: vec![
4733                        (DimName::Batch, 1),
4734                        (DimName::Height, 17),
4735                        (DimName::Width, 30),
4736                        (DimName::NumAnchorsXFeatures, 18),
4737                    ],
4738                    normalized: Some(true),
4739                },
4740                configs::Detection {
4741                    decoder: DecoderType::ModelPack,
4742                    shape: vec![1, 17, 30, 21],
4743                    anchors: Some(vec![
4744                        [0.3666666, 0.3148148],
4745                        [0.3874999, 0.474074],
4746                        [0.5333333, 0.644444],
4747                    ]),
4748                    quantization: None,
4749                    dshape: vec![
4750                        (DimName::Batch, 1),
4751                        (DimName::Height, 17),
4752                        (DimName::Width, 30),
4753                        (DimName::NumAnchorsXFeatures, 21),
4754                    ],
4755                    normalized: Some(true),
4756                },
4757            ])
4758            .build();
4759
4760        assert!(matches!(
4761            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("ModelPack Split Detection inconsistent number of classes:")));
4762
4763        let result = DecoderBuilder::default()
4764            .with_config_modelpack_det_split(vec![
4765                configs::Detection {
4766                    decoder: DecoderType::ModelPack,
4767                    shape: vec![1, 17, 30, 18],
4768                    anchors: Some(vec![
4769                        [0.3666666, 0.3148148],
4770                        [0.3874999, 0.474074],
4771                        [0.5333333, 0.644444],
4772                    ]),
4773                    quantization: None,
4774                    dshape: vec![],
4775                    normalized: Some(true),
4776                },
4777                configs::Detection {
4778                    decoder: DecoderType::ModelPack,
4779                    shape: vec![1, 17, 30, 21],
4780                    anchors: Some(vec![
4781                        [0.3666666, 0.3148148],
4782                        [0.3874999, 0.474074],
4783                        [0.5333333, 0.644444],
4784                    ]),
4785                    quantization: None,
4786                    dshape: vec![],
4787                    normalized: Some(true),
4788                },
4789            ])
4790            .build();
4791
4792        assert!(matches!(
4793            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("ModelPack Split Detection inconsistent number of classes:")));
4794    }
4795
4796    #[test]
4797    fn test_modelpack_invalid_seg() {
4798        let result = DecoderBuilder::new()
4799            .with_config_modelpack_seg(configs::Segmentation {
4800                decoder: DecoderType::ModelPack,
4801                quantization: None,
4802                shape: vec![1, 160, 106, 3, 1],
4803                dshape: vec![
4804                    (DimName::Batch, 1),
4805                    (DimName::Height, 160),
4806                    (DimName::Width, 106),
4807                    (DimName::NumClasses, 3),
4808                    (DimName::Padding, 1),
4809                ],
4810            })
4811            .build();
4812
4813        assert!(matches!(
4814            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid ModelPack Segmentation shape")));
4815    }
4816
4817    #[test]
4818    fn test_modelpack_invalid_segdet() {
4819        let result = DecoderBuilder::new()
4820            .with_config_modelpack_segdet(
4821                configs::Boxes {
4822                    decoder: DecoderType::ModelPack,
4823                    quantization: None,
4824                    shape: vec![1, 4, 1, 8400],
4825                    dshape: vec![
4826                        (DimName::Batch, 1),
4827                        (DimName::BoxCoords, 4),
4828                        (DimName::Padding, 1),
4829                        (DimName::NumBoxes, 8400),
4830                    ],
4831                    normalized: Some(true),
4832                },
4833                configs::Scores {
4834                    decoder: DecoderType::ModelPack,
4835                    quantization: None,
4836                    shape: vec![1, 4, 8400],
4837                    dshape: vec![
4838                        (DimName::Batch, 1),
4839                        (DimName::NumClasses, 4),
4840                        (DimName::NumBoxes, 8400),
4841                    ],
4842                },
4843                configs::Segmentation {
4844                    decoder: DecoderType::ModelPack,
4845                    quantization: None,
4846                    shape: vec![1, 160, 106, 3],
4847                    dshape: vec![
4848                        (DimName::Batch, 1),
4849                        (DimName::Height, 160),
4850                        (DimName::Width, 106),
4851                        (DimName::NumClasses, 3),
4852                    ],
4853                },
4854            )
4855            .build();
4856
4857        assert!(matches!(
4858            result, Err(DecoderError::InvalidConfig(s)) if s.contains("incompatible with number of classes")));
4859    }
4860
4861    #[test]
4862    fn test_modelpack_invalid_segdet_split() {
4863        let result = DecoderBuilder::new()
4864            .with_config_modelpack_segdet_split(
4865                vec![configs::Detection {
4866                    decoder: DecoderType::ModelPack,
4867                    shape: vec![1, 17, 30, 18],
4868                    anchors: Some(vec![
4869                        [0.3666666, 0.3148148],
4870                        [0.3874999, 0.474074],
4871                        [0.5333333, 0.644444],
4872                    ]),
4873                    quantization: None,
4874                    dshape: vec![
4875                        (DimName::Batch, 1),
4876                        (DimName::Height, 17),
4877                        (DimName::Width, 30),
4878                        (DimName::NumAnchorsXFeatures, 18),
4879                    ],
4880                    normalized: Some(true),
4881                }],
4882                configs::Segmentation {
4883                    decoder: DecoderType::ModelPack,
4884                    quantization: None,
4885                    shape: vec![1, 160, 106, 3],
4886                    dshape: vec![
4887                        (DimName::Batch, 1),
4888                        (DimName::Height, 160),
4889                        (DimName::Width, 106),
4890                        (DimName::NumClasses, 3),
4891                    ],
4892                },
4893            )
4894            .build();
4895
4896        assert!(matches!(
4897            result, Err(DecoderError::InvalidConfig(s)) if s.contains("incompatible with number of classes")));
4898    }
4899
4900    #[test]
4901    fn test_decode_bad_shapes() {
4902        let score_threshold = 0.25;
4903        let iou_threshold = 0.7;
4904        let quant = (0.0040811873, -123);
4905        let out = include_bytes!("../../../testdata/yolov8s_80_classes.bin");
4906        let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
4907        let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
4908        let out_float: Array3<f32> = dequantize_ndarray(out.view(), quant.into());
4909
4910        let decoder = DecoderBuilder::default()
4911            .with_config_yolo_det(
4912                configs::Detection {
4913                    decoder: DecoderType::Ultralytics,
4914                    shape: vec![1, 85, 8400],
4915                    anchors: None,
4916                    quantization: Some(quant.into()),
4917                    dshape: vec![
4918                        (DimName::Batch, 1),
4919                        (DimName::NumFeatures, 85),
4920                        (DimName::NumBoxes, 8400),
4921                    ],
4922                    normalized: Some(true),
4923                },
4924                Some(DecoderVersion::Yolo11),
4925            )
4926            .with_score_threshold(score_threshold)
4927            .with_iou_threshold(iou_threshold)
4928            .build()
4929            .unwrap();
4930
4931        let mut output_boxes: Vec<_> = Vec::with_capacity(50);
4932        let mut output_masks: Vec<_> = Vec::with_capacity(50);
4933        let result =
4934            decoder.decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks);
4935
4936        assert!(matches!(
4937            result, Err(DecoderError::InvalidShape(s)) if s == "Did not find output with shape [1, 85, 8400]"));
4938
4939        let result = decoder.decode_float(
4940            &[out_float.view().into_dyn()],
4941            &mut output_boxes,
4942            &mut output_masks,
4943        );
4944
4945        assert!(matches!(
4946            result, Err(DecoderError::InvalidShape(s)) if s == "Did not find output with shape [1, 85, 8400]"));
4947    }
4948
4949    #[test]
4950    fn test_config_outputs() {
4951        let outputs = [
4952            ConfigOutput::Detection(configs::Detection {
4953                decoder: configs::DecoderType::Ultralytics,
4954                anchors: None,
4955                shape: vec![1, 8400, 85],
4956                quantization: Some(QuantTuple(0.123, 0)),
4957                dshape: vec![
4958                    (DimName::Batch, 1),
4959                    (DimName::NumBoxes, 8400),
4960                    (DimName::NumFeatures, 85),
4961                ],
4962                normalized: Some(true),
4963            }),
4964            ConfigOutput::Mask(configs::Mask {
4965                decoder: configs::DecoderType::Ultralytics,
4966                shape: vec![1, 160, 160, 1],
4967                quantization: Some(QuantTuple(0.223, 0)),
4968                dshape: vec![
4969                    (DimName::Batch, 1),
4970                    (DimName::Height, 160),
4971                    (DimName::Width, 160),
4972                    (DimName::NumFeatures, 1),
4973                ],
4974            }),
4975            ConfigOutput::Segmentation(configs::Segmentation {
4976                decoder: configs::DecoderType::Ultralytics,
4977                shape: vec![1, 160, 160, 80],
4978                quantization: Some(QuantTuple(0.323, 0)),
4979                dshape: vec![
4980                    (DimName::Batch, 1),
4981                    (DimName::Height, 160),
4982                    (DimName::Width, 160),
4983                    (DimName::NumClasses, 80),
4984                ],
4985            }),
4986            ConfigOutput::Scores(configs::Scores {
4987                decoder: configs::DecoderType::Ultralytics,
4988                shape: vec![1, 8400, 80],
4989                quantization: Some(QuantTuple(0.423, 0)),
4990                dshape: vec![
4991                    (DimName::Batch, 1),
4992                    (DimName::NumBoxes, 8400),
4993                    (DimName::NumClasses, 80),
4994                ],
4995            }),
4996            ConfigOutput::Boxes(configs::Boxes {
4997                decoder: configs::DecoderType::Ultralytics,
4998                shape: vec![1, 8400, 4],
4999                quantization: Some(QuantTuple(0.523, 0)),
5000                dshape: vec![
5001                    (DimName::Batch, 1),
5002                    (DimName::NumBoxes, 8400),
5003                    (DimName::BoxCoords, 4),
5004                ],
5005                normalized: Some(true),
5006            }),
5007            ConfigOutput::Protos(configs::Protos {
5008                decoder: configs::DecoderType::Ultralytics,
5009                shape: vec![1, 32, 160, 160],
5010                quantization: Some(QuantTuple(0.623, 0)),
5011                dshape: vec![
5012                    (DimName::Batch, 1),
5013                    (DimName::NumProtos, 32),
5014                    (DimName::Height, 160),
5015                    (DimName::Width, 160),
5016                ],
5017            }),
5018            ConfigOutput::MaskCoefficients(configs::MaskCoefficients {
5019                decoder: configs::DecoderType::Ultralytics,
5020                shape: vec![1, 8400, 32],
5021                quantization: Some(QuantTuple(0.723, 0)),
5022                dshape: vec![
5023                    (DimName::Batch, 1),
5024                    (DimName::NumBoxes, 8400),
5025                    (DimName::NumProtos, 32),
5026                ],
5027            }),
5028        ];
5029
5030        let shapes = outputs.clone().map(|x| x.shape().to_vec());
5031        assert_eq!(
5032            shapes,
5033            [
5034                vec![1, 8400, 85],
5035                vec![1, 160, 160, 1],
5036                vec![1, 160, 160, 80],
5037                vec![1, 8400, 80],
5038                vec![1, 8400, 4],
5039                vec![1, 32, 160, 160],
5040                vec![1, 8400, 32],
5041            ]
5042        );
5043
5044        let quants: [Option<(f32, i32)>; 7] = outputs.map(|x| x.quantization().map(|q| q.into()));
5045        assert_eq!(
5046            quants,
5047            [
5048                Some((0.123, 0)),
5049                Some((0.223, 0)),
5050                Some((0.323, 0)),
5051                Some((0.423, 0)),
5052                Some((0.523, 0)),
5053                Some((0.623, 0)),
5054                Some((0.723, 0)),
5055            ]
5056        );
5057    }
5058
5059    #[test]
5060    fn test_nms_from_config_yaml() {
5061        // Test parsing NMS from YAML config
5062        let yaml_class_agnostic = r#"
5063outputs:
5064  - decoder: ultralytics
5065    type: detection
5066    shape: [1, 84, 8400]
5067    dshape:
5068      - [batch, 1]
5069      - [num_features, 84]
5070      - [num_boxes, 8400]
5071nms: class_agnostic
5072"#;
5073        let decoder = DecoderBuilder::new()
5074            .with_config_yaml_str(yaml_class_agnostic.to_string())
5075            .build()
5076            .unwrap();
5077        assert_eq!(decoder.nms, Some(configs::Nms::ClassAgnostic));
5078
5079        let yaml_class_aware = r#"
5080outputs:
5081  - decoder: ultralytics
5082    type: detection
5083    shape: [1, 84, 8400]
5084    dshape:
5085      - [batch, 1]
5086      - [num_features, 84]
5087      - [num_boxes, 8400]
5088nms: class_aware
5089"#;
5090        let decoder = DecoderBuilder::new()
5091            .with_config_yaml_str(yaml_class_aware.to_string())
5092            .build()
5093            .unwrap();
5094        assert_eq!(decoder.nms, Some(configs::Nms::ClassAware));
5095
5096        // Test that config NMS overrides builder NMS
5097        let decoder = DecoderBuilder::new()
5098            .with_config_yaml_str(yaml_class_aware.to_string())
5099            .with_nms(Some(configs::Nms::ClassAgnostic)) // Builder sets agnostic
5100            .build()
5101            .unwrap();
5102        // Config should override builder
5103        assert_eq!(decoder.nms, Some(configs::Nms::ClassAware));
5104    }
5105
5106    #[test]
5107    fn test_nms_from_config_json() {
5108        // Test parsing NMS from JSON config
5109        let json_class_aware = r#"{
5110            "outputs": [{
5111                "decoder": "ultralytics",
5112                "type": "detection",
5113                "shape": [1, 84, 8400],
5114                "dshape": [["batch", 1], ["num_features", 84], ["num_boxes", 8400]]
5115            }],
5116            "nms": "class_aware"
5117        }"#;
5118        let decoder = DecoderBuilder::new()
5119            .with_config_json_str(json_class_aware.to_string())
5120            .build()
5121            .unwrap();
5122        assert_eq!(decoder.nms, Some(configs::Nms::ClassAware));
5123    }
5124
5125    #[test]
5126    fn test_nms_missing_from_config_uses_builder_default() {
5127        // Test that missing NMS in config uses builder default
5128        let yaml_no_nms = r#"
5129outputs:
5130  - decoder: ultralytics
5131    type: detection
5132    shape: [1, 84, 8400]
5133    dshape:
5134      - [batch, 1]
5135      - [num_features, 84]
5136      - [num_boxes, 8400]
5137"#;
5138        let decoder = DecoderBuilder::new()
5139            .with_config_yaml_str(yaml_no_nms.to_string())
5140            .build()
5141            .unwrap();
5142        // Default builder NMS is ClassAgnostic
5143        assert_eq!(decoder.nms, Some(configs::Nms::ClassAgnostic));
5144
5145        // Test with explicit builder NMS
5146        let decoder = DecoderBuilder::new()
5147            .with_config_yaml_str(yaml_no_nms.to_string())
5148            .with_nms(None) // Explicitly set to None (bypass NMS)
5149            .build()
5150            .unwrap();
5151        assert_eq!(decoder.nms, None);
5152    }
5153
5154    #[test]
5155    fn test_decoder_version_yolo26_end_to_end() {
5156        // Test that decoder_version: yolo26 creates end-to-end model type
5157        let yaml = r#"
5158outputs:
5159  - decoder: ultralytics
5160    type: detection
5161    shape: [1, 6, 8400]
5162    dshape:
5163      - [batch, 1]
5164      - [num_features, 6]
5165      - [num_boxes, 8400]
5166decoder_version: yolo26
5167"#;
5168        let decoder = DecoderBuilder::new()
5169            .with_config_yaml_str(yaml.to_string())
5170            .build()
5171            .unwrap();
5172        assert!(matches!(
5173            decoder.model_type,
5174            ModelType::YoloEndToEndDet { .. }
5175        ));
5176
5177        // Even with NMS set, yolo26 should use end-to-end
5178        let yaml_with_nms = r#"
5179outputs:
5180  - decoder: ultralytics
5181    type: detection
5182    shape: [1, 6, 8400]
5183    dshape:
5184      - [batch, 1]
5185      - [num_features, 6]
5186      - [num_boxes, 8400]
5187decoder_version: yolo26
5188nms: class_agnostic
5189"#;
5190        let decoder = DecoderBuilder::new()
5191            .with_config_yaml_str(yaml_with_nms.to_string())
5192            .build()
5193            .unwrap();
5194        assert!(matches!(
5195            decoder.model_type,
5196            ModelType::YoloEndToEndDet { .. }
5197        ));
5198    }
5199
5200    #[test]
5201    fn test_decoder_version_yolov8_traditional() {
5202        // Test that decoder_version: yolov8 creates traditional model type
5203        let yaml = r#"
5204outputs:
5205  - decoder: ultralytics
5206    type: detection
5207    shape: [1, 84, 8400]
5208    dshape:
5209      - [batch, 1]
5210      - [num_features, 84]
5211      - [num_boxes, 8400]
5212decoder_version: yolov8
5213"#;
5214        let decoder = DecoderBuilder::new()
5215            .with_config_yaml_str(yaml.to_string())
5216            .build()
5217            .unwrap();
5218        assert!(matches!(decoder.model_type, ModelType::YoloDet { .. }));
5219    }
5220
5221    #[test]
5222    fn test_decoder_version_all_versions() {
5223        // Test all supported decoder versions parse correctly
5224        for version in ["yolov5", "yolov8", "yolo11"] {
5225            let yaml = format!(
5226                r#"
5227outputs:
5228  - decoder: ultralytics
5229    type: detection
5230    shape: [1, 84, 8400]
5231    dshape:
5232      - [batch, 1]
5233      - [num_features, 84]
5234      - [num_boxes, 8400]
5235decoder_version: {}
5236"#,
5237                version
5238            );
5239            let decoder = DecoderBuilder::new()
5240                .with_config_yaml_str(yaml)
5241                .build()
5242                .unwrap();
5243
5244            assert!(
5245                matches!(decoder.model_type, ModelType::YoloDet { .. }),
5246                "Expected traditional for {}",
5247                version
5248            );
5249        }
5250
5251        let yaml = r#"
5252outputs:
5253  - decoder: ultralytics
5254    type: detection
5255    shape: [1, 6, 8400]
5256    dshape:
5257      - [batch, 1]
5258      - [num_features, 6]
5259      - [num_boxes, 8400]
5260decoder_version: yolo26
5261"#
5262        .to_string();
5263
5264        let decoder = DecoderBuilder::new()
5265            .with_config_yaml_str(yaml)
5266            .build()
5267            .unwrap();
5268
5269        assert!(
5270            matches!(decoder.model_type, ModelType::YoloEndToEndDet { .. }),
5271            "Expected end to end for yolo26",
5272        );
5273    }
5274
5275    #[test]
5276    fn test_decoder_version_json() {
5277        // Test parsing decoder_version from JSON config
5278        let json = r#"{
5279            "outputs": [{
5280                "decoder": "ultralytics",
5281                "type": "detection",
5282                "shape": [1, 6, 8400],
5283                "dshape": [["batch", 1], ["num_features", 6], ["num_boxes", 8400]]
5284            }],
5285            "decoder_version": "yolo26"
5286        }"#;
5287        let decoder = DecoderBuilder::new()
5288            .with_config_json_str(json.to_string())
5289            .build()
5290            .unwrap();
5291        assert!(matches!(
5292            decoder.model_type,
5293            ModelType::YoloEndToEndDet { .. }
5294        ));
5295    }
5296
5297    #[test]
5298    fn test_decoder_version_none_uses_traditional() {
5299        // Without decoder_version, traditional model type is used
5300        let yaml = r#"
5301outputs:
5302  - decoder: ultralytics
5303    type: detection
5304    shape: [1, 84, 8400]
5305    dshape:
5306      - [batch, 1]
5307      - [num_features, 84]
5308      - [num_boxes, 8400]
5309"#;
5310        let decoder = DecoderBuilder::new()
5311            .with_config_yaml_str(yaml.to_string())
5312            .build()
5313            .unwrap();
5314        assert!(matches!(decoder.model_type, ModelType::YoloDet { .. }));
5315    }
5316
5317    #[test]
5318    fn test_decoder_version_none_with_nms_none_still_traditional() {
5319        // Without decoder_version, nms: None now means user handles NMS, not end-to-end
5320        // This is a behavior change from the previous implementation
5321        let yaml = r#"
5322outputs:
5323  - decoder: ultralytics
5324    type: detection
5325    shape: [1, 84, 8400]
5326    dshape:
5327      - [batch, 1]
5328      - [num_features, 84]
5329      - [num_boxes, 8400]
5330"#;
5331        let decoder = DecoderBuilder::new()
5332            .with_config_yaml_str(yaml.to_string())
5333            .with_nms(None) // User wants to handle NMS themselves
5334            .build()
5335            .unwrap();
5336        // nms=None with 84 features (80 classes) -> traditional YoloDet (user handles
5337        // NMS)
5338        assert!(matches!(decoder.model_type, ModelType::YoloDet { .. }));
5339    }
5340
5341    #[test]
5342    fn test_decoder_heuristic_end_to_end_detection() {
5343        // models with (batch, num_boxes, num_features) output shape are treated
5344        // as end-to-end detection
5345        let yaml = r#"
5346outputs:
5347  - decoder: ultralytics
5348    type: detection
5349    shape: [1, 300, 6]
5350    dshape:
5351      - [batch, 1]
5352      - [num_boxes, 300]
5353      - [num_features, 6]
5354 
5355"#;
5356        let decoder = DecoderBuilder::new()
5357            .with_config_yaml_str(yaml.to_string())
5358            .build()
5359            .unwrap();
5360        // 6 features with (batch, N, features) layout -> end-to-end detection
5361        assert!(matches!(
5362            decoder.model_type,
5363            ModelType::YoloEndToEndDet { .. }
5364        ));
5365
5366        let yaml = r#"
5367outputs:
5368  - decoder: ultralytics
5369    type: detection
5370    shape: [1, 300, 38]
5371    dshape:
5372      - [batch, 1]
5373      - [num_boxes, 300]
5374      - [num_features, 38]
5375  - decoder: ultralytics
5376    type: protos
5377    shape: [1, 160, 160, 32]
5378    dshape:
5379      - [batch, 1]
5380      - [height, 160]
5381      - [width, 160]
5382      - [num_protos, 32]
5383"#;
5384        let decoder = DecoderBuilder::new()
5385            .with_config_yaml_str(yaml.to_string())
5386            .build()
5387            .unwrap();
5388        // 7 features with protos -> end-to-end segmentation detection
5389        assert!(matches!(
5390            decoder.model_type,
5391            ModelType::YoloEndToEndSegDet { .. }
5392        ));
5393
5394        let yaml = r#"
5395outputs:
5396  - decoder: ultralytics
5397    type: detection
5398    shape: [1, 6, 300]
5399    dshape:
5400      - [batch, 1]
5401      - [num_features, 6]
5402      - [num_boxes, 300] 
5403"#;
5404        let decoder = DecoderBuilder::new()
5405            .with_config_yaml_str(yaml.to_string())
5406            .build()
5407            .unwrap();
5408        // 6 features -> traditional YOLO detection (needs num_classes > 0 for
5409        // end-to-end)
5410        assert!(matches!(decoder.model_type, ModelType::YoloDet { .. }));
5411
5412        let yaml = r#"
5413outputs:
5414  - decoder: ultralytics
5415    type: detection
5416    shape: [1, 38, 300]
5417    dshape:
5418      - [batch, 1]
5419      - [num_features, 38]
5420      - [num_boxes, 300]
5421
5422  - decoder: ultralytics
5423    type: protos
5424    shape: [1, 160, 160, 32]
5425    dshape:
5426      - [batch, 1]
5427      - [height, 160]
5428      - [width, 160]
5429      - [num_protos, 32]
5430"#;
5431        let decoder = DecoderBuilder::new()
5432            .with_config_yaml_str(yaml.to_string())
5433            .build()
5434            .unwrap();
5435        // 38 features (4+2+32) with protos -> traditional YOLO segmentation detection
5436        assert!(matches!(decoder.model_type, ModelType::YoloSegDet { .. }));
5437    }
5438
5439    #[test]
5440    fn test_decoder_version_is_end_to_end() {
5441        assert!(!configs::DecoderVersion::Yolov5.is_end_to_end());
5442        assert!(!configs::DecoderVersion::Yolov8.is_end_to_end());
5443        assert!(!configs::DecoderVersion::Yolo11.is_end_to_end());
5444        assert!(configs::DecoderVersion::Yolo26.is_end_to_end());
5445    }
5446}