edgefirst_decoder/decoder/builder.rs
1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::HashSet;
5
6use super::config::ConfigOutputRef;
7use super::configs::{self, DecoderType, DecoderVersion, DimName, ModelType};
8use super::merge::DecodeProgram;
9use super::{ConfigOutput, ConfigOutputs, Decoder};
10use crate::schema::SchemaV2;
11use crate::DecoderError;
12
13#[derive(Debug, Clone, PartialEq)]
14pub struct DecoderBuilder {
15 config_src: Option<ConfigSource>,
16 iou_threshold: f32,
17 score_threshold: f32,
18 /// NMS mode: Some(mode) applies NMS, None bypasses NMS (for end-to-end
19 /// models)
20 nms: Option<configs::Nms>,
21}
22
23#[derive(Debug, Clone, PartialEq)]
24enum ConfigSource {
25 Yaml(String),
26 Json(String),
27 Config(ConfigOutputs),
28 /// Schema v2 metadata. During build the schema is either converted
29 /// to a legacy [`ConfigOutputs`] (flat case) or used to construct a
30 /// [`DecodeProgram`] that performs per-child dequant + merge at
31 /// decode time.
32 Schema(SchemaV2),
33}
34
35impl Default for DecoderBuilder {
36 /// Creates a default DecoderBuilder with no configuration and 0.5 score
37 /// threshold and 0.5 IoU threshold.
38 ///
39 /// A valid configuration must be provided before building the Decoder.
40 ///
41 /// # Examples
42 /// ```rust
43 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
44 /// # fn main() -> DecoderResult<()> {
45 /// # let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
46 /// let decoder = DecoderBuilder::default()
47 /// .with_config_yaml_str(config_yaml)
48 /// .build()?;
49 /// assert_eq!(decoder.score_threshold, 0.5);
50 /// assert_eq!(decoder.iou_threshold, 0.5);
51 ///
52 /// # Ok(())
53 /// # }
54 /// ```
55 fn default() -> Self {
56 Self {
57 config_src: None,
58 iou_threshold: 0.5,
59 score_threshold: 0.5,
60 nms: Some(configs::Nms::ClassAgnostic),
61 }
62 }
63}
64
65impl DecoderBuilder {
66 /// Creates a default DecoderBuilder with no configuration and 0.5 score
67 /// threshold and 0.5 IoU threshold.
68 ///
69 /// A valid configuration must be provided before building the Decoder.
70 ///
71 /// # Examples
72 /// ```rust
73 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
74 /// # fn main() -> DecoderResult<()> {
75 /// # let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
76 /// let decoder = DecoderBuilder::new()
77 /// .with_config_yaml_str(config_yaml)
78 /// .build()?;
79 /// assert_eq!(decoder.score_threshold, 0.5);
80 /// assert_eq!(decoder.iou_threshold, 0.5);
81 ///
82 /// # Ok(())
83 /// # }
84 /// ```
85 pub fn new() -> Self {
86 Self::default()
87 }
88
89 /// Loads a model configuration in YAML format. Does not check if the string
90 /// is a correct configuration file. Use `DecoderBuilder.build()` to
91 /// deserialize the YAML and parse the model configuration.
92 ///
93 /// # Examples
94 /// ```rust
95 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
96 /// # fn main() -> DecoderResult<()> {
97 /// let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
98 /// let decoder = DecoderBuilder::new()
99 /// .with_config_yaml_str(config_yaml)
100 /// .build()?;
101 ///
102 /// # Ok(())
103 /// # }
104 /// ```
105 pub fn with_config_yaml_str(mut self, yaml_str: String) -> Self {
106 self.config_src.replace(ConfigSource::Yaml(yaml_str));
107 self
108 }
109
110 /// Loads a model configuration in JSON format. Does not check if the string
111 /// is a correct configuration file. Use `DecoderBuilder.build()` to
112 /// deserialize the JSON and parse the model configuration.
113 ///
114 /// # Examples
115 /// ```rust
116 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
117 /// # fn main() -> DecoderResult<()> {
118 /// let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
119 /// let decoder = DecoderBuilder::new()
120 /// .with_config_json_str(config_json)
121 /// .build()?;
122 ///
123 /// # Ok(())
124 /// # }
125 /// ```
126 pub fn with_config_json_str(mut self, json_str: String) -> Self {
127 self.config_src.replace(ConfigSource::Json(json_str));
128 self
129 }
130
131 /// Loads a model configuration. Does not check if the configuration is
132 /// correct. Intended to be used when the user needs control over the
133 /// deserialize of the configuration information. Use
134 /// `DecoderBuilder.build()` to parse the model configuration.
135 ///
136 /// # Examples
137 /// ```rust
138 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
139 /// # fn main() -> DecoderResult<()> {
140 /// let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json"));
141 /// let config = serde_json::from_str(config_json)?;
142 /// let decoder = DecoderBuilder::new().with_config(config).build()?;
143 ///
144 /// # Ok(())
145 /// # }
146 /// ```
147 pub fn with_config(mut self, config: ConfigOutputs) -> Self {
148 self.config_src.replace(ConfigSource::Config(config));
149 self
150 }
151
152 /// Configure the decoder from a schema v2 metadata document.
153 ///
154 /// Accepts a [`SchemaV2`] as produced by [`SchemaV2::parse_json`],
155 /// [`SchemaV2::parse_yaml`], [`SchemaV2::parse_file`], or
156 /// constructed programmatically. The builder validates the schema,
157 /// compiles a [`DecodeProgram`] for any split logical outputs
158 /// (per-scale or channel sub-splits), and downconverts the
159 /// logical-level semantics to the legacy [`ConfigOutputs`]
160 /// representation consumed by the existing decoder dispatch.
161 ///
162 /// # Examples
163 ///
164 /// ```rust,no_run
165 /// use edgefirst_decoder::{DecoderBuilder, DecoderResult};
166 /// use edgefirst_decoder::schema::SchemaV2;
167 ///
168 /// # fn main() -> DecoderResult<()> {
169 /// let schema = SchemaV2::parse_file("model/edgefirst.json")?;
170 /// let decoder = DecoderBuilder::new()
171 /// .with_schema(schema)
172 /// .with_score_threshold(0.25)
173 /// .build()?;
174 /// # Ok(())
175 /// # }
176 /// ```
177 pub fn with_schema(mut self, schema: SchemaV2) -> Self {
178 self.config_src.replace(ConfigSource::Schema(schema));
179 self
180 }
181
182 /// Loads a YOLO detection model configuration. Use
183 /// `DecoderBuilder.build()` to parse the model configuration.
184 ///
185 /// # Examples
186 /// ```rust
187 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
188 /// # fn main() -> DecoderResult<()> {
189 /// let decoder = DecoderBuilder::new()
190 /// .with_config_yolo_det(
191 /// configs::Detection {
192 /// anchors: None,
193 /// decoder: configs::DecoderType::Ultralytics,
194 /// quantization: Some(configs::QuantTuple(0.012345, 26)),
195 /// shape: vec![1, 84, 8400],
196 /// dshape: Vec::new(),
197 /// normalized: Some(true),
198 /// },
199 /// None,
200 /// )
201 /// .build()?;
202 ///
203 /// # Ok(())
204 /// # }
205 /// ```
206 pub fn with_config_yolo_det(
207 mut self,
208 boxes: configs::Detection,
209 version: Option<DecoderVersion>,
210 ) -> Self {
211 let config = ConfigOutputs {
212 outputs: vec![ConfigOutput::Detection(boxes)],
213 decoder_version: version,
214 ..Default::default()
215 };
216 self.config_src.replace(ConfigSource::Config(config));
217 self
218 }
219
220 /// Loads a YOLO split detection model configuration. Use
221 /// `DecoderBuilder.build()` to parse the model configuration.
222 ///
223 /// # Examples
224 /// ```rust
225 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
226 /// # fn main() -> DecoderResult<()> {
227 /// let boxes_config = configs::Boxes {
228 /// decoder: configs::DecoderType::Ultralytics,
229 /// quantization: Some(configs::QuantTuple(0.012345, 26)),
230 /// shape: vec![1, 4, 8400],
231 /// dshape: Vec::new(),
232 /// normalized: Some(true),
233 /// };
234 /// let scores_config = configs::Scores {
235 /// decoder: configs::DecoderType::Ultralytics,
236 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
237 /// shape: vec![1, 80, 8400],
238 /// dshape: Vec::new(),
239 /// };
240 /// let decoder = DecoderBuilder::new()
241 /// .with_config_yolo_split_det(boxes_config, scores_config)
242 /// .build()?;
243 /// # Ok(())
244 /// # }
245 /// ```
246 pub fn with_config_yolo_split_det(
247 mut self,
248 boxes: configs::Boxes,
249 scores: configs::Scores,
250 ) -> Self {
251 let config = ConfigOutputs {
252 outputs: vec![ConfigOutput::Boxes(boxes), ConfigOutput::Scores(scores)],
253 ..Default::default()
254 };
255 self.config_src.replace(ConfigSource::Config(config));
256 self
257 }
258
259 /// Loads a YOLO segmentation model configuration. Use
260 /// `DecoderBuilder.build()` to parse the model configuration.
261 ///
262 /// # Examples
263 /// ```rust
264 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
265 /// # fn main() -> DecoderResult<()> {
266 /// let seg_config = configs::Detection {
267 /// decoder: configs::DecoderType::Ultralytics,
268 /// quantization: Some(configs::QuantTuple(0.012345, 26)),
269 /// shape: vec![1, 116, 8400],
270 /// anchors: None,
271 /// dshape: Vec::new(),
272 /// normalized: Some(true),
273 /// };
274 /// let protos_config = configs::Protos {
275 /// decoder: configs::DecoderType::Ultralytics,
276 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
277 /// shape: vec![1, 160, 160, 32],
278 /// dshape: Vec::new(),
279 /// };
280 /// let decoder = DecoderBuilder::new()
281 /// .with_config_yolo_segdet(
282 /// seg_config,
283 /// protos_config,
284 /// Some(configs::DecoderVersion::Yolov8),
285 /// )
286 /// .build()?;
287 /// # Ok(())
288 /// # }
289 /// ```
290 pub fn with_config_yolo_segdet(
291 mut self,
292 boxes: configs::Detection,
293 protos: configs::Protos,
294 version: Option<DecoderVersion>,
295 ) -> Self {
296 let config = ConfigOutputs {
297 outputs: vec![ConfigOutput::Detection(boxes), ConfigOutput::Protos(protos)],
298 decoder_version: version,
299 ..Default::default()
300 };
301 self.config_src.replace(ConfigSource::Config(config));
302 self
303 }
304
305 /// Loads a YOLO split segmentation model configuration. Use
306 /// `DecoderBuilder.build()` to parse the model configuration.
307 ///
308 /// # Examples
309 /// ```rust
310 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
311 /// # fn main() -> DecoderResult<()> {
312 /// let boxes_config = configs::Boxes {
313 /// decoder: configs::DecoderType::Ultralytics,
314 /// quantization: Some(configs::QuantTuple(0.012345, 26)),
315 /// shape: vec![1, 4, 8400],
316 /// dshape: Vec::new(),
317 /// normalized: Some(true),
318 /// };
319 /// let scores_config = configs::Scores {
320 /// decoder: configs::DecoderType::Ultralytics,
321 /// quantization: Some(configs::QuantTuple(0.012345, 14)),
322 /// shape: vec![1, 80, 8400],
323 /// dshape: Vec::new(),
324 /// };
325 /// let mask_config = configs::MaskCoefficients {
326 /// decoder: configs::DecoderType::Ultralytics,
327 /// quantization: Some(configs::QuantTuple(0.0064123, 125)),
328 /// shape: vec![1, 32, 8400],
329 /// dshape: Vec::new(),
330 /// };
331 /// let protos_config = configs::Protos {
332 /// decoder: configs::DecoderType::Ultralytics,
333 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
334 /// shape: vec![1, 160, 160, 32],
335 /// dshape: Vec::new(),
336 /// };
337 /// let decoder = DecoderBuilder::new()
338 /// .with_config_yolo_split_segdet(boxes_config, scores_config, mask_config, protos_config)
339 /// .build()?;
340 /// # Ok(())
341 /// # }
342 /// ```
343 pub fn with_config_yolo_split_segdet(
344 mut self,
345 boxes: configs::Boxes,
346 scores: configs::Scores,
347 mask_coefficients: configs::MaskCoefficients,
348 protos: configs::Protos,
349 ) -> Self {
350 let config = ConfigOutputs {
351 outputs: vec![
352 ConfigOutput::Boxes(boxes),
353 ConfigOutput::Scores(scores),
354 ConfigOutput::MaskCoefficients(mask_coefficients),
355 ConfigOutput::Protos(protos),
356 ],
357 ..Default::default()
358 };
359 self.config_src.replace(ConfigSource::Config(config));
360 self
361 }
362
363 /// Loads a ModelPack detection model configuration. Use
364 /// `DecoderBuilder.build()` to parse the model configuration.
365 ///
366 /// # Examples
367 /// ```rust
368 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
369 /// # fn main() -> DecoderResult<()> {
370 /// let boxes_config = configs::Boxes {
371 /// decoder: configs::DecoderType::ModelPack,
372 /// quantization: Some(configs::QuantTuple(0.012345, 26)),
373 /// shape: vec![1, 8400, 1, 4],
374 /// dshape: Vec::new(),
375 /// normalized: Some(true),
376 /// };
377 /// let scores_config = configs::Scores {
378 /// decoder: configs::DecoderType::ModelPack,
379 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
380 /// shape: vec![1, 8400, 3],
381 /// dshape: Vec::new(),
382 /// };
383 /// let decoder = DecoderBuilder::new()
384 /// .with_config_modelpack_det(boxes_config, scores_config)
385 /// .build()?;
386 /// # Ok(())
387 /// # }
388 /// ```
389 pub fn with_config_modelpack_det(
390 mut self,
391 boxes: configs::Boxes,
392 scores: configs::Scores,
393 ) -> Self {
394 let config = ConfigOutputs {
395 outputs: vec![ConfigOutput::Boxes(boxes), ConfigOutput::Scores(scores)],
396 ..Default::default()
397 };
398 self.config_src.replace(ConfigSource::Config(config));
399 self
400 }
401
402 /// Loads a ModelPack split detection model configuration. Use
403 /// `DecoderBuilder.build()` to parse the model configuration.
404 ///
405 /// # Examples
406 /// ```rust
407 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
408 /// # fn main() -> DecoderResult<()> {
409 /// let config0 = configs::Detection {
410 /// anchors: Some(vec![
411 /// [0.13750000298023224, 0.2074074000120163],
412 /// [0.2541666626930237, 0.21481481194496155],
413 /// [0.23125000298023224, 0.35185185074806213],
414 /// ]),
415 /// decoder: configs::DecoderType::ModelPack,
416 /// quantization: Some(configs::QuantTuple(0.012345, 26)),
417 /// shape: vec![1, 17, 30, 18],
418 /// dshape: Vec::new(),
419 /// normalized: Some(true),
420 /// };
421 /// let config1 = configs::Detection {
422 /// anchors: Some(vec![
423 /// [0.36666667461395264, 0.31481480598449707],
424 /// [0.38749998807907104, 0.4740740656852722],
425 /// [0.5333333611488342, 0.644444465637207],
426 /// ]),
427 /// decoder: configs::DecoderType::ModelPack,
428 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
429 /// shape: vec![1, 9, 15, 18],
430 /// dshape: Vec::new(),
431 /// normalized: Some(true),
432 /// };
433 ///
434 /// let decoder = DecoderBuilder::new()
435 /// .with_config_modelpack_det_split(vec![config0, config1])
436 /// .build()?;
437 /// # Ok(())
438 /// # }
439 /// ```
440 pub fn with_config_modelpack_det_split(mut self, boxes: Vec<configs::Detection>) -> Self {
441 let outputs = boxes.into_iter().map(ConfigOutput::Detection).collect();
442 let config = ConfigOutputs {
443 outputs,
444 ..Default::default()
445 };
446 self.config_src.replace(ConfigSource::Config(config));
447 self
448 }
449
450 /// Loads a ModelPack segmentation detection model configuration. Use
451 /// `DecoderBuilder.build()` to parse the model configuration.
452 ///
453 /// # Examples
454 /// ```rust
455 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
456 /// # fn main() -> DecoderResult<()> {
457 /// let boxes_config = configs::Boxes {
458 /// decoder: configs::DecoderType::ModelPack,
459 /// quantization: Some(configs::QuantTuple(0.012345, 26)),
460 /// shape: vec![1, 8400, 1, 4],
461 /// dshape: Vec::new(),
462 /// normalized: Some(true),
463 /// };
464 /// let scores_config = configs::Scores {
465 /// decoder: configs::DecoderType::ModelPack,
466 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
467 /// shape: vec![1, 8400, 2],
468 /// dshape: Vec::new(),
469 /// };
470 /// let seg_config = configs::Segmentation {
471 /// decoder: configs::DecoderType::ModelPack,
472 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
473 /// shape: vec![1, 640, 640, 3],
474 /// dshape: Vec::new(),
475 /// };
476 /// let decoder = DecoderBuilder::new()
477 /// .with_config_modelpack_segdet(boxes_config, scores_config, seg_config)
478 /// .build()?;
479 /// # Ok(())
480 /// # }
481 /// ```
482 pub fn with_config_modelpack_segdet(
483 mut self,
484 boxes: configs::Boxes,
485 scores: configs::Scores,
486 segmentation: configs::Segmentation,
487 ) -> Self {
488 let config = ConfigOutputs {
489 outputs: vec![
490 ConfigOutput::Boxes(boxes),
491 ConfigOutput::Scores(scores),
492 ConfigOutput::Segmentation(segmentation),
493 ],
494 ..Default::default()
495 };
496 self.config_src.replace(ConfigSource::Config(config));
497 self
498 }
499
500 /// Loads a ModelPack segmentation split detection model configuration. Use
501 /// `DecoderBuilder.build()` to parse the model configuration.
502 ///
503 /// # Examples
504 /// ```rust
505 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
506 /// # fn main() -> DecoderResult<()> {
507 /// let config0 = configs::Detection {
508 /// anchors: Some(vec![
509 /// [0.36666667461395264, 0.31481480598449707],
510 /// [0.38749998807907104, 0.4740740656852722],
511 /// [0.5333333611488342, 0.644444465637207],
512 /// ]),
513 /// decoder: configs::DecoderType::ModelPack,
514 /// quantization: Some(configs::QuantTuple(0.08547406643629074, 174)),
515 /// shape: vec![1, 9, 15, 18],
516 /// dshape: Vec::new(),
517 /// normalized: Some(true),
518 /// };
519 /// let config1 = configs::Detection {
520 /// anchors: Some(vec![
521 /// [0.13750000298023224, 0.2074074000120163],
522 /// [0.2541666626930237, 0.21481481194496155],
523 /// [0.23125000298023224, 0.35185185074806213],
524 /// ]),
525 /// decoder: configs::DecoderType::ModelPack,
526 /// quantization: Some(configs::QuantTuple(0.09929127991199493, 183)),
527 /// shape: vec![1, 17, 30, 18],
528 /// dshape: Vec::new(),
529 /// normalized: Some(true),
530 /// };
531 /// let seg_config = configs::Segmentation {
532 /// decoder: configs::DecoderType::ModelPack,
533 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
534 /// shape: vec![1, 640, 640, 2],
535 /// dshape: Vec::new(),
536 /// };
537 /// let decoder = DecoderBuilder::new()
538 /// .with_config_modelpack_segdet_split(vec![config0, config1], seg_config)
539 /// .build()?;
540 /// # Ok(())
541 /// # }
542 /// ```
543 pub fn with_config_modelpack_segdet_split(
544 mut self,
545 boxes: Vec<configs::Detection>,
546 segmentation: configs::Segmentation,
547 ) -> Self {
548 let mut outputs = boxes
549 .into_iter()
550 .map(ConfigOutput::Detection)
551 .collect::<Vec<_>>();
552 outputs.push(ConfigOutput::Segmentation(segmentation));
553 let config = ConfigOutputs {
554 outputs,
555 ..Default::default()
556 };
557 self.config_src.replace(ConfigSource::Config(config));
558 self
559 }
560
561 /// Loads a ModelPack segmentation model configuration. Use
562 /// `DecoderBuilder.build()` to parse the model configuration.
563 ///
564 /// # Examples
565 /// ```rust
566 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
567 /// # fn main() -> DecoderResult<()> {
568 /// let seg_config = configs::Segmentation {
569 /// decoder: configs::DecoderType::ModelPack,
570 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
571 /// shape: vec![1, 640, 640, 3],
572 /// dshape: Vec::new(),
573 /// };
574 /// let decoder = DecoderBuilder::new()
575 /// .with_config_modelpack_seg(seg_config)
576 /// .build()?;
577 /// # Ok(())
578 /// # }
579 /// ```
580 pub fn with_config_modelpack_seg(mut self, segmentation: configs::Segmentation) -> Self {
581 let config = ConfigOutputs {
582 outputs: vec![ConfigOutput::Segmentation(segmentation)],
583 ..Default::default()
584 };
585 self.config_src.replace(ConfigSource::Config(config));
586 self
587 }
588
589 /// Add an output to the decoder configuration.
590 ///
591 /// Incrementally builds the model configuration by adding outputs one at
592 /// a time. The decoder resolves the model type from the combination of
593 /// outputs during `build()`.
594 ///
595 /// If `dshape` is non-empty on the output, `shape` is automatically
596 /// derived from it (the size component of each named dimension). This
597 /// prevents conflicts between `shape` and `dshape`.
598 ///
599 /// This uses the programmatic config path. Calling this after
600 /// `with_config_json_str()` or `with_config_yaml_str()` replaces the
601 /// string-based config source.
602 ///
603 /// # Examples
604 /// ```rust
605 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, ConfigOutput, configs};
606 /// # fn main() -> DecoderResult<()> {
607 /// let decoder = DecoderBuilder::new()
608 /// .add_output(ConfigOutput::Scores(configs::Scores {
609 /// decoder: configs::DecoderType::Ultralytics,
610 /// dshape: vec![
611 /// (configs::DimName::Batch, 1),
612 /// (configs::DimName::NumClasses, 80),
613 /// (configs::DimName::NumBoxes, 8400),
614 /// ],
615 /// ..Default::default()
616 /// }))
617 /// .add_output(ConfigOutput::Boxes(configs::Boxes {
618 /// decoder: configs::DecoderType::Ultralytics,
619 /// dshape: vec![
620 /// (configs::DimName::Batch, 1),
621 /// (configs::DimName::BoxCoords, 4),
622 /// (configs::DimName::NumBoxes, 8400),
623 /// ],
624 /// ..Default::default()
625 /// }))
626 /// .build()?;
627 /// # Ok(())
628 /// # }
629 /// ```
630 pub fn add_output(mut self, output: ConfigOutput) -> Self {
631 if !matches!(self.config_src, Some(ConfigSource::Config(_))) {
632 self.config_src = Some(ConfigSource::Config(ConfigOutputs::default()));
633 }
634 if let Some(ConfigSource::Config(ref mut config)) = self.config_src {
635 config.outputs.push(Self::normalize_output(output));
636 }
637 self
638 }
639
640 /// Sets the decoder version for Ultralytics models.
641 ///
642 /// This is used with `add_output()` to specify the YOLO architecture
643 /// version when it cannot be inferred from the output shapes alone.
644 ///
645 /// - `Yolov5`, `Yolov8`, `Yolo11`: Traditional models requiring external
646 /// NMS
647 /// - `Yolo26`: End-to-end models with NMS embedded in the model graph
648 ///
649 /// # Examples
650 /// ```rust
651 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, ConfigOutput, configs};
652 /// # fn main() -> DecoderResult<()> {
653 /// let decoder = DecoderBuilder::new()
654 /// .add_output(ConfigOutput::Detection(configs::Detection {
655 /// decoder: configs::DecoderType::Ultralytics,
656 /// dshape: vec![
657 /// (configs::DimName::Batch, 1),
658 /// (configs::DimName::NumBoxes, 100),
659 /// (configs::DimName::NumFeatures, 6),
660 /// ],
661 /// ..Default::default()
662 /// }))
663 /// .with_decoder_version(configs::DecoderVersion::Yolo26)
664 /// .build()?;
665 /// # Ok(())
666 /// # }
667 /// ```
668 pub fn with_decoder_version(mut self, version: configs::DecoderVersion) -> Self {
669 if !matches!(self.config_src, Some(ConfigSource::Config(_))) {
670 self.config_src = Some(ConfigSource::Config(ConfigOutputs::default()));
671 }
672 if let Some(ConfigSource::Config(ref mut config)) = self.config_src {
673 config.decoder_version = Some(version);
674 }
675 self
676 }
677
678 /// Normalize an output: if dshape is non-empty, derive shape from it.
679 fn normalize_output(mut output: ConfigOutput) -> ConfigOutput {
680 fn normalize_shape(shape: &mut Vec<usize>, dshape: &[(configs::DimName, usize)]) {
681 if !dshape.is_empty() {
682 *shape = dshape.iter().map(|(_, size)| *size).collect();
683 }
684 }
685 match &mut output {
686 ConfigOutput::Detection(c) => normalize_shape(&mut c.shape, &c.dshape),
687 ConfigOutput::Boxes(c) => normalize_shape(&mut c.shape, &c.dshape),
688 ConfigOutput::Scores(c) => normalize_shape(&mut c.shape, &c.dshape),
689 ConfigOutput::Protos(c) => normalize_shape(&mut c.shape, &c.dshape),
690 ConfigOutput::Segmentation(c) => normalize_shape(&mut c.shape, &c.dshape),
691 ConfigOutput::MaskCoefficients(c) => normalize_shape(&mut c.shape, &c.dshape),
692 ConfigOutput::Mask(c) => normalize_shape(&mut c.shape, &c.dshape),
693 ConfigOutput::Classes(c) => normalize_shape(&mut c.shape, &c.dshape),
694 }
695 output
696 }
697
698 /// Sets the scores threshold of the decoder
699 ///
700 /// # Examples
701 /// ```rust
702 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
703 /// # fn main() -> DecoderResult<()> {
704 /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
705 /// let decoder = DecoderBuilder::new()
706 /// .with_config_json_str(config_json)
707 /// .with_score_threshold(0.654)
708 /// .build()?;
709 /// assert_eq!(decoder.score_threshold, 0.654);
710 /// # Ok(())
711 /// # }
712 /// ```
713 pub fn with_score_threshold(mut self, score_threshold: f32) -> Self {
714 self.score_threshold = score_threshold;
715 self
716 }
717
718 /// Sets the IOU threshold of the decoder. Has no effect when NMS is set to
719 /// `None`
720 ///
721 /// # Examples
722 /// ```rust
723 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
724 /// # fn main() -> DecoderResult<()> {
725 /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
726 /// let decoder = DecoderBuilder::new()
727 /// .with_config_json_str(config_json)
728 /// .with_iou_threshold(0.654)
729 /// .build()?;
730 /// assert_eq!(decoder.iou_threshold, 0.654);
731 /// # Ok(())
732 /// # }
733 /// ```
734 pub fn with_iou_threshold(mut self, iou_threshold: f32) -> Self {
735 self.iou_threshold = iou_threshold;
736 self
737 }
738
739 /// Sets the NMS mode for the decoder.
740 ///
741 /// - `Some(Nms::ClassAgnostic)` — class-agnostic NMS (default): suppress
742 /// overlapping boxes regardless of class label
743 /// - `Some(Nms::ClassAware)` — class-aware NMS: only suppress boxes that
744 /// share the same class label AND overlap above the IoU threshold
745 /// - `None` — bypass NMS entirely (for end-to-end models with embedded NMS)
746 ///
747 /// # Examples
748 /// ```rust
749 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, configs::Nms};
750 /// # fn main() -> DecoderResult<()> {
751 /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
752 /// let decoder = DecoderBuilder::new()
753 /// .with_config_json_str(config_json)
754 /// .with_nms(Some(Nms::ClassAware))
755 /// .build()?;
756 /// assert_eq!(decoder.nms, Some(Nms::ClassAware));
757 /// # Ok(())
758 /// # }
759 /// ```
760 pub fn with_nms(mut self, nms: Option<configs::Nms>) -> Self {
761 self.nms = nms;
762 self
763 }
764
765 /// Builds the decoder with the given settings. If the config is a JSON or
766 /// YAML string, this will deserialize the JSON or YAML and then parse the
767 /// configuration information.
768 ///
769 /// # Examples
770 /// ```rust
771 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
772 /// # fn main() -> DecoderResult<()> {
773 /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
774 /// let decoder = DecoderBuilder::new()
775 /// .with_config_json_str(config_json)
776 /// .with_score_threshold(0.654)
777 /// .build()?;
778 /// # Ok(())
779 /// # }
780 /// ```
781 pub fn build(self) -> Result<Decoder, DecoderError> {
782 let (config, decode_program) = match self.config_src {
783 Some(ConfigSource::Json(s)) => Self::build_from_schema(SchemaV2::parse_json(&s)?)?,
784 Some(ConfigSource::Yaml(s)) => Self::build_from_schema(SchemaV2::parse_yaml(&s)?)?,
785 Some(ConfigSource::Config(c)) => (c, None),
786 Some(ConfigSource::Schema(schema)) => Self::build_from_schema(schema)?,
787 None => return Err(DecoderError::NoConfig),
788 };
789
790 // Enforce the physical-order contract: when dshape is present
791 // it must describe the same axes as shape in the same order,
792 // listed from outermost to innermost. Ambiguous-layout roles
793 // (Protos, Boxes, Scores, MaskCoefficients, Classes, Detection)
794 // may still omit dshape when shape is already in the decoder's
795 // canonical order.
796 for output in &config.outputs {
797 Decoder::validate_output_layout(output.into())?;
798 }
799
800 // Extract normalized flag from config outputs
801 let normalized = Self::get_normalized(&config.outputs);
802
803 // Use NMS from config if present, otherwise use builder's NMS setting
804 let nms = config.nms.or(self.nms);
805 let model_type = Self::get_model_type(config)?;
806
807 Ok(Decoder {
808 model_type,
809 iou_threshold: self.iou_threshold,
810 score_threshold: self.score_threshold,
811 nms,
812 normalized,
813 decode_program,
814 })
815 }
816
817 /// Validate a [`SchemaV2`] and lower it to the (legacy `ConfigOutputs`,
818 /// optional `DecodeProgram`) pair the rest of `build()` consumes.
819 ///
820 /// Centralises the v2 lowering so JSON, YAML, and direct
821 /// `with_schema` callers all go through the same validation and
822 /// merge-program construction. `SchemaV2::parse_json` /
823 /// `parse_yaml` already auto-detect v1 vs v2 input and return a v2
824 /// schema either way (v1 inputs are upgraded in memory via
825 /// `from_v1`), so this helper is the sole place that turns a
826 /// schema into builder-ready state.
827 fn build_from_schema(
828 schema: SchemaV2,
829 ) -> Result<(ConfigOutputs, Option<DecodeProgram>), DecoderError> {
830 schema.validate()?;
831 let program = DecodeProgram::try_from_schema(&schema)?;
832 let legacy = schema.to_legacy_config_outputs()?;
833 Ok((legacy, program))
834 }
835
836 /// Extracts the normalized flag from config outputs.
837 /// - `Some(true)`: Boxes are in normalized [0,1] coordinates
838 /// - `Some(false)`: Boxes are in pixel coordinates
839 /// - `None`: Unknown (not specified in config), caller must infer
840 fn get_normalized(outputs: &[ConfigOutput]) -> Option<bool> {
841 for output in outputs {
842 match output {
843 ConfigOutput::Detection(det) => return det.normalized,
844 ConfigOutput::Boxes(boxes) => return boxes.normalized,
845 _ => {}
846 }
847 }
848 None // not specified
849 }
850
851 fn get_model_type(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
852 // yolo or modelpack
853 let mut yolo = false;
854 let mut modelpack = false;
855 for c in &configs.outputs {
856 match c.decoder() {
857 DecoderType::ModelPack => modelpack = true,
858 DecoderType::Ultralytics => yolo = true,
859 }
860 }
861 match (modelpack, yolo) {
862 (true, true) => Err(DecoderError::InvalidConfig(
863 "Both ModelPack and Yolo outputs found in config".to_string(),
864 )),
865 (true, false) => Self::get_model_type_modelpack(configs),
866 (false, true) => Self::get_model_type_yolo(configs),
867 (false, false) => Err(DecoderError::InvalidConfig(
868 "No outputs found in config".to_string(),
869 )),
870 }
871 }
872
873 fn get_model_type_yolo(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
874 let mut boxes = None;
875 let mut protos = None;
876 let mut split_boxes = None;
877 let mut split_scores = None;
878 let mut split_mask_coeff = None;
879 let mut split_classes = None;
880 for c in configs.outputs {
881 match c {
882 ConfigOutput::Detection(detection) => boxes = Some(detection),
883 ConfigOutput::Segmentation(_) => {
884 return Err(DecoderError::InvalidConfig(
885 "Invalid Segmentation output with Yolo decoder".to_string(),
886 ));
887 }
888 ConfigOutput::Protos(protos_) => protos = Some(protos_),
889 ConfigOutput::Mask(_) => {
890 return Err(DecoderError::InvalidConfig(
891 "Invalid Mask output with Yolo decoder".to_string(),
892 ));
893 }
894 ConfigOutput::Scores(scores) => split_scores = Some(scores),
895 ConfigOutput::Boxes(boxes) => split_boxes = Some(boxes),
896 ConfigOutput::MaskCoefficients(mask_coeff) => split_mask_coeff = Some(mask_coeff),
897 ConfigOutput::Classes(classes) => split_classes = Some(classes),
898 }
899 }
900
901 // Use end-to-end model types when:
902 // 1. decoder_version is explicitly set to Yolo26 (definitive), OR
903 // decoder_version is not set but the dshapes are (batch, num_boxes,
904 // num_features)
905 let is_end_to_end_dshape = boxes.as_ref().is_some_and(|b| {
906 let dims = b.dshape.iter().map(|(d, _)| *d).collect::<Vec<_>>();
907 dims == vec![DimName::Batch, DimName::NumBoxes, DimName::NumFeatures]
908 });
909
910 let is_end_to_end = configs
911 .decoder_version
912 .map(|v| v.is_end_to_end())
913 .unwrap_or(is_end_to_end_dshape);
914
915 if is_end_to_end {
916 if let Some(boxes) = boxes {
917 if let Some(protos) = protos {
918 Self::verify_yolo_seg_det_26(&boxes, &protos)?;
919 return Ok(ModelType::YoloEndToEndSegDet { boxes, protos });
920 } else {
921 Self::verify_yolo_det_26(&boxes)?;
922 return Ok(ModelType::YoloEndToEndDet { boxes });
923 }
924 } else if let (Some(split_boxes), Some(split_scores), Some(split_classes)) =
925 (split_boxes, split_scores, split_classes)
926 {
927 if let (Some(split_mask_coeff), Some(protos)) = (split_mask_coeff, protos) {
928 Self::verify_yolo_split_end_to_end_segdet(
929 &split_boxes,
930 &split_scores,
931 &split_classes,
932 &split_mask_coeff,
933 &protos,
934 )?;
935 return Ok(ModelType::YoloSplitEndToEndSegDet {
936 boxes: split_boxes,
937 scores: split_scores,
938 classes: split_classes,
939 mask_coeff: split_mask_coeff,
940 protos,
941 });
942 }
943 Self::verify_yolo_split_end_to_end_det(
944 &split_boxes,
945 &split_scores,
946 &split_classes,
947 )?;
948 return Ok(ModelType::YoloSplitEndToEndDet {
949 boxes: split_boxes,
950 scores: split_scores,
951 classes: split_classes,
952 });
953 } else {
954 return Err(DecoderError::InvalidConfig(
955 "Invalid Yolo end-to-end model outputs".to_string(),
956 ));
957 }
958 }
959
960 if let Some(boxes) = boxes {
961 match (split_mask_coeff, protos) {
962 (Some(mask_coeff), Some(protos)) => {
963 // 2-way split: combined detection + separate mask_coeff + protos
964 Self::verify_yolo_seg_det_2way(&boxes, &mask_coeff, &protos)?;
965 Ok(ModelType::YoloSegDet2Way {
966 boxes,
967 mask_coeff,
968 protos,
969 })
970 }
971 (_, Some(protos)) => {
972 // Unsplit: mask_coefs embedded in detection tensor
973 Self::verify_yolo_seg_det(&boxes, &protos)?;
974 Ok(ModelType::YoloSegDet { boxes, protos })
975 }
976 _ => {
977 Self::verify_yolo_det(&boxes)?;
978 Ok(ModelType::YoloDet { boxes })
979 }
980 }
981 } else if let (Some(boxes), Some(scores)) = (split_boxes, split_scores) {
982 if let (Some(mask_coeff), Some(protos)) = (split_mask_coeff, protos) {
983 Self::verify_yolo_split_segdet(&boxes, &scores, &mask_coeff, &protos)?;
984 Ok(ModelType::YoloSplitSegDet {
985 boxes,
986 scores,
987 mask_coeff,
988 protos,
989 })
990 } else {
991 Self::verify_yolo_split_det(&boxes, &scores)?;
992 Ok(ModelType::YoloSplitDet { boxes, scores })
993 }
994 } else {
995 Err(DecoderError::InvalidConfig(
996 "Invalid Yolo model outputs".to_string(),
997 ))
998 }
999 }
1000
1001 fn verify_yolo_det(detect: &configs::Detection) -> Result<(), DecoderError> {
1002 if detect.shape.len() != 3 {
1003 return Err(DecoderError::InvalidConfig(format!(
1004 "Invalid Yolo Detection shape {:?}",
1005 detect.shape
1006 )));
1007 }
1008
1009 Self::verify_dshapes(
1010 &detect.dshape,
1011 &detect.shape,
1012 "Detection",
1013 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1014 )?;
1015 if !detect.dshape.is_empty() {
1016 Self::get_class_count(&detect.dshape, None, None)?;
1017 } else {
1018 Self::get_class_count_no_dshape(detect.into(), None)?;
1019 }
1020
1021 Ok(())
1022 }
1023
1024 fn verify_yolo_det_26(detect: &configs::Detection) -> Result<(), DecoderError> {
1025 if detect.shape.len() != 3 {
1026 return Err(DecoderError::InvalidConfig(format!(
1027 "Invalid Yolo Detection shape {:?}",
1028 detect.shape
1029 )));
1030 }
1031
1032 Self::verify_dshapes(
1033 &detect.dshape,
1034 &detect.shape,
1035 "Detection",
1036 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1037 )?;
1038
1039 if !detect.shape.contains(&6) {
1040 return Err(DecoderError::InvalidConfig(
1041 "Yolo26 Detection must have 6 features".to_string(),
1042 ));
1043 }
1044
1045 Ok(())
1046 }
1047
1048 fn verify_yolo_seg_det(
1049 detection: &configs::Detection,
1050 protos: &configs::Protos,
1051 ) -> Result<(), DecoderError> {
1052 if detection.shape.len() != 3 {
1053 return Err(DecoderError::InvalidConfig(format!(
1054 "Invalid Yolo Detection shape {:?}",
1055 detection.shape
1056 )));
1057 }
1058 if protos.shape.len() != 4 {
1059 return Err(DecoderError::InvalidConfig(format!(
1060 "Invalid Yolo Protos shape {:?}",
1061 protos.shape
1062 )));
1063 }
1064
1065 Self::verify_dshapes(
1066 &detection.dshape,
1067 &detection.shape,
1068 "Detection",
1069 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1070 )?;
1071 Self::verify_dshapes(
1072 &protos.dshape,
1073 &protos.shape,
1074 "Protos",
1075 &[
1076 DimName::Batch,
1077 DimName::Height,
1078 DimName::Width,
1079 DimName::NumProtos,
1080 ],
1081 )?;
1082
1083 let protos_count = Self::get_protos_count(&protos.dshape)
1084 .unwrap_or_else(|| protos.shape[1].min(protos.shape[3]));
1085 log::debug!("Protos count: {}", protos_count);
1086 log::debug!("Detection dshape: {:?}", detection.dshape);
1087 let classes = if !detection.dshape.is_empty() {
1088 Self::get_class_count(&detection.dshape, Some(protos_count), None)?
1089 } else {
1090 Self::get_class_count_no_dshape(detection.into(), Some(protos_count))?
1091 };
1092
1093 if classes == 0 {
1094 return Err(DecoderError::InvalidConfig(
1095 "Yolo Segmentation Detection has zero classes".to_string(),
1096 ));
1097 }
1098
1099 Ok(())
1100 }
1101
1102 fn verify_yolo_seg_det_2way(
1103 detection: &configs::Detection,
1104 mask_coeff: &configs::MaskCoefficients,
1105 protos: &configs::Protos,
1106 ) -> Result<(), DecoderError> {
1107 if detection.shape.len() != 3 {
1108 return Err(DecoderError::InvalidConfig(format!(
1109 "Invalid Yolo 2-Way Detection shape {:?}",
1110 detection.shape
1111 )));
1112 }
1113 if mask_coeff.shape.len() != 3 {
1114 return Err(DecoderError::InvalidConfig(format!(
1115 "Invalid Yolo 2-Way Mask Coefficients shape {:?}",
1116 mask_coeff.shape
1117 )));
1118 }
1119 if protos.shape.len() != 4 {
1120 return Err(DecoderError::InvalidConfig(format!(
1121 "Invalid Yolo 2-Way Protos shape {:?}",
1122 protos.shape
1123 )));
1124 }
1125
1126 Self::verify_dshapes(
1127 &detection.dshape,
1128 &detection.shape,
1129 "Detection",
1130 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1131 )?;
1132 Self::verify_dshapes(
1133 &mask_coeff.dshape,
1134 &mask_coeff.shape,
1135 "Mask Coefficients",
1136 &[DimName::Batch, DimName::NumProtos, DimName::NumBoxes],
1137 )?;
1138 Self::verify_dshapes(
1139 &protos.dshape,
1140 &protos.shape,
1141 "Protos",
1142 &[
1143 DimName::Batch,
1144 DimName::Height,
1145 DimName::Width,
1146 DimName::NumProtos,
1147 ],
1148 )?;
1149
1150 // Validate num_boxes match between detection and mask_coeff
1151 let det_num = Self::get_box_count(&detection.dshape).unwrap_or(detection.shape[2]);
1152 let mask_num = Self::get_box_count(&mask_coeff.dshape).unwrap_or(mask_coeff.shape[2]);
1153 if det_num != mask_num {
1154 return Err(DecoderError::InvalidConfig(format!(
1155 "Yolo 2-Way Detection num_boxes {} incompatible with Mask Coefficients num_boxes {}",
1156 det_num, mask_num
1157 )));
1158 }
1159
1160 // Validate mask_coeff channels match protos channels
1161 let mask_channels = if !mask_coeff.dshape.is_empty() {
1162 Self::get_protos_count(&mask_coeff.dshape).ok_or_else(|| {
1163 DecoderError::InvalidConfig(
1164 "Could not find num_protos in mask_coeff config".to_string(),
1165 )
1166 })?
1167 } else {
1168 mask_coeff.shape[1]
1169 };
1170 let proto_channels = if !protos.dshape.is_empty() {
1171 Self::get_protos_count(&protos.dshape).ok_or_else(|| {
1172 DecoderError::InvalidConfig(
1173 "Could not find num_protos in protos config".to_string(),
1174 )
1175 })?
1176 } else {
1177 protos.shape[1].min(protos.shape[3])
1178 };
1179 if mask_channels != proto_channels {
1180 return Err(DecoderError::InvalidConfig(format!(
1181 "Yolo 2-Way Protos channels {} incompatible with Mask Coefficients channels {}",
1182 proto_channels, mask_channels
1183 )));
1184 }
1185
1186 // Validate detection has classes (nc+4 features, no mask_coefs embedded)
1187 if !detection.dshape.is_empty() {
1188 Self::get_class_count(&detection.dshape, None, None)?;
1189 } else {
1190 Self::get_class_count_no_dshape(detection.into(), None)?;
1191 }
1192
1193 Ok(())
1194 }
1195
1196 fn verify_yolo_seg_det_26(
1197 detection: &configs::Detection,
1198 protos: &configs::Protos,
1199 ) -> Result<(), DecoderError> {
1200 if detection.shape.len() != 3 {
1201 return Err(DecoderError::InvalidConfig(format!(
1202 "Invalid Yolo Detection shape {:?}",
1203 detection.shape
1204 )));
1205 }
1206 if protos.shape.len() != 4 {
1207 return Err(DecoderError::InvalidConfig(format!(
1208 "Invalid Yolo Protos shape {:?}",
1209 protos.shape
1210 )));
1211 }
1212
1213 Self::verify_dshapes(
1214 &detection.dshape,
1215 &detection.shape,
1216 "Detection",
1217 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1218 )?;
1219 Self::verify_dshapes(
1220 &protos.dshape,
1221 &protos.shape,
1222 "Protos",
1223 &[
1224 DimName::Batch,
1225 DimName::Height,
1226 DimName::Width,
1227 DimName::NumProtos,
1228 ],
1229 )?;
1230
1231 let protos_count = Self::get_protos_count(&protos.dshape)
1232 .unwrap_or_else(|| protos.shape[1].min(protos.shape[3]));
1233 log::debug!("Protos count: {}", protos_count);
1234 log::debug!("Detection dshape: {:?}", detection.dshape);
1235
1236 if !detection.shape.contains(&(6 + protos_count)) {
1237 return Err(DecoderError::InvalidConfig(format!(
1238 "Yolo26 Segmentation Detection must have num_features be 6 + num_protos = {}",
1239 6 + protos_count
1240 )));
1241 }
1242
1243 Ok(())
1244 }
1245
1246 fn verify_yolo_split_det(
1247 boxes: &configs::Boxes,
1248 scores: &configs::Scores,
1249 ) -> Result<(), DecoderError> {
1250 if boxes.shape.len() != 3 {
1251 return Err(DecoderError::InvalidConfig(format!(
1252 "Invalid Yolo Split Boxes shape {:?}",
1253 boxes.shape
1254 )));
1255 }
1256 if scores.shape.len() != 3 {
1257 return Err(DecoderError::InvalidConfig(format!(
1258 "Invalid Yolo Split Scores shape {:?}",
1259 scores.shape
1260 )));
1261 }
1262
1263 Self::verify_dshapes(
1264 &boxes.dshape,
1265 &boxes.shape,
1266 "Boxes",
1267 &[DimName::Batch, DimName::BoxCoords, DimName::NumBoxes],
1268 )?;
1269 Self::verify_dshapes(
1270 &scores.dshape,
1271 &scores.shape,
1272 "Scores",
1273 &[DimName::Batch, DimName::NumClasses, DimName::NumBoxes],
1274 )?;
1275
1276 let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[2]);
1277 let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[2]);
1278
1279 if boxes_num != scores_num {
1280 return Err(DecoderError::InvalidConfig(format!(
1281 "Yolo Split Detection Boxes num {} incompatible with Scores num {}",
1282 boxes_num, scores_num
1283 )));
1284 }
1285
1286 Ok(())
1287 }
1288
1289 fn verify_yolo_split_segdet(
1290 boxes: &configs::Boxes,
1291 scores: &configs::Scores,
1292 mask_coeff: &configs::MaskCoefficients,
1293 protos: &configs::Protos,
1294 ) -> Result<(), DecoderError> {
1295 if boxes.shape.len() != 3 {
1296 return Err(DecoderError::InvalidConfig(format!(
1297 "Invalid Yolo Split Boxes shape {:?}",
1298 boxes.shape
1299 )));
1300 }
1301 if scores.shape.len() != 3 {
1302 return Err(DecoderError::InvalidConfig(format!(
1303 "Invalid Yolo Split Scores shape {:?}",
1304 scores.shape
1305 )));
1306 }
1307
1308 if mask_coeff.shape.len() != 3 {
1309 return Err(DecoderError::InvalidConfig(format!(
1310 "Invalid Yolo Split Mask Coefficients shape {:?}",
1311 mask_coeff.shape
1312 )));
1313 }
1314
1315 if protos.shape.len() != 4 {
1316 return Err(DecoderError::InvalidConfig(format!(
1317 "Invalid Yolo Protos shape {:?}",
1318 mask_coeff.shape
1319 )));
1320 }
1321
1322 Self::verify_dshapes(
1323 &boxes.dshape,
1324 &boxes.shape,
1325 "Boxes",
1326 &[DimName::Batch, DimName::BoxCoords, DimName::NumBoxes],
1327 )?;
1328 Self::verify_dshapes(
1329 &scores.dshape,
1330 &scores.shape,
1331 "Scores",
1332 &[DimName::Batch, DimName::NumClasses, DimName::NumBoxes],
1333 )?;
1334 Self::verify_dshapes(
1335 &mask_coeff.dshape,
1336 &mask_coeff.shape,
1337 "Mask Coefficients",
1338 &[DimName::Batch, DimName::NumProtos, DimName::NumBoxes],
1339 )?;
1340 Self::verify_dshapes(
1341 &protos.dshape,
1342 &protos.shape,
1343 "Protos",
1344 &[
1345 DimName::Batch,
1346 DimName::Height,
1347 DimName::Width,
1348 DimName::NumProtos,
1349 ],
1350 )?;
1351
1352 let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[2]);
1353 let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[2]);
1354 let mask_num = Self::get_box_count(&mask_coeff.dshape).unwrap_or(mask_coeff.shape[2]);
1355
1356 let mask_channels = if !mask_coeff.dshape.is_empty() {
1357 Self::get_protos_count(&mask_coeff.dshape).ok_or_else(|| {
1358 DecoderError::InvalidConfig("Could not find num_protos in config".to_string())
1359 })?
1360 } else {
1361 mask_coeff.shape[1]
1362 };
1363 let proto_channels = if !protos.dshape.is_empty() {
1364 Self::get_protos_count(&protos.dshape).ok_or_else(|| {
1365 DecoderError::InvalidConfig("Could not find num_protos in config".to_string())
1366 })?
1367 } else {
1368 protos.shape[1].min(protos.shape[3])
1369 };
1370
1371 if boxes_num != scores_num {
1372 return Err(DecoderError::InvalidConfig(format!(
1373 "Yolo Split Detection Boxes num {} incompatible with Scores num {}",
1374 boxes_num, scores_num
1375 )));
1376 }
1377
1378 if boxes_num != mask_num {
1379 return Err(DecoderError::InvalidConfig(format!(
1380 "Yolo Split Detection Boxes num {} incompatible with Mask Coefficients num {}",
1381 boxes_num, mask_num
1382 )));
1383 }
1384
1385 if proto_channels != mask_channels {
1386 return Err(DecoderError::InvalidConfig(format!(
1387 "Yolo Protos channels {} incompatible with Mask Coefficients channels {}",
1388 proto_channels, mask_channels
1389 )));
1390 }
1391
1392 Ok(())
1393 }
1394
1395 fn verify_yolo_split_end_to_end_det(
1396 boxes: &configs::Boxes,
1397 scores: &configs::Scores,
1398 classes: &configs::Classes,
1399 ) -> Result<(), DecoderError> {
1400 if boxes.shape.len() != 3 || !boxes.shape.contains(&4) {
1401 return Err(DecoderError::InvalidConfig(format!(
1402 "Split end-to-end boxes must be [batch, N, 4], got {:?}",
1403 boxes.shape
1404 )));
1405 }
1406 if scores.shape.len() != 3 || !scores.shape.contains(&1) {
1407 return Err(DecoderError::InvalidConfig(format!(
1408 "Split end-to-end scores must be [batch, N, 1], got {:?}",
1409 scores.shape
1410 )));
1411 }
1412 if classes.shape.len() != 3 || !classes.shape.contains(&1) {
1413 return Err(DecoderError::InvalidConfig(format!(
1414 "Split end-to-end classes must be [batch, N, 1], got {:?}",
1415 classes.shape
1416 )));
1417 }
1418 Ok(())
1419 }
1420
1421 fn verify_yolo_split_end_to_end_segdet(
1422 boxes: &configs::Boxes,
1423 scores: &configs::Scores,
1424 classes: &configs::Classes,
1425 mask_coeff: &configs::MaskCoefficients,
1426 protos: &configs::Protos,
1427 ) -> Result<(), DecoderError> {
1428 Self::verify_yolo_split_end_to_end_det(boxes, scores, classes)?;
1429 if mask_coeff.shape.len() != 3 {
1430 return Err(DecoderError::InvalidConfig(format!(
1431 "Invalid split end-to-end mask coefficients shape {:?}",
1432 mask_coeff.shape
1433 )));
1434 }
1435 if protos.shape.len() != 4 {
1436 return Err(DecoderError::InvalidConfig(format!(
1437 "Invalid protos shape {:?}",
1438 protos.shape
1439 )));
1440 }
1441 Ok(())
1442 }
1443
1444 fn get_model_type_modelpack(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
1445 let mut split_decoders = Vec::new();
1446 let mut segment_ = None;
1447 let mut scores_ = None;
1448 let mut boxes_ = None;
1449 for c in configs.outputs {
1450 match c {
1451 ConfigOutput::Detection(detection) => split_decoders.push(detection),
1452 ConfigOutput::Segmentation(segmentation) => segment_ = Some(segmentation),
1453 ConfigOutput::Mask(_) => {}
1454 ConfigOutput::Protos(_) => {
1455 return Err(DecoderError::InvalidConfig(
1456 "ModelPack should not have protos".to_string(),
1457 ));
1458 }
1459 ConfigOutput::Scores(scores) => scores_ = Some(scores),
1460 ConfigOutput::Boxes(boxes) => boxes_ = Some(boxes),
1461 ConfigOutput::MaskCoefficients(_) => {
1462 return Err(DecoderError::InvalidConfig(
1463 "ModelPack should not have mask coefficients".to_string(),
1464 ));
1465 }
1466 ConfigOutput::Classes(_) => {
1467 return Err(DecoderError::InvalidConfig(
1468 "ModelPack should not have classes output".to_string(),
1469 ));
1470 }
1471 }
1472 }
1473
1474 if let Some(segmentation) = segment_ {
1475 if !split_decoders.is_empty() {
1476 let classes = Self::verify_modelpack_split_det(&split_decoders)?;
1477 Self::verify_modelpack_seg(&segmentation, Some(classes))?;
1478 Ok(ModelType::ModelPackSegDetSplit {
1479 detection: split_decoders,
1480 segmentation,
1481 })
1482 } else if let (Some(scores), Some(boxes)) = (scores_, boxes_) {
1483 let classes = Self::verify_modelpack_det(&boxes, &scores)?;
1484 Self::verify_modelpack_seg(&segmentation, Some(classes))?;
1485 Ok(ModelType::ModelPackSegDet {
1486 boxes,
1487 scores,
1488 segmentation,
1489 })
1490 } else {
1491 Self::verify_modelpack_seg(&segmentation, None)?;
1492 Ok(ModelType::ModelPackSeg { segmentation })
1493 }
1494 } else if !split_decoders.is_empty() {
1495 Self::verify_modelpack_split_det(&split_decoders)?;
1496 Ok(ModelType::ModelPackDetSplit {
1497 detection: split_decoders,
1498 })
1499 } else if let (Some(scores), Some(boxes)) = (scores_, boxes_) {
1500 Self::verify_modelpack_det(&boxes, &scores)?;
1501 Ok(ModelType::ModelPackDet { boxes, scores })
1502 } else {
1503 Err(DecoderError::InvalidConfig(
1504 "Invalid ModelPack model outputs".to_string(),
1505 ))
1506 }
1507 }
1508
1509 fn verify_modelpack_det(
1510 boxes: &configs::Boxes,
1511 scores: &configs::Scores,
1512 ) -> Result<usize, DecoderError> {
1513 if boxes.shape.len() != 4 {
1514 return Err(DecoderError::InvalidConfig(format!(
1515 "Invalid ModelPack Boxes shape {:?}",
1516 boxes.shape
1517 )));
1518 }
1519 if scores.shape.len() != 3 {
1520 return Err(DecoderError::InvalidConfig(format!(
1521 "Invalid ModelPack Scores shape {:?}",
1522 scores.shape
1523 )));
1524 }
1525
1526 Self::verify_dshapes(
1527 &boxes.dshape,
1528 &boxes.shape,
1529 "Boxes",
1530 &[
1531 DimName::Batch,
1532 DimName::NumBoxes,
1533 DimName::Padding,
1534 DimName::BoxCoords,
1535 ],
1536 )?;
1537 Self::verify_dshapes(
1538 &scores.dshape,
1539 &scores.shape,
1540 "Scores",
1541 &[DimName::Batch, DimName::NumBoxes, DimName::NumClasses],
1542 )?;
1543
1544 let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[1]);
1545 let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[1]);
1546
1547 if boxes_num != scores_num {
1548 return Err(DecoderError::InvalidConfig(format!(
1549 "ModelPack Detection Boxes num {} incompatible with Scores num {}",
1550 boxes_num, scores_num
1551 )));
1552 }
1553
1554 let num_classes = if !scores.dshape.is_empty() {
1555 Self::get_class_count(&scores.dshape, None, None)?
1556 } else {
1557 Self::get_class_count_no_dshape(scores.into(), None)?
1558 };
1559
1560 Ok(num_classes)
1561 }
1562
1563 fn verify_modelpack_split_det(boxes: &[configs::Detection]) -> Result<usize, DecoderError> {
1564 let mut num_classes = None;
1565 for b in boxes {
1566 let Some(num_anchors) = b.anchors.as_ref().map(|a| a.len()) else {
1567 return Err(DecoderError::InvalidConfig(
1568 "ModelPack Split Detection missing anchors".to_string(),
1569 ));
1570 };
1571
1572 if num_anchors == 0 {
1573 return Err(DecoderError::InvalidConfig(
1574 "ModelPack Split Detection has zero anchors".to_string(),
1575 ));
1576 }
1577
1578 if b.shape.len() != 4 {
1579 return Err(DecoderError::InvalidConfig(format!(
1580 "Invalid ModelPack Split Detection shape {:?}",
1581 b.shape
1582 )));
1583 }
1584
1585 Self::verify_dshapes(
1586 &b.dshape,
1587 &b.shape,
1588 "Split Detection",
1589 &[
1590 DimName::Batch,
1591 DimName::Height,
1592 DimName::Width,
1593 DimName::NumAnchorsXFeatures,
1594 ],
1595 )?;
1596 let classes = if !b.dshape.is_empty() {
1597 Self::get_class_count(&b.dshape, None, Some(num_anchors))?
1598 } else {
1599 Self::get_class_count_no_dshape(b.into(), None)?
1600 };
1601
1602 match num_classes {
1603 Some(n) => {
1604 if n != classes {
1605 return Err(DecoderError::InvalidConfig(format!(
1606 "ModelPack Split Detection inconsistent number of classes: previous {}, current {}",
1607 n, classes
1608 )));
1609 }
1610 }
1611 None => {
1612 num_classes = Some(classes);
1613 }
1614 }
1615 }
1616
1617 Ok(num_classes.unwrap_or(0))
1618 }
1619
1620 fn verify_modelpack_seg(
1621 segmentation: &configs::Segmentation,
1622 classes: Option<usize>,
1623 ) -> Result<(), DecoderError> {
1624 if segmentation.shape.len() != 4 {
1625 return Err(DecoderError::InvalidConfig(format!(
1626 "Invalid ModelPack Segmentation shape {:?}",
1627 segmentation.shape
1628 )));
1629 }
1630 Self::verify_dshapes(
1631 &segmentation.dshape,
1632 &segmentation.shape,
1633 "Segmentation",
1634 &[
1635 DimName::Batch,
1636 DimName::Height,
1637 DimName::Width,
1638 DimName::NumClasses,
1639 ],
1640 )?;
1641
1642 if let Some(classes) = classes {
1643 let seg_classes = if !segmentation.dshape.is_empty() {
1644 Self::get_class_count(&segmentation.dshape, None, None)?
1645 } else {
1646 Self::get_class_count_no_dshape(segmentation.into(), None)?
1647 };
1648
1649 if seg_classes != classes + 1 {
1650 return Err(DecoderError::InvalidConfig(format!(
1651 "ModelPack Segmentation channels {} incompatible with number of classes {}",
1652 seg_classes, classes
1653 )));
1654 }
1655 }
1656 Ok(())
1657 }
1658
1659 // verifies that dshapes match the given shape
1660 fn verify_dshapes(
1661 dshape: &[(DimName, usize)],
1662 shape: &[usize],
1663 name: &str,
1664 dims: &[DimName],
1665 ) -> Result<(), DecoderError> {
1666 for s in shape {
1667 if *s == 0 {
1668 return Err(DecoderError::InvalidConfig(format!(
1669 "{} shape has zero dimension",
1670 name
1671 )));
1672 }
1673 }
1674
1675 if shape.len() != dims.len() {
1676 return Err(DecoderError::InvalidConfig(format!(
1677 "{} shape length {} does not match expected dims length {}",
1678 name,
1679 shape.len(),
1680 dims.len()
1681 )));
1682 }
1683
1684 if dshape.is_empty() {
1685 return Ok(());
1686 }
1687 // check the dshape lengths match the shape lengths
1688 if dshape.len() != shape.len() {
1689 return Err(DecoderError::InvalidConfig(format!(
1690 "{} dshape length does not match shape length",
1691 name
1692 )));
1693 }
1694
1695 // check the dshape values match the shape values
1696 for ((dim_name, dim_size), shape_size) in dshape.iter().zip(shape) {
1697 if dim_size != shape_size {
1698 return Err(DecoderError::InvalidConfig(format!(
1699 "{} dshape dimension {} size {} does not match shape size {}",
1700 name, dim_name, dim_size, shape_size
1701 )));
1702 }
1703 if *dim_name == DimName::Padding && *dim_size != 1 {
1704 return Err(DecoderError::InvalidConfig(
1705 "Padding dimension size must be 1".to_string(),
1706 ));
1707 }
1708
1709 if *dim_name == DimName::BoxCoords && *dim_size != 4 {
1710 return Err(DecoderError::InvalidConfig(
1711 "BoxCoords dimension size must be 4".to_string(),
1712 ));
1713 }
1714 }
1715
1716 let dims_present = HashSet::<DimName>::from_iter(dshape.iter().map(|(name, _)| *name));
1717 for dim in dims {
1718 if !dims_present.contains(dim) {
1719 return Err(DecoderError::InvalidConfig(format!(
1720 "{} dshape missing required dimension {:?}",
1721 name, dim
1722 )));
1723 }
1724 }
1725
1726 Ok(())
1727 }
1728
1729 fn get_box_count(dshape: &[(DimName, usize)]) -> Option<usize> {
1730 for (dim_name, dim_size) in dshape {
1731 if *dim_name == DimName::NumBoxes {
1732 return Some(*dim_size);
1733 }
1734 }
1735 None
1736 }
1737
1738 fn get_class_count_no_dshape(
1739 config: ConfigOutputRef,
1740 protos: Option<usize>,
1741 ) -> Result<usize, DecoderError> {
1742 match config {
1743 ConfigOutputRef::Detection(detection) => match detection.decoder {
1744 DecoderType::Ultralytics => {
1745 if detection.shape[1] <= 4 + protos.unwrap_or(0) {
1746 return Err(DecoderError::InvalidConfig(format!(
1747 "Invalid shape: Yolo num_features {} must be greater than {}",
1748 detection.shape[1],
1749 4 + protos.unwrap_or(0),
1750 )));
1751 }
1752 Ok(detection.shape[1] - 4 - protos.unwrap_or(0))
1753 }
1754 DecoderType::ModelPack => {
1755 let Some(num_anchors) = detection.anchors.as_ref().map(|a| a.len()) else {
1756 return Err(DecoderError::Internal(
1757 "ModelPack Detection missing anchors".to_string(),
1758 ));
1759 };
1760 let anchors_x_features = detection.shape[3];
1761 if anchors_x_features <= num_anchors * 5 {
1762 return Err(DecoderError::InvalidConfig(format!(
1763 "Invalid ModelPack Split Detection shape: anchors_x_features {} not greater than number of anchors * 5 = {}",
1764 anchors_x_features,
1765 num_anchors * 5,
1766 )));
1767 }
1768
1769 if !anchors_x_features.is_multiple_of(num_anchors) {
1770 return Err(DecoderError::InvalidConfig(format!(
1771 "Invalid ModelPack Split Detection shape: anchors_x_features {} not a multiple of number of anchors {}",
1772 anchors_x_features, num_anchors
1773 )));
1774 }
1775 Ok(anchors_x_features / num_anchors - 5)
1776 }
1777 },
1778
1779 ConfigOutputRef::Scores(scores) => match scores.decoder {
1780 DecoderType::Ultralytics => Ok(scores.shape[1]),
1781 DecoderType::ModelPack => Ok(scores.shape[2]),
1782 },
1783 ConfigOutputRef::Segmentation(seg) => Ok(seg.shape[3]),
1784 _ => Err(DecoderError::Internal(
1785 "Attempted to get class count from unsupported config output".to_owned(),
1786 )),
1787 }
1788 }
1789
1790 // get the class count from dshape or calculate from num_features
1791 fn get_class_count(
1792 dshape: &[(DimName, usize)],
1793 protos: Option<usize>,
1794 anchors: Option<usize>,
1795 ) -> Result<usize, DecoderError> {
1796 if dshape.is_empty() {
1797 return Ok(0);
1798 }
1799 // if it has num_classes in dshape, return it
1800 for (dim_name, dim_size) in dshape {
1801 if *dim_name == DimName::NumClasses {
1802 return Ok(*dim_size);
1803 }
1804 }
1805
1806 // number of classes can be calculated from num_features - 4 for yolo. If the
1807 // model has protos, we also subtract the number of protos.
1808 for (dim_name, dim_size) in dshape {
1809 if *dim_name == DimName::NumFeatures {
1810 let protos = protos.unwrap_or(0);
1811 if protos + 4 >= *dim_size {
1812 return Err(DecoderError::InvalidConfig(format!(
1813 "Invalid shape: Yolo num_features {} must be greater than {}",
1814 *dim_size,
1815 protos + 4,
1816 )));
1817 }
1818 return Ok(*dim_size - 4 - protos);
1819 }
1820 }
1821
1822 // number of classes can be calculated from number of anchors for modelpack
1823 // split detection
1824 if let Some(num_anchors) = anchors {
1825 for (dim_name, dim_size) in dshape {
1826 if *dim_name == DimName::NumAnchorsXFeatures {
1827 let anchors_x_features = *dim_size;
1828 if anchors_x_features <= num_anchors * 5 {
1829 return Err(DecoderError::InvalidConfig(format!(
1830 "Invalid ModelPack Split Detection shape: anchors_x_features {} not greater than number of anchors * 5 = {}",
1831 anchors_x_features,
1832 num_anchors * 5,
1833 )));
1834 }
1835
1836 if !anchors_x_features.is_multiple_of(num_anchors) {
1837 return Err(DecoderError::InvalidConfig(format!(
1838 "Invalid ModelPack Split Detection shape: anchors_x_features {} not a multiple of number of anchors {}",
1839 anchors_x_features, num_anchors
1840 )));
1841 }
1842 return Ok((anchors_x_features / num_anchors) - 5);
1843 }
1844 }
1845 }
1846 Err(DecoderError::InvalidConfig(
1847 "Cannot determine number of classes from dshape".to_owned(),
1848 ))
1849 }
1850
1851 fn get_protos_count(dshape: &[(DimName, usize)]) -> Option<usize> {
1852 for (dim_name, dim_size) in dshape {
1853 if *dim_name == DimName::NumProtos {
1854 return Some(*dim_size);
1855 }
1856 }
1857 None
1858 }
1859}