1use std::collections::HashSet;
5
6use super::config::ConfigOutputRef;
7use super::configs::{self, DecoderType, DecoderVersion, DimName, ModelType};
8use super::{ConfigOutput, ConfigOutputs, Decoder};
9use crate::DecoderError;
10
11#[derive(Debug, Clone, PartialEq)]
12pub struct DecoderBuilder {
13 config_src: Option<ConfigSource>,
14 iou_threshold: f32,
15 score_threshold: f32,
16 nms: Option<configs::Nms>,
19}
20
21#[derive(Debug, Clone, PartialEq)]
22enum ConfigSource {
23 Yaml(String),
24 Json(String),
25 Config(ConfigOutputs),
26}
27
28impl Default for DecoderBuilder {
29 fn default() -> Self {
49 Self {
50 config_src: None,
51 iou_threshold: 0.5,
52 score_threshold: 0.5,
53 nms: Some(configs::Nms::ClassAgnostic),
54 }
55 }
56}
57
58impl DecoderBuilder {
59 pub fn new() -> Self {
79 Self::default()
80 }
81
82 pub fn with_config_yaml_str(mut self, yaml_str: String) -> Self {
99 self.config_src.replace(ConfigSource::Yaml(yaml_str));
100 self
101 }
102
103 pub fn with_config_json_str(mut self, json_str: String) -> Self {
120 self.config_src.replace(ConfigSource::Json(json_str));
121 self
122 }
123
124 pub fn with_config(mut self, config: ConfigOutputs) -> Self {
141 self.config_src.replace(ConfigSource::Config(config));
142 self
143 }
144
145 pub fn with_config_yolo_det(
170 mut self,
171 boxes: configs::Detection,
172 version: Option<DecoderVersion>,
173 ) -> Self {
174 let config = ConfigOutputs {
175 outputs: vec![ConfigOutput::Detection(boxes)],
176 decoder_version: version,
177 ..Default::default()
178 };
179 self.config_src.replace(ConfigSource::Config(config));
180 self
181 }
182
183 pub fn with_config_yolo_split_det(
210 mut self,
211 boxes: configs::Boxes,
212 scores: configs::Scores,
213 ) -> Self {
214 let config = ConfigOutputs {
215 outputs: vec![ConfigOutput::Boxes(boxes), ConfigOutput::Scores(scores)],
216 ..Default::default()
217 };
218 self.config_src.replace(ConfigSource::Config(config));
219 self
220 }
221
222 pub fn with_config_yolo_segdet(
254 mut self,
255 boxes: configs::Detection,
256 protos: configs::Protos,
257 version: Option<DecoderVersion>,
258 ) -> Self {
259 let config = ConfigOutputs {
260 outputs: vec![ConfigOutput::Detection(boxes), ConfigOutput::Protos(protos)],
261 decoder_version: version,
262 ..Default::default()
263 };
264 self.config_src.replace(ConfigSource::Config(config));
265 self
266 }
267
268 pub fn with_config_yolo_split_segdet(
307 mut self,
308 boxes: configs::Boxes,
309 scores: configs::Scores,
310 mask_coefficients: configs::MaskCoefficients,
311 protos: configs::Protos,
312 ) -> Self {
313 let config = ConfigOutputs {
314 outputs: vec![
315 ConfigOutput::Boxes(boxes),
316 ConfigOutput::Scores(scores),
317 ConfigOutput::MaskCoefficients(mask_coefficients),
318 ConfigOutput::Protos(protos),
319 ],
320 ..Default::default()
321 };
322 self.config_src.replace(ConfigSource::Config(config));
323 self
324 }
325
326 pub fn with_config_modelpack_det(
353 mut self,
354 boxes: configs::Boxes,
355 scores: configs::Scores,
356 ) -> Self {
357 let config = ConfigOutputs {
358 outputs: vec![ConfigOutput::Boxes(boxes), ConfigOutput::Scores(scores)],
359 ..Default::default()
360 };
361 self.config_src.replace(ConfigSource::Config(config));
362 self
363 }
364
365 pub fn with_config_modelpack_det_split(mut self, boxes: Vec<configs::Detection>) -> Self {
404 let outputs = boxes.into_iter().map(ConfigOutput::Detection).collect();
405 let config = ConfigOutputs {
406 outputs,
407 ..Default::default()
408 };
409 self.config_src.replace(ConfigSource::Config(config));
410 self
411 }
412
413 pub fn with_config_modelpack_segdet(
446 mut self,
447 boxes: configs::Boxes,
448 scores: configs::Scores,
449 segmentation: configs::Segmentation,
450 ) -> Self {
451 let config = ConfigOutputs {
452 outputs: vec![
453 ConfigOutput::Boxes(boxes),
454 ConfigOutput::Scores(scores),
455 ConfigOutput::Segmentation(segmentation),
456 ],
457 ..Default::default()
458 };
459 self.config_src.replace(ConfigSource::Config(config));
460 self
461 }
462
463 pub fn with_config_modelpack_segdet_split(
507 mut self,
508 boxes: Vec<configs::Detection>,
509 segmentation: configs::Segmentation,
510 ) -> Self {
511 let mut outputs = boxes
512 .into_iter()
513 .map(ConfigOutput::Detection)
514 .collect::<Vec<_>>();
515 outputs.push(ConfigOutput::Segmentation(segmentation));
516 let config = ConfigOutputs {
517 outputs,
518 ..Default::default()
519 };
520 self.config_src.replace(ConfigSource::Config(config));
521 self
522 }
523
524 pub fn with_config_modelpack_seg(mut self, segmentation: configs::Segmentation) -> Self {
544 let config = ConfigOutputs {
545 outputs: vec![ConfigOutput::Segmentation(segmentation)],
546 ..Default::default()
547 };
548 self.config_src.replace(ConfigSource::Config(config));
549 self
550 }
551
552 pub fn add_output(mut self, output: ConfigOutput) -> Self {
594 if !matches!(self.config_src, Some(ConfigSource::Config(_))) {
595 self.config_src = Some(ConfigSource::Config(ConfigOutputs::default()));
596 }
597 if let Some(ConfigSource::Config(ref mut config)) = self.config_src {
598 config.outputs.push(Self::normalize_output(output));
599 }
600 self
601 }
602
603 pub fn with_decoder_version(mut self, version: configs::DecoderVersion) -> Self {
632 if !matches!(self.config_src, Some(ConfigSource::Config(_))) {
633 self.config_src = Some(ConfigSource::Config(ConfigOutputs::default()));
634 }
635 if let Some(ConfigSource::Config(ref mut config)) = self.config_src {
636 config.decoder_version = Some(version);
637 }
638 self
639 }
640
641 fn normalize_output(mut output: ConfigOutput) -> ConfigOutput {
643 fn normalize_shape(shape: &mut Vec<usize>, dshape: &[(configs::DimName, usize)]) {
644 if !dshape.is_empty() {
645 *shape = dshape.iter().map(|(_, size)| *size).collect();
646 }
647 }
648 match &mut output {
649 ConfigOutput::Detection(c) => normalize_shape(&mut c.shape, &c.dshape),
650 ConfigOutput::Boxes(c) => normalize_shape(&mut c.shape, &c.dshape),
651 ConfigOutput::Scores(c) => normalize_shape(&mut c.shape, &c.dshape),
652 ConfigOutput::Protos(c) => normalize_shape(&mut c.shape, &c.dshape),
653 ConfigOutput::Segmentation(c) => normalize_shape(&mut c.shape, &c.dshape),
654 ConfigOutput::MaskCoefficients(c) => normalize_shape(&mut c.shape, &c.dshape),
655 ConfigOutput::Mask(c) => normalize_shape(&mut c.shape, &c.dshape),
656 ConfigOutput::Classes(c) => normalize_shape(&mut c.shape, &c.dshape),
657 }
658 output
659 }
660
661 pub fn with_score_threshold(mut self, score_threshold: f32) -> Self {
677 self.score_threshold = score_threshold;
678 self
679 }
680
681 pub fn with_iou_threshold(mut self, iou_threshold: f32) -> Self {
698 self.iou_threshold = iou_threshold;
699 self
700 }
701
702 pub fn with_nms(mut self, nms: Option<configs::Nms>) -> Self {
724 self.nms = nms;
725 self
726 }
727
728 pub fn build(self) -> Result<Decoder, DecoderError> {
745 let config = match self.config_src {
746 Some(ConfigSource::Json(s)) => serde_json::from_str(&s)?,
747 Some(ConfigSource::Yaml(s)) => serde_yaml::from_str(&s)?,
748 Some(ConfigSource::Config(c)) => c,
749 None => return Err(DecoderError::NoConfig),
750 };
751
752 let normalized = Self::get_normalized(&config.outputs);
754
755 let nms = config.nms.or(self.nms);
757 let model_type = Self::get_model_type(config)?;
758
759 Ok(Decoder {
760 model_type,
761 iou_threshold: self.iou_threshold,
762 score_threshold: self.score_threshold,
763 nms,
764 normalized,
765 })
766 }
767
768 fn get_normalized(outputs: &[ConfigOutput]) -> Option<bool> {
773 for output in outputs {
774 match output {
775 ConfigOutput::Detection(det) => return det.normalized,
776 ConfigOutput::Boxes(boxes) => return boxes.normalized,
777 _ => {}
778 }
779 }
780 None }
782
783 fn get_model_type(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
784 let mut yolo = false;
786 let mut modelpack = false;
787 for c in &configs.outputs {
788 match c.decoder() {
789 DecoderType::ModelPack => modelpack = true,
790 DecoderType::Ultralytics => yolo = true,
791 }
792 }
793 match (modelpack, yolo) {
794 (true, true) => Err(DecoderError::InvalidConfig(
795 "Both ModelPack and Yolo outputs found in config".to_string(),
796 )),
797 (true, false) => Self::get_model_type_modelpack(configs),
798 (false, true) => Self::get_model_type_yolo(configs),
799 (false, false) => Err(DecoderError::InvalidConfig(
800 "No outputs found in config".to_string(),
801 )),
802 }
803 }
804
805 fn get_model_type_yolo(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
806 let mut boxes = None;
807 let mut protos = None;
808 let mut split_boxes = None;
809 let mut split_scores = None;
810 let mut split_mask_coeff = None;
811 let mut split_classes = None;
812 for c in configs.outputs {
813 match c {
814 ConfigOutput::Detection(detection) => boxes = Some(detection),
815 ConfigOutput::Segmentation(_) => {
816 return Err(DecoderError::InvalidConfig(
817 "Invalid Segmentation output with Yolo decoder".to_string(),
818 ));
819 }
820 ConfigOutput::Protos(protos_) => protos = Some(protos_),
821 ConfigOutput::Mask(_) => {
822 return Err(DecoderError::InvalidConfig(
823 "Invalid Mask output with Yolo decoder".to_string(),
824 ));
825 }
826 ConfigOutput::Scores(scores) => split_scores = Some(scores),
827 ConfigOutput::Boxes(boxes) => split_boxes = Some(boxes),
828 ConfigOutput::MaskCoefficients(mask_coeff) => split_mask_coeff = Some(mask_coeff),
829 ConfigOutput::Classes(classes) => split_classes = Some(classes),
830 }
831 }
832
833 let is_end_to_end_dshape = boxes.as_ref().is_some_and(|b| {
838 let dims = b.dshape.iter().map(|(d, _)| *d).collect::<Vec<_>>();
839 dims == vec![DimName::Batch, DimName::NumBoxes, DimName::NumFeatures]
840 });
841
842 let is_end_to_end = configs
843 .decoder_version
844 .map(|v| v.is_end_to_end())
845 .unwrap_or(is_end_to_end_dshape);
846
847 if is_end_to_end {
848 if let Some(boxes) = boxes {
849 if let Some(protos) = protos {
850 Self::verify_yolo_seg_det_26(&boxes, &protos)?;
851 return Ok(ModelType::YoloEndToEndSegDet { boxes, protos });
852 } else {
853 Self::verify_yolo_det_26(&boxes)?;
854 return Ok(ModelType::YoloEndToEndDet { boxes });
855 }
856 } else if let (Some(split_boxes), Some(split_scores), Some(split_classes)) =
857 (split_boxes, split_scores, split_classes)
858 {
859 if let (Some(split_mask_coeff), Some(protos)) = (split_mask_coeff, protos) {
860 Self::verify_yolo_split_end_to_end_segdet(
861 &split_boxes,
862 &split_scores,
863 &split_classes,
864 &split_mask_coeff,
865 &protos,
866 )?;
867 return Ok(ModelType::YoloSplitEndToEndSegDet {
868 boxes: split_boxes,
869 scores: split_scores,
870 classes: split_classes,
871 mask_coeff: split_mask_coeff,
872 protos,
873 });
874 }
875 Self::verify_yolo_split_end_to_end_det(
876 &split_boxes,
877 &split_scores,
878 &split_classes,
879 )?;
880 return Ok(ModelType::YoloSplitEndToEndDet {
881 boxes: split_boxes,
882 scores: split_scores,
883 classes: split_classes,
884 });
885 } else {
886 return Err(DecoderError::InvalidConfig(
887 "Invalid Yolo end-to-end model outputs".to_string(),
888 ));
889 }
890 }
891
892 if let Some(boxes) = boxes {
893 match (split_mask_coeff, protos) {
894 (Some(mask_coeff), Some(protos)) => {
895 Self::verify_yolo_seg_det_2way(&boxes, &mask_coeff, &protos)?;
897 Ok(ModelType::YoloSegDet2Way {
898 boxes,
899 mask_coeff,
900 protos,
901 })
902 }
903 (_, Some(protos)) => {
904 Self::verify_yolo_seg_det(&boxes, &protos)?;
906 Ok(ModelType::YoloSegDet { boxes, protos })
907 }
908 _ => {
909 Self::verify_yolo_det(&boxes)?;
910 Ok(ModelType::YoloDet { boxes })
911 }
912 }
913 } else if let (Some(boxes), Some(scores)) = (split_boxes, split_scores) {
914 if let (Some(mask_coeff), Some(protos)) = (split_mask_coeff, protos) {
915 Self::verify_yolo_split_segdet(&boxes, &scores, &mask_coeff, &protos)?;
916 Ok(ModelType::YoloSplitSegDet {
917 boxes,
918 scores,
919 mask_coeff,
920 protos,
921 })
922 } else {
923 Self::verify_yolo_split_det(&boxes, &scores)?;
924 Ok(ModelType::YoloSplitDet { boxes, scores })
925 }
926 } else {
927 Err(DecoderError::InvalidConfig(
928 "Invalid Yolo model outputs".to_string(),
929 ))
930 }
931 }
932
933 fn verify_yolo_det(detect: &configs::Detection) -> Result<(), DecoderError> {
934 if detect.shape.len() != 3 {
935 return Err(DecoderError::InvalidConfig(format!(
936 "Invalid Yolo Detection shape {:?}",
937 detect.shape
938 )));
939 }
940
941 Self::verify_dshapes(
942 &detect.dshape,
943 &detect.shape,
944 "Detection",
945 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
946 )?;
947 if !detect.dshape.is_empty() {
948 Self::get_class_count(&detect.dshape, None, None)?;
949 } else {
950 Self::get_class_count_no_dshape(detect.into(), None)?;
951 }
952
953 Ok(())
954 }
955
956 fn verify_yolo_det_26(detect: &configs::Detection) -> Result<(), DecoderError> {
957 if detect.shape.len() != 3 {
958 return Err(DecoderError::InvalidConfig(format!(
959 "Invalid Yolo Detection shape {:?}",
960 detect.shape
961 )));
962 }
963
964 Self::verify_dshapes(
965 &detect.dshape,
966 &detect.shape,
967 "Detection",
968 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
969 )?;
970
971 if !detect.shape.contains(&6) {
972 return Err(DecoderError::InvalidConfig(
973 "Yolo26 Detection must have 6 features".to_string(),
974 ));
975 }
976
977 Ok(())
978 }
979
980 fn verify_yolo_seg_det(
981 detection: &configs::Detection,
982 protos: &configs::Protos,
983 ) -> Result<(), DecoderError> {
984 if detection.shape.len() != 3 {
985 return Err(DecoderError::InvalidConfig(format!(
986 "Invalid Yolo Detection shape {:?}",
987 detection.shape
988 )));
989 }
990 if protos.shape.len() != 4 {
991 return Err(DecoderError::InvalidConfig(format!(
992 "Invalid Yolo Protos shape {:?}",
993 protos.shape
994 )));
995 }
996
997 Self::verify_dshapes(
998 &detection.dshape,
999 &detection.shape,
1000 "Detection",
1001 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1002 )?;
1003 Self::verify_dshapes(
1004 &protos.dshape,
1005 &protos.shape,
1006 "Protos",
1007 &[
1008 DimName::Batch,
1009 DimName::Height,
1010 DimName::Width,
1011 DimName::NumProtos,
1012 ],
1013 )?;
1014
1015 let protos_count = Self::get_protos_count(&protos.dshape)
1016 .unwrap_or_else(|| protos.shape[1].min(protos.shape[3]));
1017 log::debug!("Protos count: {}", protos_count);
1018 log::debug!("Detection dshape: {:?}", detection.dshape);
1019 let classes = if !detection.dshape.is_empty() {
1020 Self::get_class_count(&detection.dshape, Some(protos_count), None)?
1021 } else {
1022 Self::get_class_count_no_dshape(detection.into(), Some(protos_count))?
1023 };
1024
1025 if classes == 0 {
1026 return Err(DecoderError::InvalidConfig(
1027 "Yolo Segmentation Detection has zero classes".to_string(),
1028 ));
1029 }
1030
1031 Ok(())
1032 }
1033
1034 fn verify_yolo_seg_det_2way(
1035 detection: &configs::Detection,
1036 mask_coeff: &configs::MaskCoefficients,
1037 protos: &configs::Protos,
1038 ) -> Result<(), DecoderError> {
1039 if detection.shape.len() != 3 {
1040 return Err(DecoderError::InvalidConfig(format!(
1041 "Invalid Yolo 2-Way Detection shape {:?}",
1042 detection.shape
1043 )));
1044 }
1045 if mask_coeff.shape.len() != 3 {
1046 return Err(DecoderError::InvalidConfig(format!(
1047 "Invalid Yolo 2-Way Mask Coefficients shape {:?}",
1048 mask_coeff.shape
1049 )));
1050 }
1051 if protos.shape.len() != 4 {
1052 return Err(DecoderError::InvalidConfig(format!(
1053 "Invalid Yolo 2-Way Protos shape {:?}",
1054 protos.shape
1055 )));
1056 }
1057
1058 Self::verify_dshapes(
1059 &detection.dshape,
1060 &detection.shape,
1061 "Detection",
1062 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1063 )?;
1064 Self::verify_dshapes(
1065 &mask_coeff.dshape,
1066 &mask_coeff.shape,
1067 "Mask Coefficients",
1068 &[DimName::Batch, DimName::NumProtos, DimName::NumBoxes],
1069 )?;
1070 Self::verify_dshapes(
1071 &protos.dshape,
1072 &protos.shape,
1073 "Protos",
1074 &[
1075 DimName::Batch,
1076 DimName::Height,
1077 DimName::Width,
1078 DimName::NumProtos,
1079 ],
1080 )?;
1081
1082 let det_num = Self::get_box_count(&detection.dshape).unwrap_or(detection.shape[2]);
1084 let mask_num = Self::get_box_count(&mask_coeff.dshape).unwrap_or(mask_coeff.shape[2]);
1085 if det_num != mask_num {
1086 return Err(DecoderError::InvalidConfig(format!(
1087 "Yolo 2-Way Detection num_boxes {} incompatible with Mask Coefficients num_boxes {}",
1088 det_num, mask_num
1089 )));
1090 }
1091
1092 let mask_channels = if !mask_coeff.dshape.is_empty() {
1094 Self::get_protos_count(&mask_coeff.dshape).ok_or_else(|| {
1095 DecoderError::InvalidConfig(
1096 "Could not find num_protos in mask_coeff config".to_string(),
1097 )
1098 })?
1099 } else {
1100 mask_coeff.shape[1]
1101 };
1102 let proto_channels = if !protos.dshape.is_empty() {
1103 Self::get_protos_count(&protos.dshape).ok_or_else(|| {
1104 DecoderError::InvalidConfig(
1105 "Could not find num_protos in protos config".to_string(),
1106 )
1107 })?
1108 } else {
1109 protos.shape[1].min(protos.shape[3])
1110 };
1111 if mask_channels != proto_channels {
1112 return Err(DecoderError::InvalidConfig(format!(
1113 "Yolo 2-Way Protos channels {} incompatible with Mask Coefficients channels {}",
1114 proto_channels, mask_channels
1115 )));
1116 }
1117
1118 if !detection.dshape.is_empty() {
1120 Self::get_class_count(&detection.dshape, None, None)?;
1121 } else {
1122 Self::get_class_count_no_dshape(detection.into(), None)?;
1123 }
1124
1125 Ok(())
1126 }
1127
1128 fn verify_yolo_seg_det_26(
1129 detection: &configs::Detection,
1130 protos: &configs::Protos,
1131 ) -> Result<(), DecoderError> {
1132 if detection.shape.len() != 3 {
1133 return Err(DecoderError::InvalidConfig(format!(
1134 "Invalid Yolo Detection shape {:?}",
1135 detection.shape
1136 )));
1137 }
1138 if protos.shape.len() != 4 {
1139 return Err(DecoderError::InvalidConfig(format!(
1140 "Invalid Yolo Protos shape {:?}",
1141 protos.shape
1142 )));
1143 }
1144
1145 Self::verify_dshapes(
1146 &detection.dshape,
1147 &detection.shape,
1148 "Detection",
1149 &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1150 )?;
1151 Self::verify_dshapes(
1152 &protos.dshape,
1153 &protos.shape,
1154 "Protos",
1155 &[
1156 DimName::Batch,
1157 DimName::Height,
1158 DimName::Width,
1159 DimName::NumProtos,
1160 ],
1161 )?;
1162
1163 let protos_count = Self::get_protos_count(&protos.dshape)
1164 .unwrap_or_else(|| protos.shape[1].min(protos.shape[3]));
1165 log::debug!("Protos count: {}", protos_count);
1166 log::debug!("Detection dshape: {:?}", detection.dshape);
1167
1168 if !detection.shape.contains(&(6 + protos_count)) {
1169 return Err(DecoderError::InvalidConfig(format!(
1170 "Yolo26 Segmentation Detection must have num_features be 6 + num_protos = {}",
1171 6 + protos_count
1172 )));
1173 }
1174
1175 Ok(())
1176 }
1177
1178 fn verify_yolo_split_det(
1179 boxes: &configs::Boxes,
1180 scores: &configs::Scores,
1181 ) -> Result<(), DecoderError> {
1182 if boxes.shape.len() != 3 {
1183 return Err(DecoderError::InvalidConfig(format!(
1184 "Invalid Yolo Split Boxes shape {:?}",
1185 boxes.shape
1186 )));
1187 }
1188 if scores.shape.len() != 3 {
1189 return Err(DecoderError::InvalidConfig(format!(
1190 "Invalid Yolo Split Scores shape {:?}",
1191 scores.shape
1192 )));
1193 }
1194
1195 Self::verify_dshapes(
1196 &boxes.dshape,
1197 &boxes.shape,
1198 "Boxes",
1199 &[DimName::Batch, DimName::BoxCoords, DimName::NumBoxes],
1200 )?;
1201 Self::verify_dshapes(
1202 &scores.dshape,
1203 &scores.shape,
1204 "Scores",
1205 &[DimName::Batch, DimName::NumClasses, DimName::NumBoxes],
1206 )?;
1207
1208 let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[2]);
1209 let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[2]);
1210
1211 if boxes_num != scores_num {
1212 return Err(DecoderError::InvalidConfig(format!(
1213 "Yolo Split Detection Boxes num {} incompatible with Scores num {}",
1214 boxes_num, scores_num
1215 )));
1216 }
1217
1218 Ok(())
1219 }
1220
1221 fn verify_yolo_split_segdet(
1222 boxes: &configs::Boxes,
1223 scores: &configs::Scores,
1224 mask_coeff: &configs::MaskCoefficients,
1225 protos: &configs::Protos,
1226 ) -> Result<(), DecoderError> {
1227 if boxes.shape.len() != 3 {
1228 return Err(DecoderError::InvalidConfig(format!(
1229 "Invalid Yolo Split Boxes shape {:?}",
1230 boxes.shape
1231 )));
1232 }
1233 if scores.shape.len() != 3 {
1234 return Err(DecoderError::InvalidConfig(format!(
1235 "Invalid Yolo Split Scores shape {:?}",
1236 scores.shape
1237 )));
1238 }
1239
1240 if mask_coeff.shape.len() != 3 {
1241 return Err(DecoderError::InvalidConfig(format!(
1242 "Invalid Yolo Split Mask Coefficients shape {:?}",
1243 mask_coeff.shape
1244 )));
1245 }
1246
1247 if protos.shape.len() != 4 {
1248 return Err(DecoderError::InvalidConfig(format!(
1249 "Invalid Yolo Protos shape {:?}",
1250 mask_coeff.shape
1251 )));
1252 }
1253
1254 Self::verify_dshapes(
1255 &boxes.dshape,
1256 &boxes.shape,
1257 "Boxes",
1258 &[DimName::Batch, DimName::BoxCoords, DimName::NumBoxes],
1259 )?;
1260 Self::verify_dshapes(
1261 &scores.dshape,
1262 &scores.shape,
1263 "Scores",
1264 &[DimName::Batch, DimName::NumClasses, DimName::NumBoxes],
1265 )?;
1266 Self::verify_dshapes(
1267 &mask_coeff.dshape,
1268 &mask_coeff.shape,
1269 "Mask Coefficients",
1270 &[DimName::Batch, DimName::NumProtos, DimName::NumBoxes],
1271 )?;
1272 Self::verify_dshapes(
1273 &protos.dshape,
1274 &protos.shape,
1275 "Protos",
1276 &[
1277 DimName::Batch,
1278 DimName::Height,
1279 DimName::Width,
1280 DimName::NumProtos,
1281 ],
1282 )?;
1283
1284 let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[2]);
1285 let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[2]);
1286 let mask_num = Self::get_box_count(&mask_coeff.dshape).unwrap_or(mask_coeff.shape[2]);
1287
1288 let mask_channels = if !mask_coeff.dshape.is_empty() {
1289 Self::get_protos_count(&mask_coeff.dshape).ok_or_else(|| {
1290 DecoderError::InvalidConfig("Could not find num_protos in config".to_string())
1291 })?
1292 } else {
1293 mask_coeff.shape[1]
1294 };
1295 let proto_channels = if !protos.dshape.is_empty() {
1296 Self::get_protos_count(&protos.dshape).ok_or_else(|| {
1297 DecoderError::InvalidConfig("Could not find num_protos in config".to_string())
1298 })?
1299 } else {
1300 protos.shape[1].min(protos.shape[3])
1301 };
1302
1303 if boxes_num != scores_num {
1304 return Err(DecoderError::InvalidConfig(format!(
1305 "Yolo Split Detection Boxes num {} incompatible with Scores num {}",
1306 boxes_num, scores_num
1307 )));
1308 }
1309
1310 if boxes_num != mask_num {
1311 return Err(DecoderError::InvalidConfig(format!(
1312 "Yolo Split Detection Boxes num {} incompatible with Mask Coefficients num {}",
1313 boxes_num, mask_num
1314 )));
1315 }
1316
1317 if proto_channels != mask_channels {
1318 return Err(DecoderError::InvalidConfig(format!(
1319 "Yolo Protos channels {} incompatible with Mask Coefficients channels {}",
1320 proto_channels, mask_channels
1321 )));
1322 }
1323
1324 Ok(())
1325 }
1326
1327 fn verify_yolo_split_end_to_end_det(
1328 boxes: &configs::Boxes,
1329 scores: &configs::Scores,
1330 classes: &configs::Classes,
1331 ) -> Result<(), DecoderError> {
1332 if boxes.shape.len() != 3 || !boxes.shape.contains(&4) {
1333 return Err(DecoderError::InvalidConfig(format!(
1334 "Split end-to-end boxes must be [batch, N, 4], got {:?}",
1335 boxes.shape
1336 )));
1337 }
1338 if scores.shape.len() != 3 || !scores.shape.contains(&1) {
1339 return Err(DecoderError::InvalidConfig(format!(
1340 "Split end-to-end scores must be [batch, N, 1], got {:?}",
1341 scores.shape
1342 )));
1343 }
1344 if classes.shape.len() != 3 || !classes.shape.contains(&1) {
1345 return Err(DecoderError::InvalidConfig(format!(
1346 "Split end-to-end classes must be [batch, N, 1], got {:?}",
1347 classes.shape
1348 )));
1349 }
1350 Ok(())
1351 }
1352
1353 fn verify_yolo_split_end_to_end_segdet(
1354 boxes: &configs::Boxes,
1355 scores: &configs::Scores,
1356 classes: &configs::Classes,
1357 mask_coeff: &configs::MaskCoefficients,
1358 protos: &configs::Protos,
1359 ) -> Result<(), DecoderError> {
1360 Self::verify_yolo_split_end_to_end_det(boxes, scores, classes)?;
1361 if mask_coeff.shape.len() != 3 {
1362 return Err(DecoderError::InvalidConfig(format!(
1363 "Invalid split end-to-end mask coefficients shape {:?}",
1364 mask_coeff.shape
1365 )));
1366 }
1367 if protos.shape.len() != 4 {
1368 return Err(DecoderError::InvalidConfig(format!(
1369 "Invalid protos shape {:?}",
1370 protos.shape
1371 )));
1372 }
1373 Ok(())
1374 }
1375
1376 fn get_model_type_modelpack(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
1377 let mut split_decoders = Vec::new();
1378 let mut segment_ = None;
1379 let mut scores_ = None;
1380 let mut boxes_ = None;
1381 for c in configs.outputs {
1382 match c {
1383 ConfigOutput::Detection(detection) => split_decoders.push(detection),
1384 ConfigOutput::Segmentation(segmentation) => segment_ = Some(segmentation),
1385 ConfigOutput::Mask(_) => {}
1386 ConfigOutput::Protos(_) => {
1387 return Err(DecoderError::InvalidConfig(
1388 "ModelPack should not have protos".to_string(),
1389 ));
1390 }
1391 ConfigOutput::Scores(scores) => scores_ = Some(scores),
1392 ConfigOutput::Boxes(boxes) => boxes_ = Some(boxes),
1393 ConfigOutput::MaskCoefficients(_) => {
1394 return Err(DecoderError::InvalidConfig(
1395 "ModelPack should not have mask coefficients".to_string(),
1396 ));
1397 }
1398 ConfigOutput::Classes(_) => {
1399 return Err(DecoderError::InvalidConfig(
1400 "ModelPack should not have classes output".to_string(),
1401 ));
1402 }
1403 }
1404 }
1405
1406 if let Some(segmentation) = segment_ {
1407 if !split_decoders.is_empty() {
1408 let classes = Self::verify_modelpack_split_det(&split_decoders)?;
1409 Self::verify_modelpack_seg(&segmentation, Some(classes))?;
1410 Ok(ModelType::ModelPackSegDetSplit {
1411 detection: split_decoders,
1412 segmentation,
1413 })
1414 } else if let (Some(scores), Some(boxes)) = (scores_, boxes_) {
1415 let classes = Self::verify_modelpack_det(&boxes, &scores)?;
1416 Self::verify_modelpack_seg(&segmentation, Some(classes))?;
1417 Ok(ModelType::ModelPackSegDet {
1418 boxes,
1419 scores,
1420 segmentation,
1421 })
1422 } else {
1423 Self::verify_modelpack_seg(&segmentation, None)?;
1424 Ok(ModelType::ModelPackSeg { segmentation })
1425 }
1426 } else if !split_decoders.is_empty() {
1427 Self::verify_modelpack_split_det(&split_decoders)?;
1428 Ok(ModelType::ModelPackDetSplit {
1429 detection: split_decoders,
1430 })
1431 } else if let (Some(scores), Some(boxes)) = (scores_, boxes_) {
1432 Self::verify_modelpack_det(&boxes, &scores)?;
1433 Ok(ModelType::ModelPackDet { boxes, scores })
1434 } else {
1435 Err(DecoderError::InvalidConfig(
1436 "Invalid ModelPack model outputs".to_string(),
1437 ))
1438 }
1439 }
1440
1441 fn verify_modelpack_det(
1442 boxes: &configs::Boxes,
1443 scores: &configs::Scores,
1444 ) -> Result<usize, DecoderError> {
1445 if boxes.shape.len() != 4 {
1446 return Err(DecoderError::InvalidConfig(format!(
1447 "Invalid ModelPack Boxes shape {:?}",
1448 boxes.shape
1449 )));
1450 }
1451 if scores.shape.len() != 3 {
1452 return Err(DecoderError::InvalidConfig(format!(
1453 "Invalid ModelPack Scores shape {:?}",
1454 scores.shape
1455 )));
1456 }
1457
1458 Self::verify_dshapes(
1459 &boxes.dshape,
1460 &boxes.shape,
1461 "Boxes",
1462 &[
1463 DimName::Batch,
1464 DimName::NumBoxes,
1465 DimName::Padding,
1466 DimName::BoxCoords,
1467 ],
1468 )?;
1469 Self::verify_dshapes(
1470 &scores.dshape,
1471 &scores.shape,
1472 "Scores",
1473 &[DimName::Batch, DimName::NumBoxes, DimName::NumClasses],
1474 )?;
1475
1476 let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[1]);
1477 let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[1]);
1478
1479 if boxes_num != scores_num {
1480 return Err(DecoderError::InvalidConfig(format!(
1481 "ModelPack Detection Boxes num {} incompatible with Scores num {}",
1482 boxes_num, scores_num
1483 )));
1484 }
1485
1486 let num_classes = if !scores.dshape.is_empty() {
1487 Self::get_class_count(&scores.dshape, None, None)?
1488 } else {
1489 Self::get_class_count_no_dshape(scores.into(), None)?
1490 };
1491
1492 Ok(num_classes)
1493 }
1494
1495 fn verify_modelpack_split_det(boxes: &[configs::Detection]) -> Result<usize, DecoderError> {
1496 let mut num_classes = None;
1497 for b in boxes {
1498 let Some(num_anchors) = b.anchors.as_ref().map(|a| a.len()) else {
1499 return Err(DecoderError::InvalidConfig(
1500 "ModelPack Split Detection missing anchors".to_string(),
1501 ));
1502 };
1503
1504 if num_anchors == 0 {
1505 return Err(DecoderError::InvalidConfig(
1506 "ModelPack Split Detection has zero anchors".to_string(),
1507 ));
1508 }
1509
1510 if b.shape.len() != 4 {
1511 return Err(DecoderError::InvalidConfig(format!(
1512 "Invalid ModelPack Split Detection shape {:?}",
1513 b.shape
1514 )));
1515 }
1516
1517 Self::verify_dshapes(
1518 &b.dshape,
1519 &b.shape,
1520 "Split Detection",
1521 &[
1522 DimName::Batch,
1523 DimName::Height,
1524 DimName::Width,
1525 DimName::NumAnchorsXFeatures,
1526 ],
1527 )?;
1528 let classes = if !b.dshape.is_empty() {
1529 Self::get_class_count(&b.dshape, None, Some(num_anchors))?
1530 } else {
1531 Self::get_class_count_no_dshape(b.into(), None)?
1532 };
1533
1534 match num_classes {
1535 Some(n) => {
1536 if n != classes {
1537 return Err(DecoderError::InvalidConfig(format!(
1538 "ModelPack Split Detection inconsistent number of classes: previous {}, current {}",
1539 n, classes
1540 )));
1541 }
1542 }
1543 None => {
1544 num_classes = Some(classes);
1545 }
1546 }
1547 }
1548
1549 Ok(num_classes.unwrap_or(0))
1550 }
1551
1552 fn verify_modelpack_seg(
1553 segmentation: &configs::Segmentation,
1554 classes: Option<usize>,
1555 ) -> Result<(), DecoderError> {
1556 if segmentation.shape.len() != 4 {
1557 return Err(DecoderError::InvalidConfig(format!(
1558 "Invalid ModelPack Segmentation shape {:?}",
1559 segmentation.shape
1560 )));
1561 }
1562 Self::verify_dshapes(
1563 &segmentation.dshape,
1564 &segmentation.shape,
1565 "Segmentation",
1566 &[
1567 DimName::Batch,
1568 DimName::Height,
1569 DimName::Width,
1570 DimName::NumClasses,
1571 ],
1572 )?;
1573
1574 if let Some(classes) = classes {
1575 let seg_classes = if !segmentation.dshape.is_empty() {
1576 Self::get_class_count(&segmentation.dshape, None, None)?
1577 } else {
1578 Self::get_class_count_no_dshape(segmentation.into(), None)?
1579 };
1580
1581 if seg_classes != classes + 1 {
1582 return Err(DecoderError::InvalidConfig(format!(
1583 "ModelPack Segmentation channels {} incompatible with number of classes {}",
1584 seg_classes, classes
1585 )));
1586 }
1587 }
1588 Ok(())
1589 }
1590
1591 fn verify_dshapes(
1593 dshape: &[(DimName, usize)],
1594 shape: &[usize],
1595 name: &str,
1596 dims: &[DimName],
1597 ) -> Result<(), DecoderError> {
1598 for s in shape {
1599 if *s == 0 {
1600 return Err(DecoderError::InvalidConfig(format!(
1601 "{} shape has zero dimension",
1602 name
1603 )));
1604 }
1605 }
1606
1607 if shape.len() != dims.len() {
1608 return Err(DecoderError::InvalidConfig(format!(
1609 "{} shape length {} does not match expected dims length {}",
1610 name,
1611 shape.len(),
1612 dims.len()
1613 )));
1614 }
1615
1616 if dshape.is_empty() {
1617 return Ok(());
1618 }
1619 if dshape.len() != shape.len() {
1621 return Err(DecoderError::InvalidConfig(format!(
1622 "{} dshape length does not match shape length",
1623 name
1624 )));
1625 }
1626
1627 for ((dim_name, dim_size), shape_size) in dshape.iter().zip(shape) {
1629 if dim_size != shape_size {
1630 return Err(DecoderError::InvalidConfig(format!(
1631 "{} dshape dimension {} size {} does not match shape size {}",
1632 name, dim_name, dim_size, shape_size
1633 )));
1634 }
1635 if *dim_name == DimName::Padding && *dim_size != 1 {
1636 return Err(DecoderError::InvalidConfig(
1637 "Padding dimension size must be 1".to_string(),
1638 ));
1639 }
1640
1641 if *dim_name == DimName::BoxCoords && *dim_size != 4 {
1642 return Err(DecoderError::InvalidConfig(
1643 "BoxCoords dimension size must be 4".to_string(),
1644 ));
1645 }
1646 }
1647
1648 let dims_present = HashSet::<DimName>::from_iter(dshape.iter().map(|(name, _)| *name));
1649 for dim in dims {
1650 if !dims_present.contains(dim) {
1651 return Err(DecoderError::InvalidConfig(format!(
1652 "{} dshape missing required dimension {:?}",
1653 name, dim
1654 )));
1655 }
1656 }
1657
1658 Ok(())
1659 }
1660
1661 fn get_box_count(dshape: &[(DimName, usize)]) -> Option<usize> {
1662 for (dim_name, dim_size) in dshape {
1663 if *dim_name == DimName::NumBoxes {
1664 return Some(*dim_size);
1665 }
1666 }
1667 None
1668 }
1669
1670 fn get_class_count_no_dshape(
1671 config: ConfigOutputRef,
1672 protos: Option<usize>,
1673 ) -> Result<usize, DecoderError> {
1674 match config {
1675 ConfigOutputRef::Detection(detection) => match detection.decoder {
1676 DecoderType::Ultralytics => {
1677 if detection.shape[1] <= 4 + protos.unwrap_or(0) {
1678 return Err(DecoderError::InvalidConfig(format!(
1679 "Invalid shape: Yolo num_features {} must be greater than {}",
1680 detection.shape[1],
1681 4 + protos.unwrap_or(0),
1682 )));
1683 }
1684 Ok(detection.shape[1] - 4 - protos.unwrap_or(0))
1685 }
1686 DecoderType::ModelPack => {
1687 let Some(num_anchors) = detection.anchors.as_ref().map(|a| a.len()) else {
1688 return Err(DecoderError::Internal(
1689 "ModelPack Detection missing anchors".to_string(),
1690 ));
1691 };
1692 let anchors_x_features = detection.shape[3];
1693 if anchors_x_features <= num_anchors * 5 {
1694 return Err(DecoderError::InvalidConfig(format!(
1695 "Invalid ModelPack Split Detection shape: anchors_x_features {} not greater than number of anchors * 5 = {}",
1696 anchors_x_features,
1697 num_anchors * 5,
1698 )));
1699 }
1700
1701 if !anchors_x_features.is_multiple_of(num_anchors) {
1702 return Err(DecoderError::InvalidConfig(format!(
1703 "Invalid ModelPack Split Detection shape: anchors_x_features {} not a multiple of number of anchors {}",
1704 anchors_x_features, num_anchors
1705 )));
1706 }
1707 Ok(anchors_x_features / num_anchors - 5)
1708 }
1709 },
1710
1711 ConfigOutputRef::Scores(scores) => match scores.decoder {
1712 DecoderType::Ultralytics => Ok(scores.shape[1]),
1713 DecoderType::ModelPack => Ok(scores.shape[2]),
1714 },
1715 ConfigOutputRef::Segmentation(seg) => Ok(seg.shape[3]),
1716 _ => Err(DecoderError::Internal(
1717 "Attempted to get class count from unsupported config output".to_owned(),
1718 )),
1719 }
1720 }
1721
1722 fn get_class_count(
1724 dshape: &[(DimName, usize)],
1725 protos: Option<usize>,
1726 anchors: Option<usize>,
1727 ) -> Result<usize, DecoderError> {
1728 if dshape.is_empty() {
1729 return Ok(0);
1730 }
1731 for (dim_name, dim_size) in dshape {
1733 if *dim_name == DimName::NumClasses {
1734 return Ok(*dim_size);
1735 }
1736 }
1737
1738 for (dim_name, dim_size) in dshape {
1741 if *dim_name == DimName::NumFeatures {
1742 let protos = protos.unwrap_or(0);
1743 if protos + 4 >= *dim_size {
1744 return Err(DecoderError::InvalidConfig(format!(
1745 "Invalid shape: Yolo num_features {} must be greater than {}",
1746 *dim_size,
1747 protos + 4,
1748 )));
1749 }
1750 return Ok(*dim_size - 4 - protos);
1751 }
1752 }
1753
1754 if let Some(num_anchors) = anchors {
1757 for (dim_name, dim_size) in dshape {
1758 if *dim_name == DimName::NumAnchorsXFeatures {
1759 let anchors_x_features = *dim_size;
1760 if anchors_x_features <= num_anchors * 5 {
1761 return Err(DecoderError::InvalidConfig(format!(
1762 "Invalid ModelPack Split Detection shape: anchors_x_features {} not greater than number of anchors * 5 = {}",
1763 anchors_x_features,
1764 num_anchors * 5,
1765 )));
1766 }
1767
1768 if !anchors_x_features.is_multiple_of(num_anchors) {
1769 return Err(DecoderError::InvalidConfig(format!(
1770 "Invalid ModelPack Split Detection shape: anchors_x_features {} not a multiple of number of anchors {}",
1771 anchors_x_features, num_anchors
1772 )));
1773 }
1774 return Ok((anchors_x_features / num_anchors) - 5);
1775 }
1776 }
1777 }
1778 Err(DecoderError::InvalidConfig(
1779 "Cannot determine number of classes from dshape".to_owned(),
1780 ))
1781 }
1782
1783 fn get_protos_count(dshape: &[(DimName, usize)]) -> Option<usize> {
1784 for (dim_name, dim_size) in dshape {
1785 if *dim_name == DimName::NumProtos {
1786 return Some(*dim_size);
1787 }
1788 }
1789 None
1790 }
1791}