Skip to main content

edgefirst_decoder/decoder/
builder.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::HashSet;
5
6use super::config::ConfigOutputRef;
7use super::configs::{self, DecoderType, DecoderVersion, DimName, ModelType};
8use super::merge::DecodeProgram;
9use super::{ConfigOutput, ConfigOutputs, Decoder};
10use crate::per_scale::{DecodeDtype, PerScalePlan};
11use crate::schema::SchemaV2;
12use crate::DecoderError;
13
14/// Extract `(width, height)` from a schema [`crate::schema::InputSpec`].
15///
16/// Prefers named dims (`DimName::Width` / `DimName::Height`) when the
17/// `dshape` is populated, falling back to the NHWC convention
18/// (`shape[1] = H, shape[2] = W`) for 4-element shapes whenever either
19/// named dim is missing. Returns `None` for any other shape arity — the
20/// decoder treats that as "input dims unknown" and skips the EDGEAI-1303
21/// normalization path.
22///
23/// The fallback fires when **either** `Height` or `Width` is missing from
24/// the dshape (not only when both are absent), so a partially-named
25/// dshape (e.g. only `Width`) still resolves both dims via positional
26/// inference instead of silently disabling normalization.
27fn input_dims_from_spec(input: &crate::schema::InputSpec) -> Option<(usize, usize)> {
28    use crate::configs::DimName;
29    let mut h = None;
30    let mut w = None;
31    // `SchemaV2::validate()` doesn't currently enforce
32    // `dshape.len() <= shape.len()`, so a malformed schema can trip an
33    // out-of-bounds index here. Use `shape.get(i)` to silently skip
34    // dshape entries that overflow the shape — the caller treats
35    // missing dims as "unknown" and disables EDGEAI-1303 normalization.
36    for (i, (name, _)) in input.dshape.iter().enumerate() {
37        match name {
38            DimName::Height => h = input.shape.get(i).copied(),
39            DimName::Width => w = input.shape.get(i).copied(),
40            _ => {}
41        }
42    }
43    if (h.is_none() || w.is_none()) && input.shape.len() == 4 {
44        // NHWC default: [N, H, W, C]. Mirrors the per-scale `extract_hw`
45        // fallback (`crates/decoder/src/per_scale/plan.rs::extract_hw`).
46        // Only fill the missing axis so a partial named dshape still
47        // resolves both dims.
48        h = h.or(Some(input.shape[1]));
49        w = w.or(Some(input.shape[2]));
50    }
51    match (w, h) {
52        (Some(w), Some(h)) => Some((w, h)),
53        _ => None,
54    }
55}
56
57#[cfg(test)]
58mod input_dims_from_spec_tests {
59    use super::input_dims_from_spec;
60    use crate::configs::DimName;
61    use crate::schema::InputSpec;
62
63    #[test]
64    fn named_dshape_resolves_dims() {
65        let spec = InputSpec {
66            shape: vec![1, 480, 640, 3],
67            dshape: vec![
68                (DimName::Batch, 1),
69                (DimName::Height, 480),
70                (DimName::Width, 640),
71                (DimName::NumFeatures, 3),
72            ],
73            cameraadaptor: None,
74        };
75        assert_eq!(input_dims_from_spec(&spec), Some((640, 480)));
76    }
77
78    #[test]
79    fn empty_dshape_falls_back_to_nhwc_for_4d() {
80        let spec = InputSpec {
81            shape: vec![1, 480, 640, 3],
82            dshape: vec![],
83            cameraadaptor: None,
84        };
85        assert_eq!(input_dims_from_spec(&spec), Some((640, 480)));
86    }
87
88    #[test]
89    fn malformed_dshape_longer_than_shape_does_not_panic() {
90        // Regression for Copilot review on PR #63: indexing
91        // `input.shape[i]` while iterating dshape can OOB-panic when
92        // `dshape.len() > shape.len()`. The fix uses `shape.get(i)`
93        // and silently treats overflow as "dim missing".
94        let spec = InputSpec {
95            shape: vec![640, 480], // 2-D shape
96            dshape: vec![
97                (DimName::Width, 640),
98                (DimName::Height, 480),
99                (DimName::NumFeatures, 3), // overflow — index 2 ≥ shape.len()
100            ],
101            cameraadaptor: None,
102        };
103        // First two dshape entries resolve via `shape.get()`; the third
104        // is a no-op. Width/Height both resolved, so we expect Some.
105        assert_eq!(input_dims_from_spec(&spec), Some((640, 480)));
106    }
107
108    #[test]
109    fn malformed_dshape_only_overflow_returns_none() {
110        // All dshape entries are past the shape boundary — width and
111        // height stay None and the 4-D NHWC fallback doesn't fire
112        // (shape.len() == 1), so we get None instead of a panic.
113        let spec = InputSpec {
114            shape: vec![1],
115            dshape: vec![
116                (DimName::NumFeatures, 3),
117                (DimName::Height, 480),
118                (DimName::Width, 640),
119            ],
120            cameraadaptor: None,
121        };
122        assert_eq!(input_dims_from_spec(&spec), None);
123    }
124}
125
126#[derive(Debug, Clone, PartialEq)]
127pub struct DecoderBuilder {
128    config_src: Option<ConfigSource>,
129    iou_threshold: f32,
130    score_threshold: f32,
131    /// NMS mode.
132    ///
133    /// - `Some(Nms::Auto)` — resolve from config or fallback to
134    ///   `ClassAgnostic` (builder default)
135    /// - `Some(Nms::ClassAgnostic)` — explicit class-agnostic override
136    /// - `Some(Nms::ClassAware)` — explicit class-aware override
137    /// - `None` — bypass NMS entirely
138    nms: Option<configs::Nms>,
139    /// Output dtype for the per-scale fast path. Has no effect on
140    /// schemas without per-scale children (which use the legacy decode
141    /// path).
142    decode_dtype: DecodeDtype,
143    pre_nms_top_k: usize,
144    max_det: usize,
145    /// Explicit override for the model input dimensions `(width, height)`,
146    /// consumed by EDGEAI-1303 normalization. When set, takes precedence
147    /// over schema-derived dims; when `None`, the value is read from the
148    /// schema's `input.shape` / `input.dshape` at build time.
149    input_dims: Option<(usize, usize)>,
150}
151
152#[derive(Debug, Clone, PartialEq)]
153enum ConfigSource {
154    Yaml(String),
155    Json(String),
156    Config(ConfigOutputs),
157    /// Schema v2 metadata. During build the schema is either converted
158    /// to a legacy [`ConfigOutputs`] (flat case) or used to construct a
159    /// [`DecodeProgram`] that performs per-child dequant + merge at
160    /// decode time.
161    Schema(SchemaV2),
162}
163
164impl Default for DecoderBuilder {
165    /// Creates a default DecoderBuilder with no configuration and 0.5 score
166    /// threshold and 0.5 IoU threshold.
167    ///
168    /// A valid configuration must be provided before building the Decoder.
169    ///
170    /// # Examples
171    /// ```rust
172    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
173    /// # fn main() -> DecoderResult<()> {
174    /// #  let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
175    /// let decoder = DecoderBuilder::default()
176    ///     .with_config_yaml_str(config_yaml)
177    ///     .build()?;
178    /// assert_eq!(decoder.score_threshold, 0.5);
179    /// assert_eq!(decoder.iou_threshold, 0.5);
180    ///
181    /// # Ok(())
182    /// # }
183    /// ```
184    fn default() -> Self {
185        Self {
186            config_src: None,
187            iou_threshold: 0.5,
188            score_threshold: 0.5,
189            nms: Some(configs::Nms::Auto),
190            decode_dtype: DecodeDtype::F32,
191            pre_nms_top_k: 300,
192            max_det: 300,
193            input_dims: None,
194        }
195    }
196}
197
198impl DecoderBuilder {
199    /// Creates a default DecoderBuilder with no configuration and 0.5 score
200    /// threshold and 0.5 IoU threshold.
201    ///
202    /// A valid configuration must be provided before building the Decoder.
203    ///
204    /// # Examples
205    /// ```rust
206    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
207    /// # fn main() -> DecoderResult<()> {
208    /// #  let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
209    /// let decoder = DecoderBuilder::new()
210    ///     .with_config_yaml_str(config_yaml)
211    ///     .build()?;
212    /// assert_eq!(decoder.score_threshold, 0.5);
213    /// assert_eq!(decoder.iou_threshold, 0.5);
214    ///
215    /// # Ok(())
216    /// # }
217    /// ```
218    pub fn new() -> Self {
219        Self::default()
220    }
221
222    /// Loads a model configuration in YAML format. Does not check if the string
223    /// is a correct configuration file. Use `DecoderBuilder.build()` to
224    /// deserialize the YAML and parse the model configuration.
225    ///
226    /// # Examples
227    /// ```rust
228    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
229    /// # fn main() -> DecoderResult<()> {
230    /// let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
231    /// let decoder = DecoderBuilder::new()
232    ///     .with_config_yaml_str(config_yaml)
233    ///     .build()?;
234    ///
235    /// # Ok(())
236    /// # }
237    /// ```
238    pub fn with_config_yaml_str(mut self, yaml_str: String) -> Self {
239        self.config_src.replace(ConfigSource::Yaml(yaml_str));
240        self
241    }
242
243    /// Loads a model configuration in JSON format. Does not check if the string
244    /// is a correct configuration file. Use `DecoderBuilder.build()` to
245    /// deserialize the JSON and parse the model configuration.
246    ///
247    /// # Examples
248    /// ```rust
249    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
250    /// # fn main() -> DecoderResult<()> {
251    /// let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
252    /// let decoder = DecoderBuilder::new()
253    ///     .with_config_json_str(config_json)
254    ///     .build()?;
255    ///
256    /// # Ok(())
257    /// # }
258    /// ```
259    pub fn with_config_json_str(mut self, json_str: String) -> Self {
260        self.config_src.replace(ConfigSource::Json(json_str));
261        self
262    }
263
264    /// Loads a model configuration. Does not check if the configuration is
265    /// correct. Intended to be used when the user needs control over the
266    /// deserialize of the configuration information. Use
267    /// `DecoderBuilder.build()` to parse the model configuration.
268    ///
269    /// # Examples
270    /// ```rust
271    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
272    /// # fn main() -> DecoderResult<()> {
273    /// let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json"));
274    /// let config = serde_json::from_str(config_json)?;
275    /// let decoder = DecoderBuilder::new().with_config(config).build()?;
276    ///
277    /// # Ok(())
278    /// # }
279    /// ```
280    pub fn with_config(mut self, config: ConfigOutputs) -> Self {
281        self.config_src.replace(ConfigSource::Config(config));
282        self
283    }
284
285    /// Configure the decoder from a schema v2 metadata document.
286    ///
287    /// Accepts a [`SchemaV2`] as produced by [`SchemaV2::parse_json`],
288    /// [`SchemaV2::parse_yaml`], [`SchemaV2::parse_file`], or
289    /// constructed programmatically. The builder validates the schema,
290    /// compiles a [`DecodeProgram`] for any split logical outputs
291    /// (per-scale or channel sub-splits), and downconverts the
292    /// logical-level semantics to the legacy [`ConfigOutputs`]
293    /// representation consumed by the existing decoder dispatch.
294    ///
295    /// # Examples
296    ///
297    /// ```rust,no_run
298    /// use edgefirst_decoder::{DecoderBuilder, DecoderResult};
299    /// use edgefirst_decoder::schema::SchemaV2;
300    ///
301    /// # fn main() -> DecoderResult<()> {
302    /// let schema = SchemaV2::parse_file("model/edgefirst.json")?;
303    /// let decoder = DecoderBuilder::new()
304    ///     .with_schema(schema)
305    ///     .with_score_threshold(0.25)
306    ///     .build()?;
307    /// # Ok(())
308    /// # }
309    /// ```
310    pub fn with_schema(mut self, schema: SchemaV2) -> Self {
311        self.config_src.replace(ConfigSource::Schema(schema));
312        self
313    }
314
315    /// Choose the output dtype for the per-scale decoder pipeline.
316    ///
317    /// Defaults to [`DecodeDtype::F32`]. Has no effect on schemas
318    /// without per-scale children (which use the legacy decode path).
319    /// `F16` saves ~2× memory bandwidth at the cost of 10-bit mantissa
320    /// precision; empirically safe for YOLO-family models.
321    ///
322    /// # Examples
323    ///
324    /// ```rust,no_run
325    /// use edgefirst_decoder::{DecodeDtype, DecoderBuilder, DecoderResult};
326    /// use edgefirst_decoder::schema::SchemaV2;
327    ///
328    /// # fn main() -> DecoderResult<()> {
329    /// let schema = SchemaV2::parse_file("model/edgefirst.json")?;
330    /// let decoder = DecoderBuilder::new()
331    ///     .with_schema(schema)
332    ///     .with_decode_dtype(DecodeDtype::F32)
333    ///     .build()?;
334    /// # Ok(())
335    /// # }
336    /// ```
337    pub fn with_decode_dtype(mut self, dtype: DecodeDtype) -> Self {
338        self.decode_dtype = dtype;
339        self
340    }
341
342    /// Loads a YOLO detection model configuration.  Use
343    /// `DecoderBuilder.build()` to parse the model configuration.
344    ///
345    /// # Examples
346    /// ```rust
347    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
348    /// # fn main() -> DecoderResult<()> {
349    /// let decoder = DecoderBuilder::new()
350    ///     .with_config_yolo_det(
351    ///         configs::Detection {
352    ///             anchors: None,
353    ///             decoder: configs::DecoderType::Ultralytics,
354    ///             quantization: Some(configs::QuantTuple(0.012345, 26)),
355    ///             shape: vec![1, 84, 8400],
356    ///             dshape: Vec::new(),
357    ///             normalized: Some(true),
358    ///         },
359    ///         None,
360    ///     )
361    ///     .build()?;
362    ///
363    /// # Ok(())
364    /// # }
365    /// ```
366    pub fn with_config_yolo_det(
367        mut self,
368        boxes: configs::Detection,
369        version: Option<DecoderVersion>,
370    ) -> Self {
371        let config = ConfigOutputs {
372            outputs: vec![ConfigOutput::Detection(boxes)],
373            decoder_version: version,
374            ..Default::default()
375        };
376        self.config_src.replace(ConfigSource::Config(config));
377        self
378    }
379
380    /// Loads a YOLO split detection model configuration.  Use
381    /// `DecoderBuilder.build()` to parse the model configuration.
382    ///
383    /// # Examples
384    /// ```rust
385    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
386    /// # fn main() -> DecoderResult<()> {
387    /// let boxes_config = configs::Boxes {
388    ///     decoder: configs::DecoderType::Ultralytics,
389    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
390    ///     shape: vec![1, 4, 8400],
391    ///     dshape: Vec::new(),
392    ///     normalized: Some(true),
393    /// };
394    /// let scores_config = configs::Scores {
395    ///     decoder: configs::DecoderType::Ultralytics,
396    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
397    ///     shape: vec![1, 80, 8400],
398    ///     dshape: Vec::new(),
399    /// };
400    /// let decoder = DecoderBuilder::new()
401    ///     .with_config_yolo_split_det(boxes_config, scores_config)
402    ///     .build()?;
403    /// # Ok(())
404    /// # }
405    /// ```
406    pub fn with_config_yolo_split_det(
407        mut self,
408        boxes: configs::Boxes,
409        scores: configs::Scores,
410    ) -> Self {
411        let config = ConfigOutputs {
412            outputs: vec![ConfigOutput::Boxes(boxes), ConfigOutput::Scores(scores)],
413            ..Default::default()
414        };
415        self.config_src.replace(ConfigSource::Config(config));
416        self
417    }
418
419    /// Loads a YOLO segmentation model configuration.  Use
420    /// `DecoderBuilder.build()` to parse the model configuration.
421    ///
422    /// # Examples
423    /// ```rust
424    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
425    /// # fn main() -> DecoderResult<()> {
426    /// let seg_config = configs::Detection {
427    ///     decoder: configs::DecoderType::Ultralytics,
428    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
429    ///     shape: vec![1, 116, 8400],
430    ///     anchors: None,
431    ///     dshape: Vec::new(),
432    ///     normalized: Some(true),
433    /// };
434    /// let protos_config = configs::Protos {
435    ///     decoder: configs::DecoderType::Ultralytics,
436    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
437    ///     shape: vec![1, 160, 160, 32],
438    ///     dshape: Vec::new(),
439    /// };
440    /// let decoder = DecoderBuilder::new()
441    ///     .with_config_yolo_segdet(
442    ///         seg_config,
443    ///         protos_config,
444    ///         Some(configs::DecoderVersion::Yolov8),
445    ///     )
446    ///     .build()?;
447    /// # Ok(())
448    /// # }
449    /// ```
450    pub fn with_config_yolo_segdet(
451        mut self,
452        boxes: configs::Detection,
453        protos: configs::Protos,
454        version: Option<DecoderVersion>,
455    ) -> Self {
456        let config = ConfigOutputs {
457            outputs: vec![ConfigOutput::Detection(boxes), ConfigOutput::Protos(protos)],
458            decoder_version: version,
459            ..Default::default()
460        };
461        self.config_src.replace(ConfigSource::Config(config));
462        self
463    }
464
465    /// Loads a YOLO split segmentation model configuration.  Use
466    /// `DecoderBuilder.build()` to parse the model configuration.
467    ///
468    /// # Examples
469    /// ```rust
470    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
471    /// # fn main() -> DecoderResult<()> {
472    /// let boxes_config = configs::Boxes {
473    ///     decoder: configs::DecoderType::Ultralytics,
474    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
475    ///     shape: vec![1, 4, 8400],
476    ///     dshape: Vec::new(),
477    ///     normalized: Some(true),
478    /// };
479    /// let scores_config = configs::Scores {
480    ///     decoder: configs::DecoderType::Ultralytics,
481    ///     quantization: Some(configs::QuantTuple(0.012345, 14)),
482    ///     shape: vec![1, 80, 8400],
483    ///     dshape: Vec::new(),
484    /// };
485    /// let mask_config = configs::MaskCoefficients {
486    ///     decoder: configs::DecoderType::Ultralytics,
487    ///     quantization: Some(configs::QuantTuple(0.0064123, 125)),
488    ///     shape: vec![1, 32, 8400],
489    ///     dshape: Vec::new(),
490    /// };
491    /// let protos_config = configs::Protos {
492    ///     decoder: configs::DecoderType::Ultralytics,
493    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
494    ///     shape: vec![1, 160, 160, 32],
495    ///     dshape: Vec::new(),
496    /// };
497    /// let decoder = DecoderBuilder::new()
498    ///     .with_config_yolo_split_segdet(boxes_config, scores_config, mask_config, protos_config)
499    ///     .build()?;
500    /// # Ok(())
501    /// # }
502    /// ```
503    pub fn with_config_yolo_split_segdet(
504        mut self,
505        boxes: configs::Boxes,
506        scores: configs::Scores,
507        mask_coefficients: configs::MaskCoefficients,
508        protos: configs::Protos,
509    ) -> Self {
510        let config = ConfigOutputs {
511            outputs: vec![
512                ConfigOutput::Boxes(boxes),
513                ConfigOutput::Scores(scores),
514                ConfigOutput::MaskCoefficients(mask_coefficients),
515                ConfigOutput::Protos(protos),
516            ],
517            ..Default::default()
518        };
519        self.config_src.replace(ConfigSource::Config(config));
520        self
521    }
522
523    /// Loads a ModelPack detection model configuration.  Use
524    /// `DecoderBuilder.build()` to parse the model configuration.
525    ///
526    /// # Examples
527    /// ```rust
528    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
529    /// # fn main() -> DecoderResult<()> {
530    /// let boxes_config = configs::Boxes {
531    ///     decoder: configs::DecoderType::ModelPack,
532    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
533    ///     shape: vec![1, 8400, 1, 4],
534    ///     dshape: Vec::new(),
535    ///     normalized: Some(true),
536    /// };
537    /// let scores_config = configs::Scores {
538    ///     decoder: configs::DecoderType::ModelPack,
539    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
540    ///     shape: vec![1, 8400, 3],
541    ///     dshape: Vec::new(),
542    /// };
543    /// let decoder = DecoderBuilder::new()
544    ///     .with_config_modelpack_det(boxes_config, scores_config)
545    ///     .build()?;
546    /// # Ok(())
547    /// # }
548    /// ```
549    pub fn with_config_modelpack_det(
550        mut self,
551        boxes: configs::Boxes,
552        scores: configs::Scores,
553    ) -> Self {
554        let config = ConfigOutputs {
555            outputs: vec![ConfigOutput::Boxes(boxes), ConfigOutput::Scores(scores)],
556            ..Default::default()
557        };
558        self.config_src.replace(ConfigSource::Config(config));
559        self
560    }
561
562    /// Loads a ModelPack split detection model configuration. Use
563    /// `DecoderBuilder.build()` to parse the model configuration.
564    ///
565    /// # Examples
566    /// ```rust
567    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
568    /// # fn main() -> DecoderResult<()> {
569    /// let config0 = configs::Detection {
570    ///     anchors: Some(vec![
571    ///         [0.13750000298023224, 0.2074074000120163],
572    ///         [0.2541666626930237, 0.21481481194496155],
573    ///         [0.23125000298023224, 0.35185185074806213],
574    ///     ]),
575    ///     decoder: configs::DecoderType::ModelPack,
576    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
577    ///     shape: vec![1, 17, 30, 18],
578    ///     dshape: Vec::new(),
579    ///     normalized: Some(true),
580    /// };
581    /// let config1 = configs::Detection {
582    ///     anchors: Some(vec![
583    ///         [0.36666667461395264, 0.31481480598449707],
584    ///         [0.38749998807907104, 0.4740740656852722],
585    ///         [0.5333333611488342, 0.644444465637207],
586    ///     ]),
587    ///     decoder: configs::DecoderType::ModelPack,
588    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
589    ///     shape: vec![1, 9, 15, 18],
590    ///     dshape: Vec::new(),
591    ///     normalized: Some(true),
592    /// };
593    ///
594    /// let decoder = DecoderBuilder::new()
595    ///     .with_config_modelpack_det_split(vec![config0, config1])
596    ///     .build()?;
597    /// # Ok(())
598    /// # }
599    /// ```
600    pub fn with_config_modelpack_det_split(mut self, boxes: Vec<configs::Detection>) -> Self {
601        let outputs = boxes.into_iter().map(ConfigOutput::Detection).collect();
602        let config = ConfigOutputs {
603            outputs,
604            ..Default::default()
605        };
606        self.config_src.replace(ConfigSource::Config(config));
607        self
608    }
609
610    /// Loads a ModelPack segmentation detection model configuration. Use
611    /// `DecoderBuilder.build()` to parse the model configuration.
612    ///
613    /// # Examples
614    /// ```rust
615    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
616    /// # fn main() -> DecoderResult<()> {
617    /// let boxes_config = configs::Boxes {
618    ///     decoder: configs::DecoderType::ModelPack,
619    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
620    ///     shape: vec![1, 8400, 1, 4],
621    ///     dshape: Vec::new(),
622    ///     normalized: Some(true),
623    /// };
624    /// let scores_config = configs::Scores {
625    ///     decoder: configs::DecoderType::ModelPack,
626    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
627    ///     shape: vec![1, 8400, 2],
628    ///     dshape: Vec::new(),
629    /// };
630    /// let seg_config = configs::Segmentation {
631    ///     decoder: configs::DecoderType::ModelPack,
632    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
633    ///     shape: vec![1, 640, 640, 3],
634    ///     dshape: Vec::new(),
635    /// };
636    /// let decoder = DecoderBuilder::new()
637    ///     .with_config_modelpack_segdet(boxes_config, scores_config, seg_config)
638    ///     .build()?;
639    /// # Ok(())
640    /// # }
641    /// ```
642    pub fn with_config_modelpack_segdet(
643        mut self,
644        boxes: configs::Boxes,
645        scores: configs::Scores,
646        segmentation: configs::Segmentation,
647    ) -> Self {
648        let config = ConfigOutputs {
649            outputs: vec![
650                ConfigOutput::Boxes(boxes),
651                ConfigOutput::Scores(scores),
652                ConfigOutput::Segmentation(segmentation),
653            ],
654            ..Default::default()
655        };
656        self.config_src.replace(ConfigSource::Config(config));
657        self
658    }
659
660    /// Loads a ModelPack segmentation split detection model configuration. Use
661    /// `DecoderBuilder.build()` to parse the model configuration.
662    ///
663    /// # Examples
664    /// ```rust
665    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
666    /// # fn main() -> DecoderResult<()> {
667    /// let config0 = configs::Detection {
668    ///     anchors: Some(vec![
669    ///         [0.36666667461395264, 0.31481480598449707],
670    ///         [0.38749998807907104, 0.4740740656852722],
671    ///         [0.5333333611488342, 0.644444465637207],
672    ///     ]),
673    ///     decoder: configs::DecoderType::ModelPack,
674    ///     quantization: Some(configs::QuantTuple(0.08547406643629074, 174)),
675    ///     shape: vec![1, 9, 15, 18],
676    ///     dshape: Vec::new(),
677    ///     normalized: Some(true),
678    /// };
679    /// let config1 = configs::Detection {
680    ///     anchors: Some(vec![
681    ///         [0.13750000298023224, 0.2074074000120163],
682    ///         [0.2541666626930237, 0.21481481194496155],
683    ///         [0.23125000298023224, 0.35185185074806213],
684    ///     ]),
685    ///     decoder: configs::DecoderType::ModelPack,
686    ///     quantization: Some(configs::QuantTuple(0.09929127991199493, 183)),
687    ///     shape: vec![1, 17, 30, 18],
688    ///     dshape: Vec::new(),
689    ///     normalized: Some(true),
690    /// };
691    /// let seg_config = configs::Segmentation {
692    ///     decoder: configs::DecoderType::ModelPack,
693    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
694    ///     shape: vec![1, 640, 640, 2],
695    ///     dshape: Vec::new(),
696    /// };
697    /// let decoder = DecoderBuilder::new()
698    ///     .with_config_modelpack_segdet_split(vec![config0, config1], seg_config)
699    ///     .build()?;
700    /// # Ok(())
701    /// # }
702    /// ```
703    pub fn with_config_modelpack_segdet_split(
704        mut self,
705        boxes: Vec<configs::Detection>,
706        segmentation: configs::Segmentation,
707    ) -> Self {
708        let mut outputs = boxes
709            .into_iter()
710            .map(ConfigOutput::Detection)
711            .collect::<Vec<_>>();
712        outputs.push(ConfigOutput::Segmentation(segmentation));
713        let config = ConfigOutputs {
714            outputs,
715            ..Default::default()
716        };
717        self.config_src.replace(ConfigSource::Config(config));
718        self
719    }
720
721    /// Loads a ModelPack segmentation model configuration. Use
722    /// `DecoderBuilder.build()` to parse the model configuration.
723    ///
724    /// # Examples
725    /// ```rust
726    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
727    /// # fn main() -> DecoderResult<()> {
728    /// let seg_config = configs::Segmentation {
729    ///     decoder: configs::DecoderType::ModelPack,
730    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
731    ///     shape: vec![1, 640, 640, 3],
732    ///     dshape: Vec::new(),
733    /// };
734    /// let decoder = DecoderBuilder::new()
735    ///     .with_config_modelpack_seg(seg_config)
736    ///     .build()?;
737    /// # Ok(())
738    /// # }
739    /// ```
740    pub fn with_config_modelpack_seg(mut self, segmentation: configs::Segmentation) -> Self {
741        let config = ConfigOutputs {
742            outputs: vec![ConfigOutput::Segmentation(segmentation)],
743            ..Default::default()
744        };
745        self.config_src.replace(ConfigSource::Config(config));
746        self
747    }
748
749    /// Add an output to the decoder configuration.
750    ///
751    /// Incrementally builds the model configuration by adding outputs one at
752    /// a time. The decoder resolves the model type from the combination of
753    /// outputs during `build()`.
754    ///
755    /// If `dshape` is non-empty on the output, `shape` is automatically
756    /// derived from it (the size component of each named dimension). This
757    /// prevents conflicts between `shape` and `dshape`.
758    ///
759    /// This uses the programmatic config path. Calling this after
760    /// `with_config_json_str()` or `with_config_yaml_str()` replaces the
761    /// string-based config source.
762    ///
763    /// # Examples
764    /// ```rust
765    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, ConfigOutput, configs};
766    /// # fn main() -> DecoderResult<()> {
767    /// let decoder = DecoderBuilder::new()
768    ///     .add_output(ConfigOutput::Scores(configs::Scores {
769    ///         decoder: configs::DecoderType::Ultralytics,
770    ///         dshape: vec![
771    ///             (configs::DimName::Batch, 1),
772    ///             (configs::DimName::NumClasses, 80),
773    ///             (configs::DimName::NumBoxes, 8400),
774    ///         ],
775    ///         ..Default::default()
776    ///     }))
777    ///     .add_output(ConfigOutput::Boxes(configs::Boxes {
778    ///         decoder: configs::DecoderType::Ultralytics,
779    ///         dshape: vec![
780    ///             (configs::DimName::Batch, 1),
781    ///             (configs::DimName::BoxCoords, 4),
782    ///             (configs::DimName::NumBoxes, 8400),
783    ///         ],
784    ///         ..Default::default()
785    ///     }))
786    ///     .build()?;
787    /// # Ok(())
788    /// # }
789    /// ```
790    pub fn add_output(mut self, output: ConfigOutput) -> Self {
791        if !matches!(self.config_src, Some(ConfigSource::Config(_))) {
792            self.config_src = Some(ConfigSource::Config(ConfigOutputs::default()));
793        }
794        if let Some(ConfigSource::Config(ref mut config)) = self.config_src {
795            config.outputs.push(Self::normalize_output(output));
796        }
797        self
798    }
799
800    /// Sets the decoder version for Ultralytics models.
801    ///
802    /// This is used with `add_output()` to specify the YOLO architecture
803    /// version when it cannot be inferred from the output shapes alone.
804    ///
805    /// - `Yolov5`, `Yolov8`, `Yolo11`: Traditional models requiring external
806    ///   NMS
807    /// - `Yolo26`: End-to-end models with NMS embedded in the model graph
808    ///
809    /// # Examples
810    /// ```rust
811    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, ConfigOutput, configs};
812    /// # fn main() -> DecoderResult<()> {
813    /// let decoder = DecoderBuilder::new()
814    ///     .add_output(ConfigOutput::Detection(configs::Detection {
815    ///         decoder: configs::DecoderType::Ultralytics,
816    ///         dshape: vec![
817    ///             (configs::DimName::Batch, 1),
818    ///             (configs::DimName::NumBoxes, 100),
819    ///             (configs::DimName::NumFeatures, 6),
820    ///         ],
821    ///         ..Default::default()
822    ///     }))
823    ///     .with_decoder_version(configs::DecoderVersion::Yolo26)
824    ///     .build()?;
825    /// # Ok(())
826    /// # }
827    /// ```
828    pub fn with_decoder_version(mut self, version: configs::DecoderVersion) -> Self {
829        if !matches!(self.config_src, Some(ConfigSource::Config(_))) {
830            self.config_src = Some(ConfigSource::Config(ConfigOutputs::default()));
831        }
832        if let Some(ConfigSource::Config(ref mut config)) = self.config_src {
833            config.decoder_version = Some(version);
834        }
835        self
836    }
837
838    /// Normalize an output: if dshape is non-empty, derive shape from it.
839    fn normalize_output(mut output: ConfigOutput) -> ConfigOutput {
840        fn normalize_shape(shape: &mut Vec<usize>, dshape: &[(configs::DimName, usize)]) {
841            if !dshape.is_empty() {
842                *shape = dshape.iter().map(|(_, size)| *size).collect();
843            }
844        }
845        match &mut output {
846            ConfigOutput::Detection(c) => normalize_shape(&mut c.shape, &c.dshape),
847            ConfigOutput::Boxes(c) => normalize_shape(&mut c.shape, &c.dshape),
848            ConfigOutput::Scores(c) => normalize_shape(&mut c.shape, &c.dshape),
849            ConfigOutput::Protos(c) => normalize_shape(&mut c.shape, &c.dshape),
850            ConfigOutput::Segmentation(c) => normalize_shape(&mut c.shape, &c.dshape),
851            ConfigOutput::MaskCoefficients(c) => normalize_shape(&mut c.shape, &c.dshape),
852            ConfigOutput::Mask(c) => normalize_shape(&mut c.shape, &c.dshape),
853            ConfigOutput::Classes(c) => normalize_shape(&mut c.shape, &c.dshape),
854        }
855        output
856    }
857
858    /// Sets the scores threshold of the decoder
859    ///
860    /// # Examples
861    /// ```rust
862    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
863    /// # fn main() -> DecoderResult<()> {
864    /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
865    /// let decoder = DecoderBuilder::new()
866    ///     .with_config_json_str(config_json)
867    ///     .with_score_threshold(0.654)
868    ///     .build()?;
869    /// assert_eq!(decoder.score_threshold, 0.654);
870    /// # Ok(())
871    /// # }
872    /// ```
873    pub fn with_score_threshold(mut self, score_threshold: f32) -> Self {
874        self.score_threshold = score_threshold;
875        self
876    }
877
878    /// Sets the IOU threshold of the decoder. Has no effect when NMS is set to
879    /// `None`
880    ///
881    /// # Examples
882    /// ```rust
883    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
884    /// # fn main() -> DecoderResult<()> {
885    /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
886    /// let decoder = DecoderBuilder::new()
887    ///     .with_config_json_str(config_json)
888    ///     .with_iou_threshold(0.654)
889    ///     .build()?;
890    /// assert_eq!(decoder.iou_threshold, 0.654);
891    /// # Ok(())
892    /// # }
893    /// ```
894    pub fn with_iou_threshold(mut self, iou_threshold: f32) -> Self {
895        self.iou_threshold = iou_threshold;
896        self
897    }
898
899    /// Sets the NMS mode for the decoder.
900    ///
901    /// - `Some(Nms::Auto)` — resolve from model config (e.g. `edgefirst.json`)
902    ///   or fall back to `ClassAgnostic` (this is the builder default)
903    /// - `Some(Nms::ClassAgnostic)` — class-agnostic NMS: suppress overlapping
904    ///   boxes regardless of class label
905    /// - `Some(Nms::ClassAware)` — class-aware NMS: only suppress boxes that
906    ///   share the same class label AND overlap above the IoU threshold
907    /// - `None` — bypass NMS entirely (for end-to-end models with embedded NMS)
908    ///
909    /// # Examples
910    /// ```rust
911    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, configs::Nms};
912    /// # fn main() -> DecoderResult<()> {
913    /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
914    /// let decoder = DecoderBuilder::new()
915    ///     .with_config_json_str(config_json)
916    ///     .with_nms(Some(Nms::ClassAware))
917    ///     .build()?;
918    /// assert_eq!(decoder.nms, Some(Nms::ClassAware));
919    /// # Ok(())
920    /// # }
921    /// ```
922    pub fn with_nms(mut self, nms: Option<configs::Nms>) -> Self {
923        self.nms = nms;
924        self
925    }
926
927    /// Sets the maximum number of candidate boxes fed into NMS after score
928    /// filtering.  Uses partial sort (O(N)) to select the top-K candidates,
929    /// dramatically reducing the O(N²) NMS cost when many low-confidence
930    /// proposals pass the threshold (common with mAP eval at 0.001).
931    ///
932    /// Default: 300.
933    ///
934    /// # ⚠️ Validation vs Deployment
935    ///
936    /// The default is appropriate for **deployment** where
937    /// `score_threshold ≥ 0.25` means few anchors survive filtering and
938    /// top-K is effectively a no-op.
939    ///
940    /// For **COCO mAP evaluation** (`score_threshold ≈ 0.001`), set this to
941    /// the total anchor count (8 400 for standard 640 × 640 YOLO models) or
942    /// to `0` (no limit) so that all score-passing candidates reach NMS.
943    /// Failing to do so causes **~9 pp box mAP loss** — the decoder math is
944    /// correct but the evaluation protocol requires full recall across the
945    /// confidence range.
946    ///
947    /// Post-processing latency scales with candidate count. At deployment
948    /// thresholds the cost difference is negligible; at validation thresholds
949    /// it is measurable but necessary for correct results.
950    ///
951    /// # Examples
952    ///
953    /// Deployment (default top-K, high score threshold):
954    /// ```rust
955    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
956    /// # fn main() -> DecoderResult<()> {
957    /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
958    /// let decoder = DecoderBuilder::new()
959    ///     .with_config_json_str(config_json)
960    ///     .with_score_threshold(0.25)
961    ///     // pre_nms_top_k defaults to 300 — appropriate here
962    ///     .build()?;
963    /// # Ok(())
964    /// # }
965    /// ```
966    ///
967    /// COCO mAP evaluation (pass all anchors to NMS):
968    /// ```rust
969    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
970    /// # fn main() -> DecoderResult<()> {
971    /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
972    /// let decoder = DecoderBuilder::new()
973    ///     .with_config_json_str(config_json)
974    ///     .with_score_threshold(0.001)
975    ///     .with_pre_nms_top_k(8400)  // all YOLO anchors
976    ///     .with_max_det(300)
977    ///     .build()?;
978    /// assert_eq!(decoder.pre_nms_top_k, 8400);
979    /// # Ok(())
980    /// # }
981    /// ```
982    pub fn with_pre_nms_top_k(mut self, pre_nms_top_k: usize) -> Self {
983        self.pre_nms_top_k = pre_nms_top_k;
984        self
985    }
986
987    /// Sets the maximum number of detections returned after NMS.
988    /// Matches the Ultralytics `max_det` parameter.
989    ///
990    /// Default: 300.
991    ///
992    /// # Examples
993    /// ```rust
994    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
995    /// # fn main() -> DecoderResult<()> {
996    /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
997    /// let decoder = DecoderBuilder::new()
998    ///     .with_config_json_str(config_json)
999    ///     .with_max_det(100)
1000    ///     .build()?;
1001    /// assert_eq!(decoder.max_det, 100);
1002    /// # Ok(())
1003    /// # }
1004    /// ```
1005    pub fn with_max_det(mut self, max_det: usize) -> Self {
1006        self.max_det = max_det;
1007        self
1008    }
1009
1010    /// Sets the model input dimensions `(width, height)` consumed by the
1011    /// EDGEAI-1303 normalization path. Use this when building via
1012    /// [`with_config`](Self::with_config) / [`add_output`](Self::add_output)
1013    /// (no schema) and the model emits pixel-space boxes that need to be
1014    /// divided by `(W, H)` before NMS.
1015    ///
1016    /// When the builder is also configured with [`with_schema`](Self::with_schema)
1017    /// (or `with_config_json_str` / `with_config_yaml_str`) and the schema's
1018    /// `input` block carries usable dims, this explicit override **takes
1019    /// precedence** so callers can correct schemas with missing or wrong
1020    /// input specs without rewriting the schema.
1021    ///
1022    /// # Examples
1023    /// ```rust
1024    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
1025    /// # fn main() -> DecoderResult<()> {
1026    /// # let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
1027    /// let decoder = DecoderBuilder::new()
1028    ///     .with_config_yaml_str(config_yaml)
1029    ///     .with_input_dims(640, 640)
1030    ///     .build()?;
1031    /// assert_eq!(decoder.input_dims(), Some((640, 640)));
1032    /// # Ok(())
1033    /// # }
1034    /// ```
1035    pub fn with_input_dims(mut self, width: usize, height: usize) -> Self {
1036        self.input_dims = Some((width, height));
1037        self
1038    }
1039
1040    /// Builds the decoder with the given settings. If the config is a JSON or
1041    /// YAML string, this will deserialize the JSON or YAML and then parse the
1042    /// configuration information.
1043    ///
1044    /// # Examples
1045    /// ```rust
1046    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
1047    /// # fn main() -> DecoderResult<()> {
1048    /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
1049    /// let decoder = DecoderBuilder::new()
1050    ///     .with_config_json_str(config_json)
1051    ///     .with_score_threshold(0.654)
1052    ///     .build()?;
1053    /// # Ok(())
1054    /// # }
1055    /// ```
1056    pub fn build(self) -> Result<Decoder, DecoderError> {
1057        let decode_dtype = self.decode_dtype;
1058        let explicit_input_dims = self.input_dims;
1059        let (config, decode_program, per_scale_plan, schema_input_dims) = match self.config_src {
1060            Some(ConfigSource::Json(s)) => {
1061                Self::build_from_schema(SchemaV2::parse_json(&s)?, decode_dtype)?
1062            }
1063            Some(ConfigSource::Yaml(s)) => {
1064                Self::build_from_schema(SchemaV2::parse_yaml(&s)?, decode_dtype)?
1065            }
1066            Some(ConfigSource::Config(c)) => (c, None, None, None),
1067            Some(ConfigSource::Schema(schema)) => Self::build_from_schema(schema, decode_dtype)?,
1068            None => return Err(DecoderError::NoConfig),
1069        };
1070        // Explicit `with_input_dims(W, H)` overrides any schema-derived
1071        // value so callers can fix schemas with missing or wrong input
1072        // specs without rewriting the schema (EDGEAI-1303).
1073        let input_dims = explicit_input_dims.or(schema_input_dims);
1074
1075        // Enforce the physical-order contract: when dshape is present
1076        // it must describe the same axes as shape in the same order,
1077        // listed from outermost to innermost. Ambiguous-layout roles
1078        // (Protos, Boxes, Scores, MaskCoefficients, Classes, Detection)
1079        // may still omit dshape when shape is already in the decoder's
1080        // canonical order.
1081        for output in &config.outputs {
1082            Decoder::validate_output_layout(output.into())?;
1083        }
1084
1085        // Extract normalized flag from config outputs.
1086        //
1087        // The per-scale subsystem (DFL/LTRB → dist2bbox → sigmoid) emits
1088        // boxes in pixel coordinates by design — `(grid + dist) * stride`
1089        // — independently of any `normalized: true` annotation in the
1090        // schema. The schema's `normalized` flag describes the model's
1091        // training-time convention, not the runtime output coord space
1092        // for this code path. Override to `Some(false)` when the
1093        // per-scale path is active so `Decoder::normalized_boxes()`
1094        // matches what `decode_proto`/`decode` actually produce; the
1095        // legacy / non-per-scale paths still honor the schema flag.
1096        let normalized = if per_scale_plan.is_some() {
1097            Some(false)
1098        } else {
1099            Self::get_normalized(&config.outputs)
1100        };
1101
1102        // NMS precedence:
1103        //   Some(ClassAgnostic|ClassAware) → explicit user override
1104        //   Some(Auto) → resolve from config, fallback to ClassAgnostic
1105        //   None → NMS disabled (explicit)
1106        //
1107        // `Auto` is always resolved to a concrete mode here — it never
1108        // persists into the built `Decoder`, even if the config itself
1109        // contains `Auto`.
1110        let resolve_auto = |nms: Option<configs::Nms>| match nms {
1111            Some(configs::Nms::Auto) | None => Some(configs::Nms::ClassAgnostic),
1112            concrete => concrete,
1113        };
1114        let nms = match self.nms {
1115            Some(configs::Nms::Auto) => resolve_auto(config.nms),
1116            other => other,
1117        };
1118        // When the per-scale path is active, the per_scale subsystem owns
1119        // model decoding entirely — `decode` / `decode_proto` short-circuit
1120        // on `per_scale.is_some()` before reading `model_type`. Skip the
1121        // legacy ModelType validation, which otherwise rejects per-scale
1122        // schemas that carry `decoder_version: yolo26` (an
1123        // "end-to-end" hint) but use the per-scale split outputs rather
1124        // than the end-to-end split-output shape the legacy validator
1125        // expects. We keep a placeholder `ModelType` so the field remains
1126        // valid; it is dead state for per-scale Decoders.
1127        let model_type = if per_scale_plan.is_some() {
1128            // Drop the un-needed config; the per-scale subsystem owns it.
1129            drop(config);
1130            ModelType::PerScale
1131        } else {
1132            Self::get_model_type(config)?
1133        };
1134
1135        let per_scale = per_scale_plan
1136            .map(|plan| std::sync::Mutex::new(crate::per_scale::PerScaleDecoder::new(plan)));
1137
1138        debug_assert!(
1139            !matches!(nms, Some(configs::Nms::Auto)),
1140            "Nms::Auto must be resolved to a concrete mode before building Decoder"
1141        );
1142
1143        Ok(Decoder {
1144            model_type,
1145            iou_threshold: self.iou_threshold,
1146            score_threshold: self.score_threshold,
1147            nms,
1148            pre_nms_top_k: self.pre_nms_top_k,
1149            max_det: self.max_det,
1150            normalized,
1151            input_dims,
1152            decode_program,
1153            per_scale,
1154        })
1155    }
1156
1157    /// Validate a [`SchemaV2`] and lower it to the (legacy `ConfigOutputs`,
1158    /// optional `DecodeProgram`, optional `PerScalePlan`) tuple the rest
1159    /// of `build()` consumes.
1160    ///
1161    /// Centralises the v2 lowering so JSON, YAML, and direct
1162    /// `with_schema` callers all go through the same validation,
1163    /// merge-program, and per-scale plan construction. `SchemaV2::parse_json`
1164    /// / `parse_yaml` already auto-detect v1 vs v2 input and return a v2
1165    /// schema either way (v1 inputs are upgraded in memory via
1166    /// `from_v1`), so this helper is the sole place that turns a
1167    /// schema into builder-ready state.
1168    #[allow(clippy::type_complexity)]
1169    fn build_from_schema(
1170        schema: SchemaV2,
1171        decode_dtype: DecodeDtype,
1172    ) -> Result<
1173        (
1174            ConfigOutputs,
1175            Option<DecodeProgram>,
1176            Option<PerScalePlan>,
1177            Option<(usize, usize)>,
1178        ),
1179        DecoderError,
1180    > {
1181        schema.validate()?;
1182        let program = DecodeProgram::try_from_schema(&schema)?;
1183        let per_scale = PerScalePlan::try_from_schema(&schema, decode_dtype)?;
1184        // Extract model input (W, H) from `input.shape`/`dshape`. Used by
1185        // the legacy decode path to honour `normalized: false` (see
1186        // EDGEAI-1303). `None` is fine when the schema omits the input
1187        // spec — the decoder falls back to the protobox `>2.0` reject.
1188        let input_dims = schema.input.as_ref().and_then(input_dims_from_spec);
1189        let legacy = schema.to_legacy_config_outputs()?;
1190        Ok((legacy, program, per_scale, input_dims))
1191    }
1192
1193    /// Extracts the normalized flag from config outputs.
1194    /// - `Some(true)`: Boxes are in normalized [0,1] coordinates
1195    /// - `Some(false)`: Boxes are in pixel coordinates
1196    /// - `None`: Unknown (not specified in config), caller must infer
1197    fn get_normalized(outputs: &[ConfigOutput]) -> Option<bool> {
1198        for output in outputs {
1199            match output {
1200                ConfigOutput::Detection(det) => return det.normalized,
1201                ConfigOutput::Boxes(boxes) => return boxes.normalized,
1202                _ => {}
1203            }
1204        }
1205        None // not specified
1206    }
1207
1208    fn get_model_type(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
1209        // yolo or modelpack
1210        let mut yolo = false;
1211        let mut modelpack = false;
1212        for c in &configs.outputs {
1213            match c.decoder() {
1214                DecoderType::ModelPack => modelpack = true,
1215                DecoderType::Ultralytics => yolo = true,
1216            }
1217        }
1218        match (modelpack, yolo) {
1219            (true, true) => Err(DecoderError::InvalidConfig(
1220                "Both ModelPack and Yolo outputs found in config".to_string(),
1221            )),
1222            (true, false) => Self::get_model_type_modelpack(configs),
1223            (false, true) => Self::get_model_type_yolo(configs),
1224            (false, false) => Err(DecoderError::InvalidConfig(
1225                "No outputs found in config".to_string(),
1226            )),
1227        }
1228    }
1229
1230    fn get_model_type_yolo(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
1231        let mut boxes = None;
1232        let mut protos = None;
1233        let mut split_boxes = None;
1234        let mut split_scores = None;
1235        let mut split_mask_coeff = None;
1236        let mut split_classes = None;
1237        for c in configs.outputs {
1238            match c {
1239                ConfigOutput::Detection(detection) => boxes = Some(detection),
1240                ConfigOutput::Segmentation(_) => {
1241                    return Err(DecoderError::InvalidConfig(
1242                        "Invalid Segmentation output with Yolo decoder".to_string(),
1243                    ));
1244                }
1245                ConfigOutput::Protos(protos_) => protos = Some(protos_),
1246                ConfigOutput::Mask(_) => {
1247                    return Err(DecoderError::InvalidConfig(
1248                        "Invalid Mask output with Yolo decoder".to_string(),
1249                    ));
1250                }
1251                ConfigOutput::Scores(scores) => split_scores = Some(scores),
1252                ConfigOutput::Boxes(boxes) => split_boxes = Some(boxes),
1253                ConfigOutput::MaskCoefficients(mask_coeff) => split_mask_coeff = Some(mask_coeff),
1254                ConfigOutput::Classes(classes) => split_classes = Some(classes),
1255            }
1256        }
1257
1258        // Use end-to-end model types when:
1259        // 1. decoder_version is explicitly set to Yolo26 (definitive), OR
1260        //    decoder_version is not set but the dshapes are (batch, num_boxes,
1261        //    num_features)
1262        let is_end_to_end_dshape = boxes.as_ref().is_some_and(|b| {
1263            let dims = b.dshape.iter().map(|(d, _)| *d).collect::<Vec<_>>();
1264            dims == vec![DimName::Batch, DimName::NumBoxes, DimName::NumFeatures]
1265        });
1266
1267        let is_end_to_end = configs
1268            .decoder_version
1269            .map(|v| v.is_end_to_end())
1270            .unwrap_or(is_end_to_end_dshape);
1271
1272        if is_end_to_end {
1273            if let Some(boxes) = boxes {
1274                if let Some(protos) = protos {
1275                    Self::verify_yolo_seg_det_26(&boxes, &protos)?;
1276                    return Ok(ModelType::YoloEndToEndSegDet { boxes, protos });
1277                } else {
1278                    Self::verify_yolo_det_26(&boxes)?;
1279                    return Ok(ModelType::YoloEndToEndDet { boxes });
1280                }
1281            } else if let (Some(split_boxes), Some(split_scores), Some(split_classes)) =
1282                (split_boxes, split_scores, split_classes)
1283            {
1284                if let (Some(split_mask_coeff), Some(protos)) = (split_mask_coeff, protos) {
1285                    Self::verify_yolo_split_end_to_end_segdet(
1286                        &split_boxes,
1287                        &split_scores,
1288                        &split_classes,
1289                        &split_mask_coeff,
1290                        &protos,
1291                    )?;
1292                    return Ok(ModelType::YoloSplitEndToEndSegDet {
1293                        boxes: split_boxes,
1294                        scores: split_scores,
1295                        classes: split_classes,
1296                        mask_coeff: split_mask_coeff,
1297                        protos,
1298                    });
1299                }
1300                Self::verify_yolo_split_end_to_end_det(
1301                    &split_boxes,
1302                    &split_scores,
1303                    &split_classes,
1304                )?;
1305                return Ok(ModelType::YoloSplitEndToEndDet {
1306                    boxes: split_boxes,
1307                    scores: split_scores,
1308                    classes: split_classes,
1309                });
1310            } else {
1311                return Err(DecoderError::InvalidConfig(
1312                    "Invalid Yolo end-to-end model outputs".to_string(),
1313                ));
1314            }
1315        }
1316
1317        if let Some(boxes) = boxes {
1318            match (split_mask_coeff, protos) {
1319                (Some(mask_coeff), Some(protos)) => {
1320                    // 2-way split: combined detection + separate mask_coeff + protos
1321                    Self::verify_yolo_seg_det_2way(&boxes, &mask_coeff, &protos)?;
1322                    Ok(ModelType::YoloSegDet2Way {
1323                        boxes,
1324                        mask_coeff,
1325                        protos,
1326                    })
1327                }
1328                (_, Some(protos)) => {
1329                    // Unsplit: mask_coefs embedded in detection tensor
1330                    Self::verify_yolo_seg_det(&boxes, &protos)?;
1331                    Ok(ModelType::YoloSegDet { boxes, protos })
1332                }
1333                _ => {
1334                    Self::verify_yolo_det(&boxes)?;
1335                    Ok(ModelType::YoloDet { boxes })
1336                }
1337            }
1338        } else if let (Some(boxes), Some(scores)) = (split_boxes, split_scores) {
1339            if let (Some(mask_coeff), Some(protos)) = (split_mask_coeff, protos) {
1340                Self::verify_yolo_split_segdet(&boxes, &scores, &mask_coeff, &protos)?;
1341                Ok(ModelType::YoloSplitSegDet {
1342                    boxes,
1343                    scores,
1344                    mask_coeff,
1345                    protos,
1346                })
1347            } else {
1348                Self::verify_yolo_split_det(&boxes, &scores)?;
1349                Ok(ModelType::YoloSplitDet { boxes, scores })
1350            }
1351        } else {
1352            Err(DecoderError::InvalidConfig(
1353                "Invalid Yolo model outputs".to_string(),
1354            ))
1355        }
1356    }
1357
1358    fn verify_yolo_det(detect: &configs::Detection) -> Result<(), DecoderError> {
1359        if detect.shape.len() != 3 {
1360            return Err(DecoderError::InvalidConfig(format!(
1361                "Invalid Yolo Detection shape {:?}",
1362                detect.shape
1363            )));
1364        }
1365
1366        Self::verify_dshapes(
1367            &detect.dshape,
1368            &detect.shape,
1369            "Detection",
1370            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1371        )?;
1372        if !detect.dshape.is_empty() {
1373            Self::get_class_count(&detect.dshape, None, None)?;
1374        } else {
1375            Self::get_class_count_no_dshape(detect.into(), None)?;
1376        }
1377
1378        Ok(())
1379    }
1380
1381    fn verify_yolo_det_26(detect: &configs::Detection) -> Result<(), DecoderError> {
1382        if detect.shape.len() != 3 {
1383            return Err(DecoderError::InvalidConfig(format!(
1384                "Invalid Yolo Detection shape {:?}",
1385                detect.shape
1386            )));
1387        }
1388
1389        Self::verify_dshapes(
1390            &detect.dshape,
1391            &detect.shape,
1392            "Detection",
1393            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1394        )?;
1395
1396        if !detect.shape.contains(&6) {
1397            return Err(DecoderError::InvalidConfig(
1398                "Yolo26 Detection must have 6 features".to_string(),
1399            ));
1400        }
1401
1402        Ok(())
1403    }
1404
1405    fn verify_yolo_seg_det(
1406        detection: &configs::Detection,
1407        protos: &configs::Protos,
1408    ) -> Result<(), DecoderError> {
1409        if detection.shape.len() != 3 {
1410            return Err(DecoderError::InvalidConfig(format!(
1411                "Invalid Yolo Detection shape {:?}",
1412                detection.shape
1413            )));
1414        }
1415        if protos.shape.len() != 4 {
1416            return Err(DecoderError::InvalidConfig(format!(
1417                "Invalid Yolo Protos shape {:?}",
1418                protos.shape
1419            )));
1420        }
1421
1422        Self::verify_dshapes(
1423            &detection.dshape,
1424            &detection.shape,
1425            "Detection",
1426            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1427        )?;
1428        Self::verify_dshapes(
1429            &protos.dshape,
1430            &protos.shape,
1431            "Protos",
1432            &[
1433                DimName::Batch,
1434                DimName::Height,
1435                DimName::Width,
1436                DimName::NumProtos,
1437            ],
1438        )?;
1439
1440        let protos_count = Self::get_protos_count(&protos.dshape)
1441            .unwrap_or_else(|| protos.shape[1].min(protos.shape[3]));
1442        log::debug!("Protos count: {}", protos_count);
1443        log::debug!("Detection dshape: {:?}", detection.dshape);
1444        let classes = if !detection.dshape.is_empty() {
1445            Self::get_class_count(&detection.dshape, Some(protos_count), None)?
1446        } else {
1447            Self::get_class_count_no_dshape(detection.into(), Some(protos_count))?
1448        };
1449
1450        if classes == 0 {
1451            return Err(DecoderError::InvalidConfig(
1452                "Yolo Segmentation Detection has zero classes".to_string(),
1453            ));
1454        }
1455
1456        Ok(())
1457    }
1458
1459    fn verify_yolo_seg_det_2way(
1460        detection: &configs::Detection,
1461        mask_coeff: &configs::MaskCoefficients,
1462        protos: &configs::Protos,
1463    ) -> Result<(), DecoderError> {
1464        if detection.shape.len() != 3 {
1465            return Err(DecoderError::InvalidConfig(format!(
1466                "Invalid Yolo 2-Way Detection shape {:?}",
1467                detection.shape
1468            )));
1469        }
1470        if mask_coeff.shape.len() != 3 {
1471            return Err(DecoderError::InvalidConfig(format!(
1472                "Invalid Yolo 2-Way Mask Coefficients shape {:?}",
1473                mask_coeff.shape
1474            )));
1475        }
1476        if protos.shape.len() != 4 {
1477            return Err(DecoderError::InvalidConfig(format!(
1478                "Invalid Yolo 2-Way Protos shape {:?}",
1479                protos.shape
1480            )));
1481        }
1482
1483        Self::verify_dshapes(
1484            &detection.dshape,
1485            &detection.shape,
1486            "Detection",
1487            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1488        )?;
1489        Self::verify_dshapes(
1490            &mask_coeff.dshape,
1491            &mask_coeff.shape,
1492            "Mask Coefficients",
1493            &[DimName::Batch, DimName::NumProtos, DimName::NumBoxes],
1494        )?;
1495        Self::verify_dshapes(
1496            &protos.dshape,
1497            &protos.shape,
1498            "Protos",
1499            &[
1500                DimName::Batch,
1501                DimName::Height,
1502                DimName::Width,
1503                DimName::NumProtos,
1504            ],
1505        )?;
1506
1507        // Validate num_boxes match between detection and mask_coeff
1508        let det_num = Self::get_box_count(&detection.dshape).unwrap_or(detection.shape[2]);
1509        let mask_num = Self::get_box_count(&mask_coeff.dshape).unwrap_or(mask_coeff.shape[2]);
1510        if det_num != mask_num {
1511            return Err(DecoderError::InvalidConfig(format!(
1512                "Yolo 2-Way Detection num_boxes {} incompatible with Mask Coefficients num_boxes {}",
1513                det_num, mask_num
1514            )));
1515        }
1516
1517        // Validate mask_coeff channels match protos channels
1518        let mask_channels = if !mask_coeff.dshape.is_empty() {
1519            Self::get_protos_count(&mask_coeff.dshape).ok_or_else(|| {
1520                DecoderError::InvalidConfig(
1521                    "Could not find num_protos in mask_coeff config".to_string(),
1522                )
1523            })?
1524        } else {
1525            mask_coeff.shape[1]
1526        };
1527        let proto_channels = if !protos.dshape.is_empty() {
1528            Self::get_protos_count(&protos.dshape).ok_or_else(|| {
1529                DecoderError::InvalidConfig(
1530                    "Could not find num_protos in protos config".to_string(),
1531                )
1532            })?
1533        } else {
1534            protos.shape[1].min(protos.shape[3])
1535        };
1536        if mask_channels != proto_channels {
1537            return Err(DecoderError::InvalidConfig(format!(
1538                "Yolo 2-Way Protos channels {} incompatible with Mask Coefficients channels {}",
1539                proto_channels, mask_channels
1540            )));
1541        }
1542
1543        // Validate detection has classes (nc+4 features, no mask_coefs embedded)
1544        if !detection.dshape.is_empty() {
1545            Self::get_class_count(&detection.dshape, None, None)?;
1546        } else {
1547            Self::get_class_count_no_dshape(detection.into(), None)?;
1548        }
1549
1550        Ok(())
1551    }
1552
1553    fn verify_yolo_seg_det_26(
1554        detection: &configs::Detection,
1555        protos: &configs::Protos,
1556    ) -> Result<(), DecoderError> {
1557        if detection.shape.len() != 3 {
1558            return Err(DecoderError::InvalidConfig(format!(
1559                "Invalid Yolo Detection shape {:?}",
1560                detection.shape
1561            )));
1562        }
1563        if protos.shape.len() != 4 {
1564            return Err(DecoderError::InvalidConfig(format!(
1565                "Invalid Yolo Protos shape {:?}",
1566                protos.shape
1567            )));
1568        }
1569
1570        Self::verify_dshapes(
1571            &detection.dshape,
1572            &detection.shape,
1573            "Detection",
1574            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1575        )?;
1576        Self::verify_dshapes(
1577            &protos.dshape,
1578            &protos.shape,
1579            "Protos",
1580            &[
1581                DimName::Batch,
1582                DimName::Height,
1583                DimName::Width,
1584                DimName::NumProtos,
1585            ],
1586        )?;
1587
1588        let protos_count = Self::get_protos_count(&protos.dshape)
1589            .unwrap_or_else(|| protos.shape[1].min(protos.shape[3]));
1590        log::debug!("Protos count: {}", protos_count);
1591        log::debug!("Detection dshape: {:?}", detection.dshape);
1592
1593        if !detection.shape.contains(&(6 + protos_count)) {
1594            return Err(DecoderError::InvalidConfig(format!(
1595                "Yolo26 Segmentation Detection must have num_features be 6 + num_protos = {}",
1596                6 + protos_count
1597            )));
1598        }
1599
1600        Ok(())
1601    }
1602
1603    fn verify_yolo_split_det(
1604        boxes: &configs::Boxes,
1605        scores: &configs::Scores,
1606    ) -> Result<(), DecoderError> {
1607        if boxes.shape.len() != 3 {
1608            return Err(DecoderError::InvalidConfig(format!(
1609                "Invalid Yolo Split Boxes shape {:?}",
1610                boxes.shape
1611            )));
1612        }
1613        if scores.shape.len() != 3 {
1614            return Err(DecoderError::InvalidConfig(format!(
1615                "Invalid Yolo Split Scores shape {:?}",
1616                scores.shape
1617            )));
1618        }
1619
1620        Self::verify_dshapes(
1621            &boxes.dshape,
1622            &boxes.shape,
1623            "Boxes",
1624            &[DimName::Batch, DimName::BoxCoords, DimName::NumBoxes],
1625        )?;
1626        Self::verify_dshapes(
1627            &scores.dshape,
1628            &scores.shape,
1629            "Scores",
1630            &[DimName::Batch, DimName::NumClasses, DimName::NumBoxes],
1631        )?;
1632
1633        let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[2]);
1634        let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[2]);
1635
1636        if boxes_num != scores_num {
1637            return Err(DecoderError::InvalidConfig(format!(
1638                "Yolo Split Detection Boxes num {} incompatible with Scores num {}",
1639                boxes_num, scores_num
1640            )));
1641        }
1642
1643        Ok(())
1644    }
1645
1646    fn verify_yolo_split_segdet(
1647        boxes: &configs::Boxes,
1648        scores: &configs::Scores,
1649        mask_coeff: &configs::MaskCoefficients,
1650        protos: &configs::Protos,
1651    ) -> Result<(), DecoderError> {
1652        if boxes.shape.len() != 3 {
1653            return Err(DecoderError::InvalidConfig(format!(
1654                "Invalid Yolo Split Boxes shape {:?}",
1655                boxes.shape
1656            )));
1657        }
1658        if scores.shape.len() != 3 {
1659            return Err(DecoderError::InvalidConfig(format!(
1660                "Invalid Yolo Split Scores shape {:?}",
1661                scores.shape
1662            )));
1663        }
1664
1665        if mask_coeff.shape.len() != 3 {
1666            return Err(DecoderError::InvalidConfig(format!(
1667                "Invalid Yolo Split Mask Coefficients shape {:?}",
1668                mask_coeff.shape
1669            )));
1670        }
1671
1672        if protos.shape.len() != 4 {
1673            return Err(DecoderError::InvalidConfig(format!(
1674                "Invalid Yolo Protos shape {:?}",
1675                mask_coeff.shape
1676            )));
1677        }
1678
1679        Self::verify_dshapes(
1680            &boxes.dshape,
1681            &boxes.shape,
1682            "Boxes",
1683            &[DimName::Batch, DimName::BoxCoords, DimName::NumBoxes],
1684        )?;
1685        Self::verify_dshapes(
1686            &scores.dshape,
1687            &scores.shape,
1688            "Scores",
1689            &[DimName::Batch, DimName::NumClasses, DimName::NumBoxes],
1690        )?;
1691        Self::verify_dshapes(
1692            &mask_coeff.dshape,
1693            &mask_coeff.shape,
1694            "Mask Coefficients",
1695            &[DimName::Batch, DimName::NumProtos, DimName::NumBoxes],
1696        )?;
1697        Self::verify_dshapes(
1698            &protos.dshape,
1699            &protos.shape,
1700            "Protos",
1701            &[
1702                DimName::Batch,
1703                DimName::Height,
1704                DimName::Width,
1705                DimName::NumProtos,
1706            ],
1707        )?;
1708
1709        let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[2]);
1710        let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[2]);
1711        let mask_num = Self::get_box_count(&mask_coeff.dshape).unwrap_or(mask_coeff.shape[2]);
1712
1713        let mask_channels = if !mask_coeff.dshape.is_empty() {
1714            Self::get_protos_count(&mask_coeff.dshape).ok_or_else(|| {
1715                DecoderError::InvalidConfig("Could not find num_protos in config".to_string())
1716            })?
1717        } else {
1718            mask_coeff.shape[1]
1719        };
1720        let proto_channels = if !protos.dshape.is_empty() {
1721            Self::get_protos_count(&protos.dshape).ok_or_else(|| {
1722                DecoderError::InvalidConfig("Could not find num_protos in config".to_string())
1723            })?
1724        } else {
1725            protos.shape[1].min(protos.shape[3])
1726        };
1727
1728        if boxes_num != scores_num {
1729            return Err(DecoderError::InvalidConfig(format!(
1730                "Yolo Split Detection Boxes num {} incompatible with Scores num {}",
1731                boxes_num, scores_num
1732            )));
1733        }
1734
1735        if boxes_num != mask_num {
1736            return Err(DecoderError::InvalidConfig(format!(
1737                "Yolo Split Detection Boxes num {} incompatible with Mask Coefficients num {}",
1738                boxes_num, mask_num
1739            )));
1740        }
1741
1742        if proto_channels != mask_channels {
1743            return Err(DecoderError::InvalidConfig(format!(
1744                "Yolo Protos channels {} incompatible with Mask Coefficients channels {}",
1745                proto_channels, mask_channels
1746            )));
1747        }
1748
1749        Ok(())
1750    }
1751
1752    fn verify_yolo_split_end_to_end_det(
1753        boxes: &configs::Boxes,
1754        scores: &configs::Scores,
1755        classes: &configs::Classes,
1756    ) -> Result<(), DecoderError> {
1757        if boxes.shape.len() != 3 || !boxes.shape.contains(&4) {
1758            return Err(DecoderError::InvalidConfig(format!(
1759                "Split end-to-end boxes must be [batch, N, 4], got {:?}",
1760                boxes.shape
1761            )));
1762        }
1763        if scores.shape.len() != 3 || !scores.shape.contains(&1) {
1764            return Err(DecoderError::InvalidConfig(format!(
1765                "Split end-to-end scores must be [batch, N, 1], got {:?}",
1766                scores.shape
1767            )));
1768        }
1769        if classes.shape.len() != 3 || !classes.shape.contains(&1) {
1770            return Err(DecoderError::InvalidConfig(format!(
1771                "Split end-to-end classes must be [batch, N, 1], got {:?}",
1772                classes.shape
1773            )));
1774        }
1775        Ok(())
1776    }
1777
1778    fn verify_yolo_split_end_to_end_segdet(
1779        boxes: &configs::Boxes,
1780        scores: &configs::Scores,
1781        classes: &configs::Classes,
1782        mask_coeff: &configs::MaskCoefficients,
1783        protos: &configs::Protos,
1784    ) -> Result<(), DecoderError> {
1785        Self::verify_yolo_split_end_to_end_det(boxes, scores, classes)?;
1786        if mask_coeff.shape.len() != 3 {
1787            return Err(DecoderError::InvalidConfig(format!(
1788                "Invalid split end-to-end mask coefficients shape {:?}",
1789                mask_coeff.shape
1790            )));
1791        }
1792        if protos.shape.len() != 4 {
1793            return Err(DecoderError::InvalidConfig(format!(
1794                "Invalid protos shape {:?}",
1795                protos.shape
1796            )));
1797        }
1798        Ok(())
1799    }
1800
1801    fn get_model_type_modelpack(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
1802        let mut split_decoders = Vec::new();
1803        let mut segment_ = None;
1804        let mut scores_ = None;
1805        let mut boxes_ = None;
1806        for c in configs.outputs {
1807            match c {
1808                ConfigOutput::Detection(detection) => split_decoders.push(detection),
1809                ConfigOutput::Segmentation(segmentation) => segment_ = Some(segmentation),
1810                ConfigOutput::Mask(_) => {}
1811                ConfigOutput::Protos(_) => {
1812                    return Err(DecoderError::InvalidConfig(
1813                        "ModelPack should not have protos".to_string(),
1814                    ));
1815                }
1816                ConfigOutput::Scores(scores) => scores_ = Some(scores),
1817                ConfigOutput::Boxes(boxes) => boxes_ = Some(boxes),
1818                ConfigOutput::MaskCoefficients(_) => {
1819                    return Err(DecoderError::InvalidConfig(
1820                        "ModelPack should not have mask coefficients".to_string(),
1821                    ));
1822                }
1823                ConfigOutput::Classes(_) => {
1824                    return Err(DecoderError::InvalidConfig(
1825                        "ModelPack should not have classes output".to_string(),
1826                    ));
1827                }
1828            }
1829        }
1830
1831        if let Some(segmentation) = segment_ {
1832            if !split_decoders.is_empty() {
1833                let classes = Self::verify_modelpack_split_det(&split_decoders)?;
1834                Self::verify_modelpack_seg(&segmentation, Some(classes))?;
1835                Ok(ModelType::ModelPackSegDetSplit {
1836                    detection: split_decoders,
1837                    segmentation,
1838                })
1839            } else if let (Some(scores), Some(boxes)) = (scores_, boxes_) {
1840                let classes = Self::verify_modelpack_det(&boxes, &scores)?;
1841                Self::verify_modelpack_seg(&segmentation, Some(classes))?;
1842                Ok(ModelType::ModelPackSegDet {
1843                    boxes,
1844                    scores,
1845                    segmentation,
1846                })
1847            } else {
1848                Self::verify_modelpack_seg(&segmentation, None)?;
1849                Ok(ModelType::ModelPackSeg { segmentation })
1850            }
1851        } else if !split_decoders.is_empty() {
1852            Self::verify_modelpack_split_det(&split_decoders)?;
1853            Ok(ModelType::ModelPackDetSplit {
1854                detection: split_decoders,
1855            })
1856        } else if let (Some(scores), Some(boxes)) = (scores_, boxes_) {
1857            Self::verify_modelpack_det(&boxes, &scores)?;
1858            Ok(ModelType::ModelPackDet { boxes, scores })
1859        } else {
1860            Err(DecoderError::InvalidConfig(
1861                "Invalid ModelPack model outputs".to_string(),
1862            ))
1863        }
1864    }
1865
1866    fn verify_modelpack_det(
1867        boxes: &configs::Boxes,
1868        scores: &configs::Scores,
1869    ) -> Result<usize, DecoderError> {
1870        if boxes.shape.len() != 4 {
1871            return Err(DecoderError::InvalidConfig(format!(
1872                "Invalid ModelPack Boxes shape {:?}",
1873                boxes.shape
1874            )));
1875        }
1876        if scores.shape.len() != 3 {
1877            return Err(DecoderError::InvalidConfig(format!(
1878                "Invalid ModelPack Scores shape {:?}",
1879                scores.shape
1880            )));
1881        }
1882
1883        Self::verify_dshapes(
1884            &boxes.dshape,
1885            &boxes.shape,
1886            "Boxes",
1887            &[
1888                DimName::Batch,
1889                DimName::NumBoxes,
1890                DimName::Padding,
1891                DimName::BoxCoords,
1892            ],
1893        )?;
1894        Self::verify_dshapes(
1895            &scores.dshape,
1896            &scores.shape,
1897            "Scores",
1898            &[DimName::Batch, DimName::NumBoxes, DimName::NumClasses],
1899        )?;
1900
1901        let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[1]);
1902        let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[1]);
1903
1904        if boxes_num != scores_num {
1905            return Err(DecoderError::InvalidConfig(format!(
1906                "ModelPack Detection Boxes num {} incompatible with Scores num {}",
1907                boxes_num, scores_num
1908            )));
1909        }
1910
1911        let num_classes = if !scores.dshape.is_empty() {
1912            Self::get_class_count(&scores.dshape, None, None)?
1913        } else {
1914            Self::get_class_count_no_dshape(scores.into(), None)?
1915        };
1916
1917        Ok(num_classes)
1918    }
1919
1920    fn verify_modelpack_split_det(boxes: &[configs::Detection]) -> Result<usize, DecoderError> {
1921        let mut num_classes = None;
1922        for b in boxes {
1923            let Some(num_anchors) = b.anchors.as_ref().map(|a| a.len()) else {
1924                return Err(DecoderError::InvalidConfig(
1925                    "ModelPack Split Detection missing anchors".to_string(),
1926                ));
1927            };
1928
1929            if num_anchors == 0 {
1930                return Err(DecoderError::InvalidConfig(
1931                    "ModelPack Split Detection has zero anchors".to_string(),
1932                ));
1933            }
1934
1935            if b.shape.len() != 4 {
1936                return Err(DecoderError::InvalidConfig(format!(
1937                    "Invalid ModelPack Split Detection shape {:?}",
1938                    b.shape
1939                )));
1940            }
1941
1942            Self::verify_dshapes(
1943                &b.dshape,
1944                &b.shape,
1945                "Split Detection",
1946                &[
1947                    DimName::Batch,
1948                    DimName::Height,
1949                    DimName::Width,
1950                    DimName::NumAnchorsXFeatures,
1951                ],
1952            )?;
1953            let classes = if !b.dshape.is_empty() {
1954                Self::get_class_count(&b.dshape, None, Some(num_anchors))?
1955            } else {
1956                Self::get_class_count_no_dshape(b.into(), None)?
1957            };
1958
1959            match num_classes {
1960                Some(n) => {
1961                    if n != classes {
1962                        return Err(DecoderError::InvalidConfig(format!(
1963                            "ModelPack Split Detection inconsistent number of classes: previous {}, current {}",
1964                            n, classes
1965                        )));
1966                    }
1967                }
1968                None => {
1969                    num_classes = Some(classes);
1970                }
1971            }
1972        }
1973
1974        Ok(num_classes.unwrap_or(0))
1975    }
1976
1977    fn verify_modelpack_seg(
1978        segmentation: &configs::Segmentation,
1979        classes: Option<usize>,
1980    ) -> Result<(), DecoderError> {
1981        if segmentation.shape.len() != 4 {
1982            return Err(DecoderError::InvalidConfig(format!(
1983                "Invalid ModelPack Segmentation shape {:?}",
1984                segmentation.shape
1985            )));
1986        }
1987        Self::verify_dshapes(
1988            &segmentation.dshape,
1989            &segmentation.shape,
1990            "Segmentation",
1991            &[
1992                DimName::Batch,
1993                DimName::Height,
1994                DimName::Width,
1995                DimName::NumClasses,
1996            ],
1997        )?;
1998
1999        if let Some(classes) = classes {
2000            let seg_classes = if !segmentation.dshape.is_empty() {
2001                Self::get_class_count(&segmentation.dshape, None, None)?
2002            } else {
2003                Self::get_class_count_no_dshape(segmentation.into(), None)?
2004            };
2005
2006            if seg_classes != classes + 1 {
2007                return Err(DecoderError::InvalidConfig(format!(
2008                    "ModelPack Segmentation channels {} incompatible with number of classes {}",
2009                    seg_classes, classes
2010                )));
2011            }
2012        }
2013        Ok(())
2014    }
2015
2016    // verifies that dshapes match the given shape
2017    fn verify_dshapes(
2018        dshape: &[(DimName, usize)],
2019        shape: &[usize],
2020        name: &str,
2021        dims: &[DimName],
2022    ) -> Result<(), DecoderError> {
2023        for s in shape {
2024            if *s == 0 {
2025                return Err(DecoderError::InvalidConfig(format!(
2026                    "{} shape has zero dimension",
2027                    name
2028                )));
2029            }
2030        }
2031
2032        if shape.len() != dims.len() {
2033            return Err(DecoderError::InvalidConfig(format!(
2034                "{} shape length {} does not match expected dims length {}",
2035                name,
2036                shape.len(),
2037                dims.len()
2038            )));
2039        }
2040
2041        if dshape.is_empty() {
2042            return Ok(());
2043        }
2044        // check the dshape lengths match the shape lengths
2045        if dshape.len() != shape.len() {
2046            return Err(DecoderError::InvalidConfig(format!(
2047                "{} dshape length does not match shape length",
2048                name
2049            )));
2050        }
2051
2052        // check the dshape values match the shape values
2053        for ((dim_name, dim_size), shape_size) in dshape.iter().zip(shape) {
2054            if dim_size != shape_size {
2055                return Err(DecoderError::InvalidConfig(format!(
2056                    "{} dshape dimension {} size {} does not match shape size {}",
2057                    name, dim_name, dim_size, shape_size
2058                )));
2059            }
2060            if *dim_name == DimName::Padding && *dim_size != 1 {
2061                return Err(DecoderError::InvalidConfig(
2062                    "Padding dimension size must be 1".to_string(),
2063                ));
2064            }
2065
2066            if *dim_name == DimName::BoxCoords && *dim_size != 4 {
2067                return Err(DecoderError::InvalidConfig(
2068                    "BoxCoords dimension size must be 4".to_string(),
2069                ));
2070            }
2071        }
2072
2073        let dims_present = HashSet::<DimName>::from_iter(dshape.iter().map(|(name, _)| *name));
2074        for dim in dims {
2075            if !dims_present.contains(dim) {
2076                return Err(DecoderError::InvalidConfig(format!(
2077                    "{} dshape missing required dimension {:?}",
2078                    name, dim
2079                )));
2080            }
2081        }
2082
2083        Ok(())
2084    }
2085
2086    fn get_box_count(dshape: &[(DimName, usize)]) -> Option<usize> {
2087        for (dim_name, dim_size) in dshape {
2088            if *dim_name == DimName::NumBoxes {
2089                return Some(*dim_size);
2090            }
2091        }
2092        None
2093    }
2094
2095    fn get_class_count_no_dshape(
2096        config: ConfigOutputRef,
2097        protos: Option<usize>,
2098    ) -> Result<usize, DecoderError> {
2099        match config {
2100            ConfigOutputRef::Detection(detection) => match detection.decoder {
2101                DecoderType::Ultralytics => {
2102                    if detection.shape[1] <= 4 + protos.unwrap_or(0) {
2103                        return Err(DecoderError::InvalidConfig(format!(
2104                            "Invalid shape: Yolo num_features {} must be greater than {}",
2105                            detection.shape[1],
2106                            4 + protos.unwrap_or(0),
2107                        )));
2108                    }
2109                    Ok(detection.shape[1] - 4 - protos.unwrap_or(0))
2110                }
2111                DecoderType::ModelPack => {
2112                    let Some(num_anchors) = detection.anchors.as_ref().map(|a| a.len()) else {
2113                        return Err(DecoderError::Internal(
2114                            "ModelPack Detection missing anchors".to_string(),
2115                        ));
2116                    };
2117                    let anchors_x_features = detection.shape[3];
2118                    if anchors_x_features <= num_anchors * 5 {
2119                        return Err(DecoderError::InvalidConfig(format!(
2120                            "Invalid ModelPack Split Detection shape: anchors_x_features {} not greater than number of anchors * 5 = {}",
2121                            anchors_x_features,
2122                            num_anchors * 5,
2123                        )));
2124                    }
2125
2126                    if !anchors_x_features.is_multiple_of(num_anchors) {
2127                        return Err(DecoderError::InvalidConfig(format!(
2128                            "Invalid ModelPack Split Detection shape: anchors_x_features {} not a multiple of number of anchors {}",
2129                            anchors_x_features, num_anchors
2130                        )));
2131                    }
2132                    Ok(anchors_x_features / num_anchors - 5)
2133                }
2134            },
2135
2136            ConfigOutputRef::Scores(scores) => match scores.decoder {
2137                DecoderType::Ultralytics => Ok(scores.shape[1]),
2138                DecoderType::ModelPack => Ok(scores.shape[2]),
2139            },
2140            ConfigOutputRef::Segmentation(seg) => Ok(seg.shape[3]),
2141            _ => Err(DecoderError::Internal(
2142                "Attempted to get class count from unsupported config output".to_owned(),
2143            )),
2144        }
2145    }
2146
2147    // get the class count from dshape or calculate from num_features
2148    fn get_class_count(
2149        dshape: &[(DimName, usize)],
2150        protos: Option<usize>,
2151        anchors: Option<usize>,
2152    ) -> Result<usize, DecoderError> {
2153        if dshape.is_empty() {
2154            return Ok(0);
2155        }
2156        // if it has num_classes in dshape, return it
2157        for (dim_name, dim_size) in dshape {
2158            if *dim_name == DimName::NumClasses {
2159                return Ok(*dim_size);
2160            }
2161        }
2162
2163        // number of classes can be calculated from num_features - 4 for yolo.  If the
2164        // model has protos, we also subtract the number of protos.
2165        for (dim_name, dim_size) in dshape {
2166            if *dim_name == DimName::NumFeatures {
2167                let protos = protos.unwrap_or(0);
2168                if protos + 4 >= *dim_size {
2169                    return Err(DecoderError::InvalidConfig(format!(
2170                        "Invalid shape: Yolo num_features {} must be greater than {}",
2171                        *dim_size,
2172                        protos + 4,
2173                    )));
2174                }
2175                return Ok(*dim_size - 4 - protos);
2176            }
2177        }
2178
2179        // number of classes can be calculated from number of anchors for modelpack
2180        // split detection
2181        if let Some(num_anchors) = anchors {
2182            for (dim_name, dim_size) in dshape {
2183                if *dim_name == DimName::NumAnchorsXFeatures {
2184                    let anchors_x_features = *dim_size;
2185                    if anchors_x_features <= num_anchors * 5 {
2186                        return Err(DecoderError::InvalidConfig(format!(
2187                            "Invalid ModelPack Split Detection shape: anchors_x_features {} not greater than number of anchors * 5 = {}",
2188                            anchors_x_features,
2189                            num_anchors * 5,
2190                        )));
2191                    }
2192
2193                    if !anchors_x_features.is_multiple_of(num_anchors) {
2194                        return Err(DecoderError::InvalidConfig(format!(
2195                            "Invalid ModelPack Split Detection shape: anchors_x_features {} not a multiple of number of anchors {}",
2196                            anchors_x_features, num_anchors
2197                        )));
2198                    }
2199                    return Ok((anchors_x_features / num_anchors) - 5);
2200                }
2201            }
2202        }
2203        Err(DecoderError::InvalidConfig(
2204            "Cannot determine number of classes from dshape".to_owned(),
2205        ))
2206    }
2207
2208    fn get_protos_count(dshape: &[(DimName, usize)]) -> Option<usize> {
2209        for (dim_name, dim_size) in dshape {
2210            if *dim_name == DimName::NumProtos {
2211                return Some(*dim_size);
2212            }
2213        }
2214        None
2215    }
2216}