edgefirst_decoder/decoder/builder.rs
1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::HashSet;
5
6use super::config::ConfigOutputRef;
7use super::configs::{self, DecoderType, DecoderVersion, DimName, ModelType};
8use super::{ConfigOutput, ConfigOutputs, Decoder};
9use crate::DecoderError;
10
11#[derive(Debug, Clone, PartialEq)]
12pub struct DecoderBuilder {
13 config_src: Option<ConfigSource>,
14 iou_threshold: f32,
15 score_threshold: f32,
16 /// NMS mode: Some(mode) applies NMS, None bypasses NMS (for end-to-end
17 /// models)
18 nms: Option<configs::Nms>,
19}
20
21#[derive(Debug, Clone, PartialEq)]
22enum ConfigSource {
23 Yaml(String),
24 Json(String),
25 Config(ConfigOutputs),
26}
27
28impl Default for DecoderBuilder {
29 /// Creates a default DecoderBuilder with no configuration and 0.5 score
30 /// threshold and 0.5 IoU threshold.
31 ///
32 /// A valid configuration must be provided before building the Decoder.
33 ///
34 /// # Examples
35 /// ```rust
36 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
37 /// # fn main() -> DecoderResult<()> {
38 /// # let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
39 /// let decoder = DecoderBuilder::default()
40 /// .with_config_yaml_str(config_yaml)
41 /// .build()?;
42 /// assert_eq!(decoder.score_threshold, 0.5);
43 /// assert_eq!(decoder.iou_threshold, 0.5);
44 ///
45 /// # Ok(())
46 /// # }
47 /// ```
48 fn default() -> Self {
49 Self {
50 config_src: None,
51 iou_threshold: 0.5,
52 score_threshold: 0.5,
53 nms: Some(configs::Nms::ClassAgnostic),
54 }
55 }
56}
57
58impl DecoderBuilder {
59 /// Creates a default DecoderBuilder with no configuration and 0.5 score
60 /// threshold and 0.5 IoU threshold.
61 ///
62 /// A valid configuration must be provided before building the Decoder.
63 ///
64 /// # Examples
65 /// ```rust
66 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
67 /// # fn main() -> DecoderResult<()> {
68 /// # let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
69 /// let decoder = DecoderBuilder::new()
70 /// .with_config_yaml_str(config_yaml)
71 /// .build()?;
72 /// assert_eq!(decoder.score_threshold, 0.5);
73 /// assert_eq!(decoder.iou_threshold, 0.5);
74 ///
75 /// # Ok(())
76 /// # }
77 /// ```
78 pub fn new() -> Self {
79 Self::default()
80 }
81
82 /// Loads a model configuration in YAML format. Does not check if the string
83 /// is a correct configuration file. Use `DecoderBuilder.build()` to
84 /// deserialize the YAML and parse the model configuration.
85 ///
86 /// # Examples
87 /// ```rust
88 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
89 /// # fn main() -> DecoderResult<()> {
90 /// let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
91 /// let decoder = DecoderBuilder::new()
92 /// .with_config_yaml_str(config_yaml)
93 /// .build()?;
94 ///
95 /// # Ok(())
96 /// # }
97 /// ```
98 pub fn with_config_yaml_str(mut self, yaml_str: String) -> Self {
99 self.config_src.replace(ConfigSource::Yaml(yaml_str));
100 self
101 }
102
103 /// Loads a model configuration in JSON format. Does not check if the string
104 /// is a correct configuration file. Use `DecoderBuilder.build()` to
105 /// deserialize the JSON and parse the model configuration.
106 ///
107 /// # Examples
108 /// ```rust
109 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
110 /// # fn main() -> DecoderResult<()> {
111 /// let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
112 /// let decoder = DecoderBuilder::new()
113 /// .with_config_json_str(config_json)
114 /// .build()?;
115 ///
116 /// # Ok(())
117 /// # }
118 /// ```
119 pub fn with_config_json_str(mut self, json_str: String) -> Self {
120 self.config_src.replace(ConfigSource::Json(json_str));
121 self
122 }
123
124 /// Loads a model configuration. Does not check if the configuration is
125 /// correct. Intended to be used when the user needs control over the
126 /// deserialize of the configuration information. Use
127 /// `DecoderBuilder.build()` to parse the model configuration.
128 ///
129 /// # Examples
130 /// ```rust
131 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
132 /// # fn main() -> DecoderResult<()> {
133 /// let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json"));
134 /// let config = serde_json::from_str(config_json)?;
135 /// let decoder = DecoderBuilder::new().with_config(config).build()?;
136 ///
137 /// # Ok(())
138 /// # }
139 /// ```
140 pub fn with_config(mut self, config: ConfigOutputs) -> Self {
141 self.config_src.replace(ConfigSource::Config(config));
142 self
143 }
144
145 /// Loads a YOLO detection model configuration. Use
146 /// `DecoderBuilder.build()` to parse the model configuration.
147 ///
148 /// # Examples
149 /// ```rust
150 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
151 /// # fn main() -> DecoderResult<()> {
152 /// let decoder = DecoderBuilder::new()
153 /// .with_config_yolo_det(
154 /// configs::Detection {
155 /// anchors: None,
156 /// decoder: configs::DecoderType::Ultralytics,
157 /// quantization: Some(configs::QuantTuple(0.012345, 26)),
158 /// shape: vec![1, 84, 8400],
159 /// dshape: Vec::new(),
160 /// normalized: Some(true),
161 /// },
162 /// None,
163 /// )
164 /// .build()?;
165 ///
166 /// # Ok(())
167 /// # }
168 /// ```
169 pub fn with_config_yolo_det(
170 mut self,
171 boxes: configs::Detection,
172 version: Option<DecoderVersion>,
173 ) -> Self {
174 let config = ConfigOutputs {
175 outputs: vec![ConfigOutput::Detection(boxes)],
176 decoder_version: version,
177 ..Default::default()
178 };
179 self.config_src.replace(ConfigSource::Config(config));
180 self
181 }
182
183 /// Loads a YOLO split detection model configuration. Use
184 /// `DecoderBuilder.build()` to parse the model configuration.
185 ///
186 /// # Examples
187 /// ```rust
188 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
189 /// # fn main() -> DecoderResult<()> {
190 /// let boxes_config = configs::Boxes {
191 /// decoder: configs::DecoderType::Ultralytics,
192 /// quantization: Some(configs::QuantTuple(0.012345, 26)),
193 /// shape: vec![1, 4, 8400],
194 /// dshape: Vec::new(),
195 /// normalized: Some(true),
196 /// };
197 /// let scores_config = configs::Scores {
198 /// decoder: configs::DecoderType::Ultralytics,
199 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
200 /// shape: vec![1, 80, 8400],
201 /// dshape: Vec::new(),
202 /// };
203 /// let decoder = DecoderBuilder::new()
204 /// .with_config_yolo_split_det(boxes_config, scores_config)
205 /// .build()?;
206 /// # Ok(())
207 /// # }
208 /// ```
209 pub fn with_config_yolo_split_det(
210 mut self,
211 boxes: configs::Boxes,
212 scores: configs::Scores,
213 ) -> Self {
214 let config = ConfigOutputs {
215 outputs: vec![ConfigOutput::Boxes(boxes), ConfigOutput::Scores(scores)],
216 ..Default::default()
217 };
218 self.config_src.replace(ConfigSource::Config(config));
219 self
220 }
221
222 /// Loads a YOLO segmentation model configuration. Use
223 /// `DecoderBuilder.build()` to parse the model configuration.
224 ///
225 /// # Examples
226 /// ```rust
227 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
228 /// # fn main() -> DecoderResult<()> {
229 /// let seg_config = configs::Detection {
230 /// decoder: configs::DecoderType::Ultralytics,
231 /// quantization: Some(configs::QuantTuple(0.012345, 26)),
232 /// shape: vec![1, 116, 8400],
233 /// anchors: None,
234 /// dshape: Vec::new(),
235 /// normalized: Some(true),
236 /// };
237 /// let protos_config = configs::Protos {
238 /// decoder: configs::DecoderType::Ultralytics,
239 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
240 /// shape: vec![1, 160, 160, 32],
241 /// dshape: Vec::new(),
242 /// };
243 /// let decoder = DecoderBuilder::new()
244 /// .with_config_yolo_segdet(
245 /// seg_config,
246 /// protos_config,
247 /// Some(configs::DecoderVersion::Yolov8),
248 /// )
249 /// .build()?;
250 /// # Ok(())
251 /// # }
252 /// ```
253 pub fn with_config_yolo_segdet(
254 mut self,
255 boxes: configs::Detection,
256 protos: configs::Protos,
257 version: Option<DecoderVersion>,
258 ) -> Self {
259 let config = ConfigOutputs {
260 outputs: vec![ConfigOutput::Detection(boxes), ConfigOutput::Protos(protos)],
261 decoder_version: version,
262 ..Default::default()
263 };
264 self.config_src.replace(ConfigSource::Config(config));
265 self
266 }
267
268 /// Loads a YOLO split segmentation model configuration. Use
269 /// `DecoderBuilder.build()` to parse the model configuration.
270 ///
271 /// # Examples
272 /// ```rust
273 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
274 /// # fn main() -> DecoderResult<()> {
275 /// let boxes_config = configs::Boxes {
276 /// decoder: configs::DecoderType::Ultralytics,
277 /// quantization: Some(configs::QuantTuple(0.012345, 26)),
278 /// shape: vec![1, 4, 8400],
279 /// dshape: Vec::new(),
280 /// normalized: Some(true),
281 /// };
282 /// let scores_config = configs::Scores {
283 /// decoder: configs::DecoderType::Ultralytics,
284 /// quantization: Some(configs::QuantTuple(0.012345, 14)),
285 /// shape: vec![1, 80, 8400],
286 /// dshape: Vec::new(),
287 /// };
288 /// let mask_config = configs::MaskCoefficients {
289 /// decoder: configs::DecoderType::Ultralytics,
290 /// quantization: Some(configs::QuantTuple(0.0064123, 125)),
291 /// shape: vec![1, 32, 8400],
292 /// dshape: Vec::new(),
293 /// };
294 /// let protos_config = configs::Protos {
295 /// decoder: configs::DecoderType::Ultralytics,
296 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
297 /// shape: vec![1, 160, 160, 32],
298 /// dshape: Vec::new(),
299 /// };
300 /// let decoder = DecoderBuilder::new()
301 /// .with_config_yolo_split_segdet(boxes_config, scores_config, mask_config, protos_config)
302 /// .build()?;
303 /// # Ok(())
304 /// # }
305 /// ```
306 pub fn with_config_yolo_split_segdet(
307 mut self,
308 boxes: configs::Boxes,
309 scores: configs::Scores,
310 mask_coefficients: configs::MaskCoefficients,
311 protos: configs::Protos,
312 ) -> Self {
313 let config = ConfigOutputs {
314 outputs: vec![
315 ConfigOutput::Boxes(boxes),
316 ConfigOutput::Scores(scores),
317 ConfigOutput::MaskCoefficients(mask_coefficients),
318 ConfigOutput::Protos(protos),
319 ],
320 ..Default::default()
321 };
322 self.config_src.replace(ConfigSource::Config(config));
323 self
324 }
325
326 /// Loads a ModelPack detection model configuration. Use
327 /// `DecoderBuilder.build()` to parse the model configuration.
328 ///
329 /// # Examples
330 /// ```rust
331 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
332 /// # fn main() -> DecoderResult<()> {
333 /// let boxes_config = configs::Boxes {
334 /// decoder: configs::DecoderType::ModelPack,
335 /// quantization: Some(configs::QuantTuple(0.012345, 26)),
336 /// shape: vec![1, 8400, 1, 4],
337 /// dshape: Vec::new(),
338 /// normalized: Some(true),
339 /// };
340 /// let scores_config = configs::Scores {
341 /// decoder: configs::DecoderType::ModelPack,
342 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
343 /// shape: vec![1, 8400, 3],
344 /// dshape: Vec::new(),
345 /// };
346 /// let decoder = DecoderBuilder::new()
347 /// .with_config_modelpack_det(boxes_config, scores_config)
348 /// .build()?;
349 /// # Ok(())
350 /// # }
351 /// ```
352 pub fn with_config_modelpack_det(
353 mut self,
354 boxes: configs::Boxes,
355 scores: configs::Scores,
356 ) -> Self {
357 let config = ConfigOutputs {
358 outputs: vec![ConfigOutput::Boxes(boxes), ConfigOutput::Scores(scores)],
359 ..Default::default()
360 };
361 self.config_src.replace(ConfigSource::Config(config));
362 self
363 }
364
365 /// Loads a ModelPack split detection model configuration. Use
366 /// `DecoderBuilder.build()` to parse the model configuration.
367 ///
368 /// # Examples
369 /// ```rust
370 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
371 /// # fn main() -> DecoderResult<()> {
372 /// let config0 = configs::Detection {
373 /// anchors: Some(vec![
374 /// [0.13750000298023224, 0.2074074000120163],
375 /// [0.2541666626930237, 0.21481481194496155],
376 /// [0.23125000298023224, 0.35185185074806213],
377 /// ]),
378 /// decoder: configs::DecoderType::ModelPack,
379 /// quantization: Some(configs::QuantTuple(0.012345, 26)),
380 /// shape: vec![1, 17, 30, 18],
381 /// dshape: Vec::new(),
382 /// normalized: Some(true),
383 /// };
384 /// let config1 = configs::Detection {
385 /// anchors: Some(vec![
386 /// [0.36666667461395264, 0.31481480598449707],
387 /// [0.38749998807907104, 0.4740740656852722],
388 /// [0.5333333611488342, 0.644444465637207],
389 /// ]),
390 /// decoder: configs::DecoderType::ModelPack,
391 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
392 /// shape: vec![1, 9, 15, 18],
393 /// dshape: Vec::new(),
394 /// normalized: Some(true),
395 /// };
396 ///
397 /// let decoder = DecoderBuilder::new()
398 /// .with_config_modelpack_det_split(vec![config0, config1])
399 /// .build()?;
400 /// # Ok(())
401 /// # }
402 /// ```
403 pub fn with_config_modelpack_det_split(mut self, boxes: Vec<configs::Detection>) -> Self {
404 let outputs = boxes.into_iter().map(ConfigOutput::Detection).collect();
405 let config = ConfigOutputs {
406 outputs,
407 ..Default::default()
408 };
409 self.config_src.replace(ConfigSource::Config(config));
410 self
411 }
412
413 /// Loads a ModelPack segmentation detection model configuration. Use
414 /// `DecoderBuilder.build()` to parse the model configuration.
415 ///
416 /// # Examples
417 /// ```rust
418 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
419 /// # fn main() -> DecoderResult<()> {
420 /// let boxes_config = configs::Boxes {
421 /// decoder: configs::DecoderType::ModelPack,
422 /// quantization: Some(configs::QuantTuple(0.012345, 26)),
423 /// shape: vec![1, 8400, 1, 4],
424 /// dshape: Vec::new(),
425 /// normalized: Some(true),
426 /// };
427 /// let scores_config = configs::Scores {
428 /// decoder: configs::DecoderType::ModelPack,
429 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
430 /// shape: vec![1, 8400, 2],
431 /// dshape: Vec::new(),
432 /// };
433 /// let seg_config = configs::Segmentation {
434 /// decoder: configs::DecoderType::ModelPack,
435 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
436 /// shape: vec![1, 640, 640, 3],
437 /// dshape: Vec::new(),
438 /// };
439 /// let decoder = DecoderBuilder::new()
440 /// .with_config_modelpack_segdet(boxes_config, scores_config, seg_config)
441 /// .build()?;
442 /// # Ok(())
443 /// # }
444 /// ```
445 pub fn with_config_modelpack_segdet(
446 mut self,
447 boxes: configs::Boxes,
448 scores: configs::Scores,
449 segmentation: configs::Segmentation,
450 ) -> Self {
451 let config = ConfigOutputs {
452 outputs: vec![
453 ConfigOutput::Boxes(boxes),
454 ConfigOutput::Scores(scores),
455 ConfigOutput::Segmentation(segmentation),
456 ],
457 ..Default::default()
458 };
459 self.config_src.replace(ConfigSource::Config(config));
460 self
461 }
462
463 /// Loads a ModelPack segmentation split detection model configuration. Use
464 /// `DecoderBuilder.build()` to parse the model configuration.
465 ///
466 /// # Examples
467 /// ```rust
468 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
469 /// # fn main() -> DecoderResult<()> {
470 /// let config0 = configs::Detection {
471 /// anchors: Some(vec![
472 /// [0.36666667461395264, 0.31481480598449707],
473 /// [0.38749998807907104, 0.4740740656852722],
474 /// [0.5333333611488342, 0.644444465637207],
475 /// ]),
476 /// decoder: configs::DecoderType::ModelPack,
477 /// quantization: Some(configs::QuantTuple(0.08547406643629074, 174)),
478 /// shape: vec![1, 9, 15, 18],
479 /// dshape: Vec::new(),
480 /// normalized: Some(true),
481 /// };
482 /// let config1 = configs::Detection {
483 /// anchors: Some(vec![
484 /// [0.13750000298023224, 0.2074074000120163],
485 /// [0.2541666626930237, 0.21481481194496155],
486 /// [0.23125000298023224, 0.35185185074806213],
487 /// ]),
488 /// decoder: configs::DecoderType::ModelPack,
489 /// quantization: Some(configs::QuantTuple(0.09929127991199493, 183)),
490 /// shape: vec![1, 17, 30, 18],
491 /// dshape: Vec::new(),
492 /// normalized: Some(true),
493 /// };
494 /// let seg_config = configs::Segmentation {
495 /// decoder: configs::DecoderType::ModelPack,
496 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
497 /// shape: vec![1, 640, 640, 2],
498 /// dshape: Vec::new(),
499 /// };
500 /// let decoder = DecoderBuilder::new()
501 /// .with_config_modelpack_segdet_split(vec![config0, config1], seg_config)
502 /// .build()?;
503 /// # Ok(())
504 /// # }
505 /// ```
506 pub fn with_config_modelpack_segdet_split(
507 mut self,
508 boxes: Vec<configs::Detection>,
509 segmentation: configs::Segmentation,
510 ) -> Self {
511 let mut outputs = boxes
512 .into_iter()
513 .map(ConfigOutput::Detection)
514 .collect::<Vec<_>>();
515 outputs.push(ConfigOutput::Segmentation(segmentation));
516 let config = ConfigOutputs {
517 outputs,
518 ..Default::default()
519 };
520 self.config_src.replace(ConfigSource::Config(config));
521 self
522 }
523
524 /// Loads a ModelPack segmentation model configuration. Use
525 /// `DecoderBuilder.build()` to parse the model configuration.
526 ///
527 /// # Examples
528 /// ```rust
529 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
530 /// # fn main() -> DecoderResult<()> {
531 /// let seg_config = configs::Segmentation {
532 /// decoder: configs::DecoderType::ModelPack,
533 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
534 /// shape: vec![1, 640, 640, 3],
535 /// dshape: Vec::new(),
536 /// };
537 /// let decoder = DecoderBuilder::new()
538 /// .with_config_modelpack_seg(seg_config)
539 /// .build()?;
540 /// # Ok(())
541 /// # }
542 /// ```
543 pub fn with_config_modelpack_seg(mut self, segmentation: configs::Segmentation) -> Self {
544 let config = ConfigOutputs {
545 outputs: vec![ConfigOutput::Segmentation(segmentation)],
546 ..Default::default()
547 };
548 self.config_src.replace(ConfigSource::Config(config));
549 self
550 }
551
552 /// Add an output to the decoder configuration.
553 ///
554 /// Incrementally builds the model configuration by adding outputs one at
555 /// a time. The decoder resolves the model type from the combination of
556 /// outputs during `build()`.
557 ///
558 /// If `dshape` is non-empty on the output, `shape` is automatically
559 /// derived from it (the size component of each named dimension). This
560 /// prevents conflicts between `shape` and `dshape`.
561 ///
562 /// This uses the programmatic config path. Calling this after
563 /// `with_config_json_str()` or `with_config_yaml_str()` replaces the
564 /// string-based config source.
565 ///
566 /// # Examples
567 /// ```rust
568 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, ConfigOutput, configs};
569 /// # fn main() -> DecoderResult<()> {
570 /// let decoder = DecoderBuilder::new()
571 /// .add_output(ConfigOutput::Scores(configs::Scores {
572 /// decoder: configs::DecoderType::Ultralytics,
573 /// dshape: vec![
574 /// (configs::DimName::Batch, 1),
575 /// (configs::DimName::NumClasses, 80),
576 /// (configs::DimName::NumBoxes, 8400),
577 /// ],
578 /// ..Default::default()
579 /// }))
580 /// .add_output(ConfigOutput::Boxes(configs::Boxes {
581 /// decoder: configs::DecoderType::Ultralytics,
582 /// dshape: vec![
583 /// (configs::DimName::Batch, 1),
584 /// (configs::DimName::BoxCoords, 4),
585 /// (configs::DimName::NumBoxes, 8400),
586 /// ],
587 /// ..Default::default()
588 /// }))
589 /// .build()?;
590 /// # Ok(())
591 /// # }
592 /// ```
593 pub fn add_output(mut self, output: ConfigOutput) -> Self {
594 if !matches!(self.config_src, Some(ConfigSource::Config(_))) {
595 self.config_src = Some(ConfigSource::Config(ConfigOutputs::default()));
596 }
597 if let Some(ConfigSource::Config(ref mut config)) = self.config_src {
598 config.outputs.push(Self::normalize_output(output));
599 }
600 self
601 }
602
603 /// Sets the decoder version for Ultralytics models.
604 ///
605 /// This is used with `add_output()` to specify the YOLO architecture
606 /// version when it cannot be inferred from the output shapes alone.
607 ///
608 /// - `Yolov5`, `Yolov8`, `Yolo11`: Traditional models requiring external
609 /// NMS
610 /// - `Yolo26`: End-to-end models with NMS embedded in the model graph
611 ///
612 /// # Examples
613 /// ```rust
614 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, ConfigOutput, configs};
615 /// # fn main() -> DecoderResult<()> {
616 /// let decoder = DecoderBuilder::new()
617 /// .add_output(ConfigOutput::Detection(configs::Detection {
618 /// decoder: configs::DecoderType::Ultralytics,
619 /// dshape: vec![
620 /// (configs::DimName::Batch, 1),
621 /// (configs::DimName::NumBoxes, 100),
622 /// (configs::DimName::NumFeatures, 6),
623 /// ],
624 /// ..Default::default()
625 /// }))
626 /// .with_decoder_version(configs::DecoderVersion::Yolo26)
627 /// .build()?;
628 /// # Ok(())
629 /// # }
630 /// ```
631 pub fn with_decoder_version(mut self, version: configs::DecoderVersion) -> Self {
632 if !matches!(self.config_src, Some(ConfigSource::Config(_))) {
633 self.config_src = Some(ConfigSource::Config(ConfigOutputs::default()));
634 }
635 if let Some(ConfigSource::Config(ref mut config)) = self.config_src {
636 config.decoder_version = Some(version);
637 }
638 self
639 }
640
641 /// Normalize an output: if dshape is non-empty, derive shape from it.
642 fn normalize_output(mut output: ConfigOutput) -> ConfigOutput {
643 fn normalize_shape(shape: &mut Vec<usize>, dshape: &[(configs::DimName, usize)]) {
644 if !dshape.is_empty() {
645 *shape = dshape.iter().map(|(_, size)| *size).collect();
646 }
647 }
648 match &mut output {
649 ConfigOutput::Detection(c) => normalize_shape(&mut c.shape, &c.dshape),
650 ConfigOutput::Boxes(c) => normalize_shape(&mut c.shape, &c.dshape),
651 ConfigOutput::Scores(c) => normalize_shape(&mut c.shape, &c.dshape),
652 ConfigOutput::Protos(c) => normalize_shape(&mut c.shape, &c.dshape),
653 ConfigOutput::Segmentation(c) => normalize_shape(&mut c.shape, &c.dshape),
654 ConfigOutput::MaskCoefficients(c) => normalize_shape(&mut c.shape, &c.dshape),
655 ConfigOutput::Mask(c) => normalize_shape(&mut c.shape, &c.dshape),
656 ConfigOutput::Classes(c) => normalize_shape(&mut c.shape, &c.dshape),
657 }
658 output
659 }
660
661 /// Sets the scores threshold of the decoder
662 ///
663 /// # Examples
664 /// ```rust
665 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
666 /// # fn main() -> DecoderResult<()> {
667 /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
668 /// let decoder = DecoderBuilder::new()
669 /// .with_config_json_str(config_json)
670 /// .with_score_threshold(0.654)
671 /// .build()?;
672 /// assert_eq!(decoder.score_threshold, 0.654);
673 /// # Ok(())
674 /// # }
675 /// ```
676 pub fn with_score_threshold(mut self, score_threshold: f32) -> Self {
677 self.score_threshold = score_threshold;
678 self
679 }
680
681 /// Sets the IOU threshold of the decoder. Has no effect when NMS is set to
682 /// `None`
683 ///
684 /// # Examples
685 /// ```rust
686 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
687 /// # fn main() -> DecoderResult<()> {
688 /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
689 /// let decoder = DecoderBuilder::new()
690 /// .with_config_json_str(config_json)
691 /// .with_iou_threshold(0.654)
692 /// .build()?;
693 /// assert_eq!(decoder.iou_threshold, 0.654);
694 /// # Ok(())
695 /// # }
696 /// ```
697 pub fn with_iou_threshold(mut self, iou_threshold: f32) -> Self {
698 self.iou_threshold = iou_threshold;
699 self
700 }
701
702 /// Sets the NMS mode for the decoder.
703 ///
704 /// - `Some(Nms::ClassAgnostic)` — class-agnostic NMS (default): suppress
705 /// overlapping boxes regardless of class label
706 /// - `Some(Nms::ClassAware)` — class-aware NMS: only suppress boxes that
707 /// share the same class label AND overlap above the IoU threshold
708 /// - `None` — bypass NMS entirely (for end-to-end models with embedded NMS)
709 ///
710 /// # Examples
711 /// ```rust
712 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, configs::Nms};
713 /// # fn main() -> DecoderResult<()> {
714 /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
715 /// let decoder = DecoderBuilder::new()
716 /// .with_config_json_str(config_json)
717 /// .with_nms(Some(Nms::ClassAware))
718 /// .build()?;
719 /// assert_eq!(decoder.nms, Some(Nms::ClassAware));
720 /// # Ok(())
721 /// # }
722 /// ```
723 pub fn with_nms(mut self, nms: Option<configs::Nms>) -> Self {
724 self.nms = nms;
725 self
726 }
727
728 /// Builds the decoder with the given settings. If the config is a JSON or
729 /// YAML string, this will deserialize the JSON or YAML and then parse the
730 /// configuration information.
731 ///
732 /// # Examples
733 /// ```rust
734 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
735 /// # fn main() -> DecoderResult<()> {
736 /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
737 /// let decoder = DecoderBuilder::new()
738 /// .with_config_json_str(config_json)
739 /// .with_score_threshold(0.654)
740 /// .build()?;
741 /// # Ok(())
742 /// # }
743 /// ```
744 pub fn build(self) -> Result<Decoder, DecoderError> {
745 let config = match self.config_src {
746 Some(ConfigSource::Json(s)) => serde_json::from_str(&s)?,
747 Some(ConfigSource::Yaml(s)) => serde_yaml::from_str(&s)?,
748 Some(ConfigSource::Config(c)) => c,
749 None => return Err(DecoderError::NoConfig),
750 };
751
752 // Extract normalized flag from config outputs
753 let normalized = Self::get_normalized(&config.outputs);
754
755 // Use NMS from config if present, otherwise use builder's NMS setting
756 let nms = config.nms.or(self.nms);
757 let model_type = Self::get_model_type(config)?;
758
759 Ok(Decoder {
760 model_type,
761 iou_threshold: self.iou_threshold,
762 score_threshold: self.score_threshold,
763 nms,
764 normalized,
765 })
766 }
767
768 /// Extracts the normalized flag from config outputs.
769 /// - `Some(true)`: Boxes are in normalized [0,1] coordinates
770 /// - `Some(false)`: Boxes are in pixel coordinates
771 /// - `None`: Unknown (not specified in config), caller must infer
772 fn get_normalized(outputs: &[ConfigOutput]) -> Option<bool> {
773 for output in outputs {
774 match output {
775 ConfigOutput::Detection(det) => return det.normalized,
776 ConfigOutput::Boxes(boxes) => return boxes.normalized,
777 _ => {}
778 }
779 }
780 None // not specified
781 }
782
783 fn get_model_type(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
784 // yolo or modelpack
785 let mut yolo = false;
786 let mut modelpack = false;
787 for c in &configs.outputs {
788 match c.decoder() {
789 DecoderType::ModelPack => modelpack = true,
790 DecoderType::Ultralytics => yolo = true,
791 }
792 }
793 match (modelpack, yolo) {
794 (true, true) => Err(DecoderError::InvalidConfig(
795 "Both ModelPack and Yolo outputs found in config".to_string(),
796 )),
797 (true, false) => Self::get_model_type_modelpack(configs),
798 (false, true) => Self::get_model_type_yolo(configs),
799 (false, false) => Err(DecoderError::InvalidConfig(
800 "No outputs found in config".to_string(),
801 )),
802 }
803 }
804
805 fn get_model_type_yolo(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
806 let mut boxes = None;
807 let mut protos = None;
808 let mut split_boxes = None;
809 let mut split_scores = None;
810 let mut split_mask_coeff = None;
811 let mut split_classes = None;
812 for c in configs.outputs {
813 match c {
814 ConfigOutput::Detection(detection) => boxes = Some(detection),
815 ConfigOutput::Segmentation(_) => {
816 return Err(DecoderError::InvalidConfig(
817 "Invalid Segmentation output with Yolo decoder".to_string(),
818 ));
819 }
820 ConfigOutput::Protos(protos_) => protos = Some(protos_),
821 ConfigOutput::Mask(_) => {
822 return Err(DecoderError::InvalidConfig(
823 "Invalid Mask output with Yolo decoder".to_string(),
824 ));
825 }
826 ConfigOutput::Scores(scores) => split_scores = Some(scores),
827 ConfigOutput::Boxes(boxes) => split_boxes = Some(boxes),
828 ConfigOutput::MaskCoefficients(mask_coeff) => split_mask_coeff = Some(mask_coeff),
829 ConfigOutput::Classes(classes) => split_classes = Some(classes),
830 }
831 }
832
833 // Use end-to-end model types when:
834 // 1. decoder_version is explicitly set to Yolo26 (definitive), OR
835 // decoder_version is not set but the dshapes are (batch, num_boxes,
836 // num_features)
837 let is_end_to_end_dshape = boxes.as_ref().is_some_and(|b| {
838 let dims = b.dshape.iter().map(|(d, _)| *d).collect::<Vec<_>>();
839 dims == vec![DimName::Batch, DimName::NumBoxes, DimName::NumFeatures]
840 });
841
842 let is_end_to_end = configs
843 .decoder_version
844 .map(|v| v.is_end_to_end())
845 .unwrap_or(is_end_to_end_dshape);
846
847 if is_end_to_end {
848 if let Some(boxes) = boxes {
849 if let Some(protos) = protos {
850 Self::verify_yolo_seg_det_26(&boxes, &protos)?;
851 return Ok(ModelType::YoloEndToEndSegDet { boxes, protos });
852 } else {
853 Self::verify_yolo_det_26(&boxes)?;
854 return Ok(ModelType::YoloEndToEndDet { boxes });
855 }
856 } else if let (Some(split_boxes), Some(split_scores), Some(split_classes)) =
857 (split_boxes, split_scores, split_classes)
858 {
859 if let (Some(split_mask_coeff), Some(protos)) = (split_mask_coeff, protos) {
860 Self::verify_yolo_split_end_to_end_segdet(
861 &split_boxes,
862 &split_scores,
863 &split_classes,
864 &split_mask_coeff,
865 &protos,
866 )?;
867 return Ok(ModelType::YoloSplitEndToEndSegDet {
868 boxes: split_boxes,
869 scores: split_scores,
870 classes: split_classes,
871 mask_coeff: split_mask_coeff,
872 protos,
873 });
874 }
875 Self::verify_yolo_split_end_to_end_det(
876 &split_boxes,
877 &split_scores,
878 &split_classes,
879 )?;
880 return Ok(ModelType::YoloSplitEndToEndDet {
881 boxes: split_boxes,
882 scores: split_scores,
883 classes: split_classes,
884 });
885 } else {
886 return Err(DecoderError::InvalidConfig(
887 "Invalid Yolo end-to-end model outputs".to_string(),
888 ));
889 }
890 }
891
892 if let Some(boxes) = boxes {
893 if let Some(protos) = protos {
894 Self::verify_yolo_seg_det(&boxes, &protos)?;
895 Ok(ModelType::YoloSegDet { boxes, protos })
896 } else {
897 Self::verify_yolo_det(&boxes)?;
898 Ok(ModelType::YoloDet { boxes })
899 }
900 } else if let (Some(boxes), Some(scores)) = (split_boxes, split_scores) {
901 if let (Some(mask_coeff), Some(protos)) = (split_mask_coeff, protos) {
902 Self::verify_yolo_split_segdet(&boxes, &scores, &mask_coeff, &protos)?;
903 Ok(ModelType::YoloSplitSegDet {
904 boxes,
905 scores,
906 mask_coeff,
907 protos,
908 })
909 } else {
910 Self::verify_yolo_split_det(&boxes, &scores)?;
911 Ok(ModelType::YoloSplitDet { boxes, scores })
912 }
913 } else {
914 Err(DecoderError::InvalidConfig(
915 "Invalid Yolo model outputs".to_string(),
916 ))
917 }
918 }
919
920 fn verify_yolo_det(detect: &configs::Detection) -> Result<(), DecoderError> {
921 if detect.shape.len() != 3 {
922 return Err(DecoderError::InvalidConfig(format!(
923 "Invalid Yolo Detection shape {:?}",
924 detect.shape
925 )));
926 }
927
928 Self::verify_dshapes(
929 &detect.dshape,
930 &detect.shape,
931 "Detection",
932 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
933 )?;
934 if !detect.dshape.is_empty() {
935 Self::get_class_count(&detect.dshape, None, None)?;
936 } else {
937 Self::get_class_count_no_dshape(detect.into(), None)?;
938 }
939
940 Ok(())
941 }
942
943 fn verify_yolo_det_26(detect: &configs::Detection) -> Result<(), DecoderError> {
944 if detect.shape.len() != 3 {
945 return Err(DecoderError::InvalidConfig(format!(
946 "Invalid Yolo Detection shape {:?}",
947 detect.shape
948 )));
949 }
950
951 Self::verify_dshapes(
952 &detect.dshape,
953 &detect.shape,
954 "Detection",
955 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
956 )?;
957
958 if !detect.shape.contains(&6) {
959 return Err(DecoderError::InvalidConfig(
960 "Yolo26 Detection must have 6 features".to_string(),
961 ));
962 }
963
964 Ok(())
965 }
966
967 fn verify_yolo_seg_det(
968 detection: &configs::Detection,
969 protos: &configs::Protos,
970 ) -> Result<(), DecoderError> {
971 if detection.shape.len() != 3 {
972 return Err(DecoderError::InvalidConfig(format!(
973 "Invalid Yolo Detection shape {:?}",
974 detection.shape
975 )));
976 }
977 if protos.shape.len() != 4 {
978 return Err(DecoderError::InvalidConfig(format!(
979 "Invalid Yolo Protos shape {:?}",
980 protos.shape
981 )));
982 }
983
984 Self::verify_dshapes(
985 &detection.dshape,
986 &detection.shape,
987 "Detection",
988 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
989 )?;
990 Self::verify_dshapes(
991 &protos.dshape,
992 &protos.shape,
993 "Protos",
994 &[
995 DimName::Batch,
996 DimName::Height,
997 DimName::Width,
998 DimName::NumProtos,
999 ],
1000 )?;
1001
1002 let protos_count = Self::get_protos_count(&protos.dshape)
1003 .unwrap_or_else(|| protos.shape[1].min(protos.shape[3]));
1004 log::debug!("Protos count: {}", protos_count);
1005 log::debug!("Detection dshape: {:?}", detection.dshape);
1006 let classes = if !detection.dshape.is_empty() {
1007 Self::get_class_count(&detection.dshape, Some(protos_count), None)?
1008 } else {
1009 Self::get_class_count_no_dshape(detection.into(), Some(protos_count))?
1010 };
1011
1012 if classes == 0 {
1013 return Err(DecoderError::InvalidConfig(
1014 "Yolo Segmentation Detection has zero classes".to_string(),
1015 ));
1016 }
1017
1018 Ok(())
1019 }
1020
1021 fn verify_yolo_seg_det_26(
1022 detection: &configs::Detection,
1023 protos: &configs::Protos,
1024 ) -> Result<(), DecoderError> {
1025 if detection.shape.len() != 3 {
1026 return Err(DecoderError::InvalidConfig(format!(
1027 "Invalid Yolo Detection shape {:?}",
1028 detection.shape
1029 )));
1030 }
1031 if protos.shape.len() != 4 {
1032 return Err(DecoderError::InvalidConfig(format!(
1033 "Invalid Yolo Protos shape {:?}",
1034 protos.shape
1035 )));
1036 }
1037
1038 Self::verify_dshapes(
1039 &detection.dshape,
1040 &detection.shape,
1041 "Detection",
1042 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1043 )?;
1044 Self::verify_dshapes(
1045 &protos.dshape,
1046 &protos.shape,
1047 "Protos",
1048 &[
1049 DimName::Batch,
1050 DimName::Height,
1051 DimName::Width,
1052 DimName::NumProtos,
1053 ],
1054 )?;
1055
1056 let protos_count = Self::get_protos_count(&protos.dshape)
1057 .unwrap_or_else(|| protos.shape[1].min(protos.shape[3]));
1058 log::debug!("Protos count: {}", protos_count);
1059 log::debug!("Detection dshape: {:?}", detection.dshape);
1060
1061 if !detection.shape.contains(&(6 + protos_count)) {
1062 return Err(DecoderError::InvalidConfig(format!(
1063 "Yolo26 Segmentation Detection must have num_features be 6 + num_protos = {}",
1064 6 + protos_count
1065 )));
1066 }
1067
1068 Ok(())
1069 }
1070
1071 fn verify_yolo_split_det(
1072 boxes: &configs::Boxes,
1073 scores: &configs::Scores,
1074 ) -> Result<(), DecoderError> {
1075 if boxes.shape.len() != 3 {
1076 return Err(DecoderError::InvalidConfig(format!(
1077 "Invalid Yolo Split Boxes shape {:?}",
1078 boxes.shape
1079 )));
1080 }
1081 if scores.shape.len() != 3 {
1082 return Err(DecoderError::InvalidConfig(format!(
1083 "Invalid Yolo Split Scores shape {:?}",
1084 scores.shape
1085 )));
1086 }
1087
1088 Self::verify_dshapes(
1089 &boxes.dshape,
1090 &boxes.shape,
1091 "Boxes",
1092 &[DimName::Batch, DimName::BoxCoords, DimName::NumBoxes],
1093 )?;
1094 Self::verify_dshapes(
1095 &scores.dshape,
1096 &scores.shape,
1097 "Scores",
1098 &[DimName::Batch, DimName::NumClasses, DimName::NumBoxes],
1099 )?;
1100
1101 let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[2]);
1102 let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[2]);
1103
1104 if boxes_num != scores_num {
1105 return Err(DecoderError::InvalidConfig(format!(
1106 "Yolo Split Detection Boxes num {} incompatible with Scores num {}",
1107 boxes_num, scores_num
1108 )));
1109 }
1110
1111 Ok(())
1112 }
1113
1114 fn verify_yolo_split_segdet(
1115 boxes: &configs::Boxes,
1116 scores: &configs::Scores,
1117 mask_coeff: &configs::MaskCoefficients,
1118 protos: &configs::Protos,
1119 ) -> Result<(), DecoderError> {
1120 if boxes.shape.len() != 3 {
1121 return Err(DecoderError::InvalidConfig(format!(
1122 "Invalid Yolo Split Boxes shape {:?}",
1123 boxes.shape
1124 )));
1125 }
1126 if scores.shape.len() != 3 {
1127 return Err(DecoderError::InvalidConfig(format!(
1128 "Invalid Yolo Split Scores shape {:?}",
1129 scores.shape
1130 )));
1131 }
1132
1133 if mask_coeff.shape.len() != 3 {
1134 return Err(DecoderError::InvalidConfig(format!(
1135 "Invalid Yolo Split Mask Coefficients shape {:?}",
1136 mask_coeff.shape
1137 )));
1138 }
1139
1140 if protos.shape.len() != 4 {
1141 return Err(DecoderError::InvalidConfig(format!(
1142 "Invalid Yolo Protos shape {:?}",
1143 mask_coeff.shape
1144 )));
1145 }
1146
1147 Self::verify_dshapes(
1148 &boxes.dshape,
1149 &boxes.shape,
1150 "Boxes",
1151 &[DimName::Batch, DimName::BoxCoords, DimName::NumBoxes],
1152 )?;
1153 Self::verify_dshapes(
1154 &scores.dshape,
1155 &scores.shape,
1156 "Scores",
1157 &[DimName::Batch, DimName::NumClasses, DimName::NumBoxes],
1158 )?;
1159 Self::verify_dshapes(
1160 &mask_coeff.dshape,
1161 &mask_coeff.shape,
1162 "Mask Coefficients",
1163 &[DimName::Batch, DimName::NumProtos, DimName::NumBoxes],
1164 )?;
1165 Self::verify_dshapes(
1166 &protos.dshape,
1167 &protos.shape,
1168 "Protos",
1169 &[
1170 DimName::Batch,
1171 DimName::Height,
1172 DimName::Width,
1173 DimName::NumProtos,
1174 ],
1175 )?;
1176
1177 let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[2]);
1178 let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[2]);
1179 let mask_num = Self::get_box_count(&mask_coeff.dshape).unwrap_or(mask_coeff.shape[2]);
1180
1181 let mask_channels = if !mask_coeff.dshape.is_empty() {
1182 Self::get_protos_count(&mask_coeff.dshape).ok_or_else(|| {
1183 DecoderError::InvalidConfig("Could not find num_protos in config".to_string())
1184 })?
1185 } else {
1186 mask_coeff.shape[1]
1187 };
1188 let proto_channels = if !protos.dshape.is_empty() {
1189 Self::get_protos_count(&protos.dshape).ok_or_else(|| {
1190 DecoderError::InvalidConfig("Could not find num_protos in config".to_string())
1191 })?
1192 } else {
1193 protos.shape[1].min(protos.shape[3])
1194 };
1195
1196 if boxes_num != scores_num {
1197 return Err(DecoderError::InvalidConfig(format!(
1198 "Yolo Split Detection Boxes num {} incompatible with Scores num {}",
1199 boxes_num, scores_num
1200 )));
1201 }
1202
1203 if boxes_num != mask_num {
1204 return Err(DecoderError::InvalidConfig(format!(
1205 "Yolo Split Detection Boxes num {} incompatible with Mask Coefficients num {}",
1206 boxes_num, mask_num
1207 )));
1208 }
1209
1210 if proto_channels != mask_channels {
1211 return Err(DecoderError::InvalidConfig(format!(
1212 "Yolo Protos channels {} incompatible with Mask Coefficients channels {}",
1213 proto_channels, mask_channels
1214 )));
1215 }
1216
1217 Ok(())
1218 }
1219
1220 fn verify_yolo_split_end_to_end_det(
1221 boxes: &configs::Boxes,
1222 scores: &configs::Scores,
1223 classes: &configs::Classes,
1224 ) -> Result<(), DecoderError> {
1225 if boxes.shape.len() != 3 || !boxes.shape.contains(&4) {
1226 return Err(DecoderError::InvalidConfig(format!(
1227 "Split end-to-end boxes must be [batch, N, 4], got {:?}",
1228 boxes.shape
1229 )));
1230 }
1231 if scores.shape.len() != 3 || !scores.shape.contains(&1) {
1232 return Err(DecoderError::InvalidConfig(format!(
1233 "Split end-to-end scores must be [batch, N, 1], got {:?}",
1234 scores.shape
1235 )));
1236 }
1237 if classes.shape.len() != 3 || !classes.shape.contains(&1) {
1238 return Err(DecoderError::InvalidConfig(format!(
1239 "Split end-to-end classes must be [batch, N, 1], got {:?}",
1240 classes.shape
1241 )));
1242 }
1243 Ok(())
1244 }
1245
1246 fn verify_yolo_split_end_to_end_segdet(
1247 boxes: &configs::Boxes,
1248 scores: &configs::Scores,
1249 classes: &configs::Classes,
1250 mask_coeff: &configs::MaskCoefficients,
1251 protos: &configs::Protos,
1252 ) -> Result<(), DecoderError> {
1253 Self::verify_yolo_split_end_to_end_det(boxes, scores, classes)?;
1254 if mask_coeff.shape.len() != 3 {
1255 return Err(DecoderError::InvalidConfig(format!(
1256 "Invalid split end-to-end mask coefficients shape {:?}",
1257 mask_coeff.shape
1258 )));
1259 }
1260 if protos.shape.len() != 4 {
1261 return Err(DecoderError::InvalidConfig(format!(
1262 "Invalid protos shape {:?}",
1263 protos.shape
1264 )));
1265 }
1266 Ok(())
1267 }
1268
1269 fn get_model_type_modelpack(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
1270 let mut split_decoders = Vec::new();
1271 let mut segment_ = None;
1272 let mut scores_ = None;
1273 let mut boxes_ = None;
1274 for c in configs.outputs {
1275 match c {
1276 ConfigOutput::Detection(detection) => split_decoders.push(detection),
1277 ConfigOutput::Segmentation(segmentation) => segment_ = Some(segmentation),
1278 ConfigOutput::Mask(_) => {}
1279 ConfigOutput::Protos(_) => {
1280 return Err(DecoderError::InvalidConfig(
1281 "ModelPack should not have protos".to_string(),
1282 ));
1283 }
1284 ConfigOutput::Scores(scores) => scores_ = Some(scores),
1285 ConfigOutput::Boxes(boxes) => boxes_ = Some(boxes),
1286 ConfigOutput::MaskCoefficients(_) => {
1287 return Err(DecoderError::InvalidConfig(
1288 "ModelPack should not have mask coefficients".to_string(),
1289 ));
1290 }
1291 ConfigOutput::Classes(_) => {
1292 return Err(DecoderError::InvalidConfig(
1293 "ModelPack should not have classes output".to_string(),
1294 ));
1295 }
1296 }
1297 }
1298
1299 if let Some(segmentation) = segment_ {
1300 if !split_decoders.is_empty() {
1301 let classes = Self::verify_modelpack_split_det(&split_decoders)?;
1302 Self::verify_modelpack_seg(&segmentation, Some(classes))?;
1303 Ok(ModelType::ModelPackSegDetSplit {
1304 detection: split_decoders,
1305 segmentation,
1306 })
1307 } else if let (Some(scores), Some(boxes)) = (scores_, boxes_) {
1308 let classes = Self::verify_modelpack_det(&boxes, &scores)?;
1309 Self::verify_modelpack_seg(&segmentation, Some(classes))?;
1310 Ok(ModelType::ModelPackSegDet {
1311 boxes,
1312 scores,
1313 segmentation,
1314 })
1315 } else {
1316 Self::verify_modelpack_seg(&segmentation, None)?;
1317 Ok(ModelType::ModelPackSeg { segmentation })
1318 }
1319 } else if !split_decoders.is_empty() {
1320 Self::verify_modelpack_split_det(&split_decoders)?;
1321 Ok(ModelType::ModelPackDetSplit {
1322 detection: split_decoders,
1323 })
1324 } else if let (Some(scores), Some(boxes)) = (scores_, boxes_) {
1325 Self::verify_modelpack_det(&boxes, &scores)?;
1326 Ok(ModelType::ModelPackDet { boxes, scores })
1327 } else {
1328 Err(DecoderError::InvalidConfig(
1329 "Invalid ModelPack model outputs".to_string(),
1330 ))
1331 }
1332 }
1333
1334 fn verify_modelpack_det(
1335 boxes: &configs::Boxes,
1336 scores: &configs::Scores,
1337 ) -> Result<usize, DecoderError> {
1338 if boxes.shape.len() != 4 {
1339 return Err(DecoderError::InvalidConfig(format!(
1340 "Invalid ModelPack Boxes shape {:?}",
1341 boxes.shape
1342 )));
1343 }
1344 if scores.shape.len() != 3 {
1345 return Err(DecoderError::InvalidConfig(format!(
1346 "Invalid ModelPack Scores shape {:?}",
1347 scores.shape
1348 )));
1349 }
1350
1351 Self::verify_dshapes(
1352 &boxes.dshape,
1353 &boxes.shape,
1354 "Boxes",
1355 &[
1356 DimName::Batch,
1357 DimName::NumBoxes,
1358 DimName::Padding,
1359 DimName::BoxCoords,
1360 ],
1361 )?;
1362 Self::verify_dshapes(
1363 &scores.dshape,
1364 &scores.shape,
1365 "Scores",
1366 &[DimName::Batch, DimName::NumBoxes, DimName::NumClasses],
1367 )?;
1368
1369 let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[1]);
1370 let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[1]);
1371
1372 if boxes_num != scores_num {
1373 return Err(DecoderError::InvalidConfig(format!(
1374 "ModelPack Detection Boxes num {} incompatible with Scores num {}",
1375 boxes_num, scores_num
1376 )));
1377 }
1378
1379 let num_classes = if !scores.dshape.is_empty() {
1380 Self::get_class_count(&scores.dshape, None, None)?
1381 } else {
1382 Self::get_class_count_no_dshape(scores.into(), None)?
1383 };
1384
1385 Ok(num_classes)
1386 }
1387
1388 fn verify_modelpack_split_det(boxes: &[configs::Detection]) -> Result<usize, DecoderError> {
1389 let mut num_classes = None;
1390 for b in boxes {
1391 let Some(num_anchors) = b.anchors.as_ref().map(|a| a.len()) else {
1392 return Err(DecoderError::InvalidConfig(
1393 "ModelPack Split Detection missing anchors".to_string(),
1394 ));
1395 };
1396
1397 if num_anchors == 0 {
1398 return Err(DecoderError::InvalidConfig(
1399 "ModelPack Split Detection has zero anchors".to_string(),
1400 ));
1401 }
1402
1403 if b.shape.len() != 4 {
1404 return Err(DecoderError::InvalidConfig(format!(
1405 "Invalid ModelPack Split Detection shape {:?}",
1406 b.shape
1407 )));
1408 }
1409
1410 Self::verify_dshapes(
1411 &b.dshape,
1412 &b.shape,
1413 "Split Detection",
1414 &[
1415 DimName::Batch,
1416 DimName::Height,
1417 DimName::Width,
1418 DimName::NumAnchorsXFeatures,
1419 ],
1420 )?;
1421 let classes = if !b.dshape.is_empty() {
1422 Self::get_class_count(&b.dshape, None, Some(num_anchors))?
1423 } else {
1424 Self::get_class_count_no_dshape(b.into(), None)?
1425 };
1426
1427 match num_classes {
1428 Some(n) => {
1429 if n != classes {
1430 return Err(DecoderError::InvalidConfig(format!(
1431 "ModelPack Split Detection inconsistent number of classes: previous {}, current {}",
1432 n, classes
1433 )));
1434 }
1435 }
1436 None => {
1437 num_classes = Some(classes);
1438 }
1439 }
1440 }
1441
1442 Ok(num_classes.unwrap_or(0))
1443 }
1444
1445 fn verify_modelpack_seg(
1446 segmentation: &configs::Segmentation,
1447 classes: Option<usize>,
1448 ) -> Result<(), DecoderError> {
1449 if segmentation.shape.len() != 4 {
1450 return Err(DecoderError::InvalidConfig(format!(
1451 "Invalid ModelPack Segmentation shape {:?}",
1452 segmentation.shape
1453 )));
1454 }
1455 Self::verify_dshapes(
1456 &segmentation.dshape,
1457 &segmentation.shape,
1458 "Segmentation",
1459 &[
1460 DimName::Batch,
1461 DimName::Height,
1462 DimName::Width,
1463 DimName::NumClasses,
1464 ],
1465 )?;
1466
1467 if let Some(classes) = classes {
1468 let seg_classes = if !segmentation.dshape.is_empty() {
1469 Self::get_class_count(&segmentation.dshape, None, None)?
1470 } else {
1471 Self::get_class_count_no_dshape(segmentation.into(), None)?
1472 };
1473
1474 if seg_classes != classes + 1 {
1475 return Err(DecoderError::InvalidConfig(format!(
1476 "ModelPack Segmentation channels {} incompatible with number of classes {}",
1477 seg_classes, classes
1478 )));
1479 }
1480 }
1481 Ok(())
1482 }
1483
1484 // verifies that dshapes match the given shape
1485 fn verify_dshapes(
1486 dshape: &[(DimName, usize)],
1487 shape: &[usize],
1488 name: &str,
1489 dims: &[DimName],
1490 ) -> Result<(), DecoderError> {
1491 for s in shape {
1492 if *s == 0 {
1493 return Err(DecoderError::InvalidConfig(format!(
1494 "{} shape has zero dimension",
1495 name
1496 )));
1497 }
1498 }
1499
1500 if shape.len() != dims.len() {
1501 return Err(DecoderError::InvalidConfig(format!(
1502 "{} shape length {} does not match expected dims length {}",
1503 name,
1504 shape.len(),
1505 dims.len()
1506 )));
1507 }
1508
1509 if dshape.is_empty() {
1510 return Ok(());
1511 }
1512 // check the dshape lengths match the shape lengths
1513 if dshape.len() != shape.len() {
1514 return Err(DecoderError::InvalidConfig(format!(
1515 "{} dshape length does not match shape length",
1516 name
1517 )));
1518 }
1519
1520 // check the dshape values match the shape values
1521 for ((dim_name, dim_size), shape_size) in dshape.iter().zip(shape) {
1522 if dim_size != shape_size {
1523 return Err(DecoderError::InvalidConfig(format!(
1524 "{} dshape dimension {} size {} does not match shape size {}",
1525 name, dim_name, dim_size, shape_size
1526 )));
1527 }
1528 if *dim_name == DimName::Padding && *dim_size != 1 {
1529 return Err(DecoderError::InvalidConfig(
1530 "Padding dimension size must be 1".to_string(),
1531 ));
1532 }
1533
1534 if *dim_name == DimName::BoxCoords && *dim_size != 4 {
1535 return Err(DecoderError::InvalidConfig(
1536 "BoxCoords dimension size must be 4".to_string(),
1537 ));
1538 }
1539 }
1540
1541 let dims_present = HashSet::<DimName>::from_iter(dshape.iter().map(|(name, _)| *name));
1542 for dim in dims {
1543 if !dims_present.contains(dim) {
1544 return Err(DecoderError::InvalidConfig(format!(
1545 "{} dshape missing required dimension {:?}",
1546 name, dim
1547 )));
1548 }
1549 }
1550
1551 Ok(())
1552 }
1553
1554 fn get_box_count(dshape: &[(DimName, usize)]) -> Option<usize> {
1555 for (dim_name, dim_size) in dshape {
1556 if *dim_name == DimName::NumBoxes {
1557 return Some(*dim_size);
1558 }
1559 }
1560 None
1561 }
1562
1563 fn get_class_count_no_dshape(
1564 config: ConfigOutputRef,
1565 protos: Option<usize>,
1566 ) -> Result<usize, DecoderError> {
1567 match config {
1568 ConfigOutputRef::Detection(detection) => match detection.decoder {
1569 DecoderType::Ultralytics => {
1570 if detection.shape[1] <= 4 + protos.unwrap_or(0) {
1571 return Err(DecoderError::InvalidConfig(format!(
1572 "Invalid shape: Yolo num_features {} must be greater than {}",
1573 detection.shape[1],
1574 4 + protos.unwrap_or(0),
1575 )));
1576 }
1577 Ok(detection.shape[1] - 4 - protos.unwrap_or(0))
1578 }
1579 DecoderType::ModelPack => {
1580 let Some(num_anchors) = detection.anchors.as_ref().map(|a| a.len()) else {
1581 return Err(DecoderError::Internal(
1582 "ModelPack Detection missing anchors".to_string(),
1583 ));
1584 };
1585 let anchors_x_features = detection.shape[3];
1586 if anchors_x_features <= num_anchors * 5 {
1587 return Err(DecoderError::InvalidConfig(format!(
1588 "Invalid ModelPack Split Detection shape: anchors_x_features {} not greater than number of anchors * 5 = {}",
1589 anchors_x_features,
1590 num_anchors * 5,
1591 )));
1592 }
1593
1594 if !anchors_x_features.is_multiple_of(num_anchors) {
1595 return Err(DecoderError::InvalidConfig(format!(
1596 "Invalid ModelPack Split Detection shape: anchors_x_features {} not a multiple of number of anchors {}",
1597 anchors_x_features, num_anchors
1598 )));
1599 }
1600 Ok(anchors_x_features / num_anchors - 5)
1601 }
1602 },
1603
1604 ConfigOutputRef::Scores(scores) => match scores.decoder {
1605 DecoderType::Ultralytics => Ok(scores.shape[1]),
1606 DecoderType::ModelPack => Ok(scores.shape[2]),
1607 },
1608 ConfigOutputRef::Segmentation(seg) => Ok(seg.shape[3]),
1609 _ => Err(DecoderError::Internal(
1610 "Attempted to get class count from unsupported config output".to_owned(),
1611 )),
1612 }
1613 }
1614
1615 // get the class count from dshape or calculate from num_features
1616 fn get_class_count(
1617 dshape: &[(DimName, usize)],
1618 protos: Option<usize>,
1619 anchors: Option<usize>,
1620 ) -> Result<usize, DecoderError> {
1621 if dshape.is_empty() {
1622 return Ok(0);
1623 }
1624 // if it has num_classes in dshape, return it
1625 for (dim_name, dim_size) in dshape {
1626 if *dim_name == DimName::NumClasses {
1627 return Ok(*dim_size);
1628 }
1629 }
1630
1631 // number of classes can be calculated from num_features - 4 for yolo. If the
1632 // model has protos, we also subtract the number of protos.
1633 for (dim_name, dim_size) in dshape {
1634 if *dim_name == DimName::NumFeatures {
1635 let protos = protos.unwrap_or(0);
1636 if protos + 4 >= *dim_size {
1637 return Err(DecoderError::InvalidConfig(format!(
1638 "Invalid shape: Yolo num_features {} must be greater than {}",
1639 *dim_size,
1640 protos + 4,
1641 )));
1642 }
1643 return Ok(*dim_size - 4 - protos);
1644 }
1645 }
1646
1647 // number of classes can be calculated from number of anchors for modelpack
1648 // split detection
1649 if let Some(num_anchors) = anchors {
1650 for (dim_name, dim_size) in dshape {
1651 if *dim_name == DimName::NumAnchorsXFeatures {
1652 let anchors_x_features = *dim_size;
1653 if anchors_x_features <= num_anchors * 5 {
1654 return Err(DecoderError::InvalidConfig(format!(
1655 "Invalid ModelPack Split Detection shape: anchors_x_features {} not greater than number of anchors * 5 = {}",
1656 anchors_x_features,
1657 num_anchors * 5,
1658 )));
1659 }
1660
1661 if !anchors_x_features.is_multiple_of(num_anchors) {
1662 return Err(DecoderError::InvalidConfig(format!(
1663 "Invalid ModelPack Split Detection shape: anchors_x_features {} not a multiple of number of anchors {}",
1664 anchors_x_features, num_anchors
1665 )));
1666 }
1667 return Ok((anchors_x_features / num_anchors) - 5);
1668 }
1669 }
1670 }
1671 Err(DecoderError::InvalidConfig(
1672 "Cannot determine number of classes from dshape".to_owned(),
1673 ))
1674 }
1675
1676 fn get_protos_count(dshape: &[(DimName, usize)]) -> Option<usize> {
1677 for (dim_name, dim_size) in dshape {
1678 if *dim_name == DimName::NumProtos {
1679 return Some(*dim_size);
1680 }
1681 }
1682 None
1683 }
1684}