1#![cfg_attr(coverage_nightly, feature(coverage_attribute))]
68
69use ndarray::{Array, Array2, Array3, ArrayView, ArrayView1, ArrayView3, Dimension};
70use num_traits::{AsPrimitive, Float, PrimInt};
71
72pub mod byte;
73pub mod error;
74pub mod float;
75pub mod modelpack;
76pub mod yolo;
77
78mod decoder;
79pub use decoder::*;
80
81pub use configs::{DecoderVersion, Nms};
82pub use error::{DecoderError, DecoderResult};
83
84use crate::{
85 decoder::configs::QuantTuple, modelpack::modelpack_segmentation_to_mask,
86 yolo::yolo_segmentation_to_mask,
87};
88
89pub trait BBoxTypeTrait {
91 fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4];
93
94 fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
96 input: &[B; 4],
97 quant: Quantization,
98 ) -> [A; 4]
99 where
100 f32: AsPrimitive<A>,
101 i32: AsPrimitive<A>;
102
103 #[inline(always)]
114 fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
115 input: ArrayView1<B>,
116 ) -> [A; 4] {
117 Self::to_xyxy_float(&[input[0], input[1], input[2], input[3]])
118 }
119
120 #[inline(always)]
121 fn ndarray_to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
123 input: ArrayView1<B>,
124 quant: Quantization,
125 ) -> [A; 4]
126 where
127 f32: AsPrimitive<A>,
128 i32: AsPrimitive<A>,
129 {
130 Self::to_xyxy_dequant(&[input[0], input[1], input[2], input[3]], quant)
131 }
132}
133
134#[derive(Debug, Clone, Copy, PartialEq, Eq)]
136pub struct XYXY {}
137
138impl BBoxTypeTrait for XYXY {
139 fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
140 input.map(|b| b.as_())
141 }
142
143 fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
144 input: &[B; 4],
145 quant: Quantization,
146 ) -> [A; 4]
147 where
148 f32: AsPrimitive<A>,
149 i32: AsPrimitive<A>,
150 {
151 let scale = quant.scale.as_();
152 let zp = quant.zero_point.as_();
153 input.map(|b| (b.as_() - zp) * scale)
154 }
155
156 #[inline(always)]
157 fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
158 input: ArrayView1<B>,
159 ) -> [A; 4] {
160 [
161 input[0].as_(),
162 input[1].as_(),
163 input[2].as_(),
164 input[3].as_(),
165 ]
166 }
167}
168
169#[derive(Debug, Clone, Copy, PartialEq, Eq)]
172pub struct XYWH {}
173
174impl BBoxTypeTrait for XYWH {
175 #[inline(always)]
176 fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
177 let half = A::one() / (A::one() + A::one());
178 [
179 (input[0].as_()) - (input[2].as_() * half),
180 (input[1].as_()) - (input[3].as_() * half),
181 (input[0].as_()) + (input[2].as_() * half),
182 (input[1].as_()) + (input[3].as_() * half),
183 ]
184 }
185
186 #[inline(always)]
187 fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
188 input: &[B; 4],
189 quant: Quantization,
190 ) -> [A; 4]
191 where
192 f32: AsPrimitive<A>,
193 i32: AsPrimitive<A>,
194 {
195 let scale = quant.scale.as_();
196 let half_scale = (quant.scale * 0.5).as_();
197 let zp = quant.zero_point.as_();
198 let [x, y, w, h] = [
199 (input[0].as_() - zp) * scale,
200 (input[1].as_() - zp) * scale,
201 (input[2].as_() - zp) * half_scale,
202 (input[3].as_() - zp) * half_scale,
203 ];
204
205 [x - w, y - h, x + w, y + h]
206 }
207
208 #[inline(always)]
209 fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
210 input: ArrayView1<B>,
211 ) -> [A; 4] {
212 let half = A::one() / (A::one() + A::one());
213 [
214 (input[0].as_()) - (input[2].as_() * half),
215 (input[1].as_()) - (input[3].as_() * half),
216 (input[0].as_()) + (input[2].as_() * half),
217 (input[1].as_()) + (input[3].as_() * half),
218 ]
219 }
220}
221
222#[derive(Debug, Clone, Copy, PartialEq)]
224pub struct Quantization {
225 pub scale: f32,
226 pub zero_point: i32,
227}
228
229impl Quantization {
230 pub fn new(scale: f32, zero_point: i32) -> Self {
239 Self { scale, zero_point }
240 }
241}
242
243impl From<QuantTuple> for Quantization {
244 fn from(quant_tuple: QuantTuple) -> Quantization {
255 Quantization {
256 scale: quant_tuple.0,
257 zero_point: quant_tuple.1,
258 }
259 }
260}
261
262impl<S, Z> From<(S, Z)> for Quantization
263where
264 S: AsPrimitive<f32>,
265 Z: AsPrimitive<i32>,
266{
267 fn from((scale, zp): (S, Z)) -> Quantization {
276 Self {
277 scale: scale.as_(),
278 zero_point: zp.as_(),
279 }
280 }
281}
282
283impl Default for Quantization {
284 fn default() -> Self {
293 Self {
294 scale: 1.0,
295 zero_point: 0,
296 }
297 }
298}
299
300#[derive(Debug, Clone, Copy, PartialEq, Default)]
302pub struct DetectBox {
303 pub bbox: BoundingBox,
304 pub score: f32,
306 pub label: usize,
308}
309
310#[derive(Debug, Clone, Copy, PartialEq, Default)]
312pub struct BoundingBox {
313 pub xmin: f32,
315 pub ymin: f32,
317 pub xmax: f32,
319 pub ymax: f32,
321}
322
323impl BoundingBox {
324 pub fn new(xmin: f32, ymin: f32, xmax: f32, ymax: f32) -> Self {
326 Self {
327 xmin,
328 ymin,
329 xmax,
330 ymax,
331 }
332 }
333
334 pub fn to_canonical(&self) -> Self {
343 let xmin = self.xmin.min(self.xmax);
344 let xmax = self.xmin.max(self.xmax);
345 let ymin = self.ymin.min(self.ymax);
346 let ymax = self.ymin.max(self.ymax);
347 BoundingBox {
348 xmin,
349 ymin,
350 xmax,
351 ymax,
352 }
353 }
354}
355
356impl From<BoundingBox> for [f32; 4] {
357 fn from(b: BoundingBox) -> Self {
372 [b.xmin, b.ymin, b.xmax, b.ymax]
373 }
374}
375
376impl From<[f32; 4]> for BoundingBox {
377 fn from(arr: [f32; 4]) -> Self {
380 BoundingBox {
381 xmin: arr[0],
382 ymin: arr[1],
383 xmax: arr[2],
384 ymax: arr[3],
385 }
386 }
387}
388
389impl DetectBox {
390 pub fn equal_within_delta(&self, rhs: &DetectBox, eps: f32) -> bool {
419 let eq_delta = |a: f32, b: f32| (a - b).abs() <= eps;
420 self.label == rhs.label
421 && eq_delta(self.score, rhs.score)
422 && eq_delta(self.bbox.xmin, rhs.bbox.xmin)
423 && eq_delta(self.bbox.ymin, rhs.bbox.ymin)
424 && eq_delta(self.bbox.xmax, rhs.bbox.xmax)
425 && eq_delta(self.bbox.ymax, rhs.bbox.ymax)
426 }
427}
428
429#[derive(Debug, Clone, PartialEq, Default)]
432pub struct Segmentation {
433 pub xmin: f32,
435 pub ymin: f32,
437 pub xmax: f32,
439 pub ymax: f32,
441 pub segmentation: Array3<u8>,
451}
452
453#[derive(Debug, Clone)]
459pub enum ProtoTensor {
460 Quantized {
464 protos: Array3<i8>,
465 quantization: Quantization,
466 },
467 Float(Array3<f32>),
469}
470
471impl ProtoTensor {
472 pub fn is_quantized(&self) -> bool {
474 matches!(self, ProtoTensor::Quantized { .. })
475 }
476
477 pub fn dim(&self) -> (usize, usize, usize) {
479 match self {
480 ProtoTensor::Quantized { protos, .. } => protos.dim(),
481 ProtoTensor::Float(arr) => arr.dim(),
482 }
483 }
484
485 pub fn as_f32(&self) -> std::borrow::Cow<'_, Array3<f32>> {
488 match self {
489 ProtoTensor::Float(arr) => std::borrow::Cow::Borrowed(arr),
490 ProtoTensor::Quantized {
491 protos,
492 quantization,
493 } => {
494 let scale = quantization.scale;
495 let zp = quantization.zero_point as f32;
496 std::borrow::Cow::Owned(protos.map(|&v| (v as f32 - zp) * scale))
497 }
498 }
499 }
500}
501
502#[derive(Debug, Clone)]
508pub struct ProtoData {
509 pub mask_coefficients: Vec<Vec<f32>>,
511 pub protos: ProtoTensor,
513}
514
515pub fn dequant_detect_box<SCORE: PrimInt + AsPrimitive<f32>>(
533 detect: &DetectBoxQuantized<SCORE>,
534 quant_scores: Quantization,
535) -> DetectBox {
536 let scaled_zp = -quant_scores.scale * quant_scores.zero_point as f32;
537 DetectBox {
538 bbox: detect.bbox,
539 score: quant_scores.scale * detect.score.as_() + scaled_zp,
540 label: detect.label,
541 }
542}
543#[derive(Debug, Clone, Copy, PartialEq)]
545pub struct DetectBoxQuantized<
546 SCORE: PrimInt + AsPrimitive<f32>,
548> {
549 pub bbox: BoundingBox,
551 pub score: SCORE,
554 pub label: usize,
556}
557
558pub fn dequantize_ndarray<T: AsPrimitive<F>, D: Dimension, F: Float + 'static>(
571 input: ArrayView<T, D>,
572 quant: Quantization,
573) -> Array<F, D>
574where
575 i32: num_traits::AsPrimitive<F>,
576 f32: num_traits::AsPrimitive<F>,
577{
578 let zero_point = quant.zero_point.as_();
579 let scale = quant.scale.as_();
580 if zero_point != F::zero() {
581 let scaled_zero = -zero_point * scale;
582 input.mapv(|d| d.as_() * scale + scaled_zero)
583 } else {
584 input.mapv(|d| d.as_() * scale)
585 }
586}
587
588pub fn dequantize_cpu<T: AsPrimitive<F>, F: Float + 'static>(
601 input: &[T],
602 quant: Quantization,
603 output: &mut [F],
604) where
605 f32: num_traits::AsPrimitive<F>,
606 i32: num_traits::AsPrimitive<F>,
607{
608 assert!(input.len() == output.len());
609 let zero_point = quant.zero_point.as_();
610 let scale = quant.scale.as_();
611 if zero_point != F::zero() {
612 let scaled_zero = -zero_point * scale; input
614 .iter()
615 .zip(output)
616 .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
617 } else {
618 input
619 .iter()
620 .zip(output)
621 .for_each(|(d, deq)| *deq = d.as_() * scale);
622 }
623}
624
625pub fn dequantize_cpu_chunked<T: AsPrimitive<F>, F: Float + 'static>(
639 input: &[T],
640 quant: Quantization,
641 output: &mut [F],
642) where
643 f32: num_traits::AsPrimitive<F>,
644 i32: num_traits::AsPrimitive<F>,
645{
646 assert!(input.len() == output.len());
647 let zero_point = quant.zero_point.as_();
648 let scale = quant.scale.as_();
649
650 let input = input.as_chunks::<4>();
651 let output = output.as_chunks_mut::<4>();
652
653 if zero_point != F::zero() {
654 let scaled_zero = -zero_point * scale; input
657 .0
658 .iter()
659 .zip(output.0)
660 .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale + scaled_zero));
661 input
662 .1
663 .iter()
664 .zip(output.1)
665 .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
666 } else {
667 input
668 .0
669 .iter()
670 .zip(output.0)
671 .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale));
672 input
673 .1
674 .iter()
675 .zip(output.1)
676 .for_each(|(d, deq)| *deq = d.as_() * scale);
677 }
678}
679
680pub fn segmentation_to_mask(segmentation: ArrayView3<u8>) -> Result<Array2<u8>, DecoderError> {
698 if segmentation.shape()[2] == 0 {
699 return Err(DecoderError::InvalidShape(
700 "Segmentation tensor must have non-zero depth".to_string(),
701 ));
702 }
703 if segmentation.shape()[2] == 1 {
704 yolo_segmentation_to_mask(segmentation, 128)
705 } else {
706 Ok(modelpack_segmentation_to_mask(segmentation))
707 }
708}
709
710fn arg_max<T: PartialOrd + Copy>(score: ArrayView1<T>) -> (T, usize) {
712 score
713 .iter()
714 .enumerate()
715 .fold((score[0], 0), |(max, arg_max), (ind, s)| {
716 if max > *s {
717 (max, arg_max)
718 } else {
719 (*s, ind)
720 }
721 })
722}
723#[cfg(test)]
724#[cfg_attr(coverage_nightly, coverage(off))]
725mod decoder_tests {
726 #![allow(clippy::excessive_precision)]
727 use crate::{
728 configs::{DecoderType, DimName, Protos},
729 modelpack::{decode_modelpack_det, decode_modelpack_split_quant},
730 yolo::{
731 decode_yolo_det, decode_yolo_det_float, decode_yolo_segdet_float,
732 decode_yolo_segdet_quant,
733 },
734 *,
735 };
736 use edgefirst_tensor::{Tensor, TensorMapTrait, TensorTrait};
737 use ndarray::Dimension;
738 use ndarray::{array, s, Array2, Array3, Array4, Axis};
739 use ndarray_stats::DeviationExt;
740 use num_traits::{AsPrimitive, PrimInt};
741
742 fn compare_outputs(
743 boxes: (&[DetectBox], &[DetectBox]),
744 masks: (&[Segmentation], &[Segmentation]),
745 ) {
746 let (boxes0, boxes1) = boxes;
747 let (masks0, masks1) = masks;
748
749 assert_eq!(boxes0.len(), boxes1.len());
750 assert_eq!(masks0.len(), masks1.len());
751
752 for (b_i8, b_f32) in boxes0.iter().zip(boxes1) {
753 assert!(
754 b_i8.equal_within_delta(b_f32, 1e-6),
755 "{b_i8:?} is not equal to {b_f32:?}"
756 );
757 }
758
759 for (m_i8, m_f32) in masks0.iter().zip(masks1) {
760 assert_eq!(
761 [m_i8.xmin, m_i8.ymin, m_i8.xmax, m_i8.ymax],
762 [m_f32.xmin, m_f32.ymin, m_f32.xmax, m_f32.ymax],
763 );
764 assert_eq!(m_i8.segmentation.shape(), m_f32.segmentation.shape());
765 let mask_i8 = m_i8.segmentation.map(|x| *x as i32);
766 let mask_f32 = m_f32.segmentation.map(|x| *x as i32);
767 let diff = &mask_i8 - &mask_f32;
768 for x in 0..diff.shape()[0] {
769 for y in 0..diff.shape()[1] {
770 for z in 0..diff.shape()[2] {
771 let val = diff[[x, y, z]];
772 assert!(
773 val.abs() <= 1,
774 "Difference between mask0 and mask1 is greater than 1 at ({}, {}, {}): {}",
775 x,
776 y,
777 z,
778 val
779 );
780 }
781 }
782 }
783 let mean_sq_err = mask_i8.mean_sq_err(&mask_f32).unwrap();
784 assert!(
785 mean_sq_err < 1e-2,
786 "Mean Square Error between masks was greater than 1%: {:.2}%",
787 mean_sq_err * 100.0
788 );
789 }
790 }
791
792 fn load_yolov8_boxes() -> Array3<i8> {
795 let raw = include_bytes!(concat!(
796 env!("CARGO_MANIFEST_DIR"),
797 "/../../testdata/yolov8_boxes_116x8400.bin"
798 ));
799 let raw = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const i8, raw.len()) };
800 Array3::from_shape_vec((1, 116, 8400), raw.to_vec()).unwrap()
801 }
802
803 fn load_yolov8_protos() -> Array4<i8> {
804 let raw = include_bytes!(concat!(
805 env!("CARGO_MANIFEST_DIR"),
806 "/../../testdata/yolov8_protos_160x160x32.bin"
807 ));
808 let raw = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const i8, raw.len()) };
809 Array4::from_shape_vec((1, 160, 160, 32), raw.to_vec()).unwrap()
810 }
811
812 fn load_yolov8s_det() -> Array3<i8> {
813 let raw = include_bytes!(concat!(
814 env!("CARGO_MANIFEST_DIR"),
815 "/../../testdata/yolov8s_80_classes.bin"
816 ));
817 let raw = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const i8, raw.len()) };
818 Array3::from_shape_vec((1, 84, 8400), raw.to_vec()).unwrap()
819 }
820
821 #[test]
822 fn test_decoder_modelpack() {
823 let score_threshold = 0.45;
824 let iou_threshold = 0.45;
825 let boxes = include_bytes!(concat!(
826 env!("CARGO_MANIFEST_DIR"),
827 "/../../testdata/modelpack_boxes_1935x1x4.bin"
828 ));
829 let boxes = ndarray::Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
830
831 let scores = include_bytes!(concat!(
832 env!("CARGO_MANIFEST_DIR"),
833 "/../../testdata/modelpack_scores_1935x1.bin"
834 ));
835 let scores = ndarray::Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
836
837 let quant_boxes = (0.004656755365431309, 21).into();
838 let quant_scores = (0.0019603664986789227, 0).into();
839
840 let decoder = DecoderBuilder::default()
841 .with_config_modelpack_det(
842 configs::Boxes {
843 decoder: DecoderType::ModelPack,
844 quantization: Some(quant_boxes),
845 shape: vec![1, 1935, 1, 4],
846 dshape: vec![
847 (DimName::Batch, 1),
848 (DimName::NumBoxes, 1935),
849 (DimName::Padding, 1),
850 (DimName::BoxCoords, 4),
851 ],
852 normalized: Some(true),
853 },
854 configs::Scores {
855 decoder: DecoderType::ModelPack,
856 quantization: Some(quant_scores),
857 shape: vec![1, 1935, 1],
858 dshape: vec![
859 (DimName::Batch, 1),
860 (DimName::NumBoxes, 1935),
861 (DimName::NumClasses, 1),
862 ],
863 },
864 )
865 .with_score_threshold(score_threshold)
866 .with_iou_threshold(iou_threshold)
867 .build()
868 .unwrap();
869
870 let quant_boxes = quant_boxes.into();
871 let quant_scores = quant_scores.into();
872
873 let mut output_boxes: Vec<_> = Vec::with_capacity(50);
874 decode_modelpack_det(
875 (boxes.slice(s![0, .., 0, ..]), quant_boxes),
876 (scores.slice(s![0, .., ..]), quant_scores),
877 score_threshold,
878 iou_threshold,
879 &mut output_boxes,
880 );
881 assert!(output_boxes[0].equal_within_delta(
882 &DetectBox {
883 bbox: BoundingBox {
884 xmin: 0.40513772,
885 ymin: 0.6379755,
886 xmax: 0.5122431,
887 ymax: 0.7730214,
888 },
889 score: 0.4861709,
890 label: 0
891 },
892 1e-6
893 ));
894
895 let mut output_boxes1 = Vec::with_capacity(50);
896 let mut output_masks1 = Vec::with_capacity(50);
897
898 decoder
899 .decode_quantized(
900 &[boxes.view().into(), scores.view().into()],
901 &mut output_boxes1,
902 &mut output_masks1,
903 )
904 .unwrap();
905
906 let mut output_boxes_float = Vec::with_capacity(50);
907 let mut output_masks_float = Vec::with_capacity(50);
908
909 let boxes = dequantize_ndarray(boxes.view(), quant_boxes);
910 let scores = dequantize_ndarray(scores.view(), quant_scores);
911
912 decoder
913 .decode_float::<f32>(
914 &[boxes.view().into_dyn(), scores.view().into_dyn()],
915 &mut output_boxes_float,
916 &mut output_masks_float,
917 )
918 .unwrap();
919
920 compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
921 compare_outputs(
922 (&output_boxes, &output_boxes_float),
923 (&[], &output_masks_float),
924 );
925 }
926
927 #[test]
928 fn test_decoder_modelpack_split_u8() {
929 let score_threshold = 0.45;
930 let iou_threshold = 0.45;
931 let detect0 = include_bytes!(concat!(
932 env!("CARGO_MANIFEST_DIR"),
933 "/../../testdata/modelpack_split_9x15x18.bin"
934 ));
935 let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
936
937 let detect1 = include_bytes!(concat!(
938 env!("CARGO_MANIFEST_DIR"),
939 "/../../testdata/modelpack_split_17x30x18.bin"
940 ));
941 let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
942
943 let quant0 = (0.08547406643629074, 174).into();
944 let quant1 = (0.09929127991199493, 183).into();
945 let anchors0 = vec![
946 [0.36666667461395264, 0.31481480598449707],
947 [0.38749998807907104, 0.4740740656852722],
948 [0.5333333611488342, 0.644444465637207],
949 ];
950 let anchors1 = vec![
951 [0.13750000298023224, 0.2074074000120163],
952 [0.2541666626930237, 0.21481481194496155],
953 [0.23125000298023224, 0.35185185074806213],
954 ];
955
956 let detect_config0 = configs::Detection {
957 decoder: DecoderType::ModelPack,
958 shape: vec![1, 9, 15, 18],
959 anchors: Some(anchors0.clone()),
960 quantization: Some(quant0),
961 dshape: vec![
962 (DimName::Batch, 1),
963 (DimName::Height, 9),
964 (DimName::Width, 15),
965 (DimName::NumAnchorsXFeatures, 18),
966 ],
967 normalized: Some(true),
968 };
969
970 let detect_config1 = configs::Detection {
971 decoder: DecoderType::ModelPack,
972 shape: vec![1, 17, 30, 18],
973 anchors: Some(anchors1.clone()),
974 quantization: Some(quant1),
975 dshape: vec![
976 (DimName::Batch, 1),
977 (DimName::Height, 17),
978 (DimName::Width, 30),
979 (DimName::NumAnchorsXFeatures, 18),
980 ],
981 normalized: Some(true),
982 };
983
984 let config0 = (&detect_config0).try_into().unwrap();
985 let config1 = (&detect_config1).try_into().unwrap();
986
987 let decoder = DecoderBuilder::default()
988 .with_config_modelpack_det_split(vec![detect_config1, detect_config0])
989 .with_score_threshold(score_threshold)
990 .with_iou_threshold(iou_threshold)
991 .build()
992 .unwrap();
993
994 let quant0 = quant0.into();
995 let quant1 = quant1.into();
996
997 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
998 decode_modelpack_split_quant(
999 &[
1000 detect0.slice(s![0, .., .., ..]),
1001 detect1.slice(s![0, .., .., ..]),
1002 ],
1003 &[config0, config1],
1004 score_threshold,
1005 iou_threshold,
1006 &mut output_boxes,
1007 );
1008 assert!(output_boxes[0].equal_within_delta(
1009 &DetectBox {
1010 bbox: BoundingBox {
1011 xmin: 0.43171933,
1012 ymin: 0.68243736,
1013 xmax: 0.5626645,
1014 ymax: 0.808863,
1015 },
1016 score: 0.99240804,
1017 label: 0
1018 },
1019 1e-6
1020 ));
1021
1022 let mut output_boxes1: Vec<_> = Vec::with_capacity(10);
1023 let mut output_masks1: Vec<_> = Vec::with_capacity(10);
1024 decoder
1025 .decode_quantized(
1026 &[detect0.view().into(), detect1.view().into()],
1027 &mut output_boxes1,
1028 &mut output_masks1,
1029 )
1030 .unwrap();
1031
1032 let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(10);
1033 let mut output_masks1_f32: Vec<_> = Vec::with_capacity(10);
1034
1035 let detect0 = dequantize_ndarray(detect0.view(), quant0);
1036 let detect1 = dequantize_ndarray(detect1.view(), quant1);
1037 decoder
1038 .decode_float::<f32>(
1039 &[detect0.view().into_dyn(), detect1.view().into_dyn()],
1040 &mut output_boxes1_f32,
1041 &mut output_masks1_f32,
1042 )
1043 .unwrap();
1044
1045 compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1046 compare_outputs(
1047 (&output_boxes, &output_boxes1_f32),
1048 (&[], &output_masks1_f32),
1049 );
1050 }
1051
1052 #[test]
1053 fn test_decoder_parse_config_modelpack_split_u8() {
1054 let score_threshold = 0.45;
1055 let iou_threshold = 0.45;
1056 let detect0 = include_bytes!(concat!(
1057 env!("CARGO_MANIFEST_DIR"),
1058 "/../../testdata/modelpack_split_9x15x18.bin"
1059 ));
1060 let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1061
1062 let detect1 = include_bytes!(concat!(
1063 env!("CARGO_MANIFEST_DIR"),
1064 "/../../testdata/modelpack_split_17x30x18.bin"
1065 ));
1066 let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1067
1068 let decoder = DecoderBuilder::default()
1069 .with_config_yaml_str(
1070 include_str!(concat!(
1071 env!("CARGO_MANIFEST_DIR"),
1072 "/../../testdata/modelpack_split.yaml"
1073 ))
1074 .to_string(),
1075 )
1076 .with_score_threshold(score_threshold)
1077 .with_iou_threshold(iou_threshold)
1078 .build()
1079 .unwrap();
1080
1081 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1082 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1083 decoder
1084 .decode_quantized(
1085 &[
1086 ArrayViewDQuantized::from(detect1.view()),
1087 ArrayViewDQuantized::from(detect0.view()),
1088 ],
1089 &mut output_boxes,
1090 &mut output_masks,
1091 )
1092 .unwrap();
1093 assert!(output_boxes[0].equal_within_delta(
1094 &DetectBox {
1095 bbox: BoundingBox {
1096 xmin: 0.43171933,
1097 ymin: 0.68243736,
1098 xmax: 0.5626645,
1099 ymax: 0.808863,
1100 },
1101 score: 0.99240804,
1102 label: 0
1103 },
1104 1e-6
1105 ));
1106 }
1107
1108 #[test]
1109 fn test_modelpack_seg() {
1110 let out = include_bytes!(concat!(
1111 env!("CARGO_MANIFEST_DIR"),
1112 "/../../testdata/modelpack_seg_2x160x160.bin"
1113 ));
1114 let out = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
1115 let quant = (1.0 / 255.0, 0).into();
1116
1117 let decoder = DecoderBuilder::default()
1118 .with_config_modelpack_seg(configs::Segmentation {
1119 decoder: DecoderType::ModelPack,
1120 quantization: Some(quant),
1121 shape: vec![1, 2, 160, 160],
1122 dshape: vec![
1123 (DimName::Batch, 1),
1124 (DimName::NumClasses, 2),
1125 (DimName::Height, 160),
1126 (DimName::Width, 160),
1127 ],
1128 })
1129 .build()
1130 .unwrap();
1131 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1132 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1133 decoder
1134 .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
1135 .unwrap();
1136
1137 let mut mask = out.slice(s![0, .., .., ..]);
1138 mask.swap_axes(0, 1);
1139 mask.swap_axes(1, 2);
1140 let mask = [Segmentation {
1141 xmin: 0.0,
1142 ymin: 0.0,
1143 xmax: 1.0,
1144 ymax: 1.0,
1145 segmentation: mask.into_owned(),
1146 }];
1147 compare_outputs((&[], &output_boxes), (&mask, &output_masks));
1148
1149 decoder
1150 .decode_float::<f32>(
1151 &[dequantize_ndarray(out.view(), quant.into())
1152 .view()
1153 .into_dyn()],
1154 &mut output_boxes,
1155 &mut output_masks,
1156 )
1157 .unwrap();
1158
1159 compare_outputs((&[], &output_boxes), (&[], &[]));
1165 let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1166 let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1167
1168 assert_eq!(mask0, mask1);
1169 }
1170 #[test]
1171 fn test_modelpack_seg_quant() {
1172 let out = include_bytes!(concat!(
1173 env!("CARGO_MANIFEST_DIR"),
1174 "/../../testdata/modelpack_seg_2x160x160.bin"
1175 ));
1176 let out_u8 = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
1177 let out_i8 = out_u8.mapv(|x| (x as i16 - 128) as i8);
1178 let out_u16 = out_u8.mapv(|x| (x as u16) << 8);
1179 let out_i16 = out_u8.mapv(|x| (((x as i32) << 8) - 32768) as i16);
1180 let out_u32 = out_u8.mapv(|x| (x as u32) << 24);
1181 let out_i32 = out_u8.mapv(|x| (((x as i64) << 24) - 2147483648) as i32);
1182
1183 let quant = (1.0 / 255.0, 0).into();
1184
1185 let decoder = DecoderBuilder::default()
1186 .with_config_modelpack_seg(configs::Segmentation {
1187 decoder: DecoderType::ModelPack,
1188 quantization: Some(quant),
1189 shape: vec![1, 2, 160, 160],
1190 dshape: vec![
1191 (DimName::Batch, 1),
1192 (DimName::NumClasses, 2),
1193 (DimName::Height, 160),
1194 (DimName::Width, 160),
1195 ],
1196 })
1197 .build()
1198 .unwrap();
1199 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1200 let mut output_masks_u8: Vec<_> = Vec::with_capacity(10);
1201 decoder
1202 .decode_quantized(
1203 &[out_u8.view().into()],
1204 &mut output_boxes,
1205 &mut output_masks_u8,
1206 )
1207 .unwrap();
1208
1209 let mut output_masks_i8: Vec<_> = Vec::with_capacity(10);
1210 decoder
1211 .decode_quantized(
1212 &[out_i8.view().into()],
1213 &mut output_boxes,
1214 &mut output_masks_i8,
1215 )
1216 .unwrap();
1217
1218 let mut output_masks_u16: Vec<_> = Vec::with_capacity(10);
1219 decoder
1220 .decode_quantized(
1221 &[out_u16.view().into()],
1222 &mut output_boxes,
1223 &mut output_masks_u16,
1224 )
1225 .unwrap();
1226
1227 let mut output_masks_i16: Vec<_> = Vec::with_capacity(10);
1228 decoder
1229 .decode_quantized(
1230 &[out_i16.view().into()],
1231 &mut output_boxes,
1232 &mut output_masks_i16,
1233 )
1234 .unwrap();
1235
1236 let mut output_masks_u32: Vec<_> = Vec::with_capacity(10);
1237 decoder
1238 .decode_quantized(
1239 &[out_u32.view().into()],
1240 &mut output_boxes,
1241 &mut output_masks_u32,
1242 )
1243 .unwrap();
1244
1245 let mut output_masks_i32: Vec<_> = Vec::with_capacity(10);
1246 decoder
1247 .decode_quantized(
1248 &[out_i32.view().into()],
1249 &mut output_boxes,
1250 &mut output_masks_i32,
1251 )
1252 .unwrap();
1253
1254 compare_outputs((&[], &output_boxes), (&[], &[]));
1255 let mask_u8 = segmentation_to_mask(output_masks_u8[0].segmentation.view()).unwrap();
1256 let mask_i8 = segmentation_to_mask(output_masks_i8[0].segmentation.view()).unwrap();
1257 let mask_u16 = segmentation_to_mask(output_masks_u16[0].segmentation.view()).unwrap();
1258 let mask_i16 = segmentation_to_mask(output_masks_i16[0].segmentation.view()).unwrap();
1259 let mask_u32 = segmentation_to_mask(output_masks_u32[0].segmentation.view()).unwrap();
1260 let mask_i32 = segmentation_to_mask(output_masks_i32[0].segmentation.view()).unwrap();
1261 assert_eq!(mask_u8, mask_i8);
1262 assert_eq!(mask_u8, mask_u16);
1263 assert_eq!(mask_u8, mask_i16);
1264 assert_eq!(mask_u8, mask_u32);
1265 assert_eq!(mask_u8, mask_i32);
1266 }
1267
1268 #[test]
1269 fn test_modelpack_segdet() {
1270 let score_threshold = 0.45;
1271 let iou_threshold = 0.45;
1272
1273 let boxes = include_bytes!(concat!(
1274 env!("CARGO_MANIFEST_DIR"),
1275 "/../../testdata/modelpack_boxes_1935x1x4.bin"
1276 ));
1277 let boxes = Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
1278
1279 let scores = include_bytes!(concat!(
1280 env!("CARGO_MANIFEST_DIR"),
1281 "/../../testdata/modelpack_scores_1935x1.bin"
1282 ));
1283 let scores = Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
1284
1285 let seg = include_bytes!(concat!(
1286 env!("CARGO_MANIFEST_DIR"),
1287 "/../../testdata/modelpack_seg_2x160x160.bin"
1288 ));
1289 let seg = Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1290
1291 let quant_boxes = (0.004656755365431309, 21).into();
1292 let quant_scores = (0.0019603664986789227, 0).into();
1293 let quant_seg = (1.0 / 255.0, 0).into();
1294
1295 let decoder = DecoderBuilder::default()
1296 .with_config_modelpack_segdet(
1297 configs::Boxes {
1298 decoder: DecoderType::ModelPack,
1299 quantization: Some(quant_boxes),
1300 shape: vec![1, 1935, 1, 4],
1301 dshape: vec![
1302 (DimName::Batch, 1),
1303 (DimName::NumBoxes, 1935),
1304 (DimName::Padding, 1),
1305 (DimName::BoxCoords, 4),
1306 ],
1307 normalized: Some(true),
1308 },
1309 configs::Scores {
1310 decoder: DecoderType::ModelPack,
1311 quantization: Some(quant_scores),
1312 shape: vec![1, 1935, 1],
1313 dshape: vec![
1314 (DimName::Batch, 1),
1315 (DimName::NumBoxes, 1935),
1316 (DimName::NumClasses, 1),
1317 ],
1318 },
1319 configs::Segmentation {
1320 decoder: DecoderType::ModelPack,
1321 quantization: Some(quant_seg),
1322 shape: vec![1, 2, 160, 160],
1323 dshape: vec![
1324 (DimName::Batch, 1),
1325 (DimName::NumClasses, 2),
1326 (DimName::Height, 160),
1327 (DimName::Width, 160),
1328 ],
1329 },
1330 )
1331 .with_iou_threshold(iou_threshold)
1332 .with_score_threshold(score_threshold)
1333 .build()
1334 .unwrap();
1335 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1336 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1337 decoder
1338 .decode_quantized(
1339 &[scores.view().into(), boxes.view().into(), seg.view().into()],
1340 &mut output_boxes,
1341 &mut output_masks,
1342 )
1343 .unwrap();
1344
1345 let mut mask = seg.slice(s![0, .., .., ..]);
1346 mask.swap_axes(0, 1);
1347 mask.swap_axes(1, 2);
1348 let mask = [Segmentation {
1349 xmin: 0.0,
1350 ymin: 0.0,
1351 xmax: 1.0,
1352 ymax: 1.0,
1353 segmentation: mask.into_owned(),
1354 }];
1355 let correct_boxes = [DetectBox {
1356 bbox: BoundingBox {
1357 xmin: 0.40513772,
1358 ymin: 0.6379755,
1359 xmax: 0.5122431,
1360 ymax: 0.7730214,
1361 },
1362 score: 0.4861709,
1363 label: 0,
1364 }];
1365 compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1366
1367 let scores = dequantize_ndarray(scores.view(), quant_scores.into());
1368 let boxes = dequantize_ndarray(boxes.view(), quant_boxes.into());
1369 let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1370 decoder
1371 .decode_float::<f32>(
1372 &[
1373 scores.view().into_dyn(),
1374 boxes.view().into_dyn(),
1375 seg.view().into_dyn(),
1376 ],
1377 &mut output_boxes,
1378 &mut output_masks,
1379 )
1380 .unwrap();
1381
1382 compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1388 let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1389 let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1390
1391 assert_eq!(mask0, mask1);
1392 }
1393
1394 #[test]
1395 fn test_modelpack_segdet_split() {
1396 let score_threshold = 0.8;
1397 let iou_threshold = 0.5;
1398
1399 let seg = include_bytes!(concat!(
1400 env!("CARGO_MANIFEST_DIR"),
1401 "/../../testdata/modelpack_seg_2x160x160.bin"
1402 ));
1403 let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1404
1405 let detect0 = include_bytes!(concat!(
1406 env!("CARGO_MANIFEST_DIR"),
1407 "/../../testdata/modelpack_split_9x15x18.bin"
1408 ));
1409 let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1410
1411 let detect1 = include_bytes!(concat!(
1412 env!("CARGO_MANIFEST_DIR"),
1413 "/../../testdata/modelpack_split_17x30x18.bin"
1414 ));
1415 let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1416
1417 let quant0 = (0.08547406643629074, 174).into();
1418 let quant1 = (0.09929127991199493, 183).into();
1419 let quant_seg = (1.0 / 255.0, 0).into();
1420
1421 let anchors0 = vec![
1422 [0.36666667461395264, 0.31481480598449707],
1423 [0.38749998807907104, 0.4740740656852722],
1424 [0.5333333611488342, 0.644444465637207],
1425 ];
1426 let anchors1 = vec![
1427 [0.13750000298023224, 0.2074074000120163],
1428 [0.2541666626930237, 0.21481481194496155],
1429 [0.23125000298023224, 0.35185185074806213],
1430 ];
1431
1432 let decoder = DecoderBuilder::default()
1433 .with_config_modelpack_segdet_split(
1434 vec![
1435 configs::Detection {
1436 decoder: DecoderType::ModelPack,
1437 shape: vec![1, 17, 30, 18],
1438 anchors: Some(anchors1),
1439 quantization: Some(quant1),
1440 dshape: vec![
1441 (DimName::Batch, 1),
1442 (DimName::Height, 17),
1443 (DimName::Width, 30),
1444 (DimName::NumAnchorsXFeatures, 18),
1445 ],
1446 normalized: Some(true),
1447 },
1448 configs::Detection {
1449 decoder: DecoderType::ModelPack,
1450 shape: vec![1, 9, 15, 18],
1451 anchors: Some(anchors0),
1452 quantization: Some(quant0),
1453 dshape: vec![
1454 (DimName::Batch, 1),
1455 (DimName::Height, 9),
1456 (DimName::Width, 15),
1457 (DimName::NumAnchorsXFeatures, 18),
1458 ],
1459 normalized: Some(true),
1460 },
1461 ],
1462 configs::Segmentation {
1463 decoder: DecoderType::ModelPack,
1464 quantization: Some(quant_seg),
1465 shape: vec![1, 2, 160, 160],
1466 dshape: vec![
1467 (DimName::Batch, 1),
1468 (DimName::NumClasses, 2),
1469 (DimName::Height, 160),
1470 (DimName::Width, 160),
1471 ],
1472 },
1473 )
1474 .with_score_threshold(score_threshold)
1475 .with_iou_threshold(iou_threshold)
1476 .build()
1477 .unwrap();
1478 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1479 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1480 decoder
1481 .decode_quantized(
1482 &[
1483 detect0.view().into(),
1484 detect1.view().into(),
1485 seg.view().into(),
1486 ],
1487 &mut output_boxes,
1488 &mut output_masks,
1489 )
1490 .unwrap();
1491
1492 let mut mask = seg.slice(s![0, .., .., ..]);
1493 mask.swap_axes(0, 1);
1494 mask.swap_axes(1, 2);
1495 let mask = [Segmentation {
1496 xmin: 0.0,
1497 ymin: 0.0,
1498 xmax: 1.0,
1499 ymax: 1.0,
1500 segmentation: mask.into_owned(),
1501 }];
1502 let correct_boxes = [DetectBox {
1503 bbox: BoundingBox {
1504 xmin: 0.43171933,
1505 ymin: 0.68243736,
1506 xmax: 0.5626645,
1507 ymax: 0.808863,
1508 },
1509 score: 0.99240804,
1510 label: 0,
1511 }];
1512 println!("Output Boxes: {:?}", output_boxes);
1513 compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1514
1515 let detect0 = dequantize_ndarray(detect0.view(), quant0.into());
1516 let detect1 = dequantize_ndarray(detect1.view(), quant1.into());
1517 let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1518 decoder
1519 .decode_float::<f32>(
1520 &[
1521 detect0.view().into_dyn(),
1522 detect1.view().into_dyn(),
1523 seg.view().into_dyn(),
1524 ],
1525 &mut output_boxes,
1526 &mut output_masks,
1527 )
1528 .unwrap();
1529
1530 compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1536 let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1537 let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1538
1539 assert_eq!(mask0, mask1);
1540 }
1541
1542 #[test]
1543 fn test_dequant_chunked() {
1544 let mut out = load_yolov8s_det().into_raw_vec_and_offset().0;
1545 out.push(123); let mut out_dequant = vec![0.0; 84 * 8400 + 1];
1548 let mut out_dequant_simd = vec![0.0; 84 * 8400 + 1];
1549 let quant = Quantization::new(0.0040811873, -123);
1550 dequantize_cpu(&out, quant, &mut out_dequant);
1551
1552 dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1553 assert_eq!(out_dequant, out_dequant_simd);
1554
1555 let quant = Quantization::new(0.0040811873, 0);
1556 dequantize_cpu(&out, quant, &mut out_dequant);
1557
1558 dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1559 assert_eq!(out_dequant, out_dequant_simd);
1560 }
1561
1562 #[test]
1563 fn test_dequant_ground_truth() {
1564 let quant = Quantization::new(0.1, -128);
1569 let input: Vec<i8> = vec![0, 127, -128, 64];
1570 let mut output = vec![0.0f32; 4];
1571 let mut output_chunked = vec![0.0f32; 4];
1572 dequantize_cpu(&input, quant, &mut output);
1573 dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1574 let expected: Vec<f32> = vec![12.8, 25.5, 0.0, 19.2];
1579 for (i, (&out, &exp)) in output.iter().zip(expected.iter()).enumerate() {
1580 assert!((out - exp).abs() < 1e-5, "cpu[{i}]: {out} != {exp}");
1581 }
1582 for (i, (&out, &exp)) in output_chunked.iter().zip(expected.iter()).enumerate() {
1583 assert!((out - exp).abs() < 1e-5, "chunked[{i}]: {out} != {exp}");
1584 }
1585
1586 let quant = Quantization::new(1.0, 0);
1588 dequantize_cpu(&input, quant, &mut output);
1589 dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1590 let expected: Vec<f32> = vec![0.0, 127.0, -128.0, 64.0];
1591 assert_eq!(output, expected);
1592 assert_eq!(output_chunked, expected);
1593
1594 let quant = Quantization::new(0.5, 0);
1596 dequantize_cpu(&input, quant, &mut output);
1597 dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1598 let expected: Vec<f32> = vec![0.0, 63.5, -64.0, 32.0];
1599 assert_eq!(output, expected);
1600 assert_eq!(output_chunked, expected);
1601
1602 let quant = Quantization::new(0.021287762, 31);
1604 let input: Vec<i8> = vec![-128, -1, 0, 1, 31, 127];
1605 let mut output = vec![0.0f32; 6];
1606 let mut output_chunked = vec![0.0f32; 6];
1607 dequantize_cpu(&input, quant, &mut output);
1608 dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1609 for i in 0..6 {
1610 let expected = (input[i] as f32 - 31.0) * 0.021287762;
1611 assert!(
1612 (output[i] - expected).abs() < 1e-5,
1613 "cpu[{i}]: {} != {expected}",
1614 output[i]
1615 );
1616 assert!(
1617 (output_chunked[i] - expected).abs() < 1e-5,
1618 "chunked[{i}]: {} != {expected}",
1619 output_chunked[i]
1620 );
1621 }
1622 }
1623
1624 #[test]
1625 fn test_decoder_yolo_det() {
1626 let score_threshold = 0.25;
1627 let iou_threshold = 0.7;
1628 let out = load_yolov8s_det();
1629 let quant = (0.0040811873, -123).into();
1630
1631 let decoder = DecoderBuilder::default()
1632 .with_config_yolo_det(
1633 configs::Detection {
1634 decoder: DecoderType::Ultralytics,
1635 shape: vec![1, 84, 8400],
1636 anchors: None,
1637 quantization: Some(quant),
1638 dshape: vec![
1639 (DimName::Batch, 1),
1640 (DimName::NumFeatures, 84),
1641 (DimName::NumBoxes, 8400),
1642 ],
1643 normalized: Some(true),
1644 },
1645 Some(DecoderVersion::Yolo11),
1646 )
1647 .with_score_threshold(score_threshold)
1648 .with_iou_threshold(iou_threshold)
1649 .build()
1650 .unwrap();
1651
1652 let mut output_boxes: Vec<_> = Vec::with_capacity(50);
1653 decode_yolo_det(
1654 (out.slice(s![0, .., ..]), quant.into()),
1655 score_threshold,
1656 iou_threshold,
1657 Some(configs::Nms::ClassAgnostic),
1658 &mut output_boxes,
1659 );
1660 assert!(output_boxes[0].equal_within_delta(
1661 &DetectBox {
1662 bbox: BoundingBox {
1663 xmin: 0.5285137,
1664 ymin: 0.05305544,
1665 xmax: 0.87541467,
1666 ymax: 0.9998909,
1667 },
1668 score: 0.5591227,
1669 label: 0
1670 },
1671 1e-6
1672 ));
1673
1674 assert!(output_boxes[1].equal_within_delta(
1675 &DetectBox {
1676 bbox: BoundingBox {
1677 xmin: 0.130598,
1678 ymin: 0.43260583,
1679 xmax: 0.35098213,
1680 ymax: 0.9958097,
1681 },
1682 score: 0.33057618,
1683 label: 75
1684 },
1685 1e-6
1686 ));
1687
1688 let mut output_boxes1: Vec<_> = Vec::with_capacity(50);
1689 let mut output_masks1: Vec<_> = Vec::with_capacity(50);
1690 decoder
1691 .decode_quantized(&[out.view().into()], &mut output_boxes1, &mut output_masks1)
1692 .unwrap();
1693
1694 let out = dequantize_ndarray(out.view(), quant.into());
1695 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(50);
1696 let mut output_masks_f32: Vec<_> = Vec::with_capacity(50);
1697 decoder
1698 .decode_float::<f32>(
1699 &[out.view().into_dyn()],
1700 &mut output_boxes_f32,
1701 &mut output_masks_f32,
1702 )
1703 .unwrap();
1704
1705 compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1706 compare_outputs((&output_boxes, &output_boxes_f32), (&[], &output_masks_f32));
1707 }
1708
1709 #[test]
1710 fn test_decoder_masks() {
1711 let score_threshold = 0.45;
1712 let iou_threshold = 0.45;
1713 let boxes = load_yolov8_boxes();
1714 let quant_boxes = Quantization::new(0.021287761628627777, 31);
1715
1716 let protos = load_yolov8_protos();
1717 let quant_protos = Quantization::new(0.02491161972284317, -117);
1718 let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1719 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1720 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1721 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1722 decode_yolo_segdet_float(
1723 seg.slice(s![0, .., ..]),
1724 protos.slice(s![0, .., .., ..]),
1725 score_threshold,
1726 iou_threshold,
1727 Some(configs::Nms::ClassAgnostic),
1728 &mut output_boxes,
1729 &mut output_masks,
1730 )
1731 .unwrap();
1732 assert_eq!(output_boxes.len(), 2);
1733 assert_eq!(output_boxes.len(), output_masks.len());
1734
1735 for (b, m) in output_boxes.iter().zip(&output_masks) {
1736 assert!(b.bbox.xmin >= m.xmin);
1737 assert!(b.bbox.ymin >= m.ymin);
1738 assert!(b.bbox.xmax >= m.xmax);
1739 assert!(b.bbox.ymax >= m.ymax);
1740 }
1741 assert!(output_boxes[0].equal_within_delta(
1742 &DetectBox {
1743 bbox: BoundingBox {
1744 xmin: 0.08515105,
1745 ymin: 0.7131401,
1746 xmax: 0.29802868,
1747 ymax: 0.8195788,
1748 },
1749 score: 0.91537374,
1750 label: 23
1751 },
1752 1.0 / 160.0, ));
1754
1755 assert!(output_boxes[1].equal_within_delta(
1756 &DetectBox {
1757 bbox: BoundingBox {
1758 xmin: 0.59605736,
1759 ymin: 0.25545314,
1760 xmax: 0.93666154,
1761 ymax: 0.72378385,
1762 },
1763 score: 0.91537374,
1764 label: 23
1765 },
1766 1.0 / 160.0, ));
1768
1769 let full_mask = include_bytes!(concat!(
1770 env!("CARGO_MANIFEST_DIR"),
1771 "/../../testdata/yolov8_mask_results.bin"
1772 ));
1773 let full_mask = ndarray::Array2::from_shape_vec((160, 160), full_mask.to_vec()).unwrap();
1774
1775 let cropped_mask = full_mask.slice(ndarray::s![
1776 (output_masks[1].ymin * 160.0) as usize..(output_masks[1].ymax * 160.0) as usize,
1777 (output_masks[1].xmin * 160.0) as usize..(output_masks[1].xmax * 160.0) as usize,
1778 ]);
1779
1780 assert_eq!(
1781 cropped_mask,
1782 segmentation_to_mask(output_masks[1].segmentation.view()).unwrap()
1783 );
1784 }
1785
1786 #[test]
1791 fn test_decoder_masks_nchw_protos() {
1792 let score_threshold = 0.45;
1793 let iou_threshold = 0.45;
1794
1795 let boxes_2d = load_yolov8_boxes().slice_move(s![0, .., ..]);
1797 let quant_boxes = Quantization::new(0.021287761628627777, 31);
1798
1799 let protos_hwc = load_yolov8_protos().slice_move(s![0, .., .., ..]);
1801 let quant_protos = Quantization::new(0.02491161972284317, -117);
1802 let protos_f32_hwc = dequantize_ndarray::<_, _, f32>(protos_hwc.view(), quant_protos);
1803
1804 let seg = dequantize_ndarray::<_, _, f32>(boxes_2d.view(), quant_boxes);
1806 let mut ref_boxes: Vec<_> = Vec::with_capacity(10);
1807 let mut ref_masks: Vec<_> = Vec::with_capacity(10);
1808 decode_yolo_segdet_float(
1809 seg.view(),
1810 protos_f32_hwc.view(),
1811 score_threshold,
1812 iou_threshold,
1813 Some(configs::Nms::ClassAgnostic),
1814 &mut ref_boxes,
1815 &mut ref_masks,
1816 )
1817 .unwrap();
1818 assert_eq!(ref_boxes.len(), 2);
1819
1820 let protos_f32_chw = protos_f32_hwc.permuted_axes([2, 0, 1]); let protos_nchw = protos_f32_chw.insert_axis(ndarray::Axis(0)); let seg_3d = seg.insert_axis(ndarray::Axis(0)); let decoder = DecoderBuilder::default()
1830 .with_config_yolo_segdet(
1831 configs::Detection {
1832 decoder: configs::DecoderType::Ultralytics,
1833 quantization: None,
1834 shape: vec![1, 116, 8400],
1835 dshape: vec![],
1836 normalized: Some(true),
1837 anchors: None,
1838 },
1839 configs::Protos {
1840 decoder: configs::DecoderType::Ultralytics,
1841 quantization: None,
1842 shape: vec![1, 32, 160, 160],
1843 dshape: vec![], },
1845 None, )
1847 .with_score_threshold(score_threshold)
1848 .with_iou_threshold(iou_threshold)
1849 .build()
1850 .unwrap();
1851
1852 let mut cfg_boxes: Vec<_> = Vec::with_capacity(10);
1853 let mut cfg_masks: Vec<_> = Vec::with_capacity(10);
1854 decoder
1855 .decode_float(
1856 &[seg_3d.view().into_dyn(), protos_nchw.view().into_dyn()],
1857 &mut cfg_boxes,
1858 &mut cfg_masks,
1859 )
1860 .unwrap();
1861
1862 assert_eq!(
1864 cfg_boxes.len(),
1865 ref_boxes.len(),
1866 "config path produced {} boxes, reference produced {}",
1867 cfg_boxes.len(),
1868 ref_boxes.len()
1869 );
1870
1871 for (i, (cb, rb)) in cfg_boxes.iter().zip(&ref_boxes).enumerate() {
1873 assert!(
1874 cb.equal_within_delta(rb, 0.01),
1875 "box {i} mismatch: config={cb:?}, reference={rb:?}"
1876 );
1877 }
1878
1879 for (i, (cm, rm)) in cfg_masks.iter().zip(&ref_masks).enumerate() {
1881 let cm_arr = segmentation_to_mask(cm.segmentation.view()).unwrap();
1882 let rm_arr = segmentation_to_mask(rm.segmentation.view()).unwrap();
1883 assert_eq!(
1884 cm_arr, rm_arr,
1885 "mask {i} pixel mismatch between config-driven and reference paths"
1886 );
1887 }
1888 }
1889
1890 #[test]
1891 fn test_decoder_masks_i8() {
1892 let score_threshold = 0.45;
1893 let iou_threshold = 0.45;
1894 let boxes = load_yolov8_boxes();
1895 let quant_boxes = (0.021287761628627777, 31).into();
1896
1897 let protos = load_yolov8_protos();
1898 let quant_protos = (0.02491161972284317, -117).into();
1899 let mut output_boxes: Vec<_> = Vec::with_capacity(500);
1900 let mut output_masks: Vec<_> = Vec::with_capacity(500);
1901
1902 let decoder = DecoderBuilder::default()
1903 .with_config_yolo_segdet(
1904 configs::Detection {
1905 decoder: configs::DecoderType::Ultralytics,
1906 quantization: Some(quant_boxes),
1907 shape: vec![1, 116, 8400],
1908 anchors: None,
1909 dshape: vec![
1910 (DimName::Batch, 1),
1911 (DimName::NumFeatures, 116),
1912 (DimName::NumBoxes, 8400),
1913 ],
1914 normalized: Some(true),
1915 },
1916 Protos {
1917 decoder: configs::DecoderType::Ultralytics,
1918 quantization: Some(quant_protos),
1919 shape: vec![1, 160, 160, 32],
1920 dshape: vec![
1921 (DimName::Batch, 1),
1922 (DimName::Height, 160),
1923 (DimName::Width, 160),
1924 (DimName::NumProtos, 32),
1925 ],
1926 },
1927 Some(DecoderVersion::Yolo11),
1928 )
1929 .with_score_threshold(score_threshold)
1930 .with_iou_threshold(iou_threshold)
1931 .build()
1932 .unwrap();
1933
1934 let quant_boxes = quant_boxes.into();
1935 let quant_protos = quant_protos.into();
1936
1937 decode_yolo_segdet_quant(
1938 (boxes.slice(s![0, .., ..]), quant_boxes),
1939 (protos.slice(s![0, .., .., ..]), quant_protos),
1940 score_threshold,
1941 iou_threshold,
1942 Some(configs::Nms::ClassAgnostic),
1943 &mut output_boxes,
1944 &mut output_masks,
1945 )
1946 .unwrap();
1947
1948 let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
1949 let mut output_masks1: Vec<_> = Vec::with_capacity(500);
1950
1951 decoder
1952 .decode_quantized(
1953 &[boxes.view().into(), protos.view().into()],
1954 &mut output_boxes1,
1955 &mut output_masks1,
1956 )
1957 .unwrap();
1958
1959 let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1960 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1961
1962 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
1963 let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
1964 decode_yolo_segdet_float(
1965 seg.slice(s![0, .., ..]),
1966 protos.slice(s![0, .., .., ..]),
1967 score_threshold,
1968 iou_threshold,
1969 Some(configs::Nms::ClassAgnostic),
1970 &mut output_boxes_f32,
1971 &mut output_masks_f32,
1972 )
1973 .unwrap();
1974
1975 let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(500);
1976 let mut output_masks1_f32: Vec<_> = Vec::with_capacity(500);
1977
1978 decoder
1979 .decode_float(
1980 &[seg.view().into_dyn(), protos.view().into_dyn()],
1981 &mut output_boxes1_f32,
1982 &mut output_masks1_f32,
1983 )
1984 .unwrap();
1985
1986 compare_outputs(
1987 (&output_boxes, &output_boxes1),
1988 (&output_masks, &output_masks1),
1989 );
1990
1991 compare_outputs(
1992 (&output_boxes, &output_boxes_f32),
1993 (&output_masks, &output_masks_f32),
1994 );
1995
1996 compare_outputs(
1997 (&output_boxes_f32, &output_boxes1_f32),
1998 (&output_masks_f32, &output_masks1_f32),
1999 );
2000 }
2001
2002 #[test]
2003 fn test_decoder_yolo_split() {
2004 let score_threshold = 0.45;
2005 let iou_threshold = 0.45;
2006 let boxes = load_yolov8_boxes();
2007 let boxes: Vec<_> = boxes.iter().map(|x| *x as i16 * 256).collect();
2008 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2009
2010 let quant_boxes = Quantization::new(0.021287761628627777 / 256.0, 31 * 256);
2011
2012 let decoder = DecoderBuilder::default()
2013 .with_config_yolo_split_det(
2014 configs::Boxes {
2015 decoder: configs::DecoderType::Ultralytics,
2016 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2017 shape: vec![1, 4, 8400],
2018 dshape: vec![
2019 (DimName::Batch, 1),
2020 (DimName::BoxCoords, 4),
2021 (DimName::NumBoxes, 8400),
2022 ],
2023 normalized: Some(true),
2024 },
2025 configs::Scores {
2026 decoder: configs::DecoderType::Ultralytics,
2027 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2028 shape: vec![1, 80, 8400],
2029 dshape: vec![
2030 (DimName::Batch, 1),
2031 (DimName::NumClasses, 80),
2032 (DimName::NumBoxes, 8400),
2033 ],
2034 },
2035 )
2036 .with_score_threshold(score_threshold)
2037 .with_iou_threshold(iou_threshold)
2038 .build()
2039 .unwrap();
2040
2041 let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2042 let mut output_masks: Vec<_> = Vec::with_capacity(500);
2043
2044 decoder
2045 .decode_quantized(
2046 &[
2047 boxes.slice(s![.., ..4, ..]).into(),
2048 boxes.slice(s![.., 4..84, ..]).into(),
2049 ],
2050 &mut output_boxes,
2051 &mut output_masks,
2052 )
2053 .unwrap();
2054
2055 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
2056 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2057 decode_yolo_det_float(
2058 seg.slice(s![0, ..84, ..]),
2059 score_threshold,
2060 iou_threshold,
2061 Some(configs::Nms::ClassAgnostic),
2062 &mut output_boxes_f32,
2063 );
2064
2065 let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
2066 let mut output_masks1: Vec<_> = Vec::with_capacity(500);
2067
2068 decoder
2069 .decode_float(
2070 &[
2071 seg.slice(s![.., ..4, ..]).into_dyn(),
2072 seg.slice(s![.., 4..84, ..]).into_dyn(),
2073 ],
2074 &mut output_boxes1,
2075 &mut output_masks1,
2076 )
2077 .unwrap();
2078 compare_outputs((&output_boxes, &output_boxes_f32), (&output_masks, &[]));
2079 compare_outputs((&output_boxes_f32, &output_boxes1), (&[], &output_masks1));
2080 }
2081
2082 #[test]
2083 fn test_decoder_masks_config_mixed() {
2084 let score_threshold = 0.45;
2085 let iou_threshold = 0.45;
2086 let boxes_raw = load_yolov8_boxes();
2087 let boxes: Vec<_> = boxes_raw.iter().map(|x| *x as i16 * 256).collect();
2088 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2089
2090 let quant_boxes = (0.021287761628627777 / 256.0, 31 * 256);
2091
2092 let protos = load_yolov8_protos();
2093 let quant_protos = (0.02491161972284317, -117);
2094
2095 let decoder = build_yolo_split_segdet_decoder(
2096 score_threshold,
2097 iou_threshold,
2098 quant_boxes,
2099 quant_protos,
2100 );
2101 let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2102 let mut output_masks: Vec<_> = Vec::with_capacity(500);
2103
2104 decoder
2105 .decode_quantized(
2106 &[
2107 boxes.slice(s![.., ..4, ..]).into(),
2108 boxes.slice(s![.., 4..84, ..]).into(),
2109 boxes.slice(s![.., 84.., ..]).into(),
2110 protos.view().into(),
2111 ],
2112 &mut output_boxes,
2113 &mut output_masks,
2114 )
2115 .unwrap();
2116
2117 let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos.into());
2118 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes.into());
2119 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2120 let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
2121 decode_yolo_segdet_float(
2122 seg.slice(s![0, .., ..]),
2123 protos.slice(s![0, .., .., ..]),
2124 score_threshold,
2125 iou_threshold,
2126 Some(configs::Nms::ClassAgnostic),
2127 &mut output_boxes_f32,
2128 &mut output_masks_f32,
2129 )
2130 .unwrap();
2131
2132 let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
2133 let mut output_masks1: Vec<_> = Vec::with_capacity(500);
2134
2135 decoder
2136 .decode_float(
2137 &[
2138 seg.slice(s![.., ..4, ..]).into_dyn(),
2139 seg.slice(s![.., 4..84, ..]).into_dyn(),
2140 seg.slice(s![.., 84.., ..]).into_dyn(),
2141 protos.view().into_dyn(),
2142 ],
2143 &mut output_boxes1,
2144 &mut output_masks1,
2145 )
2146 .unwrap();
2147 compare_outputs(
2148 (&output_boxes, &output_boxes_f32),
2149 (&output_masks, &output_masks_f32),
2150 );
2151 compare_outputs(
2152 (&output_boxes_f32, &output_boxes1),
2153 (&output_masks_f32, &output_masks1),
2154 );
2155 }
2156
2157 fn build_yolo_split_segdet_decoder(
2158 score_threshold: f32,
2159 iou_threshold: f32,
2160 quant_boxes: (f32, i32),
2161 quant_protos: (f32, i32),
2162 ) -> crate::Decoder {
2163 DecoderBuilder::default()
2164 .with_config_yolo_split_segdet(
2165 configs::Boxes {
2166 decoder: configs::DecoderType::Ultralytics,
2167 quantization: Some(quant_boxes.into()),
2168 shape: vec![1, 4, 8400],
2169 dshape: vec![
2170 (DimName::Batch, 1),
2171 (DimName::BoxCoords, 4),
2172 (DimName::NumBoxes, 8400),
2173 ],
2174 normalized: Some(true),
2175 },
2176 configs::Scores {
2177 decoder: configs::DecoderType::Ultralytics,
2178 quantization: Some(quant_boxes.into()),
2179 shape: vec![1, 80, 8400],
2180 dshape: vec![
2181 (DimName::Batch, 1),
2182 (DimName::NumClasses, 80),
2183 (DimName::NumBoxes, 8400),
2184 ],
2185 },
2186 configs::MaskCoefficients {
2187 decoder: configs::DecoderType::Ultralytics,
2188 quantization: Some(quant_boxes.into()),
2189 shape: vec![1, 32, 8400],
2190 dshape: vec![
2191 (DimName::Batch, 1),
2192 (DimName::NumProtos, 32),
2193 (DimName::NumBoxes, 8400),
2194 ],
2195 },
2196 configs::Protos {
2197 decoder: configs::DecoderType::Ultralytics,
2198 quantization: Some(quant_protos.into()),
2199 shape: vec![1, 160, 160, 32],
2200 dshape: vec![
2201 (DimName::Batch, 1),
2202 (DimName::Height, 160),
2203 (DimName::Width, 160),
2204 (DimName::NumProtos, 32),
2205 ],
2206 },
2207 )
2208 .with_score_threshold(score_threshold)
2209 .with_iou_threshold(iou_threshold)
2210 .build()
2211 .unwrap()
2212 }
2213
2214 fn build_yolov8_seg_decoder(score_threshold: f32, iou_threshold: f32) -> crate::Decoder {
2215 let config_yaml = include_str!(concat!(
2216 env!("CARGO_MANIFEST_DIR"),
2217 "/../../testdata/yolov8_seg.yaml"
2218 ));
2219 DecoderBuilder::default()
2220 .with_config_yaml_str(config_yaml.to_string())
2221 .with_score_threshold(score_threshold)
2222 .with_iou_threshold(iou_threshold)
2223 .build()
2224 .unwrap()
2225 }
2226 #[test]
2227 fn test_decoder_masks_config_i32() {
2228 let score_threshold = 0.45;
2229 let iou_threshold = 0.45;
2230 let boxes_raw = load_yolov8_boxes();
2231 let scale = 1 << 23;
2232 let boxes: Vec<_> = boxes_raw.iter().map(|x| *x as i32 * scale).collect();
2233 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2234
2235 let quant_boxes = (0.021287761628627777 / scale as f32, 31 * scale);
2236
2237 let protos_raw = load_yolov8_protos();
2238 let protos: Vec<_> = protos_raw.iter().map(|x| *x as i32 * scale).collect();
2239 let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos).unwrap();
2240 let quant_protos = (0.02491161972284317 / scale as f32, -117 * scale);
2241
2242 let decoder = build_yolo_split_segdet_decoder(
2243 score_threshold,
2244 iou_threshold,
2245 quant_boxes,
2246 quant_protos,
2247 );
2248
2249 let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2250 let mut output_masks: Vec<_> = Vec::with_capacity(500);
2251
2252 decoder
2253 .decode_quantized(
2254 &[
2255 boxes.slice(s![.., ..4, ..]).into(),
2256 boxes.slice(s![.., 4..84, ..]).into(),
2257 boxes.slice(s![.., 84.., ..]).into(),
2258 protos.view().into(),
2259 ],
2260 &mut output_boxes,
2261 &mut output_masks,
2262 )
2263 .unwrap();
2264
2265 let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos.into());
2266 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes.into());
2267 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2268 let mut output_masks_f32: Vec<Segmentation> = Vec::with_capacity(500);
2269 decode_yolo_segdet_float(
2270 seg.slice(s![0, .., ..]),
2271 protos.slice(s![0, .., .., ..]),
2272 score_threshold,
2273 iou_threshold,
2274 Some(configs::Nms::ClassAgnostic),
2275 &mut output_boxes_f32,
2276 &mut output_masks_f32,
2277 )
2278 .unwrap();
2279
2280 assert_eq!(output_boxes.len(), output_boxes_f32.len());
2281 assert_eq!(output_masks.len(), output_masks_f32.len());
2282
2283 compare_outputs(
2284 (&output_boxes, &output_boxes_f32),
2285 (&output_masks, &output_masks_f32),
2286 );
2287 }
2288
2289 #[test]
2291 fn test_context_switch() {
2292 let yolo_det = || {
2293 let score_threshold = 0.25;
2294 let iou_threshold = 0.7;
2295 let out = load_yolov8s_det();
2296 let quant = (0.0040811873, -123).into();
2297
2298 let decoder = DecoderBuilder::default()
2299 .with_config_yolo_det(
2300 configs::Detection {
2301 decoder: DecoderType::Ultralytics,
2302 shape: vec![1, 84, 8400],
2303 anchors: None,
2304 quantization: Some(quant),
2305 dshape: vec![
2306 (DimName::Batch, 1),
2307 (DimName::NumFeatures, 84),
2308 (DimName::NumBoxes, 8400),
2309 ],
2310 normalized: None,
2311 },
2312 None,
2313 )
2314 .with_score_threshold(score_threshold)
2315 .with_iou_threshold(iou_threshold)
2316 .build()
2317 .unwrap();
2318
2319 let mut output_boxes: Vec<_> = Vec::with_capacity(50);
2320 let mut output_masks: Vec<_> = Vec::with_capacity(50);
2321
2322 for _ in 0..100 {
2323 decoder
2324 .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
2325 .unwrap();
2326
2327 assert!(output_boxes[0].equal_within_delta(
2328 &DetectBox {
2329 bbox: BoundingBox {
2330 xmin: 0.5285137,
2331 ymin: 0.05305544,
2332 xmax: 0.87541467,
2333 ymax: 0.9998909,
2334 },
2335 score: 0.5591227,
2336 label: 0
2337 },
2338 1e-6
2339 ));
2340
2341 assert!(output_boxes[1].equal_within_delta(
2342 &DetectBox {
2343 bbox: BoundingBox {
2344 xmin: 0.130598,
2345 ymin: 0.43260583,
2346 xmax: 0.35098213,
2347 ymax: 0.9958097,
2348 },
2349 score: 0.33057618,
2350 label: 75
2351 },
2352 1e-6
2353 ));
2354 assert!(output_masks.is_empty());
2355 }
2356 };
2357
2358 let modelpack_det_split = || {
2359 let score_threshold = 0.8;
2360 let iou_threshold = 0.5;
2361
2362 let seg = include_bytes!(concat!(
2363 env!("CARGO_MANIFEST_DIR"),
2364 "/../../testdata/modelpack_seg_2x160x160.bin"
2365 ));
2366 let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
2367
2368 let detect0 = include_bytes!(concat!(
2369 env!("CARGO_MANIFEST_DIR"),
2370 "/../../testdata/modelpack_split_9x15x18.bin"
2371 ));
2372 let detect0 =
2373 ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
2374
2375 let detect1 = include_bytes!(concat!(
2376 env!("CARGO_MANIFEST_DIR"),
2377 "/../../testdata/modelpack_split_17x30x18.bin"
2378 ));
2379 let detect1 =
2380 ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
2381
2382 let mut mask = seg.slice(s![0, .., .., ..]);
2383 mask.swap_axes(0, 1);
2384 mask.swap_axes(1, 2);
2385 let mask = [Segmentation {
2386 xmin: 0.0,
2387 ymin: 0.0,
2388 xmax: 1.0,
2389 ymax: 1.0,
2390 segmentation: mask.into_owned(),
2391 }];
2392 let correct_boxes = [DetectBox {
2393 bbox: BoundingBox {
2394 xmin: 0.43171933,
2395 ymin: 0.68243736,
2396 xmax: 0.5626645,
2397 ymax: 0.808863,
2398 },
2399 score: 0.99240804,
2400 label: 0,
2401 }];
2402
2403 let quant0 = (0.08547406643629074, 174).into();
2404 let quant1 = (0.09929127991199493, 183).into();
2405 let quant_seg = (1.0 / 255.0, 0).into();
2406
2407 let anchors0 = vec![
2408 [0.36666667461395264, 0.31481480598449707],
2409 [0.38749998807907104, 0.4740740656852722],
2410 [0.5333333611488342, 0.644444465637207],
2411 ];
2412 let anchors1 = vec![
2413 [0.13750000298023224, 0.2074074000120163],
2414 [0.2541666626930237, 0.21481481194496155],
2415 [0.23125000298023224, 0.35185185074806213],
2416 ];
2417
2418 let decoder = DecoderBuilder::default()
2419 .with_config_modelpack_segdet_split(
2420 vec![
2421 configs::Detection {
2422 decoder: DecoderType::ModelPack,
2423 shape: vec![1, 17, 30, 18],
2424 anchors: Some(anchors1),
2425 quantization: Some(quant1),
2426 dshape: vec![
2427 (DimName::Batch, 1),
2428 (DimName::Height, 17),
2429 (DimName::Width, 30),
2430 (DimName::NumAnchorsXFeatures, 18),
2431 ],
2432 normalized: None,
2433 },
2434 configs::Detection {
2435 decoder: DecoderType::ModelPack,
2436 shape: vec![1, 9, 15, 18],
2437 anchors: Some(anchors0),
2438 quantization: Some(quant0),
2439 dshape: vec![
2440 (DimName::Batch, 1),
2441 (DimName::Height, 9),
2442 (DimName::Width, 15),
2443 (DimName::NumAnchorsXFeatures, 18),
2444 ],
2445 normalized: None,
2446 },
2447 ],
2448 configs::Segmentation {
2449 decoder: DecoderType::ModelPack,
2450 quantization: Some(quant_seg),
2451 shape: vec![1, 2, 160, 160],
2452 dshape: vec![
2453 (DimName::Batch, 1),
2454 (DimName::NumClasses, 2),
2455 (DimName::Height, 160),
2456 (DimName::Width, 160),
2457 ],
2458 },
2459 )
2460 .with_score_threshold(score_threshold)
2461 .with_iou_threshold(iou_threshold)
2462 .build()
2463 .unwrap();
2464 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
2465 let mut output_masks: Vec<_> = Vec::with_capacity(10);
2466
2467 for _ in 0..100 {
2468 decoder
2469 .decode_quantized(
2470 &[
2471 detect0.view().into(),
2472 detect1.view().into(),
2473 seg.view().into(),
2474 ],
2475 &mut output_boxes,
2476 &mut output_masks,
2477 )
2478 .unwrap();
2479
2480 compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
2481 }
2482 };
2483
2484 let handles = vec![
2485 std::thread::spawn(yolo_det),
2486 std::thread::spawn(modelpack_det_split),
2487 std::thread::spawn(yolo_det),
2488 std::thread::spawn(modelpack_det_split),
2489 std::thread::spawn(yolo_det),
2490 std::thread::spawn(modelpack_det_split),
2491 std::thread::spawn(yolo_det),
2492 std::thread::spawn(modelpack_det_split),
2493 ];
2494 for handle in handles {
2495 handle.join().unwrap();
2496 }
2497 }
2498
2499 #[test]
2500 fn test_ndarray_to_xyxy_float() {
2501 let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2502 let xyxy: [f32; 4] = XYWH::ndarray_to_xyxy_float(arr.view());
2503 assert_eq!(xyxy, [0.0_f32, 10.0, 20.0, 30.0]);
2504
2505 let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2506 let xyxy: [f32; 4] = XYXY::ndarray_to_xyxy_float(arr.view());
2507 assert_eq!(xyxy, [10.0_f32, 20.0, 20.0, 20.0]);
2508 }
2509
2510 #[test]
2511 fn test_class_aware_nms_float() {
2512 use crate::float::nms_class_aware_float;
2513
2514 let boxes = vec![
2516 DetectBox {
2517 bbox: BoundingBox {
2518 xmin: 0.0,
2519 ymin: 0.0,
2520 xmax: 0.5,
2521 ymax: 0.5,
2522 },
2523 score: 0.9,
2524 label: 0, },
2526 DetectBox {
2527 bbox: BoundingBox {
2528 xmin: 0.1,
2529 ymin: 0.1,
2530 xmax: 0.6,
2531 ymax: 0.6,
2532 },
2533 score: 0.8,
2534 label: 1, },
2536 ];
2537
2538 let result = nms_class_aware_float(0.3, boxes.clone());
2541 assert_eq!(
2542 result.len(),
2543 2,
2544 "Class-aware NMS should keep both boxes with different classes"
2545 );
2546
2547 let same_class_boxes = vec![
2549 DetectBox {
2550 bbox: BoundingBox {
2551 xmin: 0.0,
2552 ymin: 0.0,
2553 xmax: 0.5,
2554 ymax: 0.5,
2555 },
2556 score: 0.9,
2557 label: 0,
2558 },
2559 DetectBox {
2560 bbox: BoundingBox {
2561 xmin: 0.1,
2562 ymin: 0.1,
2563 xmax: 0.6,
2564 ymax: 0.6,
2565 },
2566 score: 0.8,
2567 label: 0, },
2569 ];
2570
2571 let result = nms_class_aware_float(0.3, same_class_boxes);
2572 assert_eq!(
2573 result.len(),
2574 1,
2575 "Class-aware NMS should suppress overlapping box with same class"
2576 );
2577 assert_eq!(result[0].label, 0);
2578 assert!((result[0].score - 0.9).abs() < 1e-6);
2579 }
2580
2581 #[test]
2582 fn test_class_agnostic_vs_aware_nms() {
2583 use crate::float::{nms_class_aware_float, nms_float};
2584
2585 let boxes = vec![
2587 DetectBox {
2588 bbox: BoundingBox {
2589 xmin: 0.0,
2590 ymin: 0.0,
2591 xmax: 0.5,
2592 ymax: 0.5,
2593 },
2594 score: 0.9,
2595 label: 0,
2596 },
2597 DetectBox {
2598 bbox: BoundingBox {
2599 xmin: 0.1,
2600 ymin: 0.1,
2601 xmax: 0.6,
2602 ymax: 0.6,
2603 },
2604 score: 0.8,
2605 label: 1,
2606 },
2607 ];
2608
2609 let agnostic_result = nms_float(0.3, boxes.clone());
2611 assert_eq!(
2612 agnostic_result.len(),
2613 1,
2614 "Class-agnostic NMS should suppress overlapping boxes"
2615 );
2616
2617 let aware_result = nms_class_aware_float(0.3, boxes);
2619 assert_eq!(
2620 aware_result.len(),
2621 2,
2622 "Class-aware NMS should keep boxes with different classes"
2623 );
2624 }
2625
2626 #[test]
2627 fn test_class_aware_nms_int() {
2628 use crate::byte::nms_class_aware_int;
2629
2630 let boxes = vec![
2632 DetectBoxQuantized {
2633 bbox: BoundingBox {
2634 xmin: 0.0,
2635 ymin: 0.0,
2636 xmax: 0.5,
2637 ymax: 0.5,
2638 },
2639 score: 200_u8,
2640 label: 0,
2641 },
2642 DetectBoxQuantized {
2643 bbox: BoundingBox {
2644 xmin: 0.1,
2645 ymin: 0.1,
2646 xmax: 0.6,
2647 ymax: 0.6,
2648 },
2649 score: 180_u8,
2650 label: 1, },
2652 ];
2653
2654 let result = nms_class_aware_int(0.5, boxes);
2656 assert_eq!(
2657 result.len(),
2658 2,
2659 "Class-aware NMS (int) should keep boxes with different classes"
2660 );
2661 }
2662
2663 #[test]
2664 fn test_nms_enum_default() {
2665 let default_nms: configs::Nms = Default::default();
2667 assert_eq!(default_nms, configs::Nms::ClassAgnostic);
2668 }
2669
2670 #[test]
2671 fn test_decoder_nms_mode() {
2672 let decoder = DecoderBuilder::default()
2674 .with_config_yolo_det(
2675 configs::Detection {
2676 anchors: None,
2677 decoder: DecoderType::Ultralytics,
2678 quantization: None,
2679 shape: vec![1, 84, 8400],
2680 dshape: Vec::new(),
2681 normalized: Some(true),
2682 },
2683 None,
2684 )
2685 .with_nms(Some(configs::Nms::ClassAware))
2686 .build()
2687 .unwrap();
2688
2689 assert_eq!(decoder.nms, Some(configs::Nms::ClassAware));
2690 }
2691
2692 #[test]
2693 fn test_decoder_nms_bypass() {
2694 let decoder = DecoderBuilder::default()
2696 .with_config_yolo_det(
2697 configs::Detection {
2698 anchors: None,
2699 decoder: DecoderType::Ultralytics,
2700 quantization: None,
2701 shape: vec![1, 84, 8400],
2702 dshape: Vec::new(),
2703 normalized: Some(true),
2704 },
2705 None,
2706 )
2707 .with_nms(None)
2708 .build()
2709 .unwrap();
2710
2711 assert_eq!(decoder.nms, None);
2712 }
2713
2714 #[test]
2715 fn test_decoder_normalized_boxes_true() {
2716 let decoder = DecoderBuilder::default()
2718 .with_config_yolo_det(
2719 configs::Detection {
2720 anchors: None,
2721 decoder: DecoderType::Ultralytics,
2722 quantization: None,
2723 shape: vec![1, 84, 8400],
2724 dshape: Vec::new(),
2725 normalized: Some(true),
2726 },
2727 None,
2728 )
2729 .build()
2730 .unwrap();
2731
2732 assert_eq!(decoder.normalized_boxes(), Some(true));
2733 }
2734
2735 #[test]
2736 fn test_decoder_normalized_boxes_false() {
2737 let decoder = DecoderBuilder::default()
2740 .with_config_yolo_det(
2741 configs::Detection {
2742 anchors: None,
2743 decoder: DecoderType::Ultralytics,
2744 quantization: None,
2745 shape: vec![1, 84, 8400],
2746 dshape: Vec::new(),
2747 normalized: Some(false),
2748 },
2749 None,
2750 )
2751 .build()
2752 .unwrap();
2753
2754 assert_eq!(decoder.normalized_boxes(), Some(false));
2755 }
2756
2757 #[test]
2758 fn test_decoder_normalized_boxes_unknown() {
2759 let decoder = DecoderBuilder::default()
2761 .with_config_yolo_det(
2762 configs::Detection {
2763 anchors: None,
2764 decoder: DecoderType::Ultralytics,
2765 quantization: None,
2766 shape: vec![1, 84, 8400],
2767 dshape: Vec::new(),
2768 normalized: None,
2769 },
2770 Some(DecoderVersion::Yolo11),
2771 )
2772 .build()
2773 .unwrap();
2774
2775 assert_eq!(decoder.normalized_boxes(), None);
2776 }
2777
2778 pub fn quantize_ndarray<T: PrimInt + 'static, D: Dimension, F: Float + AsPrimitive<T>>(
2779 input: ArrayView<F, D>,
2780 quant: Quantization,
2781 ) -> Array<T, D>
2782 where
2783 i32: num_traits::AsPrimitive<F>,
2784 f32: num_traits::AsPrimitive<F>,
2785 {
2786 let zero_point = quant.zero_point.as_();
2787 let div_scale = F::one() / quant.scale.as_();
2788 if zero_point != F::zero() {
2789 input.mapv(|d| (d * div_scale + zero_point).round().as_())
2790 } else {
2791 input.mapv(|d| (d * div_scale).round().as_())
2792 }
2793 }
2794
2795 fn real_data_expected_boxes() -> [DetectBox; 2] {
2796 [
2797 DetectBox {
2798 bbox: BoundingBox {
2799 xmin: 0.08515105,
2800 ymin: 0.7131401,
2801 xmax: 0.29802868,
2802 ymax: 0.8195788,
2803 },
2804 score: 0.91537374,
2805 label: 23,
2806 },
2807 DetectBox {
2808 bbox: BoundingBox {
2809 xmin: 0.59605736,
2810 ymin: 0.25545314,
2811 xmax: 0.93666154,
2812 ymax: 0.72378385,
2813 },
2814 score: 0.91537374,
2815 label: 23,
2816 },
2817 ]
2818 }
2819
2820 fn e2e_expected_boxes_quant() -> [DetectBox; 1] {
2821 [DetectBox {
2822 bbox: BoundingBox {
2823 xmin: 0.12549022,
2824 ymin: 0.12549022,
2825 xmax: 0.23529413,
2826 ymax: 0.23529413,
2827 },
2828 score: 0.98823535,
2829 label: 2,
2830 }]
2831 }
2832
2833 fn e2e_expected_boxes_float() -> [DetectBox; 1] {
2834 [DetectBox {
2835 bbox: BoundingBox {
2836 xmin: 0.1234,
2837 ymin: 0.1234,
2838 xmax: 0.2345,
2839 ymax: 0.2345,
2840 },
2841 score: 0.9876,
2842 label: 2,
2843 }]
2844 }
2845
2846 macro_rules! real_data_proto_test {
2847 ($name:ident, quantized, $layout:ident) => {
2848 #[test]
2849 fn $name() {
2850 let is_split = matches!(stringify!($layout), "split");
2851
2852 let score_threshold = 0.45;
2853 let iou_threshold = 0.45;
2854 let quant_boxes = (0.021287762_f32, 31_i32);
2855 let quant_protos = (0.02491162_f32, -117_i32);
2856
2857 let raw_boxes = include_bytes!(concat!(
2858 env!("CARGO_MANIFEST_DIR"),
2859 "/../../testdata/yolov8_boxes_116x8400.bin"
2860 ));
2861 let raw_boxes = unsafe {
2862 std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len())
2863 };
2864 let boxes_i8 =
2865 ndarray::Array3::from_shape_vec((1, 116, 8400), raw_boxes.to_vec()).unwrap();
2866
2867 let raw_protos = include_bytes!(concat!(
2868 env!("CARGO_MANIFEST_DIR"),
2869 "/../../testdata/yolov8_protos_160x160x32.bin"
2870 ));
2871 let raw_protos = unsafe {
2872 std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
2873 };
2874 let protos_i8 =
2875 ndarray::Array4::from_shape_vec((1, 160, 160, 32), raw_protos.to_vec())
2876 .unwrap();
2877
2878 let mask_split = boxes_i8.slice(s![.., 84.., ..]).to_owned();
2880 let scores_split = boxes_i8.slice(s![.., 4..84, ..]).to_owned();
2881 let boxes_split = boxes_i8.slice(s![.., ..4, ..]).to_owned();
2882 let boxes_combined = boxes_i8;
2883
2884 let decoder = if is_split {
2885 build_yolo_split_segdet_decoder(
2886 score_threshold,
2887 iou_threshold,
2888 quant_boxes,
2889 quant_protos,
2890 )
2891 } else {
2892 build_yolov8_seg_decoder(score_threshold, iou_threshold)
2893 };
2894
2895 let expected = real_data_expected_boxes();
2896 let mut output_boxes = Vec::with_capacity(50);
2897
2898 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
2899 vec![
2900 boxes_split.view().into(),
2901 scores_split.view().into(),
2902 mask_split.view().into(),
2903 protos_i8.view().into(),
2904 ]
2905 } else {
2906 vec![boxes_combined.view().into(), protos_i8.view().into()]
2907 };
2908 decoder
2909 .decode_quantized_proto(&inputs, &mut output_boxes)
2910 .unwrap();
2911
2912 assert_eq!(output_boxes.len(), 2);
2913 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
2914 assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
2915 }
2916 };
2917 ($name:ident, float, $layout:ident) => {
2918 #[test]
2919 fn $name() {
2920 let is_split = matches!(stringify!($layout), "split");
2921
2922 let score_threshold = 0.45;
2923 let iou_threshold = 0.45;
2924 let quant_boxes = (0.021287762_f32, 31_i32);
2925 let quant_protos = (0.02491162_f32, -117_i32);
2926
2927 let raw_boxes = include_bytes!(concat!(
2928 env!("CARGO_MANIFEST_DIR"),
2929 "/../../testdata/yolov8_boxes_116x8400.bin"
2930 ));
2931 let raw_boxes = unsafe {
2932 std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len())
2933 };
2934 let boxes_i8 =
2935 ndarray::Array3::from_shape_vec((1, 116, 8400), raw_boxes.to_vec()).unwrap();
2936 let boxes_f32: Array3<f32> =
2937 dequantize_ndarray(boxes_i8.view(), quant_boxes.into());
2938
2939 let raw_protos = include_bytes!(concat!(
2940 env!("CARGO_MANIFEST_DIR"),
2941 "/../../testdata/yolov8_protos_160x160x32.bin"
2942 ));
2943 let raw_protos = unsafe {
2944 std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
2945 };
2946 let protos_i8 =
2947 ndarray::Array4::from_shape_vec((1, 160, 160, 32), raw_protos.to_vec())
2948 .unwrap();
2949 let protos_f32: Array4<f32> =
2950 dequantize_ndarray(protos_i8.view(), quant_protos.into());
2951
2952 let mask_split = boxes_f32.slice(s![.., 84.., ..]).to_owned();
2954 let scores_split = boxes_f32.slice(s![.., 4..84, ..]).to_owned();
2955 let boxes_split = boxes_f32.slice(s![.., ..4, ..]).to_owned();
2956 let boxes_combined = boxes_f32;
2957
2958 let decoder = if is_split {
2959 build_yolo_split_segdet_decoder(
2960 score_threshold,
2961 iou_threshold,
2962 quant_boxes,
2963 quant_protos,
2964 )
2965 } else {
2966 build_yolov8_seg_decoder(score_threshold, iou_threshold)
2967 };
2968
2969 let expected = real_data_expected_boxes();
2970 let mut output_boxes = Vec::with_capacity(50);
2971
2972 let inputs = if is_split {
2973 vec![
2974 boxes_split.view().into_dyn(),
2975 scores_split.view().into_dyn(),
2976 mask_split.view().into_dyn(),
2977 protos_f32.view().into_dyn(),
2978 ]
2979 } else {
2980 vec![
2981 boxes_combined.view().into_dyn(),
2982 protos_f32.view().into_dyn(),
2983 ]
2984 };
2985 decoder
2986 .decode_float_proto(&inputs, &mut output_boxes)
2987 .unwrap();
2988
2989 assert_eq!(output_boxes.len(), 2);
2990 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
2991 assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
2992 }
2993 };
2994 }
2995
2996 real_data_proto_test!(test_decoder_segdet_proto, quantized, combined);
2997 real_data_proto_test!(test_decoder_segdet_proto_float, float, combined);
2998 real_data_proto_test!(test_decoder_segdet_split_proto, quantized, split);
2999 real_data_proto_test!(test_decoder_segdet_split_proto_float, float, split);
3000
3001 const E2E_COMBINED_DET_CONFIG: &str = "
3002decoder_version: yolo26
3003outputs:
3004 - type: detection
3005 decoder: ultralytics
3006 quantization: [0.00784313725490196, 0]
3007 shape: [1, 10, 6]
3008 dshape:
3009 - [batch, 1]
3010 - [num_boxes, 10]
3011 - [num_features, 6]
3012 normalized: true
3013";
3014
3015 const E2E_COMBINED_SEGDET_CONFIG: &str = "
3016decoder_version: yolo26
3017outputs:
3018 - type: detection
3019 decoder: ultralytics
3020 quantization: [0.00784313725490196, 0]
3021 shape: [1, 10, 38]
3022 dshape:
3023 - [batch, 1]
3024 - [num_boxes, 10]
3025 - [num_features, 38]
3026 normalized: true
3027 - type: protos
3028 decoder: ultralytics
3029 quantization: [0.0039215686274509803921568627451, 128]
3030 shape: [1, 160, 160, 32]
3031 dshape:
3032 - [batch, 1]
3033 - [height, 160]
3034 - [width, 160]
3035 - [num_protos, 32]
3036";
3037
3038 const E2E_SPLIT_DET_CONFIG: &str = "
3039decoder_version: yolo26
3040outputs:
3041 - type: boxes
3042 decoder: ultralytics
3043 quantization: [0.00784313725490196, 0]
3044 shape: [1, 10, 4]
3045 dshape:
3046 - [batch, 1]
3047 - [num_boxes, 10]
3048 - [box_coords, 4]
3049 normalized: true
3050 - type: scores
3051 decoder: ultralytics
3052 quantization: [0.00784313725490196, 0]
3053 shape: [1, 10, 1]
3054 dshape:
3055 - [batch, 1]
3056 - [num_boxes, 10]
3057 - [num_classes, 1]
3058 - type: classes
3059 decoder: ultralytics
3060 quantization: [0.00784313725490196, 0]
3061 shape: [1, 10, 1]
3062 dshape:
3063 - [batch, 1]
3064 - [num_boxes, 10]
3065 - [num_classes, 1]
3066";
3067
3068 const E2E_SPLIT_SEGDET_CONFIG: &str = "
3069decoder_version: yolo26
3070outputs:
3071 - type: boxes
3072 decoder: ultralytics
3073 quantization: [0.00784313725490196, 0]
3074 shape: [1, 10, 4]
3075 dshape:
3076 - [batch, 1]
3077 - [num_boxes, 10]
3078 - [box_coords, 4]
3079 normalized: true
3080 - type: scores
3081 decoder: ultralytics
3082 quantization: [0.00784313725490196, 0]
3083 shape: [1, 10, 1]
3084 dshape:
3085 - [batch, 1]
3086 - [num_boxes, 10]
3087 - [num_classes, 1]
3088 - type: classes
3089 decoder: ultralytics
3090 quantization: [0.00784313725490196, 0]
3091 shape: [1, 10, 1]
3092 dshape:
3093 - [batch, 1]
3094 - [num_boxes, 10]
3095 - [num_classes, 1]
3096 - type: mask_coefficients
3097 decoder: ultralytics
3098 quantization: [0.00784313725490196, 0]
3099 shape: [1, 10, 32]
3100 dshape:
3101 - [batch, 1]
3102 - [num_boxes, 10]
3103 - [num_protos, 32]
3104 - type: protos
3105 decoder: ultralytics
3106 quantization: [0.0039215686274509803921568627451, 128]
3107 shape: [1, 160, 160, 32]
3108 dshape:
3109 - [batch, 1]
3110 - [height, 160]
3111 - [width, 160]
3112 - [num_protos, 32]
3113";
3114
3115 macro_rules! e2e_segdet_test {
3116 ($name:ident, quantized, $layout:ident, $output:ident) => {
3117 #[test]
3118 fn $name() {
3119 let is_split = matches!(stringify!($layout), "split");
3120 let is_proto = matches!(stringify!($output), "proto");
3121
3122 let score_threshold = 0.45;
3123 let iou_threshold = 0.45;
3124
3125 let mut boxes = Array2::zeros((10, 4));
3126 let mut scores = Array2::zeros((10, 1));
3127 let mut classes = Array2::zeros((10, 1));
3128 let mask = Array2::zeros((10, 32));
3129 let protos = Array3::<f64>::zeros((160, 160, 32));
3130 let protos = protos.insert_axis(Axis(0));
3131 let protos_quant = (1.0 / 255.0, 0.0);
3132 let protos: Array4<u8> = quantize_ndarray(protos.view(), protos_quant.into());
3133
3134 boxes
3135 .slice_mut(s![0, ..])
3136 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
3137 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
3138 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
3139
3140 let detect_quant = (2.0 / 255.0, 0.0);
3141
3142 let decoder = if is_split {
3143 DecoderBuilder::default()
3144 .with_config_yaml_str(E2E_SPLIT_SEGDET_CONFIG.to_string())
3145 .with_score_threshold(score_threshold)
3146 .with_iou_threshold(iou_threshold)
3147 .build()
3148 .unwrap()
3149 } else {
3150 DecoderBuilder::default()
3151 .with_config_yaml_str(E2E_COMBINED_SEGDET_CONFIG.to_string())
3152 .with_score_threshold(score_threshold)
3153 .with_iou_threshold(iou_threshold)
3154 .build()
3155 .unwrap()
3156 };
3157
3158 let expected = e2e_expected_boxes_quant();
3159 let mut output_boxes = Vec::with_capacity(50);
3160
3161 if is_split {
3162 let boxes = boxes.insert_axis(Axis(0));
3163 let scores = scores.insert_axis(Axis(0));
3164 let classes = classes.insert_axis(Axis(0));
3165 let mask = mask.insert_axis(Axis(0));
3166
3167 let boxes: Array3<u8> = quantize_ndarray(boxes.view(), detect_quant.into());
3168 let scores: Array3<u8> = quantize_ndarray(scores.view(), detect_quant.into());
3169 let classes: Array3<u8> = quantize_ndarray(classes.view(), detect_quant.into());
3170 let mask: Array3<u8> = quantize_ndarray(mask.view(), detect_quant.into());
3171
3172 if is_proto {
3173 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
3174 boxes.view().into(),
3175 scores.view().into(),
3176 classes.view().into(),
3177 mask.view().into(),
3178 protos.view().into(),
3179 ];
3180 decoder
3181 .decode_quantized_proto(&inputs, &mut output_boxes)
3182 .unwrap();
3183
3184 assert_eq!(output_boxes.len(), 1);
3185 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3186 } else {
3187 let mut output_masks = Vec::with_capacity(50);
3188 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
3189 boxes.view().into(),
3190 scores.view().into(),
3191 classes.view().into(),
3192 mask.view().into(),
3193 protos.view().into(),
3194 ];
3195 decoder
3196 .decode_quantized(&inputs, &mut output_boxes, &mut output_masks)
3197 .unwrap();
3198
3199 assert_eq!(output_boxes.len(), 1);
3200 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3201 }
3202 } else {
3203 let detect = ndarray::concatenate![
3205 Axis(1),
3206 boxes.view(),
3207 scores.view(),
3208 classes.view(),
3209 mask.view()
3210 ];
3211 let detect = detect.insert_axis(Axis(0));
3212 assert_eq!(detect.shape(), &[1, 10, 38]);
3213 let detect: Array3<u8> = quantize_ndarray(detect.view(), detect_quant.into());
3214
3215 if is_proto {
3216 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
3217 vec![detect.view().into(), protos.view().into()];
3218 decoder
3219 .decode_quantized_proto(&inputs, &mut output_boxes)
3220 .unwrap();
3221
3222 assert_eq!(output_boxes.len(), 1);
3223 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3224 } else {
3225 let mut output_masks = Vec::with_capacity(50);
3226 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
3227 vec![detect.view().into(), protos.view().into()];
3228 decoder
3229 .decode_quantized(&inputs, &mut output_boxes, &mut output_masks)
3230 .unwrap();
3231
3232 assert_eq!(output_boxes.len(), 1);
3233 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3234 }
3235 }
3236 }
3237 };
3238 ($name:ident, float, $layout:ident, $output:ident) => {
3239 #[test]
3240 fn $name() {
3241 let is_split = matches!(stringify!($layout), "split");
3242 let is_proto = matches!(stringify!($output), "proto");
3243
3244 let score_threshold = 0.45;
3245 let iou_threshold = 0.45;
3246
3247 let mut boxes = Array2::zeros((10, 4));
3248 let mut scores = Array2::zeros((10, 1));
3249 let mut classes = Array2::zeros((10, 1));
3250 let mask: Array2<f64> = Array2::zeros((10, 32));
3251 let protos = Array3::<f64>::zeros((160, 160, 32));
3252 let protos = protos.insert_axis(Axis(0));
3253
3254 boxes
3255 .slice_mut(s![0, ..])
3256 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
3257 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
3258 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
3259
3260 let decoder = if is_split {
3261 DecoderBuilder::default()
3262 .with_config_yaml_str(E2E_SPLIT_SEGDET_CONFIG.to_string())
3263 .with_score_threshold(score_threshold)
3264 .with_iou_threshold(iou_threshold)
3265 .build()
3266 .unwrap()
3267 } else {
3268 DecoderBuilder::default()
3269 .with_config_yaml_str(E2E_COMBINED_SEGDET_CONFIG.to_string())
3270 .with_score_threshold(score_threshold)
3271 .with_iou_threshold(iou_threshold)
3272 .build()
3273 .unwrap()
3274 };
3275
3276 let expected = e2e_expected_boxes_float();
3277 let mut output_boxes = Vec::with_capacity(50);
3278
3279 if is_split {
3280 let boxes = boxes.insert_axis(Axis(0));
3281 let scores = scores.insert_axis(Axis(0));
3282 let classes = classes.insert_axis(Axis(0));
3283 let mask = mask.insert_axis(Axis(0));
3284
3285 if is_proto {
3286 let inputs = vec![
3287 boxes.view().into_dyn(),
3288 scores.view().into_dyn(),
3289 classes.view().into_dyn(),
3290 mask.view().into_dyn(),
3291 protos.view().into_dyn(),
3292 ];
3293 decoder
3294 .decode_float_proto(&inputs, &mut output_boxes)
3295 .unwrap();
3296
3297 assert_eq!(output_boxes.len(), 1);
3298 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3299 } else {
3300 let mut output_masks = Vec::with_capacity(50);
3301 let inputs = vec![
3302 boxes.view().into_dyn(),
3303 scores.view().into_dyn(),
3304 classes.view().into_dyn(),
3305 mask.view().into_dyn(),
3306 protos.view().into_dyn(),
3307 ];
3308 decoder
3309 .decode_float(&inputs, &mut output_boxes, &mut output_masks)
3310 .unwrap();
3311
3312 assert_eq!(output_boxes.len(), 1);
3313 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3314 }
3315 } else {
3316 let detect = ndarray::concatenate![
3318 Axis(1),
3319 boxes.view(),
3320 scores.view(),
3321 classes.view(),
3322 mask.view()
3323 ];
3324 let detect = detect.insert_axis(Axis(0));
3325 assert_eq!(detect.shape(), &[1, 10, 38]);
3326
3327 if is_proto {
3328 let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
3329 decoder
3330 .decode_float_proto(&inputs, &mut output_boxes)
3331 .unwrap();
3332
3333 assert_eq!(output_boxes.len(), 1);
3334 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3335 } else {
3336 let mut output_masks = Vec::with_capacity(50);
3337 let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
3338 decoder
3339 .decode_float(&inputs, &mut output_boxes, &mut output_masks)
3340 .unwrap();
3341
3342 assert_eq!(output_boxes.len(), 1);
3343 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3344 }
3345 }
3346 }
3347 };
3348 }
3349
3350 e2e_segdet_test!(test_decoder_end_to_end_segdet, quantized, combined, masks);
3351 e2e_segdet_test!(test_decoder_end_to_end_segdet_float, float, combined, masks);
3352 e2e_segdet_test!(
3353 test_decoder_end_to_end_segdet_proto,
3354 quantized,
3355 combined,
3356 proto
3357 );
3358 e2e_segdet_test!(
3359 test_decoder_end_to_end_segdet_proto_float,
3360 float,
3361 combined,
3362 proto
3363 );
3364 e2e_segdet_test!(
3365 test_decoder_end_to_end_segdet_split,
3366 quantized,
3367 split,
3368 masks
3369 );
3370 e2e_segdet_test!(
3371 test_decoder_end_to_end_segdet_split_float,
3372 float,
3373 split,
3374 masks
3375 );
3376 e2e_segdet_test!(
3377 test_decoder_end_to_end_segdet_split_proto,
3378 quantized,
3379 split,
3380 proto
3381 );
3382 e2e_segdet_test!(
3383 test_decoder_end_to_end_segdet_split_proto_float,
3384 float,
3385 split,
3386 proto
3387 );
3388
3389 macro_rules! e2e_det_test {
3390 ($name:ident, quantized, $layout:ident) => {
3391 #[test]
3392 fn $name() {
3393 let is_split = matches!(stringify!($layout), "split");
3394
3395 let score_threshold = 0.45;
3396 let iou_threshold = 0.45;
3397
3398 let mut boxes = Array3::zeros((1, 10, 4));
3399 let mut scores = Array3::zeros((1, 10, 1));
3400 let mut classes = Array3::zeros((1, 10, 1));
3401
3402 boxes
3403 .slice_mut(s![0, 0, ..])
3404 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
3405 scores.slice_mut(s![0, 0, ..]).assign(&array![0.9876]);
3406 classes.slice_mut(s![0, 0, ..]).assign(&array![2.0]);
3407
3408 let detect_quant = (2.0 / 255.0, 0_i32);
3409
3410 let decoder = if is_split {
3411 DecoderBuilder::default()
3412 .with_config_yaml_str(E2E_SPLIT_DET_CONFIG.to_string())
3413 .with_score_threshold(score_threshold)
3414 .with_iou_threshold(iou_threshold)
3415 .build()
3416 .unwrap()
3417 } else {
3418 DecoderBuilder::default()
3419 .with_config_yaml_str(E2E_COMBINED_DET_CONFIG.to_string())
3420 .with_score_threshold(score_threshold)
3421 .with_iou_threshold(iou_threshold)
3422 .build()
3423 .unwrap()
3424 };
3425
3426 let expected = e2e_expected_boxes_quant();
3427 let mut output_boxes = Vec::with_capacity(50);
3428
3429 if is_split {
3430 let boxes: Array<u8, _> = quantize_ndarray(boxes.view(), detect_quant.into());
3431 let scores: Array<u8, _> = quantize_ndarray(scores.view(), detect_quant.into());
3432 let classes: Array<u8, _> =
3433 quantize_ndarray(classes.view(), detect_quant.into());
3434 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
3435 boxes.view().into(),
3436 scores.view().into(),
3437 classes.view().into(),
3438 ];
3439 decoder
3440 .decode_quantized(&inputs, &mut output_boxes, &mut Vec::new())
3441 .unwrap();
3442 } else {
3443 let detect =
3444 ndarray::concatenate![Axis(2), boxes.view(), scores.view(), classes.view()];
3445 assert_eq!(detect.shape(), &[1, 10, 6]);
3446 let detect: Array3<u8> = quantize_ndarray(detect.view(), detect_quant.into());
3447 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
3448 vec![detect.view().into()];
3449 decoder
3450 .decode_quantized(&inputs, &mut output_boxes, &mut Vec::new())
3451 .unwrap();
3452 }
3453
3454 assert_eq!(output_boxes.len(), 1);
3455 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
3456 }
3457 };
3458 ($name:ident, float, $layout:ident) => {
3459 #[test]
3460 fn $name() {
3461 let is_split = matches!(stringify!($layout), "split");
3462
3463 let score_threshold = 0.45;
3464 let iou_threshold = 0.45;
3465
3466 let mut boxes = Array3::zeros((1, 10, 4));
3467 let mut scores = Array3::zeros((1, 10, 1));
3468 let mut classes = Array3::zeros((1, 10, 1));
3469
3470 boxes
3471 .slice_mut(s![0, 0, ..])
3472 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
3473 scores.slice_mut(s![0, 0, ..]).assign(&array![0.9876]);
3474 classes.slice_mut(s![0, 0, ..]).assign(&array![2.0]);
3475
3476 let decoder = if is_split {
3477 DecoderBuilder::default()
3478 .with_config_yaml_str(E2E_SPLIT_DET_CONFIG.to_string())
3479 .with_score_threshold(score_threshold)
3480 .with_iou_threshold(iou_threshold)
3481 .build()
3482 .unwrap()
3483 } else {
3484 DecoderBuilder::default()
3485 .with_config_yaml_str(E2E_COMBINED_DET_CONFIG.to_string())
3486 .with_score_threshold(score_threshold)
3487 .with_iou_threshold(iou_threshold)
3488 .build()
3489 .unwrap()
3490 };
3491
3492 let expected = e2e_expected_boxes_float();
3493 let mut output_boxes = Vec::with_capacity(50);
3494
3495 if is_split {
3496 let inputs = vec![
3497 boxes.view().into_dyn(),
3498 scores.view().into_dyn(),
3499 classes.view().into_dyn(),
3500 ];
3501 decoder
3502 .decode_float(&inputs, &mut output_boxes, &mut Vec::new())
3503 .unwrap();
3504 } else {
3505 let detect =
3506 ndarray::concatenate![Axis(2), boxes.view(), scores.view(), classes.view()];
3507 assert_eq!(detect.shape(), &[1, 10, 6]);
3508 let inputs = vec![detect.view().into_dyn()];
3509 decoder
3510 .decode_float(&inputs, &mut output_boxes, &mut Vec::new())
3511 .unwrap();
3512 }
3513
3514 assert_eq!(output_boxes.len(), 1);
3515 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
3516 }
3517 };
3518 }
3519
3520 e2e_det_test!(test_decoder_end_to_end_combined_det, quantized, combined);
3521 e2e_det_test!(test_decoder_end_to_end_combined_det_float, float, combined);
3522 e2e_det_test!(test_decoder_end_to_end_split_det, quantized, split);
3523 e2e_det_test!(test_decoder_end_to_end_split_det_float, float, split);
3524
3525 #[test]
3526 fn test_decode_tensor() {
3527 let score_threshold = 0.45;
3528 let iou_threshold = 0.45;
3529
3530 let raw_boxes = include_bytes!(concat!(
3531 env!("CARGO_MANIFEST_DIR"),
3532 "/../../testdata/yolov8_boxes_116x8400.bin"
3533 ));
3534 let raw_boxes =
3535 unsafe { std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len()) };
3536 let boxes_i8: Tensor<i8> = Tensor::new(&[1, 116, 8400], None, None).unwrap();
3537 boxes_i8
3538 .map()
3539 .unwrap()
3540 .as_mut_slice()
3541 .copy_from_slice(raw_boxes);
3542 let boxes_i8 = boxes_i8.into();
3543
3544 let raw_protos = include_bytes!(concat!(
3545 env!("CARGO_MANIFEST_DIR"),
3546 "/../../testdata/yolov8_protos_160x160x32.bin"
3547 ));
3548 let raw_protos = unsafe {
3549 std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3550 };
3551 let protos_i8: Tensor<i8> = Tensor::new(&[1, 160, 160, 32], None, None).unwrap();
3552 protos_i8
3553 .map()
3554 .unwrap()
3555 .as_mut_slice()
3556 .copy_from_slice(raw_protos);
3557 let protos_i8 = protos_i8.into();
3558
3559 let decoder = build_yolov8_seg_decoder(score_threshold, iou_threshold);
3560 let expected = real_data_expected_boxes();
3561 let mut output_boxes = Vec::with_capacity(50);
3562
3563 decoder
3564 .decode(&[&boxes_i8, &protos_i8], &mut output_boxes, &mut Vec::new())
3565 .unwrap();
3566
3567 assert_eq!(output_boxes.len(), 2);
3568 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3569 assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3570 }
3571
3572 #[test]
3573 fn test_decode_tensor_f32() {
3574 let score_threshold = 0.45;
3575 let iou_threshold = 0.45;
3576
3577 let quant_boxes = (0.021287762_f32, 31_i32);
3578 let quant_protos = (0.02491162_f32, -117_i32);
3579 let raw_boxes = include_bytes!(concat!(
3580 env!("CARGO_MANIFEST_DIR"),
3581 "/../../testdata/yolov8_boxes_116x8400.bin"
3582 ));
3583 let raw_boxes =
3584 unsafe { std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len()) };
3585 let mut raw_boxes_f32 = vec![0f32; raw_boxes.len()];
3586 dequantize_cpu(raw_boxes, quant_boxes.into(), &mut raw_boxes_f32);
3587 let boxes_f32: Tensor<f32> = Tensor::new(&[1, 116, 8400], None, None).unwrap();
3588 boxes_f32
3589 .map()
3590 .unwrap()
3591 .as_mut_slice()
3592 .copy_from_slice(&raw_boxes_f32);
3593 let boxes_f32 = boxes_f32.into();
3594
3595 let raw_protos = include_bytes!(concat!(
3596 env!("CARGO_MANIFEST_DIR"),
3597 "/../../testdata/yolov8_protos_160x160x32.bin"
3598 ));
3599 let raw_protos = unsafe {
3600 std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3601 };
3602 let mut raw_protos_f32 = vec![0f32; raw_protos.len()];
3603 dequantize_cpu(raw_protos, quant_protos.into(), &mut raw_protos_f32);
3604 let protos_f32: Tensor<f32> = Tensor::new(&[1, 160, 160, 32], None, None).unwrap();
3605 protos_f32
3606 .map()
3607 .unwrap()
3608 .as_mut_slice()
3609 .copy_from_slice(&raw_protos_f32);
3610 let protos_f32 = protos_f32.into();
3611
3612 let decoder = build_yolov8_seg_decoder(score_threshold, iou_threshold);
3613
3614 let expected = real_data_expected_boxes();
3615 let mut output_boxes = Vec::with_capacity(50);
3616
3617 decoder
3618 .decode(
3619 &[&boxes_f32, &protos_f32],
3620 &mut output_boxes,
3621 &mut Vec::new(),
3622 )
3623 .unwrap();
3624
3625 assert_eq!(output_boxes.len(), 2);
3626 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3627 assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3628 }
3629
3630 #[test]
3631 fn test_decode_tensor_f64() {
3632 let score_threshold = 0.45;
3633 let iou_threshold = 0.45;
3634
3635 let quant_boxes = (0.021287762_f32, 31_i32);
3636 let quant_protos = (0.02491162_f32, -117_i32);
3637 let raw_boxes = include_bytes!(concat!(
3638 env!("CARGO_MANIFEST_DIR"),
3639 "/../../testdata/yolov8_boxes_116x8400.bin"
3640 ));
3641 let raw_boxes =
3642 unsafe { std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len()) };
3643 let mut raw_boxes_f64 = vec![0f64; raw_boxes.len()];
3644 dequantize_cpu(raw_boxes, quant_boxes.into(), &mut raw_boxes_f64);
3645 let boxes_f64: Tensor<f64> = Tensor::new(&[1, 116, 8400], None, None).unwrap();
3646 boxes_f64
3647 .map()
3648 .unwrap()
3649 .as_mut_slice()
3650 .copy_from_slice(&raw_boxes_f64);
3651 let boxes_f64 = boxes_f64.into();
3652
3653 let raw_protos = include_bytes!(concat!(
3654 env!("CARGO_MANIFEST_DIR"),
3655 "/../../testdata/yolov8_protos_160x160x32.bin"
3656 ));
3657 let raw_protos = unsafe {
3658 std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3659 };
3660 let mut raw_protos_f64 = vec![0f64; raw_protos.len()];
3661 dequantize_cpu(raw_protos, quant_protos.into(), &mut raw_protos_f64);
3662 let protos_f64: Tensor<f64> = Tensor::new(&[1, 160, 160, 32], None, None).unwrap();
3663 protos_f64
3664 .map()
3665 .unwrap()
3666 .as_mut_slice()
3667 .copy_from_slice(&raw_protos_f64);
3668 let protos_f64 = protos_f64.into();
3669
3670 let decoder = build_yolov8_seg_decoder(score_threshold, iou_threshold);
3671
3672 let expected = real_data_expected_boxes();
3673 let mut output_boxes = Vec::with_capacity(50);
3674
3675 decoder
3676 .decode(
3677 &[&boxes_f64, &protos_f64],
3678 &mut output_boxes,
3679 &mut Vec::new(),
3680 )
3681 .unwrap();
3682
3683 assert_eq!(output_boxes.len(), 2);
3684 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3685 assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3686 }
3687
3688 #[test]
3689 fn test_decode_tensor_proto() {
3690 let score_threshold = 0.45;
3691 let iou_threshold = 0.45;
3692
3693 let raw_boxes = include_bytes!(concat!(
3694 env!("CARGO_MANIFEST_DIR"),
3695 "/../../testdata/yolov8_boxes_116x8400.bin"
3696 ));
3697 let raw_boxes =
3698 unsafe { std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len()) };
3699 let boxes_i8: Tensor<i8> = Tensor::new(&[1, 116, 8400], None, None).unwrap();
3700 boxes_i8
3701 .map()
3702 .unwrap()
3703 .as_mut_slice()
3704 .copy_from_slice(raw_boxes);
3705 let boxes_i8 = boxes_i8.into();
3706
3707 let raw_protos = include_bytes!(concat!(
3708 env!("CARGO_MANIFEST_DIR"),
3709 "/../../testdata/yolov8_protos_160x160x32.bin"
3710 ));
3711 let raw_protos = unsafe {
3712 std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3713 };
3714 let protos_i8: Tensor<i8> = Tensor::new(&[1, 160, 160, 32], None, None).unwrap();
3715 protos_i8
3716 .map()
3717 .unwrap()
3718 .as_mut_slice()
3719 .copy_from_slice(raw_protos);
3720 let protos_i8 = protos_i8.into();
3721
3722 let decoder = build_yolov8_seg_decoder(score_threshold, iou_threshold);
3723
3724 let expected = real_data_expected_boxes();
3725 let mut output_boxes = Vec::with_capacity(50);
3726
3727 let proto_data = decoder
3728 .decode_proto(&[&boxes_i8, &protos_i8], &mut output_boxes)
3729 .unwrap();
3730
3731 assert_eq!(output_boxes.len(), 2);
3732 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3733 assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3734
3735 let proto_data = proto_data.expect("segmentation model should return ProtoData");
3736 assert_eq!(
3737 proto_data.mask_coefficients.len(),
3738 output_boxes.len(),
3739 "mask_coefficients count must match detection count"
3740 );
3741 for coeff in &proto_data.mask_coefficients {
3742 assert_eq!(
3743 coeff.len(),
3744 32,
3745 "each detection should have 32 mask coefficients"
3746 );
3747 }
3748 }
3749}
3750
3751#[cfg(feature = "tracker")]
3752#[cfg(test)]
3753#[cfg_attr(coverage_nightly, coverage(off))]
3754mod decoder_tracked_tests {
3755
3756 use edgefirst_tracker::{ByteTrackBuilder, Tracker};
3757 use ndarray::{array, s, Array, Array2, Array3, Array4, ArrayView, Axis, Dimension};
3758 use num_traits::{AsPrimitive, Float, PrimInt};
3759 use rand::{RngExt, SeedableRng};
3760 use rand_distr::StandardNormal;
3761
3762 use crate::{
3763 configs::{self, DimName},
3764 dequantize_ndarray, BoundingBox, DecoderBuilder, DetectBox, Quantization,
3765 };
3766
3767 pub fn quantize_ndarray<T: PrimInt + 'static, D: Dimension, F: Float + AsPrimitive<T>>(
3768 input: ArrayView<F, D>,
3769 quant: Quantization,
3770 ) -> Array<T, D>
3771 where
3772 i32: num_traits::AsPrimitive<F>,
3773 f32: num_traits::AsPrimitive<F>,
3774 {
3775 let zero_point = quant.zero_point.as_();
3776 let div_scale = F::one() / quant.scale.as_();
3777 if zero_point != F::zero() {
3778 input.mapv(|d| (d * div_scale + zero_point).round().as_())
3779 } else {
3780 input.mapv(|d| (d * div_scale).round().as_())
3781 }
3782 }
3783
3784 #[test]
3785 fn test_decoder_tracked_random_jitter() {
3786 use crate::configs::{DecoderType, Nms};
3787 use crate::DecoderBuilder;
3788
3789 let score_threshold = 0.25;
3790 let iou_threshold = 0.1;
3791 let out = include_bytes!(concat!(
3792 env!("CARGO_MANIFEST_DIR"),
3793 "/../../testdata/yolov8s_80_classes.bin"
3794 ));
3795 let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
3796 let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
3797 let quant = (0.0040811873, -123).into();
3798
3799 let decoder = DecoderBuilder::default()
3800 .with_config_yolo_det(
3801 crate::configs::Detection {
3802 decoder: DecoderType::Ultralytics,
3803 shape: vec![1, 84, 8400],
3804 anchors: None,
3805 quantization: Some(quant),
3806 dshape: vec![
3807 (crate::configs::DimName::Batch, 1),
3808 (crate::configs::DimName::NumFeatures, 84),
3809 (crate::configs::DimName::NumBoxes, 8400),
3810 ],
3811 normalized: Some(true),
3812 },
3813 None,
3814 )
3815 .with_score_threshold(score_threshold)
3816 .with_iou_threshold(iou_threshold)
3817 .with_nms(Some(Nms::ClassAgnostic))
3818 .build()
3819 .unwrap();
3820 let mut rng = rand::rngs::StdRng::seed_from_u64(0xAB_BEEF); let expected_boxes = [
3823 crate::DetectBox {
3824 bbox: crate::BoundingBox {
3825 xmin: 0.5285137,
3826 ymin: 0.05305544,
3827 xmax: 0.87541467,
3828 ymax: 0.9998909,
3829 },
3830 score: 0.5591227,
3831 label: 0,
3832 },
3833 crate::DetectBox {
3834 bbox: crate::BoundingBox {
3835 xmin: 0.130598,
3836 ymin: 0.43260583,
3837 xmax: 0.35098213,
3838 ymax: 0.9958097,
3839 },
3840 score: 0.33057618,
3841 label: 75,
3842 },
3843 ];
3844
3845 let mut tracker = ByteTrackBuilder::new()
3846 .track_update(0.1)
3847 .track_high_conf(0.3)
3848 .build();
3849
3850 let mut output_boxes = Vec::with_capacity(50);
3851 let mut output_masks = Vec::with_capacity(50);
3852 let mut output_tracks = Vec::with_capacity(50);
3853
3854 decoder
3855 .decode_tracked_quantized(
3856 &mut tracker,
3857 0,
3858 &[out.view().into()],
3859 &mut output_boxes,
3860 &mut output_masks,
3861 &mut output_tracks,
3862 )
3863 .unwrap();
3864
3865 assert_eq!(output_boxes.len(), 2);
3866 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3867 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3868
3869 let mut last_boxes = output_boxes.clone();
3870
3871 for i in 1..=100 {
3872 let mut out = out.clone();
3873 let mut x_values = out.slice_mut(s![0, 0, ..]);
3875 for x in x_values.iter_mut() {
3876 let r: f32 = rng.sample(StandardNormal);
3877 let r = r.clamp(-2.0, 2.0) / 2.0;
3878 *x = x.saturating_add((r * 1e-2 / quant.0) as i8);
3879 }
3880
3881 let mut y_values = out.slice_mut(s![0, 1, ..]);
3882 for y in y_values.iter_mut() {
3883 let r: f32 = rng.sample(StandardNormal);
3884 let r = r.clamp(-2.0, 2.0) / 2.0;
3885 *y = y.saturating_add((r * 1e-2 / quant.0) as i8);
3886 }
3887
3888 decoder
3889 .decode_tracked_quantized(
3890 &mut tracker,
3891 100_000_000 * i / 3, &[out.view().into()],
3893 &mut output_boxes,
3894 &mut output_masks,
3895 &mut output_tracks,
3896 )
3897 .unwrap();
3898
3899 assert_eq!(output_boxes.len(), 2);
3900 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 5e-3));
3901 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 5e-3));
3902
3903 assert!(output_boxes[0].equal_within_delta(&last_boxes[0], 1e-3));
3904 assert!(output_boxes[1].equal_within_delta(&last_boxes[1], 1e-3));
3905 last_boxes = output_boxes.clone();
3906 }
3907 }
3908
3909 fn real_data_expected_boxes() -> [DetectBox; 2] {
3912 [
3913 DetectBox {
3914 bbox: BoundingBox {
3915 xmin: 0.08515105,
3916 ymin: 0.7131401,
3917 xmax: 0.29802868,
3918 ymax: 0.8195788,
3919 },
3920 score: 0.91537374,
3921 label: 23,
3922 },
3923 DetectBox {
3924 bbox: BoundingBox {
3925 xmin: 0.59605736,
3926 ymin: 0.25545314,
3927 xmax: 0.93666154,
3928 ymax: 0.72378385,
3929 },
3930 score: 0.91537374,
3931 label: 23,
3932 },
3933 ]
3934 }
3935
3936 fn e2e_expected_boxes_quant() -> [DetectBox; 1] {
3937 [DetectBox {
3938 bbox: BoundingBox {
3939 xmin: 0.12549022,
3940 ymin: 0.12549022,
3941 xmax: 0.23529413,
3942 ymax: 0.23529413,
3943 },
3944 score: 0.98823535,
3945 label: 2,
3946 }]
3947 }
3948
3949 fn e2e_expected_boxes_float() -> [DetectBox; 1] {
3950 [DetectBox {
3951 bbox: BoundingBox {
3952 xmin: 0.1234,
3953 ymin: 0.1234,
3954 xmax: 0.2345,
3955 ymax: 0.2345,
3956 },
3957 score: 0.9876,
3958 label: 2,
3959 }]
3960 }
3961
3962 fn build_yolo_split_segdet_decoder(
3963 score_threshold: f32,
3964 iou_threshold: f32,
3965 quant_boxes: (f32, i32),
3966 quant_protos: (f32, i32),
3967 ) -> crate::Decoder {
3968 DecoderBuilder::default()
3969 .with_config_yolo_split_segdet(
3970 configs::Boxes {
3971 decoder: configs::DecoderType::Ultralytics,
3972 quantization: Some(quant_boxes.into()),
3973 shape: vec![1, 4, 8400],
3974 dshape: vec![
3975 (DimName::Batch, 1),
3976 (DimName::BoxCoords, 4),
3977 (DimName::NumBoxes, 8400),
3978 ],
3979 normalized: Some(true),
3980 },
3981 configs::Scores {
3982 decoder: configs::DecoderType::Ultralytics,
3983 quantization: Some(quant_boxes.into()),
3984 shape: vec![1, 80, 8400],
3985 dshape: vec![
3986 (DimName::Batch, 1),
3987 (DimName::NumClasses, 80),
3988 (DimName::NumBoxes, 8400),
3989 ],
3990 },
3991 configs::MaskCoefficients {
3992 decoder: configs::DecoderType::Ultralytics,
3993 quantization: Some(quant_boxes.into()),
3994 shape: vec![1, 32, 8400],
3995 dshape: vec![
3996 (DimName::Batch, 1),
3997 (DimName::NumProtos, 32),
3998 (DimName::NumBoxes, 8400),
3999 ],
4000 },
4001 configs::Protos {
4002 decoder: configs::DecoderType::Ultralytics,
4003 quantization: Some(quant_protos.into()),
4004 shape: vec![1, 160, 160, 32],
4005 dshape: vec![
4006 (DimName::Batch, 1),
4007 (DimName::Height, 160),
4008 (DimName::Width, 160),
4009 (DimName::NumProtos, 32),
4010 ],
4011 },
4012 )
4013 .with_score_threshold(score_threshold)
4014 .with_iou_threshold(iou_threshold)
4015 .build()
4016 .unwrap()
4017 }
4018
4019 fn build_yolov8_seg_decoder(score_threshold: f32, iou_threshold: f32) -> crate::Decoder {
4020 let config_yaml = include_str!(concat!(
4021 env!("CARGO_MANIFEST_DIR"),
4022 "/../../testdata/yolov8_seg.yaml"
4023 ));
4024 DecoderBuilder::default()
4025 .with_config_yaml_str(config_yaml.to_string())
4026 .with_score_threshold(score_threshold)
4027 .with_iou_threshold(iou_threshold)
4028 .build()
4029 .unwrap()
4030 }
4031
4032 macro_rules! real_data_tracked_test {
4039 ($name:ident, quantized, $layout:ident, $output:ident) => {
4040 #[test]
4041 fn $name() {
4042 let is_split = matches!(stringify!($layout), "split");
4043 let is_proto = matches!(stringify!($output), "proto");
4044
4045 let score_threshold = 0.45;
4046 let iou_threshold = 0.45;
4047 let quant_boxes = (0.021287762_f32, 31_i32);
4048 let quant_protos = (0.02491162_f32, -117_i32);
4049
4050 let raw_boxes = include_bytes!(concat!(
4051 env!("CARGO_MANIFEST_DIR"),
4052 "/../../testdata/yolov8_boxes_116x8400.bin"
4053 ));
4054 let raw_boxes = unsafe {
4055 std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len())
4056 };
4057 let boxes_i8 =
4058 ndarray::Array3::from_shape_vec((1, 116, 8400), raw_boxes.to_vec()).unwrap();
4059
4060 let raw_protos = include_bytes!(concat!(
4061 env!("CARGO_MANIFEST_DIR"),
4062 "/../../testdata/yolov8_protos_160x160x32.bin"
4063 ));
4064 let raw_protos = unsafe {
4065 std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
4066 };
4067 let protos_i8 =
4068 ndarray::Array4::from_shape_vec((1, 160, 160, 32), raw_protos.to_vec())
4069 .unwrap();
4070
4071 let mask_split = boxes_i8.slice(s![.., 84.., ..]).to_owned();
4073 let mut scores_split = boxes_i8.slice(s![.., 4..84, ..]).to_owned();
4074 let boxes_split = boxes_i8.slice(s![.., ..4, ..]).to_owned();
4075 let mut boxes_combined = boxes_i8;
4076
4077 let decoder = if is_split {
4078 build_yolo_split_segdet_decoder(
4079 score_threshold,
4080 iou_threshold,
4081 quant_boxes,
4082 quant_protos,
4083 )
4084 } else {
4085 build_yolov8_seg_decoder(score_threshold, iou_threshold)
4086 };
4087
4088 let expected = real_data_expected_boxes();
4089 let mut tracker = ByteTrackBuilder::new()
4090 .track_update(0.1)
4091 .track_high_conf(0.7)
4092 .build();
4093 let mut output_boxes = Vec::with_capacity(50);
4094 let mut output_tracks = Vec::with_capacity(50);
4095
4096 if is_proto {
4098 {
4099 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
4100 vec![
4101 boxes_split.view().into(),
4102 scores_split.view().into(),
4103 mask_split.view().into(),
4104 protos_i8.view().into(),
4105 ]
4106 } else {
4107 vec![boxes_combined.view().into(), protos_i8.view().into()]
4108 };
4109 decoder
4110 .decode_tracked_quantized_proto(
4111 &mut tracker,
4112 0,
4113 &inputs,
4114 &mut output_boxes,
4115 &mut output_tracks,
4116 )
4117 .unwrap();
4118 }
4119 assert_eq!(output_boxes.len(), 2);
4120 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4121 assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
4122
4123 if is_split {
4125 for score in scores_split.iter_mut() {
4126 *score = i8::MIN;
4127 }
4128 } else {
4129 for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
4130 *score = i8::MIN;
4131 }
4132 }
4133
4134 let proto_result = {
4135 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
4136 vec![
4137 boxes_split.view().into(),
4138 scores_split.view().into(),
4139 mask_split.view().into(),
4140 protos_i8.view().into(),
4141 ]
4142 } else {
4143 vec![boxes_combined.view().into(), protos_i8.view().into()]
4144 };
4145 decoder
4146 .decode_tracked_quantized_proto(
4147 &mut tracker,
4148 100_000_000 / 3,
4149 &inputs,
4150 &mut output_boxes,
4151 &mut output_tracks,
4152 )
4153 .unwrap()
4154 };
4155 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4156 assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
4157 assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
4158 } else {
4159 let mut output_masks = Vec::with_capacity(50);
4160 {
4161 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
4162 vec![
4163 boxes_split.view().into(),
4164 scores_split.view().into(),
4165 mask_split.view().into(),
4166 protos_i8.view().into(),
4167 ]
4168 } else {
4169 vec![boxes_combined.view().into(), protos_i8.view().into()]
4170 };
4171 decoder
4172 .decode_tracked_quantized(
4173 &mut tracker,
4174 0,
4175 &inputs,
4176 &mut output_boxes,
4177 &mut output_masks,
4178 &mut output_tracks,
4179 )
4180 .unwrap();
4181 }
4182 assert_eq!(output_boxes.len(), 2);
4183 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4184 assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
4185
4186 if is_split {
4187 for score in scores_split.iter_mut() {
4188 *score = i8::MIN;
4189 }
4190 } else {
4191 for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
4192 *score = i8::MIN;
4193 }
4194 }
4195
4196 {
4197 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
4198 vec![
4199 boxes_split.view().into(),
4200 scores_split.view().into(),
4201 mask_split.view().into(),
4202 protos_i8.view().into(),
4203 ]
4204 } else {
4205 vec![boxes_combined.view().into(), protos_i8.view().into()]
4206 };
4207 decoder
4208 .decode_tracked_quantized(
4209 &mut tracker,
4210 100_000_000 / 3,
4211 &inputs,
4212 &mut output_boxes,
4213 &mut output_masks,
4214 &mut output_tracks,
4215 )
4216 .unwrap();
4217 }
4218 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4219 assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
4220 assert!(output_masks.is_empty());
4221 }
4222 }
4223 };
4224 ($name:ident, float, $layout:ident, $output:ident) => {
4225 #[test]
4226 fn $name() {
4227 let is_split = matches!(stringify!($layout), "split");
4228 let is_proto = matches!(stringify!($output), "proto");
4229
4230 let score_threshold = 0.45;
4231 let iou_threshold = 0.45;
4232 let quant_boxes = (0.021287762_f32, 31_i32);
4233 let quant_protos = (0.02491162_f32, -117_i32);
4234
4235 let raw_boxes = include_bytes!(concat!(
4236 env!("CARGO_MANIFEST_DIR"),
4237 "/../../testdata/yolov8_boxes_116x8400.bin"
4238 ));
4239 let raw_boxes = unsafe {
4240 std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len())
4241 };
4242 let boxes_i8 =
4243 ndarray::Array3::from_shape_vec((1, 116, 8400), raw_boxes.to_vec()).unwrap();
4244 let boxes_f32 = dequantize_ndarray(boxes_i8.view(), quant_boxes.into());
4245
4246 let raw_protos = include_bytes!(concat!(
4247 env!("CARGO_MANIFEST_DIR"),
4248 "/../../testdata/yolov8_protos_160x160x32.bin"
4249 ));
4250 let raw_protos = unsafe {
4251 std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
4252 };
4253 let protos_i8 =
4254 ndarray::Array4::from_shape_vec((1, 160, 160, 32), raw_protos.to_vec())
4255 .unwrap();
4256 let protos_f32 = dequantize_ndarray(protos_i8.view(), quant_protos.into());
4257
4258 let mask_split = boxes_f32.slice(s![.., 84.., ..]).to_owned();
4260 let mut scores_split = boxes_f32.slice(s![.., 4..84, ..]).to_owned();
4261 let boxes_split = boxes_f32.slice(s![.., ..4, ..]).to_owned();
4262 let mut boxes_combined = boxes_f32;
4263
4264 let decoder = if is_split {
4265 build_yolo_split_segdet_decoder(
4266 score_threshold,
4267 iou_threshold,
4268 quant_boxes,
4269 quant_protos,
4270 )
4271 } else {
4272 build_yolov8_seg_decoder(score_threshold, iou_threshold)
4273 };
4274
4275 let expected = real_data_expected_boxes();
4276 let mut tracker = ByteTrackBuilder::new()
4277 .track_update(0.1)
4278 .track_high_conf(0.7)
4279 .build();
4280 let mut output_boxes = Vec::with_capacity(50);
4281 let mut output_tracks = Vec::with_capacity(50);
4282
4283 if is_proto {
4284 {
4285 let inputs = if is_split {
4286 vec![
4287 boxes_split.view().into_dyn(),
4288 scores_split.view().into_dyn(),
4289 mask_split.view().into_dyn(),
4290 protos_f32.view().into_dyn(),
4291 ]
4292 } else {
4293 vec![
4294 boxes_combined.view().into_dyn(),
4295 protos_f32.view().into_dyn(),
4296 ]
4297 };
4298 decoder
4299 .decode_tracked_float_proto(
4300 &mut tracker,
4301 0,
4302 &inputs,
4303 &mut output_boxes,
4304 &mut output_tracks,
4305 )
4306 .unwrap();
4307 }
4308 assert_eq!(output_boxes.len(), 2);
4309 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4310 assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
4311
4312 if is_split {
4313 for score in scores_split.iter_mut() {
4314 *score = 0.0;
4315 }
4316 } else {
4317 for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
4318 *score = 0.0;
4319 }
4320 }
4321
4322 let proto_result = {
4323 let inputs = if is_split {
4324 vec![
4325 boxes_split.view().into_dyn(),
4326 scores_split.view().into_dyn(),
4327 mask_split.view().into_dyn(),
4328 protos_f32.view().into_dyn(),
4329 ]
4330 } else {
4331 vec![
4332 boxes_combined.view().into_dyn(),
4333 protos_f32.view().into_dyn(),
4334 ]
4335 };
4336 decoder
4337 .decode_tracked_float_proto(
4338 &mut tracker,
4339 100_000_000 / 3,
4340 &inputs,
4341 &mut output_boxes,
4342 &mut output_tracks,
4343 )
4344 .unwrap()
4345 };
4346 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4347 assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
4348 assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
4349 } else {
4350 let mut output_masks = Vec::with_capacity(50);
4351 {
4352 let inputs = if is_split {
4353 vec![
4354 boxes_split.view().into_dyn(),
4355 scores_split.view().into_dyn(),
4356 mask_split.view().into_dyn(),
4357 protos_f32.view().into_dyn(),
4358 ]
4359 } else {
4360 vec![
4361 boxes_combined.view().into_dyn(),
4362 protos_f32.view().into_dyn(),
4363 ]
4364 };
4365 decoder
4366 .decode_tracked_float(
4367 &mut tracker,
4368 0,
4369 &inputs,
4370 &mut output_boxes,
4371 &mut output_masks,
4372 &mut output_tracks,
4373 )
4374 .unwrap();
4375 }
4376 assert_eq!(output_boxes.len(), 2);
4377 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4378 assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
4379
4380 if is_split {
4381 for score in scores_split.iter_mut() {
4382 *score = 0.0;
4383 }
4384 } else {
4385 for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
4386 *score = 0.0;
4387 }
4388 }
4389
4390 {
4391 let inputs = if is_split {
4392 vec![
4393 boxes_split.view().into_dyn(),
4394 scores_split.view().into_dyn(),
4395 mask_split.view().into_dyn(),
4396 protos_f32.view().into_dyn(),
4397 ]
4398 } else {
4399 vec![
4400 boxes_combined.view().into_dyn(),
4401 protos_f32.view().into_dyn(),
4402 ]
4403 };
4404 decoder
4405 .decode_tracked_float(
4406 &mut tracker,
4407 100_000_000 / 3,
4408 &inputs,
4409 &mut output_boxes,
4410 &mut output_masks,
4411 &mut output_tracks,
4412 )
4413 .unwrap();
4414 }
4415 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4416 assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
4417 assert!(output_masks.is_empty());
4418 }
4419 }
4420 };
4421 }
4422
4423 real_data_tracked_test!(test_decoder_tracked_segdet, quantized, combined, masks);
4424 real_data_tracked_test!(test_decoder_tracked_segdet_float, float, combined, masks);
4425 real_data_tracked_test!(
4426 test_decoder_tracked_segdet_proto,
4427 quantized,
4428 combined,
4429 proto
4430 );
4431 real_data_tracked_test!(
4432 test_decoder_tracked_segdet_proto_float,
4433 float,
4434 combined,
4435 proto
4436 );
4437 real_data_tracked_test!(test_decoder_tracked_segdet_split, quantized, split, masks);
4438 real_data_tracked_test!(test_decoder_tracked_segdet_split_float, float, split, masks);
4439 real_data_tracked_test!(
4440 test_decoder_tracked_segdet_split_proto,
4441 quantized,
4442 split,
4443 proto
4444 );
4445 real_data_tracked_test!(
4446 test_decoder_tracked_segdet_split_proto_float,
4447 float,
4448 split,
4449 proto
4450 );
4451
4452 const E2E_COMBINED_CONFIG: &str = "
4458decoder_version: yolo26
4459outputs:
4460 - type: detection
4461 decoder: ultralytics
4462 quantization: [0.00784313725490196, 0]
4463 shape: [1, 10, 38]
4464 dshape:
4465 - [batch, 1]
4466 - [num_boxes, 10]
4467 - [num_features, 38]
4468 normalized: true
4469 - type: protos
4470 decoder: ultralytics
4471 quantization: [0.0039215686274509803921568627451, 128]
4472 shape: [1, 160, 160, 32]
4473 dshape:
4474 - [batch, 1]
4475 - [height, 160]
4476 - [width, 160]
4477 - [num_protos, 32]
4478";
4479
4480 const E2E_SPLIT_CONFIG: &str = "
4481decoder_version: yolo26
4482outputs:
4483 - type: boxes
4484 decoder: ultralytics
4485 quantization: [0.00784313725490196, 0]
4486 shape: [1, 10, 4]
4487 dshape:
4488 - [batch, 1]
4489 - [num_boxes, 10]
4490 - [box_coords, 4]
4491 normalized: true
4492 - type: scores
4493 decoder: ultralytics
4494 quantization: [0.00784313725490196, 0]
4495 shape: [1, 10, 1]
4496 dshape:
4497 - [batch, 1]
4498 - [num_boxes, 10]
4499 - [num_classes, 1]
4500 - type: classes
4501 decoder: ultralytics
4502 quantization: [0.00784313725490196, 0]
4503 shape: [1, 10, 1]
4504 dshape:
4505 - [batch, 1]
4506 - [num_boxes, 10]
4507 - [num_classes, 1]
4508 - type: mask_coefficients
4509 decoder: ultralytics
4510 quantization: [0.00784313725490196, 0]
4511 shape: [1, 10, 32]
4512 dshape:
4513 - [batch, 1]
4514 - [num_boxes, 10]
4515 - [num_protos, 32]
4516 - type: protos
4517 decoder: ultralytics
4518 quantization: [0.0039215686274509803921568627451, 128]
4519 shape: [1, 160, 160, 32]
4520 dshape:
4521 - [batch, 1]
4522 - [height, 160]
4523 - [width, 160]
4524 - [num_protos, 32]
4525";
4526
4527 macro_rules! e2e_tracked_test {
4528 ($name:ident, quantized, $layout:ident, $output:ident) => {
4529 #[test]
4530 fn $name() {
4531 let is_split = matches!(stringify!($layout), "split");
4532 let is_proto = matches!(stringify!($output), "proto");
4533
4534 let score_threshold = 0.45;
4535 let iou_threshold = 0.45;
4536
4537 let mut boxes = Array2::zeros((10, 4));
4538 let mut scores = Array2::zeros((10, 1));
4539 let mut classes = Array2::zeros((10, 1));
4540 let mask = Array2::zeros((10, 32));
4541 let protos = Array3::<f64>::zeros((160, 160, 32));
4542 let protos = protos.insert_axis(Axis(0));
4543 let protos_quant = (1.0 / 255.0, 0.0);
4544 let protos: Array4<u8> = quantize_ndarray(protos.view(), protos_quant.into());
4545
4546 boxes
4547 .slice_mut(s![0, ..])
4548 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4549 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4550 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4551
4552 let detect_quant = (2.0 / 255.0, 0.0);
4553
4554 let decoder = if is_split {
4555 DecoderBuilder::default()
4556 .with_config_yaml_str(E2E_SPLIT_CONFIG.to_string())
4557 .with_score_threshold(score_threshold)
4558 .with_iou_threshold(iou_threshold)
4559 .build()
4560 .unwrap()
4561 } else {
4562 DecoderBuilder::default()
4563 .with_config_yaml_str(E2E_COMBINED_CONFIG.to_string())
4564 .with_score_threshold(score_threshold)
4565 .with_iou_threshold(iou_threshold)
4566 .build()
4567 .unwrap()
4568 };
4569
4570 let expected = e2e_expected_boxes_quant();
4571 let mut tracker = ByteTrackBuilder::new()
4572 .track_update(0.1)
4573 .track_high_conf(0.7)
4574 .build();
4575 let mut output_boxes = Vec::with_capacity(50);
4576 let mut output_tracks = Vec::with_capacity(50);
4577
4578 if is_split {
4579 let boxes = boxes.insert_axis(Axis(0));
4580 let scores = scores.insert_axis(Axis(0));
4581 let classes = classes.insert_axis(Axis(0));
4582 let mask = mask.insert_axis(Axis(0));
4583
4584 let boxes: Array3<u8> = quantize_ndarray(boxes.view(), detect_quant.into());
4585 let mut scores: Array3<u8> =
4586 quantize_ndarray(scores.view(), detect_quant.into());
4587 let classes: Array3<u8> = quantize_ndarray(classes.view(), detect_quant.into());
4588 let mask: Array3<u8> = quantize_ndarray(mask.view(), detect_quant.into());
4589
4590 if is_proto {
4591 {
4592 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
4593 boxes.view().into(),
4594 scores.view().into(),
4595 classes.view().into(),
4596 mask.view().into(),
4597 protos.view().into(),
4598 ];
4599 decoder
4600 .decode_tracked_quantized_proto(
4601 &mut tracker,
4602 0,
4603 &inputs,
4604 &mut output_boxes,
4605 &mut output_tracks,
4606 )
4607 .unwrap();
4608 }
4609 assert_eq!(output_boxes.len(), 1);
4610 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4611
4612 for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
4613 *score = u8::MIN;
4614 }
4615 let proto_result = {
4616 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
4617 boxes.view().into(),
4618 scores.view().into(),
4619 classes.view().into(),
4620 mask.view().into(),
4621 protos.view().into(),
4622 ];
4623 decoder
4624 .decode_tracked_quantized_proto(
4625 &mut tracker,
4626 100_000_000 / 3,
4627 &inputs,
4628 &mut output_boxes,
4629 &mut output_tracks,
4630 )
4631 .unwrap()
4632 };
4633 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4634 assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
4635 } else {
4636 let mut output_masks = Vec::with_capacity(50);
4637 {
4638 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
4639 boxes.view().into(),
4640 scores.view().into(),
4641 classes.view().into(),
4642 mask.view().into(),
4643 protos.view().into(),
4644 ];
4645 decoder
4646 .decode_tracked_quantized(
4647 &mut tracker,
4648 0,
4649 &inputs,
4650 &mut output_boxes,
4651 &mut output_masks,
4652 &mut output_tracks,
4653 )
4654 .unwrap();
4655 }
4656 assert_eq!(output_boxes.len(), 1);
4657 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4658
4659 for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
4660 *score = u8::MIN;
4661 }
4662 {
4663 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
4664 boxes.view().into(),
4665 scores.view().into(),
4666 classes.view().into(),
4667 mask.view().into(),
4668 protos.view().into(),
4669 ];
4670 decoder
4671 .decode_tracked_quantized(
4672 &mut tracker,
4673 100_000_000 / 3,
4674 &inputs,
4675 &mut output_boxes,
4676 &mut output_masks,
4677 &mut output_tracks,
4678 )
4679 .unwrap();
4680 }
4681 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4682 assert!(output_masks.is_empty());
4683 }
4684 } else {
4685 let detect = ndarray::concatenate![
4687 Axis(1),
4688 boxes.view(),
4689 scores.view(),
4690 classes.view(),
4691 mask.view()
4692 ];
4693 let detect = detect.insert_axis(Axis(0));
4694 assert_eq!(detect.shape(), &[1, 10, 38]);
4695 let mut detect: Array3<u8> =
4696 quantize_ndarray(detect.view(), detect_quant.into());
4697
4698 if is_proto {
4699 {
4700 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
4701 vec![detect.view().into(), protos.view().into()];
4702 decoder
4703 .decode_tracked_quantized_proto(
4704 &mut tracker,
4705 0,
4706 &inputs,
4707 &mut output_boxes,
4708 &mut output_tracks,
4709 )
4710 .unwrap();
4711 }
4712 assert_eq!(output_boxes.len(), 1);
4713 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4714
4715 for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
4716 *score = u8::MIN;
4717 }
4718 let proto_result = {
4719 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
4720 vec![detect.view().into(), protos.view().into()];
4721 decoder
4722 .decode_tracked_quantized_proto(
4723 &mut tracker,
4724 100_000_000 / 3,
4725 &inputs,
4726 &mut output_boxes,
4727 &mut output_tracks,
4728 )
4729 .unwrap()
4730 };
4731 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4732 assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
4733 } else {
4734 let mut output_masks = Vec::with_capacity(50);
4735 {
4736 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
4737 vec![detect.view().into(), protos.view().into()];
4738 decoder
4739 .decode_tracked_quantized(
4740 &mut tracker,
4741 0,
4742 &inputs,
4743 &mut output_boxes,
4744 &mut output_masks,
4745 &mut output_tracks,
4746 )
4747 .unwrap();
4748 }
4749 assert_eq!(output_boxes.len(), 1);
4750 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4751
4752 for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
4753 *score = u8::MIN;
4754 }
4755 {
4756 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
4757 vec![detect.view().into(), protos.view().into()];
4758 decoder
4759 .decode_tracked_quantized(
4760 &mut tracker,
4761 100_000_000 / 3,
4762 &inputs,
4763 &mut output_boxes,
4764 &mut output_masks,
4765 &mut output_tracks,
4766 )
4767 .unwrap();
4768 }
4769 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4770 assert!(output_masks.is_empty());
4771 }
4772 }
4773 }
4774 };
4775 ($name:ident, float, $layout:ident, $output:ident) => {
4776 #[test]
4777 fn $name() {
4778 let is_split = matches!(stringify!($layout), "split");
4779 let is_proto = matches!(stringify!($output), "proto");
4780
4781 let score_threshold = 0.45;
4782 let iou_threshold = 0.45;
4783
4784 let mut boxes = Array2::zeros((10, 4));
4785 let mut scores = Array2::zeros((10, 1));
4786 let mut classes = Array2::zeros((10, 1));
4787 let mask: Array2<f64> = Array2::zeros((10, 32));
4788 let protos = Array3::<f64>::zeros((160, 160, 32));
4789 let protos = protos.insert_axis(Axis(0));
4790
4791 boxes
4792 .slice_mut(s![0, ..])
4793 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4794 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4795 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4796
4797 let decoder = if is_split {
4798 DecoderBuilder::default()
4799 .with_config_yaml_str(E2E_SPLIT_CONFIG.to_string())
4800 .with_score_threshold(score_threshold)
4801 .with_iou_threshold(iou_threshold)
4802 .build()
4803 .unwrap()
4804 } else {
4805 DecoderBuilder::default()
4806 .with_config_yaml_str(E2E_COMBINED_CONFIG.to_string())
4807 .with_score_threshold(score_threshold)
4808 .with_iou_threshold(iou_threshold)
4809 .build()
4810 .unwrap()
4811 };
4812
4813 let expected = e2e_expected_boxes_float();
4814 let mut tracker = ByteTrackBuilder::new()
4815 .track_update(0.1)
4816 .track_high_conf(0.7)
4817 .build();
4818 let mut output_boxes = Vec::with_capacity(50);
4819 let mut output_tracks = Vec::with_capacity(50);
4820
4821 if is_split {
4822 let boxes = boxes.insert_axis(Axis(0));
4823 let mut scores = scores.insert_axis(Axis(0));
4824 let classes = classes.insert_axis(Axis(0));
4825 let mask = mask.insert_axis(Axis(0));
4826
4827 if is_proto {
4828 {
4829 let inputs = vec![
4830 boxes.view().into_dyn(),
4831 scores.view().into_dyn(),
4832 classes.view().into_dyn(),
4833 mask.view().into_dyn(),
4834 protos.view().into_dyn(),
4835 ];
4836 decoder
4837 .decode_tracked_float_proto(
4838 &mut tracker,
4839 0,
4840 &inputs,
4841 &mut output_boxes,
4842 &mut output_tracks,
4843 )
4844 .unwrap();
4845 }
4846 assert_eq!(output_boxes.len(), 1);
4847 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4848
4849 for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
4850 *score = 0.0;
4851 }
4852 let proto_result = {
4853 let inputs = vec![
4854 boxes.view().into_dyn(),
4855 scores.view().into_dyn(),
4856 classes.view().into_dyn(),
4857 mask.view().into_dyn(),
4858 protos.view().into_dyn(),
4859 ];
4860 decoder
4861 .decode_tracked_float_proto(
4862 &mut tracker,
4863 100_000_000 / 3,
4864 &inputs,
4865 &mut output_boxes,
4866 &mut output_tracks,
4867 )
4868 .unwrap()
4869 };
4870 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4871 assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
4872 } else {
4873 let mut output_masks = Vec::with_capacity(50);
4874 {
4875 let inputs = vec![
4876 boxes.view().into_dyn(),
4877 scores.view().into_dyn(),
4878 classes.view().into_dyn(),
4879 mask.view().into_dyn(),
4880 protos.view().into_dyn(),
4881 ];
4882 decoder
4883 .decode_tracked_float(
4884 &mut tracker,
4885 0,
4886 &inputs,
4887 &mut output_boxes,
4888 &mut output_masks,
4889 &mut output_tracks,
4890 )
4891 .unwrap();
4892 }
4893 assert_eq!(output_boxes.len(), 1);
4894 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4895
4896 for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
4897 *score = 0.0;
4898 }
4899 {
4900 let inputs = vec![
4901 boxes.view().into_dyn(),
4902 scores.view().into_dyn(),
4903 classes.view().into_dyn(),
4904 mask.view().into_dyn(),
4905 protos.view().into_dyn(),
4906 ];
4907 decoder
4908 .decode_tracked_float(
4909 &mut tracker,
4910 100_000_000 / 3,
4911 &inputs,
4912 &mut output_boxes,
4913 &mut output_masks,
4914 &mut output_tracks,
4915 )
4916 .unwrap();
4917 }
4918 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4919 assert!(output_masks.is_empty());
4920 }
4921 } else {
4922 let detect = ndarray::concatenate![
4924 Axis(1),
4925 boxes.view(),
4926 scores.view(),
4927 classes.view(),
4928 mask.view()
4929 ];
4930 let mut detect = detect.insert_axis(Axis(0));
4931 assert_eq!(detect.shape(), &[1, 10, 38]);
4932
4933 if is_proto {
4934 {
4935 let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
4936 decoder
4937 .decode_tracked_float_proto(
4938 &mut tracker,
4939 0,
4940 &inputs,
4941 &mut output_boxes,
4942 &mut output_tracks,
4943 )
4944 .unwrap();
4945 }
4946 assert_eq!(output_boxes.len(), 1);
4947 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4948
4949 for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
4950 *score = 0.0;
4951 }
4952 let proto_result = {
4953 let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
4954 decoder
4955 .decode_tracked_float_proto(
4956 &mut tracker,
4957 100_000_000 / 3,
4958 &inputs,
4959 &mut output_boxes,
4960 &mut output_tracks,
4961 )
4962 .unwrap()
4963 };
4964 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4965 assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
4966 } else {
4967 let mut output_masks = Vec::with_capacity(50);
4968 {
4969 let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
4970 decoder
4971 .decode_tracked_float(
4972 &mut tracker,
4973 0,
4974 &inputs,
4975 &mut output_boxes,
4976 &mut output_masks,
4977 &mut output_tracks,
4978 )
4979 .unwrap();
4980 }
4981 assert_eq!(output_boxes.len(), 1);
4982 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4983
4984 for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
4985 *score = 0.0;
4986 }
4987 {
4988 let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
4989 decoder
4990 .decode_tracked_float(
4991 &mut tracker,
4992 100_000_000 / 3,
4993 &inputs,
4994 &mut output_boxes,
4995 &mut output_masks,
4996 &mut output_tracks,
4997 )
4998 .unwrap();
4999 }
5000 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5001 assert!(output_masks.is_empty());
5002 }
5003 }
5004 }
5005 };
5006 }
5007
5008 e2e_tracked_test!(
5009 test_decoder_tracked_end_to_end_segdet,
5010 quantized,
5011 combined,
5012 masks
5013 );
5014 e2e_tracked_test!(
5015 test_decoder_tracked_end_to_end_segdet_float,
5016 float,
5017 combined,
5018 masks
5019 );
5020 e2e_tracked_test!(
5021 test_decoder_tracked_end_to_end_segdet_proto,
5022 quantized,
5023 combined,
5024 proto
5025 );
5026 e2e_tracked_test!(
5027 test_decoder_tracked_end_to_end_segdet_proto_float,
5028 float,
5029 combined,
5030 proto
5031 );
5032 e2e_tracked_test!(
5033 test_decoder_tracked_end_to_end_segdet_split,
5034 quantized,
5035 split,
5036 masks
5037 );
5038 e2e_tracked_test!(
5039 test_decoder_tracked_end_to_end_segdet_split_float,
5040 float,
5041 split,
5042 masks
5043 );
5044 e2e_tracked_test!(
5045 test_decoder_tracked_end_to_end_segdet_split_proto,
5046 quantized,
5047 split,
5048 proto
5049 );
5050 e2e_tracked_test!(
5051 test_decoder_tracked_end_to_end_segdet_split_proto_float,
5052 float,
5053 split,
5054 proto
5055 );
5056
5057 macro_rules! e2e_tracked_tensor_test {
5063 ($name:ident, quantized, $layout:ident, $output:ident) => {
5064 #[test]
5065 fn $name() {
5066 use edgefirst_tensor::{Tensor, TensorMapTrait, TensorTrait};
5067
5068 let is_split = matches!(stringify!($layout), "split");
5069 let is_proto = matches!(stringify!($output), "proto");
5070
5071 let score_threshold = 0.45;
5072 let iou_threshold = 0.45;
5073
5074 let mut boxes = Array2::zeros((10, 4));
5075 let mut scores = Array2::zeros((10, 1));
5076 let mut classes = Array2::zeros((10, 1));
5077 let mask = Array2::zeros((10, 32));
5078 let protos_f64 = Array3::<f64>::zeros((160, 160, 32));
5079 let protos_f64 = protos_f64.insert_axis(Axis(0));
5080 let protos_quant = (1.0 / 255.0, 0.0);
5081 let protos_u8: Array4<u8> =
5082 quantize_ndarray(protos_f64.view(), protos_quant.into());
5083
5084 boxes
5085 .slice_mut(s![0, ..])
5086 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
5087 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
5088 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
5089
5090 let detect_quant = (2.0 / 255.0, 0.0);
5091
5092 let decoder = if is_split {
5093 DecoderBuilder::default()
5094 .with_config_yaml_str(E2E_SPLIT_CONFIG.to_string())
5095 .with_score_threshold(score_threshold)
5096 .with_iou_threshold(iou_threshold)
5097 .build()
5098 .unwrap()
5099 } else {
5100 DecoderBuilder::default()
5101 .with_config_yaml_str(E2E_COMBINED_CONFIG.to_string())
5102 .with_score_threshold(score_threshold)
5103 .with_iou_threshold(iou_threshold)
5104 .build()
5105 .unwrap()
5106 };
5107
5108 let make_u8_tensor =
5110 |shape: &[usize], data: &[u8]| -> edgefirst_tensor::TensorDyn {
5111 let t = Tensor::<u8>::new(shape, None, None).unwrap();
5112 t.map().unwrap().as_mut_slice()[..data.len()].copy_from_slice(data);
5113 t.into()
5114 };
5115
5116 let expected = e2e_expected_boxes_quant();
5117 let mut tracker = ByteTrackBuilder::new()
5118 .track_update(0.1)
5119 .track_high_conf(0.7)
5120 .build();
5121 let mut output_boxes = Vec::with_capacity(50);
5122 let mut output_tracks = Vec::with_capacity(50);
5123
5124 let protos_td = make_u8_tensor(protos_u8.shape(), protos_u8.as_slice().unwrap());
5125
5126 if is_split {
5127 let boxes = boxes.insert_axis(Axis(0));
5128 let scores = scores.insert_axis(Axis(0));
5129 let classes = classes.insert_axis(Axis(0));
5130 let mask = mask.insert_axis(Axis(0));
5131
5132 let boxes_q: Array3<u8> = quantize_ndarray(boxes.view(), detect_quant.into());
5133 let mut scores_q: Array3<u8> =
5134 quantize_ndarray(scores.view(), detect_quant.into());
5135 let classes_q: Array3<u8> =
5136 quantize_ndarray(classes.view(), detect_quant.into());
5137 let mask_q: Array3<u8> = quantize_ndarray(mask.view(), detect_quant.into());
5138
5139 let boxes_td = make_u8_tensor(boxes_q.shape(), boxes_q.as_slice().unwrap());
5140 let classes_td =
5141 make_u8_tensor(classes_q.shape(), classes_q.as_slice().unwrap());
5142 let mask_td = make_u8_tensor(mask_q.shape(), mask_q.as_slice().unwrap());
5143
5144 if is_proto {
5145 let scores_td =
5146 make_u8_tensor(scores_q.shape(), scores_q.as_slice().unwrap());
5147 decoder
5148 .decode_proto_tracked(
5149 &mut tracker,
5150 0,
5151 &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5152 &mut output_boxes,
5153 &mut output_tracks,
5154 )
5155 .unwrap();
5156
5157 assert_eq!(output_boxes.len(), 1);
5158 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5159
5160 for score in scores_q.slice_mut(s![.., .., ..]).iter_mut() {
5161 *score = u8::MIN;
5162 }
5163 let scores_td =
5164 make_u8_tensor(scores_q.shape(), scores_q.as_slice().unwrap());
5165 let proto_result = decoder
5166 .decode_proto_tracked(
5167 &mut tracker,
5168 100_000_000 / 3,
5169 &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5170 &mut output_boxes,
5171 &mut output_tracks,
5172 )
5173 .unwrap();
5174 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5175 assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
5176 } else {
5177 let scores_td =
5178 make_u8_tensor(scores_q.shape(), scores_q.as_slice().unwrap());
5179 let mut output_masks = Vec::with_capacity(50);
5180 decoder
5181 .decode_tracked(
5182 &mut tracker,
5183 0,
5184 &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5185 &mut output_boxes,
5186 &mut output_masks,
5187 &mut output_tracks,
5188 )
5189 .unwrap();
5190
5191 assert_eq!(output_boxes.len(), 1);
5192 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5193
5194 for score in scores_q.slice_mut(s![.., .., ..]).iter_mut() {
5195 *score = u8::MIN;
5196 }
5197 let scores_td =
5198 make_u8_tensor(scores_q.shape(), scores_q.as_slice().unwrap());
5199 decoder
5200 .decode_tracked(
5201 &mut tracker,
5202 100_000_000 / 3,
5203 &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5204 &mut output_boxes,
5205 &mut output_masks,
5206 &mut output_tracks,
5207 )
5208 .unwrap();
5209 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5210 assert!(output_masks.is_empty());
5211 }
5212 } else {
5213 let detect = ndarray::concatenate![
5215 Axis(1),
5216 boxes.view(),
5217 scores.view(),
5218 classes.view(),
5219 mask.view()
5220 ];
5221 let detect = detect.insert_axis(Axis(0));
5222 assert_eq!(detect.shape(), &[1, 10, 38]);
5223 let detect =
5225 Array3::from_shape_vec(detect.raw_dim(), detect.iter().copied().collect())
5226 .unwrap();
5227 let mut detect_q: Array3<u8> =
5228 quantize_ndarray(detect.view(), detect_quant.into());
5229
5230 if is_proto {
5231 let detect_td =
5232 make_u8_tensor(detect_q.shape(), detect_q.as_slice().unwrap());
5233 decoder
5234 .decode_proto_tracked(
5235 &mut tracker,
5236 0,
5237 &[&detect_td, &protos_td],
5238 &mut output_boxes,
5239 &mut output_tracks,
5240 )
5241 .unwrap();
5242
5243 assert_eq!(output_boxes.len(), 1);
5244 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5245
5246 for score in detect_q.slice_mut(s![.., .., 4]).iter_mut() {
5247 *score = u8::MIN;
5248 }
5249 let detect_td =
5250 make_u8_tensor(detect_q.shape(), detect_q.as_slice().unwrap());
5251 let proto_result = decoder
5252 .decode_proto_tracked(
5253 &mut tracker,
5254 100_000_000 / 3,
5255 &[&detect_td, &protos_td],
5256 &mut output_boxes,
5257 &mut output_tracks,
5258 )
5259 .unwrap();
5260 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5261 assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
5262 } else {
5263 let detect_td =
5264 make_u8_tensor(detect_q.shape(), detect_q.as_slice().unwrap());
5265 let mut output_masks = Vec::with_capacity(50);
5266 decoder
5267 .decode_tracked(
5268 &mut tracker,
5269 0,
5270 &[&detect_td, &protos_td],
5271 &mut output_boxes,
5272 &mut output_masks,
5273 &mut output_tracks,
5274 )
5275 .unwrap();
5276
5277 assert_eq!(output_boxes.len(), 1);
5278 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5279
5280 for score in detect_q.slice_mut(s![.., .., 4]).iter_mut() {
5281 *score = u8::MIN;
5282 }
5283 let detect_td =
5284 make_u8_tensor(detect_q.shape(), detect_q.as_slice().unwrap());
5285 decoder
5286 .decode_tracked(
5287 &mut tracker,
5288 100_000_000 / 3,
5289 &[&detect_td, &protos_td],
5290 &mut output_boxes,
5291 &mut output_masks,
5292 &mut output_tracks,
5293 )
5294 .unwrap();
5295 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5296 assert!(output_masks.is_empty());
5297 }
5298 }
5299 }
5300 };
5301 ($name:ident, float, $layout:ident, $output:ident) => {
5302 #[test]
5303 fn $name() {
5304 use edgefirst_tensor::{Tensor, TensorMapTrait, TensorTrait};
5305
5306 let is_split = matches!(stringify!($layout), "split");
5307 let is_proto = matches!(stringify!($output), "proto");
5308
5309 let score_threshold = 0.45;
5310 let iou_threshold = 0.45;
5311
5312 let mut boxes = Array2::zeros((10, 4));
5313 let mut scores = Array2::zeros((10, 1));
5314 let mut classes = Array2::zeros((10, 1));
5315 let mask: Array2<f64> = Array2::zeros((10, 32));
5316 let protos = Array3::<f64>::zeros((160, 160, 32));
5317 let protos = protos.insert_axis(Axis(0));
5318
5319 boxes
5320 .slice_mut(s![0, ..])
5321 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
5322 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
5323 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
5324
5325 let decoder = if is_split {
5326 DecoderBuilder::default()
5327 .with_config_yaml_str(E2E_SPLIT_CONFIG.to_string())
5328 .with_score_threshold(score_threshold)
5329 .with_iou_threshold(iou_threshold)
5330 .build()
5331 .unwrap()
5332 } else {
5333 DecoderBuilder::default()
5334 .with_config_yaml_str(E2E_COMBINED_CONFIG.to_string())
5335 .with_score_threshold(score_threshold)
5336 .with_iou_threshold(iou_threshold)
5337 .build()
5338 .unwrap()
5339 };
5340
5341 let make_f64_tensor =
5343 |shape: &[usize], data: &[f64]| -> edgefirst_tensor::TensorDyn {
5344 let t = Tensor::<f64>::new(shape, None, None).unwrap();
5345 t.map().unwrap().as_mut_slice()[..data.len()].copy_from_slice(data);
5346 t.into()
5347 };
5348
5349 let expected = e2e_expected_boxes_float();
5350 let mut tracker = ByteTrackBuilder::new()
5351 .track_update(0.1)
5352 .track_high_conf(0.7)
5353 .build();
5354 let mut output_boxes = Vec::with_capacity(50);
5355 let mut output_tracks = Vec::with_capacity(50);
5356
5357 let protos_td = make_f64_tensor(protos.shape(), protos.as_slice().unwrap());
5358
5359 if is_split {
5360 let boxes = boxes.insert_axis(Axis(0));
5361 let mut scores = scores.insert_axis(Axis(0));
5362 let classes = classes.insert_axis(Axis(0));
5363 let mask = mask.insert_axis(Axis(0));
5364
5365 let boxes_td = make_f64_tensor(boxes.shape(), boxes.as_slice().unwrap());
5366 let classes_td = make_f64_tensor(classes.shape(), classes.as_slice().unwrap());
5367 let mask_td = make_f64_tensor(mask.shape(), mask.as_slice().unwrap());
5368
5369 if is_proto {
5370 let scores_td = make_f64_tensor(scores.shape(), scores.as_slice().unwrap());
5371 decoder
5372 .decode_proto_tracked(
5373 &mut tracker,
5374 0,
5375 &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5376 &mut output_boxes,
5377 &mut output_tracks,
5378 )
5379 .unwrap();
5380
5381 assert_eq!(output_boxes.len(), 1);
5382 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5383
5384 for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
5385 *score = 0.0;
5386 }
5387 let scores_td = make_f64_tensor(scores.shape(), scores.as_slice().unwrap());
5388 let proto_result = decoder
5389 .decode_proto_tracked(
5390 &mut tracker,
5391 100_000_000 / 3,
5392 &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5393 &mut output_boxes,
5394 &mut output_tracks,
5395 )
5396 .unwrap();
5397 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5398 assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
5399 } else {
5400 let scores_td = make_f64_tensor(scores.shape(), scores.as_slice().unwrap());
5401 let mut output_masks = Vec::with_capacity(50);
5402 decoder
5403 .decode_tracked(
5404 &mut tracker,
5405 0,
5406 &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5407 &mut output_boxes,
5408 &mut output_masks,
5409 &mut output_tracks,
5410 )
5411 .unwrap();
5412
5413 assert_eq!(output_boxes.len(), 1);
5414 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5415
5416 for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
5417 *score = 0.0;
5418 }
5419 let scores_td = make_f64_tensor(scores.shape(), scores.as_slice().unwrap());
5420 decoder
5421 .decode_tracked(
5422 &mut tracker,
5423 100_000_000 / 3,
5424 &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5425 &mut output_boxes,
5426 &mut output_masks,
5427 &mut output_tracks,
5428 )
5429 .unwrap();
5430 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5431 assert!(output_masks.is_empty());
5432 }
5433 } else {
5434 let detect = ndarray::concatenate![
5436 Axis(1),
5437 boxes.view(),
5438 scores.view(),
5439 classes.view(),
5440 mask.view()
5441 ];
5442 let detect = detect.insert_axis(Axis(0));
5443 assert_eq!(detect.shape(), &[1, 10, 38]);
5444 let mut detect =
5446 Array3::from_shape_vec(detect.raw_dim(), detect.iter().copied().collect())
5447 .unwrap();
5448
5449 if is_proto {
5450 let detect_td = make_f64_tensor(detect.shape(), detect.as_slice().unwrap());
5451 decoder
5452 .decode_proto_tracked(
5453 &mut tracker,
5454 0,
5455 &[&detect_td, &protos_td],
5456 &mut output_boxes,
5457 &mut output_tracks,
5458 )
5459 .unwrap();
5460
5461 assert_eq!(output_boxes.len(), 1);
5462 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5463
5464 for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
5465 *score = 0.0;
5466 }
5467 let detect_td = make_f64_tensor(detect.shape(), detect.as_slice().unwrap());
5468 let proto_result = decoder
5469 .decode_proto_tracked(
5470 &mut tracker,
5471 100_000_000 / 3,
5472 &[&detect_td, &protos_td],
5473 &mut output_boxes,
5474 &mut output_tracks,
5475 )
5476 .unwrap();
5477 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5478 assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
5479 } else {
5480 let detect_td = make_f64_tensor(detect.shape(), detect.as_slice().unwrap());
5481 let mut output_masks = Vec::with_capacity(50);
5482 decoder
5483 .decode_tracked(
5484 &mut tracker,
5485 0,
5486 &[&detect_td, &protos_td],
5487 &mut output_boxes,
5488 &mut output_masks,
5489 &mut output_tracks,
5490 )
5491 .unwrap();
5492
5493 assert_eq!(output_boxes.len(), 1);
5494 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5495
5496 for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
5497 *score = 0.0;
5498 }
5499 let detect_td = make_f64_tensor(detect.shape(), detect.as_slice().unwrap());
5500 decoder
5501 .decode_tracked(
5502 &mut tracker,
5503 100_000_000 / 3,
5504 &[&detect_td, &protos_td],
5505 &mut output_boxes,
5506 &mut output_masks,
5507 &mut output_tracks,
5508 )
5509 .unwrap();
5510 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5511 assert!(output_masks.is_empty());
5512 }
5513 }
5514 }
5515 };
5516 }
5517
5518 e2e_tracked_tensor_test!(
5519 test_decoder_tracked_tensor_end_to_end_segdet,
5520 quantized,
5521 combined,
5522 masks
5523 );
5524 e2e_tracked_tensor_test!(
5525 test_decoder_tracked_tensor_end_to_end_segdet_float,
5526 float,
5527 combined,
5528 masks
5529 );
5530 e2e_tracked_tensor_test!(
5531 test_decoder_tracked_tensor_end_to_end_segdet_proto,
5532 quantized,
5533 combined,
5534 proto
5535 );
5536 e2e_tracked_tensor_test!(
5537 test_decoder_tracked_tensor_end_to_end_segdet_proto_float,
5538 float,
5539 combined,
5540 proto
5541 );
5542 e2e_tracked_tensor_test!(
5543 test_decoder_tracked_tensor_end_to_end_segdet_split,
5544 quantized,
5545 split,
5546 masks
5547 );
5548 e2e_tracked_tensor_test!(
5549 test_decoder_tracked_tensor_end_to_end_segdet_split_float,
5550 float,
5551 split,
5552 masks
5553 );
5554 e2e_tracked_tensor_test!(
5555 test_decoder_tracked_tensor_end_to_end_segdet_split_proto,
5556 quantized,
5557 split,
5558 proto
5559 );
5560 e2e_tracked_tensor_test!(
5561 test_decoder_tracked_tensor_end_to_end_segdet_split_proto_float,
5562 float,
5563 split,
5564 proto
5565 );
5566
5567 #[test]
5568 fn test_decoder_tracked_linear_motion() {
5569 use crate::configs::{DecoderType, Nms};
5570 use crate::DecoderBuilder;
5571
5572 let score_threshold = 0.25;
5573 let iou_threshold = 0.1;
5574 let out = include_bytes!(concat!(
5575 env!("CARGO_MANIFEST_DIR"),
5576 "/../../testdata/yolov8s_80_classes.bin"
5577 ));
5578 let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
5579 let mut out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
5580 let quant = (0.0040811873, -123).into();
5581
5582 let decoder = DecoderBuilder::default()
5583 .with_config_yolo_det(
5584 crate::configs::Detection {
5585 decoder: DecoderType::Ultralytics,
5586 shape: vec![1, 84, 8400],
5587 anchors: None,
5588 quantization: Some(quant),
5589 dshape: vec![
5590 (crate::configs::DimName::Batch, 1),
5591 (crate::configs::DimName::NumFeatures, 84),
5592 (crate::configs::DimName::NumBoxes, 8400),
5593 ],
5594 normalized: Some(true),
5595 },
5596 None,
5597 )
5598 .with_score_threshold(score_threshold)
5599 .with_iou_threshold(iou_threshold)
5600 .with_nms(Some(Nms::ClassAgnostic))
5601 .build()
5602 .unwrap();
5603
5604 let mut expected_boxes = [
5605 DetectBox {
5606 bbox: BoundingBox {
5607 xmin: 0.5285137,
5608 ymin: 0.05305544,
5609 xmax: 0.87541467,
5610 ymax: 0.9998909,
5611 },
5612 score: 0.5591227,
5613 label: 0,
5614 },
5615 DetectBox {
5616 bbox: BoundingBox {
5617 xmin: 0.130598,
5618 ymin: 0.43260583,
5619 xmax: 0.35098213,
5620 ymax: 0.9958097,
5621 },
5622 score: 0.33057618,
5623 label: 75,
5624 },
5625 ];
5626
5627 let mut tracker = ByteTrackBuilder::new()
5628 .track_update(0.1)
5629 .track_high_conf(0.3)
5630 .build();
5631
5632 let mut output_boxes = Vec::with_capacity(50);
5633 let mut output_masks = Vec::with_capacity(50);
5634 let mut output_tracks = Vec::with_capacity(50);
5635
5636 decoder
5637 .decode_tracked_quantized(
5638 &mut tracker,
5639 0,
5640 &[out.view().into()],
5641 &mut output_boxes,
5642 &mut output_masks,
5643 &mut output_tracks,
5644 )
5645 .unwrap();
5646
5647 assert_eq!(output_boxes.len(), 2);
5648 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
5649 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
5650
5651 for i in 1..=100 {
5652 let mut out = out.clone();
5653 let mut x_values = out.slice_mut(s![0, 0, ..]);
5655 for x in x_values.iter_mut() {
5656 *x = x.saturating_add((i as f32 * 1e-3 / quant.0).round() as i8);
5657 }
5658
5659 decoder
5660 .decode_tracked_quantized(
5661 &mut tracker,
5662 100_000_000 * i / 3, &[out.view().into()],
5664 &mut output_boxes,
5665 &mut output_masks,
5666 &mut output_tracks,
5667 )
5668 .unwrap();
5669
5670 assert_eq!(output_boxes.len(), 2);
5671 }
5672 let tracks = tracker.get_active_tracks();
5673 let predicted_boxes: Vec<_> = tracks
5674 .iter()
5675 .map(|track| {
5676 let mut l = track.last_box;
5677 l.bbox = track.info.tracked_location.into();
5678 l
5679 })
5680 .collect();
5681 expected_boxes[0].bbox.xmin += 0.1; expected_boxes[0].bbox.xmax += 0.1;
5683 expected_boxes[1].bbox.xmin += 0.1;
5684 expected_boxes[1].bbox.xmax += 0.1;
5685
5686 assert!(predicted_boxes[0].equal_within_delta(&expected_boxes[0], 1e-3));
5687 assert!(predicted_boxes[1].equal_within_delta(&expected_boxes[1], 1e-3));
5688
5689 let mut scores_values = out.slice_mut(s![0, 4.., ..]);
5691 for score in scores_values.iter_mut() {
5692 *score = i8::MIN; }
5694 decoder
5695 .decode_tracked_quantized(
5696 &mut tracker,
5697 100_000_000 * 101 / 3,
5698 &[out.view().into()],
5699 &mut output_boxes,
5700 &mut output_masks,
5701 &mut output_tracks,
5702 )
5703 .unwrap();
5704 expected_boxes[0].bbox.xmin += 0.001; expected_boxes[0].bbox.xmax += 0.001;
5706 expected_boxes[1].bbox.xmin += 0.001;
5707 expected_boxes[1].bbox.xmax += 0.001;
5708
5709 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-3));
5710 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-3));
5711 }
5712
5713 #[test]
5714 fn test_decoder_tracked_end_to_end_float() {
5715 let score_threshold = 0.45;
5716 let iou_threshold = 0.45;
5717
5718 let mut boxes = Array2::zeros((10, 4));
5719 let mut scores = Array2::zeros((10, 1));
5720 let mut classes = Array2::zeros((10, 1));
5721
5722 boxes
5723 .slice_mut(s![0, ..,])
5724 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
5725 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
5726 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
5727
5728 let detect = ndarray::concatenate![Axis(1), boxes.view(), scores.view(), classes.view(),];
5729 let mut detect = detect.insert_axis(Axis(0));
5730 assert_eq!(detect.shape(), &[1, 10, 6]);
5731 let config = "
5732decoder_version: yolo26
5733outputs:
5734 - type: detection
5735 decoder: ultralytics
5736 quantization: [0.00784313725490196, 0]
5737 shape: [1, 10, 6]
5738 dshape:
5739 - [batch, 1]
5740 - [num_boxes, 10]
5741 - [num_features, 6]
5742 normalized: true
5743";
5744
5745 let decoder = DecoderBuilder::default()
5746 .with_config_yaml_str(config.to_string())
5747 .with_score_threshold(score_threshold)
5748 .with_iou_threshold(iou_threshold)
5749 .build()
5750 .unwrap();
5751
5752 let expected_boxes = [DetectBox {
5753 bbox: BoundingBox {
5754 xmin: 0.1234,
5755 ymin: 0.1234,
5756 xmax: 0.2345,
5757 ymax: 0.2345,
5758 },
5759 score: 0.9876,
5760 label: 2,
5761 }];
5762
5763 let mut tracker = ByteTrackBuilder::new()
5764 .track_update(0.1)
5765 .track_high_conf(0.7)
5766 .build();
5767
5768 let mut output_boxes = Vec::with_capacity(50);
5769 let mut output_masks = Vec::with_capacity(50);
5770 let mut output_tracks = Vec::with_capacity(50);
5771
5772 decoder
5773 .decode_tracked_float(
5774 &mut tracker,
5775 0,
5776 &[detect.view().into_dyn()],
5777 &mut output_boxes,
5778 &mut output_masks,
5779 &mut output_tracks,
5780 )
5781 .unwrap();
5782
5783 assert_eq!(output_boxes.len(), 1);
5784 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
5785
5786 for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
5789 *score = 0.0; }
5791
5792 decoder
5793 .decode_tracked_float(
5794 &mut tracker,
5795 100_000_000 / 3,
5796 &[detect.view().into_dyn()],
5797 &mut output_boxes,
5798 &mut output_masks,
5799 &mut output_tracks,
5800 )
5801 .unwrap();
5802 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
5803 }
5804}