1#![cfg_attr(coverage_nightly, feature(coverage_attribute))]
68
69use ndarray::{Array, Array2, Array3, ArrayView, ArrayView1, ArrayView3, Dimension};
70use num_traits::{AsPrimitive, Float, PrimInt};
71
72pub mod byte;
73pub mod error;
74pub mod float;
75pub mod modelpack;
76pub mod per_scale;
77pub mod schema;
78pub mod yolo;
79
80mod decoder;
81pub use decoder::*;
82
83pub use configs::{DecoderVersion, Nms};
84pub use error::{DecoderError, DecoderResult};
85pub use per_scale::DecodeDtype;
86
87use crate::{
88 decoder::configs::QuantTuple, modelpack::modelpack_segmentation_to_mask,
89 yolo::yolo_segmentation_to_mask,
90};
91
92pub trait BBoxTypeTrait {
94 fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4];
96
97 fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
99 input: &[B; 4],
100 quant: Quantization,
101 ) -> [A; 4]
102 where
103 f32: AsPrimitive<A>,
104 i32: AsPrimitive<A>;
105
106 #[inline(always)]
117 fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
118 input: ArrayView1<B>,
119 ) -> [A; 4] {
120 Self::to_xyxy_float(&[input[0], input[1], input[2], input[3]])
121 }
122
123 #[inline(always)]
124 fn ndarray_to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
126 input: ArrayView1<B>,
127 quant: Quantization,
128 ) -> [A; 4]
129 where
130 f32: AsPrimitive<A>,
131 i32: AsPrimitive<A>,
132 {
133 Self::to_xyxy_dequant(&[input[0], input[1], input[2], input[3]], quant)
134 }
135}
136
137#[derive(Debug, Clone, Copy, PartialEq, Eq)]
139pub struct XYXY {}
140
141impl BBoxTypeTrait for XYXY {
142 fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
143 input.map(|b| b.as_())
144 }
145
146 fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
147 input: &[B; 4],
148 quant: Quantization,
149 ) -> [A; 4]
150 where
151 f32: AsPrimitive<A>,
152 i32: AsPrimitive<A>,
153 {
154 let scale = quant.scale.as_();
155 let zp = quant.zero_point.as_();
156 input.map(|b| (b.as_() - zp) * scale)
157 }
158
159 #[inline(always)]
160 fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
161 input: ArrayView1<B>,
162 ) -> [A; 4] {
163 [
164 input[0].as_(),
165 input[1].as_(),
166 input[2].as_(),
167 input[3].as_(),
168 ]
169 }
170}
171
172#[derive(Debug, Clone, Copy, PartialEq, Eq)]
175pub struct XYWH {}
176
177impl BBoxTypeTrait for XYWH {
178 #[inline(always)]
179 fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
180 let half = A::one() / (A::one() + A::one());
181 [
182 (input[0].as_()) - (input[2].as_() * half),
183 (input[1].as_()) - (input[3].as_() * half),
184 (input[0].as_()) + (input[2].as_() * half),
185 (input[1].as_()) + (input[3].as_() * half),
186 ]
187 }
188
189 #[inline(always)]
190 fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
191 input: &[B; 4],
192 quant: Quantization,
193 ) -> [A; 4]
194 where
195 f32: AsPrimitive<A>,
196 i32: AsPrimitive<A>,
197 {
198 let scale = quant.scale.as_();
199 let half_scale = (quant.scale * 0.5).as_();
200 let zp = quant.zero_point.as_();
201 let [x, y, w, h] = [
202 (input[0].as_() - zp) * scale,
203 (input[1].as_() - zp) * scale,
204 (input[2].as_() - zp) * half_scale,
205 (input[3].as_() - zp) * half_scale,
206 ];
207
208 [x - w, y - h, x + w, y + h]
209 }
210
211 #[inline(always)]
212 fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
213 input: ArrayView1<B>,
214 ) -> [A; 4] {
215 let half = A::one() / (A::one() + A::one());
216 [
217 (input[0].as_()) - (input[2].as_() * half),
218 (input[1].as_()) - (input[3].as_() * half),
219 (input[0].as_()) + (input[2].as_() * half),
220 (input[1].as_()) + (input[3].as_() * half),
221 ]
222 }
223}
224
225#[derive(Debug, Clone, Copy, PartialEq)]
227pub struct Quantization {
228 pub scale: f32,
229 pub zero_point: i32,
230}
231
232impl Quantization {
233 pub fn new(scale: f32, zero_point: i32) -> Self {
242 Self { scale, zero_point }
243 }
244
245 pub fn identity() -> Self {
257 Self {
258 scale: 1.0,
259 zero_point: 0,
260 }
261 }
262}
263
264impl From<QuantTuple> for Quantization {
265 fn from(quant_tuple: QuantTuple) -> Quantization {
276 Quantization {
277 scale: quant_tuple.0,
278 zero_point: quant_tuple.1,
279 }
280 }
281}
282
283impl<S, Z> From<(S, Z)> for Quantization
284where
285 S: AsPrimitive<f32>,
286 Z: AsPrimitive<i32>,
287{
288 fn from((scale, zp): (S, Z)) -> Quantization {
297 Self {
298 scale: scale.as_(),
299 zero_point: zp.as_(),
300 }
301 }
302}
303
304impl Default for Quantization {
305 fn default() -> Self {
314 Self {
315 scale: 1.0,
316 zero_point: 0,
317 }
318 }
319}
320
321#[derive(Debug, Clone, Copy, PartialEq, Default)]
323pub struct DetectBox {
324 pub bbox: BoundingBox,
325 pub score: f32,
327 pub label: usize,
329}
330
331#[derive(Debug, Clone, Copy, PartialEq, Default)]
333pub struct BoundingBox {
334 pub xmin: f32,
336 pub ymin: f32,
338 pub xmax: f32,
340 pub ymax: f32,
342}
343
344impl BoundingBox {
345 pub fn new(xmin: f32, ymin: f32, xmax: f32, ymax: f32) -> Self {
347 Self {
348 xmin,
349 ymin,
350 xmax,
351 ymax,
352 }
353 }
354
355 pub fn to_canonical(&self) -> Self {
364 let xmin = self.xmin.min(self.xmax);
365 let xmax = self.xmin.max(self.xmax);
366 let ymin = self.ymin.min(self.ymax);
367 let ymax = self.ymin.max(self.ymax);
368 BoundingBox {
369 xmin,
370 ymin,
371 xmax,
372 ymax,
373 }
374 }
375}
376
377impl From<BoundingBox> for [f32; 4] {
378 fn from(b: BoundingBox) -> Self {
393 [b.xmin, b.ymin, b.xmax, b.ymax]
394 }
395}
396
397impl From<[f32; 4]> for BoundingBox {
398 fn from(arr: [f32; 4]) -> Self {
401 BoundingBox {
402 xmin: arr[0],
403 ymin: arr[1],
404 xmax: arr[2],
405 ymax: arr[3],
406 }
407 }
408}
409
410impl DetectBox {
411 pub fn equal_within_delta(&self, rhs: &DetectBox, eps: f32) -> bool {
440 let eq_delta = |a: f32, b: f32| (a - b).abs() <= eps;
441 self.label == rhs.label
442 && eq_delta(self.score, rhs.score)
443 && eq_delta(self.bbox.xmin, rhs.bbox.xmin)
444 && eq_delta(self.bbox.ymin, rhs.bbox.ymin)
445 && eq_delta(self.bbox.xmax, rhs.bbox.xmax)
446 && eq_delta(self.bbox.ymax, rhs.bbox.ymax)
447 }
448}
449
450#[derive(Debug, Clone, PartialEq, Default)]
463pub struct Segmentation {
464 pub xmin: f32,
467 pub ymin: f32,
470 pub xmax: f32,
473 pub ymax: f32,
476 pub segmentation: Array3<u8>,
486}
487
488#[derive(Debug, Clone, Copy, PartialEq, Eq)]
494pub enum ProtoLayout {
495 Nhwc,
499 Nchw,
503}
504
505#[derive(Debug)]
530pub struct ProtoData {
531 pub mask_coefficients: edgefirst_tensor::TensorDyn,
533 pub protos: edgefirst_tensor::TensorDyn,
538 pub layout: ProtoLayout,
540}
541
542pub fn dequant_detect_box<SCORE: PrimInt + AsPrimitive<f32>>(
560 detect: &DetectBoxQuantized<SCORE>,
561 quant_scores: Quantization,
562) -> DetectBox {
563 let scaled_zp = -quant_scores.scale * quant_scores.zero_point as f32;
564 DetectBox {
565 bbox: detect.bbox,
566 score: quant_scores.scale * detect.score.as_() + scaled_zp,
567 label: detect.label,
568 }
569}
570#[derive(Debug, Clone, Copy, PartialEq)]
572pub struct DetectBoxQuantized<
573 SCORE: PrimInt + AsPrimitive<f32>,
575> {
576 pub bbox: BoundingBox,
578 pub score: SCORE,
581 pub label: usize,
583}
584
585pub fn dequantize_ndarray<T: AsPrimitive<F>, D: Dimension, F: Float + 'static>(
598 input: ArrayView<T, D>,
599 quant: Quantization,
600) -> Array<F, D>
601where
602 i32: num_traits::AsPrimitive<F>,
603 f32: num_traits::AsPrimitive<F>,
604{
605 let zero_point = quant.zero_point.as_();
606 let scale = quant.scale.as_();
607 if zero_point != F::zero() {
608 let scaled_zero = -zero_point * scale;
609 input.mapv(|d| d.as_() * scale + scaled_zero)
610 } else {
611 input.mapv(|d| d.as_() * scale)
612 }
613}
614
615pub fn dequantize_cpu<T: AsPrimitive<F>, F: Float + 'static>(
628 input: &[T],
629 quant: Quantization,
630 output: &mut [F],
631) where
632 f32: num_traits::AsPrimitive<F>,
633 i32: num_traits::AsPrimitive<F>,
634{
635 assert!(input.len() == output.len());
636 let zero_point = quant.zero_point.as_();
637 let scale = quant.scale.as_();
638 if zero_point != F::zero() {
639 let scaled_zero = -zero_point * scale; input
641 .iter()
642 .zip(output)
643 .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
644 } else {
645 input
646 .iter()
647 .zip(output)
648 .for_each(|(d, deq)| *deq = d.as_() * scale);
649 }
650}
651
652pub fn dequantize_cpu_chunked<T: AsPrimitive<F>, F: Float + 'static>(
666 input: &[T],
667 quant: Quantization,
668 output: &mut [F],
669) where
670 f32: num_traits::AsPrimitive<F>,
671 i32: num_traits::AsPrimitive<F>,
672{
673 assert!(input.len() == output.len());
674 let zero_point = quant.zero_point.as_();
675 let scale = quant.scale.as_();
676
677 let input = input.as_chunks::<4>();
678 let output = output.as_chunks_mut::<4>();
679
680 if zero_point != F::zero() {
681 let scaled_zero = -zero_point * scale; input
684 .0
685 .iter()
686 .zip(output.0)
687 .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale + scaled_zero));
688 input
689 .1
690 .iter()
691 .zip(output.1)
692 .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
693 } else {
694 input
695 .0
696 .iter()
697 .zip(output.0)
698 .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale));
699 input
700 .1
701 .iter()
702 .zip(output.1)
703 .for_each(|(d, deq)| *deq = d.as_() * scale);
704 }
705}
706
707pub fn segmentation_to_mask(segmentation: ArrayView3<u8>) -> Result<Array2<u8>, DecoderError> {
725 if segmentation.shape()[2] == 0 {
726 return Err(DecoderError::InvalidShape(
727 "Segmentation tensor must have non-zero depth".to_string(),
728 ));
729 }
730 if segmentation.shape()[2] == 1 {
731 yolo_segmentation_to_mask(segmentation, 128)
732 } else {
733 Ok(modelpack_segmentation_to_mask(segmentation))
734 }
735}
736
737fn arg_max<T: PartialOrd + Copy>(score: ArrayView1<T>) -> (T, usize) {
739 score
740 .iter()
741 .enumerate()
742 .fold((score[0], 0), |(max, arg_max), (ind, s)| {
743 if max > *s {
744 (max, arg_max)
745 } else {
746 (*s, ind)
747 }
748 })
749}
750
751#[cfg(target_arch = "aarch64")]
757pub(crate) fn arg_max_i8(scores: &[i8]) -> (i8, usize) {
758 use std::arch::aarch64::*;
759
760 let n = scores.len();
761 if n < 16 {
762 let mut max = scores[0];
764 let mut idx = 0;
765 for (i, &s) in scores.iter().enumerate().skip(1) {
766 if s >= max {
767 max = s;
768 idx = i;
769 }
770 }
771 return (max, idx);
772 }
773
774 unsafe {
775 let chunks = n / 16;
777 let mut vmax = vld1q_s8(scores.as_ptr());
778 for i in 1..chunks {
779 let v = vld1q_s8(scores.as_ptr().add(i * 16));
780 vmax = vmaxq_s8(vmax, v);
781 }
782 let global_max = vmaxvq_s8(vmax);
783
784 let remainder_start = chunks * 16;
786 let mut final_max = global_max;
787 for &s in &scores[remainder_start..] {
788 if s > final_max {
789 final_max = s;
790 }
791 }
792
793 let mut idx = 0;
796 for i in (0..n).rev() {
797 if scores[i] == final_max {
798 idx = i;
799 break;
800 }
801 }
802 (final_max, idx)
803 }
804}
805#[cfg(test)]
806#[cfg_attr(coverage_nightly, coverage(off))]
807mod decoder_tests {
808 #![allow(clippy::excessive_precision)]
809 use crate::{
810 configs::{DecoderType, DimName, Protos},
811 modelpack::{decode_modelpack_det, decode_modelpack_split_quant},
812 yolo::{
813 decode_yolo_det, decode_yolo_det_float, decode_yolo_segdet_float,
814 decode_yolo_segdet_quant,
815 },
816 *,
817 };
818 use edgefirst_tensor::{Tensor, TensorMapTrait, TensorTrait};
819 use ndarray::Dimension;
820 use ndarray::{array, s, Array2, Array3, Array4, Axis};
821 use ndarray_stats::DeviationExt;
822 use num_traits::{AsPrimitive, PrimInt};
823
824 fn compare_outputs(
825 boxes: (&[DetectBox], &[DetectBox]),
826 masks: (&[Segmentation], &[Segmentation]),
827 ) {
828 let (boxes0, boxes1) = boxes;
829 let (masks0, masks1) = masks;
830
831 assert_eq!(boxes0.len(), boxes1.len());
832 assert_eq!(masks0.len(), masks1.len());
833
834 for (b_i8, b_f32) in boxes0.iter().zip(boxes1) {
835 assert!(
836 b_i8.equal_within_delta(b_f32, 1e-6),
837 "{b_i8:?} is not equal to {b_f32:?}"
838 );
839 }
840
841 for (m_i8, m_f32) in masks0.iter().zip(masks1) {
842 assert_eq!(
843 [m_i8.xmin, m_i8.ymin, m_i8.xmax, m_i8.ymax],
844 [m_f32.xmin, m_f32.ymin, m_f32.xmax, m_f32.ymax],
845 );
846 assert_eq!(m_i8.segmentation.shape(), m_f32.segmentation.shape());
847 let mask_i8 = m_i8.segmentation.map(|x| *x as i32);
848 let mask_f32 = m_f32.segmentation.map(|x| *x as i32);
849 let diff = &mask_i8 - &mask_f32;
850 for x in 0..diff.shape()[0] {
851 for y in 0..diff.shape()[1] {
852 for z in 0..diff.shape()[2] {
853 let val = diff[[x, y, z]];
854 assert!(
855 val.abs() <= 1,
856 "Difference between mask0 and mask1 is greater than 1 at ({}, {}, {}): {}",
857 x,
858 y,
859 z,
860 val
861 );
862 }
863 }
864 }
865 let mean_sq_err = mask_i8.mean_sq_err(&mask_f32).unwrap();
866 assert!(
867 mean_sq_err < 1e-2,
868 "Mean Square Error between masks was greater than 1%: {:.2}%",
869 mean_sq_err * 100.0
870 );
871 }
872 }
873
874 fn load_yolov8_boxes() -> Array3<i8> {
877 let raw = include_bytes!(concat!(
878 env!("CARGO_MANIFEST_DIR"),
879 "/../../testdata/yolov8_boxes_116x8400.bin"
880 ));
881 let raw = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const i8, raw.len()) };
882 Array3::from_shape_vec((1, 116, 8400), raw.to_vec()).unwrap()
883 }
884
885 fn load_yolov8_protos() -> Array4<i8> {
886 let raw = include_bytes!(concat!(
887 env!("CARGO_MANIFEST_DIR"),
888 "/../../testdata/yolov8_protos_160x160x32.bin"
889 ));
890 let raw = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const i8, raw.len()) };
891 Array4::from_shape_vec((1, 160, 160, 32), raw.to_vec()).unwrap()
892 }
893
894 fn load_yolov8s_det() -> Array3<i8> {
895 let raw = include_bytes!(concat!(
896 env!("CARGO_MANIFEST_DIR"),
897 "/../../testdata/yolov8s_80_classes.bin"
898 ));
899 let raw = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const i8, raw.len()) };
900 Array3::from_shape_vec((1, 84, 8400), raw.to_vec()).unwrap()
901 }
902
903 #[test]
904 fn test_decoder_modelpack() {
905 let score_threshold = 0.45;
906 let iou_threshold = 0.45;
907 let boxes = include_bytes!(concat!(
908 env!("CARGO_MANIFEST_DIR"),
909 "/../../testdata/modelpack_boxes_1935x1x4.bin"
910 ));
911 let boxes = ndarray::Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
912
913 let scores = include_bytes!(concat!(
914 env!("CARGO_MANIFEST_DIR"),
915 "/../../testdata/modelpack_scores_1935x1.bin"
916 ));
917 let scores = ndarray::Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
918
919 let quant_boxes = (0.004656755365431309, 21).into();
920 let quant_scores = (0.0019603664986789227, 0).into();
921
922 let decoder = DecoderBuilder::default()
923 .with_config_modelpack_det(
924 configs::Boxes {
925 decoder: DecoderType::ModelPack,
926 quantization: Some(quant_boxes),
927 shape: vec![1, 1935, 1, 4],
928 dshape: vec![
929 (DimName::Batch, 1),
930 (DimName::NumBoxes, 1935),
931 (DimName::Padding, 1),
932 (DimName::BoxCoords, 4),
933 ],
934 normalized: Some(true),
935 },
936 configs::Scores {
937 decoder: DecoderType::ModelPack,
938 quantization: Some(quant_scores),
939 shape: vec![1, 1935, 1],
940 dshape: vec![
941 (DimName::Batch, 1),
942 (DimName::NumBoxes, 1935),
943 (DimName::NumClasses, 1),
944 ],
945 },
946 )
947 .with_score_threshold(score_threshold)
948 .with_iou_threshold(iou_threshold)
949 .build()
950 .unwrap();
951
952 let quant_boxes = quant_boxes.into();
953 let quant_scores = quant_scores.into();
954
955 let mut output_boxes: Vec<_> = Vec::with_capacity(50);
956 decode_modelpack_det(
957 (boxes.slice(s![0, .., 0, ..]), quant_boxes),
958 (scores.slice(s![0, .., ..]), quant_scores),
959 score_threshold,
960 iou_threshold,
961 &mut output_boxes,
962 );
963 assert!(output_boxes[0].equal_within_delta(
964 &DetectBox {
965 bbox: BoundingBox {
966 xmin: 0.40513772,
967 ymin: 0.6379755,
968 xmax: 0.5122431,
969 ymax: 0.7730214,
970 },
971 score: 0.4861709,
972 label: 0
973 },
974 1e-6
975 ));
976
977 let mut output_boxes1 = Vec::with_capacity(50);
978 let mut output_masks1 = Vec::with_capacity(50);
979
980 decoder
981 .decode_quantized(
982 &[boxes.view().into(), scores.view().into()],
983 &mut output_boxes1,
984 &mut output_masks1,
985 )
986 .unwrap();
987
988 let mut output_boxes_float = Vec::with_capacity(50);
989 let mut output_masks_float = Vec::with_capacity(50);
990
991 let boxes = dequantize_ndarray(boxes.view(), quant_boxes);
992 let scores = dequantize_ndarray(scores.view(), quant_scores);
993
994 decoder
995 .decode_float::<f32>(
996 &[boxes.view().into_dyn(), scores.view().into_dyn()],
997 &mut output_boxes_float,
998 &mut output_masks_float,
999 )
1000 .unwrap();
1001
1002 compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1003 compare_outputs(
1004 (&output_boxes, &output_boxes_float),
1005 (&[], &output_masks_float),
1006 );
1007 }
1008
1009 #[test]
1010 fn test_decoder_modelpack_split_u8() {
1011 let score_threshold = 0.45;
1012 let iou_threshold = 0.45;
1013 let detect0 = include_bytes!(concat!(
1014 env!("CARGO_MANIFEST_DIR"),
1015 "/../../testdata/modelpack_split_9x15x18.bin"
1016 ));
1017 let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1018
1019 let detect1 = include_bytes!(concat!(
1020 env!("CARGO_MANIFEST_DIR"),
1021 "/../../testdata/modelpack_split_17x30x18.bin"
1022 ));
1023 let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1024
1025 let quant0 = (0.08547406643629074, 174).into();
1026 let quant1 = (0.09929127991199493, 183).into();
1027 let anchors0 = vec![
1028 [0.36666667461395264, 0.31481480598449707],
1029 [0.38749998807907104, 0.4740740656852722],
1030 [0.5333333611488342, 0.644444465637207],
1031 ];
1032 let anchors1 = vec![
1033 [0.13750000298023224, 0.2074074000120163],
1034 [0.2541666626930237, 0.21481481194496155],
1035 [0.23125000298023224, 0.35185185074806213],
1036 ];
1037
1038 let detect_config0 = configs::Detection {
1039 decoder: DecoderType::ModelPack,
1040 shape: vec![1, 9, 15, 18],
1041 anchors: Some(anchors0.clone()),
1042 quantization: Some(quant0),
1043 dshape: vec![
1044 (DimName::Batch, 1),
1045 (DimName::Height, 9),
1046 (DimName::Width, 15),
1047 (DimName::NumAnchorsXFeatures, 18),
1048 ],
1049 normalized: Some(true),
1050 };
1051
1052 let detect_config1 = configs::Detection {
1053 decoder: DecoderType::ModelPack,
1054 shape: vec![1, 17, 30, 18],
1055 anchors: Some(anchors1.clone()),
1056 quantization: Some(quant1),
1057 dshape: vec![
1058 (DimName::Batch, 1),
1059 (DimName::Height, 17),
1060 (DimName::Width, 30),
1061 (DimName::NumAnchorsXFeatures, 18),
1062 ],
1063 normalized: Some(true),
1064 };
1065
1066 let config0 = (&detect_config0).try_into().unwrap();
1067 let config1 = (&detect_config1).try_into().unwrap();
1068
1069 let decoder = DecoderBuilder::default()
1070 .with_config_modelpack_det_split(vec![detect_config1, detect_config0])
1071 .with_score_threshold(score_threshold)
1072 .with_iou_threshold(iou_threshold)
1073 .build()
1074 .unwrap();
1075
1076 let quant0 = quant0.into();
1077 let quant1 = quant1.into();
1078
1079 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1080 decode_modelpack_split_quant(
1081 &[
1082 detect0.slice(s![0, .., .., ..]),
1083 detect1.slice(s![0, .., .., ..]),
1084 ],
1085 &[config0, config1],
1086 score_threshold,
1087 iou_threshold,
1088 &mut output_boxes,
1089 );
1090 assert!(output_boxes[0].equal_within_delta(
1091 &DetectBox {
1092 bbox: BoundingBox {
1093 xmin: 0.43171933,
1094 ymin: 0.68243736,
1095 xmax: 0.5626645,
1096 ymax: 0.808863,
1097 },
1098 score: 0.99240804,
1099 label: 0
1100 },
1101 1e-6
1102 ));
1103
1104 let mut output_boxes1: Vec<_> = Vec::with_capacity(10);
1105 let mut output_masks1: Vec<_> = Vec::with_capacity(10);
1106 decoder
1107 .decode_quantized(
1108 &[detect0.view().into(), detect1.view().into()],
1109 &mut output_boxes1,
1110 &mut output_masks1,
1111 )
1112 .unwrap();
1113
1114 let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(10);
1115 let mut output_masks1_f32: Vec<_> = Vec::with_capacity(10);
1116
1117 let detect0 = dequantize_ndarray(detect0.view(), quant0);
1118 let detect1 = dequantize_ndarray(detect1.view(), quant1);
1119 decoder
1120 .decode_float::<f32>(
1121 &[detect0.view().into_dyn(), detect1.view().into_dyn()],
1122 &mut output_boxes1_f32,
1123 &mut output_masks1_f32,
1124 )
1125 .unwrap();
1126
1127 compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1128 compare_outputs(
1129 (&output_boxes, &output_boxes1_f32),
1130 (&[], &output_masks1_f32),
1131 );
1132 }
1133
1134 #[test]
1135 fn test_decoder_parse_config_modelpack_split_u8() {
1136 let score_threshold = 0.45;
1137 let iou_threshold = 0.45;
1138 let detect0 = include_bytes!(concat!(
1139 env!("CARGO_MANIFEST_DIR"),
1140 "/../../testdata/modelpack_split_9x15x18.bin"
1141 ));
1142 let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1143
1144 let detect1 = include_bytes!(concat!(
1145 env!("CARGO_MANIFEST_DIR"),
1146 "/../../testdata/modelpack_split_17x30x18.bin"
1147 ));
1148 let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1149
1150 let decoder = DecoderBuilder::default()
1151 .with_config_yaml_str(
1152 include_str!(concat!(
1153 env!("CARGO_MANIFEST_DIR"),
1154 "/../../testdata/modelpack_split.yaml"
1155 ))
1156 .to_string(),
1157 )
1158 .with_score_threshold(score_threshold)
1159 .with_iou_threshold(iou_threshold)
1160 .build()
1161 .unwrap();
1162
1163 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1164 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1165 decoder
1166 .decode_quantized(
1167 &[
1168 ArrayViewDQuantized::from(detect1.view()),
1169 ArrayViewDQuantized::from(detect0.view()),
1170 ],
1171 &mut output_boxes,
1172 &mut output_masks,
1173 )
1174 .unwrap();
1175 assert!(output_boxes[0].equal_within_delta(
1176 &DetectBox {
1177 bbox: BoundingBox {
1178 xmin: 0.43171933,
1179 ymin: 0.68243736,
1180 xmax: 0.5626645,
1181 ymax: 0.808863,
1182 },
1183 score: 0.99240804,
1184 label: 0
1185 },
1186 1e-6
1187 ));
1188 }
1189
1190 #[test]
1191 fn test_modelpack_seg() {
1192 let out = include_bytes!(concat!(
1193 env!("CARGO_MANIFEST_DIR"),
1194 "/../../testdata/modelpack_seg_2x160x160.bin"
1195 ));
1196 let out = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
1197 let quant = (1.0 / 255.0, 0).into();
1198
1199 let decoder = DecoderBuilder::default()
1200 .with_config_modelpack_seg(configs::Segmentation {
1201 decoder: DecoderType::ModelPack,
1202 quantization: Some(quant),
1203 shape: vec![1, 2, 160, 160],
1204 dshape: vec![
1205 (DimName::Batch, 1),
1206 (DimName::NumClasses, 2),
1207 (DimName::Height, 160),
1208 (DimName::Width, 160),
1209 ],
1210 })
1211 .build()
1212 .unwrap();
1213 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1214 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1215 decoder
1216 .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
1217 .unwrap();
1218
1219 let mut mask = out.slice(s![0, .., .., ..]);
1220 mask.swap_axes(0, 1);
1221 mask.swap_axes(1, 2);
1222 let mask = [Segmentation {
1223 xmin: 0.0,
1224 ymin: 0.0,
1225 xmax: 1.0,
1226 ymax: 1.0,
1227 segmentation: mask.into_owned(),
1228 }];
1229 compare_outputs((&[], &output_boxes), (&mask, &output_masks));
1230
1231 decoder
1232 .decode_float::<f32>(
1233 &[dequantize_ndarray(out.view(), quant.into())
1234 .view()
1235 .into_dyn()],
1236 &mut output_boxes,
1237 &mut output_masks,
1238 )
1239 .unwrap();
1240
1241 compare_outputs((&[], &output_boxes), (&[], &[]));
1247 let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1248 let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1249
1250 assert_eq!(mask0, mask1);
1251 }
1252 #[test]
1253 fn test_modelpack_seg_quant() {
1254 let out = include_bytes!(concat!(
1255 env!("CARGO_MANIFEST_DIR"),
1256 "/../../testdata/modelpack_seg_2x160x160.bin"
1257 ));
1258 let out_u8 = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
1259 let out_i8 = out_u8.mapv(|x| (x as i16 - 128) as i8);
1260 let out_u16 = out_u8.mapv(|x| (x as u16) << 8);
1261 let out_i16 = out_u8.mapv(|x| (((x as i32) << 8) - 32768) as i16);
1262 let out_u32 = out_u8.mapv(|x| (x as u32) << 24);
1263 let out_i32 = out_u8.mapv(|x| (((x as i64) << 24) - 2147483648) as i32);
1264
1265 let quant = (1.0 / 255.0, 0).into();
1266
1267 let decoder = DecoderBuilder::default()
1268 .with_config_modelpack_seg(configs::Segmentation {
1269 decoder: DecoderType::ModelPack,
1270 quantization: Some(quant),
1271 shape: vec![1, 2, 160, 160],
1272 dshape: vec![
1273 (DimName::Batch, 1),
1274 (DimName::NumClasses, 2),
1275 (DimName::Height, 160),
1276 (DimName::Width, 160),
1277 ],
1278 })
1279 .build()
1280 .unwrap();
1281 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1282 let mut output_masks_u8: Vec<_> = Vec::with_capacity(10);
1283 decoder
1284 .decode_quantized(
1285 &[out_u8.view().into()],
1286 &mut output_boxes,
1287 &mut output_masks_u8,
1288 )
1289 .unwrap();
1290
1291 let mut output_masks_i8: Vec<_> = Vec::with_capacity(10);
1292 decoder
1293 .decode_quantized(
1294 &[out_i8.view().into()],
1295 &mut output_boxes,
1296 &mut output_masks_i8,
1297 )
1298 .unwrap();
1299
1300 let mut output_masks_u16: Vec<_> = Vec::with_capacity(10);
1301 decoder
1302 .decode_quantized(
1303 &[out_u16.view().into()],
1304 &mut output_boxes,
1305 &mut output_masks_u16,
1306 )
1307 .unwrap();
1308
1309 let mut output_masks_i16: Vec<_> = Vec::with_capacity(10);
1310 decoder
1311 .decode_quantized(
1312 &[out_i16.view().into()],
1313 &mut output_boxes,
1314 &mut output_masks_i16,
1315 )
1316 .unwrap();
1317
1318 let mut output_masks_u32: Vec<_> = Vec::with_capacity(10);
1319 decoder
1320 .decode_quantized(
1321 &[out_u32.view().into()],
1322 &mut output_boxes,
1323 &mut output_masks_u32,
1324 )
1325 .unwrap();
1326
1327 let mut output_masks_i32: Vec<_> = Vec::with_capacity(10);
1328 decoder
1329 .decode_quantized(
1330 &[out_i32.view().into()],
1331 &mut output_boxes,
1332 &mut output_masks_i32,
1333 )
1334 .unwrap();
1335
1336 compare_outputs((&[], &output_boxes), (&[], &[]));
1337 let mask_u8 = segmentation_to_mask(output_masks_u8[0].segmentation.view()).unwrap();
1338 let mask_i8 = segmentation_to_mask(output_masks_i8[0].segmentation.view()).unwrap();
1339 let mask_u16 = segmentation_to_mask(output_masks_u16[0].segmentation.view()).unwrap();
1340 let mask_i16 = segmentation_to_mask(output_masks_i16[0].segmentation.view()).unwrap();
1341 let mask_u32 = segmentation_to_mask(output_masks_u32[0].segmentation.view()).unwrap();
1342 let mask_i32 = segmentation_to_mask(output_masks_i32[0].segmentation.view()).unwrap();
1343 assert_eq!(mask_u8, mask_i8);
1344 assert_eq!(mask_u8, mask_u16);
1345 assert_eq!(mask_u8, mask_i16);
1346 assert_eq!(mask_u8, mask_u32);
1347 assert_eq!(mask_u8, mask_i32);
1348 }
1349
1350 #[test]
1351 fn test_modelpack_segdet() {
1352 let score_threshold = 0.45;
1353 let iou_threshold = 0.45;
1354
1355 let boxes = include_bytes!(concat!(
1356 env!("CARGO_MANIFEST_DIR"),
1357 "/../../testdata/modelpack_boxes_1935x1x4.bin"
1358 ));
1359 let boxes = Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
1360
1361 let scores = include_bytes!(concat!(
1362 env!("CARGO_MANIFEST_DIR"),
1363 "/../../testdata/modelpack_scores_1935x1.bin"
1364 ));
1365 let scores = Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
1366
1367 let seg = include_bytes!(concat!(
1368 env!("CARGO_MANIFEST_DIR"),
1369 "/../../testdata/modelpack_seg_2x160x160.bin"
1370 ));
1371 let seg = Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1372
1373 let quant_boxes = (0.004656755365431309, 21).into();
1374 let quant_scores = (0.0019603664986789227, 0).into();
1375 let quant_seg = (1.0 / 255.0, 0).into();
1376
1377 let decoder = DecoderBuilder::default()
1378 .with_config_modelpack_segdet(
1379 configs::Boxes {
1380 decoder: DecoderType::ModelPack,
1381 quantization: Some(quant_boxes),
1382 shape: vec![1, 1935, 1, 4],
1383 dshape: vec![
1384 (DimName::Batch, 1),
1385 (DimName::NumBoxes, 1935),
1386 (DimName::Padding, 1),
1387 (DimName::BoxCoords, 4),
1388 ],
1389 normalized: Some(true),
1390 },
1391 configs::Scores {
1392 decoder: DecoderType::ModelPack,
1393 quantization: Some(quant_scores),
1394 shape: vec![1, 1935, 1],
1395 dshape: vec![
1396 (DimName::Batch, 1),
1397 (DimName::NumBoxes, 1935),
1398 (DimName::NumClasses, 1),
1399 ],
1400 },
1401 configs::Segmentation {
1402 decoder: DecoderType::ModelPack,
1403 quantization: Some(quant_seg),
1404 shape: vec![1, 2, 160, 160],
1405 dshape: vec![
1406 (DimName::Batch, 1),
1407 (DimName::NumClasses, 2),
1408 (DimName::Height, 160),
1409 (DimName::Width, 160),
1410 ],
1411 },
1412 )
1413 .with_iou_threshold(iou_threshold)
1414 .with_score_threshold(score_threshold)
1415 .build()
1416 .unwrap();
1417 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1418 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1419 decoder
1420 .decode_quantized(
1421 &[scores.view().into(), boxes.view().into(), seg.view().into()],
1422 &mut output_boxes,
1423 &mut output_masks,
1424 )
1425 .unwrap();
1426
1427 let mut mask = seg.slice(s![0, .., .., ..]);
1428 mask.swap_axes(0, 1);
1429 mask.swap_axes(1, 2);
1430 let mask = [Segmentation {
1431 xmin: 0.0,
1432 ymin: 0.0,
1433 xmax: 1.0,
1434 ymax: 1.0,
1435 segmentation: mask.into_owned(),
1436 }];
1437 let correct_boxes = [DetectBox {
1438 bbox: BoundingBox {
1439 xmin: 0.40513772,
1440 ymin: 0.6379755,
1441 xmax: 0.5122431,
1442 ymax: 0.7730214,
1443 },
1444 score: 0.4861709,
1445 label: 0,
1446 }];
1447 compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1448
1449 let scores = dequantize_ndarray(scores.view(), quant_scores.into());
1450 let boxes = dequantize_ndarray(boxes.view(), quant_boxes.into());
1451 let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1452 decoder
1453 .decode_float::<f32>(
1454 &[
1455 scores.view().into_dyn(),
1456 boxes.view().into_dyn(),
1457 seg.view().into_dyn(),
1458 ],
1459 &mut output_boxes,
1460 &mut output_masks,
1461 )
1462 .unwrap();
1463
1464 compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1470 let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1471 let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1472
1473 assert_eq!(mask0, mask1);
1474 }
1475
1476 #[test]
1477 fn test_modelpack_segdet_split() {
1478 let score_threshold = 0.8;
1479 let iou_threshold = 0.5;
1480
1481 let seg = include_bytes!(concat!(
1482 env!("CARGO_MANIFEST_DIR"),
1483 "/../../testdata/modelpack_seg_2x160x160.bin"
1484 ));
1485 let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1486
1487 let detect0 = include_bytes!(concat!(
1488 env!("CARGO_MANIFEST_DIR"),
1489 "/../../testdata/modelpack_split_9x15x18.bin"
1490 ));
1491 let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1492
1493 let detect1 = include_bytes!(concat!(
1494 env!("CARGO_MANIFEST_DIR"),
1495 "/../../testdata/modelpack_split_17x30x18.bin"
1496 ));
1497 let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1498
1499 let quant0 = (0.08547406643629074, 174).into();
1500 let quant1 = (0.09929127991199493, 183).into();
1501 let quant_seg = (1.0 / 255.0, 0).into();
1502
1503 let anchors0 = vec![
1504 [0.36666667461395264, 0.31481480598449707],
1505 [0.38749998807907104, 0.4740740656852722],
1506 [0.5333333611488342, 0.644444465637207],
1507 ];
1508 let anchors1 = vec![
1509 [0.13750000298023224, 0.2074074000120163],
1510 [0.2541666626930237, 0.21481481194496155],
1511 [0.23125000298023224, 0.35185185074806213],
1512 ];
1513
1514 let decoder = DecoderBuilder::default()
1515 .with_config_modelpack_segdet_split(
1516 vec![
1517 configs::Detection {
1518 decoder: DecoderType::ModelPack,
1519 shape: vec![1, 17, 30, 18],
1520 anchors: Some(anchors1),
1521 quantization: Some(quant1),
1522 dshape: vec![
1523 (DimName::Batch, 1),
1524 (DimName::Height, 17),
1525 (DimName::Width, 30),
1526 (DimName::NumAnchorsXFeatures, 18),
1527 ],
1528 normalized: Some(true),
1529 },
1530 configs::Detection {
1531 decoder: DecoderType::ModelPack,
1532 shape: vec![1, 9, 15, 18],
1533 anchors: Some(anchors0),
1534 quantization: Some(quant0),
1535 dshape: vec![
1536 (DimName::Batch, 1),
1537 (DimName::Height, 9),
1538 (DimName::Width, 15),
1539 (DimName::NumAnchorsXFeatures, 18),
1540 ],
1541 normalized: Some(true),
1542 },
1543 ],
1544 configs::Segmentation {
1545 decoder: DecoderType::ModelPack,
1546 quantization: Some(quant_seg),
1547 shape: vec![1, 2, 160, 160],
1548 dshape: vec![
1549 (DimName::Batch, 1),
1550 (DimName::NumClasses, 2),
1551 (DimName::Height, 160),
1552 (DimName::Width, 160),
1553 ],
1554 },
1555 )
1556 .with_score_threshold(score_threshold)
1557 .with_iou_threshold(iou_threshold)
1558 .build()
1559 .unwrap();
1560 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1561 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1562 decoder
1563 .decode_quantized(
1564 &[
1565 detect0.view().into(),
1566 detect1.view().into(),
1567 seg.view().into(),
1568 ],
1569 &mut output_boxes,
1570 &mut output_masks,
1571 )
1572 .unwrap();
1573
1574 let mut mask = seg.slice(s![0, .., .., ..]);
1575 mask.swap_axes(0, 1);
1576 mask.swap_axes(1, 2);
1577 let mask = [Segmentation {
1578 xmin: 0.0,
1579 ymin: 0.0,
1580 xmax: 1.0,
1581 ymax: 1.0,
1582 segmentation: mask.into_owned(),
1583 }];
1584 let correct_boxes = [DetectBox {
1585 bbox: BoundingBox {
1586 xmin: 0.43171933,
1587 ymin: 0.68243736,
1588 xmax: 0.5626645,
1589 ymax: 0.808863,
1590 },
1591 score: 0.99240804,
1592 label: 0,
1593 }];
1594 println!("Output Boxes: {:?}", output_boxes);
1595 compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1596
1597 let detect0 = dequantize_ndarray(detect0.view(), quant0.into());
1598 let detect1 = dequantize_ndarray(detect1.view(), quant1.into());
1599 let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1600 decoder
1601 .decode_float::<f32>(
1602 &[
1603 detect0.view().into_dyn(),
1604 detect1.view().into_dyn(),
1605 seg.view().into_dyn(),
1606 ],
1607 &mut output_boxes,
1608 &mut output_masks,
1609 )
1610 .unwrap();
1611
1612 compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1618 let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1619 let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1620
1621 assert_eq!(mask0, mask1);
1622 }
1623
1624 #[test]
1625 fn test_dequant_chunked() {
1626 let mut out = load_yolov8s_det().into_raw_vec_and_offset().0;
1627 out.push(123); let mut out_dequant = vec![0.0; 84 * 8400 + 1];
1630 let mut out_dequant_simd = vec![0.0; 84 * 8400 + 1];
1631 let quant = Quantization::new(0.0040811873, -123);
1632 dequantize_cpu(&out, quant, &mut out_dequant);
1633
1634 dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1635 assert_eq!(out_dequant, out_dequant_simd);
1636
1637 let quant = Quantization::new(0.0040811873, 0);
1638 dequantize_cpu(&out, quant, &mut out_dequant);
1639
1640 dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1641 assert_eq!(out_dequant, out_dequant_simd);
1642 }
1643
1644 #[test]
1645 fn test_dequant_ground_truth() {
1646 let quant = Quantization::new(0.1, -128);
1651 let input: Vec<i8> = vec![0, 127, -128, 64];
1652 let mut output = vec![0.0f32; 4];
1653 let mut output_chunked = vec![0.0f32; 4];
1654 dequantize_cpu(&input, quant, &mut output);
1655 dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1656 let expected: Vec<f32> = vec![12.8, 25.5, 0.0, 19.2];
1661 for (i, (&out, &exp)) in output.iter().zip(expected.iter()).enumerate() {
1662 assert!((out - exp).abs() < 1e-5, "cpu[{i}]: {out} != {exp}");
1663 }
1664 for (i, (&out, &exp)) in output_chunked.iter().zip(expected.iter()).enumerate() {
1665 assert!((out - exp).abs() < 1e-5, "chunked[{i}]: {out} != {exp}");
1666 }
1667
1668 let quant = Quantization::new(1.0, 0);
1670 dequantize_cpu(&input, quant, &mut output);
1671 dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1672 let expected: Vec<f32> = vec![0.0, 127.0, -128.0, 64.0];
1673 assert_eq!(output, expected);
1674 assert_eq!(output_chunked, expected);
1675
1676 let quant = Quantization::new(0.5, 0);
1678 dequantize_cpu(&input, quant, &mut output);
1679 dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1680 let expected: Vec<f32> = vec![0.0, 63.5, -64.0, 32.0];
1681 assert_eq!(output, expected);
1682 assert_eq!(output_chunked, expected);
1683
1684 let quant = Quantization::new(0.021287762, 31);
1686 let input: Vec<i8> = vec![-128, -1, 0, 1, 31, 127];
1687 let mut output = vec![0.0f32; 6];
1688 let mut output_chunked = vec![0.0f32; 6];
1689 dequantize_cpu(&input, quant, &mut output);
1690 dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1691 for i in 0..6 {
1692 let expected = (input[i] as f32 - 31.0) * 0.021287762;
1693 assert!(
1694 (output[i] - expected).abs() < 1e-5,
1695 "cpu[{i}]: {} != {expected}",
1696 output[i]
1697 );
1698 assert!(
1699 (output_chunked[i] - expected).abs() < 1e-5,
1700 "chunked[{i}]: {} != {expected}",
1701 output_chunked[i]
1702 );
1703 }
1704 }
1705
1706 #[test]
1707 fn test_decoder_yolo_det() {
1708 let score_threshold = 0.25;
1709 let iou_threshold = 0.7;
1710 let out = load_yolov8s_det();
1711 let quant = (0.0040811873, -123).into();
1712
1713 let decoder = DecoderBuilder::default()
1714 .with_config_yolo_det(
1715 configs::Detection {
1716 decoder: DecoderType::Ultralytics,
1717 shape: vec![1, 84, 8400],
1718 anchors: None,
1719 quantization: Some(quant),
1720 dshape: vec![
1721 (DimName::Batch, 1),
1722 (DimName::NumFeatures, 84),
1723 (DimName::NumBoxes, 8400),
1724 ],
1725 normalized: Some(true),
1726 },
1727 Some(DecoderVersion::Yolo11),
1728 )
1729 .with_score_threshold(score_threshold)
1730 .with_iou_threshold(iou_threshold)
1731 .build()
1732 .unwrap();
1733
1734 let mut output_boxes: Vec<_> = Vec::with_capacity(50);
1735 decode_yolo_det(
1736 (out.slice(s![0, .., ..]), quant.into()),
1737 score_threshold,
1738 iou_threshold,
1739 Some(configs::Nms::ClassAgnostic),
1740 &mut output_boxes,
1741 );
1742 assert!(output_boxes[0].equal_within_delta(
1743 &DetectBox {
1744 bbox: BoundingBox {
1745 xmin: 0.5285137,
1746 ymin: 0.05305544,
1747 xmax: 0.87541467,
1748 ymax: 0.9998909,
1749 },
1750 score: 0.5591227,
1751 label: 0
1752 },
1753 1e-6
1754 ));
1755
1756 assert!(output_boxes[1].equal_within_delta(
1757 &DetectBox {
1758 bbox: BoundingBox {
1759 xmin: 0.130598,
1760 ymin: 0.43260583,
1761 xmax: 0.35098213,
1762 ymax: 0.9958097,
1763 },
1764 score: 0.33057618,
1765 label: 75
1766 },
1767 1e-6
1768 ));
1769
1770 let mut output_boxes1: Vec<_> = Vec::with_capacity(50);
1771 let mut output_masks1: Vec<_> = Vec::with_capacity(50);
1772 decoder
1773 .decode_quantized(&[out.view().into()], &mut output_boxes1, &mut output_masks1)
1774 .unwrap();
1775
1776 let out = dequantize_ndarray(out.view(), quant.into());
1777 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(50);
1778 let mut output_masks_f32: Vec<_> = Vec::with_capacity(50);
1779 decoder
1780 .decode_float::<f32>(
1781 &[out.view().into_dyn()],
1782 &mut output_boxes_f32,
1783 &mut output_masks_f32,
1784 )
1785 .unwrap();
1786
1787 compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1788 compare_outputs((&output_boxes, &output_boxes_f32), (&[], &output_masks_f32));
1789 }
1790
1791 #[test]
1792 fn test_decoder_masks() {
1793 let score_threshold = 0.45;
1794 let iou_threshold = 0.45;
1795 let boxes = load_yolov8_boxes();
1796 let quant_boxes = Quantization::new(0.021287761628627777, 31);
1797
1798 let protos = load_yolov8_protos();
1799 let quant_protos = Quantization::new(0.02491161972284317, -117);
1800 let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1801 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1802 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1803 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1804 decode_yolo_segdet_float(
1805 seg.slice(s![0, .., ..]),
1806 protos.slice(s![0, .., .., ..]),
1807 score_threshold,
1808 iou_threshold,
1809 Some(configs::Nms::ClassAgnostic),
1810 &mut output_boxes,
1811 &mut output_masks,
1812 )
1813 .unwrap();
1814 assert_eq!(output_boxes.len(), 2);
1815 assert_eq!(output_boxes.len(), output_masks.len());
1816
1817 for (b, m) in output_boxes.iter().zip(&output_masks) {
1818 assert!(b.bbox.xmin >= m.xmin);
1821 assert!(b.bbox.ymin >= m.ymin);
1822 assert!(b.bbox.xmax <= m.xmax);
1823 assert!(b.bbox.ymax <= m.ymax);
1824 }
1825 assert!(output_boxes[0].equal_within_delta(
1826 &DetectBox {
1827 bbox: BoundingBox {
1828 xmin: 0.08515105,
1829 ymin: 0.7131401,
1830 xmax: 0.29802868,
1831 ymax: 0.8195788,
1832 },
1833 score: 0.91537374,
1834 label: 23
1835 },
1836 1.0 / 160.0, ));
1838
1839 assert!(output_boxes[1].equal_within_delta(
1840 &DetectBox {
1841 bbox: BoundingBox {
1842 xmin: 0.59605736,
1843 ymin: 0.25545314,
1844 xmax: 0.93666154,
1845 ymax: 0.72378385,
1846 },
1847 score: 0.91537374,
1848 label: 23
1849 },
1850 1.0 / 160.0, ));
1852
1853 let full_mask = include_bytes!(concat!(
1854 env!("CARGO_MANIFEST_DIR"),
1855 "/../../testdata/yolov8_mask_results.bin"
1856 ));
1857 let full_mask = ndarray::Array2::from_shape_vec((160, 160), full_mask.to_vec()).unwrap();
1858
1859 let cropped_mask = full_mask.slice(ndarray::s![
1860 (output_masks[1].ymin * 160.0) as usize..(output_masks[1].ymax * 160.0) as usize,
1861 (output_masks[1].xmin * 160.0) as usize..(output_masks[1].xmax * 160.0) as usize,
1862 ]);
1863
1864 assert_eq!(
1865 cropped_mask,
1866 segmentation_to_mask(output_masks[1].segmentation.view()).unwrap()
1867 );
1868 }
1869
1870 #[test]
1880 fn test_decoder_masks_nchw_protos() {
1881 let score_threshold = 0.45;
1882 let iou_threshold = 0.45;
1883
1884 let boxes_2d = load_yolov8_boxes().slice_move(s![0, .., ..]);
1886 let quant_boxes = Quantization::new(0.021287761628627777, 31);
1887
1888 let protos_hwc = load_yolov8_protos().slice_move(s![0, .., .., ..]);
1890 let quant_protos = Quantization::new(0.02491161972284317, -117);
1891 let protos_f32_hwc = dequantize_ndarray::<_, _, f32>(protos_hwc.view(), quant_protos);
1892
1893 let seg = dequantize_ndarray::<_, _, f32>(boxes_2d.view(), quant_boxes);
1895 let mut ref_boxes: Vec<_> = Vec::with_capacity(10);
1896 let mut ref_masks: Vec<_> = Vec::with_capacity(10);
1897 decode_yolo_segdet_float(
1898 seg.view(),
1899 protos_f32_hwc.view(),
1900 score_threshold,
1901 iou_threshold,
1902 Some(configs::Nms::ClassAgnostic),
1903 &mut ref_boxes,
1904 &mut ref_masks,
1905 )
1906 .unwrap();
1907 assert_eq!(ref_boxes.len(), 2);
1908
1909 let protos_f32_chw_view = protos_f32_hwc.view().permuted_axes([2, 0, 1]); let protos_f32_chw = protos_f32_chw_view.to_owned();
1916 let protos_nchw = protos_f32_chw.insert_axis(ndarray::Axis(0)); let seg_3d = seg.insert_axis(ndarray::Axis(0)); let decoder = DecoderBuilder::default()
1925 .with_config_yolo_segdet(
1926 configs::Detection {
1927 decoder: configs::DecoderType::Ultralytics,
1928 quantization: None,
1929 shape: vec![1, 116, 8400],
1930 dshape: vec![
1931 (configs::DimName::Batch, 1),
1932 (configs::DimName::NumFeatures, 116),
1933 (configs::DimName::NumBoxes, 8400),
1934 ],
1935 normalized: Some(true),
1936 anchors: None,
1937 },
1938 configs::Protos {
1939 decoder: configs::DecoderType::Ultralytics,
1940 quantization: None,
1941 shape: vec![1, 32, 160, 160],
1942 dshape: vec![
1943 (configs::DimName::Batch, 1),
1944 (configs::DimName::NumProtos, 32),
1945 (configs::DimName::Height, 160),
1946 (configs::DimName::Width, 160),
1947 ],
1948 },
1949 None, )
1951 .with_score_threshold(score_threshold)
1952 .with_iou_threshold(iou_threshold)
1953 .build()
1954 .unwrap();
1955
1956 let mut cfg_boxes: Vec<_> = Vec::with_capacity(10);
1957 let mut cfg_masks: Vec<_> = Vec::with_capacity(10);
1958 decoder
1959 .decode_float(
1960 &[seg_3d.view().into_dyn(), protos_nchw.view().into_dyn()],
1961 &mut cfg_boxes,
1962 &mut cfg_masks,
1963 )
1964 .unwrap();
1965
1966 assert_eq!(
1968 cfg_boxes.len(),
1969 ref_boxes.len(),
1970 "config path produced {} boxes, reference produced {}",
1971 cfg_boxes.len(),
1972 ref_boxes.len()
1973 );
1974
1975 for (i, (cb, rb)) in cfg_boxes.iter().zip(&ref_boxes).enumerate() {
1977 assert!(
1978 cb.equal_within_delta(rb, 0.01),
1979 "box {i} mismatch: config={cb:?}, reference={rb:?}"
1980 );
1981 }
1982
1983 for (i, (cm, rm)) in cfg_masks.iter().zip(&ref_masks).enumerate() {
1985 let cm_arr = segmentation_to_mask(cm.segmentation.view()).unwrap();
1986 let rm_arr = segmentation_to_mask(rm.segmentation.view()).unwrap();
1987 assert_eq!(
1988 cm_arr, rm_arr,
1989 "mask {i} pixel mismatch between config-driven and reference paths"
1990 );
1991 }
1992 }
1993
1994 #[test]
1995 fn test_decoder_masks_i8() {
1996 let score_threshold = 0.45;
1997 let iou_threshold = 0.45;
1998 let boxes = load_yolov8_boxes();
1999 let quant_boxes = (0.021287761628627777, 31).into();
2000
2001 let protos = load_yolov8_protos();
2002 let quant_protos = (0.02491161972284317, -117).into();
2003 let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2004 let mut output_masks: Vec<_> = Vec::with_capacity(500);
2005
2006 let decoder = DecoderBuilder::default()
2007 .with_config_yolo_segdet(
2008 configs::Detection {
2009 decoder: configs::DecoderType::Ultralytics,
2010 quantization: Some(quant_boxes),
2011 shape: vec![1, 116, 8400],
2012 anchors: None,
2013 dshape: vec![
2014 (DimName::Batch, 1),
2015 (DimName::NumFeatures, 116),
2016 (DimName::NumBoxes, 8400),
2017 ],
2018 normalized: Some(true),
2019 },
2020 Protos {
2021 decoder: configs::DecoderType::Ultralytics,
2022 quantization: Some(quant_protos),
2023 shape: vec![1, 160, 160, 32],
2024 dshape: vec![
2025 (DimName::Batch, 1),
2026 (DimName::Height, 160),
2027 (DimName::Width, 160),
2028 (DimName::NumProtos, 32),
2029 ],
2030 },
2031 Some(DecoderVersion::Yolo11),
2032 )
2033 .with_score_threshold(score_threshold)
2034 .with_iou_threshold(iou_threshold)
2035 .build()
2036 .unwrap();
2037
2038 let quant_boxes = quant_boxes.into();
2039 let quant_protos = quant_protos.into();
2040
2041 decode_yolo_segdet_quant(
2042 (boxes.slice(s![0, .., ..]), quant_boxes),
2043 (protos.slice(s![0, .., .., ..]), quant_protos),
2044 score_threshold,
2045 iou_threshold,
2046 Some(configs::Nms::ClassAgnostic),
2047 &mut output_boxes,
2048 &mut output_masks,
2049 )
2050 .unwrap();
2051
2052 let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
2053 let mut output_masks1: Vec<_> = Vec::with_capacity(500);
2054
2055 decoder
2056 .decode_quantized(
2057 &[boxes.view().into(), protos.view().into()],
2058 &mut output_boxes1,
2059 &mut output_masks1,
2060 )
2061 .unwrap();
2062
2063 let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
2064 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
2065
2066 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2067 let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
2068 decode_yolo_segdet_float(
2069 seg.slice(s![0, .., ..]),
2070 protos.slice(s![0, .., .., ..]),
2071 score_threshold,
2072 iou_threshold,
2073 Some(configs::Nms::ClassAgnostic),
2074 &mut output_boxes_f32,
2075 &mut output_masks_f32,
2076 )
2077 .unwrap();
2078
2079 let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(500);
2080 let mut output_masks1_f32: Vec<_> = Vec::with_capacity(500);
2081
2082 decoder
2083 .decode_float(
2084 &[seg.view().into_dyn(), protos.view().into_dyn()],
2085 &mut output_boxes1_f32,
2086 &mut output_masks1_f32,
2087 )
2088 .unwrap();
2089
2090 compare_outputs(
2091 (&output_boxes, &output_boxes1),
2092 (&output_masks, &output_masks1),
2093 );
2094
2095 compare_outputs(
2096 (&output_boxes, &output_boxes_f32),
2097 (&output_masks, &output_masks_f32),
2098 );
2099
2100 compare_outputs(
2101 (&output_boxes_f32, &output_boxes1_f32),
2102 (&output_masks_f32, &output_masks1_f32),
2103 );
2104 }
2105
2106 #[test]
2107 fn test_decoder_yolo_split() {
2108 let score_threshold = 0.45;
2109 let iou_threshold = 0.45;
2110 let boxes = load_yolov8_boxes();
2111 let boxes: Vec<_> = boxes.iter().map(|x| *x as i16 * 256).collect();
2112 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2113
2114 let quant_boxes = Quantization::new(0.021287761628627777 / 256.0, 31 * 256);
2115
2116 let decoder = DecoderBuilder::default()
2117 .with_config_yolo_split_det(
2118 configs::Boxes {
2119 decoder: configs::DecoderType::Ultralytics,
2120 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2121 shape: vec![1, 4, 8400],
2122 dshape: vec![
2123 (DimName::Batch, 1),
2124 (DimName::BoxCoords, 4),
2125 (DimName::NumBoxes, 8400),
2126 ],
2127 normalized: Some(true),
2128 },
2129 configs::Scores {
2130 decoder: configs::DecoderType::Ultralytics,
2131 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2132 shape: vec![1, 80, 8400],
2133 dshape: vec![
2134 (DimName::Batch, 1),
2135 (DimName::NumClasses, 80),
2136 (DimName::NumBoxes, 8400),
2137 ],
2138 },
2139 )
2140 .with_score_threshold(score_threshold)
2141 .with_iou_threshold(iou_threshold)
2142 .build()
2143 .unwrap();
2144
2145 let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2146 let mut output_masks: Vec<_> = Vec::with_capacity(500);
2147
2148 decoder
2149 .decode_quantized(
2150 &[
2151 boxes.slice(s![.., ..4, ..]).into(),
2152 boxes.slice(s![.., 4..84, ..]).into(),
2153 ],
2154 &mut output_boxes,
2155 &mut output_masks,
2156 )
2157 .unwrap();
2158
2159 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
2160 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2161 decode_yolo_det_float(
2162 seg.slice(s![0, ..84, ..]),
2163 score_threshold,
2164 iou_threshold,
2165 Some(configs::Nms::ClassAgnostic),
2166 &mut output_boxes_f32,
2167 );
2168
2169 let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
2170 let mut output_masks1: Vec<_> = Vec::with_capacity(500);
2171
2172 decoder
2173 .decode_float(
2174 &[
2175 seg.slice(s![.., ..4, ..]).into_dyn(),
2176 seg.slice(s![.., 4..84, ..]).into_dyn(),
2177 ],
2178 &mut output_boxes1,
2179 &mut output_masks1,
2180 )
2181 .unwrap();
2182 compare_outputs((&output_boxes, &output_boxes_f32), (&output_masks, &[]));
2183 compare_outputs((&output_boxes_f32, &output_boxes1), (&[], &output_masks1));
2184 }
2185
2186 #[test]
2187 fn test_decoder_masks_config_mixed() {
2188 let score_threshold = 0.45;
2189 let iou_threshold = 0.45;
2190 let boxes_raw = load_yolov8_boxes();
2191 let boxes: Vec<_> = boxes_raw.iter().map(|x| *x as i16 * 256).collect();
2192 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2193
2194 let quant_boxes = (0.021287761628627777 / 256.0, 31 * 256);
2195
2196 let protos = load_yolov8_protos();
2197 let quant_protos = (0.02491161972284317, -117);
2198
2199 let decoder = build_yolo_split_segdet_decoder(
2200 score_threshold,
2201 iou_threshold,
2202 quant_boxes,
2203 quant_protos,
2204 );
2205 let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2206 let mut output_masks: Vec<_> = Vec::with_capacity(500);
2207
2208 decoder
2209 .decode_quantized(
2210 &[
2211 boxes.slice(s![.., ..4, ..]).into(),
2212 boxes.slice(s![.., 4..84, ..]).into(),
2213 boxes.slice(s![.., 84.., ..]).into(),
2214 protos.view().into(),
2215 ],
2216 &mut output_boxes,
2217 &mut output_masks,
2218 )
2219 .unwrap();
2220
2221 let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos.into());
2222 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes.into());
2223 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2224 let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
2225 decode_yolo_segdet_float(
2226 seg.slice(s![0, .., ..]),
2227 protos.slice(s![0, .., .., ..]),
2228 score_threshold,
2229 iou_threshold,
2230 Some(configs::Nms::ClassAgnostic),
2231 &mut output_boxes_f32,
2232 &mut output_masks_f32,
2233 )
2234 .unwrap();
2235
2236 let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
2237 let mut output_masks1: Vec<_> = Vec::with_capacity(500);
2238
2239 decoder
2240 .decode_float(
2241 &[
2242 seg.slice(s![.., ..4, ..]).into_dyn(),
2243 seg.slice(s![.., 4..84, ..]).into_dyn(),
2244 seg.slice(s![.., 84.., ..]).into_dyn(),
2245 protos.view().into_dyn(),
2246 ],
2247 &mut output_boxes1,
2248 &mut output_masks1,
2249 )
2250 .unwrap();
2251 compare_outputs(
2252 (&output_boxes, &output_boxes_f32),
2253 (&output_masks, &output_masks_f32),
2254 );
2255 compare_outputs(
2256 (&output_boxes_f32, &output_boxes1),
2257 (&output_masks_f32, &output_masks1),
2258 );
2259 }
2260
2261 fn build_yolo_split_segdet_decoder(
2262 score_threshold: f32,
2263 iou_threshold: f32,
2264 quant_boxes: (f32, i32),
2265 quant_protos: (f32, i32),
2266 ) -> crate::Decoder {
2267 DecoderBuilder::default()
2268 .with_config_yolo_split_segdet(
2269 configs::Boxes {
2270 decoder: configs::DecoderType::Ultralytics,
2271 quantization: Some(quant_boxes.into()),
2272 shape: vec![1, 4, 8400],
2273 dshape: vec![
2274 (DimName::Batch, 1),
2275 (DimName::BoxCoords, 4),
2276 (DimName::NumBoxes, 8400),
2277 ],
2278 normalized: Some(true),
2279 },
2280 configs::Scores {
2281 decoder: configs::DecoderType::Ultralytics,
2282 quantization: Some(quant_boxes.into()),
2283 shape: vec![1, 80, 8400],
2284 dshape: vec![
2285 (DimName::Batch, 1),
2286 (DimName::NumClasses, 80),
2287 (DimName::NumBoxes, 8400),
2288 ],
2289 },
2290 configs::MaskCoefficients {
2291 decoder: configs::DecoderType::Ultralytics,
2292 quantization: Some(quant_boxes.into()),
2293 shape: vec![1, 32, 8400],
2294 dshape: vec![
2295 (DimName::Batch, 1),
2296 (DimName::NumProtos, 32),
2297 (DimName::NumBoxes, 8400),
2298 ],
2299 },
2300 configs::Protos {
2301 decoder: configs::DecoderType::Ultralytics,
2302 quantization: Some(quant_protos.into()),
2303 shape: vec![1, 160, 160, 32],
2304 dshape: vec![
2305 (DimName::Batch, 1),
2306 (DimName::Height, 160),
2307 (DimName::Width, 160),
2308 (DimName::NumProtos, 32),
2309 ],
2310 },
2311 )
2312 .with_score_threshold(score_threshold)
2313 .with_iou_threshold(iou_threshold)
2314 .build()
2315 .unwrap()
2316 }
2317
2318 fn build_yolov8_seg_decoder(score_threshold: f32, iou_threshold: f32) -> crate::Decoder {
2319 let config_yaml = include_str!(concat!(
2320 env!("CARGO_MANIFEST_DIR"),
2321 "/../../testdata/yolov8_seg.yaml"
2322 ));
2323 DecoderBuilder::default()
2324 .with_config_yaml_str(config_yaml.to_string())
2325 .with_score_threshold(score_threshold)
2326 .with_iou_threshold(iou_threshold)
2327 .build()
2328 .unwrap()
2329 }
2330 #[test]
2331 fn test_decoder_masks_config_i32() {
2332 let score_threshold = 0.45;
2333 let iou_threshold = 0.45;
2334 let boxes_raw = load_yolov8_boxes();
2335 let scale = 1 << 23;
2336 let boxes: Vec<_> = boxes_raw.iter().map(|x| *x as i32 * scale).collect();
2337 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2338
2339 let quant_boxes = (0.021287761628627777 / scale as f32, 31 * scale);
2340
2341 let protos_raw = load_yolov8_protos();
2342 let protos: Vec<_> = protos_raw.iter().map(|x| *x as i32 * scale).collect();
2343 let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos).unwrap();
2344 let quant_protos = (0.02491161972284317 / scale as f32, -117 * scale);
2345
2346 let decoder = build_yolo_split_segdet_decoder(
2347 score_threshold,
2348 iou_threshold,
2349 quant_boxes,
2350 quant_protos,
2351 );
2352
2353 let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2354 let mut output_masks: Vec<_> = Vec::with_capacity(500);
2355
2356 decoder
2357 .decode_quantized(
2358 &[
2359 boxes.slice(s![.., ..4, ..]).into(),
2360 boxes.slice(s![.., 4..84, ..]).into(),
2361 boxes.slice(s![.., 84.., ..]).into(),
2362 protos.view().into(),
2363 ],
2364 &mut output_boxes,
2365 &mut output_masks,
2366 )
2367 .unwrap();
2368
2369 let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos.into());
2370 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes.into());
2371 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2372 let mut output_masks_f32: Vec<Segmentation> = Vec::with_capacity(500);
2373 decode_yolo_segdet_float(
2374 seg.slice(s![0, .., ..]),
2375 protos.slice(s![0, .., .., ..]),
2376 score_threshold,
2377 iou_threshold,
2378 Some(configs::Nms::ClassAgnostic),
2379 &mut output_boxes_f32,
2380 &mut output_masks_f32,
2381 )
2382 .unwrap();
2383
2384 assert_eq!(output_boxes.len(), output_boxes_f32.len());
2385 assert_eq!(output_masks.len(), output_masks_f32.len());
2386
2387 compare_outputs(
2388 (&output_boxes, &output_boxes_f32),
2389 (&output_masks, &output_masks_f32),
2390 );
2391 }
2392
2393 #[test]
2395 fn test_context_switch() {
2396 let yolo_det = || {
2397 let score_threshold = 0.25;
2398 let iou_threshold = 0.7;
2399 let out = load_yolov8s_det();
2400 let quant = (0.0040811873, -123).into();
2401
2402 let decoder = DecoderBuilder::default()
2403 .with_config_yolo_det(
2404 configs::Detection {
2405 decoder: DecoderType::Ultralytics,
2406 shape: vec![1, 84, 8400],
2407 anchors: None,
2408 quantization: Some(quant),
2409 dshape: vec![
2410 (DimName::Batch, 1),
2411 (DimName::NumFeatures, 84),
2412 (DimName::NumBoxes, 8400),
2413 ],
2414 normalized: None,
2415 },
2416 None,
2417 )
2418 .with_score_threshold(score_threshold)
2419 .with_iou_threshold(iou_threshold)
2420 .build()
2421 .unwrap();
2422
2423 let mut output_boxes: Vec<_> = Vec::with_capacity(50);
2424 let mut output_masks: Vec<_> = Vec::with_capacity(50);
2425
2426 for _ in 0..100 {
2427 decoder
2428 .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
2429 .unwrap();
2430
2431 assert!(output_boxes[0].equal_within_delta(
2432 &DetectBox {
2433 bbox: BoundingBox {
2434 xmin: 0.5285137,
2435 ymin: 0.05305544,
2436 xmax: 0.87541467,
2437 ymax: 0.9998909,
2438 },
2439 score: 0.5591227,
2440 label: 0
2441 },
2442 1e-6
2443 ));
2444
2445 assert!(output_boxes[1].equal_within_delta(
2446 &DetectBox {
2447 bbox: BoundingBox {
2448 xmin: 0.130598,
2449 ymin: 0.43260583,
2450 xmax: 0.35098213,
2451 ymax: 0.9958097,
2452 },
2453 score: 0.33057618,
2454 label: 75
2455 },
2456 1e-6
2457 ));
2458 assert!(output_masks.is_empty());
2459 }
2460 };
2461
2462 let modelpack_det_split = || {
2463 let score_threshold = 0.8;
2464 let iou_threshold = 0.5;
2465
2466 let seg = include_bytes!(concat!(
2467 env!("CARGO_MANIFEST_DIR"),
2468 "/../../testdata/modelpack_seg_2x160x160.bin"
2469 ));
2470 let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
2471
2472 let detect0 = include_bytes!(concat!(
2473 env!("CARGO_MANIFEST_DIR"),
2474 "/../../testdata/modelpack_split_9x15x18.bin"
2475 ));
2476 let detect0 =
2477 ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
2478
2479 let detect1 = include_bytes!(concat!(
2480 env!("CARGO_MANIFEST_DIR"),
2481 "/../../testdata/modelpack_split_17x30x18.bin"
2482 ));
2483 let detect1 =
2484 ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
2485
2486 let mut mask = seg.slice(s![0, .., .., ..]);
2487 mask.swap_axes(0, 1);
2488 mask.swap_axes(1, 2);
2489 let mask = [Segmentation {
2490 xmin: 0.0,
2491 ymin: 0.0,
2492 xmax: 1.0,
2493 ymax: 1.0,
2494 segmentation: mask.into_owned(),
2495 }];
2496 let correct_boxes = [DetectBox {
2497 bbox: BoundingBox {
2498 xmin: 0.43171933,
2499 ymin: 0.68243736,
2500 xmax: 0.5626645,
2501 ymax: 0.808863,
2502 },
2503 score: 0.99240804,
2504 label: 0,
2505 }];
2506
2507 let quant0 = (0.08547406643629074, 174).into();
2508 let quant1 = (0.09929127991199493, 183).into();
2509 let quant_seg = (1.0 / 255.0, 0).into();
2510
2511 let anchors0 = vec![
2512 [0.36666667461395264, 0.31481480598449707],
2513 [0.38749998807907104, 0.4740740656852722],
2514 [0.5333333611488342, 0.644444465637207],
2515 ];
2516 let anchors1 = vec![
2517 [0.13750000298023224, 0.2074074000120163],
2518 [0.2541666626930237, 0.21481481194496155],
2519 [0.23125000298023224, 0.35185185074806213],
2520 ];
2521
2522 let decoder = DecoderBuilder::default()
2523 .with_config_modelpack_segdet_split(
2524 vec![
2525 configs::Detection {
2526 decoder: DecoderType::ModelPack,
2527 shape: vec![1, 17, 30, 18],
2528 anchors: Some(anchors1),
2529 quantization: Some(quant1),
2530 dshape: vec![
2531 (DimName::Batch, 1),
2532 (DimName::Height, 17),
2533 (DimName::Width, 30),
2534 (DimName::NumAnchorsXFeatures, 18),
2535 ],
2536 normalized: None,
2537 },
2538 configs::Detection {
2539 decoder: DecoderType::ModelPack,
2540 shape: vec![1, 9, 15, 18],
2541 anchors: Some(anchors0),
2542 quantization: Some(quant0),
2543 dshape: vec![
2544 (DimName::Batch, 1),
2545 (DimName::Height, 9),
2546 (DimName::Width, 15),
2547 (DimName::NumAnchorsXFeatures, 18),
2548 ],
2549 normalized: None,
2550 },
2551 ],
2552 configs::Segmentation {
2553 decoder: DecoderType::ModelPack,
2554 quantization: Some(quant_seg),
2555 shape: vec![1, 2, 160, 160],
2556 dshape: vec![
2557 (DimName::Batch, 1),
2558 (DimName::NumClasses, 2),
2559 (DimName::Height, 160),
2560 (DimName::Width, 160),
2561 ],
2562 },
2563 )
2564 .with_score_threshold(score_threshold)
2565 .with_iou_threshold(iou_threshold)
2566 .build()
2567 .unwrap();
2568 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
2569 let mut output_masks: Vec<_> = Vec::with_capacity(10);
2570
2571 for _ in 0..100 {
2572 decoder
2573 .decode_quantized(
2574 &[
2575 detect0.view().into(),
2576 detect1.view().into(),
2577 seg.view().into(),
2578 ],
2579 &mut output_boxes,
2580 &mut output_masks,
2581 )
2582 .unwrap();
2583
2584 compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
2585 }
2586 };
2587
2588 let handles = vec![
2589 std::thread::spawn(yolo_det),
2590 std::thread::spawn(modelpack_det_split),
2591 std::thread::spawn(yolo_det),
2592 std::thread::spawn(modelpack_det_split),
2593 std::thread::spawn(yolo_det),
2594 std::thread::spawn(modelpack_det_split),
2595 std::thread::spawn(yolo_det),
2596 std::thread::spawn(modelpack_det_split),
2597 ];
2598 for handle in handles {
2599 handle.join().unwrap();
2600 }
2601 }
2602
2603 #[test]
2604 fn test_ndarray_to_xyxy_float() {
2605 let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2606 let xyxy: [f32; 4] = XYWH::ndarray_to_xyxy_float(arr.view());
2607 assert_eq!(xyxy, [0.0_f32, 10.0, 20.0, 30.0]);
2608
2609 let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2610 let xyxy: [f32; 4] = XYXY::ndarray_to_xyxy_float(arr.view());
2611 assert_eq!(xyxy, [10.0_f32, 20.0, 20.0, 20.0]);
2612 }
2613
2614 #[test]
2615 fn test_class_aware_nms_float() {
2616 use crate::float::nms_class_aware_float;
2617
2618 let boxes = vec![
2620 DetectBox {
2621 bbox: BoundingBox {
2622 xmin: 0.0,
2623 ymin: 0.0,
2624 xmax: 0.5,
2625 ymax: 0.5,
2626 },
2627 score: 0.9,
2628 label: 0, },
2630 DetectBox {
2631 bbox: BoundingBox {
2632 xmin: 0.1,
2633 ymin: 0.1,
2634 xmax: 0.6,
2635 ymax: 0.6,
2636 },
2637 score: 0.8,
2638 label: 1, },
2640 ];
2641
2642 let result = nms_class_aware_float(0.3, boxes.clone());
2645 assert_eq!(
2646 result.len(),
2647 2,
2648 "Class-aware NMS should keep both boxes with different classes"
2649 );
2650
2651 let same_class_boxes = vec![
2653 DetectBox {
2654 bbox: BoundingBox {
2655 xmin: 0.0,
2656 ymin: 0.0,
2657 xmax: 0.5,
2658 ymax: 0.5,
2659 },
2660 score: 0.9,
2661 label: 0,
2662 },
2663 DetectBox {
2664 bbox: BoundingBox {
2665 xmin: 0.1,
2666 ymin: 0.1,
2667 xmax: 0.6,
2668 ymax: 0.6,
2669 },
2670 score: 0.8,
2671 label: 0, },
2673 ];
2674
2675 let result = nms_class_aware_float(0.3, same_class_boxes);
2676 assert_eq!(
2677 result.len(),
2678 1,
2679 "Class-aware NMS should suppress overlapping box with same class"
2680 );
2681 assert_eq!(result[0].label, 0);
2682 assert!((result[0].score - 0.9).abs() < 1e-6);
2683 }
2684
2685 #[test]
2686 fn test_class_agnostic_vs_aware_nms() {
2687 use crate::float::{nms_class_aware_float, nms_float};
2688
2689 let boxes = vec![
2691 DetectBox {
2692 bbox: BoundingBox {
2693 xmin: 0.0,
2694 ymin: 0.0,
2695 xmax: 0.5,
2696 ymax: 0.5,
2697 },
2698 score: 0.9,
2699 label: 0,
2700 },
2701 DetectBox {
2702 bbox: BoundingBox {
2703 xmin: 0.1,
2704 ymin: 0.1,
2705 xmax: 0.6,
2706 ymax: 0.6,
2707 },
2708 score: 0.8,
2709 label: 1,
2710 },
2711 ];
2712
2713 let agnostic_result = nms_float(0.3, boxes.clone());
2715 assert_eq!(
2716 agnostic_result.len(),
2717 1,
2718 "Class-agnostic NMS should suppress overlapping boxes"
2719 );
2720
2721 let aware_result = nms_class_aware_float(0.3, boxes);
2723 assert_eq!(
2724 aware_result.len(),
2725 2,
2726 "Class-aware NMS should keep boxes with different classes"
2727 );
2728 }
2729
2730 #[test]
2731 fn test_class_aware_nms_int() {
2732 use crate::byte::nms_class_aware_int;
2733
2734 let boxes = vec![
2736 DetectBoxQuantized {
2737 bbox: BoundingBox {
2738 xmin: 0.0,
2739 ymin: 0.0,
2740 xmax: 0.5,
2741 ymax: 0.5,
2742 },
2743 score: 200_u8,
2744 label: 0,
2745 },
2746 DetectBoxQuantized {
2747 bbox: BoundingBox {
2748 xmin: 0.1,
2749 ymin: 0.1,
2750 xmax: 0.6,
2751 ymax: 0.6,
2752 },
2753 score: 180_u8,
2754 label: 1, },
2756 ];
2757
2758 let result = nms_class_aware_int(0.5, boxes);
2760 assert_eq!(
2761 result.len(),
2762 2,
2763 "Class-aware NMS (int) should keep boxes with different classes"
2764 );
2765 }
2766
2767 #[test]
2768 fn test_nms_enum_default() {
2769 let default_nms: configs::Nms = Default::default();
2771 assert_eq!(default_nms, configs::Nms::ClassAgnostic);
2772 }
2773
2774 #[test]
2775 fn test_decoder_nms_mode() {
2776 let decoder = DecoderBuilder::default()
2778 .with_config_yolo_det(
2779 configs::Detection {
2780 anchors: None,
2781 decoder: DecoderType::Ultralytics,
2782 quantization: None,
2783 shape: vec![1, 84, 8400],
2784 dshape: Vec::new(),
2785 normalized: Some(true),
2786 },
2787 None,
2788 )
2789 .with_nms(Some(configs::Nms::ClassAware))
2790 .build()
2791 .unwrap();
2792
2793 assert_eq!(decoder.nms, Some(configs::Nms::ClassAware));
2794 }
2795
2796 #[test]
2797 fn test_decoder_nms_bypass() {
2798 let decoder = DecoderBuilder::default()
2800 .with_config_yolo_det(
2801 configs::Detection {
2802 anchors: None,
2803 decoder: DecoderType::Ultralytics,
2804 quantization: None,
2805 shape: vec![1, 84, 8400],
2806 dshape: Vec::new(),
2807 normalized: Some(true),
2808 },
2809 None,
2810 )
2811 .with_nms(None)
2812 .build()
2813 .unwrap();
2814
2815 assert_eq!(decoder.nms, None);
2816 }
2817
2818 #[test]
2819 fn test_decoder_normalized_boxes_true() {
2820 let decoder = DecoderBuilder::default()
2822 .with_config_yolo_det(
2823 configs::Detection {
2824 anchors: None,
2825 decoder: DecoderType::Ultralytics,
2826 quantization: None,
2827 shape: vec![1, 84, 8400],
2828 dshape: Vec::new(),
2829 normalized: Some(true),
2830 },
2831 None,
2832 )
2833 .build()
2834 .unwrap();
2835
2836 assert_eq!(decoder.normalized_boxes(), Some(true));
2837 }
2838
2839 #[test]
2840 fn test_decoder_normalized_boxes_false() {
2841 let decoder = DecoderBuilder::default()
2844 .with_config_yolo_det(
2845 configs::Detection {
2846 anchors: None,
2847 decoder: DecoderType::Ultralytics,
2848 quantization: None,
2849 shape: vec![1, 84, 8400],
2850 dshape: Vec::new(),
2851 normalized: Some(false),
2852 },
2853 None,
2854 )
2855 .build()
2856 .unwrap();
2857
2858 assert_eq!(decoder.normalized_boxes(), Some(false));
2859 }
2860
2861 #[test]
2862 fn test_decoder_normalized_boxes_unknown() {
2863 let decoder = DecoderBuilder::default()
2865 .with_config_yolo_det(
2866 configs::Detection {
2867 anchors: None,
2868 decoder: DecoderType::Ultralytics,
2869 quantization: None,
2870 shape: vec![1, 84, 8400],
2871 dshape: Vec::new(),
2872 normalized: None,
2873 },
2874 Some(DecoderVersion::Yolo11),
2875 )
2876 .build()
2877 .unwrap();
2878
2879 assert_eq!(decoder.normalized_boxes(), None);
2880 }
2881
2882 pub fn quantize_ndarray<T: PrimInt + 'static, D: Dimension, F: Float + AsPrimitive<T>>(
2883 input: ArrayView<F, D>,
2884 quant: Quantization,
2885 ) -> Array<T, D>
2886 where
2887 i32: num_traits::AsPrimitive<F>,
2888 f32: num_traits::AsPrimitive<F>,
2889 {
2890 let zero_point = quant.zero_point.as_();
2891 let div_scale = F::one() / quant.scale.as_();
2892 if zero_point != F::zero() {
2893 input.mapv(|d| (d * div_scale + zero_point).round().as_())
2894 } else {
2895 input.mapv(|d| (d * div_scale).round().as_())
2896 }
2897 }
2898
2899 fn real_data_expected_boxes() -> [DetectBox; 2] {
2900 [
2901 DetectBox {
2902 bbox: BoundingBox {
2903 xmin: 0.08515105,
2904 ymin: 0.7131401,
2905 xmax: 0.29802868,
2906 ymax: 0.8195788,
2907 },
2908 score: 0.91537374,
2909 label: 23,
2910 },
2911 DetectBox {
2912 bbox: BoundingBox {
2913 xmin: 0.59605736,
2914 ymin: 0.25545314,
2915 xmax: 0.93666154,
2916 ymax: 0.72378385,
2917 },
2918 score: 0.91537374,
2919 label: 23,
2920 },
2921 ]
2922 }
2923
2924 fn e2e_expected_boxes_quant() -> [DetectBox; 1] {
2925 [DetectBox {
2926 bbox: BoundingBox {
2927 xmin: 0.12549022,
2928 ymin: 0.12549022,
2929 xmax: 0.23529413,
2930 ymax: 0.23529413,
2931 },
2932 score: 0.98823535,
2933 label: 2,
2934 }]
2935 }
2936
2937 fn e2e_expected_boxes_float() -> [DetectBox; 1] {
2938 [DetectBox {
2939 bbox: BoundingBox {
2940 xmin: 0.1234,
2941 ymin: 0.1234,
2942 xmax: 0.2345,
2943 ymax: 0.2345,
2944 },
2945 score: 0.9876,
2946 label: 2,
2947 }]
2948 }
2949
2950 macro_rules! real_data_proto_test {
2951 ($name:ident, quantized, $layout:ident) => {
2952 #[test]
2953 fn $name() {
2954 let is_split = matches!(stringify!($layout), "split");
2955
2956 let score_threshold = 0.45;
2957 let iou_threshold = 0.45;
2958 let quant_boxes = (0.021287762_f32, 31_i32);
2959 let quant_protos = (0.02491162_f32, -117_i32);
2960
2961 let raw_boxes = include_bytes!(concat!(
2962 env!("CARGO_MANIFEST_DIR"),
2963 "/../../testdata/yolov8_boxes_116x8400.bin"
2964 ));
2965 let raw_boxes = unsafe {
2966 std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len())
2967 };
2968 let boxes_i8 =
2969 ndarray::Array3::from_shape_vec((1, 116, 8400), raw_boxes.to_vec()).unwrap();
2970
2971 let raw_protos = include_bytes!(concat!(
2972 env!("CARGO_MANIFEST_DIR"),
2973 "/../../testdata/yolov8_protos_160x160x32.bin"
2974 ));
2975 let raw_protos = unsafe {
2976 std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
2977 };
2978 let protos_i8 =
2979 ndarray::Array4::from_shape_vec((1, 160, 160, 32), raw_protos.to_vec())
2980 .unwrap();
2981
2982 let mask_split = boxes_i8.slice(s![.., 84.., ..]).to_owned();
2984 let scores_split = boxes_i8.slice(s![.., 4..84, ..]).to_owned();
2985 let boxes_split = boxes_i8.slice(s![.., ..4, ..]).to_owned();
2986 let boxes_combined = boxes_i8;
2987
2988 let decoder = if is_split {
2989 build_yolo_split_segdet_decoder(
2990 score_threshold,
2991 iou_threshold,
2992 quant_boxes,
2993 quant_protos,
2994 )
2995 } else {
2996 build_yolov8_seg_decoder(score_threshold, iou_threshold)
2997 };
2998
2999 let expected = real_data_expected_boxes();
3000 let mut output_boxes = Vec::with_capacity(50);
3001
3002 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
3003 vec![
3004 boxes_split.view().into(),
3005 scores_split.view().into(),
3006 mask_split.view().into(),
3007 protos_i8.view().into(),
3008 ]
3009 } else {
3010 vec![boxes_combined.view().into(), protos_i8.view().into()]
3011 };
3012 decoder
3013 .decode_quantized_proto(&inputs, &mut output_boxes)
3014 .unwrap();
3015
3016 assert_eq!(output_boxes.len(), 2);
3017 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3018 assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3019 }
3020 };
3021 ($name:ident, float, $layout:ident) => {
3022 #[test]
3023 fn $name() {
3024 let is_split = matches!(stringify!($layout), "split");
3025
3026 let score_threshold = 0.45;
3027 let iou_threshold = 0.45;
3028 let quant_boxes = (0.021287762_f32, 31_i32);
3029 let quant_protos = (0.02491162_f32, -117_i32);
3030
3031 let raw_boxes = include_bytes!(concat!(
3032 env!("CARGO_MANIFEST_DIR"),
3033 "/../../testdata/yolov8_boxes_116x8400.bin"
3034 ));
3035 let raw_boxes = unsafe {
3036 std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len())
3037 };
3038 let boxes_i8 =
3039 ndarray::Array3::from_shape_vec((1, 116, 8400), raw_boxes.to_vec()).unwrap();
3040 let boxes_f32: Array3<f32> =
3041 dequantize_ndarray(boxes_i8.view(), quant_boxes.into());
3042
3043 let raw_protos = include_bytes!(concat!(
3044 env!("CARGO_MANIFEST_DIR"),
3045 "/../../testdata/yolov8_protos_160x160x32.bin"
3046 ));
3047 let raw_protos = unsafe {
3048 std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3049 };
3050 let protos_i8 =
3051 ndarray::Array4::from_shape_vec((1, 160, 160, 32), raw_protos.to_vec())
3052 .unwrap();
3053 let protos_f32: Array4<f32> =
3054 dequantize_ndarray(protos_i8.view(), quant_protos.into());
3055
3056 let mask_split = boxes_f32.slice(s![.., 84.., ..]).to_owned();
3058 let scores_split = boxes_f32.slice(s![.., 4..84, ..]).to_owned();
3059 let boxes_split = boxes_f32.slice(s![.., ..4, ..]).to_owned();
3060 let boxes_combined = boxes_f32;
3061
3062 let decoder = if is_split {
3063 build_yolo_split_segdet_decoder(
3064 score_threshold,
3065 iou_threshold,
3066 quant_boxes,
3067 quant_protos,
3068 )
3069 } else {
3070 build_yolov8_seg_decoder(score_threshold, iou_threshold)
3071 };
3072
3073 let expected = real_data_expected_boxes();
3074 let mut output_boxes = Vec::with_capacity(50);
3075
3076 let inputs = if is_split {
3077 vec![
3078 boxes_split.view().into_dyn(),
3079 scores_split.view().into_dyn(),
3080 mask_split.view().into_dyn(),
3081 protos_f32.view().into_dyn(),
3082 ]
3083 } else {
3084 vec![
3085 boxes_combined.view().into_dyn(),
3086 protos_f32.view().into_dyn(),
3087 ]
3088 };
3089 decoder
3090 .decode_float_proto(&inputs, &mut output_boxes)
3091 .unwrap();
3092
3093 assert_eq!(output_boxes.len(), 2);
3094 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3095 assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3096 }
3097 };
3098 }
3099
3100 real_data_proto_test!(test_decoder_segdet_proto, quantized, combined);
3101 real_data_proto_test!(test_decoder_segdet_proto_float, float, combined);
3102 real_data_proto_test!(test_decoder_segdet_split_proto, quantized, split);
3103 real_data_proto_test!(test_decoder_segdet_split_proto_float, float, split);
3104
3105 const E2E_COMBINED_DET_CONFIG: &str = "
3106decoder_version: yolo26
3107outputs:
3108 - type: detection
3109 decoder: ultralytics
3110 quantization: [0.00784313725490196, 0]
3111 shape: [1, 10, 6]
3112 dshape:
3113 - [batch, 1]
3114 - [num_boxes, 10]
3115 - [num_features, 6]
3116 normalized: true
3117";
3118
3119 const E2E_COMBINED_SEGDET_CONFIG: &str = "
3120decoder_version: yolo26
3121outputs:
3122 - type: detection
3123 decoder: ultralytics
3124 quantization: [0.00784313725490196, 0]
3125 shape: [1, 10, 38]
3126 dshape:
3127 - [batch, 1]
3128 - [num_boxes, 10]
3129 - [num_features, 38]
3130 normalized: true
3131 - type: protos
3132 decoder: ultralytics
3133 quantization: [0.0039215686274509803921568627451, 128]
3134 shape: [1, 160, 160, 32]
3135 dshape:
3136 - [batch, 1]
3137 - [height, 160]
3138 - [width, 160]
3139 - [num_protos, 32]
3140";
3141
3142 const E2E_SPLIT_DET_CONFIG: &str = "
3143decoder_version: yolo26
3144outputs:
3145 - type: boxes
3146 decoder: ultralytics
3147 quantization: [0.00784313725490196, 0]
3148 shape: [1, 10, 4]
3149 dshape:
3150 - [batch, 1]
3151 - [num_boxes, 10]
3152 - [box_coords, 4]
3153 normalized: true
3154 - type: scores
3155 decoder: ultralytics
3156 quantization: [0.00784313725490196, 0]
3157 shape: [1, 10, 1]
3158 dshape:
3159 - [batch, 1]
3160 - [num_boxes, 10]
3161 - [num_classes, 1]
3162 - type: classes
3163 decoder: ultralytics
3164 quantization: [0.00784313725490196, 0]
3165 shape: [1, 10, 1]
3166 dshape:
3167 - [batch, 1]
3168 - [num_boxes, 10]
3169 - [num_classes, 1]
3170";
3171
3172 const E2E_SPLIT_SEGDET_CONFIG: &str = "
3173decoder_version: yolo26
3174outputs:
3175 - type: boxes
3176 decoder: ultralytics
3177 quantization: [0.00784313725490196, 0]
3178 shape: [1, 10, 4]
3179 dshape:
3180 - [batch, 1]
3181 - [num_boxes, 10]
3182 - [box_coords, 4]
3183 normalized: true
3184 - type: scores
3185 decoder: ultralytics
3186 quantization: [0.00784313725490196, 0]
3187 shape: [1, 10, 1]
3188 dshape:
3189 - [batch, 1]
3190 - [num_boxes, 10]
3191 - [num_classes, 1]
3192 - type: classes
3193 decoder: ultralytics
3194 quantization: [0.00784313725490196, 0]
3195 shape: [1, 10, 1]
3196 dshape:
3197 - [batch, 1]
3198 - [num_boxes, 10]
3199 - [num_classes, 1]
3200 - type: mask_coefficients
3201 decoder: ultralytics
3202 quantization: [0.00784313725490196, 0]
3203 shape: [1, 10, 32]
3204 dshape:
3205 - [batch, 1]
3206 - [num_boxes, 10]
3207 - [num_protos, 32]
3208 - type: protos
3209 decoder: ultralytics
3210 quantization: [0.0039215686274509803921568627451, 128]
3211 shape: [1, 160, 160, 32]
3212 dshape:
3213 - [batch, 1]
3214 - [height, 160]
3215 - [width, 160]
3216 - [num_protos, 32]
3217";
3218
3219 macro_rules! e2e_segdet_test {
3220 ($name:ident, quantized, $layout:ident, $output:ident) => {
3221 #[test]
3222 fn $name() {
3223 let is_split = matches!(stringify!($layout), "split");
3224 let is_proto = matches!(stringify!($output), "proto");
3225
3226 let score_threshold = 0.45;
3227 let iou_threshold = 0.45;
3228
3229 let mut boxes = Array2::zeros((10, 4));
3230 let mut scores = Array2::zeros((10, 1));
3231 let mut classes = Array2::zeros((10, 1));
3232 let mask = Array2::zeros((10, 32));
3233 let protos = Array3::<f64>::zeros((160, 160, 32));
3234 let protos = protos.insert_axis(Axis(0));
3235 let protos_quant = (1.0 / 255.0, 0.0);
3236 let protos: Array4<u8> = quantize_ndarray(protos.view(), protos_quant.into());
3237
3238 boxes
3239 .slice_mut(s![0, ..])
3240 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
3241 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
3242 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
3243
3244 let detect_quant = (2.0 / 255.0, 0.0);
3245
3246 let decoder = if is_split {
3247 DecoderBuilder::default()
3248 .with_config_yaml_str(E2E_SPLIT_SEGDET_CONFIG.to_string())
3249 .with_score_threshold(score_threshold)
3250 .with_iou_threshold(iou_threshold)
3251 .build()
3252 .unwrap()
3253 } else {
3254 DecoderBuilder::default()
3255 .with_config_yaml_str(E2E_COMBINED_SEGDET_CONFIG.to_string())
3256 .with_score_threshold(score_threshold)
3257 .with_iou_threshold(iou_threshold)
3258 .build()
3259 .unwrap()
3260 };
3261
3262 let expected = e2e_expected_boxes_quant();
3263 let mut output_boxes = Vec::with_capacity(50);
3264
3265 if is_split {
3266 let boxes = boxes.insert_axis(Axis(0));
3267 let scores = scores.insert_axis(Axis(0));
3268 let classes = classes.insert_axis(Axis(0));
3269 let mask = mask.insert_axis(Axis(0));
3270
3271 let boxes: Array3<u8> = quantize_ndarray(boxes.view(), detect_quant.into());
3272 let scores: Array3<u8> = quantize_ndarray(scores.view(), detect_quant.into());
3273 let classes: Array3<u8> = quantize_ndarray(classes.view(), detect_quant.into());
3274 let mask: Array3<u8> = quantize_ndarray(mask.view(), detect_quant.into());
3275
3276 if is_proto {
3277 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
3278 boxes.view().into(),
3279 scores.view().into(),
3280 classes.view().into(),
3281 mask.view().into(),
3282 protos.view().into(),
3283 ];
3284 decoder
3285 .decode_quantized_proto(&inputs, &mut output_boxes)
3286 .unwrap();
3287
3288 assert_eq!(output_boxes.len(), 1);
3289 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3290 } else {
3291 let mut output_masks = Vec::with_capacity(50);
3292 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
3293 boxes.view().into(),
3294 scores.view().into(),
3295 classes.view().into(),
3296 mask.view().into(),
3297 protos.view().into(),
3298 ];
3299 decoder
3300 .decode_quantized(&inputs, &mut output_boxes, &mut output_masks)
3301 .unwrap();
3302
3303 assert_eq!(output_boxes.len(), 1);
3304 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3305 }
3306 } else {
3307 let detect = ndarray::concatenate![
3309 Axis(1),
3310 boxes.view(),
3311 scores.view(),
3312 classes.view(),
3313 mask.view()
3314 ];
3315 let detect = detect.insert_axis(Axis(0));
3316 assert_eq!(detect.shape(), &[1, 10, 38]);
3317 let detect: Array3<u8> = quantize_ndarray(detect.view(), detect_quant.into());
3318
3319 if is_proto {
3320 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
3321 vec![detect.view().into(), protos.view().into()];
3322 decoder
3323 .decode_quantized_proto(&inputs, &mut output_boxes)
3324 .unwrap();
3325
3326 assert_eq!(output_boxes.len(), 1);
3327 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3328 } else {
3329 let mut output_masks = Vec::with_capacity(50);
3330 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
3331 vec![detect.view().into(), protos.view().into()];
3332 decoder
3333 .decode_quantized(&inputs, &mut output_boxes, &mut output_masks)
3334 .unwrap();
3335
3336 assert_eq!(output_boxes.len(), 1);
3337 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3338 }
3339 }
3340 }
3341 };
3342 ($name:ident, float, $layout:ident, $output:ident) => {
3343 #[test]
3344 fn $name() {
3345 let is_split = matches!(stringify!($layout), "split");
3346 let is_proto = matches!(stringify!($output), "proto");
3347
3348 let score_threshold = 0.45;
3349 let iou_threshold = 0.45;
3350
3351 let mut boxes = Array2::zeros((10, 4));
3352 let mut scores = Array2::zeros((10, 1));
3353 let mut classes = Array2::zeros((10, 1));
3354 let mask: Array2<f64> = Array2::zeros((10, 32));
3355 let protos = Array3::<f64>::zeros((160, 160, 32));
3356 let protos = protos.insert_axis(Axis(0));
3357
3358 boxes
3359 .slice_mut(s![0, ..])
3360 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
3361 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
3362 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
3363
3364 let decoder = if is_split {
3365 DecoderBuilder::default()
3366 .with_config_yaml_str(E2E_SPLIT_SEGDET_CONFIG.to_string())
3367 .with_score_threshold(score_threshold)
3368 .with_iou_threshold(iou_threshold)
3369 .build()
3370 .unwrap()
3371 } else {
3372 DecoderBuilder::default()
3373 .with_config_yaml_str(E2E_COMBINED_SEGDET_CONFIG.to_string())
3374 .with_score_threshold(score_threshold)
3375 .with_iou_threshold(iou_threshold)
3376 .build()
3377 .unwrap()
3378 };
3379
3380 let expected = e2e_expected_boxes_float();
3381 let mut output_boxes = Vec::with_capacity(50);
3382
3383 if is_split {
3384 let boxes = boxes.insert_axis(Axis(0));
3385 let scores = scores.insert_axis(Axis(0));
3386 let classes = classes.insert_axis(Axis(0));
3387 let mask = mask.insert_axis(Axis(0));
3388
3389 if is_proto {
3390 let inputs = vec![
3391 boxes.view().into_dyn(),
3392 scores.view().into_dyn(),
3393 classes.view().into_dyn(),
3394 mask.view().into_dyn(),
3395 protos.view().into_dyn(),
3396 ];
3397 decoder
3398 .decode_float_proto(&inputs, &mut output_boxes)
3399 .unwrap();
3400
3401 assert_eq!(output_boxes.len(), 1);
3402 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3403 } else {
3404 let mut output_masks = Vec::with_capacity(50);
3405 let inputs = vec![
3406 boxes.view().into_dyn(),
3407 scores.view().into_dyn(),
3408 classes.view().into_dyn(),
3409 mask.view().into_dyn(),
3410 protos.view().into_dyn(),
3411 ];
3412 decoder
3413 .decode_float(&inputs, &mut output_boxes, &mut output_masks)
3414 .unwrap();
3415
3416 assert_eq!(output_boxes.len(), 1);
3417 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3418 }
3419 } else {
3420 let detect = ndarray::concatenate![
3422 Axis(1),
3423 boxes.view(),
3424 scores.view(),
3425 classes.view(),
3426 mask.view()
3427 ];
3428 let detect = detect.insert_axis(Axis(0));
3429 assert_eq!(detect.shape(), &[1, 10, 38]);
3430
3431 if is_proto {
3432 let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
3433 decoder
3434 .decode_float_proto(&inputs, &mut output_boxes)
3435 .unwrap();
3436
3437 assert_eq!(output_boxes.len(), 1);
3438 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3439 } else {
3440 let mut output_masks = Vec::with_capacity(50);
3441 let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
3442 decoder
3443 .decode_float(&inputs, &mut output_boxes, &mut output_masks)
3444 .unwrap();
3445
3446 assert_eq!(output_boxes.len(), 1);
3447 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3448 }
3449 }
3450 }
3451 };
3452 }
3453
3454 e2e_segdet_test!(test_decoder_end_to_end_segdet, quantized, combined, masks);
3455 e2e_segdet_test!(test_decoder_end_to_end_segdet_float, float, combined, masks);
3456 e2e_segdet_test!(
3457 test_decoder_end_to_end_segdet_proto,
3458 quantized,
3459 combined,
3460 proto
3461 );
3462 e2e_segdet_test!(
3463 test_decoder_end_to_end_segdet_proto_float,
3464 float,
3465 combined,
3466 proto
3467 );
3468 e2e_segdet_test!(
3469 test_decoder_end_to_end_segdet_split,
3470 quantized,
3471 split,
3472 masks
3473 );
3474 e2e_segdet_test!(
3475 test_decoder_end_to_end_segdet_split_float,
3476 float,
3477 split,
3478 masks
3479 );
3480 e2e_segdet_test!(
3481 test_decoder_end_to_end_segdet_split_proto,
3482 quantized,
3483 split,
3484 proto
3485 );
3486 e2e_segdet_test!(
3487 test_decoder_end_to_end_segdet_split_proto_float,
3488 float,
3489 split,
3490 proto
3491 );
3492
3493 macro_rules! e2e_det_test {
3494 ($name:ident, quantized, $layout:ident) => {
3495 #[test]
3496 fn $name() {
3497 let is_split = matches!(stringify!($layout), "split");
3498
3499 let score_threshold = 0.45;
3500 let iou_threshold = 0.45;
3501
3502 let mut boxes = Array3::zeros((1, 10, 4));
3503 let mut scores = Array3::zeros((1, 10, 1));
3504 let mut classes = Array3::zeros((1, 10, 1));
3505
3506 boxes
3507 .slice_mut(s![0, 0, ..])
3508 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
3509 scores.slice_mut(s![0, 0, ..]).assign(&array![0.9876]);
3510 classes.slice_mut(s![0, 0, ..]).assign(&array![2.0]);
3511
3512 let detect_quant = (2.0 / 255.0, 0_i32);
3513
3514 let decoder = if is_split {
3515 DecoderBuilder::default()
3516 .with_config_yaml_str(E2E_SPLIT_DET_CONFIG.to_string())
3517 .with_score_threshold(score_threshold)
3518 .with_iou_threshold(iou_threshold)
3519 .build()
3520 .unwrap()
3521 } else {
3522 DecoderBuilder::default()
3523 .with_config_yaml_str(E2E_COMBINED_DET_CONFIG.to_string())
3524 .with_score_threshold(score_threshold)
3525 .with_iou_threshold(iou_threshold)
3526 .build()
3527 .unwrap()
3528 };
3529
3530 let expected = e2e_expected_boxes_quant();
3531 let mut output_boxes = Vec::with_capacity(50);
3532
3533 if is_split {
3534 let boxes: Array<u8, _> = quantize_ndarray(boxes.view(), detect_quant.into());
3535 let scores: Array<u8, _> = quantize_ndarray(scores.view(), detect_quant.into());
3536 let classes: Array<u8, _> =
3537 quantize_ndarray(classes.view(), detect_quant.into());
3538 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
3539 boxes.view().into(),
3540 scores.view().into(),
3541 classes.view().into(),
3542 ];
3543 decoder
3544 .decode_quantized(&inputs, &mut output_boxes, &mut Vec::new())
3545 .unwrap();
3546 } else {
3547 let detect =
3548 ndarray::concatenate![Axis(2), boxes.view(), scores.view(), classes.view()];
3549 assert_eq!(detect.shape(), &[1, 10, 6]);
3550 let detect: Array3<u8> = quantize_ndarray(detect.view(), detect_quant.into());
3551 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
3552 vec![detect.view().into()];
3553 decoder
3554 .decode_quantized(&inputs, &mut output_boxes, &mut Vec::new())
3555 .unwrap();
3556 }
3557
3558 assert_eq!(output_boxes.len(), 1);
3559 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
3560 }
3561 };
3562 ($name:ident, float, $layout:ident) => {
3563 #[test]
3564 fn $name() {
3565 let is_split = matches!(stringify!($layout), "split");
3566
3567 let score_threshold = 0.45;
3568 let iou_threshold = 0.45;
3569
3570 let mut boxes = Array3::zeros((1, 10, 4));
3571 let mut scores = Array3::zeros((1, 10, 1));
3572 let mut classes = Array3::zeros((1, 10, 1));
3573
3574 boxes
3575 .slice_mut(s![0, 0, ..])
3576 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
3577 scores.slice_mut(s![0, 0, ..]).assign(&array![0.9876]);
3578 classes.slice_mut(s![0, 0, ..]).assign(&array![2.0]);
3579
3580 let decoder = if is_split {
3581 DecoderBuilder::default()
3582 .with_config_yaml_str(E2E_SPLIT_DET_CONFIG.to_string())
3583 .with_score_threshold(score_threshold)
3584 .with_iou_threshold(iou_threshold)
3585 .build()
3586 .unwrap()
3587 } else {
3588 DecoderBuilder::default()
3589 .with_config_yaml_str(E2E_COMBINED_DET_CONFIG.to_string())
3590 .with_score_threshold(score_threshold)
3591 .with_iou_threshold(iou_threshold)
3592 .build()
3593 .unwrap()
3594 };
3595
3596 let expected = e2e_expected_boxes_float();
3597 let mut output_boxes = Vec::with_capacity(50);
3598
3599 if is_split {
3600 let inputs = vec![
3601 boxes.view().into_dyn(),
3602 scores.view().into_dyn(),
3603 classes.view().into_dyn(),
3604 ];
3605 decoder
3606 .decode_float(&inputs, &mut output_boxes, &mut Vec::new())
3607 .unwrap();
3608 } else {
3609 let detect =
3610 ndarray::concatenate![Axis(2), boxes.view(), scores.view(), classes.view()];
3611 assert_eq!(detect.shape(), &[1, 10, 6]);
3612 let inputs = vec![detect.view().into_dyn()];
3613 decoder
3614 .decode_float(&inputs, &mut output_boxes, &mut Vec::new())
3615 .unwrap();
3616 }
3617
3618 assert_eq!(output_boxes.len(), 1);
3619 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
3620 }
3621 };
3622 }
3623
3624 e2e_det_test!(test_decoder_end_to_end_combined_det, quantized, combined);
3625 e2e_det_test!(test_decoder_end_to_end_combined_det_float, float, combined);
3626 e2e_det_test!(test_decoder_end_to_end_split_det, quantized, split);
3627 e2e_det_test!(test_decoder_end_to_end_split_det_float, float, split);
3628
3629 #[test]
3630 fn test_decode_tensor() {
3631 let score_threshold = 0.45;
3632 let iou_threshold = 0.45;
3633
3634 let raw_boxes = include_bytes!(concat!(
3635 env!("CARGO_MANIFEST_DIR"),
3636 "/../../testdata/yolov8_boxes_116x8400.bin"
3637 ));
3638 let raw_boxes =
3639 unsafe { std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len()) };
3640 let boxes_i8: Tensor<i8> = Tensor::new(&[1, 116, 8400], None, None).unwrap();
3641 boxes_i8
3642 .map()
3643 .unwrap()
3644 .as_mut_slice()
3645 .copy_from_slice(raw_boxes);
3646 let boxes_i8 = boxes_i8.into();
3647
3648 let raw_protos = include_bytes!(concat!(
3649 env!("CARGO_MANIFEST_DIR"),
3650 "/../../testdata/yolov8_protos_160x160x32.bin"
3651 ));
3652 let raw_protos = unsafe {
3653 std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3654 };
3655 let protos_i8: Tensor<i8> = Tensor::new(&[1, 160, 160, 32], None, None).unwrap();
3656 protos_i8
3657 .map()
3658 .unwrap()
3659 .as_mut_slice()
3660 .copy_from_slice(raw_protos);
3661 let protos_i8 = protos_i8.into();
3662
3663 let decoder = build_yolov8_seg_decoder(score_threshold, iou_threshold);
3664 let expected = real_data_expected_boxes();
3665 let mut output_boxes = Vec::with_capacity(50);
3666
3667 decoder
3668 .decode(&[&boxes_i8, &protos_i8], &mut output_boxes, &mut Vec::new())
3669 .unwrap();
3670
3671 assert_eq!(output_boxes.len(), 2);
3672 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3673 assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3674 }
3675
3676 #[test]
3677 fn test_decode_tensor_f32() {
3678 let score_threshold = 0.45;
3679 let iou_threshold = 0.45;
3680
3681 let quant_boxes = (0.021287762_f32, 31_i32);
3682 let quant_protos = (0.02491162_f32, -117_i32);
3683 let raw_boxes = include_bytes!(concat!(
3684 env!("CARGO_MANIFEST_DIR"),
3685 "/../../testdata/yolov8_boxes_116x8400.bin"
3686 ));
3687 let raw_boxes =
3688 unsafe { std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len()) };
3689 let mut raw_boxes_f32 = vec![0f32; raw_boxes.len()];
3690 dequantize_cpu(raw_boxes, quant_boxes.into(), &mut raw_boxes_f32);
3691 let boxes_f32: Tensor<f32> = Tensor::new(&[1, 116, 8400], None, None).unwrap();
3692 boxes_f32
3693 .map()
3694 .unwrap()
3695 .as_mut_slice()
3696 .copy_from_slice(&raw_boxes_f32);
3697 let boxes_f32 = boxes_f32.into();
3698
3699 let raw_protos = include_bytes!(concat!(
3700 env!("CARGO_MANIFEST_DIR"),
3701 "/../../testdata/yolov8_protos_160x160x32.bin"
3702 ));
3703 let raw_protos = unsafe {
3704 std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3705 };
3706 let mut raw_protos_f32 = vec![0f32; raw_protos.len()];
3707 dequantize_cpu(raw_protos, quant_protos.into(), &mut raw_protos_f32);
3708 let protos_f32: Tensor<f32> = Tensor::new(&[1, 160, 160, 32], None, None).unwrap();
3709 protos_f32
3710 .map()
3711 .unwrap()
3712 .as_mut_slice()
3713 .copy_from_slice(&raw_protos_f32);
3714 let protos_f32 = protos_f32.into();
3715
3716 let decoder = build_yolov8_seg_decoder(score_threshold, iou_threshold);
3717
3718 let expected = real_data_expected_boxes();
3719 let mut output_boxes = Vec::with_capacity(50);
3720
3721 decoder
3722 .decode(
3723 &[&boxes_f32, &protos_f32],
3724 &mut output_boxes,
3725 &mut Vec::new(),
3726 )
3727 .unwrap();
3728
3729 assert_eq!(output_boxes.len(), 2);
3730 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3731 assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3732 }
3733
3734 #[test]
3735 fn test_decode_tensor_f64() {
3736 let score_threshold = 0.45;
3737 let iou_threshold = 0.45;
3738
3739 let quant_boxes = (0.021287762_f32, 31_i32);
3740 let quant_protos = (0.02491162_f32, -117_i32);
3741 let raw_boxes = include_bytes!(concat!(
3742 env!("CARGO_MANIFEST_DIR"),
3743 "/../../testdata/yolov8_boxes_116x8400.bin"
3744 ));
3745 let raw_boxes =
3746 unsafe { std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len()) };
3747 let mut raw_boxes_f64 = vec![0f64; raw_boxes.len()];
3748 dequantize_cpu(raw_boxes, quant_boxes.into(), &mut raw_boxes_f64);
3749 let boxes_f64: Tensor<f64> = Tensor::new(&[1, 116, 8400], None, None).unwrap();
3750 boxes_f64
3751 .map()
3752 .unwrap()
3753 .as_mut_slice()
3754 .copy_from_slice(&raw_boxes_f64);
3755 let boxes_f64 = boxes_f64.into();
3756
3757 let raw_protos = include_bytes!(concat!(
3758 env!("CARGO_MANIFEST_DIR"),
3759 "/../../testdata/yolov8_protos_160x160x32.bin"
3760 ));
3761 let raw_protos = unsafe {
3762 std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3763 };
3764 let mut raw_protos_f64 = vec![0f64; raw_protos.len()];
3765 dequantize_cpu(raw_protos, quant_protos.into(), &mut raw_protos_f64);
3766 let protos_f64: Tensor<f64> = Tensor::new(&[1, 160, 160, 32], None, None).unwrap();
3767 protos_f64
3768 .map()
3769 .unwrap()
3770 .as_mut_slice()
3771 .copy_from_slice(&raw_protos_f64);
3772 let protos_f64 = protos_f64.into();
3773
3774 let decoder = build_yolov8_seg_decoder(score_threshold, iou_threshold);
3775
3776 let expected = real_data_expected_boxes();
3777 let mut output_boxes = Vec::with_capacity(50);
3778
3779 decoder
3780 .decode(
3781 &[&boxes_f64, &protos_f64],
3782 &mut output_boxes,
3783 &mut Vec::new(),
3784 )
3785 .unwrap();
3786
3787 assert_eq!(output_boxes.len(), 2);
3788 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3789 assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3790 }
3791
3792 #[test]
3793 fn test_decode_tensor_proto() {
3794 let score_threshold = 0.45;
3795 let iou_threshold = 0.45;
3796
3797 let raw_boxes = include_bytes!(concat!(
3798 env!("CARGO_MANIFEST_DIR"),
3799 "/../../testdata/yolov8_boxes_116x8400.bin"
3800 ));
3801 let raw_boxes =
3802 unsafe { std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len()) };
3803 let boxes_i8: Tensor<i8> = Tensor::new(&[1, 116, 8400], None, None).unwrap();
3804 boxes_i8
3805 .map()
3806 .unwrap()
3807 .as_mut_slice()
3808 .copy_from_slice(raw_boxes);
3809 let boxes_i8 = boxes_i8.into();
3810
3811 let raw_protos = include_bytes!(concat!(
3812 env!("CARGO_MANIFEST_DIR"),
3813 "/../../testdata/yolov8_protos_160x160x32.bin"
3814 ));
3815 let raw_protos = unsafe {
3816 std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3817 };
3818 let protos_i8: Tensor<i8> = Tensor::new(&[1, 160, 160, 32], None, None).unwrap();
3819 protos_i8
3820 .map()
3821 .unwrap()
3822 .as_mut_slice()
3823 .copy_from_slice(raw_protos);
3824 let protos_i8 = protos_i8.into();
3825
3826 let decoder = build_yolov8_seg_decoder(score_threshold, iou_threshold);
3827
3828 let expected = real_data_expected_boxes();
3829 let mut output_boxes = Vec::with_capacity(50);
3830
3831 let proto_data = decoder
3832 .decode_proto(&[&boxes_i8, &protos_i8], &mut output_boxes)
3833 .unwrap();
3834
3835 assert_eq!(output_boxes.len(), 2);
3836 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3837 assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3838
3839 let proto_data = proto_data.expect("segmentation model should return ProtoData");
3840 let coeffs_shape = proto_data.mask_coefficients.shape();
3841 assert_eq!(
3842 coeffs_shape[0],
3843 output_boxes.len(),
3844 "mask_coefficients count must match detection count"
3845 );
3846 assert_eq!(
3847 coeffs_shape[1], 32,
3848 "each detection should have 32 mask coefficients"
3849 );
3850 }
3851
3852 #[test]
3870 fn test_physical_order_tflite_nhwc_protos() {
3871 let score_threshold = 0.45;
3872 let iou_threshold = 0.45;
3873
3874 let protos_hwc = load_yolov8_protos().slice_move(s![0, .., .., ..]);
3876 let quant_protos = Quantization::new(0.02491161972284317, -117);
3877 let protos_f32_hwc = dequantize_ndarray::<_, _, f32>(protos_hwc.view(), quant_protos);
3878
3879 let boxes_2d = load_yolov8_boxes().slice_move(s![0, .., ..]);
3880 let quant_boxes = Quantization::new(0.021287761628627777, 31);
3881 let seg = dequantize_ndarray::<_, _, f32>(boxes_2d.view(), quant_boxes);
3882
3883 let mut ref_boxes: Vec<_> = Vec::with_capacity(10);
3885 let mut ref_masks: Vec<_> = Vec::with_capacity(10);
3886 decode_yolo_segdet_float(
3887 seg.view(),
3888 protos_f32_hwc.view(),
3889 score_threshold,
3890 iou_threshold,
3891 Some(configs::Nms::ClassAgnostic),
3892 &mut ref_boxes,
3893 &mut ref_masks,
3894 )
3895 .unwrap();
3896
3897 let protos_nhwc = protos_f32_hwc.clone().insert_axis(Axis(0)); let seg_3d = seg.insert_axis(Axis(0)); let decoder = DecoderBuilder::default()
3903 .with_config_yolo_segdet(
3904 configs::Detection {
3905 decoder: configs::DecoderType::Ultralytics,
3906 quantization: None,
3907 shape: vec![1, 116, 8400],
3908 dshape: vec![
3909 (DimName::Batch, 1),
3910 (DimName::NumFeatures, 116),
3911 (DimName::NumBoxes, 8400),
3912 ],
3913 normalized: Some(true),
3914 anchors: None,
3915 },
3916 configs::Protos {
3917 decoder: configs::DecoderType::Ultralytics,
3918 quantization: None,
3919 shape: vec![1, 160, 160, 32],
3920 dshape: vec![
3922 (DimName::Batch, 1),
3923 (DimName::Height, 160),
3924 (DimName::Width, 160),
3925 (DimName::NumProtos, 32),
3926 ],
3927 },
3928 None,
3929 )
3930 .with_score_threshold(score_threshold)
3931 .with_iou_threshold(iou_threshold)
3932 .build()
3933 .expect("config with NHWC protos dshape must build");
3934
3935 let mut cfg_boxes = Vec::with_capacity(10);
3936 let mut cfg_masks = Vec::with_capacity(10);
3937 decoder
3938 .decode_float(
3939 &[seg_3d.view().into_dyn(), protos_nhwc.view().into_dyn()],
3940 &mut cfg_boxes,
3941 &mut cfg_masks,
3942 )
3943 .unwrap();
3944
3945 assert_eq!(cfg_boxes.len(), ref_boxes.len(), "box count mismatch");
3946 for (c, r) in cfg_boxes.iter().zip(&ref_boxes) {
3947 assert!(
3948 c.equal_within_delta(r, 0.01),
3949 "NHWC-declared box does not match reference: {c:?} vs {r:?}"
3950 );
3951 }
3952 for (cm, rm) in cfg_masks.iter().zip(&ref_masks) {
3953 let cm_arr = segmentation_to_mask(cm.segmentation.view()).unwrap();
3954 let rm_arr = segmentation_to_mask(rm.segmentation.view()).unwrap();
3955 assert_eq!(
3956 cm_arr, rm_arr,
3957 "NHWC-declared mask must match reference pixel-for-pixel"
3958 );
3959 }
3960 }
3961
3962 #[test]
3971 fn test_physical_order_ara2_anchor_first_split_boxes() {
3972 use configs::{Boxes, Scores};
3973
3974 const N: usize = 8400;
3977 let mut boxes_canonical = Array3::<f32>::zeros((1, 4, N));
3978 let target_anchor = 42usize;
3979 boxes_canonical[[0, 0, target_anchor]] = 0.4; boxes_canonical[[0, 1, target_anchor]] = 0.5; boxes_canonical[[0, 2, target_anchor]] = 0.2; boxes_canonical[[0, 3, target_anchor]] = 0.2; let mut scores_canonical = Array3::<f32>::zeros((1, 80, N));
3986 scores_canonical[[0, 0, target_anchor]] = 0.9;
3987
3988 let ref_decoder = DecoderBuilder::default()
3990 .with_config_yolo_split_det(
3991 Boxes {
3992 decoder: configs::DecoderType::Ultralytics,
3993 quantization: None,
3994 shape: vec![1, 4, N],
3995 dshape: vec![
3996 (DimName::Batch, 1),
3997 (DimName::BoxCoords, 4),
3998 (DimName::NumBoxes, N),
3999 ],
4000 normalized: Some(true),
4001 },
4002 Scores {
4003 decoder: configs::DecoderType::Ultralytics,
4004 quantization: None,
4005 shape: vec![1, 80, N],
4006 dshape: vec![
4007 (DimName::Batch, 1),
4008 (DimName::NumClasses, 80),
4009 (DimName::NumBoxes, N),
4010 ],
4011 },
4012 )
4013 .with_score_threshold(0.5)
4014 .with_iou_threshold(0.5)
4015 .with_nms(Some(configs::Nms::ClassAgnostic))
4016 .build()
4017 .expect("reference canonical split decoder must build");
4018
4019 let mut ref_boxes = Vec::with_capacity(4);
4020 let mut ref_masks = Vec::with_capacity(0);
4021 ref_decoder
4022 .decode_float(
4023 &[
4024 boxes_canonical.view().into_dyn(),
4025 scores_canonical.view().into_dyn(),
4026 ],
4027 &mut ref_boxes,
4028 &mut ref_masks,
4029 )
4030 .unwrap();
4031 assert_eq!(ref_boxes.len(), 1, "reference should produce one box");
4032
4033 let boxes_ara2 = boxes_canonical.view().permuted_axes([0, 2, 1]).to_owned(); let scores_ara2 = scores_canonical.view().permuted_axes([0, 2, 1]).to_owned(); let ara2_decoder = DecoderBuilder::default()
4041 .with_config_yolo_split_det(
4042 Boxes {
4043 decoder: configs::DecoderType::Ultralytics,
4044 quantization: None,
4045 shape: vec![1, N, 4],
4046 dshape: vec![
4047 (DimName::Batch, 1),
4048 (DimName::NumBoxes, N),
4049 (DimName::BoxCoords, 4),
4050 ],
4051 normalized: Some(true),
4052 },
4053 Scores {
4054 decoder: configs::DecoderType::Ultralytics,
4055 quantization: None,
4056 shape: vec![1, N, 80],
4057 dshape: vec![
4058 (DimName::Batch, 1),
4059 (DimName::NumBoxes, N),
4060 (DimName::NumClasses, 80),
4061 ],
4062 },
4063 )
4064 .with_score_threshold(0.5)
4065 .with_iou_threshold(0.5)
4066 .with_nms(Some(configs::Nms::ClassAgnostic))
4067 .build()
4068 .expect("Ara-2 anchor-first decoder must build");
4069
4070 let mut ara2_boxes = Vec::with_capacity(4);
4071 let mut ara2_masks = Vec::with_capacity(0);
4072 ara2_decoder
4073 .decode_float(
4074 &[boxes_ara2.view().into_dyn(), scores_ara2.view().into_dyn()],
4075 &mut ara2_boxes,
4076 &mut ara2_masks,
4077 )
4078 .unwrap();
4079
4080 assert_eq!(
4081 ara2_boxes.len(),
4082 ref_boxes.len(),
4083 "Ara-2 anchor-first declaration must produce the same number \
4084 of boxes as the canonical features-first reference"
4085 );
4086 for (a, r) in ara2_boxes.iter().zip(&ref_boxes) {
4087 assert!(
4088 a.equal_within_delta(r, 1e-4),
4089 "Ara-2 box differs from reference: {a:?} vs {r:?}"
4090 );
4091 }
4092 }
4093
4094 #[test]
4098 fn test_physical_order_rejects_shape_dshape_mismatch() {
4099 let result = DecoderBuilder::default()
4100 .with_config_yolo_segdet(
4101 configs::Detection {
4102 decoder: configs::DecoderType::Ultralytics,
4103 quantization: None,
4104 shape: vec![1, 116, 8400],
4105 dshape: vec![
4106 (DimName::Batch, 1),
4107 (DimName::NumFeatures, 116),
4108 (DimName::NumBoxes, 8400),
4109 ],
4110 normalized: Some(true),
4111 anchors: None,
4112 },
4113 configs::Protos {
4114 decoder: configs::DecoderType::Ultralytics,
4115 quantization: None,
4116 shape: vec![1, 32, 160, 160],
4118 dshape: vec![
4121 (DimName::Batch, 1),
4122 (DimName::Height, 160),
4123 (DimName::Width, 160),
4124 (DimName::NumProtos, 32),
4125 ],
4126 },
4127 None,
4128 )
4129 .build();
4130
4131 match result {
4132 Err(DecoderError::InvalidConfig(msg)) => {
4133 assert!(
4134 msg.contains("does not match shape"),
4135 "expected shape/dshape size mismatch error, got: {msg}"
4136 );
4137 }
4138 other => panic!("expected InvalidConfig, got {other:?}"),
4139 }
4140 }
4141
4142 #[test]
4145 fn test_physical_order_rejects_duplicate_dshape_axis() {
4146 let result = DecoderBuilder::default()
4147 .with_config_yolo_split_det(
4148 configs::Boxes {
4149 decoder: configs::DecoderType::Ultralytics,
4150 quantization: None,
4151 shape: vec![1, 4, 8400],
4152 dshape: vec![
4153 (DimName::Batch, 1),
4154 (DimName::BoxCoords, 4),
4155 (DimName::BoxCoords, 4), ],
4157 normalized: Some(true),
4158 },
4159 configs::Scores {
4160 decoder: configs::DecoderType::Ultralytics,
4161 quantization: None,
4162 shape: vec![1, 80, 8400],
4163 dshape: vec![
4164 (DimName::Batch, 1),
4165 (DimName::NumClasses, 80),
4166 (DimName::NumBoxes, 8400),
4167 ],
4168 },
4169 )
4170 .build();
4171
4172 match result {
4177 Err(DecoderError::InvalidConfig(msg)) => {
4178 assert!(
4179 msg.contains("appears at both index") || msg.contains("does not match shape"),
4180 "expected positional or duplicate-axis error, got: {msg}"
4181 );
4182 }
4183 other => panic!("expected InvalidConfig, got {other:?}"),
4184 }
4185
4186 let result = DecoderBuilder::default()
4191 .with_config_yolo_split_det(
4192 configs::Boxes {
4193 decoder: configs::DecoderType::Ultralytics,
4194 quantization: None,
4195 shape: vec![1, 1, 4, 8400],
4196 dshape: vec![
4197 (DimName::Batch, 1),
4198 (DimName::Batch, 1), (DimName::BoxCoords, 4),
4200 (DimName::NumBoxes, 8400),
4201 ],
4202 normalized: Some(true),
4203 },
4204 configs::Scores {
4205 decoder: configs::DecoderType::Ultralytics,
4206 quantization: None,
4207 shape: vec![1, 80, 8400],
4208 dshape: vec![
4209 (DimName::Batch, 1),
4210 (DimName::NumClasses, 80),
4211 (DimName::NumBoxes, 8400),
4212 ],
4213 },
4214 )
4215 .build();
4216 match result {
4217 Err(DecoderError::InvalidConfig(msg)) => {
4218 assert!(
4219 msg.contains("appears at both index"),
4220 "expected duplicate-axis error, got: {msg}"
4221 );
4222 }
4223 other => panic!("expected InvalidConfig, got {other:?}"),
4224 }
4225 }
4226
4227 #[test]
4233 fn test_physical_order_dshape_omitted_decodes_numerically() {
4234 let score_threshold = 0.45;
4235 let iou_threshold = 0.45;
4236
4237 let protos_hwc = load_yolov8_protos().slice_move(s![0, .., .., ..]);
4238 let quant_protos = Quantization::new(0.02491161972284317, -117);
4239 let protos_f32_hwc = dequantize_ndarray::<_, _, f32>(protos_hwc.view(), quant_protos);
4240
4241 let boxes_2d = load_yolov8_boxes().slice_move(s![0, .., ..]);
4242 let quant_boxes = Quantization::new(0.021287761628627777, 31);
4243 let seg = dequantize_ndarray::<_, _, f32>(boxes_2d.view(), quant_boxes);
4244
4245 let protos_nhwc = protos_f32_hwc.clone().insert_axis(Axis(0));
4246 let seg_3d = seg.insert_axis(Axis(0));
4247
4248 let build_decoder = |det_dshape: Vec<(DimName, usize)>,
4249 proto_dshape: Vec<(DimName, usize)>| {
4250 DecoderBuilder::default()
4251 .with_config_yolo_segdet(
4252 configs::Detection {
4253 decoder: configs::DecoderType::Ultralytics,
4254 quantization: None,
4255 shape: vec![1, 116, 8400],
4256 dshape: det_dshape,
4257 normalized: Some(true),
4258 anchors: None,
4259 },
4260 configs::Protos {
4261 decoder: configs::DecoderType::Ultralytics,
4262 quantization: None,
4263 shape: vec![1, 160, 160, 32],
4264 dshape: proto_dshape,
4265 },
4266 None,
4267 )
4268 .with_score_threshold(score_threshold)
4269 .with_iou_threshold(iou_threshold)
4270 .build()
4271 .unwrap()
4272 };
4273
4274 let dshaped = build_decoder(
4276 vec![
4277 (DimName::Batch, 1),
4278 (DimName::NumFeatures, 116),
4279 (DimName::NumBoxes, 8400),
4280 ],
4281 vec![
4282 (DimName::Batch, 1),
4283 (DimName::Height, 160),
4284 (DimName::Width, 160),
4285 (DimName::NumProtos, 32),
4286 ],
4287 );
4288 let mut dshaped_boxes = Vec::new();
4289 let mut dshaped_masks = Vec::new();
4290 dshaped
4291 .decode_float(
4292 &[seg_3d.view().into_dyn(), protos_nhwc.view().into_dyn()],
4293 &mut dshaped_boxes,
4294 &mut dshaped_masks,
4295 )
4296 .unwrap();
4297
4298 let bare = build_decoder(vec![], vec![]);
4301 let mut bare_boxes = Vec::new();
4302 let mut bare_masks = Vec::new();
4303 bare.decode_float(
4304 &[seg_3d.view().into_dyn(), protos_nhwc.view().into_dyn()],
4305 &mut bare_boxes,
4306 &mut bare_masks,
4307 )
4308 .unwrap();
4309
4310 assert_eq!(bare_boxes.len(), dshaped_boxes.len());
4311 for (b, d) in bare_boxes.iter().zip(&dshaped_boxes) {
4312 assert!(
4313 b.equal_within_delta(d, 1e-4),
4314 "dshape-omitted box {b:?} differs from dshape-populated {d:?}"
4315 );
4316 }
4317 for (bm, dm) in bare_masks.iter().zip(&dshaped_masks) {
4318 let bm_arr = segmentation_to_mask(bm.segmentation.view()).unwrap();
4319 let dm_arr = segmentation_to_mask(dm.segmentation.view()).unwrap();
4320 assert_eq!(
4321 bm_arr, dm_arr,
4322 "dshape-omitted mask must match dshape-populated pixel-for-pixel"
4323 );
4324 }
4325 }
4326
4327 #[test]
4337 fn test_physical_order_ara2_4d_anchor_first_with_padding() {
4338 const N: usize = 8400;
4343 let mut boxes = Array3::<f32>::zeros((1, N, 4));
4344 let target = 42usize;
4345 boxes[[0, target, 0]] = 0.4;
4346 boxes[[0, target, 1]] = 0.5;
4347 boxes[[0, target, 2]] = 0.2;
4348 boxes[[0, target, 3]] = 0.2;
4349 let mut scores = Array3::<f32>::zeros((1, N, 80));
4350 scores[[0, target, 0]] = 0.9;
4351
4352 let json = r#"{
4358 "schema_version": 2,
4359 "decoder_version": "yolov8",
4360 "nms": "class_agnostic",
4361 "outputs": [
4362 {"name": "boxes", "type": "boxes",
4363 "shape": [1, 8400, 1, 4],
4364 "dshape": [{"batch":1},{"num_boxes":8400},{"padding":1},{"box_coords":4}],
4365 "encoding": "direct",
4366 "decoder": "ultralytics",
4367 "normalized": true},
4368 {"name": "scores", "type": "scores",
4369 "shape": [1, 8400, 1, 80],
4370 "dshape": [{"batch":1},{"num_boxes":8400},{"padding":1},{"num_classes":80}],
4371 "decoder": "ultralytics",
4372 "score_format": "per_class"}
4373 ]
4374 }"#;
4375 let decoder = DecoderBuilder::default()
4376 .with_config_json_str(json.to_string())
4377 .with_score_threshold(0.5)
4378 .with_iou_threshold(0.5)
4379 .build()
4380 .expect("4D anchor-first schema should build via squeeze_padding_dims");
4381
4382 let mut out_boxes = Vec::with_capacity(4);
4383 let mut out_masks = Vec::with_capacity(0);
4384 decoder
4385 .decode_float(
4386 &[boxes.view().into_dyn(), scores.view().into_dyn()],
4387 &mut out_boxes,
4388 &mut out_masks,
4389 )
4390 .unwrap();
4391
4392 assert_eq!(
4393 out_boxes.len(),
4394 1,
4395 "4D anchor-first with padding should decode exactly one box from the seeded anchor"
4396 );
4397 let b = &out_boxes[0];
4398 assert!((b.bbox.xmin - 0.3).abs() < 1e-3, "xmin wrong: {b:?}");
4400 assert!((b.bbox.ymin - 0.4).abs() < 1e-3, "ymin wrong: {b:?}");
4401 assert!((b.bbox.xmax - 0.5).abs() < 1e-3, "xmax wrong: {b:?}");
4402 assert!((b.bbox.ymax - 0.6).abs() < 1e-3, "ymax wrong: {b:?}");
4403 assert_eq!(b.label, 0);
4404 assert!(b.score > 0.85, "score {}: {b:?}", b.score);
4405 }
4406}
4407
4408#[cfg(feature = "tracker")]
4409#[cfg(test)]
4410#[cfg_attr(coverage_nightly, coverage(off))]
4411mod decoder_tracked_tests {
4412
4413 use edgefirst_tracker::{ByteTrackBuilder, Tracker};
4414 use ndarray::{array, s, Array, Array2, Array3, Array4, ArrayView, Axis, Dimension};
4415 use num_traits::{AsPrimitive, Float, PrimInt};
4416 use rand::{RngExt, SeedableRng};
4417 use rand_distr::StandardNormal;
4418
4419 use crate::{
4420 configs::{self, DimName},
4421 dequantize_ndarray, BoundingBox, DecoderBuilder, DetectBox, Quantization,
4422 };
4423
4424 pub fn quantize_ndarray<T: PrimInt + 'static, D: Dimension, F: Float + AsPrimitive<T>>(
4425 input: ArrayView<F, D>,
4426 quant: Quantization,
4427 ) -> Array<T, D>
4428 where
4429 i32: num_traits::AsPrimitive<F>,
4430 f32: num_traits::AsPrimitive<F>,
4431 {
4432 let zero_point = quant.zero_point.as_();
4433 let div_scale = F::one() / quant.scale.as_();
4434 if zero_point != F::zero() {
4435 input.mapv(|d| (d * div_scale + zero_point).round().as_())
4436 } else {
4437 input.mapv(|d| (d * div_scale).round().as_())
4438 }
4439 }
4440
4441 #[test]
4442 fn test_decoder_tracked_random_jitter() {
4443 use crate::configs::{DecoderType, Nms};
4444 use crate::DecoderBuilder;
4445
4446 let score_threshold = 0.25;
4447 let iou_threshold = 0.1;
4448 let out = include_bytes!(concat!(
4449 env!("CARGO_MANIFEST_DIR"),
4450 "/../../testdata/yolov8s_80_classes.bin"
4451 ));
4452 let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
4453 let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
4454 let quant = (0.0040811873, -123).into();
4455
4456 let decoder = DecoderBuilder::default()
4457 .with_config_yolo_det(
4458 crate::configs::Detection {
4459 decoder: DecoderType::Ultralytics,
4460 shape: vec![1, 84, 8400],
4461 anchors: None,
4462 quantization: Some(quant),
4463 dshape: vec![
4464 (crate::configs::DimName::Batch, 1),
4465 (crate::configs::DimName::NumFeatures, 84),
4466 (crate::configs::DimName::NumBoxes, 8400),
4467 ],
4468 normalized: Some(true),
4469 },
4470 None,
4471 )
4472 .with_score_threshold(score_threshold)
4473 .with_iou_threshold(iou_threshold)
4474 .with_nms(Some(Nms::ClassAgnostic))
4475 .build()
4476 .unwrap();
4477 let mut rng = rand::rngs::StdRng::seed_from_u64(0xAB_BEEF); let expected_boxes = [
4480 crate::DetectBox {
4481 bbox: crate::BoundingBox {
4482 xmin: 0.5285137,
4483 ymin: 0.05305544,
4484 xmax: 0.87541467,
4485 ymax: 0.9998909,
4486 },
4487 score: 0.5591227,
4488 label: 0,
4489 },
4490 crate::DetectBox {
4491 bbox: crate::BoundingBox {
4492 xmin: 0.130598,
4493 ymin: 0.43260583,
4494 xmax: 0.35098213,
4495 ymax: 0.9958097,
4496 },
4497 score: 0.33057618,
4498 label: 75,
4499 },
4500 ];
4501
4502 let mut tracker = ByteTrackBuilder::new()
4503 .track_update(0.1)
4504 .track_high_conf(0.3)
4505 .build();
4506
4507 let mut output_boxes = Vec::with_capacity(50);
4508 let mut output_masks = Vec::with_capacity(50);
4509 let mut output_tracks = Vec::with_capacity(50);
4510
4511 decoder
4512 .decode_tracked_quantized(
4513 &mut tracker,
4514 0,
4515 &[out.view().into()],
4516 &mut output_boxes,
4517 &mut output_masks,
4518 &mut output_tracks,
4519 )
4520 .unwrap();
4521
4522 assert_eq!(output_boxes.len(), 2);
4523 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4524 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
4525
4526 let mut last_boxes = output_boxes.clone();
4527
4528 for i in 1..=100 {
4529 let mut out = out.clone();
4530 let mut x_values = out.slice_mut(s![0, 0, ..]);
4532 for x in x_values.iter_mut() {
4533 let r: f32 = rng.sample(StandardNormal);
4534 let r = r.clamp(-2.0, 2.0) / 2.0;
4535 *x = x.saturating_add((r * 1e-2 / quant.0) as i8);
4536 }
4537
4538 let mut y_values = out.slice_mut(s![0, 1, ..]);
4539 for y in y_values.iter_mut() {
4540 let r: f32 = rng.sample(StandardNormal);
4541 let r = r.clamp(-2.0, 2.0) / 2.0;
4542 *y = y.saturating_add((r * 1e-2 / quant.0) as i8);
4543 }
4544
4545 decoder
4546 .decode_tracked_quantized(
4547 &mut tracker,
4548 100_000_000 * i / 3, &[out.view().into()],
4550 &mut output_boxes,
4551 &mut output_masks,
4552 &mut output_tracks,
4553 )
4554 .unwrap();
4555
4556 assert_eq!(output_boxes.len(), 2);
4557 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 5e-3));
4558 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 5e-3));
4559
4560 assert!(output_boxes[0].equal_within_delta(&last_boxes[0], 1e-3));
4561 assert!(output_boxes[1].equal_within_delta(&last_boxes[1], 1e-3));
4562 last_boxes = output_boxes.clone();
4563 }
4564 }
4565
4566 fn real_data_expected_boxes() -> [DetectBox; 2] {
4569 [
4570 DetectBox {
4571 bbox: BoundingBox {
4572 xmin: 0.08515105,
4573 ymin: 0.7131401,
4574 xmax: 0.29802868,
4575 ymax: 0.8195788,
4576 },
4577 score: 0.91537374,
4578 label: 23,
4579 },
4580 DetectBox {
4581 bbox: BoundingBox {
4582 xmin: 0.59605736,
4583 ymin: 0.25545314,
4584 xmax: 0.93666154,
4585 ymax: 0.72378385,
4586 },
4587 score: 0.91537374,
4588 label: 23,
4589 },
4590 ]
4591 }
4592
4593 fn e2e_expected_boxes_quant() -> [DetectBox; 1] {
4594 [DetectBox {
4595 bbox: BoundingBox {
4596 xmin: 0.12549022,
4597 ymin: 0.12549022,
4598 xmax: 0.23529413,
4599 ymax: 0.23529413,
4600 },
4601 score: 0.98823535,
4602 label: 2,
4603 }]
4604 }
4605
4606 fn e2e_expected_boxes_float() -> [DetectBox; 1] {
4607 [DetectBox {
4608 bbox: BoundingBox {
4609 xmin: 0.1234,
4610 ymin: 0.1234,
4611 xmax: 0.2345,
4612 ymax: 0.2345,
4613 },
4614 score: 0.9876,
4615 label: 2,
4616 }]
4617 }
4618
4619 fn build_yolo_split_segdet_decoder(
4620 score_threshold: f32,
4621 iou_threshold: f32,
4622 quant_boxes: (f32, i32),
4623 quant_protos: (f32, i32),
4624 ) -> crate::Decoder {
4625 DecoderBuilder::default()
4626 .with_config_yolo_split_segdet(
4627 configs::Boxes {
4628 decoder: configs::DecoderType::Ultralytics,
4629 quantization: Some(quant_boxes.into()),
4630 shape: vec![1, 4, 8400],
4631 dshape: vec![
4632 (DimName::Batch, 1),
4633 (DimName::BoxCoords, 4),
4634 (DimName::NumBoxes, 8400),
4635 ],
4636 normalized: Some(true),
4637 },
4638 configs::Scores {
4639 decoder: configs::DecoderType::Ultralytics,
4640 quantization: Some(quant_boxes.into()),
4641 shape: vec![1, 80, 8400],
4642 dshape: vec![
4643 (DimName::Batch, 1),
4644 (DimName::NumClasses, 80),
4645 (DimName::NumBoxes, 8400),
4646 ],
4647 },
4648 configs::MaskCoefficients {
4649 decoder: configs::DecoderType::Ultralytics,
4650 quantization: Some(quant_boxes.into()),
4651 shape: vec![1, 32, 8400],
4652 dshape: vec![
4653 (DimName::Batch, 1),
4654 (DimName::NumProtos, 32),
4655 (DimName::NumBoxes, 8400),
4656 ],
4657 },
4658 configs::Protos {
4659 decoder: configs::DecoderType::Ultralytics,
4660 quantization: Some(quant_protos.into()),
4661 shape: vec![1, 160, 160, 32],
4662 dshape: vec![
4663 (DimName::Batch, 1),
4664 (DimName::Height, 160),
4665 (DimName::Width, 160),
4666 (DimName::NumProtos, 32),
4667 ],
4668 },
4669 )
4670 .with_score_threshold(score_threshold)
4671 .with_iou_threshold(iou_threshold)
4672 .build()
4673 .unwrap()
4674 }
4675
4676 fn build_yolov8_seg_decoder(score_threshold: f32, iou_threshold: f32) -> crate::Decoder {
4677 let config_yaml = include_str!(concat!(
4678 env!("CARGO_MANIFEST_DIR"),
4679 "/../../testdata/yolov8_seg.yaml"
4680 ));
4681 DecoderBuilder::default()
4682 .with_config_yaml_str(config_yaml.to_string())
4683 .with_score_threshold(score_threshold)
4684 .with_iou_threshold(iou_threshold)
4685 .build()
4686 .unwrap()
4687 }
4688
4689 macro_rules! real_data_tracked_test {
4696 ($name:ident, quantized, $layout:ident, $output:ident) => {
4697 #[test]
4698 fn $name() {
4699 let is_split = matches!(stringify!($layout), "split");
4700 let is_proto = matches!(stringify!($output), "proto");
4701
4702 let score_threshold = 0.45;
4703 let iou_threshold = 0.45;
4704 let quant_boxes = (0.021287762_f32, 31_i32);
4705 let quant_protos = (0.02491162_f32, -117_i32);
4706
4707 let raw_boxes = include_bytes!(concat!(
4708 env!("CARGO_MANIFEST_DIR"),
4709 "/../../testdata/yolov8_boxes_116x8400.bin"
4710 ));
4711 let raw_boxes = unsafe {
4712 std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len())
4713 };
4714 let boxes_i8 =
4715 ndarray::Array3::from_shape_vec((1, 116, 8400), raw_boxes.to_vec()).unwrap();
4716
4717 let raw_protos = include_bytes!(concat!(
4718 env!("CARGO_MANIFEST_DIR"),
4719 "/../../testdata/yolov8_protos_160x160x32.bin"
4720 ));
4721 let raw_protos = unsafe {
4722 std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
4723 };
4724 let protos_i8 =
4725 ndarray::Array4::from_shape_vec((1, 160, 160, 32), raw_protos.to_vec())
4726 .unwrap();
4727
4728 let mask_split = boxes_i8.slice(s![.., 84.., ..]).to_owned();
4730 let mut scores_split = boxes_i8.slice(s![.., 4..84, ..]).to_owned();
4731 let boxes_split = boxes_i8.slice(s![.., ..4, ..]).to_owned();
4732 let mut boxes_combined = boxes_i8;
4733
4734 let decoder = if is_split {
4735 build_yolo_split_segdet_decoder(
4736 score_threshold,
4737 iou_threshold,
4738 quant_boxes,
4739 quant_protos,
4740 )
4741 } else {
4742 build_yolov8_seg_decoder(score_threshold, iou_threshold)
4743 };
4744
4745 let expected = real_data_expected_boxes();
4746 let mut tracker = ByteTrackBuilder::new()
4747 .track_update(0.1)
4748 .track_high_conf(0.7)
4749 .build();
4750 let mut output_boxes = Vec::with_capacity(50);
4751 let mut output_tracks = Vec::with_capacity(50);
4752
4753 if is_proto {
4755 {
4756 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
4757 vec![
4758 boxes_split.view().into(),
4759 scores_split.view().into(),
4760 mask_split.view().into(),
4761 protos_i8.view().into(),
4762 ]
4763 } else {
4764 vec![boxes_combined.view().into(), protos_i8.view().into()]
4765 };
4766 decoder
4767 .decode_tracked_quantized_proto(
4768 &mut tracker,
4769 0,
4770 &inputs,
4771 &mut output_boxes,
4772 &mut output_tracks,
4773 )
4774 .unwrap();
4775 }
4776 assert_eq!(output_boxes.len(), 2);
4777 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4778 assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
4779
4780 if is_split {
4782 for score in scores_split.iter_mut() {
4783 *score = i8::MIN;
4784 }
4785 } else {
4786 for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
4787 *score = i8::MIN;
4788 }
4789 }
4790
4791 let proto_result = {
4792 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
4793 vec![
4794 boxes_split.view().into(),
4795 scores_split.view().into(),
4796 mask_split.view().into(),
4797 protos_i8.view().into(),
4798 ]
4799 } else {
4800 vec![boxes_combined.view().into(), protos_i8.view().into()]
4801 };
4802 decoder
4803 .decode_tracked_quantized_proto(
4804 &mut tracker,
4805 100_000_000 / 3,
4806 &inputs,
4807 &mut output_boxes,
4808 &mut output_tracks,
4809 )
4810 .unwrap()
4811 };
4812 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4813 assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
4814 assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
4815 } else {
4816 let mut output_masks = Vec::with_capacity(50);
4817 {
4818 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
4819 vec![
4820 boxes_split.view().into(),
4821 scores_split.view().into(),
4822 mask_split.view().into(),
4823 protos_i8.view().into(),
4824 ]
4825 } else {
4826 vec![boxes_combined.view().into(), protos_i8.view().into()]
4827 };
4828 decoder
4829 .decode_tracked_quantized(
4830 &mut tracker,
4831 0,
4832 &inputs,
4833 &mut output_boxes,
4834 &mut output_masks,
4835 &mut output_tracks,
4836 )
4837 .unwrap();
4838 }
4839 assert_eq!(output_boxes.len(), 2);
4840 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4841 assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
4842
4843 if is_split {
4844 for score in scores_split.iter_mut() {
4845 *score = i8::MIN;
4846 }
4847 } else {
4848 for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
4849 *score = i8::MIN;
4850 }
4851 }
4852
4853 {
4854 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
4855 vec![
4856 boxes_split.view().into(),
4857 scores_split.view().into(),
4858 mask_split.view().into(),
4859 protos_i8.view().into(),
4860 ]
4861 } else {
4862 vec![boxes_combined.view().into(), protos_i8.view().into()]
4863 };
4864 decoder
4865 .decode_tracked_quantized(
4866 &mut tracker,
4867 100_000_000 / 3,
4868 &inputs,
4869 &mut output_boxes,
4870 &mut output_masks,
4871 &mut output_tracks,
4872 )
4873 .unwrap();
4874 }
4875 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4876 assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
4877 assert!(output_masks.is_empty());
4878 }
4879 }
4880 };
4881 ($name:ident, float, $layout:ident, $output:ident) => {
4882 #[test]
4883 fn $name() {
4884 let is_split = matches!(stringify!($layout), "split");
4885 let is_proto = matches!(stringify!($output), "proto");
4886
4887 let score_threshold = 0.45;
4888 let iou_threshold = 0.45;
4889 let quant_boxes = (0.021287762_f32, 31_i32);
4890 let quant_protos = (0.02491162_f32, -117_i32);
4891
4892 let raw_boxes = include_bytes!(concat!(
4893 env!("CARGO_MANIFEST_DIR"),
4894 "/../../testdata/yolov8_boxes_116x8400.bin"
4895 ));
4896 let raw_boxes = unsafe {
4897 std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len())
4898 };
4899 let boxes_i8 =
4900 ndarray::Array3::from_shape_vec((1, 116, 8400), raw_boxes.to_vec()).unwrap();
4901 let boxes_f32 = dequantize_ndarray(boxes_i8.view(), quant_boxes.into());
4902
4903 let raw_protos = include_bytes!(concat!(
4904 env!("CARGO_MANIFEST_DIR"),
4905 "/../../testdata/yolov8_protos_160x160x32.bin"
4906 ));
4907 let raw_protos = unsafe {
4908 std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
4909 };
4910 let protos_i8 =
4911 ndarray::Array4::from_shape_vec((1, 160, 160, 32), raw_protos.to_vec())
4912 .unwrap();
4913 let protos_f32 = dequantize_ndarray(protos_i8.view(), quant_protos.into());
4914
4915 let mask_split = boxes_f32.slice(s![.., 84.., ..]).to_owned();
4917 let mut scores_split = boxes_f32.slice(s![.., 4..84, ..]).to_owned();
4918 let boxes_split = boxes_f32.slice(s![.., ..4, ..]).to_owned();
4919 let mut boxes_combined = boxes_f32;
4920
4921 let decoder = if is_split {
4922 build_yolo_split_segdet_decoder(
4923 score_threshold,
4924 iou_threshold,
4925 quant_boxes,
4926 quant_protos,
4927 )
4928 } else {
4929 build_yolov8_seg_decoder(score_threshold, iou_threshold)
4930 };
4931
4932 let expected = real_data_expected_boxes();
4933 let mut tracker = ByteTrackBuilder::new()
4934 .track_update(0.1)
4935 .track_high_conf(0.7)
4936 .build();
4937 let mut output_boxes = Vec::with_capacity(50);
4938 let mut output_tracks = Vec::with_capacity(50);
4939
4940 if is_proto {
4941 {
4942 let inputs = if is_split {
4943 vec![
4944 boxes_split.view().into_dyn(),
4945 scores_split.view().into_dyn(),
4946 mask_split.view().into_dyn(),
4947 protos_f32.view().into_dyn(),
4948 ]
4949 } else {
4950 vec![
4951 boxes_combined.view().into_dyn(),
4952 protos_f32.view().into_dyn(),
4953 ]
4954 };
4955 decoder
4956 .decode_tracked_float_proto(
4957 &mut tracker,
4958 0,
4959 &inputs,
4960 &mut output_boxes,
4961 &mut output_tracks,
4962 )
4963 .unwrap();
4964 }
4965 assert_eq!(output_boxes.len(), 2);
4966 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4967 assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
4968
4969 if is_split {
4970 for score in scores_split.iter_mut() {
4971 *score = 0.0;
4972 }
4973 } else {
4974 for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
4975 *score = 0.0;
4976 }
4977 }
4978
4979 let proto_result = {
4980 let inputs = if is_split {
4981 vec![
4982 boxes_split.view().into_dyn(),
4983 scores_split.view().into_dyn(),
4984 mask_split.view().into_dyn(),
4985 protos_f32.view().into_dyn(),
4986 ]
4987 } else {
4988 vec![
4989 boxes_combined.view().into_dyn(),
4990 protos_f32.view().into_dyn(),
4991 ]
4992 };
4993 decoder
4994 .decode_tracked_float_proto(
4995 &mut tracker,
4996 100_000_000 / 3,
4997 &inputs,
4998 &mut output_boxes,
4999 &mut output_tracks,
5000 )
5001 .unwrap()
5002 };
5003 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5004 assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
5005 assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
5006 } else {
5007 let mut output_masks = Vec::with_capacity(50);
5008 {
5009 let inputs = if is_split {
5010 vec![
5011 boxes_split.view().into_dyn(),
5012 scores_split.view().into_dyn(),
5013 mask_split.view().into_dyn(),
5014 protos_f32.view().into_dyn(),
5015 ]
5016 } else {
5017 vec![
5018 boxes_combined.view().into_dyn(),
5019 protos_f32.view().into_dyn(),
5020 ]
5021 };
5022 decoder
5023 .decode_tracked_float(
5024 &mut tracker,
5025 0,
5026 &inputs,
5027 &mut output_boxes,
5028 &mut output_masks,
5029 &mut output_tracks,
5030 )
5031 .unwrap();
5032 }
5033 assert_eq!(output_boxes.len(), 2);
5034 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5035 assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
5036
5037 if is_split {
5038 for score in scores_split.iter_mut() {
5039 *score = 0.0;
5040 }
5041 } else {
5042 for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
5043 *score = 0.0;
5044 }
5045 }
5046
5047 {
5048 let inputs = if is_split {
5049 vec![
5050 boxes_split.view().into_dyn(),
5051 scores_split.view().into_dyn(),
5052 mask_split.view().into_dyn(),
5053 protos_f32.view().into_dyn(),
5054 ]
5055 } else {
5056 vec![
5057 boxes_combined.view().into_dyn(),
5058 protos_f32.view().into_dyn(),
5059 ]
5060 };
5061 decoder
5062 .decode_tracked_float(
5063 &mut tracker,
5064 100_000_000 / 3,
5065 &inputs,
5066 &mut output_boxes,
5067 &mut output_masks,
5068 &mut output_tracks,
5069 )
5070 .unwrap();
5071 }
5072 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5073 assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
5074 assert!(output_masks.is_empty());
5075 }
5076 }
5077 };
5078 }
5079
5080 real_data_tracked_test!(test_decoder_tracked_segdet, quantized, combined, masks);
5081 real_data_tracked_test!(test_decoder_tracked_segdet_float, float, combined, masks);
5082 real_data_tracked_test!(
5083 test_decoder_tracked_segdet_proto,
5084 quantized,
5085 combined,
5086 proto
5087 );
5088 real_data_tracked_test!(
5089 test_decoder_tracked_segdet_proto_float,
5090 float,
5091 combined,
5092 proto
5093 );
5094 real_data_tracked_test!(test_decoder_tracked_segdet_split, quantized, split, masks);
5095 real_data_tracked_test!(test_decoder_tracked_segdet_split_float, float, split, masks);
5096 real_data_tracked_test!(
5097 test_decoder_tracked_segdet_split_proto,
5098 quantized,
5099 split,
5100 proto
5101 );
5102 real_data_tracked_test!(
5103 test_decoder_tracked_segdet_split_proto_float,
5104 float,
5105 split,
5106 proto
5107 );
5108
5109 const E2E_COMBINED_CONFIG: &str = "
5115decoder_version: yolo26
5116outputs:
5117 - type: detection
5118 decoder: ultralytics
5119 quantization: [0.00784313725490196, 0]
5120 shape: [1, 10, 38]
5121 dshape:
5122 - [batch, 1]
5123 - [num_boxes, 10]
5124 - [num_features, 38]
5125 normalized: true
5126 - type: protos
5127 decoder: ultralytics
5128 quantization: [0.0039215686274509803921568627451, 128]
5129 shape: [1, 160, 160, 32]
5130 dshape:
5131 - [batch, 1]
5132 - [height, 160]
5133 - [width, 160]
5134 - [num_protos, 32]
5135";
5136
5137 const E2E_SPLIT_CONFIG: &str = "
5138decoder_version: yolo26
5139outputs:
5140 - type: boxes
5141 decoder: ultralytics
5142 quantization: [0.00784313725490196, 0]
5143 shape: [1, 10, 4]
5144 dshape:
5145 - [batch, 1]
5146 - [num_boxes, 10]
5147 - [box_coords, 4]
5148 normalized: true
5149 - type: scores
5150 decoder: ultralytics
5151 quantization: [0.00784313725490196, 0]
5152 shape: [1, 10, 1]
5153 dshape:
5154 - [batch, 1]
5155 - [num_boxes, 10]
5156 - [num_classes, 1]
5157 - type: classes
5158 decoder: ultralytics
5159 quantization: [0.00784313725490196, 0]
5160 shape: [1, 10, 1]
5161 dshape:
5162 - [batch, 1]
5163 - [num_boxes, 10]
5164 - [num_classes, 1]
5165 - type: mask_coefficients
5166 decoder: ultralytics
5167 quantization: [0.00784313725490196, 0]
5168 shape: [1, 10, 32]
5169 dshape:
5170 - [batch, 1]
5171 - [num_boxes, 10]
5172 - [num_protos, 32]
5173 - type: protos
5174 decoder: ultralytics
5175 quantization: [0.0039215686274509803921568627451, 128]
5176 shape: [1, 160, 160, 32]
5177 dshape:
5178 - [batch, 1]
5179 - [height, 160]
5180 - [width, 160]
5181 - [num_protos, 32]
5182";
5183
5184 macro_rules! e2e_tracked_test {
5185 ($name:ident, quantized, $layout:ident, $output:ident) => {
5186 #[test]
5187 fn $name() {
5188 let is_split = matches!(stringify!($layout), "split");
5189 let is_proto = matches!(stringify!($output), "proto");
5190
5191 let score_threshold = 0.45;
5192 let iou_threshold = 0.45;
5193
5194 let mut boxes = Array2::zeros((10, 4));
5195 let mut scores = Array2::zeros((10, 1));
5196 let mut classes = Array2::zeros((10, 1));
5197 let mask = Array2::zeros((10, 32));
5198 let protos = Array3::<f64>::zeros((160, 160, 32));
5199 let protos = protos.insert_axis(Axis(0));
5200 let protos_quant = (1.0 / 255.0, 0.0);
5201 let protos: Array4<u8> = quantize_ndarray(protos.view(), protos_quant.into());
5202
5203 boxes
5204 .slice_mut(s![0, ..])
5205 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
5206 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
5207 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
5208
5209 let detect_quant = (2.0 / 255.0, 0.0);
5210
5211 let decoder = if is_split {
5212 DecoderBuilder::default()
5213 .with_config_yaml_str(E2E_SPLIT_CONFIG.to_string())
5214 .with_score_threshold(score_threshold)
5215 .with_iou_threshold(iou_threshold)
5216 .build()
5217 .unwrap()
5218 } else {
5219 DecoderBuilder::default()
5220 .with_config_yaml_str(E2E_COMBINED_CONFIG.to_string())
5221 .with_score_threshold(score_threshold)
5222 .with_iou_threshold(iou_threshold)
5223 .build()
5224 .unwrap()
5225 };
5226
5227 let expected = e2e_expected_boxes_quant();
5228 let mut tracker = ByteTrackBuilder::new()
5229 .track_update(0.1)
5230 .track_high_conf(0.7)
5231 .build();
5232 let mut output_boxes = Vec::with_capacity(50);
5233 let mut output_tracks = Vec::with_capacity(50);
5234
5235 if is_split {
5236 let boxes = boxes.insert_axis(Axis(0));
5237 let scores = scores.insert_axis(Axis(0));
5238 let classes = classes.insert_axis(Axis(0));
5239 let mask = mask.insert_axis(Axis(0));
5240
5241 let boxes: Array3<u8> = quantize_ndarray(boxes.view(), detect_quant.into());
5242 let mut scores: Array3<u8> =
5243 quantize_ndarray(scores.view(), detect_quant.into());
5244 let classes: Array3<u8> = quantize_ndarray(classes.view(), detect_quant.into());
5245 let mask: Array3<u8> = quantize_ndarray(mask.view(), detect_quant.into());
5246
5247 if is_proto {
5248 {
5249 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
5250 boxes.view().into(),
5251 scores.view().into(),
5252 classes.view().into(),
5253 mask.view().into(),
5254 protos.view().into(),
5255 ];
5256 decoder
5257 .decode_tracked_quantized_proto(
5258 &mut tracker,
5259 0,
5260 &inputs,
5261 &mut output_boxes,
5262 &mut output_tracks,
5263 )
5264 .unwrap();
5265 }
5266 assert_eq!(output_boxes.len(), 1);
5267 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5268
5269 for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
5270 *score = u8::MIN;
5271 }
5272 let proto_result = {
5273 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
5274 boxes.view().into(),
5275 scores.view().into(),
5276 classes.view().into(),
5277 mask.view().into(),
5278 protos.view().into(),
5279 ];
5280 decoder
5281 .decode_tracked_quantized_proto(
5282 &mut tracker,
5283 100_000_000 / 3,
5284 &inputs,
5285 &mut output_boxes,
5286 &mut output_tracks,
5287 )
5288 .unwrap()
5289 };
5290 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5291 assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
5292 } else {
5293 let mut output_masks = Vec::with_capacity(50);
5294 {
5295 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
5296 boxes.view().into(),
5297 scores.view().into(),
5298 classes.view().into(),
5299 mask.view().into(),
5300 protos.view().into(),
5301 ];
5302 decoder
5303 .decode_tracked_quantized(
5304 &mut tracker,
5305 0,
5306 &inputs,
5307 &mut output_boxes,
5308 &mut output_masks,
5309 &mut output_tracks,
5310 )
5311 .unwrap();
5312 }
5313 assert_eq!(output_boxes.len(), 1);
5314 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5315
5316 for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
5317 *score = u8::MIN;
5318 }
5319 {
5320 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
5321 boxes.view().into(),
5322 scores.view().into(),
5323 classes.view().into(),
5324 mask.view().into(),
5325 protos.view().into(),
5326 ];
5327 decoder
5328 .decode_tracked_quantized(
5329 &mut tracker,
5330 100_000_000 / 3,
5331 &inputs,
5332 &mut output_boxes,
5333 &mut output_masks,
5334 &mut output_tracks,
5335 )
5336 .unwrap();
5337 }
5338 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5339 assert!(output_masks.is_empty());
5340 }
5341 } else {
5342 let detect = ndarray::concatenate![
5344 Axis(1),
5345 boxes.view(),
5346 scores.view(),
5347 classes.view(),
5348 mask.view()
5349 ];
5350 let detect = detect.insert_axis(Axis(0));
5351 assert_eq!(detect.shape(), &[1, 10, 38]);
5352 let mut detect: Array3<u8> =
5353 quantize_ndarray(detect.view(), detect_quant.into());
5354
5355 if is_proto {
5356 {
5357 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
5358 vec![detect.view().into(), protos.view().into()];
5359 decoder
5360 .decode_tracked_quantized_proto(
5361 &mut tracker,
5362 0,
5363 &inputs,
5364 &mut output_boxes,
5365 &mut output_tracks,
5366 )
5367 .unwrap();
5368 }
5369 assert_eq!(output_boxes.len(), 1);
5370 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5371
5372 for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
5373 *score = u8::MIN;
5374 }
5375 let proto_result = {
5376 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
5377 vec![detect.view().into(), protos.view().into()];
5378 decoder
5379 .decode_tracked_quantized_proto(
5380 &mut tracker,
5381 100_000_000 / 3,
5382 &inputs,
5383 &mut output_boxes,
5384 &mut output_tracks,
5385 )
5386 .unwrap()
5387 };
5388 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5389 assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
5390 } else {
5391 let mut output_masks = Vec::with_capacity(50);
5392 {
5393 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
5394 vec![detect.view().into(), protos.view().into()];
5395 decoder
5396 .decode_tracked_quantized(
5397 &mut tracker,
5398 0,
5399 &inputs,
5400 &mut output_boxes,
5401 &mut output_masks,
5402 &mut output_tracks,
5403 )
5404 .unwrap();
5405 }
5406 assert_eq!(output_boxes.len(), 1);
5407 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5408
5409 for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
5410 *score = u8::MIN;
5411 }
5412 {
5413 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
5414 vec![detect.view().into(), protos.view().into()];
5415 decoder
5416 .decode_tracked_quantized(
5417 &mut tracker,
5418 100_000_000 / 3,
5419 &inputs,
5420 &mut output_boxes,
5421 &mut output_masks,
5422 &mut output_tracks,
5423 )
5424 .unwrap();
5425 }
5426 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5427 assert!(output_masks.is_empty());
5428 }
5429 }
5430 }
5431 };
5432 ($name:ident, float, $layout:ident, $output:ident) => {
5433 #[test]
5434 fn $name() {
5435 let is_split = matches!(stringify!($layout), "split");
5436 let is_proto = matches!(stringify!($output), "proto");
5437
5438 let score_threshold = 0.45;
5439 let iou_threshold = 0.45;
5440
5441 let mut boxes = Array2::zeros((10, 4));
5442 let mut scores = Array2::zeros((10, 1));
5443 let mut classes = Array2::zeros((10, 1));
5444 let mask: Array2<f64> = Array2::zeros((10, 32));
5445 let protos = Array3::<f64>::zeros((160, 160, 32));
5446 let protos = protos.insert_axis(Axis(0));
5447
5448 boxes
5449 .slice_mut(s![0, ..])
5450 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
5451 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
5452 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
5453
5454 let decoder = if is_split {
5455 DecoderBuilder::default()
5456 .with_config_yaml_str(E2E_SPLIT_CONFIG.to_string())
5457 .with_score_threshold(score_threshold)
5458 .with_iou_threshold(iou_threshold)
5459 .build()
5460 .unwrap()
5461 } else {
5462 DecoderBuilder::default()
5463 .with_config_yaml_str(E2E_COMBINED_CONFIG.to_string())
5464 .with_score_threshold(score_threshold)
5465 .with_iou_threshold(iou_threshold)
5466 .build()
5467 .unwrap()
5468 };
5469
5470 let expected = e2e_expected_boxes_float();
5471 let mut tracker = ByteTrackBuilder::new()
5472 .track_update(0.1)
5473 .track_high_conf(0.7)
5474 .build();
5475 let mut output_boxes = Vec::with_capacity(50);
5476 let mut output_tracks = Vec::with_capacity(50);
5477
5478 if is_split {
5479 let boxes = boxes.insert_axis(Axis(0));
5480 let mut scores = scores.insert_axis(Axis(0));
5481 let classes = classes.insert_axis(Axis(0));
5482 let mask = mask.insert_axis(Axis(0));
5483
5484 if is_proto {
5485 {
5486 let inputs = vec![
5487 boxes.view().into_dyn(),
5488 scores.view().into_dyn(),
5489 classes.view().into_dyn(),
5490 mask.view().into_dyn(),
5491 protos.view().into_dyn(),
5492 ];
5493 decoder
5494 .decode_tracked_float_proto(
5495 &mut tracker,
5496 0,
5497 &inputs,
5498 &mut output_boxes,
5499 &mut output_tracks,
5500 )
5501 .unwrap();
5502 }
5503 assert_eq!(output_boxes.len(), 1);
5504 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5505
5506 for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
5507 *score = 0.0;
5508 }
5509 let proto_result = {
5510 let inputs = vec![
5511 boxes.view().into_dyn(),
5512 scores.view().into_dyn(),
5513 classes.view().into_dyn(),
5514 mask.view().into_dyn(),
5515 protos.view().into_dyn(),
5516 ];
5517 decoder
5518 .decode_tracked_float_proto(
5519 &mut tracker,
5520 100_000_000 / 3,
5521 &inputs,
5522 &mut output_boxes,
5523 &mut output_tracks,
5524 )
5525 .unwrap()
5526 };
5527 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5528 assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
5529 } else {
5530 let mut output_masks = Vec::with_capacity(50);
5531 {
5532 let inputs = vec![
5533 boxes.view().into_dyn(),
5534 scores.view().into_dyn(),
5535 classes.view().into_dyn(),
5536 mask.view().into_dyn(),
5537 protos.view().into_dyn(),
5538 ];
5539 decoder
5540 .decode_tracked_float(
5541 &mut tracker,
5542 0,
5543 &inputs,
5544 &mut output_boxes,
5545 &mut output_masks,
5546 &mut output_tracks,
5547 )
5548 .unwrap();
5549 }
5550 assert_eq!(output_boxes.len(), 1);
5551 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5552
5553 for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
5554 *score = 0.0;
5555 }
5556 {
5557 let inputs = vec![
5558 boxes.view().into_dyn(),
5559 scores.view().into_dyn(),
5560 classes.view().into_dyn(),
5561 mask.view().into_dyn(),
5562 protos.view().into_dyn(),
5563 ];
5564 decoder
5565 .decode_tracked_float(
5566 &mut tracker,
5567 100_000_000 / 3,
5568 &inputs,
5569 &mut output_boxes,
5570 &mut output_masks,
5571 &mut output_tracks,
5572 )
5573 .unwrap();
5574 }
5575 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5576 assert!(output_masks.is_empty());
5577 }
5578 } else {
5579 let detect = ndarray::concatenate![
5581 Axis(1),
5582 boxes.view(),
5583 scores.view(),
5584 classes.view(),
5585 mask.view()
5586 ];
5587 let mut detect = detect.insert_axis(Axis(0));
5588 assert_eq!(detect.shape(), &[1, 10, 38]);
5589
5590 if is_proto {
5591 {
5592 let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
5593 decoder
5594 .decode_tracked_float_proto(
5595 &mut tracker,
5596 0,
5597 &inputs,
5598 &mut output_boxes,
5599 &mut output_tracks,
5600 )
5601 .unwrap();
5602 }
5603 assert_eq!(output_boxes.len(), 1);
5604 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5605
5606 for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
5607 *score = 0.0;
5608 }
5609 let proto_result = {
5610 let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
5611 decoder
5612 .decode_tracked_float_proto(
5613 &mut tracker,
5614 100_000_000 / 3,
5615 &inputs,
5616 &mut output_boxes,
5617 &mut output_tracks,
5618 )
5619 .unwrap()
5620 };
5621 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5622 assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
5623 } else {
5624 let mut output_masks = Vec::with_capacity(50);
5625 {
5626 let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
5627 decoder
5628 .decode_tracked_float(
5629 &mut tracker,
5630 0,
5631 &inputs,
5632 &mut output_boxes,
5633 &mut output_masks,
5634 &mut output_tracks,
5635 )
5636 .unwrap();
5637 }
5638 assert_eq!(output_boxes.len(), 1);
5639 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5640
5641 for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
5642 *score = 0.0;
5643 }
5644 {
5645 let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
5646 decoder
5647 .decode_tracked_float(
5648 &mut tracker,
5649 100_000_000 / 3,
5650 &inputs,
5651 &mut output_boxes,
5652 &mut output_masks,
5653 &mut output_tracks,
5654 )
5655 .unwrap();
5656 }
5657 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5658 assert!(output_masks.is_empty());
5659 }
5660 }
5661 }
5662 };
5663 }
5664
5665 e2e_tracked_test!(
5666 test_decoder_tracked_end_to_end_segdet,
5667 quantized,
5668 combined,
5669 masks
5670 );
5671 e2e_tracked_test!(
5672 test_decoder_tracked_end_to_end_segdet_float,
5673 float,
5674 combined,
5675 masks
5676 );
5677 e2e_tracked_test!(
5678 test_decoder_tracked_end_to_end_segdet_proto,
5679 quantized,
5680 combined,
5681 proto
5682 );
5683 e2e_tracked_test!(
5684 test_decoder_tracked_end_to_end_segdet_proto_float,
5685 float,
5686 combined,
5687 proto
5688 );
5689 e2e_tracked_test!(
5690 test_decoder_tracked_end_to_end_segdet_split,
5691 quantized,
5692 split,
5693 masks
5694 );
5695 e2e_tracked_test!(
5696 test_decoder_tracked_end_to_end_segdet_split_float,
5697 float,
5698 split,
5699 masks
5700 );
5701 e2e_tracked_test!(
5702 test_decoder_tracked_end_to_end_segdet_split_proto,
5703 quantized,
5704 split,
5705 proto
5706 );
5707 e2e_tracked_test!(
5708 test_decoder_tracked_end_to_end_segdet_split_proto_float,
5709 float,
5710 split,
5711 proto
5712 );
5713
5714 macro_rules! e2e_tracked_tensor_test {
5720 ($name:ident, quantized, $layout:ident, $output:ident) => {
5721 #[test]
5722 fn $name() {
5723 use edgefirst_tensor::{Tensor, TensorMapTrait, TensorTrait};
5724
5725 let is_split = matches!(stringify!($layout), "split");
5726 let is_proto = matches!(stringify!($output), "proto");
5727
5728 let score_threshold = 0.45;
5729 let iou_threshold = 0.45;
5730
5731 let mut boxes = Array2::zeros((10, 4));
5732 let mut scores = Array2::zeros((10, 1));
5733 let mut classes = Array2::zeros((10, 1));
5734 let mask = Array2::zeros((10, 32));
5735 let protos_f64 = Array3::<f64>::zeros((160, 160, 32));
5736 let protos_f64 = protos_f64.insert_axis(Axis(0));
5737 let protos_quant = (1.0 / 255.0, 0.0);
5738 let protos_u8: Array4<u8> =
5739 quantize_ndarray(protos_f64.view(), protos_quant.into());
5740
5741 boxes
5742 .slice_mut(s![0, ..])
5743 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
5744 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
5745 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
5746
5747 let detect_quant = (2.0 / 255.0, 0.0);
5748
5749 let decoder = if is_split {
5750 DecoderBuilder::default()
5751 .with_config_yaml_str(E2E_SPLIT_CONFIG.to_string())
5752 .with_score_threshold(score_threshold)
5753 .with_iou_threshold(iou_threshold)
5754 .build()
5755 .unwrap()
5756 } else {
5757 DecoderBuilder::default()
5758 .with_config_yaml_str(E2E_COMBINED_CONFIG.to_string())
5759 .with_score_threshold(score_threshold)
5760 .with_iou_threshold(iou_threshold)
5761 .build()
5762 .unwrap()
5763 };
5764
5765 let make_u8_tensor =
5767 |shape: &[usize], data: &[u8]| -> edgefirst_tensor::TensorDyn {
5768 let t = Tensor::<u8>::new(shape, None, None).unwrap();
5769 t.map().unwrap().as_mut_slice()[..data.len()].copy_from_slice(data);
5770 t.into()
5771 };
5772
5773 let expected = e2e_expected_boxes_quant();
5774 let mut tracker = ByteTrackBuilder::new()
5775 .track_update(0.1)
5776 .track_high_conf(0.7)
5777 .build();
5778 let mut output_boxes = Vec::with_capacity(50);
5779 let mut output_tracks = Vec::with_capacity(50);
5780
5781 let protos_td = make_u8_tensor(protos_u8.shape(), protos_u8.as_slice().unwrap());
5782
5783 if is_split {
5784 let boxes = boxes.insert_axis(Axis(0));
5785 let scores = scores.insert_axis(Axis(0));
5786 let classes = classes.insert_axis(Axis(0));
5787 let mask = mask.insert_axis(Axis(0));
5788
5789 let boxes_q: Array3<u8> = quantize_ndarray(boxes.view(), detect_quant.into());
5790 let mut scores_q: Array3<u8> =
5791 quantize_ndarray(scores.view(), detect_quant.into());
5792 let classes_q: Array3<u8> =
5793 quantize_ndarray(classes.view(), detect_quant.into());
5794 let mask_q: Array3<u8> = quantize_ndarray(mask.view(), detect_quant.into());
5795
5796 let boxes_td = make_u8_tensor(boxes_q.shape(), boxes_q.as_slice().unwrap());
5797 let classes_td =
5798 make_u8_tensor(classes_q.shape(), classes_q.as_slice().unwrap());
5799 let mask_td = make_u8_tensor(mask_q.shape(), mask_q.as_slice().unwrap());
5800
5801 if is_proto {
5802 let scores_td =
5803 make_u8_tensor(scores_q.shape(), scores_q.as_slice().unwrap());
5804 decoder
5805 .decode_proto_tracked(
5806 &mut tracker,
5807 0,
5808 &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5809 &mut output_boxes,
5810 &mut output_tracks,
5811 )
5812 .unwrap();
5813
5814 assert_eq!(output_boxes.len(), 1);
5815 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5816
5817 for score in scores_q.slice_mut(s![.., .., ..]).iter_mut() {
5818 *score = u8::MIN;
5819 }
5820 let scores_td =
5821 make_u8_tensor(scores_q.shape(), scores_q.as_slice().unwrap());
5822 let proto_result = decoder
5823 .decode_proto_tracked(
5824 &mut tracker,
5825 100_000_000 / 3,
5826 &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5827 &mut output_boxes,
5828 &mut output_tracks,
5829 )
5830 .unwrap();
5831 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5832 assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
5833 } else {
5834 let scores_td =
5835 make_u8_tensor(scores_q.shape(), scores_q.as_slice().unwrap());
5836 let mut output_masks = Vec::with_capacity(50);
5837 decoder
5838 .decode_tracked(
5839 &mut tracker,
5840 0,
5841 &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5842 &mut output_boxes,
5843 &mut output_masks,
5844 &mut output_tracks,
5845 )
5846 .unwrap();
5847
5848 assert_eq!(output_boxes.len(), 1);
5849 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5850
5851 for score in scores_q.slice_mut(s![.., .., ..]).iter_mut() {
5852 *score = u8::MIN;
5853 }
5854 let scores_td =
5855 make_u8_tensor(scores_q.shape(), scores_q.as_slice().unwrap());
5856 decoder
5857 .decode_tracked(
5858 &mut tracker,
5859 100_000_000 / 3,
5860 &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5861 &mut output_boxes,
5862 &mut output_masks,
5863 &mut output_tracks,
5864 )
5865 .unwrap();
5866 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5867 assert!(output_masks.is_empty());
5868 }
5869 } else {
5870 let detect = ndarray::concatenate![
5872 Axis(1),
5873 boxes.view(),
5874 scores.view(),
5875 classes.view(),
5876 mask.view()
5877 ];
5878 let detect = detect.insert_axis(Axis(0));
5879 assert_eq!(detect.shape(), &[1, 10, 38]);
5880 let detect =
5882 Array3::from_shape_vec(detect.raw_dim(), detect.iter().copied().collect())
5883 .unwrap();
5884 let mut detect_q: Array3<u8> =
5885 quantize_ndarray(detect.view(), detect_quant.into());
5886
5887 if is_proto {
5888 let detect_td =
5889 make_u8_tensor(detect_q.shape(), detect_q.as_slice().unwrap());
5890 decoder
5891 .decode_proto_tracked(
5892 &mut tracker,
5893 0,
5894 &[&detect_td, &protos_td],
5895 &mut output_boxes,
5896 &mut output_tracks,
5897 )
5898 .unwrap();
5899
5900 assert_eq!(output_boxes.len(), 1);
5901 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5902
5903 for score in detect_q.slice_mut(s![.., .., 4]).iter_mut() {
5904 *score = u8::MIN;
5905 }
5906 let detect_td =
5907 make_u8_tensor(detect_q.shape(), detect_q.as_slice().unwrap());
5908 let proto_result = decoder
5909 .decode_proto_tracked(
5910 &mut tracker,
5911 100_000_000 / 3,
5912 &[&detect_td, &protos_td],
5913 &mut output_boxes,
5914 &mut output_tracks,
5915 )
5916 .unwrap();
5917 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5918 assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
5919 } else {
5920 let detect_td =
5921 make_u8_tensor(detect_q.shape(), detect_q.as_slice().unwrap());
5922 let mut output_masks = Vec::with_capacity(50);
5923 decoder
5924 .decode_tracked(
5925 &mut tracker,
5926 0,
5927 &[&detect_td, &protos_td],
5928 &mut output_boxes,
5929 &mut output_masks,
5930 &mut output_tracks,
5931 )
5932 .unwrap();
5933
5934 assert_eq!(output_boxes.len(), 1);
5935 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5936
5937 for score in detect_q.slice_mut(s![.., .., 4]).iter_mut() {
5938 *score = u8::MIN;
5939 }
5940 let detect_td =
5941 make_u8_tensor(detect_q.shape(), detect_q.as_slice().unwrap());
5942 decoder
5943 .decode_tracked(
5944 &mut tracker,
5945 100_000_000 / 3,
5946 &[&detect_td, &protos_td],
5947 &mut output_boxes,
5948 &mut output_masks,
5949 &mut output_tracks,
5950 )
5951 .unwrap();
5952 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5953 assert!(output_masks.is_empty());
5954 }
5955 }
5956 }
5957 };
5958 ($name:ident, float, $layout:ident, $output:ident) => {
5959 #[test]
5960 fn $name() {
5961 use edgefirst_tensor::{Tensor, TensorMapTrait, TensorTrait};
5962
5963 let is_split = matches!(stringify!($layout), "split");
5964 let is_proto = matches!(stringify!($output), "proto");
5965
5966 let score_threshold = 0.45;
5967 let iou_threshold = 0.45;
5968
5969 let mut boxes = Array2::zeros((10, 4));
5970 let mut scores = Array2::zeros((10, 1));
5971 let mut classes = Array2::zeros((10, 1));
5972 let mask: Array2<f64> = Array2::zeros((10, 32));
5973 let protos = Array3::<f64>::zeros((160, 160, 32));
5974 let protos = protos.insert_axis(Axis(0));
5975
5976 boxes
5977 .slice_mut(s![0, ..])
5978 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
5979 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
5980 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
5981
5982 let decoder = if is_split {
5983 DecoderBuilder::default()
5984 .with_config_yaml_str(E2E_SPLIT_CONFIG.to_string())
5985 .with_score_threshold(score_threshold)
5986 .with_iou_threshold(iou_threshold)
5987 .build()
5988 .unwrap()
5989 } else {
5990 DecoderBuilder::default()
5991 .with_config_yaml_str(E2E_COMBINED_CONFIG.to_string())
5992 .with_score_threshold(score_threshold)
5993 .with_iou_threshold(iou_threshold)
5994 .build()
5995 .unwrap()
5996 };
5997
5998 let make_f64_tensor =
6000 |shape: &[usize], data: &[f64]| -> edgefirst_tensor::TensorDyn {
6001 let t = Tensor::<f64>::new(shape, None, None).unwrap();
6002 t.map().unwrap().as_mut_slice()[..data.len()].copy_from_slice(data);
6003 t.into()
6004 };
6005
6006 let expected = e2e_expected_boxes_float();
6007 let mut tracker = ByteTrackBuilder::new()
6008 .track_update(0.1)
6009 .track_high_conf(0.7)
6010 .build();
6011 let mut output_boxes = Vec::with_capacity(50);
6012 let mut output_tracks = Vec::with_capacity(50);
6013
6014 let protos_td = make_f64_tensor(protos.shape(), protos.as_slice().unwrap());
6015
6016 if is_split {
6017 let boxes = boxes.insert_axis(Axis(0));
6018 let mut scores = scores.insert_axis(Axis(0));
6019 let classes = classes.insert_axis(Axis(0));
6020 let mask = mask.insert_axis(Axis(0));
6021
6022 let boxes_td = make_f64_tensor(boxes.shape(), boxes.as_slice().unwrap());
6023 let classes_td = make_f64_tensor(classes.shape(), classes.as_slice().unwrap());
6024 let mask_td = make_f64_tensor(mask.shape(), mask.as_slice().unwrap());
6025
6026 if is_proto {
6027 let scores_td = make_f64_tensor(scores.shape(), scores.as_slice().unwrap());
6028 decoder
6029 .decode_proto_tracked(
6030 &mut tracker,
6031 0,
6032 &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
6033 &mut output_boxes,
6034 &mut output_tracks,
6035 )
6036 .unwrap();
6037
6038 assert_eq!(output_boxes.len(), 1);
6039 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
6040
6041 for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
6042 *score = 0.0;
6043 }
6044 let scores_td = make_f64_tensor(scores.shape(), scores.as_slice().unwrap());
6045 let proto_result = decoder
6046 .decode_proto_tracked(
6047 &mut tracker,
6048 100_000_000 / 3,
6049 &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
6050 &mut output_boxes,
6051 &mut output_tracks,
6052 )
6053 .unwrap();
6054 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
6055 assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
6056 } else {
6057 let scores_td = make_f64_tensor(scores.shape(), scores.as_slice().unwrap());
6058 let mut output_masks = Vec::with_capacity(50);
6059 decoder
6060 .decode_tracked(
6061 &mut tracker,
6062 0,
6063 &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
6064 &mut output_boxes,
6065 &mut output_masks,
6066 &mut output_tracks,
6067 )
6068 .unwrap();
6069
6070 assert_eq!(output_boxes.len(), 1);
6071 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
6072
6073 for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
6074 *score = 0.0;
6075 }
6076 let scores_td = make_f64_tensor(scores.shape(), scores.as_slice().unwrap());
6077 decoder
6078 .decode_tracked(
6079 &mut tracker,
6080 100_000_000 / 3,
6081 &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
6082 &mut output_boxes,
6083 &mut output_masks,
6084 &mut output_tracks,
6085 )
6086 .unwrap();
6087 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
6088 assert!(output_masks.is_empty());
6089 }
6090 } else {
6091 let detect = ndarray::concatenate![
6093 Axis(1),
6094 boxes.view(),
6095 scores.view(),
6096 classes.view(),
6097 mask.view()
6098 ];
6099 let detect = detect.insert_axis(Axis(0));
6100 assert_eq!(detect.shape(), &[1, 10, 38]);
6101 let mut detect =
6103 Array3::from_shape_vec(detect.raw_dim(), detect.iter().copied().collect())
6104 .unwrap();
6105
6106 if is_proto {
6107 let detect_td = make_f64_tensor(detect.shape(), detect.as_slice().unwrap());
6108 decoder
6109 .decode_proto_tracked(
6110 &mut tracker,
6111 0,
6112 &[&detect_td, &protos_td],
6113 &mut output_boxes,
6114 &mut output_tracks,
6115 )
6116 .unwrap();
6117
6118 assert_eq!(output_boxes.len(), 1);
6119 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
6120
6121 for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
6122 *score = 0.0;
6123 }
6124 let detect_td = make_f64_tensor(detect.shape(), detect.as_slice().unwrap());
6125 let proto_result = decoder
6126 .decode_proto_tracked(
6127 &mut tracker,
6128 100_000_000 / 3,
6129 &[&detect_td, &protos_td],
6130 &mut output_boxes,
6131 &mut output_tracks,
6132 )
6133 .unwrap();
6134 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
6135 assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
6136 } else {
6137 let detect_td = make_f64_tensor(detect.shape(), detect.as_slice().unwrap());
6138 let mut output_masks = Vec::with_capacity(50);
6139 decoder
6140 .decode_tracked(
6141 &mut tracker,
6142 0,
6143 &[&detect_td, &protos_td],
6144 &mut output_boxes,
6145 &mut output_masks,
6146 &mut output_tracks,
6147 )
6148 .unwrap();
6149
6150 assert_eq!(output_boxes.len(), 1);
6151 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
6152
6153 for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
6154 *score = 0.0;
6155 }
6156 let detect_td = make_f64_tensor(detect.shape(), detect.as_slice().unwrap());
6157 decoder
6158 .decode_tracked(
6159 &mut tracker,
6160 100_000_000 / 3,
6161 &[&detect_td, &protos_td],
6162 &mut output_boxes,
6163 &mut output_masks,
6164 &mut output_tracks,
6165 )
6166 .unwrap();
6167 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
6168 assert!(output_masks.is_empty());
6169 }
6170 }
6171 }
6172 };
6173 }
6174
6175 e2e_tracked_tensor_test!(
6176 test_decoder_tracked_tensor_end_to_end_segdet,
6177 quantized,
6178 combined,
6179 masks
6180 );
6181 e2e_tracked_tensor_test!(
6182 test_decoder_tracked_tensor_end_to_end_segdet_float,
6183 float,
6184 combined,
6185 masks
6186 );
6187 e2e_tracked_tensor_test!(
6188 test_decoder_tracked_tensor_end_to_end_segdet_proto,
6189 quantized,
6190 combined,
6191 proto
6192 );
6193 e2e_tracked_tensor_test!(
6194 test_decoder_tracked_tensor_end_to_end_segdet_proto_float,
6195 float,
6196 combined,
6197 proto
6198 );
6199 e2e_tracked_tensor_test!(
6200 test_decoder_tracked_tensor_end_to_end_segdet_split,
6201 quantized,
6202 split,
6203 masks
6204 );
6205 e2e_tracked_tensor_test!(
6206 test_decoder_tracked_tensor_end_to_end_segdet_split_float,
6207 float,
6208 split,
6209 masks
6210 );
6211 e2e_tracked_tensor_test!(
6212 test_decoder_tracked_tensor_end_to_end_segdet_split_proto,
6213 quantized,
6214 split,
6215 proto
6216 );
6217 e2e_tracked_tensor_test!(
6218 test_decoder_tracked_tensor_end_to_end_segdet_split_proto_float,
6219 float,
6220 split,
6221 proto
6222 );
6223
6224 #[test]
6225 fn test_decoder_tracked_linear_motion() {
6226 use crate::configs::{DecoderType, Nms};
6227 use crate::DecoderBuilder;
6228
6229 let score_threshold = 0.25;
6230 let iou_threshold = 0.1;
6231 let out = include_bytes!(concat!(
6232 env!("CARGO_MANIFEST_DIR"),
6233 "/../../testdata/yolov8s_80_classes.bin"
6234 ));
6235 let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
6236 let mut out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
6237 let quant = (0.0040811873, -123).into();
6238
6239 let decoder = DecoderBuilder::default()
6240 .with_config_yolo_det(
6241 crate::configs::Detection {
6242 decoder: DecoderType::Ultralytics,
6243 shape: vec![1, 84, 8400],
6244 anchors: None,
6245 quantization: Some(quant),
6246 dshape: vec![
6247 (crate::configs::DimName::Batch, 1),
6248 (crate::configs::DimName::NumFeatures, 84),
6249 (crate::configs::DimName::NumBoxes, 8400),
6250 ],
6251 normalized: Some(true),
6252 },
6253 None,
6254 )
6255 .with_score_threshold(score_threshold)
6256 .with_iou_threshold(iou_threshold)
6257 .with_nms(Some(Nms::ClassAgnostic))
6258 .build()
6259 .unwrap();
6260
6261 let mut expected_boxes = [
6262 DetectBox {
6263 bbox: BoundingBox {
6264 xmin: 0.5285137,
6265 ymin: 0.05305544,
6266 xmax: 0.87541467,
6267 ymax: 0.9998909,
6268 },
6269 score: 0.5591227,
6270 label: 0,
6271 },
6272 DetectBox {
6273 bbox: BoundingBox {
6274 xmin: 0.130598,
6275 ymin: 0.43260583,
6276 xmax: 0.35098213,
6277 ymax: 0.9958097,
6278 },
6279 score: 0.33057618,
6280 label: 75,
6281 },
6282 ];
6283
6284 let mut tracker = ByteTrackBuilder::new()
6285 .track_update(0.1)
6286 .track_high_conf(0.3)
6287 .build();
6288
6289 let mut output_boxes = Vec::with_capacity(50);
6290 let mut output_masks = Vec::with_capacity(50);
6291 let mut output_tracks = Vec::with_capacity(50);
6292
6293 decoder
6294 .decode_tracked_quantized(
6295 &mut tracker,
6296 0,
6297 &[out.view().into()],
6298 &mut output_boxes,
6299 &mut output_masks,
6300 &mut output_tracks,
6301 )
6302 .unwrap();
6303
6304 assert_eq!(output_boxes.len(), 2);
6305 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
6306 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
6307
6308 for i in 1..=100 {
6309 let mut out = out.clone();
6310 let mut x_values = out.slice_mut(s![0, 0, ..]);
6312 for x in x_values.iter_mut() {
6313 *x = x.saturating_add((i as f32 * 1e-3 / quant.0).round() as i8);
6314 }
6315
6316 decoder
6317 .decode_tracked_quantized(
6318 &mut tracker,
6319 100_000_000 * i / 3, &[out.view().into()],
6321 &mut output_boxes,
6322 &mut output_masks,
6323 &mut output_tracks,
6324 )
6325 .unwrap();
6326
6327 assert_eq!(output_boxes.len(), 2);
6328 }
6329 let tracks = tracker.get_active_tracks();
6330 let predicted_boxes: Vec<_> = tracks
6331 .iter()
6332 .map(|track| {
6333 let mut l = track.last_box;
6334 l.bbox = track.info.tracked_location.into();
6335 l
6336 })
6337 .collect();
6338 expected_boxes[0].bbox.xmin += 0.1; expected_boxes[0].bbox.xmax += 0.1;
6340 expected_boxes[1].bbox.xmin += 0.1;
6341 expected_boxes[1].bbox.xmax += 0.1;
6342
6343 assert!(predicted_boxes[0].equal_within_delta(&expected_boxes[0], 1e-3));
6344 assert!(predicted_boxes[1].equal_within_delta(&expected_boxes[1], 1e-3));
6345
6346 let mut scores_values = out.slice_mut(s![0, 4.., ..]);
6348 for score in scores_values.iter_mut() {
6349 *score = i8::MIN; }
6351 decoder
6352 .decode_tracked_quantized(
6353 &mut tracker,
6354 100_000_000 * 101 / 3,
6355 &[out.view().into()],
6356 &mut output_boxes,
6357 &mut output_masks,
6358 &mut output_tracks,
6359 )
6360 .unwrap();
6361 expected_boxes[0].bbox.xmin += 0.001; expected_boxes[0].bbox.xmax += 0.001;
6363 expected_boxes[1].bbox.xmin += 0.001;
6364 expected_boxes[1].bbox.xmax += 0.001;
6365
6366 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-3));
6367 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-3));
6368 }
6369
6370 #[test]
6371 fn test_decoder_tracked_end_to_end_float() {
6372 let score_threshold = 0.45;
6373 let iou_threshold = 0.45;
6374
6375 let mut boxes = Array2::zeros((10, 4));
6376 let mut scores = Array2::zeros((10, 1));
6377 let mut classes = Array2::zeros((10, 1));
6378
6379 boxes
6380 .slice_mut(s![0, ..,])
6381 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
6382 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
6383 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
6384
6385 let detect = ndarray::concatenate![Axis(1), boxes.view(), scores.view(), classes.view(),];
6386 let mut detect = detect.insert_axis(Axis(0));
6387 assert_eq!(detect.shape(), &[1, 10, 6]);
6388 let config = "
6389decoder_version: yolo26
6390outputs:
6391 - type: detection
6392 decoder: ultralytics
6393 quantization: [0.00784313725490196, 0]
6394 shape: [1, 10, 6]
6395 dshape:
6396 - [batch, 1]
6397 - [num_boxes, 10]
6398 - [num_features, 6]
6399 normalized: true
6400";
6401
6402 let decoder = DecoderBuilder::default()
6403 .with_config_yaml_str(config.to_string())
6404 .with_score_threshold(score_threshold)
6405 .with_iou_threshold(iou_threshold)
6406 .build()
6407 .unwrap();
6408
6409 let expected_boxes = [DetectBox {
6410 bbox: BoundingBox {
6411 xmin: 0.1234,
6412 ymin: 0.1234,
6413 xmax: 0.2345,
6414 ymax: 0.2345,
6415 },
6416 score: 0.9876,
6417 label: 2,
6418 }];
6419
6420 let mut tracker = ByteTrackBuilder::new()
6421 .track_update(0.1)
6422 .track_high_conf(0.7)
6423 .build();
6424
6425 let mut output_boxes = Vec::with_capacity(50);
6426 let mut output_masks = Vec::with_capacity(50);
6427 let mut output_tracks = Vec::with_capacity(50);
6428
6429 decoder
6430 .decode_tracked_float(
6431 &mut tracker,
6432 0,
6433 &[detect.view().into_dyn()],
6434 &mut output_boxes,
6435 &mut output_masks,
6436 &mut output_tracks,
6437 )
6438 .unwrap();
6439
6440 assert_eq!(output_boxes.len(), 1);
6441 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
6442
6443 for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
6446 *score = 0.0; }
6448
6449 decoder
6450 .decode_tracked_float(
6451 &mut tracker,
6452 100_000_000 / 3,
6453 &[detect.view().into_dyn()],
6454 &mut output_boxes,
6455 &mut output_masks,
6456 &mut output_tracks,
6457 )
6458 .unwrap();
6459 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
6460 }
6461}