edgefirst_decoder/decoder/builder.rs
1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::HashSet;
5
6use super::config::ConfigOutputRef;
7use super::configs::{self, DecoderType, DecoderVersion, DimName, ModelType};
8use super::merge::DecodeProgram;
9use super::{ConfigOutput, ConfigOutputs, Decoder};
10use crate::per_scale::{DecodeDtype, PerScalePlan};
11use crate::schema::SchemaV2;
12use crate::DecoderError;
13
14/// Extract `(width, height)` from a schema [`crate::schema::InputSpec`].
15///
16/// Prefers named dims (`DimName::Width` / `DimName::Height`) when the
17/// `dshape` is populated, falling back to the NHWC convention
18/// (`shape[1] = H, shape[2] = W`) for 4-element shapes whenever either
19/// named dim is missing. Returns `None` for any other shape arity — the
20/// decoder treats that as "input dims unknown" and skips the EDGEAI-1303
21/// normalization path.
22///
23/// The fallback fires when **either** `Height` or `Width` is missing from
24/// the dshape (not only when both are absent), so a partially-named
25/// dshape (e.g. only `Width`) still resolves both dims via positional
26/// inference instead of silently disabling normalization.
27fn input_dims_from_spec(input: &crate::schema::InputSpec) -> Option<(usize, usize)> {
28 use crate::configs::DimName;
29 let mut h = None;
30 let mut w = None;
31 // `SchemaV2::validate()` doesn't currently enforce
32 // `dshape.len() <= shape.len()`, so a malformed schema can trip an
33 // out-of-bounds index here. Use `shape.get(i)` to silently skip
34 // dshape entries that overflow the shape — the caller treats
35 // missing dims as "unknown" and disables EDGEAI-1303 normalization.
36 for (i, (name, _)) in input.dshape.iter().enumerate() {
37 match name {
38 DimName::Height => h = input.shape.get(i).copied(),
39 DimName::Width => w = input.shape.get(i).copied(),
40 _ => {}
41 }
42 }
43 if (h.is_none() || w.is_none()) && input.shape.len() == 4 {
44 // NHWC default: [N, H, W, C]. Mirrors the per-scale `extract_hw`
45 // fallback (`crates/decoder/src/per_scale/plan.rs::extract_hw`).
46 // Only fill the missing axis so a partial named dshape still
47 // resolves both dims.
48 h = h.or(Some(input.shape[1]));
49 w = w.or(Some(input.shape[2]));
50 }
51 match (w, h) {
52 (Some(w), Some(h)) => Some((w, h)),
53 _ => None,
54 }
55}
56
57#[cfg(test)]
58mod input_dims_from_spec_tests {
59 use super::input_dims_from_spec;
60 use crate::configs::DimName;
61 use crate::schema::InputSpec;
62
63 #[test]
64 fn named_dshape_resolves_dims() {
65 let spec = InputSpec {
66 shape: vec![1, 480, 640, 3],
67 dshape: vec![
68 (DimName::Batch, 1),
69 (DimName::Height, 480),
70 (DimName::Width, 640),
71 (DimName::NumFeatures, 3),
72 ],
73 cameraadaptor: None,
74 };
75 assert_eq!(input_dims_from_spec(&spec), Some((640, 480)));
76 }
77
78 #[test]
79 fn empty_dshape_falls_back_to_nhwc_for_4d() {
80 let spec = InputSpec {
81 shape: vec![1, 480, 640, 3],
82 dshape: vec![],
83 cameraadaptor: None,
84 };
85 assert_eq!(input_dims_from_spec(&spec), Some((640, 480)));
86 }
87
88 #[test]
89 fn malformed_dshape_longer_than_shape_does_not_panic() {
90 // Regression for Copilot review on PR #63: indexing
91 // `input.shape[i]` while iterating dshape can OOB-panic when
92 // `dshape.len() > shape.len()`. The fix uses `shape.get(i)`
93 // and silently treats overflow as "dim missing".
94 let spec = InputSpec {
95 shape: vec![640, 480], // 2-D shape
96 dshape: vec![
97 (DimName::Width, 640),
98 (DimName::Height, 480),
99 (DimName::NumFeatures, 3), // overflow — index 2 ≥ shape.len()
100 ],
101 cameraadaptor: None,
102 };
103 // First two dshape entries resolve via `shape.get()`; the third
104 // is a no-op. Width/Height both resolved, so we expect Some.
105 assert_eq!(input_dims_from_spec(&spec), Some((640, 480)));
106 }
107
108 #[test]
109 fn malformed_dshape_only_overflow_returns_none() {
110 // All dshape entries are past the shape boundary — width and
111 // height stay None and the 4-D NHWC fallback doesn't fire
112 // (shape.len() == 1), so we get None instead of a panic.
113 let spec = InputSpec {
114 shape: vec![1],
115 dshape: vec![
116 (DimName::NumFeatures, 3),
117 (DimName::Height, 480),
118 (DimName::Width, 640),
119 ],
120 cameraadaptor: None,
121 };
122 assert_eq!(input_dims_from_spec(&spec), None);
123 }
124}
125
126#[derive(Debug, Clone, PartialEq)]
127pub struct DecoderBuilder {
128 config_src: Option<ConfigSource>,
129 iou_threshold: f32,
130 score_threshold: f32,
131 /// NMS mode: Some(mode) applies NMS, None bypasses NMS (for end-to-end
132 /// models)
133 nms: Option<configs::Nms>,
134 /// Output dtype for the per-scale fast path. Has no effect on
135 /// schemas without per-scale children (which use the legacy decode
136 /// path).
137 decode_dtype: DecodeDtype,
138 pre_nms_top_k: usize,
139 max_det: usize,
140 /// Explicit override for the model input dimensions `(width, height)`,
141 /// consumed by EDGEAI-1303 normalization. When set, takes precedence
142 /// over schema-derived dims; when `None`, the value is read from the
143 /// schema's `input.shape` / `input.dshape` at build time.
144 input_dims: Option<(usize, usize)>,
145}
146
147#[derive(Debug, Clone, PartialEq)]
148enum ConfigSource {
149 Yaml(String),
150 Json(String),
151 Config(ConfigOutputs),
152 /// Schema v2 metadata. During build the schema is either converted
153 /// to a legacy [`ConfigOutputs`] (flat case) or used to construct a
154 /// [`DecodeProgram`] that performs per-child dequant + merge at
155 /// decode time.
156 Schema(SchemaV2),
157}
158
159impl Default for DecoderBuilder {
160 /// Creates a default DecoderBuilder with no configuration and 0.5 score
161 /// threshold and 0.5 IoU threshold.
162 ///
163 /// A valid configuration must be provided before building the Decoder.
164 ///
165 /// # Examples
166 /// ```rust
167 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
168 /// # fn main() -> DecoderResult<()> {
169 /// # let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
170 /// let decoder = DecoderBuilder::default()
171 /// .with_config_yaml_str(config_yaml)
172 /// .build()?;
173 /// assert_eq!(decoder.score_threshold, 0.5);
174 /// assert_eq!(decoder.iou_threshold, 0.5);
175 ///
176 /// # Ok(())
177 /// # }
178 /// ```
179 fn default() -> Self {
180 Self {
181 config_src: None,
182 iou_threshold: 0.5,
183 score_threshold: 0.5,
184 nms: Some(configs::Nms::ClassAgnostic),
185 decode_dtype: DecodeDtype::F32,
186 pre_nms_top_k: 300,
187 max_det: 300,
188 input_dims: None,
189 }
190 }
191}
192
193impl DecoderBuilder {
194 /// Creates a default DecoderBuilder with no configuration and 0.5 score
195 /// threshold and 0.5 IoU threshold.
196 ///
197 /// A valid configuration must be provided before building the Decoder.
198 ///
199 /// # Examples
200 /// ```rust
201 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
202 /// # fn main() -> DecoderResult<()> {
203 /// # let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
204 /// let decoder = DecoderBuilder::new()
205 /// .with_config_yaml_str(config_yaml)
206 /// .build()?;
207 /// assert_eq!(decoder.score_threshold, 0.5);
208 /// assert_eq!(decoder.iou_threshold, 0.5);
209 ///
210 /// # Ok(())
211 /// # }
212 /// ```
213 pub fn new() -> Self {
214 Self::default()
215 }
216
217 /// Loads a model configuration in YAML format. Does not check if the string
218 /// is a correct configuration file. Use `DecoderBuilder.build()` to
219 /// deserialize the YAML and parse the model configuration.
220 ///
221 /// # Examples
222 /// ```rust
223 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
224 /// # fn main() -> DecoderResult<()> {
225 /// let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
226 /// let decoder = DecoderBuilder::new()
227 /// .with_config_yaml_str(config_yaml)
228 /// .build()?;
229 ///
230 /// # Ok(())
231 /// # }
232 /// ```
233 pub fn with_config_yaml_str(mut self, yaml_str: String) -> Self {
234 self.config_src.replace(ConfigSource::Yaml(yaml_str));
235 self
236 }
237
238 /// Loads a model configuration in JSON format. Does not check if the string
239 /// is a correct configuration file. Use `DecoderBuilder.build()` to
240 /// deserialize the JSON and parse the model configuration.
241 ///
242 /// # Examples
243 /// ```rust
244 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
245 /// # fn main() -> DecoderResult<()> {
246 /// let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
247 /// let decoder = DecoderBuilder::new()
248 /// .with_config_json_str(config_json)
249 /// .build()?;
250 ///
251 /// # Ok(())
252 /// # }
253 /// ```
254 pub fn with_config_json_str(mut self, json_str: String) -> Self {
255 self.config_src.replace(ConfigSource::Json(json_str));
256 self
257 }
258
259 /// Loads a model configuration. Does not check if the configuration is
260 /// correct. Intended to be used when the user needs control over the
261 /// deserialize of the configuration information. Use
262 /// `DecoderBuilder.build()` to parse the model configuration.
263 ///
264 /// # Examples
265 /// ```rust
266 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
267 /// # fn main() -> DecoderResult<()> {
268 /// let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json"));
269 /// let config = serde_json::from_str(config_json)?;
270 /// let decoder = DecoderBuilder::new().with_config(config).build()?;
271 ///
272 /// # Ok(())
273 /// # }
274 /// ```
275 pub fn with_config(mut self, config: ConfigOutputs) -> Self {
276 self.config_src.replace(ConfigSource::Config(config));
277 self
278 }
279
280 /// Configure the decoder from a schema v2 metadata document.
281 ///
282 /// Accepts a [`SchemaV2`] as produced by [`SchemaV2::parse_json`],
283 /// [`SchemaV2::parse_yaml`], [`SchemaV2::parse_file`], or
284 /// constructed programmatically. The builder validates the schema,
285 /// compiles a [`DecodeProgram`] for any split logical outputs
286 /// (per-scale or channel sub-splits), and downconverts the
287 /// logical-level semantics to the legacy [`ConfigOutputs`]
288 /// representation consumed by the existing decoder dispatch.
289 ///
290 /// # Examples
291 ///
292 /// ```rust,no_run
293 /// use edgefirst_decoder::{DecoderBuilder, DecoderResult};
294 /// use edgefirst_decoder::schema::SchemaV2;
295 ///
296 /// # fn main() -> DecoderResult<()> {
297 /// let schema = SchemaV2::parse_file("model/edgefirst.json")?;
298 /// let decoder = DecoderBuilder::new()
299 /// .with_schema(schema)
300 /// .with_score_threshold(0.25)
301 /// .build()?;
302 /// # Ok(())
303 /// # }
304 /// ```
305 pub fn with_schema(mut self, schema: SchemaV2) -> Self {
306 self.config_src.replace(ConfigSource::Schema(schema));
307 self
308 }
309
310 /// Choose the output dtype for the per-scale decoder pipeline.
311 ///
312 /// Defaults to [`DecodeDtype::F32`]. Has no effect on schemas
313 /// without per-scale children (which use the legacy decode path).
314 /// `F16` saves ~2× memory bandwidth at the cost of 10-bit mantissa
315 /// precision; empirically safe for YOLO-family models.
316 ///
317 /// # Examples
318 ///
319 /// ```rust,no_run
320 /// use edgefirst_decoder::{DecodeDtype, DecoderBuilder, DecoderResult};
321 /// use edgefirst_decoder::schema::SchemaV2;
322 ///
323 /// # fn main() -> DecoderResult<()> {
324 /// let schema = SchemaV2::parse_file("model/edgefirst.json")?;
325 /// let decoder = DecoderBuilder::new()
326 /// .with_schema(schema)
327 /// .with_decode_dtype(DecodeDtype::F32)
328 /// .build()?;
329 /// # Ok(())
330 /// # }
331 /// ```
332 pub fn with_decode_dtype(mut self, dtype: DecodeDtype) -> Self {
333 self.decode_dtype = dtype;
334 self
335 }
336
337 /// Loads a YOLO detection model configuration. Use
338 /// `DecoderBuilder.build()` to parse the model configuration.
339 ///
340 /// # Examples
341 /// ```rust
342 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
343 /// # fn main() -> DecoderResult<()> {
344 /// let decoder = DecoderBuilder::new()
345 /// .with_config_yolo_det(
346 /// configs::Detection {
347 /// anchors: None,
348 /// decoder: configs::DecoderType::Ultralytics,
349 /// quantization: Some(configs::QuantTuple(0.012345, 26)),
350 /// shape: vec![1, 84, 8400],
351 /// dshape: Vec::new(),
352 /// normalized: Some(true),
353 /// },
354 /// None,
355 /// )
356 /// .build()?;
357 ///
358 /// # Ok(())
359 /// # }
360 /// ```
361 pub fn with_config_yolo_det(
362 mut self,
363 boxes: configs::Detection,
364 version: Option<DecoderVersion>,
365 ) -> Self {
366 let config = ConfigOutputs {
367 outputs: vec![ConfigOutput::Detection(boxes)],
368 decoder_version: version,
369 ..Default::default()
370 };
371 self.config_src.replace(ConfigSource::Config(config));
372 self
373 }
374
375 /// Loads a YOLO split detection model configuration. Use
376 /// `DecoderBuilder.build()` to parse the model configuration.
377 ///
378 /// # Examples
379 /// ```rust
380 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
381 /// # fn main() -> DecoderResult<()> {
382 /// let boxes_config = configs::Boxes {
383 /// decoder: configs::DecoderType::Ultralytics,
384 /// quantization: Some(configs::QuantTuple(0.012345, 26)),
385 /// shape: vec![1, 4, 8400],
386 /// dshape: Vec::new(),
387 /// normalized: Some(true),
388 /// };
389 /// let scores_config = configs::Scores {
390 /// decoder: configs::DecoderType::Ultralytics,
391 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
392 /// shape: vec![1, 80, 8400],
393 /// dshape: Vec::new(),
394 /// };
395 /// let decoder = DecoderBuilder::new()
396 /// .with_config_yolo_split_det(boxes_config, scores_config)
397 /// .build()?;
398 /// # Ok(())
399 /// # }
400 /// ```
401 pub fn with_config_yolo_split_det(
402 mut self,
403 boxes: configs::Boxes,
404 scores: configs::Scores,
405 ) -> Self {
406 let config = ConfigOutputs {
407 outputs: vec![ConfigOutput::Boxes(boxes), ConfigOutput::Scores(scores)],
408 ..Default::default()
409 };
410 self.config_src.replace(ConfigSource::Config(config));
411 self
412 }
413
414 /// Loads a YOLO segmentation model configuration. Use
415 /// `DecoderBuilder.build()` to parse the model configuration.
416 ///
417 /// # Examples
418 /// ```rust
419 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
420 /// # fn main() -> DecoderResult<()> {
421 /// let seg_config = configs::Detection {
422 /// decoder: configs::DecoderType::Ultralytics,
423 /// quantization: Some(configs::QuantTuple(0.012345, 26)),
424 /// shape: vec![1, 116, 8400],
425 /// anchors: None,
426 /// dshape: Vec::new(),
427 /// normalized: Some(true),
428 /// };
429 /// let protos_config = configs::Protos {
430 /// decoder: configs::DecoderType::Ultralytics,
431 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
432 /// shape: vec![1, 160, 160, 32],
433 /// dshape: Vec::new(),
434 /// };
435 /// let decoder = DecoderBuilder::new()
436 /// .with_config_yolo_segdet(
437 /// seg_config,
438 /// protos_config,
439 /// Some(configs::DecoderVersion::Yolov8),
440 /// )
441 /// .build()?;
442 /// # Ok(())
443 /// # }
444 /// ```
445 pub fn with_config_yolo_segdet(
446 mut self,
447 boxes: configs::Detection,
448 protos: configs::Protos,
449 version: Option<DecoderVersion>,
450 ) -> Self {
451 let config = ConfigOutputs {
452 outputs: vec![ConfigOutput::Detection(boxes), ConfigOutput::Protos(protos)],
453 decoder_version: version,
454 ..Default::default()
455 };
456 self.config_src.replace(ConfigSource::Config(config));
457 self
458 }
459
460 /// Loads a YOLO split segmentation model configuration. Use
461 /// `DecoderBuilder.build()` to parse the model configuration.
462 ///
463 /// # Examples
464 /// ```rust
465 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
466 /// # fn main() -> DecoderResult<()> {
467 /// let boxes_config = configs::Boxes {
468 /// decoder: configs::DecoderType::Ultralytics,
469 /// quantization: Some(configs::QuantTuple(0.012345, 26)),
470 /// shape: vec![1, 4, 8400],
471 /// dshape: Vec::new(),
472 /// normalized: Some(true),
473 /// };
474 /// let scores_config = configs::Scores {
475 /// decoder: configs::DecoderType::Ultralytics,
476 /// quantization: Some(configs::QuantTuple(0.012345, 14)),
477 /// shape: vec![1, 80, 8400],
478 /// dshape: Vec::new(),
479 /// };
480 /// let mask_config = configs::MaskCoefficients {
481 /// decoder: configs::DecoderType::Ultralytics,
482 /// quantization: Some(configs::QuantTuple(0.0064123, 125)),
483 /// shape: vec![1, 32, 8400],
484 /// dshape: Vec::new(),
485 /// };
486 /// let protos_config = configs::Protos {
487 /// decoder: configs::DecoderType::Ultralytics,
488 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
489 /// shape: vec![1, 160, 160, 32],
490 /// dshape: Vec::new(),
491 /// };
492 /// let decoder = DecoderBuilder::new()
493 /// .with_config_yolo_split_segdet(boxes_config, scores_config, mask_config, protos_config)
494 /// .build()?;
495 /// # Ok(())
496 /// # }
497 /// ```
498 pub fn with_config_yolo_split_segdet(
499 mut self,
500 boxes: configs::Boxes,
501 scores: configs::Scores,
502 mask_coefficients: configs::MaskCoefficients,
503 protos: configs::Protos,
504 ) -> Self {
505 let config = ConfigOutputs {
506 outputs: vec![
507 ConfigOutput::Boxes(boxes),
508 ConfigOutput::Scores(scores),
509 ConfigOutput::MaskCoefficients(mask_coefficients),
510 ConfigOutput::Protos(protos),
511 ],
512 ..Default::default()
513 };
514 self.config_src.replace(ConfigSource::Config(config));
515 self
516 }
517
518 /// Loads a ModelPack detection model configuration. Use
519 /// `DecoderBuilder.build()` to parse the model configuration.
520 ///
521 /// # Examples
522 /// ```rust
523 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
524 /// # fn main() -> DecoderResult<()> {
525 /// let boxes_config = configs::Boxes {
526 /// decoder: configs::DecoderType::ModelPack,
527 /// quantization: Some(configs::QuantTuple(0.012345, 26)),
528 /// shape: vec![1, 8400, 1, 4],
529 /// dshape: Vec::new(),
530 /// normalized: Some(true),
531 /// };
532 /// let scores_config = configs::Scores {
533 /// decoder: configs::DecoderType::ModelPack,
534 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
535 /// shape: vec![1, 8400, 3],
536 /// dshape: Vec::new(),
537 /// };
538 /// let decoder = DecoderBuilder::new()
539 /// .with_config_modelpack_det(boxes_config, scores_config)
540 /// .build()?;
541 /// # Ok(())
542 /// # }
543 /// ```
544 pub fn with_config_modelpack_det(
545 mut self,
546 boxes: configs::Boxes,
547 scores: configs::Scores,
548 ) -> Self {
549 let config = ConfigOutputs {
550 outputs: vec![ConfigOutput::Boxes(boxes), ConfigOutput::Scores(scores)],
551 ..Default::default()
552 };
553 self.config_src.replace(ConfigSource::Config(config));
554 self
555 }
556
557 /// Loads a ModelPack split detection model configuration. Use
558 /// `DecoderBuilder.build()` to parse the model configuration.
559 ///
560 /// # Examples
561 /// ```rust
562 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
563 /// # fn main() -> DecoderResult<()> {
564 /// let config0 = configs::Detection {
565 /// anchors: Some(vec![
566 /// [0.13750000298023224, 0.2074074000120163],
567 /// [0.2541666626930237, 0.21481481194496155],
568 /// [0.23125000298023224, 0.35185185074806213],
569 /// ]),
570 /// decoder: configs::DecoderType::ModelPack,
571 /// quantization: Some(configs::QuantTuple(0.012345, 26)),
572 /// shape: vec![1, 17, 30, 18],
573 /// dshape: Vec::new(),
574 /// normalized: Some(true),
575 /// };
576 /// let config1 = configs::Detection {
577 /// anchors: Some(vec![
578 /// [0.36666667461395264, 0.31481480598449707],
579 /// [0.38749998807907104, 0.4740740656852722],
580 /// [0.5333333611488342, 0.644444465637207],
581 /// ]),
582 /// decoder: configs::DecoderType::ModelPack,
583 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
584 /// shape: vec![1, 9, 15, 18],
585 /// dshape: Vec::new(),
586 /// normalized: Some(true),
587 /// };
588 ///
589 /// let decoder = DecoderBuilder::new()
590 /// .with_config_modelpack_det_split(vec![config0, config1])
591 /// .build()?;
592 /// # Ok(())
593 /// # }
594 /// ```
595 pub fn with_config_modelpack_det_split(mut self, boxes: Vec<configs::Detection>) -> Self {
596 let outputs = boxes.into_iter().map(ConfigOutput::Detection).collect();
597 let config = ConfigOutputs {
598 outputs,
599 ..Default::default()
600 };
601 self.config_src.replace(ConfigSource::Config(config));
602 self
603 }
604
605 /// Loads a ModelPack segmentation detection model configuration. Use
606 /// `DecoderBuilder.build()` to parse the model configuration.
607 ///
608 /// # Examples
609 /// ```rust
610 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
611 /// # fn main() -> DecoderResult<()> {
612 /// let boxes_config = configs::Boxes {
613 /// decoder: configs::DecoderType::ModelPack,
614 /// quantization: Some(configs::QuantTuple(0.012345, 26)),
615 /// shape: vec![1, 8400, 1, 4],
616 /// dshape: Vec::new(),
617 /// normalized: Some(true),
618 /// };
619 /// let scores_config = configs::Scores {
620 /// decoder: configs::DecoderType::ModelPack,
621 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
622 /// shape: vec![1, 8400, 2],
623 /// dshape: Vec::new(),
624 /// };
625 /// let seg_config = configs::Segmentation {
626 /// decoder: configs::DecoderType::ModelPack,
627 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
628 /// shape: vec![1, 640, 640, 3],
629 /// dshape: Vec::new(),
630 /// };
631 /// let decoder = DecoderBuilder::new()
632 /// .with_config_modelpack_segdet(boxes_config, scores_config, seg_config)
633 /// .build()?;
634 /// # Ok(())
635 /// # }
636 /// ```
637 pub fn with_config_modelpack_segdet(
638 mut self,
639 boxes: configs::Boxes,
640 scores: configs::Scores,
641 segmentation: configs::Segmentation,
642 ) -> Self {
643 let config = ConfigOutputs {
644 outputs: vec![
645 ConfigOutput::Boxes(boxes),
646 ConfigOutput::Scores(scores),
647 ConfigOutput::Segmentation(segmentation),
648 ],
649 ..Default::default()
650 };
651 self.config_src.replace(ConfigSource::Config(config));
652 self
653 }
654
655 /// Loads a ModelPack segmentation split detection model configuration. Use
656 /// `DecoderBuilder.build()` to parse the model configuration.
657 ///
658 /// # Examples
659 /// ```rust
660 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
661 /// # fn main() -> DecoderResult<()> {
662 /// let config0 = configs::Detection {
663 /// anchors: Some(vec![
664 /// [0.36666667461395264, 0.31481480598449707],
665 /// [0.38749998807907104, 0.4740740656852722],
666 /// [0.5333333611488342, 0.644444465637207],
667 /// ]),
668 /// decoder: configs::DecoderType::ModelPack,
669 /// quantization: Some(configs::QuantTuple(0.08547406643629074, 174)),
670 /// shape: vec![1, 9, 15, 18],
671 /// dshape: Vec::new(),
672 /// normalized: Some(true),
673 /// };
674 /// let config1 = configs::Detection {
675 /// anchors: Some(vec![
676 /// [0.13750000298023224, 0.2074074000120163],
677 /// [0.2541666626930237, 0.21481481194496155],
678 /// [0.23125000298023224, 0.35185185074806213],
679 /// ]),
680 /// decoder: configs::DecoderType::ModelPack,
681 /// quantization: Some(configs::QuantTuple(0.09929127991199493, 183)),
682 /// shape: vec![1, 17, 30, 18],
683 /// dshape: Vec::new(),
684 /// normalized: Some(true),
685 /// };
686 /// let seg_config = configs::Segmentation {
687 /// decoder: configs::DecoderType::ModelPack,
688 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
689 /// shape: vec![1, 640, 640, 2],
690 /// dshape: Vec::new(),
691 /// };
692 /// let decoder = DecoderBuilder::new()
693 /// .with_config_modelpack_segdet_split(vec![config0, config1], seg_config)
694 /// .build()?;
695 /// # Ok(())
696 /// # }
697 /// ```
698 pub fn with_config_modelpack_segdet_split(
699 mut self,
700 boxes: Vec<configs::Detection>,
701 segmentation: configs::Segmentation,
702 ) -> Self {
703 let mut outputs = boxes
704 .into_iter()
705 .map(ConfigOutput::Detection)
706 .collect::<Vec<_>>();
707 outputs.push(ConfigOutput::Segmentation(segmentation));
708 let config = ConfigOutputs {
709 outputs,
710 ..Default::default()
711 };
712 self.config_src.replace(ConfigSource::Config(config));
713 self
714 }
715
716 /// Loads a ModelPack segmentation model configuration. Use
717 /// `DecoderBuilder.build()` to parse the model configuration.
718 ///
719 /// # Examples
720 /// ```rust
721 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
722 /// # fn main() -> DecoderResult<()> {
723 /// let seg_config = configs::Segmentation {
724 /// decoder: configs::DecoderType::ModelPack,
725 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
726 /// shape: vec![1, 640, 640, 3],
727 /// dshape: Vec::new(),
728 /// };
729 /// let decoder = DecoderBuilder::new()
730 /// .with_config_modelpack_seg(seg_config)
731 /// .build()?;
732 /// # Ok(())
733 /// # }
734 /// ```
735 pub fn with_config_modelpack_seg(mut self, segmentation: configs::Segmentation) -> Self {
736 let config = ConfigOutputs {
737 outputs: vec![ConfigOutput::Segmentation(segmentation)],
738 ..Default::default()
739 };
740 self.config_src.replace(ConfigSource::Config(config));
741 self
742 }
743
744 /// Add an output to the decoder configuration.
745 ///
746 /// Incrementally builds the model configuration by adding outputs one at
747 /// a time. The decoder resolves the model type from the combination of
748 /// outputs during `build()`.
749 ///
750 /// If `dshape` is non-empty on the output, `shape` is automatically
751 /// derived from it (the size component of each named dimension). This
752 /// prevents conflicts between `shape` and `dshape`.
753 ///
754 /// This uses the programmatic config path. Calling this after
755 /// `with_config_json_str()` or `with_config_yaml_str()` replaces the
756 /// string-based config source.
757 ///
758 /// # Examples
759 /// ```rust
760 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, ConfigOutput, configs};
761 /// # fn main() -> DecoderResult<()> {
762 /// let decoder = DecoderBuilder::new()
763 /// .add_output(ConfigOutput::Scores(configs::Scores {
764 /// decoder: configs::DecoderType::Ultralytics,
765 /// dshape: vec![
766 /// (configs::DimName::Batch, 1),
767 /// (configs::DimName::NumClasses, 80),
768 /// (configs::DimName::NumBoxes, 8400),
769 /// ],
770 /// ..Default::default()
771 /// }))
772 /// .add_output(ConfigOutput::Boxes(configs::Boxes {
773 /// decoder: configs::DecoderType::Ultralytics,
774 /// dshape: vec![
775 /// (configs::DimName::Batch, 1),
776 /// (configs::DimName::BoxCoords, 4),
777 /// (configs::DimName::NumBoxes, 8400),
778 /// ],
779 /// ..Default::default()
780 /// }))
781 /// .build()?;
782 /// # Ok(())
783 /// # }
784 /// ```
785 pub fn add_output(mut self, output: ConfigOutput) -> Self {
786 if !matches!(self.config_src, Some(ConfigSource::Config(_))) {
787 self.config_src = Some(ConfigSource::Config(ConfigOutputs::default()));
788 }
789 if let Some(ConfigSource::Config(ref mut config)) = self.config_src {
790 config.outputs.push(Self::normalize_output(output));
791 }
792 self
793 }
794
795 /// Sets the decoder version for Ultralytics models.
796 ///
797 /// This is used with `add_output()` to specify the YOLO architecture
798 /// version when it cannot be inferred from the output shapes alone.
799 ///
800 /// - `Yolov5`, `Yolov8`, `Yolo11`: Traditional models requiring external
801 /// NMS
802 /// - `Yolo26`: End-to-end models with NMS embedded in the model graph
803 ///
804 /// # Examples
805 /// ```rust
806 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, ConfigOutput, configs};
807 /// # fn main() -> DecoderResult<()> {
808 /// let decoder = DecoderBuilder::new()
809 /// .add_output(ConfigOutput::Detection(configs::Detection {
810 /// decoder: configs::DecoderType::Ultralytics,
811 /// dshape: vec![
812 /// (configs::DimName::Batch, 1),
813 /// (configs::DimName::NumBoxes, 100),
814 /// (configs::DimName::NumFeatures, 6),
815 /// ],
816 /// ..Default::default()
817 /// }))
818 /// .with_decoder_version(configs::DecoderVersion::Yolo26)
819 /// .build()?;
820 /// # Ok(())
821 /// # }
822 /// ```
823 pub fn with_decoder_version(mut self, version: configs::DecoderVersion) -> Self {
824 if !matches!(self.config_src, Some(ConfigSource::Config(_))) {
825 self.config_src = Some(ConfigSource::Config(ConfigOutputs::default()));
826 }
827 if let Some(ConfigSource::Config(ref mut config)) = self.config_src {
828 config.decoder_version = Some(version);
829 }
830 self
831 }
832
833 /// Normalize an output: if dshape is non-empty, derive shape from it.
834 fn normalize_output(mut output: ConfigOutput) -> ConfigOutput {
835 fn normalize_shape(shape: &mut Vec<usize>, dshape: &[(configs::DimName, usize)]) {
836 if !dshape.is_empty() {
837 *shape = dshape.iter().map(|(_, size)| *size).collect();
838 }
839 }
840 match &mut output {
841 ConfigOutput::Detection(c) => normalize_shape(&mut c.shape, &c.dshape),
842 ConfigOutput::Boxes(c) => normalize_shape(&mut c.shape, &c.dshape),
843 ConfigOutput::Scores(c) => normalize_shape(&mut c.shape, &c.dshape),
844 ConfigOutput::Protos(c) => normalize_shape(&mut c.shape, &c.dshape),
845 ConfigOutput::Segmentation(c) => normalize_shape(&mut c.shape, &c.dshape),
846 ConfigOutput::MaskCoefficients(c) => normalize_shape(&mut c.shape, &c.dshape),
847 ConfigOutput::Mask(c) => normalize_shape(&mut c.shape, &c.dshape),
848 ConfigOutput::Classes(c) => normalize_shape(&mut c.shape, &c.dshape),
849 }
850 output
851 }
852
853 /// Sets the scores threshold of the decoder
854 ///
855 /// # Examples
856 /// ```rust
857 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
858 /// # fn main() -> DecoderResult<()> {
859 /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
860 /// let decoder = DecoderBuilder::new()
861 /// .with_config_json_str(config_json)
862 /// .with_score_threshold(0.654)
863 /// .build()?;
864 /// assert_eq!(decoder.score_threshold, 0.654);
865 /// # Ok(())
866 /// # }
867 /// ```
868 pub fn with_score_threshold(mut self, score_threshold: f32) -> Self {
869 self.score_threshold = score_threshold;
870 self
871 }
872
873 /// Sets the IOU threshold of the decoder. Has no effect when NMS is set to
874 /// `None`
875 ///
876 /// # Examples
877 /// ```rust
878 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
879 /// # fn main() -> DecoderResult<()> {
880 /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
881 /// let decoder = DecoderBuilder::new()
882 /// .with_config_json_str(config_json)
883 /// .with_iou_threshold(0.654)
884 /// .build()?;
885 /// assert_eq!(decoder.iou_threshold, 0.654);
886 /// # Ok(())
887 /// # }
888 /// ```
889 pub fn with_iou_threshold(mut self, iou_threshold: f32) -> Self {
890 self.iou_threshold = iou_threshold;
891 self
892 }
893
894 /// Sets the NMS mode for the decoder.
895 ///
896 /// - `Some(Nms::ClassAgnostic)` — class-agnostic NMS (default): suppress
897 /// overlapping boxes regardless of class label
898 /// - `Some(Nms::ClassAware)` — class-aware NMS: only suppress boxes that
899 /// share the same class label AND overlap above the IoU threshold
900 /// - `None` — bypass NMS entirely (for end-to-end models with embedded NMS)
901 ///
902 /// # Examples
903 /// ```rust
904 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, configs::Nms};
905 /// # fn main() -> DecoderResult<()> {
906 /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
907 /// let decoder = DecoderBuilder::new()
908 /// .with_config_json_str(config_json)
909 /// .with_nms(Some(Nms::ClassAware))
910 /// .build()?;
911 /// assert_eq!(decoder.nms, Some(Nms::ClassAware));
912 /// # Ok(())
913 /// # }
914 /// ```
915 pub fn with_nms(mut self, nms: Option<configs::Nms>) -> Self {
916 self.nms = nms;
917 self
918 }
919
920 /// Sets the maximum number of candidate boxes fed into NMS after score
921 /// filtering. Uses partial sort (O(N)) to select the top-K candidates,
922 /// dramatically reducing the O(N²) NMS cost when many low-confidence
923 /// proposals pass the threshold (common with mAP eval at 0.001).
924 ///
925 /// Default: 300 (matches Ultralytics `max_det`).
926 ///
927 /// # Examples
928 /// ```rust
929 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
930 /// # fn main() -> DecoderResult<()> {
931 /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
932 /// let decoder = DecoderBuilder::new()
933 /// .with_config_json_str(config_json)
934 /// .with_pre_nms_top_k(1000)
935 /// .build()?;
936 /// assert_eq!(decoder.pre_nms_top_k, 1000);
937 /// # Ok(())
938 /// # }
939 /// ```
940 pub fn with_pre_nms_top_k(mut self, pre_nms_top_k: usize) -> Self {
941 self.pre_nms_top_k = pre_nms_top_k;
942 self
943 }
944
945 /// Sets the maximum number of detections returned after NMS.
946 /// Matches the Ultralytics `max_det` parameter.
947 ///
948 /// Default: 300.
949 ///
950 /// # Examples
951 /// ```rust
952 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
953 /// # fn main() -> DecoderResult<()> {
954 /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
955 /// let decoder = DecoderBuilder::new()
956 /// .with_config_json_str(config_json)
957 /// .with_max_det(100)
958 /// .build()?;
959 /// assert_eq!(decoder.max_det, 100);
960 /// # Ok(())
961 /// # }
962 /// ```
963 pub fn with_max_det(mut self, max_det: usize) -> Self {
964 self.max_det = max_det;
965 self
966 }
967
968 /// Sets the model input dimensions `(width, height)` consumed by the
969 /// EDGEAI-1303 normalization path. Use this when building via
970 /// [`with_config`](Self::with_config) / [`add_output`](Self::add_output)
971 /// (no schema) and the model emits pixel-space boxes that need to be
972 /// divided by `(W, H)` before NMS.
973 ///
974 /// When the builder is also configured with [`with_schema`](Self::with_schema)
975 /// (or `with_config_json_str` / `with_config_yaml_str`) and the schema's
976 /// `input` block carries usable dims, this explicit override **takes
977 /// precedence** so callers can correct schemas with missing or wrong
978 /// input specs without rewriting the schema.
979 ///
980 /// # Examples
981 /// ```rust
982 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
983 /// # fn main() -> DecoderResult<()> {
984 /// # let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
985 /// let decoder = DecoderBuilder::new()
986 /// .with_config_yaml_str(config_yaml)
987 /// .with_input_dims(640, 640)
988 /// .build()?;
989 /// assert_eq!(decoder.input_dims(), Some((640, 640)));
990 /// # Ok(())
991 /// # }
992 /// ```
993 pub fn with_input_dims(mut self, width: usize, height: usize) -> Self {
994 self.input_dims = Some((width, height));
995 self
996 }
997
998 /// Builds the decoder with the given settings. If the config is a JSON or
999 /// YAML string, this will deserialize the JSON or YAML and then parse the
1000 /// configuration information.
1001 ///
1002 /// # Examples
1003 /// ```rust
1004 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
1005 /// # fn main() -> DecoderResult<()> {
1006 /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
1007 /// let decoder = DecoderBuilder::new()
1008 /// .with_config_json_str(config_json)
1009 /// .with_score_threshold(0.654)
1010 /// .build()?;
1011 /// # Ok(())
1012 /// # }
1013 /// ```
1014 pub fn build(self) -> Result<Decoder, DecoderError> {
1015 let decode_dtype = self.decode_dtype;
1016 let explicit_input_dims = self.input_dims;
1017 let (config, decode_program, per_scale_plan, schema_input_dims) = match self.config_src {
1018 Some(ConfigSource::Json(s)) => {
1019 Self::build_from_schema(SchemaV2::parse_json(&s)?, decode_dtype)?
1020 }
1021 Some(ConfigSource::Yaml(s)) => {
1022 Self::build_from_schema(SchemaV2::parse_yaml(&s)?, decode_dtype)?
1023 }
1024 Some(ConfigSource::Config(c)) => (c, None, None, None),
1025 Some(ConfigSource::Schema(schema)) => Self::build_from_schema(schema, decode_dtype)?,
1026 None => return Err(DecoderError::NoConfig),
1027 };
1028 // Explicit `with_input_dims(W, H)` overrides any schema-derived
1029 // value so callers can fix schemas with missing or wrong input
1030 // specs without rewriting the schema (EDGEAI-1303).
1031 let input_dims = explicit_input_dims.or(schema_input_dims);
1032
1033 // Enforce the physical-order contract: when dshape is present
1034 // it must describe the same axes as shape in the same order,
1035 // listed from outermost to innermost. Ambiguous-layout roles
1036 // (Protos, Boxes, Scores, MaskCoefficients, Classes, Detection)
1037 // may still omit dshape when shape is already in the decoder's
1038 // canonical order.
1039 for output in &config.outputs {
1040 Decoder::validate_output_layout(output.into())?;
1041 }
1042
1043 // Extract normalized flag from config outputs.
1044 //
1045 // The per-scale subsystem (DFL/LTRB → dist2bbox → sigmoid) emits
1046 // boxes in pixel coordinates by design — `(grid + dist) * stride`
1047 // — independently of any `normalized: true` annotation in the
1048 // schema. The schema's `normalized` flag describes the model's
1049 // training-time convention, not the runtime output coord space
1050 // for this code path. Override to `Some(false)` when the
1051 // per-scale path is active so `Decoder::normalized_boxes()`
1052 // matches what `decode_proto`/`decode` actually produce; the
1053 // legacy / non-per-scale paths still honor the schema flag.
1054 let normalized = if per_scale_plan.is_some() {
1055 Some(false)
1056 } else {
1057 Self::get_normalized(&config.outputs)
1058 };
1059
1060 // Use NMS from config if present, otherwise use builder's NMS setting
1061 let nms = config.nms.or(self.nms);
1062 // When the per-scale path is active, the per_scale subsystem owns
1063 // model decoding entirely — `decode` / `decode_proto` short-circuit
1064 // on `per_scale.is_some()` before reading `model_type`. Skip the
1065 // legacy ModelType validation, which otherwise rejects per-scale
1066 // schemas that carry `decoder_version: yolo26` (an
1067 // "end-to-end" hint) but use the per-scale split outputs rather
1068 // than the end-to-end split-output shape the legacy validator
1069 // expects. We keep a placeholder `ModelType` so the field remains
1070 // valid; it is dead state for per-scale Decoders.
1071 let model_type = if per_scale_plan.is_some() {
1072 // Drop the un-needed config; the per-scale subsystem owns it.
1073 drop(config);
1074 ModelType::PerScale
1075 } else {
1076 Self::get_model_type(config)?
1077 };
1078
1079 let per_scale = per_scale_plan
1080 .map(|plan| std::sync::Mutex::new(crate::per_scale::PerScaleDecoder::new(plan)));
1081
1082 Ok(Decoder {
1083 model_type,
1084 iou_threshold: self.iou_threshold,
1085 score_threshold: self.score_threshold,
1086 nms,
1087 pre_nms_top_k: self.pre_nms_top_k,
1088 max_det: self.max_det,
1089 normalized,
1090 input_dims,
1091 decode_program,
1092 per_scale,
1093 })
1094 }
1095
1096 /// Validate a [`SchemaV2`] and lower it to the (legacy `ConfigOutputs`,
1097 /// optional `DecodeProgram`, optional `PerScalePlan`) tuple the rest
1098 /// of `build()` consumes.
1099 ///
1100 /// Centralises the v2 lowering so JSON, YAML, and direct
1101 /// `with_schema` callers all go through the same validation,
1102 /// merge-program, and per-scale plan construction. `SchemaV2::parse_json`
1103 /// / `parse_yaml` already auto-detect v1 vs v2 input and return a v2
1104 /// schema either way (v1 inputs are upgraded in memory via
1105 /// `from_v1`), so this helper is the sole place that turns a
1106 /// schema into builder-ready state.
1107 #[allow(clippy::type_complexity)]
1108 fn build_from_schema(
1109 schema: SchemaV2,
1110 decode_dtype: DecodeDtype,
1111 ) -> Result<
1112 (
1113 ConfigOutputs,
1114 Option<DecodeProgram>,
1115 Option<PerScalePlan>,
1116 Option<(usize, usize)>,
1117 ),
1118 DecoderError,
1119 > {
1120 schema.validate()?;
1121 let program = DecodeProgram::try_from_schema(&schema)?;
1122 let per_scale = PerScalePlan::try_from_schema(&schema, decode_dtype)?;
1123 // Extract model input (W, H) from `input.shape`/`dshape`. Used by
1124 // the legacy decode path to honour `normalized: false` (see
1125 // EDGEAI-1303). `None` is fine when the schema omits the input
1126 // spec — the decoder falls back to the protobox `>2.0` reject.
1127 let input_dims = schema.input.as_ref().and_then(input_dims_from_spec);
1128 let legacy = schema.to_legacy_config_outputs()?;
1129 Ok((legacy, program, per_scale, input_dims))
1130 }
1131
1132 /// Extracts the normalized flag from config outputs.
1133 /// - `Some(true)`: Boxes are in normalized [0,1] coordinates
1134 /// - `Some(false)`: Boxes are in pixel coordinates
1135 /// - `None`: Unknown (not specified in config), caller must infer
1136 fn get_normalized(outputs: &[ConfigOutput]) -> Option<bool> {
1137 for output in outputs {
1138 match output {
1139 ConfigOutput::Detection(det) => return det.normalized,
1140 ConfigOutput::Boxes(boxes) => return boxes.normalized,
1141 _ => {}
1142 }
1143 }
1144 None // not specified
1145 }
1146
1147 fn get_model_type(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
1148 // yolo or modelpack
1149 let mut yolo = false;
1150 let mut modelpack = false;
1151 for c in &configs.outputs {
1152 match c.decoder() {
1153 DecoderType::ModelPack => modelpack = true,
1154 DecoderType::Ultralytics => yolo = true,
1155 }
1156 }
1157 match (modelpack, yolo) {
1158 (true, true) => Err(DecoderError::InvalidConfig(
1159 "Both ModelPack and Yolo outputs found in config".to_string(),
1160 )),
1161 (true, false) => Self::get_model_type_modelpack(configs),
1162 (false, true) => Self::get_model_type_yolo(configs),
1163 (false, false) => Err(DecoderError::InvalidConfig(
1164 "No outputs found in config".to_string(),
1165 )),
1166 }
1167 }
1168
1169 fn get_model_type_yolo(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
1170 let mut boxes = None;
1171 let mut protos = None;
1172 let mut split_boxes = None;
1173 let mut split_scores = None;
1174 let mut split_mask_coeff = None;
1175 let mut split_classes = None;
1176 for c in configs.outputs {
1177 match c {
1178 ConfigOutput::Detection(detection) => boxes = Some(detection),
1179 ConfigOutput::Segmentation(_) => {
1180 return Err(DecoderError::InvalidConfig(
1181 "Invalid Segmentation output with Yolo decoder".to_string(),
1182 ));
1183 }
1184 ConfigOutput::Protos(protos_) => protos = Some(protos_),
1185 ConfigOutput::Mask(_) => {
1186 return Err(DecoderError::InvalidConfig(
1187 "Invalid Mask output with Yolo decoder".to_string(),
1188 ));
1189 }
1190 ConfigOutput::Scores(scores) => split_scores = Some(scores),
1191 ConfigOutput::Boxes(boxes) => split_boxes = Some(boxes),
1192 ConfigOutput::MaskCoefficients(mask_coeff) => split_mask_coeff = Some(mask_coeff),
1193 ConfigOutput::Classes(classes) => split_classes = Some(classes),
1194 }
1195 }
1196
1197 // Use end-to-end model types when:
1198 // 1. decoder_version is explicitly set to Yolo26 (definitive), OR
1199 // decoder_version is not set but the dshapes are (batch, num_boxes,
1200 // num_features)
1201 let is_end_to_end_dshape = boxes.as_ref().is_some_and(|b| {
1202 let dims = b.dshape.iter().map(|(d, _)| *d).collect::<Vec<_>>();
1203 dims == vec![DimName::Batch, DimName::NumBoxes, DimName::NumFeatures]
1204 });
1205
1206 let is_end_to_end = configs
1207 .decoder_version
1208 .map(|v| v.is_end_to_end())
1209 .unwrap_or(is_end_to_end_dshape);
1210
1211 if is_end_to_end {
1212 if let Some(boxes) = boxes {
1213 if let Some(protos) = protos {
1214 Self::verify_yolo_seg_det_26(&boxes, &protos)?;
1215 return Ok(ModelType::YoloEndToEndSegDet { boxes, protos });
1216 } else {
1217 Self::verify_yolo_det_26(&boxes)?;
1218 return Ok(ModelType::YoloEndToEndDet { boxes });
1219 }
1220 } else if let (Some(split_boxes), Some(split_scores), Some(split_classes)) =
1221 (split_boxes, split_scores, split_classes)
1222 {
1223 if let (Some(split_mask_coeff), Some(protos)) = (split_mask_coeff, protos) {
1224 Self::verify_yolo_split_end_to_end_segdet(
1225 &split_boxes,
1226 &split_scores,
1227 &split_classes,
1228 &split_mask_coeff,
1229 &protos,
1230 )?;
1231 return Ok(ModelType::YoloSplitEndToEndSegDet {
1232 boxes: split_boxes,
1233 scores: split_scores,
1234 classes: split_classes,
1235 mask_coeff: split_mask_coeff,
1236 protos,
1237 });
1238 }
1239 Self::verify_yolo_split_end_to_end_det(
1240 &split_boxes,
1241 &split_scores,
1242 &split_classes,
1243 )?;
1244 return Ok(ModelType::YoloSplitEndToEndDet {
1245 boxes: split_boxes,
1246 scores: split_scores,
1247 classes: split_classes,
1248 });
1249 } else {
1250 return Err(DecoderError::InvalidConfig(
1251 "Invalid Yolo end-to-end model outputs".to_string(),
1252 ));
1253 }
1254 }
1255
1256 if let Some(boxes) = boxes {
1257 match (split_mask_coeff, protos) {
1258 (Some(mask_coeff), Some(protos)) => {
1259 // 2-way split: combined detection + separate mask_coeff + protos
1260 Self::verify_yolo_seg_det_2way(&boxes, &mask_coeff, &protos)?;
1261 Ok(ModelType::YoloSegDet2Way {
1262 boxes,
1263 mask_coeff,
1264 protos,
1265 })
1266 }
1267 (_, Some(protos)) => {
1268 // Unsplit: mask_coefs embedded in detection tensor
1269 Self::verify_yolo_seg_det(&boxes, &protos)?;
1270 Ok(ModelType::YoloSegDet { boxes, protos })
1271 }
1272 _ => {
1273 Self::verify_yolo_det(&boxes)?;
1274 Ok(ModelType::YoloDet { boxes })
1275 }
1276 }
1277 } else if let (Some(boxes), Some(scores)) = (split_boxes, split_scores) {
1278 if let (Some(mask_coeff), Some(protos)) = (split_mask_coeff, protos) {
1279 Self::verify_yolo_split_segdet(&boxes, &scores, &mask_coeff, &protos)?;
1280 Ok(ModelType::YoloSplitSegDet {
1281 boxes,
1282 scores,
1283 mask_coeff,
1284 protos,
1285 })
1286 } else {
1287 Self::verify_yolo_split_det(&boxes, &scores)?;
1288 Ok(ModelType::YoloSplitDet { boxes, scores })
1289 }
1290 } else {
1291 Err(DecoderError::InvalidConfig(
1292 "Invalid Yolo model outputs".to_string(),
1293 ))
1294 }
1295 }
1296
1297 fn verify_yolo_det(detect: &configs::Detection) -> Result<(), DecoderError> {
1298 if detect.shape.len() != 3 {
1299 return Err(DecoderError::InvalidConfig(format!(
1300 "Invalid Yolo Detection shape {:?}",
1301 detect.shape
1302 )));
1303 }
1304
1305 Self::verify_dshapes(
1306 &detect.dshape,
1307 &detect.shape,
1308 "Detection",
1309 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1310 )?;
1311 if !detect.dshape.is_empty() {
1312 Self::get_class_count(&detect.dshape, None, None)?;
1313 } else {
1314 Self::get_class_count_no_dshape(detect.into(), None)?;
1315 }
1316
1317 Ok(())
1318 }
1319
1320 fn verify_yolo_det_26(detect: &configs::Detection) -> Result<(), DecoderError> {
1321 if detect.shape.len() != 3 {
1322 return Err(DecoderError::InvalidConfig(format!(
1323 "Invalid Yolo Detection shape {:?}",
1324 detect.shape
1325 )));
1326 }
1327
1328 Self::verify_dshapes(
1329 &detect.dshape,
1330 &detect.shape,
1331 "Detection",
1332 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1333 )?;
1334
1335 if !detect.shape.contains(&6) {
1336 return Err(DecoderError::InvalidConfig(
1337 "Yolo26 Detection must have 6 features".to_string(),
1338 ));
1339 }
1340
1341 Ok(())
1342 }
1343
1344 fn verify_yolo_seg_det(
1345 detection: &configs::Detection,
1346 protos: &configs::Protos,
1347 ) -> Result<(), DecoderError> {
1348 if detection.shape.len() != 3 {
1349 return Err(DecoderError::InvalidConfig(format!(
1350 "Invalid Yolo Detection shape {:?}",
1351 detection.shape
1352 )));
1353 }
1354 if protos.shape.len() != 4 {
1355 return Err(DecoderError::InvalidConfig(format!(
1356 "Invalid Yolo Protos shape {:?}",
1357 protos.shape
1358 )));
1359 }
1360
1361 Self::verify_dshapes(
1362 &detection.dshape,
1363 &detection.shape,
1364 "Detection",
1365 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1366 )?;
1367 Self::verify_dshapes(
1368 &protos.dshape,
1369 &protos.shape,
1370 "Protos",
1371 &[
1372 DimName::Batch,
1373 DimName::Height,
1374 DimName::Width,
1375 DimName::NumProtos,
1376 ],
1377 )?;
1378
1379 let protos_count = Self::get_protos_count(&protos.dshape)
1380 .unwrap_or_else(|| protos.shape[1].min(protos.shape[3]));
1381 log::debug!("Protos count: {}", protos_count);
1382 log::debug!("Detection dshape: {:?}", detection.dshape);
1383 let classes = if !detection.dshape.is_empty() {
1384 Self::get_class_count(&detection.dshape, Some(protos_count), None)?
1385 } else {
1386 Self::get_class_count_no_dshape(detection.into(), Some(protos_count))?
1387 };
1388
1389 if classes == 0 {
1390 return Err(DecoderError::InvalidConfig(
1391 "Yolo Segmentation Detection has zero classes".to_string(),
1392 ));
1393 }
1394
1395 Ok(())
1396 }
1397
1398 fn verify_yolo_seg_det_2way(
1399 detection: &configs::Detection,
1400 mask_coeff: &configs::MaskCoefficients,
1401 protos: &configs::Protos,
1402 ) -> Result<(), DecoderError> {
1403 if detection.shape.len() != 3 {
1404 return Err(DecoderError::InvalidConfig(format!(
1405 "Invalid Yolo 2-Way Detection shape {:?}",
1406 detection.shape
1407 )));
1408 }
1409 if mask_coeff.shape.len() != 3 {
1410 return Err(DecoderError::InvalidConfig(format!(
1411 "Invalid Yolo 2-Way Mask Coefficients shape {:?}",
1412 mask_coeff.shape
1413 )));
1414 }
1415 if protos.shape.len() != 4 {
1416 return Err(DecoderError::InvalidConfig(format!(
1417 "Invalid Yolo 2-Way Protos shape {:?}",
1418 protos.shape
1419 )));
1420 }
1421
1422 Self::verify_dshapes(
1423 &detection.dshape,
1424 &detection.shape,
1425 "Detection",
1426 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1427 )?;
1428 Self::verify_dshapes(
1429 &mask_coeff.dshape,
1430 &mask_coeff.shape,
1431 "Mask Coefficients",
1432 &[DimName::Batch, DimName::NumProtos, DimName::NumBoxes],
1433 )?;
1434 Self::verify_dshapes(
1435 &protos.dshape,
1436 &protos.shape,
1437 "Protos",
1438 &[
1439 DimName::Batch,
1440 DimName::Height,
1441 DimName::Width,
1442 DimName::NumProtos,
1443 ],
1444 )?;
1445
1446 // Validate num_boxes match between detection and mask_coeff
1447 let det_num = Self::get_box_count(&detection.dshape).unwrap_or(detection.shape[2]);
1448 let mask_num = Self::get_box_count(&mask_coeff.dshape).unwrap_or(mask_coeff.shape[2]);
1449 if det_num != mask_num {
1450 return Err(DecoderError::InvalidConfig(format!(
1451 "Yolo 2-Way Detection num_boxes {} incompatible with Mask Coefficients num_boxes {}",
1452 det_num, mask_num
1453 )));
1454 }
1455
1456 // Validate mask_coeff channels match protos channels
1457 let mask_channels = if !mask_coeff.dshape.is_empty() {
1458 Self::get_protos_count(&mask_coeff.dshape).ok_or_else(|| {
1459 DecoderError::InvalidConfig(
1460 "Could not find num_protos in mask_coeff config".to_string(),
1461 )
1462 })?
1463 } else {
1464 mask_coeff.shape[1]
1465 };
1466 let proto_channels = if !protos.dshape.is_empty() {
1467 Self::get_protos_count(&protos.dshape).ok_or_else(|| {
1468 DecoderError::InvalidConfig(
1469 "Could not find num_protos in protos config".to_string(),
1470 )
1471 })?
1472 } else {
1473 protos.shape[1].min(protos.shape[3])
1474 };
1475 if mask_channels != proto_channels {
1476 return Err(DecoderError::InvalidConfig(format!(
1477 "Yolo 2-Way Protos channels {} incompatible with Mask Coefficients channels {}",
1478 proto_channels, mask_channels
1479 )));
1480 }
1481
1482 // Validate detection has classes (nc+4 features, no mask_coefs embedded)
1483 if !detection.dshape.is_empty() {
1484 Self::get_class_count(&detection.dshape, None, None)?;
1485 } else {
1486 Self::get_class_count_no_dshape(detection.into(), None)?;
1487 }
1488
1489 Ok(())
1490 }
1491
1492 fn verify_yolo_seg_det_26(
1493 detection: &configs::Detection,
1494 protos: &configs::Protos,
1495 ) -> Result<(), DecoderError> {
1496 if detection.shape.len() != 3 {
1497 return Err(DecoderError::InvalidConfig(format!(
1498 "Invalid Yolo Detection shape {:?}",
1499 detection.shape
1500 )));
1501 }
1502 if protos.shape.len() != 4 {
1503 return Err(DecoderError::InvalidConfig(format!(
1504 "Invalid Yolo Protos shape {:?}",
1505 protos.shape
1506 )));
1507 }
1508
1509 Self::verify_dshapes(
1510 &detection.dshape,
1511 &detection.shape,
1512 "Detection",
1513 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1514 )?;
1515 Self::verify_dshapes(
1516 &protos.dshape,
1517 &protos.shape,
1518 "Protos",
1519 &[
1520 DimName::Batch,
1521 DimName::Height,
1522 DimName::Width,
1523 DimName::NumProtos,
1524 ],
1525 )?;
1526
1527 let protos_count = Self::get_protos_count(&protos.dshape)
1528 .unwrap_or_else(|| protos.shape[1].min(protos.shape[3]));
1529 log::debug!("Protos count: {}", protos_count);
1530 log::debug!("Detection dshape: {:?}", detection.dshape);
1531
1532 if !detection.shape.contains(&(6 + protos_count)) {
1533 return Err(DecoderError::InvalidConfig(format!(
1534 "Yolo26 Segmentation Detection must have num_features be 6 + num_protos = {}",
1535 6 + protos_count
1536 )));
1537 }
1538
1539 Ok(())
1540 }
1541
1542 fn verify_yolo_split_det(
1543 boxes: &configs::Boxes,
1544 scores: &configs::Scores,
1545 ) -> Result<(), DecoderError> {
1546 if boxes.shape.len() != 3 {
1547 return Err(DecoderError::InvalidConfig(format!(
1548 "Invalid Yolo Split Boxes shape {:?}",
1549 boxes.shape
1550 )));
1551 }
1552 if scores.shape.len() != 3 {
1553 return Err(DecoderError::InvalidConfig(format!(
1554 "Invalid Yolo Split Scores shape {:?}",
1555 scores.shape
1556 )));
1557 }
1558
1559 Self::verify_dshapes(
1560 &boxes.dshape,
1561 &boxes.shape,
1562 "Boxes",
1563 &[DimName::Batch, DimName::BoxCoords, DimName::NumBoxes],
1564 )?;
1565 Self::verify_dshapes(
1566 &scores.dshape,
1567 &scores.shape,
1568 "Scores",
1569 &[DimName::Batch, DimName::NumClasses, DimName::NumBoxes],
1570 )?;
1571
1572 let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[2]);
1573 let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[2]);
1574
1575 if boxes_num != scores_num {
1576 return Err(DecoderError::InvalidConfig(format!(
1577 "Yolo Split Detection Boxes num {} incompatible with Scores num {}",
1578 boxes_num, scores_num
1579 )));
1580 }
1581
1582 Ok(())
1583 }
1584
1585 fn verify_yolo_split_segdet(
1586 boxes: &configs::Boxes,
1587 scores: &configs::Scores,
1588 mask_coeff: &configs::MaskCoefficients,
1589 protos: &configs::Protos,
1590 ) -> Result<(), DecoderError> {
1591 if boxes.shape.len() != 3 {
1592 return Err(DecoderError::InvalidConfig(format!(
1593 "Invalid Yolo Split Boxes shape {:?}",
1594 boxes.shape
1595 )));
1596 }
1597 if scores.shape.len() != 3 {
1598 return Err(DecoderError::InvalidConfig(format!(
1599 "Invalid Yolo Split Scores shape {:?}",
1600 scores.shape
1601 )));
1602 }
1603
1604 if mask_coeff.shape.len() != 3 {
1605 return Err(DecoderError::InvalidConfig(format!(
1606 "Invalid Yolo Split Mask Coefficients shape {:?}",
1607 mask_coeff.shape
1608 )));
1609 }
1610
1611 if protos.shape.len() != 4 {
1612 return Err(DecoderError::InvalidConfig(format!(
1613 "Invalid Yolo Protos shape {:?}",
1614 mask_coeff.shape
1615 )));
1616 }
1617
1618 Self::verify_dshapes(
1619 &boxes.dshape,
1620 &boxes.shape,
1621 "Boxes",
1622 &[DimName::Batch, DimName::BoxCoords, DimName::NumBoxes],
1623 )?;
1624 Self::verify_dshapes(
1625 &scores.dshape,
1626 &scores.shape,
1627 "Scores",
1628 &[DimName::Batch, DimName::NumClasses, DimName::NumBoxes],
1629 )?;
1630 Self::verify_dshapes(
1631 &mask_coeff.dshape,
1632 &mask_coeff.shape,
1633 "Mask Coefficients",
1634 &[DimName::Batch, DimName::NumProtos, DimName::NumBoxes],
1635 )?;
1636 Self::verify_dshapes(
1637 &protos.dshape,
1638 &protos.shape,
1639 "Protos",
1640 &[
1641 DimName::Batch,
1642 DimName::Height,
1643 DimName::Width,
1644 DimName::NumProtos,
1645 ],
1646 )?;
1647
1648 let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[2]);
1649 let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[2]);
1650 let mask_num = Self::get_box_count(&mask_coeff.dshape).unwrap_or(mask_coeff.shape[2]);
1651
1652 let mask_channels = if !mask_coeff.dshape.is_empty() {
1653 Self::get_protos_count(&mask_coeff.dshape).ok_or_else(|| {
1654 DecoderError::InvalidConfig("Could not find num_protos in config".to_string())
1655 })?
1656 } else {
1657 mask_coeff.shape[1]
1658 };
1659 let proto_channels = if !protos.dshape.is_empty() {
1660 Self::get_protos_count(&protos.dshape).ok_or_else(|| {
1661 DecoderError::InvalidConfig("Could not find num_protos in config".to_string())
1662 })?
1663 } else {
1664 protos.shape[1].min(protos.shape[3])
1665 };
1666
1667 if boxes_num != scores_num {
1668 return Err(DecoderError::InvalidConfig(format!(
1669 "Yolo Split Detection Boxes num {} incompatible with Scores num {}",
1670 boxes_num, scores_num
1671 )));
1672 }
1673
1674 if boxes_num != mask_num {
1675 return Err(DecoderError::InvalidConfig(format!(
1676 "Yolo Split Detection Boxes num {} incompatible with Mask Coefficients num {}",
1677 boxes_num, mask_num
1678 )));
1679 }
1680
1681 if proto_channels != mask_channels {
1682 return Err(DecoderError::InvalidConfig(format!(
1683 "Yolo Protos channels {} incompatible with Mask Coefficients channels {}",
1684 proto_channels, mask_channels
1685 )));
1686 }
1687
1688 Ok(())
1689 }
1690
1691 fn verify_yolo_split_end_to_end_det(
1692 boxes: &configs::Boxes,
1693 scores: &configs::Scores,
1694 classes: &configs::Classes,
1695 ) -> Result<(), DecoderError> {
1696 if boxes.shape.len() != 3 || !boxes.shape.contains(&4) {
1697 return Err(DecoderError::InvalidConfig(format!(
1698 "Split end-to-end boxes must be [batch, N, 4], got {:?}",
1699 boxes.shape
1700 )));
1701 }
1702 if scores.shape.len() != 3 || !scores.shape.contains(&1) {
1703 return Err(DecoderError::InvalidConfig(format!(
1704 "Split end-to-end scores must be [batch, N, 1], got {:?}",
1705 scores.shape
1706 )));
1707 }
1708 if classes.shape.len() != 3 || !classes.shape.contains(&1) {
1709 return Err(DecoderError::InvalidConfig(format!(
1710 "Split end-to-end classes must be [batch, N, 1], got {:?}",
1711 classes.shape
1712 )));
1713 }
1714 Ok(())
1715 }
1716
1717 fn verify_yolo_split_end_to_end_segdet(
1718 boxes: &configs::Boxes,
1719 scores: &configs::Scores,
1720 classes: &configs::Classes,
1721 mask_coeff: &configs::MaskCoefficients,
1722 protos: &configs::Protos,
1723 ) -> Result<(), DecoderError> {
1724 Self::verify_yolo_split_end_to_end_det(boxes, scores, classes)?;
1725 if mask_coeff.shape.len() != 3 {
1726 return Err(DecoderError::InvalidConfig(format!(
1727 "Invalid split end-to-end mask coefficients shape {:?}",
1728 mask_coeff.shape
1729 )));
1730 }
1731 if protos.shape.len() != 4 {
1732 return Err(DecoderError::InvalidConfig(format!(
1733 "Invalid protos shape {:?}",
1734 protos.shape
1735 )));
1736 }
1737 Ok(())
1738 }
1739
1740 fn get_model_type_modelpack(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
1741 let mut split_decoders = Vec::new();
1742 let mut segment_ = None;
1743 let mut scores_ = None;
1744 let mut boxes_ = None;
1745 for c in configs.outputs {
1746 match c {
1747 ConfigOutput::Detection(detection) => split_decoders.push(detection),
1748 ConfigOutput::Segmentation(segmentation) => segment_ = Some(segmentation),
1749 ConfigOutput::Mask(_) => {}
1750 ConfigOutput::Protos(_) => {
1751 return Err(DecoderError::InvalidConfig(
1752 "ModelPack should not have protos".to_string(),
1753 ));
1754 }
1755 ConfigOutput::Scores(scores) => scores_ = Some(scores),
1756 ConfigOutput::Boxes(boxes) => boxes_ = Some(boxes),
1757 ConfigOutput::MaskCoefficients(_) => {
1758 return Err(DecoderError::InvalidConfig(
1759 "ModelPack should not have mask coefficients".to_string(),
1760 ));
1761 }
1762 ConfigOutput::Classes(_) => {
1763 return Err(DecoderError::InvalidConfig(
1764 "ModelPack should not have classes output".to_string(),
1765 ));
1766 }
1767 }
1768 }
1769
1770 if let Some(segmentation) = segment_ {
1771 if !split_decoders.is_empty() {
1772 let classes = Self::verify_modelpack_split_det(&split_decoders)?;
1773 Self::verify_modelpack_seg(&segmentation, Some(classes))?;
1774 Ok(ModelType::ModelPackSegDetSplit {
1775 detection: split_decoders,
1776 segmentation,
1777 })
1778 } else if let (Some(scores), Some(boxes)) = (scores_, boxes_) {
1779 let classes = Self::verify_modelpack_det(&boxes, &scores)?;
1780 Self::verify_modelpack_seg(&segmentation, Some(classes))?;
1781 Ok(ModelType::ModelPackSegDet {
1782 boxes,
1783 scores,
1784 segmentation,
1785 })
1786 } else {
1787 Self::verify_modelpack_seg(&segmentation, None)?;
1788 Ok(ModelType::ModelPackSeg { segmentation })
1789 }
1790 } else if !split_decoders.is_empty() {
1791 Self::verify_modelpack_split_det(&split_decoders)?;
1792 Ok(ModelType::ModelPackDetSplit {
1793 detection: split_decoders,
1794 })
1795 } else if let (Some(scores), Some(boxes)) = (scores_, boxes_) {
1796 Self::verify_modelpack_det(&boxes, &scores)?;
1797 Ok(ModelType::ModelPackDet { boxes, scores })
1798 } else {
1799 Err(DecoderError::InvalidConfig(
1800 "Invalid ModelPack model outputs".to_string(),
1801 ))
1802 }
1803 }
1804
1805 fn verify_modelpack_det(
1806 boxes: &configs::Boxes,
1807 scores: &configs::Scores,
1808 ) -> Result<usize, DecoderError> {
1809 if boxes.shape.len() != 4 {
1810 return Err(DecoderError::InvalidConfig(format!(
1811 "Invalid ModelPack Boxes shape {:?}",
1812 boxes.shape
1813 )));
1814 }
1815 if scores.shape.len() != 3 {
1816 return Err(DecoderError::InvalidConfig(format!(
1817 "Invalid ModelPack Scores shape {:?}",
1818 scores.shape
1819 )));
1820 }
1821
1822 Self::verify_dshapes(
1823 &boxes.dshape,
1824 &boxes.shape,
1825 "Boxes",
1826 &[
1827 DimName::Batch,
1828 DimName::NumBoxes,
1829 DimName::Padding,
1830 DimName::BoxCoords,
1831 ],
1832 )?;
1833 Self::verify_dshapes(
1834 &scores.dshape,
1835 &scores.shape,
1836 "Scores",
1837 &[DimName::Batch, DimName::NumBoxes, DimName::NumClasses],
1838 )?;
1839
1840 let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[1]);
1841 let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[1]);
1842
1843 if boxes_num != scores_num {
1844 return Err(DecoderError::InvalidConfig(format!(
1845 "ModelPack Detection Boxes num {} incompatible with Scores num {}",
1846 boxes_num, scores_num
1847 )));
1848 }
1849
1850 let num_classes = if !scores.dshape.is_empty() {
1851 Self::get_class_count(&scores.dshape, None, None)?
1852 } else {
1853 Self::get_class_count_no_dshape(scores.into(), None)?
1854 };
1855
1856 Ok(num_classes)
1857 }
1858
1859 fn verify_modelpack_split_det(boxes: &[configs::Detection]) -> Result<usize, DecoderError> {
1860 let mut num_classes = None;
1861 for b in boxes {
1862 let Some(num_anchors) = b.anchors.as_ref().map(|a| a.len()) else {
1863 return Err(DecoderError::InvalidConfig(
1864 "ModelPack Split Detection missing anchors".to_string(),
1865 ));
1866 };
1867
1868 if num_anchors == 0 {
1869 return Err(DecoderError::InvalidConfig(
1870 "ModelPack Split Detection has zero anchors".to_string(),
1871 ));
1872 }
1873
1874 if b.shape.len() != 4 {
1875 return Err(DecoderError::InvalidConfig(format!(
1876 "Invalid ModelPack Split Detection shape {:?}",
1877 b.shape
1878 )));
1879 }
1880
1881 Self::verify_dshapes(
1882 &b.dshape,
1883 &b.shape,
1884 "Split Detection",
1885 &[
1886 DimName::Batch,
1887 DimName::Height,
1888 DimName::Width,
1889 DimName::NumAnchorsXFeatures,
1890 ],
1891 )?;
1892 let classes = if !b.dshape.is_empty() {
1893 Self::get_class_count(&b.dshape, None, Some(num_anchors))?
1894 } else {
1895 Self::get_class_count_no_dshape(b.into(), None)?
1896 };
1897
1898 match num_classes {
1899 Some(n) => {
1900 if n != classes {
1901 return Err(DecoderError::InvalidConfig(format!(
1902 "ModelPack Split Detection inconsistent number of classes: previous {}, current {}",
1903 n, classes
1904 )));
1905 }
1906 }
1907 None => {
1908 num_classes = Some(classes);
1909 }
1910 }
1911 }
1912
1913 Ok(num_classes.unwrap_or(0))
1914 }
1915
1916 fn verify_modelpack_seg(
1917 segmentation: &configs::Segmentation,
1918 classes: Option<usize>,
1919 ) -> Result<(), DecoderError> {
1920 if segmentation.shape.len() != 4 {
1921 return Err(DecoderError::InvalidConfig(format!(
1922 "Invalid ModelPack Segmentation shape {:?}",
1923 segmentation.shape
1924 )));
1925 }
1926 Self::verify_dshapes(
1927 &segmentation.dshape,
1928 &segmentation.shape,
1929 "Segmentation",
1930 &[
1931 DimName::Batch,
1932 DimName::Height,
1933 DimName::Width,
1934 DimName::NumClasses,
1935 ],
1936 )?;
1937
1938 if let Some(classes) = classes {
1939 let seg_classes = if !segmentation.dshape.is_empty() {
1940 Self::get_class_count(&segmentation.dshape, None, None)?
1941 } else {
1942 Self::get_class_count_no_dshape(segmentation.into(), None)?
1943 };
1944
1945 if seg_classes != classes + 1 {
1946 return Err(DecoderError::InvalidConfig(format!(
1947 "ModelPack Segmentation channels {} incompatible with number of classes {}",
1948 seg_classes, classes
1949 )));
1950 }
1951 }
1952 Ok(())
1953 }
1954
1955 // verifies that dshapes match the given shape
1956 fn verify_dshapes(
1957 dshape: &[(DimName, usize)],
1958 shape: &[usize],
1959 name: &str,
1960 dims: &[DimName],
1961 ) -> Result<(), DecoderError> {
1962 for s in shape {
1963 if *s == 0 {
1964 return Err(DecoderError::InvalidConfig(format!(
1965 "{} shape has zero dimension",
1966 name
1967 )));
1968 }
1969 }
1970
1971 if shape.len() != dims.len() {
1972 return Err(DecoderError::InvalidConfig(format!(
1973 "{} shape length {} does not match expected dims length {}",
1974 name,
1975 shape.len(),
1976 dims.len()
1977 )));
1978 }
1979
1980 if dshape.is_empty() {
1981 return Ok(());
1982 }
1983 // check the dshape lengths match the shape lengths
1984 if dshape.len() != shape.len() {
1985 return Err(DecoderError::InvalidConfig(format!(
1986 "{} dshape length does not match shape length",
1987 name
1988 )));
1989 }
1990
1991 // check the dshape values match the shape values
1992 for ((dim_name, dim_size), shape_size) in dshape.iter().zip(shape) {
1993 if dim_size != shape_size {
1994 return Err(DecoderError::InvalidConfig(format!(
1995 "{} dshape dimension {} size {} does not match shape size {}",
1996 name, dim_name, dim_size, shape_size
1997 )));
1998 }
1999 if *dim_name == DimName::Padding && *dim_size != 1 {
2000 return Err(DecoderError::InvalidConfig(
2001 "Padding dimension size must be 1".to_string(),
2002 ));
2003 }
2004
2005 if *dim_name == DimName::BoxCoords && *dim_size != 4 {
2006 return Err(DecoderError::InvalidConfig(
2007 "BoxCoords dimension size must be 4".to_string(),
2008 ));
2009 }
2010 }
2011
2012 let dims_present = HashSet::<DimName>::from_iter(dshape.iter().map(|(name, _)| *name));
2013 for dim in dims {
2014 if !dims_present.contains(dim) {
2015 return Err(DecoderError::InvalidConfig(format!(
2016 "{} dshape missing required dimension {:?}",
2017 name, dim
2018 )));
2019 }
2020 }
2021
2022 Ok(())
2023 }
2024
2025 fn get_box_count(dshape: &[(DimName, usize)]) -> Option<usize> {
2026 for (dim_name, dim_size) in dshape {
2027 if *dim_name == DimName::NumBoxes {
2028 return Some(*dim_size);
2029 }
2030 }
2031 None
2032 }
2033
2034 fn get_class_count_no_dshape(
2035 config: ConfigOutputRef,
2036 protos: Option<usize>,
2037 ) -> Result<usize, DecoderError> {
2038 match config {
2039 ConfigOutputRef::Detection(detection) => match detection.decoder {
2040 DecoderType::Ultralytics => {
2041 if detection.shape[1] <= 4 + protos.unwrap_or(0) {
2042 return Err(DecoderError::InvalidConfig(format!(
2043 "Invalid shape: Yolo num_features {} must be greater than {}",
2044 detection.shape[1],
2045 4 + protos.unwrap_or(0),
2046 )));
2047 }
2048 Ok(detection.shape[1] - 4 - protos.unwrap_or(0))
2049 }
2050 DecoderType::ModelPack => {
2051 let Some(num_anchors) = detection.anchors.as_ref().map(|a| a.len()) else {
2052 return Err(DecoderError::Internal(
2053 "ModelPack Detection missing anchors".to_string(),
2054 ));
2055 };
2056 let anchors_x_features = detection.shape[3];
2057 if anchors_x_features <= num_anchors * 5 {
2058 return Err(DecoderError::InvalidConfig(format!(
2059 "Invalid ModelPack Split Detection shape: anchors_x_features {} not greater than number of anchors * 5 = {}",
2060 anchors_x_features,
2061 num_anchors * 5,
2062 )));
2063 }
2064
2065 if !anchors_x_features.is_multiple_of(num_anchors) {
2066 return Err(DecoderError::InvalidConfig(format!(
2067 "Invalid ModelPack Split Detection shape: anchors_x_features {} not a multiple of number of anchors {}",
2068 anchors_x_features, num_anchors
2069 )));
2070 }
2071 Ok(anchors_x_features / num_anchors - 5)
2072 }
2073 },
2074
2075 ConfigOutputRef::Scores(scores) => match scores.decoder {
2076 DecoderType::Ultralytics => Ok(scores.shape[1]),
2077 DecoderType::ModelPack => Ok(scores.shape[2]),
2078 },
2079 ConfigOutputRef::Segmentation(seg) => Ok(seg.shape[3]),
2080 _ => Err(DecoderError::Internal(
2081 "Attempted to get class count from unsupported config output".to_owned(),
2082 )),
2083 }
2084 }
2085
2086 // get the class count from dshape or calculate from num_features
2087 fn get_class_count(
2088 dshape: &[(DimName, usize)],
2089 protos: Option<usize>,
2090 anchors: Option<usize>,
2091 ) -> Result<usize, DecoderError> {
2092 if dshape.is_empty() {
2093 return Ok(0);
2094 }
2095 // if it has num_classes in dshape, return it
2096 for (dim_name, dim_size) in dshape {
2097 if *dim_name == DimName::NumClasses {
2098 return Ok(*dim_size);
2099 }
2100 }
2101
2102 // number of classes can be calculated from num_features - 4 for yolo. If the
2103 // model has protos, we also subtract the number of protos.
2104 for (dim_name, dim_size) in dshape {
2105 if *dim_name == DimName::NumFeatures {
2106 let protos = protos.unwrap_or(0);
2107 if protos + 4 >= *dim_size {
2108 return Err(DecoderError::InvalidConfig(format!(
2109 "Invalid shape: Yolo num_features {} must be greater than {}",
2110 *dim_size,
2111 protos + 4,
2112 )));
2113 }
2114 return Ok(*dim_size - 4 - protos);
2115 }
2116 }
2117
2118 // number of classes can be calculated from number of anchors for modelpack
2119 // split detection
2120 if let Some(num_anchors) = anchors {
2121 for (dim_name, dim_size) in dshape {
2122 if *dim_name == DimName::NumAnchorsXFeatures {
2123 let anchors_x_features = *dim_size;
2124 if anchors_x_features <= num_anchors * 5 {
2125 return Err(DecoderError::InvalidConfig(format!(
2126 "Invalid ModelPack Split Detection shape: anchors_x_features {} not greater than number of anchors * 5 = {}",
2127 anchors_x_features,
2128 num_anchors * 5,
2129 )));
2130 }
2131
2132 if !anchors_x_features.is_multiple_of(num_anchors) {
2133 return Err(DecoderError::InvalidConfig(format!(
2134 "Invalid ModelPack Split Detection shape: anchors_x_features {} not a multiple of number of anchors {}",
2135 anchors_x_features, num_anchors
2136 )));
2137 }
2138 return Ok((anchors_x_features / num_anchors) - 5);
2139 }
2140 }
2141 }
2142 Err(DecoderError::InvalidConfig(
2143 "Cannot determine number of classes from dshape".to_owned(),
2144 ))
2145 }
2146
2147 fn get_protos_count(dshape: &[(DimName, usize)]) -> Option<usize> {
2148 for (dim_name, dim_size) in dshape {
2149 if *dim_name == DimName::NumProtos {
2150 return Some(*dim_size);
2151 }
2152 }
2153 None
2154 }
2155}