1use ndarray::{ArrayView, ArrayViewD, Dimension};
5use num_traits::{AsPrimitive, Float};
6
7use crate::{DecoderError, DetectBox, ProtoData, Segmentation};
8
9pub mod config;
10pub mod configs;
11
12use configs::ModelType;
13
14#[derive(Debug, Clone)]
15pub struct Decoder {
16 model_type: ModelType,
17 pub iou_threshold: f32,
18 pub score_threshold: f32,
19 pub nms: Option<configs::Nms>,
22 normalized: Option<bool>,
28 pub(crate) decode_program: Option<merge::DecodeProgram>,
33}
34
35impl PartialEq for Decoder {
36 fn eq(&self, other: &Self) -> bool {
37 self.model_type == other.model_type
40 && self.iou_threshold == other.iou_threshold
41 && self.score_threshold == other.score_threshold
42 && self.nms == other.nms
43 && self.normalized == other.normalized
44 && self.decode_program.is_some() == other.decode_program.is_some()
45 }
46}
47
48#[derive(Debug)]
49pub(crate) enum ArrayViewDQuantized<'a> {
50 UInt8(ArrayViewD<'a, u8>),
51 Int8(ArrayViewD<'a, i8>),
52 UInt16(ArrayViewD<'a, u16>),
53 Int16(ArrayViewD<'a, i16>),
54 UInt32(ArrayViewD<'a, u32>),
55 Int32(ArrayViewD<'a, i32>),
56}
57
58impl<'a, D> From<ArrayView<'a, u8, D>> for ArrayViewDQuantized<'a>
59where
60 D: Dimension,
61{
62 fn from(arr: ArrayView<'a, u8, D>) -> Self {
63 Self::UInt8(arr.into_dyn())
64 }
65}
66
67impl<'a, D> From<ArrayView<'a, i8, D>> for ArrayViewDQuantized<'a>
68where
69 D: Dimension,
70{
71 fn from(arr: ArrayView<'a, i8, D>) -> Self {
72 Self::Int8(arr.into_dyn())
73 }
74}
75
76impl<'a, D> From<ArrayView<'a, u16, D>> for ArrayViewDQuantized<'a>
77where
78 D: Dimension,
79{
80 fn from(arr: ArrayView<'a, u16, D>) -> Self {
81 Self::UInt16(arr.into_dyn())
82 }
83}
84
85impl<'a, D> From<ArrayView<'a, i16, D>> for ArrayViewDQuantized<'a>
86where
87 D: Dimension,
88{
89 fn from(arr: ArrayView<'a, i16, D>) -> Self {
90 Self::Int16(arr.into_dyn())
91 }
92}
93
94impl<'a, D> From<ArrayView<'a, u32, D>> for ArrayViewDQuantized<'a>
95where
96 D: Dimension,
97{
98 fn from(arr: ArrayView<'a, u32, D>) -> Self {
99 Self::UInt32(arr.into_dyn())
100 }
101}
102
103impl<'a, D> From<ArrayView<'a, i32, D>> for ArrayViewDQuantized<'a>
104where
105 D: Dimension,
106{
107 fn from(arr: ArrayView<'a, i32, D>) -> Self {
108 Self::Int32(arr.into_dyn())
109 }
110}
111
112impl<'a> ArrayViewDQuantized<'a> {
113 pub(crate) fn shape(&self) -> &[usize] {
115 match self {
116 ArrayViewDQuantized::UInt8(a) => a.shape(),
117 ArrayViewDQuantized::Int8(a) => a.shape(),
118 ArrayViewDQuantized::UInt16(a) => a.shape(),
119 ArrayViewDQuantized::Int16(a) => a.shape(),
120 ArrayViewDQuantized::UInt32(a) => a.shape(),
121 ArrayViewDQuantized::Int32(a) => a.shape(),
122 }
123 }
124}
125
126macro_rules! with_quantized {
133 ($x:expr, $var:ident, $body:expr) => {
134 match $x {
135 ArrayViewDQuantized::UInt8(x) => {
136 let $var = x;
137 $body
138 }
139 ArrayViewDQuantized::Int8(x) => {
140 let $var = x;
141 $body
142 }
143 ArrayViewDQuantized::UInt16(x) => {
144 let $var = x;
145 $body
146 }
147 ArrayViewDQuantized::Int16(x) => {
148 let $var = x;
149 $body
150 }
151 ArrayViewDQuantized::UInt32(x) => {
152 let $var = x;
153 $body
154 }
155 ArrayViewDQuantized::Int32(x) => {
156 let $var = x;
157 $body
158 }
159 }
160 };
161}
162
163mod builder;
164mod dfl;
165mod helpers;
166mod merge;
167mod postprocess;
168mod tensor_bridge;
169mod tests;
170
171pub use builder::DecoderBuilder;
172pub use config::{ConfigOutput, ConfigOutputRef, ConfigOutputs};
173
174impl Decoder {
175 pub fn model_type(&self) -> &ModelType {
194 &self.model_type
195 }
196
197 pub fn normalized_boxes(&self) -> Option<bool> {
223 self.normalized
224 }
225
226 pub(crate) fn decode_quantized(
230 &self,
231 outputs: &[ArrayViewDQuantized],
232 output_boxes: &mut Vec<DetectBox>,
233 output_masks: &mut Vec<Segmentation>,
234 ) -> Result<(), DecoderError> {
235 output_boxes.clear();
236 output_masks.clear();
237 match &self.model_type {
238 ModelType::ModelPackSegDet {
239 boxes,
240 scores,
241 segmentation,
242 } => {
243 self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)?;
244 self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
245 }
246 ModelType::ModelPackSegDetSplit {
247 detection,
248 segmentation,
249 } => {
250 self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)?;
251 self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
252 }
253 ModelType::ModelPackDet { boxes, scores } => {
254 self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)
255 }
256 ModelType::ModelPackDetSplit { detection } => {
257 self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)
258 }
259 ModelType::ModelPackSeg { segmentation } => {
260 self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
261 }
262 ModelType::YoloDet { boxes } => {
263 self.decode_yolo_det_quantized(outputs, boxes, output_boxes)
264 }
265 ModelType::YoloSegDet { boxes, protos } => self.decode_yolo_segdet_quantized(
266 outputs,
267 boxes,
268 protos,
269 output_boxes,
270 output_masks,
271 ),
272 ModelType::YoloSplitDet { boxes, scores } => {
273 self.decode_yolo_split_det_quantized(outputs, boxes, scores, output_boxes)
274 }
275 ModelType::YoloSplitSegDet {
276 boxes,
277 scores,
278 mask_coeff,
279 protos,
280 } => self.decode_yolo_split_segdet_quantized(
281 outputs,
282 boxes,
283 scores,
284 mask_coeff,
285 protos,
286 output_boxes,
287 output_masks,
288 ),
289 ModelType::YoloSegDet2Way {
290 boxes,
291 mask_coeff,
292 protos,
293 } => self.decode_yolo_segdet_2way_quantized(
294 outputs,
295 boxes,
296 mask_coeff,
297 protos,
298 output_boxes,
299 output_masks,
300 ),
301 ModelType::YoloEndToEndDet { boxes } => {
302 self.decode_yolo_end_to_end_det_quantized(outputs, boxes, output_boxes)
303 }
304 ModelType::YoloEndToEndSegDet { boxes, protos } => self
305 .decode_yolo_end_to_end_segdet_quantized(
306 outputs,
307 boxes,
308 protos,
309 output_boxes,
310 output_masks,
311 ),
312 ModelType::YoloSplitEndToEndDet {
313 boxes,
314 scores,
315 classes,
316 } => self.decode_yolo_split_end_to_end_det_quantized(
317 outputs,
318 boxes,
319 scores,
320 classes,
321 output_boxes,
322 ),
323 ModelType::YoloSplitEndToEndSegDet {
324 boxes,
325 scores,
326 classes,
327 mask_coeff,
328 protos,
329 } => self.decode_yolo_split_end_to_end_segdet_quantized(
330 outputs,
331 boxes,
332 scores,
333 classes,
334 mask_coeff,
335 protos,
336 output_boxes,
337 output_masks,
338 ),
339 }
340 }
341
342 pub(crate) fn decode_float<T>(
346 &self,
347 outputs: &[ArrayViewD<T>],
348 output_boxes: &mut Vec<DetectBox>,
349 output_masks: &mut Vec<Segmentation>,
350 ) -> Result<(), DecoderError>
351 where
352 T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
353 f32: AsPrimitive<T>,
354 {
355 output_boxes.clear();
356 output_masks.clear();
357 match &self.model_type {
358 ModelType::ModelPackSegDet {
359 boxes,
360 scores,
361 segmentation,
362 } => {
363 self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
364 self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
365 }
366 ModelType::ModelPackSegDetSplit {
367 detection,
368 segmentation,
369 } => {
370 self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
371 self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
372 }
373 ModelType::ModelPackDet { boxes, scores } => {
374 self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
375 }
376 ModelType::ModelPackDetSplit { detection } => {
377 self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
378 }
379 ModelType::ModelPackSeg { segmentation } => {
380 self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
381 }
382 ModelType::YoloDet { boxes } => {
383 self.decode_yolo_det_float(outputs, boxes, output_boxes)?;
384 }
385 ModelType::YoloSegDet { boxes, protos } => {
386 self.decode_yolo_segdet_float(outputs, boxes, protos, output_boxes, output_masks)?;
387 }
388 ModelType::YoloSplitDet { boxes, scores } => {
389 self.decode_yolo_split_det_float(outputs, boxes, scores, output_boxes)?;
390 }
391 ModelType::YoloSplitSegDet {
392 boxes,
393 scores,
394 mask_coeff,
395 protos,
396 } => {
397 self.decode_yolo_split_segdet_float(
398 outputs,
399 boxes,
400 scores,
401 mask_coeff,
402 protos,
403 output_boxes,
404 output_masks,
405 )?;
406 }
407 ModelType::YoloSegDet2Way {
408 boxes,
409 mask_coeff,
410 protos,
411 } => {
412 self.decode_yolo_segdet_2way_float(
413 outputs,
414 boxes,
415 mask_coeff,
416 protos,
417 output_boxes,
418 output_masks,
419 )?;
420 }
421 ModelType::YoloEndToEndDet { boxes } => {
422 self.decode_yolo_end_to_end_det_float(outputs, boxes, output_boxes)?;
423 }
424 ModelType::YoloEndToEndSegDet { boxes, protos } => {
425 self.decode_yolo_end_to_end_segdet_float(
426 outputs,
427 boxes,
428 protos,
429 output_boxes,
430 output_masks,
431 )?;
432 }
433 ModelType::YoloSplitEndToEndDet {
434 boxes,
435 scores,
436 classes,
437 } => {
438 self.decode_yolo_split_end_to_end_det_float(
439 outputs,
440 boxes,
441 scores,
442 classes,
443 output_boxes,
444 )?;
445 }
446 ModelType::YoloSplitEndToEndSegDet {
447 boxes,
448 scores,
449 classes,
450 mask_coeff,
451 protos,
452 } => {
453 self.decode_yolo_split_end_to_end_segdet_float(
454 outputs,
455 boxes,
456 scores,
457 classes,
458 mask_coeff,
459 protos,
460 output_boxes,
461 output_masks,
462 )?;
463 }
464 }
465 Ok(())
466 }
467
468 pub(crate) fn decode_quantized_proto(
475 &self,
476 outputs: &[ArrayViewDQuantized],
477 output_boxes: &mut Vec<DetectBox>,
478 ) -> Result<Option<ProtoData>, DecoderError> {
479 output_boxes.clear();
480 match &self.model_type {
481 ModelType::ModelPackDet { boxes, scores } => {
483 self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)?;
484 Ok(None)
485 }
486 ModelType::ModelPackDetSplit { detection } => {
487 self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)?;
488 Ok(None)
489 }
490 ModelType::YoloDet { boxes } => {
491 self.decode_yolo_det_quantized(outputs, boxes, output_boxes)?;
492 Ok(None)
493 }
494 ModelType::YoloSplitDet { boxes, scores } => {
495 self.decode_yolo_split_det_quantized(outputs, boxes, scores, output_boxes)?;
496 Ok(None)
497 }
498 ModelType::YoloEndToEndDet { boxes } => {
499 self.decode_yolo_end_to_end_det_quantized(outputs, boxes, output_boxes)?;
500 Ok(None)
501 }
502 ModelType::YoloSplitEndToEndDet {
503 boxes,
504 scores,
505 classes,
506 } => {
507 self.decode_yolo_split_end_to_end_det_quantized(
508 outputs,
509 boxes,
510 scores,
511 classes,
512 output_boxes,
513 )?;
514 Ok(None)
515 }
516 ModelType::ModelPackSegDet { boxes, scores, .. } => {
518 self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)?;
519 Ok(None)
520 }
521 ModelType::ModelPackSegDetSplit { detection, .. } => {
522 self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)?;
523 Ok(None)
524 }
525 ModelType::ModelPackSeg { .. } => Ok(None),
526
527 ModelType::YoloSegDet { boxes, protos } => {
528 let proto =
529 self.decode_yolo_segdet_quantized_proto(outputs, boxes, protos, output_boxes)?;
530 Ok(Some(proto))
531 }
532 ModelType::YoloSplitSegDet {
533 boxes,
534 scores,
535 mask_coeff,
536 protos,
537 } => {
538 let proto = self.decode_yolo_split_segdet_quantized_proto(
539 outputs,
540 boxes,
541 scores,
542 mask_coeff,
543 protos,
544 output_boxes,
545 )?;
546 Ok(Some(proto))
547 }
548 ModelType::YoloSegDet2Way {
549 boxes,
550 mask_coeff,
551 protos,
552 } => {
553 let proto = self.decode_yolo_segdet_2way_quantized_proto(
554 outputs,
555 boxes,
556 mask_coeff,
557 protos,
558 output_boxes,
559 )?;
560 Ok(Some(proto))
561 }
562 ModelType::YoloEndToEndSegDet { boxes, protos } => {
563 let proto = self.decode_yolo_end_to_end_segdet_quantized_proto(
564 outputs,
565 boxes,
566 protos,
567 output_boxes,
568 )?;
569 Ok(Some(proto))
570 }
571 ModelType::YoloSplitEndToEndSegDet {
572 boxes,
573 scores,
574 classes,
575 mask_coeff,
576 protos,
577 } => {
578 let proto = self.decode_yolo_split_end_to_end_segdet_quantized_proto(
579 outputs,
580 boxes,
581 scores,
582 classes,
583 mask_coeff,
584 protos,
585 output_boxes,
586 )?;
587 Ok(Some(proto))
588 }
589 }
590 }
591
592 pub(crate) fn decode_float_proto<T>(
599 &self,
600 outputs: &[ArrayViewD<T>],
601 output_boxes: &mut Vec<DetectBox>,
602 ) -> Result<Option<ProtoData>, DecoderError>
603 where
604 T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + crate::yolo::FloatProtoElem,
605 f32: AsPrimitive<T>,
606 {
607 output_boxes.clear();
608 match &self.model_type {
609 ModelType::ModelPackDet { boxes, scores } => {
611 self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
612 Ok(None)
613 }
614 ModelType::ModelPackDetSplit { detection } => {
615 self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
616 Ok(None)
617 }
618 ModelType::YoloDet { boxes } => {
619 self.decode_yolo_det_float(outputs, boxes, output_boxes)?;
620 Ok(None)
621 }
622 ModelType::YoloSplitDet { boxes, scores } => {
623 self.decode_yolo_split_det_float(outputs, boxes, scores, output_boxes)?;
624 Ok(None)
625 }
626 ModelType::YoloEndToEndDet { boxes } => {
627 self.decode_yolo_end_to_end_det_float(outputs, boxes, output_boxes)?;
628 Ok(None)
629 }
630 ModelType::YoloSplitEndToEndDet {
631 boxes,
632 scores,
633 classes,
634 } => {
635 self.decode_yolo_split_end_to_end_det_float(
636 outputs,
637 boxes,
638 scores,
639 classes,
640 output_boxes,
641 )?;
642 Ok(None)
643 }
644 ModelType::ModelPackSegDet { boxes, scores, .. } => {
646 self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
647 Ok(None)
648 }
649 ModelType::ModelPackSegDetSplit { detection, .. } => {
650 self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
651 Ok(None)
652 }
653 ModelType::ModelPackSeg { .. } => Ok(None),
654
655 ModelType::YoloSegDet { boxes, protos } => {
656 let proto =
657 self.decode_yolo_segdet_float_proto(outputs, boxes, protos, output_boxes)?;
658 Ok(Some(proto))
659 }
660 ModelType::YoloSplitSegDet {
661 boxes,
662 scores,
663 mask_coeff,
664 protos,
665 } => {
666 let proto = self.decode_yolo_split_segdet_float_proto(
667 outputs,
668 boxes,
669 scores,
670 mask_coeff,
671 protos,
672 output_boxes,
673 )?;
674 Ok(Some(proto))
675 }
676 ModelType::YoloSegDet2Way {
677 boxes,
678 mask_coeff,
679 protos,
680 } => {
681 let proto = self.decode_yolo_segdet_2way_float_proto(
682 outputs,
683 boxes,
684 mask_coeff,
685 protos,
686 output_boxes,
687 )?;
688 Ok(Some(proto))
689 }
690 ModelType::YoloEndToEndSegDet { boxes, protos } => {
691 let proto = self.decode_yolo_end_to_end_segdet_float_proto(
692 outputs,
693 boxes,
694 protos,
695 output_boxes,
696 )?;
697 Ok(Some(proto))
698 }
699 ModelType::YoloSplitEndToEndSegDet {
700 boxes,
701 scores,
702 classes,
703 mask_coeff,
704 protos,
705 } => {
706 let proto = self.decode_yolo_split_end_to_end_segdet_float_proto(
707 outputs,
708 boxes,
709 scores,
710 classes,
711 mask_coeff,
712 protos,
713 output_boxes,
714 )?;
715 Ok(Some(proto))
716 }
717 }
718 }
719
720 pub fn decode(
741 &self,
742 outputs: &[&edgefirst_tensor::TensorDyn],
743 output_boxes: &mut Vec<DetectBox>,
744 output_masks: &mut Vec<Segmentation>,
745 ) -> Result<(), DecoderError> {
746 if let Some(program) = &self.decode_program {
749 let merged = program.execute(outputs)?;
750 let views: Vec<_> = merged.iter().map(|a| a.view()).collect();
751 return self.decode_float(&views, output_boxes, output_masks);
752 }
753
754 let mapped = tensor_bridge::map_tensors(outputs)?;
755 match &mapped {
756 tensor_bridge::MappedOutputs::Quantized(maps) => {
757 let views = tensor_bridge::quantized_views(maps)?;
758 self.decode_quantized(&views, output_boxes, output_masks)
759 }
760 tensor_bridge::MappedOutputs::Float16(maps) => {
761 let views = tensor_bridge::f16_views(maps)?;
762 self.decode_float(&views, output_boxes, output_masks)
763 }
764 tensor_bridge::MappedOutputs::Float32(maps) => {
765 let views = tensor_bridge::f32_views(maps)?;
766 self.decode_float(&views, output_boxes, output_masks)
767 }
768 tensor_bridge::MappedOutputs::Float64(maps) => {
769 let views = tensor_bridge::f64_views(maps)?;
770 self.decode_float(&views, output_boxes, output_masks)
771 }
772 }
773 }
774
775 pub fn decode_proto(
793 &self,
794 outputs: &[&edgefirst_tensor::TensorDyn],
795 output_boxes: &mut Vec<DetectBox>,
796 ) -> Result<Option<ProtoData>, DecoderError> {
797 if let Some(program) = &self.decode_program {
800 let merged = program.execute(outputs)?;
801 let views: Vec<_> = merged.iter().map(|a| a.view()).collect();
802 return self.decode_float_proto(&views, output_boxes);
803 }
804
805 let mapped = tensor_bridge::map_tensors(outputs)?;
806 match &mapped {
807 tensor_bridge::MappedOutputs::Quantized(maps) => {
808 let views = tensor_bridge::quantized_views(maps)?;
809 self.decode_quantized_proto(&views, output_boxes)
810 }
811 tensor_bridge::MappedOutputs::Float16(maps) => {
812 let views = tensor_bridge::f16_views(maps)?;
813 self.decode_float_proto(&views, output_boxes)
814 }
815 tensor_bridge::MappedOutputs::Float32(maps) => {
816 let views = tensor_bridge::f32_views(maps)?;
817 self.decode_float_proto(&views, output_boxes)
818 }
819 tensor_bridge::MappedOutputs::Float64(maps) => {
820 let views = tensor_bridge::f64_views(maps)?;
821 self.decode_float_proto(&views, output_boxes)
822 }
823 }
824 }
825}
826
827#[cfg(feature = "tracker")]
828pub use edgefirst_tracker::TrackInfo;
829
830#[cfg(feature = "tracker")]
831pub use edgefirst_tracker::Tracker;
832
833#[cfg(feature = "tracker")]
834impl Decoder {
835 pub(crate) fn decode_tracked_quantized<TR: edgefirst_tracker::Tracker<DetectBox>>(
839 &self,
840 tracker: &mut TR,
841 timestamp: u64,
842 outputs: &[ArrayViewDQuantized],
843 output_boxes: &mut Vec<DetectBox>,
844 output_masks: &mut Vec<Segmentation>,
845 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
846 ) -> Result<(), DecoderError> {
847 output_boxes.clear();
848 output_masks.clear();
849 output_tracks.clear();
850
851 match &self.model_type {
854 ModelType::YoloSegDet { boxes, protos } => self.decode_tracked_yolo_segdet_quantized(
855 tracker,
856 timestamp,
857 outputs,
858 boxes,
859 protos,
860 output_boxes,
861 output_masks,
862 output_tracks,
863 ),
864 ModelType::YoloSplitSegDet {
865 boxes,
866 scores,
867 mask_coeff,
868 protos,
869 } => self.decode_tracked_yolo_split_segdet_quantized(
870 tracker,
871 timestamp,
872 outputs,
873 boxes,
874 scores,
875 mask_coeff,
876 protos,
877 output_boxes,
878 output_masks,
879 output_tracks,
880 ),
881 ModelType::YoloEndToEndSegDet { boxes, protos } => self
882 .decode_tracked_yolo_end_to_end_segdet_quantized(
883 tracker,
884 timestamp,
885 outputs,
886 boxes,
887 protos,
888 output_boxes,
889 output_masks,
890 output_tracks,
891 ),
892 ModelType::YoloSplitEndToEndSegDet {
893 boxes,
894 scores,
895 classes,
896 mask_coeff,
897 protos,
898 } => self.decode_tracked_yolo_split_end_to_end_segdet_quantized(
899 tracker,
900 timestamp,
901 outputs,
902 boxes,
903 scores,
904 classes,
905 mask_coeff,
906 protos,
907 output_boxes,
908 output_masks,
909 output_tracks,
910 ),
911 ModelType::YoloSegDet2Way {
912 boxes,
913 mask_coeff,
914 protos,
915 } => self.decode_tracked_yolo_segdet_2way_quantized(
916 tracker,
917 timestamp,
918 outputs,
919 boxes,
920 mask_coeff,
921 protos,
922 output_boxes,
923 output_masks,
924 output_tracks,
925 ),
926 _ => {
927 self.decode_quantized(outputs, output_boxes, output_masks)?;
928 Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
929 Ok(())
930 }
931 }
932 }
933
934 pub(crate) fn decode_tracked_float<TR: edgefirst_tracker::Tracker<DetectBox>, T>(
944 &self,
945 tracker: &mut TR,
946 timestamp: u64,
947 outputs: &[ArrayViewD<T>],
948 output_boxes: &mut Vec<DetectBox>,
949 output_masks: &mut Vec<Segmentation>,
950 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
951 ) -> Result<(), DecoderError>
952 where
953 T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
954 f32: AsPrimitive<T>,
955 {
956 output_boxes.clear();
957 output_masks.clear();
958 output_tracks.clear();
959 match &self.model_type {
960 ModelType::YoloSegDet { boxes, protos } => {
961 self.decode_tracked_yolo_segdet_float(
962 tracker,
963 timestamp,
964 outputs,
965 boxes,
966 protos,
967 output_boxes,
968 output_masks,
969 output_tracks,
970 )?;
971 }
972 ModelType::YoloSplitSegDet {
973 boxes,
974 scores,
975 mask_coeff,
976 protos,
977 } => {
978 self.decode_tracked_yolo_split_segdet_float(
979 tracker,
980 timestamp,
981 outputs,
982 boxes,
983 scores,
984 mask_coeff,
985 protos,
986 output_boxes,
987 output_masks,
988 output_tracks,
989 )?;
990 }
991 ModelType::YoloEndToEndSegDet { boxes, protos } => {
992 self.decode_tracked_yolo_end_to_end_segdet_float(
993 tracker,
994 timestamp,
995 outputs,
996 boxes,
997 protos,
998 output_boxes,
999 output_masks,
1000 output_tracks,
1001 )?;
1002 }
1003 ModelType::YoloSplitEndToEndSegDet {
1004 boxes,
1005 scores,
1006 classes,
1007 mask_coeff,
1008 protos,
1009 } => {
1010 self.decode_tracked_yolo_split_end_to_end_segdet_float(
1011 tracker,
1012 timestamp,
1013 outputs,
1014 boxes,
1015 scores,
1016 classes,
1017 mask_coeff,
1018 protos,
1019 output_boxes,
1020 output_masks,
1021 output_tracks,
1022 )?;
1023 }
1024 ModelType::YoloSegDet2Way {
1025 boxes,
1026 mask_coeff,
1027 protos,
1028 } => {
1029 self.decode_tracked_yolo_segdet_2way_float(
1030 tracker,
1031 timestamp,
1032 outputs,
1033 boxes,
1034 mask_coeff,
1035 protos,
1036 output_boxes,
1037 output_masks,
1038 output_tracks,
1039 )?;
1040 }
1041 _ => {
1042 self.decode_float(outputs, output_boxes, output_masks)?;
1043 Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
1044 }
1045 }
1046 Ok(())
1047 }
1048
1049 pub(crate) fn decode_tracked_quantized_proto<TR: edgefirst_tracker::Tracker<DetectBox>>(
1056 &self,
1057 tracker: &mut TR,
1058 timestamp: u64,
1059 outputs: &[ArrayViewDQuantized],
1060 output_boxes: &mut Vec<DetectBox>,
1061 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1062 ) -> Result<Option<ProtoData>, DecoderError> {
1063 output_boxes.clear();
1064 output_tracks.clear();
1065 match &self.model_type {
1066 ModelType::YoloSegDet { boxes, protos } => {
1067 let proto = self.decode_tracked_yolo_segdet_quantized_proto(
1068 tracker,
1069 timestamp,
1070 outputs,
1071 boxes,
1072 protos,
1073 output_boxes,
1074 output_tracks,
1075 )?;
1076 Ok(Some(proto))
1077 }
1078 ModelType::YoloSplitSegDet {
1079 boxes,
1080 scores,
1081 mask_coeff,
1082 protos,
1083 } => {
1084 let proto = self.decode_tracked_yolo_split_segdet_quantized_proto(
1085 tracker,
1086 timestamp,
1087 outputs,
1088 boxes,
1089 scores,
1090 mask_coeff,
1091 protos,
1092 output_boxes,
1093 output_tracks,
1094 )?;
1095 Ok(Some(proto))
1096 }
1097 ModelType::YoloSegDet2Way {
1098 boxes,
1099 mask_coeff,
1100 protos,
1101 } => {
1102 let proto = self.decode_tracked_yolo_segdet_2way_quantized_proto(
1103 tracker,
1104 timestamp,
1105 outputs,
1106 boxes,
1107 mask_coeff,
1108 protos,
1109 output_boxes,
1110 output_tracks,
1111 )?;
1112 Ok(Some(proto))
1113 }
1114 ModelType::YoloEndToEndSegDet { boxes, protos } => {
1115 let proto = self.decode_tracked_yolo_end_to_end_segdet_quantized_proto(
1116 tracker,
1117 timestamp,
1118 outputs,
1119 boxes,
1120 protos,
1121 output_boxes,
1122 output_tracks,
1123 )?;
1124 Ok(Some(proto))
1125 }
1126 ModelType::YoloSplitEndToEndSegDet {
1127 boxes,
1128 scores,
1129 classes,
1130 mask_coeff,
1131 protos,
1132 } => {
1133 let proto = self.decode_tracked_yolo_split_end_to_end_segdet_quantized_proto(
1134 tracker,
1135 timestamp,
1136 outputs,
1137 boxes,
1138 scores,
1139 classes,
1140 mask_coeff,
1141 protos,
1142 output_boxes,
1143 output_tracks,
1144 )?;
1145 Ok(Some(proto))
1146 }
1147 _ => {
1149 let mut masks = Vec::new();
1150 self.decode_quantized(outputs, output_boxes, &mut masks)?;
1151 Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
1152 Ok(None)
1153 }
1154 }
1155 }
1156
1157 pub(crate) fn decode_tracked_float_proto<TR: edgefirst_tracker::Tracker<DetectBox>, T>(
1164 &self,
1165 tracker: &mut TR,
1166 timestamp: u64,
1167 outputs: &[ArrayViewD<T>],
1168 output_boxes: &mut Vec<DetectBox>,
1169 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1170 ) -> Result<Option<ProtoData>, DecoderError>
1171 where
1172 T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + crate::yolo::FloatProtoElem,
1173 f32: AsPrimitive<T>,
1174 {
1175 output_boxes.clear();
1176 output_tracks.clear();
1177 match &self.model_type {
1178 ModelType::YoloSegDet { boxes, protos } => {
1179 let proto = self.decode_tracked_yolo_segdet_float_proto(
1180 tracker,
1181 timestamp,
1182 outputs,
1183 boxes,
1184 protos,
1185 output_boxes,
1186 output_tracks,
1187 )?;
1188 Ok(Some(proto))
1189 }
1190 ModelType::YoloSplitSegDet {
1191 boxes,
1192 scores,
1193 mask_coeff,
1194 protos,
1195 } => {
1196 let proto = self.decode_tracked_yolo_split_segdet_float_proto(
1197 tracker,
1198 timestamp,
1199 outputs,
1200 boxes,
1201 scores,
1202 mask_coeff,
1203 protos,
1204 output_boxes,
1205 output_tracks,
1206 )?;
1207 Ok(Some(proto))
1208 }
1209 ModelType::YoloSegDet2Way {
1210 boxes,
1211 mask_coeff,
1212 protos,
1213 } => {
1214 let proto = self.decode_tracked_yolo_segdet_2way_float_proto(
1215 tracker,
1216 timestamp,
1217 outputs,
1218 boxes,
1219 mask_coeff,
1220 protos,
1221 output_boxes,
1222 output_tracks,
1223 )?;
1224 Ok(Some(proto))
1225 }
1226 ModelType::YoloEndToEndSegDet { boxes, protos } => {
1227 let proto = self.decode_tracked_yolo_end_to_end_segdet_float_proto(
1228 tracker,
1229 timestamp,
1230 outputs,
1231 boxes,
1232 protos,
1233 output_boxes,
1234 output_tracks,
1235 )?;
1236 Ok(Some(proto))
1237 }
1238 ModelType::YoloSplitEndToEndSegDet {
1239 boxes,
1240 scores,
1241 classes,
1242 mask_coeff,
1243 protos,
1244 } => {
1245 let proto = self.decode_tracked_yolo_split_end_to_end_segdet_float_proto(
1246 tracker,
1247 timestamp,
1248 outputs,
1249 boxes,
1250 scores,
1251 classes,
1252 mask_coeff,
1253 protos,
1254 output_boxes,
1255 output_tracks,
1256 )?;
1257 Ok(Some(proto))
1258 }
1259 _ => {
1261 let mut masks = Vec::new();
1262 self.decode_float(outputs, output_boxes, &mut masks)?;
1263 Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
1264 Ok(None)
1265 }
1266 }
1267 }
1268
1269 pub fn decode_tracked<TR: edgefirst_tracker::Tracker<DetectBox>>(
1293 &self,
1294 tracker: &mut TR,
1295 timestamp: u64,
1296 outputs: &[&edgefirst_tensor::TensorDyn],
1297 output_boxes: &mut Vec<DetectBox>,
1298 output_masks: &mut Vec<Segmentation>,
1299 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1300 ) -> Result<(), DecoderError> {
1301 let mapped = tensor_bridge::map_tensors(outputs)?;
1302 match &mapped {
1303 tensor_bridge::MappedOutputs::Quantized(maps) => {
1304 let views = tensor_bridge::quantized_views(maps)?;
1305 self.decode_tracked_quantized(
1306 tracker,
1307 timestamp,
1308 &views,
1309 output_boxes,
1310 output_masks,
1311 output_tracks,
1312 )
1313 }
1314 tensor_bridge::MappedOutputs::Float16(maps) => {
1315 let views = tensor_bridge::f16_views(maps)?;
1316 self.decode_tracked_float(
1317 tracker,
1318 timestamp,
1319 &views,
1320 output_boxes,
1321 output_masks,
1322 output_tracks,
1323 )
1324 }
1325 tensor_bridge::MappedOutputs::Float32(maps) => {
1326 let views = tensor_bridge::f32_views(maps)?;
1327 self.decode_tracked_float(
1328 tracker,
1329 timestamp,
1330 &views,
1331 output_boxes,
1332 output_masks,
1333 output_tracks,
1334 )
1335 }
1336 tensor_bridge::MappedOutputs::Float64(maps) => {
1337 let views = tensor_bridge::f64_views(maps)?;
1338 self.decode_tracked_float(
1339 tracker,
1340 timestamp,
1341 &views,
1342 output_boxes,
1343 output_masks,
1344 output_tracks,
1345 )
1346 }
1347 }
1348 }
1349
1350 pub fn decode_proto_tracked<TR: edgefirst_tracker::Tracker<DetectBox>>(
1370 &self,
1371 tracker: &mut TR,
1372 timestamp: u64,
1373 outputs: &[&edgefirst_tensor::TensorDyn],
1374 output_boxes: &mut Vec<DetectBox>,
1375 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1376 ) -> Result<Option<ProtoData>, DecoderError> {
1377 let mapped = tensor_bridge::map_tensors(outputs)?;
1378 match &mapped {
1379 tensor_bridge::MappedOutputs::Quantized(maps) => {
1380 let views = tensor_bridge::quantized_views(maps)?;
1381 self.decode_tracked_quantized_proto(
1382 tracker,
1383 timestamp,
1384 &views,
1385 output_boxes,
1386 output_tracks,
1387 )
1388 }
1389 tensor_bridge::MappedOutputs::Float16(maps) => {
1390 let views = tensor_bridge::f16_views(maps)?;
1391 self.decode_tracked_float_proto(
1392 tracker,
1393 timestamp,
1394 &views,
1395 output_boxes,
1396 output_tracks,
1397 )
1398 }
1399 tensor_bridge::MappedOutputs::Float32(maps) => {
1400 let views = tensor_bridge::f32_views(maps)?;
1401 self.decode_tracked_float_proto(
1402 tracker,
1403 timestamp,
1404 &views,
1405 output_boxes,
1406 output_tracks,
1407 )
1408 }
1409 tensor_bridge::MappedOutputs::Float64(maps) => {
1410 let views = tensor_bridge::f64_views(maps)?;
1411 self.decode_tracked_float_proto(
1412 tracker,
1413 timestamp,
1414 &views,
1415 output_boxes,
1416 output_tracks,
1417 )
1418 }
1419 }
1420 }
1421}