Skip to main content

edgefirst_decoder/decoder/
builder.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::HashSet;
5
6use super::config::ConfigOutputRef;
7use super::configs::{self, DecoderType, DecoderVersion, DimName, ModelType};
8use super::{ConfigOutput, ConfigOutputs, Decoder};
9use crate::DecoderError;
10
11#[derive(Debug, Clone, PartialEq)]
12pub struct DecoderBuilder {
13    config_src: Option<ConfigSource>,
14    iou_threshold: f32,
15    score_threshold: f32,
16    /// NMS mode: Some(mode) applies NMS, None bypasses NMS (for end-to-end
17    /// models)
18    nms: Option<configs::Nms>,
19}
20
21#[derive(Debug, Clone, PartialEq)]
22enum ConfigSource {
23    Yaml(String),
24    Json(String),
25    Config(ConfigOutputs),
26}
27
28impl Default for DecoderBuilder {
29    /// Creates a default DecoderBuilder with no configuration and 0.5 score
30    /// threshold and 0.5 IoU threshold.
31    ///
32    /// A valid configuration must be provided before building the Decoder.
33    ///
34    /// # Examples
35    /// ```rust
36    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
37    /// # fn main() -> DecoderResult<()> {
38    /// #  let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
39    /// let decoder = DecoderBuilder::default()
40    ///     .with_config_yaml_str(config_yaml)
41    ///     .build()?;
42    /// assert_eq!(decoder.score_threshold, 0.5);
43    /// assert_eq!(decoder.iou_threshold, 0.5);
44    ///
45    /// # Ok(())
46    /// # }
47    /// ```
48    fn default() -> Self {
49        Self {
50            config_src: None,
51            iou_threshold: 0.5,
52            score_threshold: 0.5,
53            nms: Some(configs::Nms::ClassAgnostic),
54        }
55    }
56}
57
58impl DecoderBuilder {
59    /// Creates a default DecoderBuilder with no configuration and 0.5 score
60    /// threshold and 0.5 IoU threshold.
61    ///
62    /// A valid configuration must be provided before building the Decoder.
63    ///
64    /// # Examples
65    /// ```rust
66    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
67    /// # fn main() -> DecoderResult<()> {
68    /// #  let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
69    /// let decoder = DecoderBuilder::new()
70    ///     .with_config_yaml_str(config_yaml)
71    ///     .build()?;
72    /// assert_eq!(decoder.score_threshold, 0.5);
73    /// assert_eq!(decoder.iou_threshold, 0.5);
74    ///
75    /// # Ok(())
76    /// # }
77    /// ```
78    pub fn new() -> Self {
79        Self::default()
80    }
81
82    /// Loads a model configuration in YAML format. Does not check if the string
83    /// is a correct configuration file. Use `DecoderBuilder.build()` to
84    /// deserialize the YAML and parse the model configuration.
85    ///
86    /// # Examples
87    /// ```rust
88    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
89    /// # fn main() -> DecoderResult<()> {
90    /// let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
91    /// let decoder = DecoderBuilder::new()
92    ///     .with_config_yaml_str(config_yaml)
93    ///     .build()?;
94    ///
95    /// # Ok(())
96    /// # }
97    /// ```
98    pub fn with_config_yaml_str(mut self, yaml_str: String) -> Self {
99        self.config_src.replace(ConfigSource::Yaml(yaml_str));
100        self
101    }
102
103    /// Loads a model configuration in JSON format. Does not check if the string
104    /// is a correct configuration file. Use `DecoderBuilder.build()` to
105    /// deserialize the JSON and parse the model configuration.
106    ///
107    /// # Examples
108    /// ```rust
109    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
110    /// # fn main() -> DecoderResult<()> {
111    /// let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
112    /// let decoder = DecoderBuilder::new()
113    ///     .with_config_json_str(config_json)
114    ///     .build()?;
115    ///
116    /// # Ok(())
117    /// # }
118    /// ```
119    pub fn with_config_json_str(mut self, json_str: String) -> Self {
120        self.config_src.replace(ConfigSource::Json(json_str));
121        self
122    }
123
124    /// Loads a model configuration. Does not check if the configuration is
125    /// correct. Intended to be used when the user needs control over the
126    /// deserialize of the configuration information. Use
127    /// `DecoderBuilder.build()` to parse the model configuration.
128    ///
129    /// # Examples
130    /// ```rust
131    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
132    /// # fn main() -> DecoderResult<()> {
133    /// let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json"));
134    /// let config = serde_json::from_str(config_json)?;
135    /// let decoder = DecoderBuilder::new().with_config(config).build()?;
136    ///
137    /// # Ok(())
138    /// # }
139    /// ```
140    pub fn with_config(mut self, config: ConfigOutputs) -> Self {
141        self.config_src.replace(ConfigSource::Config(config));
142        self
143    }
144
145    /// Loads a YOLO detection model configuration.  Use
146    /// `DecoderBuilder.build()` to parse the model configuration.
147    ///
148    /// # Examples
149    /// ```rust
150    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
151    /// # fn main() -> DecoderResult<()> {
152    /// let decoder = DecoderBuilder::new()
153    ///     .with_config_yolo_det(
154    ///         configs::Detection {
155    ///             anchors: None,
156    ///             decoder: configs::DecoderType::Ultralytics,
157    ///             quantization: Some(configs::QuantTuple(0.012345, 26)),
158    ///             shape: vec![1, 84, 8400],
159    ///             dshape: Vec::new(),
160    ///             normalized: Some(true),
161    ///         },
162    ///         None,
163    ///     )
164    ///     .build()?;
165    ///
166    /// # Ok(())
167    /// # }
168    /// ```
169    pub fn with_config_yolo_det(
170        mut self,
171        boxes: configs::Detection,
172        version: Option<DecoderVersion>,
173    ) -> Self {
174        let config = ConfigOutputs {
175            outputs: vec![ConfigOutput::Detection(boxes)],
176            decoder_version: version,
177            ..Default::default()
178        };
179        self.config_src.replace(ConfigSource::Config(config));
180        self
181    }
182
183    /// Loads a YOLO split detection model configuration.  Use
184    /// `DecoderBuilder.build()` to parse the model configuration.
185    ///
186    /// # Examples
187    /// ```rust
188    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
189    /// # fn main() -> DecoderResult<()> {
190    /// let boxes_config = configs::Boxes {
191    ///     decoder: configs::DecoderType::Ultralytics,
192    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
193    ///     shape: vec![1, 4, 8400],
194    ///     dshape: Vec::new(),
195    ///     normalized: Some(true),
196    /// };
197    /// let scores_config = configs::Scores {
198    ///     decoder: configs::DecoderType::Ultralytics,
199    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
200    ///     shape: vec![1, 80, 8400],
201    ///     dshape: Vec::new(),
202    /// };
203    /// let decoder = DecoderBuilder::new()
204    ///     .with_config_yolo_split_det(boxes_config, scores_config)
205    ///     .build()?;
206    /// # Ok(())
207    /// # }
208    /// ```
209    pub fn with_config_yolo_split_det(
210        mut self,
211        boxes: configs::Boxes,
212        scores: configs::Scores,
213    ) -> Self {
214        let config = ConfigOutputs {
215            outputs: vec![ConfigOutput::Boxes(boxes), ConfigOutput::Scores(scores)],
216            ..Default::default()
217        };
218        self.config_src.replace(ConfigSource::Config(config));
219        self
220    }
221
222    /// Loads a YOLO segmentation model configuration.  Use
223    /// `DecoderBuilder.build()` to parse the model configuration.
224    ///
225    /// # Examples
226    /// ```rust
227    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
228    /// # fn main() -> DecoderResult<()> {
229    /// let seg_config = configs::Detection {
230    ///     decoder: configs::DecoderType::Ultralytics,
231    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
232    ///     shape: vec![1, 116, 8400],
233    ///     anchors: None,
234    ///     dshape: Vec::new(),
235    ///     normalized: Some(true),
236    /// };
237    /// let protos_config = configs::Protos {
238    ///     decoder: configs::DecoderType::Ultralytics,
239    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
240    ///     shape: vec![1, 160, 160, 32],
241    ///     dshape: Vec::new(),
242    /// };
243    /// let decoder = DecoderBuilder::new()
244    ///     .with_config_yolo_segdet(
245    ///         seg_config,
246    ///         protos_config,
247    ///         Some(configs::DecoderVersion::Yolov8),
248    ///     )
249    ///     .build()?;
250    /// # Ok(())
251    /// # }
252    /// ```
253    pub fn with_config_yolo_segdet(
254        mut self,
255        boxes: configs::Detection,
256        protos: configs::Protos,
257        version: Option<DecoderVersion>,
258    ) -> Self {
259        let config = ConfigOutputs {
260            outputs: vec![ConfigOutput::Detection(boxes), ConfigOutput::Protos(protos)],
261            decoder_version: version,
262            ..Default::default()
263        };
264        self.config_src.replace(ConfigSource::Config(config));
265        self
266    }
267
268    /// Loads a YOLO split segmentation model configuration.  Use
269    /// `DecoderBuilder.build()` to parse the model configuration.
270    ///
271    /// # Examples
272    /// ```rust
273    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
274    /// # fn main() -> DecoderResult<()> {
275    /// let boxes_config = configs::Boxes {
276    ///     decoder: configs::DecoderType::Ultralytics,
277    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
278    ///     shape: vec![1, 4, 8400],
279    ///     dshape: Vec::new(),
280    ///     normalized: Some(true),
281    /// };
282    /// let scores_config = configs::Scores {
283    ///     decoder: configs::DecoderType::Ultralytics,
284    ///     quantization: Some(configs::QuantTuple(0.012345, 14)),
285    ///     shape: vec![1, 80, 8400],
286    ///     dshape: Vec::new(),
287    /// };
288    /// let mask_config = configs::MaskCoefficients {
289    ///     decoder: configs::DecoderType::Ultralytics,
290    ///     quantization: Some(configs::QuantTuple(0.0064123, 125)),
291    ///     shape: vec![1, 32, 8400],
292    ///     dshape: Vec::new(),
293    /// };
294    /// let protos_config = configs::Protos {
295    ///     decoder: configs::DecoderType::Ultralytics,
296    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
297    ///     shape: vec![1, 160, 160, 32],
298    ///     dshape: Vec::new(),
299    /// };
300    /// let decoder = DecoderBuilder::new()
301    ///     .with_config_yolo_split_segdet(boxes_config, scores_config, mask_config, protos_config)
302    ///     .build()?;
303    /// # Ok(())
304    /// # }
305    /// ```
306    pub fn with_config_yolo_split_segdet(
307        mut self,
308        boxes: configs::Boxes,
309        scores: configs::Scores,
310        mask_coefficients: configs::MaskCoefficients,
311        protos: configs::Protos,
312    ) -> Self {
313        let config = ConfigOutputs {
314            outputs: vec![
315                ConfigOutput::Boxes(boxes),
316                ConfigOutput::Scores(scores),
317                ConfigOutput::MaskCoefficients(mask_coefficients),
318                ConfigOutput::Protos(protos),
319            ],
320            ..Default::default()
321        };
322        self.config_src.replace(ConfigSource::Config(config));
323        self
324    }
325
326    /// Loads a ModelPack detection model configuration.  Use
327    /// `DecoderBuilder.build()` to parse the model configuration.
328    ///
329    /// # Examples
330    /// ```rust
331    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
332    /// # fn main() -> DecoderResult<()> {
333    /// let boxes_config = configs::Boxes {
334    ///     decoder: configs::DecoderType::ModelPack,
335    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
336    ///     shape: vec![1, 8400, 1, 4],
337    ///     dshape: Vec::new(),
338    ///     normalized: Some(true),
339    /// };
340    /// let scores_config = configs::Scores {
341    ///     decoder: configs::DecoderType::ModelPack,
342    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
343    ///     shape: vec![1, 8400, 3],
344    ///     dshape: Vec::new(),
345    /// };
346    /// let decoder = DecoderBuilder::new()
347    ///     .with_config_modelpack_det(boxes_config, scores_config)
348    ///     .build()?;
349    /// # Ok(())
350    /// # }
351    /// ```
352    pub fn with_config_modelpack_det(
353        mut self,
354        boxes: configs::Boxes,
355        scores: configs::Scores,
356    ) -> Self {
357        let config = ConfigOutputs {
358            outputs: vec![ConfigOutput::Boxes(boxes), ConfigOutput::Scores(scores)],
359            ..Default::default()
360        };
361        self.config_src.replace(ConfigSource::Config(config));
362        self
363    }
364
365    /// Loads a ModelPack split detection model configuration. Use
366    /// `DecoderBuilder.build()` to parse the model configuration.
367    ///
368    /// # Examples
369    /// ```rust
370    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
371    /// # fn main() -> DecoderResult<()> {
372    /// let config0 = configs::Detection {
373    ///     anchors: Some(vec![
374    ///         [0.13750000298023224, 0.2074074000120163],
375    ///         [0.2541666626930237, 0.21481481194496155],
376    ///         [0.23125000298023224, 0.35185185074806213],
377    ///     ]),
378    ///     decoder: configs::DecoderType::ModelPack,
379    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
380    ///     shape: vec![1, 17, 30, 18],
381    ///     dshape: Vec::new(),
382    ///     normalized: Some(true),
383    /// };
384    /// let config1 = configs::Detection {
385    ///     anchors: Some(vec![
386    ///         [0.36666667461395264, 0.31481480598449707],
387    ///         [0.38749998807907104, 0.4740740656852722],
388    ///         [0.5333333611488342, 0.644444465637207],
389    ///     ]),
390    ///     decoder: configs::DecoderType::ModelPack,
391    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
392    ///     shape: vec![1, 9, 15, 18],
393    ///     dshape: Vec::new(),
394    ///     normalized: Some(true),
395    /// };
396    ///
397    /// let decoder = DecoderBuilder::new()
398    ///     .with_config_modelpack_det_split(vec![config0, config1])
399    ///     .build()?;
400    /// # Ok(())
401    /// # }
402    /// ```
403    pub fn with_config_modelpack_det_split(mut self, boxes: Vec<configs::Detection>) -> Self {
404        let outputs = boxes.into_iter().map(ConfigOutput::Detection).collect();
405        let config = ConfigOutputs {
406            outputs,
407            ..Default::default()
408        };
409        self.config_src.replace(ConfigSource::Config(config));
410        self
411    }
412
413    /// Loads a ModelPack segmentation detection model configuration. Use
414    /// `DecoderBuilder.build()` to parse the model configuration.
415    ///
416    /// # Examples
417    /// ```rust
418    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
419    /// # fn main() -> DecoderResult<()> {
420    /// let boxes_config = configs::Boxes {
421    ///     decoder: configs::DecoderType::ModelPack,
422    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
423    ///     shape: vec![1, 8400, 1, 4],
424    ///     dshape: Vec::new(),
425    ///     normalized: Some(true),
426    /// };
427    /// let scores_config = configs::Scores {
428    ///     decoder: configs::DecoderType::ModelPack,
429    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
430    ///     shape: vec![1, 8400, 2],
431    ///     dshape: Vec::new(),
432    /// };
433    /// let seg_config = configs::Segmentation {
434    ///     decoder: configs::DecoderType::ModelPack,
435    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
436    ///     shape: vec![1, 640, 640, 3],
437    ///     dshape: Vec::new(),
438    /// };
439    /// let decoder = DecoderBuilder::new()
440    ///     .with_config_modelpack_segdet(boxes_config, scores_config, seg_config)
441    ///     .build()?;
442    /// # Ok(())
443    /// # }
444    /// ```
445    pub fn with_config_modelpack_segdet(
446        mut self,
447        boxes: configs::Boxes,
448        scores: configs::Scores,
449        segmentation: configs::Segmentation,
450    ) -> Self {
451        let config = ConfigOutputs {
452            outputs: vec![
453                ConfigOutput::Boxes(boxes),
454                ConfigOutput::Scores(scores),
455                ConfigOutput::Segmentation(segmentation),
456            ],
457            ..Default::default()
458        };
459        self.config_src.replace(ConfigSource::Config(config));
460        self
461    }
462
463    /// Loads a ModelPack segmentation split detection model configuration. Use
464    /// `DecoderBuilder.build()` to parse the model configuration.
465    ///
466    /// # Examples
467    /// ```rust
468    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
469    /// # fn main() -> DecoderResult<()> {
470    /// let config0 = configs::Detection {
471    ///     anchors: Some(vec![
472    ///         [0.36666667461395264, 0.31481480598449707],
473    ///         [0.38749998807907104, 0.4740740656852722],
474    ///         [0.5333333611488342, 0.644444465637207],
475    ///     ]),
476    ///     decoder: configs::DecoderType::ModelPack,
477    ///     quantization: Some(configs::QuantTuple(0.08547406643629074, 174)),
478    ///     shape: vec![1, 9, 15, 18],
479    ///     dshape: Vec::new(),
480    ///     normalized: Some(true),
481    /// };
482    /// let config1 = configs::Detection {
483    ///     anchors: Some(vec![
484    ///         [0.13750000298023224, 0.2074074000120163],
485    ///         [0.2541666626930237, 0.21481481194496155],
486    ///         [0.23125000298023224, 0.35185185074806213],
487    ///     ]),
488    ///     decoder: configs::DecoderType::ModelPack,
489    ///     quantization: Some(configs::QuantTuple(0.09929127991199493, 183)),
490    ///     shape: vec![1, 17, 30, 18],
491    ///     dshape: Vec::new(),
492    ///     normalized: Some(true),
493    /// };
494    /// let seg_config = configs::Segmentation {
495    ///     decoder: configs::DecoderType::ModelPack,
496    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
497    ///     shape: vec![1, 640, 640, 2],
498    ///     dshape: Vec::new(),
499    /// };
500    /// let decoder = DecoderBuilder::new()
501    ///     .with_config_modelpack_segdet_split(vec![config0, config1], seg_config)
502    ///     .build()?;
503    /// # Ok(())
504    /// # }
505    /// ```
506    pub fn with_config_modelpack_segdet_split(
507        mut self,
508        boxes: Vec<configs::Detection>,
509        segmentation: configs::Segmentation,
510    ) -> Self {
511        let mut outputs = boxes
512            .into_iter()
513            .map(ConfigOutput::Detection)
514            .collect::<Vec<_>>();
515        outputs.push(ConfigOutput::Segmentation(segmentation));
516        let config = ConfigOutputs {
517            outputs,
518            ..Default::default()
519        };
520        self.config_src.replace(ConfigSource::Config(config));
521        self
522    }
523
524    /// Loads a ModelPack segmentation model configuration. Use
525    /// `DecoderBuilder.build()` to parse the model configuration.
526    ///
527    /// # Examples
528    /// ```rust
529    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
530    /// # fn main() -> DecoderResult<()> {
531    /// let seg_config = configs::Segmentation {
532    ///     decoder: configs::DecoderType::ModelPack,
533    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
534    ///     shape: vec![1, 640, 640, 3],
535    ///     dshape: Vec::new(),
536    /// };
537    /// let decoder = DecoderBuilder::new()
538    ///     .with_config_modelpack_seg(seg_config)
539    ///     .build()?;
540    /// # Ok(())
541    /// # }
542    /// ```
543    pub fn with_config_modelpack_seg(mut self, segmentation: configs::Segmentation) -> Self {
544        let config = ConfigOutputs {
545            outputs: vec![ConfigOutput::Segmentation(segmentation)],
546            ..Default::default()
547        };
548        self.config_src.replace(ConfigSource::Config(config));
549        self
550    }
551
552    /// Add an output to the decoder configuration.
553    ///
554    /// Incrementally builds the model configuration by adding outputs one at
555    /// a time. The decoder resolves the model type from the combination of
556    /// outputs during `build()`.
557    ///
558    /// If `dshape` is non-empty on the output, `shape` is automatically
559    /// derived from it (the size component of each named dimension). This
560    /// prevents conflicts between `shape` and `dshape`.
561    ///
562    /// This uses the programmatic config path. Calling this after
563    /// `with_config_json_str()` or `with_config_yaml_str()` replaces the
564    /// string-based config source.
565    ///
566    /// # Examples
567    /// ```rust
568    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, ConfigOutput, configs};
569    /// # fn main() -> DecoderResult<()> {
570    /// let decoder = DecoderBuilder::new()
571    ///     .add_output(ConfigOutput::Scores(configs::Scores {
572    ///         decoder: configs::DecoderType::Ultralytics,
573    ///         dshape: vec![
574    ///             (configs::DimName::Batch, 1),
575    ///             (configs::DimName::NumClasses, 80),
576    ///             (configs::DimName::NumBoxes, 8400),
577    ///         ],
578    ///         ..Default::default()
579    ///     }))
580    ///     .add_output(ConfigOutput::Boxes(configs::Boxes {
581    ///         decoder: configs::DecoderType::Ultralytics,
582    ///         dshape: vec![
583    ///             (configs::DimName::Batch, 1),
584    ///             (configs::DimName::BoxCoords, 4),
585    ///             (configs::DimName::NumBoxes, 8400),
586    ///         ],
587    ///         ..Default::default()
588    ///     }))
589    ///     .build()?;
590    /// # Ok(())
591    /// # }
592    /// ```
593    pub fn add_output(mut self, output: ConfigOutput) -> Self {
594        if !matches!(self.config_src, Some(ConfigSource::Config(_))) {
595            self.config_src = Some(ConfigSource::Config(ConfigOutputs::default()));
596        }
597        if let Some(ConfigSource::Config(ref mut config)) = self.config_src {
598            config.outputs.push(Self::normalize_output(output));
599        }
600        self
601    }
602
603    /// Sets the decoder version for Ultralytics models.
604    ///
605    /// This is used with `add_output()` to specify the YOLO architecture
606    /// version when it cannot be inferred from the output shapes alone.
607    ///
608    /// - `Yolov5`, `Yolov8`, `Yolo11`: Traditional models requiring external
609    ///   NMS
610    /// - `Yolo26`: End-to-end models with NMS embedded in the model graph
611    ///
612    /// # Examples
613    /// ```rust
614    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, ConfigOutput, configs};
615    /// # fn main() -> DecoderResult<()> {
616    /// let decoder = DecoderBuilder::new()
617    ///     .add_output(ConfigOutput::Detection(configs::Detection {
618    ///         decoder: configs::DecoderType::Ultralytics,
619    ///         dshape: vec![
620    ///             (configs::DimName::Batch, 1),
621    ///             (configs::DimName::NumBoxes, 100),
622    ///             (configs::DimName::NumFeatures, 6),
623    ///         ],
624    ///         ..Default::default()
625    ///     }))
626    ///     .with_decoder_version(configs::DecoderVersion::Yolo26)
627    ///     .build()?;
628    /// # Ok(())
629    /// # }
630    /// ```
631    pub fn with_decoder_version(mut self, version: configs::DecoderVersion) -> Self {
632        if !matches!(self.config_src, Some(ConfigSource::Config(_))) {
633            self.config_src = Some(ConfigSource::Config(ConfigOutputs::default()));
634        }
635        if let Some(ConfigSource::Config(ref mut config)) = self.config_src {
636            config.decoder_version = Some(version);
637        }
638        self
639    }
640
641    /// Normalize an output: if dshape is non-empty, derive shape from it.
642    fn normalize_output(mut output: ConfigOutput) -> ConfigOutput {
643        fn normalize_shape(shape: &mut Vec<usize>, dshape: &[(configs::DimName, usize)]) {
644            if !dshape.is_empty() {
645                *shape = dshape.iter().map(|(_, size)| *size).collect();
646            }
647        }
648        match &mut output {
649            ConfigOutput::Detection(c) => normalize_shape(&mut c.shape, &c.dshape),
650            ConfigOutput::Boxes(c) => normalize_shape(&mut c.shape, &c.dshape),
651            ConfigOutput::Scores(c) => normalize_shape(&mut c.shape, &c.dshape),
652            ConfigOutput::Protos(c) => normalize_shape(&mut c.shape, &c.dshape),
653            ConfigOutput::Segmentation(c) => normalize_shape(&mut c.shape, &c.dshape),
654            ConfigOutput::MaskCoefficients(c) => normalize_shape(&mut c.shape, &c.dshape),
655            ConfigOutput::Mask(c) => normalize_shape(&mut c.shape, &c.dshape),
656            ConfigOutput::Classes(c) => normalize_shape(&mut c.shape, &c.dshape),
657        }
658        output
659    }
660
661    /// Sets the scores threshold of the decoder
662    ///
663    /// # Examples
664    /// ```rust
665    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
666    /// # fn main() -> DecoderResult<()> {
667    /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
668    /// let decoder = DecoderBuilder::new()
669    ///     .with_config_json_str(config_json)
670    ///     .with_score_threshold(0.654)
671    ///     .build()?;
672    /// assert_eq!(decoder.score_threshold, 0.654);
673    /// # Ok(())
674    /// # }
675    /// ```
676    pub fn with_score_threshold(mut self, score_threshold: f32) -> Self {
677        self.score_threshold = score_threshold;
678        self
679    }
680
681    /// Sets the IOU threshold of the decoder. Has no effect when NMS is set to
682    /// `None`
683    ///
684    /// # Examples
685    /// ```rust
686    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
687    /// # fn main() -> DecoderResult<()> {
688    /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
689    /// let decoder = DecoderBuilder::new()
690    ///     .with_config_json_str(config_json)
691    ///     .with_iou_threshold(0.654)
692    ///     .build()?;
693    /// assert_eq!(decoder.iou_threshold, 0.654);
694    /// # Ok(())
695    /// # }
696    /// ```
697    pub fn with_iou_threshold(mut self, iou_threshold: f32) -> Self {
698        self.iou_threshold = iou_threshold;
699        self
700    }
701
702    /// Sets the NMS mode for the decoder.
703    ///
704    /// - `Some(Nms::ClassAgnostic)` — class-agnostic NMS (default): suppress
705    ///   overlapping boxes regardless of class label
706    /// - `Some(Nms::ClassAware)` — class-aware NMS: only suppress boxes that
707    ///   share the same class label AND overlap above the IoU threshold
708    /// - `None` — bypass NMS entirely (for end-to-end models with embedded NMS)
709    ///
710    /// # Examples
711    /// ```rust
712    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, configs::Nms};
713    /// # fn main() -> DecoderResult<()> {
714    /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
715    /// let decoder = DecoderBuilder::new()
716    ///     .with_config_json_str(config_json)
717    ///     .with_nms(Some(Nms::ClassAware))
718    ///     .build()?;
719    /// assert_eq!(decoder.nms, Some(Nms::ClassAware));
720    /// # Ok(())
721    /// # }
722    /// ```
723    pub fn with_nms(mut self, nms: Option<configs::Nms>) -> Self {
724        self.nms = nms;
725        self
726    }
727
728    /// Builds the decoder with the given settings. If the config is a JSON or
729    /// YAML string, this will deserialize the JSON or YAML and then parse the
730    /// configuration information.
731    ///
732    /// # Examples
733    /// ```rust
734    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
735    /// # fn main() -> DecoderResult<()> {
736    /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
737    /// let decoder = DecoderBuilder::new()
738    ///     .with_config_json_str(config_json)
739    ///     .with_score_threshold(0.654)
740    ///     .build()?;
741    /// # Ok(())
742    /// # }
743    /// ```
744    pub fn build(self) -> Result<Decoder, DecoderError> {
745        let config = match self.config_src {
746            Some(ConfigSource::Json(s)) => serde_json::from_str(&s)?,
747            Some(ConfigSource::Yaml(s)) => serde_yaml::from_str(&s)?,
748            Some(ConfigSource::Config(c)) => c,
749            None => return Err(DecoderError::NoConfig),
750        };
751
752        // Extract normalized flag from config outputs
753        let normalized = Self::get_normalized(&config.outputs);
754
755        // Use NMS from config if present, otherwise use builder's NMS setting
756        let nms = config.nms.or(self.nms);
757        let model_type = Self::get_model_type(config)?;
758
759        Ok(Decoder {
760            model_type,
761            iou_threshold: self.iou_threshold,
762            score_threshold: self.score_threshold,
763            nms,
764            normalized,
765        })
766    }
767
768    /// Extracts the normalized flag from config outputs.
769    /// - `Some(true)`: Boxes are in normalized [0,1] coordinates
770    /// - `Some(false)`: Boxes are in pixel coordinates
771    /// - `None`: Unknown (not specified in config), caller must infer
772    fn get_normalized(outputs: &[ConfigOutput]) -> Option<bool> {
773        for output in outputs {
774            match output {
775                ConfigOutput::Detection(det) => return det.normalized,
776                ConfigOutput::Boxes(boxes) => return boxes.normalized,
777                _ => {}
778            }
779        }
780        None // not specified
781    }
782
783    fn get_model_type(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
784        // yolo or modelpack
785        let mut yolo = false;
786        let mut modelpack = false;
787        for c in &configs.outputs {
788            match c.decoder() {
789                DecoderType::ModelPack => modelpack = true,
790                DecoderType::Ultralytics => yolo = true,
791            }
792        }
793        match (modelpack, yolo) {
794            (true, true) => Err(DecoderError::InvalidConfig(
795                "Both ModelPack and Yolo outputs found in config".to_string(),
796            )),
797            (true, false) => Self::get_model_type_modelpack(configs),
798            (false, true) => Self::get_model_type_yolo(configs),
799            (false, false) => Err(DecoderError::InvalidConfig(
800                "No outputs found in config".to_string(),
801            )),
802        }
803    }
804
805    fn get_model_type_yolo(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
806        let mut boxes = None;
807        let mut protos = None;
808        let mut split_boxes = None;
809        let mut split_scores = None;
810        let mut split_mask_coeff = None;
811        let mut split_classes = None;
812        for c in configs.outputs {
813            match c {
814                ConfigOutput::Detection(detection) => boxes = Some(detection),
815                ConfigOutput::Segmentation(_) => {
816                    return Err(DecoderError::InvalidConfig(
817                        "Invalid Segmentation output with Yolo decoder".to_string(),
818                    ));
819                }
820                ConfigOutput::Protos(protos_) => protos = Some(protos_),
821                ConfigOutput::Mask(_) => {
822                    return Err(DecoderError::InvalidConfig(
823                        "Invalid Mask output with Yolo decoder".to_string(),
824                    ));
825                }
826                ConfigOutput::Scores(scores) => split_scores = Some(scores),
827                ConfigOutput::Boxes(boxes) => split_boxes = Some(boxes),
828                ConfigOutput::MaskCoefficients(mask_coeff) => split_mask_coeff = Some(mask_coeff),
829                ConfigOutput::Classes(classes) => split_classes = Some(classes),
830            }
831        }
832
833        // Use end-to-end model types when:
834        // 1. decoder_version is explicitly set to Yolo26 (definitive), OR
835        //    decoder_version is not set but the dshapes are (batch, num_boxes,
836        //    num_features)
837        let is_end_to_end_dshape = boxes.as_ref().is_some_and(|b| {
838            let dims = b.dshape.iter().map(|(d, _)| *d).collect::<Vec<_>>();
839            dims == vec![DimName::Batch, DimName::NumBoxes, DimName::NumFeatures]
840        });
841
842        let is_end_to_end = configs
843            .decoder_version
844            .map(|v| v.is_end_to_end())
845            .unwrap_or(is_end_to_end_dshape);
846
847        if is_end_to_end {
848            if let Some(boxes) = boxes {
849                if let Some(protos) = protos {
850                    Self::verify_yolo_seg_det_26(&boxes, &protos)?;
851                    return Ok(ModelType::YoloEndToEndSegDet { boxes, protos });
852                } else {
853                    Self::verify_yolo_det_26(&boxes)?;
854                    return Ok(ModelType::YoloEndToEndDet { boxes });
855                }
856            } else if let (Some(split_boxes), Some(split_scores), Some(split_classes)) =
857                (split_boxes, split_scores, split_classes)
858            {
859                if let (Some(split_mask_coeff), Some(protos)) = (split_mask_coeff, protos) {
860                    Self::verify_yolo_split_end_to_end_segdet(
861                        &split_boxes,
862                        &split_scores,
863                        &split_classes,
864                        &split_mask_coeff,
865                        &protos,
866                    )?;
867                    return Ok(ModelType::YoloSplitEndToEndSegDet {
868                        boxes: split_boxes,
869                        scores: split_scores,
870                        classes: split_classes,
871                        mask_coeff: split_mask_coeff,
872                        protos,
873                    });
874                }
875                Self::verify_yolo_split_end_to_end_det(
876                    &split_boxes,
877                    &split_scores,
878                    &split_classes,
879                )?;
880                return Ok(ModelType::YoloSplitEndToEndDet {
881                    boxes: split_boxes,
882                    scores: split_scores,
883                    classes: split_classes,
884                });
885            } else {
886                return Err(DecoderError::InvalidConfig(
887                    "Invalid Yolo end-to-end model outputs".to_string(),
888                ));
889            }
890        }
891
892        if let Some(boxes) = boxes {
893            match (split_mask_coeff, protos) {
894                (Some(mask_coeff), Some(protos)) => {
895                    // 2-way split: combined detection + separate mask_coeff + protos
896                    Self::verify_yolo_seg_det_2way(&boxes, &mask_coeff, &protos)?;
897                    Ok(ModelType::YoloSegDet2Way {
898                        boxes,
899                        mask_coeff,
900                        protos,
901                    })
902                }
903                (_, Some(protos)) => {
904                    // Unsplit: mask_coefs embedded in detection tensor
905                    Self::verify_yolo_seg_det(&boxes, &protos)?;
906                    Ok(ModelType::YoloSegDet { boxes, protos })
907                }
908                _ => {
909                    Self::verify_yolo_det(&boxes)?;
910                    Ok(ModelType::YoloDet { boxes })
911                }
912            }
913        } else if let (Some(boxes), Some(scores)) = (split_boxes, split_scores) {
914            if let (Some(mask_coeff), Some(protos)) = (split_mask_coeff, protos) {
915                Self::verify_yolo_split_segdet(&boxes, &scores, &mask_coeff, &protos)?;
916                Ok(ModelType::YoloSplitSegDet {
917                    boxes,
918                    scores,
919                    mask_coeff,
920                    protos,
921                })
922            } else {
923                Self::verify_yolo_split_det(&boxes, &scores)?;
924                Ok(ModelType::YoloSplitDet { boxes, scores })
925            }
926        } else {
927            Err(DecoderError::InvalidConfig(
928                "Invalid Yolo model outputs".to_string(),
929            ))
930        }
931    }
932
933    fn verify_yolo_det(detect: &configs::Detection) -> Result<(), DecoderError> {
934        if detect.shape.len() != 3 {
935            return Err(DecoderError::InvalidConfig(format!(
936                "Invalid Yolo Detection shape {:?}",
937                detect.shape
938            )));
939        }
940
941        Self::verify_dshapes(
942            &detect.dshape,
943            &detect.shape,
944            "Detection",
945            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
946        )?;
947        if !detect.dshape.is_empty() {
948            Self::get_class_count(&detect.dshape, None, None)?;
949        } else {
950            Self::get_class_count_no_dshape(detect.into(), None)?;
951        }
952
953        Ok(())
954    }
955
956    fn verify_yolo_det_26(detect: &configs::Detection) -> Result<(), DecoderError> {
957        if detect.shape.len() != 3 {
958            return Err(DecoderError::InvalidConfig(format!(
959                "Invalid Yolo Detection shape {:?}",
960                detect.shape
961            )));
962        }
963
964        Self::verify_dshapes(
965            &detect.dshape,
966            &detect.shape,
967            "Detection",
968            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
969        )?;
970
971        if !detect.shape.contains(&6) {
972            return Err(DecoderError::InvalidConfig(
973                "Yolo26 Detection must have 6 features".to_string(),
974            ));
975        }
976
977        Ok(())
978    }
979
980    fn verify_yolo_seg_det(
981        detection: &configs::Detection,
982        protos: &configs::Protos,
983    ) -> Result<(), DecoderError> {
984        if detection.shape.len() != 3 {
985            return Err(DecoderError::InvalidConfig(format!(
986                "Invalid Yolo Detection shape {:?}",
987                detection.shape
988            )));
989        }
990        if protos.shape.len() != 4 {
991            return Err(DecoderError::InvalidConfig(format!(
992                "Invalid Yolo Protos shape {:?}",
993                protos.shape
994            )));
995        }
996
997        Self::verify_dshapes(
998            &detection.dshape,
999            &detection.shape,
1000            "Detection",
1001            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1002        )?;
1003        Self::verify_dshapes(
1004            &protos.dshape,
1005            &protos.shape,
1006            "Protos",
1007            &[
1008                DimName::Batch,
1009                DimName::Height,
1010                DimName::Width,
1011                DimName::NumProtos,
1012            ],
1013        )?;
1014
1015        let protos_count = Self::get_protos_count(&protos.dshape)
1016            .unwrap_or_else(|| protos.shape[1].min(protos.shape[3]));
1017        log::debug!("Protos count: {}", protos_count);
1018        log::debug!("Detection dshape: {:?}", detection.dshape);
1019        let classes = if !detection.dshape.is_empty() {
1020            Self::get_class_count(&detection.dshape, Some(protos_count), None)?
1021        } else {
1022            Self::get_class_count_no_dshape(detection.into(), Some(protos_count))?
1023        };
1024
1025        if classes == 0 {
1026            return Err(DecoderError::InvalidConfig(
1027                "Yolo Segmentation Detection has zero classes".to_string(),
1028            ));
1029        }
1030
1031        Ok(())
1032    }
1033
1034    fn verify_yolo_seg_det_2way(
1035        detection: &configs::Detection,
1036        mask_coeff: &configs::MaskCoefficients,
1037        protos: &configs::Protos,
1038    ) -> Result<(), DecoderError> {
1039        if detection.shape.len() != 3 {
1040            return Err(DecoderError::InvalidConfig(format!(
1041                "Invalid Yolo 2-Way Detection shape {:?}",
1042                detection.shape
1043            )));
1044        }
1045        if mask_coeff.shape.len() != 3 {
1046            return Err(DecoderError::InvalidConfig(format!(
1047                "Invalid Yolo 2-Way Mask Coefficients shape {:?}",
1048                mask_coeff.shape
1049            )));
1050        }
1051        if protos.shape.len() != 4 {
1052            return Err(DecoderError::InvalidConfig(format!(
1053                "Invalid Yolo 2-Way Protos shape {:?}",
1054                protos.shape
1055            )));
1056        }
1057
1058        Self::verify_dshapes(
1059            &detection.dshape,
1060            &detection.shape,
1061            "Detection",
1062            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1063        )?;
1064        Self::verify_dshapes(
1065            &mask_coeff.dshape,
1066            &mask_coeff.shape,
1067            "Mask Coefficients",
1068            &[DimName::Batch, DimName::NumProtos, DimName::NumBoxes],
1069        )?;
1070        Self::verify_dshapes(
1071            &protos.dshape,
1072            &protos.shape,
1073            "Protos",
1074            &[
1075                DimName::Batch,
1076                DimName::Height,
1077                DimName::Width,
1078                DimName::NumProtos,
1079            ],
1080        )?;
1081
1082        // Validate num_boxes match between detection and mask_coeff
1083        let det_num = Self::get_box_count(&detection.dshape).unwrap_or(detection.shape[2]);
1084        let mask_num = Self::get_box_count(&mask_coeff.dshape).unwrap_or(mask_coeff.shape[2]);
1085        if det_num != mask_num {
1086            return Err(DecoderError::InvalidConfig(format!(
1087                "Yolo 2-Way Detection num_boxes {} incompatible with Mask Coefficients num_boxes {}",
1088                det_num, mask_num
1089            )));
1090        }
1091
1092        // Validate mask_coeff channels match protos channels
1093        let mask_channels = if !mask_coeff.dshape.is_empty() {
1094            Self::get_protos_count(&mask_coeff.dshape).ok_or_else(|| {
1095                DecoderError::InvalidConfig(
1096                    "Could not find num_protos in mask_coeff config".to_string(),
1097                )
1098            })?
1099        } else {
1100            mask_coeff.shape[1]
1101        };
1102        let proto_channels = if !protos.dshape.is_empty() {
1103            Self::get_protos_count(&protos.dshape).ok_or_else(|| {
1104                DecoderError::InvalidConfig(
1105                    "Could not find num_protos in protos config".to_string(),
1106                )
1107            })?
1108        } else {
1109            protos.shape[1].min(protos.shape[3])
1110        };
1111        if mask_channels != proto_channels {
1112            return Err(DecoderError::InvalidConfig(format!(
1113                "Yolo 2-Way Protos channels {} incompatible with Mask Coefficients channels {}",
1114                proto_channels, mask_channels
1115            )));
1116        }
1117
1118        // Validate detection has classes (nc+4 features, no mask_coefs embedded)
1119        if !detection.dshape.is_empty() {
1120            Self::get_class_count(&detection.dshape, None, None)?;
1121        } else {
1122            Self::get_class_count_no_dshape(detection.into(), None)?;
1123        }
1124
1125        Ok(())
1126    }
1127
1128    fn verify_yolo_seg_det_26(
1129        detection: &configs::Detection,
1130        protos: &configs::Protos,
1131    ) -> Result<(), DecoderError> {
1132        if detection.shape.len() != 3 {
1133            return Err(DecoderError::InvalidConfig(format!(
1134                "Invalid Yolo Detection shape {:?}",
1135                detection.shape
1136            )));
1137        }
1138        if protos.shape.len() != 4 {
1139            return Err(DecoderError::InvalidConfig(format!(
1140                "Invalid Yolo Protos shape {:?}",
1141                protos.shape
1142            )));
1143        }
1144
1145        Self::verify_dshapes(
1146            &detection.dshape,
1147            &detection.shape,
1148            "Detection",
1149            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1150        )?;
1151        Self::verify_dshapes(
1152            &protos.dshape,
1153            &protos.shape,
1154            "Protos",
1155            &[
1156                DimName::Batch,
1157                DimName::Height,
1158                DimName::Width,
1159                DimName::NumProtos,
1160            ],
1161        )?;
1162
1163        let protos_count = Self::get_protos_count(&protos.dshape)
1164            .unwrap_or_else(|| protos.shape[1].min(protos.shape[3]));
1165        log::debug!("Protos count: {}", protos_count);
1166        log::debug!("Detection dshape: {:?}", detection.dshape);
1167
1168        if !detection.shape.contains(&(6 + protos_count)) {
1169            return Err(DecoderError::InvalidConfig(format!(
1170                "Yolo26 Segmentation Detection must have num_features be 6 + num_protos = {}",
1171                6 + protos_count
1172            )));
1173        }
1174
1175        Ok(())
1176    }
1177
1178    fn verify_yolo_split_det(
1179        boxes: &configs::Boxes,
1180        scores: &configs::Scores,
1181    ) -> Result<(), DecoderError> {
1182        if boxes.shape.len() != 3 {
1183            return Err(DecoderError::InvalidConfig(format!(
1184                "Invalid Yolo Split Boxes shape {:?}",
1185                boxes.shape
1186            )));
1187        }
1188        if scores.shape.len() != 3 {
1189            return Err(DecoderError::InvalidConfig(format!(
1190                "Invalid Yolo Split Scores shape {:?}",
1191                scores.shape
1192            )));
1193        }
1194
1195        Self::verify_dshapes(
1196            &boxes.dshape,
1197            &boxes.shape,
1198            "Boxes",
1199            &[DimName::Batch, DimName::BoxCoords, DimName::NumBoxes],
1200        )?;
1201        Self::verify_dshapes(
1202            &scores.dshape,
1203            &scores.shape,
1204            "Scores",
1205            &[DimName::Batch, DimName::NumClasses, DimName::NumBoxes],
1206        )?;
1207
1208        let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[2]);
1209        let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[2]);
1210
1211        if boxes_num != scores_num {
1212            return Err(DecoderError::InvalidConfig(format!(
1213                "Yolo Split Detection Boxes num {} incompatible with Scores num {}",
1214                boxes_num, scores_num
1215            )));
1216        }
1217
1218        Ok(())
1219    }
1220
1221    fn verify_yolo_split_segdet(
1222        boxes: &configs::Boxes,
1223        scores: &configs::Scores,
1224        mask_coeff: &configs::MaskCoefficients,
1225        protos: &configs::Protos,
1226    ) -> Result<(), DecoderError> {
1227        if boxes.shape.len() != 3 {
1228            return Err(DecoderError::InvalidConfig(format!(
1229                "Invalid Yolo Split Boxes shape {:?}",
1230                boxes.shape
1231            )));
1232        }
1233        if scores.shape.len() != 3 {
1234            return Err(DecoderError::InvalidConfig(format!(
1235                "Invalid Yolo Split Scores shape {:?}",
1236                scores.shape
1237            )));
1238        }
1239
1240        if mask_coeff.shape.len() != 3 {
1241            return Err(DecoderError::InvalidConfig(format!(
1242                "Invalid Yolo Split Mask Coefficients shape {:?}",
1243                mask_coeff.shape
1244            )));
1245        }
1246
1247        if protos.shape.len() != 4 {
1248            return Err(DecoderError::InvalidConfig(format!(
1249                "Invalid Yolo Protos shape {:?}",
1250                mask_coeff.shape
1251            )));
1252        }
1253
1254        Self::verify_dshapes(
1255            &boxes.dshape,
1256            &boxes.shape,
1257            "Boxes",
1258            &[DimName::Batch, DimName::BoxCoords, DimName::NumBoxes],
1259        )?;
1260        Self::verify_dshapes(
1261            &scores.dshape,
1262            &scores.shape,
1263            "Scores",
1264            &[DimName::Batch, DimName::NumClasses, DimName::NumBoxes],
1265        )?;
1266        Self::verify_dshapes(
1267            &mask_coeff.dshape,
1268            &mask_coeff.shape,
1269            "Mask Coefficients",
1270            &[DimName::Batch, DimName::NumProtos, DimName::NumBoxes],
1271        )?;
1272        Self::verify_dshapes(
1273            &protos.dshape,
1274            &protos.shape,
1275            "Protos",
1276            &[
1277                DimName::Batch,
1278                DimName::Height,
1279                DimName::Width,
1280                DimName::NumProtos,
1281            ],
1282        )?;
1283
1284        let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[2]);
1285        let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[2]);
1286        let mask_num = Self::get_box_count(&mask_coeff.dshape).unwrap_or(mask_coeff.shape[2]);
1287
1288        let mask_channels = if !mask_coeff.dshape.is_empty() {
1289            Self::get_protos_count(&mask_coeff.dshape).ok_or_else(|| {
1290                DecoderError::InvalidConfig("Could not find num_protos in config".to_string())
1291            })?
1292        } else {
1293            mask_coeff.shape[1]
1294        };
1295        let proto_channels = if !protos.dshape.is_empty() {
1296            Self::get_protos_count(&protos.dshape).ok_or_else(|| {
1297                DecoderError::InvalidConfig("Could not find num_protos in config".to_string())
1298            })?
1299        } else {
1300            protos.shape[1].min(protos.shape[3])
1301        };
1302
1303        if boxes_num != scores_num {
1304            return Err(DecoderError::InvalidConfig(format!(
1305                "Yolo Split Detection Boxes num {} incompatible with Scores num {}",
1306                boxes_num, scores_num
1307            )));
1308        }
1309
1310        if boxes_num != mask_num {
1311            return Err(DecoderError::InvalidConfig(format!(
1312                "Yolo Split Detection Boxes num {} incompatible with Mask Coefficients num {}",
1313                boxes_num, mask_num
1314            )));
1315        }
1316
1317        if proto_channels != mask_channels {
1318            return Err(DecoderError::InvalidConfig(format!(
1319                "Yolo Protos channels {} incompatible with Mask Coefficients channels {}",
1320                proto_channels, mask_channels
1321            )));
1322        }
1323
1324        Ok(())
1325    }
1326
1327    fn verify_yolo_split_end_to_end_det(
1328        boxes: &configs::Boxes,
1329        scores: &configs::Scores,
1330        classes: &configs::Classes,
1331    ) -> Result<(), DecoderError> {
1332        if boxes.shape.len() != 3 || !boxes.shape.contains(&4) {
1333            return Err(DecoderError::InvalidConfig(format!(
1334                "Split end-to-end boxes must be [batch, N, 4], got {:?}",
1335                boxes.shape
1336            )));
1337        }
1338        if scores.shape.len() != 3 || !scores.shape.contains(&1) {
1339            return Err(DecoderError::InvalidConfig(format!(
1340                "Split end-to-end scores must be [batch, N, 1], got {:?}",
1341                scores.shape
1342            )));
1343        }
1344        if classes.shape.len() != 3 || !classes.shape.contains(&1) {
1345            return Err(DecoderError::InvalidConfig(format!(
1346                "Split end-to-end classes must be [batch, N, 1], got {:?}",
1347                classes.shape
1348            )));
1349        }
1350        Ok(())
1351    }
1352
1353    fn verify_yolo_split_end_to_end_segdet(
1354        boxes: &configs::Boxes,
1355        scores: &configs::Scores,
1356        classes: &configs::Classes,
1357        mask_coeff: &configs::MaskCoefficients,
1358        protos: &configs::Protos,
1359    ) -> Result<(), DecoderError> {
1360        Self::verify_yolo_split_end_to_end_det(boxes, scores, classes)?;
1361        if mask_coeff.shape.len() != 3 {
1362            return Err(DecoderError::InvalidConfig(format!(
1363                "Invalid split end-to-end mask coefficients shape {:?}",
1364                mask_coeff.shape
1365            )));
1366        }
1367        if protos.shape.len() != 4 {
1368            return Err(DecoderError::InvalidConfig(format!(
1369                "Invalid protos shape {:?}",
1370                protos.shape
1371            )));
1372        }
1373        Ok(())
1374    }
1375
1376    fn get_model_type_modelpack(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
1377        let mut split_decoders = Vec::new();
1378        let mut segment_ = None;
1379        let mut scores_ = None;
1380        let mut boxes_ = None;
1381        for c in configs.outputs {
1382            match c {
1383                ConfigOutput::Detection(detection) => split_decoders.push(detection),
1384                ConfigOutput::Segmentation(segmentation) => segment_ = Some(segmentation),
1385                ConfigOutput::Mask(_) => {}
1386                ConfigOutput::Protos(_) => {
1387                    return Err(DecoderError::InvalidConfig(
1388                        "ModelPack should not have protos".to_string(),
1389                    ));
1390                }
1391                ConfigOutput::Scores(scores) => scores_ = Some(scores),
1392                ConfigOutput::Boxes(boxes) => boxes_ = Some(boxes),
1393                ConfigOutput::MaskCoefficients(_) => {
1394                    return Err(DecoderError::InvalidConfig(
1395                        "ModelPack should not have mask coefficients".to_string(),
1396                    ));
1397                }
1398                ConfigOutput::Classes(_) => {
1399                    return Err(DecoderError::InvalidConfig(
1400                        "ModelPack should not have classes output".to_string(),
1401                    ));
1402                }
1403            }
1404        }
1405
1406        if let Some(segmentation) = segment_ {
1407            if !split_decoders.is_empty() {
1408                let classes = Self::verify_modelpack_split_det(&split_decoders)?;
1409                Self::verify_modelpack_seg(&segmentation, Some(classes))?;
1410                Ok(ModelType::ModelPackSegDetSplit {
1411                    detection: split_decoders,
1412                    segmentation,
1413                })
1414            } else if let (Some(scores), Some(boxes)) = (scores_, boxes_) {
1415                let classes = Self::verify_modelpack_det(&boxes, &scores)?;
1416                Self::verify_modelpack_seg(&segmentation, Some(classes))?;
1417                Ok(ModelType::ModelPackSegDet {
1418                    boxes,
1419                    scores,
1420                    segmentation,
1421                })
1422            } else {
1423                Self::verify_modelpack_seg(&segmentation, None)?;
1424                Ok(ModelType::ModelPackSeg { segmentation })
1425            }
1426        } else if !split_decoders.is_empty() {
1427            Self::verify_modelpack_split_det(&split_decoders)?;
1428            Ok(ModelType::ModelPackDetSplit {
1429                detection: split_decoders,
1430            })
1431        } else if let (Some(scores), Some(boxes)) = (scores_, boxes_) {
1432            Self::verify_modelpack_det(&boxes, &scores)?;
1433            Ok(ModelType::ModelPackDet { boxes, scores })
1434        } else {
1435            Err(DecoderError::InvalidConfig(
1436                "Invalid ModelPack model outputs".to_string(),
1437            ))
1438        }
1439    }
1440
1441    fn verify_modelpack_det(
1442        boxes: &configs::Boxes,
1443        scores: &configs::Scores,
1444    ) -> Result<usize, DecoderError> {
1445        if boxes.shape.len() != 4 {
1446            return Err(DecoderError::InvalidConfig(format!(
1447                "Invalid ModelPack Boxes shape {:?}",
1448                boxes.shape
1449            )));
1450        }
1451        if scores.shape.len() != 3 {
1452            return Err(DecoderError::InvalidConfig(format!(
1453                "Invalid ModelPack Scores shape {:?}",
1454                scores.shape
1455            )));
1456        }
1457
1458        Self::verify_dshapes(
1459            &boxes.dshape,
1460            &boxes.shape,
1461            "Boxes",
1462            &[
1463                DimName::Batch,
1464                DimName::NumBoxes,
1465                DimName::Padding,
1466                DimName::BoxCoords,
1467            ],
1468        )?;
1469        Self::verify_dshapes(
1470            &scores.dshape,
1471            &scores.shape,
1472            "Scores",
1473            &[DimName::Batch, DimName::NumBoxes, DimName::NumClasses],
1474        )?;
1475
1476        let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[1]);
1477        let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[1]);
1478
1479        if boxes_num != scores_num {
1480            return Err(DecoderError::InvalidConfig(format!(
1481                "ModelPack Detection Boxes num {} incompatible with Scores num {}",
1482                boxes_num, scores_num
1483            )));
1484        }
1485
1486        let num_classes = if !scores.dshape.is_empty() {
1487            Self::get_class_count(&scores.dshape, None, None)?
1488        } else {
1489            Self::get_class_count_no_dshape(scores.into(), None)?
1490        };
1491
1492        Ok(num_classes)
1493    }
1494
1495    fn verify_modelpack_split_det(boxes: &[configs::Detection]) -> Result<usize, DecoderError> {
1496        let mut num_classes = None;
1497        for b in boxes {
1498            let Some(num_anchors) = b.anchors.as_ref().map(|a| a.len()) else {
1499                return Err(DecoderError::InvalidConfig(
1500                    "ModelPack Split Detection missing anchors".to_string(),
1501                ));
1502            };
1503
1504            if num_anchors == 0 {
1505                return Err(DecoderError::InvalidConfig(
1506                    "ModelPack Split Detection has zero anchors".to_string(),
1507                ));
1508            }
1509
1510            if b.shape.len() != 4 {
1511                return Err(DecoderError::InvalidConfig(format!(
1512                    "Invalid ModelPack Split Detection shape {:?}",
1513                    b.shape
1514                )));
1515            }
1516
1517            Self::verify_dshapes(
1518                &b.dshape,
1519                &b.shape,
1520                "Split Detection",
1521                &[
1522                    DimName::Batch,
1523                    DimName::Height,
1524                    DimName::Width,
1525                    DimName::NumAnchorsXFeatures,
1526                ],
1527            )?;
1528            let classes = if !b.dshape.is_empty() {
1529                Self::get_class_count(&b.dshape, None, Some(num_anchors))?
1530            } else {
1531                Self::get_class_count_no_dshape(b.into(), None)?
1532            };
1533
1534            match num_classes {
1535                Some(n) => {
1536                    if n != classes {
1537                        return Err(DecoderError::InvalidConfig(format!(
1538                            "ModelPack Split Detection inconsistent number of classes: previous {}, current {}",
1539                            n, classes
1540                        )));
1541                    }
1542                }
1543                None => {
1544                    num_classes = Some(classes);
1545                }
1546            }
1547        }
1548
1549        Ok(num_classes.unwrap_or(0))
1550    }
1551
1552    fn verify_modelpack_seg(
1553        segmentation: &configs::Segmentation,
1554        classes: Option<usize>,
1555    ) -> Result<(), DecoderError> {
1556        if segmentation.shape.len() != 4 {
1557            return Err(DecoderError::InvalidConfig(format!(
1558                "Invalid ModelPack Segmentation shape {:?}",
1559                segmentation.shape
1560            )));
1561        }
1562        Self::verify_dshapes(
1563            &segmentation.dshape,
1564            &segmentation.shape,
1565            "Segmentation",
1566            &[
1567                DimName::Batch,
1568                DimName::Height,
1569                DimName::Width,
1570                DimName::NumClasses,
1571            ],
1572        )?;
1573
1574        if let Some(classes) = classes {
1575            let seg_classes = if !segmentation.dshape.is_empty() {
1576                Self::get_class_count(&segmentation.dshape, None, None)?
1577            } else {
1578                Self::get_class_count_no_dshape(segmentation.into(), None)?
1579            };
1580
1581            if seg_classes != classes + 1 {
1582                return Err(DecoderError::InvalidConfig(format!(
1583                    "ModelPack Segmentation channels {} incompatible with number of classes {}",
1584                    seg_classes, classes
1585                )));
1586            }
1587        }
1588        Ok(())
1589    }
1590
1591    // verifies that dshapes match the given shape
1592    fn verify_dshapes(
1593        dshape: &[(DimName, usize)],
1594        shape: &[usize],
1595        name: &str,
1596        dims: &[DimName],
1597    ) -> Result<(), DecoderError> {
1598        for s in shape {
1599            if *s == 0 {
1600                return Err(DecoderError::InvalidConfig(format!(
1601                    "{} shape has zero dimension",
1602                    name
1603                )));
1604            }
1605        }
1606
1607        if shape.len() != dims.len() {
1608            return Err(DecoderError::InvalidConfig(format!(
1609                "{} shape length {} does not match expected dims length {}",
1610                name,
1611                shape.len(),
1612                dims.len()
1613            )));
1614        }
1615
1616        if dshape.is_empty() {
1617            return Ok(());
1618        }
1619        // check the dshape lengths match the shape lengths
1620        if dshape.len() != shape.len() {
1621            return Err(DecoderError::InvalidConfig(format!(
1622                "{} dshape length does not match shape length",
1623                name
1624            )));
1625        }
1626
1627        // check the dshape values match the shape values
1628        for ((dim_name, dim_size), shape_size) in dshape.iter().zip(shape) {
1629            if dim_size != shape_size {
1630                return Err(DecoderError::InvalidConfig(format!(
1631                    "{} dshape dimension {} size {} does not match shape size {}",
1632                    name, dim_name, dim_size, shape_size
1633                )));
1634            }
1635            if *dim_name == DimName::Padding && *dim_size != 1 {
1636                return Err(DecoderError::InvalidConfig(
1637                    "Padding dimension size must be 1".to_string(),
1638                ));
1639            }
1640
1641            if *dim_name == DimName::BoxCoords && *dim_size != 4 {
1642                return Err(DecoderError::InvalidConfig(
1643                    "BoxCoords dimension size must be 4".to_string(),
1644                ));
1645            }
1646        }
1647
1648        let dims_present = HashSet::<DimName>::from_iter(dshape.iter().map(|(name, _)| *name));
1649        for dim in dims {
1650            if !dims_present.contains(dim) {
1651                return Err(DecoderError::InvalidConfig(format!(
1652                    "{} dshape missing required dimension {:?}",
1653                    name, dim
1654                )));
1655            }
1656        }
1657
1658        Ok(())
1659    }
1660
1661    fn get_box_count(dshape: &[(DimName, usize)]) -> Option<usize> {
1662        for (dim_name, dim_size) in dshape {
1663            if *dim_name == DimName::NumBoxes {
1664                return Some(*dim_size);
1665            }
1666        }
1667        None
1668    }
1669
1670    fn get_class_count_no_dshape(
1671        config: ConfigOutputRef,
1672        protos: Option<usize>,
1673    ) -> Result<usize, DecoderError> {
1674        match config {
1675            ConfigOutputRef::Detection(detection) => match detection.decoder {
1676                DecoderType::Ultralytics => {
1677                    if detection.shape[1] <= 4 + protos.unwrap_or(0) {
1678                        return Err(DecoderError::InvalidConfig(format!(
1679                            "Invalid shape: Yolo num_features {} must be greater than {}",
1680                            detection.shape[1],
1681                            4 + protos.unwrap_or(0),
1682                        )));
1683                    }
1684                    Ok(detection.shape[1] - 4 - protos.unwrap_or(0))
1685                }
1686                DecoderType::ModelPack => {
1687                    let Some(num_anchors) = detection.anchors.as_ref().map(|a| a.len()) else {
1688                        return Err(DecoderError::Internal(
1689                            "ModelPack Detection missing anchors".to_string(),
1690                        ));
1691                    };
1692                    let anchors_x_features = detection.shape[3];
1693                    if anchors_x_features <= num_anchors * 5 {
1694                        return Err(DecoderError::InvalidConfig(format!(
1695                            "Invalid ModelPack Split Detection shape: anchors_x_features {} not greater than number of anchors * 5 = {}",
1696                            anchors_x_features,
1697                            num_anchors * 5,
1698                        )));
1699                    }
1700
1701                    if !anchors_x_features.is_multiple_of(num_anchors) {
1702                        return Err(DecoderError::InvalidConfig(format!(
1703                            "Invalid ModelPack Split Detection shape: anchors_x_features {} not a multiple of number of anchors {}",
1704                            anchors_x_features, num_anchors
1705                        )));
1706                    }
1707                    Ok(anchors_x_features / num_anchors - 5)
1708                }
1709            },
1710
1711            ConfigOutputRef::Scores(scores) => match scores.decoder {
1712                DecoderType::Ultralytics => Ok(scores.shape[1]),
1713                DecoderType::ModelPack => Ok(scores.shape[2]),
1714            },
1715            ConfigOutputRef::Segmentation(seg) => Ok(seg.shape[3]),
1716            _ => Err(DecoderError::Internal(
1717                "Attempted to get class count from unsupported config output".to_owned(),
1718            )),
1719        }
1720    }
1721
1722    // get the class count from dshape or calculate from num_features
1723    fn get_class_count(
1724        dshape: &[(DimName, usize)],
1725        protos: Option<usize>,
1726        anchors: Option<usize>,
1727    ) -> Result<usize, DecoderError> {
1728        if dshape.is_empty() {
1729            return Ok(0);
1730        }
1731        // if it has num_classes in dshape, return it
1732        for (dim_name, dim_size) in dshape {
1733            if *dim_name == DimName::NumClasses {
1734                return Ok(*dim_size);
1735            }
1736        }
1737
1738        // number of classes can be calculated from num_features - 4 for yolo.  If the
1739        // model has protos, we also subtract the number of protos.
1740        for (dim_name, dim_size) in dshape {
1741            if *dim_name == DimName::NumFeatures {
1742                let protos = protos.unwrap_or(0);
1743                if protos + 4 >= *dim_size {
1744                    return Err(DecoderError::InvalidConfig(format!(
1745                        "Invalid shape: Yolo num_features {} must be greater than {}",
1746                        *dim_size,
1747                        protos + 4,
1748                    )));
1749                }
1750                return Ok(*dim_size - 4 - protos);
1751            }
1752        }
1753
1754        // number of classes can be calculated from number of anchors for modelpack
1755        // split detection
1756        if let Some(num_anchors) = anchors {
1757            for (dim_name, dim_size) in dshape {
1758                if *dim_name == DimName::NumAnchorsXFeatures {
1759                    let anchors_x_features = *dim_size;
1760                    if anchors_x_features <= num_anchors * 5 {
1761                        return Err(DecoderError::InvalidConfig(format!(
1762                            "Invalid ModelPack Split Detection shape: anchors_x_features {} not greater than number of anchors * 5 = {}",
1763                            anchors_x_features,
1764                            num_anchors * 5,
1765                        )));
1766                    }
1767
1768                    if !anchors_x_features.is_multiple_of(num_anchors) {
1769                        return Err(DecoderError::InvalidConfig(format!(
1770                            "Invalid ModelPack Split Detection shape: anchors_x_features {} not a multiple of number of anchors {}",
1771                            anchors_x_features, num_anchors
1772                        )));
1773                    }
1774                    return Ok((anchors_x_features / num_anchors) - 5);
1775                }
1776            }
1777        }
1778        Err(DecoderError::InvalidConfig(
1779            "Cannot determine number of classes from dshape".to_owned(),
1780        ))
1781    }
1782
1783    fn get_protos_count(dshape: &[(DimName, usize)]) -> Option<usize> {
1784        for (dim_name, dim_size) in dshape {
1785            if *dim_name == DimName::NumProtos {
1786                return Some(*dim_size);
1787            }
1788        }
1789        None
1790    }
1791}