1use std::collections::HashSet;
5
6use ndarray::{s, Array3, ArrayView, ArrayViewD, Dimension};
7use ndarray_stats::QuantileExt;
8use num_traits::{AsPrimitive, Float};
9use serde::{Deserialize, Serialize};
10
11use crate::{
12 configs::{DecoderType, DimName, ModelType, QuantTuple},
13 dequantize_ndarray,
14 modelpack::{
15 decode_modelpack_det, decode_modelpack_float, decode_modelpack_split_float,
16 ModelPackDetectionConfig,
17 },
18 yolo::{
19 decode_yolo_det, decode_yolo_det_float, decode_yolo_segdet_float, decode_yolo_segdet_quant,
20 decode_yolo_split_det_float, decode_yolo_split_det_quant, decode_yolo_split_segdet_float,
21 impl_yolo_split_segdet_quant_get_boxes, impl_yolo_split_segdet_quant_process_masks,
22 },
23 DecoderError, DecoderVersion, DetectBox, Quantization, Segmentation, XYWH,
24};
25
26#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
38pub struct ConfigOutputs {
39 #[serde(default)]
40 pub outputs: Vec<ConfigOutput>,
41 #[serde(default, skip_serializing_if = "Option::is_none")]
49 pub nms: Option<configs::Nms>,
50 #[serde(default, skip_serializing_if = "Option::is_none")]
57 pub decoder_version: Option<configs::DecoderVersion>,
58}
59
60#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
61#[serde(tag = "type")]
62pub enum ConfigOutput {
63 #[serde(rename = "detection")]
64 Detection(configs::Detection),
65 #[serde(rename = "masks")]
66 Mask(configs::Mask),
67 #[serde(rename = "segmentation")]
68 Segmentation(configs::Segmentation),
69 #[serde(rename = "protos")]
70 Protos(configs::Protos),
71 #[serde(rename = "scores")]
72 Scores(configs::Scores),
73 #[serde(rename = "boxes")]
74 Boxes(configs::Boxes),
75 #[serde(rename = "mask_coefficients")]
76 MaskCoefficients(configs::MaskCoefficients),
77}
78
79#[derive(Debug, PartialEq, Clone)]
80pub enum ConfigOutputRef<'a> {
81 Detection(&'a configs::Detection),
82 Mask(&'a configs::Mask),
83 Segmentation(&'a configs::Segmentation),
84 Protos(&'a configs::Protos),
85 Scores(&'a configs::Scores),
86 Boxes(&'a configs::Boxes),
87 MaskCoefficients(&'a configs::MaskCoefficients),
88}
89
90impl<'a> ConfigOutputRef<'a> {
91 fn decoder(&self) -> configs::DecoderType {
92 match self {
93 ConfigOutputRef::Detection(v) => v.decoder,
94 ConfigOutputRef::Mask(v) => v.decoder,
95 ConfigOutputRef::Segmentation(v) => v.decoder,
96 ConfigOutputRef::Protos(v) => v.decoder,
97 ConfigOutputRef::Scores(v) => v.decoder,
98 ConfigOutputRef::Boxes(v) => v.decoder,
99 ConfigOutputRef::MaskCoefficients(v) => v.decoder,
100 }
101 }
102
103 fn dshape(&self) -> &[(DimName, usize)] {
104 match self {
105 ConfigOutputRef::Detection(v) => &v.dshape,
106 ConfigOutputRef::Mask(v) => &v.dshape,
107 ConfigOutputRef::Segmentation(v) => &v.dshape,
108 ConfigOutputRef::Protos(v) => &v.dshape,
109 ConfigOutputRef::Scores(v) => &v.dshape,
110 ConfigOutputRef::Boxes(v) => &v.dshape,
111 ConfigOutputRef::MaskCoefficients(v) => &v.dshape,
112 }
113 }
114}
115
116impl<'a> From<&'a configs::Detection> for ConfigOutputRef<'a> {
117 fn from(v: &'a configs::Detection) -> ConfigOutputRef<'a> {
132 ConfigOutputRef::Detection(v)
133 }
134}
135
136impl<'a> From<&'a configs::Mask> for ConfigOutputRef<'a> {
137 fn from(v: &'a configs::Mask) -> ConfigOutputRef<'a> {
150 ConfigOutputRef::Mask(v)
151 }
152}
153
154impl<'a> From<&'a configs::Segmentation> for ConfigOutputRef<'a> {
155 fn from(v: &'a configs::Segmentation) -> ConfigOutputRef<'a> {
168 ConfigOutputRef::Segmentation(v)
169 }
170}
171
172impl<'a> From<&'a configs::Protos> for ConfigOutputRef<'a> {
173 fn from(v: &'a configs::Protos) -> ConfigOutputRef<'a> {
186 ConfigOutputRef::Protos(v)
187 }
188}
189
190impl<'a> From<&'a configs::Scores> for ConfigOutputRef<'a> {
191 fn from(v: &'a configs::Scores) -> ConfigOutputRef<'a> {
204 ConfigOutputRef::Scores(v)
205 }
206}
207
208impl<'a> From<&'a configs::Boxes> for ConfigOutputRef<'a> {
209 fn from(v: &'a configs::Boxes) -> ConfigOutputRef<'a> {
223 ConfigOutputRef::Boxes(v)
224 }
225}
226
227impl<'a> From<&'a configs::MaskCoefficients> for ConfigOutputRef<'a> {
228 fn from(v: &'a configs::MaskCoefficients) -> ConfigOutputRef<'a> {
241 ConfigOutputRef::MaskCoefficients(v)
242 }
243}
244
245impl ConfigOutput {
246 pub fn shape(&self) -> &[usize] {
263 match self {
264 ConfigOutput::Detection(detection) => &detection.shape,
265 ConfigOutput::Mask(mask) => &mask.shape,
266 ConfigOutput::Segmentation(segmentation) => &segmentation.shape,
267 ConfigOutput::Scores(scores) => &scores.shape,
268 ConfigOutput::Boxes(boxes) => &boxes.shape,
269 ConfigOutput::Protos(protos) => &protos.shape,
270 ConfigOutput::MaskCoefficients(mask_coefficients) => &mask_coefficients.shape,
271 }
272 }
273
274 pub fn decoder(&self) -> &configs::DecoderType {
291 match self {
292 ConfigOutput::Detection(detection) => &detection.decoder,
293 ConfigOutput::Mask(mask) => &mask.decoder,
294 ConfigOutput::Segmentation(segmentation) => &segmentation.decoder,
295 ConfigOutput::Scores(scores) => &scores.decoder,
296 ConfigOutput::Boxes(boxes) => &boxes.decoder,
297 ConfigOutput::Protos(protos) => &protos.decoder,
298 ConfigOutput::MaskCoefficients(mask_coefficients) => &mask_coefficients.decoder,
299 }
300 }
301
302 pub fn quantization(&self) -> Option<QuantTuple> {
319 match self {
320 ConfigOutput::Detection(detection) => detection.quantization,
321 ConfigOutput::Mask(mask) => mask.quantization,
322 ConfigOutput::Segmentation(segmentation) => segmentation.quantization,
323 ConfigOutput::Scores(scores) => scores.quantization,
324 ConfigOutput::Boxes(boxes) => boxes.quantization,
325 ConfigOutput::Protos(protos) => protos.quantization,
326 ConfigOutput::MaskCoefficients(mask_coefficients) => mask_coefficients.quantization,
327 }
328 }
329}
330
331pub mod configs {
332 use std::fmt::Display;
333
334 use serde::{Deserialize, Serialize};
335
336 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy)]
337 pub struct QuantTuple(pub f32, pub i32);
338 impl From<QuantTuple> for (f32, i32) {
339 fn from(value: QuantTuple) -> Self {
340 (value.0, value.1)
341 }
342 }
343
344 impl From<(f32, i32)> for QuantTuple {
345 fn from(value: (f32, i32)) -> Self {
346 QuantTuple(value.0, value.1)
347 }
348 }
349
350 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
351 pub struct Segmentation {
352 pub decoder: DecoderType,
353 pub quantization: Option<QuantTuple>,
354 pub shape: Vec<usize>,
355 #[serde(default)]
358 pub dshape: Vec<(DimName, usize)>,
359 }
360
361 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
362 pub struct Protos {
363 pub decoder: DecoderType,
364 #[serde(flatten)]
365 pub quantization: Option<QuantTuple>,
366 pub shape: Vec<usize>,
367 #[serde(default)]
370 pub dshape: Vec<(DimName, usize)>,
371 }
372
373 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
374 pub struct MaskCoefficients {
375 pub decoder: DecoderType,
376 pub quantization: Option<QuantTuple>,
377 pub shape: Vec<usize>,
378 #[serde(default)]
381 pub dshape: Vec<(DimName, usize)>,
382 }
383
384 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
385 pub struct Mask {
386 pub decoder: DecoderType,
387 pub quantization: Option<QuantTuple>,
388 pub shape: Vec<usize>,
389 #[serde(default)]
392 pub dshape: Vec<(DimName, usize)>,
393 }
394
395 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
396 pub struct Detection {
397 pub anchors: Option<Vec<[f32; 2]>>,
398 pub decoder: DecoderType,
399 pub quantization: Option<QuantTuple>,
400 pub shape: Vec<usize>,
401 #[serde(default)]
404 pub dshape: Vec<(DimName, usize)>,
405 #[serde(default)]
412 pub normalized: Option<bool>,
413 }
414
415 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
416 pub struct Scores {
417 pub decoder: DecoderType,
418 pub quantization: Option<QuantTuple>,
419 pub shape: Vec<usize>,
420 #[serde(default)]
423 pub dshape: Vec<(DimName, usize)>,
424 }
425
426 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
427 pub struct Boxes {
428 pub decoder: DecoderType,
429 pub quantization: Option<QuantTuple>,
430 pub shape: Vec<usize>,
431 #[serde(default)]
434 pub dshape: Vec<(DimName, usize)>,
435 #[serde(default)]
442 pub normalized: Option<bool>,
443 }
444
445 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq)]
446 pub enum DimName {
447 #[serde(rename = "batch")]
448 Batch,
449 #[serde(rename = "height")]
450 Height,
451 #[serde(rename = "width")]
452 Width,
453 #[serde(rename = "num_classes")]
454 NumClasses,
455 #[serde(rename = "num_features")]
456 NumFeatures,
457 #[serde(rename = "num_boxes")]
458 NumBoxes,
459 #[serde(rename = "num_protos")]
460 NumProtos,
461 #[serde(rename = "num_anchors_x_features")]
462 NumAnchorsXFeatures,
463 #[serde(rename = "padding")]
464 Padding,
465 #[serde(rename = "box_coords")]
466 BoxCoords,
467 }
468
469 impl Display for DimName {
470 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
480 match self {
481 DimName::Batch => write!(f, "batch"),
482 DimName::Height => write!(f, "height"),
483 DimName::Width => write!(f, "width"),
484 DimName::NumClasses => write!(f, "num_classes"),
485 DimName::NumFeatures => write!(f, "num_features"),
486 DimName::NumBoxes => write!(f, "num_boxes"),
487 DimName::NumProtos => write!(f, "num_protos"),
488 DimName::NumAnchorsXFeatures => write!(f, "num_anchors_x_features"),
489 DimName::Padding => write!(f, "padding"),
490 DimName::BoxCoords => write!(f, "box_coords"),
491 }
492 }
493 }
494
495 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq)]
496 pub enum DecoderType {
497 #[serde(rename = "modelpack")]
498 ModelPack,
499 #[serde(rename = "ultralytics")]
500 Ultralytics,
501 }
502
503 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq)]
515 #[serde(rename_all = "lowercase")]
516 pub enum DecoderVersion {
517 #[serde(rename = "yolov5")]
519 Yolov5,
520 #[serde(rename = "yolov8")]
522 Yolov8,
523 #[serde(rename = "yolo11")]
525 Yolo11,
526 #[serde(rename = "yolo26")]
529 Yolo26,
530 }
531
532 impl DecoderVersion {
533 pub fn is_end_to_end(&self) -> bool {
536 matches!(self, DecoderVersion::Yolo26)
537 }
538 }
539
540 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq, Default)]
549 #[serde(rename_all = "snake_case")]
550 pub enum Nms {
551 #[default]
554 ClassAgnostic,
555 ClassAware,
557 }
558
559 #[derive(Debug, Clone, PartialEq)]
560 pub enum ModelType {
561 ModelPackSegDet {
562 boxes: Boxes,
563 scores: Scores,
564 segmentation: Segmentation,
565 },
566 ModelPackSegDetSplit {
567 detection: Vec<Detection>,
568 segmentation: Segmentation,
569 },
570 ModelPackDet {
571 boxes: Boxes,
572 scores: Scores,
573 },
574 ModelPackDetSplit {
575 detection: Vec<Detection>,
576 },
577 ModelPackSeg {
578 segmentation: Segmentation,
579 },
580 YoloDet {
581 boxes: Detection,
582 },
583 YoloSegDet {
584 boxes: Detection,
585 protos: Protos,
586 },
587 YoloSplitDet {
588 boxes: Boxes,
589 scores: Scores,
590 },
591 YoloSplitSegDet {
592 boxes: Boxes,
593 scores: Scores,
594 mask_coeff: MaskCoefficients,
595 protos: Protos,
596 },
597 YoloEndToEndDet {
601 boxes: Detection,
602 },
603 YoloEndToEndSegDet {
607 boxes: Detection,
608 protos: Protos,
609 },
610 }
611
612 #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
613 #[serde(rename_all = "lowercase")]
614 pub enum DataType {
615 Raw = 0,
616 Int8 = 1,
617 UInt8 = 2,
618 Int16 = 3,
619 UInt16 = 4,
620 Float16 = 5,
621 Int32 = 6,
622 UInt32 = 7,
623 Float32 = 8,
624 Int64 = 9,
625 UInt64 = 10,
626 Float64 = 11,
627 String = 12,
628 }
629}
630
631#[derive(Debug, Clone, PartialEq)]
632pub struct DecoderBuilder {
633 config_src: Option<ConfigSource>,
634 iou_threshold: f32,
635 score_threshold: f32,
636 nms: Option<configs::Nms>,
639}
640
641#[derive(Debug, Clone, PartialEq)]
642enum ConfigSource {
643 Yaml(String),
644 Json(String),
645 Config(ConfigOutputs),
646}
647
648impl Default for DecoderBuilder {
649 fn default() -> Self {
669 Self {
670 config_src: None,
671 iou_threshold: 0.5,
672 score_threshold: 0.5,
673 nms: Some(configs::Nms::ClassAgnostic),
674 }
675 }
676}
677
678impl DecoderBuilder {
679 pub fn new() -> Self {
699 Self::default()
700 }
701
702 pub fn with_config_yaml_str(mut self, yaml_str: String) -> Self {
719 self.config_src.replace(ConfigSource::Yaml(yaml_str));
720 self
721 }
722
723 pub fn with_config_json_str(mut self, json_str: String) -> Self {
740 self.config_src.replace(ConfigSource::Json(json_str));
741 self
742 }
743
744 pub fn with_config(mut self, config: ConfigOutputs) -> Self {
761 self.config_src.replace(ConfigSource::Config(config));
762 self
763 }
764
765 pub fn with_config_yolo_det(
790 mut self,
791 boxes: configs::Detection,
792 version: Option<DecoderVersion>,
793 ) -> Self {
794 let config = ConfigOutputs {
795 outputs: vec![ConfigOutput::Detection(boxes)],
796 decoder_version: version,
797 ..Default::default()
798 };
799 self.config_src.replace(ConfigSource::Config(config));
800 self
801 }
802
803 pub fn with_config_yolo_split_det(
830 mut self,
831 boxes: configs::Boxes,
832 scores: configs::Scores,
833 ) -> Self {
834 let config = ConfigOutputs {
835 outputs: vec![ConfigOutput::Boxes(boxes), ConfigOutput::Scores(scores)],
836 ..Default::default()
837 };
838 self.config_src.replace(ConfigSource::Config(config));
839 self
840 }
841
842 pub fn with_config_yolo_segdet(
874 mut self,
875 boxes: configs::Detection,
876 protos: configs::Protos,
877 version: Option<DecoderVersion>,
878 ) -> Self {
879 let config = ConfigOutputs {
880 outputs: vec![ConfigOutput::Detection(boxes), ConfigOutput::Protos(protos)],
881 decoder_version: version,
882 ..Default::default()
883 };
884 self.config_src.replace(ConfigSource::Config(config));
885 self
886 }
887
888 pub fn with_config_yolo_split_segdet(
927 mut self,
928 boxes: configs::Boxes,
929 scores: configs::Scores,
930 mask_coefficients: configs::MaskCoefficients,
931 protos: configs::Protos,
932 ) -> Self {
933 let config = ConfigOutputs {
934 outputs: vec![
935 ConfigOutput::Boxes(boxes),
936 ConfigOutput::Scores(scores),
937 ConfigOutput::MaskCoefficients(mask_coefficients),
938 ConfigOutput::Protos(protos),
939 ],
940 ..Default::default()
941 };
942 self.config_src.replace(ConfigSource::Config(config));
943 self
944 }
945
946 pub fn with_config_modelpack_det(
973 mut self,
974 boxes: configs::Boxes,
975 scores: configs::Scores,
976 ) -> Self {
977 let config = ConfigOutputs {
978 outputs: vec![ConfigOutput::Boxes(boxes), ConfigOutput::Scores(scores)],
979 ..Default::default()
980 };
981 self.config_src.replace(ConfigSource::Config(config));
982 self
983 }
984
985 pub fn with_config_modelpack_det_split(mut self, boxes: Vec<configs::Detection>) -> Self {
1024 let outputs = boxes.into_iter().map(ConfigOutput::Detection).collect();
1025 let config = ConfigOutputs {
1026 outputs,
1027 ..Default::default()
1028 };
1029 self.config_src.replace(ConfigSource::Config(config));
1030 self
1031 }
1032
1033 pub fn with_config_modelpack_segdet(
1066 mut self,
1067 boxes: configs::Boxes,
1068 scores: configs::Scores,
1069 segmentation: configs::Segmentation,
1070 ) -> Self {
1071 let config = ConfigOutputs {
1072 outputs: vec![
1073 ConfigOutput::Boxes(boxes),
1074 ConfigOutput::Scores(scores),
1075 ConfigOutput::Segmentation(segmentation),
1076 ],
1077 ..Default::default()
1078 };
1079 self.config_src.replace(ConfigSource::Config(config));
1080 self
1081 }
1082
1083 pub fn with_config_modelpack_segdet_split(
1127 mut self,
1128 boxes: Vec<configs::Detection>,
1129 segmentation: configs::Segmentation,
1130 ) -> Self {
1131 let mut outputs = boxes
1132 .into_iter()
1133 .map(ConfigOutput::Detection)
1134 .collect::<Vec<_>>();
1135 outputs.push(ConfigOutput::Segmentation(segmentation));
1136 let config = ConfigOutputs {
1137 outputs,
1138 ..Default::default()
1139 };
1140 self.config_src.replace(ConfigSource::Config(config));
1141 self
1142 }
1143
1144 pub fn with_config_modelpack_seg(mut self, segmentation: configs::Segmentation) -> Self {
1164 let config = ConfigOutputs {
1165 outputs: vec![ConfigOutput::Segmentation(segmentation)],
1166 ..Default::default()
1167 };
1168 self.config_src.replace(ConfigSource::Config(config));
1169 self
1170 }
1171
1172 pub fn with_score_threshold(mut self, score_threshold: f32) -> Self {
1188 self.score_threshold = score_threshold;
1189 self
1190 }
1191
1192 pub fn with_iou_threshold(mut self, iou_threshold: f32) -> Self {
1209 self.iou_threshold = iou_threshold;
1210 self
1211 }
1212
1213 pub fn with_nms(mut self, nms: Option<configs::Nms>) -> Self {
1235 self.nms = nms;
1236 self
1237 }
1238
1239 pub fn build(self) -> Result<Decoder, DecoderError> {
1256 let config = match self.config_src {
1257 Some(ConfigSource::Json(s)) => serde_json::from_str(&s)?,
1258 Some(ConfigSource::Yaml(s)) => serde_yaml::from_str(&s)?,
1259 Some(ConfigSource::Config(c)) => c,
1260 None => return Err(DecoderError::NoConfig),
1261 };
1262
1263 let normalized = Self::get_normalized(&config.outputs);
1265
1266 let nms = config.nms.or(self.nms);
1268 let model_type = Self::get_model_type(config)?;
1269
1270 Ok(Decoder {
1271 model_type,
1272 iou_threshold: self.iou_threshold,
1273 score_threshold: self.score_threshold,
1274 nms,
1275 normalized,
1276 })
1277 }
1278
1279 fn get_normalized(outputs: &[ConfigOutput]) -> Option<bool> {
1284 for output in outputs {
1285 match output {
1286 ConfigOutput::Detection(det) => return det.normalized,
1287 ConfigOutput::Boxes(boxes) => return boxes.normalized,
1288 _ => {}
1289 }
1290 }
1291 None }
1293
1294 fn get_model_type(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
1295 let mut yolo = false;
1297 let mut modelpack = false;
1298 for c in &configs.outputs {
1299 match c.decoder() {
1300 DecoderType::ModelPack => modelpack = true,
1301 DecoderType::Ultralytics => yolo = true,
1302 }
1303 }
1304 match (modelpack, yolo) {
1305 (true, true) => Err(DecoderError::InvalidConfig(
1306 "Both ModelPack and Yolo outputs found in config".to_string(),
1307 )),
1308 (true, false) => Self::get_model_type_modelpack(configs),
1309 (false, true) => Self::get_model_type_yolo(configs),
1310 (false, false) => Err(DecoderError::InvalidConfig(
1311 "No outputs found in config".to_string(),
1312 )),
1313 }
1314 }
1315
1316 fn get_model_type_yolo(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
1317 let mut boxes = None;
1318 let mut protos = None;
1319 let mut split_boxes = None;
1320 let mut split_scores = None;
1321 let mut split_mask_coeff = None;
1322 for c in configs.outputs {
1323 match c {
1324 ConfigOutput::Detection(detection) => boxes = Some(detection),
1325 ConfigOutput::Segmentation(_) => {
1326 return Err(DecoderError::InvalidConfig(
1327 "Invalid Segmentation output with Yolo decoder".to_string(),
1328 ));
1329 }
1330 ConfigOutput::Protos(protos_) => protos = Some(protos_),
1331 ConfigOutput::Mask(_) => {
1332 return Err(DecoderError::InvalidConfig(
1333 "Invalid Mask output with Yolo decoder".to_string(),
1334 ));
1335 }
1336 ConfigOutput::Scores(scores) => split_scores = Some(scores),
1337 ConfigOutput::Boxes(boxes) => split_boxes = Some(boxes),
1338 ConfigOutput::MaskCoefficients(mask_coeff) => split_mask_coeff = Some(mask_coeff),
1339 }
1340 }
1341
1342 let is_end_to_end_dshape = boxes.as_ref().is_some_and(|b| {
1347 let dims = b.dshape.iter().map(|(d, _)| *d).collect::<Vec<_>>();
1348 dims == vec![DimName::Batch, DimName::NumBoxes, DimName::NumFeatures]
1349 });
1350
1351 let is_end_to_end = configs
1352 .decoder_version
1353 .map(|v| v.is_end_to_end())
1354 .unwrap_or(is_end_to_end_dshape);
1355
1356 if is_end_to_end {
1357 if let Some(boxes) = boxes {
1358 if let Some(protos) = protos {
1359 Self::verify_yolo_seg_det_26(&boxes, &protos)?;
1360 return Ok(ModelType::YoloEndToEndSegDet { boxes, protos });
1361 } else {
1362 Self::verify_yolo_det_26(&boxes)?;
1363 return Ok(ModelType::YoloEndToEndDet { boxes });
1364 }
1365 } else {
1366 return Err(DecoderError::InvalidConfig(
1367 "Invalid Yolo end-to-end model outputs".to_string(),
1368 ));
1369 }
1370 }
1371
1372 if let Some(boxes) = boxes {
1373 if let Some(protos) = protos {
1374 Self::verify_yolo_seg_det(&boxes, &protos)?;
1375 Ok(ModelType::YoloSegDet { boxes, protos })
1376 } else {
1377 Self::verify_yolo_det(&boxes)?;
1378 Ok(ModelType::YoloDet { boxes })
1379 }
1380 } else if let (Some(boxes), Some(scores)) = (split_boxes, split_scores) {
1381 if let (Some(mask_coeff), Some(protos)) = (split_mask_coeff, protos) {
1382 Self::verify_yolo_split_segdet(&boxes, &scores, &mask_coeff, &protos)?;
1383 Ok(ModelType::YoloSplitSegDet {
1384 boxes,
1385 scores,
1386 mask_coeff,
1387 protos,
1388 })
1389 } else {
1390 Self::verify_yolo_split_det(&boxes, &scores)?;
1391 Ok(ModelType::YoloSplitDet { boxes, scores })
1392 }
1393 } else {
1394 Err(DecoderError::InvalidConfig(
1395 "Invalid Yolo model outputs".to_string(),
1396 ))
1397 }
1398 }
1399
1400 fn verify_yolo_det(detect: &configs::Detection) -> Result<(), DecoderError> {
1401 if detect.shape.len() != 3 {
1402 return Err(DecoderError::InvalidConfig(format!(
1403 "Invalid Yolo Detection shape {:?}",
1404 detect.shape
1405 )));
1406 }
1407
1408 Self::verify_dshapes(
1409 &detect.dshape,
1410 &detect.shape,
1411 "Detection",
1412 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1413 )?;
1414 if !detect.dshape.is_empty() {
1415 Self::get_class_count(&detect.dshape, None, None)?;
1416 } else {
1417 Self::get_class_count_no_dshape(detect.into(), None)?;
1418 }
1419
1420 Ok(())
1421 }
1422
1423 fn verify_yolo_det_26(detect: &configs::Detection) -> Result<(), DecoderError> {
1424 if detect.shape.len() != 3 {
1425 return Err(DecoderError::InvalidConfig(format!(
1426 "Invalid Yolo Detection shape {:?}",
1427 detect.shape
1428 )));
1429 }
1430
1431 Self::verify_dshapes(
1432 &detect.dshape,
1433 &detect.shape,
1434 "Detection",
1435 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1436 )?;
1437
1438 if !detect.shape.contains(&6) {
1439 return Err(DecoderError::InvalidConfig(
1440 "Yolo26 Detection must have 6 features".to_string(),
1441 ));
1442 }
1443
1444 Ok(())
1445 }
1446
1447 fn verify_yolo_seg_det(
1448 detection: &configs::Detection,
1449 protos: &configs::Protos,
1450 ) -> Result<(), DecoderError> {
1451 if detection.shape.len() != 3 {
1452 return Err(DecoderError::InvalidConfig(format!(
1453 "Invalid Yolo Detection shape {:?}",
1454 detection.shape
1455 )));
1456 }
1457 if protos.shape.len() != 4 {
1458 return Err(DecoderError::InvalidConfig(format!(
1459 "Invalid Yolo Protos shape {:?}",
1460 protos.shape
1461 )));
1462 }
1463
1464 Self::verify_dshapes(
1465 &detection.dshape,
1466 &detection.shape,
1467 "Detection",
1468 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1469 )?;
1470 Self::verify_dshapes(
1471 &protos.dshape,
1472 &protos.shape,
1473 "Protos",
1474 &[
1475 DimName::Batch,
1476 DimName::Height,
1477 DimName::Width,
1478 DimName::NumProtos,
1479 ],
1480 )?;
1481
1482 let protos_count = Self::get_protos_count(&protos.dshape).unwrap_or(protos.shape[3]);
1483 log::debug!("Protos count: {}", protos_count);
1484 log::debug!("Detection dshape: {:?}", detection.dshape);
1485 let classes = if !detection.dshape.is_empty() {
1486 Self::get_class_count(&detection.dshape, Some(protos_count), None)?
1487 } else {
1488 Self::get_class_count_no_dshape(detection.into(), Some(protos_count))?
1489 };
1490
1491 if classes == 0 {
1492 return Err(DecoderError::InvalidConfig(
1493 "Yolo Segmentation Detection has zero classes".to_string(),
1494 ));
1495 }
1496
1497 Ok(())
1498 }
1499
1500 fn verify_yolo_seg_det_26(
1501 detection: &configs::Detection,
1502 protos: &configs::Protos,
1503 ) -> Result<(), DecoderError> {
1504 if detection.shape.len() != 3 {
1505 return Err(DecoderError::InvalidConfig(format!(
1506 "Invalid Yolo Detection shape {:?}",
1507 detection.shape
1508 )));
1509 }
1510 if protos.shape.len() != 4 {
1511 return Err(DecoderError::InvalidConfig(format!(
1512 "Invalid Yolo Protos shape {:?}",
1513 protos.shape
1514 )));
1515 }
1516
1517 Self::verify_dshapes(
1518 &detection.dshape,
1519 &detection.shape,
1520 "Detection",
1521 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1522 )?;
1523 Self::verify_dshapes(
1524 &protos.dshape,
1525 &protos.shape,
1526 "Protos",
1527 &[
1528 DimName::Batch,
1529 DimName::Height,
1530 DimName::Width,
1531 DimName::NumProtos,
1532 ],
1533 )?;
1534
1535 let protos_count = Self::get_protos_count(&protos.dshape).unwrap_or(protos.shape[3]);
1536 log::debug!("Protos count: {}", protos_count);
1537 log::debug!("Detection dshape: {:?}", detection.dshape);
1538
1539 if !detection.shape.contains(&(6 + protos_count)) {
1540 return Err(DecoderError::InvalidConfig(format!(
1541 "Yolo26 Segmentation Detection must have num_features be 6 + num_protos = {}",
1542 6 + protos_count
1543 )));
1544 }
1545
1546 Ok(())
1547 }
1548
1549 fn verify_yolo_split_det(
1550 boxes: &configs::Boxes,
1551 scores: &configs::Scores,
1552 ) -> Result<(), DecoderError> {
1553 if boxes.shape.len() != 3 {
1554 return Err(DecoderError::InvalidConfig(format!(
1555 "Invalid Yolo Split Boxes shape {:?}",
1556 boxes.shape
1557 )));
1558 }
1559 if scores.shape.len() != 3 {
1560 return Err(DecoderError::InvalidConfig(format!(
1561 "Invalid Yolo Split Scores shape {:?}",
1562 scores.shape
1563 )));
1564 }
1565
1566 Self::verify_dshapes(
1567 &boxes.dshape,
1568 &boxes.shape,
1569 "Boxes",
1570 &[DimName::Batch, DimName::BoxCoords, DimName::NumBoxes],
1571 )?;
1572 Self::verify_dshapes(
1573 &scores.dshape,
1574 &scores.shape,
1575 "Scores",
1576 &[DimName::Batch, DimName::NumClasses, DimName::NumBoxes],
1577 )?;
1578
1579 let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[2]);
1580 let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[2]);
1581
1582 if boxes_num != scores_num {
1583 return Err(DecoderError::InvalidConfig(format!(
1584 "Yolo Split Detection Boxes num {} incompatible with Scores num {}",
1585 boxes_num, scores_num
1586 )));
1587 }
1588
1589 Ok(())
1590 }
1591
1592 fn verify_yolo_split_segdet(
1593 boxes: &configs::Boxes,
1594 scores: &configs::Scores,
1595 mask_coeff: &configs::MaskCoefficients,
1596 protos: &configs::Protos,
1597 ) -> Result<(), DecoderError> {
1598 if boxes.shape.len() != 3 {
1599 return Err(DecoderError::InvalidConfig(format!(
1600 "Invalid Yolo Split Boxes shape {:?}",
1601 boxes.shape
1602 )));
1603 }
1604 if scores.shape.len() != 3 {
1605 return Err(DecoderError::InvalidConfig(format!(
1606 "Invalid Yolo Split Scores shape {:?}",
1607 scores.shape
1608 )));
1609 }
1610
1611 if mask_coeff.shape.len() != 3 {
1612 return Err(DecoderError::InvalidConfig(format!(
1613 "Invalid Yolo Split Mask Coefficients shape {:?}",
1614 mask_coeff.shape
1615 )));
1616 }
1617
1618 if protos.shape.len() != 4 {
1619 return Err(DecoderError::InvalidConfig(format!(
1620 "Invalid Yolo Protos shape {:?}",
1621 mask_coeff.shape
1622 )));
1623 }
1624
1625 Self::verify_dshapes(
1626 &boxes.dshape,
1627 &boxes.shape,
1628 "Boxes",
1629 &[DimName::Batch, DimName::BoxCoords, DimName::NumBoxes],
1630 )?;
1631 Self::verify_dshapes(
1632 &scores.dshape,
1633 &scores.shape,
1634 "Scores",
1635 &[DimName::Batch, DimName::NumClasses, DimName::NumBoxes],
1636 )?;
1637 Self::verify_dshapes(
1638 &mask_coeff.dshape,
1639 &mask_coeff.shape,
1640 "Mask Coefficients",
1641 &[DimName::Batch, DimName::NumProtos, DimName::NumBoxes],
1642 )?;
1643 Self::verify_dshapes(
1644 &protos.dshape,
1645 &protos.shape,
1646 "Protos",
1647 &[
1648 DimName::Batch,
1649 DimName::Height,
1650 DimName::Width,
1651 DimName::NumProtos,
1652 ],
1653 )?;
1654
1655 let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[2]);
1656 let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[2]);
1657 let mask_num = Self::get_box_count(&mask_coeff.dshape).unwrap_or(mask_coeff.shape[2]);
1658
1659 let mask_channels = if !mask_coeff.dshape.is_empty() {
1660 Self::get_protos_count(&mask_coeff.dshape).ok_or_else(|| {
1661 DecoderError::InvalidConfig("Could not find num_protos in config".to_string())
1662 })?
1663 } else {
1664 mask_coeff.shape[1]
1665 };
1666 let proto_channels = if !protos.dshape.is_empty() {
1667 Self::get_protos_count(&protos.dshape).ok_or_else(|| {
1668 DecoderError::InvalidConfig("Could not find num_protos in config".to_string())
1669 })?
1670 } else {
1671 protos.shape[3]
1672 };
1673
1674 if boxes_num != scores_num {
1675 return Err(DecoderError::InvalidConfig(format!(
1676 "Yolo Split Detection Boxes num {} incompatible with Scores num {}",
1677 boxes_num, scores_num
1678 )));
1679 }
1680
1681 if boxes_num != mask_num {
1682 return Err(DecoderError::InvalidConfig(format!(
1683 "Yolo Split Detection Boxes num {} incompatible with Mask Coefficients num {}",
1684 boxes_num, mask_num
1685 )));
1686 }
1687
1688 if proto_channels != mask_channels {
1689 return Err(DecoderError::InvalidConfig(format!(
1690 "Yolo Protos channels {} incompatible with Mask Coefficients channels {}",
1691 proto_channels, mask_channels
1692 )));
1693 }
1694
1695 Ok(())
1696 }
1697
1698 fn get_model_type_modelpack(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
1699 let mut split_decoders = Vec::new();
1700 let mut segment_ = None;
1701 let mut scores_ = None;
1702 let mut boxes_ = None;
1703 for c in configs.outputs {
1704 match c {
1705 ConfigOutput::Detection(detection) => split_decoders.push(detection),
1706 ConfigOutput::Segmentation(segmentation) => segment_ = Some(segmentation),
1707 ConfigOutput::Mask(_) => {}
1708 ConfigOutput::Protos(_) => {
1709 return Err(DecoderError::InvalidConfig(
1710 "ModelPack should not have protos".to_string(),
1711 ));
1712 }
1713 ConfigOutput::Scores(scores) => scores_ = Some(scores),
1714 ConfigOutput::Boxes(boxes) => boxes_ = Some(boxes),
1715 ConfigOutput::MaskCoefficients(_) => {
1716 return Err(DecoderError::InvalidConfig(
1717 "ModelPack should not have mask coefficients".to_string(),
1718 ));
1719 }
1720 }
1721 }
1722
1723 if let Some(segmentation) = segment_ {
1724 if !split_decoders.is_empty() {
1725 let classes = Self::verify_modelpack_split_det(&split_decoders)?;
1726 Self::verify_modelpack_seg(&segmentation, Some(classes))?;
1727 Ok(ModelType::ModelPackSegDetSplit {
1728 detection: split_decoders,
1729 segmentation,
1730 })
1731 } else if let (Some(scores), Some(boxes)) = (scores_, boxes_) {
1732 let classes = Self::verify_modelpack_det(&boxes, &scores)?;
1733 Self::verify_modelpack_seg(&segmentation, Some(classes))?;
1734 Ok(ModelType::ModelPackSegDet {
1735 boxes,
1736 scores,
1737 segmentation,
1738 })
1739 } else {
1740 Self::verify_modelpack_seg(&segmentation, None)?;
1741 Ok(ModelType::ModelPackSeg { segmentation })
1742 }
1743 } else if !split_decoders.is_empty() {
1744 Self::verify_modelpack_split_det(&split_decoders)?;
1745 Ok(ModelType::ModelPackDetSplit {
1746 detection: split_decoders,
1747 })
1748 } else if let (Some(scores), Some(boxes)) = (scores_, boxes_) {
1749 Self::verify_modelpack_det(&boxes, &scores)?;
1750 Ok(ModelType::ModelPackDet { boxes, scores })
1751 } else {
1752 Err(DecoderError::InvalidConfig(
1753 "Invalid ModelPack model outputs".to_string(),
1754 ))
1755 }
1756 }
1757
1758 fn verify_modelpack_det(
1759 boxes: &configs::Boxes,
1760 scores: &configs::Scores,
1761 ) -> Result<usize, DecoderError> {
1762 if boxes.shape.len() != 4 {
1763 return Err(DecoderError::InvalidConfig(format!(
1764 "Invalid ModelPack Boxes shape {:?}",
1765 boxes.shape
1766 )));
1767 }
1768 if scores.shape.len() != 3 {
1769 return Err(DecoderError::InvalidConfig(format!(
1770 "Invalid ModelPack Scores shape {:?}",
1771 scores.shape
1772 )));
1773 }
1774
1775 Self::verify_dshapes(
1776 &boxes.dshape,
1777 &boxes.shape,
1778 "Boxes",
1779 &[
1780 DimName::Batch,
1781 DimName::NumBoxes,
1782 DimName::Padding,
1783 DimName::BoxCoords,
1784 ],
1785 )?;
1786 Self::verify_dshapes(
1787 &scores.dshape,
1788 &scores.shape,
1789 "Scores",
1790 &[DimName::Batch, DimName::NumBoxes, DimName::NumClasses],
1791 )?;
1792
1793 let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[1]);
1794 let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[1]);
1795
1796 if boxes_num != scores_num {
1797 return Err(DecoderError::InvalidConfig(format!(
1798 "ModelPack Detection Boxes num {} incompatible with Scores num {}",
1799 boxes_num, scores_num
1800 )));
1801 }
1802
1803 let num_classes = if !scores.dshape.is_empty() {
1804 Self::get_class_count(&scores.dshape, None, None)?
1805 } else {
1806 Self::get_class_count_no_dshape(scores.into(), None)?
1807 };
1808
1809 Ok(num_classes)
1810 }
1811
1812 fn verify_modelpack_split_det(boxes: &[configs::Detection]) -> Result<usize, DecoderError> {
1813 let mut num_classes = None;
1814 for b in boxes {
1815 let Some(num_anchors) = b.anchors.as_ref().map(|a| a.len()) else {
1816 return Err(DecoderError::InvalidConfig(
1817 "ModelPack Split Detection missing anchors".to_string(),
1818 ));
1819 };
1820
1821 if num_anchors == 0 {
1822 return Err(DecoderError::InvalidConfig(
1823 "ModelPack Split Detection has zero anchors".to_string(),
1824 ));
1825 }
1826
1827 if b.shape.len() != 4 {
1828 return Err(DecoderError::InvalidConfig(format!(
1829 "Invalid ModelPack Split Detection shape {:?}",
1830 b.shape
1831 )));
1832 }
1833
1834 Self::verify_dshapes(
1835 &b.dshape,
1836 &b.shape,
1837 "Split Detection",
1838 &[
1839 DimName::Batch,
1840 DimName::Height,
1841 DimName::Width,
1842 DimName::NumAnchorsXFeatures,
1843 ],
1844 )?;
1845 let classes = if !b.dshape.is_empty() {
1846 Self::get_class_count(&b.dshape, None, Some(num_anchors))?
1847 } else {
1848 Self::get_class_count_no_dshape(b.into(), None)?
1849 };
1850
1851 match num_classes {
1852 Some(n) => {
1853 if n != classes {
1854 return Err(DecoderError::InvalidConfig(format!(
1855 "ModelPack Split Detection inconsistent number of classes: previous {}, current {}",
1856 n, classes
1857 )));
1858 }
1859 }
1860 None => {
1861 num_classes = Some(classes);
1862 }
1863 }
1864 }
1865
1866 Ok(num_classes.unwrap_or(0))
1867 }
1868
1869 fn verify_modelpack_seg(
1870 segmentation: &configs::Segmentation,
1871 classes: Option<usize>,
1872 ) -> Result<(), DecoderError> {
1873 if segmentation.shape.len() != 4 {
1874 return Err(DecoderError::InvalidConfig(format!(
1875 "Invalid ModelPack Segmentation shape {:?}",
1876 segmentation.shape
1877 )));
1878 }
1879 Self::verify_dshapes(
1880 &segmentation.dshape,
1881 &segmentation.shape,
1882 "Segmentation",
1883 &[
1884 DimName::Batch,
1885 DimName::Height,
1886 DimName::Width,
1887 DimName::NumClasses,
1888 ],
1889 )?;
1890
1891 if let Some(classes) = classes {
1892 let seg_classes = if !segmentation.dshape.is_empty() {
1893 Self::get_class_count(&segmentation.dshape, None, None)?
1894 } else {
1895 Self::get_class_count_no_dshape(segmentation.into(), None)?
1896 };
1897
1898 if seg_classes != classes + 1 {
1899 return Err(DecoderError::InvalidConfig(format!(
1900 "ModelPack Segmentation channels {} incompatible with number of classes {}",
1901 seg_classes, classes
1902 )));
1903 }
1904 }
1905 Ok(())
1906 }
1907
1908 fn verify_dshapes(
1910 dshape: &[(DimName, usize)],
1911 shape: &[usize],
1912 name: &str,
1913 dims: &[DimName],
1914 ) -> Result<(), DecoderError> {
1915 for s in shape {
1916 if *s == 0 {
1917 return Err(DecoderError::InvalidConfig(format!(
1918 "{} shape has zero dimension",
1919 name
1920 )));
1921 }
1922 }
1923
1924 if shape.len() != dims.len() {
1925 return Err(DecoderError::InvalidConfig(format!(
1926 "{} shape length {} does not match expected dims length {}",
1927 name,
1928 shape.len(),
1929 dims.len()
1930 )));
1931 }
1932
1933 if dshape.is_empty() {
1934 return Ok(());
1935 }
1936 if dshape.len() != shape.len() {
1938 return Err(DecoderError::InvalidConfig(format!(
1939 "{} dshape length does not match shape length",
1940 name
1941 )));
1942 }
1943
1944 for ((dim_name, dim_size), shape_size) in dshape.iter().zip(shape) {
1946 if dim_size != shape_size {
1947 return Err(DecoderError::InvalidConfig(format!(
1948 "{} dshape dimension {} size {} does not match shape size {}",
1949 name, dim_name, dim_size, shape_size
1950 )));
1951 }
1952 if *dim_name == DimName::Padding && *dim_size != 1 {
1953 return Err(DecoderError::InvalidConfig(
1954 "Padding dimension size must be 1".to_string(),
1955 ));
1956 }
1957
1958 if *dim_name == DimName::BoxCoords && *dim_size != 4 {
1959 return Err(DecoderError::InvalidConfig(
1960 "BoxCoords dimension size must be 4".to_string(),
1961 ));
1962 }
1963 }
1964
1965 let dims_present = HashSet::<DimName>::from_iter(dshape.iter().map(|(name, _)| *name));
1966 for dim in dims {
1967 if !dims_present.contains(dim) {
1968 return Err(DecoderError::InvalidConfig(format!(
1969 "{} dshape missing required dimension {:?}",
1970 name, dim
1971 )));
1972 }
1973 }
1974
1975 Ok(())
1976 }
1977
1978 fn get_box_count(dshape: &[(DimName, usize)]) -> Option<usize> {
1979 for (dim_name, dim_size) in dshape {
1980 if *dim_name == DimName::NumBoxes {
1981 return Some(*dim_size);
1982 }
1983 }
1984 None
1985 }
1986
1987 fn get_class_count_no_dshape(
1988 config: ConfigOutputRef,
1989 protos: Option<usize>,
1990 ) -> Result<usize, DecoderError> {
1991 match config {
1992 ConfigOutputRef::Detection(detection) => match detection.decoder {
1993 DecoderType::Ultralytics => {
1994 if detection.shape[1] <= 4 + protos.unwrap_or(0) {
1995 return Err(DecoderError::InvalidConfig(format!(
1996 "Invalid shape: Yolo num_features {} must be greater than {}",
1997 detection.shape[1],
1998 4 + protos.unwrap_or(0),
1999 )));
2000 }
2001 Ok(detection.shape[1] - 4 - protos.unwrap_or(0))
2002 }
2003 DecoderType::ModelPack => {
2004 let Some(num_anchors) = detection.anchors.as_ref().map(|a| a.len()) else {
2005 return Err(DecoderError::Internal(
2006 "ModelPack Detection missing anchors".to_string(),
2007 ));
2008 };
2009 let anchors_x_features = detection.shape[3];
2010 if anchors_x_features <= num_anchors * 5 {
2011 return Err(DecoderError::InvalidConfig(format!(
2012 "Invalid ModelPack Split Detection shape: anchors_x_features {} not greater than number of anchors * 5 = {}",
2013 anchors_x_features,
2014 num_anchors * 5,
2015 )));
2016 }
2017
2018 if !anchors_x_features.is_multiple_of(num_anchors) {
2019 return Err(DecoderError::InvalidConfig(format!(
2020 "Invalid ModelPack Split Detection shape: anchors_x_features {} not a multiple of number of anchors {}",
2021 anchors_x_features, num_anchors
2022 )));
2023 }
2024 Ok(anchors_x_features / num_anchors - 5)
2025 }
2026 },
2027
2028 ConfigOutputRef::Scores(scores) => match scores.decoder {
2029 DecoderType::Ultralytics => Ok(scores.shape[1]),
2030 DecoderType::ModelPack => Ok(scores.shape[2]),
2031 },
2032 ConfigOutputRef::Segmentation(seg) => Ok(seg.shape[3]),
2033 _ => Err(DecoderError::Internal(
2034 "Attempted to get class count from unsupported config output".to_owned(),
2035 )),
2036 }
2037 }
2038
2039 fn get_class_count(
2041 dshape: &[(DimName, usize)],
2042 protos: Option<usize>,
2043 anchors: Option<usize>,
2044 ) -> Result<usize, DecoderError> {
2045 if dshape.is_empty() {
2046 return Ok(0);
2047 }
2048 for (dim_name, dim_size) in dshape {
2050 if *dim_name == DimName::NumClasses {
2051 return Ok(*dim_size);
2052 }
2053 }
2054
2055 for (dim_name, dim_size) in dshape {
2058 if *dim_name == DimName::NumFeatures {
2059 let protos = protos.unwrap_or(0);
2060 if protos + 4 >= *dim_size {
2061 return Err(DecoderError::InvalidConfig(format!(
2062 "Invalid shape: Yolo num_features {} must be greater than {}",
2063 *dim_size,
2064 protos + 4,
2065 )));
2066 }
2067 return Ok(*dim_size - 4 - protos);
2068 }
2069 }
2070
2071 if let Some(num_anchors) = anchors {
2074 for (dim_name, dim_size) in dshape {
2075 if *dim_name == DimName::NumAnchorsXFeatures {
2076 let anchors_x_features = *dim_size;
2077 if anchors_x_features <= num_anchors * 5 {
2078 return Err(DecoderError::InvalidConfig(format!(
2079 "Invalid ModelPack Split Detection shape: anchors_x_features {} not greater than number of anchors * 5 = {}",
2080 anchors_x_features,
2081 num_anchors * 5,
2082 )));
2083 }
2084
2085 if !anchors_x_features.is_multiple_of(num_anchors) {
2086 return Err(DecoderError::InvalidConfig(format!(
2087 "Invalid ModelPack Split Detection shape: anchors_x_features {} not a multiple of number of anchors {}",
2088 anchors_x_features, num_anchors
2089 )));
2090 }
2091 return Ok((anchors_x_features / num_anchors) - 5);
2092 }
2093 }
2094 }
2095 Err(DecoderError::InvalidConfig(
2096 "Cannot determine number of classes from dshape".to_owned(),
2097 ))
2098 }
2099
2100 fn get_protos_count(dshape: &[(DimName, usize)]) -> Option<usize> {
2101 for (dim_name, dim_size) in dshape {
2102 if *dim_name == DimName::NumProtos {
2103 return Some(*dim_size);
2104 }
2105 }
2106 None
2107 }
2108}
2109
2110#[derive(Debug, Clone, PartialEq)]
2111pub struct Decoder {
2112 model_type: ModelType,
2113 pub iou_threshold: f32,
2114 pub score_threshold: f32,
2115 pub nms: Option<configs::Nms>,
2118 normalized: Option<bool>,
2124}
2125
2126#[derive(Debug)]
2127pub enum ArrayViewDQuantized<'a> {
2128 UInt8(ArrayViewD<'a, u8>),
2129 Int8(ArrayViewD<'a, i8>),
2130 UInt16(ArrayViewD<'a, u16>),
2131 Int16(ArrayViewD<'a, i16>),
2132 UInt32(ArrayViewD<'a, u32>),
2133 Int32(ArrayViewD<'a, i32>),
2134}
2135
2136impl<'a, D> From<ArrayView<'a, u8, D>> for ArrayViewDQuantized<'a>
2137where
2138 D: Dimension,
2139{
2140 fn from(arr: ArrayView<'a, u8, D>) -> Self {
2141 Self::UInt8(arr.into_dyn())
2142 }
2143}
2144
2145impl<'a, D> From<ArrayView<'a, i8, D>> for ArrayViewDQuantized<'a>
2146where
2147 D: Dimension,
2148{
2149 fn from(arr: ArrayView<'a, i8, D>) -> Self {
2150 Self::Int8(arr.into_dyn())
2151 }
2152}
2153
2154impl<'a, D> From<ArrayView<'a, u16, D>> for ArrayViewDQuantized<'a>
2155where
2156 D: Dimension,
2157{
2158 fn from(arr: ArrayView<'a, u16, D>) -> Self {
2159 Self::UInt16(arr.into_dyn())
2160 }
2161}
2162
2163impl<'a, D> From<ArrayView<'a, i16, D>> for ArrayViewDQuantized<'a>
2164where
2165 D: Dimension,
2166{
2167 fn from(arr: ArrayView<'a, i16, D>) -> Self {
2168 Self::Int16(arr.into_dyn())
2169 }
2170}
2171
2172impl<'a, D> From<ArrayView<'a, u32, D>> for ArrayViewDQuantized<'a>
2173where
2174 D: Dimension,
2175{
2176 fn from(arr: ArrayView<'a, u32, D>) -> Self {
2177 Self::UInt32(arr.into_dyn())
2178 }
2179}
2180
2181impl<'a, D> From<ArrayView<'a, i32, D>> for ArrayViewDQuantized<'a>
2182where
2183 D: Dimension,
2184{
2185 fn from(arr: ArrayView<'a, i32, D>) -> Self {
2186 Self::Int32(arr.into_dyn())
2187 }
2188}
2189
2190impl<'a> ArrayViewDQuantized<'a> {
2191 pub fn shape(&self) -> &[usize] {
2205 match self {
2206 ArrayViewDQuantized::UInt8(a) => a.shape(),
2207 ArrayViewDQuantized::Int8(a) => a.shape(),
2208 ArrayViewDQuantized::UInt16(a) => a.shape(),
2209 ArrayViewDQuantized::Int16(a) => a.shape(),
2210 ArrayViewDQuantized::UInt32(a) => a.shape(),
2211 ArrayViewDQuantized::Int32(a) => a.shape(),
2212 }
2213 }
2214}
2215
2216macro_rules! with_quantized {
2217 ($x:expr, $var:ident, $body:expr) => {
2218 match $x {
2219 ArrayViewDQuantized::UInt8(x) => {
2220 let $var = x;
2221 $body
2222 }
2223 ArrayViewDQuantized::Int8(x) => {
2224 let $var = x;
2225 $body
2226 }
2227 ArrayViewDQuantized::UInt16(x) => {
2228 let $var = x;
2229 $body
2230 }
2231 ArrayViewDQuantized::Int16(x) => {
2232 let $var = x;
2233 $body
2234 }
2235 ArrayViewDQuantized::UInt32(x) => {
2236 let $var = x;
2237 $body
2238 }
2239 ArrayViewDQuantized::Int32(x) => {
2240 let $var = x;
2241 $body
2242 }
2243 }
2244 };
2245}
2246
2247impl Decoder {
2248 pub fn model_type(&self) -> &ModelType {
2267 &self.model_type
2268 }
2269
2270 pub fn normalized_boxes(&self) -> Option<bool> {
2296 self.normalized
2297 }
2298
2299 pub fn decode_quantized(
2349 &self,
2350 outputs: &[ArrayViewDQuantized],
2351 output_boxes: &mut Vec<DetectBox>,
2352 output_masks: &mut Vec<Segmentation>,
2353 ) -> Result<(), DecoderError> {
2354 output_boxes.clear();
2355 output_masks.clear();
2356 match &self.model_type {
2357 ModelType::ModelPackSegDet {
2358 boxes,
2359 scores,
2360 segmentation,
2361 } => {
2362 self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)?;
2363 self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
2364 }
2365 ModelType::ModelPackSegDetSplit {
2366 detection,
2367 segmentation,
2368 } => {
2369 self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)?;
2370 self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
2371 }
2372 ModelType::ModelPackDet { boxes, scores } => {
2373 self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)
2374 }
2375 ModelType::ModelPackDetSplit { detection } => {
2376 self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)
2377 }
2378 ModelType::ModelPackSeg { segmentation } => {
2379 self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
2380 }
2381 ModelType::YoloDet { boxes } => {
2382 self.decode_yolo_det_quantized(outputs, boxes, output_boxes)
2383 }
2384 ModelType::YoloSegDet { boxes, protos } => self.decode_yolo_segdet_quantized(
2385 outputs,
2386 boxes,
2387 protos,
2388 output_boxes,
2389 output_masks,
2390 ),
2391 ModelType::YoloSplitDet { boxes, scores } => {
2392 self.decode_yolo_split_det_quantized(outputs, boxes, scores, output_boxes)
2393 }
2394 ModelType::YoloSplitSegDet {
2395 boxes,
2396 scores,
2397 mask_coeff,
2398 protos,
2399 } => self.decode_yolo_split_segdet_quantized(
2400 outputs,
2401 boxes,
2402 scores,
2403 mask_coeff,
2404 protos,
2405 output_boxes,
2406 output_masks,
2407 ),
2408 ModelType::YoloEndToEndDet { .. } | ModelType::YoloEndToEndSegDet { .. } => {
2409 Err(DecoderError::InvalidConfig(
2410 "End-to-end models require float decode, not quantized".to_string(),
2411 ))
2412 }
2413 }
2414 }
2415
2416 pub fn decode_float<T>(
2473 &self,
2474 outputs: &[ArrayViewD<T>],
2475 output_boxes: &mut Vec<DetectBox>,
2476 output_masks: &mut Vec<Segmentation>,
2477 ) -> Result<(), DecoderError>
2478 where
2479 T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
2480 f32: AsPrimitive<T>,
2481 {
2482 output_boxes.clear();
2483 output_masks.clear();
2484 match &self.model_type {
2485 ModelType::ModelPackSegDet {
2486 boxes,
2487 scores,
2488 segmentation,
2489 } => {
2490 self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
2491 self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
2492 }
2493 ModelType::ModelPackSegDetSplit {
2494 detection,
2495 segmentation,
2496 } => {
2497 self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
2498 self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
2499 }
2500 ModelType::ModelPackDet { boxes, scores } => {
2501 self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
2502 }
2503 ModelType::ModelPackDetSplit { detection } => {
2504 self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
2505 }
2506 ModelType::ModelPackSeg { segmentation } => {
2507 self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
2508 }
2509 ModelType::YoloDet { boxes } => {
2510 self.decode_yolo_det_float(outputs, boxes, output_boxes)?;
2511 }
2512 ModelType::YoloSegDet { boxes, protos } => {
2513 self.decode_yolo_segdet_float(outputs, boxes, protos, output_boxes, output_masks)?;
2514 }
2515 ModelType::YoloSplitDet { boxes, scores } => {
2516 self.decode_yolo_split_det_float(outputs, boxes, scores, output_boxes)?;
2517 }
2518 ModelType::YoloSplitSegDet {
2519 boxes,
2520 scores,
2521 mask_coeff,
2522 protos,
2523 } => {
2524 self.decode_yolo_split_segdet_float(
2525 outputs,
2526 boxes,
2527 scores,
2528 mask_coeff,
2529 protos,
2530 output_boxes,
2531 output_masks,
2532 )?;
2533 }
2534 ModelType::YoloEndToEndDet { boxes } => {
2535 self.decode_yolo_end_to_end_det_float(outputs, boxes, output_boxes)?;
2536 }
2537 ModelType::YoloEndToEndSegDet { boxes, protos } => {
2538 self.decode_yolo_end_to_end_segdet_float(
2539 outputs,
2540 boxes,
2541 protos,
2542 output_boxes,
2543 output_masks,
2544 )?;
2545 }
2546 }
2547 Ok(())
2548 }
2549
2550 fn decode_modelpack_det_quantized(
2551 &self,
2552 outputs: &[ArrayViewDQuantized],
2553 boxes: &configs::Boxes,
2554 scores: &configs::Scores,
2555 output_boxes: &mut Vec<DetectBox>,
2556 ) -> Result<(), DecoderError> {
2557 let (boxes_tensor, ind) =
2558 Self::find_outputs_with_shape_quantized(&boxes.shape, outputs, &[])?;
2559 let (scores_tensor, _) =
2560 Self::find_outputs_with_shape_quantized(&scores.shape, outputs, &[ind])?;
2561 let quant_boxes = boxes
2562 .quantization
2563 .map(Quantization::from)
2564 .unwrap_or_default();
2565 let quant_scores = scores
2566 .quantization
2567 .map(Quantization::from)
2568 .unwrap_or_default();
2569
2570 with_quantized!(boxes_tensor, b, {
2571 with_quantized!(scores_tensor, s, {
2572 let boxes_tensor = Self::swap_axes_if_needed(b, boxes.into());
2573 let boxes_tensor = boxes_tensor.slice(s![0, .., 0, ..]);
2574
2575 let scores_tensor = Self::swap_axes_if_needed(s, scores.into());
2576 let scores_tensor = scores_tensor.slice(s![0, .., ..]);
2577 decode_modelpack_det(
2578 (boxes_tensor, quant_boxes),
2579 (scores_tensor, quant_scores),
2580 self.score_threshold,
2581 self.iou_threshold,
2582 output_boxes,
2583 );
2584 });
2585 });
2586
2587 Ok(())
2588 }
2589
2590 fn decode_modelpack_seg_quantized(
2591 &self,
2592 outputs: &[ArrayViewDQuantized],
2593 segmentation: &configs::Segmentation,
2594 output_masks: &mut Vec<Segmentation>,
2595 ) -> Result<(), DecoderError> {
2596 let (seg, _) = Self::find_outputs_with_shape_quantized(&segmentation.shape, outputs, &[])?;
2597
2598 macro_rules! modelpack_seg {
2599 ($seg:expr, $body:expr) => {{
2600 let seg = Self::swap_axes_if_needed($seg, segmentation.into());
2601 let seg = seg.slice(s![0, .., .., ..]);
2602 seg.mapv($body)
2603 }};
2604 }
2605 use ArrayViewDQuantized::*;
2606 let seg = match seg {
2607 UInt8(s) => {
2608 modelpack_seg!(s, |x| x)
2609 }
2610 Int8(s) => {
2611 modelpack_seg!(s, |x| (x as i16 + 128) as u8)
2612 }
2613 UInt16(s) => {
2614 modelpack_seg!(s, |x| (x >> 8) as u8)
2615 }
2616 Int16(s) => {
2617 modelpack_seg!(s, |x| ((x as i32 + 32768) >> 8) as u8)
2618 }
2619 UInt32(s) => {
2620 modelpack_seg!(s, |x| (x >> 24) as u8)
2621 }
2622 Int32(s) => {
2623 modelpack_seg!(s, |x| ((x as i64 + 2147483648) >> 24) as u8)
2624 }
2625 };
2626
2627 output_masks.push(Segmentation {
2628 xmin: 0.0,
2629 ymin: 0.0,
2630 xmax: 1.0,
2631 ymax: 1.0,
2632 segmentation: seg,
2633 });
2634 Ok(())
2635 }
2636
2637 fn decode_modelpack_det_split_quantized(
2638 &self,
2639 outputs: &[ArrayViewDQuantized],
2640 detection: &[configs::Detection],
2641 output_boxes: &mut Vec<DetectBox>,
2642 ) -> Result<(), DecoderError> {
2643 let new_detection = detection
2644 .iter()
2645 .map(|x| match &x.anchors {
2646 None => Err(DecoderError::InvalidConfig(
2647 "ModelPack Split Detection missing anchors".to_string(),
2648 )),
2649 Some(a) => Ok(ModelPackDetectionConfig {
2650 anchors: a.clone(),
2651 quantization: None,
2652 }),
2653 })
2654 .collect::<Result<Vec<_>, _>>()?;
2655 let new_outputs = Self::match_outputs_to_detect_quantized(detection, outputs)?;
2656
2657 macro_rules! dequant_output {
2658 ($det_tensor:expr, $detection:expr) => {{
2659 let det_tensor = Self::swap_axes_if_needed($det_tensor, $detection.into());
2660 let det_tensor = det_tensor.slice(s![0, .., .., ..]);
2661 if let Some(q) = $detection.quantization {
2662 dequantize_ndarray(det_tensor, q.into())
2663 } else {
2664 det_tensor.map(|x| *x as f32)
2665 }
2666 }};
2667 }
2668
2669 let new_outputs = new_outputs
2670 .iter()
2671 .zip(detection)
2672 .map(|(det_tensor, detection)| {
2673 with_quantized!(det_tensor, d, dequant_output!(d, detection))
2674 })
2675 .collect::<Vec<_>>();
2676
2677 let new_outputs_view = new_outputs
2678 .iter()
2679 .map(|d: &Array3<f32>| d.view())
2680 .collect::<Vec<_>>();
2681 decode_modelpack_split_float(
2682 &new_outputs_view,
2683 &new_detection,
2684 self.score_threshold,
2685 self.iou_threshold,
2686 output_boxes,
2687 );
2688 Ok(())
2689 }
2690
2691 fn decode_yolo_det_quantized(
2692 &self,
2693 outputs: &[ArrayViewDQuantized],
2694 boxes: &configs::Detection,
2695 output_boxes: &mut Vec<DetectBox>,
2696 ) -> Result<(), DecoderError> {
2697 let (boxes_tensor, _) =
2698 Self::find_outputs_with_shape_quantized(&boxes.shape, outputs, &[])?;
2699 let quant_boxes = boxes
2700 .quantization
2701 .map(Quantization::from)
2702 .unwrap_or_default();
2703
2704 with_quantized!(boxes_tensor, b, {
2705 let boxes_tensor = Self::swap_axes_if_needed(b, boxes.into());
2706 let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
2707 decode_yolo_det(
2708 (boxes_tensor, quant_boxes),
2709 self.score_threshold,
2710 self.iou_threshold,
2711 self.nms,
2712 output_boxes,
2713 );
2714 });
2715
2716 Ok(())
2717 }
2718
2719 fn decode_yolo_segdet_quantized(
2720 &self,
2721 outputs: &[ArrayViewDQuantized],
2722 boxes: &configs::Detection,
2723 protos: &configs::Protos,
2724 output_boxes: &mut Vec<DetectBox>,
2725 output_masks: &mut Vec<Segmentation>,
2726 ) -> Result<(), DecoderError> {
2727 let (boxes_tensor, ind) =
2728 Self::find_outputs_with_shape_quantized(&boxes.shape, outputs, &[])?;
2729 let (protos_tensor, _) =
2730 Self::find_outputs_with_shape_quantized(&protos.shape, outputs, &[ind])?;
2731
2732 let quant_boxes = boxes
2733 .quantization
2734 .map(Quantization::from)
2735 .unwrap_or_default();
2736 let quant_protos = protos
2737 .quantization
2738 .map(Quantization::from)
2739 .unwrap_or_default();
2740
2741 with_quantized!(boxes_tensor, b, {
2742 with_quantized!(protos_tensor, p, {
2743 let box_tensor = Self::swap_axes_if_needed(b, boxes.into());
2744 let box_tensor = box_tensor.slice(s![0, .., ..]);
2745
2746 let protos_tensor = Self::swap_axes_if_needed(p, protos.into());
2747 let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
2748 decode_yolo_segdet_quant(
2749 (box_tensor, quant_boxes),
2750 (protos_tensor, quant_protos),
2751 self.score_threshold,
2752 self.iou_threshold,
2753 self.nms,
2754 output_boxes,
2755 output_masks,
2756 );
2757 });
2758 });
2759
2760 Ok(())
2761 }
2762
2763 fn decode_yolo_split_det_quantized(
2764 &self,
2765 outputs: &[ArrayViewDQuantized],
2766 boxes: &configs::Boxes,
2767 scores: &configs::Scores,
2768 output_boxes: &mut Vec<DetectBox>,
2769 ) -> Result<(), DecoderError> {
2770 let (boxes_tensor, ind) =
2771 Self::find_outputs_with_shape_quantized(&boxes.shape, outputs, &[])?;
2772 let (scores_tensor, _) =
2773 Self::find_outputs_with_shape_quantized(&scores.shape, outputs, &[ind])?;
2774 let quant_boxes = boxes
2775 .quantization
2776 .map(Quantization::from)
2777 .unwrap_or_default();
2778 let quant_scores = scores
2779 .quantization
2780 .map(Quantization::from)
2781 .unwrap_or_default();
2782
2783 with_quantized!(boxes_tensor, b, {
2784 with_quantized!(scores_tensor, s, {
2785 let boxes_tensor = Self::swap_axes_if_needed(b, boxes.into());
2786 let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
2787
2788 let scores_tensor = Self::swap_axes_if_needed(s, scores.into());
2789 let scores_tensor = scores_tensor.slice(s![0, .., ..]);
2790 decode_yolo_split_det_quant(
2791 (boxes_tensor, quant_boxes),
2792 (scores_tensor, quant_scores),
2793 self.score_threshold,
2794 self.iou_threshold,
2795 self.nms,
2796 output_boxes,
2797 );
2798 });
2799 });
2800
2801 Ok(())
2802 }
2803
2804 #[allow(clippy::too_many_arguments)]
2805 fn decode_yolo_split_segdet_quantized(
2806 &self,
2807 outputs: &[ArrayViewDQuantized],
2808 boxes: &configs::Boxes,
2809 scores: &configs::Scores,
2810 mask_coeff: &configs::MaskCoefficients,
2811 protos: &configs::Protos,
2812 output_boxes: &mut Vec<DetectBox>,
2813 output_masks: &mut Vec<Segmentation>,
2814 ) -> Result<(), DecoderError> {
2815 let quant_boxes = boxes
2816 .quantization
2817 .map(Quantization::from)
2818 .unwrap_or_default();
2819 let quant_scores = scores
2820 .quantization
2821 .map(Quantization::from)
2822 .unwrap_or_default();
2823 let quant_masks = mask_coeff
2824 .quantization
2825 .map(Quantization::from)
2826 .unwrap_or_default();
2827 let quant_protos = protos
2828 .quantization
2829 .map(Quantization::from)
2830 .unwrap_or_default();
2831
2832 let mut skip = vec![];
2833
2834 let (boxes_tensor, ind) =
2835 Self::find_outputs_with_shape_quantized(&boxes.shape, outputs, &skip)?;
2836 skip.push(ind);
2837
2838 let (scores_tensor, ind) =
2839 Self::find_outputs_with_shape_quantized(&scores.shape, outputs, &skip)?;
2840 skip.push(ind);
2841
2842 let (mask_tensor, ind) =
2843 Self::find_outputs_with_shape_quantized(&mask_coeff.shape, outputs, &skip)?;
2844 skip.push(ind);
2845
2846 let (protos_tensor, _) =
2847 Self::find_outputs_with_shape_quantized(&protos.shape, outputs, &skip)?;
2848
2849 let boxes = with_quantized!(boxes_tensor, b, {
2850 with_quantized!(scores_tensor, s, {
2851 let boxes_tensor = Self::swap_axes_if_needed(b, boxes.into());
2852 let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
2853
2854 let scores_tensor = Self::swap_axes_if_needed(s, scores.into());
2855 let scores_tensor = scores_tensor.slice(s![0, .., ..]);
2856 impl_yolo_split_segdet_quant_get_boxes::<XYWH, _, _>(
2857 (boxes_tensor, quant_boxes),
2858 (scores_tensor, quant_scores),
2859 self.score_threshold,
2860 self.iou_threshold,
2861 self.nms,
2862 output_boxes.capacity(),
2863 )
2864 })
2865 });
2866
2867 with_quantized!(mask_tensor, m, {
2868 with_quantized!(protos_tensor, p, {
2869 let mask_tensor = Self::swap_axes_if_needed(m, mask_coeff.into());
2870 let mask_tensor = mask_tensor.slice(s![0, .., ..]);
2871
2872 let protos_tensor = Self::swap_axes_if_needed(p, protos.into());
2873 let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
2874 impl_yolo_split_segdet_quant_process_masks::<_, _>(
2875 boxes,
2876 (mask_tensor, quant_masks),
2877 (protos_tensor, quant_protos),
2878 output_boxes,
2879 output_masks,
2880 )
2881 })
2882 });
2883
2884 Ok(())
2885 }
2886
2887 fn decode_modelpack_det_split_float<D>(
2888 &self,
2889 outputs: &[ArrayViewD<D>],
2890 detection: &[configs::Detection],
2891 output_boxes: &mut Vec<DetectBox>,
2892 ) -> Result<(), DecoderError>
2893 where
2894 D: AsPrimitive<f32>,
2895 {
2896 let new_detection = detection
2897 .iter()
2898 .map(|x| match &x.anchors {
2899 None => Err(DecoderError::InvalidConfig(
2900 "ModelPack Split Detection missing anchors".to_string(),
2901 )),
2902 Some(a) => Ok(ModelPackDetectionConfig {
2903 anchors: a.clone(),
2904 quantization: None,
2905 }),
2906 })
2907 .collect::<Result<Vec<_>, _>>()?;
2908
2909 let new_outputs = Self::match_outputs_to_detect(detection, outputs)?;
2910 let new_outputs = new_outputs
2911 .into_iter()
2912 .map(|x| x.slice(s![0, .., .., ..]))
2913 .collect::<Vec<_>>();
2914
2915 decode_modelpack_split_float(
2916 &new_outputs,
2917 &new_detection,
2918 self.score_threshold,
2919 self.iou_threshold,
2920 output_boxes,
2921 );
2922 Ok(())
2923 }
2924
2925 fn decode_modelpack_seg_float<T>(
2926 &self,
2927 outputs: &[ArrayViewD<T>],
2928 segmentation: &configs::Segmentation,
2929 output_masks: &mut Vec<Segmentation>,
2930 ) -> Result<(), DecoderError>
2931 where
2932 T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
2933 f32: AsPrimitive<T>,
2934 {
2935 let (seg, _) = Self::find_outputs_with_shape(&segmentation.shape, outputs, &[])?;
2936
2937 let seg = Self::swap_axes_if_needed(seg, segmentation.into());
2938 let seg = seg.slice(s![0, .., .., ..]);
2939 let u8_max = 255.0_f32.as_();
2940 let max = *seg.max().unwrap_or(&u8_max);
2941 let min = *seg.min().unwrap_or(&0.0_f32.as_());
2942 let seg = seg.mapv(|x| ((x - min) / (max - min) * u8_max).as_());
2943 output_masks.push(Segmentation {
2944 xmin: 0.0,
2945 ymin: 0.0,
2946 xmax: 1.0,
2947 ymax: 1.0,
2948 segmentation: seg,
2949 });
2950 Ok(())
2951 }
2952
2953 fn decode_modelpack_det_float<T>(
2954 &self,
2955 outputs: &[ArrayViewD<T>],
2956 boxes: &configs::Boxes,
2957 scores: &configs::Scores,
2958 output_boxes: &mut Vec<DetectBox>,
2959 ) -> Result<(), DecoderError>
2960 where
2961 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
2962 f32: AsPrimitive<T>,
2963 {
2964 let (boxes_tensor, ind) = Self::find_outputs_with_shape(&boxes.shape, outputs, &[])?;
2965
2966 let boxes_tensor = Self::swap_axes_if_needed(boxes_tensor, boxes.into());
2967 let boxes_tensor = boxes_tensor.slice(s![0, .., 0, ..]);
2968
2969 let (scores_tensor, _) = Self::find_outputs_with_shape(&scores.shape, outputs, &[ind])?;
2970 let scores_tensor = Self::swap_axes_if_needed(scores_tensor, scores.into());
2971 let scores_tensor = scores_tensor.slice(s![0, .., ..]);
2972
2973 decode_modelpack_float(
2974 boxes_tensor,
2975 scores_tensor,
2976 self.score_threshold,
2977 self.iou_threshold,
2978 output_boxes,
2979 );
2980 Ok(())
2981 }
2982
2983 fn decode_yolo_det_float<T>(
2984 &self,
2985 outputs: &[ArrayViewD<T>],
2986 boxes: &configs::Detection,
2987 output_boxes: &mut Vec<DetectBox>,
2988 ) -> Result<(), DecoderError>
2989 where
2990 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
2991 f32: AsPrimitive<T>,
2992 {
2993 let (boxes_tensor, _) = Self::find_outputs_with_shape(&boxes.shape, outputs, &[])?;
2994
2995 let boxes_tensor = Self::swap_axes_if_needed(boxes_tensor, boxes.into());
2996 let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
2997 decode_yolo_det_float(
2998 boxes_tensor,
2999 self.score_threshold,
3000 self.iou_threshold,
3001 self.nms,
3002 output_boxes,
3003 );
3004 Ok(())
3005 }
3006
3007 fn decode_yolo_segdet_float<T>(
3008 &self,
3009 outputs: &[ArrayViewD<T>],
3010 boxes: &configs::Detection,
3011 protos: &configs::Protos,
3012 output_boxes: &mut Vec<DetectBox>,
3013 output_masks: &mut Vec<Segmentation>,
3014 ) -> Result<(), DecoderError>
3015 where
3016 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
3017 f32: AsPrimitive<T>,
3018 {
3019 let (boxes_tensor, ind) = Self::find_outputs_with_shape(&boxes.shape, outputs, &[])?;
3020
3021 let boxes_tensor = Self::swap_axes_if_needed(boxes_tensor, boxes.into());
3022 let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
3023
3024 let (protos_tensor, _) = Self::find_outputs_with_shape(&protos.shape, outputs, &[ind])?;
3025
3026 let protos_tensor = Self::swap_axes_if_needed(protos_tensor, protos.into());
3027 let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
3028 decode_yolo_segdet_float(
3029 boxes_tensor,
3030 protos_tensor,
3031 self.score_threshold,
3032 self.iou_threshold,
3033 self.nms,
3034 output_boxes,
3035 output_masks,
3036 );
3037 Ok(())
3038 }
3039
3040 fn decode_yolo_split_det_float<T>(
3041 &self,
3042 outputs: &[ArrayViewD<T>],
3043 boxes: &configs::Boxes,
3044 scores: &configs::Scores,
3045 output_boxes: &mut Vec<DetectBox>,
3046 ) -> Result<(), DecoderError>
3047 where
3048 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
3049 f32: AsPrimitive<T>,
3050 {
3051 let (boxes_tensor, ind) = Self::find_outputs_with_shape(&boxes.shape, outputs, &[])?;
3052 let boxes_tensor = Self::swap_axes_if_needed(boxes_tensor, boxes.into());
3053 let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
3054
3055 let (scores_tensor, _) = Self::find_outputs_with_shape(&scores.shape, outputs, &[ind])?;
3056
3057 let scores_tensor = Self::swap_axes_if_needed(scores_tensor, scores.into());
3058 let scores_tensor = scores_tensor.slice(s![0, .., ..]);
3059
3060 decode_yolo_split_det_float(
3061 boxes_tensor,
3062 scores_tensor,
3063 self.score_threshold,
3064 self.iou_threshold,
3065 self.nms,
3066 output_boxes,
3067 );
3068 Ok(())
3069 }
3070
3071 #[allow(clippy::too_many_arguments)]
3072 fn decode_yolo_split_segdet_float<T>(
3073 &self,
3074 outputs: &[ArrayViewD<T>],
3075 boxes: &configs::Boxes,
3076 scores: &configs::Scores,
3077 mask_coeff: &configs::MaskCoefficients,
3078 protos: &configs::Protos,
3079 output_boxes: &mut Vec<DetectBox>,
3080 output_masks: &mut Vec<Segmentation>,
3081 ) -> Result<(), DecoderError>
3082 where
3083 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
3084 f32: AsPrimitive<T>,
3085 {
3086 let mut skip = vec![];
3087 let (boxes_tensor, ind) = Self::find_outputs_with_shape(&boxes.shape, outputs, &skip)?;
3088
3089 let boxes_tensor = Self::swap_axes_if_needed(boxes_tensor, boxes.into());
3090 let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
3091 skip.push(ind);
3092
3093 let (scores_tensor, ind) = Self::find_outputs_with_shape(&scores.shape, outputs, &skip)?;
3094
3095 let scores_tensor = Self::swap_axes_if_needed(scores_tensor, scores.into());
3096 let scores_tensor = scores_tensor.slice(s![0, .., ..]);
3097 skip.push(ind);
3098
3099 let (mask_tensor, ind) = Self::find_outputs_with_shape(&mask_coeff.shape, outputs, &skip)?;
3100 let mask_tensor = Self::swap_axes_if_needed(mask_tensor, mask_coeff.into());
3101 let mask_tensor = mask_tensor.slice(s![0, .., ..]);
3102 skip.push(ind);
3103
3104 let (protos_tensor, _) = Self::find_outputs_with_shape(&protos.shape, outputs, &skip)?;
3105 let protos_tensor = Self::swap_axes_if_needed(protos_tensor, protos.into());
3106 let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
3107 decode_yolo_split_segdet_float(
3108 boxes_tensor,
3109 scores_tensor,
3110 mask_tensor,
3111 protos_tensor,
3112 self.score_threshold,
3113 self.iou_threshold,
3114 self.nms,
3115 output_boxes,
3116 output_masks,
3117 );
3118 Ok(())
3119 }
3120
3121 fn decode_yolo_end_to_end_det_float<T>(
3127 &self,
3128 outputs: &[ArrayViewD<T>],
3129 boxes_config: &configs::Detection,
3130 output_boxes: &mut Vec<DetectBox>,
3131 ) -> Result<(), DecoderError>
3132 where
3133 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
3134 f32: AsPrimitive<T>,
3135 {
3136 let (det_tensor, _) = Self::find_outputs_with_shape(&boxes_config.shape, outputs, &[])?;
3137 let det_tensor = Self::swap_axes_if_needed(det_tensor, boxes_config.into());
3138 let det_tensor = det_tensor.slice(s![0, .., ..]);
3139
3140 crate::yolo::decode_yolo_end_to_end_det_float(
3141 det_tensor,
3142 self.score_threshold,
3143 output_boxes,
3144 )?;
3145 Ok(())
3146 }
3147
3148 fn decode_yolo_end_to_end_segdet_float<T>(
3156 &self,
3157 outputs: &[ArrayViewD<T>],
3158 boxes_config: &configs::Detection,
3159 protos_config: &configs::Protos,
3160 output_boxes: &mut Vec<DetectBox>,
3161 output_masks: &mut Vec<Segmentation>,
3162 ) -> Result<(), DecoderError>
3163 where
3164 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
3165 f32: AsPrimitive<T>,
3166 {
3167 if outputs.len() < 2 {
3168 return Err(DecoderError::InvalidShape(
3169 "End-to-end segdet requires detection and protos outputs".to_string(),
3170 ));
3171 }
3172
3173 let (det_tensor, det_ind) =
3174 Self::find_outputs_with_shape(&boxes_config.shape, outputs, &[])?;
3175 let det_tensor = Self::swap_axes_if_needed(det_tensor, boxes_config.into());
3176 let det_tensor = det_tensor.slice(s![0, .., ..]);
3177
3178 let (protos_tensor, _) =
3179 Self::find_outputs_with_shape(&protos_config.shape, outputs, &[det_ind])?;
3180 let protos_tensor = Self::swap_axes_if_needed(protos_tensor, protos_config.into());
3181 let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
3182
3183 crate::yolo::decode_yolo_end_to_end_segdet_float(
3184 det_tensor,
3185 protos_tensor,
3186 self.score_threshold,
3187 output_boxes,
3188 output_masks,
3189 )?;
3190 Ok(())
3191 }
3192
3193 fn match_outputs_to_detect<'a, 'b, T>(
3194 configs: &[configs::Detection],
3195 outputs: &'a [ArrayViewD<'b, T>],
3196 ) -> Result<Vec<&'a ArrayViewD<'b, T>>, DecoderError> {
3197 let mut new_output_order = Vec::new();
3198 for c in configs {
3199 let mut found = false;
3200 for o in outputs {
3201 if o.shape() == c.shape {
3202 new_output_order.push(o);
3203 found = true;
3204 break;
3205 }
3206 }
3207 if !found {
3208 return Err(DecoderError::InvalidShape(format!(
3209 "Did not find output with shape {:?}",
3210 c.shape
3211 )));
3212 }
3213 }
3214 Ok(new_output_order)
3215 }
3216
3217 fn find_outputs_with_shape<'a, 'b, T>(
3218 shape: &[usize],
3219 outputs: &'a [ArrayViewD<'b, T>],
3220 skip: &[usize],
3221 ) -> Result<(&'a ArrayViewD<'b, T>, usize), DecoderError> {
3222 for (ind, o) in outputs.iter().enumerate() {
3223 if skip.contains(&ind) {
3224 continue;
3225 }
3226 if o.shape() == shape {
3227 return Ok((o, ind));
3228 }
3229 }
3230 Err(DecoderError::InvalidShape(format!(
3231 "Did not find output with shape {:?}",
3232 shape
3233 )))
3234 }
3235
3236 fn find_outputs_with_shape_quantized<'a, 'b>(
3237 shape: &[usize],
3238 outputs: &'a [ArrayViewDQuantized<'b>],
3239 skip: &[usize],
3240 ) -> Result<(&'a ArrayViewDQuantized<'b>, usize), DecoderError> {
3241 for (ind, o) in outputs.iter().enumerate() {
3242 if skip.contains(&ind) {
3243 continue;
3244 }
3245 if o.shape() == shape {
3246 return Ok((o, ind));
3247 }
3248 }
3249 Err(DecoderError::InvalidShape(format!(
3250 "Did not find output with shape {:?}",
3251 shape
3252 )))
3253 }
3254
3255 fn modelpack_det_order(x: DimName) -> usize {
3258 match x {
3259 DimName::Batch => 0,
3260 DimName::NumBoxes => 1,
3261 DimName::Padding => 2,
3262 DimName::BoxCoords => 3,
3263 _ => 1000, }
3265 }
3266
3267 fn yolo_det_order(x: DimName) -> usize {
3270 match x {
3271 DimName::Batch => 0,
3272 DimName::NumFeatures => 1,
3273 DimName::NumBoxes => 2,
3274 _ => 1000, }
3276 }
3277
3278 fn modelpack_boxes_order(x: DimName) -> usize {
3281 match x {
3282 DimName::Batch => 0,
3283 DimName::NumBoxes => 1,
3284 DimName::Padding => 2,
3285 DimName::BoxCoords => 3,
3286 _ => 1000, }
3288 }
3289
3290 fn yolo_boxes_order(x: DimName) -> usize {
3293 match x {
3294 DimName::Batch => 0,
3295 DimName::BoxCoords => 1,
3296 DimName::NumBoxes => 2,
3297 _ => 1000, }
3299 }
3300
3301 fn modelpack_scores_order(x: DimName) -> usize {
3304 match x {
3305 DimName::Batch => 0,
3306 DimName::NumBoxes => 1,
3307 DimName::NumClasses => 2,
3308 _ => 1000, }
3310 }
3311
3312 fn yolo_scores_order(x: DimName) -> usize {
3313 match x {
3314 DimName::Batch => 0,
3315 DimName::NumClasses => 1,
3316 DimName::NumBoxes => 2,
3317 _ => 1000, }
3319 }
3320
3321 fn modelpack_segmentation_order(x: DimName) -> usize {
3324 match x {
3325 DimName::Batch => 0,
3326 DimName::Height => 1,
3327 DimName::Width => 2,
3328 DimName::NumClasses => 3,
3329 _ => 1000, }
3331 }
3332
3333 fn modelpack_mask_order(x: DimName) -> usize {
3336 match x {
3337 DimName::Batch => 0,
3338 DimName::Height => 1,
3339 DimName::Width => 2,
3340 _ => 1000, }
3342 }
3343
3344 fn yolo_protos_order(x: DimName) -> usize {
3347 match x {
3348 DimName::Batch => 0,
3349 DimName::Height => 1,
3350 DimName::Width => 2,
3351 DimName::NumProtos => 3,
3352 _ => 1000, }
3354 }
3355
3356 fn yolo_maskcoefficients_order(x: DimName) -> usize {
3359 match x {
3360 DimName::Batch => 0,
3361 DimName::NumProtos => 1,
3362 DimName::NumBoxes => 2,
3363 _ => 1000, }
3365 }
3366
3367 fn get_order_fn(config: ConfigOutputRef) -> fn(DimName) -> usize {
3368 let decoder_type = config.decoder();
3369 match (config, decoder_type) {
3370 (ConfigOutputRef::Detection(_), DecoderType::ModelPack) => Self::modelpack_det_order,
3371 (ConfigOutputRef::Detection(_), DecoderType::Ultralytics) => Self::yolo_det_order,
3372 (ConfigOutputRef::Boxes(_), DecoderType::ModelPack) => Self::modelpack_boxes_order,
3373 (ConfigOutputRef::Boxes(_), DecoderType::Ultralytics) => Self::yolo_boxes_order,
3374 (ConfigOutputRef::Scores(_), DecoderType::ModelPack) => Self::modelpack_scores_order,
3375 (ConfigOutputRef::Scores(_), DecoderType::Ultralytics) => Self::yolo_scores_order,
3376 (ConfigOutputRef::Segmentation(_), _) => Self::modelpack_segmentation_order,
3377 (ConfigOutputRef::Mask(_), _) => Self::modelpack_mask_order,
3378 (ConfigOutputRef::Protos(_), _) => Self::yolo_protos_order,
3379 (ConfigOutputRef::MaskCoefficients(_), _) => Self::yolo_maskcoefficients_order,
3380 }
3381 }
3382
3383 fn swap_axes_if_needed<'a, T, D: Dimension>(
3384 array: &ArrayView<'a, T, D>,
3385 config: ConfigOutputRef,
3386 ) -> ArrayView<'a, T, D> {
3387 let mut array = array.clone();
3388 if config.dshape().is_empty() {
3389 return array;
3390 }
3391 let order_fn: fn(DimName) -> usize = Self::get_order_fn(config.clone());
3392 let mut current_order: Vec<usize> = config
3393 .dshape()
3394 .iter()
3395 .map(|x| order_fn(x.0))
3396 .collect::<Vec<_>>();
3397
3398 assert_eq!(array.shape().len(), current_order.len());
3399 for i in 0..current_order.len() {
3402 let mut swapped = false;
3403 for j in 0..current_order.len() - 1 - i {
3404 if current_order[j] > current_order[j + 1] {
3405 array.swap_axes(j, j + 1);
3406 current_order.swap(j, j + 1);
3407 swapped = true;
3408 }
3409 }
3410 if !swapped {
3411 break;
3412 }
3413 }
3414 array
3415 }
3416
3417 fn match_outputs_to_detect_quantized<'a, 'b>(
3418 configs: &[configs::Detection],
3419 outputs: &'a [ArrayViewDQuantized<'b>],
3420 ) -> Result<Vec<&'a ArrayViewDQuantized<'b>>, DecoderError> {
3421 let mut new_output_order = Vec::new();
3422 for c in configs {
3423 let mut found = false;
3424 for o in outputs {
3425 if o.shape() == c.shape {
3426 new_output_order.push(o);
3427 found = true;
3428 break;
3429 }
3430 }
3431 if !found {
3432 return Err(DecoderError::InvalidShape(format!(
3433 "Did not find output with shape {:?}",
3434 c.shape
3435 )));
3436 }
3437 }
3438 Ok(new_output_order)
3439 }
3440}
3441
3442#[cfg(test)]
3443#[cfg_attr(coverage_nightly, coverage(off))]
3444mod decoder_builder_tests {
3445 use super::*;
3446
3447 #[test]
3448 fn test_decoder_builder_no_config() {
3449 use crate::DecoderBuilder;
3450 let result = DecoderBuilder::default().build();
3451 assert!(matches!(result, Err(DecoderError::NoConfig)));
3452 }
3453
3454 #[test]
3455 fn test_decoder_builder_empty_config() {
3456 use crate::DecoderBuilder;
3457 let result = DecoderBuilder::default()
3458 .with_config(ConfigOutputs {
3459 outputs: vec![],
3460 ..Default::default()
3461 })
3462 .build();
3463 assert!(
3464 matches!(result, Err(DecoderError::InvalidConfig(s)) if s == "No outputs found in config")
3465 );
3466 }
3467
3468 #[test]
3469 fn test_malformed_config_yaml() {
3470 let malformed_yaml = "
3471 model_type: yolov8_det
3472 outputs:
3473 - shape: [1, 84, 8400]
3474 "
3475 .to_owned();
3476 let result = DecoderBuilder::new()
3477 .with_config_yaml_str(malformed_yaml)
3478 .build();
3479 assert!(matches!(result, Err(DecoderError::Yaml(_))));
3480 }
3481
3482 #[test]
3483 fn test_malformed_config_json() {
3484 let malformed_yaml = "
3485 {
3486 \"model_type\": \"yolov8_det\",
3487 \"outputs\": [
3488 {
3489 \"shape\": [1, 84, 8400]
3490 }
3491 ]
3492 }"
3493 .to_owned();
3494 let result = DecoderBuilder::new()
3495 .with_config_json_str(malformed_yaml)
3496 .build();
3497 assert!(matches!(result, Err(DecoderError::Json(_))));
3498 }
3499
3500 #[test]
3501 fn test_modelpack_and_yolo_config_error() {
3502 let result = DecoderBuilder::new()
3503 .with_config_modelpack_det(
3504 configs::Boxes {
3505 decoder: configs::DecoderType::Ultralytics,
3506 shape: vec![1, 4, 8400],
3507 quantization: None,
3508 dshape: vec![
3509 (DimName::Batch, 1),
3510 (DimName::BoxCoords, 4),
3511 (DimName::NumBoxes, 8400),
3512 ],
3513 normalized: Some(true),
3514 },
3515 configs::Scores {
3516 decoder: configs::DecoderType::ModelPack,
3517 shape: vec![1, 80, 8400],
3518 quantization: None,
3519 dshape: vec![
3520 (DimName::Batch, 1),
3521 (DimName::NumClasses, 80),
3522 (DimName::NumBoxes, 8400),
3523 ],
3524 },
3525 )
3526 .build();
3527
3528 assert!(matches!(
3529 result, Err(DecoderError::InvalidConfig(s)) if s == "Both ModelPack and Yolo outputs found in config"
3530 ));
3531 }
3532
3533 #[test]
3534 fn test_yolo_invalid_seg_shape() {
3535 let result = DecoderBuilder::new()
3536 .with_config_yolo_segdet(
3537 configs::Detection {
3538 decoder: configs::DecoderType::Ultralytics,
3539 shape: vec![1, 85, 8400, 1], quantization: None,
3541 anchors: None,
3542 dshape: vec![
3543 (DimName::Batch, 1),
3544 (DimName::NumFeatures, 85),
3545 (DimName::NumBoxes, 8400),
3546 (DimName::Batch, 1),
3547 ],
3548 normalized: Some(true),
3549 },
3550 configs::Protos {
3551 decoder: configs::DecoderType::Ultralytics,
3552 shape: vec![1, 32, 160, 160],
3553 quantization: None,
3554 dshape: vec![
3555 (DimName::Batch, 1),
3556 (DimName::NumProtos, 32),
3557 (DimName::Height, 160),
3558 (DimName::Width, 160),
3559 ],
3560 },
3561 Some(DecoderVersion::Yolo11),
3562 )
3563 .build();
3564
3565 assert!(matches!(
3566 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Detection shape")
3567 ));
3568 }
3569
3570 #[test]
3571 fn test_yolo_invalid_mask() {
3572 let result = DecoderBuilder::new()
3573 .with_config(ConfigOutputs {
3574 outputs: vec![ConfigOutput::Mask(configs::Mask {
3575 shape: vec![1, 160, 160, 1],
3576 decoder: configs::DecoderType::Ultralytics,
3577 quantization: None,
3578 dshape: vec![
3579 (DimName::Batch, 1),
3580 (DimName::Height, 160),
3581 (DimName::Width, 160),
3582 (DimName::NumFeatures, 1),
3583 ],
3584 })],
3585 ..Default::default()
3586 })
3587 .build();
3588
3589 assert!(matches!(
3590 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Mask output with Yolo decoder")
3591 ));
3592 }
3593
3594 #[test]
3595 fn test_yolo_invalid_outputs() {
3596 let result = DecoderBuilder::new()
3597 .with_config(ConfigOutputs {
3598 outputs: vec![ConfigOutput::Segmentation(configs::Segmentation {
3599 shape: vec![1, 84, 8400],
3600 decoder: configs::DecoderType::Ultralytics,
3601 quantization: None,
3602 dshape: vec![
3603 (DimName::Batch, 1),
3604 (DimName::NumFeatures, 84),
3605 (DimName::NumBoxes, 8400),
3606 ],
3607 })],
3608 ..Default::default()
3609 })
3610 .build();
3611
3612 assert!(
3613 matches!(result, Err(DecoderError::InvalidConfig(s)) if s == "Invalid Segmentation output with Yolo decoder")
3614 );
3615 }
3616
3617 #[test]
3618 fn test_yolo_invalid_det() {
3619 let result = DecoderBuilder::new()
3620 .with_config_yolo_det(
3621 configs::Detection {
3622 anchors: None,
3623 decoder: DecoderType::Ultralytics,
3624 quantization: None,
3625 shape: vec![1, 84, 8400, 1], dshape: vec![
3627 (DimName::Batch, 1),
3628 (DimName::NumFeatures, 84),
3629 (DimName::NumBoxes, 8400),
3630 (DimName::Batch, 1),
3631 ],
3632 normalized: Some(true),
3633 },
3634 Some(DecoderVersion::Yolo11),
3635 )
3636 .build();
3637
3638 assert!(matches!(
3639 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Detection shape")));
3640
3641 let result = DecoderBuilder::new()
3642 .with_config_yolo_det(
3643 configs::Detection {
3644 anchors: None,
3645 decoder: DecoderType::Ultralytics,
3646 quantization: None,
3647 shape: vec![1, 8400, 3], dshape: vec![
3649 (DimName::Batch, 1),
3650 (DimName::NumBoxes, 8400),
3651 (DimName::NumFeatures, 3),
3652 ],
3653 normalized: Some(true),
3654 },
3655 Some(DecoderVersion::Yolo11),
3656 )
3657 .build();
3658
3659 assert!(
3660 matches!(
3661 &result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid shape: Yolo num_features 3 must be greater than 4")),
3662 "{}",
3663 result.unwrap_err()
3664 );
3665
3666 let result = DecoderBuilder::new()
3667 .with_config_yolo_det(
3668 configs::Detection {
3669 anchors: None,
3670 decoder: DecoderType::Ultralytics,
3671 quantization: None,
3672 shape: vec![1, 3, 8400], dshape: Vec::new(),
3674 normalized: Some(true),
3675 },
3676 Some(DecoderVersion::Yolo11),
3677 )
3678 .build();
3679
3680 assert!(matches!(
3681 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid shape: Yolo num_features 3 must be greater than 4")));
3682 }
3683
3684 #[test]
3685 fn test_yolo_invalid_segdet() {
3686 let result = DecoderBuilder::new()
3687 .with_config_yolo_segdet(
3688 configs::Detection {
3689 decoder: configs::DecoderType::Ultralytics,
3690 shape: vec![1, 85, 8400, 1], quantization: None,
3692 anchors: None,
3693 dshape: vec![
3694 (DimName::Batch, 1),
3695 (DimName::NumFeatures, 85),
3696 (DimName::NumBoxes, 8400),
3697 (DimName::Batch, 1),
3698 ],
3699 normalized: Some(true),
3700 },
3701 configs::Protos {
3702 decoder: configs::DecoderType::Ultralytics,
3703 shape: vec![1, 32, 160, 160],
3704 quantization: None,
3705 dshape: vec![
3706 (DimName::Batch, 1),
3707 (DimName::NumProtos, 32),
3708 (DimName::Height, 160),
3709 (DimName::Width, 160),
3710 ],
3711 },
3712 Some(DecoderVersion::Yolo11),
3713 )
3714 .build();
3715
3716 assert!(matches!(
3717 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Detection shape")));
3718
3719 let result = DecoderBuilder::new()
3720 .with_config_yolo_segdet(
3721 configs::Detection {
3722 decoder: configs::DecoderType::Ultralytics,
3723 shape: vec![1, 85, 8400],
3724 quantization: None,
3725 anchors: None,
3726 dshape: vec![
3727 (DimName::Batch, 1),
3728 (DimName::NumFeatures, 85),
3729 (DimName::NumBoxes, 8400),
3730 ],
3731 normalized: Some(true),
3732 },
3733 configs::Protos {
3734 decoder: configs::DecoderType::Ultralytics,
3735 shape: vec![1, 32, 160, 160, 1], dshape: vec![
3737 (DimName::Batch, 1),
3738 (DimName::NumProtos, 32),
3739 (DimName::Height, 160),
3740 (DimName::Width, 160),
3741 (DimName::Batch, 1),
3742 ],
3743 quantization: None,
3744 },
3745 Some(DecoderVersion::Yolo11),
3746 )
3747 .build();
3748
3749 assert!(matches!(
3750 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Protos shape")));
3751
3752 let result = DecoderBuilder::new()
3753 .with_config_yolo_segdet(
3754 configs::Detection {
3755 decoder: configs::DecoderType::Ultralytics,
3756 shape: vec![1, 8400, 36], quantization: None,
3758 anchors: None,
3759 dshape: vec![
3760 (DimName::Batch, 1),
3761 (DimName::NumBoxes, 8400),
3762 (DimName::NumFeatures, 36),
3763 ],
3764 normalized: Some(true),
3765 },
3766 configs::Protos {
3767 decoder: configs::DecoderType::Ultralytics,
3768 shape: vec![1, 32, 160, 160],
3769 quantization: None,
3770 dshape: vec![
3771 (DimName::Batch, 1),
3772 (DimName::NumProtos, 32),
3773 (DimName::Height, 160),
3774 (DimName::Width, 160),
3775 ],
3776 },
3777 Some(DecoderVersion::Yolo11),
3778 )
3779 .build();
3780 println!("{:?}", result);
3781 assert!(matches!(
3782 result, Err(DecoderError::InvalidConfig(s)) if s == "Invalid shape: Yolo num_features 36 must be greater than 36"));
3783 }
3784
3785 #[test]
3786 fn test_yolo_invalid_split_det() {
3787 let result = DecoderBuilder::new()
3788 .with_config_yolo_split_det(
3789 configs::Boxes {
3790 decoder: configs::DecoderType::Ultralytics,
3791 shape: vec![1, 4, 8400, 1], quantization: None,
3793 dshape: vec![
3794 (DimName::Batch, 1),
3795 (DimName::BoxCoords, 4),
3796 (DimName::NumBoxes, 8400),
3797 (DimName::Batch, 1),
3798 ],
3799 normalized: Some(true),
3800 },
3801 configs::Scores {
3802 decoder: configs::DecoderType::Ultralytics,
3803 shape: vec![1, 80, 8400],
3804 quantization: None,
3805 dshape: vec![
3806 (DimName::Batch, 1),
3807 (DimName::NumClasses, 80),
3808 (DimName::NumBoxes, 8400),
3809 ],
3810 },
3811 )
3812 .build();
3813
3814 assert!(matches!(
3815 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Split Boxes shape")));
3816
3817 let result = DecoderBuilder::new()
3818 .with_config_yolo_split_det(
3819 configs::Boxes {
3820 decoder: configs::DecoderType::Ultralytics,
3821 shape: vec![1, 4, 8400],
3822 quantization: None,
3823 dshape: vec![
3824 (DimName::Batch, 1),
3825 (DimName::BoxCoords, 4),
3826 (DimName::NumBoxes, 8400),
3827 ],
3828 normalized: Some(true),
3829 },
3830 configs::Scores {
3831 decoder: configs::DecoderType::Ultralytics,
3832 shape: vec![1, 80, 8400, 1], quantization: None,
3834 dshape: vec![
3835 (DimName::Batch, 1),
3836 (DimName::NumClasses, 80),
3837 (DimName::NumBoxes, 8400),
3838 (DimName::Batch, 1),
3839 ],
3840 },
3841 )
3842 .build();
3843
3844 assert!(matches!(
3845 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Split Scores shape")));
3846
3847 let result = DecoderBuilder::new()
3848 .with_config_yolo_split_det(
3849 configs::Boxes {
3850 decoder: configs::DecoderType::Ultralytics,
3851 shape: vec![1, 8400, 4],
3852 quantization: None,
3853 dshape: vec![
3854 (DimName::Batch, 1),
3855 (DimName::NumBoxes, 8400),
3856 (DimName::BoxCoords, 4),
3857 ],
3858 normalized: Some(true),
3859 },
3860 configs::Scores {
3861 decoder: configs::DecoderType::Ultralytics,
3862 shape: vec![1, 8400 + 1, 80], quantization: None,
3864 dshape: vec![
3865 (DimName::Batch, 1),
3866 (DimName::NumBoxes, 8401),
3867 (DimName::NumClasses, 80),
3868 ],
3869 },
3870 )
3871 .build();
3872
3873 assert!(matches!(
3874 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Yolo Split Detection Boxes num 8400 incompatible with Scores num 8401")));
3875
3876 let result = DecoderBuilder::new()
3877 .with_config_yolo_split_det(
3878 configs::Boxes {
3879 decoder: configs::DecoderType::Ultralytics,
3880 shape: vec![1, 5, 8400], quantization: None,
3882 dshape: vec![
3883 (DimName::Batch, 1),
3884 (DimName::BoxCoords, 5),
3885 (DimName::NumBoxes, 8400),
3886 ],
3887 normalized: Some(true),
3888 },
3889 configs::Scores {
3890 decoder: configs::DecoderType::Ultralytics,
3891 shape: vec![1, 80, 8400],
3892 quantization: None,
3893 dshape: vec![
3894 (DimName::Batch, 1),
3895 (DimName::NumClasses, 80),
3896 (DimName::NumBoxes, 8400),
3897 ],
3898 },
3899 )
3900 .build();
3901 assert!(matches!(
3902 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("BoxCoords dimension size must be 4")));
3903 }
3904
3905 #[test]
3906 fn test_yolo_invalid_split_segdet() {
3907 let result = DecoderBuilder::new()
3908 .with_config_yolo_split_segdet(
3909 configs::Boxes {
3910 decoder: configs::DecoderType::Ultralytics,
3911 shape: vec![1, 8400, 4, 1],
3912 quantization: None,
3913 dshape: vec![
3914 (DimName::Batch, 1),
3915 (DimName::NumBoxes, 8400),
3916 (DimName::BoxCoords, 4),
3917 (DimName::Batch, 1),
3918 ],
3919 normalized: Some(true),
3920 },
3921 configs::Scores {
3922 decoder: configs::DecoderType::Ultralytics,
3923 shape: vec![1, 8400, 80],
3924
3925 quantization: None,
3926 dshape: vec![
3927 (DimName::Batch, 1),
3928 (DimName::NumBoxes, 8400),
3929 (DimName::NumClasses, 80),
3930 ],
3931 },
3932 configs::MaskCoefficients {
3933 decoder: configs::DecoderType::Ultralytics,
3934 shape: vec![1, 8400, 32],
3935 quantization: None,
3936 dshape: vec![
3937 (DimName::Batch, 1),
3938 (DimName::NumBoxes, 8400),
3939 (DimName::NumProtos, 32),
3940 ],
3941 },
3942 configs::Protos {
3943 decoder: configs::DecoderType::Ultralytics,
3944 shape: vec![1, 32, 160, 160],
3945 quantization: None,
3946 dshape: vec![
3947 (DimName::Batch, 1),
3948 (DimName::NumProtos, 32),
3949 (DimName::Height, 160),
3950 (DimName::Width, 160),
3951 ],
3952 },
3953 )
3954 .build();
3955
3956 assert!(matches!(
3957 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Split Boxes shape")));
3958
3959 let result = DecoderBuilder::new()
3960 .with_config_yolo_split_segdet(
3961 configs::Boxes {
3962 decoder: configs::DecoderType::Ultralytics,
3963 shape: vec![1, 8400, 4],
3964 quantization: None,
3965 dshape: vec![
3966 (DimName::Batch, 1),
3967 (DimName::NumBoxes, 8400),
3968 (DimName::BoxCoords, 4),
3969 ],
3970 normalized: Some(true),
3971 },
3972 configs::Scores {
3973 decoder: configs::DecoderType::Ultralytics,
3974 shape: vec![1, 8400, 80, 1],
3975 quantization: None,
3976 dshape: vec![
3977 (DimName::Batch, 1),
3978 (DimName::NumBoxes, 8400),
3979 (DimName::NumClasses, 80),
3980 (DimName::Batch, 1),
3981 ],
3982 },
3983 configs::MaskCoefficients {
3984 decoder: configs::DecoderType::Ultralytics,
3985 shape: vec![1, 8400, 32],
3986 quantization: None,
3987 dshape: vec![
3988 (DimName::Batch, 1),
3989 (DimName::NumBoxes, 8400),
3990 (DimName::NumProtos, 32),
3991 ],
3992 },
3993 configs::Protos {
3994 decoder: configs::DecoderType::Ultralytics,
3995 shape: vec![1, 32, 160, 160],
3996 quantization: None,
3997 dshape: vec![
3998 (DimName::Batch, 1),
3999 (DimName::NumProtos, 32),
4000 (DimName::Height, 160),
4001 (DimName::Width, 160),
4002 ],
4003 },
4004 )
4005 .build();
4006
4007 assert!(matches!(
4008 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Split Scores shape")));
4009
4010 let result = DecoderBuilder::new()
4011 .with_config_yolo_split_segdet(
4012 configs::Boxes {
4013 decoder: configs::DecoderType::Ultralytics,
4014 shape: vec![1, 8400, 4],
4015 quantization: None,
4016 dshape: vec![
4017 (DimName::Batch, 1),
4018 (DimName::NumBoxes, 8400),
4019 (DimName::BoxCoords, 4),
4020 ],
4021 normalized: Some(true),
4022 },
4023 configs::Scores {
4024 decoder: configs::DecoderType::Ultralytics,
4025 shape: vec![1, 8400, 80],
4026 quantization: None,
4027 dshape: vec![
4028 (DimName::Batch, 1),
4029 (DimName::NumBoxes, 8400),
4030 (DimName::NumClasses, 80),
4031 ],
4032 },
4033 configs::MaskCoefficients {
4034 decoder: configs::DecoderType::Ultralytics,
4035 shape: vec![1, 8400, 32, 1],
4036 quantization: None,
4037 dshape: vec![
4038 (DimName::Batch, 1),
4039 (DimName::NumBoxes, 8400),
4040 (DimName::NumProtos, 32),
4041 (DimName::Batch, 1),
4042 ],
4043 },
4044 configs::Protos {
4045 decoder: configs::DecoderType::Ultralytics,
4046 shape: vec![1, 32, 160, 160],
4047 quantization: None,
4048 dshape: vec![
4049 (DimName::Batch, 1),
4050 (DimName::NumProtos, 32),
4051 (DimName::Height, 160),
4052 (DimName::Width, 160),
4053 ],
4054 },
4055 )
4056 .build();
4057
4058 assert!(matches!(
4059 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Split Mask Coefficients shape")));
4060
4061 let result = DecoderBuilder::new()
4062 .with_config_yolo_split_segdet(
4063 configs::Boxes {
4064 decoder: configs::DecoderType::Ultralytics,
4065 shape: vec![1, 8400, 4],
4066 quantization: None,
4067 dshape: vec![
4068 (DimName::Batch, 1),
4069 (DimName::NumBoxes, 8400),
4070 (DimName::BoxCoords, 4),
4071 ],
4072 normalized: Some(true),
4073 },
4074 configs::Scores {
4075 decoder: configs::DecoderType::Ultralytics,
4076 shape: vec![1, 8400, 80],
4077 quantization: None,
4078 dshape: vec![
4079 (DimName::Batch, 1),
4080 (DimName::NumBoxes, 8400),
4081 (DimName::NumClasses, 80),
4082 ],
4083 },
4084 configs::MaskCoefficients {
4085 decoder: configs::DecoderType::Ultralytics,
4086 shape: vec![1, 8400, 32],
4087 quantization: None,
4088 dshape: vec![
4089 (DimName::Batch, 1),
4090 (DimName::NumBoxes, 8400),
4091 (DimName::NumProtos, 32),
4092 ],
4093 },
4094 configs::Protos {
4095 decoder: configs::DecoderType::Ultralytics,
4096 shape: vec![1, 32, 160, 160, 1],
4097 quantization: None,
4098 dshape: vec![
4099 (DimName::Batch, 1),
4100 (DimName::NumProtos, 32),
4101 (DimName::Height, 160),
4102 (DimName::Width, 160),
4103 (DimName::Batch, 1),
4104 ],
4105 },
4106 )
4107 .build();
4108
4109 assert!(matches!(
4110 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Protos shape")));
4111
4112 let result = DecoderBuilder::new()
4113 .with_config_yolo_split_segdet(
4114 configs::Boxes {
4115 decoder: configs::DecoderType::Ultralytics,
4116 shape: vec![1, 8400, 4],
4117 quantization: None,
4118 dshape: vec![
4119 (DimName::Batch, 1),
4120 (DimName::NumBoxes, 8400),
4121 (DimName::BoxCoords, 4),
4122 ],
4123 normalized: Some(true),
4124 },
4125 configs::Scores {
4126 decoder: configs::DecoderType::Ultralytics,
4127 shape: vec![1, 8401, 80],
4128 quantization: None,
4129 dshape: vec![
4130 (DimName::Batch, 1),
4131 (DimName::NumBoxes, 8401),
4132 (DimName::NumClasses, 80),
4133 ],
4134 },
4135 configs::MaskCoefficients {
4136 decoder: configs::DecoderType::Ultralytics,
4137 shape: vec![1, 8400, 32],
4138 quantization: None,
4139 dshape: vec![
4140 (DimName::Batch, 1),
4141 (DimName::NumBoxes, 8400),
4142 (DimName::NumProtos, 32),
4143 ],
4144 },
4145 configs::Protos {
4146 decoder: configs::DecoderType::Ultralytics,
4147 shape: vec![1, 32, 160, 160],
4148 quantization: None,
4149 dshape: vec![
4150 (DimName::Batch, 1),
4151 (DimName::NumProtos, 32),
4152 (DimName::Height, 160),
4153 (DimName::Width, 160),
4154 ],
4155 },
4156 )
4157 .build();
4158
4159 assert!(matches!(
4160 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Yolo Split Detection Boxes num 8400 incompatible with Scores num 8401")));
4161
4162 let result = DecoderBuilder::new()
4163 .with_config_yolo_split_segdet(
4164 configs::Boxes {
4165 decoder: configs::DecoderType::Ultralytics,
4166 shape: vec![1, 8400, 4],
4167 quantization: None,
4168 dshape: vec![
4169 (DimName::Batch, 1),
4170 (DimName::NumBoxes, 8400),
4171 (DimName::BoxCoords, 4),
4172 ],
4173 normalized: Some(true),
4174 },
4175 configs::Scores {
4176 decoder: configs::DecoderType::Ultralytics,
4177 shape: vec![1, 8400, 80],
4178 quantization: None,
4179 dshape: vec![
4180 (DimName::Batch, 1),
4181 (DimName::NumBoxes, 8400),
4182 (DimName::NumClasses, 80),
4183 ],
4184 },
4185 configs::MaskCoefficients {
4186 decoder: configs::DecoderType::Ultralytics,
4187 shape: vec![1, 8401, 32],
4188
4189 quantization: None,
4190 dshape: vec![
4191 (DimName::Batch, 1),
4192 (DimName::NumBoxes, 8401),
4193 (DimName::NumProtos, 32),
4194 ],
4195 },
4196 configs::Protos {
4197 decoder: configs::DecoderType::Ultralytics,
4198 shape: vec![1, 32, 160, 160],
4199 quantization: None,
4200 dshape: vec![
4201 (DimName::Batch, 1),
4202 (DimName::NumProtos, 32),
4203 (DimName::Height, 160),
4204 (DimName::Width, 160),
4205 ],
4206 },
4207 )
4208 .build();
4209
4210 assert!(matches!(
4211 result, Err(DecoderError::InvalidConfig(ref s)) if s.starts_with("Yolo Split Detection Boxes num 8400 incompatible with Mask Coefficients num 8401")));
4212 let result = DecoderBuilder::new()
4213 .with_config_yolo_split_segdet(
4214 configs::Boxes {
4215 decoder: configs::DecoderType::Ultralytics,
4216 shape: vec![1, 8400, 4],
4217 quantization: None,
4218 dshape: vec![
4219 (DimName::Batch, 1),
4220 (DimName::NumBoxes, 8400),
4221 (DimName::BoxCoords, 4),
4222 ],
4223 normalized: Some(true),
4224 },
4225 configs::Scores {
4226 decoder: configs::DecoderType::Ultralytics,
4227 shape: vec![1, 8400, 80],
4228 quantization: None,
4229 dshape: vec![
4230 (DimName::Batch, 1),
4231 (DimName::NumBoxes, 8400),
4232 (DimName::NumClasses, 80),
4233 ],
4234 },
4235 configs::MaskCoefficients {
4236 decoder: configs::DecoderType::Ultralytics,
4237 shape: vec![1, 8400, 32],
4238 quantization: None,
4239 dshape: vec![
4240 (DimName::Batch, 1),
4241 (DimName::NumBoxes, 8400),
4242 (DimName::NumProtos, 32),
4243 ],
4244 },
4245 configs::Protos {
4246 decoder: configs::DecoderType::Ultralytics,
4247 shape: vec![1, 31, 160, 160],
4248 quantization: None,
4249 dshape: vec![
4250 (DimName::Batch, 1),
4251 (DimName::NumProtos, 31),
4252 (DimName::Height, 160),
4253 (DimName::Width, 160),
4254 ],
4255 },
4256 )
4257 .build();
4258 println!("{:?}", result);
4259 assert!(matches!(
4260 result, Err(DecoderError::InvalidConfig(ref s)) if s.starts_with( "Yolo Protos channels 31 incompatible with Mask Coefficients channels 32")));
4261 }
4262
4263 #[test]
4264 fn test_modelpack_invalid_config() {
4265 let result = DecoderBuilder::new()
4266 .with_config(ConfigOutputs {
4267 outputs: vec![
4268 ConfigOutput::Boxes(configs::Boxes {
4269 decoder: configs::DecoderType::ModelPack,
4270 shape: vec![1, 8400, 1, 4],
4271 quantization: None,
4272 dshape: vec![
4273 (DimName::Batch, 1),
4274 (DimName::NumBoxes, 8400),
4275 (DimName::Padding, 1),
4276 (DimName::BoxCoords, 4),
4277 ],
4278 normalized: Some(true),
4279 }),
4280 ConfigOutput::Scores(configs::Scores {
4281 decoder: configs::DecoderType::ModelPack,
4282 shape: vec![1, 8400, 3],
4283 quantization: None,
4284 dshape: vec![
4285 (DimName::Batch, 1),
4286 (DimName::NumBoxes, 8400),
4287 (DimName::NumClasses, 3),
4288 ],
4289 }),
4290 ConfigOutput::Protos(configs::Protos {
4291 decoder: configs::DecoderType::ModelPack,
4292 shape: vec![1, 8400, 3],
4293 quantization: None,
4294 dshape: vec![
4295 (DimName::Batch, 1),
4296 (DimName::NumBoxes, 8400),
4297 (DimName::NumFeatures, 3),
4298 ],
4299 }),
4300 ],
4301 ..Default::default()
4302 })
4303 .build();
4304
4305 assert!(matches!(
4306 result, Err(DecoderError::InvalidConfig(s)) if s == "ModelPack should not have protos"));
4307
4308 let result = DecoderBuilder::new()
4309 .with_config(ConfigOutputs {
4310 outputs: vec![
4311 ConfigOutput::Boxes(configs::Boxes {
4312 decoder: configs::DecoderType::ModelPack,
4313 shape: vec![1, 8400, 1, 4],
4314 quantization: None,
4315 dshape: vec![
4316 (DimName::Batch, 1),
4317 (DimName::NumBoxes, 8400),
4318 (DimName::Padding, 1),
4319 (DimName::BoxCoords, 4),
4320 ],
4321 normalized: Some(true),
4322 }),
4323 ConfigOutput::Scores(configs::Scores {
4324 decoder: configs::DecoderType::ModelPack,
4325 shape: vec![1, 8400, 3],
4326 quantization: None,
4327 dshape: vec![
4328 (DimName::Batch, 1),
4329 (DimName::NumBoxes, 8400),
4330 (DimName::NumClasses, 3),
4331 ],
4332 }),
4333 ConfigOutput::MaskCoefficients(configs::MaskCoefficients {
4334 decoder: configs::DecoderType::ModelPack,
4335 shape: vec![1, 8400, 3],
4336 quantization: None,
4337 dshape: vec![
4338 (DimName::Batch, 1),
4339 (DimName::NumBoxes, 8400),
4340 (DimName::NumProtos, 3),
4341 ],
4342 }),
4343 ],
4344 ..Default::default()
4345 })
4346 .build();
4347
4348 assert!(matches!(
4349 result, Err(DecoderError::InvalidConfig(s)) if s == "ModelPack should not have mask coefficients"));
4350
4351 let result = DecoderBuilder::new()
4352 .with_config(ConfigOutputs {
4353 outputs: vec![ConfigOutput::Boxes(configs::Boxes {
4354 decoder: configs::DecoderType::ModelPack,
4355 shape: vec![1, 8400, 1, 4],
4356 quantization: None,
4357 dshape: vec![
4358 (DimName::Batch, 1),
4359 (DimName::NumBoxes, 8400),
4360 (DimName::Padding, 1),
4361 (DimName::BoxCoords, 4),
4362 ],
4363 normalized: Some(true),
4364 })],
4365 ..Default::default()
4366 })
4367 .build();
4368
4369 assert!(matches!(
4370 result, Err(DecoderError::InvalidConfig(s)) if s == "Invalid ModelPack model outputs"));
4371 }
4372
4373 #[test]
4374 fn test_modelpack_invalid_det() {
4375 let result = DecoderBuilder::new()
4376 .with_config_modelpack_det(
4377 configs::Boxes {
4378 decoder: DecoderType::ModelPack,
4379 quantization: None,
4380 shape: vec![1, 4, 8400],
4381 dshape: vec![
4382 (DimName::Batch, 1),
4383 (DimName::BoxCoords, 4),
4384 (DimName::NumBoxes, 8400),
4385 ],
4386 normalized: Some(true),
4387 },
4388 configs::Scores {
4389 decoder: DecoderType::ModelPack,
4390 quantization: None,
4391 shape: vec![1, 80, 8400],
4392 dshape: vec![
4393 (DimName::Batch, 1),
4394 (DimName::NumClasses, 80),
4395 (DimName::NumBoxes, 8400),
4396 ],
4397 },
4398 )
4399 .build();
4400
4401 assert!(matches!(
4402 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid ModelPack Boxes shape")));
4403
4404 let result = DecoderBuilder::new()
4405 .with_config_modelpack_det(
4406 configs::Boxes {
4407 decoder: DecoderType::ModelPack,
4408 quantization: None,
4409 shape: vec![1, 4, 1, 8400],
4410 dshape: vec![
4411 (DimName::Batch, 1),
4412 (DimName::BoxCoords, 4),
4413 (DimName::Padding, 1),
4414 (DimName::NumBoxes, 8400),
4415 ],
4416 normalized: Some(true),
4417 },
4418 configs::Scores {
4419 decoder: DecoderType::ModelPack,
4420 quantization: None,
4421 shape: vec![1, 80, 8400, 1],
4422 dshape: vec![
4423 (DimName::Batch, 1),
4424 (DimName::NumClasses, 80),
4425 (DimName::NumBoxes, 8400),
4426 (DimName::Padding, 1),
4427 ],
4428 },
4429 )
4430 .build();
4431
4432 assert!(matches!(
4433 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid ModelPack Scores shape")));
4434
4435 let result = DecoderBuilder::new()
4436 .with_config_modelpack_det(
4437 configs::Boxes {
4438 decoder: DecoderType::ModelPack,
4439 quantization: None,
4440 shape: vec![1, 4, 2, 8400],
4441 dshape: vec![
4442 (DimName::Batch, 1),
4443 (DimName::BoxCoords, 4),
4444 (DimName::Padding, 2),
4445 (DimName::NumBoxes, 8400),
4446 ],
4447 normalized: Some(true),
4448 },
4449 configs::Scores {
4450 decoder: DecoderType::ModelPack,
4451 quantization: None,
4452 shape: vec![1, 80, 8400],
4453 dshape: vec![
4454 (DimName::Batch, 1),
4455 (DimName::NumClasses, 80),
4456 (DimName::NumBoxes, 8400),
4457 ],
4458 },
4459 )
4460 .build();
4461 assert!(matches!(
4462 result, Err(DecoderError::InvalidConfig(s)) if s == "Padding dimension size must be 1"));
4463
4464 let result = DecoderBuilder::new()
4465 .with_config_modelpack_det(
4466 configs::Boxes {
4467 decoder: DecoderType::ModelPack,
4468 quantization: None,
4469 shape: vec![1, 5, 1, 8400],
4470 dshape: vec![
4471 (DimName::Batch, 1),
4472 (DimName::BoxCoords, 5),
4473 (DimName::Padding, 1),
4474 (DimName::NumBoxes, 8400),
4475 ],
4476 normalized: Some(true),
4477 },
4478 configs::Scores {
4479 decoder: DecoderType::ModelPack,
4480 quantization: None,
4481 shape: vec![1, 80, 8400],
4482 dshape: vec![
4483 (DimName::Batch, 1),
4484 (DimName::NumClasses, 80),
4485 (DimName::NumBoxes, 8400),
4486 ],
4487 },
4488 )
4489 .build();
4490
4491 assert!(matches!(
4492 result, Err(DecoderError::InvalidConfig(s)) if s == "BoxCoords dimension size must be 4"));
4493
4494 let result = DecoderBuilder::new()
4495 .with_config_modelpack_det(
4496 configs::Boxes {
4497 decoder: DecoderType::ModelPack,
4498 quantization: None,
4499 shape: vec![1, 4, 1, 8400],
4500 dshape: vec![
4501 (DimName::Batch, 1),
4502 (DimName::BoxCoords, 4),
4503 (DimName::Padding, 1),
4504 (DimName::NumBoxes, 8400),
4505 ],
4506 normalized: Some(true),
4507 },
4508 configs::Scores {
4509 decoder: DecoderType::ModelPack,
4510 quantization: None,
4511 shape: vec![1, 80, 8401],
4512 dshape: vec![
4513 (DimName::Batch, 1),
4514 (DimName::NumClasses, 80),
4515 (DimName::NumBoxes, 8401),
4516 ],
4517 },
4518 )
4519 .build();
4520
4521 assert!(matches!(
4522 result, Err(DecoderError::InvalidConfig(s)) if s == "ModelPack Detection Boxes num 8400 incompatible with Scores num 8401"));
4523 }
4524
4525 #[test]
4526 fn test_modelpack_invalid_det_split() {
4527 let result = DecoderBuilder::default()
4528 .with_config_modelpack_det_split(vec![
4529 configs::Detection {
4530 decoder: DecoderType::ModelPack,
4531 shape: vec![1, 17, 30, 18],
4532 anchors: None,
4533 quantization: None,
4534 dshape: vec![
4535 (DimName::Batch, 1),
4536 (DimName::Height, 17),
4537 (DimName::Width, 30),
4538 (DimName::NumAnchorsXFeatures, 18),
4539 ],
4540 normalized: Some(true),
4541 },
4542 configs::Detection {
4543 decoder: DecoderType::ModelPack,
4544 shape: vec![1, 9, 15, 18],
4545 anchors: None,
4546 quantization: None,
4547 dshape: vec![
4548 (DimName::Batch, 1),
4549 (DimName::Height, 9),
4550 (DimName::Width, 15),
4551 (DimName::NumAnchorsXFeatures, 18),
4552 ],
4553 normalized: Some(true),
4554 },
4555 ])
4556 .build();
4557
4558 assert!(matches!(
4559 result, Err(DecoderError::InvalidConfig(s)) if s == "ModelPack Split Detection missing anchors"));
4560
4561 let result = DecoderBuilder::default()
4562 .with_config_modelpack_det_split(vec![configs::Detection {
4563 decoder: DecoderType::ModelPack,
4564 shape: vec![1, 17, 30, 18],
4565 anchors: None,
4566 quantization: None,
4567 dshape: Vec::new(),
4568 normalized: Some(true),
4569 }])
4570 .build();
4571
4572 assert!(matches!(
4573 result, Err(DecoderError::InvalidConfig(s)) if s == "ModelPack Split Detection missing anchors"));
4574
4575 let result = DecoderBuilder::default()
4576 .with_config_modelpack_det_split(vec![configs::Detection {
4577 decoder: DecoderType::ModelPack,
4578 shape: vec![1, 17, 30, 18],
4579 anchors: Some(vec![]),
4580 quantization: None,
4581 dshape: vec![
4582 (DimName::Batch, 1),
4583 (DimName::Height, 17),
4584 (DimName::Width, 30),
4585 (DimName::NumAnchorsXFeatures, 18),
4586 ],
4587 normalized: Some(true),
4588 }])
4589 .build();
4590
4591 assert!(matches!(
4592 result, Err(DecoderError::InvalidConfig(s)) if s == "ModelPack Split Detection has zero anchors"));
4593
4594 let result = DecoderBuilder::default()
4595 .with_config_modelpack_det_split(vec![configs::Detection {
4596 decoder: DecoderType::ModelPack,
4597 shape: vec![1, 17, 30, 18, 1],
4598 anchors: Some(vec![
4599 [0.3666666, 0.3148148],
4600 [0.3874999, 0.474074],
4601 [0.5333333, 0.644444],
4602 ]),
4603 quantization: None,
4604 dshape: vec![
4605 (DimName::Batch, 1),
4606 (DimName::Height, 17),
4607 (DimName::Width, 30),
4608 (DimName::NumAnchorsXFeatures, 18),
4609 (DimName::Padding, 1),
4610 ],
4611 normalized: Some(true),
4612 }])
4613 .build();
4614
4615 assert!(matches!(
4616 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid ModelPack Split Detection shape")));
4617
4618 let result = DecoderBuilder::default()
4619 .with_config_modelpack_det_split(vec![configs::Detection {
4620 decoder: DecoderType::ModelPack,
4621 shape: vec![1, 15, 17, 30],
4622 anchors: Some(vec![
4623 [0.3666666, 0.3148148],
4624 [0.3874999, 0.474074],
4625 [0.5333333, 0.644444],
4626 ]),
4627 quantization: None,
4628 dshape: vec![
4629 (DimName::Batch, 1),
4630 (DimName::NumAnchorsXFeatures, 15),
4631 (DimName::Height, 17),
4632 (DimName::Width, 30),
4633 ],
4634 normalized: Some(true),
4635 }])
4636 .build();
4637
4638 assert!(matches!(
4639 result, Err(DecoderError::InvalidConfig(s)) if s.contains("not greater than number of anchors * 5 =")));
4640
4641 let result = DecoderBuilder::default()
4642 .with_config_modelpack_det_split(vec![configs::Detection {
4643 decoder: DecoderType::ModelPack,
4644 shape: vec![1, 17, 30, 15],
4645 anchors: Some(vec![
4646 [0.3666666, 0.3148148],
4647 [0.3874999, 0.474074],
4648 [0.5333333, 0.644444],
4649 ]),
4650 quantization: None,
4651 dshape: Vec::new(),
4652 normalized: Some(true),
4653 }])
4654 .build();
4655
4656 assert!(matches!(
4657 result, Err(DecoderError::InvalidConfig(s)) if s.contains("not greater than number of anchors * 5 =")));
4658
4659 let result = DecoderBuilder::default()
4660 .with_config_modelpack_det_split(vec![configs::Detection {
4661 decoder: DecoderType::ModelPack,
4662 shape: vec![1, 16, 17, 30],
4663 anchors: Some(vec![
4664 [0.3666666, 0.3148148],
4665 [0.3874999, 0.474074],
4666 [0.5333333, 0.644444],
4667 ]),
4668 quantization: None,
4669 dshape: vec![
4670 (DimName::Batch, 1),
4671 (DimName::NumAnchorsXFeatures, 16),
4672 (DimName::Height, 17),
4673 (DimName::Width, 30),
4674 ],
4675 normalized: Some(true),
4676 }])
4677 .build();
4678
4679 assert!(matches!(
4680 result, Err(DecoderError::InvalidConfig(s)) if s.contains("not a multiple of number of anchors")));
4681
4682 let result = DecoderBuilder::default()
4683 .with_config_modelpack_det_split(vec![configs::Detection {
4684 decoder: DecoderType::ModelPack,
4685 shape: vec![1, 17, 30, 16],
4686 anchors: Some(vec![
4687 [0.3666666, 0.3148148],
4688 [0.3874999, 0.474074],
4689 [0.5333333, 0.644444],
4690 ]),
4691 quantization: None,
4692 dshape: Vec::new(),
4693 normalized: Some(true),
4694 }])
4695 .build();
4696
4697 assert!(matches!(
4698 result, Err(DecoderError::InvalidConfig(s)) if s.contains("not a multiple of number of anchors")));
4699
4700 let result = DecoderBuilder::default()
4701 .with_config_modelpack_det_split(vec![configs::Detection {
4702 decoder: DecoderType::ModelPack,
4703 shape: vec![1, 18, 17, 30],
4704 anchors: Some(vec![
4705 [0.3666666, 0.3148148],
4706 [0.3874999, 0.474074],
4707 [0.5333333, 0.644444],
4708 ]),
4709 quantization: None,
4710 dshape: vec![
4711 (DimName::Batch, 1),
4712 (DimName::NumProtos, 18),
4713 (DimName::Height, 17),
4714 (DimName::Width, 30),
4715 ],
4716 normalized: Some(true),
4717 }])
4718 .build();
4719 assert!(matches!(
4720 result, Err(DecoderError::InvalidConfig(s)) if s.contains("Split Detection dshape missing required dimension NumAnchorsXFeature")));
4721
4722 let result = DecoderBuilder::default()
4723 .with_config_modelpack_det_split(vec![
4724 configs::Detection {
4725 decoder: DecoderType::ModelPack,
4726 shape: vec![1, 17, 30, 18],
4727 anchors: Some(vec![
4728 [0.3666666, 0.3148148],
4729 [0.3874999, 0.474074],
4730 [0.5333333, 0.644444],
4731 ]),
4732 quantization: None,
4733 dshape: vec![
4734 (DimName::Batch, 1),
4735 (DimName::Height, 17),
4736 (DimName::Width, 30),
4737 (DimName::NumAnchorsXFeatures, 18),
4738 ],
4739 normalized: Some(true),
4740 },
4741 configs::Detection {
4742 decoder: DecoderType::ModelPack,
4743 shape: vec![1, 17, 30, 21],
4744 anchors: Some(vec![
4745 [0.3666666, 0.3148148],
4746 [0.3874999, 0.474074],
4747 [0.5333333, 0.644444],
4748 ]),
4749 quantization: None,
4750 dshape: vec![
4751 (DimName::Batch, 1),
4752 (DimName::Height, 17),
4753 (DimName::Width, 30),
4754 (DimName::NumAnchorsXFeatures, 21),
4755 ],
4756 normalized: Some(true),
4757 },
4758 ])
4759 .build();
4760
4761 assert!(matches!(
4762 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("ModelPack Split Detection inconsistent number of classes:")));
4763
4764 let result = DecoderBuilder::default()
4765 .with_config_modelpack_det_split(vec![
4766 configs::Detection {
4767 decoder: DecoderType::ModelPack,
4768 shape: vec![1, 17, 30, 18],
4769 anchors: Some(vec![
4770 [0.3666666, 0.3148148],
4771 [0.3874999, 0.474074],
4772 [0.5333333, 0.644444],
4773 ]),
4774 quantization: None,
4775 dshape: vec![],
4776 normalized: Some(true),
4777 },
4778 configs::Detection {
4779 decoder: DecoderType::ModelPack,
4780 shape: vec![1, 17, 30, 21],
4781 anchors: Some(vec![
4782 [0.3666666, 0.3148148],
4783 [0.3874999, 0.474074],
4784 [0.5333333, 0.644444],
4785 ]),
4786 quantization: None,
4787 dshape: vec![],
4788 normalized: Some(true),
4789 },
4790 ])
4791 .build();
4792
4793 assert!(matches!(
4794 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("ModelPack Split Detection inconsistent number of classes:")));
4795 }
4796
4797 #[test]
4798 fn test_modelpack_invalid_seg() {
4799 let result = DecoderBuilder::new()
4800 .with_config_modelpack_seg(configs::Segmentation {
4801 decoder: DecoderType::ModelPack,
4802 quantization: None,
4803 shape: vec![1, 160, 106, 3, 1],
4804 dshape: vec![
4805 (DimName::Batch, 1),
4806 (DimName::Height, 160),
4807 (DimName::Width, 106),
4808 (DimName::NumClasses, 3),
4809 (DimName::Padding, 1),
4810 ],
4811 })
4812 .build();
4813
4814 assert!(matches!(
4815 result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid ModelPack Segmentation shape")));
4816 }
4817
4818 #[test]
4819 fn test_modelpack_invalid_segdet() {
4820 let result = DecoderBuilder::new()
4821 .with_config_modelpack_segdet(
4822 configs::Boxes {
4823 decoder: DecoderType::ModelPack,
4824 quantization: None,
4825 shape: vec![1, 4, 1, 8400],
4826 dshape: vec![
4827 (DimName::Batch, 1),
4828 (DimName::BoxCoords, 4),
4829 (DimName::Padding, 1),
4830 (DimName::NumBoxes, 8400),
4831 ],
4832 normalized: Some(true),
4833 },
4834 configs::Scores {
4835 decoder: DecoderType::ModelPack,
4836 quantization: None,
4837 shape: vec![1, 4, 8400],
4838 dshape: vec![
4839 (DimName::Batch, 1),
4840 (DimName::NumClasses, 4),
4841 (DimName::NumBoxes, 8400),
4842 ],
4843 },
4844 configs::Segmentation {
4845 decoder: DecoderType::ModelPack,
4846 quantization: None,
4847 shape: vec![1, 160, 106, 3],
4848 dshape: vec![
4849 (DimName::Batch, 1),
4850 (DimName::Height, 160),
4851 (DimName::Width, 106),
4852 (DimName::NumClasses, 3),
4853 ],
4854 },
4855 )
4856 .build();
4857
4858 assert!(matches!(
4859 result, Err(DecoderError::InvalidConfig(s)) if s.contains("incompatible with number of classes")));
4860 }
4861
4862 #[test]
4863 fn test_modelpack_invalid_segdet_split() {
4864 let result = DecoderBuilder::new()
4865 .with_config_modelpack_segdet_split(
4866 vec![configs::Detection {
4867 decoder: DecoderType::ModelPack,
4868 shape: vec![1, 17, 30, 18],
4869 anchors: Some(vec![
4870 [0.3666666, 0.3148148],
4871 [0.3874999, 0.474074],
4872 [0.5333333, 0.644444],
4873 ]),
4874 quantization: None,
4875 dshape: vec![
4876 (DimName::Batch, 1),
4877 (DimName::Height, 17),
4878 (DimName::Width, 30),
4879 (DimName::NumAnchorsXFeatures, 18),
4880 ],
4881 normalized: Some(true),
4882 }],
4883 configs::Segmentation {
4884 decoder: DecoderType::ModelPack,
4885 quantization: None,
4886 shape: vec![1, 160, 106, 3],
4887 dshape: vec![
4888 (DimName::Batch, 1),
4889 (DimName::Height, 160),
4890 (DimName::Width, 106),
4891 (DimName::NumClasses, 3),
4892 ],
4893 },
4894 )
4895 .build();
4896
4897 assert!(matches!(
4898 result, Err(DecoderError::InvalidConfig(s)) if s.contains("incompatible with number of classes")));
4899 }
4900
4901 #[test]
4902 fn test_decode_bad_shapes() {
4903 let score_threshold = 0.25;
4904 let iou_threshold = 0.7;
4905 let quant = (0.0040811873, -123);
4906 let out = include_bytes!("../../../testdata/yolov8s_80_classes.bin");
4907 let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
4908 let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
4909 let out_float: Array3<f32> = dequantize_ndarray(out.view(), quant.into());
4910
4911 let decoder = DecoderBuilder::default()
4912 .with_config_yolo_det(
4913 configs::Detection {
4914 decoder: DecoderType::Ultralytics,
4915 shape: vec![1, 85, 8400],
4916 anchors: None,
4917 quantization: Some(quant.into()),
4918 dshape: vec![
4919 (DimName::Batch, 1),
4920 (DimName::NumFeatures, 85),
4921 (DimName::NumBoxes, 8400),
4922 ],
4923 normalized: Some(true),
4924 },
4925 Some(DecoderVersion::Yolo11),
4926 )
4927 .with_score_threshold(score_threshold)
4928 .with_iou_threshold(iou_threshold)
4929 .build()
4930 .unwrap();
4931
4932 let mut output_boxes: Vec<_> = Vec::with_capacity(50);
4933 let mut output_masks: Vec<_> = Vec::with_capacity(50);
4934 let result =
4935 decoder.decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks);
4936
4937 assert!(matches!(
4938 result, Err(DecoderError::InvalidShape(s)) if s == "Did not find output with shape [1, 85, 8400]"));
4939
4940 let result = decoder.decode_float(
4941 &[out_float.view().into_dyn()],
4942 &mut output_boxes,
4943 &mut output_masks,
4944 );
4945
4946 assert!(matches!(
4947 result, Err(DecoderError::InvalidShape(s)) if s == "Did not find output with shape [1, 85, 8400]"));
4948 }
4949
4950 #[test]
4951 fn test_config_outputs() {
4952 let outputs = [
4953 ConfigOutput::Detection(configs::Detection {
4954 decoder: configs::DecoderType::Ultralytics,
4955 anchors: None,
4956 shape: vec![1, 8400, 85],
4957 quantization: Some(QuantTuple(0.123, 0)),
4958 dshape: vec![
4959 (DimName::Batch, 1),
4960 (DimName::NumBoxes, 8400),
4961 (DimName::NumFeatures, 85),
4962 ],
4963 normalized: Some(true),
4964 }),
4965 ConfigOutput::Mask(configs::Mask {
4966 decoder: configs::DecoderType::Ultralytics,
4967 shape: vec![1, 160, 160, 1],
4968 quantization: Some(QuantTuple(0.223, 0)),
4969 dshape: vec![
4970 (DimName::Batch, 1),
4971 (DimName::Height, 160),
4972 (DimName::Width, 160),
4973 (DimName::NumFeatures, 1),
4974 ],
4975 }),
4976 ConfigOutput::Segmentation(configs::Segmentation {
4977 decoder: configs::DecoderType::Ultralytics,
4978 shape: vec![1, 160, 160, 80],
4979 quantization: Some(QuantTuple(0.323, 0)),
4980 dshape: vec![
4981 (DimName::Batch, 1),
4982 (DimName::Height, 160),
4983 (DimName::Width, 160),
4984 (DimName::NumClasses, 80),
4985 ],
4986 }),
4987 ConfigOutput::Scores(configs::Scores {
4988 decoder: configs::DecoderType::Ultralytics,
4989 shape: vec![1, 8400, 80],
4990 quantization: Some(QuantTuple(0.423, 0)),
4991 dshape: vec![
4992 (DimName::Batch, 1),
4993 (DimName::NumBoxes, 8400),
4994 (DimName::NumClasses, 80),
4995 ],
4996 }),
4997 ConfigOutput::Boxes(configs::Boxes {
4998 decoder: configs::DecoderType::Ultralytics,
4999 shape: vec![1, 8400, 4],
5000 quantization: Some(QuantTuple(0.523, 0)),
5001 dshape: vec![
5002 (DimName::Batch, 1),
5003 (DimName::NumBoxes, 8400),
5004 (DimName::BoxCoords, 4),
5005 ],
5006 normalized: Some(true),
5007 }),
5008 ConfigOutput::Protos(configs::Protos {
5009 decoder: configs::DecoderType::Ultralytics,
5010 shape: vec![1, 32, 160, 160],
5011 quantization: Some(QuantTuple(0.623, 0)),
5012 dshape: vec![
5013 (DimName::Batch, 1),
5014 (DimName::NumProtos, 32),
5015 (DimName::Height, 160),
5016 (DimName::Width, 160),
5017 ],
5018 }),
5019 ConfigOutput::MaskCoefficients(configs::MaskCoefficients {
5020 decoder: configs::DecoderType::Ultralytics,
5021 shape: vec![1, 8400, 32],
5022 quantization: Some(QuantTuple(0.723, 0)),
5023 dshape: vec![
5024 (DimName::Batch, 1),
5025 (DimName::NumBoxes, 8400),
5026 (DimName::NumProtos, 32),
5027 ],
5028 }),
5029 ];
5030
5031 let shapes = outputs.clone().map(|x| x.shape().to_vec());
5032 assert_eq!(
5033 shapes,
5034 [
5035 vec![1, 8400, 85],
5036 vec![1, 160, 160, 1],
5037 vec![1, 160, 160, 80],
5038 vec![1, 8400, 80],
5039 vec![1, 8400, 4],
5040 vec![1, 32, 160, 160],
5041 vec![1, 8400, 32],
5042 ]
5043 );
5044
5045 let quants: [Option<(f32, i32)>; 7] = outputs.map(|x| x.quantization().map(|q| q.into()));
5046 assert_eq!(
5047 quants,
5048 [
5049 Some((0.123, 0)),
5050 Some((0.223, 0)),
5051 Some((0.323, 0)),
5052 Some((0.423, 0)),
5053 Some((0.523, 0)),
5054 Some((0.623, 0)),
5055 Some((0.723, 0)),
5056 ]
5057 );
5058 }
5059
5060 #[test]
5061 fn test_nms_from_config_yaml() {
5062 let yaml_class_agnostic = r#"
5064outputs:
5065 - decoder: ultralytics
5066 type: detection
5067 shape: [1, 84, 8400]
5068 dshape:
5069 - [batch, 1]
5070 - [num_features, 84]
5071 - [num_boxes, 8400]
5072nms: class_agnostic
5073"#;
5074 let decoder = DecoderBuilder::new()
5075 .with_config_yaml_str(yaml_class_agnostic.to_string())
5076 .build()
5077 .unwrap();
5078 assert_eq!(decoder.nms, Some(configs::Nms::ClassAgnostic));
5079
5080 let yaml_class_aware = r#"
5081outputs:
5082 - decoder: ultralytics
5083 type: detection
5084 shape: [1, 84, 8400]
5085 dshape:
5086 - [batch, 1]
5087 - [num_features, 84]
5088 - [num_boxes, 8400]
5089nms: class_aware
5090"#;
5091 let decoder = DecoderBuilder::new()
5092 .with_config_yaml_str(yaml_class_aware.to_string())
5093 .build()
5094 .unwrap();
5095 assert_eq!(decoder.nms, Some(configs::Nms::ClassAware));
5096
5097 let decoder = DecoderBuilder::new()
5099 .with_config_yaml_str(yaml_class_aware.to_string())
5100 .with_nms(Some(configs::Nms::ClassAgnostic)) .build()
5102 .unwrap();
5103 assert_eq!(decoder.nms, Some(configs::Nms::ClassAware));
5105 }
5106
5107 #[test]
5108 fn test_nms_from_config_json() {
5109 let json_class_aware = r#"{
5111 "outputs": [{
5112 "decoder": "ultralytics",
5113 "type": "detection",
5114 "shape": [1, 84, 8400],
5115 "dshape": [["batch", 1], ["num_features", 84], ["num_boxes", 8400]]
5116 }],
5117 "nms": "class_aware"
5118 }"#;
5119 let decoder = DecoderBuilder::new()
5120 .with_config_json_str(json_class_aware.to_string())
5121 .build()
5122 .unwrap();
5123 assert_eq!(decoder.nms, Some(configs::Nms::ClassAware));
5124 }
5125
5126 #[test]
5127 fn test_nms_missing_from_config_uses_builder_default() {
5128 let yaml_no_nms = r#"
5130outputs:
5131 - decoder: ultralytics
5132 type: detection
5133 shape: [1, 84, 8400]
5134 dshape:
5135 - [batch, 1]
5136 - [num_features, 84]
5137 - [num_boxes, 8400]
5138"#;
5139 let decoder = DecoderBuilder::new()
5140 .with_config_yaml_str(yaml_no_nms.to_string())
5141 .build()
5142 .unwrap();
5143 assert_eq!(decoder.nms, Some(configs::Nms::ClassAgnostic));
5145
5146 let decoder = DecoderBuilder::new()
5148 .with_config_yaml_str(yaml_no_nms.to_string())
5149 .with_nms(None) .build()
5151 .unwrap();
5152 assert_eq!(decoder.nms, None);
5153 }
5154
5155 #[test]
5156 fn test_decoder_version_yolo26_end_to_end() {
5157 let yaml = r#"
5159outputs:
5160 - decoder: ultralytics
5161 type: detection
5162 shape: [1, 6, 8400]
5163 dshape:
5164 - [batch, 1]
5165 - [num_features, 6]
5166 - [num_boxes, 8400]
5167decoder_version: yolo26
5168"#;
5169 let decoder = DecoderBuilder::new()
5170 .with_config_yaml_str(yaml.to_string())
5171 .build()
5172 .unwrap();
5173 assert!(matches!(
5174 decoder.model_type,
5175 ModelType::YoloEndToEndDet { .. }
5176 ));
5177
5178 let yaml_with_nms = r#"
5180outputs:
5181 - decoder: ultralytics
5182 type: detection
5183 shape: [1, 6, 8400]
5184 dshape:
5185 - [batch, 1]
5186 - [num_features, 6]
5187 - [num_boxes, 8400]
5188decoder_version: yolo26
5189nms: class_agnostic
5190"#;
5191 let decoder = DecoderBuilder::new()
5192 .with_config_yaml_str(yaml_with_nms.to_string())
5193 .build()
5194 .unwrap();
5195 assert!(matches!(
5196 decoder.model_type,
5197 ModelType::YoloEndToEndDet { .. }
5198 ));
5199 }
5200
5201 #[test]
5202 fn test_decoder_version_yolov8_traditional() {
5203 let yaml = r#"
5205outputs:
5206 - decoder: ultralytics
5207 type: detection
5208 shape: [1, 84, 8400]
5209 dshape:
5210 - [batch, 1]
5211 - [num_features, 84]
5212 - [num_boxes, 8400]
5213decoder_version: yolov8
5214"#;
5215 let decoder = DecoderBuilder::new()
5216 .with_config_yaml_str(yaml.to_string())
5217 .build()
5218 .unwrap();
5219 assert!(matches!(decoder.model_type, ModelType::YoloDet { .. }));
5220 }
5221
5222 #[test]
5223 fn test_decoder_version_all_versions() {
5224 for version in ["yolov5", "yolov8", "yolo11"] {
5226 let yaml = format!(
5227 r#"
5228outputs:
5229 - decoder: ultralytics
5230 type: detection
5231 shape: [1, 84, 8400]
5232 dshape:
5233 - [batch, 1]
5234 - [num_features, 84]
5235 - [num_boxes, 8400]
5236decoder_version: {}
5237"#,
5238 version
5239 );
5240 let decoder = DecoderBuilder::new()
5241 .with_config_yaml_str(yaml)
5242 .build()
5243 .unwrap();
5244
5245 assert!(
5246 matches!(decoder.model_type, ModelType::YoloDet { .. }),
5247 "Expected traditional for {}",
5248 version
5249 );
5250 }
5251
5252 let yaml = r#"
5253outputs:
5254 - decoder: ultralytics
5255 type: detection
5256 shape: [1, 6, 8400]
5257 dshape:
5258 - [batch, 1]
5259 - [num_features, 6]
5260 - [num_boxes, 8400]
5261decoder_version: yolo26
5262"#
5263 .to_string();
5264
5265 let decoder = DecoderBuilder::new()
5266 .with_config_yaml_str(yaml)
5267 .build()
5268 .unwrap();
5269
5270 assert!(
5271 matches!(decoder.model_type, ModelType::YoloEndToEndDet { .. }),
5272 "Expected end to end for yolo26",
5273 );
5274 }
5275
5276 #[test]
5277 fn test_decoder_version_json() {
5278 let json = r#"{
5280 "outputs": [{
5281 "decoder": "ultralytics",
5282 "type": "detection",
5283 "shape": [1, 6, 8400],
5284 "dshape": [["batch", 1], ["num_features", 6], ["num_boxes", 8400]]
5285 }],
5286 "decoder_version": "yolo26"
5287 }"#;
5288 let decoder = DecoderBuilder::new()
5289 .with_config_json_str(json.to_string())
5290 .build()
5291 .unwrap();
5292 assert!(matches!(
5293 decoder.model_type,
5294 ModelType::YoloEndToEndDet { .. }
5295 ));
5296 }
5297
5298 #[test]
5299 fn test_decoder_version_none_uses_traditional() {
5300 let yaml = r#"
5302outputs:
5303 - decoder: ultralytics
5304 type: detection
5305 shape: [1, 84, 8400]
5306 dshape:
5307 - [batch, 1]
5308 - [num_features, 84]
5309 - [num_boxes, 8400]
5310"#;
5311 let decoder = DecoderBuilder::new()
5312 .with_config_yaml_str(yaml.to_string())
5313 .build()
5314 .unwrap();
5315 assert!(matches!(decoder.model_type, ModelType::YoloDet { .. }));
5316 }
5317
5318 #[test]
5319 fn test_decoder_version_none_with_nms_none_still_traditional() {
5320 let yaml = r#"
5323outputs:
5324 - decoder: ultralytics
5325 type: detection
5326 shape: [1, 84, 8400]
5327 dshape:
5328 - [batch, 1]
5329 - [num_features, 84]
5330 - [num_boxes, 8400]
5331"#;
5332 let decoder = DecoderBuilder::new()
5333 .with_config_yaml_str(yaml.to_string())
5334 .with_nms(None) .build()
5336 .unwrap();
5337 assert!(matches!(decoder.model_type, ModelType::YoloDet { .. }));
5340 }
5341
5342 #[test]
5343 fn test_decoder_heuristic_end_to_end_detection() {
5344 let yaml = r#"
5347outputs:
5348 - decoder: ultralytics
5349 type: detection
5350 shape: [1, 300, 6]
5351 dshape:
5352 - [batch, 1]
5353 - [num_boxes, 300]
5354 - [num_features, 6]
5355
5356"#;
5357 let decoder = DecoderBuilder::new()
5358 .with_config_yaml_str(yaml.to_string())
5359 .build()
5360 .unwrap();
5361 assert!(matches!(
5363 decoder.model_type,
5364 ModelType::YoloEndToEndDet { .. }
5365 ));
5366
5367 let yaml = r#"
5368outputs:
5369 - decoder: ultralytics
5370 type: detection
5371 shape: [1, 300, 38]
5372 dshape:
5373 - [batch, 1]
5374 - [num_boxes, 300]
5375 - [num_features, 38]
5376 - decoder: ultralytics
5377 type: protos
5378 shape: [1, 160, 160, 32]
5379 dshape:
5380 - [batch, 1]
5381 - [height, 160]
5382 - [width, 160]
5383 - [num_protos, 32]
5384"#;
5385 let decoder = DecoderBuilder::new()
5386 .with_config_yaml_str(yaml.to_string())
5387 .build()
5388 .unwrap();
5389 assert!(matches!(
5391 decoder.model_type,
5392 ModelType::YoloEndToEndSegDet { .. }
5393 ));
5394
5395 let yaml = r#"
5396outputs:
5397 - decoder: ultralytics
5398 type: detection
5399 shape: [1, 6, 300]
5400 dshape:
5401 - [batch, 1]
5402 - [num_features, 6]
5403 - [num_boxes, 300]
5404"#;
5405 let decoder = DecoderBuilder::new()
5406 .with_config_yaml_str(yaml.to_string())
5407 .build()
5408 .unwrap();
5409 assert!(matches!(decoder.model_type, ModelType::YoloDet { .. }));
5412
5413 let yaml = r#"
5414outputs:
5415 - decoder: ultralytics
5416 type: detection
5417 shape: [1, 38, 300]
5418 dshape:
5419 - [batch, 1]
5420 - [num_features, 38]
5421 - [num_boxes, 300]
5422
5423 - decoder: ultralytics
5424 type: protos
5425 shape: [1, 160, 160, 32]
5426 dshape:
5427 - [batch, 1]
5428 - [height, 160]
5429 - [width, 160]
5430 - [num_protos, 32]
5431"#;
5432 let decoder = DecoderBuilder::new()
5433 .with_config_yaml_str(yaml.to_string())
5434 .build()
5435 .unwrap();
5436 assert!(matches!(decoder.model_type, ModelType::YoloSegDet { .. }));
5438 }
5439
5440 #[test]
5441 fn test_decoder_version_is_end_to_end() {
5442 assert!(!configs::DecoderVersion::Yolov5.is_end_to_end());
5443 assert!(!configs::DecoderVersion::Yolov8.is_end_to_end());
5444 assert!(!configs::DecoderVersion::Yolo11.is_end_to_end());
5445 assert!(configs::DecoderVersion::Yolo26.is_end_to_end());
5446 }
5447}