1use std::collections::HashSet;
5
6use ndarray::{s, Array3, ArrayView, ArrayViewD, Dimension};
7use ndarray_stats::QuantileExt;
8use num_traits::{AsPrimitive, Float};
9use serde::{Deserialize, Serialize};
10
11use crate::{
12 configs::{DecoderType, DimName, ModelType, QuantTuple},
13 dequantize_ndarray,
14 modelpack::{
15 decode_modelpack_det, decode_modelpack_float, decode_modelpack_split_float,
16 ModelPackDetectionConfig,
17 },
18 yolo::{
19 decode_yolo_det, decode_yolo_det_float, decode_yolo_segdet_float, decode_yolo_segdet_quant,
20 decode_yolo_split_det_float, decode_yolo_split_det_quant, decode_yolo_split_segdet_float,
21 impl_yolo_split_segdet_quant_get_boxes, impl_yolo_split_segdet_quant_process_masks,
22 },
23 DecoderError, DecoderVersion, DetectBox, Quantization, Segmentation, XYWH,
24};
25
26#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
38pub struct ConfigOutputs {
39 #[serde(default)]
40 pub outputs: Vec<ConfigOutput>,
41 #[serde(default, skip_serializing_if = "Option::is_none")]
49 pub nms: Option<configs::Nms>,
50 #[serde(default, skip_serializing_if = "Option::is_none")]
57 pub decoder_version: Option<configs::DecoderVersion>,
58}
59
60#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
61#[serde(tag = "type")]
62pub enum ConfigOutput {
63 #[serde(rename = "detection")]
64 Detection(configs::Detection),
65 #[serde(rename = "masks")]
66 Mask(configs::Mask),
67 #[serde(rename = "segmentation")]
68 Segmentation(configs::Segmentation),
69 #[serde(rename = "protos")]
70 Protos(configs::Protos),
71 #[serde(rename = "scores")]
72 Scores(configs::Scores),
73 #[serde(rename = "boxes")]
74 Boxes(configs::Boxes),
75 #[serde(rename = "mask_coefficients")]
76 MaskCoefficients(configs::MaskCoefficients),
77}
78
79#[derive(Debug, PartialEq, Clone)]
80pub enum ConfigOutputRef<'a> {
81 Detection(&'a configs::Detection),
82 Mask(&'a configs::Mask),
83 Segmentation(&'a configs::Segmentation),
84 Protos(&'a configs::Protos),
85 Scores(&'a configs::Scores),
86 Boxes(&'a configs::Boxes),
87 MaskCoefficients(&'a configs::MaskCoefficients),
88}
89
90impl<'a> ConfigOutputRef<'a> {
91 fn decoder(&self) -> configs::DecoderType {
92 match self {
93 ConfigOutputRef::Detection(v) => v.decoder,
94 ConfigOutputRef::Mask(v) => v.decoder,
95 ConfigOutputRef::Segmentation(v) => v.decoder,
96 ConfigOutputRef::Protos(v) => v.decoder,
97 ConfigOutputRef::Scores(v) => v.decoder,
98 ConfigOutputRef::Boxes(v) => v.decoder,
99 ConfigOutputRef::MaskCoefficients(v) => v.decoder,
100 }
101 }
102
103 fn dshape(&self) -> &[(DimName, usize)] {
104 match self {
105 ConfigOutputRef::Detection(v) => &v.dshape,
106 ConfigOutputRef::Mask(v) => &v.dshape,
107 ConfigOutputRef::Segmentation(v) => &v.dshape,
108 ConfigOutputRef::Protos(v) => &v.dshape,
109 ConfigOutputRef::Scores(v) => &v.dshape,
110 ConfigOutputRef::Boxes(v) => &v.dshape,
111 ConfigOutputRef::MaskCoefficients(v) => &v.dshape,
112 }
113 }
114}
115
116impl<'a> From<&'a configs::Detection> for ConfigOutputRef<'a> {
117 fn from(v: &'a configs::Detection) -> ConfigOutputRef<'a> {
132 ConfigOutputRef::Detection(v)
133 }
134}
135
136impl<'a> From<&'a configs::Mask> for ConfigOutputRef<'a> {
137 fn from(v: &'a configs::Mask) -> ConfigOutputRef<'a> {
150 ConfigOutputRef::Mask(v)
151 }
152}
153
154impl<'a> From<&'a configs::Segmentation> for ConfigOutputRef<'a> {
155 fn from(v: &'a configs::Segmentation) -> ConfigOutputRef<'a> {
168 ConfigOutputRef::Segmentation(v)
169 }
170}
171
172impl<'a> From<&'a configs::Protos> for ConfigOutputRef<'a> {
173 fn from(v: &'a configs::Protos) -> ConfigOutputRef<'a> {
186 ConfigOutputRef::Protos(v)
187 }
188}
189
190impl<'a> From<&'a configs::Scores> for ConfigOutputRef<'a> {
191 fn from(v: &'a configs::Scores) -> ConfigOutputRef<'a> {
204 ConfigOutputRef::Scores(v)
205 }
206}
207
208impl<'a> From<&'a configs::Boxes> for ConfigOutputRef<'a> {
209 fn from(v: &'a configs::Boxes) -> ConfigOutputRef<'a> {
223 ConfigOutputRef::Boxes(v)
224 }
225}
226
227impl<'a> From<&'a configs::MaskCoefficients> for ConfigOutputRef<'a> {
228 fn from(v: &'a configs::MaskCoefficients) -> ConfigOutputRef<'a> {
241 ConfigOutputRef::MaskCoefficients(v)
242 }
243}
244
245impl ConfigOutput {
246 pub fn shape(&self) -> &[usize] {
263 match self {
264 ConfigOutput::Detection(detection) => &detection.shape,
265 ConfigOutput::Mask(mask) => &mask.shape,
266 ConfigOutput::Segmentation(segmentation) => &segmentation.shape,
267 ConfigOutput::Scores(scores) => &scores.shape,
268 ConfigOutput::Boxes(boxes) => &boxes.shape,
269 ConfigOutput::Protos(protos) => &protos.shape,
270 ConfigOutput::MaskCoefficients(mask_coefficients) => &mask_coefficients.shape,
271 }
272 }
273
274 pub fn decoder(&self) -> &configs::DecoderType {
291 match self {
292 ConfigOutput::Detection(detection) => &detection.decoder,
293 ConfigOutput::Mask(mask) => &mask.decoder,
294 ConfigOutput::Segmentation(segmentation) => &segmentation.decoder,
295 ConfigOutput::Scores(scores) => &scores.decoder,
296 ConfigOutput::Boxes(boxes) => &boxes.decoder,
297 ConfigOutput::Protos(protos) => &protos.decoder,
298 ConfigOutput::MaskCoefficients(mask_coefficients) => &mask_coefficients.decoder,
299 }
300 }
301
302 pub fn quantization(&self) -> Option<QuantTuple> {
319 match self {
320 ConfigOutput::Detection(detection) => detection.quantization,
321 ConfigOutput::Mask(mask) => mask.quantization,
322 ConfigOutput::Segmentation(segmentation) => segmentation.quantization,
323 ConfigOutput::Scores(scores) => scores.quantization,
324 ConfigOutput::Boxes(boxes) => boxes.quantization,
325 ConfigOutput::Protos(protos) => protos.quantization,
326 ConfigOutput::MaskCoefficients(mask_coefficients) => mask_coefficients.quantization,
327 }
328 }
329}
330
331pub mod configs {
332 use std::fmt::Display;
333
334 use serde::{Deserialize, Serialize};
335
336 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy)]
337 pub struct QuantTuple(pub f32, pub i32);
338 impl From<QuantTuple> for (f32, i32) {
339 fn from(value: QuantTuple) -> Self {
340 (value.0, value.1)
341 }
342 }
343
344 impl From<(f32, i32)> for QuantTuple {
345 fn from(value: (f32, i32)) -> Self {
346 QuantTuple(value.0, value.1)
347 }
348 }
349
350 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
351 pub struct Segmentation {
352 pub decoder: DecoderType,
353 pub quantization: Option<QuantTuple>,
354 pub shape: Vec<usize>,
355 #[serde(default)]
358 pub dshape: Vec<(DimName, usize)>,
359 }
360
361 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
362 pub struct Protos {
363 pub decoder: DecoderType,
364 pub quantization: Option<QuantTuple>,
365 pub shape: Vec<usize>,
366 #[serde(default)]
369 pub dshape: Vec<(DimName, usize)>,
370 }
371
372 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
373 pub struct MaskCoefficients {
374 pub decoder: DecoderType,
375 pub quantization: Option<QuantTuple>,
376 pub shape: Vec<usize>,
377 #[serde(default)]
380 pub dshape: Vec<(DimName, usize)>,
381 }
382
383 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
384 pub struct Mask {
385 pub decoder: DecoderType,
386 pub quantization: Option<QuantTuple>,
387 pub shape: Vec<usize>,
388 #[serde(default)]
391 pub dshape: Vec<(DimName, usize)>,
392 }
393
394 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
395 pub struct Detection {
396 pub anchors: Option<Vec<[f32; 2]>>,
397 pub decoder: DecoderType,
398 pub quantization: Option<QuantTuple>,
399 pub shape: Vec<usize>,
400 #[serde(default)]
403 pub dshape: Vec<(DimName, usize)>,
404 #[serde(default)]
411 pub normalized: Option<bool>,
412 }
413
414 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
415 pub struct Scores {
416 pub decoder: DecoderType,
417 pub quantization: Option<QuantTuple>,
418 pub shape: Vec<usize>,
419 #[serde(default)]
422 pub dshape: Vec<(DimName, usize)>,
423 }
424
425 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
426 pub struct Boxes {
427 pub decoder: DecoderType,
428 pub quantization: Option<QuantTuple>,
429 pub shape: Vec<usize>,
430 #[serde(default)]
433 pub dshape: Vec<(DimName, usize)>,
434 #[serde(default)]
441 pub normalized: Option<bool>,
442 }
443
444 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq)]
445 pub enum DimName {
446 #[serde(rename = "batch")]
447 Batch,
448 #[serde(rename = "height")]
449 Height,
450 #[serde(rename = "width")]
451 Width,
452 #[serde(rename = "num_classes")]
453 NumClasses,
454 #[serde(rename = "num_features")]
455 NumFeatures,
456 #[serde(rename = "num_boxes")]
457 NumBoxes,
458 #[serde(rename = "num_protos")]
459 NumProtos,
460 #[serde(rename = "num_anchors_x_features")]
461 NumAnchorsXFeatures,
462 #[serde(rename = "padding")]
463 Padding,
464 #[serde(rename = "box_coords")]
465 BoxCoords,
466 }
467
468 impl Display for DimName {
469 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
479 match self {
480 DimName::Batch => write!(f, "batch"),
481 DimName::Height => write!(f, "height"),
482 DimName::Width => write!(f, "width"),
483 DimName::NumClasses => write!(f, "num_classes"),
484 DimName::NumFeatures => write!(f, "num_features"),
485 DimName::NumBoxes => write!(f, "num_boxes"),
486 DimName::NumProtos => write!(f, "num_protos"),
487 DimName::NumAnchorsXFeatures => write!(f, "num_anchors_x_features"),
488 DimName::Padding => write!(f, "padding"),
489 DimName::BoxCoords => write!(f, "box_coords"),
490 }
491 }
492 }
493
494 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq)]
495 pub enum DecoderType {
496 #[serde(rename = "modelpack")]
497 ModelPack,
498 #[serde(rename = "ultralytics")]
499 Ultralytics,
500 }
501
502 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq)]
514 #[serde(rename_all = "lowercase")]
515 pub enum DecoderVersion {
516 #[serde(rename = "yolov5")]
518 Yolov5,
519 #[serde(rename = "yolov8")]
521 Yolov8,
522 #[serde(rename = "yolo11")]
524 Yolo11,
525 #[serde(rename = "yolo26")]
528 Yolo26,
529 }
530
531 impl DecoderVersion {
532 pub fn is_end_to_end(&self) -> bool {
535 matches!(self, DecoderVersion::Yolo26)
536 }
537 }
538
539 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq, Default)]
548 #[serde(rename_all = "snake_case")]
549 pub enum Nms {
550 #[default]
553 ClassAgnostic,
554 ClassAware,
556 }
557
558 #[derive(Debug, Clone, PartialEq)]
559 pub enum ModelType {
560 ModelPackSegDet {
561 boxes: Boxes,
562 scores: Scores,
563 segmentation: Segmentation,
564 },
565 ModelPackSegDetSplit {
566 detection: Vec<Detection>,
567 segmentation: Segmentation,
568 },
569 ModelPackDet {
570 boxes: Boxes,
571 scores: Scores,
572 },
573 ModelPackDetSplit {
574 detection: Vec<Detection>,
575 },
576 ModelPackSeg {
577 segmentation: Segmentation,
578 },
579 YoloDet {
580 boxes: Detection,
581 },
582 YoloSegDet {
583 boxes: Detection,
584 protos: Protos,
585 },
586 YoloSplitDet {
587 boxes: Boxes,
588 scores: Scores,
589 },
590 YoloSplitSegDet {
591 boxes: Boxes,
592 scores: Scores,
593 mask_coeff: MaskCoefficients,
594 protos: Protos,
595 },
596 YoloEndToEndDet {
600 boxes: Detection,
601 },
602 YoloEndToEndSegDet {
606 boxes: Detection,
607 protos: Protos,
608 },
609 }
610
611 #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
612 #[serde(rename_all = "lowercase")]
613 pub enum DataType {
614 Raw = 0,
615 Int8 = 1,
616 UInt8 = 2,
617 Int16 = 3,
618 UInt16 = 4,
619 Float16 = 5,
620 Int32 = 6,
621 UInt32 = 7,
622 Float32 = 8,
623 Int64 = 9,
624 UInt64 = 10,
625 Float64 = 11,
626 String = 12,
627 }
628}
629
630#[derive(Debug, Clone, PartialEq)]
631pub struct DecoderBuilder {
632 config_src: Option<ConfigSource>,
633 iou_threshold: f32,
634 score_threshold: f32,
635 nms: Option<configs::Nms>,
638}
639
640#[derive(Debug, Clone, PartialEq)]
641enum ConfigSource {
642 Yaml(String),
643 Json(String),
644 Config(ConfigOutputs),
645}
646
647impl Default for DecoderBuilder {
648 fn default() -> Self {
668 Self {
669 config_src: None,
670 iou_threshold: 0.5,
671 score_threshold: 0.5,
672 nms: Some(configs::Nms::ClassAgnostic),
673 }
674 }
675}
676
677impl DecoderBuilder {
678 pub fn new() -> Self {
698 Self::default()
699 }
700
701 pub fn with_config_yaml_str(mut self, yaml_str: String) -> Self {
718 self.config_src.replace(ConfigSource::Yaml(yaml_str));
719 self
720 }
721
722 pub fn with_config_json_str(mut self, json_str: String) -> Self {
739 self.config_src.replace(ConfigSource::Json(json_str));
740 self
741 }
742
743 pub fn with_config(mut self, config: ConfigOutputs) -> Self {
760 self.config_src.replace(ConfigSource::Config(config));
761 self
762 }
763
764 pub fn with_config_yolo_det(
789 mut self,
790 boxes: configs::Detection,
791 version: Option<DecoderVersion>,
792 ) -> Self {
793 let config = ConfigOutputs {
794 outputs: vec![ConfigOutput::Detection(boxes)],
795 decoder_version: version,
796 ..Default::default()
797 };
798 self.config_src.replace(ConfigSource::Config(config));
799 self
800 }
801
802 pub fn with_config_yolo_split_det(
829 mut self,
830 boxes: configs::Boxes,
831 scores: configs::Scores,
832 ) -> Self {
833 let config = ConfigOutputs {
834 outputs: vec![ConfigOutput::Boxes(boxes), ConfigOutput::Scores(scores)],
835 ..Default::default()
836 };
837 self.config_src.replace(ConfigSource::Config(config));
838 self
839 }
840
841 pub fn with_config_yolo_segdet(
873 mut self,
874 boxes: configs::Detection,
875 protos: configs::Protos,
876 version: Option<DecoderVersion>,
877 ) -> Self {
878 let config = ConfigOutputs {
879 outputs: vec![ConfigOutput::Detection(boxes), ConfigOutput::Protos(protos)],
880 decoder_version: version,
881 ..Default::default()
882 };
883 self.config_src.replace(ConfigSource::Config(config));
884 self
885 }
886
887 pub fn with_config_yolo_split_segdet(
926 mut self,
927 boxes: configs::Boxes,
928 scores: configs::Scores,
929 mask_coefficients: configs::MaskCoefficients,
930 protos: configs::Protos,
931 ) -> Self {
932 let config = ConfigOutputs {
933 outputs: vec![
934 ConfigOutput::Boxes(boxes),
935 ConfigOutput::Scores(scores),
936 ConfigOutput::MaskCoefficients(mask_coefficients),
937 ConfigOutput::Protos(protos),
938 ],
939 ..Default::default()
940 };
941 self.config_src.replace(ConfigSource::Config(config));
942 self
943 }
944
945 pub fn with_config_modelpack_det(
972 mut self,
973 boxes: configs::Boxes,
974 scores: configs::Scores,
975 ) -> Self {
976 let config = ConfigOutputs {
977 outputs: vec![ConfigOutput::Boxes(boxes), ConfigOutput::Scores(scores)],
978 ..Default::default()
979 };
980 self.config_src.replace(ConfigSource::Config(config));
981 self
982 }
983
984 pub fn with_config_modelpack_det_split(mut self, boxes: Vec<configs::Detection>) -> Self {
1023 let outputs = boxes.into_iter().map(ConfigOutput::Detection).collect();
1024 let config = ConfigOutputs {
1025 outputs,
1026 ..Default::default()
1027 };
1028 self.config_src.replace(ConfigSource::Config(config));
1029 self
1030 }
1031
1032 pub fn with_config_modelpack_segdet(
1065 mut self,
1066 boxes: configs::Boxes,
1067 scores: configs::Scores,
1068 segmentation: configs::Segmentation,
1069 ) -> Self {
1070 let config = ConfigOutputs {
1071 outputs: vec![
1072 ConfigOutput::Boxes(boxes),
1073 ConfigOutput::Scores(scores),
1074 ConfigOutput::Segmentation(segmentation),
1075 ],
1076 ..Default::default()
1077 };
1078 self.config_src.replace(ConfigSource::Config(config));
1079 self
1080 }
1081
1082 pub fn with_config_modelpack_segdet_split(
1126 mut self,
1127 boxes: Vec<configs::Detection>,
1128 segmentation: configs::Segmentation,
1129 ) -> Self {
1130 let mut outputs = boxes
1131 .into_iter()
1132 .map(ConfigOutput::Detection)
1133 .collect::<Vec<_>>();
1134 outputs.push(ConfigOutput::Segmentation(segmentation));
1135 let config = ConfigOutputs {
1136 outputs,
1137 ..Default::default()
1138 };
1139 self.config_src.replace(ConfigSource::Config(config));
1140 self
1141 }
1142
1143 pub fn with_config_modelpack_seg(mut self, segmentation: configs::Segmentation) -> Self {
1163 let config = ConfigOutputs {
1164 outputs: vec![ConfigOutput::Segmentation(segmentation)],
1165 ..Default::default()
1166 };
1167 self.config_src.replace(ConfigSource::Config(config));
1168 self
1169 }
1170
1171 pub fn with_score_threshold(mut self, score_threshold: f32) -> Self {
1187 self.score_threshold = score_threshold;
1188 self
1189 }
1190
1191 pub fn with_iou_threshold(mut self, iou_threshold: f32) -> Self {
1208 self.iou_threshold = iou_threshold;
1209 self
1210 }
1211
1212 pub fn with_nms(mut self, nms: Option<configs::Nms>) -> Self {
1234 self.nms = nms;
1235 self
1236 }
1237
1238 pub fn build(self) -> Result<Decoder, DecoderError> {
1255 let config = match self.config_src {
1256 Some(ConfigSource::Json(s)) => serde_json::from_str(&s)?,
1257 Some(ConfigSource::Yaml(s)) => serde_yaml::from_str(&s)?,
1258 Some(ConfigSource::Config(c)) => c,
1259 None => return Err(DecoderError::NoConfig),
1260 };
1261
1262 let normalized = Self::get_normalized(&config.outputs);
1264
1265 let nms = config.nms.or(self.nms);
1267 let model_type = Self::get_model_type(config)?;
1268
1269 Ok(Decoder {
1270 model_type,
1271 iou_threshold: self.iou_threshold,
1272 score_threshold: self.score_threshold,
1273 nms,
1274 normalized,
1275 })
1276 }
1277
1278 fn get_normalized(outputs: &[ConfigOutput]) -> Option<bool> {
1283 for output in outputs {
1284 match output {
1285 ConfigOutput::Detection(det) => return det.normalized,
1286 ConfigOutput::Boxes(boxes) => return boxes.normalized,
1287 _ => {}
1288 }
1289 }
1290 None }
1292
1293 fn get_model_type(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
1294 let mut yolo = false;
1296 let mut modelpack = false;
1297 for c in &configs.outputs {
1298 match c.decoder() {
1299 DecoderType::ModelPack => modelpack = true,
1300 DecoderType::Ultralytics => yolo = true,
1301 }
1302 }
1303 match (modelpack, yolo) {
1304 (true, true) => Err(DecoderError::InvalidConfig(
1305 "Both ModelPack and Yolo outputs found in config".to_string(),
1306 )),
1307 (true, false) => Self::get_model_type_modelpack(configs),
1308 (false, true) => Self::get_model_type_yolo(configs),
1309 (false, false) => Err(DecoderError::InvalidConfig(
1310 "No outputs found in config".to_string(),
1311 )),
1312 }
1313 }
1314
1315 fn get_model_type_yolo(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
1316 let mut boxes = None;
1317 let mut protos = None;
1318 let mut split_boxes = None;
1319 let mut split_scores = None;
1320 let mut split_mask_coeff = None;
1321 for c in configs.outputs {
1322 match c {
1323 ConfigOutput::Detection(detection) => boxes = Some(detection),
1324 ConfigOutput::Segmentation(_) => {
1325 return Err(DecoderError::InvalidConfig(
1326 "Invalid Segmentation output with Yolo decoder".to_string(),
1327 ));
1328 }
1329 ConfigOutput::Protos(protos_) => protos = Some(protos_),
1330 ConfigOutput::Mask(_) => {
1331 return Err(DecoderError::InvalidConfig(
1332 "Invalid Mask output with Yolo decoder".to_string(),
1333 ));
1334 }
1335 ConfigOutput::Scores(scores) => split_scores = Some(scores),
1336 ConfigOutput::Boxes(boxes) => split_boxes = Some(boxes),
1337 ConfigOutput::MaskCoefficients(mask_coeff) => split_mask_coeff = Some(mask_coeff),
1338 }
1339 }
1340
1341 let is_end_to_end_dshape = boxes.as_ref().is_some_and(|b| {
1346 let dims = b.dshape.iter().map(|(d, _)| *d).collect::<Vec<_>>();
1347 dims == vec![DimName::Batch, DimName::NumBoxes, DimName::NumFeatures]
1348 });
1349
1350 let is_end_to_end = configs
1351 .decoder_version
1352 .map(|v| v.is_end_to_end())
1353 .unwrap_or(is_end_to_end_dshape);
1354
1355 if is_end_to_end {
1356 if let Some(boxes) = boxes {
1357 if let Some(protos) = protos {
1358 Self::verify_yolo_seg_det_26(&boxes, &protos)?;
1359 return Ok(ModelType::YoloEndToEndSegDet { boxes, protos });
1360 } else {
1361 Self::verify_yolo_det_26(&boxes)?;
1362 return Ok(ModelType::YoloEndToEndDet { boxes });
1363 }
1364 } else {
1365 return Err(DecoderError::InvalidConfig(
1366 "Invalid Yolo end-to-end model outputs".to_string(),
1367 ));
1368 }
1369 }
1370
1371 if let Some(boxes) = boxes {
1372 if let Some(protos) = protos {
1373 Self::verify_yolo_seg_det(&boxes, &protos)?;
1374 Ok(ModelType::YoloSegDet { boxes, protos })
1375 } else {
1376 Self::verify_yolo_det(&boxes)?;
1377 Ok(ModelType::YoloDet { boxes })
1378 }
1379 } else if let (Some(boxes), Some(scores)) = (split_boxes, split_scores) {
1380 if let (Some(mask_coeff), Some(protos)) = (split_mask_coeff, protos) {
1381 Self::verify_yolo_split_segdet(&boxes, &scores, &mask_coeff, &protos)?;
1382 Ok(ModelType::YoloSplitSegDet {
1383 boxes,
1384 scores,
1385 mask_coeff,
1386 protos,
1387 })
1388 } else {
1389 Self::verify_yolo_split_det(&boxes, &scores)?;
1390 Ok(ModelType::YoloSplitDet { boxes, scores })
1391 }
1392 } else {
1393 Err(DecoderError::InvalidConfig(
1394 "Invalid Yolo model outputs".to_string(),
1395 ))
1396 }
1397 }
1398
1399 fn verify_yolo_det(detect: &configs::Detection) -> Result<(), DecoderError> {
1400 if detect.shape.len() != 3 {
1401 return Err(DecoderError::InvalidConfig(format!(
1402 "Invalid Yolo Detection shape {:?}",
1403 detect.shape
1404 )));
1405 }
1406
1407 Self::verify_dshapes(
1408 &detect.dshape,
1409 &detect.shape,
1410 "Detection",
1411 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1412 )?;
1413 if !detect.dshape.is_empty() {
1414 Self::get_class_count(&detect.dshape, None, None)?;
1415 } else {
1416 Self::get_class_count_no_dshape(detect.into(), None)?;
1417 }
1418
1419 Ok(())
1420 }
1421
1422 fn verify_yolo_det_26(detect: &configs::Detection) -> Result<(), DecoderError> {
1423 if detect.shape.len() != 3 {
1424 return Err(DecoderError::InvalidConfig(format!(
1425 "Invalid Yolo Detection shape {:?}",
1426 detect.shape
1427 )));
1428 }
1429
1430 Self::verify_dshapes(
1431 &detect.dshape,
1432 &detect.shape,
1433 "Detection",
1434 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1435 )?;
1436
1437 if !detect.shape.contains(&6) {
1438 return Err(DecoderError::InvalidConfig(
1439 "Yolo26 Detection must have 6 features".to_string(),
1440 ));
1441 }
1442
1443 Ok(())
1444 }
1445
1446 fn verify_yolo_seg_det(
1447 detection: &configs::Detection,
1448 protos: &configs::Protos,
1449 ) -> Result<(), DecoderError> {
1450 if detection.shape.len() != 3 {
1451 return Err(DecoderError::InvalidConfig(format!(
1452 "Invalid Yolo Detection shape {:?}",
1453 detection.shape
1454 )));
1455 }
1456 if protos.shape.len() != 4 {
1457 return Err(DecoderError::InvalidConfig(format!(
1458 "Invalid Yolo Protos shape {:?}",
1459 protos.shape
1460 )));
1461 }
1462
1463 Self::verify_dshapes(
1464 &detection.dshape,
1465 &detection.shape,
1466 "Detection",
1467 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1468 )?;
1469 Self::verify_dshapes(
1470 &protos.dshape,
1471 &protos.shape,
1472 "Protos",
1473 &[
1474 DimName::Batch,
1475 DimName::Height,
1476 DimName::Width,
1477 DimName::NumProtos,
1478 ],
1479 )?;
1480
1481 let protos_count = Self::get_protos_count(&protos.dshape).unwrap_or(protos.shape[3]);
1482 log::debug!("Protos count: {}", protos_count);
1483 log::debug!("Detection dshape: {:?}", detection.dshape);
1484 let classes = if !detection.dshape.is_empty() {
1485 Self::get_class_count(&detection.dshape, Some(protos_count), None)?
1486 } else {
1487 Self::get_class_count_no_dshape(detection.into(), Some(protos_count))?
1488 };
1489
1490 if classes == 0 {
1491 return Err(DecoderError::InvalidConfig(
1492 "Yolo Segmentation Detection has zero classes".to_string(),
1493 ));
1494 }
1495
1496 Ok(())
1497 }
1498
1499 fn verify_yolo_seg_det_26(
1500 detection: &configs::Detection,
1501 protos: &configs::Protos,
1502 ) -> Result<(), DecoderError> {
1503 if detection.shape.len() != 3 {
1504 return Err(DecoderError::InvalidConfig(format!(
1505 "Invalid Yolo Detection shape {:?}",
1506 detection.shape
1507 )));
1508 }
1509 if protos.shape.len() != 4 {
1510 return Err(DecoderError::InvalidConfig(format!(
1511 "Invalid Yolo Protos shape {:?}",
1512 protos.shape
1513 )));
1514 }
1515
1516 Self::verify_dshapes(
1517 &detection.dshape,
1518 &detection.shape,
1519 "Detection",
1520 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1521 )?;
1522 Self::verify_dshapes(
1523 &protos.dshape,
1524 &protos.shape,
1525 "Protos",
1526 &[
1527 DimName::Batch,
1528 DimName::Height,
1529 DimName::Width,
1530 DimName::NumProtos,
1531 ],
1532 )?;
1533
1534 let protos_count = Self::get_protos_count(&protos.dshape).unwrap_or(protos.shape[3]);
1535 log::debug!("Protos count: {}", protos_count);
1536 log::debug!("Detection dshape: {:?}", detection.dshape);
1537
1538 if !detection.shape.contains(&(6 + protos_count)) {
1539 return Err(DecoderError::InvalidConfig(format!(
1540 "Yolo26 Segmentation Detection must have num_features be 6 + num_protos = {}",
1541 6 + protos_count
1542 )));
1543 }
1544
1545 Ok(())
1546 }
1547
1548 fn verify_yolo_split_det(
1549 boxes: &configs::Boxes,
1550 scores: &configs::Scores,
1551 ) -> Result<(), DecoderError> {
1552 if boxes.shape.len() != 3 {
1553 return Err(DecoderError::InvalidConfig(format!(
1554 "Invalid Yolo Split Boxes shape {:?}",
1555 boxes.shape
1556 )));
1557 }
1558 if scores.shape.len() != 3 {
1559 return Err(DecoderError::InvalidConfig(format!(
1560 "Invalid Yolo Split Scores shape {:?}",
1561 scores.shape
1562 )));
1563 }
1564
1565 Self::verify_dshapes(
1566 &boxes.dshape,
1567 &boxes.shape,
1568 "Boxes",
1569 &[DimName::Batch, DimName::BoxCoords, DimName::NumBoxes],
1570 )?;
1571 Self::verify_dshapes(
1572 &scores.dshape,
1573 &scores.shape,
1574 "Scores",
1575 &[DimName::Batch, DimName::NumClasses, DimName::NumBoxes],
1576 )?;
1577
1578 let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[2]);
1579 let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[2]);
1580
1581 if boxes_num != scores_num {
1582 return Err(DecoderError::InvalidConfig(format!(
1583 "Yolo Split Detection Boxes num {} incompatible with Scores num {}",
1584 boxes_num, scores_num
1585 )));
1586 }
1587
1588 Ok(())
1589 }
1590
1591 fn verify_yolo_split_segdet(
1592 boxes: &configs::Boxes,
1593 scores: &configs::Scores,
1594 mask_coeff: &configs::MaskCoefficients,
1595 protos: &configs::Protos,
1596 ) -> Result<(), DecoderError> {
1597 if boxes.shape.len() != 3 {
1598 return Err(DecoderError::InvalidConfig(format!(
1599 "Invalid Yolo Split Boxes shape {:?}",
1600 boxes.shape
1601 )));
1602 }
1603 if scores.shape.len() != 3 {
1604 return Err(DecoderError::InvalidConfig(format!(
1605 "Invalid Yolo Split Scores shape {:?}",
1606 scores.shape
1607 )));
1608 }
1609
1610 if mask_coeff.shape.len() != 3 {
1611 return Err(DecoderError::InvalidConfig(format!(
1612 "Invalid Yolo Split Mask Coefficients shape {:?}",
1613 mask_coeff.shape
1614 )));
1615 }
1616
1617 if protos.shape.len() != 4 {
1618 return Err(DecoderError::InvalidConfig(format!(
1619 "Invalid Yolo Protos shape {:?}",
1620 mask_coeff.shape
1621 )));
1622 }
1623
1624 Self::verify_dshapes(
1625 &boxes.dshape,
1626 &boxes.shape,
1627 "Boxes",
1628 &[DimName::Batch, DimName::BoxCoords, DimName::NumBoxes],
1629 )?;
1630 Self::verify_dshapes(
1631 &scores.dshape,
1632 &scores.shape,
1633 "Scores",
1634 &[DimName::Batch, DimName::NumClasses, DimName::NumBoxes],
1635 )?;
1636 Self::verify_dshapes(
1637 &mask_coeff.dshape,
1638 &mask_coeff.shape,
1639 "Mask Coefficients",
1640 &[DimName::Batch, DimName::NumProtos, DimName::NumBoxes],
1641 )?;
1642 Self::verify_dshapes(
1643 &protos.dshape,
1644 &protos.shape,
1645 "Protos",
1646 &[
1647 DimName::Batch,
1648 DimName::Height,
1649 DimName::Width,
1650 DimName::NumProtos,
1651 ],
1652 )?;
1653
1654 let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[2]);
1655 let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[2]);
1656 let mask_num = Self::get_box_count(&mask_coeff.dshape).unwrap_or(mask_coeff.shape[2]);
1657
1658 let mask_channels = if !mask_coeff.dshape.is_empty() {
1659 Self::get_protos_count(&mask_coeff.dshape).ok_or_else(|| {
1660 DecoderError::InvalidConfig("Could not find num_protos in config".to_string())
1661 })?
1662 } else {
1663 mask_coeff.shape[1]
1664 };
1665 let proto_channels = if !protos.dshape.is_empty() {
1666 Self::get_protos_count(&protos.dshape).ok_or_else(|| {
1667 DecoderError::InvalidConfig("Could not find num_protos in config".to_string())
1668 })?
1669 } else {
1670 protos.shape[3]
1671 };
1672
1673 if boxes_num != scores_num {
1674 return Err(DecoderError::InvalidConfig(format!(
1675 "Yolo Split Detection Boxes num {} incompatible with Scores num {}",
1676 boxes_num, scores_num
1677 )));
1678 }
1679
1680 if boxes_num != mask_num {
1681 return Err(DecoderError::InvalidConfig(format!(
1682 "Yolo Split Detection Boxes num {} incompatible with Mask Coefficients num {}",
1683 boxes_num, mask_num
1684 )));
1685 }
1686
1687 if proto_channels != mask_channels {
1688 return Err(DecoderError::InvalidConfig(format!(
1689 "Yolo Protos channels {} incompatible with Mask Coefficients channels {}",
1690 proto_channels, mask_channels
1691 )));
1692 }
1693
1694 Ok(())
1695 }
1696
1697 fn get_model_type_modelpack(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
1698 let mut split_decoders = Vec::new();
1699 let mut segment_ = None;
1700 let mut scores_ = None;
1701 let mut boxes_ = None;
1702 for c in configs.outputs {
1703 match c {
1704 ConfigOutput::Detection(detection) => split_decoders.push(detection),
1705 ConfigOutput::Segmentation(segmentation) => segment_ = Some(segmentation),
1706 ConfigOutput::Mask(_) => {}
1707 ConfigOutput::Protos(_) => {
1708 return Err(DecoderError::InvalidConfig(
1709 "ModelPack should not have protos".to_string(),
1710 ));
1711 }
1712 ConfigOutput::Scores(scores) => scores_ = Some(scores),
1713 ConfigOutput::Boxes(boxes) => boxes_ = Some(boxes),
1714 ConfigOutput::MaskCoefficients(_) => {
1715 return Err(DecoderError::InvalidConfig(
1716 "ModelPack should not have mask coefficients".to_string(),
1717 ));
1718 }
1719 }
1720 }
1721
1722 if let Some(segmentation) = segment_ {
1723 if !split_decoders.is_empty() {
1724 let classes = Self::verify_modelpack_split_det(&split_decoders)?;
1725 Self::verify_modelpack_seg(&segmentation, Some(classes))?;
1726 Ok(ModelType::ModelPackSegDetSplit {
1727 detection: split_decoders,
1728 segmentation,
1729 })
1730 } else if let (Some(scores), Some(boxes)) = (scores_, boxes_) {
1731 let classes = Self::verify_modelpack_det(&boxes, &scores)?;
1732 Self::verify_modelpack_seg(&segmentation, Some(classes))?;
1733 Ok(ModelType::ModelPackSegDet {
1734 boxes,
1735 scores,
1736 segmentation,
1737 })
1738 } else {
1739 Self::verify_modelpack_seg(&segmentation, None)?;
1740 Ok(ModelType::ModelPackSeg { segmentation })
1741 }
1742 } else if !split_decoders.is_empty() {
1743 Self::verify_modelpack_split_det(&split_decoders)?;
1744 Ok(ModelType::ModelPackDetSplit {
1745 detection: split_decoders,
1746 })
1747 } else if let (Some(scores), Some(boxes)) = (scores_, boxes_) {
1748 Self::verify_modelpack_det(&boxes, &scores)?;
1749 Ok(ModelType::ModelPackDet { boxes, scores })
1750 } else {
1751 Err(DecoderError::InvalidConfig(
1752 "Invalid ModelPack model outputs".to_string(),
1753 ))
1754 }
1755 }
1756
1757 fn verify_modelpack_det(
1758 boxes: &configs::Boxes,
1759 scores: &configs::Scores,
1760 ) -> Result<usize, DecoderError> {
1761 if boxes.shape.len() != 4 {
1762 return Err(DecoderError::InvalidConfig(format!(
1763 "Invalid ModelPack Boxes shape {:?}",
1764 boxes.shape
1765 )));
1766 }
1767 if scores.shape.len() != 3 {
1768 return Err(DecoderError::InvalidConfig(format!(
1769 "Invalid ModelPack Scores shape {:?}",
1770 scores.shape
1771 )));
1772 }
1773
1774 Self::verify_dshapes(
1775 &boxes.dshape,
1776 &boxes.shape,
1777 "Boxes",
1778 &[
1779 DimName::Batch,
1780 DimName::NumBoxes,
1781 DimName::Padding,
1782 DimName::BoxCoords,
1783 ],
1784 )?;
1785 Self::verify_dshapes(
1786 &scores.dshape,
1787 &scores.shape,
1788 "Scores",
1789 &[DimName::Batch, DimName::NumBoxes, DimName::NumClasses],
1790 )?;
1791
1792 let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[1]);
1793 let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[1]);
1794
1795 if boxes_num != scores_num {
1796 return Err(DecoderError::InvalidConfig(format!(
1797 "ModelPack Detection Boxes num {} incompatible with Scores num {}",
1798 boxes_num, scores_num
1799 )));
1800 }
1801
1802 let num_classes = if !scores.dshape.is_empty() {
1803 Self::get_class_count(&scores.dshape, None, None)?
1804 } else {
1805 Self::get_class_count_no_dshape(scores.into(), None)?
1806 };
1807
1808 Ok(num_classes)
1809 }
1810
1811 fn verify_modelpack_split_det(boxes: &[configs::Detection]) -> Result<usize, DecoderError> {
1812 let mut num_classes = None;
1813 for b in boxes {
1814 let Some(num_anchors) = b.anchors.as_ref().map(|a| a.len()) else {
1815 return Err(DecoderError::InvalidConfig(
1816 "ModelPack Split Detection missing anchors".to_string(),
1817 ));
1818 };
1819
1820 if num_anchors == 0 {
1821 return Err(DecoderError::InvalidConfig(
1822 "ModelPack Split Detection has zero anchors".to_string(),
1823 ));
1824 }
1825
1826 if b.shape.len() != 4 {
1827 return Err(DecoderError::InvalidConfig(format!(
1828 "Invalid ModelPack Split Detection shape {:?}",
1829 b.shape
1830 )));
1831 }
1832
1833 Self::verify_dshapes(
1834 &b.dshape,
1835 &b.shape,
1836 "Split Detection",
1837 &[
1838 DimName::Batch,
1839 DimName::Height,
1840 DimName::Width,
1841 DimName::NumAnchorsXFeatures,
1842 ],
1843 )?;
1844 let classes = if !b.dshape.is_empty() {
1845 Self::get_class_count(&b.dshape, None, Some(num_anchors))?
1846 } else {
1847 Self::get_class_count_no_dshape(b.into(), None)?
1848 };
1849
1850 match num_classes {
1851 Some(n) => {
1852 if n != classes {
1853 return Err(DecoderError::InvalidConfig(format!(
1854 "ModelPack Split Detection inconsistent number of classes: previous {}, current {}",
1855 n, classes
1856 )));
1857 }
1858 }
1859 None => {
1860 num_classes = Some(classes);
1861 }
1862 }
1863 }
1864
1865 Ok(num_classes.unwrap_or(0))
1866 }
1867
1868 fn verify_modelpack_seg(
1869 segmentation: &configs::Segmentation,
1870 classes: Option<usize>,
1871 ) -> Result<(), DecoderError> {
1872 if segmentation.shape.len() != 4 {
1873 return Err(DecoderError::InvalidConfig(format!(
1874 "Invalid ModelPack Segmentation shape {:?}",
1875 segmentation.shape
1876 )));
1877 }
1878 Self::verify_dshapes(
1879 &segmentation.dshape,
1880 &segmentation.shape,
1881 "Segmentation",
1882 &[
1883 DimName::Batch,
1884 DimName::Height,
1885 DimName::Width,
1886 DimName::NumClasses,
1887 ],
1888 )?;
1889
1890 if let Some(classes) = classes {
1891 let seg_classes = if !segmentation.dshape.is_empty() {
1892 Self::get_class_count(&segmentation.dshape, None, None)?
1893 } else {
1894 Self::get_class_count_no_dshape(segmentation.into(), None)?
1895 };
1896
1897 if seg_classes != classes + 1 {
1898 return Err(DecoderError::InvalidConfig(format!(
1899 "ModelPack Segmentation channels {} incompatible with number of classes {}",
1900 seg_classes, classes
1901 )));
1902 }
1903 }
1904 Ok(())
1905 }
1906
1907 fn verify_dshapes(
1909 dshape: &[(DimName, usize)],
1910 shape: &[usize],
1911 name: &str,
1912 dims: &[DimName],
1913 ) -> Result<(), DecoderError> {
1914 for s in shape {
1915 if *s == 0 {
1916 return Err(DecoderError::InvalidConfig(format!(
1917 "{} shape has zero dimension",
1918 name
1919 )));
1920 }
1921 }
1922
1923 if shape.len() != dims.len() {
1924 return Err(DecoderError::InvalidConfig(format!(
1925 "{} shape length {} does not match expected dims length {}",
1926 name,
1927 shape.len(),
1928 dims.len()
1929 )));
1930 }
1931
1932 if dshape.is_empty() {
1933 return Ok(());
1934 }
1935 if dshape.len() != shape.len() {
1937 return Err(DecoderError::InvalidConfig(format!(
1938 "{} dshape length does not match shape length",
1939 name
1940 )));
1941 }
1942
1943 for ((dim_name, dim_size), shape_size) in dshape.iter().zip(shape) {
1945 if dim_size != shape_size {
1946 return Err(DecoderError::InvalidConfig(format!(
1947 "{} dshape dimension {} size {} does not match shape size {}",
1948 name, dim_name, dim_size, shape_size
1949 )));
1950 }
1951 if *dim_name == DimName::Padding && *dim_size != 1 {
1952 return Err(DecoderError::InvalidConfig(
1953 "Padding dimension size must be 1".to_string(),
1954 ));
1955 }
1956
1957 if *dim_name == DimName::BoxCoords && *dim_size != 4 {
1958 return Err(DecoderError::InvalidConfig(
1959 "BoxCoords dimension size must be 4".to_string(),
1960 ));
1961 }
1962 }
1963
1964 let dims_present = HashSet::<DimName>::from_iter(dshape.iter().map(|(name, _)| *name));
1965 for dim in dims {
1966 if !dims_present.contains(dim) {
1967 return Err(DecoderError::InvalidConfig(format!(
1968 "{} dshape missing required dimension {:?}",
1969 name, dim
1970 )));
1971 }
1972 }
1973
1974 Ok(())
1975 }
1976
1977 fn get_box_count(dshape: &[(DimName, usize)]) -> Option<usize> {
1978 for (dim_name, dim_size) in dshape {
1979 if *dim_name == DimName::NumBoxes {
1980 return Some(*dim_size);
1981 }
1982 }
1983 None
1984 }
1985
1986 fn get_class_count_no_dshape(
1987 config: ConfigOutputRef,
1988 protos: Option<usize>,
1989 ) -> Result<usize, DecoderError> {
1990 match config {
1991 ConfigOutputRef::Detection(detection) => match detection.decoder {
1992 DecoderType::Ultralytics => {
1993 if detection.shape[1] <= 4 + protos.unwrap_or(0) {
1994 return Err(DecoderError::InvalidConfig(format!(
1995 "Invalid shape: Yolo num_features {} must be greater than {}",
1996 detection.shape[1],
1997 4 + protos.unwrap_or(0),
1998 )));
1999 }
2000 Ok(detection.shape[1] - 4 - protos.unwrap_or(0))
2001 }
2002 DecoderType::ModelPack => {
2003 let Some(num_anchors) = detection.anchors.as_ref().map(|a| a.len()) else {
2004 return Err(DecoderError::Internal(
2005 "ModelPack Detection missing anchors".to_string(),
2006 ));
2007 };
2008 let anchors_x_features = detection.shape[3];
2009 if anchors_x_features <= num_anchors * 5 {
2010 return Err(DecoderError::InvalidConfig(format!(
2011 "Invalid ModelPack Split Detection shape: anchors_x_features {} not greater than number of anchors * 5 = {}",
2012 anchors_x_features,
2013 num_anchors * 5,
2014 )));
2015 }
2016
2017 if !anchors_x_features.is_multiple_of(num_anchors) {
2018 return Err(DecoderError::InvalidConfig(format!(
2019 "Invalid ModelPack Split Detection shape: anchors_x_features {} not a multiple of number of anchors {}",
2020 anchors_x_features, num_anchors
2021 )));
2022 }
2023 Ok(anchors_x_features / num_anchors - 5)
2024 }
2025 },
2026
2027 ConfigOutputRef::Scores(scores) => match scores.decoder {
2028 DecoderType::Ultralytics => Ok(scores.shape[1]),
2029 DecoderType::ModelPack => Ok(scores.shape[2]),
2030 },
2031 ConfigOutputRef::Segmentation(seg) => Ok(seg.shape[3]),
2032 _ => Err(DecoderError::Internal(
2033 "Attempted to get class count from unsupported config output".to_owned(),
2034 )),
2035 }
2036 }
2037
2038 fn get_class_count(
2040 dshape: &[(DimName, usize)],
2041 protos: Option<usize>,
2042 anchors: Option<usize>,
2043 ) -> Result<usize, DecoderError> {
2044 if dshape.is_empty() {
2045 return Ok(0);
2046 }
2047 for (dim_name, dim_size) in dshape {
2049 if *dim_name == DimName::NumClasses {
2050 return Ok(*dim_size);
2051 }
2052 }
2053
2054 for (dim_name, dim_size) in dshape {
2057 if *dim_name == DimName::NumFeatures {
2058 let protos = protos.unwrap_or(0);
2059 if protos + 4 >= *dim_size {
2060 return Err(DecoderError::InvalidConfig(format!(
2061 "Invalid shape: Yolo num_features {} must be greater than {}",
2062 *dim_size,
2063 protos + 4,
2064 )));
2065 }
2066 return Ok(*dim_size - 4 - protos);
2067 }
2068 }
2069
2070 if let Some(num_anchors) = anchors {
2073 for (dim_name, dim_size) in dshape {
2074 if *dim_name == DimName::NumAnchorsXFeatures {
2075 let anchors_x_features = *dim_size;
2076 if anchors_x_features <= num_anchors * 5 {
2077 return Err(DecoderError::InvalidConfig(format!(
2078 "Invalid ModelPack Split Detection shape: anchors_x_features {} not greater than number of anchors * 5 = {}",
2079 anchors_x_features,
2080 num_anchors * 5,
2081 )));
2082 }
2083
2084 if !anchors_x_features.is_multiple_of(num_anchors) {
2085 return Err(DecoderError::InvalidConfig(format!(
2086 "Invalid ModelPack Split Detection shape: anchors_x_features {} not a multiple of number of anchors {}",
2087 anchors_x_features, num_anchors
2088 )));
2089 }
2090 return Ok((anchors_x_features / num_anchors) - 5);
2091 }
2092 }
2093 }
2094 Err(DecoderError::InvalidConfig(
2095 "Cannot determine number of classes from dshape".to_owned(),
2096 ))
2097 }
2098
2099 fn get_protos_count(dshape: &[(DimName, usize)]) -> Option<usize> {
2100 for (dim_name, dim_size) in dshape {
2101 if *dim_name == DimName::NumProtos {
2102 return Some(*dim_size);
2103 }
2104 }
2105 None
2106 }
2107}
2108
2109#[derive(Debug, Clone, PartialEq)]
2110pub struct Decoder {
2111 model_type: ModelType,
2112 pub iou_threshold: f32,
2113 pub score_threshold: f32,
2114 pub nms: Option<configs::Nms>,
2117 normalized: Option<bool>,
2123}
2124
2125#[derive(Debug)]
2126pub enum ArrayViewDQuantized<'a> {
2127 UInt8(ArrayViewD<'a, u8>),
2128 Int8(ArrayViewD<'a, i8>),
2129 UInt16(ArrayViewD<'a, u16>),
2130 Int16(ArrayViewD<'a, i16>),
2131 UInt32(ArrayViewD<'a, u32>),
2132 Int32(ArrayViewD<'a, i32>),
2133}
2134
2135impl<'a, D> From<ArrayView<'a, u8, D>> for ArrayViewDQuantized<'a>
2136where
2137 D: Dimension,
2138{
2139 fn from(arr: ArrayView<'a, u8, D>) -> Self {
2140 Self::UInt8(arr.into_dyn())
2141 }
2142}
2143
2144impl<'a, D> From<ArrayView<'a, i8, D>> for ArrayViewDQuantized<'a>
2145where
2146 D: Dimension,
2147{
2148 fn from(arr: ArrayView<'a, i8, D>) -> Self {
2149 Self::Int8(arr.into_dyn())
2150 }
2151}
2152
2153impl<'a, D> From<ArrayView<'a, u16, D>> for ArrayViewDQuantized<'a>
2154where
2155 D: Dimension,
2156{
2157 fn from(arr: ArrayView<'a, u16, D>) -> Self {
2158 Self::UInt16(arr.into_dyn())
2159 }
2160}
2161
2162impl<'a, D> From<ArrayView<'a, i16, D>> for ArrayViewDQuantized<'a>
2163where
2164 D: Dimension,
2165{
2166 fn from(arr: ArrayView<'a, i16, D>) -> Self {
2167 Self::Int16(arr.into_dyn())
2168 }
2169}
2170
2171impl<'a, D> From<ArrayView<'a, u32, D>> for ArrayViewDQuantized<'a>
2172where
2173 D: Dimension,
2174{
2175 fn from(arr: ArrayView<'a, u32, D>) -> Self {
2176 Self::UInt32(arr.into_dyn())
2177 }
2178}
2179
2180impl<'a, D> From<ArrayView<'a, i32, D>> for ArrayViewDQuantized<'a>
2181where
2182 D: Dimension,
2183{
2184 fn from(arr: ArrayView<'a, i32, D>) -> Self {
2185 Self::Int32(arr.into_dyn())
2186 }
2187}
2188
2189impl<'a> ArrayViewDQuantized<'a> {
2190 pub fn shape(&self) -> &[usize] {
2204 match self {
2205 ArrayViewDQuantized::UInt8(a) => a.shape(),
2206 ArrayViewDQuantized::Int8(a) => a.shape(),
2207 ArrayViewDQuantized::UInt16(a) => a.shape(),
2208 ArrayViewDQuantized::Int16(a) => a.shape(),
2209 ArrayViewDQuantized::UInt32(a) => a.shape(),
2210 ArrayViewDQuantized::Int32(a) => a.shape(),
2211 }
2212 }
2213}
2214
2215macro_rules! with_quantized {
2216 ($x:expr, $var:ident, $body:expr) => {
2217 match $x {
2218 ArrayViewDQuantized::UInt8(x) => {
2219 let $var = x;
2220 $body
2221 }
2222 ArrayViewDQuantized::Int8(x) => {
2223 let $var = x;
2224 $body
2225 }
2226 ArrayViewDQuantized::UInt16(x) => {
2227 let $var = x;
2228 $body
2229 }
2230 ArrayViewDQuantized::Int16(x) => {
2231 let $var = x;
2232 $body
2233 }
2234 ArrayViewDQuantized::UInt32(x) => {
2235 let $var = x;
2236 $body
2237 }
2238 ArrayViewDQuantized::Int32(x) => {
2239 let $var = x;
2240 $body
2241 }
2242 }
2243 };
2244}
2245
2246impl Decoder {
2247 pub fn model_type(&self) -> &ModelType {
2266 &self.model_type
2267 }
2268
2269 pub fn normalized_boxes(&self) -> Option<bool> {
2295 self.normalized
2296 }
2297
2298 pub fn decode_quantized(
2348 &self,
2349 outputs: &[ArrayViewDQuantized],
2350 output_boxes: &mut Vec<DetectBox>,
2351 output_masks: &mut Vec<Segmentation>,
2352 ) -> Result<(), DecoderError> {
2353 output_boxes.clear();
2354 output_masks.clear();
2355 match &self.model_type {
2356 ModelType::ModelPackSegDet {
2357 boxes,
2358 scores,
2359 segmentation,
2360 } => {
2361 self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)?;
2362 self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
2363 }
2364 ModelType::ModelPackSegDetSplit {
2365 detection,
2366 segmentation,
2367 } => {
2368 self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)?;
2369 self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
2370 }
2371 ModelType::ModelPackDet { boxes, scores } => {
2372 self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)
2373 }
2374 ModelType::ModelPackDetSplit { detection } => {
2375 self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)
2376 }
2377 ModelType::ModelPackSeg { segmentation } => {
2378 self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
2379 }
2380 ModelType::YoloDet { boxes } => {
2381 self.decode_yolo_det_quantized(outputs, boxes, output_boxes)
2382 }
2383 ModelType::YoloSegDet { boxes, protos } => self.decode_yolo_segdet_quantized(
2384 outputs,
2385 boxes,
2386 protos,
2387 output_boxes,
2388 output_masks,
2389 ),
2390 ModelType::YoloSplitDet { boxes, scores } => {
2391 self.decode_yolo_split_det_quantized(outputs, boxes, scores, output_boxes)
2392 }
2393 ModelType::YoloSplitSegDet {
2394 boxes,
2395 scores,
2396 mask_coeff,
2397 protos,
2398 } => self.decode_yolo_split_segdet_quantized(
2399 outputs,
2400 boxes,
2401 scores,
2402 mask_coeff,
2403 protos,
2404 output_boxes,
2405 output_masks,
2406 ),
2407 ModelType::YoloEndToEndDet { .. } | ModelType::YoloEndToEndSegDet { .. } => {
2408 Err(DecoderError::InvalidConfig(
2409 "End-to-end models require float decode, not quantized".to_string(),
2410 ))
2411 }
2412 }
2413 }
2414
2415 pub fn decode_float<T>(
2472 &self,
2473 outputs: &[ArrayViewD<T>],
2474 output_boxes: &mut Vec<DetectBox>,
2475 output_masks: &mut Vec<Segmentation>,
2476 ) -> Result<(), DecoderError>
2477 where
2478 T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
2479 f32: AsPrimitive<T>,
2480 {
2481 output_boxes.clear();
2482 output_masks.clear();
2483 match &self.model_type {
2484 ModelType::ModelPackSegDet {
2485 boxes,
2486 scores,
2487 segmentation,
2488 } => {
2489 self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
2490 self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
2491 }
2492 ModelType::ModelPackSegDetSplit {
2493 detection,
2494 segmentation,
2495 } => {
2496 self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
2497 self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
2498 }
2499 ModelType::ModelPackDet { boxes, scores } => {
2500 self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
2501 }
2502 ModelType::ModelPackDetSplit { detection } => {
2503 self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
2504 }
2505 ModelType::ModelPackSeg { segmentation } => {
2506 self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
2507 }
2508 ModelType::YoloDet { boxes } => {
2509 self.decode_yolo_det_float(outputs, boxes, output_boxes)?;
2510 }
2511 ModelType::YoloSegDet { boxes, protos } => {
2512 self.decode_yolo_segdet_float(outputs, boxes, protos, output_boxes, output_masks)?;
2513 }
2514 ModelType::YoloSplitDet { boxes, scores } => {
2515 self.decode_yolo_split_det_float(outputs, boxes, scores, output_boxes)?;
2516 }
2517 ModelType::YoloSplitSegDet {
2518 boxes,
2519 scores,
2520 mask_coeff,
2521 protos,
2522 } => {
2523 self.decode_yolo_split_segdet_float(
2524 outputs,
2525 boxes,
2526 scores,
2527 mask_coeff,
2528 protos,
2529 output_boxes,
2530 output_masks,
2531 )?;
2532 }
2533 ModelType::YoloEndToEndDet { boxes } => {
2534 self.decode_yolo_end_to_end_det_float(outputs, boxes, output_boxes)?;
2535 }
2536 ModelType::YoloEndToEndSegDet { boxes, protos } => {
2537 self.decode_yolo_end_to_end_segdet_float(
2538 outputs,
2539 boxes,
2540 protos,
2541 output_boxes,
2542 output_masks,
2543 )?;
2544 }
2545 }
2546 Ok(())
2547 }
2548
2549 fn decode_modelpack_det_quantized(
2550 &self,
2551 outputs: &[ArrayViewDQuantized],
2552 boxes: &configs::Boxes,
2553 scores: &configs::Scores,
2554 output_boxes: &mut Vec<DetectBox>,
2555 ) -> Result<(), DecoderError> {
2556 let (boxes_tensor, ind) =
2557 Self::find_outputs_with_shape_quantized(&boxes.shape, outputs, &[])?;
2558 let (scores_tensor, _) =
2559 Self::find_outputs_with_shape_quantized(&scores.shape, outputs, &[ind])?;
2560 let quant_boxes = boxes
2561 .quantization
2562 .map(Quantization::from)
2563 .unwrap_or_default();
2564 let quant_scores = scores
2565 .quantization
2566 .map(Quantization::from)
2567 .unwrap_or_default();
2568
2569 with_quantized!(boxes_tensor, b, {
2570 with_quantized!(scores_tensor, s, {
2571 let boxes_tensor = Self::swap_axes_if_needed(b, boxes.into());
2572 let boxes_tensor = boxes_tensor.slice(s![0, .., 0, ..]);
2573
2574 let scores_tensor = Self::swap_axes_if_needed(s, scores.into());
2575 let scores_tensor = scores_tensor.slice(s![0, .., ..]);
2576 decode_modelpack_det(
2577 (boxes_tensor, quant_boxes),
2578 (scores_tensor, quant_scores),
2579 self.score_threshold,
2580 self.iou_threshold,
2581 output_boxes,
2582 );
2583 });
2584 });
2585
2586 Ok(())
2587 }
2588
2589 fn decode_modelpack_seg_quantized(
2590 &self,
2591 outputs: &[ArrayViewDQuantized],
2592 segmentation: &configs::Segmentation,
2593 output_masks: &mut Vec<Segmentation>,
2594 ) -> Result<(), DecoderError> {
2595 let (seg, _) = Self::find_outputs_with_shape_quantized(&segmentation.shape, outputs, &[])?;
2596
2597 macro_rules! modelpack_seg {
2598 ($seg:expr, $body:expr) => {{
2599 let seg = Self::swap_axes_if_needed($seg, segmentation.into());
2600 let seg = seg.slice(s![0, .., .., ..]);
2601 seg.mapv($body)
2602 }};
2603 }
2604 use ArrayViewDQuantized::*;
2605 let seg = match seg {
2606 UInt8(s) => {
2607 modelpack_seg!(s, |x| x)
2608 }
2609 Int8(s) => {
2610 modelpack_seg!(s, |x| (x as i16 + 128) as u8)
2611 }
2612 UInt16(s) => {
2613 modelpack_seg!(s, |x| (x >> 8) as u8)
2614 }
2615 Int16(s) => {
2616 modelpack_seg!(s, |x| ((x as i32 + 32768) >> 8) as u8)
2617 }
2618 UInt32(s) => {
2619 modelpack_seg!(s, |x| (x >> 24) as u8)
2620 }
2621 Int32(s) => {
2622 modelpack_seg!(s, |x| ((x as i64 + 2147483648) >> 24) as u8)
2623 }
2624 };
2625
2626 output_masks.push(Segmentation {
2627 xmin: 0.0,
2628 ymin: 0.0,
2629 xmax: 1.0,
2630 ymax: 1.0,
2631 segmentation: seg,
2632 });
2633 Ok(())
2634 }
2635
2636 fn decode_modelpack_det_split_quantized(
2637 &self,
2638 outputs: &[ArrayViewDQuantized],
2639 detection: &[configs::Detection],
2640 output_boxes: &mut Vec<DetectBox>,
2641 ) -> Result<(), DecoderError> {
2642 let new_detection = detection
2643 .iter()
2644 .map(|x| match &x.anchors {
2645 None => Err(DecoderError::InvalidConfig(
2646 "ModelPack Split Detection missing anchors".to_string(),
2647 )),
2648 Some(a) => Ok(ModelPackDetectionConfig {
2649 anchors: a.clone(),
2650 quantization: None,
2651 }),
2652 })
2653 .collect::<Result<Vec<_>, _>>()?;
2654 let new_outputs = Self::match_outputs_to_detect_quantized(detection, outputs)?;
2655
2656 macro_rules! dequant_output {
2657 ($det_tensor:expr, $detection:expr) => {{
2658 let det_tensor = Self::swap_axes_if_needed($det_tensor, $detection.into());
2659 let det_tensor = det_tensor.slice(s![0, .., .., ..]);
2660 if let Some(q) = $detection.quantization {
2661 dequantize_ndarray(det_tensor, q.into())
2662 } else {
2663 det_tensor.map(|x| *x as f32)
2664 }
2665 }};
2666 }
2667
2668 let new_outputs = new_outputs
2669 .iter()
2670 .zip(detection)
2671 .map(|(det_tensor, detection)| {
2672 with_quantized!(det_tensor, d, dequant_output!(d, detection))
2673 })
2674 .collect::<Vec<_>>();
2675
2676 let new_outputs_view = new_outputs
2677 .iter()
2678 .map(|d: &Array3<f32>| d.view())
2679 .collect::<Vec<_>>();
2680 decode_modelpack_split_float(
2681 &new_outputs_view,
2682 &new_detection,
2683 self.score_threshold,
2684 self.iou_threshold,
2685 output_boxes,
2686 );
2687 Ok(())
2688 }
2689
2690 fn decode_yolo_det_quantized(
2691 &self,
2692 outputs: &[ArrayViewDQuantized],
2693 boxes: &configs::Detection,
2694 output_boxes: &mut Vec<DetectBox>,
2695 ) -> Result<(), DecoderError> {
2696 let (boxes_tensor, _) =
2697 Self::find_outputs_with_shape_quantized(&boxes.shape, outputs, &[])?;
2698 let quant_boxes = boxes
2699 .quantization
2700 .map(Quantization::from)
2701 .unwrap_or_default();
2702
2703 with_quantized!(boxes_tensor, b, {
2704 let boxes_tensor = Self::swap_axes_if_needed(b, boxes.into());
2705 let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
2706 decode_yolo_det(
2707 (boxes_tensor, quant_boxes),
2708 self.score_threshold,
2709 self.iou_threshold,
2710 self.nms,
2711 output_boxes,
2712 );
2713 });
2714
2715 Ok(())
2716 }
2717
2718 fn decode_yolo_segdet_quantized(
2719 &self,
2720 outputs: &[ArrayViewDQuantized],
2721 boxes: &configs::Detection,
2722 protos: &configs::Protos,
2723 output_boxes: &mut Vec<DetectBox>,
2724 output_masks: &mut Vec<Segmentation>,
2725 ) -> Result<(), DecoderError> {
2726 let (boxes_tensor, ind) =
2727 Self::find_outputs_with_shape_quantized(&boxes.shape, outputs, &[])?;
2728 let (protos_tensor, _) =
2729 Self::find_outputs_with_shape_quantized(&protos.shape, outputs, &[ind])?;
2730
2731 let quant_boxes = boxes
2732 .quantization
2733 .map(Quantization::from)
2734 .unwrap_or_default();
2735 let quant_protos = protos
2736 .quantization
2737 .map(Quantization::from)
2738 .unwrap_or_default();
2739
2740 with_quantized!(boxes_tensor, b, {
2741 with_quantized!(protos_tensor, p, {
2742 let box_tensor = Self::swap_axes_if_needed(b, boxes.into());
2743 let box_tensor = box_tensor.slice(s![0, .., ..]);
2744
2745 let protos_tensor = Self::swap_axes_if_needed(p, protos.into());
2746 let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
2747 decode_yolo_segdet_quant(
2748 (box_tensor, quant_boxes),
2749 (protos_tensor, quant_protos),
2750 self.score_threshold,
2751 self.iou_threshold,
2752 self.nms,
2753 output_boxes,
2754 output_masks,
2755 );
2756 });
2757 });
2758
2759 Ok(())
2760 }
2761
2762 fn decode_yolo_split_det_quantized(
2763 &self,
2764 outputs: &[ArrayViewDQuantized],
2765 boxes: &configs::Boxes,
2766 scores: &configs::Scores,
2767 output_boxes: &mut Vec<DetectBox>,
2768 ) -> Result<(), DecoderError> {
2769 let (boxes_tensor, ind) =
2770 Self::find_outputs_with_shape_quantized(&boxes.shape, outputs, &[])?;
2771 let (scores_tensor, _) =
2772 Self::find_outputs_with_shape_quantized(&scores.shape, outputs, &[ind])?;
2773 let quant_boxes = boxes
2774 .quantization
2775 .map(Quantization::from)
2776 .unwrap_or_default();
2777 let quant_scores = scores
2778 .quantization
2779 .map(Quantization::from)
2780 .unwrap_or_default();
2781
2782 with_quantized!(boxes_tensor, b, {
2783 with_quantized!(scores_tensor, s, {
2784 let boxes_tensor = Self::swap_axes_if_needed(b, boxes.into());
2785 let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
2786
2787 let scores_tensor = Self::swap_axes_if_needed(s, scores.into());
2788 let scores_tensor = scores_tensor.slice(s![0, .., ..]);
2789 decode_yolo_split_det_quant(
2790 (boxes_tensor, quant_boxes),
2791 (scores_tensor, quant_scores),
2792 self.score_threshold,
2793 self.iou_threshold,
2794 self.nms,
2795 output_boxes,
2796 );
2797 });
2798 });
2799
2800 Ok(())
2801 }
2802
2803 #[allow(clippy::too_many_arguments)]
2804 fn decode_yolo_split_segdet_quantized(
2805 &self,
2806 outputs: &[ArrayViewDQuantized],
2807 boxes: &configs::Boxes,
2808 scores: &configs::Scores,
2809 mask_coeff: &configs::MaskCoefficients,
2810 protos: &configs::Protos,
2811 output_boxes: &mut Vec<DetectBox>,
2812 output_masks: &mut Vec<Segmentation>,
2813 ) -> Result<(), DecoderError> {
2814 let quant_boxes = boxes
2815 .quantization
2816 .map(Quantization::from)
2817 .unwrap_or_default();
2818 let quant_scores = scores
2819 .quantization
2820 .map(Quantization::from)
2821 .unwrap_or_default();
2822 let quant_masks = mask_coeff
2823 .quantization
2824 .map(Quantization::from)
2825 .unwrap_or_default();
2826 let quant_protos = protos
2827 .quantization
2828 .map(Quantization::from)
2829 .unwrap_or_default();
2830
2831 let mut skip = vec![];
2832
2833 let (boxes_tensor, ind) =
2834 Self::find_outputs_with_shape_quantized(&boxes.shape, outputs, &skip)?;
2835 skip.push(ind);
2836
2837 let (scores_tensor, ind) =
2838 Self::find_outputs_with_shape_quantized(&scores.shape, outputs, &skip)?;
2839 skip.push(ind);
2840
2841 let (mask_tensor, ind) =
2842 Self::find_outputs_with_shape_quantized(&mask_coeff.shape, outputs, &skip)?;
2843 skip.push(ind);
2844
2845 let (protos_tensor, _) =
2846 Self::find_outputs_with_shape_quantized(&protos.shape, outputs, &skip)?;
2847
2848 let boxes = with_quantized!(boxes_tensor, b, {
2849 with_quantized!(scores_tensor, s, {
2850 let boxes_tensor = Self::swap_axes_if_needed(b, boxes.into());
2851 let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
2852
2853 let scores_tensor = Self::swap_axes_if_needed(s, scores.into());
2854 let scores_tensor = scores_tensor.slice(s![0, .., ..]);
2855 impl_yolo_split_segdet_quant_get_boxes::<XYWH, _, _>(
2856 (boxes_tensor, quant_boxes),
2857 (scores_tensor, quant_scores),
2858 self.score_threshold,
2859 self.iou_threshold,
2860 self.nms,
2861 output_boxes.capacity(),
2862 )
2863 })
2864 });
2865
2866 with_quantized!(mask_tensor, m, {
2867 with_quantized!(protos_tensor, p, {
2868 let mask_tensor = Self::swap_axes_if_needed(m, mask_coeff.into());
2869 let mask_tensor = mask_tensor.slice(s![0, .., ..]);
2870
2871 let protos_tensor = Self::swap_axes_if_needed(p, protos.into());
2872 let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
2873 impl_yolo_split_segdet_quant_process_masks::<_, _>(
2874 boxes,
2875 (mask_tensor, quant_masks),
2876 (protos_tensor, quant_protos),
2877 output_boxes,
2878 output_masks,
2879 )
2880 })
2881 });
2882
2883 Ok(())
2884 }
2885
2886 fn decode_modelpack_det_split_float<D>(
2887 &self,
2888 outputs: &[ArrayViewD<D>],
2889 detection: &[configs::Detection],
2890 output_boxes: &mut Vec<DetectBox>,
2891 ) -> Result<(), DecoderError>
2892 where
2893 D: AsPrimitive<f32>,
2894 {
2895 let new_detection = detection
2896 .iter()
2897 .map(|x| match &x.anchors {
2898 None => Err(DecoderError::InvalidConfig(
2899 "ModelPack Split Detection missing anchors".to_string(),
2900 )),
2901 Some(a) => Ok(ModelPackDetectionConfig {
2902 anchors: a.clone(),
2903 quantization: None,
2904 }),
2905 })
2906 .collect::<Result<Vec<_>, _>>()?;
2907
2908 let new_outputs = Self::match_outputs_to_detect(detection, outputs)?;
2909 let new_outputs = new_outputs
2910 .into_iter()
2911 .map(|x| x.slice(s![0, .., .., ..]))
2912 .collect::<Vec<_>>();
2913
2914 decode_modelpack_split_float(
2915 &new_outputs,
2916 &new_detection,
2917 self.score_threshold,
2918 self.iou_threshold,
2919 output_boxes,
2920 );
2921 Ok(())
2922 }
2923
2924 fn decode_modelpack_seg_float<T>(
2925 &self,
2926 outputs: &[ArrayViewD<T>],
2927 segmentation: &configs::Segmentation,
2928 output_masks: &mut Vec<Segmentation>,
2929 ) -> Result<(), DecoderError>
2930 where
2931 T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
2932 f32: AsPrimitive<T>,
2933 {
2934 let (seg, _) = Self::find_outputs_with_shape(&segmentation.shape, outputs, &[])?;
2935
2936 let seg = Self::swap_axes_if_needed(seg, segmentation.into());
2937 let seg = seg.slice(s![0, .., .., ..]);
2938 let u8_max = 255.0_f32.as_();
2939 let max = *seg.max().unwrap_or(&u8_max);
2940 let min = *seg.min().unwrap_or(&0.0_f32.as_());
2941 let seg = seg.mapv(|x| ((x - min) / (max - min) * u8_max).as_());
2942 output_masks.push(Segmentation {
2943 xmin: 0.0,
2944 ymin: 0.0,
2945 xmax: 1.0,
2946 ymax: 1.0,
2947 segmentation: seg,
2948 });
2949 Ok(())
2950 }
2951
2952 fn decode_modelpack_det_float<T>(
2953 &self,
2954 outputs: &[ArrayViewD<T>],
2955 boxes: &configs::Boxes,
2956 scores: &configs::Scores,
2957 output_boxes: &mut Vec<DetectBox>,
2958 ) -> Result<(), DecoderError>
2959 where
2960 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
2961 f32: AsPrimitive<T>,
2962 {
2963 let (boxes_tensor, ind) = Self::find_outputs_with_shape(&boxes.shape, outputs, &[])?;
2964
2965 let boxes_tensor = Self::swap_axes_if_needed(boxes_tensor, boxes.into());
2966 let boxes_tensor = boxes_tensor.slice(s![0, .., 0, ..]);
2967
2968 let (scores_tensor, _) = Self::find_outputs_with_shape(&scores.shape, outputs, &[ind])?;
2969 let scores_tensor = Self::swap_axes_if_needed(scores_tensor, scores.into());
2970 let scores_tensor = scores_tensor.slice(s![0, .., ..]);
2971
2972 decode_modelpack_float(
2973 boxes_tensor,
2974 scores_tensor,
2975 self.score_threshold,
2976 self.iou_threshold,
2977 output_boxes,
2978 );
2979 Ok(())
2980 }
2981
2982 fn decode_yolo_det_float<T>(
2983 &self,
2984 outputs: &[ArrayViewD<T>],
2985 boxes: &configs::Detection,
2986 output_boxes: &mut Vec<DetectBox>,
2987 ) -> Result<(), DecoderError>
2988 where
2989 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
2990 f32: AsPrimitive<T>,
2991 {
2992 let (boxes_tensor, _) = Self::find_outputs_with_shape(&boxes.shape, outputs, &[])?;
2993
2994 let boxes_tensor = Self::swap_axes_if_needed(boxes_tensor, boxes.into());
2995 let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
2996 decode_yolo_det_float(
2997 boxes_tensor,
2998 self.score_threshold,
2999 self.iou_threshold,
3000 self.nms,
3001 output_boxes,
3002 );
3003 Ok(())
3004 }
3005
3006 fn decode_yolo_segdet_float<T>(
3007 &self,
3008 outputs: &[ArrayViewD<T>],
3009 boxes: &configs::Detection,
3010 protos: &configs::Protos,
3011 output_boxes: &mut Vec<DetectBox>,
3012 output_masks: &mut Vec<Segmentation>,
3013 ) -> Result<(), DecoderError>
3014 where
3015 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
3016 f32: AsPrimitive<T>,
3017 {
3018 let (boxes_tensor, ind) = Self::find_outputs_with_shape(&boxes.shape, outputs, &[])?;
3019
3020 let boxes_tensor = Self::swap_axes_if_needed(boxes_tensor, boxes.into());
3021 let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
3022
3023 let (protos_tensor, _) = Self::find_outputs_with_shape(&protos.shape, outputs, &[ind])?;
3024
3025 let protos_tensor = Self::swap_axes_if_needed(protos_tensor, protos.into());
3026 let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
3027 decode_yolo_segdet_float(
3028 boxes_tensor,
3029 protos_tensor,
3030 self.score_threshold,
3031 self.iou_threshold,
3032 self.nms,
3033 output_boxes,
3034 output_masks,
3035 );
3036 Ok(())
3037 }
3038
3039 fn decode_yolo_split_det_float<T>(
3040 &self,
3041 outputs: &[ArrayViewD<T>],
3042 boxes: &configs::Boxes,
3043 scores: &configs::Scores,
3044 output_boxes: &mut Vec<DetectBox>,
3045 ) -> Result<(), DecoderError>
3046 where
3047 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
3048 f32: AsPrimitive<T>,
3049 {
3050 let (boxes_tensor, ind) = Self::find_outputs_with_shape(&boxes.shape, outputs, &[])?;
3051 let boxes_tensor = Self::swap_axes_if_needed(boxes_tensor, boxes.into());
3052 let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
3053
3054 let (scores_tensor, _) = Self::find_outputs_with_shape(&scores.shape, outputs, &[ind])?;
3055
3056 let scores_tensor = Self::swap_axes_if_needed(scores_tensor, scores.into());
3057 let scores_tensor = scores_tensor.slice(s![0, .., ..]);
3058
3059 decode_yolo_split_det_float(
3060 boxes_tensor,
3061 scores_tensor,
3062 self.score_threshold,
3063 self.iou_threshold,
3064 self.nms,
3065 output_boxes,
3066 );
3067 Ok(())
3068 }
3069
3070 #[allow(clippy::too_many_arguments)]
3071 fn decode_yolo_split_segdet_float<T>(
3072 &self,
3073 outputs: &[ArrayViewD<T>],
3074 boxes: &configs::Boxes,
3075 scores: &configs::Scores,
3076 mask_coeff: &configs::MaskCoefficients,
3077 protos: &configs::Protos,
3078 output_boxes: &mut Vec<DetectBox>,
3079 output_masks: &mut Vec<Segmentation>,
3080 ) -> Result<(), DecoderError>
3081 where
3082 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
3083 f32: AsPrimitive<T>,
3084 {
3085 let mut skip = vec![];
3086 let (boxes_tensor, ind) = Self::find_outputs_with_shape(&boxes.shape, outputs, &skip)?;
3087
3088 let boxes_tensor = Self::swap_axes_if_needed(boxes_tensor, boxes.into());
3089 let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
3090 skip.push(ind);
3091
3092 let (scores_tensor, ind) = Self::find_outputs_with_shape(&scores.shape, outputs, &skip)?;
3093
3094 let scores_tensor = Self::swap_axes_if_needed(scores_tensor, scores.into());
3095 let scores_tensor = scores_tensor.slice(s![0, .., ..]);
3096 skip.push(ind);
3097
3098 let (mask_tensor, ind) = Self::find_outputs_with_shape(&mask_coeff.shape, outputs, &skip)?;
3099 let mask_tensor = Self::swap_axes_if_needed(mask_tensor, mask_coeff.into());
3100 let mask_tensor = mask_tensor.slice(s![0, .., ..]);
3101 skip.push(ind);
3102
3103 let (protos_tensor, _) = Self::find_outputs_with_shape(&protos.shape, outputs, &skip)?;
3104 let protos_tensor = Self::swap_axes_if_needed(protos_tensor, protos.into());
3105 let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
3106 decode_yolo_split_segdet_float(
3107 boxes_tensor,
3108 scores_tensor,
3109 mask_tensor,
3110 protos_tensor,
3111 self.score_threshold,
3112 self.iou_threshold,
3113 self.nms,
3114 output_boxes,
3115 output_masks,
3116 );
3117 Ok(())
3118 }
3119
3120 fn decode_yolo_end_to_end_det_float<T>(
3126 &self,
3127 outputs: &[ArrayViewD<T>],
3128 boxes_config: &configs::Detection,
3129 output_boxes: &mut Vec<DetectBox>,
3130 ) -> Result<(), DecoderError>
3131 where
3132 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
3133 f32: AsPrimitive<T>,
3134 {
3135 let (det_tensor, _) = Self::find_outputs_with_shape(&boxes_config.shape, outputs, &[])?;
3136 let det_tensor = Self::swap_axes_if_needed(det_tensor, boxes_config.into());
3137 let det_tensor = det_tensor.slice(s![0, .., ..]);
3138
3139 crate::yolo::decode_yolo_end_to_end_det_float(
3140 det_tensor,
3141 self.score_threshold,
3142 output_boxes,
3143 )?;
3144 Ok(())
3145 }
3146
3147 fn decode_yolo_end_to_end_segdet_float<T>(
3155 &self,
3156 outputs: &[ArrayViewD<T>],
3157 boxes_config: &configs::Detection,
3158 protos_config: &configs::Protos,
3159 output_boxes: &mut Vec<DetectBox>,
3160 output_masks: &mut Vec<Segmentation>,
3161 ) -> Result<(), DecoderError>
3162 where
3163 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
3164 f32: AsPrimitive<T>,
3165 {
3166 if outputs.len() < 2 {
3167 return Err(DecoderError::InvalidShape(
3168 "End-to-end segdet requires detection and protos outputs".to_string(),
3169 ));
3170 }
3171
3172 let (det_tensor, det_ind) =
3173 Self::find_outputs_with_shape(&boxes_config.shape, outputs, &[])?;
3174 let det_tensor = Self::swap_axes_if_needed(det_tensor, boxes_config.into());
3175 let det_tensor = det_tensor.slice(s![0, .., ..]);
3176
3177 let (protos_tensor, _) =
3178 Self::find_outputs_with_shape(&protos_config.shape, outputs, &[det_ind])?;
3179 let protos_tensor = Self::swap_axes_if_needed(protos_tensor, protos_config.into());
3180 let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
3181
3182 crate::yolo::decode_yolo_end_to_end_segdet_float(
3183 det_tensor,
3184 protos_tensor,
3185 self.score_threshold,
3186 output_boxes,
3187 output_masks,
3188 )?;
3189 Ok(())
3190 }
3191
3192 fn match_outputs_to_detect<'a, 'b, T>(
3193 configs: &[configs::Detection],
3194 outputs: &'a [ArrayViewD<'b, T>],
3195 ) -> Result<Vec<&'a ArrayViewD<'b, T>>, DecoderError> {
3196 let mut new_output_order = Vec::new();
3197 for c in configs {
3198 let mut found = false;
3199 for o in outputs {
3200 if o.shape() == c.shape {
3201 new_output_order.push(o);
3202 found = true;
3203 break;
3204 }
3205 }
3206 if !found {
3207 return Err(DecoderError::InvalidShape(format!(
3208 "Did not find output with shape {:?}",
3209 c.shape
3210 )));
3211 }
3212 }
3213 Ok(new_output_order)
3214 }
3215
3216 fn find_outputs_with_shape<'a, 'b, T>(
3217 shape: &[usize],
3218 outputs: &'a [ArrayViewD<'b, T>],
3219 skip: &[usize],
3220 ) -> Result<(&'a ArrayViewD<'b, T>, usize), DecoderError> {
3221 for (ind, o) in outputs.iter().enumerate() {
3222 if skip.contains(&ind) {
3223 continue;
3224 }
3225 if o.shape() == shape {
3226 return Ok((o, ind));
3227 }
3228 }
3229 Err(DecoderError::InvalidShape(format!(
3230 "Did not find output with shape {:?}",
3231 shape
3232 )))
3233 }
3234
3235 fn find_outputs_with_shape_quantized<'a, 'b>(
3236 shape: &[usize],
3237 outputs: &'a [ArrayViewDQuantized<'b>],
3238 skip: &[usize],
3239 ) -> Result<(&'a ArrayViewDQuantized<'b>, usize), DecoderError> {
3240 for (ind, o) in outputs.iter().enumerate() {
3241 if skip.contains(&ind) {
3242 continue;
3243 }
3244 if o.shape() == shape {
3245 return Ok((o, ind));
3246 }
3247 }
3248 Err(DecoderError::InvalidShape(format!(
3249 "Did not find output with shape {:?}",
3250 shape
3251 )))
3252 }
3253
3254 fn modelpack_det_order(x: DimName) -> usize {
3257 match x {
3258 DimName::Batch => 0,
3259 DimName::NumBoxes => 1,
3260 DimName::Padding => 2,
3261 DimName::BoxCoords => 3,
3262 _ => 1000, }
3264 }
3265
3266 fn yolo_det_order(x: DimName) -> usize {
3269 match x {
3270 DimName::Batch => 0,
3271 DimName::NumFeatures => 1,
3272 DimName::NumBoxes => 2,
3273 _ => 1000, }
3275 }
3276
3277 fn modelpack_boxes_order(x: DimName) -> usize {
3280 match x {
3281 DimName::Batch => 0,
3282 DimName::NumBoxes => 1,
3283 DimName::Padding => 2,
3284 DimName::BoxCoords => 3,
3285 _ => 1000, }
3287 }
3288
3289 fn yolo_boxes_order(x: DimName) -> usize {
3292 match x {
3293 DimName::Batch => 0,
3294 DimName::BoxCoords => 1,
3295 DimName::NumBoxes => 2,
3296 _ => 1000, }
3298 }
3299
3300 fn modelpack_scores_order(x: DimName) -> usize {
3303 match x {
3304 DimName::Batch => 0,
3305 DimName::NumBoxes => 1,
3306 DimName::NumClasses => 2,
3307 _ => 1000, }
3309 }
3310
3311 fn yolo_scores_order(x: DimName) -> usize {
3312 match x {
3313 DimName::Batch => 0,
3314 DimName::NumClasses => 1,
3315 DimName::NumBoxes => 2,
3316 _ => 1000, }
3318 }
3319
3320 fn modelpack_segmentation_order(x: DimName) -> usize {
3323 match x {
3324 DimName::Batch => 0,
3325 DimName::Height => 1,
3326 DimName::Width => 2,
3327 DimName::NumClasses => 3,
3328 _ => 1000, }
3330 }
3331
3332 fn modelpack_mask_order(x: DimName) -> usize {
3335 match x {
3336 DimName::Batch => 0,
3337 DimName::Height => 1,
3338 DimName::Width => 2,
3339 _ => 1000, }
3341 }
3342
3343 fn yolo_protos_order(x: DimName) -> usize {
3346 match x {
3347 DimName::Batch => 0,
3348 DimName::Height => 1,
3349 DimName::Width => 2,
3350 DimName::NumProtos => 3,
3351 _ => 1000, }
3353 }
3354
3355 fn yolo_maskcoefficients_order(x: DimName) -> usize {
3358 match x {
3359 DimName::Batch => 0,
3360 DimName::NumProtos => 1,
3361 DimName::NumBoxes => 2,
3362 _ => 1000, }
3364 }
3365
3366 fn get_order_fn(config: ConfigOutputRef) -> fn(DimName) -> usize {
3367 let decoder_type = config.decoder();
3368 match (config, decoder_type) {
3369 (ConfigOutputRef::Detection(_), DecoderType::ModelPack) => Self::modelpack_det_order,
3370 (ConfigOutputRef::Detection(_), DecoderType::Ultralytics) => Self::yolo_det_order,
3371 (ConfigOutputRef::Boxes(_), DecoderType::ModelPack) => Self::modelpack_boxes_order,
3372 (ConfigOutputRef::Boxes(_), DecoderType::Ultralytics) => Self::yolo_boxes_order,
3373 (ConfigOutputRef::Scores(_), DecoderType::ModelPack) => Self::modelpack_scores_order,
3374 (ConfigOutputRef::Scores(_), DecoderType::Ultralytics) => Self::yolo_scores_order,
3375 (ConfigOutputRef::Segmentation(_), _) => Self::modelpack_segmentation_order,
3376 (ConfigOutputRef::Mask(_), _) => Self::modelpack_mask_order,
3377 (ConfigOutputRef::Protos(_), _) => Self::yolo_protos_order,
3378 (ConfigOutputRef::MaskCoefficients(_), _) => Self::yolo_maskcoefficients_order,
3379 }
3380 }
3381
3382 fn swap_axes_if_needed<'a, T, D: Dimension>(
3383 array: &ArrayView<'a, T, D>,
3384 config: ConfigOutputRef,
3385 ) -> ArrayView<'a, T, D> {
3386 let mut array = array.clone();
3387 if config.dshape().is_empty() {
3388 return array;
3389 }
3390 let order_fn: fn(DimName) -> usize = Self::get_order_fn(config.clone());
3391 let mut current_order: Vec<usize> = config
3392 .dshape()
3393 .iter()
3394 .map(|x| order_fn(x.0))
3395 .collect::<Vec<_>>();
3396
3397 assert_eq!(array.shape().len(), current_order.len());
3398 for i in 0..current_order.len() {
3401 let mut swapped = false;
3402 for j in 0..current_order.len() - 1 - i {
3403 if current_order[j] > current_order[j + 1] {
3404 array.swap_axes(j, j + 1);
3405 current_order.swap(j, j + 1);
3406 swapped = true;
3407 }
3408 }
3409 if !swapped {
3410 break;
3411 }
3412 }
3413 array
3414 }
3415
3416 fn match_outputs_to_detect_quantized<'a, 'b>(
3417 configs: &[configs::Detection],
3418 outputs: &'a [ArrayViewDQuantized<'b>],
3419 ) -> Result<Vec<&'a ArrayViewDQuantized<'b>>, DecoderError> {
3420 let mut new_output_order = Vec::new();
3421 for c in configs {
3422 let mut found = false;
3423 for o in outputs {
3424 if o.shape() == c.shape {
3425 new_output_order.push(o);
3426 found = true;
3427 break;
3428 }
3429 }
3430 if !found {
3431 return Err(DecoderError::InvalidShape(format!(
3432 "Did not find output with shape {:?}",
3433 c.shape
3434 )));
3435 }
3436 }
3437 Ok(new_output_order)
3438 }
3439}
3440
3441#[cfg(test)]
3442#[cfg_attr(coverage_nightly, coverage(off))]
3443mod decoder_builder_tests {
3444 use super::*;
3445
3446 #[test]
3447 fn test_decoder_builder_no_config() {
3448 use crate::DecoderBuilder;
3449 let result = DecoderBuilder::default().build();
3450 assert!(matches!(result, Err(DecoderError::NoConfig)));
3451 }
3452
3453 #[test]
3454 fn test_decoder_builder_empty_config() {
3455 use crate::DecoderBuilder;
3456 let result = DecoderBuilder::default()
3457 .with_config(ConfigOutputs {
3458 outputs: vec![],
3459 ..Default::default()
3460 })
3461 .build();
3462 assert!(
3463 matches!(result, Err(DecoderError::InvalidConfig(s)) if s == "No outputs found in config")
3464 );
3465 }
3466
3467 #[test]
3468 fn test_malformed_config_yaml() {
3469 let malformed_yaml = "
3470 model_type: yolov8_det
3471 outputs:
3472 - shape: [1, 84, 8400]
3473 "
3474 .to_owned();
3475 let result = DecoderBuilder::new()
3476 .with_config_yaml_str(malformed_yaml)
3477 .build();
3478 assert!(matches!(result, Err(DecoderError::Yaml(_))));
3479 }
3480
3481 #[test]
3482 fn test_malformed_config_json() {
3483 let malformed_yaml = "
3484 {
3485 \"model_type\": \"yolov8_det\",
3486 \"outputs\": [
3487 {
3488 \"shape\": [1, 84, 8400]
3489 }
3490 ]
3491 }"
3492 .to_owned();
3493 let result = DecoderBuilder::new()
3494 .with_config_json_str(malformed_yaml)
3495 .build();
3496 assert!(matches!(result, Err(DecoderError::Json(_))));
3497 }
3498
3499 #[test]
3500 fn test_modelpack_and_yolo_config_error() {
3501 let result = DecoderBuilder::new()
3502 .with_config_modelpack_det(
3503 configs::Boxes {
3504 decoder: configs::DecoderType::Ultralytics,
3505 shape: vec![1, 4, 8400],
3506 quantization: None,
3507 dshape: vec![
3508 (DimName::Batch, 1),
3509 (DimName::BoxCoords, 4),
3510 (DimName::NumBoxes, 8400),
3511 ],
3512 normalized: Some(true),
3513 },
3514 configs::Scores {
3515 decoder: configs::DecoderType::ModelPack,
3516 shape: vec![1, 80, 8400],
3517 quantization: None,
3518 dshape: vec![
3519 (DimName::Batch, 1),
3520 (DimName::NumClasses, 80),
3521 (DimName::NumBoxes, 8400),
3522 ],
3523 },
3524 )
3525 .build();
3526
3527 assert!(matches!(
3528 result, Err(DecoderError::InvalidConfig(s)) if s == "Both ModelPack and Yolo outputs found in config"
3529 ));
3530 }
3531
3532 #[test]
3533 fn test_yolo_invalid_seg_shape() {
3534 let result = DecoderBuilder::new()
3535 .with_config_yolo_segdet(
3536 configs::Detection {
3537 decoder: configs::DecoderType::Ultralytics,
3538 shape: vec![1, 85, 8400, 1], quantization: None,
3540 anchors: None,
3541 dshape: vec![
3542 (DimName::Batch, 1),
3543 (DimName::NumFeatures, 85),
3544 (DimName::NumBoxes, 8400),
3545 (DimName::Batch, 1),
3546 ],
3547 normalized: Some(true),
3548 },
3549 configs::Protos {
3550 decoder: configs::DecoderType::Ultralytics,
3551 shape: vec![1, 32, 160, 160],
3552 quantization: None,
3553 dshape: vec![
3554 (DimName::Batch, 1),
3555 (DimName::NumProtos, 32),
3556 (DimName::Height, 160),
3557 (DimName::Width, 160),
3558 ],
3559 },
3560 Some(DecoderVersion::Yolo11),
3561 )
3562 .build();
3563
3564 assert!(matches!(
3565 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Detection shape")
3566 ));
3567 }
3568
3569 #[test]
3570 fn test_yolo_invalid_mask() {
3571 let result = DecoderBuilder::new()
3572 .with_config(ConfigOutputs {
3573 outputs: vec![ConfigOutput::Mask(configs::Mask {
3574 shape: vec![1, 160, 160, 1],
3575 decoder: configs::DecoderType::Ultralytics,
3576 quantization: None,
3577 dshape: vec![
3578 (DimName::Batch, 1),
3579 (DimName::Height, 160),
3580 (DimName::Width, 160),
3581 (DimName::NumFeatures, 1),
3582 ],
3583 })],
3584 ..Default::default()
3585 })
3586 .build();
3587
3588 assert!(matches!(
3589 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Mask output with Yolo decoder")
3590 ));
3591 }
3592
3593 #[test]
3594 fn test_yolo_invalid_outputs() {
3595 let result = DecoderBuilder::new()
3596 .with_config(ConfigOutputs {
3597 outputs: vec![ConfigOutput::Segmentation(configs::Segmentation {
3598 shape: vec![1, 84, 8400],
3599 decoder: configs::DecoderType::Ultralytics,
3600 quantization: None,
3601 dshape: vec![
3602 (DimName::Batch, 1),
3603 (DimName::NumFeatures, 84),
3604 (DimName::NumBoxes, 8400),
3605 ],
3606 })],
3607 ..Default::default()
3608 })
3609 .build();
3610
3611 assert!(
3612 matches!(result, Err(DecoderError::InvalidConfig(s)) if s == "Invalid Segmentation output with Yolo decoder")
3613 );
3614 }
3615
3616 #[test]
3617 fn test_yolo_invalid_det() {
3618 let result = DecoderBuilder::new()
3619 .with_config_yolo_det(
3620 configs::Detection {
3621 anchors: None,
3622 decoder: DecoderType::Ultralytics,
3623 quantization: None,
3624 shape: vec![1, 84, 8400, 1], dshape: vec![
3626 (DimName::Batch, 1),
3627 (DimName::NumFeatures, 84),
3628 (DimName::NumBoxes, 8400),
3629 (DimName::Batch, 1),
3630 ],
3631 normalized: Some(true),
3632 },
3633 Some(DecoderVersion::Yolo11),
3634 )
3635 .build();
3636
3637 assert!(matches!(
3638 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Detection shape")));
3639
3640 let result = DecoderBuilder::new()
3641 .with_config_yolo_det(
3642 configs::Detection {
3643 anchors: None,
3644 decoder: DecoderType::Ultralytics,
3645 quantization: None,
3646 shape: vec![1, 8400, 3], dshape: vec![
3648 (DimName::Batch, 1),
3649 (DimName::NumBoxes, 8400),
3650 (DimName::NumFeatures, 3),
3651 ],
3652 normalized: Some(true),
3653 },
3654 Some(DecoderVersion::Yolo11),
3655 )
3656 .build();
3657
3658 assert!(
3659 matches!(
3660 &result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid shape: Yolo num_features 3 must be greater than 4")),
3661 "{}",
3662 result.unwrap_err()
3663 );
3664
3665 let result = DecoderBuilder::new()
3666 .with_config_yolo_det(
3667 configs::Detection {
3668 anchors: None,
3669 decoder: DecoderType::Ultralytics,
3670 quantization: None,
3671 shape: vec![1, 3, 8400], dshape: Vec::new(),
3673 normalized: Some(true),
3674 },
3675 Some(DecoderVersion::Yolo11),
3676 )
3677 .build();
3678
3679 assert!(matches!(
3680 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid shape: Yolo num_features 3 must be greater than 4")));
3681 }
3682
3683 #[test]
3684 fn test_yolo_invalid_segdet() {
3685 let result = DecoderBuilder::new()
3686 .with_config_yolo_segdet(
3687 configs::Detection {
3688 decoder: configs::DecoderType::Ultralytics,
3689 shape: vec![1, 85, 8400, 1], quantization: None,
3691 anchors: None,
3692 dshape: vec![
3693 (DimName::Batch, 1),
3694 (DimName::NumFeatures, 85),
3695 (DimName::NumBoxes, 8400),
3696 (DimName::Batch, 1),
3697 ],
3698 normalized: Some(true),
3699 },
3700 configs::Protos {
3701 decoder: configs::DecoderType::Ultralytics,
3702 shape: vec![1, 32, 160, 160],
3703 quantization: None,
3704 dshape: vec![
3705 (DimName::Batch, 1),
3706 (DimName::NumProtos, 32),
3707 (DimName::Height, 160),
3708 (DimName::Width, 160),
3709 ],
3710 },
3711 Some(DecoderVersion::Yolo11),
3712 )
3713 .build();
3714
3715 assert!(matches!(
3716 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Detection shape")));
3717
3718 let result = DecoderBuilder::new()
3719 .with_config_yolo_segdet(
3720 configs::Detection {
3721 decoder: configs::DecoderType::Ultralytics,
3722 shape: vec![1, 85, 8400],
3723 quantization: None,
3724 anchors: None,
3725 dshape: vec![
3726 (DimName::Batch, 1),
3727 (DimName::NumFeatures, 85),
3728 (DimName::NumBoxes, 8400),
3729 ],
3730 normalized: Some(true),
3731 },
3732 configs::Protos {
3733 decoder: configs::DecoderType::Ultralytics,
3734 shape: vec![1, 32, 160, 160, 1], dshape: vec![
3736 (DimName::Batch, 1),
3737 (DimName::NumProtos, 32),
3738 (DimName::Height, 160),
3739 (DimName::Width, 160),
3740 (DimName::Batch, 1),
3741 ],
3742 quantization: None,
3743 },
3744 Some(DecoderVersion::Yolo11),
3745 )
3746 .build();
3747
3748 assert!(matches!(
3749 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Protos shape")));
3750
3751 let result = DecoderBuilder::new()
3752 .with_config_yolo_segdet(
3753 configs::Detection {
3754 decoder: configs::DecoderType::Ultralytics,
3755 shape: vec![1, 8400, 36], quantization: None,
3757 anchors: None,
3758 dshape: vec![
3759 (DimName::Batch, 1),
3760 (DimName::NumBoxes, 8400),
3761 (DimName::NumFeatures, 36),
3762 ],
3763 normalized: Some(true),
3764 },
3765 configs::Protos {
3766 decoder: configs::DecoderType::Ultralytics,
3767 shape: vec![1, 32, 160, 160],
3768 quantization: None,
3769 dshape: vec![
3770 (DimName::Batch, 1),
3771 (DimName::NumProtos, 32),
3772 (DimName::Height, 160),
3773 (DimName::Width, 160),
3774 ],
3775 },
3776 Some(DecoderVersion::Yolo11),
3777 )
3778 .build();
3779 println!("{:?}", result);
3780 assert!(matches!(
3781 result, Err(DecoderError::InvalidConfig(s)) if s == "Invalid shape: Yolo num_features 36 must be greater than 36"));
3782 }
3783
3784 #[test]
3785 fn test_yolo_invalid_split_det() {
3786 let result = DecoderBuilder::new()
3787 .with_config_yolo_split_det(
3788 configs::Boxes {
3789 decoder: configs::DecoderType::Ultralytics,
3790 shape: vec![1, 4, 8400, 1], quantization: None,
3792 dshape: vec![
3793 (DimName::Batch, 1),
3794 (DimName::BoxCoords, 4),
3795 (DimName::NumBoxes, 8400),
3796 (DimName::Batch, 1),
3797 ],
3798 normalized: Some(true),
3799 },
3800 configs::Scores {
3801 decoder: configs::DecoderType::Ultralytics,
3802 shape: vec![1, 80, 8400],
3803 quantization: None,
3804 dshape: vec![
3805 (DimName::Batch, 1),
3806 (DimName::NumClasses, 80),
3807 (DimName::NumBoxes, 8400),
3808 ],
3809 },
3810 )
3811 .build();
3812
3813 assert!(matches!(
3814 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Split Boxes shape")));
3815
3816 let result = DecoderBuilder::new()
3817 .with_config_yolo_split_det(
3818 configs::Boxes {
3819 decoder: configs::DecoderType::Ultralytics,
3820 shape: vec![1, 4, 8400],
3821 quantization: None,
3822 dshape: vec![
3823 (DimName::Batch, 1),
3824 (DimName::BoxCoords, 4),
3825 (DimName::NumBoxes, 8400),
3826 ],
3827 normalized: Some(true),
3828 },
3829 configs::Scores {
3830 decoder: configs::DecoderType::Ultralytics,
3831 shape: vec![1, 80, 8400, 1], quantization: None,
3833 dshape: vec![
3834 (DimName::Batch, 1),
3835 (DimName::NumClasses, 80),
3836 (DimName::NumBoxes, 8400),
3837 (DimName::Batch, 1),
3838 ],
3839 },
3840 )
3841 .build();
3842
3843 assert!(matches!(
3844 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Split Scores shape")));
3845
3846 let result = DecoderBuilder::new()
3847 .with_config_yolo_split_det(
3848 configs::Boxes {
3849 decoder: configs::DecoderType::Ultralytics,
3850 shape: vec![1, 8400, 4],
3851 quantization: None,
3852 dshape: vec![
3853 (DimName::Batch, 1),
3854 (DimName::NumBoxes, 8400),
3855 (DimName::BoxCoords, 4),
3856 ],
3857 normalized: Some(true),
3858 },
3859 configs::Scores {
3860 decoder: configs::DecoderType::Ultralytics,
3861 shape: vec![1, 8400 + 1, 80], quantization: None,
3863 dshape: vec![
3864 (DimName::Batch, 1),
3865 (DimName::NumBoxes, 8401),
3866 (DimName::NumClasses, 80),
3867 ],
3868 },
3869 )
3870 .build();
3871
3872 assert!(matches!(
3873 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Yolo Split Detection Boxes num 8400 incompatible with Scores num 8401")));
3874
3875 let result = DecoderBuilder::new()
3876 .with_config_yolo_split_det(
3877 configs::Boxes {
3878 decoder: configs::DecoderType::Ultralytics,
3879 shape: vec![1, 5, 8400], quantization: None,
3881 dshape: vec![
3882 (DimName::Batch, 1),
3883 (DimName::BoxCoords, 5),
3884 (DimName::NumBoxes, 8400),
3885 ],
3886 normalized: Some(true),
3887 },
3888 configs::Scores {
3889 decoder: configs::DecoderType::Ultralytics,
3890 shape: vec![1, 80, 8400],
3891 quantization: None,
3892 dshape: vec![
3893 (DimName::Batch, 1),
3894 (DimName::NumClasses, 80),
3895 (DimName::NumBoxes, 8400),
3896 ],
3897 },
3898 )
3899 .build();
3900 assert!(matches!(
3901 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("BoxCoords dimension size must be 4")));
3902 }
3903
3904 #[test]
3905 fn test_yolo_invalid_split_segdet() {
3906 let result = DecoderBuilder::new()
3907 .with_config_yolo_split_segdet(
3908 configs::Boxes {
3909 decoder: configs::DecoderType::Ultralytics,
3910 shape: vec![1, 8400, 4, 1],
3911 quantization: None,
3912 dshape: vec![
3913 (DimName::Batch, 1),
3914 (DimName::NumBoxes, 8400),
3915 (DimName::BoxCoords, 4),
3916 (DimName::Batch, 1),
3917 ],
3918 normalized: Some(true),
3919 },
3920 configs::Scores {
3921 decoder: configs::DecoderType::Ultralytics,
3922 shape: vec![1, 8400, 80],
3923
3924 quantization: None,
3925 dshape: vec![
3926 (DimName::Batch, 1),
3927 (DimName::NumBoxes, 8400),
3928 (DimName::NumClasses, 80),
3929 ],
3930 },
3931 configs::MaskCoefficients {
3932 decoder: configs::DecoderType::Ultralytics,
3933 shape: vec![1, 8400, 32],
3934 quantization: None,
3935 dshape: vec![
3936 (DimName::Batch, 1),
3937 (DimName::NumBoxes, 8400),
3938 (DimName::NumProtos, 32),
3939 ],
3940 },
3941 configs::Protos {
3942 decoder: configs::DecoderType::Ultralytics,
3943 shape: vec![1, 32, 160, 160],
3944 quantization: None,
3945 dshape: vec![
3946 (DimName::Batch, 1),
3947 (DimName::NumProtos, 32),
3948 (DimName::Height, 160),
3949 (DimName::Width, 160),
3950 ],
3951 },
3952 )
3953 .build();
3954
3955 assert!(matches!(
3956 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Split Boxes shape")));
3957
3958 let result = DecoderBuilder::new()
3959 .with_config_yolo_split_segdet(
3960 configs::Boxes {
3961 decoder: configs::DecoderType::Ultralytics,
3962 shape: vec![1, 8400, 4],
3963 quantization: None,
3964 dshape: vec![
3965 (DimName::Batch, 1),
3966 (DimName::NumBoxes, 8400),
3967 (DimName::BoxCoords, 4),
3968 ],
3969 normalized: Some(true),
3970 },
3971 configs::Scores {
3972 decoder: configs::DecoderType::Ultralytics,
3973 shape: vec![1, 8400, 80, 1],
3974 quantization: None,
3975 dshape: vec![
3976 (DimName::Batch, 1),
3977 (DimName::NumBoxes, 8400),
3978 (DimName::NumClasses, 80),
3979 (DimName::Batch, 1),
3980 ],
3981 },
3982 configs::MaskCoefficients {
3983 decoder: configs::DecoderType::Ultralytics,
3984 shape: vec![1, 8400, 32],
3985 quantization: None,
3986 dshape: vec![
3987 (DimName::Batch, 1),
3988 (DimName::NumBoxes, 8400),
3989 (DimName::NumProtos, 32),
3990 ],
3991 },
3992 configs::Protos {
3993 decoder: configs::DecoderType::Ultralytics,
3994 shape: vec![1, 32, 160, 160],
3995 quantization: None,
3996 dshape: vec![
3997 (DimName::Batch, 1),
3998 (DimName::NumProtos, 32),
3999 (DimName::Height, 160),
4000 (DimName::Width, 160),
4001 ],
4002 },
4003 )
4004 .build();
4005
4006 assert!(matches!(
4007 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Split Scores shape")));
4008
4009 let result = DecoderBuilder::new()
4010 .with_config_yolo_split_segdet(
4011 configs::Boxes {
4012 decoder: configs::DecoderType::Ultralytics,
4013 shape: vec![1, 8400, 4],
4014 quantization: None,
4015 dshape: vec![
4016 (DimName::Batch, 1),
4017 (DimName::NumBoxes, 8400),
4018 (DimName::BoxCoords, 4),
4019 ],
4020 normalized: Some(true),
4021 },
4022 configs::Scores {
4023 decoder: configs::DecoderType::Ultralytics,
4024 shape: vec![1, 8400, 80],
4025 quantization: None,
4026 dshape: vec![
4027 (DimName::Batch, 1),
4028 (DimName::NumBoxes, 8400),
4029 (DimName::NumClasses, 80),
4030 ],
4031 },
4032 configs::MaskCoefficients {
4033 decoder: configs::DecoderType::Ultralytics,
4034 shape: vec![1, 8400, 32, 1],
4035 quantization: None,
4036 dshape: vec![
4037 (DimName::Batch, 1),
4038 (DimName::NumBoxes, 8400),
4039 (DimName::NumProtos, 32),
4040 (DimName::Batch, 1),
4041 ],
4042 },
4043 configs::Protos {
4044 decoder: configs::DecoderType::Ultralytics,
4045 shape: vec![1, 32, 160, 160],
4046 quantization: None,
4047 dshape: vec![
4048 (DimName::Batch, 1),
4049 (DimName::NumProtos, 32),
4050 (DimName::Height, 160),
4051 (DimName::Width, 160),
4052 ],
4053 },
4054 )
4055 .build();
4056
4057 assert!(matches!(
4058 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Split Mask Coefficients shape")));
4059
4060 let result = DecoderBuilder::new()
4061 .with_config_yolo_split_segdet(
4062 configs::Boxes {
4063 decoder: configs::DecoderType::Ultralytics,
4064 shape: vec![1, 8400, 4],
4065 quantization: None,
4066 dshape: vec![
4067 (DimName::Batch, 1),
4068 (DimName::NumBoxes, 8400),
4069 (DimName::BoxCoords, 4),
4070 ],
4071 normalized: Some(true),
4072 },
4073 configs::Scores {
4074 decoder: configs::DecoderType::Ultralytics,
4075 shape: vec![1, 8400, 80],
4076 quantization: None,
4077 dshape: vec![
4078 (DimName::Batch, 1),
4079 (DimName::NumBoxes, 8400),
4080 (DimName::NumClasses, 80),
4081 ],
4082 },
4083 configs::MaskCoefficients {
4084 decoder: configs::DecoderType::Ultralytics,
4085 shape: vec![1, 8400, 32],
4086 quantization: None,
4087 dshape: vec![
4088 (DimName::Batch, 1),
4089 (DimName::NumBoxes, 8400),
4090 (DimName::NumProtos, 32),
4091 ],
4092 },
4093 configs::Protos {
4094 decoder: configs::DecoderType::Ultralytics,
4095 shape: vec![1, 32, 160, 160, 1],
4096 quantization: None,
4097 dshape: vec![
4098 (DimName::Batch, 1),
4099 (DimName::NumProtos, 32),
4100 (DimName::Height, 160),
4101 (DimName::Width, 160),
4102 (DimName::Batch, 1),
4103 ],
4104 },
4105 )
4106 .build();
4107
4108 assert!(matches!(
4109 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Protos shape")));
4110
4111 let result = DecoderBuilder::new()
4112 .with_config_yolo_split_segdet(
4113 configs::Boxes {
4114 decoder: configs::DecoderType::Ultralytics,
4115 shape: vec![1, 8400, 4],
4116 quantization: None,
4117 dshape: vec![
4118 (DimName::Batch, 1),
4119 (DimName::NumBoxes, 8400),
4120 (DimName::BoxCoords, 4),
4121 ],
4122 normalized: Some(true),
4123 },
4124 configs::Scores {
4125 decoder: configs::DecoderType::Ultralytics,
4126 shape: vec![1, 8401, 80],
4127 quantization: None,
4128 dshape: vec![
4129 (DimName::Batch, 1),
4130 (DimName::NumBoxes, 8401),
4131 (DimName::NumClasses, 80),
4132 ],
4133 },
4134 configs::MaskCoefficients {
4135 decoder: configs::DecoderType::Ultralytics,
4136 shape: vec![1, 8400, 32],
4137 quantization: None,
4138 dshape: vec![
4139 (DimName::Batch, 1),
4140 (DimName::NumBoxes, 8400),
4141 (DimName::NumProtos, 32),
4142 ],
4143 },
4144 configs::Protos {
4145 decoder: configs::DecoderType::Ultralytics,
4146 shape: vec![1, 32, 160, 160],
4147 quantization: None,
4148 dshape: vec![
4149 (DimName::Batch, 1),
4150 (DimName::NumProtos, 32),
4151 (DimName::Height, 160),
4152 (DimName::Width, 160),
4153 ],
4154 },
4155 )
4156 .build();
4157
4158 assert!(matches!(
4159 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Yolo Split Detection Boxes num 8400 incompatible with Scores num 8401")));
4160
4161 let result = DecoderBuilder::new()
4162 .with_config_yolo_split_segdet(
4163 configs::Boxes {
4164 decoder: configs::DecoderType::Ultralytics,
4165 shape: vec![1, 8400, 4],
4166 quantization: None,
4167 dshape: vec![
4168 (DimName::Batch, 1),
4169 (DimName::NumBoxes, 8400),
4170 (DimName::BoxCoords, 4),
4171 ],
4172 normalized: Some(true),
4173 },
4174 configs::Scores {
4175 decoder: configs::DecoderType::Ultralytics,
4176 shape: vec![1, 8400, 80],
4177 quantization: None,
4178 dshape: vec![
4179 (DimName::Batch, 1),
4180 (DimName::NumBoxes, 8400),
4181 (DimName::NumClasses, 80),
4182 ],
4183 },
4184 configs::MaskCoefficients {
4185 decoder: configs::DecoderType::Ultralytics,
4186 shape: vec![1, 8401, 32],
4187
4188 quantization: None,
4189 dshape: vec![
4190 (DimName::Batch, 1),
4191 (DimName::NumBoxes, 8401),
4192 (DimName::NumProtos, 32),
4193 ],
4194 },
4195 configs::Protos {
4196 decoder: configs::DecoderType::Ultralytics,
4197 shape: vec![1, 32, 160, 160],
4198 quantization: None,
4199 dshape: vec![
4200 (DimName::Batch, 1),
4201 (DimName::NumProtos, 32),
4202 (DimName::Height, 160),
4203 (DimName::Width, 160),
4204 ],
4205 },
4206 )
4207 .build();
4208
4209 assert!(matches!(
4210 result, Err(DecoderError::InvalidConfig(ref s)) if s.starts_with("Yolo Split Detection Boxes num 8400 incompatible with Mask Coefficients num 8401")));
4211 let result = DecoderBuilder::new()
4212 .with_config_yolo_split_segdet(
4213 configs::Boxes {
4214 decoder: configs::DecoderType::Ultralytics,
4215 shape: vec![1, 8400, 4],
4216 quantization: None,
4217 dshape: vec![
4218 (DimName::Batch, 1),
4219 (DimName::NumBoxes, 8400),
4220 (DimName::BoxCoords, 4),
4221 ],
4222 normalized: Some(true),
4223 },
4224 configs::Scores {
4225 decoder: configs::DecoderType::Ultralytics,
4226 shape: vec![1, 8400, 80],
4227 quantization: None,
4228 dshape: vec![
4229 (DimName::Batch, 1),
4230 (DimName::NumBoxes, 8400),
4231 (DimName::NumClasses, 80),
4232 ],
4233 },
4234 configs::MaskCoefficients {
4235 decoder: configs::DecoderType::Ultralytics,
4236 shape: vec![1, 8400, 32],
4237 quantization: None,
4238 dshape: vec![
4239 (DimName::Batch, 1),
4240 (DimName::NumBoxes, 8400),
4241 (DimName::NumProtos, 32),
4242 ],
4243 },
4244 configs::Protos {
4245 decoder: configs::DecoderType::Ultralytics,
4246 shape: vec![1, 31, 160, 160],
4247 quantization: None,
4248 dshape: vec![
4249 (DimName::Batch, 1),
4250 (DimName::NumProtos, 31),
4251 (DimName::Height, 160),
4252 (DimName::Width, 160),
4253 ],
4254 },
4255 )
4256 .build();
4257 println!("{:?}", result);
4258 assert!(matches!(
4259 result, Err(DecoderError::InvalidConfig(ref s)) if s.starts_with( "Yolo Protos channels 31 incompatible with Mask Coefficients channels 32")));
4260 }
4261
4262 #[test]
4263 fn test_modelpack_invalid_config() {
4264 let result = DecoderBuilder::new()
4265 .with_config(ConfigOutputs {
4266 outputs: vec![
4267 ConfigOutput::Boxes(configs::Boxes {
4268 decoder: configs::DecoderType::ModelPack,
4269 shape: vec![1, 8400, 1, 4],
4270 quantization: None,
4271 dshape: vec![
4272 (DimName::Batch, 1),
4273 (DimName::NumBoxes, 8400),
4274 (DimName::Padding, 1),
4275 (DimName::BoxCoords, 4),
4276 ],
4277 normalized: Some(true),
4278 }),
4279 ConfigOutput::Scores(configs::Scores {
4280 decoder: configs::DecoderType::ModelPack,
4281 shape: vec![1, 8400, 3],
4282 quantization: None,
4283 dshape: vec![
4284 (DimName::Batch, 1),
4285 (DimName::NumBoxes, 8400),
4286 (DimName::NumClasses, 3),
4287 ],
4288 }),
4289 ConfigOutput::Protos(configs::Protos {
4290 decoder: configs::DecoderType::ModelPack,
4291 shape: vec![1, 8400, 3],
4292 quantization: None,
4293 dshape: vec![
4294 (DimName::Batch, 1),
4295 (DimName::NumBoxes, 8400),
4296 (DimName::NumFeatures, 3),
4297 ],
4298 }),
4299 ],
4300 ..Default::default()
4301 })
4302 .build();
4303
4304 assert!(matches!(
4305 result, Err(DecoderError::InvalidConfig(s)) if s == "ModelPack should not have protos"));
4306
4307 let result = DecoderBuilder::new()
4308 .with_config(ConfigOutputs {
4309 outputs: vec![
4310 ConfigOutput::Boxes(configs::Boxes {
4311 decoder: configs::DecoderType::ModelPack,
4312 shape: vec![1, 8400, 1, 4],
4313 quantization: None,
4314 dshape: vec![
4315 (DimName::Batch, 1),
4316 (DimName::NumBoxes, 8400),
4317 (DimName::Padding, 1),
4318 (DimName::BoxCoords, 4),
4319 ],
4320 normalized: Some(true),
4321 }),
4322 ConfigOutput::Scores(configs::Scores {
4323 decoder: configs::DecoderType::ModelPack,
4324 shape: vec![1, 8400, 3],
4325 quantization: None,
4326 dshape: vec![
4327 (DimName::Batch, 1),
4328 (DimName::NumBoxes, 8400),
4329 (DimName::NumClasses, 3),
4330 ],
4331 }),
4332 ConfigOutput::MaskCoefficients(configs::MaskCoefficients {
4333 decoder: configs::DecoderType::ModelPack,
4334 shape: vec![1, 8400, 3],
4335 quantization: None,
4336 dshape: vec![
4337 (DimName::Batch, 1),
4338 (DimName::NumBoxes, 8400),
4339 (DimName::NumProtos, 3),
4340 ],
4341 }),
4342 ],
4343 ..Default::default()
4344 })
4345 .build();
4346
4347 assert!(matches!(
4348 result, Err(DecoderError::InvalidConfig(s)) if s == "ModelPack should not have mask coefficients"));
4349
4350 let result = DecoderBuilder::new()
4351 .with_config(ConfigOutputs {
4352 outputs: vec![ConfigOutput::Boxes(configs::Boxes {
4353 decoder: configs::DecoderType::ModelPack,
4354 shape: vec![1, 8400, 1, 4],
4355 quantization: None,
4356 dshape: vec![
4357 (DimName::Batch, 1),
4358 (DimName::NumBoxes, 8400),
4359 (DimName::Padding, 1),
4360 (DimName::BoxCoords, 4),
4361 ],
4362 normalized: Some(true),
4363 })],
4364 ..Default::default()
4365 })
4366 .build();
4367
4368 assert!(matches!(
4369 result, Err(DecoderError::InvalidConfig(s)) if s == "Invalid ModelPack model outputs"));
4370 }
4371
4372 #[test]
4373 fn test_modelpack_invalid_det() {
4374 let result = DecoderBuilder::new()
4375 .with_config_modelpack_det(
4376 configs::Boxes {
4377 decoder: DecoderType::ModelPack,
4378 quantization: None,
4379 shape: vec![1, 4, 8400],
4380 dshape: vec![
4381 (DimName::Batch, 1),
4382 (DimName::BoxCoords, 4),
4383 (DimName::NumBoxes, 8400),
4384 ],
4385 normalized: Some(true),
4386 },
4387 configs::Scores {
4388 decoder: DecoderType::ModelPack,
4389 quantization: None,
4390 shape: vec![1, 80, 8400],
4391 dshape: vec![
4392 (DimName::Batch, 1),
4393 (DimName::NumClasses, 80),
4394 (DimName::NumBoxes, 8400),
4395 ],
4396 },
4397 )
4398 .build();
4399
4400 assert!(matches!(
4401 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid ModelPack Boxes shape")));
4402
4403 let result = DecoderBuilder::new()
4404 .with_config_modelpack_det(
4405 configs::Boxes {
4406 decoder: DecoderType::ModelPack,
4407 quantization: None,
4408 shape: vec![1, 4, 1, 8400],
4409 dshape: vec![
4410 (DimName::Batch, 1),
4411 (DimName::BoxCoords, 4),
4412 (DimName::Padding, 1),
4413 (DimName::NumBoxes, 8400),
4414 ],
4415 normalized: Some(true),
4416 },
4417 configs::Scores {
4418 decoder: DecoderType::ModelPack,
4419 quantization: None,
4420 shape: vec![1, 80, 8400, 1],
4421 dshape: vec![
4422 (DimName::Batch, 1),
4423 (DimName::NumClasses, 80),
4424 (DimName::NumBoxes, 8400),
4425 (DimName::Padding, 1),
4426 ],
4427 },
4428 )
4429 .build();
4430
4431 assert!(matches!(
4432 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid ModelPack Scores shape")));
4433
4434 let result = DecoderBuilder::new()
4435 .with_config_modelpack_det(
4436 configs::Boxes {
4437 decoder: DecoderType::ModelPack,
4438 quantization: None,
4439 shape: vec![1, 4, 2, 8400],
4440 dshape: vec![
4441 (DimName::Batch, 1),
4442 (DimName::BoxCoords, 4),
4443 (DimName::Padding, 2),
4444 (DimName::NumBoxes, 8400),
4445 ],
4446 normalized: Some(true),
4447 },
4448 configs::Scores {
4449 decoder: DecoderType::ModelPack,
4450 quantization: None,
4451 shape: vec![1, 80, 8400],
4452 dshape: vec![
4453 (DimName::Batch, 1),
4454 (DimName::NumClasses, 80),
4455 (DimName::NumBoxes, 8400),
4456 ],
4457 },
4458 )
4459 .build();
4460 assert!(matches!(
4461 result, Err(DecoderError::InvalidConfig(s)) if s == "Padding dimension size must be 1"));
4462
4463 let result = DecoderBuilder::new()
4464 .with_config_modelpack_det(
4465 configs::Boxes {
4466 decoder: DecoderType::ModelPack,
4467 quantization: None,
4468 shape: vec![1, 5, 1, 8400],
4469 dshape: vec![
4470 (DimName::Batch, 1),
4471 (DimName::BoxCoords, 5),
4472 (DimName::Padding, 1),
4473 (DimName::NumBoxes, 8400),
4474 ],
4475 normalized: Some(true),
4476 },
4477 configs::Scores {
4478 decoder: DecoderType::ModelPack,
4479 quantization: None,
4480 shape: vec![1, 80, 8400],
4481 dshape: vec![
4482 (DimName::Batch, 1),
4483 (DimName::NumClasses, 80),
4484 (DimName::NumBoxes, 8400),
4485 ],
4486 },
4487 )
4488 .build();
4489
4490 assert!(matches!(
4491 result, Err(DecoderError::InvalidConfig(s)) if s == "BoxCoords dimension size must be 4"));
4492
4493 let result = DecoderBuilder::new()
4494 .with_config_modelpack_det(
4495 configs::Boxes {
4496 decoder: DecoderType::ModelPack,
4497 quantization: None,
4498 shape: vec![1, 4, 1, 8400],
4499 dshape: vec![
4500 (DimName::Batch, 1),
4501 (DimName::BoxCoords, 4),
4502 (DimName::Padding, 1),
4503 (DimName::NumBoxes, 8400),
4504 ],
4505 normalized: Some(true),
4506 },
4507 configs::Scores {
4508 decoder: DecoderType::ModelPack,
4509 quantization: None,
4510 shape: vec![1, 80, 8401],
4511 dshape: vec![
4512 (DimName::Batch, 1),
4513 (DimName::NumClasses, 80),
4514 (DimName::NumBoxes, 8401),
4515 ],
4516 },
4517 )
4518 .build();
4519
4520 assert!(matches!(
4521 result, Err(DecoderError::InvalidConfig(s)) if s == "ModelPack Detection Boxes num 8400 incompatible with Scores num 8401"));
4522 }
4523
4524 #[test]
4525 fn test_modelpack_invalid_det_split() {
4526 let result = DecoderBuilder::default()
4527 .with_config_modelpack_det_split(vec![
4528 configs::Detection {
4529 decoder: DecoderType::ModelPack,
4530 shape: vec![1, 17, 30, 18],
4531 anchors: None,
4532 quantization: None,
4533 dshape: vec![
4534 (DimName::Batch, 1),
4535 (DimName::Height, 17),
4536 (DimName::Width, 30),
4537 (DimName::NumAnchorsXFeatures, 18),
4538 ],
4539 normalized: Some(true),
4540 },
4541 configs::Detection {
4542 decoder: DecoderType::ModelPack,
4543 shape: vec![1, 9, 15, 18],
4544 anchors: None,
4545 quantization: None,
4546 dshape: vec![
4547 (DimName::Batch, 1),
4548 (DimName::Height, 9),
4549 (DimName::Width, 15),
4550 (DimName::NumAnchorsXFeatures, 18),
4551 ],
4552 normalized: Some(true),
4553 },
4554 ])
4555 .build();
4556
4557 assert!(matches!(
4558 result, Err(DecoderError::InvalidConfig(s)) if s == "ModelPack Split Detection missing anchors"));
4559
4560 let result = DecoderBuilder::default()
4561 .with_config_modelpack_det_split(vec![configs::Detection {
4562 decoder: DecoderType::ModelPack,
4563 shape: vec![1, 17, 30, 18],
4564 anchors: None,
4565 quantization: None,
4566 dshape: Vec::new(),
4567 normalized: Some(true),
4568 }])
4569 .build();
4570
4571 assert!(matches!(
4572 result, Err(DecoderError::InvalidConfig(s)) if s == "ModelPack Split Detection missing anchors"));
4573
4574 let result = DecoderBuilder::default()
4575 .with_config_modelpack_det_split(vec![configs::Detection {
4576 decoder: DecoderType::ModelPack,
4577 shape: vec![1, 17, 30, 18],
4578 anchors: Some(vec![]),
4579 quantization: None,
4580 dshape: vec![
4581 (DimName::Batch, 1),
4582 (DimName::Height, 17),
4583 (DimName::Width, 30),
4584 (DimName::NumAnchorsXFeatures, 18),
4585 ],
4586 normalized: Some(true),
4587 }])
4588 .build();
4589
4590 assert!(matches!(
4591 result, Err(DecoderError::InvalidConfig(s)) if s == "ModelPack Split Detection has zero anchors"));
4592
4593 let result = DecoderBuilder::default()
4594 .with_config_modelpack_det_split(vec![configs::Detection {
4595 decoder: DecoderType::ModelPack,
4596 shape: vec![1, 17, 30, 18, 1],
4597 anchors: Some(vec![
4598 [0.3666666, 0.3148148],
4599 [0.3874999, 0.474074],
4600 [0.5333333, 0.644444],
4601 ]),
4602 quantization: None,
4603 dshape: vec![
4604 (DimName::Batch, 1),
4605 (DimName::Height, 17),
4606 (DimName::Width, 30),
4607 (DimName::NumAnchorsXFeatures, 18),
4608 (DimName::Padding, 1),
4609 ],
4610 normalized: Some(true),
4611 }])
4612 .build();
4613
4614 assert!(matches!(
4615 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid ModelPack Split Detection shape")));
4616
4617 let result = DecoderBuilder::default()
4618 .with_config_modelpack_det_split(vec![configs::Detection {
4619 decoder: DecoderType::ModelPack,
4620 shape: vec![1, 15, 17, 30],
4621 anchors: Some(vec![
4622 [0.3666666, 0.3148148],
4623 [0.3874999, 0.474074],
4624 [0.5333333, 0.644444],
4625 ]),
4626 quantization: None,
4627 dshape: vec![
4628 (DimName::Batch, 1),
4629 (DimName::NumAnchorsXFeatures, 15),
4630 (DimName::Height, 17),
4631 (DimName::Width, 30),
4632 ],
4633 normalized: Some(true),
4634 }])
4635 .build();
4636
4637 assert!(matches!(
4638 result, Err(DecoderError::InvalidConfig(s)) if s.contains("not greater than number of anchors * 5 =")));
4639
4640 let result = DecoderBuilder::default()
4641 .with_config_modelpack_det_split(vec![configs::Detection {
4642 decoder: DecoderType::ModelPack,
4643 shape: vec![1, 17, 30, 15],
4644 anchors: Some(vec![
4645 [0.3666666, 0.3148148],
4646 [0.3874999, 0.474074],
4647 [0.5333333, 0.644444],
4648 ]),
4649 quantization: None,
4650 dshape: Vec::new(),
4651 normalized: Some(true),
4652 }])
4653 .build();
4654
4655 assert!(matches!(
4656 result, Err(DecoderError::InvalidConfig(s)) if s.contains("not greater than number of anchors * 5 =")));
4657
4658 let result = DecoderBuilder::default()
4659 .with_config_modelpack_det_split(vec![configs::Detection {
4660 decoder: DecoderType::ModelPack,
4661 shape: vec![1, 16, 17, 30],
4662 anchors: Some(vec![
4663 [0.3666666, 0.3148148],
4664 [0.3874999, 0.474074],
4665 [0.5333333, 0.644444],
4666 ]),
4667 quantization: None,
4668 dshape: vec![
4669 (DimName::Batch, 1),
4670 (DimName::NumAnchorsXFeatures, 16),
4671 (DimName::Height, 17),
4672 (DimName::Width, 30),
4673 ],
4674 normalized: Some(true),
4675 }])
4676 .build();
4677
4678 assert!(matches!(
4679 result, Err(DecoderError::InvalidConfig(s)) if s.contains("not a multiple of number of anchors")));
4680
4681 let result = DecoderBuilder::default()
4682 .with_config_modelpack_det_split(vec![configs::Detection {
4683 decoder: DecoderType::ModelPack,
4684 shape: vec![1, 17, 30, 16],
4685 anchors: Some(vec![
4686 [0.3666666, 0.3148148],
4687 [0.3874999, 0.474074],
4688 [0.5333333, 0.644444],
4689 ]),
4690 quantization: None,
4691 dshape: Vec::new(),
4692 normalized: Some(true),
4693 }])
4694 .build();
4695
4696 assert!(matches!(
4697 result, Err(DecoderError::InvalidConfig(s)) if s.contains("not a multiple of number of anchors")));
4698
4699 let result = DecoderBuilder::default()
4700 .with_config_modelpack_det_split(vec![configs::Detection {
4701 decoder: DecoderType::ModelPack,
4702 shape: vec![1, 18, 17, 30],
4703 anchors: Some(vec![
4704 [0.3666666, 0.3148148],
4705 [0.3874999, 0.474074],
4706 [0.5333333, 0.644444],
4707 ]),
4708 quantization: None,
4709 dshape: vec![
4710 (DimName::Batch, 1),
4711 (DimName::NumProtos, 18),
4712 (DimName::Height, 17),
4713 (DimName::Width, 30),
4714 ],
4715 normalized: Some(true),
4716 }])
4717 .build();
4718 assert!(matches!(
4719 result, Err(DecoderError::InvalidConfig(s)) if s.contains("Split Detection dshape missing required dimension NumAnchorsXFeature")));
4720
4721 let result = DecoderBuilder::default()
4722 .with_config_modelpack_det_split(vec![
4723 configs::Detection {
4724 decoder: DecoderType::ModelPack,
4725 shape: vec![1, 17, 30, 18],
4726 anchors: Some(vec![
4727 [0.3666666, 0.3148148],
4728 [0.3874999, 0.474074],
4729 [0.5333333, 0.644444],
4730 ]),
4731 quantization: None,
4732 dshape: vec![
4733 (DimName::Batch, 1),
4734 (DimName::Height, 17),
4735 (DimName::Width, 30),
4736 (DimName::NumAnchorsXFeatures, 18),
4737 ],
4738 normalized: Some(true),
4739 },
4740 configs::Detection {
4741 decoder: DecoderType::ModelPack,
4742 shape: vec![1, 17, 30, 21],
4743 anchors: Some(vec![
4744 [0.3666666, 0.3148148],
4745 [0.3874999, 0.474074],
4746 [0.5333333, 0.644444],
4747 ]),
4748 quantization: None,
4749 dshape: vec![
4750 (DimName::Batch, 1),
4751 (DimName::Height, 17),
4752 (DimName::Width, 30),
4753 (DimName::NumAnchorsXFeatures, 21),
4754 ],
4755 normalized: Some(true),
4756 },
4757 ])
4758 .build();
4759
4760 assert!(matches!(
4761 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("ModelPack Split Detection inconsistent number of classes:")));
4762
4763 let result = DecoderBuilder::default()
4764 .with_config_modelpack_det_split(vec![
4765 configs::Detection {
4766 decoder: DecoderType::ModelPack,
4767 shape: vec![1, 17, 30, 18],
4768 anchors: Some(vec![
4769 [0.3666666, 0.3148148],
4770 [0.3874999, 0.474074],
4771 [0.5333333, 0.644444],
4772 ]),
4773 quantization: None,
4774 dshape: vec![],
4775 normalized: Some(true),
4776 },
4777 configs::Detection {
4778 decoder: DecoderType::ModelPack,
4779 shape: vec![1, 17, 30, 21],
4780 anchors: Some(vec![
4781 [0.3666666, 0.3148148],
4782 [0.3874999, 0.474074],
4783 [0.5333333, 0.644444],
4784 ]),
4785 quantization: None,
4786 dshape: vec![],
4787 normalized: Some(true),
4788 },
4789 ])
4790 .build();
4791
4792 assert!(matches!(
4793 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("ModelPack Split Detection inconsistent number of classes:")));
4794 }
4795
4796 #[test]
4797 fn test_modelpack_invalid_seg() {
4798 let result = DecoderBuilder::new()
4799 .with_config_modelpack_seg(configs::Segmentation {
4800 decoder: DecoderType::ModelPack,
4801 quantization: None,
4802 shape: vec![1, 160, 106, 3, 1],
4803 dshape: vec![
4804 (DimName::Batch, 1),
4805 (DimName::Height, 160),
4806 (DimName::Width, 106),
4807 (DimName::NumClasses, 3),
4808 (DimName::Padding, 1),
4809 ],
4810 })
4811 .build();
4812
4813 assert!(matches!(
4814 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid ModelPack Segmentation shape")));
4815 }
4816
4817 #[test]
4818 fn test_modelpack_invalid_segdet() {
4819 let result = DecoderBuilder::new()
4820 .with_config_modelpack_segdet(
4821 configs::Boxes {
4822 decoder: DecoderType::ModelPack,
4823 quantization: None,
4824 shape: vec![1, 4, 1, 8400],
4825 dshape: vec![
4826 (DimName::Batch, 1),
4827 (DimName::BoxCoords, 4),
4828 (DimName::Padding, 1),
4829 (DimName::NumBoxes, 8400),
4830 ],
4831 normalized: Some(true),
4832 },
4833 configs::Scores {
4834 decoder: DecoderType::ModelPack,
4835 quantization: None,
4836 shape: vec![1, 4, 8400],
4837 dshape: vec![
4838 (DimName::Batch, 1),
4839 (DimName::NumClasses, 4),
4840 (DimName::NumBoxes, 8400),
4841 ],
4842 },
4843 configs::Segmentation {
4844 decoder: DecoderType::ModelPack,
4845 quantization: None,
4846 shape: vec![1, 160, 106, 3],
4847 dshape: vec![
4848 (DimName::Batch, 1),
4849 (DimName::Height, 160),
4850 (DimName::Width, 106),
4851 (DimName::NumClasses, 3),
4852 ],
4853 },
4854 )
4855 .build();
4856
4857 assert!(matches!(
4858 result, Err(DecoderError::InvalidConfig(s)) if s.contains("incompatible with number of classes")));
4859 }
4860
4861 #[test]
4862 fn test_modelpack_invalid_segdet_split() {
4863 let result = DecoderBuilder::new()
4864 .with_config_modelpack_segdet_split(
4865 vec![configs::Detection {
4866 decoder: DecoderType::ModelPack,
4867 shape: vec![1, 17, 30, 18],
4868 anchors: Some(vec![
4869 [0.3666666, 0.3148148],
4870 [0.3874999, 0.474074],
4871 [0.5333333, 0.644444],
4872 ]),
4873 quantization: None,
4874 dshape: vec![
4875 (DimName::Batch, 1),
4876 (DimName::Height, 17),
4877 (DimName::Width, 30),
4878 (DimName::NumAnchorsXFeatures, 18),
4879 ],
4880 normalized: Some(true),
4881 }],
4882 configs::Segmentation {
4883 decoder: DecoderType::ModelPack,
4884 quantization: None,
4885 shape: vec![1, 160, 106, 3],
4886 dshape: vec![
4887 (DimName::Batch, 1),
4888 (DimName::Height, 160),
4889 (DimName::Width, 106),
4890 (DimName::NumClasses, 3),
4891 ],
4892 },
4893 )
4894 .build();
4895
4896 assert!(matches!(
4897 result, Err(DecoderError::InvalidConfig(s)) if s.contains("incompatible with number of classes")));
4898 }
4899
4900 #[test]
4901 fn test_decode_bad_shapes() {
4902 let score_threshold = 0.25;
4903 let iou_threshold = 0.7;
4904 let quant = (0.0040811873, -123);
4905 let out = include_bytes!("../../../testdata/yolov8s_80_classes.bin");
4906 let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
4907 let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
4908 let out_float: Array3<f32> = dequantize_ndarray(out.view(), quant.into());
4909
4910 let decoder = DecoderBuilder::default()
4911 .with_config_yolo_det(
4912 configs::Detection {
4913 decoder: DecoderType::Ultralytics,
4914 shape: vec![1, 85, 8400],
4915 anchors: None,
4916 quantization: Some(quant.into()),
4917 dshape: vec![
4918 (DimName::Batch, 1),
4919 (DimName::NumFeatures, 85),
4920 (DimName::NumBoxes, 8400),
4921 ],
4922 normalized: Some(true),
4923 },
4924 Some(DecoderVersion::Yolo11),
4925 )
4926 .with_score_threshold(score_threshold)
4927 .with_iou_threshold(iou_threshold)
4928 .build()
4929 .unwrap();
4930
4931 let mut output_boxes: Vec<_> = Vec::with_capacity(50);
4932 let mut output_masks: Vec<_> = Vec::with_capacity(50);
4933 let result =
4934 decoder.decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks);
4935
4936 assert!(matches!(
4937 result, Err(DecoderError::InvalidShape(s)) if s == "Did not find output with shape [1, 85, 8400]"));
4938
4939 let result = decoder.decode_float(
4940 &[out_float.view().into_dyn()],
4941 &mut output_boxes,
4942 &mut output_masks,
4943 );
4944
4945 assert!(matches!(
4946 result, Err(DecoderError::InvalidShape(s)) if s == "Did not find output with shape [1, 85, 8400]"));
4947 }
4948
4949 #[test]
4950 fn test_config_outputs() {
4951 let outputs = [
4952 ConfigOutput::Detection(configs::Detection {
4953 decoder: configs::DecoderType::Ultralytics,
4954 anchors: None,
4955 shape: vec![1, 8400, 85],
4956 quantization: Some(QuantTuple(0.123, 0)),
4957 dshape: vec![
4958 (DimName::Batch, 1),
4959 (DimName::NumBoxes, 8400),
4960 (DimName::NumFeatures, 85),
4961 ],
4962 normalized: Some(true),
4963 }),
4964 ConfigOutput::Mask(configs::Mask {
4965 decoder: configs::DecoderType::Ultralytics,
4966 shape: vec![1, 160, 160, 1],
4967 quantization: Some(QuantTuple(0.223, 0)),
4968 dshape: vec![
4969 (DimName::Batch, 1),
4970 (DimName::Height, 160),
4971 (DimName::Width, 160),
4972 (DimName::NumFeatures, 1),
4973 ],
4974 }),
4975 ConfigOutput::Segmentation(configs::Segmentation {
4976 decoder: configs::DecoderType::Ultralytics,
4977 shape: vec![1, 160, 160, 80],
4978 quantization: Some(QuantTuple(0.323, 0)),
4979 dshape: vec![
4980 (DimName::Batch, 1),
4981 (DimName::Height, 160),
4982 (DimName::Width, 160),
4983 (DimName::NumClasses, 80),
4984 ],
4985 }),
4986 ConfigOutput::Scores(configs::Scores {
4987 decoder: configs::DecoderType::Ultralytics,
4988 shape: vec![1, 8400, 80],
4989 quantization: Some(QuantTuple(0.423, 0)),
4990 dshape: vec![
4991 (DimName::Batch, 1),
4992 (DimName::NumBoxes, 8400),
4993 (DimName::NumClasses, 80),
4994 ],
4995 }),
4996 ConfigOutput::Boxes(configs::Boxes {
4997 decoder: configs::DecoderType::Ultralytics,
4998 shape: vec![1, 8400, 4],
4999 quantization: Some(QuantTuple(0.523, 0)),
5000 dshape: vec![
5001 (DimName::Batch, 1),
5002 (DimName::NumBoxes, 8400),
5003 (DimName::BoxCoords, 4),
5004 ],
5005 normalized: Some(true),
5006 }),
5007 ConfigOutput::Protos(configs::Protos {
5008 decoder: configs::DecoderType::Ultralytics,
5009 shape: vec![1, 32, 160, 160],
5010 quantization: Some(QuantTuple(0.623, 0)),
5011 dshape: vec![
5012 (DimName::Batch, 1),
5013 (DimName::NumProtos, 32),
5014 (DimName::Height, 160),
5015 (DimName::Width, 160),
5016 ],
5017 }),
5018 ConfigOutput::MaskCoefficients(configs::MaskCoefficients {
5019 decoder: configs::DecoderType::Ultralytics,
5020 shape: vec![1, 8400, 32],
5021 quantization: Some(QuantTuple(0.723, 0)),
5022 dshape: vec![
5023 (DimName::Batch, 1),
5024 (DimName::NumBoxes, 8400),
5025 (DimName::NumProtos, 32),
5026 ],
5027 }),
5028 ];
5029
5030 let shapes = outputs.clone().map(|x| x.shape().to_vec());
5031 assert_eq!(
5032 shapes,
5033 [
5034 vec![1, 8400, 85],
5035 vec![1, 160, 160, 1],
5036 vec![1, 160, 160, 80],
5037 vec![1, 8400, 80],
5038 vec![1, 8400, 4],
5039 vec![1, 32, 160, 160],
5040 vec![1, 8400, 32],
5041 ]
5042 );
5043
5044 let quants: [Option<(f32, i32)>; 7] = outputs.map(|x| x.quantization().map(|q| q.into()));
5045 assert_eq!(
5046 quants,
5047 [
5048 Some((0.123, 0)),
5049 Some((0.223, 0)),
5050 Some((0.323, 0)),
5051 Some((0.423, 0)),
5052 Some((0.523, 0)),
5053 Some((0.623, 0)),
5054 Some((0.723, 0)),
5055 ]
5056 );
5057 }
5058
5059 #[test]
5060 fn test_nms_from_config_yaml() {
5061 let yaml_class_agnostic = r#"
5063outputs:
5064 - decoder: ultralytics
5065 type: detection
5066 shape: [1, 84, 8400]
5067 dshape:
5068 - [batch, 1]
5069 - [num_features, 84]
5070 - [num_boxes, 8400]
5071nms: class_agnostic
5072"#;
5073 let decoder = DecoderBuilder::new()
5074 .with_config_yaml_str(yaml_class_agnostic.to_string())
5075 .build()
5076 .unwrap();
5077 assert_eq!(decoder.nms, Some(configs::Nms::ClassAgnostic));
5078
5079 let yaml_class_aware = r#"
5080outputs:
5081 - decoder: ultralytics
5082 type: detection
5083 shape: [1, 84, 8400]
5084 dshape:
5085 - [batch, 1]
5086 - [num_features, 84]
5087 - [num_boxes, 8400]
5088nms: class_aware
5089"#;
5090 let decoder = DecoderBuilder::new()
5091 .with_config_yaml_str(yaml_class_aware.to_string())
5092 .build()
5093 .unwrap();
5094 assert_eq!(decoder.nms, Some(configs::Nms::ClassAware));
5095
5096 let decoder = DecoderBuilder::new()
5098 .with_config_yaml_str(yaml_class_aware.to_string())
5099 .with_nms(Some(configs::Nms::ClassAgnostic)) .build()
5101 .unwrap();
5102 assert_eq!(decoder.nms, Some(configs::Nms::ClassAware));
5104 }
5105
5106 #[test]
5107 fn test_nms_from_config_json() {
5108 let json_class_aware = r#"{
5110 "outputs": [{
5111 "decoder": "ultralytics",
5112 "type": "detection",
5113 "shape": [1, 84, 8400],
5114 "dshape": [["batch", 1], ["num_features", 84], ["num_boxes", 8400]]
5115 }],
5116 "nms": "class_aware"
5117 }"#;
5118 let decoder = DecoderBuilder::new()
5119 .with_config_json_str(json_class_aware.to_string())
5120 .build()
5121 .unwrap();
5122 assert_eq!(decoder.nms, Some(configs::Nms::ClassAware));
5123 }
5124
5125 #[test]
5126 fn test_nms_missing_from_config_uses_builder_default() {
5127 let yaml_no_nms = r#"
5129outputs:
5130 - decoder: ultralytics
5131 type: detection
5132 shape: [1, 84, 8400]
5133 dshape:
5134 - [batch, 1]
5135 - [num_features, 84]
5136 - [num_boxes, 8400]
5137"#;
5138 let decoder = DecoderBuilder::new()
5139 .with_config_yaml_str(yaml_no_nms.to_string())
5140 .build()
5141 .unwrap();
5142 assert_eq!(decoder.nms, Some(configs::Nms::ClassAgnostic));
5144
5145 let decoder = DecoderBuilder::new()
5147 .with_config_yaml_str(yaml_no_nms.to_string())
5148 .with_nms(None) .build()
5150 .unwrap();
5151 assert_eq!(decoder.nms, None);
5152 }
5153
5154 #[test]
5155 fn test_decoder_version_yolo26_end_to_end() {
5156 let yaml = r#"
5158outputs:
5159 - decoder: ultralytics
5160 type: detection
5161 shape: [1, 6, 8400]
5162 dshape:
5163 - [batch, 1]
5164 - [num_features, 6]
5165 - [num_boxes, 8400]
5166decoder_version: yolo26
5167"#;
5168 let decoder = DecoderBuilder::new()
5169 .with_config_yaml_str(yaml.to_string())
5170 .build()
5171 .unwrap();
5172 assert!(matches!(
5173 decoder.model_type,
5174 ModelType::YoloEndToEndDet { .. }
5175 ));
5176
5177 let yaml_with_nms = r#"
5179outputs:
5180 - decoder: ultralytics
5181 type: detection
5182 shape: [1, 6, 8400]
5183 dshape:
5184 - [batch, 1]
5185 - [num_features, 6]
5186 - [num_boxes, 8400]
5187decoder_version: yolo26
5188nms: class_agnostic
5189"#;
5190 let decoder = DecoderBuilder::new()
5191 .with_config_yaml_str(yaml_with_nms.to_string())
5192 .build()
5193 .unwrap();
5194 assert!(matches!(
5195 decoder.model_type,
5196 ModelType::YoloEndToEndDet { .. }
5197 ));
5198 }
5199
5200 #[test]
5201 fn test_decoder_version_yolov8_traditional() {
5202 let yaml = r#"
5204outputs:
5205 - decoder: ultralytics
5206 type: detection
5207 shape: [1, 84, 8400]
5208 dshape:
5209 - [batch, 1]
5210 - [num_features, 84]
5211 - [num_boxes, 8400]
5212decoder_version: yolov8
5213"#;
5214 let decoder = DecoderBuilder::new()
5215 .with_config_yaml_str(yaml.to_string())
5216 .build()
5217 .unwrap();
5218 assert!(matches!(decoder.model_type, ModelType::YoloDet { .. }));
5219 }
5220
5221 #[test]
5222 fn test_decoder_version_all_versions() {
5223 for version in ["yolov5", "yolov8", "yolo11"] {
5225 let yaml = format!(
5226 r#"
5227outputs:
5228 - decoder: ultralytics
5229 type: detection
5230 shape: [1, 84, 8400]
5231 dshape:
5232 - [batch, 1]
5233 - [num_features, 84]
5234 - [num_boxes, 8400]
5235decoder_version: {}
5236"#,
5237 version
5238 );
5239 let decoder = DecoderBuilder::new()
5240 .with_config_yaml_str(yaml)
5241 .build()
5242 .unwrap();
5243
5244 assert!(
5245 matches!(decoder.model_type, ModelType::YoloDet { .. }),
5246 "Expected traditional for {}",
5247 version
5248 );
5249 }
5250
5251 let yaml = r#"
5252outputs:
5253 - decoder: ultralytics
5254 type: detection
5255 shape: [1, 6, 8400]
5256 dshape:
5257 - [batch, 1]
5258 - [num_features, 6]
5259 - [num_boxes, 8400]
5260decoder_version: yolo26
5261"#
5262 .to_string();
5263
5264 let decoder = DecoderBuilder::new()
5265 .with_config_yaml_str(yaml)
5266 .build()
5267 .unwrap();
5268
5269 assert!(
5270 matches!(decoder.model_type, ModelType::YoloEndToEndDet { .. }),
5271 "Expected end to end for yolo26",
5272 );
5273 }
5274
5275 #[test]
5276 fn test_decoder_version_json() {
5277 let json = r#"{
5279 "outputs": [{
5280 "decoder": "ultralytics",
5281 "type": "detection",
5282 "shape": [1, 6, 8400],
5283 "dshape": [["batch", 1], ["num_features", 6], ["num_boxes", 8400]]
5284 }],
5285 "decoder_version": "yolo26"
5286 }"#;
5287 let decoder = DecoderBuilder::new()
5288 .with_config_json_str(json.to_string())
5289 .build()
5290 .unwrap();
5291 assert!(matches!(
5292 decoder.model_type,
5293 ModelType::YoloEndToEndDet { .. }
5294 ));
5295 }
5296
5297 #[test]
5298 fn test_decoder_version_none_uses_traditional() {
5299 let yaml = r#"
5301outputs:
5302 - decoder: ultralytics
5303 type: detection
5304 shape: [1, 84, 8400]
5305 dshape:
5306 - [batch, 1]
5307 - [num_features, 84]
5308 - [num_boxes, 8400]
5309"#;
5310 let decoder = DecoderBuilder::new()
5311 .with_config_yaml_str(yaml.to_string())
5312 .build()
5313 .unwrap();
5314 assert!(matches!(decoder.model_type, ModelType::YoloDet { .. }));
5315 }
5316
5317 #[test]
5318 fn test_decoder_version_none_with_nms_none_still_traditional() {
5319 let yaml = r#"
5322outputs:
5323 - decoder: ultralytics
5324 type: detection
5325 shape: [1, 84, 8400]
5326 dshape:
5327 - [batch, 1]
5328 - [num_features, 84]
5329 - [num_boxes, 8400]
5330"#;
5331 let decoder = DecoderBuilder::new()
5332 .with_config_yaml_str(yaml.to_string())
5333 .with_nms(None) .build()
5335 .unwrap();
5336 assert!(matches!(decoder.model_type, ModelType::YoloDet { .. }));
5339 }
5340
5341 #[test]
5342 fn test_decoder_heuristic_end_to_end_detection() {
5343 let yaml = r#"
5346outputs:
5347 - decoder: ultralytics
5348 type: detection
5349 shape: [1, 300, 6]
5350 dshape:
5351 - [batch, 1]
5352 - [num_boxes, 300]
5353 - [num_features, 6]
5354
5355"#;
5356 let decoder = DecoderBuilder::new()
5357 .with_config_yaml_str(yaml.to_string())
5358 .build()
5359 .unwrap();
5360 assert!(matches!(
5362 decoder.model_type,
5363 ModelType::YoloEndToEndDet { .. }
5364 ));
5365
5366 let yaml = r#"
5367outputs:
5368 - decoder: ultralytics
5369 type: detection
5370 shape: [1, 300, 38]
5371 dshape:
5372 - [batch, 1]
5373 - [num_boxes, 300]
5374 - [num_features, 38]
5375 - decoder: ultralytics
5376 type: protos
5377 shape: [1, 160, 160, 32]
5378 dshape:
5379 - [batch, 1]
5380 - [height, 160]
5381 - [width, 160]
5382 - [num_protos, 32]
5383"#;
5384 let decoder = DecoderBuilder::new()
5385 .with_config_yaml_str(yaml.to_string())
5386 .build()
5387 .unwrap();
5388 assert!(matches!(
5390 decoder.model_type,
5391 ModelType::YoloEndToEndSegDet { .. }
5392 ));
5393
5394 let yaml = r#"
5395outputs:
5396 - decoder: ultralytics
5397 type: detection
5398 shape: [1, 6, 300]
5399 dshape:
5400 - [batch, 1]
5401 - [num_features, 6]
5402 - [num_boxes, 300]
5403"#;
5404 let decoder = DecoderBuilder::new()
5405 .with_config_yaml_str(yaml.to_string())
5406 .build()
5407 .unwrap();
5408 assert!(matches!(decoder.model_type, ModelType::YoloDet { .. }));
5411
5412 let yaml = r#"
5413outputs:
5414 - decoder: ultralytics
5415 type: detection
5416 shape: [1, 38, 300]
5417 dshape:
5418 - [batch, 1]
5419 - [num_features, 38]
5420 - [num_boxes, 300]
5421
5422 - decoder: ultralytics
5423 type: protos
5424 shape: [1, 160, 160, 32]
5425 dshape:
5426 - [batch, 1]
5427 - [height, 160]
5428 - [width, 160]
5429 - [num_protos, 32]
5430"#;
5431 let decoder = DecoderBuilder::new()
5432 .with_config_yaml_str(yaml.to_string())
5433 .build()
5434 .unwrap();
5435 assert!(matches!(decoder.model_type, ModelType::YoloSegDet { .. }));
5437 }
5438
5439 #[test]
5440 fn test_decoder_version_is_end_to_end() {
5441 assert!(!configs::DecoderVersion::Yolov5.is_end_to_end());
5442 assert!(!configs::DecoderVersion::Yolov8.is_end_to_end());
5443 assert!(!configs::DecoderVersion::Yolo11.is_end_to_end());
5444 assert!(configs::DecoderVersion::Yolo26.is_end_to_end());
5445 }
5446}