1use ndarray::{ArrayView, ArrayViewD, Dimension};
5use num_traits::{AsPrimitive, Float};
6
7use crate::{DecoderError, DetectBox, ProtoData, Segmentation};
8
9pub mod config;
10pub mod configs;
11
12use configs::ModelType;
13
14#[derive(Debug, Clone, PartialEq)]
15pub struct Decoder {
16 model_type: ModelType,
17 pub iou_threshold: f32,
18 pub score_threshold: f32,
19 pub nms: Option<configs::Nms>,
22 normalized: Option<bool>,
28}
29
30#[derive(Debug)]
31pub(crate) enum ArrayViewDQuantized<'a> {
32 UInt8(ArrayViewD<'a, u8>),
33 Int8(ArrayViewD<'a, i8>),
34 UInt16(ArrayViewD<'a, u16>),
35 Int16(ArrayViewD<'a, i16>),
36 UInt32(ArrayViewD<'a, u32>),
37 Int32(ArrayViewD<'a, i32>),
38}
39
40impl<'a, D> From<ArrayView<'a, u8, D>> for ArrayViewDQuantized<'a>
41where
42 D: Dimension,
43{
44 fn from(arr: ArrayView<'a, u8, D>) -> Self {
45 Self::UInt8(arr.into_dyn())
46 }
47}
48
49impl<'a, D> From<ArrayView<'a, i8, D>> for ArrayViewDQuantized<'a>
50where
51 D: Dimension,
52{
53 fn from(arr: ArrayView<'a, i8, D>) -> Self {
54 Self::Int8(arr.into_dyn())
55 }
56}
57
58impl<'a, D> From<ArrayView<'a, u16, D>> for ArrayViewDQuantized<'a>
59where
60 D: Dimension,
61{
62 fn from(arr: ArrayView<'a, u16, D>) -> Self {
63 Self::UInt16(arr.into_dyn())
64 }
65}
66
67impl<'a, D> From<ArrayView<'a, i16, D>> for ArrayViewDQuantized<'a>
68where
69 D: Dimension,
70{
71 fn from(arr: ArrayView<'a, i16, D>) -> Self {
72 Self::Int16(arr.into_dyn())
73 }
74}
75
76impl<'a, D> From<ArrayView<'a, u32, D>> for ArrayViewDQuantized<'a>
77where
78 D: Dimension,
79{
80 fn from(arr: ArrayView<'a, u32, D>) -> Self {
81 Self::UInt32(arr.into_dyn())
82 }
83}
84
85impl<'a, D> From<ArrayView<'a, i32, D>> for ArrayViewDQuantized<'a>
86where
87 D: Dimension,
88{
89 fn from(arr: ArrayView<'a, i32, D>) -> Self {
90 Self::Int32(arr.into_dyn())
91 }
92}
93
94impl<'a> ArrayViewDQuantized<'a> {
95 pub(crate) fn shape(&self) -> &[usize] {
97 match self {
98 ArrayViewDQuantized::UInt8(a) => a.shape(),
99 ArrayViewDQuantized::Int8(a) => a.shape(),
100 ArrayViewDQuantized::UInt16(a) => a.shape(),
101 ArrayViewDQuantized::Int16(a) => a.shape(),
102 ArrayViewDQuantized::UInt32(a) => a.shape(),
103 ArrayViewDQuantized::Int32(a) => a.shape(),
104 }
105 }
106}
107
108macro_rules! with_quantized {
115 ($x:expr, $var:ident, $body:expr) => {
116 match $x {
117 ArrayViewDQuantized::UInt8(x) => {
118 let $var = x;
119 $body
120 }
121 ArrayViewDQuantized::Int8(x) => {
122 let $var = x;
123 $body
124 }
125 ArrayViewDQuantized::UInt16(x) => {
126 let $var = x;
127 $body
128 }
129 ArrayViewDQuantized::Int16(x) => {
130 let $var = x;
131 $body
132 }
133 ArrayViewDQuantized::UInt32(x) => {
134 let $var = x;
135 $body
136 }
137 ArrayViewDQuantized::Int32(x) => {
138 let $var = x;
139 $body
140 }
141 }
142 };
143}
144
145mod builder;
146mod helpers;
147mod postprocess;
148mod tensor_bridge;
149mod tests;
150
151pub use builder::DecoderBuilder;
152pub use config::{ConfigOutput, ConfigOutputRef, ConfigOutputs};
153
154impl Decoder {
155 pub fn model_type(&self) -> &ModelType {
174 &self.model_type
175 }
176
177 pub fn normalized_boxes(&self) -> Option<bool> {
203 self.normalized
204 }
205
206 pub(crate) fn decode_quantized(
210 &self,
211 outputs: &[ArrayViewDQuantized],
212 output_boxes: &mut Vec<DetectBox>,
213 output_masks: &mut Vec<Segmentation>,
214 ) -> Result<(), DecoderError> {
215 output_boxes.clear();
216 output_masks.clear();
217 match &self.model_type {
218 ModelType::ModelPackSegDet {
219 boxes,
220 scores,
221 segmentation,
222 } => {
223 self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)?;
224 self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
225 }
226 ModelType::ModelPackSegDetSplit {
227 detection,
228 segmentation,
229 } => {
230 self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)?;
231 self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
232 }
233 ModelType::ModelPackDet { boxes, scores } => {
234 self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)
235 }
236 ModelType::ModelPackDetSplit { detection } => {
237 self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)
238 }
239 ModelType::ModelPackSeg { segmentation } => {
240 self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
241 }
242 ModelType::YoloDet { boxes } => {
243 self.decode_yolo_det_quantized(outputs, boxes, output_boxes)
244 }
245 ModelType::YoloSegDet { boxes, protos } => self.decode_yolo_segdet_quantized(
246 outputs,
247 boxes,
248 protos,
249 output_boxes,
250 output_masks,
251 ),
252 ModelType::YoloSplitDet { boxes, scores } => {
253 self.decode_yolo_split_det_quantized(outputs, boxes, scores, output_boxes)
254 }
255 ModelType::YoloSplitSegDet {
256 boxes,
257 scores,
258 mask_coeff,
259 protos,
260 } => self.decode_yolo_split_segdet_quantized(
261 outputs,
262 boxes,
263 scores,
264 mask_coeff,
265 protos,
266 output_boxes,
267 output_masks,
268 ),
269 ModelType::YoloSegDet2Way {
270 boxes,
271 mask_coeff,
272 protos,
273 } => self.decode_yolo_segdet_2way_quantized(
274 outputs,
275 boxes,
276 mask_coeff,
277 protos,
278 output_boxes,
279 output_masks,
280 ),
281 ModelType::YoloEndToEndDet { boxes } => {
282 self.decode_yolo_end_to_end_det_quantized(outputs, boxes, output_boxes)
283 }
284 ModelType::YoloEndToEndSegDet { boxes, protos } => self
285 .decode_yolo_end_to_end_segdet_quantized(
286 outputs,
287 boxes,
288 protos,
289 output_boxes,
290 output_masks,
291 ),
292 ModelType::YoloSplitEndToEndDet {
293 boxes,
294 scores,
295 classes,
296 } => self.decode_yolo_split_end_to_end_det_quantized(
297 outputs,
298 boxes,
299 scores,
300 classes,
301 output_boxes,
302 ),
303 ModelType::YoloSplitEndToEndSegDet {
304 boxes,
305 scores,
306 classes,
307 mask_coeff,
308 protos,
309 } => self.decode_yolo_split_end_to_end_segdet_quantized(
310 outputs,
311 boxes,
312 scores,
313 classes,
314 mask_coeff,
315 protos,
316 output_boxes,
317 output_masks,
318 ),
319 }
320 }
321
322 pub(crate) fn decode_float<T>(
326 &self,
327 outputs: &[ArrayViewD<T>],
328 output_boxes: &mut Vec<DetectBox>,
329 output_masks: &mut Vec<Segmentation>,
330 ) -> Result<(), DecoderError>
331 where
332 T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
333 f32: AsPrimitive<T>,
334 {
335 output_boxes.clear();
336 output_masks.clear();
337 match &self.model_type {
338 ModelType::ModelPackSegDet {
339 boxes,
340 scores,
341 segmentation,
342 } => {
343 self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
344 self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
345 }
346 ModelType::ModelPackSegDetSplit {
347 detection,
348 segmentation,
349 } => {
350 self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
351 self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
352 }
353 ModelType::ModelPackDet { boxes, scores } => {
354 self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
355 }
356 ModelType::ModelPackDetSplit { detection } => {
357 self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
358 }
359 ModelType::ModelPackSeg { segmentation } => {
360 self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
361 }
362 ModelType::YoloDet { boxes } => {
363 self.decode_yolo_det_float(outputs, boxes, output_boxes)?;
364 }
365 ModelType::YoloSegDet { boxes, protos } => {
366 self.decode_yolo_segdet_float(outputs, boxes, protos, output_boxes, output_masks)?;
367 }
368 ModelType::YoloSplitDet { boxes, scores } => {
369 self.decode_yolo_split_det_float(outputs, boxes, scores, output_boxes)?;
370 }
371 ModelType::YoloSplitSegDet {
372 boxes,
373 scores,
374 mask_coeff,
375 protos,
376 } => {
377 self.decode_yolo_split_segdet_float(
378 outputs,
379 boxes,
380 scores,
381 mask_coeff,
382 protos,
383 output_boxes,
384 output_masks,
385 )?;
386 }
387 ModelType::YoloSegDet2Way {
388 boxes,
389 mask_coeff,
390 protos,
391 } => {
392 self.decode_yolo_segdet_2way_float(
393 outputs,
394 boxes,
395 mask_coeff,
396 protos,
397 output_boxes,
398 output_masks,
399 )?;
400 }
401 ModelType::YoloEndToEndDet { boxes } => {
402 self.decode_yolo_end_to_end_det_float(outputs, boxes, output_boxes)?;
403 }
404 ModelType::YoloEndToEndSegDet { boxes, protos } => {
405 self.decode_yolo_end_to_end_segdet_float(
406 outputs,
407 boxes,
408 protos,
409 output_boxes,
410 output_masks,
411 )?;
412 }
413 ModelType::YoloSplitEndToEndDet {
414 boxes,
415 scores,
416 classes,
417 } => {
418 self.decode_yolo_split_end_to_end_det_float(
419 outputs,
420 boxes,
421 scores,
422 classes,
423 output_boxes,
424 )?;
425 }
426 ModelType::YoloSplitEndToEndSegDet {
427 boxes,
428 scores,
429 classes,
430 mask_coeff,
431 protos,
432 } => {
433 self.decode_yolo_split_end_to_end_segdet_float(
434 outputs,
435 boxes,
436 scores,
437 classes,
438 mask_coeff,
439 protos,
440 output_boxes,
441 output_masks,
442 )?;
443 }
444 }
445 Ok(())
446 }
447
448 pub(crate) fn decode_quantized_proto(
455 &self,
456 outputs: &[ArrayViewDQuantized],
457 output_boxes: &mut Vec<DetectBox>,
458 ) -> Result<Option<ProtoData>, DecoderError> {
459 output_boxes.clear();
460 match &self.model_type {
461 ModelType::ModelPackSegDet { .. }
463 | ModelType::ModelPackSegDetSplit { .. }
464 | ModelType::ModelPackDet { .. }
465 | ModelType::ModelPackDetSplit { .. }
466 | ModelType::ModelPackSeg { .. }
467 | ModelType::YoloDet { .. }
468 | ModelType::YoloSplitDet { .. }
469 | ModelType::YoloEndToEndDet { .. }
470 | ModelType::YoloSplitEndToEndDet { .. } => Ok(None),
471
472 ModelType::YoloSegDet { boxes, protos } => {
473 let proto =
474 self.decode_yolo_segdet_quantized_proto(outputs, boxes, protos, output_boxes)?;
475 Ok(Some(proto))
476 }
477 ModelType::YoloSplitSegDet {
478 boxes,
479 scores,
480 mask_coeff,
481 protos,
482 } => {
483 let proto = self.decode_yolo_split_segdet_quantized_proto(
484 outputs,
485 boxes,
486 scores,
487 mask_coeff,
488 protos,
489 output_boxes,
490 )?;
491 Ok(Some(proto))
492 }
493 ModelType::YoloSegDet2Way {
494 boxes,
495 mask_coeff,
496 protos,
497 } => {
498 let proto = self.decode_yolo_segdet_2way_quantized_proto(
499 outputs,
500 boxes,
501 mask_coeff,
502 protos,
503 output_boxes,
504 )?;
505 Ok(Some(proto))
506 }
507 ModelType::YoloEndToEndSegDet { boxes, protos } => {
508 let proto = self.decode_yolo_end_to_end_segdet_quantized_proto(
509 outputs,
510 boxes,
511 protos,
512 output_boxes,
513 )?;
514 Ok(Some(proto))
515 }
516 ModelType::YoloSplitEndToEndSegDet {
517 boxes,
518 scores,
519 classes,
520 mask_coeff,
521 protos,
522 } => {
523 let proto = self.decode_yolo_split_end_to_end_segdet_quantized_proto(
524 outputs,
525 boxes,
526 scores,
527 classes,
528 mask_coeff,
529 protos,
530 output_boxes,
531 )?;
532 Ok(Some(proto))
533 }
534 }
535 }
536
537 pub(crate) fn decode_float_proto<T>(
543 &self,
544 outputs: &[ArrayViewD<T>],
545 output_boxes: &mut Vec<DetectBox>,
546 ) -> Result<Option<ProtoData>, DecoderError>
547 where
548 T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
549 f32: AsPrimitive<T>,
550 {
551 output_boxes.clear();
552 match &self.model_type {
553 ModelType::ModelPackSegDet { .. }
555 | ModelType::ModelPackSegDetSplit { .. }
556 | ModelType::ModelPackDet { .. }
557 | ModelType::ModelPackDetSplit { .. }
558 | ModelType::ModelPackSeg { .. }
559 | ModelType::YoloDet { .. }
560 | ModelType::YoloSplitDet { .. }
561 | ModelType::YoloEndToEndDet { .. }
562 | ModelType::YoloSplitEndToEndDet { .. } => Ok(None),
563
564 ModelType::YoloSegDet { boxes, protos } => {
565 let proto =
566 self.decode_yolo_segdet_float_proto(outputs, boxes, protos, output_boxes)?;
567 Ok(Some(proto))
568 }
569 ModelType::YoloSplitSegDet {
570 boxes,
571 scores,
572 mask_coeff,
573 protos,
574 } => {
575 let proto = self.decode_yolo_split_segdet_float_proto(
576 outputs,
577 boxes,
578 scores,
579 mask_coeff,
580 protos,
581 output_boxes,
582 )?;
583 Ok(Some(proto))
584 }
585 ModelType::YoloSegDet2Way {
586 boxes,
587 mask_coeff,
588 protos,
589 } => {
590 let proto = self.decode_yolo_segdet_2way_float_proto(
591 outputs,
592 boxes,
593 mask_coeff,
594 protos,
595 output_boxes,
596 )?;
597 Ok(Some(proto))
598 }
599 ModelType::YoloEndToEndSegDet { boxes, protos } => {
600 let proto = self.decode_yolo_end_to_end_segdet_float_proto(
601 outputs,
602 boxes,
603 protos,
604 output_boxes,
605 )?;
606 Ok(Some(proto))
607 }
608 ModelType::YoloSplitEndToEndSegDet {
609 boxes,
610 scores,
611 classes,
612 mask_coeff,
613 protos,
614 } => {
615 let proto = self.decode_yolo_split_end_to_end_segdet_float_proto(
616 outputs,
617 boxes,
618 scores,
619 classes,
620 mask_coeff,
621 protos,
622 output_boxes,
623 )?;
624 Ok(Some(proto))
625 }
626 }
627 }
628
629 pub fn decode(
650 &self,
651 outputs: &[&edgefirst_tensor::TensorDyn],
652 output_boxes: &mut Vec<DetectBox>,
653 output_masks: &mut Vec<Segmentation>,
654 ) -> Result<(), DecoderError> {
655 let mapped = tensor_bridge::map_tensors(outputs)?;
656 match &mapped {
657 tensor_bridge::MappedOutputs::Quantized(maps) => {
658 let views = tensor_bridge::quantized_views(maps)?;
659 self.decode_quantized(&views, output_boxes, output_masks)
660 }
661 tensor_bridge::MappedOutputs::Float32(maps) => {
662 let views = tensor_bridge::f32_views(maps)?;
663 self.decode_float(&views, output_boxes, output_masks)
664 }
665 tensor_bridge::MappedOutputs::Float64(maps) => {
666 let views = tensor_bridge::f64_views(maps)?;
667 self.decode_float(&views, output_boxes, output_masks)
668 }
669 }
670 }
671
672 pub fn decode_proto(
689 &self,
690 outputs: &[&edgefirst_tensor::TensorDyn],
691 output_boxes: &mut Vec<DetectBox>,
692 ) -> Result<Option<ProtoData>, DecoderError> {
693 let mapped = tensor_bridge::map_tensors(outputs)?;
694 match &mapped {
695 tensor_bridge::MappedOutputs::Quantized(maps) => {
696 let views = tensor_bridge::quantized_views(maps)?;
697 self.decode_quantized_proto(&views, output_boxes)
698 }
699 tensor_bridge::MappedOutputs::Float32(maps) => {
700 let views = tensor_bridge::f32_views(maps)?;
701 self.decode_float_proto(&views, output_boxes)
702 }
703 tensor_bridge::MappedOutputs::Float64(maps) => {
704 let views = tensor_bridge::f64_views(maps)?;
705 self.decode_float_proto(&views, output_boxes)
706 }
707 }
708 }
709}
710
711#[cfg(feature = "tracker")]
712pub use edgefirst_tracker::TrackInfo;
713
714#[cfg(feature = "tracker")]
715pub use edgefirst_tracker::Tracker;
716
717#[cfg(feature = "tracker")]
718impl Decoder {
719 pub(crate) fn decode_tracked_quantized<TR: edgefirst_tracker::Tracker<DetectBox>>(
723 &self,
724 tracker: &mut TR,
725 timestamp: u64,
726 outputs: &[ArrayViewDQuantized],
727 output_boxes: &mut Vec<DetectBox>,
728 output_masks: &mut Vec<Segmentation>,
729 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
730 ) -> Result<(), DecoderError> {
731 output_boxes.clear();
732 output_masks.clear();
733 output_tracks.clear();
734
735 match &self.model_type {
738 ModelType::YoloSegDet { boxes, protos } => self.decode_tracked_yolo_segdet_quantized(
739 tracker,
740 timestamp,
741 outputs,
742 boxes,
743 protos,
744 output_boxes,
745 output_masks,
746 output_tracks,
747 ),
748 ModelType::YoloSplitSegDet {
749 boxes,
750 scores,
751 mask_coeff,
752 protos,
753 } => self.decode_tracked_yolo_split_segdet_quantized(
754 tracker,
755 timestamp,
756 outputs,
757 boxes,
758 scores,
759 mask_coeff,
760 protos,
761 output_boxes,
762 output_masks,
763 output_tracks,
764 ),
765 ModelType::YoloEndToEndSegDet { boxes, protos } => self
766 .decode_tracked_yolo_end_to_end_segdet_quantized(
767 tracker,
768 timestamp,
769 outputs,
770 boxes,
771 protos,
772 output_boxes,
773 output_masks,
774 output_tracks,
775 ),
776 ModelType::YoloSplitEndToEndSegDet {
777 boxes,
778 scores,
779 classes,
780 mask_coeff,
781 protos,
782 } => self.decode_tracked_yolo_split_end_to_end_segdet_quantized(
783 tracker,
784 timestamp,
785 outputs,
786 boxes,
787 scores,
788 classes,
789 mask_coeff,
790 protos,
791 output_boxes,
792 output_masks,
793 output_tracks,
794 ),
795 ModelType::YoloSegDet2Way {
796 boxes,
797 mask_coeff,
798 protos,
799 } => self.decode_tracked_yolo_segdet_2way_quantized(
800 tracker,
801 timestamp,
802 outputs,
803 boxes,
804 mask_coeff,
805 protos,
806 output_boxes,
807 output_masks,
808 output_tracks,
809 ),
810 _ => {
811 self.decode_quantized(outputs, output_boxes, output_masks)?;
812 Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
813 Ok(())
814 }
815 }
816 }
817
818 pub(crate) fn decode_tracked_float<TR: edgefirst_tracker::Tracker<DetectBox>, T>(
828 &self,
829 tracker: &mut TR,
830 timestamp: u64,
831 outputs: &[ArrayViewD<T>],
832 output_boxes: &mut Vec<DetectBox>,
833 output_masks: &mut Vec<Segmentation>,
834 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
835 ) -> Result<(), DecoderError>
836 where
837 T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
838 f32: AsPrimitive<T>,
839 {
840 output_boxes.clear();
841 output_masks.clear();
842 output_tracks.clear();
843 match &self.model_type {
844 ModelType::YoloSegDet { boxes, protos } => {
845 self.decode_tracked_yolo_segdet_float(
846 tracker,
847 timestamp,
848 outputs,
849 boxes,
850 protos,
851 output_boxes,
852 output_masks,
853 output_tracks,
854 )?;
855 }
856 ModelType::YoloSplitSegDet {
857 boxes,
858 scores,
859 mask_coeff,
860 protos,
861 } => {
862 self.decode_tracked_yolo_split_segdet_float(
863 tracker,
864 timestamp,
865 outputs,
866 boxes,
867 scores,
868 mask_coeff,
869 protos,
870 output_boxes,
871 output_masks,
872 output_tracks,
873 )?;
874 }
875 ModelType::YoloEndToEndSegDet { boxes, protos } => {
876 self.decode_tracked_yolo_end_to_end_segdet_float(
877 tracker,
878 timestamp,
879 outputs,
880 boxes,
881 protos,
882 output_boxes,
883 output_masks,
884 output_tracks,
885 )?;
886 }
887 ModelType::YoloSplitEndToEndSegDet {
888 boxes,
889 scores,
890 classes,
891 mask_coeff,
892 protos,
893 } => {
894 self.decode_tracked_yolo_split_end_to_end_segdet_float(
895 tracker,
896 timestamp,
897 outputs,
898 boxes,
899 scores,
900 classes,
901 mask_coeff,
902 protos,
903 output_boxes,
904 output_masks,
905 output_tracks,
906 )?;
907 }
908 ModelType::YoloSegDet2Way {
909 boxes,
910 mask_coeff,
911 protos,
912 } => {
913 self.decode_tracked_yolo_segdet_2way_float(
914 tracker,
915 timestamp,
916 outputs,
917 boxes,
918 mask_coeff,
919 protos,
920 output_boxes,
921 output_masks,
922 output_tracks,
923 )?;
924 }
925 _ => {
926 self.decode_float(outputs, output_boxes, output_masks)?;
927 Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
928 }
929 }
930 Ok(())
931 }
932
933 pub(crate) fn decode_tracked_quantized_proto<TR: edgefirst_tracker::Tracker<DetectBox>>(
940 &self,
941 tracker: &mut TR,
942 timestamp: u64,
943 outputs: &[ArrayViewDQuantized],
944 output_boxes: &mut Vec<DetectBox>,
945 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
946 ) -> Result<Option<ProtoData>, DecoderError> {
947 output_boxes.clear();
948 output_tracks.clear();
949 match &self.model_type {
950 ModelType::ModelPackSegDet { .. }
952 | ModelType::ModelPackSegDetSplit { .. }
953 | ModelType::ModelPackDet { .. }
954 | ModelType::ModelPackDetSplit { .. }
955 | ModelType::ModelPackSeg { .. }
956 | ModelType::YoloDet { .. }
957 | ModelType::YoloSplitDet { .. }
958 | ModelType::YoloEndToEndDet { .. }
959 | ModelType::YoloSplitEndToEndDet { .. } => Ok(None),
960
961 ModelType::YoloSegDet { boxes, protos } => {
962 let proto = self.decode_tracked_yolo_segdet_quantized_proto(
963 tracker,
964 timestamp,
965 outputs,
966 boxes,
967 protos,
968 output_boxes,
969 output_tracks,
970 )?;
971 Ok(Some(proto))
972 }
973 ModelType::YoloSplitSegDet {
974 boxes,
975 scores,
976 mask_coeff,
977 protos,
978 } => {
979 let proto = self.decode_tracked_yolo_split_segdet_quantized_proto(
980 tracker,
981 timestamp,
982 outputs,
983 boxes,
984 scores,
985 mask_coeff,
986 protos,
987 output_boxes,
988 output_tracks,
989 )?;
990 Ok(Some(proto))
991 }
992 ModelType::YoloSegDet2Way {
993 boxes,
994 mask_coeff,
995 protos,
996 } => {
997 let proto = self.decode_tracked_yolo_segdet_2way_quantized_proto(
998 tracker,
999 timestamp,
1000 outputs,
1001 boxes,
1002 mask_coeff,
1003 protos,
1004 output_boxes,
1005 output_tracks,
1006 )?;
1007 Ok(Some(proto))
1008 }
1009 ModelType::YoloEndToEndSegDet { boxes, protos } => {
1010 let proto = self.decode_tracked_yolo_end_to_end_segdet_quantized_proto(
1011 tracker,
1012 timestamp,
1013 outputs,
1014 boxes,
1015 protos,
1016 output_boxes,
1017 output_tracks,
1018 )?;
1019 Ok(Some(proto))
1020 }
1021 ModelType::YoloSplitEndToEndSegDet {
1022 boxes,
1023 scores,
1024 classes,
1025 mask_coeff,
1026 protos,
1027 } => {
1028 let proto = self.decode_tracked_yolo_split_end_to_end_segdet_quantized_proto(
1029 tracker,
1030 timestamp,
1031 outputs,
1032 boxes,
1033 scores,
1034 classes,
1035 mask_coeff,
1036 protos,
1037 output_boxes,
1038 output_tracks,
1039 )?;
1040 Ok(Some(proto))
1041 }
1042 }
1043 }
1044
1045 pub(crate) fn decode_tracked_float_proto<TR: edgefirst_tracker::Tracker<DetectBox>, T>(
1051 &self,
1052 tracker: &mut TR,
1053 timestamp: u64,
1054 outputs: &[ArrayViewD<T>],
1055 output_boxes: &mut Vec<DetectBox>,
1056 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1057 ) -> Result<Option<ProtoData>, DecoderError>
1058 where
1059 T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
1060 f32: AsPrimitive<T>,
1061 {
1062 output_boxes.clear();
1063 output_tracks.clear();
1064 match &self.model_type {
1065 ModelType::ModelPackSegDet { .. }
1067 | ModelType::ModelPackSegDetSplit { .. }
1068 | ModelType::ModelPackDet { .. }
1069 | ModelType::ModelPackDetSplit { .. }
1070 | ModelType::ModelPackSeg { .. }
1071 | ModelType::YoloDet { .. }
1072 | ModelType::YoloSplitDet { .. }
1073 | ModelType::YoloEndToEndDet { .. }
1074 | ModelType::YoloSplitEndToEndDet { .. } => Ok(None),
1075
1076 ModelType::YoloSegDet { boxes, protos } => {
1077 let proto = self.decode_tracked_yolo_segdet_float_proto(
1078 tracker,
1079 timestamp,
1080 outputs,
1081 boxes,
1082 protos,
1083 output_boxes,
1084 output_tracks,
1085 )?;
1086 Ok(Some(proto))
1087 }
1088 ModelType::YoloSplitSegDet {
1089 boxes,
1090 scores,
1091 mask_coeff,
1092 protos,
1093 } => {
1094 let proto = self.decode_tracked_yolo_split_segdet_float_proto(
1095 tracker,
1096 timestamp,
1097 outputs,
1098 boxes,
1099 scores,
1100 mask_coeff,
1101 protos,
1102 output_boxes,
1103 output_tracks,
1104 )?;
1105 Ok(Some(proto))
1106 }
1107 ModelType::YoloSegDet2Way {
1108 boxes,
1109 mask_coeff,
1110 protos,
1111 } => {
1112 let proto = self.decode_tracked_yolo_segdet_2way_float_proto(
1113 tracker,
1114 timestamp,
1115 outputs,
1116 boxes,
1117 mask_coeff,
1118 protos,
1119 output_boxes,
1120 output_tracks,
1121 )?;
1122 Ok(Some(proto))
1123 }
1124 ModelType::YoloEndToEndSegDet { boxes, protos } => {
1125 let proto = self.decode_tracked_yolo_end_to_end_segdet_float_proto(
1126 tracker,
1127 timestamp,
1128 outputs,
1129 boxes,
1130 protos,
1131 output_boxes,
1132 output_tracks,
1133 )?;
1134 Ok(Some(proto))
1135 }
1136 ModelType::YoloSplitEndToEndSegDet {
1137 boxes,
1138 scores,
1139 classes,
1140 mask_coeff,
1141 protos,
1142 } => {
1143 let proto = self.decode_tracked_yolo_split_end_to_end_segdet_float_proto(
1144 tracker,
1145 timestamp,
1146 outputs,
1147 boxes,
1148 scores,
1149 classes,
1150 mask_coeff,
1151 protos,
1152 output_boxes,
1153 output_tracks,
1154 )?;
1155 Ok(Some(proto))
1156 }
1157 }
1158 }
1159
1160 pub fn decode_tracked<TR: edgefirst_tracker::Tracker<DetectBox>>(
1184 &self,
1185 tracker: &mut TR,
1186 timestamp: u64,
1187 outputs: &[&edgefirst_tensor::TensorDyn],
1188 output_boxes: &mut Vec<DetectBox>,
1189 output_masks: &mut Vec<Segmentation>,
1190 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1191 ) -> Result<(), DecoderError> {
1192 let mapped = tensor_bridge::map_tensors(outputs)?;
1193 match &mapped {
1194 tensor_bridge::MappedOutputs::Quantized(maps) => {
1195 let views = tensor_bridge::quantized_views(maps)?;
1196 self.decode_tracked_quantized(
1197 tracker,
1198 timestamp,
1199 &views,
1200 output_boxes,
1201 output_masks,
1202 output_tracks,
1203 )
1204 }
1205 tensor_bridge::MappedOutputs::Float32(maps) => {
1206 let views = tensor_bridge::f32_views(maps)?;
1207 self.decode_tracked_float(
1208 tracker,
1209 timestamp,
1210 &views,
1211 output_boxes,
1212 output_masks,
1213 output_tracks,
1214 )
1215 }
1216 tensor_bridge::MappedOutputs::Float64(maps) => {
1217 let views = tensor_bridge::f64_views(maps)?;
1218 self.decode_tracked_float(
1219 tracker,
1220 timestamp,
1221 &views,
1222 output_boxes,
1223 output_masks,
1224 output_tracks,
1225 )
1226 }
1227 }
1228 }
1229
1230 pub fn decode_proto_tracked<TR: edgefirst_tracker::Tracker<DetectBox>>(
1250 &self,
1251 tracker: &mut TR,
1252 timestamp: u64,
1253 outputs: &[&edgefirst_tensor::TensorDyn],
1254 output_boxes: &mut Vec<DetectBox>,
1255 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1256 ) -> Result<Option<ProtoData>, DecoderError> {
1257 let mapped = tensor_bridge::map_tensors(outputs)?;
1258 match &mapped {
1259 tensor_bridge::MappedOutputs::Quantized(maps) => {
1260 let views = tensor_bridge::quantized_views(maps)?;
1261 self.decode_tracked_quantized_proto(
1262 tracker,
1263 timestamp,
1264 &views,
1265 output_boxes,
1266 output_tracks,
1267 )
1268 }
1269 tensor_bridge::MappedOutputs::Float32(maps) => {
1270 let views = tensor_bridge::f32_views(maps)?;
1271 self.decode_tracked_float_proto(
1272 tracker,
1273 timestamp,
1274 &views,
1275 output_boxes,
1276 output_tracks,
1277 )
1278 }
1279 tensor_bridge::MappedOutputs::Float64(maps) => {
1280 let views = tensor_bridge::f64_views(maps)?;
1281 self.decode_tracked_float_proto(
1282 tracker,
1283 timestamp,
1284 &views,
1285 output_boxes,
1286 output_tracks,
1287 )
1288 }
1289 }
1290 }
1291}