Skip to main content

edgefirst_decoder/decoder/
builder.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::HashSet;
5
6use super::config::ConfigOutputRef;
7use super::configs::{self, DecoderType, DecoderVersion, DimName, ModelType};
8use super::merge::DecodeProgram;
9use super::{ConfigOutput, ConfigOutputs, Decoder};
10use crate::per_scale::{DecodeDtype, PerScalePlan};
11use crate::schema::SchemaV2;
12use crate::DecoderError;
13
14/// Extract `(width, height)` from a schema [`crate::schema::InputSpec`].
15///
16/// Prefers named dims (`DimName::Width` / `DimName::Height`) when the
17/// `dshape` is populated, falling back to the NHWC convention
18/// (`shape[1] = H, shape[2] = W`) for 4-element shapes whenever either
19/// named dim is missing. Returns `None` for any other shape arity — the
20/// decoder treats that as "input dims unknown" and skips the EDGEAI-1303
21/// normalization path.
22///
23/// The fallback fires when **either** `Height` or `Width` is missing from
24/// the dshape (not only when both are absent), so a partially-named
25/// dshape (e.g. only `Width`) still resolves both dims via positional
26/// inference instead of silently disabling normalization.
27fn input_dims_from_spec(input: &crate::schema::InputSpec) -> Option<(usize, usize)> {
28    use crate::configs::DimName;
29    let mut h = None;
30    let mut w = None;
31    // `SchemaV2::validate()` doesn't currently enforce
32    // `dshape.len() <= shape.len()`, so a malformed schema can trip an
33    // out-of-bounds index here. Use `shape.get(i)` to silently skip
34    // dshape entries that overflow the shape — the caller treats
35    // missing dims as "unknown" and disables EDGEAI-1303 normalization.
36    for (i, (name, _)) in input.dshape.iter().enumerate() {
37        match name {
38            DimName::Height => h = input.shape.get(i).copied(),
39            DimName::Width => w = input.shape.get(i).copied(),
40            _ => {}
41        }
42    }
43    if (h.is_none() || w.is_none()) && input.shape.len() == 4 {
44        // NHWC default: [N, H, W, C]. Mirrors the per-scale `extract_hw`
45        // fallback (`crates/decoder/src/per_scale/plan.rs::extract_hw`).
46        // Only fill the missing axis so a partial named dshape still
47        // resolves both dims.
48        h = h.or(Some(input.shape[1]));
49        w = w.or(Some(input.shape[2]));
50    }
51    match (w, h) {
52        (Some(w), Some(h)) => Some((w, h)),
53        _ => None,
54    }
55}
56
57#[cfg(test)]
58mod input_dims_from_spec_tests {
59    use super::input_dims_from_spec;
60    use crate::configs::DimName;
61    use crate::schema::InputSpec;
62
63    #[test]
64    fn named_dshape_resolves_dims() {
65        let spec = InputSpec {
66            shape: vec![1, 480, 640, 3],
67            dshape: vec![
68                (DimName::Batch, 1),
69                (DimName::Height, 480),
70                (DimName::Width, 640),
71                (DimName::NumFeatures, 3),
72            ],
73            cameraadaptor: None,
74        };
75        assert_eq!(input_dims_from_spec(&spec), Some((640, 480)));
76    }
77
78    #[test]
79    fn empty_dshape_falls_back_to_nhwc_for_4d() {
80        let spec = InputSpec {
81            shape: vec![1, 480, 640, 3],
82            dshape: vec![],
83            cameraadaptor: None,
84        };
85        assert_eq!(input_dims_from_spec(&spec), Some((640, 480)));
86    }
87
88    #[test]
89    fn malformed_dshape_longer_than_shape_does_not_panic() {
90        // Regression for Copilot review on PR #63: indexing
91        // `input.shape[i]` while iterating dshape can OOB-panic when
92        // `dshape.len() > shape.len()`. The fix uses `shape.get(i)`
93        // and silently treats overflow as "dim missing".
94        let spec = InputSpec {
95            shape: vec![640, 480], // 2-D shape
96            dshape: vec![
97                (DimName::Width, 640),
98                (DimName::Height, 480),
99                (DimName::NumFeatures, 3), // overflow — index 2 ≥ shape.len()
100            ],
101            cameraadaptor: None,
102        };
103        // First two dshape entries resolve via `shape.get()`; the third
104        // is a no-op. Width/Height both resolved, so we expect Some.
105        assert_eq!(input_dims_from_spec(&spec), Some((640, 480)));
106    }
107
108    #[test]
109    fn malformed_dshape_only_overflow_returns_none() {
110        // All dshape entries are past the shape boundary — width and
111        // height stay None and the 4-D NHWC fallback doesn't fire
112        // (shape.len() == 1), so we get None instead of a panic.
113        let spec = InputSpec {
114            shape: vec![1],
115            dshape: vec![
116                (DimName::NumFeatures, 3),
117                (DimName::Height, 480),
118                (DimName::Width, 640),
119            ],
120            cameraadaptor: None,
121        };
122        assert_eq!(input_dims_from_spec(&spec), None);
123    }
124}
125
126#[derive(Debug, Clone, PartialEq)]
127pub struct DecoderBuilder {
128    config_src: Option<ConfigSource>,
129    iou_threshold: f32,
130    score_threshold: f32,
131    /// NMS mode.
132    ///
133    /// - `Some(Nms::Auto)` — resolve from config or fallback to
134    ///   `ClassAgnostic` (builder default)
135    /// - `Some(Nms::ClassAgnostic)` — explicit class-agnostic override
136    /// - `Some(Nms::ClassAware)` — explicit class-aware override
137    /// - `None` — bypass NMS entirely
138    nms: Option<configs::Nms>,
139    /// Output dtype for the per-scale fast path. Has no effect on
140    /// schemas without per-scale children (which use the legacy decode
141    /// path).
142    decode_dtype: DecodeDtype,
143    pre_nms_top_k: usize,
144    max_det: usize,
145    /// Explicit override for the model input dimensions `(width, height)`,
146    /// consumed by EDGEAI-1303 normalization. When set, takes precedence
147    /// over schema-derived dims; when `None`, the value is read from the
148    /// schema's `input.shape` / `input.dshape` at build time.
149    input_dims: Option<(usize, usize)>,
150}
151
152#[derive(Debug, Clone, PartialEq)]
153enum ConfigSource {
154    Yaml(String),
155    Json(String),
156    Config(ConfigOutputs),
157    /// Schema v2 metadata. During build the schema is either converted
158    /// to a legacy [`ConfigOutputs`] (flat case) or used to construct a
159    /// [`DecodeProgram`] that performs per-child dequant + merge at
160    /// decode time.
161    Schema(SchemaV2),
162}
163
164impl Default for DecoderBuilder {
165    /// Creates a default DecoderBuilder with no configuration and 0.5 score
166    /// threshold and 0.5 IoU threshold.
167    ///
168    /// A valid configuration must be provided before building the Decoder.
169    ///
170    /// # Examples
171    /// ```rust
172    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
173    /// # fn main() -> DecoderResult<()> {
174    /// #  let config_yaml = edgefirst_bench::testdata::read_to_string("modelpack_split.yaml").to_string();
175    /// let decoder = DecoderBuilder::default()
176    ///     .with_config_yaml_str(config_yaml)
177    ///     .build()?;
178    /// assert_eq!(decoder.score_threshold, 0.5);
179    /// assert_eq!(decoder.iou_threshold, 0.5);
180    ///
181    /// # Ok(())
182    /// # }
183    /// ```
184    fn default() -> Self {
185        Self {
186            config_src: None,
187            iou_threshold: 0.5,
188            score_threshold: 0.5,
189            nms: Some(configs::Nms::Auto),
190            decode_dtype: DecodeDtype::F32,
191            pre_nms_top_k: 300,
192            max_det: 300,
193            input_dims: None,
194        }
195    }
196}
197
198impl DecoderBuilder {
199    /// Creates a default DecoderBuilder with no configuration and 0.5 score
200    /// threshold and 0.5 IoU threshold.
201    ///
202    /// A valid configuration must be provided before building the Decoder.
203    ///
204    /// # Examples
205    /// ```rust
206    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
207    /// # fn main() -> DecoderResult<()> {
208    /// #  let config_yaml = edgefirst_bench::testdata::read_to_string("modelpack_split.yaml").to_string();
209    /// let decoder = DecoderBuilder::new()
210    ///     .with_config_yaml_str(config_yaml)
211    ///     .build()?;
212    /// assert_eq!(decoder.score_threshold, 0.5);
213    /// assert_eq!(decoder.iou_threshold, 0.5);
214    ///
215    /// # Ok(())
216    /// # }
217    /// ```
218    pub fn new() -> Self {
219        Self::default()
220    }
221
222    /// Loads a model configuration in YAML format. Does not check if the string
223    /// is a correct configuration file. Use `DecoderBuilder.build()` to
224    /// deserialize the YAML and parse the model configuration.
225    ///
226    /// # Examples
227    /// ```rust,no_run
228    /// # use edgefirst_decoder::DecoderBuilder;
229    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
230    /// let config_yaml = std::fs::read_to_string("modelpack_split.yaml")?;
231    /// let decoder = DecoderBuilder::new()
232    ///     .with_config_yaml_str(config_yaml)
233    ///     .build()?;
234    ///
235    /// # Ok(())
236    /// # }
237    /// ```
238    pub fn with_config_yaml_str(mut self, yaml_str: String) -> Self {
239        self.config_src.replace(ConfigSource::Yaml(yaml_str));
240        self
241    }
242
243    /// Loads a model configuration in JSON format. Does not check if the string
244    /// is a correct configuration file. Use `DecoderBuilder.build()` to
245    /// deserialize the JSON and parse the model configuration.
246    ///
247    /// # Examples
248    /// ```rust,no_run
249    /// # use edgefirst_decoder::DecoderBuilder;
250    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
251    /// let config_json = std::fs::read_to_string("modelpack_split.json")?;
252    /// let decoder = DecoderBuilder::new()
253    ///     .with_config_json_str(config_json)
254    ///     .build()?;
255    ///
256    /// # Ok(())
257    /// # }
258    /// ```
259    pub fn with_config_json_str(mut self, json_str: String) -> Self {
260        self.config_src.replace(ConfigSource::Json(json_str));
261        self
262    }
263
264    /// Loads a model configuration. Does not check if the configuration is
265    /// correct. Intended to be used when the user needs control over the
266    /// deserialize of the configuration information. Use
267    /// `DecoderBuilder.build()` to parse the model configuration.
268    ///
269    /// # Examples
270    /// ```rust,no_run
271    /// # use edgefirst_decoder::{DecoderBuilder, ConfigOutputs};
272    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
273    /// let config_json = std::fs::read_to_string("modelpack_split.json")?;
274    /// let config: ConfigOutputs = serde_json::from_str(&config_json)?;
275    /// let decoder = DecoderBuilder::new().with_config(config).build()?;
276    ///
277    /// # Ok(())
278    /// # }
279    /// ```
280    pub fn with_config(mut self, config: ConfigOutputs) -> Self {
281        self.config_src.replace(ConfigSource::Config(config));
282        self
283    }
284
285    /// Configure the decoder from a schema v2 metadata document.
286    ///
287    /// Accepts a [`SchemaV2`] as produced by [`SchemaV2::parse_json`],
288    /// [`SchemaV2::parse_yaml`], [`SchemaV2::parse_file`], or
289    /// constructed programmatically. The builder validates the schema,
290    /// compiles a [`DecodeProgram`] for any split logical outputs
291    /// (per-scale or channel sub-splits), and downconverts the
292    /// logical-level semantics to the legacy [`ConfigOutputs`]
293    /// representation consumed by the existing decoder dispatch.
294    ///
295    /// # Examples
296    ///
297    /// ```rust,no_run
298    /// use edgefirst_decoder::{DecoderBuilder, DecoderResult};
299    /// use edgefirst_decoder::schema::SchemaV2;
300    ///
301    /// # fn main() -> DecoderResult<()> {
302    /// let schema = SchemaV2::parse_file("model/edgefirst.json")?;
303    /// let decoder = DecoderBuilder::new()
304    ///     .with_schema(schema)
305    ///     .with_score_threshold(0.25)
306    ///     .build()?;
307    /// # Ok(())
308    /// # }
309    /// ```
310    pub fn with_schema(mut self, schema: SchemaV2) -> Self {
311        self.config_src.replace(ConfigSource::Schema(schema));
312        self
313    }
314
315    /// Choose the output dtype for the per-scale decoder pipeline.
316    ///
317    /// Defaults to [`DecodeDtype::F32`]. Has no effect on schemas
318    /// without per-scale children (which use the legacy decode path).
319    /// `F16` saves ~2× memory bandwidth at the cost of 10-bit mantissa
320    /// precision; empirically safe for YOLO-family models.
321    ///
322    /// # Examples
323    ///
324    /// ```rust,no_run
325    /// use edgefirst_decoder::{DecodeDtype, DecoderBuilder, DecoderResult};
326    /// use edgefirst_decoder::schema::SchemaV2;
327    ///
328    /// # fn main() -> DecoderResult<()> {
329    /// let schema = SchemaV2::parse_file("model/edgefirst.json")?;
330    /// let decoder = DecoderBuilder::new()
331    ///     .with_schema(schema)
332    ///     .with_decode_dtype(DecodeDtype::F32)
333    ///     .build()?;
334    /// # Ok(())
335    /// # }
336    /// ```
337    pub fn with_decode_dtype(mut self, dtype: DecodeDtype) -> Self {
338        self.decode_dtype = dtype;
339        self
340    }
341
342    /// Loads a YOLO detection model configuration.  Use
343    /// `DecoderBuilder.build()` to parse the model configuration.
344    ///
345    /// # Examples
346    /// ```rust
347    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
348    /// # fn main() -> DecoderResult<()> {
349    /// let decoder = DecoderBuilder::new()
350    ///     .with_config_yolo_det(
351    ///         configs::Detection {
352    ///             anchors: None,
353    ///             decoder: configs::DecoderType::Ultralytics,
354    ///             quantization: Some(configs::QuantTuple(0.012345, 26)),
355    ///             shape: vec![1, 84, 8400],
356    ///             dshape: Vec::new(),
357    ///             normalized: Some(true),
358    ///         },
359    ///         None,
360    ///     )
361    ///     .build()?;
362    ///
363    /// # Ok(())
364    /// # }
365    /// ```
366    pub fn with_config_yolo_det(
367        mut self,
368        boxes: configs::Detection,
369        version: Option<DecoderVersion>,
370    ) -> Self {
371        let config = ConfigOutputs {
372            outputs: vec![ConfigOutput::Detection(boxes)],
373            decoder_version: version,
374            ..Default::default()
375        };
376        self.config_src.replace(ConfigSource::Config(config));
377        self
378    }
379
380    /// Loads a YOLO split detection model configuration.  Use
381    /// `DecoderBuilder.build()` to parse the model configuration.
382    ///
383    /// # Examples
384    /// ```rust
385    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
386    /// # fn main() -> DecoderResult<()> {
387    /// let boxes_config = configs::Boxes {
388    ///     decoder: configs::DecoderType::Ultralytics,
389    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
390    ///     shape: vec![1, 4, 8400],
391    ///     dshape: Vec::new(),
392    ///     normalized: Some(true),
393    /// };
394    /// let scores_config = configs::Scores {
395    ///     decoder: configs::DecoderType::Ultralytics,
396    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
397    ///     shape: vec![1, 80, 8400],
398    ///     dshape: Vec::new(),
399    /// };
400    /// let decoder = DecoderBuilder::new()
401    ///     .with_config_yolo_split_det(boxes_config, scores_config)
402    ///     .build()?;
403    /// # Ok(())
404    /// # }
405    /// ```
406    pub fn with_config_yolo_split_det(
407        mut self,
408        boxes: configs::Boxes,
409        scores: configs::Scores,
410    ) -> Self {
411        let config = ConfigOutputs {
412            outputs: vec![ConfigOutput::Boxes(boxes), ConfigOutput::Scores(scores)],
413            ..Default::default()
414        };
415        self.config_src.replace(ConfigSource::Config(config));
416        self
417    }
418
419    /// Loads a YOLO segmentation model configuration.  Use
420    /// `DecoderBuilder.build()` to parse the model configuration.
421    ///
422    /// # Examples
423    /// ```rust
424    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
425    /// # fn main() -> DecoderResult<()> {
426    /// let seg_config = configs::Detection {
427    ///     decoder: configs::DecoderType::Ultralytics,
428    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
429    ///     shape: vec![1, 116, 8400],
430    ///     anchors: None,
431    ///     dshape: Vec::new(),
432    ///     normalized: Some(true),
433    /// };
434    /// let protos_config = configs::Protos {
435    ///     decoder: configs::DecoderType::Ultralytics,
436    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
437    ///     shape: vec![1, 160, 160, 32],
438    ///     dshape: Vec::new(),
439    /// };
440    /// let decoder = DecoderBuilder::new()
441    ///     .with_config_yolo_segdet(
442    ///         seg_config,
443    ///         protos_config,
444    ///         Some(configs::DecoderVersion::Yolov8),
445    ///     )
446    ///     .build()?;
447    /// # Ok(())
448    /// # }
449    /// ```
450    pub fn with_config_yolo_segdet(
451        mut self,
452        boxes: configs::Detection,
453        protos: configs::Protos,
454        version: Option<DecoderVersion>,
455    ) -> Self {
456        let config = ConfigOutputs {
457            outputs: vec![ConfigOutput::Detection(boxes), ConfigOutput::Protos(protos)],
458            decoder_version: version,
459            ..Default::default()
460        };
461        self.config_src.replace(ConfigSource::Config(config));
462        self
463    }
464
465    /// Loads a YOLO split segmentation model configuration.  Use
466    /// `DecoderBuilder.build()` to parse the model configuration.
467    ///
468    /// # Examples
469    /// ```rust
470    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
471    /// # fn main() -> DecoderResult<()> {
472    /// let boxes_config = configs::Boxes {
473    ///     decoder: configs::DecoderType::Ultralytics,
474    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
475    ///     shape: vec![1, 4, 8400],
476    ///     dshape: Vec::new(),
477    ///     normalized: Some(true),
478    /// };
479    /// let scores_config = configs::Scores {
480    ///     decoder: configs::DecoderType::Ultralytics,
481    ///     quantization: Some(configs::QuantTuple(0.012345, 14)),
482    ///     shape: vec![1, 80, 8400],
483    ///     dshape: Vec::new(),
484    /// };
485    /// let mask_config = configs::MaskCoefficients {
486    ///     decoder: configs::DecoderType::Ultralytics,
487    ///     quantization: Some(configs::QuantTuple(0.0064123, 125)),
488    ///     shape: vec![1, 32, 8400],
489    ///     dshape: Vec::new(),
490    /// };
491    /// let protos_config = configs::Protos {
492    ///     decoder: configs::DecoderType::Ultralytics,
493    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
494    ///     shape: vec![1, 160, 160, 32],
495    ///     dshape: Vec::new(),
496    /// };
497    /// let decoder = DecoderBuilder::new()
498    ///     .with_config_yolo_split_segdet(boxes_config, scores_config, mask_config, protos_config)
499    ///     .build()?;
500    /// # Ok(())
501    /// # }
502    /// ```
503    pub fn with_config_yolo_split_segdet(
504        mut self,
505        boxes: configs::Boxes,
506        scores: configs::Scores,
507        mask_coefficients: configs::MaskCoefficients,
508        protos: configs::Protos,
509    ) -> Self {
510        let config = ConfigOutputs {
511            outputs: vec![
512                ConfigOutput::Boxes(boxes),
513                ConfigOutput::Scores(scores),
514                ConfigOutput::MaskCoefficients(mask_coefficients),
515                ConfigOutput::Protos(protos),
516            ],
517            ..Default::default()
518        };
519        self.config_src.replace(ConfigSource::Config(config));
520        self
521    }
522
523    /// Loads a ModelPack detection model configuration.  Use
524    /// `DecoderBuilder.build()` to parse the model configuration.
525    ///
526    /// # Examples
527    /// ```rust
528    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
529    /// # fn main() -> DecoderResult<()> {
530    /// let boxes_config = configs::Boxes {
531    ///     decoder: configs::DecoderType::ModelPack,
532    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
533    ///     shape: vec![1, 8400, 1, 4],
534    ///     dshape: Vec::new(),
535    ///     normalized: Some(true),
536    /// };
537    /// let scores_config = configs::Scores {
538    ///     decoder: configs::DecoderType::ModelPack,
539    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
540    ///     shape: vec![1, 8400, 3],
541    ///     dshape: Vec::new(),
542    /// };
543    /// let decoder = DecoderBuilder::new()
544    ///     .with_config_modelpack_det(boxes_config, scores_config)
545    ///     .build()?;
546    /// # Ok(())
547    /// # }
548    /// ```
549    pub fn with_config_modelpack_det(
550        mut self,
551        boxes: configs::Boxes,
552        scores: configs::Scores,
553    ) -> Self {
554        let config = ConfigOutputs {
555            outputs: vec![ConfigOutput::Boxes(boxes), ConfigOutput::Scores(scores)],
556            ..Default::default()
557        };
558        self.config_src.replace(ConfigSource::Config(config));
559        self
560    }
561
562    /// Loads a ModelPack split detection model configuration. Use
563    /// `DecoderBuilder.build()` to parse the model configuration.
564    ///
565    /// # Examples
566    /// ```rust
567    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
568    /// # fn main() -> DecoderResult<()> {
569    /// let config0 = configs::Detection {
570    ///     anchors: Some(vec![
571    ///         [0.13750000298023224, 0.2074074000120163],
572    ///         [0.2541666626930237, 0.21481481194496155],
573    ///         [0.23125000298023224, 0.35185185074806213],
574    ///     ]),
575    ///     decoder: configs::DecoderType::ModelPack,
576    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
577    ///     shape: vec![1, 17, 30, 18],
578    ///     dshape: Vec::new(),
579    ///     normalized: Some(true),
580    /// };
581    /// let config1 = configs::Detection {
582    ///     anchors: Some(vec![
583    ///         [0.36666667461395264, 0.31481480598449707],
584    ///         [0.38749998807907104, 0.4740740656852722],
585    ///         [0.5333333611488342, 0.644444465637207],
586    ///     ]),
587    ///     decoder: configs::DecoderType::ModelPack,
588    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
589    ///     shape: vec![1, 9, 15, 18],
590    ///     dshape: Vec::new(),
591    ///     normalized: Some(true),
592    /// };
593    ///
594    /// let decoder = DecoderBuilder::new()
595    ///     .with_config_modelpack_det_split(vec![config0, config1])
596    ///     .build()?;
597    /// # Ok(())
598    /// # }
599    /// ```
600    pub fn with_config_modelpack_det_split(mut self, boxes: Vec<configs::Detection>) -> Self {
601        let outputs = boxes.into_iter().map(ConfigOutput::Detection).collect();
602        let config = ConfigOutputs {
603            outputs,
604            ..Default::default()
605        };
606        self.config_src.replace(ConfigSource::Config(config));
607        self
608    }
609
610    /// Loads a ModelPack segmentation detection model configuration. Use
611    /// `DecoderBuilder.build()` to parse the model configuration.
612    ///
613    /// # Examples
614    /// ```rust
615    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
616    /// # fn main() -> DecoderResult<()> {
617    /// let boxes_config = configs::Boxes {
618    ///     decoder: configs::DecoderType::ModelPack,
619    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
620    ///     shape: vec![1, 8400, 1, 4],
621    ///     dshape: Vec::new(),
622    ///     normalized: Some(true),
623    /// };
624    /// let scores_config = configs::Scores {
625    ///     decoder: configs::DecoderType::ModelPack,
626    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
627    ///     shape: vec![1, 8400, 2],
628    ///     dshape: Vec::new(),
629    /// };
630    /// let seg_config = configs::Segmentation {
631    ///     decoder: configs::DecoderType::ModelPack,
632    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
633    ///     shape: vec![1, 640, 640, 3],
634    ///     dshape: Vec::new(),
635    /// };
636    /// let decoder = DecoderBuilder::new()
637    ///     .with_config_modelpack_segdet(boxes_config, scores_config, seg_config)
638    ///     .build()?;
639    /// # Ok(())
640    /// # }
641    /// ```
642    pub fn with_config_modelpack_segdet(
643        mut self,
644        boxes: configs::Boxes,
645        scores: configs::Scores,
646        segmentation: configs::Segmentation,
647    ) -> Self {
648        let config = ConfigOutputs {
649            outputs: vec![
650                ConfigOutput::Boxes(boxes),
651                ConfigOutput::Scores(scores),
652                ConfigOutput::Segmentation(segmentation),
653            ],
654            ..Default::default()
655        };
656        self.config_src.replace(ConfigSource::Config(config));
657        self
658    }
659
660    /// Loads a ModelPack segmentation split detection model configuration. Use
661    /// `DecoderBuilder.build()` to parse the model configuration.
662    ///
663    /// # Examples
664    /// ```rust
665    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
666    /// # fn main() -> DecoderResult<()> {
667    /// let config0 = configs::Detection {
668    ///     anchors: Some(vec![
669    ///         [0.36666667461395264, 0.31481480598449707],
670    ///         [0.38749998807907104, 0.4740740656852722],
671    ///         [0.5333333611488342, 0.644444465637207],
672    ///     ]),
673    ///     decoder: configs::DecoderType::ModelPack,
674    ///     quantization: Some(configs::QuantTuple(0.08547406643629074, 174)),
675    ///     shape: vec![1, 9, 15, 18],
676    ///     dshape: Vec::new(),
677    ///     normalized: Some(true),
678    /// };
679    /// let config1 = configs::Detection {
680    ///     anchors: Some(vec![
681    ///         [0.13750000298023224, 0.2074074000120163],
682    ///         [0.2541666626930237, 0.21481481194496155],
683    ///         [0.23125000298023224, 0.35185185074806213],
684    ///     ]),
685    ///     decoder: configs::DecoderType::ModelPack,
686    ///     quantization: Some(configs::QuantTuple(0.09929127991199493, 183)),
687    ///     shape: vec![1, 17, 30, 18],
688    ///     dshape: Vec::new(),
689    ///     normalized: Some(true),
690    /// };
691    /// let seg_config = configs::Segmentation {
692    ///     decoder: configs::DecoderType::ModelPack,
693    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
694    ///     shape: vec![1, 640, 640, 2],
695    ///     dshape: Vec::new(),
696    /// };
697    /// let decoder = DecoderBuilder::new()
698    ///     .with_config_modelpack_segdet_split(vec![config0, config1], seg_config)
699    ///     .build()?;
700    /// # Ok(())
701    /// # }
702    /// ```
703    pub fn with_config_modelpack_segdet_split(
704        mut self,
705        boxes: Vec<configs::Detection>,
706        segmentation: configs::Segmentation,
707    ) -> Self {
708        let mut outputs = boxes
709            .into_iter()
710            .map(ConfigOutput::Detection)
711            .collect::<Vec<_>>();
712        outputs.push(ConfigOutput::Segmentation(segmentation));
713        let config = ConfigOutputs {
714            outputs,
715            ..Default::default()
716        };
717        self.config_src.replace(ConfigSource::Config(config));
718        self
719    }
720
721    /// Loads a ModelPack segmentation model configuration. Use
722    /// `DecoderBuilder.build()` to parse the model configuration.
723    ///
724    /// # Examples
725    /// ```rust
726    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
727    /// # fn main() -> DecoderResult<()> {
728    /// let seg_config = configs::Segmentation {
729    ///     decoder: configs::DecoderType::ModelPack,
730    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
731    ///     shape: vec![1, 640, 640, 3],
732    ///     dshape: Vec::new(),
733    /// };
734    /// let decoder = DecoderBuilder::new()
735    ///     .with_config_modelpack_seg(seg_config)
736    ///     .build()?;
737    /// # Ok(())
738    /// # }
739    /// ```
740    pub fn with_config_modelpack_seg(mut self, segmentation: configs::Segmentation) -> Self {
741        let config = ConfigOutputs {
742            outputs: vec![ConfigOutput::Segmentation(segmentation)],
743            ..Default::default()
744        };
745        self.config_src.replace(ConfigSource::Config(config));
746        self
747    }
748
749    /// Add an output to the decoder configuration.
750    ///
751    /// Incrementally builds the model configuration by adding outputs one at
752    /// a time. The decoder resolves the model type from the combination of
753    /// outputs during `build()`.
754    ///
755    /// If `dshape` is non-empty on the output, `shape` is automatically
756    /// derived from it (the size component of each named dimension). This
757    /// prevents conflicts between `shape` and `dshape`.
758    ///
759    /// This uses the programmatic config path. Calling this after
760    /// `with_config_json_str()` or `with_config_yaml_str()` replaces the
761    /// string-based config source.
762    ///
763    /// # Examples
764    /// ```rust
765    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, ConfigOutput, configs};
766    /// # fn main() -> DecoderResult<()> {
767    /// let decoder = DecoderBuilder::new()
768    ///     .add_output(ConfigOutput::Scores(configs::Scores {
769    ///         decoder: configs::DecoderType::Ultralytics,
770    ///         dshape: vec![
771    ///             (configs::DimName::Batch, 1),
772    ///             (configs::DimName::NumClasses, 80),
773    ///             (configs::DimName::NumBoxes, 8400),
774    ///         ],
775    ///         ..Default::default()
776    ///     }))
777    ///     .add_output(ConfigOutput::Boxes(configs::Boxes {
778    ///         decoder: configs::DecoderType::Ultralytics,
779    ///         dshape: vec![
780    ///             (configs::DimName::Batch, 1),
781    ///             (configs::DimName::BoxCoords, 4),
782    ///             (configs::DimName::NumBoxes, 8400),
783    ///         ],
784    ///         ..Default::default()
785    ///     }))
786    ///     .build()?;
787    /// # Ok(())
788    /// # }
789    /// ```
790    pub fn add_output(mut self, output: ConfigOutput) -> Self {
791        if !matches!(self.config_src, Some(ConfigSource::Config(_))) {
792            self.config_src = Some(ConfigSource::Config(ConfigOutputs::default()));
793        }
794        if let Some(ConfigSource::Config(ref mut config)) = self.config_src {
795            config.outputs.push(Self::normalize_output(output));
796        }
797        self
798    }
799
800    /// Sets the decoder version for Ultralytics models.
801    ///
802    /// This is used with `add_output()` to specify the YOLO architecture
803    /// version when it cannot be inferred from the output shapes alone.
804    ///
805    /// - `Yolov5`, `Yolov8`, `Yolo11`: Traditional models requiring external
806    ///   NMS
807    /// - `Yolo26`: End-to-end models with NMS embedded in the model graph
808    ///
809    /// # Examples
810    /// ```rust
811    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, ConfigOutput, configs};
812    /// # fn main() -> DecoderResult<()> {
813    /// let decoder = DecoderBuilder::new()
814    ///     .add_output(ConfigOutput::Detection(configs::Detection {
815    ///         decoder: configs::DecoderType::Ultralytics,
816    ///         dshape: vec![
817    ///             (configs::DimName::Batch, 1),
818    ///             (configs::DimName::NumBoxes, 100),
819    ///             (configs::DimName::NumFeatures, 6),
820    ///         ],
821    ///         ..Default::default()
822    ///     }))
823    ///     .with_decoder_version(configs::DecoderVersion::Yolo26)
824    ///     .build()?;
825    /// # Ok(())
826    /// # }
827    /// ```
828    pub fn with_decoder_version(mut self, version: configs::DecoderVersion) -> Self {
829        if !matches!(self.config_src, Some(ConfigSource::Config(_))) {
830            self.config_src = Some(ConfigSource::Config(ConfigOutputs::default()));
831        }
832        if let Some(ConfigSource::Config(ref mut config)) = self.config_src {
833            config.decoder_version = Some(version);
834        }
835        self
836    }
837
838    /// Normalize an output: if dshape is non-empty, derive shape from it.
839    fn normalize_output(mut output: ConfigOutput) -> ConfigOutput {
840        fn normalize_shape(shape: &mut Vec<usize>, dshape: &[(configs::DimName, usize)]) {
841            if !dshape.is_empty() {
842                *shape = dshape.iter().map(|(_, size)| *size).collect();
843            }
844        }
845        match &mut output {
846            ConfigOutput::Detection(c) => normalize_shape(&mut c.shape, &c.dshape),
847            ConfigOutput::Boxes(c) => normalize_shape(&mut c.shape, &c.dshape),
848            ConfigOutput::Scores(c) => normalize_shape(&mut c.shape, &c.dshape),
849            ConfigOutput::Protos(c) => normalize_shape(&mut c.shape, &c.dshape),
850            ConfigOutput::Segmentation(c) => normalize_shape(&mut c.shape, &c.dshape),
851            ConfigOutput::MaskCoefficients(c) => normalize_shape(&mut c.shape, &c.dshape),
852            ConfigOutput::Mask(c) => normalize_shape(&mut c.shape, &c.dshape),
853            ConfigOutput::Classes(c) => normalize_shape(&mut c.shape, &c.dshape),
854        }
855        output
856    }
857
858    /// Sets the scores threshold of the decoder
859    ///
860    /// # Examples
861    /// ```rust
862    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
863    /// # fn main() -> DecoderResult<()> {
864    /// # let config_json = edgefirst_bench::testdata::read_to_string("modelpack_split.json").to_string();
865    /// let decoder = DecoderBuilder::new()
866    ///     .with_config_json_str(config_json)
867    ///     .with_score_threshold(0.654)
868    ///     .build()?;
869    /// assert_eq!(decoder.score_threshold, 0.654);
870    /// # Ok(())
871    /// # }
872    /// ```
873    pub fn with_score_threshold(mut self, score_threshold: f32) -> Self {
874        self.score_threshold = score_threshold;
875        self
876    }
877
878    /// Sets the IOU threshold of the decoder. Has no effect when NMS is set to
879    /// `None`
880    ///
881    /// # Examples
882    /// ```rust
883    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
884    /// # fn main() -> DecoderResult<()> {
885    /// # let config_json = edgefirst_bench::testdata::read_to_string("modelpack_split.json").to_string();
886    /// let decoder = DecoderBuilder::new()
887    ///     .with_config_json_str(config_json)
888    ///     .with_iou_threshold(0.654)
889    ///     .build()?;
890    /// assert_eq!(decoder.iou_threshold, 0.654);
891    /// # Ok(())
892    /// # }
893    /// ```
894    pub fn with_iou_threshold(mut self, iou_threshold: f32) -> Self {
895        self.iou_threshold = iou_threshold;
896        self
897    }
898
899    /// Sets the NMS mode for the decoder.
900    ///
901    /// - `Some(Nms::Auto)` — resolve from model config (e.g. `edgefirst.json`)
902    ///   or fall back to `ClassAgnostic` (this is the builder default)
903    /// - `Some(Nms::ClassAgnostic)` — class-agnostic NMS: suppress overlapping
904    ///   boxes regardless of class label
905    /// - `Some(Nms::ClassAware)` — class-aware NMS: only suppress boxes that
906    ///   share the same class label AND overlap above the IoU threshold
907    /// - `None` — bypass NMS entirely (for end-to-end models with embedded NMS)
908    ///
909    /// # Examples
910    /// ```rust
911    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, configs::Nms};
912    /// # fn main() -> DecoderResult<()> {
913    /// # let config_json = edgefirst_bench::testdata::read_to_string("modelpack_split.json").to_string();
914    /// let decoder = DecoderBuilder::new()
915    ///     .with_config_json_str(config_json)
916    ///     .with_nms(Some(Nms::ClassAware))
917    ///     .build()?;
918    /// assert_eq!(decoder.nms, Some(Nms::ClassAware));
919    /// # Ok(())
920    /// # }
921    /// ```
922    pub fn with_nms(mut self, nms: Option<configs::Nms>) -> Self {
923        self.nms = nms;
924        self
925    }
926
927    /// Sets the maximum number of candidate boxes fed into NMS after score
928    /// filtering.  Uses partial sort (O(N)) to select the top-K candidates,
929    /// dramatically reducing the O(N²) NMS cost when many low-confidence
930    /// proposals pass the threshold (common with mAP eval at 0.001).
931    ///
932    /// Default: 300.
933    ///
934    /// # ⚠️ Validation vs Deployment
935    ///
936    /// The default is appropriate for **deployment** where
937    /// `score_threshold ≥ 0.25` means few anchors survive filtering and
938    /// top-K is effectively a no-op.
939    ///
940    /// For **COCO mAP evaluation** (`score_threshold ≈ 0.001`), set this to
941    /// the total anchor count (8 400 for standard 640 × 640 YOLO models) or
942    /// to `0` (no limit) so that all score-passing candidates reach NMS.
943    /// Failing to do so causes **~9 pp box mAP loss** — the decoder math is
944    /// correct but the evaluation protocol requires full recall across the
945    /// confidence range.
946    ///
947    /// Post-processing latency scales with candidate count. At deployment
948    /// thresholds the cost difference is negligible; at validation thresholds
949    /// it is measurable but necessary for correct results.
950    ///
951    /// # Examples
952    ///
953    /// Deployment (default top-K, high score threshold):
954    /// ```rust
955    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
956    /// # fn main() -> DecoderResult<()> {
957    /// # let config_json = edgefirst_bench::testdata::read_to_string("modelpack_split.json").to_string();
958    /// let decoder = DecoderBuilder::new()
959    ///     .with_config_json_str(config_json)
960    ///     .with_score_threshold(0.25)
961    ///     // pre_nms_top_k defaults to 300 — appropriate here
962    ///     .build()?;
963    /// # Ok(())
964    /// # }
965    /// ```
966    ///
967    /// COCO mAP evaluation (pass all anchors to NMS):
968    /// ```rust
969    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
970    /// # fn main() -> DecoderResult<()> {
971    /// # let config_json = edgefirst_bench::testdata::read_to_string("modelpack_split.json").to_string();
972    /// let decoder = DecoderBuilder::new()
973    ///     .with_config_json_str(config_json)
974    ///     .with_score_threshold(0.001)
975    ///     .with_pre_nms_top_k(8400)  // all YOLO anchors
976    ///     .with_max_det(300)
977    ///     .build()?;
978    /// assert_eq!(decoder.pre_nms_top_k, 8400);
979    /// # Ok(())
980    /// # }
981    /// ```
982    pub fn with_pre_nms_top_k(mut self, pre_nms_top_k: usize) -> Self {
983        self.pre_nms_top_k = pre_nms_top_k;
984        self
985    }
986
987    /// Sets the maximum number of detections returned after NMS.
988    /// Matches the Ultralytics `max_det` parameter.
989    ///
990    /// Default: 300.
991    ///
992    /// # Examples
993    /// ```rust
994    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
995    /// # fn main() -> DecoderResult<()> {
996    /// # let config_json = edgefirst_bench::testdata::read_to_string("modelpack_split.json").to_string();
997    /// let decoder = DecoderBuilder::new()
998    ///     .with_config_json_str(config_json)
999    ///     .with_max_det(100)
1000    ///     .build()?;
1001    /// assert_eq!(decoder.max_det, 100);
1002    /// # Ok(())
1003    /// # }
1004    /// ```
1005    pub fn with_max_det(mut self, max_det: usize) -> Self {
1006        self.max_det = max_det;
1007        self
1008    }
1009
1010    /// Sets the model input dimensions `(width, height)` consumed by the
1011    /// EDGEAI-1303 normalization path. Use this when building via
1012    /// [`with_config`](Self::with_config) / [`add_output`](Self::add_output)
1013    /// (no schema) and the model emits pixel-space boxes that need to be
1014    /// divided by `(W, H)` before NMS.
1015    ///
1016    /// When the builder is also configured with [`with_schema`](Self::with_schema)
1017    /// (or `with_config_json_str` / `with_config_yaml_str`) and the schema's
1018    /// `input` block carries usable dims, this explicit override **takes
1019    /// precedence** so callers can correct schemas with missing or wrong
1020    /// input specs without rewriting the schema.
1021    ///
1022    /// # Examples
1023    /// ```rust
1024    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
1025    /// # fn main() -> DecoderResult<()> {
1026    /// # let config_yaml = edgefirst_bench::testdata::read_to_string("modelpack_split.yaml").to_string();
1027    /// let decoder = DecoderBuilder::new()
1028    ///     .with_config_yaml_str(config_yaml)
1029    ///     .with_input_dims(640, 640)
1030    ///     .build()?;
1031    /// assert_eq!(decoder.input_dims(), Some((640, 640)));
1032    /// # Ok(())
1033    /// # }
1034    /// ```
1035    pub fn with_input_dims(mut self, width: usize, height: usize) -> Self {
1036        self.input_dims = Some((width, height));
1037        self
1038    }
1039
1040    /// Builds the decoder with the given settings. If the config is a JSON or
1041    /// YAML string, this will deserialize the JSON or YAML and then parse the
1042    /// configuration information.
1043    ///
1044    /// # Examples
1045    /// ```rust
1046    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
1047    /// # fn main() -> DecoderResult<()> {
1048    /// # let config_json = edgefirst_bench::testdata::read_to_string("modelpack_split.json").to_string();
1049    /// let decoder = DecoderBuilder::new()
1050    ///     .with_config_json_str(config_json)
1051    ///     .with_score_threshold(0.654)
1052    ///     .build()?;
1053    /// # Ok(())
1054    /// # }
1055    /// ```
1056    pub fn build(self) -> Result<Decoder, DecoderError> {
1057        let decode_dtype = self.decode_dtype;
1058        let explicit_input_dims = self.input_dims;
1059        let (config, decode_program, per_scale_plan, schema_input_dims) = match self.config_src {
1060            Some(ConfigSource::Json(s)) => {
1061                Self::build_from_schema(SchemaV2::parse_json(&s)?, decode_dtype)?
1062            }
1063            Some(ConfigSource::Yaml(s)) => {
1064                Self::build_from_schema(SchemaV2::parse_yaml(&s)?, decode_dtype)?
1065            }
1066            Some(ConfigSource::Config(c)) => (c, None, None, None),
1067            Some(ConfigSource::Schema(schema)) => Self::build_from_schema(schema, decode_dtype)?,
1068            None => return Err(DecoderError::NoConfig),
1069        };
1070        // Explicit `with_input_dims(W, H)` overrides any schema-derived
1071        // value so callers can fix schemas with missing or wrong input
1072        // specs without rewriting the schema (EDGEAI-1303).
1073        let input_dims = explicit_input_dims.or(schema_input_dims);
1074
1075        // Enforce the physical-order contract: when dshape is present
1076        // it must describe the same axes as shape in the same order,
1077        // listed from outermost to innermost. Ambiguous-layout roles
1078        // (Protos, Boxes, Scores, MaskCoefficients, Classes, Detection)
1079        // may still omit dshape when shape is already in the decoder's
1080        // canonical order.
1081        for output in &config.outputs {
1082            Decoder::validate_output_layout(output.into())?;
1083        }
1084
1085        // Extract normalized flag from config outputs.
1086        //
1087        // The per-scale subsystem (DFL/LTRB → dist2bbox → sigmoid) emits
1088        // boxes in pixel coordinates by design — `(grid + dist) * stride`
1089        // — independently of any `normalized: true` annotation in the
1090        // schema. Override to `Some(false)` so the per-scale bridge's
1091        // call to `yolo::maybe_normalize_boxes_in_place` fires and
1092        // divides by `input_dims`, yielding `[0, 1]` output. The
1093        // accessor `Decoder::normalized_boxes()` applies the
1094        // pixel→normalized upgrade for the per-scale path and for any
1095        // legacy `ModelType` whose every entry point normalizes
1096        // uniformly (currently `YoloSegDet`, `YoloSplitSegDet`, and
1097        // `YoloSegDet2Way`); other paths surface the raw flag.
1098        let normalized = if per_scale_plan.is_some() {
1099            Some(false)
1100        } else {
1101            Self::get_normalized(&config.outputs)
1102        };
1103
1104        // NMS precedence:
1105        //   Some(ClassAgnostic|ClassAware) → explicit user override
1106        //   Some(Auto) → resolve from config, fallback to ClassAgnostic
1107        //   None → NMS disabled (explicit)
1108        //
1109        // `Auto` is always resolved to a concrete mode here — it never
1110        // persists into the built `Decoder`, even if the config itself
1111        // contains `Auto`.
1112        let resolve_auto = |nms: Option<configs::Nms>| match nms {
1113            Some(configs::Nms::Auto) | None => Some(configs::Nms::ClassAgnostic),
1114            concrete => concrete,
1115        };
1116        let nms = match self.nms {
1117            Some(configs::Nms::Auto) => resolve_auto(config.nms),
1118            other => other,
1119        };
1120        // When the per-scale path is active, the per_scale subsystem owns
1121        // model decoding entirely — `decode` / `decode_proto` short-circuit
1122        // on `per_scale.is_some()` before reading `model_type`. Skip the
1123        // legacy ModelType validation, which otherwise rejects per-scale
1124        // schemas that carry `decoder_version: yolo26` (an
1125        // "end-to-end" hint) but use the per-scale split outputs rather
1126        // than the end-to-end split-output shape the legacy validator
1127        // expects. We keep a placeholder `ModelType` so the field remains
1128        // valid; it is dead state for per-scale Decoders.
1129        let model_type = if per_scale_plan.is_some() {
1130            // Drop the un-needed config; the per-scale subsystem owns it.
1131            drop(config);
1132            ModelType::PerScale
1133        } else {
1134            Self::get_model_type(config)?
1135        };
1136
1137        let per_scale = per_scale_plan
1138            .map(|plan| std::sync::Mutex::new(crate::per_scale::PerScaleDecoder::new(plan)));
1139
1140        debug_assert!(
1141            !matches!(nms, Some(configs::Nms::Auto)),
1142            "Nms::Auto must be resolved to a concrete mode before building Decoder"
1143        );
1144
1145        Ok(Decoder {
1146            model_type,
1147            iou_threshold: self.iou_threshold,
1148            score_threshold: self.score_threshold,
1149            nms,
1150            pre_nms_top_k: self.pre_nms_top_k,
1151            max_det: self.max_det,
1152            normalized,
1153            input_dims,
1154            decode_program,
1155            per_scale,
1156        })
1157    }
1158
1159    /// Validate a [`SchemaV2`] and lower it to the (legacy `ConfigOutputs`,
1160    /// optional `DecodeProgram`, optional `PerScalePlan`) tuple the rest
1161    /// of `build()` consumes.
1162    ///
1163    /// Centralises the v2 lowering so JSON, YAML, and direct
1164    /// `with_schema` callers all go through the same validation,
1165    /// merge-program, and per-scale plan construction. `SchemaV2::parse_json`
1166    /// / `parse_yaml` already auto-detect v1 vs v2 input and return a v2
1167    /// schema either way (v1 inputs are upgraded in memory via
1168    /// `from_v1`), so this helper is the sole place that turns a
1169    /// schema into builder-ready state.
1170    #[allow(clippy::type_complexity)]
1171    fn build_from_schema(
1172        schema: SchemaV2,
1173        decode_dtype: DecodeDtype,
1174    ) -> Result<
1175        (
1176            ConfigOutputs,
1177            Option<DecodeProgram>,
1178            Option<PerScalePlan>,
1179            Option<(usize, usize)>,
1180        ),
1181        DecoderError,
1182    > {
1183        schema.validate()?;
1184        let program = DecodeProgram::try_from_schema(&schema)?;
1185        let per_scale = PerScalePlan::try_from_schema(&schema, decode_dtype)?;
1186        // Extract model input (W, H) from `input.shape`/`dshape`. Used by
1187        // the legacy decode path to honour `normalized: false` (see
1188        // EDGEAI-1303). `None` is fine when the schema omits the input
1189        // spec — the decoder falls back to the protobox `>2.0` reject.
1190        let input_dims = schema.input.as_ref().and_then(input_dims_from_spec);
1191        let legacy = schema.to_legacy_config_outputs()?;
1192        Ok((legacy, program, per_scale, input_dims))
1193    }
1194
1195    /// Extracts the normalized flag from config outputs.
1196    /// - `Some(true)`: Boxes are in normalized [0,1] coordinates
1197    /// - `Some(false)`: Boxes are in pixel coordinates
1198    /// - `None`: Unknown (not specified in config), caller must infer
1199    fn get_normalized(outputs: &[ConfigOutput]) -> Option<bool> {
1200        for output in outputs {
1201            match output {
1202                ConfigOutput::Detection(det) => return det.normalized,
1203                ConfigOutput::Boxes(boxes) => return boxes.normalized,
1204                _ => {}
1205            }
1206        }
1207        None // not specified
1208    }
1209
1210    fn get_model_type(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
1211        // yolo or modelpack
1212        let mut yolo = false;
1213        let mut modelpack = false;
1214        for c in &configs.outputs {
1215            match c.decoder() {
1216                DecoderType::ModelPack => modelpack = true,
1217                DecoderType::Ultralytics => yolo = true,
1218            }
1219        }
1220        match (modelpack, yolo) {
1221            (true, true) => Err(DecoderError::InvalidConfig(
1222                "Both ModelPack and Yolo outputs found in config".to_string(),
1223            )),
1224            (true, false) => Self::get_model_type_modelpack(configs),
1225            (false, true) => Self::get_model_type_yolo(configs),
1226            (false, false) => Err(DecoderError::InvalidConfig(
1227                "No outputs found in config".to_string(),
1228            )),
1229        }
1230    }
1231
1232    fn get_model_type_yolo(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
1233        let mut boxes = None;
1234        let mut protos = None;
1235        let mut split_boxes = None;
1236        let mut split_scores = None;
1237        let mut split_mask_coeff = None;
1238        let mut split_classes = None;
1239        for c in configs.outputs {
1240            match c {
1241                ConfigOutput::Detection(detection) => boxes = Some(detection),
1242                ConfigOutput::Segmentation(_) => {
1243                    return Err(DecoderError::InvalidConfig(
1244                        "Invalid Segmentation output with Yolo decoder".to_string(),
1245                    ));
1246                }
1247                ConfigOutput::Protos(protos_) => protos = Some(protos_),
1248                ConfigOutput::Mask(_) => {
1249                    return Err(DecoderError::InvalidConfig(
1250                        "Invalid Mask output with Yolo decoder".to_string(),
1251                    ));
1252                }
1253                ConfigOutput::Scores(scores) => split_scores = Some(scores),
1254                ConfigOutput::Boxes(boxes) => split_boxes = Some(boxes),
1255                ConfigOutput::MaskCoefficients(mask_coeff) => split_mask_coeff = Some(mask_coeff),
1256                ConfigOutput::Classes(classes) => split_classes = Some(classes),
1257            }
1258        }
1259
1260        // Use end-to-end model types when:
1261        // 1. decoder_version is explicitly set to Yolo26 (definitive), OR
1262        //    decoder_version is not set but the dshapes are (batch, num_boxes,
1263        //    num_features)
1264        let is_end_to_end_dshape = boxes.as_ref().is_some_and(|b| {
1265            let dims = b.dshape.iter().map(|(d, _)| *d).collect::<Vec<_>>();
1266            dims == vec![DimName::Batch, DimName::NumBoxes, DimName::NumFeatures]
1267        });
1268
1269        let is_end_to_end = configs
1270            .decoder_version
1271            .map(|v| v.is_end_to_end())
1272            .unwrap_or(is_end_to_end_dshape);
1273
1274        if is_end_to_end {
1275            if let Some(boxes) = boxes {
1276                if let Some(protos) = protos {
1277                    Self::verify_yolo_seg_det_26(&boxes, &protos)?;
1278                    return Ok(ModelType::YoloEndToEndSegDet { boxes, protos });
1279                } else {
1280                    Self::verify_yolo_det_26(&boxes)?;
1281                    return Ok(ModelType::YoloEndToEndDet { boxes });
1282                }
1283            } else if let (Some(split_boxes), Some(split_scores), Some(split_classes)) =
1284                (split_boxes, split_scores, split_classes)
1285            {
1286                if let (Some(split_mask_coeff), Some(protos)) = (split_mask_coeff, protos) {
1287                    Self::verify_yolo_split_end_to_end_segdet(
1288                        &split_boxes,
1289                        &split_scores,
1290                        &split_classes,
1291                        &split_mask_coeff,
1292                        &protos,
1293                    )?;
1294                    return Ok(ModelType::YoloSplitEndToEndSegDet {
1295                        boxes: split_boxes,
1296                        scores: split_scores,
1297                        classes: split_classes,
1298                        mask_coeff: split_mask_coeff,
1299                        protos,
1300                    });
1301                }
1302                Self::verify_yolo_split_end_to_end_det(
1303                    &split_boxes,
1304                    &split_scores,
1305                    &split_classes,
1306                )?;
1307                return Ok(ModelType::YoloSplitEndToEndDet {
1308                    boxes: split_boxes,
1309                    scores: split_scores,
1310                    classes: split_classes,
1311                });
1312            } else {
1313                return Err(DecoderError::InvalidConfig(
1314                    "Invalid Yolo end-to-end model outputs".to_string(),
1315                ));
1316            }
1317        }
1318
1319        if let Some(boxes) = boxes {
1320            match (split_mask_coeff, protos) {
1321                (Some(mask_coeff), Some(protos)) => {
1322                    // 2-way split: combined detection + separate mask_coeff + protos
1323                    Self::verify_yolo_seg_det_2way(&boxes, &mask_coeff, &protos)?;
1324                    Ok(ModelType::YoloSegDet2Way {
1325                        boxes,
1326                        mask_coeff,
1327                        protos,
1328                    })
1329                }
1330                (_, Some(protos)) => {
1331                    // Unsplit: mask_coefs embedded in detection tensor
1332                    Self::verify_yolo_seg_det(&boxes, &protos)?;
1333                    Ok(ModelType::YoloSegDet { boxes, protos })
1334                }
1335                _ => {
1336                    Self::verify_yolo_det(&boxes)?;
1337                    Ok(ModelType::YoloDet { boxes })
1338                }
1339            }
1340        } else if let (Some(boxes), Some(scores)) = (split_boxes, split_scores) {
1341            if let (Some(mask_coeff), Some(protos)) = (split_mask_coeff, protos) {
1342                Self::verify_yolo_split_segdet(&boxes, &scores, &mask_coeff, &protos)?;
1343                Ok(ModelType::YoloSplitSegDet {
1344                    boxes,
1345                    scores,
1346                    mask_coeff,
1347                    protos,
1348                })
1349            } else {
1350                Self::verify_yolo_split_det(&boxes, &scores)?;
1351                Ok(ModelType::YoloSplitDet { boxes, scores })
1352            }
1353        } else {
1354            Err(DecoderError::InvalidConfig(
1355                "Invalid Yolo model outputs".to_string(),
1356            ))
1357        }
1358    }
1359
1360    fn verify_yolo_det(detect: &configs::Detection) -> Result<(), DecoderError> {
1361        if detect.shape.len() != 3 {
1362            return Err(DecoderError::InvalidConfig(format!(
1363                "Invalid Yolo Detection shape {:?}",
1364                detect.shape
1365            )));
1366        }
1367
1368        Self::verify_dshapes(
1369            &detect.dshape,
1370            &detect.shape,
1371            "Detection",
1372            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1373        )?;
1374        if !detect.dshape.is_empty() {
1375            Self::get_class_count(&detect.dshape, None, None)?;
1376        } else {
1377            Self::get_class_count_no_dshape(detect.into(), None)?;
1378        }
1379
1380        Ok(())
1381    }
1382
1383    fn verify_yolo_det_26(detect: &configs::Detection) -> Result<(), DecoderError> {
1384        if detect.shape.len() != 3 {
1385            return Err(DecoderError::InvalidConfig(format!(
1386                "Invalid Yolo Detection shape {:?}",
1387                detect.shape
1388            )));
1389        }
1390
1391        Self::verify_dshapes(
1392            &detect.dshape,
1393            &detect.shape,
1394            "Detection",
1395            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1396        )?;
1397
1398        if !detect.shape.contains(&6) {
1399            return Err(DecoderError::InvalidConfig(
1400                "Yolo26 Detection must have 6 features".to_string(),
1401            ));
1402        }
1403
1404        Ok(())
1405    }
1406
1407    fn verify_yolo_seg_det(
1408        detection: &configs::Detection,
1409        protos: &configs::Protos,
1410    ) -> Result<(), DecoderError> {
1411        if detection.shape.len() != 3 {
1412            return Err(DecoderError::InvalidConfig(format!(
1413                "Invalid Yolo Detection shape {:?}",
1414                detection.shape
1415            )));
1416        }
1417        if protos.shape.len() != 4 {
1418            return Err(DecoderError::InvalidConfig(format!(
1419                "Invalid Yolo Protos shape {:?}",
1420                protos.shape
1421            )));
1422        }
1423
1424        Self::verify_dshapes(
1425            &detection.dshape,
1426            &detection.shape,
1427            "Detection",
1428            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1429        )?;
1430        Self::verify_dshapes(
1431            &protos.dshape,
1432            &protos.shape,
1433            "Protos",
1434            &[
1435                DimName::Batch,
1436                DimName::Height,
1437                DimName::Width,
1438                DimName::NumProtos,
1439            ],
1440        )?;
1441
1442        let protos_count = Self::get_protos_count(&protos.dshape)
1443            .unwrap_or_else(|| protos.shape[1].min(protos.shape[3]));
1444        log::debug!("Protos count: {}", protos_count);
1445        log::debug!("Detection dshape: {:?}", detection.dshape);
1446        let classes = if !detection.dshape.is_empty() {
1447            Self::get_class_count(&detection.dshape, Some(protos_count), None)?
1448        } else {
1449            Self::get_class_count_no_dshape(detection.into(), Some(protos_count))?
1450        };
1451
1452        if classes == 0 {
1453            return Err(DecoderError::InvalidConfig(
1454                "Yolo Segmentation Detection has zero classes".to_string(),
1455            ));
1456        }
1457
1458        Ok(())
1459    }
1460
1461    fn verify_yolo_seg_det_2way(
1462        detection: &configs::Detection,
1463        mask_coeff: &configs::MaskCoefficients,
1464        protos: &configs::Protos,
1465    ) -> Result<(), DecoderError> {
1466        if detection.shape.len() != 3 {
1467            return Err(DecoderError::InvalidConfig(format!(
1468                "Invalid Yolo 2-Way Detection shape {:?}",
1469                detection.shape
1470            )));
1471        }
1472        if mask_coeff.shape.len() != 3 {
1473            return Err(DecoderError::InvalidConfig(format!(
1474                "Invalid Yolo 2-Way Mask Coefficients shape {:?}",
1475                mask_coeff.shape
1476            )));
1477        }
1478        if protos.shape.len() != 4 {
1479            return Err(DecoderError::InvalidConfig(format!(
1480                "Invalid Yolo 2-Way Protos shape {:?}",
1481                protos.shape
1482            )));
1483        }
1484
1485        Self::verify_dshapes(
1486            &detection.dshape,
1487            &detection.shape,
1488            "Detection",
1489            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1490        )?;
1491        Self::verify_dshapes(
1492            &mask_coeff.dshape,
1493            &mask_coeff.shape,
1494            "Mask Coefficients",
1495            &[DimName::Batch, DimName::NumProtos, DimName::NumBoxes],
1496        )?;
1497        Self::verify_dshapes(
1498            &protos.dshape,
1499            &protos.shape,
1500            "Protos",
1501            &[
1502                DimName::Batch,
1503                DimName::Height,
1504                DimName::Width,
1505                DimName::NumProtos,
1506            ],
1507        )?;
1508
1509        // Validate num_boxes match between detection and mask_coeff
1510        let det_num = Self::get_box_count(&detection.dshape).unwrap_or(detection.shape[2]);
1511        let mask_num = Self::get_box_count(&mask_coeff.dshape).unwrap_or(mask_coeff.shape[2]);
1512        if det_num != mask_num {
1513            return Err(DecoderError::InvalidConfig(format!(
1514                "Yolo 2-Way Detection num_boxes {} incompatible with Mask Coefficients num_boxes {}",
1515                det_num, mask_num
1516            )));
1517        }
1518
1519        // Validate mask_coeff channels match protos channels
1520        let mask_channels = if !mask_coeff.dshape.is_empty() {
1521            Self::get_protos_count(&mask_coeff.dshape).ok_or_else(|| {
1522                DecoderError::InvalidConfig(
1523                    "Could not find num_protos in mask_coeff config".to_string(),
1524                )
1525            })?
1526        } else {
1527            mask_coeff.shape[1]
1528        };
1529        let proto_channels = if !protos.dshape.is_empty() {
1530            Self::get_protos_count(&protos.dshape).ok_or_else(|| {
1531                DecoderError::InvalidConfig(
1532                    "Could not find num_protos in protos config".to_string(),
1533                )
1534            })?
1535        } else {
1536            protos.shape[1].min(protos.shape[3])
1537        };
1538        if mask_channels != proto_channels {
1539            return Err(DecoderError::InvalidConfig(format!(
1540                "Yolo 2-Way Protos channels {} incompatible with Mask Coefficients channels {}",
1541                proto_channels, mask_channels
1542            )));
1543        }
1544
1545        // Validate detection has classes (nc+4 features, no mask_coefs embedded)
1546        if !detection.dshape.is_empty() {
1547            Self::get_class_count(&detection.dshape, None, None)?;
1548        } else {
1549            Self::get_class_count_no_dshape(detection.into(), None)?;
1550        }
1551
1552        Ok(())
1553    }
1554
1555    fn verify_yolo_seg_det_26(
1556        detection: &configs::Detection,
1557        protos: &configs::Protos,
1558    ) -> Result<(), DecoderError> {
1559        if detection.shape.len() != 3 {
1560            return Err(DecoderError::InvalidConfig(format!(
1561                "Invalid Yolo Detection shape {:?}",
1562                detection.shape
1563            )));
1564        }
1565        if protos.shape.len() != 4 {
1566            return Err(DecoderError::InvalidConfig(format!(
1567                "Invalid Yolo Protos shape {:?}",
1568                protos.shape
1569            )));
1570        }
1571
1572        Self::verify_dshapes(
1573            &detection.dshape,
1574            &detection.shape,
1575            "Detection",
1576            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1577        )?;
1578        Self::verify_dshapes(
1579            &protos.dshape,
1580            &protos.shape,
1581            "Protos",
1582            &[
1583                DimName::Batch,
1584                DimName::Height,
1585                DimName::Width,
1586                DimName::NumProtos,
1587            ],
1588        )?;
1589
1590        let protos_count = Self::get_protos_count(&protos.dshape)
1591            .unwrap_or_else(|| protos.shape[1].min(protos.shape[3]));
1592        log::debug!("Protos count: {}", protos_count);
1593        log::debug!("Detection dshape: {:?}", detection.dshape);
1594
1595        if !detection.shape.contains(&(6 + protos_count)) {
1596            return Err(DecoderError::InvalidConfig(format!(
1597                "Yolo26 Segmentation Detection must have num_features be 6 + num_protos = {}",
1598                6 + protos_count
1599            )));
1600        }
1601
1602        Ok(())
1603    }
1604
1605    fn verify_yolo_split_det(
1606        boxes: &configs::Boxes,
1607        scores: &configs::Scores,
1608    ) -> Result<(), DecoderError> {
1609        if boxes.shape.len() != 3 {
1610            return Err(DecoderError::InvalidConfig(format!(
1611                "Invalid Yolo Split Boxes shape {:?}",
1612                boxes.shape
1613            )));
1614        }
1615        if scores.shape.len() != 3 {
1616            return Err(DecoderError::InvalidConfig(format!(
1617                "Invalid Yolo Split Scores shape {:?}",
1618                scores.shape
1619            )));
1620        }
1621
1622        Self::verify_dshapes(
1623            &boxes.dshape,
1624            &boxes.shape,
1625            "Boxes",
1626            &[DimName::Batch, DimName::BoxCoords, DimName::NumBoxes],
1627        )?;
1628        Self::verify_dshapes(
1629            &scores.dshape,
1630            &scores.shape,
1631            "Scores",
1632            &[DimName::Batch, DimName::NumClasses, DimName::NumBoxes],
1633        )?;
1634
1635        let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[2]);
1636        let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[2]);
1637
1638        if boxes_num != scores_num {
1639            return Err(DecoderError::InvalidConfig(format!(
1640                "Yolo Split Detection Boxes num {} incompatible with Scores num {}",
1641                boxes_num, scores_num
1642            )));
1643        }
1644
1645        Ok(())
1646    }
1647
1648    fn verify_yolo_split_segdet(
1649        boxes: &configs::Boxes,
1650        scores: &configs::Scores,
1651        mask_coeff: &configs::MaskCoefficients,
1652        protos: &configs::Protos,
1653    ) -> Result<(), DecoderError> {
1654        if boxes.shape.len() != 3 {
1655            return Err(DecoderError::InvalidConfig(format!(
1656                "Invalid Yolo Split Boxes shape {:?}",
1657                boxes.shape
1658            )));
1659        }
1660        if scores.shape.len() != 3 {
1661            return Err(DecoderError::InvalidConfig(format!(
1662                "Invalid Yolo Split Scores shape {:?}",
1663                scores.shape
1664            )));
1665        }
1666
1667        if mask_coeff.shape.len() != 3 {
1668            return Err(DecoderError::InvalidConfig(format!(
1669                "Invalid Yolo Split Mask Coefficients shape {:?}",
1670                mask_coeff.shape
1671            )));
1672        }
1673
1674        if protos.shape.len() != 4 {
1675            return Err(DecoderError::InvalidConfig(format!(
1676                "Invalid Yolo Protos shape {:?}",
1677                mask_coeff.shape
1678            )));
1679        }
1680
1681        Self::verify_dshapes(
1682            &boxes.dshape,
1683            &boxes.shape,
1684            "Boxes",
1685            &[DimName::Batch, DimName::BoxCoords, DimName::NumBoxes],
1686        )?;
1687        Self::verify_dshapes(
1688            &scores.dshape,
1689            &scores.shape,
1690            "Scores",
1691            &[DimName::Batch, DimName::NumClasses, DimName::NumBoxes],
1692        )?;
1693        Self::verify_dshapes(
1694            &mask_coeff.dshape,
1695            &mask_coeff.shape,
1696            "Mask Coefficients",
1697            &[DimName::Batch, DimName::NumProtos, DimName::NumBoxes],
1698        )?;
1699        Self::verify_dshapes(
1700            &protos.dshape,
1701            &protos.shape,
1702            "Protos",
1703            &[
1704                DimName::Batch,
1705                DimName::Height,
1706                DimName::Width,
1707                DimName::NumProtos,
1708            ],
1709        )?;
1710
1711        let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[2]);
1712        let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[2]);
1713        let mask_num = Self::get_box_count(&mask_coeff.dshape).unwrap_or(mask_coeff.shape[2]);
1714
1715        let mask_channels = if !mask_coeff.dshape.is_empty() {
1716            Self::get_protos_count(&mask_coeff.dshape).ok_or_else(|| {
1717                DecoderError::InvalidConfig("Could not find num_protos in config".to_string())
1718            })?
1719        } else {
1720            mask_coeff.shape[1]
1721        };
1722        let proto_channels = if !protos.dshape.is_empty() {
1723            Self::get_protos_count(&protos.dshape).ok_or_else(|| {
1724                DecoderError::InvalidConfig("Could not find num_protos in config".to_string())
1725            })?
1726        } else {
1727            protos.shape[1].min(protos.shape[3])
1728        };
1729
1730        if boxes_num != scores_num {
1731            return Err(DecoderError::InvalidConfig(format!(
1732                "Yolo Split Detection Boxes num {} incompatible with Scores num {}",
1733                boxes_num, scores_num
1734            )));
1735        }
1736
1737        if boxes_num != mask_num {
1738            return Err(DecoderError::InvalidConfig(format!(
1739                "Yolo Split Detection Boxes num {} incompatible with Mask Coefficients num {}",
1740                boxes_num, mask_num
1741            )));
1742        }
1743
1744        if proto_channels != mask_channels {
1745            return Err(DecoderError::InvalidConfig(format!(
1746                "Yolo Protos channels {} incompatible with Mask Coefficients channels {}",
1747                proto_channels, mask_channels
1748            )));
1749        }
1750
1751        Ok(())
1752    }
1753
1754    fn verify_yolo_split_end_to_end_det(
1755        boxes: &configs::Boxes,
1756        scores: &configs::Scores,
1757        classes: &configs::Classes,
1758    ) -> Result<(), DecoderError> {
1759        if boxes.shape.len() != 3 || !boxes.shape.contains(&4) {
1760            return Err(DecoderError::InvalidConfig(format!(
1761                "Split end-to-end boxes must be [batch, N, 4], got {:?}",
1762                boxes.shape
1763            )));
1764        }
1765        if scores.shape.len() != 3 || !scores.shape.contains(&1) {
1766            return Err(DecoderError::InvalidConfig(format!(
1767                "Split end-to-end scores must be [batch, N, 1], got {:?}",
1768                scores.shape
1769            )));
1770        }
1771        if classes.shape.len() != 3 || !classes.shape.contains(&1) {
1772            return Err(DecoderError::InvalidConfig(format!(
1773                "Split end-to-end classes must be [batch, N, 1], got {:?}",
1774                classes.shape
1775            )));
1776        }
1777        Ok(())
1778    }
1779
1780    fn verify_yolo_split_end_to_end_segdet(
1781        boxes: &configs::Boxes,
1782        scores: &configs::Scores,
1783        classes: &configs::Classes,
1784        mask_coeff: &configs::MaskCoefficients,
1785        protos: &configs::Protos,
1786    ) -> Result<(), DecoderError> {
1787        Self::verify_yolo_split_end_to_end_det(boxes, scores, classes)?;
1788        if mask_coeff.shape.len() != 3 {
1789            return Err(DecoderError::InvalidConfig(format!(
1790                "Invalid split end-to-end mask coefficients shape {:?}",
1791                mask_coeff.shape
1792            )));
1793        }
1794        if protos.shape.len() != 4 {
1795            return Err(DecoderError::InvalidConfig(format!(
1796                "Invalid protos shape {:?}",
1797                protos.shape
1798            )));
1799        }
1800        Ok(())
1801    }
1802
1803    fn get_model_type_modelpack(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
1804        let mut split_decoders = Vec::new();
1805        let mut segment_ = None;
1806        let mut scores_ = None;
1807        let mut boxes_ = None;
1808        for c in configs.outputs {
1809            match c {
1810                ConfigOutput::Detection(detection) => split_decoders.push(detection),
1811                ConfigOutput::Segmentation(segmentation) => segment_ = Some(segmentation),
1812                ConfigOutput::Mask(_) => {}
1813                ConfigOutput::Protos(_) => {
1814                    return Err(DecoderError::InvalidConfig(
1815                        "ModelPack should not have protos".to_string(),
1816                    ));
1817                }
1818                ConfigOutput::Scores(scores) => scores_ = Some(scores),
1819                ConfigOutput::Boxes(boxes) => boxes_ = Some(boxes),
1820                ConfigOutput::MaskCoefficients(_) => {
1821                    return Err(DecoderError::InvalidConfig(
1822                        "ModelPack should not have mask coefficients".to_string(),
1823                    ));
1824                }
1825                ConfigOutput::Classes(_) => {
1826                    return Err(DecoderError::InvalidConfig(
1827                        "ModelPack should not have classes output".to_string(),
1828                    ));
1829                }
1830            }
1831        }
1832
1833        if let Some(segmentation) = segment_ {
1834            if !split_decoders.is_empty() {
1835                let classes = Self::verify_modelpack_split_det(&split_decoders)?;
1836                Self::verify_modelpack_seg(&segmentation, Some(classes))?;
1837                Ok(ModelType::ModelPackSegDetSplit {
1838                    detection: split_decoders,
1839                    segmentation,
1840                })
1841            } else if let (Some(scores), Some(boxes)) = (scores_, boxes_) {
1842                let classes = Self::verify_modelpack_det(&boxes, &scores)?;
1843                Self::verify_modelpack_seg(&segmentation, Some(classes))?;
1844                Ok(ModelType::ModelPackSegDet {
1845                    boxes,
1846                    scores,
1847                    segmentation,
1848                })
1849            } else {
1850                Self::verify_modelpack_seg(&segmentation, None)?;
1851                Ok(ModelType::ModelPackSeg { segmentation })
1852            }
1853        } else if !split_decoders.is_empty() {
1854            Self::verify_modelpack_split_det(&split_decoders)?;
1855            Ok(ModelType::ModelPackDetSplit {
1856                detection: split_decoders,
1857            })
1858        } else if let (Some(scores), Some(boxes)) = (scores_, boxes_) {
1859            Self::verify_modelpack_det(&boxes, &scores)?;
1860            Ok(ModelType::ModelPackDet { boxes, scores })
1861        } else {
1862            Err(DecoderError::InvalidConfig(
1863                "Invalid ModelPack model outputs".to_string(),
1864            ))
1865        }
1866    }
1867
1868    fn verify_modelpack_det(
1869        boxes: &configs::Boxes,
1870        scores: &configs::Scores,
1871    ) -> Result<usize, DecoderError> {
1872        if boxes.shape.len() != 4 {
1873            return Err(DecoderError::InvalidConfig(format!(
1874                "Invalid ModelPack Boxes shape {:?}",
1875                boxes.shape
1876            )));
1877        }
1878        if scores.shape.len() != 3 {
1879            return Err(DecoderError::InvalidConfig(format!(
1880                "Invalid ModelPack Scores shape {:?}",
1881                scores.shape
1882            )));
1883        }
1884
1885        Self::verify_dshapes(
1886            &boxes.dshape,
1887            &boxes.shape,
1888            "Boxes",
1889            &[
1890                DimName::Batch,
1891                DimName::NumBoxes,
1892                DimName::Padding,
1893                DimName::BoxCoords,
1894            ],
1895        )?;
1896        Self::verify_dshapes(
1897            &scores.dshape,
1898            &scores.shape,
1899            "Scores",
1900            &[DimName::Batch, DimName::NumBoxes, DimName::NumClasses],
1901        )?;
1902
1903        let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[1]);
1904        let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[1]);
1905
1906        if boxes_num != scores_num {
1907            return Err(DecoderError::InvalidConfig(format!(
1908                "ModelPack Detection Boxes num {} incompatible with Scores num {}",
1909                boxes_num, scores_num
1910            )));
1911        }
1912
1913        let num_classes = if !scores.dshape.is_empty() {
1914            Self::get_class_count(&scores.dshape, None, None)?
1915        } else {
1916            Self::get_class_count_no_dshape(scores.into(), None)?
1917        };
1918
1919        Ok(num_classes)
1920    }
1921
1922    fn verify_modelpack_split_det(boxes: &[configs::Detection]) -> Result<usize, DecoderError> {
1923        let mut num_classes = None;
1924        for b in boxes {
1925            let Some(num_anchors) = b.anchors.as_ref().map(|a| a.len()) else {
1926                return Err(DecoderError::InvalidConfig(
1927                    "ModelPack Split Detection missing anchors".to_string(),
1928                ));
1929            };
1930
1931            if num_anchors == 0 {
1932                return Err(DecoderError::InvalidConfig(
1933                    "ModelPack Split Detection has zero anchors".to_string(),
1934                ));
1935            }
1936
1937            if b.shape.len() != 4 {
1938                return Err(DecoderError::InvalidConfig(format!(
1939                    "Invalid ModelPack Split Detection shape {:?}",
1940                    b.shape
1941                )));
1942            }
1943
1944            Self::verify_dshapes(
1945                &b.dshape,
1946                &b.shape,
1947                "Split Detection",
1948                &[
1949                    DimName::Batch,
1950                    DimName::Height,
1951                    DimName::Width,
1952                    DimName::NumAnchorsXFeatures,
1953                ],
1954            )?;
1955            let classes = if !b.dshape.is_empty() {
1956                Self::get_class_count(&b.dshape, None, Some(num_anchors))?
1957            } else {
1958                Self::get_class_count_no_dshape(b.into(), None)?
1959            };
1960
1961            match num_classes {
1962                Some(n) => {
1963                    if n != classes {
1964                        return Err(DecoderError::InvalidConfig(format!(
1965                            "ModelPack Split Detection inconsistent number of classes: previous {}, current {}",
1966                            n, classes
1967                        )));
1968                    }
1969                }
1970                None => {
1971                    num_classes = Some(classes);
1972                }
1973            }
1974        }
1975
1976        Ok(num_classes.unwrap_or(0))
1977    }
1978
1979    fn verify_modelpack_seg(
1980        segmentation: &configs::Segmentation,
1981        classes: Option<usize>,
1982    ) -> Result<(), DecoderError> {
1983        if segmentation.shape.len() != 4 {
1984            return Err(DecoderError::InvalidConfig(format!(
1985                "Invalid ModelPack Segmentation shape {:?}",
1986                segmentation.shape
1987            )));
1988        }
1989        Self::verify_dshapes(
1990            &segmentation.dshape,
1991            &segmentation.shape,
1992            "Segmentation",
1993            &[
1994                DimName::Batch,
1995                DimName::Height,
1996                DimName::Width,
1997                DimName::NumClasses,
1998            ],
1999        )?;
2000
2001        if let Some(classes) = classes {
2002            let seg_classes = if !segmentation.dshape.is_empty() {
2003                Self::get_class_count(&segmentation.dshape, None, None)?
2004            } else {
2005                Self::get_class_count_no_dshape(segmentation.into(), None)?
2006            };
2007
2008            if seg_classes != classes + 1 {
2009                return Err(DecoderError::InvalidConfig(format!(
2010                    "ModelPack Segmentation channels {} incompatible with number of classes {}",
2011                    seg_classes, classes
2012                )));
2013            }
2014        }
2015        Ok(())
2016    }
2017
2018    // verifies that dshapes match the given shape
2019    fn verify_dshapes(
2020        dshape: &[(DimName, usize)],
2021        shape: &[usize],
2022        name: &str,
2023        dims: &[DimName],
2024    ) -> Result<(), DecoderError> {
2025        for s in shape {
2026            if *s == 0 {
2027                return Err(DecoderError::InvalidConfig(format!(
2028                    "{} shape has zero dimension",
2029                    name
2030                )));
2031            }
2032        }
2033
2034        if shape.len() != dims.len() {
2035            return Err(DecoderError::InvalidConfig(format!(
2036                "{} shape length {} does not match expected dims length {}",
2037                name,
2038                shape.len(),
2039                dims.len()
2040            )));
2041        }
2042
2043        if dshape.is_empty() {
2044            return Ok(());
2045        }
2046        // check the dshape lengths match the shape lengths
2047        if dshape.len() != shape.len() {
2048            return Err(DecoderError::InvalidConfig(format!(
2049                "{} dshape length does not match shape length",
2050                name
2051            )));
2052        }
2053
2054        // check the dshape values match the shape values
2055        for ((dim_name, dim_size), shape_size) in dshape.iter().zip(shape) {
2056            if dim_size != shape_size {
2057                return Err(DecoderError::InvalidConfig(format!(
2058                    "{} dshape dimension {} size {} does not match shape size {}",
2059                    name, dim_name, dim_size, shape_size
2060                )));
2061            }
2062            if *dim_name == DimName::Padding && *dim_size != 1 {
2063                return Err(DecoderError::InvalidConfig(
2064                    "Padding dimension size must be 1".to_string(),
2065                ));
2066            }
2067
2068            if *dim_name == DimName::BoxCoords && *dim_size != 4 {
2069                return Err(DecoderError::InvalidConfig(
2070                    "BoxCoords dimension size must be 4".to_string(),
2071                ));
2072            }
2073        }
2074
2075        let dims_present = HashSet::<DimName>::from_iter(dshape.iter().map(|(name, _)| *name));
2076        for dim in dims {
2077            if !dims_present.contains(dim) {
2078                return Err(DecoderError::InvalidConfig(format!(
2079                    "{} dshape missing required dimension {:?}",
2080                    name, dim
2081                )));
2082            }
2083        }
2084
2085        Ok(())
2086    }
2087
2088    fn get_box_count(dshape: &[(DimName, usize)]) -> Option<usize> {
2089        for (dim_name, dim_size) in dshape {
2090            if *dim_name == DimName::NumBoxes {
2091                return Some(*dim_size);
2092            }
2093        }
2094        None
2095    }
2096
2097    fn get_class_count_no_dshape(
2098        config: ConfigOutputRef,
2099        protos: Option<usize>,
2100    ) -> Result<usize, DecoderError> {
2101        match config {
2102            ConfigOutputRef::Detection(detection) => match detection.decoder {
2103                DecoderType::Ultralytics => {
2104                    if detection.shape[1] <= 4 + protos.unwrap_or(0) {
2105                        return Err(DecoderError::InvalidConfig(format!(
2106                            "Invalid shape: Yolo num_features {} must be greater than {}",
2107                            detection.shape[1],
2108                            4 + protos.unwrap_or(0),
2109                        )));
2110                    }
2111                    Ok(detection.shape[1] - 4 - protos.unwrap_or(0))
2112                }
2113                DecoderType::ModelPack => {
2114                    let Some(num_anchors) = detection.anchors.as_ref().map(|a| a.len()) else {
2115                        return Err(DecoderError::Internal(
2116                            "ModelPack Detection missing anchors".to_string(),
2117                        ));
2118                    };
2119                    let anchors_x_features = detection.shape[3];
2120                    if anchors_x_features <= num_anchors * 5 {
2121                        return Err(DecoderError::InvalidConfig(format!(
2122                            "Invalid ModelPack Split Detection shape: anchors_x_features {} not greater than number of anchors * 5 = {}",
2123                            anchors_x_features,
2124                            num_anchors * 5,
2125                        )));
2126                    }
2127
2128                    if !anchors_x_features.is_multiple_of(num_anchors) {
2129                        return Err(DecoderError::InvalidConfig(format!(
2130                            "Invalid ModelPack Split Detection shape: anchors_x_features {} not a multiple of number of anchors {}",
2131                            anchors_x_features, num_anchors
2132                        )));
2133                    }
2134                    Ok(anchors_x_features / num_anchors - 5)
2135                }
2136            },
2137
2138            ConfigOutputRef::Scores(scores) => match scores.decoder {
2139                DecoderType::Ultralytics => Ok(scores.shape[1]),
2140                DecoderType::ModelPack => Ok(scores.shape[2]),
2141            },
2142            ConfigOutputRef::Segmentation(seg) => Ok(seg.shape[3]),
2143            _ => Err(DecoderError::Internal(
2144                "Attempted to get class count from unsupported config output".to_owned(),
2145            )),
2146        }
2147    }
2148
2149    // get the class count from dshape or calculate from num_features
2150    fn get_class_count(
2151        dshape: &[(DimName, usize)],
2152        protos: Option<usize>,
2153        anchors: Option<usize>,
2154    ) -> Result<usize, DecoderError> {
2155        if dshape.is_empty() {
2156            return Ok(0);
2157        }
2158        // if it has num_classes in dshape, return it
2159        for (dim_name, dim_size) in dshape {
2160            if *dim_name == DimName::NumClasses {
2161                return Ok(*dim_size);
2162            }
2163        }
2164
2165        // number of classes can be calculated from num_features - 4 for yolo.  If the
2166        // model has protos, we also subtract the number of protos.
2167        for (dim_name, dim_size) in dshape {
2168            if *dim_name == DimName::NumFeatures {
2169                let protos = protos.unwrap_or(0);
2170                if protos + 4 >= *dim_size {
2171                    return Err(DecoderError::InvalidConfig(format!(
2172                        "Invalid shape: Yolo num_features {} must be greater than {}",
2173                        *dim_size,
2174                        protos + 4,
2175                    )));
2176                }
2177                return Ok(*dim_size - 4 - protos);
2178            }
2179        }
2180
2181        // number of classes can be calculated from number of anchors for modelpack
2182        // split detection
2183        if let Some(num_anchors) = anchors {
2184            for (dim_name, dim_size) in dshape {
2185                if *dim_name == DimName::NumAnchorsXFeatures {
2186                    let anchors_x_features = *dim_size;
2187                    if anchors_x_features <= num_anchors * 5 {
2188                        return Err(DecoderError::InvalidConfig(format!(
2189                            "Invalid ModelPack Split Detection shape: anchors_x_features {} not greater than number of anchors * 5 = {}",
2190                            anchors_x_features,
2191                            num_anchors * 5,
2192                        )));
2193                    }
2194
2195                    if !anchors_x_features.is_multiple_of(num_anchors) {
2196                        return Err(DecoderError::InvalidConfig(format!(
2197                            "Invalid ModelPack Split Detection shape: anchors_x_features {} not a multiple of number of anchors {}",
2198                            anchors_x_features, num_anchors
2199                        )));
2200                    }
2201                    return Ok((anchors_x_features / num_anchors) - 5);
2202                }
2203            }
2204        }
2205        Err(DecoderError::InvalidConfig(
2206            "Cannot determine number of classes from dshape".to_owned(),
2207        ))
2208    }
2209
2210    fn get_protos_count(dshape: &[(DimName, usize)]) -> Option<usize> {
2211        for (dim_name, dim_size) in dshape {
2212            if *dim_name == DimName::NumProtos {
2213                return Some(*dim_size);
2214            }
2215        }
2216        None
2217    }
2218}