Skip to main content

edgefirst_decoder/decoder/
builder.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::HashSet;
5
6use super::config::ConfigOutputRef;
7use super::configs::{self, DecoderType, DecoderVersion, DimName, ModelType};
8use super::merge::DecodeProgram;
9use super::{ConfigOutput, ConfigOutputs, Decoder};
10use crate::schema::SchemaV2;
11use crate::DecoderError;
12
13#[derive(Debug, Clone, PartialEq)]
14pub struct DecoderBuilder {
15    config_src: Option<ConfigSource>,
16    iou_threshold: f32,
17    score_threshold: f32,
18    /// NMS mode: Some(mode) applies NMS, None bypasses NMS (for end-to-end
19    /// models)
20    nms: Option<configs::Nms>,
21}
22
23#[derive(Debug, Clone, PartialEq)]
24enum ConfigSource {
25    Yaml(String),
26    Json(String),
27    Config(ConfigOutputs),
28    /// Schema v2 metadata. During build the schema is either converted
29    /// to a legacy [`ConfigOutputs`] (flat case) or used to construct a
30    /// [`DecodeProgram`] that performs per-child dequant + merge at
31    /// decode time.
32    Schema(SchemaV2),
33}
34
35impl Default for DecoderBuilder {
36    /// Creates a default DecoderBuilder with no configuration and 0.5 score
37    /// threshold and 0.5 IoU threshold.
38    ///
39    /// A valid configuration must be provided before building the Decoder.
40    ///
41    /// # Examples
42    /// ```rust
43    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
44    /// # fn main() -> DecoderResult<()> {
45    /// #  let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
46    /// let decoder = DecoderBuilder::default()
47    ///     .with_config_yaml_str(config_yaml)
48    ///     .build()?;
49    /// assert_eq!(decoder.score_threshold, 0.5);
50    /// assert_eq!(decoder.iou_threshold, 0.5);
51    ///
52    /// # Ok(())
53    /// # }
54    /// ```
55    fn default() -> Self {
56        Self {
57            config_src: None,
58            iou_threshold: 0.5,
59            score_threshold: 0.5,
60            nms: Some(configs::Nms::ClassAgnostic),
61        }
62    }
63}
64
65impl DecoderBuilder {
66    /// Creates a default DecoderBuilder with no configuration and 0.5 score
67    /// threshold and 0.5 IoU threshold.
68    ///
69    /// A valid configuration must be provided before building the Decoder.
70    ///
71    /// # Examples
72    /// ```rust
73    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
74    /// # fn main() -> DecoderResult<()> {
75    /// #  let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
76    /// let decoder = DecoderBuilder::new()
77    ///     .with_config_yaml_str(config_yaml)
78    ///     .build()?;
79    /// assert_eq!(decoder.score_threshold, 0.5);
80    /// assert_eq!(decoder.iou_threshold, 0.5);
81    ///
82    /// # Ok(())
83    /// # }
84    /// ```
85    pub fn new() -> Self {
86        Self::default()
87    }
88
89    /// Loads a model configuration in YAML format. Does not check if the string
90    /// is a correct configuration file. Use `DecoderBuilder.build()` to
91    /// deserialize the YAML and parse the model configuration.
92    ///
93    /// # Examples
94    /// ```rust
95    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
96    /// # fn main() -> DecoderResult<()> {
97    /// let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
98    /// let decoder = DecoderBuilder::new()
99    ///     .with_config_yaml_str(config_yaml)
100    ///     .build()?;
101    ///
102    /// # Ok(())
103    /// # }
104    /// ```
105    pub fn with_config_yaml_str(mut self, yaml_str: String) -> Self {
106        self.config_src.replace(ConfigSource::Yaml(yaml_str));
107        self
108    }
109
110    /// Loads a model configuration in JSON format. Does not check if the string
111    /// is a correct configuration file. Use `DecoderBuilder.build()` to
112    /// deserialize the JSON and parse the model configuration.
113    ///
114    /// # Examples
115    /// ```rust
116    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
117    /// # fn main() -> DecoderResult<()> {
118    /// let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
119    /// let decoder = DecoderBuilder::new()
120    ///     .with_config_json_str(config_json)
121    ///     .build()?;
122    ///
123    /// # Ok(())
124    /// # }
125    /// ```
126    pub fn with_config_json_str(mut self, json_str: String) -> Self {
127        self.config_src.replace(ConfigSource::Json(json_str));
128        self
129    }
130
131    /// Loads a model configuration. Does not check if the configuration is
132    /// correct. Intended to be used when the user needs control over the
133    /// deserialize of the configuration information. Use
134    /// `DecoderBuilder.build()` to parse the model configuration.
135    ///
136    /// # Examples
137    /// ```rust
138    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
139    /// # fn main() -> DecoderResult<()> {
140    /// let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json"));
141    /// let config = serde_json::from_str(config_json)?;
142    /// let decoder = DecoderBuilder::new().with_config(config).build()?;
143    ///
144    /// # Ok(())
145    /// # }
146    /// ```
147    pub fn with_config(mut self, config: ConfigOutputs) -> Self {
148        self.config_src.replace(ConfigSource::Config(config));
149        self
150    }
151
152    /// Configure the decoder from a schema v2 metadata document.
153    ///
154    /// Accepts a [`SchemaV2`] as produced by [`SchemaV2::parse_json`],
155    /// [`SchemaV2::parse_yaml`], [`SchemaV2::parse_file`], or
156    /// constructed programmatically. The builder validates the schema,
157    /// compiles a [`DecodeProgram`] for any split logical outputs
158    /// (per-scale or channel sub-splits), and downconverts the
159    /// logical-level semantics to the legacy [`ConfigOutputs`]
160    /// representation consumed by the existing decoder dispatch.
161    ///
162    /// # Examples
163    ///
164    /// ```rust,no_run
165    /// use edgefirst_decoder::{DecoderBuilder, DecoderResult};
166    /// use edgefirst_decoder::schema::SchemaV2;
167    ///
168    /// # fn main() -> DecoderResult<()> {
169    /// let schema = SchemaV2::parse_file("model/edgefirst.json")?;
170    /// let decoder = DecoderBuilder::new()
171    ///     .with_schema(schema)
172    ///     .with_score_threshold(0.25)
173    ///     .build()?;
174    /// # Ok(())
175    /// # }
176    /// ```
177    pub fn with_schema(mut self, schema: SchemaV2) -> Self {
178        self.config_src.replace(ConfigSource::Schema(schema));
179        self
180    }
181
182    /// Loads a YOLO detection model configuration.  Use
183    /// `DecoderBuilder.build()` to parse the model configuration.
184    ///
185    /// # Examples
186    /// ```rust
187    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
188    /// # fn main() -> DecoderResult<()> {
189    /// let decoder = DecoderBuilder::new()
190    ///     .with_config_yolo_det(
191    ///         configs::Detection {
192    ///             anchors: None,
193    ///             decoder: configs::DecoderType::Ultralytics,
194    ///             quantization: Some(configs::QuantTuple(0.012345, 26)),
195    ///             shape: vec![1, 84, 8400],
196    ///             dshape: Vec::new(),
197    ///             normalized: Some(true),
198    ///         },
199    ///         None,
200    ///     )
201    ///     .build()?;
202    ///
203    /// # Ok(())
204    /// # }
205    /// ```
206    pub fn with_config_yolo_det(
207        mut self,
208        boxes: configs::Detection,
209        version: Option<DecoderVersion>,
210    ) -> Self {
211        let config = ConfigOutputs {
212            outputs: vec![ConfigOutput::Detection(boxes)],
213            decoder_version: version,
214            ..Default::default()
215        };
216        self.config_src.replace(ConfigSource::Config(config));
217        self
218    }
219
220    /// Loads a YOLO split detection model configuration.  Use
221    /// `DecoderBuilder.build()` to parse the model configuration.
222    ///
223    /// # Examples
224    /// ```rust
225    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
226    /// # fn main() -> DecoderResult<()> {
227    /// let boxes_config = configs::Boxes {
228    ///     decoder: configs::DecoderType::Ultralytics,
229    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
230    ///     shape: vec![1, 4, 8400],
231    ///     dshape: Vec::new(),
232    ///     normalized: Some(true),
233    /// };
234    /// let scores_config = configs::Scores {
235    ///     decoder: configs::DecoderType::Ultralytics,
236    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
237    ///     shape: vec![1, 80, 8400],
238    ///     dshape: Vec::new(),
239    /// };
240    /// let decoder = DecoderBuilder::new()
241    ///     .with_config_yolo_split_det(boxes_config, scores_config)
242    ///     .build()?;
243    /// # Ok(())
244    /// # }
245    /// ```
246    pub fn with_config_yolo_split_det(
247        mut self,
248        boxes: configs::Boxes,
249        scores: configs::Scores,
250    ) -> Self {
251        let config = ConfigOutputs {
252            outputs: vec![ConfigOutput::Boxes(boxes), ConfigOutput::Scores(scores)],
253            ..Default::default()
254        };
255        self.config_src.replace(ConfigSource::Config(config));
256        self
257    }
258
259    /// Loads a YOLO segmentation model configuration.  Use
260    /// `DecoderBuilder.build()` to parse the model configuration.
261    ///
262    /// # Examples
263    /// ```rust
264    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
265    /// # fn main() -> DecoderResult<()> {
266    /// let seg_config = configs::Detection {
267    ///     decoder: configs::DecoderType::Ultralytics,
268    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
269    ///     shape: vec![1, 116, 8400],
270    ///     anchors: None,
271    ///     dshape: Vec::new(),
272    ///     normalized: Some(true),
273    /// };
274    /// let protos_config = configs::Protos {
275    ///     decoder: configs::DecoderType::Ultralytics,
276    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
277    ///     shape: vec![1, 160, 160, 32],
278    ///     dshape: Vec::new(),
279    /// };
280    /// let decoder = DecoderBuilder::new()
281    ///     .with_config_yolo_segdet(
282    ///         seg_config,
283    ///         protos_config,
284    ///         Some(configs::DecoderVersion::Yolov8),
285    ///     )
286    ///     .build()?;
287    /// # Ok(())
288    /// # }
289    /// ```
290    pub fn with_config_yolo_segdet(
291        mut self,
292        boxes: configs::Detection,
293        protos: configs::Protos,
294        version: Option<DecoderVersion>,
295    ) -> Self {
296        let config = ConfigOutputs {
297            outputs: vec![ConfigOutput::Detection(boxes), ConfigOutput::Protos(protos)],
298            decoder_version: version,
299            ..Default::default()
300        };
301        self.config_src.replace(ConfigSource::Config(config));
302        self
303    }
304
305    /// Loads a YOLO split segmentation model configuration.  Use
306    /// `DecoderBuilder.build()` to parse the model configuration.
307    ///
308    /// # Examples
309    /// ```rust
310    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
311    /// # fn main() -> DecoderResult<()> {
312    /// let boxes_config = configs::Boxes {
313    ///     decoder: configs::DecoderType::Ultralytics,
314    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
315    ///     shape: vec![1, 4, 8400],
316    ///     dshape: Vec::new(),
317    ///     normalized: Some(true),
318    /// };
319    /// let scores_config = configs::Scores {
320    ///     decoder: configs::DecoderType::Ultralytics,
321    ///     quantization: Some(configs::QuantTuple(0.012345, 14)),
322    ///     shape: vec![1, 80, 8400],
323    ///     dshape: Vec::new(),
324    /// };
325    /// let mask_config = configs::MaskCoefficients {
326    ///     decoder: configs::DecoderType::Ultralytics,
327    ///     quantization: Some(configs::QuantTuple(0.0064123, 125)),
328    ///     shape: vec![1, 32, 8400],
329    ///     dshape: Vec::new(),
330    /// };
331    /// let protos_config = configs::Protos {
332    ///     decoder: configs::DecoderType::Ultralytics,
333    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
334    ///     shape: vec![1, 160, 160, 32],
335    ///     dshape: Vec::new(),
336    /// };
337    /// let decoder = DecoderBuilder::new()
338    ///     .with_config_yolo_split_segdet(boxes_config, scores_config, mask_config, protos_config)
339    ///     .build()?;
340    /// # Ok(())
341    /// # }
342    /// ```
343    pub fn with_config_yolo_split_segdet(
344        mut self,
345        boxes: configs::Boxes,
346        scores: configs::Scores,
347        mask_coefficients: configs::MaskCoefficients,
348        protos: configs::Protos,
349    ) -> Self {
350        let config = ConfigOutputs {
351            outputs: vec![
352                ConfigOutput::Boxes(boxes),
353                ConfigOutput::Scores(scores),
354                ConfigOutput::MaskCoefficients(mask_coefficients),
355                ConfigOutput::Protos(protos),
356            ],
357            ..Default::default()
358        };
359        self.config_src.replace(ConfigSource::Config(config));
360        self
361    }
362
363    /// Loads a ModelPack detection model configuration.  Use
364    /// `DecoderBuilder.build()` to parse the model configuration.
365    ///
366    /// # Examples
367    /// ```rust
368    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
369    /// # fn main() -> DecoderResult<()> {
370    /// let boxes_config = configs::Boxes {
371    ///     decoder: configs::DecoderType::ModelPack,
372    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
373    ///     shape: vec![1, 8400, 1, 4],
374    ///     dshape: Vec::new(),
375    ///     normalized: Some(true),
376    /// };
377    /// let scores_config = configs::Scores {
378    ///     decoder: configs::DecoderType::ModelPack,
379    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
380    ///     shape: vec![1, 8400, 3],
381    ///     dshape: Vec::new(),
382    /// };
383    /// let decoder = DecoderBuilder::new()
384    ///     .with_config_modelpack_det(boxes_config, scores_config)
385    ///     .build()?;
386    /// # Ok(())
387    /// # }
388    /// ```
389    pub fn with_config_modelpack_det(
390        mut self,
391        boxes: configs::Boxes,
392        scores: configs::Scores,
393    ) -> Self {
394        let config = ConfigOutputs {
395            outputs: vec![ConfigOutput::Boxes(boxes), ConfigOutput::Scores(scores)],
396            ..Default::default()
397        };
398        self.config_src.replace(ConfigSource::Config(config));
399        self
400    }
401
402    /// Loads a ModelPack split detection model configuration. Use
403    /// `DecoderBuilder.build()` to parse the model configuration.
404    ///
405    /// # Examples
406    /// ```rust
407    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
408    /// # fn main() -> DecoderResult<()> {
409    /// let config0 = configs::Detection {
410    ///     anchors: Some(vec![
411    ///         [0.13750000298023224, 0.2074074000120163],
412    ///         [0.2541666626930237, 0.21481481194496155],
413    ///         [0.23125000298023224, 0.35185185074806213],
414    ///     ]),
415    ///     decoder: configs::DecoderType::ModelPack,
416    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
417    ///     shape: vec![1, 17, 30, 18],
418    ///     dshape: Vec::new(),
419    ///     normalized: Some(true),
420    /// };
421    /// let config1 = configs::Detection {
422    ///     anchors: Some(vec![
423    ///         [0.36666667461395264, 0.31481480598449707],
424    ///         [0.38749998807907104, 0.4740740656852722],
425    ///         [0.5333333611488342, 0.644444465637207],
426    ///     ]),
427    ///     decoder: configs::DecoderType::ModelPack,
428    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
429    ///     shape: vec![1, 9, 15, 18],
430    ///     dshape: Vec::new(),
431    ///     normalized: Some(true),
432    /// };
433    ///
434    /// let decoder = DecoderBuilder::new()
435    ///     .with_config_modelpack_det_split(vec![config0, config1])
436    ///     .build()?;
437    /// # Ok(())
438    /// # }
439    /// ```
440    pub fn with_config_modelpack_det_split(mut self, boxes: Vec<configs::Detection>) -> Self {
441        let outputs = boxes.into_iter().map(ConfigOutput::Detection).collect();
442        let config = ConfigOutputs {
443            outputs,
444            ..Default::default()
445        };
446        self.config_src.replace(ConfigSource::Config(config));
447        self
448    }
449
450    /// Loads a ModelPack segmentation detection model configuration. Use
451    /// `DecoderBuilder.build()` to parse the model configuration.
452    ///
453    /// # Examples
454    /// ```rust
455    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
456    /// # fn main() -> DecoderResult<()> {
457    /// let boxes_config = configs::Boxes {
458    ///     decoder: configs::DecoderType::ModelPack,
459    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
460    ///     shape: vec![1, 8400, 1, 4],
461    ///     dshape: Vec::new(),
462    ///     normalized: Some(true),
463    /// };
464    /// let scores_config = configs::Scores {
465    ///     decoder: configs::DecoderType::ModelPack,
466    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
467    ///     shape: vec![1, 8400, 2],
468    ///     dshape: Vec::new(),
469    /// };
470    /// let seg_config = configs::Segmentation {
471    ///     decoder: configs::DecoderType::ModelPack,
472    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
473    ///     shape: vec![1, 640, 640, 3],
474    ///     dshape: Vec::new(),
475    /// };
476    /// let decoder = DecoderBuilder::new()
477    ///     .with_config_modelpack_segdet(boxes_config, scores_config, seg_config)
478    ///     .build()?;
479    /// # Ok(())
480    /// # }
481    /// ```
482    pub fn with_config_modelpack_segdet(
483        mut self,
484        boxes: configs::Boxes,
485        scores: configs::Scores,
486        segmentation: configs::Segmentation,
487    ) -> Self {
488        let config = ConfigOutputs {
489            outputs: vec![
490                ConfigOutput::Boxes(boxes),
491                ConfigOutput::Scores(scores),
492                ConfigOutput::Segmentation(segmentation),
493            ],
494            ..Default::default()
495        };
496        self.config_src.replace(ConfigSource::Config(config));
497        self
498    }
499
500    /// Loads a ModelPack segmentation split detection model configuration. Use
501    /// `DecoderBuilder.build()` to parse the model configuration.
502    ///
503    /// # Examples
504    /// ```rust
505    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
506    /// # fn main() -> DecoderResult<()> {
507    /// let config0 = configs::Detection {
508    ///     anchors: Some(vec![
509    ///         [0.36666667461395264, 0.31481480598449707],
510    ///         [0.38749998807907104, 0.4740740656852722],
511    ///         [0.5333333611488342, 0.644444465637207],
512    ///     ]),
513    ///     decoder: configs::DecoderType::ModelPack,
514    ///     quantization: Some(configs::QuantTuple(0.08547406643629074, 174)),
515    ///     shape: vec![1, 9, 15, 18],
516    ///     dshape: Vec::new(),
517    ///     normalized: Some(true),
518    /// };
519    /// let config1 = configs::Detection {
520    ///     anchors: Some(vec![
521    ///         [0.13750000298023224, 0.2074074000120163],
522    ///         [0.2541666626930237, 0.21481481194496155],
523    ///         [0.23125000298023224, 0.35185185074806213],
524    ///     ]),
525    ///     decoder: configs::DecoderType::ModelPack,
526    ///     quantization: Some(configs::QuantTuple(0.09929127991199493, 183)),
527    ///     shape: vec![1, 17, 30, 18],
528    ///     dshape: Vec::new(),
529    ///     normalized: Some(true),
530    /// };
531    /// let seg_config = configs::Segmentation {
532    ///     decoder: configs::DecoderType::ModelPack,
533    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
534    ///     shape: vec![1, 640, 640, 2],
535    ///     dshape: Vec::new(),
536    /// };
537    /// let decoder = DecoderBuilder::new()
538    ///     .with_config_modelpack_segdet_split(vec![config0, config1], seg_config)
539    ///     .build()?;
540    /// # Ok(())
541    /// # }
542    /// ```
543    pub fn with_config_modelpack_segdet_split(
544        mut self,
545        boxes: Vec<configs::Detection>,
546        segmentation: configs::Segmentation,
547    ) -> Self {
548        let mut outputs = boxes
549            .into_iter()
550            .map(ConfigOutput::Detection)
551            .collect::<Vec<_>>();
552        outputs.push(ConfigOutput::Segmentation(segmentation));
553        let config = ConfigOutputs {
554            outputs,
555            ..Default::default()
556        };
557        self.config_src.replace(ConfigSource::Config(config));
558        self
559    }
560
561    /// Loads a ModelPack segmentation model configuration. Use
562    /// `DecoderBuilder.build()` to parse the model configuration.
563    ///
564    /// # Examples
565    /// ```rust
566    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
567    /// # fn main() -> DecoderResult<()> {
568    /// let seg_config = configs::Segmentation {
569    ///     decoder: configs::DecoderType::ModelPack,
570    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
571    ///     shape: vec![1, 640, 640, 3],
572    ///     dshape: Vec::new(),
573    /// };
574    /// let decoder = DecoderBuilder::new()
575    ///     .with_config_modelpack_seg(seg_config)
576    ///     .build()?;
577    /// # Ok(())
578    /// # }
579    /// ```
580    pub fn with_config_modelpack_seg(mut self, segmentation: configs::Segmentation) -> Self {
581        let config = ConfigOutputs {
582            outputs: vec![ConfigOutput::Segmentation(segmentation)],
583            ..Default::default()
584        };
585        self.config_src.replace(ConfigSource::Config(config));
586        self
587    }
588
589    /// Add an output to the decoder configuration.
590    ///
591    /// Incrementally builds the model configuration by adding outputs one at
592    /// a time. The decoder resolves the model type from the combination of
593    /// outputs during `build()`.
594    ///
595    /// If `dshape` is non-empty on the output, `shape` is automatically
596    /// derived from it (the size component of each named dimension). This
597    /// prevents conflicts between `shape` and `dshape`.
598    ///
599    /// This uses the programmatic config path. Calling this after
600    /// `with_config_json_str()` or `with_config_yaml_str()` replaces the
601    /// string-based config source.
602    ///
603    /// # Examples
604    /// ```rust
605    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, ConfigOutput, configs};
606    /// # fn main() -> DecoderResult<()> {
607    /// let decoder = DecoderBuilder::new()
608    ///     .add_output(ConfigOutput::Scores(configs::Scores {
609    ///         decoder: configs::DecoderType::Ultralytics,
610    ///         dshape: vec![
611    ///             (configs::DimName::Batch, 1),
612    ///             (configs::DimName::NumClasses, 80),
613    ///             (configs::DimName::NumBoxes, 8400),
614    ///         ],
615    ///         ..Default::default()
616    ///     }))
617    ///     .add_output(ConfigOutput::Boxes(configs::Boxes {
618    ///         decoder: configs::DecoderType::Ultralytics,
619    ///         dshape: vec![
620    ///             (configs::DimName::Batch, 1),
621    ///             (configs::DimName::BoxCoords, 4),
622    ///             (configs::DimName::NumBoxes, 8400),
623    ///         ],
624    ///         ..Default::default()
625    ///     }))
626    ///     .build()?;
627    /// # Ok(())
628    /// # }
629    /// ```
630    pub fn add_output(mut self, output: ConfigOutput) -> Self {
631        if !matches!(self.config_src, Some(ConfigSource::Config(_))) {
632            self.config_src = Some(ConfigSource::Config(ConfigOutputs::default()));
633        }
634        if let Some(ConfigSource::Config(ref mut config)) = self.config_src {
635            config.outputs.push(Self::normalize_output(output));
636        }
637        self
638    }
639
640    /// Sets the decoder version for Ultralytics models.
641    ///
642    /// This is used with `add_output()` to specify the YOLO architecture
643    /// version when it cannot be inferred from the output shapes alone.
644    ///
645    /// - `Yolov5`, `Yolov8`, `Yolo11`: Traditional models requiring external
646    ///   NMS
647    /// - `Yolo26`: End-to-end models with NMS embedded in the model graph
648    ///
649    /// # Examples
650    /// ```rust
651    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, ConfigOutput, configs};
652    /// # fn main() -> DecoderResult<()> {
653    /// let decoder = DecoderBuilder::new()
654    ///     .add_output(ConfigOutput::Detection(configs::Detection {
655    ///         decoder: configs::DecoderType::Ultralytics,
656    ///         dshape: vec![
657    ///             (configs::DimName::Batch, 1),
658    ///             (configs::DimName::NumBoxes, 100),
659    ///             (configs::DimName::NumFeatures, 6),
660    ///         ],
661    ///         ..Default::default()
662    ///     }))
663    ///     .with_decoder_version(configs::DecoderVersion::Yolo26)
664    ///     .build()?;
665    /// # Ok(())
666    /// # }
667    /// ```
668    pub fn with_decoder_version(mut self, version: configs::DecoderVersion) -> Self {
669        if !matches!(self.config_src, Some(ConfigSource::Config(_))) {
670            self.config_src = Some(ConfigSource::Config(ConfigOutputs::default()));
671        }
672        if let Some(ConfigSource::Config(ref mut config)) = self.config_src {
673            config.decoder_version = Some(version);
674        }
675        self
676    }
677
678    /// Normalize an output: if dshape is non-empty, derive shape from it.
679    fn normalize_output(mut output: ConfigOutput) -> ConfigOutput {
680        fn normalize_shape(shape: &mut Vec<usize>, dshape: &[(configs::DimName, usize)]) {
681            if !dshape.is_empty() {
682                *shape = dshape.iter().map(|(_, size)| *size).collect();
683            }
684        }
685        match &mut output {
686            ConfigOutput::Detection(c) => normalize_shape(&mut c.shape, &c.dshape),
687            ConfigOutput::Boxes(c) => normalize_shape(&mut c.shape, &c.dshape),
688            ConfigOutput::Scores(c) => normalize_shape(&mut c.shape, &c.dshape),
689            ConfigOutput::Protos(c) => normalize_shape(&mut c.shape, &c.dshape),
690            ConfigOutput::Segmentation(c) => normalize_shape(&mut c.shape, &c.dshape),
691            ConfigOutput::MaskCoefficients(c) => normalize_shape(&mut c.shape, &c.dshape),
692            ConfigOutput::Mask(c) => normalize_shape(&mut c.shape, &c.dshape),
693            ConfigOutput::Classes(c) => normalize_shape(&mut c.shape, &c.dshape),
694        }
695        output
696    }
697
698    /// Sets the scores threshold of the decoder
699    ///
700    /// # Examples
701    /// ```rust
702    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
703    /// # fn main() -> DecoderResult<()> {
704    /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
705    /// let decoder = DecoderBuilder::new()
706    ///     .with_config_json_str(config_json)
707    ///     .with_score_threshold(0.654)
708    ///     .build()?;
709    /// assert_eq!(decoder.score_threshold, 0.654);
710    /// # Ok(())
711    /// # }
712    /// ```
713    pub fn with_score_threshold(mut self, score_threshold: f32) -> Self {
714        self.score_threshold = score_threshold;
715        self
716    }
717
718    /// Sets the IOU threshold of the decoder. Has no effect when NMS is set to
719    /// `None`
720    ///
721    /// # Examples
722    /// ```rust
723    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
724    /// # fn main() -> DecoderResult<()> {
725    /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
726    /// let decoder = DecoderBuilder::new()
727    ///     .with_config_json_str(config_json)
728    ///     .with_iou_threshold(0.654)
729    ///     .build()?;
730    /// assert_eq!(decoder.iou_threshold, 0.654);
731    /// # Ok(())
732    /// # }
733    /// ```
734    pub fn with_iou_threshold(mut self, iou_threshold: f32) -> Self {
735        self.iou_threshold = iou_threshold;
736        self
737    }
738
739    /// Sets the NMS mode for the decoder.
740    ///
741    /// - `Some(Nms::ClassAgnostic)` — class-agnostic NMS (default): suppress
742    ///   overlapping boxes regardless of class label
743    /// - `Some(Nms::ClassAware)` — class-aware NMS: only suppress boxes that
744    ///   share the same class label AND overlap above the IoU threshold
745    /// - `None` — bypass NMS entirely (for end-to-end models with embedded NMS)
746    ///
747    /// # Examples
748    /// ```rust
749    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, configs::Nms};
750    /// # fn main() -> DecoderResult<()> {
751    /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
752    /// let decoder = DecoderBuilder::new()
753    ///     .with_config_json_str(config_json)
754    ///     .with_nms(Some(Nms::ClassAware))
755    ///     .build()?;
756    /// assert_eq!(decoder.nms, Some(Nms::ClassAware));
757    /// # Ok(())
758    /// # }
759    /// ```
760    pub fn with_nms(mut self, nms: Option<configs::Nms>) -> Self {
761        self.nms = nms;
762        self
763    }
764
765    /// Builds the decoder with the given settings. If the config is a JSON or
766    /// YAML string, this will deserialize the JSON or YAML and then parse the
767    /// configuration information.
768    ///
769    /// # Examples
770    /// ```rust
771    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
772    /// # fn main() -> DecoderResult<()> {
773    /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
774    /// let decoder = DecoderBuilder::new()
775    ///     .with_config_json_str(config_json)
776    ///     .with_score_threshold(0.654)
777    ///     .build()?;
778    /// # Ok(())
779    /// # }
780    /// ```
781    pub fn build(self) -> Result<Decoder, DecoderError> {
782        let (config, decode_program) = match self.config_src {
783            Some(ConfigSource::Json(s)) => Self::build_from_schema(SchemaV2::parse_json(&s)?)?,
784            Some(ConfigSource::Yaml(s)) => Self::build_from_schema(SchemaV2::parse_yaml(&s)?)?,
785            Some(ConfigSource::Config(c)) => (c, None),
786            Some(ConfigSource::Schema(schema)) => Self::build_from_schema(schema)?,
787            None => return Err(DecoderError::NoConfig),
788        };
789
790        // Enforce the physical-order contract: when dshape is present
791        // it must describe the same axes as shape in the same order,
792        // listed from outermost to innermost. Ambiguous-layout roles
793        // (Protos, Boxes, Scores, MaskCoefficients, Classes, Detection)
794        // may still omit dshape when shape is already in the decoder's
795        // canonical order.
796        for output in &config.outputs {
797            Decoder::validate_output_layout(output.into())?;
798        }
799
800        // Extract normalized flag from config outputs
801        let normalized = Self::get_normalized(&config.outputs);
802
803        // Use NMS from config if present, otherwise use builder's NMS setting
804        let nms = config.nms.or(self.nms);
805        let model_type = Self::get_model_type(config)?;
806
807        Ok(Decoder {
808            model_type,
809            iou_threshold: self.iou_threshold,
810            score_threshold: self.score_threshold,
811            nms,
812            normalized,
813            decode_program,
814        })
815    }
816
817    /// Validate a [`SchemaV2`] and lower it to the (legacy `ConfigOutputs`,
818    /// optional `DecodeProgram`) pair the rest of `build()` consumes.
819    ///
820    /// Centralises the v2 lowering so JSON, YAML, and direct
821    /// `with_schema` callers all go through the same validation and
822    /// merge-program construction. `SchemaV2::parse_json` /
823    /// `parse_yaml` already auto-detect v1 vs v2 input and return a v2
824    /// schema either way (v1 inputs are upgraded in memory via
825    /// `from_v1`), so this helper is the sole place that turns a
826    /// schema into builder-ready state.
827    fn build_from_schema(
828        schema: SchemaV2,
829    ) -> Result<(ConfigOutputs, Option<DecodeProgram>), DecoderError> {
830        schema.validate()?;
831        let program = DecodeProgram::try_from_schema(&schema)?;
832        let legacy = schema.to_legacy_config_outputs()?;
833        Ok((legacy, program))
834    }
835
836    /// Extracts the normalized flag from config outputs.
837    /// - `Some(true)`: Boxes are in normalized [0,1] coordinates
838    /// - `Some(false)`: Boxes are in pixel coordinates
839    /// - `None`: Unknown (not specified in config), caller must infer
840    fn get_normalized(outputs: &[ConfigOutput]) -> Option<bool> {
841        for output in outputs {
842            match output {
843                ConfigOutput::Detection(det) => return det.normalized,
844                ConfigOutput::Boxes(boxes) => return boxes.normalized,
845                _ => {}
846            }
847        }
848        None // not specified
849    }
850
851    fn get_model_type(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
852        // yolo or modelpack
853        let mut yolo = false;
854        let mut modelpack = false;
855        for c in &configs.outputs {
856            match c.decoder() {
857                DecoderType::ModelPack => modelpack = true,
858                DecoderType::Ultralytics => yolo = true,
859            }
860        }
861        match (modelpack, yolo) {
862            (true, true) => Err(DecoderError::InvalidConfig(
863                "Both ModelPack and Yolo outputs found in config".to_string(),
864            )),
865            (true, false) => Self::get_model_type_modelpack(configs),
866            (false, true) => Self::get_model_type_yolo(configs),
867            (false, false) => Err(DecoderError::InvalidConfig(
868                "No outputs found in config".to_string(),
869            )),
870        }
871    }
872
873    fn get_model_type_yolo(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
874        let mut boxes = None;
875        let mut protos = None;
876        let mut split_boxes = None;
877        let mut split_scores = None;
878        let mut split_mask_coeff = None;
879        let mut split_classes = None;
880        for c in configs.outputs {
881            match c {
882                ConfigOutput::Detection(detection) => boxes = Some(detection),
883                ConfigOutput::Segmentation(_) => {
884                    return Err(DecoderError::InvalidConfig(
885                        "Invalid Segmentation output with Yolo decoder".to_string(),
886                    ));
887                }
888                ConfigOutput::Protos(protos_) => protos = Some(protos_),
889                ConfigOutput::Mask(_) => {
890                    return Err(DecoderError::InvalidConfig(
891                        "Invalid Mask output with Yolo decoder".to_string(),
892                    ));
893                }
894                ConfigOutput::Scores(scores) => split_scores = Some(scores),
895                ConfigOutput::Boxes(boxes) => split_boxes = Some(boxes),
896                ConfigOutput::MaskCoefficients(mask_coeff) => split_mask_coeff = Some(mask_coeff),
897                ConfigOutput::Classes(classes) => split_classes = Some(classes),
898            }
899        }
900
901        // Use end-to-end model types when:
902        // 1. decoder_version is explicitly set to Yolo26 (definitive), OR
903        //    decoder_version is not set but the dshapes are (batch, num_boxes,
904        //    num_features)
905        let is_end_to_end_dshape = boxes.as_ref().is_some_and(|b| {
906            let dims = b.dshape.iter().map(|(d, _)| *d).collect::<Vec<_>>();
907            dims == vec![DimName::Batch, DimName::NumBoxes, DimName::NumFeatures]
908        });
909
910        let is_end_to_end = configs
911            .decoder_version
912            .map(|v| v.is_end_to_end())
913            .unwrap_or(is_end_to_end_dshape);
914
915        if is_end_to_end {
916            if let Some(boxes) = boxes {
917                if let Some(protos) = protos {
918                    Self::verify_yolo_seg_det_26(&boxes, &protos)?;
919                    return Ok(ModelType::YoloEndToEndSegDet { boxes, protos });
920                } else {
921                    Self::verify_yolo_det_26(&boxes)?;
922                    return Ok(ModelType::YoloEndToEndDet { boxes });
923                }
924            } else if let (Some(split_boxes), Some(split_scores), Some(split_classes)) =
925                (split_boxes, split_scores, split_classes)
926            {
927                if let (Some(split_mask_coeff), Some(protos)) = (split_mask_coeff, protos) {
928                    Self::verify_yolo_split_end_to_end_segdet(
929                        &split_boxes,
930                        &split_scores,
931                        &split_classes,
932                        &split_mask_coeff,
933                        &protos,
934                    )?;
935                    return Ok(ModelType::YoloSplitEndToEndSegDet {
936                        boxes: split_boxes,
937                        scores: split_scores,
938                        classes: split_classes,
939                        mask_coeff: split_mask_coeff,
940                        protos,
941                    });
942                }
943                Self::verify_yolo_split_end_to_end_det(
944                    &split_boxes,
945                    &split_scores,
946                    &split_classes,
947                )?;
948                return Ok(ModelType::YoloSplitEndToEndDet {
949                    boxes: split_boxes,
950                    scores: split_scores,
951                    classes: split_classes,
952                });
953            } else {
954                return Err(DecoderError::InvalidConfig(
955                    "Invalid Yolo end-to-end model outputs".to_string(),
956                ));
957            }
958        }
959
960        if let Some(boxes) = boxes {
961            match (split_mask_coeff, protos) {
962                (Some(mask_coeff), Some(protos)) => {
963                    // 2-way split: combined detection + separate mask_coeff + protos
964                    Self::verify_yolo_seg_det_2way(&boxes, &mask_coeff, &protos)?;
965                    Ok(ModelType::YoloSegDet2Way {
966                        boxes,
967                        mask_coeff,
968                        protos,
969                    })
970                }
971                (_, Some(protos)) => {
972                    // Unsplit: mask_coefs embedded in detection tensor
973                    Self::verify_yolo_seg_det(&boxes, &protos)?;
974                    Ok(ModelType::YoloSegDet { boxes, protos })
975                }
976                _ => {
977                    Self::verify_yolo_det(&boxes)?;
978                    Ok(ModelType::YoloDet { boxes })
979                }
980            }
981        } else if let (Some(boxes), Some(scores)) = (split_boxes, split_scores) {
982            if let (Some(mask_coeff), Some(protos)) = (split_mask_coeff, protos) {
983                Self::verify_yolo_split_segdet(&boxes, &scores, &mask_coeff, &protos)?;
984                Ok(ModelType::YoloSplitSegDet {
985                    boxes,
986                    scores,
987                    mask_coeff,
988                    protos,
989                })
990            } else {
991                Self::verify_yolo_split_det(&boxes, &scores)?;
992                Ok(ModelType::YoloSplitDet { boxes, scores })
993            }
994        } else {
995            Err(DecoderError::InvalidConfig(
996                "Invalid Yolo model outputs".to_string(),
997            ))
998        }
999    }
1000
1001    fn verify_yolo_det(detect: &configs::Detection) -> Result<(), DecoderError> {
1002        if detect.shape.len() != 3 {
1003            return Err(DecoderError::InvalidConfig(format!(
1004                "Invalid Yolo Detection shape {:?}",
1005                detect.shape
1006            )));
1007        }
1008
1009        Self::verify_dshapes(
1010            &detect.dshape,
1011            &detect.shape,
1012            "Detection",
1013            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1014        )?;
1015        if !detect.dshape.is_empty() {
1016            Self::get_class_count(&detect.dshape, None, None)?;
1017        } else {
1018            Self::get_class_count_no_dshape(detect.into(), None)?;
1019        }
1020
1021        Ok(())
1022    }
1023
1024    fn verify_yolo_det_26(detect: &configs::Detection) -> Result<(), DecoderError> {
1025        if detect.shape.len() != 3 {
1026            return Err(DecoderError::InvalidConfig(format!(
1027                "Invalid Yolo Detection shape {:?}",
1028                detect.shape
1029            )));
1030        }
1031
1032        Self::verify_dshapes(
1033            &detect.dshape,
1034            &detect.shape,
1035            "Detection",
1036            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1037        )?;
1038
1039        if !detect.shape.contains(&6) {
1040            return Err(DecoderError::InvalidConfig(
1041                "Yolo26 Detection must have 6 features".to_string(),
1042            ));
1043        }
1044
1045        Ok(())
1046    }
1047
1048    fn verify_yolo_seg_det(
1049        detection: &configs::Detection,
1050        protos: &configs::Protos,
1051    ) -> Result<(), DecoderError> {
1052        if detection.shape.len() != 3 {
1053            return Err(DecoderError::InvalidConfig(format!(
1054                "Invalid Yolo Detection shape {:?}",
1055                detection.shape
1056            )));
1057        }
1058        if protos.shape.len() != 4 {
1059            return Err(DecoderError::InvalidConfig(format!(
1060                "Invalid Yolo Protos shape {:?}",
1061                protos.shape
1062            )));
1063        }
1064
1065        Self::verify_dshapes(
1066            &detection.dshape,
1067            &detection.shape,
1068            "Detection",
1069            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1070        )?;
1071        Self::verify_dshapes(
1072            &protos.dshape,
1073            &protos.shape,
1074            "Protos",
1075            &[
1076                DimName::Batch,
1077                DimName::Height,
1078                DimName::Width,
1079                DimName::NumProtos,
1080            ],
1081        )?;
1082
1083        let protos_count = Self::get_protos_count(&protos.dshape)
1084            .unwrap_or_else(|| protos.shape[1].min(protos.shape[3]));
1085        log::debug!("Protos count: {}", protos_count);
1086        log::debug!("Detection dshape: {:?}", detection.dshape);
1087        let classes = if !detection.dshape.is_empty() {
1088            Self::get_class_count(&detection.dshape, Some(protos_count), None)?
1089        } else {
1090            Self::get_class_count_no_dshape(detection.into(), Some(protos_count))?
1091        };
1092
1093        if classes == 0 {
1094            return Err(DecoderError::InvalidConfig(
1095                "Yolo Segmentation Detection has zero classes".to_string(),
1096            ));
1097        }
1098
1099        Ok(())
1100    }
1101
1102    fn verify_yolo_seg_det_2way(
1103        detection: &configs::Detection,
1104        mask_coeff: &configs::MaskCoefficients,
1105        protos: &configs::Protos,
1106    ) -> Result<(), DecoderError> {
1107        if detection.shape.len() != 3 {
1108            return Err(DecoderError::InvalidConfig(format!(
1109                "Invalid Yolo 2-Way Detection shape {:?}",
1110                detection.shape
1111            )));
1112        }
1113        if mask_coeff.shape.len() != 3 {
1114            return Err(DecoderError::InvalidConfig(format!(
1115                "Invalid Yolo 2-Way Mask Coefficients shape {:?}",
1116                mask_coeff.shape
1117            )));
1118        }
1119        if protos.shape.len() != 4 {
1120            return Err(DecoderError::InvalidConfig(format!(
1121                "Invalid Yolo 2-Way Protos shape {:?}",
1122                protos.shape
1123            )));
1124        }
1125
1126        Self::verify_dshapes(
1127            &detection.dshape,
1128            &detection.shape,
1129            "Detection",
1130            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1131        )?;
1132        Self::verify_dshapes(
1133            &mask_coeff.dshape,
1134            &mask_coeff.shape,
1135            "Mask Coefficients",
1136            &[DimName::Batch, DimName::NumProtos, DimName::NumBoxes],
1137        )?;
1138        Self::verify_dshapes(
1139            &protos.dshape,
1140            &protos.shape,
1141            "Protos",
1142            &[
1143                DimName::Batch,
1144                DimName::Height,
1145                DimName::Width,
1146                DimName::NumProtos,
1147            ],
1148        )?;
1149
1150        // Validate num_boxes match between detection and mask_coeff
1151        let det_num = Self::get_box_count(&detection.dshape).unwrap_or(detection.shape[2]);
1152        let mask_num = Self::get_box_count(&mask_coeff.dshape).unwrap_or(mask_coeff.shape[2]);
1153        if det_num != mask_num {
1154            return Err(DecoderError::InvalidConfig(format!(
1155                "Yolo 2-Way Detection num_boxes {} incompatible with Mask Coefficients num_boxes {}",
1156                det_num, mask_num
1157            )));
1158        }
1159
1160        // Validate mask_coeff channels match protos channels
1161        let mask_channels = if !mask_coeff.dshape.is_empty() {
1162            Self::get_protos_count(&mask_coeff.dshape).ok_or_else(|| {
1163                DecoderError::InvalidConfig(
1164                    "Could not find num_protos in mask_coeff config".to_string(),
1165                )
1166            })?
1167        } else {
1168            mask_coeff.shape[1]
1169        };
1170        let proto_channels = if !protos.dshape.is_empty() {
1171            Self::get_protos_count(&protos.dshape).ok_or_else(|| {
1172                DecoderError::InvalidConfig(
1173                    "Could not find num_protos in protos config".to_string(),
1174                )
1175            })?
1176        } else {
1177            protos.shape[1].min(protos.shape[3])
1178        };
1179        if mask_channels != proto_channels {
1180            return Err(DecoderError::InvalidConfig(format!(
1181                "Yolo 2-Way Protos channels {} incompatible with Mask Coefficients channels {}",
1182                proto_channels, mask_channels
1183            )));
1184        }
1185
1186        // Validate detection has classes (nc+4 features, no mask_coefs embedded)
1187        if !detection.dshape.is_empty() {
1188            Self::get_class_count(&detection.dshape, None, None)?;
1189        } else {
1190            Self::get_class_count_no_dshape(detection.into(), None)?;
1191        }
1192
1193        Ok(())
1194    }
1195
1196    fn verify_yolo_seg_det_26(
1197        detection: &configs::Detection,
1198        protos: &configs::Protos,
1199    ) -> Result<(), DecoderError> {
1200        if detection.shape.len() != 3 {
1201            return Err(DecoderError::InvalidConfig(format!(
1202                "Invalid Yolo Detection shape {:?}",
1203                detection.shape
1204            )));
1205        }
1206        if protos.shape.len() != 4 {
1207            return Err(DecoderError::InvalidConfig(format!(
1208                "Invalid Yolo Protos shape {:?}",
1209                protos.shape
1210            )));
1211        }
1212
1213        Self::verify_dshapes(
1214            &detection.dshape,
1215            &detection.shape,
1216            "Detection",
1217            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1218        )?;
1219        Self::verify_dshapes(
1220            &protos.dshape,
1221            &protos.shape,
1222            "Protos",
1223            &[
1224                DimName::Batch,
1225                DimName::Height,
1226                DimName::Width,
1227                DimName::NumProtos,
1228            ],
1229        )?;
1230
1231        let protos_count = Self::get_protos_count(&protos.dshape)
1232            .unwrap_or_else(|| protos.shape[1].min(protos.shape[3]));
1233        log::debug!("Protos count: {}", protos_count);
1234        log::debug!("Detection dshape: {:?}", detection.dshape);
1235
1236        if !detection.shape.contains(&(6 + protos_count)) {
1237            return Err(DecoderError::InvalidConfig(format!(
1238                "Yolo26 Segmentation Detection must have num_features be 6 + num_protos = {}",
1239                6 + protos_count
1240            )));
1241        }
1242
1243        Ok(())
1244    }
1245
1246    fn verify_yolo_split_det(
1247        boxes: &configs::Boxes,
1248        scores: &configs::Scores,
1249    ) -> Result<(), DecoderError> {
1250        if boxes.shape.len() != 3 {
1251            return Err(DecoderError::InvalidConfig(format!(
1252                "Invalid Yolo Split Boxes shape {:?}",
1253                boxes.shape
1254            )));
1255        }
1256        if scores.shape.len() != 3 {
1257            return Err(DecoderError::InvalidConfig(format!(
1258                "Invalid Yolo Split Scores shape {:?}",
1259                scores.shape
1260            )));
1261        }
1262
1263        Self::verify_dshapes(
1264            &boxes.dshape,
1265            &boxes.shape,
1266            "Boxes",
1267            &[DimName::Batch, DimName::BoxCoords, DimName::NumBoxes],
1268        )?;
1269        Self::verify_dshapes(
1270            &scores.dshape,
1271            &scores.shape,
1272            "Scores",
1273            &[DimName::Batch, DimName::NumClasses, DimName::NumBoxes],
1274        )?;
1275
1276        let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[2]);
1277        let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[2]);
1278
1279        if boxes_num != scores_num {
1280            return Err(DecoderError::InvalidConfig(format!(
1281                "Yolo Split Detection Boxes num {} incompatible with Scores num {}",
1282                boxes_num, scores_num
1283            )));
1284        }
1285
1286        Ok(())
1287    }
1288
1289    fn verify_yolo_split_segdet(
1290        boxes: &configs::Boxes,
1291        scores: &configs::Scores,
1292        mask_coeff: &configs::MaskCoefficients,
1293        protos: &configs::Protos,
1294    ) -> Result<(), DecoderError> {
1295        if boxes.shape.len() != 3 {
1296            return Err(DecoderError::InvalidConfig(format!(
1297                "Invalid Yolo Split Boxes shape {:?}",
1298                boxes.shape
1299            )));
1300        }
1301        if scores.shape.len() != 3 {
1302            return Err(DecoderError::InvalidConfig(format!(
1303                "Invalid Yolo Split Scores shape {:?}",
1304                scores.shape
1305            )));
1306        }
1307
1308        if mask_coeff.shape.len() != 3 {
1309            return Err(DecoderError::InvalidConfig(format!(
1310                "Invalid Yolo Split Mask Coefficients shape {:?}",
1311                mask_coeff.shape
1312            )));
1313        }
1314
1315        if protos.shape.len() != 4 {
1316            return Err(DecoderError::InvalidConfig(format!(
1317                "Invalid Yolo Protos shape {:?}",
1318                mask_coeff.shape
1319            )));
1320        }
1321
1322        Self::verify_dshapes(
1323            &boxes.dshape,
1324            &boxes.shape,
1325            "Boxes",
1326            &[DimName::Batch, DimName::BoxCoords, DimName::NumBoxes],
1327        )?;
1328        Self::verify_dshapes(
1329            &scores.dshape,
1330            &scores.shape,
1331            "Scores",
1332            &[DimName::Batch, DimName::NumClasses, DimName::NumBoxes],
1333        )?;
1334        Self::verify_dshapes(
1335            &mask_coeff.dshape,
1336            &mask_coeff.shape,
1337            "Mask Coefficients",
1338            &[DimName::Batch, DimName::NumProtos, DimName::NumBoxes],
1339        )?;
1340        Self::verify_dshapes(
1341            &protos.dshape,
1342            &protos.shape,
1343            "Protos",
1344            &[
1345                DimName::Batch,
1346                DimName::Height,
1347                DimName::Width,
1348                DimName::NumProtos,
1349            ],
1350        )?;
1351
1352        let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[2]);
1353        let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[2]);
1354        let mask_num = Self::get_box_count(&mask_coeff.dshape).unwrap_or(mask_coeff.shape[2]);
1355
1356        let mask_channels = if !mask_coeff.dshape.is_empty() {
1357            Self::get_protos_count(&mask_coeff.dshape).ok_or_else(|| {
1358                DecoderError::InvalidConfig("Could not find num_protos in config".to_string())
1359            })?
1360        } else {
1361            mask_coeff.shape[1]
1362        };
1363        let proto_channels = if !protos.dshape.is_empty() {
1364            Self::get_protos_count(&protos.dshape).ok_or_else(|| {
1365                DecoderError::InvalidConfig("Could not find num_protos in config".to_string())
1366            })?
1367        } else {
1368            protos.shape[1].min(protos.shape[3])
1369        };
1370
1371        if boxes_num != scores_num {
1372            return Err(DecoderError::InvalidConfig(format!(
1373                "Yolo Split Detection Boxes num {} incompatible with Scores num {}",
1374                boxes_num, scores_num
1375            )));
1376        }
1377
1378        if boxes_num != mask_num {
1379            return Err(DecoderError::InvalidConfig(format!(
1380                "Yolo Split Detection Boxes num {} incompatible with Mask Coefficients num {}",
1381                boxes_num, mask_num
1382            )));
1383        }
1384
1385        if proto_channels != mask_channels {
1386            return Err(DecoderError::InvalidConfig(format!(
1387                "Yolo Protos channels {} incompatible with Mask Coefficients channels {}",
1388                proto_channels, mask_channels
1389            )));
1390        }
1391
1392        Ok(())
1393    }
1394
1395    fn verify_yolo_split_end_to_end_det(
1396        boxes: &configs::Boxes,
1397        scores: &configs::Scores,
1398        classes: &configs::Classes,
1399    ) -> Result<(), DecoderError> {
1400        if boxes.shape.len() != 3 || !boxes.shape.contains(&4) {
1401            return Err(DecoderError::InvalidConfig(format!(
1402                "Split end-to-end boxes must be [batch, N, 4], got {:?}",
1403                boxes.shape
1404            )));
1405        }
1406        if scores.shape.len() != 3 || !scores.shape.contains(&1) {
1407            return Err(DecoderError::InvalidConfig(format!(
1408                "Split end-to-end scores must be [batch, N, 1], got {:?}",
1409                scores.shape
1410            )));
1411        }
1412        if classes.shape.len() != 3 || !classes.shape.contains(&1) {
1413            return Err(DecoderError::InvalidConfig(format!(
1414                "Split end-to-end classes must be [batch, N, 1], got {:?}",
1415                classes.shape
1416            )));
1417        }
1418        Ok(())
1419    }
1420
1421    fn verify_yolo_split_end_to_end_segdet(
1422        boxes: &configs::Boxes,
1423        scores: &configs::Scores,
1424        classes: &configs::Classes,
1425        mask_coeff: &configs::MaskCoefficients,
1426        protos: &configs::Protos,
1427    ) -> Result<(), DecoderError> {
1428        Self::verify_yolo_split_end_to_end_det(boxes, scores, classes)?;
1429        if mask_coeff.shape.len() != 3 {
1430            return Err(DecoderError::InvalidConfig(format!(
1431                "Invalid split end-to-end mask coefficients shape {:?}",
1432                mask_coeff.shape
1433            )));
1434        }
1435        if protos.shape.len() != 4 {
1436            return Err(DecoderError::InvalidConfig(format!(
1437                "Invalid protos shape {:?}",
1438                protos.shape
1439            )));
1440        }
1441        Ok(())
1442    }
1443
1444    fn get_model_type_modelpack(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
1445        let mut split_decoders = Vec::new();
1446        let mut segment_ = None;
1447        let mut scores_ = None;
1448        let mut boxes_ = None;
1449        for c in configs.outputs {
1450            match c {
1451                ConfigOutput::Detection(detection) => split_decoders.push(detection),
1452                ConfigOutput::Segmentation(segmentation) => segment_ = Some(segmentation),
1453                ConfigOutput::Mask(_) => {}
1454                ConfigOutput::Protos(_) => {
1455                    return Err(DecoderError::InvalidConfig(
1456                        "ModelPack should not have protos".to_string(),
1457                    ));
1458                }
1459                ConfigOutput::Scores(scores) => scores_ = Some(scores),
1460                ConfigOutput::Boxes(boxes) => boxes_ = Some(boxes),
1461                ConfigOutput::MaskCoefficients(_) => {
1462                    return Err(DecoderError::InvalidConfig(
1463                        "ModelPack should not have mask coefficients".to_string(),
1464                    ));
1465                }
1466                ConfigOutput::Classes(_) => {
1467                    return Err(DecoderError::InvalidConfig(
1468                        "ModelPack should not have classes output".to_string(),
1469                    ));
1470                }
1471            }
1472        }
1473
1474        if let Some(segmentation) = segment_ {
1475            if !split_decoders.is_empty() {
1476                let classes = Self::verify_modelpack_split_det(&split_decoders)?;
1477                Self::verify_modelpack_seg(&segmentation, Some(classes))?;
1478                Ok(ModelType::ModelPackSegDetSplit {
1479                    detection: split_decoders,
1480                    segmentation,
1481                })
1482            } else if let (Some(scores), Some(boxes)) = (scores_, boxes_) {
1483                let classes = Self::verify_modelpack_det(&boxes, &scores)?;
1484                Self::verify_modelpack_seg(&segmentation, Some(classes))?;
1485                Ok(ModelType::ModelPackSegDet {
1486                    boxes,
1487                    scores,
1488                    segmentation,
1489                })
1490            } else {
1491                Self::verify_modelpack_seg(&segmentation, None)?;
1492                Ok(ModelType::ModelPackSeg { segmentation })
1493            }
1494        } else if !split_decoders.is_empty() {
1495            Self::verify_modelpack_split_det(&split_decoders)?;
1496            Ok(ModelType::ModelPackDetSplit {
1497                detection: split_decoders,
1498            })
1499        } else if let (Some(scores), Some(boxes)) = (scores_, boxes_) {
1500            Self::verify_modelpack_det(&boxes, &scores)?;
1501            Ok(ModelType::ModelPackDet { boxes, scores })
1502        } else {
1503            Err(DecoderError::InvalidConfig(
1504                "Invalid ModelPack model outputs".to_string(),
1505            ))
1506        }
1507    }
1508
1509    fn verify_modelpack_det(
1510        boxes: &configs::Boxes,
1511        scores: &configs::Scores,
1512    ) -> Result<usize, DecoderError> {
1513        if boxes.shape.len() != 4 {
1514            return Err(DecoderError::InvalidConfig(format!(
1515                "Invalid ModelPack Boxes shape {:?}",
1516                boxes.shape
1517            )));
1518        }
1519        if scores.shape.len() != 3 {
1520            return Err(DecoderError::InvalidConfig(format!(
1521                "Invalid ModelPack Scores shape {:?}",
1522                scores.shape
1523            )));
1524        }
1525
1526        Self::verify_dshapes(
1527            &boxes.dshape,
1528            &boxes.shape,
1529            "Boxes",
1530            &[
1531                DimName::Batch,
1532                DimName::NumBoxes,
1533                DimName::Padding,
1534                DimName::BoxCoords,
1535            ],
1536        )?;
1537        Self::verify_dshapes(
1538            &scores.dshape,
1539            &scores.shape,
1540            "Scores",
1541            &[DimName::Batch, DimName::NumBoxes, DimName::NumClasses],
1542        )?;
1543
1544        let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[1]);
1545        let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[1]);
1546
1547        if boxes_num != scores_num {
1548            return Err(DecoderError::InvalidConfig(format!(
1549                "ModelPack Detection Boxes num {} incompatible with Scores num {}",
1550                boxes_num, scores_num
1551            )));
1552        }
1553
1554        let num_classes = if !scores.dshape.is_empty() {
1555            Self::get_class_count(&scores.dshape, None, None)?
1556        } else {
1557            Self::get_class_count_no_dshape(scores.into(), None)?
1558        };
1559
1560        Ok(num_classes)
1561    }
1562
1563    fn verify_modelpack_split_det(boxes: &[configs::Detection]) -> Result<usize, DecoderError> {
1564        let mut num_classes = None;
1565        for b in boxes {
1566            let Some(num_anchors) = b.anchors.as_ref().map(|a| a.len()) else {
1567                return Err(DecoderError::InvalidConfig(
1568                    "ModelPack Split Detection missing anchors".to_string(),
1569                ));
1570            };
1571
1572            if num_anchors == 0 {
1573                return Err(DecoderError::InvalidConfig(
1574                    "ModelPack Split Detection has zero anchors".to_string(),
1575                ));
1576            }
1577
1578            if b.shape.len() != 4 {
1579                return Err(DecoderError::InvalidConfig(format!(
1580                    "Invalid ModelPack Split Detection shape {:?}",
1581                    b.shape
1582                )));
1583            }
1584
1585            Self::verify_dshapes(
1586                &b.dshape,
1587                &b.shape,
1588                "Split Detection",
1589                &[
1590                    DimName::Batch,
1591                    DimName::Height,
1592                    DimName::Width,
1593                    DimName::NumAnchorsXFeatures,
1594                ],
1595            )?;
1596            let classes = if !b.dshape.is_empty() {
1597                Self::get_class_count(&b.dshape, None, Some(num_anchors))?
1598            } else {
1599                Self::get_class_count_no_dshape(b.into(), None)?
1600            };
1601
1602            match num_classes {
1603                Some(n) => {
1604                    if n != classes {
1605                        return Err(DecoderError::InvalidConfig(format!(
1606                            "ModelPack Split Detection inconsistent number of classes: previous {}, current {}",
1607                            n, classes
1608                        )));
1609                    }
1610                }
1611                None => {
1612                    num_classes = Some(classes);
1613                }
1614            }
1615        }
1616
1617        Ok(num_classes.unwrap_or(0))
1618    }
1619
1620    fn verify_modelpack_seg(
1621        segmentation: &configs::Segmentation,
1622        classes: Option<usize>,
1623    ) -> Result<(), DecoderError> {
1624        if segmentation.shape.len() != 4 {
1625            return Err(DecoderError::InvalidConfig(format!(
1626                "Invalid ModelPack Segmentation shape {:?}",
1627                segmentation.shape
1628            )));
1629        }
1630        Self::verify_dshapes(
1631            &segmentation.dshape,
1632            &segmentation.shape,
1633            "Segmentation",
1634            &[
1635                DimName::Batch,
1636                DimName::Height,
1637                DimName::Width,
1638                DimName::NumClasses,
1639            ],
1640        )?;
1641
1642        if let Some(classes) = classes {
1643            let seg_classes = if !segmentation.dshape.is_empty() {
1644                Self::get_class_count(&segmentation.dshape, None, None)?
1645            } else {
1646                Self::get_class_count_no_dshape(segmentation.into(), None)?
1647            };
1648
1649            if seg_classes != classes + 1 {
1650                return Err(DecoderError::InvalidConfig(format!(
1651                    "ModelPack Segmentation channels {} incompatible with number of classes {}",
1652                    seg_classes, classes
1653                )));
1654            }
1655        }
1656        Ok(())
1657    }
1658
1659    // verifies that dshapes match the given shape
1660    fn verify_dshapes(
1661        dshape: &[(DimName, usize)],
1662        shape: &[usize],
1663        name: &str,
1664        dims: &[DimName],
1665    ) -> Result<(), DecoderError> {
1666        for s in shape {
1667            if *s == 0 {
1668                return Err(DecoderError::InvalidConfig(format!(
1669                    "{} shape has zero dimension",
1670                    name
1671                )));
1672            }
1673        }
1674
1675        if shape.len() != dims.len() {
1676            return Err(DecoderError::InvalidConfig(format!(
1677                "{} shape length {} does not match expected dims length {}",
1678                name,
1679                shape.len(),
1680                dims.len()
1681            )));
1682        }
1683
1684        if dshape.is_empty() {
1685            return Ok(());
1686        }
1687        // check the dshape lengths match the shape lengths
1688        if dshape.len() != shape.len() {
1689            return Err(DecoderError::InvalidConfig(format!(
1690                "{} dshape length does not match shape length",
1691                name
1692            )));
1693        }
1694
1695        // check the dshape values match the shape values
1696        for ((dim_name, dim_size), shape_size) in dshape.iter().zip(shape) {
1697            if dim_size != shape_size {
1698                return Err(DecoderError::InvalidConfig(format!(
1699                    "{} dshape dimension {} size {} does not match shape size {}",
1700                    name, dim_name, dim_size, shape_size
1701                )));
1702            }
1703            if *dim_name == DimName::Padding && *dim_size != 1 {
1704                return Err(DecoderError::InvalidConfig(
1705                    "Padding dimension size must be 1".to_string(),
1706                ));
1707            }
1708
1709            if *dim_name == DimName::BoxCoords && *dim_size != 4 {
1710                return Err(DecoderError::InvalidConfig(
1711                    "BoxCoords dimension size must be 4".to_string(),
1712                ));
1713            }
1714        }
1715
1716        let dims_present = HashSet::<DimName>::from_iter(dshape.iter().map(|(name, _)| *name));
1717        for dim in dims {
1718            if !dims_present.contains(dim) {
1719                return Err(DecoderError::InvalidConfig(format!(
1720                    "{} dshape missing required dimension {:?}",
1721                    name, dim
1722                )));
1723            }
1724        }
1725
1726        Ok(())
1727    }
1728
1729    fn get_box_count(dshape: &[(DimName, usize)]) -> Option<usize> {
1730        for (dim_name, dim_size) in dshape {
1731            if *dim_name == DimName::NumBoxes {
1732                return Some(*dim_size);
1733            }
1734        }
1735        None
1736    }
1737
1738    fn get_class_count_no_dshape(
1739        config: ConfigOutputRef,
1740        protos: Option<usize>,
1741    ) -> Result<usize, DecoderError> {
1742        match config {
1743            ConfigOutputRef::Detection(detection) => match detection.decoder {
1744                DecoderType::Ultralytics => {
1745                    if detection.shape[1] <= 4 + protos.unwrap_or(0) {
1746                        return Err(DecoderError::InvalidConfig(format!(
1747                            "Invalid shape: Yolo num_features {} must be greater than {}",
1748                            detection.shape[1],
1749                            4 + protos.unwrap_or(0),
1750                        )));
1751                    }
1752                    Ok(detection.shape[1] - 4 - protos.unwrap_or(0))
1753                }
1754                DecoderType::ModelPack => {
1755                    let Some(num_anchors) = detection.anchors.as_ref().map(|a| a.len()) else {
1756                        return Err(DecoderError::Internal(
1757                            "ModelPack Detection missing anchors".to_string(),
1758                        ));
1759                    };
1760                    let anchors_x_features = detection.shape[3];
1761                    if anchors_x_features <= num_anchors * 5 {
1762                        return Err(DecoderError::InvalidConfig(format!(
1763                            "Invalid ModelPack Split Detection shape: anchors_x_features {} not greater than number of anchors * 5 = {}",
1764                            anchors_x_features,
1765                            num_anchors * 5,
1766                        )));
1767                    }
1768
1769                    if !anchors_x_features.is_multiple_of(num_anchors) {
1770                        return Err(DecoderError::InvalidConfig(format!(
1771                            "Invalid ModelPack Split Detection shape: anchors_x_features {} not a multiple of number of anchors {}",
1772                            anchors_x_features, num_anchors
1773                        )));
1774                    }
1775                    Ok(anchors_x_features / num_anchors - 5)
1776                }
1777            },
1778
1779            ConfigOutputRef::Scores(scores) => match scores.decoder {
1780                DecoderType::Ultralytics => Ok(scores.shape[1]),
1781                DecoderType::ModelPack => Ok(scores.shape[2]),
1782            },
1783            ConfigOutputRef::Segmentation(seg) => Ok(seg.shape[3]),
1784            _ => Err(DecoderError::Internal(
1785                "Attempted to get class count from unsupported config output".to_owned(),
1786            )),
1787        }
1788    }
1789
1790    // get the class count from dshape or calculate from num_features
1791    fn get_class_count(
1792        dshape: &[(DimName, usize)],
1793        protos: Option<usize>,
1794        anchors: Option<usize>,
1795    ) -> Result<usize, DecoderError> {
1796        if dshape.is_empty() {
1797            return Ok(0);
1798        }
1799        // if it has num_classes in dshape, return it
1800        for (dim_name, dim_size) in dshape {
1801            if *dim_name == DimName::NumClasses {
1802                return Ok(*dim_size);
1803            }
1804        }
1805
1806        // number of classes can be calculated from num_features - 4 for yolo.  If the
1807        // model has protos, we also subtract the number of protos.
1808        for (dim_name, dim_size) in dshape {
1809            if *dim_name == DimName::NumFeatures {
1810                let protos = protos.unwrap_or(0);
1811                if protos + 4 >= *dim_size {
1812                    return Err(DecoderError::InvalidConfig(format!(
1813                        "Invalid shape: Yolo num_features {} must be greater than {}",
1814                        *dim_size,
1815                        protos + 4,
1816                    )));
1817                }
1818                return Ok(*dim_size - 4 - protos);
1819            }
1820        }
1821
1822        // number of classes can be calculated from number of anchors for modelpack
1823        // split detection
1824        if let Some(num_anchors) = anchors {
1825            for (dim_name, dim_size) in dshape {
1826                if *dim_name == DimName::NumAnchorsXFeatures {
1827                    let anchors_x_features = *dim_size;
1828                    if anchors_x_features <= num_anchors * 5 {
1829                        return Err(DecoderError::InvalidConfig(format!(
1830                            "Invalid ModelPack Split Detection shape: anchors_x_features {} not greater than number of anchors * 5 = {}",
1831                            anchors_x_features,
1832                            num_anchors * 5,
1833                        )));
1834                    }
1835
1836                    if !anchors_x_features.is_multiple_of(num_anchors) {
1837                        return Err(DecoderError::InvalidConfig(format!(
1838                            "Invalid ModelPack Split Detection shape: anchors_x_features {} not a multiple of number of anchors {}",
1839                            anchors_x_features, num_anchors
1840                        )));
1841                    }
1842                    return Ok((anchors_x_features / num_anchors) - 5);
1843                }
1844            }
1845        }
1846        Err(DecoderError::InvalidConfig(
1847            "Cannot determine number of classes from dshape".to_owned(),
1848        ))
1849    }
1850
1851    fn get_protos_count(dshape: &[(DimName, usize)]) -> Option<usize> {
1852        for (dim_name, dim_size) in dshape {
1853            if *dim_name == DimName::NumProtos {
1854                return Some(*dim_size);
1855            }
1856        }
1857        None
1858    }
1859}