Skip to main content

edgefirst_decoder/decoder/
builder.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::HashSet;
5
6use super::config::ConfigOutputRef;
7use super::configs::{self, DecoderType, DecoderVersion, DimName, ModelType};
8use super::{ConfigOutput, ConfigOutputs, Decoder};
9use crate::DecoderError;
10
11#[derive(Debug, Clone, PartialEq)]
12pub struct DecoderBuilder {
13    config_src: Option<ConfigSource>,
14    iou_threshold: f32,
15    score_threshold: f32,
16    /// NMS mode: Some(mode) applies NMS, None bypasses NMS (for end-to-end
17    /// models)
18    nms: Option<configs::Nms>,
19}
20
21#[derive(Debug, Clone, PartialEq)]
22enum ConfigSource {
23    Yaml(String),
24    Json(String),
25    Config(ConfigOutputs),
26}
27
28impl Default for DecoderBuilder {
29    /// Creates a default DecoderBuilder with no configuration and 0.5 score
30    /// threshold and 0.5 IoU threshold.
31    ///
32    /// A valid configuration must be provided before building the Decoder.
33    ///
34    /// # Examples
35    /// ```rust
36    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
37    /// # fn main() -> DecoderResult<()> {
38    /// #  let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
39    /// let decoder = DecoderBuilder::default()
40    ///     .with_config_yaml_str(config_yaml)
41    ///     .build()?;
42    /// assert_eq!(decoder.score_threshold, 0.5);
43    /// assert_eq!(decoder.iou_threshold, 0.5);
44    ///
45    /// # Ok(())
46    /// # }
47    /// ```
48    fn default() -> Self {
49        Self {
50            config_src: None,
51            iou_threshold: 0.5,
52            score_threshold: 0.5,
53            nms: Some(configs::Nms::ClassAgnostic),
54        }
55    }
56}
57
58impl DecoderBuilder {
59    /// Creates a default DecoderBuilder with no configuration and 0.5 score
60    /// threshold and 0.5 IoU threshold.
61    ///
62    /// A valid configuration must be provided before building the Decoder.
63    ///
64    /// # Examples
65    /// ```rust
66    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
67    /// # fn main() -> DecoderResult<()> {
68    /// #  let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
69    /// let decoder = DecoderBuilder::new()
70    ///     .with_config_yaml_str(config_yaml)
71    ///     .build()?;
72    /// assert_eq!(decoder.score_threshold, 0.5);
73    /// assert_eq!(decoder.iou_threshold, 0.5);
74    ///
75    /// # Ok(())
76    /// # }
77    /// ```
78    pub fn new() -> Self {
79        Self::default()
80    }
81
82    /// Loads a model configuration in YAML format. Does not check if the string
83    /// is a correct configuration file. Use `DecoderBuilder.build()` to
84    /// deserialize the YAML and parse the model configuration.
85    ///
86    /// # Examples
87    /// ```rust
88    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
89    /// # fn main() -> DecoderResult<()> {
90    /// let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
91    /// let decoder = DecoderBuilder::new()
92    ///     .with_config_yaml_str(config_yaml)
93    ///     .build()?;
94    ///
95    /// # Ok(())
96    /// # }
97    /// ```
98    pub fn with_config_yaml_str(mut self, yaml_str: String) -> Self {
99        self.config_src.replace(ConfigSource::Yaml(yaml_str));
100        self
101    }
102
103    /// Loads a model configuration in JSON format. Does not check if the string
104    /// is a correct configuration file. Use `DecoderBuilder.build()` to
105    /// deserialize the JSON and parse the model configuration.
106    ///
107    /// # Examples
108    /// ```rust
109    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
110    /// # fn main() -> DecoderResult<()> {
111    /// let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
112    /// let decoder = DecoderBuilder::new()
113    ///     .with_config_json_str(config_json)
114    ///     .build()?;
115    ///
116    /// # Ok(())
117    /// # }
118    /// ```
119    pub fn with_config_json_str(mut self, json_str: String) -> Self {
120        self.config_src.replace(ConfigSource::Json(json_str));
121        self
122    }
123
124    /// Loads a model configuration. Does not check if the configuration is
125    /// correct. Intended to be used when the user needs control over the
126    /// deserialize of the configuration information. Use
127    /// `DecoderBuilder.build()` to parse the model configuration.
128    ///
129    /// # Examples
130    /// ```rust
131    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
132    /// # fn main() -> DecoderResult<()> {
133    /// let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json"));
134    /// let config = serde_json::from_str(config_json)?;
135    /// let decoder = DecoderBuilder::new().with_config(config).build()?;
136    ///
137    /// # Ok(())
138    /// # }
139    /// ```
140    pub fn with_config(mut self, config: ConfigOutputs) -> Self {
141        self.config_src.replace(ConfigSource::Config(config));
142        self
143    }
144
145    /// Loads a YOLO detection model configuration.  Use
146    /// `DecoderBuilder.build()` to parse the model configuration.
147    ///
148    /// # Examples
149    /// ```rust
150    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
151    /// # fn main() -> DecoderResult<()> {
152    /// let decoder = DecoderBuilder::new()
153    ///     .with_config_yolo_det(
154    ///         configs::Detection {
155    ///             anchors: None,
156    ///             decoder: configs::DecoderType::Ultralytics,
157    ///             quantization: Some(configs::QuantTuple(0.012345, 26)),
158    ///             shape: vec![1, 84, 8400],
159    ///             dshape: Vec::new(),
160    ///             normalized: Some(true),
161    ///         },
162    ///         None,
163    ///     )
164    ///     .build()?;
165    ///
166    /// # Ok(())
167    /// # }
168    /// ```
169    pub fn with_config_yolo_det(
170        mut self,
171        boxes: configs::Detection,
172        version: Option<DecoderVersion>,
173    ) -> Self {
174        let config = ConfigOutputs {
175            outputs: vec![ConfigOutput::Detection(boxes)],
176            decoder_version: version,
177            ..Default::default()
178        };
179        self.config_src.replace(ConfigSource::Config(config));
180        self
181    }
182
183    /// Loads a YOLO split detection model configuration.  Use
184    /// `DecoderBuilder.build()` to parse the model configuration.
185    ///
186    /// # Examples
187    /// ```rust
188    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
189    /// # fn main() -> DecoderResult<()> {
190    /// let boxes_config = configs::Boxes {
191    ///     decoder: configs::DecoderType::Ultralytics,
192    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
193    ///     shape: vec![1, 4, 8400],
194    ///     dshape: Vec::new(),
195    ///     normalized: Some(true),
196    /// };
197    /// let scores_config = configs::Scores {
198    ///     decoder: configs::DecoderType::Ultralytics,
199    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
200    ///     shape: vec![1, 80, 8400],
201    ///     dshape: Vec::new(),
202    /// };
203    /// let decoder = DecoderBuilder::new()
204    ///     .with_config_yolo_split_det(boxes_config, scores_config)
205    ///     .build()?;
206    /// # Ok(())
207    /// # }
208    /// ```
209    pub fn with_config_yolo_split_det(
210        mut self,
211        boxes: configs::Boxes,
212        scores: configs::Scores,
213    ) -> Self {
214        let config = ConfigOutputs {
215            outputs: vec![ConfigOutput::Boxes(boxes), ConfigOutput::Scores(scores)],
216            ..Default::default()
217        };
218        self.config_src.replace(ConfigSource::Config(config));
219        self
220    }
221
222    /// Loads a YOLO segmentation model configuration.  Use
223    /// `DecoderBuilder.build()` to parse the model configuration.
224    ///
225    /// # Examples
226    /// ```rust
227    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
228    /// # fn main() -> DecoderResult<()> {
229    /// let seg_config = configs::Detection {
230    ///     decoder: configs::DecoderType::Ultralytics,
231    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
232    ///     shape: vec![1, 116, 8400],
233    ///     anchors: None,
234    ///     dshape: Vec::new(),
235    ///     normalized: Some(true),
236    /// };
237    /// let protos_config = configs::Protos {
238    ///     decoder: configs::DecoderType::Ultralytics,
239    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
240    ///     shape: vec![1, 160, 160, 32],
241    ///     dshape: Vec::new(),
242    /// };
243    /// let decoder = DecoderBuilder::new()
244    ///     .with_config_yolo_segdet(
245    ///         seg_config,
246    ///         protos_config,
247    ///         Some(configs::DecoderVersion::Yolov8),
248    ///     )
249    ///     .build()?;
250    /// # Ok(())
251    /// # }
252    /// ```
253    pub fn with_config_yolo_segdet(
254        mut self,
255        boxes: configs::Detection,
256        protos: configs::Protos,
257        version: Option<DecoderVersion>,
258    ) -> Self {
259        let config = ConfigOutputs {
260            outputs: vec![ConfigOutput::Detection(boxes), ConfigOutput::Protos(protos)],
261            decoder_version: version,
262            ..Default::default()
263        };
264        self.config_src.replace(ConfigSource::Config(config));
265        self
266    }
267
268    /// Loads a YOLO split segmentation model configuration.  Use
269    /// `DecoderBuilder.build()` to parse the model configuration.
270    ///
271    /// # Examples
272    /// ```rust
273    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
274    /// # fn main() -> DecoderResult<()> {
275    /// let boxes_config = configs::Boxes {
276    ///     decoder: configs::DecoderType::Ultralytics,
277    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
278    ///     shape: vec![1, 4, 8400],
279    ///     dshape: Vec::new(),
280    ///     normalized: Some(true),
281    /// };
282    /// let scores_config = configs::Scores {
283    ///     decoder: configs::DecoderType::Ultralytics,
284    ///     quantization: Some(configs::QuantTuple(0.012345, 14)),
285    ///     shape: vec![1, 80, 8400],
286    ///     dshape: Vec::new(),
287    /// };
288    /// let mask_config = configs::MaskCoefficients {
289    ///     decoder: configs::DecoderType::Ultralytics,
290    ///     quantization: Some(configs::QuantTuple(0.0064123, 125)),
291    ///     shape: vec![1, 32, 8400],
292    ///     dshape: Vec::new(),
293    /// };
294    /// let protos_config = configs::Protos {
295    ///     decoder: configs::DecoderType::Ultralytics,
296    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
297    ///     shape: vec![1, 160, 160, 32],
298    ///     dshape: Vec::new(),
299    /// };
300    /// let decoder = DecoderBuilder::new()
301    ///     .with_config_yolo_split_segdet(boxes_config, scores_config, mask_config, protos_config)
302    ///     .build()?;
303    /// # Ok(())
304    /// # }
305    /// ```
306    pub fn with_config_yolo_split_segdet(
307        mut self,
308        boxes: configs::Boxes,
309        scores: configs::Scores,
310        mask_coefficients: configs::MaskCoefficients,
311        protos: configs::Protos,
312    ) -> Self {
313        let config = ConfigOutputs {
314            outputs: vec![
315                ConfigOutput::Boxes(boxes),
316                ConfigOutput::Scores(scores),
317                ConfigOutput::MaskCoefficients(mask_coefficients),
318                ConfigOutput::Protos(protos),
319            ],
320            ..Default::default()
321        };
322        self.config_src.replace(ConfigSource::Config(config));
323        self
324    }
325
326    /// Loads a ModelPack detection model configuration.  Use
327    /// `DecoderBuilder.build()` to parse the model configuration.
328    ///
329    /// # Examples
330    /// ```rust
331    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
332    /// # fn main() -> DecoderResult<()> {
333    /// let boxes_config = configs::Boxes {
334    ///     decoder: configs::DecoderType::ModelPack,
335    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
336    ///     shape: vec![1, 8400, 1, 4],
337    ///     dshape: Vec::new(),
338    ///     normalized: Some(true),
339    /// };
340    /// let scores_config = configs::Scores {
341    ///     decoder: configs::DecoderType::ModelPack,
342    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
343    ///     shape: vec![1, 8400, 3],
344    ///     dshape: Vec::new(),
345    /// };
346    /// let decoder = DecoderBuilder::new()
347    ///     .with_config_modelpack_det(boxes_config, scores_config)
348    ///     .build()?;
349    /// # Ok(())
350    /// # }
351    /// ```
352    pub fn with_config_modelpack_det(
353        mut self,
354        boxes: configs::Boxes,
355        scores: configs::Scores,
356    ) -> Self {
357        let config = ConfigOutputs {
358            outputs: vec![ConfigOutput::Boxes(boxes), ConfigOutput::Scores(scores)],
359            ..Default::default()
360        };
361        self.config_src.replace(ConfigSource::Config(config));
362        self
363    }
364
365    /// Loads a ModelPack split detection model configuration. Use
366    /// `DecoderBuilder.build()` to parse the model configuration.
367    ///
368    /// # Examples
369    /// ```rust
370    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
371    /// # fn main() -> DecoderResult<()> {
372    /// let config0 = configs::Detection {
373    ///     anchors: Some(vec![
374    ///         [0.13750000298023224, 0.2074074000120163],
375    ///         [0.2541666626930237, 0.21481481194496155],
376    ///         [0.23125000298023224, 0.35185185074806213],
377    ///     ]),
378    ///     decoder: configs::DecoderType::ModelPack,
379    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
380    ///     shape: vec![1, 17, 30, 18],
381    ///     dshape: Vec::new(),
382    ///     normalized: Some(true),
383    /// };
384    /// let config1 = configs::Detection {
385    ///     anchors: Some(vec![
386    ///         [0.36666667461395264, 0.31481480598449707],
387    ///         [0.38749998807907104, 0.4740740656852722],
388    ///         [0.5333333611488342, 0.644444465637207],
389    ///     ]),
390    ///     decoder: configs::DecoderType::ModelPack,
391    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
392    ///     shape: vec![1, 9, 15, 18],
393    ///     dshape: Vec::new(),
394    ///     normalized: Some(true),
395    /// };
396    ///
397    /// let decoder = DecoderBuilder::new()
398    ///     .with_config_modelpack_det_split(vec![config0, config1])
399    ///     .build()?;
400    /// # Ok(())
401    /// # }
402    /// ```
403    pub fn with_config_modelpack_det_split(mut self, boxes: Vec<configs::Detection>) -> Self {
404        let outputs = boxes.into_iter().map(ConfigOutput::Detection).collect();
405        let config = ConfigOutputs {
406            outputs,
407            ..Default::default()
408        };
409        self.config_src.replace(ConfigSource::Config(config));
410        self
411    }
412
413    /// Loads a ModelPack segmentation detection model configuration. Use
414    /// `DecoderBuilder.build()` to parse the model configuration.
415    ///
416    /// # Examples
417    /// ```rust
418    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
419    /// # fn main() -> DecoderResult<()> {
420    /// let boxes_config = configs::Boxes {
421    ///     decoder: configs::DecoderType::ModelPack,
422    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
423    ///     shape: vec![1, 8400, 1, 4],
424    ///     dshape: Vec::new(),
425    ///     normalized: Some(true),
426    /// };
427    /// let scores_config = configs::Scores {
428    ///     decoder: configs::DecoderType::ModelPack,
429    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
430    ///     shape: vec![1, 8400, 2],
431    ///     dshape: Vec::new(),
432    /// };
433    /// let seg_config = configs::Segmentation {
434    ///     decoder: configs::DecoderType::ModelPack,
435    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
436    ///     shape: vec![1, 640, 640, 3],
437    ///     dshape: Vec::new(),
438    /// };
439    /// let decoder = DecoderBuilder::new()
440    ///     .with_config_modelpack_segdet(boxes_config, scores_config, seg_config)
441    ///     .build()?;
442    /// # Ok(())
443    /// # }
444    /// ```
445    pub fn with_config_modelpack_segdet(
446        mut self,
447        boxes: configs::Boxes,
448        scores: configs::Scores,
449        segmentation: configs::Segmentation,
450    ) -> Self {
451        let config = ConfigOutputs {
452            outputs: vec![
453                ConfigOutput::Boxes(boxes),
454                ConfigOutput::Scores(scores),
455                ConfigOutput::Segmentation(segmentation),
456            ],
457            ..Default::default()
458        };
459        self.config_src.replace(ConfigSource::Config(config));
460        self
461    }
462
463    /// Loads a ModelPack segmentation split detection model configuration. Use
464    /// `DecoderBuilder.build()` to parse the model configuration.
465    ///
466    /// # Examples
467    /// ```rust
468    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
469    /// # fn main() -> DecoderResult<()> {
470    /// let config0 = configs::Detection {
471    ///     anchors: Some(vec![
472    ///         [0.36666667461395264, 0.31481480598449707],
473    ///         [0.38749998807907104, 0.4740740656852722],
474    ///         [0.5333333611488342, 0.644444465637207],
475    ///     ]),
476    ///     decoder: configs::DecoderType::ModelPack,
477    ///     quantization: Some(configs::QuantTuple(0.08547406643629074, 174)),
478    ///     shape: vec![1, 9, 15, 18],
479    ///     dshape: Vec::new(),
480    ///     normalized: Some(true),
481    /// };
482    /// let config1 = configs::Detection {
483    ///     anchors: Some(vec![
484    ///         [0.13750000298023224, 0.2074074000120163],
485    ///         [0.2541666626930237, 0.21481481194496155],
486    ///         [0.23125000298023224, 0.35185185074806213],
487    ///     ]),
488    ///     decoder: configs::DecoderType::ModelPack,
489    ///     quantization: Some(configs::QuantTuple(0.09929127991199493, 183)),
490    ///     shape: vec![1, 17, 30, 18],
491    ///     dshape: Vec::new(),
492    ///     normalized: Some(true),
493    /// };
494    /// let seg_config = configs::Segmentation {
495    ///     decoder: configs::DecoderType::ModelPack,
496    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
497    ///     shape: vec![1, 640, 640, 2],
498    ///     dshape: Vec::new(),
499    /// };
500    /// let decoder = DecoderBuilder::new()
501    ///     .with_config_modelpack_segdet_split(vec![config0, config1], seg_config)
502    ///     .build()?;
503    /// # Ok(())
504    /// # }
505    /// ```
506    pub fn with_config_modelpack_segdet_split(
507        mut self,
508        boxes: Vec<configs::Detection>,
509        segmentation: configs::Segmentation,
510    ) -> Self {
511        let mut outputs = boxes
512            .into_iter()
513            .map(ConfigOutput::Detection)
514            .collect::<Vec<_>>();
515        outputs.push(ConfigOutput::Segmentation(segmentation));
516        let config = ConfigOutputs {
517            outputs,
518            ..Default::default()
519        };
520        self.config_src.replace(ConfigSource::Config(config));
521        self
522    }
523
524    /// Loads a ModelPack segmentation model configuration. Use
525    /// `DecoderBuilder.build()` to parse the model configuration.
526    ///
527    /// # Examples
528    /// ```rust
529    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
530    /// # fn main() -> DecoderResult<()> {
531    /// let seg_config = configs::Segmentation {
532    ///     decoder: configs::DecoderType::ModelPack,
533    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
534    ///     shape: vec![1, 640, 640, 3],
535    ///     dshape: Vec::new(),
536    /// };
537    /// let decoder = DecoderBuilder::new()
538    ///     .with_config_modelpack_seg(seg_config)
539    ///     .build()?;
540    /// # Ok(())
541    /// # }
542    /// ```
543    pub fn with_config_modelpack_seg(mut self, segmentation: configs::Segmentation) -> Self {
544        let config = ConfigOutputs {
545            outputs: vec![ConfigOutput::Segmentation(segmentation)],
546            ..Default::default()
547        };
548        self.config_src.replace(ConfigSource::Config(config));
549        self
550    }
551
552    /// Add an output to the decoder configuration.
553    ///
554    /// Incrementally builds the model configuration by adding outputs one at
555    /// a time. The decoder resolves the model type from the combination of
556    /// outputs during `build()`.
557    ///
558    /// If `dshape` is non-empty on the output, `shape` is automatically
559    /// derived from it (the size component of each named dimension). This
560    /// prevents conflicts between `shape` and `dshape`.
561    ///
562    /// This uses the programmatic config path. Calling this after
563    /// `with_config_json_str()` or `with_config_yaml_str()` replaces the
564    /// string-based config source.
565    ///
566    /// # Examples
567    /// ```rust
568    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, ConfigOutput, configs};
569    /// # fn main() -> DecoderResult<()> {
570    /// let decoder = DecoderBuilder::new()
571    ///     .add_output(ConfigOutput::Scores(configs::Scores {
572    ///         decoder: configs::DecoderType::Ultralytics,
573    ///         dshape: vec![
574    ///             (configs::DimName::Batch, 1),
575    ///             (configs::DimName::NumClasses, 80),
576    ///             (configs::DimName::NumBoxes, 8400),
577    ///         ],
578    ///         ..Default::default()
579    ///     }))
580    ///     .add_output(ConfigOutput::Boxes(configs::Boxes {
581    ///         decoder: configs::DecoderType::Ultralytics,
582    ///         dshape: vec![
583    ///             (configs::DimName::Batch, 1),
584    ///             (configs::DimName::BoxCoords, 4),
585    ///             (configs::DimName::NumBoxes, 8400),
586    ///         ],
587    ///         ..Default::default()
588    ///     }))
589    ///     .build()?;
590    /// # Ok(())
591    /// # }
592    /// ```
593    pub fn add_output(mut self, output: ConfigOutput) -> Self {
594        if !matches!(self.config_src, Some(ConfigSource::Config(_))) {
595            self.config_src = Some(ConfigSource::Config(ConfigOutputs::default()));
596        }
597        if let Some(ConfigSource::Config(ref mut config)) = self.config_src {
598            config.outputs.push(Self::normalize_output(output));
599        }
600        self
601    }
602
603    /// Sets the decoder version for Ultralytics models.
604    ///
605    /// This is used with `add_output()` to specify the YOLO architecture
606    /// version when it cannot be inferred from the output shapes alone.
607    ///
608    /// - `Yolov5`, `Yolov8`, `Yolo11`: Traditional models requiring external
609    ///   NMS
610    /// - `Yolo26`: End-to-end models with NMS embedded in the model graph
611    ///
612    /// # Examples
613    /// ```rust
614    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, ConfigOutput, configs};
615    /// # fn main() -> DecoderResult<()> {
616    /// let decoder = DecoderBuilder::new()
617    ///     .add_output(ConfigOutput::Detection(configs::Detection {
618    ///         decoder: configs::DecoderType::Ultralytics,
619    ///         dshape: vec![
620    ///             (configs::DimName::Batch, 1),
621    ///             (configs::DimName::NumBoxes, 100),
622    ///             (configs::DimName::NumFeatures, 6),
623    ///         ],
624    ///         ..Default::default()
625    ///     }))
626    ///     .with_decoder_version(configs::DecoderVersion::Yolo26)
627    ///     .build()?;
628    /// # Ok(())
629    /// # }
630    /// ```
631    pub fn with_decoder_version(mut self, version: configs::DecoderVersion) -> Self {
632        if !matches!(self.config_src, Some(ConfigSource::Config(_))) {
633            self.config_src = Some(ConfigSource::Config(ConfigOutputs::default()));
634        }
635        if let Some(ConfigSource::Config(ref mut config)) = self.config_src {
636            config.decoder_version = Some(version);
637        }
638        self
639    }
640
641    /// Normalize an output: if dshape is non-empty, derive shape from it.
642    fn normalize_output(mut output: ConfigOutput) -> ConfigOutput {
643        fn normalize_shape(shape: &mut Vec<usize>, dshape: &[(configs::DimName, usize)]) {
644            if !dshape.is_empty() {
645                *shape = dshape.iter().map(|(_, size)| *size).collect();
646            }
647        }
648        match &mut output {
649            ConfigOutput::Detection(c) => normalize_shape(&mut c.shape, &c.dshape),
650            ConfigOutput::Boxes(c) => normalize_shape(&mut c.shape, &c.dshape),
651            ConfigOutput::Scores(c) => normalize_shape(&mut c.shape, &c.dshape),
652            ConfigOutput::Protos(c) => normalize_shape(&mut c.shape, &c.dshape),
653            ConfigOutput::Segmentation(c) => normalize_shape(&mut c.shape, &c.dshape),
654            ConfigOutput::MaskCoefficients(c) => normalize_shape(&mut c.shape, &c.dshape),
655            ConfigOutput::Mask(c) => normalize_shape(&mut c.shape, &c.dshape),
656            ConfigOutput::Classes(c) => normalize_shape(&mut c.shape, &c.dshape),
657        }
658        output
659    }
660
661    /// Sets the scores threshold of the decoder
662    ///
663    /// # Examples
664    /// ```rust
665    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
666    /// # fn main() -> DecoderResult<()> {
667    /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
668    /// let decoder = DecoderBuilder::new()
669    ///     .with_config_json_str(config_json)
670    ///     .with_score_threshold(0.654)
671    ///     .build()?;
672    /// assert_eq!(decoder.score_threshold, 0.654);
673    /// # Ok(())
674    /// # }
675    /// ```
676    pub fn with_score_threshold(mut self, score_threshold: f32) -> Self {
677        self.score_threshold = score_threshold;
678        self
679    }
680
681    /// Sets the IOU threshold of the decoder. Has no effect when NMS is set to
682    /// `None`
683    ///
684    /// # Examples
685    /// ```rust
686    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
687    /// # fn main() -> DecoderResult<()> {
688    /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
689    /// let decoder = DecoderBuilder::new()
690    ///     .with_config_json_str(config_json)
691    ///     .with_iou_threshold(0.654)
692    ///     .build()?;
693    /// assert_eq!(decoder.iou_threshold, 0.654);
694    /// # Ok(())
695    /// # }
696    /// ```
697    pub fn with_iou_threshold(mut self, iou_threshold: f32) -> Self {
698        self.iou_threshold = iou_threshold;
699        self
700    }
701
702    /// Sets the NMS mode for the decoder.
703    ///
704    /// - `Some(Nms::ClassAgnostic)` — class-agnostic NMS (default): suppress
705    ///   overlapping boxes regardless of class label
706    /// - `Some(Nms::ClassAware)` — class-aware NMS: only suppress boxes that
707    ///   share the same class label AND overlap above the IoU threshold
708    /// - `None` — bypass NMS entirely (for end-to-end models with embedded NMS)
709    ///
710    /// # Examples
711    /// ```rust
712    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, configs::Nms};
713    /// # fn main() -> DecoderResult<()> {
714    /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
715    /// let decoder = DecoderBuilder::new()
716    ///     .with_config_json_str(config_json)
717    ///     .with_nms(Some(Nms::ClassAware))
718    ///     .build()?;
719    /// assert_eq!(decoder.nms, Some(Nms::ClassAware));
720    /// # Ok(())
721    /// # }
722    /// ```
723    pub fn with_nms(mut self, nms: Option<configs::Nms>) -> Self {
724        self.nms = nms;
725        self
726    }
727
728    /// Builds the decoder with the given settings. If the config is a JSON or
729    /// YAML string, this will deserialize the JSON or YAML and then parse the
730    /// configuration information.
731    ///
732    /// # Examples
733    /// ```rust
734    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
735    /// # fn main() -> DecoderResult<()> {
736    /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
737    /// let decoder = DecoderBuilder::new()
738    ///     .with_config_json_str(config_json)
739    ///     .with_score_threshold(0.654)
740    ///     .build()?;
741    /// # Ok(())
742    /// # }
743    /// ```
744    pub fn build(self) -> Result<Decoder, DecoderError> {
745        let config = match self.config_src {
746            Some(ConfigSource::Json(s)) => serde_json::from_str(&s)?,
747            Some(ConfigSource::Yaml(s)) => serde_yaml::from_str(&s)?,
748            Some(ConfigSource::Config(c)) => c,
749            None => return Err(DecoderError::NoConfig),
750        };
751
752        // Extract normalized flag from config outputs
753        let normalized = Self::get_normalized(&config.outputs);
754
755        // Use NMS from config if present, otherwise use builder's NMS setting
756        let nms = config.nms.or(self.nms);
757        let model_type = Self::get_model_type(config)?;
758
759        Ok(Decoder {
760            model_type,
761            iou_threshold: self.iou_threshold,
762            score_threshold: self.score_threshold,
763            nms,
764            normalized,
765        })
766    }
767
768    /// Extracts the normalized flag from config outputs.
769    /// - `Some(true)`: Boxes are in normalized [0,1] coordinates
770    /// - `Some(false)`: Boxes are in pixel coordinates
771    /// - `None`: Unknown (not specified in config), caller must infer
772    fn get_normalized(outputs: &[ConfigOutput]) -> Option<bool> {
773        for output in outputs {
774            match output {
775                ConfigOutput::Detection(det) => return det.normalized,
776                ConfigOutput::Boxes(boxes) => return boxes.normalized,
777                _ => {}
778            }
779        }
780        None // not specified
781    }
782
783    fn get_model_type(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
784        // yolo or modelpack
785        let mut yolo = false;
786        let mut modelpack = false;
787        for c in &configs.outputs {
788            match c.decoder() {
789                DecoderType::ModelPack => modelpack = true,
790                DecoderType::Ultralytics => yolo = true,
791            }
792        }
793        match (modelpack, yolo) {
794            (true, true) => Err(DecoderError::InvalidConfig(
795                "Both ModelPack and Yolo outputs found in config".to_string(),
796            )),
797            (true, false) => Self::get_model_type_modelpack(configs),
798            (false, true) => Self::get_model_type_yolo(configs),
799            (false, false) => Err(DecoderError::InvalidConfig(
800                "No outputs found in config".to_string(),
801            )),
802        }
803    }
804
805    fn get_model_type_yolo(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
806        let mut boxes = None;
807        let mut protos = None;
808        let mut split_boxes = None;
809        let mut split_scores = None;
810        let mut split_mask_coeff = None;
811        let mut split_classes = None;
812        for c in configs.outputs {
813            match c {
814                ConfigOutput::Detection(detection) => boxes = Some(detection),
815                ConfigOutput::Segmentation(_) => {
816                    return Err(DecoderError::InvalidConfig(
817                        "Invalid Segmentation output with Yolo decoder".to_string(),
818                    ));
819                }
820                ConfigOutput::Protos(protos_) => protos = Some(protos_),
821                ConfigOutput::Mask(_) => {
822                    return Err(DecoderError::InvalidConfig(
823                        "Invalid Mask output with Yolo decoder".to_string(),
824                    ));
825                }
826                ConfigOutput::Scores(scores) => split_scores = Some(scores),
827                ConfigOutput::Boxes(boxes) => split_boxes = Some(boxes),
828                ConfigOutput::MaskCoefficients(mask_coeff) => split_mask_coeff = Some(mask_coeff),
829                ConfigOutput::Classes(classes) => split_classes = Some(classes),
830            }
831        }
832
833        // Use end-to-end model types when:
834        // 1. decoder_version is explicitly set to Yolo26 (definitive), OR
835        //    decoder_version is not set but the dshapes are (batch, num_boxes,
836        //    num_features)
837        let is_end_to_end_dshape = boxes.as_ref().is_some_and(|b| {
838            let dims = b.dshape.iter().map(|(d, _)| *d).collect::<Vec<_>>();
839            dims == vec![DimName::Batch, DimName::NumBoxes, DimName::NumFeatures]
840        });
841
842        let is_end_to_end = configs
843            .decoder_version
844            .map(|v| v.is_end_to_end())
845            .unwrap_or(is_end_to_end_dshape);
846
847        if is_end_to_end {
848            if let Some(boxes) = boxes {
849                if let Some(protos) = protos {
850                    Self::verify_yolo_seg_det_26(&boxes, &protos)?;
851                    return Ok(ModelType::YoloEndToEndSegDet { boxes, protos });
852                } else {
853                    Self::verify_yolo_det_26(&boxes)?;
854                    return Ok(ModelType::YoloEndToEndDet { boxes });
855                }
856            } else if let (Some(split_boxes), Some(split_scores), Some(split_classes)) =
857                (split_boxes, split_scores, split_classes)
858            {
859                if let (Some(split_mask_coeff), Some(protos)) = (split_mask_coeff, protos) {
860                    Self::verify_yolo_split_end_to_end_segdet(
861                        &split_boxes,
862                        &split_scores,
863                        &split_classes,
864                        &split_mask_coeff,
865                        &protos,
866                    )?;
867                    return Ok(ModelType::YoloSplitEndToEndSegDet {
868                        boxes: split_boxes,
869                        scores: split_scores,
870                        classes: split_classes,
871                        mask_coeff: split_mask_coeff,
872                        protos,
873                    });
874                }
875                Self::verify_yolo_split_end_to_end_det(
876                    &split_boxes,
877                    &split_scores,
878                    &split_classes,
879                )?;
880                return Ok(ModelType::YoloSplitEndToEndDet {
881                    boxes: split_boxes,
882                    scores: split_scores,
883                    classes: split_classes,
884                });
885            } else {
886                return Err(DecoderError::InvalidConfig(
887                    "Invalid Yolo end-to-end model outputs".to_string(),
888                ));
889            }
890        }
891
892        if let Some(boxes) = boxes {
893            if let Some(protos) = protos {
894                Self::verify_yolo_seg_det(&boxes, &protos)?;
895                Ok(ModelType::YoloSegDet { boxes, protos })
896            } else {
897                Self::verify_yolo_det(&boxes)?;
898                Ok(ModelType::YoloDet { boxes })
899            }
900        } else if let (Some(boxes), Some(scores)) = (split_boxes, split_scores) {
901            if let (Some(mask_coeff), Some(protos)) = (split_mask_coeff, protos) {
902                Self::verify_yolo_split_segdet(&boxes, &scores, &mask_coeff, &protos)?;
903                Ok(ModelType::YoloSplitSegDet {
904                    boxes,
905                    scores,
906                    mask_coeff,
907                    protos,
908                })
909            } else {
910                Self::verify_yolo_split_det(&boxes, &scores)?;
911                Ok(ModelType::YoloSplitDet { boxes, scores })
912            }
913        } else {
914            Err(DecoderError::InvalidConfig(
915                "Invalid Yolo model outputs".to_string(),
916            ))
917        }
918    }
919
920    fn verify_yolo_det(detect: &configs::Detection) -> Result<(), DecoderError> {
921        if detect.shape.len() != 3 {
922            return Err(DecoderError::InvalidConfig(format!(
923                "Invalid Yolo Detection shape {:?}",
924                detect.shape
925            )));
926        }
927
928        Self::verify_dshapes(
929            &detect.dshape,
930            &detect.shape,
931            "Detection",
932            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
933        )?;
934        if !detect.dshape.is_empty() {
935            Self::get_class_count(&detect.dshape, None, None)?;
936        } else {
937            Self::get_class_count_no_dshape(detect.into(), None)?;
938        }
939
940        Ok(())
941    }
942
943    fn verify_yolo_det_26(detect: &configs::Detection) -> Result<(), DecoderError> {
944        if detect.shape.len() != 3 {
945            return Err(DecoderError::InvalidConfig(format!(
946                "Invalid Yolo Detection shape {:?}",
947                detect.shape
948            )));
949        }
950
951        Self::verify_dshapes(
952            &detect.dshape,
953            &detect.shape,
954            "Detection",
955            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
956        )?;
957
958        if !detect.shape.contains(&6) {
959            return Err(DecoderError::InvalidConfig(
960                "Yolo26 Detection must have 6 features".to_string(),
961            ));
962        }
963
964        Ok(())
965    }
966
967    fn verify_yolo_seg_det(
968        detection: &configs::Detection,
969        protos: &configs::Protos,
970    ) -> Result<(), DecoderError> {
971        if detection.shape.len() != 3 {
972            return Err(DecoderError::InvalidConfig(format!(
973                "Invalid Yolo Detection shape {:?}",
974                detection.shape
975            )));
976        }
977        if protos.shape.len() != 4 {
978            return Err(DecoderError::InvalidConfig(format!(
979                "Invalid Yolo Protos shape {:?}",
980                protos.shape
981            )));
982        }
983
984        Self::verify_dshapes(
985            &detection.dshape,
986            &detection.shape,
987            "Detection",
988            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
989        )?;
990        Self::verify_dshapes(
991            &protos.dshape,
992            &protos.shape,
993            "Protos",
994            &[
995                DimName::Batch,
996                DimName::Height,
997                DimName::Width,
998                DimName::NumProtos,
999            ],
1000        )?;
1001
1002        let protos_count = Self::get_protos_count(&protos.dshape)
1003            .unwrap_or_else(|| protos.shape[1].min(protos.shape[3]));
1004        log::debug!("Protos count: {}", protos_count);
1005        log::debug!("Detection dshape: {:?}", detection.dshape);
1006        let classes = if !detection.dshape.is_empty() {
1007            Self::get_class_count(&detection.dshape, Some(protos_count), None)?
1008        } else {
1009            Self::get_class_count_no_dshape(detection.into(), Some(protos_count))?
1010        };
1011
1012        if classes == 0 {
1013            return Err(DecoderError::InvalidConfig(
1014                "Yolo Segmentation Detection has zero classes".to_string(),
1015            ));
1016        }
1017
1018        Ok(())
1019    }
1020
1021    fn verify_yolo_seg_det_26(
1022        detection: &configs::Detection,
1023        protos: &configs::Protos,
1024    ) -> Result<(), DecoderError> {
1025        if detection.shape.len() != 3 {
1026            return Err(DecoderError::InvalidConfig(format!(
1027                "Invalid Yolo Detection shape {:?}",
1028                detection.shape
1029            )));
1030        }
1031        if protos.shape.len() != 4 {
1032            return Err(DecoderError::InvalidConfig(format!(
1033                "Invalid Yolo Protos shape {:?}",
1034                protos.shape
1035            )));
1036        }
1037
1038        Self::verify_dshapes(
1039            &detection.dshape,
1040            &detection.shape,
1041            "Detection",
1042            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1043        )?;
1044        Self::verify_dshapes(
1045            &protos.dshape,
1046            &protos.shape,
1047            "Protos",
1048            &[
1049                DimName::Batch,
1050                DimName::Height,
1051                DimName::Width,
1052                DimName::NumProtos,
1053            ],
1054        )?;
1055
1056        let protos_count = Self::get_protos_count(&protos.dshape)
1057            .unwrap_or_else(|| protos.shape[1].min(protos.shape[3]));
1058        log::debug!("Protos count: {}", protos_count);
1059        log::debug!("Detection dshape: {:?}", detection.dshape);
1060
1061        if !detection.shape.contains(&(6 + protos_count)) {
1062            return Err(DecoderError::InvalidConfig(format!(
1063                "Yolo26 Segmentation Detection must have num_features be 6 + num_protos = {}",
1064                6 + protos_count
1065            )));
1066        }
1067
1068        Ok(())
1069    }
1070
1071    fn verify_yolo_split_det(
1072        boxes: &configs::Boxes,
1073        scores: &configs::Scores,
1074    ) -> Result<(), DecoderError> {
1075        if boxes.shape.len() != 3 {
1076            return Err(DecoderError::InvalidConfig(format!(
1077                "Invalid Yolo Split Boxes shape {:?}",
1078                boxes.shape
1079            )));
1080        }
1081        if scores.shape.len() != 3 {
1082            return Err(DecoderError::InvalidConfig(format!(
1083                "Invalid Yolo Split Scores shape {:?}",
1084                scores.shape
1085            )));
1086        }
1087
1088        Self::verify_dshapes(
1089            &boxes.dshape,
1090            &boxes.shape,
1091            "Boxes",
1092            &[DimName::Batch, DimName::BoxCoords, DimName::NumBoxes],
1093        )?;
1094        Self::verify_dshapes(
1095            &scores.dshape,
1096            &scores.shape,
1097            "Scores",
1098            &[DimName::Batch, DimName::NumClasses, DimName::NumBoxes],
1099        )?;
1100
1101        let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[2]);
1102        let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[2]);
1103
1104        if boxes_num != scores_num {
1105            return Err(DecoderError::InvalidConfig(format!(
1106                "Yolo Split Detection Boxes num {} incompatible with Scores num {}",
1107                boxes_num, scores_num
1108            )));
1109        }
1110
1111        Ok(())
1112    }
1113
1114    fn verify_yolo_split_segdet(
1115        boxes: &configs::Boxes,
1116        scores: &configs::Scores,
1117        mask_coeff: &configs::MaskCoefficients,
1118        protos: &configs::Protos,
1119    ) -> Result<(), DecoderError> {
1120        if boxes.shape.len() != 3 {
1121            return Err(DecoderError::InvalidConfig(format!(
1122                "Invalid Yolo Split Boxes shape {:?}",
1123                boxes.shape
1124            )));
1125        }
1126        if scores.shape.len() != 3 {
1127            return Err(DecoderError::InvalidConfig(format!(
1128                "Invalid Yolo Split Scores shape {:?}",
1129                scores.shape
1130            )));
1131        }
1132
1133        if mask_coeff.shape.len() != 3 {
1134            return Err(DecoderError::InvalidConfig(format!(
1135                "Invalid Yolo Split Mask Coefficients shape {:?}",
1136                mask_coeff.shape
1137            )));
1138        }
1139
1140        if protos.shape.len() != 4 {
1141            return Err(DecoderError::InvalidConfig(format!(
1142                "Invalid Yolo Protos shape {:?}",
1143                mask_coeff.shape
1144            )));
1145        }
1146
1147        Self::verify_dshapes(
1148            &boxes.dshape,
1149            &boxes.shape,
1150            "Boxes",
1151            &[DimName::Batch, DimName::BoxCoords, DimName::NumBoxes],
1152        )?;
1153        Self::verify_dshapes(
1154            &scores.dshape,
1155            &scores.shape,
1156            "Scores",
1157            &[DimName::Batch, DimName::NumClasses, DimName::NumBoxes],
1158        )?;
1159        Self::verify_dshapes(
1160            &mask_coeff.dshape,
1161            &mask_coeff.shape,
1162            "Mask Coefficients",
1163            &[DimName::Batch, DimName::NumProtos, DimName::NumBoxes],
1164        )?;
1165        Self::verify_dshapes(
1166            &protos.dshape,
1167            &protos.shape,
1168            "Protos",
1169            &[
1170                DimName::Batch,
1171                DimName::Height,
1172                DimName::Width,
1173                DimName::NumProtos,
1174            ],
1175        )?;
1176
1177        let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[2]);
1178        let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[2]);
1179        let mask_num = Self::get_box_count(&mask_coeff.dshape).unwrap_or(mask_coeff.shape[2]);
1180
1181        let mask_channels = if !mask_coeff.dshape.is_empty() {
1182            Self::get_protos_count(&mask_coeff.dshape).ok_or_else(|| {
1183                DecoderError::InvalidConfig("Could not find num_protos in config".to_string())
1184            })?
1185        } else {
1186            mask_coeff.shape[1]
1187        };
1188        let proto_channels = if !protos.dshape.is_empty() {
1189            Self::get_protos_count(&protos.dshape).ok_or_else(|| {
1190                DecoderError::InvalidConfig("Could not find num_protos in config".to_string())
1191            })?
1192        } else {
1193            protos.shape[1].min(protos.shape[3])
1194        };
1195
1196        if boxes_num != scores_num {
1197            return Err(DecoderError::InvalidConfig(format!(
1198                "Yolo Split Detection Boxes num {} incompatible with Scores num {}",
1199                boxes_num, scores_num
1200            )));
1201        }
1202
1203        if boxes_num != mask_num {
1204            return Err(DecoderError::InvalidConfig(format!(
1205                "Yolo Split Detection Boxes num {} incompatible with Mask Coefficients num {}",
1206                boxes_num, mask_num
1207            )));
1208        }
1209
1210        if proto_channels != mask_channels {
1211            return Err(DecoderError::InvalidConfig(format!(
1212                "Yolo Protos channels {} incompatible with Mask Coefficients channels {}",
1213                proto_channels, mask_channels
1214            )));
1215        }
1216
1217        Ok(())
1218    }
1219
1220    fn verify_yolo_split_end_to_end_det(
1221        boxes: &configs::Boxes,
1222        scores: &configs::Scores,
1223        classes: &configs::Classes,
1224    ) -> Result<(), DecoderError> {
1225        if boxes.shape.len() != 3 || !boxes.shape.contains(&4) {
1226            return Err(DecoderError::InvalidConfig(format!(
1227                "Split end-to-end boxes must be [batch, N, 4], got {:?}",
1228                boxes.shape
1229            )));
1230        }
1231        if scores.shape.len() != 3 || !scores.shape.contains(&1) {
1232            return Err(DecoderError::InvalidConfig(format!(
1233                "Split end-to-end scores must be [batch, N, 1], got {:?}",
1234                scores.shape
1235            )));
1236        }
1237        if classes.shape.len() != 3 || !classes.shape.contains(&1) {
1238            return Err(DecoderError::InvalidConfig(format!(
1239                "Split end-to-end classes must be [batch, N, 1], got {:?}",
1240                classes.shape
1241            )));
1242        }
1243        Ok(())
1244    }
1245
1246    fn verify_yolo_split_end_to_end_segdet(
1247        boxes: &configs::Boxes,
1248        scores: &configs::Scores,
1249        classes: &configs::Classes,
1250        mask_coeff: &configs::MaskCoefficients,
1251        protos: &configs::Protos,
1252    ) -> Result<(), DecoderError> {
1253        Self::verify_yolo_split_end_to_end_det(boxes, scores, classes)?;
1254        if mask_coeff.shape.len() != 3 {
1255            return Err(DecoderError::InvalidConfig(format!(
1256                "Invalid split end-to-end mask coefficients shape {:?}",
1257                mask_coeff.shape
1258            )));
1259        }
1260        if protos.shape.len() != 4 {
1261            return Err(DecoderError::InvalidConfig(format!(
1262                "Invalid protos shape {:?}",
1263                protos.shape
1264            )));
1265        }
1266        Ok(())
1267    }
1268
1269    fn get_model_type_modelpack(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
1270        let mut split_decoders = Vec::new();
1271        let mut segment_ = None;
1272        let mut scores_ = None;
1273        let mut boxes_ = None;
1274        for c in configs.outputs {
1275            match c {
1276                ConfigOutput::Detection(detection) => split_decoders.push(detection),
1277                ConfigOutput::Segmentation(segmentation) => segment_ = Some(segmentation),
1278                ConfigOutput::Mask(_) => {}
1279                ConfigOutput::Protos(_) => {
1280                    return Err(DecoderError::InvalidConfig(
1281                        "ModelPack should not have protos".to_string(),
1282                    ));
1283                }
1284                ConfigOutput::Scores(scores) => scores_ = Some(scores),
1285                ConfigOutput::Boxes(boxes) => boxes_ = Some(boxes),
1286                ConfigOutput::MaskCoefficients(_) => {
1287                    return Err(DecoderError::InvalidConfig(
1288                        "ModelPack should not have mask coefficients".to_string(),
1289                    ));
1290                }
1291                ConfigOutput::Classes(_) => {
1292                    return Err(DecoderError::InvalidConfig(
1293                        "ModelPack should not have classes output".to_string(),
1294                    ));
1295                }
1296            }
1297        }
1298
1299        if let Some(segmentation) = segment_ {
1300            if !split_decoders.is_empty() {
1301                let classes = Self::verify_modelpack_split_det(&split_decoders)?;
1302                Self::verify_modelpack_seg(&segmentation, Some(classes))?;
1303                Ok(ModelType::ModelPackSegDetSplit {
1304                    detection: split_decoders,
1305                    segmentation,
1306                })
1307            } else if let (Some(scores), Some(boxes)) = (scores_, boxes_) {
1308                let classes = Self::verify_modelpack_det(&boxes, &scores)?;
1309                Self::verify_modelpack_seg(&segmentation, Some(classes))?;
1310                Ok(ModelType::ModelPackSegDet {
1311                    boxes,
1312                    scores,
1313                    segmentation,
1314                })
1315            } else {
1316                Self::verify_modelpack_seg(&segmentation, None)?;
1317                Ok(ModelType::ModelPackSeg { segmentation })
1318            }
1319        } else if !split_decoders.is_empty() {
1320            Self::verify_modelpack_split_det(&split_decoders)?;
1321            Ok(ModelType::ModelPackDetSplit {
1322                detection: split_decoders,
1323            })
1324        } else if let (Some(scores), Some(boxes)) = (scores_, boxes_) {
1325            Self::verify_modelpack_det(&boxes, &scores)?;
1326            Ok(ModelType::ModelPackDet { boxes, scores })
1327        } else {
1328            Err(DecoderError::InvalidConfig(
1329                "Invalid ModelPack model outputs".to_string(),
1330            ))
1331        }
1332    }
1333
1334    fn verify_modelpack_det(
1335        boxes: &configs::Boxes,
1336        scores: &configs::Scores,
1337    ) -> Result<usize, DecoderError> {
1338        if boxes.shape.len() != 4 {
1339            return Err(DecoderError::InvalidConfig(format!(
1340                "Invalid ModelPack Boxes shape {:?}",
1341                boxes.shape
1342            )));
1343        }
1344        if scores.shape.len() != 3 {
1345            return Err(DecoderError::InvalidConfig(format!(
1346                "Invalid ModelPack Scores shape {:?}",
1347                scores.shape
1348            )));
1349        }
1350
1351        Self::verify_dshapes(
1352            &boxes.dshape,
1353            &boxes.shape,
1354            "Boxes",
1355            &[
1356                DimName::Batch,
1357                DimName::NumBoxes,
1358                DimName::Padding,
1359                DimName::BoxCoords,
1360            ],
1361        )?;
1362        Self::verify_dshapes(
1363            &scores.dshape,
1364            &scores.shape,
1365            "Scores",
1366            &[DimName::Batch, DimName::NumBoxes, DimName::NumClasses],
1367        )?;
1368
1369        let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[1]);
1370        let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[1]);
1371
1372        if boxes_num != scores_num {
1373            return Err(DecoderError::InvalidConfig(format!(
1374                "ModelPack Detection Boxes num {} incompatible with Scores num {}",
1375                boxes_num, scores_num
1376            )));
1377        }
1378
1379        let num_classes = if !scores.dshape.is_empty() {
1380            Self::get_class_count(&scores.dshape, None, None)?
1381        } else {
1382            Self::get_class_count_no_dshape(scores.into(), None)?
1383        };
1384
1385        Ok(num_classes)
1386    }
1387
1388    fn verify_modelpack_split_det(boxes: &[configs::Detection]) -> Result<usize, DecoderError> {
1389        let mut num_classes = None;
1390        for b in boxes {
1391            let Some(num_anchors) = b.anchors.as_ref().map(|a| a.len()) else {
1392                return Err(DecoderError::InvalidConfig(
1393                    "ModelPack Split Detection missing anchors".to_string(),
1394                ));
1395            };
1396
1397            if num_anchors == 0 {
1398                return Err(DecoderError::InvalidConfig(
1399                    "ModelPack Split Detection has zero anchors".to_string(),
1400                ));
1401            }
1402
1403            if b.shape.len() != 4 {
1404                return Err(DecoderError::InvalidConfig(format!(
1405                    "Invalid ModelPack Split Detection shape {:?}",
1406                    b.shape
1407                )));
1408            }
1409
1410            Self::verify_dshapes(
1411                &b.dshape,
1412                &b.shape,
1413                "Split Detection",
1414                &[
1415                    DimName::Batch,
1416                    DimName::Height,
1417                    DimName::Width,
1418                    DimName::NumAnchorsXFeatures,
1419                ],
1420            )?;
1421            let classes = if !b.dshape.is_empty() {
1422                Self::get_class_count(&b.dshape, None, Some(num_anchors))?
1423            } else {
1424                Self::get_class_count_no_dshape(b.into(), None)?
1425            };
1426
1427            match num_classes {
1428                Some(n) => {
1429                    if n != classes {
1430                        return Err(DecoderError::InvalidConfig(format!(
1431                            "ModelPack Split Detection inconsistent number of classes: previous {}, current {}",
1432                            n, classes
1433                        )));
1434                    }
1435                }
1436                None => {
1437                    num_classes = Some(classes);
1438                }
1439            }
1440        }
1441
1442        Ok(num_classes.unwrap_or(0))
1443    }
1444
1445    fn verify_modelpack_seg(
1446        segmentation: &configs::Segmentation,
1447        classes: Option<usize>,
1448    ) -> Result<(), DecoderError> {
1449        if segmentation.shape.len() != 4 {
1450            return Err(DecoderError::InvalidConfig(format!(
1451                "Invalid ModelPack Segmentation shape {:?}",
1452                segmentation.shape
1453            )));
1454        }
1455        Self::verify_dshapes(
1456            &segmentation.dshape,
1457            &segmentation.shape,
1458            "Segmentation",
1459            &[
1460                DimName::Batch,
1461                DimName::Height,
1462                DimName::Width,
1463                DimName::NumClasses,
1464            ],
1465        )?;
1466
1467        if let Some(classes) = classes {
1468            let seg_classes = if !segmentation.dshape.is_empty() {
1469                Self::get_class_count(&segmentation.dshape, None, None)?
1470            } else {
1471                Self::get_class_count_no_dshape(segmentation.into(), None)?
1472            };
1473
1474            if seg_classes != classes + 1 {
1475                return Err(DecoderError::InvalidConfig(format!(
1476                    "ModelPack Segmentation channels {} incompatible with number of classes {}",
1477                    seg_classes, classes
1478                )));
1479            }
1480        }
1481        Ok(())
1482    }
1483
1484    // verifies that dshapes match the given shape
1485    fn verify_dshapes(
1486        dshape: &[(DimName, usize)],
1487        shape: &[usize],
1488        name: &str,
1489        dims: &[DimName],
1490    ) -> Result<(), DecoderError> {
1491        for s in shape {
1492            if *s == 0 {
1493                return Err(DecoderError::InvalidConfig(format!(
1494                    "{} shape has zero dimension",
1495                    name
1496                )));
1497            }
1498        }
1499
1500        if shape.len() != dims.len() {
1501            return Err(DecoderError::InvalidConfig(format!(
1502                "{} shape length {} does not match expected dims length {}",
1503                name,
1504                shape.len(),
1505                dims.len()
1506            )));
1507        }
1508
1509        if dshape.is_empty() {
1510            return Ok(());
1511        }
1512        // check the dshape lengths match the shape lengths
1513        if dshape.len() != shape.len() {
1514            return Err(DecoderError::InvalidConfig(format!(
1515                "{} dshape length does not match shape length",
1516                name
1517            )));
1518        }
1519
1520        // check the dshape values match the shape values
1521        for ((dim_name, dim_size), shape_size) in dshape.iter().zip(shape) {
1522            if dim_size != shape_size {
1523                return Err(DecoderError::InvalidConfig(format!(
1524                    "{} dshape dimension {} size {} does not match shape size {}",
1525                    name, dim_name, dim_size, shape_size
1526                )));
1527            }
1528            if *dim_name == DimName::Padding && *dim_size != 1 {
1529                return Err(DecoderError::InvalidConfig(
1530                    "Padding dimension size must be 1".to_string(),
1531                ));
1532            }
1533
1534            if *dim_name == DimName::BoxCoords && *dim_size != 4 {
1535                return Err(DecoderError::InvalidConfig(
1536                    "BoxCoords dimension size must be 4".to_string(),
1537                ));
1538            }
1539        }
1540
1541        let dims_present = HashSet::<DimName>::from_iter(dshape.iter().map(|(name, _)| *name));
1542        for dim in dims {
1543            if !dims_present.contains(dim) {
1544                return Err(DecoderError::InvalidConfig(format!(
1545                    "{} dshape missing required dimension {:?}",
1546                    name, dim
1547                )));
1548            }
1549        }
1550
1551        Ok(())
1552    }
1553
1554    fn get_box_count(dshape: &[(DimName, usize)]) -> Option<usize> {
1555        for (dim_name, dim_size) in dshape {
1556            if *dim_name == DimName::NumBoxes {
1557                return Some(*dim_size);
1558            }
1559        }
1560        None
1561    }
1562
1563    fn get_class_count_no_dshape(
1564        config: ConfigOutputRef,
1565        protos: Option<usize>,
1566    ) -> Result<usize, DecoderError> {
1567        match config {
1568            ConfigOutputRef::Detection(detection) => match detection.decoder {
1569                DecoderType::Ultralytics => {
1570                    if detection.shape[1] <= 4 + protos.unwrap_or(0) {
1571                        return Err(DecoderError::InvalidConfig(format!(
1572                            "Invalid shape: Yolo num_features {} must be greater than {}",
1573                            detection.shape[1],
1574                            4 + protos.unwrap_or(0),
1575                        )));
1576                    }
1577                    Ok(detection.shape[1] - 4 - protos.unwrap_or(0))
1578                }
1579                DecoderType::ModelPack => {
1580                    let Some(num_anchors) = detection.anchors.as_ref().map(|a| a.len()) else {
1581                        return Err(DecoderError::Internal(
1582                            "ModelPack Detection missing anchors".to_string(),
1583                        ));
1584                    };
1585                    let anchors_x_features = detection.shape[3];
1586                    if anchors_x_features <= num_anchors * 5 {
1587                        return Err(DecoderError::InvalidConfig(format!(
1588                            "Invalid ModelPack Split Detection shape: anchors_x_features {} not greater than number of anchors * 5 = {}",
1589                            anchors_x_features,
1590                            num_anchors * 5,
1591                        )));
1592                    }
1593
1594                    if !anchors_x_features.is_multiple_of(num_anchors) {
1595                        return Err(DecoderError::InvalidConfig(format!(
1596                            "Invalid ModelPack Split Detection shape: anchors_x_features {} not a multiple of number of anchors {}",
1597                            anchors_x_features, num_anchors
1598                        )));
1599                    }
1600                    Ok(anchors_x_features / num_anchors - 5)
1601                }
1602            },
1603
1604            ConfigOutputRef::Scores(scores) => match scores.decoder {
1605                DecoderType::Ultralytics => Ok(scores.shape[1]),
1606                DecoderType::ModelPack => Ok(scores.shape[2]),
1607            },
1608            ConfigOutputRef::Segmentation(seg) => Ok(seg.shape[3]),
1609            _ => Err(DecoderError::Internal(
1610                "Attempted to get class count from unsupported config output".to_owned(),
1611            )),
1612        }
1613    }
1614
1615    // get the class count from dshape or calculate from num_features
1616    fn get_class_count(
1617        dshape: &[(DimName, usize)],
1618        protos: Option<usize>,
1619        anchors: Option<usize>,
1620    ) -> Result<usize, DecoderError> {
1621        if dshape.is_empty() {
1622            return Ok(0);
1623        }
1624        // if it has num_classes in dshape, return it
1625        for (dim_name, dim_size) in dshape {
1626            if *dim_name == DimName::NumClasses {
1627                return Ok(*dim_size);
1628            }
1629        }
1630
1631        // number of classes can be calculated from num_features - 4 for yolo.  If the
1632        // model has protos, we also subtract the number of protos.
1633        for (dim_name, dim_size) in dshape {
1634            if *dim_name == DimName::NumFeatures {
1635                let protos = protos.unwrap_or(0);
1636                if protos + 4 >= *dim_size {
1637                    return Err(DecoderError::InvalidConfig(format!(
1638                        "Invalid shape: Yolo num_features {} must be greater than {}",
1639                        *dim_size,
1640                        protos + 4,
1641                    )));
1642                }
1643                return Ok(*dim_size - 4 - protos);
1644            }
1645        }
1646
1647        // number of classes can be calculated from number of anchors for modelpack
1648        // split detection
1649        if let Some(num_anchors) = anchors {
1650            for (dim_name, dim_size) in dshape {
1651                if *dim_name == DimName::NumAnchorsXFeatures {
1652                    let anchors_x_features = *dim_size;
1653                    if anchors_x_features <= num_anchors * 5 {
1654                        return Err(DecoderError::InvalidConfig(format!(
1655                            "Invalid ModelPack Split Detection shape: anchors_x_features {} not greater than number of anchors * 5 = {}",
1656                            anchors_x_features,
1657                            num_anchors * 5,
1658                        )));
1659                    }
1660
1661                    if !anchors_x_features.is_multiple_of(num_anchors) {
1662                        return Err(DecoderError::InvalidConfig(format!(
1663                            "Invalid ModelPack Split Detection shape: anchors_x_features {} not a multiple of number of anchors {}",
1664                            anchors_x_features, num_anchors
1665                        )));
1666                    }
1667                    return Ok((anchors_x_features / num_anchors) - 5);
1668                }
1669            }
1670        }
1671        Err(DecoderError::InvalidConfig(
1672            "Cannot determine number of classes from dshape".to_owned(),
1673        ))
1674    }
1675
1676    fn get_protos_count(dshape: &[(DimName, usize)]) -> Option<usize> {
1677        for (dim_name, dim_size) in dshape {
1678            if *dim_name == DimName::NumProtos {
1679                return Some(*dim_size);
1680            }
1681        }
1682        None
1683    }
1684}