1use std::collections::HashSet;
5
6use ndarray::{s, Array3, ArrayView, ArrayViewD, Dimension};
7use ndarray_stats::QuantileExt;
8use num_traits::{AsPrimitive, Float};
9use serde::{Deserialize, Serialize};
10
11use crate::{
12 configs::{DecoderType, DimName, ModelType, QuantTuple},
13 dequantize_ndarray,
14 modelpack::{
15 decode_modelpack_det, decode_modelpack_float, decode_modelpack_split_float,
16 ModelPackDetectionConfig,
17 },
18 yolo::{
19 decode_yolo_det, decode_yolo_det_float, decode_yolo_segdet_float, decode_yolo_segdet_quant,
20 decode_yolo_split_det_float, decode_yolo_split_det_quant, decode_yolo_split_segdet_float,
21 impl_yolo_split_segdet_quant_get_boxes, impl_yolo_split_segdet_quant_process_masks,
22 },
23 DecoderError, DecoderVersion, DetectBox, ProtoData, Quantization, Segmentation, XYWH,
24};
25
26#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
38pub struct ConfigOutputs {
39 #[serde(default)]
40 pub outputs: Vec<ConfigOutput>,
41 #[serde(default, skip_serializing_if = "Option::is_none")]
49 pub nms: Option<configs::Nms>,
50 #[serde(default, skip_serializing_if = "Option::is_none")]
57 pub decoder_version: Option<configs::DecoderVersion>,
58}
59
60#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
61#[serde(tag = "type")]
62pub enum ConfigOutput {
63 #[serde(rename = "detection")]
64 Detection(configs::Detection),
65 #[serde(rename = "masks")]
66 Mask(configs::Mask),
67 #[serde(rename = "segmentation")]
68 Segmentation(configs::Segmentation),
69 #[serde(rename = "protos")]
70 Protos(configs::Protos),
71 #[serde(rename = "scores")]
72 Scores(configs::Scores),
73 #[serde(rename = "boxes")]
74 Boxes(configs::Boxes),
75 #[serde(rename = "mask_coefficients")]
76 MaskCoefficients(configs::MaskCoefficients),
77 #[serde(rename = "classes")]
78 Classes(configs::Classes),
79}
80
81#[derive(Debug, PartialEq, Clone)]
82pub enum ConfigOutputRef<'a> {
83 Detection(&'a configs::Detection),
84 Mask(&'a configs::Mask),
85 Segmentation(&'a configs::Segmentation),
86 Protos(&'a configs::Protos),
87 Scores(&'a configs::Scores),
88 Boxes(&'a configs::Boxes),
89 MaskCoefficients(&'a configs::MaskCoefficients),
90 Classes(&'a configs::Classes),
91}
92
93impl<'a> ConfigOutputRef<'a> {
94 fn decoder(&self) -> configs::DecoderType {
95 match self {
96 ConfigOutputRef::Detection(v) => v.decoder,
97 ConfigOutputRef::Mask(v) => v.decoder,
98 ConfigOutputRef::Segmentation(v) => v.decoder,
99 ConfigOutputRef::Protos(v) => v.decoder,
100 ConfigOutputRef::Scores(v) => v.decoder,
101 ConfigOutputRef::Boxes(v) => v.decoder,
102 ConfigOutputRef::MaskCoefficients(v) => v.decoder,
103 ConfigOutputRef::Classes(v) => v.decoder,
104 }
105 }
106
107 fn dshape(&self) -> &[(DimName, usize)] {
108 match self {
109 ConfigOutputRef::Detection(v) => &v.dshape,
110 ConfigOutputRef::Mask(v) => &v.dshape,
111 ConfigOutputRef::Segmentation(v) => &v.dshape,
112 ConfigOutputRef::Protos(v) => &v.dshape,
113 ConfigOutputRef::Scores(v) => &v.dshape,
114 ConfigOutputRef::Boxes(v) => &v.dshape,
115 ConfigOutputRef::MaskCoefficients(v) => &v.dshape,
116 ConfigOutputRef::Classes(v) => &v.dshape,
117 }
118 }
119}
120
121impl<'a> From<&'a configs::Detection> for ConfigOutputRef<'a> {
122 fn from(v: &'a configs::Detection) -> ConfigOutputRef<'a> {
137 ConfigOutputRef::Detection(v)
138 }
139}
140
141impl<'a> From<&'a configs::Mask> for ConfigOutputRef<'a> {
142 fn from(v: &'a configs::Mask) -> ConfigOutputRef<'a> {
155 ConfigOutputRef::Mask(v)
156 }
157}
158
159impl<'a> From<&'a configs::Segmentation> for ConfigOutputRef<'a> {
160 fn from(v: &'a configs::Segmentation) -> ConfigOutputRef<'a> {
173 ConfigOutputRef::Segmentation(v)
174 }
175}
176
177impl<'a> From<&'a configs::Protos> for ConfigOutputRef<'a> {
178 fn from(v: &'a configs::Protos) -> ConfigOutputRef<'a> {
191 ConfigOutputRef::Protos(v)
192 }
193}
194
195impl<'a> From<&'a configs::Scores> for ConfigOutputRef<'a> {
196 fn from(v: &'a configs::Scores) -> ConfigOutputRef<'a> {
209 ConfigOutputRef::Scores(v)
210 }
211}
212
213impl<'a> From<&'a configs::Boxes> for ConfigOutputRef<'a> {
214 fn from(v: &'a configs::Boxes) -> ConfigOutputRef<'a> {
228 ConfigOutputRef::Boxes(v)
229 }
230}
231
232impl<'a> From<&'a configs::MaskCoefficients> for ConfigOutputRef<'a> {
233 fn from(v: &'a configs::MaskCoefficients) -> ConfigOutputRef<'a> {
246 ConfigOutputRef::MaskCoefficients(v)
247 }
248}
249
250impl<'a> From<&'a configs::Classes> for ConfigOutputRef<'a> {
251 fn from(v: &'a configs::Classes) -> ConfigOutputRef<'a> {
252 ConfigOutputRef::Classes(v)
253 }
254}
255
256impl ConfigOutput {
257 pub fn shape(&self) -> &[usize] {
274 match self {
275 ConfigOutput::Detection(detection) => &detection.shape,
276 ConfigOutput::Mask(mask) => &mask.shape,
277 ConfigOutput::Segmentation(segmentation) => &segmentation.shape,
278 ConfigOutput::Scores(scores) => &scores.shape,
279 ConfigOutput::Boxes(boxes) => &boxes.shape,
280 ConfigOutput::Protos(protos) => &protos.shape,
281 ConfigOutput::MaskCoefficients(mask_coefficients) => &mask_coefficients.shape,
282 ConfigOutput::Classes(classes) => &classes.shape,
283 }
284 }
285
286 pub fn decoder(&self) -> &configs::DecoderType {
303 match self {
304 ConfigOutput::Detection(detection) => &detection.decoder,
305 ConfigOutput::Mask(mask) => &mask.decoder,
306 ConfigOutput::Segmentation(segmentation) => &segmentation.decoder,
307 ConfigOutput::Scores(scores) => &scores.decoder,
308 ConfigOutput::Boxes(boxes) => &boxes.decoder,
309 ConfigOutput::Protos(protos) => &protos.decoder,
310 ConfigOutput::MaskCoefficients(mask_coefficients) => &mask_coefficients.decoder,
311 ConfigOutput::Classes(classes) => &classes.decoder,
312 }
313 }
314
315 pub fn quantization(&self) -> Option<QuantTuple> {
332 match self {
333 ConfigOutput::Detection(detection) => detection.quantization,
334 ConfigOutput::Mask(mask) => mask.quantization,
335 ConfigOutput::Segmentation(segmentation) => segmentation.quantization,
336 ConfigOutput::Scores(scores) => scores.quantization,
337 ConfigOutput::Boxes(boxes) => boxes.quantization,
338 ConfigOutput::Protos(protos) => protos.quantization,
339 ConfigOutput::MaskCoefficients(mask_coefficients) => mask_coefficients.quantization,
340 ConfigOutput::Classes(classes) => classes.quantization,
341 }
342 }
343}
344
345pub mod configs {
346 use std::collections::HashMap;
347 use std::fmt::Display;
348
349 use serde::{Deserialize, Serialize};
350
351 pub fn deserialize_dshape<'de, D>(deserializer: D) -> Result<Vec<(DimName, usize)>, D::Error>
357 where
358 D: serde::Deserializer<'de>,
359 {
360 #[derive(Deserialize)]
361 #[serde(untagged)]
362 enum DShapeItem {
363 Tuple(DimName, usize),
364 Map(HashMap<DimName, usize>),
365 }
366
367 let items: Vec<DShapeItem> = Vec::deserialize(deserializer)?;
368 items
369 .into_iter()
370 .map(|item| match item {
371 DShapeItem::Tuple(name, size) => Ok((name, size)),
372 DShapeItem::Map(map) => {
373 if map.len() != 1 {
374 return Err(serde::de::Error::custom(
375 "dshape map entry must have exactly one key",
376 ));
377 }
378 let (name, size) = map.into_iter().next().unwrap();
379 Ok((name, size))
380 }
381 })
382 .collect()
383 }
384
385 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy)]
386 pub struct QuantTuple(pub f32, pub i32);
387 impl From<QuantTuple> for (f32, i32) {
388 fn from(value: QuantTuple) -> Self {
389 (value.0, value.1)
390 }
391 }
392
393 impl From<(f32, i32)> for QuantTuple {
394 fn from(value: (f32, i32)) -> Self {
395 QuantTuple(value.0, value.1)
396 }
397 }
398
399 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
400 pub struct Segmentation {
401 #[serde(default)]
402 pub decoder: DecoderType,
403 #[serde(default)]
404 pub quantization: Option<QuantTuple>,
405 #[serde(default)]
406 pub shape: Vec<usize>,
407 #[serde(default, deserialize_with = "deserialize_dshape")]
408 pub dshape: Vec<(DimName, usize)>,
409 }
410
411 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
412 pub struct Protos {
413 #[serde(default)]
414 pub decoder: DecoderType,
415 #[serde(default)]
416 pub quantization: Option<QuantTuple>,
417 #[serde(default)]
418 pub shape: Vec<usize>,
419 #[serde(default, deserialize_with = "deserialize_dshape")]
420 pub dshape: Vec<(DimName, usize)>,
421 }
422
423 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
424 pub struct MaskCoefficients {
425 #[serde(default)]
426 pub decoder: DecoderType,
427 #[serde(default)]
428 pub quantization: Option<QuantTuple>,
429 #[serde(default)]
430 pub shape: Vec<usize>,
431 #[serde(default, deserialize_with = "deserialize_dshape")]
432 pub dshape: Vec<(DimName, usize)>,
433 }
434
435 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
436 pub struct Mask {
437 #[serde(default)]
438 pub decoder: DecoderType,
439 #[serde(default)]
440 pub quantization: Option<QuantTuple>,
441 #[serde(default)]
442 pub shape: Vec<usize>,
443 #[serde(default, deserialize_with = "deserialize_dshape")]
444 pub dshape: Vec<(DimName, usize)>,
445 }
446
447 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
448 pub struct Detection {
449 #[serde(default)]
450 pub anchors: Option<Vec<[f32; 2]>>,
451 #[serde(default)]
452 pub decoder: DecoderType,
453 #[serde(default)]
454 pub quantization: Option<QuantTuple>,
455 #[serde(default)]
456 pub shape: Vec<usize>,
457 #[serde(default, deserialize_with = "deserialize_dshape")]
458 pub dshape: Vec<(DimName, usize)>,
459 #[serde(default)]
466 pub normalized: Option<bool>,
467 }
468
469 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
470 pub struct Scores {
471 #[serde(default)]
472 pub decoder: DecoderType,
473 #[serde(default)]
474 pub quantization: Option<QuantTuple>,
475 #[serde(default)]
476 pub shape: Vec<usize>,
477 #[serde(default, deserialize_with = "deserialize_dshape")]
478 pub dshape: Vec<(DimName, usize)>,
479 }
480
481 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
482 pub struct Boxes {
483 #[serde(default)]
484 pub decoder: DecoderType,
485 #[serde(default)]
486 pub quantization: Option<QuantTuple>,
487 #[serde(default)]
488 pub shape: Vec<usize>,
489 #[serde(default, deserialize_with = "deserialize_dshape")]
490 pub dshape: Vec<(DimName, usize)>,
491 #[serde(default)]
498 pub normalized: Option<bool>,
499 }
500
501 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
502 pub struct Classes {
503 #[serde(default)]
504 pub decoder: DecoderType,
505 #[serde(default)]
506 pub quantization: Option<QuantTuple>,
507 #[serde(default)]
508 pub shape: Vec<usize>,
509 #[serde(default, deserialize_with = "deserialize_dshape")]
510 pub dshape: Vec<(DimName, usize)>,
511 }
512
513 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq)]
514 pub enum DimName {
515 #[serde(rename = "batch")]
516 Batch,
517 #[serde(rename = "height")]
518 Height,
519 #[serde(rename = "width")]
520 Width,
521 #[serde(rename = "num_classes")]
522 NumClasses,
523 #[serde(rename = "num_features")]
524 NumFeatures,
525 #[serde(rename = "num_boxes")]
526 NumBoxes,
527 #[serde(rename = "num_protos")]
528 NumProtos,
529 #[serde(rename = "num_anchors_x_features")]
530 NumAnchorsXFeatures,
531 #[serde(rename = "padding")]
532 Padding,
533 #[serde(rename = "box_coords")]
534 BoxCoords,
535 }
536
537 impl Display for DimName {
538 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
548 match self {
549 DimName::Batch => write!(f, "batch"),
550 DimName::Height => write!(f, "height"),
551 DimName::Width => write!(f, "width"),
552 DimName::NumClasses => write!(f, "num_classes"),
553 DimName::NumFeatures => write!(f, "num_features"),
554 DimName::NumBoxes => write!(f, "num_boxes"),
555 DimName::NumProtos => write!(f, "num_protos"),
556 DimName::NumAnchorsXFeatures => write!(f, "num_anchors_x_features"),
557 DimName::Padding => write!(f, "padding"),
558 DimName::BoxCoords => write!(f, "box_coords"),
559 }
560 }
561 }
562
563 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq, Default)]
564 pub enum DecoderType {
565 #[serde(rename = "modelpack")]
566 ModelPack,
567 #[default]
568 #[serde(rename = "ultralytics", alias = "yolov8")]
569 Ultralytics,
570 }
571
572 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq)]
584 #[serde(rename_all = "lowercase")]
585 pub enum DecoderVersion {
586 #[serde(rename = "yolov5")]
588 Yolov5,
589 #[serde(rename = "yolov8")]
591 Yolov8,
592 #[serde(rename = "yolo11")]
594 Yolo11,
595 #[serde(rename = "yolo26")]
598 Yolo26,
599 }
600
601 impl DecoderVersion {
602 pub fn is_end_to_end(&self) -> bool {
605 matches!(self, DecoderVersion::Yolo26)
606 }
607 }
608
609 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq, Default)]
618 #[serde(rename_all = "snake_case")]
619 pub enum Nms {
620 #[default]
623 ClassAgnostic,
624 ClassAware,
626 }
627
628 #[derive(Debug, Clone, PartialEq)]
629 pub enum ModelType {
630 ModelPackSegDet {
631 boxes: Boxes,
632 scores: Scores,
633 segmentation: Segmentation,
634 },
635 ModelPackSegDetSplit {
636 detection: Vec<Detection>,
637 segmentation: Segmentation,
638 },
639 ModelPackDet {
640 boxes: Boxes,
641 scores: Scores,
642 },
643 ModelPackDetSplit {
644 detection: Vec<Detection>,
645 },
646 ModelPackSeg {
647 segmentation: Segmentation,
648 },
649 YoloDet {
650 boxes: Detection,
651 },
652 YoloSegDet {
653 boxes: Detection,
654 protos: Protos,
655 },
656 YoloSplitDet {
657 boxes: Boxes,
658 scores: Scores,
659 },
660 YoloSplitSegDet {
661 boxes: Boxes,
662 scores: Scores,
663 mask_coeff: MaskCoefficients,
664 protos: Protos,
665 },
666 YoloEndToEndDet {
670 boxes: Detection,
671 },
672 YoloEndToEndSegDet {
676 boxes: Detection,
677 protos: Protos,
678 },
679 YoloSplitEndToEndDet {
683 boxes: Boxes,
684 scores: Scores,
685 classes: Classes,
686 },
687 YoloSplitEndToEndSegDet {
690 boxes: Boxes,
691 scores: Scores,
692 classes: Classes,
693 mask_coeff: MaskCoefficients,
694 protos: Protos,
695 },
696 }
697
698 #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
699 #[serde(rename_all = "lowercase")]
700 pub enum DataType {
701 Raw = 0,
702 Int8 = 1,
703 UInt8 = 2,
704 Int16 = 3,
705 UInt16 = 4,
706 Float16 = 5,
707 Int32 = 6,
708 UInt32 = 7,
709 Float32 = 8,
710 Int64 = 9,
711 UInt64 = 10,
712 Float64 = 11,
713 String = 12,
714 }
715}
716
717#[derive(Debug, Clone, PartialEq)]
718pub struct DecoderBuilder {
719 config_src: Option<ConfigSource>,
720 iou_threshold: f32,
721 score_threshold: f32,
722 nms: Option<configs::Nms>,
725}
726
727#[derive(Debug, Clone, PartialEq)]
728enum ConfigSource {
729 Yaml(String),
730 Json(String),
731 Config(ConfigOutputs),
732}
733
734impl Default for DecoderBuilder {
735 fn default() -> Self {
755 Self {
756 config_src: None,
757 iou_threshold: 0.5,
758 score_threshold: 0.5,
759 nms: Some(configs::Nms::ClassAgnostic),
760 }
761 }
762}
763
764impl DecoderBuilder {
765 pub fn new() -> Self {
785 Self::default()
786 }
787
788 pub fn with_config_yaml_str(mut self, yaml_str: String) -> Self {
805 self.config_src.replace(ConfigSource::Yaml(yaml_str));
806 self
807 }
808
809 pub fn with_config_json_str(mut self, json_str: String) -> Self {
826 self.config_src.replace(ConfigSource::Json(json_str));
827 self
828 }
829
830 pub fn with_config(mut self, config: ConfigOutputs) -> Self {
847 self.config_src.replace(ConfigSource::Config(config));
848 self
849 }
850
851 pub fn with_config_yolo_det(
876 mut self,
877 boxes: configs::Detection,
878 version: Option<DecoderVersion>,
879 ) -> Self {
880 let config = ConfigOutputs {
881 outputs: vec![ConfigOutput::Detection(boxes)],
882 decoder_version: version,
883 ..Default::default()
884 };
885 self.config_src.replace(ConfigSource::Config(config));
886 self
887 }
888
889 pub fn with_config_yolo_split_det(
916 mut self,
917 boxes: configs::Boxes,
918 scores: configs::Scores,
919 ) -> Self {
920 let config = ConfigOutputs {
921 outputs: vec![ConfigOutput::Boxes(boxes), ConfigOutput::Scores(scores)],
922 ..Default::default()
923 };
924 self.config_src.replace(ConfigSource::Config(config));
925 self
926 }
927
928 pub fn with_config_yolo_segdet(
960 mut self,
961 boxes: configs::Detection,
962 protos: configs::Protos,
963 version: Option<DecoderVersion>,
964 ) -> Self {
965 let config = ConfigOutputs {
966 outputs: vec![ConfigOutput::Detection(boxes), ConfigOutput::Protos(protos)],
967 decoder_version: version,
968 ..Default::default()
969 };
970 self.config_src.replace(ConfigSource::Config(config));
971 self
972 }
973
974 pub fn with_config_yolo_split_segdet(
1013 mut self,
1014 boxes: configs::Boxes,
1015 scores: configs::Scores,
1016 mask_coefficients: configs::MaskCoefficients,
1017 protos: configs::Protos,
1018 ) -> Self {
1019 let config = ConfigOutputs {
1020 outputs: vec![
1021 ConfigOutput::Boxes(boxes),
1022 ConfigOutput::Scores(scores),
1023 ConfigOutput::MaskCoefficients(mask_coefficients),
1024 ConfigOutput::Protos(protos),
1025 ],
1026 ..Default::default()
1027 };
1028 self.config_src.replace(ConfigSource::Config(config));
1029 self
1030 }
1031
1032 pub fn with_config_modelpack_det(
1059 mut self,
1060 boxes: configs::Boxes,
1061 scores: configs::Scores,
1062 ) -> Self {
1063 let config = ConfigOutputs {
1064 outputs: vec![ConfigOutput::Boxes(boxes), ConfigOutput::Scores(scores)],
1065 ..Default::default()
1066 };
1067 self.config_src.replace(ConfigSource::Config(config));
1068 self
1069 }
1070
1071 pub fn with_config_modelpack_det_split(mut self, boxes: Vec<configs::Detection>) -> Self {
1110 let outputs = boxes.into_iter().map(ConfigOutput::Detection).collect();
1111 let config = ConfigOutputs {
1112 outputs,
1113 ..Default::default()
1114 };
1115 self.config_src.replace(ConfigSource::Config(config));
1116 self
1117 }
1118
1119 pub fn with_config_modelpack_segdet(
1152 mut self,
1153 boxes: configs::Boxes,
1154 scores: configs::Scores,
1155 segmentation: configs::Segmentation,
1156 ) -> Self {
1157 let config = ConfigOutputs {
1158 outputs: vec![
1159 ConfigOutput::Boxes(boxes),
1160 ConfigOutput::Scores(scores),
1161 ConfigOutput::Segmentation(segmentation),
1162 ],
1163 ..Default::default()
1164 };
1165 self.config_src.replace(ConfigSource::Config(config));
1166 self
1167 }
1168
1169 pub fn with_config_modelpack_segdet_split(
1213 mut self,
1214 boxes: Vec<configs::Detection>,
1215 segmentation: configs::Segmentation,
1216 ) -> Self {
1217 let mut outputs = boxes
1218 .into_iter()
1219 .map(ConfigOutput::Detection)
1220 .collect::<Vec<_>>();
1221 outputs.push(ConfigOutput::Segmentation(segmentation));
1222 let config = ConfigOutputs {
1223 outputs,
1224 ..Default::default()
1225 };
1226 self.config_src.replace(ConfigSource::Config(config));
1227 self
1228 }
1229
1230 pub fn with_config_modelpack_seg(mut self, segmentation: configs::Segmentation) -> Self {
1250 let config = ConfigOutputs {
1251 outputs: vec![ConfigOutput::Segmentation(segmentation)],
1252 ..Default::default()
1253 };
1254 self.config_src.replace(ConfigSource::Config(config));
1255 self
1256 }
1257
1258 pub fn add_output(mut self, output: ConfigOutput) -> Self {
1300 if !matches!(self.config_src, Some(ConfigSource::Config(_))) {
1301 self.config_src = Some(ConfigSource::Config(ConfigOutputs::default()));
1302 }
1303 if let Some(ConfigSource::Config(ref mut config)) = self.config_src {
1304 config.outputs.push(Self::normalize_output(output));
1305 }
1306 self
1307 }
1308
1309 pub fn with_decoder_version(mut self, version: configs::DecoderVersion) -> Self {
1338 if !matches!(self.config_src, Some(ConfigSource::Config(_))) {
1339 self.config_src = Some(ConfigSource::Config(ConfigOutputs::default()));
1340 }
1341 if let Some(ConfigSource::Config(ref mut config)) = self.config_src {
1342 config.decoder_version = Some(version);
1343 }
1344 self
1345 }
1346
1347 fn normalize_output(mut output: ConfigOutput) -> ConfigOutput {
1349 fn normalize_shape(shape: &mut Vec<usize>, dshape: &[(configs::DimName, usize)]) {
1350 if !dshape.is_empty() {
1351 *shape = dshape.iter().map(|(_, size)| *size).collect();
1352 }
1353 }
1354 match &mut output {
1355 ConfigOutput::Detection(c) => normalize_shape(&mut c.shape, &c.dshape),
1356 ConfigOutput::Boxes(c) => normalize_shape(&mut c.shape, &c.dshape),
1357 ConfigOutput::Scores(c) => normalize_shape(&mut c.shape, &c.dshape),
1358 ConfigOutput::Protos(c) => normalize_shape(&mut c.shape, &c.dshape),
1359 ConfigOutput::Segmentation(c) => normalize_shape(&mut c.shape, &c.dshape),
1360 ConfigOutput::MaskCoefficients(c) => normalize_shape(&mut c.shape, &c.dshape),
1361 ConfigOutput::Mask(c) => normalize_shape(&mut c.shape, &c.dshape),
1362 ConfigOutput::Classes(c) => normalize_shape(&mut c.shape, &c.dshape),
1363 }
1364 output
1365 }
1366
1367 pub fn with_score_threshold(mut self, score_threshold: f32) -> Self {
1383 self.score_threshold = score_threshold;
1384 self
1385 }
1386
1387 pub fn with_iou_threshold(mut self, iou_threshold: f32) -> Self {
1404 self.iou_threshold = iou_threshold;
1405 self
1406 }
1407
1408 pub fn with_nms(mut self, nms: Option<configs::Nms>) -> Self {
1430 self.nms = nms;
1431 self
1432 }
1433
1434 pub fn build(self) -> Result<Decoder, DecoderError> {
1451 let config = match self.config_src {
1452 Some(ConfigSource::Json(s)) => serde_json::from_str(&s)?,
1453 Some(ConfigSource::Yaml(s)) => serde_yaml::from_str(&s)?,
1454 Some(ConfigSource::Config(c)) => c,
1455 None => return Err(DecoderError::NoConfig),
1456 };
1457
1458 let normalized = Self::get_normalized(&config.outputs);
1460
1461 let nms = config.nms.or(self.nms);
1463 let model_type = Self::get_model_type(config)?;
1464
1465 Ok(Decoder {
1466 model_type,
1467 iou_threshold: self.iou_threshold,
1468 score_threshold: self.score_threshold,
1469 nms,
1470 normalized,
1471 })
1472 }
1473
1474 fn get_normalized(outputs: &[ConfigOutput]) -> Option<bool> {
1479 for output in outputs {
1480 match output {
1481 ConfigOutput::Detection(det) => return det.normalized,
1482 ConfigOutput::Boxes(boxes) => return boxes.normalized,
1483 _ => {}
1484 }
1485 }
1486 None }
1488
1489 fn get_model_type(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
1490 let mut yolo = false;
1492 let mut modelpack = false;
1493 for c in &configs.outputs {
1494 match c.decoder() {
1495 DecoderType::ModelPack => modelpack = true,
1496 DecoderType::Ultralytics => yolo = true,
1497 }
1498 }
1499 match (modelpack, yolo) {
1500 (true, true) => Err(DecoderError::InvalidConfig(
1501 "Both ModelPack and Yolo outputs found in config".to_string(),
1502 )),
1503 (true, false) => Self::get_model_type_modelpack(configs),
1504 (false, true) => Self::get_model_type_yolo(configs),
1505 (false, false) => Err(DecoderError::InvalidConfig(
1506 "No outputs found in config".to_string(),
1507 )),
1508 }
1509 }
1510
1511 fn get_model_type_yolo(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
1512 let mut boxes = None;
1513 let mut protos = None;
1514 let mut split_boxes = None;
1515 let mut split_scores = None;
1516 let mut split_mask_coeff = None;
1517 let mut split_classes = None;
1518 for c in configs.outputs {
1519 match c {
1520 ConfigOutput::Detection(detection) => boxes = Some(detection),
1521 ConfigOutput::Segmentation(_) => {
1522 return Err(DecoderError::InvalidConfig(
1523 "Invalid Segmentation output with Yolo decoder".to_string(),
1524 ));
1525 }
1526 ConfigOutput::Protos(protos_) => protos = Some(protos_),
1527 ConfigOutput::Mask(_) => {
1528 return Err(DecoderError::InvalidConfig(
1529 "Invalid Mask output with Yolo decoder".to_string(),
1530 ));
1531 }
1532 ConfigOutput::Scores(scores) => split_scores = Some(scores),
1533 ConfigOutput::Boxes(boxes) => split_boxes = Some(boxes),
1534 ConfigOutput::MaskCoefficients(mask_coeff) => split_mask_coeff = Some(mask_coeff),
1535 ConfigOutput::Classes(classes) => split_classes = Some(classes),
1536 }
1537 }
1538
1539 let is_end_to_end_dshape = boxes.as_ref().is_some_and(|b| {
1544 let dims = b.dshape.iter().map(|(d, _)| *d).collect::<Vec<_>>();
1545 dims == vec![DimName::Batch, DimName::NumBoxes, DimName::NumFeatures]
1546 });
1547
1548 let is_end_to_end = configs
1549 .decoder_version
1550 .map(|v| v.is_end_to_end())
1551 .unwrap_or(is_end_to_end_dshape);
1552
1553 if is_end_to_end {
1554 if let Some(boxes) = boxes {
1555 if let Some(protos) = protos {
1556 Self::verify_yolo_seg_det_26(&boxes, &protos)?;
1557 return Ok(ModelType::YoloEndToEndSegDet { boxes, protos });
1558 } else {
1559 Self::verify_yolo_det_26(&boxes)?;
1560 return Ok(ModelType::YoloEndToEndDet { boxes });
1561 }
1562 } else if let (Some(split_boxes), Some(split_scores), Some(split_classes)) =
1563 (split_boxes, split_scores, split_classes)
1564 {
1565 if let (Some(split_mask_coeff), Some(protos)) = (split_mask_coeff, protos) {
1566 Self::verify_yolo_split_end_to_end_segdet(
1567 &split_boxes,
1568 &split_scores,
1569 &split_classes,
1570 &split_mask_coeff,
1571 &protos,
1572 )?;
1573 return Ok(ModelType::YoloSplitEndToEndSegDet {
1574 boxes: split_boxes,
1575 scores: split_scores,
1576 classes: split_classes,
1577 mask_coeff: split_mask_coeff,
1578 protos,
1579 });
1580 }
1581 Self::verify_yolo_split_end_to_end_det(
1582 &split_boxes,
1583 &split_scores,
1584 &split_classes,
1585 )?;
1586 return Ok(ModelType::YoloSplitEndToEndDet {
1587 boxes: split_boxes,
1588 scores: split_scores,
1589 classes: split_classes,
1590 });
1591 } else {
1592 return Err(DecoderError::InvalidConfig(
1593 "Invalid Yolo end-to-end model outputs".to_string(),
1594 ));
1595 }
1596 }
1597
1598 if let Some(boxes) = boxes {
1599 if let Some(protos) = protos {
1600 Self::verify_yolo_seg_det(&boxes, &protos)?;
1601 Ok(ModelType::YoloSegDet { boxes, protos })
1602 } else {
1603 Self::verify_yolo_det(&boxes)?;
1604 Ok(ModelType::YoloDet { boxes })
1605 }
1606 } else if let (Some(boxes), Some(scores)) = (split_boxes, split_scores) {
1607 if let (Some(mask_coeff), Some(protos)) = (split_mask_coeff, protos) {
1608 Self::verify_yolo_split_segdet(&boxes, &scores, &mask_coeff, &protos)?;
1609 Ok(ModelType::YoloSplitSegDet {
1610 boxes,
1611 scores,
1612 mask_coeff,
1613 protos,
1614 })
1615 } else {
1616 Self::verify_yolo_split_det(&boxes, &scores)?;
1617 Ok(ModelType::YoloSplitDet { boxes, scores })
1618 }
1619 } else {
1620 Err(DecoderError::InvalidConfig(
1621 "Invalid Yolo model outputs".to_string(),
1622 ))
1623 }
1624 }
1625
1626 fn verify_yolo_det(detect: &configs::Detection) -> Result<(), DecoderError> {
1627 if detect.shape.len() != 3 {
1628 return Err(DecoderError::InvalidConfig(format!(
1629 "Invalid Yolo Detection shape {:?}",
1630 detect.shape
1631 )));
1632 }
1633
1634 Self::verify_dshapes(
1635 &detect.dshape,
1636 &detect.shape,
1637 "Detection",
1638 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1639 )?;
1640 if !detect.dshape.is_empty() {
1641 Self::get_class_count(&detect.dshape, None, None)?;
1642 } else {
1643 Self::get_class_count_no_dshape(detect.into(), None)?;
1644 }
1645
1646 Ok(())
1647 }
1648
1649 fn verify_yolo_det_26(detect: &configs::Detection) -> Result<(), DecoderError> {
1650 if detect.shape.len() != 3 {
1651 return Err(DecoderError::InvalidConfig(format!(
1652 "Invalid Yolo Detection shape {:?}",
1653 detect.shape
1654 )));
1655 }
1656
1657 Self::verify_dshapes(
1658 &detect.dshape,
1659 &detect.shape,
1660 "Detection",
1661 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1662 )?;
1663
1664 if !detect.shape.contains(&6) {
1665 return Err(DecoderError::InvalidConfig(
1666 "Yolo26 Detection must have 6 features".to_string(),
1667 ));
1668 }
1669
1670 Ok(())
1671 }
1672
1673 fn verify_yolo_seg_det(
1674 detection: &configs::Detection,
1675 protos: &configs::Protos,
1676 ) -> Result<(), DecoderError> {
1677 if detection.shape.len() != 3 {
1678 return Err(DecoderError::InvalidConfig(format!(
1679 "Invalid Yolo Detection shape {:?}",
1680 detection.shape
1681 )));
1682 }
1683 if protos.shape.len() != 4 {
1684 return Err(DecoderError::InvalidConfig(format!(
1685 "Invalid Yolo Protos shape {:?}",
1686 protos.shape
1687 )));
1688 }
1689
1690 Self::verify_dshapes(
1691 &detection.dshape,
1692 &detection.shape,
1693 "Detection",
1694 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1695 )?;
1696 Self::verify_dshapes(
1697 &protos.dshape,
1698 &protos.shape,
1699 "Protos",
1700 &[
1701 DimName::Batch,
1702 DimName::Height,
1703 DimName::Width,
1704 DimName::NumProtos,
1705 ],
1706 )?;
1707
1708 let protos_count = Self::get_protos_count(&protos.dshape).unwrap_or(protos.shape[3]);
1709 log::debug!("Protos count: {}", protos_count);
1710 log::debug!("Detection dshape: {:?}", detection.dshape);
1711 let classes = if !detection.dshape.is_empty() {
1712 Self::get_class_count(&detection.dshape, Some(protos_count), None)?
1713 } else {
1714 Self::get_class_count_no_dshape(detection.into(), Some(protos_count))?
1715 };
1716
1717 if classes == 0 {
1718 return Err(DecoderError::InvalidConfig(
1719 "Yolo Segmentation Detection has zero classes".to_string(),
1720 ));
1721 }
1722
1723 Ok(())
1724 }
1725
1726 fn verify_yolo_seg_det_26(
1727 detection: &configs::Detection,
1728 protos: &configs::Protos,
1729 ) -> Result<(), DecoderError> {
1730 if detection.shape.len() != 3 {
1731 return Err(DecoderError::InvalidConfig(format!(
1732 "Invalid Yolo Detection shape {:?}",
1733 detection.shape
1734 )));
1735 }
1736 if protos.shape.len() != 4 {
1737 return Err(DecoderError::InvalidConfig(format!(
1738 "Invalid Yolo Protos shape {:?}",
1739 protos.shape
1740 )));
1741 }
1742
1743 Self::verify_dshapes(
1744 &detection.dshape,
1745 &detection.shape,
1746 "Detection",
1747 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1748 )?;
1749 Self::verify_dshapes(
1750 &protos.dshape,
1751 &protos.shape,
1752 "Protos",
1753 &[
1754 DimName::Batch,
1755 DimName::Height,
1756 DimName::Width,
1757 DimName::NumProtos,
1758 ],
1759 )?;
1760
1761 let protos_count = Self::get_protos_count(&protos.dshape).unwrap_or(protos.shape[3]);
1762 log::debug!("Protos count: {}", protos_count);
1763 log::debug!("Detection dshape: {:?}", detection.dshape);
1764
1765 if !detection.shape.contains(&(6 + protos_count)) {
1766 return Err(DecoderError::InvalidConfig(format!(
1767 "Yolo26 Segmentation Detection must have num_features be 6 + num_protos = {}",
1768 6 + protos_count
1769 )));
1770 }
1771
1772 Ok(())
1773 }
1774
1775 fn verify_yolo_split_det(
1776 boxes: &configs::Boxes,
1777 scores: &configs::Scores,
1778 ) -> Result<(), DecoderError> {
1779 if boxes.shape.len() != 3 {
1780 return Err(DecoderError::InvalidConfig(format!(
1781 "Invalid Yolo Split Boxes shape {:?}",
1782 boxes.shape
1783 )));
1784 }
1785 if scores.shape.len() != 3 {
1786 return Err(DecoderError::InvalidConfig(format!(
1787 "Invalid Yolo Split Scores shape {:?}",
1788 scores.shape
1789 )));
1790 }
1791
1792 Self::verify_dshapes(
1793 &boxes.dshape,
1794 &boxes.shape,
1795 "Boxes",
1796 &[DimName::Batch, DimName::BoxCoords, DimName::NumBoxes],
1797 )?;
1798 Self::verify_dshapes(
1799 &scores.dshape,
1800 &scores.shape,
1801 "Scores",
1802 &[DimName::Batch, DimName::NumClasses, DimName::NumBoxes],
1803 )?;
1804
1805 let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[2]);
1806 let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[2]);
1807
1808 if boxes_num != scores_num {
1809 return Err(DecoderError::InvalidConfig(format!(
1810 "Yolo Split Detection Boxes num {} incompatible with Scores num {}",
1811 boxes_num, scores_num
1812 )));
1813 }
1814
1815 Ok(())
1816 }
1817
1818 fn verify_yolo_split_segdet(
1819 boxes: &configs::Boxes,
1820 scores: &configs::Scores,
1821 mask_coeff: &configs::MaskCoefficients,
1822 protos: &configs::Protos,
1823 ) -> Result<(), DecoderError> {
1824 if boxes.shape.len() != 3 {
1825 return Err(DecoderError::InvalidConfig(format!(
1826 "Invalid Yolo Split Boxes shape {:?}",
1827 boxes.shape
1828 )));
1829 }
1830 if scores.shape.len() != 3 {
1831 return Err(DecoderError::InvalidConfig(format!(
1832 "Invalid Yolo Split Scores shape {:?}",
1833 scores.shape
1834 )));
1835 }
1836
1837 if mask_coeff.shape.len() != 3 {
1838 return Err(DecoderError::InvalidConfig(format!(
1839 "Invalid Yolo Split Mask Coefficients shape {:?}",
1840 mask_coeff.shape
1841 )));
1842 }
1843
1844 if protos.shape.len() != 4 {
1845 return Err(DecoderError::InvalidConfig(format!(
1846 "Invalid Yolo Protos shape {:?}",
1847 mask_coeff.shape
1848 )));
1849 }
1850
1851 Self::verify_dshapes(
1852 &boxes.dshape,
1853 &boxes.shape,
1854 "Boxes",
1855 &[DimName::Batch, DimName::BoxCoords, DimName::NumBoxes],
1856 )?;
1857 Self::verify_dshapes(
1858 &scores.dshape,
1859 &scores.shape,
1860 "Scores",
1861 &[DimName::Batch, DimName::NumClasses, DimName::NumBoxes],
1862 )?;
1863 Self::verify_dshapes(
1864 &mask_coeff.dshape,
1865 &mask_coeff.shape,
1866 "Mask Coefficients",
1867 &[DimName::Batch, DimName::NumProtos, DimName::NumBoxes],
1868 )?;
1869 Self::verify_dshapes(
1870 &protos.dshape,
1871 &protos.shape,
1872 "Protos",
1873 &[
1874 DimName::Batch,
1875 DimName::Height,
1876 DimName::Width,
1877 DimName::NumProtos,
1878 ],
1879 )?;
1880
1881 let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[2]);
1882 let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[2]);
1883 let mask_num = Self::get_box_count(&mask_coeff.dshape).unwrap_or(mask_coeff.shape[2]);
1884
1885 let mask_channels = if !mask_coeff.dshape.is_empty() {
1886 Self::get_protos_count(&mask_coeff.dshape).ok_or_else(|| {
1887 DecoderError::InvalidConfig("Could not find num_protos in config".to_string())
1888 })?
1889 } else {
1890 mask_coeff.shape[1]
1891 };
1892 let proto_channels = if !protos.dshape.is_empty() {
1893 Self::get_protos_count(&protos.dshape).ok_or_else(|| {
1894 DecoderError::InvalidConfig("Could not find num_protos in config".to_string())
1895 })?
1896 } else {
1897 protos.shape[3]
1898 };
1899
1900 if boxes_num != scores_num {
1901 return Err(DecoderError::InvalidConfig(format!(
1902 "Yolo Split Detection Boxes num {} incompatible with Scores num {}",
1903 boxes_num, scores_num
1904 )));
1905 }
1906
1907 if boxes_num != mask_num {
1908 return Err(DecoderError::InvalidConfig(format!(
1909 "Yolo Split Detection Boxes num {} incompatible with Mask Coefficients num {}",
1910 boxes_num, mask_num
1911 )));
1912 }
1913
1914 if proto_channels != mask_channels {
1915 return Err(DecoderError::InvalidConfig(format!(
1916 "Yolo Protos channels {} incompatible with Mask Coefficients channels {}",
1917 proto_channels, mask_channels
1918 )));
1919 }
1920
1921 Ok(())
1922 }
1923
1924 fn verify_yolo_split_end_to_end_det(
1925 boxes: &configs::Boxes,
1926 scores: &configs::Scores,
1927 classes: &configs::Classes,
1928 ) -> Result<(), DecoderError> {
1929 if boxes.shape.len() != 3 || !boxes.shape.contains(&4) {
1930 return Err(DecoderError::InvalidConfig(format!(
1931 "Split end-to-end boxes must be [batch, N, 4], got {:?}",
1932 boxes.shape
1933 )));
1934 }
1935 if scores.shape.len() != 3 || !scores.shape.contains(&1) {
1936 return Err(DecoderError::InvalidConfig(format!(
1937 "Split end-to-end scores must be [batch, N, 1], got {:?}",
1938 scores.shape
1939 )));
1940 }
1941 if classes.shape.len() != 3 || !classes.shape.contains(&1) {
1942 return Err(DecoderError::InvalidConfig(format!(
1943 "Split end-to-end classes must be [batch, N, 1], got {:?}",
1944 classes.shape
1945 )));
1946 }
1947 Ok(())
1948 }
1949
1950 fn verify_yolo_split_end_to_end_segdet(
1951 boxes: &configs::Boxes,
1952 scores: &configs::Scores,
1953 classes: &configs::Classes,
1954 mask_coeff: &configs::MaskCoefficients,
1955 protos: &configs::Protos,
1956 ) -> Result<(), DecoderError> {
1957 Self::verify_yolo_split_end_to_end_det(boxes, scores, classes)?;
1958 if mask_coeff.shape.len() != 3 {
1959 return Err(DecoderError::InvalidConfig(format!(
1960 "Invalid split end-to-end mask coefficients shape {:?}",
1961 mask_coeff.shape
1962 )));
1963 }
1964 if protos.shape.len() != 4 {
1965 return Err(DecoderError::InvalidConfig(format!(
1966 "Invalid protos shape {:?}",
1967 protos.shape
1968 )));
1969 }
1970 Ok(())
1971 }
1972
1973 fn get_model_type_modelpack(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
1974 let mut split_decoders = Vec::new();
1975 let mut segment_ = None;
1976 let mut scores_ = None;
1977 let mut boxes_ = None;
1978 for c in configs.outputs {
1979 match c {
1980 ConfigOutput::Detection(detection) => split_decoders.push(detection),
1981 ConfigOutput::Segmentation(segmentation) => segment_ = Some(segmentation),
1982 ConfigOutput::Mask(_) => {}
1983 ConfigOutput::Protos(_) => {
1984 return Err(DecoderError::InvalidConfig(
1985 "ModelPack should not have protos".to_string(),
1986 ));
1987 }
1988 ConfigOutput::Scores(scores) => scores_ = Some(scores),
1989 ConfigOutput::Boxes(boxes) => boxes_ = Some(boxes),
1990 ConfigOutput::MaskCoefficients(_) => {
1991 return Err(DecoderError::InvalidConfig(
1992 "ModelPack should not have mask coefficients".to_string(),
1993 ));
1994 }
1995 ConfigOutput::Classes(_) => {
1996 return Err(DecoderError::InvalidConfig(
1997 "ModelPack should not have classes output".to_string(),
1998 ));
1999 }
2000 }
2001 }
2002
2003 if let Some(segmentation) = segment_ {
2004 if !split_decoders.is_empty() {
2005 let classes = Self::verify_modelpack_split_det(&split_decoders)?;
2006 Self::verify_modelpack_seg(&segmentation, Some(classes))?;
2007 Ok(ModelType::ModelPackSegDetSplit {
2008 detection: split_decoders,
2009 segmentation,
2010 })
2011 } else if let (Some(scores), Some(boxes)) = (scores_, boxes_) {
2012 let classes = Self::verify_modelpack_det(&boxes, &scores)?;
2013 Self::verify_modelpack_seg(&segmentation, Some(classes))?;
2014 Ok(ModelType::ModelPackSegDet {
2015 boxes,
2016 scores,
2017 segmentation,
2018 })
2019 } else {
2020 Self::verify_modelpack_seg(&segmentation, None)?;
2021 Ok(ModelType::ModelPackSeg { segmentation })
2022 }
2023 } else if !split_decoders.is_empty() {
2024 Self::verify_modelpack_split_det(&split_decoders)?;
2025 Ok(ModelType::ModelPackDetSplit {
2026 detection: split_decoders,
2027 })
2028 } else if let (Some(scores), Some(boxes)) = (scores_, boxes_) {
2029 Self::verify_modelpack_det(&boxes, &scores)?;
2030 Ok(ModelType::ModelPackDet { boxes, scores })
2031 } else {
2032 Err(DecoderError::InvalidConfig(
2033 "Invalid ModelPack model outputs".to_string(),
2034 ))
2035 }
2036 }
2037
2038 fn verify_modelpack_det(
2039 boxes: &configs::Boxes,
2040 scores: &configs::Scores,
2041 ) -> Result<usize, DecoderError> {
2042 if boxes.shape.len() != 4 {
2043 return Err(DecoderError::InvalidConfig(format!(
2044 "Invalid ModelPack Boxes shape {:?}",
2045 boxes.shape
2046 )));
2047 }
2048 if scores.shape.len() != 3 {
2049 return Err(DecoderError::InvalidConfig(format!(
2050 "Invalid ModelPack Scores shape {:?}",
2051 scores.shape
2052 )));
2053 }
2054
2055 Self::verify_dshapes(
2056 &boxes.dshape,
2057 &boxes.shape,
2058 "Boxes",
2059 &[
2060 DimName::Batch,
2061 DimName::NumBoxes,
2062 DimName::Padding,
2063 DimName::BoxCoords,
2064 ],
2065 )?;
2066 Self::verify_dshapes(
2067 &scores.dshape,
2068 &scores.shape,
2069 "Scores",
2070 &[DimName::Batch, DimName::NumBoxes, DimName::NumClasses],
2071 )?;
2072
2073 let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[1]);
2074 let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[1]);
2075
2076 if boxes_num != scores_num {
2077 return Err(DecoderError::InvalidConfig(format!(
2078 "ModelPack Detection Boxes num {} incompatible with Scores num {}",
2079 boxes_num, scores_num
2080 )));
2081 }
2082
2083 let num_classes = if !scores.dshape.is_empty() {
2084 Self::get_class_count(&scores.dshape, None, None)?
2085 } else {
2086 Self::get_class_count_no_dshape(scores.into(), None)?
2087 };
2088
2089 Ok(num_classes)
2090 }
2091
2092 fn verify_modelpack_split_det(boxes: &[configs::Detection]) -> Result<usize, DecoderError> {
2093 let mut num_classes = None;
2094 for b in boxes {
2095 let Some(num_anchors) = b.anchors.as_ref().map(|a| a.len()) else {
2096 return Err(DecoderError::InvalidConfig(
2097 "ModelPack Split Detection missing anchors".to_string(),
2098 ));
2099 };
2100
2101 if num_anchors == 0 {
2102 return Err(DecoderError::InvalidConfig(
2103 "ModelPack Split Detection has zero anchors".to_string(),
2104 ));
2105 }
2106
2107 if b.shape.len() != 4 {
2108 return Err(DecoderError::InvalidConfig(format!(
2109 "Invalid ModelPack Split Detection shape {:?}",
2110 b.shape
2111 )));
2112 }
2113
2114 Self::verify_dshapes(
2115 &b.dshape,
2116 &b.shape,
2117 "Split Detection",
2118 &[
2119 DimName::Batch,
2120 DimName::Height,
2121 DimName::Width,
2122 DimName::NumAnchorsXFeatures,
2123 ],
2124 )?;
2125 let classes = if !b.dshape.is_empty() {
2126 Self::get_class_count(&b.dshape, None, Some(num_anchors))?
2127 } else {
2128 Self::get_class_count_no_dshape(b.into(), None)?
2129 };
2130
2131 match num_classes {
2132 Some(n) => {
2133 if n != classes {
2134 return Err(DecoderError::InvalidConfig(format!(
2135 "ModelPack Split Detection inconsistent number of classes: previous {}, current {}",
2136 n, classes
2137 )));
2138 }
2139 }
2140 None => {
2141 num_classes = Some(classes);
2142 }
2143 }
2144 }
2145
2146 Ok(num_classes.unwrap_or(0))
2147 }
2148
2149 fn verify_modelpack_seg(
2150 segmentation: &configs::Segmentation,
2151 classes: Option<usize>,
2152 ) -> Result<(), DecoderError> {
2153 if segmentation.shape.len() != 4 {
2154 return Err(DecoderError::InvalidConfig(format!(
2155 "Invalid ModelPack Segmentation shape {:?}",
2156 segmentation.shape
2157 )));
2158 }
2159 Self::verify_dshapes(
2160 &segmentation.dshape,
2161 &segmentation.shape,
2162 "Segmentation",
2163 &[
2164 DimName::Batch,
2165 DimName::Height,
2166 DimName::Width,
2167 DimName::NumClasses,
2168 ],
2169 )?;
2170
2171 if let Some(classes) = classes {
2172 let seg_classes = if !segmentation.dshape.is_empty() {
2173 Self::get_class_count(&segmentation.dshape, None, None)?
2174 } else {
2175 Self::get_class_count_no_dshape(segmentation.into(), None)?
2176 };
2177
2178 if seg_classes != classes + 1 {
2179 return Err(DecoderError::InvalidConfig(format!(
2180 "ModelPack Segmentation channels {} incompatible with number of classes {}",
2181 seg_classes, classes
2182 )));
2183 }
2184 }
2185 Ok(())
2186 }
2187
2188 fn verify_dshapes(
2190 dshape: &[(DimName, usize)],
2191 shape: &[usize],
2192 name: &str,
2193 dims: &[DimName],
2194 ) -> Result<(), DecoderError> {
2195 for s in shape {
2196 if *s == 0 {
2197 return Err(DecoderError::InvalidConfig(format!(
2198 "{} shape has zero dimension",
2199 name
2200 )));
2201 }
2202 }
2203
2204 if shape.len() != dims.len() {
2205 return Err(DecoderError::InvalidConfig(format!(
2206 "{} shape length {} does not match expected dims length {}",
2207 name,
2208 shape.len(),
2209 dims.len()
2210 )));
2211 }
2212
2213 if dshape.is_empty() {
2214 return Ok(());
2215 }
2216 if dshape.len() != shape.len() {
2218 return Err(DecoderError::InvalidConfig(format!(
2219 "{} dshape length does not match shape length",
2220 name
2221 )));
2222 }
2223
2224 for ((dim_name, dim_size), shape_size) in dshape.iter().zip(shape) {
2226 if dim_size != shape_size {
2227 return Err(DecoderError::InvalidConfig(format!(
2228 "{} dshape dimension {} size {} does not match shape size {}",
2229 name, dim_name, dim_size, shape_size
2230 )));
2231 }
2232 if *dim_name == DimName::Padding && *dim_size != 1 {
2233 return Err(DecoderError::InvalidConfig(
2234 "Padding dimension size must be 1".to_string(),
2235 ));
2236 }
2237
2238 if *dim_name == DimName::BoxCoords && *dim_size != 4 {
2239 return Err(DecoderError::InvalidConfig(
2240 "BoxCoords dimension size must be 4".to_string(),
2241 ));
2242 }
2243 }
2244
2245 let dims_present = HashSet::<DimName>::from_iter(dshape.iter().map(|(name, _)| *name));
2246 for dim in dims {
2247 if !dims_present.contains(dim) {
2248 return Err(DecoderError::InvalidConfig(format!(
2249 "{} dshape missing required dimension {:?}",
2250 name, dim
2251 )));
2252 }
2253 }
2254
2255 Ok(())
2256 }
2257
2258 fn get_box_count(dshape: &[(DimName, usize)]) -> Option<usize> {
2259 for (dim_name, dim_size) in dshape {
2260 if *dim_name == DimName::NumBoxes {
2261 return Some(*dim_size);
2262 }
2263 }
2264 None
2265 }
2266
2267 fn get_class_count_no_dshape(
2268 config: ConfigOutputRef,
2269 protos: Option<usize>,
2270 ) -> Result<usize, DecoderError> {
2271 match config {
2272 ConfigOutputRef::Detection(detection) => match detection.decoder {
2273 DecoderType::Ultralytics => {
2274 if detection.shape[1] <= 4 + protos.unwrap_or(0) {
2275 return Err(DecoderError::InvalidConfig(format!(
2276 "Invalid shape: Yolo num_features {} must be greater than {}",
2277 detection.shape[1],
2278 4 + protos.unwrap_or(0),
2279 )));
2280 }
2281 Ok(detection.shape[1] - 4 - protos.unwrap_or(0))
2282 }
2283 DecoderType::ModelPack => {
2284 let Some(num_anchors) = detection.anchors.as_ref().map(|a| a.len()) else {
2285 return Err(DecoderError::Internal(
2286 "ModelPack Detection missing anchors".to_string(),
2287 ));
2288 };
2289 let anchors_x_features = detection.shape[3];
2290 if anchors_x_features <= num_anchors * 5 {
2291 return Err(DecoderError::InvalidConfig(format!(
2292 "Invalid ModelPack Split Detection shape: anchors_x_features {} not greater than number of anchors * 5 = {}",
2293 anchors_x_features,
2294 num_anchors * 5,
2295 )));
2296 }
2297
2298 if !anchors_x_features.is_multiple_of(num_anchors) {
2299 return Err(DecoderError::InvalidConfig(format!(
2300 "Invalid ModelPack Split Detection shape: anchors_x_features {} not a multiple of number of anchors {}",
2301 anchors_x_features, num_anchors
2302 )));
2303 }
2304 Ok(anchors_x_features / num_anchors - 5)
2305 }
2306 },
2307
2308 ConfigOutputRef::Scores(scores) => match scores.decoder {
2309 DecoderType::Ultralytics => Ok(scores.shape[1]),
2310 DecoderType::ModelPack => Ok(scores.shape[2]),
2311 },
2312 ConfigOutputRef::Segmentation(seg) => Ok(seg.shape[3]),
2313 _ => Err(DecoderError::Internal(
2314 "Attempted to get class count from unsupported config output".to_owned(),
2315 )),
2316 }
2317 }
2318
2319 fn get_class_count(
2321 dshape: &[(DimName, usize)],
2322 protos: Option<usize>,
2323 anchors: Option<usize>,
2324 ) -> Result<usize, DecoderError> {
2325 if dshape.is_empty() {
2326 return Ok(0);
2327 }
2328 for (dim_name, dim_size) in dshape {
2330 if *dim_name == DimName::NumClasses {
2331 return Ok(*dim_size);
2332 }
2333 }
2334
2335 for (dim_name, dim_size) in dshape {
2338 if *dim_name == DimName::NumFeatures {
2339 let protos = protos.unwrap_or(0);
2340 if protos + 4 >= *dim_size {
2341 return Err(DecoderError::InvalidConfig(format!(
2342 "Invalid shape: Yolo num_features {} must be greater than {}",
2343 *dim_size,
2344 protos + 4,
2345 )));
2346 }
2347 return Ok(*dim_size - 4 - protos);
2348 }
2349 }
2350
2351 if let Some(num_anchors) = anchors {
2354 for (dim_name, dim_size) in dshape {
2355 if *dim_name == DimName::NumAnchorsXFeatures {
2356 let anchors_x_features = *dim_size;
2357 if anchors_x_features <= num_anchors * 5 {
2358 return Err(DecoderError::InvalidConfig(format!(
2359 "Invalid ModelPack Split Detection shape: anchors_x_features {} not greater than number of anchors * 5 = {}",
2360 anchors_x_features,
2361 num_anchors * 5,
2362 )));
2363 }
2364
2365 if !anchors_x_features.is_multiple_of(num_anchors) {
2366 return Err(DecoderError::InvalidConfig(format!(
2367 "Invalid ModelPack Split Detection shape: anchors_x_features {} not a multiple of number of anchors {}",
2368 anchors_x_features, num_anchors
2369 )));
2370 }
2371 return Ok((anchors_x_features / num_anchors) - 5);
2372 }
2373 }
2374 }
2375 Err(DecoderError::InvalidConfig(
2376 "Cannot determine number of classes from dshape".to_owned(),
2377 ))
2378 }
2379
2380 fn get_protos_count(dshape: &[(DimName, usize)]) -> Option<usize> {
2381 for (dim_name, dim_size) in dshape {
2382 if *dim_name == DimName::NumProtos {
2383 return Some(*dim_size);
2384 }
2385 }
2386 None
2387 }
2388}
2389
2390#[derive(Debug, Clone, PartialEq)]
2391pub struct Decoder {
2392 model_type: ModelType,
2393 pub iou_threshold: f32,
2394 pub score_threshold: f32,
2395 pub nms: Option<configs::Nms>,
2398 normalized: Option<bool>,
2404}
2405
2406#[derive(Debug)]
2407pub enum ArrayViewDQuantized<'a> {
2408 UInt8(ArrayViewD<'a, u8>),
2409 Int8(ArrayViewD<'a, i8>),
2410 UInt16(ArrayViewD<'a, u16>),
2411 Int16(ArrayViewD<'a, i16>),
2412 UInt32(ArrayViewD<'a, u32>),
2413 Int32(ArrayViewD<'a, i32>),
2414}
2415
2416impl<'a, D> From<ArrayView<'a, u8, D>> for ArrayViewDQuantized<'a>
2417where
2418 D: Dimension,
2419{
2420 fn from(arr: ArrayView<'a, u8, D>) -> Self {
2421 Self::UInt8(arr.into_dyn())
2422 }
2423}
2424
2425impl<'a, D> From<ArrayView<'a, i8, D>> for ArrayViewDQuantized<'a>
2426where
2427 D: Dimension,
2428{
2429 fn from(arr: ArrayView<'a, i8, D>) -> Self {
2430 Self::Int8(arr.into_dyn())
2431 }
2432}
2433
2434impl<'a, D> From<ArrayView<'a, u16, D>> for ArrayViewDQuantized<'a>
2435where
2436 D: Dimension,
2437{
2438 fn from(arr: ArrayView<'a, u16, D>) -> Self {
2439 Self::UInt16(arr.into_dyn())
2440 }
2441}
2442
2443impl<'a, D> From<ArrayView<'a, i16, D>> for ArrayViewDQuantized<'a>
2444where
2445 D: Dimension,
2446{
2447 fn from(arr: ArrayView<'a, i16, D>) -> Self {
2448 Self::Int16(arr.into_dyn())
2449 }
2450}
2451
2452impl<'a, D> From<ArrayView<'a, u32, D>> for ArrayViewDQuantized<'a>
2453where
2454 D: Dimension,
2455{
2456 fn from(arr: ArrayView<'a, u32, D>) -> Self {
2457 Self::UInt32(arr.into_dyn())
2458 }
2459}
2460
2461impl<'a, D> From<ArrayView<'a, i32, D>> for ArrayViewDQuantized<'a>
2462where
2463 D: Dimension,
2464{
2465 fn from(arr: ArrayView<'a, i32, D>) -> Self {
2466 Self::Int32(arr.into_dyn())
2467 }
2468}
2469
2470impl<'a> ArrayViewDQuantized<'a> {
2471 pub fn shape(&self) -> &[usize] {
2485 match self {
2486 ArrayViewDQuantized::UInt8(a) => a.shape(),
2487 ArrayViewDQuantized::Int8(a) => a.shape(),
2488 ArrayViewDQuantized::UInt16(a) => a.shape(),
2489 ArrayViewDQuantized::Int16(a) => a.shape(),
2490 ArrayViewDQuantized::UInt32(a) => a.shape(),
2491 ArrayViewDQuantized::Int32(a) => a.shape(),
2492 }
2493 }
2494}
2495
2496macro_rules! with_quantized {
2503 ($x:expr, $var:ident, $body:expr) => {
2504 match $x {
2505 ArrayViewDQuantized::UInt8(x) => {
2506 let $var = x;
2507 $body
2508 }
2509 ArrayViewDQuantized::Int8(x) => {
2510 let $var = x;
2511 $body
2512 }
2513 ArrayViewDQuantized::UInt16(x) => {
2514 let $var = x;
2515 $body
2516 }
2517 ArrayViewDQuantized::Int16(x) => {
2518 let $var = x;
2519 $body
2520 }
2521 ArrayViewDQuantized::UInt32(x) => {
2522 let $var = x;
2523 $body
2524 }
2525 ArrayViewDQuantized::Int32(x) => {
2526 let $var = x;
2527 $body
2528 }
2529 }
2530 };
2531}
2532
2533impl Decoder {
2534 pub fn model_type(&self) -> &ModelType {
2553 &self.model_type
2554 }
2555
2556 pub fn normalized_boxes(&self) -> Option<bool> {
2582 self.normalized
2583 }
2584
2585 pub fn decode_quantized(
2635 &self,
2636 outputs: &[ArrayViewDQuantized],
2637 output_boxes: &mut Vec<DetectBox>,
2638 output_masks: &mut Vec<Segmentation>,
2639 ) -> Result<(), DecoderError> {
2640 output_boxes.clear();
2641 output_masks.clear();
2642 match &self.model_type {
2643 ModelType::ModelPackSegDet {
2644 boxes,
2645 scores,
2646 segmentation,
2647 } => {
2648 self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)?;
2649 self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
2650 }
2651 ModelType::ModelPackSegDetSplit {
2652 detection,
2653 segmentation,
2654 } => {
2655 self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)?;
2656 self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
2657 }
2658 ModelType::ModelPackDet { boxes, scores } => {
2659 self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)
2660 }
2661 ModelType::ModelPackDetSplit { detection } => {
2662 self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)
2663 }
2664 ModelType::ModelPackSeg { segmentation } => {
2665 self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
2666 }
2667 ModelType::YoloDet { boxes } => {
2668 self.decode_yolo_det_quantized(outputs, boxes, output_boxes)
2669 }
2670 ModelType::YoloSegDet { boxes, protos } => self.decode_yolo_segdet_quantized(
2671 outputs,
2672 boxes,
2673 protos,
2674 output_boxes,
2675 output_masks,
2676 ),
2677 ModelType::YoloSplitDet { boxes, scores } => {
2678 self.decode_yolo_split_det_quantized(outputs, boxes, scores, output_boxes)
2679 }
2680 ModelType::YoloSplitSegDet {
2681 boxes,
2682 scores,
2683 mask_coeff,
2684 protos,
2685 } => self.decode_yolo_split_segdet_quantized(
2686 outputs,
2687 boxes,
2688 scores,
2689 mask_coeff,
2690 protos,
2691 output_boxes,
2692 output_masks,
2693 ),
2694 ModelType::YoloEndToEndDet { boxes } => {
2695 self.decode_yolo_end_to_end_det_quantized(outputs, boxes, output_boxes)
2696 }
2697 ModelType::YoloEndToEndSegDet { boxes, protos } => self
2698 .decode_yolo_end_to_end_segdet_quantized(
2699 outputs,
2700 boxes,
2701 protos,
2702 output_boxes,
2703 output_masks,
2704 ),
2705 ModelType::YoloSplitEndToEndDet {
2706 boxes,
2707 scores,
2708 classes,
2709 } => self.decode_yolo_split_end_to_end_det_quantized(
2710 outputs,
2711 boxes,
2712 scores,
2713 classes,
2714 output_boxes,
2715 ),
2716 ModelType::YoloSplitEndToEndSegDet {
2717 boxes,
2718 scores,
2719 classes,
2720 mask_coeff,
2721 protos,
2722 } => self.decode_yolo_split_end_to_end_segdet_quantized(
2723 outputs,
2724 boxes,
2725 scores,
2726 classes,
2727 mask_coeff,
2728 protos,
2729 output_boxes,
2730 output_masks,
2731 ),
2732 }
2733 }
2734
2735 pub fn decode_float<T>(
2792 &self,
2793 outputs: &[ArrayViewD<T>],
2794 output_boxes: &mut Vec<DetectBox>,
2795 output_masks: &mut Vec<Segmentation>,
2796 ) -> Result<(), DecoderError>
2797 where
2798 T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
2799 f32: AsPrimitive<T>,
2800 {
2801 output_boxes.clear();
2802 output_masks.clear();
2803 match &self.model_type {
2804 ModelType::ModelPackSegDet {
2805 boxes,
2806 scores,
2807 segmentation,
2808 } => {
2809 self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
2810 self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
2811 }
2812 ModelType::ModelPackSegDetSplit {
2813 detection,
2814 segmentation,
2815 } => {
2816 self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
2817 self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
2818 }
2819 ModelType::ModelPackDet { boxes, scores } => {
2820 self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
2821 }
2822 ModelType::ModelPackDetSplit { detection } => {
2823 self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
2824 }
2825 ModelType::ModelPackSeg { segmentation } => {
2826 self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
2827 }
2828 ModelType::YoloDet { boxes } => {
2829 self.decode_yolo_det_float(outputs, boxes, output_boxes)?;
2830 }
2831 ModelType::YoloSegDet { boxes, protos } => {
2832 self.decode_yolo_segdet_float(outputs, boxes, protos, output_boxes, output_masks)?;
2833 }
2834 ModelType::YoloSplitDet { boxes, scores } => {
2835 self.decode_yolo_split_det_float(outputs, boxes, scores, output_boxes)?;
2836 }
2837 ModelType::YoloSplitSegDet {
2838 boxes,
2839 scores,
2840 mask_coeff,
2841 protos,
2842 } => {
2843 self.decode_yolo_split_segdet_float(
2844 outputs,
2845 boxes,
2846 scores,
2847 mask_coeff,
2848 protos,
2849 output_boxes,
2850 output_masks,
2851 )?;
2852 }
2853 ModelType::YoloEndToEndDet { boxes } => {
2854 self.decode_yolo_end_to_end_det_float(outputs, boxes, output_boxes)?;
2855 }
2856 ModelType::YoloEndToEndSegDet { boxes, protos } => {
2857 self.decode_yolo_end_to_end_segdet_float(
2858 outputs,
2859 boxes,
2860 protos,
2861 output_boxes,
2862 output_masks,
2863 )?;
2864 }
2865 ModelType::YoloSplitEndToEndDet {
2866 boxes,
2867 scores,
2868 classes,
2869 } => {
2870 self.decode_yolo_split_end_to_end_det_float(
2871 outputs,
2872 boxes,
2873 scores,
2874 classes,
2875 output_boxes,
2876 )?;
2877 }
2878 ModelType::YoloSplitEndToEndSegDet {
2879 boxes,
2880 scores,
2881 classes,
2882 mask_coeff,
2883 protos,
2884 } => {
2885 self.decode_yolo_split_end_to_end_segdet_float(
2886 outputs,
2887 boxes,
2888 scores,
2889 classes,
2890 mask_coeff,
2891 protos,
2892 output_boxes,
2893 output_masks,
2894 )?;
2895 }
2896 }
2897 Ok(())
2898 }
2899
2900 pub fn decode_quantized_proto(
2907 &self,
2908 outputs: &[ArrayViewDQuantized],
2909 output_boxes: &mut Vec<DetectBox>,
2910 ) -> Result<Option<ProtoData>, DecoderError> {
2911 output_boxes.clear();
2912 match &self.model_type {
2913 ModelType::ModelPackSegDet { .. }
2915 | ModelType::ModelPackSegDetSplit { .. }
2916 | ModelType::ModelPackDet { .. }
2917 | ModelType::ModelPackDetSplit { .. }
2918 | ModelType::ModelPackSeg { .. }
2919 | ModelType::YoloDet { .. }
2920 | ModelType::YoloSplitDet { .. }
2921 | ModelType::YoloEndToEndDet { .. }
2922 | ModelType::YoloSplitEndToEndDet { .. } => Ok(None),
2923
2924 ModelType::YoloSegDet { boxes, protos } => {
2925 let proto =
2926 self.decode_yolo_segdet_quantized_proto(outputs, boxes, protos, output_boxes)?;
2927 Ok(Some(proto))
2928 }
2929 ModelType::YoloSplitSegDet {
2930 boxes,
2931 scores,
2932 mask_coeff,
2933 protos,
2934 } => {
2935 let proto = self.decode_yolo_split_segdet_quantized_proto(
2936 outputs,
2937 boxes,
2938 scores,
2939 mask_coeff,
2940 protos,
2941 output_boxes,
2942 )?;
2943 Ok(Some(proto))
2944 }
2945 ModelType::YoloEndToEndSegDet { boxes, protos } => {
2946 let proto = self.decode_yolo_end_to_end_segdet_quantized_proto(
2947 outputs,
2948 boxes,
2949 protos,
2950 output_boxes,
2951 )?;
2952 Ok(Some(proto))
2953 }
2954 ModelType::YoloSplitEndToEndSegDet {
2955 boxes,
2956 scores,
2957 classes,
2958 mask_coeff,
2959 protos,
2960 } => {
2961 let proto = self.decode_yolo_split_end_to_end_segdet_quantized_proto(
2962 outputs,
2963 boxes,
2964 scores,
2965 classes,
2966 mask_coeff,
2967 protos,
2968 output_boxes,
2969 )?;
2970 Ok(Some(proto))
2971 }
2972 }
2973 }
2974
2975 pub fn decode_float_proto<T>(
2981 &self,
2982 outputs: &[ArrayViewD<T>],
2983 output_boxes: &mut Vec<DetectBox>,
2984 ) -> Result<Option<ProtoData>, DecoderError>
2985 where
2986 T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
2987 f32: AsPrimitive<T>,
2988 {
2989 output_boxes.clear();
2990 match &self.model_type {
2991 ModelType::ModelPackSegDet { .. }
2993 | ModelType::ModelPackSegDetSplit { .. }
2994 | ModelType::ModelPackDet { .. }
2995 | ModelType::ModelPackDetSplit { .. }
2996 | ModelType::ModelPackSeg { .. }
2997 | ModelType::YoloDet { .. }
2998 | ModelType::YoloSplitDet { .. }
2999 | ModelType::YoloEndToEndDet { .. }
3000 | ModelType::YoloSplitEndToEndDet { .. } => Ok(None),
3001
3002 ModelType::YoloSegDet { boxes, protos } => {
3003 let proto =
3004 self.decode_yolo_segdet_float_proto(outputs, boxes, protos, output_boxes)?;
3005 Ok(Some(proto))
3006 }
3007 ModelType::YoloSplitSegDet {
3008 boxes,
3009 scores,
3010 mask_coeff,
3011 protos,
3012 } => {
3013 let proto = self.decode_yolo_split_segdet_float_proto(
3014 outputs,
3015 boxes,
3016 scores,
3017 mask_coeff,
3018 protos,
3019 output_boxes,
3020 )?;
3021 Ok(Some(proto))
3022 }
3023 ModelType::YoloEndToEndSegDet { boxes, protos } => {
3024 let proto = self.decode_yolo_end_to_end_segdet_float_proto(
3025 outputs,
3026 boxes,
3027 protos,
3028 output_boxes,
3029 )?;
3030 Ok(Some(proto))
3031 }
3032 ModelType::YoloSplitEndToEndSegDet {
3033 boxes,
3034 scores,
3035 classes,
3036 mask_coeff,
3037 protos,
3038 } => {
3039 let proto = self.decode_yolo_split_end_to_end_segdet_float_proto(
3040 outputs,
3041 boxes,
3042 scores,
3043 classes,
3044 mask_coeff,
3045 protos,
3046 output_boxes,
3047 )?;
3048 Ok(Some(proto))
3049 }
3050 }
3051 }
3052
3053 fn decode_modelpack_det_quantized(
3054 &self,
3055 outputs: &[ArrayViewDQuantized],
3056 boxes: &configs::Boxes,
3057 scores: &configs::Scores,
3058 output_boxes: &mut Vec<DetectBox>,
3059 ) -> Result<(), DecoderError> {
3060 let (boxes_tensor, ind) =
3061 Self::find_outputs_with_shape_quantized(&boxes.shape, outputs, &[])?;
3062 let (scores_tensor, _) =
3063 Self::find_outputs_with_shape_quantized(&scores.shape, outputs, &[ind])?;
3064 let quant_boxes = boxes
3065 .quantization
3066 .map(Quantization::from)
3067 .unwrap_or_default();
3068 let quant_scores = scores
3069 .quantization
3070 .map(Quantization::from)
3071 .unwrap_or_default();
3072
3073 with_quantized!(boxes_tensor, b, {
3074 with_quantized!(scores_tensor, s, {
3075 let boxes_tensor = Self::swap_axes_if_needed(b, boxes.into());
3076 let boxes_tensor = boxes_tensor.slice(s![0, .., 0, ..]);
3077
3078 let scores_tensor = Self::swap_axes_if_needed(s, scores.into());
3079 let scores_tensor = scores_tensor.slice(s![0, .., ..]);
3080 decode_modelpack_det(
3081 (boxes_tensor, quant_boxes),
3082 (scores_tensor, quant_scores),
3083 self.score_threshold,
3084 self.iou_threshold,
3085 output_boxes,
3086 );
3087 });
3088 });
3089
3090 Ok(())
3091 }
3092
3093 fn decode_modelpack_seg_quantized(
3094 &self,
3095 outputs: &[ArrayViewDQuantized],
3096 segmentation: &configs::Segmentation,
3097 output_masks: &mut Vec<Segmentation>,
3098 ) -> Result<(), DecoderError> {
3099 let (seg, _) = Self::find_outputs_with_shape_quantized(&segmentation.shape, outputs, &[])?;
3100
3101 macro_rules! modelpack_seg {
3102 ($seg:expr, $body:expr) => {{
3103 let seg = Self::swap_axes_if_needed($seg, segmentation.into());
3104 let seg = seg.slice(s![0, .., .., ..]);
3105 seg.mapv($body)
3106 }};
3107 }
3108 use ArrayViewDQuantized::*;
3109 let seg = match seg {
3110 UInt8(s) => {
3111 modelpack_seg!(s, |x| x)
3112 }
3113 Int8(s) => {
3114 modelpack_seg!(s, |x| (x as i16 + 128) as u8)
3115 }
3116 UInt16(s) => {
3117 modelpack_seg!(s, |x| (x >> 8) as u8)
3118 }
3119 Int16(s) => {
3120 modelpack_seg!(s, |x| ((x as i32 + 32768) >> 8) as u8)
3121 }
3122 UInt32(s) => {
3123 modelpack_seg!(s, |x| (x >> 24) as u8)
3124 }
3125 Int32(s) => {
3126 modelpack_seg!(s, |x| ((x as i64 + 2147483648) >> 24) as u8)
3127 }
3128 };
3129
3130 output_masks.push(Segmentation {
3131 xmin: 0.0,
3132 ymin: 0.0,
3133 xmax: 1.0,
3134 ymax: 1.0,
3135 segmentation: seg,
3136 });
3137 Ok(())
3138 }
3139
3140 fn decode_modelpack_det_split_quantized(
3141 &self,
3142 outputs: &[ArrayViewDQuantized],
3143 detection: &[configs::Detection],
3144 output_boxes: &mut Vec<DetectBox>,
3145 ) -> Result<(), DecoderError> {
3146 let new_detection = detection
3147 .iter()
3148 .map(|x| match &x.anchors {
3149 None => Err(DecoderError::InvalidConfig(
3150 "ModelPack Split Detection missing anchors".to_string(),
3151 )),
3152 Some(a) => Ok(ModelPackDetectionConfig {
3153 anchors: a.clone(),
3154 quantization: None,
3155 }),
3156 })
3157 .collect::<Result<Vec<_>, _>>()?;
3158 let new_outputs = Self::match_outputs_to_detect_quantized(detection, outputs)?;
3159
3160 macro_rules! dequant_output {
3161 ($det_tensor:expr, $detection:expr) => {{
3162 let det_tensor = Self::swap_axes_if_needed($det_tensor, $detection.into());
3163 let det_tensor = det_tensor.slice(s![0, .., .., ..]);
3164 if let Some(q) = $detection.quantization {
3165 dequantize_ndarray(det_tensor, q.into())
3166 } else {
3167 det_tensor.map(|x| *x as f32)
3168 }
3169 }};
3170 }
3171
3172 let new_outputs = new_outputs
3173 .iter()
3174 .zip(detection)
3175 .map(|(det_tensor, detection)| {
3176 with_quantized!(det_tensor, d, dequant_output!(d, detection))
3177 })
3178 .collect::<Vec<_>>();
3179
3180 let new_outputs_view = new_outputs
3181 .iter()
3182 .map(|d: &Array3<f32>| d.view())
3183 .collect::<Vec<_>>();
3184 decode_modelpack_split_float(
3185 &new_outputs_view,
3186 &new_detection,
3187 self.score_threshold,
3188 self.iou_threshold,
3189 output_boxes,
3190 );
3191 Ok(())
3192 }
3193
3194 fn decode_yolo_det_quantized(
3195 &self,
3196 outputs: &[ArrayViewDQuantized],
3197 boxes: &configs::Detection,
3198 output_boxes: &mut Vec<DetectBox>,
3199 ) -> Result<(), DecoderError> {
3200 let (boxes_tensor, _) =
3201 Self::find_outputs_with_shape_quantized(&boxes.shape, outputs, &[])?;
3202 let quant_boxes = boxes
3203 .quantization
3204 .map(Quantization::from)
3205 .unwrap_or_default();
3206
3207 with_quantized!(boxes_tensor, b, {
3208 let boxes_tensor = Self::swap_axes_if_needed(b, boxes.into());
3209 let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
3210 decode_yolo_det(
3211 (boxes_tensor, quant_boxes),
3212 self.score_threshold,
3213 self.iou_threshold,
3214 self.nms,
3215 output_boxes,
3216 );
3217 });
3218
3219 Ok(())
3220 }
3221
3222 fn decode_yolo_segdet_quantized(
3223 &self,
3224 outputs: &[ArrayViewDQuantized],
3225 boxes: &configs::Detection,
3226 protos: &configs::Protos,
3227 output_boxes: &mut Vec<DetectBox>,
3228 output_masks: &mut Vec<Segmentation>,
3229 ) -> Result<(), DecoderError> {
3230 let (boxes_tensor, ind) =
3231 Self::find_outputs_with_shape_quantized(&boxes.shape, outputs, &[])?;
3232 let (protos_tensor, _) =
3233 Self::find_outputs_with_shape_quantized(&protos.shape, outputs, &[ind])?;
3234
3235 let quant_boxes = boxes
3236 .quantization
3237 .map(Quantization::from)
3238 .unwrap_or_default();
3239 let quant_protos = protos
3240 .quantization
3241 .map(Quantization::from)
3242 .unwrap_or_default();
3243
3244 with_quantized!(boxes_tensor, b, {
3245 with_quantized!(protos_tensor, p, {
3246 let box_tensor = Self::swap_axes_if_needed(b, boxes.into());
3247 let box_tensor = box_tensor.slice(s![0, .., ..]);
3248
3249 let protos_tensor = Self::swap_axes_if_needed(p, protos.into());
3250 let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
3251 decode_yolo_segdet_quant(
3252 (box_tensor, quant_boxes),
3253 (protos_tensor, quant_protos),
3254 self.score_threshold,
3255 self.iou_threshold,
3256 self.nms,
3257 output_boxes,
3258 output_masks,
3259 );
3260 });
3261 });
3262
3263 Ok(())
3264 }
3265
3266 fn decode_yolo_split_det_quantized(
3267 &self,
3268 outputs: &[ArrayViewDQuantized],
3269 boxes: &configs::Boxes,
3270 scores: &configs::Scores,
3271 output_boxes: &mut Vec<DetectBox>,
3272 ) -> Result<(), DecoderError> {
3273 let (boxes_tensor, ind) =
3274 Self::find_outputs_with_shape_quantized(&boxes.shape, outputs, &[])?;
3275 let (scores_tensor, _) =
3276 Self::find_outputs_with_shape_quantized(&scores.shape, outputs, &[ind])?;
3277 let quant_boxes = boxes
3278 .quantization
3279 .map(Quantization::from)
3280 .unwrap_or_default();
3281 let quant_scores = scores
3282 .quantization
3283 .map(Quantization::from)
3284 .unwrap_or_default();
3285
3286 with_quantized!(boxes_tensor, b, {
3287 with_quantized!(scores_tensor, s, {
3288 let boxes_tensor = Self::swap_axes_if_needed(b, boxes.into());
3289 let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
3290
3291 let scores_tensor = Self::swap_axes_if_needed(s, scores.into());
3292 let scores_tensor = scores_tensor.slice(s![0, .., ..]);
3293 decode_yolo_split_det_quant(
3294 (boxes_tensor, quant_boxes),
3295 (scores_tensor, quant_scores),
3296 self.score_threshold,
3297 self.iou_threshold,
3298 self.nms,
3299 output_boxes,
3300 );
3301 });
3302 });
3303
3304 Ok(())
3305 }
3306
3307 #[allow(clippy::too_many_arguments)]
3308 fn decode_yolo_split_segdet_quantized(
3309 &self,
3310 outputs: &[ArrayViewDQuantized],
3311 boxes: &configs::Boxes,
3312 scores: &configs::Scores,
3313 mask_coeff: &configs::MaskCoefficients,
3314 protos: &configs::Protos,
3315 output_boxes: &mut Vec<DetectBox>,
3316 output_masks: &mut Vec<Segmentation>,
3317 ) -> Result<(), DecoderError> {
3318 let quant_boxes = boxes
3319 .quantization
3320 .map(Quantization::from)
3321 .unwrap_or_default();
3322 let quant_scores = scores
3323 .quantization
3324 .map(Quantization::from)
3325 .unwrap_or_default();
3326 let quant_masks = mask_coeff
3327 .quantization
3328 .map(Quantization::from)
3329 .unwrap_or_default();
3330 let quant_protos = protos
3331 .quantization
3332 .map(Quantization::from)
3333 .unwrap_or_default();
3334
3335 let mut skip = vec![];
3336
3337 let (boxes_tensor, ind) =
3338 Self::find_outputs_with_shape_quantized(&boxes.shape, outputs, &skip)?;
3339 skip.push(ind);
3340
3341 let (scores_tensor, ind) =
3342 Self::find_outputs_with_shape_quantized(&scores.shape, outputs, &skip)?;
3343 skip.push(ind);
3344
3345 let (mask_tensor, ind) =
3346 Self::find_outputs_with_shape_quantized(&mask_coeff.shape, outputs, &skip)?;
3347 skip.push(ind);
3348
3349 let (protos_tensor, _) =
3350 Self::find_outputs_with_shape_quantized(&protos.shape, outputs, &skip)?;
3351
3352 let boxes = with_quantized!(boxes_tensor, b, {
3353 with_quantized!(scores_tensor, s, {
3354 let boxes_tensor = Self::swap_axes_if_needed(b, boxes.into());
3355 let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
3356
3357 let scores_tensor = Self::swap_axes_if_needed(s, scores.into());
3358 let scores_tensor = scores_tensor.slice(s![0, .., ..]);
3359 impl_yolo_split_segdet_quant_get_boxes::<XYWH, _, _>(
3360 (boxes_tensor, quant_boxes),
3361 (scores_tensor, quant_scores),
3362 self.score_threshold,
3363 self.iou_threshold,
3364 self.nms,
3365 output_boxes.capacity(),
3366 )
3367 })
3368 });
3369
3370 with_quantized!(mask_tensor, m, {
3371 with_quantized!(protos_tensor, p, {
3372 let mask_tensor = Self::swap_axes_if_needed(m, mask_coeff.into());
3373 let mask_tensor = mask_tensor.slice(s![0, .., ..]);
3374
3375 let protos_tensor = Self::swap_axes_if_needed(p, protos.into());
3376 let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
3377 impl_yolo_split_segdet_quant_process_masks::<_, _>(
3378 boxes,
3379 (mask_tensor, quant_masks),
3380 (protos_tensor, quant_protos),
3381 output_boxes,
3382 output_masks,
3383 )
3384 })
3385 });
3386
3387 Ok(())
3388 }
3389
3390 fn decode_modelpack_det_split_float<D>(
3391 &self,
3392 outputs: &[ArrayViewD<D>],
3393 detection: &[configs::Detection],
3394 output_boxes: &mut Vec<DetectBox>,
3395 ) -> Result<(), DecoderError>
3396 where
3397 D: AsPrimitive<f32>,
3398 {
3399 let new_detection = detection
3400 .iter()
3401 .map(|x| match &x.anchors {
3402 None => Err(DecoderError::InvalidConfig(
3403 "ModelPack Split Detection missing anchors".to_string(),
3404 )),
3405 Some(a) => Ok(ModelPackDetectionConfig {
3406 anchors: a.clone(),
3407 quantization: None,
3408 }),
3409 })
3410 .collect::<Result<Vec<_>, _>>()?;
3411
3412 let new_outputs = Self::match_outputs_to_detect(detection, outputs)?;
3413 let new_outputs = new_outputs
3414 .into_iter()
3415 .map(|x| x.slice(s![0, .., .., ..]))
3416 .collect::<Vec<_>>();
3417
3418 decode_modelpack_split_float(
3419 &new_outputs,
3420 &new_detection,
3421 self.score_threshold,
3422 self.iou_threshold,
3423 output_boxes,
3424 );
3425 Ok(())
3426 }
3427
3428 fn decode_modelpack_seg_float<T>(
3429 &self,
3430 outputs: &[ArrayViewD<T>],
3431 segmentation: &configs::Segmentation,
3432 output_masks: &mut Vec<Segmentation>,
3433 ) -> Result<(), DecoderError>
3434 where
3435 T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
3436 f32: AsPrimitive<T>,
3437 {
3438 let (seg, _) = Self::find_outputs_with_shape(&segmentation.shape, outputs, &[])?;
3439
3440 let seg = Self::swap_axes_if_needed(seg, segmentation.into());
3441 let seg = seg.slice(s![0, .., .., ..]);
3442 let u8_max = 255.0_f32.as_();
3443 let max = *seg.max().unwrap_or(&u8_max);
3444 let min = *seg.min().unwrap_or(&0.0_f32.as_());
3445 let seg = seg.mapv(|x| ((x - min) / (max - min) * u8_max).as_());
3446 output_masks.push(Segmentation {
3447 xmin: 0.0,
3448 ymin: 0.0,
3449 xmax: 1.0,
3450 ymax: 1.0,
3451 segmentation: seg,
3452 });
3453 Ok(())
3454 }
3455
3456 fn decode_modelpack_det_float<T>(
3457 &self,
3458 outputs: &[ArrayViewD<T>],
3459 boxes: &configs::Boxes,
3460 scores: &configs::Scores,
3461 output_boxes: &mut Vec<DetectBox>,
3462 ) -> Result<(), DecoderError>
3463 where
3464 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
3465 f32: AsPrimitive<T>,
3466 {
3467 let (boxes_tensor, ind) = Self::find_outputs_with_shape(&boxes.shape, outputs, &[])?;
3468
3469 let boxes_tensor = Self::swap_axes_if_needed(boxes_tensor, boxes.into());
3470 let boxes_tensor = boxes_tensor.slice(s![0, .., 0, ..]);
3471
3472 let (scores_tensor, _) = Self::find_outputs_with_shape(&scores.shape, outputs, &[ind])?;
3473 let scores_tensor = Self::swap_axes_if_needed(scores_tensor, scores.into());
3474 let scores_tensor = scores_tensor.slice(s![0, .., ..]);
3475
3476 decode_modelpack_float(
3477 boxes_tensor,
3478 scores_tensor,
3479 self.score_threshold,
3480 self.iou_threshold,
3481 output_boxes,
3482 );
3483 Ok(())
3484 }
3485
3486 fn decode_yolo_det_float<T>(
3487 &self,
3488 outputs: &[ArrayViewD<T>],
3489 boxes: &configs::Detection,
3490 output_boxes: &mut Vec<DetectBox>,
3491 ) -> Result<(), DecoderError>
3492 where
3493 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
3494 f32: AsPrimitive<T>,
3495 {
3496 let (boxes_tensor, _) = Self::find_outputs_with_shape(&boxes.shape, outputs, &[])?;
3497
3498 let boxes_tensor = Self::swap_axes_if_needed(boxes_tensor, boxes.into());
3499 let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
3500 decode_yolo_det_float(
3501 boxes_tensor,
3502 self.score_threshold,
3503 self.iou_threshold,
3504 self.nms,
3505 output_boxes,
3506 );
3507 Ok(())
3508 }
3509
3510 fn decode_yolo_segdet_float<T>(
3511 &self,
3512 outputs: &[ArrayViewD<T>],
3513 boxes: &configs::Detection,
3514 protos: &configs::Protos,
3515 output_boxes: &mut Vec<DetectBox>,
3516 output_masks: &mut Vec<Segmentation>,
3517 ) -> Result<(), DecoderError>
3518 where
3519 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
3520 f32: AsPrimitive<T>,
3521 {
3522 let (boxes_tensor, ind) = Self::find_outputs_with_shape(&boxes.shape, outputs, &[])?;
3523
3524 let boxes_tensor = Self::swap_axes_if_needed(boxes_tensor, boxes.into());
3525 let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
3526
3527 let (protos_tensor, _) = Self::find_outputs_with_shape(&protos.shape, outputs, &[ind])?;
3528
3529 let protos_tensor = Self::swap_axes_if_needed(protos_tensor, protos.into());
3530 let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
3531 decode_yolo_segdet_float(
3532 boxes_tensor,
3533 protos_tensor,
3534 self.score_threshold,
3535 self.iou_threshold,
3536 self.nms,
3537 output_boxes,
3538 output_masks,
3539 );
3540 Ok(())
3541 }
3542
3543 fn decode_yolo_split_det_float<T>(
3544 &self,
3545 outputs: &[ArrayViewD<T>],
3546 boxes: &configs::Boxes,
3547 scores: &configs::Scores,
3548 output_boxes: &mut Vec<DetectBox>,
3549 ) -> Result<(), DecoderError>
3550 where
3551 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
3552 f32: AsPrimitive<T>,
3553 {
3554 let (boxes_tensor, ind) = Self::find_outputs_with_shape(&boxes.shape, outputs, &[])?;
3555 let boxes_tensor = Self::swap_axes_if_needed(boxes_tensor, boxes.into());
3556 let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
3557
3558 let (scores_tensor, _) = Self::find_outputs_with_shape(&scores.shape, outputs, &[ind])?;
3559
3560 let scores_tensor = Self::swap_axes_if_needed(scores_tensor, scores.into());
3561 let scores_tensor = scores_tensor.slice(s![0, .., ..]);
3562
3563 decode_yolo_split_det_float(
3564 boxes_tensor,
3565 scores_tensor,
3566 self.score_threshold,
3567 self.iou_threshold,
3568 self.nms,
3569 output_boxes,
3570 );
3571 Ok(())
3572 }
3573
3574 #[allow(clippy::too_many_arguments)]
3575 fn decode_yolo_split_segdet_float<T>(
3576 &self,
3577 outputs: &[ArrayViewD<T>],
3578 boxes: &configs::Boxes,
3579 scores: &configs::Scores,
3580 mask_coeff: &configs::MaskCoefficients,
3581 protos: &configs::Protos,
3582 output_boxes: &mut Vec<DetectBox>,
3583 output_masks: &mut Vec<Segmentation>,
3584 ) -> Result<(), DecoderError>
3585 where
3586 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
3587 f32: AsPrimitive<T>,
3588 {
3589 let mut skip = vec![];
3590 let (boxes_tensor, ind) = Self::find_outputs_with_shape(&boxes.shape, outputs, &skip)?;
3591
3592 let boxes_tensor = Self::swap_axes_if_needed(boxes_tensor, boxes.into());
3593 let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
3594 skip.push(ind);
3595
3596 let (scores_tensor, ind) = Self::find_outputs_with_shape(&scores.shape, outputs, &skip)?;
3597
3598 let scores_tensor = Self::swap_axes_if_needed(scores_tensor, scores.into());
3599 let scores_tensor = scores_tensor.slice(s![0, .., ..]);
3600 skip.push(ind);
3601
3602 let (mask_tensor, ind) = Self::find_outputs_with_shape(&mask_coeff.shape, outputs, &skip)?;
3603 let mask_tensor = Self::swap_axes_if_needed(mask_tensor, mask_coeff.into());
3604 let mask_tensor = mask_tensor.slice(s![0, .., ..]);
3605 skip.push(ind);
3606
3607 let (protos_tensor, _) = Self::find_outputs_with_shape(&protos.shape, outputs, &skip)?;
3608 let protos_tensor = Self::swap_axes_if_needed(protos_tensor, protos.into());
3609 let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
3610 decode_yolo_split_segdet_float(
3611 boxes_tensor,
3612 scores_tensor,
3613 mask_tensor,
3614 protos_tensor,
3615 self.score_threshold,
3616 self.iou_threshold,
3617 self.nms,
3618 output_boxes,
3619 output_masks,
3620 );
3621 Ok(())
3622 }
3623
3624 fn decode_yolo_end_to_end_det_float<T>(
3630 &self,
3631 outputs: &[ArrayViewD<T>],
3632 boxes_config: &configs::Detection,
3633 output_boxes: &mut Vec<DetectBox>,
3634 ) -> Result<(), DecoderError>
3635 where
3636 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
3637 f32: AsPrimitive<T>,
3638 {
3639 let (det_tensor, _) = Self::find_outputs_with_shape(&boxes_config.shape, outputs, &[])?;
3640 let det_tensor = Self::swap_axes_if_needed(det_tensor, boxes_config.into());
3641 let det_tensor = det_tensor.slice(s![0, .., ..]);
3642
3643 crate::yolo::decode_yolo_end_to_end_det_float(
3644 det_tensor,
3645 self.score_threshold,
3646 output_boxes,
3647 )?;
3648 Ok(())
3649 }
3650
3651 fn decode_yolo_end_to_end_segdet_float<T>(
3659 &self,
3660 outputs: &[ArrayViewD<T>],
3661 boxes_config: &configs::Detection,
3662 protos_config: &configs::Protos,
3663 output_boxes: &mut Vec<DetectBox>,
3664 output_masks: &mut Vec<Segmentation>,
3665 ) -> Result<(), DecoderError>
3666 where
3667 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
3668 f32: AsPrimitive<T>,
3669 {
3670 if outputs.len() < 2 {
3671 return Err(DecoderError::InvalidShape(
3672 "End-to-end segdet requires detection and protos outputs".to_string(),
3673 ));
3674 }
3675
3676 let (det_tensor, det_ind) =
3677 Self::find_outputs_with_shape(&boxes_config.shape, outputs, &[])?;
3678 let det_tensor = Self::swap_axes_if_needed(det_tensor, boxes_config.into());
3679 let det_tensor = det_tensor.slice(s![0, .., ..]);
3680
3681 let (protos_tensor, _) =
3682 Self::find_outputs_with_shape(&protos_config.shape, outputs, &[det_ind])?;
3683 let protos_tensor = Self::swap_axes_if_needed(protos_tensor, protos_config.into());
3684 let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
3685
3686 crate::yolo::decode_yolo_end_to_end_segdet_float(
3687 det_tensor,
3688 protos_tensor,
3689 self.score_threshold,
3690 output_boxes,
3691 output_masks,
3692 )?;
3693 Ok(())
3694 }
3695
3696 fn decode_yolo_end_to_end_det_quantized(
3699 &self,
3700 outputs: &[ArrayViewDQuantized],
3701 boxes_config: &configs::Detection,
3702 output_boxes: &mut Vec<DetectBox>,
3703 ) -> Result<(), DecoderError> {
3704 let (det_tensor, _) =
3705 Self::find_outputs_with_shape_quantized(&boxes_config.shape, outputs, &[])?;
3706 let quant = boxes_config
3707 .quantization
3708 .map(Quantization::from)
3709 .unwrap_or_default();
3710
3711 with_quantized!(det_tensor, d, {
3712 let d = Self::swap_axes_if_needed(d, boxes_config.into());
3713 let d = d.slice(s![0, .., ..]);
3714 let dequant = d.map(|v| {
3715 let val: f32 = v.as_();
3716 (val - quant.zero_point as f32) * quant.scale
3717 });
3718 crate::yolo::decode_yolo_end_to_end_det_float(
3719 dequant.view(),
3720 self.score_threshold,
3721 output_boxes,
3722 )?;
3723 });
3724 Ok(())
3725 }
3726
3727 #[allow(clippy::too_many_arguments)]
3729 fn decode_yolo_end_to_end_segdet_quantized(
3730 &self,
3731 outputs: &[ArrayViewDQuantized],
3732 boxes_config: &configs::Detection,
3733 protos_config: &configs::Protos,
3734 output_boxes: &mut Vec<DetectBox>,
3735 output_masks: &mut Vec<Segmentation>,
3736 ) -> Result<(), DecoderError> {
3737 let (det_tensor, det_ind) =
3738 Self::find_outputs_with_shape_quantized(&boxes_config.shape, outputs, &[])?;
3739 let (protos_tensor, _) =
3740 Self::find_outputs_with_shape_quantized(&protos_config.shape, outputs, &[det_ind])?;
3741
3742 let quant_det = boxes_config
3743 .quantization
3744 .map(Quantization::from)
3745 .unwrap_or_default();
3746 let quant_protos = protos_config
3747 .quantization
3748 .map(Quantization::from)
3749 .unwrap_or_default();
3750
3751 macro_rules! dequant_3d {
3754 ($tensor:expr, $config:expr, $quant:expr) => {{
3755 with_quantized!($tensor, t, {
3756 let t = Self::swap_axes_if_needed(t, $config.into());
3757 let t = t.slice(s![0, .., ..]);
3758 t.map(|v| {
3759 let val: f32 = v.as_();
3760 (val - $quant.zero_point as f32) * $quant.scale
3761 })
3762 })
3763 }};
3764 }
3765 macro_rules! dequant_4d {
3766 ($tensor:expr, $config:expr, $quant:expr) => {{
3767 with_quantized!($tensor, t, {
3768 let t = Self::swap_axes_if_needed(t, $config.into());
3769 let t = t.slice(s![0, .., .., ..]);
3770 t.map(|v| {
3771 let val: f32 = v.as_();
3772 (val - $quant.zero_point as f32) * $quant.scale
3773 })
3774 })
3775 }};
3776 }
3777
3778 let dequant_d = dequant_3d!(det_tensor, boxes_config, quant_det);
3779 let dequant_p = dequant_4d!(protos_tensor, protos_config, quant_protos);
3780
3781 crate::yolo::decode_yolo_end_to_end_segdet_float(
3782 dequant_d.view(),
3783 dequant_p.view(),
3784 self.score_threshold,
3785 output_boxes,
3786 output_masks,
3787 )?;
3788 Ok(())
3789 }
3790
3791 fn decode_yolo_split_end_to_end_det_float<T>(
3793 &self,
3794 outputs: &[ArrayViewD<T>],
3795 boxes_config: &configs::Boxes,
3796 scores_config: &configs::Scores,
3797 classes_config: &configs::Classes,
3798 output_boxes: &mut Vec<DetectBox>,
3799 ) -> Result<(), DecoderError>
3800 where
3801 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
3802 f32: AsPrimitive<T>,
3803 {
3804 let mut skip = vec![];
3805 let (boxes_tensor, ind) =
3806 Self::find_outputs_with_shape(&boxes_config.shape, outputs, &skip)?;
3807 let boxes_tensor = Self::swap_axes_if_needed(boxes_tensor, boxes_config.into());
3808 let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
3809 skip.push(ind);
3810
3811 let (scores_tensor, ind) =
3812 Self::find_outputs_with_shape(&scores_config.shape, outputs, &skip)?;
3813 let scores_tensor = Self::swap_axes_if_needed(scores_tensor, scores_config.into());
3814 let scores_tensor = scores_tensor.slice(s![0, .., ..]);
3815 skip.push(ind);
3816
3817 let (classes_tensor, _) =
3818 Self::find_outputs_with_shape(&classes_config.shape, outputs, &skip)?;
3819 let classes_tensor = Self::swap_axes_if_needed(classes_tensor, classes_config.into());
3820 let classes_tensor = classes_tensor.slice(s![0, .., ..]);
3821
3822 crate::yolo::decode_yolo_split_end_to_end_det_float(
3823 boxes_tensor,
3824 scores_tensor,
3825 classes_tensor,
3826 self.score_threshold,
3827 output_boxes,
3828 )?;
3829 Ok(())
3830 }
3831
3832 #[allow(clippy::too_many_arguments)]
3834 fn decode_yolo_split_end_to_end_segdet_float<T>(
3835 &self,
3836 outputs: &[ArrayViewD<T>],
3837 boxes_config: &configs::Boxes,
3838 scores_config: &configs::Scores,
3839 classes_config: &configs::Classes,
3840 mask_coeff_config: &configs::MaskCoefficients,
3841 protos_config: &configs::Protos,
3842 output_boxes: &mut Vec<DetectBox>,
3843 output_masks: &mut Vec<Segmentation>,
3844 ) -> Result<(), DecoderError>
3845 where
3846 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
3847 f32: AsPrimitive<T>,
3848 {
3849 let mut skip = vec![];
3850 let (boxes_tensor, ind) =
3851 Self::find_outputs_with_shape(&boxes_config.shape, outputs, &skip)?;
3852 let boxes_tensor = Self::swap_axes_if_needed(boxes_tensor, boxes_config.into());
3853 let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
3854 skip.push(ind);
3855
3856 let (scores_tensor, ind) =
3857 Self::find_outputs_with_shape(&scores_config.shape, outputs, &skip)?;
3858 let scores_tensor = Self::swap_axes_if_needed(scores_tensor, scores_config.into());
3859 let scores_tensor = scores_tensor.slice(s![0, .., ..]);
3860 skip.push(ind);
3861
3862 let (classes_tensor, ind) =
3863 Self::find_outputs_with_shape(&classes_config.shape, outputs, &skip)?;
3864 let classes_tensor = Self::swap_axes_if_needed(classes_tensor, classes_config.into());
3865 let classes_tensor = classes_tensor.slice(s![0, .., ..]);
3866 skip.push(ind);
3867
3868 let (mask_tensor, ind) =
3869 Self::find_outputs_with_shape(&mask_coeff_config.shape, outputs, &skip)?;
3870 let mask_tensor = Self::swap_axes_if_needed(mask_tensor, mask_coeff_config.into());
3871 let mask_tensor = mask_tensor.slice(s![0, .., ..]);
3872 skip.push(ind);
3873
3874 let (protos_tensor, _) =
3875 Self::find_outputs_with_shape(&protos_config.shape, outputs, &skip)?;
3876 let protos_tensor = Self::swap_axes_if_needed(protos_tensor, protos_config.into());
3877 let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
3878
3879 crate::yolo::decode_yolo_split_end_to_end_segdet_float(
3880 boxes_tensor,
3881 scores_tensor,
3882 classes_tensor,
3883 mask_tensor,
3884 protos_tensor,
3885 self.score_threshold,
3886 output_boxes,
3887 output_masks,
3888 )?;
3889 Ok(())
3890 }
3891
3892 fn decode_yolo_split_end_to_end_det_quantized(
3895 &self,
3896 outputs: &[ArrayViewDQuantized],
3897 boxes_config: &configs::Boxes,
3898 scores_config: &configs::Scores,
3899 classes_config: &configs::Classes,
3900 output_boxes: &mut Vec<DetectBox>,
3901 ) -> Result<(), DecoderError> {
3902 let mut skip = vec![];
3903 let (boxes_tensor, ind) =
3904 Self::find_outputs_with_shape_quantized(&boxes_config.shape, outputs, &skip)?;
3905 skip.push(ind);
3906 let (scores_tensor, ind) =
3907 Self::find_outputs_with_shape_quantized(&scores_config.shape, outputs, &skip)?;
3908 skip.push(ind);
3909 let (classes_tensor, _) =
3910 Self::find_outputs_with_shape_quantized(&classes_config.shape, outputs, &skip)?;
3911
3912 let quant_boxes = boxes_config
3913 .quantization
3914 .map(Quantization::from)
3915 .unwrap_or_default();
3916 let quant_scores = scores_config
3917 .quantization
3918 .map(Quantization::from)
3919 .unwrap_or_default();
3920 let quant_classes = classes_config
3921 .quantization
3922 .map(Quantization::from)
3923 .unwrap_or_default();
3924
3925 macro_rules! dequant_3d {
3928 ($tensor:expr, $config:expr, $quant:expr) => {{
3929 with_quantized!($tensor, t, {
3930 let t = Self::swap_axes_if_needed(t, $config.into());
3931 let t = t.slice(s![0, .., ..]);
3932 t.map(|v| {
3933 let val: f32 = v.as_();
3934 (val - $quant.zero_point as f32) * $quant.scale
3935 })
3936 })
3937 }};
3938 }
3939
3940 let dequant_b = dequant_3d!(boxes_tensor, boxes_config, quant_boxes);
3941 let dequant_s = dequant_3d!(scores_tensor, scores_config, quant_scores);
3942 let dequant_c = dequant_3d!(classes_tensor, classes_config, quant_classes);
3943
3944 crate::yolo::decode_yolo_split_end_to_end_det_float(
3945 dequant_b.view(),
3946 dequant_s.view(),
3947 dequant_c.view(),
3948 self.score_threshold,
3949 output_boxes,
3950 )?;
3951 Ok(())
3952 }
3953
3954 #[allow(clippy::too_many_arguments)]
3956 fn decode_yolo_split_end_to_end_segdet_quantized(
3957 &self,
3958 outputs: &[ArrayViewDQuantized],
3959 boxes_config: &configs::Boxes,
3960 scores_config: &configs::Scores,
3961 classes_config: &configs::Classes,
3962 mask_coeff_config: &configs::MaskCoefficients,
3963 protos_config: &configs::Protos,
3964 output_boxes: &mut Vec<DetectBox>,
3965 output_masks: &mut Vec<Segmentation>,
3966 ) -> Result<(), DecoderError> {
3967 let mut skip = vec![];
3968 let (boxes_tensor, ind) =
3969 Self::find_outputs_with_shape_quantized(&boxes_config.shape, outputs, &skip)?;
3970 skip.push(ind);
3971 let (scores_tensor, ind) =
3972 Self::find_outputs_with_shape_quantized(&scores_config.shape, outputs, &skip)?;
3973 skip.push(ind);
3974 let (classes_tensor, ind) =
3975 Self::find_outputs_with_shape_quantized(&classes_config.shape, outputs, &skip)?;
3976 skip.push(ind);
3977 let (mask_tensor, ind) =
3978 Self::find_outputs_with_shape_quantized(&mask_coeff_config.shape, outputs, &skip)?;
3979 skip.push(ind);
3980 let (protos_tensor, _) =
3981 Self::find_outputs_with_shape_quantized(&protos_config.shape, outputs, &skip)?;
3982
3983 let quant_boxes = boxes_config
3984 .quantization
3985 .map(Quantization::from)
3986 .unwrap_or_default();
3987 let quant_scores = scores_config
3988 .quantization
3989 .map(Quantization::from)
3990 .unwrap_or_default();
3991 let quant_classes = classes_config
3992 .quantization
3993 .map(Quantization::from)
3994 .unwrap_or_default();
3995 let quant_masks = mask_coeff_config
3996 .quantization
3997 .map(Quantization::from)
3998 .unwrap_or_default();
3999 let quant_protos = protos_config
4000 .quantization
4001 .map(Quantization::from)
4002 .unwrap_or_default();
4003
4004 macro_rules! dequant_3d {
4007 ($tensor:expr, $config:expr, $quant:expr) => {{
4008 with_quantized!($tensor, t, {
4009 let t = Self::swap_axes_if_needed(t, $config.into());
4010 let t = t.slice(s![0, .., ..]);
4011 t.map(|v| {
4012 let val: f32 = v.as_();
4013 (val - $quant.zero_point as f32) * $quant.scale
4014 })
4015 })
4016 }};
4017 }
4018 macro_rules! dequant_4d {
4019 ($tensor:expr, $config:expr, $quant:expr) => {{
4020 with_quantized!($tensor, t, {
4021 let t = Self::swap_axes_if_needed(t, $config.into());
4022 let t = t.slice(s![0, .., .., ..]);
4023 t.map(|v| {
4024 let val: f32 = v.as_();
4025 (val - $quant.zero_point as f32) * $quant.scale
4026 })
4027 })
4028 }};
4029 }
4030
4031 let dequant_b = dequant_3d!(boxes_tensor, boxes_config, quant_boxes);
4032 let dequant_s = dequant_3d!(scores_tensor, scores_config, quant_scores);
4033 let dequant_c = dequant_3d!(classes_tensor, classes_config, quant_classes);
4034 let dequant_m = dequant_3d!(mask_tensor, mask_coeff_config, quant_masks);
4035 let dequant_p = dequant_4d!(protos_tensor, protos_config, quant_protos);
4036
4037 crate::yolo::decode_yolo_split_end_to_end_segdet_float(
4038 dequant_b.view(),
4039 dequant_s.view(),
4040 dequant_c.view(),
4041 dequant_m.view(),
4042 dequant_p.view(),
4043 self.score_threshold,
4044 output_boxes,
4045 output_masks,
4046 )?;
4047 Ok(())
4048 }
4049
4050 fn decode_yolo_segdet_quantized_proto(
4055 &self,
4056 outputs: &[ArrayViewDQuantized],
4057 boxes: &configs::Detection,
4058 protos: &configs::Protos,
4059 output_boxes: &mut Vec<DetectBox>,
4060 ) -> Result<ProtoData, DecoderError> {
4061 let (boxes_tensor, ind) =
4062 Self::find_outputs_with_shape_quantized(&boxes.shape, outputs, &[])?;
4063 let (protos_tensor, _) =
4064 Self::find_outputs_with_shape_quantized(&protos.shape, outputs, &[ind])?;
4065
4066 let quant_boxes = boxes
4067 .quantization
4068 .map(Quantization::from)
4069 .unwrap_or_default();
4070 let quant_protos = protos
4071 .quantization
4072 .map(Quantization::from)
4073 .unwrap_or_default();
4074
4075 let proto = with_quantized!(boxes_tensor, b, {
4076 with_quantized!(protos_tensor, p, {
4077 let box_tensor = Self::swap_axes_if_needed(b, boxes.into());
4078 let box_tensor = box_tensor.slice(s![0, .., ..]);
4079
4080 let protos_tensor = Self::swap_axes_if_needed(p, protos.into());
4081 let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
4082 crate::yolo::impl_yolo_segdet_quant_proto::<XYWH, _, _>(
4083 (box_tensor, quant_boxes),
4084 (protos_tensor, quant_protos),
4085 self.score_threshold,
4086 self.iou_threshold,
4087 self.nms,
4088 output_boxes,
4089 )
4090 })
4091 });
4092 Ok(proto)
4093 }
4094
4095 fn decode_yolo_segdet_float_proto<T>(
4096 &self,
4097 outputs: &[ArrayViewD<T>],
4098 boxes: &configs::Detection,
4099 protos: &configs::Protos,
4100 output_boxes: &mut Vec<DetectBox>,
4101 ) -> Result<ProtoData, DecoderError>
4102 where
4103 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
4104 f32: AsPrimitive<T>,
4105 {
4106 let (boxes_tensor, ind) = Self::find_outputs_with_shape(&boxes.shape, outputs, &[])?;
4107 let boxes_tensor = Self::swap_axes_if_needed(boxes_tensor, boxes.into());
4108 let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
4109
4110 let (protos_tensor, _) = Self::find_outputs_with_shape(&protos.shape, outputs, &[ind])?;
4111 let protos_tensor = Self::swap_axes_if_needed(protos_tensor, protos.into());
4112 let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
4113
4114 Ok(crate::yolo::impl_yolo_segdet_float_proto::<XYWH, _, _>(
4115 boxes_tensor,
4116 protos_tensor,
4117 self.score_threshold,
4118 self.iou_threshold,
4119 self.nms,
4120 output_boxes,
4121 ))
4122 }
4123
4124 #[allow(clippy::too_many_arguments)]
4125 fn decode_yolo_split_segdet_quantized_proto(
4126 &self,
4127 outputs: &[ArrayViewDQuantized],
4128 boxes: &configs::Boxes,
4129 scores: &configs::Scores,
4130 mask_coeff: &configs::MaskCoefficients,
4131 protos: &configs::Protos,
4132 output_boxes: &mut Vec<DetectBox>,
4133 ) -> Result<ProtoData, DecoderError> {
4134 let quant_boxes = boxes
4135 .quantization
4136 .map(Quantization::from)
4137 .unwrap_or_default();
4138 let quant_scores = scores
4139 .quantization
4140 .map(Quantization::from)
4141 .unwrap_or_default();
4142 let quant_masks = mask_coeff
4143 .quantization
4144 .map(Quantization::from)
4145 .unwrap_or_default();
4146 let quant_protos = protos
4147 .quantization
4148 .map(Quantization::from)
4149 .unwrap_or_default();
4150
4151 let mut skip = vec![];
4152
4153 let (boxes_tensor, ind) =
4154 Self::find_outputs_with_shape_quantized(&boxes.shape, outputs, &skip)?;
4155 skip.push(ind);
4156
4157 let (scores_tensor, ind) =
4158 Self::find_outputs_with_shape_quantized(&scores.shape, outputs, &skip)?;
4159 skip.push(ind);
4160
4161 let (mask_tensor, ind) =
4162 Self::find_outputs_with_shape_quantized(&mask_coeff.shape, outputs, &skip)?;
4163 skip.push(ind);
4164
4165 let (protos_tensor, _) =
4166 Self::find_outputs_with_shape_quantized(&protos.shape, outputs, &skip)?;
4167
4168 let det_indices = with_quantized!(boxes_tensor, b, {
4170 with_quantized!(scores_tensor, s, {
4171 let boxes_tensor = Self::swap_axes_if_needed(b, boxes.into());
4172 let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
4173
4174 let scores_tensor = Self::swap_axes_if_needed(s, scores.into());
4175 let scores_tensor = scores_tensor.slice(s![0, .., ..]);
4176
4177 impl_yolo_split_segdet_quant_get_boxes::<XYWH, _, _>(
4178 (boxes_tensor, quant_boxes),
4179 (scores_tensor, quant_scores),
4180 self.score_threshold,
4181 self.iou_threshold,
4182 self.nms,
4183 output_boxes.capacity(),
4184 )
4185 })
4186 });
4187
4188 let proto = with_quantized!(mask_tensor, m, {
4190 with_quantized!(protos_tensor, p, {
4191 let mask_tensor = Self::swap_axes_if_needed(m, mask_coeff.into());
4192 let mask_tensor = mask_tensor.slice(s![0, .., ..]);
4193 let mask_tensor = mask_tensor.reversed_axes();
4194
4195 let protos_tensor = Self::swap_axes_if_needed(p, protos.into());
4196 let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
4197
4198 crate::yolo::extract_proto_data_quant(
4199 det_indices,
4200 mask_tensor,
4201 quant_masks,
4202 protos_tensor,
4203 quant_protos,
4204 output_boxes,
4205 )
4206 })
4207 });
4208 Ok(proto)
4209 }
4210
4211 #[allow(clippy::too_many_arguments)]
4212 fn decode_yolo_split_segdet_float_proto<T>(
4213 &self,
4214 outputs: &[ArrayViewD<T>],
4215 boxes: &configs::Boxes,
4216 scores: &configs::Scores,
4217 mask_coeff: &configs::MaskCoefficients,
4218 protos: &configs::Protos,
4219 output_boxes: &mut Vec<DetectBox>,
4220 ) -> Result<ProtoData, DecoderError>
4221 where
4222 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
4223 f32: AsPrimitive<T>,
4224 {
4225 let mut skip = vec![];
4226 let (boxes_tensor, ind) = Self::find_outputs_with_shape(&boxes.shape, outputs, &skip)?;
4227 let boxes_tensor = Self::swap_axes_if_needed(boxes_tensor, boxes.into());
4228 let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
4229 skip.push(ind);
4230
4231 let (scores_tensor, ind) = Self::find_outputs_with_shape(&scores.shape, outputs, &skip)?;
4232 let scores_tensor = Self::swap_axes_if_needed(scores_tensor, scores.into());
4233 let scores_tensor = scores_tensor.slice(s![0, .., ..]);
4234 skip.push(ind);
4235
4236 let (mask_tensor, ind) = Self::find_outputs_with_shape(&mask_coeff.shape, outputs, &skip)?;
4237 let mask_tensor = Self::swap_axes_if_needed(mask_tensor, mask_coeff.into());
4238 let mask_tensor = mask_tensor.slice(s![0, .., ..]);
4239 skip.push(ind);
4240
4241 let (protos_tensor, _) = Self::find_outputs_with_shape(&protos.shape, outputs, &skip)?;
4242 let protos_tensor = Self::swap_axes_if_needed(protos_tensor, protos.into());
4243 let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
4244
4245 Ok(crate::yolo::impl_yolo_split_segdet_float_proto::<
4246 XYWH,
4247 _,
4248 _,
4249 _,
4250 _,
4251 >(
4252 boxes_tensor,
4253 scores_tensor,
4254 mask_tensor,
4255 protos_tensor,
4256 self.score_threshold,
4257 self.iou_threshold,
4258 self.nms,
4259 output_boxes,
4260 ))
4261 }
4262
4263 fn decode_yolo_end_to_end_segdet_float_proto<T>(
4264 &self,
4265 outputs: &[ArrayViewD<T>],
4266 boxes_config: &configs::Detection,
4267 protos_config: &configs::Protos,
4268 output_boxes: &mut Vec<DetectBox>,
4269 ) -> Result<ProtoData, DecoderError>
4270 where
4271 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
4272 f32: AsPrimitive<T>,
4273 {
4274 if outputs.len() < 2 {
4275 return Err(DecoderError::InvalidShape(
4276 "End-to-end segdet requires detection and protos outputs".to_string(),
4277 ));
4278 }
4279
4280 let (det_tensor, det_ind) =
4281 Self::find_outputs_with_shape(&boxes_config.shape, outputs, &[])?;
4282 let det_tensor = Self::swap_axes_if_needed(det_tensor, boxes_config.into());
4283 let det_tensor = det_tensor.slice(s![0, .., ..]);
4284
4285 let (protos_tensor, _) =
4286 Self::find_outputs_with_shape(&protos_config.shape, outputs, &[det_ind])?;
4287 let protos_tensor = Self::swap_axes_if_needed(protos_tensor, protos_config.into());
4288 let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
4289
4290 crate::yolo::decode_yolo_end_to_end_segdet_float_proto(
4291 det_tensor,
4292 protos_tensor,
4293 self.score_threshold,
4294 output_boxes,
4295 )
4296 }
4297
4298 fn decode_yolo_end_to_end_segdet_quantized_proto(
4299 &self,
4300 outputs: &[ArrayViewDQuantized],
4301 boxes_config: &configs::Detection,
4302 protos_config: &configs::Protos,
4303 output_boxes: &mut Vec<DetectBox>,
4304 ) -> Result<ProtoData, DecoderError> {
4305 let (det_tensor, det_ind) =
4306 Self::find_outputs_with_shape_quantized(&boxes_config.shape, outputs, &[])?;
4307 let (protos_tensor, _) =
4308 Self::find_outputs_with_shape_quantized(&protos_config.shape, outputs, &[det_ind])?;
4309
4310 let quant_det = boxes_config
4311 .quantization
4312 .map(Quantization::from)
4313 .unwrap_or_default();
4314 let quant_protos = protos_config
4315 .quantization
4316 .map(Quantization::from)
4317 .unwrap_or_default();
4318
4319 macro_rules! dequant_3d {
4322 ($tensor:expr, $config:expr, $quant:expr) => {{
4323 with_quantized!($tensor, t, {
4324 let t = Self::swap_axes_if_needed(t, $config.into());
4325 let t = t.slice(s![0, .., ..]);
4326 t.map(|v| {
4327 let val: f32 = v.as_();
4328 (val - $quant.zero_point as f32) * $quant.scale
4329 })
4330 })
4331 }};
4332 }
4333 macro_rules! dequant_4d {
4334 ($tensor:expr, $config:expr, $quant:expr) => {{
4335 with_quantized!($tensor, t, {
4336 let t = Self::swap_axes_if_needed(t, $config.into());
4337 let t = t.slice(s![0, .., .., ..]);
4338 t.map(|v| {
4339 let val: f32 = v.as_();
4340 (val - $quant.zero_point as f32) * $quant.scale
4341 })
4342 })
4343 }};
4344 }
4345
4346 let dequant_d = dequant_3d!(det_tensor, boxes_config, quant_det);
4347 let dequant_p = dequant_4d!(protos_tensor, protos_config, quant_protos);
4348
4349 let proto = crate::yolo::decode_yolo_end_to_end_segdet_float_proto(
4350 dequant_d.view(),
4351 dequant_p.view(),
4352 self.score_threshold,
4353 output_boxes,
4354 )?;
4355 Ok(proto)
4356 }
4357
4358 #[allow(clippy::too_many_arguments)]
4359 fn decode_yolo_split_end_to_end_segdet_float_proto<T>(
4360 &self,
4361 outputs: &[ArrayViewD<T>],
4362 boxes_config: &configs::Boxes,
4363 scores_config: &configs::Scores,
4364 classes_config: &configs::Classes,
4365 mask_coeff_config: &configs::MaskCoefficients,
4366 protos_config: &configs::Protos,
4367 output_boxes: &mut Vec<DetectBox>,
4368 ) -> Result<ProtoData, DecoderError>
4369 where
4370 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
4371 f32: AsPrimitive<T>,
4372 {
4373 let mut skip = vec![];
4374 let (boxes_tensor, ind) =
4375 Self::find_outputs_with_shape(&boxes_config.shape, outputs, &skip)?;
4376 let boxes_tensor = Self::swap_axes_if_needed(boxes_tensor, boxes_config.into());
4377 let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
4378 skip.push(ind);
4379
4380 let (scores_tensor, ind) =
4381 Self::find_outputs_with_shape(&scores_config.shape, outputs, &skip)?;
4382 let scores_tensor = Self::swap_axes_if_needed(scores_tensor, scores_config.into());
4383 let scores_tensor = scores_tensor.slice(s![0, .., ..]);
4384 skip.push(ind);
4385
4386 let (classes_tensor, ind) =
4387 Self::find_outputs_with_shape(&classes_config.shape, outputs, &skip)?;
4388 let classes_tensor = Self::swap_axes_if_needed(classes_tensor, classes_config.into());
4389 let classes_tensor = classes_tensor.slice(s![0, .., ..]);
4390 skip.push(ind);
4391
4392 let (mask_tensor, ind) =
4393 Self::find_outputs_with_shape(&mask_coeff_config.shape, outputs, &skip)?;
4394 let mask_tensor = Self::swap_axes_if_needed(mask_tensor, mask_coeff_config.into());
4395 let mask_tensor = mask_tensor.slice(s![0, .., ..]);
4396 skip.push(ind);
4397
4398 let (protos_tensor, _) =
4399 Self::find_outputs_with_shape(&protos_config.shape, outputs, &skip)?;
4400 let protos_tensor = Self::swap_axes_if_needed(protos_tensor, protos_config.into());
4401 let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
4402
4403 crate::yolo::decode_yolo_split_end_to_end_segdet_float_proto(
4404 boxes_tensor,
4405 scores_tensor,
4406 classes_tensor,
4407 mask_tensor,
4408 protos_tensor,
4409 self.score_threshold,
4410 output_boxes,
4411 )
4412 }
4413
4414 #[allow(clippy::too_many_arguments)]
4415 fn decode_yolo_split_end_to_end_segdet_quantized_proto(
4416 &self,
4417 outputs: &[ArrayViewDQuantized],
4418 boxes_config: &configs::Boxes,
4419 scores_config: &configs::Scores,
4420 classes_config: &configs::Classes,
4421 mask_coeff_config: &configs::MaskCoefficients,
4422 protos_config: &configs::Protos,
4423 output_boxes: &mut Vec<DetectBox>,
4424 ) -> Result<ProtoData, DecoderError> {
4425 let mut skip = vec![];
4426 let (boxes_tensor, ind) =
4427 Self::find_outputs_with_shape_quantized(&boxes_config.shape, outputs, &skip)?;
4428 skip.push(ind);
4429 let (scores_tensor, ind) =
4430 Self::find_outputs_with_shape_quantized(&scores_config.shape, outputs, &skip)?;
4431 skip.push(ind);
4432 let (classes_tensor, ind) =
4433 Self::find_outputs_with_shape_quantized(&classes_config.shape, outputs, &skip)?;
4434 skip.push(ind);
4435 let (mask_tensor, ind) =
4436 Self::find_outputs_with_shape_quantized(&mask_coeff_config.shape, outputs, &skip)?;
4437 skip.push(ind);
4438 let (protos_tensor, _) =
4439 Self::find_outputs_with_shape_quantized(&protos_config.shape, outputs, &skip)?;
4440
4441 let quant_boxes = boxes_config
4442 .quantization
4443 .map(Quantization::from)
4444 .unwrap_or_default();
4445 let quant_scores = scores_config
4446 .quantization
4447 .map(Quantization::from)
4448 .unwrap_or_default();
4449 let quant_classes = classes_config
4450 .quantization
4451 .map(Quantization::from)
4452 .unwrap_or_default();
4453 let quant_masks = mask_coeff_config
4454 .quantization
4455 .map(Quantization::from)
4456 .unwrap_or_default();
4457 let quant_protos = protos_config
4458 .quantization
4459 .map(Quantization::from)
4460 .unwrap_or_default();
4461
4462 macro_rules! dequant_3d {
4463 ($tensor:expr, $config:expr, $quant:expr) => {{
4464 with_quantized!($tensor, t, {
4465 let t = Self::swap_axes_if_needed(t, $config.into());
4466 let t = t.slice(s![0, .., ..]);
4467 t.map(|v| {
4468 let val: f32 = v.as_();
4469 (val - $quant.zero_point as f32) * $quant.scale
4470 })
4471 })
4472 }};
4473 }
4474 macro_rules! dequant_4d {
4475 ($tensor:expr, $config:expr, $quant:expr) => {{
4476 with_quantized!($tensor, t, {
4477 let t = Self::swap_axes_if_needed(t, $config.into());
4478 let t = t.slice(s![0, .., .., ..]);
4479 t.map(|v| {
4480 let val: f32 = v.as_();
4481 (val - $quant.zero_point as f32) * $quant.scale
4482 })
4483 })
4484 }};
4485 }
4486
4487 let dequant_b = dequant_3d!(boxes_tensor, boxes_config, quant_boxes);
4488 let dequant_s = dequant_3d!(scores_tensor, scores_config, quant_scores);
4489 let dequant_c = dequant_3d!(classes_tensor, classes_config, quant_classes);
4490 let dequant_m = dequant_3d!(mask_tensor, mask_coeff_config, quant_masks);
4491 let dequant_p = dequant_4d!(protos_tensor, protos_config, quant_protos);
4492
4493 crate::yolo::decode_yolo_split_end_to_end_segdet_float_proto(
4494 dequant_b.view(),
4495 dequant_s.view(),
4496 dequant_c.view(),
4497 dequant_m.view(),
4498 dequant_p.view(),
4499 self.score_threshold,
4500 output_boxes,
4501 )
4502 }
4503
4504 fn match_outputs_to_detect<'a, 'b, T>(
4505 configs: &[configs::Detection],
4506 outputs: &'a [ArrayViewD<'b, T>],
4507 ) -> Result<Vec<&'a ArrayViewD<'b, T>>, DecoderError> {
4508 let mut new_output_order = Vec::new();
4509 for c in configs {
4510 let mut found = false;
4511 for o in outputs {
4512 if o.shape() == c.shape {
4513 new_output_order.push(o);
4514 found = true;
4515 break;
4516 }
4517 }
4518 if !found {
4519 return Err(DecoderError::InvalidShape(format!(
4520 "Did not find output with shape {:?}",
4521 c.shape
4522 )));
4523 }
4524 }
4525 Ok(new_output_order)
4526 }
4527
4528 fn find_outputs_with_shape<'a, 'b, T>(
4529 shape: &[usize],
4530 outputs: &'a [ArrayViewD<'b, T>],
4531 skip: &[usize],
4532 ) -> Result<(&'a ArrayViewD<'b, T>, usize), DecoderError> {
4533 for (ind, o) in outputs.iter().enumerate() {
4534 if skip.contains(&ind) {
4535 continue;
4536 }
4537 if o.shape() == shape {
4538 return Ok((o, ind));
4539 }
4540 }
4541 Err(DecoderError::InvalidShape(format!(
4542 "Did not find output with shape {:?}",
4543 shape
4544 )))
4545 }
4546
4547 fn find_outputs_with_shape_quantized<'a, 'b>(
4548 shape: &[usize],
4549 outputs: &'a [ArrayViewDQuantized<'b>],
4550 skip: &[usize],
4551 ) -> Result<(&'a ArrayViewDQuantized<'b>, usize), DecoderError> {
4552 for (ind, o) in outputs.iter().enumerate() {
4553 if skip.contains(&ind) {
4554 continue;
4555 }
4556 if o.shape() == shape {
4557 return Ok((o, ind));
4558 }
4559 }
4560 Err(DecoderError::InvalidShape(format!(
4561 "Did not find output with shape {:?}",
4562 shape
4563 )))
4564 }
4565
4566 fn modelpack_det_order(x: DimName) -> usize {
4569 match x {
4570 DimName::Batch => 0,
4571 DimName::NumBoxes => 1,
4572 DimName::Padding => 2,
4573 DimName::BoxCoords => 3,
4574 _ => 1000, }
4576 }
4577
4578 fn yolo_det_order(x: DimName) -> usize {
4581 match x {
4582 DimName::Batch => 0,
4583 DimName::NumFeatures => 1,
4584 DimName::NumBoxes => 2,
4585 _ => 1000, }
4587 }
4588
4589 fn modelpack_boxes_order(x: DimName) -> usize {
4592 match x {
4593 DimName::Batch => 0,
4594 DimName::NumBoxes => 1,
4595 DimName::Padding => 2,
4596 DimName::BoxCoords => 3,
4597 _ => 1000, }
4599 }
4600
4601 fn yolo_boxes_order(x: DimName) -> usize {
4604 match x {
4605 DimName::Batch => 0,
4606 DimName::BoxCoords => 1,
4607 DimName::NumBoxes => 2,
4608 _ => 1000, }
4610 }
4611
4612 fn modelpack_scores_order(x: DimName) -> usize {
4615 match x {
4616 DimName::Batch => 0,
4617 DimName::NumBoxes => 1,
4618 DimName::NumClasses => 2,
4619 _ => 1000, }
4621 }
4622
4623 fn yolo_scores_order(x: DimName) -> usize {
4624 match x {
4625 DimName::Batch => 0,
4626 DimName::NumClasses => 1,
4627 DimName::NumBoxes => 2,
4628 _ => 1000, }
4630 }
4631
4632 fn modelpack_segmentation_order(x: DimName) -> usize {
4635 match x {
4636 DimName::Batch => 0,
4637 DimName::Height => 1,
4638 DimName::Width => 2,
4639 DimName::NumClasses => 3,
4640 _ => 1000, }
4642 }
4643
4644 fn modelpack_mask_order(x: DimName) -> usize {
4647 match x {
4648 DimName::Batch => 0,
4649 DimName::Height => 1,
4650 DimName::Width => 2,
4651 _ => 1000, }
4653 }
4654
4655 fn yolo_protos_order(x: DimName) -> usize {
4658 match x {
4659 DimName::Batch => 0,
4660 DimName::Height => 1,
4661 DimName::Width => 2,
4662 DimName::NumProtos => 3,
4663 _ => 1000, }
4665 }
4666
4667 fn yolo_maskcoefficients_order(x: DimName) -> usize {
4670 match x {
4671 DimName::Batch => 0,
4672 DimName::NumProtos => 1,
4673 DimName::NumBoxes => 2,
4674 _ => 1000, }
4676 }
4677
4678 fn get_order_fn(config: ConfigOutputRef) -> fn(DimName) -> usize {
4679 let decoder_type = config.decoder();
4680 match (config, decoder_type) {
4681 (ConfigOutputRef::Detection(_), DecoderType::ModelPack) => Self::modelpack_det_order,
4682 (ConfigOutputRef::Detection(_), DecoderType::Ultralytics) => Self::yolo_det_order,
4683 (ConfigOutputRef::Boxes(_), DecoderType::ModelPack) => Self::modelpack_boxes_order,
4684 (ConfigOutputRef::Boxes(_), DecoderType::Ultralytics) => Self::yolo_boxes_order,
4685 (ConfigOutputRef::Scores(_), DecoderType::ModelPack) => Self::modelpack_scores_order,
4686 (ConfigOutputRef::Scores(_), DecoderType::Ultralytics) => Self::yolo_scores_order,
4687 (ConfigOutputRef::Segmentation(_), _) => Self::modelpack_segmentation_order,
4688 (ConfigOutputRef::Mask(_), _) => Self::modelpack_mask_order,
4689 (ConfigOutputRef::Protos(_), _) => Self::yolo_protos_order,
4690 (ConfigOutputRef::MaskCoefficients(_), _) => Self::yolo_maskcoefficients_order,
4691 (ConfigOutputRef::Classes(_), _) => Self::yolo_scores_order,
4692 }
4693 }
4694
4695 fn swap_axes_if_needed<'a, T, D: Dimension>(
4696 array: &ArrayView<'a, T, D>,
4697 config: ConfigOutputRef,
4698 ) -> ArrayView<'a, T, D> {
4699 let mut array = array.clone();
4700 if config.dshape().is_empty() {
4701 return array;
4702 }
4703 let order_fn: fn(DimName) -> usize = Self::get_order_fn(config.clone());
4704 let mut current_order: Vec<usize> = config
4705 .dshape()
4706 .iter()
4707 .map(|x| order_fn(x.0))
4708 .collect::<Vec<_>>();
4709
4710 assert_eq!(array.shape().len(), current_order.len());
4711 for i in 0..current_order.len() {
4714 let mut swapped = false;
4715 for j in 0..current_order.len() - 1 - i {
4716 if current_order[j] > current_order[j + 1] {
4717 array.swap_axes(j, j + 1);
4718 current_order.swap(j, j + 1);
4719 swapped = true;
4720 }
4721 }
4722 if !swapped {
4723 break;
4724 }
4725 }
4726 array
4727 }
4728
4729 fn match_outputs_to_detect_quantized<'a, 'b>(
4730 configs: &[configs::Detection],
4731 outputs: &'a [ArrayViewDQuantized<'b>],
4732 ) -> Result<Vec<&'a ArrayViewDQuantized<'b>>, DecoderError> {
4733 let mut new_output_order = Vec::new();
4734 for c in configs {
4735 let mut found = false;
4736 for o in outputs {
4737 if o.shape() == c.shape {
4738 new_output_order.push(o);
4739 found = true;
4740 break;
4741 }
4742 }
4743 if !found {
4744 return Err(DecoderError::InvalidShape(format!(
4745 "Did not find output with shape {:?}",
4746 c.shape
4747 )));
4748 }
4749 }
4750 Ok(new_output_order)
4751 }
4752}
4753
4754#[cfg(test)]
4755#[cfg_attr(coverage_nightly, coverage(off))]
4756mod decoder_builder_tests {
4757 use super::*;
4758
4759 #[test]
4760 fn test_decoder_builder_no_config() {
4761 use crate::DecoderBuilder;
4762 let result = DecoderBuilder::default().build();
4763 assert!(matches!(result, Err(DecoderError::NoConfig)));
4764 }
4765
4766 #[test]
4767 fn test_decoder_builder_empty_config() {
4768 use crate::DecoderBuilder;
4769 let result = DecoderBuilder::default()
4770 .with_config(ConfigOutputs {
4771 outputs: vec![],
4772 ..Default::default()
4773 })
4774 .build();
4775 assert!(
4776 matches!(result, Err(DecoderError::InvalidConfig(s)) if s == "No outputs found in config")
4777 );
4778 }
4779
4780 #[test]
4781 fn test_malformed_config_yaml() {
4782 let malformed_yaml = "
4783 model_type: yolov8_det
4784 outputs:
4785 - shape: [1, 84, 8400]
4786 "
4787 .to_owned();
4788 let result = DecoderBuilder::new()
4789 .with_config_yaml_str(malformed_yaml)
4790 .build();
4791 assert!(matches!(result, Err(DecoderError::Yaml(_))));
4792 }
4793
4794 #[test]
4795 fn test_malformed_config_json() {
4796 let malformed_yaml = "
4797 {
4798 \"model_type\": \"yolov8_det\",
4799 \"outputs\": [
4800 {
4801 \"shape\": [1, 84, 8400]
4802 }
4803 ]
4804 }"
4805 .to_owned();
4806 let result = DecoderBuilder::new()
4807 .with_config_json_str(malformed_yaml)
4808 .build();
4809 assert!(matches!(result, Err(DecoderError::Json(_))));
4810 }
4811
4812 #[test]
4813 fn test_modelpack_and_yolo_config_error() {
4814 let result = DecoderBuilder::new()
4815 .with_config_modelpack_det(
4816 configs::Boxes {
4817 decoder: configs::DecoderType::Ultralytics,
4818 shape: vec![1, 4, 8400],
4819 quantization: None,
4820 dshape: vec![
4821 (DimName::Batch, 1),
4822 (DimName::BoxCoords, 4),
4823 (DimName::NumBoxes, 8400),
4824 ],
4825 normalized: Some(true),
4826 },
4827 configs::Scores {
4828 decoder: configs::DecoderType::ModelPack,
4829 shape: vec![1, 80, 8400],
4830 quantization: None,
4831 dshape: vec![
4832 (DimName::Batch, 1),
4833 (DimName::NumClasses, 80),
4834 (DimName::NumBoxes, 8400),
4835 ],
4836 },
4837 )
4838 .build();
4839
4840 assert!(matches!(
4841 result, Err(DecoderError::InvalidConfig(s)) if s == "Both ModelPack and Yolo outputs found in config"
4842 ));
4843 }
4844
4845 #[test]
4846 fn test_yolo_invalid_seg_shape() {
4847 let result = DecoderBuilder::new()
4848 .with_config_yolo_segdet(
4849 configs::Detection {
4850 decoder: configs::DecoderType::Ultralytics,
4851 shape: vec![1, 85, 8400, 1], quantization: None,
4853 anchors: None,
4854 dshape: vec![
4855 (DimName::Batch, 1),
4856 (DimName::NumFeatures, 85),
4857 (DimName::NumBoxes, 8400),
4858 (DimName::Batch, 1),
4859 ],
4860 normalized: Some(true),
4861 },
4862 configs::Protos {
4863 decoder: configs::DecoderType::Ultralytics,
4864 shape: vec![1, 32, 160, 160],
4865 quantization: None,
4866 dshape: vec![
4867 (DimName::Batch, 1),
4868 (DimName::NumProtos, 32),
4869 (DimName::Height, 160),
4870 (DimName::Width, 160),
4871 ],
4872 },
4873 Some(DecoderVersion::Yolo11),
4874 )
4875 .build();
4876
4877 assert!(matches!(
4878 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Detection shape")
4879 ));
4880 }
4881
4882 #[test]
4883 fn test_yolo_invalid_mask() {
4884 let result = DecoderBuilder::new()
4885 .with_config(ConfigOutputs {
4886 outputs: vec![ConfigOutput::Mask(configs::Mask {
4887 shape: vec![1, 160, 160, 1],
4888 decoder: configs::DecoderType::Ultralytics,
4889 quantization: None,
4890 dshape: vec![
4891 (DimName::Batch, 1),
4892 (DimName::Height, 160),
4893 (DimName::Width, 160),
4894 (DimName::NumFeatures, 1),
4895 ],
4896 })],
4897 ..Default::default()
4898 })
4899 .build();
4900
4901 assert!(matches!(
4902 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Mask output with Yolo decoder")
4903 ));
4904 }
4905
4906 #[test]
4907 fn test_yolo_invalid_outputs() {
4908 let result = DecoderBuilder::new()
4909 .with_config(ConfigOutputs {
4910 outputs: vec![ConfigOutput::Segmentation(configs::Segmentation {
4911 shape: vec![1, 84, 8400],
4912 decoder: configs::DecoderType::Ultralytics,
4913 quantization: None,
4914 dshape: vec![
4915 (DimName::Batch, 1),
4916 (DimName::NumFeatures, 84),
4917 (DimName::NumBoxes, 8400),
4918 ],
4919 })],
4920 ..Default::default()
4921 })
4922 .build();
4923
4924 assert!(
4925 matches!(result, Err(DecoderError::InvalidConfig(s)) if s == "Invalid Segmentation output with Yolo decoder")
4926 );
4927 }
4928
4929 #[test]
4930 fn test_yolo_invalid_det() {
4931 let result = DecoderBuilder::new()
4932 .with_config_yolo_det(
4933 configs::Detection {
4934 anchors: None,
4935 decoder: DecoderType::Ultralytics,
4936 quantization: None,
4937 shape: vec![1, 84, 8400, 1], dshape: vec![
4939 (DimName::Batch, 1),
4940 (DimName::NumFeatures, 84),
4941 (DimName::NumBoxes, 8400),
4942 (DimName::Batch, 1),
4943 ],
4944 normalized: Some(true),
4945 },
4946 Some(DecoderVersion::Yolo11),
4947 )
4948 .build();
4949
4950 assert!(matches!(
4951 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Detection shape")));
4952
4953 let result = DecoderBuilder::new()
4954 .with_config_yolo_det(
4955 configs::Detection {
4956 anchors: None,
4957 decoder: DecoderType::Ultralytics,
4958 quantization: None,
4959 shape: vec![1, 8400, 3], dshape: vec![
4961 (DimName::Batch, 1),
4962 (DimName::NumBoxes, 8400),
4963 (DimName::NumFeatures, 3),
4964 ],
4965 normalized: Some(true),
4966 },
4967 Some(DecoderVersion::Yolo11),
4968 )
4969 .build();
4970
4971 assert!(
4972 matches!(
4973 &result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid shape: Yolo num_features 3 must be greater than 4")),
4974 "{}",
4975 result.unwrap_err()
4976 );
4977
4978 let result = DecoderBuilder::new()
4979 .with_config_yolo_det(
4980 configs::Detection {
4981 anchors: None,
4982 decoder: DecoderType::Ultralytics,
4983 quantization: None,
4984 shape: vec![1, 3, 8400], dshape: Vec::new(),
4986 normalized: Some(true),
4987 },
4988 Some(DecoderVersion::Yolo11),
4989 )
4990 .build();
4991
4992 assert!(matches!(
4993 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid shape: Yolo num_features 3 must be greater than 4")));
4994 }
4995
4996 #[test]
4997 fn test_yolo_invalid_segdet() {
4998 let result = DecoderBuilder::new()
4999 .with_config_yolo_segdet(
5000 configs::Detection {
5001 decoder: configs::DecoderType::Ultralytics,
5002 shape: vec![1, 85, 8400, 1], quantization: None,
5004 anchors: None,
5005 dshape: vec![
5006 (DimName::Batch, 1),
5007 (DimName::NumFeatures, 85),
5008 (DimName::NumBoxes, 8400),
5009 (DimName::Batch, 1),
5010 ],
5011 normalized: Some(true),
5012 },
5013 configs::Protos {
5014 decoder: configs::DecoderType::Ultralytics,
5015 shape: vec![1, 32, 160, 160],
5016 quantization: None,
5017 dshape: vec![
5018 (DimName::Batch, 1),
5019 (DimName::NumProtos, 32),
5020 (DimName::Height, 160),
5021 (DimName::Width, 160),
5022 ],
5023 },
5024 Some(DecoderVersion::Yolo11),
5025 )
5026 .build();
5027
5028 assert!(matches!(
5029 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Detection shape")));
5030
5031 let result = DecoderBuilder::new()
5032 .with_config_yolo_segdet(
5033 configs::Detection {
5034 decoder: configs::DecoderType::Ultralytics,
5035 shape: vec![1, 85, 8400],
5036 quantization: None,
5037 anchors: None,
5038 dshape: vec![
5039 (DimName::Batch, 1),
5040 (DimName::NumFeatures, 85),
5041 (DimName::NumBoxes, 8400),
5042 ],
5043 normalized: Some(true),
5044 },
5045 configs::Protos {
5046 decoder: configs::DecoderType::Ultralytics,
5047 shape: vec![1, 32, 160, 160, 1], dshape: vec![
5049 (DimName::Batch, 1),
5050 (DimName::NumProtos, 32),
5051 (DimName::Height, 160),
5052 (DimName::Width, 160),
5053 (DimName::Batch, 1),
5054 ],
5055 quantization: None,
5056 },
5057 Some(DecoderVersion::Yolo11),
5058 )
5059 .build();
5060
5061 assert!(matches!(
5062 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Protos shape")));
5063
5064 let result = DecoderBuilder::new()
5065 .with_config_yolo_segdet(
5066 configs::Detection {
5067 decoder: configs::DecoderType::Ultralytics,
5068 shape: vec![1, 8400, 36], quantization: None,
5070 anchors: None,
5071 dshape: vec![
5072 (DimName::Batch, 1),
5073 (DimName::NumBoxes, 8400),
5074 (DimName::NumFeatures, 36),
5075 ],
5076 normalized: Some(true),
5077 },
5078 configs::Protos {
5079 decoder: configs::DecoderType::Ultralytics,
5080 shape: vec![1, 32, 160, 160],
5081 quantization: None,
5082 dshape: vec![
5083 (DimName::Batch, 1),
5084 (DimName::NumProtos, 32),
5085 (DimName::Height, 160),
5086 (DimName::Width, 160),
5087 ],
5088 },
5089 Some(DecoderVersion::Yolo11),
5090 )
5091 .build();
5092 println!("{:?}", result);
5093 assert!(matches!(
5094 result, Err(DecoderError::InvalidConfig(s)) if s == "Invalid shape: Yolo num_features 36 must be greater than 36"));
5095 }
5096
5097 #[test]
5098 fn test_yolo_invalid_split_det() {
5099 let result = DecoderBuilder::new()
5100 .with_config_yolo_split_det(
5101 configs::Boxes {
5102 decoder: configs::DecoderType::Ultralytics,
5103 shape: vec![1, 4, 8400, 1], quantization: None,
5105 dshape: vec![
5106 (DimName::Batch, 1),
5107 (DimName::BoxCoords, 4),
5108 (DimName::NumBoxes, 8400),
5109 (DimName::Batch, 1),
5110 ],
5111 normalized: Some(true),
5112 },
5113 configs::Scores {
5114 decoder: configs::DecoderType::Ultralytics,
5115 shape: vec![1, 80, 8400],
5116 quantization: None,
5117 dshape: vec![
5118 (DimName::Batch, 1),
5119 (DimName::NumClasses, 80),
5120 (DimName::NumBoxes, 8400),
5121 ],
5122 },
5123 )
5124 .build();
5125
5126 assert!(matches!(
5127 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Split Boxes shape")));
5128
5129 let result = DecoderBuilder::new()
5130 .with_config_yolo_split_det(
5131 configs::Boxes {
5132 decoder: configs::DecoderType::Ultralytics,
5133 shape: vec![1, 4, 8400],
5134 quantization: None,
5135 dshape: vec![
5136 (DimName::Batch, 1),
5137 (DimName::BoxCoords, 4),
5138 (DimName::NumBoxes, 8400),
5139 ],
5140 normalized: Some(true),
5141 },
5142 configs::Scores {
5143 decoder: configs::DecoderType::Ultralytics,
5144 shape: vec![1, 80, 8400, 1], quantization: None,
5146 dshape: vec![
5147 (DimName::Batch, 1),
5148 (DimName::NumClasses, 80),
5149 (DimName::NumBoxes, 8400),
5150 (DimName::Batch, 1),
5151 ],
5152 },
5153 )
5154 .build();
5155
5156 assert!(matches!(
5157 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Split Scores shape")));
5158
5159 let result = DecoderBuilder::new()
5160 .with_config_yolo_split_det(
5161 configs::Boxes {
5162 decoder: configs::DecoderType::Ultralytics,
5163 shape: vec![1, 8400, 4],
5164 quantization: None,
5165 dshape: vec![
5166 (DimName::Batch, 1),
5167 (DimName::NumBoxes, 8400),
5168 (DimName::BoxCoords, 4),
5169 ],
5170 normalized: Some(true),
5171 },
5172 configs::Scores {
5173 decoder: configs::DecoderType::Ultralytics,
5174 shape: vec![1, 8400 + 1, 80], quantization: None,
5176 dshape: vec![
5177 (DimName::Batch, 1),
5178 (DimName::NumBoxes, 8401),
5179 (DimName::NumClasses, 80),
5180 ],
5181 },
5182 )
5183 .build();
5184
5185 assert!(matches!(
5186 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Yolo Split Detection Boxes num 8400 incompatible with Scores num 8401")));
5187
5188 let result = DecoderBuilder::new()
5189 .with_config_yolo_split_det(
5190 configs::Boxes {
5191 decoder: configs::DecoderType::Ultralytics,
5192 shape: vec![1, 5, 8400], quantization: None,
5194 dshape: vec![
5195 (DimName::Batch, 1),
5196 (DimName::BoxCoords, 5),
5197 (DimName::NumBoxes, 8400),
5198 ],
5199 normalized: Some(true),
5200 },
5201 configs::Scores {
5202 decoder: configs::DecoderType::Ultralytics,
5203 shape: vec![1, 80, 8400],
5204 quantization: None,
5205 dshape: vec![
5206 (DimName::Batch, 1),
5207 (DimName::NumClasses, 80),
5208 (DimName::NumBoxes, 8400),
5209 ],
5210 },
5211 )
5212 .build();
5213 assert!(matches!(
5214 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("BoxCoords dimension size must be 4")));
5215 }
5216
5217 #[test]
5218 fn test_yolo_invalid_split_segdet() {
5219 let result = DecoderBuilder::new()
5220 .with_config_yolo_split_segdet(
5221 configs::Boxes {
5222 decoder: configs::DecoderType::Ultralytics,
5223 shape: vec![1, 8400, 4, 1],
5224 quantization: None,
5225 dshape: vec![
5226 (DimName::Batch, 1),
5227 (DimName::NumBoxes, 8400),
5228 (DimName::BoxCoords, 4),
5229 (DimName::Batch, 1),
5230 ],
5231 normalized: Some(true),
5232 },
5233 configs::Scores {
5234 decoder: configs::DecoderType::Ultralytics,
5235 shape: vec![1, 8400, 80],
5236
5237 quantization: None,
5238 dshape: vec![
5239 (DimName::Batch, 1),
5240 (DimName::NumBoxes, 8400),
5241 (DimName::NumClasses, 80),
5242 ],
5243 },
5244 configs::MaskCoefficients {
5245 decoder: configs::DecoderType::Ultralytics,
5246 shape: vec![1, 8400, 32],
5247 quantization: None,
5248 dshape: vec![
5249 (DimName::Batch, 1),
5250 (DimName::NumBoxes, 8400),
5251 (DimName::NumProtos, 32),
5252 ],
5253 },
5254 configs::Protos {
5255 decoder: configs::DecoderType::Ultralytics,
5256 shape: vec![1, 32, 160, 160],
5257 quantization: None,
5258 dshape: vec![
5259 (DimName::Batch, 1),
5260 (DimName::NumProtos, 32),
5261 (DimName::Height, 160),
5262 (DimName::Width, 160),
5263 ],
5264 },
5265 )
5266 .build();
5267
5268 assert!(matches!(
5269 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Split Boxes shape")));
5270
5271 let result = DecoderBuilder::new()
5272 .with_config_yolo_split_segdet(
5273 configs::Boxes {
5274 decoder: configs::DecoderType::Ultralytics,
5275 shape: vec![1, 8400, 4],
5276 quantization: None,
5277 dshape: vec![
5278 (DimName::Batch, 1),
5279 (DimName::NumBoxes, 8400),
5280 (DimName::BoxCoords, 4),
5281 ],
5282 normalized: Some(true),
5283 },
5284 configs::Scores {
5285 decoder: configs::DecoderType::Ultralytics,
5286 shape: vec![1, 8400, 80, 1],
5287 quantization: None,
5288 dshape: vec![
5289 (DimName::Batch, 1),
5290 (DimName::NumBoxes, 8400),
5291 (DimName::NumClasses, 80),
5292 (DimName::Batch, 1),
5293 ],
5294 },
5295 configs::MaskCoefficients {
5296 decoder: configs::DecoderType::Ultralytics,
5297 shape: vec![1, 8400, 32],
5298 quantization: None,
5299 dshape: vec![
5300 (DimName::Batch, 1),
5301 (DimName::NumBoxes, 8400),
5302 (DimName::NumProtos, 32),
5303 ],
5304 },
5305 configs::Protos {
5306 decoder: configs::DecoderType::Ultralytics,
5307 shape: vec![1, 32, 160, 160],
5308 quantization: None,
5309 dshape: vec![
5310 (DimName::Batch, 1),
5311 (DimName::NumProtos, 32),
5312 (DimName::Height, 160),
5313 (DimName::Width, 160),
5314 ],
5315 },
5316 )
5317 .build();
5318
5319 assert!(matches!(
5320 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Split Scores shape")));
5321
5322 let result = DecoderBuilder::new()
5323 .with_config_yolo_split_segdet(
5324 configs::Boxes {
5325 decoder: configs::DecoderType::Ultralytics,
5326 shape: vec![1, 8400, 4],
5327 quantization: None,
5328 dshape: vec![
5329 (DimName::Batch, 1),
5330 (DimName::NumBoxes, 8400),
5331 (DimName::BoxCoords, 4),
5332 ],
5333 normalized: Some(true),
5334 },
5335 configs::Scores {
5336 decoder: configs::DecoderType::Ultralytics,
5337 shape: vec![1, 8400, 80],
5338 quantization: None,
5339 dshape: vec![
5340 (DimName::Batch, 1),
5341 (DimName::NumBoxes, 8400),
5342 (DimName::NumClasses, 80),
5343 ],
5344 },
5345 configs::MaskCoefficients {
5346 decoder: configs::DecoderType::Ultralytics,
5347 shape: vec![1, 8400, 32, 1],
5348 quantization: None,
5349 dshape: vec![
5350 (DimName::Batch, 1),
5351 (DimName::NumBoxes, 8400),
5352 (DimName::NumProtos, 32),
5353 (DimName::Batch, 1),
5354 ],
5355 },
5356 configs::Protos {
5357 decoder: configs::DecoderType::Ultralytics,
5358 shape: vec![1, 32, 160, 160],
5359 quantization: None,
5360 dshape: vec![
5361 (DimName::Batch, 1),
5362 (DimName::NumProtos, 32),
5363 (DimName::Height, 160),
5364 (DimName::Width, 160),
5365 ],
5366 },
5367 )
5368 .build();
5369
5370 assert!(matches!(
5371 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Split Mask Coefficients shape")));
5372
5373 let result = DecoderBuilder::new()
5374 .with_config_yolo_split_segdet(
5375 configs::Boxes {
5376 decoder: configs::DecoderType::Ultralytics,
5377 shape: vec![1, 8400, 4],
5378 quantization: None,
5379 dshape: vec![
5380 (DimName::Batch, 1),
5381 (DimName::NumBoxes, 8400),
5382 (DimName::BoxCoords, 4),
5383 ],
5384 normalized: Some(true),
5385 },
5386 configs::Scores {
5387 decoder: configs::DecoderType::Ultralytics,
5388 shape: vec![1, 8400, 80],
5389 quantization: None,
5390 dshape: vec![
5391 (DimName::Batch, 1),
5392 (DimName::NumBoxes, 8400),
5393 (DimName::NumClasses, 80),
5394 ],
5395 },
5396 configs::MaskCoefficients {
5397 decoder: configs::DecoderType::Ultralytics,
5398 shape: vec![1, 8400, 32],
5399 quantization: None,
5400 dshape: vec![
5401 (DimName::Batch, 1),
5402 (DimName::NumBoxes, 8400),
5403 (DimName::NumProtos, 32),
5404 ],
5405 },
5406 configs::Protos {
5407 decoder: configs::DecoderType::Ultralytics,
5408 shape: vec![1, 32, 160, 160, 1],
5409 quantization: None,
5410 dshape: vec![
5411 (DimName::Batch, 1),
5412 (DimName::NumProtos, 32),
5413 (DimName::Height, 160),
5414 (DimName::Width, 160),
5415 (DimName::Batch, 1),
5416 ],
5417 },
5418 )
5419 .build();
5420
5421 assert!(matches!(
5422 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Protos shape")));
5423
5424 let result = DecoderBuilder::new()
5425 .with_config_yolo_split_segdet(
5426 configs::Boxes {
5427 decoder: configs::DecoderType::Ultralytics,
5428 shape: vec![1, 8400, 4],
5429 quantization: None,
5430 dshape: vec![
5431 (DimName::Batch, 1),
5432 (DimName::NumBoxes, 8400),
5433 (DimName::BoxCoords, 4),
5434 ],
5435 normalized: Some(true),
5436 },
5437 configs::Scores {
5438 decoder: configs::DecoderType::Ultralytics,
5439 shape: vec![1, 8401, 80],
5440 quantization: None,
5441 dshape: vec![
5442 (DimName::Batch, 1),
5443 (DimName::NumBoxes, 8401),
5444 (DimName::NumClasses, 80),
5445 ],
5446 },
5447 configs::MaskCoefficients {
5448 decoder: configs::DecoderType::Ultralytics,
5449 shape: vec![1, 8400, 32],
5450 quantization: None,
5451 dshape: vec![
5452 (DimName::Batch, 1),
5453 (DimName::NumBoxes, 8400),
5454 (DimName::NumProtos, 32),
5455 ],
5456 },
5457 configs::Protos {
5458 decoder: configs::DecoderType::Ultralytics,
5459 shape: vec![1, 32, 160, 160],
5460 quantization: None,
5461 dshape: vec![
5462 (DimName::Batch, 1),
5463 (DimName::NumProtos, 32),
5464 (DimName::Height, 160),
5465 (DimName::Width, 160),
5466 ],
5467 },
5468 )
5469 .build();
5470
5471 assert!(matches!(
5472 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Yolo Split Detection Boxes num 8400 incompatible with Scores num 8401")));
5473
5474 let result = DecoderBuilder::new()
5475 .with_config_yolo_split_segdet(
5476 configs::Boxes {
5477 decoder: configs::DecoderType::Ultralytics,
5478 shape: vec![1, 8400, 4],
5479 quantization: None,
5480 dshape: vec![
5481 (DimName::Batch, 1),
5482 (DimName::NumBoxes, 8400),
5483 (DimName::BoxCoords, 4),
5484 ],
5485 normalized: Some(true),
5486 },
5487 configs::Scores {
5488 decoder: configs::DecoderType::Ultralytics,
5489 shape: vec![1, 8400, 80],
5490 quantization: None,
5491 dshape: vec![
5492 (DimName::Batch, 1),
5493 (DimName::NumBoxes, 8400),
5494 (DimName::NumClasses, 80),
5495 ],
5496 },
5497 configs::MaskCoefficients {
5498 decoder: configs::DecoderType::Ultralytics,
5499 shape: vec![1, 8401, 32],
5500
5501 quantization: None,
5502 dshape: vec![
5503 (DimName::Batch, 1),
5504 (DimName::NumBoxes, 8401),
5505 (DimName::NumProtos, 32),
5506 ],
5507 },
5508 configs::Protos {
5509 decoder: configs::DecoderType::Ultralytics,
5510 shape: vec![1, 32, 160, 160],
5511 quantization: None,
5512 dshape: vec![
5513 (DimName::Batch, 1),
5514 (DimName::NumProtos, 32),
5515 (DimName::Height, 160),
5516 (DimName::Width, 160),
5517 ],
5518 },
5519 )
5520 .build();
5521
5522 assert!(matches!(
5523 result, Err(DecoderError::InvalidConfig(ref s)) if s.starts_with("Yolo Split Detection Boxes num 8400 incompatible with Mask Coefficients num 8401")));
5524 let result = DecoderBuilder::new()
5525 .with_config_yolo_split_segdet(
5526 configs::Boxes {
5527 decoder: configs::DecoderType::Ultralytics,
5528 shape: vec![1, 8400, 4],
5529 quantization: None,
5530 dshape: vec![
5531 (DimName::Batch, 1),
5532 (DimName::NumBoxes, 8400),
5533 (DimName::BoxCoords, 4),
5534 ],
5535 normalized: Some(true),
5536 },
5537 configs::Scores {
5538 decoder: configs::DecoderType::Ultralytics,
5539 shape: vec![1, 8400, 80],
5540 quantization: None,
5541 dshape: vec![
5542 (DimName::Batch, 1),
5543 (DimName::NumBoxes, 8400),
5544 (DimName::NumClasses, 80),
5545 ],
5546 },
5547 configs::MaskCoefficients {
5548 decoder: configs::DecoderType::Ultralytics,
5549 shape: vec![1, 8400, 32],
5550 quantization: None,
5551 dshape: vec![
5552 (DimName::Batch, 1),
5553 (DimName::NumBoxes, 8400),
5554 (DimName::NumProtos, 32),
5555 ],
5556 },
5557 configs::Protos {
5558 decoder: configs::DecoderType::Ultralytics,
5559 shape: vec![1, 31, 160, 160],
5560 quantization: None,
5561 dshape: vec![
5562 (DimName::Batch, 1),
5563 (DimName::NumProtos, 31),
5564 (DimName::Height, 160),
5565 (DimName::Width, 160),
5566 ],
5567 },
5568 )
5569 .build();
5570 println!("{:?}", result);
5571 assert!(matches!(
5572 result, Err(DecoderError::InvalidConfig(ref s)) if s.starts_with( "Yolo Protos channels 31 incompatible with Mask Coefficients channels 32")));
5573 }
5574
5575 #[test]
5576 fn test_modelpack_invalid_config() {
5577 let result = DecoderBuilder::new()
5578 .with_config(ConfigOutputs {
5579 outputs: vec![
5580 ConfigOutput::Boxes(configs::Boxes {
5581 decoder: configs::DecoderType::ModelPack,
5582 shape: vec![1, 8400, 1, 4],
5583 quantization: None,
5584 dshape: vec![
5585 (DimName::Batch, 1),
5586 (DimName::NumBoxes, 8400),
5587 (DimName::Padding, 1),
5588 (DimName::BoxCoords, 4),
5589 ],
5590 normalized: Some(true),
5591 }),
5592 ConfigOutput::Scores(configs::Scores {
5593 decoder: configs::DecoderType::ModelPack,
5594 shape: vec![1, 8400, 3],
5595 quantization: None,
5596 dshape: vec![
5597 (DimName::Batch, 1),
5598 (DimName::NumBoxes, 8400),
5599 (DimName::NumClasses, 3),
5600 ],
5601 }),
5602 ConfigOutput::Protos(configs::Protos {
5603 decoder: configs::DecoderType::ModelPack,
5604 shape: vec![1, 8400, 3],
5605 quantization: None,
5606 dshape: vec![
5607 (DimName::Batch, 1),
5608 (DimName::NumBoxes, 8400),
5609 (DimName::NumFeatures, 3),
5610 ],
5611 }),
5612 ],
5613 ..Default::default()
5614 })
5615 .build();
5616
5617 assert!(matches!(
5618 result, Err(DecoderError::InvalidConfig(s)) if s == "ModelPack should not have protos"));
5619
5620 let result = DecoderBuilder::new()
5621 .with_config(ConfigOutputs {
5622 outputs: vec![
5623 ConfigOutput::Boxes(configs::Boxes {
5624 decoder: configs::DecoderType::ModelPack,
5625 shape: vec![1, 8400, 1, 4],
5626 quantization: None,
5627 dshape: vec![
5628 (DimName::Batch, 1),
5629 (DimName::NumBoxes, 8400),
5630 (DimName::Padding, 1),
5631 (DimName::BoxCoords, 4),
5632 ],
5633 normalized: Some(true),
5634 }),
5635 ConfigOutput::Scores(configs::Scores {
5636 decoder: configs::DecoderType::ModelPack,
5637 shape: vec![1, 8400, 3],
5638 quantization: None,
5639 dshape: vec![
5640 (DimName::Batch, 1),
5641 (DimName::NumBoxes, 8400),
5642 (DimName::NumClasses, 3),
5643 ],
5644 }),
5645 ConfigOutput::MaskCoefficients(configs::MaskCoefficients {
5646 decoder: configs::DecoderType::ModelPack,
5647 shape: vec![1, 8400, 3],
5648 quantization: None,
5649 dshape: vec![
5650 (DimName::Batch, 1),
5651 (DimName::NumBoxes, 8400),
5652 (DimName::NumProtos, 3),
5653 ],
5654 }),
5655 ],
5656 ..Default::default()
5657 })
5658 .build();
5659
5660 assert!(matches!(
5661 result, Err(DecoderError::InvalidConfig(s)) if s == "ModelPack should not have mask coefficients"));
5662
5663 let result = DecoderBuilder::new()
5664 .with_config(ConfigOutputs {
5665 outputs: vec![ConfigOutput::Boxes(configs::Boxes {
5666 decoder: configs::DecoderType::ModelPack,
5667 shape: vec![1, 8400, 1, 4],
5668 quantization: None,
5669 dshape: vec![
5670 (DimName::Batch, 1),
5671 (DimName::NumBoxes, 8400),
5672 (DimName::Padding, 1),
5673 (DimName::BoxCoords, 4),
5674 ],
5675 normalized: Some(true),
5676 })],
5677 ..Default::default()
5678 })
5679 .build();
5680
5681 assert!(matches!(
5682 result, Err(DecoderError::InvalidConfig(s)) if s == "Invalid ModelPack model outputs"));
5683 }
5684
5685 #[test]
5686 fn test_modelpack_invalid_det() {
5687 let result = DecoderBuilder::new()
5688 .with_config_modelpack_det(
5689 configs::Boxes {
5690 decoder: DecoderType::ModelPack,
5691 quantization: None,
5692 shape: vec![1, 4, 8400],
5693 dshape: vec![
5694 (DimName::Batch, 1),
5695 (DimName::BoxCoords, 4),
5696 (DimName::NumBoxes, 8400),
5697 ],
5698 normalized: Some(true),
5699 },
5700 configs::Scores {
5701 decoder: DecoderType::ModelPack,
5702 quantization: None,
5703 shape: vec![1, 80, 8400],
5704 dshape: vec![
5705 (DimName::Batch, 1),
5706 (DimName::NumClasses, 80),
5707 (DimName::NumBoxes, 8400),
5708 ],
5709 },
5710 )
5711 .build();
5712
5713 assert!(matches!(
5714 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid ModelPack Boxes shape")));
5715
5716 let result = DecoderBuilder::new()
5717 .with_config_modelpack_det(
5718 configs::Boxes {
5719 decoder: DecoderType::ModelPack,
5720 quantization: None,
5721 shape: vec![1, 4, 1, 8400],
5722 dshape: vec![
5723 (DimName::Batch, 1),
5724 (DimName::BoxCoords, 4),
5725 (DimName::Padding, 1),
5726 (DimName::NumBoxes, 8400),
5727 ],
5728 normalized: Some(true),
5729 },
5730 configs::Scores {
5731 decoder: DecoderType::ModelPack,
5732 quantization: None,
5733 shape: vec![1, 80, 8400, 1],
5734 dshape: vec![
5735 (DimName::Batch, 1),
5736 (DimName::NumClasses, 80),
5737 (DimName::NumBoxes, 8400),
5738 (DimName::Padding, 1),
5739 ],
5740 },
5741 )
5742 .build();
5743
5744 assert!(matches!(
5745 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid ModelPack Scores shape")));
5746
5747 let result = DecoderBuilder::new()
5748 .with_config_modelpack_det(
5749 configs::Boxes {
5750 decoder: DecoderType::ModelPack,
5751 quantization: None,
5752 shape: vec![1, 4, 2, 8400],
5753 dshape: vec![
5754 (DimName::Batch, 1),
5755 (DimName::BoxCoords, 4),
5756 (DimName::Padding, 2),
5757 (DimName::NumBoxes, 8400),
5758 ],
5759 normalized: Some(true),
5760 },
5761 configs::Scores {
5762 decoder: DecoderType::ModelPack,
5763 quantization: None,
5764 shape: vec![1, 80, 8400],
5765 dshape: vec![
5766 (DimName::Batch, 1),
5767 (DimName::NumClasses, 80),
5768 (DimName::NumBoxes, 8400),
5769 ],
5770 },
5771 )
5772 .build();
5773 assert!(matches!(
5774 result, Err(DecoderError::InvalidConfig(s)) if s == "Padding dimension size must be 1"));
5775
5776 let result = DecoderBuilder::new()
5777 .with_config_modelpack_det(
5778 configs::Boxes {
5779 decoder: DecoderType::ModelPack,
5780 quantization: None,
5781 shape: vec![1, 5, 1, 8400],
5782 dshape: vec![
5783 (DimName::Batch, 1),
5784 (DimName::BoxCoords, 5),
5785 (DimName::Padding, 1),
5786 (DimName::NumBoxes, 8400),
5787 ],
5788 normalized: Some(true),
5789 },
5790 configs::Scores {
5791 decoder: DecoderType::ModelPack,
5792 quantization: None,
5793 shape: vec![1, 80, 8400],
5794 dshape: vec![
5795 (DimName::Batch, 1),
5796 (DimName::NumClasses, 80),
5797 (DimName::NumBoxes, 8400),
5798 ],
5799 },
5800 )
5801 .build();
5802
5803 assert!(matches!(
5804 result, Err(DecoderError::InvalidConfig(s)) if s == "BoxCoords dimension size must be 4"));
5805
5806 let result = DecoderBuilder::new()
5807 .with_config_modelpack_det(
5808 configs::Boxes {
5809 decoder: DecoderType::ModelPack,
5810 quantization: None,
5811 shape: vec![1, 4, 1, 8400],
5812 dshape: vec![
5813 (DimName::Batch, 1),
5814 (DimName::BoxCoords, 4),
5815 (DimName::Padding, 1),
5816 (DimName::NumBoxes, 8400),
5817 ],
5818 normalized: Some(true),
5819 },
5820 configs::Scores {
5821 decoder: DecoderType::ModelPack,
5822 quantization: None,
5823 shape: vec![1, 80, 8401],
5824 dshape: vec![
5825 (DimName::Batch, 1),
5826 (DimName::NumClasses, 80),
5827 (DimName::NumBoxes, 8401),
5828 ],
5829 },
5830 )
5831 .build();
5832
5833 assert!(matches!(
5834 result, Err(DecoderError::InvalidConfig(s)) if s == "ModelPack Detection Boxes num 8400 incompatible with Scores num 8401"));
5835 }
5836
5837 #[test]
5838 fn test_modelpack_invalid_det_split() {
5839 let result = DecoderBuilder::default()
5840 .with_config_modelpack_det_split(vec![
5841 configs::Detection {
5842 decoder: DecoderType::ModelPack,
5843 shape: vec![1, 17, 30, 18],
5844 anchors: None,
5845 quantization: None,
5846 dshape: vec![
5847 (DimName::Batch, 1),
5848 (DimName::Height, 17),
5849 (DimName::Width, 30),
5850 (DimName::NumAnchorsXFeatures, 18),
5851 ],
5852 normalized: Some(true),
5853 },
5854 configs::Detection {
5855 decoder: DecoderType::ModelPack,
5856 shape: vec![1, 9, 15, 18],
5857 anchors: None,
5858 quantization: None,
5859 dshape: vec![
5860 (DimName::Batch, 1),
5861 (DimName::Height, 9),
5862 (DimName::Width, 15),
5863 (DimName::NumAnchorsXFeatures, 18),
5864 ],
5865 normalized: Some(true),
5866 },
5867 ])
5868 .build();
5869
5870 assert!(matches!(
5871 result, Err(DecoderError::InvalidConfig(s)) if s == "ModelPack Split Detection missing anchors"));
5872
5873 let result = DecoderBuilder::default()
5874 .with_config_modelpack_det_split(vec![configs::Detection {
5875 decoder: DecoderType::ModelPack,
5876 shape: vec![1, 17, 30, 18],
5877 anchors: None,
5878 quantization: None,
5879 dshape: Vec::new(),
5880 normalized: Some(true),
5881 }])
5882 .build();
5883
5884 assert!(matches!(
5885 result, Err(DecoderError::InvalidConfig(s)) if s == "ModelPack Split Detection missing anchors"));
5886
5887 let result = DecoderBuilder::default()
5888 .with_config_modelpack_det_split(vec![configs::Detection {
5889 decoder: DecoderType::ModelPack,
5890 shape: vec![1, 17, 30, 18],
5891 anchors: Some(vec![]),
5892 quantization: None,
5893 dshape: vec![
5894 (DimName::Batch, 1),
5895 (DimName::Height, 17),
5896 (DimName::Width, 30),
5897 (DimName::NumAnchorsXFeatures, 18),
5898 ],
5899 normalized: Some(true),
5900 }])
5901 .build();
5902
5903 assert!(matches!(
5904 result, Err(DecoderError::InvalidConfig(s)) if s == "ModelPack Split Detection has zero anchors"));
5905
5906 let result = DecoderBuilder::default()
5907 .with_config_modelpack_det_split(vec![configs::Detection {
5908 decoder: DecoderType::ModelPack,
5909 shape: vec![1, 17, 30, 18, 1],
5910 anchors: Some(vec![
5911 [0.3666666, 0.3148148],
5912 [0.3874999, 0.474074],
5913 [0.5333333, 0.644444],
5914 ]),
5915 quantization: None,
5916 dshape: vec![
5917 (DimName::Batch, 1),
5918 (DimName::Height, 17),
5919 (DimName::Width, 30),
5920 (DimName::NumAnchorsXFeatures, 18),
5921 (DimName::Padding, 1),
5922 ],
5923 normalized: Some(true),
5924 }])
5925 .build();
5926
5927 assert!(matches!(
5928 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid ModelPack Split Detection shape")));
5929
5930 let result = DecoderBuilder::default()
5931 .with_config_modelpack_det_split(vec![configs::Detection {
5932 decoder: DecoderType::ModelPack,
5933 shape: vec![1, 15, 17, 30],
5934 anchors: Some(vec![
5935 [0.3666666, 0.3148148],
5936 [0.3874999, 0.474074],
5937 [0.5333333, 0.644444],
5938 ]),
5939 quantization: None,
5940 dshape: vec![
5941 (DimName::Batch, 1),
5942 (DimName::NumAnchorsXFeatures, 15),
5943 (DimName::Height, 17),
5944 (DimName::Width, 30),
5945 ],
5946 normalized: Some(true),
5947 }])
5948 .build();
5949
5950 assert!(matches!(
5951 result, Err(DecoderError::InvalidConfig(s)) if s.contains("not greater than number of anchors * 5 =")));
5952
5953 let result = DecoderBuilder::default()
5954 .with_config_modelpack_det_split(vec![configs::Detection {
5955 decoder: DecoderType::ModelPack,
5956 shape: vec![1, 17, 30, 15],
5957 anchors: Some(vec![
5958 [0.3666666, 0.3148148],
5959 [0.3874999, 0.474074],
5960 [0.5333333, 0.644444],
5961 ]),
5962 quantization: None,
5963 dshape: Vec::new(),
5964 normalized: Some(true),
5965 }])
5966 .build();
5967
5968 assert!(matches!(
5969 result, Err(DecoderError::InvalidConfig(s)) if s.contains("not greater than number of anchors * 5 =")));
5970
5971 let result = DecoderBuilder::default()
5972 .with_config_modelpack_det_split(vec![configs::Detection {
5973 decoder: DecoderType::ModelPack,
5974 shape: vec![1, 16, 17, 30],
5975 anchors: Some(vec![
5976 [0.3666666, 0.3148148],
5977 [0.3874999, 0.474074],
5978 [0.5333333, 0.644444],
5979 ]),
5980 quantization: None,
5981 dshape: vec![
5982 (DimName::Batch, 1),
5983 (DimName::NumAnchorsXFeatures, 16),
5984 (DimName::Height, 17),
5985 (DimName::Width, 30),
5986 ],
5987 normalized: Some(true),
5988 }])
5989 .build();
5990
5991 assert!(matches!(
5992 result, Err(DecoderError::InvalidConfig(s)) if s.contains("not a multiple of number of anchors")));
5993
5994 let result = DecoderBuilder::default()
5995 .with_config_modelpack_det_split(vec![configs::Detection {
5996 decoder: DecoderType::ModelPack,
5997 shape: vec![1, 17, 30, 16],
5998 anchors: Some(vec![
5999 [0.3666666, 0.3148148],
6000 [0.3874999, 0.474074],
6001 [0.5333333, 0.644444],
6002 ]),
6003 quantization: None,
6004 dshape: Vec::new(),
6005 normalized: Some(true),
6006 }])
6007 .build();
6008
6009 assert!(matches!(
6010 result, Err(DecoderError::InvalidConfig(s)) if s.contains("not a multiple of number of anchors")));
6011
6012 let result = DecoderBuilder::default()
6013 .with_config_modelpack_det_split(vec![configs::Detection {
6014 decoder: DecoderType::ModelPack,
6015 shape: vec![1, 18, 17, 30],
6016 anchors: Some(vec![
6017 [0.3666666, 0.3148148],
6018 [0.3874999, 0.474074],
6019 [0.5333333, 0.644444],
6020 ]),
6021 quantization: None,
6022 dshape: vec![
6023 (DimName::Batch, 1),
6024 (DimName::NumProtos, 18),
6025 (DimName::Height, 17),
6026 (DimName::Width, 30),
6027 ],
6028 normalized: Some(true),
6029 }])
6030 .build();
6031 assert!(matches!(
6032 result, Err(DecoderError::InvalidConfig(s)) if s.contains("Split Detection dshape missing required dimension NumAnchorsXFeature")));
6033
6034 let result = DecoderBuilder::default()
6035 .with_config_modelpack_det_split(vec![
6036 configs::Detection {
6037 decoder: DecoderType::ModelPack,
6038 shape: vec![1, 17, 30, 18],
6039 anchors: Some(vec![
6040 [0.3666666, 0.3148148],
6041 [0.3874999, 0.474074],
6042 [0.5333333, 0.644444],
6043 ]),
6044 quantization: None,
6045 dshape: vec![
6046 (DimName::Batch, 1),
6047 (DimName::Height, 17),
6048 (DimName::Width, 30),
6049 (DimName::NumAnchorsXFeatures, 18),
6050 ],
6051 normalized: Some(true),
6052 },
6053 configs::Detection {
6054 decoder: DecoderType::ModelPack,
6055 shape: vec![1, 17, 30, 21],
6056 anchors: Some(vec![
6057 [0.3666666, 0.3148148],
6058 [0.3874999, 0.474074],
6059 [0.5333333, 0.644444],
6060 ]),
6061 quantization: None,
6062 dshape: vec![
6063 (DimName::Batch, 1),
6064 (DimName::Height, 17),
6065 (DimName::Width, 30),
6066 (DimName::NumAnchorsXFeatures, 21),
6067 ],
6068 normalized: Some(true),
6069 },
6070 ])
6071 .build();
6072
6073 assert!(matches!(
6074 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("ModelPack Split Detection inconsistent number of classes:")));
6075
6076 let result = DecoderBuilder::default()
6077 .with_config_modelpack_det_split(vec![
6078 configs::Detection {
6079 decoder: DecoderType::ModelPack,
6080 shape: vec![1, 17, 30, 18],
6081 anchors: Some(vec![
6082 [0.3666666, 0.3148148],
6083 [0.3874999, 0.474074],
6084 [0.5333333, 0.644444],
6085 ]),
6086 quantization: None,
6087 dshape: vec![],
6088 normalized: Some(true),
6089 },
6090 configs::Detection {
6091 decoder: DecoderType::ModelPack,
6092 shape: vec![1, 17, 30, 21],
6093 anchors: Some(vec![
6094 [0.3666666, 0.3148148],
6095 [0.3874999, 0.474074],
6096 [0.5333333, 0.644444],
6097 ]),
6098 quantization: None,
6099 dshape: vec![],
6100 normalized: Some(true),
6101 },
6102 ])
6103 .build();
6104
6105 assert!(matches!(
6106 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("ModelPack Split Detection inconsistent number of classes:")));
6107 }
6108
6109 #[test]
6110 fn test_modelpack_invalid_seg() {
6111 let result = DecoderBuilder::new()
6112 .with_config_modelpack_seg(configs::Segmentation {
6113 decoder: DecoderType::ModelPack,
6114 quantization: None,
6115 shape: vec![1, 160, 106, 3, 1],
6116 dshape: vec![
6117 (DimName::Batch, 1),
6118 (DimName::Height, 160),
6119 (DimName::Width, 106),
6120 (DimName::NumClasses, 3),
6121 (DimName::Padding, 1),
6122 ],
6123 })
6124 .build();
6125
6126 assert!(matches!(
6127 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid ModelPack Segmentation shape")));
6128 }
6129
6130 #[test]
6131 fn test_modelpack_invalid_segdet() {
6132 let result = DecoderBuilder::new()
6133 .with_config_modelpack_segdet(
6134 configs::Boxes {
6135 decoder: DecoderType::ModelPack,
6136 quantization: None,
6137 shape: vec![1, 4, 1, 8400],
6138 dshape: vec![
6139 (DimName::Batch, 1),
6140 (DimName::BoxCoords, 4),
6141 (DimName::Padding, 1),
6142 (DimName::NumBoxes, 8400),
6143 ],
6144 normalized: Some(true),
6145 },
6146 configs::Scores {
6147 decoder: DecoderType::ModelPack,
6148 quantization: None,
6149 shape: vec![1, 4, 8400],
6150 dshape: vec![
6151 (DimName::Batch, 1),
6152 (DimName::NumClasses, 4),
6153 (DimName::NumBoxes, 8400),
6154 ],
6155 },
6156 configs::Segmentation {
6157 decoder: DecoderType::ModelPack,
6158 quantization: None,
6159 shape: vec![1, 160, 106, 3],
6160 dshape: vec![
6161 (DimName::Batch, 1),
6162 (DimName::Height, 160),
6163 (DimName::Width, 106),
6164 (DimName::NumClasses, 3),
6165 ],
6166 },
6167 )
6168 .build();
6169
6170 assert!(matches!(
6171 result, Err(DecoderError::InvalidConfig(s)) if s.contains("incompatible with number of classes")));
6172 }
6173
6174 #[test]
6175 fn test_modelpack_invalid_segdet_split() {
6176 let result = DecoderBuilder::new()
6177 .with_config_modelpack_segdet_split(
6178 vec![configs::Detection {
6179 decoder: DecoderType::ModelPack,
6180 shape: vec![1, 17, 30, 18],
6181 anchors: Some(vec![
6182 [0.3666666, 0.3148148],
6183 [0.3874999, 0.474074],
6184 [0.5333333, 0.644444],
6185 ]),
6186 quantization: None,
6187 dshape: vec![
6188 (DimName::Batch, 1),
6189 (DimName::Height, 17),
6190 (DimName::Width, 30),
6191 (DimName::NumAnchorsXFeatures, 18),
6192 ],
6193 normalized: Some(true),
6194 }],
6195 configs::Segmentation {
6196 decoder: DecoderType::ModelPack,
6197 quantization: None,
6198 shape: vec![1, 160, 106, 3],
6199 dshape: vec![
6200 (DimName::Batch, 1),
6201 (DimName::Height, 160),
6202 (DimName::Width, 106),
6203 (DimName::NumClasses, 3),
6204 ],
6205 },
6206 )
6207 .build();
6208
6209 assert!(matches!(
6210 result, Err(DecoderError::InvalidConfig(s)) if s.contains("incompatible with number of classes")));
6211 }
6212
6213 #[test]
6214 fn test_decode_bad_shapes() {
6215 let score_threshold = 0.25;
6216 let iou_threshold = 0.7;
6217 let quant = (0.0040811873, -123);
6218 let out = include_bytes!("../../../testdata/yolov8s_80_classes.bin");
6219 let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
6220 let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
6221 let out_float: Array3<f32> = dequantize_ndarray(out.view(), quant.into());
6222
6223 let decoder = DecoderBuilder::default()
6224 .with_config_yolo_det(
6225 configs::Detection {
6226 decoder: DecoderType::Ultralytics,
6227 shape: vec![1, 85, 8400],
6228 anchors: None,
6229 quantization: Some(quant.into()),
6230 dshape: vec![
6231 (DimName::Batch, 1),
6232 (DimName::NumFeatures, 85),
6233 (DimName::NumBoxes, 8400),
6234 ],
6235 normalized: Some(true),
6236 },
6237 Some(DecoderVersion::Yolo11),
6238 )
6239 .with_score_threshold(score_threshold)
6240 .with_iou_threshold(iou_threshold)
6241 .build()
6242 .unwrap();
6243
6244 let mut output_boxes: Vec<_> = Vec::with_capacity(50);
6245 let mut output_masks: Vec<_> = Vec::with_capacity(50);
6246 let result =
6247 decoder.decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks);
6248
6249 assert!(matches!(
6250 result, Err(DecoderError::InvalidShape(s)) if s == "Did not find output with shape [1, 85, 8400]"));
6251
6252 let result = decoder.decode_float(
6253 &[out_float.view().into_dyn()],
6254 &mut output_boxes,
6255 &mut output_masks,
6256 );
6257
6258 assert!(matches!(
6259 result, Err(DecoderError::InvalidShape(s)) if s == "Did not find output with shape [1, 85, 8400]"));
6260 }
6261
6262 #[test]
6263 fn test_config_outputs() {
6264 let outputs = [
6265 ConfigOutput::Detection(configs::Detection {
6266 decoder: configs::DecoderType::Ultralytics,
6267 anchors: None,
6268 shape: vec![1, 8400, 85],
6269 quantization: Some(QuantTuple(0.123, 0)),
6270 dshape: vec![
6271 (DimName::Batch, 1),
6272 (DimName::NumBoxes, 8400),
6273 (DimName::NumFeatures, 85),
6274 ],
6275 normalized: Some(true),
6276 }),
6277 ConfigOutput::Mask(configs::Mask {
6278 decoder: configs::DecoderType::Ultralytics,
6279 shape: vec![1, 160, 160, 1],
6280 quantization: Some(QuantTuple(0.223, 0)),
6281 dshape: vec![
6282 (DimName::Batch, 1),
6283 (DimName::Height, 160),
6284 (DimName::Width, 160),
6285 (DimName::NumFeatures, 1),
6286 ],
6287 }),
6288 ConfigOutput::Segmentation(configs::Segmentation {
6289 decoder: configs::DecoderType::Ultralytics,
6290 shape: vec![1, 160, 160, 80],
6291 quantization: Some(QuantTuple(0.323, 0)),
6292 dshape: vec![
6293 (DimName::Batch, 1),
6294 (DimName::Height, 160),
6295 (DimName::Width, 160),
6296 (DimName::NumClasses, 80),
6297 ],
6298 }),
6299 ConfigOutput::Scores(configs::Scores {
6300 decoder: configs::DecoderType::Ultralytics,
6301 shape: vec![1, 8400, 80],
6302 quantization: Some(QuantTuple(0.423, 0)),
6303 dshape: vec![
6304 (DimName::Batch, 1),
6305 (DimName::NumBoxes, 8400),
6306 (DimName::NumClasses, 80),
6307 ],
6308 }),
6309 ConfigOutput::Boxes(configs::Boxes {
6310 decoder: configs::DecoderType::Ultralytics,
6311 shape: vec![1, 8400, 4],
6312 quantization: Some(QuantTuple(0.523, 0)),
6313 dshape: vec![
6314 (DimName::Batch, 1),
6315 (DimName::NumBoxes, 8400),
6316 (DimName::BoxCoords, 4),
6317 ],
6318 normalized: Some(true),
6319 }),
6320 ConfigOutput::Protos(configs::Protos {
6321 decoder: configs::DecoderType::Ultralytics,
6322 shape: vec![1, 32, 160, 160],
6323 quantization: Some(QuantTuple(0.623, 0)),
6324 dshape: vec![
6325 (DimName::Batch, 1),
6326 (DimName::NumProtos, 32),
6327 (DimName::Height, 160),
6328 (DimName::Width, 160),
6329 ],
6330 }),
6331 ConfigOutput::MaskCoefficients(configs::MaskCoefficients {
6332 decoder: configs::DecoderType::Ultralytics,
6333 shape: vec![1, 8400, 32],
6334 quantization: Some(QuantTuple(0.723, 0)),
6335 dshape: vec![
6336 (DimName::Batch, 1),
6337 (DimName::NumBoxes, 8400),
6338 (DimName::NumProtos, 32),
6339 ],
6340 }),
6341 ];
6342
6343 let shapes = outputs.clone().map(|x| x.shape().to_vec());
6344 assert_eq!(
6345 shapes,
6346 [
6347 vec![1, 8400, 85],
6348 vec![1, 160, 160, 1],
6349 vec![1, 160, 160, 80],
6350 vec![1, 8400, 80],
6351 vec![1, 8400, 4],
6352 vec![1, 32, 160, 160],
6353 vec![1, 8400, 32],
6354 ]
6355 );
6356
6357 let quants: [Option<(f32, i32)>; 7] = outputs.map(|x| x.quantization().map(|q| q.into()));
6358 assert_eq!(
6359 quants,
6360 [
6361 Some((0.123, 0)),
6362 Some((0.223, 0)),
6363 Some((0.323, 0)),
6364 Some((0.423, 0)),
6365 Some((0.523, 0)),
6366 Some((0.623, 0)),
6367 Some((0.723, 0)),
6368 ]
6369 );
6370 }
6371
6372 #[test]
6373 fn test_nms_from_config_yaml() {
6374 let yaml_class_agnostic = r#"
6376outputs:
6377 - decoder: ultralytics
6378 type: detection
6379 shape: [1, 84, 8400]
6380 dshape:
6381 - [batch, 1]
6382 - [num_features, 84]
6383 - [num_boxes, 8400]
6384nms: class_agnostic
6385"#;
6386 let decoder = DecoderBuilder::new()
6387 .with_config_yaml_str(yaml_class_agnostic.to_string())
6388 .build()
6389 .unwrap();
6390 assert_eq!(decoder.nms, Some(configs::Nms::ClassAgnostic));
6391
6392 let yaml_class_aware = r#"
6393outputs:
6394 - decoder: ultralytics
6395 type: detection
6396 shape: [1, 84, 8400]
6397 dshape:
6398 - [batch, 1]
6399 - [num_features, 84]
6400 - [num_boxes, 8400]
6401nms: class_aware
6402"#;
6403 let decoder = DecoderBuilder::new()
6404 .with_config_yaml_str(yaml_class_aware.to_string())
6405 .build()
6406 .unwrap();
6407 assert_eq!(decoder.nms, Some(configs::Nms::ClassAware));
6408
6409 let decoder = DecoderBuilder::new()
6411 .with_config_yaml_str(yaml_class_aware.to_string())
6412 .with_nms(Some(configs::Nms::ClassAgnostic)) .build()
6414 .unwrap();
6415 assert_eq!(decoder.nms, Some(configs::Nms::ClassAware));
6417 }
6418
6419 #[test]
6420 fn test_nms_from_config_json() {
6421 let json_class_aware = r#"{
6423 "outputs": [{
6424 "decoder": "ultralytics",
6425 "type": "detection",
6426 "shape": [1, 84, 8400],
6427 "dshape": [["batch", 1], ["num_features", 84], ["num_boxes", 8400]]
6428 }],
6429 "nms": "class_aware"
6430 }"#;
6431 let decoder = DecoderBuilder::new()
6432 .with_config_json_str(json_class_aware.to_string())
6433 .build()
6434 .unwrap();
6435 assert_eq!(decoder.nms, Some(configs::Nms::ClassAware));
6436 }
6437
6438 #[test]
6439 fn test_nms_missing_from_config_uses_builder_default() {
6440 let yaml_no_nms = r#"
6442outputs:
6443 - decoder: ultralytics
6444 type: detection
6445 shape: [1, 84, 8400]
6446 dshape:
6447 - [batch, 1]
6448 - [num_features, 84]
6449 - [num_boxes, 8400]
6450"#;
6451 let decoder = DecoderBuilder::new()
6452 .with_config_yaml_str(yaml_no_nms.to_string())
6453 .build()
6454 .unwrap();
6455 assert_eq!(decoder.nms, Some(configs::Nms::ClassAgnostic));
6457
6458 let decoder = DecoderBuilder::new()
6460 .with_config_yaml_str(yaml_no_nms.to_string())
6461 .with_nms(None) .build()
6463 .unwrap();
6464 assert_eq!(decoder.nms, None);
6465 }
6466
6467 #[test]
6468 fn test_decoder_version_yolo26_end_to_end() {
6469 let yaml = r#"
6471outputs:
6472 - decoder: ultralytics
6473 type: detection
6474 shape: [1, 6, 8400]
6475 dshape:
6476 - [batch, 1]
6477 - [num_features, 6]
6478 - [num_boxes, 8400]
6479decoder_version: yolo26
6480"#;
6481 let decoder = DecoderBuilder::new()
6482 .with_config_yaml_str(yaml.to_string())
6483 .build()
6484 .unwrap();
6485 assert!(matches!(
6486 decoder.model_type,
6487 ModelType::YoloEndToEndDet { .. }
6488 ));
6489
6490 let yaml_with_nms = r#"
6492outputs:
6493 - decoder: ultralytics
6494 type: detection
6495 shape: [1, 6, 8400]
6496 dshape:
6497 - [batch, 1]
6498 - [num_features, 6]
6499 - [num_boxes, 8400]
6500decoder_version: yolo26
6501nms: class_agnostic
6502"#;
6503 let decoder = DecoderBuilder::new()
6504 .with_config_yaml_str(yaml_with_nms.to_string())
6505 .build()
6506 .unwrap();
6507 assert!(matches!(
6508 decoder.model_type,
6509 ModelType::YoloEndToEndDet { .. }
6510 ));
6511 }
6512
6513 #[test]
6514 fn test_decoder_version_yolov8_traditional() {
6515 let yaml = r#"
6517outputs:
6518 - decoder: ultralytics
6519 type: detection
6520 shape: [1, 84, 8400]
6521 dshape:
6522 - [batch, 1]
6523 - [num_features, 84]
6524 - [num_boxes, 8400]
6525decoder_version: yolov8
6526"#;
6527 let decoder = DecoderBuilder::new()
6528 .with_config_yaml_str(yaml.to_string())
6529 .build()
6530 .unwrap();
6531 assert!(matches!(decoder.model_type, ModelType::YoloDet { .. }));
6532 }
6533
6534 #[test]
6535 fn test_decoder_version_all_versions() {
6536 for version in ["yolov5", "yolov8", "yolo11"] {
6538 let yaml = format!(
6539 r#"
6540outputs:
6541 - decoder: ultralytics
6542 type: detection
6543 shape: [1, 84, 8400]
6544 dshape:
6545 - [batch, 1]
6546 - [num_features, 84]
6547 - [num_boxes, 8400]
6548decoder_version: {}
6549"#,
6550 version
6551 );
6552 let decoder = DecoderBuilder::new()
6553 .with_config_yaml_str(yaml)
6554 .build()
6555 .unwrap();
6556
6557 assert!(
6558 matches!(decoder.model_type, ModelType::YoloDet { .. }),
6559 "Expected traditional for {}",
6560 version
6561 );
6562 }
6563
6564 let yaml = r#"
6565outputs:
6566 - decoder: ultralytics
6567 type: detection
6568 shape: [1, 6, 8400]
6569 dshape:
6570 - [batch, 1]
6571 - [num_features, 6]
6572 - [num_boxes, 8400]
6573decoder_version: yolo26
6574"#
6575 .to_string();
6576
6577 let decoder = DecoderBuilder::new()
6578 .with_config_yaml_str(yaml)
6579 .build()
6580 .unwrap();
6581
6582 assert!(
6583 matches!(decoder.model_type, ModelType::YoloEndToEndDet { .. }),
6584 "Expected end to end for yolo26",
6585 );
6586 }
6587
6588 #[test]
6589 fn test_decoder_version_json() {
6590 let json = r#"{
6592 "outputs": [{
6593 "decoder": "ultralytics",
6594 "type": "detection",
6595 "shape": [1, 6, 8400],
6596 "dshape": [["batch", 1], ["num_features", 6], ["num_boxes", 8400]]
6597 }],
6598 "decoder_version": "yolo26"
6599 }"#;
6600 let decoder = DecoderBuilder::new()
6601 .with_config_json_str(json.to_string())
6602 .build()
6603 .unwrap();
6604 assert!(matches!(
6605 decoder.model_type,
6606 ModelType::YoloEndToEndDet { .. }
6607 ));
6608 }
6609
6610 #[test]
6611 fn test_decoder_version_none_uses_traditional() {
6612 let yaml = r#"
6614outputs:
6615 - decoder: ultralytics
6616 type: detection
6617 shape: [1, 84, 8400]
6618 dshape:
6619 - [batch, 1]
6620 - [num_features, 84]
6621 - [num_boxes, 8400]
6622"#;
6623 let decoder = DecoderBuilder::new()
6624 .with_config_yaml_str(yaml.to_string())
6625 .build()
6626 .unwrap();
6627 assert!(matches!(decoder.model_type, ModelType::YoloDet { .. }));
6628 }
6629
6630 #[test]
6631 fn test_decoder_version_none_with_nms_none_still_traditional() {
6632 let yaml = r#"
6635outputs:
6636 - decoder: ultralytics
6637 type: detection
6638 shape: [1, 84, 8400]
6639 dshape:
6640 - [batch, 1]
6641 - [num_features, 84]
6642 - [num_boxes, 8400]
6643"#;
6644 let decoder = DecoderBuilder::new()
6645 .with_config_yaml_str(yaml.to_string())
6646 .with_nms(None) .build()
6648 .unwrap();
6649 assert!(matches!(decoder.model_type, ModelType::YoloDet { .. }));
6652 }
6653
6654 #[test]
6655 fn test_decoder_heuristic_end_to_end_detection() {
6656 let yaml = r#"
6659outputs:
6660 - decoder: ultralytics
6661 type: detection
6662 shape: [1, 300, 6]
6663 dshape:
6664 - [batch, 1]
6665 - [num_boxes, 300]
6666 - [num_features, 6]
6667
6668"#;
6669 let decoder = DecoderBuilder::new()
6670 .with_config_yaml_str(yaml.to_string())
6671 .build()
6672 .unwrap();
6673 assert!(matches!(
6675 decoder.model_type,
6676 ModelType::YoloEndToEndDet { .. }
6677 ));
6678
6679 let yaml = r#"
6680outputs:
6681 - decoder: ultralytics
6682 type: detection
6683 shape: [1, 300, 38]
6684 dshape:
6685 - [batch, 1]
6686 - [num_boxes, 300]
6687 - [num_features, 38]
6688 - decoder: ultralytics
6689 type: protos
6690 shape: [1, 160, 160, 32]
6691 dshape:
6692 - [batch, 1]
6693 - [height, 160]
6694 - [width, 160]
6695 - [num_protos, 32]
6696"#;
6697 let decoder = DecoderBuilder::new()
6698 .with_config_yaml_str(yaml.to_string())
6699 .build()
6700 .unwrap();
6701 assert!(matches!(
6703 decoder.model_type,
6704 ModelType::YoloEndToEndSegDet { .. }
6705 ));
6706
6707 let yaml = r#"
6708outputs:
6709 - decoder: ultralytics
6710 type: detection
6711 shape: [1, 6, 300]
6712 dshape:
6713 - [batch, 1]
6714 - [num_features, 6]
6715 - [num_boxes, 300]
6716"#;
6717 let decoder = DecoderBuilder::new()
6718 .with_config_yaml_str(yaml.to_string())
6719 .build()
6720 .unwrap();
6721 assert!(matches!(decoder.model_type, ModelType::YoloDet { .. }));
6724
6725 let yaml = r#"
6726outputs:
6727 - decoder: ultralytics
6728 type: detection
6729 shape: [1, 38, 300]
6730 dshape:
6731 - [batch, 1]
6732 - [num_features, 38]
6733 - [num_boxes, 300]
6734
6735 - decoder: ultralytics
6736 type: protos
6737 shape: [1, 160, 160, 32]
6738 dshape:
6739 - [batch, 1]
6740 - [height, 160]
6741 - [width, 160]
6742 - [num_protos, 32]
6743"#;
6744 let decoder = DecoderBuilder::new()
6745 .with_config_yaml_str(yaml.to_string())
6746 .build()
6747 .unwrap();
6748 assert!(matches!(decoder.model_type, ModelType::YoloSegDet { .. }));
6750 }
6751
6752 #[test]
6753 fn test_decoder_version_is_end_to_end() {
6754 assert!(!configs::DecoderVersion::Yolov5.is_end_to_end());
6755 assert!(!configs::DecoderVersion::Yolov8.is_end_to_end());
6756 assert!(!configs::DecoderVersion::Yolo11.is_end_to_end());
6757 assert!(configs::DecoderVersion::Yolo26.is_end_to_end());
6758 }
6759
6760 #[test]
6761 fn test_dshape_dict_format() {
6762 let json = r#"{
6764 "decoder": "ultralytics",
6765 "shape": [1, 84, 8400],
6766 "dshape": [{"batch": 1}, {"num_features": 84}, {"num_boxes": 8400}]
6767 }"#;
6768 let det: configs::Detection = serde_json::from_str(json).unwrap();
6769 assert_eq!(det.dshape.len(), 3);
6770 assert_eq!(det.dshape[0], (configs::DimName::Batch, 1));
6771 assert_eq!(det.dshape[1], (configs::DimName::NumFeatures, 84));
6772 assert_eq!(det.dshape[2], (configs::DimName::NumBoxes, 8400));
6773 }
6774
6775 #[test]
6776 fn test_dshape_tuple_format() {
6777 let json = r#"{
6779 "decoder": "ultralytics",
6780 "shape": [1, 84, 8400],
6781 "dshape": [["batch", 1], ["num_features", 84], ["num_boxes", 8400]]
6782 }"#;
6783 let det: configs::Detection = serde_json::from_str(json).unwrap();
6784 assert_eq!(det.dshape.len(), 3);
6785 assert_eq!(det.dshape[0], (configs::DimName::Batch, 1));
6786 assert_eq!(det.dshape[1], (configs::DimName::NumFeatures, 84));
6787 assert_eq!(det.dshape[2], (configs::DimName::NumBoxes, 8400));
6788 }
6789
6790 #[test]
6791 fn test_dshape_empty_default() {
6792 let json = r#"{
6794 "decoder": "ultralytics",
6795 "shape": [1, 84, 8400]
6796 }"#;
6797 let det: configs::Detection = serde_json::from_str(json).unwrap();
6798 assert!(det.dshape.is_empty());
6799 }
6800
6801 #[test]
6802 fn test_dshape_dict_format_protos() {
6803 let json = r#"{
6804 "decoder": "ultralytics",
6805 "shape": [1, 32, 160, 160],
6806 "dshape": [{"batch": 1}, {"num_protos": 32}, {"height": 160}, {"width": 160}]
6807 }"#;
6808 let protos: configs::Protos = serde_json::from_str(json).unwrap();
6809 assert_eq!(protos.dshape.len(), 4);
6810 assert_eq!(protos.dshape[0], (configs::DimName::Batch, 1));
6811 assert_eq!(protos.dshape[1], (configs::DimName::NumProtos, 32));
6812 }
6813
6814 #[test]
6815 fn test_dshape_dict_format_boxes() {
6816 let json = r#"{
6817 "decoder": "ultralytics",
6818 "shape": [1, 8400, 4],
6819 "dshape": [{"batch": 1}, {"num_boxes": 8400}, {"box_coords": 4}]
6820 }"#;
6821 let boxes: configs::Boxes = serde_json::from_str(json).unwrap();
6822 assert_eq!(boxes.dshape.len(), 3);
6823 assert_eq!(boxes.dshape[2], (configs::DimName::BoxCoords, 4));
6824 }
6825}