Skip to main content

edgefirst_decoder/decoder/
builder.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::HashSet;
5
6use super::config::ConfigOutputRef;
7use super::configs::{self, DecoderType, DecoderVersion, DimName, ModelType};
8use super::merge::DecodeProgram;
9use super::{ConfigOutput, ConfigOutputs, Decoder};
10use crate::schema::SchemaV2;
11use crate::DecoderError;
12
13#[derive(Debug, Clone, PartialEq)]
14pub struct DecoderBuilder {
15    config_src: Option<ConfigSource>,
16    iou_threshold: f32,
17    score_threshold: f32,
18    /// NMS mode: Some(mode) applies NMS, None bypasses NMS (for end-to-end
19    /// models)
20    nms: Option<configs::Nms>,
21    pre_nms_top_k: usize,
22    max_det: usize,
23}
24
25#[derive(Debug, Clone, PartialEq)]
26enum ConfigSource {
27    Yaml(String),
28    Json(String),
29    Config(ConfigOutputs),
30    /// Schema v2 metadata. During build the schema is either converted
31    /// to a legacy [`ConfigOutputs`] (flat case) or used to construct a
32    /// [`DecodeProgram`] that performs per-child dequant + merge at
33    /// decode time.
34    Schema(SchemaV2),
35}
36
37impl Default for DecoderBuilder {
38    /// Creates a default DecoderBuilder with no configuration and 0.5 score
39    /// threshold and 0.5 IoU threshold.
40    ///
41    /// A valid configuration must be provided before building the Decoder.
42    ///
43    /// # Examples
44    /// ```rust
45    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
46    /// # fn main() -> DecoderResult<()> {
47    /// #  let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
48    /// let decoder = DecoderBuilder::default()
49    ///     .with_config_yaml_str(config_yaml)
50    ///     .build()?;
51    /// assert_eq!(decoder.score_threshold, 0.5);
52    /// assert_eq!(decoder.iou_threshold, 0.5);
53    ///
54    /// # Ok(())
55    /// # }
56    /// ```
57    fn default() -> Self {
58        Self {
59            config_src: None,
60            iou_threshold: 0.5,
61            score_threshold: 0.5,
62            nms: Some(configs::Nms::ClassAgnostic),
63            pre_nms_top_k: 300,
64            max_det: 300,
65        }
66    }
67}
68
69impl DecoderBuilder {
70    /// Creates a default DecoderBuilder with no configuration and 0.5 score
71    /// threshold and 0.5 IoU threshold.
72    ///
73    /// A valid configuration must be provided before building the Decoder.
74    ///
75    /// # Examples
76    /// ```rust
77    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
78    /// # fn main() -> DecoderResult<()> {
79    /// #  let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
80    /// let decoder = DecoderBuilder::new()
81    ///     .with_config_yaml_str(config_yaml)
82    ///     .build()?;
83    /// assert_eq!(decoder.score_threshold, 0.5);
84    /// assert_eq!(decoder.iou_threshold, 0.5);
85    ///
86    /// # Ok(())
87    /// # }
88    /// ```
89    pub fn new() -> Self {
90        Self::default()
91    }
92
93    /// Loads a model configuration in YAML format. Does not check if the string
94    /// is a correct configuration file. Use `DecoderBuilder.build()` to
95    /// deserialize the YAML and parse the model configuration.
96    ///
97    /// # Examples
98    /// ```rust
99    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
100    /// # fn main() -> DecoderResult<()> {
101    /// let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
102    /// let decoder = DecoderBuilder::new()
103    ///     .with_config_yaml_str(config_yaml)
104    ///     .build()?;
105    ///
106    /// # Ok(())
107    /// # }
108    /// ```
109    pub fn with_config_yaml_str(mut self, yaml_str: String) -> Self {
110        self.config_src.replace(ConfigSource::Yaml(yaml_str));
111        self
112    }
113
114    /// Loads a model configuration in JSON format. Does not check if the string
115    /// is a correct configuration file. Use `DecoderBuilder.build()` to
116    /// deserialize the JSON and parse the model configuration.
117    ///
118    /// # Examples
119    /// ```rust
120    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
121    /// # fn main() -> DecoderResult<()> {
122    /// let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
123    /// let decoder = DecoderBuilder::new()
124    ///     .with_config_json_str(config_json)
125    ///     .build()?;
126    ///
127    /// # Ok(())
128    /// # }
129    /// ```
130    pub fn with_config_json_str(mut self, json_str: String) -> Self {
131        self.config_src.replace(ConfigSource::Json(json_str));
132        self
133    }
134
135    /// Loads a model configuration. Does not check if the configuration is
136    /// correct. Intended to be used when the user needs control over the
137    /// deserialize of the configuration information. Use
138    /// `DecoderBuilder.build()` to parse the model configuration.
139    ///
140    /// # Examples
141    /// ```rust
142    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
143    /// # fn main() -> DecoderResult<()> {
144    /// let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json"));
145    /// let config = serde_json::from_str(config_json)?;
146    /// let decoder = DecoderBuilder::new().with_config(config).build()?;
147    ///
148    /// # Ok(())
149    /// # }
150    /// ```
151    pub fn with_config(mut self, config: ConfigOutputs) -> Self {
152        self.config_src.replace(ConfigSource::Config(config));
153        self
154    }
155
156    /// Configure the decoder from a schema v2 metadata document.
157    ///
158    /// Accepts a [`SchemaV2`] as produced by [`SchemaV2::parse_json`],
159    /// [`SchemaV2::parse_yaml`], [`SchemaV2::parse_file`], or
160    /// constructed programmatically. The builder validates the schema,
161    /// compiles a [`DecodeProgram`] for any split logical outputs
162    /// (per-scale or channel sub-splits), and downconverts the
163    /// logical-level semantics to the legacy [`ConfigOutputs`]
164    /// representation consumed by the existing decoder dispatch.
165    ///
166    /// # Examples
167    ///
168    /// ```rust,no_run
169    /// use edgefirst_decoder::{DecoderBuilder, DecoderResult};
170    /// use edgefirst_decoder::schema::SchemaV2;
171    ///
172    /// # fn main() -> DecoderResult<()> {
173    /// let schema = SchemaV2::parse_file("model/edgefirst.json")?;
174    /// let decoder = DecoderBuilder::new()
175    ///     .with_schema(schema)
176    ///     .with_score_threshold(0.25)
177    ///     .build()?;
178    /// # Ok(())
179    /// # }
180    /// ```
181    pub fn with_schema(mut self, schema: SchemaV2) -> Self {
182        self.config_src.replace(ConfigSource::Schema(schema));
183        self
184    }
185
186    /// Loads a YOLO detection model configuration.  Use
187    /// `DecoderBuilder.build()` to parse the model configuration.
188    ///
189    /// # Examples
190    /// ```rust
191    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
192    /// # fn main() -> DecoderResult<()> {
193    /// let decoder = DecoderBuilder::new()
194    ///     .with_config_yolo_det(
195    ///         configs::Detection {
196    ///             anchors: None,
197    ///             decoder: configs::DecoderType::Ultralytics,
198    ///             quantization: Some(configs::QuantTuple(0.012345, 26)),
199    ///             shape: vec![1, 84, 8400],
200    ///             dshape: Vec::new(),
201    ///             normalized: Some(true),
202    ///         },
203    ///         None,
204    ///     )
205    ///     .build()?;
206    ///
207    /// # Ok(())
208    /// # }
209    /// ```
210    pub fn with_config_yolo_det(
211        mut self,
212        boxes: configs::Detection,
213        version: Option<DecoderVersion>,
214    ) -> Self {
215        let config = ConfigOutputs {
216            outputs: vec![ConfigOutput::Detection(boxes)],
217            decoder_version: version,
218            ..Default::default()
219        };
220        self.config_src.replace(ConfigSource::Config(config));
221        self
222    }
223
224    /// Loads a YOLO split detection model configuration.  Use
225    /// `DecoderBuilder.build()` to parse the model configuration.
226    ///
227    /// # Examples
228    /// ```rust
229    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
230    /// # fn main() -> DecoderResult<()> {
231    /// let boxes_config = configs::Boxes {
232    ///     decoder: configs::DecoderType::Ultralytics,
233    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
234    ///     shape: vec![1, 4, 8400],
235    ///     dshape: Vec::new(),
236    ///     normalized: Some(true),
237    /// };
238    /// let scores_config = configs::Scores {
239    ///     decoder: configs::DecoderType::Ultralytics,
240    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
241    ///     shape: vec![1, 80, 8400],
242    ///     dshape: Vec::new(),
243    /// };
244    /// let decoder = DecoderBuilder::new()
245    ///     .with_config_yolo_split_det(boxes_config, scores_config)
246    ///     .build()?;
247    /// # Ok(())
248    /// # }
249    /// ```
250    pub fn with_config_yolo_split_det(
251        mut self,
252        boxes: configs::Boxes,
253        scores: configs::Scores,
254    ) -> Self {
255        let config = ConfigOutputs {
256            outputs: vec![ConfigOutput::Boxes(boxes), ConfigOutput::Scores(scores)],
257            ..Default::default()
258        };
259        self.config_src.replace(ConfigSource::Config(config));
260        self
261    }
262
263    /// Loads a YOLO segmentation model configuration.  Use
264    /// `DecoderBuilder.build()` to parse the model configuration.
265    ///
266    /// # Examples
267    /// ```rust
268    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
269    /// # fn main() -> DecoderResult<()> {
270    /// let seg_config = configs::Detection {
271    ///     decoder: configs::DecoderType::Ultralytics,
272    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
273    ///     shape: vec![1, 116, 8400],
274    ///     anchors: None,
275    ///     dshape: Vec::new(),
276    ///     normalized: Some(true),
277    /// };
278    /// let protos_config = configs::Protos {
279    ///     decoder: configs::DecoderType::Ultralytics,
280    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
281    ///     shape: vec![1, 160, 160, 32],
282    ///     dshape: Vec::new(),
283    /// };
284    /// let decoder = DecoderBuilder::new()
285    ///     .with_config_yolo_segdet(
286    ///         seg_config,
287    ///         protos_config,
288    ///         Some(configs::DecoderVersion::Yolov8),
289    ///     )
290    ///     .build()?;
291    /// # Ok(())
292    /// # }
293    /// ```
294    pub fn with_config_yolo_segdet(
295        mut self,
296        boxes: configs::Detection,
297        protos: configs::Protos,
298        version: Option<DecoderVersion>,
299    ) -> Self {
300        let config = ConfigOutputs {
301            outputs: vec![ConfigOutput::Detection(boxes), ConfigOutput::Protos(protos)],
302            decoder_version: version,
303            ..Default::default()
304        };
305        self.config_src.replace(ConfigSource::Config(config));
306        self
307    }
308
309    /// Loads a YOLO split segmentation model configuration.  Use
310    /// `DecoderBuilder.build()` to parse the model configuration.
311    ///
312    /// # Examples
313    /// ```rust
314    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
315    /// # fn main() -> DecoderResult<()> {
316    /// let boxes_config = configs::Boxes {
317    ///     decoder: configs::DecoderType::Ultralytics,
318    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
319    ///     shape: vec![1, 4, 8400],
320    ///     dshape: Vec::new(),
321    ///     normalized: Some(true),
322    /// };
323    /// let scores_config = configs::Scores {
324    ///     decoder: configs::DecoderType::Ultralytics,
325    ///     quantization: Some(configs::QuantTuple(0.012345, 14)),
326    ///     shape: vec![1, 80, 8400],
327    ///     dshape: Vec::new(),
328    /// };
329    /// let mask_config = configs::MaskCoefficients {
330    ///     decoder: configs::DecoderType::Ultralytics,
331    ///     quantization: Some(configs::QuantTuple(0.0064123, 125)),
332    ///     shape: vec![1, 32, 8400],
333    ///     dshape: Vec::new(),
334    /// };
335    /// let protos_config = configs::Protos {
336    ///     decoder: configs::DecoderType::Ultralytics,
337    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
338    ///     shape: vec![1, 160, 160, 32],
339    ///     dshape: Vec::new(),
340    /// };
341    /// let decoder = DecoderBuilder::new()
342    ///     .with_config_yolo_split_segdet(boxes_config, scores_config, mask_config, protos_config)
343    ///     .build()?;
344    /// # Ok(())
345    /// # }
346    /// ```
347    pub fn with_config_yolo_split_segdet(
348        mut self,
349        boxes: configs::Boxes,
350        scores: configs::Scores,
351        mask_coefficients: configs::MaskCoefficients,
352        protos: configs::Protos,
353    ) -> Self {
354        let config = ConfigOutputs {
355            outputs: vec![
356                ConfigOutput::Boxes(boxes),
357                ConfigOutput::Scores(scores),
358                ConfigOutput::MaskCoefficients(mask_coefficients),
359                ConfigOutput::Protos(protos),
360            ],
361            ..Default::default()
362        };
363        self.config_src.replace(ConfigSource::Config(config));
364        self
365    }
366
367    /// Loads a ModelPack detection model configuration.  Use
368    /// `DecoderBuilder.build()` to parse the model configuration.
369    ///
370    /// # Examples
371    /// ```rust
372    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
373    /// # fn main() -> DecoderResult<()> {
374    /// let boxes_config = configs::Boxes {
375    ///     decoder: configs::DecoderType::ModelPack,
376    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
377    ///     shape: vec![1, 8400, 1, 4],
378    ///     dshape: Vec::new(),
379    ///     normalized: Some(true),
380    /// };
381    /// let scores_config = configs::Scores {
382    ///     decoder: configs::DecoderType::ModelPack,
383    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
384    ///     shape: vec![1, 8400, 3],
385    ///     dshape: Vec::new(),
386    /// };
387    /// let decoder = DecoderBuilder::new()
388    ///     .with_config_modelpack_det(boxes_config, scores_config)
389    ///     .build()?;
390    /// # Ok(())
391    /// # }
392    /// ```
393    pub fn with_config_modelpack_det(
394        mut self,
395        boxes: configs::Boxes,
396        scores: configs::Scores,
397    ) -> Self {
398        let config = ConfigOutputs {
399            outputs: vec![ConfigOutput::Boxes(boxes), ConfigOutput::Scores(scores)],
400            ..Default::default()
401        };
402        self.config_src.replace(ConfigSource::Config(config));
403        self
404    }
405
406    /// Loads a ModelPack split detection model configuration. Use
407    /// `DecoderBuilder.build()` to parse the model configuration.
408    ///
409    /// # Examples
410    /// ```rust
411    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
412    /// # fn main() -> DecoderResult<()> {
413    /// let config0 = configs::Detection {
414    ///     anchors: Some(vec![
415    ///         [0.13750000298023224, 0.2074074000120163],
416    ///         [0.2541666626930237, 0.21481481194496155],
417    ///         [0.23125000298023224, 0.35185185074806213],
418    ///     ]),
419    ///     decoder: configs::DecoderType::ModelPack,
420    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
421    ///     shape: vec![1, 17, 30, 18],
422    ///     dshape: Vec::new(),
423    ///     normalized: Some(true),
424    /// };
425    /// let config1 = configs::Detection {
426    ///     anchors: Some(vec![
427    ///         [0.36666667461395264, 0.31481480598449707],
428    ///         [0.38749998807907104, 0.4740740656852722],
429    ///         [0.5333333611488342, 0.644444465637207],
430    ///     ]),
431    ///     decoder: configs::DecoderType::ModelPack,
432    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
433    ///     shape: vec![1, 9, 15, 18],
434    ///     dshape: Vec::new(),
435    ///     normalized: Some(true),
436    /// };
437    ///
438    /// let decoder = DecoderBuilder::new()
439    ///     .with_config_modelpack_det_split(vec![config0, config1])
440    ///     .build()?;
441    /// # Ok(())
442    /// # }
443    /// ```
444    pub fn with_config_modelpack_det_split(mut self, boxes: Vec<configs::Detection>) -> Self {
445        let outputs = boxes.into_iter().map(ConfigOutput::Detection).collect();
446        let config = ConfigOutputs {
447            outputs,
448            ..Default::default()
449        };
450        self.config_src.replace(ConfigSource::Config(config));
451        self
452    }
453
454    /// Loads a ModelPack segmentation detection model configuration. Use
455    /// `DecoderBuilder.build()` to parse the model configuration.
456    ///
457    /// # Examples
458    /// ```rust
459    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
460    /// # fn main() -> DecoderResult<()> {
461    /// let boxes_config = configs::Boxes {
462    ///     decoder: configs::DecoderType::ModelPack,
463    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
464    ///     shape: vec![1, 8400, 1, 4],
465    ///     dshape: Vec::new(),
466    ///     normalized: Some(true),
467    /// };
468    /// let scores_config = configs::Scores {
469    ///     decoder: configs::DecoderType::ModelPack,
470    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
471    ///     shape: vec![1, 8400, 2],
472    ///     dshape: Vec::new(),
473    /// };
474    /// let seg_config = configs::Segmentation {
475    ///     decoder: configs::DecoderType::ModelPack,
476    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
477    ///     shape: vec![1, 640, 640, 3],
478    ///     dshape: Vec::new(),
479    /// };
480    /// let decoder = DecoderBuilder::new()
481    ///     .with_config_modelpack_segdet(boxes_config, scores_config, seg_config)
482    ///     .build()?;
483    /// # Ok(())
484    /// # }
485    /// ```
486    pub fn with_config_modelpack_segdet(
487        mut self,
488        boxes: configs::Boxes,
489        scores: configs::Scores,
490        segmentation: configs::Segmentation,
491    ) -> Self {
492        let config = ConfigOutputs {
493            outputs: vec![
494                ConfigOutput::Boxes(boxes),
495                ConfigOutput::Scores(scores),
496                ConfigOutput::Segmentation(segmentation),
497            ],
498            ..Default::default()
499        };
500        self.config_src.replace(ConfigSource::Config(config));
501        self
502    }
503
504    /// Loads a ModelPack segmentation split detection model configuration. Use
505    /// `DecoderBuilder.build()` to parse the model configuration.
506    ///
507    /// # Examples
508    /// ```rust
509    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
510    /// # fn main() -> DecoderResult<()> {
511    /// let config0 = configs::Detection {
512    ///     anchors: Some(vec![
513    ///         [0.36666667461395264, 0.31481480598449707],
514    ///         [0.38749998807907104, 0.4740740656852722],
515    ///         [0.5333333611488342, 0.644444465637207],
516    ///     ]),
517    ///     decoder: configs::DecoderType::ModelPack,
518    ///     quantization: Some(configs::QuantTuple(0.08547406643629074, 174)),
519    ///     shape: vec![1, 9, 15, 18],
520    ///     dshape: Vec::new(),
521    ///     normalized: Some(true),
522    /// };
523    /// let config1 = configs::Detection {
524    ///     anchors: Some(vec![
525    ///         [0.13750000298023224, 0.2074074000120163],
526    ///         [0.2541666626930237, 0.21481481194496155],
527    ///         [0.23125000298023224, 0.35185185074806213],
528    ///     ]),
529    ///     decoder: configs::DecoderType::ModelPack,
530    ///     quantization: Some(configs::QuantTuple(0.09929127991199493, 183)),
531    ///     shape: vec![1, 17, 30, 18],
532    ///     dshape: Vec::new(),
533    ///     normalized: Some(true),
534    /// };
535    /// let seg_config = configs::Segmentation {
536    ///     decoder: configs::DecoderType::ModelPack,
537    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
538    ///     shape: vec![1, 640, 640, 2],
539    ///     dshape: Vec::new(),
540    /// };
541    /// let decoder = DecoderBuilder::new()
542    ///     .with_config_modelpack_segdet_split(vec![config0, config1], seg_config)
543    ///     .build()?;
544    /// # Ok(())
545    /// # }
546    /// ```
547    pub fn with_config_modelpack_segdet_split(
548        mut self,
549        boxes: Vec<configs::Detection>,
550        segmentation: configs::Segmentation,
551    ) -> Self {
552        let mut outputs = boxes
553            .into_iter()
554            .map(ConfigOutput::Detection)
555            .collect::<Vec<_>>();
556        outputs.push(ConfigOutput::Segmentation(segmentation));
557        let config = ConfigOutputs {
558            outputs,
559            ..Default::default()
560        };
561        self.config_src.replace(ConfigSource::Config(config));
562        self
563    }
564
565    /// Loads a ModelPack segmentation model configuration. Use
566    /// `DecoderBuilder.build()` to parse the model configuration.
567    ///
568    /// # Examples
569    /// ```rust
570    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
571    /// # fn main() -> DecoderResult<()> {
572    /// let seg_config = configs::Segmentation {
573    ///     decoder: configs::DecoderType::ModelPack,
574    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
575    ///     shape: vec![1, 640, 640, 3],
576    ///     dshape: Vec::new(),
577    /// };
578    /// let decoder = DecoderBuilder::new()
579    ///     .with_config_modelpack_seg(seg_config)
580    ///     .build()?;
581    /// # Ok(())
582    /// # }
583    /// ```
584    pub fn with_config_modelpack_seg(mut self, segmentation: configs::Segmentation) -> Self {
585        let config = ConfigOutputs {
586            outputs: vec![ConfigOutput::Segmentation(segmentation)],
587            ..Default::default()
588        };
589        self.config_src.replace(ConfigSource::Config(config));
590        self
591    }
592
593    /// Add an output to the decoder configuration.
594    ///
595    /// Incrementally builds the model configuration by adding outputs one at
596    /// a time. The decoder resolves the model type from the combination of
597    /// outputs during `build()`.
598    ///
599    /// If `dshape` is non-empty on the output, `shape` is automatically
600    /// derived from it (the size component of each named dimension). This
601    /// prevents conflicts between `shape` and `dshape`.
602    ///
603    /// This uses the programmatic config path. Calling this after
604    /// `with_config_json_str()` or `with_config_yaml_str()` replaces the
605    /// string-based config source.
606    ///
607    /// # Examples
608    /// ```rust
609    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, ConfigOutput, configs};
610    /// # fn main() -> DecoderResult<()> {
611    /// let decoder = DecoderBuilder::new()
612    ///     .add_output(ConfigOutput::Scores(configs::Scores {
613    ///         decoder: configs::DecoderType::Ultralytics,
614    ///         dshape: vec![
615    ///             (configs::DimName::Batch, 1),
616    ///             (configs::DimName::NumClasses, 80),
617    ///             (configs::DimName::NumBoxes, 8400),
618    ///         ],
619    ///         ..Default::default()
620    ///     }))
621    ///     .add_output(ConfigOutput::Boxes(configs::Boxes {
622    ///         decoder: configs::DecoderType::Ultralytics,
623    ///         dshape: vec![
624    ///             (configs::DimName::Batch, 1),
625    ///             (configs::DimName::BoxCoords, 4),
626    ///             (configs::DimName::NumBoxes, 8400),
627    ///         ],
628    ///         ..Default::default()
629    ///     }))
630    ///     .build()?;
631    /// # Ok(())
632    /// # }
633    /// ```
634    pub fn add_output(mut self, output: ConfigOutput) -> Self {
635        if !matches!(self.config_src, Some(ConfigSource::Config(_))) {
636            self.config_src = Some(ConfigSource::Config(ConfigOutputs::default()));
637        }
638        if let Some(ConfigSource::Config(ref mut config)) = self.config_src {
639            config.outputs.push(Self::normalize_output(output));
640        }
641        self
642    }
643
644    /// Sets the decoder version for Ultralytics models.
645    ///
646    /// This is used with `add_output()` to specify the YOLO architecture
647    /// version when it cannot be inferred from the output shapes alone.
648    ///
649    /// - `Yolov5`, `Yolov8`, `Yolo11`: Traditional models requiring external
650    ///   NMS
651    /// - `Yolo26`: End-to-end models with NMS embedded in the model graph
652    ///
653    /// # Examples
654    /// ```rust
655    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, ConfigOutput, configs};
656    /// # fn main() -> DecoderResult<()> {
657    /// let decoder = DecoderBuilder::new()
658    ///     .add_output(ConfigOutput::Detection(configs::Detection {
659    ///         decoder: configs::DecoderType::Ultralytics,
660    ///         dshape: vec![
661    ///             (configs::DimName::Batch, 1),
662    ///             (configs::DimName::NumBoxes, 100),
663    ///             (configs::DimName::NumFeatures, 6),
664    ///         ],
665    ///         ..Default::default()
666    ///     }))
667    ///     .with_decoder_version(configs::DecoderVersion::Yolo26)
668    ///     .build()?;
669    /// # Ok(())
670    /// # }
671    /// ```
672    pub fn with_decoder_version(mut self, version: configs::DecoderVersion) -> Self {
673        if !matches!(self.config_src, Some(ConfigSource::Config(_))) {
674            self.config_src = Some(ConfigSource::Config(ConfigOutputs::default()));
675        }
676        if let Some(ConfigSource::Config(ref mut config)) = self.config_src {
677            config.decoder_version = Some(version);
678        }
679        self
680    }
681
682    /// Normalize an output: if dshape is non-empty, derive shape from it.
683    fn normalize_output(mut output: ConfigOutput) -> ConfigOutput {
684        fn normalize_shape(shape: &mut Vec<usize>, dshape: &[(configs::DimName, usize)]) {
685            if !dshape.is_empty() {
686                *shape = dshape.iter().map(|(_, size)| *size).collect();
687            }
688        }
689        match &mut output {
690            ConfigOutput::Detection(c) => normalize_shape(&mut c.shape, &c.dshape),
691            ConfigOutput::Boxes(c) => normalize_shape(&mut c.shape, &c.dshape),
692            ConfigOutput::Scores(c) => normalize_shape(&mut c.shape, &c.dshape),
693            ConfigOutput::Protos(c) => normalize_shape(&mut c.shape, &c.dshape),
694            ConfigOutput::Segmentation(c) => normalize_shape(&mut c.shape, &c.dshape),
695            ConfigOutput::MaskCoefficients(c) => normalize_shape(&mut c.shape, &c.dshape),
696            ConfigOutput::Mask(c) => normalize_shape(&mut c.shape, &c.dshape),
697            ConfigOutput::Classes(c) => normalize_shape(&mut c.shape, &c.dshape),
698        }
699        output
700    }
701
702    /// Sets the scores threshold of the decoder
703    ///
704    /// # Examples
705    /// ```rust
706    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
707    /// # fn main() -> DecoderResult<()> {
708    /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
709    /// let decoder = DecoderBuilder::new()
710    ///     .with_config_json_str(config_json)
711    ///     .with_score_threshold(0.654)
712    ///     .build()?;
713    /// assert_eq!(decoder.score_threshold, 0.654);
714    /// # Ok(())
715    /// # }
716    /// ```
717    pub fn with_score_threshold(mut self, score_threshold: f32) -> Self {
718        self.score_threshold = score_threshold;
719        self
720    }
721
722    /// Sets the IOU threshold of the decoder. Has no effect when NMS is set to
723    /// `None`
724    ///
725    /// # Examples
726    /// ```rust
727    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
728    /// # fn main() -> DecoderResult<()> {
729    /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
730    /// let decoder = DecoderBuilder::new()
731    ///     .with_config_json_str(config_json)
732    ///     .with_iou_threshold(0.654)
733    ///     .build()?;
734    /// assert_eq!(decoder.iou_threshold, 0.654);
735    /// # Ok(())
736    /// # }
737    /// ```
738    pub fn with_iou_threshold(mut self, iou_threshold: f32) -> Self {
739        self.iou_threshold = iou_threshold;
740        self
741    }
742
743    /// Sets the NMS mode for the decoder.
744    ///
745    /// - `Some(Nms::ClassAgnostic)` — class-agnostic NMS (default): suppress
746    ///   overlapping boxes regardless of class label
747    /// - `Some(Nms::ClassAware)` — class-aware NMS: only suppress boxes that
748    ///   share the same class label AND overlap above the IoU threshold
749    /// - `None` — bypass NMS entirely (for end-to-end models with embedded NMS)
750    ///
751    /// # Examples
752    /// ```rust
753    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, configs::Nms};
754    /// # fn main() -> DecoderResult<()> {
755    /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
756    /// let decoder = DecoderBuilder::new()
757    ///     .with_config_json_str(config_json)
758    ///     .with_nms(Some(Nms::ClassAware))
759    ///     .build()?;
760    /// assert_eq!(decoder.nms, Some(Nms::ClassAware));
761    /// # Ok(())
762    /// # }
763    /// ```
764    pub fn with_nms(mut self, nms: Option<configs::Nms>) -> Self {
765        self.nms = nms;
766        self
767    }
768
769    /// Sets the maximum number of candidate boxes fed into NMS after score
770    /// filtering.  Uses partial sort (O(N)) to select the top-K candidates,
771    /// dramatically reducing the O(N²) NMS cost when many low-confidence
772    /// proposals pass the threshold (common with mAP eval at 0.001).
773    ///
774    /// Default: 300 (matches Ultralytics `max_det`).
775    ///
776    /// # Examples
777    /// ```rust
778    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
779    /// # fn main() -> DecoderResult<()> {
780    /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
781    /// let decoder = DecoderBuilder::new()
782    ///     .with_config_json_str(config_json)
783    ///     .with_pre_nms_top_k(1000)
784    ///     .build()?;
785    /// assert_eq!(decoder.pre_nms_top_k, 1000);
786    /// # Ok(())
787    /// # }
788    /// ```
789    pub fn with_pre_nms_top_k(mut self, pre_nms_top_k: usize) -> Self {
790        self.pre_nms_top_k = pre_nms_top_k;
791        self
792    }
793
794    /// Sets the maximum number of detections returned after NMS.
795    /// Matches the Ultralytics `max_det` parameter.
796    ///
797    /// Default: 300.
798    ///
799    /// # Examples
800    /// ```rust
801    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
802    /// # fn main() -> DecoderResult<()> {
803    /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
804    /// let decoder = DecoderBuilder::new()
805    ///     .with_config_json_str(config_json)
806    ///     .with_max_det(100)
807    ///     .build()?;
808    /// assert_eq!(decoder.max_det, 100);
809    /// # Ok(())
810    /// # }
811    /// ```
812    pub fn with_max_det(mut self, max_det: usize) -> Self {
813        self.max_det = max_det;
814        self
815    }
816
817    /// Builds the decoder with the given settings. If the config is a JSON or
818    /// YAML string, this will deserialize the JSON or YAML and then parse the
819    /// configuration information.
820    ///
821    /// # Examples
822    /// ```rust
823    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
824    /// # fn main() -> DecoderResult<()> {
825    /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
826    /// let decoder = DecoderBuilder::new()
827    ///     .with_config_json_str(config_json)
828    ///     .with_score_threshold(0.654)
829    ///     .build()?;
830    /// # Ok(())
831    /// # }
832    /// ```
833    pub fn build(self) -> Result<Decoder, DecoderError> {
834        let (config, decode_program) = match self.config_src {
835            Some(ConfigSource::Json(s)) => Self::build_from_schema(SchemaV2::parse_json(&s)?)?,
836            Some(ConfigSource::Yaml(s)) => Self::build_from_schema(SchemaV2::parse_yaml(&s)?)?,
837            Some(ConfigSource::Config(c)) => (c, None),
838            Some(ConfigSource::Schema(schema)) => Self::build_from_schema(schema)?,
839            None => return Err(DecoderError::NoConfig),
840        };
841
842        // Enforce the physical-order contract: when dshape is present
843        // it must describe the same axes as shape in the same order,
844        // listed from outermost to innermost. Ambiguous-layout roles
845        // (Protos, Boxes, Scores, MaskCoefficients, Classes, Detection)
846        // may still omit dshape when shape is already in the decoder's
847        // canonical order.
848        for output in &config.outputs {
849            Decoder::validate_output_layout(output.into())?;
850        }
851
852        // Extract normalized flag from config outputs
853        let normalized = Self::get_normalized(&config.outputs);
854
855        // Use NMS from config if present, otherwise use builder's NMS setting
856        let nms = config.nms.or(self.nms);
857        let model_type = Self::get_model_type(config)?;
858
859        Ok(Decoder {
860            model_type,
861            iou_threshold: self.iou_threshold,
862            score_threshold: self.score_threshold,
863            nms,
864            pre_nms_top_k: self.pre_nms_top_k,
865            max_det: self.max_det,
866            normalized,
867            decode_program,
868        })
869    }
870
871    /// Validate a [`SchemaV2`] and lower it to the (legacy `ConfigOutputs`,
872    /// optional `DecodeProgram`) pair the rest of `build()` consumes.
873    ///
874    /// Centralises the v2 lowering so JSON, YAML, and direct
875    /// `with_schema` callers all go through the same validation and
876    /// merge-program construction. `SchemaV2::parse_json` /
877    /// `parse_yaml` already auto-detect v1 vs v2 input and return a v2
878    /// schema either way (v1 inputs are upgraded in memory via
879    /// `from_v1`), so this helper is the sole place that turns a
880    /// schema into builder-ready state.
881    fn build_from_schema(
882        schema: SchemaV2,
883    ) -> Result<(ConfigOutputs, Option<DecodeProgram>), DecoderError> {
884        schema.validate()?;
885        let program = DecodeProgram::try_from_schema(&schema)?;
886        let legacy = schema.to_legacy_config_outputs()?;
887        Ok((legacy, program))
888    }
889
890    /// Extracts the normalized flag from config outputs.
891    /// - `Some(true)`: Boxes are in normalized [0,1] coordinates
892    /// - `Some(false)`: Boxes are in pixel coordinates
893    /// - `None`: Unknown (not specified in config), caller must infer
894    fn get_normalized(outputs: &[ConfigOutput]) -> Option<bool> {
895        for output in outputs {
896            match output {
897                ConfigOutput::Detection(det) => return det.normalized,
898                ConfigOutput::Boxes(boxes) => return boxes.normalized,
899                _ => {}
900            }
901        }
902        None // not specified
903    }
904
905    fn get_model_type(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
906        // yolo or modelpack
907        let mut yolo = false;
908        let mut modelpack = false;
909        for c in &configs.outputs {
910            match c.decoder() {
911                DecoderType::ModelPack => modelpack = true,
912                DecoderType::Ultralytics => yolo = true,
913            }
914        }
915        match (modelpack, yolo) {
916            (true, true) => Err(DecoderError::InvalidConfig(
917                "Both ModelPack and Yolo outputs found in config".to_string(),
918            )),
919            (true, false) => Self::get_model_type_modelpack(configs),
920            (false, true) => Self::get_model_type_yolo(configs),
921            (false, false) => Err(DecoderError::InvalidConfig(
922                "No outputs found in config".to_string(),
923            )),
924        }
925    }
926
927    fn get_model_type_yolo(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
928        let mut boxes = None;
929        let mut protos = None;
930        let mut split_boxes = None;
931        let mut split_scores = None;
932        let mut split_mask_coeff = None;
933        let mut split_classes = None;
934        for c in configs.outputs {
935            match c {
936                ConfigOutput::Detection(detection) => boxes = Some(detection),
937                ConfigOutput::Segmentation(_) => {
938                    return Err(DecoderError::InvalidConfig(
939                        "Invalid Segmentation output with Yolo decoder".to_string(),
940                    ));
941                }
942                ConfigOutput::Protos(protos_) => protos = Some(protos_),
943                ConfigOutput::Mask(_) => {
944                    return Err(DecoderError::InvalidConfig(
945                        "Invalid Mask output with Yolo decoder".to_string(),
946                    ));
947                }
948                ConfigOutput::Scores(scores) => split_scores = Some(scores),
949                ConfigOutput::Boxes(boxes) => split_boxes = Some(boxes),
950                ConfigOutput::MaskCoefficients(mask_coeff) => split_mask_coeff = Some(mask_coeff),
951                ConfigOutput::Classes(classes) => split_classes = Some(classes),
952            }
953        }
954
955        // Use end-to-end model types when:
956        // 1. decoder_version is explicitly set to Yolo26 (definitive), OR
957        //    decoder_version is not set but the dshapes are (batch, num_boxes,
958        //    num_features)
959        let is_end_to_end_dshape = boxes.as_ref().is_some_and(|b| {
960            let dims = b.dshape.iter().map(|(d, _)| *d).collect::<Vec<_>>();
961            dims == vec![DimName::Batch, DimName::NumBoxes, DimName::NumFeatures]
962        });
963
964        let is_end_to_end = configs
965            .decoder_version
966            .map(|v| v.is_end_to_end())
967            .unwrap_or(is_end_to_end_dshape);
968
969        if is_end_to_end {
970            if let Some(boxes) = boxes {
971                if let Some(protos) = protos {
972                    Self::verify_yolo_seg_det_26(&boxes, &protos)?;
973                    return Ok(ModelType::YoloEndToEndSegDet { boxes, protos });
974                } else {
975                    Self::verify_yolo_det_26(&boxes)?;
976                    return Ok(ModelType::YoloEndToEndDet { boxes });
977                }
978            } else if let (Some(split_boxes), Some(split_scores), Some(split_classes)) =
979                (split_boxes, split_scores, split_classes)
980            {
981                if let (Some(split_mask_coeff), Some(protos)) = (split_mask_coeff, protos) {
982                    Self::verify_yolo_split_end_to_end_segdet(
983                        &split_boxes,
984                        &split_scores,
985                        &split_classes,
986                        &split_mask_coeff,
987                        &protos,
988                    )?;
989                    return Ok(ModelType::YoloSplitEndToEndSegDet {
990                        boxes: split_boxes,
991                        scores: split_scores,
992                        classes: split_classes,
993                        mask_coeff: split_mask_coeff,
994                        protos,
995                    });
996                }
997                Self::verify_yolo_split_end_to_end_det(
998                    &split_boxes,
999                    &split_scores,
1000                    &split_classes,
1001                )?;
1002                return Ok(ModelType::YoloSplitEndToEndDet {
1003                    boxes: split_boxes,
1004                    scores: split_scores,
1005                    classes: split_classes,
1006                });
1007            } else {
1008                return Err(DecoderError::InvalidConfig(
1009                    "Invalid Yolo end-to-end model outputs".to_string(),
1010                ));
1011            }
1012        }
1013
1014        if let Some(boxes) = boxes {
1015            match (split_mask_coeff, protos) {
1016                (Some(mask_coeff), Some(protos)) => {
1017                    // 2-way split: combined detection + separate mask_coeff + protos
1018                    Self::verify_yolo_seg_det_2way(&boxes, &mask_coeff, &protos)?;
1019                    Ok(ModelType::YoloSegDet2Way {
1020                        boxes,
1021                        mask_coeff,
1022                        protos,
1023                    })
1024                }
1025                (_, Some(protos)) => {
1026                    // Unsplit: mask_coefs embedded in detection tensor
1027                    Self::verify_yolo_seg_det(&boxes, &protos)?;
1028                    Ok(ModelType::YoloSegDet { boxes, protos })
1029                }
1030                _ => {
1031                    Self::verify_yolo_det(&boxes)?;
1032                    Ok(ModelType::YoloDet { boxes })
1033                }
1034            }
1035        } else if let (Some(boxes), Some(scores)) = (split_boxes, split_scores) {
1036            if let (Some(mask_coeff), Some(protos)) = (split_mask_coeff, protos) {
1037                Self::verify_yolo_split_segdet(&boxes, &scores, &mask_coeff, &protos)?;
1038                Ok(ModelType::YoloSplitSegDet {
1039                    boxes,
1040                    scores,
1041                    mask_coeff,
1042                    protos,
1043                })
1044            } else {
1045                Self::verify_yolo_split_det(&boxes, &scores)?;
1046                Ok(ModelType::YoloSplitDet { boxes, scores })
1047            }
1048        } else {
1049            Err(DecoderError::InvalidConfig(
1050                "Invalid Yolo model outputs".to_string(),
1051            ))
1052        }
1053    }
1054
1055    fn verify_yolo_det(detect: &configs::Detection) -> Result<(), DecoderError> {
1056        if detect.shape.len() != 3 {
1057            return Err(DecoderError::InvalidConfig(format!(
1058                "Invalid Yolo Detection shape {:?}",
1059                detect.shape
1060            )));
1061        }
1062
1063        Self::verify_dshapes(
1064            &detect.dshape,
1065            &detect.shape,
1066            "Detection",
1067            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1068        )?;
1069        if !detect.dshape.is_empty() {
1070            Self::get_class_count(&detect.dshape, None, None)?;
1071        } else {
1072            Self::get_class_count_no_dshape(detect.into(), None)?;
1073        }
1074
1075        Ok(())
1076    }
1077
1078    fn verify_yolo_det_26(detect: &configs::Detection) -> Result<(), DecoderError> {
1079        if detect.shape.len() != 3 {
1080            return Err(DecoderError::InvalidConfig(format!(
1081                "Invalid Yolo Detection shape {:?}",
1082                detect.shape
1083            )));
1084        }
1085
1086        Self::verify_dshapes(
1087            &detect.dshape,
1088            &detect.shape,
1089            "Detection",
1090            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1091        )?;
1092
1093        if !detect.shape.contains(&6) {
1094            return Err(DecoderError::InvalidConfig(
1095                "Yolo26 Detection must have 6 features".to_string(),
1096            ));
1097        }
1098
1099        Ok(())
1100    }
1101
1102    fn verify_yolo_seg_det(
1103        detection: &configs::Detection,
1104        protos: &configs::Protos,
1105    ) -> Result<(), DecoderError> {
1106        if detection.shape.len() != 3 {
1107            return Err(DecoderError::InvalidConfig(format!(
1108                "Invalid Yolo Detection shape {:?}",
1109                detection.shape
1110            )));
1111        }
1112        if protos.shape.len() != 4 {
1113            return Err(DecoderError::InvalidConfig(format!(
1114                "Invalid Yolo Protos shape {:?}",
1115                protos.shape
1116            )));
1117        }
1118
1119        Self::verify_dshapes(
1120            &detection.dshape,
1121            &detection.shape,
1122            "Detection",
1123            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1124        )?;
1125        Self::verify_dshapes(
1126            &protos.dshape,
1127            &protos.shape,
1128            "Protos",
1129            &[
1130                DimName::Batch,
1131                DimName::Height,
1132                DimName::Width,
1133                DimName::NumProtos,
1134            ],
1135        )?;
1136
1137        let protos_count = Self::get_protos_count(&protos.dshape)
1138            .unwrap_or_else(|| protos.shape[1].min(protos.shape[3]));
1139        log::debug!("Protos count: {}", protos_count);
1140        log::debug!("Detection dshape: {:?}", detection.dshape);
1141        let classes = if !detection.dshape.is_empty() {
1142            Self::get_class_count(&detection.dshape, Some(protos_count), None)?
1143        } else {
1144            Self::get_class_count_no_dshape(detection.into(), Some(protos_count))?
1145        };
1146
1147        if classes == 0 {
1148            return Err(DecoderError::InvalidConfig(
1149                "Yolo Segmentation Detection has zero classes".to_string(),
1150            ));
1151        }
1152
1153        Ok(())
1154    }
1155
1156    fn verify_yolo_seg_det_2way(
1157        detection: &configs::Detection,
1158        mask_coeff: &configs::MaskCoefficients,
1159        protos: &configs::Protos,
1160    ) -> Result<(), DecoderError> {
1161        if detection.shape.len() != 3 {
1162            return Err(DecoderError::InvalidConfig(format!(
1163                "Invalid Yolo 2-Way Detection shape {:?}",
1164                detection.shape
1165            )));
1166        }
1167        if mask_coeff.shape.len() != 3 {
1168            return Err(DecoderError::InvalidConfig(format!(
1169                "Invalid Yolo 2-Way Mask Coefficients shape {:?}",
1170                mask_coeff.shape
1171            )));
1172        }
1173        if protos.shape.len() != 4 {
1174            return Err(DecoderError::InvalidConfig(format!(
1175                "Invalid Yolo 2-Way Protos shape {:?}",
1176                protos.shape
1177            )));
1178        }
1179
1180        Self::verify_dshapes(
1181            &detection.dshape,
1182            &detection.shape,
1183            "Detection",
1184            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1185        )?;
1186        Self::verify_dshapes(
1187            &mask_coeff.dshape,
1188            &mask_coeff.shape,
1189            "Mask Coefficients",
1190            &[DimName::Batch, DimName::NumProtos, DimName::NumBoxes],
1191        )?;
1192        Self::verify_dshapes(
1193            &protos.dshape,
1194            &protos.shape,
1195            "Protos",
1196            &[
1197                DimName::Batch,
1198                DimName::Height,
1199                DimName::Width,
1200                DimName::NumProtos,
1201            ],
1202        )?;
1203
1204        // Validate num_boxes match between detection and mask_coeff
1205        let det_num = Self::get_box_count(&detection.dshape).unwrap_or(detection.shape[2]);
1206        let mask_num = Self::get_box_count(&mask_coeff.dshape).unwrap_or(mask_coeff.shape[2]);
1207        if det_num != mask_num {
1208            return Err(DecoderError::InvalidConfig(format!(
1209                "Yolo 2-Way Detection num_boxes {} incompatible with Mask Coefficients num_boxes {}",
1210                det_num, mask_num
1211            )));
1212        }
1213
1214        // Validate mask_coeff channels match protos channels
1215        let mask_channels = if !mask_coeff.dshape.is_empty() {
1216            Self::get_protos_count(&mask_coeff.dshape).ok_or_else(|| {
1217                DecoderError::InvalidConfig(
1218                    "Could not find num_protos in mask_coeff config".to_string(),
1219                )
1220            })?
1221        } else {
1222            mask_coeff.shape[1]
1223        };
1224        let proto_channels = if !protos.dshape.is_empty() {
1225            Self::get_protos_count(&protos.dshape).ok_or_else(|| {
1226                DecoderError::InvalidConfig(
1227                    "Could not find num_protos in protos config".to_string(),
1228                )
1229            })?
1230        } else {
1231            protos.shape[1].min(protos.shape[3])
1232        };
1233        if mask_channels != proto_channels {
1234            return Err(DecoderError::InvalidConfig(format!(
1235                "Yolo 2-Way Protos channels {} incompatible with Mask Coefficients channels {}",
1236                proto_channels, mask_channels
1237            )));
1238        }
1239
1240        // Validate detection has classes (nc+4 features, no mask_coefs embedded)
1241        if !detection.dshape.is_empty() {
1242            Self::get_class_count(&detection.dshape, None, None)?;
1243        } else {
1244            Self::get_class_count_no_dshape(detection.into(), None)?;
1245        }
1246
1247        Ok(())
1248    }
1249
1250    fn verify_yolo_seg_det_26(
1251        detection: &configs::Detection,
1252        protos: &configs::Protos,
1253    ) -> Result<(), DecoderError> {
1254        if detection.shape.len() != 3 {
1255            return Err(DecoderError::InvalidConfig(format!(
1256                "Invalid Yolo Detection shape {:?}",
1257                detection.shape
1258            )));
1259        }
1260        if protos.shape.len() != 4 {
1261            return Err(DecoderError::InvalidConfig(format!(
1262                "Invalid Yolo Protos shape {:?}",
1263                protos.shape
1264            )));
1265        }
1266
1267        Self::verify_dshapes(
1268            &detection.dshape,
1269            &detection.shape,
1270            "Detection",
1271            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1272        )?;
1273        Self::verify_dshapes(
1274            &protos.dshape,
1275            &protos.shape,
1276            "Protos",
1277            &[
1278                DimName::Batch,
1279                DimName::Height,
1280                DimName::Width,
1281                DimName::NumProtos,
1282            ],
1283        )?;
1284
1285        let protos_count = Self::get_protos_count(&protos.dshape)
1286            .unwrap_or_else(|| protos.shape[1].min(protos.shape[3]));
1287        log::debug!("Protos count: {}", protos_count);
1288        log::debug!("Detection dshape: {:?}", detection.dshape);
1289
1290        if !detection.shape.contains(&(6 + protos_count)) {
1291            return Err(DecoderError::InvalidConfig(format!(
1292                "Yolo26 Segmentation Detection must have num_features be 6 + num_protos = {}",
1293                6 + protos_count
1294            )));
1295        }
1296
1297        Ok(())
1298    }
1299
1300    fn verify_yolo_split_det(
1301        boxes: &configs::Boxes,
1302        scores: &configs::Scores,
1303    ) -> Result<(), DecoderError> {
1304        if boxes.shape.len() != 3 {
1305            return Err(DecoderError::InvalidConfig(format!(
1306                "Invalid Yolo Split Boxes shape {:?}",
1307                boxes.shape
1308            )));
1309        }
1310        if scores.shape.len() != 3 {
1311            return Err(DecoderError::InvalidConfig(format!(
1312                "Invalid Yolo Split Scores shape {:?}",
1313                scores.shape
1314            )));
1315        }
1316
1317        Self::verify_dshapes(
1318            &boxes.dshape,
1319            &boxes.shape,
1320            "Boxes",
1321            &[DimName::Batch, DimName::BoxCoords, DimName::NumBoxes],
1322        )?;
1323        Self::verify_dshapes(
1324            &scores.dshape,
1325            &scores.shape,
1326            "Scores",
1327            &[DimName::Batch, DimName::NumClasses, DimName::NumBoxes],
1328        )?;
1329
1330        let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[2]);
1331        let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[2]);
1332
1333        if boxes_num != scores_num {
1334            return Err(DecoderError::InvalidConfig(format!(
1335                "Yolo Split Detection Boxes num {} incompatible with Scores num {}",
1336                boxes_num, scores_num
1337            )));
1338        }
1339
1340        Ok(())
1341    }
1342
1343    fn verify_yolo_split_segdet(
1344        boxes: &configs::Boxes,
1345        scores: &configs::Scores,
1346        mask_coeff: &configs::MaskCoefficients,
1347        protos: &configs::Protos,
1348    ) -> Result<(), DecoderError> {
1349        if boxes.shape.len() != 3 {
1350            return Err(DecoderError::InvalidConfig(format!(
1351                "Invalid Yolo Split Boxes shape {:?}",
1352                boxes.shape
1353            )));
1354        }
1355        if scores.shape.len() != 3 {
1356            return Err(DecoderError::InvalidConfig(format!(
1357                "Invalid Yolo Split Scores shape {:?}",
1358                scores.shape
1359            )));
1360        }
1361
1362        if mask_coeff.shape.len() != 3 {
1363            return Err(DecoderError::InvalidConfig(format!(
1364                "Invalid Yolo Split Mask Coefficients shape {:?}",
1365                mask_coeff.shape
1366            )));
1367        }
1368
1369        if protos.shape.len() != 4 {
1370            return Err(DecoderError::InvalidConfig(format!(
1371                "Invalid Yolo Protos shape {:?}",
1372                mask_coeff.shape
1373            )));
1374        }
1375
1376        Self::verify_dshapes(
1377            &boxes.dshape,
1378            &boxes.shape,
1379            "Boxes",
1380            &[DimName::Batch, DimName::BoxCoords, DimName::NumBoxes],
1381        )?;
1382        Self::verify_dshapes(
1383            &scores.dshape,
1384            &scores.shape,
1385            "Scores",
1386            &[DimName::Batch, DimName::NumClasses, DimName::NumBoxes],
1387        )?;
1388        Self::verify_dshapes(
1389            &mask_coeff.dshape,
1390            &mask_coeff.shape,
1391            "Mask Coefficients",
1392            &[DimName::Batch, DimName::NumProtos, DimName::NumBoxes],
1393        )?;
1394        Self::verify_dshapes(
1395            &protos.dshape,
1396            &protos.shape,
1397            "Protos",
1398            &[
1399                DimName::Batch,
1400                DimName::Height,
1401                DimName::Width,
1402                DimName::NumProtos,
1403            ],
1404        )?;
1405
1406        let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[2]);
1407        let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[2]);
1408        let mask_num = Self::get_box_count(&mask_coeff.dshape).unwrap_or(mask_coeff.shape[2]);
1409
1410        let mask_channels = if !mask_coeff.dshape.is_empty() {
1411            Self::get_protos_count(&mask_coeff.dshape).ok_or_else(|| {
1412                DecoderError::InvalidConfig("Could not find num_protos in config".to_string())
1413            })?
1414        } else {
1415            mask_coeff.shape[1]
1416        };
1417        let proto_channels = if !protos.dshape.is_empty() {
1418            Self::get_protos_count(&protos.dshape).ok_or_else(|| {
1419                DecoderError::InvalidConfig("Could not find num_protos in config".to_string())
1420            })?
1421        } else {
1422            protos.shape[1].min(protos.shape[3])
1423        };
1424
1425        if boxes_num != scores_num {
1426            return Err(DecoderError::InvalidConfig(format!(
1427                "Yolo Split Detection Boxes num {} incompatible with Scores num {}",
1428                boxes_num, scores_num
1429            )));
1430        }
1431
1432        if boxes_num != mask_num {
1433            return Err(DecoderError::InvalidConfig(format!(
1434                "Yolo Split Detection Boxes num {} incompatible with Mask Coefficients num {}",
1435                boxes_num, mask_num
1436            )));
1437        }
1438
1439        if proto_channels != mask_channels {
1440            return Err(DecoderError::InvalidConfig(format!(
1441                "Yolo Protos channels {} incompatible with Mask Coefficients channels {}",
1442                proto_channels, mask_channels
1443            )));
1444        }
1445
1446        Ok(())
1447    }
1448
1449    fn verify_yolo_split_end_to_end_det(
1450        boxes: &configs::Boxes,
1451        scores: &configs::Scores,
1452        classes: &configs::Classes,
1453    ) -> Result<(), DecoderError> {
1454        if boxes.shape.len() != 3 || !boxes.shape.contains(&4) {
1455            return Err(DecoderError::InvalidConfig(format!(
1456                "Split end-to-end boxes must be [batch, N, 4], got {:?}",
1457                boxes.shape
1458            )));
1459        }
1460        if scores.shape.len() != 3 || !scores.shape.contains(&1) {
1461            return Err(DecoderError::InvalidConfig(format!(
1462                "Split end-to-end scores must be [batch, N, 1], got {:?}",
1463                scores.shape
1464            )));
1465        }
1466        if classes.shape.len() != 3 || !classes.shape.contains(&1) {
1467            return Err(DecoderError::InvalidConfig(format!(
1468                "Split end-to-end classes must be [batch, N, 1], got {:?}",
1469                classes.shape
1470            )));
1471        }
1472        Ok(())
1473    }
1474
1475    fn verify_yolo_split_end_to_end_segdet(
1476        boxes: &configs::Boxes,
1477        scores: &configs::Scores,
1478        classes: &configs::Classes,
1479        mask_coeff: &configs::MaskCoefficients,
1480        protos: &configs::Protos,
1481    ) -> Result<(), DecoderError> {
1482        Self::verify_yolo_split_end_to_end_det(boxes, scores, classes)?;
1483        if mask_coeff.shape.len() != 3 {
1484            return Err(DecoderError::InvalidConfig(format!(
1485                "Invalid split end-to-end mask coefficients shape {:?}",
1486                mask_coeff.shape
1487            )));
1488        }
1489        if protos.shape.len() != 4 {
1490            return Err(DecoderError::InvalidConfig(format!(
1491                "Invalid protos shape {:?}",
1492                protos.shape
1493            )));
1494        }
1495        Ok(())
1496    }
1497
1498    fn get_model_type_modelpack(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
1499        let mut split_decoders = Vec::new();
1500        let mut segment_ = None;
1501        let mut scores_ = None;
1502        let mut boxes_ = None;
1503        for c in configs.outputs {
1504            match c {
1505                ConfigOutput::Detection(detection) => split_decoders.push(detection),
1506                ConfigOutput::Segmentation(segmentation) => segment_ = Some(segmentation),
1507                ConfigOutput::Mask(_) => {}
1508                ConfigOutput::Protos(_) => {
1509                    return Err(DecoderError::InvalidConfig(
1510                        "ModelPack should not have protos".to_string(),
1511                    ));
1512                }
1513                ConfigOutput::Scores(scores) => scores_ = Some(scores),
1514                ConfigOutput::Boxes(boxes) => boxes_ = Some(boxes),
1515                ConfigOutput::MaskCoefficients(_) => {
1516                    return Err(DecoderError::InvalidConfig(
1517                        "ModelPack should not have mask coefficients".to_string(),
1518                    ));
1519                }
1520                ConfigOutput::Classes(_) => {
1521                    return Err(DecoderError::InvalidConfig(
1522                        "ModelPack should not have classes output".to_string(),
1523                    ));
1524                }
1525            }
1526        }
1527
1528        if let Some(segmentation) = segment_ {
1529            if !split_decoders.is_empty() {
1530                let classes = Self::verify_modelpack_split_det(&split_decoders)?;
1531                Self::verify_modelpack_seg(&segmentation, Some(classes))?;
1532                Ok(ModelType::ModelPackSegDetSplit {
1533                    detection: split_decoders,
1534                    segmentation,
1535                })
1536            } else if let (Some(scores), Some(boxes)) = (scores_, boxes_) {
1537                let classes = Self::verify_modelpack_det(&boxes, &scores)?;
1538                Self::verify_modelpack_seg(&segmentation, Some(classes))?;
1539                Ok(ModelType::ModelPackSegDet {
1540                    boxes,
1541                    scores,
1542                    segmentation,
1543                })
1544            } else {
1545                Self::verify_modelpack_seg(&segmentation, None)?;
1546                Ok(ModelType::ModelPackSeg { segmentation })
1547            }
1548        } else if !split_decoders.is_empty() {
1549            Self::verify_modelpack_split_det(&split_decoders)?;
1550            Ok(ModelType::ModelPackDetSplit {
1551                detection: split_decoders,
1552            })
1553        } else if let (Some(scores), Some(boxes)) = (scores_, boxes_) {
1554            Self::verify_modelpack_det(&boxes, &scores)?;
1555            Ok(ModelType::ModelPackDet { boxes, scores })
1556        } else {
1557            Err(DecoderError::InvalidConfig(
1558                "Invalid ModelPack model outputs".to_string(),
1559            ))
1560        }
1561    }
1562
1563    fn verify_modelpack_det(
1564        boxes: &configs::Boxes,
1565        scores: &configs::Scores,
1566    ) -> Result<usize, DecoderError> {
1567        if boxes.shape.len() != 4 {
1568            return Err(DecoderError::InvalidConfig(format!(
1569                "Invalid ModelPack Boxes shape {:?}",
1570                boxes.shape
1571            )));
1572        }
1573        if scores.shape.len() != 3 {
1574            return Err(DecoderError::InvalidConfig(format!(
1575                "Invalid ModelPack Scores shape {:?}",
1576                scores.shape
1577            )));
1578        }
1579
1580        Self::verify_dshapes(
1581            &boxes.dshape,
1582            &boxes.shape,
1583            "Boxes",
1584            &[
1585                DimName::Batch,
1586                DimName::NumBoxes,
1587                DimName::Padding,
1588                DimName::BoxCoords,
1589            ],
1590        )?;
1591        Self::verify_dshapes(
1592            &scores.dshape,
1593            &scores.shape,
1594            "Scores",
1595            &[DimName::Batch, DimName::NumBoxes, DimName::NumClasses],
1596        )?;
1597
1598        let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[1]);
1599        let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[1]);
1600
1601        if boxes_num != scores_num {
1602            return Err(DecoderError::InvalidConfig(format!(
1603                "ModelPack Detection Boxes num {} incompatible with Scores num {}",
1604                boxes_num, scores_num
1605            )));
1606        }
1607
1608        let num_classes = if !scores.dshape.is_empty() {
1609            Self::get_class_count(&scores.dshape, None, None)?
1610        } else {
1611            Self::get_class_count_no_dshape(scores.into(), None)?
1612        };
1613
1614        Ok(num_classes)
1615    }
1616
1617    fn verify_modelpack_split_det(boxes: &[configs::Detection]) -> Result<usize, DecoderError> {
1618        let mut num_classes = None;
1619        for b in boxes {
1620            let Some(num_anchors) = b.anchors.as_ref().map(|a| a.len()) else {
1621                return Err(DecoderError::InvalidConfig(
1622                    "ModelPack Split Detection missing anchors".to_string(),
1623                ));
1624            };
1625
1626            if num_anchors == 0 {
1627                return Err(DecoderError::InvalidConfig(
1628                    "ModelPack Split Detection has zero anchors".to_string(),
1629                ));
1630            }
1631
1632            if b.shape.len() != 4 {
1633                return Err(DecoderError::InvalidConfig(format!(
1634                    "Invalid ModelPack Split Detection shape {:?}",
1635                    b.shape
1636                )));
1637            }
1638
1639            Self::verify_dshapes(
1640                &b.dshape,
1641                &b.shape,
1642                "Split Detection",
1643                &[
1644                    DimName::Batch,
1645                    DimName::Height,
1646                    DimName::Width,
1647                    DimName::NumAnchorsXFeatures,
1648                ],
1649            )?;
1650            let classes = if !b.dshape.is_empty() {
1651                Self::get_class_count(&b.dshape, None, Some(num_anchors))?
1652            } else {
1653                Self::get_class_count_no_dshape(b.into(), None)?
1654            };
1655
1656            match num_classes {
1657                Some(n) => {
1658                    if n != classes {
1659                        return Err(DecoderError::InvalidConfig(format!(
1660                            "ModelPack Split Detection inconsistent number of classes: previous {}, current {}",
1661                            n, classes
1662                        )));
1663                    }
1664                }
1665                None => {
1666                    num_classes = Some(classes);
1667                }
1668            }
1669        }
1670
1671        Ok(num_classes.unwrap_or(0))
1672    }
1673
1674    fn verify_modelpack_seg(
1675        segmentation: &configs::Segmentation,
1676        classes: Option<usize>,
1677    ) -> Result<(), DecoderError> {
1678        if segmentation.shape.len() != 4 {
1679            return Err(DecoderError::InvalidConfig(format!(
1680                "Invalid ModelPack Segmentation shape {:?}",
1681                segmentation.shape
1682            )));
1683        }
1684        Self::verify_dshapes(
1685            &segmentation.dshape,
1686            &segmentation.shape,
1687            "Segmentation",
1688            &[
1689                DimName::Batch,
1690                DimName::Height,
1691                DimName::Width,
1692                DimName::NumClasses,
1693            ],
1694        )?;
1695
1696        if let Some(classes) = classes {
1697            let seg_classes = if !segmentation.dshape.is_empty() {
1698                Self::get_class_count(&segmentation.dshape, None, None)?
1699            } else {
1700                Self::get_class_count_no_dshape(segmentation.into(), None)?
1701            };
1702
1703            if seg_classes != classes + 1 {
1704                return Err(DecoderError::InvalidConfig(format!(
1705                    "ModelPack Segmentation channels {} incompatible with number of classes {}",
1706                    seg_classes, classes
1707                )));
1708            }
1709        }
1710        Ok(())
1711    }
1712
1713    // verifies that dshapes match the given shape
1714    fn verify_dshapes(
1715        dshape: &[(DimName, usize)],
1716        shape: &[usize],
1717        name: &str,
1718        dims: &[DimName],
1719    ) -> Result<(), DecoderError> {
1720        for s in shape {
1721            if *s == 0 {
1722                return Err(DecoderError::InvalidConfig(format!(
1723                    "{} shape has zero dimension",
1724                    name
1725                )));
1726            }
1727        }
1728
1729        if shape.len() != dims.len() {
1730            return Err(DecoderError::InvalidConfig(format!(
1731                "{} shape length {} does not match expected dims length {}",
1732                name,
1733                shape.len(),
1734                dims.len()
1735            )));
1736        }
1737
1738        if dshape.is_empty() {
1739            return Ok(());
1740        }
1741        // check the dshape lengths match the shape lengths
1742        if dshape.len() != shape.len() {
1743            return Err(DecoderError::InvalidConfig(format!(
1744                "{} dshape length does not match shape length",
1745                name
1746            )));
1747        }
1748
1749        // check the dshape values match the shape values
1750        for ((dim_name, dim_size), shape_size) in dshape.iter().zip(shape) {
1751            if dim_size != shape_size {
1752                return Err(DecoderError::InvalidConfig(format!(
1753                    "{} dshape dimension {} size {} does not match shape size {}",
1754                    name, dim_name, dim_size, shape_size
1755                )));
1756            }
1757            if *dim_name == DimName::Padding && *dim_size != 1 {
1758                return Err(DecoderError::InvalidConfig(
1759                    "Padding dimension size must be 1".to_string(),
1760                ));
1761            }
1762
1763            if *dim_name == DimName::BoxCoords && *dim_size != 4 {
1764                return Err(DecoderError::InvalidConfig(
1765                    "BoxCoords dimension size must be 4".to_string(),
1766                ));
1767            }
1768        }
1769
1770        let dims_present = HashSet::<DimName>::from_iter(dshape.iter().map(|(name, _)| *name));
1771        for dim in dims {
1772            if !dims_present.contains(dim) {
1773                return Err(DecoderError::InvalidConfig(format!(
1774                    "{} dshape missing required dimension {:?}",
1775                    name, dim
1776                )));
1777            }
1778        }
1779
1780        Ok(())
1781    }
1782
1783    fn get_box_count(dshape: &[(DimName, usize)]) -> Option<usize> {
1784        for (dim_name, dim_size) in dshape {
1785            if *dim_name == DimName::NumBoxes {
1786                return Some(*dim_size);
1787            }
1788        }
1789        None
1790    }
1791
1792    fn get_class_count_no_dshape(
1793        config: ConfigOutputRef,
1794        protos: Option<usize>,
1795    ) -> Result<usize, DecoderError> {
1796        match config {
1797            ConfigOutputRef::Detection(detection) => match detection.decoder {
1798                DecoderType::Ultralytics => {
1799                    if detection.shape[1] <= 4 + protos.unwrap_or(0) {
1800                        return Err(DecoderError::InvalidConfig(format!(
1801                            "Invalid shape: Yolo num_features {} must be greater than {}",
1802                            detection.shape[1],
1803                            4 + protos.unwrap_or(0),
1804                        )));
1805                    }
1806                    Ok(detection.shape[1] - 4 - protos.unwrap_or(0))
1807                }
1808                DecoderType::ModelPack => {
1809                    let Some(num_anchors) = detection.anchors.as_ref().map(|a| a.len()) else {
1810                        return Err(DecoderError::Internal(
1811                            "ModelPack Detection missing anchors".to_string(),
1812                        ));
1813                    };
1814                    let anchors_x_features = detection.shape[3];
1815                    if anchors_x_features <= num_anchors * 5 {
1816                        return Err(DecoderError::InvalidConfig(format!(
1817                            "Invalid ModelPack Split Detection shape: anchors_x_features {} not greater than number of anchors * 5 = {}",
1818                            anchors_x_features,
1819                            num_anchors * 5,
1820                        )));
1821                    }
1822
1823                    if !anchors_x_features.is_multiple_of(num_anchors) {
1824                        return Err(DecoderError::InvalidConfig(format!(
1825                            "Invalid ModelPack Split Detection shape: anchors_x_features {} not a multiple of number of anchors {}",
1826                            anchors_x_features, num_anchors
1827                        )));
1828                    }
1829                    Ok(anchors_x_features / num_anchors - 5)
1830                }
1831            },
1832
1833            ConfigOutputRef::Scores(scores) => match scores.decoder {
1834                DecoderType::Ultralytics => Ok(scores.shape[1]),
1835                DecoderType::ModelPack => Ok(scores.shape[2]),
1836            },
1837            ConfigOutputRef::Segmentation(seg) => Ok(seg.shape[3]),
1838            _ => Err(DecoderError::Internal(
1839                "Attempted to get class count from unsupported config output".to_owned(),
1840            )),
1841        }
1842    }
1843
1844    // get the class count from dshape or calculate from num_features
1845    fn get_class_count(
1846        dshape: &[(DimName, usize)],
1847        protos: Option<usize>,
1848        anchors: Option<usize>,
1849    ) -> Result<usize, DecoderError> {
1850        if dshape.is_empty() {
1851            return Ok(0);
1852        }
1853        // if it has num_classes in dshape, return it
1854        for (dim_name, dim_size) in dshape {
1855            if *dim_name == DimName::NumClasses {
1856                return Ok(*dim_size);
1857            }
1858        }
1859
1860        // number of classes can be calculated from num_features - 4 for yolo.  If the
1861        // model has protos, we also subtract the number of protos.
1862        for (dim_name, dim_size) in dshape {
1863            if *dim_name == DimName::NumFeatures {
1864                let protos = protos.unwrap_or(0);
1865                if protos + 4 >= *dim_size {
1866                    return Err(DecoderError::InvalidConfig(format!(
1867                        "Invalid shape: Yolo num_features {} must be greater than {}",
1868                        *dim_size,
1869                        protos + 4,
1870                    )));
1871                }
1872                return Ok(*dim_size - 4 - protos);
1873            }
1874        }
1875
1876        // number of classes can be calculated from number of anchors for modelpack
1877        // split detection
1878        if let Some(num_anchors) = anchors {
1879            for (dim_name, dim_size) in dshape {
1880                if *dim_name == DimName::NumAnchorsXFeatures {
1881                    let anchors_x_features = *dim_size;
1882                    if anchors_x_features <= num_anchors * 5 {
1883                        return Err(DecoderError::InvalidConfig(format!(
1884                            "Invalid ModelPack Split Detection shape: anchors_x_features {} not greater than number of anchors * 5 = {}",
1885                            anchors_x_features,
1886                            num_anchors * 5,
1887                        )));
1888                    }
1889
1890                    if !anchors_x_features.is_multiple_of(num_anchors) {
1891                        return Err(DecoderError::InvalidConfig(format!(
1892                            "Invalid ModelPack Split Detection shape: anchors_x_features {} not a multiple of number of anchors {}",
1893                            anchors_x_features, num_anchors
1894                        )));
1895                    }
1896                    return Ok((anchors_x_features / num_anchors) - 5);
1897                }
1898            }
1899        }
1900        Err(DecoderError::InvalidConfig(
1901            "Cannot determine number of classes from dshape".to_owned(),
1902        ))
1903    }
1904
1905    fn get_protos_count(dshape: &[(DimName, usize)]) -> Option<usize> {
1906        for (dim_name, dim_size) in dshape {
1907            if *dim_name == DimName::NumProtos {
1908                return Some(*dim_size);
1909            }
1910        }
1911        None
1912    }
1913}