1use ndarray::{ArrayView, ArrayViewD, Dimension};
5use num_traits::{AsPrimitive, Float};
6
7use crate::{DecoderError, DetectBox, ProtoData, Segmentation};
8
9pub mod config;
10pub mod configs;
11
12use configs::ModelType;
13
14#[derive(Debug, Clone)]
15pub struct Decoder {
16 model_type: ModelType,
17 pub iou_threshold: f32,
18 pub score_threshold: f32,
19 pub nms: Option<configs::Nms>,
22 pub pre_nms_top_k: usize,
31 pub max_det: usize,
37 normalized: Option<bool>,
43 pub(crate) decode_program: Option<merge::DecodeProgram>,
48}
49
50impl PartialEq for Decoder {
51 fn eq(&self, other: &Self) -> bool {
52 self.model_type == other.model_type
55 && self.iou_threshold == other.iou_threshold
56 && self.score_threshold == other.score_threshold
57 && self.nms == other.nms
58 && self.pre_nms_top_k == other.pre_nms_top_k
59 && self.max_det == other.max_det
60 && self.normalized == other.normalized
61 && self.decode_program.is_some() == other.decode_program.is_some()
62 }
63}
64
65#[derive(Debug)]
66pub(crate) enum ArrayViewDQuantized<'a> {
67 UInt8(ArrayViewD<'a, u8>),
68 Int8(ArrayViewD<'a, i8>),
69 UInt16(ArrayViewD<'a, u16>),
70 Int16(ArrayViewD<'a, i16>),
71 UInt32(ArrayViewD<'a, u32>),
72 Int32(ArrayViewD<'a, i32>),
73}
74
75impl<'a, D> From<ArrayView<'a, u8, D>> for ArrayViewDQuantized<'a>
76where
77 D: Dimension,
78{
79 fn from(arr: ArrayView<'a, u8, D>) -> Self {
80 Self::UInt8(arr.into_dyn())
81 }
82}
83
84impl<'a, D> From<ArrayView<'a, i8, D>> for ArrayViewDQuantized<'a>
85where
86 D: Dimension,
87{
88 fn from(arr: ArrayView<'a, i8, D>) -> Self {
89 Self::Int8(arr.into_dyn())
90 }
91}
92
93impl<'a, D> From<ArrayView<'a, u16, D>> for ArrayViewDQuantized<'a>
94where
95 D: Dimension,
96{
97 fn from(arr: ArrayView<'a, u16, D>) -> Self {
98 Self::UInt16(arr.into_dyn())
99 }
100}
101
102impl<'a, D> From<ArrayView<'a, i16, D>> for ArrayViewDQuantized<'a>
103where
104 D: Dimension,
105{
106 fn from(arr: ArrayView<'a, i16, D>) -> Self {
107 Self::Int16(arr.into_dyn())
108 }
109}
110
111impl<'a, D> From<ArrayView<'a, u32, D>> for ArrayViewDQuantized<'a>
112where
113 D: Dimension,
114{
115 fn from(arr: ArrayView<'a, u32, D>) -> Self {
116 Self::UInt32(arr.into_dyn())
117 }
118}
119
120impl<'a, D> From<ArrayView<'a, i32, D>> for ArrayViewDQuantized<'a>
121where
122 D: Dimension,
123{
124 fn from(arr: ArrayView<'a, i32, D>) -> Self {
125 Self::Int32(arr.into_dyn())
126 }
127}
128
129impl<'a> ArrayViewDQuantized<'a> {
130 pub(crate) fn shape(&self) -> &[usize] {
132 match self {
133 ArrayViewDQuantized::UInt8(a) => a.shape(),
134 ArrayViewDQuantized::Int8(a) => a.shape(),
135 ArrayViewDQuantized::UInt16(a) => a.shape(),
136 ArrayViewDQuantized::Int16(a) => a.shape(),
137 ArrayViewDQuantized::UInt32(a) => a.shape(),
138 ArrayViewDQuantized::Int32(a) => a.shape(),
139 }
140 }
141}
142
143macro_rules! with_quantized {
150 ($x:expr, $var:ident, $body:expr) => {
151 match $x {
152 ArrayViewDQuantized::UInt8(x) => {
153 let $var = x;
154 $body
155 }
156 ArrayViewDQuantized::Int8(x) => {
157 let $var = x;
158 $body
159 }
160 ArrayViewDQuantized::UInt16(x) => {
161 let $var = x;
162 $body
163 }
164 ArrayViewDQuantized::Int16(x) => {
165 let $var = x;
166 $body
167 }
168 ArrayViewDQuantized::UInt32(x) => {
169 let $var = x;
170 $body
171 }
172 ArrayViewDQuantized::Int32(x) => {
173 let $var = x;
174 $body
175 }
176 }
177 };
178}
179
180mod builder;
181mod dfl;
182mod helpers;
183mod merge;
184mod postprocess;
185mod tensor_bridge;
186mod tests;
187
188pub use builder::DecoderBuilder;
189pub use config::{ConfigOutput, ConfigOutputRef, ConfigOutputs};
190
191impl Decoder {
192 pub fn model_type(&self) -> &ModelType {
211 &self.model_type
212 }
213
214 pub fn normalized_boxes(&self) -> Option<bool> {
240 self.normalized
241 }
242
243 pub(crate) fn decode_quantized(
247 &self,
248 outputs: &[ArrayViewDQuantized],
249 output_boxes: &mut Vec<DetectBox>,
250 output_masks: &mut Vec<Segmentation>,
251 ) -> Result<(), DecoderError> {
252 output_boxes.clear();
253 output_masks.clear();
254 match &self.model_type {
255 ModelType::ModelPackSegDet {
256 boxes,
257 scores,
258 segmentation,
259 } => {
260 self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)?;
261 self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
262 }
263 ModelType::ModelPackSegDetSplit {
264 detection,
265 segmentation,
266 } => {
267 self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)?;
268 self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
269 }
270 ModelType::ModelPackDet { boxes, scores } => {
271 self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)
272 }
273 ModelType::ModelPackDetSplit { detection } => {
274 self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)
275 }
276 ModelType::ModelPackSeg { segmentation } => {
277 self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
278 }
279 ModelType::YoloDet { boxes } => {
280 self.decode_yolo_det_quantized(outputs, boxes, output_boxes)
281 }
282 ModelType::YoloSegDet { boxes, protos } => self.decode_yolo_segdet_quantized(
283 outputs,
284 boxes,
285 protos,
286 output_boxes,
287 output_masks,
288 ),
289 ModelType::YoloSplitDet { boxes, scores } => {
290 self.decode_yolo_split_det_quantized(outputs, boxes, scores, output_boxes)
291 }
292 ModelType::YoloSplitSegDet {
293 boxes,
294 scores,
295 mask_coeff,
296 protos,
297 } => self.decode_yolo_split_segdet_quantized(
298 outputs,
299 boxes,
300 scores,
301 mask_coeff,
302 protos,
303 output_boxes,
304 output_masks,
305 ),
306 ModelType::YoloSegDet2Way {
307 boxes,
308 mask_coeff,
309 protos,
310 } => self.decode_yolo_segdet_2way_quantized(
311 outputs,
312 boxes,
313 mask_coeff,
314 protos,
315 output_boxes,
316 output_masks,
317 ),
318 ModelType::YoloEndToEndDet { boxes } => {
319 self.decode_yolo_end_to_end_det_quantized(outputs, boxes, output_boxes)
320 }
321 ModelType::YoloEndToEndSegDet { boxes, protos } => self
322 .decode_yolo_end_to_end_segdet_quantized(
323 outputs,
324 boxes,
325 protos,
326 output_boxes,
327 output_masks,
328 ),
329 ModelType::YoloSplitEndToEndDet {
330 boxes,
331 scores,
332 classes,
333 } => self.decode_yolo_split_end_to_end_det_quantized(
334 outputs,
335 boxes,
336 scores,
337 classes,
338 output_boxes,
339 ),
340 ModelType::YoloSplitEndToEndSegDet {
341 boxes,
342 scores,
343 classes,
344 mask_coeff,
345 protos,
346 } => self.decode_yolo_split_end_to_end_segdet_quantized(
347 outputs,
348 boxes,
349 scores,
350 classes,
351 mask_coeff,
352 protos,
353 output_boxes,
354 output_masks,
355 ),
356 }
357 }
358
359 pub(crate) fn decode_float<T>(
363 &self,
364 outputs: &[ArrayViewD<T>],
365 output_boxes: &mut Vec<DetectBox>,
366 output_masks: &mut Vec<Segmentation>,
367 ) -> Result<(), DecoderError>
368 where
369 T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
370 f32: AsPrimitive<T>,
371 {
372 output_boxes.clear();
373 output_masks.clear();
374 match &self.model_type {
375 ModelType::ModelPackSegDet {
376 boxes,
377 scores,
378 segmentation,
379 } => {
380 self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
381 self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
382 }
383 ModelType::ModelPackSegDetSplit {
384 detection,
385 segmentation,
386 } => {
387 self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
388 self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
389 }
390 ModelType::ModelPackDet { boxes, scores } => {
391 self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
392 }
393 ModelType::ModelPackDetSplit { detection } => {
394 self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
395 }
396 ModelType::ModelPackSeg { segmentation } => {
397 self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
398 }
399 ModelType::YoloDet { boxes } => {
400 self.decode_yolo_det_float(outputs, boxes, output_boxes)?;
401 }
402 ModelType::YoloSegDet { boxes, protos } => {
403 self.decode_yolo_segdet_float(outputs, boxes, protos, output_boxes, output_masks)?;
404 }
405 ModelType::YoloSplitDet { boxes, scores } => {
406 self.decode_yolo_split_det_float(outputs, boxes, scores, output_boxes)?;
407 }
408 ModelType::YoloSplitSegDet {
409 boxes,
410 scores,
411 mask_coeff,
412 protos,
413 } => {
414 self.decode_yolo_split_segdet_float(
415 outputs,
416 boxes,
417 scores,
418 mask_coeff,
419 protos,
420 output_boxes,
421 output_masks,
422 )?;
423 }
424 ModelType::YoloSegDet2Way {
425 boxes,
426 mask_coeff,
427 protos,
428 } => {
429 self.decode_yolo_segdet_2way_float(
430 outputs,
431 boxes,
432 mask_coeff,
433 protos,
434 output_boxes,
435 output_masks,
436 )?;
437 }
438 ModelType::YoloEndToEndDet { boxes } => {
439 self.decode_yolo_end_to_end_det_float(outputs, boxes, output_boxes)?;
440 }
441 ModelType::YoloEndToEndSegDet { boxes, protos } => {
442 self.decode_yolo_end_to_end_segdet_float(
443 outputs,
444 boxes,
445 protos,
446 output_boxes,
447 output_masks,
448 )?;
449 }
450 ModelType::YoloSplitEndToEndDet {
451 boxes,
452 scores,
453 classes,
454 } => {
455 self.decode_yolo_split_end_to_end_det_float(
456 outputs,
457 boxes,
458 scores,
459 classes,
460 output_boxes,
461 )?;
462 }
463 ModelType::YoloSplitEndToEndSegDet {
464 boxes,
465 scores,
466 classes,
467 mask_coeff,
468 protos,
469 } => {
470 self.decode_yolo_split_end_to_end_segdet_float(
471 outputs,
472 boxes,
473 scores,
474 classes,
475 mask_coeff,
476 protos,
477 output_boxes,
478 output_masks,
479 )?;
480 }
481 }
482 Ok(())
483 }
484
485 pub(crate) fn decode_quantized_proto(
492 &self,
493 outputs: &[ArrayViewDQuantized],
494 output_boxes: &mut Vec<DetectBox>,
495 ) -> Result<Option<ProtoData>, DecoderError> {
496 output_boxes.clear();
497 match &self.model_type {
498 ModelType::ModelPackDet { boxes, scores } => {
500 self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)?;
501 Ok(None)
502 }
503 ModelType::ModelPackDetSplit { detection } => {
504 self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)?;
505 Ok(None)
506 }
507 ModelType::YoloDet { boxes } => {
508 self.decode_yolo_det_quantized(outputs, boxes, output_boxes)?;
509 Ok(None)
510 }
511 ModelType::YoloSplitDet { boxes, scores } => {
512 self.decode_yolo_split_det_quantized(outputs, boxes, scores, output_boxes)?;
513 Ok(None)
514 }
515 ModelType::YoloEndToEndDet { boxes } => {
516 self.decode_yolo_end_to_end_det_quantized(outputs, boxes, output_boxes)?;
517 Ok(None)
518 }
519 ModelType::YoloSplitEndToEndDet {
520 boxes,
521 scores,
522 classes,
523 } => {
524 self.decode_yolo_split_end_to_end_det_quantized(
525 outputs,
526 boxes,
527 scores,
528 classes,
529 output_boxes,
530 )?;
531 Ok(None)
532 }
533 ModelType::ModelPackSegDet { boxes, scores, .. } => {
535 self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)?;
536 Ok(None)
537 }
538 ModelType::ModelPackSegDetSplit { detection, .. } => {
539 self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)?;
540 Ok(None)
541 }
542 ModelType::ModelPackSeg { .. } => Ok(None),
543
544 ModelType::YoloSegDet { boxes, protos } => {
545 let proto =
546 self.decode_yolo_segdet_quantized_proto(outputs, boxes, protos, output_boxes)?;
547 Ok(Some(proto))
548 }
549 ModelType::YoloSplitSegDet {
550 boxes,
551 scores,
552 mask_coeff,
553 protos,
554 } => {
555 let proto = self.decode_yolo_split_segdet_quantized_proto(
556 outputs,
557 boxes,
558 scores,
559 mask_coeff,
560 protos,
561 output_boxes,
562 )?;
563 Ok(Some(proto))
564 }
565 ModelType::YoloSegDet2Way {
566 boxes,
567 mask_coeff,
568 protos,
569 } => {
570 let proto = self.decode_yolo_segdet_2way_quantized_proto(
571 outputs,
572 boxes,
573 mask_coeff,
574 protos,
575 output_boxes,
576 )?;
577 Ok(Some(proto))
578 }
579 ModelType::YoloEndToEndSegDet { boxes, protos } => {
580 let proto = self.decode_yolo_end_to_end_segdet_quantized_proto(
581 outputs,
582 boxes,
583 protos,
584 output_boxes,
585 )?;
586 Ok(Some(proto))
587 }
588 ModelType::YoloSplitEndToEndSegDet {
589 boxes,
590 scores,
591 classes,
592 mask_coeff,
593 protos,
594 } => {
595 let proto = self.decode_yolo_split_end_to_end_segdet_quantized_proto(
596 outputs,
597 boxes,
598 scores,
599 classes,
600 mask_coeff,
601 protos,
602 output_boxes,
603 )?;
604 Ok(Some(proto))
605 }
606 }
607 }
608
609 pub(crate) fn decode_float_proto<T>(
616 &self,
617 outputs: &[ArrayViewD<T>],
618 output_boxes: &mut Vec<DetectBox>,
619 ) -> Result<Option<ProtoData>, DecoderError>
620 where
621 T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + crate::yolo::FloatProtoElem,
622 f32: AsPrimitive<T>,
623 {
624 output_boxes.clear();
625 match &self.model_type {
626 ModelType::ModelPackDet { boxes, scores } => {
628 self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
629 Ok(None)
630 }
631 ModelType::ModelPackDetSplit { detection } => {
632 self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
633 Ok(None)
634 }
635 ModelType::YoloDet { boxes } => {
636 self.decode_yolo_det_float(outputs, boxes, output_boxes)?;
637 Ok(None)
638 }
639 ModelType::YoloSplitDet { boxes, scores } => {
640 self.decode_yolo_split_det_float(outputs, boxes, scores, output_boxes)?;
641 Ok(None)
642 }
643 ModelType::YoloEndToEndDet { boxes } => {
644 self.decode_yolo_end_to_end_det_float(outputs, boxes, output_boxes)?;
645 Ok(None)
646 }
647 ModelType::YoloSplitEndToEndDet {
648 boxes,
649 scores,
650 classes,
651 } => {
652 self.decode_yolo_split_end_to_end_det_float(
653 outputs,
654 boxes,
655 scores,
656 classes,
657 output_boxes,
658 )?;
659 Ok(None)
660 }
661 ModelType::ModelPackSegDet { boxes, scores, .. } => {
663 self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
664 Ok(None)
665 }
666 ModelType::ModelPackSegDetSplit { detection, .. } => {
667 self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
668 Ok(None)
669 }
670 ModelType::ModelPackSeg { .. } => Ok(None),
671
672 ModelType::YoloSegDet { boxes, protos } => {
673 let proto =
674 self.decode_yolo_segdet_float_proto(outputs, boxes, protos, output_boxes)?;
675 Ok(Some(proto))
676 }
677 ModelType::YoloSplitSegDet {
678 boxes,
679 scores,
680 mask_coeff,
681 protos,
682 } => {
683 let proto = self.decode_yolo_split_segdet_float_proto(
684 outputs,
685 boxes,
686 scores,
687 mask_coeff,
688 protos,
689 output_boxes,
690 )?;
691 Ok(Some(proto))
692 }
693 ModelType::YoloSegDet2Way {
694 boxes,
695 mask_coeff,
696 protos,
697 } => {
698 let proto = self.decode_yolo_segdet_2way_float_proto(
699 outputs,
700 boxes,
701 mask_coeff,
702 protos,
703 output_boxes,
704 )?;
705 Ok(Some(proto))
706 }
707 ModelType::YoloEndToEndSegDet { boxes, protos } => {
708 let proto = self.decode_yolo_end_to_end_segdet_float_proto(
709 outputs,
710 boxes,
711 protos,
712 output_boxes,
713 )?;
714 Ok(Some(proto))
715 }
716 ModelType::YoloSplitEndToEndSegDet {
717 boxes,
718 scores,
719 classes,
720 mask_coeff,
721 protos,
722 } => {
723 let proto = self.decode_yolo_split_end_to_end_segdet_float_proto(
724 outputs,
725 boxes,
726 scores,
727 classes,
728 mask_coeff,
729 protos,
730 output_boxes,
731 )?;
732 Ok(Some(proto))
733 }
734 }
735 }
736
737 pub fn decode(
758 &self,
759 outputs: &[&edgefirst_tensor::TensorDyn],
760 output_boxes: &mut Vec<DetectBox>,
761 output_masks: &mut Vec<Segmentation>,
762 ) -> Result<(), DecoderError> {
763 if let Some(program) = &self.decode_program {
766 let merged = program.execute(outputs)?;
767 let views: Vec<_> = merged.iter().map(|a| a.view()).collect();
768 return self.decode_float(&views, output_boxes, output_masks);
769 }
770
771 let mapped = tensor_bridge::map_tensors(outputs)?;
772 match &mapped {
773 tensor_bridge::MappedOutputs::Quantized(maps) => {
774 let views = tensor_bridge::quantized_views(maps)?;
775 self.decode_quantized(&views, output_boxes, output_masks)
776 }
777 tensor_bridge::MappedOutputs::Float16(maps) => {
778 let views = tensor_bridge::f16_views(maps)?;
779 self.decode_float(&views, output_boxes, output_masks)
780 }
781 tensor_bridge::MappedOutputs::Float32(maps) => {
782 let views = tensor_bridge::f32_views(maps)?;
783 self.decode_float(&views, output_boxes, output_masks)
784 }
785 tensor_bridge::MappedOutputs::Float64(maps) => {
786 let views = tensor_bridge::f64_views(maps)?;
787 self.decode_float(&views, output_boxes, output_masks)
788 }
789 }
790 }
791
792 pub fn decode_proto(
810 &self,
811 outputs: &[&edgefirst_tensor::TensorDyn],
812 output_boxes: &mut Vec<DetectBox>,
813 ) -> Result<Option<ProtoData>, DecoderError> {
814 if let Some(program) = &self.decode_program {
817 let merged = program.execute(outputs)?;
818 let views: Vec<_> = merged.iter().map(|a| a.view()).collect();
819 return self.decode_float_proto(&views, output_boxes);
820 }
821
822 let mapped = tensor_bridge::map_tensors(outputs)?;
823 let result = match &mapped {
824 tensor_bridge::MappedOutputs::Quantized(maps) => {
825 let views = tensor_bridge::quantized_views(maps)?;
826 self.decode_quantized_proto(&views, output_boxes)
827 }
828 tensor_bridge::MappedOutputs::Float16(maps) => {
829 let views = tensor_bridge::f16_views(maps)?;
830 self.decode_float_proto(&views, output_boxes)
831 }
832 tensor_bridge::MappedOutputs::Float32(maps) => {
833 let views = tensor_bridge::f32_views(maps)?;
834 self.decode_float_proto(&views, output_boxes)
835 }
836 tensor_bridge::MappedOutputs::Float64(maps) => {
837 let views = tensor_bridge::f64_views(maps)?;
838 self.decode_float_proto(&views, output_boxes)
839 }
840 };
841 result
842 }
843}
844
845#[cfg(feature = "tracker")]
846pub use edgefirst_tracker::TrackInfo;
847
848#[cfg(feature = "tracker")]
849pub use edgefirst_tracker::Tracker;
850
851#[cfg(feature = "tracker")]
852impl Decoder {
853 pub(crate) fn decode_tracked_quantized<TR: edgefirst_tracker::Tracker<DetectBox>>(
857 &self,
858 tracker: &mut TR,
859 timestamp: u64,
860 outputs: &[ArrayViewDQuantized],
861 output_boxes: &mut Vec<DetectBox>,
862 output_masks: &mut Vec<Segmentation>,
863 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
864 ) -> Result<(), DecoderError> {
865 output_boxes.clear();
866 output_masks.clear();
867 output_tracks.clear();
868
869 match &self.model_type {
872 ModelType::YoloSegDet { boxes, protos } => self.decode_tracked_yolo_segdet_quantized(
873 tracker,
874 timestamp,
875 outputs,
876 boxes,
877 protos,
878 output_boxes,
879 output_masks,
880 output_tracks,
881 ),
882 ModelType::YoloSplitSegDet {
883 boxes,
884 scores,
885 mask_coeff,
886 protos,
887 } => self.decode_tracked_yolo_split_segdet_quantized(
888 tracker,
889 timestamp,
890 outputs,
891 boxes,
892 scores,
893 mask_coeff,
894 protos,
895 output_boxes,
896 output_masks,
897 output_tracks,
898 ),
899 ModelType::YoloEndToEndSegDet { boxes, protos } => self
900 .decode_tracked_yolo_end_to_end_segdet_quantized(
901 tracker,
902 timestamp,
903 outputs,
904 boxes,
905 protos,
906 output_boxes,
907 output_masks,
908 output_tracks,
909 ),
910 ModelType::YoloSplitEndToEndSegDet {
911 boxes,
912 scores,
913 classes,
914 mask_coeff,
915 protos,
916 } => self.decode_tracked_yolo_split_end_to_end_segdet_quantized(
917 tracker,
918 timestamp,
919 outputs,
920 boxes,
921 scores,
922 classes,
923 mask_coeff,
924 protos,
925 output_boxes,
926 output_masks,
927 output_tracks,
928 ),
929 ModelType::YoloSegDet2Way {
930 boxes,
931 mask_coeff,
932 protos,
933 } => self.decode_tracked_yolo_segdet_2way_quantized(
934 tracker,
935 timestamp,
936 outputs,
937 boxes,
938 mask_coeff,
939 protos,
940 output_boxes,
941 output_masks,
942 output_tracks,
943 ),
944 _ => {
945 self.decode_quantized(outputs, output_boxes, output_masks)?;
946 Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
947 Ok(())
948 }
949 }
950 }
951
952 pub(crate) fn decode_tracked_float<TR: edgefirst_tracker::Tracker<DetectBox>, T>(
962 &self,
963 tracker: &mut TR,
964 timestamp: u64,
965 outputs: &[ArrayViewD<T>],
966 output_boxes: &mut Vec<DetectBox>,
967 output_masks: &mut Vec<Segmentation>,
968 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
969 ) -> Result<(), DecoderError>
970 where
971 T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
972 f32: AsPrimitive<T>,
973 {
974 output_boxes.clear();
975 output_masks.clear();
976 output_tracks.clear();
977 match &self.model_type {
978 ModelType::YoloSegDet { boxes, protos } => {
979 self.decode_tracked_yolo_segdet_float(
980 tracker,
981 timestamp,
982 outputs,
983 boxes,
984 protos,
985 output_boxes,
986 output_masks,
987 output_tracks,
988 )?;
989 }
990 ModelType::YoloSplitSegDet {
991 boxes,
992 scores,
993 mask_coeff,
994 protos,
995 } => {
996 self.decode_tracked_yolo_split_segdet_float(
997 tracker,
998 timestamp,
999 outputs,
1000 boxes,
1001 scores,
1002 mask_coeff,
1003 protos,
1004 output_boxes,
1005 output_masks,
1006 output_tracks,
1007 )?;
1008 }
1009 ModelType::YoloEndToEndSegDet { boxes, protos } => {
1010 self.decode_tracked_yolo_end_to_end_segdet_float(
1011 tracker,
1012 timestamp,
1013 outputs,
1014 boxes,
1015 protos,
1016 output_boxes,
1017 output_masks,
1018 output_tracks,
1019 )?;
1020 }
1021 ModelType::YoloSplitEndToEndSegDet {
1022 boxes,
1023 scores,
1024 classes,
1025 mask_coeff,
1026 protos,
1027 } => {
1028 self.decode_tracked_yolo_split_end_to_end_segdet_float(
1029 tracker,
1030 timestamp,
1031 outputs,
1032 boxes,
1033 scores,
1034 classes,
1035 mask_coeff,
1036 protos,
1037 output_boxes,
1038 output_masks,
1039 output_tracks,
1040 )?;
1041 }
1042 ModelType::YoloSegDet2Way {
1043 boxes,
1044 mask_coeff,
1045 protos,
1046 } => {
1047 self.decode_tracked_yolo_segdet_2way_float(
1048 tracker,
1049 timestamp,
1050 outputs,
1051 boxes,
1052 mask_coeff,
1053 protos,
1054 output_boxes,
1055 output_masks,
1056 output_tracks,
1057 )?;
1058 }
1059 _ => {
1060 self.decode_float(outputs, output_boxes, output_masks)?;
1061 Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
1062 }
1063 }
1064 Ok(())
1065 }
1066
1067 pub(crate) fn decode_tracked_quantized_proto<TR: edgefirst_tracker::Tracker<DetectBox>>(
1074 &self,
1075 tracker: &mut TR,
1076 timestamp: u64,
1077 outputs: &[ArrayViewDQuantized],
1078 output_boxes: &mut Vec<DetectBox>,
1079 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1080 ) -> Result<Option<ProtoData>, DecoderError> {
1081 output_boxes.clear();
1082 output_tracks.clear();
1083 match &self.model_type {
1084 ModelType::YoloSegDet { boxes, protos } => {
1085 let proto = self.decode_tracked_yolo_segdet_quantized_proto(
1086 tracker,
1087 timestamp,
1088 outputs,
1089 boxes,
1090 protos,
1091 output_boxes,
1092 output_tracks,
1093 )?;
1094 Ok(Some(proto))
1095 }
1096 ModelType::YoloSplitSegDet {
1097 boxes,
1098 scores,
1099 mask_coeff,
1100 protos,
1101 } => {
1102 let proto = self.decode_tracked_yolo_split_segdet_quantized_proto(
1103 tracker,
1104 timestamp,
1105 outputs,
1106 boxes,
1107 scores,
1108 mask_coeff,
1109 protos,
1110 output_boxes,
1111 output_tracks,
1112 )?;
1113 Ok(Some(proto))
1114 }
1115 ModelType::YoloSegDet2Way {
1116 boxes,
1117 mask_coeff,
1118 protos,
1119 } => {
1120 let proto = self.decode_tracked_yolo_segdet_2way_quantized_proto(
1121 tracker,
1122 timestamp,
1123 outputs,
1124 boxes,
1125 mask_coeff,
1126 protos,
1127 output_boxes,
1128 output_tracks,
1129 )?;
1130 Ok(Some(proto))
1131 }
1132 ModelType::YoloEndToEndSegDet { boxes, protos } => {
1133 let proto = self.decode_tracked_yolo_end_to_end_segdet_quantized_proto(
1134 tracker,
1135 timestamp,
1136 outputs,
1137 boxes,
1138 protos,
1139 output_boxes,
1140 output_tracks,
1141 )?;
1142 Ok(Some(proto))
1143 }
1144 ModelType::YoloSplitEndToEndSegDet {
1145 boxes,
1146 scores,
1147 classes,
1148 mask_coeff,
1149 protos,
1150 } => {
1151 let proto = self.decode_tracked_yolo_split_end_to_end_segdet_quantized_proto(
1152 tracker,
1153 timestamp,
1154 outputs,
1155 boxes,
1156 scores,
1157 classes,
1158 mask_coeff,
1159 protos,
1160 output_boxes,
1161 output_tracks,
1162 )?;
1163 Ok(Some(proto))
1164 }
1165 _ => {
1167 let mut masks = Vec::new();
1168 self.decode_quantized(outputs, output_boxes, &mut masks)?;
1169 Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
1170 Ok(None)
1171 }
1172 }
1173 }
1174
1175 pub(crate) fn decode_tracked_float_proto<TR: edgefirst_tracker::Tracker<DetectBox>, T>(
1182 &self,
1183 tracker: &mut TR,
1184 timestamp: u64,
1185 outputs: &[ArrayViewD<T>],
1186 output_boxes: &mut Vec<DetectBox>,
1187 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1188 ) -> Result<Option<ProtoData>, DecoderError>
1189 where
1190 T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + crate::yolo::FloatProtoElem,
1191 f32: AsPrimitive<T>,
1192 {
1193 output_boxes.clear();
1194 output_tracks.clear();
1195 match &self.model_type {
1196 ModelType::YoloSegDet { boxes, protos } => {
1197 let proto = self.decode_tracked_yolo_segdet_float_proto(
1198 tracker,
1199 timestamp,
1200 outputs,
1201 boxes,
1202 protos,
1203 output_boxes,
1204 output_tracks,
1205 )?;
1206 Ok(Some(proto))
1207 }
1208 ModelType::YoloSplitSegDet {
1209 boxes,
1210 scores,
1211 mask_coeff,
1212 protos,
1213 } => {
1214 let proto = self.decode_tracked_yolo_split_segdet_float_proto(
1215 tracker,
1216 timestamp,
1217 outputs,
1218 boxes,
1219 scores,
1220 mask_coeff,
1221 protos,
1222 output_boxes,
1223 output_tracks,
1224 )?;
1225 Ok(Some(proto))
1226 }
1227 ModelType::YoloSegDet2Way {
1228 boxes,
1229 mask_coeff,
1230 protos,
1231 } => {
1232 let proto = self.decode_tracked_yolo_segdet_2way_float_proto(
1233 tracker,
1234 timestamp,
1235 outputs,
1236 boxes,
1237 mask_coeff,
1238 protos,
1239 output_boxes,
1240 output_tracks,
1241 )?;
1242 Ok(Some(proto))
1243 }
1244 ModelType::YoloEndToEndSegDet { boxes, protos } => {
1245 let proto = self.decode_tracked_yolo_end_to_end_segdet_float_proto(
1246 tracker,
1247 timestamp,
1248 outputs,
1249 boxes,
1250 protos,
1251 output_boxes,
1252 output_tracks,
1253 )?;
1254 Ok(Some(proto))
1255 }
1256 ModelType::YoloSplitEndToEndSegDet {
1257 boxes,
1258 scores,
1259 classes,
1260 mask_coeff,
1261 protos,
1262 } => {
1263 let proto = self.decode_tracked_yolo_split_end_to_end_segdet_float_proto(
1264 tracker,
1265 timestamp,
1266 outputs,
1267 boxes,
1268 scores,
1269 classes,
1270 mask_coeff,
1271 protos,
1272 output_boxes,
1273 output_tracks,
1274 )?;
1275 Ok(Some(proto))
1276 }
1277 _ => {
1279 let mut masks = Vec::new();
1280 self.decode_float(outputs, output_boxes, &mut masks)?;
1281 Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
1282 Ok(None)
1283 }
1284 }
1285 }
1286
1287 pub fn decode_tracked<TR: edgefirst_tracker::Tracker<DetectBox>>(
1311 &self,
1312 tracker: &mut TR,
1313 timestamp: u64,
1314 outputs: &[&edgefirst_tensor::TensorDyn],
1315 output_boxes: &mut Vec<DetectBox>,
1316 output_masks: &mut Vec<Segmentation>,
1317 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1318 ) -> Result<(), DecoderError> {
1319 let mapped = tensor_bridge::map_tensors(outputs)?;
1320 match &mapped {
1321 tensor_bridge::MappedOutputs::Quantized(maps) => {
1322 let views = tensor_bridge::quantized_views(maps)?;
1323 self.decode_tracked_quantized(
1324 tracker,
1325 timestamp,
1326 &views,
1327 output_boxes,
1328 output_masks,
1329 output_tracks,
1330 )
1331 }
1332 tensor_bridge::MappedOutputs::Float16(maps) => {
1333 let views = tensor_bridge::f16_views(maps)?;
1334 self.decode_tracked_float(
1335 tracker,
1336 timestamp,
1337 &views,
1338 output_boxes,
1339 output_masks,
1340 output_tracks,
1341 )
1342 }
1343 tensor_bridge::MappedOutputs::Float32(maps) => {
1344 let views = tensor_bridge::f32_views(maps)?;
1345 self.decode_tracked_float(
1346 tracker,
1347 timestamp,
1348 &views,
1349 output_boxes,
1350 output_masks,
1351 output_tracks,
1352 )
1353 }
1354 tensor_bridge::MappedOutputs::Float64(maps) => {
1355 let views = tensor_bridge::f64_views(maps)?;
1356 self.decode_tracked_float(
1357 tracker,
1358 timestamp,
1359 &views,
1360 output_boxes,
1361 output_masks,
1362 output_tracks,
1363 )
1364 }
1365 }
1366 }
1367
1368 pub fn decode_proto_tracked<TR: edgefirst_tracker::Tracker<DetectBox>>(
1388 &self,
1389 tracker: &mut TR,
1390 timestamp: u64,
1391 outputs: &[&edgefirst_tensor::TensorDyn],
1392 output_boxes: &mut Vec<DetectBox>,
1393 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1394 ) -> Result<Option<ProtoData>, DecoderError> {
1395 let mapped = tensor_bridge::map_tensors(outputs)?;
1396 match &mapped {
1397 tensor_bridge::MappedOutputs::Quantized(maps) => {
1398 let views = tensor_bridge::quantized_views(maps)?;
1399 self.decode_tracked_quantized_proto(
1400 tracker,
1401 timestamp,
1402 &views,
1403 output_boxes,
1404 output_tracks,
1405 )
1406 }
1407 tensor_bridge::MappedOutputs::Float16(maps) => {
1408 let views = tensor_bridge::f16_views(maps)?;
1409 self.decode_tracked_float_proto(
1410 tracker,
1411 timestamp,
1412 &views,
1413 output_boxes,
1414 output_tracks,
1415 )
1416 }
1417 tensor_bridge::MappedOutputs::Float32(maps) => {
1418 let views = tensor_bridge::f32_views(maps)?;
1419 self.decode_tracked_float_proto(
1420 tracker,
1421 timestamp,
1422 &views,
1423 output_boxes,
1424 output_tracks,
1425 )
1426 }
1427 tensor_bridge::MappedOutputs::Float64(maps) => {
1428 let views = tensor_bridge::f64_views(maps)?;
1429 self.decode_tracked_float_proto(
1430 tracker,
1431 timestamp,
1432 &views,
1433 output_boxes,
1434 output_tracks,
1435 )
1436 }
1437 }
1438 }
1439}