1use ndarray::{ArrayView, ArrayViewD, Dimension};
5use num_traits::{AsPrimitive, Float};
6
7use crate::{DecoderError, DetectBox, ProtoData, Segmentation};
8
9pub mod config;
10pub mod configs;
11
12use configs::ModelType;
13
14#[derive(Debug, Clone)]
15pub struct Decoder {
16 model_type: ModelType,
17 pub iou_threshold: f32,
18 pub score_threshold: f32,
19 pub nms: Option<configs::Nms>,
22 normalized: Option<bool>,
28 pub(crate) decode_program: Option<merge::DecodeProgram>,
33}
34
35impl PartialEq for Decoder {
36 fn eq(&self, other: &Self) -> bool {
37 self.model_type == other.model_type
40 && self.iou_threshold == other.iou_threshold
41 && self.score_threshold == other.score_threshold
42 && self.nms == other.nms
43 && self.normalized == other.normalized
44 && self.decode_program.is_some() == other.decode_program.is_some()
45 }
46}
47
48#[derive(Debug)]
49pub(crate) enum ArrayViewDQuantized<'a> {
50 UInt8(ArrayViewD<'a, u8>),
51 Int8(ArrayViewD<'a, i8>),
52 UInt16(ArrayViewD<'a, u16>),
53 Int16(ArrayViewD<'a, i16>),
54 UInt32(ArrayViewD<'a, u32>),
55 Int32(ArrayViewD<'a, i32>),
56}
57
58impl<'a, D> From<ArrayView<'a, u8, D>> for ArrayViewDQuantized<'a>
59where
60 D: Dimension,
61{
62 fn from(arr: ArrayView<'a, u8, D>) -> Self {
63 Self::UInt8(arr.into_dyn())
64 }
65}
66
67impl<'a, D> From<ArrayView<'a, i8, D>> for ArrayViewDQuantized<'a>
68where
69 D: Dimension,
70{
71 fn from(arr: ArrayView<'a, i8, D>) -> Self {
72 Self::Int8(arr.into_dyn())
73 }
74}
75
76impl<'a, D> From<ArrayView<'a, u16, D>> for ArrayViewDQuantized<'a>
77where
78 D: Dimension,
79{
80 fn from(arr: ArrayView<'a, u16, D>) -> Self {
81 Self::UInt16(arr.into_dyn())
82 }
83}
84
85impl<'a, D> From<ArrayView<'a, i16, D>> for ArrayViewDQuantized<'a>
86where
87 D: Dimension,
88{
89 fn from(arr: ArrayView<'a, i16, D>) -> Self {
90 Self::Int16(arr.into_dyn())
91 }
92}
93
94impl<'a, D> From<ArrayView<'a, u32, D>> for ArrayViewDQuantized<'a>
95where
96 D: Dimension,
97{
98 fn from(arr: ArrayView<'a, u32, D>) -> Self {
99 Self::UInt32(arr.into_dyn())
100 }
101}
102
103impl<'a, D> From<ArrayView<'a, i32, D>> for ArrayViewDQuantized<'a>
104where
105 D: Dimension,
106{
107 fn from(arr: ArrayView<'a, i32, D>) -> Self {
108 Self::Int32(arr.into_dyn())
109 }
110}
111
112impl<'a> ArrayViewDQuantized<'a> {
113 pub(crate) fn shape(&self) -> &[usize] {
115 match self {
116 ArrayViewDQuantized::UInt8(a) => a.shape(),
117 ArrayViewDQuantized::Int8(a) => a.shape(),
118 ArrayViewDQuantized::UInt16(a) => a.shape(),
119 ArrayViewDQuantized::Int16(a) => a.shape(),
120 ArrayViewDQuantized::UInt32(a) => a.shape(),
121 ArrayViewDQuantized::Int32(a) => a.shape(),
122 }
123 }
124}
125
126macro_rules! with_quantized {
133 ($x:expr, $var:ident, $body:expr) => {
134 match $x {
135 ArrayViewDQuantized::UInt8(x) => {
136 let $var = x;
137 $body
138 }
139 ArrayViewDQuantized::Int8(x) => {
140 let $var = x;
141 $body
142 }
143 ArrayViewDQuantized::UInt16(x) => {
144 let $var = x;
145 $body
146 }
147 ArrayViewDQuantized::Int16(x) => {
148 let $var = x;
149 $body
150 }
151 ArrayViewDQuantized::UInt32(x) => {
152 let $var = x;
153 $body
154 }
155 ArrayViewDQuantized::Int32(x) => {
156 let $var = x;
157 $body
158 }
159 }
160 };
161}
162
163mod builder;
164mod dfl;
165mod helpers;
166mod merge;
167mod postprocess;
168mod tensor_bridge;
169mod tests;
170
171pub use builder::DecoderBuilder;
172pub use config::{ConfigOutput, ConfigOutputRef, ConfigOutputs};
173
174impl Decoder {
175 pub fn model_type(&self) -> &ModelType {
194 &self.model_type
195 }
196
197 pub fn normalized_boxes(&self) -> Option<bool> {
223 self.normalized
224 }
225
226 pub(crate) fn decode_quantized(
230 &self,
231 outputs: &[ArrayViewDQuantized],
232 output_boxes: &mut Vec<DetectBox>,
233 output_masks: &mut Vec<Segmentation>,
234 ) -> Result<(), DecoderError> {
235 output_boxes.clear();
236 output_masks.clear();
237 match &self.model_type {
238 ModelType::ModelPackSegDet {
239 boxes,
240 scores,
241 segmentation,
242 } => {
243 self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)?;
244 self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
245 }
246 ModelType::ModelPackSegDetSplit {
247 detection,
248 segmentation,
249 } => {
250 self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)?;
251 self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
252 }
253 ModelType::ModelPackDet { boxes, scores } => {
254 self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)
255 }
256 ModelType::ModelPackDetSplit { detection } => {
257 self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)
258 }
259 ModelType::ModelPackSeg { segmentation } => {
260 self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
261 }
262 ModelType::YoloDet { boxes } => {
263 self.decode_yolo_det_quantized(outputs, boxes, output_boxes)
264 }
265 ModelType::YoloSegDet { boxes, protos } => self.decode_yolo_segdet_quantized(
266 outputs,
267 boxes,
268 protos,
269 output_boxes,
270 output_masks,
271 ),
272 ModelType::YoloSplitDet { boxes, scores } => {
273 self.decode_yolo_split_det_quantized(outputs, boxes, scores, output_boxes)
274 }
275 ModelType::YoloSplitSegDet {
276 boxes,
277 scores,
278 mask_coeff,
279 protos,
280 } => self.decode_yolo_split_segdet_quantized(
281 outputs,
282 boxes,
283 scores,
284 mask_coeff,
285 protos,
286 output_boxes,
287 output_masks,
288 ),
289 ModelType::YoloSegDet2Way {
290 boxes,
291 mask_coeff,
292 protos,
293 } => self.decode_yolo_segdet_2way_quantized(
294 outputs,
295 boxes,
296 mask_coeff,
297 protos,
298 output_boxes,
299 output_masks,
300 ),
301 ModelType::YoloEndToEndDet { boxes } => {
302 self.decode_yolo_end_to_end_det_quantized(outputs, boxes, output_boxes)
303 }
304 ModelType::YoloEndToEndSegDet { boxes, protos } => self
305 .decode_yolo_end_to_end_segdet_quantized(
306 outputs,
307 boxes,
308 protos,
309 output_boxes,
310 output_masks,
311 ),
312 ModelType::YoloSplitEndToEndDet {
313 boxes,
314 scores,
315 classes,
316 } => self.decode_yolo_split_end_to_end_det_quantized(
317 outputs,
318 boxes,
319 scores,
320 classes,
321 output_boxes,
322 ),
323 ModelType::YoloSplitEndToEndSegDet {
324 boxes,
325 scores,
326 classes,
327 mask_coeff,
328 protos,
329 } => self.decode_yolo_split_end_to_end_segdet_quantized(
330 outputs,
331 boxes,
332 scores,
333 classes,
334 mask_coeff,
335 protos,
336 output_boxes,
337 output_masks,
338 ),
339 }
340 }
341
342 pub(crate) fn decode_float<T>(
346 &self,
347 outputs: &[ArrayViewD<T>],
348 output_boxes: &mut Vec<DetectBox>,
349 output_masks: &mut Vec<Segmentation>,
350 ) -> Result<(), DecoderError>
351 where
352 T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
353 f32: AsPrimitive<T>,
354 {
355 output_boxes.clear();
356 output_masks.clear();
357 match &self.model_type {
358 ModelType::ModelPackSegDet {
359 boxes,
360 scores,
361 segmentation,
362 } => {
363 self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
364 self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
365 }
366 ModelType::ModelPackSegDetSplit {
367 detection,
368 segmentation,
369 } => {
370 self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
371 self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
372 }
373 ModelType::ModelPackDet { boxes, scores } => {
374 self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
375 }
376 ModelType::ModelPackDetSplit { detection } => {
377 self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
378 }
379 ModelType::ModelPackSeg { segmentation } => {
380 self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
381 }
382 ModelType::YoloDet { boxes } => {
383 self.decode_yolo_det_float(outputs, boxes, output_boxes)?;
384 }
385 ModelType::YoloSegDet { boxes, protos } => {
386 self.decode_yolo_segdet_float(outputs, boxes, protos, output_boxes, output_masks)?;
387 }
388 ModelType::YoloSplitDet { boxes, scores } => {
389 self.decode_yolo_split_det_float(outputs, boxes, scores, output_boxes)?;
390 }
391 ModelType::YoloSplitSegDet {
392 boxes,
393 scores,
394 mask_coeff,
395 protos,
396 } => {
397 self.decode_yolo_split_segdet_float(
398 outputs,
399 boxes,
400 scores,
401 mask_coeff,
402 protos,
403 output_boxes,
404 output_masks,
405 )?;
406 }
407 ModelType::YoloSegDet2Way {
408 boxes,
409 mask_coeff,
410 protos,
411 } => {
412 self.decode_yolo_segdet_2way_float(
413 outputs,
414 boxes,
415 mask_coeff,
416 protos,
417 output_boxes,
418 output_masks,
419 )?;
420 }
421 ModelType::YoloEndToEndDet { boxes } => {
422 self.decode_yolo_end_to_end_det_float(outputs, boxes, output_boxes)?;
423 }
424 ModelType::YoloEndToEndSegDet { boxes, protos } => {
425 self.decode_yolo_end_to_end_segdet_float(
426 outputs,
427 boxes,
428 protos,
429 output_boxes,
430 output_masks,
431 )?;
432 }
433 ModelType::YoloSplitEndToEndDet {
434 boxes,
435 scores,
436 classes,
437 } => {
438 self.decode_yolo_split_end_to_end_det_float(
439 outputs,
440 boxes,
441 scores,
442 classes,
443 output_boxes,
444 )?;
445 }
446 ModelType::YoloSplitEndToEndSegDet {
447 boxes,
448 scores,
449 classes,
450 mask_coeff,
451 protos,
452 } => {
453 self.decode_yolo_split_end_to_end_segdet_float(
454 outputs,
455 boxes,
456 scores,
457 classes,
458 mask_coeff,
459 protos,
460 output_boxes,
461 output_masks,
462 )?;
463 }
464 }
465 Ok(())
466 }
467
468 pub(crate) fn decode_quantized_proto(
475 &self,
476 outputs: &[ArrayViewDQuantized],
477 output_boxes: &mut Vec<DetectBox>,
478 ) -> Result<Option<ProtoData>, DecoderError> {
479 output_boxes.clear();
480 match &self.model_type {
481 ModelType::ModelPackDet { boxes, scores } => {
483 self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)?;
484 Ok(None)
485 }
486 ModelType::ModelPackDetSplit { detection } => {
487 self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)?;
488 Ok(None)
489 }
490 ModelType::YoloDet { boxes } => {
491 self.decode_yolo_det_quantized(outputs, boxes, output_boxes)?;
492 Ok(None)
493 }
494 ModelType::YoloSplitDet { boxes, scores } => {
495 self.decode_yolo_split_det_quantized(outputs, boxes, scores, output_boxes)?;
496 Ok(None)
497 }
498 ModelType::YoloEndToEndDet { boxes } => {
499 self.decode_yolo_end_to_end_det_quantized(outputs, boxes, output_boxes)?;
500 Ok(None)
501 }
502 ModelType::YoloSplitEndToEndDet {
503 boxes,
504 scores,
505 classes,
506 } => {
507 self.decode_yolo_split_end_to_end_det_quantized(
508 outputs,
509 boxes,
510 scores,
511 classes,
512 output_boxes,
513 )?;
514 Ok(None)
515 }
516 ModelType::ModelPackSegDet { boxes, scores, .. } => {
518 self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)?;
519 Ok(None)
520 }
521 ModelType::ModelPackSegDetSplit { detection, .. } => {
522 self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)?;
523 Ok(None)
524 }
525 ModelType::ModelPackSeg { .. } => Ok(None),
526
527 ModelType::YoloSegDet { boxes, protos } => {
528 let proto =
529 self.decode_yolo_segdet_quantized_proto(outputs, boxes, protos, output_boxes)?;
530 Ok(Some(proto))
531 }
532 ModelType::YoloSplitSegDet {
533 boxes,
534 scores,
535 mask_coeff,
536 protos,
537 } => {
538 let proto = self.decode_yolo_split_segdet_quantized_proto(
539 outputs,
540 boxes,
541 scores,
542 mask_coeff,
543 protos,
544 output_boxes,
545 )?;
546 Ok(Some(proto))
547 }
548 ModelType::YoloSegDet2Way {
549 boxes,
550 mask_coeff,
551 protos,
552 } => {
553 let proto = self.decode_yolo_segdet_2way_quantized_proto(
554 outputs,
555 boxes,
556 mask_coeff,
557 protos,
558 output_boxes,
559 )?;
560 Ok(Some(proto))
561 }
562 ModelType::YoloEndToEndSegDet { boxes, protos } => {
563 let proto = self.decode_yolo_end_to_end_segdet_quantized_proto(
564 outputs,
565 boxes,
566 protos,
567 output_boxes,
568 )?;
569 Ok(Some(proto))
570 }
571 ModelType::YoloSplitEndToEndSegDet {
572 boxes,
573 scores,
574 classes,
575 mask_coeff,
576 protos,
577 } => {
578 let proto = self.decode_yolo_split_end_to_end_segdet_quantized_proto(
579 outputs,
580 boxes,
581 scores,
582 classes,
583 mask_coeff,
584 protos,
585 output_boxes,
586 )?;
587 Ok(Some(proto))
588 }
589 }
590 }
591
592 pub(crate) fn decode_float_proto<T>(
599 &self,
600 outputs: &[ArrayViewD<T>],
601 output_boxes: &mut Vec<DetectBox>,
602 ) -> Result<Option<ProtoData>, DecoderError>
603 where
604 T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
605 f32: AsPrimitive<T>,
606 {
607 output_boxes.clear();
608 match &self.model_type {
609 ModelType::ModelPackDet { boxes, scores } => {
611 self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
612 Ok(None)
613 }
614 ModelType::ModelPackDetSplit { detection } => {
615 self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
616 Ok(None)
617 }
618 ModelType::YoloDet { boxes } => {
619 self.decode_yolo_det_float(outputs, boxes, output_boxes)?;
620 Ok(None)
621 }
622 ModelType::YoloSplitDet { boxes, scores } => {
623 self.decode_yolo_split_det_float(outputs, boxes, scores, output_boxes)?;
624 Ok(None)
625 }
626 ModelType::YoloEndToEndDet { boxes } => {
627 self.decode_yolo_end_to_end_det_float(outputs, boxes, output_boxes)?;
628 Ok(None)
629 }
630 ModelType::YoloSplitEndToEndDet {
631 boxes,
632 scores,
633 classes,
634 } => {
635 self.decode_yolo_split_end_to_end_det_float(
636 outputs,
637 boxes,
638 scores,
639 classes,
640 output_boxes,
641 )?;
642 Ok(None)
643 }
644 ModelType::ModelPackSegDet { boxes, scores, .. } => {
646 self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
647 Ok(None)
648 }
649 ModelType::ModelPackSegDetSplit { detection, .. } => {
650 self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
651 Ok(None)
652 }
653 ModelType::ModelPackSeg { .. } => Ok(None),
654
655 ModelType::YoloSegDet { boxes, protos } => {
656 let proto =
657 self.decode_yolo_segdet_float_proto(outputs, boxes, protos, output_boxes)?;
658 Ok(Some(proto))
659 }
660 ModelType::YoloSplitSegDet {
661 boxes,
662 scores,
663 mask_coeff,
664 protos,
665 } => {
666 let proto = self.decode_yolo_split_segdet_float_proto(
667 outputs,
668 boxes,
669 scores,
670 mask_coeff,
671 protos,
672 output_boxes,
673 )?;
674 Ok(Some(proto))
675 }
676 ModelType::YoloSegDet2Way {
677 boxes,
678 mask_coeff,
679 protos,
680 } => {
681 let proto = self.decode_yolo_segdet_2way_float_proto(
682 outputs,
683 boxes,
684 mask_coeff,
685 protos,
686 output_boxes,
687 )?;
688 Ok(Some(proto))
689 }
690 ModelType::YoloEndToEndSegDet { boxes, protos } => {
691 let proto = self.decode_yolo_end_to_end_segdet_float_proto(
692 outputs,
693 boxes,
694 protos,
695 output_boxes,
696 )?;
697 Ok(Some(proto))
698 }
699 ModelType::YoloSplitEndToEndSegDet {
700 boxes,
701 scores,
702 classes,
703 mask_coeff,
704 protos,
705 } => {
706 let proto = self.decode_yolo_split_end_to_end_segdet_float_proto(
707 outputs,
708 boxes,
709 scores,
710 classes,
711 mask_coeff,
712 protos,
713 output_boxes,
714 )?;
715 Ok(Some(proto))
716 }
717 }
718 }
719
720 pub fn decode(
741 &self,
742 outputs: &[&edgefirst_tensor::TensorDyn],
743 output_boxes: &mut Vec<DetectBox>,
744 output_masks: &mut Vec<Segmentation>,
745 ) -> Result<(), DecoderError> {
746 if let Some(program) = &self.decode_program {
749 let merged = program.execute(outputs)?;
750 let views: Vec<_> = merged.iter().map(|a| a.view()).collect();
751 return self.decode_float(&views, output_boxes, output_masks);
752 }
753
754 let mapped = tensor_bridge::map_tensors(outputs)?;
755 match &mapped {
756 tensor_bridge::MappedOutputs::Quantized(maps) => {
757 let views = tensor_bridge::quantized_views(maps)?;
758 self.decode_quantized(&views, output_boxes, output_masks)
759 }
760 tensor_bridge::MappedOutputs::Float32(maps) => {
761 let views = tensor_bridge::f32_views(maps)?;
762 self.decode_float(&views, output_boxes, output_masks)
763 }
764 tensor_bridge::MappedOutputs::Float64(maps) => {
765 let views = tensor_bridge::f64_views(maps)?;
766 self.decode_float(&views, output_boxes, output_masks)
767 }
768 }
769 }
770
771 pub fn decode_proto(
789 &self,
790 outputs: &[&edgefirst_tensor::TensorDyn],
791 output_boxes: &mut Vec<DetectBox>,
792 ) -> Result<Option<ProtoData>, DecoderError> {
793 let mapped = tensor_bridge::map_tensors(outputs)?;
794 match &mapped {
795 tensor_bridge::MappedOutputs::Quantized(maps) => {
796 let views = tensor_bridge::quantized_views(maps)?;
797 self.decode_quantized_proto(&views, output_boxes)
798 }
799 tensor_bridge::MappedOutputs::Float32(maps) => {
800 let views = tensor_bridge::f32_views(maps)?;
801 self.decode_float_proto(&views, output_boxes)
802 }
803 tensor_bridge::MappedOutputs::Float64(maps) => {
804 let views = tensor_bridge::f64_views(maps)?;
805 self.decode_float_proto(&views, output_boxes)
806 }
807 }
808 }
809}
810
811#[cfg(feature = "tracker")]
812pub use edgefirst_tracker::TrackInfo;
813
814#[cfg(feature = "tracker")]
815pub use edgefirst_tracker::Tracker;
816
817#[cfg(feature = "tracker")]
818impl Decoder {
819 pub(crate) fn decode_tracked_quantized<TR: edgefirst_tracker::Tracker<DetectBox>>(
823 &self,
824 tracker: &mut TR,
825 timestamp: u64,
826 outputs: &[ArrayViewDQuantized],
827 output_boxes: &mut Vec<DetectBox>,
828 output_masks: &mut Vec<Segmentation>,
829 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
830 ) -> Result<(), DecoderError> {
831 output_boxes.clear();
832 output_masks.clear();
833 output_tracks.clear();
834
835 match &self.model_type {
838 ModelType::YoloSegDet { boxes, protos } => self.decode_tracked_yolo_segdet_quantized(
839 tracker,
840 timestamp,
841 outputs,
842 boxes,
843 protos,
844 output_boxes,
845 output_masks,
846 output_tracks,
847 ),
848 ModelType::YoloSplitSegDet {
849 boxes,
850 scores,
851 mask_coeff,
852 protos,
853 } => self.decode_tracked_yolo_split_segdet_quantized(
854 tracker,
855 timestamp,
856 outputs,
857 boxes,
858 scores,
859 mask_coeff,
860 protos,
861 output_boxes,
862 output_masks,
863 output_tracks,
864 ),
865 ModelType::YoloEndToEndSegDet { boxes, protos } => self
866 .decode_tracked_yolo_end_to_end_segdet_quantized(
867 tracker,
868 timestamp,
869 outputs,
870 boxes,
871 protos,
872 output_boxes,
873 output_masks,
874 output_tracks,
875 ),
876 ModelType::YoloSplitEndToEndSegDet {
877 boxes,
878 scores,
879 classes,
880 mask_coeff,
881 protos,
882 } => self.decode_tracked_yolo_split_end_to_end_segdet_quantized(
883 tracker,
884 timestamp,
885 outputs,
886 boxes,
887 scores,
888 classes,
889 mask_coeff,
890 protos,
891 output_boxes,
892 output_masks,
893 output_tracks,
894 ),
895 ModelType::YoloSegDet2Way {
896 boxes,
897 mask_coeff,
898 protos,
899 } => self.decode_tracked_yolo_segdet_2way_quantized(
900 tracker,
901 timestamp,
902 outputs,
903 boxes,
904 mask_coeff,
905 protos,
906 output_boxes,
907 output_masks,
908 output_tracks,
909 ),
910 _ => {
911 self.decode_quantized(outputs, output_boxes, output_masks)?;
912 Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
913 Ok(())
914 }
915 }
916 }
917
918 pub(crate) fn decode_tracked_float<TR: edgefirst_tracker::Tracker<DetectBox>, T>(
928 &self,
929 tracker: &mut TR,
930 timestamp: u64,
931 outputs: &[ArrayViewD<T>],
932 output_boxes: &mut Vec<DetectBox>,
933 output_masks: &mut Vec<Segmentation>,
934 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
935 ) -> Result<(), DecoderError>
936 where
937 T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
938 f32: AsPrimitive<T>,
939 {
940 output_boxes.clear();
941 output_masks.clear();
942 output_tracks.clear();
943 match &self.model_type {
944 ModelType::YoloSegDet { boxes, protos } => {
945 self.decode_tracked_yolo_segdet_float(
946 tracker,
947 timestamp,
948 outputs,
949 boxes,
950 protos,
951 output_boxes,
952 output_masks,
953 output_tracks,
954 )?;
955 }
956 ModelType::YoloSplitSegDet {
957 boxes,
958 scores,
959 mask_coeff,
960 protos,
961 } => {
962 self.decode_tracked_yolo_split_segdet_float(
963 tracker,
964 timestamp,
965 outputs,
966 boxes,
967 scores,
968 mask_coeff,
969 protos,
970 output_boxes,
971 output_masks,
972 output_tracks,
973 )?;
974 }
975 ModelType::YoloEndToEndSegDet { boxes, protos } => {
976 self.decode_tracked_yolo_end_to_end_segdet_float(
977 tracker,
978 timestamp,
979 outputs,
980 boxes,
981 protos,
982 output_boxes,
983 output_masks,
984 output_tracks,
985 )?;
986 }
987 ModelType::YoloSplitEndToEndSegDet {
988 boxes,
989 scores,
990 classes,
991 mask_coeff,
992 protos,
993 } => {
994 self.decode_tracked_yolo_split_end_to_end_segdet_float(
995 tracker,
996 timestamp,
997 outputs,
998 boxes,
999 scores,
1000 classes,
1001 mask_coeff,
1002 protos,
1003 output_boxes,
1004 output_masks,
1005 output_tracks,
1006 )?;
1007 }
1008 ModelType::YoloSegDet2Way {
1009 boxes,
1010 mask_coeff,
1011 protos,
1012 } => {
1013 self.decode_tracked_yolo_segdet_2way_float(
1014 tracker,
1015 timestamp,
1016 outputs,
1017 boxes,
1018 mask_coeff,
1019 protos,
1020 output_boxes,
1021 output_masks,
1022 output_tracks,
1023 )?;
1024 }
1025 _ => {
1026 self.decode_float(outputs, output_boxes, output_masks)?;
1027 Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
1028 }
1029 }
1030 Ok(())
1031 }
1032
1033 pub(crate) fn decode_tracked_quantized_proto<TR: edgefirst_tracker::Tracker<DetectBox>>(
1040 &self,
1041 tracker: &mut TR,
1042 timestamp: u64,
1043 outputs: &[ArrayViewDQuantized],
1044 output_boxes: &mut Vec<DetectBox>,
1045 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1046 ) -> Result<Option<ProtoData>, DecoderError> {
1047 output_boxes.clear();
1048 output_tracks.clear();
1049 match &self.model_type {
1050 ModelType::YoloSegDet { boxes, protos } => {
1051 let proto = self.decode_tracked_yolo_segdet_quantized_proto(
1052 tracker,
1053 timestamp,
1054 outputs,
1055 boxes,
1056 protos,
1057 output_boxes,
1058 output_tracks,
1059 )?;
1060 Ok(Some(proto))
1061 }
1062 ModelType::YoloSplitSegDet {
1063 boxes,
1064 scores,
1065 mask_coeff,
1066 protos,
1067 } => {
1068 let proto = self.decode_tracked_yolo_split_segdet_quantized_proto(
1069 tracker,
1070 timestamp,
1071 outputs,
1072 boxes,
1073 scores,
1074 mask_coeff,
1075 protos,
1076 output_boxes,
1077 output_tracks,
1078 )?;
1079 Ok(Some(proto))
1080 }
1081 ModelType::YoloSegDet2Way {
1082 boxes,
1083 mask_coeff,
1084 protos,
1085 } => {
1086 let proto = self.decode_tracked_yolo_segdet_2way_quantized_proto(
1087 tracker,
1088 timestamp,
1089 outputs,
1090 boxes,
1091 mask_coeff,
1092 protos,
1093 output_boxes,
1094 output_tracks,
1095 )?;
1096 Ok(Some(proto))
1097 }
1098 ModelType::YoloEndToEndSegDet { boxes, protos } => {
1099 let proto = self.decode_tracked_yolo_end_to_end_segdet_quantized_proto(
1100 tracker,
1101 timestamp,
1102 outputs,
1103 boxes,
1104 protos,
1105 output_boxes,
1106 output_tracks,
1107 )?;
1108 Ok(Some(proto))
1109 }
1110 ModelType::YoloSplitEndToEndSegDet {
1111 boxes,
1112 scores,
1113 classes,
1114 mask_coeff,
1115 protos,
1116 } => {
1117 let proto = self.decode_tracked_yolo_split_end_to_end_segdet_quantized_proto(
1118 tracker,
1119 timestamp,
1120 outputs,
1121 boxes,
1122 scores,
1123 classes,
1124 mask_coeff,
1125 protos,
1126 output_boxes,
1127 output_tracks,
1128 )?;
1129 Ok(Some(proto))
1130 }
1131 _ => {
1133 let mut masks = Vec::new();
1134 self.decode_quantized(outputs, output_boxes, &mut masks)?;
1135 Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
1136 Ok(None)
1137 }
1138 }
1139 }
1140
1141 pub(crate) fn decode_tracked_float_proto<TR: edgefirst_tracker::Tracker<DetectBox>, T>(
1148 &self,
1149 tracker: &mut TR,
1150 timestamp: u64,
1151 outputs: &[ArrayViewD<T>],
1152 output_boxes: &mut Vec<DetectBox>,
1153 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1154 ) -> Result<Option<ProtoData>, DecoderError>
1155 where
1156 T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
1157 f32: AsPrimitive<T>,
1158 {
1159 output_boxes.clear();
1160 output_tracks.clear();
1161 match &self.model_type {
1162 ModelType::YoloSegDet { boxes, protos } => {
1163 let proto = self.decode_tracked_yolo_segdet_float_proto(
1164 tracker,
1165 timestamp,
1166 outputs,
1167 boxes,
1168 protos,
1169 output_boxes,
1170 output_tracks,
1171 )?;
1172 Ok(Some(proto))
1173 }
1174 ModelType::YoloSplitSegDet {
1175 boxes,
1176 scores,
1177 mask_coeff,
1178 protos,
1179 } => {
1180 let proto = self.decode_tracked_yolo_split_segdet_float_proto(
1181 tracker,
1182 timestamp,
1183 outputs,
1184 boxes,
1185 scores,
1186 mask_coeff,
1187 protos,
1188 output_boxes,
1189 output_tracks,
1190 )?;
1191 Ok(Some(proto))
1192 }
1193 ModelType::YoloSegDet2Way {
1194 boxes,
1195 mask_coeff,
1196 protos,
1197 } => {
1198 let proto = self.decode_tracked_yolo_segdet_2way_float_proto(
1199 tracker,
1200 timestamp,
1201 outputs,
1202 boxes,
1203 mask_coeff,
1204 protos,
1205 output_boxes,
1206 output_tracks,
1207 )?;
1208 Ok(Some(proto))
1209 }
1210 ModelType::YoloEndToEndSegDet { boxes, protos } => {
1211 let proto = self.decode_tracked_yolo_end_to_end_segdet_float_proto(
1212 tracker,
1213 timestamp,
1214 outputs,
1215 boxes,
1216 protos,
1217 output_boxes,
1218 output_tracks,
1219 )?;
1220 Ok(Some(proto))
1221 }
1222 ModelType::YoloSplitEndToEndSegDet {
1223 boxes,
1224 scores,
1225 classes,
1226 mask_coeff,
1227 protos,
1228 } => {
1229 let proto = self.decode_tracked_yolo_split_end_to_end_segdet_float_proto(
1230 tracker,
1231 timestamp,
1232 outputs,
1233 boxes,
1234 scores,
1235 classes,
1236 mask_coeff,
1237 protos,
1238 output_boxes,
1239 output_tracks,
1240 )?;
1241 Ok(Some(proto))
1242 }
1243 _ => {
1245 let mut masks = Vec::new();
1246 self.decode_float(outputs, output_boxes, &mut masks)?;
1247 Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
1248 Ok(None)
1249 }
1250 }
1251 }
1252
1253 pub fn decode_tracked<TR: edgefirst_tracker::Tracker<DetectBox>>(
1277 &self,
1278 tracker: &mut TR,
1279 timestamp: u64,
1280 outputs: &[&edgefirst_tensor::TensorDyn],
1281 output_boxes: &mut Vec<DetectBox>,
1282 output_masks: &mut Vec<Segmentation>,
1283 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1284 ) -> Result<(), DecoderError> {
1285 let mapped = tensor_bridge::map_tensors(outputs)?;
1286 match &mapped {
1287 tensor_bridge::MappedOutputs::Quantized(maps) => {
1288 let views = tensor_bridge::quantized_views(maps)?;
1289 self.decode_tracked_quantized(
1290 tracker,
1291 timestamp,
1292 &views,
1293 output_boxes,
1294 output_masks,
1295 output_tracks,
1296 )
1297 }
1298 tensor_bridge::MappedOutputs::Float32(maps) => {
1299 let views = tensor_bridge::f32_views(maps)?;
1300 self.decode_tracked_float(
1301 tracker,
1302 timestamp,
1303 &views,
1304 output_boxes,
1305 output_masks,
1306 output_tracks,
1307 )
1308 }
1309 tensor_bridge::MappedOutputs::Float64(maps) => {
1310 let views = tensor_bridge::f64_views(maps)?;
1311 self.decode_tracked_float(
1312 tracker,
1313 timestamp,
1314 &views,
1315 output_boxes,
1316 output_masks,
1317 output_tracks,
1318 )
1319 }
1320 }
1321 }
1322
1323 pub fn decode_proto_tracked<TR: edgefirst_tracker::Tracker<DetectBox>>(
1343 &self,
1344 tracker: &mut TR,
1345 timestamp: u64,
1346 outputs: &[&edgefirst_tensor::TensorDyn],
1347 output_boxes: &mut Vec<DetectBox>,
1348 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1349 ) -> Result<Option<ProtoData>, DecoderError> {
1350 let mapped = tensor_bridge::map_tensors(outputs)?;
1351 match &mapped {
1352 tensor_bridge::MappedOutputs::Quantized(maps) => {
1353 let views = tensor_bridge::quantized_views(maps)?;
1354 self.decode_tracked_quantized_proto(
1355 tracker,
1356 timestamp,
1357 &views,
1358 output_boxes,
1359 output_tracks,
1360 )
1361 }
1362 tensor_bridge::MappedOutputs::Float32(maps) => {
1363 let views = tensor_bridge::f32_views(maps)?;
1364 self.decode_tracked_float_proto(
1365 tracker,
1366 timestamp,
1367 &views,
1368 output_boxes,
1369 output_tracks,
1370 )
1371 }
1372 tensor_bridge::MappedOutputs::Float64(maps) => {
1373 let views = tensor_bridge::f64_views(maps)?;
1374 self.decode_tracked_float_proto(
1375 tracker,
1376 timestamp,
1377 &views,
1378 output_boxes,
1379 output_tracks,
1380 )
1381 }
1382 }
1383 }
1384}