1use ndarray::{ArrayView, ArrayViewD, Dimension};
5use num_traits::{AsPrimitive, Float};
6
7use crate::{DecoderError, DetectBox, ProtoData, Segmentation};
8
9pub mod config;
10pub mod configs;
11
12use configs::ModelType;
13
14#[derive(Debug, Clone, PartialEq)]
15pub struct Decoder {
16 model_type: ModelType,
17 pub iou_threshold: f32,
18 pub score_threshold: f32,
19 pub nms: Option<configs::Nms>,
22 normalized: Option<bool>,
28}
29
30#[derive(Debug)]
31pub(crate) enum ArrayViewDQuantized<'a> {
32 UInt8(ArrayViewD<'a, u8>),
33 Int8(ArrayViewD<'a, i8>),
34 UInt16(ArrayViewD<'a, u16>),
35 Int16(ArrayViewD<'a, i16>),
36 UInt32(ArrayViewD<'a, u32>),
37 Int32(ArrayViewD<'a, i32>),
38}
39
40impl<'a, D> From<ArrayView<'a, u8, D>> for ArrayViewDQuantized<'a>
41where
42 D: Dimension,
43{
44 fn from(arr: ArrayView<'a, u8, D>) -> Self {
45 Self::UInt8(arr.into_dyn())
46 }
47}
48
49impl<'a, D> From<ArrayView<'a, i8, D>> for ArrayViewDQuantized<'a>
50where
51 D: Dimension,
52{
53 fn from(arr: ArrayView<'a, i8, D>) -> Self {
54 Self::Int8(arr.into_dyn())
55 }
56}
57
58impl<'a, D> From<ArrayView<'a, u16, D>> for ArrayViewDQuantized<'a>
59where
60 D: Dimension,
61{
62 fn from(arr: ArrayView<'a, u16, D>) -> Self {
63 Self::UInt16(arr.into_dyn())
64 }
65}
66
67impl<'a, D> From<ArrayView<'a, i16, D>> for ArrayViewDQuantized<'a>
68where
69 D: Dimension,
70{
71 fn from(arr: ArrayView<'a, i16, D>) -> Self {
72 Self::Int16(arr.into_dyn())
73 }
74}
75
76impl<'a, D> From<ArrayView<'a, u32, D>> for ArrayViewDQuantized<'a>
77where
78 D: Dimension,
79{
80 fn from(arr: ArrayView<'a, u32, D>) -> Self {
81 Self::UInt32(arr.into_dyn())
82 }
83}
84
85impl<'a, D> From<ArrayView<'a, i32, D>> for ArrayViewDQuantized<'a>
86where
87 D: Dimension,
88{
89 fn from(arr: ArrayView<'a, i32, D>) -> Self {
90 Self::Int32(arr.into_dyn())
91 }
92}
93
94impl<'a> ArrayViewDQuantized<'a> {
95 pub(crate) fn shape(&self) -> &[usize] {
97 match self {
98 ArrayViewDQuantized::UInt8(a) => a.shape(),
99 ArrayViewDQuantized::Int8(a) => a.shape(),
100 ArrayViewDQuantized::UInt16(a) => a.shape(),
101 ArrayViewDQuantized::Int16(a) => a.shape(),
102 ArrayViewDQuantized::UInt32(a) => a.shape(),
103 ArrayViewDQuantized::Int32(a) => a.shape(),
104 }
105 }
106}
107
108macro_rules! with_quantized {
115 ($x:expr, $var:ident, $body:expr) => {
116 match $x {
117 ArrayViewDQuantized::UInt8(x) => {
118 let $var = x;
119 $body
120 }
121 ArrayViewDQuantized::Int8(x) => {
122 let $var = x;
123 $body
124 }
125 ArrayViewDQuantized::UInt16(x) => {
126 let $var = x;
127 $body
128 }
129 ArrayViewDQuantized::Int16(x) => {
130 let $var = x;
131 $body
132 }
133 ArrayViewDQuantized::UInt32(x) => {
134 let $var = x;
135 $body
136 }
137 ArrayViewDQuantized::Int32(x) => {
138 let $var = x;
139 $body
140 }
141 }
142 };
143}
144
145mod builder;
146mod helpers;
147mod postprocess;
148mod tensor_bridge;
149mod tests;
150
151pub use builder::DecoderBuilder;
152pub use config::{ConfigOutput, ConfigOutputRef, ConfigOutputs};
153
154impl Decoder {
155 pub fn model_type(&self) -> &ModelType {
174 &self.model_type
175 }
176
177 pub fn normalized_boxes(&self) -> Option<bool> {
203 self.normalized
204 }
205
206 pub(crate) fn decode_quantized(
210 &self,
211 outputs: &[ArrayViewDQuantized],
212 output_boxes: &mut Vec<DetectBox>,
213 output_masks: &mut Vec<Segmentation>,
214 ) -> Result<(), DecoderError> {
215 output_boxes.clear();
216 output_masks.clear();
217 match &self.model_type {
218 ModelType::ModelPackSegDet {
219 boxes,
220 scores,
221 segmentation,
222 } => {
223 self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)?;
224 self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
225 }
226 ModelType::ModelPackSegDetSplit {
227 detection,
228 segmentation,
229 } => {
230 self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)?;
231 self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
232 }
233 ModelType::ModelPackDet { boxes, scores } => {
234 self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)
235 }
236 ModelType::ModelPackDetSplit { detection } => {
237 self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)
238 }
239 ModelType::ModelPackSeg { segmentation } => {
240 self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
241 }
242 ModelType::YoloDet { boxes } => {
243 self.decode_yolo_det_quantized(outputs, boxes, output_boxes)
244 }
245 ModelType::YoloSegDet { boxes, protos } => self.decode_yolo_segdet_quantized(
246 outputs,
247 boxes,
248 protos,
249 output_boxes,
250 output_masks,
251 ),
252 ModelType::YoloSplitDet { boxes, scores } => {
253 self.decode_yolo_split_det_quantized(outputs, boxes, scores, output_boxes)
254 }
255 ModelType::YoloSplitSegDet {
256 boxes,
257 scores,
258 mask_coeff,
259 protos,
260 } => self.decode_yolo_split_segdet_quantized(
261 outputs,
262 boxes,
263 scores,
264 mask_coeff,
265 protos,
266 output_boxes,
267 output_masks,
268 ),
269 ModelType::YoloSegDet2Way {
270 boxes,
271 mask_coeff,
272 protos,
273 } => self.decode_yolo_segdet_2way_quantized(
274 outputs,
275 boxes,
276 mask_coeff,
277 protos,
278 output_boxes,
279 output_masks,
280 ),
281 ModelType::YoloEndToEndDet { boxes } => {
282 self.decode_yolo_end_to_end_det_quantized(outputs, boxes, output_boxes)
283 }
284 ModelType::YoloEndToEndSegDet { boxes, protos } => self
285 .decode_yolo_end_to_end_segdet_quantized(
286 outputs,
287 boxes,
288 protos,
289 output_boxes,
290 output_masks,
291 ),
292 ModelType::YoloSplitEndToEndDet {
293 boxes,
294 scores,
295 classes,
296 } => self.decode_yolo_split_end_to_end_det_quantized(
297 outputs,
298 boxes,
299 scores,
300 classes,
301 output_boxes,
302 ),
303 ModelType::YoloSplitEndToEndSegDet {
304 boxes,
305 scores,
306 classes,
307 mask_coeff,
308 protos,
309 } => self.decode_yolo_split_end_to_end_segdet_quantized(
310 outputs,
311 boxes,
312 scores,
313 classes,
314 mask_coeff,
315 protos,
316 output_boxes,
317 output_masks,
318 ),
319 }
320 }
321
322 pub(crate) fn decode_float<T>(
326 &self,
327 outputs: &[ArrayViewD<T>],
328 output_boxes: &mut Vec<DetectBox>,
329 output_masks: &mut Vec<Segmentation>,
330 ) -> Result<(), DecoderError>
331 where
332 T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
333 f32: AsPrimitive<T>,
334 {
335 output_boxes.clear();
336 output_masks.clear();
337 match &self.model_type {
338 ModelType::ModelPackSegDet {
339 boxes,
340 scores,
341 segmentation,
342 } => {
343 self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
344 self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
345 }
346 ModelType::ModelPackSegDetSplit {
347 detection,
348 segmentation,
349 } => {
350 self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
351 self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
352 }
353 ModelType::ModelPackDet { boxes, scores } => {
354 self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
355 }
356 ModelType::ModelPackDetSplit { detection } => {
357 self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
358 }
359 ModelType::ModelPackSeg { segmentation } => {
360 self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
361 }
362 ModelType::YoloDet { boxes } => {
363 self.decode_yolo_det_float(outputs, boxes, output_boxes)?;
364 }
365 ModelType::YoloSegDet { boxes, protos } => {
366 self.decode_yolo_segdet_float(outputs, boxes, protos, output_boxes, output_masks)?;
367 }
368 ModelType::YoloSplitDet { boxes, scores } => {
369 self.decode_yolo_split_det_float(outputs, boxes, scores, output_boxes)?;
370 }
371 ModelType::YoloSplitSegDet {
372 boxes,
373 scores,
374 mask_coeff,
375 protos,
376 } => {
377 self.decode_yolo_split_segdet_float(
378 outputs,
379 boxes,
380 scores,
381 mask_coeff,
382 protos,
383 output_boxes,
384 output_masks,
385 )?;
386 }
387 ModelType::YoloSegDet2Way {
388 boxes,
389 mask_coeff,
390 protos,
391 } => {
392 self.decode_yolo_segdet_2way_float(
393 outputs,
394 boxes,
395 mask_coeff,
396 protos,
397 output_boxes,
398 output_masks,
399 )?;
400 }
401 ModelType::YoloEndToEndDet { boxes } => {
402 self.decode_yolo_end_to_end_det_float(outputs, boxes, output_boxes)?;
403 }
404 ModelType::YoloEndToEndSegDet { boxes, protos } => {
405 self.decode_yolo_end_to_end_segdet_float(
406 outputs,
407 boxes,
408 protos,
409 output_boxes,
410 output_masks,
411 )?;
412 }
413 ModelType::YoloSplitEndToEndDet {
414 boxes,
415 scores,
416 classes,
417 } => {
418 self.decode_yolo_split_end_to_end_det_float(
419 outputs,
420 boxes,
421 scores,
422 classes,
423 output_boxes,
424 )?;
425 }
426 ModelType::YoloSplitEndToEndSegDet {
427 boxes,
428 scores,
429 classes,
430 mask_coeff,
431 protos,
432 } => {
433 self.decode_yolo_split_end_to_end_segdet_float(
434 outputs,
435 boxes,
436 scores,
437 classes,
438 mask_coeff,
439 protos,
440 output_boxes,
441 output_masks,
442 )?;
443 }
444 }
445 Ok(())
446 }
447
448 pub(crate) fn decode_quantized_proto(
455 &self,
456 outputs: &[ArrayViewDQuantized],
457 output_boxes: &mut Vec<DetectBox>,
458 ) -> Result<Option<ProtoData>, DecoderError> {
459 output_boxes.clear();
460 match &self.model_type {
461 ModelType::ModelPackDet { boxes, scores } => {
463 self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)?;
464 Ok(None)
465 }
466 ModelType::ModelPackDetSplit { detection } => {
467 self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)?;
468 Ok(None)
469 }
470 ModelType::YoloDet { boxes } => {
471 self.decode_yolo_det_quantized(outputs, boxes, output_boxes)?;
472 Ok(None)
473 }
474 ModelType::YoloSplitDet { boxes, scores } => {
475 self.decode_yolo_split_det_quantized(outputs, boxes, scores, output_boxes)?;
476 Ok(None)
477 }
478 ModelType::YoloEndToEndDet { boxes } => {
479 self.decode_yolo_end_to_end_det_quantized(outputs, boxes, output_boxes)?;
480 Ok(None)
481 }
482 ModelType::YoloSplitEndToEndDet {
483 boxes,
484 scores,
485 classes,
486 } => {
487 self.decode_yolo_split_end_to_end_det_quantized(
488 outputs,
489 boxes,
490 scores,
491 classes,
492 output_boxes,
493 )?;
494 Ok(None)
495 }
496 ModelType::ModelPackSegDet { boxes, scores, .. } => {
498 self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)?;
499 Ok(None)
500 }
501 ModelType::ModelPackSegDetSplit { detection, .. } => {
502 self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)?;
503 Ok(None)
504 }
505 ModelType::ModelPackSeg { .. } => Ok(None),
506
507 ModelType::YoloSegDet { boxes, protos } => {
508 let proto =
509 self.decode_yolo_segdet_quantized_proto(outputs, boxes, protos, output_boxes)?;
510 Ok(Some(proto))
511 }
512 ModelType::YoloSplitSegDet {
513 boxes,
514 scores,
515 mask_coeff,
516 protos,
517 } => {
518 let proto = self.decode_yolo_split_segdet_quantized_proto(
519 outputs,
520 boxes,
521 scores,
522 mask_coeff,
523 protos,
524 output_boxes,
525 )?;
526 Ok(Some(proto))
527 }
528 ModelType::YoloSegDet2Way {
529 boxes,
530 mask_coeff,
531 protos,
532 } => {
533 let proto = self.decode_yolo_segdet_2way_quantized_proto(
534 outputs,
535 boxes,
536 mask_coeff,
537 protos,
538 output_boxes,
539 )?;
540 Ok(Some(proto))
541 }
542 ModelType::YoloEndToEndSegDet { boxes, protos } => {
543 let proto = self.decode_yolo_end_to_end_segdet_quantized_proto(
544 outputs,
545 boxes,
546 protos,
547 output_boxes,
548 )?;
549 Ok(Some(proto))
550 }
551 ModelType::YoloSplitEndToEndSegDet {
552 boxes,
553 scores,
554 classes,
555 mask_coeff,
556 protos,
557 } => {
558 let proto = self.decode_yolo_split_end_to_end_segdet_quantized_proto(
559 outputs,
560 boxes,
561 scores,
562 classes,
563 mask_coeff,
564 protos,
565 output_boxes,
566 )?;
567 Ok(Some(proto))
568 }
569 }
570 }
571
572 pub(crate) fn decode_float_proto<T>(
579 &self,
580 outputs: &[ArrayViewD<T>],
581 output_boxes: &mut Vec<DetectBox>,
582 ) -> Result<Option<ProtoData>, DecoderError>
583 where
584 T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
585 f32: AsPrimitive<T>,
586 {
587 output_boxes.clear();
588 match &self.model_type {
589 ModelType::ModelPackDet { boxes, scores } => {
591 self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
592 Ok(None)
593 }
594 ModelType::ModelPackDetSplit { detection } => {
595 self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
596 Ok(None)
597 }
598 ModelType::YoloDet { boxes } => {
599 self.decode_yolo_det_float(outputs, boxes, output_boxes)?;
600 Ok(None)
601 }
602 ModelType::YoloSplitDet { boxes, scores } => {
603 self.decode_yolo_split_det_float(outputs, boxes, scores, output_boxes)?;
604 Ok(None)
605 }
606 ModelType::YoloEndToEndDet { boxes } => {
607 self.decode_yolo_end_to_end_det_float(outputs, boxes, output_boxes)?;
608 Ok(None)
609 }
610 ModelType::YoloSplitEndToEndDet {
611 boxes,
612 scores,
613 classes,
614 } => {
615 self.decode_yolo_split_end_to_end_det_float(
616 outputs,
617 boxes,
618 scores,
619 classes,
620 output_boxes,
621 )?;
622 Ok(None)
623 }
624 ModelType::ModelPackSegDet { boxes, scores, .. } => {
626 self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
627 Ok(None)
628 }
629 ModelType::ModelPackSegDetSplit { detection, .. } => {
630 self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
631 Ok(None)
632 }
633 ModelType::ModelPackSeg { .. } => Ok(None),
634
635 ModelType::YoloSegDet { boxes, protos } => {
636 let proto =
637 self.decode_yolo_segdet_float_proto(outputs, boxes, protos, output_boxes)?;
638 Ok(Some(proto))
639 }
640 ModelType::YoloSplitSegDet {
641 boxes,
642 scores,
643 mask_coeff,
644 protos,
645 } => {
646 let proto = self.decode_yolo_split_segdet_float_proto(
647 outputs,
648 boxes,
649 scores,
650 mask_coeff,
651 protos,
652 output_boxes,
653 )?;
654 Ok(Some(proto))
655 }
656 ModelType::YoloSegDet2Way {
657 boxes,
658 mask_coeff,
659 protos,
660 } => {
661 let proto = self.decode_yolo_segdet_2way_float_proto(
662 outputs,
663 boxes,
664 mask_coeff,
665 protos,
666 output_boxes,
667 )?;
668 Ok(Some(proto))
669 }
670 ModelType::YoloEndToEndSegDet { boxes, protos } => {
671 let proto = self.decode_yolo_end_to_end_segdet_float_proto(
672 outputs,
673 boxes,
674 protos,
675 output_boxes,
676 )?;
677 Ok(Some(proto))
678 }
679 ModelType::YoloSplitEndToEndSegDet {
680 boxes,
681 scores,
682 classes,
683 mask_coeff,
684 protos,
685 } => {
686 let proto = self.decode_yolo_split_end_to_end_segdet_float_proto(
687 outputs,
688 boxes,
689 scores,
690 classes,
691 mask_coeff,
692 protos,
693 output_boxes,
694 )?;
695 Ok(Some(proto))
696 }
697 }
698 }
699
700 pub fn decode(
721 &self,
722 outputs: &[&edgefirst_tensor::TensorDyn],
723 output_boxes: &mut Vec<DetectBox>,
724 output_masks: &mut Vec<Segmentation>,
725 ) -> Result<(), DecoderError> {
726 let mapped = tensor_bridge::map_tensors(outputs)?;
727 match &mapped {
728 tensor_bridge::MappedOutputs::Quantized(maps) => {
729 let views = tensor_bridge::quantized_views(maps)?;
730 self.decode_quantized(&views, output_boxes, output_masks)
731 }
732 tensor_bridge::MappedOutputs::Float32(maps) => {
733 let views = tensor_bridge::f32_views(maps)?;
734 self.decode_float(&views, output_boxes, output_masks)
735 }
736 tensor_bridge::MappedOutputs::Float64(maps) => {
737 let views = tensor_bridge::f64_views(maps)?;
738 self.decode_float(&views, output_boxes, output_masks)
739 }
740 }
741 }
742
743 pub fn decode_proto(
761 &self,
762 outputs: &[&edgefirst_tensor::TensorDyn],
763 output_boxes: &mut Vec<DetectBox>,
764 ) -> Result<Option<ProtoData>, DecoderError> {
765 let mapped = tensor_bridge::map_tensors(outputs)?;
766 match &mapped {
767 tensor_bridge::MappedOutputs::Quantized(maps) => {
768 let views = tensor_bridge::quantized_views(maps)?;
769 self.decode_quantized_proto(&views, output_boxes)
770 }
771 tensor_bridge::MappedOutputs::Float32(maps) => {
772 let views = tensor_bridge::f32_views(maps)?;
773 self.decode_float_proto(&views, output_boxes)
774 }
775 tensor_bridge::MappedOutputs::Float64(maps) => {
776 let views = tensor_bridge::f64_views(maps)?;
777 self.decode_float_proto(&views, output_boxes)
778 }
779 }
780 }
781}
782
783#[cfg(feature = "tracker")]
784pub use edgefirst_tracker::TrackInfo;
785
786#[cfg(feature = "tracker")]
787pub use edgefirst_tracker::Tracker;
788
789#[cfg(feature = "tracker")]
790impl Decoder {
791 pub(crate) fn decode_tracked_quantized<TR: edgefirst_tracker::Tracker<DetectBox>>(
795 &self,
796 tracker: &mut TR,
797 timestamp: u64,
798 outputs: &[ArrayViewDQuantized],
799 output_boxes: &mut Vec<DetectBox>,
800 output_masks: &mut Vec<Segmentation>,
801 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
802 ) -> Result<(), DecoderError> {
803 output_boxes.clear();
804 output_masks.clear();
805 output_tracks.clear();
806
807 match &self.model_type {
810 ModelType::YoloSegDet { boxes, protos } => self.decode_tracked_yolo_segdet_quantized(
811 tracker,
812 timestamp,
813 outputs,
814 boxes,
815 protos,
816 output_boxes,
817 output_masks,
818 output_tracks,
819 ),
820 ModelType::YoloSplitSegDet {
821 boxes,
822 scores,
823 mask_coeff,
824 protos,
825 } => self.decode_tracked_yolo_split_segdet_quantized(
826 tracker,
827 timestamp,
828 outputs,
829 boxes,
830 scores,
831 mask_coeff,
832 protos,
833 output_boxes,
834 output_masks,
835 output_tracks,
836 ),
837 ModelType::YoloEndToEndSegDet { boxes, protos } => self
838 .decode_tracked_yolo_end_to_end_segdet_quantized(
839 tracker,
840 timestamp,
841 outputs,
842 boxes,
843 protos,
844 output_boxes,
845 output_masks,
846 output_tracks,
847 ),
848 ModelType::YoloSplitEndToEndSegDet {
849 boxes,
850 scores,
851 classes,
852 mask_coeff,
853 protos,
854 } => self.decode_tracked_yolo_split_end_to_end_segdet_quantized(
855 tracker,
856 timestamp,
857 outputs,
858 boxes,
859 scores,
860 classes,
861 mask_coeff,
862 protos,
863 output_boxes,
864 output_masks,
865 output_tracks,
866 ),
867 ModelType::YoloSegDet2Way {
868 boxes,
869 mask_coeff,
870 protos,
871 } => self.decode_tracked_yolo_segdet_2way_quantized(
872 tracker,
873 timestamp,
874 outputs,
875 boxes,
876 mask_coeff,
877 protos,
878 output_boxes,
879 output_masks,
880 output_tracks,
881 ),
882 _ => {
883 self.decode_quantized(outputs, output_boxes, output_masks)?;
884 Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
885 Ok(())
886 }
887 }
888 }
889
890 pub(crate) fn decode_tracked_float<TR: edgefirst_tracker::Tracker<DetectBox>, T>(
900 &self,
901 tracker: &mut TR,
902 timestamp: u64,
903 outputs: &[ArrayViewD<T>],
904 output_boxes: &mut Vec<DetectBox>,
905 output_masks: &mut Vec<Segmentation>,
906 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
907 ) -> Result<(), DecoderError>
908 where
909 T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
910 f32: AsPrimitive<T>,
911 {
912 output_boxes.clear();
913 output_masks.clear();
914 output_tracks.clear();
915 match &self.model_type {
916 ModelType::YoloSegDet { boxes, protos } => {
917 self.decode_tracked_yolo_segdet_float(
918 tracker,
919 timestamp,
920 outputs,
921 boxes,
922 protos,
923 output_boxes,
924 output_masks,
925 output_tracks,
926 )?;
927 }
928 ModelType::YoloSplitSegDet {
929 boxes,
930 scores,
931 mask_coeff,
932 protos,
933 } => {
934 self.decode_tracked_yolo_split_segdet_float(
935 tracker,
936 timestamp,
937 outputs,
938 boxes,
939 scores,
940 mask_coeff,
941 protos,
942 output_boxes,
943 output_masks,
944 output_tracks,
945 )?;
946 }
947 ModelType::YoloEndToEndSegDet { boxes, protos } => {
948 self.decode_tracked_yolo_end_to_end_segdet_float(
949 tracker,
950 timestamp,
951 outputs,
952 boxes,
953 protos,
954 output_boxes,
955 output_masks,
956 output_tracks,
957 )?;
958 }
959 ModelType::YoloSplitEndToEndSegDet {
960 boxes,
961 scores,
962 classes,
963 mask_coeff,
964 protos,
965 } => {
966 self.decode_tracked_yolo_split_end_to_end_segdet_float(
967 tracker,
968 timestamp,
969 outputs,
970 boxes,
971 scores,
972 classes,
973 mask_coeff,
974 protos,
975 output_boxes,
976 output_masks,
977 output_tracks,
978 )?;
979 }
980 ModelType::YoloSegDet2Way {
981 boxes,
982 mask_coeff,
983 protos,
984 } => {
985 self.decode_tracked_yolo_segdet_2way_float(
986 tracker,
987 timestamp,
988 outputs,
989 boxes,
990 mask_coeff,
991 protos,
992 output_boxes,
993 output_masks,
994 output_tracks,
995 )?;
996 }
997 _ => {
998 self.decode_float(outputs, output_boxes, output_masks)?;
999 Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
1000 }
1001 }
1002 Ok(())
1003 }
1004
1005 pub(crate) fn decode_tracked_quantized_proto<TR: edgefirst_tracker::Tracker<DetectBox>>(
1012 &self,
1013 tracker: &mut TR,
1014 timestamp: u64,
1015 outputs: &[ArrayViewDQuantized],
1016 output_boxes: &mut Vec<DetectBox>,
1017 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1018 ) -> Result<Option<ProtoData>, DecoderError> {
1019 output_boxes.clear();
1020 output_tracks.clear();
1021 match &self.model_type {
1022 ModelType::YoloSegDet { boxes, protos } => {
1023 let proto = self.decode_tracked_yolo_segdet_quantized_proto(
1024 tracker,
1025 timestamp,
1026 outputs,
1027 boxes,
1028 protos,
1029 output_boxes,
1030 output_tracks,
1031 )?;
1032 Ok(Some(proto))
1033 }
1034 ModelType::YoloSplitSegDet {
1035 boxes,
1036 scores,
1037 mask_coeff,
1038 protos,
1039 } => {
1040 let proto = self.decode_tracked_yolo_split_segdet_quantized_proto(
1041 tracker,
1042 timestamp,
1043 outputs,
1044 boxes,
1045 scores,
1046 mask_coeff,
1047 protos,
1048 output_boxes,
1049 output_tracks,
1050 )?;
1051 Ok(Some(proto))
1052 }
1053 ModelType::YoloSegDet2Way {
1054 boxes,
1055 mask_coeff,
1056 protos,
1057 } => {
1058 let proto = self.decode_tracked_yolo_segdet_2way_quantized_proto(
1059 tracker,
1060 timestamp,
1061 outputs,
1062 boxes,
1063 mask_coeff,
1064 protos,
1065 output_boxes,
1066 output_tracks,
1067 )?;
1068 Ok(Some(proto))
1069 }
1070 ModelType::YoloEndToEndSegDet { boxes, protos } => {
1071 let proto = self.decode_tracked_yolo_end_to_end_segdet_quantized_proto(
1072 tracker,
1073 timestamp,
1074 outputs,
1075 boxes,
1076 protos,
1077 output_boxes,
1078 output_tracks,
1079 )?;
1080 Ok(Some(proto))
1081 }
1082 ModelType::YoloSplitEndToEndSegDet {
1083 boxes,
1084 scores,
1085 classes,
1086 mask_coeff,
1087 protos,
1088 } => {
1089 let proto = self.decode_tracked_yolo_split_end_to_end_segdet_quantized_proto(
1090 tracker,
1091 timestamp,
1092 outputs,
1093 boxes,
1094 scores,
1095 classes,
1096 mask_coeff,
1097 protos,
1098 output_boxes,
1099 output_tracks,
1100 )?;
1101 Ok(Some(proto))
1102 }
1103 _ => {
1105 let mut masks = Vec::new();
1106 self.decode_quantized(outputs, output_boxes, &mut masks)?;
1107 Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
1108 Ok(None)
1109 }
1110 }
1111 }
1112
1113 pub(crate) fn decode_tracked_float_proto<TR: edgefirst_tracker::Tracker<DetectBox>, T>(
1120 &self,
1121 tracker: &mut TR,
1122 timestamp: u64,
1123 outputs: &[ArrayViewD<T>],
1124 output_boxes: &mut Vec<DetectBox>,
1125 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1126 ) -> Result<Option<ProtoData>, DecoderError>
1127 where
1128 T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
1129 f32: AsPrimitive<T>,
1130 {
1131 output_boxes.clear();
1132 output_tracks.clear();
1133 match &self.model_type {
1134 ModelType::YoloSegDet { boxes, protos } => {
1135 let proto = self.decode_tracked_yolo_segdet_float_proto(
1136 tracker,
1137 timestamp,
1138 outputs,
1139 boxes,
1140 protos,
1141 output_boxes,
1142 output_tracks,
1143 )?;
1144 Ok(Some(proto))
1145 }
1146 ModelType::YoloSplitSegDet {
1147 boxes,
1148 scores,
1149 mask_coeff,
1150 protos,
1151 } => {
1152 let proto = self.decode_tracked_yolo_split_segdet_float_proto(
1153 tracker,
1154 timestamp,
1155 outputs,
1156 boxes,
1157 scores,
1158 mask_coeff,
1159 protos,
1160 output_boxes,
1161 output_tracks,
1162 )?;
1163 Ok(Some(proto))
1164 }
1165 ModelType::YoloSegDet2Way {
1166 boxes,
1167 mask_coeff,
1168 protos,
1169 } => {
1170 let proto = self.decode_tracked_yolo_segdet_2way_float_proto(
1171 tracker,
1172 timestamp,
1173 outputs,
1174 boxes,
1175 mask_coeff,
1176 protos,
1177 output_boxes,
1178 output_tracks,
1179 )?;
1180 Ok(Some(proto))
1181 }
1182 ModelType::YoloEndToEndSegDet { boxes, protos } => {
1183 let proto = self.decode_tracked_yolo_end_to_end_segdet_float_proto(
1184 tracker,
1185 timestamp,
1186 outputs,
1187 boxes,
1188 protos,
1189 output_boxes,
1190 output_tracks,
1191 )?;
1192 Ok(Some(proto))
1193 }
1194 ModelType::YoloSplitEndToEndSegDet {
1195 boxes,
1196 scores,
1197 classes,
1198 mask_coeff,
1199 protos,
1200 } => {
1201 let proto = self.decode_tracked_yolo_split_end_to_end_segdet_float_proto(
1202 tracker,
1203 timestamp,
1204 outputs,
1205 boxes,
1206 scores,
1207 classes,
1208 mask_coeff,
1209 protos,
1210 output_boxes,
1211 output_tracks,
1212 )?;
1213 Ok(Some(proto))
1214 }
1215 _ => {
1217 let mut masks = Vec::new();
1218 self.decode_float(outputs, output_boxes, &mut masks)?;
1219 Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
1220 Ok(None)
1221 }
1222 }
1223 }
1224
1225 pub fn decode_tracked<TR: edgefirst_tracker::Tracker<DetectBox>>(
1249 &self,
1250 tracker: &mut TR,
1251 timestamp: u64,
1252 outputs: &[&edgefirst_tensor::TensorDyn],
1253 output_boxes: &mut Vec<DetectBox>,
1254 output_masks: &mut Vec<Segmentation>,
1255 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1256 ) -> Result<(), DecoderError> {
1257 let mapped = tensor_bridge::map_tensors(outputs)?;
1258 match &mapped {
1259 tensor_bridge::MappedOutputs::Quantized(maps) => {
1260 let views = tensor_bridge::quantized_views(maps)?;
1261 self.decode_tracked_quantized(
1262 tracker,
1263 timestamp,
1264 &views,
1265 output_boxes,
1266 output_masks,
1267 output_tracks,
1268 )
1269 }
1270 tensor_bridge::MappedOutputs::Float32(maps) => {
1271 let views = tensor_bridge::f32_views(maps)?;
1272 self.decode_tracked_float(
1273 tracker,
1274 timestamp,
1275 &views,
1276 output_boxes,
1277 output_masks,
1278 output_tracks,
1279 )
1280 }
1281 tensor_bridge::MappedOutputs::Float64(maps) => {
1282 let views = tensor_bridge::f64_views(maps)?;
1283 self.decode_tracked_float(
1284 tracker,
1285 timestamp,
1286 &views,
1287 output_boxes,
1288 output_masks,
1289 output_tracks,
1290 )
1291 }
1292 }
1293 }
1294
1295 pub fn decode_proto_tracked<TR: edgefirst_tracker::Tracker<DetectBox>>(
1315 &self,
1316 tracker: &mut TR,
1317 timestamp: u64,
1318 outputs: &[&edgefirst_tensor::TensorDyn],
1319 output_boxes: &mut Vec<DetectBox>,
1320 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1321 ) -> Result<Option<ProtoData>, DecoderError> {
1322 let mapped = tensor_bridge::map_tensors(outputs)?;
1323 match &mapped {
1324 tensor_bridge::MappedOutputs::Quantized(maps) => {
1325 let views = tensor_bridge::quantized_views(maps)?;
1326 self.decode_tracked_quantized_proto(
1327 tracker,
1328 timestamp,
1329 &views,
1330 output_boxes,
1331 output_tracks,
1332 )
1333 }
1334 tensor_bridge::MappedOutputs::Float32(maps) => {
1335 let views = tensor_bridge::f32_views(maps)?;
1336 self.decode_tracked_float_proto(
1337 tracker,
1338 timestamp,
1339 &views,
1340 output_boxes,
1341 output_tracks,
1342 )
1343 }
1344 tensor_bridge::MappedOutputs::Float64(maps) => {
1345 let views = tensor_bridge::f64_views(maps)?;
1346 self.decode_tracked_float_proto(
1347 tracker,
1348 timestamp,
1349 &views,
1350 output_boxes,
1351 output_tracks,
1352 )
1353 }
1354 }
1355 }
1356}