edgefirst_decoder/decoder/builder.rs
1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::HashSet;
5
6use super::config::ConfigOutputRef;
7use super::configs::{self, DecoderType, DecoderVersion, DimName, ModelType};
8use super::merge::DecodeProgram;
9use super::{ConfigOutput, ConfigOutputs, Decoder};
10use crate::schema::SchemaV2;
11use crate::DecoderError;
12
13#[derive(Debug, Clone, PartialEq)]
14pub struct DecoderBuilder {
15 config_src: Option<ConfigSource>,
16 iou_threshold: f32,
17 score_threshold: f32,
18 /// NMS mode: Some(mode) applies NMS, None bypasses NMS (for end-to-end
19 /// models)
20 nms: Option<configs::Nms>,
21 pre_nms_top_k: usize,
22 max_det: usize,
23}
24
25#[derive(Debug, Clone, PartialEq)]
26enum ConfigSource {
27 Yaml(String),
28 Json(String),
29 Config(ConfigOutputs),
30 /// Schema v2 metadata. During build the schema is either converted
31 /// to a legacy [`ConfigOutputs`] (flat case) or used to construct a
32 /// [`DecodeProgram`] that performs per-child dequant + merge at
33 /// decode time.
34 Schema(SchemaV2),
35}
36
37impl Default for DecoderBuilder {
38 /// Creates a default DecoderBuilder with no configuration and 0.5 score
39 /// threshold and 0.5 IoU threshold.
40 ///
41 /// A valid configuration must be provided before building the Decoder.
42 ///
43 /// # Examples
44 /// ```rust
45 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
46 /// # fn main() -> DecoderResult<()> {
47 /// # let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
48 /// let decoder = DecoderBuilder::default()
49 /// .with_config_yaml_str(config_yaml)
50 /// .build()?;
51 /// assert_eq!(decoder.score_threshold, 0.5);
52 /// assert_eq!(decoder.iou_threshold, 0.5);
53 ///
54 /// # Ok(())
55 /// # }
56 /// ```
57 fn default() -> Self {
58 Self {
59 config_src: None,
60 iou_threshold: 0.5,
61 score_threshold: 0.5,
62 nms: Some(configs::Nms::ClassAgnostic),
63 pre_nms_top_k: 300,
64 max_det: 300,
65 }
66 }
67}
68
69impl DecoderBuilder {
70 /// Creates a default DecoderBuilder with no configuration and 0.5 score
71 /// threshold and 0.5 IoU threshold.
72 ///
73 /// A valid configuration must be provided before building the Decoder.
74 ///
75 /// # Examples
76 /// ```rust
77 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
78 /// # fn main() -> DecoderResult<()> {
79 /// # let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
80 /// let decoder = DecoderBuilder::new()
81 /// .with_config_yaml_str(config_yaml)
82 /// .build()?;
83 /// assert_eq!(decoder.score_threshold, 0.5);
84 /// assert_eq!(decoder.iou_threshold, 0.5);
85 ///
86 /// # Ok(())
87 /// # }
88 /// ```
89 pub fn new() -> Self {
90 Self::default()
91 }
92
93 /// Loads a model configuration in YAML format. Does not check if the string
94 /// is a correct configuration file. Use `DecoderBuilder.build()` to
95 /// deserialize the YAML and parse the model configuration.
96 ///
97 /// # Examples
98 /// ```rust
99 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
100 /// # fn main() -> DecoderResult<()> {
101 /// let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
102 /// let decoder = DecoderBuilder::new()
103 /// .with_config_yaml_str(config_yaml)
104 /// .build()?;
105 ///
106 /// # Ok(())
107 /// # }
108 /// ```
109 pub fn with_config_yaml_str(mut self, yaml_str: String) -> Self {
110 self.config_src.replace(ConfigSource::Yaml(yaml_str));
111 self
112 }
113
114 /// Loads a model configuration in JSON format. Does not check if the string
115 /// is a correct configuration file. Use `DecoderBuilder.build()` to
116 /// deserialize the JSON and parse the model configuration.
117 ///
118 /// # Examples
119 /// ```rust
120 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
121 /// # fn main() -> DecoderResult<()> {
122 /// let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
123 /// let decoder = DecoderBuilder::new()
124 /// .with_config_json_str(config_json)
125 /// .build()?;
126 ///
127 /// # Ok(())
128 /// # }
129 /// ```
130 pub fn with_config_json_str(mut self, json_str: String) -> Self {
131 self.config_src.replace(ConfigSource::Json(json_str));
132 self
133 }
134
135 /// Loads a model configuration. Does not check if the configuration is
136 /// correct. Intended to be used when the user needs control over the
137 /// deserialize of the configuration information. Use
138 /// `DecoderBuilder.build()` to parse the model configuration.
139 ///
140 /// # Examples
141 /// ```rust
142 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
143 /// # fn main() -> DecoderResult<()> {
144 /// let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json"));
145 /// let config = serde_json::from_str(config_json)?;
146 /// let decoder = DecoderBuilder::new().with_config(config).build()?;
147 ///
148 /// # Ok(())
149 /// # }
150 /// ```
151 pub fn with_config(mut self, config: ConfigOutputs) -> Self {
152 self.config_src.replace(ConfigSource::Config(config));
153 self
154 }
155
156 /// Configure the decoder from a schema v2 metadata document.
157 ///
158 /// Accepts a [`SchemaV2`] as produced by [`SchemaV2::parse_json`],
159 /// [`SchemaV2::parse_yaml`], [`SchemaV2::parse_file`], or
160 /// constructed programmatically. The builder validates the schema,
161 /// compiles a [`DecodeProgram`] for any split logical outputs
162 /// (per-scale or channel sub-splits), and downconverts the
163 /// logical-level semantics to the legacy [`ConfigOutputs`]
164 /// representation consumed by the existing decoder dispatch.
165 ///
166 /// # Examples
167 ///
168 /// ```rust,no_run
169 /// use edgefirst_decoder::{DecoderBuilder, DecoderResult};
170 /// use edgefirst_decoder::schema::SchemaV2;
171 ///
172 /// # fn main() -> DecoderResult<()> {
173 /// let schema = SchemaV2::parse_file("model/edgefirst.json")?;
174 /// let decoder = DecoderBuilder::new()
175 /// .with_schema(schema)
176 /// .with_score_threshold(0.25)
177 /// .build()?;
178 /// # Ok(())
179 /// # }
180 /// ```
181 pub fn with_schema(mut self, schema: SchemaV2) -> Self {
182 self.config_src.replace(ConfigSource::Schema(schema));
183 self
184 }
185
186 /// Loads a YOLO detection model configuration. Use
187 /// `DecoderBuilder.build()` to parse the model configuration.
188 ///
189 /// # Examples
190 /// ```rust
191 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
192 /// # fn main() -> DecoderResult<()> {
193 /// let decoder = DecoderBuilder::new()
194 /// .with_config_yolo_det(
195 /// configs::Detection {
196 /// anchors: None,
197 /// decoder: configs::DecoderType::Ultralytics,
198 /// quantization: Some(configs::QuantTuple(0.012345, 26)),
199 /// shape: vec![1, 84, 8400],
200 /// dshape: Vec::new(),
201 /// normalized: Some(true),
202 /// },
203 /// None,
204 /// )
205 /// .build()?;
206 ///
207 /// # Ok(())
208 /// # }
209 /// ```
210 pub fn with_config_yolo_det(
211 mut self,
212 boxes: configs::Detection,
213 version: Option<DecoderVersion>,
214 ) -> Self {
215 let config = ConfigOutputs {
216 outputs: vec![ConfigOutput::Detection(boxes)],
217 decoder_version: version,
218 ..Default::default()
219 };
220 self.config_src.replace(ConfigSource::Config(config));
221 self
222 }
223
224 /// Loads a YOLO split detection model configuration. Use
225 /// `DecoderBuilder.build()` to parse the model configuration.
226 ///
227 /// # Examples
228 /// ```rust
229 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
230 /// # fn main() -> DecoderResult<()> {
231 /// let boxes_config = configs::Boxes {
232 /// decoder: configs::DecoderType::Ultralytics,
233 /// quantization: Some(configs::QuantTuple(0.012345, 26)),
234 /// shape: vec![1, 4, 8400],
235 /// dshape: Vec::new(),
236 /// normalized: Some(true),
237 /// };
238 /// let scores_config = configs::Scores {
239 /// decoder: configs::DecoderType::Ultralytics,
240 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
241 /// shape: vec![1, 80, 8400],
242 /// dshape: Vec::new(),
243 /// };
244 /// let decoder = DecoderBuilder::new()
245 /// .with_config_yolo_split_det(boxes_config, scores_config)
246 /// .build()?;
247 /// # Ok(())
248 /// # }
249 /// ```
250 pub fn with_config_yolo_split_det(
251 mut self,
252 boxes: configs::Boxes,
253 scores: configs::Scores,
254 ) -> Self {
255 let config = ConfigOutputs {
256 outputs: vec![ConfigOutput::Boxes(boxes), ConfigOutput::Scores(scores)],
257 ..Default::default()
258 };
259 self.config_src.replace(ConfigSource::Config(config));
260 self
261 }
262
263 /// Loads a YOLO segmentation model configuration. Use
264 /// `DecoderBuilder.build()` to parse the model configuration.
265 ///
266 /// # Examples
267 /// ```rust
268 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
269 /// # fn main() -> DecoderResult<()> {
270 /// let seg_config = configs::Detection {
271 /// decoder: configs::DecoderType::Ultralytics,
272 /// quantization: Some(configs::QuantTuple(0.012345, 26)),
273 /// shape: vec![1, 116, 8400],
274 /// anchors: None,
275 /// dshape: Vec::new(),
276 /// normalized: Some(true),
277 /// };
278 /// let protos_config = configs::Protos {
279 /// decoder: configs::DecoderType::Ultralytics,
280 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
281 /// shape: vec![1, 160, 160, 32],
282 /// dshape: Vec::new(),
283 /// };
284 /// let decoder = DecoderBuilder::new()
285 /// .with_config_yolo_segdet(
286 /// seg_config,
287 /// protos_config,
288 /// Some(configs::DecoderVersion::Yolov8),
289 /// )
290 /// .build()?;
291 /// # Ok(())
292 /// # }
293 /// ```
294 pub fn with_config_yolo_segdet(
295 mut self,
296 boxes: configs::Detection,
297 protos: configs::Protos,
298 version: Option<DecoderVersion>,
299 ) -> Self {
300 let config = ConfigOutputs {
301 outputs: vec![ConfigOutput::Detection(boxes), ConfigOutput::Protos(protos)],
302 decoder_version: version,
303 ..Default::default()
304 };
305 self.config_src.replace(ConfigSource::Config(config));
306 self
307 }
308
309 /// Loads a YOLO split segmentation model configuration. Use
310 /// `DecoderBuilder.build()` to parse the model configuration.
311 ///
312 /// # Examples
313 /// ```rust
314 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
315 /// # fn main() -> DecoderResult<()> {
316 /// let boxes_config = configs::Boxes {
317 /// decoder: configs::DecoderType::Ultralytics,
318 /// quantization: Some(configs::QuantTuple(0.012345, 26)),
319 /// shape: vec![1, 4, 8400],
320 /// dshape: Vec::new(),
321 /// normalized: Some(true),
322 /// };
323 /// let scores_config = configs::Scores {
324 /// decoder: configs::DecoderType::Ultralytics,
325 /// quantization: Some(configs::QuantTuple(0.012345, 14)),
326 /// shape: vec![1, 80, 8400],
327 /// dshape: Vec::new(),
328 /// };
329 /// let mask_config = configs::MaskCoefficients {
330 /// decoder: configs::DecoderType::Ultralytics,
331 /// quantization: Some(configs::QuantTuple(0.0064123, 125)),
332 /// shape: vec![1, 32, 8400],
333 /// dshape: Vec::new(),
334 /// };
335 /// let protos_config = configs::Protos {
336 /// decoder: configs::DecoderType::Ultralytics,
337 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
338 /// shape: vec![1, 160, 160, 32],
339 /// dshape: Vec::new(),
340 /// };
341 /// let decoder = DecoderBuilder::new()
342 /// .with_config_yolo_split_segdet(boxes_config, scores_config, mask_config, protos_config)
343 /// .build()?;
344 /// # Ok(())
345 /// # }
346 /// ```
347 pub fn with_config_yolo_split_segdet(
348 mut self,
349 boxes: configs::Boxes,
350 scores: configs::Scores,
351 mask_coefficients: configs::MaskCoefficients,
352 protos: configs::Protos,
353 ) -> Self {
354 let config = ConfigOutputs {
355 outputs: vec![
356 ConfigOutput::Boxes(boxes),
357 ConfigOutput::Scores(scores),
358 ConfigOutput::MaskCoefficients(mask_coefficients),
359 ConfigOutput::Protos(protos),
360 ],
361 ..Default::default()
362 };
363 self.config_src.replace(ConfigSource::Config(config));
364 self
365 }
366
367 /// Loads a ModelPack detection model configuration. Use
368 /// `DecoderBuilder.build()` to parse the model configuration.
369 ///
370 /// # Examples
371 /// ```rust
372 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
373 /// # fn main() -> DecoderResult<()> {
374 /// let boxes_config = configs::Boxes {
375 /// decoder: configs::DecoderType::ModelPack,
376 /// quantization: Some(configs::QuantTuple(0.012345, 26)),
377 /// shape: vec![1, 8400, 1, 4],
378 /// dshape: Vec::new(),
379 /// normalized: Some(true),
380 /// };
381 /// let scores_config = configs::Scores {
382 /// decoder: configs::DecoderType::ModelPack,
383 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
384 /// shape: vec![1, 8400, 3],
385 /// dshape: Vec::new(),
386 /// };
387 /// let decoder = DecoderBuilder::new()
388 /// .with_config_modelpack_det(boxes_config, scores_config)
389 /// .build()?;
390 /// # Ok(())
391 /// # }
392 /// ```
393 pub fn with_config_modelpack_det(
394 mut self,
395 boxes: configs::Boxes,
396 scores: configs::Scores,
397 ) -> Self {
398 let config = ConfigOutputs {
399 outputs: vec![ConfigOutput::Boxes(boxes), ConfigOutput::Scores(scores)],
400 ..Default::default()
401 };
402 self.config_src.replace(ConfigSource::Config(config));
403 self
404 }
405
406 /// Loads a ModelPack split detection model configuration. Use
407 /// `DecoderBuilder.build()` to parse the model configuration.
408 ///
409 /// # Examples
410 /// ```rust
411 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
412 /// # fn main() -> DecoderResult<()> {
413 /// let config0 = configs::Detection {
414 /// anchors: Some(vec![
415 /// [0.13750000298023224, 0.2074074000120163],
416 /// [0.2541666626930237, 0.21481481194496155],
417 /// [0.23125000298023224, 0.35185185074806213],
418 /// ]),
419 /// decoder: configs::DecoderType::ModelPack,
420 /// quantization: Some(configs::QuantTuple(0.012345, 26)),
421 /// shape: vec![1, 17, 30, 18],
422 /// dshape: Vec::new(),
423 /// normalized: Some(true),
424 /// };
425 /// let config1 = configs::Detection {
426 /// anchors: Some(vec![
427 /// [0.36666667461395264, 0.31481480598449707],
428 /// [0.38749998807907104, 0.4740740656852722],
429 /// [0.5333333611488342, 0.644444465637207],
430 /// ]),
431 /// decoder: configs::DecoderType::ModelPack,
432 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
433 /// shape: vec![1, 9, 15, 18],
434 /// dshape: Vec::new(),
435 /// normalized: Some(true),
436 /// };
437 ///
438 /// let decoder = DecoderBuilder::new()
439 /// .with_config_modelpack_det_split(vec![config0, config1])
440 /// .build()?;
441 /// # Ok(())
442 /// # }
443 /// ```
444 pub fn with_config_modelpack_det_split(mut self, boxes: Vec<configs::Detection>) -> Self {
445 let outputs = boxes.into_iter().map(ConfigOutput::Detection).collect();
446 let config = ConfigOutputs {
447 outputs,
448 ..Default::default()
449 };
450 self.config_src.replace(ConfigSource::Config(config));
451 self
452 }
453
454 /// Loads a ModelPack segmentation detection model configuration. Use
455 /// `DecoderBuilder.build()` to parse the model configuration.
456 ///
457 /// # Examples
458 /// ```rust
459 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
460 /// # fn main() -> DecoderResult<()> {
461 /// let boxes_config = configs::Boxes {
462 /// decoder: configs::DecoderType::ModelPack,
463 /// quantization: Some(configs::QuantTuple(0.012345, 26)),
464 /// shape: vec![1, 8400, 1, 4],
465 /// dshape: Vec::new(),
466 /// normalized: Some(true),
467 /// };
468 /// let scores_config = configs::Scores {
469 /// decoder: configs::DecoderType::ModelPack,
470 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
471 /// shape: vec![1, 8400, 2],
472 /// dshape: Vec::new(),
473 /// };
474 /// let seg_config = configs::Segmentation {
475 /// decoder: configs::DecoderType::ModelPack,
476 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
477 /// shape: vec![1, 640, 640, 3],
478 /// dshape: Vec::new(),
479 /// };
480 /// let decoder = DecoderBuilder::new()
481 /// .with_config_modelpack_segdet(boxes_config, scores_config, seg_config)
482 /// .build()?;
483 /// # Ok(())
484 /// # }
485 /// ```
486 pub fn with_config_modelpack_segdet(
487 mut self,
488 boxes: configs::Boxes,
489 scores: configs::Scores,
490 segmentation: configs::Segmentation,
491 ) -> Self {
492 let config = ConfigOutputs {
493 outputs: vec![
494 ConfigOutput::Boxes(boxes),
495 ConfigOutput::Scores(scores),
496 ConfigOutput::Segmentation(segmentation),
497 ],
498 ..Default::default()
499 };
500 self.config_src.replace(ConfigSource::Config(config));
501 self
502 }
503
504 /// Loads a ModelPack segmentation split detection model configuration. Use
505 /// `DecoderBuilder.build()` to parse the model configuration.
506 ///
507 /// # Examples
508 /// ```rust
509 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
510 /// # fn main() -> DecoderResult<()> {
511 /// let config0 = configs::Detection {
512 /// anchors: Some(vec![
513 /// [0.36666667461395264, 0.31481480598449707],
514 /// [0.38749998807907104, 0.4740740656852722],
515 /// [0.5333333611488342, 0.644444465637207],
516 /// ]),
517 /// decoder: configs::DecoderType::ModelPack,
518 /// quantization: Some(configs::QuantTuple(0.08547406643629074, 174)),
519 /// shape: vec![1, 9, 15, 18],
520 /// dshape: Vec::new(),
521 /// normalized: Some(true),
522 /// };
523 /// let config1 = configs::Detection {
524 /// anchors: Some(vec![
525 /// [0.13750000298023224, 0.2074074000120163],
526 /// [0.2541666626930237, 0.21481481194496155],
527 /// [0.23125000298023224, 0.35185185074806213],
528 /// ]),
529 /// decoder: configs::DecoderType::ModelPack,
530 /// quantization: Some(configs::QuantTuple(0.09929127991199493, 183)),
531 /// shape: vec![1, 17, 30, 18],
532 /// dshape: Vec::new(),
533 /// normalized: Some(true),
534 /// };
535 /// let seg_config = configs::Segmentation {
536 /// decoder: configs::DecoderType::ModelPack,
537 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
538 /// shape: vec![1, 640, 640, 2],
539 /// dshape: Vec::new(),
540 /// };
541 /// let decoder = DecoderBuilder::new()
542 /// .with_config_modelpack_segdet_split(vec![config0, config1], seg_config)
543 /// .build()?;
544 /// # Ok(())
545 /// # }
546 /// ```
547 pub fn with_config_modelpack_segdet_split(
548 mut self,
549 boxes: Vec<configs::Detection>,
550 segmentation: configs::Segmentation,
551 ) -> Self {
552 let mut outputs = boxes
553 .into_iter()
554 .map(ConfigOutput::Detection)
555 .collect::<Vec<_>>();
556 outputs.push(ConfigOutput::Segmentation(segmentation));
557 let config = ConfigOutputs {
558 outputs,
559 ..Default::default()
560 };
561 self.config_src.replace(ConfigSource::Config(config));
562 self
563 }
564
565 /// Loads a ModelPack segmentation model configuration. Use
566 /// `DecoderBuilder.build()` to parse the model configuration.
567 ///
568 /// # Examples
569 /// ```rust
570 /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
571 /// # fn main() -> DecoderResult<()> {
572 /// let seg_config = configs::Segmentation {
573 /// decoder: configs::DecoderType::ModelPack,
574 /// quantization: Some(configs::QuantTuple(0.0064123, -31)),
575 /// shape: vec![1, 640, 640, 3],
576 /// dshape: Vec::new(),
577 /// };
578 /// let decoder = DecoderBuilder::new()
579 /// .with_config_modelpack_seg(seg_config)
580 /// .build()?;
581 /// # Ok(())
582 /// # }
583 /// ```
584 pub fn with_config_modelpack_seg(mut self, segmentation: configs::Segmentation) -> Self {
585 let config = ConfigOutputs {
586 outputs: vec![ConfigOutput::Segmentation(segmentation)],
587 ..Default::default()
588 };
589 self.config_src.replace(ConfigSource::Config(config));
590 self
591 }
592
593 /// Add an output to the decoder configuration.
594 ///
595 /// Incrementally builds the model configuration by adding outputs one at
596 /// a time. The decoder resolves the model type from the combination of
597 /// outputs during `build()`.
598 ///
599 /// If `dshape` is non-empty on the output, `shape` is automatically
600 /// derived from it (the size component of each named dimension). This
601 /// prevents conflicts between `shape` and `dshape`.
602 ///
603 /// This uses the programmatic config path. Calling this after
604 /// `with_config_json_str()` or `with_config_yaml_str()` replaces the
605 /// string-based config source.
606 ///
607 /// # Examples
608 /// ```rust
609 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, ConfigOutput, configs};
610 /// # fn main() -> DecoderResult<()> {
611 /// let decoder = DecoderBuilder::new()
612 /// .add_output(ConfigOutput::Scores(configs::Scores {
613 /// decoder: configs::DecoderType::Ultralytics,
614 /// dshape: vec![
615 /// (configs::DimName::Batch, 1),
616 /// (configs::DimName::NumClasses, 80),
617 /// (configs::DimName::NumBoxes, 8400),
618 /// ],
619 /// ..Default::default()
620 /// }))
621 /// .add_output(ConfigOutput::Boxes(configs::Boxes {
622 /// decoder: configs::DecoderType::Ultralytics,
623 /// dshape: vec![
624 /// (configs::DimName::Batch, 1),
625 /// (configs::DimName::BoxCoords, 4),
626 /// (configs::DimName::NumBoxes, 8400),
627 /// ],
628 /// ..Default::default()
629 /// }))
630 /// .build()?;
631 /// # Ok(())
632 /// # }
633 /// ```
634 pub fn add_output(mut self, output: ConfigOutput) -> Self {
635 if !matches!(self.config_src, Some(ConfigSource::Config(_))) {
636 self.config_src = Some(ConfigSource::Config(ConfigOutputs::default()));
637 }
638 if let Some(ConfigSource::Config(ref mut config)) = self.config_src {
639 config.outputs.push(Self::normalize_output(output));
640 }
641 self
642 }
643
644 /// Sets the decoder version for Ultralytics models.
645 ///
646 /// This is used with `add_output()` to specify the YOLO architecture
647 /// version when it cannot be inferred from the output shapes alone.
648 ///
649 /// - `Yolov5`, `Yolov8`, `Yolo11`: Traditional models requiring external
650 /// NMS
651 /// - `Yolo26`: End-to-end models with NMS embedded in the model graph
652 ///
653 /// # Examples
654 /// ```rust
655 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, ConfigOutput, configs};
656 /// # fn main() -> DecoderResult<()> {
657 /// let decoder = DecoderBuilder::new()
658 /// .add_output(ConfigOutput::Detection(configs::Detection {
659 /// decoder: configs::DecoderType::Ultralytics,
660 /// dshape: vec![
661 /// (configs::DimName::Batch, 1),
662 /// (configs::DimName::NumBoxes, 100),
663 /// (configs::DimName::NumFeatures, 6),
664 /// ],
665 /// ..Default::default()
666 /// }))
667 /// .with_decoder_version(configs::DecoderVersion::Yolo26)
668 /// .build()?;
669 /// # Ok(())
670 /// # }
671 /// ```
672 pub fn with_decoder_version(mut self, version: configs::DecoderVersion) -> Self {
673 if !matches!(self.config_src, Some(ConfigSource::Config(_))) {
674 self.config_src = Some(ConfigSource::Config(ConfigOutputs::default()));
675 }
676 if let Some(ConfigSource::Config(ref mut config)) = self.config_src {
677 config.decoder_version = Some(version);
678 }
679 self
680 }
681
682 /// Normalize an output: if dshape is non-empty, derive shape from it.
683 fn normalize_output(mut output: ConfigOutput) -> ConfigOutput {
684 fn normalize_shape(shape: &mut Vec<usize>, dshape: &[(configs::DimName, usize)]) {
685 if !dshape.is_empty() {
686 *shape = dshape.iter().map(|(_, size)| *size).collect();
687 }
688 }
689 match &mut output {
690 ConfigOutput::Detection(c) => normalize_shape(&mut c.shape, &c.dshape),
691 ConfigOutput::Boxes(c) => normalize_shape(&mut c.shape, &c.dshape),
692 ConfigOutput::Scores(c) => normalize_shape(&mut c.shape, &c.dshape),
693 ConfigOutput::Protos(c) => normalize_shape(&mut c.shape, &c.dshape),
694 ConfigOutput::Segmentation(c) => normalize_shape(&mut c.shape, &c.dshape),
695 ConfigOutput::MaskCoefficients(c) => normalize_shape(&mut c.shape, &c.dshape),
696 ConfigOutput::Mask(c) => normalize_shape(&mut c.shape, &c.dshape),
697 ConfigOutput::Classes(c) => normalize_shape(&mut c.shape, &c.dshape),
698 }
699 output
700 }
701
702 /// Sets the scores threshold of the decoder
703 ///
704 /// # Examples
705 /// ```rust
706 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
707 /// # fn main() -> DecoderResult<()> {
708 /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
709 /// let decoder = DecoderBuilder::new()
710 /// .with_config_json_str(config_json)
711 /// .with_score_threshold(0.654)
712 /// .build()?;
713 /// assert_eq!(decoder.score_threshold, 0.654);
714 /// # Ok(())
715 /// # }
716 /// ```
717 pub fn with_score_threshold(mut self, score_threshold: f32) -> Self {
718 self.score_threshold = score_threshold;
719 self
720 }
721
722 /// Sets the IOU threshold of the decoder. Has no effect when NMS is set to
723 /// `None`
724 ///
725 /// # Examples
726 /// ```rust
727 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
728 /// # fn main() -> DecoderResult<()> {
729 /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
730 /// let decoder = DecoderBuilder::new()
731 /// .with_config_json_str(config_json)
732 /// .with_iou_threshold(0.654)
733 /// .build()?;
734 /// assert_eq!(decoder.iou_threshold, 0.654);
735 /// # Ok(())
736 /// # }
737 /// ```
738 pub fn with_iou_threshold(mut self, iou_threshold: f32) -> Self {
739 self.iou_threshold = iou_threshold;
740 self
741 }
742
743 /// Sets the NMS mode for the decoder.
744 ///
745 /// - `Some(Nms::ClassAgnostic)` — class-agnostic NMS (default): suppress
746 /// overlapping boxes regardless of class label
747 /// - `Some(Nms::ClassAware)` — class-aware NMS: only suppress boxes that
748 /// share the same class label AND overlap above the IoU threshold
749 /// - `None` — bypass NMS entirely (for end-to-end models with embedded NMS)
750 ///
751 /// # Examples
752 /// ```rust
753 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, configs::Nms};
754 /// # fn main() -> DecoderResult<()> {
755 /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
756 /// let decoder = DecoderBuilder::new()
757 /// .with_config_json_str(config_json)
758 /// .with_nms(Some(Nms::ClassAware))
759 /// .build()?;
760 /// assert_eq!(decoder.nms, Some(Nms::ClassAware));
761 /// # Ok(())
762 /// # }
763 /// ```
764 pub fn with_nms(mut self, nms: Option<configs::Nms>) -> Self {
765 self.nms = nms;
766 self
767 }
768
769 /// Sets the maximum number of candidate boxes fed into NMS after score
770 /// filtering. Uses partial sort (O(N)) to select the top-K candidates,
771 /// dramatically reducing the O(N²) NMS cost when many low-confidence
772 /// proposals pass the threshold (common with mAP eval at 0.001).
773 ///
774 /// Default: 300 (matches Ultralytics `max_det`).
775 ///
776 /// # Examples
777 /// ```rust
778 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
779 /// # fn main() -> DecoderResult<()> {
780 /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
781 /// let decoder = DecoderBuilder::new()
782 /// .with_config_json_str(config_json)
783 /// .with_pre_nms_top_k(1000)
784 /// .build()?;
785 /// assert_eq!(decoder.pre_nms_top_k, 1000);
786 /// # Ok(())
787 /// # }
788 /// ```
789 pub fn with_pre_nms_top_k(mut self, pre_nms_top_k: usize) -> Self {
790 self.pre_nms_top_k = pre_nms_top_k;
791 self
792 }
793
794 /// Sets the maximum number of detections returned after NMS.
795 /// Matches the Ultralytics `max_det` parameter.
796 ///
797 /// Default: 300.
798 ///
799 /// # Examples
800 /// ```rust
801 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
802 /// # fn main() -> DecoderResult<()> {
803 /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
804 /// let decoder = DecoderBuilder::new()
805 /// .with_config_json_str(config_json)
806 /// .with_max_det(100)
807 /// .build()?;
808 /// assert_eq!(decoder.max_det, 100);
809 /// # Ok(())
810 /// # }
811 /// ```
812 pub fn with_max_det(mut self, max_det: usize) -> Self {
813 self.max_det = max_det;
814 self
815 }
816
817 /// Builds the decoder with the given settings. If the config is a JSON or
818 /// YAML string, this will deserialize the JSON or YAML and then parse the
819 /// configuration information.
820 ///
821 /// # Examples
822 /// ```rust
823 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
824 /// # fn main() -> DecoderResult<()> {
825 /// # let config_json = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.json")).to_string();
826 /// let decoder = DecoderBuilder::new()
827 /// .with_config_json_str(config_json)
828 /// .with_score_threshold(0.654)
829 /// .build()?;
830 /// # Ok(())
831 /// # }
832 /// ```
833 pub fn build(self) -> Result<Decoder, DecoderError> {
834 let (config, decode_program) = match self.config_src {
835 Some(ConfigSource::Json(s)) => Self::build_from_schema(SchemaV2::parse_json(&s)?)?,
836 Some(ConfigSource::Yaml(s)) => Self::build_from_schema(SchemaV2::parse_yaml(&s)?)?,
837 Some(ConfigSource::Config(c)) => (c, None),
838 Some(ConfigSource::Schema(schema)) => Self::build_from_schema(schema)?,
839 None => return Err(DecoderError::NoConfig),
840 };
841
842 // Enforce the physical-order contract: when dshape is present
843 // it must describe the same axes as shape in the same order,
844 // listed from outermost to innermost. Ambiguous-layout roles
845 // (Protos, Boxes, Scores, MaskCoefficients, Classes, Detection)
846 // may still omit dshape when shape is already in the decoder's
847 // canonical order.
848 for output in &config.outputs {
849 Decoder::validate_output_layout(output.into())?;
850 }
851
852 // Extract normalized flag from config outputs
853 let normalized = Self::get_normalized(&config.outputs);
854
855 // Use NMS from config if present, otherwise use builder's NMS setting
856 let nms = config.nms.or(self.nms);
857 let model_type = Self::get_model_type(config)?;
858
859 Ok(Decoder {
860 model_type,
861 iou_threshold: self.iou_threshold,
862 score_threshold: self.score_threshold,
863 nms,
864 pre_nms_top_k: self.pre_nms_top_k,
865 max_det: self.max_det,
866 normalized,
867 decode_program,
868 })
869 }
870
871 /// Validate a [`SchemaV2`] and lower it to the (legacy `ConfigOutputs`,
872 /// optional `DecodeProgram`) pair the rest of `build()` consumes.
873 ///
874 /// Centralises the v2 lowering so JSON, YAML, and direct
875 /// `with_schema` callers all go through the same validation and
876 /// merge-program construction. `SchemaV2::parse_json` /
877 /// `parse_yaml` already auto-detect v1 vs v2 input and return a v2
878 /// schema either way (v1 inputs are upgraded in memory via
879 /// `from_v1`), so this helper is the sole place that turns a
880 /// schema into builder-ready state.
881 fn build_from_schema(
882 schema: SchemaV2,
883 ) -> Result<(ConfigOutputs, Option<DecodeProgram>), DecoderError> {
884 schema.validate()?;
885 let program = DecodeProgram::try_from_schema(&schema)?;
886 let legacy = schema.to_legacy_config_outputs()?;
887 Ok((legacy, program))
888 }
889
890 /// Extracts the normalized flag from config outputs.
891 /// - `Some(true)`: Boxes are in normalized [0,1] coordinates
892 /// - `Some(false)`: Boxes are in pixel coordinates
893 /// - `None`: Unknown (not specified in config), caller must infer
894 fn get_normalized(outputs: &[ConfigOutput]) -> Option<bool> {
895 for output in outputs {
896 match output {
897 ConfigOutput::Detection(det) => return det.normalized,
898 ConfigOutput::Boxes(boxes) => return boxes.normalized,
899 _ => {}
900 }
901 }
902 None // not specified
903 }
904
905 fn get_model_type(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
906 // yolo or modelpack
907 let mut yolo = false;
908 let mut modelpack = false;
909 for c in &configs.outputs {
910 match c.decoder() {
911 DecoderType::ModelPack => modelpack = true,
912 DecoderType::Ultralytics => yolo = true,
913 }
914 }
915 match (modelpack, yolo) {
916 (true, true) => Err(DecoderError::InvalidConfig(
917 "Both ModelPack and Yolo outputs found in config".to_string(),
918 )),
919 (true, false) => Self::get_model_type_modelpack(configs),
920 (false, true) => Self::get_model_type_yolo(configs),
921 (false, false) => Err(DecoderError::InvalidConfig(
922 "No outputs found in config".to_string(),
923 )),
924 }
925 }
926
927 fn get_model_type_yolo(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
928 let mut boxes = None;
929 let mut protos = None;
930 let mut split_boxes = None;
931 let mut split_scores = None;
932 let mut split_mask_coeff = None;
933 let mut split_classes = None;
934 for c in configs.outputs {
935 match c {
936 ConfigOutput::Detection(detection) => boxes = Some(detection),
937 ConfigOutput::Segmentation(_) => {
938 return Err(DecoderError::InvalidConfig(
939 "Invalid Segmentation output with Yolo decoder".to_string(),
940 ));
941 }
942 ConfigOutput::Protos(protos_) => protos = Some(protos_),
943 ConfigOutput::Mask(_) => {
944 return Err(DecoderError::InvalidConfig(
945 "Invalid Mask output with Yolo decoder".to_string(),
946 ));
947 }
948 ConfigOutput::Scores(scores) => split_scores = Some(scores),
949 ConfigOutput::Boxes(boxes) => split_boxes = Some(boxes),
950 ConfigOutput::MaskCoefficients(mask_coeff) => split_mask_coeff = Some(mask_coeff),
951 ConfigOutput::Classes(classes) => split_classes = Some(classes),
952 }
953 }
954
955 // Use end-to-end model types when:
956 // 1. decoder_version is explicitly set to Yolo26 (definitive), OR
957 // decoder_version is not set but the dshapes are (batch, num_boxes,
958 // num_features)
959 let is_end_to_end_dshape = boxes.as_ref().is_some_and(|b| {
960 let dims = b.dshape.iter().map(|(d, _)| *d).collect::<Vec<_>>();
961 dims == vec![DimName::Batch, DimName::NumBoxes, DimName::NumFeatures]
962 });
963
964 let is_end_to_end = configs
965 .decoder_version
966 .map(|v| v.is_end_to_end())
967 .unwrap_or(is_end_to_end_dshape);
968
969 if is_end_to_end {
970 if let Some(boxes) = boxes {
971 if let Some(protos) = protos {
972 Self::verify_yolo_seg_det_26(&boxes, &protos)?;
973 return Ok(ModelType::YoloEndToEndSegDet { boxes, protos });
974 } else {
975 Self::verify_yolo_det_26(&boxes)?;
976 return Ok(ModelType::YoloEndToEndDet { boxes });
977 }
978 } else if let (Some(split_boxes), Some(split_scores), Some(split_classes)) =
979 (split_boxes, split_scores, split_classes)
980 {
981 if let (Some(split_mask_coeff), Some(protos)) = (split_mask_coeff, protos) {
982 Self::verify_yolo_split_end_to_end_segdet(
983 &split_boxes,
984 &split_scores,
985 &split_classes,
986 &split_mask_coeff,
987 &protos,
988 )?;
989 return Ok(ModelType::YoloSplitEndToEndSegDet {
990 boxes: split_boxes,
991 scores: split_scores,
992 classes: split_classes,
993 mask_coeff: split_mask_coeff,
994 protos,
995 });
996 }
997 Self::verify_yolo_split_end_to_end_det(
998 &split_boxes,
999 &split_scores,
1000 &split_classes,
1001 )?;
1002 return Ok(ModelType::YoloSplitEndToEndDet {
1003 boxes: split_boxes,
1004 scores: split_scores,
1005 classes: split_classes,
1006 });
1007 } else {
1008 return Err(DecoderError::InvalidConfig(
1009 "Invalid Yolo end-to-end model outputs".to_string(),
1010 ));
1011 }
1012 }
1013
1014 if let Some(boxes) = boxes {
1015 match (split_mask_coeff, protos) {
1016 (Some(mask_coeff), Some(protos)) => {
1017 // 2-way split: combined detection + separate mask_coeff + protos
1018 Self::verify_yolo_seg_det_2way(&boxes, &mask_coeff, &protos)?;
1019 Ok(ModelType::YoloSegDet2Way {
1020 boxes,
1021 mask_coeff,
1022 protos,
1023 })
1024 }
1025 (_, Some(protos)) => {
1026 // Unsplit: mask_coefs embedded in detection tensor
1027 Self::verify_yolo_seg_det(&boxes, &protos)?;
1028 Ok(ModelType::YoloSegDet { boxes, protos })
1029 }
1030 _ => {
1031 Self::verify_yolo_det(&boxes)?;
1032 Ok(ModelType::YoloDet { boxes })
1033 }
1034 }
1035 } else if let (Some(boxes), Some(scores)) = (split_boxes, split_scores) {
1036 if let (Some(mask_coeff), Some(protos)) = (split_mask_coeff, protos) {
1037 Self::verify_yolo_split_segdet(&boxes, &scores, &mask_coeff, &protos)?;
1038 Ok(ModelType::YoloSplitSegDet {
1039 boxes,
1040 scores,
1041 mask_coeff,
1042 protos,
1043 })
1044 } else {
1045 Self::verify_yolo_split_det(&boxes, &scores)?;
1046 Ok(ModelType::YoloSplitDet { boxes, scores })
1047 }
1048 } else {
1049 Err(DecoderError::InvalidConfig(
1050 "Invalid Yolo model outputs".to_string(),
1051 ))
1052 }
1053 }
1054
1055 fn verify_yolo_det(detect: &configs::Detection) -> Result<(), DecoderError> {
1056 if detect.shape.len() != 3 {
1057 return Err(DecoderError::InvalidConfig(format!(
1058 "Invalid Yolo Detection shape {:?}",
1059 detect.shape
1060 )));
1061 }
1062
1063 Self::verify_dshapes(
1064 &detect.dshape,
1065 &detect.shape,
1066 "Detection",
1067 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1068 )?;
1069 if !detect.dshape.is_empty() {
1070 Self::get_class_count(&detect.dshape, None, None)?;
1071 } else {
1072 Self::get_class_count_no_dshape(detect.into(), None)?;
1073 }
1074
1075 Ok(())
1076 }
1077
1078 fn verify_yolo_det_26(detect: &configs::Detection) -> Result<(), DecoderError> {
1079 if detect.shape.len() != 3 {
1080 return Err(DecoderError::InvalidConfig(format!(
1081 "Invalid Yolo Detection shape {:?}",
1082 detect.shape
1083 )));
1084 }
1085
1086 Self::verify_dshapes(
1087 &detect.dshape,
1088 &detect.shape,
1089 "Detection",
1090 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1091 )?;
1092
1093 if !detect.shape.contains(&6) {
1094 return Err(DecoderError::InvalidConfig(
1095 "Yolo26 Detection must have 6 features".to_string(),
1096 ));
1097 }
1098
1099 Ok(())
1100 }
1101
1102 fn verify_yolo_seg_det(
1103 detection: &configs::Detection,
1104 protos: &configs::Protos,
1105 ) -> Result<(), DecoderError> {
1106 if detection.shape.len() != 3 {
1107 return Err(DecoderError::InvalidConfig(format!(
1108 "Invalid Yolo Detection shape {:?}",
1109 detection.shape
1110 )));
1111 }
1112 if protos.shape.len() != 4 {
1113 return Err(DecoderError::InvalidConfig(format!(
1114 "Invalid Yolo Protos shape {:?}",
1115 protos.shape
1116 )));
1117 }
1118
1119 Self::verify_dshapes(
1120 &detection.dshape,
1121 &detection.shape,
1122 "Detection",
1123 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1124 )?;
1125 Self::verify_dshapes(
1126 &protos.dshape,
1127 &protos.shape,
1128 "Protos",
1129 &[
1130 DimName::Batch,
1131 DimName::Height,
1132 DimName::Width,
1133 DimName::NumProtos,
1134 ],
1135 )?;
1136
1137 let protos_count = Self::get_protos_count(&protos.dshape)
1138 .unwrap_or_else(|| protos.shape[1].min(protos.shape[3]));
1139 log::debug!("Protos count: {}", protos_count);
1140 log::debug!("Detection dshape: {:?}", detection.dshape);
1141 let classes = if !detection.dshape.is_empty() {
1142 Self::get_class_count(&detection.dshape, Some(protos_count), None)?
1143 } else {
1144 Self::get_class_count_no_dshape(detection.into(), Some(protos_count))?
1145 };
1146
1147 if classes == 0 {
1148 return Err(DecoderError::InvalidConfig(
1149 "Yolo Segmentation Detection has zero classes".to_string(),
1150 ));
1151 }
1152
1153 Ok(())
1154 }
1155
1156 fn verify_yolo_seg_det_2way(
1157 detection: &configs::Detection,
1158 mask_coeff: &configs::MaskCoefficients,
1159 protos: &configs::Protos,
1160 ) -> Result<(), DecoderError> {
1161 if detection.shape.len() != 3 {
1162 return Err(DecoderError::InvalidConfig(format!(
1163 "Invalid Yolo 2-Way Detection shape {:?}",
1164 detection.shape
1165 )));
1166 }
1167 if mask_coeff.shape.len() != 3 {
1168 return Err(DecoderError::InvalidConfig(format!(
1169 "Invalid Yolo 2-Way Mask Coefficients shape {:?}",
1170 mask_coeff.shape
1171 )));
1172 }
1173 if protos.shape.len() != 4 {
1174 return Err(DecoderError::InvalidConfig(format!(
1175 "Invalid Yolo 2-Way Protos shape {:?}",
1176 protos.shape
1177 )));
1178 }
1179
1180 Self::verify_dshapes(
1181 &detection.dshape,
1182 &detection.shape,
1183 "Detection",
1184 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1185 )?;
1186 Self::verify_dshapes(
1187 &mask_coeff.dshape,
1188 &mask_coeff.shape,
1189 "Mask Coefficients",
1190 &[DimName::Batch, DimName::NumProtos, DimName::NumBoxes],
1191 )?;
1192 Self::verify_dshapes(
1193 &protos.dshape,
1194 &protos.shape,
1195 "Protos",
1196 &[
1197 DimName::Batch,
1198 DimName::Height,
1199 DimName::Width,
1200 DimName::NumProtos,
1201 ],
1202 )?;
1203
1204 // Validate num_boxes match between detection and mask_coeff
1205 let det_num = Self::get_box_count(&detection.dshape).unwrap_or(detection.shape[2]);
1206 let mask_num = Self::get_box_count(&mask_coeff.dshape).unwrap_or(mask_coeff.shape[2]);
1207 if det_num != mask_num {
1208 return Err(DecoderError::InvalidConfig(format!(
1209 "Yolo 2-Way Detection num_boxes {} incompatible with Mask Coefficients num_boxes {}",
1210 det_num, mask_num
1211 )));
1212 }
1213
1214 // Validate mask_coeff channels match protos channels
1215 let mask_channels = if !mask_coeff.dshape.is_empty() {
1216 Self::get_protos_count(&mask_coeff.dshape).ok_or_else(|| {
1217 DecoderError::InvalidConfig(
1218 "Could not find num_protos in mask_coeff config".to_string(),
1219 )
1220 })?
1221 } else {
1222 mask_coeff.shape[1]
1223 };
1224 let proto_channels = if !protos.dshape.is_empty() {
1225 Self::get_protos_count(&protos.dshape).ok_or_else(|| {
1226 DecoderError::InvalidConfig(
1227 "Could not find num_protos in protos config".to_string(),
1228 )
1229 })?
1230 } else {
1231 protos.shape[1].min(protos.shape[3])
1232 };
1233 if mask_channels != proto_channels {
1234 return Err(DecoderError::InvalidConfig(format!(
1235 "Yolo 2-Way Protos channels {} incompatible with Mask Coefficients channels {}",
1236 proto_channels, mask_channels
1237 )));
1238 }
1239
1240 // Validate detection has classes (nc+4 features, no mask_coefs embedded)
1241 if !detection.dshape.is_empty() {
1242 Self::get_class_count(&detection.dshape, None, None)?;
1243 } else {
1244 Self::get_class_count_no_dshape(detection.into(), None)?;
1245 }
1246
1247 Ok(())
1248 }
1249
1250 fn verify_yolo_seg_det_26(
1251 detection: &configs::Detection,
1252 protos: &configs::Protos,
1253 ) -> Result<(), DecoderError> {
1254 if detection.shape.len() != 3 {
1255 return Err(DecoderError::InvalidConfig(format!(
1256 "Invalid Yolo Detection shape {:?}",
1257 detection.shape
1258 )));
1259 }
1260 if protos.shape.len() != 4 {
1261 return Err(DecoderError::InvalidConfig(format!(
1262 "Invalid Yolo Protos shape {:?}",
1263 protos.shape
1264 )));
1265 }
1266
1267 Self::verify_dshapes(
1268 &detection.dshape,
1269 &detection.shape,
1270 "Detection",
1271 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1272 )?;
1273 Self::verify_dshapes(
1274 &protos.dshape,
1275 &protos.shape,
1276 "Protos",
1277 &[
1278 DimName::Batch,
1279 DimName::Height,
1280 DimName::Width,
1281 DimName::NumProtos,
1282 ],
1283 )?;
1284
1285 let protos_count = Self::get_protos_count(&protos.dshape)
1286 .unwrap_or_else(|| protos.shape[1].min(protos.shape[3]));
1287 log::debug!("Protos count: {}", protos_count);
1288 log::debug!("Detection dshape: {:?}", detection.dshape);
1289
1290 if !detection.shape.contains(&(6 + protos_count)) {
1291 return Err(DecoderError::InvalidConfig(format!(
1292 "Yolo26 Segmentation Detection must have num_features be 6 + num_protos = {}",
1293 6 + protos_count
1294 )));
1295 }
1296
1297 Ok(())
1298 }
1299
1300 fn verify_yolo_split_det(
1301 boxes: &configs::Boxes,
1302 scores: &configs::Scores,
1303 ) -> Result<(), DecoderError> {
1304 if boxes.shape.len() != 3 {
1305 return Err(DecoderError::InvalidConfig(format!(
1306 "Invalid Yolo Split Boxes shape {:?}",
1307 boxes.shape
1308 )));
1309 }
1310 if scores.shape.len() != 3 {
1311 return Err(DecoderError::InvalidConfig(format!(
1312 "Invalid Yolo Split Scores shape {:?}",
1313 scores.shape
1314 )));
1315 }
1316
1317 Self::verify_dshapes(
1318 &boxes.dshape,
1319 &boxes.shape,
1320 "Boxes",
1321 &[DimName::Batch, DimName::BoxCoords, DimName::NumBoxes],
1322 )?;
1323 Self::verify_dshapes(
1324 &scores.dshape,
1325 &scores.shape,
1326 "Scores",
1327 &[DimName::Batch, DimName::NumClasses, DimName::NumBoxes],
1328 )?;
1329
1330 let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[2]);
1331 let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[2]);
1332
1333 if boxes_num != scores_num {
1334 return Err(DecoderError::InvalidConfig(format!(
1335 "Yolo Split Detection Boxes num {} incompatible with Scores num {}",
1336 boxes_num, scores_num
1337 )));
1338 }
1339
1340 Ok(())
1341 }
1342
1343 fn verify_yolo_split_segdet(
1344 boxes: &configs::Boxes,
1345 scores: &configs::Scores,
1346 mask_coeff: &configs::MaskCoefficients,
1347 protos: &configs::Protos,
1348 ) -> Result<(), DecoderError> {
1349 if boxes.shape.len() != 3 {
1350 return Err(DecoderError::InvalidConfig(format!(
1351 "Invalid Yolo Split Boxes shape {:?}",
1352 boxes.shape
1353 )));
1354 }
1355 if scores.shape.len() != 3 {
1356 return Err(DecoderError::InvalidConfig(format!(
1357 "Invalid Yolo Split Scores shape {:?}",
1358 scores.shape
1359 )));
1360 }
1361
1362 if mask_coeff.shape.len() != 3 {
1363 return Err(DecoderError::InvalidConfig(format!(
1364 "Invalid Yolo Split Mask Coefficients shape {:?}",
1365 mask_coeff.shape
1366 )));
1367 }
1368
1369 if protos.shape.len() != 4 {
1370 return Err(DecoderError::InvalidConfig(format!(
1371 "Invalid Yolo Protos shape {:?}",
1372 mask_coeff.shape
1373 )));
1374 }
1375
1376 Self::verify_dshapes(
1377 &boxes.dshape,
1378 &boxes.shape,
1379 "Boxes",
1380 &[DimName::Batch, DimName::BoxCoords, DimName::NumBoxes],
1381 )?;
1382 Self::verify_dshapes(
1383 &scores.dshape,
1384 &scores.shape,
1385 "Scores",
1386 &[DimName::Batch, DimName::NumClasses, DimName::NumBoxes],
1387 )?;
1388 Self::verify_dshapes(
1389 &mask_coeff.dshape,
1390 &mask_coeff.shape,
1391 "Mask Coefficients",
1392 &[DimName::Batch, DimName::NumProtos, DimName::NumBoxes],
1393 )?;
1394 Self::verify_dshapes(
1395 &protos.dshape,
1396 &protos.shape,
1397 "Protos",
1398 &[
1399 DimName::Batch,
1400 DimName::Height,
1401 DimName::Width,
1402 DimName::NumProtos,
1403 ],
1404 )?;
1405
1406 let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[2]);
1407 let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[2]);
1408 let mask_num = Self::get_box_count(&mask_coeff.dshape).unwrap_or(mask_coeff.shape[2]);
1409
1410 let mask_channels = if !mask_coeff.dshape.is_empty() {
1411 Self::get_protos_count(&mask_coeff.dshape).ok_or_else(|| {
1412 DecoderError::InvalidConfig("Could not find num_protos in config".to_string())
1413 })?
1414 } else {
1415 mask_coeff.shape[1]
1416 };
1417 let proto_channels = if !protos.dshape.is_empty() {
1418 Self::get_protos_count(&protos.dshape).ok_or_else(|| {
1419 DecoderError::InvalidConfig("Could not find num_protos in config".to_string())
1420 })?
1421 } else {
1422 protos.shape[1].min(protos.shape[3])
1423 };
1424
1425 if boxes_num != scores_num {
1426 return Err(DecoderError::InvalidConfig(format!(
1427 "Yolo Split Detection Boxes num {} incompatible with Scores num {}",
1428 boxes_num, scores_num
1429 )));
1430 }
1431
1432 if boxes_num != mask_num {
1433 return Err(DecoderError::InvalidConfig(format!(
1434 "Yolo Split Detection Boxes num {} incompatible with Mask Coefficients num {}",
1435 boxes_num, mask_num
1436 )));
1437 }
1438
1439 if proto_channels != mask_channels {
1440 return Err(DecoderError::InvalidConfig(format!(
1441 "Yolo Protos channels {} incompatible with Mask Coefficients channels {}",
1442 proto_channels, mask_channels
1443 )));
1444 }
1445
1446 Ok(())
1447 }
1448
1449 fn verify_yolo_split_end_to_end_det(
1450 boxes: &configs::Boxes,
1451 scores: &configs::Scores,
1452 classes: &configs::Classes,
1453 ) -> Result<(), DecoderError> {
1454 if boxes.shape.len() != 3 || !boxes.shape.contains(&4) {
1455 return Err(DecoderError::InvalidConfig(format!(
1456 "Split end-to-end boxes must be [batch, N, 4], got {:?}",
1457 boxes.shape
1458 )));
1459 }
1460 if scores.shape.len() != 3 || !scores.shape.contains(&1) {
1461 return Err(DecoderError::InvalidConfig(format!(
1462 "Split end-to-end scores must be [batch, N, 1], got {:?}",
1463 scores.shape
1464 )));
1465 }
1466 if classes.shape.len() != 3 || !classes.shape.contains(&1) {
1467 return Err(DecoderError::InvalidConfig(format!(
1468 "Split end-to-end classes must be [batch, N, 1], got {:?}",
1469 classes.shape
1470 )));
1471 }
1472 Ok(())
1473 }
1474
1475 fn verify_yolo_split_end_to_end_segdet(
1476 boxes: &configs::Boxes,
1477 scores: &configs::Scores,
1478 classes: &configs::Classes,
1479 mask_coeff: &configs::MaskCoefficients,
1480 protos: &configs::Protos,
1481 ) -> Result<(), DecoderError> {
1482 Self::verify_yolo_split_end_to_end_det(boxes, scores, classes)?;
1483 if mask_coeff.shape.len() != 3 {
1484 return Err(DecoderError::InvalidConfig(format!(
1485 "Invalid split end-to-end mask coefficients shape {:?}",
1486 mask_coeff.shape
1487 )));
1488 }
1489 if protos.shape.len() != 4 {
1490 return Err(DecoderError::InvalidConfig(format!(
1491 "Invalid protos shape {:?}",
1492 protos.shape
1493 )));
1494 }
1495 Ok(())
1496 }
1497
1498 fn get_model_type_modelpack(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
1499 let mut split_decoders = Vec::new();
1500 let mut segment_ = None;
1501 let mut scores_ = None;
1502 let mut boxes_ = None;
1503 for c in configs.outputs {
1504 match c {
1505 ConfigOutput::Detection(detection) => split_decoders.push(detection),
1506 ConfigOutput::Segmentation(segmentation) => segment_ = Some(segmentation),
1507 ConfigOutput::Mask(_) => {}
1508 ConfigOutput::Protos(_) => {
1509 return Err(DecoderError::InvalidConfig(
1510 "ModelPack should not have protos".to_string(),
1511 ));
1512 }
1513 ConfigOutput::Scores(scores) => scores_ = Some(scores),
1514 ConfigOutput::Boxes(boxes) => boxes_ = Some(boxes),
1515 ConfigOutput::MaskCoefficients(_) => {
1516 return Err(DecoderError::InvalidConfig(
1517 "ModelPack should not have mask coefficients".to_string(),
1518 ));
1519 }
1520 ConfigOutput::Classes(_) => {
1521 return Err(DecoderError::InvalidConfig(
1522 "ModelPack should not have classes output".to_string(),
1523 ));
1524 }
1525 }
1526 }
1527
1528 if let Some(segmentation) = segment_ {
1529 if !split_decoders.is_empty() {
1530 let classes = Self::verify_modelpack_split_det(&split_decoders)?;
1531 Self::verify_modelpack_seg(&segmentation, Some(classes))?;
1532 Ok(ModelType::ModelPackSegDetSplit {
1533 detection: split_decoders,
1534 segmentation,
1535 })
1536 } else if let (Some(scores), Some(boxes)) = (scores_, boxes_) {
1537 let classes = Self::verify_modelpack_det(&boxes, &scores)?;
1538 Self::verify_modelpack_seg(&segmentation, Some(classes))?;
1539 Ok(ModelType::ModelPackSegDet {
1540 boxes,
1541 scores,
1542 segmentation,
1543 })
1544 } else {
1545 Self::verify_modelpack_seg(&segmentation, None)?;
1546 Ok(ModelType::ModelPackSeg { segmentation })
1547 }
1548 } else if !split_decoders.is_empty() {
1549 Self::verify_modelpack_split_det(&split_decoders)?;
1550 Ok(ModelType::ModelPackDetSplit {
1551 detection: split_decoders,
1552 })
1553 } else if let (Some(scores), Some(boxes)) = (scores_, boxes_) {
1554 Self::verify_modelpack_det(&boxes, &scores)?;
1555 Ok(ModelType::ModelPackDet { boxes, scores })
1556 } else {
1557 Err(DecoderError::InvalidConfig(
1558 "Invalid ModelPack model outputs".to_string(),
1559 ))
1560 }
1561 }
1562
1563 fn verify_modelpack_det(
1564 boxes: &configs::Boxes,
1565 scores: &configs::Scores,
1566 ) -> Result<usize, DecoderError> {
1567 if boxes.shape.len() != 4 {
1568 return Err(DecoderError::InvalidConfig(format!(
1569 "Invalid ModelPack Boxes shape {:?}",
1570 boxes.shape
1571 )));
1572 }
1573 if scores.shape.len() != 3 {
1574 return Err(DecoderError::InvalidConfig(format!(
1575 "Invalid ModelPack Scores shape {:?}",
1576 scores.shape
1577 )));
1578 }
1579
1580 Self::verify_dshapes(
1581 &boxes.dshape,
1582 &boxes.shape,
1583 "Boxes",
1584 &[
1585 DimName::Batch,
1586 DimName::NumBoxes,
1587 DimName::Padding,
1588 DimName::BoxCoords,
1589 ],
1590 )?;
1591 Self::verify_dshapes(
1592 &scores.dshape,
1593 &scores.shape,
1594 "Scores",
1595 &[DimName::Batch, DimName::NumBoxes, DimName::NumClasses],
1596 )?;
1597
1598 let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[1]);
1599 let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[1]);
1600
1601 if boxes_num != scores_num {
1602 return Err(DecoderError::InvalidConfig(format!(
1603 "ModelPack Detection Boxes num {} incompatible with Scores num {}",
1604 boxes_num, scores_num
1605 )));
1606 }
1607
1608 let num_classes = if !scores.dshape.is_empty() {
1609 Self::get_class_count(&scores.dshape, None, None)?
1610 } else {
1611 Self::get_class_count_no_dshape(scores.into(), None)?
1612 };
1613
1614 Ok(num_classes)
1615 }
1616
1617 fn verify_modelpack_split_det(boxes: &[configs::Detection]) -> Result<usize, DecoderError> {
1618 let mut num_classes = None;
1619 for b in boxes {
1620 let Some(num_anchors) = b.anchors.as_ref().map(|a| a.len()) else {
1621 return Err(DecoderError::InvalidConfig(
1622 "ModelPack Split Detection missing anchors".to_string(),
1623 ));
1624 };
1625
1626 if num_anchors == 0 {
1627 return Err(DecoderError::InvalidConfig(
1628 "ModelPack Split Detection has zero anchors".to_string(),
1629 ));
1630 }
1631
1632 if b.shape.len() != 4 {
1633 return Err(DecoderError::InvalidConfig(format!(
1634 "Invalid ModelPack Split Detection shape {:?}",
1635 b.shape
1636 )));
1637 }
1638
1639 Self::verify_dshapes(
1640 &b.dshape,
1641 &b.shape,
1642 "Split Detection",
1643 &[
1644 DimName::Batch,
1645 DimName::Height,
1646 DimName::Width,
1647 DimName::NumAnchorsXFeatures,
1648 ],
1649 )?;
1650 let classes = if !b.dshape.is_empty() {
1651 Self::get_class_count(&b.dshape, None, Some(num_anchors))?
1652 } else {
1653 Self::get_class_count_no_dshape(b.into(), None)?
1654 };
1655
1656 match num_classes {
1657 Some(n) => {
1658 if n != classes {
1659 return Err(DecoderError::InvalidConfig(format!(
1660 "ModelPack Split Detection inconsistent number of classes: previous {}, current {}",
1661 n, classes
1662 )));
1663 }
1664 }
1665 None => {
1666 num_classes = Some(classes);
1667 }
1668 }
1669 }
1670
1671 Ok(num_classes.unwrap_or(0))
1672 }
1673
1674 fn verify_modelpack_seg(
1675 segmentation: &configs::Segmentation,
1676 classes: Option<usize>,
1677 ) -> Result<(), DecoderError> {
1678 if segmentation.shape.len() != 4 {
1679 return Err(DecoderError::InvalidConfig(format!(
1680 "Invalid ModelPack Segmentation shape {:?}",
1681 segmentation.shape
1682 )));
1683 }
1684 Self::verify_dshapes(
1685 &segmentation.dshape,
1686 &segmentation.shape,
1687 "Segmentation",
1688 &[
1689 DimName::Batch,
1690 DimName::Height,
1691 DimName::Width,
1692 DimName::NumClasses,
1693 ],
1694 )?;
1695
1696 if let Some(classes) = classes {
1697 let seg_classes = if !segmentation.dshape.is_empty() {
1698 Self::get_class_count(&segmentation.dshape, None, None)?
1699 } else {
1700 Self::get_class_count_no_dshape(segmentation.into(), None)?
1701 };
1702
1703 if seg_classes != classes + 1 {
1704 return Err(DecoderError::InvalidConfig(format!(
1705 "ModelPack Segmentation channels {} incompatible with number of classes {}",
1706 seg_classes, classes
1707 )));
1708 }
1709 }
1710 Ok(())
1711 }
1712
1713 // verifies that dshapes match the given shape
1714 fn verify_dshapes(
1715 dshape: &[(DimName, usize)],
1716 shape: &[usize],
1717 name: &str,
1718 dims: &[DimName],
1719 ) -> Result<(), DecoderError> {
1720 for s in shape {
1721 if *s == 0 {
1722 return Err(DecoderError::InvalidConfig(format!(
1723 "{} shape has zero dimension",
1724 name
1725 )));
1726 }
1727 }
1728
1729 if shape.len() != dims.len() {
1730 return Err(DecoderError::InvalidConfig(format!(
1731 "{} shape length {} does not match expected dims length {}",
1732 name,
1733 shape.len(),
1734 dims.len()
1735 )));
1736 }
1737
1738 if dshape.is_empty() {
1739 return Ok(());
1740 }
1741 // check the dshape lengths match the shape lengths
1742 if dshape.len() != shape.len() {
1743 return Err(DecoderError::InvalidConfig(format!(
1744 "{} dshape length does not match shape length",
1745 name
1746 )));
1747 }
1748
1749 // check the dshape values match the shape values
1750 for ((dim_name, dim_size), shape_size) in dshape.iter().zip(shape) {
1751 if dim_size != shape_size {
1752 return Err(DecoderError::InvalidConfig(format!(
1753 "{} dshape dimension {} size {} does not match shape size {}",
1754 name, dim_name, dim_size, shape_size
1755 )));
1756 }
1757 if *dim_name == DimName::Padding && *dim_size != 1 {
1758 return Err(DecoderError::InvalidConfig(
1759 "Padding dimension size must be 1".to_string(),
1760 ));
1761 }
1762
1763 if *dim_name == DimName::BoxCoords && *dim_size != 4 {
1764 return Err(DecoderError::InvalidConfig(
1765 "BoxCoords dimension size must be 4".to_string(),
1766 ));
1767 }
1768 }
1769
1770 let dims_present = HashSet::<DimName>::from_iter(dshape.iter().map(|(name, _)| *name));
1771 for dim in dims {
1772 if !dims_present.contains(dim) {
1773 return Err(DecoderError::InvalidConfig(format!(
1774 "{} dshape missing required dimension {:?}",
1775 name, dim
1776 )));
1777 }
1778 }
1779
1780 Ok(())
1781 }
1782
1783 fn get_box_count(dshape: &[(DimName, usize)]) -> Option<usize> {
1784 for (dim_name, dim_size) in dshape {
1785 if *dim_name == DimName::NumBoxes {
1786 return Some(*dim_size);
1787 }
1788 }
1789 None
1790 }
1791
1792 fn get_class_count_no_dshape(
1793 config: ConfigOutputRef,
1794 protos: Option<usize>,
1795 ) -> Result<usize, DecoderError> {
1796 match config {
1797 ConfigOutputRef::Detection(detection) => match detection.decoder {
1798 DecoderType::Ultralytics => {
1799 if detection.shape[1] <= 4 + protos.unwrap_or(0) {
1800 return Err(DecoderError::InvalidConfig(format!(
1801 "Invalid shape: Yolo num_features {} must be greater than {}",
1802 detection.shape[1],
1803 4 + protos.unwrap_or(0),
1804 )));
1805 }
1806 Ok(detection.shape[1] - 4 - protos.unwrap_or(0))
1807 }
1808 DecoderType::ModelPack => {
1809 let Some(num_anchors) = detection.anchors.as_ref().map(|a| a.len()) else {
1810 return Err(DecoderError::Internal(
1811 "ModelPack Detection missing anchors".to_string(),
1812 ));
1813 };
1814 let anchors_x_features = detection.shape[3];
1815 if anchors_x_features <= num_anchors * 5 {
1816 return Err(DecoderError::InvalidConfig(format!(
1817 "Invalid ModelPack Split Detection shape: anchors_x_features {} not greater than number of anchors * 5 = {}",
1818 anchors_x_features,
1819 num_anchors * 5,
1820 )));
1821 }
1822
1823 if !anchors_x_features.is_multiple_of(num_anchors) {
1824 return Err(DecoderError::InvalidConfig(format!(
1825 "Invalid ModelPack Split Detection shape: anchors_x_features {} not a multiple of number of anchors {}",
1826 anchors_x_features, num_anchors
1827 )));
1828 }
1829 Ok(anchors_x_features / num_anchors - 5)
1830 }
1831 },
1832
1833 ConfigOutputRef::Scores(scores) => match scores.decoder {
1834 DecoderType::Ultralytics => Ok(scores.shape[1]),
1835 DecoderType::ModelPack => Ok(scores.shape[2]),
1836 },
1837 ConfigOutputRef::Segmentation(seg) => Ok(seg.shape[3]),
1838 _ => Err(DecoderError::Internal(
1839 "Attempted to get class count from unsupported config output".to_owned(),
1840 )),
1841 }
1842 }
1843
1844 // get the class count from dshape or calculate from num_features
1845 fn get_class_count(
1846 dshape: &[(DimName, usize)],
1847 protos: Option<usize>,
1848 anchors: Option<usize>,
1849 ) -> Result<usize, DecoderError> {
1850 if dshape.is_empty() {
1851 return Ok(0);
1852 }
1853 // if it has num_classes in dshape, return it
1854 for (dim_name, dim_size) in dshape {
1855 if *dim_name == DimName::NumClasses {
1856 return Ok(*dim_size);
1857 }
1858 }
1859
1860 // number of classes can be calculated from num_features - 4 for yolo. If the
1861 // model has protos, we also subtract the number of protos.
1862 for (dim_name, dim_size) in dshape {
1863 if *dim_name == DimName::NumFeatures {
1864 let protos = protos.unwrap_or(0);
1865 if protos + 4 >= *dim_size {
1866 return Err(DecoderError::InvalidConfig(format!(
1867 "Invalid shape: Yolo num_features {} must be greater than {}",
1868 *dim_size,
1869 protos + 4,
1870 )));
1871 }
1872 return Ok(*dim_size - 4 - protos);
1873 }
1874 }
1875
1876 // number of classes can be calculated from number of anchors for modelpack
1877 // split detection
1878 if let Some(num_anchors) = anchors {
1879 for (dim_name, dim_size) in dshape {
1880 if *dim_name == DimName::NumAnchorsXFeatures {
1881 let anchors_x_features = *dim_size;
1882 if anchors_x_features <= num_anchors * 5 {
1883 return Err(DecoderError::InvalidConfig(format!(
1884 "Invalid ModelPack Split Detection shape: anchors_x_features {} not greater than number of anchors * 5 = {}",
1885 anchors_x_features,
1886 num_anchors * 5,
1887 )));
1888 }
1889
1890 if !anchors_x_features.is_multiple_of(num_anchors) {
1891 return Err(DecoderError::InvalidConfig(format!(
1892 "Invalid ModelPack Split Detection shape: anchors_x_features {} not a multiple of number of anchors {}",
1893 anchors_x_features, num_anchors
1894 )));
1895 }
1896 return Ok((anchors_x_features / num_anchors) - 5);
1897 }
1898 }
1899 }
1900 Err(DecoderError::InvalidConfig(
1901 "Cannot determine number of classes from dshape".to_owned(),
1902 ))
1903 }
1904
1905 fn get_protos_count(dshape: &[(DimName, usize)]) -> Option<usize> {
1906 for (dim_name, dim_size) in dshape {
1907 if *dim_name == DimName::NumProtos {
1908 return Some(*dim_size);
1909 }
1910 }
1911 None
1912 }
1913}