Skip to main content

edgefirst_decoder/decoder/
builder.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::HashSet;
5
6use super::config::ConfigOutputRef;
7use super::configs::{self, DecoderType, DecoderVersion, DimName, ModelType};
8use super::merge::DecodeProgram;
9use super::{ConfigOutput, ConfigOutputs, Decoder};
10use crate::schema::SchemaV2;
11use crate::DecoderError;
12
13#[derive(Debug, Clone, PartialEq)]
14pub struct DecoderBuilder {
15    config_src: Option<ConfigSource>,
16    iou_threshold: f32,
17    score_threshold: f32,
18    /// NMS mode: Some(mode) applies NMS, None bypasses NMS (for end-to-end
19    /// models)
20    nms: Option<configs::Nms>,
21}
22
23#[derive(Debug, Clone, PartialEq)]
24enum ConfigSource {
25    Yaml(String),
26    Json(String),
27    Config(ConfigOutputs),
28    /// Schema v2 metadata. During build the schema is either converted
29    /// to a legacy [`ConfigOutputs`] (flat case) or used to construct a
30    /// [`DecodeProgram`] that performs per-child dequant + merge at
31    /// decode time.
32    Schema(SchemaV2),
33}
34
35impl Default for DecoderBuilder {
36    /// Creates a default DecoderBuilder with no configuration and 0.5 score
37    /// threshold and 0.5 IoU threshold.
38    ///
39    /// A valid configuration must be provided before building the Decoder.
40    ///
41    /// # Examples
42    /// ```rust
43    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
44    /// # fn main() -> DecoderResult<()> {
45    /// #  let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
46    /// let decoder = DecoderBuilder::default()
47    ///     .with_config_yaml_str(config_yaml)
48    ///     .build()?;
49    /// assert_eq!(decoder.score_threshold, 0.5);
50    /// assert_eq!(decoder.iou_threshold, 0.5);
51    ///
52    /// # Ok(())
53    /// # }
54    /// ```
55    fn default() -> Self {
56        Self {
57            config_src: None,
58            iou_threshold: 0.5,
59            score_threshold: 0.5,
60            nms: Some(configs::Nms::ClassAgnostic),
61        }
62    }
63}
64
65impl DecoderBuilder {
66    /// Creates a default DecoderBuilder with no configuration and 0.5 score
67    /// threshold and 0.5 IoU threshold.
68    ///
69    /// A valid configuration must be provided before building the Decoder.
70    ///
71    /// # Examples
72    /// ```rust
73    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
74    /// # fn main() -> DecoderResult<()> {
75    /// #  let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
76    /// let decoder = DecoderBuilder::new()
77    ///     .with_config_yaml_str(config_yaml)
78    ///     .build()?;
79    /// assert_eq!(decoder.score_threshold, 0.5);
80    /// assert_eq!(decoder.iou_threshold, 0.5);
81    ///
82    /// # Ok(())
83    /// # }
84    /// ```
85    pub fn new() -> Self {
86        Self::default()
87    }
88
89    /// Loads a model configuration in YAML format. Does not check if the string
90    /// is a correct configuration file. Use `DecoderBuilder.build()` to
91    /// deserialize the YAML and parse the model configuration.
92    ///
93    /// # Examples
94    /// ```rust
95    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
96    /// # fn main() -> DecoderResult<()> {
97    /// let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
98    /// let decoder = DecoderBuilder::new()
99    ///     .with_config_yaml_str(config_yaml)
100    ///     .build()?;
101    ///
102    /// # Ok(())
103    /// # }
104    /// ```
105    pub fn with_config_yaml_str(mut self, yaml_str: String) -> Self {
106        self.config_src.replace(ConfigSource::Yaml(yaml_str));
107        self
108    }
109
110    /// Loads a model configuration in JSON format. Does not check if the string
111    /// is a correct configuration file. Use `DecoderBuilder.build()` to
112    /// deserialize the JSON and parse the model configuration.
113    ///
114    /// # Examples
115    /// ```rust
116    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
117    /// # fn main() -> DecoderResult<()> {
118    /// let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
119    /// let decoder = DecoderBuilder::new()
120    ///     .with_config_json_str(config_json)
121    ///     .build()?;
122    ///
123    /// # Ok(())
124    /// # }
125    /// ```
126    pub fn with_config_json_str(mut self, json_str: String) -> Self {
127        self.config_src.replace(ConfigSource::Json(json_str));
128        self
129    }
130
131    /// Loads a model configuration. Does not check if the configuration is
132    /// correct. Intended to be used when the user needs control over the
133    /// deserialize of the configuration information. Use
134    /// `DecoderBuilder.build()` to parse the model configuration.
135    ///
136    /// # Examples
137    /// ```rust
138    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
139    /// # fn main() -> DecoderResult<()> {
140    /// let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json"));
141    /// let config = serde_json::from_str(config_json)?;
142    /// let decoder = DecoderBuilder::new().with_config(config).build()?;
143    ///
144    /// # Ok(())
145    /// # }
146    /// ```
147    pub fn with_config(mut self, config: ConfigOutputs) -> Self {
148        self.config_src.replace(ConfigSource::Config(config));
149        self
150    }
151
152    /// Configure the decoder from a schema v2 metadata document.
153    ///
154    /// Accepts a [`SchemaV2`] as produced by [`SchemaV2::parse_json`],
155    /// [`SchemaV2::parse_yaml`], [`SchemaV2::parse_file`], or
156    /// constructed programmatically. The builder validates the schema,
157    /// compiles a [`DecodeProgram`] for any split logical outputs
158    /// (per-scale or channel sub-splits), and downconverts the
159    /// logical-level semantics to the legacy [`ConfigOutputs`]
160    /// representation consumed by the existing decoder dispatch.
161    ///
162    /// # Examples
163    ///
164    /// ```rust,no_run
165    /// use edgefirst_decoder::{DecoderBuilder, DecoderResult};
166    /// use edgefirst_decoder::schema::SchemaV2;
167    ///
168    /// # fn main() -> DecoderResult<()> {
169    /// let schema = SchemaV2::parse_file("model/edgefirst.json")?;
170    /// let decoder = DecoderBuilder::new()
171    ///     .with_schema(schema)
172    ///     .with_score_threshold(0.25)
173    ///     .build()?;
174    /// # Ok(())
175    /// # }
176    /// ```
177    pub fn with_schema(mut self, schema: SchemaV2) -> Self {
178        self.config_src.replace(ConfigSource::Schema(schema));
179        self
180    }
181
182    /// Loads a YOLO detection model configuration.  Use
183    /// `DecoderBuilder.build()` to parse the model configuration.
184    ///
185    /// # Examples
186    /// ```rust
187    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
188    /// # fn main() -> DecoderResult<()> {
189    /// let decoder = DecoderBuilder::new()
190    ///     .with_config_yolo_det(
191    ///         configs::Detection {
192    ///             anchors: None,
193    ///             decoder: configs::DecoderType::Ultralytics,
194    ///             quantization: Some(configs::QuantTuple(0.012345, 26)),
195    ///             shape: vec![1, 84, 8400],
196    ///             dshape: Vec::new(),
197    ///             normalized: Some(true),
198    ///         },
199    ///         None,
200    ///     )
201    ///     .build()?;
202    ///
203    /// # Ok(())
204    /// # }
205    /// ```
206    pub fn with_config_yolo_det(
207        mut self,
208        boxes: configs::Detection,
209        version: Option<DecoderVersion>,
210    ) -> Self {
211        let config = ConfigOutputs {
212            outputs: vec![ConfigOutput::Detection(boxes)],
213            decoder_version: version,
214            ..Default::default()
215        };
216        self.config_src.replace(ConfigSource::Config(config));
217        self
218    }
219
220    /// Loads a YOLO split detection model configuration.  Use
221    /// `DecoderBuilder.build()` to parse the model configuration.
222    ///
223    /// # Examples
224    /// ```rust
225    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
226    /// # fn main() -> DecoderResult<()> {
227    /// let boxes_config = configs::Boxes {
228    ///     decoder: configs::DecoderType::Ultralytics,
229    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
230    ///     shape: vec![1, 4, 8400],
231    ///     dshape: Vec::new(),
232    ///     normalized: Some(true),
233    /// };
234    /// let scores_config = configs::Scores {
235    ///     decoder: configs::DecoderType::Ultralytics,
236    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
237    ///     shape: vec![1, 80, 8400],
238    ///     dshape: Vec::new(),
239    /// };
240    /// let decoder = DecoderBuilder::new()
241    ///     .with_config_yolo_split_det(boxes_config, scores_config)
242    ///     .build()?;
243    /// # Ok(())
244    /// # }
245    /// ```
246    pub fn with_config_yolo_split_det(
247        mut self,
248        boxes: configs::Boxes,
249        scores: configs::Scores,
250    ) -> Self {
251        let config = ConfigOutputs {
252            outputs: vec![ConfigOutput::Boxes(boxes), ConfigOutput::Scores(scores)],
253            ..Default::default()
254        };
255        self.config_src.replace(ConfigSource::Config(config));
256        self
257    }
258
259    /// Loads a YOLO segmentation model configuration.  Use
260    /// `DecoderBuilder.build()` to parse the model configuration.
261    ///
262    /// # Examples
263    /// ```rust
264    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
265    /// # fn main() -> DecoderResult<()> {
266    /// let seg_config = configs::Detection {
267    ///     decoder: configs::DecoderType::Ultralytics,
268    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
269    ///     shape: vec![1, 116, 8400],
270    ///     anchors: None,
271    ///     dshape: Vec::new(),
272    ///     normalized: Some(true),
273    /// };
274    /// let protos_config = configs::Protos {
275    ///     decoder: configs::DecoderType::Ultralytics,
276    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
277    ///     shape: vec![1, 160, 160, 32],
278    ///     dshape: Vec::new(),
279    /// };
280    /// let decoder = DecoderBuilder::new()
281    ///     .with_config_yolo_segdet(
282    ///         seg_config,
283    ///         protos_config,
284    ///         Some(configs::DecoderVersion::Yolov8),
285    ///     )
286    ///     .build()?;
287    /// # Ok(())
288    /// # }
289    /// ```
290    pub fn with_config_yolo_segdet(
291        mut self,
292        boxes: configs::Detection,
293        protos: configs::Protos,
294        version: Option<DecoderVersion>,
295    ) -> Self {
296        let config = ConfigOutputs {
297            outputs: vec![ConfigOutput::Detection(boxes), ConfigOutput::Protos(protos)],
298            decoder_version: version,
299            ..Default::default()
300        };
301        self.config_src.replace(ConfigSource::Config(config));
302        self
303    }
304
305    /// Loads a YOLO split segmentation model configuration.  Use
306    /// `DecoderBuilder.build()` to parse the model configuration.
307    ///
308    /// # Examples
309    /// ```rust
310    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
311    /// # fn main() -> DecoderResult<()> {
312    /// let boxes_config = configs::Boxes {
313    ///     decoder: configs::DecoderType::Ultralytics,
314    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
315    ///     shape: vec![1, 4, 8400],
316    ///     dshape: Vec::new(),
317    ///     normalized: Some(true),
318    /// };
319    /// let scores_config = configs::Scores {
320    ///     decoder: configs::DecoderType::Ultralytics,
321    ///     quantization: Some(configs::QuantTuple(0.012345, 14)),
322    ///     shape: vec![1, 80, 8400],
323    ///     dshape: Vec::new(),
324    /// };
325    /// let mask_config = configs::MaskCoefficients {
326    ///     decoder: configs::DecoderType::Ultralytics,
327    ///     quantization: Some(configs::QuantTuple(0.0064123, 125)),
328    ///     shape: vec![1, 32, 8400],
329    ///     dshape: Vec::new(),
330    /// };
331    /// let protos_config = configs::Protos {
332    ///     decoder: configs::DecoderType::Ultralytics,
333    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
334    ///     shape: vec![1, 160, 160, 32],
335    ///     dshape: Vec::new(),
336    /// };
337    /// let decoder = DecoderBuilder::new()
338    ///     .with_config_yolo_split_segdet(boxes_config, scores_config, mask_config, protos_config)
339    ///     .build()?;
340    /// # Ok(())
341    /// # }
342    /// ```
343    pub fn with_config_yolo_split_segdet(
344        mut self,
345        boxes: configs::Boxes,
346        scores: configs::Scores,
347        mask_coefficients: configs::MaskCoefficients,
348        protos: configs::Protos,
349    ) -> Self {
350        let config = ConfigOutputs {
351            outputs: vec![
352                ConfigOutput::Boxes(boxes),
353                ConfigOutput::Scores(scores),
354                ConfigOutput::MaskCoefficients(mask_coefficients),
355                ConfigOutput::Protos(protos),
356            ],
357            ..Default::default()
358        };
359        self.config_src.replace(ConfigSource::Config(config));
360        self
361    }
362
363    /// Loads a ModelPack detection model configuration.  Use
364    /// `DecoderBuilder.build()` to parse the model configuration.
365    ///
366    /// # Examples
367    /// ```rust
368    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
369    /// # fn main() -> DecoderResult<()> {
370    /// let boxes_config = configs::Boxes {
371    ///     decoder: configs::DecoderType::ModelPack,
372    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
373    ///     shape: vec![1, 8400, 1, 4],
374    ///     dshape: Vec::new(),
375    ///     normalized: Some(true),
376    /// };
377    /// let scores_config = configs::Scores {
378    ///     decoder: configs::DecoderType::ModelPack,
379    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
380    ///     shape: vec![1, 8400, 3],
381    ///     dshape: Vec::new(),
382    /// };
383    /// let decoder = DecoderBuilder::new()
384    ///     .with_config_modelpack_det(boxes_config, scores_config)
385    ///     .build()?;
386    /// # Ok(())
387    /// # }
388    /// ```
389    pub fn with_config_modelpack_det(
390        mut self,
391        boxes: configs::Boxes,
392        scores: configs::Scores,
393    ) -> Self {
394        let config = ConfigOutputs {
395            outputs: vec![ConfigOutput::Boxes(boxes), ConfigOutput::Scores(scores)],
396            ..Default::default()
397        };
398        self.config_src.replace(ConfigSource::Config(config));
399        self
400    }
401
402    /// Loads a ModelPack split detection model configuration. Use
403    /// `DecoderBuilder.build()` to parse the model configuration.
404    ///
405    /// # Examples
406    /// ```rust
407    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
408    /// # fn main() -> DecoderResult<()> {
409    /// let config0 = configs::Detection {
410    ///     anchors: Some(vec![
411    ///         [0.13750000298023224, 0.2074074000120163],
412    ///         [0.2541666626930237, 0.21481481194496155],
413    ///         [0.23125000298023224, 0.35185185074806213],
414    ///     ]),
415    ///     decoder: configs::DecoderType::ModelPack,
416    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
417    ///     shape: vec![1, 17, 30, 18],
418    ///     dshape: Vec::new(),
419    ///     normalized: Some(true),
420    /// };
421    /// let config1 = configs::Detection {
422    ///     anchors: Some(vec![
423    ///         [0.36666667461395264, 0.31481480598449707],
424    ///         [0.38749998807907104, 0.4740740656852722],
425    ///         [0.5333333611488342, 0.644444465637207],
426    ///     ]),
427    ///     decoder: configs::DecoderType::ModelPack,
428    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
429    ///     shape: vec![1, 9, 15, 18],
430    ///     dshape: Vec::new(),
431    ///     normalized: Some(true),
432    /// };
433    ///
434    /// let decoder = DecoderBuilder::new()
435    ///     .with_config_modelpack_det_split(vec![config0, config1])
436    ///     .build()?;
437    /// # Ok(())
438    /// # }
439    /// ```
440    pub fn with_config_modelpack_det_split(mut self, boxes: Vec<configs::Detection>) -> Self {
441        let outputs = boxes.into_iter().map(ConfigOutput::Detection).collect();
442        let config = ConfigOutputs {
443            outputs,
444            ..Default::default()
445        };
446        self.config_src.replace(ConfigSource::Config(config));
447        self
448    }
449
450    /// Loads a ModelPack segmentation detection model configuration. Use
451    /// `DecoderBuilder.build()` to parse the model configuration.
452    ///
453    /// # Examples
454    /// ```rust
455    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
456    /// # fn main() -> DecoderResult<()> {
457    /// let boxes_config = configs::Boxes {
458    ///     decoder: configs::DecoderType::ModelPack,
459    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
460    ///     shape: vec![1, 8400, 1, 4],
461    ///     dshape: Vec::new(),
462    ///     normalized: Some(true),
463    /// };
464    /// let scores_config = configs::Scores {
465    ///     decoder: configs::DecoderType::ModelPack,
466    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
467    ///     shape: vec![1, 8400, 2],
468    ///     dshape: Vec::new(),
469    /// };
470    /// let seg_config = configs::Segmentation {
471    ///     decoder: configs::DecoderType::ModelPack,
472    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
473    ///     shape: vec![1, 640, 640, 3],
474    ///     dshape: Vec::new(),
475    /// };
476    /// let decoder = DecoderBuilder::new()
477    ///     .with_config_modelpack_segdet(boxes_config, scores_config, seg_config)
478    ///     .build()?;
479    /// # Ok(())
480    /// # }
481    /// ```
482    pub fn with_config_modelpack_segdet(
483        mut self,
484        boxes: configs::Boxes,
485        scores: configs::Scores,
486        segmentation: configs::Segmentation,
487    ) -> Self {
488        let config = ConfigOutputs {
489            outputs: vec![
490                ConfigOutput::Boxes(boxes),
491                ConfigOutput::Scores(scores),
492                ConfigOutput::Segmentation(segmentation),
493            ],
494            ..Default::default()
495        };
496        self.config_src.replace(ConfigSource::Config(config));
497        self
498    }
499
500    /// Loads a ModelPack segmentation split detection model configuration. Use
501    /// `DecoderBuilder.build()` to parse the model configuration.
502    ///
503    /// # Examples
504    /// ```rust
505    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
506    /// # fn main() -> DecoderResult<()> {
507    /// let config0 = configs::Detection {
508    ///     anchors: Some(vec![
509    ///         [0.36666667461395264, 0.31481480598449707],
510    ///         [0.38749998807907104, 0.4740740656852722],
511    ///         [0.5333333611488342, 0.644444465637207],
512    ///     ]),
513    ///     decoder: configs::DecoderType::ModelPack,
514    ///     quantization: Some(configs::QuantTuple(0.08547406643629074, 174)),
515    ///     shape: vec![1, 9, 15, 18],
516    ///     dshape: Vec::new(),
517    ///     normalized: Some(true),
518    /// };
519    /// let config1 = configs::Detection {
520    ///     anchors: Some(vec![
521    ///         [0.13750000298023224, 0.2074074000120163],
522    ///         [0.2541666626930237, 0.21481481194496155],
523    ///         [0.23125000298023224, 0.35185185074806213],
524    ///     ]),
525    ///     decoder: configs::DecoderType::ModelPack,
526    ///     quantization: Some(configs::QuantTuple(0.09929127991199493, 183)),
527    ///     shape: vec![1, 17, 30, 18],
528    ///     dshape: Vec::new(),
529    ///     normalized: Some(true),
530    /// };
531    /// let seg_config = configs::Segmentation {
532    ///     decoder: configs::DecoderType::ModelPack,
533    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
534    ///     shape: vec![1, 640, 640, 2],
535    ///     dshape: Vec::new(),
536    /// };
537    /// let decoder = DecoderBuilder::new()
538    ///     .with_config_modelpack_segdet_split(vec![config0, config1], seg_config)
539    ///     .build()?;
540    /// # Ok(())
541    /// # }
542    /// ```
543    pub fn with_config_modelpack_segdet_split(
544        mut self,
545        boxes: Vec<configs::Detection>,
546        segmentation: configs::Segmentation,
547    ) -> Self {
548        let mut outputs = boxes
549            .into_iter()
550            .map(ConfigOutput::Detection)
551            .collect::<Vec<_>>();
552        outputs.push(ConfigOutput::Segmentation(segmentation));
553        let config = ConfigOutputs {
554            outputs,
555            ..Default::default()
556        };
557        self.config_src.replace(ConfigSource::Config(config));
558        self
559    }
560
561    /// Loads a ModelPack segmentation model configuration. Use
562    /// `DecoderBuilder.build()` to parse the model configuration.
563    ///
564    /// # Examples
565    /// ```rust
566    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
567    /// # fn main() -> DecoderResult<()> {
568    /// let seg_config = configs::Segmentation {
569    ///     decoder: configs::DecoderType::ModelPack,
570    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
571    ///     shape: vec![1, 640, 640, 3],
572    ///     dshape: Vec::new(),
573    /// };
574    /// let decoder = DecoderBuilder::new()
575    ///     .with_config_modelpack_seg(seg_config)
576    ///     .build()?;
577    /// # Ok(())
578    /// # }
579    /// ```
580    pub fn with_config_modelpack_seg(mut self, segmentation: configs::Segmentation) -> Self {
581        let config = ConfigOutputs {
582            outputs: vec![ConfigOutput::Segmentation(segmentation)],
583            ..Default::default()
584        };
585        self.config_src.replace(ConfigSource::Config(config));
586        self
587    }
588
589    /// Add an output to the decoder configuration.
590    ///
591    /// Incrementally builds the model configuration by adding outputs one at
592    /// a time. The decoder resolves the model type from the combination of
593    /// outputs during `build()`.
594    ///
595    /// If `dshape` is non-empty on the output, `shape` is automatically
596    /// derived from it (the size component of each named dimension). This
597    /// prevents conflicts between `shape` and `dshape`.
598    ///
599    /// This uses the programmatic config path. Calling this after
600    /// `with_config_json_str()` or `with_config_yaml_str()` replaces the
601    /// string-based config source.
602    ///
603    /// # Examples
604    /// ```rust
605    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, ConfigOutput, configs};
606    /// # fn main() -> DecoderResult<()> {
607    /// let decoder = DecoderBuilder::new()
608    ///     .add_output(ConfigOutput::Scores(configs::Scores {
609    ///         decoder: configs::DecoderType::Ultralytics,
610    ///         dshape: vec![
611    ///             (configs::DimName::Batch, 1),
612    ///             (configs::DimName::NumClasses, 80),
613    ///             (configs::DimName::NumBoxes, 8400),
614    ///         ],
615    ///         ..Default::default()
616    ///     }))
617    ///     .add_output(ConfigOutput::Boxes(configs::Boxes {
618    ///         decoder: configs::DecoderType::Ultralytics,
619    ///         dshape: vec![
620    ///             (configs::DimName::Batch, 1),
621    ///             (configs::DimName::BoxCoords, 4),
622    ///             (configs::DimName::NumBoxes, 8400),
623    ///         ],
624    ///         ..Default::default()
625    ///     }))
626    ///     .build()?;
627    /// # Ok(())
628    /// # }
629    /// ```
630    pub fn add_output(mut self, output: ConfigOutput) -> Self {
631        if !matches!(self.config_src, Some(ConfigSource::Config(_))) {
632            self.config_src = Some(ConfigSource::Config(ConfigOutputs::default()));
633        }
634        if let Some(ConfigSource::Config(ref mut config)) = self.config_src {
635            config.outputs.push(Self::normalize_output(output));
636        }
637        self
638    }
639
640    /// Sets the decoder version for Ultralytics models.
641    ///
642    /// This is used with `add_output()` to specify the YOLO architecture
643    /// version when it cannot be inferred from the output shapes alone.
644    ///
645    /// - `Yolov5`, `Yolov8`, `Yolo11`: Traditional models requiring external
646    ///   NMS
647    /// - `Yolo26`: End-to-end models with NMS embedded in the model graph
648    ///
649    /// # Examples
650    /// ```rust
651    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, ConfigOutput, configs};
652    /// # fn main() -> DecoderResult<()> {
653    /// let decoder = DecoderBuilder::new()
654    ///     .add_output(ConfigOutput::Detection(configs::Detection {
655    ///         decoder: configs::DecoderType::Ultralytics,
656    ///         dshape: vec![
657    ///             (configs::DimName::Batch, 1),
658    ///             (configs::DimName::NumBoxes, 100),
659    ///             (configs::DimName::NumFeatures, 6),
660    ///         ],
661    ///         ..Default::default()
662    ///     }))
663    ///     .with_decoder_version(configs::DecoderVersion::Yolo26)
664    ///     .build()?;
665    /// # Ok(())
666    /// # }
667    /// ```
668    pub fn with_decoder_version(mut self, version: configs::DecoderVersion) -> Self {
669        if !matches!(self.config_src, Some(ConfigSource::Config(_))) {
670            self.config_src = Some(ConfigSource::Config(ConfigOutputs::default()));
671        }
672        if let Some(ConfigSource::Config(ref mut config)) = self.config_src {
673            config.decoder_version = Some(version);
674        }
675        self
676    }
677
678    /// Normalize an output: if dshape is non-empty, derive shape from it.
679    fn normalize_output(mut output: ConfigOutput) -> ConfigOutput {
680        fn normalize_shape(shape: &mut Vec<usize>, dshape: &[(configs::DimName, usize)]) {
681            if !dshape.is_empty() {
682                *shape = dshape.iter().map(|(_, size)| *size).collect();
683            }
684        }
685        match &mut output {
686            ConfigOutput::Detection(c) => normalize_shape(&mut c.shape, &c.dshape),
687            ConfigOutput::Boxes(c) => normalize_shape(&mut c.shape, &c.dshape),
688            ConfigOutput::Scores(c) => normalize_shape(&mut c.shape, &c.dshape),
689            ConfigOutput::Protos(c) => normalize_shape(&mut c.shape, &c.dshape),
690            ConfigOutput::Segmentation(c) => normalize_shape(&mut c.shape, &c.dshape),
691            ConfigOutput::MaskCoefficients(c) => normalize_shape(&mut c.shape, &c.dshape),
692            ConfigOutput::Mask(c) => normalize_shape(&mut c.shape, &c.dshape),
693            ConfigOutput::Classes(c) => normalize_shape(&mut c.shape, &c.dshape),
694        }
695        output
696    }
697
698    /// Sets the scores threshold of the decoder
699    ///
700    /// # Examples
701    /// ```rust
702    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
703    /// # fn main() -> DecoderResult<()> {
704    /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
705    /// let decoder = DecoderBuilder::new()
706    ///     .with_config_json_str(config_json)
707    ///     .with_score_threshold(0.654)
708    ///     .build()?;
709    /// assert_eq!(decoder.score_threshold, 0.654);
710    /// # Ok(())
711    /// # }
712    /// ```
713    pub fn with_score_threshold(mut self, score_threshold: f32) -> Self {
714        self.score_threshold = score_threshold;
715        self
716    }
717
718    /// Sets the IOU threshold of the decoder. Has no effect when NMS is set to
719    /// `None`
720    ///
721    /// # Examples
722    /// ```rust
723    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
724    /// # fn main() -> DecoderResult<()> {
725    /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
726    /// let decoder = DecoderBuilder::new()
727    ///     .with_config_json_str(config_json)
728    ///     .with_iou_threshold(0.654)
729    ///     .build()?;
730    /// assert_eq!(decoder.iou_threshold, 0.654);
731    /// # Ok(())
732    /// # }
733    /// ```
734    pub fn with_iou_threshold(mut self, iou_threshold: f32) -> Self {
735        self.iou_threshold = iou_threshold;
736        self
737    }
738
739    /// Sets the NMS mode for the decoder.
740    ///
741    /// - `Some(Nms::ClassAgnostic)` — class-agnostic NMS (default): suppress
742    ///   overlapping boxes regardless of class label
743    /// - `Some(Nms::ClassAware)` — class-aware NMS: only suppress boxes that
744    ///   share the same class label AND overlap above the IoU threshold
745    /// - `None` — bypass NMS entirely (for end-to-end models with embedded NMS)
746    ///
747    /// # Examples
748    /// ```rust
749    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, configs::Nms};
750    /// # fn main() -> DecoderResult<()> {
751    /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
752    /// let decoder = DecoderBuilder::new()
753    ///     .with_config_json_str(config_json)
754    ///     .with_nms(Some(Nms::ClassAware))
755    ///     .build()?;
756    /// assert_eq!(decoder.nms, Some(Nms::ClassAware));
757    /// # Ok(())
758    /// # }
759    /// ```
760    pub fn with_nms(mut self, nms: Option<configs::Nms>) -> Self {
761        self.nms = nms;
762        self
763    }
764
765    /// Builds the decoder with the given settings. If the config is a JSON or
766    /// YAML string, this will deserialize the JSON or YAML and then parse the
767    /// configuration information.
768    ///
769    /// # Examples
770    /// ```rust
771    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
772    /// # fn main() -> DecoderResult<()> {
773    /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
774    /// let decoder = DecoderBuilder::new()
775    ///     .with_config_json_str(config_json)
776    ///     .with_score_threshold(0.654)
777    ///     .build()?;
778    /// # Ok(())
779    /// # }
780    /// ```
781    pub fn build(self) -> Result<Decoder, DecoderError> {
782        let (config, decode_program) = match self.config_src {
783            Some(ConfigSource::Json(s)) => Self::build_from_schema(SchemaV2::parse_json(&s)?)?,
784            Some(ConfigSource::Yaml(s)) => Self::build_from_schema(SchemaV2::parse_yaml(&s)?)?,
785            Some(ConfigSource::Config(c)) => (c, None),
786            Some(ConfigSource::Schema(schema)) => Self::build_from_schema(schema)?,
787            None => return Err(DecoderError::NoConfig),
788        };
789
790        // Extract normalized flag from config outputs
791        let normalized = Self::get_normalized(&config.outputs);
792
793        // Use NMS from config if present, otherwise use builder's NMS setting
794        let nms = config.nms.or(self.nms);
795        let model_type = Self::get_model_type(config)?;
796
797        Ok(Decoder {
798            model_type,
799            iou_threshold: self.iou_threshold,
800            score_threshold: self.score_threshold,
801            nms,
802            normalized,
803            decode_program,
804        })
805    }
806
807    /// Validate a [`SchemaV2`] and lower it to the (legacy `ConfigOutputs`,
808    /// optional `DecodeProgram`) pair the rest of `build()` consumes.
809    ///
810    /// Centralises the v2 lowering so JSON, YAML, and direct
811    /// `with_schema` callers all go through the same validation and
812    /// merge-program construction. `SchemaV2::parse_json` /
813    /// `parse_yaml` already auto-detect v1 vs v2 input and return a v2
814    /// schema either way (v1 inputs are upgraded in memory via
815    /// `from_v1`), so this helper is the sole place that turns a
816    /// schema into builder-ready state.
817    fn build_from_schema(
818        schema: SchemaV2,
819    ) -> Result<(ConfigOutputs, Option<DecodeProgram>), DecoderError> {
820        schema.validate()?;
821        let program = DecodeProgram::try_from_schema(&schema)?;
822        let legacy = schema.to_legacy_config_outputs()?;
823        Ok((legacy, program))
824    }
825
826    /// Extracts the normalized flag from config outputs.
827    /// - `Some(true)`: Boxes are in normalized [0,1] coordinates
828    /// - `Some(false)`: Boxes are in pixel coordinates
829    /// - `None`: Unknown (not specified in config), caller must infer
830    fn get_normalized(outputs: &[ConfigOutput]) -> Option<bool> {
831        for output in outputs {
832            match output {
833                ConfigOutput::Detection(det) => return det.normalized,
834                ConfigOutput::Boxes(boxes) => return boxes.normalized,
835                _ => {}
836            }
837        }
838        None // not specified
839    }
840
841    fn get_model_type(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
842        // yolo or modelpack
843        let mut yolo = false;
844        let mut modelpack = false;
845        for c in &configs.outputs {
846            match c.decoder() {
847                DecoderType::ModelPack => modelpack = true,
848                DecoderType::Ultralytics => yolo = true,
849            }
850        }
851        match (modelpack, yolo) {
852            (true, true) => Err(DecoderError::InvalidConfig(
853                "Both ModelPack and Yolo outputs found in config".to_string(),
854            )),
855            (true, false) => Self::get_model_type_modelpack(configs),
856            (false, true) => Self::get_model_type_yolo(configs),
857            (false, false) => Err(DecoderError::InvalidConfig(
858                "No outputs found in config".to_string(),
859            )),
860        }
861    }
862
863    fn get_model_type_yolo(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
864        let mut boxes = None;
865        let mut protos = None;
866        let mut split_boxes = None;
867        let mut split_scores = None;
868        let mut split_mask_coeff = None;
869        let mut split_classes = None;
870        for c in configs.outputs {
871            match c {
872                ConfigOutput::Detection(detection) => boxes = Some(detection),
873                ConfigOutput::Segmentation(_) => {
874                    return Err(DecoderError::InvalidConfig(
875                        "Invalid Segmentation output with Yolo decoder".to_string(),
876                    ));
877                }
878                ConfigOutput::Protos(protos_) => protos = Some(protos_),
879                ConfigOutput::Mask(_) => {
880                    return Err(DecoderError::InvalidConfig(
881                        "Invalid Mask output with Yolo decoder".to_string(),
882                    ));
883                }
884                ConfigOutput::Scores(scores) => split_scores = Some(scores),
885                ConfigOutput::Boxes(boxes) => split_boxes = Some(boxes),
886                ConfigOutput::MaskCoefficients(mask_coeff) => split_mask_coeff = Some(mask_coeff),
887                ConfigOutput::Classes(classes) => split_classes = Some(classes),
888            }
889        }
890
891        // Use end-to-end model types when:
892        // 1. decoder_version is explicitly set to Yolo26 (definitive), OR
893        //    decoder_version is not set but the dshapes are (batch, num_boxes,
894        //    num_features)
895        let is_end_to_end_dshape = boxes.as_ref().is_some_and(|b| {
896            let dims = b.dshape.iter().map(|(d, _)| *d).collect::<Vec<_>>();
897            dims == vec![DimName::Batch, DimName::NumBoxes, DimName::NumFeatures]
898        });
899
900        let is_end_to_end = configs
901            .decoder_version
902            .map(|v| v.is_end_to_end())
903            .unwrap_or(is_end_to_end_dshape);
904
905        if is_end_to_end {
906            if let Some(boxes) = boxes {
907                if let Some(protos) = protos {
908                    Self::verify_yolo_seg_det_26(&boxes, &protos)?;
909                    return Ok(ModelType::YoloEndToEndSegDet { boxes, protos });
910                } else {
911                    Self::verify_yolo_det_26(&boxes)?;
912                    return Ok(ModelType::YoloEndToEndDet { boxes });
913                }
914            } else if let (Some(split_boxes), Some(split_scores), Some(split_classes)) =
915                (split_boxes, split_scores, split_classes)
916            {
917                if let (Some(split_mask_coeff), Some(protos)) = (split_mask_coeff, protos) {
918                    Self::verify_yolo_split_end_to_end_segdet(
919                        &split_boxes,
920                        &split_scores,
921                        &split_classes,
922                        &split_mask_coeff,
923                        &protos,
924                    )?;
925                    return Ok(ModelType::YoloSplitEndToEndSegDet {
926                        boxes: split_boxes,
927                        scores: split_scores,
928                        classes: split_classes,
929                        mask_coeff: split_mask_coeff,
930                        protos,
931                    });
932                }
933                Self::verify_yolo_split_end_to_end_det(
934                    &split_boxes,
935                    &split_scores,
936                    &split_classes,
937                )?;
938                return Ok(ModelType::YoloSplitEndToEndDet {
939                    boxes: split_boxes,
940                    scores: split_scores,
941                    classes: split_classes,
942                });
943            } else {
944                return Err(DecoderError::InvalidConfig(
945                    "Invalid Yolo end-to-end model outputs".to_string(),
946                ));
947            }
948        }
949
950        if let Some(boxes) = boxes {
951            match (split_mask_coeff, protos) {
952                (Some(mask_coeff), Some(protos)) => {
953                    // 2-way split: combined detection + separate mask_coeff + protos
954                    Self::verify_yolo_seg_det_2way(&boxes, &mask_coeff, &protos)?;
955                    Ok(ModelType::YoloSegDet2Way {
956                        boxes,
957                        mask_coeff,
958                        protos,
959                    })
960                }
961                (_, Some(protos)) => {
962                    // Unsplit: mask_coefs embedded in detection tensor
963                    Self::verify_yolo_seg_det(&boxes, &protos)?;
964                    Ok(ModelType::YoloSegDet { boxes, protos })
965                }
966                _ => {
967                    Self::verify_yolo_det(&boxes)?;
968                    Ok(ModelType::YoloDet { boxes })
969                }
970            }
971        } else if let (Some(boxes), Some(scores)) = (split_boxes, split_scores) {
972            if let (Some(mask_coeff), Some(protos)) = (split_mask_coeff, protos) {
973                Self::verify_yolo_split_segdet(&boxes, &scores, &mask_coeff, &protos)?;
974                Ok(ModelType::YoloSplitSegDet {
975                    boxes,
976                    scores,
977                    mask_coeff,
978                    protos,
979                })
980            } else {
981                Self::verify_yolo_split_det(&boxes, &scores)?;
982                Ok(ModelType::YoloSplitDet { boxes, scores })
983            }
984        } else {
985            Err(DecoderError::InvalidConfig(
986                "Invalid Yolo model outputs".to_string(),
987            ))
988        }
989    }
990
991    fn verify_yolo_det(detect: &configs::Detection) -> Result<(), DecoderError> {
992        if detect.shape.len() != 3 {
993            return Err(DecoderError::InvalidConfig(format!(
994                "Invalid Yolo Detection shape {:?}",
995                detect.shape
996            )));
997        }
998
999        Self::verify_dshapes(
1000            &detect.dshape,
1001            &detect.shape,
1002            "Detection",
1003            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1004        )?;
1005        if !detect.dshape.is_empty() {
1006            Self::get_class_count(&detect.dshape, None, None)?;
1007        } else {
1008            Self::get_class_count_no_dshape(detect.into(), None)?;
1009        }
1010
1011        Ok(())
1012    }
1013
1014    fn verify_yolo_det_26(detect: &configs::Detection) -> Result<(), DecoderError> {
1015        if detect.shape.len() != 3 {
1016            return Err(DecoderError::InvalidConfig(format!(
1017                "Invalid Yolo Detection shape {:?}",
1018                detect.shape
1019            )));
1020        }
1021
1022        Self::verify_dshapes(
1023            &detect.dshape,
1024            &detect.shape,
1025            "Detection",
1026            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1027        )?;
1028
1029        if !detect.shape.contains(&6) {
1030            return Err(DecoderError::InvalidConfig(
1031                "Yolo26 Detection must have 6 features".to_string(),
1032            ));
1033        }
1034
1035        Ok(())
1036    }
1037
1038    fn verify_yolo_seg_det(
1039        detection: &configs::Detection,
1040        protos: &configs::Protos,
1041    ) -> Result<(), DecoderError> {
1042        if detection.shape.len() != 3 {
1043            return Err(DecoderError::InvalidConfig(format!(
1044                "Invalid Yolo Detection shape {:?}",
1045                detection.shape
1046            )));
1047        }
1048        if protos.shape.len() != 4 {
1049            return Err(DecoderError::InvalidConfig(format!(
1050                "Invalid Yolo Protos shape {:?}",
1051                protos.shape
1052            )));
1053        }
1054
1055        Self::verify_dshapes(
1056            &detection.dshape,
1057            &detection.shape,
1058            "Detection",
1059            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1060        )?;
1061        Self::verify_dshapes(
1062            &protos.dshape,
1063            &protos.shape,
1064            "Protos",
1065            &[
1066                DimName::Batch,
1067                DimName::Height,
1068                DimName::Width,
1069                DimName::NumProtos,
1070            ],
1071        )?;
1072
1073        let protos_count = Self::get_protos_count(&protos.dshape)
1074            .unwrap_or_else(|| protos.shape[1].min(protos.shape[3]));
1075        log::debug!("Protos count: {}", protos_count);
1076        log::debug!("Detection dshape: {:?}", detection.dshape);
1077        let classes = if !detection.dshape.is_empty() {
1078            Self::get_class_count(&detection.dshape, Some(protos_count), None)?
1079        } else {
1080            Self::get_class_count_no_dshape(detection.into(), Some(protos_count))?
1081        };
1082
1083        if classes == 0 {
1084            return Err(DecoderError::InvalidConfig(
1085                "Yolo Segmentation Detection has zero classes".to_string(),
1086            ));
1087        }
1088
1089        Ok(())
1090    }
1091
1092    fn verify_yolo_seg_det_2way(
1093        detection: &configs::Detection,
1094        mask_coeff: &configs::MaskCoefficients,
1095        protos: &configs::Protos,
1096    ) -> Result<(), DecoderError> {
1097        if detection.shape.len() != 3 {
1098            return Err(DecoderError::InvalidConfig(format!(
1099                "Invalid Yolo 2-Way Detection shape {:?}",
1100                detection.shape
1101            )));
1102        }
1103        if mask_coeff.shape.len() != 3 {
1104            return Err(DecoderError::InvalidConfig(format!(
1105                "Invalid Yolo 2-Way Mask Coefficients shape {:?}",
1106                mask_coeff.shape
1107            )));
1108        }
1109        if protos.shape.len() != 4 {
1110            return Err(DecoderError::InvalidConfig(format!(
1111                "Invalid Yolo 2-Way Protos shape {:?}",
1112                protos.shape
1113            )));
1114        }
1115
1116        Self::verify_dshapes(
1117            &detection.dshape,
1118            &detection.shape,
1119            "Detection",
1120            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1121        )?;
1122        Self::verify_dshapes(
1123            &mask_coeff.dshape,
1124            &mask_coeff.shape,
1125            "Mask Coefficients",
1126            &[DimName::Batch, DimName::NumProtos, DimName::NumBoxes],
1127        )?;
1128        Self::verify_dshapes(
1129            &protos.dshape,
1130            &protos.shape,
1131            "Protos",
1132            &[
1133                DimName::Batch,
1134                DimName::Height,
1135                DimName::Width,
1136                DimName::NumProtos,
1137            ],
1138        )?;
1139
1140        // Validate num_boxes match between detection and mask_coeff
1141        let det_num = Self::get_box_count(&detection.dshape).unwrap_or(detection.shape[2]);
1142        let mask_num = Self::get_box_count(&mask_coeff.dshape).unwrap_or(mask_coeff.shape[2]);
1143        if det_num != mask_num {
1144            return Err(DecoderError::InvalidConfig(format!(
1145                "Yolo 2-Way Detection num_boxes {} incompatible with Mask Coefficients num_boxes {}",
1146                det_num, mask_num
1147            )));
1148        }
1149
1150        // Validate mask_coeff channels match protos channels
1151        let mask_channels = if !mask_coeff.dshape.is_empty() {
1152            Self::get_protos_count(&mask_coeff.dshape).ok_or_else(|| {
1153                DecoderError::InvalidConfig(
1154                    "Could not find num_protos in mask_coeff config".to_string(),
1155                )
1156            })?
1157        } else {
1158            mask_coeff.shape[1]
1159        };
1160        let proto_channels = if !protos.dshape.is_empty() {
1161            Self::get_protos_count(&protos.dshape).ok_or_else(|| {
1162                DecoderError::InvalidConfig(
1163                    "Could not find num_protos in protos config".to_string(),
1164                )
1165            })?
1166        } else {
1167            protos.shape[1].min(protos.shape[3])
1168        };
1169        if mask_channels != proto_channels {
1170            return Err(DecoderError::InvalidConfig(format!(
1171                "Yolo 2-Way Protos channels {} incompatible with Mask Coefficients channels {}",
1172                proto_channels, mask_channels
1173            )));
1174        }
1175
1176        // Validate detection has classes (nc+4 features, no mask_coefs embedded)
1177        if !detection.dshape.is_empty() {
1178            Self::get_class_count(&detection.dshape, None, None)?;
1179        } else {
1180            Self::get_class_count_no_dshape(detection.into(), None)?;
1181        }
1182
1183        Ok(())
1184    }
1185
1186    fn verify_yolo_seg_det_26(
1187        detection: &configs::Detection,
1188        protos: &configs::Protos,
1189    ) -> Result<(), DecoderError> {
1190        if detection.shape.len() != 3 {
1191            return Err(DecoderError::InvalidConfig(format!(
1192                "Invalid Yolo Detection shape {:?}",
1193                detection.shape
1194            )));
1195        }
1196        if protos.shape.len() != 4 {
1197            return Err(DecoderError::InvalidConfig(format!(
1198                "Invalid Yolo Protos shape {:?}",
1199                protos.shape
1200            )));
1201        }
1202
1203        Self::verify_dshapes(
1204            &detection.dshape,
1205            &detection.shape,
1206            "Detection",
1207            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1208        )?;
1209        Self::verify_dshapes(
1210            &protos.dshape,
1211            &protos.shape,
1212            "Protos",
1213            &[
1214                DimName::Batch,
1215                DimName::Height,
1216                DimName::Width,
1217                DimName::NumProtos,
1218            ],
1219        )?;
1220
1221        let protos_count = Self::get_protos_count(&protos.dshape)
1222            .unwrap_or_else(|| protos.shape[1].min(protos.shape[3]));
1223        log::debug!("Protos count: {}", protos_count);
1224        log::debug!("Detection dshape: {:?}", detection.dshape);
1225
1226        if !detection.shape.contains(&(6 + protos_count)) {
1227            return Err(DecoderError::InvalidConfig(format!(
1228                "Yolo26 Segmentation Detection must have num_features be 6 + num_protos = {}",
1229                6 + protos_count
1230            )));
1231        }
1232
1233        Ok(())
1234    }
1235
1236    fn verify_yolo_split_det(
1237        boxes: &configs::Boxes,
1238        scores: &configs::Scores,
1239    ) -> Result<(), DecoderError> {
1240        if boxes.shape.len() != 3 {
1241            return Err(DecoderError::InvalidConfig(format!(
1242                "Invalid Yolo Split Boxes shape {:?}",
1243                boxes.shape
1244            )));
1245        }
1246        if scores.shape.len() != 3 {
1247            return Err(DecoderError::InvalidConfig(format!(
1248                "Invalid Yolo Split Scores shape {:?}",
1249                scores.shape
1250            )));
1251        }
1252
1253        Self::verify_dshapes(
1254            &boxes.dshape,
1255            &boxes.shape,
1256            "Boxes",
1257            &[DimName::Batch, DimName::BoxCoords, DimName::NumBoxes],
1258        )?;
1259        Self::verify_dshapes(
1260            &scores.dshape,
1261            &scores.shape,
1262            "Scores",
1263            &[DimName::Batch, DimName::NumClasses, DimName::NumBoxes],
1264        )?;
1265
1266        let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[2]);
1267        let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[2]);
1268
1269        if boxes_num != scores_num {
1270            return Err(DecoderError::InvalidConfig(format!(
1271                "Yolo Split Detection Boxes num {} incompatible with Scores num {}",
1272                boxes_num, scores_num
1273            )));
1274        }
1275
1276        Ok(())
1277    }
1278
1279    fn verify_yolo_split_segdet(
1280        boxes: &configs::Boxes,
1281        scores: &configs::Scores,
1282        mask_coeff: &configs::MaskCoefficients,
1283        protos: &configs::Protos,
1284    ) -> Result<(), DecoderError> {
1285        if boxes.shape.len() != 3 {
1286            return Err(DecoderError::InvalidConfig(format!(
1287                "Invalid Yolo Split Boxes shape {:?}",
1288                boxes.shape
1289            )));
1290        }
1291        if scores.shape.len() != 3 {
1292            return Err(DecoderError::InvalidConfig(format!(
1293                "Invalid Yolo Split Scores shape {:?}",
1294                scores.shape
1295            )));
1296        }
1297
1298        if mask_coeff.shape.len() != 3 {
1299            return Err(DecoderError::InvalidConfig(format!(
1300                "Invalid Yolo Split Mask Coefficients shape {:?}",
1301                mask_coeff.shape
1302            )));
1303        }
1304
1305        if protos.shape.len() != 4 {
1306            return Err(DecoderError::InvalidConfig(format!(
1307                "Invalid Yolo Protos shape {:?}",
1308                mask_coeff.shape
1309            )));
1310        }
1311
1312        Self::verify_dshapes(
1313            &boxes.dshape,
1314            &boxes.shape,
1315            "Boxes",
1316            &[DimName::Batch, DimName::BoxCoords, DimName::NumBoxes],
1317        )?;
1318        Self::verify_dshapes(
1319            &scores.dshape,
1320            &scores.shape,
1321            "Scores",
1322            &[DimName::Batch, DimName::NumClasses, DimName::NumBoxes],
1323        )?;
1324        Self::verify_dshapes(
1325            &mask_coeff.dshape,
1326            &mask_coeff.shape,
1327            "Mask Coefficients",
1328            &[DimName::Batch, DimName::NumProtos, DimName::NumBoxes],
1329        )?;
1330        Self::verify_dshapes(
1331            &protos.dshape,
1332            &protos.shape,
1333            "Protos",
1334            &[
1335                DimName::Batch,
1336                DimName::Height,
1337                DimName::Width,
1338                DimName::NumProtos,
1339            ],
1340        )?;
1341
1342        let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[2]);
1343        let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[2]);
1344        let mask_num = Self::get_box_count(&mask_coeff.dshape).unwrap_or(mask_coeff.shape[2]);
1345
1346        let mask_channels = if !mask_coeff.dshape.is_empty() {
1347            Self::get_protos_count(&mask_coeff.dshape).ok_or_else(|| {
1348                DecoderError::InvalidConfig("Could not find num_protos in config".to_string())
1349            })?
1350        } else {
1351            mask_coeff.shape[1]
1352        };
1353        let proto_channels = if !protos.dshape.is_empty() {
1354            Self::get_protos_count(&protos.dshape).ok_or_else(|| {
1355                DecoderError::InvalidConfig("Could not find num_protos in config".to_string())
1356            })?
1357        } else {
1358            protos.shape[1].min(protos.shape[3])
1359        };
1360
1361        if boxes_num != scores_num {
1362            return Err(DecoderError::InvalidConfig(format!(
1363                "Yolo Split Detection Boxes num {} incompatible with Scores num {}",
1364                boxes_num, scores_num
1365            )));
1366        }
1367
1368        if boxes_num != mask_num {
1369            return Err(DecoderError::InvalidConfig(format!(
1370                "Yolo Split Detection Boxes num {} incompatible with Mask Coefficients num {}",
1371                boxes_num, mask_num
1372            )));
1373        }
1374
1375        if proto_channels != mask_channels {
1376            return Err(DecoderError::InvalidConfig(format!(
1377                "Yolo Protos channels {} incompatible with Mask Coefficients channels {}",
1378                proto_channels, mask_channels
1379            )));
1380        }
1381
1382        Ok(())
1383    }
1384
1385    fn verify_yolo_split_end_to_end_det(
1386        boxes: &configs::Boxes,
1387        scores: &configs::Scores,
1388        classes: &configs::Classes,
1389    ) -> Result<(), DecoderError> {
1390        if boxes.shape.len() != 3 || !boxes.shape.contains(&4) {
1391            return Err(DecoderError::InvalidConfig(format!(
1392                "Split end-to-end boxes must be [batch, N, 4], got {:?}",
1393                boxes.shape
1394            )));
1395        }
1396        if scores.shape.len() != 3 || !scores.shape.contains(&1) {
1397            return Err(DecoderError::InvalidConfig(format!(
1398                "Split end-to-end scores must be [batch, N, 1], got {:?}",
1399                scores.shape
1400            )));
1401        }
1402        if classes.shape.len() != 3 || !classes.shape.contains(&1) {
1403            return Err(DecoderError::InvalidConfig(format!(
1404                "Split end-to-end classes must be [batch, N, 1], got {:?}",
1405                classes.shape
1406            )));
1407        }
1408        Ok(())
1409    }
1410
1411    fn verify_yolo_split_end_to_end_segdet(
1412        boxes: &configs::Boxes,
1413        scores: &configs::Scores,
1414        classes: &configs::Classes,
1415        mask_coeff: &configs::MaskCoefficients,
1416        protos: &configs::Protos,
1417    ) -> Result<(), DecoderError> {
1418        Self::verify_yolo_split_end_to_end_det(boxes, scores, classes)?;
1419        if mask_coeff.shape.len() != 3 {
1420            return Err(DecoderError::InvalidConfig(format!(
1421                "Invalid split end-to-end mask coefficients shape {:?}",
1422                mask_coeff.shape
1423            )));
1424        }
1425        if protos.shape.len() != 4 {
1426            return Err(DecoderError::InvalidConfig(format!(
1427                "Invalid protos shape {:?}",
1428                protos.shape
1429            )));
1430        }
1431        Ok(())
1432    }
1433
1434    fn get_model_type_modelpack(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
1435        let mut split_decoders = Vec::new();
1436        let mut segment_ = None;
1437        let mut scores_ = None;
1438        let mut boxes_ = None;
1439        for c in configs.outputs {
1440            match c {
1441                ConfigOutput::Detection(detection) => split_decoders.push(detection),
1442                ConfigOutput::Segmentation(segmentation) => segment_ = Some(segmentation),
1443                ConfigOutput::Mask(_) => {}
1444                ConfigOutput::Protos(_) => {
1445                    return Err(DecoderError::InvalidConfig(
1446                        "ModelPack should not have protos".to_string(),
1447                    ));
1448                }
1449                ConfigOutput::Scores(scores) => scores_ = Some(scores),
1450                ConfigOutput::Boxes(boxes) => boxes_ = Some(boxes),
1451                ConfigOutput::MaskCoefficients(_) => {
1452                    return Err(DecoderError::InvalidConfig(
1453                        "ModelPack should not have mask coefficients".to_string(),
1454                    ));
1455                }
1456                ConfigOutput::Classes(_) => {
1457                    return Err(DecoderError::InvalidConfig(
1458                        "ModelPack should not have classes output".to_string(),
1459                    ));
1460                }
1461            }
1462        }
1463
1464        if let Some(segmentation) = segment_ {
1465            if !split_decoders.is_empty() {
1466                let classes = Self::verify_modelpack_split_det(&split_decoders)?;
1467                Self::verify_modelpack_seg(&segmentation, Some(classes))?;
1468                Ok(ModelType::ModelPackSegDetSplit {
1469                    detection: split_decoders,
1470                    segmentation,
1471                })
1472            } else if let (Some(scores), Some(boxes)) = (scores_, boxes_) {
1473                let classes = Self::verify_modelpack_det(&boxes, &scores)?;
1474                Self::verify_modelpack_seg(&segmentation, Some(classes))?;
1475                Ok(ModelType::ModelPackSegDet {
1476                    boxes,
1477                    scores,
1478                    segmentation,
1479                })
1480            } else {
1481                Self::verify_modelpack_seg(&segmentation, None)?;
1482                Ok(ModelType::ModelPackSeg { segmentation })
1483            }
1484        } else if !split_decoders.is_empty() {
1485            Self::verify_modelpack_split_det(&split_decoders)?;
1486            Ok(ModelType::ModelPackDetSplit {
1487                detection: split_decoders,
1488            })
1489        } else if let (Some(scores), Some(boxes)) = (scores_, boxes_) {
1490            Self::verify_modelpack_det(&boxes, &scores)?;
1491            Ok(ModelType::ModelPackDet { boxes, scores })
1492        } else {
1493            Err(DecoderError::InvalidConfig(
1494                "Invalid ModelPack model outputs".to_string(),
1495            ))
1496        }
1497    }
1498
1499    fn verify_modelpack_det(
1500        boxes: &configs::Boxes,
1501        scores: &configs::Scores,
1502    ) -> Result<usize, DecoderError> {
1503        if boxes.shape.len() != 4 {
1504            return Err(DecoderError::InvalidConfig(format!(
1505                "Invalid ModelPack Boxes shape {:?}",
1506                boxes.shape
1507            )));
1508        }
1509        if scores.shape.len() != 3 {
1510            return Err(DecoderError::InvalidConfig(format!(
1511                "Invalid ModelPack Scores shape {:?}",
1512                scores.shape
1513            )));
1514        }
1515
1516        Self::verify_dshapes(
1517            &boxes.dshape,
1518            &boxes.shape,
1519            "Boxes",
1520            &[
1521                DimName::Batch,
1522                DimName::NumBoxes,
1523                DimName::Padding,
1524                DimName::BoxCoords,
1525            ],
1526        )?;
1527        Self::verify_dshapes(
1528            &scores.dshape,
1529            &scores.shape,
1530            "Scores",
1531            &[DimName::Batch, DimName::NumBoxes, DimName::NumClasses],
1532        )?;
1533
1534        let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[1]);
1535        let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[1]);
1536
1537        if boxes_num != scores_num {
1538            return Err(DecoderError::InvalidConfig(format!(
1539                "ModelPack Detection Boxes num {} incompatible with Scores num {}",
1540                boxes_num, scores_num
1541            )));
1542        }
1543
1544        let num_classes = if !scores.dshape.is_empty() {
1545            Self::get_class_count(&scores.dshape, None, None)?
1546        } else {
1547            Self::get_class_count_no_dshape(scores.into(), None)?
1548        };
1549
1550        Ok(num_classes)
1551    }
1552
1553    fn verify_modelpack_split_det(boxes: &[configs::Detection]) -> Result<usize, DecoderError> {
1554        let mut num_classes = None;
1555        for b in boxes {
1556            let Some(num_anchors) = b.anchors.as_ref().map(|a| a.len()) else {
1557                return Err(DecoderError::InvalidConfig(
1558                    "ModelPack Split Detection missing anchors".to_string(),
1559                ));
1560            };
1561
1562            if num_anchors == 0 {
1563                return Err(DecoderError::InvalidConfig(
1564                    "ModelPack Split Detection has zero anchors".to_string(),
1565                ));
1566            }
1567
1568            if b.shape.len() != 4 {
1569                return Err(DecoderError::InvalidConfig(format!(
1570                    "Invalid ModelPack Split Detection shape {:?}",
1571                    b.shape
1572                )));
1573            }
1574
1575            Self::verify_dshapes(
1576                &b.dshape,
1577                &b.shape,
1578                "Split Detection",
1579                &[
1580                    DimName::Batch,
1581                    DimName::Height,
1582                    DimName::Width,
1583                    DimName::NumAnchorsXFeatures,
1584                ],
1585            )?;
1586            let classes = if !b.dshape.is_empty() {
1587                Self::get_class_count(&b.dshape, None, Some(num_anchors))?
1588            } else {
1589                Self::get_class_count_no_dshape(b.into(), None)?
1590            };
1591
1592            match num_classes {
1593                Some(n) => {
1594                    if n != classes {
1595                        return Err(DecoderError::InvalidConfig(format!(
1596                            "ModelPack Split Detection inconsistent number of classes: previous {}, current {}",
1597                            n, classes
1598                        )));
1599                    }
1600                }
1601                None => {
1602                    num_classes = Some(classes);
1603                }
1604            }
1605        }
1606
1607        Ok(num_classes.unwrap_or(0))
1608    }
1609
1610    fn verify_modelpack_seg(
1611        segmentation: &configs::Segmentation,
1612        classes: Option<usize>,
1613    ) -> Result<(), DecoderError> {
1614        if segmentation.shape.len() != 4 {
1615            return Err(DecoderError::InvalidConfig(format!(
1616                "Invalid ModelPack Segmentation shape {:?}",
1617                segmentation.shape
1618            )));
1619        }
1620        Self::verify_dshapes(
1621            &segmentation.dshape,
1622            &segmentation.shape,
1623            "Segmentation",
1624            &[
1625                DimName::Batch,
1626                DimName::Height,
1627                DimName::Width,
1628                DimName::NumClasses,
1629            ],
1630        )?;
1631
1632        if let Some(classes) = classes {
1633            let seg_classes = if !segmentation.dshape.is_empty() {
1634                Self::get_class_count(&segmentation.dshape, None, None)?
1635            } else {
1636                Self::get_class_count_no_dshape(segmentation.into(), None)?
1637            };
1638
1639            if seg_classes != classes + 1 {
1640                return Err(DecoderError::InvalidConfig(format!(
1641                    "ModelPack Segmentation channels {} incompatible with number of classes {}",
1642                    seg_classes, classes
1643                )));
1644            }
1645        }
1646        Ok(())
1647    }
1648
1649    // verifies that dshapes match the given shape
1650    fn verify_dshapes(
1651        dshape: &[(DimName, usize)],
1652        shape: &[usize],
1653        name: &str,
1654        dims: &[DimName],
1655    ) -> Result<(), DecoderError> {
1656        for s in shape {
1657            if *s == 0 {
1658                return Err(DecoderError::InvalidConfig(format!(
1659                    "{} shape has zero dimension",
1660                    name
1661                )));
1662            }
1663        }
1664
1665        if shape.len() != dims.len() {
1666            return Err(DecoderError::InvalidConfig(format!(
1667                "{} shape length {} does not match expected dims length {}",
1668                name,
1669                shape.len(),
1670                dims.len()
1671            )));
1672        }
1673
1674        if dshape.is_empty() {
1675            return Ok(());
1676        }
1677        // check the dshape lengths match the shape lengths
1678        if dshape.len() != shape.len() {
1679            return Err(DecoderError::InvalidConfig(format!(
1680                "{} dshape length does not match shape length",
1681                name
1682            )));
1683        }
1684
1685        // check the dshape values match the shape values
1686        for ((dim_name, dim_size), shape_size) in dshape.iter().zip(shape) {
1687            if dim_size != shape_size {
1688                return Err(DecoderError::InvalidConfig(format!(
1689                    "{} dshape dimension {} size {} does not match shape size {}",
1690                    name, dim_name, dim_size, shape_size
1691                )));
1692            }
1693            if *dim_name == DimName::Padding && *dim_size != 1 {
1694                return Err(DecoderError::InvalidConfig(
1695                    "Padding dimension size must be 1".to_string(),
1696                ));
1697            }
1698
1699            if *dim_name == DimName::BoxCoords && *dim_size != 4 {
1700                return Err(DecoderError::InvalidConfig(
1701                    "BoxCoords dimension size must be 4".to_string(),
1702                ));
1703            }
1704        }
1705
1706        let dims_present = HashSet::<DimName>::from_iter(dshape.iter().map(|(name, _)| *name));
1707        for dim in dims {
1708            if !dims_present.contains(dim) {
1709                return Err(DecoderError::InvalidConfig(format!(
1710                    "{} dshape missing required dimension {:?}",
1711                    name, dim
1712                )));
1713            }
1714        }
1715
1716        Ok(())
1717    }
1718
1719    fn get_box_count(dshape: &[(DimName, usize)]) -> Option<usize> {
1720        for (dim_name, dim_size) in dshape {
1721            if *dim_name == DimName::NumBoxes {
1722                return Some(*dim_size);
1723            }
1724        }
1725        None
1726    }
1727
1728    fn get_class_count_no_dshape(
1729        config: ConfigOutputRef,
1730        protos: Option<usize>,
1731    ) -> Result<usize, DecoderError> {
1732        match config {
1733            ConfigOutputRef::Detection(detection) => match detection.decoder {
1734                DecoderType::Ultralytics => {
1735                    if detection.shape[1] <= 4 + protos.unwrap_or(0) {
1736                        return Err(DecoderError::InvalidConfig(format!(
1737                            "Invalid shape: Yolo num_features {} must be greater than {}",
1738                            detection.shape[1],
1739                            4 + protos.unwrap_or(0),
1740                        )));
1741                    }
1742                    Ok(detection.shape[1] - 4 - protos.unwrap_or(0))
1743                }
1744                DecoderType::ModelPack => {
1745                    let Some(num_anchors) = detection.anchors.as_ref().map(|a| a.len()) else {
1746                        return Err(DecoderError::Internal(
1747                            "ModelPack Detection missing anchors".to_string(),
1748                        ));
1749                    };
1750                    let anchors_x_features = detection.shape[3];
1751                    if anchors_x_features <= num_anchors * 5 {
1752                        return Err(DecoderError::InvalidConfig(format!(
1753                            "Invalid ModelPack Split Detection shape: anchors_x_features {} not greater than number of anchors * 5 = {}",
1754                            anchors_x_features,
1755                            num_anchors * 5,
1756                        )));
1757                    }
1758
1759                    if !anchors_x_features.is_multiple_of(num_anchors) {
1760                        return Err(DecoderError::InvalidConfig(format!(
1761                            "Invalid ModelPack Split Detection shape: anchors_x_features {} not a multiple of number of anchors {}",
1762                            anchors_x_features, num_anchors
1763                        )));
1764                    }
1765                    Ok(anchors_x_features / num_anchors - 5)
1766                }
1767            },
1768
1769            ConfigOutputRef::Scores(scores) => match scores.decoder {
1770                DecoderType::Ultralytics => Ok(scores.shape[1]),
1771                DecoderType::ModelPack => Ok(scores.shape[2]),
1772            },
1773            ConfigOutputRef::Segmentation(seg) => Ok(seg.shape[3]),
1774            _ => Err(DecoderError::Internal(
1775                "Attempted to get class count from unsupported config output".to_owned(),
1776            )),
1777        }
1778    }
1779
1780    // get the class count from dshape or calculate from num_features
1781    fn get_class_count(
1782        dshape: &[(DimName, usize)],
1783        protos: Option<usize>,
1784        anchors: Option<usize>,
1785    ) -> Result<usize, DecoderError> {
1786        if dshape.is_empty() {
1787            return Ok(0);
1788        }
1789        // if it has num_classes in dshape, return it
1790        for (dim_name, dim_size) in dshape {
1791            if *dim_name == DimName::NumClasses {
1792                return Ok(*dim_size);
1793            }
1794        }
1795
1796        // number of classes can be calculated from num_features - 4 for yolo.  If the
1797        // model has protos, we also subtract the number of protos.
1798        for (dim_name, dim_size) in dshape {
1799            if *dim_name == DimName::NumFeatures {
1800                let protos = protos.unwrap_or(0);
1801                if protos + 4 >= *dim_size {
1802                    return Err(DecoderError::InvalidConfig(format!(
1803                        "Invalid shape: Yolo num_features {} must be greater than {}",
1804                        *dim_size,
1805                        protos + 4,
1806                    )));
1807                }
1808                return Ok(*dim_size - 4 - protos);
1809            }
1810        }
1811
1812        // number of classes can be calculated from number of anchors for modelpack
1813        // split detection
1814        if let Some(num_anchors) = anchors {
1815            for (dim_name, dim_size) in dshape {
1816                if *dim_name == DimName::NumAnchorsXFeatures {
1817                    let anchors_x_features = *dim_size;
1818                    if anchors_x_features <= num_anchors * 5 {
1819                        return Err(DecoderError::InvalidConfig(format!(
1820                            "Invalid ModelPack Split Detection shape: anchors_x_features {} not greater than number of anchors * 5 = {}",
1821                            anchors_x_features,
1822                            num_anchors * 5,
1823                        )));
1824                    }
1825
1826                    if !anchors_x_features.is_multiple_of(num_anchors) {
1827                        return Err(DecoderError::InvalidConfig(format!(
1828                            "Invalid ModelPack Split Detection shape: anchors_x_features {} not a multiple of number of anchors {}",
1829                            anchors_x_features, num_anchors
1830                        )));
1831                    }
1832                    return Ok((anchors_x_features / num_anchors) - 5);
1833                }
1834            }
1835        }
1836        Err(DecoderError::InvalidConfig(
1837            "Cannot determine number of classes from dshape".to_owned(),
1838        ))
1839    }
1840
1841    fn get_protos_count(dshape: &[(DimName, usize)]) -> Option<usize> {
1842        for (dim_name, dim_size) in dshape {
1843            if *dim_name == DimName::NumProtos {
1844                return Some(*dim_size);
1845            }
1846        }
1847        None
1848    }
1849}