1use ndarray::{ArrayView, ArrayViewD, Dimension};
5use num_traits::{AsPrimitive, Float};
6
7use crate::{DecoderError, DetectBox, ProtoData, Segmentation};
8
9pub mod config;
10pub mod configs;
11
12use configs::ModelType;
13
14#[derive(Debug)]
15pub struct Decoder {
16 model_type: ModelType,
17 pub iou_threshold: f32,
18 pub score_threshold: f32,
19 pub nms: Option<configs::Nms>,
22 pub pre_nms_top_k: usize,
28 pub max_det: usize,
36 normalized: Option<bool>,
42 input_dims: Option<(usize, usize)>,
50 pub(crate) decode_program: Option<merge::DecodeProgram>,
55 pub(crate) per_scale: Option<std::sync::Mutex<crate::per_scale::PerScaleDecoder>>,
60}
61
62impl PartialEq for Decoder {
63 fn eq(&self, other: &Self) -> bool {
64 self.model_type == other.model_type
67 && self.iou_threshold == other.iou_threshold
68 && self.score_threshold == other.score_threshold
69 && self.nms == other.nms
70 && self.pre_nms_top_k == other.pre_nms_top_k
71 && self.max_det == other.max_det
72 && self.normalized == other.normalized
73 && self.input_dims == other.input_dims
74 && self.decode_program.is_some() == other.decode_program.is_some()
75 && self.per_scale.is_some() == other.per_scale.is_some()
76 }
77}
78
79impl Clone for Decoder {
80 fn clone(&self) -> Self {
87 Self {
88 model_type: self.model_type.clone(),
89 iou_threshold: self.iou_threshold,
90 score_threshold: self.score_threshold,
91 nms: self.nms,
92 pre_nms_top_k: self.pre_nms_top_k,
93 max_det: self.max_det,
94 normalized: self.normalized,
95 input_dims: self.input_dims,
96 decode_program: self.decode_program.clone(),
97 per_scale: None,
98 }
99 }
100}
101
102#[derive(Debug)]
103pub(crate) enum ArrayViewDQuantized<'a> {
104 UInt8(ArrayViewD<'a, u8>),
105 Int8(ArrayViewD<'a, i8>),
106 UInt16(ArrayViewD<'a, u16>),
107 Int16(ArrayViewD<'a, i16>),
108 UInt32(ArrayViewD<'a, u32>),
109 Int32(ArrayViewD<'a, i32>),
110}
111
112impl<'a, D> From<ArrayView<'a, u8, D>> for ArrayViewDQuantized<'a>
113where
114 D: Dimension,
115{
116 fn from(arr: ArrayView<'a, u8, D>) -> Self {
117 Self::UInt8(arr.into_dyn())
118 }
119}
120
121impl<'a, D> From<ArrayView<'a, i8, D>> for ArrayViewDQuantized<'a>
122where
123 D: Dimension,
124{
125 fn from(arr: ArrayView<'a, i8, D>) -> Self {
126 Self::Int8(arr.into_dyn())
127 }
128}
129
130impl<'a, D> From<ArrayView<'a, u16, D>> for ArrayViewDQuantized<'a>
131where
132 D: Dimension,
133{
134 fn from(arr: ArrayView<'a, u16, D>) -> Self {
135 Self::UInt16(arr.into_dyn())
136 }
137}
138
139impl<'a, D> From<ArrayView<'a, i16, D>> for ArrayViewDQuantized<'a>
140where
141 D: Dimension,
142{
143 fn from(arr: ArrayView<'a, i16, D>) -> Self {
144 Self::Int16(arr.into_dyn())
145 }
146}
147
148impl<'a, D> From<ArrayView<'a, u32, D>> for ArrayViewDQuantized<'a>
149where
150 D: Dimension,
151{
152 fn from(arr: ArrayView<'a, u32, D>) -> Self {
153 Self::UInt32(arr.into_dyn())
154 }
155}
156
157impl<'a, D> From<ArrayView<'a, i32, D>> for ArrayViewDQuantized<'a>
158where
159 D: Dimension,
160{
161 fn from(arr: ArrayView<'a, i32, D>) -> Self {
162 Self::Int32(arr.into_dyn())
163 }
164}
165
166impl<'a> ArrayViewDQuantized<'a> {
167 pub(crate) fn shape(&self) -> &[usize] {
169 match self {
170 ArrayViewDQuantized::UInt8(a) => a.shape(),
171 ArrayViewDQuantized::Int8(a) => a.shape(),
172 ArrayViewDQuantized::UInt16(a) => a.shape(),
173 ArrayViewDQuantized::Int16(a) => a.shape(),
174 ArrayViewDQuantized::UInt32(a) => a.shape(),
175 ArrayViewDQuantized::Int32(a) => a.shape(),
176 }
177 }
178}
179
180macro_rules! with_quantized {
187 ($x:expr, $var:ident, $body:expr) => {
188 match $x {
189 ArrayViewDQuantized::UInt8(x) => {
190 let $var = x;
191 $body
192 }
193 ArrayViewDQuantized::Int8(x) => {
194 let $var = x;
195 $body
196 }
197 ArrayViewDQuantized::UInt16(x) => {
198 let $var = x;
199 $body
200 }
201 ArrayViewDQuantized::Int16(x) => {
202 let $var = x;
203 $body
204 }
205 ArrayViewDQuantized::UInt32(x) => {
206 let $var = x;
207 $body
208 }
209 ArrayViewDQuantized::Int32(x) => {
210 let $var = x;
211 $body
212 }
213 }
214 };
215}
216
217mod builder;
218mod dfl;
219mod helpers;
220mod merge;
221mod per_scale_bridge;
222mod postprocess;
223mod tensor_bridge;
224mod tests;
225
226pub use builder::DecoderBuilder;
227pub use config::{ConfigOutput, ConfigOutputRef, ConfigOutputs};
228
229impl Decoder {
230 fn decode_path_label(&self) -> &'static str {
236 if self.per_scale.is_some() {
237 "per_scale"
238 } else if self.decode_program.is_some() {
239 "decode_program"
240 } else {
241 "legacy"
242 }
243 }
244
245 pub fn model_type(&self) -> &ModelType {
264 &self.model_type
265 }
266
267 pub fn normalized_boxes(&self) -> Option<bool> {
293 self.normalized
294 }
295
296 pub fn input_dims(&self) -> Option<(usize, usize)> {
335 self.input_dims
336 }
337
338 pub(crate) fn decode_quantized(
342 &self,
343 outputs: &[ArrayViewDQuantized],
344 output_boxes: &mut Vec<DetectBox>,
345 output_masks: &mut Vec<Segmentation>,
346 ) -> Result<(), DecoderError> {
347 output_boxes.clear();
348 output_masks.clear();
349 match &self.model_type {
350 ModelType::ModelPackSegDet {
351 boxes,
352 scores,
353 segmentation,
354 } => {
355 self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)?;
356 self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
357 }
358 ModelType::ModelPackSegDetSplit {
359 detection,
360 segmentation,
361 } => {
362 self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)?;
363 self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
364 }
365 ModelType::ModelPackDet { boxes, scores } => {
366 self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)
367 }
368 ModelType::ModelPackDetSplit { detection } => {
369 self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)
370 }
371 ModelType::ModelPackSeg { segmentation } => {
372 self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
373 }
374 ModelType::YoloDet { boxes } => {
375 self.decode_yolo_det_quantized(outputs, boxes, output_boxes)
376 }
377 ModelType::YoloSegDet { boxes, protos } => self.decode_yolo_segdet_quantized(
378 outputs,
379 boxes,
380 protos,
381 output_boxes,
382 output_masks,
383 ),
384 ModelType::YoloSplitDet { boxes, scores } => {
385 self.decode_yolo_split_det_quantized(outputs, boxes, scores, output_boxes)
386 }
387 ModelType::YoloSplitSegDet {
388 boxes,
389 scores,
390 mask_coeff,
391 protos,
392 } => self.decode_yolo_split_segdet_quantized(
393 outputs,
394 boxes,
395 scores,
396 mask_coeff,
397 protos,
398 output_boxes,
399 output_masks,
400 ),
401 ModelType::YoloSegDet2Way {
402 boxes,
403 mask_coeff,
404 protos,
405 } => self.decode_yolo_segdet_2way_quantized(
406 outputs,
407 boxes,
408 mask_coeff,
409 protos,
410 output_boxes,
411 output_masks,
412 ),
413 ModelType::YoloEndToEndDet { boxes } => {
414 self.decode_yolo_end_to_end_det_quantized(outputs, boxes, output_boxes)
415 }
416 ModelType::YoloEndToEndSegDet { boxes, protos } => self
417 .decode_yolo_end_to_end_segdet_quantized(
418 outputs,
419 boxes,
420 protos,
421 output_boxes,
422 output_masks,
423 ),
424 ModelType::YoloSplitEndToEndDet {
425 boxes,
426 scores,
427 classes,
428 } => self.decode_yolo_split_end_to_end_det_quantized(
429 outputs,
430 boxes,
431 scores,
432 classes,
433 output_boxes,
434 ),
435 ModelType::YoloSplitEndToEndSegDet {
436 boxes,
437 scores,
438 classes,
439 mask_coeff,
440 protos,
441 } => self.decode_yolo_split_end_to_end_segdet_quantized(
442 outputs,
443 boxes,
444 scores,
445 classes,
446 mask_coeff,
447 protos,
448 output_boxes,
449 output_masks,
450 ),
451 ModelType::PerScale => Err(DecoderError::Internal(
452 "per-scale path must be intercepted before ModelType dispatch".into(),
453 )),
454 }
455 }
456
457 pub(crate) fn decode_float<T>(
461 &self,
462 outputs: &[ArrayViewD<T>],
463 output_boxes: &mut Vec<DetectBox>,
464 output_masks: &mut Vec<Segmentation>,
465 ) -> Result<(), DecoderError>
466 where
467 T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
468 f32: AsPrimitive<T>,
469 {
470 output_boxes.clear();
471 output_masks.clear();
472 match &self.model_type {
473 ModelType::ModelPackSegDet {
474 boxes,
475 scores,
476 segmentation,
477 } => {
478 self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
479 self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
480 }
481 ModelType::ModelPackSegDetSplit {
482 detection,
483 segmentation,
484 } => {
485 self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
486 self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
487 }
488 ModelType::ModelPackDet { boxes, scores } => {
489 self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
490 }
491 ModelType::ModelPackDetSplit { detection } => {
492 self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
493 }
494 ModelType::ModelPackSeg { segmentation } => {
495 self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
496 }
497 ModelType::YoloDet { boxes } => {
498 self.decode_yolo_det_float(outputs, boxes, output_boxes)?;
499 }
500 ModelType::YoloSegDet { boxes, protos } => {
501 self.decode_yolo_segdet_float(outputs, boxes, protos, output_boxes, output_masks)?;
502 }
503 ModelType::YoloSplitDet { boxes, scores } => {
504 self.decode_yolo_split_det_float(outputs, boxes, scores, output_boxes)?;
505 }
506 ModelType::YoloSplitSegDet {
507 boxes,
508 scores,
509 mask_coeff,
510 protos,
511 } => {
512 self.decode_yolo_split_segdet_float(
513 outputs,
514 boxes,
515 scores,
516 mask_coeff,
517 protos,
518 output_boxes,
519 output_masks,
520 )?;
521 }
522 ModelType::YoloSegDet2Way {
523 boxes,
524 mask_coeff,
525 protos,
526 } => {
527 self.decode_yolo_segdet_2way_float(
528 outputs,
529 boxes,
530 mask_coeff,
531 protos,
532 output_boxes,
533 output_masks,
534 )?;
535 }
536 ModelType::YoloEndToEndDet { boxes } => {
537 self.decode_yolo_end_to_end_det_float(outputs, boxes, output_boxes)?;
538 }
539 ModelType::YoloEndToEndSegDet { boxes, protos } => {
540 self.decode_yolo_end_to_end_segdet_float(
541 outputs,
542 boxes,
543 protos,
544 output_boxes,
545 output_masks,
546 )?;
547 }
548 ModelType::YoloSplitEndToEndDet {
549 boxes,
550 scores,
551 classes,
552 } => {
553 self.decode_yolo_split_end_to_end_det_float(
554 outputs,
555 boxes,
556 scores,
557 classes,
558 output_boxes,
559 )?;
560 }
561 ModelType::YoloSplitEndToEndSegDet {
562 boxes,
563 scores,
564 classes,
565 mask_coeff,
566 protos,
567 } => {
568 self.decode_yolo_split_end_to_end_segdet_float(
569 outputs,
570 boxes,
571 scores,
572 classes,
573 mask_coeff,
574 protos,
575 output_boxes,
576 output_masks,
577 )?;
578 }
579 ModelType::PerScale => {
580 return Err(DecoderError::Internal(
581 "per-scale path must be intercepted before ModelType dispatch".into(),
582 ));
583 }
584 }
585 Ok(())
586 }
587
588 pub(crate) fn decode_quantized_proto(
595 &self,
596 outputs: &[ArrayViewDQuantized],
597 output_boxes: &mut Vec<DetectBox>,
598 ) -> Result<Option<ProtoData>, DecoderError> {
599 output_boxes.clear();
600 match &self.model_type {
601 ModelType::ModelPackDet { boxes, scores } => {
603 self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)?;
604 Ok(None)
605 }
606 ModelType::ModelPackDetSplit { detection } => {
607 self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)?;
608 Ok(None)
609 }
610 ModelType::YoloDet { boxes } => {
611 self.decode_yolo_det_quantized(outputs, boxes, output_boxes)?;
612 Ok(None)
613 }
614 ModelType::YoloSplitDet { boxes, scores } => {
615 self.decode_yolo_split_det_quantized(outputs, boxes, scores, output_boxes)?;
616 Ok(None)
617 }
618 ModelType::YoloEndToEndDet { boxes } => {
619 self.decode_yolo_end_to_end_det_quantized(outputs, boxes, output_boxes)?;
620 Ok(None)
621 }
622 ModelType::YoloSplitEndToEndDet {
623 boxes,
624 scores,
625 classes,
626 } => {
627 self.decode_yolo_split_end_to_end_det_quantized(
628 outputs,
629 boxes,
630 scores,
631 classes,
632 output_boxes,
633 )?;
634 Ok(None)
635 }
636 ModelType::ModelPackSegDet { boxes, scores, .. } => {
638 self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)?;
639 Ok(None)
640 }
641 ModelType::ModelPackSegDetSplit { detection, .. } => {
642 self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)?;
643 Ok(None)
644 }
645 ModelType::ModelPackSeg { .. } => Ok(None),
646
647 ModelType::YoloSegDet { boxes, protos } => {
648 let proto =
649 self.decode_yolo_segdet_quantized_proto(outputs, boxes, protos, output_boxes)?;
650 Ok(Some(proto))
651 }
652 ModelType::YoloSplitSegDet {
653 boxes,
654 scores,
655 mask_coeff,
656 protos,
657 } => {
658 let proto = self.decode_yolo_split_segdet_quantized_proto(
659 outputs,
660 boxes,
661 scores,
662 mask_coeff,
663 protos,
664 output_boxes,
665 )?;
666 Ok(Some(proto))
667 }
668 ModelType::YoloSegDet2Way {
669 boxes,
670 mask_coeff,
671 protos,
672 } => {
673 let proto = self.decode_yolo_segdet_2way_quantized_proto(
674 outputs,
675 boxes,
676 mask_coeff,
677 protos,
678 output_boxes,
679 )?;
680 Ok(Some(proto))
681 }
682 ModelType::YoloEndToEndSegDet { boxes, protos } => {
683 let proto = self.decode_yolo_end_to_end_segdet_quantized_proto(
684 outputs,
685 boxes,
686 protos,
687 output_boxes,
688 )?;
689 Ok(Some(proto))
690 }
691 ModelType::YoloSplitEndToEndSegDet {
692 boxes,
693 scores,
694 classes,
695 mask_coeff,
696 protos,
697 } => {
698 let proto = self.decode_yolo_split_end_to_end_segdet_quantized_proto(
699 outputs,
700 boxes,
701 scores,
702 classes,
703 mask_coeff,
704 protos,
705 output_boxes,
706 )?;
707 Ok(Some(proto))
708 }
709 ModelType::PerScale => Err(DecoderError::Internal(
710 "per-scale path must be intercepted before ModelType dispatch".into(),
711 )),
712 }
713 }
714
715 pub(crate) fn decode_float_proto<T>(
722 &self,
723 outputs: &[ArrayViewD<T>],
724 output_boxes: &mut Vec<DetectBox>,
725 ) -> Result<Option<ProtoData>, DecoderError>
726 where
727 T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + crate::yolo::FloatProtoElem,
728 f32: AsPrimitive<T>,
729 {
730 output_boxes.clear();
731 match &self.model_type {
732 ModelType::ModelPackDet { boxes, scores } => {
734 self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
735 Ok(None)
736 }
737 ModelType::ModelPackDetSplit { detection } => {
738 self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
739 Ok(None)
740 }
741 ModelType::YoloDet { boxes } => {
742 self.decode_yolo_det_float(outputs, boxes, output_boxes)?;
743 Ok(None)
744 }
745 ModelType::YoloSplitDet { boxes, scores } => {
746 self.decode_yolo_split_det_float(outputs, boxes, scores, output_boxes)?;
747 Ok(None)
748 }
749 ModelType::YoloEndToEndDet { boxes } => {
750 self.decode_yolo_end_to_end_det_float(outputs, boxes, output_boxes)?;
751 Ok(None)
752 }
753 ModelType::YoloSplitEndToEndDet {
754 boxes,
755 scores,
756 classes,
757 } => {
758 self.decode_yolo_split_end_to_end_det_float(
759 outputs,
760 boxes,
761 scores,
762 classes,
763 output_boxes,
764 )?;
765 Ok(None)
766 }
767 ModelType::ModelPackSegDet { boxes, scores, .. } => {
769 self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
770 Ok(None)
771 }
772 ModelType::ModelPackSegDetSplit { detection, .. } => {
773 self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
774 Ok(None)
775 }
776 ModelType::ModelPackSeg { .. } => Ok(None),
777
778 ModelType::YoloSegDet { boxes, protos } => {
779 let proto =
780 self.decode_yolo_segdet_float_proto(outputs, boxes, protos, output_boxes)?;
781 Ok(Some(proto))
782 }
783 ModelType::YoloSplitSegDet {
784 boxes,
785 scores,
786 mask_coeff,
787 protos,
788 } => {
789 let proto = self.decode_yolo_split_segdet_float_proto(
790 outputs,
791 boxes,
792 scores,
793 mask_coeff,
794 protos,
795 output_boxes,
796 )?;
797 Ok(Some(proto))
798 }
799 ModelType::YoloSegDet2Way {
800 boxes,
801 mask_coeff,
802 protos,
803 } => {
804 let proto = self.decode_yolo_segdet_2way_float_proto(
805 outputs,
806 boxes,
807 mask_coeff,
808 protos,
809 output_boxes,
810 )?;
811 Ok(Some(proto))
812 }
813 ModelType::YoloEndToEndSegDet { boxes, protos } => {
814 let proto = self.decode_yolo_end_to_end_segdet_float_proto(
815 outputs,
816 boxes,
817 protos,
818 output_boxes,
819 )?;
820 Ok(Some(proto))
821 }
822 ModelType::YoloSplitEndToEndSegDet {
823 boxes,
824 scores,
825 classes,
826 mask_coeff,
827 protos,
828 } => {
829 let proto = self.decode_yolo_split_end_to_end_segdet_float_proto(
830 outputs,
831 boxes,
832 scores,
833 classes,
834 mask_coeff,
835 protos,
836 output_boxes,
837 )?;
838 Ok(Some(proto))
839 }
840 ModelType::PerScale => Err(DecoderError::Internal(
841 "per-scale path must be intercepted before ModelType dispatch".into(),
842 )),
843 }
844 }
845
846 pub fn decode(
877 &self,
878 outputs: &[&edgefirst_tensor::TensorDyn],
879 output_boxes: &mut Vec<DetectBox>,
880 output_masks: &mut Vec<Segmentation>,
881 ) -> Result<(), DecoderError> {
882 let path = self.decode_path_label();
883 let _span = tracing::trace_span!("Decoder::decode", path = path, n_outputs = outputs.len())
884 .entered();
885 if let Some(per_scale_mutex) = &self.per_scale {
888 let mut ps = per_scale_mutex
889 .lock()
890 .map_err(|e| DecoderError::Internal(format!("per_scale mutex poisoned: {e}")))?;
891 let decoded = ps.run(outputs)?;
892 return per_scale_bridge::per_scale_to_masks(
893 &decoded,
894 output_boxes,
895 output_masks,
896 self.iou_threshold,
897 self.score_threshold,
898 self.nms,
899 self.pre_nms_top_k,
900 self.max_det,
901 self.normalized,
902 self.input_dims,
903 );
904 }
905
906 if let Some(program) = &self.decode_program {
909 let merged = program.execute(outputs)?;
910 let views: Vec<_> = merged.iter().map(|a| a.view()).collect();
911 return self.decode_float(&views, output_boxes, output_masks);
912 }
913
914 let mapped = tensor_bridge::map_tensors(outputs)?;
915 match &mapped {
916 tensor_bridge::MappedOutputs::Quantized(maps) => {
917 let views = tensor_bridge::quantized_views(maps)?;
918 self.decode_quantized(&views, output_boxes, output_masks)
919 }
920 tensor_bridge::MappedOutputs::Float16(maps) => {
921 let views = tensor_bridge::f16_views(maps)?;
922 self.decode_float(&views, output_boxes, output_masks)
923 }
924 tensor_bridge::MappedOutputs::Float32(maps) => {
925 let views = tensor_bridge::f32_views(maps)?;
926 self.decode_float(&views, output_boxes, output_masks)
927 }
928 tensor_bridge::MappedOutputs::Float64(maps) => {
929 let views = tensor_bridge::f64_views(maps)?;
930 self.decode_float(&views, output_boxes, output_masks)
931 }
932 }
933 }
934
935 pub fn decode_proto(
961 &self,
962 outputs: &[&edgefirst_tensor::TensorDyn],
963 output_boxes: &mut Vec<DetectBox>,
964 ) -> Result<Option<ProtoData>, DecoderError> {
965 let path = self.decode_path_label();
966 let _span = tracing::trace_span!(
967 "Decoder::decode_proto",
968 path = path,
969 n_outputs = outputs.len()
970 )
971 .entered();
972 if let Some(per_scale_mutex) = &self.per_scale {
975 let mut ps = per_scale_mutex
976 .lock()
977 .map_err(|e| DecoderError::Internal(format!("per_scale mutex poisoned: {e}")))?;
978 let decoded = ps.run(outputs)?;
979 return per_scale_bridge::per_scale_to_proto_data(
980 &decoded,
981 output_boxes,
982 self.iou_threshold,
983 self.score_threshold,
984 self.nms,
985 self.pre_nms_top_k,
986 self.max_det,
987 self.normalized,
988 self.input_dims,
989 );
990 }
991
992 if let Some(program) = &self.decode_program {
995 let merged = program.execute(outputs)?;
996 let views: Vec<_> = merged.iter().map(|a| a.view()).collect();
997 return self.decode_float_proto(&views, output_boxes);
998 }
999
1000 let mapped = tensor_bridge::map_tensors(outputs)?;
1001 let result = match &mapped {
1002 tensor_bridge::MappedOutputs::Quantized(maps) => {
1003 let views = tensor_bridge::quantized_views(maps)?;
1004 self.decode_quantized_proto(&views, output_boxes)
1005 }
1006 tensor_bridge::MappedOutputs::Float16(maps) => {
1007 let views = tensor_bridge::f16_views(maps)?;
1008 self.decode_float_proto(&views, output_boxes)
1009 }
1010 tensor_bridge::MappedOutputs::Float32(maps) => {
1011 let views = tensor_bridge::f32_views(maps)?;
1012 self.decode_float_proto(&views, output_boxes)
1013 }
1014 tensor_bridge::MappedOutputs::Float64(maps) => {
1015 let views = tensor_bridge::f64_views(maps)?;
1016 self.decode_float_proto(&views, output_boxes)
1017 }
1018 };
1019 result
1020 }
1021
1022 #[doc(hidden)]
1029 pub fn _testing_run_per_scale_pre_nms(
1030 &self,
1031 outputs: &[&edgefirst_tensor::TensorDyn],
1032 ) -> Result<crate::per_scale::PreNmsCapture, crate::error::DecoderError> {
1033 let mutex = self.per_scale.as_ref().ok_or_else(|| {
1034 crate::error::DecoderError::Internal("decoder not configured for per-scale".into())
1035 })?;
1036 let mut ps = mutex.lock().map_err(|e| {
1037 crate::error::DecoderError::Internal(format!("per_scale mutex poisoned: {e}"))
1038 })?;
1039 {
1041 ps.run(outputs)?;
1042 }
1043 let total_anchors = ps.plan.total_anchors;
1044 let num_classes = ps.plan.num_classes;
1045 let num_mc = ps.plan.num_mask_coefs;
1046 Ok(ps
1047 .buffers
1048 .snapshot_owned_f32(total_anchors, num_classes, num_mc))
1049 }
1050}
1051
1052#[cfg(feature = "tracker")]
1053pub use edgefirst_tracker::TrackInfo;
1054
1055#[cfg(feature = "tracker")]
1056pub use edgefirst_tracker::Tracker;
1057
1058#[cfg(feature = "tracker")]
1059impl Decoder {
1060 pub(crate) fn decode_tracked_quantized<TR: edgefirst_tracker::Tracker<DetectBox>>(
1064 &self,
1065 tracker: &mut TR,
1066 timestamp: u64,
1067 outputs: &[ArrayViewDQuantized],
1068 output_boxes: &mut Vec<DetectBox>,
1069 output_masks: &mut Vec<Segmentation>,
1070 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1071 ) -> Result<(), DecoderError> {
1072 output_boxes.clear();
1073 output_masks.clear();
1074 output_tracks.clear();
1075
1076 match &self.model_type {
1079 ModelType::YoloSegDet { boxes, protos } => self.decode_tracked_yolo_segdet_quantized(
1080 tracker,
1081 timestamp,
1082 outputs,
1083 boxes,
1084 protos,
1085 output_boxes,
1086 output_masks,
1087 output_tracks,
1088 ),
1089 ModelType::YoloSplitSegDet {
1090 boxes,
1091 scores,
1092 mask_coeff,
1093 protos,
1094 } => self.decode_tracked_yolo_split_segdet_quantized(
1095 tracker,
1096 timestamp,
1097 outputs,
1098 boxes,
1099 scores,
1100 mask_coeff,
1101 protos,
1102 output_boxes,
1103 output_masks,
1104 output_tracks,
1105 ),
1106 ModelType::YoloEndToEndSegDet { boxes, protos } => self
1107 .decode_tracked_yolo_end_to_end_segdet_quantized(
1108 tracker,
1109 timestamp,
1110 outputs,
1111 boxes,
1112 protos,
1113 output_boxes,
1114 output_masks,
1115 output_tracks,
1116 ),
1117 ModelType::YoloSplitEndToEndSegDet {
1118 boxes,
1119 scores,
1120 classes,
1121 mask_coeff,
1122 protos,
1123 } => self.decode_tracked_yolo_split_end_to_end_segdet_quantized(
1124 tracker,
1125 timestamp,
1126 outputs,
1127 boxes,
1128 scores,
1129 classes,
1130 mask_coeff,
1131 protos,
1132 output_boxes,
1133 output_masks,
1134 output_tracks,
1135 ),
1136 ModelType::YoloSegDet2Way {
1137 boxes,
1138 mask_coeff,
1139 protos,
1140 } => self.decode_tracked_yolo_segdet_2way_quantized(
1141 tracker,
1142 timestamp,
1143 outputs,
1144 boxes,
1145 mask_coeff,
1146 protos,
1147 output_boxes,
1148 output_masks,
1149 output_tracks,
1150 ),
1151 _ => {
1152 self.decode_quantized(outputs, output_boxes, output_masks)?;
1153 Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
1154 Ok(())
1155 }
1156 }
1157 }
1158
1159 pub(crate) fn decode_tracked_float<TR: edgefirst_tracker::Tracker<DetectBox>, T>(
1169 &self,
1170 tracker: &mut TR,
1171 timestamp: u64,
1172 outputs: &[ArrayViewD<T>],
1173 output_boxes: &mut Vec<DetectBox>,
1174 output_masks: &mut Vec<Segmentation>,
1175 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1176 ) -> Result<(), DecoderError>
1177 where
1178 T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
1179 f32: AsPrimitive<T>,
1180 {
1181 output_boxes.clear();
1182 output_masks.clear();
1183 output_tracks.clear();
1184 match &self.model_type {
1185 ModelType::YoloSegDet { boxes, protos } => {
1186 self.decode_tracked_yolo_segdet_float(
1187 tracker,
1188 timestamp,
1189 outputs,
1190 boxes,
1191 protos,
1192 output_boxes,
1193 output_masks,
1194 output_tracks,
1195 )?;
1196 }
1197 ModelType::YoloSplitSegDet {
1198 boxes,
1199 scores,
1200 mask_coeff,
1201 protos,
1202 } => {
1203 self.decode_tracked_yolo_split_segdet_float(
1204 tracker,
1205 timestamp,
1206 outputs,
1207 boxes,
1208 scores,
1209 mask_coeff,
1210 protos,
1211 output_boxes,
1212 output_masks,
1213 output_tracks,
1214 )?;
1215 }
1216 ModelType::YoloEndToEndSegDet { boxes, protos } => {
1217 self.decode_tracked_yolo_end_to_end_segdet_float(
1218 tracker,
1219 timestamp,
1220 outputs,
1221 boxes,
1222 protos,
1223 output_boxes,
1224 output_masks,
1225 output_tracks,
1226 )?;
1227 }
1228 ModelType::YoloSplitEndToEndSegDet {
1229 boxes,
1230 scores,
1231 classes,
1232 mask_coeff,
1233 protos,
1234 } => {
1235 self.decode_tracked_yolo_split_end_to_end_segdet_float(
1236 tracker,
1237 timestamp,
1238 outputs,
1239 boxes,
1240 scores,
1241 classes,
1242 mask_coeff,
1243 protos,
1244 output_boxes,
1245 output_masks,
1246 output_tracks,
1247 )?;
1248 }
1249 ModelType::YoloSegDet2Way {
1250 boxes,
1251 mask_coeff,
1252 protos,
1253 } => {
1254 self.decode_tracked_yolo_segdet_2way_float(
1255 tracker,
1256 timestamp,
1257 outputs,
1258 boxes,
1259 mask_coeff,
1260 protos,
1261 output_boxes,
1262 output_masks,
1263 output_tracks,
1264 )?;
1265 }
1266 _ => {
1267 self.decode_float(outputs, output_boxes, output_masks)?;
1268 Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
1269 }
1270 }
1271 Ok(())
1272 }
1273
1274 pub(crate) fn decode_tracked_quantized_proto<TR: edgefirst_tracker::Tracker<DetectBox>>(
1281 &self,
1282 tracker: &mut TR,
1283 timestamp: u64,
1284 outputs: &[ArrayViewDQuantized],
1285 output_boxes: &mut Vec<DetectBox>,
1286 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1287 ) -> Result<Option<ProtoData>, DecoderError> {
1288 output_boxes.clear();
1289 output_tracks.clear();
1290 match &self.model_type {
1291 ModelType::YoloSegDet { boxes, protos } => {
1292 let proto = self.decode_tracked_yolo_segdet_quantized_proto(
1293 tracker,
1294 timestamp,
1295 outputs,
1296 boxes,
1297 protos,
1298 output_boxes,
1299 output_tracks,
1300 )?;
1301 Ok(Some(proto))
1302 }
1303 ModelType::YoloSplitSegDet {
1304 boxes,
1305 scores,
1306 mask_coeff,
1307 protos,
1308 } => {
1309 let proto = self.decode_tracked_yolo_split_segdet_quantized_proto(
1310 tracker,
1311 timestamp,
1312 outputs,
1313 boxes,
1314 scores,
1315 mask_coeff,
1316 protos,
1317 output_boxes,
1318 output_tracks,
1319 )?;
1320 Ok(Some(proto))
1321 }
1322 ModelType::YoloSegDet2Way {
1323 boxes,
1324 mask_coeff,
1325 protos,
1326 } => {
1327 let proto = self.decode_tracked_yolo_segdet_2way_quantized_proto(
1328 tracker,
1329 timestamp,
1330 outputs,
1331 boxes,
1332 mask_coeff,
1333 protos,
1334 output_boxes,
1335 output_tracks,
1336 )?;
1337 Ok(Some(proto))
1338 }
1339 ModelType::YoloEndToEndSegDet { boxes, protos } => {
1340 let proto = self.decode_tracked_yolo_end_to_end_segdet_quantized_proto(
1341 tracker,
1342 timestamp,
1343 outputs,
1344 boxes,
1345 protos,
1346 output_boxes,
1347 output_tracks,
1348 )?;
1349 Ok(Some(proto))
1350 }
1351 ModelType::YoloSplitEndToEndSegDet {
1352 boxes,
1353 scores,
1354 classes,
1355 mask_coeff,
1356 protos,
1357 } => {
1358 let proto = self.decode_tracked_yolo_split_end_to_end_segdet_quantized_proto(
1359 tracker,
1360 timestamp,
1361 outputs,
1362 boxes,
1363 scores,
1364 classes,
1365 mask_coeff,
1366 protos,
1367 output_boxes,
1368 output_tracks,
1369 )?;
1370 Ok(Some(proto))
1371 }
1372 _ => {
1374 let mut masks = Vec::new();
1375 self.decode_quantized(outputs, output_boxes, &mut masks)?;
1376 Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
1377 Ok(None)
1378 }
1379 }
1380 }
1381
1382 pub(crate) fn decode_tracked_float_proto<TR: edgefirst_tracker::Tracker<DetectBox>, T>(
1389 &self,
1390 tracker: &mut TR,
1391 timestamp: u64,
1392 outputs: &[ArrayViewD<T>],
1393 output_boxes: &mut Vec<DetectBox>,
1394 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1395 ) -> Result<Option<ProtoData>, DecoderError>
1396 where
1397 T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + crate::yolo::FloatProtoElem,
1398 f32: AsPrimitive<T>,
1399 {
1400 output_boxes.clear();
1401 output_tracks.clear();
1402 match &self.model_type {
1403 ModelType::YoloSegDet { boxes, protos } => {
1404 let proto = self.decode_tracked_yolo_segdet_float_proto(
1405 tracker,
1406 timestamp,
1407 outputs,
1408 boxes,
1409 protos,
1410 output_boxes,
1411 output_tracks,
1412 )?;
1413 Ok(Some(proto))
1414 }
1415 ModelType::YoloSplitSegDet {
1416 boxes,
1417 scores,
1418 mask_coeff,
1419 protos,
1420 } => {
1421 let proto = self.decode_tracked_yolo_split_segdet_float_proto(
1422 tracker,
1423 timestamp,
1424 outputs,
1425 boxes,
1426 scores,
1427 mask_coeff,
1428 protos,
1429 output_boxes,
1430 output_tracks,
1431 )?;
1432 Ok(Some(proto))
1433 }
1434 ModelType::YoloSegDet2Way {
1435 boxes,
1436 mask_coeff,
1437 protos,
1438 } => {
1439 let proto = self.decode_tracked_yolo_segdet_2way_float_proto(
1440 tracker,
1441 timestamp,
1442 outputs,
1443 boxes,
1444 mask_coeff,
1445 protos,
1446 output_boxes,
1447 output_tracks,
1448 )?;
1449 Ok(Some(proto))
1450 }
1451 ModelType::YoloEndToEndSegDet { boxes, protos } => {
1452 let proto = self.decode_tracked_yolo_end_to_end_segdet_float_proto(
1453 tracker,
1454 timestamp,
1455 outputs,
1456 boxes,
1457 protos,
1458 output_boxes,
1459 output_tracks,
1460 )?;
1461 Ok(Some(proto))
1462 }
1463 ModelType::YoloSplitEndToEndSegDet {
1464 boxes,
1465 scores,
1466 classes,
1467 mask_coeff,
1468 protos,
1469 } => {
1470 let proto = self.decode_tracked_yolo_split_end_to_end_segdet_float_proto(
1471 tracker,
1472 timestamp,
1473 outputs,
1474 boxes,
1475 scores,
1476 classes,
1477 mask_coeff,
1478 protos,
1479 output_boxes,
1480 output_tracks,
1481 )?;
1482 Ok(Some(proto))
1483 }
1484 _ => {
1486 let mut masks = Vec::new();
1487 self.decode_float(outputs, output_boxes, &mut masks)?;
1488 Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
1489 Ok(None)
1490 }
1491 }
1492 }
1493
1494 pub fn decode_tracked<TR: edgefirst_tracker::Tracker<DetectBox>>(
1518 &self,
1519 tracker: &mut TR,
1520 timestamp: u64,
1521 outputs: &[&edgefirst_tensor::TensorDyn],
1522 output_boxes: &mut Vec<DetectBox>,
1523 output_masks: &mut Vec<Segmentation>,
1524 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1525 ) -> Result<(), DecoderError> {
1526 if self.per_scale.is_some() {
1530 output_tracks.clear();
1531 self.decode(outputs, output_boxes, output_masks)?;
1532 Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
1533 return Ok(());
1534 }
1535
1536 let mapped = tensor_bridge::map_tensors(outputs)?;
1537 match &mapped {
1538 tensor_bridge::MappedOutputs::Quantized(maps) => {
1539 let views = tensor_bridge::quantized_views(maps)?;
1540 self.decode_tracked_quantized(
1541 tracker,
1542 timestamp,
1543 &views,
1544 output_boxes,
1545 output_masks,
1546 output_tracks,
1547 )
1548 }
1549 tensor_bridge::MappedOutputs::Float16(maps) => {
1550 let views = tensor_bridge::f16_views(maps)?;
1551 self.decode_tracked_float(
1552 tracker,
1553 timestamp,
1554 &views,
1555 output_boxes,
1556 output_masks,
1557 output_tracks,
1558 )
1559 }
1560 tensor_bridge::MappedOutputs::Float32(maps) => {
1561 let views = tensor_bridge::f32_views(maps)?;
1562 self.decode_tracked_float(
1563 tracker,
1564 timestamp,
1565 &views,
1566 output_boxes,
1567 output_masks,
1568 output_tracks,
1569 )
1570 }
1571 tensor_bridge::MappedOutputs::Float64(maps) => {
1572 let views = tensor_bridge::f64_views(maps)?;
1573 self.decode_tracked_float(
1574 tracker,
1575 timestamp,
1576 &views,
1577 output_boxes,
1578 output_masks,
1579 output_tracks,
1580 )
1581 }
1582 }
1583 }
1584
1585 pub fn decode_proto_tracked<TR: edgefirst_tracker::Tracker<DetectBox>>(
1605 &self,
1606 tracker: &mut TR,
1607 timestamp: u64,
1608 outputs: &[&edgefirst_tensor::TensorDyn],
1609 output_boxes: &mut Vec<DetectBox>,
1610 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1611 ) -> Result<Option<ProtoData>, DecoderError> {
1612 if self.per_scale.is_some() {
1615 output_tracks.clear();
1616 let proto = self.decode_proto(outputs, output_boxes)?;
1617 Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
1618 return Ok(proto);
1619 }
1620
1621 let mapped = tensor_bridge::map_tensors(outputs)?;
1622 match &mapped {
1623 tensor_bridge::MappedOutputs::Quantized(maps) => {
1624 let views = tensor_bridge::quantized_views(maps)?;
1625 self.decode_tracked_quantized_proto(
1626 tracker,
1627 timestamp,
1628 &views,
1629 output_boxes,
1630 output_tracks,
1631 )
1632 }
1633 tensor_bridge::MappedOutputs::Float16(maps) => {
1634 let views = tensor_bridge::f16_views(maps)?;
1635 self.decode_tracked_float_proto(
1636 tracker,
1637 timestamp,
1638 &views,
1639 output_boxes,
1640 output_tracks,
1641 )
1642 }
1643 tensor_bridge::MappedOutputs::Float32(maps) => {
1644 let views = tensor_bridge::f32_views(maps)?;
1645 self.decode_tracked_float_proto(
1646 tracker,
1647 timestamp,
1648 &views,
1649 output_boxes,
1650 output_tracks,
1651 )
1652 }
1653 tensor_bridge::MappedOutputs::Float64(maps) => {
1654 let views = tensor_bridge::f64_views(maps)?;
1655 self.decode_tracked_float_proto(
1656 tracker,
1657 timestamp,
1658 &views,
1659 output_boxes,
1660 output_tracks,
1661 )
1662 }
1663 }
1664 }
1665}