edgefirst_decoder/decoder/builder.rs
1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::HashSet;
5
6use super::config::ConfigOutputRef;
7use super::configs::{self, DecoderType, DecoderVersion, DimName, ModelType};
8use super::merge::DecodeProgram;
9use super::{ConfigOutput, ConfigOutputs, Decoder};
10use crate::per_scale::{DecodeDtype, PerScalePlan};
11use crate::schema::SchemaV2;
12use crate::DecoderError;
13
14/// Extract `(width, height)` from a schema [`crate::schema::InputSpec`].
15///
16/// Prefers named dims (`DimName::Width` / `DimName::Height`) when the
17/// `dshape` is populated, falling back to the NHWC convention
18/// (`shape[1] = H, shape[2] = W`) for 4-element shapes whenever either
19/// named dim is missing. Returns `None` for any other shape arity — the
20/// decoder treats that as "input dims unknown" and skips the EDGEAI-1303
21/// normalization path.
22///
23/// The fallback fires when **either** `Height` or `Width` is missing from
24/// the dshape (not only when both are absent), so a partially-named
25/// dshape (e.g. only `Width`) still resolves both dims via positional
26/// inference instead of silently disabling normalization.
27fn input_dims_from_spec(input: &crate::schema::InputSpec) -> Option<(usize, usize)> {
28 use crate::configs::DimName;
29 let mut h = None;
30 let mut w = None;
31 // `SchemaV2::validate()` doesn't currently enforce
32 // `dshape.len() <= shape.len()`, so a malformed schema can trip an
33 // out-of-bounds index here. Use `shape.get(i)` to silently skip
34 // dshape entries that overflow the shape — the caller treats
35 // missing dims as "unknown" and disables EDGEAI-1303 normalization.
36 for (i, (name, _)) in input.dshape.iter().enumerate() {
37 match name {
38 DimName::Height => h = input.shape.get(i).copied(),
39 DimName::Width => w = input.shape.get(i).copied(),
40 _ => {}
41 }
42 }
43 if (h.is_none() || w.is_none()) && input.shape.len() == 4 {
44 // NHWC default: [N, H, W, C]. Mirrors the per-scale `extract_hw`
45 // fallback (`crates/decoder/src/per_scale/plan.rs::extract_hw`).
46 // Only fill the missing axis so a partial named dshape still
47 // resolves both dims.
48 h = h.or(Some(input.shape[1]));
49 w = w.or(Some(input.shape[2]));
50 }
51 match (w, h) {
52 (Some(w), Some(h)) => Some((w, h)),
53 _ => None,
54 }
55}
56
57#[cfg(test)]
58mod input_dims_from_spec_tests {
59 use super::input_dims_from_spec;
60 use crate::configs::DimName;
61 use crate::schema::InputSpec;
62
63 #[test]
64 fn named_dshape_resolves_dims() {
65 let spec = InputSpec {
66 shape: vec![1, 480, 640, 3],
67 dshape: vec![
68 (DimName::Batch, 1),
69 (DimName::Height, 480),
70 (DimName::Width, 640),
71 (DimName::NumFeatures, 3),
72 ],
73 cameraadaptor: None,
74 };
75 assert_eq!(input_dims_from_spec(&spec), Some((640, 480)));
76 }
77
78 #[test]
79 fn empty_dshape_falls_back_to_nhwc_for_4d() {
80 let spec = InputSpec {
81 shape: vec![1, 480, 640, 3],
82 dshape: vec![],
83 cameraadaptor: None,
84 };
85 assert_eq!(input_dims_from_spec(&spec), Some((640, 480)));
86 }
87
88 #[test]
89 fn malformed_dshape_longer_than_shape_does_not_panic() {
90 // Regression for Copilot review on PR #63: indexing
91 // `input.shape[i]` while iterating dshape can OOB-panic when
92 // `dshape.len() > shape.len()`. The fix uses `shape.get(i)`
93 // and silently treats overflow as "dim missing".
94 let spec = InputSpec {
95 shape: vec![640, 480], // 2-D shape
96 dshape: vec![
97 (DimName::Width, 640),
98 (DimName::Height, 480),
99 (DimName::NumFeatures, 3), // overflow — index 2 ≥ shape.len()
100 ],
101 cameraadaptor: None,
102 };
103 // First two dshape entries resolve via `shape.get()`; the third
104 // is a no-op. Width/Height both resolved, so we expect Some.
105 assert_eq!(input_dims_from_spec(&spec), Some((640, 480)));
106 }
107
108 #[test]
109 fn malformed_dshape_only_overflow_returns_none() {
110 // All dshape entries are past the shape boundary — width and
111 // height stay None and the 4-D NHWC fallback doesn't fire
112 // (shape.len() == 1), so we get None instead of a panic.
113 let spec = InputSpec {
114 shape: vec![1],
115 dshape: vec![
116 (DimName::NumFeatures, 3),
117 (DimName::Height, 480),
118 (DimName::Width, 640),
119 ],
120 cameraadaptor: None,
121 };
122 assert_eq!(input_dims_from_spec(&spec), None);
123 }
124}
125
126#[derive(Debug, Clone, PartialEq)]
127pub struct DecoderBuilder {
128 config_src: Option<ConfigSource>,
129 iou_threshold: f32,
130 score_threshold: f32,
131 /// NMS mode.
132 ///
133 /// - `Some(Nms::Auto)` — resolve from config or fallback to
134 /// `ClassAgnostic` (builder default)
135 /// - `Some(Nms::ClassAgnostic)` — explicit class-agnostic override
136 /// - `Some(Nms::ClassAware)` — explicit class-aware override
137 /// - `None` — bypass NMS entirely
138 nms: Option<configs::Nms>,
139 /// Output dtype for the per-scale fast path. Has no effect on
140 /// schemas without per-scale children (which use the legacy decode
141 /// path).
142 decode_dtype: DecodeDtype,
143 pre_nms_top_k: usize,
144 max_det: usize,
145 /// Explicit override for the model input dimensions `(width, height)`,
146 /// consumed by EDGEAI-1303 normalization. When set, takes precedence
147 /// over schema-derived dims; when `None`, the value is read from the
148 /// schema's `input.shape` / `input.dshape` at build time.
149 input_dims: Option<(usize, usize)>,
150}
151
152#[derive(Debug, Clone, PartialEq)]
153enum ConfigSource {
154 Yaml(String),
155 Json(String),
156 Config(ConfigOutputs),
157 /// Schema v2 metadata. During build the schema is either converted
158 /// to a legacy [`ConfigOutputs`] (flat case) or used to construct a
159 /// [`DecodeProgram`] that performs per-child dequant + merge at
160 /// decode time.
161 Schema(SchemaV2),
162}
163
164impl Default for DecoderBuilder {
165 /// Creates a default DecoderBuilder with no configuration and 0.5 score
166 /// threshold and 0.5 IoU threshold.
167 ///
168 /// A valid configuration must be provided before building the Decoder.
169 ///
170 /// # Examples
171 /// ```rust
172 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
173 /// # fn main() -> DecoderResult<()> {
174 /// # let config_yaml = edgefirst_bench::testdata::read_to_string("modelpack_split.yaml").to_string();
175 /// let decoder = DecoderBuilder::default()
176 /// .with_config_yaml_str(config_yaml)
177 /// .build()?;
178 /// assert_eq!(decoder.score_threshold, 0.5);
179 /// assert_eq!(decoder.iou_threshold, 0.5);
180 ///
181 /// # Ok(())
182 /// # }
183 /// ```
184 fn default() -> Self {
185 Self {
186 config_src: None,
187 iou_threshold: 0.5,
188 score_threshold: 0.5,
189 nms: Some(configs::Nms::Auto),
190 decode_dtype: DecodeDtype::F32,
191 pre_nms_top_k: 300,
192 max_det: 300,
193 input_dims: None,
194 }
195 }
196}
197
198impl DecoderBuilder {
199 /// Creates a default DecoderBuilder with no configuration and 0.5 score
200 /// threshold and 0.5 IoU threshold.
201 ///
202 /// A valid configuration must be provided before building the Decoder.
203 ///
204 /// # Examples
205 /// ```rust
206 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
207 /// # fn main() -> DecoderResult<()> {
208 /// # let config_yaml = edgefirst_bench::testdata::read_to_string("modelpack_split.yaml").to_string();
209 /// let decoder = DecoderBuilder::new()
210 /// .with_config_yaml_str(config_yaml)
211 /// .build()?;
212 /// assert_eq!(decoder.score_threshold, 0.5);
213 /// assert_eq!(decoder.iou_threshold, 0.5);
214 ///
215 /// # Ok(())
216 /// # }
217 /// ```
218 pub fn new() -> Self {
219 Self::default()
220 }
221
222 /// Loads a model configuration in YAML format. Does not check if the string
223 /// is a correct configuration file. Use `DecoderBuilder.build()` to
224 /// deserialize the YAML and parse the model configuration.
225 ///
226 /// # Examples
227 /// ```rust,no_run
228 /// # use edgefirst_decoder::DecoderBuilder;
229 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
230 /// let config_yaml = std::fs::read_to_string("modelpack_split.yaml")?;
231 /// let decoder = DecoderBuilder::new()
232 /// .with_config_yaml_str(config_yaml)
233 /// .build()?;
234 ///
235 /// # Ok(())
236 /// # }
237 /// ```
238 pub fn with_config_yaml_str(mut self, yaml_str: String) -> Self {
239 self.config_src.replace(ConfigSource::Yaml(yaml_str));
240 self
241 }
242
243 /// Loads a model configuration in JSON format. Does not check if the string
244 /// is a correct configuration file. Use `DecoderBuilder.build()` to
245 /// deserialize the JSON and parse the model configuration.
246 ///
247 /// # Examples
248 /// ```rust,no_run
249 /// # use edgefirst_decoder::DecoderBuilder;
250 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
251 /// let config_json = std::fs::read_to_string("modelpack_split.json")?;
252 /// let decoder = DecoderBuilder::new()
253 /// .with_config_json_str(config_json)
254 /// .build()?;
255 ///
256 /// # Ok(())
257 /// # }
258 /// ```
259 pub fn with_config_json_str(mut self, json_str: String) -> Self {
260 self.config_src.replace(ConfigSource::Json(json_str));
261 self
262 }
263
264 /// Loads a model configuration. Does not check if the configuration is
265 /// correct. Intended to be used when the user needs control over the
266 /// deserialize of the configuration information. Use
267 /// `DecoderBuilder.build()` to parse the model configuration.
268 ///
269 /// # Examples
270 /// ```rust,no_run
271 /// # use edgefirst_decoder::{DecoderBuilder, ConfigOutputs};
272 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
273 /// let config_json = std::fs::read_to_string("modelpack_split.json")?;
274 /// let config: ConfigOutputs = serde_json::from_str(&config_json)?;
275 /// let decoder = DecoderBuilder::new().with_config(config).build()?;
276 ///
277 /// # Ok(())
278 /// # }
279 /// ```
280 pub fn with_config(mut self, config: ConfigOutputs) -> Self {
281 self.config_src.replace(ConfigSource::Config(config));
282 self
283 }
284
285 /// Configure the decoder from a schema v2 metadata document.
286 ///
287 /// Accepts a [`SchemaV2`] as produced by [`SchemaV2::parse_json`],
288 /// [`SchemaV2::parse_yaml`], [`SchemaV2::parse_file`], or
289 /// constructed programmatically. The builder validates the schema,
290 /// compiles a [`DecodeProgram`] for any split logical outputs
291 /// (per-scale or channel sub-splits), and downconverts the
292 /// logical-level semantics to the legacy [`ConfigOutputs`]
293 /// representation consumed by the existing decoder dispatch.
294 ///
295 /// # Examples
296 ///
297 /// ```rust,no_run
298 /// use edgefirst_decoder::{DecoderBuilder, DecoderResult};
299 /// use edgefirst_decoder::schema::SchemaV2;
300 ///
301 /// # fn main() -> DecoderResult<()> {
302 /// let schema = SchemaV2::parse_file("model/edgefirst.json")?;
303 /// let decoder = DecoderBuilder::new()
304 /// .with_schema(schema)
305 /// .with_score_threshold(0.25)
306 /// .build()?;
307 /// # Ok(())
308 /// # }
309 /// ```
310 pub fn with_schema(mut self, schema: SchemaV2) -> Self {
311 self.config_src.replace(ConfigSource::Schema(schema));
312 self
313 }
314
315 /// Choose the output dtype for the per-scale decoder pipeline.
316 ///
317 /// Defaults to [`DecodeDtype::F32`]. Has no effect on schemas
318 /// without per-scale children (which use the legacy decode path).
319 /// `F16` saves ~2× memory bandwidth at the cost of 10-bit mantissa
320 /// precision; empirically safe for YOLO-family models.
321 ///
322 /// # Examples
323 ///
324 /// ```rust,no_run
325 /// use edgefirst_decoder::{DecodeDtype, DecoderBuilder, DecoderResult};
326 /// use edgefirst_decoder::schema::SchemaV2;
327 ///
328 /// # fn main() -> DecoderResult<()> {
329 /// let schema = SchemaV2::parse_file("model/edgefirst.json")?;
330 /// let decoder = DecoderBuilder::new()
331 /// .with_schema(schema)
332 /// .with_decode_dtype(DecodeDtype::F32)
333 /// .build()?;
334 /// # Ok(())
335 /// # }
336 /// ```
337 pub fn with_decode_dtype(mut self, dtype: DecodeDtype) -> Self {
338 self.decode_dtype = dtype;
339 self
340 }
341
342 /// Loads a YOLO detection model configuration. Use
343 /// `DecoderBuilder.build()` to parse the model configuration.
344 ///
345 /// # Examples
346 /// ```rust
347 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
348 /// # fn main() -> DecoderResult<()> {
349 /// let decoder = DecoderBuilder::new()
350 /// .with_config_yolo_det(
351 /// configs::Detection {
352 /// anchors: None,
353 /// decoder: configs::DecoderType::Ultralytics,
354 /// quantization: Some(configs::QuantTuple(0.012345, 26)),
355 /// shape: vec![1, 84, 8400],
356 /// dshape: Vec::new(),
357 /// normalized: Some(true),
358 /// },
359 /// None,
360 /// )
361 /// .build()?;
362 ///
363 /// # Ok(())
364 /// # }
365 /// ```
366 pub fn with_config_yolo_det(
367 mut self,
368 boxes: configs::Detection,
369 version: Option<DecoderVersion>,
370 ) -> Self {
371 let config = ConfigOutputs {
372 outputs: vec![ConfigOutput::Detection(boxes)],
373 decoder_version: version,
374 ..Default::default()
375 };
376 self.config_src.replace(ConfigSource::Config(config));
377 self
378 }
379
380 /// Loads a YOLO split detection model configuration. Use
381 /// `DecoderBuilder.build()` to parse the model configuration.
382 ///
383 /// # Examples
384 /// ```rust
385 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
386 /// # fn main() -> DecoderResult<()> {
387 /// let boxes_config = configs::Boxes {
388 /// decoder: configs::DecoderType::Ultralytics,
389 /// quantization: Some(configs::QuantTuple(0.012345, 26)),
390 /// shape: vec![1, 4, 8400],
391 /// dshape: Vec::new(),
392 /// normalized: Some(true),
393 /// };
394 /// let scores_config = configs::Scores {
395 /// decoder: configs::DecoderType::Ultralytics,
396 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
397 /// shape: vec![1, 80, 8400],
398 /// dshape: Vec::new(),
399 /// };
400 /// let decoder = DecoderBuilder::new()
401 /// .with_config_yolo_split_det(boxes_config, scores_config)
402 /// .build()?;
403 /// # Ok(())
404 /// # }
405 /// ```
406 pub fn with_config_yolo_split_det(
407 mut self,
408 boxes: configs::Boxes,
409 scores: configs::Scores,
410 ) -> Self {
411 let config = ConfigOutputs {
412 outputs: vec![ConfigOutput::Boxes(boxes), ConfigOutput::Scores(scores)],
413 ..Default::default()
414 };
415 self.config_src.replace(ConfigSource::Config(config));
416 self
417 }
418
419 /// Loads a YOLO segmentation model configuration. Use
420 /// `DecoderBuilder.build()` to parse the model configuration.
421 ///
422 /// # Examples
423 /// ```rust
424 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
425 /// # fn main() -> DecoderResult<()> {
426 /// let seg_config = configs::Detection {
427 /// decoder: configs::DecoderType::Ultralytics,
428 /// quantization: Some(configs::QuantTuple(0.012345, 26)),
429 /// shape: vec![1, 116, 8400],
430 /// anchors: None,
431 /// dshape: Vec::new(),
432 /// normalized: Some(true),
433 /// };
434 /// let protos_config = configs::Protos {
435 /// decoder: configs::DecoderType::Ultralytics,
436 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
437 /// shape: vec![1, 160, 160, 32],
438 /// dshape: Vec::new(),
439 /// };
440 /// let decoder = DecoderBuilder::new()
441 /// .with_config_yolo_segdet(
442 /// seg_config,
443 /// protos_config,
444 /// Some(configs::DecoderVersion::Yolov8),
445 /// )
446 /// .build()?;
447 /// # Ok(())
448 /// # }
449 /// ```
450 pub fn with_config_yolo_segdet(
451 mut self,
452 boxes: configs::Detection,
453 protos: configs::Protos,
454 version: Option<DecoderVersion>,
455 ) -> Self {
456 let config = ConfigOutputs {
457 outputs: vec![ConfigOutput::Detection(boxes), ConfigOutput::Protos(protos)],
458 decoder_version: version,
459 ..Default::default()
460 };
461 self.config_src.replace(ConfigSource::Config(config));
462 self
463 }
464
465 /// Loads a YOLO split segmentation model configuration. Use
466 /// `DecoderBuilder.build()` to parse the model configuration.
467 ///
468 /// # Examples
469 /// ```rust
470 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
471 /// # fn main() -> DecoderResult<()> {
472 /// let boxes_config = configs::Boxes {
473 /// decoder: configs::DecoderType::Ultralytics,
474 /// quantization: Some(configs::QuantTuple(0.012345, 26)),
475 /// shape: vec![1, 4, 8400],
476 /// dshape: Vec::new(),
477 /// normalized: Some(true),
478 /// };
479 /// let scores_config = configs::Scores {
480 /// decoder: configs::DecoderType::Ultralytics,
481 /// quantization: Some(configs::QuantTuple(0.012345, 14)),
482 /// shape: vec![1, 80, 8400],
483 /// dshape: Vec::new(),
484 /// };
485 /// let mask_config = configs::MaskCoefficients {
486 /// decoder: configs::DecoderType::Ultralytics,
487 /// quantization: Some(configs::QuantTuple(0.0064123, 125)),
488 /// shape: vec![1, 32, 8400],
489 /// dshape: Vec::new(),
490 /// };
491 /// let protos_config = configs::Protos {
492 /// decoder: configs::DecoderType::Ultralytics,
493 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
494 /// shape: vec![1, 160, 160, 32],
495 /// dshape: Vec::new(),
496 /// };
497 /// let decoder = DecoderBuilder::new()
498 /// .with_config_yolo_split_segdet(boxes_config, scores_config, mask_config, protos_config)
499 /// .build()?;
500 /// # Ok(())
501 /// # }
502 /// ```
503 pub fn with_config_yolo_split_segdet(
504 mut self,
505 boxes: configs::Boxes,
506 scores: configs::Scores,
507 mask_coefficients: configs::MaskCoefficients,
508 protos: configs::Protos,
509 ) -> Self {
510 let config = ConfigOutputs {
511 outputs: vec![
512 ConfigOutput::Boxes(boxes),
513 ConfigOutput::Scores(scores),
514 ConfigOutput::MaskCoefficients(mask_coefficients),
515 ConfigOutput::Protos(protos),
516 ],
517 ..Default::default()
518 };
519 self.config_src.replace(ConfigSource::Config(config));
520 self
521 }
522
523 /// Loads a ModelPack detection model configuration. Use
524 /// `DecoderBuilder.build()` to parse the model configuration.
525 ///
526 /// # Examples
527 /// ```rust
528 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
529 /// # fn main() -> DecoderResult<()> {
530 /// let boxes_config = configs::Boxes {
531 /// decoder: configs::DecoderType::ModelPack,
532 /// quantization: Some(configs::QuantTuple(0.012345, 26)),
533 /// shape: vec![1, 8400, 1, 4],
534 /// dshape: Vec::new(),
535 /// normalized: Some(true),
536 /// };
537 /// let scores_config = configs::Scores {
538 /// decoder: configs::DecoderType::ModelPack,
539 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
540 /// shape: vec![1, 8400, 3],
541 /// dshape: Vec::new(),
542 /// };
543 /// let decoder = DecoderBuilder::new()
544 /// .with_config_modelpack_det(boxes_config, scores_config)
545 /// .build()?;
546 /// # Ok(())
547 /// # }
548 /// ```
549 pub fn with_config_modelpack_det(
550 mut self,
551 boxes: configs::Boxes,
552 scores: configs::Scores,
553 ) -> Self {
554 let config = ConfigOutputs {
555 outputs: vec![ConfigOutput::Boxes(boxes), ConfigOutput::Scores(scores)],
556 ..Default::default()
557 };
558 self.config_src.replace(ConfigSource::Config(config));
559 self
560 }
561
562 /// Loads a ModelPack split detection model configuration. Use
563 /// `DecoderBuilder.build()` to parse the model configuration.
564 ///
565 /// # Examples
566 /// ```rust
567 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
568 /// # fn main() -> DecoderResult<()> {
569 /// let config0 = configs::Detection {
570 /// anchors: Some(vec![
571 /// [0.13750000298023224, 0.2074074000120163],
572 /// [0.2541666626930237, 0.21481481194496155],
573 /// [0.23125000298023224, 0.35185185074806213],
574 /// ]),
575 /// decoder: configs::DecoderType::ModelPack,
576 /// quantization: Some(configs::QuantTuple(0.012345, 26)),
577 /// shape: vec![1, 17, 30, 18],
578 /// dshape: Vec::new(),
579 /// normalized: Some(true),
580 /// };
581 /// let config1 = configs::Detection {
582 /// anchors: Some(vec![
583 /// [0.36666667461395264, 0.31481480598449707],
584 /// [0.38749998807907104, 0.4740740656852722],
585 /// [0.5333333611488342, 0.644444465637207],
586 /// ]),
587 /// decoder: configs::DecoderType::ModelPack,
588 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
589 /// shape: vec![1, 9, 15, 18],
590 /// dshape: Vec::new(),
591 /// normalized: Some(true),
592 /// };
593 ///
594 /// let decoder = DecoderBuilder::new()
595 /// .with_config_modelpack_det_split(vec![config0, config1])
596 /// .build()?;
597 /// # Ok(())
598 /// # }
599 /// ```
600 pub fn with_config_modelpack_det_split(mut self, boxes: Vec<configs::Detection>) -> Self {
601 let outputs = boxes.into_iter().map(ConfigOutput::Detection).collect();
602 let config = ConfigOutputs {
603 outputs,
604 ..Default::default()
605 };
606 self.config_src.replace(ConfigSource::Config(config));
607 self
608 }
609
610 /// Loads a ModelPack segmentation detection model configuration. Use
611 /// `DecoderBuilder.build()` to parse the model configuration.
612 ///
613 /// # Examples
614 /// ```rust
615 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
616 /// # fn main() -> DecoderResult<()> {
617 /// let boxes_config = configs::Boxes {
618 /// decoder: configs::DecoderType::ModelPack,
619 /// quantization: Some(configs::QuantTuple(0.012345, 26)),
620 /// shape: vec![1, 8400, 1, 4],
621 /// dshape: Vec::new(),
622 /// normalized: Some(true),
623 /// };
624 /// let scores_config = configs::Scores {
625 /// decoder: configs::DecoderType::ModelPack,
626 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
627 /// shape: vec![1, 8400, 2],
628 /// dshape: Vec::new(),
629 /// };
630 /// let seg_config = configs::Segmentation {
631 /// decoder: configs::DecoderType::ModelPack,
632 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
633 /// shape: vec![1, 640, 640, 3],
634 /// dshape: Vec::new(),
635 /// };
636 /// let decoder = DecoderBuilder::new()
637 /// .with_config_modelpack_segdet(boxes_config, scores_config, seg_config)
638 /// .build()?;
639 /// # Ok(())
640 /// # }
641 /// ```
642 pub fn with_config_modelpack_segdet(
643 mut self,
644 boxes: configs::Boxes,
645 scores: configs::Scores,
646 segmentation: configs::Segmentation,
647 ) -> Self {
648 let config = ConfigOutputs {
649 outputs: vec![
650 ConfigOutput::Boxes(boxes),
651 ConfigOutput::Scores(scores),
652 ConfigOutput::Segmentation(segmentation),
653 ],
654 ..Default::default()
655 };
656 self.config_src.replace(ConfigSource::Config(config));
657 self
658 }
659
660 /// Loads a ModelPack segmentation split detection model configuration. Use
661 /// `DecoderBuilder.build()` to parse the model configuration.
662 ///
663 /// # Examples
664 /// ```rust
665 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
666 /// # fn main() -> DecoderResult<()> {
667 /// let config0 = configs::Detection {
668 /// anchors: Some(vec![
669 /// [0.36666667461395264, 0.31481480598449707],
670 /// [0.38749998807907104, 0.4740740656852722],
671 /// [0.5333333611488342, 0.644444465637207],
672 /// ]),
673 /// decoder: configs::DecoderType::ModelPack,
674 /// quantization: Some(configs::QuantTuple(0.08547406643629074, 174)),
675 /// shape: vec![1, 9, 15, 18],
676 /// dshape: Vec::new(),
677 /// normalized: Some(true),
678 /// };
679 /// let config1 = configs::Detection {
680 /// anchors: Some(vec![
681 /// [0.13750000298023224, 0.2074074000120163],
682 /// [0.2541666626930237, 0.21481481194496155],
683 /// [0.23125000298023224, 0.35185185074806213],
684 /// ]),
685 /// decoder: configs::DecoderType::ModelPack,
686 /// quantization: Some(configs::QuantTuple(0.09929127991199493, 183)),
687 /// shape: vec![1, 17, 30, 18],
688 /// dshape: Vec::new(),
689 /// normalized: Some(true),
690 /// };
691 /// let seg_config = configs::Segmentation {
692 /// decoder: configs::DecoderType::ModelPack,
693 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
694 /// shape: vec![1, 640, 640, 2],
695 /// dshape: Vec::new(),
696 /// };
697 /// let decoder = DecoderBuilder::new()
698 /// .with_config_modelpack_segdet_split(vec![config0, config1], seg_config)
699 /// .build()?;
700 /// # Ok(())
701 /// # }
702 /// ```
703 pub fn with_config_modelpack_segdet_split(
704 mut self,
705 boxes: Vec<configs::Detection>,
706 segmentation: configs::Segmentation,
707 ) -> Self {
708 let mut outputs = boxes
709 .into_iter()
710 .map(ConfigOutput::Detection)
711 .collect::<Vec<_>>();
712 outputs.push(ConfigOutput::Segmentation(segmentation));
713 let config = ConfigOutputs {
714 outputs,
715 ..Default::default()
716 };
717 self.config_src.replace(ConfigSource::Config(config));
718 self
719 }
720
721 /// Loads a ModelPack segmentation model configuration. Use
722 /// `DecoderBuilder.build()` to parse the model configuration.
723 ///
724 /// # Examples
725 /// ```rust
726 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
727 /// # fn main() -> DecoderResult<()> {
728 /// let seg_config = configs::Segmentation {
729 /// decoder: configs::DecoderType::ModelPack,
730 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
731 /// shape: vec![1, 640, 640, 3],
732 /// dshape: Vec::new(),
733 /// };
734 /// let decoder = DecoderBuilder::new()
735 /// .with_config_modelpack_seg(seg_config)
736 /// .build()?;
737 /// # Ok(())
738 /// # }
739 /// ```
740 pub fn with_config_modelpack_seg(mut self, segmentation: configs::Segmentation) -> Self {
741 let config = ConfigOutputs {
742 outputs: vec![ConfigOutput::Segmentation(segmentation)],
743 ..Default::default()
744 };
745 self.config_src.replace(ConfigSource::Config(config));
746 self
747 }
748
749 /// Add an output to the decoder configuration.
750 ///
751 /// Incrementally builds the model configuration by adding outputs one at
752 /// a time. The decoder resolves the model type from the combination of
753 /// outputs during `build()`.
754 ///
755 /// If `dshape` is non-empty on the output, `shape` is automatically
756 /// derived from it (the size component of each named dimension). This
757 /// prevents conflicts between `shape` and `dshape`.
758 ///
759 /// This uses the programmatic config path. Calling this after
760 /// `with_config_json_str()` or `with_config_yaml_str()` replaces the
761 /// string-based config source.
762 ///
763 /// # Examples
764 /// ```rust
765 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, ConfigOutput, configs};
766 /// # fn main() -> DecoderResult<()> {
767 /// let decoder = DecoderBuilder::new()
768 /// .add_output(ConfigOutput::Scores(configs::Scores {
769 /// decoder: configs::DecoderType::Ultralytics,
770 /// dshape: vec![
771 /// (configs::DimName::Batch, 1),
772 /// (configs::DimName::NumClasses, 80),
773 /// (configs::DimName::NumBoxes, 8400),
774 /// ],
775 /// ..Default::default()
776 /// }))
777 /// .add_output(ConfigOutput::Boxes(configs::Boxes {
778 /// decoder: configs::DecoderType::Ultralytics,
779 /// dshape: vec![
780 /// (configs::DimName::Batch, 1),
781 /// (configs::DimName::BoxCoords, 4),
782 /// (configs::DimName::NumBoxes, 8400),
783 /// ],
784 /// ..Default::default()
785 /// }))
786 /// .build()?;
787 /// # Ok(())
788 /// # }
789 /// ```
790 pub fn add_output(mut self, output: ConfigOutput) -> Self {
791 if !matches!(self.config_src, Some(ConfigSource::Config(_))) {
792 self.config_src = Some(ConfigSource::Config(ConfigOutputs::default()));
793 }
794 if let Some(ConfigSource::Config(ref mut config)) = self.config_src {
795 config.outputs.push(Self::normalize_output(output));
796 }
797 self
798 }
799
800 /// Sets the decoder version for Ultralytics models.
801 ///
802 /// This is used with `add_output()` to specify the YOLO architecture
803 /// version when it cannot be inferred from the output shapes alone.
804 ///
805 /// - `Yolov5`, `Yolov8`, `Yolo11`: Traditional models requiring external
806 /// NMS
807 /// - `Yolo26`: End-to-end models with NMS embedded in the model graph
808 ///
809 /// # Examples
810 /// ```rust
811 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, ConfigOutput, configs};
812 /// # fn main() -> DecoderResult<()> {
813 /// let decoder = DecoderBuilder::new()
814 /// .add_output(ConfigOutput::Detection(configs::Detection {
815 /// decoder: configs::DecoderType::Ultralytics,
816 /// dshape: vec![
817 /// (configs::DimName::Batch, 1),
818 /// (configs::DimName::NumBoxes, 100),
819 /// (configs::DimName::NumFeatures, 6),
820 /// ],
821 /// ..Default::default()
822 /// }))
823 /// .with_decoder_version(configs::DecoderVersion::Yolo26)
824 /// .build()?;
825 /// # Ok(())
826 /// # }
827 /// ```
828 pub fn with_decoder_version(mut self, version: configs::DecoderVersion) -> Self {
829 if !matches!(self.config_src, Some(ConfigSource::Config(_))) {
830 self.config_src = Some(ConfigSource::Config(ConfigOutputs::default()));
831 }
832 if let Some(ConfigSource::Config(ref mut config)) = self.config_src {
833 config.decoder_version = Some(version);
834 }
835 self
836 }
837
838 /// Normalize an output: if dshape is non-empty, derive shape from it.
839 fn normalize_output(mut output: ConfigOutput) -> ConfigOutput {
840 fn normalize_shape(shape: &mut Vec<usize>, dshape: &[(configs::DimName, usize)]) {
841 if !dshape.is_empty() {
842 *shape = dshape.iter().map(|(_, size)| *size).collect();
843 }
844 }
845 match &mut output {
846 ConfigOutput::Detection(c) => normalize_shape(&mut c.shape, &c.dshape),
847 ConfigOutput::Boxes(c) => normalize_shape(&mut c.shape, &c.dshape),
848 ConfigOutput::Scores(c) => normalize_shape(&mut c.shape, &c.dshape),
849 ConfigOutput::Protos(c) => normalize_shape(&mut c.shape, &c.dshape),
850 ConfigOutput::Segmentation(c) => normalize_shape(&mut c.shape, &c.dshape),
851 ConfigOutput::MaskCoefficients(c) => normalize_shape(&mut c.shape, &c.dshape),
852 ConfigOutput::Mask(c) => normalize_shape(&mut c.shape, &c.dshape),
853 ConfigOutput::Classes(c) => normalize_shape(&mut c.shape, &c.dshape),
854 }
855 output
856 }
857
858 /// Sets the scores threshold of the decoder
859 ///
860 /// # Examples
861 /// ```rust
862 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
863 /// # fn main() -> DecoderResult<()> {
864 /// # let config_json = edgefirst_bench::testdata::read_to_string("modelpack_split.json").to_string();
865 /// let decoder = DecoderBuilder::new()
866 /// .with_config_json_str(config_json)
867 /// .with_score_threshold(0.654)
868 /// .build()?;
869 /// assert_eq!(decoder.score_threshold, 0.654);
870 /// # Ok(())
871 /// # }
872 /// ```
873 pub fn with_score_threshold(mut self, score_threshold: f32) -> Self {
874 self.score_threshold = score_threshold;
875 self
876 }
877
878 /// Sets the IOU threshold of the decoder. Has no effect when NMS is set to
879 /// `None`
880 ///
881 /// # Examples
882 /// ```rust
883 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
884 /// # fn main() -> DecoderResult<()> {
885 /// # let config_json = edgefirst_bench::testdata::read_to_string("modelpack_split.json").to_string();
886 /// let decoder = DecoderBuilder::new()
887 /// .with_config_json_str(config_json)
888 /// .with_iou_threshold(0.654)
889 /// .build()?;
890 /// assert_eq!(decoder.iou_threshold, 0.654);
891 /// # Ok(())
892 /// # }
893 /// ```
894 pub fn with_iou_threshold(mut self, iou_threshold: f32) -> Self {
895 self.iou_threshold = iou_threshold;
896 self
897 }
898
899 /// Sets the NMS mode for the decoder.
900 ///
901 /// - `Some(Nms::Auto)` — resolve from model config (e.g. `edgefirst.json`)
902 /// or fall back to `ClassAgnostic` (this is the builder default)
903 /// - `Some(Nms::ClassAgnostic)` — class-agnostic NMS: suppress overlapping
904 /// boxes regardless of class label
905 /// - `Some(Nms::ClassAware)` — class-aware NMS: only suppress boxes that
906 /// share the same class label AND overlap above the IoU threshold
907 /// - `None` — bypass NMS entirely (for end-to-end models with embedded NMS)
908 ///
909 /// # Examples
910 /// ```rust
911 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, configs::Nms};
912 /// # fn main() -> DecoderResult<()> {
913 /// # let config_json = edgefirst_bench::testdata::read_to_string("modelpack_split.json").to_string();
914 /// let decoder = DecoderBuilder::new()
915 /// .with_config_json_str(config_json)
916 /// .with_nms(Some(Nms::ClassAware))
917 /// .build()?;
918 /// assert_eq!(decoder.nms, Some(Nms::ClassAware));
919 /// # Ok(())
920 /// # }
921 /// ```
922 pub fn with_nms(mut self, nms: Option<configs::Nms>) -> Self {
923 self.nms = nms;
924 self
925 }
926
927 /// Sets the maximum number of candidate boxes fed into NMS after score
928 /// filtering. Uses partial sort (O(N)) to select the top-K candidates,
929 /// dramatically reducing the O(N²) NMS cost when many low-confidence
930 /// proposals pass the threshold (common with mAP eval at 0.001).
931 ///
932 /// Default: 300.
933 ///
934 /// # ⚠️ Validation vs Deployment
935 ///
936 /// The default is appropriate for **deployment** where
937 /// `score_threshold ≥ 0.25` means few anchors survive filtering and
938 /// top-K is effectively a no-op.
939 ///
940 /// For **COCO mAP evaluation** (`score_threshold ≈ 0.001`), set this to
941 /// the total anchor count (8 400 for standard 640 × 640 YOLO models) or
942 /// to `0` (no limit) so that all score-passing candidates reach NMS.
943 /// Failing to do so causes **~9 pp box mAP loss** — the decoder math is
944 /// correct but the evaluation protocol requires full recall across the
945 /// confidence range.
946 ///
947 /// Post-processing latency scales with candidate count. At deployment
948 /// thresholds the cost difference is negligible; at validation thresholds
949 /// it is measurable but necessary for correct results.
950 ///
951 /// # Examples
952 ///
953 /// Deployment (default top-K, high score threshold):
954 /// ```rust
955 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
956 /// # fn main() -> DecoderResult<()> {
957 /// # let config_json = edgefirst_bench::testdata::read_to_string("modelpack_split.json").to_string();
958 /// let decoder = DecoderBuilder::new()
959 /// .with_config_json_str(config_json)
960 /// .with_score_threshold(0.25)
961 /// // pre_nms_top_k defaults to 300 — appropriate here
962 /// .build()?;
963 /// # Ok(())
964 /// # }
965 /// ```
966 ///
967 /// COCO mAP evaluation (pass all anchors to NMS):
968 /// ```rust
969 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
970 /// # fn main() -> DecoderResult<()> {
971 /// # let config_json = edgefirst_bench::testdata::read_to_string("modelpack_split.json").to_string();
972 /// let decoder = DecoderBuilder::new()
973 /// .with_config_json_str(config_json)
974 /// .with_score_threshold(0.001)
975 /// .with_pre_nms_top_k(8400) // all YOLO anchors
976 /// .with_max_det(300)
977 /// .build()?;
978 /// assert_eq!(decoder.pre_nms_top_k, 8400);
979 /// # Ok(())
980 /// # }
981 /// ```
982 pub fn with_pre_nms_top_k(mut self, pre_nms_top_k: usize) -> Self {
983 self.pre_nms_top_k = pre_nms_top_k;
984 self
985 }
986
987 /// Sets the maximum number of detections returned after NMS.
988 /// Matches the Ultralytics `max_det` parameter.
989 ///
990 /// Default: 300.
991 ///
992 /// # Examples
993 /// ```rust
994 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
995 /// # fn main() -> DecoderResult<()> {
996 /// # let config_json = edgefirst_bench::testdata::read_to_string("modelpack_split.json").to_string();
997 /// let decoder = DecoderBuilder::new()
998 /// .with_config_json_str(config_json)
999 /// .with_max_det(100)
1000 /// .build()?;
1001 /// assert_eq!(decoder.max_det, 100);
1002 /// # Ok(())
1003 /// # }
1004 /// ```
1005 pub fn with_max_det(mut self, max_det: usize) -> Self {
1006 self.max_det = max_det;
1007 self
1008 }
1009
1010 /// Sets the model input dimensions `(width, height)` consumed by the
1011 /// EDGEAI-1303 normalization path. Use this when building via
1012 /// [`with_config`](Self::with_config) / [`add_output`](Self::add_output)
1013 /// (no schema) and the model emits pixel-space boxes that need to be
1014 /// divided by `(W, H)` before NMS.
1015 ///
1016 /// When the builder is also configured with [`with_schema`](Self::with_schema)
1017 /// (or `with_config_json_str` / `with_config_yaml_str`) and the schema's
1018 /// `input` block carries usable dims, this explicit override **takes
1019 /// precedence** so callers can correct schemas with missing or wrong
1020 /// input specs without rewriting the schema.
1021 ///
1022 /// # Examples
1023 /// ```rust
1024 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
1025 /// # fn main() -> DecoderResult<()> {
1026 /// # let config_yaml = edgefirst_bench::testdata::read_to_string("modelpack_split.yaml").to_string();
1027 /// let decoder = DecoderBuilder::new()
1028 /// .with_config_yaml_str(config_yaml)
1029 /// .with_input_dims(640, 640)
1030 /// .build()?;
1031 /// assert_eq!(decoder.input_dims(), Some((640, 640)));
1032 /// # Ok(())
1033 /// # }
1034 /// ```
1035 pub fn with_input_dims(mut self, width: usize, height: usize) -> Self {
1036 self.input_dims = Some((width, height));
1037 self
1038 }
1039
1040 /// Builds the decoder with the given settings. If the config is a JSON or
1041 /// YAML string, this will deserialize the JSON or YAML and then parse the
1042 /// configuration information.
1043 ///
1044 /// # Examples
1045 /// ```rust
1046 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
1047 /// # fn main() -> DecoderResult<()> {
1048 /// # let config_json = edgefirst_bench::testdata::read_to_string("modelpack_split.json").to_string();
1049 /// let decoder = DecoderBuilder::new()
1050 /// .with_config_json_str(config_json)
1051 /// .with_score_threshold(0.654)
1052 /// .build()?;
1053 /// # Ok(())
1054 /// # }
1055 /// ```
1056 pub fn build(self) -> Result<Decoder, DecoderError> {
1057 let decode_dtype = self.decode_dtype;
1058 let explicit_input_dims = self.input_dims;
1059 let (config, decode_program, per_scale_plan, schema_input_dims) = match self.config_src {
1060 Some(ConfigSource::Json(s)) => {
1061 Self::build_from_schema(SchemaV2::parse_json(&s)?, decode_dtype)?
1062 }
1063 Some(ConfigSource::Yaml(s)) => {
1064 Self::build_from_schema(SchemaV2::parse_yaml(&s)?, decode_dtype)?
1065 }
1066 Some(ConfigSource::Config(c)) => (c, None, None, None),
1067 Some(ConfigSource::Schema(schema)) => Self::build_from_schema(schema, decode_dtype)?,
1068 None => return Err(DecoderError::NoConfig),
1069 };
1070 // Explicit `with_input_dims(W, H)` overrides any schema-derived
1071 // value so callers can fix schemas with missing or wrong input
1072 // specs without rewriting the schema (EDGEAI-1303).
1073 let input_dims = explicit_input_dims.or(schema_input_dims);
1074
1075 // Enforce the physical-order contract: when dshape is present
1076 // it must describe the same axes as shape in the same order,
1077 // listed from outermost to innermost. Ambiguous-layout roles
1078 // (Protos, Boxes, Scores, MaskCoefficients, Classes, Detection)
1079 // may still omit dshape when shape is already in the decoder's
1080 // canonical order.
1081 for output in &config.outputs {
1082 Decoder::validate_output_layout(output.into())?;
1083 }
1084
1085 // Extract normalized flag from config outputs.
1086 //
1087 // The per-scale subsystem (DFL/LTRB → dist2bbox → sigmoid) emits
1088 // boxes in pixel coordinates by design — `(grid + dist) * stride`
1089 // — independently of any `normalized: true` annotation in the
1090 // schema. Override to `Some(false)` so the per-scale bridge's
1091 // call to `yolo::maybe_normalize_boxes_in_place` fires and
1092 // divides by `input_dims`, yielding `[0, 1]` output. The
1093 // accessor `Decoder::normalized_boxes()` applies the
1094 // pixel→normalized upgrade for the per-scale path and for any
1095 // legacy `ModelType` whose every entry point normalizes
1096 // uniformly (currently `YoloSegDet`, `YoloSplitSegDet`, and
1097 // `YoloSegDet2Way`); other paths surface the raw flag.
1098 let normalized = if per_scale_plan.is_some() {
1099 Some(false)
1100 } else {
1101 Self::get_normalized(&config.outputs)
1102 };
1103
1104 // NMS precedence:
1105 // Some(ClassAgnostic|ClassAware) → explicit user override
1106 // Some(Auto) → resolve from config, fallback to ClassAgnostic
1107 // None → NMS disabled (explicit)
1108 //
1109 // `Auto` is always resolved to a concrete mode here — it never
1110 // persists into the built `Decoder`, even if the config itself
1111 // contains `Auto`.
1112 let resolve_auto = |nms: Option<configs::Nms>| match nms {
1113 Some(configs::Nms::Auto) | None => Some(configs::Nms::ClassAgnostic),
1114 concrete => concrete,
1115 };
1116 let nms = match self.nms {
1117 Some(configs::Nms::Auto) => resolve_auto(config.nms),
1118 other => other,
1119 };
1120 // When the per-scale path is active, the per_scale subsystem owns
1121 // model decoding entirely — `decode` / `decode_proto` short-circuit
1122 // on `per_scale.is_some()` before reading `model_type`. Skip the
1123 // legacy ModelType validation, which otherwise rejects per-scale
1124 // schemas that carry `decoder_version: yolo26` (an
1125 // "end-to-end" hint) but use the per-scale split outputs rather
1126 // than the end-to-end split-output shape the legacy validator
1127 // expects. We keep a placeholder `ModelType` so the field remains
1128 // valid; it is dead state for per-scale Decoders.
1129 let model_type = if per_scale_plan.is_some() {
1130 // Drop the un-needed config; the per-scale subsystem owns it.
1131 drop(config);
1132 ModelType::PerScale
1133 } else {
1134 Self::get_model_type(config)?
1135 };
1136
1137 let per_scale = per_scale_plan
1138 .map(|plan| std::sync::Mutex::new(crate::per_scale::PerScaleDecoder::new(plan)));
1139
1140 debug_assert!(
1141 !matches!(nms, Some(configs::Nms::Auto)),
1142 "Nms::Auto must be resolved to a concrete mode before building Decoder"
1143 );
1144
1145 Ok(Decoder {
1146 model_type,
1147 iou_threshold: self.iou_threshold,
1148 score_threshold: self.score_threshold,
1149 nms,
1150 pre_nms_top_k: self.pre_nms_top_k,
1151 max_det: self.max_det,
1152 normalized,
1153 input_dims,
1154 decode_program,
1155 per_scale,
1156 })
1157 }
1158
1159 /// Validate a [`SchemaV2`] and lower it to the (legacy `ConfigOutputs`,
1160 /// optional `DecodeProgram`, optional `PerScalePlan`) tuple the rest
1161 /// of `build()` consumes.
1162 ///
1163 /// Centralises the v2 lowering so JSON, YAML, and direct
1164 /// `with_schema` callers all go through the same validation,
1165 /// merge-program, and per-scale plan construction. `SchemaV2::parse_json`
1166 /// / `parse_yaml` already auto-detect v1 vs v2 input and return a v2
1167 /// schema either way (v1 inputs are upgraded in memory via
1168 /// `from_v1`), so this helper is the sole place that turns a
1169 /// schema into builder-ready state.
1170 #[allow(clippy::type_complexity)]
1171 fn build_from_schema(
1172 schema: SchemaV2,
1173 decode_dtype: DecodeDtype,
1174 ) -> Result<
1175 (
1176 ConfigOutputs,
1177 Option<DecodeProgram>,
1178 Option<PerScalePlan>,
1179 Option<(usize, usize)>,
1180 ),
1181 DecoderError,
1182 > {
1183 schema.validate()?;
1184 let program = DecodeProgram::try_from_schema(&schema)?;
1185 let per_scale = PerScalePlan::try_from_schema(&schema, decode_dtype)?;
1186 // Extract model input (W, H) from `input.shape`/`dshape`. Used by
1187 // the legacy decode path to honour `normalized: false` (see
1188 // EDGEAI-1303). `None` is fine when the schema omits the input
1189 // spec — the decoder falls back to the protobox `>2.0` reject.
1190 let input_dims = schema.input.as_ref().and_then(input_dims_from_spec);
1191 let legacy = schema.to_legacy_config_outputs()?;
1192 Ok((legacy, program, per_scale, input_dims))
1193 }
1194
1195 /// Extracts the normalized flag from config outputs.
1196 /// - `Some(true)`: Boxes are in normalized [0,1] coordinates
1197 /// - `Some(false)`: Boxes are in pixel coordinates
1198 /// - `None`: Unknown (not specified in config), caller must infer
1199 fn get_normalized(outputs: &[ConfigOutput]) -> Option<bool> {
1200 for output in outputs {
1201 match output {
1202 ConfigOutput::Detection(det) => return det.normalized,
1203 ConfigOutput::Boxes(boxes) => return boxes.normalized,
1204 _ => {}
1205 }
1206 }
1207 None // not specified
1208 }
1209
1210 fn get_model_type(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
1211 // yolo or modelpack
1212 let mut yolo = false;
1213 let mut modelpack = false;
1214 for c in &configs.outputs {
1215 match c.decoder() {
1216 DecoderType::ModelPack => modelpack = true,
1217 DecoderType::Ultralytics => yolo = true,
1218 }
1219 }
1220 match (modelpack, yolo) {
1221 (true, true) => Err(DecoderError::InvalidConfig(
1222 "Both ModelPack and Yolo outputs found in config".to_string(),
1223 )),
1224 (true, false) => Self::get_model_type_modelpack(configs),
1225 (false, true) => Self::get_model_type_yolo(configs),
1226 (false, false) => Err(DecoderError::InvalidConfig(
1227 "No outputs found in config".to_string(),
1228 )),
1229 }
1230 }
1231
1232 fn get_model_type_yolo(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
1233 let mut boxes = None;
1234 let mut protos = None;
1235 let mut split_boxes = None;
1236 let mut split_scores = None;
1237 let mut split_mask_coeff = None;
1238 let mut split_classes = None;
1239 for c in configs.outputs {
1240 match c {
1241 ConfigOutput::Detection(detection) => boxes = Some(detection),
1242 ConfigOutput::Segmentation(_) => {
1243 return Err(DecoderError::InvalidConfig(
1244 "Invalid Segmentation output with Yolo decoder".to_string(),
1245 ));
1246 }
1247 ConfigOutput::Protos(protos_) => protos = Some(protos_),
1248 ConfigOutput::Mask(_) => {
1249 return Err(DecoderError::InvalidConfig(
1250 "Invalid Mask output with Yolo decoder".to_string(),
1251 ));
1252 }
1253 ConfigOutput::Scores(scores) => split_scores = Some(scores),
1254 ConfigOutput::Boxes(boxes) => split_boxes = Some(boxes),
1255 ConfigOutput::MaskCoefficients(mask_coeff) => split_mask_coeff = Some(mask_coeff),
1256 ConfigOutput::Classes(classes) => split_classes = Some(classes),
1257 }
1258 }
1259
1260 // Use end-to-end model types when:
1261 // 1. decoder_version is explicitly set to Yolo26 (definitive), OR
1262 // decoder_version is not set but the dshapes are (batch, num_boxes,
1263 // num_features)
1264 let is_end_to_end_dshape = boxes.as_ref().is_some_and(|b| {
1265 let dims = b.dshape.iter().map(|(d, _)| *d).collect::<Vec<_>>();
1266 dims == vec![DimName::Batch, DimName::NumBoxes, DimName::NumFeatures]
1267 });
1268
1269 let is_end_to_end = configs
1270 .decoder_version
1271 .map(|v| v.is_end_to_end())
1272 .unwrap_or(is_end_to_end_dshape);
1273
1274 if is_end_to_end {
1275 if let Some(boxes) = boxes {
1276 if let Some(protos) = protos {
1277 Self::verify_yolo_seg_det_26(&boxes, &protos)?;
1278 return Ok(ModelType::YoloEndToEndSegDet { boxes, protos });
1279 } else {
1280 Self::verify_yolo_det_26(&boxes)?;
1281 return Ok(ModelType::YoloEndToEndDet { boxes });
1282 }
1283 } else if let (Some(split_boxes), Some(split_scores), Some(split_classes)) =
1284 (split_boxes, split_scores, split_classes)
1285 {
1286 if let (Some(split_mask_coeff), Some(protos)) = (split_mask_coeff, protos) {
1287 Self::verify_yolo_split_end_to_end_segdet(
1288 &split_boxes,
1289 &split_scores,
1290 &split_classes,
1291 &split_mask_coeff,
1292 &protos,
1293 )?;
1294 return Ok(ModelType::YoloSplitEndToEndSegDet {
1295 boxes: split_boxes,
1296 scores: split_scores,
1297 classes: split_classes,
1298 mask_coeff: split_mask_coeff,
1299 protos,
1300 });
1301 }
1302 Self::verify_yolo_split_end_to_end_det(
1303 &split_boxes,
1304 &split_scores,
1305 &split_classes,
1306 )?;
1307 return Ok(ModelType::YoloSplitEndToEndDet {
1308 boxes: split_boxes,
1309 scores: split_scores,
1310 classes: split_classes,
1311 });
1312 } else {
1313 return Err(DecoderError::InvalidConfig(
1314 "Invalid Yolo end-to-end model outputs".to_string(),
1315 ));
1316 }
1317 }
1318
1319 if let Some(boxes) = boxes {
1320 match (split_mask_coeff, protos) {
1321 (Some(mask_coeff), Some(protos)) => {
1322 // 2-way split: combined detection + separate mask_coeff + protos
1323 Self::verify_yolo_seg_det_2way(&boxes, &mask_coeff, &protos)?;
1324 Ok(ModelType::YoloSegDet2Way {
1325 boxes,
1326 mask_coeff,
1327 protos,
1328 })
1329 }
1330 (_, Some(protos)) => {
1331 // Unsplit: mask_coefs embedded in detection tensor
1332 Self::verify_yolo_seg_det(&boxes, &protos)?;
1333 Ok(ModelType::YoloSegDet { boxes, protos })
1334 }
1335 _ => {
1336 Self::verify_yolo_det(&boxes)?;
1337 Ok(ModelType::YoloDet { boxes })
1338 }
1339 }
1340 } else if let (Some(boxes), Some(scores)) = (split_boxes, split_scores) {
1341 if let (Some(mask_coeff), Some(protos)) = (split_mask_coeff, protos) {
1342 Self::verify_yolo_split_segdet(&boxes, &scores, &mask_coeff, &protos)?;
1343 Ok(ModelType::YoloSplitSegDet {
1344 boxes,
1345 scores,
1346 mask_coeff,
1347 protos,
1348 })
1349 } else {
1350 Self::verify_yolo_split_det(&boxes, &scores)?;
1351 Ok(ModelType::YoloSplitDet { boxes, scores })
1352 }
1353 } else {
1354 Err(DecoderError::InvalidConfig(
1355 "Invalid Yolo model outputs".to_string(),
1356 ))
1357 }
1358 }
1359
1360 fn verify_yolo_det(detect: &configs::Detection) -> Result<(), DecoderError> {
1361 if detect.shape.len() != 3 {
1362 return Err(DecoderError::InvalidConfig(format!(
1363 "Invalid Yolo Detection shape {:?}",
1364 detect.shape
1365 )));
1366 }
1367
1368 Self::verify_dshapes(
1369 &detect.dshape,
1370 &detect.shape,
1371 "Detection",
1372 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1373 )?;
1374 if !detect.dshape.is_empty() {
1375 Self::get_class_count(&detect.dshape, None, None)?;
1376 } else {
1377 Self::get_class_count_no_dshape(detect.into(), None)?;
1378 }
1379
1380 Ok(())
1381 }
1382
1383 fn verify_yolo_det_26(detect: &configs::Detection) -> Result<(), DecoderError> {
1384 if detect.shape.len() != 3 {
1385 return Err(DecoderError::InvalidConfig(format!(
1386 "Invalid Yolo Detection shape {:?}",
1387 detect.shape
1388 )));
1389 }
1390
1391 Self::verify_dshapes(
1392 &detect.dshape,
1393 &detect.shape,
1394 "Detection",
1395 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1396 )?;
1397
1398 if !detect.shape.contains(&6) {
1399 return Err(DecoderError::InvalidConfig(
1400 "Yolo26 Detection must have 6 features".to_string(),
1401 ));
1402 }
1403
1404 Ok(())
1405 }
1406
1407 fn verify_yolo_seg_det(
1408 detection: &configs::Detection,
1409 protos: &configs::Protos,
1410 ) -> Result<(), DecoderError> {
1411 if detection.shape.len() != 3 {
1412 return Err(DecoderError::InvalidConfig(format!(
1413 "Invalid Yolo Detection shape {:?}",
1414 detection.shape
1415 )));
1416 }
1417 if protos.shape.len() != 4 {
1418 return Err(DecoderError::InvalidConfig(format!(
1419 "Invalid Yolo Protos shape {:?}",
1420 protos.shape
1421 )));
1422 }
1423
1424 Self::verify_dshapes(
1425 &detection.dshape,
1426 &detection.shape,
1427 "Detection",
1428 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1429 )?;
1430 Self::verify_dshapes(
1431 &protos.dshape,
1432 &protos.shape,
1433 "Protos",
1434 &[
1435 DimName::Batch,
1436 DimName::Height,
1437 DimName::Width,
1438 DimName::NumProtos,
1439 ],
1440 )?;
1441
1442 let protos_count = Self::get_protos_count(&protos.dshape)
1443 .unwrap_or_else(|| protos.shape[1].min(protos.shape[3]));
1444 log::debug!("Protos count: {}", protos_count);
1445 log::debug!("Detection dshape: {:?}", detection.dshape);
1446 let classes = if !detection.dshape.is_empty() {
1447 Self::get_class_count(&detection.dshape, Some(protos_count), None)?
1448 } else {
1449 Self::get_class_count_no_dshape(detection.into(), Some(protos_count))?
1450 };
1451
1452 if classes == 0 {
1453 return Err(DecoderError::InvalidConfig(
1454 "Yolo Segmentation Detection has zero classes".to_string(),
1455 ));
1456 }
1457
1458 Ok(())
1459 }
1460
1461 fn verify_yolo_seg_det_2way(
1462 detection: &configs::Detection,
1463 mask_coeff: &configs::MaskCoefficients,
1464 protos: &configs::Protos,
1465 ) -> Result<(), DecoderError> {
1466 if detection.shape.len() != 3 {
1467 return Err(DecoderError::InvalidConfig(format!(
1468 "Invalid Yolo 2-Way Detection shape {:?}",
1469 detection.shape
1470 )));
1471 }
1472 if mask_coeff.shape.len() != 3 {
1473 return Err(DecoderError::InvalidConfig(format!(
1474 "Invalid Yolo 2-Way Mask Coefficients shape {:?}",
1475 mask_coeff.shape
1476 )));
1477 }
1478 if protos.shape.len() != 4 {
1479 return Err(DecoderError::InvalidConfig(format!(
1480 "Invalid Yolo 2-Way Protos shape {:?}",
1481 protos.shape
1482 )));
1483 }
1484
1485 Self::verify_dshapes(
1486 &detection.dshape,
1487 &detection.shape,
1488 "Detection",
1489 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1490 )?;
1491 Self::verify_dshapes(
1492 &mask_coeff.dshape,
1493 &mask_coeff.shape,
1494 "Mask Coefficients",
1495 &[DimName::Batch, DimName::NumProtos, DimName::NumBoxes],
1496 )?;
1497 Self::verify_dshapes(
1498 &protos.dshape,
1499 &protos.shape,
1500 "Protos",
1501 &[
1502 DimName::Batch,
1503 DimName::Height,
1504 DimName::Width,
1505 DimName::NumProtos,
1506 ],
1507 )?;
1508
1509 // Validate num_boxes match between detection and mask_coeff
1510 let det_num = Self::get_box_count(&detection.dshape).unwrap_or(detection.shape[2]);
1511 let mask_num = Self::get_box_count(&mask_coeff.dshape).unwrap_or(mask_coeff.shape[2]);
1512 if det_num != mask_num {
1513 return Err(DecoderError::InvalidConfig(format!(
1514 "Yolo 2-Way Detection num_boxes {} incompatible with Mask Coefficients num_boxes {}",
1515 det_num, mask_num
1516 )));
1517 }
1518
1519 // Validate mask_coeff channels match protos channels
1520 let mask_channels = if !mask_coeff.dshape.is_empty() {
1521 Self::get_protos_count(&mask_coeff.dshape).ok_or_else(|| {
1522 DecoderError::InvalidConfig(
1523 "Could not find num_protos in mask_coeff config".to_string(),
1524 )
1525 })?
1526 } else {
1527 mask_coeff.shape[1]
1528 };
1529 let proto_channels = if !protos.dshape.is_empty() {
1530 Self::get_protos_count(&protos.dshape).ok_or_else(|| {
1531 DecoderError::InvalidConfig(
1532 "Could not find num_protos in protos config".to_string(),
1533 )
1534 })?
1535 } else {
1536 protos.shape[1].min(protos.shape[3])
1537 };
1538 if mask_channels != proto_channels {
1539 return Err(DecoderError::InvalidConfig(format!(
1540 "Yolo 2-Way Protos channels {} incompatible with Mask Coefficients channels {}",
1541 proto_channels, mask_channels
1542 )));
1543 }
1544
1545 // Validate detection has classes (nc+4 features, no mask_coefs embedded)
1546 if !detection.dshape.is_empty() {
1547 Self::get_class_count(&detection.dshape, None, None)?;
1548 } else {
1549 Self::get_class_count_no_dshape(detection.into(), None)?;
1550 }
1551
1552 Ok(())
1553 }
1554
1555 fn verify_yolo_seg_det_26(
1556 detection: &configs::Detection,
1557 protos: &configs::Protos,
1558 ) -> Result<(), DecoderError> {
1559 if detection.shape.len() != 3 {
1560 return Err(DecoderError::InvalidConfig(format!(
1561 "Invalid Yolo Detection shape {:?}",
1562 detection.shape
1563 )));
1564 }
1565 if protos.shape.len() != 4 {
1566 return Err(DecoderError::InvalidConfig(format!(
1567 "Invalid Yolo Protos shape {:?}",
1568 protos.shape
1569 )));
1570 }
1571
1572 Self::verify_dshapes(
1573 &detection.dshape,
1574 &detection.shape,
1575 "Detection",
1576 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1577 )?;
1578 Self::verify_dshapes(
1579 &protos.dshape,
1580 &protos.shape,
1581 "Protos",
1582 &[
1583 DimName::Batch,
1584 DimName::Height,
1585 DimName::Width,
1586 DimName::NumProtos,
1587 ],
1588 )?;
1589
1590 let protos_count = Self::get_protos_count(&protos.dshape)
1591 .unwrap_or_else(|| protos.shape[1].min(protos.shape[3]));
1592 log::debug!("Protos count: {}", protos_count);
1593 log::debug!("Detection dshape: {:?}", detection.dshape);
1594
1595 if !detection.shape.contains(&(6 + protos_count)) {
1596 return Err(DecoderError::InvalidConfig(format!(
1597 "Yolo26 Segmentation Detection must have num_features be 6 + num_protos = {}",
1598 6 + protos_count
1599 )));
1600 }
1601
1602 Ok(())
1603 }
1604
1605 fn verify_yolo_split_det(
1606 boxes: &configs::Boxes,
1607 scores: &configs::Scores,
1608 ) -> Result<(), DecoderError> {
1609 if boxes.shape.len() != 3 {
1610 return Err(DecoderError::InvalidConfig(format!(
1611 "Invalid Yolo Split Boxes shape {:?}",
1612 boxes.shape
1613 )));
1614 }
1615 if scores.shape.len() != 3 {
1616 return Err(DecoderError::InvalidConfig(format!(
1617 "Invalid Yolo Split Scores shape {:?}",
1618 scores.shape
1619 )));
1620 }
1621
1622 Self::verify_dshapes(
1623 &boxes.dshape,
1624 &boxes.shape,
1625 "Boxes",
1626 &[DimName::Batch, DimName::BoxCoords, DimName::NumBoxes],
1627 )?;
1628 Self::verify_dshapes(
1629 &scores.dshape,
1630 &scores.shape,
1631 "Scores",
1632 &[DimName::Batch, DimName::NumClasses, DimName::NumBoxes],
1633 )?;
1634
1635 let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[2]);
1636 let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[2]);
1637
1638 if boxes_num != scores_num {
1639 return Err(DecoderError::InvalidConfig(format!(
1640 "Yolo Split Detection Boxes num {} incompatible with Scores num {}",
1641 boxes_num, scores_num
1642 )));
1643 }
1644
1645 Ok(())
1646 }
1647
1648 fn verify_yolo_split_segdet(
1649 boxes: &configs::Boxes,
1650 scores: &configs::Scores,
1651 mask_coeff: &configs::MaskCoefficients,
1652 protos: &configs::Protos,
1653 ) -> Result<(), DecoderError> {
1654 if boxes.shape.len() != 3 {
1655 return Err(DecoderError::InvalidConfig(format!(
1656 "Invalid Yolo Split Boxes shape {:?}",
1657 boxes.shape
1658 )));
1659 }
1660 if scores.shape.len() != 3 {
1661 return Err(DecoderError::InvalidConfig(format!(
1662 "Invalid Yolo Split Scores shape {:?}",
1663 scores.shape
1664 )));
1665 }
1666
1667 if mask_coeff.shape.len() != 3 {
1668 return Err(DecoderError::InvalidConfig(format!(
1669 "Invalid Yolo Split Mask Coefficients shape {:?}",
1670 mask_coeff.shape
1671 )));
1672 }
1673
1674 if protos.shape.len() != 4 {
1675 return Err(DecoderError::InvalidConfig(format!(
1676 "Invalid Yolo Protos shape {:?}",
1677 mask_coeff.shape
1678 )));
1679 }
1680
1681 Self::verify_dshapes(
1682 &boxes.dshape,
1683 &boxes.shape,
1684 "Boxes",
1685 &[DimName::Batch, DimName::BoxCoords, DimName::NumBoxes],
1686 )?;
1687 Self::verify_dshapes(
1688 &scores.dshape,
1689 &scores.shape,
1690 "Scores",
1691 &[DimName::Batch, DimName::NumClasses, DimName::NumBoxes],
1692 )?;
1693 Self::verify_dshapes(
1694 &mask_coeff.dshape,
1695 &mask_coeff.shape,
1696 "Mask Coefficients",
1697 &[DimName::Batch, DimName::NumProtos, DimName::NumBoxes],
1698 )?;
1699 Self::verify_dshapes(
1700 &protos.dshape,
1701 &protos.shape,
1702 "Protos",
1703 &[
1704 DimName::Batch,
1705 DimName::Height,
1706 DimName::Width,
1707 DimName::NumProtos,
1708 ],
1709 )?;
1710
1711 let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[2]);
1712 let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[2]);
1713 let mask_num = Self::get_box_count(&mask_coeff.dshape).unwrap_or(mask_coeff.shape[2]);
1714
1715 let mask_channels = if !mask_coeff.dshape.is_empty() {
1716 Self::get_protos_count(&mask_coeff.dshape).ok_or_else(|| {
1717 DecoderError::InvalidConfig("Could not find num_protos in config".to_string())
1718 })?
1719 } else {
1720 mask_coeff.shape[1]
1721 };
1722 let proto_channels = if !protos.dshape.is_empty() {
1723 Self::get_protos_count(&protos.dshape).ok_or_else(|| {
1724 DecoderError::InvalidConfig("Could not find num_protos in config".to_string())
1725 })?
1726 } else {
1727 protos.shape[1].min(protos.shape[3])
1728 };
1729
1730 if boxes_num != scores_num {
1731 return Err(DecoderError::InvalidConfig(format!(
1732 "Yolo Split Detection Boxes num {} incompatible with Scores num {}",
1733 boxes_num, scores_num
1734 )));
1735 }
1736
1737 if boxes_num != mask_num {
1738 return Err(DecoderError::InvalidConfig(format!(
1739 "Yolo Split Detection Boxes num {} incompatible with Mask Coefficients num {}",
1740 boxes_num, mask_num
1741 )));
1742 }
1743
1744 if proto_channels != mask_channels {
1745 return Err(DecoderError::InvalidConfig(format!(
1746 "Yolo Protos channels {} incompatible with Mask Coefficients channels {}",
1747 proto_channels, mask_channels
1748 )));
1749 }
1750
1751 Ok(())
1752 }
1753
1754 fn verify_yolo_split_end_to_end_det(
1755 boxes: &configs::Boxes,
1756 scores: &configs::Scores,
1757 classes: &configs::Classes,
1758 ) -> Result<(), DecoderError> {
1759 if boxes.shape.len() != 3 || !boxes.shape.contains(&4) {
1760 return Err(DecoderError::InvalidConfig(format!(
1761 "Split end-to-end boxes must be [batch, N, 4], got {:?}",
1762 boxes.shape
1763 )));
1764 }
1765 if scores.shape.len() != 3 || !scores.shape.contains(&1) {
1766 return Err(DecoderError::InvalidConfig(format!(
1767 "Split end-to-end scores must be [batch, N, 1], got {:?}",
1768 scores.shape
1769 )));
1770 }
1771 if classes.shape.len() != 3 || !classes.shape.contains(&1) {
1772 return Err(DecoderError::InvalidConfig(format!(
1773 "Split end-to-end classes must be [batch, N, 1], got {:?}",
1774 classes.shape
1775 )));
1776 }
1777 Ok(())
1778 }
1779
1780 fn verify_yolo_split_end_to_end_segdet(
1781 boxes: &configs::Boxes,
1782 scores: &configs::Scores,
1783 classes: &configs::Classes,
1784 mask_coeff: &configs::MaskCoefficients,
1785 protos: &configs::Protos,
1786 ) -> Result<(), DecoderError> {
1787 Self::verify_yolo_split_end_to_end_det(boxes, scores, classes)?;
1788 if mask_coeff.shape.len() != 3 {
1789 return Err(DecoderError::InvalidConfig(format!(
1790 "Invalid split end-to-end mask coefficients shape {:?}",
1791 mask_coeff.shape
1792 )));
1793 }
1794 if protos.shape.len() != 4 {
1795 return Err(DecoderError::InvalidConfig(format!(
1796 "Invalid protos shape {:?}",
1797 protos.shape
1798 )));
1799 }
1800 Ok(())
1801 }
1802
1803 fn get_model_type_modelpack(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
1804 let mut split_decoders = Vec::new();
1805 let mut segment_ = None;
1806 let mut scores_ = None;
1807 let mut boxes_ = None;
1808 for c in configs.outputs {
1809 match c {
1810 ConfigOutput::Detection(detection) => split_decoders.push(detection),
1811 ConfigOutput::Segmentation(segmentation) => segment_ = Some(segmentation),
1812 ConfigOutput::Mask(_) => {}
1813 ConfigOutput::Protos(_) => {
1814 return Err(DecoderError::InvalidConfig(
1815 "ModelPack should not have protos".to_string(),
1816 ));
1817 }
1818 ConfigOutput::Scores(scores) => scores_ = Some(scores),
1819 ConfigOutput::Boxes(boxes) => boxes_ = Some(boxes),
1820 ConfigOutput::MaskCoefficients(_) => {
1821 return Err(DecoderError::InvalidConfig(
1822 "ModelPack should not have mask coefficients".to_string(),
1823 ));
1824 }
1825 ConfigOutput::Classes(_) => {
1826 return Err(DecoderError::InvalidConfig(
1827 "ModelPack should not have classes output".to_string(),
1828 ));
1829 }
1830 }
1831 }
1832
1833 if let Some(segmentation) = segment_ {
1834 if !split_decoders.is_empty() {
1835 let classes = Self::verify_modelpack_split_det(&split_decoders)?;
1836 Self::verify_modelpack_seg(&segmentation, Some(classes))?;
1837 Ok(ModelType::ModelPackSegDetSplit {
1838 detection: split_decoders,
1839 segmentation,
1840 })
1841 } else if let (Some(scores), Some(boxes)) = (scores_, boxes_) {
1842 let classes = Self::verify_modelpack_det(&boxes, &scores)?;
1843 Self::verify_modelpack_seg(&segmentation, Some(classes))?;
1844 Ok(ModelType::ModelPackSegDet {
1845 boxes,
1846 scores,
1847 segmentation,
1848 })
1849 } else {
1850 Self::verify_modelpack_seg(&segmentation, None)?;
1851 Ok(ModelType::ModelPackSeg { segmentation })
1852 }
1853 } else if !split_decoders.is_empty() {
1854 Self::verify_modelpack_split_det(&split_decoders)?;
1855 Ok(ModelType::ModelPackDetSplit {
1856 detection: split_decoders,
1857 })
1858 } else if let (Some(scores), Some(boxes)) = (scores_, boxes_) {
1859 Self::verify_modelpack_det(&boxes, &scores)?;
1860 Ok(ModelType::ModelPackDet { boxes, scores })
1861 } else {
1862 Err(DecoderError::InvalidConfig(
1863 "Invalid ModelPack model outputs".to_string(),
1864 ))
1865 }
1866 }
1867
1868 fn verify_modelpack_det(
1869 boxes: &configs::Boxes,
1870 scores: &configs::Scores,
1871 ) -> Result<usize, DecoderError> {
1872 if boxes.shape.len() != 4 {
1873 return Err(DecoderError::InvalidConfig(format!(
1874 "Invalid ModelPack Boxes shape {:?}",
1875 boxes.shape
1876 )));
1877 }
1878 if scores.shape.len() != 3 {
1879 return Err(DecoderError::InvalidConfig(format!(
1880 "Invalid ModelPack Scores shape {:?}",
1881 scores.shape
1882 )));
1883 }
1884
1885 Self::verify_dshapes(
1886 &boxes.dshape,
1887 &boxes.shape,
1888 "Boxes",
1889 &[
1890 DimName::Batch,
1891 DimName::NumBoxes,
1892 DimName::Padding,
1893 DimName::BoxCoords,
1894 ],
1895 )?;
1896 Self::verify_dshapes(
1897 &scores.dshape,
1898 &scores.shape,
1899 "Scores",
1900 &[DimName::Batch, DimName::NumBoxes, DimName::NumClasses],
1901 )?;
1902
1903 let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[1]);
1904 let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[1]);
1905
1906 if boxes_num != scores_num {
1907 return Err(DecoderError::InvalidConfig(format!(
1908 "ModelPack Detection Boxes num {} incompatible with Scores num {}",
1909 boxes_num, scores_num
1910 )));
1911 }
1912
1913 let num_classes = if !scores.dshape.is_empty() {
1914 Self::get_class_count(&scores.dshape, None, None)?
1915 } else {
1916 Self::get_class_count_no_dshape(scores.into(), None)?
1917 };
1918
1919 Ok(num_classes)
1920 }
1921
1922 fn verify_modelpack_split_det(boxes: &[configs::Detection]) -> Result<usize, DecoderError> {
1923 let mut num_classes = None;
1924 for b in boxes {
1925 let Some(num_anchors) = b.anchors.as_ref().map(|a| a.len()) else {
1926 return Err(DecoderError::InvalidConfig(
1927 "ModelPack Split Detection missing anchors".to_string(),
1928 ));
1929 };
1930
1931 if num_anchors == 0 {
1932 return Err(DecoderError::InvalidConfig(
1933 "ModelPack Split Detection has zero anchors".to_string(),
1934 ));
1935 }
1936
1937 if b.shape.len() != 4 {
1938 return Err(DecoderError::InvalidConfig(format!(
1939 "Invalid ModelPack Split Detection shape {:?}",
1940 b.shape
1941 )));
1942 }
1943
1944 Self::verify_dshapes(
1945 &b.dshape,
1946 &b.shape,
1947 "Split Detection",
1948 &[
1949 DimName::Batch,
1950 DimName::Height,
1951 DimName::Width,
1952 DimName::NumAnchorsXFeatures,
1953 ],
1954 )?;
1955 let classes = if !b.dshape.is_empty() {
1956 Self::get_class_count(&b.dshape, None, Some(num_anchors))?
1957 } else {
1958 Self::get_class_count_no_dshape(b.into(), None)?
1959 };
1960
1961 match num_classes {
1962 Some(n) => {
1963 if n != classes {
1964 return Err(DecoderError::InvalidConfig(format!(
1965 "ModelPack Split Detection inconsistent number of classes: previous {}, current {}",
1966 n, classes
1967 )));
1968 }
1969 }
1970 None => {
1971 num_classes = Some(classes);
1972 }
1973 }
1974 }
1975
1976 Ok(num_classes.unwrap_or(0))
1977 }
1978
1979 fn verify_modelpack_seg(
1980 segmentation: &configs::Segmentation,
1981 classes: Option<usize>,
1982 ) -> Result<(), DecoderError> {
1983 if segmentation.shape.len() != 4 {
1984 return Err(DecoderError::InvalidConfig(format!(
1985 "Invalid ModelPack Segmentation shape {:?}",
1986 segmentation.shape
1987 )));
1988 }
1989 Self::verify_dshapes(
1990 &segmentation.dshape,
1991 &segmentation.shape,
1992 "Segmentation",
1993 &[
1994 DimName::Batch,
1995 DimName::Height,
1996 DimName::Width,
1997 DimName::NumClasses,
1998 ],
1999 )?;
2000
2001 if let Some(classes) = classes {
2002 let seg_classes = if !segmentation.dshape.is_empty() {
2003 Self::get_class_count(&segmentation.dshape, None, None)?
2004 } else {
2005 Self::get_class_count_no_dshape(segmentation.into(), None)?
2006 };
2007
2008 if seg_classes != classes + 1 {
2009 return Err(DecoderError::InvalidConfig(format!(
2010 "ModelPack Segmentation channels {} incompatible with number of classes {}",
2011 seg_classes, classes
2012 )));
2013 }
2014 }
2015 Ok(())
2016 }
2017
2018 // verifies that dshapes match the given shape
2019 fn verify_dshapes(
2020 dshape: &[(DimName, usize)],
2021 shape: &[usize],
2022 name: &str,
2023 dims: &[DimName],
2024 ) -> Result<(), DecoderError> {
2025 for s in shape {
2026 if *s == 0 {
2027 return Err(DecoderError::InvalidConfig(format!(
2028 "{} shape has zero dimension",
2029 name
2030 )));
2031 }
2032 }
2033
2034 if shape.len() != dims.len() {
2035 return Err(DecoderError::InvalidConfig(format!(
2036 "{} shape length {} does not match expected dims length {}",
2037 name,
2038 shape.len(),
2039 dims.len()
2040 )));
2041 }
2042
2043 if dshape.is_empty() {
2044 return Ok(());
2045 }
2046 // check the dshape lengths match the shape lengths
2047 if dshape.len() != shape.len() {
2048 return Err(DecoderError::InvalidConfig(format!(
2049 "{} dshape length does not match shape length",
2050 name
2051 )));
2052 }
2053
2054 // check the dshape values match the shape values
2055 for ((dim_name, dim_size), shape_size) in dshape.iter().zip(shape) {
2056 if dim_size != shape_size {
2057 return Err(DecoderError::InvalidConfig(format!(
2058 "{} dshape dimension {} size {} does not match shape size {}",
2059 name, dim_name, dim_size, shape_size
2060 )));
2061 }
2062 if *dim_name == DimName::Padding && *dim_size != 1 {
2063 return Err(DecoderError::InvalidConfig(
2064 "Padding dimension size must be 1".to_string(),
2065 ));
2066 }
2067
2068 if *dim_name == DimName::BoxCoords && *dim_size != 4 {
2069 return Err(DecoderError::InvalidConfig(
2070 "BoxCoords dimension size must be 4".to_string(),
2071 ));
2072 }
2073 }
2074
2075 let dims_present = HashSet::<DimName>::from_iter(dshape.iter().map(|(name, _)| *name));
2076 for dim in dims {
2077 if !dims_present.contains(dim) {
2078 return Err(DecoderError::InvalidConfig(format!(
2079 "{} dshape missing required dimension {:?}",
2080 name, dim
2081 )));
2082 }
2083 }
2084
2085 Ok(())
2086 }
2087
2088 fn get_box_count(dshape: &[(DimName, usize)]) -> Option<usize> {
2089 for (dim_name, dim_size) in dshape {
2090 if *dim_name == DimName::NumBoxes {
2091 return Some(*dim_size);
2092 }
2093 }
2094 None
2095 }
2096
2097 fn get_class_count_no_dshape(
2098 config: ConfigOutputRef,
2099 protos: Option<usize>,
2100 ) -> Result<usize, DecoderError> {
2101 match config {
2102 ConfigOutputRef::Detection(detection) => match detection.decoder {
2103 DecoderType::Ultralytics => {
2104 if detection.shape[1] <= 4 + protos.unwrap_or(0) {
2105 return Err(DecoderError::InvalidConfig(format!(
2106 "Invalid shape: Yolo num_features {} must be greater than {}",
2107 detection.shape[1],
2108 4 + protos.unwrap_or(0),
2109 )));
2110 }
2111 Ok(detection.shape[1] - 4 - protos.unwrap_or(0))
2112 }
2113 DecoderType::ModelPack => {
2114 let Some(num_anchors) = detection.anchors.as_ref().map(|a| a.len()) else {
2115 return Err(DecoderError::Internal(
2116 "ModelPack Detection missing anchors".to_string(),
2117 ));
2118 };
2119 let anchors_x_features = detection.shape[3];
2120 if anchors_x_features <= num_anchors * 5 {
2121 return Err(DecoderError::InvalidConfig(format!(
2122 "Invalid ModelPack Split Detection shape: anchors_x_features {} not greater than number of anchors * 5 = {}",
2123 anchors_x_features,
2124 num_anchors * 5,
2125 )));
2126 }
2127
2128 if !anchors_x_features.is_multiple_of(num_anchors) {
2129 return Err(DecoderError::InvalidConfig(format!(
2130 "Invalid ModelPack Split Detection shape: anchors_x_features {} not a multiple of number of anchors {}",
2131 anchors_x_features, num_anchors
2132 )));
2133 }
2134 Ok(anchors_x_features / num_anchors - 5)
2135 }
2136 },
2137
2138 ConfigOutputRef::Scores(scores) => match scores.decoder {
2139 DecoderType::Ultralytics => Ok(scores.shape[1]),
2140 DecoderType::ModelPack => Ok(scores.shape[2]),
2141 },
2142 ConfigOutputRef::Segmentation(seg) => Ok(seg.shape[3]),
2143 _ => Err(DecoderError::Internal(
2144 "Attempted to get class count from unsupported config output".to_owned(),
2145 )),
2146 }
2147 }
2148
2149 // get the class count from dshape or calculate from num_features
2150 fn get_class_count(
2151 dshape: &[(DimName, usize)],
2152 protos: Option<usize>,
2153 anchors: Option<usize>,
2154 ) -> Result<usize, DecoderError> {
2155 if dshape.is_empty() {
2156 return Ok(0);
2157 }
2158 // if it has num_classes in dshape, return it
2159 for (dim_name, dim_size) in dshape {
2160 if *dim_name == DimName::NumClasses {
2161 return Ok(*dim_size);
2162 }
2163 }
2164
2165 // number of classes can be calculated from num_features - 4 for yolo. If the
2166 // model has protos, we also subtract the number of protos.
2167 for (dim_name, dim_size) in dshape {
2168 if *dim_name == DimName::NumFeatures {
2169 let protos = protos.unwrap_or(0);
2170 if protos + 4 >= *dim_size {
2171 return Err(DecoderError::InvalidConfig(format!(
2172 "Invalid shape: Yolo num_features {} must be greater than {}",
2173 *dim_size,
2174 protos + 4,
2175 )));
2176 }
2177 return Ok(*dim_size - 4 - protos);
2178 }
2179 }
2180
2181 // number of classes can be calculated from number of anchors for modelpack
2182 // split detection
2183 if let Some(num_anchors) = anchors {
2184 for (dim_name, dim_size) in dshape {
2185 if *dim_name == DimName::NumAnchorsXFeatures {
2186 let anchors_x_features = *dim_size;
2187 if anchors_x_features <= num_anchors * 5 {
2188 return Err(DecoderError::InvalidConfig(format!(
2189 "Invalid ModelPack Split Detection shape: anchors_x_features {} not greater than number of anchors * 5 = {}",
2190 anchors_x_features,
2191 num_anchors * 5,
2192 )));
2193 }
2194
2195 if !anchors_x_features.is_multiple_of(num_anchors) {
2196 return Err(DecoderError::InvalidConfig(format!(
2197 "Invalid ModelPack Split Detection shape: anchors_x_features {} not a multiple of number of anchors {}",
2198 anchors_x_features, num_anchors
2199 )));
2200 }
2201 return Ok((anchors_x_features / num_anchors) - 5);
2202 }
2203 }
2204 }
2205 Err(DecoderError::InvalidConfig(
2206 "Cannot determine number of classes from dshape".to_owned(),
2207 ))
2208 }
2209
2210 fn get_protos_count(dshape: &[(DimName, usize)]) -> Option<usize> {
2211 for (dim_name, dim_size) in dshape {
2212 if *dim_name == DimName::NumProtos {
2213 return Some(*dim_size);
2214 }
2215 }
2216 None
2217 }
2218}