1#![cfg_attr(coverage_nightly, feature(coverage_attribute))]
68
69use ndarray::{Array, Array2, Array3, ArrayView, ArrayView1, ArrayView3, Dimension};
70use num_traits::{AsPrimitive, Float, PrimInt};
71
72pub mod byte;
73pub mod error;
74pub mod float;
75pub mod modelpack;
76pub mod schema;
77pub mod yolo;
78
79mod decoder;
80pub use decoder::*;
81
82pub use configs::{DecoderVersion, Nms};
83pub use error::{DecoderError, DecoderResult};
84
85use crate::{
86 decoder::configs::QuantTuple, modelpack::modelpack_segmentation_to_mask,
87 yolo::yolo_segmentation_to_mask,
88};
89
90pub trait BBoxTypeTrait {
92 fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4];
94
95 fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
97 input: &[B; 4],
98 quant: Quantization,
99 ) -> [A; 4]
100 where
101 f32: AsPrimitive<A>,
102 i32: AsPrimitive<A>;
103
104 #[inline(always)]
115 fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
116 input: ArrayView1<B>,
117 ) -> [A; 4] {
118 Self::to_xyxy_float(&[input[0], input[1], input[2], input[3]])
119 }
120
121 #[inline(always)]
122 fn ndarray_to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
124 input: ArrayView1<B>,
125 quant: Quantization,
126 ) -> [A; 4]
127 where
128 f32: AsPrimitive<A>,
129 i32: AsPrimitive<A>,
130 {
131 Self::to_xyxy_dequant(&[input[0], input[1], input[2], input[3]], quant)
132 }
133}
134
135#[derive(Debug, Clone, Copy, PartialEq, Eq)]
137pub struct XYXY {}
138
139impl BBoxTypeTrait for XYXY {
140 fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
141 input.map(|b| b.as_())
142 }
143
144 fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
145 input: &[B; 4],
146 quant: Quantization,
147 ) -> [A; 4]
148 where
149 f32: AsPrimitive<A>,
150 i32: AsPrimitive<A>,
151 {
152 let scale = quant.scale.as_();
153 let zp = quant.zero_point.as_();
154 input.map(|b| (b.as_() - zp) * scale)
155 }
156
157 #[inline(always)]
158 fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
159 input: ArrayView1<B>,
160 ) -> [A; 4] {
161 [
162 input[0].as_(),
163 input[1].as_(),
164 input[2].as_(),
165 input[3].as_(),
166 ]
167 }
168}
169
170#[derive(Debug, Clone, Copy, PartialEq, Eq)]
173pub struct XYWH {}
174
175impl BBoxTypeTrait for XYWH {
176 #[inline(always)]
177 fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
178 let half = A::one() / (A::one() + A::one());
179 [
180 (input[0].as_()) - (input[2].as_() * half),
181 (input[1].as_()) - (input[3].as_() * half),
182 (input[0].as_()) + (input[2].as_() * half),
183 (input[1].as_()) + (input[3].as_() * half),
184 ]
185 }
186
187 #[inline(always)]
188 fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
189 input: &[B; 4],
190 quant: Quantization,
191 ) -> [A; 4]
192 where
193 f32: AsPrimitive<A>,
194 i32: AsPrimitive<A>,
195 {
196 let scale = quant.scale.as_();
197 let half_scale = (quant.scale * 0.5).as_();
198 let zp = quant.zero_point.as_();
199 let [x, y, w, h] = [
200 (input[0].as_() - zp) * scale,
201 (input[1].as_() - zp) * scale,
202 (input[2].as_() - zp) * half_scale,
203 (input[3].as_() - zp) * half_scale,
204 ];
205
206 [x - w, y - h, x + w, y + h]
207 }
208
209 #[inline(always)]
210 fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
211 input: ArrayView1<B>,
212 ) -> [A; 4] {
213 let half = A::one() / (A::one() + A::one());
214 [
215 (input[0].as_()) - (input[2].as_() * half),
216 (input[1].as_()) - (input[3].as_() * half),
217 (input[0].as_()) + (input[2].as_() * half),
218 (input[1].as_()) + (input[3].as_() * half),
219 ]
220 }
221}
222
223#[derive(Debug, Clone, Copy, PartialEq)]
225pub struct Quantization {
226 pub scale: f32,
227 pub zero_point: i32,
228}
229
230impl Quantization {
231 pub fn new(scale: f32, zero_point: i32) -> Self {
240 Self { scale, zero_point }
241 }
242}
243
244impl From<QuantTuple> for Quantization {
245 fn from(quant_tuple: QuantTuple) -> Quantization {
256 Quantization {
257 scale: quant_tuple.0,
258 zero_point: quant_tuple.1,
259 }
260 }
261}
262
263impl<S, Z> From<(S, Z)> for Quantization
264where
265 S: AsPrimitive<f32>,
266 Z: AsPrimitive<i32>,
267{
268 fn from((scale, zp): (S, Z)) -> Quantization {
277 Self {
278 scale: scale.as_(),
279 zero_point: zp.as_(),
280 }
281 }
282}
283
284impl Default for Quantization {
285 fn default() -> Self {
294 Self {
295 scale: 1.0,
296 zero_point: 0,
297 }
298 }
299}
300
301#[derive(Debug, Clone, Copy, PartialEq, Default)]
303pub struct DetectBox {
304 pub bbox: BoundingBox,
305 pub score: f32,
307 pub label: usize,
309}
310
311#[derive(Debug, Clone, Copy, PartialEq, Default)]
313pub struct BoundingBox {
314 pub xmin: f32,
316 pub ymin: f32,
318 pub xmax: f32,
320 pub ymax: f32,
322}
323
324impl BoundingBox {
325 pub fn new(xmin: f32, ymin: f32, xmax: f32, ymax: f32) -> Self {
327 Self {
328 xmin,
329 ymin,
330 xmax,
331 ymax,
332 }
333 }
334
335 pub fn to_canonical(&self) -> Self {
344 let xmin = self.xmin.min(self.xmax);
345 let xmax = self.xmin.max(self.xmax);
346 let ymin = self.ymin.min(self.ymax);
347 let ymax = self.ymin.max(self.ymax);
348 BoundingBox {
349 xmin,
350 ymin,
351 xmax,
352 ymax,
353 }
354 }
355}
356
357impl From<BoundingBox> for [f32; 4] {
358 fn from(b: BoundingBox) -> Self {
373 [b.xmin, b.ymin, b.xmax, b.ymax]
374 }
375}
376
377impl From<[f32; 4]> for BoundingBox {
378 fn from(arr: [f32; 4]) -> Self {
381 BoundingBox {
382 xmin: arr[0],
383 ymin: arr[1],
384 xmax: arr[2],
385 ymax: arr[3],
386 }
387 }
388}
389
390impl DetectBox {
391 pub fn equal_within_delta(&self, rhs: &DetectBox, eps: f32) -> bool {
420 let eq_delta = |a: f32, b: f32| (a - b).abs() <= eps;
421 self.label == rhs.label
422 && eq_delta(self.score, rhs.score)
423 && eq_delta(self.bbox.xmin, rhs.bbox.xmin)
424 && eq_delta(self.bbox.ymin, rhs.bbox.ymin)
425 && eq_delta(self.bbox.xmax, rhs.bbox.xmax)
426 && eq_delta(self.bbox.ymax, rhs.bbox.ymax)
427 }
428}
429
430#[derive(Debug, Clone, PartialEq, Default)]
433pub struct Segmentation {
434 pub xmin: f32,
436 pub ymin: f32,
438 pub xmax: f32,
440 pub ymax: f32,
442 pub segmentation: Array3<u8>,
452}
453
454#[derive(Debug)]
479pub struct ProtoData {
480 pub mask_coefficients: edgefirst_tensor::TensorDyn,
482 pub protos: edgefirst_tensor::TensorDyn,
484}
485
486pub fn dequant_detect_box<SCORE: PrimInt + AsPrimitive<f32>>(
504 detect: &DetectBoxQuantized<SCORE>,
505 quant_scores: Quantization,
506) -> DetectBox {
507 let scaled_zp = -quant_scores.scale * quant_scores.zero_point as f32;
508 DetectBox {
509 bbox: detect.bbox,
510 score: quant_scores.scale * detect.score.as_() + scaled_zp,
511 label: detect.label,
512 }
513}
514#[derive(Debug, Clone, Copy, PartialEq)]
516pub struct DetectBoxQuantized<
517 SCORE: PrimInt + AsPrimitive<f32>,
519> {
520 pub bbox: BoundingBox,
522 pub score: SCORE,
525 pub label: usize,
527}
528
529pub fn dequantize_ndarray<T: AsPrimitive<F>, D: Dimension, F: Float + 'static>(
542 input: ArrayView<T, D>,
543 quant: Quantization,
544) -> Array<F, D>
545where
546 i32: num_traits::AsPrimitive<F>,
547 f32: num_traits::AsPrimitive<F>,
548{
549 let zero_point = quant.zero_point.as_();
550 let scale = quant.scale.as_();
551 if zero_point != F::zero() {
552 let scaled_zero = -zero_point * scale;
553 input.mapv(|d| d.as_() * scale + scaled_zero)
554 } else {
555 input.mapv(|d| d.as_() * scale)
556 }
557}
558
559pub fn dequantize_cpu<T: AsPrimitive<F>, F: Float + 'static>(
572 input: &[T],
573 quant: Quantization,
574 output: &mut [F],
575) where
576 f32: num_traits::AsPrimitive<F>,
577 i32: num_traits::AsPrimitive<F>,
578{
579 assert!(input.len() == output.len());
580 let zero_point = quant.zero_point.as_();
581 let scale = quant.scale.as_();
582 if zero_point != F::zero() {
583 let scaled_zero = -zero_point * scale; input
585 .iter()
586 .zip(output)
587 .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
588 } else {
589 input
590 .iter()
591 .zip(output)
592 .for_each(|(d, deq)| *deq = d.as_() * scale);
593 }
594}
595
596pub fn dequantize_cpu_chunked<T: AsPrimitive<F>, F: Float + 'static>(
610 input: &[T],
611 quant: Quantization,
612 output: &mut [F],
613) where
614 f32: num_traits::AsPrimitive<F>,
615 i32: num_traits::AsPrimitive<F>,
616{
617 assert!(input.len() == output.len());
618 let zero_point = quant.zero_point.as_();
619 let scale = quant.scale.as_();
620
621 let input = input.as_chunks::<4>();
622 let output = output.as_chunks_mut::<4>();
623
624 if zero_point != F::zero() {
625 let scaled_zero = -zero_point * scale; input
628 .0
629 .iter()
630 .zip(output.0)
631 .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale + scaled_zero));
632 input
633 .1
634 .iter()
635 .zip(output.1)
636 .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
637 } else {
638 input
639 .0
640 .iter()
641 .zip(output.0)
642 .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale));
643 input
644 .1
645 .iter()
646 .zip(output.1)
647 .for_each(|(d, deq)| *deq = d.as_() * scale);
648 }
649}
650
651pub fn segmentation_to_mask(segmentation: ArrayView3<u8>) -> Result<Array2<u8>, DecoderError> {
669 if segmentation.shape()[2] == 0 {
670 return Err(DecoderError::InvalidShape(
671 "Segmentation tensor must have non-zero depth".to_string(),
672 ));
673 }
674 if segmentation.shape()[2] == 1 {
675 yolo_segmentation_to_mask(segmentation, 128)
676 } else {
677 Ok(modelpack_segmentation_to_mask(segmentation))
678 }
679}
680
681fn arg_max<T: PartialOrd + Copy>(score: ArrayView1<T>) -> (T, usize) {
683 score
684 .iter()
685 .enumerate()
686 .fold((score[0], 0), |(max, arg_max), (ind, s)| {
687 if max > *s {
688 (max, arg_max)
689 } else {
690 (*s, ind)
691 }
692 })
693}
694#[cfg(test)]
695#[cfg_attr(coverage_nightly, coverage(off))]
696mod decoder_tests {
697 #![allow(clippy::excessive_precision)]
698 use crate::{
699 configs::{DecoderType, DimName, Protos},
700 modelpack::{decode_modelpack_det, decode_modelpack_split_quant},
701 yolo::{
702 decode_yolo_det, decode_yolo_det_float, decode_yolo_segdet_float,
703 decode_yolo_segdet_quant,
704 },
705 *,
706 };
707 use edgefirst_tensor::{Tensor, TensorMapTrait, TensorTrait};
708 use ndarray::Dimension;
709 use ndarray::{array, s, Array2, Array3, Array4, Axis};
710 use ndarray_stats::DeviationExt;
711 use num_traits::{AsPrimitive, PrimInt};
712
713 fn compare_outputs(
714 boxes: (&[DetectBox], &[DetectBox]),
715 masks: (&[Segmentation], &[Segmentation]),
716 ) {
717 let (boxes0, boxes1) = boxes;
718 let (masks0, masks1) = masks;
719
720 assert_eq!(boxes0.len(), boxes1.len());
721 assert_eq!(masks0.len(), masks1.len());
722
723 for (b_i8, b_f32) in boxes0.iter().zip(boxes1) {
724 assert!(
725 b_i8.equal_within_delta(b_f32, 1e-6),
726 "{b_i8:?} is not equal to {b_f32:?}"
727 );
728 }
729
730 for (m_i8, m_f32) in masks0.iter().zip(masks1) {
731 assert_eq!(
732 [m_i8.xmin, m_i8.ymin, m_i8.xmax, m_i8.ymax],
733 [m_f32.xmin, m_f32.ymin, m_f32.xmax, m_f32.ymax],
734 );
735 assert_eq!(m_i8.segmentation.shape(), m_f32.segmentation.shape());
736 let mask_i8 = m_i8.segmentation.map(|x| *x as i32);
737 let mask_f32 = m_f32.segmentation.map(|x| *x as i32);
738 let diff = &mask_i8 - &mask_f32;
739 for x in 0..diff.shape()[0] {
740 for y in 0..diff.shape()[1] {
741 for z in 0..diff.shape()[2] {
742 let val = diff[[x, y, z]];
743 assert!(
744 val.abs() <= 1,
745 "Difference between mask0 and mask1 is greater than 1 at ({}, {}, {}): {}",
746 x,
747 y,
748 z,
749 val
750 );
751 }
752 }
753 }
754 let mean_sq_err = mask_i8.mean_sq_err(&mask_f32).unwrap();
755 assert!(
756 mean_sq_err < 1e-2,
757 "Mean Square Error between masks was greater than 1%: {:.2}%",
758 mean_sq_err * 100.0
759 );
760 }
761 }
762
763 fn load_yolov8_boxes() -> Array3<i8> {
766 let raw = include_bytes!(concat!(
767 env!("CARGO_MANIFEST_DIR"),
768 "/../../testdata/yolov8_boxes_116x8400.bin"
769 ));
770 let raw = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const i8, raw.len()) };
771 Array3::from_shape_vec((1, 116, 8400), raw.to_vec()).unwrap()
772 }
773
774 fn load_yolov8_protos() -> Array4<i8> {
775 let raw = include_bytes!(concat!(
776 env!("CARGO_MANIFEST_DIR"),
777 "/../../testdata/yolov8_protos_160x160x32.bin"
778 ));
779 let raw = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const i8, raw.len()) };
780 Array4::from_shape_vec((1, 160, 160, 32), raw.to_vec()).unwrap()
781 }
782
783 fn load_yolov8s_det() -> Array3<i8> {
784 let raw = include_bytes!(concat!(
785 env!("CARGO_MANIFEST_DIR"),
786 "/../../testdata/yolov8s_80_classes.bin"
787 ));
788 let raw = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const i8, raw.len()) };
789 Array3::from_shape_vec((1, 84, 8400), raw.to_vec()).unwrap()
790 }
791
792 #[test]
793 fn test_decoder_modelpack() {
794 let score_threshold = 0.45;
795 let iou_threshold = 0.45;
796 let boxes = include_bytes!(concat!(
797 env!("CARGO_MANIFEST_DIR"),
798 "/../../testdata/modelpack_boxes_1935x1x4.bin"
799 ));
800 let boxes = ndarray::Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
801
802 let scores = include_bytes!(concat!(
803 env!("CARGO_MANIFEST_DIR"),
804 "/../../testdata/modelpack_scores_1935x1.bin"
805 ));
806 let scores = ndarray::Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
807
808 let quant_boxes = (0.004656755365431309, 21).into();
809 let quant_scores = (0.0019603664986789227, 0).into();
810
811 let decoder = DecoderBuilder::default()
812 .with_config_modelpack_det(
813 configs::Boxes {
814 decoder: DecoderType::ModelPack,
815 quantization: Some(quant_boxes),
816 shape: vec![1, 1935, 1, 4],
817 dshape: vec![
818 (DimName::Batch, 1),
819 (DimName::NumBoxes, 1935),
820 (DimName::Padding, 1),
821 (DimName::BoxCoords, 4),
822 ],
823 normalized: Some(true),
824 },
825 configs::Scores {
826 decoder: DecoderType::ModelPack,
827 quantization: Some(quant_scores),
828 shape: vec![1, 1935, 1],
829 dshape: vec![
830 (DimName::Batch, 1),
831 (DimName::NumBoxes, 1935),
832 (DimName::NumClasses, 1),
833 ],
834 },
835 )
836 .with_score_threshold(score_threshold)
837 .with_iou_threshold(iou_threshold)
838 .build()
839 .unwrap();
840
841 let quant_boxes = quant_boxes.into();
842 let quant_scores = quant_scores.into();
843
844 let mut output_boxes: Vec<_> = Vec::with_capacity(50);
845 decode_modelpack_det(
846 (boxes.slice(s![0, .., 0, ..]), quant_boxes),
847 (scores.slice(s![0, .., ..]), quant_scores),
848 score_threshold,
849 iou_threshold,
850 &mut output_boxes,
851 );
852 assert!(output_boxes[0].equal_within_delta(
853 &DetectBox {
854 bbox: BoundingBox {
855 xmin: 0.40513772,
856 ymin: 0.6379755,
857 xmax: 0.5122431,
858 ymax: 0.7730214,
859 },
860 score: 0.4861709,
861 label: 0
862 },
863 1e-6
864 ));
865
866 let mut output_boxes1 = Vec::with_capacity(50);
867 let mut output_masks1 = Vec::with_capacity(50);
868
869 decoder
870 .decode_quantized(
871 &[boxes.view().into(), scores.view().into()],
872 &mut output_boxes1,
873 &mut output_masks1,
874 )
875 .unwrap();
876
877 let mut output_boxes_float = Vec::with_capacity(50);
878 let mut output_masks_float = Vec::with_capacity(50);
879
880 let boxes = dequantize_ndarray(boxes.view(), quant_boxes);
881 let scores = dequantize_ndarray(scores.view(), quant_scores);
882
883 decoder
884 .decode_float::<f32>(
885 &[boxes.view().into_dyn(), scores.view().into_dyn()],
886 &mut output_boxes_float,
887 &mut output_masks_float,
888 )
889 .unwrap();
890
891 compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
892 compare_outputs(
893 (&output_boxes, &output_boxes_float),
894 (&[], &output_masks_float),
895 );
896 }
897
898 #[test]
899 fn test_decoder_modelpack_split_u8() {
900 let score_threshold = 0.45;
901 let iou_threshold = 0.45;
902 let detect0 = include_bytes!(concat!(
903 env!("CARGO_MANIFEST_DIR"),
904 "/../../testdata/modelpack_split_9x15x18.bin"
905 ));
906 let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
907
908 let detect1 = include_bytes!(concat!(
909 env!("CARGO_MANIFEST_DIR"),
910 "/../../testdata/modelpack_split_17x30x18.bin"
911 ));
912 let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
913
914 let quant0 = (0.08547406643629074, 174).into();
915 let quant1 = (0.09929127991199493, 183).into();
916 let anchors0 = vec![
917 [0.36666667461395264, 0.31481480598449707],
918 [0.38749998807907104, 0.4740740656852722],
919 [0.5333333611488342, 0.644444465637207],
920 ];
921 let anchors1 = vec![
922 [0.13750000298023224, 0.2074074000120163],
923 [0.2541666626930237, 0.21481481194496155],
924 [0.23125000298023224, 0.35185185074806213],
925 ];
926
927 let detect_config0 = configs::Detection {
928 decoder: DecoderType::ModelPack,
929 shape: vec![1, 9, 15, 18],
930 anchors: Some(anchors0.clone()),
931 quantization: Some(quant0),
932 dshape: vec![
933 (DimName::Batch, 1),
934 (DimName::Height, 9),
935 (DimName::Width, 15),
936 (DimName::NumAnchorsXFeatures, 18),
937 ],
938 normalized: Some(true),
939 };
940
941 let detect_config1 = configs::Detection {
942 decoder: DecoderType::ModelPack,
943 shape: vec![1, 17, 30, 18],
944 anchors: Some(anchors1.clone()),
945 quantization: Some(quant1),
946 dshape: vec![
947 (DimName::Batch, 1),
948 (DimName::Height, 17),
949 (DimName::Width, 30),
950 (DimName::NumAnchorsXFeatures, 18),
951 ],
952 normalized: Some(true),
953 };
954
955 let config0 = (&detect_config0).try_into().unwrap();
956 let config1 = (&detect_config1).try_into().unwrap();
957
958 let decoder = DecoderBuilder::default()
959 .with_config_modelpack_det_split(vec![detect_config1, detect_config0])
960 .with_score_threshold(score_threshold)
961 .with_iou_threshold(iou_threshold)
962 .build()
963 .unwrap();
964
965 let quant0 = quant0.into();
966 let quant1 = quant1.into();
967
968 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
969 decode_modelpack_split_quant(
970 &[
971 detect0.slice(s![0, .., .., ..]),
972 detect1.slice(s![0, .., .., ..]),
973 ],
974 &[config0, config1],
975 score_threshold,
976 iou_threshold,
977 &mut output_boxes,
978 );
979 assert!(output_boxes[0].equal_within_delta(
980 &DetectBox {
981 bbox: BoundingBox {
982 xmin: 0.43171933,
983 ymin: 0.68243736,
984 xmax: 0.5626645,
985 ymax: 0.808863,
986 },
987 score: 0.99240804,
988 label: 0
989 },
990 1e-6
991 ));
992
993 let mut output_boxes1: Vec<_> = Vec::with_capacity(10);
994 let mut output_masks1: Vec<_> = Vec::with_capacity(10);
995 decoder
996 .decode_quantized(
997 &[detect0.view().into(), detect1.view().into()],
998 &mut output_boxes1,
999 &mut output_masks1,
1000 )
1001 .unwrap();
1002
1003 let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(10);
1004 let mut output_masks1_f32: Vec<_> = Vec::with_capacity(10);
1005
1006 let detect0 = dequantize_ndarray(detect0.view(), quant0);
1007 let detect1 = dequantize_ndarray(detect1.view(), quant1);
1008 decoder
1009 .decode_float::<f32>(
1010 &[detect0.view().into_dyn(), detect1.view().into_dyn()],
1011 &mut output_boxes1_f32,
1012 &mut output_masks1_f32,
1013 )
1014 .unwrap();
1015
1016 compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1017 compare_outputs(
1018 (&output_boxes, &output_boxes1_f32),
1019 (&[], &output_masks1_f32),
1020 );
1021 }
1022
1023 #[test]
1024 fn test_decoder_parse_config_modelpack_split_u8() {
1025 let score_threshold = 0.45;
1026 let iou_threshold = 0.45;
1027 let detect0 = include_bytes!(concat!(
1028 env!("CARGO_MANIFEST_DIR"),
1029 "/../../testdata/modelpack_split_9x15x18.bin"
1030 ));
1031 let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1032
1033 let detect1 = include_bytes!(concat!(
1034 env!("CARGO_MANIFEST_DIR"),
1035 "/../../testdata/modelpack_split_17x30x18.bin"
1036 ));
1037 let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1038
1039 let decoder = DecoderBuilder::default()
1040 .with_config_yaml_str(
1041 include_str!(concat!(
1042 env!("CARGO_MANIFEST_DIR"),
1043 "/../../testdata/modelpack_split.yaml"
1044 ))
1045 .to_string(),
1046 )
1047 .with_score_threshold(score_threshold)
1048 .with_iou_threshold(iou_threshold)
1049 .build()
1050 .unwrap();
1051
1052 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1053 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1054 decoder
1055 .decode_quantized(
1056 &[
1057 ArrayViewDQuantized::from(detect1.view()),
1058 ArrayViewDQuantized::from(detect0.view()),
1059 ],
1060 &mut output_boxes,
1061 &mut output_masks,
1062 )
1063 .unwrap();
1064 assert!(output_boxes[0].equal_within_delta(
1065 &DetectBox {
1066 bbox: BoundingBox {
1067 xmin: 0.43171933,
1068 ymin: 0.68243736,
1069 xmax: 0.5626645,
1070 ymax: 0.808863,
1071 },
1072 score: 0.99240804,
1073 label: 0
1074 },
1075 1e-6
1076 ));
1077 }
1078
1079 #[test]
1080 fn test_modelpack_seg() {
1081 let out = include_bytes!(concat!(
1082 env!("CARGO_MANIFEST_DIR"),
1083 "/../../testdata/modelpack_seg_2x160x160.bin"
1084 ));
1085 let out = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
1086 let quant = (1.0 / 255.0, 0).into();
1087
1088 let decoder = DecoderBuilder::default()
1089 .with_config_modelpack_seg(configs::Segmentation {
1090 decoder: DecoderType::ModelPack,
1091 quantization: Some(quant),
1092 shape: vec![1, 2, 160, 160],
1093 dshape: vec![
1094 (DimName::Batch, 1),
1095 (DimName::NumClasses, 2),
1096 (DimName::Height, 160),
1097 (DimName::Width, 160),
1098 ],
1099 })
1100 .build()
1101 .unwrap();
1102 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1103 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1104 decoder
1105 .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
1106 .unwrap();
1107
1108 let mut mask = out.slice(s![0, .., .., ..]);
1109 mask.swap_axes(0, 1);
1110 mask.swap_axes(1, 2);
1111 let mask = [Segmentation {
1112 xmin: 0.0,
1113 ymin: 0.0,
1114 xmax: 1.0,
1115 ymax: 1.0,
1116 segmentation: mask.into_owned(),
1117 }];
1118 compare_outputs((&[], &output_boxes), (&mask, &output_masks));
1119
1120 decoder
1121 .decode_float::<f32>(
1122 &[dequantize_ndarray(out.view(), quant.into())
1123 .view()
1124 .into_dyn()],
1125 &mut output_boxes,
1126 &mut output_masks,
1127 )
1128 .unwrap();
1129
1130 compare_outputs((&[], &output_boxes), (&[], &[]));
1136 let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1137 let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1138
1139 assert_eq!(mask0, mask1);
1140 }
1141 #[test]
1142 fn test_modelpack_seg_quant() {
1143 let out = include_bytes!(concat!(
1144 env!("CARGO_MANIFEST_DIR"),
1145 "/../../testdata/modelpack_seg_2x160x160.bin"
1146 ));
1147 let out_u8 = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
1148 let out_i8 = out_u8.mapv(|x| (x as i16 - 128) as i8);
1149 let out_u16 = out_u8.mapv(|x| (x as u16) << 8);
1150 let out_i16 = out_u8.mapv(|x| (((x as i32) << 8) - 32768) as i16);
1151 let out_u32 = out_u8.mapv(|x| (x as u32) << 24);
1152 let out_i32 = out_u8.mapv(|x| (((x as i64) << 24) - 2147483648) as i32);
1153
1154 let quant = (1.0 / 255.0, 0).into();
1155
1156 let decoder = DecoderBuilder::default()
1157 .with_config_modelpack_seg(configs::Segmentation {
1158 decoder: DecoderType::ModelPack,
1159 quantization: Some(quant),
1160 shape: vec![1, 2, 160, 160],
1161 dshape: vec![
1162 (DimName::Batch, 1),
1163 (DimName::NumClasses, 2),
1164 (DimName::Height, 160),
1165 (DimName::Width, 160),
1166 ],
1167 })
1168 .build()
1169 .unwrap();
1170 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1171 let mut output_masks_u8: Vec<_> = Vec::with_capacity(10);
1172 decoder
1173 .decode_quantized(
1174 &[out_u8.view().into()],
1175 &mut output_boxes,
1176 &mut output_masks_u8,
1177 )
1178 .unwrap();
1179
1180 let mut output_masks_i8: Vec<_> = Vec::with_capacity(10);
1181 decoder
1182 .decode_quantized(
1183 &[out_i8.view().into()],
1184 &mut output_boxes,
1185 &mut output_masks_i8,
1186 )
1187 .unwrap();
1188
1189 let mut output_masks_u16: Vec<_> = Vec::with_capacity(10);
1190 decoder
1191 .decode_quantized(
1192 &[out_u16.view().into()],
1193 &mut output_boxes,
1194 &mut output_masks_u16,
1195 )
1196 .unwrap();
1197
1198 let mut output_masks_i16: Vec<_> = Vec::with_capacity(10);
1199 decoder
1200 .decode_quantized(
1201 &[out_i16.view().into()],
1202 &mut output_boxes,
1203 &mut output_masks_i16,
1204 )
1205 .unwrap();
1206
1207 let mut output_masks_u32: Vec<_> = Vec::with_capacity(10);
1208 decoder
1209 .decode_quantized(
1210 &[out_u32.view().into()],
1211 &mut output_boxes,
1212 &mut output_masks_u32,
1213 )
1214 .unwrap();
1215
1216 let mut output_masks_i32: Vec<_> = Vec::with_capacity(10);
1217 decoder
1218 .decode_quantized(
1219 &[out_i32.view().into()],
1220 &mut output_boxes,
1221 &mut output_masks_i32,
1222 )
1223 .unwrap();
1224
1225 compare_outputs((&[], &output_boxes), (&[], &[]));
1226 let mask_u8 = segmentation_to_mask(output_masks_u8[0].segmentation.view()).unwrap();
1227 let mask_i8 = segmentation_to_mask(output_masks_i8[0].segmentation.view()).unwrap();
1228 let mask_u16 = segmentation_to_mask(output_masks_u16[0].segmentation.view()).unwrap();
1229 let mask_i16 = segmentation_to_mask(output_masks_i16[0].segmentation.view()).unwrap();
1230 let mask_u32 = segmentation_to_mask(output_masks_u32[0].segmentation.view()).unwrap();
1231 let mask_i32 = segmentation_to_mask(output_masks_i32[0].segmentation.view()).unwrap();
1232 assert_eq!(mask_u8, mask_i8);
1233 assert_eq!(mask_u8, mask_u16);
1234 assert_eq!(mask_u8, mask_i16);
1235 assert_eq!(mask_u8, mask_u32);
1236 assert_eq!(mask_u8, mask_i32);
1237 }
1238
1239 #[test]
1240 fn test_modelpack_segdet() {
1241 let score_threshold = 0.45;
1242 let iou_threshold = 0.45;
1243
1244 let boxes = include_bytes!(concat!(
1245 env!("CARGO_MANIFEST_DIR"),
1246 "/../../testdata/modelpack_boxes_1935x1x4.bin"
1247 ));
1248 let boxes = Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
1249
1250 let scores = include_bytes!(concat!(
1251 env!("CARGO_MANIFEST_DIR"),
1252 "/../../testdata/modelpack_scores_1935x1.bin"
1253 ));
1254 let scores = Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
1255
1256 let seg = include_bytes!(concat!(
1257 env!("CARGO_MANIFEST_DIR"),
1258 "/../../testdata/modelpack_seg_2x160x160.bin"
1259 ));
1260 let seg = Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1261
1262 let quant_boxes = (0.004656755365431309, 21).into();
1263 let quant_scores = (0.0019603664986789227, 0).into();
1264 let quant_seg = (1.0 / 255.0, 0).into();
1265
1266 let decoder = DecoderBuilder::default()
1267 .with_config_modelpack_segdet(
1268 configs::Boxes {
1269 decoder: DecoderType::ModelPack,
1270 quantization: Some(quant_boxes),
1271 shape: vec![1, 1935, 1, 4],
1272 dshape: vec![
1273 (DimName::Batch, 1),
1274 (DimName::NumBoxes, 1935),
1275 (DimName::Padding, 1),
1276 (DimName::BoxCoords, 4),
1277 ],
1278 normalized: Some(true),
1279 },
1280 configs::Scores {
1281 decoder: DecoderType::ModelPack,
1282 quantization: Some(quant_scores),
1283 shape: vec![1, 1935, 1],
1284 dshape: vec![
1285 (DimName::Batch, 1),
1286 (DimName::NumBoxes, 1935),
1287 (DimName::NumClasses, 1),
1288 ],
1289 },
1290 configs::Segmentation {
1291 decoder: DecoderType::ModelPack,
1292 quantization: Some(quant_seg),
1293 shape: vec![1, 2, 160, 160],
1294 dshape: vec![
1295 (DimName::Batch, 1),
1296 (DimName::NumClasses, 2),
1297 (DimName::Height, 160),
1298 (DimName::Width, 160),
1299 ],
1300 },
1301 )
1302 .with_iou_threshold(iou_threshold)
1303 .with_score_threshold(score_threshold)
1304 .build()
1305 .unwrap();
1306 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1307 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1308 decoder
1309 .decode_quantized(
1310 &[scores.view().into(), boxes.view().into(), seg.view().into()],
1311 &mut output_boxes,
1312 &mut output_masks,
1313 )
1314 .unwrap();
1315
1316 let mut mask = seg.slice(s![0, .., .., ..]);
1317 mask.swap_axes(0, 1);
1318 mask.swap_axes(1, 2);
1319 let mask = [Segmentation {
1320 xmin: 0.0,
1321 ymin: 0.0,
1322 xmax: 1.0,
1323 ymax: 1.0,
1324 segmentation: mask.into_owned(),
1325 }];
1326 let correct_boxes = [DetectBox {
1327 bbox: BoundingBox {
1328 xmin: 0.40513772,
1329 ymin: 0.6379755,
1330 xmax: 0.5122431,
1331 ymax: 0.7730214,
1332 },
1333 score: 0.4861709,
1334 label: 0,
1335 }];
1336 compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1337
1338 let scores = dequantize_ndarray(scores.view(), quant_scores.into());
1339 let boxes = dequantize_ndarray(boxes.view(), quant_boxes.into());
1340 let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1341 decoder
1342 .decode_float::<f32>(
1343 &[
1344 scores.view().into_dyn(),
1345 boxes.view().into_dyn(),
1346 seg.view().into_dyn(),
1347 ],
1348 &mut output_boxes,
1349 &mut output_masks,
1350 )
1351 .unwrap();
1352
1353 compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1359 let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1360 let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1361
1362 assert_eq!(mask0, mask1);
1363 }
1364
1365 #[test]
1366 fn test_modelpack_segdet_split() {
1367 let score_threshold = 0.8;
1368 let iou_threshold = 0.5;
1369
1370 let seg = include_bytes!(concat!(
1371 env!("CARGO_MANIFEST_DIR"),
1372 "/../../testdata/modelpack_seg_2x160x160.bin"
1373 ));
1374 let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1375
1376 let detect0 = include_bytes!(concat!(
1377 env!("CARGO_MANIFEST_DIR"),
1378 "/../../testdata/modelpack_split_9x15x18.bin"
1379 ));
1380 let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1381
1382 let detect1 = include_bytes!(concat!(
1383 env!("CARGO_MANIFEST_DIR"),
1384 "/../../testdata/modelpack_split_17x30x18.bin"
1385 ));
1386 let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1387
1388 let quant0 = (0.08547406643629074, 174).into();
1389 let quant1 = (0.09929127991199493, 183).into();
1390 let quant_seg = (1.0 / 255.0, 0).into();
1391
1392 let anchors0 = vec![
1393 [0.36666667461395264, 0.31481480598449707],
1394 [0.38749998807907104, 0.4740740656852722],
1395 [0.5333333611488342, 0.644444465637207],
1396 ];
1397 let anchors1 = vec![
1398 [0.13750000298023224, 0.2074074000120163],
1399 [0.2541666626930237, 0.21481481194496155],
1400 [0.23125000298023224, 0.35185185074806213],
1401 ];
1402
1403 let decoder = DecoderBuilder::default()
1404 .with_config_modelpack_segdet_split(
1405 vec![
1406 configs::Detection {
1407 decoder: DecoderType::ModelPack,
1408 shape: vec![1, 17, 30, 18],
1409 anchors: Some(anchors1),
1410 quantization: Some(quant1),
1411 dshape: vec![
1412 (DimName::Batch, 1),
1413 (DimName::Height, 17),
1414 (DimName::Width, 30),
1415 (DimName::NumAnchorsXFeatures, 18),
1416 ],
1417 normalized: Some(true),
1418 },
1419 configs::Detection {
1420 decoder: DecoderType::ModelPack,
1421 shape: vec![1, 9, 15, 18],
1422 anchors: Some(anchors0),
1423 quantization: Some(quant0),
1424 dshape: vec![
1425 (DimName::Batch, 1),
1426 (DimName::Height, 9),
1427 (DimName::Width, 15),
1428 (DimName::NumAnchorsXFeatures, 18),
1429 ],
1430 normalized: Some(true),
1431 },
1432 ],
1433 configs::Segmentation {
1434 decoder: DecoderType::ModelPack,
1435 quantization: Some(quant_seg),
1436 shape: vec![1, 2, 160, 160],
1437 dshape: vec![
1438 (DimName::Batch, 1),
1439 (DimName::NumClasses, 2),
1440 (DimName::Height, 160),
1441 (DimName::Width, 160),
1442 ],
1443 },
1444 )
1445 .with_score_threshold(score_threshold)
1446 .with_iou_threshold(iou_threshold)
1447 .build()
1448 .unwrap();
1449 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1450 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1451 decoder
1452 .decode_quantized(
1453 &[
1454 detect0.view().into(),
1455 detect1.view().into(),
1456 seg.view().into(),
1457 ],
1458 &mut output_boxes,
1459 &mut output_masks,
1460 )
1461 .unwrap();
1462
1463 let mut mask = seg.slice(s![0, .., .., ..]);
1464 mask.swap_axes(0, 1);
1465 mask.swap_axes(1, 2);
1466 let mask = [Segmentation {
1467 xmin: 0.0,
1468 ymin: 0.0,
1469 xmax: 1.0,
1470 ymax: 1.0,
1471 segmentation: mask.into_owned(),
1472 }];
1473 let correct_boxes = [DetectBox {
1474 bbox: BoundingBox {
1475 xmin: 0.43171933,
1476 ymin: 0.68243736,
1477 xmax: 0.5626645,
1478 ymax: 0.808863,
1479 },
1480 score: 0.99240804,
1481 label: 0,
1482 }];
1483 println!("Output Boxes: {:?}", output_boxes);
1484 compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1485
1486 let detect0 = dequantize_ndarray(detect0.view(), quant0.into());
1487 let detect1 = dequantize_ndarray(detect1.view(), quant1.into());
1488 let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1489 decoder
1490 .decode_float::<f32>(
1491 &[
1492 detect0.view().into_dyn(),
1493 detect1.view().into_dyn(),
1494 seg.view().into_dyn(),
1495 ],
1496 &mut output_boxes,
1497 &mut output_masks,
1498 )
1499 .unwrap();
1500
1501 compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1507 let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1508 let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1509
1510 assert_eq!(mask0, mask1);
1511 }
1512
1513 #[test]
1514 fn test_dequant_chunked() {
1515 let mut out = load_yolov8s_det().into_raw_vec_and_offset().0;
1516 out.push(123); let mut out_dequant = vec![0.0; 84 * 8400 + 1];
1519 let mut out_dequant_simd = vec![0.0; 84 * 8400 + 1];
1520 let quant = Quantization::new(0.0040811873, -123);
1521 dequantize_cpu(&out, quant, &mut out_dequant);
1522
1523 dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1524 assert_eq!(out_dequant, out_dequant_simd);
1525
1526 let quant = Quantization::new(0.0040811873, 0);
1527 dequantize_cpu(&out, quant, &mut out_dequant);
1528
1529 dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1530 assert_eq!(out_dequant, out_dequant_simd);
1531 }
1532
1533 #[test]
1534 fn test_dequant_ground_truth() {
1535 let quant = Quantization::new(0.1, -128);
1540 let input: Vec<i8> = vec![0, 127, -128, 64];
1541 let mut output = vec![0.0f32; 4];
1542 let mut output_chunked = vec![0.0f32; 4];
1543 dequantize_cpu(&input, quant, &mut output);
1544 dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1545 let expected: Vec<f32> = vec![12.8, 25.5, 0.0, 19.2];
1550 for (i, (&out, &exp)) in output.iter().zip(expected.iter()).enumerate() {
1551 assert!((out - exp).abs() < 1e-5, "cpu[{i}]: {out} != {exp}");
1552 }
1553 for (i, (&out, &exp)) in output_chunked.iter().zip(expected.iter()).enumerate() {
1554 assert!((out - exp).abs() < 1e-5, "chunked[{i}]: {out} != {exp}");
1555 }
1556
1557 let quant = Quantization::new(1.0, 0);
1559 dequantize_cpu(&input, quant, &mut output);
1560 dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1561 let expected: Vec<f32> = vec![0.0, 127.0, -128.0, 64.0];
1562 assert_eq!(output, expected);
1563 assert_eq!(output_chunked, expected);
1564
1565 let quant = Quantization::new(0.5, 0);
1567 dequantize_cpu(&input, quant, &mut output);
1568 dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1569 let expected: Vec<f32> = vec![0.0, 63.5, -64.0, 32.0];
1570 assert_eq!(output, expected);
1571 assert_eq!(output_chunked, expected);
1572
1573 let quant = Quantization::new(0.021287762, 31);
1575 let input: Vec<i8> = vec![-128, -1, 0, 1, 31, 127];
1576 let mut output = vec![0.0f32; 6];
1577 let mut output_chunked = vec![0.0f32; 6];
1578 dequantize_cpu(&input, quant, &mut output);
1579 dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1580 for i in 0..6 {
1581 let expected = (input[i] as f32 - 31.0) * 0.021287762;
1582 assert!(
1583 (output[i] - expected).abs() < 1e-5,
1584 "cpu[{i}]: {} != {expected}",
1585 output[i]
1586 );
1587 assert!(
1588 (output_chunked[i] - expected).abs() < 1e-5,
1589 "chunked[{i}]: {} != {expected}",
1590 output_chunked[i]
1591 );
1592 }
1593 }
1594
1595 #[test]
1596 fn test_decoder_yolo_det() {
1597 let score_threshold = 0.25;
1598 let iou_threshold = 0.7;
1599 let out = load_yolov8s_det();
1600 let quant = (0.0040811873, -123).into();
1601
1602 let decoder = DecoderBuilder::default()
1603 .with_config_yolo_det(
1604 configs::Detection {
1605 decoder: DecoderType::Ultralytics,
1606 shape: vec![1, 84, 8400],
1607 anchors: None,
1608 quantization: Some(quant),
1609 dshape: vec![
1610 (DimName::Batch, 1),
1611 (DimName::NumFeatures, 84),
1612 (DimName::NumBoxes, 8400),
1613 ],
1614 normalized: Some(true),
1615 },
1616 Some(DecoderVersion::Yolo11),
1617 )
1618 .with_score_threshold(score_threshold)
1619 .with_iou_threshold(iou_threshold)
1620 .build()
1621 .unwrap();
1622
1623 let mut output_boxes: Vec<_> = Vec::with_capacity(50);
1624 decode_yolo_det(
1625 (out.slice(s![0, .., ..]), quant.into()),
1626 score_threshold,
1627 iou_threshold,
1628 Some(configs::Nms::ClassAgnostic),
1629 &mut output_boxes,
1630 );
1631 assert!(output_boxes[0].equal_within_delta(
1632 &DetectBox {
1633 bbox: BoundingBox {
1634 xmin: 0.5285137,
1635 ymin: 0.05305544,
1636 xmax: 0.87541467,
1637 ymax: 0.9998909,
1638 },
1639 score: 0.5591227,
1640 label: 0
1641 },
1642 1e-6
1643 ));
1644
1645 assert!(output_boxes[1].equal_within_delta(
1646 &DetectBox {
1647 bbox: BoundingBox {
1648 xmin: 0.130598,
1649 ymin: 0.43260583,
1650 xmax: 0.35098213,
1651 ymax: 0.9958097,
1652 },
1653 score: 0.33057618,
1654 label: 75
1655 },
1656 1e-6
1657 ));
1658
1659 let mut output_boxes1: Vec<_> = Vec::with_capacity(50);
1660 let mut output_masks1: Vec<_> = Vec::with_capacity(50);
1661 decoder
1662 .decode_quantized(&[out.view().into()], &mut output_boxes1, &mut output_masks1)
1663 .unwrap();
1664
1665 let out = dequantize_ndarray(out.view(), quant.into());
1666 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(50);
1667 let mut output_masks_f32: Vec<_> = Vec::with_capacity(50);
1668 decoder
1669 .decode_float::<f32>(
1670 &[out.view().into_dyn()],
1671 &mut output_boxes_f32,
1672 &mut output_masks_f32,
1673 )
1674 .unwrap();
1675
1676 compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1677 compare_outputs((&output_boxes, &output_boxes_f32), (&[], &output_masks_f32));
1678 }
1679
1680 #[test]
1681 fn test_decoder_masks() {
1682 let score_threshold = 0.45;
1683 let iou_threshold = 0.45;
1684 let boxes = load_yolov8_boxes();
1685 let quant_boxes = Quantization::new(0.021287761628627777, 31);
1686
1687 let protos = load_yolov8_protos();
1688 let quant_protos = Quantization::new(0.02491161972284317, -117);
1689 let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1690 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1691 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1692 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1693 decode_yolo_segdet_float(
1694 seg.slice(s![0, .., ..]),
1695 protos.slice(s![0, .., .., ..]),
1696 score_threshold,
1697 iou_threshold,
1698 Some(configs::Nms::ClassAgnostic),
1699 &mut output_boxes,
1700 &mut output_masks,
1701 )
1702 .unwrap();
1703 assert_eq!(output_boxes.len(), 2);
1704 assert_eq!(output_boxes.len(), output_masks.len());
1705
1706 for (b, m) in output_boxes.iter().zip(&output_masks) {
1707 assert!(b.bbox.xmin >= m.xmin);
1708 assert!(b.bbox.ymin >= m.ymin);
1709 assert!(b.bbox.xmax >= m.xmax);
1710 assert!(b.bbox.ymax >= m.ymax);
1711 }
1712 assert!(output_boxes[0].equal_within_delta(
1713 &DetectBox {
1714 bbox: BoundingBox {
1715 xmin: 0.08515105,
1716 ymin: 0.7131401,
1717 xmax: 0.29802868,
1718 ymax: 0.8195788,
1719 },
1720 score: 0.91537374,
1721 label: 23
1722 },
1723 1.0 / 160.0, ));
1725
1726 assert!(output_boxes[1].equal_within_delta(
1727 &DetectBox {
1728 bbox: BoundingBox {
1729 xmin: 0.59605736,
1730 ymin: 0.25545314,
1731 xmax: 0.93666154,
1732 ymax: 0.72378385,
1733 },
1734 score: 0.91537374,
1735 label: 23
1736 },
1737 1.0 / 160.0, ));
1739
1740 let full_mask = include_bytes!(concat!(
1741 env!("CARGO_MANIFEST_DIR"),
1742 "/../../testdata/yolov8_mask_results.bin"
1743 ));
1744 let full_mask = ndarray::Array2::from_shape_vec((160, 160), full_mask.to_vec()).unwrap();
1745
1746 let cropped_mask = full_mask.slice(ndarray::s![
1747 (output_masks[1].ymin * 160.0) as usize..(output_masks[1].ymax * 160.0) as usize,
1748 (output_masks[1].xmin * 160.0) as usize..(output_masks[1].xmax * 160.0) as usize,
1749 ]);
1750
1751 assert_eq!(
1752 cropped_mask,
1753 segmentation_to_mask(output_masks[1].segmentation.view()).unwrap()
1754 );
1755 }
1756
1757 #[test]
1767 fn test_decoder_masks_nchw_protos() {
1768 let score_threshold = 0.45;
1769 let iou_threshold = 0.45;
1770
1771 let boxes_2d = load_yolov8_boxes().slice_move(s![0, .., ..]);
1773 let quant_boxes = Quantization::new(0.021287761628627777, 31);
1774
1775 let protos_hwc = load_yolov8_protos().slice_move(s![0, .., .., ..]);
1777 let quant_protos = Quantization::new(0.02491161972284317, -117);
1778 let protos_f32_hwc = dequantize_ndarray::<_, _, f32>(protos_hwc.view(), quant_protos);
1779
1780 let seg = dequantize_ndarray::<_, _, f32>(boxes_2d.view(), quant_boxes);
1782 let mut ref_boxes: Vec<_> = Vec::with_capacity(10);
1783 let mut ref_masks: Vec<_> = Vec::with_capacity(10);
1784 decode_yolo_segdet_float(
1785 seg.view(),
1786 protos_f32_hwc.view(),
1787 score_threshold,
1788 iou_threshold,
1789 Some(configs::Nms::ClassAgnostic),
1790 &mut ref_boxes,
1791 &mut ref_masks,
1792 )
1793 .unwrap();
1794 assert_eq!(ref_boxes.len(), 2);
1795
1796 let protos_f32_chw_view = protos_f32_hwc.view().permuted_axes([2, 0, 1]); let protos_f32_chw = protos_f32_chw_view.to_owned();
1803 let protos_nchw = protos_f32_chw.insert_axis(ndarray::Axis(0)); let seg_3d = seg.insert_axis(ndarray::Axis(0)); let decoder = DecoderBuilder::default()
1812 .with_config_yolo_segdet(
1813 configs::Detection {
1814 decoder: configs::DecoderType::Ultralytics,
1815 quantization: None,
1816 shape: vec![1, 116, 8400],
1817 dshape: vec![
1818 (configs::DimName::Batch, 1),
1819 (configs::DimName::NumFeatures, 116),
1820 (configs::DimName::NumBoxes, 8400),
1821 ],
1822 normalized: Some(true),
1823 anchors: None,
1824 },
1825 configs::Protos {
1826 decoder: configs::DecoderType::Ultralytics,
1827 quantization: None,
1828 shape: vec![1, 32, 160, 160],
1829 dshape: vec![
1830 (configs::DimName::Batch, 1),
1831 (configs::DimName::NumProtos, 32),
1832 (configs::DimName::Height, 160),
1833 (configs::DimName::Width, 160),
1834 ],
1835 },
1836 None, )
1838 .with_score_threshold(score_threshold)
1839 .with_iou_threshold(iou_threshold)
1840 .build()
1841 .unwrap();
1842
1843 let mut cfg_boxes: Vec<_> = Vec::with_capacity(10);
1844 let mut cfg_masks: Vec<_> = Vec::with_capacity(10);
1845 decoder
1846 .decode_float(
1847 &[seg_3d.view().into_dyn(), protos_nchw.view().into_dyn()],
1848 &mut cfg_boxes,
1849 &mut cfg_masks,
1850 )
1851 .unwrap();
1852
1853 assert_eq!(
1855 cfg_boxes.len(),
1856 ref_boxes.len(),
1857 "config path produced {} boxes, reference produced {}",
1858 cfg_boxes.len(),
1859 ref_boxes.len()
1860 );
1861
1862 for (i, (cb, rb)) in cfg_boxes.iter().zip(&ref_boxes).enumerate() {
1864 assert!(
1865 cb.equal_within_delta(rb, 0.01),
1866 "box {i} mismatch: config={cb:?}, reference={rb:?}"
1867 );
1868 }
1869
1870 for (i, (cm, rm)) in cfg_masks.iter().zip(&ref_masks).enumerate() {
1872 let cm_arr = segmentation_to_mask(cm.segmentation.view()).unwrap();
1873 let rm_arr = segmentation_to_mask(rm.segmentation.view()).unwrap();
1874 assert_eq!(
1875 cm_arr, rm_arr,
1876 "mask {i} pixel mismatch between config-driven and reference paths"
1877 );
1878 }
1879 }
1880
1881 #[test]
1882 fn test_decoder_masks_i8() {
1883 let score_threshold = 0.45;
1884 let iou_threshold = 0.45;
1885 let boxes = load_yolov8_boxes();
1886 let quant_boxes = (0.021287761628627777, 31).into();
1887
1888 let protos = load_yolov8_protos();
1889 let quant_protos = (0.02491161972284317, -117).into();
1890 let mut output_boxes: Vec<_> = Vec::with_capacity(500);
1891 let mut output_masks: Vec<_> = Vec::with_capacity(500);
1892
1893 let decoder = DecoderBuilder::default()
1894 .with_config_yolo_segdet(
1895 configs::Detection {
1896 decoder: configs::DecoderType::Ultralytics,
1897 quantization: Some(quant_boxes),
1898 shape: vec![1, 116, 8400],
1899 anchors: None,
1900 dshape: vec![
1901 (DimName::Batch, 1),
1902 (DimName::NumFeatures, 116),
1903 (DimName::NumBoxes, 8400),
1904 ],
1905 normalized: Some(true),
1906 },
1907 Protos {
1908 decoder: configs::DecoderType::Ultralytics,
1909 quantization: Some(quant_protos),
1910 shape: vec![1, 160, 160, 32],
1911 dshape: vec![
1912 (DimName::Batch, 1),
1913 (DimName::Height, 160),
1914 (DimName::Width, 160),
1915 (DimName::NumProtos, 32),
1916 ],
1917 },
1918 Some(DecoderVersion::Yolo11),
1919 )
1920 .with_score_threshold(score_threshold)
1921 .with_iou_threshold(iou_threshold)
1922 .build()
1923 .unwrap();
1924
1925 let quant_boxes = quant_boxes.into();
1926 let quant_protos = quant_protos.into();
1927
1928 decode_yolo_segdet_quant(
1929 (boxes.slice(s![0, .., ..]), quant_boxes),
1930 (protos.slice(s![0, .., .., ..]), quant_protos),
1931 score_threshold,
1932 iou_threshold,
1933 Some(configs::Nms::ClassAgnostic),
1934 &mut output_boxes,
1935 &mut output_masks,
1936 )
1937 .unwrap();
1938
1939 let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
1940 let mut output_masks1: Vec<_> = Vec::with_capacity(500);
1941
1942 decoder
1943 .decode_quantized(
1944 &[boxes.view().into(), protos.view().into()],
1945 &mut output_boxes1,
1946 &mut output_masks1,
1947 )
1948 .unwrap();
1949
1950 let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1951 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1952
1953 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
1954 let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
1955 decode_yolo_segdet_float(
1956 seg.slice(s![0, .., ..]),
1957 protos.slice(s![0, .., .., ..]),
1958 score_threshold,
1959 iou_threshold,
1960 Some(configs::Nms::ClassAgnostic),
1961 &mut output_boxes_f32,
1962 &mut output_masks_f32,
1963 )
1964 .unwrap();
1965
1966 let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(500);
1967 let mut output_masks1_f32: Vec<_> = Vec::with_capacity(500);
1968
1969 decoder
1970 .decode_float(
1971 &[seg.view().into_dyn(), protos.view().into_dyn()],
1972 &mut output_boxes1_f32,
1973 &mut output_masks1_f32,
1974 )
1975 .unwrap();
1976
1977 compare_outputs(
1978 (&output_boxes, &output_boxes1),
1979 (&output_masks, &output_masks1),
1980 );
1981
1982 compare_outputs(
1983 (&output_boxes, &output_boxes_f32),
1984 (&output_masks, &output_masks_f32),
1985 );
1986
1987 compare_outputs(
1988 (&output_boxes_f32, &output_boxes1_f32),
1989 (&output_masks_f32, &output_masks1_f32),
1990 );
1991 }
1992
1993 #[test]
1994 fn test_decoder_yolo_split() {
1995 let score_threshold = 0.45;
1996 let iou_threshold = 0.45;
1997 let boxes = load_yolov8_boxes();
1998 let boxes: Vec<_> = boxes.iter().map(|x| *x as i16 * 256).collect();
1999 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2000
2001 let quant_boxes = Quantization::new(0.021287761628627777 / 256.0, 31 * 256);
2002
2003 let decoder = DecoderBuilder::default()
2004 .with_config_yolo_split_det(
2005 configs::Boxes {
2006 decoder: configs::DecoderType::Ultralytics,
2007 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2008 shape: vec![1, 4, 8400],
2009 dshape: vec![
2010 (DimName::Batch, 1),
2011 (DimName::BoxCoords, 4),
2012 (DimName::NumBoxes, 8400),
2013 ],
2014 normalized: Some(true),
2015 },
2016 configs::Scores {
2017 decoder: configs::DecoderType::Ultralytics,
2018 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2019 shape: vec![1, 80, 8400],
2020 dshape: vec![
2021 (DimName::Batch, 1),
2022 (DimName::NumClasses, 80),
2023 (DimName::NumBoxes, 8400),
2024 ],
2025 },
2026 )
2027 .with_score_threshold(score_threshold)
2028 .with_iou_threshold(iou_threshold)
2029 .build()
2030 .unwrap();
2031
2032 let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2033 let mut output_masks: Vec<_> = Vec::with_capacity(500);
2034
2035 decoder
2036 .decode_quantized(
2037 &[
2038 boxes.slice(s![.., ..4, ..]).into(),
2039 boxes.slice(s![.., 4..84, ..]).into(),
2040 ],
2041 &mut output_boxes,
2042 &mut output_masks,
2043 )
2044 .unwrap();
2045
2046 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
2047 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2048 decode_yolo_det_float(
2049 seg.slice(s![0, ..84, ..]),
2050 score_threshold,
2051 iou_threshold,
2052 Some(configs::Nms::ClassAgnostic),
2053 &mut output_boxes_f32,
2054 );
2055
2056 let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
2057 let mut output_masks1: Vec<_> = Vec::with_capacity(500);
2058
2059 decoder
2060 .decode_float(
2061 &[
2062 seg.slice(s![.., ..4, ..]).into_dyn(),
2063 seg.slice(s![.., 4..84, ..]).into_dyn(),
2064 ],
2065 &mut output_boxes1,
2066 &mut output_masks1,
2067 )
2068 .unwrap();
2069 compare_outputs((&output_boxes, &output_boxes_f32), (&output_masks, &[]));
2070 compare_outputs((&output_boxes_f32, &output_boxes1), (&[], &output_masks1));
2071 }
2072
2073 #[test]
2074 fn test_decoder_masks_config_mixed() {
2075 let score_threshold = 0.45;
2076 let iou_threshold = 0.45;
2077 let boxes_raw = load_yolov8_boxes();
2078 let boxes: Vec<_> = boxes_raw.iter().map(|x| *x as i16 * 256).collect();
2079 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2080
2081 let quant_boxes = (0.021287761628627777 / 256.0, 31 * 256);
2082
2083 let protos = load_yolov8_protos();
2084 let quant_protos = (0.02491161972284317, -117);
2085
2086 let decoder = build_yolo_split_segdet_decoder(
2087 score_threshold,
2088 iou_threshold,
2089 quant_boxes,
2090 quant_protos,
2091 );
2092 let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2093 let mut output_masks: Vec<_> = Vec::with_capacity(500);
2094
2095 decoder
2096 .decode_quantized(
2097 &[
2098 boxes.slice(s![.., ..4, ..]).into(),
2099 boxes.slice(s![.., 4..84, ..]).into(),
2100 boxes.slice(s![.., 84.., ..]).into(),
2101 protos.view().into(),
2102 ],
2103 &mut output_boxes,
2104 &mut output_masks,
2105 )
2106 .unwrap();
2107
2108 let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos.into());
2109 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes.into());
2110 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2111 let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
2112 decode_yolo_segdet_float(
2113 seg.slice(s![0, .., ..]),
2114 protos.slice(s![0, .., .., ..]),
2115 score_threshold,
2116 iou_threshold,
2117 Some(configs::Nms::ClassAgnostic),
2118 &mut output_boxes_f32,
2119 &mut output_masks_f32,
2120 )
2121 .unwrap();
2122
2123 let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
2124 let mut output_masks1: Vec<_> = Vec::with_capacity(500);
2125
2126 decoder
2127 .decode_float(
2128 &[
2129 seg.slice(s![.., ..4, ..]).into_dyn(),
2130 seg.slice(s![.., 4..84, ..]).into_dyn(),
2131 seg.slice(s![.., 84.., ..]).into_dyn(),
2132 protos.view().into_dyn(),
2133 ],
2134 &mut output_boxes1,
2135 &mut output_masks1,
2136 )
2137 .unwrap();
2138 compare_outputs(
2139 (&output_boxes, &output_boxes_f32),
2140 (&output_masks, &output_masks_f32),
2141 );
2142 compare_outputs(
2143 (&output_boxes_f32, &output_boxes1),
2144 (&output_masks_f32, &output_masks1),
2145 );
2146 }
2147
2148 fn build_yolo_split_segdet_decoder(
2149 score_threshold: f32,
2150 iou_threshold: f32,
2151 quant_boxes: (f32, i32),
2152 quant_protos: (f32, i32),
2153 ) -> crate::Decoder {
2154 DecoderBuilder::default()
2155 .with_config_yolo_split_segdet(
2156 configs::Boxes {
2157 decoder: configs::DecoderType::Ultralytics,
2158 quantization: Some(quant_boxes.into()),
2159 shape: vec![1, 4, 8400],
2160 dshape: vec![
2161 (DimName::Batch, 1),
2162 (DimName::BoxCoords, 4),
2163 (DimName::NumBoxes, 8400),
2164 ],
2165 normalized: Some(true),
2166 },
2167 configs::Scores {
2168 decoder: configs::DecoderType::Ultralytics,
2169 quantization: Some(quant_boxes.into()),
2170 shape: vec![1, 80, 8400],
2171 dshape: vec![
2172 (DimName::Batch, 1),
2173 (DimName::NumClasses, 80),
2174 (DimName::NumBoxes, 8400),
2175 ],
2176 },
2177 configs::MaskCoefficients {
2178 decoder: configs::DecoderType::Ultralytics,
2179 quantization: Some(quant_boxes.into()),
2180 shape: vec![1, 32, 8400],
2181 dshape: vec![
2182 (DimName::Batch, 1),
2183 (DimName::NumProtos, 32),
2184 (DimName::NumBoxes, 8400),
2185 ],
2186 },
2187 configs::Protos {
2188 decoder: configs::DecoderType::Ultralytics,
2189 quantization: Some(quant_protos.into()),
2190 shape: vec![1, 160, 160, 32],
2191 dshape: vec![
2192 (DimName::Batch, 1),
2193 (DimName::Height, 160),
2194 (DimName::Width, 160),
2195 (DimName::NumProtos, 32),
2196 ],
2197 },
2198 )
2199 .with_score_threshold(score_threshold)
2200 .with_iou_threshold(iou_threshold)
2201 .build()
2202 .unwrap()
2203 }
2204
2205 fn build_yolov8_seg_decoder(score_threshold: f32, iou_threshold: f32) -> crate::Decoder {
2206 let config_yaml = include_str!(concat!(
2207 env!("CARGO_MANIFEST_DIR"),
2208 "/../../testdata/yolov8_seg.yaml"
2209 ));
2210 DecoderBuilder::default()
2211 .with_config_yaml_str(config_yaml.to_string())
2212 .with_score_threshold(score_threshold)
2213 .with_iou_threshold(iou_threshold)
2214 .build()
2215 .unwrap()
2216 }
2217 #[test]
2218 fn test_decoder_masks_config_i32() {
2219 let score_threshold = 0.45;
2220 let iou_threshold = 0.45;
2221 let boxes_raw = load_yolov8_boxes();
2222 let scale = 1 << 23;
2223 let boxes: Vec<_> = boxes_raw.iter().map(|x| *x as i32 * scale).collect();
2224 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2225
2226 let quant_boxes = (0.021287761628627777 / scale as f32, 31 * scale);
2227
2228 let protos_raw = load_yolov8_protos();
2229 let protos: Vec<_> = protos_raw.iter().map(|x| *x as i32 * scale).collect();
2230 let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos).unwrap();
2231 let quant_protos = (0.02491161972284317 / scale as f32, -117 * scale);
2232
2233 let decoder = build_yolo_split_segdet_decoder(
2234 score_threshold,
2235 iou_threshold,
2236 quant_boxes,
2237 quant_protos,
2238 );
2239
2240 let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2241 let mut output_masks: Vec<_> = Vec::with_capacity(500);
2242
2243 decoder
2244 .decode_quantized(
2245 &[
2246 boxes.slice(s![.., ..4, ..]).into(),
2247 boxes.slice(s![.., 4..84, ..]).into(),
2248 boxes.slice(s![.., 84.., ..]).into(),
2249 protos.view().into(),
2250 ],
2251 &mut output_boxes,
2252 &mut output_masks,
2253 )
2254 .unwrap();
2255
2256 let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos.into());
2257 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes.into());
2258 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2259 let mut output_masks_f32: Vec<Segmentation> = Vec::with_capacity(500);
2260 decode_yolo_segdet_float(
2261 seg.slice(s![0, .., ..]),
2262 protos.slice(s![0, .., .., ..]),
2263 score_threshold,
2264 iou_threshold,
2265 Some(configs::Nms::ClassAgnostic),
2266 &mut output_boxes_f32,
2267 &mut output_masks_f32,
2268 )
2269 .unwrap();
2270
2271 assert_eq!(output_boxes.len(), output_boxes_f32.len());
2272 assert_eq!(output_masks.len(), output_masks_f32.len());
2273
2274 compare_outputs(
2275 (&output_boxes, &output_boxes_f32),
2276 (&output_masks, &output_masks_f32),
2277 );
2278 }
2279
2280 #[test]
2282 fn test_context_switch() {
2283 let yolo_det = || {
2284 let score_threshold = 0.25;
2285 let iou_threshold = 0.7;
2286 let out = load_yolov8s_det();
2287 let quant = (0.0040811873, -123).into();
2288
2289 let decoder = DecoderBuilder::default()
2290 .with_config_yolo_det(
2291 configs::Detection {
2292 decoder: DecoderType::Ultralytics,
2293 shape: vec![1, 84, 8400],
2294 anchors: None,
2295 quantization: Some(quant),
2296 dshape: vec![
2297 (DimName::Batch, 1),
2298 (DimName::NumFeatures, 84),
2299 (DimName::NumBoxes, 8400),
2300 ],
2301 normalized: None,
2302 },
2303 None,
2304 )
2305 .with_score_threshold(score_threshold)
2306 .with_iou_threshold(iou_threshold)
2307 .build()
2308 .unwrap();
2309
2310 let mut output_boxes: Vec<_> = Vec::with_capacity(50);
2311 let mut output_masks: Vec<_> = Vec::with_capacity(50);
2312
2313 for _ in 0..100 {
2314 decoder
2315 .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
2316 .unwrap();
2317
2318 assert!(output_boxes[0].equal_within_delta(
2319 &DetectBox {
2320 bbox: BoundingBox {
2321 xmin: 0.5285137,
2322 ymin: 0.05305544,
2323 xmax: 0.87541467,
2324 ymax: 0.9998909,
2325 },
2326 score: 0.5591227,
2327 label: 0
2328 },
2329 1e-6
2330 ));
2331
2332 assert!(output_boxes[1].equal_within_delta(
2333 &DetectBox {
2334 bbox: BoundingBox {
2335 xmin: 0.130598,
2336 ymin: 0.43260583,
2337 xmax: 0.35098213,
2338 ymax: 0.9958097,
2339 },
2340 score: 0.33057618,
2341 label: 75
2342 },
2343 1e-6
2344 ));
2345 assert!(output_masks.is_empty());
2346 }
2347 };
2348
2349 let modelpack_det_split = || {
2350 let score_threshold = 0.8;
2351 let iou_threshold = 0.5;
2352
2353 let seg = include_bytes!(concat!(
2354 env!("CARGO_MANIFEST_DIR"),
2355 "/../../testdata/modelpack_seg_2x160x160.bin"
2356 ));
2357 let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
2358
2359 let detect0 = include_bytes!(concat!(
2360 env!("CARGO_MANIFEST_DIR"),
2361 "/../../testdata/modelpack_split_9x15x18.bin"
2362 ));
2363 let detect0 =
2364 ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
2365
2366 let detect1 = include_bytes!(concat!(
2367 env!("CARGO_MANIFEST_DIR"),
2368 "/../../testdata/modelpack_split_17x30x18.bin"
2369 ));
2370 let detect1 =
2371 ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
2372
2373 let mut mask = seg.slice(s![0, .., .., ..]);
2374 mask.swap_axes(0, 1);
2375 mask.swap_axes(1, 2);
2376 let mask = [Segmentation {
2377 xmin: 0.0,
2378 ymin: 0.0,
2379 xmax: 1.0,
2380 ymax: 1.0,
2381 segmentation: mask.into_owned(),
2382 }];
2383 let correct_boxes = [DetectBox {
2384 bbox: BoundingBox {
2385 xmin: 0.43171933,
2386 ymin: 0.68243736,
2387 xmax: 0.5626645,
2388 ymax: 0.808863,
2389 },
2390 score: 0.99240804,
2391 label: 0,
2392 }];
2393
2394 let quant0 = (0.08547406643629074, 174).into();
2395 let quant1 = (0.09929127991199493, 183).into();
2396 let quant_seg = (1.0 / 255.0, 0).into();
2397
2398 let anchors0 = vec![
2399 [0.36666667461395264, 0.31481480598449707],
2400 [0.38749998807907104, 0.4740740656852722],
2401 [0.5333333611488342, 0.644444465637207],
2402 ];
2403 let anchors1 = vec![
2404 [0.13750000298023224, 0.2074074000120163],
2405 [0.2541666626930237, 0.21481481194496155],
2406 [0.23125000298023224, 0.35185185074806213],
2407 ];
2408
2409 let decoder = DecoderBuilder::default()
2410 .with_config_modelpack_segdet_split(
2411 vec![
2412 configs::Detection {
2413 decoder: DecoderType::ModelPack,
2414 shape: vec![1, 17, 30, 18],
2415 anchors: Some(anchors1),
2416 quantization: Some(quant1),
2417 dshape: vec![
2418 (DimName::Batch, 1),
2419 (DimName::Height, 17),
2420 (DimName::Width, 30),
2421 (DimName::NumAnchorsXFeatures, 18),
2422 ],
2423 normalized: None,
2424 },
2425 configs::Detection {
2426 decoder: DecoderType::ModelPack,
2427 shape: vec![1, 9, 15, 18],
2428 anchors: Some(anchors0),
2429 quantization: Some(quant0),
2430 dshape: vec![
2431 (DimName::Batch, 1),
2432 (DimName::Height, 9),
2433 (DimName::Width, 15),
2434 (DimName::NumAnchorsXFeatures, 18),
2435 ],
2436 normalized: None,
2437 },
2438 ],
2439 configs::Segmentation {
2440 decoder: DecoderType::ModelPack,
2441 quantization: Some(quant_seg),
2442 shape: vec![1, 2, 160, 160],
2443 dshape: vec![
2444 (DimName::Batch, 1),
2445 (DimName::NumClasses, 2),
2446 (DimName::Height, 160),
2447 (DimName::Width, 160),
2448 ],
2449 },
2450 )
2451 .with_score_threshold(score_threshold)
2452 .with_iou_threshold(iou_threshold)
2453 .build()
2454 .unwrap();
2455 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
2456 let mut output_masks: Vec<_> = Vec::with_capacity(10);
2457
2458 for _ in 0..100 {
2459 decoder
2460 .decode_quantized(
2461 &[
2462 detect0.view().into(),
2463 detect1.view().into(),
2464 seg.view().into(),
2465 ],
2466 &mut output_boxes,
2467 &mut output_masks,
2468 )
2469 .unwrap();
2470
2471 compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
2472 }
2473 };
2474
2475 let handles = vec![
2476 std::thread::spawn(yolo_det),
2477 std::thread::spawn(modelpack_det_split),
2478 std::thread::spawn(yolo_det),
2479 std::thread::spawn(modelpack_det_split),
2480 std::thread::spawn(yolo_det),
2481 std::thread::spawn(modelpack_det_split),
2482 std::thread::spawn(yolo_det),
2483 std::thread::spawn(modelpack_det_split),
2484 ];
2485 for handle in handles {
2486 handle.join().unwrap();
2487 }
2488 }
2489
2490 #[test]
2491 fn test_ndarray_to_xyxy_float() {
2492 let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2493 let xyxy: [f32; 4] = XYWH::ndarray_to_xyxy_float(arr.view());
2494 assert_eq!(xyxy, [0.0_f32, 10.0, 20.0, 30.0]);
2495
2496 let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2497 let xyxy: [f32; 4] = XYXY::ndarray_to_xyxy_float(arr.view());
2498 assert_eq!(xyxy, [10.0_f32, 20.0, 20.0, 20.0]);
2499 }
2500
2501 #[test]
2502 fn test_class_aware_nms_float() {
2503 use crate::float::nms_class_aware_float;
2504
2505 let boxes = vec![
2507 DetectBox {
2508 bbox: BoundingBox {
2509 xmin: 0.0,
2510 ymin: 0.0,
2511 xmax: 0.5,
2512 ymax: 0.5,
2513 },
2514 score: 0.9,
2515 label: 0, },
2517 DetectBox {
2518 bbox: BoundingBox {
2519 xmin: 0.1,
2520 ymin: 0.1,
2521 xmax: 0.6,
2522 ymax: 0.6,
2523 },
2524 score: 0.8,
2525 label: 1, },
2527 ];
2528
2529 let result = nms_class_aware_float(0.3, boxes.clone());
2532 assert_eq!(
2533 result.len(),
2534 2,
2535 "Class-aware NMS should keep both boxes with different classes"
2536 );
2537
2538 let same_class_boxes = vec![
2540 DetectBox {
2541 bbox: BoundingBox {
2542 xmin: 0.0,
2543 ymin: 0.0,
2544 xmax: 0.5,
2545 ymax: 0.5,
2546 },
2547 score: 0.9,
2548 label: 0,
2549 },
2550 DetectBox {
2551 bbox: BoundingBox {
2552 xmin: 0.1,
2553 ymin: 0.1,
2554 xmax: 0.6,
2555 ymax: 0.6,
2556 },
2557 score: 0.8,
2558 label: 0, },
2560 ];
2561
2562 let result = nms_class_aware_float(0.3, same_class_boxes);
2563 assert_eq!(
2564 result.len(),
2565 1,
2566 "Class-aware NMS should suppress overlapping box with same class"
2567 );
2568 assert_eq!(result[0].label, 0);
2569 assert!((result[0].score - 0.9).abs() < 1e-6);
2570 }
2571
2572 #[test]
2573 fn test_class_agnostic_vs_aware_nms() {
2574 use crate::float::{nms_class_aware_float, nms_float};
2575
2576 let boxes = vec![
2578 DetectBox {
2579 bbox: BoundingBox {
2580 xmin: 0.0,
2581 ymin: 0.0,
2582 xmax: 0.5,
2583 ymax: 0.5,
2584 },
2585 score: 0.9,
2586 label: 0,
2587 },
2588 DetectBox {
2589 bbox: BoundingBox {
2590 xmin: 0.1,
2591 ymin: 0.1,
2592 xmax: 0.6,
2593 ymax: 0.6,
2594 },
2595 score: 0.8,
2596 label: 1,
2597 },
2598 ];
2599
2600 let agnostic_result = nms_float(0.3, boxes.clone());
2602 assert_eq!(
2603 agnostic_result.len(),
2604 1,
2605 "Class-agnostic NMS should suppress overlapping boxes"
2606 );
2607
2608 let aware_result = nms_class_aware_float(0.3, boxes);
2610 assert_eq!(
2611 aware_result.len(),
2612 2,
2613 "Class-aware NMS should keep boxes with different classes"
2614 );
2615 }
2616
2617 #[test]
2618 fn test_class_aware_nms_int() {
2619 use crate::byte::nms_class_aware_int;
2620
2621 let boxes = vec![
2623 DetectBoxQuantized {
2624 bbox: BoundingBox {
2625 xmin: 0.0,
2626 ymin: 0.0,
2627 xmax: 0.5,
2628 ymax: 0.5,
2629 },
2630 score: 200_u8,
2631 label: 0,
2632 },
2633 DetectBoxQuantized {
2634 bbox: BoundingBox {
2635 xmin: 0.1,
2636 ymin: 0.1,
2637 xmax: 0.6,
2638 ymax: 0.6,
2639 },
2640 score: 180_u8,
2641 label: 1, },
2643 ];
2644
2645 let result = nms_class_aware_int(0.5, boxes);
2647 assert_eq!(
2648 result.len(),
2649 2,
2650 "Class-aware NMS (int) should keep boxes with different classes"
2651 );
2652 }
2653
2654 #[test]
2655 fn test_nms_enum_default() {
2656 let default_nms: configs::Nms = Default::default();
2658 assert_eq!(default_nms, configs::Nms::ClassAgnostic);
2659 }
2660
2661 #[test]
2662 fn test_decoder_nms_mode() {
2663 let decoder = DecoderBuilder::default()
2665 .with_config_yolo_det(
2666 configs::Detection {
2667 anchors: None,
2668 decoder: DecoderType::Ultralytics,
2669 quantization: None,
2670 shape: vec![1, 84, 8400],
2671 dshape: Vec::new(),
2672 normalized: Some(true),
2673 },
2674 None,
2675 )
2676 .with_nms(Some(configs::Nms::ClassAware))
2677 .build()
2678 .unwrap();
2679
2680 assert_eq!(decoder.nms, Some(configs::Nms::ClassAware));
2681 }
2682
2683 #[test]
2684 fn test_decoder_nms_bypass() {
2685 let decoder = DecoderBuilder::default()
2687 .with_config_yolo_det(
2688 configs::Detection {
2689 anchors: None,
2690 decoder: DecoderType::Ultralytics,
2691 quantization: None,
2692 shape: vec![1, 84, 8400],
2693 dshape: Vec::new(),
2694 normalized: Some(true),
2695 },
2696 None,
2697 )
2698 .with_nms(None)
2699 .build()
2700 .unwrap();
2701
2702 assert_eq!(decoder.nms, None);
2703 }
2704
2705 #[test]
2706 fn test_decoder_normalized_boxes_true() {
2707 let decoder = DecoderBuilder::default()
2709 .with_config_yolo_det(
2710 configs::Detection {
2711 anchors: None,
2712 decoder: DecoderType::Ultralytics,
2713 quantization: None,
2714 shape: vec![1, 84, 8400],
2715 dshape: Vec::new(),
2716 normalized: Some(true),
2717 },
2718 None,
2719 )
2720 .build()
2721 .unwrap();
2722
2723 assert_eq!(decoder.normalized_boxes(), Some(true));
2724 }
2725
2726 #[test]
2727 fn test_decoder_normalized_boxes_false() {
2728 let decoder = DecoderBuilder::default()
2731 .with_config_yolo_det(
2732 configs::Detection {
2733 anchors: None,
2734 decoder: DecoderType::Ultralytics,
2735 quantization: None,
2736 shape: vec![1, 84, 8400],
2737 dshape: Vec::new(),
2738 normalized: Some(false),
2739 },
2740 None,
2741 )
2742 .build()
2743 .unwrap();
2744
2745 assert_eq!(decoder.normalized_boxes(), Some(false));
2746 }
2747
2748 #[test]
2749 fn test_decoder_normalized_boxes_unknown() {
2750 let decoder = DecoderBuilder::default()
2752 .with_config_yolo_det(
2753 configs::Detection {
2754 anchors: None,
2755 decoder: DecoderType::Ultralytics,
2756 quantization: None,
2757 shape: vec![1, 84, 8400],
2758 dshape: Vec::new(),
2759 normalized: None,
2760 },
2761 Some(DecoderVersion::Yolo11),
2762 )
2763 .build()
2764 .unwrap();
2765
2766 assert_eq!(decoder.normalized_boxes(), None);
2767 }
2768
2769 pub fn quantize_ndarray<T: PrimInt + 'static, D: Dimension, F: Float + AsPrimitive<T>>(
2770 input: ArrayView<F, D>,
2771 quant: Quantization,
2772 ) -> Array<T, D>
2773 where
2774 i32: num_traits::AsPrimitive<F>,
2775 f32: num_traits::AsPrimitive<F>,
2776 {
2777 let zero_point = quant.zero_point.as_();
2778 let div_scale = F::one() / quant.scale.as_();
2779 if zero_point != F::zero() {
2780 input.mapv(|d| (d * div_scale + zero_point).round().as_())
2781 } else {
2782 input.mapv(|d| (d * div_scale).round().as_())
2783 }
2784 }
2785
2786 fn real_data_expected_boxes() -> [DetectBox; 2] {
2787 [
2788 DetectBox {
2789 bbox: BoundingBox {
2790 xmin: 0.08515105,
2791 ymin: 0.7131401,
2792 xmax: 0.29802868,
2793 ymax: 0.8195788,
2794 },
2795 score: 0.91537374,
2796 label: 23,
2797 },
2798 DetectBox {
2799 bbox: BoundingBox {
2800 xmin: 0.59605736,
2801 ymin: 0.25545314,
2802 xmax: 0.93666154,
2803 ymax: 0.72378385,
2804 },
2805 score: 0.91537374,
2806 label: 23,
2807 },
2808 ]
2809 }
2810
2811 fn e2e_expected_boxes_quant() -> [DetectBox; 1] {
2812 [DetectBox {
2813 bbox: BoundingBox {
2814 xmin: 0.12549022,
2815 ymin: 0.12549022,
2816 xmax: 0.23529413,
2817 ymax: 0.23529413,
2818 },
2819 score: 0.98823535,
2820 label: 2,
2821 }]
2822 }
2823
2824 fn e2e_expected_boxes_float() -> [DetectBox; 1] {
2825 [DetectBox {
2826 bbox: BoundingBox {
2827 xmin: 0.1234,
2828 ymin: 0.1234,
2829 xmax: 0.2345,
2830 ymax: 0.2345,
2831 },
2832 score: 0.9876,
2833 label: 2,
2834 }]
2835 }
2836
2837 macro_rules! real_data_proto_test {
2838 ($name:ident, quantized, $layout:ident) => {
2839 #[test]
2840 fn $name() {
2841 let is_split = matches!(stringify!($layout), "split");
2842
2843 let score_threshold = 0.45;
2844 let iou_threshold = 0.45;
2845 let quant_boxes = (0.021287762_f32, 31_i32);
2846 let quant_protos = (0.02491162_f32, -117_i32);
2847
2848 let raw_boxes = include_bytes!(concat!(
2849 env!("CARGO_MANIFEST_DIR"),
2850 "/../../testdata/yolov8_boxes_116x8400.bin"
2851 ));
2852 let raw_boxes = unsafe {
2853 std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len())
2854 };
2855 let boxes_i8 =
2856 ndarray::Array3::from_shape_vec((1, 116, 8400), raw_boxes.to_vec()).unwrap();
2857
2858 let raw_protos = include_bytes!(concat!(
2859 env!("CARGO_MANIFEST_DIR"),
2860 "/../../testdata/yolov8_protos_160x160x32.bin"
2861 ));
2862 let raw_protos = unsafe {
2863 std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
2864 };
2865 let protos_i8 =
2866 ndarray::Array4::from_shape_vec((1, 160, 160, 32), raw_protos.to_vec())
2867 .unwrap();
2868
2869 let mask_split = boxes_i8.slice(s![.., 84.., ..]).to_owned();
2871 let scores_split = boxes_i8.slice(s![.., 4..84, ..]).to_owned();
2872 let boxes_split = boxes_i8.slice(s![.., ..4, ..]).to_owned();
2873 let boxes_combined = boxes_i8;
2874
2875 let decoder = if is_split {
2876 build_yolo_split_segdet_decoder(
2877 score_threshold,
2878 iou_threshold,
2879 quant_boxes,
2880 quant_protos,
2881 )
2882 } else {
2883 build_yolov8_seg_decoder(score_threshold, iou_threshold)
2884 };
2885
2886 let expected = real_data_expected_boxes();
2887 let mut output_boxes = Vec::with_capacity(50);
2888
2889 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
2890 vec![
2891 boxes_split.view().into(),
2892 scores_split.view().into(),
2893 mask_split.view().into(),
2894 protos_i8.view().into(),
2895 ]
2896 } else {
2897 vec![boxes_combined.view().into(), protos_i8.view().into()]
2898 };
2899 decoder
2900 .decode_quantized_proto(&inputs, &mut output_boxes)
2901 .unwrap();
2902
2903 assert_eq!(output_boxes.len(), 2);
2904 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
2905 assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
2906 }
2907 };
2908 ($name:ident, float, $layout:ident) => {
2909 #[test]
2910 fn $name() {
2911 let is_split = matches!(stringify!($layout), "split");
2912
2913 let score_threshold = 0.45;
2914 let iou_threshold = 0.45;
2915 let quant_boxes = (0.021287762_f32, 31_i32);
2916 let quant_protos = (0.02491162_f32, -117_i32);
2917
2918 let raw_boxes = include_bytes!(concat!(
2919 env!("CARGO_MANIFEST_DIR"),
2920 "/../../testdata/yolov8_boxes_116x8400.bin"
2921 ));
2922 let raw_boxes = unsafe {
2923 std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len())
2924 };
2925 let boxes_i8 =
2926 ndarray::Array3::from_shape_vec((1, 116, 8400), raw_boxes.to_vec()).unwrap();
2927 let boxes_f32: Array3<f32> =
2928 dequantize_ndarray(boxes_i8.view(), quant_boxes.into());
2929
2930 let raw_protos = include_bytes!(concat!(
2931 env!("CARGO_MANIFEST_DIR"),
2932 "/../../testdata/yolov8_protos_160x160x32.bin"
2933 ));
2934 let raw_protos = unsafe {
2935 std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
2936 };
2937 let protos_i8 =
2938 ndarray::Array4::from_shape_vec((1, 160, 160, 32), raw_protos.to_vec())
2939 .unwrap();
2940 let protos_f32: Array4<f32> =
2941 dequantize_ndarray(protos_i8.view(), quant_protos.into());
2942
2943 let mask_split = boxes_f32.slice(s![.., 84.., ..]).to_owned();
2945 let scores_split = boxes_f32.slice(s![.., 4..84, ..]).to_owned();
2946 let boxes_split = boxes_f32.slice(s![.., ..4, ..]).to_owned();
2947 let boxes_combined = boxes_f32;
2948
2949 let decoder = if is_split {
2950 build_yolo_split_segdet_decoder(
2951 score_threshold,
2952 iou_threshold,
2953 quant_boxes,
2954 quant_protos,
2955 )
2956 } else {
2957 build_yolov8_seg_decoder(score_threshold, iou_threshold)
2958 };
2959
2960 let expected = real_data_expected_boxes();
2961 let mut output_boxes = Vec::with_capacity(50);
2962
2963 let inputs = if is_split {
2964 vec![
2965 boxes_split.view().into_dyn(),
2966 scores_split.view().into_dyn(),
2967 mask_split.view().into_dyn(),
2968 protos_f32.view().into_dyn(),
2969 ]
2970 } else {
2971 vec![
2972 boxes_combined.view().into_dyn(),
2973 protos_f32.view().into_dyn(),
2974 ]
2975 };
2976 decoder
2977 .decode_float_proto(&inputs, &mut output_boxes)
2978 .unwrap();
2979
2980 assert_eq!(output_boxes.len(), 2);
2981 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
2982 assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
2983 }
2984 };
2985 }
2986
2987 real_data_proto_test!(test_decoder_segdet_proto, quantized, combined);
2988 real_data_proto_test!(test_decoder_segdet_proto_float, float, combined);
2989 real_data_proto_test!(test_decoder_segdet_split_proto, quantized, split);
2990 real_data_proto_test!(test_decoder_segdet_split_proto_float, float, split);
2991
2992 const E2E_COMBINED_DET_CONFIG: &str = "
2993decoder_version: yolo26
2994outputs:
2995 - type: detection
2996 decoder: ultralytics
2997 quantization: [0.00784313725490196, 0]
2998 shape: [1, 10, 6]
2999 dshape:
3000 - [batch, 1]
3001 - [num_boxes, 10]
3002 - [num_features, 6]
3003 normalized: true
3004";
3005
3006 const E2E_COMBINED_SEGDET_CONFIG: &str = "
3007decoder_version: yolo26
3008outputs:
3009 - type: detection
3010 decoder: ultralytics
3011 quantization: [0.00784313725490196, 0]
3012 shape: [1, 10, 38]
3013 dshape:
3014 - [batch, 1]
3015 - [num_boxes, 10]
3016 - [num_features, 38]
3017 normalized: true
3018 - type: protos
3019 decoder: ultralytics
3020 quantization: [0.0039215686274509803921568627451, 128]
3021 shape: [1, 160, 160, 32]
3022 dshape:
3023 - [batch, 1]
3024 - [height, 160]
3025 - [width, 160]
3026 - [num_protos, 32]
3027";
3028
3029 const E2E_SPLIT_DET_CONFIG: &str = "
3030decoder_version: yolo26
3031outputs:
3032 - type: boxes
3033 decoder: ultralytics
3034 quantization: [0.00784313725490196, 0]
3035 shape: [1, 10, 4]
3036 dshape:
3037 - [batch, 1]
3038 - [num_boxes, 10]
3039 - [box_coords, 4]
3040 normalized: true
3041 - type: scores
3042 decoder: ultralytics
3043 quantization: [0.00784313725490196, 0]
3044 shape: [1, 10, 1]
3045 dshape:
3046 - [batch, 1]
3047 - [num_boxes, 10]
3048 - [num_classes, 1]
3049 - type: classes
3050 decoder: ultralytics
3051 quantization: [0.00784313725490196, 0]
3052 shape: [1, 10, 1]
3053 dshape:
3054 - [batch, 1]
3055 - [num_boxes, 10]
3056 - [num_classes, 1]
3057";
3058
3059 const E2E_SPLIT_SEGDET_CONFIG: &str = "
3060decoder_version: yolo26
3061outputs:
3062 - type: boxes
3063 decoder: ultralytics
3064 quantization: [0.00784313725490196, 0]
3065 shape: [1, 10, 4]
3066 dshape:
3067 - [batch, 1]
3068 - [num_boxes, 10]
3069 - [box_coords, 4]
3070 normalized: true
3071 - type: scores
3072 decoder: ultralytics
3073 quantization: [0.00784313725490196, 0]
3074 shape: [1, 10, 1]
3075 dshape:
3076 - [batch, 1]
3077 - [num_boxes, 10]
3078 - [num_classes, 1]
3079 - type: classes
3080 decoder: ultralytics
3081 quantization: [0.00784313725490196, 0]
3082 shape: [1, 10, 1]
3083 dshape:
3084 - [batch, 1]
3085 - [num_boxes, 10]
3086 - [num_classes, 1]
3087 - type: mask_coefficients
3088 decoder: ultralytics
3089 quantization: [0.00784313725490196, 0]
3090 shape: [1, 10, 32]
3091 dshape:
3092 - [batch, 1]
3093 - [num_boxes, 10]
3094 - [num_protos, 32]
3095 - type: protos
3096 decoder: ultralytics
3097 quantization: [0.0039215686274509803921568627451, 128]
3098 shape: [1, 160, 160, 32]
3099 dshape:
3100 - [batch, 1]
3101 - [height, 160]
3102 - [width, 160]
3103 - [num_protos, 32]
3104";
3105
3106 macro_rules! e2e_segdet_test {
3107 ($name:ident, quantized, $layout:ident, $output:ident) => {
3108 #[test]
3109 fn $name() {
3110 let is_split = matches!(stringify!($layout), "split");
3111 let is_proto = matches!(stringify!($output), "proto");
3112
3113 let score_threshold = 0.45;
3114 let iou_threshold = 0.45;
3115
3116 let mut boxes = Array2::zeros((10, 4));
3117 let mut scores = Array2::zeros((10, 1));
3118 let mut classes = Array2::zeros((10, 1));
3119 let mask = Array2::zeros((10, 32));
3120 let protos = Array3::<f64>::zeros((160, 160, 32));
3121 let protos = protos.insert_axis(Axis(0));
3122 let protos_quant = (1.0 / 255.0, 0.0);
3123 let protos: Array4<u8> = quantize_ndarray(protos.view(), protos_quant.into());
3124
3125 boxes
3126 .slice_mut(s![0, ..])
3127 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
3128 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
3129 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
3130
3131 let detect_quant = (2.0 / 255.0, 0.0);
3132
3133 let decoder = if is_split {
3134 DecoderBuilder::default()
3135 .with_config_yaml_str(E2E_SPLIT_SEGDET_CONFIG.to_string())
3136 .with_score_threshold(score_threshold)
3137 .with_iou_threshold(iou_threshold)
3138 .build()
3139 .unwrap()
3140 } else {
3141 DecoderBuilder::default()
3142 .with_config_yaml_str(E2E_COMBINED_SEGDET_CONFIG.to_string())
3143 .with_score_threshold(score_threshold)
3144 .with_iou_threshold(iou_threshold)
3145 .build()
3146 .unwrap()
3147 };
3148
3149 let expected = e2e_expected_boxes_quant();
3150 let mut output_boxes = Vec::with_capacity(50);
3151
3152 if is_split {
3153 let boxes = boxes.insert_axis(Axis(0));
3154 let scores = scores.insert_axis(Axis(0));
3155 let classes = classes.insert_axis(Axis(0));
3156 let mask = mask.insert_axis(Axis(0));
3157
3158 let boxes: Array3<u8> = quantize_ndarray(boxes.view(), detect_quant.into());
3159 let scores: Array3<u8> = quantize_ndarray(scores.view(), detect_quant.into());
3160 let classes: Array3<u8> = quantize_ndarray(classes.view(), detect_quant.into());
3161 let mask: Array3<u8> = quantize_ndarray(mask.view(), detect_quant.into());
3162
3163 if is_proto {
3164 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
3165 boxes.view().into(),
3166 scores.view().into(),
3167 classes.view().into(),
3168 mask.view().into(),
3169 protos.view().into(),
3170 ];
3171 decoder
3172 .decode_quantized_proto(&inputs, &mut output_boxes)
3173 .unwrap();
3174
3175 assert_eq!(output_boxes.len(), 1);
3176 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3177 } else {
3178 let mut output_masks = Vec::with_capacity(50);
3179 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
3180 boxes.view().into(),
3181 scores.view().into(),
3182 classes.view().into(),
3183 mask.view().into(),
3184 protos.view().into(),
3185 ];
3186 decoder
3187 .decode_quantized(&inputs, &mut output_boxes, &mut output_masks)
3188 .unwrap();
3189
3190 assert_eq!(output_boxes.len(), 1);
3191 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3192 }
3193 } else {
3194 let detect = ndarray::concatenate![
3196 Axis(1),
3197 boxes.view(),
3198 scores.view(),
3199 classes.view(),
3200 mask.view()
3201 ];
3202 let detect = detect.insert_axis(Axis(0));
3203 assert_eq!(detect.shape(), &[1, 10, 38]);
3204 let detect: Array3<u8> = quantize_ndarray(detect.view(), detect_quant.into());
3205
3206 if is_proto {
3207 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
3208 vec![detect.view().into(), protos.view().into()];
3209 decoder
3210 .decode_quantized_proto(&inputs, &mut output_boxes)
3211 .unwrap();
3212
3213 assert_eq!(output_boxes.len(), 1);
3214 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3215 } else {
3216 let mut output_masks = Vec::with_capacity(50);
3217 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
3218 vec![detect.view().into(), protos.view().into()];
3219 decoder
3220 .decode_quantized(&inputs, &mut output_boxes, &mut output_masks)
3221 .unwrap();
3222
3223 assert_eq!(output_boxes.len(), 1);
3224 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3225 }
3226 }
3227 }
3228 };
3229 ($name:ident, float, $layout:ident, $output:ident) => {
3230 #[test]
3231 fn $name() {
3232 let is_split = matches!(stringify!($layout), "split");
3233 let is_proto = matches!(stringify!($output), "proto");
3234
3235 let score_threshold = 0.45;
3236 let iou_threshold = 0.45;
3237
3238 let mut boxes = Array2::zeros((10, 4));
3239 let mut scores = Array2::zeros((10, 1));
3240 let mut classes = Array2::zeros((10, 1));
3241 let mask: Array2<f64> = Array2::zeros((10, 32));
3242 let protos = Array3::<f64>::zeros((160, 160, 32));
3243 let protos = protos.insert_axis(Axis(0));
3244
3245 boxes
3246 .slice_mut(s![0, ..])
3247 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
3248 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
3249 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
3250
3251 let decoder = if is_split {
3252 DecoderBuilder::default()
3253 .with_config_yaml_str(E2E_SPLIT_SEGDET_CONFIG.to_string())
3254 .with_score_threshold(score_threshold)
3255 .with_iou_threshold(iou_threshold)
3256 .build()
3257 .unwrap()
3258 } else {
3259 DecoderBuilder::default()
3260 .with_config_yaml_str(E2E_COMBINED_SEGDET_CONFIG.to_string())
3261 .with_score_threshold(score_threshold)
3262 .with_iou_threshold(iou_threshold)
3263 .build()
3264 .unwrap()
3265 };
3266
3267 let expected = e2e_expected_boxes_float();
3268 let mut output_boxes = Vec::with_capacity(50);
3269
3270 if is_split {
3271 let boxes = boxes.insert_axis(Axis(0));
3272 let scores = scores.insert_axis(Axis(0));
3273 let classes = classes.insert_axis(Axis(0));
3274 let mask = mask.insert_axis(Axis(0));
3275
3276 if is_proto {
3277 let inputs = vec![
3278 boxes.view().into_dyn(),
3279 scores.view().into_dyn(),
3280 classes.view().into_dyn(),
3281 mask.view().into_dyn(),
3282 protos.view().into_dyn(),
3283 ];
3284 decoder
3285 .decode_float_proto(&inputs, &mut output_boxes)
3286 .unwrap();
3287
3288 assert_eq!(output_boxes.len(), 1);
3289 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3290 } else {
3291 let mut output_masks = Vec::with_capacity(50);
3292 let inputs = vec![
3293 boxes.view().into_dyn(),
3294 scores.view().into_dyn(),
3295 classes.view().into_dyn(),
3296 mask.view().into_dyn(),
3297 protos.view().into_dyn(),
3298 ];
3299 decoder
3300 .decode_float(&inputs, &mut output_boxes, &mut output_masks)
3301 .unwrap();
3302
3303 assert_eq!(output_boxes.len(), 1);
3304 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3305 }
3306 } else {
3307 let detect = ndarray::concatenate![
3309 Axis(1),
3310 boxes.view(),
3311 scores.view(),
3312 classes.view(),
3313 mask.view()
3314 ];
3315 let detect = detect.insert_axis(Axis(0));
3316 assert_eq!(detect.shape(), &[1, 10, 38]);
3317
3318 if is_proto {
3319 let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
3320 decoder
3321 .decode_float_proto(&inputs, &mut output_boxes)
3322 .unwrap();
3323
3324 assert_eq!(output_boxes.len(), 1);
3325 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3326 } else {
3327 let mut output_masks = Vec::with_capacity(50);
3328 let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
3329 decoder
3330 .decode_float(&inputs, &mut output_boxes, &mut output_masks)
3331 .unwrap();
3332
3333 assert_eq!(output_boxes.len(), 1);
3334 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3335 }
3336 }
3337 }
3338 };
3339 }
3340
3341 e2e_segdet_test!(test_decoder_end_to_end_segdet, quantized, combined, masks);
3342 e2e_segdet_test!(test_decoder_end_to_end_segdet_float, float, combined, masks);
3343 e2e_segdet_test!(
3344 test_decoder_end_to_end_segdet_proto,
3345 quantized,
3346 combined,
3347 proto
3348 );
3349 e2e_segdet_test!(
3350 test_decoder_end_to_end_segdet_proto_float,
3351 float,
3352 combined,
3353 proto
3354 );
3355 e2e_segdet_test!(
3356 test_decoder_end_to_end_segdet_split,
3357 quantized,
3358 split,
3359 masks
3360 );
3361 e2e_segdet_test!(
3362 test_decoder_end_to_end_segdet_split_float,
3363 float,
3364 split,
3365 masks
3366 );
3367 e2e_segdet_test!(
3368 test_decoder_end_to_end_segdet_split_proto,
3369 quantized,
3370 split,
3371 proto
3372 );
3373 e2e_segdet_test!(
3374 test_decoder_end_to_end_segdet_split_proto_float,
3375 float,
3376 split,
3377 proto
3378 );
3379
3380 macro_rules! e2e_det_test {
3381 ($name:ident, quantized, $layout:ident) => {
3382 #[test]
3383 fn $name() {
3384 let is_split = matches!(stringify!($layout), "split");
3385
3386 let score_threshold = 0.45;
3387 let iou_threshold = 0.45;
3388
3389 let mut boxes = Array3::zeros((1, 10, 4));
3390 let mut scores = Array3::zeros((1, 10, 1));
3391 let mut classes = Array3::zeros((1, 10, 1));
3392
3393 boxes
3394 .slice_mut(s![0, 0, ..])
3395 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
3396 scores.slice_mut(s![0, 0, ..]).assign(&array![0.9876]);
3397 classes.slice_mut(s![0, 0, ..]).assign(&array![2.0]);
3398
3399 let detect_quant = (2.0 / 255.0, 0_i32);
3400
3401 let decoder = if is_split {
3402 DecoderBuilder::default()
3403 .with_config_yaml_str(E2E_SPLIT_DET_CONFIG.to_string())
3404 .with_score_threshold(score_threshold)
3405 .with_iou_threshold(iou_threshold)
3406 .build()
3407 .unwrap()
3408 } else {
3409 DecoderBuilder::default()
3410 .with_config_yaml_str(E2E_COMBINED_DET_CONFIG.to_string())
3411 .with_score_threshold(score_threshold)
3412 .with_iou_threshold(iou_threshold)
3413 .build()
3414 .unwrap()
3415 };
3416
3417 let expected = e2e_expected_boxes_quant();
3418 let mut output_boxes = Vec::with_capacity(50);
3419
3420 if is_split {
3421 let boxes: Array<u8, _> = quantize_ndarray(boxes.view(), detect_quant.into());
3422 let scores: Array<u8, _> = quantize_ndarray(scores.view(), detect_quant.into());
3423 let classes: Array<u8, _> =
3424 quantize_ndarray(classes.view(), detect_quant.into());
3425 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
3426 boxes.view().into(),
3427 scores.view().into(),
3428 classes.view().into(),
3429 ];
3430 decoder
3431 .decode_quantized(&inputs, &mut output_boxes, &mut Vec::new())
3432 .unwrap();
3433 } else {
3434 let detect =
3435 ndarray::concatenate![Axis(2), boxes.view(), scores.view(), classes.view()];
3436 assert_eq!(detect.shape(), &[1, 10, 6]);
3437 let detect: Array3<u8> = quantize_ndarray(detect.view(), detect_quant.into());
3438 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
3439 vec![detect.view().into()];
3440 decoder
3441 .decode_quantized(&inputs, &mut output_boxes, &mut Vec::new())
3442 .unwrap();
3443 }
3444
3445 assert_eq!(output_boxes.len(), 1);
3446 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
3447 }
3448 };
3449 ($name:ident, float, $layout:ident) => {
3450 #[test]
3451 fn $name() {
3452 let is_split = matches!(stringify!($layout), "split");
3453
3454 let score_threshold = 0.45;
3455 let iou_threshold = 0.45;
3456
3457 let mut boxes = Array3::zeros((1, 10, 4));
3458 let mut scores = Array3::zeros((1, 10, 1));
3459 let mut classes = Array3::zeros((1, 10, 1));
3460
3461 boxes
3462 .slice_mut(s![0, 0, ..])
3463 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
3464 scores.slice_mut(s![0, 0, ..]).assign(&array![0.9876]);
3465 classes.slice_mut(s![0, 0, ..]).assign(&array![2.0]);
3466
3467 let decoder = if is_split {
3468 DecoderBuilder::default()
3469 .with_config_yaml_str(E2E_SPLIT_DET_CONFIG.to_string())
3470 .with_score_threshold(score_threshold)
3471 .with_iou_threshold(iou_threshold)
3472 .build()
3473 .unwrap()
3474 } else {
3475 DecoderBuilder::default()
3476 .with_config_yaml_str(E2E_COMBINED_DET_CONFIG.to_string())
3477 .with_score_threshold(score_threshold)
3478 .with_iou_threshold(iou_threshold)
3479 .build()
3480 .unwrap()
3481 };
3482
3483 let expected = e2e_expected_boxes_float();
3484 let mut output_boxes = Vec::with_capacity(50);
3485
3486 if is_split {
3487 let inputs = vec![
3488 boxes.view().into_dyn(),
3489 scores.view().into_dyn(),
3490 classes.view().into_dyn(),
3491 ];
3492 decoder
3493 .decode_float(&inputs, &mut output_boxes, &mut Vec::new())
3494 .unwrap();
3495 } else {
3496 let detect =
3497 ndarray::concatenate![Axis(2), boxes.view(), scores.view(), classes.view()];
3498 assert_eq!(detect.shape(), &[1, 10, 6]);
3499 let inputs = vec![detect.view().into_dyn()];
3500 decoder
3501 .decode_float(&inputs, &mut output_boxes, &mut Vec::new())
3502 .unwrap();
3503 }
3504
3505 assert_eq!(output_boxes.len(), 1);
3506 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
3507 }
3508 };
3509 }
3510
3511 e2e_det_test!(test_decoder_end_to_end_combined_det, quantized, combined);
3512 e2e_det_test!(test_decoder_end_to_end_combined_det_float, float, combined);
3513 e2e_det_test!(test_decoder_end_to_end_split_det, quantized, split);
3514 e2e_det_test!(test_decoder_end_to_end_split_det_float, float, split);
3515
3516 #[test]
3517 fn test_decode_tensor() {
3518 let score_threshold = 0.45;
3519 let iou_threshold = 0.45;
3520
3521 let raw_boxes = include_bytes!(concat!(
3522 env!("CARGO_MANIFEST_DIR"),
3523 "/../../testdata/yolov8_boxes_116x8400.bin"
3524 ));
3525 let raw_boxes =
3526 unsafe { std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len()) };
3527 let boxes_i8: Tensor<i8> = Tensor::new(&[1, 116, 8400], None, None).unwrap();
3528 boxes_i8
3529 .map()
3530 .unwrap()
3531 .as_mut_slice()
3532 .copy_from_slice(raw_boxes);
3533 let boxes_i8 = boxes_i8.into();
3534
3535 let raw_protos = include_bytes!(concat!(
3536 env!("CARGO_MANIFEST_DIR"),
3537 "/../../testdata/yolov8_protos_160x160x32.bin"
3538 ));
3539 let raw_protos = unsafe {
3540 std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3541 };
3542 let protos_i8: Tensor<i8> = Tensor::new(&[1, 160, 160, 32], None, None).unwrap();
3543 protos_i8
3544 .map()
3545 .unwrap()
3546 .as_mut_slice()
3547 .copy_from_slice(raw_protos);
3548 let protos_i8 = protos_i8.into();
3549
3550 let decoder = build_yolov8_seg_decoder(score_threshold, iou_threshold);
3551 let expected = real_data_expected_boxes();
3552 let mut output_boxes = Vec::with_capacity(50);
3553
3554 decoder
3555 .decode(&[&boxes_i8, &protos_i8], &mut output_boxes, &mut Vec::new())
3556 .unwrap();
3557
3558 assert_eq!(output_boxes.len(), 2);
3559 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3560 assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3561 }
3562
3563 #[test]
3564 fn test_decode_tensor_f32() {
3565 let score_threshold = 0.45;
3566 let iou_threshold = 0.45;
3567
3568 let quant_boxes = (0.021287762_f32, 31_i32);
3569 let quant_protos = (0.02491162_f32, -117_i32);
3570 let raw_boxes = include_bytes!(concat!(
3571 env!("CARGO_MANIFEST_DIR"),
3572 "/../../testdata/yolov8_boxes_116x8400.bin"
3573 ));
3574 let raw_boxes =
3575 unsafe { std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len()) };
3576 let mut raw_boxes_f32 = vec![0f32; raw_boxes.len()];
3577 dequantize_cpu(raw_boxes, quant_boxes.into(), &mut raw_boxes_f32);
3578 let boxes_f32: Tensor<f32> = Tensor::new(&[1, 116, 8400], None, None).unwrap();
3579 boxes_f32
3580 .map()
3581 .unwrap()
3582 .as_mut_slice()
3583 .copy_from_slice(&raw_boxes_f32);
3584 let boxes_f32 = boxes_f32.into();
3585
3586 let raw_protos = include_bytes!(concat!(
3587 env!("CARGO_MANIFEST_DIR"),
3588 "/../../testdata/yolov8_protos_160x160x32.bin"
3589 ));
3590 let raw_protos = unsafe {
3591 std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3592 };
3593 let mut raw_protos_f32 = vec![0f32; raw_protos.len()];
3594 dequantize_cpu(raw_protos, quant_protos.into(), &mut raw_protos_f32);
3595 let protos_f32: Tensor<f32> = Tensor::new(&[1, 160, 160, 32], None, None).unwrap();
3596 protos_f32
3597 .map()
3598 .unwrap()
3599 .as_mut_slice()
3600 .copy_from_slice(&raw_protos_f32);
3601 let protos_f32 = protos_f32.into();
3602
3603 let decoder = build_yolov8_seg_decoder(score_threshold, iou_threshold);
3604
3605 let expected = real_data_expected_boxes();
3606 let mut output_boxes = Vec::with_capacity(50);
3607
3608 decoder
3609 .decode(
3610 &[&boxes_f32, &protos_f32],
3611 &mut output_boxes,
3612 &mut Vec::new(),
3613 )
3614 .unwrap();
3615
3616 assert_eq!(output_boxes.len(), 2);
3617 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3618 assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3619 }
3620
3621 #[test]
3622 fn test_decode_tensor_f64() {
3623 let score_threshold = 0.45;
3624 let iou_threshold = 0.45;
3625
3626 let quant_boxes = (0.021287762_f32, 31_i32);
3627 let quant_protos = (0.02491162_f32, -117_i32);
3628 let raw_boxes = include_bytes!(concat!(
3629 env!("CARGO_MANIFEST_DIR"),
3630 "/../../testdata/yolov8_boxes_116x8400.bin"
3631 ));
3632 let raw_boxes =
3633 unsafe { std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len()) };
3634 let mut raw_boxes_f64 = vec![0f64; raw_boxes.len()];
3635 dequantize_cpu(raw_boxes, quant_boxes.into(), &mut raw_boxes_f64);
3636 let boxes_f64: Tensor<f64> = Tensor::new(&[1, 116, 8400], None, None).unwrap();
3637 boxes_f64
3638 .map()
3639 .unwrap()
3640 .as_mut_slice()
3641 .copy_from_slice(&raw_boxes_f64);
3642 let boxes_f64 = boxes_f64.into();
3643
3644 let raw_protos = include_bytes!(concat!(
3645 env!("CARGO_MANIFEST_DIR"),
3646 "/../../testdata/yolov8_protos_160x160x32.bin"
3647 ));
3648 let raw_protos = unsafe {
3649 std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3650 };
3651 let mut raw_protos_f64 = vec![0f64; raw_protos.len()];
3652 dequantize_cpu(raw_protos, quant_protos.into(), &mut raw_protos_f64);
3653 let protos_f64: Tensor<f64> = Tensor::new(&[1, 160, 160, 32], None, None).unwrap();
3654 protos_f64
3655 .map()
3656 .unwrap()
3657 .as_mut_slice()
3658 .copy_from_slice(&raw_protos_f64);
3659 let protos_f64 = protos_f64.into();
3660
3661 let decoder = build_yolov8_seg_decoder(score_threshold, iou_threshold);
3662
3663 let expected = real_data_expected_boxes();
3664 let mut output_boxes = Vec::with_capacity(50);
3665
3666 decoder
3667 .decode(
3668 &[&boxes_f64, &protos_f64],
3669 &mut output_boxes,
3670 &mut Vec::new(),
3671 )
3672 .unwrap();
3673
3674 assert_eq!(output_boxes.len(), 2);
3675 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3676 assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3677 }
3678
3679 #[test]
3680 fn test_decode_tensor_proto() {
3681 let score_threshold = 0.45;
3682 let iou_threshold = 0.45;
3683
3684 let raw_boxes = include_bytes!(concat!(
3685 env!("CARGO_MANIFEST_DIR"),
3686 "/../../testdata/yolov8_boxes_116x8400.bin"
3687 ));
3688 let raw_boxes =
3689 unsafe { std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len()) };
3690 let boxes_i8: Tensor<i8> = Tensor::new(&[1, 116, 8400], None, None).unwrap();
3691 boxes_i8
3692 .map()
3693 .unwrap()
3694 .as_mut_slice()
3695 .copy_from_slice(raw_boxes);
3696 let boxes_i8 = boxes_i8.into();
3697
3698 let raw_protos = include_bytes!(concat!(
3699 env!("CARGO_MANIFEST_DIR"),
3700 "/../../testdata/yolov8_protos_160x160x32.bin"
3701 ));
3702 let raw_protos = unsafe {
3703 std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3704 };
3705 let protos_i8: Tensor<i8> = Tensor::new(&[1, 160, 160, 32], None, None).unwrap();
3706 protos_i8
3707 .map()
3708 .unwrap()
3709 .as_mut_slice()
3710 .copy_from_slice(raw_protos);
3711 let protos_i8 = protos_i8.into();
3712
3713 let decoder = build_yolov8_seg_decoder(score_threshold, iou_threshold);
3714
3715 let expected = real_data_expected_boxes();
3716 let mut output_boxes = Vec::with_capacity(50);
3717
3718 let proto_data = decoder
3719 .decode_proto(&[&boxes_i8, &protos_i8], &mut output_boxes)
3720 .unwrap();
3721
3722 assert_eq!(output_boxes.len(), 2);
3723 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3724 assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3725
3726 let proto_data = proto_data.expect("segmentation model should return ProtoData");
3727 let coeffs_shape = proto_data.mask_coefficients.shape();
3728 assert_eq!(
3729 coeffs_shape[0],
3730 output_boxes.len(),
3731 "mask_coefficients count must match detection count"
3732 );
3733 assert_eq!(
3734 coeffs_shape[1], 32,
3735 "each detection should have 32 mask coefficients"
3736 );
3737 }
3738
3739 #[test]
3757 fn test_physical_order_tflite_nhwc_protos() {
3758 let score_threshold = 0.45;
3759 let iou_threshold = 0.45;
3760
3761 let protos_hwc = load_yolov8_protos().slice_move(s![0, .., .., ..]);
3763 let quant_protos = Quantization::new(0.02491161972284317, -117);
3764 let protos_f32_hwc = dequantize_ndarray::<_, _, f32>(protos_hwc.view(), quant_protos);
3765
3766 let boxes_2d = load_yolov8_boxes().slice_move(s![0, .., ..]);
3767 let quant_boxes = Quantization::new(0.021287761628627777, 31);
3768 let seg = dequantize_ndarray::<_, _, f32>(boxes_2d.view(), quant_boxes);
3769
3770 let mut ref_boxes: Vec<_> = Vec::with_capacity(10);
3772 let mut ref_masks: Vec<_> = Vec::with_capacity(10);
3773 decode_yolo_segdet_float(
3774 seg.view(),
3775 protos_f32_hwc.view(),
3776 score_threshold,
3777 iou_threshold,
3778 Some(configs::Nms::ClassAgnostic),
3779 &mut ref_boxes,
3780 &mut ref_masks,
3781 )
3782 .unwrap();
3783
3784 let protos_nhwc = protos_f32_hwc.clone().insert_axis(Axis(0)); let seg_3d = seg.insert_axis(Axis(0)); let decoder = DecoderBuilder::default()
3790 .with_config_yolo_segdet(
3791 configs::Detection {
3792 decoder: configs::DecoderType::Ultralytics,
3793 quantization: None,
3794 shape: vec![1, 116, 8400],
3795 dshape: vec![
3796 (DimName::Batch, 1),
3797 (DimName::NumFeatures, 116),
3798 (DimName::NumBoxes, 8400),
3799 ],
3800 normalized: Some(true),
3801 anchors: None,
3802 },
3803 configs::Protos {
3804 decoder: configs::DecoderType::Ultralytics,
3805 quantization: None,
3806 shape: vec![1, 160, 160, 32],
3807 dshape: vec![
3809 (DimName::Batch, 1),
3810 (DimName::Height, 160),
3811 (DimName::Width, 160),
3812 (DimName::NumProtos, 32),
3813 ],
3814 },
3815 None,
3816 )
3817 .with_score_threshold(score_threshold)
3818 .with_iou_threshold(iou_threshold)
3819 .build()
3820 .expect("config with NHWC protos dshape must build");
3821
3822 let mut cfg_boxes = Vec::with_capacity(10);
3823 let mut cfg_masks = Vec::with_capacity(10);
3824 decoder
3825 .decode_float(
3826 &[seg_3d.view().into_dyn(), protos_nhwc.view().into_dyn()],
3827 &mut cfg_boxes,
3828 &mut cfg_masks,
3829 )
3830 .unwrap();
3831
3832 assert_eq!(cfg_boxes.len(), ref_boxes.len(), "box count mismatch");
3833 for (c, r) in cfg_boxes.iter().zip(&ref_boxes) {
3834 assert!(
3835 c.equal_within_delta(r, 0.01),
3836 "NHWC-declared box does not match reference: {c:?} vs {r:?}"
3837 );
3838 }
3839 for (cm, rm) in cfg_masks.iter().zip(&ref_masks) {
3840 let cm_arr = segmentation_to_mask(cm.segmentation.view()).unwrap();
3841 let rm_arr = segmentation_to_mask(rm.segmentation.view()).unwrap();
3842 assert_eq!(
3843 cm_arr, rm_arr,
3844 "NHWC-declared mask must match reference pixel-for-pixel"
3845 );
3846 }
3847 }
3848
3849 #[test]
3858 fn test_physical_order_ara2_anchor_first_split_boxes() {
3859 use configs::{Boxes, Scores};
3860
3861 const N: usize = 8400;
3864 let mut boxes_canonical = Array3::<f32>::zeros((1, 4, N));
3865 let target_anchor = 42usize;
3866 boxes_canonical[[0, 0, target_anchor]] = 0.4; boxes_canonical[[0, 1, target_anchor]] = 0.5; boxes_canonical[[0, 2, target_anchor]] = 0.2; boxes_canonical[[0, 3, target_anchor]] = 0.2; let mut scores_canonical = Array3::<f32>::zeros((1, 80, N));
3873 scores_canonical[[0, 0, target_anchor]] = 0.9;
3874
3875 let ref_decoder = DecoderBuilder::default()
3877 .with_config_yolo_split_det(
3878 Boxes {
3879 decoder: configs::DecoderType::Ultralytics,
3880 quantization: None,
3881 shape: vec![1, 4, N],
3882 dshape: vec![
3883 (DimName::Batch, 1),
3884 (DimName::BoxCoords, 4),
3885 (DimName::NumBoxes, N),
3886 ],
3887 normalized: Some(true),
3888 },
3889 Scores {
3890 decoder: configs::DecoderType::Ultralytics,
3891 quantization: None,
3892 shape: vec![1, 80, N],
3893 dshape: vec![
3894 (DimName::Batch, 1),
3895 (DimName::NumClasses, 80),
3896 (DimName::NumBoxes, N),
3897 ],
3898 },
3899 )
3900 .with_score_threshold(0.5)
3901 .with_iou_threshold(0.5)
3902 .with_nms(Some(configs::Nms::ClassAgnostic))
3903 .build()
3904 .expect("reference canonical split decoder must build");
3905
3906 let mut ref_boxes = Vec::with_capacity(4);
3907 let mut ref_masks = Vec::with_capacity(0);
3908 ref_decoder
3909 .decode_float(
3910 &[
3911 boxes_canonical.view().into_dyn(),
3912 scores_canonical.view().into_dyn(),
3913 ],
3914 &mut ref_boxes,
3915 &mut ref_masks,
3916 )
3917 .unwrap();
3918 assert_eq!(ref_boxes.len(), 1, "reference should produce one box");
3919
3920 let boxes_ara2 = boxes_canonical.view().permuted_axes([0, 2, 1]).to_owned(); let scores_ara2 = scores_canonical.view().permuted_axes([0, 2, 1]).to_owned(); let ara2_decoder = DecoderBuilder::default()
3928 .with_config_yolo_split_det(
3929 Boxes {
3930 decoder: configs::DecoderType::Ultralytics,
3931 quantization: None,
3932 shape: vec![1, N, 4],
3933 dshape: vec![
3934 (DimName::Batch, 1),
3935 (DimName::NumBoxes, N),
3936 (DimName::BoxCoords, 4),
3937 ],
3938 normalized: Some(true),
3939 },
3940 Scores {
3941 decoder: configs::DecoderType::Ultralytics,
3942 quantization: None,
3943 shape: vec![1, N, 80],
3944 dshape: vec![
3945 (DimName::Batch, 1),
3946 (DimName::NumBoxes, N),
3947 (DimName::NumClasses, 80),
3948 ],
3949 },
3950 )
3951 .with_score_threshold(0.5)
3952 .with_iou_threshold(0.5)
3953 .with_nms(Some(configs::Nms::ClassAgnostic))
3954 .build()
3955 .expect("Ara-2 anchor-first decoder must build");
3956
3957 let mut ara2_boxes = Vec::with_capacity(4);
3958 let mut ara2_masks = Vec::with_capacity(0);
3959 ara2_decoder
3960 .decode_float(
3961 &[boxes_ara2.view().into_dyn(), scores_ara2.view().into_dyn()],
3962 &mut ara2_boxes,
3963 &mut ara2_masks,
3964 )
3965 .unwrap();
3966
3967 assert_eq!(
3968 ara2_boxes.len(),
3969 ref_boxes.len(),
3970 "Ara-2 anchor-first declaration must produce the same number \
3971 of boxes as the canonical features-first reference"
3972 );
3973 for (a, r) in ara2_boxes.iter().zip(&ref_boxes) {
3974 assert!(
3975 a.equal_within_delta(r, 1e-4),
3976 "Ara-2 box differs from reference: {a:?} vs {r:?}"
3977 );
3978 }
3979 }
3980
3981 #[test]
3985 fn test_physical_order_rejects_shape_dshape_mismatch() {
3986 let result = DecoderBuilder::default()
3987 .with_config_yolo_segdet(
3988 configs::Detection {
3989 decoder: configs::DecoderType::Ultralytics,
3990 quantization: None,
3991 shape: vec![1, 116, 8400],
3992 dshape: vec![
3993 (DimName::Batch, 1),
3994 (DimName::NumFeatures, 116),
3995 (DimName::NumBoxes, 8400),
3996 ],
3997 normalized: Some(true),
3998 anchors: None,
3999 },
4000 configs::Protos {
4001 decoder: configs::DecoderType::Ultralytics,
4002 quantization: None,
4003 shape: vec![1, 32, 160, 160],
4005 dshape: vec![
4008 (DimName::Batch, 1),
4009 (DimName::Height, 160),
4010 (DimName::Width, 160),
4011 (DimName::NumProtos, 32),
4012 ],
4013 },
4014 None,
4015 )
4016 .build();
4017
4018 match result {
4019 Err(DecoderError::InvalidConfig(msg)) => {
4020 assert!(
4021 msg.contains("does not match shape"),
4022 "expected shape/dshape size mismatch error, got: {msg}"
4023 );
4024 }
4025 other => panic!("expected InvalidConfig, got {other:?}"),
4026 }
4027 }
4028
4029 #[test]
4032 fn test_physical_order_rejects_duplicate_dshape_axis() {
4033 let result = DecoderBuilder::default()
4034 .with_config_yolo_split_det(
4035 configs::Boxes {
4036 decoder: configs::DecoderType::Ultralytics,
4037 quantization: None,
4038 shape: vec![1, 4, 8400],
4039 dshape: vec![
4040 (DimName::Batch, 1),
4041 (DimName::BoxCoords, 4),
4042 (DimName::BoxCoords, 4), ],
4044 normalized: Some(true),
4045 },
4046 configs::Scores {
4047 decoder: configs::DecoderType::Ultralytics,
4048 quantization: None,
4049 shape: vec![1, 80, 8400],
4050 dshape: vec![
4051 (DimName::Batch, 1),
4052 (DimName::NumClasses, 80),
4053 (DimName::NumBoxes, 8400),
4054 ],
4055 },
4056 )
4057 .build();
4058
4059 match result {
4064 Err(DecoderError::InvalidConfig(msg)) => {
4065 assert!(
4066 msg.contains("appears at both index") || msg.contains("does not match shape"),
4067 "expected positional or duplicate-axis error, got: {msg}"
4068 );
4069 }
4070 other => panic!("expected InvalidConfig, got {other:?}"),
4071 }
4072
4073 let result = DecoderBuilder::default()
4078 .with_config_yolo_split_det(
4079 configs::Boxes {
4080 decoder: configs::DecoderType::Ultralytics,
4081 quantization: None,
4082 shape: vec![1, 1, 4, 8400],
4083 dshape: vec![
4084 (DimName::Batch, 1),
4085 (DimName::Batch, 1), (DimName::BoxCoords, 4),
4087 (DimName::NumBoxes, 8400),
4088 ],
4089 normalized: Some(true),
4090 },
4091 configs::Scores {
4092 decoder: configs::DecoderType::Ultralytics,
4093 quantization: None,
4094 shape: vec![1, 80, 8400],
4095 dshape: vec![
4096 (DimName::Batch, 1),
4097 (DimName::NumClasses, 80),
4098 (DimName::NumBoxes, 8400),
4099 ],
4100 },
4101 )
4102 .build();
4103 match result {
4104 Err(DecoderError::InvalidConfig(msg)) => {
4105 assert!(
4106 msg.contains("appears at both index"),
4107 "expected duplicate-axis error, got: {msg}"
4108 );
4109 }
4110 other => panic!("expected InvalidConfig, got {other:?}"),
4111 }
4112 }
4113
4114 #[test]
4120 fn test_physical_order_dshape_omitted_decodes_numerically() {
4121 let score_threshold = 0.45;
4122 let iou_threshold = 0.45;
4123
4124 let protos_hwc = load_yolov8_protos().slice_move(s![0, .., .., ..]);
4125 let quant_protos = Quantization::new(0.02491161972284317, -117);
4126 let protos_f32_hwc = dequantize_ndarray::<_, _, f32>(protos_hwc.view(), quant_protos);
4127
4128 let boxes_2d = load_yolov8_boxes().slice_move(s![0, .., ..]);
4129 let quant_boxes = Quantization::new(0.021287761628627777, 31);
4130 let seg = dequantize_ndarray::<_, _, f32>(boxes_2d.view(), quant_boxes);
4131
4132 let protos_nhwc = protos_f32_hwc.clone().insert_axis(Axis(0));
4133 let seg_3d = seg.insert_axis(Axis(0));
4134
4135 let build_decoder = |det_dshape: Vec<(DimName, usize)>,
4136 proto_dshape: Vec<(DimName, usize)>| {
4137 DecoderBuilder::default()
4138 .with_config_yolo_segdet(
4139 configs::Detection {
4140 decoder: configs::DecoderType::Ultralytics,
4141 quantization: None,
4142 shape: vec![1, 116, 8400],
4143 dshape: det_dshape,
4144 normalized: Some(true),
4145 anchors: None,
4146 },
4147 configs::Protos {
4148 decoder: configs::DecoderType::Ultralytics,
4149 quantization: None,
4150 shape: vec![1, 160, 160, 32],
4151 dshape: proto_dshape,
4152 },
4153 None,
4154 )
4155 .with_score_threshold(score_threshold)
4156 .with_iou_threshold(iou_threshold)
4157 .build()
4158 .unwrap()
4159 };
4160
4161 let dshaped = build_decoder(
4163 vec![
4164 (DimName::Batch, 1),
4165 (DimName::NumFeatures, 116),
4166 (DimName::NumBoxes, 8400),
4167 ],
4168 vec![
4169 (DimName::Batch, 1),
4170 (DimName::Height, 160),
4171 (DimName::Width, 160),
4172 (DimName::NumProtos, 32),
4173 ],
4174 );
4175 let mut dshaped_boxes = Vec::new();
4176 let mut dshaped_masks = Vec::new();
4177 dshaped
4178 .decode_float(
4179 &[seg_3d.view().into_dyn(), protos_nhwc.view().into_dyn()],
4180 &mut dshaped_boxes,
4181 &mut dshaped_masks,
4182 )
4183 .unwrap();
4184
4185 let bare = build_decoder(vec![], vec![]);
4188 let mut bare_boxes = Vec::new();
4189 let mut bare_masks = Vec::new();
4190 bare.decode_float(
4191 &[seg_3d.view().into_dyn(), protos_nhwc.view().into_dyn()],
4192 &mut bare_boxes,
4193 &mut bare_masks,
4194 )
4195 .unwrap();
4196
4197 assert_eq!(bare_boxes.len(), dshaped_boxes.len());
4198 for (b, d) in bare_boxes.iter().zip(&dshaped_boxes) {
4199 assert!(
4200 b.equal_within_delta(d, 1e-4),
4201 "dshape-omitted box {b:?} differs from dshape-populated {d:?}"
4202 );
4203 }
4204 for (bm, dm) in bare_masks.iter().zip(&dshaped_masks) {
4205 let bm_arr = segmentation_to_mask(bm.segmentation.view()).unwrap();
4206 let dm_arr = segmentation_to_mask(dm.segmentation.view()).unwrap();
4207 assert_eq!(
4208 bm_arr, dm_arr,
4209 "dshape-omitted mask must match dshape-populated pixel-for-pixel"
4210 );
4211 }
4212 }
4213
4214 #[test]
4224 fn test_physical_order_ara2_4d_anchor_first_with_padding() {
4225 const N: usize = 8400;
4230 let mut boxes = Array3::<f32>::zeros((1, N, 4));
4231 let target = 42usize;
4232 boxes[[0, target, 0]] = 0.4;
4233 boxes[[0, target, 1]] = 0.5;
4234 boxes[[0, target, 2]] = 0.2;
4235 boxes[[0, target, 3]] = 0.2;
4236 let mut scores = Array3::<f32>::zeros((1, N, 80));
4237 scores[[0, target, 0]] = 0.9;
4238
4239 let json = r#"{
4245 "schema_version": 2,
4246 "decoder_version": "yolov8",
4247 "nms": "class_agnostic",
4248 "outputs": [
4249 {"name": "boxes", "type": "boxes",
4250 "shape": [1, 8400, 1, 4],
4251 "dshape": [{"batch":1},{"num_boxes":8400},{"padding":1},{"box_coords":4}],
4252 "encoding": "direct",
4253 "decoder": "ultralytics",
4254 "normalized": true},
4255 {"name": "scores", "type": "scores",
4256 "shape": [1, 8400, 1, 80],
4257 "dshape": [{"batch":1},{"num_boxes":8400},{"padding":1},{"num_classes":80}],
4258 "decoder": "ultralytics",
4259 "score_format": "per_class"}
4260 ]
4261 }"#;
4262 let decoder = DecoderBuilder::default()
4263 .with_config_json_str(json.to_string())
4264 .with_score_threshold(0.5)
4265 .with_iou_threshold(0.5)
4266 .build()
4267 .expect("4D anchor-first schema should build via squeeze_padding_dims");
4268
4269 let mut out_boxes = Vec::with_capacity(4);
4270 let mut out_masks = Vec::with_capacity(0);
4271 decoder
4272 .decode_float(
4273 &[boxes.view().into_dyn(), scores.view().into_dyn()],
4274 &mut out_boxes,
4275 &mut out_masks,
4276 )
4277 .unwrap();
4278
4279 assert_eq!(
4280 out_boxes.len(),
4281 1,
4282 "4D anchor-first with padding should decode exactly one box from the seeded anchor"
4283 );
4284 let b = &out_boxes[0];
4285 assert!((b.bbox.xmin - 0.3).abs() < 1e-3, "xmin wrong: {b:?}");
4287 assert!((b.bbox.ymin - 0.4).abs() < 1e-3, "ymin wrong: {b:?}");
4288 assert!((b.bbox.xmax - 0.5).abs() < 1e-3, "xmax wrong: {b:?}");
4289 assert!((b.bbox.ymax - 0.6).abs() < 1e-3, "ymax wrong: {b:?}");
4290 assert_eq!(b.label, 0);
4291 assert!(b.score > 0.85, "score {}: {b:?}", b.score);
4292 }
4293}
4294
4295#[cfg(feature = "tracker")]
4296#[cfg(test)]
4297#[cfg_attr(coverage_nightly, coverage(off))]
4298mod decoder_tracked_tests {
4299
4300 use edgefirst_tracker::{ByteTrackBuilder, Tracker};
4301 use ndarray::{array, s, Array, Array2, Array3, Array4, ArrayView, Axis, Dimension};
4302 use num_traits::{AsPrimitive, Float, PrimInt};
4303 use rand::{RngExt, SeedableRng};
4304 use rand_distr::StandardNormal;
4305
4306 use crate::{
4307 configs::{self, DimName},
4308 dequantize_ndarray, BoundingBox, DecoderBuilder, DetectBox, Quantization,
4309 };
4310
4311 pub fn quantize_ndarray<T: PrimInt + 'static, D: Dimension, F: Float + AsPrimitive<T>>(
4312 input: ArrayView<F, D>,
4313 quant: Quantization,
4314 ) -> Array<T, D>
4315 where
4316 i32: num_traits::AsPrimitive<F>,
4317 f32: num_traits::AsPrimitive<F>,
4318 {
4319 let zero_point = quant.zero_point.as_();
4320 let div_scale = F::one() / quant.scale.as_();
4321 if zero_point != F::zero() {
4322 input.mapv(|d| (d * div_scale + zero_point).round().as_())
4323 } else {
4324 input.mapv(|d| (d * div_scale).round().as_())
4325 }
4326 }
4327
4328 #[test]
4329 fn test_decoder_tracked_random_jitter() {
4330 use crate::configs::{DecoderType, Nms};
4331 use crate::DecoderBuilder;
4332
4333 let score_threshold = 0.25;
4334 let iou_threshold = 0.1;
4335 let out = include_bytes!(concat!(
4336 env!("CARGO_MANIFEST_DIR"),
4337 "/../../testdata/yolov8s_80_classes.bin"
4338 ));
4339 let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
4340 let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
4341 let quant = (0.0040811873, -123).into();
4342
4343 let decoder = DecoderBuilder::default()
4344 .with_config_yolo_det(
4345 crate::configs::Detection {
4346 decoder: DecoderType::Ultralytics,
4347 shape: vec![1, 84, 8400],
4348 anchors: None,
4349 quantization: Some(quant),
4350 dshape: vec![
4351 (crate::configs::DimName::Batch, 1),
4352 (crate::configs::DimName::NumFeatures, 84),
4353 (crate::configs::DimName::NumBoxes, 8400),
4354 ],
4355 normalized: Some(true),
4356 },
4357 None,
4358 )
4359 .with_score_threshold(score_threshold)
4360 .with_iou_threshold(iou_threshold)
4361 .with_nms(Some(Nms::ClassAgnostic))
4362 .build()
4363 .unwrap();
4364 let mut rng = rand::rngs::StdRng::seed_from_u64(0xAB_BEEF); let expected_boxes = [
4367 crate::DetectBox {
4368 bbox: crate::BoundingBox {
4369 xmin: 0.5285137,
4370 ymin: 0.05305544,
4371 xmax: 0.87541467,
4372 ymax: 0.9998909,
4373 },
4374 score: 0.5591227,
4375 label: 0,
4376 },
4377 crate::DetectBox {
4378 bbox: crate::BoundingBox {
4379 xmin: 0.130598,
4380 ymin: 0.43260583,
4381 xmax: 0.35098213,
4382 ymax: 0.9958097,
4383 },
4384 score: 0.33057618,
4385 label: 75,
4386 },
4387 ];
4388
4389 let mut tracker = ByteTrackBuilder::new()
4390 .track_update(0.1)
4391 .track_high_conf(0.3)
4392 .build();
4393
4394 let mut output_boxes = Vec::with_capacity(50);
4395 let mut output_masks = Vec::with_capacity(50);
4396 let mut output_tracks = Vec::with_capacity(50);
4397
4398 decoder
4399 .decode_tracked_quantized(
4400 &mut tracker,
4401 0,
4402 &[out.view().into()],
4403 &mut output_boxes,
4404 &mut output_masks,
4405 &mut output_tracks,
4406 )
4407 .unwrap();
4408
4409 assert_eq!(output_boxes.len(), 2);
4410 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4411 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
4412
4413 let mut last_boxes = output_boxes.clone();
4414
4415 for i in 1..=100 {
4416 let mut out = out.clone();
4417 let mut x_values = out.slice_mut(s![0, 0, ..]);
4419 for x in x_values.iter_mut() {
4420 let r: f32 = rng.sample(StandardNormal);
4421 let r = r.clamp(-2.0, 2.0) / 2.0;
4422 *x = x.saturating_add((r * 1e-2 / quant.0) as i8);
4423 }
4424
4425 let mut y_values = out.slice_mut(s![0, 1, ..]);
4426 for y in y_values.iter_mut() {
4427 let r: f32 = rng.sample(StandardNormal);
4428 let r = r.clamp(-2.0, 2.0) / 2.0;
4429 *y = y.saturating_add((r * 1e-2 / quant.0) as i8);
4430 }
4431
4432 decoder
4433 .decode_tracked_quantized(
4434 &mut tracker,
4435 100_000_000 * i / 3, &[out.view().into()],
4437 &mut output_boxes,
4438 &mut output_masks,
4439 &mut output_tracks,
4440 )
4441 .unwrap();
4442
4443 assert_eq!(output_boxes.len(), 2);
4444 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 5e-3));
4445 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 5e-3));
4446
4447 assert!(output_boxes[0].equal_within_delta(&last_boxes[0], 1e-3));
4448 assert!(output_boxes[1].equal_within_delta(&last_boxes[1], 1e-3));
4449 last_boxes = output_boxes.clone();
4450 }
4451 }
4452
4453 fn real_data_expected_boxes() -> [DetectBox; 2] {
4456 [
4457 DetectBox {
4458 bbox: BoundingBox {
4459 xmin: 0.08515105,
4460 ymin: 0.7131401,
4461 xmax: 0.29802868,
4462 ymax: 0.8195788,
4463 },
4464 score: 0.91537374,
4465 label: 23,
4466 },
4467 DetectBox {
4468 bbox: BoundingBox {
4469 xmin: 0.59605736,
4470 ymin: 0.25545314,
4471 xmax: 0.93666154,
4472 ymax: 0.72378385,
4473 },
4474 score: 0.91537374,
4475 label: 23,
4476 },
4477 ]
4478 }
4479
4480 fn e2e_expected_boxes_quant() -> [DetectBox; 1] {
4481 [DetectBox {
4482 bbox: BoundingBox {
4483 xmin: 0.12549022,
4484 ymin: 0.12549022,
4485 xmax: 0.23529413,
4486 ymax: 0.23529413,
4487 },
4488 score: 0.98823535,
4489 label: 2,
4490 }]
4491 }
4492
4493 fn e2e_expected_boxes_float() -> [DetectBox; 1] {
4494 [DetectBox {
4495 bbox: BoundingBox {
4496 xmin: 0.1234,
4497 ymin: 0.1234,
4498 xmax: 0.2345,
4499 ymax: 0.2345,
4500 },
4501 score: 0.9876,
4502 label: 2,
4503 }]
4504 }
4505
4506 fn build_yolo_split_segdet_decoder(
4507 score_threshold: f32,
4508 iou_threshold: f32,
4509 quant_boxes: (f32, i32),
4510 quant_protos: (f32, i32),
4511 ) -> crate::Decoder {
4512 DecoderBuilder::default()
4513 .with_config_yolo_split_segdet(
4514 configs::Boxes {
4515 decoder: configs::DecoderType::Ultralytics,
4516 quantization: Some(quant_boxes.into()),
4517 shape: vec![1, 4, 8400],
4518 dshape: vec![
4519 (DimName::Batch, 1),
4520 (DimName::BoxCoords, 4),
4521 (DimName::NumBoxes, 8400),
4522 ],
4523 normalized: Some(true),
4524 },
4525 configs::Scores {
4526 decoder: configs::DecoderType::Ultralytics,
4527 quantization: Some(quant_boxes.into()),
4528 shape: vec![1, 80, 8400],
4529 dshape: vec![
4530 (DimName::Batch, 1),
4531 (DimName::NumClasses, 80),
4532 (DimName::NumBoxes, 8400),
4533 ],
4534 },
4535 configs::MaskCoefficients {
4536 decoder: configs::DecoderType::Ultralytics,
4537 quantization: Some(quant_boxes.into()),
4538 shape: vec![1, 32, 8400],
4539 dshape: vec![
4540 (DimName::Batch, 1),
4541 (DimName::NumProtos, 32),
4542 (DimName::NumBoxes, 8400),
4543 ],
4544 },
4545 configs::Protos {
4546 decoder: configs::DecoderType::Ultralytics,
4547 quantization: Some(quant_protos.into()),
4548 shape: vec![1, 160, 160, 32],
4549 dshape: vec![
4550 (DimName::Batch, 1),
4551 (DimName::Height, 160),
4552 (DimName::Width, 160),
4553 (DimName::NumProtos, 32),
4554 ],
4555 },
4556 )
4557 .with_score_threshold(score_threshold)
4558 .with_iou_threshold(iou_threshold)
4559 .build()
4560 .unwrap()
4561 }
4562
4563 fn build_yolov8_seg_decoder(score_threshold: f32, iou_threshold: f32) -> crate::Decoder {
4564 let config_yaml = include_str!(concat!(
4565 env!("CARGO_MANIFEST_DIR"),
4566 "/../../testdata/yolov8_seg.yaml"
4567 ));
4568 DecoderBuilder::default()
4569 .with_config_yaml_str(config_yaml.to_string())
4570 .with_score_threshold(score_threshold)
4571 .with_iou_threshold(iou_threshold)
4572 .build()
4573 .unwrap()
4574 }
4575
4576 macro_rules! real_data_tracked_test {
4583 ($name:ident, quantized, $layout:ident, $output:ident) => {
4584 #[test]
4585 fn $name() {
4586 let is_split = matches!(stringify!($layout), "split");
4587 let is_proto = matches!(stringify!($output), "proto");
4588
4589 let score_threshold = 0.45;
4590 let iou_threshold = 0.45;
4591 let quant_boxes = (0.021287762_f32, 31_i32);
4592 let quant_protos = (0.02491162_f32, -117_i32);
4593
4594 let raw_boxes = include_bytes!(concat!(
4595 env!("CARGO_MANIFEST_DIR"),
4596 "/../../testdata/yolov8_boxes_116x8400.bin"
4597 ));
4598 let raw_boxes = unsafe {
4599 std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len())
4600 };
4601 let boxes_i8 =
4602 ndarray::Array3::from_shape_vec((1, 116, 8400), raw_boxes.to_vec()).unwrap();
4603
4604 let raw_protos = include_bytes!(concat!(
4605 env!("CARGO_MANIFEST_DIR"),
4606 "/../../testdata/yolov8_protos_160x160x32.bin"
4607 ));
4608 let raw_protos = unsafe {
4609 std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
4610 };
4611 let protos_i8 =
4612 ndarray::Array4::from_shape_vec((1, 160, 160, 32), raw_protos.to_vec())
4613 .unwrap();
4614
4615 let mask_split = boxes_i8.slice(s![.., 84.., ..]).to_owned();
4617 let mut scores_split = boxes_i8.slice(s![.., 4..84, ..]).to_owned();
4618 let boxes_split = boxes_i8.slice(s![.., ..4, ..]).to_owned();
4619 let mut boxes_combined = boxes_i8;
4620
4621 let decoder = if is_split {
4622 build_yolo_split_segdet_decoder(
4623 score_threshold,
4624 iou_threshold,
4625 quant_boxes,
4626 quant_protos,
4627 )
4628 } else {
4629 build_yolov8_seg_decoder(score_threshold, iou_threshold)
4630 };
4631
4632 let expected = real_data_expected_boxes();
4633 let mut tracker = ByteTrackBuilder::new()
4634 .track_update(0.1)
4635 .track_high_conf(0.7)
4636 .build();
4637 let mut output_boxes = Vec::with_capacity(50);
4638 let mut output_tracks = Vec::with_capacity(50);
4639
4640 if is_proto {
4642 {
4643 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
4644 vec![
4645 boxes_split.view().into(),
4646 scores_split.view().into(),
4647 mask_split.view().into(),
4648 protos_i8.view().into(),
4649 ]
4650 } else {
4651 vec![boxes_combined.view().into(), protos_i8.view().into()]
4652 };
4653 decoder
4654 .decode_tracked_quantized_proto(
4655 &mut tracker,
4656 0,
4657 &inputs,
4658 &mut output_boxes,
4659 &mut output_tracks,
4660 )
4661 .unwrap();
4662 }
4663 assert_eq!(output_boxes.len(), 2);
4664 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4665 assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
4666
4667 if is_split {
4669 for score in scores_split.iter_mut() {
4670 *score = i8::MIN;
4671 }
4672 } else {
4673 for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
4674 *score = i8::MIN;
4675 }
4676 }
4677
4678 let proto_result = {
4679 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
4680 vec![
4681 boxes_split.view().into(),
4682 scores_split.view().into(),
4683 mask_split.view().into(),
4684 protos_i8.view().into(),
4685 ]
4686 } else {
4687 vec![boxes_combined.view().into(), protos_i8.view().into()]
4688 };
4689 decoder
4690 .decode_tracked_quantized_proto(
4691 &mut tracker,
4692 100_000_000 / 3,
4693 &inputs,
4694 &mut output_boxes,
4695 &mut output_tracks,
4696 )
4697 .unwrap()
4698 };
4699 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4700 assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
4701 assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
4702 } else {
4703 let mut output_masks = Vec::with_capacity(50);
4704 {
4705 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
4706 vec![
4707 boxes_split.view().into(),
4708 scores_split.view().into(),
4709 mask_split.view().into(),
4710 protos_i8.view().into(),
4711 ]
4712 } else {
4713 vec![boxes_combined.view().into(), protos_i8.view().into()]
4714 };
4715 decoder
4716 .decode_tracked_quantized(
4717 &mut tracker,
4718 0,
4719 &inputs,
4720 &mut output_boxes,
4721 &mut output_masks,
4722 &mut output_tracks,
4723 )
4724 .unwrap();
4725 }
4726 assert_eq!(output_boxes.len(), 2);
4727 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4728 assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
4729
4730 if is_split {
4731 for score in scores_split.iter_mut() {
4732 *score = i8::MIN;
4733 }
4734 } else {
4735 for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
4736 *score = i8::MIN;
4737 }
4738 }
4739
4740 {
4741 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
4742 vec![
4743 boxes_split.view().into(),
4744 scores_split.view().into(),
4745 mask_split.view().into(),
4746 protos_i8.view().into(),
4747 ]
4748 } else {
4749 vec![boxes_combined.view().into(), protos_i8.view().into()]
4750 };
4751 decoder
4752 .decode_tracked_quantized(
4753 &mut tracker,
4754 100_000_000 / 3,
4755 &inputs,
4756 &mut output_boxes,
4757 &mut output_masks,
4758 &mut output_tracks,
4759 )
4760 .unwrap();
4761 }
4762 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4763 assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
4764 assert!(output_masks.is_empty());
4765 }
4766 }
4767 };
4768 ($name:ident, float, $layout:ident, $output:ident) => {
4769 #[test]
4770 fn $name() {
4771 let is_split = matches!(stringify!($layout), "split");
4772 let is_proto = matches!(stringify!($output), "proto");
4773
4774 let score_threshold = 0.45;
4775 let iou_threshold = 0.45;
4776 let quant_boxes = (0.021287762_f32, 31_i32);
4777 let quant_protos = (0.02491162_f32, -117_i32);
4778
4779 let raw_boxes = include_bytes!(concat!(
4780 env!("CARGO_MANIFEST_DIR"),
4781 "/../../testdata/yolov8_boxes_116x8400.bin"
4782 ));
4783 let raw_boxes = unsafe {
4784 std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len())
4785 };
4786 let boxes_i8 =
4787 ndarray::Array3::from_shape_vec((1, 116, 8400), raw_boxes.to_vec()).unwrap();
4788 let boxes_f32 = dequantize_ndarray(boxes_i8.view(), quant_boxes.into());
4789
4790 let raw_protos = include_bytes!(concat!(
4791 env!("CARGO_MANIFEST_DIR"),
4792 "/../../testdata/yolov8_protos_160x160x32.bin"
4793 ));
4794 let raw_protos = unsafe {
4795 std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
4796 };
4797 let protos_i8 =
4798 ndarray::Array4::from_shape_vec((1, 160, 160, 32), raw_protos.to_vec())
4799 .unwrap();
4800 let protos_f32 = dequantize_ndarray(protos_i8.view(), quant_protos.into());
4801
4802 let mask_split = boxes_f32.slice(s![.., 84.., ..]).to_owned();
4804 let mut scores_split = boxes_f32.slice(s![.., 4..84, ..]).to_owned();
4805 let boxes_split = boxes_f32.slice(s![.., ..4, ..]).to_owned();
4806 let mut boxes_combined = boxes_f32;
4807
4808 let decoder = if is_split {
4809 build_yolo_split_segdet_decoder(
4810 score_threshold,
4811 iou_threshold,
4812 quant_boxes,
4813 quant_protos,
4814 )
4815 } else {
4816 build_yolov8_seg_decoder(score_threshold, iou_threshold)
4817 };
4818
4819 let expected = real_data_expected_boxes();
4820 let mut tracker = ByteTrackBuilder::new()
4821 .track_update(0.1)
4822 .track_high_conf(0.7)
4823 .build();
4824 let mut output_boxes = Vec::with_capacity(50);
4825 let mut output_tracks = Vec::with_capacity(50);
4826
4827 if is_proto {
4828 {
4829 let inputs = if is_split {
4830 vec![
4831 boxes_split.view().into_dyn(),
4832 scores_split.view().into_dyn(),
4833 mask_split.view().into_dyn(),
4834 protos_f32.view().into_dyn(),
4835 ]
4836 } else {
4837 vec![
4838 boxes_combined.view().into_dyn(),
4839 protos_f32.view().into_dyn(),
4840 ]
4841 };
4842 decoder
4843 .decode_tracked_float_proto(
4844 &mut tracker,
4845 0,
4846 &inputs,
4847 &mut output_boxes,
4848 &mut output_tracks,
4849 )
4850 .unwrap();
4851 }
4852 assert_eq!(output_boxes.len(), 2);
4853 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4854 assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
4855
4856 if is_split {
4857 for score in scores_split.iter_mut() {
4858 *score = 0.0;
4859 }
4860 } else {
4861 for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
4862 *score = 0.0;
4863 }
4864 }
4865
4866 let proto_result = {
4867 let inputs = if is_split {
4868 vec![
4869 boxes_split.view().into_dyn(),
4870 scores_split.view().into_dyn(),
4871 mask_split.view().into_dyn(),
4872 protos_f32.view().into_dyn(),
4873 ]
4874 } else {
4875 vec![
4876 boxes_combined.view().into_dyn(),
4877 protos_f32.view().into_dyn(),
4878 ]
4879 };
4880 decoder
4881 .decode_tracked_float_proto(
4882 &mut tracker,
4883 100_000_000 / 3,
4884 &inputs,
4885 &mut output_boxes,
4886 &mut output_tracks,
4887 )
4888 .unwrap()
4889 };
4890 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4891 assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
4892 assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
4893 } else {
4894 let mut output_masks = Vec::with_capacity(50);
4895 {
4896 let inputs = if is_split {
4897 vec![
4898 boxes_split.view().into_dyn(),
4899 scores_split.view().into_dyn(),
4900 mask_split.view().into_dyn(),
4901 protos_f32.view().into_dyn(),
4902 ]
4903 } else {
4904 vec![
4905 boxes_combined.view().into_dyn(),
4906 protos_f32.view().into_dyn(),
4907 ]
4908 };
4909 decoder
4910 .decode_tracked_float(
4911 &mut tracker,
4912 0,
4913 &inputs,
4914 &mut output_boxes,
4915 &mut output_masks,
4916 &mut output_tracks,
4917 )
4918 .unwrap();
4919 }
4920 assert_eq!(output_boxes.len(), 2);
4921 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4922 assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
4923
4924 if is_split {
4925 for score in scores_split.iter_mut() {
4926 *score = 0.0;
4927 }
4928 } else {
4929 for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
4930 *score = 0.0;
4931 }
4932 }
4933
4934 {
4935 let inputs = if is_split {
4936 vec![
4937 boxes_split.view().into_dyn(),
4938 scores_split.view().into_dyn(),
4939 mask_split.view().into_dyn(),
4940 protos_f32.view().into_dyn(),
4941 ]
4942 } else {
4943 vec![
4944 boxes_combined.view().into_dyn(),
4945 protos_f32.view().into_dyn(),
4946 ]
4947 };
4948 decoder
4949 .decode_tracked_float(
4950 &mut tracker,
4951 100_000_000 / 3,
4952 &inputs,
4953 &mut output_boxes,
4954 &mut output_masks,
4955 &mut output_tracks,
4956 )
4957 .unwrap();
4958 }
4959 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4960 assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
4961 assert!(output_masks.is_empty());
4962 }
4963 }
4964 };
4965 }
4966
4967 real_data_tracked_test!(test_decoder_tracked_segdet, quantized, combined, masks);
4968 real_data_tracked_test!(test_decoder_tracked_segdet_float, float, combined, masks);
4969 real_data_tracked_test!(
4970 test_decoder_tracked_segdet_proto,
4971 quantized,
4972 combined,
4973 proto
4974 );
4975 real_data_tracked_test!(
4976 test_decoder_tracked_segdet_proto_float,
4977 float,
4978 combined,
4979 proto
4980 );
4981 real_data_tracked_test!(test_decoder_tracked_segdet_split, quantized, split, masks);
4982 real_data_tracked_test!(test_decoder_tracked_segdet_split_float, float, split, masks);
4983 real_data_tracked_test!(
4984 test_decoder_tracked_segdet_split_proto,
4985 quantized,
4986 split,
4987 proto
4988 );
4989 real_data_tracked_test!(
4990 test_decoder_tracked_segdet_split_proto_float,
4991 float,
4992 split,
4993 proto
4994 );
4995
4996 const E2E_COMBINED_CONFIG: &str = "
5002decoder_version: yolo26
5003outputs:
5004 - type: detection
5005 decoder: ultralytics
5006 quantization: [0.00784313725490196, 0]
5007 shape: [1, 10, 38]
5008 dshape:
5009 - [batch, 1]
5010 - [num_boxes, 10]
5011 - [num_features, 38]
5012 normalized: true
5013 - type: protos
5014 decoder: ultralytics
5015 quantization: [0.0039215686274509803921568627451, 128]
5016 shape: [1, 160, 160, 32]
5017 dshape:
5018 - [batch, 1]
5019 - [height, 160]
5020 - [width, 160]
5021 - [num_protos, 32]
5022";
5023
5024 const E2E_SPLIT_CONFIG: &str = "
5025decoder_version: yolo26
5026outputs:
5027 - type: boxes
5028 decoder: ultralytics
5029 quantization: [0.00784313725490196, 0]
5030 shape: [1, 10, 4]
5031 dshape:
5032 - [batch, 1]
5033 - [num_boxes, 10]
5034 - [box_coords, 4]
5035 normalized: true
5036 - type: scores
5037 decoder: ultralytics
5038 quantization: [0.00784313725490196, 0]
5039 shape: [1, 10, 1]
5040 dshape:
5041 - [batch, 1]
5042 - [num_boxes, 10]
5043 - [num_classes, 1]
5044 - type: classes
5045 decoder: ultralytics
5046 quantization: [0.00784313725490196, 0]
5047 shape: [1, 10, 1]
5048 dshape:
5049 - [batch, 1]
5050 - [num_boxes, 10]
5051 - [num_classes, 1]
5052 - type: mask_coefficients
5053 decoder: ultralytics
5054 quantization: [0.00784313725490196, 0]
5055 shape: [1, 10, 32]
5056 dshape:
5057 - [batch, 1]
5058 - [num_boxes, 10]
5059 - [num_protos, 32]
5060 - type: protos
5061 decoder: ultralytics
5062 quantization: [0.0039215686274509803921568627451, 128]
5063 shape: [1, 160, 160, 32]
5064 dshape:
5065 - [batch, 1]
5066 - [height, 160]
5067 - [width, 160]
5068 - [num_protos, 32]
5069";
5070
5071 macro_rules! e2e_tracked_test {
5072 ($name:ident, quantized, $layout:ident, $output:ident) => {
5073 #[test]
5074 fn $name() {
5075 let is_split = matches!(stringify!($layout), "split");
5076 let is_proto = matches!(stringify!($output), "proto");
5077
5078 let score_threshold = 0.45;
5079 let iou_threshold = 0.45;
5080
5081 let mut boxes = Array2::zeros((10, 4));
5082 let mut scores = Array2::zeros((10, 1));
5083 let mut classes = Array2::zeros((10, 1));
5084 let mask = Array2::zeros((10, 32));
5085 let protos = Array3::<f64>::zeros((160, 160, 32));
5086 let protos = protos.insert_axis(Axis(0));
5087 let protos_quant = (1.0 / 255.0, 0.0);
5088 let protos: Array4<u8> = quantize_ndarray(protos.view(), protos_quant.into());
5089
5090 boxes
5091 .slice_mut(s![0, ..])
5092 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
5093 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
5094 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
5095
5096 let detect_quant = (2.0 / 255.0, 0.0);
5097
5098 let decoder = if is_split {
5099 DecoderBuilder::default()
5100 .with_config_yaml_str(E2E_SPLIT_CONFIG.to_string())
5101 .with_score_threshold(score_threshold)
5102 .with_iou_threshold(iou_threshold)
5103 .build()
5104 .unwrap()
5105 } else {
5106 DecoderBuilder::default()
5107 .with_config_yaml_str(E2E_COMBINED_CONFIG.to_string())
5108 .with_score_threshold(score_threshold)
5109 .with_iou_threshold(iou_threshold)
5110 .build()
5111 .unwrap()
5112 };
5113
5114 let expected = e2e_expected_boxes_quant();
5115 let mut tracker = ByteTrackBuilder::new()
5116 .track_update(0.1)
5117 .track_high_conf(0.7)
5118 .build();
5119 let mut output_boxes = Vec::with_capacity(50);
5120 let mut output_tracks = Vec::with_capacity(50);
5121
5122 if is_split {
5123 let boxes = boxes.insert_axis(Axis(0));
5124 let scores = scores.insert_axis(Axis(0));
5125 let classes = classes.insert_axis(Axis(0));
5126 let mask = mask.insert_axis(Axis(0));
5127
5128 let boxes: Array3<u8> = quantize_ndarray(boxes.view(), detect_quant.into());
5129 let mut scores: Array3<u8> =
5130 quantize_ndarray(scores.view(), detect_quant.into());
5131 let classes: Array3<u8> = quantize_ndarray(classes.view(), detect_quant.into());
5132 let mask: Array3<u8> = quantize_ndarray(mask.view(), detect_quant.into());
5133
5134 if is_proto {
5135 {
5136 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
5137 boxes.view().into(),
5138 scores.view().into(),
5139 classes.view().into(),
5140 mask.view().into(),
5141 protos.view().into(),
5142 ];
5143 decoder
5144 .decode_tracked_quantized_proto(
5145 &mut tracker,
5146 0,
5147 &inputs,
5148 &mut output_boxes,
5149 &mut output_tracks,
5150 )
5151 .unwrap();
5152 }
5153 assert_eq!(output_boxes.len(), 1);
5154 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5155
5156 for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
5157 *score = u8::MIN;
5158 }
5159 let proto_result = {
5160 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
5161 boxes.view().into(),
5162 scores.view().into(),
5163 classes.view().into(),
5164 mask.view().into(),
5165 protos.view().into(),
5166 ];
5167 decoder
5168 .decode_tracked_quantized_proto(
5169 &mut tracker,
5170 100_000_000 / 3,
5171 &inputs,
5172 &mut output_boxes,
5173 &mut output_tracks,
5174 )
5175 .unwrap()
5176 };
5177 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5178 assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
5179 } else {
5180 let mut output_masks = Vec::with_capacity(50);
5181 {
5182 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
5183 boxes.view().into(),
5184 scores.view().into(),
5185 classes.view().into(),
5186 mask.view().into(),
5187 protos.view().into(),
5188 ];
5189 decoder
5190 .decode_tracked_quantized(
5191 &mut tracker,
5192 0,
5193 &inputs,
5194 &mut output_boxes,
5195 &mut output_masks,
5196 &mut output_tracks,
5197 )
5198 .unwrap();
5199 }
5200 assert_eq!(output_boxes.len(), 1);
5201 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5202
5203 for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
5204 *score = u8::MIN;
5205 }
5206 {
5207 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
5208 boxes.view().into(),
5209 scores.view().into(),
5210 classes.view().into(),
5211 mask.view().into(),
5212 protos.view().into(),
5213 ];
5214 decoder
5215 .decode_tracked_quantized(
5216 &mut tracker,
5217 100_000_000 / 3,
5218 &inputs,
5219 &mut output_boxes,
5220 &mut output_masks,
5221 &mut output_tracks,
5222 )
5223 .unwrap();
5224 }
5225 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5226 assert!(output_masks.is_empty());
5227 }
5228 } else {
5229 let detect = ndarray::concatenate![
5231 Axis(1),
5232 boxes.view(),
5233 scores.view(),
5234 classes.view(),
5235 mask.view()
5236 ];
5237 let detect = detect.insert_axis(Axis(0));
5238 assert_eq!(detect.shape(), &[1, 10, 38]);
5239 let mut detect: Array3<u8> =
5240 quantize_ndarray(detect.view(), detect_quant.into());
5241
5242 if is_proto {
5243 {
5244 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
5245 vec![detect.view().into(), protos.view().into()];
5246 decoder
5247 .decode_tracked_quantized_proto(
5248 &mut tracker,
5249 0,
5250 &inputs,
5251 &mut output_boxes,
5252 &mut output_tracks,
5253 )
5254 .unwrap();
5255 }
5256 assert_eq!(output_boxes.len(), 1);
5257 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5258
5259 for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
5260 *score = u8::MIN;
5261 }
5262 let proto_result = {
5263 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
5264 vec![detect.view().into(), protos.view().into()];
5265 decoder
5266 .decode_tracked_quantized_proto(
5267 &mut tracker,
5268 100_000_000 / 3,
5269 &inputs,
5270 &mut output_boxes,
5271 &mut output_tracks,
5272 )
5273 .unwrap()
5274 };
5275 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5276 assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
5277 } else {
5278 let mut output_masks = Vec::with_capacity(50);
5279 {
5280 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
5281 vec![detect.view().into(), protos.view().into()];
5282 decoder
5283 .decode_tracked_quantized(
5284 &mut tracker,
5285 0,
5286 &inputs,
5287 &mut output_boxes,
5288 &mut output_masks,
5289 &mut output_tracks,
5290 )
5291 .unwrap();
5292 }
5293 assert_eq!(output_boxes.len(), 1);
5294 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5295
5296 for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
5297 *score = u8::MIN;
5298 }
5299 {
5300 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
5301 vec![detect.view().into(), protos.view().into()];
5302 decoder
5303 .decode_tracked_quantized(
5304 &mut tracker,
5305 100_000_000 / 3,
5306 &inputs,
5307 &mut output_boxes,
5308 &mut output_masks,
5309 &mut output_tracks,
5310 )
5311 .unwrap();
5312 }
5313 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5314 assert!(output_masks.is_empty());
5315 }
5316 }
5317 }
5318 };
5319 ($name:ident, float, $layout:ident, $output:ident) => {
5320 #[test]
5321 fn $name() {
5322 let is_split = matches!(stringify!($layout), "split");
5323 let is_proto = matches!(stringify!($output), "proto");
5324
5325 let score_threshold = 0.45;
5326 let iou_threshold = 0.45;
5327
5328 let mut boxes = Array2::zeros((10, 4));
5329 let mut scores = Array2::zeros((10, 1));
5330 let mut classes = Array2::zeros((10, 1));
5331 let mask: Array2<f64> = Array2::zeros((10, 32));
5332 let protos = Array3::<f64>::zeros((160, 160, 32));
5333 let protos = protos.insert_axis(Axis(0));
5334
5335 boxes
5336 .slice_mut(s![0, ..])
5337 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
5338 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
5339 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
5340
5341 let decoder = if is_split {
5342 DecoderBuilder::default()
5343 .with_config_yaml_str(E2E_SPLIT_CONFIG.to_string())
5344 .with_score_threshold(score_threshold)
5345 .with_iou_threshold(iou_threshold)
5346 .build()
5347 .unwrap()
5348 } else {
5349 DecoderBuilder::default()
5350 .with_config_yaml_str(E2E_COMBINED_CONFIG.to_string())
5351 .with_score_threshold(score_threshold)
5352 .with_iou_threshold(iou_threshold)
5353 .build()
5354 .unwrap()
5355 };
5356
5357 let expected = e2e_expected_boxes_float();
5358 let mut tracker = ByteTrackBuilder::new()
5359 .track_update(0.1)
5360 .track_high_conf(0.7)
5361 .build();
5362 let mut output_boxes = Vec::with_capacity(50);
5363 let mut output_tracks = Vec::with_capacity(50);
5364
5365 if is_split {
5366 let boxes = boxes.insert_axis(Axis(0));
5367 let mut scores = scores.insert_axis(Axis(0));
5368 let classes = classes.insert_axis(Axis(0));
5369 let mask = mask.insert_axis(Axis(0));
5370
5371 if is_proto {
5372 {
5373 let inputs = vec![
5374 boxes.view().into_dyn(),
5375 scores.view().into_dyn(),
5376 classes.view().into_dyn(),
5377 mask.view().into_dyn(),
5378 protos.view().into_dyn(),
5379 ];
5380 decoder
5381 .decode_tracked_float_proto(
5382 &mut tracker,
5383 0,
5384 &inputs,
5385 &mut output_boxes,
5386 &mut output_tracks,
5387 )
5388 .unwrap();
5389 }
5390 assert_eq!(output_boxes.len(), 1);
5391 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5392
5393 for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
5394 *score = 0.0;
5395 }
5396 let proto_result = {
5397 let inputs = vec![
5398 boxes.view().into_dyn(),
5399 scores.view().into_dyn(),
5400 classes.view().into_dyn(),
5401 mask.view().into_dyn(),
5402 protos.view().into_dyn(),
5403 ];
5404 decoder
5405 .decode_tracked_float_proto(
5406 &mut tracker,
5407 100_000_000 / 3,
5408 &inputs,
5409 &mut output_boxes,
5410 &mut output_tracks,
5411 )
5412 .unwrap()
5413 };
5414 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5415 assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
5416 } else {
5417 let mut output_masks = Vec::with_capacity(50);
5418 {
5419 let inputs = vec![
5420 boxes.view().into_dyn(),
5421 scores.view().into_dyn(),
5422 classes.view().into_dyn(),
5423 mask.view().into_dyn(),
5424 protos.view().into_dyn(),
5425 ];
5426 decoder
5427 .decode_tracked_float(
5428 &mut tracker,
5429 0,
5430 &inputs,
5431 &mut output_boxes,
5432 &mut output_masks,
5433 &mut output_tracks,
5434 )
5435 .unwrap();
5436 }
5437 assert_eq!(output_boxes.len(), 1);
5438 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5439
5440 for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
5441 *score = 0.0;
5442 }
5443 {
5444 let inputs = vec![
5445 boxes.view().into_dyn(),
5446 scores.view().into_dyn(),
5447 classes.view().into_dyn(),
5448 mask.view().into_dyn(),
5449 protos.view().into_dyn(),
5450 ];
5451 decoder
5452 .decode_tracked_float(
5453 &mut tracker,
5454 100_000_000 / 3,
5455 &inputs,
5456 &mut output_boxes,
5457 &mut output_masks,
5458 &mut output_tracks,
5459 )
5460 .unwrap();
5461 }
5462 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5463 assert!(output_masks.is_empty());
5464 }
5465 } else {
5466 let detect = ndarray::concatenate![
5468 Axis(1),
5469 boxes.view(),
5470 scores.view(),
5471 classes.view(),
5472 mask.view()
5473 ];
5474 let mut detect = detect.insert_axis(Axis(0));
5475 assert_eq!(detect.shape(), &[1, 10, 38]);
5476
5477 if is_proto {
5478 {
5479 let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
5480 decoder
5481 .decode_tracked_float_proto(
5482 &mut tracker,
5483 0,
5484 &inputs,
5485 &mut output_boxes,
5486 &mut output_tracks,
5487 )
5488 .unwrap();
5489 }
5490 assert_eq!(output_boxes.len(), 1);
5491 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5492
5493 for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
5494 *score = 0.0;
5495 }
5496 let proto_result = {
5497 let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
5498 decoder
5499 .decode_tracked_float_proto(
5500 &mut tracker,
5501 100_000_000 / 3,
5502 &inputs,
5503 &mut output_boxes,
5504 &mut output_tracks,
5505 )
5506 .unwrap()
5507 };
5508 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5509 assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
5510 } else {
5511 let mut output_masks = Vec::with_capacity(50);
5512 {
5513 let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
5514 decoder
5515 .decode_tracked_float(
5516 &mut tracker,
5517 0,
5518 &inputs,
5519 &mut output_boxes,
5520 &mut output_masks,
5521 &mut output_tracks,
5522 )
5523 .unwrap();
5524 }
5525 assert_eq!(output_boxes.len(), 1);
5526 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5527
5528 for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
5529 *score = 0.0;
5530 }
5531 {
5532 let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
5533 decoder
5534 .decode_tracked_float(
5535 &mut tracker,
5536 100_000_000 / 3,
5537 &inputs,
5538 &mut output_boxes,
5539 &mut output_masks,
5540 &mut output_tracks,
5541 )
5542 .unwrap();
5543 }
5544 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5545 assert!(output_masks.is_empty());
5546 }
5547 }
5548 }
5549 };
5550 }
5551
5552 e2e_tracked_test!(
5553 test_decoder_tracked_end_to_end_segdet,
5554 quantized,
5555 combined,
5556 masks
5557 );
5558 e2e_tracked_test!(
5559 test_decoder_tracked_end_to_end_segdet_float,
5560 float,
5561 combined,
5562 masks
5563 );
5564 e2e_tracked_test!(
5565 test_decoder_tracked_end_to_end_segdet_proto,
5566 quantized,
5567 combined,
5568 proto
5569 );
5570 e2e_tracked_test!(
5571 test_decoder_tracked_end_to_end_segdet_proto_float,
5572 float,
5573 combined,
5574 proto
5575 );
5576 e2e_tracked_test!(
5577 test_decoder_tracked_end_to_end_segdet_split,
5578 quantized,
5579 split,
5580 masks
5581 );
5582 e2e_tracked_test!(
5583 test_decoder_tracked_end_to_end_segdet_split_float,
5584 float,
5585 split,
5586 masks
5587 );
5588 e2e_tracked_test!(
5589 test_decoder_tracked_end_to_end_segdet_split_proto,
5590 quantized,
5591 split,
5592 proto
5593 );
5594 e2e_tracked_test!(
5595 test_decoder_tracked_end_to_end_segdet_split_proto_float,
5596 float,
5597 split,
5598 proto
5599 );
5600
5601 macro_rules! e2e_tracked_tensor_test {
5607 ($name:ident, quantized, $layout:ident, $output:ident) => {
5608 #[test]
5609 fn $name() {
5610 use edgefirst_tensor::{Tensor, TensorMapTrait, TensorTrait};
5611
5612 let is_split = matches!(stringify!($layout), "split");
5613 let is_proto = matches!(stringify!($output), "proto");
5614
5615 let score_threshold = 0.45;
5616 let iou_threshold = 0.45;
5617
5618 let mut boxes = Array2::zeros((10, 4));
5619 let mut scores = Array2::zeros((10, 1));
5620 let mut classes = Array2::zeros((10, 1));
5621 let mask = Array2::zeros((10, 32));
5622 let protos_f64 = Array3::<f64>::zeros((160, 160, 32));
5623 let protos_f64 = protos_f64.insert_axis(Axis(0));
5624 let protos_quant = (1.0 / 255.0, 0.0);
5625 let protos_u8: Array4<u8> =
5626 quantize_ndarray(protos_f64.view(), protos_quant.into());
5627
5628 boxes
5629 .slice_mut(s![0, ..])
5630 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
5631 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
5632 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
5633
5634 let detect_quant = (2.0 / 255.0, 0.0);
5635
5636 let decoder = if is_split {
5637 DecoderBuilder::default()
5638 .with_config_yaml_str(E2E_SPLIT_CONFIG.to_string())
5639 .with_score_threshold(score_threshold)
5640 .with_iou_threshold(iou_threshold)
5641 .build()
5642 .unwrap()
5643 } else {
5644 DecoderBuilder::default()
5645 .with_config_yaml_str(E2E_COMBINED_CONFIG.to_string())
5646 .with_score_threshold(score_threshold)
5647 .with_iou_threshold(iou_threshold)
5648 .build()
5649 .unwrap()
5650 };
5651
5652 let make_u8_tensor =
5654 |shape: &[usize], data: &[u8]| -> edgefirst_tensor::TensorDyn {
5655 let t = Tensor::<u8>::new(shape, None, None).unwrap();
5656 t.map().unwrap().as_mut_slice()[..data.len()].copy_from_slice(data);
5657 t.into()
5658 };
5659
5660 let expected = e2e_expected_boxes_quant();
5661 let mut tracker = ByteTrackBuilder::new()
5662 .track_update(0.1)
5663 .track_high_conf(0.7)
5664 .build();
5665 let mut output_boxes = Vec::with_capacity(50);
5666 let mut output_tracks = Vec::with_capacity(50);
5667
5668 let protos_td = make_u8_tensor(protos_u8.shape(), protos_u8.as_slice().unwrap());
5669
5670 if is_split {
5671 let boxes = boxes.insert_axis(Axis(0));
5672 let scores = scores.insert_axis(Axis(0));
5673 let classes = classes.insert_axis(Axis(0));
5674 let mask = mask.insert_axis(Axis(0));
5675
5676 let boxes_q: Array3<u8> = quantize_ndarray(boxes.view(), detect_quant.into());
5677 let mut scores_q: Array3<u8> =
5678 quantize_ndarray(scores.view(), detect_quant.into());
5679 let classes_q: Array3<u8> =
5680 quantize_ndarray(classes.view(), detect_quant.into());
5681 let mask_q: Array3<u8> = quantize_ndarray(mask.view(), detect_quant.into());
5682
5683 let boxes_td = make_u8_tensor(boxes_q.shape(), boxes_q.as_slice().unwrap());
5684 let classes_td =
5685 make_u8_tensor(classes_q.shape(), classes_q.as_slice().unwrap());
5686 let mask_td = make_u8_tensor(mask_q.shape(), mask_q.as_slice().unwrap());
5687
5688 if is_proto {
5689 let scores_td =
5690 make_u8_tensor(scores_q.shape(), scores_q.as_slice().unwrap());
5691 decoder
5692 .decode_proto_tracked(
5693 &mut tracker,
5694 0,
5695 &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5696 &mut output_boxes,
5697 &mut output_tracks,
5698 )
5699 .unwrap();
5700
5701 assert_eq!(output_boxes.len(), 1);
5702 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5703
5704 for score in scores_q.slice_mut(s![.., .., ..]).iter_mut() {
5705 *score = u8::MIN;
5706 }
5707 let scores_td =
5708 make_u8_tensor(scores_q.shape(), scores_q.as_slice().unwrap());
5709 let proto_result = decoder
5710 .decode_proto_tracked(
5711 &mut tracker,
5712 100_000_000 / 3,
5713 &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5714 &mut output_boxes,
5715 &mut output_tracks,
5716 )
5717 .unwrap();
5718 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5719 assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
5720 } else {
5721 let scores_td =
5722 make_u8_tensor(scores_q.shape(), scores_q.as_slice().unwrap());
5723 let mut output_masks = Vec::with_capacity(50);
5724 decoder
5725 .decode_tracked(
5726 &mut tracker,
5727 0,
5728 &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5729 &mut output_boxes,
5730 &mut output_masks,
5731 &mut output_tracks,
5732 )
5733 .unwrap();
5734
5735 assert_eq!(output_boxes.len(), 1);
5736 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5737
5738 for score in scores_q.slice_mut(s![.., .., ..]).iter_mut() {
5739 *score = u8::MIN;
5740 }
5741 let scores_td =
5742 make_u8_tensor(scores_q.shape(), scores_q.as_slice().unwrap());
5743 decoder
5744 .decode_tracked(
5745 &mut tracker,
5746 100_000_000 / 3,
5747 &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5748 &mut output_boxes,
5749 &mut output_masks,
5750 &mut output_tracks,
5751 )
5752 .unwrap();
5753 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5754 assert!(output_masks.is_empty());
5755 }
5756 } else {
5757 let detect = ndarray::concatenate![
5759 Axis(1),
5760 boxes.view(),
5761 scores.view(),
5762 classes.view(),
5763 mask.view()
5764 ];
5765 let detect = detect.insert_axis(Axis(0));
5766 assert_eq!(detect.shape(), &[1, 10, 38]);
5767 let detect =
5769 Array3::from_shape_vec(detect.raw_dim(), detect.iter().copied().collect())
5770 .unwrap();
5771 let mut detect_q: Array3<u8> =
5772 quantize_ndarray(detect.view(), detect_quant.into());
5773
5774 if is_proto {
5775 let detect_td =
5776 make_u8_tensor(detect_q.shape(), detect_q.as_slice().unwrap());
5777 decoder
5778 .decode_proto_tracked(
5779 &mut tracker,
5780 0,
5781 &[&detect_td, &protos_td],
5782 &mut output_boxes,
5783 &mut output_tracks,
5784 )
5785 .unwrap();
5786
5787 assert_eq!(output_boxes.len(), 1);
5788 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5789
5790 for score in detect_q.slice_mut(s![.., .., 4]).iter_mut() {
5791 *score = u8::MIN;
5792 }
5793 let detect_td =
5794 make_u8_tensor(detect_q.shape(), detect_q.as_slice().unwrap());
5795 let proto_result = decoder
5796 .decode_proto_tracked(
5797 &mut tracker,
5798 100_000_000 / 3,
5799 &[&detect_td, &protos_td],
5800 &mut output_boxes,
5801 &mut output_tracks,
5802 )
5803 .unwrap();
5804 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5805 assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
5806 } else {
5807 let detect_td =
5808 make_u8_tensor(detect_q.shape(), detect_q.as_slice().unwrap());
5809 let mut output_masks = Vec::with_capacity(50);
5810 decoder
5811 .decode_tracked(
5812 &mut tracker,
5813 0,
5814 &[&detect_td, &protos_td],
5815 &mut output_boxes,
5816 &mut output_masks,
5817 &mut output_tracks,
5818 )
5819 .unwrap();
5820
5821 assert_eq!(output_boxes.len(), 1);
5822 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5823
5824 for score in detect_q.slice_mut(s![.., .., 4]).iter_mut() {
5825 *score = u8::MIN;
5826 }
5827 let detect_td =
5828 make_u8_tensor(detect_q.shape(), detect_q.as_slice().unwrap());
5829 decoder
5830 .decode_tracked(
5831 &mut tracker,
5832 100_000_000 / 3,
5833 &[&detect_td, &protos_td],
5834 &mut output_boxes,
5835 &mut output_masks,
5836 &mut output_tracks,
5837 )
5838 .unwrap();
5839 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5840 assert!(output_masks.is_empty());
5841 }
5842 }
5843 }
5844 };
5845 ($name:ident, float, $layout:ident, $output:ident) => {
5846 #[test]
5847 fn $name() {
5848 use edgefirst_tensor::{Tensor, TensorMapTrait, TensorTrait};
5849
5850 let is_split = matches!(stringify!($layout), "split");
5851 let is_proto = matches!(stringify!($output), "proto");
5852
5853 let score_threshold = 0.45;
5854 let iou_threshold = 0.45;
5855
5856 let mut boxes = Array2::zeros((10, 4));
5857 let mut scores = Array2::zeros((10, 1));
5858 let mut classes = Array2::zeros((10, 1));
5859 let mask: Array2<f64> = Array2::zeros((10, 32));
5860 let protos = Array3::<f64>::zeros((160, 160, 32));
5861 let protos = protos.insert_axis(Axis(0));
5862
5863 boxes
5864 .slice_mut(s![0, ..])
5865 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
5866 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
5867 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
5868
5869 let decoder = if is_split {
5870 DecoderBuilder::default()
5871 .with_config_yaml_str(E2E_SPLIT_CONFIG.to_string())
5872 .with_score_threshold(score_threshold)
5873 .with_iou_threshold(iou_threshold)
5874 .build()
5875 .unwrap()
5876 } else {
5877 DecoderBuilder::default()
5878 .with_config_yaml_str(E2E_COMBINED_CONFIG.to_string())
5879 .with_score_threshold(score_threshold)
5880 .with_iou_threshold(iou_threshold)
5881 .build()
5882 .unwrap()
5883 };
5884
5885 let make_f64_tensor =
5887 |shape: &[usize], data: &[f64]| -> edgefirst_tensor::TensorDyn {
5888 let t = Tensor::<f64>::new(shape, None, None).unwrap();
5889 t.map().unwrap().as_mut_slice()[..data.len()].copy_from_slice(data);
5890 t.into()
5891 };
5892
5893 let expected = e2e_expected_boxes_float();
5894 let mut tracker = ByteTrackBuilder::new()
5895 .track_update(0.1)
5896 .track_high_conf(0.7)
5897 .build();
5898 let mut output_boxes = Vec::with_capacity(50);
5899 let mut output_tracks = Vec::with_capacity(50);
5900
5901 let protos_td = make_f64_tensor(protos.shape(), protos.as_slice().unwrap());
5902
5903 if is_split {
5904 let boxes = boxes.insert_axis(Axis(0));
5905 let mut scores = scores.insert_axis(Axis(0));
5906 let classes = classes.insert_axis(Axis(0));
5907 let mask = mask.insert_axis(Axis(0));
5908
5909 let boxes_td = make_f64_tensor(boxes.shape(), boxes.as_slice().unwrap());
5910 let classes_td = make_f64_tensor(classes.shape(), classes.as_slice().unwrap());
5911 let mask_td = make_f64_tensor(mask.shape(), mask.as_slice().unwrap());
5912
5913 if is_proto {
5914 let scores_td = make_f64_tensor(scores.shape(), scores.as_slice().unwrap());
5915 decoder
5916 .decode_proto_tracked(
5917 &mut tracker,
5918 0,
5919 &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5920 &mut output_boxes,
5921 &mut output_tracks,
5922 )
5923 .unwrap();
5924
5925 assert_eq!(output_boxes.len(), 1);
5926 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5927
5928 for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
5929 *score = 0.0;
5930 }
5931 let scores_td = make_f64_tensor(scores.shape(), scores.as_slice().unwrap());
5932 let proto_result = decoder
5933 .decode_proto_tracked(
5934 &mut tracker,
5935 100_000_000 / 3,
5936 &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5937 &mut output_boxes,
5938 &mut output_tracks,
5939 )
5940 .unwrap();
5941 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5942 assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
5943 } else {
5944 let scores_td = make_f64_tensor(scores.shape(), scores.as_slice().unwrap());
5945 let mut output_masks = Vec::with_capacity(50);
5946 decoder
5947 .decode_tracked(
5948 &mut tracker,
5949 0,
5950 &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5951 &mut output_boxes,
5952 &mut output_masks,
5953 &mut output_tracks,
5954 )
5955 .unwrap();
5956
5957 assert_eq!(output_boxes.len(), 1);
5958 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5959
5960 for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
5961 *score = 0.0;
5962 }
5963 let scores_td = make_f64_tensor(scores.shape(), scores.as_slice().unwrap());
5964 decoder
5965 .decode_tracked(
5966 &mut tracker,
5967 100_000_000 / 3,
5968 &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5969 &mut output_boxes,
5970 &mut output_masks,
5971 &mut output_tracks,
5972 )
5973 .unwrap();
5974 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5975 assert!(output_masks.is_empty());
5976 }
5977 } else {
5978 let detect = ndarray::concatenate![
5980 Axis(1),
5981 boxes.view(),
5982 scores.view(),
5983 classes.view(),
5984 mask.view()
5985 ];
5986 let detect = detect.insert_axis(Axis(0));
5987 assert_eq!(detect.shape(), &[1, 10, 38]);
5988 let mut detect =
5990 Array3::from_shape_vec(detect.raw_dim(), detect.iter().copied().collect())
5991 .unwrap();
5992
5993 if is_proto {
5994 let detect_td = make_f64_tensor(detect.shape(), detect.as_slice().unwrap());
5995 decoder
5996 .decode_proto_tracked(
5997 &mut tracker,
5998 0,
5999 &[&detect_td, &protos_td],
6000 &mut output_boxes,
6001 &mut output_tracks,
6002 )
6003 .unwrap();
6004
6005 assert_eq!(output_boxes.len(), 1);
6006 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
6007
6008 for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
6009 *score = 0.0;
6010 }
6011 let detect_td = make_f64_tensor(detect.shape(), detect.as_slice().unwrap());
6012 let proto_result = decoder
6013 .decode_proto_tracked(
6014 &mut tracker,
6015 100_000_000 / 3,
6016 &[&detect_td, &protos_td],
6017 &mut output_boxes,
6018 &mut output_tracks,
6019 )
6020 .unwrap();
6021 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
6022 assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
6023 } else {
6024 let detect_td = make_f64_tensor(detect.shape(), detect.as_slice().unwrap());
6025 let mut output_masks = Vec::with_capacity(50);
6026 decoder
6027 .decode_tracked(
6028 &mut tracker,
6029 0,
6030 &[&detect_td, &protos_td],
6031 &mut output_boxes,
6032 &mut output_masks,
6033 &mut output_tracks,
6034 )
6035 .unwrap();
6036
6037 assert_eq!(output_boxes.len(), 1);
6038 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
6039
6040 for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
6041 *score = 0.0;
6042 }
6043 let detect_td = make_f64_tensor(detect.shape(), detect.as_slice().unwrap());
6044 decoder
6045 .decode_tracked(
6046 &mut tracker,
6047 100_000_000 / 3,
6048 &[&detect_td, &protos_td],
6049 &mut output_boxes,
6050 &mut output_masks,
6051 &mut output_tracks,
6052 )
6053 .unwrap();
6054 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
6055 assert!(output_masks.is_empty());
6056 }
6057 }
6058 }
6059 };
6060 }
6061
6062 e2e_tracked_tensor_test!(
6063 test_decoder_tracked_tensor_end_to_end_segdet,
6064 quantized,
6065 combined,
6066 masks
6067 );
6068 e2e_tracked_tensor_test!(
6069 test_decoder_tracked_tensor_end_to_end_segdet_float,
6070 float,
6071 combined,
6072 masks
6073 );
6074 e2e_tracked_tensor_test!(
6075 test_decoder_tracked_tensor_end_to_end_segdet_proto,
6076 quantized,
6077 combined,
6078 proto
6079 );
6080 e2e_tracked_tensor_test!(
6081 test_decoder_tracked_tensor_end_to_end_segdet_proto_float,
6082 float,
6083 combined,
6084 proto
6085 );
6086 e2e_tracked_tensor_test!(
6087 test_decoder_tracked_tensor_end_to_end_segdet_split,
6088 quantized,
6089 split,
6090 masks
6091 );
6092 e2e_tracked_tensor_test!(
6093 test_decoder_tracked_tensor_end_to_end_segdet_split_float,
6094 float,
6095 split,
6096 masks
6097 );
6098 e2e_tracked_tensor_test!(
6099 test_decoder_tracked_tensor_end_to_end_segdet_split_proto,
6100 quantized,
6101 split,
6102 proto
6103 );
6104 e2e_tracked_tensor_test!(
6105 test_decoder_tracked_tensor_end_to_end_segdet_split_proto_float,
6106 float,
6107 split,
6108 proto
6109 );
6110
6111 #[test]
6112 fn test_decoder_tracked_linear_motion() {
6113 use crate::configs::{DecoderType, Nms};
6114 use crate::DecoderBuilder;
6115
6116 let score_threshold = 0.25;
6117 let iou_threshold = 0.1;
6118 let out = include_bytes!(concat!(
6119 env!("CARGO_MANIFEST_DIR"),
6120 "/../../testdata/yolov8s_80_classes.bin"
6121 ));
6122 let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
6123 let mut out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
6124 let quant = (0.0040811873, -123).into();
6125
6126 let decoder = DecoderBuilder::default()
6127 .with_config_yolo_det(
6128 crate::configs::Detection {
6129 decoder: DecoderType::Ultralytics,
6130 shape: vec![1, 84, 8400],
6131 anchors: None,
6132 quantization: Some(quant),
6133 dshape: vec![
6134 (crate::configs::DimName::Batch, 1),
6135 (crate::configs::DimName::NumFeatures, 84),
6136 (crate::configs::DimName::NumBoxes, 8400),
6137 ],
6138 normalized: Some(true),
6139 },
6140 None,
6141 )
6142 .with_score_threshold(score_threshold)
6143 .with_iou_threshold(iou_threshold)
6144 .with_nms(Some(Nms::ClassAgnostic))
6145 .build()
6146 .unwrap();
6147
6148 let mut expected_boxes = [
6149 DetectBox {
6150 bbox: BoundingBox {
6151 xmin: 0.5285137,
6152 ymin: 0.05305544,
6153 xmax: 0.87541467,
6154 ymax: 0.9998909,
6155 },
6156 score: 0.5591227,
6157 label: 0,
6158 },
6159 DetectBox {
6160 bbox: BoundingBox {
6161 xmin: 0.130598,
6162 ymin: 0.43260583,
6163 xmax: 0.35098213,
6164 ymax: 0.9958097,
6165 },
6166 score: 0.33057618,
6167 label: 75,
6168 },
6169 ];
6170
6171 let mut tracker = ByteTrackBuilder::new()
6172 .track_update(0.1)
6173 .track_high_conf(0.3)
6174 .build();
6175
6176 let mut output_boxes = Vec::with_capacity(50);
6177 let mut output_masks = Vec::with_capacity(50);
6178 let mut output_tracks = Vec::with_capacity(50);
6179
6180 decoder
6181 .decode_tracked_quantized(
6182 &mut tracker,
6183 0,
6184 &[out.view().into()],
6185 &mut output_boxes,
6186 &mut output_masks,
6187 &mut output_tracks,
6188 )
6189 .unwrap();
6190
6191 assert_eq!(output_boxes.len(), 2);
6192 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
6193 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
6194
6195 for i in 1..=100 {
6196 let mut out = out.clone();
6197 let mut x_values = out.slice_mut(s![0, 0, ..]);
6199 for x in x_values.iter_mut() {
6200 *x = x.saturating_add((i as f32 * 1e-3 / quant.0).round() as i8);
6201 }
6202
6203 decoder
6204 .decode_tracked_quantized(
6205 &mut tracker,
6206 100_000_000 * i / 3, &[out.view().into()],
6208 &mut output_boxes,
6209 &mut output_masks,
6210 &mut output_tracks,
6211 )
6212 .unwrap();
6213
6214 assert_eq!(output_boxes.len(), 2);
6215 }
6216 let tracks = tracker.get_active_tracks();
6217 let predicted_boxes: Vec<_> = tracks
6218 .iter()
6219 .map(|track| {
6220 let mut l = track.last_box;
6221 l.bbox = track.info.tracked_location.into();
6222 l
6223 })
6224 .collect();
6225 expected_boxes[0].bbox.xmin += 0.1; expected_boxes[0].bbox.xmax += 0.1;
6227 expected_boxes[1].bbox.xmin += 0.1;
6228 expected_boxes[1].bbox.xmax += 0.1;
6229
6230 assert!(predicted_boxes[0].equal_within_delta(&expected_boxes[0], 1e-3));
6231 assert!(predicted_boxes[1].equal_within_delta(&expected_boxes[1], 1e-3));
6232
6233 let mut scores_values = out.slice_mut(s![0, 4.., ..]);
6235 for score in scores_values.iter_mut() {
6236 *score = i8::MIN; }
6238 decoder
6239 .decode_tracked_quantized(
6240 &mut tracker,
6241 100_000_000 * 101 / 3,
6242 &[out.view().into()],
6243 &mut output_boxes,
6244 &mut output_masks,
6245 &mut output_tracks,
6246 )
6247 .unwrap();
6248 expected_boxes[0].bbox.xmin += 0.001; expected_boxes[0].bbox.xmax += 0.001;
6250 expected_boxes[1].bbox.xmin += 0.001;
6251 expected_boxes[1].bbox.xmax += 0.001;
6252
6253 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-3));
6254 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-3));
6255 }
6256
6257 #[test]
6258 fn test_decoder_tracked_end_to_end_float() {
6259 let score_threshold = 0.45;
6260 let iou_threshold = 0.45;
6261
6262 let mut boxes = Array2::zeros((10, 4));
6263 let mut scores = Array2::zeros((10, 1));
6264 let mut classes = Array2::zeros((10, 1));
6265
6266 boxes
6267 .slice_mut(s![0, ..,])
6268 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
6269 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
6270 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
6271
6272 let detect = ndarray::concatenate![Axis(1), boxes.view(), scores.view(), classes.view(),];
6273 let mut detect = detect.insert_axis(Axis(0));
6274 assert_eq!(detect.shape(), &[1, 10, 6]);
6275 let config = "
6276decoder_version: yolo26
6277outputs:
6278 - type: detection
6279 decoder: ultralytics
6280 quantization: [0.00784313725490196, 0]
6281 shape: [1, 10, 6]
6282 dshape:
6283 - [batch, 1]
6284 - [num_boxes, 10]
6285 - [num_features, 6]
6286 normalized: true
6287";
6288
6289 let decoder = DecoderBuilder::default()
6290 .with_config_yaml_str(config.to_string())
6291 .with_score_threshold(score_threshold)
6292 .with_iou_threshold(iou_threshold)
6293 .build()
6294 .unwrap();
6295
6296 let expected_boxes = [DetectBox {
6297 bbox: BoundingBox {
6298 xmin: 0.1234,
6299 ymin: 0.1234,
6300 xmax: 0.2345,
6301 ymax: 0.2345,
6302 },
6303 score: 0.9876,
6304 label: 2,
6305 }];
6306
6307 let mut tracker = ByteTrackBuilder::new()
6308 .track_update(0.1)
6309 .track_high_conf(0.7)
6310 .build();
6311
6312 let mut output_boxes = Vec::with_capacity(50);
6313 let mut output_masks = Vec::with_capacity(50);
6314 let mut output_tracks = Vec::with_capacity(50);
6315
6316 decoder
6317 .decode_tracked_float(
6318 &mut tracker,
6319 0,
6320 &[detect.view().into_dyn()],
6321 &mut output_boxes,
6322 &mut output_masks,
6323 &mut output_tracks,
6324 )
6325 .unwrap();
6326
6327 assert_eq!(output_boxes.len(), 1);
6328 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
6329
6330 for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
6333 *score = 0.0; }
6335
6336 decoder
6337 .decode_tracked_float(
6338 &mut tracker,
6339 100_000_000 / 3,
6340 &[detect.view().into_dyn()],
6341 &mut output_boxes,
6342 &mut output_masks,
6343 &mut output_tracks,
6344 )
6345 .unwrap();
6346 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
6347 }
6348}