Skip to main content

edgefirst_decoder/decoder/
builder.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::HashSet;
5
6use super::config::ConfigOutputRef;
7use super::configs::{self, DecoderType, DecoderVersion, DimName, ModelType};
8use super::merge::DecodeProgram;
9use super::{ConfigOutput, ConfigOutputs, Decoder};
10use crate::per_scale::{DecodeDtype, PerScalePlan};
11use crate::schema::SchemaV2;
12use crate::DecoderError;
13
14/// Extract `(width, height)` from a schema [`crate::schema::InputSpec`].
15///
16/// Prefers named dims (`DimName::Width` / `DimName::Height`) when the
17/// `dshape` is populated, falling back to the NHWC convention
18/// (`shape[1] = H, shape[2] = W`) for 4-element shapes whenever either
19/// named dim is missing. Returns `None` for any other shape arity — the
20/// decoder treats that as "input dims unknown" and skips the EDGEAI-1303
21/// normalization path.
22///
23/// The fallback fires when **either** `Height` or `Width` is missing from
24/// the dshape (not only when both are absent), so a partially-named
25/// dshape (e.g. only `Width`) still resolves both dims via positional
26/// inference instead of silently disabling normalization.
27fn input_dims_from_spec(input: &crate::schema::InputSpec) -> Option<(usize, usize)> {
28    use crate::configs::DimName;
29    let mut h = None;
30    let mut w = None;
31    // `SchemaV2::validate()` doesn't currently enforce
32    // `dshape.len() <= shape.len()`, so a malformed schema can trip an
33    // out-of-bounds index here. Use `shape.get(i)` to silently skip
34    // dshape entries that overflow the shape — the caller treats
35    // missing dims as "unknown" and disables EDGEAI-1303 normalization.
36    for (i, (name, _)) in input.dshape.iter().enumerate() {
37        match name {
38            DimName::Height => h = input.shape.get(i).copied(),
39            DimName::Width => w = input.shape.get(i).copied(),
40            _ => {}
41        }
42    }
43    if (h.is_none() || w.is_none()) && input.shape.len() == 4 {
44        // NHWC default: [N, H, W, C]. Mirrors the per-scale `extract_hw`
45        // fallback (`crates/decoder/src/per_scale/plan.rs::extract_hw`).
46        // Only fill the missing axis so a partial named dshape still
47        // resolves both dims.
48        h = h.or(Some(input.shape[1]));
49        w = w.or(Some(input.shape[2]));
50    }
51    match (w, h) {
52        (Some(w), Some(h)) => Some((w, h)),
53        _ => None,
54    }
55}
56
57#[cfg(test)]
58mod input_dims_from_spec_tests {
59    use super::input_dims_from_spec;
60    use crate::configs::DimName;
61    use crate::schema::InputSpec;
62
63    #[test]
64    fn named_dshape_resolves_dims() {
65        let spec = InputSpec {
66            shape: vec![1, 480, 640, 3],
67            dshape: vec![
68                (DimName::Batch, 1),
69                (DimName::Height, 480),
70                (DimName::Width, 640),
71                (DimName::NumFeatures, 3),
72            ],
73            cameraadaptor: None,
74        };
75        assert_eq!(input_dims_from_spec(&spec), Some((640, 480)));
76    }
77
78    #[test]
79    fn empty_dshape_falls_back_to_nhwc_for_4d() {
80        let spec = InputSpec {
81            shape: vec![1, 480, 640, 3],
82            dshape: vec![],
83            cameraadaptor: None,
84        };
85        assert_eq!(input_dims_from_spec(&spec), Some((640, 480)));
86    }
87
88    #[test]
89    fn malformed_dshape_longer_than_shape_does_not_panic() {
90        // Regression for Copilot review on PR #63: indexing
91        // `input.shape[i]` while iterating dshape can OOB-panic when
92        // `dshape.len() > shape.len()`. The fix uses `shape.get(i)`
93        // and silently treats overflow as "dim missing".
94        let spec = InputSpec {
95            shape: vec![640, 480], // 2-D shape
96            dshape: vec![
97                (DimName::Width, 640),
98                (DimName::Height, 480),
99                (DimName::NumFeatures, 3), // overflow — index 2 ≥ shape.len()
100            ],
101            cameraadaptor: None,
102        };
103        // First two dshape entries resolve via `shape.get()`; the third
104        // is a no-op. Width/Height both resolved, so we expect Some.
105        assert_eq!(input_dims_from_spec(&spec), Some((640, 480)));
106    }
107
108    #[test]
109    fn malformed_dshape_only_overflow_returns_none() {
110        // All dshape entries are past the shape boundary — width and
111        // height stay None and the 4-D NHWC fallback doesn't fire
112        // (shape.len() == 1), so we get None instead of a panic.
113        let spec = InputSpec {
114            shape: vec![1],
115            dshape: vec![
116                (DimName::NumFeatures, 3),
117                (DimName::Height, 480),
118                (DimName::Width, 640),
119            ],
120            cameraadaptor: None,
121        };
122        assert_eq!(input_dims_from_spec(&spec), None);
123    }
124}
125
126#[derive(Debug, Clone, PartialEq)]
127pub struct DecoderBuilder {
128    config_src: Option<ConfigSource>,
129    iou_threshold: f32,
130    score_threshold: f32,
131    /// NMS mode: Some(mode) applies NMS, None bypasses NMS (for end-to-end
132    /// models)
133    nms: Option<configs::Nms>,
134    /// Output dtype for the per-scale fast path. Has no effect on
135    /// schemas without per-scale children (which use the legacy decode
136    /// path).
137    decode_dtype: DecodeDtype,
138    pre_nms_top_k: usize,
139    max_det: usize,
140    /// Explicit override for the model input dimensions `(width, height)`,
141    /// consumed by EDGEAI-1303 normalization. When set, takes precedence
142    /// over schema-derived dims; when `None`, the value is read from the
143    /// schema's `input.shape` / `input.dshape` at build time.
144    input_dims: Option<(usize, usize)>,
145}
146
147#[derive(Debug, Clone, PartialEq)]
148enum ConfigSource {
149    Yaml(String),
150    Json(String),
151    Config(ConfigOutputs),
152    /// Schema v2 metadata. During build the schema is either converted
153    /// to a legacy [`ConfigOutputs`] (flat case) or used to construct a
154    /// [`DecodeProgram`] that performs per-child dequant + merge at
155    /// decode time.
156    Schema(SchemaV2),
157}
158
159impl Default for DecoderBuilder {
160    /// Creates a default DecoderBuilder with no configuration and 0.5 score
161    /// threshold and 0.5 IoU threshold.
162    ///
163    /// A valid configuration must be provided before building the Decoder.
164    ///
165    /// # Examples
166    /// ```rust
167    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
168    /// # fn main() -> DecoderResult<()> {
169    /// #  let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
170    /// let decoder = DecoderBuilder::default()
171    ///     .with_config_yaml_str(config_yaml)
172    ///     .build()?;
173    /// assert_eq!(decoder.score_threshold, 0.5);
174    /// assert_eq!(decoder.iou_threshold, 0.5);
175    ///
176    /// # Ok(())
177    /// # }
178    /// ```
179    fn default() -> Self {
180        Self {
181            config_src: None,
182            iou_threshold: 0.5,
183            score_threshold: 0.5,
184            nms: Some(configs::Nms::ClassAgnostic),
185            decode_dtype: DecodeDtype::F32,
186            pre_nms_top_k: 300,
187            max_det: 300,
188            input_dims: None,
189        }
190    }
191}
192
193impl DecoderBuilder {
194    /// Creates a default DecoderBuilder with no configuration and 0.5 score
195    /// threshold and 0.5 IoU threshold.
196    ///
197    /// A valid configuration must be provided before building the Decoder.
198    ///
199    /// # Examples
200    /// ```rust
201    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
202    /// # fn main() -> DecoderResult<()> {
203    /// #  let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
204    /// let decoder = DecoderBuilder::new()
205    ///     .with_config_yaml_str(config_yaml)
206    ///     .build()?;
207    /// assert_eq!(decoder.score_threshold, 0.5);
208    /// assert_eq!(decoder.iou_threshold, 0.5);
209    ///
210    /// # Ok(())
211    /// # }
212    /// ```
213    pub fn new() -> Self {
214        Self::default()
215    }
216
217    /// Loads a model configuration in YAML format. Does not check if the string
218    /// is a correct configuration file. Use `DecoderBuilder.build()` to
219    /// deserialize the YAML and parse the model configuration.
220    ///
221    /// # Examples
222    /// ```rust
223    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
224    /// # fn main() -> DecoderResult<()> {
225    /// let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
226    /// let decoder = DecoderBuilder::new()
227    ///     .with_config_yaml_str(config_yaml)
228    ///     .build()?;
229    ///
230    /// # Ok(())
231    /// # }
232    /// ```
233    pub fn with_config_yaml_str(mut self, yaml_str: String) -> Self {
234        self.config_src.replace(ConfigSource::Yaml(yaml_str));
235        self
236    }
237
238    /// Loads a model configuration in JSON format. Does not check if the string
239    /// is a correct configuration file. Use `DecoderBuilder.build()` to
240    /// deserialize the JSON and parse the model configuration.
241    ///
242    /// # Examples
243    /// ```rust
244    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
245    /// # fn main() -> DecoderResult<()> {
246    /// let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
247    /// let decoder = DecoderBuilder::new()
248    ///     .with_config_json_str(config_json)
249    ///     .build()?;
250    ///
251    /// # Ok(())
252    /// # }
253    /// ```
254    pub fn with_config_json_str(mut self, json_str: String) -> Self {
255        self.config_src.replace(ConfigSource::Json(json_str));
256        self
257    }
258
259    /// Loads a model configuration. Does not check if the configuration is
260    /// correct. Intended to be used when the user needs control over the
261    /// deserialize of the configuration information. Use
262    /// `DecoderBuilder.build()` to parse the model configuration.
263    ///
264    /// # Examples
265    /// ```rust
266    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
267    /// # fn main() -> DecoderResult<()> {
268    /// let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json"));
269    /// let config = serde_json::from_str(config_json)?;
270    /// let decoder = DecoderBuilder::new().with_config(config).build()?;
271    ///
272    /// # Ok(())
273    /// # }
274    /// ```
275    pub fn with_config(mut self, config: ConfigOutputs) -> Self {
276        self.config_src.replace(ConfigSource::Config(config));
277        self
278    }
279
280    /// Configure the decoder from a schema v2 metadata document.
281    ///
282    /// Accepts a [`SchemaV2`] as produced by [`SchemaV2::parse_json`],
283    /// [`SchemaV2::parse_yaml`], [`SchemaV2::parse_file`], or
284    /// constructed programmatically. The builder validates the schema,
285    /// compiles a [`DecodeProgram`] for any split logical outputs
286    /// (per-scale or channel sub-splits), and downconverts the
287    /// logical-level semantics to the legacy [`ConfigOutputs`]
288    /// representation consumed by the existing decoder dispatch.
289    ///
290    /// # Examples
291    ///
292    /// ```rust,no_run
293    /// use edgefirst_decoder::{DecoderBuilder, DecoderResult};
294    /// use edgefirst_decoder::schema::SchemaV2;
295    ///
296    /// # fn main() -> DecoderResult<()> {
297    /// let schema = SchemaV2::parse_file("model/edgefirst.json")?;
298    /// let decoder = DecoderBuilder::new()
299    ///     .with_schema(schema)
300    ///     .with_score_threshold(0.25)
301    ///     .build()?;
302    /// # Ok(())
303    /// # }
304    /// ```
305    pub fn with_schema(mut self, schema: SchemaV2) -> Self {
306        self.config_src.replace(ConfigSource::Schema(schema));
307        self
308    }
309
310    /// Choose the output dtype for the per-scale decoder pipeline.
311    ///
312    /// Defaults to [`DecodeDtype::F32`]. Has no effect on schemas
313    /// without per-scale children (which use the legacy decode path).
314    /// `F16` saves ~2× memory bandwidth at the cost of 10-bit mantissa
315    /// precision; empirically safe for YOLO-family models.
316    ///
317    /// # Examples
318    ///
319    /// ```rust,no_run
320    /// use edgefirst_decoder::{DecodeDtype, DecoderBuilder, DecoderResult};
321    /// use edgefirst_decoder::schema::SchemaV2;
322    ///
323    /// # fn main() -> DecoderResult<()> {
324    /// let schema = SchemaV2::parse_file("model/edgefirst.json")?;
325    /// let decoder = DecoderBuilder::new()
326    ///     .with_schema(schema)
327    ///     .with_decode_dtype(DecodeDtype::F32)
328    ///     .build()?;
329    /// # Ok(())
330    /// # }
331    /// ```
332    pub fn with_decode_dtype(mut self, dtype: DecodeDtype) -> Self {
333        self.decode_dtype = dtype;
334        self
335    }
336
337    /// Loads a YOLO detection model configuration.  Use
338    /// `DecoderBuilder.build()` to parse the model configuration.
339    ///
340    /// # Examples
341    /// ```rust
342    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
343    /// # fn main() -> DecoderResult<()> {
344    /// let decoder = DecoderBuilder::new()
345    ///     .with_config_yolo_det(
346    ///         configs::Detection {
347    ///             anchors: None,
348    ///             decoder: configs::DecoderType::Ultralytics,
349    ///             quantization: Some(configs::QuantTuple(0.012345, 26)),
350    ///             shape: vec![1, 84, 8400],
351    ///             dshape: Vec::new(),
352    ///             normalized: Some(true),
353    ///         },
354    ///         None,
355    ///     )
356    ///     .build()?;
357    ///
358    /// # Ok(())
359    /// # }
360    /// ```
361    pub fn with_config_yolo_det(
362        mut self,
363        boxes: configs::Detection,
364        version: Option<DecoderVersion>,
365    ) -> Self {
366        let config = ConfigOutputs {
367            outputs: vec![ConfigOutput::Detection(boxes)],
368            decoder_version: version,
369            ..Default::default()
370        };
371        self.config_src.replace(ConfigSource::Config(config));
372        self
373    }
374
375    /// Loads a YOLO split detection model configuration.  Use
376    /// `DecoderBuilder.build()` to parse the model configuration.
377    ///
378    /// # Examples
379    /// ```rust
380    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
381    /// # fn main() -> DecoderResult<()> {
382    /// let boxes_config = configs::Boxes {
383    ///     decoder: configs::DecoderType::Ultralytics,
384    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
385    ///     shape: vec![1, 4, 8400],
386    ///     dshape: Vec::new(),
387    ///     normalized: Some(true),
388    /// };
389    /// let scores_config = configs::Scores {
390    ///     decoder: configs::DecoderType::Ultralytics,
391    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
392    ///     shape: vec![1, 80, 8400],
393    ///     dshape: Vec::new(),
394    /// };
395    /// let decoder = DecoderBuilder::new()
396    ///     .with_config_yolo_split_det(boxes_config, scores_config)
397    ///     .build()?;
398    /// # Ok(())
399    /// # }
400    /// ```
401    pub fn with_config_yolo_split_det(
402        mut self,
403        boxes: configs::Boxes,
404        scores: configs::Scores,
405    ) -> Self {
406        let config = ConfigOutputs {
407            outputs: vec![ConfigOutput::Boxes(boxes), ConfigOutput::Scores(scores)],
408            ..Default::default()
409        };
410        self.config_src.replace(ConfigSource::Config(config));
411        self
412    }
413
414    /// Loads a YOLO segmentation model configuration.  Use
415    /// `DecoderBuilder.build()` to parse the model configuration.
416    ///
417    /// # Examples
418    /// ```rust
419    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
420    /// # fn main() -> DecoderResult<()> {
421    /// let seg_config = configs::Detection {
422    ///     decoder: configs::DecoderType::Ultralytics,
423    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
424    ///     shape: vec![1, 116, 8400],
425    ///     anchors: None,
426    ///     dshape: Vec::new(),
427    ///     normalized: Some(true),
428    /// };
429    /// let protos_config = configs::Protos {
430    ///     decoder: configs::DecoderType::Ultralytics,
431    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
432    ///     shape: vec![1, 160, 160, 32],
433    ///     dshape: Vec::new(),
434    /// };
435    /// let decoder = DecoderBuilder::new()
436    ///     .with_config_yolo_segdet(
437    ///         seg_config,
438    ///         protos_config,
439    ///         Some(configs::DecoderVersion::Yolov8),
440    ///     )
441    ///     .build()?;
442    /// # Ok(())
443    /// # }
444    /// ```
445    pub fn with_config_yolo_segdet(
446        mut self,
447        boxes: configs::Detection,
448        protos: configs::Protos,
449        version: Option<DecoderVersion>,
450    ) -> Self {
451        let config = ConfigOutputs {
452            outputs: vec![ConfigOutput::Detection(boxes), ConfigOutput::Protos(protos)],
453            decoder_version: version,
454            ..Default::default()
455        };
456        self.config_src.replace(ConfigSource::Config(config));
457        self
458    }
459
460    /// Loads a YOLO split segmentation model configuration.  Use
461    /// `DecoderBuilder.build()` to parse the model configuration.
462    ///
463    /// # Examples
464    /// ```rust
465    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
466    /// # fn main() -> DecoderResult<()> {
467    /// let boxes_config = configs::Boxes {
468    ///     decoder: configs::DecoderType::Ultralytics,
469    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
470    ///     shape: vec![1, 4, 8400],
471    ///     dshape: Vec::new(),
472    ///     normalized: Some(true),
473    /// };
474    /// let scores_config = configs::Scores {
475    ///     decoder: configs::DecoderType::Ultralytics,
476    ///     quantization: Some(configs::QuantTuple(0.012345, 14)),
477    ///     shape: vec![1, 80, 8400],
478    ///     dshape: Vec::new(),
479    /// };
480    /// let mask_config = configs::MaskCoefficients {
481    ///     decoder: configs::DecoderType::Ultralytics,
482    ///     quantization: Some(configs::QuantTuple(0.0064123, 125)),
483    ///     shape: vec![1, 32, 8400],
484    ///     dshape: Vec::new(),
485    /// };
486    /// let protos_config = configs::Protos {
487    ///     decoder: configs::DecoderType::Ultralytics,
488    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
489    ///     shape: vec![1, 160, 160, 32],
490    ///     dshape: Vec::new(),
491    /// };
492    /// let decoder = DecoderBuilder::new()
493    ///     .with_config_yolo_split_segdet(boxes_config, scores_config, mask_config, protos_config)
494    ///     .build()?;
495    /// # Ok(())
496    /// # }
497    /// ```
498    pub fn with_config_yolo_split_segdet(
499        mut self,
500        boxes: configs::Boxes,
501        scores: configs::Scores,
502        mask_coefficients: configs::MaskCoefficients,
503        protos: configs::Protos,
504    ) -> Self {
505        let config = ConfigOutputs {
506            outputs: vec![
507                ConfigOutput::Boxes(boxes),
508                ConfigOutput::Scores(scores),
509                ConfigOutput::MaskCoefficients(mask_coefficients),
510                ConfigOutput::Protos(protos),
511            ],
512            ..Default::default()
513        };
514        self.config_src.replace(ConfigSource::Config(config));
515        self
516    }
517
518    /// Loads a ModelPack detection model configuration.  Use
519    /// `DecoderBuilder.build()` to parse the model configuration.
520    ///
521    /// # Examples
522    /// ```rust
523    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
524    /// # fn main() -> DecoderResult<()> {
525    /// let boxes_config = configs::Boxes {
526    ///     decoder: configs::DecoderType::ModelPack,
527    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
528    ///     shape: vec![1, 8400, 1, 4],
529    ///     dshape: Vec::new(),
530    ///     normalized: Some(true),
531    /// };
532    /// let scores_config = configs::Scores {
533    ///     decoder: configs::DecoderType::ModelPack,
534    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
535    ///     shape: vec![1, 8400, 3],
536    ///     dshape: Vec::new(),
537    /// };
538    /// let decoder = DecoderBuilder::new()
539    ///     .with_config_modelpack_det(boxes_config, scores_config)
540    ///     .build()?;
541    /// # Ok(())
542    /// # }
543    /// ```
544    pub fn with_config_modelpack_det(
545        mut self,
546        boxes: configs::Boxes,
547        scores: configs::Scores,
548    ) -> Self {
549        let config = ConfigOutputs {
550            outputs: vec![ConfigOutput::Boxes(boxes), ConfigOutput::Scores(scores)],
551            ..Default::default()
552        };
553        self.config_src.replace(ConfigSource::Config(config));
554        self
555    }
556
557    /// Loads a ModelPack split detection model configuration. Use
558    /// `DecoderBuilder.build()` to parse the model configuration.
559    ///
560    /// # Examples
561    /// ```rust
562    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
563    /// # fn main() -> DecoderResult<()> {
564    /// let config0 = configs::Detection {
565    ///     anchors: Some(vec![
566    ///         [0.13750000298023224, 0.2074074000120163],
567    ///         [0.2541666626930237, 0.21481481194496155],
568    ///         [0.23125000298023224, 0.35185185074806213],
569    ///     ]),
570    ///     decoder: configs::DecoderType::ModelPack,
571    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
572    ///     shape: vec![1, 17, 30, 18],
573    ///     dshape: Vec::new(),
574    ///     normalized: Some(true),
575    /// };
576    /// let config1 = configs::Detection {
577    ///     anchors: Some(vec![
578    ///         [0.36666667461395264, 0.31481480598449707],
579    ///         [0.38749998807907104, 0.4740740656852722],
580    ///         [0.5333333611488342, 0.644444465637207],
581    ///     ]),
582    ///     decoder: configs::DecoderType::ModelPack,
583    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
584    ///     shape: vec![1, 9, 15, 18],
585    ///     dshape: Vec::new(),
586    ///     normalized: Some(true),
587    /// };
588    ///
589    /// let decoder = DecoderBuilder::new()
590    ///     .with_config_modelpack_det_split(vec![config0, config1])
591    ///     .build()?;
592    /// # Ok(())
593    /// # }
594    /// ```
595    pub fn with_config_modelpack_det_split(mut self, boxes: Vec<configs::Detection>) -> Self {
596        let outputs = boxes.into_iter().map(ConfigOutput::Detection).collect();
597        let config = ConfigOutputs {
598            outputs,
599            ..Default::default()
600        };
601        self.config_src.replace(ConfigSource::Config(config));
602        self
603    }
604
605    /// Loads a ModelPack segmentation detection model configuration. Use
606    /// `DecoderBuilder.build()` to parse the model configuration.
607    ///
608    /// # Examples
609    /// ```rust
610    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
611    /// # fn main() -> DecoderResult<()> {
612    /// let boxes_config = configs::Boxes {
613    ///     decoder: configs::DecoderType::ModelPack,
614    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
615    ///     shape: vec![1, 8400, 1, 4],
616    ///     dshape: Vec::new(),
617    ///     normalized: Some(true),
618    /// };
619    /// let scores_config = configs::Scores {
620    ///     decoder: configs::DecoderType::ModelPack,
621    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
622    ///     shape: vec![1, 8400, 2],
623    ///     dshape: Vec::new(),
624    /// };
625    /// let seg_config = configs::Segmentation {
626    ///     decoder: configs::DecoderType::ModelPack,
627    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
628    ///     shape: vec![1, 640, 640, 3],
629    ///     dshape: Vec::new(),
630    /// };
631    /// let decoder = DecoderBuilder::new()
632    ///     .with_config_modelpack_segdet(boxes_config, scores_config, seg_config)
633    ///     .build()?;
634    /// # Ok(())
635    /// # }
636    /// ```
637    pub fn with_config_modelpack_segdet(
638        mut self,
639        boxes: configs::Boxes,
640        scores: configs::Scores,
641        segmentation: configs::Segmentation,
642    ) -> Self {
643        let config = ConfigOutputs {
644            outputs: vec![
645                ConfigOutput::Boxes(boxes),
646                ConfigOutput::Scores(scores),
647                ConfigOutput::Segmentation(segmentation),
648            ],
649            ..Default::default()
650        };
651        self.config_src.replace(ConfigSource::Config(config));
652        self
653    }
654
655    /// Loads a ModelPack segmentation split detection model configuration. Use
656    /// `DecoderBuilder.build()` to parse the model configuration.
657    ///
658    /// # Examples
659    /// ```rust
660    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
661    /// # fn main() -> DecoderResult<()> {
662    /// let config0 = configs::Detection {
663    ///     anchors: Some(vec![
664    ///         [0.36666667461395264, 0.31481480598449707],
665    ///         [0.38749998807907104, 0.4740740656852722],
666    ///         [0.5333333611488342, 0.644444465637207],
667    ///     ]),
668    ///     decoder: configs::DecoderType::ModelPack,
669    ///     quantization: Some(configs::QuantTuple(0.08547406643629074, 174)),
670    ///     shape: vec![1, 9, 15, 18],
671    ///     dshape: Vec::new(),
672    ///     normalized: Some(true),
673    /// };
674    /// let config1 = configs::Detection {
675    ///     anchors: Some(vec![
676    ///         [0.13750000298023224, 0.2074074000120163],
677    ///         [0.2541666626930237, 0.21481481194496155],
678    ///         [0.23125000298023224, 0.35185185074806213],
679    ///     ]),
680    ///     decoder: configs::DecoderType::ModelPack,
681    ///     quantization: Some(configs::QuantTuple(0.09929127991199493, 183)),
682    ///     shape: vec![1, 17, 30, 18],
683    ///     dshape: Vec::new(),
684    ///     normalized: Some(true),
685    /// };
686    /// let seg_config = configs::Segmentation {
687    ///     decoder: configs::DecoderType::ModelPack,
688    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
689    ///     shape: vec![1, 640, 640, 2],
690    ///     dshape: Vec::new(),
691    /// };
692    /// let decoder = DecoderBuilder::new()
693    ///     .with_config_modelpack_segdet_split(vec![config0, config1], seg_config)
694    ///     .build()?;
695    /// # Ok(())
696    /// # }
697    /// ```
698    pub fn with_config_modelpack_segdet_split(
699        mut self,
700        boxes: Vec<configs::Detection>,
701        segmentation: configs::Segmentation,
702    ) -> Self {
703        let mut outputs = boxes
704            .into_iter()
705            .map(ConfigOutput::Detection)
706            .collect::<Vec<_>>();
707        outputs.push(ConfigOutput::Segmentation(segmentation));
708        let config = ConfigOutputs {
709            outputs,
710            ..Default::default()
711        };
712        self.config_src.replace(ConfigSource::Config(config));
713        self
714    }
715
716    /// Loads a ModelPack segmentation model configuration. Use
717    /// `DecoderBuilder.build()` to parse the model configuration.
718    ///
719    /// # Examples
720    /// ```rust
721    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
722    /// # fn main() -> DecoderResult<()> {
723    /// let seg_config = configs::Segmentation {
724    ///     decoder: configs::DecoderType::ModelPack,
725    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
726    ///     shape: vec![1, 640, 640, 3],
727    ///     dshape: Vec::new(),
728    /// };
729    /// let decoder = DecoderBuilder::new()
730    ///     .with_config_modelpack_seg(seg_config)
731    ///     .build()?;
732    /// # Ok(())
733    /// # }
734    /// ```
735    pub fn with_config_modelpack_seg(mut self, segmentation: configs::Segmentation) -> Self {
736        let config = ConfigOutputs {
737            outputs: vec![ConfigOutput::Segmentation(segmentation)],
738            ..Default::default()
739        };
740        self.config_src.replace(ConfigSource::Config(config));
741        self
742    }
743
744    /// Add an output to the decoder configuration.
745    ///
746    /// Incrementally builds the model configuration by adding outputs one at
747    /// a time. The decoder resolves the model type from the combination of
748    /// outputs during `build()`.
749    ///
750    /// If `dshape` is non-empty on the output, `shape` is automatically
751    /// derived from it (the size component of each named dimension). This
752    /// prevents conflicts between `shape` and `dshape`.
753    ///
754    /// This uses the programmatic config path. Calling this after
755    /// `with_config_json_str()` or `with_config_yaml_str()` replaces the
756    /// string-based config source.
757    ///
758    /// # Examples
759    /// ```rust
760    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, ConfigOutput, configs};
761    /// # fn main() -> DecoderResult<()> {
762    /// let decoder = DecoderBuilder::new()
763    ///     .add_output(ConfigOutput::Scores(configs::Scores {
764    ///         decoder: configs::DecoderType::Ultralytics,
765    ///         dshape: vec![
766    ///             (configs::DimName::Batch, 1),
767    ///             (configs::DimName::NumClasses, 80),
768    ///             (configs::DimName::NumBoxes, 8400),
769    ///         ],
770    ///         ..Default::default()
771    ///     }))
772    ///     .add_output(ConfigOutput::Boxes(configs::Boxes {
773    ///         decoder: configs::DecoderType::Ultralytics,
774    ///         dshape: vec![
775    ///             (configs::DimName::Batch, 1),
776    ///             (configs::DimName::BoxCoords, 4),
777    ///             (configs::DimName::NumBoxes, 8400),
778    ///         ],
779    ///         ..Default::default()
780    ///     }))
781    ///     .build()?;
782    /// # Ok(())
783    /// # }
784    /// ```
785    pub fn add_output(mut self, output: ConfigOutput) -> Self {
786        if !matches!(self.config_src, Some(ConfigSource::Config(_))) {
787            self.config_src = Some(ConfigSource::Config(ConfigOutputs::default()));
788        }
789        if let Some(ConfigSource::Config(ref mut config)) = self.config_src {
790            config.outputs.push(Self::normalize_output(output));
791        }
792        self
793    }
794
795    /// Sets the decoder version for Ultralytics models.
796    ///
797    /// This is used with `add_output()` to specify the YOLO architecture
798    /// version when it cannot be inferred from the output shapes alone.
799    ///
800    /// - `Yolov5`, `Yolov8`, `Yolo11`: Traditional models requiring external
801    ///   NMS
802    /// - `Yolo26`: End-to-end models with NMS embedded in the model graph
803    ///
804    /// # Examples
805    /// ```rust
806    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, ConfigOutput, configs};
807    /// # fn main() -> DecoderResult<()> {
808    /// let decoder = DecoderBuilder::new()
809    ///     .add_output(ConfigOutput::Detection(configs::Detection {
810    ///         decoder: configs::DecoderType::Ultralytics,
811    ///         dshape: vec![
812    ///             (configs::DimName::Batch, 1),
813    ///             (configs::DimName::NumBoxes, 100),
814    ///             (configs::DimName::NumFeatures, 6),
815    ///         ],
816    ///         ..Default::default()
817    ///     }))
818    ///     .with_decoder_version(configs::DecoderVersion::Yolo26)
819    ///     .build()?;
820    /// # Ok(())
821    /// # }
822    /// ```
823    pub fn with_decoder_version(mut self, version: configs::DecoderVersion) -> Self {
824        if !matches!(self.config_src, Some(ConfigSource::Config(_))) {
825            self.config_src = Some(ConfigSource::Config(ConfigOutputs::default()));
826        }
827        if let Some(ConfigSource::Config(ref mut config)) = self.config_src {
828            config.decoder_version = Some(version);
829        }
830        self
831    }
832
833    /// Normalize an output: if dshape is non-empty, derive shape from it.
834    fn normalize_output(mut output: ConfigOutput) -> ConfigOutput {
835        fn normalize_shape(shape: &mut Vec<usize>, dshape: &[(configs::DimName, usize)]) {
836            if !dshape.is_empty() {
837                *shape = dshape.iter().map(|(_, size)| *size).collect();
838            }
839        }
840        match &mut output {
841            ConfigOutput::Detection(c) => normalize_shape(&mut c.shape, &c.dshape),
842            ConfigOutput::Boxes(c) => normalize_shape(&mut c.shape, &c.dshape),
843            ConfigOutput::Scores(c) => normalize_shape(&mut c.shape, &c.dshape),
844            ConfigOutput::Protos(c) => normalize_shape(&mut c.shape, &c.dshape),
845            ConfigOutput::Segmentation(c) => normalize_shape(&mut c.shape, &c.dshape),
846            ConfigOutput::MaskCoefficients(c) => normalize_shape(&mut c.shape, &c.dshape),
847            ConfigOutput::Mask(c) => normalize_shape(&mut c.shape, &c.dshape),
848            ConfigOutput::Classes(c) => normalize_shape(&mut c.shape, &c.dshape),
849        }
850        output
851    }
852
853    /// Sets the scores threshold of the decoder
854    ///
855    /// # Examples
856    /// ```rust
857    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
858    /// # fn main() -> DecoderResult<()> {
859    /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
860    /// let decoder = DecoderBuilder::new()
861    ///     .with_config_json_str(config_json)
862    ///     .with_score_threshold(0.654)
863    ///     .build()?;
864    /// assert_eq!(decoder.score_threshold, 0.654);
865    /// # Ok(())
866    /// # }
867    /// ```
868    pub fn with_score_threshold(mut self, score_threshold: f32) -> Self {
869        self.score_threshold = score_threshold;
870        self
871    }
872
873    /// Sets the IOU threshold of the decoder. Has no effect when NMS is set to
874    /// `None`
875    ///
876    /// # Examples
877    /// ```rust
878    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
879    /// # fn main() -> DecoderResult<()> {
880    /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
881    /// let decoder = DecoderBuilder::new()
882    ///     .with_config_json_str(config_json)
883    ///     .with_iou_threshold(0.654)
884    ///     .build()?;
885    /// assert_eq!(decoder.iou_threshold, 0.654);
886    /// # Ok(())
887    /// # }
888    /// ```
889    pub fn with_iou_threshold(mut self, iou_threshold: f32) -> Self {
890        self.iou_threshold = iou_threshold;
891        self
892    }
893
894    /// Sets the NMS mode for the decoder.
895    ///
896    /// - `Some(Nms::ClassAgnostic)` — class-agnostic NMS (default): suppress
897    ///   overlapping boxes regardless of class label
898    /// - `Some(Nms::ClassAware)` — class-aware NMS: only suppress boxes that
899    ///   share the same class label AND overlap above the IoU threshold
900    /// - `None` — bypass NMS entirely (for end-to-end models with embedded NMS)
901    ///
902    /// # Examples
903    /// ```rust
904    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, configs::Nms};
905    /// # fn main() -> DecoderResult<()> {
906    /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
907    /// let decoder = DecoderBuilder::new()
908    ///     .with_config_json_str(config_json)
909    ///     .with_nms(Some(Nms::ClassAware))
910    ///     .build()?;
911    /// assert_eq!(decoder.nms, Some(Nms::ClassAware));
912    /// # Ok(())
913    /// # }
914    /// ```
915    pub fn with_nms(mut self, nms: Option<configs::Nms>) -> Self {
916        self.nms = nms;
917        self
918    }
919
920    /// Sets the maximum number of candidate boxes fed into NMS after score
921    /// filtering.  Uses partial sort (O(N)) to select the top-K candidates,
922    /// dramatically reducing the O(N²) NMS cost when many low-confidence
923    /// proposals pass the threshold (common with mAP eval at 0.001).
924    ///
925    /// Default: 300 (matches Ultralytics `max_det`).
926    ///
927    /// # Examples
928    /// ```rust
929    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
930    /// # fn main() -> DecoderResult<()> {
931    /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
932    /// let decoder = DecoderBuilder::new()
933    ///     .with_config_json_str(config_json)
934    ///     .with_pre_nms_top_k(1000)
935    ///     .build()?;
936    /// assert_eq!(decoder.pre_nms_top_k, 1000);
937    /// # Ok(())
938    /// # }
939    /// ```
940    pub fn with_pre_nms_top_k(mut self, pre_nms_top_k: usize) -> Self {
941        self.pre_nms_top_k = pre_nms_top_k;
942        self
943    }
944
945    /// Sets the maximum number of detections returned after NMS.
946    /// Matches the Ultralytics `max_det` parameter.
947    ///
948    /// Default: 300.
949    ///
950    /// # Examples
951    /// ```rust
952    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
953    /// # fn main() -> DecoderResult<()> {
954    /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
955    /// let decoder = DecoderBuilder::new()
956    ///     .with_config_json_str(config_json)
957    ///     .with_max_det(100)
958    ///     .build()?;
959    /// assert_eq!(decoder.max_det, 100);
960    /// # Ok(())
961    /// # }
962    /// ```
963    pub fn with_max_det(mut self, max_det: usize) -> Self {
964        self.max_det = max_det;
965        self
966    }
967
968    /// Sets the model input dimensions `(width, height)` consumed by the
969    /// EDGEAI-1303 normalization path. Use this when building via
970    /// [`with_config`](Self::with_config) / [`add_output`](Self::add_output)
971    /// (no schema) and the model emits pixel-space boxes that need to be
972    /// divided by `(W, H)` before NMS.
973    ///
974    /// When the builder is also configured with [`with_schema`](Self::with_schema)
975    /// (or `with_config_json_str` / `with_config_yaml_str`) and the schema's
976    /// `input` block carries usable dims, this explicit override **takes
977    /// precedence** so callers can correct schemas with missing or wrong
978    /// input specs without rewriting the schema.
979    ///
980    /// # Examples
981    /// ```rust
982    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
983    /// # fn main() -> DecoderResult<()> {
984    /// # let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
985    /// let decoder = DecoderBuilder::new()
986    ///     .with_config_yaml_str(config_yaml)
987    ///     .with_input_dims(640, 640)
988    ///     .build()?;
989    /// assert_eq!(decoder.input_dims(), Some((640, 640)));
990    /// # Ok(())
991    /// # }
992    /// ```
993    pub fn with_input_dims(mut self, width: usize, height: usize) -> Self {
994        self.input_dims = Some((width, height));
995        self
996    }
997
998    /// Builds the decoder with the given settings. If the config is a JSON or
999    /// YAML string, this will deserialize the JSON or YAML and then parse the
1000    /// configuration information.
1001    ///
1002    /// # Examples
1003    /// ```rust
1004    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
1005    /// # fn main() -> DecoderResult<()> {
1006    /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
1007    /// let decoder = DecoderBuilder::new()
1008    ///     .with_config_json_str(config_json)
1009    ///     .with_score_threshold(0.654)
1010    ///     .build()?;
1011    /// # Ok(())
1012    /// # }
1013    /// ```
1014    pub fn build(self) -> Result<Decoder, DecoderError> {
1015        let decode_dtype = self.decode_dtype;
1016        let explicit_input_dims = self.input_dims;
1017        let (config, decode_program, per_scale_plan, schema_input_dims) = match self.config_src {
1018            Some(ConfigSource::Json(s)) => {
1019                Self::build_from_schema(SchemaV2::parse_json(&s)?, decode_dtype)?
1020            }
1021            Some(ConfigSource::Yaml(s)) => {
1022                Self::build_from_schema(SchemaV2::parse_yaml(&s)?, decode_dtype)?
1023            }
1024            Some(ConfigSource::Config(c)) => (c, None, None, None),
1025            Some(ConfigSource::Schema(schema)) => Self::build_from_schema(schema, decode_dtype)?,
1026            None => return Err(DecoderError::NoConfig),
1027        };
1028        // Explicit `with_input_dims(W, H)` overrides any schema-derived
1029        // value so callers can fix schemas with missing or wrong input
1030        // specs without rewriting the schema (EDGEAI-1303).
1031        let input_dims = explicit_input_dims.or(schema_input_dims);
1032
1033        // Enforce the physical-order contract: when dshape is present
1034        // it must describe the same axes as shape in the same order,
1035        // listed from outermost to innermost. Ambiguous-layout roles
1036        // (Protos, Boxes, Scores, MaskCoefficients, Classes, Detection)
1037        // may still omit dshape when shape is already in the decoder's
1038        // canonical order.
1039        for output in &config.outputs {
1040            Decoder::validate_output_layout(output.into())?;
1041        }
1042
1043        // Extract normalized flag from config outputs.
1044        //
1045        // The per-scale subsystem (DFL/LTRB → dist2bbox → sigmoid) emits
1046        // boxes in pixel coordinates by design — `(grid + dist) * stride`
1047        // — independently of any `normalized: true` annotation in the
1048        // schema. The schema's `normalized` flag describes the model's
1049        // training-time convention, not the runtime output coord space
1050        // for this code path. Override to `Some(false)` when the
1051        // per-scale path is active so `Decoder::normalized_boxes()`
1052        // matches what `decode_proto`/`decode` actually produce; the
1053        // legacy / non-per-scale paths still honor the schema flag.
1054        let normalized = if per_scale_plan.is_some() {
1055            Some(false)
1056        } else {
1057            Self::get_normalized(&config.outputs)
1058        };
1059
1060        // Use NMS from config if present, otherwise use builder's NMS setting
1061        let nms = config.nms.or(self.nms);
1062        // When the per-scale path is active, the per_scale subsystem owns
1063        // model decoding entirely — `decode` / `decode_proto` short-circuit
1064        // on `per_scale.is_some()` before reading `model_type`. Skip the
1065        // legacy ModelType validation, which otherwise rejects per-scale
1066        // schemas that carry `decoder_version: yolo26` (an
1067        // "end-to-end" hint) but use the per-scale split outputs rather
1068        // than the end-to-end split-output shape the legacy validator
1069        // expects. We keep a placeholder `ModelType` so the field remains
1070        // valid; it is dead state for per-scale Decoders.
1071        let model_type = if per_scale_plan.is_some() {
1072            // Drop the un-needed config; the per-scale subsystem owns it.
1073            drop(config);
1074            ModelType::PerScale
1075        } else {
1076            Self::get_model_type(config)?
1077        };
1078
1079        let per_scale = per_scale_plan
1080            .map(|plan| std::sync::Mutex::new(crate::per_scale::PerScaleDecoder::new(plan)));
1081
1082        Ok(Decoder {
1083            model_type,
1084            iou_threshold: self.iou_threshold,
1085            score_threshold: self.score_threshold,
1086            nms,
1087            pre_nms_top_k: self.pre_nms_top_k,
1088            max_det: self.max_det,
1089            normalized,
1090            input_dims,
1091            decode_program,
1092            per_scale,
1093        })
1094    }
1095
1096    /// Validate a [`SchemaV2`] and lower it to the (legacy `ConfigOutputs`,
1097    /// optional `DecodeProgram`, optional `PerScalePlan`) tuple the rest
1098    /// of `build()` consumes.
1099    ///
1100    /// Centralises the v2 lowering so JSON, YAML, and direct
1101    /// `with_schema` callers all go through the same validation,
1102    /// merge-program, and per-scale plan construction. `SchemaV2::parse_json`
1103    /// / `parse_yaml` already auto-detect v1 vs v2 input and return a v2
1104    /// schema either way (v1 inputs are upgraded in memory via
1105    /// `from_v1`), so this helper is the sole place that turns a
1106    /// schema into builder-ready state.
1107    #[allow(clippy::type_complexity)]
1108    fn build_from_schema(
1109        schema: SchemaV2,
1110        decode_dtype: DecodeDtype,
1111    ) -> Result<
1112        (
1113            ConfigOutputs,
1114            Option<DecodeProgram>,
1115            Option<PerScalePlan>,
1116            Option<(usize, usize)>,
1117        ),
1118        DecoderError,
1119    > {
1120        schema.validate()?;
1121        let program = DecodeProgram::try_from_schema(&schema)?;
1122        let per_scale = PerScalePlan::try_from_schema(&schema, decode_dtype)?;
1123        // Extract model input (W, H) from `input.shape`/`dshape`. Used by
1124        // the legacy decode path to honour `normalized: false` (see
1125        // EDGEAI-1303). `None` is fine when the schema omits the input
1126        // spec — the decoder falls back to the protobox `>2.0` reject.
1127        let input_dims = schema.input.as_ref().and_then(input_dims_from_spec);
1128        let legacy = schema.to_legacy_config_outputs()?;
1129        Ok((legacy, program, per_scale, input_dims))
1130    }
1131
1132    /// Extracts the normalized flag from config outputs.
1133    /// - `Some(true)`: Boxes are in normalized [0,1] coordinates
1134    /// - `Some(false)`: Boxes are in pixel coordinates
1135    /// - `None`: Unknown (not specified in config), caller must infer
1136    fn get_normalized(outputs: &[ConfigOutput]) -> Option<bool> {
1137        for output in outputs {
1138            match output {
1139                ConfigOutput::Detection(det) => return det.normalized,
1140                ConfigOutput::Boxes(boxes) => return boxes.normalized,
1141                _ => {}
1142            }
1143        }
1144        None // not specified
1145    }
1146
1147    fn get_model_type(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
1148        // yolo or modelpack
1149        let mut yolo = false;
1150        let mut modelpack = false;
1151        for c in &configs.outputs {
1152            match c.decoder() {
1153                DecoderType::ModelPack => modelpack = true,
1154                DecoderType::Ultralytics => yolo = true,
1155            }
1156        }
1157        match (modelpack, yolo) {
1158            (true, true) => Err(DecoderError::InvalidConfig(
1159                "Both ModelPack and Yolo outputs found in config".to_string(),
1160            )),
1161            (true, false) => Self::get_model_type_modelpack(configs),
1162            (false, true) => Self::get_model_type_yolo(configs),
1163            (false, false) => Err(DecoderError::InvalidConfig(
1164                "No outputs found in config".to_string(),
1165            )),
1166        }
1167    }
1168
1169    fn get_model_type_yolo(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
1170        let mut boxes = None;
1171        let mut protos = None;
1172        let mut split_boxes = None;
1173        let mut split_scores = None;
1174        let mut split_mask_coeff = None;
1175        let mut split_classes = None;
1176        for c in configs.outputs {
1177            match c {
1178                ConfigOutput::Detection(detection) => boxes = Some(detection),
1179                ConfigOutput::Segmentation(_) => {
1180                    return Err(DecoderError::InvalidConfig(
1181                        "Invalid Segmentation output with Yolo decoder".to_string(),
1182                    ));
1183                }
1184                ConfigOutput::Protos(protos_) => protos = Some(protos_),
1185                ConfigOutput::Mask(_) => {
1186                    return Err(DecoderError::InvalidConfig(
1187                        "Invalid Mask output with Yolo decoder".to_string(),
1188                    ));
1189                }
1190                ConfigOutput::Scores(scores) => split_scores = Some(scores),
1191                ConfigOutput::Boxes(boxes) => split_boxes = Some(boxes),
1192                ConfigOutput::MaskCoefficients(mask_coeff) => split_mask_coeff = Some(mask_coeff),
1193                ConfigOutput::Classes(classes) => split_classes = Some(classes),
1194            }
1195        }
1196
1197        // Use end-to-end model types when:
1198        // 1. decoder_version is explicitly set to Yolo26 (definitive), OR
1199        //    decoder_version is not set but the dshapes are (batch, num_boxes,
1200        //    num_features)
1201        let is_end_to_end_dshape = boxes.as_ref().is_some_and(|b| {
1202            let dims = b.dshape.iter().map(|(d, _)| *d).collect::<Vec<_>>();
1203            dims == vec![DimName::Batch, DimName::NumBoxes, DimName::NumFeatures]
1204        });
1205
1206        let is_end_to_end = configs
1207            .decoder_version
1208            .map(|v| v.is_end_to_end())
1209            .unwrap_or(is_end_to_end_dshape);
1210
1211        if is_end_to_end {
1212            if let Some(boxes) = boxes {
1213                if let Some(protos) = protos {
1214                    Self::verify_yolo_seg_det_26(&boxes, &protos)?;
1215                    return Ok(ModelType::YoloEndToEndSegDet { boxes, protos });
1216                } else {
1217                    Self::verify_yolo_det_26(&boxes)?;
1218                    return Ok(ModelType::YoloEndToEndDet { boxes });
1219                }
1220            } else if let (Some(split_boxes), Some(split_scores), Some(split_classes)) =
1221                (split_boxes, split_scores, split_classes)
1222            {
1223                if let (Some(split_mask_coeff), Some(protos)) = (split_mask_coeff, protos) {
1224                    Self::verify_yolo_split_end_to_end_segdet(
1225                        &split_boxes,
1226                        &split_scores,
1227                        &split_classes,
1228                        &split_mask_coeff,
1229                        &protos,
1230                    )?;
1231                    return Ok(ModelType::YoloSplitEndToEndSegDet {
1232                        boxes: split_boxes,
1233                        scores: split_scores,
1234                        classes: split_classes,
1235                        mask_coeff: split_mask_coeff,
1236                        protos,
1237                    });
1238                }
1239                Self::verify_yolo_split_end_to_end_det(
1240                    &split_boxes,
1241                    &split_scores,
1242                    &split_classes,
1243                )?;
1244                return Ok(ModelType::YoloSplitEndToEndDet {
1245                    boxes: split_boxes,
1246                    scores: split_scores,
1247                    classes: split_classes,
1248                });
1249            } else {
1250                return Err(DecoderError::InvalidConfig(
1251                    "Invalid Yolo end-to-end model outputs".to_string(),
1252                ));
1253            }
1254        }
1255
1256        if let Some(boxes) = boxes {
1257            match (split_mask_coeff, protos) {
1258                (Some(mask_coeff), Some(protos)) => {
1259                    // 2-way split: combined detection + separate mask_coeff + protos
1260                    Self::verify_yolo_seg_det_2way(&boxes, &mask_coeff, &protos)?;
1261                    Ok(ModelType::YoloSegDet2Way {
1262                        boxes,
1263                        mask_coeff,
1264                        protos,
1265                    })
1266                }
1267                (_, Some(protos)) => {
1268                    // Unsplit: mask_coefs embedded in detection tensor
1269                    Self::verify_yolo_seg_det(&boxes, &protos)?;
1270                    Ok(ModelType::YoloSegDet { boxes, protos })
1271                }
1272                _ => {
1273                    Self::verify_yolo_det(&boxes)?;
1274                    Ok(ModelType::YoloDet { boxes })
1275                }
1276            }
1277        } else if let (Some(boxes), Some(scores)) = (split_boxes, split_scores) {
1278            if let (Some(mask_coeff), Some(protos)) = (split_mask_coeff, protos) {
1279                Self::verify_yolo_split_segdet(&boxes, &scores, &mask_coeff, &protos)?;
1280                Ok(ModelType::YoloSplitSegDet {
1281                    boxes,
1282                    scores,
1283                    mask_coeff,
1284                    protos,
1285                })
1286            } else {
1287                Self::verify_yolo_split_det(&boxes, &scores)?;
1288                Ok(ModelType::YoloSplitDet { boxes, scores })
1289            }
1290        } else {
1291            Err(DecoderError::InvalidConfig(
1292                "Invalid Yolo model outputs".to_string(),
1293            ))
1294        }
1295    }
1296
1297    fn verify_yolo_det(detect: &configs::Detection) -> Result<(), DecoderError> {
1298        if detect.shape.len() != 3 {
1299            return Err(DecoderError::InvalidConfig(format!(
1300                "Invalid Yolo Detection shape {:?}",
1301                detect.shape
1302            )));
1303        }
1304
1305        Self::verify_dshapes(
1306            &detect.dshape,
1307            &detect.shape,
1308            "Detection",
1309            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1310        )?;
1311        if !detect.dshape.is_empty() {
1312            Self::get_class_count(&detect.dshape, None, None)?;
1313        } else {
1314            Self::get_class_count_no_dshape(detect.into(), None)?;
1315        }
1316
1317        Ok(())
1318    }
1319
1320    fn verify_yolo_det_26(detect: &configs::Detection) -> Result<(), DecoderError> {
1321        if detect.shape.len() != 3 {
1322            return Err(DecoderError::InvalidConfig(format!(
1323                "Invalid Yolo Detection shape {:?}",
1324                detect.shape
1325            )));
1326        }
1327
1328        Self::verify_dshapes(
1329            &detect.dshape,
1330            &detect.shape,
1331            "Detection",
1332            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1333        )?;
1334
1335        if !detect.shape.contains(&6) {
1336            return Err(DecoderError::InvalidConfig(
1337                "Yolo26 Detection must have 6 features".to_string(),
1338            ));
1339        }
1340
1341        Ok(())
1342    }
1343
1344    fn verify_yolo_seg_det(
1345        detection: &configs::Detection,
1346        protos: &configs::Protos,
1347    ) -> Result<(), DecoderError> {
1348        if detection.shape.len() != 3 {
1349            return Err(DecoderError::InvalidConfig(format!(
1350                "Invalid Yolo Detection shape {:?}",
1351                detection.shape
1352            )));
1353        }
1354        if protos.shape.len() != 4 {
1355            return Err(DecoderError::InvalidConfig(format!(
1356                "Invalid Yolo Protos shape {:?}",
1357                protos.shape
1358            )));
1359        }
1360
1361        Self::verify_dshapes(
1362            &detection.dshape,
1363            &detection.shape,
1364            "Detection",
1365            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1366        )?;
1367        Self::verify_dshapes(
1368            &protos.dshape,
1369            &protos.shape,
1370            "Protos",
1371            &[
1372                DimName::Batch,
1373                DimName::Height,
1374                DimName::Width,
1375                DimName::NumProtos,
1376            ],
1377        )?;
1378
1379        let protos_count = Self::get_protos_count(&protos.dshape)
1380            .unwrap_or_else(|| protos.shape[1].min(protos.shape[3]));
1381        log::debug!("Protos count: {}", protos_count);
1382        log::debug!("Detection dshape: {:?}", detection.dshape);
1383        let classes = if !detection.dshape.is_empty() {
1384            Self::get_class_count(&detection.dshape, Some(protos_count), None)?
1385        } else {
1386            Self::get_class_count_no_dshape(detection.into(), Some(protos_count))?
1387        };
1388
1389        if classes == 0 {
1390            return Err(DecoderError::InvalidConfig(
1391                "Yolo Segmentation Detection has zero classes".to_string(),
1392            ));
1393        }
1394
1395        Ok(())
1396    }
1397
1398    fn verify_yolo_seg_det_2way(
1399        detection: &configs::Detection,
1400        mask_coeff: &configs::MaskCoefficients,
1401        protos: &configs::Protos,
1402    ) -> Result<(), DecoderError> {
1403        if detection.shape.len() != 3 {
1404            return Err(DecoderError::InvalidConfig(format!(
1405                "Invalid Yolo 2-Way Detection shape {:?}",
1406                detection.shape
1407            )));
1408        }
1409        if mask_coeff.shape.len() != 3 {
1410            return Err(DecoderError::InvalidConfig(format!(
1411                "Invalid Yolo 2-Way Mask Coefficients shape {:?}",
1412                mask_coeff.shape
1413            )));
1414        }
1415        if protos.shape.len() != 4 {
1416            return Err(DecoderError::InvalidConfig(format!(
1417                "Invalid Yolo 2-Way Protos shape {:?}",
1418                protos.shape
1419            )));
1420        }
1421
1422        Self::verify_dshapes(
1423            &detection.dshape,
1424            &detection.shape,
1425            "Detection",
1426            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1427        )?;
1428        Self::verify_dshapes(
1429            &mask_coeff.dshape,
1430            &mask_coeff.shape,
1431            "Mask Coefficients",
1432            &[DimName::Batch, DimName::NumProtos, DimName::NumBoxes],
1433        )?;
1434        Self::verify_dshapes(
1435            &protos.dshape,
1436            &protos.shape,
1437            "Protos",
1438            &[
1439                DimName::Batch,
1440                DimName::Height,
1441                DimName::Width,
1442                DimName::NumProtos,
1443            ],
1444        )?;
1445
1446        // Validate num_boxes match between detection and mask_coeff
1447        let det_num = Self::get_box_count(&detection.dshape).unwrap_or(detection.shape[2]);
1448        let mask_num = Self::get_box_count(&mask_coeff.dshape).unwrap_or(mask_coeff.shape[2]);
1449        if det_num != mask_num {
1450            return Err(DecoderError::InvalidConfig(format!(
1451                "Yolo 2-Way Detection num_boxes {} incompatible with Mask Coefficients num_boxes {}",
1452                det_num, mask_num
1453            )));
1454        }
1455
1456        // Validate mask_coeff channels match protos channels
1457        let mask_channels = if !mask_coeff.dshape.is_empty() {
1458            Self::get_protos_count(&mask_coeff.dshape).ok_or_else(|| {
1459                DecoderError::InvalidConfig(
1460                    "Could not find num_protos in mask_coeff config".to_string(),
1461                )
1462            })?
1463        } else {
1464            mask_coeff.shape[1]
1465        };
1466        let proto_channels = if !protos.dshape.is_empty() {
1467            Self::get_protos_count(&protos.dshape).ok_or_else(|| {
1468                DecoderError::InvalidConfig(
1469                    "Could not find num_protos in protos config".to_string(),
1470                )
1471            })?
1472        } else {
1473            protos.shape[1].min(protos.shape[3])
1474        };
1475        if mask_channels != proto_channels {
1476            return Err(DecoderError::InvalidConfig(format!(
1477                "Yolo 2-Way Protos channels {} incompatible with Mask Coefficients channels {}",
1478                proto_channels, mask_channels
1479            )));
1480        }
1481
1482        // Validate detection has classes (nc+4 features, no mask_coefs embedded)
1483        if !detection.dshape.is_empty() {
1484            Self::get_class_count(&detection.dshape, None, None)?;
1485        } else {
1486            Self::get_class_count_no_dshape(detection.into(), None)?;
1487        }
1488
1489        Ok(())
1490    }
1491
1492    fn verify_yolo_seg_det_26(
1493        detection: &configs::Detection,
1494        protos: &configs::Protos,
1495    ) -> Result<(), DecoderError> {
1496        if detection.shape.len() != 3 {
1497            return Err(DecoderError::InvalidConfig(format!(
1498                "Invalid Yolo Detection shape {:?}",
1499                detection.shape
1500            )));
1501        }
1502        if protos.shape.len() != 4 {
1503            return Err(DecoderError::InvalidConfig(format!(
1504                "Invalid Yolo Protos shape {:?}",
1505                protos.shape
1506            )));
1507        }
1508
1509        Self::verify_dshapes(
1510            &detection.dshape,
1511            &detection.shape,
1512            "Detection",
1513            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1514        )?;
1515        Self::verify_dshapes(
1516            &protos.dshape,
1517            &protos.shape,
1518            "Protos",
1519            &[
1520                DimName::Batch,
1521                DimName::Height,
1522                DimName::Width,
1523                DimName::NumProtos,
1524            ],
1525        )?;
1526
1527        let protos_count = Self::get_protos_count(&protos.dshape)
1528            .unwrap_or_else(|| protos.shape[1].min(protos.shape[3]));
1529        log::debug!("Protos count: {}", protos_count);
1530        log::debug!("Detection dshape: {:?}", detection.dshape);
1531
1532        if !detection.shape.contains(&(6 + protos_count)) {
1533            return Err(DecoderError::InvalidConfig(format!(
1534                "Yolo26 Segmentation Detection must have num_features be 6 + num_protos = {}",
1535                6 + protos_count
1536            )));
1537        }
1538
1539        Ok(())
1540    }
1541
1542    fn verify_yolo_split_det(
1543        boxes: &configs::Boxes,
1544        scores: &configs::Scores,
1545    ) -> Result<(), DecoderError> {
1546        if boxes.shape.len() != 3 {
1547            return Err(DecoderError::InvalidConfig(format!(
1548                "Invalid Yolo Split Boxes shape {:?}",
1549                boxes.shape
1550            )));
1551        }
1552        if scores.shape.len() != 3 {
1553            return Err(DecoderError::InvalidConfig(format!(
1554                "Invalid Yolo Split Scores shape {:?}",
1555                scores.shape
1556            )));
1557        }
1558
1559        Self::verify_dshapes(
1560            &boxes.dshape,
1561            &boxes.shape,
1562            "Boxes",
1563            &[DimName::Batch, DimName::BoxCoords, DimName::NumBoxes],
1564        )?;
1565        Self::verify_dshapes(
1566            &scores.dshape,
1567            &scores.shape,
1568            "Scores",
1569            &[DimName::Batch, DimName::NumClasses, DimName::NumBoxes],
1570        )?;
1571
1572        let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[2]);
1573        let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[2]);
1574
1575        if boxes_num != scores_num {
1576            return Err(DecoderError::InvalidConfig(format!(
1577                "Yolo Split Detection Boxes num {} incompatible with Scores num {}",
1578                boxes_num, scores_num
1579            )));
1580        }
1581
1582        Ok(())
1583    }
1584
1585    fn verify_yolo_split_segdet(
1586        boxes: &configs::Boxes,
1587        scores: &configs::Scores,
1588        mask_coeff: &configs::MaskCoefficients,
1589        protos: &configs::Protos,
1590    ) -> Result<(), DecoderError> {
1591        if boxes.shape.len() != 3 {
1592            return Err(DecoderError::InvalidConfig(format!(
1593                "Invalid Yolo Split Boxes shape {:?}",
1594                boxes.shape
1595            )));
1596        }
1597        if scores.shape.len() != 3 {
1598            return Err(DecoderError::InvalidConfig(format!(
1599                "Invalid Yolo Split Scores shape {:?}",
1600                scores.shape
1601            )));
1602        }
1603
1604        if mask_coeff.shape.len() != 3 {
1605            return Err(DecoderError::InvalidConfig(format!(
1606                "Invalid Yolo Split Mask Coefficients shape {:?}",
1607                mask_coeff.shape
1608            )));
1609        }
1610
1611        if protos.shape.len() != 4 {
1612            return Err(DecoderError::InvalidConfig(format!(
1613                "Invalid Yolo Protos shape {:?}",
1614                mask_coeff.shape
1615            )));
1616        }
1617
1618        Self::verify_dshapes(
1619            &boxes.dshape,
1620            &boxes.shape,
1621            "Boxes",
1622            &[DimName::Batch, DimName::BoxCoords, DimName::NumBoxes],
1623        )?;
1624        Self::verify_dshapes(
1625            &scores.dshape,
1626            &scores.shape,
1627            "Scores",
1628            &[DimName::Batch, DimName::NumClasses, DimName::NumBoxes],
1629        )?;
1630        Self::verify_dshapes(
1631            &mask_coeff.dshape,
1632            &mask_coeff.shape,
1633            "Mask Coefficients",
1634            &[DimName::Batch, DimName::NumProtos, DimName::NumBoxes],
1635        )?;
1636        Self::verify_dshapes(
1637            &protos.dshape,
1638            &protos.shape,
1639            "Protos",
1640            &[
1641                DimName::Batch,
1642                DimName::Height,
1643                DimName::Width,
1644                DimName::NumProtos,
1645            ],
1646        )?;
1647
1648        let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[2]);
1649        let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[2]);
1650        let mask_num = Self::get_box_count(&mask_coeff.dshape).unwrap_or(mask_coeff.shape[2]);
1651
1652        let mask_channels = if !mask_coeff.dshape.is_empty() {
1653            Self::get_protos_count(&mask_coeff.dshape).ok_or_else(|| {
1654                DecoderError::InvalidConfig("Could not find num_protos in config".to_string())
1655            })?
1656        } else {
1657            mask_coeff.shape[1]
1658        };
1659        let proto_channels = if !protos.dshape.is_empty() {
1660            Self::get_protos_count(&protos.dshape).ok_or_else(|| {
1661                DecoderError::InvalidConfig("Could not find num_protos in config".to_string())
1662            })?
1663        } else {
1664            protos.shape[1].min(protos.shape[3])
1665        };
1666
1667        if boxes_num != scores_num {
1668            return Err(DecoderError::InvalidConfig(format!(
1669                "Yolo Split Detection Boxes num {} incompatible with Scores num {}",
1670                boxes_num, scores_num
1671            )));
1672        }
1673
1674        if boxes_num != mask_num {
1675            return Err(DecoderError::InvalidConfig(format!(
1676                "Yolo Split Detection Boxes num {} incompatible with Mask Coefficients num {}",
1677                boxes_num, mask_num
1678            )));
1679        }
1680
1681        if proto_channels != mask_channels {
1682            return Err(DecoderError::InvalidConfig(format!(
1683                "Yolo Protos channels {} incompatible with Mask Coefficients channels {}",
1684                proto_channels, mask_channels
1685            )));
1686        }
1687
1688        Ok(())
1689    }
1690
1691    fn verify_yolo_split_end_to_end_det(
1692        boxes: &configs::Boxes,
1693        scores: &configs::Scores,
1694        classes: &configs::Classes,
1695    ) -> Result<(), DecoderError> {
1696        if boxes.shape.len() != 3 || !boxes.shape.contains(&4) {
1697            return Err(DecoderError::InvalidConfig(format!(
1698                "Split end-to-end boxes must be [batch, N, 4], got {:?}",
1699                boxes.shape
1700            )));
1701        }
1702        if scores.shape.len() != 3 || !scores.shape.contains(&1) {
1703            return Err(DecoderError::InvalidConfig(format!(
1704                "Split end-to-end scores must be [batch, N, 1], got {:?}",
1705                scores.shape
1706            )));
1707        }
1708        if classes.shape.len() != 3 || !classes.shape.contains(&1) {
1709            return Err(DecoderError::InvalidConfig(format!(
1710                "Split end-to-end classes must be [batch, N, 1], got {:?}",
1711                classes.shape
1712            )));
1713        }
1714        Ok(())
1715    }
1716
1717    fn verify_yolo_split_end_to_end_segdet(
1718        boxes: &configs::Boxes,
1719        scores: &configs::Scores,
1720        classes: &configs::Classes,
1721        mask_coeff: &configs::MaskCoefficients,
1722        protos: &configs::Protos,
1723    ) -> Result<(), DecoderError> {
1724        Self::verify_yolo_split_end_to_end_det(boxes, scores, classes)?;
1725        if mask_coeff.shape.len() != 3 {
1726            return Err(DecoderError::InvalidConfig(format!(
1727                "Invalid split end-to-end mask coefficients shape {:?}",
1728                mask_coeff.shape
1729            )));
1730        }
1731        if protos.shape.len() != 4 {
1732            return Err(DecoderError::InvalidConfig(format!(
1733                "Invalid protos shape {:?}",
1734                protos.shape
1735            )));
1736        }
1737        Ok(())
1738    }
1739
1740    fn get_model_type_modelpack(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
1741        let mut split_decoders = Vec::new();
1742        let mut segment_ = None;
1743        let mut scores_ = None;
1744        let mut boxes_ = None;
1745        for c in configs.outputs {
1746            match c {
1747                ConfigOutput::Detection(detection) => split_decoders.push(detection),
1748                ConfigOutput::Segmentation(segmentation) => segment_ = Some(segmentation),
1749                ConfigOutput::Mask(_) => {}
1750                ConfigOutput::Protos(_) => {
1751                    return Err(DecoderError::InvalidConfig(
1752                        "ModelPack should not have protos".to_string(),
1753                    ));
1754                }
1755                ConfigOutput::Scores(scores) => scores_ = Some(scores),
1756                ConfigOutput::Boxes(boxes) => boxes_ = Some(boxes),
1757                ConfigOutput::MaskCoefficients(_) => {
1758                    return Err(DecoderError::InvalidConfig(
1759                        "ModelPack should not have mask coefficients".to_string(),
1760                    ));
1761                }
1762                ConfigOutput::Classes(_) => {
1763                    return Err(DecoderError::InvalidConfig(
1764                        "ModelPack should not have classes output".to_string(),
1765                    ));
1766                }
1767            }
1768        }
1769
1770        if let Some(segmentation) = segment_ {
1771            if !split_decoders.is_empty() {
1772                let classes = Self::verify_modelpack_split_det(&split_decoders)?;
1773                Self::verify_modelpack_seg(&segmentation, Some(classes))?;
1774                Ok(ModelType::ModelPackSegDetSplit {
1775                    detection: split_decoders,
1776                    segmentation,
1777                })
1778            } else if let (Some(scores), Some(boxes)) = (scores_, boxes_) {
1779                let classes = Self::verify_modelpack_det(&boxes, &scores)?;
1780                Self::verify_modelpack_seg(&segmentation, Some(classes))?;
1781                Ok(ModelType::ModelPackSegDet {
1782                    boxes,
1783                    scores,
1784                    segmentation,
1785                })
1786            } else {
1787                Self::verify_modelpack_seg(&segmentation, None)?;
1788                Ok(ModelType::ModelPackSeg { segmentation })
1789            }
1790        } else if !split_decoders.is_empty() {
1791            Self::verify_modelpack_split_det(&split_decoders)?;
1792            Ok(ModelType::ModelPackDetSplit {
1793                detection: split_decoders,
1794            })
1795        } else if let (Some(scores), Some(boxes)) = (scores_, boxes_) {
1796            Self::verify_modelpack_det(&boxes, &scores)?;
1797            Ok(ModelType::ModelPackDet { boxes, scores })
1798        } else {
1799            Err(DecoderError::InvalidConfig(
1800                "Invalid ModelPack model outputs".to_string(),
1801            ))
1802        }
1803    }
1804
1805    fn verify_modelpack_det(
1806        boxes: &configs::Boxes,
1807        scores: &configs::Scores,
1808    ) -> Result<usize, DecoderError> {
1809        if boxes.shape.len() != 4 {
1810            return Err(DecoderError::InvalidConfig(format!(
1811                "Invalid ModelPack Boxes shape {:?}",
1812                boxes.shape
1813            )));
1814        }
1815        if scores.shape.len() != 3 {
1816            return Err(DecoderError::InvalidConfig(format!(
1817                "Invalid ModelPack Scores shape {:?}",
1818                scores.shape
1819            )));
1820        }
1821
1822        Self::verify_dshapes(
1823            &boxes.dshape,
1824            &boxes.shape,
1825            "Boxes",
1826            &[
1827                DimName::Batch,
1828                DimName::NumBoxes,
1829                DimName::Padding,
1830                DimName::BoxCoords,
1831            ],
1832        )?;
1833        Self::verify_dshapes(
1834            &scores.dshape,
1835            &scores.shape,
1836            "Scores",
1837            &[DimName::Batch, DimName::NumBoxes, DimName::NumClasses],
1838        )?;
1839
1840        let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[1]);
1841        let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[1]);
1842
1843        if boxes_num != scores_num {
1844            return Err(DecoderError::InvalidConfig(format!(
1845                "ModelPack Detection Boxes num {} incompatible with Scores num {}",
1846                boxes_num, scores_num
1847            )));
1848        }
1849
1850        let num_classes = if !scores.dshape.is_empty() {
1851            Self::get_class_count(&scores.dshape, None, None)?
1852        } else {
1853            Self::get_class_count_no_dshape(scores.into(), None)?
1854        };
1855
1856        Ok(num_classes)
1857    }
1858
1859    fn verify_modelpack_split_det(boxes: &[configs::Detection]) -> Result<usize, DecoderError> {
1860        let mut num_classes = None;
1861        for b in boxes {
1862            let Some(num_anchors) = b.anchors.as_ref().map(|a| a.len()) else {
1863                return Err(DecoderError::InvalidConfig(
1864                    "ModelPack Split Detection missing anchors".to_string(),
1865                ));
1866            };
1867
1868            if num_anchors == 0 {
1869                return Err(DecoderError::InvalidConfig(
1870                    "ModelPack Split Detection has zero anchors".to_string(),
1871                ));
1872            }
1873
1874            if b.shape.len() != 4 {
1875                return Err(DecoderError::InvalidConfig(format!(
1876                    "Invalid ModelPack Split Detection shape {:?}",
1877                    b.shape
1878                )));
1879            }
1880
1881            Self::verify_dshapes(
1882                &b.dshape,
1883                &b.shape,
1884                "Split Detection",
1885                &[
1886                    DimName::Batch,
1887                    DimName::Height,
1888                    DimName::Width,
1889                    DimName::NumAnchorsXFeatures,
1890                ],
1891            )?;
1892            let classes = if !b.dshape.is_empty() {
1893                Self::get_class_count(&b.dshape, None, Some(num_anchors))?
1894            } else {
1895                Self::get_class_count_no_dshape(b.into(), None)?
1896            };
1897
1898            match num_classes {
1899                Some(n) => {
1900                    if n != classes {
1901                        return Err(DecoderError::InvalidConfig(format!(
1902                            "ModelPack Split Detection inconsistent number of classes: previous {}, current {}",
1903                            n, classes
1904                        )));
1905                    }
1906                }
1907                None => {
1908                    num_classes = Some(classes);
1909                }
1910            }
1911        }
1912
1913        Ok(num_classes.unwrap_or(0))
1914    }
1915
1916    fn verify_modelpack_seg(
1917        segmentation: &configs::Segmentation,
1918        classes: Option<usize>,
1919    ) -> Result<(), DecoderError> {
1920        if segmentation.shape.len() != 4 {
1921            return Err(DecoderError::InvalidConfig(format!(
1922                "Invalid ModelPack Segmentation shape {:?}",
1923                segmentation.shape
1924            )));
1925        }
1926        Self::verify_dshapes(
1927            &segmentation.dshape,
1928            &segmentation.shape,
1929            "Segmentation",
1930            &[
1931                DimName::Batch,
1932                DimName::Height,
1933                DimName::Width,
1934                DimName::NumClasses,
1935            ],
1936        )?;
1937
1938        if let Some(classes) = classes {
1939            let seg_classes = if !segmentation.dshape.is_empty() {
1940                Self::get_class_count(&segmentation.dshape, None, None)?
1941            } else {
1942                Self::get_class_count_no_dshape(segmentation.into(), None)?
1943            };
1944
1945            if seg_classes != classes + 1 {
1946                return Err(DecoderError::InvalidConfig(format!(
1947                    "ModelPack Segmentation channels {} incompatible with number of classes {}",
1948                    seg_classes, classes
1949                )));
1950            }
1951        }
1952        Ok(())
1953    }
1954
1955    // verifies that dshapes match the given shape
1956    fn verify_dshapes(
1957        dshape: &[(DimName, usize)],
1958        shape: &[usize],
1959        name: &str,
1960        dims: &[DimName],
1961    ) -> Result<(), DecoderError> {
1962        for s in shape {
1963            if *s == 0 {
1964                return Err(DecoderError::InvalidConfig(format!(
1965                    "{} shape has zero dimension",
1966                    name
1967                )));
1968            }
1969        }
1970
1971        if shape.len() != dims.len() {
1972            return Err(DecoderError::InvalidConfig(format!(
1973                "{} shape length {} does not match expected dims length {}",
1974                name,
1975                shape.len(),
1976                dims.len()
1977            )));
1978        }
1979
1980        if dshape.is_empty() {
1981            return Ok(());
1982        }
1983        // check the dshape lengths match the shape lengths
1984        if dshape.len() != shape.len() {
1985            return Err(DecoderError::InvalidConfig(format!(
1986                "{} dshape length does not match shape length",
1987                name
1988            )));
1989        }
1990
1991        // check the dshape values match the shape values
1992        for ((dim_name, dim_size), shape_size) in dshape.iter().zip(shape) {
1993            if dim_size != shape_size {
1994                return Err(DecoderError::InvalidConfig(format!(
1995                    "{} dshape dimension {} size {} does not match shape size {}",
1996                    name, dim_name, dim_size, shape_size
1997                )));
1998            }
1999            if *dim_name == DimName::Padding && *dim_size != 1 {
2000                return Err(DecoderError::InvalidConfig(
2001                    "Padding dimension size must be 1".to_string(),
2002                ));
2003            }
2004
2005            if *dim_name == DimName::BoxCoords && *dim_size != 4 {
2006                return Err(DecoderError::InvalidConfig(
2007                    "BoxCoords dimension size must be 4".to_string(),
2008                ));
2009            }
2010        }
2011
2012        let dims_present = HashSet::<DimName>::from_iter(dshape.iter().map(|(name, _)| *name));
2013        for dim in dims {
2014            if !dims_present.contains(dim) {
2015                return Err(DecoderError::InvalidConfig(format!(
2016                    "{} dshape missing required dimension {:?}",
2017                    name, dim
2018                )));
2019            }
2020        }
2021
2022        Ok(())
2023    }
2024
2025    fn get_box_count(dshape: &[(DimName, usize)]) -> Option<usize> {
2026        for (dim_name, dim_size) in dshape {
2027            if *dim_name == DimName::NumBoxes {
2028                return Some(*dim_size);
2029            }
2030        }
2031        None
2032    }
2033
2034    fn get_class_count_no_dshape(
2035        config: ConfigOutputRef,
2036        protos: Option<usize>,
2037    ) -> Result<usize, DecoderError> {
2038        match config {
2039            ConfigOutputRef::Detection(detection) => match detection.decoder {
2040                DecoderType::Ultralytics => {
2041                    if detection.shape[1] <= 4 + protos.unwrap_or(0) {
2042                        return Err(DecoderError::InvalidConfig(format!(
2043                            "Invalid shape: Yolo num_features {} must be greater than {}",
2044                            detection.shape[1],
2045                            4 + protos.unwrap_or(0),
2046                        )));
2047                    }
2048                    Ok(detection.shape[1] - 4 - protos.unwrap_or(0))
2049                }
2050                DecoderType::ModelPack => {
2051                    let Some(num_anchors) = detection.anchors.as_ref().map(|a| a.len()) else {
2052                        return Err(DecoderError::Internal(
2053                            "ModelPack Detection missing anchors".to_string(),
2054                        ));
2055                    };
2056                    let anchors_x_features = detection.shape[3];
2057                    if anchors_x_features <= num_anchors * 5 {
2058                        return Err(DecoderError::InvalidConfig(format!(
2059                            "Invalid ModelPack Split Detection shape: anchors_x_features {} not greater than number of anchors * 5 = {}",
2060                            anchors_x_features,
2061                            num_anchors * 5,
2062                        )));
2063                    }
2064
2065                    if !anchors_x_features.is_multiple_of(num_anchors) {
2066                        return Err(DecoderError::InvalidConfig(format!(
2067                            "Invalid ModelPack Split Detection shape: anchors_x_features {} not a multiple of number of anchors {}",
2068                            anchors_x_features, num_anchors
2069                        )));
2070                    }
2071                    Ok(anchors_x_features / num_anchors - 5)
2072                }
2073            },
2074
2075            ConfigOutputRef::Scores(scores) => match scores.decoder {
2076                DecoderType::Ultralytics => Ok(scores.shape[1]),
2077                DecoderType::ModelPack => Ok(scores.shape[2]),
2078            },
2079            ConfigOutputRef::Segmentation(seg) => Ok(seg.shape[3]),
2080            _ => Err(DecoderError::Internal(
2081                "Attempted to get class count from unsupported config output".to_owned(),
2082            )),
2083        }
2084    }
2085
2086    // get the class count from dshape or calculate from num_features
2087    fn get_class_count(
2088        dshape: &[(DimName, usize)],
2089        protos: Option<usize>,
2090        anchors: Option<usize>,
2091    ) -> Result<usize, DecoderError> {
2092        if dshape.is_empty() {
2093            return Ok(0);
2094        }
2095        // if it has num_classes in dshape, return it
2096        for (dim_name, dim_size) in dshape {
2097            if *dim_name == DimName::NumClasses {
2098                return Ok(*dim_size);
2099            }
2100        }
2101
2102        // number of classes can be calculated from num_features - 4 for yolo.  If the
2103        // model has protos, we also subtract the number of protos.
2104        for (dim_name, dim_size) in dshape {
2105            if *dim_name == DimName::NumFeatures {
2106                let protos = protos.unwrap_or(0);
2107                if protos + 4 >= *dim_size {
2108                    return Err(DecoderError::InvalidConfig(format!(
2109                        "Invalid shape: Yolo num_features {} must be greater than {}",
2110                        *dim_size,
2111                        protos + 4,
2112                    )));
2113                }
2114                return Ok(*dim_size - 4 - protos);
2115            }
2116        }
2117
2118        // number of classes can be calculated from number of anchors for modelpack
2119        // split detection
2120        if let Some(num_anchors) = anchors {
2121            for (dim_name, dim_size) in dshape {
2122                if *dim_name == DimName::NumAnchorsXFeatures {
2123                    let anchors_x_features = *dim_size;
2124                    if anchors_x_features <= num_anchors * 5 {
2125                        return Err(DecoderError::InvalidConfig(format!(
2126                            "Invalid ModelPack Split Detection shape: anchors_x_features {} not greater than number of anchors * 5 = {}",
2127                            anchors_x_features,
2128                            num_anchors * 5,
2129                        )));
2130                    }
2131
2132                    if !anchors_x_features.is_multiple_of(num_anchors) {
2133                        return Err(DecoderError::InvalidConfig(format!(
2134                            "Invalid ModelPack Split Detection shape: anchors_x_features {} not a multiple of number of anchors {}",
2135                            anchors_x_features, num_anchors
2136                        )));
2137                    }
2138                    return Ok((anchors_x_features / num_anchors) - 5);
2139                }
2140            }
2141        }
2142        Err(DecoderError::InvalidConfig(
2143            "Cannot determine number of classes from dshape".to_owned(),
2144        ))
2145    }
2146
2147    fn get_protos_count(dshape: &[(DimName, usize)]) -> Option<usize> {
2148        for (dim_name, dim_size) in dshape {
2149            if *dim_name == DimName::NumProtos {
2150                return Some(*dim_size);
2151            }
2152        }
2153        None
2154    }
2155}