edgefirst_decoder/decoder/builder.rs
1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::HashSet;
5
6use super::config::ConfigOutputRef;
7use super::configs::{self, DecoderType, DecoderVersion, DimName, ModelType};
8use super::merge::DecodeProgram;
9use super::{ConfigOutput, ConfigOutputs, Decoder};
10use crate::schema::SchemaV2;
11use crate::DecoderError;
12
13#[derive(Debug, Clone, PartialEq)]
14pub struct DecoderBuilder {
15 config_src: Option<ConfigSource>,
16 iou_threshold: f32,
17 score_threshold: f32,
18 /// NMS mode: Some(mode) applies NMS, None bypasses NMS (for end-to-end
19 /// models)
20 nms: Option<configs::Nms>,
21}
22
23#[derive(Debug, Clone, PartialEq)]
24enum ConfigSource {
25 Yaml(String),
26 Json(String),
27 Config(ConfigOutputs),
28 /// Schema v2 metadata. During build the schema is either converted
29 /// to a legacy [`ConfigOutputs`] (flat case) or used to construct a
30 /// [`DecodeProgram`] that performs per-child dequant + merge at
31 /// decode time.
32 Schema(SchemaV2),
33}
34
35impl Default for DecoderBuilder {
36 /// Creates a default DecoderBuilder with no configuration and 0.5 score
37 /// threshold and 0.5 IoU threshold.
38 ///
39 /// A valid configuration must be provided before building the Decoder.
40 ///
41 /// # Examples
42 /// ```rust
43 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
44 /// # fn main() -> DecoderResult<()> {
45 /// # let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
46 /// let decoder = DecoderBuilder::default()
47 /// .with_config_yaml_str(config_yaml)
48 /// .build()?;
49 /// assert_eq!(decoder.score_threshold, 0.5);
50 /// assert_eq!(decoder.iou_threshold, 0.5);
51 ///
52 /// # Ok(())
53 /// # }
54 /// ```
55 fn default() -> Self {
56 Self {
57 config_src: None,
58 iou_threshold: 0.5,
59 score_threshold: 0.5,
60 nms: Some(configs::Nms::ClassAgnostic),
61 }
62 }
63}
64
65impl DecoderBuilder {
66 /// Creates a default DecoderBuilder with no configuration and 0.5 score
67 /// threshold and 0.5 IoU threshold.
68 ///
69 /// A valid configuration must be provided before building the Decoder.
70 ///
71 /// # Examples
72 /// ```rust
73 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
74 /// # fn main() -> DecoderResult<()> {
75 /// # let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
76 /// let decoder = DecoderBuilder::new()
77 /// .with_config_yaml_str(config_yaml)
78 /// .build()?;
79 /// assert_eq!(decoder.score_threshold, 0.5);
80 /// assert_eq!(decoder.iou_threshold, 0.5);
81 ///
82 /// # Ok(())
83 /// # }
84 /// ```
85 pub fn new() -> Self {
86 Self::default()
87 }
88
89 /// Loads a model configuration in YAML format. Does not check if the string
90 /// is a correct configuration file. Use `DecoderBuilder.build()` to
91 /// deserialize the YAML and parse the model configuration.
92 ///
93 /// # Examples
94 /// ```rust
95 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
96 /// # fn main() -> DecoderResult<()> {
97 /// let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
98 /// let decoder = DecoderBuilder::new()
99 /// .with_config_yaml_str(config_yaml)
100 /// .build()?;
101 ///
102 /// # Ok(())
103 /// # }
104 /// ```
105 pub fn with_config_yaml_str(mut self, yaml_str: String) -> Self {
106 self.config_src.replace(ConfigSource::Yaml(yaml_str));
107 self
108 }
109
110 /// Loads a model configuration in JSON format. Does not check if the string
111 /// is a correct configuration file. Use `DecoderBuilder.build()` to
112 /// deserialize the JSON and parse the model configuration.
113 ///
114 /// # Examples
115 /// ```rust
116 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
117 /// # fn main() -> DecoderResult<()> {
118 /// let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
119 /// let decoder = DecoderBuilder::new()
120 /// .with_config_json_str(config_json)
121 /// .build()?;
122 ///
123 /// # Ok(())
124 /// # }
125 /// ```
126 pub fn with_config_json_str(mut self, json_str: String) -> Self {
127 self.config_src.replace(ConfigSource::Json(json_str));
128 self
129 }
130
131 /// Loads a model configuration. Does not check if the configuration is
132 /// correct. Intended to be used when the user needs control over the
133 /// deserialize of the configuration information. Use
134 /// `DecoderBuilder.build()` to parse the model configuration.
135 ///
136 /// # Examples
137 /// ```rust
138 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
139 /// # fn main() -> DecoderResult<()> {
140 /// let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json"));
141 /// let config = serde_json::from_str(config_json)?;
142 /// let decoder = DecoderBuilder::new().with_config(config).build()?;
143 ///
144 /// # Ok(())
145 /// # }
146 /// ```
147 pub fn with_config(mut self, config: ConfigOutputs) -> Self {
148 self.config_src.replace(ConfigSource::Config(config));
149 self
150 }
151
152 /// Configure the decoder from a schema v2 metadata document.
153 ///
154 /// Accepts a [`SchemaV2`] as produced by [`SchemaV2::parse_json`],
155 /// [`SchemaV2::parse_yaml`], [`SchemaV2::parse_file`], or
156 /// constructed programmatically. The builder validates the schema,
157 /// compiles a [`DecodeProgram`] for any split logical outputs
158 /// (per-scale or channel sub-splits), and downconverts the
159 /// logical-level semantics to the legacy [`ConfigOutputs`]
160 /// representation consumed by the existing decoder dispatch.
161 ///
162 /// # Examples
163 ///
164 /// ```rust,no_run
165 /// use edgefirst_decoder::{DecoderBuilder, DecoderResult};
166 /// use edgefirst_decoder::schema::SchemaV2;
167 ///
168 /// # fn main() -> DecoderResult<()> {
169 /// let schema = SchemaV2::parse_file("model/edgefirst.json")?;
170 /// let decoder = DecoderBuilder::new()
171 /// .with_schema(schema)
172 /// .with_score_threshold(0.25)
173 /// .build()?;
174 /// # Ok(())
175 /// # }
176 /// ```
177 pub fn with_schema(mut self, schema: SchemaV2) -> Self {
178 self.config_src.replace(ConfigSource::Schema(schema));
179 self
180 }
181
182 /// Loads a YOLO detection model configuration. Use
183 /// `DecoderBuilder.build()` to parse the model configuration.
184 ///
185 /// # Examples
186 /// ```rust
187 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
188 /// # fn main() -> DecoderResult<()> {
189 /// let decoder = DecoderBuilder::new()
190 /// .with_config_yolo_det(
191 /// configs::Detection {
192 /// anchors: None,
193 /// decoder: configs::DecoderType::Ultralytics,
194 /// quantization: Some(configs::QuantTuple(0.012345, 26)),
195 /// shape: vec![1, 84, 8400],
196 /// dshape: Vec::new(),
197 /// normalized: Some(true),
198 /// },
199 /// None,
200 /// )
201 /// .build()?;
202 ///
203 /// # Ok(())
204 /// # }
205 /// ```
206 pub fn with_config_yolo_det(
207 mut self,
208 boxes: configs::Detection,
209 version: Option<DecoderVersion>,
210 ) -> Self {
211 let config = ConfigOutputs {
212 outputs: vec![ConfigOutput::Detection(boxes)],
213 decoder_version: version,
214 ..Default::default()
215 };
216 self.config_src.replace(ConfigSource::Config(config));
217 self
218 }
219
220 /// Loads a YOLO split detection model configuration. Use
221 /// `DecoderBuilder.build()` to parse the model configuration.
222 ///
223 /// # Examples
224 /// ```rust
225 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
226 /// # fn main() -> DecoderResult<()> {
227 /// let boxes_config = configs::Boxes {
228 /// decoder: configs::DecoderType::Ultralytics,
229 /// quantization: Some(configs::QuantTuple(0.012345, 26)),
230 /// shape: vec![1, 4, 8400],
231 /// dshape: Vec::new(),
232 /// normalized: Some(true),
233 /// };
234 /// let scores_config = configs::Scores {
235 /// decoder: configs::DecoderType::Ultralytics,
236 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
237 /// shape: vec![1, 80, 8400],
238 /// dshape: Vec::new(),
239 /// };
240 /// let decoder = DecoderBuilder::new()
241 /// .with_config_yolo_split_det(boxes_config, scores_config)
242 /// .build()?;
243 /// # Ok(())
244 /// # }
245 /// ```
246 pub fn with_config_yolo_split_det(
247 mut self,
248 boxes: configs::Boxes,
249 scores: configs::Scores,
250 ) -> Self {
251 let config = ConfigOutputs {
252 outputs: vec![ConfigOutput::Boxes(boxes), ConfigOutput::Scores(scores)],
253 ..Default::default()
254 };
255 self.config_src.replace(ConfigSource::Config(config));
256 self
257 }
258
259 /// Loads a YOLO segmentation model configuration. Use
260 /// `DecoderBuilder.build()` to parse the model configuration.
261 ///
262 /// # Examples
263 /// ```rust
264 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
265 /// # fn main() -> DecoderResult<()> {
266 /// let seg_config = configs::Detection {
267 /// decoder: configs::DecoderType::Ultralytics,
268 /// quantization: Some(configs::QuantTuple(0.012345, 26)),
269 /// shape: vec![1, 116, 8400],
270 /// anchors: None,
271 /// dshape: Vec::new(),
272 /// normalized: Some(true),
273 /// };
274 /// let protos_config = configs::Protos {
275 /// decoder: configs::DecoderType::Ultralytics,
276 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
277 /// shape: vec![1, 160, 160, 32],
278 /// dshape: Vec::new(),
279 /// };
280 /// let decoder = DecoderBuilder::new()
281 /// .with_config_yolo_segdet(
282 /// seg_config,
283 /// protos_config,
284 /// Some(configs::DecoderVersion::Yolov8),
285 /// )
286 /// .build()?;
287 /// # Ok(())
288 /// # }
289 /// ```
290 pub fn with_config_yolo_segdet(
291 mut self,
292 boxes: configs::Detection,
293 protos: configs::Protos,
294 version: Option<DecoderVersion>,
295 ) -> Self {
296 let config = ConfigOutputs {
297 outputs: vec![ConfigOutput::Detection(boxes), ConfigOutput::Protos(protos)],
298 decoder_version: version,
299 ..Default::default()
300 };
301 self.config_src.replace(ConfigSource::Config(config));
302 self
303 }
304
305 /// Loads a YOLO split segmentation model configuration. Use
306 /// `DecoderBuilder.build()` to parse the model configuration.
307 ///
308 /// # Examples
309 /// ```rust
310 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
311 /// # fn main() -> DecoderResult<()> {
312 /// let boxes_config = configs::Boxes {
313 /// decoder: configs::DecoderType::Ultralytics,
314 /// quantization: Some(configs::QuantTuple(0.012345, 26)),
315 /// shape: vec![1, 4, 8400],
316 /// dshape: Vec::new(),
317 /// normalized: Some(true),
318 /// };
319 /// let scores_config = configs::Scores {
320 /// decoder: configs::DecoderType::Ultralytics,
321 /// quantization: Some(configs::QuantTuple(0.012345, 14)),
322 /// shape: vec![1, 80, 8400],
323 /// dshape: Vec::new(),
324 /// };
325 /// let mask_config = configs::MaskCoefficients {
326 /// decoder: configs::DecoderType::Ultralytics,
327 /// quantization: Some(configs::QuantTuple(0.0064123, 125)),
328 /// shape: vec![1, 32, 8400],
329 /// dshape: Vec::new(),
330 /// };
331 /// let protos_config = configs::Protos {
332 /// decoder: configs::DecoderType::Ultralytics,
333 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
334 /// shape: vec![1, 160, 160, 32],
335 /// dshape: Vec::new(),
336 /// };
337 /// let decoder = DecoderBuilder::new()
338 /// .with_config_yolo_split_segdet(boxes_config, scores_config, mask_config, protos_config)
339 /// .build()?;
340 /// # Ok(())
341 /// # }
342 /// ```
343 pub fn with_config_yolo_split_segdet(
344 mut self,
345 boxes: configs::Boxes,
346 scores: configs::Scores,
347 mask_coefficients: configs::MaskCoefficients,
348 protos: configs::Protos,
349 ) -> Self {
350 let config = ConfigOutputs {
351 outputs: vec![
352 ConfigOutput::Boxes(boxes),
353 ConfigOutput::Scores(scores),
354 ConfigOutput::MaskCoefficients(mask_coefficients),
355 ConfigOutput::Protos(protos),
356 ],
357 ..Default::default()
358 };
359 self.config_src.replace(ConfigSource::Config(config));
360 self
361 }
362
363 /// Loads a ModelPack detection model configuration. Use
364 /// `DecoderBuilder.build()` to parse the model configuration.
365 ///
366 /// # Examples
367 /// ```rust
368 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
369 /// # fn main() -> DecoderResult<()> {
370 /// let boxes_config = configs::Boxes {
371 /// decoder: configs::DecoderType::ModelPack,
372 /// quantization: Some(configs::QuantTuple(0.012345, 26)),
373 /// shape: vec![1, 8400, 1, 4],
374 /// dshape: Vec::new(),
375 /// normalized: Some(true),
376 /// };
377 /// let scores_config = configs::Scores {
378 /// decoder: configs::DecoderType::ModelPack,
379 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
380 /// shape: vec![1, 8400, 3],
381 /// dshape: Vec::new(),
382 /// };
383 /// let decoder = DecoderBuilder::new()
384 /// .with_config_modelpack_det(boxes_config, scores_config)
385 /// .build()?;
386 /// # Ok(())
387 /// # }
388 /// ```
389 pub fn with_config_modelpack_det(
390 mut self,
391 boxes: configs::Boxes,
392 scores: configs::Scores,
393 ) -> Self {
394 let config = ConfigOutputs {
395 outputs: vec![ConfigOutput::Boxes(boxes), ConfigOutput::Scores(scores)],
396 ..Default::default()
397 };
398 self.config_src.replace(ConfigSource::Config(config));
399 self
400 }
401
402 /// Loads a ModelPack split detection model configuration. Use
403 /// `DecoderBuilder.build()` to parse the model configuration.
404 ///
405 /// # Examples
406 /// ```rust
407 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
408 /// # fn main() -> DecoderResult<()> {
409 /// let config0 = configs::Detection {
410 /// anchors: Some(vec![
411 /// [0.13750000298023224, 0.2074074000120163],
412 /// [0.2541666626930237, 0.21481481194496155],
413 /// [0.23125000298023224, 0.35185185074806213],
414 /// ]),
415 /// decoder: configs::DecoderType::ModelPack,
416 /// quantization: Some(configs::QuantTuple(0.012345, 26)),
417 /// shape: vec![1, 17, 30, 18],
418 /// dshape: Vec::new(),
419 /// normalized: Some(true),
420 /// };
421 /// let config1 = configs::Detection {
422 /// anchors: Some(vec![
423 /// [0.36666667461395264, 0.31481480598449707],
424 /// [0.38749998807907104, 0.4740740656852722],
425 /// [0.5333333611488342, 0.644444465637207],
426 /// ]),
427 /// decoder: configs::DecoderType::ModelPack,
428 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
429 /// shape: vec![1, 9, 15, 18],
430 /// dshape: Vec::new(),
431 /// normalized: Some(true),
432 /// };
433 ///
434 /// let decoder = DecoderBuilder::new()
435 /// .with_config_modelpack_det_split(vec![config0, config1])
436 /// .build()?;
437 /// # Ok(())
438 /// # }
439 /// ```
440 pub fn with_config_modelpack_det_split(mut self, boxes: Vec<configs::Detection>) -> Self {
441 let outputs = boxes.into_iter().map(ConfigOutput::Detection).collect();
442 let config = ConfigOutputs {
443 outputs,
444 ..Default::default()
445 };
446 self.config_src.replace(ConfigSource::Config(config));
447 self
448 }
449
450 /// Loads a ModelPack segmentation detection model configuration. Use
451 /// `DecoderBuilder.build()` to parse the model configuration.
452 ///
453 /// # Examples
454 /// ```rust
455 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
456 /// # fn main() -> DecoderResult<()> {
457 /// let boxes_config = configs::Boxes {
458 /// decoder: configs::DecoderType::ModelPack,
459 /// quantization: Some(configs::QuantTuple(0.012345, 26)),
460 /// shape: vec![1, 8400, 1, 4],
461 /// dshape: Vec::new(),
462 /// normalized: Some(true),
463 /// };
464 /// let scores_config = configs::Scores {
465 /// decoder: configs::DecoderType::ModelPack,
466 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
467 /// shape: vec![1, 8400, 2],
468 /// dshape: Vec::new(),
469 /// };
470 /// let seg_config = configs::Segmentation {
471 /// decoder: configs::DecoderType::ModelPack,
472 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
473 /// shape: vec![1, 640, 640, 3],
474 /// dshape: Vec::new(),
475 /// };
476 /// let decoder = DecoderBuilder::new()
477 /// .with_config_modelpack_segdet(boxes_config, scores_config, seg_config)
478 /// .build()?;
479 /// # Ok(())
480 /// # }
481 /// ```
482 pub fn with_config_modelpack_segdet(
483 mut self,
484 boxes: configs::Boxes,
485 scores: configs::Scores,
486 segmentation: configs::Segmentation,
487 ) -> Self {
488 let config = ConfigOutputs {
489 outputs: vec![
490 ConfigOutput::Boxes(boxes),
491 ConfigOutput::Scores(scores),
492 ConfigOutput::Segmentation(segmentation),
493 ],
494 ..Default::default()
495 };
496 self.config_src.replace(ConfigSource::Config(config));
497 self
498 }
499
500 /// Loads a ModelPack segmentation split detection model configuration. Use
501 /// `DecoderBuilder.build()` to parse the model configuration.
502 ///
503 /// # Examples
504 /// ```rust
505 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
506 /// # fn main() -> DecoderResult<()> {
507 /// let config0 = configs::Detection {
508 /// anchors: Some(vec![
509 /// [0.36666667461395264, 0.31481480598449707],
510 /// [0.38749998807907104, 0.4740740656852722],
511 /// [0.5333333611488342, 0.644444465637207],
512 /// ]),
513 /// decoder: configs::DecoderType::ModelPack,
514 /// quantization: Some(configs::QuantTuple(0.08547406643629074, 174)),
515 /// shape: vec![1, 9, 15, 18],
516 /// dshape: Vec::new(),
517 /// normalized: Some(true),
518 /// };
519 /// let config1 = configs::Detection {
520 /// anchors: Some(vec![
521 /// [0.13750000298023224, 0.2074074000120163],
522 /// [0.2541666626930237, 0.21481481194496155],
523 /// [0.23125000298023224, 0.35185185074806213],
524 /// ]),
525 /// decoder: configs::DecoderType::ModelPack,
526 /// quantization: Some(configs::QuantTuple(0.09929127991199493, 183)),
527 /// shape: vec![1, 17, 30, 18],
528 /// dshape: Vec::new(),
529 /// normalized: Some(true),
530 /// };
531 /// let seg_config = configs::Segmentation {
532 /// decoder: configs::DecoderType::ModelPack,
533 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
534 /// shape: vec![1, 640, 640, 2],
535 /// dshape: Vec::new(),
536 /// };
537 /// let decoder = DecoderBuilder::new()
538 /// .with_config_modelpack_segdet_split(vec![config0, config1], seg_config)
539 /// .build()?;
540 /// # Ok(())
541 /// # }
542 /// ```
543 pub fn with_config_modelpack_segdet_split(
544 mut self,
545 boxes: Vec<configs::Detection>,
546 segmentation: configs::Segmentation,
547 ) -> Self {
548 let mut outputs = boxes
549 .into_iter()
550 .map(ConfigOutput::Detection)
551 .collect::<Vec<_>>();
552 outputs.push(ConfigOutput::Segmentation(segmentation));
553 let config = ConfigOutputs {
554 outputs,
555 ..Default::default()
556 };
557 self.config_src.replace(ConfigSource::Config(config));
558 self
559 }
560
561 /// Loads a ModelPack segmentation model configuration. Use
562 /// `DecoderBuilder.build()` to parse the model configuration.
563 ///
564 /// # Examples
565 /// ```rust
566 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
567 /// # fn main() -> DecoderResult<()> {
568 /// let seg_config = configs::Segmentation {
569 /// decoder: configs::DecoderType::ModelPack,
570 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
571 /// shape: vec![1, 640, 640, 3],
572 /// dshape: Vec::new(),
573 /// };
574 /// let decoder = DecoderBuilder::new()
575 /// .with_config_modelpack_seg(seg_config)
576 /// .build()?;
577 /// # Ok(())
578 /// # }
579 /// ```
580 pub fn with_config_modelpack_seg(mut self, segmentation: configs::Segmentation) -> Self {
581 let config = ConfigOutputs {
582 outputs: vec![ConfigOutput::Segmentation(segmentation)],
583 ..Default::default()
584 };
585 self.config_src.replace(ConfigSource::Config(config));
586 self
587 }
588
589 /// Add an output to the decoder configuration.
590 ///
591 /// Incrementally builds the model configuration by adding outputs one at
592 /// a time. The decoder resolves the model type from the combination of
593 /// outputs during `build()`.
594 ///
595 /// If `dshape` is non-empty on the output, `shape` is automatically
596 /// derived from it (the size component of each named dimension). This
597 /// prevents conflicts between `shape` and `dshape`.
598 ///
599 /// This uses the programmatic config path. Calling this after
600 /// `with_config_json_str()` or `with_config_yaml_str()` replaces the
601 /// string-based config source.
602 ///
603 /// # Examples
604 /// ```rust
605 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, ConfigOutput, configs};
606 /// # fn main() -> DecoderResult<()> {
607 /// let decoder = DecoderBuilder::new()
608 /// .add_output(ConfigOutput::Scores(configs::Scores {
609 /// decoder: configs::DecoderType::Ultralytics,
610 /// dshape: vec![
611 /// (configs::DimName::Batch, 1),
612 /// (configs::DimName::NumClasses, 80),
613 /// (configs::DimName::NumBoxes, 8400),
614 /// ],
615 /// ..Default::default()
616 /// }))
617 /// .add_output(ConfigOutput::Boxes(configs::Boxes {
618 /// decoder: configs::DecoderType::Ultralytics,
619 /// dshape: vec![
620 /// (configs::DimName::Batch, 1),
621 /// (configs::DimName::BoxCoords, 4),
622 /// (configs::DimName::NumBoxes, 8400),
623 /// ],
624 /// ..Default::default()
625 /// }))
626 /// .build()?;
627 /// # Ok(())
628 /// # }
629 /// ```
630 pub fn add_output(mut self, output: ConfigOutput) -> Self {
631 if !matches!(self.config_src, Some(ConfigSource::Config(_))) {
632 self.config_src = Some(ConfigSource::Config(ConfigOutputs::default()));
633 }
634 if let Some(ConfigSource::Config(ref mut config)) = self.config_src {
635 config.outputs.push(Self::normalize_output(output));
636 }
637 self
638 }
639
640 /// Sets the decoder version for Ultralytics models.
641 ///
642 /// This is used with `add_output()` to specify the YOLO architecture
643 /// version when it cannot be inferred from the output shapes alone.
644 ///
645 /// - `Yolov5`, `Yolov8`, `Yolo11`: Traditional models requiring external
646 /// NMS
647 /// - `Yolo26`: End-to-end models with NMS embedded in the model graph
648 ///
649 /// # Examples
650 /// ```rust
651 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, ConfigOutput, configs};
652 /// # fn main() -> DecoderResult<()> {
653 /// let decoder = DecoderBuilder::new()
654 /// .add_output(ConfigOutput::Detection(configs::Detection {
655 /// decoder: configs::DecoderType::Ultralytics,
656 /// dshape: vec![
657 /// (configs::DimName::Batch, 1),
658 /// (configs::DimName::NumBoxes, 100),
659 /// (configs::DimName::NumFeatures, 6),
660 /// ],
661 /// ..Default::default()
662 /// }))
663 /// .with_decoder_version(configs::DecoderVersion::Yolo26)
664 /// .build()?;
665 /// # Ok(())
666 /// # }
667 /// ```
668 pub fn with_decoder_version(mut self, version: configs::DecoderVersion) -> Self {
669 if !matches!(self.config_src, Some(ConfigSource::Config(_))) {
670 self.config_src = Some(ConfigSource::Config(ConfigOutputs::default()));
671 }
672 if let Some(ConfigSource::Config(ref mut config)) = self.config_src {
673 config.decoder_version = Some(version);
674 }
675 self
676 }
677
678 /// Normalize an output: if dshape is non-empty, derive shape from it.
679 fn normalize_output(mut output: ConfigOutput) -> ConfigOutput {
680 fn normalize_shape(shape: &mut Vec<usize>, dshape: &[(configs::DimName, usize)]) {
681 if !dshape.is_empty() {
682 *shape = dshape.iter().map(|(_, size)| *size).collect();
683 }
684 }
685 match &mut output {
686 ConfigOutput::Detection(c) => normalize_shape(&mut c.shape, &c.dshape),
687 ConfigOutput::Boxes(c) => normalize_shape(&mut c.shape, &c.dshape),
688 ConfigOutput::Scores(c) => normalize_shape(&mut c.shape, &c.dshape),
689 ConfigOutput::Protos(c) => normalize_shape(&mut c.shape, &c.dshape),
690 ConfigOutput::Segmentation(c) => normalize_shape(&mut c.shape, &c.dshape),
691 ConfigOutput::MaskCoefficients(c) => normalize_shape(&mut c.shape, &c.dshape),
692 ConfigOutput::Mask(c) => normalize_shape(&mut c.shape, &c.dshape),
693 ConfigOutput::Classes(c) => normalize_shape(&mut c.shape, &c.dshape),
694 }
695 output
696 }
697
698 /// Sets the scores threshold of the decoder
699 ///
700 /// # Examples
701 /// ```rust
702 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
703 /// # fn main() -> DecoderResult<()> {
704 /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
705 /// let decoder = DecoderBuilder::new()
706 /// .with_config_json_str(config_json)
707 /// .with_score_threshold(0.654)
708 /// .build()?;
709 /// assert_eq!(decoder.score_threshold, 0.654);
710 /// # Ok(())
711 /// # }
712 /// ```
713 pub fn with_score_threshold(mut self, score_threshold: f32) -> Self {
714 self.score_threshold = score_threshold;
715 self
716 }
717
718 /// Sets the IOU threshold of the decoder. Has no effect when NMS is set to
719 /// `None`
720 ///
721 /// # Examples
722 /// ```rust
723 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
724 /// # fn main() -> DecoderResult<()> {
725 /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
726 /// let decoder = DecoderBuilder::new()
727 /// .with_config_json_str(config_json)
728 /// .with_iou_threshold(0.654)
729 /// .build()?;
730 /// assert_eq!(decoder.iou_threshold, 0.654);
731 /// # Ok(())
732 /// # }
733 /// ```
734 pub fn with_iou_threshold(mut self, iou_threshold: f32) -> Self {
735 self.iou_threshold = iou_threshold;
736 self
737 }
738
739 /// Sets the NMS mode for the decoder.
740 ///
741 /// - `Some(Nms::ClassAgnostic)` — class-agnostic NMS (default): suppress
742 /// overlapping boxes regardless of class label
743 /// - `Some(Nms::ClassAware)` — class-aware NMS: only suppress boxes that
744 /// share the same class label AND overlap above the IoU threshold
745 /// - `None` — bypass NMS entirely (for end-to-end models with embedded NMS)
746 ///
747 /// # Examples
748 /// ```rust
749 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, configs::Nms};
750 /// # fn main() -> DecoderResult<()> {
751 /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
752 /// let decoder = DecoderBuilder::new()
753 /// .with_config_json_str(config_json)
754 /// .with_nms(Some(Nms::ClassAware))
755 /// .build()?;
756 /// assert_eq!(decoder.nms, Some(Nms::ClassAware));
757 /// # Ok(())
758 /// # }
759 /// ```
760 pub fn with_nms(mut self, nms: Option<configs::Nms>) -> Self {
761 self.nms = nms;
762 self
763 }
764
765 /// Builds the decoder with the given settings. If the config is a JSON or
766 /// YAML string, this will deserialize the JSON or YAML and then parse the
767 /// configuration information.
768 ///
769 /// # Examples
770 /// ```rust
771 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
772 /// # fn main() -> DecoderResult<()> {
773 /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
774 /// let decoder = DecoderBuilder::new()
775 /// .with_config_json_str(config_json)
776 /// .with_score_threshold(0.654)
777 /// .build()?;
778 /// # Ok(())
779 /// # }
780 /// ```
781 pub fn build(self) -> Result<Decoder, DecoderError> {
782 let (config, decode_program) = match self.config_src {
783 Some(ConfigSource::Json(s)) => Self::build_from_schema(SchemaV2::parse_json(&s)?)?,
784 Some(ConfigSource::Yaml(s)) => Self::build_from_schema(SchemaV2::parse_yaml(&s)?)?,
785 Some(ConfigSource::Config(c)) => (c, None),
786 Some(ConfigSource::Schema(schema)) => Self::build_from_schema(schema)?,
787 None => return Err(DecoderError::NoConfig),
788 };
789
790 // Extract normalized flag from config outputs
791 let normalized = Self::get_normalized(&config.outputs);
792
793 // Use NMS from config if present, otherwise use builder's NMS setting
794 let nms = config.nms.or(self.nms);
795 let model_type = Self::get_model_type(config)?;
796
797 Ok(Decoder {
798 model_type,
799 iou_threshold: self.iou_threshold,
800 score_threshold: self.score_threshold,
801 nms,
802 normalized,
803 decode_program,
804 })
805 }
806
807 /// Validate a [`SchemaV2`] and lower it to the (legacy `ConfigOutputs`,
808 /// optional `DecodeProgram`) pair the rest of `build()` consumes.
809 ///
810 /// Centralises the v2 lowering so JSON, YAML, and direct
811 /// `with_schema` callers all go through the same validation and
812 /// merge-program construction. `SchemaV2::parse_json` /
813 /// `parse_yaml` already auto-detect v1 vs v2 input and return a v2
814 /// schema either way (v1 inputs are upgraded in memory via
815 /// `from_v1`), so this helper is the sole place that turns a
816 /// schema into builder-ready state.
817 fn build_from_schema(
818 schema: SchemaV2,
819 ) -> Result<(ConfigOutputs, Option<DecodeProgram>), DecoderError> {
820 schema.validate()?;
821 let program = DecodeProgram::try_from_schema(&schema)?;
822 let legacy = schema.to_legacy_config_outputs()?;
823 Ok((legacy, program))
824 }
825
826 /// Extracts the normalized flag from config outputs.
827 /// - `Some(true)`: Boxes are in normalized [0,1] coordinates
828 /// - `Some(false)`: Boxes are in pixel coordinates
829 /// - `None`: Unknown (not specified in config), caller must infer
830 fn get_normalized(outputs: &[ConfigOutput]) -> Option<bool> {
831 for output in outputs {
832 match output {
833 ConfigOutput::Detection(det) => return det.normalized,
834 ConfigOutput::Boxes(boxes) => return boxes.normalized,
835 _ => {}
836 }
837 }
838 None // not specified
839 }
840
841 fn get_model_type(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
842 // yolo or modelpack
843 let mut yolo = false;
844 let mut modelpack = false;
845 for c in &configs.outputs {
846 match c.decoder() {
847 DecoderType::ModelPack => modelpack = true,
848 DecoderType::Ultralytics => yolo = true,
849 }
850 }
851 match (modelpack, yolo) {
852 (true, true) => Err(DecoderError::InvalidConfig(
853 "Both ModelPack and Yolo outputs found in config".to_string(),
854 )),
855 (true, false) => Self::get_model_type_modelpack(configs),
856 (false, true) => Self::get_model_type_yolo(configs),
857 (false, false) => Err(DecoderError::InvalidConfig(
858 "No outputs found in config".to_string(),
859 )),
860 }
861 }
862
863 fn get_model_type_yolo(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
864 let mut boxes = None;
865 let mut protos = None;
866 let mut split_boxes = None;
867 let mut split_scores = None;
868 let mut split_mask_coeff = None;
869 let mut split_classes = None;
870 for c in configs.outputs {
871 match c {
872 ConfigOutput::Detection(detection) => boxes = Some(detection),
873 ConfigOutput::Segmentation(_) => {
874 return Err(DecoderError::InvalidConfig(
875 "Invalid Segmentation output with Yolo decoder".to_string(),
876 ));
877 }
878 ConfigOutput::Protos(protos_) => protos = Some(protos_),
879 ConfigOutput::Mask(_) => {
880 return Err(DecoderError::InvalidConfig(
881 "Invalid Mask output with Yolo decoder".to_string(),
882 ));
883 }
884 ConfigOutput::Scores(scores) => split_scores = Some(scores),
885 ConfigOutput::Boxes(boxes) => split_boxes = Some(boxes),
886 ConfigOutput::MaskCoefficients(mask_coeff) => split_mask_coeff = Some(mask_coeff),
887 ConfigOutput::Classes(classes) => split_classes = Some(classes),
888 }
889 }
890
891 // Use end-to-end model types when:
892 // 1. decoder_version is explicitly set to Yolo26 (definitive), OR
893 // decoder_version is not set but the dshapes are (batch, num_boxes,
894 // num_features)
895 let is_end_to_end_dshape = boxes.as_ref().is_some_and(|b| {
896 let dims = b.dshape.iter().map(|(d, _)| *d).collect::<Vec<_>>();
897 dims == vec![DimName::Batch, DimName::NumBoxes, DimName::NumFeatures]
898 });
899
900 let is_end_to_end = configs
901 .decoder_version
902 .map(|v| v.is_end_to_end())
903 .unwrap_or(is_end_to_end_dshape);
904
905 if is_end_to_end {
906 if let Some(boxes) = boxes {
907 if let Some(protos) = protos {
908 Self::verify_yolo_seg_det_26(&boxes, &protos)?;
909 return Ok(ModelType::YoloEndToEndSegDet { boxes, protos });
910 } else {
911 Self::verify_yolo_det_26(&boxes)?;
912 return Ok(ModelType::YoloEndToEndDet { boxes });
913 }
914 } else if let (Some(split_boxes), Some(split_scores), Some(split_classes)) =
915 (split_boxes, split_scores, split_classes)
916 {
917 if let (Some(split_mask_coeff), Some(protos)) = (split_mask_coeff, protos) {
918 Self::verify_yolo_split_end_to_end_segdet(
919 &split_boxes,
920 &split_scores,
921 &split_classes,
922 &split_mask_coeff,
923 &protos,
924 )?;
925 return Ok(ModelType::YoloSplitEndToEndSegDet {
926 boxes: split_boxes,
927 scores: split_scores,
928 classes: split_classes,
929 mask_coeff: split_mask_coeff,
930 protos,
931 });
932 }
933 Self::verify_yolo_split_end_to_end_det(
934 &split_boxes,
935 &split_scores,
936 &split_classes,
937 )?;
938 return Ok(ModelType::YoloSplitEndToEndDet {
939 boxes: split_boxes,
940 scores: split_scores,
941 classes: split_classes,
942 });
943 } else {
944 return Err(DecoderError::InvalidConfig(
945 "Invalid Yolo end-to-end model outputs".to_string(),
946 ));
947 }
948 }
949
950 if let Some(boxes) = boxes {
951 match (split_mask_coeff, protos) {
952 (Some(mask_coeff), Some(protos)) => {
953 // 2-way split: combined detection + separate mask_coeff + protos
954 Self::verify_yolo_seg_det_2way(&boxes, &mask_coeff, &protos)?;
955 Ok(ModelType::YoloSegDet2Way {
956 boxes,
957 mask_coeff,
958 protos,
959 })
960 }
961 (_, Some(protos)) => {
962 // Unsplit: mask_coefs embedded in detection tensor
963 Self::verify_yolo_seg_det(&boxes, &protos)?;
964 Ok(ModelType::YoloSegDet { boxes, protos })
965 }
966 _ => {
967 Self::verify_yolo_det(&boxes)?;
968 Ok(ModelType::YoloDet { boxes })
969 }
970 }
971 } else if let (Some(boxes), Some(scores)) = (split_boxes, split_scores) {
972 if let (Some(mask_coeff), Some(protos)) = (split_mask_coeff, protos) {
973 Self::verify_yolo_split_segdet(&boxes, &scores, &mask_coeff, &protos)?;
974 Ok(ModelType::YoloSplitSegDet {
975 boxes,
976 scores,
977 mask_coeff,
978 protos,
979 })
980 } else {
981 Self::verify_yolo_split_det(&boxes, &scores)?;
982 Ok(ModelType::YoloSplitDet { boxes, scores })
983 }
984 } else {
985 Err(DecoderError::InvalidConfig(
986 "Invalid Yolo model outputs".to_string(),
987 ))
988 }
989 }
990
991 fn verify_yolo_det(detect: &configs::Detection) -> Result<(), DecoderError> {
992 if detect.shape.len() != 3 {
993 return Err(DecoderError::InvalidConfig(format!(
994 "Invalid Yolo Detection shape {:?}",
995 detect.shape
996 )));
997 }
998
999 Self::verify_dshapes(
1000 &detect.dshape,
1001 &detect.shape,
1002 "Detection",
1003 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1004 )?;
1005 if !detect.dshape.is_empty() {
1006 Self::get_class_count(&detect.dshape, None, None)?;
1007 } else {
1008 Self::get_class_count_no_dshape(detect.into(), None)?;
1009 }
1010
1011 Ok(())
1012 }
1013
1014 fn verify_yolo_det_26(detect: &configs::Detection) -> Result<(), DecoderError> {
1015 if detect.shape.len() != 3 {
1016 return Err(DecoderError::InvalidConfig(format!(
1017 "Invalid Yolo Detection shape {:?}",
1018 detect.shape
1019 )));
1020 }
1021
1022 Self::verify_dshapes(
1023 &detect.dshape,
1024 &detect.shape,
1025 "Detection",
1026 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1027 )?;
1028
1029 if !detect.shape.contains(&6) {
1030 return Err(DecoderError::InvalidConfig(
1031 "Yolo26 Detection must have 6 features".to_string(),
1032 ));
1033 }
1034
1035 Ok(())
1036 }
1037
1038 fn verify_yolo_seg_det(
1039 detection: &configs::Detection,
1040 protos: &configs::Protos,
1041 ) -> Result<(), DecoderError> {
1042 if detection.shape.len() != 3 {
1043 return Err(DecoderError::InvalidConfig(format!(
1044 "Invalid Yolo Detection shape {:?}",
1045 detection.shape
1046 )));
1047 }
1048 if protos.shape.len() != 4 {
1049 return Err(DecoderError::InvalidConfig(format!(
1050 "Invalid Yolo Protos shape {:?}",
1051 protos.shape
1052 )));
1053 }
1054
1055 Self::verify_dshapes(
1056 &detection.dshape,
1057 &detection.shape,
1058 "Detection",
1059 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1060 )?;
1061 Self::verify_dshapes(
1062 &protos.dshape,
1063 &protos.shape,
1064 "Protos",
1065 &[
1066 DimName::Batch,
1067 DimName::Height,
1068 DimName::Width,
1069 DimName::NumProtos,
1070 ],
1071 )?;
1072
1073 let protos_count = Self::get_protos_count(&protos.dshape)
1074 .unwrap_or_else(|| protos.shape[1].min(protos.shape[3]));
1075 log::debug!("Protos count: {}", protos_count);
1076 log::debug!("Detection dshape: {:?}", detection.dshape);
1077 let classes = if !detection.dshape.is_empty() {
1078 Self::get_class_count(&detection.dshape, Some(protos_count), None)?
1079 } else {
1080 Self::get_class_count_no_dshape(detection.into(), Some(protos_count))?
1081 };
1082
1083 if classes == 0 {
1084 return Err(DecoderError::InvalidConfig(
1085 "Yolo Segmentation Detection has zero classes".to_string(),
1086 ));
1087 }
1088
1089 Ok(())
1090 }
1091
1092 fn verify_yolo_seg_det_2way(
1093 detection: &configs::Detection,
1094 mask_coeff: &configs::MaskCoefficients,
1095 protos: &configs::Protos,
1096 ) -> Result<(), DecoderError> {
1097 if detection.shape.len() != 3 {
1098 return Err(DecoderError::InvalidConfig(format!(
1099 "Invalid Yolo 2-Way Detection shape {:?}",
1100 detection.shape
1101 )));
1102 }
1103 if mask_coeff.shape.len() != 3 {
1104 return Err(DecoderError::InvalidConfig(format!(
1105 "Invalid Yolo 2-Way Mask Coefficients shape {:?}",
1106 mask_coeff.shape
1107 )));
1108 }
1109 if protos.shape.len() != 4 {
1110 return Err(DecoderError::InvalidConfig(format!(
1111 "Invalid Yolo 2-Way Protos shape {:?}",
1112 protos.shape
1113 )));
1114 }
1115
1116 Self::verify_dshapes(
1117 &detection.dshape,
1118 &detection.shape,
1119 "Detection",
1120 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1121 )?;
1122 Self::verify_dshapes(
1123 &mask_coeff.dshape,
1124 &mask_coeff.shape,
1125 "Mask Coefficients",
1126 &[DimName::Batch, DimName::NumProtos, DimName::NumBoxes],
1127 )?;
1128 Self::verify_dshapes(
1129 &protos.dshape,
1130 &protos.shape,
1131 "Protos",
1132 &[
1133 DimName::Batch,
1134 DimName::Height,
1135 DimName::Width,
1136 DimName::NumProtos,
1137 ],
1138 )?;
1139
1140 // Validate num_boxes match between detection and mask_coeff
1141 let det_num = Self::get_box_count(&detection.dshape).unwrap_or(detection.shape[2]);
1142 let mask_num = Self::get_box_count(&mask_coeff.dshape).unwrap_or(mask_coeff.shape[2]);
1143 if det_num != mask_num {
1144 return Err(DecoderError::InvalidConfig(format!(
1145 "Yolo 2-Way Detection num_boxes {} incompatible with Mask Coefficients num_boxes {}",
1146 det_num, mask_num
1147 )));
1148 }
1149
1150 // Validate mask_coeff channels match protos channels
1151 let mask_channels = if !mask_coeff.dshape.is_empty() {
1152 Self::get_protos_count(&mask_coeff.dshape).ok_or_else(|| {
1153 DecoderError::InvalidConfig(
1154 "Could not find num_protos in mask_coeff config".to_string(),
1155 )
1156 })?
1157 } else {
1158 mask_coeff.shape[1]
1159 };
1160 let proto_channels = if !protos.dshape.is_empty() {
1161 Self::get_protos_count(&protos.dshape).ok_or_else(|| {
1162 DecoderError::InvalidConfig(
1163 "Could not find num_protos in protos config".to_string(),
1164 )
1165 })?
1166 } else {
1167 protos.shape[1].min(protos.shape[3])
1168 };
1169 if mask_channels != proto_channels {
1170 return Err(DecoderError::InvalidConfig(format!(
1171 "Yolo 2-Way Protos channels {} incompatible with Mask Coefficients channels {}",
1172 proto_channels, mask_channels
1173 )));
1174 }
1175
1176 // Validate detection has classes (nc+4 features, no mask_coefs embedded)
1177 if !detection.dshape.is_empty() {
1178 Self::get_class_count(&detection.dshape, None, None)?;
1179 } else {
1180 Self::get_class_count_no_dshape(detection.into(), None)?;
1181 }
1182
1183 Ok(())
1184 }
1185
1186 fn verify_yolo_seg_det_26(
1187 detection: &configs::Detection,
1188 protos: &configs::Protos,
1189 ) -> Result<(), DecoderError> {
1190 if detection.shape.len() != 3 {
1191 return Err(DecoderError::InvalidConfig(format!(
1192 "Invalid Yolo Detection shape {:?}",
1193 detection.shape
1194 )));
1195 }
1196 if protos.shape.len() != 4 {
1197 return Err(DecoderError::InvalidConfig(format!(
1198 "Invalid Yolo Protos shape {:?}",
1199 protos.shape
1200 )));
1201 }
1202
1203 Self::verify_dshapes(
1204 &detection.dshape,
1205 &detection.shape,
1206 "Detection",
1207 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1208 )?;
1209 Self::verify_dshapes(
1210 &protos.dshape,
1211 &protos.shape,
1212 "Protos",
1213 &[
1214 DimName::Batch,
1215 DimName::Height,
1216 DimName::Width,
1217 DimName::NumProtos,
1218 ],
1219 )?;
1220
1221 let protos_count = Self::get_protos_count(&protos.dshape)
1222 .unwrap_or_else(|| protos.shape[1].min(protos.shape[3]));
1223 log::debug!("Protos count: {}", protos_count);
1224 log::debug!("Detection dshape: {:?}", detection.dshape);
1225
1226 if !detection.shape.contains(&(6 + protos_count)) {
1227 return Err(DecoderError::InvalidConfig(format!(
1228 "Yolo26 Segmentation Detection must have num_features be 6 + num_protos = {}",
1229 6 + protos_count
1230 )));
1231 }
1232
1233 Ok(())
1234 }
1235
1236 fn verify_yolo_split_det(
1237 boxes: &configs::Boxes,
1238 scores: &configs::Scores,
1239 ) -> Result<(), DecoderError> {
1240 if boxes.shape.len() != 3 {
1241 return Err(DecoderError::InvalidConfig(format!(
1242 "Invalid Yolo Split Boxes shape {:?}",
1243 boxes.shape
1244 )));
1245 }
1246 if scores.shape.len() != 3 {
1247 return Err(DecoderError::InvalidConfig(format!(
1248 "Invalid Yolo Split Scores shape {:?}",
1249 scores.shape
1250 )));
1251 }
1252
1253 Self::verify_dshapes(
1254 &boxes.dshape,
1255 &boxes.shape,
1256 "Boxes",
1257 &[DimName::Batch, DimName::BoxCoords, DimName::NumBoxes],
1258 )?;
1259 Self::verify_dshapes(
1260 &scores.dshape,
1261 &scores.shape,
1262 "Scores",
1263 &[DimName::Batch, DimName::NumClasses, DimName::NumBoxes],
1264 )?;
1265
1266 let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[2]);
1267 let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[2]);
1268
1269 if boxes_num != scores_num {
1270 return Err(DecoderError::InvalidConfig(format!(
1271 "Yolo Split Detection Boxes num {} incompatible with Scores num {}",
1272 boxes_num, scores_num
1273 )));
1274 }
1275
1276 Ok(())
1277 }
1278
1279 fn verify_yolo_split_segdet(
1280 boxes: &configs::Boxes,
1281 scores: &configs::Scores,
1282 mask_coeff: &configs::MaskCoefficients,
1283 protos: &configs::Protos,
1284 ) -> Result<(), DecoderError> {
1285 if boxes.shape.len() != 3 {
1286 return Err(DecoderError::InvalidConfig(format!(
1287 "Invalid Yolo Split Boxes shape {:?}",
1288 boxes.shape
1289 )));
1290 }
1291 if scores.shape.len() != 3 {
1292 return Err(DecoderError::InvalidConfig(format!(
1293 "Invalid Yolo Split Scores shape {:?}",
1294 scores.shape
1295 )));
1296 }
1297
1298 if mask_coeff.shape.len() != 3 {
1299 return Err(DecoderError::InvalidConfig(format!(
1300 "Invalid Yolo Split Mask Coefficients shape {:?}",
1301 mask_coeff.shape
1302 )));
1303 }
1304
1305 if protos.shape.len() != 4 {
1306 return Err(DecoderError::InvalidConfig(format!(
1307 "Invalid Yolo Protos shape {:?}",
1308 mask_coeff.shape
1309 )));
1310 }
1311
1312 Self::verify_dshapes(
1313 &boxes.dshape,
1314 &boxes.shape,
1315 "Boxes",
1316 &[DimName::Batch, DimName::BoxCoords, DimName::NumBoxes],
1317 )?;
1318 Self::verify_dshapes(
1319 &scores.dshape,
1320 &scores.shape,
1321 "Scores",
1322 &[DimName::Batch, DimName::NumClasses, DimName::NumBoxes],
1323 )?;
1324 Self::verify_dshapes(
1325 &mask_coeff.dshape,
1326 &mask_coeff.shape,
1327 "Mask Coefficients",
1328 &[DimName::Batch, DimName::NumProtos, DimName::NumBoxes],
1329 )?;
1330 Self::verify_dshapes(
1331 &protos.dshape,
1332 &protos.shape,
1333 "Protos",
1334 &[
1335 DimName::Batch,
1336 DimName::Height,
1337 DimName::Width,
1338 DimName::NumProtos,
1339 ],
1340 )?;
1341
1342 let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[2]);
1343 let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[2]);
1344 let mask_num = Self::get_box_count(&mask_coeff.dshape).unwrap_or(mask_coeff.shape[2]);
1345
1346 let mask_channels = if !mask_coeff.dshape.is_empty() {
1347 Self::get_protos_count(&mask_coeff.dshape).ok_or_else(|| {
1348 DecoderError::InvalidConfig("Could not find num_protos in config".to_string())
1349 })?
1350 } else {
1351 mask_coeff.shape[1]
1352 };
1353 let proto_channels = if !protos.dshape.is_empty() {
1354 Self::get_protos_count(&protos.dshape).ok_or_else(|| {
1355 DecoderError::InvalidConfig("Could not find num_protos in config".to_string())
1356 })?
1357 } else {
1358 protos.shape[1].min(protos.shape[3])
1359 };
1360
1361 if boxes_num != scores_num {
1362 return Err(DecoderError::InvalidConfig(format!(
1363 "Yolo Split Detection Boxes num {} incompatible with Scores num {}",
1364 boxes_num, scores_num
1365 )));
1366 }
1367
1368 if boxes_num != mask_num {
1369 return Err(DecoderError::InvalidConfig(format!(
1370 "Yolo Split Detection Boxes num {} incompatible with Mask Coefficients num {}",
1371 boxes_num, mask_num
1372 )));
1373 }
1374
1375 if proto_channels != mask_channels {
1376 return Err(DecoderError::InvalidConfig(format!(
1377 "Yolo Protos channels {} incompatible with Mask Coefficients channels {}",
1378 proto_channels, mask_channels
1379 )));
1380 }
1381
1382 Ok(())
1383 }
1384
1385 fn verify_yolo_split_end_to_end_det(
1386 boxes: &configs::Boxes,
1387 scores: &configs::Scores,
1388 classes: &configs::Classes,
1389 ) -> Result<(), DecoderError> {
1390 if boxes.shape.len() != 3 || !boxes.shape.contains(&4) {
1391 return Err(DecoderError::InvalidConfig(format!(
1392 "Split end-to-end boxes must be [batch, N, 4], got {:?}",
1393 boxes.shape
1394 )));
1395 }
1396 if scores.shape.len() != 3 || !scores.shape.contains(&1) {
1397 return Err(DecoderError::InvalidConfig(format!(
1398 "Split end-to-end scores must be [batch, N, 1], got {:?}",
1399 scores.shape
1400 )));
1401 }
1402 if classes.shape.len() != 3 || !classes.shape.contains(&1) {
1403 return Err(DecoderError::InvalidConfig(format!(
1404 "Split end-to-end classes must be [batch, N, 1], got {:?}",
1405 classes.shape
1406 )));
1407 }
1408 Ok(())
1409 }
1410
1411 fn verify_yolo_split_end_to_end_segdet(
1412 boxes: &configs::Boxes,
1413 scores: &configs::Scores,
1414 classes: &configs::Classes,
1415 mask_coeff: &configs::MaskCoefficients,
1416 protos: &configs::Protos,
1417 ) -> Result<(), DecoderError> {
1418 Self::verify_yolo_split_end_to_end_det(boxes, scores, classes)?;
1419 if mask_coeff.shape.len() != 3 {
1420 return Err(DecoderError::InvalidConfig(format!(
1421 "Invalid split end-to-end mask coefficients shape {:?}",
1422 mask_coeff.shape
1423 )));
1424 }
1425 if protos.shape.len() != 4 {
1426 return Err(DecoderError::InvalidConfig(format!(
1427 "Invalid protos shape {:?}",
1428 protos.shape
1429 )));
1430 }
1431 Ok(())
1432 }
1433
1434 fn get_model_type_modelpack(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
1435 let mut split_decoders = Vec::new();
1436 let mut segment_ = None;
1437 let mut scores_ = None;
1438 let mut boxes_ = None;
1439 for c in configs.outputs {
1440 match c {
1441 ConfigOutput::Detection(detection) => split_decoders.push(detection),
1442 ConfigOutput::Segmentation(segmentation) => segment_ = Some(segmentation),
1443 ConfigOutput::Mask(_) => {}
1444 ConfigOutput::Protos(_) => {
1445 return Err(DecoderError::InvalidConfig(
1446 "ModelPack should not have protos".to_string(),
1447 ));
1448 }
1449 ConfigOutput::Scores(scores) => scores_ = Some(scores),
1450 ConfigOutput::Boxes(boxes) => boxes_ = Some(boxes),
1451 ConfigOutput::MaskCoefficients(_) => {
1452 return Err(DecoderError::InvalidConfig(
1453 "ModelPack should not have mask coefficients".to_string(),
1454 ));
1455 }
1456 ConfigOutput::Classes(_) => {
1457 return Err(DecoderError::InvalidConfig(
1458 "ModelPack should not have classes output".to_string(),
1459 ));
1460 }
1461 }
1462 }
1463
1464 if let Some(segmentation) = segment_ {
1465 if !split_decoders.is_empty() {
1466 let classes = Self::verify_modelpack_split_det(&split_decoders)?;
1467 Self::verify_modelpack_seg(&segmentation, Some(classes))?;
1468 Ok(ModelType::ModelPackSegDetSplit {
1469 detection: split_decoders,
1470 segmentation,
1471 })
1472 } else if let (Some(scores), Some(boxes)) = (scores_, boxes_) {
1473 let classes = Self::verify_modelpack_det(&boxes, &scores)?;
1474 Self::verify_modelpack_seg(&segmentation, Some(classes))?;
1475 Ok(ModelType::ModelPackSegDet {
1476 boxes,
1477 scores,
1478 segmentation,
1479 })
1480 } else {
1481 Self::verify_modelpack_seg(&segmentation, None)?;
1482 Ok(ModelType::ModelPackSeg { segmentation })
1483 }
1484 } else if !split_decoders.is_empty() {
1485 Self::verify_modelpack_split_det(&split_decoders)?;
1486 Ok(ModelType::ModelPackDetSplit {
1487 detection: split_decoders,
1488 })
1489 } else if let (Some(scores), Some(boxes)) = (scores_, boxes_) {
1490 Self::verify_modelpack_det(&boxes, &scores)?;
1491 Ok(ModelType::ModelPackDet { boxes, scores })
1492 } else {
1493 Err(DecoderError::InvalidConfig(
1494 "Invalid ModelPack model outputs".to_string(),
1495 ))
1496 }
1497 }
1498
1499 fn verify_modelpack_det(
1500 boxes: &configs::Boxes,
1501 scores: &configs::Scores,
1502 ) -> Result<usize, DecoderError> {
1503 if boxes.shape.len() != 4 {
1504 return Err(DecoderError::InvalidConfig(format!(
1505 "Invalid ModelPack Boxes shape {:?}",
1506 boxes.shape
1507 )));
1508 }
1509 if scores.shape.len() != 3 {
1510 return Err(DecoderError::InvalidConfig(format!(
1511 "Invalid ModelPack Scores shape {:?}",
1512 scores.shape
1513 )));
1514 }
1515
1516 Self::verify_dshapes(
1517 &boxes.dshape,
1518 &boxes.shape,
1519 "Boxes",
1520 &[
1521 DimName::Batch,
1522 DimName::NumBoxes,
1523 DimName::Padding,
1524 DimName::BoxCoords,
1525 ],
1526 )?;
1527 Self::verify_dshapes(
1528 &scores.dshape,
1529 &scores.shape,
1530 "Scores",
1531 &[DimName::Batch, DimName::NumBoxes, DimName::NumClasses],
1532 )?;
1533
1534 let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[1]);
1535 let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[1]);
1536
1537 if boxes_num != scores_num {
1538 return Err(DecoderError::InvalidConfig(format!(
1539 "ModelPack Detection Boxes num {} incompatible with Scores num {}",
1540 boxes_num, scores_num
1541 )));
1542 }
1543
1544 let num_classes = if !scores.dshape.is_empty() {
1545 Self::get_class_count(&scores.dshape, None, None)?
1546 } else {
1547 Self::get_class_count_no_dshape(scores.into(), None)?
1548 };
1549
1550 Ok(num_classes)
1551 }
1552
1553 fn verify_modelpack_split_det(boxes: &[configs::Detection]) -> Result<usize, DecoderError> {
1554 let mut num_classes = None;
1555 for b in boxes {
1556 let Some(num_anchors) = b.anchors.as_ref().map(|a| a.len()) else {
1557 return Err(DecoderError::InvalidConfig(
1558 "ModelPack Split Detection missing anchors".to_string(),
1559 ));
1560 };
1561
1562 if num_anchors == 0 {
1563 return Err(DecoderError::InvalidConfig(
1564 "ModelPack Split Detection has zero anchors".to_string(),
1565 ));
1566 }
1567
1568 if b.shape.len() != 4 {
1569 return Err(DecoderError::InvalidConfig(format!(
1570 "Invalid ModelPack Split Detection shape {:?}",
1571 b.shape
1572 )));
1573 }
1574
1575 Self::verify_dshapes(
1576 &b.dshape,
1577 &b.shape,
1578 "Split Detection",
1579 &[
1580 DimName::Batch,
1581 DimName::Height,
1582 DimName::Width,
1583 DimName::NumAnchorsXFeatures,
1584 ],
1585 )?;
1586 let classes = if !b.dshape.is_empty() {
1587 Self::get_class_count(&b.dshape, None, Some(num_anchors))?
1588 } else {
1589 Self::get_class_count_no_dshape(b.into(), None)?
1590 };
1591
1592 match num_classes {
1593 Some(n) => {
1594 if n != classes {
1595 return Err(DecoderError::InvalidConfig(format!(
1596 "ModelPack Split Detection inconsistent number of classes: previous {}, current {}",
1597 n, classes
1598 )));
1599 }
1600 }
1601 None => {
1602 num_classes = Some(classes);
1603 }
1604 }
1605 }
1606
1607 Ok(num_classes.unwrap_or(0))
1608 }
1609
1610 fn verify_modelpack_seg(
1611 segmentation: &configs::Segmentation,
1612 classes: Option<usize>,
1613 ) -> Result<(), DecoderError> {
1614 if segmentation.shape.len() != 4 {
1615 return Err(DecoderError::InvalidConfig(format!(
1616 "Invalid ModelPack Segmentation shape {:?}",
1617 segmentation.shape
1618 )));
1619 }
1620 Self::verify_dshapes(
1621 &segmentation.dshape,
1622 &segmentation.shape,
1623 "Segmentation",
1624 &[
1625 DimName::Batch,
1626 DimName::Height,
1627 DimName::Width,
1628 DimName::NumClasses,
1629 ],
1630 )?;
1631
1632 if let Some(classes) = classes {
1633 let seg_classes = if !segmentation.dshape.is_empty() {
1634 Self::get_class_count(&segmentation.dshape, None, None)?
1635 } else {
1636 Self::get_class_count_no_dshape(segmentation.into(), None)?
1637 };
1638
1639 if seg_classes != classes + 1 {
1640 return Err(DecoderError::InvalidConfig(format!(
1641 "ModelPack Segmentation channels {} incompatible with number of classes {}",
1642 seg_classes, classes
1643 )));
1644 }
1645 }
1646 Ok(())
1647 }
1648
1649 // verifies that dshapes match the given shape
1650 fn verify_dshapes(
1651 dshape: &[(DimName, usize)],
1652 shape: &[usize],
1653 name: &str,
1654 dims: &[DimName],
1655 ) -> Result<(), DecoderError> {
1656 for s in shape {
1657 if *s == 0 {
1658 return Err(DecoderError::InvalidConfig(format!(
1659 "{} shape has zero dimension",
1660 name
1661 )));
1662 }
1663 }
1664
1665 if shape.len() != dims.len() {
1666 return Err(DecoderError::InvalidConfig(format!(
1667 "{} shape length {} does not match expected dims length {}",
1668 name,
1669 shape.len(),
1670 dims.len()
1671 )));
1672 }
1673
1674 if dshape.is_empty() {
1675 return Ok(());
1676 }
1677 // check the dshape lengths match the shape lengths
1678 if dshape.len() != shape.len() {
1679 return Err(DecoderError::InvalidConfig(format!(
1680 "{} dshape length does not match shape length",
1681 name
1682 )));
1683 }
1684
1685 // check the dshape values match the shape values
1686 for ((dim_name, dim_size), shape_size) in dshape.iter().zip(shape) {
1687 if dim_size != shape_size {
1688 return Err(DecoderError::InvalidConfig(format!(
1689 "{} dshape dimension {} size {} does not match shape size {}",
1690 name, dim_name, dim_size, shape_size
1691 )));
1692 }
1693 if *dim_name == DimName::Padding && *dim_size != 1 {
1694 return Err(DecoderError::InvalidConfig(
1695 "Padding dimension size must be 1".to_string(),
1696 ));
1697 }
1698
1699 if *dim_name == DimName::BoxCoords && *dim_size != 4 {
1700 return Err(DecoderError::InvalidConfig(
1701 "BoxCoords dimension size must be 4".to_string(),
1702 ));
1703 }
1704 }
1705
1706 let dims_present = HashSet::<DimName>::from_iter(dshape.iter().map(|(name, _)| *name));
1707 for dim in dims {
1708 if !dims_present.contains(dim) {
1709 return Err(DecoderError::InvalidConfig(format!(
1710 "{} dshape missing required dimension {:?}",
1711 name, dim
1712 )));
1713 }
1714 }
1715
1716 Ok(())
1717 }
1718
1719 fn get_box_count(dshape: &[(DimName, usize)]) -> Option<usize> {
1720 for (dim_name, dim_size) in dshape {
1721 if *dim_name == DimName::NumBoxes {
1722 return Some(*dim_size);
1723 }
1724 }
1725 None
1726 }
1727
1728 fn get_class_count_no_dshape(
1729 config: ConfigOutputRef,
1730 protos: Option<usize>,
1731 ) -> Result<usize, DecoderError> {
1732 match config {
1733 ConfigOutputRef::Detection(detection) => match detection.decoder {
1734 DecoderType::Ultralytics => {
1735 if detection.shape[1] <= 4 + protos.unwrap_or(0) {
1736 return Err(DecoderError::InvalidConfig(format!(
1737 "Invalid shape: Yolo num_features {} must be greater than {}",
1738 detection.shape[1],
1739 4 + protos.unwrap_or(0),
1740 )));
1741 }
1742 Ok(detection.shape[1] - 4 - protos.unwrap_or(0))
1743 }
1744 DecoderType::ModelPack => {
1745 let Some(num_anchors) = detection.anchors.as_ref().map(|a| a.len()) else {
1746 return Err(DecoderError::Internal(
1747 "ModelPack Detection missing anchors".to_string(),
1748 ));
1749 };
1750 let anchors_x_features = detection.shape[3];
1751 if anchors_x_features <= num_anchors * 5 {
1752 return Err(DecoderError::InvalidConfig(format!(
1753 "Invalid ModelPack Split Detection shape: anchors_x_features {} not greater than number of anchors * 5 = {}",
1754 anchors_x_features,
1755 num_anchors * 5,
1756 )));
1757 }
1758
1759 if !anchors_x_features.is_multiple_of(num_anchors) {
1760 return Err(DecoderError::InvalidConfig(format!(
1761 "Invalid ModelPack Split Detection shape: anchors_x_features {} not a multiple of number of anchors {}",
1762 anchors_x_features, num_anchors
1763 )));
1764 }
1765 Ok(anchors_x_features / num_anchors - 5)
1766 }
1767 },
1768
1769 ConfigOutputRef::Scores(scores) => match scores.decoder {
1770 DecoderType::Ultralytics => Ok(scores.shape[1]),
1771 DecoderType::ModelPack => Ok(scores.shape[2]),
1772 },
1773 ConfigOutputRef::Segmentation(seg) => Ok(seg.shape[3]),
1774 _ => Err(DecoderError::Internal(
1775 "Attempted to get class count from unsupported config output".to_owned(),
1776 )),
1777 }
1778 }
1779
1780 // get the class count from dshape or calculate from num_features
1781 fn get_class_count(
1782 dshape: &[(DimName, usize)],
1783 protos: Option<usize>,
1784 anchors: Option<usize>,
1785 ) -> Result<usize, DecoderError> {
1786 if dshape.is_empty() {
1787 return Ok(0);
1788 }
1789 // if it has num_classes in dshape, return it
1790 for (dim_name, dim_size) in dshape {
1791 if *dim_name == DimName::NumClasses {
1792 return Ok(*dim_size);
1793 }
1794 }
1795
1796 // number of classes can be calculated from num_features - 4 for yolo. If the
1797 // model has protos, we also subtract the number of protos.
1798 for (dim_name, dim_size) in dshape {
1799 if *dim_name == DimName::NumFeatures {
1800 let protos = protos.unwrap_or(0);
1801 if protos + 4 >= *dim_size {
1802 return Err(DecoderError::InvalidConfig(format!(
1803 "Invalid shape: Yolo num_features {} must be greater than {}",
1804 *dim_size,
1805 protos + 4,
1806 )));
1807 }
1808 return Ok(*dim_size - 4 - protos);
1809 }
1810 }
1811
1812 // number of classes can be calculated from number of anchors for modelpack
1813 // split detection
1814 if let Some(num_anchors) = anchors {
1815 for (dim_name, dim_size) in dshape {
1816 if *dim_name == DimName::NumAnchorsXFeatures {
1817 let anchors_x_features = *dim_size;
1818 if anchors_x_features <= num_anchors * 5 {
1819 return Err(DecoderError::InvalidConfig(format!(
1820 "Invalid ModelPack Split Detection shape: anchors_x_features {} not greater than number of anchors * 5 = {}",
1821 anchors_x_features,
1822 num_anchors * 5,
1823 )));
1824 }
1825
1826 if !anchors_x_features.is_multiple_of(num_anchors) {
1827 return Err(DecoderError::InvalidConfig(format!(
1828 "Invalid ModelPack Split Detection shape: anchors_x_features {} not a multiple of number of anchors {}",
1829 anchors_x_features, num_anchors
1830 )));
1831 }
1832 return Ok((anchors_x_features / num_anchors) - 5);
1833 }
1834 }
1835 }
1836 Err(DecoderError::InvalidConfig(
1837 "Cannot determine number of classes from dshape".to_owned(),
1838 ))
1839 }
1840
1841 fn get_protos_count(dshape: &[(DimName, usize)]) -> Option<usize> {
1842 for (dim_name, dim_size) in dshape {
1843 if *dim_name == DimName::NumProtos {
1844 return Some(*dim_size);
1845 }
1846 }
1847 None
1848 }
1849}