Skip to main content

edgefirst_decoder/
decoder.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::HashSet;
5
6use ndarray::{s, Array3, ArrayView, ArrayViewD, Dimension};
7use ndarray_stats::QuantileExt;
8use num_traits::{AsPrimitive, Float};
9use serde::{Deserialize, Serialize};
10
11use crate::{
12    configs::{DecoderType, DimName, ModelType, QuantTuple},
13    dequantize_ndarray,
14    modelpack::{
15        decode_modelpack_det, decode_modelpack_float, decode_modelpack_split_float,
16        ModelPackDetectionConfig,
17    },
18    yolo::{
19        decode_yolo_det, decode_yolo_det_float, decode_yolo_segdet_float, decode_yolo_segdet_quant,
20        decode_yolo_split_det_float, decode_yolo_split_det_quant, decode_yolo_split_segdet_float,
21        impl_yolo_split_segdet_quant_get_boxes, impl_yolo_split_segdet_quant_process_masks,
22    },
23    DecoderError, DecoderVersion, DetectBox, ProtoData, Quantization, Segmentation, XYWH,
24};
25
26/// Used to represent the outputs in the model configuration.
27/// # Examples
28/// ```rust
29/// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, ConfigOutputs};
30/// # fn main() -> DecoderResult<()> {
31/// let config_json = include_str!("../../../testdata/modelpack_split.json");
32/// let config: ConfigOutputs = serde_json::from_str(config_json)?;
33/// let decoder = DecoderBuilder::new().with_config(config).build()?;
34///
35/// # Ok(())
36/// # }
37#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
38pub struct ConfigOutputs {
39    #[serde(default)]
40    pub outputs: Vec<ConfigOutput>,
41    /// NMS mode from config file. When present, overrides the builder's NMS
42    /// setting.
43    /// - `Some(Nms::ClassAgnostic)` — class-agnostic NMS: suppress overlapping
44    ///   boxes regardless of class
45    /// - `Some(Nms::ClassAware)` — class-aware NMS: only suppress boxes with
46    ///   the same class
47    /// - `None` — use builder default or skip NMS (user handles it externally)
48    #[serde(default, skip_serializing_if = "Option::is_none")]
49    pub nms: Option<configs::Nms>,
50    /// Decoder version for Ultralytics models. Determines the decoding
51    /// strategy.
52    /// - `Some(Yolo26)` — end-to-end model with embedded NMS
53    /// - `Some(Yolov5/Yolov8/Yolo11)` — traditional models requiring external
54    ///   NMS
55    /// - `None` — infer from other settings (legacy behavior)
56    #[serde(default, skip_serializing_if = "Option::is_none")]
57    pub decoder_version: Option<configs::DecoderVersion>,
58}
59
60#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
61#[serde(tag = "type")]
62pub enum ConfigOutput {
63    #[serde(rename = "detection")]
64    Detection(configs::Detection),
65    #[serde(rename = "masks")]
66    Mask(configs::Mask),
67    #[serde(rename = "segmentation")]
68    Segmentation(configs::Segmentation),
69    #[serde(rename = "protos")]
70    Protos(configs::Protos),
71    #[serde(rename = "scores")]
72    Scores(configs::Scores),
73    #[serde(rename = "boxes")]
74    Boxes(configs::Boxes),
75    #[serde(rename = "mask_coefficients")]
76    MaskCoefficients(configs::MaskCoefficients),
77    #[serde(rename = "classes")]
78    Classes(configs::Classes),
79}
80
81#[derive(Debug, PartialEq, Clone)]
82pub enum ConfigOutputRef<'a> {
83    Detection(&'a configs::Detection),
84    Mask(&'a configs::Mask),
85    Segmentation(&'a configs::Segmentation),
86    Protos(&'a configs::Protos),
87    Scores(&'a configs::Scores),
88    Boxes(&'a configs::Boxes),
89    MaskCoefficients(&'a configs::MaskCoefficients),
90    Classes(&'a configs::Classes),
91}
92
93impl<'a> ConfigOutputRef<'a> {
94    fn decoder(&self) -> configs::DecoderType {
95        match self {
96            ConfigOutputRef::Detection(v) => v.decoder,
97            ConfigOutputRef::Mask(v) => v.decoder,
98            ConfigOutputRef::Segmentation(v) => v.decoder,
99            ConfigOutputRef::Protos(v) => v.decoder,
100            ConfigOutputRef::Scores(v) => v.decoder,
101            ConfigOutputRef::Boxes(v) => v.decoder,
102            ConfigOutputRef::MaskCoefficients(v) => v.decoder,
103            ConfigOutputRef::Classes(v) => v.decoder,
104        }
105    }
106
107    fn dshape(&self) -> &[(DimName, usize)] {
108        match self {
109            ConfigOutputRef::Detection(v) => &v.dshape,
110            ConfigOutputRef::Mask(v) => &v.dshape,
111            ConfigOutputRef::Segmentation(v) => &v.dshape,
112            ConfigOutputRef::Protos(v) => &v.dshape,
113            ConfigOutputRef::Scores(v) => &v.dshape,
114            ConfigOutputRef::Boxes(v) => &v.dshape,
115            ConfigOutputRef::MaskCoefficients(v) => &v.dshape,
116            ConfigOutputRef::Classes(v) => &v.dshape,
117        }
118    }
119}
120
121impl<'a> From<&'a configs::Detection> for ConfigOutputRef<'a> {
122    /// Converts from references of config structs to ConfigOutputRef
123    /// # Examples
124    /// ```rust
125    /// # use edgefirst_decoder::{configs, ConfigOutputRef};
126    /// let detection_config = configs::Detection {
127    ///     anchors: None,
128    ///     decoder: configs::DecoderType::Ultralytics,
129    ///     quantization: None,
130    ///     shape: vec![1, 84, 8400],
131    ///     dshape: Vec::new(),
132    ///     normalized: Some(true),
133    /// };
134    /// let output: ConfigOutputRef = (&detection_config).into();
135    /// ```
136    fn from(v: &'a configs::Detection) -> ConfigOutputRef<'a> {
137        ConfigOutputRef::Detection(v)
138    }
139}
140
141impl<'a> From<&'a configs::Mask> for ConfigOutputRef<'a> {
142    /// Converts from references of config structs to ConfigOutputRef
143    /// # Examples
144    /// ```rust
145    /// # use edgefirst_decoder::{configs, ConfigOutputRef};
146    /// let mask = configs::Mask {
147    ///     decoder: configs::DecoderType::ModelPack,
148    ///     quantization: None,
149    ///     shape: vec![1, 160, 160, 1],
150    ///     dshape: Vec::new(),
151    /// };
152    /// let output: ConfigOutputRef = (&mask).into();
153    /// ```
154    fn from(v: &'a configs::Mask) -> ConfigOutputRef<'a> {
155        ConfigOutputRef::Mask(v)
156    }
157}
158
159impl<'a> From<&'a configs::Segmentation> for ConfigOutputRef<'a> {
160    /// Converts from references of config structs to ConfigOutputRef
161    /// # Examples
162    /// ```rust
163    /// # use edgefirst_decoder::{configs, ConfigOutputRef};
164    /// let seg = configs::Segmentation {
165    ///     decoder: configs::DecoderType::ModelPack,
166    ///     quantization: None,
167    ///     shape: vec![1, 160, 160, 3],
168    ///     dshape: Vec::new(),
169    /// };
170    /// let output: ConfigOutputRef = (&seg).into();
171    /// ```
172    fn from(v: &'a configs::Segmentation) -> ConfigOutputRef<'a> {
173        ConfigOutputRef::Segmentation(v)
174    }
175}
176
177impl<'a> From<&'a configs::Protos> for ConfigOutputRef<'a> {
178    /// Converts from references of config structs to ConfigOutputRef
179    /// # Examples
180    /// ```rust
181    /// # use edgefirst_decoder::{configs, ConfigOutputRef};
182    /// let protos = configs::Protos {
183    ///     decoder: configs::DecoderType::Ultralytics,
184    ///     quantization: None,
185    ///     shape: vec![1, 160, 160, 32],
186    ///     dshape: Vec::new(),
187    /// };
188    /// let output: ConfigOutputRef = (&protos).into();
189    /// ```
190    fn from(v: &'a configs::Protos) -> ConfigOutputRef<'a> {
191        ConfigOutputRef::Protos(v)
192    }
193}
194
195impl<'a> From<&'a configs::Scores> for ConfigOutputRef<'a> {
196    /// Converts from references of config structs to ConfigOutputRef
197    /// # Examples
198    /// ```rust
199    /// # use edgefirst_decoder::{configs, ConfigOutputRef};
200    /// let scores = configs::Scores {
201    ///     decoder: configs::DecoderType::Ultralytics,
202    ///     quantization: None,
203    ///     shape: vec![1, 40, 8400],
204    ///     dshape: Vec::new(),
205    /// };
206    /// let output: ConfigOutputRef = (&scores).into();
207    /// ```
208    fn from(v: &'a configs::Scores) -> ConfigOutputRef<'a> {
209        ConfigOutputRef::Scores(v)
210    }
211}
212
213impl<'a> From<&'a configs::Boxes> for ConfigOutputRef<'a> {
214    /// Converts from references of config structs to ConfigOutputRef
215    /// # Examples
216    /// ```rust
217    /// # use edgefirst_decoder::{configs, ConfigOutputRef};
218    /// let boxes = configs::Boxes {
219    ///     decoder: configs::DecoderType::Ultralytics,
220    ///     quantization: None,
221    ///     shape: vec![1, 4, 8400],
222    ///     dshape: Vec::new(),
223    ///     normalized: Some(true),
224    /// };
225    /// let output: ConfigOutputRef = (&boxes).into();
226    /// ```
227    fn from(v: &'a configs::Boxes) -> ConfigOutputRef<'a> {
228        ConfigOutputRef::Boxes(v)
229    }
230}
231
232impl<'a> From<&'a configs::MaskCoefficients> for ConfigOutputRef<'a> {
233    /// Converts from references of config structs to ConfigOutputRef
234    /// # Examples
235    /// ```rust
236    /// # use edgefirst_decoder::{configs, ConfigOutputRef};
237    /// let mask_coefficients = configs::MaskCoefficients {
238    ///     decoder: configs::DecoderType::Ultralytics,
239    ///     quantization: None,
240    ///     shape: vec![1, 32, 8400],
241    ///     dshape: Vec::new(),
242    /// };
243    /// let output: ConfigOutputRef = (&mask_coefficients).into();
244    /// ```
245    fn from(v: &'a configs::MaskCoefficients) -> ConfigOutputRef<'a> {
246        ConfigOutputRef::MaskCoefficients(v)
247    }
248}
249
250impl<'a> From<&'a configs::Classes> for ConfigOutputRef<'a> {
251    fn from(v: &'a configs::Classes) -> ConfigOutputRef<'a> {
252        ConfigOutputRef::Classes(v)
253    }
254}
255
256impl ConfigOutput {
257    /// Returns the shape of the output.
258    ///
259    /// # Examples
260    /// ```rust
261    /// # use edgefirst_decoder::{configs, ConfigOutput};
262    /// let detection_config = configs::Detection {
263    ///     anchors: None,
264    ///     decoder: configs::DecoderType::Ultralytics,
265    ///     quantization: None,
266    ///     shape: vec![1, 84, 8400],
267    ///     dshape: Vec::new(),
268    ///     normalized: Some(true),
269    /// };
270    /// let output = ConfigOutput::Detection(detection_config);
271    /// assert_eq!(output.shape(), &[1, 84, 8400]);
272    /// ```
273    pub fn shape(&self) -> &[usize] {
274        match self {
275            ConfigOutput::Detection(detection) => &detection.shape,
276            ConfigOutput::Mask(mask) => &mask.shape,
277            ConfigOutput::Segmentation(segmentation) => &segmentation.shape,
278            ConfigOutput::Scores(scores) => &scores.shape,
279            ConfigOutput::Boxes(boxes) => &boxes.shape,
280            ConfigOutput::Protos(protos) => &protos.shape,
281            ConfigOutput::MaskCoefficients(mask_coefficients) => &mask_coefficients.shape,
282            ConfigOutput::Classes(classes) => &classes.shape,
283        }
284    }
285
286    /// Returns the decoder type of the output.
287    ///    
288    /// # Examples
289    /// ```rust
290    /// # use edgefirst_decoder::{configs, ConfigOutput};
291    /// let detection_config = configs::Detection {
292    ///     anchors: None,
293    ///     decoder: configs::DecoderType::Ultralytics,
294    ///     quantization: None,
295    ///     shape: vec![1, 84, 8400],
296    ///     dshape: Vec::new(),
297    ///     normalized: Some(true),
298    /// };
299    /// let output = ConfigOutput::Detection(detection_config);
300    /// assert_eq!(output.decoder(), &configs::DecoderType::Ultralytics);
301    /// ```
302    pub fn decoder(&self) -> &configs::DecoderType {
303        match self {
304            ConfigOutput::Detection(detection) => &detection.decoder,
305            ConfigOutput::Mask(mask) => &mask.decoder,
306            ConfigOutput::Segmentation(segmentation) => &segmentation.decoder,
307            ConfigOutput::Scores(scores) => &scores.decoder,
308            ConfigOutput::Boxes(boxes) => &boxes.decoder,
309            ConfigOutput::Protos(protos) => &protos.decoder,
310            ConfigOutput::MaskCoefficients(mask_coefficients) => &mask_coefficients.decoder,
311            ConfigOutput::Classes(classes) => &classes.decoder,
312        }
313    }
314
315    /// Returns the quantization of the output.
316    ///
317    /// # Examples
318    /// ```rust
319    /// # use edgefirst_decoder::{configs, ConfigOutput};
320    /// let detection_config = configs::Detection {
321    ///   anchors: None,
322    ///   decoder: configs::DecoderType::Ultralytics,
323    ///   quantization: Some(configs::QuantTuple(0.012345, 26)),
324    ///   shape: vec![1, 84, 8400],
325    ///   dshape: Vec::new(),
326    ///   normalized: Some(true),
327    /// };
328    /// let output = ConfigOutput::Detection(detection_config);
329    /// assert_eq!(output.quantization(),
330    /// Some(configs::QuantTuple(0.012345,26))); ```
331    pub fn quantization(&self) -> Option<QuantTuple> {
332        match self {
333            ConfigOutput::Detection(detection) => detection.quantization,
334            ConfigOutput::Mask(mask) => mask.quantization,
335            ConfigOutput::Segmentation(segmentation) => segmentation.quantization,
336            ConfigOutput::Scores(scores) => scores.quantization,
337            ConfigOutput::Boxes(boxes) => boxes.quantization,
338            ConfigOutput::Protos(protos) => protos.quantization,
339            ConfigOutput::MaskCoefficients(mask_coefficients) => mask_coefficients.quantization,
340            ConfigOutput::Classes(classes) => classes.quantization,
341        }
342    }
343}
344
345pub mod configs {
346    use std::collections::HashMap;
347    use std::fmt::Display;
348
349    use serde::{Deserialize, Serialize};
350
351    /// Deserialize dshape from either array-of-tuples or array-of-single-key-dicts.
352    ///
353    /// The metadata spec produces `[{"batch": 1}, {"num_features": 84}]` (dict format),
354    /// while serde's default `Vec<(A, B)>` expects `[["batch", 1]]` (tuple format).
355    /// This deserializer accepts both.
356    pub fn deserialize_dshape<'de, D>(deserializer: D) -> Result<Vec<(DimName, usize)>, D::Error>
357    where
358        D: serde::Deserializer<'de>,
359    {
360        #[derive(Deserialize)]
361        #[serde(untagged)]
362        enum DShapeItem {
363            Tuple(DimName, usize),
364            Map(HashMap<DimName, usize>),
365        }
366
367        let items: Vec<DShapeItem> = Vec::deserialize(deserializer)?;
368        items
369            .into_iter()
370            .map(|item| match item {
371                DShapeItem::Tuple(name, size) => Ok((name, size)),
372                DShapeItem::Map(map) => {
373                    if map.len() != 1 {
374                        return Err(serde::de::Error::custom(
375                            "dshape map entry must have exactly one key",
376                        ));
377                    }
378                    let (name, size) = map.into_iter().next().unwrap();
379                    Ok((name, size))
380                }
381            })
382            .collect()
383    }
384
385    #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy)]
386    pub struct QuantTuple(pub f32, pub i32);
387    impl From<QuantTuple> for (f32, i32) {
388        fn from(value: QuantTuple) -> Self {
389            (value.0, value.1)
390        }
391    }
392
393    impl From<(f32, i32)> for QuantTuple {
394        fn from(value: (f32, i32)) -> Self {
395            QuantTuple(value.0, value.1)
396        }
397    }
398
399    #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
400    pub struct Segmentation {
401        #[serde(default)]
402        pub decoder: DecoderType,
403        #[serde(default)]
404        pub quantization: Option<QuantTuple>,
405        #[serde(default)]
406        pub shape: Vec<usize>,
407        #[serde(default, deserialize_with = "deserialize_dshape")]
408        pub dshape: Vec<(DimName, usize)>,
409    }
410
411    #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
412    pub struct Protos {
413        #[serde(default)]
414        pub decoder: DecoderType,
415        #[serde(default)]
416        pub quantization: Option<QuantTuple>,
417        #[serde(default)]
418        pub shape: Vec<usize>,
419        #[serde(default, deserialize_with = "deserialize_dshape")]
420        pub dshape: Vec<(DimName, usize)>,
421    }
422
423    #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
424    pub struct MaskCoefficients {
425        #[serde(default)]
426        pub decoder: DecoderType,
427        #[serde(default)]
428        pub quantization: Option<QuantTuple>,
429        #[serde(default)]
430        pub shape: Vec<usize>,
431        #[serde(default, deserialize_with = "deserialize_dshape")]
432        pub dshape: Vec<(DimName, usize)>,
433    }
434
435    #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
436    pub struct Mask {
437        #[serde(default)]
438        pub decoder: DecoderType,
439        #[serde(default)]
440        pub quantization: Option<QuantTuple>,
441        #[serde(default)]
442        pub shape: Vec<usize>,
443        #[serde(default, deserialize_with = "deserialize_dshape")]
444        pub dshape: Vec<(DimName, usize)>,
445    }
446
447    #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
448    pub struct Detection {
449        #[serde(default)]
450        pub anchors: Option<Vec<[f32; 2]>>,
451        #[serde(default)]
452        pub decoder: DecoderType,
453        #[serde(default)]
454        pub quantization: Option<QuantTuple>,
455        #[serde(default)]
456        pub shape: Vec<usize>,
457        #[serde(default, deserialize_with = "deserialize_dshape")]
458        pub dshape: Vec<(DimName, usize)>,
459        /// Whether box coordinates are normalized to [0,1] range.
460        /// - `Some(true)`: Coordinates in [0,1] range relative to model input
461        /// - `Some(false)`: Pixel coordinates relative to model input
462        ///   (letterboxed)
463        /// - `None`: Unknown, caller must infer (e.g., check if any coordinate
464        ///   > 1.0)
465        #[serde(default)]
466        pub normalized: Option<bool>,
467    }
468
469    #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
470    pub struct Scores {
471        #[serde(default)]
472        pub decoder: DecoderType,
473        #[serde(default)]
474        pub quantization: Option<QuantTuple>,
475        #[serde(default)]
476        pub shape: Vec<usize>,
477        #[serde(default, deserialize_with = "deserialize_dshape")]
478        pub dshape: Vec<(DimName, usize)>,
479    }
480
481    #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
482    pub struct Boxes {
483        #[serde(default)]
484        pub decoder: DecoderType,
485        #[serde(default)]
486        pub quantization: Option<QuantTuple>,
487        #[serde(default)]
488        pub shape: Vec<usize>,
489        #[serde(default, deserialize_with = "deserialize_dshape")]
490        pub dshape: Vec<(DimName, usize)>,
491        /// Whether box coordinates are normalized to [0,1] range.
492        /// - `Some(true)`: Coordinates in [0,1] range relative to model input
493        /// - `Some(false)`: Pixel coordinates relative to model input
494        ///   (letterboxed)
495        /// - `None`: Unknown, caller must infer (e.g., check if any coordinate
496        ///   > 1.0)
497        #[serde(default)]
498        pub normalized: Option<bool>,
499    }
500
501    #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
502    pub struct Classes {
503        #[serde(default)]
504        pub decoder: DecoderType,
505        #[serde(default)]
506        pub quantization: Option<QuantTuple>,
507        #[serde(default)]
508        pub shape: Vec<usize>,
509        #[serde(default, deserialize_with = "deserialize_dshape")]
510        pub dshape: Vec<(DimName, usize)>,
511    }
512
513    #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq)]
514    pub enum DimName {
515        #[serde(rename = "batch")]
516        Batch,
517        #[serde(rename = "height")]
518        Height,
519        #[serde(rename = "width")]
520        Width,
521        #[serde(rename = "num_classes")]
522        NumClasses,
523        #[serde(rename = "num_features")]
524        NumFeatures,
525        #[serde(rename = "num_boxes")]
526        NumBoxes,
527        #[serde(rename = "num_protos")]
528        NumProtos,
529        #[serde(rename = "num_anchors_x_features")]
530        NumAnchorsXFeatures,
531        #[serde(rename = "padding")]
532        Padding,
533        #[serde(rename = "box_coords")]
534        BoxCoords,
535    }
536
537    impl Display for DimName {
538        /// Formats the DimName for display
539        /// # Examples
540        /// ```rust
541        /// # use edgefirst_decoder::configs::DimName;
542        /// let dim = DimName::Height;
543        /// assert_eq!(format!("{}", dim), "height");
544        /// # let s = format!("{} {} {} {} {} {} {} {} {} {}", DimName::Batch, DimName::Height, DimName::Width, DimName::NumClasses, DimName::NumFeatures, DimName::NumBoxes, DimName::NumProtos, DimName::NumAnchorsXFeatures, DimName::Padding, DimName::BoxCoords);
545        /// # assert_eq!(s, "batch height width num_classes num_features num_boxes num_protos num_anchors_x_features padding box_coords");
546        /// ```
547        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
548            match self {
549                DimName::Batch => write!(f, "batch"),
550                DimName::Height => write!(f, "height"),
551                DimName::Width => write!(f, "width"),
552                DimName::NumClasses => write!(f, "num_classes"),
553                DimName::NumFeatures => write!(f, "num_features"),
554                DimName::NumBoxes => write!(f, "num_boxes"),
555                DimName::NumProtos => write!(f, "num_protos"),
556                DimName::NumAnchorsXFeatures => write!(f, "num_anchors_x_features"),
557                DimName::Padding => write!(f, "padding"),
558                DimName::BoxCoords => write!(f, "box_coords"),
559            }
560        }
561    }
562
563    #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq, Default)]
564    pub enum DecoderType {
565        #[serde(rename = "modelpack")]
566        ModelPack,
567        #[default]
568        #[serde(rename = "ultralytics", alias = "yolov8")]
569        Ultralytics,
570    }
571
572    /// Decoder version for Ultralytics models.
573    ///
574    /// Specifies the YOLO architecture version, which determines the decoding
575    /// strategy:
576    /// - `Yolov5`, `Yolov8`, `Yolo11`: Traditional models requiring external
577    ///   NMS
578    /// - `Yolo26`: End-to-end models with NMS embedded in the model
579    ///   architecture
580    ///
581    /// When `decoder_version` is set to `Yolo26`, the decoder uses end-to-end
582    /// model types regardless of the `nms` setting.
583    #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq)]
584    #[serde(rename_all = "lowercase")]
585    pub enum DecoderVersion {
586        /// YOLOv5 - anchor-free DFL decoder, requires external NMS
587        #[serde(rename = "yolov5")]
588        Yolov5,
589        /// YOLOv8 - anchor-free DFL decoder, requires external NMS
590        #[serde(rename = "yolov8")]
591        Yolov8,
592        /// YOLO11 - anchor-free DFL decoder, requires external NMS
593        #[serde(rename = "yolo11")]
594        Yolo11,
595        /// YOLO26 - end-to-end model with embedded NMS (one-to-one matching
596        /// heads)
597        #[serde(rename = "yolo26")]
598        Yolo26,
599    }
600
601    impl DecoderVersion {
602        /// Returns true if this version uses end-to-end inference (embedded
603        /// NMS).
604        pub fn is_end_to_end(&self) -> bool {
605            matches!(self, DecoderVersion::Yolo26)
606        }
607    }
608
609    /// NMS (Non-Maximum Suppression) mode for filtering overlapping detections.
610    ///
611    /// This enum is used with `Option<Nms>`:
612    /// - `Some(Nms::ClassAgnostic)` — class-agnostic NMS (default): suppress
613    ///   overlapping boxes regardless of class label
614    /// - `Some(Nms::ClassAware)` — class-aware NMS: only suppress boxes that
615    ///   share the same class label AND overlap above the IoU threshold
616    /// - `None` — bypass NMS entirely (for end-to-end models with embedded NMS)
617    #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq, Default)]
618    #[serde(rename_all = "snake_case")]
619    pub enum Nms {
620        /// Suppress overlapping boxes regardless of class label (default HAL
621        /// behavior)
622        #[default]
623        ClassAgnostic,
624        /// Only suppress boxes with the same class label that overlap
625        ClassAware,
626    }
627
628    #[derive(Debug, Clone, PartialEq)]
629    pub enum ModelType {
630        ModelPackSegDet {
631            boxes: Boxes,
632            scores: Scores,
633            segmentation: Segmentation,
634        },
635        ModelPackSegDetSplit {
636            detection: Vec<Detection>,
637            segmentation: Segmentation,
638        },
639        ModelPackDet {
640            boxes: Boxes,
641            scores: Scores,
642        },
643        ModelPackDetSplit {
644            detection: Vec<Detection>,
645        },
646        ModelPackSeg {
647            segmentation: Segmentation,
648        },
649        YoloDet {
650            boxes: Detection,
651        },
652        YoloSegDet {
653            boxes: Detection,
654            protos: Protos,
655        },
656        YoloSplitDet {
657            boxes: Boxes,
658            scores: Scores,
659        },
660        YoloSplitSegDet {
661            boxes: Boxes,
662            scores: Scores,
663            mask_coeff: MaskCoefficients,
664            protos: Protos,
665        },
666        /// End-to-end YOLO detection (post-NMS output from model)
667        /// Input shape: (1, N, 6+) where columns are [x1, y1, x2, y2, conf,
668        /// class, ...]
669        YoloEndToEndDet {
670            boxes: Detection,
671        },
672        /// End-to-end YOLO detection + segmentation (post-NMS output from
673        /// model) Input shape: (1, N, 6 + num_protos) where columns are
674        /// [x1, y1, x2, y2, conf, class, mask_coeff_0, ..., mask_coeff_31]
675        YoloEndToEndSegDet {
676            boxes: Detection,
677            protos: Protos,
678        },
679        /// Split end-to-end YOLO detection (onnx2tf splits [1,N,6] into 3
680        /// tensors) boxes: [batch, N, 4] xyxy, scores: [batch, N, 1],
681        /// classes: [batch, N, 1]
682        YoloSplitEndToEndDet {
683            boxes: Boxes,
684            scores: Scores,
685            classes: Classes,
686        },
687        /// Split end-to-end YOLO seg detection (onnx2tf splits into 5
688        /// tensors)
689        YoloSplitEndToEndSegDet {
690            boxes: Boxes,
691            scores: Scores,
692            classes: Classes,
693            mask_coeff: MaskCoefficients,
694            protos: Protos,
695        },
696    }
697
698    #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
699    #[serde(rename_all = "lowercase")]
700    pub enum DataType {
701        Raw = 0,
702        Int8 = 1,
703        UInt8 = 2,
704        Int16 = 3,
705        UInt16 = 4,
706        Float16 = 5,
707        Int32 = 6,
708        UInt32 = 7,
709        Float32 = 8,
710        Int64 = 9,
711        UInt64 = 10,
712        Float64 = 11,
713        String = 12,
714    }
715}
716
717#[derive(Debug, Clone, PartialEq)]
718pub struct DecoderBuilder {
719    config_src: Option<ConfigSource>,
720    iou_threshold: f32,
721    score_threshold: f32,
722    /// NMS mode: Some(mode) applies NMS, None bypasses NMS (for end-to-end
723    /// models)
724    nms: Option<configs::Nms>,
725}
726
727#[derive(Debug, Clone, PartialEq)]
728enum ConfigSource {
729    Yaml(String),
730    Json(String),
731    Config(ConfigOutputs),
732}
733
734impl Default for DecoderBuilder {
735    /// Creates a default DecoderBuilder with no configuration and 0.5 score
736    /// threshold and 0.5 OU threshold.
737    ///
738    /// A valid confguration must be provided before building the Decoder.
739    ///
740    /// # Examples
741    /// ```rust
742    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
743    /// # fn main() -> DecoderResult<()> {
744    /// #  let config_yaml = include_str!("../../../testdata/modelpack_split.yaml").to_string();
745    /// let decoder = DecoderBuilder::default()
746    ///     .with_config_yaml_str(config_yaml)
747    ///     .build()?;
748    /// assert_eq!(decoder.score_threshold, 0.5);
749    /// assert_eq!(decoder.iou_threshold, 0.5);
750    ///
751    /// # Ok(())
752    /// # }
753    /// ```
754    fn default() -> Self {
755        Self {
756            config_src: None,
757            iou_threshold: 0.5,
758            score_threshold: 0.5,
759            nms: Some(configs::Nms::ClassAgnostic),
760        }
761    }
762}
763
764impl DecoderBuilder {
765    /// Creates a default DecoderBuilder with no configuration and 0.5 score
766    /// threshold and 0.5 OU threshold.
767    ///
768    /// A valid confguration must be provided before building the Decoder.
769    ///
770    /// # Examples
771    /// ```rust
772    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
773    /// # fn main() -> DecoderResult<()> {
774    /// #  let config_yaml = include_str!("../../../testdata/modelpack_split.yaml").to_string();
775    /// let decoder = DecoderBuilder::new()
776    ///     .with_config_yaml_str(config_yaml)
777    ///     .build()?;
778    /// assert_eq!(decoder.score_threshold, 0.5);
779    /// assert_eq!(decoder.iou_threshold, 0.5);
780    ///
781    /// # Ok(())
782    /// # }
783    /// ```
784    pub fn new() -> Self {
785        Self::default()
786    }
787
788    /// Loads a model configuration in YAML format. Does not check if the string
789    /// is a correct configuration file. Use `DecoderBuilder.build()` to
790    /// deserialize the YAML and parse the model configuration.
791    ///
792    /// # Examples
793    /// ```rust
794    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
795    /// # fn main() -> DecoderResult<()> {
796    /// let config_yaml = include_str!("../../../testdata/modelpack_split.yaml").to_string();
797    /// let decoder = DecoderBuilder::new()
798    ///     .with_config_yaml_str(config_yaml)
799    ///     .build()?;
800    ///
801    /// # Ok(())
802    /// # }
803    /// ```
804    pub fn with_config_yaml_str(mut self, yaml_str: String) -> Self {
805        self.config_src.replace(ConfigSource::Yaml(yaml_str));
806        self
807    }
808
809    /// Loads a model configuration in JSON format. Does not check if the string
810    /// is a correct configuration file. Use `DecoderBuilder.build()` to
811    /// deserialize the JSON and parse the model configuration.
812    ///
813    /// # Examples
814    /// ```rust
815    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
816    /// # fn main() -> DecoderResult<()> {
817    /// let config_json = include_str!("../../../testdata/modelpack_split.json").to_string();
818    /// let decoder = DecoderBuilder::new()
819    ///     .with_config_json_str(config_json)
820    ///     .build()?;
821    ///
822    /// # Ok(())
823    /// # }
824    /// ```
825    pub fn with_config_json_str(mut self, json_str: String) -> Self {
826        self.config_src.replace(ConfigSource::Json(json_str));
827        self
828    }
829
830    /// Loads a model configuration. Does not check if the configuration is
831    /// correct. Intended to be used when the user needs control over the
832    /// deserialize of the configuration information. Use
833    /// `DecoderBuilder.build()` to parse the model configuration.
834    ///
835    /// # Examples
836    /// ```rust
837    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
838    /// # fn main() -> DecoderResult<()> {
839    /// let config_json = include_str!("../../../testdata/modelpack_split.json");
840    /// let config = serde_json::from_str(config_json)?;
841    /// let decoder = DecoderBuilder::new().with_config(config).build()?;
842    ///
843    /// # Ok(())
844    /// # }
845    /// ```
846    pub fn with_config(mut self, config: ConfigOutputs) -> Self {
847        self.config_src.replace(ConfigSource::Config(config));
848        self
849    }
850
851    /// Loads a YOLO detection model configuration.  Use
852    /// `DecoderBuilder.build()` to parse the model configuration.
853    ///
854    /// # Examples
855    /// ```rust
856    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
857    /// # fn main() -> DecoderResult<()> {
858    /// let decoder = DecoderBuilder::new()
859    ///     .with_config_yolo_det(
860    ///         configs::Detection {
861    ///             anchors: None,
862    ///             decoder: configs::DecoderType::Ultralytics,
863    ///             quantization: Some(configs::QuantTuple(0.012345, 26)),
864    ///             shape: vec![1, 84, 8400],
865    ///             dshape: Vec::new(),
866    ///             normalized: Some(true),
867    ///         },
868    ///         None,
869    ///     )
870    ///     .build()?;
871    ///
872    /// # Ok(())
873    /// # }
874    /// ```
875    pub fn with_config_yolo_det(
876        mut self,
877        boxes: configs::Detection,
878        version: Option<DecoderVersion>,
879    ) -> Self {
880        let config = ConfigOutputs {
881            outputs: vec![ConfigOutput::Detection(boxes)],
882            decoder_version: version,
883            ..Default::default()
884        };
885        self.config_src.replace(ConfigSource::Config(config));
886        self
887    }
888
889    /// Loads a YOLO split detection model configuration.  Use
890    /// `DecoderBuilder.build()` to parse the model configuration.
891    ///
892    /// # Examples
893    /// ```rust
894    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
895    /// # fn main() -> DecoderResult<()> {
896    /// let boxes_config = configs::Boxes {
897    ///     decoder: configs::DecoderType::Ultralytics,
898    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
899    ///     shape: vec![1, 4, 8400],
900    ///     dshape: Vec::new(),
901    ///     normalized: Some(true),
902    /// };
903    /// let scores_config = configs::Scores {
904    ///     decoder: configs::DecoderType::Ultralytics,
905    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
906    ///     shape: vec![1, 80, 8400],
907    ///     dshape: Vec::new(),
908    /// };
909    /// let decoder = DecoderBuilder::new()
910    ///     .with_config_yolo_split_det(boxes_config, scores_config)
911    ///     .build()?;
912    /// # Ok(())
913    /// # }
914    /// ```
915    pub fn with_config_yolo_split_det(
916        mut self,
917        boxes: configs::Boxes,
918        scores: configs::Scores,
919    ) -> Self {
920        let config = ConfigOutputs {
921            outputs: vec![ConfigOutput::Boxes(boxes), ConfigOutput::Scores(scores)],
922            ..Default::default()
923        };
924        self.config_src.replace(ConfigSource::Config(config));
925        self
926    }
927
928    /// Loads a YOLO segmentation model configuration.  Use
929    /// `DecoderBuilder.build()` to parse the model configuration.
930    ///
931    /// # Examples
932    /// ```rust
933    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
934    /// # fn main() -> DecoderResult<()> {
935    /// let seg_config = configs::Detection {
936    ///     decoder: configs::DecoderType::Ultralytics,
937    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
938    ///     shape: vec![1, 116, 8400],
939    ///     anchors: None,
940    ///     dshape: Vec::new(),
941    ///     normalized: Some(true),
942    /// };
943    /// let protos_config = configs::Protos {
944    ///     decoder: configs::DecoderType::Ultralytics,
945    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
946    ///     shape: vec![1, 160, 160, 32],
947    ///     dshape: Vec::new(),
948    /// };
949    /// let decoder = DecoderBuilder::new()
950    ///     .with_config_yolo_segdet(
951    ///         seg_config,
952    ///         protos_config,
953    ///         Some(configs::DecoderVersion::Yolov8),
954    ///     )
955    ///     .build()?;
956    /// # Ok(())
957    /// # }
958    /// ```
959    pub fn with_config_yolo_segdet(
960        mut self,
961        boxes: configs::Detection,
962        protos: configs::Protos,
963        version: Option<DecoderVersion>,
964    ) -> Self {
965        let config = ConfigOutputs {
966            outputs: vec![ConfigOutput::Detection(boxes), ConfigOutput::Protos(protos)],
967            decoder_version: version,
968            ..Default::default()
969        };
970        self.config_src.replace(ConfigSource::Config(config));
971        self
972    }
973
974    /// Loads a YOLO split segmentation model configuration.  Use
975    /// `DecoderBuilder.build()` to parse the model configuration.
976    ///
977    /// # Examples
978    /// ```rust
979    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
980    /// # fn main() -> DecoderResult<()> {
981    /// let boxes_config = configs::Boxes {
982    ///     decoder: configs::DecoderType::Ultralytics,
983    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
984    ///     shape: vec![1, 4, 8400],
985    ///     dshape: Vec::new(),
986    ///     normalized: Some(true),
987    /// };
988    /// let scores_config = configs::Scores {
989    ///     decoder: configs::DecoderType::Ultralytics,
990    ///     quantization: Some(configs::QuantTuple(0.012345, 14)),
991    ///     shape: vec![1, 80, 8400],
992    ///     dshape: Vec::new(),
993    /// };
994    /// let mask_config = configs::MaskCoefficients {
995    ///     decoder: configs::DecoderType::Ultralytics,
996    ///     quantization: Some(configs::QuantTuple(0.0064123, 125)),
997    ///     shape: vec![1, 32, 8400],
998    ///     dshape: Vec::new(),
999    /// };
1000    /// let protos_config = configs::Protos {
1001    ///     decoder: configs::DecoderType::Ultralytics,
1002    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
1003    ///     shape: vec![1, 160, 160, 32],
1004    ///     dshape: Vec::new(),
1005    /// };
1006    /// let decoder = DecoderBuilder::new()
1007    ///     .with_config_yolo_split_segdet(boxes_config, scores_config, mask_config, protos_config)
1008    ///     .build()?;
1009    /// # Ok(())
1010    /// # }
1011    /// ```
1012    pub fn with_config_yolo_split_segdet(
1013        mut self,
1014        boxes: configs::Boxes,
1015        scores: configs::Scores,
1016        mask_coefficients: configs::MaskCoefficients,
1017        protos: configs::Protos,
1018    ) -> Self {
1019        let config = ConfigOutputs {
1020            outputs: vec![
1021                ConfigOutput::Boxes(boxes),
1022                ConfigOutput::Scores(scores),
1023                ConfigOutput::MaskCoefficients(mask_coefficients),
1024                ConfigOutput::Protos(protos),
1025            ],
1026            ..Default::default()
1027        };
1028        self.config_src.replace(ConfigSource::Config(config));
1029        self
1030    }
1031
1032    /// Loads a ModelPack detection model configuration.  Use
1033    /// `DecoderBuilder.build()` to parse the model configuration.
1034    ///
1035    /// # Examples
1036    /// ```rust
1037    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
1038    /// # fn main() -> DecoderResult<()> {
1039    /// let boxes_config = configs::Boxes {
1040    ///     decoder: configs::DecoderType::ModelPack,
1041    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
1042    ///     shape: vec![1, 8400, 1, 4],
1043    ///     dshape: Vec::new(),
1044    ///     normalized: Some(true),
1045    /// };
1046    /// let scores_config = configs::Scores {
1047    ///     decoder: configs::DecoderType::ModelPack,
1048    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
1049    ///     shape: vec![1, 8400, 3],
1050    ///     dshape: Vec::new(),
1051    /// };
1052    /// let decoder = DecoderBuilder::new()
1053    ///     .with_config_modelpack_det(boxes_config, scores_config)
1054    ///     .build()?;
1055    /// # Ok(())
1056    /// # }
1057    /// ```
1058    pub fn with_config_modelpack_det(
1059        mut self,
1060        boxes: configs::Boxes,
1061        scores: configs::Scores,
1062    ) -> Self {
1063        let config = ConfigOutputs {
1064            outputs: vec![ConfigOutput::Boxes(boxes), ConfigOutput::Scores(scores)],
1065            ..Default::default()
1066        };
1067        self.config_src.replace(ConfigSource::Config(config));
1068        self
1069    }
1070
1071    /// Loads a ModelPack split detection model configuration. Use
1072    /// `DecoderBuilder.build()` to parse the model configuration.
1073    ///
1074    /// # Examples
1075    /// ```rust
1076    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
1077    /// # fn main() -> DecoderResult<()> {
1078    /// let config0 = configs::Detection {
1079    ///     anchors: Some(vec![
1080    ///         [0.13750000298023224, 0.2074074000120163],
1081    ///         [0.2541666626930237, 0.21481481194496155],
1082    ///         [0.23125000298023224, 0.35185185074806213],
1083    ///     ]),
1084    ///     decoder: configs::DecoderType::ModelPack,
1085    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
1086    ///     shape: vec![1, 17, 30, 18],
1087    ///     dshape: Vec::new(),
1088    ///     normalized: Some(true),
1089    /// };
1090    /// let config1 = configs::Detection {
1091    ///     anchors: Some(vec![
1092    ///         [0.36666667461395264, 0.31481480598449707],
1093    ///         [0.38749998807907104, 0.4740740656852722],
1094    ///         [0.5333333611488342, 0.644444465637207],
1095    ///     ]),
1096    ///     decoder: configs::DecoderType::ModelPack,
1097    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
1098    ///     shape: vec![1, 9, 15, 18],
1099    ///     dshape: Vec::new(),
1100    ///     normalized: Some(true),
1101    /// };
1102    ///
1103    /// let decoder = DecoderBuilder::new()
1104    ///     .with_config_modelpack_det_split(vec![config0, config1])
1105    ///     .build()?;
1106    /// # Ok(())
1107    /// # }
1108    /// ```
1109    pub fn with_config_modelpack_det_split(mut self, boxes: Vec<configs::Detection>) -> Self {
1110        let outputs = boxes.into_iter().map(ConfigOutput::Detection).collect();
1111        let config = ConfigOutputs {
1112            outputs,
1113            ..Default::default()
1114        };
1115        self.config_src.replace(ConfigSource::Config(config));
1116        self
1117    }
1118
1119    /// Loads a ModelPack segmentation detection model configuration. Use
1120    /// `DecoderBuilder.build()` to parse the model configuration.
1121    ///
1122    /// # Examples
1123    /// ```rust
1124    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
1125    /// # fn main() -> DecoderResult<()> {
1126    /// let boxes_config = configs::Boxes {
1127    ///     decoder: configs::DecoderType::ModelPack,
1128    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
1129    ///     shape: vec![1, 8400, 1, 4],
1130    ///     dshape: Vec::new(),
1131    ///     normalized: Some(true),
1132    /// };
1133    /// let scores_config = configs::Scores {
1134    ///     decoder: configs::DecoderType::ModelPack,
1135    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
1136    ///     shape: vec![1, 8400, 2],
1137    ///     dshape: Vec::new(),
1138    /// };
1139    /// let seg_config = configs::Segmentation {
1140    ///     decoder: configs::DecoderType::ModelPack,
1141    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
1142    ///     shape: vec![1, 640, 640, 3],
1143    ///     dshape: Vec::new(),
1144    /// };
1145    /// let decoder = DecoderBuilder::new()
1146    ///     .with_config_modelpack_segdet(boxes_config, scores_config, seg_config)
1147    ///     .build()?;
1148    /// # Ok(())
1149    /// # }
1150    /// ```
1151    pub fn with_config_modelpack_segdet(
1152        mut self,
1153        boxes: configs::Boxes,
1154        scores: configs::Scores,
1155        segmentation: configs::Segmentation,
1156    ) -> Self {
1157        let config = ConfigOutputs {
1158            outputs: vec![
1159                ConfigOutput::Boxes(boxes),
1160                ConfigOutput::Scores(scores),
1161                ConfigOutput::Segmentation(segmentation),
1162            ],
1163            ..Default::default()
1164        };
1165        self.config_src.replace(ConfigSource::Config(config));
1166        self
1167    }
1168
1169    /// Loads a ModelPack segmentation split detection model configuration. Use
1170    /// `DecoderBuilder.build()` to parse the model configuration.
1171    ///
1172    /// # Examples
1173    /// ```rust
1174    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
1175    /// # fn main() -> DecoderResult<()> {
1176    /// let config0 = configs::Detection {
1177    ///     anchors: Some(vec![
1178    ///         [0.36666667461395264, 0.31481480598449707],
1179    ///         [0.38749998807907104, 0.4740740656852722],
1180    ///         [0.5333333611488342, 0.644444465637207],
1181    ///     ]),
1182    ///     decoder: configs::DecoderType::ModelPack,
1183    ///     quantization: Some(configs::QuantTuple(0.08547406643629074, 174)),
1184    ///     shape: vec![1, 9, 15, 18],
1185    ///     dshape: Vec::new(),
1186    ///     normalized: Some(true),
1187    /// };
1188    /// let config1 = configs::Detection {
1189    ///     anchors: Some(vec![
1190    ///         [0.13750000298023224, 0.2074074000120163],
1191    ///         [0.2541666626930237, 0.21481481194496155],
1192    ///         [0.23125000298023224, 0.35185185074806213],
1193    ///     ]),
1194    ///     decoder: configs::DecoderType::ModelPack,
1195    ///     quantization: Some(configs::QuantTuple(0.09929127991199493, 183)),
1196    ///     shape: vec![1, 17, 30, 18],
1197    ///     dshape: Vec::new(),
1198    ///     normalized: Some(true),
1199    /// };
1200    /// let seg_config = configs::Segmentation {
1201    ///     decoder: configs::DecoderType::ModelPack,
1202    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
1203    ///     shape: vec![1, 640, 640, 2],
1204    ///     dshape: Vec::new(),
1205    /// };
1206    /// let decoder = DecoderBuilder::new()
1207    ///     .with_config_modelpack_segdet_split(vec![config0, config1], seg_config)
1208    ///     .build()?;
1209    /// # Ok(())
1210    /// # }
1211    /// ```
1212    pub fn with_config_modelpack_segdet_split(
1213        mut self,
1214        boxes: Vec<configs::Detection>,
1215        segmentation: configs::Segmentation,
1216    ) -> Self {
1217        let mut outputs = boxes
1218            .into_iter()
1219            .map(ConfigOutput::Detection)
1220            .collect::<Vec<_>>();
1221        outputs.push(ConfigOutput::Segmentation(segmentation));
1222        let config = ConfigOutputs {
1223            outputs,
1224            ..Default::default()
1225        };
1226        self.config_src.replace(ConfigSource::Config(config));
1227        self
1228    }
1229
1230    /// Loads a ModelPack segmentation model configuration. Use
1231    /// `DecoderBuilder.build()` to parse the model configuration.
1232    ///
1233    /// # Examples
1234    /// ```rust
1235    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
1236    /// # fn main() -> DecoderResult<()> {
1237    /// let seg_config = configs::Segmentation {
1238    ///     decoder: configs::DecoderType::ModelPack,
1239    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
1240    ///     shape: vec![1, 640, 640, 3],
1241    ///     dshape: Vec::new(),
1242    /// };
1243    /// let decoder = DecoderBuilder::new()
1244    ///     .with_config_modelpack_seg(seg_config)
1245    ///     .build()?;
1246    /// # Ok(())
1247    /// # }
1248    /// ```
1249    pub fn with_config_modelpack_seg(mut self, segmentation: configs::Segmentation) -> Self {
1250        let config = ConfigOutputs {
1251            outputs: vec![ConfigOutput::Segmentation(segmentation)],
1252            ..Default::default()
1253        };
1254        self.config_src.replace(ConfigSource::Config(config));
1255        self
1256    }
1257
1258    /// Add an output to the decoder configuration.
1259    ///
1260    /// Incrementally builds the model configuration by adding outputs one at
1261    /// a time. The decoder resolves the model type from the combination of
1262    /// outputs during `build()`.
1263    ///
1264    /// If `dshape` is non-empty on the output, `shape` is automatically
1265    /// derived from it (the size component of each named dimension). This
1266    /// prevents conflicts between `shape` and `dshape`.
1267    ///
1268    /// This uses the programmatic config path. Calling this after
1269    /// `with_config_json_str()` or `with_config_yaml_str()` replaces the
1270    /// string-based config source.
1271    ///
1272    /// # Examples
1273    /// ```rust
1274    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, ConfigOutput, configs};
1275    /// # fn main() -> DecoderResult<()> {
1276    /// let decoder = DecoderBuilder::new()
1277    ///     .add_output(ConfigOutput::Scores(configs::Scores {
1278    ///         decoder: configs::DecoderType::Ultralytics,
1279    ///         dshape: vec![
1280    ///             (configs::DimName::Batch, 1),
1281    ///             (configs::DimName::NumClasses, 80),
1282    ///             (configs::DimName::NumBoxes, 8400),
1283    ///         ],
1284    ///         ..Default::default()
1285    ///     }))
1286    ///     .add_output(ConfigOutput::Boxes(configs::Boxes {
1287    ///         decoder: configs::DecoderType::Ultralytics,
1288    ///         dshape: vec![
1289    ///             (configs::DimName::Batch, 1),
1290    ///             (configs::DimName::BoxCoords, 4),
1291    ///             (configs::DimName::NumBoxes, 8400),
1292    ///         ],
1293    ///         ..Default::default()
1294    ///     }))
1295    ///     .build()?;
1296    /// # Ok(())
1297    /// # }
1298    /// ```
1299    pub fn add_output(mut self, output: ConfigOutput) -> Self {
1300        if !matches!(self.config_src, Some(ConfigSource::Config(_))) {
1301            self.config_src = Some(ConfigSource::Config(ConfigOutputs::default()));
1302        }
1303        if let Some(ConfigSource::Config(ref mut config)) = self.config_src {
1304            config.outputs.push(Self::normalize_output(output));
1305        }
1306        self
1307    }
1308
1309    /// Sets the decoder version for Ultralytics models.
1310    ///
1311    /// This is used with `add_output()` to specify the YOLO architecture
1312    /// version when it cannot be inferred from the output shapes alone.
1313    ///
1314    /// - `Yolov5`, `Yolov8`, `Yolo11`: Traditional models requiring external
1315    ///   NMS
1316    /// - `Yolo26`: End-to-end models with NMS embedded in the model graph
1317    ///
1318    /// # Examples
1319    /// ```rust
1320    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, ConfigOutput, configs};
1321    /// # fn main() -> DecoderResult<()> {
1322    /// let decoder = DecoderBuilder::new()
1323    ///     .add_output(ConfigOutput::Detection(configs::Detection {
1324    ///         decoder: configs::DecoderType::Ultralytics,
1325    ///         dshape: vec![
1326    ///             (configs::DimName::Batch, 1),
1327    ///             (configs::DimName::NumBoxes, 100),
1328    ///             (configs::DimName::NumFeatures, 6),
1329    ///         ],
1330    ///         ..Default::default()
1331    ///     }))
1332    ///     .with_decoder_version(configs::DecoderVersion::Yolo26)
1333    ///     .build()?;
1334    /// # Ok(())
1335    /// # }
1336    /// ```
1337    pub fn with_decoder_version(mut self, version: configs::DecoderVersion) -> Self {
1338        if !matches!(self.config_src, Some(ConfigSource::Config(_))) {
1339            self.config_src = Some(ConfigSource::Config(ConfigOutputs::default()));
1340        }
1341        if let Some(ConfigSource::Config(ref mut config)) = self.config_src {
1342            config.decoder_version = Some(version);
1343        }
1344        self
1345    }
1346
1347    /// Normalize an output: if dshape is non-empty, derive shape from it.
1348    fn normalize_output(mut output: ConfigOutput) -> ConfigOutput {
1349        fn normalize_shape(shape: &mut Vec<usize>, dshape: &[(configs::DimName, usize)]) {
1350            if !dshape.is_empty() {
1351                *shape = dshape.iter().map(|(_, size)| *size).collect();
1352            }
1353        }
1354        match &mut output {
1355            ConfigOutput::Detection(c) => normalize_shape(&mut c.shape, &c.dshape),
1356            ConfigOutput::Boxes(c) => normalize_shape(&mut c.shape, &c.dshape),
1357            ConfigOutput::Scores(c) => normalize_shape(&mut c.shape, &c.dshape),
1358            ConfigOutput::Protos(c) => normalize_shape(&mut c.shape, &c.dshape),
1359            ConfigOutput::Segmentation(c) => normalize_shape(&mut c.shape, &c.dshape),
1360            ConfigOutput::MaskCoefficients(c) => normalize_shape(&mut c.shape, &c.dshape),
1361            ConfigOutput::Mask(c) => normalize_shape(&mut c.shape, &c.dshape),
1362            ConfigOutput::Classes(c) => normalize_shape(&mut c.shape, &c.dshape),
1363        }
1364        output
1365    }
1366
1367    /// Sets the scores threshold of the decoder
1368    ///
1369    /// # Examples
1370    /// ```rust
1371    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
1372    /// # fn main() -> DecoderResult<()> {
1373    /// # let config_json = include_str!("../../../testdata/modelpack_split.json").to_string();
1374    /// let decoder = DecoderBuilder::new()
1375    ///     .with_config_json_str(config_json)
1376    ///     .with_score_threshold(0.654)
1377    ///     .build()?;
1378    /// assert_eq!(decoder.score_threshold, 0.654);
1379    /// # Ok(())
1380    /// # }
1381    /// ```
1382    pub fn with_score_threshold(mut self, score_threshold: f32) -> Self {
1383        self.score_threshold = score_threshold;
1384        self
1385    }
1386
1387    /// Sets the IOU threshold of the decoder. Has no effect when NMS is set to
1388    /// `None`
1389    ///
1390    /// # Examples
1391    /// ```rust
1392    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
1393    /// # fn main() -> DecoderResult<()> {
1394    /// # let config_json = include_str!("../../../testdata/modelpack_split.json").to_string();
1395    /// let decoder = DecoderBuilder::new()
1396    ///     .with_config_json_str(config_json)
1397    ///     .with_iou_threshold(0.654)
1398    ///     .build()?;
1399    /// assert_eq!(decoder.iou_threshold, 0.654);
1400    /// # Ok(())
1401    /// # }
1402    /// ```
1403    pub fn with_iou_threshold(mut self, iou_threshold: f32) -> Self {
1404        self.iou_threshold = iou_threshold;
1405        self
1406    }
1407
1408    /// Sets the NMS mode for the decoder.
1409    ///
1410    /// - `Some(Nms::ClassAgnostic)` — class-agnostic NMS (default): suppress
1411    ///   overlapping boxes regardless of class label
1412    /// - `Some(Nms::ClassAware)` — class-aware NMS: only suppress boxes that
1413    ///   share the same class label AND overlap above the IoU threshold
1414    /// - `None` — bypass NMS entirely (for end-to-end models with embedded NMS)
1415    ///
1416    /// # Examples
1417    /// ```rust
1418    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, configs::Nms};
1419    /// # fn main() -> DecoderResult<()> {
1420    /// # let config_json = include_str!("../../../testdata/modelpack_split.json").to_string();
1421    /// let decoder = DecoderBuilder::new()
1422    ///     .with_config_json_str(config_json)
1423    ///     .with_nms(Some(Nms::ClassAware))
1424    ///     .build()?;
1425    /// assert_eq!(decoder.nms, Some(Nms::ClassAware));
1426    /// # Ok(())
1427    /// # }
1428    /// ```
1429    pub fn with_nms(mut self, nms: Option<configs::Nms>) -> Self {
1430        self.nms = nms;
1431        self
1432    }
1433
1434    /// Builds the decoder with the given settings. If the config is a JSON or
1435    /// YAML string, this will deserialize the JSON or YAML and then parse the
1436    /// configuration information.
1437    ///
1438    /// # Examples
1439    /// ```rust
1440    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
1441    /// # fn main() -> DecoderResult<()> {
1442    /// # let config_json = include_str!("../../../testdata/modelpack_split.json").to_string();
1443    /// let decoder = DecoderBuilder::new()
1444    ///     .with_config_json_str(config_json)
1445    ///     .with_score_threshold(0.654)
1446    ///     .build()?;
1447    /// # Ok(())
1448    /// # }
1449    /// ```
1450    pub fn build(self) -> Result<Decoder, DecoderError> {
1451        let config = match self.config_src {
1452            Some(ConfigSource::Json(s)) => serde_json::from_str(&s)?,
1453            Some(ConfigSource::Yaml(s)) => serde_yaml::from_str(&s)?,
1454            Some(ConfigSource::Config(c)) => c,
1455            None => return Err(DecoderError::NoConfig),
1456        };
1457
1458        // Extract normalized flag from config outputs
1459        let normalized = Self::get_normalized(&config.outputs);
1460
1461        // Use NMS from config if present, otherwise use builder's NMS setting
1462        let nms = config.nms.or(self.nms);
1463        let model_type = Self::get_model_type(config)?;
1464
1465        Ok(Decoder {
1466            model_type,
1467            iou_threshold: self.iou_threshold,
1468            score_threshold: self.score_threshold,
1469            nms,
1470            normalized,
1471        })
1472    }
1473
1474    /// Extracts the normalized flag from config outputs.
1475    /// - `Some(true)`: Boxes are in normalized [0,1] coordinates
1476    /// - `Some(false)`: Boxes are in pixel coordinates
1477    /// - `None`: Unknown (not specified in config), caller must infer
1478    fn get_normalized(outputs: &[ConfigOutput]) -> Option<bool> {
1479        for output in outputs {
1480            match output {
1481                ConfigOutput::Detection(det) => return det.normalized,
1482                ConfigOutput::Boxes(boxes) => return boxes.normalized,
1483                _ => {}
1484            }
1485        }
1486        None // not specified
1487    }
1488
1489    fn get_model_type(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
1490        // yolo or modelpack
1491        let mut yolo = false;
1492        let mut modelpack = false;
1493        for c in &configs.outputs {
1494            match c.decoder() {
1495                DecoderType::ModelPack => modelpack = true,
1496                DecoderType::Ultralytics => yolo = true,
1497            }
1498        }
1499        match (modelpack, yolo) {
1500            (true, true) => Err(DecoderError::InvalidConfig(
1501                "Both ModelPack and Yolo outputs found in config".to_string(),
1502            )),
1503            (true, false) => Self::get_model_type_modelpack(configs),
1504            (false, true) => Self::get_model_type_yolo(configs),
1505            (false, false) => Err(DecoderError::InvalidConfig(
1506                "No outputs found in config".to_string(),
1507            )),
1508        }
1509    }
1510
1511    fn get_model_type_yolo(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
1512        let mut boxes = None;
1513        let mut protos = None;
1514        let mut split_boxes = None;
1515        let mut split_scores = None;
1516        let mut split_mask_coeff = None;
1517        let mut split_classes = None;
1518        for c in configs.outputs {
1519            match c {
1520                ConfigOutput::Detection(detection) => boxes = Some(detection),
1521                ConfigOutput::Segmentation(_) => {
1522                    return Err(DecoderError::InvalidConfig(
1523                        "Invalid Segmentation output with Yolo decoder".to_string(),
1524                    ));
1525                }
1526                ConfigOutput::Protos(protos_) => protos = Some(protos_),
1527                ConfigOutput::Mask(_) => {
1528                    return Err(DecoderError::InvalidConfig(
1529                        "Invalid Mask output with Yolo decoder".to_string(),
1530                    ));
1531                }
1532                ConfigOutput::Scores(scores) => split_scores = Some(scores),
1533                ConfigOutput::Boxes(boxes) => split_boxes = Some(boxes),
1534                ConfigOutput::MaskCoefficients(mask_coeff) => split_mask_coeff = Some(mask_coeff),
1535                ConfigOutput::Classes(classes) => split_classes = Some(classes),
1536            }
1537        }
1538
1539        // Use end-to-end model types when:
1540        // 1. decoder_version is explicitly set to Yolo26 (definitive), OR
1541        //    decoder_version is not set but the dshapes are (batch, num_boxes,
1542        //    num_features)
1543        let is_end_to_end_dshape = boxes.as_ref().is_some_and(|b| {
1544            let dims = b.dshape.iter().map(|(d, _)| *d).collect::<Vec<_>>();
1545            dims == vec![DimName::Batch, DimName::NumBoxes, DimName::NumFeatures]
1546        });
1547
1548        let is_end_to_end = configs
1549            .decoder_version
1550            .map(|v| v.is_end_to_end())
1551            .unwrap_or(is_end_to_end_dshape);
1552
1553        if is_end_to_end {
1554            if let Some(boxes) = boxes {
1555                if let Some(protos) = protos {
1556                    Self::verify_yolo_seg_det_26(&boxes, &protos)?;
1557                    return Ok(ModelType::YoloEndToEndSegDet { boxes, protos });
1558                } else {
1559                    Self::verify_yolo_det_26(&boxes)?;
1560                    return Ok(ModelType::YoloEndToEndDet { boxes });
1561                }
1562            } else if let (Some(split_boxes), Some(split_scores), Some(split_classes)) =
1563                (split_boxes, split_scores, split_classes)
1564            {
1565                if let (Some(split_mask_coeff), Some(protos)) = (split_mask_coeff, protos) {
1566                    Self::verify_yolo_split_end_to_end_segdet(
1567                        &split_boxes,
1568                        &split_scores,
1569                        &split_classes,
1570                        &split_mask_coeff,
1571                        &protos,
1572                    )?;
1573                    return Ok(ModelType::YoloSplitEndToEndSegDet {
1574                        boxes: split_boxes,
1575                        scores: split_scores,
1576                        classes: split_classes,
1577                        mask_coeff: split_mask_coeff,
1578                        protos,
1579                    });
1580                }
1581                Self::verify_yolo_split_end_to_end_det(
1582                    &split_boxes,
1583                    &split_scores,
1584                    &split_classes,
1585                )?;
1586                return Ok(ModelType::YoloSplitEndToEndDet {
1587                    boxes: split_boxes,
1588                    scores: split_scores,
1589                    classes: split_classes,
1590                });
1591            } else {
1592                return Err(DecoderError::InvalidConfig(
1593                    "Invalid Yolo end-to-end model outputs".to_string(),
1594                ));
1595            }
1596        }
1597
1598        if let Some(boxes) = boxes {
1599            if let Some(protos) = protos {
1600                Self::verify_yolo_seg_det(&boxes, &protos)?;
1601                Ok(ModelType::YoloSegDet { boxes, protos })
1602            } else {
1603                Self::verify_yolo_det(&boxes)?;
1604                Ok(ModelType::YoloDet { boxes })
1605            }
1606        } else if let (Some(boxes), Some(scores)) = (split_boxes, split_scores) {
1607            if let (Some(mask_coeff), Some(protos)) = (split_mask_coeff, protos) {
1608                Self::verify_yolo_split_segdet(&boxes, &scores, &mask_coeff, &protos)?;
1609                Ok(ModelType::YoloSplitSegDet {
1610                    boxes,
1611                    scores,
1612                    mask_coeff,
1613                    protos,
1614                })
1615            } else {
1616                Self::verify_yolo_split_det(&boxes, &scores)?;
1617                Ok(ModelType::YoloSplitDet { boxes, scores })
1618            }
1619        } else {
1620            Err(DecoderError::InvalidConfig(
1621                "Invalid Yolo model outputs".to_string(),
1622            ))
1623        }
1624    }
1625
1626    fn verify_yolo_det(detect: &configs::Detection) -> Result<(), DecoderError> {
1627        if detect.shape.len() != 3 {
1628            return Err(DecoderError::InvalidConfig(format!(
1629                "Invalid Yolo Detection shape {:?}",
1630                detect.shape
1631            )));
1632        }
1633
1634        Self::verify_dshapes(
1635            &detect.dshape,
1636            &detect.shape,
1637            "Detection",
1638            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1639        )?;
1640        if !detect.dshape.is_empty() {
1641            Self::get_class_count(&detect.dshape, None, None)?;
1642        } else {
1643            Self::get_class_count_no_dshape(detect.into(), None)?;
1644        }
1645
1646        Ok(())
1647    }
1648
1649    fn verify_yolo_det_26(detect: &configs::Detection) -> Result<(), DecoderError> {
1650        if detect.shape.len() != 3 {
1651            return Err(DecoderError::InvalidConfig(format!(
1652                "Invalid Yolo Detection shape {:?}",
1653                detect.shape
1654            )));
1655        }
1656
1657        Self::verify_dshapes(
1658            &detect.dshape,
1659            &detect.shape,
1660            "Detection",
1661            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1662        )?;
1663
1664        if !detect.shape.contains(&6) {
1665            return Err(DecoderError::InvalidConfig(
1666                "Yolo26 Detection must have 6 features".to_string(),
1667            ));
1668        }
1669
1670        Ok(())
1671    }
1672
1673    fn verify_yolo_seg_det(
1674        detection: &configs::Detection,
1675        protos: &configs::Protos,
1676    ) -> Result<(), DecoderError> {
1677        if detection.shape.len() != 3 {
1678            return Err(DecoderError::InvalidConfig(format!(
1679                "Invalid Yolo Detection shape {:?}",
1680                detection.shape
1681            )));
1682        }
1683        if protos.shape.len() != 4 {
1684            return Err(DecoderError::InvalidConfig(format!(
1685                "Invalid Yolo Protos shape {:?}",
1686                protos.shape
1687            )));
1688        }
1689
1690        Self::verify_dshapes(
1691            &detection.dshape,
1692            &detection.shape,
1693            "Detection",
1694            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1695        )?;
1696        Self::verify_dshapes(
1697            &protos.dshape,
1698            &protos.shape,
1699            "Protos",
1700            &[
1701                DimName::Batch,
1702                DimName::Height,
1703                DimName::Width,
1704                DimName::NumProtos,
1705            ],
1706        )?;
1707
1708        let protos_count = Self::get_protos_count(&protos.dshape).unwrap_or(protos.shape[3]);
1709        log::debug!("Protos count: {}", protos_count);
1710        log::debug!("Detection dshape: {:?}", detection.dshape);
1711        let classes = if !detection.dshape.is_empty() {
1712            Self::get_class_count(&detection.dshape, Some(protos_count), None)?
1713        } else {
1714            Self::get_class_count_no_dshape(detection.into(), Some(protos_count))?
1715        };
1716
1717        if classes == 0 {
1718            return Err(DecoderError::InvalidConfig(
1719                "Yolo Segmentation Detection has zero classes".to_string(),
1720            ));
1721        }
1722
1723        Ok(())
1724    }
1725
1726    fn verify_yolo_seg_det_26(
1727        detection: &configs::Detection,
1728        protos: &configs::Protos,
1729    ) -> Result<(), DecoderError> {
1730        if detection.shape.len() != 3 {
1731            return Err(DecoderError::InvalidConfig(format!(
1732                "Invalid Yolo Detection shape {:?}",
1733                detection.shape
1734            )));
1735        }
1736        if protos.shape.len() != 4 {
1737            return Err(DecoderError::InvalidConfig(format!(
1738                "Invalid Yolo Protos shape {:?}",
1739                protos.shape
1740            )));
1741        }
1742
1743        Self::verify_dshapes(
1744            &detection.dshape,
1745            &detection.shape,
1746            "Detection",
1747            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1748        )?;
1749        Self::verify_dshapes(
1750            &protos.dshape,
1751            &protos.shape,
1752            "Protos",
1753            &[
1754                DimName::Batch,
1755                DimName::Height,
1756                DimName::Width,
1757                DimName::NumProtos,
1758            ],
1759        )?;
1760
1761        let protos_count = Self::get_protos_count(&protos.dshape).unwrap_or(protos.shape[3]);
1762        log::debug!("Protos count: {}", protos_count);
1763        log::debug!("Detection dshape: {:?}", detection.dshape);
1764
1765        if !detection.shape.contains(&(6 + protos_count)) {
1766            return Err(DecoderError::InvalidConfig(format!(
1767                "Yolo26 Segmentation Detection must have num_features be 6 + num_protos = {}",
1768                6 + protos_count
1769            )));
1770        }
1771
1772        Ok(())
1773    }
1774
1775    fn verify_yolo_split_det(
1776        boxes: &configs::Boxes,
1777        scores: &configs::Scores,
1778    ) -> Result<(), DecoderError> {
1779        if boxes.shape.len() != 3 {
1780            return Err(DecoderError::InvalidConfig(format!(
1781                "Invalid Yolo Split Boxes shape {:?}",
1782                boxes.shape
1783            )));
1784        }
1785        if scores.shape.len() != 3 {
1786            return Err(DecoderError::InvalidConfig(format!(
1787                "Invalid Yolo Split Scores shape {:?}",
1788                scores.shape
1789            )));
1790        }
1791
1792        Self::verify_dshapes(
1793            &boxes.dshape,
1794            &boxes.shape,
1795            "Boxes",
1796            &[DimName::Batch, DimName::BoxCoords, DimName::NumBoxes],
1797        )?;
1798        Self::verify_dshapes(
1799            &scores.dshape,
1800            &scores.shape,
1801            "Scores",
1802            &[DimName::Batch, DimName::NumClasses, DimName::NumBoxes],
1803        )?;
1804
1805        let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[2]);
1806        let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[2]);
1807
1808        if boxes_num != scores_num {
1809            return Err(DecoderError::InvalidConfig(format!(
1810                "Yolo Split Detection Boxes num {} incompatible with Scores num {}",
1811                boxes_num, scores_num
1812            )));
1813        }
1814
1815        Ok(())
1816    }
1817
1818    fn verify_yolo_split_segdet(
1819        boxes: &configs::Boxes,
1820        scores: &configs::Scores,
1821        mask_coeff: &configs::MaskCoefficients,
1822        protos: &configs::Protos,
1823    ) -> Result<(), DecoderError> {
1824        if boxes.shape.len() != 3 {
1825            return Err(DecoderError::InvalidConfig(format!(
1826                "Invalid Yolo Split Boxes shape {:?}",
1827                boxes.shape
1828            )));
1829        }
1830        if scores.shape.len() != 3 {
1831            return Err(DecoderError::InvalidConfig(format!(
1832                "Invalid Yolo Split Scores shape {:?}",
1833                scores.shape
1834            )));
1835        }
1836
1837        if mask_coeff.shape.len() != 3 {
1838            return Err(DecoderError::InvalidConfig(format!(
1839                "Invalid Yolo Split Mask Coefficients shape {:?}",
1840                mask_coeff.shape
1841            )));
1842        }
1843
1844        if protos.shape.len() != 4 {
1845            return Err(DecoderError::InvalidConfig(format!(
1846                "Invalid Yolo Protos shape {:?}",
1847                mask_coeff.shape
1848            )));
1849        }
1850
1851        Self::verify_dshapes(
1852            &boxes.dshape,
1853            &boxes.shape,
1854            "Boxes",
1855            &[DimName::Batch, DimName::BoxCoords, DimName::NumBoxes],
1856        )?;
1857        Self::verify_dshapes(
1858            &scores.dshape,
1859            &scores.shape,
1860            "Scores",
1861            &[DimName::Batch, DimName::NumClasses, DimName::NumBoxes],
1862        )?;
1863        Self::verify_dshapes(
1864            &mask_coeff.dshape,
1865            &mask_coeff.shape,
1866            "Mask Coefficients",
1867            &[DimName::Batch, DimName::NumProtos, DimName::NumBoxes],
1868        )?;
1869        Self::verify_dshapes(
1870            &protos.dshape,
1871            &protos.shape,
1872            "Protos",
1873            &[
1874                DimName::Batch,
1875                DimName::Height,
1876                DimName::Width,
1877                DimName::NumProtos,
1878            ],
1879        )?;
1880
1881        let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[2]);
1882        let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[2]);
1883        let mask_num = Self::get_box_count(&mask_coeff.dshape).unwrap_or(mask_coeff.shape[2]);
1884
1885        let mask_channels = if !mask_coeff.dshape.is_empty() {
1886            Self::get_protos_count(&mask_coeff.dshape).ok_or_else(|| {
1887                DecoderError::InvalidConfig("Could not find num_protos in config".to_string())
1888            })?
1889        } else {
1890            mask_coeff.shape[1]
1891        };
1892        let proto_channels = if !protos.dshape.is_empty() {
1893            Self::get_protos_count(&protos.dshape).ok_or_else(|| {
1894                DecoderError::InvalidConfig("Could not find num_protos in config".to_string())
1895            })?
1896        } else {
1897            protos.shape[3]
1898        };
1899
1900        if boxes_num != scores_num {
1901            return Err(DecoderError::InvalidConfig(format!(
1902                "Yolo Split Detection Boxes num {} incompatible with Scores num {}",
1903                boxes_num, scores_num
1904            )));
1905        }
1906
1907        if boxes_num != mask_num {
1908            return Err(DecoderError::InvalidConfig(format!(
1909                "Yolo Split Detection Boxes num {} incompatible with Mask Coefficients num {}",
1910                boxes_num, mask_num
1911            )));
1912        }
1913
1914        if proto_channels != mask_channels {
1915            return Err(DecoderError::InvalidConfig(format!(
1916                "Yolo Protos channels {} incompatible with Mask Coefficients channels {}",
1917                proto_channels, mask_channels
1918            )));
1919        }
1920
1921        Ok(())
1922    }
1923
1924    fn verify_yolo_split_end_to_end_det(
1925        boxes: &configs::Boxes,
1926        scores: &configs::Scores,
1927        classes: &configs::Classes,
1928    ) -> Result<(), DecoderError> {
1929        if boxes.shape.len() != 3 || !boxes.shape.contains(&4) {
1930            return Err(DecoderError::InvalidConfig(format!(
1931                "Split end-to-end boxes must be [batch, N, 4], got {:?}",
1932                boxes.shape
1933            )));
1934        }
1935        if scores.shape.len() != 3 || !scores.shape.contains(&1) {
1936            return Err(DecoderError::InvalidConfig(format!(
1937                "Split end-to-end scores must be [batch, N, 1], got {:?}",
1938                scores.shape
1939            )));
1940        }
1941        if classes.shape.len() != 3 || !classes.shape.contains(&1) {
1942            return Err(DecoderError::InvalidConfig(format!(
1943                "Split end-to-end classes must be [batch, N, 1], got {:?}",
1944                classes.shape
1945            )));
1946        }
1947        Ok(())
1948    }
1949
1950    fn verify_yolo_split_end_to_end_segdet(
1951        boxes: &configs::Boxes,
1952        scores: &configs::Scores,
1953        classes: &configs::Classes,
1954        mask_coeff: &configs::MaskCoefficients,
1955        protos: &configs::Protos,
1956    ) -> Result<(), DecoderError> {
1957        Self::verify_yolo_split_end_to_end_det(boxes, scores, classes)?;
1958        if mask_coeff.shape.len() != 3 {
1959            return Err(DecoderError::InvalidConfig(format!(
1960                "Invalid split end-to-end mask coefficients shape {:?}",
1961                mask_coeff.shape
1962            )));
1963        }
1964        if protos.shape.len() != 4 {
1965            return Err(DecoderError::InvalidConfig(format!(
1966                "Invalid protos shape {:?}",
1967                protos.shape
1968            )));
1969        }
1970        Ok(())
1971    }
1972
1973    fn get_model_type_modelpack(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
1974        let mut split_decoders = Vec::new();
1975        let mut segment_ = None;
1976        let mut scores_ = None;
1977        let mut boxes_ = None;
1978        for c in configs.outputs {
1979            match c {
1980                ConfigOutput::Detection(detection) => split_decoders.push(detection),
1981                ConfigOutput::Segmentation(segmentation) => segment_ = Some(segmentation),
1982                ConfigOutput::Mask(_) => {}
1983                ConfigOutput::Protos(_) => {
1984                    return Err(DecoderError::InvalidConfig(
1985                        "ModelPack should not have protos".to_string(),
1986                    ));
1987                }
1988                ConfigOutput::Scores(scores) => scores_ = Some(scores),
1989                ConfigOutput::Boxes(boxes) => boxes_ = Some(boxes),
1990                ConfigOutput::MaskCoefficients(_) => {
1991                    return Err(DecoderError::InvalidConfig(
1992                        "ModelPack should not have mask coefficients".to_string(),
1993                    ));
1994                }
1995                ConfigOutput::Classes(_) => {
1996                    return Err(DecoderError::InvalidConfig(
1997                        "ModelPack should not have classes output".to_string(),
1998                    ));
1999                }
2000            }
2001        }
2002
2003        if let Some(segmentation) = segment_ {
2004            if !split_decoders.is_empty() {
2005                let classes = Self::verify_modelpack_split_det(&split_decoders)?;
2006                Self::verify_modelpack_seg(&segmentation, Some(classes))?;
2007                Ok(ModelType::ModelPackSegDetSplit {
2008                    detection: split_decoders,
2009                    segmentation,
2010                })
2011            } else if let (Some(scores), Some(boxes)) = (scores_, boxes_) {
2012                let classes = Self::verify_modelpack_det(&boxes, &scores)?;
2013                Self::verify_modelpack_seg(&segmentation, Some(classes))?;
2014                Ok(ModelType::ModelPackSegDet {
2015                    boxes,
2016                    scores,
2017                    segmentation,
2018                })
2019            } else {
2020                Self::verify_modelpack_seg(&segmentation, None)?;
2021                Ok(ModelType::ModelPackSeg { segmentation })
2022            }
2023        } else if !split_decoders.is_empty() {
2024            Self::verify_modelpack_split_det(&split_decoders)?;
2025            Ok(ModelType::ModelPackDetSplit {
2026                detection: split_decoders,
2027            })
2028        } else if let (Some(scores), Some(boxes)) = (scores_, boxes_) {
2029            Self::verify_modelpack_det(&boxes, &scores)?;
2030            Ok(ModelType::ModelPackDet { boxes, scores })
2031        } else {
2032            Err(DecoderError::InvalidConfig(
2033                "Invalid ModelPack model outputs".to_string(),
2034            ))
2035        }
2036    }
2037
2038    fn verify_modelpack_det(
2039        boxes: &configs::Boxes,
2040        scores: &configs::Scores,
2041    ) -> Result<usize, DecoderError> {
2042        if boxes.shape.len() != 4 {
2043            return Err(DecoderError::InvalidConfig(format!(
2044                "Invalid ModelPack Boxes shape {:?}",
2045                boxes.shape
2046            )));
2047        }
2048        if scores.shape.len() != 3 {
2049            return Err(DecoderError::InvalidConfig(format!(
2050                "Invalid ModelPack Scores shape {:?}",
2051                scores.shape
2052            )));
2053        }
2054
2055        Self::verify_dshapes(
2056            &boxes.dshape,
2057            &boxes.shape,
2058            "Boxes",
2059            &[
2060                DimName::Batch,
2061                DimName::NumBoxes,
2062                DimName::Padding,
2063                DimName::BoxCoords,
2064            ],
2065        )?;
2066        Self::verify_dshapes(
2067            &scores.dshape,
2068            &scores.shape,
2069            "Scores",
2070            &[DimName::Batch, DimName::NumBoxes, DimName::NumClasses],
2071        )?;
2072
2073        let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[1]);
2074        let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[1]);
2075
2076        if boxes_num != scores_num {
2077            return Err(DecoderError::InvalidConfig(format!(
2078                "ModelPack Detection Boxes num {} incompatible with Scores num {}",
2079                boxes_num, scores_num
2080            )));
2081        }
2082
2083        let num_classes = if !scores.dshape.is_empty() {
2084            Self::get_class_count(&scores.dshape, None, None)?
2085        } else {
2086            Self::get_class_count_no_dshape(scores.into(), None)?
2087        };
2088
2089        Ok(num_classes)
2090    }
2091
2092    fn verify_modelpack_split_det(boxes: &[configs::Detection]) -> Result<usize, DecoderError> {
2093        let mut num_classes = None;
2094        for b in boxes {
2095            let Some(num_anchors) = b.anchors.as_ref().map(|a| a.len()) else {
2096                return Err(DecoderError::InvalidConfig(
2097                    "ModelPack Split Detection missing anchors".to_string(),
2098                ));
2099            };
2100
2101            if num_anchors == 0 {
2102                return Err(DecoderError::InvalidConfig(
2103                    "ModelPack Split Detection has zero anchors".to_string(),
2104                ));
2105            }
2106
2107            if b.shape.len() != 4 {
2108                return Err(DecoderError::InvalidConfig(format!(
2109                    "Invalid ModelPack Split Detection shape {:?}",
2110                    b.shape
2111                )));
2112            }
2113
2114            Self::verify_dshapes(
2115                &b.dshape,
2116                &b.shape,
2117                "Split Detection",
2118                &[
2119                    DimName::Batch,
2120                    DimName::Height,
2121                    DimName::Width,
2122                    DimName::NumAnchorsXFeatures,
2123                ],
2124            )?;
2125            let classes = if !b.dshape.is_empty() {
2126                Self::get_class_count(&b.dshape, None, Some(num_anchors))?
2127            } else {
2128                Self::get_class_count_no_dshape(b.into(), None)?
2129            };
2130
2131            match num_classes {
2132                Some(n) => {
2133                    if n != classes {
2134                        return Err(DecoderError::InvalidConfig(format!(
2135                            "ModelPack Split Detection inconsistent number of classes: previous {}, current {}",
2136                            n, classes
2137                        )));
2138                    }
2139                }
2140                None => {
2141                    num_classes = Some(classes);
2142                }
2143            }
2144        }
2145
2146        Ok(num_classes.unwrap_or(0))
2147    }
2148
2149    fn verify_modelpack_seg(
2150        segmentation: &configs::Segmentation,
2151        classes: Option<usize>,
2152    ) -> Result<(), DecoderError> {
2153        if segmentation.shape.len() != 4 {
2154            return Err(DecoderError::InvalidConfig(format!(
2155                "Invalid ModelPack Segmentation shape {:?}",
2156                segmentation.shape
2157            )));
2158        }
2159        Self::verify_dshapes(
2160            &segmentation.dshape,
2161            &segmentation.shape,
2162            "Segmentation",
2163            &[
2164                DimName::Batch,
2165                DimName::Height,
2166                DimName::Width,
2167                DimName::NumClasses,
2168            ],
2169        )?;
2170
2171        if let Some(classes) = classes {
2172            let seg_classes = if !segmentation.dshape.is_empty() {
2173                Self::get_class_count(&segmentation.dshape, None, None)?
2174            } else {
2175                Self::get_class_count_no_dshape(segmentation.into(), None)?
2176            };
2177
2178            if seg_classes != classes + 1 {
2179                return Err(DecoderError::InvalidConfig(format!(
2180                    "ModelPack Segmentation channels {} incompatible with number of classes {}",
2181                    seg_classes, classes
2182                )));
2183            }
2184        }
2185        Ok(())
2186    }
2187
2188    // verifies that dshapes match the given shape
2189    fn verify_dshapes(
2190        dshape: &[(DimName, usize)],
2191        shape: &[usize],
2192        name: &str,
2193        dims: &[DimName],
2194    ) -> Result<(), DecoderError> {
2195        for s in shape {
2196            if *s == 0 {
2197                return Err(DecoderError::InvalidConfig(format!(
2198                    "{} shape has zero dimension",
2199                    name
2200                )));
2201            }
2202        }
2203
2204        if shape.len() != dims.len() {
2205            return Err(DecoderError::InvalidConfig(format!(
2206                "{} shape length {} does not match expected dims length {}",
2207                name,
2208                shape.len(),
2209                dims.len()
2210            )));
2211        }
2212
2213        if dshape.is_empty() {
2214            return Ok(());
2215        }
2216        // check the dshape lengths match the shape lengths
2217        if dshape.len() != shape.len() {
2218            return Err(DecoderError::InvalidConfig(format!(
2219                "{} dshape length does not match shape length",
2220                name
2221            )));
2222        }
2223
2224        // check the dshape values match the shape values
2225        for ((dim_name, dim_size), shape_size) in dshape.iter().zip(shape) {
2226            if dim_size != shape_size {
2227                return Err(DecoderError::InvalidConfig(format!(
2228                    "{} dshape dimension {} size {} does not match shape size {}",
2229                    name, dim_name, dim_size, shape_size
2230                )));
2231            }
2232            if *dim_name == DimName::Padding && *dim_size != 1 {
2233                return Err(DecoderError::InvalidConfig(
2234                    "Padding dimension size must be 1".to_string(),
2235                ));
2236            }
2237
2238            if *dim_name == DimName::BoxCoords && *dim_size != 4 {
2239                return Err(DecoderError::InvalidConfig(
2240                    "BoxCoords dimension size must be 4".to_string(),
2241                ));
2242            }
2243        }
2244
2245        let dims_present = HashSet::<DimName>::from_iter(dshape.iter().map(|(name, _)| *name));
2246        for dim in dims {
2247            if !dims_present.contains(dim) {
2248                return Err(DecoderError::InvalidConfig(format!(
2249                    "{} dshape missing required dimension {:?}",
2250                    name, dim
2251                )));
2252            }
2253        }
2254
2255        Ok(())
2256    }
2257
2258    fn get_box_count(dshape: &[(DimName, usize)]) -> Option<usize> {
2259        for (dim_name, dim_size) in dshape {
2260            if *dim_name == DimName::NumBoxes {
2261                return Some(*dim_size);
2262            }
2263        }
2264        None
2265    }
2266
2267    fn get_class_count_no_dshape(
2268        config: ConfigOutputRef,
2269        protos: Option<usize>,
2270    ) -> Result<usize, DecoderError> {
2271        match config {
2272            ConfigOutputRef::Detection(detection) => match detection.decoder {
2273                DecoderType::Ultralytics => {
2274                    if detection.shape[1] <= 4 + protos.unwrap_or(0) {
2275                        return Err(DecoderError::InvalidConfig(format!(
2276                            "Invalid shape: Yolo num_features {} must be greater than {}",
2277                            detection.shape[1],
2278                            4 + protos.unwrap_or(0),
2279                        )));
2280                    }
2281                    Ok(detection.shape[1] - 4 - protos.unwrap_or(0))
2282                }
2283                DecoderType::ModelPack => {
2284                    let Some(num_anchors) = detection.anchors.as_ref().map(|a| a.len()) else {
2285                        return Err(DecoderError::Internal(
2286                            "ModelPack Detection missing anchors".to_string(),
2287                        ));
2288                    };
2289                    let anchors_x_features = detection.shape[3];
2290                    if anchors_x_features <= num_anchors * 5 {
2291                        return Err(DecoderError::InvalidConfig(format!(
2292                            "Invalid ModelPack Split Detection shape: anchors_x_features {} not greater than number of anchors * 5 = {}",
2293                            anchors_x_features,
2294                            num_anchors * 5,
2295                        )));
2296                    }
2297
2298                    if !anchors_x_features.is_multiple_of(num_anchors) {
2299                        return Err(DecoderError::InvalidConfig(format!(
2300                            "Invalid ModelPack Split Detection shape: anchors_x_features {} not a multiple of number of anchors {}",
2301                            anchors_x_features, num_anchors
2302                        )));
2303                    }
2304                    Ok(anchors_x_features / num_anchors - 5)
2305                }
2306            },
2307
2308            ConfigOutputRef::Scores(scores) => match scores.decoder {
2309                DecoderType::Ultralytics => Ok(scores.shape[1]),
2310                DecoderType::ModelPack => Ok(scores.shape[2]),
2311            },
2312            ConfigOutputRef::Segmentation(seg) => Ok(seg.shape[3]),
2313            _ => Err(DecoderError::Internal(
2314                "Attempted to get class count from unsupported config output".to_owned(),
2315            )),
2316        }
2317    }
2318
2319    // get the class count from dshape or calculate from num_features
2320    fn get_class_count(
2321        dshape: &[(DimName, usize)],
2322        protos: Option<usize>,
2323        anchors: Option<usize>,
2324    ) -> Result<usize, DecoderError> {
2325        if dshape.is_empty() {
2326            return Ok(0);
2327        }
2328        // if it has num_classes in dshape, return it
2329        for (dim_name, dim_size) in dshape {
2330            if *dim_name == DimName::NumClasses {
2331                return Ok(*dim_size);
2332            }
2333        }
2334
2335        // number of classes can be calculated from num_features - 4 for yolo.  If the
2336        // model has protos, we also subtract the number of protos.
2337        for (dim_name, dim_size) in dshape {
2338            if *dim_name == DimName::NumFeatures {
2339                let protos = protos.unwrap_or(0);
2340                if protos + 4 >= *dim_size {
2341                    return Err(DecoderError::InvalidConfig(format!(
2342                        "Invalid shape: Yolo num_features {} must be greater than {}",
2343                        *dim_size,
2344                        protos + 4,
2345                    )));
2346                }
2347                return Ok(*dim_size - 4 - protos);
2348            }
2349        }
2350
2351        // number of classes can be calculated from number of anchors for modelpack
2352        // split detection
2353        if let Some(num_anchors) = anchors {
2354            for (dim_name, dim_size) in dshape {
2355                if *dim_name == DimName::NumAnchorsXFeatures {
2356                    let anchors_x_features = *dim_size;
2357                    if anchors_x_features <= num_anchors * 5 {
2358                        return Err(DecoderError::InvalidConfig(format!(
2359                            "Invalid ModelPack Split Detection shape: anchors_x_features {} not greater than number of anchors * 5 = {}",
2360                            anchors_x_features,
2361                            num_anchors * 5,
2362                        )));
2363                    }
2364
2365                    if !anchors_x_features.is_multiple_of(num_anchors) {
2366                        return Err(DecoderError::InvalidConfig(format!(
2367                            "Invalid ModelPack Split Detection shape: anchors_x_features {} not a multiple of number of anchors {}",
2368                            anchors_x_features, num_anchors
2369                        )));
2370                    }
2371                    return Ok((anchors_x_features / num_anchors) - 5);
2372                }
2373            }
2374        }
2375        Err(DecoderError::InvalidConfig(
2376            "Cannot determine number of classes from dshape".to_owned(),
2377        ))
2378    }
2379
2380    fn get_protos_count(dshape: &[(DimName, usize)]) -> Option<usize> {
2381        for (dim_name, dim_size) in dshape {
2382            if *dim_name == DimName::NumProtos {
2383                return Some(*dim_size);
2384            }
2385        }
2386        None
2387    }
2388}
2389
2390#[derive(Debug, Clone, PartialEq)]
2391pub struct Decoder {
2392    model_type: ModelType,
2393    pub iou_threshold: f32,
2394    pub score_threshold: f32,
2395    /// NMS mode: Some(mode) applies NMS, None bypasses NMS (for end-to-end
2396    /// models)
2397    pub nms: Option<configs::Nms>,
2398    /// Whether decoded boxes are in normalized [0,1] coordinates.
2399    /// - `Some(true)`: Coordinates in [0,1] range
2400    /// - `Some(false)`: Pixel coordinates
2401    /// - `None`: Unknown, caller must infer (e.g., check if any coordinate >
2402    ///   1.0)
2403    normalized: Option<bool>,
2404}
2405
2406#[derive(Debug)]
2407pub enum ArrayViewDQuantized<'a> {
2408    UInt8(ArrayViewD<'a, u8>),
2409    Int8(ArrayViewD<'a, i8>),
2410    UInt16(ArrayViewD<'a, u16>),
2411    Int16(ArrayViewD<'a, i16>),
2412    UInt32(ArrayViewD<'a, u32>),
2413    Int32(ArrayViewD<'a, i32>),
2414}
2415
2416impl<'a, D> From<ArrayView<'a, u8, D>> for ArrayViewDQuantized<'a>
2417where
2418    D: Dimension,
2419{
2420    fn from(arr: ArrayView<'a, u8, D>) -> Self {
2421        Self::UInt8(arr.into_dyn())
2422    }
2423}
2424
2425impl<'a, D> From<ArrayView<'a, i8, D>> for ArrayViewDQuantized<'a>
2426where
2427    D: Dimension,
2428{
2429    fn from(arr: ArrayView<'a, i8, D>) -> Self {
2430        Self::Int8(arr.into_dyn())
2431    }
2432}
2433
2434impl<'a, D> From<ArrayView<'a, u16, D>> for ArrayViewDQuantized<'a>
2435where
2436    D: Dimension,
2437{
2438    fn from(arr: ArrayView<'a, u16, D>) -> Self {
2439        Self::UInt16(arr.into_dyn())
2440    }
2441}
2442
2443impl<'a, D> From<ArrayView<'a, i16, D>> for ArrayViewDQuantized<'a>
2444where
2445    D: Dimension,
2446{
2447    fn from(arr: ArrayView<'a, i16, D>) -> Self {
2448        Self::Int16(arr.into_dyn())
2449    }
2450}
2451
2452impl<'a, D> From<ArrayView<'a, u32, D>> for ArrayViewDQuantized<'a>
2453where
2454    D: Dimension,
2455{
2456    fn from(arr: ArrayView<'a, u32, D>) -> Self {
2457        Self::UInt32(arr.into_dyn())
2458    }
2459}
2460
2461impl<'a, D> From<ArrayView<'a, i32, D>> for ArrayViewDQuantized<'a>
2462where
2463    D: Dimension,
2464{
2465    fn from(arr: ArrayView<'a, i32, D>) -> Self {
2466        Self::Int32(arr.into_dyn())
2467    }
2468}
2469
2470impl<'a> ArrayViewDQuantized<'a> {
2471    /// Returns the shape of the underlying array.
2472    ///
2473    /// # Examples
2474    /// ```rust
2475    /// # use edgefirst_decoder::ArrayViewDQuantized;
2476    /// # use ndarray::Array2;
2477    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
2478    /// let arr = Array2::from_shape_vec((2, 3), vec![1u8, 2, 3, 4, 5, 6])?;
2479    /// let view = ArrayViewDQuantized::from(arr.view().into_dyn());
2480    /// assert_eq!(view.shape(), &[2, 3]);
2481    /// # Ok(())
2482    /// # }
2483    /// ```
2484    pub fn shape(&self) -> &[usize] {
2485        match self {
2486            ArrayViewDQuantized::UInt8(a) => a.shape(),
2487            ArrayViewDQuantized::Int8(a) => a.shape(),
2488            ArrayViewDQuantized::UInt16(a) => a.shape(),
2489            ArrayViewDQuantized::Int16(a) => a.shape(),
2490            ArrayViewDQuantized::UInt32(a) => a.shape(),
2491            ArrayViewDQuantized::Int32(a) => a.shape(),
2492        }
2493    }
2494}
2495
2496/// WARNING: Do NOT nest `with_quantized!` calls. Each level multiplies
2497/// monomorphized code paths by 6 (one per integer variant), so nesting
2498/// N levels deep produces 6^N instantiations.
2499///
2500/// Instead, dequantize each tensor sequentially with `dequant_3d!`/`dequant_4d!`
2501/// (6*N paths) or split into independent phases that each nest at most 2 levels.
2502macro_rules! with_quantized {
2503    ($x:expr, $var:ident, $body:expr) => {
2504        match $x {
2505            ArrayViewDQuantized::UInt8(x) => {
2506                let $var = x;
2507                $body
2508            }
2509            ArrayViewDQuantized::Int8(x) => {
2510                let $var = x;
2511                $body
2512            }
2513            ArrayViewDQuantized::UInt16(x) => {
2514                let $var = x;
2515                $body
2516            }
2517            ArrayViewDQuantized::Int16(x) => {
2518                let $var = x;
2519                $body
2520            }
2521            ArrayViewDQuantized::UInt32(x) => {
2522                let $var = x;
2523                $body
2524            }
2525            ArrayViewDQuantized::Int32(x) => {
2526                let $var = x;
2527                $body
2528            }
2529        }
2530    };
2531}
2532
2533impl Decoder {
2534    /// This function returns the parsed model type of the decoder.
2535    ///
2536    /// # Examples
2537    ///
2538    /// ```rust
2539    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, configs::ModelType};
2540    /// # fn main() -> DecoderResult<()> {
2541    /// #    let config_yaml = include_str!("../../../testdata/modelpack_split.yaml").to_string();
2542    ///     let decoder = DecoderBuilder::default()
2543    ///         .with_config_yaml_str(config_yaml)
2544    ///         .build()?;
2545    ///     assert!(matches!(
2546    ///         decoder.model_type(),
2547    ///         ModelType::ModelPackDetSplit { .. }
2548    ///     ));
2549    /// #    Ok(())
2550    /// # }
2551    /// ```
2552    pub fn model_type(&self) -> &ModelType {
2553        &self.model_type
2554    }
2555
2556    /// Returns the box coordinate format if known from the model config.
2557    ///
2558    /// - `Some(true)`: Boxes are in normalized [0,1] coordinates
2559    /// - `Some(false)`: Boxes are in pixel coordinates relative to model input
2560    /// - `None`: Unknown, caller must infer (e.g., check if any coordinate >
2561    ///   1.0)
2562    ///
2563    /// This is determined by the model config's `normalized` field, not the NMS
2564    /// mode. When coordinates are in pixels or unknown, the caller may need
2565    /// to normalize using the model input dimensions.
2566    ///
2567    /// # Examples
2568    ///
2569    /// ```rust
2570    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
2571    /// # fn main() -> DecoderResult<()> {
2572    /// #    let config_yaml = include_str!("../../../testdata/modelpack_split.yaml").to_string();
2573    ///     let decoder = DecoderBuilder::default()
2574    ///         .with_config_yaml_str(config_yaml)
2575    ///         .build()?;
2576    ///     // Config doesn't specify normalized, so it's None
2577    ///     assert!(decoder.normalized_boxes().is_none());
2578    /// #    Ok(())
2579    /// # }
2580    /// ```
2581    pub fn normalized_boxes(&self) -> Option<bool> {
2582        self.normalized
2583    }
2584
2585    /// This function decodes quantized model outputs into detection boxes and
2586    /// segmentation masks. The quantized outputs can be of u8, i8, u16, i16,
2587    /// u32, or i32 types. Up to `output_boxes.capacity()` boxes and masks
2588    /// will be decoded. The function clears the provided output vectors
2589    /// before populating them with the decoded results.
2590    ///
2591    /// This function returns a `DecoderError` if the the provided outputs don't
2592    /// match the configuration provided by the user when building the decoder.
2593    ///
2594    /// # Examples
2595    ///
2596    /// ```rust
2597    /// # use edgefirst_decoder::{BoundingBox, DecoderBuilder, DetectBox, DecoderResult};
2598    /// # use ndarray::Array4;
2599    /// # fn main() -> DecoderResult<()> {
2600    /// #    let detect0 = include_bytes!("../../../testdata/modelpack_split_9x15x18.bin");
2601    /// #    let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec())?;
2602    /// #
2603    /// #    let detect1 = include_bytes!("../../../testdata/modelpack_split_17x30x18.bin");
2604    /// #    let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec())?;
2605    /// #    let model_output = vec![
2606    /// #        detect1.view().into_dyn().into(),
2607    /// #        detect0.view().into_dyn().into(),
2608    /// #    ];
2609    /// let decoder = DecoderBuilder::default()
2610    ///     .with_config_yaml_str(include_str!("../../../testdata/modelpack_split.yaml").to_string())
2611    ///     .with_score_threshold(0.45)
2612    ///     .with_iou_threshold(0.45)
2613    ///     .build()?;
2614    ///
2615    /// let mut output_boxes: Vec<_> = Vec::with_capacity(10);
2616    /// let mut output_masks: Vec<_> = Vec::with_capacity(10);
2617    /// decoder.decode_quantized(&model_output, &mut output_boxes, &mut output_masks)?;
2618    /// assert!(output_boxes[0].equal_within_delta(
2619    ///     &DetectBox {
2620    ///         bbox: BoundingBox {
2621    ///             xmin: 0.43171933,
2622    ///             ymin: 0.68243736,
2623    ///             xmax: 0.5626645,
2624    ///             ymax: 0.808863,
2625    ///         },
2626    ///         score: 0.99240804,
2627    ///         label: 0
2628    ///     },
2629    ///     1e-6
2630    /// ));
2631    /// #    Ok(())
2632    /// # }
2633    /// ```
2634    pub fn decode_quantized(
2635        &self,
2636        outputs: &[ArrayViewDQuantized],
2637        output_boxes: &mut Vec<DetectBox>,
2638        output_masks: &mut Vec<Segmentation>,
2639    ) -> Result<(), DecoderError> {
2640        output_boxes.clear();
2641        output_masks.clear();
2642        match &self.model_type {
2643            ModelType::ModelPackSegDet {
2644                boxes,
2645                scores,
2646                segmentation,
2647            } => {
2648                self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)?;
2649                self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
2650            }
2651            ModelType::ModelPackSegDetSplit {
2652                detection,
2653                segmentation,
2654            } => {
2655                self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)?;
2656                self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
2657            }
2658            ModelType::ModelPackDet { boxes, scores } => {
2659                self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)
2660            }
2661            ModelType::ModelPackDetSplit { detection } => {
2662                self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)
2663            }
2664            ModelType::ModelPackSeg { segmentation } => {
2665                self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
2666            }
2667            ModelType::YoloDet { boxes } => {
2668                self.decode_yolo_det_quantized(outputs, boxes, output_boxes)
2669            }
2670            ModelType::YoloSegDet { boxes, protos } => self.decode_yolo_segdet_quantized(
2671                outputs,
2672                boxes,
2673                protos,
2674                output_boxes,
2675                output_masks,
2676            ),
2677            ModelType::YoloSplitDet { boxes, scores } => {
2678                self.decode_yolo_split_det_quantized(outputs, boxes, scores, output_boxes)
2679            }
2680            ModelType::YoloSplitSegDet {
2681                boxes,
2682                scores,
2683                mask_coeff,
2684                protos,
2685            } => self.decode_yolo_split_segdet_quantized(
2686                outputs,
2687                boxes,
2688                scores,
2689                mask_coeff,
2690                protos,
2691                output_boxes,
2692                output_masks,
2693            ),
2694            ModelType::YoloEndToEndDet { boxes } => {
2695                self.decode_yolo_end_to_end_det_quantized(outputs, boxes, output_boxes)
2696            }
2697            ModelType::YoloEndToEndSegDet { boxes, protos } => self
2698                .decode_yolo_end_to_end_segdet_quantized(
2699                    outputs,
2700                    boxes,
2701                    protos,
2702                    output_boxes,
2703                    output_masks,
2704                ),
2705            ModelType::YoloSplitEndToEndDet {
2706                boxes,
2707                scores,
2708                classes,
2709            } => self.decode_yolo_split_end_to_end_det_quantized(
2710                outputs,
2711                boxes,
2712                scores,
2713                classes,
2714                output_boxes,
2715            ),
2716            ModelType::YoloSplitEndToEndSegDet {
2717                boxes,
2718                scores,
2719                classes,
2720                mask_coeff,
2721                protos,
2722            } => self.decode_yolo_split_end_to_end_segdet_quantized(
2723                outputs,
2724                boxes,
2725                scores,
2726                classes,
2727                mask_coeff,
2728                protos,
2729                output_boxes,
2730                output_masks,
2731            ),
2732        }
2733    }
2734
2735    /// This function decodes floating point model outputs into detection boxes
2736    /// and segmentation masks. Up to `output_boxes.capacity()` boxes and
2737    /// masks will be decoded. The function clears the provided output
2738    /// vectors before populating them with the decoded results.
2739    ///
2740    /// This function returns an `Error` if the the provided outputs don't
2741    /// match the configuration provided by the user when building the decoder.
2742    ///
2743    /// Any quantization information in the configuration will be ignored.
2744    ///
2745    /// # Examples
2746    ///
2747    /// ```rust
2748    /// # use edgefirst_decoder::{BoundingBox, DecoderBuilder, DetectBox, DecoderResult, configs, configs::{DecoderType, DecoderVersion}, dequantize_cpu, Quantization};
2749    /// # use ndarray::Array3;
2750    /// # fn main() -> DecoderResult<()> {
2751    /// #   let out = include_bytes!("../../../testdata/yolov8s_80_classes.bin");
2752    /// #   let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
2753    /// #   let mut out_dequant = vec![0.0_f64; 84 * 8400];
2754    /// #   let quant = Quantization::new(0.0040811873, -123);
2755    /// #   dequantize_cpu(out, quant, &mut out_dequant);
2756    /// #   let model_output_f64 = Array3::from_shape_vec((1, 84, 8400), out_dequant)?.into_dyn();
2757    ///    let decoder = DecoderBuilder::default()
2758    ///     .with_config_yolo_det(configs::Detection {
2759    ///         decoder: DecoderType::Ultralytics,
2760    ///         quantization: None,
2761    ///         shape: vec![1, 84, 8400],
2762    ///         anchors: None,
2763    ///         dshape: Vec::new(),
2764    ///         normalized: Some(true),
2765    ///     },
2766    ///     Some(DecoderVersion::Yolo11))
2767    ///     .with_score_threshold(0.25)
2768    ///     .with_iou_threshold(0.7)
2769    ///     .build()?;
2770    ///
2771    /// let mut output_boxes: Vec<_> = Vec::with_capacity(10);
2772    /// let mut output_masks: Vec<_> = Vec::with_capacity(10);
2773    /// let model_output_f64 = vec![model_output_f64.view().into()];
2774    /// decoder.decode_float(&model_output_f64, &mut output_boxes, &mut output_masks)?;    
2775    /// assert!(output_boxes[0].equal_within_delta(
2776    ///        &DetectBox {
2777    ///            bbox: BoundingBox {
2778    ///                xmin: 0.5285137,
2779    ///                ymin: 0.05305544,
2780    ///                xmax: 0.87541467,
2781    ///                ymax: 0.9998909,
2782    ///            },
2783    ///            score: 0.5591227,
2784    ///            label: 0
2785    ///        },
2786    ///        1e-6
2787    ///    ));
2788    ///
2789    /// #    Ok(())
2790    /// # }
2791    pub fn decode_float<T>(
2792        &self,
2793        outputs: &[ArrayViewD<T>],
2794        output_boxes: &mut Vec<DetectBox>,
2795        output_masks: &mut Vec<Segmentation>,
2796    ) -> Result<(), DecoderError>
2797    where
2798        T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
2799        f32: AsPrimitive<T>,
2800    {
2801        output_boxes.clear();
2802        output_masks.clear();
2803        match &self.model_type {
2804            ModelType::ModelPackSegDet {
2805                boxes,
2806                scores,
2807                segmentation,
2808            } => {
2809                self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
2810                self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
2811            }
2812            ModelType::ModelPackSegDetSplit {
2813                detection,
2814                segmentation,
2815            } => {
2816                self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
2817                self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
2818            }
2819            ModelType::ModelPackDet { boxes, scores } => {
2820                self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
2821            }
2822            ModelType::ModelPackDetSplit { detection } => {
2823                self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
2824            }
2825            ModelType::ModelPackSeg { segmentation } => {
2826                self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
2827            }
2828            ModelType::YoloDet { boxes } => {
2829                self.decode_yolo_det_float(outputs, boxes, output_boxes)?;
2830            }
2831            ModelType::YoloSegDet { boxes, protos } => {
2832                self.decode_yolo_segdet_float(outputs, boxes, protos, output_boxes, output_masks)?;
2833            }
2834            ModelType::YoloSplitDet { boxes, scores } => {
2835                self.decode_yolo_split_det_float(outputs, boxes, scores, output_boxes)?;
2836            }
2837            ModelType::YoloSplitSegDet {
2838                boxes,
2839                scores,
2840                mask_coeff,
2841                protos,
2842            } => {
2843                self.decode_yolo_split_segdet_float(
2844                    outputs,
2845                    boxes,
2846                    scores,
2847                    mask_coeff,
2848                    protos,
2849                    output_boxes,
2850                    output_masks,
2851                )?;
2852            }
2853            ModelType::YoloEndToEndDet { boxes } => {
2854                self.decode_yolo_end_to_end_det_float(outputs, boxes, output_boxes)?;
2855            }
2856            ModelType::YoloEndToEndSegDet { boxes, protos } => {
2857                self.decode_yolo_end_to_end_segdet_float(
2858                    outputs,
2859                    boxes,
2860                    protos,
2861                    output_boxes,
2862                    output_masks,
2863                )?;
2864            }
2865            ModelType::YoloSplitEndToEndDet {
2866                boxes,
2867                scores,
2868                classes,
2869            } => {
2870                self.decode_yolo_split_end_to_end_det_float(
2871                    outputs,
2872                    boxes,
2873                    scores,
2874                    classes,
2875                    output_boxes,
2876                )?;
2877            }
2878            ModelType::YoloSplitEndToEndSegDet {
2879                boxes,
2880                scores,
2881                classes,
2882                mask_coeff,
2883                protos,
2884            } => {
2885                self.decode_yolo_split_end_to_end_segdet_float(
2886                    outputs,
2887                    boxes,
2888                    scores,
2889                    classes,
2890                    mask_coeff,
2891                    protos,
2892                    output_boxes,
2893                    output_masks,
2894                )?;
2895            }
2896        }
2897        Ok(())
2898    }
2899
2900    /// Decodes quantized model outputs into detection boxes, returning raw
2901    /// `ProtoData` for segmentation models instead of materialized masks.
2902    ///
2903    /// Returns `Ok(None)` for detection-only and ModelPack models (use
2904    /// `decode_quantized` for those). Returns `Ok(Some(ProtoData))` for
2905    /// YOLO segmentation models.
2906    pub fn decode_quantized_proto(
2907        &self,
2908        outputs: &[ArrayViewDQuantized],
2909        output_boxes: &mut Vec<DetectBox>,
2910    ) -> Result<Option<ProtoData>, DecoderError> {
2911        output_boxes.clear();
2912        match &self.model_type {
2913            // Detection-only and ModelPack variants: no proto data
2914            ModelType::ModelPackSegDet { .. }
2915            | ModelType::ModelPackSegDetSplit { .. }
2916            | ModelType::ModelPackDet { .. }
2917            | ModelType::ModelPackDetSplit { .. }
2918            | ModelType::ModelPackSeg { .. }
2919            | ModelType::YoloDet { .. }
2920            | ModelType::YoloSplitDet { .. }
2921            | ModelType::YoloEndToEndDet { .. }
2922            | ModelType::YoloSplitEndToEndDet { .. } => Ok(None),
2923
2924            ModelType::YoloSegDet { boxes, protos } => {
2925                let proto =
2926                    self.decode_yolo_segdet_quantized_proto(outputs, boxes, protos, output_boxes)?;
2927                Ok(Some(proto))
2928            }
2929            ModelType::YoloSplitSegDet {
2930                boxes,
2931                scores,
2932                mask_coeff,
2933                protos,
2934            } => {
2935                let proto = self.decode_yolo_split_segdet_quantized_proto(
2936                    outputs,
2937                    boxes,
2938                    scores,
2939                    mask_coeff,
2940                    protos,
2941                    output_boxes,
2942                )?;
2943                Ok(Some(proto))
2944            }
2945            ModelType::YoloEndToEndSegDet { boxes, protos } => {
2946                let proto = self.decode_yolo_end_to_end_segdet_quantized_proto(
2947                    outputs,
2948                    boxes,
2949                    protos,
2950                    output_boxes,
2951                )?;
2952                Ok(Some(proto))
2953            }
2954            ModelType::YoloSplitEndToEndSegDet {
2955                boxes,
2956                scores,
2957                classes,
2958                mask_coeff,
2959                protos,
2960            } => {
2961                let proto = self.decode_yolo_split_end_to_end_segdet_quantized_proto(
2962                    outputs,
2963                    boxes,
2964                    scores,
2965                    classes,
2966                    mask_coeff,
2967                    protos,
2968                    output_boxes,
2969                )?;
2970                Ok(Some(proto))
2971            }
2972        }
2973    }
2974
2975    /// Decodes floating-point model outputs into detection boxes, returning
2976    /// raw `ProtoData` for segmentation models instead of materialized masks.
2977    ///
2978    /// Returns `Ok(None)` for detection-only and ModelPack models. Returns
2979    /// `Ok(Some(ProtoData))` for YOLO segmentation models.
2980    pub fn decode_float_proto<T>(
2981        &self,
2982        outputs: &[ArrayViewD<T>],
2983        output_boxes: &mut Vec<DetectBox>,
2984    ) -> Result<Option<ProtoData>, DecoderError>
2985    where
2986        T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
2987        f32: AsPrimitive<T>,
2988    {
2989        output_boxes.clear();
2990        match &self.model_type {
2991            // Detection-only and ModelPack variants: no proto data
2992            ModelType::ModelPackSegDet { .. }
2993            | ModelType::ModelPackSegDetSplit { .. }
2994            | ModelType::ModelPackDet { .. }
2995            | ModelType::ModelPackDetSplit { .. }
2996            | ModelType::ModelPackSeg { .. }
2997            | ModelType::YoloDet { .. }
2998            | ModelType::YoloSplitDet { .. }
2999            | ModelType::YoloEndToEndDet { .. }
3000            | ModelType::YoloSplitEndToEndDet { .. } => Ok(None),
3001
3002            ModelType::YoloSegDet { boxes, protos } => {
3003                let proto =
3004                    self.decode_yolo_segdet_float_proto(outputs, boxes, protos, output_boxes)?;
3005                Ok(Some(proto))
3006            }
3007            ModelType::YoloSplitSegDet {
3008                boxes,
3009                scores,
3010                mask_coeff,
3011                protos,
3012            } => {
3013                let proto = self.decode_yolo_split_segdet_float_proto(
3014                    outputs,
3015                    boxes,
3016                    scores,
3017                    mask_coeff,
3018                    protos,
3019                    output_boxes,
3020                )?;
3021                Ok(Some(proto))
3022            }
3023            ModelType::YoloEndToEndSegDet { boxes, protos } => {
3024                let proto = self.decode_yolo_end_to_end_segdet_float_proto(
3025                    outputs,
3026                    boxes,
3027                    protos,
3028                    output_boxes,
3029                )?;
3030                Ok(Some(proto))
3031            }
3032            ModelType::YoloSplitEndToEndSegDet {
3033                boxes,
3034                scores,
3035                classes,
3036                mask_coeff,
3037                protos,
3038            } => {
3039                let proto = self.decode_yolo_split_end_to_end_segdet_float_proto(
3040                    outputs,
3041                    boxes,
3042                    scores,
3043                    classes,
3044                    mask_coeff,
3045                    protos,
3046                    output_boxes,
3047                )?;
3048                Ok(Some(proto))
3049            }
3050        }
3051    }
3052
3053    fn decode_modelpack_det_quantized(
3054        &self,
3055        outputs: &[ArrayViewDQuantized],
3056        boxes: &configs::Boxes,
3057        scores: &configs::Scores,
3058        output_boxes: &mut Vec<DetectBox>,
3059    ) -> Result<(), DecoderError> {
3060        let (boxes_tensor, ind) =
3061            Self::find_outputs_with_shape_quantized(&boxes.shape, outputs, &[])?;
3062        let (scores_tensor, _) =
3063            Self::find_outputs_with_shape_quantized(&scores.shape, outputs, &[ind])?;
3064        let quant_boxes = boxes
3065            .quantization
3066            .map(Quantization::from)
3067            .unwrap_or_default();
3068        let quant_scores = scores
3069            .quantization
3070            .map(Quantization::from)
3071            .unwrap_or_default();
3072
3073        with_quantized!(boxes_tensor, b, {
3074            with_quantized!(scores_tensor, s, {
3075                let boxes_tensor = Self::swap_axes_if_needed(b, boxes.into());
3076                let boxes_tensor = boxes_tensor.slice(s![0, .., 0, ..]);
3077
3078                let scores_tensor = Self::swap_axes_if_needed(s, scores.into());
3079                let scores_tensor = scores_tensor.slice(s![0, .., ..]);
3080                decode_modelpack_det(
3081                    (boxes_tensor, quant_boxes),
3082                    (scores_tensor, quant_scores),
3083                    self.score_threshold,
3084                    self.iou_threshold,
3085                    output_boxes,
3086                );
3087            });
3088        });
3089
3090        Ok(())
3091    }
3092
3093    fn decode_modelpack_seg_quantized(
3094        &self,
3095        outputs: &[ArrayViewDQuantized],
3096        segmentation: &configs::Segmentation,
3097        output_masks: &mut Vec<Segmentation>,
3098    ) -> Result<(), DecoderError> {
3099        let (seg, _) = Self::find_outputs_with_shape_quantized(&segmentation.shape, outputs, &[])?;
3100
3101        macro_rules! modelpack_seg {
3102            ($seg:expr, $body:expr) => {{
3103                let seg = Self::swap_axes_if_needed($seg, segmentation.into());
3104                let seg = seg.slice(s![0, .., .., ..]);
3105                seg.mapv($body)
3106            }};
3107        }
3108        use ArrayViewDQuantized::*;
3109        let seg = match seg {
3110            UInt8(s) => {
3111                modelpack_seg!(s, |x| x)
3112            }
3113            Int8(s) => {
3114                modelpack_seg!(s, |x| (x as i16 + 128) as u8)
3115            }
3116            UInt16(s) => {
3117                modelpack_seg!(s, |x| (x >> 8) as u8)
3118            }
3119            Int16(s) => {
3120                modelpack_seg!(s, |x| ((x as i32 + 32768) >> 8) as u8)
3121            }
3122            UInt32(s) => {
3123                modelpack_seg!(s, |x| (x >> 24) as u8)
3124            }
3125            Int32(s) => {
3126                modelpack_seg!(s, |x| ((x as i64 + 2147483648) >> 24) as u8)
3127            }
3128        };
3129
3130        output_masks.push(Segmentation {
3131            xmin: 0.0,
3132            ymin: 0.0,
3133            xmax: 1.0,
3134            ymax: 1.0,
3135            segmentation: seg,
3136        });
3137        Ok(())
3138    }
3139
3140    fn decode_modelpack_det_split_quantized(
3141        &self,
3142        outputs: &[ArrayViewDQuantized],
3143        detection: &[configs::Detection],
3144        output_boxes: &mut Vec<DetectBox>,
3145    ) -> Result<(), DecoderError> {
3146        let new_detection = detection
3147            .iter()
3148            .map(|x| match &x.anchors {
3149                None => Err(DecoderError::InvalidConfig(
3150                    "ModelPack Split Detection missing anchors".to_string(),
3151                )),
3152                Some(a) => Ok(ModelPackDetectionConfig {
3153                    anchors: a.clone(),
3154                    quantization: None,
3155                }),
3156            })
3157            .collect::<Result<Vec<_>, _>>()?;
3158        let new_outputs = Self::match_outputs_to_detect_quantized(detection, outputs)?;
3159
3160        macro_rules! dequant_output {
3161            ($det_tensor:expr, $detection:expr) => {{
3162                let det_tensor = Self::swap_axes_if_needed($det_tensor, $detection.into());
3163                let det_tensor = det_tensor.slice(s![0, .., .., ..]);
3164                if let Some(q) = $detection.quantization {
3165                    dequantize_ndarray(det_tensor, q.into())
3166                } else {
3167                    det_tensor.map(|x| *x as f32)
3168                }
3169            }};
3170        }
3171
3172        let new_outputs = new_outputs
3173            .iter()
3174            .zip(detection)
3175            .map(|(det_tensor, detection)| {
3176                with_quantized!(det_tensor, d, dequant_output!(d, detection))
3177            })
3178            .collect::<Vec<_>>();
3179
3180        let new_outputs_view = new_outputs
3181            .iter()
3182            .map(|d: &Array3<f32>| d.view())
3183            .collect::<Vec<_>>();
3184        decode_modelpack_split_float(
3185            &new_outputs_view,
3186            &new_detection,
3187            self.score_threshold,
3188            self.iou_threshold,
3189            output_boxes,
3190        );
3191        Ok(())
3192    }
3193
3194    fn decode_yolo_det_quantized(
3195        &self,
3196        outputs: &[ArrayViewDQuantized],
3197        boxes: &configs::Detection,
3198        output_boxes: &mut Vec<DetectBox>,
3199    ) -> Result<(), DecoderError> {
3200        let (boxes_tensor, _) =
3201            Self::find_outputs_with_shape_quantized(&boxes.shape, outputs, &[])?;
3202        let quant_boxes = boxes
3203            .quantization
3204            .map(Quantization::from)
3205            .unwrap_or_default();
3206
3207        with_quantized!(boxes_tensor, b, {
3208            let boxes_tensor = Self::swap_axes_if_needed(b, boxes.into());
3209            let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
3210            decode_yolo_det(
3211                (boxes_tensor, quant_boxes),
3212                self.score_threshold,
3213                self.iou_threshold,
3214                self.nms,
3215                output_boxes,
3216            );
3217        });
3218
3219        Ok(())
3220    }
3221
3222    fn decode_yolo_segdet_quantized(
3223        &self,
3224        outputs: &[ArrayViewDQuantized],
3225        boxes: &configs::Detection,
3226        protos: &configs::Protos,
3227        output_boxes: &mut Vec<DetectBox>,
3228        output_masks: &mut Vec<Segmentation>,
3229    ) -> Result<(), DecoderError> {
3230        let (boxes_tensor, ind) =
3231            Self::find_outputs_with_shape_quantized(&boxes.shape, outputs, &[])?;
3232        let (protos_tensor, _) =
3233            Self::find_outputs_with_shape_quantized(&protos.shape, outputs, &[ind])?;
3234
3235        let quant_boxes = boxes
3236            .quantization
3237            .map(Quantization::from)
3238            .unwrap_or_default();
3239        let quant_protos = protos
3240            .quantization
3241            .map(Quantization::from)
3242            .unwrap_or_default();
3243
3244        with_quantized!(boxes_tensor, b, {
3245            with_quantized!(protos_tensor, p, {
3246                let box_tensor = Self::swap_axes_if_needed(b, boxes.into());
3247                let box_tensor = box_tensor.slice(s![0, .., ..]);
3248
3249                let protos_tensor = Self::swap_axes_if_needed(p, protos.into());
3250                let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
3251                decode_yolo_segdet_quant(
3252                    (box_tensor, quant_boxes),
3253                    (protos_tensor, quant_protos),
3254                    self.score_threshold,
3255                    self.iou_threshold,
3256                    self.nms,
3257                    output_boxes,
3258                    output_masks,
3259                );
3260            });
3261        });
3262
3263        Ok(())
3264    }
3265
3266    fn decode_yolo_split_det_quantized(
3267        &self,
3268        outputs: &[ArrayViewDQuantized],
3269        boxes: &configs::Boxes,
3270        scores: &configs::Scores,
3271        output_boxes: &mut Vec<DetectBox>,
3272    ) -> Result<(), DecoderError> {
3273        let (boxes_tensor, ind) =
3274            Self::find_outputs_with_shape_quantized(&boxes.shape, outputs, &[])?;
3275        let (scores_tensor, _) =
3276            Self::find_outputs_with_shape_quantized(&scores.shape, outputs, &[ind])?;
3277        let quant_boxes = boxes
3278            .quantization
3279            .map(Quantization::from)
3280            .unwrap_or_default();
3281        let quant_scores = scores
3282            .quantization
3283            .map(Quantization::from)
3284            .unwrap_or_default();
3285
3286        with_quantized!(boxes_tensor, b, {
3287            with_quantized!(scores_tensor, s, {
3288                let boxes_tensor = Self::swap_axes_if_needed(b, boxes.into());
3289                let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
3290
3291                let scores_tensor = Self::swap_axes_if_needed(s, scores.into());
3292                let scores_tensor = scores_tensor.slice(s![0, .., ..]);
3293                decode_yolo_split_det_quant(
3294                    (boxes_tensor, quant_boxes),
3295                    (scores_tensor, quant_scores),
3296                    self.score_threshold,
3297                    self.iou_threshold,
3298                    self.nms,
3299                    output_boxes,
3300                );
3301            });
3302        });
3303
3304        Ok(())
3305    }
3306
3307    #[allow(clippy::too_many_arguments)]
3308    fn decode_yolo_split_segdet_quantized(
3309        &self,
3310        outputs: &[ArrayViewDQuantized],
3311        boxes: &configs::Boxes,
3312        scores: &configs::Scores,
3313        mask_coeff: &configs::MaskCoefficients,
3314        protos: &configs::Protos,
3315        output_boxes: &mut Vec<DetectBox>,
3316        output_masks: &mut Vec<Segmentation>,
3317    ) -> Result<(), DecoderError> {
3318        let quant_boxes = boxes
3319            .quantization
3320            .map(Quantization::from)
3321            .unwrap_or_default();
3322        let quant_scores = scores
3323            .quantization
3324            .map(Quantization::from)
3325            .unwrap_or_default();
3326        let quant_masks = mask_coeff
3327            .quantization
3328            .map(Quantization::from)
3329            .unwrap_or_default();
3330        let quant_protos = protos
3331            .quantization
3332            .map(Quantization::from)
3333            .unwrap_or_default();
3334
3335        let mut skip = vec![];
3336
3337        let (boxes_tensor, ind) =
3338            Self::find_outputs_with_shape_quantized(&boxes.shape, outputs, &skip)?;
3339        skip.push(ind);
3340
3341        let (scores_tensor, ind) =
3342            Self::find_outputs_with_shape_quantized(&scores.shape, outputs, &skip)?;
3343        skip.push(ind);
3344
3345        let (mask_tensor, ind) =
3346            Self::find_outputs_with_shape_quantized(&mask_coeff.shape, outputs, &skip)?;
3347        skip.push(ind);
3348
3349        let (protos_tensor, _) =
3350            Self::find_outputs_with_shape_quantized(&protos.shape, outputs, &skip)?;
3351
3352        let boxes = with_quantized!(boxes_tensor, b, {
3353            with_quantized!(scores_tensor, s, {
3354                let boxes_tensor = Self::swap_axes_if_needed(b, boxes.into());
3355                let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
3356
3357                let scores_tensor = Self::swap_axes_if_needed(s, scores.into());
3358                let scores_tensor = scores_tensor.slice(s![0, .., ..]);
3359                impl_yolo_split_segdet_quant_get_boxes::<XYWH, _, _>(
3360                    (boxes_tensor, quant_boxes),
3361                    (scores_tensor, quant_scores),
3362                    self.score_threshold,
3363                    self.iou_threshold,
3364                    self.nms,
3365                    output_boxes.capacity(),
3366                )
3367            })
3368        });
3369
3370        with_quantized!(mask_tensor, m, {
3371            with_quantized!(protos_tensor, p, {
3372                let mask_tensor = Self::swap_axes_if_needed(m, mask_coeff.into());
3373                let mask_tensor = mask_tensor.slice(s![0, .., ..]);
3374
3375                let protos_tensor = Self::swap_axes_if_needed(p, protos.into());
3376                let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
3377                impl_yolo_split_segdet_quant_process_masks::<_, _>(
3378                    boxes,
3379                    (mask_tensor, quant_masks),
3380                    (protos_tensor, quant_protos),
3381                    output_boxes,
3382                    output_masks,
3383                )
3384            })
3385        });
3386
3387        Ok(())
3388    }
3389
3390    fn decode_modelpack_det_split_float<D>(
3391        &self,
3392        outputs: &[ArrayViewD<D>],
3393        detection: &[configs::Detection],
3394        output_boxes: &mut Vec<DetectBox>,
3395    ) -> Result<(), DecoderError>
3396    where
3397        D: AsPrimitive<f32>,
3398    {
3399        let new_detection = detection
3400            .iter()
3401            .map(|x| match &x.anchors {
3402                None => Err(DecoderError::InvalidConfig(
3403                    "ModelPack Split Detection missing anchors".to_string(),
3404                )),
3405                Some(a) => Ok(ModelPackDetectionConfig {
3406                    anchors: a.clone(),
3407                    quantization: None,
3408                }),
3409            })
3410            .collect::<Result<Vec<_>, _>>()?;
3411
3412        let new_outputs = Self::match_outputs_to_detect(detection, outputs)?;
3413        let new_outputs = new_outputs
3414            .into_iter()
3415            .map(|x| x.slice(s![0, .., .., ..]))
3416            .collect::<Vec<_>>();
3417
3418        decode_modelpack_split_float(
3419            &new_outputs,
3420            &new_detection,
3421            self.score_threshold,
3422            self.iou_threshold,
3423            output_boxes,
3424        );
3425        Ok(())
3426    }
3427
3428    fn decode_modelpack_seg_float<T>(
3429        &self,
3430        outputs: &[ArrayViewD<T>],
3431        segmentation: &configs::Segmentation,
3432        output_masks: &mut Vec<Segmentation>,
3433    ) -> Result<(), DecoderError>
3434    where
3435        T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
3436        f32: AsPrimitive<T>,
3437    {
3438        let (seg, _) = Self::find_outputs_with_shape(&segmentation.shape, outputs, &[])?;
3439
3440        let seg = Self::swap_axes_if_needed(seg, segmentation.into());
3441        let seg = seg.slice(s![0, .., .., ..]);
3442        let u8_max = 255.0_f32.as_();
3443        let max = *seg.max().unwrap_or(&u8_max);
3444        let min = *seg.min().unwrap_or(&0.0_f32.as_());
3445        let seg = seg.mapv(|x| ((x - min) / (max - min) * u8_max).as_());
3446        output_masks.push(Segmentation {
3447            xmin: 0.0,
3448            ymin: 0.0,
3449            xmax: 1.0,
3450            ymax: 1.0,
3451            segmentation: seg,
3452        });
3453        Ok(())
3454    }
3455
3456    fn decode_modelpack_det_float<T>(
3457        &self,
3458        outputs: &[ArrayViewD<T>],
3459        boxes: &configs::Boxes,
3460        scores: &configs::Scores,
3461        output_boxes: &mut Vec<DetectBox>,
3462    ) -> Result<(), DecoderError>
3463    where
3464        T: Float + AsPrimitive<f32> + Send + Sync + 'static,
3465        f32: AsPrimitive<T>,
3466    {
3467        let (boxes_tensor, ind) = Self::find_outputs_with_shape(&boxes.shape, outputs, &[])?;
3468
3469        let boxes_tensor = Self::swap_axes_if_needed(boxes_tensor, boxes.into());
3470        let boxes_tensor = boxes_tensor.slice(s![0, .., 0, ..]);
3471
3472        let (scores_tensor, _) = Self::find_outputs_with_shape(&scores.shape, outputs, &[ind])?;
3473        let scores_tensor = Self::swap_axes_if_needed(scores_tensor, scores.into());
3474        let scores_tensor = scores_tensor.slice(s![0, .., ..]);
3475
3476        decode_modelpack_float(
3477            boxes_tensor,
3478            scores_tensor,
3479            self.score_threshold,
3480            self.iou_threshold,
3481            output_boxes,
3482        );
3483        Ok(())
3484    }
3485
3486    fn decode_yolo_det_float<T>(
3487        &self,
3488        outputs: &[ArrayViewD<T>],
3489        boxes: &configs::Detection,
3490        output_boxes: &mut Vec<DetectBox>,
3491    ) -> Result<(), DecoderError>
3492    where
3493        T: Float + AsPrimitive<f32> + Send + Sync + 'static,
3494        f32: AsPrimitive<T>,
3495    {
3496        let (boxes_tensor, _) = Self::find_outputs_with_shape(&boxes.shape, outputs, &[])?;
3497
3498        let boxes_tensor = Self::swap_axes_if_needed(boxes_tensor, boxes.into());
3499        let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
3500        decode_yolo_det_float(
3501            boxes_tensor,
3502            self.score_threshold,
3503            self.iou_threshold,
3504            self.nms,
3505            output_boxes,
3506        );
3507        Ok(())
3508    }
3509
3510    fn decode_yolo_segdet_float<T>(
3511        &self,
3512        outputs: &[ArrayViewD<T>],
3513        boxes: &configs::Detection,
3514        protos: &configs::Protos,
3515        output_boxes: &mut Vec<DetectBox>,
3516        output_masks: &mut Vec<Segmentation>,
3517    ) -> Result<(), DecoderError>
3518    where
3519        T: Float + AsPrimitive<f32> + Send + Sync + 'static,
3520        f32: AsPrimitive<T>,
3521    {
3522        let (boxes_tensor, ind) = Self::find_outputs_with_shape(&boxes.shape, outputs, &[])?;
3523
3524        let boxes_tensor = Self::swap_axes_if_needed(boxes_tensor, boxes.into());
3525        let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
3526
3527        let (protos_tensor, _) = Self::find_outputs_with_shape(&protos.shape, outputs, &[ind])?;
3528
3529        let protos_tensor = Self::swap_axes_if_needed(protos_tensor, protos.into());
3530        let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
3531        decode_yolo_segdet_float(
3532            boxes_tensor,
3533            protos_tensor,
3534            self.score_threshold,
3535            self.iou_threshold,
3536            self.nms,
3537            output_boxes,
3538            output_masks,
3539        );
3540        Ok(())
3541    }
3542
3543    fn decode_yolo_split_det_float<T>(
3544        &self,
3545        outputs: &[ArrayViewD<T>],
3546        boxes: &configs::Boxes,
3547        scores: &configs::Scores,
3548        output_boxes: &mut Vec<DetectBox>,
3549    ) -> Result<(), DecoderError>
3550    where
3551        T: Float + AsPrimitive<f32> + Send + Sync + 'static,
3552        f32: AsPrimitive<T>,
3553    {
3554        let (boxes_tensor, ind) = Self::find_outputs_with_shape(&boxes.shape, outputs, &[])?;
3555        let boxes_tensor = Self::swap_axes_if_needed(boxes_tensor, boxes.into());
3556        let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
3557
3558        let (scores_tensor, _) = Self::find_outputs_with_shape(&scores.shape, outputs, &[ind])?;
3559
3560        let scores_tensor = Self::swap_axes_if_needed(scores_tensor, scores.into());
3561        let scores_tensor = scores_tensor.slice(s![0, .., ..]);
3562
3563        decode_yolo_split_det_float(
3564            boxes_tensor,
3565            scores_tensor,
3566            self.score_threshold,
3567            self.iou_threshold,
3568            self.nms,
3569            output_boxes,
3570        );
3571        Ok(())
3572    }
3573
3574    #[allow(clippy::too_many_arguments)]
3575    fn decode_yolo_split_segdet_float<T>(
3576        &self,
3577        outputs: &[ArrayViewD<T>],
3578        boxes: &configs::Boxes,
3579        scores: &configs::Scores,
3580        mask_coeff: &configs::MaskCoefficients,
3581        protos: &configs::Protos,
3582        output_boxes: &mut Vec<DetectBox>,
3583        output_masks: &mut Vec<Segmentation>,
3584    ) -> Result<(), DecoderError>
3585    where
3586        T: Float + AsPrimitive<f32> + Send + Sync + 'static,
3587        f32: AsPrimitive<T>,
3588    {
3589        let mut skip = vec![];
3590        let (boxes_tensor, ind) = Self::find_outputs_with_shape(&boxes.shape, outputs, &skip)?;
3591
3592        let boxes_tensor = Self::swap_axes_if_needed(boxes_tensor, boxes.into());
3593        let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
3594        skip.push(ind);
3595
3596        let (scores_tensor, ind) = Self::find_outputs_with_shape(&scores.shape, outputs, &skip)?;
3597
3598        let scores_tensor = Self::swap_axes_if_needed(scores_tensor, scores.into());
3599        let scores_tensor = scores_tensor.slice(s![0, .., ..]);
3600        skip.push(ind);
3601
3602        let (mask_tensor, ind) = Self::find_outputs_with_shape(&mask_coeff.shape, outputs, &skip)?;
3603        let mask_tensor = Self::swap_axes_if_needed(mask_tensor, mask_coeff.into());
3604        let mask_tensor = mask_tensor.slice(s![0, .., ..]);
3605        skip.push(ind);
3606
3607        let (protos_tensor, _) = Self::find_outputs_with_shape(&protos.shape, outputs, &skip)?;
3608        let protos_tensor = Self::swap_axes_if_needed(protos_tensor, protos.into());
3609        let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
3610        decode_yolo_split_segdet_float(
3611            boxes_tensor,
3612            scores_tensor,
3613            mask_tensor,
3614            protos_tensor,
3615            self.score_threshold,
3616            self.iou_threshold,
3617            self.nms,
3618            output_boxes,
3619            output_masks,
3620        );
3621        Ok(())
3622    }
3623
3624    /// Decodes end-to-end YOLO detection outputs (post-NMS from model).
3625    ///
3626    /// Input shape: (1, N, 6+) where columns are [x1, y1, x2, y2, conf, class,
3627    /// ...] Boxes are output directly from model (may be normalized or
3628    /// pixel coords depending on config).
3629    fn decode_yolo_end_to_end_det_float<T>(
3630        &self,
3631        outputs: &[ArrayViewD<T>],
3632        boxes_config: &configs::Detection,
3633        output_boxes: &mut Vec<DetectBox>,
3634    ) -> Result<(), DecoderError>
3635    where
3636        T: Float + AsPrimitive<f32> + Send + Sync + 'static,
3637        f32: AsPrimitive<T>,
3638    {
3639        let (det_tensor, _) = Self::find_outputs_with_shape(&boxes_config.shape, outputs, &[])?;
3640        let det_tensor = Self::swap_axes_if_needed(det_tensor, boxes_config.into());
3641        let det_tensor = det_tensor.slice(s![0, .., ..]);
3642
3643        crate::yolo::decode_yolo_end_to_end_det_float(
3644            det_tensor,
3645            self.score_threshold,
3646            output_boxes,
3647        )?;
3648        Ok(())
3649    }
3650
3651    /// Decodes end-to-end YOLO detection + segmentation outputs (post-NMS from
3652    /// model).
3653    ///
3654    /// Input shapes:
3655    /// - detection: (1, N, 6 + num_protos) where columns are [x1, y1, x2, y2,
3656    ///   conf, class, mask_coeff_0, ..., mask_coeff_31]
3657    /// - protos: (1, proto_height, proto_width, num_protos)
3658    fn decode_yolo_end_to_end_segdet_float<T>(
3659        &self,
3660        outputs: &[ArrayViewD<T>],
3661        boxes_config: &configs::Detection,
3662        protos_config: &configs::Protos,
3663        output_boxes: &mut Vec<DetectBox>,
3664        output_masks: &mut Vec<Segmentation>,
3665    ) -> Result<(), DecoderError>
3666    where
3667        T: Float + AsPrimitive<f32> + Send + Sync + 'static,
3668        f32: AsPrimitive<T>,
3669    {
3670        if outputs.len() < 2 {
3671            return Err(DecoderError::InvalidShape(
3672                "End-to-end segdet requires detection and protos outputs".to_string(),
3673            ));
3674        }
3675
3676        let (det_tensor, det_ind) =
3677            Self::find_outputs_with_shape(&boxes_config.shape, outputs, &[])?;
3678        let det_tensor = Self::swap_axes_if_needed(det_tensor, boxes_config.into());
3679        let det_tensor = det_tensor.slice(s![0, .., ..]);
3680
3681        let (protos_tensor, _) =
3682            Self::find_outputs_with_shape(&protos_config.shape, outputs, &[det_ind])?;
3683        let protos_tensor = Self::swap_axes_if_needed(protos_tensor, protos_config.into());
3684        let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
3685
3686        crate::yolo::decode_yolo_end_to_end_segdet_float(
3687            det_tensor,
3688            protos_tensor,
3689            self.score_threshold,
3690            output_boxes,
3691            output_masks,
3692        )?;
3693        Ok(())
3694    }
3695
3696    /// Decodes monolithic end-to-end YOLO detection from quantized tensors.
3697    /// Dequantizes then delegates to the float decode path.
3698    fn decode_yolo_end_to_end_det_quantized(
3699        &self,
3700        outputs: &[ArrayViewDQuantized],
3701        boxes_config: &configs::Detection,
3702        output_boxes: &mut Vec<DetectBox>,
3703    ) -> Result<(), DecoderError> {
3704        let (det_tensor, _) =
3705            Self::find_outputs_with_shape_quantized(&boxes_config.shape, outputs, &[])?;
3706        let quant = boxes_config
3707            .quantization
3708            .map(Quantization::from)
3709            .unwrap_or_default();
3710
3711        with_quantized!(det_tensor, d, {
3712            let d = Self::swap_axes_if_needed(d, boxes_config.into());
3713            let d = d.slice(s![0, .., ..]);
3714            let dequant = d.map(|v| {
3715                let val: f32 = v.as_();
3716                (val - quant.zero_point as f32) * quant.scale
3717            });
3718            crate::yolo::decode_yolo_end_to_end_det_float(
3719                dequant.view(),
3720                self.score_threshold,
3721                output_boxes,
3722            )?;
3723        });
3724        Ok(())
3725    }
3726
3727    /// Decodes monolithic end-to-end YOLO seg detection from quantized tensors.
3728    #[allow(clippy::too_many_arguments)]
3729    fn decode_yolo_end_to_end_segdet_quantized(
3730        &self,
3731        outputs: &[ArrayViewDQuantized],
3732        boxes_config: &configs::Detection,
3733        protos_config: &configs::Protos,
3734        output_boxes: &mut Vec<DetectBox>,
3735        output_masks: &mut Vec<Segmentation>,
3736    ) -> Result<(), DecoderError> {
3737        let (det_tensor, det_ind) =
3738            Self::find_outputs_with_shape_quantized(&boxes_config.shape, outputs, &[])?;
3739        let (protos_tensor, _) =
3740            Self::find_outputs_with_shape_quantized(&protos_config.shape, outputs, &[det_ind])?;
3741
3742        let quant_det = boxes_config
3743            .quantization
3744            .map(Quantization::from)
3745            .unwrap_or_default();
3746        let quant_protos = protos_config
3747            .quantization
3748            .map(Quantization::from)
3749            .unwrap_or_default();
3750
3751        // Dequantize each tensor independently to avoid monomorphization explosion.
3752        // Nesting 2 with_quantized! calls produces 6^2 = 36 instantiations; sequential is 6*2 = 12.
3753        macro_rules! dequant_3d {
3754            ($tensor:expr, $config:expr, $quant:expr) => {{
3755                with_quantized!($tensor, t, {
3756                    let t = Self::swap_axes_if_needed(t, $config.into());
3757                    let t = t.slice(s![0, .., ..]);
3758                    t.map(|v| {
3759                        let val: f32 = v.as_();
3760                        (val - $quant.zero_point as f32) * $quant.scale
3761                    })
3762                })
3763            }};
3764        }
3765        macro_rules! dequant_4d {
3766            ($tensor:expr, $config:expr, $quant:expr) => {{
3767                with_quantized!($tensor, t, {
3768                    let t = Self::swap_axes_if_needed(t, $config.into());
3769                    let t = t.slice(s![0, .., .., ..]);
3770                    t.map(|v| {
3771                        let val: f32 = v.as_();
3772                        (val - $quant.zero_point as f32) * $quant.scale
3773                    })
3774                })
3775            }};
3776        }
3777
3778        let dequant_d = dequant_3d!(det_tensor, boxes_config, quant_det);
3779        let dequant_p = dequant_4d!(protos_tensor, protos_config, quant_protos);
3780
3781        crate::yolo::decode_yolo_end_to_end_segdet_float(
3782            dequant_d.view(),
3783            dequant_p.view(),
3784            self.score_threshold,
3785            output_boxes,
3786            output_masks,
3787        )?;
3788        Ok(())
3789    }
3790
3791    /// Decodes split end-to-end YOLO detection from float tensors.
3792    fn decode_yolo_split_end_to_end_det_float<T>(
3793        &self,
3794        outputs: &[ArrayViewD<T>],
3795        boxes_config: &configs::Boxes,
3796        scores_config: &configs::Scores,
3797        classes_config: &configs::Classes,
3798        output_boxes: &mut Vec<DetectBox>,
3799    ) -> Result<(), DecoderError>
3800    where
3801        T: Float + AsPrimitive<f32> + Send + Sync + 'static,
3802        f32: AsPrimitive<T>,
3803    {
3804        let mut skip = vec![];
3805        let (boxes_tensor, ind) =
3806            Self::find_outputs_with_shape(&boxes_config.shape, outputs, &skip)?;
3807        let boxes_tensor = Self::swap_axes_if_needed(boxes_tensor, boxes_config.into());
3808        let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
3809        skip.push(ind);
3810
3811        let (scores_tensor, ind) =
3812            Self::find_outputs_with_shape(&scores_config.shape, outputs, &skip)?;
3813        let scores_tensor = Self::swap_axes_if_needed(scores_tensor, scores_config.into());
3814        let scores_tensor = scores_tensor.slice(s![0, .., ..]);
3815        skip.push(ind);
3816
3817        let (classes_tensor, _) =
3818            Self::find_outputs_with_shape(&classes_config.shape, outputs, &skip)?;
3819        let classes_tensor = Self::swap_axes_if_needed(classes_tensor, classes_config.into());
3820        let classes_tensor = classes_tensor.slice(s![0, .., ..]);
3821
3822        crate::yolo::decode_yolo_split_end_to_end_det_float(
3823            boxes_tensor,
3824            scores_tensor,
3825            classes_tensor,
3826            self.score_threshold,
3827            output_boxes,
3828        )?;
3829        Ok(())
3830    }
3831
3832    /// Decodes split end-to-end YOLO seg detection from float tensors.
3833    #[allow(clippy::too_many_arguments)]
3834    fn decode_yolo_split_end_to_end_segdet_float<T>(
3835        &self,
3836        outputs: &[ArrayViewD<T>],
3837        boxes_config: &configs::Boxes,
3838        scores_config: &configs::Scores,
3839        classes_config: &configs::Classes,
3840        mask_coeff_config: &configs::MaskCoefficients,
3841        protos_config: &configs::Protos,
3842        output_boxes: &mut Vec<DetectBox>,
3843        output_masks: &mut Vec<Segmentation>,
3844    ) -> Result<(), DecoderError>
3845    where
3846        T: Float + AsPrimitive<f32> + Send + Sync + 'static,
3847        f32: AsPrimitive<T>,
3848    {
3849        let mut skip = vec![];
3850        let (boxes_tensor, ind) =
3851            Self::find_outputs_with_shape(&boxes_config.shape, outputs, &skip)?;
3852        let boxes_tensor = Self::swap_axes_if_needed(boxes_tensor, boxes_config.into());
3853        let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
3854        skip.push(ind);
3855
3856        let (scores_tensor, ind) =
3857            Self::find_outputs_with_shape(&scores_config.shape, outputs, &skip)?;
3858        let scores_tensor = Self::swap_axes_if_needed(scores_tensor, scores_config.into());
3859        let scores_tensor = scores_tensor.slice(s![0, .., ..]);
3860        skip.push(ind);
3861
3862        let (classes_tensor, ind) =
3863            Self::find_outputs_with_shape(&classes_config.shape, outputs, &skip)?;
3864        let classes_tensor = Self::swap_axes_if_needed(classes_tensor, classes_config.into());
3865        let classes_tensor = classes_tensor.slice(s![0, .., ..]);
3866        skip.push(ind);
3867
3868        let (mask_tensor, ind) =
3869            Self::find_outputs_with_shape(&mask_coeff_config.shape, outputs, &skip)?;
3870        let mask_tensor = Self::swap_axes_if_needed(mask_tensor, mask_coeff_config.into());
3871        let mask_tensor = mask_tensor.slice(s![0, .., ..]);
3872        skip.push(ind);
3873
3874        let (protos_tensor, _) =
3875            Self::find_outputs_with_shape(&protos_config.shape, outputs, &skip)?;
3876        let protos_tensor = Self::swap_axes_if_needed(protos_tensor, protos_config.into());
3877        let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
3878
3879        crate::yolo::decode_yolo_split_end_to_end_segdet_float(
3880            boxes_tensor,
3881            scores_tensor,
3882            classes_tensor,
3883            mask_tensor,
3884            protos_tensor,
3885            self.score_threshold,
3886            output_boxes,
3887            output_masks,
3888        )?;
3889        Ok(())
3890    }
3891
3892    /// Decodes split end-to-end YOLO detection from quantized tensors.
3893    /// Dequantizes each tensor then delegates to the float decode path.
3894    fn decode_yolo_split_end_to_end_det_quantized(
3895        &self,
3896        outputs: &[ArrayViewDQuantized],
3897        boxes_config: &configs::Boxes,
3898        scores_config: &configs::Scores,
3899        classes_config: &configs::Classes,
3900        output_boxes: &mut Vec<DetectBox>,
3901    ) -> Result<(), DecoderError> {
3902        let mut skip = vec![];
3903        let (boxes_tensor, ind) =
3904            Self::find_outputs_with_shape_quantized(&boxes_config.shape, outputs, &skip)?;
3905        skip.push(ind);
3906        let (scores_tensor, ind) =
3907            Self::find_outputs_with_shape_quantized(&scores_config.shape, outputs, &skip)?;
3908        skip.push(ind);
3909        let (classes_tensor, _) =
3910            Self::find_outputs_with_shape_quantized(&classes_config.shape, outputs, &skip)?;
3911
3912        let quant_boxes = boxes_config
3913            .quantization
3914            .map(Quantization::from)
3915            .unwrap_or_default();
3916        let quant_scores = scores_config
3917            .quantization
3918            .map(Quantization::from)
3919            .unwrap_or_default();
3920        let quant_classes = classes_config
3921            .quantization
3922            .map(Quantization::from)
3923            .unwrap_or_default();
3924
3925        // Dequantize each tensor independently to avoid monomorphization explosion.
3926        // Nesting N with_quantized! calls produces 6^N instantiations; sequential is 6*N.
3927        macro_rules! dequant_3d {
3928            ($tensor:expr, $config:expr, $quant:expr) => {{
3929                with_quantized!($tensor, t, {
3930                    let t = Self::swap_axes_if_needed(t, $config.into());
3931                    let t = t.slice(s![0, .., ..]);
3932                    t.map(|v| {
3933                        let val: f32 = v.as_();
3934                        (val - $quant.zero_point as f32) * $quant.scale
3935                    })
3936                })
3937            }};
3938        }
3939
3940        let dequant_b = dequant_3d!(boxes_tensor, boxes_config, quant_boxes);
3941        let dequant_s = dequant_3d!(scores_tensor, scores_config, quant_scores);
3942        let dequant_c = dequant_3d!(classes_tensor, classes_config, quant_classes);
3943
3944        crate::yolo::decode_yolo_split_end_to_end_det_float(
3945            dequant_b.view(),
3946            dequant_s.view(),
3947            dequant_c.view(),
3948            self.score_threshold,
3949            output_boxes,
3950        )?;
3951        Ok(())
3952    }
3953
3954    /// Decodes split end-to-end YOLO seg detection from quantized tensors.
3955    #[allow(clippy::too_many_arguments)]
3956    fn decode_yolo_split_end_to_end_segdet_quantized(
3957        &self,
3958        outputs: &[ArrayViewDQuantized],
3959        boxes_config: &configs::Boxes,
3960        scores_config: &configs::Scores,
3961        classes_config: &configs::Classes,
3962        mask_coeff_config: &configs::MaskCoefficients,
3963        protos_config: &configs::Protos,
3964        output_boxes: &mut Vec<DetectBox>,
3965        output_masks: &mut Vec<Segmentation>,
3966    ) -> Result<(), DecoderError> {
3967        let mut skip = vec![];
3968        let (boxes_tensor, ind) =
3969            Self::find_outputs_with_shape_quantized(&boxes_config.shape, outputs, &skip)?;
3970        skip.push(ind);
3971        let (scores_tensor, ind) =
3972            Self::find_outputs_with_shape_quantized(&scores_config.shape, outputs, &skip)?;
3973        skip.push(ind);
3974        let (classes_tensor, ind) =
3975            Self::find_outputs_with_shape_quantized(&classes_config.shape, outputs, &skip)?;
3976        skip.push(ind);
3977        let (mask_tensor, ind) =
3978            Self::find_outputs_with_shape_quantized(&mask_coeff_config.shape, outputs, &skip)?;
3979        skip.push(ind);
3980        let (protos_tensor, _) =
3981            Self::find_outputs_with_shape_quantized(&protos_config.shape, outputs, &skip)?;
3982
3983        let quant_boxes = boxes_config
3984            .quantization
3985            .map(Quantization::from)
3986            .unwrap_or_default();
3987        let quant_scores = scores_config
3988            .quantization
3989            .map(Quantization::from)
3990            .unwrap_or_default();
3991        let quant_classes = classes_config
3992            .quantization
3993            .map(Quantization::from)
3994            .unwrap_or_default();
3995        let quant_masks = mask_coeff_config
3996            .quantization
3997            .map(Quantization::from)
3998            .unwrap_or_default();
3999        let quant_protos = protos_config
4000            .quantization
4001            .map(Quantization::from)
4002            .unwrap_or_default();
4003
4004        // Dequantize each tensor independently to avoid monomorphization explosion.
4005        // Nesting 5 with_quantized! calls would produce 6^5 = 7776 instantiations.
4006        macro_rules! dequant_3d {
4007            ($tensor:expr, $config:expr, $quant:expr) => {{
4008                with_quantized!($tensor, t, {
4009                    let t = Self::swap_axes_if_needed(t, $config.into());
4010                    let t = t.slice(s![0, .., ..]);
4011                    t.map(|v| {
4012                        let val: f32 = v.as_();
4013                        (val - $quant.zero_point as f32) * $quant.scale
4014                    })
4015                })
4016            }};
4017        }
4018        macro_rules! dequant_4d {
4019            ($tensor:expr, $config:expr, $quant:expr) => {{
4020                with_quantized!($tensor, t, {
4021                    let t = Self::swap_axes_if_needed(t, $config.into());
4022                    let t = t.slice(s![0, .., .., ..]);
4023                    t.map(|v| {
4024                        let val: f32 = v.as_();
4025                        (val - $quant.zero_point as f32) * $quant.scale
4026                    })
4027                })
4028            }};
4029        }
4030
4031        let dequant_b = dequant_3d!(boxes_tensor, boxes_config, quant_boxes);
4032        let dequant_s = dequant_3d!(scores_tensor, scores_config, quant_scores);
4033        let dequant_c = dequant_3d!(classes_tensor, classes_config, quant_classes);
4034        let dequant_m = dequant_3d!(mask_tensor, mask_coeff_config, quant_masks);
4035        let dequant_p = dequant_4d!(protos_tensor, protos_config, quant_protos);
4036
4037        crate::yolo::decode_yolo_split_end_to_end_segdet_float(
4038            dequant_b.view(),
4039            dequant_s.view(),
4040            dequant_c.view(),
4041            dequant_m.view(),
4042            dequant_p.view(),
4043            self.score_threshold,
4044            output_boxes,
4045            output_masks,
4046        )?;
4047        Ok(())
4048    }
4049
4050    // ------------------------------------------------------------------
4051    // Proto-extraction private helpers (mirror the non-proto variants)
4052    // ------------------------------------------------------------------
4053
4054    fn decode_yolo_segdet_quantized_proto(
4055        &self,
4056        outputs: &[ArrayViewDQuantized],
4057        boxes: &configs::Detection,
4058        protos: &configs::Protos,
4059        output_boxes: &mut Vec<DetectBox>,
4060    ) -> Result<ProtoData, DecoderError> {
4061        let (boxes_tensor, ind) =
4062            Self::find_outputs_with_shape_quantized(&boxes.shape, outputs, &[])?;
4063        let (protos_tensor, _) =
4064            Self::find_outputs_with_shape_quantized(&protos.shape, outputs, &[ind])?;
4065
4066        let quant_boxes = boxes
4067            .quantization
4068            .map(Quantization::from)
4069            .unwrap_or_default();
4070        let quant_protos = protos
4071            .quantization
4072            .map(Quantization::from)
4073            .unwrap_or_default();
4074
4075        let proto = with_quantized!(boxes_tensor, b, {
4076            with_quantized!(protos_tensor, p, {
4077                let box_tensor = Self::swap_axes_if_needed(b, boxes.into());
4078                let box_tensor = box_tensor.slice(s![0, .., ..]);
4079
4080                let protos_tensor = Self::swap_axes_if_needed(p, protos.into());
4081                let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
4082                crate::yolo::impl_yolo_segdet_quant_proto::<XYWH, _, _>(
4083                    (box_tensor, quant_boxes),
4084                    (protos_tensor, quant_protos),
4085                    self.score_threshold,
4086                    self.iou_threshold,
4087                    self.nms,
4088                    output_boxes,
4089                )
4090            })
4091        });
4092        Ok(proto)
4093    }
4094
4095    fn decode_yolo_segdet_float_proto<T>(
4096        &self,
4097        outputs: &[ArrayViewD<T>],
4098        boxes: &configs::Detection,
4099        protos: &configs::Protos,
4100        output_boxes: &mut Vec<DetectBox>,
4101    ) -> Result<ProtoData, DecoderError>
4102    where
4103        T: Float + AsPrimitive<f32> + Send + Sync + 'static,
4104        f32: AsPrimitive<T>,
4105    {
4106        let (boxes_tensor, ind) = Self::find_outputs_with_shape(&boxes.shape, outputs, &[])?;
4107        let boxes_tensor = Self::swap_axes_if_needed(boxes_tensor, boxes.into());
4108        let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
4109
4110        let (protos_tensor, _) = Self::find_outputs_with_shape(&protos.shape, outputs, &[ind])?;
4111        let protos_tensor = Self::swap_axes_if_needed(protos_tensor, protos.into());
4112        let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
4113
4114        Ok(crate::yolo::impl_yolo_segdet_float_proto::<XYWH, _, _>(
4115            boxes_tensor,
4116            protos_tensor,
4117            self.score_threshold,
4118            self.iou_threshold,
4119            self.nms,
4120            output_boxes,
4121        ))
4122    }
4123
4124    #[allow(clippy::too_many_arguments)]
4125    fn decode_yolo_split_segdet_quantized_proto(
4126        &self,
4127        outputs: &[ArrayViewDQuantized],
4128        boxes: &configs::Boxes,
4129        scores: &configs::Scores,
4130        mask_coeff: &configs::MaskCoefficients,
4131        protos: &configs::Protos,
4132        output_boxes: &mut Vec<DetectBox>,
4133    ) -> Result<ProtoData, DecoderError> {
4134        let quant_boxes = boxes
4135            .quantization
4136            .map(Quantization::from)
4137            .unwrap_or_default();
4138        let quant_scores = scores
4139            .quantization
4140            .map(Quantization::from)
4141            .unwrap_or_default();
4142        let quant_masks = mask_coeff
4143            .quantization
4144            .map(Quantization::from)
4145            .unwrap_or_default();
4146        let quant_protos = protos
4147            .quantization
4148            .map(Quantization::from)
4149            .unwrap_or_default();
4150
4151        let mut skip = vec![];
4152
4153        let (boxes_tensor, ind) =
4154            Self::find_outputs_with_shape_quantized(&boxes.shape, outputs, &skip)?;
4155        skip.push(ind);
4156
4157        let (scores_tensor, ind) =
4158            Self::find_outputs_with_shape_quantized(&scores.shape, outputs, &skip)?;
4159        skip.push(ind);
4160
4161        let (mask_tensor, ind) =
4162            Self::find_outputs_with_shape_quantized(&mask_coeff.shape, outputs, &skip)?;
4163        skip.push(ind);
4164
4165        let (protos_tensor, _) =
4166            Self::find_outputs_with_shape_quantized(&protos.shape, outputs, &skip)?;
4167
4168        // Phase 1: boxes + scores (2-level nesting, 36 paths).
4169        let det_indices = with_quantized!(boxes_tensor, b, {
4170            with_quantized!(scores_tensor, s, {
4171                let boxes_tensor = Self::swap_axes_if_needed(b, boxes.into());
4172                let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
4173
4174                let scores_tensor = Self::swap_axes_if_needed(s, scores.into());
4175                let scores_tensor = scores_tensor.slice(s![0, .., ..]);
4176
4177                impl_yolo_split_segdet_quant_get_boxes::<XYWH, _, _>(
4178                    (boxes_tensor, quant_boxes),
4179                    (scores_tensor, quant_scores),
4180                    self.score_threshold,
4181                    self.iou_threshold,
4182                    self.nms,
4183                    output_boxes.capacity(),
4184                )
4185            })
4186        });
4187
4188        // Phase 2: masks + protos (2-level nesting, 36 paths).
4189        let proto = with_quantized!(mask_tensor, m, {
4190            with_quantized!(protos_tensor, p, {
4191                let mask_tensor = Self::swap_axes_if_needed(m, mask_coeff.into());
4192                let mask_tensor = mask_tensor.slice(s![0, .., ..]);
4193                let mask_tensor = mask_tensor.reversed_axes();
4194
4195                let protos_tensor = Self::swap_axes_if_needed(p, protos.into());
4196                let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
4197
4198                crate::yolo::extract_proto_data_quant(
4199                    det_indices,
4200                    mask_tensor,
4201                    quant_masks,
4202                    protos_tensor,
4203                    quant_protos,
4204                    output_boxes,
4205                )
4206            })
4207        });
4208        Ok(proto)
4209    }
4210
4211    #[allow(clippy::too_many_arguments)]
4212    fn decode_yolo_split_segdet_float_proto<T>(
4213        &self,
4214        outputs: &[ArrayViewD<T>],
4215        boxes: &configs::Boxes,
4216        scores: &configs::Scores,
4217        mask_coeff: &configs::MaskCoefficients,
4218        protos: &configs::Protos,
4219        output_boxes: &mut Vec<DetectBox>,
4220    ) -> Result<ProtoData, DecoderError>
4221    where
4222        T: Float + AsPrimitive<f32> + Send + Sync + 'static,
4223        f32: AsPrimitive<T>,
4224    {
4225        let mut skip = vec![];
4226        let (boxes_tensor, ind) = Self::find_outputs_with_shape(&boxes.shape, outputs, &skip)?;
4227        let boxes_tensor = Self::swap_axes_if_needed(boxes_tensor, boxes.into());
4228        let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
4229        skip.push(ind);
4230
4231        let (scores_tensor, ind) = Self::find_outputs_with_shape(&scores.shape, outputs, &skip)?;
4232        let scores_tensor = Self::swap_axes_if_needed(scores_tensor, scores.into());
4233        let scores_tensor = scores_tensor.slice(s![0, .., ..]);
4234        skip.push(ind);
4235
4236        let (mask_tensor, ind) = Self::find_outputs_with_shape(&mask_coeff.shape, outputs, &skip)?;
4237        let mask_tensor = Self::swap_axes_if_needed(mask_tensor, mask_coeff.into());
4238        let mask_tensor = mask_tensor.slice(s![0, .., ..]);
4239        skip.push(ind);
4240
4241        let (protos_tensor, _) = Self::find_outputs_with_shape(&protos.shape, outputs, &skip)?;
4242        let protos_tensor = Self::swap_axes_if_needed(protos_tensor, protos.into());
4243        let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
4244
4245        Ok(crate::yolo::impl_yolo_split_segdet_float_proto::<
4246            XYWH,
4247            _,
4248            _,
4249            _,
4250            _,
4251        >(
4252            boxes_tensor,
4253            scores_tensor,
4254            mask_tensor,
4255            protos_tensor,
4256            self.score_threshold,
4257            self.iou_threshold,
4258            self.nms,
4259            output_boxes,
4260        ))
4261    }
4262
4263    fn decode_yolo_end_to_end_segdet_float_proto<T>(
4264        &self,
4265        outputs: &[ArrayViewD<T>],
4266        boxes_config: &configs::Detection,
4267        protos_config: &configs::Protos,
4268        output_boxes: &mut Vec<DetectBox>,
4269    ) -> Result<ProtoData, DecoderError>
4270    where
4271        T: Float + AsPrimitive<f32> + Send + Sync + 'static,
4272        f32: AsPrimitive<T>,
4273    {
4274        if outputs.len() < 2 {
4275            return Err(DecoderError::InvalidShape(
4276                "End-to-end segdet requires detection and protos outputs".to_string(),
4277            ));
4278        }
4279
4280        let (det_tensor, det_ind) =
4281            Self::find_outputs_with_shape(&boxes_config.shape, outputs, &[])?;
4282        let det_tensor = Self::swap_axes_if_needed(det_tensor, boxes_config.into());
4283        let det_tensor = det_tensor.slice(s![0, .., ..]);
4284
4285        let (protos_tensor, _) =
4286            Self::find_outputs_with_shape(&protos_config.shape, outputs, &[det_ind])?;
4287        let protos_tensor = Self::swap_axes_if_needed(protos_tensor, protos_config.into());
4288        let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
4289
4290        crate::yolo::decode_yolo_end_to_end_segdet_float_proto(
4291            det_tensor,
4292            protos_tensor,
4293            self.score_threshold,
4294            output_boxes,
4295        )
4296    }
4297
4298    fn decode_yolo_end_to_end_segdet_quantized_proto(
4299        &self,
4300        outputs: &[ArrayViewDQuantized],
4301        boxes_config: &configs::Detection,
4302        protos_config: &configs::Protos,
4303        output_boxes: &mut Vec<DetectBox>,
4304    ) -> Result<ProtoData, DecoderError> {
4305        let (det_tensor, det_ind) =
4306            Self::find_outputs_with_shape_quantized(&boxes_config.shape, outputs, &[])?;
4307        let (protos_tensor, _) =
4308            Self::find_outputs_with_shape_quantized(&protos_config.shape, outputs, &[det_ind])?;
4309
4310        let quant_det = boxes_config
4311            .quantization
4312            .map(Quantization::from)
4313            .unwrap_or_default();
4314        let quant_protos = protos_config
4315            .quantization
4316            .map(Quantization::from)
4317            .unwrap_or_default();
4318
4319        // Dequantize each tensor independently to avoid monomorphization explosion.
4320        // Nesting 2 with_quantized! calls produces 6^2 = 36 instantiations; sequential is 6*2 = 12.
4321        macro_rules! dequant_3d {
4322            ($tensor:expr, $config:expr, $quant:expr) => {{
4323                with_quantized!($tensor, t, {
4324                    let t = Self::swap_axes_if_needed(t, $config.into());
4325                    let t = t.slice(s![0, .., ..]);
4326                    t.map(|v| {
4327                        let val: f32 = v.as_();
4328                        (val - $quant.zero_point as f32) * $quant.scale
4329                    })
4330                })
4331            }};
4332        }
4333        macro_rules! dequant_4d {
4334            ($tensor:expr, $config:expr, $quant:expr) => {{
4335                with_quantized!($tensor, t, {
4336                    let t = Self::swap_axes_if_needed(t, $config.into());
4337                    let t = t.slice(s![0, .., .., ..]);
4338                    t.map(|v| {
4339                        let val: f32 = v.as_();
4340                        (val - $quant.zero_point as f32) * $quant.scale
4341                    })
4342                })
4343            }};
4344        }
4345
4346        let dequant_d = dequant_3d!(det_tensor, boxes_config, quant_det);
4347        let dequant_p = dequant_4d!(protos_tensor, protos_config, quant_protos);
4348
4349        let proto = crate::yolo::decode_yolo_end_to_end_segdet_float_proto(
4350            dequant_d.view(),
4351            dequant_p.view(),
4352            self.score_threshold,
4353            output_boxes,
4354        )?;
4355        Ok(proto)
4356    }
4357
4358    #[allow(clippy::too_many_arguments)]
4359    fn decode_yolo_split_end_to_end_segdet_float_proto<T>(
4360        &self,
4361        outputs: &[ArrayViewD<T>],
4362        boxes_config: &configs::Boxes,
4363        scores_config: &configs::Scores,
4364        classes_config: &configs::Classes,
4365        mask_coeff_config: &configs::MaskCoefficients,
4366        protos_config: &configs::Protos,
4367        output_boxes: &mut Vec<DetectBox>,
4368    ) -> Result<ProtoData, DecoderError>
4369    where
4370        T: Float + AsPrimitive<f32> + Send + Sync + 'static,
4371        f32: AsPrimitive<T>,
4372    {
4373        let mut skip = vec![];
4374        let (boxes_tensor, ind) =
4375            Self::find_outputs_with_shape(&boxes_config.shape, outputs, &skip)?;
4376        let boxes_tensor = Self::swap_axes_if_needed(boxes_tensor, boxes_config.into());
4377        let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
4378        skip.push(ind);
4379
4380        let (scores_tensor, ind) =
4381            Self::find_outputs_with_shape(&scores_config.shape, outputs, &skip)?;
4382        let scores_tensor = Self::swap_axes_if_needed(scores_tensor, scores_config.into());
4383        let scores_tensor = scores_tensor.slice(s![0, .., ..]);
4384        skip.push(ind);
4385
4386        let (classes_tensor, ind) =
4387            Self::find_outputs_with_shape(&classes_config.shape, outputs, &skip)?;
4388        let classes_tensor = Self::swap_axes_if_needed(classes_tensor, classes_config.into());
4389        let classes_tensor = classes_tensor.slice(s![0, .., ..]);
4390        skip.push(ind);
4391
4392        let (mask_tensor, ind) =
4393            Self::find_outputs_with_shape(&mask_coeff_config.shape, outputs, &skip)?;
4394        let mask_tensor = Self::swap_axes_if_needed(mask_tensor, mask_coeff_config.into());
4395        let mask_tensor = mask_tensor.slice(s![0, .., ..]);
4396        skip.push(ind);
4397
4398        let (protos_tensor, _) =
4399            Self::find_outputs_with_shape(&protos_config.shape, outputs, &skip)?;
4400        let protos_tensor = Self::swap_axes_if_needed(protos_tensor, protos_config.into());
4401        let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
4402
4403        crate::yolo::decode_yolo_split_end_to_end_segdet_float_proto(
4404            boxes_tensor,
4405            scores_tensor,
4406            classes_tensor,
4407            mask_tensor,
4408            protos_tensor,
4409            self.score_threshold,
4410            output_boxes,
4411        )
4412    }
4413
4414    #[allow(clippy::too_many_arguments)]
4415    fn decode_yolo_split_end_to_end_segdet_quantized_proto(
4416        &self,
4417        outputs: &[ArrayViewDQuantized],
4418        boxes_config: &configs::Boxes,
4419        scores_config: &configs::Scores,
4420        classes_config: &configs::Classes,
4421        mask_coeff_config: &configs::MaskCoefficients,
4422        protos_config: &configs::Protos,
4423        output_boxes: &mut Vec<DetectBox>,
4424    ) -> Result<ProtoData, DecoderError> {
4425        let mut skip = vec![];
4426        let (boxes_tensor, ind) =
4427            Self::find_outputs_with_shape_quantized(&boxes_config.shape, outputs, &skip)?;
4428        skip.push(ind);
4429        let (scores_tensor, ind) =
4430            Self::find_outputs_with_shape_quantized(&scores_config.shape, outputs, &skip)?;
4431        skip.push(ind);
4432        let (classes_tensor, ind) =
4433            Self::find_outputs_with_shape_quantized(&classes_config.shape, outputs, &skip)?;
4434        skip.push(ind);
4435        let (mask_tensor, ind) =
4436            Self::find_outputs_with_shape_quantized(&mask_coeff_config.shape, outputs, &skip)?;
4437        skip.push(ind);
4438        let (protos_tensor, _) =
4439            Self::find_outputs_with_shape_quantized(&protos_config.shape, outputs, &skip)?;
4440
4441        let quant_boxes = boxes_config
4442            .quantization
4443            .map(Quantization::from)
4444            .unwrap_or_default();
4445        let quant_scores = scores_config
4446            .quantization
4447            .map(Quantization::from)
4448            .unwrap_or_default();
4449        let quant_classes = classes_config
4450            .quantization
4451            .map(Quantization::from)
4452            .unwrap_or_default();
4453        let quant_masks = mask_coeff_config
4454            .quantization
4455            .map(Quantization::from)
4456            .unwrap_or_default();
4457        let quant_protos = protos_config
4458            .quantization
4459            .map(Quantization::from)
4460            .unwrap_or_default();
4461
4462        macro_rules! dequant_3d {
4463            ($tensor:expr, $config:expr, $quant:expr) => {{
4464                with_quantized!($tensor, t, {
4465                    let t = Self::swap_axes_if_needed(t, $config.into());
4466                    let t = t.slice(s![0, .., ..]);
4467                    t.map(|v| {
4468                        let val: f32 = v.as_();
4469                        (val - $quant.zero_point as f32) * $quant.scale
4470                    })
4471                })
4472            }};
4473        }
4474        macro_rules! dequant_4d {
4475            ($tensor:expr, $config:expr, $quant:expr) => {{
4476                with_quantized!($tensor, t, {
4477                    let t = Self::swap_axes_if_needed(t, $config.into());
4478                    let t = t.slice(s![0, .., .., ..]);
4479                    t.map(|v| {
4480                        let val: f32 = v.as_();
4481                        (val - $quant.zero_point as f32) * $quant.scale
4482                    })
4483                })
4484            }};
4485        }
4486
4487        let dequant_b = dequant_3d!(boxes_tensor, boxes_config, quant_boxes);
4488        let dequant_s = dequant_3d!(scores_tensor, scores_config, quant_scores);
4489        let dequant_c = dequant_3d!(classes_tensor, classes_config, quant_classes);
4490        let dequant_m = dequant_3d!(mask_tensor, mask_coeff_config, quant_masks);
4491        let dequant_p = dequant_4d!(protos_tensor, protos_config, quant_protos);
4492
4493        crate::yolo::decode_yolo_split_end_to_end_segdet_float_proto(
4494            dequant_b.view(),
4495            dequant_s.view(),
4496            dequant_c.view(),
4497            dequant_m.view(),
4498            dequant_p.view(),
4499            self.score_threshold,
4500            output_boxes,
4501        )
4502    }
4503
4504    fn match_outputs_to_detect<'a, 'b, T>(
4505        configs: &[configs::Detection],
4506        outputs: &'a [ArrayViewD<'b, T>],
4507    ) -> Result<Vec<&'a ArrayViewD<'b, T>>, DecoderError> {
4508        let mut new_output_order = Vec::new();
4509        for c in configs {
4510            let mut found = false;
4511            for o in outputs {
4512                if o.shape() == c.shape {
4513                    new_output_order.push(o);
4514                    found = true;
4515                    break;
4516                }
4517            }
4518            if !found {
4519                return Err(DecoderError::InvalidShape(format!(
4520                    "Did not find output with shape {:?}",
4521                    c.shape
4522                )));
4523            }
4524        }
4525        Ok(new_output_order)
4526    }
4527
4528    fn find_outputs_with_shape<'a, 'b, T>(
4529        shape: &[usize],
4530        outputs: &'a [ArrayViewD<'b, T>],
4531        skip: &[usize],
4532    ) -> Result<(&'a ArrayViewD<'b, T>, usize), DecoderError> {
4533        for (ind, o) in outputs.iter().enumerate() {
4534            if skip.contains(&ind) {
4535                continue;
4536            }
4537            if o.shape() == shape {
4538                return Ok((o, ind));
4539            }
4540        }
4541        Err(DecoderError::InvalidShape(format!(
4542            "Did not find output with shape {:?}",
4543            shape
4544        )))
4545    }
4546
4547    fn find_outputs_with_shape_quantized<'a, 'b>(
4548        shape: &[usize],
4549        outputs: &'a [ArrayViewDQuantized<'b>],
4550        skip: &[usize],
4551    ) -> Result<(&'a ArrayViewDQuantized<'b>, usize), DecoderError> {
4552        for (ind, o) in outputs.iter().enumerate() {
4553            if skip.contains(&ind) {
4554                continue;
4555            }
4556            if o.shape() == shape {
4557                return Ok((o, ind));
4558            }
4559        }
4560        Err(DecoderError::InvalidShape(format!(
4561            "Did not find output with shape {:?}",
4562            shape
4563        )))
4564    }
4565
4566    /// This is split detection, need to swap axes to batch, height, width,
4567    /// num_anchors_x_features,
4568    fn modelpack_det_order(x: DimName) -> usize {
4569        match x {
4570            DimName::Batch => 0,
4571            DimName::NumBoxes => 1,
4572            DimName::Padding => 2,
4573            DimName::BoxCoords => 3,
4574            _ => 1000, // this should be unreachable
4575        }
4576    }
4577
4578    // This is Ultralytics detection, need to swap axes to batch, num_features,
4579    // height, width
4580    fn yolo_det_order(x: DimName) -> usize {
4581        match x {
4582            DimName::Batch => 0,
4583            DimName::NumFeatures => 1,
4584            DimName::NumBoxes => 2,
4585            _ => 1000, // this should be unreachable
4586        }
4587    }
4588
4589    // This is modelpack boxes, need to swap axes to batch, num_boxes, padding,
4590    // box_coords
4591    fn modelpack_boxes_order(x: DimName) -> usize {
4592        match x {
4593            DimName::Batch => 0,
4594            DimName::NumBoxes => 1,
4595            DimName::Padding => 2,
4596            DimName::BoxCoords => 3,
4597            _ => 1000, // this should be unreachable
4598        }
4599    }
4600
4601    /// This is Ultralytics boxes, need to swap axes to batch, box_coords,
4602    /// num_boxes
4603    fn yolo_boxes_order(x: DimName) -> usize {
4604        match x {
4605            DimName::Batch => 0,
4606            DimName::BoxCoords => 1,
4607            DimName::NumBoxes => 2,
4608            _ => 1000, // this should be unreachable
4609        }
4610    }
4611
4612    /// This is modelpack scores, need to swap axes to batch, num_boxes,
4613    /// num_classes
4614    fn modelpack_scores_order(x: DimName) -> usize {
4615        match x {
4616            DimName::Batch => 0,
4617            DimName::NumBoxes => 1,
4618            DimName::NumClasses => 2,
4619            _ => 1000, // this should be unreachable
4620        }
4621    }
4622
4623    fn yolo_scores_order(x: DimName) -> usize {
4624        match x {
4625            DimName::Batch => 0,
4626            DimName::NumClasses => 1,
4627            DimName::NumBoxes => 2,
4628            _ => 1000, // this should be unreachable
4629        }
4630    }
4631
4632    /// This is modelpack segmentation, need to swap axes to batch, height,
4633    /// width, num_classes
4634    fn modelpack_segmentation_order(x: DimName) -> usize {
4635        match x {
4636            DimName::Batch => 0,
4637            DimName::Height => 1,
4638            DimName::Width => 2,
4639            DimName::NumClasses => 3,
4640            _ => 1000, // this should be unreachable
4641        }
4642    }
4643
4644    /// This is modelpack masks, need to swap axes to batch, height,
4645    /// width
4646    fn modelpack_mask_order(x: DimName) -> usize {
4647        match x {
4648            DimName::Batch => 0,
4649            DimName::Height => 1,
4650            DimName::Width => 2,
4651            _ => 1000, // this should be unreachable
4652        }
4653    }
4654
4655    /// This is yolo protos, need to swap axes to batch, height, width,
4656    /// num_protos
4657    fn yolo_protos_order(x: DimName) -> usize {
4658        match x {
4659            DimName::Batch => 0,
4660            DimName::Height => 1,
4661            DimName::Width => 2,
4662            DimName::NumProtos => 3,
4663            _ => 1000, // this should be unreachable
4664        }
4665    }
4666
4667    /// This is yolo mask coefficients, need to swap axes to batch, num_protos,
4668    /// num_boxes
4669    fn yolo_maskcoefficients_order(x: DimName) -> usize {
4670        match x {
4671            DimName::Batch => 0,
4672            DimName::NumProtos => 1,
4673            DimName::NumBoxes => 2,
4674            _ => 1000, // this should be unreachable
4675        }
4676    }
4677
4678    fn get_order_fn(config: ConfigOutputRef) -> fn(DimName) -> usize {
4679        let decoder_type = config.decoder();
4680        match (config, decoder_type) {
4681            (ConfigOutputRef::Detection(_), DecoderType::ModelPack) => Self::modelpack_det_order,
4682            (ConfigOutputRef::Detection(_), DecoderType::Ultralytics) => Self::yolo_det_order,
4683            (ConfigOutputRef::Boxes(_), DecoderType::ModelPack) => Self::modelpack_boxes_order,
4684            (ConfigOutputRef::Boxes(_), DecoderType::Ultralytics) => Self::yolo_boxes_order,
4685            (ConfigOutputRef::Scores(_), DecoderType::ModelPack) => Self::modelpack_scores_order,
4686            (ConfigOutputRef::Scores(_), DecoderType::Ultralytics) => Self::yolo_scores_order,
4687            (ConfigOutputRef::Segmentation(_), _) => Self::modelpack_segmentation_order,
4688            (ConfigOutputRef::Mask(_), _) => Self::modelpack_mask_order,
4689            (ConfigOutputRef::Protos(_), _) => Self::yolo_protos_order,
4690            (ConfigOutputRef::MaskCoefficients(_), _) => Self::yolo_maskcoefficients_order,
4691            (ConfigOutputRef::Classes(_), _) => Self::yolo_scores_order,
4692        }
4693    }
4694
4695    fn swap_axes_if_needed<'a, T, D: Dimension>(
4696        array: &ArrayView<'a, T, D>,
4697        config: ConfigOutputRef,
4698    ) -> ArrayView<'a, T, D> {
4699        let mut array = array.clone();
4700        if config.dshape().is_empty() {
4701            return array;
4702        }
4703        let order_fn: fn(DimName) -> usize = Self::get_order_fn(config.clone());
4704        let mut current_order: Vec<usize> = config
4705            .dshape()
4706            .iter()
4707            .map(|x| order_fn(x.0))
4708            .collect::<Vec<_>>();
4709
4710        assert_eq!(array.shape().len(), current_order.len());
4711        // do simple bubble sort as swap_axes is inexpensive and the
4712        // number of dimensions is small
4713        for i in 0..current_order.len() {
4714            let mut swapped = false;
4715            for j in 0..current_order.len() - 1 - i {
4716                if current_order[j] > current_order[j + 1] {
4717                    array.swap_axes(j, j + 1);
4718                    current_order.swap(j, j + 1);
4719                    swapped = true;
4720                }
4721            }
4722            if !swapped {
4723                break;
4724            }
4725        }
4726        array
4727    }
4728
4729    fn match_outputs_to_detect_quantized<'a, 'b>(
4730        configs: &[configs::Detection],
4731        outputs: &'a [ArrayViewDQuantized<'b>],
4732    ) -> Result<Vec<&'a ArrayViewDQuantized<'b>>, DecoderError> {
4733        let mut new_output_order = Vec::new();
4734        for c in configs {
4735            let mut found = false;
4736            for o in outputs {
4737                if o.shape() == c.shape {
4738                    new_output_order.push(o);
4739                    found = true;
4740                    break;
4741                }
4742            }
4743            if !found {
4744                return Err(DecoderError::InvalidShape(format!(
4745                    "Did not find output with shape {:?}",
4746                    c.shape
4747                )));
4748            }
4749        }
4750        Ok(new_output_order)
4751    }
4752}
4753
4754#[cfg(test)]
4755#[cfg_attr(coverage_nightly, coverage(off))]
4756mod decoder_builder_tests {
4757    use super::*;
4758
4759    #[test]
4760    fn test_decoder_builder_no_config() {
4761        use crate::DecoderBuilder;
4762        let result = DecoderBuilder::default().build();
4763        assert!(matches!(result, Err(DecoderError::NoConfig)));
4764    }
4765
4766    #[test]
4767    fn test_decoder_builder_empty_config() {
4768        use crate::DecoderBuilder;
4769        let result = DecoderBuilder::default()
4770            .with_config(ConfigOutputs {
4771                outputs: vec![],
4772                ..Default::default()
4773            })
4774            .build();
4775        assert!(
4776            matches!(result, Err(DecoderError::InvalidConfig(s)) if s == "No outputs found in config")
4777        );
4778    }
4779
4780    #[test]
4781    fn test_malformed_config_yaml() {
4782        let malformed_yaml = "
4783        model_type: yolov8_det
4784        outputs:
4785          - shape: [1, 84, 8400]
4786        "
4787        .to_owned();
4788        let result = DecoderBuilder::new()
4789            .with_config_yaml_str(malformed_yaml)
4790            .build();
4791        assert!(matches!(result, Err(DecoderError::Yaml(_))));
4792    }
4793
4794    #[test]
4795    fn test_malformed_config_json() {
4796        let malformed_yaml = "
4797        {
4798            \"model_type\": \"yolov8_det\",
4799            \"outputs\": [
4800                {
4801                    \"shape\": [1, 84, 8400]
4802                }
4803            ]
4804        }"
4805        .to_owned();
4806        let result = DecoderBuilder::new()
4807            .with_config_json_str(malformed_yaml)
4808            .build();
4809        assert!(matches!(result, Err(DecoderError::Json(_))));
4810    }
4811
4812    #[test]
4813    fn test_modelpack_and_yolo_config_error() {
4814        let result = DecoderBuilder::new()
4815            .with_config_modelpack_det(
4816                configs::Boxes {
4817                    decoder: configs::DecoderType::Ultralytics,
4818                    shape: vec![1, 4, 8400],
4819                    quantization: None,
4820                    dshape: vec![
4821                        (DimName::Batch, 1),
4822                        (DimName::BoxCoords, 4),
4823                        (DimName::NumBoxes, 8400),
4824                    ],
4825                    normalized: Some(true),
4826                },
4827                configs::Scores {
4828                    decoder: configs::DecoderType::ModelPack,
4829                    shape: vec![1, 80, 8400],
4830                    quantization: None,
4831                    dshape: vec![
4832                        (DimName::Batch, 1),
4833                        (DimName::NumClasses, 80),
4834                        (DimName::NumBoxes, 8400),
4835                    ],
4836                },
4837            )
4838            .build();
4839
4840        assert!(matches!(
4841            result, Err(DecoderError::InvalidConfig(s)) if s == "Both ModelPack and Yolo outputs found in config"
4842        ));
4843    }
4844
4845    #[test]
4846    fn test_yolo_invalid_seg_shape() {
4847        let result = DecoderBuilder::new()
4848            .with_config_yolo_segdet(
4849                configs::Detection {
4850                    decoder: configs::DecoderType::Ultralytics,
4851                    shape: vec![1, 85, 8400, 1], // Invalid shape
4852                    quantization: None,
4853                    anchors: None,
4854                    dshape: vec![
4855                        (DimName::Batch, 1),
4856                        (DimName::NumFeatures, 85),
4857                        (DimName::NumBoxes, 8400),
4858                        (DimName::Batch, 1),
4859                    ],
4860                    normalized: Some(true),
4861                },
4862                configs::Protos {
4863                    decoder: configs::DecoderType::Ultralytics,
4864                    shape: vec![1, 32, 160, 160],
4865                    quantization: None,
4866                    dshape: vec![
4867                        (DimName::Batch, 1),
4868                        (DimName::NumProtos, 32),
4869                        (DimName::Height, 160),
4870                        (DimName::Width, 160),
4871                    ],
4872                },
4873                Some(DecoderVersion::Yolo11),
4874            )
4875            .build();
4876
4877        assert!(matches!(
4878            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Detection shape")
4879        ));
4880    }
4881
4882    #[test]
4883    fn test_yolo_invalid_mask() {
4884        let result = DecoderBuilder::new()
4885            .with_config(ConfigOutputs {
4886                outputs: vec![ConfigOutput::Mask(configs::Mask {
4887                    shape: vec![1, 160, 160, 1],
4888                    decoder: configs::DecoderType::Ultralytics,
4889                    quantization: None,
4890                    dshape: vec![
4891                        (DimName::Batch, 1),
4892                        (DimName::Height, 160),
4893                        (DimName::Width, 160),
4894                        (DimName::NumFeatures, 1),
4895                    ],
4896                })],
4897                ..Default::default()
4898            })
4899            .build();
4900
4901        assert!(matches!(
4902            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Mask output with Yolo decoder")
4903        ));
4904    }
4905
4906    #[test]
4907    fn test_yolo_invalid_outputs() {
4908        let result = DecoderBuilder::new()
4909            .with_config(ConfigOutputs {
4910                outputs: vec![ConfigOutput::Segmentation(configs::Segmentation {
4911                    shape: vec![1, 84, 8400],
4912                    decoder: configs::DecoderType::Ultralytics,
4913                    quantization: None,
4914                    dshape: vec![
4915                        (DimName::Batch, 1),
4916                        (DimName::NumFeatures, 84),
4917                        (DimName::NumBoxes, 8400),
4918                    ],
4919                })],
4920                ..Default::default()
4921            })
4922            .build();
4923
4924        assert!(
4925            matches!(result, Err(DecoderError::InvalidConfig(s)) if s == "Invalid Segmentation output with Yolo decoder")
4926        );
4927    }
4928
4929    #[test]
4930    fn test_yolo_invalid_det() {
4931        let result = DecoderBuilder::new()
4932            .with_config_yolo_det(
4933                configs::Detection {
4934                    anchors: None,
4935                    decoder: DecoderType::Ultralytics,
4936                    quantization: None,
4937                    shape: vec![1, 84, 8400, 1], // Invalid shape
4938                    dshape: vec![
4939                        (DimName::Batch, 1),
4940                        (DimName::NumFeatures, 84),
4941                        (DimName::NumBoxes, 8400),
4942                        (DimName::Batch, 1),
4943                    ],
4944                    normalized: Some(true),
4945                },
4946                Some(DecoderVersion::Yolo11),
4947            )
4948            .build();
4949
4950        assert!(matches!(
4951            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Detection shape")));
4952
4953        let result = DecoderBuilder::new()
4954            .with_config_yolo_det(
4955                configs::Detection {
4956                    anchors: None,
4957                    decoder: DecoderType::Ultralytics,
4958                    quantization: None,
4959                    shape: vec![1, 8400, 3], // Invalid shape
4960                    dshape: vec![
4961                        (DimName::Batch, 1),
4962                        (DimName::NumBoxes, 8400),
4963                        (DimName::NumFeatures, 3),
4964                    ],
4965                    normalized: Some(true),
4966                },
4967                Some(DecoderVersion::Yolo11),
4968            )
4969            .build();
4970
4971        assert!(
4972            matches!(
4973            &result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid shape: Yolo num_features 3 must be greater than 4")),
4974            "{}",
4975            result.unwrap_err()
4976        );
4977
4978        let result = DecoderBuilder::new()
4979            .with_config_yolo_det(
4980                configs::Detection {
4981                    anchors: None,
4982                    decoder: DecoderType::Ultralytics,
4983                    quantization: None,
4984                    shape: vec![1, 3, 8400], // Invalid shape
4985                    dshape: Vec::new(),
4986                    normalized: Some(true),
4987                },
4988                Some(DecoderVersion::Yolo11),
4989            )
4990            .build();
4991
4992        assert!(matches!(
4993            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid shape: Yolo num_features 3 must be greater than 4")));
4994    }
4995
4996    #[test]
4997    fn test_yolo_invalid_segdet() {
4998        let result = DecoderBuilder::new()
4999            .with_config_yolo_segdet(
5000                configs::Detection {
5001                    decoder: configs::DecoderType::Ultralytics,
5002                    shape: vec![1, 85, 8400, 1], // Invalid shape
5003                    quantization: None,
5004                    anchors: None,
5005                    dshape: vec![
5006                        (DimName::Batch, 1),
5007                        (DimName::NumFeatures, 85),
5008                        (DimName::NumBoxes, 8400),
5009                        (DimName::Batch, 1),
5010                    ],
5011                    normalized: Some(true),
5012                },
5013                configs::Protos {
5014                    decoder: configs::DecoderType::Ultralytics,
5015                    shape: vec![1, 32, 160, 160],
5016                    quantization: None,
5017                    dshape: vec![
5018                        (DimName::Batch, 1),
5019                        (DimName::NumProtos, 32),
5020                        (DimName::Height, 160),
5021                        (DimName::Width, 160),
5022                    ],
5023                },
5024                Some(DecoderVersion::Yolo11),
5025            )
5026            .build();
5027
5028        assert!(matches!(
5029            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Detection shape")));
5030
5031        let result = DecoderBuilder::new()
5032            .with_config_yolo_segdet(
5033                configs::Detection {
5034                    decoder: configs::DecoderType::Ultralytics,
5035                    shape: vec![1, 85, 8400],
5036                    quantization: None,
5037                    anchors: None,
5038                    dshape: vec![
5039                        (DimName::Batch, 1),
5040                        (DimName::NumFeatures, 85),
5041                        (DimName::NumBoxes, 8400),
5042                    ],
5043                    normalized: Some(true),
5044                },
5045                configs::Protos {
5046                    decoder: configs::DecoderType::Ultralytics,
5047                    shape: vec![1, 32, 160, 160, 1], // Invalid shape
5048                    dshape: vec![
5049                        (DimName::Batch, 1),
5050                        (DimName::NumProtos, 32),
5051                        (DimName::Height, 160),
5052                        (DimName::Width, 160),
5053                        (DimName::Batch, 1),
5054                    ],
5055                    quantization: None,
5056                },
5057                Some(DecoderVersion::Yolo11),
5058            )
5059            .build();
5060
5061        assert!(matches!(
5062            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Protos shape")));
5063
5064        let result = DecoderBuilder::new()
5065            .with_config_yolo_segdet(
5066                configs::Detection {
5067                    decoder: configs::DecoderType::Ultralytics,
5068                    shape: vec![1, 8400, 36], // too few classes
5069                    quantization: None,
5070                    anchors: None,
5071                    dshape: vec![
5072                        (DimName::Batch, 1),
5073                        (DimName::NumBoxes, 8400),
5074                        (DimName::NumFeatures, 36),
5075                    ],
5076                    normalized: Some(true),
5077                },
5078                configs::Protos {
5079                    decoder: configs::DecoderType::Ultralytics,
5080                    shape: vec![1, 32, 160, 160],
5081                    quantization: None,
5082                    dshape: vec![
5083                        (DimName::Batch, 1),
5084                        (DimName::NumProtos, 32),
5085                        (DimName::Height, 160),
5086                        (DimName::Width, 160),
5087                    ],
5088                },
5089                Some(DecoderVersion::Yolo11),
5090            )
5091            .build();
5092        println!("{:?}", result);
5093        assert!(matches!(
5094            result, Err(DecoderError::InvalidConfig(s)) if s == "Invalid shape: Yolo num_features 36 must be greater than 36"));
5095    }
5096
5097    #[test]
5098    fn test_yolo_invalid_split_det() {
5099        let result = DecoderBuilder::new()
5100            .with_config_yolo_split_det(
5101                configs::Boxes {
5102                    decoder: configs::DecoderType::Ultralytics,
5103                    shape: vec![1, 4, 8400, 1], // Invalid shape
5104                    quantization: None,
5105                    dshape: vec![
5106                        (DimName::Batch, 1),
5107                        (DimName::BoxCoords, 4),
5108                        (DimName::NumBoxes, 8400),
5109                        (DimName::Batch, 1),
5110                    ],
5111                    normalized: Some(true),
5112                },
5113                configs::Scores {
5114                    decoder: configs::DecoderType::Ultralytics,
5115                    shape: vec![1, 80, 8400],
5116                    quantization: None,
5117                    dshape: vec![
5118                        (DimName::Batch, 1),
5119                        (DimName::NumClasses, 80),
5120                        (DimName::NumBoxes, 8400),
5121                    ],
5122                },
5123            )
5124            .build();
5125
5126        assert!(matches!(
5127            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Split Boxes shape")));
5128
5129        let result = DecoderBuilder::new()
5130            .with_config_yolo_split_det(
5131                configs::Boxes {
5132                    decoder: configs::DecoderType::Ultralytics,
5133                    shape: vec![1, 4, 8400],
5134                    quantization: None,
5135                    dshape: vec![
5136                        (DimName::Batch, 1),
5137                        (DimName::BoxCoords, 4),
5138                        (DimName::NumBoxes, 8400),
5139                    ],
5140                    normalized: Some(true),
5141                },
5142                configs::Scores {
5143                    decoder: configs::DecoderType::Ultralytics,
5144                    shape: vec![1, 80, 8400, 1], // Invalid shape
5145                    quantization: None,
5146                    dshape: vec![
5147                        (DimName::Batch, 1),
5148                        (DimName::NumClasses, 80),
5149                        (DimName::NumBoxes, 8400),
5150                        (DimName::Batch, 1),
5151                    ],
5152                },
5153            )
5154            .build();
5155
5156        assert!(matches!(
5157            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Split Scores shape")));
5158
5159        let result = DecoderBuilder::new()
5160            .with_config_yolo_split_det(
5161                configs::Boxes {
5162                    decoder: configs::DecoderType::Ultralytics,
5163                    shape: vec![1, 8400, 4],
5164                    quantization: None,
5165                    dshape: vec![
5166                        (DimName::Batch, 1),
5167                        (DimName::NumBoxes, 8400),
5168                        (DimName::BoxCoords, 4),
5169                    ],
5170                    normalized: Some(true),
5171                },
5172                configs::Scores {
5173                    decoder: configs::DecoderType::Ultralytics,
5174                    shape: vec![1, 8400 + 1, 80], // Invalid number of boxes
5175                    quantization: None,
5176                    dshape: vec![
5177                        (DimName::Batch, 1),
5178                        (DimName::NumBoxes, 8401),
5179                        (DimName::NumClasses, 80),
5180                    ],
5181                },
5182            )
5183            .build();
5184
5185        assert!(matches!(
5186            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Yolo Split Detection Boxes num 8400 incompatible with Scores num 8401")));
5187
5188        let result = DecoderBuilder::new()
5189            .with_config_yolo_split_det(
5190                configs::Boxes {
5191                    decoder: configs::DecoderType::Ultralytics,
5192                    shape: vec![1, 5, 8400], // Invalid boxes dimensions
5193                    quantization: None,
5194                    dshape: vec![
5195                        (DimName::Batch, 1),
5196                        (DimName::BoxCoords, 5),
5197                        (DimName::NumBoxes, 8400),
5198                    ],
5199                    normalized: Some(true),
5200                },
5201                configs::Scores {
5202                    decoder: configs::DecoderType::Ultralytics,
5203                    shape: vec![1, 80, 8400],
5204                    quantization: None,
5205                    dshape: vec![
5206                        (DimName::Batch, 1),
5207                        (DimName::NumClasses, 80),
5208                        (DimName::NumBoxes, 8400),
5209                    ],
5210                },
5211            )
5212            .build();
5213        assert!(matches!(
5214            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("BoxCoords dimension size must be 4")));
5215    }
5216
5217    #[test]
5218    fn test_yolo_invalid_split_segdet() {
5219        let result = DecoderBuilder::new()
5220            .with_config_yolo_split_segdet(
5221                configs::Boxes {
5222                    decoder: configs::DecoderType::Ultralytics,
5223                    shape: vec![1, 8400, 4, 1],
5224                    quantization: None,
5225                    dshape: vec![
5226                        (DimName::Batch, 1),
5227                        (DimName::NumBoxes, 8400),
5228                        (DimName::BoxCoords, 4),
5229                        (DimName::Batch, 1),
5230                    ],
5231                    normalized: Some(true),
5232                },
5233                configs::Scores {
5234                    decoder: configs::DecoderType::Ultralytics,
5235                    shape: vec![1, 8400, 80],
5236
5237                    quantization: None,
5238                    dshape: vec![
5239                        (DimName::Batch, 1),
5240                        (DimName::NumBoxes, 8400),
5241                        (DimName::NumClasses, 80),
5242                    ],
5243                },
5244                configs::MaskCoefficients {
5245                    decoder: configs::DecoderType::Ultralytics,
5246                    shape: vec![1, 8400, 32],
5247                    quantization: None,
5248                    dshape: vec![
5249                        (DimName::Batch, 1),
5250                        (DimName::NumBoxes, 8400),
5251                        (DimName::NumProtos, 32),
5252                    ],
5253                },
5254                configs::Protos {
5255                    decoder: configs::DecoderType::Ultralytics,
5256                    shape: vec![1, 32, 160, 160],
5257                    quantization: None,
5258                    dshape: vec![
5259                        (DimName::Batch, 1),
5260                        (DimName::NumProtos, 32),
5261                        (DimName::Height, 160),
5262                        (DimName::Width, 160),
5263                    ],
5264                },
5265            )
5266            .build();
5267
5268        assert!(matches!(
5269            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Split Boxes shape")));
5270
5271        let result = DecoderBuilder::new()
5272            .with_config_yolo_split_segdet(
5273                configs::Boxes {
5274                    decoder: configs::DecoderType::Ultralytics,
5275                    shape: vec![1, 8400, 4],
5276                    quantization: None,
5277                    dshape: vec![
5278                        (DimName::Batch, 1),
5279                        (DimName::NumBoxes, 8400),
5280                        (DimName::BoxCoords, 4),
5281                    ],
5282                    normalized: Some(true),
5283                },
5284                configs::Scores {
5285                    decoder: configs::DecoderType::Ultralytics,
5286                    shape: vec![1, 8400, 80, 1],
5287                    quantization: None,
5288                    dshape: vec![
5289                        (DimName::Batch, 1),
5290                        (DimName::NumBoxes, 8400),
5291                        (DimName::NumClasses, 80),
5292                        (DimName::Batch, 1),
5293                    ],
5294                },
5295                configs::MaskCoefficients {
5296                    decoder: configs::DecoderType::Ultralytics,
5297                    shape: vec![1, 8400, 32],
5298                    quantization: None,
5299                    dshape: vec![
5300                        (DimName::Batch, 1),
5301                        (DimName::NumBoxes, 8400),
5302                        (DimName::NumProtos, 32),
5303                    ],
5304                },
5305                configs::Protos {
5306                    decoder: configs::DecoderType::Ultralytics,
5307                    shape: vec![1, 32, 160, 160],
5308                    quantization: None,
5309                    dshape: vec![
5310                        (DimName::Batch, 1),
5311                        (DimName::NumProtos, 32),
5312                        (DimName::Height, 160),
5313                        (DimName::Width, 160),
5314                    ],
5315                },
5316            )
5317            .build();
5318
5319        assert!(matches!(
5320            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Split Scores shape")));
5321
5322        let result = DecoderBuilder::new()
5323            .with_config_yolo_split_segdet(
5324                configs::Boxes {
5325                    decoder: configs::DecoderType::Ultralytics,
5326                    shape: vec![1, 8400, 4],
5327                    quantization: None,
5328                    dshape: vec![
5329                        (DimName::Batch, 1),
5330                        (DimName::NumBoxes, 8400),
5331                        (DimName::BoxCoords, 4),
5332                    ],
5333                    normalized: Some(true),
5334                },
5335                configs::Scores {
5336                    decoder: configs::DecoderType::Ultralytics,
5337                    shape: vec![1, 8400, 80],
5338                    quantization: None,
5339                    dshape: vec![
5340                        (DimName::Batch, 1),
5341                        (DimName::NumBoxes, 8400),
5342                        (DimName::NumClasses, 80),
5343                    ],
5344                },
5345                configs::MaskCoefficients {
5346                    decoder: configs::DecoderType::Ultralytics,
5347                    shape: vec![1, 8400, 32, 1],
5348                    quantization: None,
5349                    dshape: vec![
5350                        (DimName::Batch, 1),
5351                        (DimName::NumBoxes, 8400),
5352                        (DimName::NumProtos, 32),
5353                        (DimName::Batch, 1),
5354                    ],
5355                },
5356                configs::Protos {
5357                    decoder: configs::DecoderType::Ultralytics,
5358                    shape: vec![1, 32, 160, 160],
5359                    quantization: None,
5360                    dshape: vec![
5361                        (DimName::Batch, 1),
5362                        (DimName::NumProtos, 32),
5363                        (DimName::Height, 160),
5364                        (DimName::Width, 160),
5365                    ],
5366                },
5367            )
5368            .build();
5369
5370        assert!(matches!(
5371            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Split Mask Coefficients shape")));
5372
5373        let result = DecoderBuilder::new()
5374            .with_config_yolo_split_segdet(
5375                configs::Boxes {
5376                    decoder: configs::DecoderType::Ultralytics,
5377                    shape: vec![1, 8400, 4],
5378                    quantization: None,
5379                    dshape: vec![
5380                        (DimName::Batch, 1),
5381                        (DimName::NumBoxes, 8400),
5382                        (DimName::BoxCoords, 4),
5383                    ],
5384                    normalized: Some(true),
5385                },
5386                configs::Scores {
5387                    decoder: configs::DecoderType::Ultralytics,
5388                    shape: vec![1, 8400, 80],
5389                    quantization: None,
5390                    dshape: vec![
5391                        (DimName::Batch, 1),
5392                        (DimName::NumBoxes, 8400),
5393                        (DimName::NumClasses, 80),
5394                    ],
5395                },
5396                configs::MaskCoefficients {
5397                    decoder: configs::DecoderType::Ultralytics,
5398                    shape: vec![1, 8400, 32],
5399                    quantization: None,
5400                    dshape: vec![
5401                        (DimName::Batch, 1),
5402                        (DimName::NumBoxes, 8400),
5403                        (DimName::NumProtos, 32),
5404                    ],
5405                },
5406                configs::Protos {
5407                    decoder: configs::DecoderType::Ultralytics,
5408                    shape: vec![1, 32, 160, 160, 1],
5409                    quantization: None,
5410                    dshape: vec![
5411                        (DimName::Batch, 1),
5412                        (DimName::NumProtos, 32),
5413                        (DimName::Height, 160),
5414                        (DimName::Width, 160),
5415                        (DimName::Batch, 1),
5416                    ],
5417                },
5418            )
5419            .build();
5420
5421        assert!(matches!(
5422            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Protos shape")));
5423
5424        let result = DecoderBuilder::new()
5425            .with_config_yolo_split_segdet(
5426                configs::Boxes {
5427                    decoder: configs::DecoderType::Ultralytics,
5428                    shape: vec![1, 8400, 4],
5429                    quantization: None,
5430                    dshape: vec![
5431                        (DimName::Batch, 1),
5432                        (DimName::NumBoxes, 8400),
5433                        (DimName::BoxCoords, 4),
5434                    ],
5435                    normalized: Some(true),
5436                },
5437                configs::Scores {
5438                    decoder: configs::DecoderType::Ultralytics,
5439                    shape: vec![1, 8401, 80],
5440                    quantization: None,
5441                    dshape: vec![
5442                        (DimName::Batch, 1),
5443                        (DimName::NumBoxes, 8401),
5444                        (DimName::NumClasses, 80),
5445                    ],
5446                },
5447                configs::MaskCoefficients {
5448                    decoder: configs::DecoderType::Ultralytics,
5449                    shape: vec![1, 8400, 32],
5450                    quantization: None,
5451                    dshape: vec![
5452                        (DimName::Batch, 1),
5453                        (DimName::NumBoxes, 8400),
5454                        (DimName::NumProtos, 32),
5455                    ],
5456                },
5457                configs::Protos {
5458                    decoder: configs::DecoderType::Ultralytics,
5459                    shape: vec![1, 32, 160, 160],
5460                    quantization: None,
5461                    dshape: vec![
5462                        (DimName::Batch, 1),
5463                        (DimName::NumProtos, 32),
5464                        (DimName::Height, 160),
5465                        (DimName::Width, 160),
5466                    ],
5467                },
5468            )
5469            .build();
5470
5471        assert!(matches!(
5472            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Yolo Split Detection Boxes num 8400 incompatible with Scores num 8401")));
5473
5474        let result = DecoderBuilder::new()
5475            .with_config_yolo_split_segdet(
5476                configs::Boxes {
5477                    decoder: configs::DecoderType::Ultralytics,
5478                    shape: vec![1, 8400, 4],
5479                    quantization: None,
5480                    dshape: vec![
5481                        (DimName::Batch, 1),
5482                        (DimName::NumBoxes, 8400),
5483                        (DimName::BoxCoords, 4),
5484                    ],
5485                    normalized: Some(true),
5486                },
5487                configs::Scores {
5488                    decoder: configs::DecoderType::Ultralytics,
5489                    shape: vec![1, 8400, 80],
5490                    quantization: None,
5491                    dshape: vec![
5492                        (DimName::Batch, 1),
5493                        (DimName::NumBoxes, 8400),
5494                        (DimName::NumClasses, 80),
5495                    ],
5496                },
5497                configs::MaskCoefficients {
5498                    decoder: configs::DecoderType::Ultralytics,
5499                    shape: vec![1, 8401, 32],
5500
5501                    quantization: None,
5502                    dshape: vec![
5503                        (DimName::Batch, 1),
5504                        (DimName::NumBoxes, 8401),
5505                        (DimName::NumProtos, 32),
5506                    ],
5507                },
5508                configs::Protos {
5509                    decoder: configs::DecoderType::Ultralytics,
5510                    shape: vec![1, 32, 160, 160],
5511                    quantization: None,
5512                    dshape: vec![
5513                        (DimName::Batch, 1),
5514                        (DimName::NumProtos, 32),
5515                        (DimName::Height, 160),
5516                        (DimName::Width, 160),
5517                    ],
5518                },
5519            )
5520            .build();
5521
5522        assert!(matches!(
5523            result, Err(DecoderError::InvalidConfig(ref s)) if s.starts_with("Yolo Split Detection Boxes num 8400 incompatible with Mask Coefficients num 8401")));
5524        let result = DecoderBuilder::new()
5525            .with_config_yolo_split_segdet(
5526                configs::Boxes {
5527                    decoder: configs::DecoderType::Ultralytics,
5528                    shape: vec![1, 8400, 4],
5529                    quantization: None,
5530                    dshape: vec![
5531                        (DimName::Batch, 1),
5532                        (DimName::NumBoxes, 8400),
5533                        (DimName::BoxCoords, 4),
5534                    ],
5535                    normalized: Some(true),
5536                },
5537                configs::Scores {
5538                    decoder: configs::DecoderType::Ultralytics,
5539                    shape: vec![1, 8400, 80],
5540                    quantization: None,
5541                    dshape: vec![
5542                        (DimName::Batch, 1),
5543                        (DimName::NumBoxes, 8400),
5544                        (DimName::NumClasses, 80),
5545                    ],
5546                },
5547                configs::MaskCoefficients {
5548                    decoder: configs::DecoderType::Ultralytics,
5549                    shape: vec![1, 8400, 32],
5550                    quantization: None,
5551                    dshape: vec![
5552                        (DimName::Batch, 1),
5553                        (DimName::NumBoxes, 8400),
5554                        (DimName::NumProtos, 32),
5555                    ],
5556                },
5557                configs::Protos {
5558                    decoder: configs::DecoderType::Ultralytics,
5559                    shape: vec![1, 31, 160, 160],
5560                    quantization: None,
5561                    dshape: vec![
5562                        (DimName::Batch, 1),
5563                        (DimName::NumProtos, 31),
5564                        (DimName::Height, 160),
5565                        (DimName::Width, 160),
5566                    ],
5567                },
5568            )
5569            .build();
5570        println!("{:?}", result);
5571        assert!(matches!(
5572            result, Err(DecoderError::InvalidConfig(ref s)) if s.starts_with( "Yolo Protos channels 31 incompatible with Mask Coefficients channels 32")));
5573    }
5574
5575    #[test]
5576    fn test_modelpack_invalid_config() {
5577        let result = DecoderBuilder::new()
5578            .with_config(ConfigOutputs {
5579                outputs: vec![
5580                    ConfigOutput::Boxes(configs::Boxes {
5581                        decoder: configs::DecoderType::ModelPack,
5582                        shape: vec![1, 8400, 1, 4],
5583                        quantization: None,
5584                        dshape: vec![
5585                            (DimName::Batch, 1),
5586                            (DimName::NumBoxes, 8400),
5587                            (DimName::Padding, 1),
5588                            (DimName::BoxCoords, 4),
5589                        ],
5590                        normalized: Some(true),
5591                    }),
5592                    ConfigOutput::Scores(configs::Scores {
5593                        decoder: configs::DecoderType::ModelPack,
5594                        shape: vec![1, 8400, 3],
5595                        quantization: None,
5596                        dshape: vec![
5597                            (DimName::Batch, 1),
5598                            (DimName::NumBoxes, 8400),
5599                            (DimName::NumClasses, 3),
5600                        ],
5601                    }),
5602                    ConfigOutput::Protos(configs::Protos {
5603                        decoder: configs::DecoderType::ModelPack,
5604                        shape: vec![1, 8400, 3],
5605                        quantization: None,
5606                        dshape: vec![
5607                            (DimName::Batch, 1),
5608                            (DimName::NumBoxes, 8400),
5609                            (DimName::NumFeatures, 3),
5610                        ],
5611                    }),
5612                ],
5613                ..Default::default()
5614            })
5615            .build();
5616
5617        assert!(matches!(
5618            result, Err(DecoderError::InvalidConfig(s)) if s == "ModelPack should not have protos"));
5619
5620        let result = DecoderBuilder::new()
5621            .with_config(ConfigOutputs {
5622                outputs: vec![
5623                    ConfigOutput::Boxes(configs::Boxes {
5624                        decoder: configs::DecoderType::ModelPack,
5625                        shape: vec![1, 8400, 1, 4],
5626                        quantization: None,
5627                        dshape: vec![
5628                            (DimName::Batch, 1),
5629                            (DimName::NumBoxes, 8400),
5630                            (DimName::Padding, 1),
5631                            (DimName::BoxCoords, 4),
5632                        ],
5633                        normalized: Some(true),
5634                    }),
5635                    ConfigOutput::Scores(configs::Scores {
5636                        decoder: configs::DecoderType::ModelPack,
5637                        shape: vec![1, 8400, 3],
5638                        quantization: None,
5639                        dshape: vec![
5640                            (DimName::Batch, 1),
5641                            (DimName::NumBoxes, 8400),
5642                            (DimName::NumClasses, 3),
5643                        ],
5644                    }),
5645                    ConfigOutput::MaskCoefficients(configs::MaskCoefficients {
5646                        decoder: configs::DecoderType::ModelPack,
5647                        shape: vec![1, 8400, 3],
5648                        quantization: None,
5649                        dshape: vec![
5650                            (DimName::Batch, 1),
5651                            (DimName::NumBoxes, 8400),
5652                            (DimName::NumProtos, 3),
5653                        ],
5654                    }),
5655                ],
5656                ..Default::default()
5657            })
5658            .build();
5659
5660        assert!(matches!(
5661            result, Err(DecoderError::InvalidConfig(s)) if s == "ModelPack should not have mask coefficients"));
5662
5663        let result = DecoderBuilder::new()
5664            .with_config(ConfigOutputs {
5665                outputs: vec![ConfigOutput::Boxes(configs::Boxes {
5666                    decoder: configs::DecoderType::ModelPack,
5667                    shape: vec![1, 8400, 1, 4],
5668                    quantization: None,
5669                    dshape: vec![
5670                        (DimName::Batch, 1),
5671                        (DimName::NumBoxes, 8400),
5672                        (DimName::Padding, 1),
5673                        (DimName::BoxCoords, 4),
5674                    ],
5675                    normalized: Some(true),
5676                })],
5677                ..Default::default()
5678            })
5679            .build();
5680
5681        assert!(matches!(
5682            result, Err(DecoderError::InvalidConfig(s)) if s == "Invalid ModelPack model outputs"));
5683    }
5684
5685    #[test]
5686    fn test_modelpack_invalid_det() {
5687        let result = DecoderBuilder::new()
5688            .with_config_modelpack_det(
5689                configs::Boxes {
5690                    decoder: DecoderType::ModelPack,
5691                    quantization: None,
5692                    shape: vec![1, 4, 8400],
5693                    dshape: vec![
5694                        (DimName::Batch, 1),
5695                        (DimName::BoxCoords, 4),
5696                        (DimName::NumBoxes, 8400),
5697                    ],
5698                    normalized: Some(true),
5699                },
5700                configs::Scores {
5701                    decoder: DecoderType::ModelPack,
5702                    quantization: None,
5703                    shape: vec![1, 80, 8400],
5704                    dshape: vec![
5705                        (DimName::Batch, 1),
5706                        (DimName::NumClasses, 80),
5707                        (DimName::NumBoxes, 8400),
5708                    ],
5709                },
5710            )
5711            .build();
5712
5713        assert!(matches!(
5714            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid ModelPack Boxes shape")));
5715
5716        let result = DecoderBuilder::new()
5717            .with_config_modelpack_det(
5718                configs::Boxes {
5719                    decoder: DecoderType::ModelPack,
5720                    quantization: None,
5721                    shape: vec![1, 4, 1, 8400],
5722                    dshape: vec![
5723                        (DimName::Batch, 1),
5724                        (DimName::BoxCoords, 4),
5725                        (DimName::Padding, 1),
5726                        (DimName::NumBoxes, 8400),
5727                    ],
5728                    normalized: Some(true),
5729                },
5730                configs::Scores {
5731                    decoder: DecoderType::ModelPack,
5732                    quantization: None,
5733                    shape: vec![1, 80, 8400, 1],
5734                    dshape: vec![
5735                        (DimName::Batch, 1),
5736                        (DimName::NumClasses, 80),
5737                        (DimName::NumBoxes, 8400),
5738                        (DimName::Padding, 1),
5739                    ],
5740                },
5741            )
5742            .build();
5743
5744        assert!(matches!(
5745            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid ModelPack Scores shape")));
5746
5747        let result = DecoderBuilder::new()
5748            .with_config_modelpack_det(
5749                configs::Boxes {
5750                    decoder: DecoderType::ModelPack,
5751                    quantization: None,
5752                    shape: vec![1, 4, 2, 8400],
5753                    dshape: vec![
5754                        (DimName::Batch, 1),
5755                        (DimName::BoxCoords, 4),
5756                        (DimName::Padding, 2),
5757                        (DimName::NumBoxes, 8400),
5758                    ],
5759                    normalized: Some(true),
5760                },
5761                configs::Scores {
5762                    decoder: DecoderType::ModelPack,
5763                    quantization: None,
5764                    shape: vec![1, 80, 8400],
5765                    dshape: vec![
5766                        (DimName::Batch, 1),
5767                        (DimName::NumClasses, 80),
5768                        (DimName::NumBoxes, 8400),
5769                    ],
5770                },
5771            )
5772            .build();
5773        assert!(matches!(
5774            result, Err(DecoderError::InvalidConfig(s)) if s == "Padding dimension size must be 1"));
5775
5776        let result = DecoderBuilder::new()
5777            .with_config_modelpack_det(
5778                configs::Boxes {
5779                    decoder: DecoderType::ModelPack,
5780                    quantization: None,
5781                    shape: vec![1, 5, 1, 8400],
5782                    dshape: vec![
5783                        (DimName::Batch, 1),
5784                        (DimName::BoxCoords, 5),
5785                        (DimName::Padding, 1),
5786                        (DimName::NumBoxes, 8400),
5787                    ],
5788                    normalized: Some(true),
5789                },
5790                configs::Scores {
5791                    decoder: DecoderType::ModelPack,
5792                    quantization: None,
5793                    shape: vec![1, 80, 8400],
5794                    dshape: vec![
5795                        (DimName::Batch, 1),
5796                        (DimName::NumClasses, 80),
5797                        (DimName::NumBoxes, 8400),
5798                    ],
5799                },
5800            )
5801            .build();
5802
5803        assert!(matches!(
5804            result, Err(DecoderError::InvalidConfig(s)) if s == "BoxCoords dimension size must be 4"));
5805
5806        let result = DecoderBuilder::new()
5807            .with_config_modelpack_det(
5808                configs::Boxes {
5809                    decoder: DecoderType::ModelPack,
5810                    quantization: None,
5811                    shape: vec![1, 4, 1, 8400],
5812                    dshape: vec![
5813                        (DimName::Batch, 1),
5814                        (DimName::BoxCoords, 4),
5815                        (DimName::Padding, 1),
5816                        (DimName::NumBoxes, 8400),
5817                    ],
5818                    normalized: Some(true),
5819                },
5820                configs::Scores {
5821                    decoder: DecoderType::ModelPack,
5822                    quantization: None,
5823                    shape: vec![1, 80, 8401],
5824                    dshape: vec![
5825                        (DimName::Batch, 1),
5826                        (DimName::NumClasses, 80),
5827                        (DimName::NumBoxes, 8401),
5828                    ],
5829                },
5830            )
5831            .build();
5832
5833        assert!(matches!(
5834            result, Err(DecoderError::InvalidConfig(s)) if s == "ModelPack Detection Boxes num 8400 incompatible with Scores num 8401"));
5835    }
5836
5837    #[test]
5838    fn test_modelpack_invalid_det_split() {
5839        let result = DecoderBuilder::default()
5840            .with_config_modelpack_det_split(vec![
5841                configs::Detection {
5842                    decoder: DecoderType::ModelPack,
5843                    shape: vec![1, 17, 30, 18],
5844                    anchors: None,
5845                    quantization: None,
5846                    dshape: vec![
5847                        (DimName::Batch, 1),
5848                        (DimName::Height, 17),
5849                        (DimName::Width, 30),
5850                        (DimName::NumAnchorsXFeatures, 18),
5851                    ],
5852                    normalized: Some(true),
5853                },
5854                configs::Detection {
5855                    decoder: DecoderType::ModelPack,
5856                    shape: vec![1, 9, 15, 18],
5857                    anchors: None,
5858                    quantization: None,
5859                    dshape: vec![
5860                        (DimName::Batch, 1),
5861                        (DimName::Height, 9),
5862                        (DimName::Width, 15),
5863                        (DimName::NumAnchorsXFeatures, 18),
5864                    ],
5865                    normalized: Some(true),
5866                },
5867            ])
5868            .build();
5869
5870        assert!(matches!(
5871            result, Err(DecoderError::InvalidConfig(s)) if s == "ModelPack Split Detection missing anchors"));
5872
5873        let result = DecoderBuilder::default()
5874            .with_config_modelpack_det_split(vec![configs::Detection {
5875                decoder: DecoderType::ModelPack,
5876                shape: vec![1, 17, 30, 18],
5877                anchors: None,
5878                quantization: None,
5879                dshape: Vec::new(),
5880                normalized: Some(true),
5881            }])
5882            .build();
5883
5884        assert!(matches!(
5885            result, Err(DecoderError::InvalidConfig(s)) if s == "ModelPack Split Detection missing anchors"));
5886
5887        let result = DecoderBuilder::default()
5888            .with_config_modelpack_det_split(vec![configs::Detection {
5889                decoder: DecoderType::ModelPack,
5890                shape: vec![1, 17, 30, 18],
5891                anchors: Some(vec![]),
5892                quantization: None,
5893                dshape: vec![
5894                    (DimName::Batch, 1),
5895                    (DimName::Height, 17),
5896                    (DimName::Width, 30),
5897                    (DimName::NumAnchorsXFeatures, 18),
5898                ],
5899                normalized: Some(true),
5900            }])
5901            .build();
5902
5903        assert!(matches!(
5904            result, Err(DecoderError::InvalidConfig(s)) if s == "ModelPack Split Detection has zero anchors"));
5905
5906        let result = DecoderBuilder::default()
5907            .with_config_modelpack_det_split(vec![configs::Detection {
5908                decoder: DecoderType::ModelPack,
5909                shape: vec![1, 17, 30, 18, 1],
5910                anchors: Some(vec![
5911                    [0.3666666, 0.3148148],
5912                    [0.3874999, 0.474074],
5913                    [0.5333333, 0.644444],
5914                ]),
5915                quantization: None,
5916                dshape: vec![
5917                    (DimName::Batch, 1),
5918                    (DimName::Height, 17),
5919                    (DimName::Width, 30),
5920                    (DimName::NumAnchorsXFeatures, 18),
5921                    (DimName::Padding, 1),
5922                ],
5923                normalized: Some(true),
5924            }])
5925            .build();
5926
5927        assert!(matches!(
5928            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid ModelPack Split Detection shape")));
5929
5930        let result = DecoderBuilder::default()
5931            .with_config_modelpack_det_split(vec![configs::Detection {
5932                decoder: DecoderType::ModelPack,
5933                shape: vec![1, 15, 17, 30],
5934                anchors: Some(vec![
5935                    [0.3666666, 0.3148148],
5936                    [0.3874999, 0.474074],
5937                    [0.5333333, 0.644444],
5938                ]),
5939                quantization: None,
5940                dshape: vec![
5941                    (DimName::Batch, 1),
5942                    (DimName::NumAnchorsXFeatures, 15),
5943                    (DimName::Height, 17),
5944                    (DimName::Width, 30),
5945                ],
5946                normalized: Some(true),
5947            }])
5948            .build();
5949
5950        assert!(matches!(
5951            result, Err(DecoderError::InvalidConfig(s)) if s.contains("not greater than number of anchors * 5 =")));
5952
5953        let result = DecoderBuilder::default()
5954            .with_config_modelpack_det_split(vec![configs::Detection {
5955                decoder: DecoderType::ModelPack,
5956                shape: vec![1, 17, 30, 15],
5957                anchors: Some(vec![
5958                    [0.3666666, 0.3148148],
5959                    [0.3874999, 0.474074],
5960                    [0.5333333, 0.644444],
5961                ]),
5962                quantization: None,
5963                dshape: Vec::new(),
5964                normalized: Some(true),
5965            }])
5966            .build();
5967
5968        assert!(matches!(
5969            result, Err(DecoderError::InvalidConfig(s)) if s.contains("not greater than number of anchors * 5 =")));
5970
5971        let result = DecoderBuilder::default()
5972            .with_config_modelpack_det_split(vec![configs::Detection {
5973                decoder: DecoderType::ModelPack,
5974                shape: vec![1, 16, 17, 30],
5975                anchors: Some(vec![
5976                    [0.3666666, 0.3148148],
5977                    [0.3874999, 0.474074],
5978                    [0.5333333, 0.644444],
5979                ]),
5980                quantization: None,
5981                dshape: vec![
5982                    (DimName::Batch, 1),
5983                    (DimName::NumAnchorsXFeatures, 16),
5984                    (DimName::Height, 17),
5985                    (DimName::Width, 30),
5986                ],
5987                normalized: Some(true),
5988            }])
5989            .build();
5990
5991        assert!(matches!(
5992            result, Err(DecoderError::InvalidConfig(s)) if s.contains("not a multiple of number of anchors")));
5993
5994        let result = DecoderBuilder::default()
5995            .with_config_modelpack_det_split(vec![configs::Detection {
5996                decoder: DecoderType::ModelPack,
5997                shape: vec![1, 17, 30, 16],
5998                anchors: Some(vec![
5999                    [0.3666666, 0.3148148],
6000                    [0.3874999, 0.474074],
6001                    [0.5333333, 0.644444],
6002                ]),
6003                quantization: None,
6004                dshape: Vec::new(),
6005                normalized: Some(true),
6006            }])
6007            .build();
6008
6009        assert!(matches!(
6010            result, Err(DecoderError::InvalidConfig(s)) if s.contains("not a multiple of number of anchors")));
6011
6012        let result = DecoderBuilder::default()
6013            .with_config_modelpack_det_split(vec![configs::Detection {
6014                decoder: DecoderType::ModelPack,
6015                shape: vec![1, 18, 17, 30],
6016                anchors: Some(vec![
6017                    [0.3666666, 0.3148148],
6018                    [0.3874999, 0.474074],
6019                    [0.5333333, 0.644444],
6020                ]),
6021                quantization: None,
6022                dshape: vec![
6023                    (DimName::Batch, 1),
6024                    (DimName::NumProtos, 18),
6025                    (DimName::Height, 17),
6026                    (DimName::Width, 30),
6027                ],
6028                normalized: Some(true),
6029            }])
6030            .build();
6031        assert!(matches!(
6032            result, Err(DecoderError::InvalidConfig(s)) if s.contains("Split Detection dshape missing required dimension NumAnchorsXFeature")));
6033
6034        let result = DecoderBuilder::default()
6035            .with_config_modelpack_det_split(vec![
6036                configs::Detection {
6037                    decoder: DecoderType::ModelPack,
6038                    shape: vec![1, 17, 30, 18],
6039                    anchors: Some(vec![
6040                        [0.3666666, 0.3148148],
6041                        [0.3874999, 0.474074],
6042                        [0.5333333, 0.644444],
6043                    ]),
6044                    quantization: None,
6045                    dshape: vec![
6046                        (DimName::Batch, 1),
6047                        (DimName::Height, 17),
6048                        (DimName::Width, 30),
6049                        (DimName::NumAnchorsXFeatures, 18),
6050                    ],
6051                    normalized: Some(true),
6052                },
6053                configs::Detection {
6054                    decoder: DecoderType::ModelPack,
6055                    shape: vec![1, 17, 30, 21],
6056                    anchors: Some(vec![
6057                        [0.3666666, 0.3148148],
6058                        [0.3874999, 0.474074],
6059                        [0.5333333, 0.644444],
6060                    ]),
6061                    quantization: None,
6062                    dshape: vec![
6063                        (DimName::Batch, 1),
6064                        (DimName::Height, 17),
6065                        (DimName::Width, 30),
6066                        (DimName::NumAnchorsXFeatures, 21),
6067                    ],
6068                    normalized: Some(true),
6069                },
6070            ])
6071            .build();
6072
6073        assert!(matches!(
6074            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("ModelPack Split Detection inconsistent number of classes:")));
6075
6076        let result = DecoderBuilder::default()
6077            .with_config_modelpack_det_split(vec![
6078                configs::Detection {
6079                    decoder: DecoderType::ModelPack,
6080                    shape: vec![1, 17, 30, 18],
6081                    anchors: Some(vec![
6082                        [0.3666666, 0.3148148],
6083                        [0.3874999, 0.474074],
6084                        [0.5333333, 0.644444],
6085                    ]),
6086                    quantization: None,
6087                    dshape: vec![],
6088                    normalized: Some(true),
6089                },
6090                configs::Detection {
6091                    decoder: DecoderType::ModelPack,
6092                    shape: vec![1, 17, 30, 21],
6093                    anchors: Some(vec![
6094                        [0.3666666, 0.3148148],
6095                        [0.3874999, 0.474074],
6096                        [0.5333333, 0.644444],
6097                    ]),
6098                    quantization: None,
6099                    dshape: vec![],
6100                    normalized: Some(true),
6101                },
6102            ])
6103            .build();
6104
6105        assert!(matches!(
6106            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("ModelPack Split Detection inconsistent number of classes:")));
6107    }
6108
6109    #[test]
6110    fn test_modelpack_invalid_seg() {
6111        let result = DecoderBuilder::new()
6112            .with_config_modelpack_seg(configs::Segmentation {
6113                decoder: DecoderType::ModelPack,
6114                quantization: None,
6115                shape: vec![1, 160, 106, 3, 1],
6116                dshape: vec![
6117                    (DimName::Batch, 1),
6118                    (DimName::Height, 160),
6119                    (DimName::Width, 106),
6120                    (DimName::NumClasses, 3),
6121                    (DimName::Padding, 1),
6122                ],
6123            })
6124            .build();
6125
6126        assert!(matches!(
6127            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid ModelPack Segmentation shape")));
6128    }
6129
6130    #[test]
6131    fn test_modelpack_invalid_segdet() {
6132        let result = DecoderBuilder::new()
6133            .with_config_modelpack_segdet(
6134                configs::Boxes {
6135                    decoder: DecoderType::ModelPack,
6136                    quantization: None,
6137                    shape: vec![1, 4, 1, 8400],
6138                    dshape: vec![
6139                        (DimName::Batch, 1),
6140                        (DimName::BoxCoords, 4),
6141                        (DimName::Padding, 1),
6142                        (DimName::NumBoxes, 8400),
6143                    ],
6144                    normalized: Some(true),
6145                },
6146                configs::Scores {
6147                    decoder: DecoderType::ModelPack,
6148                    quantization: None,
6149                    shape: vec![1, 4, 8400],
6150                    dshape: vec![
6151                        (DimName::Batch, 1),
6152                        (DimName::NumClasses, 4),
6153                        (DimName::NumBoxes, 8400),
6154                    ],
6155                },
6156                configs::Segmentation {
6157                    decoder: DecoderType::ModelPack,
6158                    quantization: None,
6159                    shape: vec![1, 160, 106, 3],
6160                    dshape: vec![
6161                        (DimName::Batch, 1),
6162                        (DimName::Height, 160),
6163                        (DimName::Width, 106),
6164                        (DimName::NumClasses, 3),
6165                    ],
6166                },
6167            )
6168            .build();
6169
6170        assert!(matches!(
6171            result, Err(DecoderError::InvalidConfig(s)) if s.contains("incompatible with number of classes")));
6172    }
6173
6174    #[test]
6175    fn test_modelpack_invalid_segdet_split() {
6176        let result = DecoderBuilder::new()
6177            .with_config_modelpack_segdet_split(
6178                vec![configs::Detection {
6179                    decoder: DecoderType::ModelPack,
6180                    shape: vec![1, 17, 30, 18],
6181                    anchors: Some(vec![
6182                        [0.3666666, 0.3148148],
6183                        [0.3874999, 0.474074],
6184                        [0.5333333, 0.644444],
6185                    ]),
6186                    quantization: None,
6187                    dshape: vec![
6188                        (DimName::Batch, 1),
6189                        (DimName::Height, 17),
6190                        (DimName::Width, 30),
6191                        (DimName::NumAnchorsXFeatures, 18),
6192                    ],
6193                    normalized: Some(true),
6194                }],
6195                configs::Segmentation {
6196                    decoder: DecoderType::ModelPack,
6197                    quantization: None,
6198                    shape: vec![1, 160, 106, 3],
6199                    dshape: vec![
6200                        (DimName::Batch, 1),
6201                        (DimName::Height, 160),
6202                        (DimName::Width, 106),
6203                        (DimName::NumClasses, 3),
6204                    ],
6205                },
6206            )
6207            .build();
6208
6209        assert!(matches!(
6210            result, Err(DecoderError::InvalidConfig(s)) if s.contains("incompatible with number of classes")));
6211    }
6212
6213    #[test]
6214    fn test_decode_bad_shapes() {
6215        let score_threshold = 0.25;
6216        let iou_threshold = 0.7;
6217        let quant = (0.0040811873, -123);
6218        let out = include_bytes!("../../../testdata/yolov8s_80_classes.bin");
6219        let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
6220        let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
6221        let out_float: Array3<f32> = dequantize_ndarray(out.view(), quant.into());
6222
6223        let decoder = DecoderBuilder::default()
6224            .with_config_yolo_det(
6225                configs::Detection {
6226                    decoder: DecoderType::Ultralytics,
6227                    shape: vec![1, 85, 8400],
6228                    anchors: None,
6229                    quantization: Some(quant.into()),
6230                    dshape: vec![
6231                        (DimName::Batch, 1),
6232                        (DimName::NumFeatures, 85),
6233                        (DimName::NumBoxes, 8400),
6234                    ],
6235                    normalized: Some(true),
6236                },
6237                Some(DecoderVersion::Yolo11),
6238            )
6239            .with_score_threshold(score_threshold)
6240            .with_iou_threshold(iou_threshold)
6241            .build()
6242            .unwrap();
6243
6244        let mut output_boxes: Vec<_> = Vec::with_capacity(50);
6245        let mut output_masks: Vec<_> = Vec::with_capacity(50);
6246        let result =
6247            decoder.decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks);
6248
6249        assert!(matches!(
6250            result, Err(DecoderError::InvalidShape(s)) if s == "Did not find output with shape [1, 85, 8400]"));
6251
6252        let result = decoder.decode_float(
6253            &[out_float.view().into_dyn()],
6254            &mut output_boxes,
6255            &mut output_masks,
6256        );
6257
6258        assert!(matches!(
6259            result, Err(DecoderError::InvalidShape(s)) if s == "Did not find output with shape [1, 85, 8400]"));
6260    }
6261
6262    #[test]
6263    fn test_config_outputs() {
6264        let outputs = [
6265            ConfigOutput::Detection(configs::Detection {
6266                decoder: configs::DecoderType::Ultralytics,
6267                anchors: None,
6268                shape: vec![1, 8400, 85],
6269                quantization: Some(QuantTuple(0.123, 0)),
6270                dshape: vec![
6271                    (DimName::Batch, 1),
6272                    (DimName::NumBoxes, 8400),
6273                    (DimName::NumFeatures, 85),
6274                ],
6275                normalized: Some(true),
6276            }),
6277            ConfigOutput::Mask(configs::Mask {
6278                decoder: configs::DecoderType::Ultralytics,
6279                shape: vec![1, 160, 160, 1],
6280                quantization: Some(QuantTuple(0.223, 0)),
6281                dshape: vec![
6282                    (DimName::Batch, 1),
6283                    (DimName::Height, 160),
6284                    (DimName::Width, 160),
6285                    (DimName::NumFeatures, 1),
6286                ],
6287            }),
6288            ConfigOutput::Segmentation(configs::Segmentation {
6289                decoder: configs::DecoderType::Ultralytics,
6290                shape: vec![1, 160, 160, 80],
6291                quantization: Some(QuantTuple(0.323, 0)),
6292                dshape: vec![
6293                    (DimName::Batch, 1),
6294                    (DimName::Height, 160),
6295                    (DimName::Width, 160),
6296                    (DimName::NumClasses, 80),
6297                ],
6298            }),
6299            ConfigOutput::Scores(configs::Scores {
6300                decoder: configs::DecoderType::Ultralytics,
6301                shape: vec![1, 8400, 80],
6302                quantization: Some(QuantTuple(0.423, 0)),
6303                dshape: vec![
6304                    (DimName::Batch, 1),
6305                    (DimName::NumBoxes, 8400),
6306                    (DimName::NumClasses, 80),
6307                ],
6308            }),
6309            ConfigOutput::Boxes(configs::Boxes {
6310                decoder: configs::DecoderType::Ultralytics,
6311                shape: vec![1, 8400, 4],
6312                quantization: Some(QuantTuple(0.523, 0)),
6313                dshape: vec![
6314                    (DimName::Batch, 1),
6315                    (DimName::NumBoxes, 8400),
6316                    (DimName::BoxCoords, 4),
6317                ],
6318                normalized: Some(true),
6319            }),
6320            ConfigOutput::Protos(configs::Protos {
6321                decoder: configs::DecoderType::Ultralytics,
6322                shape: vec![1, 32, 160, 160],
6323                quantization: Some(QuantTuple(0.623, 0)),
6324                dshape: vec![
6325                    (DimName::Batch, 1),
6326                    (DimName::NumProtos, 32),
6327                    (DimName::Height, 160),
6328                    (DimName::Width, 160),
6329                ],
6330            }),
6331            ConfigOutput::MaskCoefficients(configs::MaskCoefficients {
6332                decoder: configs::DecoderType::Ultralytics,
6333                shape: vec![1, 8400, 32],
6334                quantization: Some(QuantTuple(0.723, 0)),
6335                dshape: vec![
6336                    (DimName::Batch, 1),
6337                    (DimName::NumBoxes, 8400),
6338                    (DimName::NumProtos, 32),
6339                ],
6340            }),
6341        ];
6342
6343        let shapes = outputs.clone().map(|x| x.shape().to_vec());
6344        assert_eq!(
6345            shapes,
6346            [
6347                vec![1, 8400, 85],
6348                vec![1, 160, 160, 1],
6349                vec![1, 160, 160, 80],
6350                vec![1, 8400, 80],
6351                vec![1, 8400, 4],
6352                vec![1, 32, 160, 160],
6353                vec![1, 8400, 32],
6354            ]
6355        );
6356
6357        let quants: [Option<(f32, i32)>; 7] = outputs.map(|x| x.quantization().map(|q| q.into()));
6358        assert_eq!(
6359            quants,
6360            [
6361                Some((0.123, 0)),
6362                Some((0.223, 0)),
6363                Some((0.323, 0)),
6364                Some((0.423, 0)),
6365                Some((0.523, 0)),
6366                Some((0.623, 0)),
6367                Some((0.723, 0)),
6368            ]
6369        );
6370    }
6371
6372    #[test]
6373    fn test_nms_from_config_yaml() {
6374        // Test parsing NMS from YAML config
6375        let yaml_class_agnostic = r#"
6376outputs:
6377  - decoder: ultralytics
6378    type: detection
6379    shape: [1, 84, 8400]
6380    dshape:
6381      - [batch, 1]
6382      - [num_features, 84]
6383      - [num_boxes, 8400]
6384nms: class_agnostic
6385"#;
6386        let decoder = DecoderBuilder::new()
6387            .with_config_yaml_str(yaml_class_agnostic.to_string())
6388            .build()
6389            .unwrap();
6390        assert_eq!(decoder.nms, Some(configs::Nms::ClassAgnostic));
6391
6392        let yaml_class_aware = r#"
6393outputs:
6394  - decoder: ultralytics
6395    type: detection
6396    shape: [1, 84, 8400]
6397    dshape:
6398      - [batch, 1]
6399      - [num_features, 84]
6400      - [num_boxes, 8400]
6401nms: class_aware
6402"#;
6403        let decoder = DecoderBuilder::new()
6404            .with_config_yaml_str(yaml_class_aware.to_string())
6405            .build()
6406            .unwrap();
6407        assert_eq!(decoder.nms, Some(configs::Nms::ClassAware));
6408
6409        // Test that config NMS overrides builder NMS
6410        let decoder = DecoderBuilder::new()
6411            .with_config_yaml_str(yaml_class_aware.to_string())
6412            .with_nms(Some(configs::Nms::ClassAgnostic)) // Builder sets agnostic
6413            .build()
6414            .unwrap();
6415        // Config should override builder
6416        assert_eq!(decoder.nms, Some(configs::Nms::ClassAware));
6417    }
6418
6419    #[test]
6420    fn test_nms_from_config_json() {
6421        // Test parsing NMS from JSON config
6422        let json_class_aware = r#"{
6423            "outputs": [{
6424                "decoder": "ultralytics",
6425                "type": "detection",
6426                "shape": [1, 84, 8400],
6427                "dshape": [["batch", 1], ["num_features", 84], ["num_boxes", 8400]]
6428            }],
6429            "nms": "class_aware"
6430        }"#;
6431        let decoder = DecoderBuilder::new()
6432            .with_config_json_str(json_class_aware.to_string())
6433            .build()
6434            .unwrap();
6435        assert_eq!(decoder.nms, Some(configs::Nms::ClassAware));
6436    }
6437
6438    #[test]
6439    fn test_nms_missing_from_config_uses_builder_default() {
6440        // Test that missing NMS in config uses builder default
6441        let yaml_no_nms = r#"
6442outputs:
6443  - decoder: ultralytics
6444    type: detection
6445    shape: [1, 84, 8400]
6446    dshape:
6447      - [batch, 1]
6448      - [num_features, 84]
6449      - [num_boxes, 8400]
6450"#;
6451        let decoder = DecoderBuilder::new()
6452            .with_config_yaml_str(yaml_no_nms.to_string())
6453            .build()
6454            .unwrap();
6455        // Default builder NMS is ClassAgnostic
6456        assert_eq!(decoder.nms, Some(configs::Nms::ClassAgnostic));
6457
6458        // Test with explicit builder NMS
6459        let decoder = DecoderBuilder::new()
6460            .with_config_yaml_str(yaml_no_nms.to_string())
6461            .with_nms(None) // Explicitly set to None (bypass NMS)
6462            .build()
6463            .unwrap();
6464        assert_eq!(decoder.nms, None);
6465    }
6466
6467    #[test]
6468    fn test_decoder_version_yolo26_end_to_end() {
6469        // Test that decoder_version: yolo26 creates end-to-end model type
6470        let yaml = r#"
6471outputs:
6472  - decoder: ultralytics
6473    type: detection
6474    shape: [1, 6, 8400]
6475    dshape:
6476      - [batch, 1]
6477      - [num_features, 6]
6478      - [num_boxes, 8400]
6479decoder_version: yolo26
6480"#;
6481        let decoder = DecoderBuilder::new()
6482            .with_config_yaml_str(yaml.to_string())
6483            .build()
6484            .unwrap();
6485        assert!(matches!(
6486            decoder.model_type,
6487            ModelType::YoloEndToEndDet { .. }
6488        ));
6489
6490        // Even with NMS set, yolo26 should use end-to-end
6491        let yaml_with_nms = r#"
6492outputs:
6493  - decoder: ultralytics
6494    type: detection
6495    shape: [1, 6, 8400]
6496    dshape:
6497      - [batch, 1]
6498      - [num_features, 6]
6499      - [num_boxes, 8400]
6500decoder_version: yolo26
6501nms: class_agnostic
6502"#;
6503        let decoder = DecoderBuilder::new()
6504            .with_config_yaml_str(yaml_with_nms.to_string())
6505            .build()
6506            .unwrap();
6507        assert!(matches!(
6508            decoder.model_type,
6509            ModelType::YoloEndToEndDet { .. }
6510        ));
6511    }
6512
6513    #[test]
6514    fn test_decoder_version_yolov8_traditional() {
6515        // Test that decoder_version: yolov8 creates traditional model type
6516        let yaml = r#"
6517outputs:
6518  - decoder: ultralytics
6519    type: detection
6520    shape: [1, 84, 8400]
6521    dshape:
6522      - [batch, 1]
6523      - [num_features, 84]
6524      - [num_boxes, 8400]
6525decoder_version: yolov8
6526"#;
6527        let decoder = DecoderBuilder::new()
6528            .with_config_yaml_str(yaml.to_string())
6529            .build()
6530            .unwrap();
6531        assert!(matches!(decoder.model_type, ModelType::YoloDet { .. }));
6532    }
6533
6534    #[test]
6535    fn test_decoder_version_all_versions() {
6536        // Test all supported decoder versions parse correctly
6537        for version in ["yolov5", "yolov8", "yolo11"] {
6538            let yaml = format!(
6539                r#"
6540outputs:
6541  - decoder: ultralytics
6542    type: detection
6543    shape: [1, 84, 8400]
6544    dshape:
6545      - [batch, 1]
6546      - [num_features, 84]
6547      - [num_boxes, 8400]
6548decoder_version: {}
6549"#,
6550                version
6551            );
6552            let decoder = DecoderBuilder::new()
6553                .with_config_yaml_str(yaml)
6554                .build()
6555                .unwrap();
6556
6557            assert!(
6558                matches!(decoder.model_type, ModelType::YoloDet { .. }),
6559                "Expected traditional for {}",
6560                version
6561            );
6562        }
6563
6564        let yaml = r#"
6565outputs:
6566  - decoder: ultralytics
6567    type: detection
6568    shape: [1, 6, 8400]
6569    dshape:
6570      - [batch, 1]
6571      - [num_features, 6]
6572      - [num_boxes, 8400]
6573decoder_version: yolo26
6574"#
6575        .to_string();
6576
6577        let decoder = DecoderBuilder::new()
6578            .with_config_yaml_str(yaml)
6579            .build()
6580            .unwrap();
6581
6582        assert!(
6583            matches!(decoder.model_type, ModelType::YoloEndToEndDet { .. }),
6584            "Expected end to end for yolo26",
6585        );
6586    }
6587
6588    #[test]
6589    fn test_decoder_version_json() {
6590        // Test parsing decoder_version from JSON config
6591        let json = r#"{
6592            "outputs": [{
6593                "decoder": "ultralytics",
6594                "type": "detection",
6595                "shape": [1, 6, 8400],
6596                "dshape": [["batch", 1], ["num_features", 6], ["num_boxes", 8400]]
6597            }],
6598            "decoder_version": "yolo26"
6599        }"#;
6600        let decoder = DecoderBuilder::new()
6601            .with_config_json_str(json.to_string())
6602            .build()
6603            .unwrap();
6604        assert!(matches!(
6605            decoder.model_type,
6606            ModelType::YoloEndToEndDet { .. }
6607        ));
6608    }
6609
6610    #[test]
6611    fn test_decoder_version_none_uses_traditional() {
6612        // Without decoder_version, traditional model type is used
6613        let yaml = r#"
6614outputs:
6615  - decoder: ultralytics
6616    type: detection
6617    shape: [1, 84, 8400]
6618    dshape:
6619      - [batch, 1]
6620      - [num_features, 84]
6621      - [num_boxes, 8400]
6622"#;
6623        let decoder = DecoderBuilder::new()
6624            .with_config_yaml_str(yaml.to_string())
6625            .build()
6626            .unwrap();
6627        assert!(matches!(decoder.model_type, ModelType::YoloDet { .. }));
6628    }
6629
6630    #[test]
6631    fn test_decoder_version_none_with_nms_none_still_traditional() {
6632        // Without decoder_version, nms: None now means user handles NMS, not end-to-end
6633        // This is a behavior change from the previous implementation
6634        let yaml = r#"
6635outputs:
6636  - decoder: ultralytics
6637    type: detection
6638    shape: [1, 84, 8400]
6639    dshape:
6640      - [batch, 1]
6641      - [num_features, 84]
6642      - [num_boxes, 8400]
6643"#;
6644        let decoder = DecoderBuilder::new()
6645            .with_config_yaml_str(yaml.to_string())
6646            .with_nms(None) // User wants to handle NMS themselves
6647            .build()
6648            .unwrap();
6649        // nms=None with 84 features (80 classes) -> traditional YoloDet (user handles
6650        // NMS)
6651        assert!(matches!(decoder.model_type, ModelType::YoloDet { .. }));
6652    }
6653
6654    #[test]
6655    fn test_decoder_heuristic_end_to_end_detection() {
6656        // models with (batch, num_boxes, num_features) output shape are treated
6657        // as end-to-end detection
6658        let yaml = r#"
6659outputs:
6660  - decoder: ultralytics
6661    type: detection
6662    shape: [1, 300, 6]
6663    dshape:
6664      - [batch, 1]
6665      - [num_boxes, 300]
6666      - [num_features, 6]
6667 
6668"#;
6669        let decoder = DecoderBuilder::new()
6670            .with_config_yaml_str(yaml.to_string())
6671            .build()
6672            .unwrap();
6673        // 6 features with (batch, N, features) layout -> end-to-end detection
6674        assert!(matches!(
6675            decoder.model_type,
6676            ModelType::YoloEndToEndDet { .. }
6677        ));
6678
6679        let yaml = r#"
6680outputs:
6681  - decoder: ultralytics
6682    type: detection
6683    shape: [1, 300, 38]
6684    dshape:
6685      - [batch, 1]
6686      - [num_boxes, 300]
6687      - [num_features, 38]
6688  - decoder: ultralytics
6689    type: protos
6690    shape: [1, 160, 160, 32]
6691    dshape:
6692      - [batch, 1]
6693      - [height, 160]
6694      - [width, 160]
6695      - [num_protos, 32]
6696"#;
6697        let decoder = DecoderBuilder::new()
6698            .with_config_yaml_str(yaml.to_string())
6699            .build()
6700            .unwrap();
6701        // 7 features with protos -> end-to-end segmentation detection
6702        assert!(matches!(
6703            decoder.model_type,
6704            ModelType::YoloEndToEndSegDet { .. }
6705        ));
6706
6707        let yaml = r#"
6708outputs:
6709  - decoder: ultralytics
6710    type: detection
6711    shape: [1, 6, 300]
6712    dshape:
6713      - [batch, 1]
6714      - [num_features, 6]
6715      - [num_boxes, 300] 
6716"#;
6717        let decoder = DecoderBuilder::new()
6718            .with_config_yaml_str(yaml.to_string())
6719            .build()
6720            .unwrap();
6721        // 6 features -> traditional YOLO detection (needs num_classes > 0 for
6722        // end-to-end)
6723        assert!(matches!(decoder.model_type, ModelType::YoloDet { .. }));
6724
6725        let yaml = r#"
6726outputs:
6727  - decoder: ultralytics
6728    type: detection
6729    shape: [1, 38, 300]
6730    dshape:
6731      - [batch, 1]
6732      - [num_features, 38]
6733      - [num_boxes, 300]
6734
6735  - decoder: ultralytics
6736    type: protos
6737    shape: [1, 160, 160, 32]
6738    dshape:
6739      - [batch, 1]
6740      - [height, 160]
6741      - [width, 160]
6742      - [num_protos, 32]
6743"#;
6744        let decoder = DecoderBuilder::new()
6745            .with_config_yaml_str(yaml.to_string())
6746            .build()
6747            .unwrap();
6748        // 38 features (4+2+32) with protos -> traditional YOLO segmentation detection
6749        assert!(matches!(decoder.model_type, ModelType::YoloSegDet { .. }));
6750    }
6751
6752    #[test]
6753    fn test_decoder_version_is_end_to_end() {
6754        assert!(!configs::DecoderVersion::Yolov5.is_end_to_end());
6755        assert!(!configs::DecoderVersion::Yolov8.is_end_to_end());
6756        assert!(!configs::DecoderVersion::Yolo11.is_end_to_end());
6757        assert!(configs::DecoderVersion::Yolo26.is_end_to_end());
6758    }
6759
6760    #[test]
6761    fn test_dshape_dict_format() {
6762        // Spec produces array-of-single-key-dicts: [{"batch": 1}, {"num_features": 84}]
6763        let json = r#"{
6764            "decoder": "ultralytics",
6765            "shape": [1, 84, 8400],
6766            "dshape": [{"batch": 1}, {"num_features": 84}, {"num_boxes": 8400}]
6767        }"#;
6768        let det: configs::Detection = serde_json::from_str(json).unwrap();
6769        assert_eq!(det.dshape.len(), 3);
6770        assert_eq!(det.dshape[0], (configs::DimName::Batch, 1));
6771        assert_eq!(det.dshape[1], (configs::DimName::NumFeatures, 84));
6772        assert_eq!(det.dshape[2], (configs::DimName::NumBoxes, 8400));
6773    }
6774
6775    #[test]
6776    fn test_dshape_tuple_format() {
6777        // Serde native tuple format: [["batch", 1], ["num_features", 84]]
6778        let json = r#"{
6779            "decoder": "ultralytics",
6780            "shape": [1, 84, 8400],
6781            "dshape": [["batch", 1], ["num_features", 84], ["num_boxes", 8400]]
6782        }"#;
6783        let det: configs::Detection = serde_json::from_str(json).unwrap();
6784        assert_eq!(det.dshape.len(), 3);
6785        assert_eq!(det.dshape[0], (configs::DimName::Batch, 1));
6786        assert_eq!(det.dshape[1], (configs::DimName::NumFeatures, 84));
6787        assert_eq!(det.dshape[2], (configs::DimName::NumBoxes, 8400));
6788    }
6789
6790    #[test]
6791    fn test_dshape_empty_default() {
6792        // When dshape is omitted entirely, default to empty vec
6793        let json = r#"{
6794            "decoder": "ultralytics",
6795            "shape": [1, 84, 8400]
6796        }"#;
6797        let det: configs::Detection = serde_json::from_str(json).unwrap();
6798        assert!(det.dshape.is_empty());
6799    }
6800
6801    #[test]
6802    fn test_dshape_dict_format_protos() {
6803        let json = r#"{
6804            "decoder": "ultralytics",
6805            "shape": [1, 32, 160, 160],
6806            "dshape": [{"batch": 1}, {"num_protos": 32}, {"height": 160}, {"width": 160}]
6807        }"#;
6808        let protos: configs::Protos = serde_json::from_str(json).unwrap();
6809        assert_eq!(protos.dshape.len(), 4);
6810        assert_eq!(protos.dshape[0], (configs::DimName::Batch, 1));
6811        assert_eq!(protos.dshape[1], (configs::DimName::NumProtos, 32));
6812    }
6813
6814    #[test]
6815    fn test_dshape_dict_format_boxes() {
6816        let json = r#"{
6817            "decoder": "ultralytics",
6818            "shape": [1, 8400, 4],
6819            "dshape": [{"batch": 1}, {"num_boxes": 8400}, {"box_coords": 4}]
6820        }"#;
6821        let boxes: configs::Boxes = serde_json::from_str(json).unwrap();
6822        assert_eq!(boxes.dshape.len(), 3);
6823        assert_eq!(boxes.dshape[2], (configs::DimName::BoxCoords, 4));
6824    }
6825}