1use std::collections::HashSet;
5
6use ndarray::{s, Array3, ArrayView, ArrayViewD, Dimension};
7use ndarray_stats::QuantileExt;
8use num_traits::{AsPrimitive, Float};
9use serde::{Deserialize, Serialize};
10
11use crate::{
12 configs::{DecoderType, DimName, ModelType, QuantTuple},
13 dequantize_ndarray,
14 modelpack::{
15 decode_modelpack_det, decode_modelpack_float, decode_modelpack_split_float,
16 ModelPackDetectionConfig,
17 },
18 yolo::{
19 decode_yolo_det, decode_yolo_det_float, decode_yolo_segdet_float, decode_yolo_segdet_quant,
20 decode_yolo_split_det_float, decode_yolo_split_det_quant, decode_yolo_split_segdet_float,
21 impl_yolo_split_segdet_quant_get_boxes, impl_yolo_split_segdet_quant_process_masks,
22 },
23 DecoderError, DecoderVersion, DetectBox, ProtoData, Quantization, Segmentation, XYWH,
24};
25
26#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
38pub struct ConfigOutputs {
39 #[serde(default)]
40 pub outputs: Vec<ConfigOutput>,
41 #[serde(default, skip_serializing_if = "Option::is_none")]
49 pub nms: Option<configs::Nms>,
50 #[serde(default, skip_serializing_if = "Option::is_none")]
57 pub decoder_version: Option<configs::DecoderVersion>,
58}
59
60#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
61#[serde(tag = "type")]
62pub enum ConfigOutput {
63 #[serde(rename = "detection")]
64 Detection(configs::Detection),
65 #[serde(rename = "masks")]
66 Mask(configs::Mask),
67 #[serde(rename = "segmentation")]
68 Segmentation(configs::Segmentation),
69 #[serde(rename = "protos")]
70 Protos(configs::Protos),
71 #[serde(rename = "scores")]
72 Scores(configs::Scores),
73 #[serde(rename = "boxes")]
74 Boxes(configs::Boxes),
75 #[serde(rename = "mask_coefficients")]
76 MaskCoefficients(configs::MaskCoefficients),
77 #[serde(rename = "classes")]
78 Classes(configs::Classes),
79}
80
81#[derive(Debug, PartialEq, Clone)]
82pub enum ConfigOutputRef<'a> {
83 Detection(&'a configs::Detection),
84 Mask(&'a configs::Mask),
85 Segmentation(&'a configs::Segmentation),
86 Protos(&'a configs::Protos),
87 Scores(&'a configs::Scores),
88 Boxes(&'a configs::Boxes),
89 MaskCoefficients(&'a configs::MaskCoefficients),
90 Classes(&'a configs::Classes),
91}
92
93impl<'a> ConfigOutputRef<'a> {
94 fn decoder(&self) -> configs::DecoderType {
95 match self {
96 ConfigOutputRef::Detection(v) => v.decoder,
97 ConfigOutputRef::Mask(v) => v.decoder,
98 ConfigOutputRef::Segmentation(v) => v.decoder,
99 ConfigOutputRef::Protos(v) => v.decoder,
100 ConfigOutputRef::Scores(v) => v.decoder,
101 ConfigOutputRef::Boxes(v) => v.decoder,
102 ConfigOutputRef::MaskCoefficients(v) => v.decoder,
103 ConfigOutputRef::Classes(v) => v.decoder,
104 }
105 }
106
107 fn dshape(&self) -> &[(DimName, usize)] {
108 match self {
109 ConfigOutputRef::Detection(v) => &v.dshape,
110 ConfigOutputRef::Mask(v) => &v.dshape,
111 ConfigOutputRef::Segmentation(v) => &v.dshape,
112 ConfigOutputRef::Protos(v) => &v.dshape,
113 ConfigOutputRef::Scores(v) => &v.dshape,
114 ConfigOutputRef::Boxes(v) => &v.dshape,
115 ConfigOutputRef::MaskCoefficients(v) => &v.dshape,
116 ConfigOutputRef::Classes(v) => &v.dshape,
117 }
118 }
119}
120
121impl<'a> From<&'a configs::Detection> for ConfigOutputRef<'a> {
122 fn from(v: &'a configs::Detection) -> ConfigOutputRef<'a> {
137 ConfigOutputRef::Detection(v)
138 }
139}
140
141impl<'a> From<&'a configs::Mask> for ConfigOutputRef<'a> {
142 fn from(v: &'a configs::Mask) -> ConfigOutputRef<'a> {
155 ConfigOutputRef::Mask(v)
156 }
157}
158
159impl<'a> From<&'a configs::Segmentation> for ConfigOutputRef<'a> {
160 fn from(v: &'a configs::Segmentation) -> ConfigOutputRef<'a> {
173 ConfigOutputRef::Segmentation(v)
174 }
175}
176
177impl<'a> From<&'a configs::Protos> for ConfigOutputRef<'a> {
178 fn from(v: &'a configs::Protos) -> ConfigOutputRef<'a> {
191 ConfigOutputRef::Protos(v)
192 }
193}
194
195impl<'a> From<&'a configs::Scores> for ConfigOutputRef<'a> {
196 fn from(v: &'a configs::Scores) -> ConfigOutputRef<'a> {
209 ConfigOutputRef::Scores(v)
210 }
211}
212
213impl<'a> From<&'a configs::Boxes> for ConfigOutputRef<'a> {
214 fn from(v: &'a configs::Boxes) -> ConfigOutputRef<'a> {
228 ConfigOutputRef::Boxes(v)
229 }
230}
231
232impl<'a> From<&'a configs::MaskCoefficients> for ConfigOutputRef<'a> {
233 fn from(v: &'a configs::MaskCoefficients) -> ConfigOutputRef<'a> {
246 ConfigOutputRef::MaskCoefficients(v)
247 }
248}
249
250impl<'a> From<&'a configs::Classes> for ConfigOutputRef<'a> {
251 fn from(v: &'a configs::Classes) -> ConfigOutputRef<'a> {
252 ConfigOutputRef::Classes(v)
253 }
254}
255
256impl ConfigOutput {
257 pub fn shape(&self) -> &[usize] {
274 match self {
275 ConfigOutput::Detection(detection) => &detection.shape,
276 ConfigOutput::Mask(mask) => &mask.shape,
277 ConfigOutput::Segmentation(segmentation) => &segmentation.shape,
278 ConfigOutput::Scores(scores) => &scores.shape,
279 ConfigOutput::Boxes(boxes) => &boxes.shape,
280 ConfigOutput::Protos(protos) => &protos.shape,
281 ConfigOutput::MaskCoefficients(mask_coefficients) => &mask_coefficients.shape,
282 ConfigOutput::Classes(classes) => &classes.shape,
283 }
284 }
285
286 pub fn decoder(&self) -> &configs::DecoderType {
303 match self {
304 ConfigOutput::Detection(detection) => &detection.decoder,
305 ConfigOutput::Mask(mask) => &mask.decoder,
306 ConfigOutput::Segmentation(segmentation) => &segmentation.decoder,
307 ConfigOutput::Scores(scores) => &scores.decoder,
308 ConfigOutput::Boxes(boxes) => &boxes.decoder,
309 ConfigOutput::Protos(protos) => &protos.decoder,
310 ConfigOutput::MaskCoefficients(mask_coefficients) => &mask_coefficients.decoder,
311 ConfigOutput::Classes(classes) => &classes.decoder,
312 }
313 }
314
315 pub fn quantization(&self) -> Option<QuantTuple> {
332 match self {
333 ConfigOutput::Detection(detection) => detection.quantization,
334 ConfigOutput::Mask(mask) => mask.quantization,
335 ConfigOutput::Segmentation(segmentation) => segmentation.quantization,
336 ConfigOutput::Scores(scores) => scores.quantization,
337 ConfigOutput::Boxes(boxes) => boxes.quantization,
338 ConfigOutput::Protos(protos) => protos.quantization,
339 ConfigOutput::MaskCoefficients(mask_coefficients) => mask_coefficients.quantization,
340 ConfigOutput::Classes(classes) => classes.quantization,
341 }
342 }
343}
344
345pub mod configs {
346 use std::collections::HashMap;
347 use std::fmt::Display;
348
349 use serde::{Deserialize, Serialize};
350
351 pub fn deserialize_dshape<'de, D>(deserializer: D) -> Result<Vec<(DimName, usize)>, D::Error>
357 where
358 D: serde::Deserializer<'de>,
359 {
360 #[derive(Deserialize)]
361 #[serde(untagged)]
362 enum DShapeItem {
363 Tuple(DimName, usize),
364 Map(HashMap<DimName, usize>),
365 }
366
367 let items: Vec<DShapeItem> = Vec::deserialize(deserializer)?;
368 items
369 .into_iter()
370 .map(|item| match item {
371 DShapeItem::Tuple(name, size) => Ok((name, size)),
372 DShapeItem::Map(map) => {
373 if map.len() != 1 {
374 return Err(serde::de::Error::custom(
375 "dshape map entry must have exactly one key",
376 ));
377 }
378 let (name, size) = map.into_iter().next().unwrap();
379 Ok((name, size))
380 }
381 })
382 .collect()
383 }
384
385 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy)]
386 pub struct QuantTuple(pub f32, pub i32);
387 impl From<QuantTuple> for (f32, i32) {
388 fn from(value: QuantTuple) -> Self {
389 (value.0, value.1)
390 }
391 }
392
393 impl From<(f32, i32)> for QuantTuple {
394 fn from(value: (f32, i32)) -> Self {
395 QuantTuple(value.0, value.1)
396 }
397 }
398
399 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
400 pub struct Segmentation {
401 #[serde(default)]
402 pub decoder: DecoderType,
403 #[serde(default)]
404 pub quantization: Option<QuantTuple>,
405 #[serde(default)]
406 pub shape: Vec<usize>,
407 #[serde(default, deserialize_with = "deserialize_dshape")]
408 pub dshape: Vec<(DimName, usize)>,
409 }
410
411 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
412 pub struct Protos {
413 #[serde(default)]
414 pub decoder: DecoderType,
415 #[serde(default)]
416 pub quantization: Option<QuantTuple>,
417 #[serde(default)]
418 pub shape: Vec<usize>,
419 #[serde(default, deserialize_with = "deserialize_dshape")]
420 pub dshape: Vec<(DimName, usize)>,
421 }
422
423 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
424 pub struct MaskCoefficients {
425 #[serde(default)]
426 pub decoder: DecoderType,
427 #[serde(default)]
428 pub quantization: Option<QuantTuple>,
429 #[serde(default)]
430 pub shape: Vec<usize>,
431 #[serde(default, deserialize_with = "deserialize_dshape")]
432 pub dshape: Vec<(DimName, usize)>,
433 }
434
435 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
436 pub struct Mask {
437 #[serde(default)]
438 pub decoder: DecoderType,
439 #[serde(default)]
440 pub quantization: Option<QuantTuple>,
441 #[serde(default)]
442 pub shape: Vec<usize>,
443 #[serde(default, deserialize_with = "deserialize_dshape")]
444 pub dshape: Vec<(DimName, usize)>,
445 }
446
447 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
448 pub struct Detection {
449 #[serde(default)]
450 pub anchors: Option<Vec<[f32; 2]>>,
451 #[serde(default)]
452 pub decoder: DecoderType,
453 #[serde(default)]
454 pub quantization: Option<QuantTuple>,
455 #[serde(default)]
456 pub shape: Vec<usize>,
457 #[serde(default, deserialize_with = "deserialize_dshape")]
458 pub dshape: Vec<(DimName, usize)>,
459 #[serde(default)]
466 pub normalized: Option<bool>,
467 }
468
469 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
470 pub struct Scores {
471 #[serde(default)]
472 pub decoder: DecoderType,
473 #[serde(default)]
474 pub quantization: Option<QuantTuple>,
475 #[serde(default)]
476 pub shape: Vec<usize>,
477 #[serde(default, deserialize_with = "deserialize_dshape")]
478 pub dshape: Vec<(DimName, usize)>,
479 }
480
481 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
482 pub struct Boxes {
483 #[serde(default)]
484 pub decoder: DecoderType,
485 #[serde(default)]
486 pub quantization: Option<QuantTuple>,
487 #[serde(default)]
488 pub shape: Vec<usize>,
489 #[serde(default, deserialize_with = "deserialize_dshape")]
490 pub dshape: Vec<(DimName, usize)>,
491 #[serde(default)]
498 pub normalized: Option<bool>,
499 }
500
501 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
502 pub struct Classes {
503 #[serde(default)]
504 pub decoder: DecoderType,
505 #[serde(default)]
506 pub quantization: Option<QuantTuple>,
507 #[serde(default)]
508 pub shape: Vec<usize>,
509 #[serde(default, deserialize_with = "deserialize_dshape")]
510 pub dshape: Vec<(DimName, usize)>,
511 }
512
513 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq)]
514 pub enum DimName {
515 #[serde(rename = "batch")]
516 Batch,
517 #[serde(rename = "height")]
518 Height,
519 #[serde(rename = "width")]
520 Width,
521 #[serde(rename = "num_classes")]
522 NumClasses,
523 #[serde(rename = "num_features")]
524 NumFeatures,
525 #[serde(rename = "num_boxes")]
526 NumBoxes,
527 #[serde(rename = "num_protos")]
528 NumProtos,
529 #[serde(rename = "num_anchors_x_features")]
530 NumAnchorsXFeatures,
531 #[serde(rename = "padding")]
532 Padding,
533 #[serde(rename = "box_coords")]
534 BoxCoords,
535 }
536
537 impl Display for DimName {
538 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
548 match self {
549 DimName::Batch => write!(f, "batch"),
550 DimName::Height => write!(f, "height"),
551 DimName::Width => write!(f, "width"),
552 DimName::NumClasses => write!(f, "num_classes"),
553 DimName::NumFeatures => write!(f, "num_features"),
554 DimName::NumBoxes => write!(f, "num_boxes"),
555 DimName::NumProtos => write!(f, "num_protos"),
556 DimName::NumAnchorsXFeatures => write!(f, "num_anchors_x_features"),
557 DimName::Padding => write!(f, "padding"),
558 DimName::BoxCoords => write!(f, "box_coords"),
559 }
560 }
561 }
562
563 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq, Default)]
564 pub enum DecoderType {
565 #[serde(rename = "modelpack")]
566 ModelPack,
567 #[default]
568 #[serde(rename = "ultralytics", alias = "yolov8")]
569 Ultralytics,
570 }
571
572 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq)]
584 #[serde(rename_all = "lowercase")]
585 pub enum DecoderVersion {
586 #[serde(rename = "yolov5")]
588 Yolov5,
589 #[serde(rename = "yolov8")]
591 Yolov8,
592 #[serde(rename = "yolo11")]
594 Yolo11,
595 #[serde(rename = "yolo26")]
598 Yolo26,
599 }
600
601 impl DecoderVersion {
602 pub fn is_end_to_end(&self) -> bool {
605 matches!(self, DecoderVersion::Yolo26)
606 }
607 }
608
609 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq, Default)]
618 #[serde(rename_all = "snake_case")]
619 pub enum Nms {
620 #[default]
623 ClassAgnostic,
624 ClassAware,
626 }
627
628 #[derive(Debug, Clone, PartialEq)]
629 pub enum ModelType {
630 ModelPackSegDet {
631 boxes: Boxes,
632 scores: Scores,
633 segmentation: Segmentation,
634 },
635 ModelPackSegDetSplit {
636 detection: Vec<Detection>,
637 segmentation: Segmentation,
638 },
639 ModelPackDet {
640 boxes: Boxes,
641 scores: Scores,
642 },
643 ModelPackDetSplit {
644 detection: Vec<Detection>,
645 },
646 ModelPackSeg {
647 segmentation: Segmentation,
648 },
649 YoloDet {
650 boxes: Detection,
651 },
652 YoloSegDet {
653 boxes: Detection,
654 protos: Protos,
655 },
656 YoloSplitDet {
657 boxes: Boxes,
658 scores: Scores,
659 },
660 YoloSplitSegDet {
661 boxes: Boxes,
662 scores: Scores,
663 mask_coeff: MaskCoefficients,
664 protos: Protos,
665 },
666 YoloEndToEndDet {
670 boxes: Detection,
671 },
672 YoloEndToEndSegDet {
676 boxes: Detection,
677 protos: Protos,
678 },
679 YoloSplitEndToEndDet {
683 boxes: Boxes,
684 scores: Scores,
685 classes: Classes,
686 },
687 YoloSplitEndToEndSegDet {
690 boxes: Boxes,
691 scores: Scores,
692 classes: Classes,
693 mask_coeff: MaskCoefficients,
694 protos: Protos,
695 },
696 }
697
698 #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
699 #[serde(rename_all = "lowercase")]
700 pub enum DataType {
701 Raw = 0,
702 Int8 = 1,
703 UInt8 = 2,
704 Int16 = 3,
705 UInt16 = 4,
706 Float16 = 5,
707 Int32 = 6,
708 UInt32 = 7,
709 Float32 = 8,
710 Int64 = 9,
711 UInt64 = 10,
712 Float64 = 11,
713 String = 12,
714 }
715}
716
717#[derive(Debug, Clone, PartialEq)]
718pub struct DecoderBuilder {
719 config_src: Option<ConfigSource>,
720 iou_threshold: f32,
721 score_threshold: f32,
722 nms: Option<configs::Nms>,
725}
726
727#[derive(Debug, Clone, PartialEq)]
728enum ConfigSource {
729 Yaml(String),
730 Json(String),
731 Config(ConfigOutputs),
732}
733
734impl Default for DecoderBuilder {
735 fn default() -> Self {
755 Self {
756 config_src: None,
757 iou_threshold: 0.5,
758 score_threshold: 0.5,
759 nms: Some(configs::Nms::ClassAgnostic),
760 }
761 }
762}
763
764impl DecoderBuilder {
765 pub fn new() -> Self {
785 Self::default()
786 }
787
788 pub fn with_config_yaml_str(mut self, yaml_str: String) -> Self {
805 self.config_src.replace(ConfigSource::Yaml(yaml_str));
806 self
807 }
808
809 pub fn with_config_json_str(mut self, json_str: String) -> Self {
826 self.config_src.replace(ConfigSource::Json(json_str));
827 self
828 }
829
830 pub fn with_config(mut self, config: ConfigOutputs) -> Self {
847 self.config_src.replace(ConfigSource::Config(config));
848 self
849 }
850
851 pub fn with_config_yolo_det(
876 mut self,
877 boxes: configs::Detection,
878 version: Option<DecoderVersion>,
879 ) -> Self {
880 let config = ConfigOutputs {
881 outputs: vec![ConfigOutput::Detection(boxes)],
882 decoder_version: version,
883 ..Default::default()
884 };
885 self.config_src.replace(ConfigSource::Config(config));
886 self
887 }
888
889 pub fn with_config_yolo_split_det(
916 mut self,
917 boxes: configs::Boxes,
918 scores: configs::Scores,
919 ) -> Self {
920 let config = ConfigOutputs {
921 outputs: vec![ConfigOutput::Boxes(boxes), ConfigOutput::Scores(scores)],
922 ..Default::default()
923 };
924 self.config_src.replace(ConfigSource::Config(config));
925 self
926 }
927
928 pub fn with_config_yolo_segdet(
960 mut self,
961 boxes: configs::Detection,
962 protos: configs::Protos,
963 version: Option<DecoderVersion>,
964 ) -> Self {
965 let config = ConfigOutputs {
966 outputs: vec![ConfigOutput::Detection(boxes), ConfigOutput::Protos(protos)],
967 decoder_version: version,
968 ..Default::default()
969 };
970 self.config_src.replace(ConfigSource::Config(config));
971 self
972 }
973
974 pub fn with_config_yolo_split_segdet(
1013 mut self,
1014 boxes: configs::Boxes,
1015 scores: configs::Scores,
1016 mask_coefficients: configs::MaskCoefficients,
1017 protos: configs::Protos,
1018 ) -> Self {
1019 let config = ConfigOutputs {
1020 outputs: vec![
1021 ConfigOutput::Boxes(boxes),
1022 ConfigOutput::Scores(scores),
1023 ConfigOutput::MaskCoefficients(mask_coefficients),
1024 ConfigOutput::Protos(protos),
1025 ],
1026 ..Default::default()
1027 };
1028 self.config_src.replace(ConfigSource::Config(config));
1029 self
1030 }
1031
1032 pub fn with_config_modelpack_det(
1059 mut self,
1060 boxes: configs::Boxes,
1061 scores: configs::Scores,
1062 ) -> Self {
1063 let config = ConfigOutputs {
1064 outputs: vec![ConfigOutput::Boxes(boxes), ConfigOutput::Scores(scores)],
1065 ..Default::default()
1066 };
1067 self.config_src.replace(ConfigSource::Config(config));
1068 self
1069 }
1070
1071 pub fn with_config_modelpack_det_split(mut self, boxes: Vec<configs::Detection>) -> Self {
1110 let outputs = boxes.into_iter().map(ConfigOutput::Detection).collect();
1111 let config = ConfigOutputs {
1112 outputs,
1113 ..Default::default()
1114 };
1115 self.config_src.replace(ConfigSource::Config(config));
1116 self
1117 }
1118
1119 pub fn with_config_modelpack_segdet(
1152 mut self,
1153 boxes: configs::Boxes,
1154 scores: configs::Scores,
1155 segmentation: configs::Segmentation,
1156 ) -> Self {
1157 let config = ConfigOutputs {
1158 outputs: vec![
1159 ConfigOutput::Boxes(boxes),
1160 ConfigOutput::Scores(scores),
1161 ConfigOutput::Segmentation(segmentation),
1162 ],
1163 ..Default::default()
1164 };
1165 self.config_src.replace(ConfigSource::Config(config));
1166 self
1167 }
1168
1169 pub fn with_config_modelpack_segdet_split(
1213 mut self,
1214 boxes: Vec<configs::Detection>,
1215 segmentation: configs::Segmentation,
1216 ) -> Self {
1217 let mut outputs = boxes
1218 .into_iter()
1219 .map(ConfigOutput::Detection)
1220 .collect::<Vec<_>>();
1221 outputs.push(ConfigOutput::Segmentation(segmentation));
1222 let config = ConfigOutputs {
1223 outputs,
1224 ..Default::default()
1225 };
1226 self.config_src.replace(ConfigSource::Config(config));
1227 self
1228 }
1229
1230 pub fn with_config_modelpack_seg(mut self, segmentation: configs::Segmentation) -> Self {
1250 let config = ConfigOutputs {
1251 outputs: vec![ConfigOutput::Segmentation(segmentation)],
1252 ..Default::default()
1253 };
1254 self.config_src.replace(ConfigSource::Config(config));
1255 self
1256 }
1257
1258 pub fn add_output(mut self, output: ConfigOutput) -> Self {
1300 if !matches!(self.config_src, Some(ConfigSource::Config(_))) {
1301 self.config_src = Some(ConfigSource::Config(ConfigOutputs::default()));
1302 }
1303 if let Some(ConfigSource::Config(ref mut config)) = self.config_src {
1304 config.outputs.push(Self::normalize_output(output));
1305 }
1306 self
1307 }
1308
1309 pub fn with_decoder_version(mut self, version: configs::DecoderVersion) -> Self {
1338 if !matches!(self.config_src, Some(ConfigSource::Config(_))) {
1339 self.config_src = Some(ConfigSource::Config(ConfigOutputs::default()));
1340 }
1341 if let Some(ConfigSource::Config(ref mut config)) = self.config_src {
1342 config.decoder_version = Some(version);
1343 }
1344 self
1345 }
1346
1347 fn normalize_output(mut output: ConfigOutput) -> ConfigOutput {
1349 fn normalize_shape(shape: &mut Vec<usize>, dshape: &[(configs::DimName, usize)]) {
1350 if !dshape.is_empty() {
1351 *shape = dshape.iter().map(|(_, size)| *size).collect();
1352 }
1353 }
1354 match &mut output {
1355 ConfigOutput::Detection(c) => normalize_shape(&mut c.shape, &c.dshape),
1356 ConfigOutput::Boxes(c) => normalize_shape(&mut c.shape, &c.dshape),
1357 ConfigOutput::Scores(c) => normalize_shape(&mut c.shape, &c.dshape),
1358 ConfigOutput::Protos(c) => normalize_shape(&mut c.shape, &c.dshape),
1359 ConfigOutput::Segmentation(c) => normalize_shape(&mut c.shape, &c.dshape),
1360 ConfigOutput::MaskCoefficients(c) => normalize_shape(&mut c.shape, &c.dshape),
1361 ConfigOutput::Mask(c) => normalize_shape(&mut c.shape, &c.dshape),
1362 ConfigOutput::Classes(c) => normalize_shape(&mut c.shape, &c.dshape),
1363 }
1364 output
1365 }
1366
1367 pub fn with_score_threshold(mut self, score_threshold: f32) -> Self {
1383 self.score_threshold = score_threshold;
1384 self
1385 }
1386
1387 pub fn with_iou_threshold(mut self, iou_threshold: f32) -> Self {
1404 self.iou_threshold = iou_threshold;
1405 self
1406 }
1407
1408 pub fn with_nms(mut self, nms: Option<configs::Nms>) -> Self {
1430 self.nms = nms;
1431 self
1432 }
1433
1434 pub fn build(self) -> Result<Decoder, DecoderError> {
1451 let config = match self.config_src {
1452 Some(ConfigSource::Json(s)) => serde_json::from_str(&s)?,
1453 Some(ConfigSource::Yaml(s)) => serde_yaml::from_str(&s)?,
1454 Some(ConfigSource::Config(c)) => c,
1455 None => return Err(DecoderError::NoConfig),
1456 };
1457
1458 let normalized = Self::get_normalized(&config.outputs);
1460
1461 let nms = config.nms.or(self.nms);
1463 let model_type = Self::get_model_type(config)?;
1464
1465 Ok(Decoder {
1466 model_type,
1467 iou_threshold: self.iou_threshold,
1468 score_threshold: self.score_threshold,
1469 nms,
1470 normalized,
1471 })
1472 }
1473
1474 fn get_normalized(outputs: &[ConfigOutput]) -> Option<bool> {
1479 for output in outputs {
1480 match output {
1481 ConfigOutput::Detection(det) => return det.normalized,
1482 ConfigOutput::Boxes(boxes) => return boxes.normalized,
1483 _ => {}
1484 }
1485 }
1486 None }
1488
1489 fn get_model_type(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
1490 let mut yolo = false;
1492 let mut modelpack = false;
1493 for c in &configs.outputs {
1494 match c.decoder() {
1495 DecoderType::ModelPack => modelpack = true,
1496 DecoderType::Ultralytics => yolo = true,
1497 }
1498 }
1499 match (modelpack, yolo) {
1500 (true, true) => Err(DecoderError::InvalidConfig(
1501 "Both ModelPack and Yolo outputs found in config".to_string(),
1502 )),
1503 (true, false) => Self::get_model_type_modelpack(configs),
1504 (false, true) => Self::get_model_type_yolo(configs),
1505 (false, false) => Err(DecoderError::InvalidConfig(
1506 "No outputs found in config".to_string(),
1507 )),
1508 }
1509 }
1510
1511 fn get_model_type_yolo(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
1512 let mut boxes = None;
1513 let mut protos = None;
1514 let mut split_boxes = None;
1515 let mut split_scores = None;
1516 let mut split_mask_coeff = None;
1517 let mut split_classes = None;
1518 for c in configs.outputs {
1519 match c {
1520 ConfigOutput::Detection(detection) => boxes = Some(detection),
1521 ConfigOutput::Segmentation(_) => {
1522 return Err(DecoderError::InvalidConfig(
1523 "Invalid Segmentation output with Yolo decoder".to_string(),
1524 ));
1525 }
1526 ConfigOutput::Protos(protos_) => protos = Some(protos_),
1527 ConfigOutput::Mask(_) => {
1528 return Err(DecoderError::InvalidConfig(
1529 "Invalid Mask output with Yolo decoder".to_string(),
1530 ));
1531 }
1532 ConfigOutput::Scores(scores) => split_scores = Some(scores),
1533 ConfigOutput::Boxes(boxes) => split_boxes = Some(boxes),
1534 ConfigOutput::MaskCoefficients(mask_coeff) => split_mask_coeff = Some(mask_coeff),
1535 ConfigOutput::Classes(classes) => split_classes = Some(classes),
1536 }
1537 }
1538
1539 let is_end_to_end_dshape = boxes.as_ref().is_some_and(|b| {
1544 let dims = b.dshape.iter().map(|(d, _)| *d).collect::<Vec<_>>();
1545 dims == vec![DimName::Batch, DimName::NumBoxes, DimName::NumFeatures]
1546 });
1547
1548 let is_end_to_end = configs
1549 .decoder_version
1550 .map(|v| v.is_end_to_end())
1551 .unwrap_or(is_end_to_end_dshape);
1552
1553 if is_end_to_end {
1554 if let Some(boxes) = boxes {
1555 if let Some(protos) = protos {
1556 Self::verify_yolo_seg_det_26(&boxes, &protos)?;
1557 return Ok(ModelType::YoloEndToEndSegDet { boxes, protos });
1558 } else {
1559 Self::verify_yolo_det_26(&boxes)?;
1560 return Ok(ModelType::YoloEndToEndDet { boxes });
1561 }
1562 } else if let (Some(split_boxes), Some(split_scores), Some(split_classes)) =
1563 (split_boxes, split_scores, split_classes)
1564 {
1565 if let (Some(split_mask_coeff), Some(protos)) = (split_mask_coeff, protos) {
1566 Self::verify_yolo_split_end_to_end_segdet(
1567 &split_boxes,
1568 &split_scores,
1569 &split_classes,
1570 &split_mask_coeff,
1571 &protos,
1572 )?;
1573 return Ok(ModelType::YoloSplitEndToEndSegDet {
1574 boxes: split_boxes,
1575 scores: split_scores,
1576 classes: split_classes,
1577 mask_coeff: split_mask_coeff,
1578 protos,
1579 });
1580 }
1581 Self::verify_yolo_split_end_to_end_det(
1582 &split_boxes,
1583 &split_scores,
1584 &split_classes,
1585 )?;
1586 return Ok(ModelType::YoloSplitEndToEndDet {
1587 boxes: split_boxes,
1588 scores: split_scores,
1589 classes: split_classes,
1590 });
1591 } else {
1592 return Err(DecoderError::InvalidConfig(
1593 "Invalid Yolo end-to-end model outputs".to_string(),
1594 ));
1595 }
1596 }
1597
1598 if let Some(boxes) = boxes {
1599 if let Some(protos) = protos {
1600 Self::verify_yolo_seg_det(&boxes, &protos)?;
1601 Ok(ModelType::YoloSegDet { boxes, protos })
1602 } else {
1603 Self::verify_yolo_det(&boxes)?;
1604 Ok(ModelType::YoloDet { boxes })
1605 }
1606 } else if let (Some(boxes), Some(scores)) = (split_boxes, split_scores) {
1607 if let (Some(mask_coeff), Some(protos)) = (split_mask_coeff, protos) {
1608 Self::verify_yolo_split_segdet(&boxes, &scores, &mask_coeff, &protos)?;
1609 Ok(ModelType::YoloSplitSegDet {
1610 boxes,
1611 scores,
1612 mask_coeff,
1613 protos,
1614 })
1615 } else {
1616 Self::verify_yolo_split_det(&boxes, &scores)?;
1617 Ok(ModelType::YoloSplitDet { boxes, scores })
1618 }
1619 } else {
1620 Err(DecoderError::InvalidConfig(
1621 "Invalid Yolo model outputs".to_string(),
1622 ))
1623 }
1624 }
1625
1626 fn verify_yolo_det(detect: &configs::Detection) -> Result<(), DecoderError> {
1627 if detect.shape.len() != 3 {
1628 return Err(DecoderError::InvalidConfig(format!(
1629 "Invalid Yolo Detection shape {:?}",
1630 detect.shape
1631 )));
1632 }
1633
1634 Self::verify_dshapes(
1635 &detect.dshape,
1636 &detect.shape,
1637 "Detection",
1638 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1639 )?;
1640 if !detect.dshape.is_empty() {
1641 Self::get_class_count(&detect.dshape, None, None)?;
1642 } else {
1643 Self::get_class_count_no_dshape(detect.into(), None)?;
1644 }
1645
1646 Ok(())
1647 }
1648
1649 fn verify_yolo_det_26(detect: &configs::Detection) -> Result<(), DecoderError> {
1650 if detect.shape.len() != 3 {
1651 return Err(DecoderError::InvalidConfig(format!(
1652 "Invalid Yolo Detection shape {:?}",
1653 detect.shape
1654 )));
1655 }
1656
1657 Self::verify_dshapes(
1658 &detect.dshape,
1659 &detect.shape,
1660 "Detection",
1661 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1662 )?;
1663
1664 if !detect.shape.contains(&6) {
1665 return Err(DecoderError::InvalidConfig(
1666 "Yolo26 Detection must have 6 features".to_string(),
1667 ));
1668 }
1669
1670 Ok(())
1671 }
1672
1673 fn verify_yolo_seg_det(
1674 detection: &configs::Detection,
1675 protos: &configs::Protos,
1676 ) -> Result<(), DecoderError> {
1677 if detection.shape.len() != 3 {
1678 return Err(DecoderError::InvalidConfig(format!(
1679 "Invalid Yolo Detection shape {:?}",
1680 detection.shape
1681 )));
1682 }
1683 if protos.shape.len() != 4 {
1684 return Err(DecoderError::InvalidConfig(format!(
1685 "Invalid Yolo Protos shape {:?}",
1686 protos.shape
1687 )));
1688 }
1689
1690 Self::verify_dshapes(
1691 &detection.dshape,
1692 &detection.shape,
1693 "Detection",
1694 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1695 )?;
1696 Self::verify_dshapes(
1697 &protos.dshape,
1698 &protos.shape,
1699 "Protos",
1700 &[
1701 DimName::Batch,
1702 DimName::Height,
1703 DimName::Width,
1704 DimName::NumProtos,
1705 ],
1706 )?;
1707
1708 let protos_count = Self::get_protos_count(&protos.dshape)
1709 .unwrap_or_else(|| protos.shape[1].min(protos.shape[3]));
1710 log::debug!("Protos count: {}", protos_count);
1711 log::debug!("Detection dshape: {:?}", detection.dshape);
1712 let classes = if !detection.dshape.is_empty() {
1713 Self::get_class_count(&detection.dshape, Some(protos_count), None)?
1714 } else {
1715 Self::get_class_count_no_dshape(detection.into(), Some(protos_count))?
1716 };
1717
1718 if classes == 0 {
1719 return Err(DecoderError::InvalidConfig(
1720 "Yolo Segmentation Detection has zero classes".to_string(),
1721 ));
1722 }
1723
1724 Ok(())
1725 }
1726
1727 fn verify_yolo_seg_det_26(
1728 detection: &configs::Detection,
1729 protos: &configs::Protos,
1730 ) -> Result<(), DecoderError> {
1731 if detection.shape.len() != 3 {
1732 return Err(DecoderError::InvalidConfig(format!(
1733 "Invalid Yolo Detection shape {:?}",
1734 detection.shape
1735 )));
1736 }
1737 if protos.shape.len() != 4 {
1738 return Err(DecoderError::InvalidConfig(format!(
1739 "Invalid Yolo Protos shape {:?}",
1740 protos.shape
1741 )));
1742 }
1743
1744 Self::verify_dshapes(
1745 &detection.dshape,
1746 &detection.shape,
1747 "Detection",
1748 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1749 )?;
1750 Self::verify_dshapes(
1751 &protos.dshape,
1752 &protos.shape,
1753 "Protos",
1754 &[
1755 DimName::Batch,
1756 DimName::Height,
1757 DimName::Width,
1758 DimName::NumProtos,
1759 ],
1760 )?;
1761
1762 let protos_count = Self::get_protos_count(&protos.dshape)
1763 .unwrap_or_else(|| protos.shape[1].min(protos.shape[3]));
1764 log::debug!("Protos count: {}", protos_count);
1765 log::debug!("Detection dshape: {:?}", detection.dshape);
1766
1767 if !detection.shape.contains(&(6 + protos_count)) {
1768 return Err(DecoderError::InvalidConfig(format!(
1769 "Yolo26 Segmentation Detection must have num_features be 6 + num_protos = {}",
1770 6 + protos_count
1771 )));
1772 }
1773
1774 Ok(())
1775 }
1776
1777 fn verify_yolo_split_det(
1778 boxes: &configs::Boxes,
1779 scores: &configs::Scores,
1780 ) -> Result<(), DecoderError> {
1781 if boxes.shape.len() != 3 {
1782 return Err(DecoderError::InvalidConfig(format!(
1783 "Invalid Yolo Split Boxes shape {:?}",
1784 boxes.shape
1785 )));
1786 }
1787 if scores.shape.len() != 3 {
1788 return Err(DecoderError::InvalidConfig(format!(
1789 "Invalid Yolo Split Scores shape {:?}",
1790 scores.shape
1791 )));
1792 }
1793
1794 Self::verify_dshapes(
1795 &boxes.dshape,
1796 &boxes.shape,
1797 "Boxes",
1798 &[DimName::Batch, DimName::BoxCoords, DimName::NumBoxes],
1799 )?;
1800 Self::verify_dshapes(
1801 &scores.dshape,
1802 &scores.shape,
1803 "Scores",
1804 &[DimName::Batch, DimName::NumClasses, DimName::NumBoxes],
1805 )?;
1806
1807 let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[2]);
1808 let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[2]);
1809
1810 if boxes_num != scores_num {
1811 return Err(DecoderError::InvalidConfig(format!(
1812 "Yolo Split Detection Boxes num {} incompatible with Scores num {}",
1813 boxes_num, scores_num
1814 )));
1815 }
1816
1817 Ok(())
1818 }
1819
1820 fn verify_yolo_split_segdet(
1821 boxes: &configs::Boxes,
1822 scores: &configs::Scores,
1823 mask_coeff: &configs::MaskCoefficients,
1824 protos: &configs::Protos,
1825 ) -> Result<(), DecoderError> {
1826 if boxes.shape.len() != 3 {
1827 return Err(DecoderError::InvalidConfig(format!(
1828 "Invalid Yolo Split Boxes shape {:?}",
1829 boxes.shape
1830 )));
1831 }
1832 if scores.shape.len() != 3 {
1833 return Err(DecoderError::InvalidConfig(format!(
1834 "Invalid Yolo Split Scores shape {:?}",
1835 scores.shape
1836 )));
1837 }
1838
1839 if mask_coeff.shape.len() != 3 {
1840 return Err(DecoderError::InvalidConfig(format!(
1841 "Invalid Yolo Split Mask Coefficients shape {:?}",
1842 mask_coeff.shape
1843 )));
1844 }
1845
1846 if protos.shape.len() != 4 {
1847 return Err(DecoderError::InvalidConfig(format!(
1848 "Invalid Yolo Protos shape {:?}",
1849 mask_coeff.shape
1850 )));
1851 }
1852
1853 Self::verify_dshapes(
1854 &boxes.dshape,
1855 &boxes.shape,
1856 "Boxes",
1857 &[DimName::Batch, DimName::BoxCoords, DimName::NumBoxes],
1858 )?;
1859 Self::verify_dshapes(
1860 &scores.dshape,
1861 &scores.shape,
1862 "Scores",
1863 &[DimName::Batch, DimName::NumClasses, DimName::NumBoxes],
1864 )?;
1865 Self::verify_dshapes(
1866 &mask_coeff.dshape,
1867 &mask_coeff.shape,
1868 "Mask Coefficients",
1869 &[DimName::Batch, DimName::NumProtos, DimName::NumBoxes],
1870 )?;
1871 Self::verify_dshapes(
1872 &protos.dshape,
1873 &protos.shape,
1874 "Protos",
1875 &[
1876 DimName::Batch,
1877 DimName::Height,
1878 DimName::Width,
1879 DimName::NumProtos,
1880 ],
1881 )?;
1882
1883 let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[2]);
1884 let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[2]);
1885 let mask_num = Self::get_box_count(&mask_coeff.dshape).unwrap_or(mask_coeff.shape[2]);
1886
1887 let mask_channels = if !mask_coeff.dshape.is_empty() {
1888 Self::get_protos_count(&mask_coeff.dshape).ok_or_else(|| {
1889 DecoderError::InvalidConfig("Could not find num_protos in config".to_string())
1890 })?
1891 } else {
1892 mask_coeff.shape[1]
1893 };
1894 let proto_channels = if !protos.dshape.is_empty() {
1895 Self::get_protos_count(&protos.dshape).ok_or_else(|| {
1896 DecoderError::InvalidConfig("Could not find num_protos in config".to_string())
1897 })?
1898 } else {
1899 protos.shape[1].min(protos.shape[3])
1900 };
1901
1902 if boxes_num != scores_num {
1903 return Err(DecoderError::InvalidConfig(format!(
1904 "Yolo Split Detection Boxes num {} incompatible with Scores num {}",
1905 boxes_num, scores_num
1906 )));
1907 }
1908
1909 if boxes_num != mask_num {
1910 return Err(DecoderError::InvalidConfig(format!(
1911 "Yolo Split Detection Boxes num {} incompatible with Mask Coefficients num {}",
1912 boxes_num, mask_num
1913 )));
1914 }
1915
1916 if proto_channels != mask_channels {
1917 return Err(DecoderError::InvalidConfig(format!(
1918 "Yolo Protos channels {} incompatible with Mask Coefficients channels {}",
1919 proto_channels, mask_channels
1920 )));
1921 }
1922
1923 Ok(())
1924 }
1925
1926 fn verify_yolo_split_end_to_end_det(
1927 boxes: &configs::Boxes,
1928 scores: &configs::Scores,
1929 classes: &configs::Classes,
1930 ) -> Result<(), DecoderError> {
1931 if boxes.shape.len() != 3 || !boxes.shape.contains(&4) {
1932 return Err(DecoderError::InvalidConfig(format!(
1933 "Split end-to-end boxes must be [batch, N, 4], got {:?}",
1934 boxes.shape
1935 )));
1936 }
1937 if scores.shape.len() != 3 || !scores.shape.contains(&1) {
1938 return Err(DecoderError::InvalidConfig(format!(
1939 "Split end-to-end scores must be [batch, N, 1], got {:?}",
1940 scores.shape
1941 )));
1942 }
1943 if classes.shape.len() != 3 || !classes.shape.contains(&1) {
1944 return Err(DecoderError::InvalidConfig(format!(
1945 "Split end-to-end classes must be [batch, N, 1], got {:?}",
1946 classes.shape
1947 )));
1948 }
1949 Ok(())
1950 }
1951
1952 fn verify_yolo_split_end_to_end_segdet(
1953 boxes: &configs::Boxes,
1954 scores: &configs::Scores,
1955 classes: &configs::Classes,
1956 mask_coeff: &configs::MaskCoefficients,
1957 protos: &configs::Protos,
1958 ) -> Result<(), DecoderError> {
1959 Self::verify_yolo_split_end_to_end_det(boxes, scores, classes)?;
1960 if mask_coeff.shape.len() != 3 {
1961 return Err(DecoderError::InvalidConfig(format!(
1962 "Invalid split end-to-end mask coefficients shape {:?}",
1963 mask_coeff.shape
1964 )));
1965 }
1966 if protos.shape.len() != 4 {
1967 return Err(DecoderError::InvalidConfig(format!(
1968 "Invalid protos shape {:?}",
1969 protos.shape
1970 )));
1971 }
1972 Ok(())
1973 }
1974
1975 fn get_model_type_modelpack(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
1976 let mut split_decoders = Vec::new();
1977 let mut segment_ = None;
1978 let mut scores_ = None;
1979 let mut boxes_ = None;
1980 for c in configs.outputs {
1981 match c {
1982 ConfigOutput::Detection(detection) => split_decoders.push(detection),
1983 ConfigOutput::Segmentation(segmentation) => segment_ = Some(segmentation),
1984 ConfigOutput::Mask(_) => {}
1985 ConfigOutput::Protos(_) => {
1986 return Err(DecoderError::InvalidConfig(
1987 "ModelPack should not have protos".to_string(),
1988 ));
1989 }
1990 ConfigOutput::Scores(scores) => scores_ = Some(scores),
1991 ConfigOutput::Boxes(boxes) => boxes_ = Some(boxes),
1992 ConfigOutput::MaskCoefficients(_) => {
1993 return Err(DecoderError::InvalidConfig(
1994 "ModelPack should not have mask coefficients".to_string(),
1995 ));
1996 }
1997 ConfigOutput::Classes(_) => {
1998 return Err(DecoderError::InvalidConfig(
1999 "ModelPack should not have classes output".to_string(),
2000 ));
2001 }
2002 }
2003 }
2004
2005 if let Some(segmentation) = segment_ {
2006 if !split_decoders.is_empty() {
2007 let classes = Self::verify_modelpack_split_det(&split_decoders)?;
2008 Self::verify_modelpack_seg(&segmentation, Some(classes))?;
2009 Ok(ModelType::ModelPackSegDetSplit {
2010 detection: split_decoders,
2011 segmentation,
2012 })
2013 } else if let (Some(scores), Some(boxes)) = (scores_, boxes_) {
2014 let classes = Self::verify_modelpack_det(&boxes, &scores)?;
2015 Self::verify_modelpack_seg(&segmentation, Some(classes))?;
2016 Ok(ModelType::ModelPackSegDet {
2017 boxes,
2018 scores,
2019 segmentation,
2020 })
2021 } else {
2022 Self::verify_modelpack_seg(&segmentation, None)?;
2023 Ok(ModelType::ModelPackSeg { segmentation })
2024 }
2025 } else if !split_decoders.is_empty() {
2026 Self::verify_modelpack_split_det(&split_decoders)?;
2027 Ok(ModelType::ModelPackDetSplit {
2028 detection: split_decoders,
2029 })
2030 } else if let (Some(scores), Some(boxes)) = (scores_, boxes_) {
2031 Self::verify_modelpack_det(&boxes, &scores)?;
2032 Ok(ModelType::ModelPackDet { boxes, scores })
2033 } else {
2034 Err(DecoderError::InvalidConfig(
2035 "Invalid ModelPack model outputs".to_string(),
2036 ))
2037 }
2038 }
2039
2040 fn verify_modelpack_det(
2041 boxes: &configs::Boxes,
2042 scores: &configs::Scores,
2043 ) -> Result<usize, DecoderError> {
2044 if boxes.shape.len() != 4 {
2045 return Err(DecoderError::InvalidConfig(format!(
2046 "Invalid ModelPack Boxes shape {:?}",
2047 boxes.shape
2048 )));
2049 }
2050 if scores.shape.len() != 3 {
2051 return Err(DecoderError::InvalidConfig(format!(
2052 "Invalid ModelPack Scores shape {:?}",
2053 scores.shape
2054 )));
2055 }
2056
2057 Self::verify_dshapes(
2058 &boxes.dshape,
2059 &boxes.shape,
2060 "Boxes",
2061 &[
2062 DimName::Batch,
2063 DimName::NumBoxes,
2064 DimName::Padding,
2065 DimName::BoxCoords,
2066 ],
2067 )?;
2068 Self::verify_dshapes(
2069 &scores.dshape,
2070 &scores.shape,
2071 "Scores",
2072 &[DimName::Batch, DimName::NumBoxes, DimName::NumClasses],
2073 )?;
2074
2075 let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[1]);
2076 let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[1]);
2077
2078 if boxes_num != scores_num {
2079 return Err(DecoderError::InvalidConfig(format!(
2080 "ModelPack Detection Boxes num {} incompatible with Scores num {}",
2081 boxes_num, scores_num
2082 )));
2083 }
2084
2085 let num_classes = if !scores.dshape.is_empty() {
2086 Self::get_class_count(&scores.dshape, None, None)?
2087 } else {
2088 Self::get_class_count_no_dshape(scores.into(), None)?
2089 };
2090
2091 Ok(num_classes)
2092 }
2093
2094 fn verify_modelpack_split_det(boxes: &[configs::Detection]) -> Result<usize, DecoderError> {
2095 let mut num_classes = None;
2096 for b in boxes {
2097 let Some(num_anchors) = b.anchors.as_ref().map(|a| a.len()) else {
2098 return Err(DecoderError::InvalidConfig(
2099 "ModelPack Split Detection missing anchors".to_string(),
2100 ));
2101 };
2102
2103 if num_anchors == 0 {
2104 return Err(DecoderError::InvalidConfig(
2105 "ModelPack Split Detection has zero anchors".to_string(),
2106 ));
2107 }
2108
2109 if b.shape.len() != 4 {
2110 return Err(DecoderError::InvalidConfig(format!(
2111 "Invalid ModelPack Split Detection shape {:?}",
2112 b.shape
2113 )));
2114 }
2115
2116 Self::verify_dshapes(
2117 &b.dshape,
2118 &b.shape,
2119 "Split Detection",
2120 &[
2121 DimName::Batch,
2122 DimName::Height,
2123 DimName::Width,
2124 DimName::NumAnchorsXFeatures,
2125 ],
2126 )?;
2127 let classes = if !b.dshape.is_empty() {
2128 Self::get_class_count(&b.dshape, None, Some(num_anchors))?
2129 } else {
2130 Self::get_class_count_no_dshape(b.into(), None)?
2131 };
2132
2133 match num_classes {
2134 Some(n) => {
2135 if n != classes {
2136 return Err(DecoderError::InvalidConfig(format!(
2137 "ModelPack Split Detection inconsistent number of classes: previous {}, current {}",
2138 n, classes
2139 )));
2140 }
2141 }
2142 None => {
2143 num_classes = Some(classes);
2144 }
2145 }
2146 }
2147
2148 Ok(num_classes.unwrap_or(0))
2149 }
2150
2151 fn verify_modelpack_seg(
2152 segmentation: &configs::Segmentation,
2153 classes: Option<usize>,
2154 ) -> Result<(), DecoderError> {
2155 if segmentation.shape.len() != 4 {
2156 return Err(DecoderError::InvalidConfig(format!(
2157 "Invalid ModelPack Segmentation shape {:?}",
2158 segmentation.shape
2159 )));
2160 }
2161 Self::verify_dshapes(
2162 &segmentation.dshape,
2163 &segmentation.shape,
2164 "Segmentation",
2165 &[
2166 DimName::Batch,
2167 DimName::Height,
2168 DimName::Width,
2169 DimName::NumClasses,
2170 ],
2171 )?;
2172
2173 if let Some(classes) = classes {
2174 let seg_classes = if !segmentation.dshape.is_empty() {
2175 Self::get_class_count(&segmentation.dshape, None, None)?
2176 } else {
2177 Self::get_class_count_no_dshape(segmentation.into(), None)?
2178 };
2179
2180 if seg_classes != classes + 1 {
2181 return Err(DecoderError::InvalidConfig(format!(
2182 "ModelPack Segmentation channels {} incompatible with number of classes {}",
2183 seg_classes, classes
2184 )));
2185 }
2186 }
2187 Ok(())
2188 }
2189
2190 fn verify_dshapes(
2192 dshape: &[(DimName, usize)],
2193 shape: &[usize],
2194 name: &str,
2195 dims: &[DimName],
2196 ) -> Result<(), DecoderError> {
2197 for s in shape {
2198 if *s == 0 {
2199 return Err(DecoderError::InvalidConfig(format!(
2200 "{} shape has zero dimension",
2201 name
2202 )));
2203 }
2204 }
2205
2206 if shape.len() != dims.len() {
2207 return Err(DecoderError::InvalidConfig(format!(
2208 "{} shape length {} does not match expected dims length {}",
2209 name,
2210 shape.len(),
2211 dims.len()
2212 )));
2213 }
2214
2215 if dshape.is_empty() {
2216 return Ok(());
2217 }
2218 if dshape.len() != shape.len() {
2220 return Err(DecoderError::InvalidConfig(format!(
2221 "{} dshape length does not match shape length",
2222 name
2223 )));
2224 }
2225
2226 for ((dim_name, dim_size), shape_size) in dshape.iter().zip(shape) {
2228 if dim_size != shape_size {
2229 return Err(DecoderError::InvalidConfig(format!(
2230 "{} dshape dimension {} size {} does not match shape size {}",
2231 name, dim_name, dim_size, shape_size
2232 )));
2233 }
2234 if *dim_name == DimName::Padding && *dim_size != 1 {
2235 return Err(DecoderError::InvalidConfig(
2236 "Padding dimension size must be 1".to_string(),
2237 ));
2238 }
2239
2240 if *dim_name == DimName::BoxCoords && *dim_size != 4 {
2241 return Err(DecoderError::InvalidConfig(
2242 "BoxCoords dimension size must be 4".to_string(),
2243 ));
2244 }
2245 }
2246
2247 let dims_present = HashSet::<DimName>::from_iter(dshape.iter().map(|(name, _)| *name));
2248 for dim in dims {
2249 if !dims_present.contains(dim) {
2250 return Err(DecoderError::InvalidConfig(format!(
2251 "{} dshape missing required dimension {:?}",
2252 name, dim
2253 )));
2254 }
2255 }
2256
2257 Ok(())
2258 }
2259
2260 fn get_box_count(dshape: &[(DimName, usize)]) -> Option<usize> {
2261 for (dim_name, dim_size) in dshape {
2262 if *dim_name == DimName::NumBoxes {
2263 return Some(*dim_size);
2264 }
2265 }
2266 None
2267 }
2268
2269 fn get_class_count_no_dshape(
2270 config: ConfigOutputRef,
2271 protos: Option<usize>,
2272 ) -> Result<usize, DecoderError> {
2273 match config {
2274 ConfigOutputRef::Detection(detection) => match detection.decoder {
2275 DecoderType::Ultralytics => {
2276 if detection.shape[1] <= 4 + protos.unwrap_or(0) {
2277 return Err(DecoderError::InvalidConfig(format!(
2278 "Invalid shape: Yolo num_features {} must be greater than {}",
2279 detection.shape[1],
2280 4 + protos.unwrap_or(0),
2281 )));
2282 }
2283 Ok(detection.shape[1] - 4 - protos.unwrap_or(0))
2284 }
2285 DecoderType::ModelPack => {
2286 let Some(num_anchors) = detection.anchors.as_ref().map(|a| a.len()) else {
2287 return Err(DecoderError::Internal(
2288 "ModelPack Detection missing anchors".to_string(),
2289 ));
2290 };
2291 let anchors_x_features = detection.shape[3];
2292 if anchors_x_features <= num_anchors * 5 {
2293 return Err(DecoderError::InvalidConfig(format!(
2294 "Invalid ModelPack Split Detection shape: anchors_x_features {} not greater than number of anchors * 5 = {}",
2295 anchors_x_features,
2296 num_anchors * 5,
2297 )));
2298 }
2299
2300 if !anchors_x_features.is_multiple_of(num_anchors) {
2301 return Err(DecoderError::InvalidConfig(format!(
2302 "Invalid ModelPack Split Detection shape: anchors_x_features {} not a multiple of number of anchors {}",
2303 anchors_x_features, num_anchors
2304 )));
2305 }
2306 Ok(anchors_x_features / num_anchors - 5)
2307 }
2308 },
2309
2310 ConfigOutputRef::Scores(scores) => match scores.decoder {
2311 DecoderType::Ultralytics => Ok(scores.shape[1]),
2312 DecoderType::ModelPack => Ok(scores.shape[2]),
2313 },
2314 ConfigOutputRef::Segmentation(seg) => Ok(seg.shape[3]),
2315 _ => Err(DecoderError::Internal(
2316 "Attempted to get class count from unsupported config output".to_owned(),
2317 )),
2318 }
2319 }
2320
2321 fn get_class_count(
2323 dshape: &[(DimName, usize)],
2324 protos: Option<usize>,
2325 anchors: Option<usize>,
2326 ) -> Result<usize, DecoderError> {
2327 if dshape.is_empty() {
2328 return Ok(0);
2329 }
2330 for (dim_name, dim_size) in dshape {
2332 if *dim_name == DimName::NumClasses {
2333 return Ok(*dim_size);
2334 }
2335 }
2336
2337 for (dim_name, dim_size) in dshape {
2340 if *dim_name == DimName::NumFeatures {
2341 let protos = protos.unwrap_or(0);
2342 if protos + 4 >= *dim_size {
2343 return Err(DecoderError::InvalidConfig(format!(
2344 "Invalid shape: Yolo num_features {} must be greater than {}",
2345 *dim_size,
2346 protos + 4,
2347 )));
2348 }
2349 return Ok(*dim_size - 4 - protos);
2350 }
2351 }
2352
2353 if let Some(num_anchors) = anchors {
2356 for (dim_name, dim_size) in dshape {
2357 if *dim_name == DimName::NumAnchorsXFeatures {
2358 let anchors_x_features = *dim_size;
2359 if anchors_x_features <= num_anchors * 5 {
2360 return Err(DecoderError::InvalidConfig(format!(
2361 "Invalid ModelPack Split Detection shape: anchors_x_features {} not greater than number of anchors * 5 = {}",
2362 anchors_x_features,
2363 num_anchors * 5,
2364 )));
2365 }
2366
2367 if !anchors_x_features.is_multiple_of(num_anchors) {
2368 return Err(DecoderError::InvalidConfig(format!(
2369 "Invalid ModelPack Split Detection shape: anchors_x_features {} not a multiple of number of anchors {}",
2370 anchors_x_features, num_anchors
2371 )));
2372 }
2373 return Ok((anchors_x_features / num_anchors) - 5);
2374 }
2375 }
2376 }
2377 Err(DecoderError::InvalidConfig(
2378 "Cannot determine number of classes from dshape".to_owned(),
2379 ))
2380 }
2381
2382 fn get_protos_count(dshape: &[(DimName, usize)]) -> Option<usize> {
2383 for (dim_name, dim_size) in dshape {
2384 if *dim_name == DimName::NumProtos {
2385 return Some(*dim_size);
2386 }
2387 }
2388 None
2389 }
2390}
2391
2392#[derive(Debug, Clone, PartialEq)]
2393pub struct Decoder {
2394 model_type: ModelType,
2395 pub iou_threshold: f32,
2396 pub score_threshold: f32,
2397 pub nms: Option<configs::Nms>,
2400 normalized: Option<bool>,
2406}
2407
2408#[derive(Debug)]
2409pub enum ArrayViewDQuantized<'a> {
2410 UInt8(ArrayViewD<'a, u8>),
2411 Int8(ArrayViewD<'a, i8>),
2412 UInt16(ArrayViewD<'a, u16>),
2413 Int16(ArrayViewD<'a, i16>),
2414 UInt32(ArrayViewD<'a, u32>),
2415 Int32(ArrayViewD<'a, i32>),
2416}
2417
2418impl<'a, D> From<ArrayView<'a, u8, D>> for ArrayViewDQuantized<'a>
2419where
2420 D: Dimension,
2421{
2422 fn from(arr: ArrayView<'a, u8, D>) -> Self {
2423 Self::UInt8(arr.into_dyn())
2424 }
2425}
2426
2427impl<'a, D> From<ArrayView<'a, i8, D>> for ArrayViewDQuantized<'a>
2428where
2429 D: Dimension,
2430{
2431 fn from(arr: ArrayView<'a, i8, D>) -> Self {
2432 Self::Int8(arr.into_dyn())
2433 }
2434}
2435
2436impl<'a, D> From<ArrayView<'a, u16, D>> for ArrayViewDQuantized<'a>
2437where
2438 D: Dimension,
2439{
2440 fn from(arr: ArrayView<'a, u16, D>) -> Self {
2441 Self::UInt16(arr.into_dyn())
2442 }
2443}
2444
2445impl<'a, D> From<ArrayView<'a, i16, D>> for ArrayViewDQuantized<'a>
2446where
2447 D: Dimension,
2448{
2449 fn from(arr: ArrayView<'a, i16, D>) -> Self {
2450 Self::Int16(arr.into_dyn())
2451 }
2452}
2453
2454impl<'a, D> From<ArrayView<'a, u32, D>> for ArrayViewDQuantized<'a>
2455where
2456 D: Dimension,
2457{
2458 fn from(arr: ArrayView<'a, u32, D>) -> Self {
2459 Self::UInt32(arr.into_dyn())
2460 }
2461}
2462
2463impl<'a, D> From<ArrayView<'a, i32, D>> for ArrayViewDQuantized<'a>
2464where
2465 D: Dimension,
2466{
2467 fn from(arr: ArrayView<'a, i32, D>) -> Self {
2468 Self::Int32(arr.into_dyn())
2469 }
2470}
2471
2472impl<'a> ArrayViewDQuantized<'a> {
2473 pub fn shape(&self) -> &[usize] {
2487 match self {
2488 ArrayViewDQuantized::UInt8(a) => a.shape(),
2489 ArrayViewDQuantized::Int8(a) => a.shape(),
2490 ArrayViewDQuantized::UInt16(a) => a.shape(),
2491 ArrayViewDQuantized::Int16(a) => a.shape(),
2492 ArrayViewDQuantized::UInt32(a) => a.shape(),
2493 ArrayViewDQuantized::Int32(a) => a.shape(),
2494 }
2495 }
2496}
2497
2498macro_rules! with_quantized {
2505 ($x:expr, $var:ident, $body:expr) => {
2506 match $x {
2507 ArrayViewDQuantized::UInt8(x) => {
2508 let $var = x;
2509 $body
2510 }
2511 ArrayViewDQuantized::Int8(x) => {
2512 let $var = x;
2513 $body
2514 }
2515 ArrayViewDQuantized::UInt16(x) => {
2516 let $var = x;
2517 $body
2518 }
2519 ArrayViewDQuantized::Int16(x) => {
2520 let $var = x;
2521 $body
2522 }
2523 ArrayViewDQuantized::UInt32(x) => {
2524 let $var = x;
2525 $body
2526 }
2527 ArrayViewDQuantized::Int32(x) => {
2528 let $var = x;
2529 $body
2530 }
2531 }
2532 };
2533}
2534
2535impl Decoder {
2536 pub fn model_type(&self) -> &ModelType {
2555 &self.model_type
2556 }
2557
2558 pub fn normalized_boxes(&self) -> Option<bool> {
2584 self.normalized
2585 }
2586
2587 pub fn decode_quantized(
2637 &self,
2638 outputs: &[ArrayViewDQuantized],
2639 output_boxes: &mut Vec<DetectBox>,
2640 output_masks: &mut Vec<Segmentation>,
2641 ) -> Result<(), DecoderError> {
2642 output_boxes.clear();
2643 output_masks.clear();
2644 match &self.model_type {
2645 ModelType::ModelPackSegDet {
2646 boxes,
2647 scores,
2648 segmentation,
2649 } => {
2650 self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)?;
2651 self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
2652 }
2653 ModelType::ModelPackSegDetSplit {
2654 detection,
2655 segmentation,
2656 } => {
2657 self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)?;
2658 self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
2659 }
2660 ModelType::ModelPackDet { boxes, scores } => {
2661 self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)
2662 }
2663 ModelType::ModelPackDetSplit { detection } => {
2664 self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)
2665 }
2666 ModelType::ModelPackSeg { segmentation } => {
2667 self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
2668 }
2669 ModelType::YoloDet { boxes } => {
2670 self.decode_yolo_det_quantized(outputs, boxes, output_boxes)
2671 }
2672 ModelType::YoloSegDet { boxes, protos } => self.decode_yolo_segdet_quantized(
2673 outputs,
2674 boxes,
2675 protos,
2676 output_boxes,
2677 output_masks,
2678 ),
2679 ModelType::YoloSplitDet { boxes, scores } => {
2680 self.decode_yolo_split_det_quantized(outputs, boxes, scores, output_boxes)
2681 }
2682 ModelType::YoloSplitSegDet {
2683 boxes,
2684 scores,
2685 mask_coeff,
2686 protos,
2687 } => self.decode_yolo_split_segdet_quantized(
2688 outputs,
2689 boxes,
2690 scores,
2691 mask_coeff,
2692 protos,
2693 output_boxes,
2694 output_masks,
2695 ),
2696 ModelType::YoloEndToEndDet { boxes } => {
2697 self.decode_yolo_end_to_end_det_quantized(outputs, boxes, output_boxes)
2698 }
2699 ModelType::YoloEndToEndSegDet { boxes, protos } => self
2700 .decode_yolo_end_to_end_segdet_quantized(
2701 outputs,
2702 boxes,
2703 protos,
2704 output_boxes,
2705 output_masks,
2706 ),
2707 ModelType::YoloSplitEndToEndDet {
2708 boxes,
2709 scores,
2710 classes,
2711 } => self.decode_yolo_split_end_to_end_det_quantized(
2712 outputs,
2713 boxes,
2714 scores,
2715 classes,
2716 output_boxes,
2717 ),
2718 ModelType::YoloSplitEndToEndSegDet {
2719 boxes,
2720 scores,
2721 classes,
2722 mask_coeff,
2723 protos,
2724 } => self.decode_yolo_split_end_to_end_segdet_quantized(
2725 outputs,
2726 boxes,
2727 scores,
2728 classes,
2729 mask_coeff,
2730 protos,
2731 output_boxes,
2732 output_masks,
2733 ),
2734 }
2735 }
2736
2737 pub fn decode_float<T>(
2794 &self,
2795 outputs: &[ArrayViewD<T>],
2796 output_boxes: &mut Vec<DetectBox>,
2797 output_masks: &mut Vec<Segmentation>,
2798 ) -> Result<(), DecoderError>
2799 where
2800 T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
2801 f32: AsPrimitive<T>,
2802 {
2803 output_boxes.clear();
2804 output_masks.clear();
2805 match &self.model_type {
2806 ModelType::ModelPackSegDet {
2807 boxes,
2808 scores,
2809 segmentation,
2810 } => {
2811 self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
2812 self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
2813 }
2814 ModelType::ModelPackSegDetSplit {
2815 detection,
2816 segmentation,
2817 } => {
2818 self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
2819 self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
2820 }
2821 ModelType::ModelPackDet { boxes, scores } => {
2822 self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
2823 }
2824 ModelType::ModelPackDetSplit { detection } => {
2825 self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
2826 }
2827 ModelType::ModelPackSeg { segmentation } => {
2828 self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
2829 }
2830 ModelType::YoloDet { boxes } => {
2831 self.decode_yolo_det_float(outputs, boxes, output_boxes)?;
2832 }
2833 ModelType::YoloSegDet { boxes, protos } => {
2834 self.decode_yolo_segdet_float(outputs, boxes, protos, output_boxes, output_masks)?;
2835 }
2836 ModelType::YoloSplitDet { boxes, scores } => {
2837 self.decode_yolo_split_det_float(outputs, boxes, scores, output_boxes)?;
2838 }
2839 ModelType::YoloSplitSegDet {
2840 boxes,
2841 scores,
2842 mask_coeff,
2843 protos,
2844 } => {
2845 self.decode_yolo_split_segdet_float(
2846 outputs,
2847 boxes,
2848 scores,
2849 mask_coeff,
2850 protos,
2851 output_boxes,
2852 output_masks,
2853 )?;
2854 }
2855 ModelType::YoloEndToEndDet { boxes } => {
2856 self.decode_yolo_end_to_end_det_float(outputs, boxes, output_boxes)?;
2857 }
2858 ModelType::YoloEndToEndSegDet { boxes, protos } => {
2859 self.decode_yolo_end_to_end_segdet_float(
2860 outputs,
2861 boxes,
2862 protos,
2863 output_boxes,
2864 output_masks,
2865 )?;
2866 }
2867 ModelType::YoloSplitEndToEndDet {
2868 boxes,
2869 scores,
2870 classes,
2871 } => {
2872 self.decode_yolo_split_end_to_end_det_float(
2873 outputs,
2874 boxes,
2875 scores,
2876 classes,
2877 output_boxes,
2878 )?;
2879 }
2880 ModelType::YoloSplitEndToEndSegDet {
2881 boxes,
2882 scores,
2883 classes,
2884 mask_coeff,
2885 protos,
2886 } => {
2887 self.decode_yolo_split_end_to_end_segdet_float(
2888 outputs,
2889 boxes,
2890 scores,
2891 classes,
2892 mask_coeff,
2893 protos,
2894 output_boxes,
2895 output_masks,
2896 )?;
2897 }
2898 }
2899 Ok(())
2900 }
2901
2902 pub fn decode_quantized_proto(
2909 &self,
2910 outputs: &[ArrayViewDQuantized],
2911 output_boxes: &mut Vec<DetectBox>,
2912 ) -> Result<Option<ProtoData>, DecoderError> {
2913 output_boxes.clear();
2914 match &self.model_type {
2915 ModelType::ModelPackSegDet { .. }
2917 | ModelType::ModelPackSegDetSplit { .. }
2918 | ModelType::ModelPackDet { .. }
2919 | ModelType::ModelPackDetSplit { .. }
2920 | ModelType::ModelPackSeg { .. }
2921 | ModelType::YoloDet { .. }
2922 | ModelType::YoloSplitDet { .. }
2923 | ModelType::YoloEndToEndDet { .. }
2924 | ModelType::YoloSplitEndToEndDet { .. } => Ok(None),
2925
2926 ModelType::YoloSegDet { boxes, protos } => {
2927 let proto =
2928 self.decode_yolo_segdet_quantized_proto(outputs, boxes, protos, output_boxes)?;
2929 Ok(Some(proto))
2930 }
2931 ModelType::YoloSplitSegDet {
2932 boxes,
2933 scores,
2934 mask_coeff,
2935 protos,
2936 } => {
2937 let proto = self.decode_yolo_split_segdet_quantized_proto(
2938 outputs,
2939 boxes,
2940 scores,
2941 mask_coeff,
2942 protos,
2943 output_boxes,
2944 )?;
2945 Ok(Some(proto))
2946 }
2947 ModelType::YoloEndToEndSegDet { boxes, protos } => {
2948 let proto = self.decode_yolo_end_to_end_segdet_quantized_proto(
2949 outputs,
2950 boxes,
2951 protos,
2952 output_boxes,
2953 )?;
2954 Ok(Some(proto))
2955 }
2956 ModelType::YoloSplitEndToEndSegDet {
2957 boxes,
2958 scores,
2959 classes,
2960 mask_coeff,
2961 protos,
2962 } => {
2963 let proto = self.decode_yolo_split_end_to_end_segdet_quantized_proto(
2964 outputs,
2965 boxes,
2966 scores,
2967 classes,
2968 mask_coeff,
2969 protos,
2970 output_boxes,
2971 )?;
2972 Ok(Some(proto))
2973 }
2974 }
2975 }
2976
2977 pub fn decode_float_proto<T>(
2983 &self,
2984 outputs: &[ArrayViewD<T>],
2985 output_boxes: &mut Vec<DetectBox>,
2986 ) -> Result<Option<ProtoData>, DecoderError>
2987 where
2988 T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
2989 f32: AsPrimitive<T>,
2990 {
2991 output_boxes.clear();
2992 match &self.model_type {
2993 ModelType::ModelPackSegDet { .. }
2995 | ModelType::ModelPackSegDetSplit { .. }
2996 | ModelType::ModelPackDet { .. }
2997 | ModelType::ModelPackDetSplit { .. }
2998 | ModelType::ModelPackSeg { .. }
2999 | ModelType::YoloDet { .. }
3000 | ModelType::YoloSplitDet { .. }
3001 | ModelType::YoloEndToEndDet { .. }
3002 | ModelType::YoloSplitEndToEndDet { .. } => Ok(None),
3003
3004 ModelType::YoloSegDet { boxes, protos } => {
3005 let proto =
3006 self.decode_yolo_segdet_float_proto(outputs, boxes, protos, output_boxes)?;
3007 Ok(Some(proto))
3008 }
3009 ModelType::YoloSplitSegDet {
3010 boxes,
3011 scores,
3012 mask_coeff,
3013 protos,
3014 } => {
3015 let proto = self.decode_yolo_split_segdet_float_proto(
3016 outputs,
3017 boxes,
3018 scores,
3019 mask_coeff,
3020 protos,
3021 output_boxes,
3022 )?;
3023 Ok(Some(proto))
3024 }
3025 ModelType::YoloEndToEndSegDet { boxes, protos } => {
3026 let proto = self.decode_yolo_end_to_end_segdet_float_proto(
3027 outputs,
3028 boxes,
3029 protos,
3030 output_boxes,
3031 )?;
3032 Ok(Some(proto))
3033 }
3034 ModelType::YoloSplitEndToEndSegDet {
3035 boxes,
3036 scores,
3037 classes,
3038 mask_coeff,
3039 protos,
3040 } => {
3041 let proto = self.decode_yolo_split_end_to_end_segdet_float_proto(
3042 outputs,
3043 boxes,
3044 scores,
3045 classes,
3046 mask_coeff,
3047 protos,
3048 output_boxes,
3049 )?;
3050 Ok(Some(proto))
3051 }
3052 }
3053 }
3054
3055 fn decode_modelpack_det_quantized(
3056 &self,
3057 outputs: &[ArrayViewDQuantized],
3058 boxes: &configs::Boxes,
3059 scores: &configs::Scores,
3060 output_boxes: &mut Vec<DetectBox>,
3061 ) -> Result<(), DecoderError> {
3062 let (boxes_tensor, ind) =
3063 Self::find_outputs_with_shape_quantized(&boxes.shape, outputs, &[])?;
3064 let (scores_tensor, _) =
3065 Self::find_outputs_with_shape_quantized(&scores.shape, outputs, &[ind])?;
3066 let quant_boxes = boxes
3067 .quantization
3068 .map(Quantization::from)
3069 .unwrap_or_default();
3070 let quant_scores = scores
3071 .quantization
3072 .map(Quantization::from)
3073 .unwrap_or_default();
3074
3075 with_quantized!(boxes_tensor, b, {
3076 with_quantized!(scores_tensor, s, {
3077 let boxes_tensor = Self::swap_axes_if_needed(b, boxes.into());
3078 let boxes_tensor = boxes_tensor.slice(s![0, .., 0, ..]);
3079
3080 let scores_tensor = Self::swap_axes_if_needed(s, scores.into());
3081 let scores_tensor = scores_tensor.slice(s![0, .., ..]);
3082 decode_modelpack_det(
3083 (boxes_tensor, quant_boxes),
3084 (scores_tensor, quant_scores),
3085 self.score_threshold,
3086 self.iou_threshold,
3087 output_boxes,
3088 );
3089 });
3090 });
3091
3092 Ok(())
3093 }
3094
3095 fn decode_modelpack_seg_quantized(
3096 &self,
3097 outputs: &[ArrayViewDQuantized],
3098 segmentation: &configs::Segmentation,
3099 output_masks: &mut Vec<Segmentation>,
3100 ) -> Result<(), DecoderError> {
3101 let (seg, _) = Self::find_outputs_with_shape_quantized(&segmentation.shape, outputs, &[])?;
3102
3103 macro_rules! modelpack_seg {
3104 ($seg:expr, $body:expr) => {{
3105 let seg = Self::swap_axes_if_needed($seg, segmentation.into());
3106 let seg = seg.slice(s![0, .., .., ..]);
3107 seg.mapv($body)
3108 }};
3109 }
3110 use ArrayViewDQuantized::*;
3111 let seg = match seg {
3112 UInt8(s) => {
3113 modelpack_seg!(s, |x| x)
3114 }
3115 Int8(s) => {
3116 modelpack_seg!(s, |x| (x as i16 + 128) as u8)
3117 }
3118 UInt16(s) => {
3119 modelpack_seg!(s, |x| (x >> 8) as u8)
3120 }
3121 Int16(s) => {
3122 modelpack_seg!(s, |x| ((x as i32 + 32768) >> 8) as u8)
3123 }
3124 UInt32(s) => {
3125 modelpack_seg!(s, |x| (x >> 24) as u8)
3126 }
3127 Int32(s) => {
3128 modelpack_seg!(s, |x| ((x as i64 + 2147483648) >> 24) as u8)
3129 }
3130 };
3131
3132 output_masks.push(Segmentation {
3133 xmin: 0.0,
3134 ymin: 0.0,
3135 xmax: 1.0,
3136 ymax: 1.0,
3137 segmentation: seg,
3138 });
3139 Ok(())
3140 }
3141
3142 fn decode_modelpack_det_split_quantized(
3143 &self,
3144 outputs: &[ArrayViewDQuantized],
3145 detection: &[configs::Detection],
3146 output_boxes: &mut Vec<DetectBox>,
3147 ) -> Result<(), DecoderError> {
3148 let new_detection = detection
3149 .iter()
3150 .map(|x| match &x.anchors {
3151 None => Err(DecoderError::InvalidConfig(
3152 "ModelPack Split Detection missing anchors".to_string(),
3153 )),
3154 Some(a) => Ok(ModelPackDetectionConfig {
3155 anchors: a.clone(),
3156 quantization: None,
3157 }),
3158 })
3159 .collect::<Result<Vec<_>, _>>()?;
3160 let new_outputs = Self::match_outputs_to_detect_quantized(detection, outputs)?;
3161
3162 macro_rules! dequant_output {
3163 ($det_tensor:expr, $detection:expr) => {{
3164 let det_tensor = Self::swap_axes_if_needed($det_tensor, $detection.into());
3165 let det_tensor = det_tensor.slice(s![0, .., .., ..]);
3166 if let Some(q) = $detection.quantization {
3167 dequantize_ndarray(det_tensor, q.into())
3168 } else {
3169 det_tensor.map(|x| *x as f32)
3170 }
3171 }};
3172 }
3173
3174 let new_outputs = new_outputs
3175 .iter()
3176 .zip(detection)
3177 .map(|(det_tensor, detection)| {
3178 with_quantized!(det_tensor, d, dequant_output!(d, detection))
3179 })
3180 .collect::<Vec<_>>();
3181
3182 let new_outputs_view = new_outputs
3183 .iter()
3184 .map(|d: &Array3<f32>| d.view())
3185 .collect::<Vec<_>>();
3186 decode_modelpack_split_float(
3187 &new_outputs_view,
3188 &new_detection,
3189 self.score_threshold,
3190 self.iou_threshold,
3191 output_boxes,
3192 );
3193 Ok(())
3194 }
3195
3196 fn decode_yolo_det_quantized(
3197 &self,
3198 outputs: &[ArrayViewDQuantized],
3199 boxes: &configs::Detection,
3200 output_boxes: &mut Vec<DetectBox>,
3201 ) -> Result<(), DecoderError> {
3202 let (boxes_tensor, _) =
3203 Self::find_outputs_with_shape_quantized(&boxes.shape, outputs, &[])?;
3204 let quant_boxes = boxes
3205 .quantization
3206 .map(Quantization::from)
3207 .unwrap_or_default();
3208
3209 with_quantized!(boxes_tensor, b, {
3210 let boxes_tensor = Self::swap_axes_if_needed(b, boxes.into());
3211 let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
3212 decode_yolo_det(
3213 (boxes_tensor, quant_boxes),
3214 self.score_threshold,
3215 self.iou_threshold,
3216 self.nms,
3217 output_boxes,
3218 );
3219 });
3220
3221 Ok(())
3222 }
3223
3224 fn decode_yolo_segdet_quantized(
3225 &self,
3226 outputs: &[ArrayViewDQuantized],
3227 boxes: &configs::Detection,
3228 protos: &configs::Protos,
3229 output_boxes: &mut Vec<DetectBox>,
3230 output_masks: &mut Vec<Segmentation>,
3231 ) -> Result<(), DecoderError> {
3232 let (boxes_tensor, ind) =
3233 Self::find_outputs_with_shape_quantized(&boxes.shape, outputs, &[])?;
3234 let (protos_tensor, _) =
3235 Self::find_outputs_with_shape_quantized(&protos.shape, outputs, &[ind])?;
3236
3237 let quant_boxes = boxes
3238 .quantization
3239 .map(Quantization::from)
3240 .unwrap_or_default();
3241 let quant_protos = protos
3242 .quantization
3243 .map(Quantization::from)
3244 .unwrap_or_default();
3245
3246 with_quantized!(boxes_tensor, b, {
3247 with_quantized!(protos_tensor, p, {
3248 let box_tensor = Self::swap_axes_if_needed(b, boxes.into());
3249 let box_tensor = box_tensor.slice(s![0, .., ..]);
3250
3251 let protos_tensor = Self::swap_axes_if_needed(p, protos.into());
3252 let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
3253 let protos_tensor = Self::protos_to_hwc(protos_tensor, protos);
3254 decode_yolo_segdet_quant(
3255 (box_tensor, quant_boxes),
3256 (protos_tensor, quant_protos),
3257 self.score_threshold,
3258 self.iou_threshold,
3259 self.nms,
3260 output_boxes,
3261 output_masks,
3262 )
3263 })
3264 })
3265 }
3266
3267 fn decode_yolo_split_det_quantized(
3268 &self,
3269 outputs: &[ArrayViewDQuantized],
3270 boxes: &configs::Boxes,
3271 scores: &configs::Scores,
3272 output_boxes: &mut Vec<DetectBox>,
3273 ) -> Result<(), DecoderError> {
3274 let (boxes_tensor, ind) =
3275 Self::find_outputs_with_shape_quantized(&boxes.shape, outputs, &[])?;
3276 let (scores_tensor, _) =
3277 Self::find_outputs_with_shape_quantized(&scores.shape, outputs, &[ind])?;
3278 let quant_boxes = boxes
3279 .quantization
3280 .map(Quantization::from)
3281 .unwrap_or_default();
3282 let quant_scores = scores
3283 .quantization
3284 .map(Quantization::from)
3285 .unwrap_or_default();
3286
3287 with_quantized!(boxes_tensor, b, {
3288 with_quantized!(scores_tensor, s, {
3289 let boxes_tensor = Self::swap_axes_if_needed(b, boxes.into());
3290 let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
3291
3292 let scores_tensor = Self::swap_axes_if_needed(s, scores.into());
3293 let scores_tensor = scores_tensor.slice(s![0, .., ..]);
3294 decode_yolo_split_det_quant(
3295 (boxes_tensor, quant_boxes),
3296 (scores_tensor, quant_scores),
3297 self.score_threshold,
3298 self.iou_threshold,
3299 self.nms,
3300 output_boxes,
3301 );
3302 });
3303 });
3304
3305 Ok(())
3306 }
3307
3308 #[allow(clippy::too_many_arguments)]
3309 fn decode_yolo_split_segdet_quantized(
3310 &self,
3311 outputs: &[ArrayViewDQuantized],
3312 boxes: &configs::Boxes,
3313 scores: &configs::Scores,
3314 mask_coeff: &configs::MaskCoefficients,
3315 protos: &configs::Protos,
3316 output_boxes: &mut Vec<DetectBox>,
3317 output_masks: &mut Vec<Segmentation>,
3318 ) -> Result<(), DecoderError> {
3319 let quant_boxes = boxes
3320 .quantization
3321 .map(Quantization::from)
3322 .unwrap_or_default();
3323 let quant_scores = scores
3324 .quantization
3325 .map(Quantization::from)
3326 .unwrap_or_default();
3327 let quant_masks = mask_coeff
3328 .quantization
3329 .map(Quantization::from)
3330 .unwrap_or_default();
3331 let quant_protos = protos
3332 .quantization
3333 .map(Quantization::from)
3334 .unwrap_or_default();
3335
3336 let mut skip = vec![];
3337
3338 let (boxes_tensor, ind) =
3339 Self::find_outputs_with_shape_quantized(&boxes.shape, outputs, &skip)?;
3340 skip.push(ind);
3341
3342 let (scores_tensor, ind) =
3343 Self::find_outputs_with_shape_quantized(&scores.shape, outputs, &skip)?;
3344 skip.push(ind);
3345
3346 let (mask_tensor, ind) =
3347 Self::find_outputs_with_shape_quantized(&mask_coeff.shape, outputs, &skip)?;
3348 skip.push(ind);
3349
3350 let (protos_tensor, _) =
3351 Self::find_outputs_with_shape_quantized(&protos.shape, outputs, &skip)?;
3352
3353 let boxes = with_quantized!(boxes_tensor, b, {
3354 with_quantized!(scores_tensor, s, {
3355 let boxes_tensor = Self::swap_axes_if_needed(b, boxes.into());
3356 let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
3357
3358 let scores_tensor = Self::swap_axes_if_needed(s, scores.into());
3359 let scores_tensor = scores_tensor.slice(s![0, .., ..]);
3360 impl_yolo_split_segdet_quant_get_boxes::<XYWH, _, _>(
3361 (boxes_tensor, quant_boxes),
3362 (scores_tensor, quant_scores),
3363 self.score_threshold,
3364 self.iou_threshold,
3365 self.nms,
3366 output_boxes.capacity(),
3367 )
3368 })
3369 });
3370
3371 with_quantized!(mask_tensor, m, {
3372 with_quantized!(protos_tensor, p, {
3373 let mask_tensor = Self::swap_axes_if_needed(m, mask_coeff.into());
3374 let mask_tensor = mask_tensor.slice(s![0, .., ..]);
3375
3376 let protos_tensor = Self::swap_axes_if_needed(p, protos.into());
3377 let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
3378 let protos_tensor = Self::protos_to_hwc(protos_tensor, protos);
3379 impl_yolo_split_segdet_quant_process_masks::<_, _>(
3380 boxes,
3381 (mask_tensor, quant_masks),
3382 (protos_tensor, quant_protos),
3383 output_boxes,
3384 output_masks,
3385 )
3386 })
3387 })
3388 }
3389
3390 fn decode_modelpack_det_split_float<D>(
3391 &self,
3392 outputs: &[ArrayViewD<D>],
3393 detection: &[configs::Detection],
3394 output_boxes: &mut Vec<DetectBox>,
3395 ) -> Result<(), DecoderError>
3396 where
3397 D: AsPrimitive<f32>,
3398 {
3399 let new_detection = detection
3400 .iter()
3401 .map(|x| match &x.anchors {
3402 None => Err(DecoderError::InvalidConfig(
3403 "ModelPack Split Detection missing anchors".to_string(),
3404 )),
3405 Some(a) => Ok(ModelPackDetectionConfig {
3406 anchors: a.clone(),
3407 quantization: None,
3408 }),
3409 })
3410 .collect::<Result<Vec<_>, _>>()?;
3411
3412 let new_outputs = Self::match_outputs_to_detect(detection, outputs)?;
3413 let new_outputs = new_outputs
3414 .into_iter()
3415 .map(|x| x.slice(s![0, .., .., ..]))
3416 .collect::<Vec<_>>();
3417
3418 decode_modelpack_split_float(
3419 &new_outputs,
3420 &new_detection,
3421 self.score_threshold,
3422 self.iou_threshold,
3423 output_boxes,
3424 );
3425 Ok(())
3426 }
3427
3428 fn decode_modelpack_seg_float<T>(
3429 &self,
3430 outputs: &[ArrayViewD<T>],
3431 segmentation: &configs::Segmentation,
3432 output_masks: &mut Vec<Segmentation>,
3433 ) -> Result<(), DecoderError>
3434 where
3435 T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
3436 f32: AsPrimitive<T>,
3437 {
3438 let (seg, _) = Self::find_outputs_with_shape(&segmentation.shape, outputs, &[])?;
3439
3440 let seg = Self::swap_axes_if_needed(seg, segmentation.into());
3441 let seg = seg.slice(s![0, .., .., ..]);
3442 let u8_max = 255.0_f32.as_();
3443 let max = *seg.max().unwrap_or(&u8_max);
3444 let min = *seg.min().unwrap_or(&0.0_f32.as_());
3445 let seg = seg.mapv(|x| ((x - min) / (max - min) * u8_max).as_());
3446 output_masks.push(Segmentation {
3447 xmin: 0.0,
3448 ymin: 0.0,
3449 xmax: 1.0,
3450 ymax: 1.0,
3451 segmentation: seg,
3452 });
3453 Ok(())
3454 }
3455
3456 fn decode_modelpack_det_float<T>(
3457 &self,
3458 outputs: &[ArrayViewD<T>],
3459 boxes: &configs::Boxes,
3460 scores: &configs::Scores,
3461 output_boxes: &mut Vec<DetectBox>,
3462 ) -> Result<(), DecoderError>
3463 where
3464 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
3465 f32: AsPrimitive<T>,
3466 {
3467 let (boxes_tensor, ind) = Self::find_outputs_with_shape(&boxes.shape, outputs, &[])?;
3468
3469 let boxes_tensor = Self::swap_axes_if_needed(boxes_tensor, boxes.into());
3470 let boxes_tensor = boxes_tensor.slice(s![0, .., 0, ..]);
3471
3472 let (scores_tensor, _) = Self::find_outputs_with_shape(&scores.shape, outputs, &[ind])?;
3473 let scores_tensor = Self::swap_axes_if_needed(scores_tensor, scores.into());
3474 let scores_tensor = scores_tensor.slice(s![0, .., ..]);
3475
3476 decode_modelpack_float(
3477 boxes_tensor,
3478 scores_tensor,
3479 self.score_threshold,
3480 self.iou_threshold,
3481 output_boxes,
3482 );
3483 Ok(())
3484 }
3485
3486 fn decode_yolo_det_float<T>(
3487 &self,
3488 outputs: &[ArrayViewD<T>],
3489 boxes: &configs::Detection,
3490 output_boxes: &mut Vec<DetectBox>,
3491 ) -> Result<(), DecoderError>
3492 where
3493 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
3494 f32: AsPrimitive<T>,
3495 {
3496 let (boxes_tensor, _) = Self::find_outputs_with_shape(&boxes.shape, outputs, &[])?;
3497
3498 let boxes_tensor = Self::swap_axes_if_needed(boxes_tensor, boxes.into());
3499 let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
3500 decode_yolo_det_float(
3501 boxes_tensor,
3502 self.score_threshold,
3503 self.iou_threshold,
3504 self.nms,
3505 output_boxes,
3506 );
3507 Ok(())
3508 }
3509
3510 fn decode_yolo_segdet_float<T>(
3511 &self,
3512 outputs: &[ArrayViewD<T>],
3513 boxes: &configs::Detection,
3514 protos: &configs::Protos,
3515 output_boxes: &mut Vec<DetectBox>,
3516 output_masks: &mut Vec<Segmentation>,
3517 ) -> Result<(), DecoderError>
3518 where
3519 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
3520 f32: AsPrimitive<T>,
3521 {
3522 let (boxes_tensor, ind) = Self::find_outputs_with_shape(&boxes.shape, outputs, &[])?;
3523
3524 let boxes_tensor = Self::swap_axes_if_needed(boxes_tensor, boxes.into());
3525 let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
3526
3527 let (protos_tensor, _) = Self::find_outputs_with_shape(&protos.shape, outputs, &[ind])?;
3528
3529 let protos_tensor = Self::swap_axes_if_needed(protos_tensor, protos.into());
3530 let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
3531 let protos_tensor = Self::protos_to_hwc(protos_tensor, protos);
3532 decode_yolo_segdet_float(
3533 boxes_tensor,
3534 protos_tensor,
3535 self.score_threshold,
3536 self.iou_threshold,
3537 self.nms,
3538 output_boxes,
3539 output_masks,
3540 )
3541 }
3542
3543 fn decode_yolo_split_det_float<T>(
3544 &self,
3545 outputs: &[ArrayViewD<T>],
3546 boxes: &configs::Boxes,
3547 scores: &configs::Scores,
3548 output_boxes: &mut Vec<DetectBox>,
3549 ) -> Result<(), DecoderError>
3550 where
3551 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
3552 f32: AsPrimitive<T>,
3553 {
3554 let (boxes_tensor, ind) = Self::find_outputs_with_shape(&boxes.shape, outputs, &[])?;
3555 let boxes_tensor = Self::swap_axes_if_needed(boxes_tensor, boxes.into());
3556 let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
3557
3558 let (scores_tensor, _) = Self::find_outputs_with_shape(&scores.shape, outputs, &[ind])?;
3559
3560 let scores_tensor = Self::swap_axes_if_needed(scores_tensor, scores.into());
3561 let scores_tensor = scores_tensor.slice(s![0, .., ..]);
3562
3563 decode_yolo_split_det_float(
3564 boxes_tensor,
3565 scores_tensor,
3566 self.score_threshold,
3567 self.iou_threshold,
3568 self.nms,
3569 output_boxes,
3570 );
3571 Ok(())
3572 }
3573
3574 #[allow(clippy::too_many_arguments)]
3575 fn decode_yolo_split_segdet_float<T>(
3576 &self,
3577 outputs: &[ArrayViewD<T>],
3578 boxes: &configs::Boxes,
3579 scores: &configs::Scores,
3580 mask_coeff: &configs::MaskCoefficients,
3581 protos: &configs::Protos,
3582 output_boxes: &mut Vec<DetectBox>,
3583 output_masks: &mut Vec<Segmentation>,
3584 ) -> Result<(), DecoderError>
3585 where
3586 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
3587 f32: AsPrimitive<T>,
3588 {
3589 let mut skip = vec![];
3590 let (boxes_tensor, ind) = Self::find_outputs_with_shape(&boxes.shape, outputs, &skip)?;
3591
3592 let boxes_tensor = Self::swap_axes_if_needed(boxes_tensor, boxes.into());
3593 let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
3594 skip.push(ind);
3595
3596 let (scores_tensor, ind) = Self::find_outputs_with_shape(&scores.shape, outputs, &skip)?;
3597
3598 let scores_tensor = Self::swap_axes_if_needed(scores_tensor, scores.into());
3599 let scores_tensor = scores_tensor.slice(s![0, .., ..]);
3600 skip.push(ind);
3601
3602 let (mask_tensor, ind) = Self::find_outputs_with_shape(&mask_coeff.shape, outputs, &skip)?;
3603 let mask_tensor = Self::swap_axes_if_needed(mask_tensor, mask_coeff.into());
3604 let mask_tensor = mask_tensor.slice(s![0, .., ..]);
3605 skip.push(ind);
3606
3607 let (protos_tensor, _) = Self::find_outputs_with_shape(&protos.shape, outputs, &skip)?;
3608 let protos_tensor = Self::swap_axes_if_needed(protos_tensor, protos.into());
3609 let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
3610 let protos_tensor = Self::protos_to_hwc(protos_tensor, protos);
3611 decode_yolo_split_segdet_float(
3612 boxes_tensor,
3613 scores_tensor,
3614 mask_tensor,
3615 protos_tensor,
3616 self.score_threshold,
3617 self.iou_threshold,
3618 self.nms,
3619 output_boxes,
3620 output_masks,
3621 )
3622 }
3623
3624 fn decode_yolo_end_to_end_det_float<T>(
3630 &self,
3631 outputs: &[ArrayViewD<T>],
3632 boxes_config: &configs::Detection,
3633 output_boxes: &mut Vec<DetectBox>,
3634 ) -> Result<(), DecoderError>
3635 where
3636 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
3637 f32: AsPrimitive<T>,
3638 {
3639 let (det_tensor, _) = Self::find_outputs_with_shape(&boxes_config.shape, outputs, &[])?;
3640 let det_tensor = Self::swap_axes_if_needed(det_tensor, boxes_config.into());
3641 let det_tensor = det_tensor.slice(s![0, .., ..]);
3642
3643 crate::yolo::decode_yolo_end_to_end_det_float(
3644 det_tensor,
3645 self.score_threshold,
3646 output_boxes,
3647 )?;
3648 Ok(())
3649 }
3650
3651 fn decode_yolo_end_to_end_segdet_float<T>(
3659 &self,
3660 outputs: &[ArrayViewD<T>],
3661 boxes_config: &configs::Detection,
3662 protos_config: &configs::Protos,
3663 output_boxes: &mut Vec<DetectBox>,
3664 output_masks: &mut Vec<Segmentation>,
3665 ) -> Result<(), DecoderError>
3666 where
3667 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
3668 f32: AsPrimitive<T>,
3669 {
3670 if outputs.len() < 2 {
3671 return Err(DecoderError::InvalidShape(
3672 "End-to-end segdet requires detection and protos outputs".to_string(),
3673 ));
3674 }
3675
3676 let (det_tensor, det_ind) =
3677 Self::find_outputs_with_shape(&boxes_config.shape, outputs, &[])?;
3678 let det_tensor = Self::swap_axes_if_needed(det_tensor, boxes_config.into());
3679 let det_tensor = det_tensor.slice(s![0, .., ..]);
3680
3681 let (protos_tensor, _) =
3682 Self::find_outputs_with_shape(&protos_config.shape, outputs, &[det_ind])?;
3683 let protos_tensor = Self::swap_axes_if_needed(protos_tensor, protos_config.into());
3684 let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
3685 let protos_tensor = Self::protos_to_hwc(protos_tensor, protos_config);
3686
3687 crate::yolo::decode_yolo_end_to_end_segdet_float(
3688 det_tensor,
3689 protos_tensor,
3690 self.score_threshold,
3691 output_boxes,
3692 output_masks,
3693 )?;
3694 Ok(())
3695 }
3696
3697 fn decode_yolo_end_to_end_det_quantized(
3700 &self,
3701 outputs: &[ArrayViewDQuantized],
3702 boxes_config: &configs::Detection,
3703 output_boxes: &mut Vec<DetectBox>,
3704 ) -> Result<(), DecoderError> {
3705 let (det_tensor, _) =
3706 Self::find_outputs_with_shape_quantized(&boxes_config.shape, outputs, &[])?;
3707 let quant = boxes_config
3708 .quantization
3709 .map(Quantization::from)
3710 .unwrap_or_default();
3711
3712 with_quantized!(det_tensor, d, {
3713 let d = Self::swap_axes_if_needed(d, boxes_config.into());
3714 let d = d.slice(s![0, .., ..]);
3715 let dequant = d.map(|v| {
3716 let val: f32 = v.as_();
3717 (val - quant.zero_point as f32) * quant.scale
3718 });
3719 crate::yolo::decode_yolo_end_to_end_det_float(
3720 dequant.view(),
3721 self.score_threshold,
3722 output_boxes,
3723 )?;
3724 });
3725 Ok(())
3726 }
3727
3728 #[allow(clippy::too_many_arguments)]
3730 fn decode_yolo_end_to_end_segdet_quantized(
3731 &self,
3732 outputs: &[ArrayViewDQuantized],
3733 boxes_config: &configs::Detection,
3734 protos_config: &configs::Protos,
3735 output_boxes: &mut Vec<DetectBox>,
3736 output_masks: &mut Vec<Segmentation>,
3737 ) -> Result<(), DecoderError> {
3738 let (det_tensor, det_ind) =
3739 Self::find_outputs_with_shape_quantized(&boxes_config.shape, outputs, &[])?;
3740 let (protos_tensor, _) =
3741 Self::find_outputs_with_shape_quantized(&protos_config.shape, outputs, &[det_ind])?;
3742
3743 let quant_det = boxes_config
3744 .quantization
3745 .map(Quantization::from)
3746 .unwrap_or_default();
3747 let quant_protos = protos_config
3748 .quantization
3749 .map(Quantization::from)
3750 .unwrap_or_default();
3751
3752 macro_rules! dequant_3d {
3755 ($tensor:expr, $config:expr, $quant:expr) => {{
3756 with_quantized!($tensor, t, {
3757 let t = Self::swap_axes_if_needed(t, $config.into());
3758 let t = t.slice(s![0, .., ..]);
3759 t.map(|v| {
3760 let val: f32 = v.as_();
3761 (val - $quant.zero_point as f32) * $quant.scale
3762 })
3763 })
3764 }};
3765 }
3766 macro_rules! dequant_4d {
3767 ($tensor:expr, $config:expr, $quant:expr) => {{
3768 with_quantized!($tensor, t, {
3769 let t = Self::swap_axes_if_needed(t, $config.into());
3770 let t = t.slice(s![0, .., .., ..]);
3771 t.map(|v| {
3772 let val: f32 = v.as_();
3773 (val - $quant.zero_point as f32) * $quant.scale
3774 })
3775 })
3776 }};
3777 }
3778
3779 let dequant_d = dequant_3d!(det_tensor, boxes_config, quant_det);
3780 let dequant_p = dequant_4d!(protos_tensor, protos_config, quant_protos);
3781
3782 crate::yolo::decode_yolo_end_to_end_segdet_float(
3783 dequant_d.view(),
3784 dequant_p.view(),
3785 self.score_threshold,
3786 output_boxes,
3787 output_masks,
3788 )?;
3789 Ok(())
3790 }
3791
3792 fn decode_yolo_split_end_to_end_det_float<T>(
3794 &self,
3795 outputs: &[ArrayViewD<T>],
3796 boxes_config: &configs::Boxes,
3797 scores_config: &configs::Scores,
3798 classes_config: &configs::Classes,
3799 output_boxes: &mut Vec<DetectBox>,
3800 ) -> Result<(), DecoderError>
3801 where
3802 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
3803 f32: AsPrimitive<T>,
3804 {
3805 let mut skip = vec![];
3806 let (boxes_tensor, ind) =
3807 Self::find_outputs_with_shape(&boxes_config.shape, outputs, &skip)?;
3808 let boxes_tensor = Self::swap_axes_if_needed(boxes_tensor, boxes_config.into());
3809 let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
3810 skip.push(ind);
3811
3812 let (scores_tensor, ind) =
3813 Self::find_outputs_with_shape(&scores_config.shape, outputs, &skip)?;
3814 let scores_tensor = Self::swap_axes_if_needed(scores_tensor, scores_config.into());
3815 let scores_tensor = scores_tensor.slice(s![0, .., ..]);
3816 skip.push(ind);
3817
3818 let (classes_tensor, _) =
3819 Self::find_outputs_with_shape(&classes_config.shape, outputs, &skip)?;
3820 let classes_tensor = Self::swap_axes_if_needed(classes_tensor, classes_config.into());
3821 let classes_tensor = classes_tensor.slice(s![0, .., ..]);
3822
3823 crate::yolo::decode_yolo_split_end_to_end_det_float(
3824 boxes_tensor,
3825 scores_tensor,
3826 classes_tensor,
3827 self.score_threshold,
3828 output_boxes,
3829 )?;
3830 Ok(())
3831 }
3832
3833 #[allow(clippy::too_many_arguments)]
3835 fn decode_yolo_split_end_to_end_segdet_float<T>(
3836 &self,
3837 outputs: &[ArrayViewD<T>],
3838 boxes_config: &configs::Boxes,
3839 scores_config: &configs::Scores,
3840 classes_config: &configs::Classes,
3841 mask_coeff_config: &configs::MaskCoefficients,
3842 protos_config: &configs::Protos,
3843 output_boxes: &mut Vec<DetectBox>,
3844 output_masks: &mut Vec<Segmentation>,
3845 ) -> Result<(), DecoderError>
3846 where
3847 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
3848 f32: AsPrimitive<T>,
3849 {
3850 let mut skip = vec![];
3851 let (boxes_tensor, ind) =
3852 Self::find_outputs_with_shape(&boxes_config.shape, outputs, &skip)?;
3853 let boxes_tensor = Self::swap_axes_if_needed(boxes_tensor, boxes_config.into());
3854 let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
3855 skip.push(ind);
3856
3857 let (scores_tensor, ind) =
3858 Self::find_outputs_with_shape(&scores_config.shape, outputs, &skip)?;
3859 let scores_tensor = Self::swap_axes_if_needed(scores_tensor, scores_config.into());
3860 let scores_tensor = scores_tensor.slice(s![0, .., ..]);
3861 skip.push(ind);
3862
3863 let (classes_tensor, ind) =
3864 Self::find_outputs_with_shape(&classes_config.shape, outputs, &skip)?;
3865 let classes_tensor = Self::swap_axes_if_needed(classes_tensor, classes_config.into());
3866 let classes_tensor = classes_tensor.slice(s![0, .., ..]);
3867 skip.push(ind);
3868
3869 let (mask_tensor, ind) =
3870 Self::find_outputs_with_shape(&mask_coeff_config.shape, outputs, &skip)?;
3871 let mask_tensor = Self::swap_axes_if_needed(mask_tensor, mask_coeff_config.into());
3872 let mask_tensor = mask_tensor.slice(s![0, .., ..]);
3873 skip.push(ind);
3874
3875 let (protos_tensor, _) =
3876 Self::find_outputs_with_shape(&protos_config.shape, outputs, &skip)?;
3877 let protos_tensor = Self::swap_axes_if_needed(protos_tensor, protos_config.into());
3878 let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
3879 let protos_tensor = Self::protos_to_hwc(protos_tensor, protos_config);
3880
3881 crate::yolo::decode_yolo_split_end_to_end_segdet_float(
3882 boxes_tensor,
3883 scores_tensor,
3884 classes_tensor,
3885 mask_tensor,
3886 protos_tensor,
3887 self.score_threshold,
3888 output_boxes,
3889 output_masks,
3890 )?;
3891 Ok(())
3892 }
3893
3894 fn decode_yolo_split_end_to_end_det_quantized(
3897 &self,
3898 outputs: &[ArrayViewDQuantized],
3899 boxes_config: &configs::Boxes,
3900 scores_config: &configs::Scores,
3901 classes_config: &configs::Classes,
3902 output_boxes: &mut Vec<DetectBox>,
3903 ) -> Result<(), DecoderError> {
3904 let mut skip = vec![];
3905 let (boxes_tensor, ind) =
3906 Self::find_outputs_with_shape_quantized(&boxes_config.shape, outputs, &skip)?;
3907 skip.push(ind);
3908 let (scores_tensor, ind) =
3909 Self::find_outputs_with_shape_quantized(&scores_config.shape, outputs, &skip)?;
3910 skip.push(ind);
3911 let (classes_tensor, _) =
3912 Self::find_outputs_with_shape_quantized(&classes_config.shape, outputs, &skip)?;
3913
3914 let quant_boxes = boxes_config
3915 .quantization
3916 .map(Quantization::from)
3917 .unwrap_or_default();
3918 let quant_scores = scores_config
3919 .quantization
3920 .map(Quantization::from)
3921 .unwrap_or_default();
3922 let quant_classes = classes_config
3923 .quantization
3924 .map(Quantization::from)
3925 .unwrap_or_default();
3926
3927 macro_rules! dequant_3d {
3930 ($tensor:expr, $config:expr, $quant:expr) => {{
3931 with_quantized!($tensor, t, {
3932 let t = Self::swap_axes_if_needed(t, $config.into());
3933 let t = t.slice(s![0, .., ..]);
3934 t.map(|v| {
3935 let val: f32 = v.as_();
3936 (val - $quant.zero_point as f32) * $quant.scale
3937 })
3938 })
3939 }};
3940 }
3941
3942 let dequant_b = dequant_3d!(boxes_tensor, boxes_config, quant_boxes);
3943 let dequant_s = dequant_3d!(scores_tensor, scores_config, quant_scores);
3944 let dequant_c = dequant_3d!(classes_tensor, classes_config, quant_classes);
3945
3946 crate::yolo::decode_yolo_split_end_to_end_det_float(
3947 dequant_b.view(),
3948 dequant_s.view(),
3949 dequant_c.view(),
3950 self.score_threshold,
3951 output_boxes,
3952 )?;
3953 Ok(())
3954 }
3955
3956 #[allow(clippy::too_many_arguments)]
3958 fn decode_yolo_split_end_to_end_segdet_quantized(
3959 &self,
3960 outputs: &[ArrayViewDQuantized],
3961 boxes_config: &configs::Boxes,
3962 scores_config: &configs::Scores,
3963 classes_config: &configs::Classes,
3964 mask_coeff_config: &configs::MaskCoefficients,
3965 protos_config: &configs::Protos,
3966 output_boxes: &mut Vec<DetectBox>,
3967 output_masks: &mut Vec<Segmentation>,
3968 ) -> Result<(), DecoderError> {
3969 let mut skip = vec![];
3970 let (boxes_tensor, ind) =
3971 Self::find_outputs_with_shape_quantized(&boxes_config.shape, outputs, &skip)?;
3972 skip.push(ind);
3973 let (scores_tensor, ind) =
3974 Self::find_outputs_with_shape_quantized(&scores_config.shape, outputs, &skip)?;
3975 skip.push(ind);
3976 let (classes_tensor, ind) =
3977 Self::find_outputs_with_shape_quantized(&classes_config.shape, outputs, &skip)?;
3978 skip.push(ind);
3979 let (mask_tensor, ind) =
3980 Self::find_outputs_with_shape_quantized(&mask_coeff_config.shape, outputs, &skip)?;
3981 skip.push(ind);
3982 let (protos_tensor, _) =
3983 Self::find_outputs_with_shape_quantized(&protos_config.shape, outputs, &skip)?;
3984
3985 let quant_boxes = boxes_config
3986 .quantization
3987 .map(Quantization::from)
3988 .unwrap_or_default();
3989 let quant_scores = scores_config
3990 .quantization
3991 .map(Quantization::from)
3992 .unwrap_or_default();
3993 let quant_classes = classes_config
3994 .quantization
3995 .map(Quantization::from)
3996 .unwrap_or_default();
3997 let quant_masks = mask_coeff_config
3998 .quantization
3999 .map(Quantization::from)
4000 .unwrap_or_default();
4001 let quant_protos = protos_config
4002 .quantization
4003 .map(Quantization::from)
4004 .unwrap_or_default();
4005
4006 macro_rules! dequant_3d {
4009 ($tensor:expr, $config:expr, $quant:expr) => {{
4010 with_quantized!($tensor, t, {
4011 let t = Self::swap_axes_if_needed(t, $config.into());
4012 let t = t.slice(s![0, .., ..]);
4013 t.map(|v| {
4014 let val: f32 = v.as_();
4015 (val - $quant.zero_point as f32) * $quant.scale
4016 })
4017 })
4018 }};
4019 }
4020 macro_rules! dequant_4d {
4021 ($tensor:expr, $config:expr, $quant:expr) => {{
4022 with_quantized!($tensor, t, {
4023 let t = Self::swap_axes_if_needed(t, $config.into());
4024 let t = t.slice(s![0, .., .., ..]);
4025 t.map(|v| {
4026 let val: f32 = v.as_();
4027 (val - $quant.zero_point as f32) * $quant.scale
4028 })
4029 })
4030 }};
4031 }
4032
4033 let dequant_b = dequant_3d!(boxes_tensor, boxes_config, quant_boxes);
4034 let dequant_s = dequant_3d!(scores_tensor, scores_config, quant_scores);
4035 let dequant_c = dequant_3d!(classes_tensor, classes_config, quant_classes);
4036 let dequant_m = dequant_3d!(mask_tensor, mask_coeff_config, quant_masks);
4037 let dequant_p = dequant_4d!(protos_tensor, protos_config, quant_protos);
4038
4039 crate::yolo::decode_yolo_split_end_to_end_segdet_float(
4040 dequant_b.view(),
4041 dequant_s.view(),
4042 dequant_c.view(),
4043 dequant_m.view(),
4044 dequant_p.view(),
4045 self.score_threshold,
4046 output_boxes,
4047 output_masks,
4048 )?;
4049 Ok(())
4050 }
4051
4052 fn decode_yolo_segdet_quantized_proto(
4057 &self,
4058 outputs: &[ArrayViewDQuantized],
4059 boxes: &configs::Detection,
4060 protos: &configs::Protos,
4061 output_boxes: &mut Vec<DetectBox>,
4062 ) -> Result<ProtoData, DecoderError> {
4063 let (boxes_tensor, ind) =
4064 Self::find_outputs_with_shape_quantized(&boxes.shape, outputs, &[])?;
4065 let (protos_tensor, _) =
4066 Self::find_outputs_with_shape_quantized(&protos.shape, outputs, &[ind])?;
4067
4068 let quant_boxes = boxes
4069 .quantization
4070 .map(Quantization::from)
4071 .unwrap_or_default();
4072 let quant_protos = protos
4073 .quantization
4074 .map(Quantization::from)
4075 .unwrap_or_default();
4076
4077 let proto = with_quantized!(boxes_tensor, b, {
4078 with_quantized!(protos_tensor, p, {
4079 let box_tensor = Self::swap_axes_if_needed(b, boxes.into());
4080 let box_tensor = box_tensor.slice(s![0, .., ..]);
4081
4082 let protos_tensor = Self::swap_axes_if_needed(p, protos.into());
4083 let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
4084 let protos_tensor = Self::protos_to_hwc(protos_tensor, protos);
4085 crate::yolo::impl_yolo_segdet_quant_proto::<XYWH, _, _>(
4086 (box_tensor, quant_boxes),
4087 (protos_tensor, quant_protos),
4088 self.score_threshold,
4089 self.iou_threshold,
4090 self.nms,
4091 output_boxes,
4092 )
4093 })
4094 });
4095 Ok(proto)
4096 }
4097
4098 fn decode_yolo_segdet_float_proto<T>(
4099 &self,
4100 outputs: &[ArrayViewD<T>],
4101 boxes: &configs::Detection,
4102 protos: &configs::Protos,
4103 output_boxes: &mut Vec<DetectBox>,
4104 ) -> Result<ProtoData, DecoderError>
4105 where
4106 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
4107 f32: AsPrimitive<T>,
4108 {
4109 let (boxes_tensor, ind) = Self::find_outputs_with_shape(&boxes.shape, outputs, &[])?;
4110 let boxes_tensor = Self::swap_axes_if_needed(boxes_tensor, boxes.into());
4111 let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
4112
4113 let (protos_tensor, _) = Self::find_outputs_with_shape(&protos.shape, outputs, &[ind])?;
4114 let protos_tensor = Self::swap_axes_if_needed(protos_tensor, protos.into());
4115 let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
4116 let protos_tensor = Self::protos_to_hwc(protos_tensor, protos);
4117
4118 Ok(crate::yolo::impl_yolo_segdet_float_proto::<XYWH, _, _>(
4119 boxes_tensor,
4120 protos_tensor,
4121 self.score_threshold,
4122 self.iou_threshold,
4123 self.nms,
4124 output_boxes,
4125 ))
4126 }
4127
4128 #[allow(clippy::too_many_arguments)]
4129 fn decode_yolo_split_segdet_quantized_proto(
4130 &self,
4131 outputs: &[ArrayViewDQuantized],
4132 boxes: &configs::Boxes,
4133 scores: &configs::Scores,
4134 mask_coeff: &configs::MaskCoefficients,
4135 protos: &configs::Protos,
4136 output_boxes: &mut Vec<DetectBox>,
4137 ) -> Result<ProtoData, DecoderError> {
4138 let quant_boxes = boxes
4139 .quantization
4140 .map(Quantization::from)
4141 .unwrap_or_default();
4142 let quant_scores = scores
4143 .quantization
4144 .map(Quantization::from)
4145 .unwrap_or_default();
4146 let quant_masks = mask_coeff
4147 .quantization
4148 .map(Quantization::from)
4149 .unwrap_or_default();
4150 let quant_protos = protos
4151 .quantization
4152 .map(Quantization::from)
4153 .unwrap_or_default();
4154
4155 let mut skip = vec![];
4156
4157 let (boxes_tensor, ind) =
4158 Self::find_outputs_with_shape_quantized(&boxes.shape, outputs, &skip)?;
4159 skip.push(ind);
4160
4161 let (scores_tensor, ind) =
4162 Self::find_outputs_with_shape_quantized(&scores.shape, outputs, &skip)?;
4163 skip.push(ind);
4164
4165 let (mask_tensor, ind) =
4166 Self::find_outputs_with_shape_quantized(&mask_coeff.shape, outputs, &skip)?;
4167 skip.push(ind);
4168
4169 let (protos_tensor, _) =
4170 Self::find_outputs_with_shape_quantized(&protos.shape, outputs, &skip)?;
4171
4172 let det_indices = with_quantized!(boxes_tensor, b, {
4174 with_quantized!(scores_tensor, s, {
4175 let boxes_tensor = Self::swap_axes_if_needed(b, boxes.into());
4176 let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
4177
4178 let scores_tensor = Self::swap_axes_if_needed(s, scores.into());
4179 let scores_tensor = scores_tensor.slice(s![0, .., ..]);
4180
4181 impl_yolo_split_segdet_quant_get_boxes::<XYWH, _, _>(
4182 (boxes_tensor, quant_boxes),
4183 (scores_tensor, quant_scores),
4184 self.score_threshold,
4185 self.iou_threshold,
4186 self.nms,
4187 output_boxes.capacity(),
4188 )
4189 })
4190 });
4191
4192 let proto = with_quantized!(mask_tensor, m, {
4194 with_quantized!(protos_tensor, p, {
4195 let mask_tensor = Self::swap_axes_if_needed(m, mask_coeff.into());
4196 let mask_tensor = mask_tensor.slice(s![0, .., ..]);
4197 let mask_tensor = mask_tensor.reversed_axes();
4198
4199 let protos_tensor = Self::swap_axes_if_needed(p, protos.into());
4200 let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
4201 let protos_tensor = Self::protos_to_hwc(protos_tensor, protos);
4202
4203 crate::yolo::extract_proto_data_quant(
4204 det_indices,
4205 mask_tensor,
4206 quant_masks,
4207 protos_tensor,
4208 quant_protos,
4209 output_boxes,
4210 )
4211 })
4212 });
4213 Ok(proto)
4214 }
4215
4216 #[allow(clippy::too_many_arguments)]
4217 fn decode_yolo_split_segdet_float_proto<T>(
4218 &self,
4219 outputs: &[ArrayViewD<T>],
4220 boxes: &configs::Boxes,
4221 scores: &configs::Scores,
4222 mask_coeff: &configs::MaskCoefficients,
4223 protos: &configs::Protos,
4224 output_boxes: &mut Vec<DetectBox>,
4225 ) -> Result<ProtoData, DecoderError>
4226 where
4227 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
4228 f32: AsPrimitive<T>,
4229 {
4230 let mut skip = vec![];
4231 let (boxes_tensor, ind) = Self::find_outputs_with_shape(&boxes.shape, outputs, &skip)?;
4232 let boxes_tensor = Self::swap_axes_if_needed(boxes_tensor, boxes.into());
4233 let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
4234 skip.push(ind);
4235
4236 let (scores_tensor, ind) = Self::find_outputs_with_shape(&scores.shape, outputs, &skip)?;
4237 let scores_tensor = Self::swap_axes_if_needed(scores_tensor, scores.into());
4238 let scores_tensor = scores_tensor.slice(s![0, .., ..]);
4239 skip.push(ind);
4240
4241 let (mask_tensor, ind) = Self::find_outputs_with_shape(&mask_coeff.shape, outputs, &skip)?;
4242 let mask_tensor = Self::swap_axes_if_needed(mask_tensor, mask_coeff.into());
4243 let mask_tensor = mask_tensor.slice(s![0, .., ..]);
4244 skip.push(ind);
4245
4246 let (protos_tensor, _) = Self::find_outputs_with_shape(&protos.shape, outputs, &skip)?;
4247 let protos_tensor = Self::swap_axes_if_needed(protos_tensor, protos.into());
4248 let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
4249 let protos_tensor = Self::protos_to_hwc(protos_tensor, protos);
4250
4251 Ok(crate::yolo::impl_yolo_split_segdet_float_proto::<
4252 XYWH,
4253 _,
4254 _,
4255 _,
4256 _,
4257 >(
4258 boxes_tensor,
4259 scores_tensor,
4260 mask_tensor,
4261 protos_tensor,
4262 self.score_threshold,
4263 self.iou_threshold,
4264 self.nms,
4265 output_boxes,
4266 ))
4267 }
4268
4269 fn decode_yolo_end_to_end_segdet_float_proto<T>(
4270 &self,
4271 outputs: &[ArrayViewD<T>],
4272 boxes_config: &configs::Detection,
4273 protos_config: &configs::Protos,
4274 output_boxes: &mut Vec<DetectBox>,
4275 ) -> Result<ProtoData, DecoderError>
4276 where
4277 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
4278 f32: AsPrimitive<T>,
4279 {
4280 if outputs.len() < 2 {
4281 return Err(DecoderError::InvalidShape(
4282 "End-to-end segdet requires detection and protos outputs".to_string(),
4283 ));
4284 }
4285
4286 let (det_tensor, det_ind) =
4287 Self::find_outputs_with_shape(&boxes_config.shape, outputs, &[])?;
4288 let det_tensor = Self::swap_axes_if_needed(det_tensor, boxes_config.into());
4289 let det_tensor = det_tensor.slice(s![0, .., ..]);
4290
4291 let (protos_tensor, _) =
4292 Self::find_outputs_with_shape(&protos_config.shape, outputs, &[det_ind])?;
4293 let protos_tensor = Self::swap_axes_if_needed(protos_tensor, protos_config.into());
4294 let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
4295 let protos_tensor = Self::protos_to_hwc(protos_tensor, protos_config);
4296
4297 crate::yolo::decode_yolo_end_to_end_segdet_float_proto(
4298 det_tensor,
4299 protos_tensor,
4300 self.score_threshold,
4301 output_boxes,
4302 )
4303 }
4304
4305 fn decode_yolo_end_to_end_segdet_quantized_proto(
4306 &self,
4307 outputs: &[ArrayViewDQuantized],
4308 boxes_config: &configs::Detection,
4309 protos_config: &configs::Protos,
4310 output_boxes: &mut Vec<DetectBox>,
4311 ) -> Result<ProtoData, DecoderError> {
4312 let (det_tensor, det_ind) =
4313 Self::find_outputs_with_shape_quantized(&boxes_config.shape, outputs, &[])?;
4314 let (protos_tensor, _) =
4315 Self::find_outputs_with_shape_quantized(&protos_config.shape, outputs, &[det_ind])?;
4316
4317 let quant_det = boxes_config
4318 .quantization
4319 .map(Quantization::from)
4320 .unwrap_or_default();
4321 let quant_protos = protos_config
4322 .quantization
4323 .map(Quantization::from)
4324 .unwrap_or_default();
4325
4326 macro_rules! dequant_3d {
4329 ($tensor:expr, $config:expr, $quant:expr) => {{
4330 with_quantized!($tensor, t, {
4331 let t = Self::swap_axes_if_needed(t, $config.into());
4332 let t = t.slice(s![0, .., ..]);
4333 t.map(|v| {
4334 let val: f32 = v.as_();
4335 (val - $quant.zero_point as f32) * $quant.scale
4336 })
4337 })
4338 }};
4339 }
4340 macro_rules! dequant_4d {
4341 ($tensor:expr, $config:expr, $quant:expr) => {{
4342 with_quantized!($tensor, t, {
4343 let t = Self::swap_axes_if_needed(t, $config.into());
4344 let t = t.slice(s![0, .., .., ..]);
4345 t.map(|v| {
4346 let val: f32 = v.as_();
4347 (val - $quant.zero_point as f32) * $quant.scale
4348 })
4349 })
4350 }};
4351 }
4352
4353 let dequant_d = dequant_3d!(det_tensor, boxes_config, quant_det);
4354 let dequant_p = dequant_4d!(protos_tensor, protos_config, quant_protos);
4355
4356 let proto = crate::yolo::decode_yolo_end_to_end_segdet_float_proto(
4357 dequant_d.view(),
4358 dequant_p.view(),
4359 self.score_threshold,
4360 output_boxes,
4361 )?;
4362 Ok(proto)
4363 }
4364
4365 #[allow(clippy::too_many_arguments)]
4366 fn decode_yolo_split_end_to_end_segdet_float_proto<T>(
4367 &self,
4368 outputs: &[ArrayViewD<T>],
4369 boxes_config: &configs::Boxes,
4370 scores_config: &configs::Scores,
4371 classes_config: &configs::Classes,
4372 mask_coeff_config: &configs::MaskCoefficients,
4373 protos_config: &configs::Protos,
4374 output_boxes: &mut Vec<DetectBox>,
4375 ) -> Result<ProtoData, DecoderError>
4376 where
4377 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
4378 f32: AsPrimitive<T>,
4379 {
4380 let mut skip = vec![];
4381 let (boxes_tensor, ind) =
4382 Self::find_outputs_with_shape(&boxes_config.shape, outputs, &skip)?;
4383 let boxes_tensor = Self::swap_axes_if_needed(boxes_tensor, boxes_config.into());
4384 let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
4385 skip.push(ind);
4386
4387 let (scores_tensor, ind) =
4388 Self::find_outputs_with_shape(&scores_config.shape, outputs, &skip)?;
4389 let scores_tensor = Self::swap_axes_if_needed(scores_tensor, scores_config.into());
4390 let scores_tensor = scores_tensor.slice(s![0, .., ..]);
4391 skip.push(ind);
4392
4393 let (classes_tensor, ind) =
4394 Self::find_outputs_with_shape(&classes_config.shape, outputs, &skip)?;
4395 let classes_tensor = Self::swap_axes_if_needed(classes_tensor, classes_config.into());
4396 let classes_tensor = classes_tensor.slice(s![0, .., ..]);
4397 skip.push(ind);
4398
4399 let (mask_tensor, ind) =
4400 Self::find_outputs_with_shape(&mask_coeff_config.shape, outputs, &skip)?;
4401 let mask_tensor = Self::swap_axes_if_needed(mask_tensor, mask_coeff_config.into());
4402 let mask_tensor = mask_tensor.slice(s![0, .., ..]);
4403 skip.push(ind);
4404
4405 let (protos_tensor, _) =
4406 Self::find_outputs_with_shape(&protos_config.shape, outputs, &skip)?;
4407 let protos_tensor = Self::swap_axes_if_needed(protos_tensor, protos_config.into());
4408 let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
4409 let protos_tensor = Self::protos_to_hwc(protos_tensor, protos_config);
4410
4411 crate::yolo::decode_yolo_split_end_to_end_segdet_float_proto(
4412 boxes_tensor,
4413 scores_tensor,
4414 classes_tensor,
4415 mask_tensor,
4416 protos_tensor,
4417 self.score_threshold,
4418 output_boxes,
4419 )
4420 }
4421
4422 #[allow(clippy::too_many_arguments)]
4423 fn decode_yolo_split_end_to_end_segdet_quantized_proto(
4424 &self,
4425 outputs: &[ArrayViewDQuantized],
4426 boxes_config: &configs::Boxes,
4427 scores_config: &configs::Scores,
4428 classes_config: &configs::Classes,
4429 mask_coeff_config: &configs::MaskCoefficients,
4430 protos_config: &configs::Protos,
4431 output_boxes: &mut Vec<DetectBox>,
4432 ) -> Result<ProtoData, DecoderError> {
4433 let mut skip = vec![];
4434 let (boxes_tensor, ind) =
4435 Self::find_outputs_with_shape_quantized(&boxes_config.shape, outputs, &skip)?;
4436 skip.push(ind);
4437 let (scores_tensor, ind) =
4438 Self::find_outputs_with_shape_quantized(&scores_config.shape, outputs, &skip)?;
4439 skip.push(ind);
4440 let (classes_tensor, ind) =
4441 Self::find_outputs_with_shape_quantized(&classes_config.shape, outputs, &skip)?;
4442 skip.push(ind);
4443 let (mask_tensor, ind) =
4444 Self::find_outputs_with_shape_quantized(&mask_coeff_config.shape, outputs, &skip)?;
4445 skip.push(ind);
4446 let (protos_tensor, _) =
4447 Self::find_outputs_with_shape_quantized(&protos_config.shape, outputs, &skip)?;
4448
4449 let quant_boxes = boxes_config
4450 .quantization
4451 .map(Quantization::from)
4452 .unwrap_or_default();
4453 let quant_scores = scores_config
4454 .quantization
4455 .map(Quantization::from)
4456 .unwrap_or_default();
4457 let quant_classes = classes_config
4458 .quantization
4459 .map(Quantization::from)
4460 .unwrap_or_default();
4461 let quant_masks = mask_coeff_config
4462 .quantization
4463 .map(Quantization::from)
4464 .unwrap_or_default();
4465 let quant_protos = protos_config
4466 .quantization
4467 .map(Quantization::from)
4468 .unwrap_or_default();
4469
4470 macro_rules! dequant_3d {
4471 ($tensor:expr, $config:expr, $quant:expr) => {{
4472 with_quantized!($tensor, t, {
4473 let t = Self::swap_axes_if_needed(t, $config.into());
4474 let t = t.slice(s![0, .., ..]);
4475 t.map(|v| {
4476 let val: f32 = v.as_();
4477 (val - $quant.zero_point as f32) * $quant.scale
4478 })
4479 })
4480 }};
4481 }
4482 macro_rules! dequant_4d {
4483 ($tensor:expr, $config:expr, $quant:expr) => {{
4484 with_quantized!($tensor, t, {
4485 let t = Self::swap_axes_if_needed(t, $config.into());
4486 let t = t.slice(s![0, .., .., ..]);
4487 t.map(|v| {
4488 let val: f32 = v.as_();
4489 (val - $quant.zero_point as f32) * $quant.scale
4490 })
4491 })
4492 }};
4493 }
4494
4495 let dequant_b = dequant_3d!(boxes_tensor, boxes_config, quant_boxes);
4496 let dequant_s = dequant_3d!(scores_tensor, scores_config, quant_scores);
4497 let dequant_c = dequant_3d!(classes_tensor, classes_config, quant_classes);
4498 let dequant_m = dequant_3d!(mask_tensor, mask_coeff_config, quant_masks);
4499 let dequant_p = dequant_4d!(protos_tensor, protos_config, quant_protos);
4500
4501 crate::yolo::decode_yolo_split_end_to_end_segdet_float_proto(
4502 dequant_b.view(),
4503 dequant_s.view(),
4504 dequant_c.view(),
4505 dequant_m.view(),
4506 dequant_p.view(),
4507 self.score_threshold,
4508 output_boxes,
4509 )
4510 }
4511
4512 fn match_outputs_to_detect<'a, 'b, T>(
4513 configs: &[configs::Detection],
4514 outputs: &'a [ArrayViewD<'b, T>],
4515 ) -> Result<Vec<&'a ArrayViewD<'b, T>>, DecoderError> {
4516 let mut new_output_order = Vec::new();
4517 for c in configs {
4518 let mut found = false;
4519 for o in outputs {
4520 if o.shape() == c.shape {
4521 new_output_order.push(o);
4522 found = true;
4523 break;
4524 }
4525 }
4526 if !found {
4527 return Err(DecoderError::InvalidShape(format!(
4528 "Did not find output with shape {:?}",
4529 c.shape
4530 )));
4531 }
4532 }
4533 Ok(new_output_order)
4534 }
4535
4536 fn find_outputs_with_shape<'a, 'b, T>(
4537 shape: &[usize],
4538 outputs: &'a [ArrayViewD<'b, T>],
4539 skip: &[usize],
4540 ) -> Result<(&'a ArrayViewD<'b, T>, usize), DecoderError> {
4541 for (ind, o) in outputs.iter().enumerate() {
4542 if skip.contains(&ind) {
4543 continue;
4544 }
4545 if o.shape() == shape {
4546 return Ok((o, ind));
4547 }
4548 }
4549 Err(DecoderError::InvalidShape(format!(
4550 "Did not find output with shape {:?}",
4551 shape
4552 )))
4553 }
4554
4555 fn find_outputs_with_shape_quantized<'a, 'b>(
4556 shape: &[usize],
4557 outputs: &'a [ArrayViewDQuantized<'b>],
4558 skip: &[usize],
4559 ) -> Result<(&'a ArrayViewDQuantized<'b>, usize), DecoderError> {
4560 for (ind, o) in outputs.iter().enumerate() {
4561 if skip.contains(&ind) {
4562 continue;
4563 }
4564 if o.shape() == shape {
4565 return Ok((o, ind));
4566 }
4567 }
4568 Err(DecoderError::InvalidShape(format!(
4569 "Did not find output with shape {:?}",
4570 shape
4571 )))
4572 }
4573
4574 fn modelpack_det_order(x: DimName) -> usize {
4577 match x {
4578 DimName::Batch => 0,
4579 DimName::NumBoxes => 1,
4580 DimName::Padding => 2,
4581 DimName::BoxCoords => 3,
4582 _ => 1000, }
4584 }
4585
4586 fn yolo_det_order(x: DimName) -> usize {
4589 match x {
4590 DimName::Batch => 0,
4591 DimName::NumFeatures => 1,
4592 DimName::NumBoxes => 2,
4593 _ => 1000, }
4595 }
4596
4597 fn modelpack_boxes_order(x: DimName) -> usize {
4600 match x {
4601 DimName::Batch => 0,
4602 DimName::NumBoxes => 1,
4603 DimName::Padding => 2,
4604 DimName::BoxCoords => 3,
4605 _ => 1000, }
4607 }
4608
4609 fn yolo_boxes_order(x: DimName) -> usize {
4612 match x {
4613 DimName::Batch => 0,
4614 DimName::BoxCoords => 1,
4615 DimName::NumBoxes => 2,
4616 _ => 1000, }
4618 }
4619
4620 fn modelpack_scores_order(x: DimName) -> usize {
4623 match x {
4624 DimName::Batch => 0,
4625 DimName::NumBoxes => 1,
4626 DimName::NumClasses => 2,
4627 _ => 1000, }
4629 }
4630
4631 fn yolo_scores_order(x: DimName) -> usize {
4632 match x {
4633 DimName::Batch => 0,
4634 DimName::NumClasses => 1,
4635 DimName::NumBoxes => 2,
4636 _ => 1000, }
4638 }
4639
4640 fn modelpack_segmentation_order(x: DimName) -> usize {
4643 match x {
4644 DimName::Batch => 0,
4645 DimName::Height => 1,
4646 DimName::Width => 2,
4647 DimName::NumClasses => 3,
4648 _ => 1000, }
4650 }
4651
4652 fn modelpack_mask_order(x: DimName) -> usize {
4655 match x {
4656 DimName::Batch => 0,
4657 DimName::Height => 1,
4658 DimName::Width => 2,
4659 _ => 1000, }
4661 }
4662
4663 fn yolo_protos_order(x: DimName) -> usize {
4666 match x {
4667 DimName::Batch => 0,
4668 DimName::Height => 1,
4669 DimName::Width => 2,
4670 DimName::NumProtos => 3,
4671 _ => 1000, }
4673 }
4674
4675 fn yolo_maskcoefficients_order(x: DimName) -> usize {
4678 match x {
4679 DimName::Batch => 0,
4680 DimName::NumProtos => 1,
4681 DimName::NumBoxes => 2,
4682 _ => 1000, }
4684 }
4685
4686 fn get_order_fn(config: ConfigOutputRef) -> fn(DimName) -> usize {
4687 let decoder_type = config.decoder();
4688 match (config, decoder_type) {
4689 (ConfigOutputRef::Detection(_), DecoderType::ModelPack) => Self::modelpack_det_order,
4690 (ConfigOutputRef::Detection(_), DecoderType::Ultralytics) => Self::yolo_det_order,
4691 (ConfigOutputRef::Boxes(_), DecoderType::ModelPack) => Self::modelpack_boxes_order,
4692 (ConfigOutputRef::Boxes(_), DecoderType::Ultralytics) => Self::yolo_boxes_order,
4693 (ConfigOutputRef::Scores(_), DecoderType::ModelPack) => Self::modelpack_scores_order,
4694 (ConfigOutputRef::Scores(_), DecoderType::Ultralytics) => Self::yolo_scores_order,
4695 (ConfigOutputRef::Segmentation(_), _) => Self::modelpack_segmentation_order,
4696 (ConfigOutputRef::Mask(_), _) => Self::modelpack_mask_order,
4697 (ConfigOutputRef::Protos(_), _) => Self::yolo_protos_order,
4698 (ConfigOutputRef::MaskCoefficients(_), _) => Self::yolo_maskcoefficients_order,
4699 (ConfigOutputRef::Classes(_), _) => Self::yolo_scores_order,
4700 }
4701 }
4702
4703 fn protos_to_hwc<'a, T>(
4718 protos: ArrayView<'a, T, ndarray::Ix3>,
4719 config: &configs::Protos,
4720 ) -> ArrayView<'a, T, ndarray::Ix3> {
4721 if config.dshape.is_empty() {
4722 let (d0, d1, d2) = protos.dim();
4723 log::warn!(
4724 "protos_to_hwc: no dshape configured, using size heuristic on \
4725 shape ({d0}, {d1}, {d2}); set dshape in config for reliable ordering"
4726 );
4727 if d0 < d1 && d0 < d2 {
4728 protos.permuted_axes([1, 2, 0])
4730 } else {
4731 protos
4733 }
4734 } else {
4735 protos
4736 }
4737 }
4738
4739 fn swap_axes_if_needed<'a, T, D: Dimension>(
4740 array: &ArrayView<'a, T, D>,
4741 config: ConfigOutputRef,
4742 ) -> ArrayView<'a, T, D> {
4743 let mut array = array.clone();
4744 if config.dshape().is_empty() {
4745 return array;
4746 }
4747 let order_fn: fn(DimName) -> usize = Self::get_order_fn(config.clone());
4748 let mut current_order: Vec<usize> = config
4749 .dshape()
4750 .iter()
4751 .map(|x| order_fn(x.0))
4752 .collect::<Vec<_>>();
4753
4754 assert_eq!(array.shape().len(), current_order.len());
4755 for i in 0..current_order.len() {
4758 let mut swapped = false;
4759 for j in 0..current_order.len() - 1 - i {
4760 if current_order[j] > current_order[j + 1] {
4761 array.swap_axes(j, j + 1);
4762 current_order.swap(j, j + 1);
4763 swapped = true;
4764 }
4765 }
4766 if !swapped {
4767 break;
4768 }
4769 }
4770 array
4771 }
4772
4773 fn match_outputs_to_detect_quantized<'a, 'b>(
4774 configs: &[configs::Detection],
4775 outputs: &'a [ArrayViewDQuantized<'b>],
4776 ) -> Result<Vec<&'a ArrayViewDQuantized<'b>>, DecoderError> {
4777 let mut new_output_order = Vec::new();
4778 for c in configs {
4779 let mut found = false;
4780 for o in outputs {
4781 if o.shape() == c.shape {
4782 new_output_order.push(o);
4783 found = true;
4784 break;
4785 }
4786 }
4787 if !found {
4788 return Err(DecoderError::InvalidShape(format!(
4789 "Did not find output with shape {:?}",
4790 c.shape
4791 )));
4792 }
4793 }
4794 Ok(new_output_order)
4795 }
4796}
4797
4798#[cfg(test)]
4799#[cfg_attr(coverage_nightly, coverage(off))]
4800mod decoder_builder_tests {
4801 use super::*;
4802
4803 #[test]
4804 fn test_decoder_builder_no_config() {
4805 use crate::DecoderBuilder;
4806 let result = DecoderBuilder::default().build();
4807 assert!(matches!(result, Err(DecoderError::NoConfig)));
4808 }
4809
4810 #[test]
4811 fn test_decoder_builder_empty_config() {
4812 use crate::DecoderBuilder;
4813 let result = DecoderBuilder::default()
4814 .with_config(ConfigOutputs {
4815 outputs: vec![],
4816 ..Default::default()
4817 })
4818 .build();
4819 assert!(
4820 matches!(result, Err(DecoderError::InvalidConfig(s)) if s == "No outputs found in config")
4821 );
4822 }
4823
4824 #[test]
4825 fn test_malformed_config_yaml() {
4826 let malformed_yaml = "
4827 model_type: yolov8_det
4828 outputs:
4829 - shape: [1, 84, 8400]
4830 "
4831 .to_owned();
4832 let result = DecoderBuilder::new()
4833 .with_config_yaml_str(malformed_yaml)
4834 .build();
4835 assert!(matches!(result, Err(DecoderError::Yaml(_))));
4836 }
4837
4838 #[test]
4839 fn test_malformed_config_json() {
4840 let malformed_yaml = "
4841 {
4842 \"model_type\": \"yolov8_det\",
4843 \"outputs\": [
4844 {
4845 \"shape\": [1, 84, 8400]
4846 }
4847 ]
4848 }"
4849 .to_owned();
4850 let result = DecoderBuilder::new()
4851 .with_config_json_str(malformed_yaml)
4852 .build();
4853 assert!(matches!(result, Err(DecoderError::Json(_))));
4854 }
4855
4856 #[test]
4857 fn test_modelpack_and_yolo_config_error() {
4858 let result = DecoderBuilder::new()
4859 .with_config_modelpack_det(
4860 configs::Boxes {
4861 decoder: configs::DecoderType::Ultralytics,
4862 shape: vec![1, 4, 8400],
4863 quantization: None,
4864 dshape: vec![
4865 (DimName::Batch, 1),
4866 (DimName::BoxCoords, 4),
4867 (DimName::NumBoxes, 8400),
4868 ],
4869 normalized: Some(true),
4870 },
4871 configs::Scores {
4872 decoder: configs::DecoderType::ModelPack,
4873 shape: vec![1, 80, 8400],
4874 quantization: None,
4875 dshape: vec![
4876 (DimName::Batch, 1),
4877 (DimName::NumClasses, 80),
4878 (DimName::NumBoxes, 8400),
4879 ],
4880 },
4881 )
4882 .build();
4883
4884 assert!(matches!(
4885 result, Err(DecoderError::InvalidConfig(s)) if s == "Both ModelPack and Yolo outputs found in config"
4886 ));
4887 }
4888
4889 #[test]
4890 fn test_yolo_invalid_seg_shape() {
4891 let result = DecoderBuilder::new()
4892 .with_config_yolo_segdet(
4893 configs::Detection {
4894 decoder: configs::DecoderType::Ultralytics,
4895 shape: vec![1, 85, 8400, 1], quantization: None,
4897 anchors: None,
4898 dshape: vec![
4899 (DimName::Batch, 1),
4900 (DimName::NumFeatures, 85),
4901 (DimName::NumBoxes, 8400),
4902 (DimName::Batch, 1),
4903 ],
4904 normalized: Some(true),
4905 },
4906 configs::Protos {
4907 decoder: configs::DecoderType::Ultralytics,
4908 shape: vec![1, 32, 160, 160],
4909 quantization: None,
4910 dshape: vec![
4911 (DimName::Batch, 1),
4912 (DimName::NumProtos, 32),
4913 (DimName::Height, 160),
4914 (DimName::Width, 160),
4915 ],
4916 },
4917 Some(DecoderVersion::Yolo11),
4918 )
4919 .build();
4920
4921 assert!(matches!(
4922 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Detection shape")
4923 ));
4924 }
4925
4926 #[test]
4927 fn test_yolo_invalid_mask() {
4928 let result = DecoderBuilder::new()
4929 .with_config(ConfigOutputs {
4930 outputs: vec![ConfigOutput::Mask(configs::Mask {
4931 shape: vec![1, 160, 160, 1],
4932 decoder: configs::DecoderType::Ultralytics,
4933 quantization: None,
4934 dshape: vec![
4935 (DimName::Batch, 1),
4936 (DimName::Height, 160),
4937 (DimName::Width, 160),
4938 (DimName::NumFeatures, 1),
4939 ],
4940 })],
4941 ..Default::default()
4942 })
4943 .build();
4944
4945 assert!(matches!(
4946 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Mask output with Yolo decoder")
4947 ));
4948 }
4949
4950 #[test]
4951 fn test_yolo_invalid_outputs() {
4952 let result = DecoderBuilder::new()
4953 .with_config(ConfigOutputs {
4954 outputs: vec![ConfigOutput::Segmentation(configs::Segmentation {
4955 shape: vec![1, 84, 8400],
4956 decoder: configs::DecoderType::Ultralytics,
4957 quantization: None,
4958 dshape: vec![
4959 (DimName::Batch, 1),
4960 (DimName::NumFeatures, 84),
4961 (DimName::NumBoxes, 8400),
4962 ],
4963 })],
4964 ..Default::default()
4965 })
4966 .build();
4967
4968 assert!(
4969 matches!(result, Err(DecoderError::InvalidConfig(s)) if s == "Invalid Segmentation output with Yolo decoder")
4970 );
4971 }
4972
4973 #[test]
4974 fn test_yolo_invalid_det() {
4975 let result = DecoderBuilder::new()
4976 .with_config_yolo_det(
4977 configs::Detection {
4978 anchors: None,
4979 decoder: DecoderType::Ultralytics,
4980 quantization: None,
4981 shape: vec![1, 84, 8400, 1], dshape: vec![
4983 (DimName::Batch, 1),
4984 (DimName::NumFeatures, 84),
4985 (DimName::NumBoxes, 8400),
4986 (DimName::Batch, 1),
4987 ],
4988 normalized: Some(true),
4989 },
4990 Some(DecoderVersion::Yolo11),
4991 )
4992 .build();
4993
4994 assert!(matches!(
4995 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Detection shape")));
4996
4997 let result = DecoderBuilder::new()
4998 .with_config_yolo_det(
4999 configs::Detection {
5000 anchors: None,
5001 decoder: DecoderType::Ultralytics,
5002 quantization: None,
5003 shape: vec![1, 8400, 3], dshape: vec![
5005 (DimName::Batch, 1),
5006 (DimName::NumBoxes, 8400),
5007 (DimName::NumFeatures, 3),
5008 ],
5009 normalized: Some(true),
5010 },
5011 Some(DecoderVersion::Yolo11),
5012 )
5013 .build();
5014
5015 assert!(
5016 matches!(
5017 &result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid shape: Yolo num_features 3 must be greater than 4")),
5018 "{}",
5019 result.unwrap_err()
5020 );
5021
5022 let result = DecoderBuilder::new()
5023 .with_config_yolo_det(
5024 configs::Detection {
5025 anchors: None,
5026 decoder: DecoderType::Ultralytics,
5027 quantization: None,
5028 shape: vec![1, 3, 8400], dshape: Vec::new(),
5030 normalized: Some(true),
5031 },
5032 Some(DecoderVersion::Yolo11),
5033 )
5034 .build();
5035
5036 assert!(matches!(
5037 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid shape: Yolo num_features 3 must be greater than 4")));
5038 }
5039
5040 #[test]
5041 fn test_yolo_invalid_segdet() {
5042 let result = DecoderBuilder::new()
5043 .with_config_yolo_segdet(
5044 configs::Detection {
5045 decoder: configs::DecoderType::Ultralytics,
5046 shape: vec![1, 85, 8400, 1], quantization: None,
5048 anchors: None,
5049 dshape: vec![
5050 (DimName::Batch, 1),
5051 (DimName::NumFeatures, 85),
5052 (DimName::NumBoxes, 8400),
5053 (DimName::Batch, 1),
5054 ],
5055 normalized: Some(true),
5056 },
5057 configs::Protos {
5058 decoder: configs::DecoderType::Ultralytics,
5059 shape: vec![1, 32, 160, 160],
5060 quantization: None,
5061 dshape: vec![
5062 (DimName::Batch, 1),
5063 (DimName::NumProtos, 32),
5064 (DimName::Height, 160),
5065 (DimName::Width, 160),
5066 ],
5067 },
5068 Some(DecoderVersion::Yolo11),
5069 )
5070 .build();
5071
5072 assert!(matches!(
5073 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Detection shape")));
5074
5075 let result = DecoderBuilder::new()
5076 .with_config_yolo_segdet(
5077 configs::Detection {
5078 decoder: configs::DecoderType::Ultralytics,
5079 shape: vec![1, 85, 8400],
5080 quantization: None,
5081 anchors: None,
5082 dshape: vec![
5083 (DimName::Batch, 1),
5084 (DimName::NumFeatures, 85),
5085 (DimName::NumBoxes, 8400),
5086 ],
5087 normalized: Some(true),
5088 },
5089 configs::Protos {
5090 decoder: configs::DecoderType::Ultralytics,
5091 shape: vec![1, 32, 160, 160, 1], dshape: vec![
5093 (DimName::Batch, 1),
5094 (DimName::NumProtos, 32),
5095 (DimName::Height, 160),
5096 (DimName::Width, 160),
5097 (DimName::Batch, 1),
5098 ],
5099 quantization: None,
5100 },
5101 Some(DecoderVersion::Yolo11),
5102 )
5103 .build();
5104
5105 assert!(matches!(
5106 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Protos shape")));
5107
5108 let result = DecoderBuilder::new()
5109 .with_config_yolo_segdet(
5110 configs::Detection {
5111 decoder: configs::DecoderType::Ultralytics,
5112 shape: vec![1, 8400, 36], quantization: None,
5114 anchors: None,
5115 dshape: vec![
5116 (DimName::Batch, 1),
5117 (DimName::NumBoxes, 8400),
5118 (DimName::NumFeatures, 36),
5119 ],
5120 normalized: Some(true),
5121 },
5122 configs::Protos {
5123 decoder: configs::DecoderType::Ultralytics,
5124 shape: vec![1, 32, 160, 160],
5125 quantization: None,
5126 dshape: vec![
5127 (DimName::Batch, 1),
5128 (DimName::NumProtos, 32),
5129 (DimName::Height, 160),
5130 (DimName::Width, 160),
5131 ],
5132 },
5133 Some(DecoderVersion::Yolo11),
5134 )
5135 .build();
5136 println!("{:?}", result);
5137 assert!(matches!(
5138 result, Err(DecoderError::InvalidConfig(s)) if s == "Invalid shape: Yolo num_features 36 must be greater than 36"));
5139 }
5140
5141 #[test]
5142 fn test_yolo_invalid_split_det() {
5143 let result = DecoderBuilder::new()
5144 .with_config_yolo_split_det(
5145 configs::Boxes {
5146 decoder: configs::DecoderType::Ultralytics,
5147 shape: vec![1, 4, 8400, 1], quantization: None,
5149 dshape: vec![
5150 (DimName::Batch, 1),
5151 (DimName::BoxCoords, 4),
5152 (DimName::NumBoxes, 8400),
5153 (DimName::Batch, 1),
5154 ],
5155 normalized: Some(true),
5156 },
5157 configs::Scores {
5158 decoder: configs::DecoderType::Ultralytics,
5159 shape: vec![1, 80, 8400],
5160 quantization: None,
5161 dshape: vec![
5162 (DimName::Batch, 1),
5163 (DimName::NumClasses, 80),
5164 (DimName::NumBoxes, 8400),
5165 ],
5166 },
5167 )
5168 .build();
5169
5170 assert!(matches!(
5171 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Split Boxes shape")));
5172
5173 let result = DecoderBuilder::new()
5174 .with_config_yolo_split_det(
5175 configs::Boxes {
5176 decoder: configs::DecoderType::Ultralytics,
5177 shape: vec![1, 4, 8400],
5178 quantization: None,
5179 dshape: vec![
5180 (DimName::Batch, 1),
5181 (DimName::BoxCoords, 4),
5182 (DimName::NumBoxes, 8400),
5183 ],
5184 normalized: Some(true),
5185 },
5186 configs::Scores {
5187 decoder: configs::DecoderType::Ultralytics,
5188 shape: vec![1, 80, 8400, 1], quantization: None,
5190 dshape: vec![
5191 (DimName::Batch, 1),
5192 (DimName::NumClasses, 80),
5193 (DimName::NumBoxes, 8400),
5194 (DimName::Batch, 1),
5195 ],
5196 },
5197 )
5198 .build();
5199
5200 assert!(matches!(
5201 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Split Scores shape")));
5202
5203 let result = DecoderBuilder::new()
5204 .with_config_yolo_split_det(
5205 configs::Boxes {
5206 decoder: configs::DecoderType::Ultralytics,
5207 shape: vec![1, 8400, 4],
5208 quantization: None,
5209 dshape: vec![
5210 (DimName::Batch, 1),
5211 (DimName::NumBoxes, 8400),
5212 (DimName::BoxCoords, 4),
5213 ],
5214 normalized: Some(true),
5215 },
5216 configs::Scores {
5217 decoder: configs::DecoderType::Ultralytics,
5218 shape: vec![1, 8400 + 1, 80], quantization: None,
5220 dshape: vec![
5221 (DimName::Batch, 1),
5222 (DimName::NumBoxes, 8401),
5223 (DimName::NumClasses, 80),
5224 ],
5225 },
5226 )
5227 .build();
5228
5229 assert!(matches!(
5230 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Yolo Split Detection Boxes num 8400 incompatible with Scores num 8401")));
5231
5232 let result = DecoderBuilder::new()
5233 .with_config_yolo_split_det(
5234 configs::Boxes {
5235 decoder: configs::DecoderType::Ultralytics,
5236 shape: vec![1, 5, 8400], quantization: None,
5238 dshape: vec![
5239 (DimName::Batch, 1),
5240 (DimName::BoxCoords, 5),
5241 (DimName::NumBoxes, 8400),
5242 ],
5243 normalized: Some(true),
5244 },
5245 configs::Scores {
5246 decoder: configs::DecoderType::Ultralytics,
5247 shape: vec![1, 80, 8400],
5248 quantization: None,
5249 dshape: vec![
5250 (DimName::Batch, 1),
5251 (DimName::NumClasses, 80),
5252 (DimName::NumBoxes, 8400),
5253 ],
5254 },
5255 )
5256 .build();
5257 assert!(matches!(
5258 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("BoxCoords dimension size must be 4")));
5259 }
5260
5261 #[test]
5262 fn test_yolo_invalid_split_segdet() {
5263 let result = DecoderBuilder::new()
5264 .with_config_yolo_split_segdet(
5265 configs::Boxes {
5266 decoder: configs::DecoderType::Ultralytics,
5267 shape: vec![1, 8400, 4, 1],
5268 quantization: None,
5269 dshape: vec![
5270 (DimName::Batch, 1),
5271 (DimName::NumBoxes, 8400),
5272 (DimName::BoxCoords, 4),
5273 (DimName::Batch, 1),
5274 ],
5275 normalized: Some(true),
5276 },
5277 configs::Scores {
5278 decoder: configs::DecoderType::Ultralytics,
5279 shape: vec![1, 8400, 80],
5280
5281 quantization: None,
5282 dshape: vec![
5283 (DimName::Batch, 1),
5284 (DimName::NumBoxes, 8400),
5285 (DimName::NumClasses, 80),
5286 ],
5287 },
5288 configs::MaskCoefficients {
5289 decoder: configs::DecoderType::Ultralytics,
5290 shape: vec![1, 8400, 32],
5291 quantization: None,
5292 dshape: vec![
5293 (DimName::Batch, 1),
5294 (DimName::NumBoxes, 8400),
5295 (DimName::NumProtos, 32),
5296 ],
5297 },
5298 configs::Protos {
5299 decoder: configs::DecoderType::Ultralytics,
5300 shape: vec![1, 32, 160, 160],
5301 quantization: None,
5302 dshape: vec![
5303 (DimName::Batch, 1),
5304 (DimName::NumProtos, 32),
5305 (DimName::Height, 160),
5306 (DimName::Width, 160),
5307 ],
5308 },
5309 )
5310 .build();
5311
5312 assert!(matches!(
5313 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Split Boxes shape")));
5314
5315 let result = DecoderBuilder::new()
5316 .with_config_yolo_split_segdet(
5317 configs::Boxes {
5318 decoder: configs::DecoderType::Ultralytics,
5319 shape: vec![1, 8400, 4],
5320 quantization: None,
5321 dshape: vec![
5322 (DimName::Batch, 1),
5323 (DimName::NumBoxes, 8400),
5324 (DimName::BoxCoords, 4),
5325 ],
5326 normalized: Some(true),
5327 },
5328 configs::Scores {
5329 decoder: configs::DecoderType::Ultralytics,
5330 shape: vec![1, 8400, 80, 1],
5331 quantization: None,
5332 dshape: vec![
5333 (DimName::Batch, 1),
5334 (DimName::NumBoxes, 8400),
5335 (DimName::NumClasses, 80),
5336 (DimName::Batch, 1),
5337 ],
5338 },
5339 configs::MaskCoefficients {
5340 decoder: configs::DecoderType::Ultralytics,
5341 shape: vec![1, 8400, 32],
5342 quantization: None,
5343 dshape: vec![
5344 (DimName::Batch, 1),
5345 (DimName::NumBoxes, 8400),
5346 (DimName::NumProtos, 32),
5347 ],
5348 },
5349 configs::Protos {
5350 decoder: configs::DecoderType::Ultralytics,
5351 shape: vec![1, 32, 160, 160],
5352 quantization: None,
5353 dshape: vec![
5354 (DimName::Batch, 1),
5355 (DimName::NumProtos, 32),
5356 (DimName::Height, 160),
5357 (DimName::Width, 160),
5358 ],
5359 },
5360 )
5361 .build();
5362
5363 assert!(matches!(
5364 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Split Scores shape")));
5365
5366 let result = DecoderBuilder::new()
5367 .with_config_yolo_split_segdet(
5368 configs::Boxes {
5369 decoder: configs::DecoderType::Ultralytics,
5370 shape: vec![1, 8400, 4],
5371 quantization: None,
5372 dshape: vec![
5373 (DimName::Batch, 1),
5374 (DimName::NumBoxes, 8400),
5375 (DimName::BoxCoords, 4),
5376 ],
5377 normalized: Some(true),
5378 },
5379 configs::Scores {
5380 decoder: configs::DecoderType::Ultralytics,
5381 shape: vec![1, 8400, 80],
5382 quantization: None,
5383 dshape: vec![
5384 (DimName::Batch, 1),
5385 (DimName::NumBoxes, 8400),
5386 (DimName::NumClasses, 80),
5387 ],
5388 },
5389 configs::MaskCoefficients {
5390 decoder: configs::DecoderType::Ultralytics,
5391 shape: vec![1, 8400, 32, 1],
5392 quantization: None,
5393 dshape: vec![
5394 (DimName::Batch, 1),
5395 (DimName::NumBoxes, 8400),
5396 (DimName::NumProtos, 32),
5397 (DimName::Batch, 1),
5398 ],
5399 },
5400 configs::Protos {
5401 decoder: configs::DecoderType::Ultralytics,
5402 shape: vec![1, 32, 160, 160],
5403 quantization: None,
5404 dshape: vec![
5405 (DimName::Batch, 1),
5406 (DimName::NumProtos, 32),
5407 (DimName::Height, 160),
5408 (DimName::Width, 160),
5409 ],
5410 },
5411 )
5412 .build();
5413
5414 assert!(matches!(
5415 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Split Mask Coefficients shape")));
5416
5417 let result = DecoderBuilder::new()
5418 .with_config_yolo_split_segdet(
5419 configs::Boxes {
5420 decoder: configs::DecoderType::Ultralytics,
5421 shape: vec![1, 8400, 4],
5422 quantization: None,
5423 dshape: vec![
5424 (DimName::Batch, 1),
5425 (DimName::NumBoxes, 8400),
5426 (DimName::BoxCoords, 4),
5427 ],
5428 normalized: Some(true),
5429 },
5430 configs::Scores {
5431 decoder: configs::DecoderType::Ultralytics,
5432 shape: vec![1, 8400, 80],
5433 quantization: None,
5434 dshape: vec![
5435 (DimName::Batch, 1),
5436 (DimName::NumBoxes, 8400),
5437 (DimName::NumClasses, 80),
5438 ],
5439 },
5440 configs::MaskCoefficients {
5441 decoder: configs::DecoderType::Ultralytics,
5442 shape: vec![1, 8400, 32],
5443 quantization: None,
5444 dshape: vec![
5445 (DimName::Batch, 1),
5446 (DimName::NumBoxes, 8400),
5447 (DimName::NumProtos, 32),
5448 ],
5449 },
5450 configs::Protos {
5451 decoder: configs::DecoderType::Ultralytics,
5452 shape: vec![1, 32, 160, 160, 1],
5453 quantization: None,
5454 dshape: vec![
5455 (DimName::Batch, 1),
5456 (DimName::NumProtos, 32),
5457 (DimName::Height, 160),
5458 (DimName::Width, 160),
5459 (DimName::Batch, 1),
5460 ],
5461 },
5462 )
5463 .build();
5464
5465 assert!(matches!(
5466 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Protos shape")));
5467
5468 let result = DecoderBuilder::new()
5469 .with_config_yolo_split_segdet(
5470 configs::Boxes {
5471 decoder: configs::DecoderType::Ultralytics,
5472 shape: vec![1, 8400, 4],
5473 quantization: None,
5474 dshape: vec![
5475 (DimName::Batch, 1),
5476 (DimName::NumBoxes, 8400),
5477 (DimName::BoxCoords, 4),
5478 ],
5479 normalized: Some(true),
5480 },
5481 configs::Scores {
5482 decoder: configs::DecoderType::Ultralytics,
5483 shape: vec![1, 8401, 80],
5484 quantization: None,
5485 dshape: vec![
5486 (DimName::Batch, 1),
5487 (DimName::NumBoxes, 8401),
5488 (DimName::NumClasses, 80),
5489 ],
5490 },
5491 configs::MaskCoefficients {
5492 decoder: configs::DecoderType::Ultralytics,
5493 shape: vec![1, 8400, 32],
5494 quantization: None,
5495 dshape: vec![
5496 (DimName::Batch, 1),
5497 (DimName::NumBoxes, 8400),
5498 (DimName::NumProtos, 32),
5499 ],
5500 },
5501 configs::Protos {
5502 decoder: configs::DecoderType::Ultralytics,
5503 shape: vec![1, 32, 160, 160],
5504 quantization: None,
5505 dshape: vec![
5506 (DimName::Batch, 1),
5507 (DimName::NumProtos, 32),
5508 (DimName::Height, 160),
5509 (DimName::Width, 160),
5510 ],
5511 },
5512 )
5513 .build();
5514
5515 assert!(matches!(
5516 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Yolo Split Detection Boxes num 8400 incompatible with Scores num 8401")));
5517
5518 let result = DecoderBuilder::new()
5519 .with_config_yolo_split_segdet(
5520 configs::Boxes {
5521 decoder: configs::DecoderType::Ultralytics,
5522 shape: vec![1, 8400, 4],
5523 quantization: None,
5524 dshape: vec![
5525 (DimName::Batch, 1),
5526 (DimName::NumBoxes, 8400),
5527 (DimName::BoxCoords, 4),
5528 ],
5529 normalized: Some(true),
5530 },
5531 configs::Scores {
5532 decoder: configs::DecoderType::Ultralytics,
5533 shape: vec![1, 8400, 80],
5534 quantization: None,
5535 dshape: vec![
5536 (DimName::Batch, 1),
5537 (DimName::NumBoxes, 8400),
5538 (DimName::NumClasses, 80),
5539 ],
5540 },
5541 configs::MaskCoefficients {
5542 decoder: configs::DecoderType::Ultralytics,
5543 shape: vec![1, 8401, 32],
5544
5545 quantization: None,
5546 dshape: vec![
5547 (DimName::Batch, 1),
5548 (DimName::NumBoxes, 8401),
5549 (DimName::NumProtos, 32),
5550 ],
5551 },
5552 configs::Protos {
5553 decoder: configs::DecoderType::Ultralytics,
5554 shape: vec![1, 32, 160, 160],
5555 quantization: None,
5556 dshape: vec![
5557 (DimName::Batch, 1),
5558 (DimName::NumProtos, 32),
5559 (DimName::Height, 160),
5560 (DimName::Width, 160),
5561 ],
5562 },
5563 )
5564 .build();
5565
5566 assert!(matches!(
5567 result, Err(DecoderError::InvalidConfig(ref s)) if s.starts_with("Yolo Split Detection Boxes num 8400 incompatible with Mask Coefficients num 8401")));
5568 let result = DecoderBuilder::new()
5569 .with_config_yolo_split_segdet(
5570 configs::Boxes {
5571 decoder: configs::DecoderType::Ultralytics,
5572 shape: vec![1, 8400, 4],
5573 quantization: None,
5574 dshape: vec![
5575 (DimName::Batch, 1),
5576 (DimName::NumBoxes, 8400),
5577 (DimName::BoxCoords, 4),
5578 ],
5579 normalized: Some(true),
5580 },
5581 configs::Scores {
5582 decoder: configs::DecoderType::Ultralytics,
5583 shape: vec![1, 8400, 80],
5584 quantization: None,
5585 dshape: vec![
5586 (DimName::Batch, 1),
5587 (DimName::NumBoxes, 8400),
5588 (DimName::NumClasses, 80),
5589 ],
5590 },
5591 configs::MaskCoefficients {
5592 decoder: configs::DecoderType::Ultralytics,
5593 shape: vec![1, 8400, 32],
5594 quantization: None,
5595 dshape: vec![
5596 (DimName::Batch, 1),
5597 (DimName::NumBoxes, 8400),
5598 (DimName::NumProtos, 32),
5599 ],
5600 },
5601 configs::Protos {
5602 decoder: configs::DecoderType::Ultralytics,
5603 shape: vec![1, 31, 160, 160],
5604 quantization: None,
5605 dshape: vec![
5606 (DimName::Batch, 1),
5607 (DimName::NumProtos, 31),
5608 (DimName::Height, 160),
5609 (DimName::Width, 160),
5610 ],
5611 },
5612 )
5613 .build();
5614 println!("{:?}", result);
5615 assert!(matches!(
5616 result, Err(DecoderError::InvalidConfig(ref s)) if s.starts_with( "Yolo Protos channels 31 incompatible with Mask Coefficients channels 32")));
5617 }
5618
5619 #[test]
5620 fn test_modelpack_invalid_config() {
5621 let result = DecoderBuilder::new()
5622 .with_config(ConfigOutputs {
5623 outputs: vec![
5624 ConfigOutput::Boxes(configs::Boxes {
5625 decoder: configs::DecoderType::ModelPack,
5626 shape: vec![1, 8400, 1, 4],
5627 quantization: None,
5628 dshape: vec![
5629 (DimName::Batch, 1),
5630 (DimName::NumBoxes, 8400),
5631 (DimName::Padding, 1),
5632 (DimName::BoxCoords, 4),
5633 ],
5634 normalized: Some(true),
5635 }),
5636 ConfigOutput::Scores(configs::Scores {
5637 decoder: configs::DecoderType::ModelPack,
5638 shape: vec![1, 8400, 3],
5639 quantization: None,
5640 dshape: vec![
5641 (DimName::Batch, 1),
5642 (DimName::NumBoxes, 8400),
5643 (DimName::NumClasses, 3),
5644 ],
5645 }),
5646 ConfigOutput::Protos(configs::Protos {
5647 decoder: configs::DecoderType::ModelPack,
5648 shape: vec![1, 8400, 3],
5649 quantization: None,
5650 dshape: vec![
5651 (DimName::Batch, 1),
5652 (DimName::NumBoxes, 8400),
5653 (DimName::NumFeatures, 3),
5654 ],
5655 }),
5656 ],
5657 ..Default::default()
5658 })
5659 .build();
5660
5661 assert!(matches!(
5662 result, Err(DecoderError::InvalidConfig(s)) if s == "ModelPack should not have protos"));
5663
5664 let result = DecoderBuilder::new()
5665 .with_config(ConfigOutputs {
5666 outputs: vec![
5667 ConfigOutput::Boxes(configs::Boxes {
5668 decoder: configs::DecoderType::ModelPack,
5669 shape: vec![1, 8400, 1, 4],
5670 quantization: None,
5671 dshape: vec![
5672 (DimName::Batch, 1),
5673 (DimName::NumBoxes, 8400),
5674 (DimName::Padding, 1),
5675 (DimName::BoxCoords, 4),
5676 ],
5677 normalized: Some(true),
5678 }),
5679 ConfigOutput::Scores(configs::Scores {
5680 decoder: configs::DecoderType::ModelPack,
5681 shape: vec![1, 8400, 3],
5682 quantization: None,
5683 dshape: vec![
5684 (DimName::Batch, 1),
5685 (DimName::NumBoxes, 8400),
5686 (DimName::NumClasses, 3),
5687 ],
5688 }),
5689 ConfigOutput::MaskCoefficients(configs::MaskCoefficients {
5690 decoder: configs::DecoderType::ModelPack,
5691 shape: vec![1, 8400, 3],
5692 quantization: None,
5693 dshape: vec![
5694 (DimName::Batch, 1),
5695 (DimName::NumBoxes, 8400),
5696 (DimName::NumProtos, 3),
5697 ],
5698 }),
5699 ],
5700 ..Default::default()
5701 })
5702 .build();
5703
5704 assert!(matches!(
5705 result, Err(DecoderError::InvalidConfig(s)) if s == "ModelPack should not have mask coefficients"));
5706
5707 let result = DecoderBuilder::new()
5708 .with_config(ConfigOutputs {
5709 outputs: vec![ConfigOutput::Boxes(configs::Boxes {
5710 decoder: configs::DecoderType::ModelPack,
5711 shape: vec![1, 8400, 1, 4],
5712 quantization: None,
5713 dshape: vec![
5714 (DimName::Batch, 1),
5715 (DimName::NumBoxes, 8400),
5716 (DimName::Padding, 1),
5717 (DimName::BoxCoords, 4),
5718 ],
5719 normalized: Some(true),
5720 })],
5721 ..Default::default()
5722 })
5723 .build();
5724
5725 assert!(matches!(
5726 result, Err(DecoderError::InvalidConfig(s)) if s == "Invalid ModelPack model outputs"));
5727 }
5728
5729 #[test]
5730 fn test_modelpack_invalid_det() {
5731 let result = DecoderBuilder::new()
5732 .with_config_modelpack_det(
5733 configs::Boxes {
5734 decoder: DecoderType::ModelPack,
5735 quantization: None,
5736 shape: vec![1, 4, 8400],
5737 dshape: vec![
5738 (DimName::Batch, 1),
5739 (DimName::BoxCoords, 4),
5740 (DimName::NumBoxes, 8400),
5741 ],
5742 normalized: Some(true),
5743 },
5744 configs::Scores {
5745 decoder: DecoderType::ModelPack,
5746 quantization: None,
5747 shape: vec![1, 80, 8400],
5748 dshape: vec![
5749 (DimName::Batch, 1),
5750 (DimName::NumClasses, 80),
5751 (DimName::NumBoxes, 8400),
5752 ],
5753 },
5754 )
5755 .build();
5756
5757 assert!(matches!(
5758 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid ModelPack Boxes shape")));
5759
5760 let result = DecoderBuilder::new()
5761 .with_config_modelpack_det(
5762 configs::Boxes {
5763 decoder: DecoderType::ModelPack,
5764 quantization: None,
5765 shape: vec![1, 4, 1, 8400],
5766 dshape: vec![
5767 (DimName::Batch, 1),
5768 (DimName::BoxCoords, 4),
5769 (DimName::Padding, 1),
5770 (DimName::NumBoxes, 8400),
5771 ],
5772 normalized: Some(true),
5773 },
5774 configs::Scores {
5775 decoder: DecoderType::ModelPack,
5776 quantization: None,
5777 shape: vec![1, 80, 8400, 1],
5778 dshape: vec![
5779 (DimName::Batch, 1),
5780 (DimName::NumClasses, 80),
5781 (DimName::NumBoxes, 8400),
5782 (DimName::Padding, 1),
5783 ],
5784 },
5785 )
5786 .build();
5787
5788 assert!(matches!(
5789 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid ModelPack Scores shape")));
5790
5791 let result = DecoderBuilder::new()
5792 .with_config_modelpack_det(
5793 configs::Boxes {
5794 decoder: DecoderType::ModelPack,
5795 quantization: None,
5796 shape: vec![1, 4, 2, 8400],
5797 dshape: vec![
5798 (DimName::Batch, 1),
5799 (DimName::BoxCoords, 4),
5800 (DimName::Padding, 2),
5801 (DimName::NumBoxes, 8400),
5802 ],
5803 normalized: Some(true),
5804 },
5805 configs::Scores {
5806 decoder: DecoderType::ModelPack,
5807 quantization: None,
5808 shape: vec![1, 80, 8400],
5809 dshape: vec![
5810 (DimName::Batch, 1),
5811 (DimName::NumClasses, 80),
5812 (DimName::NumBoxes, 8400),
5813 ],
5814 },
5815 )
5816 .build();
5817 assert!(matches!(
5818 result, Err(DecoderError::InvalidConfig(s)) if s == "Padding dimension size must be 1"));
5819
5820 let result = DecoderBuilder::new()
5821 .with_config_modelpack_det(
5822 configs::Boxes {
5823 decoder: DecoderType::ModelPack,
5824 quantization: None,
5825 shape: vec![1, 5, 1, 8400],
5826 dshape: vec![
5827 (DimName::Batch, 1),
5828 (DimName::BoxCoords, 5),
5829 (DimName::Padding, 1),
5830 (DimName::NumBoxes, 8400),
5831 ],
5832 normalized: Some(true),
5833 },
5834 configs::Scores {
5835 decoder: DecoderType::ModelPack,
5836 quantization: None,
5837 shape: vec![1, 80, 8400],
5838 dshape: vec![
5839 (DimName::Batch, 1),
5840 (DimName::NumClasses, 80),
5841 (DimName::NumBoxes, 8400),
5842 ],
5843 },
5844 )
5845 .build();
5846
5847 assert!(matches!(
5848 result, Err(DecoderError::InvalidConfig(s)) if s == "BoxCoords dimension size must be 4"));
5849
5850 let result = DecoderBuilder::new()
5851 .with_config_modelpack_det(
5852 configs::Boxes {
5853 decoder: DecoderType::ModelPack,
5854 quantization: None,
5855 shape: vec![1, 4, 1, 8400],
5856 dshape: vec![
5857 (DimName::Batch, 1),
5858 (DimName::BoxCoords, 4),
5859 (DimName::Padding, 1),
5860 (DimName::NumBoxes, 8400),
5861 ],
5862 normalized: Some(true),
5863 },
5864 configs::Scores {
5865 decoder: DecoderType::ModelPack,
5866 quantization: None,
5867 shape: vec![1, 80, 8401],
5868 dshape: vec![
5869 (DimName::Batch, 1),
5870 (DimName::NumClasses, 80),
5871 (DimName::NumBoxes, 8401),
5872 ],
5873 },
5874 )
5875 .build();
5876
5877 assert!(matches!(
5878 result, Err(DecoderError::InvalidConfig(s)) if s == "ModelPack Detection Boxes num 8400 incompatible with Scores num 8401"));
5879 }
5880
5881 #[test]
5882 fn test_modelpack_invalid_det_split() {
5883 let result = DecoderBuilder::default()
5884 .with_config_modelpack_det_split(vec![
5885 configs::Detection {
5886 decoder: DecoderType::ModelPack,
5887 shape: vec![1, 17, 30, 18],
5888 anchors: None,
5889 quantization: None,
5890 dshape: vec![
5891 (DimName::Batch, 1),
5892 (DimName::Height, 17),
5893 (DimName::Width, 30),
5894 (DimName::NumAnchorsXFeatures, 18),
5895 ],
5896 normalized: Some(true),
5897 },
5898 configs::Detection {
5899 decoder: DecoderType::ModelPack,
5900 shape: vec![1, 9, 15, 18],
5901 anchors: None,
5902 quantization: None,
5903 dshape: vec![
5904 (DimName::Batch, 1),
5905 (DimName::Height, 9),
5906 (DimName::Width, 15),
5907 (DimName::NumAnchorsXFeatures, 18),
5908 ],
5909 normalized: Some(true),
5910 },
5911 ])
5912 .build();
5913
5914 assert!(matches!(
5915 result, Err(DecoderError::InvalidConfig(s)) if s == "ModelPack Split Detection missing anchors"));
5916
5917 let result = DecoderBuilder::default()
5918 .with_config_modelpack_det_split(vec![configs::Detection {
5919 decoder: DecoderType::ModelPack,
5920 shape: vec![1, 17, 30, 18],
5921 anchors: None,
5922 quantization: None,
5923 dshape: Vec::new(),
5924 normalized: Some(true),
5925 }])
5926 .build();
5927
5928 assert!(matches!(
5929 result, Err(DecoderError::InvalidConfig(s)) if s == "ModelPack Split Detection missing anchors"));
5930
5931 let result = DecoderBuilder::default()
5932 .with_config_modelpack_det_split(vec![configs::Detection {
5933 decoder: DecoderType::ModelPack,
5934 shape: vec![1, 17, 30, 18],
5935 anchors: Some(vec![]),
5936 quantization: None,
5937 dshape: vec![
5938 (DimName::Batch, 1),
5939 (DimName::Height, 17),
5940 (DimName::Width, 30),
5941 (DimName::NumAnchorsXFeatures, 18),
5942 ],
5943 normalized: Some(true),
5944 }])
5945 .build();
5946
5947 assert!(matches!(
5948 result, Err(DecoderError::InvalidConfig(s)) if s == "ModelPack Split Detection has zero anchors"));
5949
5950 let result = DecoderBuilder::default()
5951 .with_config_modelpack_det_split(vec![configs::Detection {
5952 decoder: DecoderType::ModelPack,
5953 shape: vec![1, 17, 30, 18, 1],
5954 anchors: Some(vec![
5955 [0.3666666, 0.3148148],
5956 [0.3874999, 0.474074],
5957 [0.5333333, 0.644444],
5958 ]),
5959 quantization: None,
5960 dshape: vec![
5961 (DimName::Batch, 1),
5962 (DimName::Height, 17),
5963 (DimName::Width, 30),
5964 (DimName::NumAnchorsXFeatures, 18),
5965 (DimName::Padding, 1),
5966 ],
5967 normalized: Some(true),
5968 }])
5969 .build();
5970
5971 assert!(matches!(
5972 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid ModelPack Split Detection shape")));
5973
5974 let result = DecoderBuilder::default()
5975 .with_config_modelpack_det_split(vec![configs::Detection {
5976 decoder: DecoderType::ModelPack,
5977 shape: vec![1, 15, 17, 30],
5978 anchors: Some(vec![
5979 [0.3666666, 0.3148148],
5980 [0.3874999, 0.474074],
5981 [0.5333333, 0.644444],
5982 ]),
5983 quantization: None,
5984 dshape: vec![
5985 (DimName::Batch, 1),
5986 (DimName::NumAnchorsXFeatures, 15),
5987 (DimName::Height, 17),
5988 (DimName::Width, 30),
5989 ],
5990 normalized: Some(true),
5991 }])
5992 .build();
5993
5994 assert!(matches!(
5995 result, Err(DecoderError::InvalidConfig(s)) if s.contains("not greater than number of anchors * 5 =")));
5996
5997 let result = DecoderBuilder::default()
5998 .with_config_modelpack_det_split(vec![configs::Detection {
5999 decoder: DecoderType::ModelPack,
6000 shape: vec![1, 17, 30, 15],
6001 anchors: Some(vec![
6002 [0.3666666, 0.3148148],
6003 [0.3874999, 0.474074],
6004 [0.5333333, 0.644444],
6005 ]),
6006 quantization: None,
6007 dshape: Vec::new(),
6008 normalized: Some(true),
6009 }])
6010 .build();
6011
6012 assert!(matches!(
6013 result, Err(DecoderError::InvalidConfig(s)) if s.contains("not greater than number of anchors * 5 =")));
6014
6015 let result = DecoderBuilder::default()
6016 .with_config_modelpack_det_split(vec![configs::Detection {
6017 decoder: DecoderType::ModelPack,
6018 shape: vec![1, 16, 17, 30],
6019 anchors: Some(vec![
6020 [0.3666666, 0.3148148],
6021 [0.3874999, 0.474074],
6022 [0.5333333, 0.644444],
6023 ]),
6024 quantization: None,
6025 dshape: vec![
6026 (DimName::Batch, 1),
6027 (DimName::NumAnchorsXFeatures, 16),
6028 (DimName::Height, 17),
6029 (DimName::Width, 30),
6030 ],
6031 normalized: Some(true),
6032 }])
6033 .build();
6034
6035 assert!(matches!(
6036 result, Err(DecoderError::InvalidConfig(s)) if s.contains("not a multiple of number of anchors")));
6037
6038 let result = DecoderBuilder::default()
6039 .with_config_modelpack_det_split(vec![configs::Detection {
6040 decoder: DecoderType::ModelPack,
6041 shape: vec![1, 17, 30, 16],
6042 anchors: Some(vec![
6043 [0.3666666, 0.3148148],
6044 [0.3874999, 0.474074],
6045 [0.5333333, 0.644444],
6046 ]),
6047 quantization: None,
6048 dshape: Vec::new(),
6049 normalized: Some(true),
6050 }])
6051 .build();
6052
6053 assert!(matches!(
6054 result, Err(DecoderError::InvalidConfig(s)) if s.contains("not a multiple of number of anchors")));
6055
6056 let result = DecoderBuilder::default()
6057 .with_config_modelpack_det_split(vec![configs::Detection {
6058 decoder: DecoderType::ModelPack,
6059 shape: vec![1, 18, 17, 30],
6060 anchors: Some(vec![
6061 [0.3666666, 0.3148148],
6062 [0.3874999, 0.474074],
6063 [0.5333333, 0.644444],
6064 ]),
6065 quantization: None,
6066 dshape: vec![
6067 (DimName::Batch, 1),
6068 (DimName::NumProtos, 18),
6069 (DimName::Height, 17),
6070 (DimName::Width, 30),
6071 ],
6072 normalized: Some(true),
6073 }])
6074 .build();
6075 assert!(matches!(
6076 result, Err(DecoderError::InvalidConfig(s)) if s.contains("Split Detection dshape missing required dimension NumAnchorsXFeature")));
6077
6078 let result = DecoderBuilder::default()
6079 .with_config_modelpack_det_split(vec![
6080 configs::Detection {
6081 decoder: DecoderType::ModelPack,
6082 shape: vec![1, 17, 30, 18],
6083 anchors: Some(vec![
6084 [0.3666666, 0.3148148],
6085 [0.3874999, 0.474074],
6086 [0.5333333, 0.644444],
6087 ]),
6088 quantization: None,
6089 dshape: vec![
6090 (DimName::Batch, 1),
6091 (DimName::Height, 17),
6092 (DimName::Width, 30),
6093 (DimName::NumAnchorsXFeatures, 18),
6094 ],
6095 normalized: Some(true),
6096 },
6097 configs::Detection {
6098 decoder: DecoderType::ModelPack,
6099 shape: vec![1, 17, 30, 21],
6100 anchors: Some(vec![
6101 [0.3666666, 0.3148148],
6102 [0.3874999, 0.474074],
6103 [0.5333333, 0.644444],
6104 ]),
6105 quantization: None,
6106 dshape: vec![
6107 (DimName::Batch, 1),
6108 (DimName::Height, 17),
6109 (DimName::Width, 30),
6110 (DimName::NumAnchorsXFeatures, 21),
6111 ],
6112 normalized: Some(true),
6113 },
6114 ])
6115 .build();
6116
6117 assert!(matches!(
6118 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("ModelPack Split Detection inconsistent number of classes:")));
6119
6120 let result = DecoderBuilder::default()
6121 .with_config_modelpack_det_split(vec![
6122 configs::Detection {
6123 decoder: DecoderType::ModelPack,
6124 shape: vec![1, 17, 30, 18],
6125 anchors: Some(vec![
6126 [0.3666666, 0.3148148],
6127 [0.3874999, 0.474074],
6128 [0.5333333, 0.644444],
6129 ]),
6130 quantization: None,
6131 dshape: vec![],
6132 normalized: Some(true),
6133 },
6134 configs::Detection {
6135 decoder: DecoderType::ModelPack,
6136 shape: vec![1, 17, 30, 21],
6137 anchors: Some(vec![
6138 [0.3666666, 0.3148148],
6139 [0.3874999, 0.474074],
6140 [0.5333333, 0.644444],
6141 ]),
6142 quantization: None,
6143 dshape: vec![],
6144 normalized: Some(true),
6145 },
6146 ])
6147 .build();
6148
6149 assert!(matches!(
6150 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("ModelPack Split Detection inconsistent number of classes:")));
6151 }
6152
6153 #[test]
6154 fn test_modelpack_invalid_seg() {
6155 let result = DecoderBuilder::new()
6156 .with_config_modelpack_seg(configs::Segmentation {
6157 decoder: DecoderType::ModelPack,
6158 quantization: None,
6159 shape: vec![1, 160, 106, 3, 1],
6160 dshape: vec![
6161 (DimName::Batch, 1),
6162 (DimName::Height, 160),
6163 (DimName::Width, 106),
6164 (DimName::NumClasses, 3),
6165 (DimName::Padding, 1),
6166 ],
6167 })
6168 .build();
6169
6170 assert!(matches!(
6171 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid ModelPack Segmentation shape")));
6172 }
6173
6174 #[test]
6175 fn test_modelpack_invalid_segdet() {
6176 let result = DecoderBuilder::new()
6177 .with_config_modelpack_segdet(
6178 configs::Boxes {
6179 decoder: DecoderType::ModelPack,
6180 quantization: None,
6181 shape: vec![1, 4, 1, 8400],
6182 dshape: vec![
6183 (DimName::Batch, 1),
6184 (DimName::BoxCoords, 4),
6185 (DimName::Padding, 1),
6186 (DimName::NumBoxes, 8400),
6187 ],
6188 normalized: Some(true),
6189 },
6190 configs::Scores {
6191 decoder: DecoderType::ModelPack,
6192 quantization: None,
6193 shape: vec![1, 4, 8400],
6194 dshape: vec![
6195 (DimName::Batch, 1),
6196 (DimName::NumClasses, 4),
6197 (DimName::NumBoxes, 8400),
6198 ],
6199 },
6200 configs::Segmentation {
6201 decoder: DecoderType::ModelPack,
6202 quantization: None,
6203 shape: vec![1, 160, 106, 3],
6204 dshape: vec![
6205 (DimName::Batch, 1),
6206 (DimName::Height, 160),
6207 (DimName::Width, 106),
6208 (DimName::NumClasses, 3),
6209 ],
6210 },
6211 )
6212 .build();
6213
6214 assert!(matches!(
6215 result, Err(DecoderError::InvalidConfig(s)) if s.contains("incompatible with number of classes")));
6216 }
6217
6218 #[test]
6219 fn test_modelpack_invalid_segdet_split() {
6220 let result = DecoderBuilder::new()
6221 .with_config_modelpack_segdet_split(
6222 vec![configs::Detection {
6223 decoder: DecoderType::ModelPack,
6224 shape: vec![1, 17, 30, 18],
6225 anchors: Some(vec![
6226 [0.3666666, 0.3148148],
6227 [0.3874999, 0.474074],
6228 [0.5333333, 0.644444],
6229 ]),
6230 quantization: None,
6231 dshape: vec![
6232 (DimName::Batch, 1),
6233 (DimName::Height, 17),
6234 (DimName::Width, 30),
6235 (DimName::NumAnchorsXFeatures, 18),
6236 ],
6237 normalized: Some(true),
6238 }],
6239 configs::Segmentation {
6240 decoder: DecoderType::ModelPack,
6241 quantization: None,
6242 shape: vec![1, 160, 106, 3],
6243 dshape: vec![
6244 (DimName::Batch, 1),
6245 (DimName::Height, 160),
6246 (DimName::Width, 106),
6247 (DimName::NumClasses, 3),
6248 ],
6249 },
6250 )
6251 .build();
6252
6253 assert!(matches!(
6254 result, Err(DecoderError::InvalidConfig(s)) if s.contains("incompatible with number of classes")));
6255 }
6256
6257 #[test]
6258 fn test_decode_bad_shapes() {
6259 let score_threshold = 0.25;
6260 let iou_threshold = 0.7;
6261 let quant = (0.0040811873, -123);
6262 let out = include_bytes!("../../../testdata/yolov8s_80_classes.bin");
6263 let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
6264 let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
6265 let out_float: Array3<f32> = dequantize_ndarray(out.view(), quant.into());
6266
6267 let decoder = DecoderBuilder::default()
6268 .with_config_yolo_det(
6269 configs::Detection {
6270 decoder: DecoderType::Ultralytics,
6271 shape: vec![1, 85, 8400],
6272 anchors: None,
6273 quantization: Some(quant.into()),
6274 dshape: vec![
6275 (DimName::Batch, 1),
6276 (DimName::NumFeatures, 85),
6277 (DimName::NumBoxes, 8400),
6278 ],
6279 normalized: Some(true),
6280 },
6281 Some(DecoderVersion::Yolo11),
6282 )
6283 .with_score_threshold(score_threshold)
6284 .with_iou_threshold(iou_threshold)
6285 .build()
6286 .unwrap();
6287
6288 let mut output_boxes: Vec<_> = Vec::with_capacity(50);
6289 let mut output_masks: Vec<_> = Vec::with_capacity(50);
6290 let result =
6291 decoder.decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks);
6292
6293 assert!(matches!(
6294 result, Err(DecoderError::InvalidShape(s)) if s == "Did not find output with shape [1, 85, 8400]"));
6295
6296 let result = decoder.decode_float(
6297 &[out_float.view().into_dyn()],
6298 &mut output_boxes,
6299 &mut output_masks,
6300 );
6301
6302 assert!(matches!(
6303 result, Err(DecoderError::InvalidShape(s)) if s == "Did not find output with shape [1, 85, 8400]"));
6304 }
6305
6306 #[test]
6307 fn test_config_outputs() {
6308 let outputs = [
6309 ConfigOutput::Detection(configs::Detection {
6310 decoder: configs::DecoderType::Ultralytics,
6311 anchors: None,
6312 shape: vec![1, 8400, 85],
6313 quantization: Some(QuantTuple(0.123, 0)),
6314 dshape: vec![
6315 (DimName::Batch, 1),
6316 (DimName::NumBoxes, 8400),
6317 (DimName::NumFeatures, 85),
6318 ],
6319 normalized: Some(true),
6320 }),
6321 ConfigOutput::Mask(configs::Mask {
6322 decoder: configs::DecoderType::Ultralytics,
6323 shape: vec![1, 160, 160, 1],
6324 quantization: Some(QuantTuple(0.223, 0)),
6325 dshape: vec![
6326 (DimName::Batch, 1),
6327 (DimName::Height, 160),
6328 (DimName::Width, 160),
6329 (DimName::NumFeatures, 1),
6330 ],
6331 }),
6332 ConfigOutput::Segmentation(configs::Segmentation {
6333 decoder: configs::DecoderType::Ultralytics,
6334 shape: vec![1, 160, 160, 80],
6335 quantization: Some(QuantTuple(0.323, 0)),
6336 dshape: vec![
6337 (DimName::Batch, 1),
6338 (DimName::Height, 160),
6339 (DimName::Width, 160),
6340 (DimName::NumClasses, 80),
6341 ],
6342 }),
6343 ConfigOutput::Scores(configs::Scores {
6344 decoder: configs::DecoderType::Ultralytics,
6345 shape: vec![1, 8400, 80],
6346 quantization: Some(QuantTuple(0.423, 0)),
6347 dshape: vec![
6348 (DimName::Batch, 1),
6349 (DimName::NumBoxes, 8400),
6350 (DimName::NumClasses, 80),
6351 ],
6352 }),
6353 ConfigOutput::Boxes(configs::Boxes {
6354 decoder: configs::DecoderType::Ultralytics,
6355 shape: vec![1, 8400, 4],
6356 quantization: Some(QuantTuple(0.523, 0)),
6357 dshape: vec![
6358 (DimName::Batch, 1),
6359 (DimName::NumBoxes, 8400),
6360 (DimName::BoxCoords, 4),
6361 ],
6362 normalized: Some(true),
6363 }),
6364 ConfigOutput::Protos(configs::Protos {
6365 decoder: configs::DecoderType::Ultralytics,
6366 shape: vec![1, 32, 160, 160],
6367 quantization: Some(QuantTuple(0.623, 0)),
6368 dshape: vec![
6369 (DimName::Batch, 1),
6370 (DimName::NumProtos, 32),
6371 (DimName::Height, 160),
6372 (DimName::Width, 160),
6373 ],
6374 }),
6375 ConfigOutput::MaskCoefficients(configs::MaskCoefficients {
6376 decoder: configs::DecoderType::Ultralytics,
6377 shape: vec![1, 8400, 32],
6378 quantization: Some(QuantTuple(0.723, 0)),
6379 dshape: vec![
6380 (DimName::Batch, 1),
6381 (DimName::NumBoxes, 8400),
6382 (DimName::NumProtos, 32),
6383 ],
6384 }),
6385 ];
6386
6387 let shapes = outputs.clone().map(|x| x.shape().to_vec());
6388 assert_eq!(
6389 shapes,
6390 [
6391 vec![1, 8400, 85],
6392 vec![1, 160, 160, 1],
6393 vec![1, 160, 160, 80],
6394 vec![1, 8400, 80],
6395 vec![1, 8400, 4],
6396 vec![1, 32, 160, 160],
6397 vec![1, 8400, 32],
6398 ]
6399 );
6400
6401 let quants: [Option<(f32, i32)>; 7] = outputs.map(|x| x.quantization().map(|q| q.into()));
6402 assert_eq!(
6403 quants,
6404 [
6405 Some((0.123, 0)),
6406 Some((0.223, 0)),
6407 Some((0.323, 0)),
6408 Some((0.423, 0)),
6409 Some((0.523, 0)),
6410 Some((0.623, 0)),
6411 Some((0.723, 0)),
6412 ]
6413 );
6414 }
6415
6416 #[test]
6417 fn test_nms_from_config_yaml() {
6418 let yaml_class_agnostic = r#"
6420outputs:
6421 - decoder: ultralytics
6422 type: detection
6423 shape: [1, 84, 8400]
6424 dshape:
6425 - [batch, 1]
6426 - [num_features, 84]
6427 - [num_boxes, 8400]
6428nms: class_agnostic
6429"#;
6430 let decoder = DecoderBuilder::new()
6431 .with_config_yaml_str(yaml_class_agnostic.to_string())
6432 .build()
6433 .unwrap();
6434 assert_eq!(decoder.nms, Some(configs::Nms::ClassAgnostic));
6435
6436 let yaml_class_aware = r#"
6437outputs:
6438 - decoder: ultralytics
6439 type: detection
6440 shape: [1, 84, 8400]
6441 dshape:
6442 - [batch, 1]
6443 - [num_features, 84]
6444 - [num_boxes, 8400]
6445nms: class_aware
6446"#;
6447 let decoder = DecoderBuilder::new()
6448 .with_config_yaml_str(yaml_class_aware.to_string())
6449 .build()
6450 .unwrap();
6451 assert_eq!(decoder.nms, Some(configs::Nms::ClassAware));
6452
6453 let decoder = DecoderBuilder::new()
6455 .with_config_yaml_str(yaml_class_aware.to_string())
6456 .with_nms(Some(configs::Nms::ClassAgnostic)) .build()
6458 .unwrap();
6459 assert_eq!(decoder.nms, Some(configs::Nms::ClassAware));
6461 }
6462
6463 #[test]
6464 fn test_nms_from_config_json() {
6465 let json_class_aware = r#"{
6467 "outputs": [{
6468 "decoder": "ultralytics",
6469 "type": "detection",
6470 "shape": [1, 84, 8400],
6471 "dshape": [["batch", 1], ["num_features", 84], ["num_boxes", 8400]]
6472 }],
6473 "nms": "class_aware"
6474 }"#;
6475 let decoder = DecoderBuilder::new()
6476 .with_config_json_str(json_class_aware.to_string())
6477 .build()
6478 .unwrap();
6479 assert_eq!(decoder.nms, Some(configs::Nms::ClassAware));
6480 }
6481
6482 #[test]
6483 fn test_nms_missing_from_config_uses_builder_default() {
6484 let yaml_no_nms = r#"
6486outputs:
6487 - decoder: ultralytics
6488 type: detection
6489 shape: [1, 84, 8400]
6490 dshape:
6491 - [batch, 1]
6492 - [num_features, 84]
6493 - [num_boxes, 8400]
6494"#;
6495 let decoder = DecoderBuilder::new()
6496 .with_config_yaml_str(yaml_no_nms.to_string())
6497 .build()
6498 .unwrap();
6499 assert_eq!(decoder.nms, Some(configs::Nms::ClassAgnostic));
6501
6502 let decoder = DecoderBuilder::new()
6504 .with_config_yaml_str(yaml_no_nms.to_string())
6505 .with_nms(None) .build()
6507 .unwrap();
6508 assert_eq!(decoder.nms, None);
6509 }
6510
6511 #[test]
6512 fn test_decoder_version_yolo26_end_to_end() {
6513 let yaml = r#"
6515outputs:
6516 - decoder: ultralytics
6517 type: detection
6518 shape: [1, 6, 8400]
6519 dshape:
6520 - [batch, 1]
6521 - [num_features, 6]
6522 - [num_boxes, 8400]
6523decoder_version: yolo26
6524"#;
6525 let decoder = DecoderBuilder::new()
6526 .with_config_yaml_str(yaml.to_string())
6527 .build()
6528 .unwrap();
6529 assert!(matches!(
6530 decoder.model_type,
6531 ModelType::YoloEndToEndDet { .. }
6532 ));
6533
6534 let yaml_with_nms = r#"
6536outputs:
6537 - decoder: ultralytics
6538 type: detection
6539 shape: [1, 6, 8400]
6540 dshape:
6541 - [batch, 1]
6542 - [num_features, 6]
6543 - [num_boxes, 8400]
6544decoder_version: yolo26
6545nms: class_agnostic
6546"#;
6547 let decoder = DecoderBuilder::new()
6548 .with_config_yaml_str(yaml_with_nms.to_string())
6549 .build()
6550 .unwrap();
6551 assert!(matches!(
6552 decoder.model_type,
6553 ModelType::YoloEndToEndDet { .. }
6554 ));
6555 }
6556
6557 #[test]
6558 fn test_decoder_version_yolov8_traditional() {
6559 let yaml = r#"
6561outputs:
6562 - decoder: ultralytics
6563 type: detection
6564 shape: [1, 84, 8400]
6565 dshape:
6566 - [batch, 1]
6567 - [num_features, 84]
6568 - [num_boxes, 8400]
6569decoder_version: yolov8
6570"#;
6571 let decoder = DecoderBuilder::new()
6572 .with_config_yaml_str(yaml.to_string())
6573 .build()
6574 .unwrap();
6575 assert!(matches!(decoder.model_type, ModelType::YoloDet { .. }));
6576 }
6577
6578 #[test]
6579 fn test_decoder_version_all_versions() {
6580 for version in ["yolov5", "yolov8", "yolo11"] {
6582 let yaml = format!(
6583 r#"
6584outputs:
6585 - decoder: ultralytics
6586 type: detection
6587 shape: [1, 84, 8400]
6588 dshape:
6589 - [batch, 1]
6590 - [num_features, 84]
6591 - [num_boxes, 8400]
6592decoder_version: {}
6593"#,
6594 version
6595 );
6596 let decoder = DecoderBuilder::new()
6597 .with_config_yaml_str(yaml)
6598 .build()
6599 .unwrap();
6600
6601 assert!(
6602 matches!(decoder.model_type, ModelType::YoloDet { .. }),
6603 "Expected traditional for {}",
6604 version
6605 );
6606 }
6607
6608 let yaml = r#"
6609outputs:
6610 - decoder: ultralytics
6611 type: detection
6612 shape: [1, 6, 8400]
6613 dshape:
6614 - [batch, 1]
6615 - [num_features, 6]
6616 - [num_boxes, 8400]
6617decoder_version: yolo26
6618"#
6619 .to_string();
6620
6621 let decoder = DecoderBuilder::new()
6622 .with_config_yaml_str(yaml)
6623 .build()
6624 .unwrap();
6625
6626 assert!(
6627 matches!(decoder.model_type, ModelType::YoloEndToEndDet { .. }),
6628 "Expected end to end for yolo26",
6629 );
6630 }
6631
6632 #[test]
6633 fn test_decoder_version_json() {
6634 let json = r#"{
6636 "outputs": [{
6637 "decoder": "ultralytics",
6638 "type": "detection",
6639 "shape": [1, 6, 8400],
6640 "dshape": [["batch", 1], ["num_features", 6], ["num_boxes", 8400]]
6641 }],
6642 "decoder_version": "yolo26"
6643 }"#;
6644 let decoder = DecoderBuilder::new()
6645 .with_config_json_str(json.to_string())
6646 .build()
6647 .unwrap();
6648 assert!(matches!(
6649 decoder.model_type,
6650 ModelType::YoloEndToEndDet { .. }
6651 ));
6652 }
6653
6654 #[test]
6655 fn test_decoder_version_none_uses_traditional() {
6656 let yaml = r#"
6658outputs:
6659 - decoder: ultralytics
6660 type: detection
6661 shape: [1, 84, 8400]
6662 dshape:
6663 - [batch, 1]
6664 - [num_features, 84]
6665 - [num_boxes, 8400]
6666"#;
6667 let decoder = DecoderBuilder::new()
6668 .with_config_yaml_str(yaml.to_string())
6669 .build()
6670 .unwrap();
6671 assert!(matches!(decoder.model_type, ModelType::YoloDet { .. }));
6672 }
6673
6674 #[test]
6675 fn test_decoder_version_none_with_nms_none_still_traditional() {
6676 let yaml = r#"
6679outputs:
6680 - decoder: ultralytics
6681 type: detection
6682 shape: [1, 84, 8400]
6683 dshape:
6684 - [batch, 1]
6685 - [num_features, 84]
6686 - [num_boxes, 8400]
6687"#;
6688 let decoder = DecoderBuilder::new()
6689 .with_config_yaml_str(yaml.to_string())
6690 .with_nms(None) .build()
6692 .unwrap();
6693 assert!(matches!(decoder.model_type, ModelType::YoloDet { .. }));
6696 }
6697
6698 #[test]
6699 fn test_decoder_heuristic_end_to_end_detection() {
6700 let yaml = r#"
6703outputs:
6704 - decoder: ultralytics
6705 type: detection
6706 shape: [1, 300, 6]
6707 dshape:
6708 - [batch, 1]
6709 - [num_boxes, 300]
6710 - [num_features, 6]
6711
6712"#;
6713 let decoder = DecoderBuilder::new()
6714 .with_config_yaml_str(yaml.to_string())
6715 .build()
6716 .unwrap();
6717 assert!(matches!(
6719 decoder.model_type,
6720 ModelType::YoloEndToEndDet { .. }
6721 ));
6722
6723 let yaml = r#"
6724outputs:
6725 - decoder: ultralytics
6726 type: detection
6727 shape: [1, 300, 38]
6728 dshape:
6729 - [batch, 1]
6730 - [num_boxes, 300]
6731 - [num_features, 38]
6732 - decoder: ultralytics
6733 type: protos
6734 shape: [1, 160, 160, 32]
6735 dshape:
6736 - [batch, 1]
6737 - [height, 160]
6738 - [width, 160]
6739 - [num_protos, 32]
6740"#;
6741 let decoder = DecoderBuilder::new()
6742 .with_config_yaml_str(yaml.to_string())
6743 .build()
6744 .unwrap();
6745 assert!(matches!(
6747 decoder.model_type,
6748 ModelType::YoloEndToEndSegDet { .. }
6749 ));
6750
6751 let yaml = r#"
6752outputs:
6753 - decoder: ultralytics
6754 type: detection
6755 shape: [1, 6, 300]
6756 dshape:
6757 - [batch, 1]
6758 - [num_features, 6]
6759 - [num_boxes, 300]
6760"#;
6761 let decoder = DecoderBuilder::new()
6762 .with_config_yaml_str(yaml.to_string())
6763 .build()
6764 .unwrap();
6765 assert!(matches!(decoder.model_type, ModelType::YoloDet { .. }));
6768
6769 let yaml = r#"
6770outputs:
6771 - decoder: ultralytics
6772 type: detection
6773 shape: [1, 38, 300]
6774 dshape:
6775 - [batch, 1]
6776 - [num_features, 38]
6777 - [num_boxes, 300]
6778
6779 - decoder: ultralytics
6780 type: protos
6781 shape: [1, 160, 160, 32]
6782 dshape:
6783 - [batch, 1]
6784 - [height, 160]
6785 - [width, 160]
6786 - [num_protos, 32]
6787"#;
6788 let decoder = DecoderBuilder::new()
6789 .with_config_yaml_str(yaml.to_string())
6790 .build()
6791 .unwrap();
6792 assert!(matches!(decoder.model_type, ModelType::YoloSegDet { .. }));
6794 }
6795
6796 #[test]
6797 fn test_decoder_version_is_end_to_end() {
6798 assert!(!configs::DecoderVersion::Yolov5.is_end_to_end());
6799 assert!(!configs::DecoderVersion::Yolov8.is_end_to_end());
6800 assert!(!configs::DecoderVersion::Yolo11.is_end_to_end());
6801 assert!(configs::DecoderVersion::Yolo26.is_end_to_end());
6802 }
6803
6804 #[test]
6805 fn test_dshape_dict_format() {
6806 let json = r#"{
6808 "decoder": "ultralytics",
6809 "shape": [1, 84, 8400],
6810 "dshape": [{"batch": 1}, {"num_features": 84}, {"num_boxes": 8400}]
6811 }"#;
6812 let det: configs::Detection = serde_json::from_str(json).unwrap();
6813 assert_eq!(det.dshape.len(), 3);
6814 assert_eq!(det.dshape[0], (configs::DimName::Batch, 1));
6815 assert_eq!(det.dshape[1], (configs::DimName::NumFeatures, 84));
6816 assert_eq!(det.dshape[2], (configs::DimName::NumBoxes, 8400));
6817 }
6818
6819 #[test]
6820 fn test_dshape_tuple_format() {
6821 let json = r#"{
6823 "decoder": "ultralytics",
6824 "shape": [1, 84, 8400],
6825 "dshape": [["batch", 1], ["num_features", 84], ["num_boxes", 8400]]
6826 }"#;
6827 let det: configs::Detection = serde_json::from_str(json).unwrap();
6828 assert_eq!(det.dshape.len(), 3);
6829 assert_eq!(det.dshape[0], (configs::DimName::Batch, 1));
6830 assert_eq!(det.dshape[1], (configs::DimName::NumFeatures, 84));
6831 assert_eq!(det.dshape[2], (configs::DimName::NumBoxes, 8400));
6832 }
6833
6834 #[test]
6835 fn test_dshape_empty_default() {
6836 let json = r#"{
6838 "decoder": "ultralytics",
6839 "shape": [1, 84, 8400]
6840 }"#;
6841 let det: configs::Detection = serde_json::from_str(json).unwrap();
6842 assert!(det.dshape.is_empty());
6843 }
6844
6845 #[test]
6846 fn test_dshape_dict_format_protos() {
6847 let json = r#"{
6848 "decoder": "ultralytics",
6849 "shape": [1, 32, 160, 160],
6850 "dshape": [{"batch": 1}, {"num_protos": 32}, {"height": 160}, {"width": 160}]
6851 }"#;
6852 let protos: configs::Protos = serde_json::from_str(json).unwrap();
6853 assert_eq!(protos.dshape.len(), 4);
6854 assert_eq!(protos.dshape[0], (configs::DimName::Batch, 1));
6855 assert_eq!(protos.dshape[1], (configs::DimName::NumProtos, 32));
6856 }
6857
6858 #[test]
6859 fn test_dshape_dict_format_boxes() {
6860 let json = r#"{
6861 "decoder": "ultralytics",
6862 "shape": [1, 8400, 4],
6863 "dshape": [{"batch": 1}, {"num_boxes": 8400}, {"box_coords": 4}]
6864 }"#;
6865 let boxes: configs::Boxes = serde_json::from_str(json).unwrap();
6866 assert_eq!(boxes.dshape.len(), 3);
6867 assert_eq!(boxes.dshape[2], (configs::DimName::BoxCoords, 4));
6868 }
6869}