1#![cfg_attr(coverage_nightly, feature(coverage_attribute))]
68
69use ndarray::{Array, Array2, Array3, ArrayView, ArrayView1, ArrayView3, Dimension};
70use num_traits::{AsPrimitive, Float, PrimInt};
71
72pub mod byte;
73pub mod error;
74pub mod float;
75pub mod modelpack;
76pub mod schema;
77pub mod yolo;
78
79mod decoder;
80pub use decoder::*;
81
82pub use configs::{DecoderVersion, Nms};
83pub use error::{DecoderError, DecoderResult};
84
85use crate::{
86 decoder::configs::QuantTuple, modelpack::modelpack_segmentation_to_mask,
87 yolo::yolo_segmentation_to_mask,
88};
89
90pub trait BBoxTypeTrait {
92 fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4];
94
95 fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
97 input: &[B; 4],
98 quant: Quantization,
99 ) -> [A; 4]
100 where
101 f32: AsPrimitive<A>,
102 i32: AsPrimitive<A>;
103
104 #[inline(always)]
115 fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
116 input: ArrayView1<B>,
117 ) -> [A; 4] {
118 Self::to_xyxy_float(&[input[0], input[1], input[2], input[3]])
119 }
120
121 #[inline(always)]
122 fn ndarray_to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
124 input: ArrayView1<B>,
125 quant: Quantization,
126 ) -> [A; 4]
127 where
128 f32: AsPrimitive<A>,
129 i32: AsPrimitive<A>,
130 {
131 Self::to_xyxy_dequant(&[input[0], input[1], input[2], input[3]], quant)
132 }
133}
134
135#[derive(Debug, Clone, Copy, PartialEq, Eq)]
137pub struct XYXY {}
138
139impl BBoxTypeTrait for XYXY {
140 fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
141 input.map(|b| b.as_())
142 }
143
144 fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
145 input: &[B; 4],
146 quant: Quantization,
147 ) -> [A; 4]
148 where
149 f32: AsPrimitive<A>,
150 i32: AsPrimitive<A>,
151 {
152 let scale = quant.scale.as_();
153 let zp = quant.zero_point.as_();
154 input.map(|b| (b.as_() - zp) * scale)
155 }
156
157 #[inline(always)]
158 fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
159 input: ArrayView1<B>,
160 ) -> [A; 4] {
161 [
162 input[0].as_(),
163 input[1].as_(),
164 input[2].as_(),
165 input[3].as_(),
166 ]
167 }
168}
169
170#[derive(Debug, Clone, Copy, PartialEq, Eq)]
173pub struct XYWH {}
174
175impl BBoxTypeTrait for XYWH {
176 #[inline(always)]
177 fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
178 let half = A::one() / (A::one() + A::one());
179 [
180 (input[0].as_()) - (input[2].as_() * half),
181 (input[1].as_()) - (input[3].as_() * half),
182 (input[0].as_()) + (input[2].as_() * half),
183 (input[1].as_()) + (input[3].as_() * half),
184 ]
185 }
186
187 #[inline(always)]
188 fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
189 input: &[B; 4],
190 quant: Quantization,
191 ) -> [A; 4]
192 where
193 f32: AsPrimitive<A>,
194 i32: AsPrimitive<A>,
195 {
196 let scale = quant.scale.as_();
197 let half_scale = (quant.scale * 0.5).as_();
198 let zp = quant.zero_point.as_();
199 let [x, y, w, h] = [
200 (input[0].as_() - zp) * scale,
201 (input[1].as_() - zp) * scale,
202 (input[2].as_() - zp) * half_scale,
203 (input[3].as_() - zp) * half_scale,
204 ];
205
206 [x - w, y - h, x + w, y + h]
207 }
208
209 #[inline(always)]
210 fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
211 input: ArrayView1<B>,
212 ) -> [A; 4] {
213 let half = A::one() / (A::one() + A::one());
214 [
215 (input[0].as_()) - (input[2].as_() * half),
216 (input[1].as_()) - (input[3].as_() * half),
217 (input[0].as_()) + (input[2].as_() * half),
218 (input[1].as_()) + (input[3].as_() * half),
219 ]
220 }
221}
222
223#[derive(Debug, Clone, Copy, PartialEq)]
225pub struct Quantization {
226 pub scale: f32,
227 pub zero_point: i32,
228}
229
230impl Quantization {
231 pub fn new(scale: f32, zero_point: i32) -> Self {
240 Self { scale, zero_point }
241 }
242}
243
244impl From<QuantTuple> for Quantization {
245 fn from(quant_tuple: QuantTuple) -> Quantization {
256 Quantization {
257 scale: quant_tuple.0,
258 zero_point: quant_tuple.1,
259 }
260 }
261}
262
263impl<S, Z> From<(S, Z)> for Quantization
264where
265 S: AsPrimitive<f32>,
266 Z: AsPrimitive<i32>,
267{
268 fn from((scale, zp): (S, Z)) -> Quantization {
277 Self {
278 scale: scale.as_(),
279 zero_point: zp.as_(),
280 }
281 }
282}
283
284impl Default for Quantization {
285 fn default() -> Self {
294 Self {
295 scale: 1.0,
296 zero_point: 0,
297 }
298 }
299}
300
301#[derive(Debug, Clone, Copy, PartialEq, Default)]
303pub struct DetectBox {
304 pub bbox: BoundingBox,
305 pub score: f32,
307 pub label: usize,
309}
310
311#[derive(Debug, Clone, Copy, PartialEq, Default)]
313pub struct BoundingBox {
314 pub xmin: f32,
316 pub ymin: f32,
318 pub xmax: f32,
320 pub ymax: f32,
322}
323
324impl BoundingBox {
325 pub fn new(xmin: f32, ymin: f32, xmax: f32, ymax: f32) -> Self {
327 Self {
328 xmin,
329 ymin,
330 xmax,
331 ymax,
332 }
333 }
334
335 pub fn to_canonical(&self) -> Self {
344 let xmin = self.xmin.min(self.xmax);
345 let xmax = self.xmin.max(self.xmax);
346 let ymin = self.ymin.min(self.ymax);
347 let ymax = self.ymin.max(self.ymax);
348 BoundingBox {
349 xmin,
350 ymin,
351 xmax,
352 ymax,
353 }
354 }
355}
356
357impl From<BoundingBox> for [f32; 4] {
358 fn from(b: BoundingBox) -> Self {
373 [b.xmin, b.ymin, b.xmax, b.ymax]
374 }
375}
376
377impl From<[f32; 4]> for BoundingBox {
378 fn from(arr: [f32; 4]) -> Self {
381 BoundingBox {
382 xmin: arr[0],
383 ymin: arr[1],
384 xmax: arr[2],
385 ymax: arr[3],
386 }
387 }
388}
389
390impl DetectBox {
391 pub fn equal_within_delta(&self, rhs: &DetectBox, eps: f32) -> bool {
420 let eq_delta = |a: f32, b: f32| (a - b).abs() <= eps;
421 self.label == rhs.label
422 && eq_delta(self.score, rhs.score)
423 && eq_delta(self.bbox.xmin, rhs.bbox.xmin)
424 && eq_delta(self.bbox.ymin, rhs.bbox.ymin)
425 && eq_delta(self.bbox.xmax, rhs.bbox.xmax)
426 && eq_delta(self.bbox.ymax, rhs.bbox.ymax)
427 }
428}
429
430#[derive(Debug, Clone, PartialEq, Default)]
433pub struct Segmentation {
434 pub xmin: f32,
436 pub ymin: f32,
438 pub xmax: f32,
440 pub ymax: f32,
442 pub segmentation: Array3<u8>,
452}
453
454#[derive(Debug, Clone)]
460pub enum ProtoTensor {
461 Quantized {
465 protos: Array3<i8>,
466 quantization: Quantization,
467 },
468 Float(Array3<f32>),
470}
471
472impl ProtoTensor {
473 pub fn is_quantized(&self) -> bool {
475 matches!(self, ProtoTensor::Quantized { .. })
476 }
477
478 pub fn dim(&self) -> (usize, usize, usize) {
480 match self {
481 ProtoTensor::Quantized { protos, .. } => protos.dim(),
482 ProtoTensor::Float(arr) => arr.dim(),
483 }
484 }
485
486 pub fn as_f32(&self) -> std::borrow::Cow<'_, Array3<f32>> {
489 match self {
490 ProtoTensor::Float(arr) => std::borrow::Cow::Borrowed(arr),
491 ProtoTensor::Quantized {
492 protos,
493 quantization,
494 } => {
495 let scale = quantization.scale;
496 let zp = quantization.zero_point as f32;
497 std::borrow::Cow::Owned(protos.map(|&v| (v as f32 - zp) * scale))
498 }
499 }
500 }
501}
502
503#[derive(Debug, Clone)]
509pub struct ProtoData {
510 pub mask_coefficients: Vec<Vec<f32>>,
512 pub protos: ProtoTensor,
514}
515
516pub fn dequant_detect_box<SCORE: PrimInt + AsPrimitive<f32>>(
534 detect: &DetectBoxQuantized<SCORE>,
535 quant_scores: Quantization,
536) -> DetectBox {
537 let scaled_zp = -quant_scores.scale * quant_scores.zero_point as f32;
538 DetectBox {
539 bbox: detect.bbox,
540 score: quant_scores.scale * detect.score.as_() + scaled_zp,
541 label: detect.label,
542 }
543}
544#[derive(Debug, Clone, Copy, PartialEq)]
546pub struct DetectBoxQuantized<
547 SCORE: PrimInt + AsPrimitive<f32>,
549> {
550 pub bbox: BoundingBox,
552 pub score: SCORE,
555 pub label: usize,
557}
558
559pub fn dequantize_ndarray<T: AsPrimitive<F>, D: Dimension, F: Float + 'static>(
572 input: ArrayView<T, D>,
573 quant: Quantization,
574) -> Array<F, D>
575where
576 i32: num_traits::AsPrimitive<F>,
577 f32: num_traits::AsPrimitive<F>,
578{
579 let zero_point = quant.zero_point.as_();
580 let scale = quant.scale.as_();
581 if zero_point != F::zero() {
582 let scaled_zero = -zero_point * scale;
583 input.mapv(|d| d.as_() * scale + scaled_zero)
584 } else {
585 input.mapv(|d| d.as_() * scale)
586 }
587}
588
589pub fn dequantize_cpu<T: AsPrimitive<F>, F: Float + 'static>(
602 input: &[T],
603 quant: Quantization,
604 output: &mut [F],
605) where
606 f32: num_traits::AsPrimitive<F>,
607 i32: num_traits::AsPrimitive<F>,
608{
609 assert!(input.len() == output.len());
610 let zero_point = quant.zero_point.as_();
611 let scale = quant.scale.as_();
612 if zero_point != F::zero() {
613 let scaled_zero = -zero_point * scale; input
615 .iter()
616 .zip(output)
617 .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
618 } else {
619 input
620 .iter()
621 .zip(output)
622 .for_each(|(d, deq)| *deq = d.as_() * scale);
623 }
624}
625
626pub fn dequantize_cpu_chunked<T: AsPrimitive<F>, F: Float + 'static>(
640 input: &[T],
641 quant: Quantization,
642 output: &mut [F],
643) where
644 f32: num_traits::AsPrimitive<F>,
645 i32: num_traits::AsPrimitive<F>,
646{
647 assert!(input.len() == output.len());
648 let zero_point = quant.zero_point.as_();
649 let scale = quant.scale.as_();
650
651 let input = input.as_chunks::<4>();
652 let output = output.as_chunks_mut::<4>();
653
654 if zero_point != F::zero() {
655 let scaled_zero = -zero_point * scale; input
658 .0
659 .iter()
660 .zip(output.0)
661 .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale + scaled_zero));
662 input
663 .1
664 .iter()
665 .zip(output.1)
666 .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
667 } else {
668 input
669 .0
670 .iter()
671 .zip(output.0)
672 .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale));
673 input
674 .1
675 .iter()
676 .zip(output.1)
677 .for_each(|(d, deq)| *deq = d.as_() * scale);
678 }
679}
680
681pub fn segmentation_to_mask(segmentation: ArrayView3<u8>) -> Result<Array2<u8>, DecoderError> {
699 if segmentation.shape()[2] == 0 {
700 return Err(DecoderError::InvalidShape(
701 "Segmentation tensor must have non-zero depth".to_string(),
702 ));
703 }
704 if segmentation.shape()[2] == 1 {
705 yolo_segmentation_to_mask(segmentation, 128)
706 } else {
707 Ok(modelpack_segmentation_to_mask(segmentation))
708 }
709}
710
711fn arg_max<T: PartialOrd + Copy>(score: ArrayView1<T>) -> (T, usize) {
713 score
714 .iter()
715 .enumerate()
716 .fold((score[0], 0), |(max, arg_max), (ind, s)| {
717 if max > *s {
718 (max, arg_max)
719 } else {
720 (*s, ind)
721 }
722 })
723}
724#[cfg(test)]
725#[cfg_attr(coverage_nightly, coverage(off))]
726mod decoder_tests {
727 #![allow(clippy::excessive_precision)]
728 use crate::{
729 configs::{DecoderType, DimName, Protos},
730 modelpack::{decode_modelpack_det, decode_modelpack_split_quant},
731 yolo::{
732 decode_yolo_det, decode_yolo_det_float, decode_yolo_segdet_float,
733 decode_yolo_segdet_quant,
734 },
735 *,
736 };
737 use edgefirst_tensor::{Tensor, TensorMapTrait, TensorTrait};
738 use ndarray::Dimension;
739 use ndarray::{array, s, Array2, Array3, Array4, Axis};
740 use ndarray_stats::DeviationExt;
741 use num_traits::{AsPrimitive, PrimInt};
742
743 fn compare_outputs(
744 boxes: (&[DetectBox], &[DetectBox]),
745 masks: (&[Segmentation], &[Segmentation]),
746 ) {
747 let (boxes0, boxes1) = boxes;
748 let (masks0, masks1) = masks;
749
750 assert_eq!(boxes0.len(), boxes1.len());
751 assert_eq!(masks0.len(), masks1.len());
752
753 for (b_i8, b_f32) in boxes0.iter().zip(boxes1) {
754 assert!(
755 b_i8.equal_within_delta(b_f32, 1e-6),
756 "{b_i8:?} is not equal to {b_f32:?}"
757 );
758 }
759
760 for (m_i8, m_f32) in masks0.iter().zip(masks1) {
761 assert_eq!(
762 [m_i8.xmin, m_i8.ymin, m_i8.xmax, m_i8.ymax],
763 [m_f32.xmin, m_f32.ymin, m_f32.xmax, m_f32.ymax],
764 );
765 assert_eq!(m_i8.segmentation.shape(), m_f32.segmentation.shape());
766 let mask_i8 = m_i8.segmentation.map(|x| *x as i32);
767 let mask_f32 = m_f32.segmentation.map(|x| *x as i32);
768 let diff = &mask_i8 - &mask_f32;
769 for x in 0..diff.shape()[0] {
770 for y in 0..diff.shape()[1] {
771 for z in 0..diff.shape()[2] {
772 let val = diff[[x, y, z]];
773 assert!(
774 val.abs() <= 1,
775 "Difference between mask0 and mask1 is greater than 1 at ({}, {}, {}): {}",
776 x,
777 y,
778 z,
779 val
780 );
781 }
782 }
783 }
784 let mean_sq_err = mask_i8.mean_sq_err(&mask_f32).unwrap();
785 assert!(
786 mean_sq_err < 1e-2,
787 "Mean Square Error between masks was greater than 1%: {:.2}%",
788 mean_sq_err * 100.0
789 );
790 }
791 }
792
793 fn load_yolov8_boxes() -> Array3<i8> {
796 let raw = include_bytes!(concat!(
797 env!("CARGO_MANIFEST_DIR"),
798 "/../../testdata/yolov8_boxes_116x8400.bin"
799 ));
800 let raw = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const i8, raw.len()) };
801 Array3::from_shape_vec((1, 116, 8400), raw.to_vec()).unwrap()
802 }
803
804 fn load_yolov8_protos() -> Array4<i8> {
805 let raw = include_bytes!(concat!(
806 env!("CARGO_MANIFEST_DIR"),
807 "/../../testdata/yolov8_protos_160x160x32.bin"
808 ));
809 let raw = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const i8, raw.len()) };
810 Array4::from_shape_vec((1, 160, 160, 32), raw.to_vec()).unwrap()
811 }
812
813 fn load_yolov8s_det() -> Array3<i8> {
814 let raw = include_bytes!(concat!(
815 env!("CARGO_MANIFEST_DIR"),
816 "/../../testdata/yolov8s_80_classes.bin"
817 ));
818 let raw = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const i8, raw.len()) };
819 Array3::from_shape_vec((1, 84, 8400), raw.to_vec()).unwrap()
820 }
821
822 #[test]
823 fn test_decoder_modelpack() {
824 let score_threshold = 0.45;
825 let iou_threshold = 0.45;
826 let boxes = include_bytes!(concat!(
827 env!("CARGO_MANIFEST_DIR"),
828 "/../../testdata/modelpack_boxes_1935x1x4.bin"
829 ));
830 let boxes = ndarray::Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
831
832 let scores = include_bytes!(concat!(
833 env!("CARGO_MANIFEST_DIR"),
834 "/../../testdata/modelpack_scores_1935x1.bin"
835 ));
836 let scores = ndarray::Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
837
838 let quant_boxes = (0.004656755365431309, 21).into();
839 let quant_scores = (0.0019603664986789227, 0).into();
840
841 let decoder = DecoderBuilder::default()
842 .with_config_modelpack_det(
843 configs::Boxes {
844 decoder: DecoderType::ModelPack,
845 quantization: Some(quant_boxes),
846 shape: vec![1, 1935, 1, 4],
847 dshape: vec![
848 (DimName::Batch, 1),
849 (DimName::NumBoxes, 1935),
850 (DimName::Padding, 1),
851 (DimName::BoxCoords, 4),
852 ],
853 normalized: Some(true),
854 },
855 configs::Scores {
856 decoder: DecoderType::ModelPack,
857 quantization: Some(quant_scores),
858 shape: vec![1, 1935, 1],
859 dshape: vec![
860 (DimName::Batch, 1),
861 (DimName::NumBoxes, 1935),
862 (DimName::NumClasses, 1),
863 ],
864 },
865 )
866 .with_score_threshold(score_threshold)
867 .with_iou_threshold(iou_threshold)
868 .build()
869 .unwrap();
870
871 let quant_boxes = quant_boxes.into();
872 let quant_scores = quant_scores.into();
873
874 let mut output_boxes: Vec<_> = Vec::with_capacity(50);
875 decode_modelpack_det(
876 (boxes.slice(s![0, .., 0, ..]), quant_boxes),
877 (scores.slice(s![0, .., ..]), quant_scores),
878 score_threshold,
879 iou_threshold,
880 &mut output_boxes,
881 );
882 assert!(output_boxes[0].equal_within_delta(
883 &DetectBox {
884 bbox: BoundingBox {
885 xmin: 0.40513772,
886 ymin: 0.6379755,
887 xmax: 0.5122431,
888 ymax: 0.7730214,
889 },
890 score: 0.4861709,
891 label: 0
892 },
893 1e-6
894 ));
895
896 let mut output_boxes1 = Vec::with_capacity(50);
897 let mut output_masks1 = Vec::with_capacity(50);
898
899 decoder
900 .decode_quantized(
901 &[boxes.view().into(), scores.view().into()],
902 &mut output_boxes1,
903 &mut output_masks1,
904 )
905 .unwrap();
906
907 let mut output_boxes_float = Vec::with_capacity(50);
908 let mut output_masks_float = Vec::with_capacity(50);
909
910 let boxes = dequantize_ndarray(boxes.view(), quant_boxes);
911 let scores = dequantize_ndarray(scores.view(), quant_scores);
912
913 decoder
914 .decode_float::<f32>(
915 &[boxes.view().into_dyn(), scores.view().into_dyn()],
916 &mut output_boxes_float,
917 &mut output_masks_float,
918 )
919 .unwrap();
920
921 compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
922 compare_outputs(
923 (&output_boxes, &output_boxes_float),
924 (&[], &output_masks_float),
925 );
926 }
927
928 #[test]
929 fn test_decoder_modelpack_split_u8() {
930 let score_threshold = 0.45;
931 let iou_threshold = 0.45;
932 let detect0 = include_bytes!(concat!(
933 env!("CARGO_MANIFEST_DIR"),
934 "/../../testdata/modelpack_split_9x15x18.bin"
935 ));
936 let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
937
938 let detect1 = include_bytes!(concat!(
939 env!("CARGO_MANIFEST_DIR"),
940 "/../../testdata/modelpack_split_17x30x18.bin"
941 ));
942 let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
943
944 let quant0 = (0.08547406643629074, 174).into();
945 let quant1 = (0.09929127991199493, 183).into();
946 let anchors0 = vec![
947 [0.36666667461395264, 0.31481480598449707],
948 [0.38749998807907104, 0.4740740656852722],
949 [0.5333333611488342, 0.644444465637207],
950 ];
951 let anchors1 = vec![
952 [0.13750000298023224, 0.2074074000120163],
953 [0.2541666626930237, 0.21481481194496155],
954 [0.23125000298023224, 0.35185185074806213],
955 ];
956
957 let detect_config0 = configs::Detection {
958 decoder: DecoderType::ModelPack,
959 shape: vec![1, 9, 15, 18],
960 anchors: Some(anchors0.clone()),
961 quantization: Some(quant0),
962 dshape: vec![
963 (DimName::Batch, 1),
964 (DimName::Height, 9),
965 (DimName::Width, 15),
966 (DimName::NumAnchorsXFeatures, 18),
967 ],
968 normalized: Some(true),
969 };
970
971 let detect_config1 = configs::Detection {
972 decoder: DecoderType::ModelPack,
973 shape: vec![1, 17, 30, 18],
974 anchors: Some(anchors1.clone()),
975 quantization: Some(quant1),
976 dshape: vec![
977 (DimName::Batch, 1),
978 (DimName::Height, 17),
979 (DimName::Width, 30),
980 (DimName::NumAnchorsXFeatures, 18),
981 ],
982 normalized: Some(true),
983 };
984
985 let config0 = (&detect_config0).try_into().unwrap();
986 let config1 = (&detect_config1).try_into().unwrap();
987
988 let decoder = DecoderBuilder::default()
989 .with_config_modelpack_det_split(vec![detect_config1, detect_config0])
990 .with_score_threshold(score_threshold)
991 .with_iou_threshold(iou_threshold)
992 .build()
993 .unwrap();
994
995 let quant0 = quant0.into();
996 let quant1 = quant1.into();
997
998 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
999 decode_modelpack_split_quant(
1000 &[
1001 detect0.slice(s![0, .., .., ..]),
1002 detect1.slice(s![0, .., .., ..]),
1003 ],
1004 &[config0, config1],
1005 score_threshold,
1006 iou_threshold,
1007 &mut output_boxes,
1008 );
1009 assert!(output_boxes[0].equal_within_delta(
1010 &DetectBox {
1011 bbox: BoundingBox {
1012 xmin: 0.43171933,
1013 ymin: 0.68243736,
1014 xmax: 0.5626645,
1015 ymax: 0.808863,
1016 },
1017 score: 0.99240804,
1018 label: 0
1019 },
1020 1e-6
1021 ));
1022
1023 let mut output_boxes1: Vec<_> = Vec::with_capacity(10);
1024 let mut output_masks1: Vec<_> = Vec::with_capacity(10);
1025 decoder
1026 .decode_quantized(
1027 &[detect0.view().into(), detect1.view().into()],
1028 &mut output_boxes1,
1029 &mut output_masks1,
1030 )
1031 .unwrap();
1032
1033 let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(10);
1034 let mut output_masks1_f32: Vec<_> = Vec::with_capacity(10);
1035
1036 let detect0 = dequantize_ndarray(detect0.view(), quant0);
1037 let detect1 = dequantize_ndarray(detect1.view(), quant1);
1038 decoder
1039 .decode_float::<f32>(
1040 &[detect0.view().into_dyn(), detect1.view().into_dyn()],
1041 &mut output_boxes1_f32,
1042 &mut output_masks1_f32,
1043 )
1044 .unwrap();
1045
1046 compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1047 compare_outputs(
1048 (&output_boxes, &output_boxes1_f32),
1049 (&[], &output_masks1_f32),
1050 );
1051 }
1052
1053 #[test]
1054 fn test_decoder_parse_config_modelpack_split_u8() {
1055 let score_threshold = 0.45;
1056 let iou_threshold = 0.45;
1057 let detect0 = include_bytes!(concat!(
1058 env!("CARGO_MANIFEST_DIR"),
1059 "/../../testdata/modelpack_split_9x15x18.bin"
1060 ));
1061 let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1062
1063 let detect1 = include_bytes!(concat!(
1064 env!("CARGO_MANIFEST_DIR"),
1065 "/../../testdata/modelpack_split_17x30x18.bin"
1066 ));
1067 let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1068
1069 let decoder = DecoderBuilder::default()
1070 .with_config_yaml_str(
1071 include_str!(concat!(
1072 env!("CARGO_MANIFEST_DIR"),
1073 "/../../testdata/modelpack_split.yaml"
1074 ))
1075 .to_string(),
1076 )
1077 .with_score_threshold(score_threshold)
1078 .with_iou_threshold(iou_threshold)
1079 .build()
1080 .unwrap();
1081
1082 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1083 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1084 decoder
1085 .decode_quantized(
1086 &[
1087 ArrayViewDQuantized::from(detect1.view()),
1088 ArrayViewDQuantized::from(detect0.view()),
1089 ],
1090 &mut output_boxes,
1091 &mut output_masks,
1092 )
1093 .unwrap();
1094 assert!(output_boxes[0].equal_within_delta(
1095 &DetectBox {
1096 bbox: BoundingBox {
1097 xmin: 0.43171933,
1098 ymin: 0.68243736,
1099 xmax: 0.5626645,
1100 ymax: 0.808863,
1101 },
1102 score: 0.99240804,
1103 label: 0
1104 },
1105 1e-6
1106 ));
1107 }
1108
1109 #[test]
1110 fn test_modelpack_seg() {
1111 let out = include_bytes!(concat!(
1112 env!("CARGO_MANIFEST_DIR"),
1113 "/../../testdata/modelpack_seg_2x160x160.bin"
1114 ));
1115 let out = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
1116 let quant = (1.0 / 255.0, 0).into();
1117
1118 let decoder = DecoderBuilder::default()
1119 .with_config_modelpack_seg(configs::Segmentation {
1120 decoder: DecoderType::ModelPack,
1121 quantization: Some(quant),
1122 shape: vec![1, 2, 160, 160],
1123 dshape: vec![
1124 (DimName::Batch, 1),
1125 (DimName::NumClasses, 2),
1126 (DimName::Height, 160),
1127 (DimName::Width, 160),
1128 ],
1129 })
1130 .build()
1131 .unwrap();
1132 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1133 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1134 decoder
1135 .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
1136 .unwrap();
1137
1138 let mut mask = out.slice(s![0, .., .., ..]);
1139 mask.swap_axes(0, 1);
1140 mask.swap_axes(1, 2);
1141 let mask = [Segmentation {
1142 xmin: 0.0,
1143 ymin: 0.0,
1144 xmax: 1.0,
1145 ymax: 1.0,
1146 segmentation: mask.into_owned(),
1147 }];
1148 compare_outputs((&[], &output_boxes), (&mask, &output_masks));
1149
1150 decoder
1151 .decode_float::<f32>(
1152 &[dequantize_ndarray(out.view(), quant.into())
1153 .view()
1154 .into_dyn()],
1155 &mut output_boxes,
1156 &mut output_masks,
1157 )
1158 .unwrap();
1159
1160 compare_outputs((&[], &output_boxes), (&[], &[]));
1166 let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1167 let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1168
1169 assert_eq!(mask0, mask1);
1170 }
1171 #[test]
1172 fn test_modelpack_seg_quant() {
1173 let out = include_bytes!(concat!(
1174 env!("CARGO_MANIFEST_DIR"),
1175 "/../../testdata/modelpack_seg_2x160x160.bin"
1176 ));
1177 let out_u8 = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
1178 let out_i8 = out_u8.mapv(|x| (x as i16 - 128) as i8);
1179 let out_u16 = out_u8.mapv(|x| (x as u16) << 8);
1180 let out_i16 = out_u8.mapv(|x| (((x as i32) << 8) - 32768) as i16);
1181 let out_u32 = out_u8.mapv(|x| (x as u32) << 24);
1182 let out_i32 = out_u8.mapv(|x| (((x as i64) << 24) - 2147483648) as i32);
1183
1184 let quant = (1.0 / 255.0, 0).into();
1185
1186 let decoder = DecoderBuilder::default()
1187 .with_config_modelpack_seg(configs::Segmentation {
1188 decoder: DecoderType::ModelPack,
1189 quantization: Some(quant),
1190 shape: vec![1, 2, 160, 160],
1191 dshape: vec![
1192 (DimName::Batch, 1),
1193 (DimName::NumClasses, 2),
1194 (DimName::Height, 160),
1195 (DimName::Width, 160),
1196 ],
1197 })
1198 .build()
1199 .unwrap();
1200 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1201 let mut output_masks_u8: Vec<_> = Vec::with_capacity(10);
1202 decoder
1203 .decode_quantized(
1204 &[out_u8.view().into()],
1205 &mut output_boxes,
1206 &mut output_masks_u8,
1207 )
1208 .unwrap();
1209
1210 let mut output_masks_i8: Vec<_> = Vec::with_capacity(10);
1211 decoder
1212 .decode_quantized(
1213 &[out_i8.view().into()],
1214 &mut output_boxes,
1215 &mut output_masks_i8,
1216 )
1217 .unwrap();
1218
1219 let mut output_masks_u16: Vec<_> = Vec::with_capacity(10);
1220 decoder
1221 .decode_quantized(
1222 &[out_u16.view().into()],
1223 &mut output_boxes,
1224 &mut output_masks_u16,
1225 )
1226 .unwrap();
1227
1228 let mut output_masks_i16: Vec<_> = Vec::with_capacity(10);
1229 decoder
1230 .decode_quantized(
1231 &[out_i16.view().into()],
1232 &mut output_boxes,
1233 &mut output_masks_i16,
1234 )
1235 .unwrap();
1236
1237 let mut output_masks_u32: Vec<_> = Vec::with_capacity(10);
1238 decoder
1239 .decode_quantized(
1240 &[out_u32.view().into()],
1241 &mut output_boxes,
1242 &mut output_masks_u32,
1243 )
1244 .unwrap();
1245
1246 let mut output_masks_i32: Vec<_> = Vec::with_capacity(10);
1247 decoder
1248 .decode_quantized(
1249 &[out_i32.view().into()],
1250 &mut output_boxes,
1251 &mut output_masks_i32,
1252 )
1253 .unwrap();
1254
1255 compare_outputs((&[], &output_boxes), (&[], &[]));
1256 let mask_u8 = segmentation_to_mask(output_masks_u8[0].segmentation.view()).unwrap();
1257 let mask_i8 = segmentation_to_mask(output_masks_i8[0].segmentation.view()).unwrap();
1258 let mask_u16 = segmentation_to_mask(output_masks_u16[0].segmentation.view()).unwrap();
1259 let mask_i16 = segmentation_to_mask(output_masks_i16[0].segmentation.view()).unwrap();
1260 let mask_u32 = segmentation_to_mask(output_masks_u32[0].segmentation.view()).unwrap();
1261 let mask_i32 = segmentation_to_mask(output_masks_i32[0].segmentation.view()).unwrap();
1262 assert_eq!(mask_u8, mask_i8);
1263 assert_eq!(mask_u8, mask_u16);
1264 assert_eq!(mask_u8, mask_i16);
1265 assert_eq!(mask_u8, mask_u32);
1266 assert_eq!(mask_u8, mask_i32);
1267 }
1268
1269 #[test]
1270 fn test_modelpack_segdet() {
1271 let score_threshold = 0.45;
1272 let iou_threshold = 0.45;
1273
1274 let boxes = include_bytes!(concat!(
1275 env!("CARGO_MANIFEST_DIR"),
1276 "/../../testdata/modelpack_boxes_1935x1x4.bin"
1277 ));
1278 let boxes = Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
1279
1280 let scores = include_bytes!(concat!(
1281 env!("CARGO_MANIFEST_DIR"),
1282 "/../../testdata/modelpack_scores_1935x1.bin"
1283 ));
1284 let scores = Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
1285
1286 let seg = include_bytes!(concat!(
1287 env!("CARGO_MANIFEST_DIR"),
1288 "/../../testdata/modelpack_seg_2x160x160.bin"
1289 ));
1290 let seg = Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1291
1292 let quant_boxes = (0.004656755365431309, 21).into();
1293 let quant_scores = (0.0019603664986789227, 0).into();
1294 let quant_seg = (1.0 / 255.0, 0).into();
1295
1296 let decoder = DecoderBuilder::default()
1297 .with_config_modelpack_segdet(
1298 configs::Boxes {
1299 decoder: DecoderType::ModelPack,
1300 quantization: Some(quant_boxes),
1301 shape: vec![1, 1935, 1, 4],
1302 dshape: vec![
1303 (DimName::Batch, 1),
1304 (DimName::NumBoxes, 1935),
1305 (DimName::Padding, 1),
1306 (DimName::BoxCoords, 4),
1307 ],
1308 normalized: Some(true),
1309 },
1310 configs::Scores {
1311 decoder: DecoderType::ModelPack,
1312 quantization: Some(quant_scores),
1313 shape: vec![1, 1935, 1],
1314 dshape: vec![
1315 (DimName::Batch, 1),
1316 (DimName::NumBoxes, 1935),
1317 (DimName::NumClasses, 1),
1318 ],
1319 },
1320 configs::Segmentation {
1321 decoder: DecoderType::ModelPack,
1322 quantization: Some(quant_seg),
1323 shape: vec![1, 2, 160, 160],
1324 dshape: vec![
1325 (DimName::Batch, 1),
1326 (DimName::NumClasses, 2),
1327 (DimName::Height, 160),
1328 (DimName::Width, 160),
1329 ],
1330 },
1331 )
1332 .with_iou_threshold(iou_threshold)
1333 .with_score_threshold(score_threshold)
1334 .build()
1335 .unwrap();
1336 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1337 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1338 decoder
1339 .decode_quantized(
1340 &[scores.view().into(), boxes.view().into(), seg.view().into()],
1341 &mut output_boxes,
1342 &mut output_masks,
1343 )
1344 .unwrap();
1345
1346 let mut mask = seg.slice(s![0, .., .., ..]);
1347 mask.swap_axes(0, 1);
1348 mask.swap_axes(1, 2);
1349 let mask = [Segmentation {
1350 xmin: 0.0,
1351 ymin: 0.0,
1352 xmax: 1.0,
1353 ymax: 1.0,
1354 segmentation: mask.into_owned(),
1355 }];
1356 let correct_boxes = [DetectBox {
1357 bbox: BoundingBox {
1358 xmin: 0.40513772,
1359 ymin: 0.6379755,
1360 xmax: 0.5122431,
1361 ymax: 0.7730214,
1362 },
1363 score: 0.4861709,
1364 label: 0,
1365 }];
1366 compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1367
1368 let scores = dequantize_ndarray(scores.view(), quant_scores.into());
1369 let boxes = dequantize_ndarray(boxes.view(), quant_boxes.into());
1370 let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1371 decoder
1372 .decode_float::<f32>(
1373 &[
1374 scores.view().into_dyn(),
1375 boxes.view().into_dyn(),
1376 seg.view().into_dyn(),
1377 ],
1378 &mut output_boxes,
1379 &mut output_masks,
1380 )
1381 .unwrap();
1382
1383 compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1389 let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1390 let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1391
1392 assert_eq!(mask0, mask1);
1393 }
1394
1395 #[test]
1396 fn test_modelpack_segdet_split() {
1397 let score_threshold = 0.8;
1398 let iou_threshold = 0.5;
1399
1400 let seg = include_bytes!(concat!(
1401 env!("CARGO_MANIFEST_DIR"),
1402 "/../../testdata/modelpack_seg_2x160x160.bin"
1403 ));
1404 let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1405
1406 let detect0 = include_bytes!(concat!(
1407 env!("CARGO_MANIFEST_DIR"),
1408 "/../../testdata/modelpack_split_9x15x18.bin"
1409 ));
1410 let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1411
1412 let detect1 = include_bytes!(concat!(
1413 env!("CARGO_MANIFEST_DIR"),
1414 "/../../testdata/modelpack_split_17x30x18.bin"
1415 ));
1416 let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1417
1418 let quant0 = (0.08547406643629074, 174).into();
1419 let quant1 = (0.09929127991199493, 183).into();
1420 let quant_seg = (1.0 / 255.0, 0).into();
1421
1422 let anchors0 = vec![
1423 [0.36666667461395264, 0.31481480598449707],
1424 [0.38749998807907104, 0.4740740656852722],
1425 [0.5333333611488342, 0.644444465637207],
1426 ];
1427 let anchors1 = vec![
1428 [0.13750000298023224, 0.2074074000120163],
1429 [0.2541666626930237, 0.21481481194496155],
1430 [0.23125000298023224, 0.35185185074806213],
1431 ];
1432
1433 let decoder = DecoderBuilder::default()
1434 .with_config_modelpack_segdet_split(
1435 vec![
1436 configs::Detection {
1437 decoder: DecoderType::ModelPack,
1438 shape: vec![1, 17, 30, 18],
1439 anchors: Some(anchors1),
1440 quantization: Some(quant1),
1441 dshape: vec![
1442 (DimName::Batch, 1),
1443 (DimName::Height, 17),
1444 (DimName::Width, 30),
1445 (DimName::NumAnchorsXFeatures, 18),
1446 ],
1447 normalized: Some(true),
1448 },
1449 configs::Detection {
1450 decoder: DecoderType::ModelPack,
1451 shape: vec![1, 9, 15, 18],
1452 anchors: Some(anchors0),
1453 quantization: Some(quant0),
1454 dshape: vec![
1455 (DimName::Batch, 1),
1456 (DimName::Height, 9),
1457 (DimName::Width, 15),
1458 (DimName::NumAnchorsXFeatures, 18),
1459 ],
1460 normalized: Some(true),
1461 },
1462 ],
1463 configs::Segmentation {
1464 decoder: DecoderType::ModelPack,
1465 quantization: Some(quant_seg),
1466 shape: vec![1, 2, 160, 160],
1467 dshape: vec![
1468 (DimName::Batch, 1),
1469 (DimName::NumClasses, 2),
1470 (DimName::Height, 160),
1471 (DimName::Width, 160),
1472 ],
1473 },
1474 )
1475 .with_score_threshold(score_threshold)
1476 .with_iou_threshold(iou_threshold)
1477 .build()
1478 .unwrap();
1479 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1480 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1481 decoder
1482 .decode_quantized(
1483 &[
1484 detect0.view().into(),
1485 detect1.view().into(),
1486 seg.view().into(),
1487 ],
1488 &mut output_boxes,
1489 &mut output_masks,
1490 )
1491 .unwrap();
1492
1493 let mut mask = seg.slice(s![0, .., .., ..]);
1494 mask.swap_axes(0, 1);
1495 mask.swap_axes(1, 2);
1496 let mask = [Segmentation {
1497 xmin: 0.0,
1498 ymin: 0.0,
1499 xmax: 1.0,
1500 ymax: 1.0,
1501 segmentation: mask.into_owned(),
1502 }];
1503 let correct_boxes = [DetectBox {
1504 bbox: BoundingBox {
1505 xmin: 0.43171933,
1506 ymin: 0.68243736,
1507 xmax: 0.5626645,
1508 ymax: 0.808863,
1509 },
1510 score: 0.99240804,
1511 label: 0,
1512 }];
1513 println!("Output Boxes: {:?}", output_boxes);
1514 compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1515
1516 let detect0 = dequantize_ndarray(detect0.view(), quant0.into());
1517 let detect1 = dequantize_ndarray(detect1.view(), quant1.into());
1518 let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1519 decoder
1520 .decode_float::<f32>(
1521 &[
1522 detect0.view().into_dyn(),
1523 detect1.view().into_dyn(),
1524 seg.view().into_dyn(),
1525 ],
1526 &mut output_boxes,
1527 &mut output_masks,
1528 )
1529 .unwrap();
1530
1531 compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1537 let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1538 let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1539
1540 assert_eq!(mask0, mask1);
1541 }
1542
1543 #[test]
1544 fn test_dequant_chunked() {
1545 let mut out = load_yolov8s_det().into_raw_vec_and_offset().0;
1546 out.push(123); let mut out_dequant = vec![0.0; 84 * 8400 + 1];
1549 let mut out_dequant_simd = vec![0.0; 84 * 8400 + 1];
1550 let quant = Quantization::new(0.0040811873, -123);
1551 dequantize_cpu(&out, quant, &mut out_dequant);
1552
1553 dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1554 assert_eq!(out_dequant, out_dequant_simd);
1555
1556 let quant = Quantization::new(0.0040811873, 0);
1557 dequantize_cpu(&out, quant, &mut out_dequant);
1558
1559 dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1560 assert_eq!(out_dequant, out_dequant_simd);
1561 }
1562
1563 #[test]
1564 fn test_dequant_ground_truth() {
1565 let quant = Quantization::new(0.1, -128);
1570 let input: Vec<i8> = vec![0, 127, -128, 64];
1571 let mut output = vec![0.0f32; 4];
1572 let mut output_chunked = vec![0.0f32; 4];
1573 dequantize_cpu(&input, quant, &mut output);
1574 dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1575 let expected: Vec<f32> = vec![12.8, 25.5, 0.0, 19.2];
1580 for (i, (&out, &exp)) in output.iter().zip(expected.iter()).enumerate() {
1581 assert!((out - exp).abs() < 1e-5, "cpu[{i}]: {out} != {exp}");
1582 }
1583 for (i, (&out, &exp)) in output_chunked.iter().zip(expected.iter()).enumerate() {
1584 assert!((out - exp).abs() < 1e-5, "chunked[{i}]: {out} != {exp}");
1585 }
1586
1587 let quant = Quantization::new(1.0, 0);
1589 dequantize_cpu(&input, quant, &mut output);
1590 dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1591 let expected: Vec<f32> = vec![0.0, 127.0, -128.0, 64.0];
1592 assert_eq!(output, expected);
1593 assert_eq!(output_chunked, expected);
1594
1595 let quant = Quantization::new(0.5, 0);
1597 dequantize_cpu(&input, quant, &mut output);
1598 dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1599 let expected: Vec<f32> = vec![0.0, 63.5, -64.0, 32.0];
1600 assert_eq!(output, expected);
1601 assert_eq!(output_chunked, expected);
1602
1603 let quant = Quantization::new(0.021287762, 31);
1605 let input: Vec<i8> = vec![-128, -1, 0, 1, 31, 127];
1606 let mut output = vec![0.0f32; 6];
1607 let mut output_chunked = vec![0.0f32; 6];
1608 dequantize_cpu(&input, quant, &mut output);
1609 dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1610 for i in 0..6 {
1611 let expected = (input[i] as f32 - 31.0) * 0.021287762;
1612 assert!(
1613 (output[i] - expected).abs() < 1e-5,
1614 "cpu[{i}]: {} != {expected}",
1615 output[i]
1616 );
1617 assert!(
1618 (output_chunked[i] - expected).abs() < 1e-5,
1619 "chunked[{i}]: {} != {expected}",
1620 output_chunked[i]
1621 );
1622 }
1623 }
1624
1625 #[test]
1626 fn test_decoder_yolo_det() {
1627 let score_threshold = 0.25;
1628 let iou_threshold = 0.7;
1629 let out = load_yolov8s_det();
1630 let quant = (0.0040811873, -123).into();
1631
1632 let decoder = DecoderBuilder::default()
1633 .with_config_yolo_det(
1634 configs::Detection {
1635 decoder: DecoderType::Ultralytics,
1636 shape: vec![1, 84, 8400],
1637 anchors: None,
1638 quantization: Some(quant),
1639 dshape: vec![
1640 (DimName::Batch, 1),
1641 (DimName::NumFeatures, 84),
1642 (DimName::NumBoxes, 8400),
1643 ],
1644 normalized: Some(true),
1645 },
1646 Some(DecoderVersion::Yolo11),
1647 )
1648 .with_score_threshold(score_threshold)
1649 .with_iou_threshold(iou_threshold)
1650 .build()
1651 .unwrap();
1652
1653 let mut output_boxes: Vec<_> = Vec::with_capacity(50);
1654 decode_yolo_det(
1655 (out.slice(s![0, .., ..]), quant.into()),
1656 score_threshold,
1657 iou_threshold,
1658 Some(configs::Nms::ClassAgnostic),
1659 &mut output_boxes,
1660 );
1661 assert!(output_boxes[0].equal_within_delta(
1662 &DetectBox {
1663 bbox: BoundingBox {
1664 xmin: 0.5285137,
1665 ymin: 0.05305544,
1666 xmax: 0.87541467,
1667 ymax: 0.9998909,
1668 },
1669 score: 0.5591227,
1670 label: 0
1671 },
1672 1e-6
1673 ));
1674
1675 assert!(output_boxes[1].equal_within_delta(
1676 &DetectBox {
1677 bbox: BoundingBox {
1678 xmin: 0.130598,
1679 ymin: 0.43260583,
1680 xmax: 0.35098213,
1681 ymax: 0.9958097,
1682 },
1683 score: 0.33057618,
1684 label: 75
1685 },
1686 1e-6
1687 ));
1688
1689 let mut output_boxes1: Vec<_> = Vec::with_capacity(50);
1690 let mut output_masks1: Vec<_> = Vec::with_capacity(50);
1691 decoder
1692 .decode_quantized(&[out.view().into()], &mut output_boxes1, &mut output_masks1)
1693 .unwrap();
1694
1695 let out = dequantize_ndarray(out.view(), quant.into());
1696 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(50);
1697 let mut output_masks_f32: Vec<_> = Vec::with_capacity(50);
1698 decoder
1699 .decode_float::<f32>(
1700 &[out.view().into_dyn()],
1701 &mut output_boxes_f32,
1702 &mut output_masks_f32,
1703 )
1704 .unwrap();
1705
1706 compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1707 compare_outputs((&output_boxes, &output_boxes_f32), (&[], &output_masks_f32));
1708 }
1709
1710 #[test]
1711 fn test_decoder_masks() {
1712 let score_threshold = 0.45;
1713 let iou_threshold = 0.45;
1714 let boxes = load_yolov8_boxes();
1715 let quant_boxes = Quantization::new(0.021287761628627777, 31);
1716
1717 let protos = load_yolov8_protos();
1718 let quant_protos = Quantization::new(0.02491161972284317, -117);
1719 let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1720 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1721 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1722 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1723 decode_yolo_segdet_float(
1724 seg.slice(s![0, .., ..]),
1725 protos.slice(s![0, .., .., ..]),
1726 score_threshold,
1727 iou_threshold,
1728 Some(configs::Nms::ClassAgnostic),
1729 &mut output_boxes,
1730 &mut output_masks,
1731 )
1732 .unwrap();
1733 assert_eq!(output_boxes.len(), 2);
1734 assert_eq!(output_boxes.len(), output_masks.len());
1735
1736 for (b, m) in output_boxes.iter().zip(&output_masks) {
1737 assert!(b.bbox.xmin >= m.xmin);
1738 assert!(b.bbox.ymin >= m.ymin);
1739 assert!(b.bbox.xmax >= m.xmax);
1740 assert!(b.bbox.ymax >= m.ymax);
1741 }
1742 assert!(output_boxes[0].equal_within_delta(
1743 &DetectBox {
1744 bbox: BoundingBox {
1745 xmin: 0.08515105,
1746 ymin: 0.7131401,
1747 xmax: 0.29802868,
1748 ymax: 0.8195788,
1749 },
1750 score: 0.91537374,
1751 label: 23
1752 },
1753 1.0 / 160.0, ));
1755
1756 assert!(output_boxes[1].equal_within_delta(
1757 &DetectBox {
1758 bbox: BoundingBox {
1759 xmin: 0.59605736,
1760 ymin: 0.25545314,
1761 xmax: 0.93666154,
1762 ymax: 0.72378385,
1763 },
1764 score: 0.91537374,
1765 label: 23
1766 },
1767 1.0 / 160.0, ));
1769
1770 let full_mask = include_bytes!(concat!(
1771 env!("CARGO_MANIFEST_DIR"),
1772 "/../../testdata/yolov8_mask_results.bin"
1773 ));
1774 let full_mask = ndarray::Array2::from_shape_vec((160, 160), full_mask.to_vec()).unwrap();
1775
1776 let cropped_mask = full_mask.slice(ndarray::s![
1777 (output_masks[1].ymin * 160.0) as usize..(output_masks[1].ymax * 160.0) as usize,
1778 (output_masks[1].xmin * 160.0) as usize..(output_masks[1].xmax * 160.0) as usize,
1779 ]);
1780
1781 assert_eq!(
1782 cropped_mask,
1783 segmentation_to_mask(output_masks[1].segmentation.view()).unwrap()
1784 );
1785 }
1786
1787 #[test]
1792 fn test_decoder_masks_nchw_protos() {
1793 let score_threshold = 0.45;
1794 let iou_threshold = 0.45;
1795
1796 let boxes_2d = load_yolov8_boxes().slice_move(s![0, .., ..]);
1798 let quant_boxes = Quantization::new(0.021287761628627777, 31);
1799
1800 let protos_hwc = load_yolov8_protos().slice_move(s![0, .., .., ..]);
1802 let quant_protos = Quantization::new(0.02491161972284317, -117);
1803 let protos_f32_hwc = dequantize_ndarray::<_, _, f32>(protos_hwc.view(), quant_protos);
1804
1805 let seg = dequantize_ndarray::<_, _, f32>(boxes_2d.view(), quant_boxes);
1807 let mut ref_boxes: Vec<_> = Vec::with_capacity(10);
1808 let mut ref_masks: Vec<_> = Vec::with_capacity(10);
1809 decode_yolo_segdet_float(
1810 seg.view(),
1811 protos_f32_hwc.view(),
1812 score_threshold,
1813 iou_threshold,
1814 Some(configs::Nms::ClassAgnostic),
1815 &mut ref_boxes,
1816 &mut ref_masks,
1817 )
1818 .unwrap();
1819 assert_eq!(ref_boxes.len(), 2);
1820
1821 let protos_f32_chw = protos_f32_hwc.permuted_axes([2, 0, 1]); let protos_nchw = protos_f32_chw.insert_axis(ndarray::Axis(0)); let seg_3d = seg.insert_axis(ndarray::Axis(0)); let decoder = DecoderBuilder::default()
1831 .with_config_yolo_segdet(
1832 configs::Detection {
1833 decoder: configs::DecoderType::Ultralytics,
1834 quantization: None,
1835 shape: vec![1, 116, 8400],
1836 dshape: vec![],
1837 normalized: Some(true),
1838 anchors: None,
1839 },
1840 configs::Protos {
1841 decoder: configs::DecoderType::Ultralytics,
1842 quantization: None,
1843 shape: vec![1, 32, 160, 160],
1844 dshape: vec![], },
1846 None, )
1848 .with_score_threshold(score_threshold)
1849 .with_iou_threshold(iou_threshold)
1850 .build()
1851 .unwrap();
1852
1853 let mut cfg_boxes: Vec<_> = Vec::with_capacity(10);
1854 let mut cfg_masks: Vec<_> = Vec::with_capacity(10);
1855 decoder
1856 .decode_float(
1857 &[seg_3d.view().into_dyn(), protos_nchw.view().into_dyn()],
1858 &mut cfg_boxes,
1859 &mut cfg_masks,
1860 )
1861 .unwrap();
1862
1863 assert_eq!(
1865 cfg_boxes.len(),
1866 ref_boxes.len(),
1867 "config path produced {} boxes, reference produced {}",
1868 cfg_boxes.len(),
1869 ref_boxes.len()
1870 );
1871
1872 for (i, (cb, rb)) in cfg_boxes.iter().zip(&ref_boxes).enumerate() {
1874 assert!(
1875 cb.equal_within_delta(rb, 0.01),
1876 "box {i} mismatch: config={cb:?}, reference={rb:?}"
1877 );
1878 }
1879
1880 for (i, (cm, rm)) in cfg_masks.iter().zip(&ref_masks).enumerate() {
1882 let cm_arr = segmentation_to_mask(cm.segmentation.view()).unwrap();
1883 let rm_arr = segmentation_to_mask(rm.segmentation.view()).unwrap();
1884 assert_eq!(
1885 cm_arr, rm_arr,
1886 "mask {i} pixel mismatch between config-driven and reference paths"
1887 );
1888 }
1889 }
1890
1891 #[test]
1892 fn test_decoder_masks_i8() {
1893 let score_threshold = 0.45;
1894 let iou_threshold = 0.45;
1895 let boxes = load_yolov8_boxes();
1896 let quant_boxes = (0.021287761628627777, 31).into();
1897
1898 let protos = load_yolov8_protos();
1899 let quant_protos = (0.02491161972284317, -117).into();
1900 let mut output_boxes: Vec<_> = Vec::with_capacity(500);
1901 let mut output_masks: Vec<_> = Vec::with_capacity(500);
1902
1903 let decoder = DecoderBuilder::default()
1904 .with_config_yolo_segdet(
1905 configs::Detection {
1906 decoder: configs::DecoderType::Ultralytics,
1907 quantization: Some(quant_boxes),
1908 shape: vec![1, 116, 8400],
1909 anchors: None,
1910 dshape: vec![
1911 (DimName::Batch, 1),
1912 (DimName::NumFeatures, 116),
1913 (DimName::NumBoxes, 8400),
1914 ],
1915 normalized: Some(true),
1916 },
1917 Protos {
1918 decoder: configs::DecoderType::Ultralytics,
1919 quantization: Some(quant_protos),
1920 shape: vec![1, 160, 160, 32],
1921 dshape: vec![
1922 (DimName::Batch, 1),
1923 (DimName::Height, 160),
1924 (DimName::Width, 160),
1925 (DimName::NumProtos, 32),
1926 ],
1927 },
1928 Some(DecoderVersion::Yolo11),
1929 )
1930 .with_score_threshold(score_threshold)
1931 .with_iou_threshold(iou_threshold)
1932 .build()
1933 .unwrap();
1934
1935 let quant_boxes = quant_boxes.into();
1936 let quant_protos = quant_protos.into();
1937
1938 decode_yolo_segdet_quant(
1939 (boxes.slice(s![0, .., ..]), quant_boxes),
1940 (protos.slice(s![0, .., .., ..]), quant_protos),
1941 score_threshold,
1942 iou_threshold,
1943 Some(configs::Nms::ClassAgnostic),
1944 &mut output_boxes,
1945 &mut output_masks,
1946 )
1947 .unwrap();
1948
1949 let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
1950 let mut output_masks1: Vec<_> = Vec::with_capacity(500);
1951
1952 decoder
1953 .decode_quantized(
1954 &[boxes.view().into(), protos.view().into()],
1955 &mut output_boxes1,
1956 &mut output_masks1,
1957 )
1958 .unwrap();
1959
1960 let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1961 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1962
1963 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
1964 let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
1965 decode_yolo_segdet_float(
1966 seg.slice(s![0, .., ..]),
1967 protos.slice(s![0, .., .., ..]),
1968 score_threshold,
1969 iou_threshold,
1970 Some(configs::Nms::ClassAgnostic),
1971 &mut output_boxes_f32,
1972 &mut output_masks_f32,
1973 )
1974 .unwrap();
1975
1976 let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(500);
1977 let mut output_masks1_f32: Vec<_> = Vec::with_capacity(500);
1978
1979 decoder
1980 .decode_float(
1981 &[seg.view().into_dyn(), protos.view().into_dyn()],
1982 &mut output_boxes1_f32,
1983 &mut output_masks1_f32,
1984 )
1985 .unwrap();
1986
1987 compare_outputs(
1988 (&output_boxes, &output_boxes1),
1989 (&output_masks, &output_masks1),
1990 );
1991
1992 compare_outputs(
1993 (&output_boxes, &output_boxes_f32),
1994 (&output_masks, &output_masks_f32),
1995 );
1996
1997 compare_outputs(
1998 (&output_boxes_f32, &output_boxes1_f32),
1999 (&output_masks_f32, &output_masks1_f32),
2000 );
2001 }
2002
2003 #[test]
2004 fn test_decoder_yolo_split() {
2005 let score_threshold = 0.45;
2006 let iou_threshold = 0.45;
2007 let boxes = load_yolov8_boxes();
2008 let boxes: Vec<_> = boxes.iter().map(|x| *x as i16 * 256).collect();
2009 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2010
2011 let quant_boxes = Quantization::new(0.021287761628627777 / 256.0, 31 * 256);
2012
2013 let decoder = DecoderBuilder::default()
2014 .with_config_yolo_split_det(
2015 configs::Boxes {
2016 decoder: configs::DecoderType::Ultralytics,
2017 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2018 shape: vec![1, 4, 8400],
2019 dshape: vec![
2020 (DimName::Batch, 1),
2021 (DimName::BoxCoords, 4),
2022 (DimName::NumBoxes, 8400),
2023 ],
2024 normalized: Some(true),
2025 },
2026 configs::Scores {
2027 decoder: configs::DecoderType::Ultralytics,
2028 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2029 shape: vec![1, 80, 8400],
2030 dshape: vec![
2031 (DimName::Batch, 1),
2032 (DimName::NumClasses, 80),
2033 (DimName::NumBoxes, 8400),
2034 ],
2035 },
2036 )
2037 .with_score_threshold(score_threshold)
2038 .with_iou_threshold(iou_threshold)
2039 .build()
2040 .unwrap();
2041
2042 let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2043 let mut output_masks: Vec<_> = Vec::with_capacity(500);
2044
2045 decoder
2046 .decode_quantized(
2047 &[
2048 boxes.slice(s![.., ..4, ..]).into(),
2049 boxes.slice(s![.., 4..84, ..]).into(),
2050 ],
2051 &mut output_boxes,
2052 &mut output_masks,
2053 )
2054 .unwrap();
2055
2056 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
2057 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2058 decode_yolo_det_float(
2059 seg.slice(s![0, ..84, ..]),
2060 score_threshold,
2061 iou_threshold,
2062 Some(configs::Nms::ClassAgnostic),
2063 &mut output_boxes_f32,
2064 );
2065
2066 let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
2067 let mut output_masks1: Vec<_> = Vec::with_capacity(500);
2068
2069 decoder
2070 .decode_float(
2071 &[
2072 seg.slice(s![.., ..4, ..]).into_dyn(),
2073 seg.slice(s![.., 4..84, ..]).into_dyn(),
2074 ],
2075 &mut output_boxes1,
2076 &mut output_masks1,
2077 )
2078 .unwrap();
2079 compare_outputs((&output_boxes, &output_boxes_f32), (&output_masks, &[]));
2080 compare_outputs((&output_boxes_f32, &output_boxes1), (&[], &output_masks1));
2081 }
2082
2083 #[test]
2084 fn test_decoder_masks_config_mixed() {
2085 let score_threshold = 0.45;
2086 let iou_threshold = 0.45;
2087 let boxes_raw = load_yolov8_boxes();
2088 let boxes: Vec<_> = boxes_raw.iter().map(|x| *x as i16 * 256).collect();
2089 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2090
2091 let quant_boxes = (0.021287761628627777 / 256.0, 31 * 256);
2092
2093 let protos = load_yolov8_protos();
2094 let quant_protos = (0.02491161972284317, -117);
2095
2096 let decoder = build_yolo_split_segdet_decoder(
2097 score_threshold,
2098 iou_threshold,
2099 quant_boxes,
2100 quant_protos,
2101 );
2102 let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2103 let mut output_masks: Vec<_> = Vec::with_capacity(500);
2104
2105 decoder
2106 .decode_quantized(
2107 &[
2108 boxes.slice(s![.., ..4, ..]).into(),
2109 boxes.slice(s![.., 4..84, ..]).into(),
2110 boxes.slice(s![.., 84.., ..]).into(),
2111 protos.view().into(),
2112 ],
2113 &mut output_boxes,
2114 &mut output_masks,
2115 )
2116 .unwrap();
2117
2118 let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos.into());
2119 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes.into());
2120 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2121 let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
2122 decode_yolo_segdet_float(
2123 seg.slice(s![0, .., ..]),
2124 protos.slice(s![0, .., .., ..]),
2125 score_threshold,
2126 iou_threshold,
2127 Some(configs::Nms::ClassAgnostic),
2128 &mut output_boxes_f32,
2129 &mut output_masks_f32,
2130 )
2131 .unwrap();
2132
2133 let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
2134 let mut output_masks1: Vec<_> = Vec::with_capacity(500);
2135
2136 decoder
2137 .decode_float(
2138 &[
2139 seg.slice(s![.., ..4, ..]).into_dyn(),
2140 seg.slice(s![.., 4..84, ..]).into_dyn(),
2141 seg.slice(s![.., 84.., ..]).into_dyn(),
2142 protos.view().into_dyn(),
2143 ],
2144 &mut output_boxes1,
2145 &mut output_masks1,
2146 )
2147 .unwrap();
2148 compare_outputs(
2149 (&output_boxes, &output_boxes_f32),
2150 (&output_masks, &output_masks_f32),
2151 );
2152 compare_outputs(
2153 (&output_boxes_f32, &output_boxes1),
2154 (&output_masks_f32, &output_masks1),
2155 );
2156 }
2157
2158 fn build_yolo_split_segdet_decoder(
2159 score_threshold: f32,
2160 iou_threshold: f32,
2161 quant_boxes: (f32, i32),
2162 quant_protos: (f32, i32),
2163 ) -> crate::Decoder {
2164 DecoderBuilder::default()
2165 .with_config_yolo_split_segdet(
2166 configs::Boxes {
2167 decoder: configs::DecoderType::Ultralytics,
2168 quantization: Some(quant_boxes.into()),
2169 shape: vec![1, 4, 8400],
2170 dshape: vec![
2171 (DimName::Batch, 1),
2172 (DimName::BoxCoords, 4),
2173 (DimName::NumBoxes, 8400),
2174 ],
2175 normalized: Some(true),
2176 },
2177 configs::Scores {
2178 decoder: configs::DecoderType::Ultralytics,
2179 quantization: Some(quant_boxes.into()),
2180 shape: vec![1, 80, 8400],
2181 dshape: vec![
2182 (DimName::Batch, 1),
2183 (DimName::NumClasses, 80),
2184 (DimName::NumBoxes, 8400),
2185 ],
2186 },
2187 configs::MaskCoefficients {
2188 decoder: configs::DecoderType::Ultralytics,
2189 quantization: Some(quant_boxes.into()),
2190 shape: vec![1, 32, 8400],
2191 dshape: vec![
2192 (DimName::Batch, 1),
2193 (DimName::NumProtos, 32),
2194 (DimName::NumBoxes, 8400),
2195 ],
2196 },
2197 configs::Protos {
2198 decoder: configs::DecoderType::Ultralytics,
2199 quantization: Some(quant_protos.into()),
2200 shape: vec![1, 160, 160, 32],
2201 dshape: vec![
2202 (DimName::Batch, 1),
2203 (DimName::Height, 160),
2204 (DimName::Width, 160),
2205 (DimName::NumProtos, 32),
2206 ],
2207 },
2208 )
2209 .with_score_threshold(score_threshold)
2210 .with_iou_threshold(iou_threshold)
2211 .build()
2212 .unwrap()
2213 }
2214
2215 fn build_yolov8_seg_decoder(score_threshold: f32, iou_threshold: f32) -> crate::Decoder {
2216 let config_yaml = include_str!(concat!(
2217 env!("CARGO_MANIFEST_DIR"),
2218 "/../../testdata/yolov8_seg.yaml"
2219 ));
2220 DecoderBuilder::default()
2221 .with_config_yaml_str(config_yaml.to_string())
2222 .with_score_threshold(score_threshold)
2223 .with_iou_threshold(iou_threshold)
2224 .build()
2225 .unwrap()
2226 }
2227 #[test]
2228 fn test_decoder_masks_config_i32() {
2229 let score_threshold = 0.45;
2230 let iou_threshold = 0.45;
2231 let boxes_raw = load_yolov8_boxes();
2232 let scale = 1 << 23;
2233 let boxes: Vec<_> = boxes_raw.iter().map(|x| *x as i32 * scale).collect();
2234 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2235
2236 let quant_boxes = (0.021287761628627777 / scale as f32, 31 * scale);
2237
2238 let protos_raw = load_yolov8_protos();
2239 let protos: Vec<_> = protos_raw.iter().map(|x| *x as i32 * scale).collect();
2240 let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos).unwrap();
2241 let quant_protos = (0.02491161972284317 / scale as f32, -117 * scale);
2242
2243 let decoder = build_yolo_split_segdet_decoder(
2244 score_threshold,
2245 iou_threshold,
2246 quant_boxes,
2247 quant_protos,
2248 );
2249
2250 let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2251 let mut output_masks: Vec<_> = Vec::with_capacity(500);
2252
2253 decoder
2254 .decode_quantized(
2255 &[
2256 boxes.slice(s![.., ..4, ..]).into(),
2257 boxes.slice(s![.., 4..84, ..]).into(),
2258 boxes.slice(s![.., 84.., ..]).into(),
2259 protos.view().into(),
2260 ],
2261 &mut output_boxes,
2262 &mut output_masks,
2263 )
2264 .unwrap();
2265
2266 let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos.into());
2267 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes.into());
2268 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2269 let mut output_masks_f32: Vec<Segmentation> = Vec::with_capacity(500);
2270 decode_yolo_segdet_float(
2271 seg.slice(s![0, .., ..]),
2272 protos.slice(s![0, .., .., ..]),
2273 score_threshold,
2274 iou_threshold,
2275 Some(configs::Nms::ClassAgnostic),
2276 &mut output_boxes_f32,
2277 &mut output_masks_f32,
2278 )
2279 .unwrap();
2280
2281 assert_eq!(output_boxes.len(), output_boxes_f32.len());
2282 assert_eq!(output_masks.len(), output_masks_f32.len());
2283
2284 compare_outputs(
2285 (&output_boxes, &output_boxes_f32),
2286 (&output_masks, &output_masks_f32),
2287 );
2288 }
2289
2290 #[test]
2292 fn test_context_switch() {
2293 let yolo_det = || {
2294 let score_threshold = 0.25;
2295 let iou_threshold = 0.7;
2296 let out = load_yolov8s_det();
2297 let quant = (0.0040811873, -123).into();
2298
2299 let decoder = DecoderBuilder::default()
2300 .with_config_yolo_det(
2301 configs::Detection {
2302 decoder: DecoderType::Ultralytics,
2303 shape: vec![1, 84, 8400],
2304 anchors: None,
2305 quantization: Some(quant),
2306 dshape: vec![
2307 (DimName::Batch, 1),
2308 (DimName::NumFeatures, 84),
2309 (DimName::NumBoxes, 8400),
2310 ],
2311 normalized: None,
2312 },
2313 None,
2314 )
2315 .with_score_threshold(score_threshold)
2316 .with_iou_threshold(iou_threshold)
2317 .build()
2318 .unwrap();
2319
2320 let mut output_boxes: Vec<_> = Vec::with_capacity(50);
2321 let mut output_masks: Vec<_> = Vec::with_capacity(50);
2322
2323 for _ in 0..100 {
2324 decoder
2325 .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
2326 .unwrap();
2327
2328 assert!(output_boxes[0].equal_within_delta(
2329 &DetectBox {
2330 bbox: BoundingBox {
2331 xmin: 0.5285137,
2332 ymin: 0.05305544,
2333 xmax: 0.87541467,
2334 ymax: 0.9998909,
2335 },
2336 score: 0.5591227,
2337 label: 0
2338 },
2339 1e-6
2340 ));
2341
2342 assert!(output_boxes[1].equal_within_delta(
2343 &DetectBox {
2344 bbox: BoundingBox {
2345 xmin: 0.130598,
2346 ymin: 0.43260583,
2347 xmax: 0.35098213,
2348 ymax: 0.9958097,
2349 },
2350 score: 0.33057618,
2351 label: 75
2352 },
2353 1e-6
2354 ));
2355 assert!(output_masks.is_empty());
2356 }
2357 };
2358
2359 let modelpack_det_split = || {
2360 let score_threshold = 0.8;
2361 let iou_threshold = 0.5;
2362
2363 let seg = include_bytes!(concat!(
2364 env!("CARGO_MANIFEST_DIR"),
2365 "/../../testdata/modelpack_seg_2x160x160.bin"
2366 ));
2367 let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
2368
2369 let detect0 = include_bytes!(concat!(
2370 env!("CARGO_MANIFEST_DIR"),
2371 "/../../testdata/modelpack_split_9x15x18.bin"
2372 ));
2373 let detect0 =
2374 ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
2375
2376 let detect1 = include_bytes!(concat!(
2377 env!("CARGO_MANIFEST_DIR"),
2378 "/../../testdata/modelpack_split_17x30x18.bin"
2379 ));
2380 let detect1 =
2381 ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
2382
2383 let mut mask = seg.slice(s![0, .., .., ..]);
2384 mask.swap_axes(0, 1);
2385 mask.swap_axes(1, 2);
2386 let mask = [Segmentation {
2387 xmin: 0.0,
2388 ymin: 0.0,
2389 xmax: 1.0,
2390 ymax: 1.0,
2391 segmentation: mask.into_owned(),
2392 }];
2393 let correct_boxes = [DetectBox {
2394 bbox: BoundingBox {
2395 xmin: 0.43171933,
2396 ymin: 0.68243736,
2397 xmax: 0.5626645,
2398 ymax: 0.808863,
2399 },
2400 score: 0.99240804,
2401 label: 0,
2402 }];
2403
2404 let quant0 = (0.08547406643629074, 174).into();
2405 let quant1 = (0.09929127991199493, 183).into();
2406 let quant_seg = (1.0 / 255.0, 0).into();
2407
2408 let anchors0 = vec![
2409 [0.36666667461395264, 0.31481480598449707],
2410 [0.38749998807907104, 0.4740740656852722],
2411 [0.5333333611488342, 0.644444465637207],
2412 ];
2413 let anchors1 = vec![
2414 [0.13750000298023224, 0.2074074000120163],
2415 [0.2541666626930237, 0.21481481194496155],
2416 [0.23125000298023224, 0.35185185074806213],
2417 ];
2418
2419 let decoder = DecoderBuilder::default()
2420 .with_config_modelpack_segdet_split(
2421 vec![
2422 configs::Detection {
2423 decoder: DecoderType::ModelPack,
2424 shape: vec![1, 17, 30, 18],
2425 anchors: Some(anchors1),
2426 quantization: Some(quant1),
2427 dshape: vec![
2428 (DimName::Batch, 1),
2429 (DimName::Height, 17),
2430 (DimName::Width, 30),
2431 (DimName::NumAnchorsXFeatures, 18),
2432 ],
2433 normalized: None,
2434 },
2435 configs::Detection {
2436 decoder: DecoderType::ModelPack,
2437 shape: vec![1, 9, 15, 18],
2438 anchors: Some(anchors0),
2439 quantization: Some(quant0),
2440 dshape: vec![
2441 (DimName::Batch, 1),
2442 (DimName::Height, 9),
2443 (DimName::Width, 15),
2444 (DimName::NumAnchorsXFeatures, 18),
2445 ],
2446 normalized: None,
2447 },
2448 ],
2449 configs::Segmentation {
2450 decoder: DecoderType::ModelPack,
2451 quantization: Some(quant_seg),
2452 shape: vec![1, 2, 160, 160],
2453 dshape: vec![
2454 (DimName::Batch, 1),
2455 (DimName::NumClasses, 2),
2456 (DimName::Height, 160),
2457 (DimName::Width, 160),
2458 ],
2459 },
2460 )
2461 .with_score_threshold(score_threshold)
2462 .with_iou_threshold(iou_threshold)
2463 .build()
2464 .unwrap();
2465 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
2466 let mut output_masks: Vec<_> = Vec::with_capacity(10);
2467
2468 for _ in 0..100 {
2469 decoder
2470 .decode_quantized(
2471 &[
2472 detect0.view().into(),
2473 detect1.view().into(),
2474 seg.view().into(),
2475 ],
2476 &mut output_boxes,
2477 &mut output_masks,
2478 )
2479 .unwrap();
2480
2481 compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
2482 }
2483 };
2484
2485 let handles = vec![
2486 std::thread::spawn(yolo_det),
2487 std::thread::spawn(modelpack_det_split),
2488 std::thread::spawn(yolo_det),
2489 std::thread::spawn(modelpack_det_split),
2490 std::thread::spawn(yolo_det),
2491 std::thread::spawn(modelpack_det_split),
2492 std::thread::spawn(yolo_det),
2493 std::thread::spawn(modelpack_det_split),
2494 ];
2495 for handle in handles {
2496 handle.join().unwrap();
2497 }
2498 }
2499
2500 #[test]
2501 fn test_ndarray_to_xyxy_float() {
2502 let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2503 let xyxy: [f32; 4] = XYWH::ndarray_to_xyxy_float(arr.view());
2504 assert_eq!(xyxy, [0.0_f32, 10.0, 20.0, 30.0]);
2505
2506 let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2507 let xyxy: [f32; 4] = XYXY::ndarray_to_xyxy_float(arr.view());
2508 assert_eq!(xyxy, [10.0_f32, 20.0, 20.0, 20.0]);
2509 }
2510
2511 #[test]
2512 fn test_class_aware_nms_float() {
2513 use crate::float::nms_class_aware_float;
2514
2515 let boxes = vec![
2517 DetectBox {
2518 bbox: BoundingBox {
2519 xmin: 0.0,
2520 ymin: 0.0,
2521 xmax: 0.5,
2522 ymax: 0.5,
2523 },
2524 score: 0.9,
2525 label: 0, },
2527 DetectBox {
2528 bbox: BoundingBox {
2529 xmin: 0.1,
2530 ymin: 0.1,
2531 xmax: 0.6,
2532 ymax: 0.6,
2533 },
2534 score: 0.8,
2535 label: 1, },
2537 ];
2538
2539 let result = nms_class_aware_float(0.3, boxes.clone());
2542 assert_eq!(
2543 result.len(),
2544 2,
2545 "Class-aware NMS should keep both boxes with different classes"
2546 );
2547
2548 let same_class_boxes = vec![
2550 DetectBox {
2551 bbox: BoundingBox {
2552 xmin: 0.0,
2553 ymin: 0.0,
2554 xmax: 0.5,
2555 ymax: 0.5,
2556 },
2557 score: 0.9,
2558 label: 0,
2559 },
2560 DetectBox {
2561 bbox: BoundingBox {
2562 xmin: 0.1,
2563 ymin: 0.1,
2564 xmax: 0.6,
2565 ymax: 0.6,
2566 },
2567 score: 0.8,
2568 label: 0, },
2570 ];
2571
2572 let result = nms_class_aware_float(0.3, same_class_boxes);
2573 assert_eq!(
2574 result.len(),
2575 1,
2576 "Class-aware NMS should suppress overlapping box with same class"
2577 );
2578 assert_eq!(result[0].label, 0);
2579 assert!((result[0].score - 0.9).abs() < 1e-6);
2580 }
2581
2582 #[test]
2583 fn test_class_agnostic_vs_aware_nms() {
2584 use crate::float::{nms_class_aware_float, nms_float};
2585
2586 let boxes = vec![
2588 DetectBox {
2589 bbox: BoundingBox {
2590 xmin: 0.0,
2591 ymin: 0.0,
2592 xmax: 0.5,
2593 ymax: 0.5,
2594 },
2595 score: 0.9,
2596 label: 0,
2597 },
2598 DetectBox {
2599 bbox: BoundingBox {
2600 xmin: 0.1,
2601 ymin: 0.1,
2602 xmax: 0.6,
2603 ymax: 0.6,
2604 },
2605 score: 0.8,
2606 label: 1,
2607 },
2608 ];
2609
2610 let agnostic_result = nms_float(0.3, boxes.clone());
2612 assert_eq!(
2613 agnostic_result.len(),
2614 1,
2615 "Class-agnostic NMS should suppress overlapping boxes"
2616 );
2617
2618 let aware_result = nms_class_aware_float(0.3, boxes);
2620 assert_eq!(
2621 aware_result.len(),
2622 2,
2623 "Class-aware NMS should keep boxes with different classes"
2624 );
2625 }
2626
2627 #[test]
2628 fn test_class_aware_nms_int() {
2629 use crate::byte::nms_class_aware_int;
2630
2631 let boxes = vec![
2633 DetectBoxQuantized {
2634 bbox: BoundingBox {
2635 xmin: 0.0,
2636 ymin: 0.0,
2637 xmax: 0.5,
2638 ymax: 0.5,
2639 },
2640 score: 200_u8,
2641 label: 0,
2642 },
2643 DetectBoxQuantized {
2644 bbox: BoundingBox {
2645 xmin: 0.1,
2646 ymin: 0.1,
2647 xmax: 0.6,
2648 ymax: 0.6,
2649 },
2650 score: 180_u8,
2651 label: 1, },
2653 ];
2654
2655 let result = nms_class_aware_int(0.5, boxes);
2657 assert_eq!(
2658 result.len(),
2659 2,
2660 "Class-aware NMS (int) should keep boxes with different classes"
2661 );
2662 }
2663
2664 #[test]
2665 fn test_nms_enum_default() {
2666 let default_nms: configs::Nms = Default::default();
2668 assert_eq!(default_nms, configs::Nms::ClassAgnostic);
2669 }
2670
2671 #[test]
2672 fn test_decoder_nms_mode() {
2673 let decoder = DecoderBuilder::default()
2675 .with_config_yolo_det(
2676 configs::Detection {
2677 anchors: None,
2678 decoder: DecoderType::Ultralytics,
2679 quantization: None,
2680 shape: vec![1, 84, 8400],
2681 dshape: Vec::new(),
2682 normalized: Some(true),
2683 },
2684 None,
2685 )
2686 .with_nms(Some(configs::Nms::ClassAware))
2687 .build()
2688 .unwrap();
2689
2690 assert_eq!(decoder.nms, Some(configs::Nms::ClassAware));
2691 }
2692
2693 #[test]
2694 fn test_decoder_nms_bypass() {
2695 let decoder = DecoderBuilder::default()
2697 .with_config_yolo_det(
2698 configs::Detection {
2699 anchors: None,
2700 decoder: DecoderType::Ultralytics,
2701 quantization: None,
2702 shape: vec![1, 84, 8400],
2703 dshape: Vec::new(),
2704 normalized: Some(true),
2705 },
2706 None,
2707 )
2708 .with_nms(None)
2709 .build()
2710 .unwrap();
2711
2712 assert_eq!(decoder.nms, None);
2713 }
2714
2715 #[test]
2716 fn test_decoder_normalized_boxes_true() {
2717 let decoder = DecoderBuilder::default()
2719 .with_config_yolo_det(
2720 configs::Detection {
2721 anchors: None,
2722 decoder: DecoderType::Ultralytics,
2723 quantization: None,
2724 shape: vec![1, 84, 8400],
2725 dshape: Vec::new(),
2726 normalized: Some(true),
2727 },
2728 None,
2729 )
2730 .build()
2731 .unwrap();
2732
2733 assert_eq!(decoder.normalized_boxes(), Some(true));
2734 }
2735
2736 #[test]
2737 fn test_decoder_normalized_boxes_false() {
2738 let decoder = DecoderBuilder::default()
2741 .with_config_yolo_det(
2742 configs::Detection {
2743 anchors: None,
2744 decoder: DecoderType::Ultralytics,
2745 quantization: None,
2746 shape: vec![1, 84, 8400],
2747 dshape: Vec::new(),
2748 normalized: Some(false),
2749 },
2750 None,
2751 )
2752 .build()
2753 .unwrap();
2754
2755 assert_eq!(decoder.normalized_boxes(), Some(false));
2756 }
2757
2758 #[test]
2759 fn test_decoder_normalized_boxes_unknown() {
2760 let decoder = DecoderBuilder::default()
2762 .with_config_yolo_det(
2763 configs::Detection {
2764 anchors: None,
2765 decoder: DecoderType::Ultralytics,
2766 quantization: None,
2767 shape: vec![1, 84, 8400],
2768 dshape: Vec::new(),
2769 normalized: None,
2770 },
2771 Some(DecoderVersion::Yolo11),
2772 )
2773 .build()
2774 .unwrap();
2775
2776 assert_eq!(decoder.normalized_boxes(), None);
2777 }
2778
2779 pub fn quantize_ndarray<T: PrimInt + 'static, D: Dimension, F: Float + AsPrimitive<T>>(
2780 input: ArrayView<F, D>,
2781 quant: Quantization,
2782 ) -> Array<T, D>
2783 where
2784 i32: num_traits::AsPrimitive<F>,
2785 f32: num_traits::AsPrimitive<F>,
2786 {
2787 let zero_point = quant.zero_point.as_();
2788 let div_scale = F::one() / quant.scale.as_();
2789 if zero_point != F::zero() {
2790 input.mapv(|d| (d * div_scale + zero_point).round().as_())
2791 } else {
2792 input.mapv(|d| (d * div_scale).round().as_())
2793 }
2794 }
2795
2796 fn real_data_expected_boxes() -> [DetectBox; 2] {
2797 [
2798 DetectBox {
2799 bbox: BoundingBox {
2800 xmin: 0.08515105,
2801 ymin: 0.7131401,
2802 xmax: 0.29802868,
2803 ymax: 0.8195788,
2804 },
2805 score: 0.91537374,
2806 label: 23,
2807 },
2808 DetectBox {
2809 bbox: BoundingBox {
2810 xmin: 0.59605736,
2811 ymin: 0.25545314,
2812 xmax: 0.93666154,
2813 ymax: 0.72378385,
2814 },
2815 score: 0.91537374,
2816 label: 23,
2817 },
2818 ]
2819 }
2820
2821 fn e2e_expected_boxes_quant() -> [DetectBox; 1] {
2822 [DetectBox {
2823 bbox: BoundingBox {
2824 xmin: 0.12549022,
2825 ymin: 0.12549022,
2826 xmax: 0.23529413,
2827 ymax: 0.23529413,
2828 },
2829 score: 0.98823535,
2830 label: 2,
2831 }]
2832 }
2833
2834 fn e2e_expected_boxes_float() -> [DetectBox; 1] {
2835 [DetectBox {
2836 bbox: BoundingBox {
2837 xmin: 0.1234,
2838 ymin: 0.1234,
2839 xmax: 0.2345,
2840 ymax: 0.2345,
2841 },
2842 score: 0.9876,
2843 label: 2,
2844 }]
2845 }
2846
2847 macro_rules! real_data_proto_test {
2848 ($name:ident, quantized, $layout:ident) => {
2849 #[test]
2850 fn $name() {
2851 let is_split = matches!(stringify!($layout), "split");
2852
2853 let score_threshold = 0.45;
2854 let iou_threshold = 0.45;
2855 let quant_boxes = (0.021287762_f32, 31_i32);
2856 let quant_protos = (0.02491162_f32, -117_i32);
2857
2858 let raw_boxes = include_bytes!(concat!(
2859 env!("CARGO_MANIFEST_DIR"),
2860 "/../../testdata/yolov8_boxes_116x8400.bin"
2861 ));
2862 let raw_boxes = unsafe {
2863 std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len())
2864 };
2865 let boxes_i8 =
2866 ndarray::Array3::from_shape_vec((1, 116, 8400), raw_boxes.to_vec()).unwrap();
2867
2868 let raw_protos = include_bytes!(concat!(
2869 env!("CARGO_MANIFEST_DIR"),
2870 "/../../testdata/yolov8_protos_160x160x32.bin"
2871 ));
2872 let raw_protos = unsafe {
2873 std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
2874 };
2875 let protos_i8 =
2876 ndarray::Array4::from_shape_vec((1, 160, 160, 32), raw_protos.to_vec())
2877 .unwrap();
2878
2879 let mask_split = boxes_i8.slice(s![.., 84.., ..]).to_owned();
2881 let scores_split = boxes_i8.slice(s![.., 4..84, ..]).to_owned();
2882 let boxes_split = boxes_i8.slice(s![.., ..4, ..]).to_owned();
2883 let boxes_combined = boxes_i8;
2884
2885 let decoder = if is_split {
2886 build_yolo_split_segdet_decoder(
2887 score_threshold,
2888 iou_threshold,
2889 quant_boxes,
2890 quant_protos,
2891 )
2892 } else {
2893 build_yolov8_seg_decoder(score_threshold, iou_threshold)
2894 };
2895
2896 let expected = real_data_expected_boxes();
2897 let mut output_boxes = Vec::with_capacity(50);
2898
2899 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
2900 vec![
2901 boxes_split.view().into(),
2902 scores_split.view().into(),
2903 mask_split.view().into(),
2904 protos_i8.view().into(),
2905 ]
2906 } else {
2907 vec![boxes_combined.view().into(), protos_i8.view().into()]
2908 };
2909 decoder
2910 .decode_quantized_proto(&inputs, &mut output_boxes)
2911 .unwrap();
2912
2913 assert_eq!(output_boxes.len(), 2);
2914 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
2915 assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
2916 }
2917 };
2918 ($name:ident, float, $layout:ident) => {
2919 #[test]
2920 fn $name() {
2921 let is_split = matches!(stringify!($layout), "split");
2922
2923 let score_threshold = 0.45;
2924 let iou_threshold = 0.45;
2925 let quant_boxes = (0.021287762_f32, 31_i32);
2926 let quant_protos = (0.02491162_f32, -117_i32);
2927
2928 let raw_boxes = include_bytes!(concat!(
2929 env!("CARGO_MANIFEST_DIR"),
2930 "/../../testdata/yolov8_boxes_116x8400.bin"
2931 ));
2932 let raw_boxes = unsafe {
2933 std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len())
2934 };
2935 let boxes_i8 =
2936 ndarray::Array3::from_shape_vec((1, 116, 8400), raw_boxes.to_vec()).unwrap();
2937 let boxes_f32: Array3<f32> =
2938 dequantize_ndarray(boxes_i8.view(), quant_boxes.into());
2939
2940 let raw_protos = include_bytes!(concat!(
2941 env!("CARGO_MANIFEST_DIR"),
2942 "/../../testdata/yolov8_protos_160x160x32.bin"
2943 ));
2944 let raw_protos = unsafe {
2945 std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
2946 };
2947 let protos_i8 =
2948 ndarray::Array4::from_shape_vec((1, 160, 160, 32), raw_protos.to_vec())
2949 .unwrap();
2950 let protos_f32: Array4<f32> =
2951 dequantize_ndarray(protos_i8.view(), quant_protos.into());
2952
2953 let mask_split = boxes_f32.slice(s![.., 84.., ..]).to_owned();
2955 let scores_split = boxes_f32.slice(s![.., 4..84, ..]).to_owned();
2956 let boxes_split = boxes_f32.slice(s![.., ..4, ..]).to_owned();
2957 let boxes_combined = boxes_f32;
2958
2959 let decoder = if is_split {
2960 build_yolo_split_segdet_decoder(
2961 score_threshold,
2962 iou_threshold,
2963 quant_boxes,
2964 quant_protos,
2965 )
2966 } else {
2967 build_yolov8_seg_decoder(score_threshold, iou_threshold)
2968 };
2969
2970 let expected = real_data_expected_boxes();
2971 let mut output_boxes = Vec::with_capacity(50);
2972
2973 let inputs = if is_split {
2974 vec![
2975 boxes_split.view().into_dyn(),
2976 scores_split.view().into_dyn(),
2977 mask_split.view().into_dyn(),
2978 protos_f32.view().into_dyn(),
2979 ]
2980 } else {
2981 vec![
2982 boxes_combined.view().into_dyn(),
2983 protos_f32.view().into_dyn(),
2984 ]
2985 };
2986 decoder
2987 .decode_float_proto(&inputs, &mut output_boxes)
2988 .unwrap();
2989
2990 assert_eq!(output_boxes.len(), 2);
2991 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
2992 assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
2993 }
2994 };
2995 }
2996
2997 real_data_proto_test!(test_decoder_segdet_proto, quantized, combined);
2998 real_data_proto_test!(test_decoder_segdet_proto_float, float, combined);
2999 real_data_proto_test!(test_decoder_segdet_split_proto, quantized, split);
3000 real_data_proto_test!(test_decoder_segdet_split_proto_float, float, split);
3001
3002 const E2E_COMBINED_DET_CONFIG: &str = "
3003decoder_version: yolo26
3004outputs:
3005 - type: detection
3006 decoder: ultralytics
3007 quantization: [0.00784313725490196, 0]
3008 shape: [1, 10, 6]
3009 dshape:
3010 - [batch, 1]
3011 - [num_boxes, 10]
3012 - [num_features, 6]
3013 normalized: true
3014";
3015
3016 const E2E_COMBINED_SEGDET_CONFIG: &str = "
3017decoder_version: yolo26
3018outputs:
3019 - type: detection
3020 decoder: ultralytics
3021 quantization: [0.00784313725490196, 0]
3022 shape: [1, 10, 38]
3023 dshape:
3024 - [batch, 1]
3025 - [num_boxes, 10]
3026 - [num_features, 38]
3027 normalized: true
3028 - type: protos
3029 decoder: ultralytics
3030 quantization: [0.0039215686274509803921568627451, 128]
3031 shape: [1, 160, 160, 32]
3032 dshape:
3033 - [batch, 1]
3034 - [height, 160]
3035 - [width, 160]
3036 - [num_protos, 32]
3037";
3038
3039 const E2E_SPLIT_DET_CONFIG: &str = "
3040decoder_version: yolo26
3041outputs:
3042 - type: boxes
3043 decoder: ultralytics
3044 quantization: [0.00784313725490196, 0]
3045 shape: [1, 10, 4]
3046 dshape:
3047 - [batch, 1]
3048 - [num_boxes, 10]
3049 - [box_coords, 4]
3050 normalized: true
3051 - type: scores
3052 decoder: ultralytics
3053 quantization: [0.00784313725490196, 0]
3054 shape: [1, 10, 1]
3055 dshape:
3056 - [batch, 1]
3057 - [num_boxes, 10]
3058 - [num_classes, 1]
3059 - type: classes
3060 decoder: ultralytics
3061 quantization: [0.00784313725490196, 0]
3062 shape: [1, 10, 1]
3063 dshape:
3064 - [batch, 1]
3065 - [num_boxes, 10]
3066 - [num_classes, 1]
3067";
3068
3069 const E2E_SPLIT_SEGDET_CONFIG: &str = "
3070decoder_version: yolo26
3071outputs:
3072 - type: boxes
3073 decoder: ultralytics
3074 quantization: [0.00784313725490196, 0]
3075 shape: [1, 10, 4]
3076 dshape:
3077 - [batch, 1]
3078 - [num_boxes, 10]
3079 - [box_coords, 4]
3080 normalized: true
3081 - type: scores
3082 decoder: ultralytics
3083 quantization: [0.00784313725490196, 0]
3084 shape: [1, 10, 1]
3085 dshape:
3086 - [batch, 1]
3087 - [num_boxes, 10]
3088 - [num_classes, 1]
3089 - type: classes
3090 decoder: ultralytics
3091 quantization: [0.00784313725490196, 0]
3092 shape: [1, 10, 1]
3093 dshape:
3094 - [batch, 1]
3095 - [num_boxes, 10]
3096 - [num_classes, 1]
3097 - type: mask_coefficients
3098 decoder: ultralytics
3099 quantization: [0.00784313725490196, 0]
3100 shape: [1, 10, 32]
3101 dshape:
3102 - [batch, 1]
3103 - [num_boxes, 10]
3104 - [num_protos, 32]
3105 - type: protos
3106 decoder: ultralytics
3107 quantization: [0.0039215686274509803921568627451, 128]
3108 shape: [1, 160, 160, 32]
3109 dshape:
3110 - [batch, 1]
3111 - [height, 160]
3112 - [width, 160]
3113 - [num_protos, 32]
3114";
3115
3116 macro_rules! e2e_segdet_test {
3117 ($name:ident, quantized, $layout:ident, $output:ident) => {
3118 #[test]
3119 fn $name() {
3120 let is_split = matches!(stringify!($layout), "split");
3121 let is_proto = matches!(stringify!($output), "proto");
3122
3123 let score_threshold = 0.45;
3124 let iou_threshold = 0.45;
3125
3126 let mut boxes = Array2::zeros((10, 4));
3127 let mut scores = Array2::zeros((10, 1));
3128 let mut classes = Array2::zeros((10, 1));
3129 let mask = Array2::zeros((10, 32));
3130 let protos = Array3::<f64>::zeros((160, 160, 32));
3131 let protos = protos.insert_axis(Axis(0));
3132 let protos_quant = (1.0 / 255.0, 0.0);
3133 let protos: Array4<u8> = quantize_ndarray(protos.view(), protos_quant.into());
3134
3135 boxes
3136 .slice_mut(s![0, ..])
3137 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
3138 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
3139 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
3140
3141 let detect_quant = (2.0 / 255.0, 0.0);
3142
3143 let decoder = if is_split {
3144 DecoderBuilder::default()
3145 .with_config_yaml_str(E2E_SPLIT_SEGDET_CONFIG.to_string())
3146 .with_score_threshold(score_threshold)
3147 .with_iou_threshold(iou_threshold)
3148 .build()
3149 .unwrap()
3150 } else {
3151 DecoderBuilder::default()
3152 .with_config_yaml_str(E2E_COMBINED_SEGDET_CONFIG.to_string())
3153 .with_score_threshold(score_threshold)
3154 .with_iou_threshold(iou_threshold)
3155 .build()
3156 .unwrap()
3157 };
3158
3159 let expected = e2e_expected_boxes_quant();
3160 let mut output_boxes = Vec::with_capacity(50);
3161
3162 if is_split {
3163 let boxes = boxes.insert_axis(Axis(0));
3164 let scores = scores.insert_axis(Axis(0));
3165 let classes = classes.insert_axis(Axis(0));
3166 let mask = mask.insert_axis(Axis(0));
3167
3168 let boxes: Array3<u8> = quantize_ndarray(boxes.view(), detect_quant.into());
3169 let scores: Array3<u8> = quantize_ndarray(scores.view(), detect_quant.into());
3170 let classes: Array3<u8> = quantize_ndarray(classes.view(), detect_quant.into());
3171 let mask: Array3<u8> = quantize_ndarray(mask.view(), detect_quant.into());
3172
3173 if is_proto {
3174 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
3175 boxes.view().into(),
3176 scores.view().into(),
3177 classes.view().into(),
3178 mask.view().into(),
3179 protos.view().into(),
3180 ];
3181 decoder
3182 .decode_quantized_proto(&inputs, &mut output_boxes)
3183 .unwrap();
3184
3185 assert_eq!(output_boxes.len(), 1);
3186 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3187 } else {
3188 let mut output_masks = Vec::with_capacity(50);
3189 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
3190 boxes.view().into(),
3191 scores.view().into(),
3192 classes.view().into(),
3193 mask.view().into(),
3194 protos.view().into(),
3195 ];
3196 decoder
3197 .decode_quantized(&inputs, &mut output_boxes, &mut output_masks)
3198 .unwrap();
3199
3200 assert_eq!(output_boxes.len(), 1);
3201 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3202 }
3203 } else {
3204 let detect = ndarray::concatenate![
3206 Axis(1),
3207 boxes.view(),
3208 scores.view(),
3209 classes.view(),
3210 mask.view()
3211 ];
3212 let detect = detect.insert_axis(Axis(0));
3213 assert_eq!(detect.shape(), &[1, 10, 38]);
3214 let detect: Array3<u8> = quantize_ndarray(detect.view(), detect_quant.into());
3215
3216 if is_proto {
3217 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
3218 vec![detect.view().into(), protos.view().into()];
3219 decoder
3220 .decode_quantized_proto(&inputs, &mut output_boxes)
3221 .unwrap();
3222
3223 assert_eq!(output_boxes.len(), 1);
3224 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3225 } else {
3226 let mut output_masks = Vec::with_capacity(50);
3227 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
3228 vec![detect.view().into(), protos.view().into()];
3229 decoder
3230 .decode_quantized(&inputs, &mut output_boxes, &mut output_masks)
3231 .unwrap();
3232
3233 assert_eq!(output_boxes.len(), 1);
3234 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3235 }
3236 }
3237 }
3238 };
3239 ($name:ident, float, $layout:ident, $output:ident) => {
3240 #[test]
3241 fn $name() {
3242 let is_split = matches!(stringify!($layout), "split");
3243 let is_proto = matches!(stringify!($output), "proto");
3244
3245 let score_threshold = 0.45;
3246 let iou_threshold = 0.45;
3247
3248 let mut boxes = Array2::zeros((10, 4));
3249 let mut scores = Array2::zeros((10, 1));
3250 let mut classes = Array2::zeros((10, 1));
3251 let mask: Array2<f64> = Array2::zeros((10, 32));
3252 let protos = Array3::<f64>::zeros((160, 160, 32));
3253 let protos = protos.insert_axis(Axis(0));
3254
3255 boxes
3256 .slice_mut(s![0, ..])
3257 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
3258 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
3259 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
3260
3261 let decoder = if is_split {
3262 DecoderBuilder::default()
3263 .with_config_yaml_str(E2E_SPLIT_SEGDET_CONFIG.to_string())
3264 .with_score_threshold(score_threshold)
3265 .with_iou_threshold(iou_threshold)
3266 .build()
3267 .unwrap()
3268 } else {
3269 DecoderBuilder::default()
3270 .with_config_yaml_str(E2E_COMBINED_SEGDET_CONFIG.to_string())
3271 .with_score_threshold(score_threshold)
3272 .with_iou_threshold(iou_threshold)
3273 .build()
3274 .unwrap()
3275 };
3276
3277 let expected = e2e_expected_boxes_float();
3278 let mut output_boxes = Vec::with_capacity(50);
3279
3280 if is_split {
3281 let boxes = boxes.insert_axis(Axis(0));
3282 let scores = scores.insert_axis(Axis(0));
3283 let classes = classes.insert_axis(Axis(0));
3284 let mask = mask.insert_axis(Axis(0));
3285
3286 if is_proto {
3287 let inputs = vec![
3288 boxes.view().into_dyn(),
3289 scores.view().into_dyn(),
3290 classes.view().into_dyn(),
3291 mask.view().into_dyn(),
3292 protos.view().into_dyn(),
3293 ];
3294 decoder
3295 .decode_float_proto(&inputs, &mut output_boxes)
3296 .unwrap();
3297
3298 assert_eq!(output_boxes.len(), 1);
3299 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3300 } else {
3301 let mut output_masks = Vec::with_capacity(50);
3302 let inputs = vec![
3303 boxes.view().into_dyn(),
3304 scores.view().into_dyn(),
3305 classes.view().into_dyn(),
3306 mask.view().into_dyn(),
3307 protos.view().into_dyn(),
3308 ];
3309 decoder
3310 .decode_float(&inputs, &mut output_boxes, &mut output_masks)
3311 .unwrap();
3312
3313 assert_eq!(output_boxes.len(), 1);
3314 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3315 }
3316 } else {
3317 let detect = ndarray::concatenate![
3319 Axis(1),
3320 boxes.view(),
3321 scores.view(),
3322 classes.view(),
3323 mask.view()
3324 ];
3325 let detect = detect.insert_axis(Axis(0));
3326 assert_eq!(detect.shape(), &[1, 10, 38]);
3327
3328 if is_proto {
3329 let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
3330 decoder
3331 .decode_float_proto(&inputs, &mut output_boxes)
3332 .unwrap();
3333
3334 assert_eq!(output_boxes.len(), 1);
3335 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3336 } else {
3337 let mut output_masks = Vec::with_capacity(50);
3338 let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
3339 decoder
3340 .decode_float(&inputs, &mut output_boxes, &mut output_masks)
3341 .unwrap();
3342
3343 assert_eq!(output_boxes.len(), 1);
3344 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3345 }
3346 }
3347 }
3348 };
3349 }
3350
3351 e2e_segdet_test!(test_decoder_end_to_end_segdet, quantized, combined, masks);
3352 e2e_segdet_test!(test_decoder_end_to_end_segdet_float, float, combined, masks);
3353 e2e_segdet_test!(
3354 test_decoder_end_to_end_segdet_proto,
3355 quantized,
3356 combined,
3357 proto
3358 );
3359 e2e_segdet_test!(
3360 test_decoder_end_to_end_segdet_proto_float,
3361 float,
3362 combined,
3363 proto
3364 );
3365 e2e_segdet_test!(
3366 test_decoder_end_to_end_segdet_split,
3367 quantized,
3368 split,
3369 masks
3370 );
3371 e2e_segdet_test!(
3372 test_decoder_end_to_end_segdet_split_float,
3373 float,
3374 split,
3375 masks
3376 );
3377 e2e_segdet_test!(
3378 test_decoder_end_to_end_segdet_split_proto,
3379 quantized,
3380 split,
3381 proto
3382 );
3383 e2e_segdet_test!(
3384 test_decoder_end_to_end_segdet_split_proto_float,
3385 float,
3386 split,
3387 proto
3388 );
3389
3390 macro_rules! e2e_det_test {
3391 ($name:ident, quantized, $layout:ident) => {
3392 #[test]
3393 fn $name() {
3394 let is_split = matches!(stringify!($layout), "split");
3395
3396 let score_threshold = 0.45;
3397 let iou_threshold = 0.45;
3398
3399 let mut boxes = Array3::zeros((1, 10, 4));
3400 let mut scores = Array3::zeros((1, 10, 1));
3401 let mut classes = Array3::zeros((1, 10, 1));
3402
3403 boxes
3404 .slice_mut(s![0, 0, ..])
3405 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
3406 scores.slice_mut(s![0, 0, ..]).assign(&array![0.9876]);
3407 classes.slice_mut(s![0, 0, ..]).assign(&array![2.0]);
3408
3409 let detect_quant = (2.0 / 255.0, 0_i32);
3410
3411 let decoder = if is_split {
3412 DecoderBuilder::default()
3413 .with_config_yaml_str(E2E_SPLIT_DET_CONFIG.to_string())
3414 .with_score_threshold(score_threshold)
3415 .with_iou_threshold(iou_threshold)
3416 .build()
3417 .unwrap()
3418 } else {
3419 DecoderBuilder::default()
3420 .with_config_yaml_str(E2E_COMBINED_DET_CONFIG.to_string())
3421 .with_score_threshold(score_threshold)
3422 .with_iou_threshold(iou_threshold)
3423 .build()
3424 .unwrap()
3425 };
3426
3427 let expected = e2e_expected_boxes_quant();
3428 let mut output_boxes = Vec::with_capacity(50);
3429
3430 if is_split {
3431 let boxes: Array<u8, _> = quantize_ndarray(boxes.view(), detect_quant.into());
3432 let scores: Array<u8, _> = quantize_ndarray(scores.view(), detect_quant.into());
3433 let classes: Array<u8, _> =
3434 quantize_ndarray(classes.view(), detect_quant.into());
3435 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
3436 boxes.view().into(),
3437 scores.view().into(),
3438 classes.view().into(),
3439 ];
3440 decoder
3441 .decode_quantized(&inputs, &mut output_boxes, &mut Vec::new())
3442 .unwrap();
3443 } else {
3444 let detect =
3445 ndarray::concatenate![Axis(2), boxes.view(), scores.view(), classes.view()];
3446 assert_eq!(detect.shape(), &[1, 10, 6]);
3447 let detect: Array3<u8> = quantize_ndarray(detect.view(), detect_quant.into());
3448 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
3449 vec![detect.view().into()];
3450 decoder
3451 .decode_quantized(&inputs, &mut output_boxes, &mut Vec::new())
3452 .unwrap();
3453 }
3454
3455 assert_eq!(output_boxes.len(), 1);
3456 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
3457 }
3458 };
3459 ($name:ident, float, $layout:ident) => {
3460 #[test]
3461 fn $name() {
3462 let is_split = matches!(stringify!($layout), "split");
3463
3464 let score_threshold = 0.45;
3465 let iou_threshold = 0.45;
3466
3467 let mut boxes = Array3::zeros((1, 10, 4));
3468 let mut scores = Array3::zeros((1, 10, 1));
3469 let mut classes = Array3::zeros((1, 10, 1));
3470
3471 boxes
3472 .slice_mut(s![0, 0, ..])
3473 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
3474 scores.slice_mut(s![0, 0, ..]).assign(&array![0.9876]);
3475 classes.slice_mut(s![0, 0, ..]).assign(&array![2.0]);
3476
3477 let decoder = if is_split {
3478 DecoderBuilder::default()
3479 .with_config_yaml_str(E2E_SPLIT_DET_CONFIG.to_string())
3480 .with_score_threshold(score_threshold)
3481 .with_iou_threshold(iou_threshold)
3482 .build()
3483 .unwrap()
3484 } else {
3485 DecoderBuilder::default()
3486 .with_config_yaml_str(E2E_COMBINED_DET_CONFIG.to_string())
3487 .with_score_threshold(score_threshold)
3488 .with_iou_threshold(iou_threshold)
3489 .build()
3490 .unwrap()
3491 };
3492
3493 let expected = e2e_expected_boxes_float();
3494 let mut output_boxes = Vec::with_capacity(50);
3495
3496 if is_split {
3497 let inputs = vec![
3498 boxes.view().into_dyn(),
3499 scores.view().into_dyn(),
3500 classes.view().into_dyn(),
3501 ];
3502 decoder
3503 .decode_float(&inputs, &mut output_boxes, &mut Vec::new())
3504 .unwrap();
3505 } else {
3506 let detect =
3507 ndarray::concatenate![Axis(2), boxes.view(), scores.view(), classes.view()];
3508 assert_eq!(detect.shape(), &[1, 10, 6]);
3509 let inputs = vec![detect.view().into_dyn()];
3510 decoder
3511 .decode_float(&inputs, &mut output_boxes, &mut Vec::new())
3512 .unwrap();
3513 }
3514
3515 assert_eq!(output_boxes.len(), 1);
3516 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
3517 }
3518 };
3519 }
3520
3521 e2e_det_test!(test_decoder_end_to_end_combined_det, quantized, combined);
3522 e2e_det_test!(test_decoder_end_to_end_combined_det_float, float, combined);
3523 e2e_det_test!(test_decoder_end_to_end_split_det, quantized, split);
3524 e2e_det_test!(test_decoder_end_to_end_split_det_float, float, split);
3525
3526 #[test]
3527 fn test_decode_tensor() {
3528 let score_threshold = 0.45;
3529 let iou_threshold = 0.45;
3530
3531 let raw_boxes = include_bytes!(concat!(
3532 env!("CARGO_MANIFEST_DIR"),
3533 "/../../testdata/yolov8_boxes_116x8400.bin"
3534 ));
3535 let raw_boxes =
3536 unsafe { std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len()) };
3537 let boxes_i8: Tensor<i8> = Tensor::new(&[1, 116, 8400], None, None).unwrap();
3538 boxes_i8
3539 .map()
3540 .unwrap()
3541 .as_mut_slice()
3542 .copy_from_slice(raw_boxes);
3543 let boxes_i8 = boxes_i8.into();
3544
3545 let raw_protos = include_bytes!(concat!(
3546 env!("CARGO_MANIFEST_DIR"),
3547 "/../../testdata/yolov8_protos_160x160x32.bin"
3548 ));
3549 let raw_protos = unsafe {
3550 std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3551 };
3552 let protos_i8: Tensor<i8> = Tensor::new(&[1, 160, 160, 32], None, None).unwrap();
3553 protos_i8
3554 .map()
3555 .unwrap()
3556 .as_mut_slice()
3557 .copy_from_slice(raw_protos);
3558 let protos_i8 = protos_i8.into();
3559
3560 let decoder = build_yolov8_seg_decoder(score_threshold, iou_threshold);
3561 let expected = real_data_expected_boxes();
3562 let mut output_boxes = Vec::with_capacity(50);
3563
3564 decoder
3565 .decode(&[&boxes_i8, &protos_i8], &mut output_boxes, &mut Vec::new())
3566 .unwrap();
3567
3568 assert_eq!(output_boxes.len(), 2);
3569 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3570 assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3571 }
3572
3573 #[test]
3574 fn test_decode_tensor_f32() {
3575 let score_threshold = 0.45;
3576 let iou_threshold = 0.45;
3577
3578 let quant_boxes = (0.021287762_f32, 31_i32);
3579 let quant_protos = (0.02491162_f32, -117_i32);
3580 let raw_boxes = include_bytes!(concat!(
3581 env!("CARGO_MANIFEST_DIR"),
3582 "/../../testdata/yolov8_boxes_116x8400.bin"
3583 ));
3584 let raw_boxes =
3585 unsafe { std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len()) };
3586 let mut raw_boxes_f32 = vec![0f32; raw_boxes.len()];
3587 dequantize_cpu(raw_boxes, quant_boxes.into(), &mut raw_boxes_f32);
3588 let boxes_f32: Tensor<f32> = Tensor::new(&[1, 116, 8400], None, None).unwrap();
3589 boxes_f32
3590 .map()
3591 .unwrap()
3592 .as_mut_slice()
3593 .copy_from_slice(&raw_boxes_f32);
3594 let boxes_f32 = boxes_f32.into();
3595
3596 let raw_protos = include_bytes!(concat!(
3597 env!("CARGO_MANIFEST_DIR"),
3598 "/../../testdata/yolov8_protos_160x160x32.bin"
3599 ));
3600 let raw_protos = unsafe {
3601 std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3602 };
3603 let mut raw_protos_f32 = vec![0f32; raw_protos.len()];
3604 dequantize_cpu(raw_protos, quant_protos.into(), &mut raw_protos_f32);
3605 let protos_f32: Tensor<f32> = Tensor::new(&[1, 160, 160, 32], None, None).unwrap();
3606 protos_f32
3607 .map()
3608 .unwrap()
3609 .as_mut_slice()
3610 .copy_from_slice(&raw_protos_f32);
3611 let protos_f32 = protos_f32.into();
3612
3613 let decoder = build_yolov8_seg_decoder(score_threshold, iou_threshold);
3614
3615 let expected = real_data_expected_boxes();
3616 let mut output_boxes = Vec::with_capacity(50);
3617
3618 decoder
3619 .decode(
3620 &[&boxes_f32, &protos_f32],
3621 &mut output_boxes,
3622 &mut Vec::new(),
3623 )
3624 .unwrap();
3625
3626 assert_eq!(output_boxes.len(), 2);
3627 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3628 assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3629 }
3630
3631 #[test]
3632 fn test_decode_tensor_f64() {
3633 let score_threshold = 0.45;
3634 let iou_threshold = 0.45;
3635
3636 let quant_boxes = (0.021287762_f32, 31_i32);
3637 let quant_protos = (0.02491162_f32, -117_i32);
3638 let raw_boxes = include_bytes!(concat!(
3639 env!("CARGO_MANIFEST_DIR"),
3640 "/../../testdata/yolov8_boxes_116x8400.bin"
3641 ));
3642 let raw_boxes =
3643 unsafe { std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len()) };
3644 let mut raw_boxes_f64 = vec![0f64; raw_boxes.len()];
3645 dequantize_cpu(raw_boxes, quant_boxes.into(), &mut raw_boxes_f64);
3646 let boxes_f64: Tensor<f64> = Tensor::new(&[1, 116, 8400], None, None).unwrap();
3647 boxes_f64
3648 .map()
3649 .unwrap()
3650 .as_mut_slice()
3651 .copy_from_slice(&raw_boxes_f64);
3652 let boxes_f64 = boxes_f64.into();
3653
3654 let raw_protos = include_bytes!(concat!(
3655 env!("CARGO_MANIFEST_DIR"),
3656 "/../../testdata/yolov8_protos_160x160x32.bin"
3657 ));
3658 let raw_protos = unsafe {
3659 std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3660 };
3661 let mut raw_protos_f64 = vec![0f64; raw_protos.len()];
3662 dequantize_cpu(raw_protos, quant_protos.into(), &mut raw_protos_f64);
3663 let protos_f64: Tensor<f64> = Tensor::new(&[1, 160, 160, 32], None, None).unwrap();
3664 protos_f64
3665 .map()
3666 .unwrap()
3667 .as_mut_slice()
3668 .copy_from_slice(&raw_protos_f64);
3669 let protos_f64 = protos_f64.into();
3670
3671 let decoder = build_yolov8_seg_decoder(score_threshold, iou_threshold);
3672
3673 let expected = real_data_expected_boxes();
3674 let mut output_boxes = Vec::with_capacity(50);
3675
3676 decoder
3677 .decode(
3678 &[&boxes_f64, &protos_f64],
3679 &mut output_boxes,
3680 &mut Vec::new(),
3681 )
3682 .unwrap();
3683
3684 assert_eq!(output_boxes.len(), 2);
3685 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3686 assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3687 }
3688
3689 #[test]
3690 fn test_decode_tensor_proto() {
3691 let score_threshold = 0.45;
3692 let iou_threshold = 0.45;
3693
3694 let raw_boxes = include_bytes!(concat!(
3695 env!("CARGO_MANIFEST_DIR"),
3696 "/../../testdata/yolov8_boxes_116x8400.bin"
3697 ));
3698 let raw_boxes =
3699 unsafe { std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len()) };
3700 let boxes_i8: Tensor<i8> = Tensor::new(&[1, 116, 8400], None, None).unwrap();
3701 boxes_i8
3702 .map()
3703 .unwrap()
3704 .as_mut_slice()
3705 .copy_from_slice(raw_boxes);
3706 let boxes_i8 = boxes_i8.into();
3707
3708 let raw_protos = include_bytes!(concat!(
3709 env!("CARGO_MANIFEST_DIR"),
3710 "/../../testdata/yolov8_protos_160x160x32.bin"
3711 ));
3712 let raw_protos = unsafe {
3713 std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3714 };
3715 let protos_i8: Tensor<i8> = Tensor::new(&[1, 160, 160, 32], None, None).unwrap();
3716 protos_i8
3717 .map()
3718 .unwrap()
3719 .as_mut_slice()
3720 .copy_from_slice(raw_protos);
3721 let protos_i8 = protos_i8.into();
3722
3723 let decoder = build_yolov8_seg_decoder(score_threshold, iou_threshold);
3724
3725 let expected = real_data_expected_boxes();
3726 let mut output_boxes = Vec::with_capacity(50);
3727
3728 let proto_data = decoder
3729 .decode_proto(&[&boxes_i8, &protos_i8], &mut output_boxes)
3730 .unwrap();
3731
3732 assert_eq!(output_boxes.len(), 2);
3733 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3734 assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3735
3736 let proto_data = proto_data.expect("segmentation model should return ProtoData");
3737 assert_eq!(
3738 proto_data.mask_coefficients.len(),
3739 output_boxes.len(),
3740 "mask_coefficients count must match detection count"
3741 );
3742 for coeff in &proto_data.mask_coefficients {
3743 assert_eq!(
3744 coeff.len(),
3745 32,
3746 "each detection should have 32 mask coefficients"
3747 );
3748 }
3749 }
3750}
3751
3752#[cfg(feature = "tracker")]
3753#[cfg(test)]
3754#[cfg_attr(coverage_nightly, coverage(off))]
3755mod decoder_tracked_tests {
3756
3757 use edgefirst_tracker::{ByteTrackBuilder, Tracker};
3758 use ndarray::{array, s, Array, Array2, Array3, Array4, ArrayView, Axis, Dimension};
3759 use num_traits::{AsPrimitive, Float, PrimInt};
3760 use rand::{RngExt, SeedableRng};
3761 use rand_distr::StandardNormal;
3762
3763 use crate::{
3764 configs::{self, DimName},
3765 dequantize_ndarray, BoundingBox, DecoderBuilder, DetectBox, Quantization,
3766 };
3767
3768 pub fn quantize_ndarray<T: PrimInt + 'static, D: Dimension, F: Float + AsPrimitive<T>>(
3769 input: ArrayView<F, D>,
3770 quant: Quantization,
3771 ) -> Array<T, D>
3772 where
3773 i32: num_traits::AsPrimitive<F>,
3774 f32: num_traits::AsPrimitive<F>,
3775 {
3776 let zero_point = quant.zero_point.as_();
3777 let div_scale = F::one() / quant.scale.as_();
3778 if zero_point != F::zero() {
3779 input.mapv(|d| (d * div_scale + zero_point).round().as_())
3780 } else {
3781 input.mapv(|d| (d * div_scale).round().as_())
3782 }
3783 }
3784
3785 #[test]
3786 fn test_decoder_tracked_random_jitter() {
3787 use crate::configs::{DecoderType, Nms};
3788 use crate::DecoderBuilder;
3789
3790 let score_threshold = 0.25;
3791 let iou_threshold = 0.1;
3792 let out = include_bytes!(concat!(
3793 env!("CARGO_MANIFEST_DIR"),
3794 "/../../testdata/yolov8s_80_classes.bin"
3795 ));
3796 let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
3797 let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
3798 let quant = (0.0040811873, -123).into();
3799
3800 let decoder = DecoderBuilder::default()
3801 .with_config_yolo_det(
3802 crate::configs::Detection {
3803 decoder: DecoderType::Ultralytics,
3804 shape: vec![1, 84, 8400],
3805 anchors: None,
3806 quantization: Some(quant),
3807 dshape: vec![
3808 (crate::configs::DimName::Batch, 1),
3809 (crate::configs::DimName::NumFeatures, 84),
3810 (crate::configs::DimName::NumBoxes, 8400),
3811 ],
3812 normalized: Some(true),
3813 },
3814 None,
3815 )
3816 .with_score_threshold(score_threshold)
3817 .with_iou_threshold(iou_threshold)
3818 .with_nms(Some(Nms::ClassAgnostic))
3819 .build()
3820 .unwrap();
3821 let mut rng = rand::rngs::StdRng::seed_from_u64(0xAB_BEEF); let expected_boxes = [
3824 crate::DetectBox {
3825 bbox: crate::BoundingBox {
3826 xmin: 0.5285137,
3827 ymin: 0.05305544,
3828 xmax: 0.87541467,
3829 ymax: 0.9998909,
3830 },
3831 score: 0.5591227,
3832 label: 0,
3833 },
3834 crate::DetectBox {
3835 bbox: crate::BoundingBox {
3836 xmin: 0.130598,
3837 ymin: 0.43260583,
3838 xmax: 0.35098213,
3839 ymax: 0.9958097,
3840 },
3841 score: 0.33057618,
3842 label: 75,
3843 },
3844 ];
3845
3846 let mut tracker = ByteTrackBuilder::new()
3847 .track_update(0.1)
3848 .track_high_conf(0.3)
3849 .build();
3850
3851 let mut output_boxes = Vec::with_capacity(50);
3852 let mut output_masks = Vec::with_capacity(50);
3853 let mut output_tracks = Vec::with_capacity(50);
3854
3855 decoder
3856 .decode_tracked_quantized(
3857 &mut tracker,
3858 0,
3859 &[out.view().into()],
3860 &mut output_boxes,
3861 &mut output_masks,
3862 &mut output_tracks,
3863 )
3864 .unwrap();
3865
3866 assert_eq!(output_boxes.len(), 2);
3867 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3868 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3869
3870 let mut last_boxes = output_boxes.clone();
3871
3872 for i in 1..=100 {
3873 let mut out = out.clone();
3874 let mut x_values = out.slice_mut(s![0, 0, ..]);
3876 for x in x_values.iter_mut() {
3877 let r: f32 = rng.sample(StandardNormal);
3878 let r = r.clamp(-2.0, 2.0) / 2.0;
3879 *x = x.saturating_add((r * 1e-2 / quant.0) as i8);
3880 }
3881
3882 let mut y_values = out.slice_mut(s![0, 1, ..]);
3883 for y in y_values.iter_mut() {
3884 let r: f32 = rng.sample(StandardNormal);
3885 let r = r.clamp(-2.0, 2.0) / 2.0;
3886 *y = y.saturating_add((r * 1e-2 / quant.0) as i8);
3887 }
3888
3889 decoder
3890 .decode_tracked_quantized(
3891 &mut tracker,
3892 100_000_000 * i / 3, &[out.view().into()],
3894 &mut output_boxes,
3895 &mut output_masks,
3896 &mut output_tracks,
3897 )
3898 .unwrap();
3899
3900 assert_eq!(output_boxes.len(), 2);
3901 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 5e-3));
3902 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 5e-3));
3903
3904 assert!(output_boxes[0].equal_within_delta(&last_boxes[0], 1e-3));
3905 assert!(output_boxes[1].equal_within_delta(&last_boxes[1], 1e-3));
3906 last_boxes = output_boxes.clone();
3907 }
3908 }
3909
3910 fn real_data_expected_boxes() -> [DetectBox; 2] {
3913 [
3914 DetectBox {
3915 bbox: BoundingBox {
3916 xmin: 0.08515105,
3917 ymin: 0.7131401,
3918 xmax: 0.29802868,
3919 ymax: 0.8195788,
3920 },
3921 score: 0.91537374,
3922 label: 23,
3923 },
3924 DetectBox {
3925 bbox: BoundingBox {
3926 xmin: 0.59605736,
3927 ymin: 0.25545314,
3928 xmax: 0.93666154,
3929 ymax: 0.72378385,
3930 },
3931 score: 0.91537374,
3932 label: 23,
3933 },
3934 ]
3935 }
3936
3937 fn e2e_expected_boxes_quant() -> [DetectBox; 1] {
3938 [DetectBox {
3939 bbox: BoundingBox {
3940 xmin: 0.12549022,
3941 ymin: 0.12549022,
3942 xmax: 0.23529413,
3943 ymax: 0.23529413,
3944 },
3945 score: 0.98823535,
3946 label: 2,
3947 }]
3948 }
3949
3950 fn e2e_expected_boxes_float() -> [DetectBox; 1] {
3951 [DetectBox {
3952 bbox: BoundingBox {
3953 xmin: 0.1234,
3954 ymin: 0.1234,
3955 xmax: 0.2345,
3956 ymax: 0.2345,
3957 },
3958 score: 0.9876,
3959 label: 2,
3960 }]
3961 }
3962
3963 fn build_yolo_split_segdet_decoder(
3964 score_threshold: f32,
3965 iou_threshold: f32,
3966 quant_boxes: (f32, i32),
3967 quant_protos: (f32, i32),
3968 ) -> crate::Decoder {
3969 DecoderBuilder::default()
3970 .with_config_yolo_split_segdet(
3971 configs::Boxes {
3972 decoder: configs::DecoderType::Ultralytics,
3973 quantization: Some(quant_boxes.into()),
3974 shape: vec![1, 4, 8400],
3975 dshape: vec![
3976 (DimName::Batch, 1),
3977 (DimName::BoxCoords, 4),
3978 (DimName::NumBoxes, 8400),
3979 ],
3980 normalized: Some(true),
3981 },
3982 configs::Scores {
3983 decoder: configs::DecoderType::Ultralytics,
3984 quantization: Some(quant_boxes.into()),
3985 shape: vec![1, 80, 8400],
3986 dshape: vec![
3987 (DimName::Batch, 1),
3988 (DimName::NumClasses, 80),
3989 (DimName::NumBoxes, 8400),
3990 ],
3991 },
3992 configs::MaskCoefficients {
3993 decoder: configs::DecoderType::Ultralytics,
3994 quantization: Some(quant_boxes.into()),
3995 shape: vec![1, 32, 8400],
3996 dshape: vec![
3997 (DimName::Batch, 1),
3998 (DimName::NumProtos, 32),
3999 (DimName::NumBoxes, 8400),
4000 ],
4001 },
4002 configs::Protos {
4003 decoder: configs::DecoderType::Ultralytics,
4004 quantization: Some(quant_protos.into()),
4005 shape: vec![1, 160, 160, 32],
4006 dshape: vec![
4007 (DimName::Batch, 1),
4008 (DimName::Height, 160),
4009 (DimName::Width, 160),
4010 (DimName::NumProtos, 32),
4011 ],
4012 },
4013 )
4014 .with_score_threshold(score_threshold)
4015 .with_iou_threshold(iou_threshold)
4016 .build()
4017 .unwrap()
4018 }
4019
4020 fn build_yolov8_seg_decoder(score_threshold: f32, iou_threshold: f32) -> crate::Decoder {
4021 let config_yaml = include_str!(concat!(
4022 env!("CARGO_MANIFEST_DIR"),
4023 "/../../testdata/yolov8_seg.yaml"
4024 ));
4025 DecoderBuilder::default()
4026 .with_config_yaml_str(config_yaml.to_string())
4027 .with_score_threshold(score_threshold)
4028 .with_iou_threshold(iou_threshold)
4029 .build()
4030 .unwrap()
4031 }
4032
4033 macro_rules! real_data_tracked_test {
4040 ($name:ident, quantized, $layout:ident, $output:ident) => {
4041 #[test]
4042 fn $name() {
4043 let is_split = matches!(stringify!($layout), "split");
4044 let is_proto = matches!(stringify!($output), "proto");
4045
4046 let score_threshold = 0.45;
4047 let iou_threshold = 0.45;
4048 let quant_boxes = (0.021287762_f32, 31_i32);
4049 let quant_protos = (0.02491162_f32, -117_i32);
4050
4051 let raw_boxes = include_bytes!(concat!(
4052 env!("CARGO_MANIFEST_DIR"),
4053 "/../../testdata/yolov8_boxes_116x8400.bin"
4054 ));
4055 let raw_boxes = unsafe {
4056 std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len())
4057 };
4058 let boxes_i8 =
4059 ndarray::Array3::from_shape_vec((1, 116, 8400), raw_boxes.to_vec()).unwrap();
4060
4061 let raw_protos = include_bytes!(concat!(
4062 env!("CARGO_MANIFEST_DIR"),
4063 "/../../testdata/yolov8_protos_160x160x32.bin"
4064 ));
4065 let raw_protos = unsafe {
4066 std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
4067 };
4068 let protos_i8 =
4069 ndarray::Array4::from_shape_vec((1, 160, 160, 32), raw_protos.to_vec())
4070 .unwrap();
4071
4072 let mask_split = boxes_i8.slice(s![.., 84.., ..]).to_owned();
4074 let mut scores_split = boxes_i8.slice(s![.., 4..84, ..]).to_owned();
4075 let boxes_split = boxes_i8.slice(s![.., ..4, ..]).to_owned();
4076 let mut boxes_combined = boxes_i8;
4077
4078 let decoder = if is_split {
4079 build_yolo_split_segdet_decoder(
4080 score_threshold,
4081 iou_threshold,
4082 quant_boxes,
4083 quant_protos,
4084 )
4085 } else {
4086 build_yolov8_seg_decoder(score_threshold, iou_threshold)
4087 };
4088
4089 let expected = real_data_expected_boxes();
4090 let mut tracker = ByteTrackBuilder::new()
4091 .track_update(0.1)
4092 .track_high_conf(0.7)
4093 .build();
4094 let mut output_boxes = Vec::with_capacity(50);
4095 let mut output_tracks = Vec::with_capacity(50);
4096
4097 if is_proto {
4099 {
4100 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
4101 vec![
4102 boxes_split.view().into(),
4103 scores_split.view().into(),
4104 mask_split.view().into(),
4105 protos_i8.view().into(),
4106 ]
4107 } else {
4108 vec![boxes_combined.view().into(), protos_i8.view().into()]
4109 };
4110 decoder
4111 .decode_tracked_quantized_proto(
4112 &mut tracker,
4113 0,
4114 &inputs,
4115 &mut output_boxes,
4116 &mut output_tracks,
4117 )
4118 .unwrap();
4119 }
4120 assert_eq!(output_boxes.len(), 2);
4121 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4122 assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
4123
4124 if is_split {
4126 for score in scores_split.iter_mut() {
4127 *score = i8::MIN;
4128 }
4129 } else {
4130 for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
4131 *score = i8::MIN;
4132 }
4133 }
4134
4135 let proto_result = {
4136 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
4137 vec![
4138 boxes_split.view().into(),
4139 scores_split.view().into(),
4140 mask_split.view().into(),
4141 protos_i8.view().into(),
4142 ]
4143 } else {
4144 vec![boxes_combined.view().into(), protos_i8.view().into()]
4145 };
4146 decoder
4147 .decode_tracked_quantized_proto(
4148 &mut tracker,
4149 100_000_000 / 3,
4150 &inputs,
4151 &mut output_boxes,
4152 &mut output_tracks,
4153 )
4154 .unwrap()
4155 };
4156 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4157 assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
4158 assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
4159 } else {
4160 let mut output_masks = Vec::with_capacity(50);
4161 {
4162 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
4163 vec![
4164 boxes_split.view().into(),
4165 scores_split.view().into(),
4166 mask_split.view().into(),
4167 protos_i8.view().into(),
4168 ]
4169 } else {
4170 vec![boxes_combined.view().into(), protos_i8.view().into()]
4171 };
4172 decoder
4173 .decode_tracked_quantized(
4174 &mut tracker,
4175 0,
4176 &inputs,
4177 &mut output_boxes,
4178 &mut output_masks,
4179 &mut output_tracks,
4180 )
4181 .unwrap();
4182 }
4183 assert_eq!(output_boxes.len(), 2);
4184 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4185 assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
4186
4187 if is_split {
4188 for score in scores_split.iter_mut() {
4189 *score = i8::MIN;
4190 }
4191 } else {
4192 for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
4193 *score = i8::MIN;
4194 }
4195 }
4196
4197 {
4198 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
4199 vec![
4200 boxes_split.view().into(),
4201 scores_split.view().into(),
4202 mask_split.view().into(),
4203 protos_i8.view().into(),
4204 ]
4205 } else {
4206 vec![boxes_combined.view().into(), protos_i8.view().into()]
4207 };
4208 decoder
4209 .decode_tracked_quantized(
4210 &mut tracker,
4211 100_000_000 / 3,
4212 &inputs,
4213 &mut output_boxes,
4214 &mut output_masks,
4215 &mut output_tracks,
4216 )
4217 .unwrap();
4218 }
4219 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4220 assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
4221 assert!(output_masks.is_empty());
4222 }
4223 }
4224 };
4225 ($name:ident, float, $layout:ident, $output:ident) => {
4226 #[test]
4227 fn $name() {
4228 let is_split = matches!(stringify!($layout), "split");
4229 let is_proto = matches!(stringify!($output), "proto");
4230
4231 let score_threshold = 0.45;
4232 let iou_threshold = 0.45;
4233 let quant_boxes = (0.021287762_f32, 31_i32);
4234 let quant_protos = (0.02491162_f32, -117_i32);
4235
4236 let raw_boxes = include_bytes!(concat!(
4237 env!("CARGO_MANIFEST_DIR"),
4238 "/../../testdata/yolov8_boxes_116x8400.bin"
4239 ));
4240 let raw_boxes = unsafe {
4241 std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len())
4242 };
4243 let boxes_i8 =
4244 ndarray::Array3::from_shape_vec((1, 116, 8400), raw_boxes.to_vec()).unwrap();
4245 let boxes_f32 = dequantize_ndarray(boxes_i8.view(), quant_boxes.into());
4246
4247 let raw_protos = include_bytes!(concat!(
4248 env!("CARGO_MANIFEST_DIR"),
4249 "/../../testdata/yolov8_protos_160x160x32.bin"
4250 ));
4251 let raw_protos = unsafe {
4252 std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
4253 };
4254 let protos_i8 =
4255 ndarray::Array4::from_shape_vec((1, 160, 160, 32), raw_protos.to_vec())
4256 .unwrap();
4257 let protos_f32 = dequantize_ndarray(protos_i8.view(), quant_protos.into());
4258
4259 let mask_split = boxes_f32.slice(s![.., 84.., ..]).to_owned();
4261 let mut scores_split = boxes_f32.slice(s![.., 4..84, ..]).to_owned();
4262 let boxes_split = boxes_f32.slice(s![.., ..4, ..]).to_owned();
4263 let mut boxes_combined = boxes_f32;
4264
4265 let decoder = if is_split {
4266 build_yolo_split_segdet_decoder(
4267 score_threshold,
4268 iou_threshold,
4269 quant_boxes,
4270 quant_protos,
4271 )
4272 } else {
4273 build_yolov8_seg_decoder(score_threshold, iou_threshold)
4274 };
4275
4276 let expected = real_data_expected_boxes();
4277 let mut tracker = ByteTrackBuilder::new()
4278 .track_update(0.1)
4279 .track_high_conf(0.7)
4280 .build();
4281 let mut output_boxes = Vec::with_capacity(50);
4282 let mut output_tracks = Vec::with_capacity(50);
4283
4284 if is_proto {
4285 {
4286 let inputs = if is_split {
4287 vec![
4288 boxes_split.view().into_dyn(),
4289 scores_split.view().into_dyn(),
4290 mask_split.view().into_dyn(),
4291 protos_f32.view().into_dyn(),
4292 ]
4293 } else {
4294 vec![
4295 boxes_combined.view().into_dyn(),
4296 protos_f32.view().into_dyn(),
4297 ]
4298 };
4299 decoder
4300 .decode_tracked_float_proto(
4301 &mut tracker,
4302 0,
4303 &inputs,
4304 &mut output_boxes,
4305 &mut output_tracks,
4306 )
4307 .unwrap();
4308 }
4309 assert_eq!(output_boxes.len(), 2);
4310 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4311 assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
4312
4313 if is_split {
4314 for score in scores_split.iter_mut() {
4315 *score = 0.0;
4316 }
4317 } else {
4318 for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
4319 *score = 0.0;
4320 }
4321 }
4322
4323 let proto_result = {
4324 let inputs = if is_split {
4325 vec![
4326 boxes_split.view().into_dyn(),
4327 scores_split.view().into_dyn(),
4328 mask_split.view().into_dyn(),
4329 protos_f32.view().into_dyn(),
4330 ]
4331 } else {
4332 vec![
4333 boxes_combined.view().into_dyn(),
4334 protos_f32.view().into_dyn(),
4335 ]
4336 };
4337 decoder
4338 .decode_tracked_float_proto(
4339 &mut tracker,
4340 100_000_000 / 3,
4341 &inputs,
4342 &mut output_boxes,
4343 &mut output_tracks,
4344 )
4345 .unwrap()
4346 };
4347 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4348 assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
4349 assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
4350 } else {
4351 let mut output_masks = Vec::with_capacity(50);
4352 {
4353 let inputs = if is_split {
4354 vec![
4355 boxes_split.view().into_dyn(),
4356 scores_split.view().into_dyn(),
4357 mask_split.view().into_dyn(),
4358 protos_f32.view().into_dyn(),
4359 ]
4360 } else {
4361 vec![
4362 boxes_combined.view().into_dyn(),
4363 protos_f32.view().into_dyn(),
4364 ]
4365 };
4366 decoder
4367 .decode_tracked_float(
4368 &mut tracker,
4369 0,
4370 &inputs,
4371 &mut output_boxes,
4372 &mut output_masks,
4373 &mut output_tracks,
4374 )
4375 .unwrap();
4376 }
4377 assert_eq!(output_boxes.len(), 2);
4378 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4379 assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
4380
4381 if is_split {
4382 for score in scores_split.iter_mut() {
4383 *score = 0.0;
4384 }
4385 } else {
4386 for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
4387 *score = 0.0;
4388 }
4389 }
4390
4391 {
4392 let inputs = if is_split {
4393 vec![
4394 boxes_split.view().into_dyn(),
4395 scores_split.view().into_dyn(),
4396 mask_split.view().into_dyn(),
4397 protos_f32.view().into_dyn(),
4398 ]
4399 } else {
4400 vec![
4401 boxes_combined.view().into_dyn(),
4402 protos_f32.view().into_dyn(),
4403 ]
4404 };
4405 decoder
4406 .decode_tracked_float(
4407 &mut tracker,
4408 100_000_000 / 3,
4409 &inputs,
4410 &mut output_boxes,
4411 &mut output_masks,
4412 &mut output_tracks,
4413 )
4414 .unwrap();
4415 }
4416 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4417 assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
4418 assert!(output_masks.is_empty());
4419 }
4420 }
4421 };
4422 }
4423
4424 real_data_tracked_test!(test_decoder_tracked_segdet, quantized, combined, masks);
4425 real_data_tracked_test!(test_decoder_tracked_segdet_float, float, combined, masks);
4426 real_data_tracked_test!(
4427 test_decoder_tracked_segdet_proto,
4428 quantized,
4429 combined,
4430 proto
4431 );
4432 real_data_tracked_test!(
4433 test_decoder_tracked_segdet_proto_float,
4434 float,
4435 combined,
4436 proto
4437 );
4438 real_data_tracked_test!(test_decoder_tracked_segdet_split, quantized, split, masks);
4439 real_data_tracked_test!(test_decoder_tracked_segdet_split_float, float, split, masks);
4440 real_data_tracked_test!(
4441 test_decoder_tracked_segdet_split_proto,
4442 quantized,
4443 split,
4444 proto
4445 );
4446 real_data_tracked_test!(
4447 test_decoder_tracked_segdet_split_proto_float,
4448 float,
4449 split,
4450 proto
4451 );
4452
4453 const E2E_COMBINED_CONFIG: &str = "
4459decoder_version: yolo26
4460outputs:
4461 - type: detection
4462 decoder: ultralytics
4463 quantization: [0.00784313725490196, 0]
4464 shape: [1, 10, 38]
4465 dshape:
4466 - [batch, 1]
4467 - [num_boxes, 10]
4468 - [num_features, 38]
4469 normalized: true
4470 - type: protos
4471 decoder: ultralytics
4472 quantization: [0.0039215686274509803921568627451, 128]
4473 shape: [1, 160, 160, 32]
4474 dshape:
4475 - [batch, 1]
4476 - [height, 160]
4477 - [width, 160]
4478 - [num_protos, 32]
4479";
4480
4481 const E2E_SPLIT_CONFIG: &str = "
4482decoder_version: yolo26
4483outputs:
4484 - type: boxes
4485 decoder: ultralytics
4486 quantization: [0.00784313725490196, 0]
4487 shape: [1, 10, 4]
4488 dshape:
4489 - [batch, 1]
4490 - [num_boxes, 10]
4491 - [box_coords, 4]
4492 normalized: true
4493 - type: scores
4494 decoder: ultralytics
4495 quantization: [0.00784313725490196, 0]
4496 shape: [1, 10, 1]
4497 dshape:
4498 - [batch, 1]
4499 - [num_boxes, 10]
4500 - [num_classes, 1]
4501 - type: classes
4502 decoder: ultralytics
4503 quantization: [0.00784313725490196, 0]
4504 shape: [1, 10, 1]
4505 dshape:
4506 - [batch, 1]
4507 - [num_boxes, 10]
4508 - [num_classes, 1]
4509 - type: mask_coefficients
4510 decoder: ultralytics
4511 quantization: [0.00784313725490196, 0]
4512 shape: [1, 10, 32]
4513 dshape:
4514 - [batch, 1]
4515 - [num_boxes, 10]
4516 - [num_protos, 32]
4517 - type: protos
4518 decoder: ultralytics
4519 quantization: [0.0039215686274509803921568627451, 128]
4520 shape: [1, 160, 160, 32]
4521 dshape:
4522 - [batch, 1]
4523 - [height, 160]
4524 - [width, 160]
4525 - [num_protos, 32]
4526";
4527
4528 macro_rules! e2e_tracked_test {
4529 ($name:ident, quantized, $layout:ident, $output:ident) => {
4530 #[test]
4531 fn $name() {
4532 let is_split = matches!(stringify!($layout), "split");
4533 let is_proto = matches!(stringify!($output), "proto");
4534
4535 let score_threshold = 0.45;
4536 let iou_threshold = 0.45;
4537
4538 let mut boxes = Array2::zeros((10, 4));
4539 let mut scores = Array2::zeros((10, 1));
4540 let mut classes = Array2::zeros((10, 1));
4541 let mask = Array2::zeros((10, 32));
4542 let protos = Array3::<f64>::zeros((160, 160, 32));
4543 let protos = protos.insert_axis(Axis(0));
4544 let protos_quant = (1.0 / 255.0, 0.0);
4545 let protos: Array4<u8> = quantize_ndarray(protos.view(), protos_quant.into());
4546
4547 boxes
4548 .slice_mut(s![0, ..])
4549 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4550 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4551 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4552
4553 let detect_quant = (2.0 / 255.0, 0.0);
4554
4555 let decoder = if is_split {
4556 DecoderBuilder::default()
4557 .with_config_yaml_str(E2E_SPLIT_CONFIG.to_string())
4558 .with_score_threshold(score_threshold)
4559 .with_iou_threshold(iou_threshold)
4560 .build()
4561 .unwrap()
4562 } else {
4563 DecoderBuilder::default()
4564 .with_config_yaml_str(E2E_COMBINED_CONFIG.to_string())
4565 .with_score_threshold(score_threshold)
4566 .with_iou_threshold(iou_threshold)
4567 .build()
4568 .unwrap()
4569 };
4570
4571 let expected = e2e_expected_boxes_quant();
4572 let mut tracker = ByteTrackBuilder::new()
4573 .track_update(0.1)
4574 .track_high_conf(0.7)
4575 .build();
4576 let mut output_boxes = Vec::with_capacity(50);
4577 let mut output_tracks = Vec::with_capacity(50);
4578
4579 if is_split {
4580 let boxes = boxes.insert_axis(Axis(0));
4581 let scores = scores.insert_axis(Axis(0));
4582 let classes = classes.insert_axis(Axis(0));
4583 let mask = mask.insert_axis(Axis(0));
4584
4585 let boxes: Array3<u8> = quantize_ndarray(boxes.view(), detect_quant.into());
4586 let mut scores: Array3<u8> =
4587 quantize_ndarray(scores.view(), detect_quant.into());
4588 let classes: Array3<u8> = quantize_ndarray(classes.view(), detect_quant.into());
4589 let mask: Array3<u8> = quantize_ndarray(mask.view(), detect_quant.into());
4590
4591 if is_proto {
4592 {
4593 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
4594 boxes.view().into(),
4595 scores.view().into(),
4596 classes.view().into(),
4597 mask.view().into(),
4598 protos.view().into(),
4599 ];
4600 decoder
4601 .decode_tracked_quantized_proto(
4602 &mut tracker,
4603 0,
4604 &inputs,
4605 &mut output_boxes,
4606 &mut output_tracks,
4607 )
4608 .unwrap();
4609 }
4610 assert_eq!(output_boxes.len(), 1);
4611 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4612
4613 for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
4614 *score = u8::MIN;
4615 }
4616 let proto_result = {
4617 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
4618 boxes.view().into(),
4619 scores.view().into(),
4620 classes.view().into(),
4621 mask.view().into(),
4622 protos.view().into(),
4623 ];
4624 decoder
4625 .decode_tracked_quantized_proto(
4626 &mut tracker,
4627 100_000_000 / 3,
4628 &inputs,
4629 &mut output_boxes,
4630 &mut output_tracks,
4631 )
4632 .unwrap()
4633 };
4634 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4635 assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
4636 } else {
4637 let mut output_masks = Vec::with_capacity(50);
4638 {
4639 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
4640 boxes.view().into(),
4641 scores.view().into(),
4642 classes.view().into(),
4643 mask.view().into(),
4644 protos.view().into(),
4645 ];
4646 decoder
4647 .decode_tracked_quantized(
4648 &mut tracker,
4649 0,
4650 &inputs,
4651 &mut output_boxes,
4652 &mut output_masks,
4653 &mut output_tracks,
4654 )
4655 .unwrap();
4656 }
4657 assert_eq!(output_boxes.len(), 1);
4658 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4659
4660 for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
4661 *score = u8::MIN;
4662 }
4663 {
4664 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
4665 boxes.view().into(),
4666 scores.view().into(),
4667 classes.view().into(),
4668 mask.view().into(),
4669 protos.view().into(),
4670 ];
4671 decoder
4672 .decode_tracked_quantized(
4673 &mut tracker,
4674 100_000_000 / 3,
4675 &inputs,
4676 &mut output_boxes,
4677 &mut output_masks,
4678 &mut output_tracks,
4679 )
4680 .unwrap();
4681 }
4682 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4683 assert!(output_masks.is_empty());
4684 }
4685 } else {
4686 let detect = ndarray::concatenate![
4688 Axis(1),
4689 boxes.view(),
4690 scores.view(),
4691 classes.view(),
4692 mask.view()
4693 ];
4694 let detect = detect.insert_axis(Axis(0));
4695 assert_eq!(detect.shape(), &[1, 10, 38]);
4696 let mut detect: Array3<u8> =
4697 quantize_ndarray(detect.view(), detect_quant.into());
4698
4699 if is_proto {
4700 {
4701 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
4702 vec![detect.view().into(), protos.view().into()];
4703 decoder
4704 .decode_tracked_quantized_proto(
4705 &mut tracker,
4706 0,
4707 &inputs,
4708 &mut output_boxes,
4709 &mut output_tracks,
4710 )
4711 .unwrap();
4712 }
4713 assert_eq!(output_boxes.len(), 1);
4714 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4715
4716 for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
4717 *score = u8::MIN;
4718 }
4719 let proto_result = {
4720 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
4721 vec![detect.view().into(), protos.view().into()];
4722 decoder
4723 .decode_tracked_quantized_proto(
4724 &mut tracker,
4725 100_000_000 / 3,
4726 &inputs,
4727 &mut output_boxes,
4728 &mut output_tracks,
4729 )
4730 .unwrap()
4731 };
4732 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4733 assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
4734 } else {
4735 let mut output_masks = Vec::with_capacity(50);
4736 {
4737 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
4738 vec![detect.view().into(), protos.view().into()];
4739 decoder
4740 .decode_tracked_quantized(
4741 &mut tracker,
4742 0,
4743 &inputs,
4744 &mut output_boxes,
4745 &mut output_masks,
4746 &mut output_tracks,
4747 )
4748 .unwrap();
4749 }
4750 assert_eq!(output_boxes.len(), 1);
4751 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4752
4753 for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
4754 *score = u8::MIN;
4755 }
4756 {
4757 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
4758 vec![detect.view().into(), protos.view().into()];
4759 decoder
4760 .decode_tracked_quantized(
4761 &mut tracker,
4762 100_000_000 / 3,
4763 &inputs,
4764 &mut output_boxes,
4765 &mut output_masks,
4766 &mut output_tracks,
4767 )
4768 .unwrap();
4769 }
4770 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4771 assert!(output_masks.is_empty());
4772 }
4773 }
4774 }
4775 };
4776 ($name:ident, float, $layout:ident, $output:ident) => {
4777 #[test]
4778 fn $name() {
4779 let is_split = matches!(stringify!($layout), "split");
4780 let is_proto = matches!(stringify!($output), "proto");
4781
4782 let score_threshold = 0.45;
4783 let iou_threshold = 0.45;
4784
4785 let mut boxes = Array2::zeros((10, 4));
4786 let mut scores = Array2::zeros((10, 1));
4787 let mut classes = Array2::zeros((10, 1));
4788 let mask: Array2<f64> = Array2::zeros((10, 32));
4789 let protos = Array3::<f64>::zeros((160, 160, 32));
4790 let protos = protos.insert_axis(Axis(0));
4791
4792 boxes
4793 .slice_mut(s![0, ..])
4794 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4795 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4796 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4797
4798 let decoder = if is_split {
4799 DecoderBuilder::default()
4800 .with_config_yaml_str(E2E_SPLIT_CONFIG.to_string())
4801 .with_score_threshold(score_threshold)
4802 .with_iou_threshold(iou_threshold)
4803 .build()
4804 .unwrap()
4805 } else {
4806 DecoderBuilder::default()
4807 .with_config_yaml_str(E2E_COMBINED_CONFIG.to_string())
4808 .with_score_threshold(score_threshold)
4809 .with_iou_threshold(iou_threshold)
4810 .build()
4811 .unwrap()
4812 };
4813
4814 let expected = e2e_expected_boxes_float();
4815 let mut tracker = ByteTrackBuilder::new()
4816 .track_update(0.1)
4817 .track_high_conf(0.7)
4818 .build();
4819 let mut output_boxes = Vec::with_capacity(50);
4820 let mut output_tracks = Vec::with_capacity(50);
4821
4822 if is_split {
4823 let boxes = boxes.insert_axis(Axis(0));
4824 let mut scores = scores.insert_axis(Axis(0));
4825 let classes = classes.insert_axis(Axis(0));
4826 let mask = mask.insert_axis(Axis(0));
4827
4828 if is_proto {
4829 {
4830 let inputs = vec![
4831 boxes.view().into_dyn(),
4832 scores.view().into_dyn(),
4833 classes.view().into_dyn(),
4834 mask.view().into_dyn(),
4835 protos.view().into_dyn(),
4836 ];
4837 decoder
4838 .decode_tracked_float_proto(
4839 &mut tracker,
4840 0,
4841 &inputs,
4842 &mut output_boxes,
4843 &mut output_tracks,
4844 )
4845 .unwrap();
4846 }
4847 assert_eq!(output_boxes.len(), 1);
4848 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4849
4850 for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
4851 *score = 0.0;
4852 }
4853 let proto_result = {
4854 let inputs = vec![
4855 boxes.view().into_dyn(),
4856 scores.view().into_dyn(),
4857 classes.view().into_dyn(),
4858 mask.view().into_dyn(),
4859 protos.view().into_dyn(),
4860 ];
4861 decoder
4862 .decode_tracked_float_proto(
4863 &mut tracker,
4864 100_000_000 / 3,
4865 &inputs,
4866 &mut output_boxes,
4867 &mut output_tracks,
4868 )
4869 .unwrap()
4870 };
4871 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4872 assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
4873 } else {
4874 let mut output_masks = Vec::with_capacity(50);
4875 {
4876 let inputs = vec![
4877 boxes.view().into_dyn(),
4878 scores.view().into_dyn(),
4879 classes.view().into_dyn(),
4880 mask.view().into_dyn(),
4881 protos.view().into_dyn(),
4882 ];
4883 decoder
4884 .decode_tracked_float(
4885 &mut tracker,
4886 0,
4887 &inputs,
4888 &mut output_boxes,
4889 &mut output_masks,
4890 &mut output_tracks,
4891 )
4892 .unwrap();
4893 }
4894 assert_eq!(output_boxes.len(), 1);
4895 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4896
4897 for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
4898 *score = 0.0;
4899 }
4900 {
4901 let inputs = vec![
4902 boxes.view().into_dyn(),
4903 scores.view().into_dyn(),
4904 classes.view().into_dyn(),
4905 mask.view().into_dyn(),
4906 protos.view().into_dyn(),
4907 ];
4908 decoder
4909 .decode_tracked_float(
4910 &mut tracker,
4911 100_000_000 / 3,
4912 &inputs,
4913 &mut output_boxes,
4914 &mut output_masks,
4915 &mut output_tracks,
4916 )
4917 .unwrap();
4918 }
4919 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4920 assert!(output_masks.is_empty());
4921 }
4922 } else {
4923 let detect = ndarray::concatenate![
4925 Axis(1),
4926 boxes.view(),
4927 scores.view(),
4928 classes.view(),
4929 mask.view()
4930 ];
4931 let mut detect = detect.insert_axis(Axis(0));
4932 assert_eq!(detect.shape(), &[1, 10, 38]);
4933
4934 if is_proto {
4935 {
4936 let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
4937 decoder
4938 .decode_tracked_float_proto(
4939 &mut tracker,
4940 0,
4941 &inputs,
4942 &mut output_boxes,
4943 &mut output_tracks,
4944 )
4945 .unwrap();
4946 }
4947 assert_eq!(output_boxes.len(), 1);
4948 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4949
4950 for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
4951 *score = 0.0;
4952 }
4953 let proto_result = {
4954 let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
4955 decoder
4956 .decode_tracked_float_proto(
4957 &mut tracker,
4958 100_000_000 / 3,
4959 &inputs,
4960 &mut output_boxes,
4961 &mut output_tracks,
4962 )
4963 .unwrap()
4964 };
4965 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4966 assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
4967 } else {
4968 let mut output_masks = Vec::with_capacity(50);
4969 {
4970 let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
4971 decoder
4972 .decode_tracked_float(
4973 &mut tracker,
4974 0,
4975 &inputs,
4976 &mut output_boxes,
4977 &mut output_masks,
4978 &mut output_tracks,
4979 )
4980 .unwrap();
4981 }
4982 assert_eq!(output_boxes.len(), 1);
4983 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4984
4985 for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
4986 *score = 0.0;
4987 }
4988 {
4989 let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
4990 decoder
4991 .decode_tracked_float(
4992 &mut tracker,
4993 100_000_000 / 3,
4994 &inputs,
4995 &mut output_boxes,
4996 &mut output_masks,
4997 &mut output_tracks,
4998 )
4999 .unwrap();
5000 }
5001 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5002 assert!(output_masks.is_empty());
5003 }
5004 }
5005 }
5006 };
5007 }
5008
5009 e2e_tracked_test!(
5010 test_decoder_tracked_end_to_end_segdet,
5011 quantized,
5012 combined,
5013 masks
5014 );
5015 e2e_tracked_test!(
5016 test_decoder_tracked_end_to_end_segdet_float,
5017 float,
5018 combined,
5019 masks
5020 );
5021 e2e_tracked_test!(
5022 test_decoder_tracked_end_to_end_segdet_proto,
5023 quantized,
5024 combined,
5025 proto
5026 );
5027 e2e_tracked_test!(
5028 test_decoder_tracked_end_to_end_segdet_proto_float,
5029 float,
5030 combined,
5031 proto
5032 );
5033 e2e_tracked_test!(
5034 test_decoder_tracked_end_to_end_segdet_split,
5035 quantized,
5036 split,
5037 masks
5038 );
5039 e2e_tracked_test!(
5040 test_decoder_tracked_end_to_end_segdet_split_float,
5041 float,
5042 split,
5043 masks
5044 );
5045 e2e_tracked_test!(
5046 test_decoder_tracked_end_to_end_segdet_split_proto,
5047 quantized,
5048 split,
5049 proto
5050 );
5051 e2e_tracked_test!(
5052 test_decoder_tracked_end_to_end_segdet_split_proto_float,
5053 float,
5054 split,
5055 proto
5056 );
5057
5058 macro_rules! e2e_tracked_tensor_test {
5064 ($name:ident, quantized, $layout:ident, $output:ident) => {
5065 #[test]
5066 fn $name() {
5067 use edgefirst_tensor::{Tensor, TensorMapTrait, TensorTrait};
5068
5069 let is_split = matches!(stringify!($layout), "split");
5070 let is_proto = matches!(stringify!($output), "proto");
5071
5072 let score_threshold = 0.45;
5073 let iou_threshold = 0.45;
5074
5075 let mut boxes = Array2::zeros((10, 4));
5076 let mut scores = Array2::zeros((10, 1));
5077 let mut classes = Array2::zeros((10, 1));
5078 let mask = Array2::zeros((10, 32));
5079 let protos_f64 = Array3::<f64>::zeros((160, 160, 32));
5080 let protos_f64 = protos_f64.insert_axis(Axis(0));
5081 let protos_quant = (1.0 / 255.0, 0.0);
5082 let protos_u8: Array4<u8> =
5083 quantize_ndarray(protos_f64.view(), protos_quant.into());
5084
5085 boxes
5086 .slice_mut(s![0, ..])
5087 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
5088 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
5089 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
5090
5091 let detect_quant = (2.0 / 255.0, 0.0);
5092
5093 let decoder = if is_split {
5094 DecoderBuilder::default()
5095 .with_config_yaml_str(E2E_SPLIT_CONFIG.to_string())
5096 .with_score_threshold(score_threshold)
5097 .with_iou_threshold(iou_threshold)
5098 .build()
5099 .unwrap()
5100 } else {
5101 DecoderBuilder::default()
5102 .with_config_yaml_str(E2E_COMBINED_CONFIG.to_string())
5103 .with_score_threshold(score_threshold)
5104 .with_iou_threshold(iou_threshold)
5105 .build()
5106 .unwrap()
5107 };
5108
5109 let make_u8_tensor =
5111 |shape: &[usize], data: &[u8]| -> edgefirst_tensor::TensorDyn {
5112 let t = Tensor::<u8>::new(shape, None, None).unwrap();
5113 t.map().unwrap().as_mut_slice()[..data.len()].copy_from_slice(data);
5114 t.into()
5115 };
5116
5117 let expected = e2e_expected_boxes_quant();
5118 let mut tracker = ByteTrackBuilder::new()
5119 .track_update(0.1)
5120 .track_high_conf(0.7)
5121 .build();
5122 let mut output_boxes = Vec::with_capacity(50);
5123 let mut output_tracks = Vec::with_capacity(50);
5124
5125 let protos_td = make_u8_tensor(protos_u8.shape(), protos_u8.as_slice().unwrap());
5126
5127 if is_split {
5128 let boxes = boxes.insert_axis(Axis(0));
5129 let scores = scores.insert_axis(Axis(0));
5130 let classes = classes.insert_axis(Axis(0));
5131 let mask = mask.insert_axis(Axis(0));
5132
5133 let boxes_q: Array3<u8> = quantize_ndarray(boxes.view(), detect_quant.into());
5134 let mut scores_q: Array3<u8> =
5135 quantize_ndarray(scores.view(), detect_quant.into());
5136 let classes_q: Array3<u8> =
5137 quantize_ndarray(classes.view(), detect_quant.into());
5138 let mask_q: Array3<u8> = quantize_ndarray(mask.view(), detect_quant.into());
5139
5140 let boxes_td = make_u8_tensor(boxes_q.shape(), boxes_q.as_slice().unwrap());
5141 let classes_td =
5142 make_u8_tensor(classes_q.shape(), classes_q.as_slice().unwrap());
5143 let mask_td = make_u8_tensor(mask_q.shape(), mask_q.as_slice().unwrap());
5144
5145 if is_proto {
5146 let scores_td =
5147 make_u8_tensor(scores_q.shape(), scores_q.as_slice().unwrap());
5148 decoder
5149 .decode_proto_tracked(
5150 &mut tracker,
5151 0,
5152 &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5153 &mut output_boxes,
5154 &mut output_tracks,
5155 )
5156 .unwrap();
5157
5158 assert_eq!(output_boxes.len(), 1);
5159 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5160
5161 for score in scores_q.slice_mut(s![.., .., ..]).iter_mut() {
5162 *score = u8::MIN;
5163 }
5164 let scores_td =
5165 make_u8_tensor(scores_q.shape(), scores_q.as_slice().unwrap());
5166 let proto_result = decoder
5167 .decode_proto_tracked(
5168 &mut tracker,
5169 100_000_000 / 3,
5170 &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5171 &mut output_boxes,
5172 &mut output_tracks,
5173 )
5174 .unwrap();
5175 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5176 assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
5177 } else {
5178 let scores_td =
5179 make_u8_tensor(scores_q.shape(), scores_q.as_slice().unwrap());
5180 let mut output_masks = Vec::with_capacity(50);
5181 decoder
5182 .decode_tracked(
5183 &mut tracker,
5184 0,
5185 &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5186 &mut output_boxes,
5187 &mut output_masks,
5188 &mut output_tracks,
5189 )
5190 .unwrap();
5191
5192 assert_eq!(output_boxes.len(), 1);
5193 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5194
5195 for score in scores_q.slice_mut(s![.., .., ..]).iter_mut() {
5196 *score = u8::MIN;
5197 }
5198 let scores_td =
5199 make_u8_tensor(scores_q.shape(), scores_q.as_slice().unwrap());
5200 decoder
5201 .decode_tracked(
5202 &mut tracker,
5203 100_000_000 / 3,
5204 &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5205 &mut output_boxes,
5206 &mut output_masks,
5207 &mut output_tracks,
5208 )
5209 .unwrap();
5210 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5211 assert!(output_masks.is_empty());
5212 }
5213 } else {
5214 let detect = ndarray::concatenate![
5216 Axis(1),
5217 boxes.view(),
5218 scores.view(),
5219 classes.view(),
5220 mask.view()
5221 ];
5222 let detect = detect.insert_axis(Axis(0));
5223 assert_eq!(detect.shape(), &[1, 10, 38]);
5224 let detect =
5226 Array3::from_shape_vec(detect.raw_dim(), detect.iter().copied().collect())
5227 .unwrap();
5228 let mut detect_q: Array3<u8> =
5229 quantize_ndarray(detect.view(), detect_quant.into());
5230
5231 if is_proto {
5232 let detect_td =
5233 make_u8_tensor(detect_q.shape(), detect_q.as_slice().unwrap());
5234 decoder
5235 .decode_proto_tracked(
5236 &mut tracker,
5237 0,
5238 &[&detect_td, &protos_td],
5239 &mut output_boxes,
5240 &mut output_tracks,
5241 )
5242 .unwrap();
5243
5244 assert_eq!(output_boxes.len(), 1);
5245 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5246
5247 for score in detect_q.slice_mut(s![.., .., 4]).iter_mut() {
5248 *score = u8::MIN;
5249 }
5250 let detect_td =
5251 make_u8_tensor(detect_q.shape(), detect_q.as_slice().unwrap());
5252 let proto_result = decoder
5253 .decode_proto_tracked(
5254 &mut tracker,
5255 100_000_000 / 3,
5256 &[&detect_td, &protos_td],
5257 &mut output_boxes,
5258 &mut output_tracks,
5259 )
5260 .unwrap();
5261 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5262 assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
5263 } else {
5264 let detect_td =
5265 make_u8_tensor(detect_q.shape(), detect_q.as_slice().unwrap());
5266 let mut output_masks = Vec::with_capacity(50);
5267 decoder
5268 .decode_tracked(
5269 &mut tracker,
5270 0,
5271 &[&detect_td, &protos_td],
5272 &mut output_boxes,
5273 &mut output_masks,
5274 &mut output_tracks,
5275 )
5276 .unwrap();
5277
5278 assert_eq!(output_boxes.len(), 1);
5279 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5280
5281 for score in detect_q.slice_mut(s![.., .., 4]).iter_mut() {
5282 *score = u8::MIN;
5283 }
5284 let detect_td =
5285 make_u8_tensor(detect_q.shape(), detect_q.as_slice().unwrap());
5286 decoder
5287 .decode_tracked(
5288 &mut tracker,
5289 100_000_000 / 3,
5290 &[&detect_td, &protos_td],
5291 &mut output_boxes,
5292 &mut output_masks,
5293 &mut output_tracks,
5294 )
5295 .unwrap();
5296 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5297 assert!(output_masks.is_empty());
5298 }
5299 }
5300 }
5301 };
5302 ($name:ident, float, $layout:ident, $output:ident) => {
5303 #[test]
5304 fn $name() {
5305 use edgefirst_tensor::{Tensor, TensorMapTrait, TensorTrait};
5306
5307 let is_split = matches!(stringify!($layout), "split");
5308 let is_proto = matches!(stringify!($output), "proto");
5309
5310 let score_threshold = 0.45;
5311 let iou_threshold = 0.45;
5312
5313 let mut boxes = Array2::zeros((10, 4));
5314 let mut scores = Array2::zeros((10, 1));
5315 let mut classes = Array2::zeros((10, 1));
5316 let mask: Array2<f64> = Array2::zeros((10, 32));
5317 let protos = Array3::<f64>::zeros((160, 160, 32));
5318 let protos = protos.insert_axis(Axis(0));
5319
5320 boxes
5321 .slice_mut(s![0, ..])
5322 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
5323 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
5324 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
5325
5326 let decoder = if is_split {
5327 DecoderBuilder::default()
5328 .with_config_yaml_str(E2E_SPLIT_CONFIG.to_string())
5329 .with_score_threshold(score_threshold)
5330 .with_iou_threshold(iou_threshold)
5331 .build()
5332 .unwrap()
5333 } else {
5334 DecoderBuilder::default()
5335 .with_config_yaml_str(E2E_COMBINED_CONFIG.to_string())
5336 .with_score_threshold(score_threshold)
5337 .with_iou_threshold(iou_threshold)
5338 .build()
5339 .unwrap()
5340 };
5341
5342 let make_f64_tensor =
5344 |shape: &[usize], data: &[f64]| -> edgefirst_tensor::TensorDyn {
5345 let t = Tensor::<f64>::new(shape, None, None).unwrap();
5346 t.map().unwrap().as_mut_slice()[..data.len()].copy_from_slice(data);
5347 t.into()
5348 };
5349
5350 let expected = e2e_expected_boxes_float();
5351 let mut tracker = ByteTrackBuilder::new()
5352 .track_update(0.1)
5353 .track_high_conf(0.7)
5354 .build();
5355 let mut output_boxes = Vec::with_capacity(50);
5356 let mut output_tracks = Vec::with_capacity(50);
5357
5358 let protos_td = make_f64_tensor(protos.shape(), protos.as_slice().unwrap());
5359
5360 if is_split {
5361 let boxes = boxes.insert_axis(Axis(0));
5362 let mut scores = scores.insert_axis(Axis(0));
5363 let classes = classes.insert_axis(Axis(0));
5364 let mask = mask.insert_axis(Axis(0));
5365
5366 let boxes_td = make_f64_tensor(boxes.shape(), boxes.as_slice().unwrap());
5367 let classes_td = make_f64_tensor(classes.shape(), classes.as_slice().unwrap());
5368 let mask_td = make_f64_tensor(mask.shape(), mask.as_slice().unwrap());
5369
5370 if is_proto {
5371 let scores_td = make_f64_tensor(scores.shape(), scores.as_slice().unwrap());
5372 decoder
5373 .decode_proto_tracked(
5374 &mut tracker,
5375 0,
5376 &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5377 &mut output_boxes,
5378 &mut output_tracks,
5379 )
5380 .unwrap();
5381
5382 assert_eq!(output_boxes.len(), 1);
5383 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5384
5385 for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
5386 *score = 0.0;
5387 }
5388 let scores_td = make_f64_tensor(scores.shape(), scores.as_slice().unwrap());
5389 let proto_result = decoder
5390 .decode_proto_tracked(
5391 &mut tracker,
5392 100_000_000 / 3,
5393 &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5394 &mut output_boxes,
5395 &mut output_tracks,
5396 )
5397 .unwrap();
5398 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5399 assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
5400 } else {
5401 let scores_td = make_f64_tensor(scores.shape(), scores.as_slice().unwrap());
5402 let mut output_masks = Vec::with_capacity(50);
5403 decoder
5404 .decode_tracked(
5405 &mut tracker,
5406 0,
5407 &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5408 &mut output_boxes,
5409 &mut output_masks,
5410 &mut output_tracks,
5411 )
5412 .unwrap();
5413
5414 assert_eq!(output_boxes.len(), 1);
5415 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5416
5417 for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
5418 *score = 0.0;
5419 }
5420 let scores_td = make_f64_tensor(scores.shape(), scores.as_slice().unwrap());
5421 decoder
5422 .decode_tracked(
5423 &mut tracker,
5424 100_000_000 / 3,
5425 &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5426 &mut output_boxes,
5427 &mut output_masks,
5428 &mut output_tracks,
5429 )
5430 .unwrap();
5431 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5432 assert!(output_masks.is_empty());
5433 }
5434 } else {
5435 let detect = ndarray::concatenate![
5437 Axis(1),
5438 boxes.view(),
5439 scores.view(),
5440 classes.view(),
5441 mask.view()
5442 ];
5443 let detect = detect.insert_axis(Axis(0));
5444 assert_eq!(detect.shape(), &[1, 10, 38]);
5445 let mut detect =
5447 Array3::from_shape_vec(detect.raw_dim(), detect.iter().copied().collect())
5448 .unwrap();
5449
5450 if is_proto {
5451 let detect_td = make_f64_tensor(detect.shape(), detect.as_slice().unwrap());
5452 decoder
5453 .decode_proto_tracked(
5454 &mut tracker,
5455 0,
5456 &[&detect_td, &protos_td],
5457 &mut output_boxes,
5458 &mut output_tracks,
5459 )
5460 .unwrap();
5461
5462 assert_eq!(output_boxes.len(), 1);
5463 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5464
5465 for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
5466 *score = 0.0;
5467 }
5468 let detect_td = make_f64_tensor(detect.shape(), detect.as_slice().unwrap());
5469 let proto_result = decoder
5470 .decode_proto_tracked(
5471 &mut tracker,
5472 100_000_000 / 3,
5473 &[&detect_td, &protos_td],
5474 &mut output_boxes,
5475 &mut output_tracks,
5476 )
5477 .unwrap();
5478 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5479 assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
5480 } else {
5481 let detect_td = make_f64_tensor(detect.shape(), detect.as_slice().unwrap());
5482 let mut output_masks = Vec::with_capacity(50);
5483 decoder
5484 .decode_tracked(
5485 &mut tracker,
5486 0,
5487 &[&detect_td, &protos_td],
5488 &mut output_boxes,
5489 &mut output_masks,
5490 &mut output_tracks,
5491 )
5492 .unwrap();
5493
5494 assert_eq!(output_boxes.len(), 1);
5495 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5496
5497 for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
5498 *score = 0.0;
5499 }
5500 let detect_td = make_f64_tensor(detect.shape(), detect.as_slice().unwrap());
5501 decoder
5502 .decode_tracked(
5503 &mut tracker,
5504 100_000_000 / 3,
5505 &[&detect_td, &protos_td],
5506 &mut output_boxes,
5507 &mut output_masks,
5508 &mut output_tracks,
5509 )
5510 .unwrap();
5511 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5512 assert!(output_masks.is_empty());
5513 }
5514 }
5515 }
5516 };
5517 }
5518
5519 e2e_tracked_tensor_test!(
5520 test_decoder_tracked_tensor_end_to_end_segdet,
5521 quantized,
5522 combined,
5523 masks
5524 );
5525 e2e_tracked_tensor_test!(
5526 test_decoder_tracked_tensor_end_to_end_segdet_float,
5527 float,
5528 combined,
5529 masks
5530 );
5531 e2e_tracked_tensor_test!(
5532 test_decoder_tracked_tensor_end_to_end_segdet_proto,
5533 quantized,
5534 combined,
5535 proto
5536 );
5537 e2e_tracked_tensor_test!(
5538 test_decoder_tracked_tensor_end_to_end_segdet_proto_float,
5539 float,
5540 combined,
5541 proto
5542 );
5543 e2e_tracked_tensor_test!(
5544 test_decoder_tracked_tensor_end_to_end_segdet_split,
5545 quantized,
5546 split,
5547 masks
5548 );
5549 e2e_tracked_tensor_test!(
5550 test_decoder_tracked_tensor_end_to_end_segdet_split_float,
5551 float,
5552 split,
5553 masks
5554 );
5555 e2e_tracked_tensor_test!(
5556 test_decoder_tracked_tensor_end_to_end_segdet_split_proto,
5557 quantized,
5558 split,
5559 proto
5560 );
5561 e2e_tracked_tensor_test!(
5562 test_decoder_tracked_tensor_end_to_end_segdet_split_proto_float,
5563 float,
5564 split,
5565 proto
5566 );
5567
5568 #[test]
5569 fn test_decoder_tracked_linear_motion() {
5570 use crate::configs::{DecoderType, Nms};
5571 use crate::DecoderBuilder;
5572
5573 let score_threshold = 0.25;
5574 let iou_threshold = 0.1;
5575 let out = include_bytes!(concat!(
5576 env!("CARGO_MANIFEST_DIR"),
5577 "/../../testdata/yolov8s_80_classes.bin"
5578 ));
5579 let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
5580 let mut out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
5581 let quant = (0.0040811873, -123).into();
5582
5583 let decoder = DecoderBuilder::default()
5584 .with_config_yolo_det(
5585 crate::configs::Detection {
5586 decoder: DecoderType::Ultralytics,
5587 shape: vec![1, 84, 8400],
5588 anchors: None,
5589 quantization: Some(quant),
5590 dshape: vec![
5591 (crate::configs::DimName::Batch, 1),
5592 (crate::configs::DimName::NumFeatures, 84),
5593 (crate::configs::DimName::NumBoxes, 8400),
5594 ],
5595 normalized: Some(true),
5596 },
5597 None,
5598 )
5599 .with_score_threshold(score_threshold)
5600 .with_iou_threshold(iou_threshold)
5601 .with_nms(Some(Nms::ClassAgnostic))
5602 .build()
5603 .unwrap();
5604
5605 let mut expected_boxes = [
5606 DetectBox {
5607 bbox: BoundingBox {
5608 xmin: 0.5285137,
5609 ymin: 0.05305544,
5610 xmax: 0.87541467,
5611 ymax: 0.9998909,
5612 },
5613 score: 0.5591227,
5614 label: 0,
5615 },
5616 DetectBox {
5617 bbox: BoundingBox {
5618 xmin: 0.130598,
5619 ymin: 0.43260583,
5620 xmax: 0.35098213,
5621 ymax: 0.9958097,
5622 },
5623 score: 0.33057618,
5624 label: 75,
5625 },
5626 ];
5627
5628 let mut tracker = ByteTrackBuilder::new()
5629 .track_update(0.1)
5630 .track_high_conf(0.3)
5631 .build();
5632
5633 let mut output_boxes = Vec::with_capacity(50);
5634 let mut output_masks = Vec::with_capacity(50);
5635 let mut output_tracks = Vec::with_capacity(50);
5636
5637 decoder
5638 .decode_tracked_quantized(
5639 &mut tracker,
5640 0,
5641 &[out.view().into()],
5642 &mut output_boxes,
5643 &mut output_masks,
5644 &mut output_tracks,
5645 )
5646 .unwrap();
5647
5648 assert_eq!(output_boxes.len(), 2);
5649 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
5650 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
5651
5652 for i in 1..=100 {
5653 let mut out = out.clone();
5654 let mut x_values = out.slice_mut(s![0, 0, ..]);
5656 for x in x_values.iter_mut() {
5657 *x = x.saturating_add((i as f32 * 1e-3 / quant.0).round() as i8);
5658 }
5659
5660 decoder
5661 .decode_tracked_quantized(
5662 &mut tracker,
5663 100_000_000 * i / 3, &[out.view().into()],
5665 &mut output_boxes,
5666 &mut output_masks,
5667 &mut output_tracks,
5668 )
5669 .unwrap();
5670
5671 assert_eq!(output_boxes.len(), 2);
5672 }
5673 let tracks = tracker.get_active_tracks();
5674 let predicted_boxes: Vec<_> = tracks
5675 .iter()
5676 .map(|track| {
5677 let mut l = track.last_box;
5678 l.bbox = track.info.tracked_location.into();
5679 l
5680 })
5681 .collect();
5682 expected_boxes[0].bbox.xmin += 0.1; expected_boxes[0].bbox.xmax += 0.1;
5684 expected_boxes[1].bbox.xmin += 0.1;
5685 expected_boxes[1].bbox.xmax += 0.1;
5686
5687 assert!(predicted_boxes[0].equal_within_delta(&expected_boxes[0], 1e-3));
5688 assert!(predicted_boxes[1].equal_within_delta(&expected_boxes[1], 1e-3));
5689
5690 let mut scores_values = out.slice_mut(s![0, 4.., ..]);
5692 for score in scores_values.iter_mut() {
5693 *score = i8::MIN; }
5695 decoder
5696 .decode_tracked_quantized(
5697 &mut tracker,
5698 100_000_000 * 101 / 3,
5699 &[out.view().into()],
5700 &mut output_boxes,
5701 &mut output_masks,
5702 &mut output_tracks,
5703 )
5704 .unwrap();
5705 expected_boxes[0].bbox.xmin += 0.001; expected_boxes[0].bbox.xmax += 0.001;
5707 expected_boxes[1].bbox.xmin += 0.001;
5708 expected_boxes[1].bbox.xmax += 0.001;
5709
5710 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-3));
5711 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-3));
5712 }
5713
5714 #[test]
5715 fn test_decoder_tracked_end_to_end_float() {
5716 let score_threshold = 0.45;
5717 let iou_threshold = 0.45;
5718
5719 let mut boxes = Array2::zeros((10, 4));
5720 let mut scores = Array2::zeros((10, 1));
5721 let mut classes = Array2::zeros((10, 1));
5722
5723 boxes
5724 .slice_mut(s![0, ..,])
5725 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
5726 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
5727 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
5728
5729 let detect = ndarray::concatenate![Axis(1), boxes.view(), scores.view(), classes.view(),];
5730 let mut detect = detect.insert_axis(Axis(0));
5731 assert_eq!(detect.shape(), &[1, 10, 6]);
5732 let config = "
5733decoder_version: yolo26
5734outputs:
5735 - type: detection
5736 decoder: ultralytics
5737 quantization: [0.00784313725490196, 0]
5738 shape: [1, 10, 6]
5739 dshape:
5740 - [batch, 1]
5741 - [num_boxes, 10]
5742 - [num_features, 6]
5743 normalized: true
5744";
5745
5746 let decoder = DecoderBuilder::default()
5747 .with_config_yaml_str(config.to_string())
5748 .with_score_threshold(score_threshold)
5749 .with_iou_threshold(iou_threshold)
5750 .build()
5751 .unwrap();
5752
5753 let expected_boxes = [DetectBox {
5754 bbox: BoundingBox {
5755 xmin: 0.1234,
5756 ymin: 0.1234,
5757 xmax: 0.2345,
5758 ymax: 0.2345,
5759 },
5760 score: 0.9876,
5761 label: 2,
5762 }];
5763
5764 let mut tracker = ByteTrackBuilder::new()
5765 .track_update(0.1)
5766 .track_high_conf(0.7)
5767 .build();
5768
5769 let mut output_boxes = Vec::with_capacity(50);
5770 let mut output_masks = Vec::with_capacity(50);
5771 let mut output_tracks = Vec::with_capacity(50);
5772
5773 decoder
5774 .decode_tracked_float(
5775 &mut tracker,
5776 0,
5777 &[detect.view().into_dyn()],
5778 &mut output_boxes,
5779 &mut output_masks,
5780 &mut output_tracks,
5781 )
5782 .unwrap();
5783
5784 assert_eq!(output_boxes.len(), 1);
5785 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
5786
5787 for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
5790 *score = 0.0; }
5792
5793 decoder
5794 .decode_tracked_float(
5795 &mut tracker,
5796 100_000_000 / 3,
5797 &[detect.view().into_dyn()],
5798 &mut output_boxes,
5799 &mut output_masks,
5800 &mut output_tracks,
5801 )
5802 .unwrap();
5803 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
5804 }
5805}