edgefirst_decoder/decoder/builder.rs
1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::HashSet;
5
6use super::config::ConfigOutputRef;
7use super::configs::{self, DecoderType, DecoderVersion, DimName, ModelType};
8use super::merge::DecodeProgram;
9use super::{ConfigOutput, ConfigOutputs, Decoder};
10use crate::per_scale::{DecodeDtype, PerScalePlan};
11use crate::schema::SchemaV2;
12use crate::DecoderError;
13
14/// Extract `(width, height)` from a schema [`crate::schema::InputSpec`].
15///
16/// Prefers named dims (`DimName::Width` / `DimName::Height`) when the
17/// `dshape` is populated, falling back to the NHWC convention
18/// (`shape[1] = H, shape[2] = W`) for 4-element shapes whenever either
19/// named dim is missing. Returns `None` for any other shape arity — the
20/// decoder treats that as "input dims unknown" and skips the EDGEAI-1303
21/// normalization path.
22///
23/// The fallback fires when **either** `Height` or `Width` is missing from
24/// the dshape (not only when both are absent), so a partially-named
25/// dshape (e.g. only `Width`) still resolves both dims via positional
26/// inference instead of silently disabling normalization.
27fn input_dims_from_spec(input: &crate::schema::InputSpec) -> Option<(usize, usize)> {
28 use crate::configs::DimName;
29 let mut h = None;
30 let mut w = None;
31 // `SchemaV2::validate()` doesn't currently enforce
32 // `dshape.len() <= shape.len()`, so a malformed schema can trip an
33 // out-of-bounds index here. Use `shape.get(i)` to silently skip
34 // dshape entries that overflow the shape — the caller treats
35 // missing dims as "unknown" and disables EDGEAI-1303 normalization.
36 for (i, (name, _)) in input.dshape.iter().enumerate() {
37 match name {
38 DimName::Height => h = input.shape.get(i).copied(),
39 DimName::Width => w = input.shape.get(i).copied(),
40 _ => {}
41 }
42 }
43 if (h.is_none() || w.is_none()) && input.shape.len() == 4 {
44 // NHWC default: [N, H, W, C]. Mirrors the per-scale `extract_hw`
45 // fallback (`crates/decoder/src/per_scale/plan.rs::extract_hw`).
46 // Only fill the missing axis so a partial named dshape still
47 // resolves both dims.
48 h = h.or(Some(input.shape[1]));
49 w = w.or(Some(input.shape[2]));
50 }
51 match (w, h) {
52 (Some(w), Some(h)) => Some((w, h)),
53 _ => None,
54 }
55}
56
57#[cfg(test)]
58mod input_dims_from_spec_tests {
59 use super::input_dims_from_spec;
60 use crate::configs::DimName;
61 use crate::schema::InputSpec;
62
63 #[test]
64 fn named_dshape_resolves_dims() {
65 let spec = InputSpec {
66 shape: vec![1, 480, 640, 3],
67 dshape: vec![
68 (DimName::Batch, 1),
69 (DimName::Height, 480),
70 (DimName::Width, 640),
71 (DimName::NumFeatures, 3),
72 ],
73 cameraadaptor: None,
74 };
75 assert_eq!(input_dims_from_spec(&spec), Some((640, 480)));
76 }
77
78 #[test]
79 fn empty_dshape_falls_back_to_nhwc_for_4d() {
80 let spec = InputSpec {
81 shape: vec![1, 480, 640, 3],
82 dshape: vec![],
83 cameraadaptor: None,
84 };
85 assert_eq!(input_dims_from_spec(&spec), Some((640, 480)));
86 }
87
88 #[test]
89 fn malformed_dshape_longer_than_shape_does_not_panic() {
90 // Regression for Copilot review on PR #63: indexing
91 // `input.shape[i]` while iterating dshape can OOB-panic when
92 // `dshape.len() > shape.len()`. The fix uses `shape.get(i)`
93 // and silently treats overflow as "dim missing".
94 let spec = InputSpec {
95 shape: vec![640, 480], // 2-D shape
96 dshape: vec![
97 (DimName::Width, 640),
98 (DimName::Height, 480),
99 (DimName::NumFeatures, 3), // overflow — index 2 ≥ shape.len()
100 ],
101 cameraadaptor: None,
102 };
103 // First two dshape entries resolve via `shape.get()`; the third
104 // is a no-op. Width/Height both resolved, so we expect Some.
105 assert_eq!(input_dims_from_spec(&spec), Some((640, 480)));
106 }
107
108 #[test]
109 fn malformed_dshape_only_overflow_returns_none() {
110 // All dshape entries are past the shape boundary — width and
111 // height stay None and the 4-D NHWC fallback doesn't fire
112 // (shape.len() == 1), so we get None instead of a panic.
113 let spec = InputSpec {
114 shape: vec![1],
115 dshape: vec![
116 (DimName::NumFeatures, 3),
117 (DimName::Height, 480),
118 (DimName::Width, 640),
119 ],
120 cameraadaptor: None,
121 };
122 assert_eq!(input_dims_from_spec(&spec), None);
123 }
124}
125
126#[derive(Debug, Clone, PartialEq)]
127pub struct DecoderBuilder {
128 config_src: Option<ConfigSource>,
129 iou_threshold: f32,
130 score_threshold: f32,
131 /// NMS mode.
132 ///
133 /// - `Some(Nms::Auto)` — resolve from config or fallback to
134 /// `ClassAgnostic` (builder default)
135 /// - `Some(Nms::ClassAgnostic)` — explicit class-agnostic override
136 /// - `Some(Nms::ClassAware)` — explicit class-aware override
137 /// - `None` — bypass NMS entirely
138 nms: Option<configs::Nms>,
139 /// Output dtype for the per-scale fast path. Has no effect on
140 /// schemas without per-scale children (which use the legacy decode
141 /// path).
142 decode_dtype: DecodeDtype,
143 pre_nms_top_k: usize,
144 max_det: usize,
145 /// Explicit override for the model input dimensions `(width, height)`,
146 /// consumed by EDGEAI-1303 normalization. When set, takes precedence
147 /// over schema-derived dims; when `None`, the value is read from the
148 /// schema's `input.shape` / `input.dshape` at build time.
149 input_dims: Option<(usize, usize)>,
150}
151
152#[derive(Debug, Clone, PartialEq)]
153enum ConfigSource {
154 Yaml(String),
155 Json(String),
156 Config(ConfigOutputs),
157 /// Schema v2 metadata. During build the schema is either converted
158 /// to a legacy [`ConfigOutputs`] (flat case) or used to construct a
159 /// [`DecodeProgram`] that performs per-child dequant + merge at
160 /// decode time.
161 Schema(SchemaV2),
162}
163
164impl Default for DecoderBuilder {
165 /// Creates a default DecoderBuilder with no configuration and 0.5 score
166 /// threshold and 0.5 IoU threshold.
167 ///
168 /// A valid configuration must be provided before building the Decoder.
169 ///
170 /// # Examples
171 /// ```rust
172 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
173 /// # fn main() -> DecoderResult<()> {
174 /// # let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
175 /// let decoder = DecoderBuilder::default()
176 /// .with_config_yaml_str(config_yaml)
177 /// .build()?;
178 /// assert_eq!(decoder.score_threshold, 0.5);
179 /// assert_eq!(decoder.iou_threshold, 0.5);
180 ///
181 /// # Ok(())
182 /// # }
183 /// ```
184 fn default() -> Self {
185 Self {
186 config_src: None,
187 iou_threshold: 0.5,
188 score_threshold: 0.5,
189 nms: Some(configs::Nms::Auto),
190 decode_dtype: DecodeDtype::F32,
191 pre_nms_top_k: 300,
192 max_det: 300,
193 input_dims: None,
194 }
195 }
196}
197
198impl DecoderBuilder {
199 /// Creates a default DecoderBuilder with no configuration and 0.5 score
200 /// threshold and 0.5 IoU threshold.
201 ///
202 /// A valid configuration must be provided before building the Decoder.
203 ///
204 /// # Examples
205 /// ```rust
206 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
207 /// # fn main() -> DecoderResult<()> {
208 /// # let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
209 /// let decoder = DecoderBuilder::new()
210 /// .with_config_yaml_str(config_yaml)
211 /// .build()?;
212 /// assert_eq!(decoder.score_threshold, 0.5);
213 /// assert_eq!(decoder.iou_threshold, 0.5);
214 ///
215 /// # Ok(())
216 /// # }
217 /// ```
218 pub fn new() -> Self {
219 Self::default()
220 }
221
222 /// Loads a model configuration in YAML format. Does not check if the string
223 /// is a correct configuration file. Use `DecoderBuilder.build()` to
224 /// deserialize the YAML and parse the model configuration.
225 ///
226 /// # Examples
227 /// ```rust
228 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
229 /// # fn main() -> DecoderResult<()> {
230 /// let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
231 /// let decoder = DecoderBuilder::new()
232 /// .with_config_yaml_str(config_yaml)
233 /// .build()?;
234 ///
235 /// # Ok(())
236 /// # }
237 /// ```
238 pub fn with_config_yaml_str(mut self, yaml_str: String) -> Self {
239 self.config_src.replace(ConfigSource::Yaml(yaml_str));
240 self
241 }
242
243 /// Loads a model configuration in JSON format. Does not check if the string
244 /// is a correct configuration file. Use `DecoderBuilder.build()` to
245 /// deserialize the JSON and parse the model configuration.
246 ///
247 /// # Examples
248 /// ```rust
249 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
250 /// # fn main() -> DecoderResult<()> {
251 /// let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
252 /// let decoder = DecoderBuilder::new()
253 /// .with_config_json_str(config_json)
254 /// .build()?;
255 ///
256 /// # Ok(())
257 /// # }
258 /// ```
259 pub fn with_config_json_str(mut self, json_str: String) -> Self {
260 self.config_src.replace(ConfigSource::Json(json_str));
261 self
262 }
263
264 /// Loads a model configuration. Does not check if the configuration is
265 /// correct. Intended to be used when the user needs control over the
266 /// deserialize of the configuration information. Use
267 /// `DecoderBuilder.build()` to parse the model configuration.
268 ///
269 /// # Examples
270 /// ```rust
271 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
272 /// # fn main() -> DecoderResult<()> {
273 /// let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json"));
274 /// let config = serde_json::from_str(config_json)?;
275 /// let decoder = DecoderBuilder::new().with_config(config).build()?;
276 ///
277 /// # Ok(())
278 /// # }
279 /// ```
280 pub fn with_config(mut self, config: ConfigOutputs) -> Self {
281 self.config_src.replace(ConfigSource::Config(config));
282 self
283 }
284
285 /// Configure the decoder from a schema v2 metadata document.
286 ///
287 /// Accepts a [`SchemaV2`] as produced by [`SchemaV2::parse_json`],
288 /// [`SchemaV2::parse_yaml`], [`SchemaV2::parse_file`], or
289 /// constructed programmatically. The builder validates the schema,
290 /// compiles a [`DecodeProgram`] for any split logical outputs
291 /// (per-scale or channel sub-splits), and downconverts the
292 /// logical-level semantics to the legacy [`ConfigOutputs`]
293 /// representation consumed by the existing decoder dispatch.
294 ///
295 /// # Examples
296 ///
297 /// ```rust,no_run
298 /// use edgefirst_decoder::{DecoderBuilder, DecoderResult};
299 /// use edgefirst_decoder::schema::SchemaV2;
300 ///
301 /// # fn main() -> DecoderResult<()> {
302 /// let schema = SchemaV2::parse_file("model/edgefirst.json")?;
303 /// let decoder = DecoderBuilder::new()
304 /// .with_schema(schema)
305 /// .with_score_threshold(0.25)
306 /// .build()?;
307 /// # Ok(())
308 /// # }
309 /// ```
310 pub fn with_schema(mut self, schema: SchemaV2) -> Self {
311 self.config_src.replace(ConfigSource::Schema(schema));
312 self
313 }
314
315 /// Choose the output dtype for the per-scale decoder pipeline.
316 ///
317 /// Defaults to [`DecodeDtype::F32`]. Has no effect on schemas
318 /// without per-scale children (which use the legacy decode path).
319 /// `F16` saves ~2× memory bandwidth at the cost of 10-bit mantissa
320 /// precision; empirically safe for YOLO-family models.
321 ///
322 /// # Examples
323 ///
324 /// ```rust,no_run
325 /// use edgefirst_decoder::{DecodeDtype, DecoderBuilder, DecoderResult};
326 /// use edgefirst_decoder::schema::SchemaV2;
327 ///
328 /// # fn main() -> DecoderResult<()> {
329 /// let schema = SchemaV2::parse_file("model/edgefirst.json")?;
330 /// let decoder = DecoderBuilder::new()
331 /// .with_schema(schema)
332 /// .with_decode_dtype(DecodeDtype::F32)
333 /// .build()?;
334 /// # Ok(())
335 /// # }
336 /// ```
337 pub fn with_decode_dtype(mut self, dtype: DecodeDtype) -> Self {
338 self.decode_dtype = dtype;
339 self
340 }
341
342 /// Loads a YOLO detection model configuration. Use
343 /// `DecoderBuilder.build()` to parse the model configuration.
344 ///
345 /// # Examples
346 /// ```rust
347 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
348 /// # fn main() -> DecoderResult<()> {
349 /// let decoder = DecoderBuilder::new()
350 /// .with_config_yolo_det(
351 /// configs::Detection {
352 /// anchors: None,
353 /// decoder: configs::DecoderType::Ultralytics,
354 /// quantization: Some(configs::QuantTuple(0.012345, 26)),
355 /// shape: vec![1, 84, 8400],
356 /// dshape: Vec::new(),
357 /// normalized: Some(true),
358 /// },
359 /// None,
360 /// )
361 /// .build()?;
362 ///
363 /// # Ok(())
364 /// # }
365 /// ```
366 pub fn with_config_yolo_det(
367 mut self,
368 boxes: configs::Detection,
369 version: Option<DecoderVersion>,
370 ) -> Self {
371 let config = ConfigOutputs {
372 outputs: vec![ConfigOutput::Detection(boxes)],
373 decoder_version: version,
374 ..Default::default()
375 };
376 self.config_src.replace(ConfigSource::Config(config));
377 self
378 }
379
380 /// Loads a YOLO split detection model configuration. Use
381 /// `DecoderBuilder.build()` to parse the model configuration.
382 ///
383 /// # Examples
384 /// ```rust
385 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
386 /// # fn main() -> DecoderResult<()> {
387 /// let boxes_config = configs::Boxes {
388 /// decoder: configs::DecoderType::Ultralytics,
389 /// quantization: Some(configs::QuantTuple(0.012345, 26)),
390 /// shape: vec![1, 4, 8400],
391 /// dshape: Vec::new(),
392 /// normalized: Some(true),
393 /// };
394 /// let scores_config = configs::Scores {
395 /// decoder: configs::DecoderType::Ultralytics,
396 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
397 /// shape: vec![1, 80, 8400],
398 /// dshape: Vec::new(),
399 /// };
400 /// let decoder = DecoderBuilder::new()
401 /// .with_config_yolo_split_det(boxes_config, scores_config)
402 /// .build()?;
403 /// # Ok(())
404 /// # }
405 /// ```
406 pub fn with_config_yolo_split_det(
407 mut self,
408 boxes: configs::Boxes,
409 scores: configs::Scores,
410 ) -> Self {
411 let config = ConfigOutputs {
412 outputs: vec![ConfigOutput::Boxes(boxes), ConfigOutput::Scores(scores)],
413 ..Default::default()
414 };
415 self.config_src.replace(ConfigSource::Config(config));
416 self
417 }
418
419 /// Loads a YOLO segmentation model configuration. Use
420 /// `DecoderBuilder.build()` to parse the model configuration.
421 ///
422 /// # Examples
423 /// ```rust
424 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
425 /// # fn main() -> DecoderResult<()> {
426 /// let seg_config = configs::Detection {
427 /// decoder: configs::DecoderType::Ultralytics,
428 /// quantization: Some(configs::QuantTuple(0.012345, 26)),
429 /// shape: vec![1, 116, 8400],
430 /// anchors: None,
431 /// dshape: Vec::new(),
432 /// normalized: Some(true),
433 /// };
434 /// let protos_config = configs::Protos {
435 /// decoder: configs::DecoderType::Ultralytics,
436 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
437 /// shape: vec![1, 160, 160, 32],
438 /// dshape: Vec::new(),
439 /// };
440 /// let decoder = DecoderBuilder::new()
441 /// .with_config_yolo_segdet(
442 /// seg_config,
443 /// protos_config,
444 /// Some(configs::DecoderVersion::Yolov8),
445 /// )
446 /// .build()?;
447 /// # Ok(())
448 /// # }
449 /// ```
450 pub fn with_config_yolo_segdet(
451 mut self,
452 boxes: configs::Detection,
453 protos: configs::Protos,
454 version: Option<DecoderVersion>,
455 ) -> Self {
456 let config = ConfigOutputs {
457 outputs: vec![ConfigOutput::Detection(boxes), ConfigOutput::Protos(protos)],
458 decoder_version: version,
459 ..Default::default()
460 };
461 self.config_src.replace(ConfigSource::Config(config));
462 self
463 }
464
465 /// Loads a YOLO split segmentation model configuration. Use
466 /// `DecoderBuilder.build()` to parse the model configuration.
467 ///
468 /// # Examples
469 /// ```rust
470 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
471 /// # fn main() -> DecoderResult<()> {
472 /// let boxes_config = configs::Boxes {
473 /// decoder: configs::DecoderType::Ultralytics,
474 /// quantization: Some(configs::QuantTuple(0.012345, 26)),
475 /// shape: vec![1, 4, 8400],
476 /// dshape: Vec::new(),
477 /// normalized: Some(true),
478 /// };
479 /// let scores_config = configs::Scores {
480 /// decoder: configs::DecoderType::Ultralytics,
481 /// quantization: Some(configs::QuantTuple(0.012345, 14)),
482 /// shape: vec![1, 80, 8400],
483 /// dshape: Vec::new(),
484 /// };
485 /// let mask_config = configs::MaskCoefficients {
486 /// decoder: configs::DecoderType::Ultralytics,
487 /// quantization: Some(configs::QuantTuple(0.0064123, 125)),
488 /// shape: vec![1, 32, 8400],
489 /// dshape: Vec::new(),
490 /// };
491 /// let protos_config = configs::Protos {
492 /// decoder: configs::DecoderType::Ultralytics,
493 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
494 /// shape: vec![1, 160, 160, 32],
495 /// dshape: Vec::new(),
496 /// };
497 /// let decoder = DecoderBuilder::new()
498 /// .with_config_yolo_split_segdet(boxes_config, scores_config, mask_config, protos_config)
499 /// .build()?;
500 /// # Ok(())
501 /// # }
502 /// ```
503 pub fn with_config_yolo_split_segdet(
504 mut self,
505 boxes: configs::Boxes,
506 scores: configs::Scores,
507 mask_coefficients: configs::MaskCoefficients,
508 protos: configs::Protos,
509 ) -> Self {
510 let config = ConfigOutputs {
511 outputs: vec![
512 ConfigOutput::Boxes(boxes),
513 ConfigOutput::Scores(scores),
514 ConfigOutput::MaskCoefficients(mask_coefficients),
515 ConfigOutput::Protos(protos),
516 ],
517 ..Default::default()
518 };
519 self.config_src.replace(ConfigSource::Config(config));
520 self
521 }
522
523 /// Loads a ModelPack detection model configuration. Use
524 /// `DecoderBuilder.build()` to parse the model configuration.
525 ///
526 /// # Examples
527 /// ```rust
528 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
529 /// # fn main() -> DecoderResult<()> {
530 /// let boxes_config = configs::Boxes {
531 /// decoder: configs::DecoderType::ModelPack,
532 /// quantization: Some(configs::QuantTuple(0.012345, 26)),
533 /// shape: vec![1, 8400, 1, 4],
534 /// dshape: Vec::new(),
535 /// normalized: Some(true),
536 /// };
537 /// let scores_config = configs::Scores {
538 /// decoder: configs::DecoderType::ModelPack,
539 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
540 /// shape: vec![1, 8400, 3],
541 /// dshape: Vec::new(),
542 /// };
543 /// let decoder = DecoderBuilder::new()
544 /// .with_config_modelpack_det(boxes_config, scores_config)
545 /// .build()?;
546 /// # Ok(())
547 /// # }
548 /// ```
549 pub fn with_config_modelpack_det(
550 mut self,
551 boxes: configs::Boxes,
552 scores: configs::Scores,
553 ) -> Self {
554 let config = ConfigOutputs {
555 outputs: vec![ConfigOutput::Boxes(boxes), ConfigOutput::Scores(scores)],
556 ..Default::default()
557 };
558 self.config_src.replace(ConfigSource::Config(config));
559 self
560 }
561
562 /// Loads a ModelPack split detection model configuration. Use
563 /// `DecoderBuilder.build()` to parse the model configuration.
564 ///
565 /// # Examples
566 /// ```rust
567 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
568 /// # fn main() -> DecoderResult<()> {
569 /// let config0 = configs::Detection {
570 /// anchors: Some(vec![
571 /// [0.13750000298023224, 0.2074074000120163],
572 /// [0.2541666626930237, 0.21481481194496155],
573 /// [0.23125000298023224, 0.35185185074806213],
574 /// ]),
575 /// decoder: configs::DecoderType::ModelPack,
576 /// quantization: Some(configs::QuantTuple(0.012345, 26)),
577 /// shape: vec![1, 17, 30, 18],
578 /// dshape: Vec::new(),
579 /// normalized: Some(true),
580 /// };
581 /// let config1 = configs::Detection {
582 /// anchors: Some(vec![
583 /// [0.36666667461395264, 0.31481480598449707],
584 /// [0.38749998807907104, 0.4740740656852722],
585 /// [0.5333333611488342, 0.644444465637207],
586 /// ]),
587 /// decoder: configs::DecoderType::ModelPack,
588 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
589 /// shape: vec![1, 9, 15, 18],
590 /// dshape: Vec::new(),
591 /// normalized: Some(true),
592 /// };
593 ///
594 /// let decoder = DecoderBuilder::new()
595 /// .with_config_modelpack_det_split(vec![config0, config1])
596 /// .build()?;
597 /// # Ok(())
598 /// # }
599 /// ```
600 pub fn with_config_modelpack_det_split(mut self, boxes: Vec<configs::Detection>) -> Self {
601 let outputs = boxes.into_iter().map(ConfigOutput::Detection).collect();
602 let config = ConfigOutputs {
603 outputs,
604 ..Default::default()
605 };
606 self.config_src.replace(ConfigSource::Config(config));
607 self
608 }
609
610 /// Loads a ModelPack segmentation detection model configuration. Use
611 /// `DecoderBuilder.build()` to parse the model configuration.
612 ///
613 /// # Examples
614 /// ```rust
615 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
616 /// # fn main() -> DecoderResult<()> {
617 /// let boxes_config = configs::Boxes {
618 /// decoder: configs::DecoderType::ModelPack,
619 /// quantization: Some(configs::QuantTuple(0.012345, 26)),
620 /// shape: vec![1, 8400, 1, 4],
621 /// dshape: Vec::new(),
622 /// normalized: Some(true),
623 /// };
624 /// let scores_config = configs::Scores {
625 /// decoder: configs::DecoderType::ModelPack,
626 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
627 /// shape: vec![1, 8400, 2],
628 /// dshape: Vec::new(),
629 /// };
630 /// let seg_config = configs::Segmentation {
631 /// decoder: configs::DecoderType::ModelPack,
632 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
633 /// shape: vec![1, 640, 640, 3],
634 /// dshape: Vec::new(),
635 /// };
636 /// let decoder = DecoderBuilder::new()
637 /// .with_config_modelpack_segdet(boxes_config, scores_config, seg_config)
638 /// .build()?;
639 /// # Ok(())
640 /// # }
641 /// ```
642 pub fn with_config_modelpack_segdet(
643 mut self,
644 boxes: configs::Boxes,
645 scores: configs::Scores,
646 segmentation: configs::Segmentation,
647 ) -> Self {
648 let config = ConfigOutputs {
649 outputs: vec![
650 ConfigOutput::Boxes(boxes),
651 ConfigOutput::Scores(scores),
652 ConfigOutput::Segmentation(segmentation),
653 ],
654 ..Default::default()
655 };
656 self.config_src.replace(ConfigSource::Config(config));
657 self
658 }
659
660 /// Loads a ModelPack segmentation split detection model configuration. Use
661 /// `DecoderBuilder.build()` to parse the model configuration.
662 ///
663 /// # Examples
664 /// ```rust
665 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
666 /// # fn main() -> DecoderResult<()> {
667 /// let config0 = configs::Detection {
668 /// anchors: Some(vec![
669 /// [0.36666667461395264, 0.31481480598449707],
670 /// [0.38749998807907104, 0.4740740656852722],
671 /// [0.5333333611488342, 0.644444465637207],
672 /// ]),
673 /// decoder: configs::DecoderType::ModelPack,
674 /// quantization: Some(configs::QuantTuple(0.08547406643629074, 174)),
675 /// shape: vec![1, 9, 15, 18],
676 /// dshape: Vec::new(),
677 /// normalized: Some(true),
678 /// };
679 /// let config1 = configs::Detection {
680 /// anchors: Some(vec![
681 /// [0.13750000298023224, 0.2074074000120163],
682 /// [0.2541666626930237, 0.21481481194496155],
683 /// [0.23125000298023224, 0.35185185074806213],
684 /// ]),
685 /// decoder: configs::DecoderType::ModelPack,
686 /// quantization: Some(configs::QuantTuple(0.09929127991199493, 183)),
687 /// shape: vec![1, 17, 30, 18],
688 /// dshape: Vec::new(),
689 /// normalized: Some(true),
690 /// };
691 /// let seg_config = configs::Segmentation {
692 /// decoder: configs::DecoderType::ModelPack,
693 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
694 /// shape: vec![1, 640, 640, 2],
695 /// dshape: Vec::new(),
696 /// };
697 /// let decoder = DecoderBuilder::new()
698 /// .with_config_modelpack_segdet_split(vec![config0, config1], seg_config)
699 /// .build()?;
700 /// # Ok(())
701 /// # }
702 /// ```
703 pub fn with_config_modelpack_segdet_split(
704 mut self,
705 boxes: Vec<configs::Detection>,
706 segmentation: configs::Segmentation,
707 ) -> Self {
708 let mut outputs = boxes
709 .into_iter()
710 .map(ConfigOutput::Detection)
711 .collect::<Vec<_>>();
712 outputs.push(ConfigOutput::Segmentation(segmentation));
713 let config = ConfigOutputs {
714 outputs,
715 ..Default::default()
716 };
717 self.config_src.replace(ConfigSource::Config(config));
718 self
719 }
720
721 /// Loads a ModelPack segmentation model configuration. Use
722 /// `DecoderBuilder.build()` to parse the model configuration.
723 ///
724 /// # Examples
725 /// ```rust
726 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
727 /// # fn main() -> DecoderResult<()> {
728 /// let seg_config = configs::Segmentation {
729 /// decoder: configs::DecoderType::ModelPack,
730 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
731 /// shape: vec![1, 640, 640, 3],
732 /// dshape: Vec::new(),
733 /// };
734 /// let decoder = DecoderBuilder::new()
735 /// .with_config_modelpack_seg(seg_config)
736 /// .build()?;
737 /// # Ok(())
738 /// # }
739 /// ```
740 pub fn with_config_modelpack_seg(mut self, segmentation: configs::Segmentation) -> Self {
741 let config = ConfigOutputs {
742 outputs: vec![ConfigOutput::Segmentation(segmentation)],
743 ..Default::default()
744 };
745 self.config_src.replace(ConfigSource::Config(config));
746 self
747 }
748
749 /// Add an output to the decoder configuration.
750 ///
751 /// Incrementally builds the model configuration by adding outputs one at
752 /// a time. The decoder resolves the model type from the combination of
753 /// outputs during `build()`.
754 ///
755 /// If `dshape` is non-empty on the output, `shape` is automatically
756 /// derived from it (the size component of each named dimension). This
757 /// prevents conflicts between `shape` and `dshape`.
758 ///
759 /// This uses the programmatic config path. Calling this after
760 /// `with_config_json_str()` or `with_config_yaml_str()` replaces the
761 /// string-based config source.
762 ///
763 /// # Examples
764 /// ```rust
765 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, ConfigOutput, configs};
766 /// # fn main() -> DecoderResult<()> {
767 /// let decoder = DecoderBuilder::new()
768 /// .add_output(ConfigOutput::Scores(configs::Scores {
769 /// decoder: configs::DecoderType::Ultralytics,
770 /// dshape: vec![
771 /// (configs::DimName::Batch, 1),
772 /// (configs::DimName::NumClasses, 80),
773 /// (configs::DimName::NumBoxes, 8400),
774 /// ],
775 /// ..Default::default()
776 /// }))
777 /// .add_output(ConfigOutput::Boxes(configs::Boxes {
778 /// decoder: configs::DecoderType::Ultralytics,
779 /// dshape: vec![
780 /// (configs::DimName::Batch, 1),
781 /// (configs::DimName::BoxCoords, 4),
782 /// (configs::DimName::NumBoxes, 8400),
783 /// ],
784 /// ..Default::default()
785 /// }))
786 /// .build()?;
787 /// # Ok(())
788 /// # }
789 /// ```
790 pub fn add_output(mut self, output: ConfigOutput) -> Self {
791 if !matches!(self.config_src, Some(ConfigSource::Config(_))) {
792 self.config_src = Some(ConfigSource::Config(ConfigOutputs::default()));
793 }
794 if let Some(ConfigSource::Config(ref mut config)) = self.config_src {
795 config.outputs.push(Self::normalize_output(output));
796 }
797 self
798 }
799
800 /// Sets the decoder version for Ultralytics models.
801 ///
802 /// This is used with `add_output()` to specify the YOLO architecture
803 /// version when it cannot be inferred from the output shapes alone.
804 ///
805 /// - `Yolov5`, `Yolov8`, `Yolo11`: Traditional models requiring external
806 /// NMS
807 /// - `Yolo26`: End-to-end models with NMS embedded in the model graph
808 ///
809 /// # Examples
810 /// ```rust
811 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, ConfigOutput, configs};
812 /// # fn main() -> DecoderResult<()> {
813 /// let decoder = DecoderBuilder::new()
814 /// .add_output(ConfigOutput::Detection(configs::Detection {
815 /// decoder: configs::DecoderType::Ultralytics,
816 /// dshape: vec![
817 /// (configs::DimName::Batch, 1),
818 /// (configs::DimName::NumBoxes, 100),
819 /// (configs::DimName::NumFeatures, 6),
820 /// ],
821 /// ..Default::default()
822 /// }))
823 /// .with_decoder_version(configs::DecoderVersion::Yolo26)
824 /// .build()?;
825 /// # Ok(())
826 /// # }
827 /// ```
828 pub fn with_decoder_version(mut self, version: configs::DecoderVersion) -> Self {
829 if !matches!(self.config_src, Some(ConfigSource::Config(_))) {
830 self.config_src = Some(ConfigSource::Config(ConfigOutputs::default()));
831 }
832 if let Some(ConfigSource::Config(ref mut config)) = self.config_src {
833 config.decoder_version = Some(version);
834 }
835 self
836 }
837
838 /// Normalize an output: if dshape is non-empty, derive shape from it.
839 fn normalize_output(mut output: ConfigOutput) -> ConfigOutput {
840 fn normalize_shape(shape: &mut Vec<usize>, dshape: &[(configs::DimName, usize)]) {
841 if !dshape.is_empty() {
842 *shape = dshape.iter().map(|(_, size)| *size).collect();
843 }
844 }
845 match &mut output {
846 ConfigOutput::Detection(c) => normalize_shape(&mut c.shape, &c.dshape),
847 ConfigOutput::Boxes(c) => normalize_shape(&mut c.shape, &c.dshape),
848 ConfigOutput::Scores(c) => normalize_shape(&mut c.shape, &c.dshape),
849 ConfigOutput::Protos(c) => normalize_shape(&mut c.shape, &c.dshape),
850 ConfigOutput::Segmentation(c) => normalize_shape(&mut c.shape, &c.dshape),
851 ConfigOutput::MaskCoefficients(c) => normalize_shape(&mut c.shape, &c.dshape),
852 ConfigOutput::Mask(c) => normalize_shape(&mut c.shape, &c.dshape),
853 ConfigOutput::Classes(c) => normalize_shape(&mut c.shape, &c.dshape),
854 }
855 output
856 }
857
858 /// Sets the scores threshold of the decoder
859 ///
860 /// # Examples
861 /// ```rust
862 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
863 /// # fn main() -> DecoderResult<()> {
864 /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
865 /// let decoder = DecoderBuilder::new()
866 /// .with_config_json_str(config_json)
867 /// .with_score_threshold(0.654)
868 /// .build()?;
869 /// assert_eq!(decoder.score_threshold, 0.654);
870 /// # Ok(())
871 /// # }
872 /// ```
873 pub fn with_score_threshold(mut self, score_threshold: f32) -> Self {
874 self.score_threshold = score_threshold;
875 self
876 }
877
878 /// Sets the IOU threshold of the decoder. Has no effect when NMS is set to
879 /// `None`
880 ///
881 /// # Examples
882 /// ```rust
883 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
884 /// # fn main() -> DecoderResult<()> {
885 /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
886 /// let decoder = DecoderBuilder::new()
887 /// .with_config_json_str(config_json)
888 /// .with_iou_threshold(0.654)
889 /// .build()?;
890 /// assert_eq!(decoder.iou_threshold, 0.654);
891 /// # Ok(())
892 /// # }
893 /// ```
894 pub fn with_iou_threshold(mut self, iou_threshold: f32) -> Self {
895 self.iou_threshold = iou_threshold;
896 self
897 }
898
899 /// Sets the NMS mode for the decoder.
900 ///
901 /// - `Some(Nms::Auto)` — resolve from model config (e.g. `edgefirst.json`)
902 /// or fall back to `ClassAgnostic` (this is the builder default)
903 /// - `Some(Nms::ClassAgnostic)` — class-agnostic NMS: suppress overlapping
904 /// boxes regardless of class label
905 /// - `Some(Nms::ClassAware)` — class-aware NMS: only suppress boxes that
906 /// share the same class label AND overlap above the IoU threshold
907 /// - `None` — bypass NMS entirely (for end-to-end models with embedded NMS)
908 ///
909 /// # Examples
910 /// ```rust
911 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, configs::Nms};
912 /// # fn main() -> DecoderResult<()> {
913 /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
914 /// let decoder = DecoderBuilder::new()
915 /// .with_config_json_str(config_json)
916 /// .with_nms(Some(Nms::ClassAware))
917 /// .build()?;
918 /// assert_eq!(decoder.nms, Some(Nms::ClassAware));
919 /// # Ok(())
920 /// # }
921 /// ```
922 pub fn with_nms(mut self, nms: Option<configs::Nms>) -> Self {
923 self.nms = nms;
924 self
925 }
926
927 /// Sets the maximum number of candidate boxes fed into NMS after score
928 /// filtering. Uses partial sort (O(N)) to select the top-K candidates,
929 /// dramatically reducing the O(N²) NMS cost when many low-confidence
930 /// proposals pass the threshold (common with mAP eval at 0.001).
931 ///
932 /// Default: 300.
933 ///
934 /// # ⚠️ Validation vs Deployment
935 ///
936 /// The default is appropriate for **deployment** where
937 /// `score_threshold ≥ 0.25` means few anchors survive filtering and
938 /// top-K is effectively a no-op.
939 ///
940 /// For **COCO mAP evaluation** (`score_threshold ≈ 0.001`), set this to
941 /// the total anchor count (8 400 for standard 640 × 640 YOLO models) or
942 /// to `0` (no limit) so that all score-passing candidates reach NMS.
943 /// Failing to do so causes **~9 pp box mAP loss** — the decoder math is
944 /// correct but the evaluation protocol requires full recall across the
945 /// confidence range.
946 ///
947 /// Post-processing latency scales with candidate count. At deployment
948 /// thresholds the cost difference is negligible; at validation thresholds
949 /// it is measurable but necessary for correct results.
950 ///
951 /// # Examples
952 ///
953 /// Deployment (default top-K, high score threshold):
954 /// ```rust
955 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
956 /// # fn main() -> DecoderResult<()> {
957 /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
958 /// let decoder = DecoderBuilder::new()
959 /// .with_config_json_str(config_json)
960 /// .with_score_threshold(0.25)
961 /// // pre_nms_top_k defaults to 300 — appropriate here
962 /// .build()?;
963 /// # Ok(())
964 /// # }
965 /// ```
966 ///
967 /// COCO mAP evaluation (pass all anchors to NMS):
968 /// ```rust
969 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
970 /// # fn main() -> DecoderResult<()> {
971 /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
972 /// let decoder = DecoderBuilder::new()
973 /// .with_config_json_str(config_json)
974 /// .with_score_threshold(0.001)
975 /// .with_pre_nms_top_k(8400) // all YOLO anchors
976 /// .with_max_det(300)
977 /// .build()?;
978 /// assert_eq!(decoder.pre_nms_top_k, 8400);
979 /// # Ok(())
980 /// # }
981 /// ```
982 pub fn with_pre_nms_top_k(mut self, pre_nms_top_k: usize) -> Self {
983 self.pre_nms_top_k = pre_nms_top_k;
984 self
985 }
986
987 /// Sets the maximum number of detections returned after NMS.
988 /// Matches the Ultralytics `max_det` parameter.
989 ///
990 /// Default: 300.
991 ///
992 /// # Examples
993 /// ```rust
994 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
995 /// # fn main() -> DecoderResult<()> {
996 /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
997 /// let decoder = DecoderBuilder::new()
998 /// .with_config_json_str(config_json)
999 /// .with_max_det(100)
1000 /// .build()?;
1001 /// assert_eq!(decoder.max_det, 100);
1002 /// # Ok(())
1003 /// # }
1004 /// ```
1005 pub fn with_max_det(mut self, max_det: usize) -> Self {
1006 self.max_det = max_det;
1007 self
1008 }
1009
1010 /// Sets the model input dimensions `(width, height)` consumed by the
1011 /// EDGEAI-1303 normalization path. Use this when building via
1012 /// [`with_config`](Self::with_config) / [`add_output`](Self::add_output)
1013 /// (no schema) and the model emits pixel-space boxes that need to be
1014 /// divided by `(W, H)` before NMS.
1015 ///
1016 /// When the builder is also configured with [`with_schema`](Self::with_schema)
1017 /// (or `with_config_json_str` / `with_config_yaml_str`) and the schema's
1018 /// `input` block carries usable dims, this explicit override **takes
1019 /// precedence** so callers can correct schemas with missing or wrong
1020 /// input specs without rewriting the schema.
1021 ///
1022 /// # Examples
1023 /// ```rust
1024 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
1025 /// # fn main() -> DecoderResult<()> {
1026 /// # let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
1027 /// let decoder = DecoderBuilder::new()
1028 /// .with_config_yaml_str(config_yaml)
1029 /// .with_input_dims(640, 640)
1030 /// .build()?;
1031 /// assert_eq!(decoder.input_dims(), Some((640, 640)));
1032 /// # Ok(())
1033 /// # }
1034 /// ```
1035 pub fn with_input_dims(mut self, width: usize, height: usize) -> Self {
1036 self.input_dims = Some((width, height));
1037 self
1038 }
1039
1040 /// Builds the decoder with the given settings. If the config is a JSON or
1041 /// YAML string, this will deserialize the JSON or YAML and then parse the
1042 /// configuration information.
1043 ///
1044 /// # Examples
1045 /// ```rust
1046 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
1047 /// # fn main() -> DecoderResult<()> {
1048 /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
1049 /// let decoder = DecoderBuilder::new()
1050 /// .with_config_json_str(config_json)
1051 /// .with_score_threshold(0.654)
1052 /// .build()?;
1053 /// # Ok(())
1054 /// # }
1055 /// ```
1056 pub fn build(self) -> Result<Decoder, DecoderError> {
1057 let decode_dtype = self.decode_dtype;
1058 let explicit_input_dims = self.input_dims;
1059 let (config, decode_program, per_scale_plan, schema_input_dims) = match self.config_src {
1060 Some(ConfigSource::Json(s)) => {
1061 Self::build_from_schema(SchemaV2::parse_json(&s)?, decode_dtype)?
1062 }
1063 Some(ConfigSource::Yaml(s)) => {
1064 Self::build_from_schema(SchemaV2::parse_yaml(&s)?, decode_dtype)?
1065 }
1066 Some(ConfigSource::Config(c)) => (c, None, None, None),
1067 Some(ConfigSource::Schema(schema)) => Self::build_from_schema(schema, decode_dtype)?,
1068 None => return Err(DecoderError::NoConfig),
1069 };
1070 // Explicit `with_input_dims(W, H)` overrides any schema-derived
1071 // value so callers can fix schemas with missing or wrong input
1072 // specs without rewriting the schema (EDGEAI-1303).
1073 let input_dims = explicit_input_dims.or(schema_input_dims);
1074
1075 // Enforce the physical-order contract: when dshape is present
1076 // it must describe the same axes as shape in the same order,
1077 // listed from outermost to innermost. Ambiguous-layout roles
1078 // (Protos, Boxes, Scores, MaskCoefficients, Classes, Detection)
1079 // may still omit dshape when shape is already in the decoder's
1080 // canonical order.
1081 for output in &config.outputs {
1082 Decoder::validate_output_layout(output.into())?;
1083 }
1084
1085 // Extract normalized flag from config outputs.
1086 //
1087 // The per-scale subsystem (DFL/LTRB → dist2bbox → sigmoid) emits
1088 // boxes in pixel coordinates by design — `(grid + dist) * stride`
1089 // — independently of any `normalized: true` annotation in the
1090 // schema. The schema's `normalized` flag describes the model's
1091 // training-time convention, not the runtime output coord space
1092 // for this code path. Override to `Some(false)` when the
1093 // per-scale path is active so `Decoder::normalized_boxes()`
1094 // matches what `decode_proto`/`decode` actually produce; the
1095 // legacy / non-per-scale paths still honor the schema flag.
1096 let normalized = if per_scale_plan.is_some() {
1097 Some(false)
1098 } else {
1099 Self::get_normalized(&config.outputs)
1100 };
1101
1102 // NMS precedence:
1103 // Some(ClassAgnostic|ClassAware) → explicit user override
1104 // Some(Auto) → resolve from config, fallback to ClassAgnostic
1105 // None → NMS disabled (explicit)
1106 //
1107 // `Auto` is always resolved to a concrete mode here — it never
1108 // persists into the built `Decoder`, even if the config itself
1109 // contains `Auto`.
1110 let resolve_auto = |nms: Option<configs::Nms>| match nms {
1111 Some(configs::Nms::Auto) | None => Some(configs::Nms::ClassAgnostic),
1112 concrete => concrete,
1113 };
1114 let nms = match self.nms {
1115 Some(configs::Nms::Auto) => resolve_auto(config.nms),
1116 other => other,
1117 };
1118 // When the per-scale path is active, the per_scale subsystem owns
1119 // model decoding entirely — `decode` / `decode_proto` short-circuit
1120 // on `per_scale.is_some()` before reading `model_type`. Skip the
1121 // legacy ModelType validation, which otherwise rejects per-scale
1122 // schemas that carry `decoder_version: yolo26` (an
1123 // "end-to-end" hint) but use the per-scale split outputs rather
1124 // than the end-to-end split-output shape the legacy validator
1125 // expects. We keep a placeholder `ModelType` so the field remains
1126 // valid; it is dead state for per-scale Decoders.
1127 let model_type = if per_scale_plan.is_some() {
1128 // Drop the un-needed config; the per-scale subsystem owns it.
1129 drop(config);
1130 ModelType::PerScale
1131 } else {
1132 Self::get_model_type(config)?
1133 };
1134
1135 let per_scale = per_scale_plan
1136 .map(|plan| std::sync::Mutex::new(crate::per_scale::PerScaleDecoder::new(plan)));
1137
1138 debug_assert!(
1139 !matches!(nms, Some(configs::Nms::Auto)),
1140 "Nms::Auto must be resolved to a concrete mode before building Decoder"
1141 );
1142
1143 Ok(Decoder {
1144 model_type,
1145 iou_threshold: self.iou_threshold,
1146 score_threshold: self.score_threshold,
1147 nms,
1148 pre_nms_top_k: self.pre_nms_top_k,
1149 max_det: self.max_det,
1150 normalized,
1151 input_dims,
1152 decode_program,
1153 per_scale,
1154 })
1155 }
1156
1157 /// Validate a [`SchemaV2`] and lower it to the (legacy `ConfigOutputs`,
1158 /// optional `DecodeProgram`, optional `PerScalePlan`) tuple the rest
1159 /// of `build()` consumes.
1160 ///
1161 /// Centralises the v2 lowering so JSON, YAML, and direct
1162 /// `with_schema` callers all go through the same validation,
1163 /// merge-program, and per-scale plan construction. `SchemaV2::parse_json`
1164 /// / `parse_yaml` already auto-detect v1 vs v2 input and return a v2
1165 /// schema either way (v1 inputs are upgraded in memory via
1166 /// `from_v1`), so this helper is the sole place that turns a
1167 /// schema into builder-ready state.
1168 #[allow(clippy::type_complexity)]
1169 fn build_from_schema(
1170 schema: SchemaV2,
1171 decode_dtype: DecodeDtype,
1172 ) -> Result<
1173 (
1174 ConfigOutputs,
1175 Option<DecodeProgram>,
1176 Option<PerScalePlan>,
1177 Option<(usize, usize)>,
1178 ),
1179 DecoderError,
1180 > {
1181 schema.validate()?;
1182 let program = DecodeProgram::try_from_schema(&schema)?;
1183 let per_scale = PerScalePlan::try_from_schema(&schema, decode_dtype)?;
1184 // Extract model input (W, H) from `input.shape`/`dshape`. Used by
1185 // the legacy decode path to honour `normalized: false` (see
1186 // EDGEAI-1303). `None` is fine when the schema omits the input
1187 // spec — the decoder falls back to the protobox `>2.0` reject.
1188 let input_dims = schema.input.as_ref().and_then(input_dims_from_spec);
1189 let legacy = schema.to_legacy_config_outputs()?;
1190 Ok((legacy, program, per_scale, input_dims))
1191 }
1192
1193 /// Extracts the normalized flag from config outputs.
1194 /// - `Some(true)`: Boxes are in normalized [0,1] coordinates
1195 /// - `Some(false)`: Boxes are in pixel coordinates
1196 /// - `None`: Unknown (not specified in config), caller must infer
1197 fn get_normalized(outputs: &[ConfigOutput]) -> Option<bool> {
1198 for output in outputs {
1199 match output {
1200 ConfigOutput::Detection(det) => return det.normalized,
1201 ConfigOutput::Boxes(boxes) => return boxes.normalized,
1202 _ => {}
1203 }
1204 }
1205 None // not specified
1206 }
1207
1208 fn get_model_type(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
1209 // yolo or modelpack
1210 let mut yolo = false;
1211 let mut modelpack = false;
1212 for c in &configs.outputs {
1213 match c.decoder() {
1214 DecoderType::ModelPack => modelpack = true,
1215 DecoderType::Ultralytics => yolo = true,
1216 }
1217 }
1218 match (modelpack, yolo) {
1219 (true, true) => Err(DecoderError::InvalidConfig(
1220 "Both ModelPack and Yolo outputs found in config".to_string(),
1221 )),
1222 (true, false) => Self::get_model_type_modelpack(configs),
1223 (false, true) => Self::get_model_type_yolo(configs),
1224 (false, false) => Err(DecoderError::InvalidConfig(
1225 "No outputs found in config".to_string(),
1226 )),
1227 }
1228 }
1229
1230 fn get_model_type_yolo(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
1231 let mut boxes = None;
1232 let mut protos = None;
1233 let mut split_boxes = None;
1234 let mut split_scores = None;
1235 let mut split_mask_coeff = None;
1236 let mut split_classes = None;
1237 for c in configs.outputs {
1238 match c {
1239 ConfigOutput::Detection(detection) => boxes = Some(detection),
1240 ConfigOutput::Segmentation(_) => {
1241 return Err(DecoderError::InvalidConfig(
1242 "Invalid Segmentation output with Yolo decoder".to_string(),
1243 ));
1244 }
1245 ConfigOutput::Protos(protos_) => protos = Some(protos_),
1246 ConfigOutput::Mask(_) => {
1247 return Err(DecoderError::InvalidConfig(
1248 "Invalid Mask output with Yolo decoder".to_string(),
1249 ));
1250 }
1251 ConfigOutput::Scores(scores) => split_scores = Some(scores),
1252 ConfigOutput::Boxes(boxes) => split_boxes = Some(boxes),
1253 ConfigOutput::MaskCoefficients(mask_coeff) => split_mask_coeff = Some(mask_coeff),
1254 ConfigOutput::Classes(classes) => split_classes = Some(classes),
1255 }
1256 }
1257
1258 // Use end-to-end model types when:
1259 // 1. decoder_version is explicitly set to Yolo26 (definitive), OR
1260 // decoder_version is not set but the dshapes are (batch, num_boxes,
1261 // num_features)
1262 let is_end_to_end_dshape = boxes.as_ref().is_some_and(|b| {
1263 let dims = b.dshape.iter().map(|(d, _)| *d).collect::<Vec<_>>();
1264 dims == vec![DimName::Batch, DimName::NumBoxes, DimName::NumFeatures]
1265 });
1266
1267 let is_end_to_end = configs
1268 .decoder_version
1269 .map(|v| v.is_end_to_end())
1270 .unwrap_or(is_end_to_end_dshape);
1271
1272 if is_end_to_end {
1273 if let Some(boxes) = boxes {
1274 if let Some(protos) = protos {
1275 Self::verify_yolo_seg_det_26(&boxes, &protos)?;
1276 return Ok(ModelType::YoloEndToEndSegDet { boxes, protos });
1277 } else {
1278 Self::verify_yolo_det_26(&boxes)?;
1279 return Ok(ModelType::YoloEndToEndDet { boxes });
1280 }
1281 } else if let (Some(split_boxes), Some(split_scores), Some(split_classes)) =
1282 (split_boxes, split_scores, split_classes)
1283 {
1284 if let (Some(split_mask_coeff), Some(protos)) = (split_mask_coeff, protos) {
1285 Self::verify_yolo_split_end_to_end_segdet(
1286 &split_boxes,
1287 &split_scores,
1288 &split_classes,
1289 &split_mask_coeff,
1290 &protos,
1291 )?;
1292 return Ok(ModelType::YoloSplitEndToEndSegDet {
1293 boxes: split_boxes,
1294 scores: split_scores,
1295 classes: split_classes,
1296 mask_coeff: split_mask_coeff,
1297 protos,
1298 });
1299 }
1300 Self::verify_yolo_split_end_to_end_det(
1301 &split_boxes,
1302 &split_scores,
1303 &split_classes,
1304 )?;
1305 return Ok(ModelType::YoloSplitEndToEndDet {
1306 boxes: split_boxes,
1307 scores: split_scores,
1308 classes: split_classes,
1309 });
1310 } else {
1311 return Err(DecoderError::InvalidConfig(
1312 "Invalid Yolo end-to-end model outputs".to_string(),
1313 ));
1314 }
1315 }
1316
1317 if let Some(boxes) = boxes {
1318 match (split_mask_coeff, protos) {
1319 (Some(mask_coeff), Some(protos)) => {
1320 // 2-way split: combined detection + separate mask_coeff + protos
1321 Self::verify_yolo_seg_det_2way(&boxes, &mask_coeff, &protos)?;
1322 Ok(ModelType::YoloSegDet2Way {
1323 boxes,
1324 mask_coeff,
1325 protos,
1326 })
1327 }
1328 (_, Some(protos)) => {
1329 // Unsplit: mask_coefs embedded in detection tensor
1330 Self::verify_yolo_seg_det(&boxes, &protos)?;
1331 Ok(ModelType::YoloSegDet { boxes, protos })
1332 }
1333 _ => {
1334 Self::verify_yolo_det(&boxes)?;
1335 Ok(ModelType::YoloDet { boxes })
1336 }
1337 }
1338 } else if let (Some(boxes), Some(scores)) = (split_boxes, split_scores) {
1339 if let (Some(mask_coeff), Some(protos)) = (split_mask_coeff, protos) {
1340 Self::verify_yolo_split_segdet(&boxes, &scores, &mask_coeff, &protos)?;
1341 Ok(ModelType::YoloSplitSegDet {
1342 boxes,
1343 scores,
1344 mask_coeff,
1345 protos,
1346 })
1347 } else {
1348 Self::verify_yolo_split_det(&boxes, &scores)?;
1349 Ok(ModelType::YoloSplitDet { boxes, scores })
1350 }
1351 } else {
1352 Err(DecoderError::InvalidConfig(
1353 "Invalid Yolo model outputs".to_string(),
1354 ))
1355 }
1356 }
1357
1358 fn verify_yolo_det(detect: &configs::Detection) -> Result<(), DecoderError> {
1359 if detect.shape.len() != 3 {
1360 return Err(DecoderError::InvalidConfig(format!(
1361 "Invalid Yolo Detection shape {:?}",
1362 detect.shape
1363 )));
1364 }
1365
1366 Self::verify_dshapes(
1367 &detect.dshape,
1368 &detect.shape,
1369 "Detection",
1370 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1371 )?;
1372 if !detect.dshape.is_empty() {
1373 Self::get_class_count(&detect.dshape, None, None)?;
1374 } else {
1375 Self::get_class_count_no_dshape(detect.into(), None)?;
1376 }
1377
1378 Ok(())
1379 }
1380
1381 fn verify_yolo_det_26(detect: &configs::Detection) -> Result<(), DecoderError> {
1382 if detect.shape.len() != 3 {
1383 return Err(DecoderError::InvalidConfig(format!(
1384 "Invalid Yolo Detection shape {:?}",
1385 detect.shape
1386 )));
1387 }
1388
1389 Self::verify_dshapes(
1390 &detect.dshape,
1391 &detect.shape,
1392 "Detection",
1393 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1394 )?;
1395
1396 if !detect.shape.contains(&6) {
1397 return Err(DecoderError::InvalidConfig(
1398 "Yolo26 Detection must have 6 features".to_string(),
1399 ));
1400 }
1401
1402 Ok(())
1403 }
1404
1405 fn verify_yolo_seg_det(
1406 detection: &configs::Detection,
1407 protos: &configs::Protos,
1408 ) -> Result<(), DecoderError> {
1409 if detection.shape.len() != 3 {
1410 return Err(DecoderError::InvalidConfig(format!(
1411 "Invalid Yolo Detection shape {:?}",
1412 detection.shape
1413 )));
1414 }
1415 if protos.shape.len() != 4 {
1416 return Err(DecoderError::InvalidConfig(format!(
1417 "Invalid Yolo Protos shape {:?}",
1418 protos.shape
1419 )));
1420 }
1421
1422 Self::verify_dshapes(
1423 &detection.dshape,
1424 &detection.shape,
1425 "Detection",
1426 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1427 )?;
1428 Self::verify_dshapes(
1429 &protos.dshape,
1430 &protos.shape,
1431 "Protos",
1432 &[
1433 DimName::Batch,
1434 DimName::Height,
1435 DimName::Width,
1436 DimName::NumProtos,
1437 ],
1438 )?;
1439
1440 let protos_count = Self::get_protos_count(&protos.dshape)
1441 .unwrap_or_else(|| protos.shape[1].min(protos.shape[3]));
1442 log::debug!("Protos count: {}", protos_count);
1443 log::debug!("Detection dshape: {:?}", detection.dshape);
1444 let classes = if !detection.dshape.is_empty() {
1445 Self::get_class_count(&detection.dshape, Some(protos_count), None)?
1446 } else {
1447 Self::get_class_count_no_dshape(detection.into(), Some(protos_count))?
1448 };
1449
1450 if classes == 0 {
1451 return Err(DecoderError::InvalidConfig(
1452 "Yolo Segmentation Detection has zero classes".to_string(),
1453 ));
1454 }
1455
1456 Ok(())
1457 }
1458
1459 fn verify_yolo_seg_det_2way(
1460 detection: &configs::Detection,
1461 mask_coeff: &configs::MaskCoefficients,
1462 protos: &configs::Protos,
1463 ) -> Result<(), DecoderError> {
1464 if detection.shape.len() != 3 {
1465 return Err(DecoderError::InvalidConfig(format!(
1466 "Invalid Yolo 2-Way Detection shape {:?}",
1467 detection.shape
1468 )));
1469 }
1470 if mask_coeff.shape.len() != 3 {
1471 return Err(DecoderError::InvalidConfig(format!(
1472 "Invalid Yolo 2-Way Mask Coefficients shape {:?}",
1473 mask_coeff.shape
1474 )));
1475 }
1476 if protos.shape.len() != 4 {
1477 return Err(DecoderError::InvalidConfig(format!(
1478 "Invalid Yolo 2-Way Protos shape {:?}",
1479 protos.shape
1480 )));
1481 }
1482
1483 Self::verify_dshapes(
1484 &detection.dshape,
1485 &detection.shape,
1486 "Detection",
1487 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1488 )?;
1489 Self::verify_dshapes(
1490 &mask_coeff.dshape,
1491 &mask_coeff.shape,
1492 "Mask Coefficients",
1493 &[DimName::Batch, DimName::NumProtos, DimName::NumBoxes],
1494 )?;
1495 Self::verify_dshapes(
1496 &protos.dshape,
1497 &protos.shape,
1498 "Protos",
1499 &[
1500 DimName::Batch,
1501 DimName::Height,
1502 DimName::Width,
1503 DimName::NumProtos,
1504 ],
1505 )?;
1506
1507 // Validate num_boxes match between detection and mask_coeff
1508 let det_num = Self::get_box_count(&detection.dshape).unwrap_or(detection.shape[2]);
1509 let mask_num = Self::get_box_count(&mask_coeff.dshape).unwrap_or(mask_coeff.shape[2]);
1510 if det_num != mask_num {
1511 return Err(DecoderError::InvalidConfig(format!(
1512 "Yolo 2-Way Detection num_boxes {} incompatible with Mask Coefficients num_boxes {}",
1513 det_num, mask_num
1514 )));
1515 }
1516
1517 // Validate mask_coeff channels match protos channels
1518 let mask_channels = if !mask_coeff.dshape.is_empty() {
1519 Self::get_protos_count(&mask_coeff.dshape).ok_or_else(|| {
1520 DecoderError::InvalidConfig(
1521 "Could not find num_protos in mask_coeff config".to_string(),
1522 )
1523 })?
1524 } else {
1525 mask_coeff.shape[1]
1526 };
1527 let proto_channels = if !protos.dshape.is_empty() {
1528 Self::get_protos_count(&protos.dshape).ok_or_else(|| {
1529 DecoderError::InvalidConfig(
1530 "Could not find num_protos in protos config".to_string(),
1531 )
1532 })?
1533 } else {
1534 protos.shape[1].min(protos.shape[3])
1535 };
1536 if mask_channels != proto_channels {
1537 return Err(DecoderError::InvalidConfig(format!(
1538 "Yolo 2-Way Protos channels {} incompatible with Mask Coefficients channels {}",
1539 proto_channels, mask_channels
1540 )));
1541 }
1542
1543 // Validate detection has classes (nc+4 features, no mask_coefs embedded)
1544 if !detection.dshape.is_empty() {
1545 Self::get_class_count(&detection.dshape, None, None)?;
1546 } else {
1547 Self::get_class_count_no_dshape(detection.into(), None)?;
1548 }
1549
1550 Ok(())
1551 }
1552
1553 fn verify_yolo_seg_det_26(
1554 detection: &configs::Detection,
1555 protos: &configs::Protos,
1556 ) -> Result<(), DecoderError> {
1557 if detection.shape.len() != 3 {
1558 return Err(DecoderError::InvalidConfig(format!(
1559 "Invalid Yolo Detection shape {:?}",
1560 detection.shape
1561 )));
1562 }
1563 if protos.shape.len() != 4 {
1564 return Err(DecoderError::InvalidConfig(format!(
1565 "Invalid Yolo Protos shape {:?}",
1566 protos.shape
1567 )));
1568 }
1569
1570 Self::verify_dshapes(
1571 &detection.dshape,
1572 &detection.shape,
1573 "Detection",
1574 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1575 )?;
1576 Self::verify_dshapes(
1577 &protos.dshape,
1578 &protos.shape,
1579 "Protos",
1580 &[
1581 DimName::Batch,
1582 DimName::Height,
1583 DimName::Width,
1584 DimName::NumProtos,
1585 ],
1586 )?;
1587
1588 let protos_count = Self::get_protos_count(&protos.dshape)
1589 .unwrap_or_else(|| protos.shape[1].min(protos.shape[3]));
1590 log::debug!("Protos count: {}", protos_count);
1591 log::debug!("Detection dshape: {:?}", detection.dshape);
1592
1593 if !detection.shape.contains(&(6 + protos_count)) {
1594 return Err(DecoderError::InvalidConfig(format!(
1595 "Yolo26 Segmentation Detection must have num_features be 6 + num_protos = {}",
1596 6 + protos_count
1597 )));
1598 }
1599
1600 Ok(())
1601 }
1602
1603 fn verify_yolo_split_det(
1604 boxes: &configs::Boxes,
1605 scores: &configs::Scores,
1606 ) -> Result<(), DecoderError> {
1607 if boxes.shape.len() != 3 {
1608 return Err(DecoderError::InvalidConfig(format!(
1609 "Invalid Yolo Split Boxes shape {:?}",
1610 boxes.shape
1611 )));
1612 }
1613 if scores.shape.len() != 3 {
1614 return Err(DecoderError::InvalidConfig(format!(
1615 "Invalid Yolo Split Scores shape {:?}",
1616 scores.shape
1617 )));
1618 }
1619
1620 Self::verify_dshapes(
1621 &boxes.dshape,
1622 &boxes.shape,
1623 "Boxes",
1624 &[DimName::Batch, DimName::BoxCoords, DimName::NumBoxes],
1625 )?;
1626 Self::verify_dshapes(
1627 &scores.dshape,
1628 &scores.shape,
1629 "Scores",
1630 &[DimName::Batch, DimName::NumClasses, DimName::NumBoxes],
1631 )?;
1632
1633 let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[2]);
1634 let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[2]);
1635
1636 if boxes_num != scores_num {
1637 return Err(DecoderError::InvalidConfig(format!(
1638 "Yolo Split Detection Boxes num {} incompatible with Scores num {}",
1639 boxes_num, scores_num
1640 )));
1641 }
1642
1643 Ok(())
1644 }
1645
1646 fn verify_yolo_split_segdet(
1647 boxes: &configs::Boxes,
1648 scores: &configs::Scores,
1649 mask_coeff: &configs::MaskCoefficients,
1650 protos: &configs::Protos,
1651 ) -> Result<(), DecoderError> {
1652 if boxes.shape.len() != 3 {
1653 return Err(DecoderError::InvalidConfig(format!(
1654 "Invalid Yolo Split Boxes shape {:?}",
1655 boxes.shape
1656 )));
1657 }
1658 if scores.shape.len() != 3 {
1659 return Err(DecoderError::InvalidConfig(format!(
1660 "Invalid Yolo Split Scores shape {:?}",
1661 scores.shape
1662 )));
1663 }
1664
1665 if mask_coeff.shape.len() != 3 {
1666 return Err(DecoderError::InvalidConfig(format!(
1667 "Invalid Yolo Split Mask Coefficients shape {:?}",
1668 mask_coeff.shape
1669 )));
1670 }
1671
1672 if protos.shape.len() != 4 {
1673 return Err(DecoderError::InvalidConfig(format!(
1674 "Invalid Yolo Protos shape {:?}",
1675 mask_coeff.shape
1676 )));
1677 }
1678
1679 Self::verify_dshapes(
1680 &boxes.dshape,
1681 &boxes.shape,
1682 "Boxes",
1683 &[DimName::Batch, DimName::BoxCoords, DimName::NumBoxes],
1684 )?;
1685 Self::verify_dshapes(
1686 &scores.dshape,
1687 &scores.shape,
1688 "Scores",
1689 &[DimName::Batch, DimName::NumClasses, DimName::NumBoxes],
1690 )?;
1691 Self::verify_dshapes(
1692 &mask_coeff.dshape,
1693 &mask_coeff.shape,
1694 "Mask Coefficients",
1695 &[DimName::Batch, DimName::NumProtos, DimName::NumBoxes],
1696 )?;
1697 Self::verify_dshapes(
1698 &protos.dshape,
1699 &protos.shape,
1700 "Protos",
1701 &[
1702 DimName::Batch,
1703 DimName::Height,
1704 DimName::Width,
1705 DimName::NumProtos,
1706 ],
1707 )?;
1708
1709 let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[2]);
1710 let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[2]);
1711 let mask_num = Self::get_box_count(&mask_coeff.dshape).unwrap_or(mask_coeff.shape[2]);
1712
1713 let mask_channels = if !mask_coeff.dshape.is_empty() {
1714 Self::get_protos_count(&mask_coeff.dshape).ok_or_else(|| {
1715 DecoderError::InvalidConfig("Could not find num_protos in config".to_string())
1716 })?
1717 } else {
1718 mask_coeff.shape[1]
1719 };
1720 let proto_channels = if !protos.dshape.is_empty() {
1721 Self::get_protos_count(&protos.dshape).ok_or_else(|| {
1722 DecoderError::InvalidConfig("Could not find num_protos in config".to_string())
1723 })?
1724 } else {
1725 protos.shape[1].min(protos.shape[3])
1726 };
1727
1728 if boxes_num != scores_num {
1729 return Err(DecoderError::InvalidConfig(format!(
1730 "Yolo Split Detection Boxes num {} incompatible with Scores num {}",
1731 boxes_num, scores_num
1732 )));
1733 }
1734
1735 if boxes_num != mask_num {
1736 return Err(DecoderError::InvalidConfig(format!(
1737 "Yolo Split Detection Boxes num {} incompatible with Mask Coefficients num {}",
1738 boxes_num, mask_num
1739 )));
1740 }
1741
1742 if proto_channels != mask_channels {
1743 return Err(DecoderError::InvalidConfig(format!(
1744 "Yolo Protos channels {} incompatible with Mask Coefficients channels {}",
1745 proto_channels, mask_channels
1746 )));
1747 }
1748
1749 Ok(())
1750 }
1751
1752 fn verify_yolo_split_end_to_end_det(
1753 boxes: &configs::Boxes,
1754 scores: &configs::Scores,
1755 classes: &configs::Classes,
1756 ) -> Result<(), DecoderError> {
1757 if boxes.shape.len() != 3 || !boxes.shape.contains(&4) {
1758 return Err(DecoderError::InvalidConfig(format!(
1759 "Split end-to-end boxes must be [batch, N, 4], got {:?}",
1760 boxes.shape
1761 )));
1762 }
1763 if scores.shape.len() != 3 || !scores.shape.contains(&1) {
1764 return Err(DecoderError::InvalidConfig(format!(
1765 "Split end-to-end scores must be [batch, N, 1], got {:?}",
1766 scores.shape
1767 )));
1768 }
1769 if classes.shape.len() != 3 || !classes.shape.contains(&1) {
1770 return Err(DecoderError::InvalidConfig(format!(
1771 "Split end-to-end classes must be [batch, N, 1], got {:?}",
1772 classes.shape
1773 )));
1774 }
1775 Ok(())
1776 }
1777
1778 fn verify_yolo_split_end_to_end_segdet(
1779 boxes: &configs::Boxes,
1780 scores: &configs::Scores,
1781 classes: &configs::Classes,
1782 mask_coeff: &configs::MaskCoefficients,
1783 protos: &configs::Protos,
1784 ) -> Result<(), DecoderError> {
1785 Self::verify_yolo_split_end_to_end_det(boxes, scores, classes)?;
1786 if mask_coeff.shape.len() != 3 {
1787 return Err(DecoderError::InvalidConfig(format!(
1788 "Invalid split end-to-end mask coefficients shape {:?}",
1789 mask_coeff.shape
1790 )));
1791 }
1792 if protos.shape.len() != 4 {
1793 return Err(DecoderError::InvalidConfig(format!(
1794 "Invalid protos shape {:?}",
1795 protos.shape
1796 )));
1797 }
1798 Ok(())
1799 }
1800
1801 fn get_model_type_modelpack(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
1802 let mut split_decoders = Vec::new();
1803 let mut segment_ = None;
1804 let mut scores_ = None;
1805 let mut boxes_ = None;
1806 for c in configs.outputs {
1807 match c {
1808 ConfigOutput::Detection(detection) => split_decoders.push(detection),
1809 ConfigOutput::Segmentation(segmentation) => segment_ = Some(segmentation),
1810 ConfigOutput::Mask(_) => {}
1811 ConfigOutput::Protos(_) => {
1812 return Err(DecoderError::InvalidConfig(
1813 "ModelPack should not have protos".to_string(),
1814 ));
1815 }
1816 ConfigOutput::Scores(scores) => scores_ = Some(scores),
1817 ConfigOutput::Boxes(boxes) => boxes_ = Some(boxes),
1818 ConfigOutput::MaskCoefficients(_) => {
1819 return Err(DecoderError::InvalidConfig(
1820 "ModelPack should not have mask coefficients".to_string(),
1821 ));
1822 }
1823 ConfigOutput::Classes(_) => {
1824 return Err(DecoderError::InvalidConfig(
1825 "ModelPack should not have classes output".to_string(),
1826 ));
1827 }
1828 }
1829 }
1830
1831 if let Some(segmentation) = segment_ {
1832 if !split_decoders.is_empty() {
1833 let classes = Self::verify_modelpack_split_det(&split_decoders)?;
1834 Self::verify_modelpack_seg(&segmentation, Some(classes))?;
1835 Ok(ModelType::ModelPackSegDetSplit {
1836 detection: split_decoders,
1837 segmentation,
1838 })
1839 } else if let (Some(scores), Some(boxes)) = (scores_, boxes_) {
1840 let classes = Self::verify_modelpack_det(&boxes, &scores)?;
1841 Self::verify_modelpack_seg(&segmentation, Some(classes))?;
1842 Ok(ModelType::ModelPackSegDet {
1843 boxes,
1844 scores,
1845 segmentation,
1846 })
1847 } else {
1848 Self::verify_modelpack_seg(&segmentation, None)?;
1849 Ok(ModelType::ModelPackSeg { segmentation })
1850 }
1851 } else if !split_decoders.is_empty() {
1852 Self::verify_modelpack_split_det(&split_decoders)?;
1853 Ok(ModelType::ModelPackDetSplit {
1854 detection: split_decoders,
1855 })
1856 } else if let (Some(scores), Some(boxes)) = (scores_, boxes_) {
1857 Self::verify_modelpack_det(&boxes, &scores)?;
1858 Ok(ModelType::ModelPackDet { boxes, scores })
1859 } else {
1860 Err(DecoderError::InvalidConfig(
1861 "Invalid ModelPack model outputs".to_string(),
1862 ))
1863 }
1864 }
1865
1866 fn verify_modelpack_det(
1867 boxes: &configs::Boxes,
1868 scores: &configs::Scores,
1869 ) -> Result<usize, DecoderError> {
1870 if boxes.shape.len() != 4 {
1871 return Err(DecoderError::InvalidConfig(format!(
1872 "Invalid ModelPack Boxes shape {:?}",
1873 boxes.shape
1874 )));
1875 }
1876 if scores.shape.len() != 3 {
1877 return Err(DecoderError::InvalidConfig(format!(
1878 "Invalid ModelPack Scores shape {:?}",
1879 scores.shape
1880 )));
1881 }
1882
1883 Self::verify_dshapes(
1884 &boxes.dshape,
1885 &boxes.shape,
1886 "Boxes",
1887 &[
1888 DimName::Batch,
1889 DimName::NumBoxes,
1890 DimName::Padding,
1891 DimName::BoxCoords,
1892 ],
1893 )?;
1894 Self::verify_dshapes(
1895 &scores.dshape,
1896 &scores.shape,
1897 "Scores",
1898 &[DimName::Batch, DimName::NumBoxes, DimName::NumClasses],
1899 )?;
1900
1901 let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[1]);
1902 let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[1]);
1903
1904 if boxes_num != scores_num {
1905 return Err(DecoderError::InvalidConfig(format!(
1906 "ModelPack Detection Boxes num {} incompatible with Scores num {}",
1907 boxes_num, scores_num
1908 )));
1909 }
1910
1911 let num_classes = if !scores.dshape.is_empty() {
1912 Self::get_class_count(&scores.dshape, None, None)?
1913 } else {
1914 Self::get_class_count_no_dshape(scores.into(), None)?
1915 };
1916
1917 Ok(num_classes)
1918 }
1919
1920 fn verify_modelpack_split_det(boxes: &[configs::Detection]) -> Result<usize, DecoderError> {
1921 let mut num_classes = None;
1922 for b in boxes {
1923 let Some(num_anchors) = b.anchors.as_ref().map(|a| a.len()) else {
1924 return Err(DecoderError::InvalidConfig(
1925 "ModelPack Split Detection missing anchors".to_string(),
1926 ));
1927 };
1928
1929 if num_anchors == 0 {
1930 return Err(DecoderError::InvalidConfig(
1931 "ModelPack Split Detection has zero anchors".to_string(),
1932 ));
1933 }
1934
1935 if b.shape.len() != 4 {
1936 return Err(DecoderError::InvalidConfig(format!(
1937 "Invalid ModelPack Split Detection shape {:?}",
1938 b.shape
1939 )));
1940 }
1941
1942 Self::verify_dshapes(
1943 &b.dshape,
1944 &b.shape,
1945 "Split Detection",
1946 &[
1947 DimName::Batch,
1948 DimName::Height,
1949 DimName::Width,
1950 DimName::NumAnchorsXFeatures,
1951 ],
1952 )?;
1953 let classes = if !b.dshape.is_empty() {
1954 Self::get_class_count(&b.dshape, None, Some(num_anchors))?
1955 } else {
1956 Self::get_class_count_no_dshape(b.into(), None)?
1957 };
1958
1959 match num_classes {
1960 Some(n) => {
1961 if n != classes {
1962 return Err(DecoderError::InvalidConfig(format!(
1963 "ModelPack Split Detection inconsistent number of classes: previous {}, current {}",
1964 n, classes
1965 )));
1966 }
1967 }
1968 None => {
1969 num_classes = Some(classes);
1970 }
1971 }
1972 }
1973
1974 Ok(num_classes.unwrap_or(0))
1975 }
1976
1977 fn verify_modelpack_seg(
1978 segmentation: &configs::Segmentation,
1979 classes: Option<usize>,
1980 ) -> Result<(), DecoderError> {
1981 if segmentation.shape.len() != 4 {
1982 return Err(DecoderError::InvalidConfig(format!(
1983 "Invalid ModelPack Segmentation shape {:?}",
1984 segmentation.shape
1985 )));
1986 }
1987 Self::verify_dshapes(
1988 &segmentation.dshape,
1989 &segmentation.shape,
1990 "Segmentation",
1991 &[
1992 DimName::Batch,
1993 DimName::Height,
1994 DimName::Width,
1995 DimName::NumClasses,
1996 ],
1997 )?;
1998
1999 if let Some(classes) = classes {
2000 let seg_classes = if !segmentation.dshape.is_empty() {
2001 Self::get_class_count(&segmentation.dshape, None, None)?
2002 } else {
2003 Self::get_class_count_no_dshape(segmentation.into(), None)?
2004 };
2005
2006 if seg_classes != classes + 1 {
2007 return Err(DecoderError::InvalidConfig(format!(
2008 "ModelPack Segmentation channels {} incompatible with number of classes {}",
2009 seg_classes, classes
2010 )));
2011 }
2012 }
2013 Ok(())
2014 }
2015
2016 // verifies that dshapes match the given shape
2017 fn verify_dshapes(
2018 dshape: &[(DimName, usize)],
2019 shape: &[usize],
2020 name: &str,
2021 dims: &[DimName],
2022 ) -> Result<(), DecoderError> {
2023 for s in shape {
2024 if *s == 0 {
2025 return Err(DecoderError::InvalidConfig(format!(
2026 "{} shape has zero dimension",
2027 name
2028 )));
2029 }
2030 }
2031
2032 if shape.len() != dims.len() {
2033 return Err(DecoderError::InvalidConfig(format!(
2034 "{} shape length {} does not match expected dims length {}",
2035 name,
2036 shape.len(),
2037 dims.len()
2038 )));
2039 }
2040
2041 if dshape.is_empty() {
2042 return Ok(());
2043 }
2044 // check the dshape lengths match the shape lengths
2045 if dshape.len() != shape.len() {
2046 return Err(DecoderError::InvalidConfig(format!(
2047 "{} dshape length does not match shape length",
2048 name
2049 )));
2050 }
2051
2052 // check the dshape values match the shape values
2053 for ((dim_name, dim_size), shape_size) in dshape.iter().zip(shape) {
2054 if dim_size != shape_size {
2055 return Err(DecoderError::InvalidConfig(format!(
2056 "{} dshape dimension {} size {} does not match shape size {}",
2057 name, dim_name, dim_size, shape_size
2058 )));
2059 }
2060 if *dim_name == DimName::Padding && *dim_size != 1 {
2061 return Err(DecoderError::InvalidConfig(
2062 "Padding dimension size must be 1".to_string(),
2063 ));
2064 }
2065
2066 if *dim_name == DimName::BoxCoords && *dim_size != 4 {
2067 return Err(DecoderError::InvalidConfig(
2068 "BoxCoords dimension size must be 4".to_string(),
2069 ));
2070 }
2071 }
2072
2073 let dims_present = HashSet::<DimName>::from_iter(dshape.iter().map(|(name, _)| *name));
2074 for dim in dims {
2075 if !dims_present.contains(dim) {
2076 return Err(DecoderError::InvalidConfig(format!(
2077 "{} dshape missing required dimension {:?}",
2078 name, dim
2079 )));
2080 }
2081 }
2082
2083 Ok(())
2084 }
2085
2086 fn get_box_count(dshape: &[(DimName, usize)]) -> Option<usize> {
2087 for (dim_name, dim_size) in dshape {
2088 if *dim_name == DimName::NumBoxes {
2089 return Some(*dim_size);
2090 }
2091 }
2092 None
2093 }
2094
2095 fn get_class_count_no_dshape(
2096 config: ConfigOutputRef,
2097 protos: Option<usize>,
2098 ) -> Result<usize, DecoderError> {
2099 match config {
2100 ConfigOutputRef::Detection(detection) => match detection.decoder {
2101 DecoderType::Ultralytics => {
2102 if detection.shape[1] <= 4 + protos.unwrap_or(0) {
2103 return Err(DecoderError::InvalidConfig(format!(
2104 "Invalid shape: Yolo num_features {} must be greater than {}",
2105 detection.shape[1],
2106 4 + protos.unwrap_or(0),
2107 )));
2108 }
2109 Ok(detection.shape[1] - 4 - protos.unwrap_or(0))
2110 }
2111 DecoderType::ModelPack => {
2112 let Some(num_anchors) = detection.anchors.as_ref().map(|a| a.len()) else {
2113 return Err(DecoderError::Internal(
2114 "ModelPack Detection missing anchors".to_string(),
2115 ));
2116 };
2117 let anchors_x_features = detection.shape[3];
2118 if anchors_x_features <= num_anchors * 5 {
2119 return Err(DecoderError::InvalidConfig(format!(
2120 "Invalid ModelPack Split Detection shape: anchors_x_features {} not greater than number of anchors * 5 = {}",
2121 anchors_x_features,
2122 num_anchors * 5,
2123 )));
2124 }
2125
2126 if !anchors_x_features.is_multiple_of(num_anchors) {
2127 return Err(DecoderError::InvalidConfig(format!(
2128 "Invalid ModelPack Split Detection shape: anchors_x_features {} not a multiple of number of anchors {}",
2129 anchors_x_features, num_anchors
2130 )));
2131 }
2132 Ok(anchors_x_features / num_anchors - 5)
2133 }
2134 },
2135
2136 ConfigOutputRef::Scores(scores) => match scores.decoder {
2137 DecoderType::Ultralytics => Ok(scores.shape[1]),
2138 DecoderType::ModelPack => Ok(scores.shape[2]),
2139 },
2140 ConfigOutputRef::Segmentation(seg) => Ok(seg.shape[3]),
2141 _ => Err(DecoderError::Internal(
2142 "Attempted to get class count from unsupported config output".to_owned(),
2143 )),
2144 }
2145 }
2146
2147 // get the class count from dshape or calculate from num_features
2148 fn get_class_count(
2149 dshape: &[(DimName, usize)],
2150 protos: Option<usize>,
2151 anchors: Option<usize>,
2152 ) -> Result<usize, DecoderError> {
2153 if dshape.is_empty() {
2154 return Ok(0);
2155 }
2156 // if it has num_classes in dshape, return it
2157 for (dim_name, dim_size) in dshape {
2158 if *dim_name == DimName::NumClasses {
2159 return Ok(*dim_size);
2160 }
2161 }
2162
2163 // number of classes can be calculated from num_features - 4 for yolo. If the
2164 // model has protos, we also subtract the number of protos.
2165 for (dim_name, dim_size) in dshape {
2166 if *dim_name == DimName::NumFeatures {
2167 let protos = protos.unwrap_or(0);
2168 if protos + 4 >= *dim_size {
2169 return Err(DecoderError::InvalidConfig(format!(
2170 "Invalid shape: Yolo num_features {} must be greater than {}",
2171 *dim_size,
2172 protos + 4,
2173 )));
2174 }
2175 return Ok(*dim_size - 4 - protos);
2176 }
2177 }
2178
2179 // number of classes can be calculated from number of anchors for modelpack
2180 // split detection
2181 if let Some(num_anchors) = anchors {
2182 for (dim_name, dim_size) in dshape {
2183 if *dim_name == DimName::NumAnchorsXFeatures {
2184 let anchors_x_features = *dim_size;
2185 if anchors_x_features <= num_anchors * 5 {
2186 return Err(DecoderError::InvalidConfig(format!(
2187 "Invalid ModelPack Split Detection shape: anchors_x_features {} not greater than number of anchors * 5 = {}",
2188 anchors_x_features,
2189 num_anchors * 5,
2190 )));
2191 }
2192
2193 if !anchors_x_features.is_multiple_of(num_anchors) {
2194 return Err(DecoderError::InvalidConfig(format!(
2195 "Invalid ModelPack Split Detection shape: anchors_x_features {} not a multiple of number of anchors {}",
2196 anchors_x_features, num_anchors
2197 )));
2198 }
2199 return Ok((anchors_x_features / num_anchors) - 5);
2200 }
2201 }
2202 }
2203 Err(DecoderError::InvalidConfig(
2204 "Cannot determine number of classes from dshape".to_owned(),
2205 ))
2206 }
2207
2208 fn get_protos_count(dshape: &[(DimName, usize)]) -> Option<usize> {
2209 for (dim_name, dim_size) in dshape {
2210 if *dim_name == DimName::NumProtos {
2211 return Some(*dim_size);
2212 }
2213 }
2214 None
2215 }
2216}