1#![cfg_attr(coverage_nightly, feature(coverage_attribute))]
68
69use ndarray::{Array, Array2, Array3, ArrayView, ArrayView1, ArrayView3, Dimension};
70use num_traits::{AsPrimitive, Float, PrimInt};
71
72pub mod byte;
73pub mod error;
74pub mod float;
75pub mod modelpack;
76pub mod yolo;
77
78mod decoder;
79pub use decoder::*;
80
81pub use configs::{DecoderVersion, Nms};
82pub use error::{DecoderError, DecoderResult};
83
84use crate::{
85 decoder::configs::QuantTuple, modelpack::modelpack_segmentation_to_mask,
86 yolo::yolo_segmentation_to_mask,
87};
88
89pub trait BBoxTypeTrait {
91 fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4];
93
94 fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
96 input: &[B; 4],
97 quant: Quantization,
98 ) -> [A; 4]
99 where
100 f32: AsPrimitive<A>,
101 i32: AsPrimitive<A>;
102
103 #[inline(always)]
114 fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
115 input: ArrayView1<B>,
116 ) -> [A; 4] {
117 Self::to_xyxy_float(&[input[0], input[1], input[2], input[3]])
118 }
119
120 #[inline(always)]
121 fn ndarray_to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
123 input: ArrayView1<B>,
124 quant: Quantization,
125 ) -> [A; 4]
126 where
127 f32: AsPrimitive<A>,
128 i32: AsPrimitive<A>,
129 {
130 Self::to_xyxy_dequant(&[input[0], input[1], input[2], input[3]], quant)
131 }
132}
133
134#[derive(Debug, Clone, Copy, PartialEq, Eq)]
136pub struct XYXY {}
137
138impl BBoxTypeTrait for XYXY {
139 fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
140 input.map(|b| b.as_())
141 }
142
143 fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
144 input: &[B; 4],
145 quant: Quantization,
146 ) -> [A; 4]
147 where
148 f32: AsPrimitive<A>,
149 i32: AsPrimitive<A>,
150 {
151 let scale = quant.scale.as_();
152 let zp = quant.zero_point.as_();
153 input.map(|b| (b.as_() - zp) * scale)
154 }
155
156 #[inline(always)]
157 fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
158 input: ArrayView1<B>,
159 ) -> [A; 4] {
160 [
161 input[0].as_(),
162 input[1].as_(),
163 input[2].as_(),
164 input[3].as_(),
165 ]
166 }
167}
168
169#[derive(Debug, Clone, Copy, PartialEq, Eq)]
172pub struct XYWH {}
173
174impl BBoxTypeTrait for XYWH {
175 #[inline(always)]
176 fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
177 let half = A::one() / (A::one() + A::one());
178 [
179 (input[0].as_()) - (input[2].as_() * half),
180 (input[1].as_()) - (input[3].as_() * half),
181 (input[0].as_()) + (input[2].as_() * half),
182 (input[1].as_()) + (input[3].as_() * half),
183 ]
184 }
185
186 #[inline(always)]
187 fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
188 input: &[B; 4],
189 quant: Quantization,
190 ) -> [A; 4]
191 where
192 f32: AsPrimitive<A>,
193 i32: AsPrimitive<A>,
194 {
195 let scale = quant.scale.as_();
196 let half_scale = (quant.scale * 0.5).as_();
197 let zp = quant.zero_point.as_();
198 let [x, y, w, h] = [
199 (input[0].as_() - zp) * scale,
200 (input[1].as_() - zp) * scale,
201 (input[2].as_() - zp) * half_scale,
202 (input[3].as_() - zp) * half_scale,
203 ];
204
205 [x - w, y - h, x + w, y + h]
206 }
207
208 #[inline(always)]
209 fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
210 input: ArrayView1<B>,
211 ) -> [A; 4] {
212 let half = A::one() / (A::one() + A::one());
213 [
214 (input[0].as_()) - (input[2].as_() * half),
215 (input[1].as_()) - (input[3].as_() * half),
216 (input[0].as_()) + (input[2].as_() * half),
217 (input[1].as_()) + (input[3].as_() * half),
218 ]
219 }
220}
221
222#[derive(Debug, Clone, Copy, PartialEq)]
224pub struct Quantization {
225 pub scale: f32,
226 pub zero_point: i32,
227}
228
229impl Quantization {
230 pub fn new(scale: f32, zero_point: i32) -> Self {
239 Self { scale, zero_point }
240 }
241}
242
243impl From<QuantTuple> for Quantization {
244 fn from(quant_tuple: QuantTuple) -> Quantization {
255 Quantization {
256 scale: quant_tuple.0,
257 zero_point: quant_tuple.1,
258 }
259 }
260}
261
262impl<S, Z> From<(S, Z)> for Quantization
263where
264 S: AsPrimitive<f32>,
265 Z: AsPrimitive<i32>,
266{
267 fn from((scale, zp): (S, Z)) -> Quantization {
276 Self {
277 scale: scale.as_(),
278 zero_point: zp.as_(),
279 }
280 }
281}
282
283impl Default for Quantization {
284 fn default() -> Self {
293 Self {
294 scale: 1.0,
295 zero_point: 0,
296 }
297 }
298}
299
300#[derive(Debug, Clone, Copy, PartialEq, Default)]
302pub struct DetectBox {
303 pub bbox: BoundingBox,
304 pub score: f32,
306 pub label: usize,
308}
309
310#[derive(Debug, Clone, Copy, PartialEq, Default)]
312pub struct BoundingBox {
313 pub xmin: f32,
315 pub ymin: f32,
317 pub xmax: f32,
319 pub ymax: f32,
321}
322
323impl BoundingBox {
324 pub fn new(xmin: f32, ymin: f32, xmax: f32, ymax: f32) -> Self {
326 Self {
327 xmin,
328 ymin,
329 xmax,
330 ymax,
331 }
332 }
333
334 pub fn to_canonical(&self) -> Self {
343 let xmin = self.xmin.min(self.xmax);
344 let xmax = self.xmin.max(self.xmax);
345 let ymin = self.ymin.min(self.ymax);
346 let ymax = self.ymin.max(self.ymax);
347 BoundingBox {
348 xmin,
349 ymin,
350 xmax,
351 ymax,
352 }
353 }
354}
355
356impl From<BoundingBox> for [f32; 4] {
357 fn from(b: BoundingBox) -> Self {
372 [b.xmin, b.ymin, b.xmax, b.ymax]
373 }
374}
375
376impl From<[f32; 4]> for BoundingBox {
377 fn from(arr: [f32; 4]) -> Self {
380 BoundingBox {
381 xmin: arr[0],
382 ymin: arr[1],
383 xmax: arr[2],
384 ymax: arr[3],
385 }
386 }
387}
388
389impl DetectBox {
390 pub fn equal_within_delta(&self, rhs: &DetectBox, eps: f32) -> bool {
419 let eq_delta = |a: f32, b: f32| (a - b).abs() <= eps;
420 self.label == rhs.label
421 && eq_delta(self.score, rhs.score)
422 && eq_delta(self.bbox.xmin, rhs.bbox.xmin)
423 && eq_delta(self.bbox.ymin, rhs.bbox.ymin)
424 && eq_delta(self.bbox.xmax, rhs.bbox.xmax)
425 && eq_delta(self.bbox.ymax, rhs.bbox.ymax)
426 }
427}
428
429#[derive(Debug, Clone, PartialEq, Default)]
432pub struct Segmentation {
433 pub xmin: f32,
435 pub ymin: f32,
437 pub xmax: f32,
439 pub ymax: f32,
441 pub segmentation: Array3<u8>,
451}
452
453#[derive(Debug, Clone)]
459pub enum ProtoTensor {
460 Quantized {
464 protos: Array3<i8>,
465 quantization: Quantization,
466 },
467 Float(Array3<f32>),
469}
470
471impl ProtoTensor {
472 pub fn is_quantized(&self) -> bool {
474 matches!(self, ProtoTensor::Quantized { .. })
475 }
476
477 pub fn dim(&self) -> (usize, usize, usize) {
479 match self {
480 ProtoTensor::Quantized { protos, .. } => protos.dim(),
481 ProtoTensor::Float(arr) => arr.dim(),
482 }
483 }
484
485 pub fn as_f32(&self) -> std::borrow::Cow<'_, Array3<f32>> {
488 match self {
489 ProtoTensor::Float(arr) => std::borrow::Cow::Borrowed(arr),
490 ProtoTensor::Quantized {
491 protos,
492 quantization,
493 } => {
494 let scale = quantization.scale;
495 let zp = quantization.zero_point as f32;
496 std::borrow::Cow::Owned(protos.map(|&v| (v as f32 - zp) * scale))
497 }
498 }
499 }
500}
501
502#[derive(Debug, Clone)]
508pub struct ProtoData {
509 pub mask_coefficients: Vec<Vec<f32>>,
511 pub protos: ProtoTensor,
513}
514
515pub fn dequant_detect_box<SCORE: PrimInt + AsPrimitive<f32>>(
533 detect: &DetectBoxQuantized<SCORE>,
534 quant_scores: Quantization,
535) -> DetectBox {
536 let scaled_zp = -quant_scores.scale * quant_scores.zero_point as f32;
537 DetectBox {
538 bbox: detect.bbox,
539 score: quant_scores.scale * detect.score.as_() + scaled_zp,
540 label: detect.label,
541 }
542}
543#[derive(Debug, Clone, Copy, PartialEq)]
545pub struct DetectBoxQuantized<
546 SCORE: PrimInt + AsPrimitive<f32>,
548> {
549 pub bbox: BoundingBox,
551 pub score: SCORE,
554 pub label: usize,
556}
557
558pub fn dequantize_ndarray<T: AsPrimitive<F>, D: Dimension, F: Float + 'static>(
571 input: ArrayView<T, D>,
572 quant: Quantization,
573) -> Array<F, D>
574where
575 i32: num_traits::AsPrimitive<F>,
576 f32: num_traits::AsPrimitive<F>,
577{
578 let zero_point = quant.zero_point.as_();
579 let scale = quant.scale.as_();
580 if zero_point != F::zero() {
581 let scaled_zero = -zero_point * scale;
582 input.mapv(|d| d.as_() * scale + scaled_zero)
583 } else {
584 input.mapv(|d| d.as_() * scale)
585 }
586}
587
588pub fn dequantize_cpu<T: AsPrimitive<F>, F: Float + 'static>(
601 input: &[T],
602 quant: Quantization,
603 output: &mut [F],
604) where
605 f32: num_traits::AsPrimitive<F>,
606 i32: num_traits::AsPrimitive<F>,
607{
608 assert!(input.len() == output.len());
609 let zero_point = quant.zero_point.as_();
610 let scale = quant.scale.as_();
611 if zero_point != F::zero() {
612 let scaled_zero = -zero_point * scale; input
614 .iter()
615 .zip(output)
616 .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
617 } else {
618 input
619 .iter()
620 .zip(output)
621 .for_each(|(d, deq)| *deq = d.as_() * scale);
622 }
623}
624
625pub fn dequantize_cpu_chunked<T: AsPrimitive<F>, F: Float + 'static>(
639 input: &[T],
640 quant: Quantization,
641 output: &mut [F],
642) where
643 f32: num_traits::AsPrimitive<F>,
644 i32: num_traits::AsPrimitive<F>,
645{
646 assert!(input.len() == output.len());
647 let zero_point = quant.zero_point.as_();
648 let scale = quant.scale.as_();
649
650 let input = input.as_chunks::<4>();
651 let output = output.as_chunks_mut::<4>();
652
653 if zero_point != F::zero() {
654 let scaled_zero = -zero_point * scale; input
657 .0
658 .iter()
659 .zip(output.0)
660 .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale + scaled_zero));
661 input
662 .1
663 .iter()
664 .zip(output.1)
665 .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
666 } else {
667 input
668 .0
669 .iter()
670 .zip(output.0)
671 .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale));
672 input
673 .1
674 .iter()
675 .zip(output.1)
676 .for_each(|(d, deq)| *deq = d.as_() * scale);
677 }
678}
679
680pub fn segmentation_to_mask(segmentation: ArrayView3<u8>) -> Result<Array2<u8>, DecoderError> {
698 if segmentation.shape()[2] == 0 {
699 return Err(DecoderError::InvalidShape(
700 "Segmentation tensor must have non-zero depth".to_string(),
701 ));
702 }
703 if segmentation.shape()[2] == 1 {
704 yolo_segmentation_to_mask(segmentation, 128)
705 } else {
706 Ok(modelpack_segmentation_to_mask(segmentation))
707 }
708}
709
710fn arg_max<T: PartialOrd + Copy>(score: ArrayView1<T>) -> (T, usize) {
712 score
713 .iter()
714 .enumerate()
715 .fold((score[0], 0), |(max, arg_max), (ind, s)| {
716 if max > *s {
717 (max, arg_max)
718 } else {
719 (*s, ind)
720 }
721 })
722}
723#[cfg(test)]
724#[cfg_attr(coverage_nightly, coverage(off))]
725mod decoder_tests {
726 #![allow(clippy::excessive_precision)]
727 use crate::{
728 configs::{DecoderType, DimName, Protos},
729 modelpack::{decode_modelpack_det, decode_modelpack_split_quant},
730 yolo::{
731 decode_yolo_det, decode_yolo_det_float, decode_yolo_segdet_float,
732 decode_yolo_segdet_quant,
733 },
734 *,
735 };
736 use ndarray::{array, s, Array4};
737 use ndarray_stats::DeviationExt;
738
739 fn compare_outputs(
740 boxes: (&[DetectBox], &[DetectBox]),
741 masks: (&[Segmentation], &[Segmentation]),
742 ) {
743 let (boxes0, boxes1) = boxes;
744 let (masks0, masks1) = masks;
745
746 assert_eq!(boxes0.len(), boxes1.len());
747 assert_eq!(masks0.len(), masks1.len());
748
749 for (b_i8, b_f32) in boxes0.iter().zip(boxes1) {
750 assert!(
751 b_i8.equal_within_delta(b_f32, 1e-6),
752 "{b_i8:?} is not equal to {b_f32:?}"
753 );
754 }
755
756 for (m_i8, m_f32) in masks0.iter().zip(masks1) {
757 assert_eq!(
758 [m_i8.xmin, m_i8.ymin, m_i8.xmax, m_i8.ymax],
759 [m_f32.xmin, m_f32.ymin, m_f32.xmax, m_f32.ymax],
760 );
761 assert_eq!(m_i8.segmentation.shape(), m_f32.segmentation.shape());
762 let mask_i8 = m_i8.segmentation.map(|x| *x as i32);
763 let mask_f32 = m_f32.segmentation.map(|x| *x as i32);
764 let diff = &mask_i8 - &mask_f32;
765 for x in 0..diff.shape()[0] {
766 for y in 0..diff.shape()[1] {
767 for z in 0..diff.shape()[2] {
768 let val = diff[[x, y, z]];
769 assert!(
770 val.abs() <= 1,
771 "Difference between mask0 and mask1 is greater than 1 at ({}, {}, {}): {}",
772 x,
773 y,
774 z,
775 val
776 );
777 }
778 }
779 }
780 let mean_sq_err = mask_i8.mean_sq_err(&mask_f32).unwrap();
781 assert!(
782 mean_sq_err < 1e-2,
783 "Mean Square Error between masks was greater than 1%: {:.2}%",
784 mean_sq_err * 100.0
785 );
786 }
787 }
788
789 #[test]
790 fn test_decoder_modelpack() {
791 let score_threshold = 0.45;
792 let iou_threshold = 0.45;
793 let boxes = include_bytes!(concat!(
794 env!("CARGO_MANIFEST_DIR"),
795 "/../../testdata/modelpack_boxes_1935x1x4.bin"
796 ));
797 let boxes = ndarray::Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
798
799 let scores = include_bytes!(concat!(
800 env!("CARGO_MANIFEST_DIR"),
801 "/../../testdata/modelpack_scores_1935x1.bin"
802 ));
803 let scores = ndarray::Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
804
805 let quant_boxes = (0.004656755365431309, 21).into();
806 let quant_scores = (0.0019603664986789227, 0).into();
807
808 let decoder = DecoderBuilder::default()
809 .with_config_modelpack_det(
810 configs::Boxes {
811 decoder: DecoderType::ModelPack,
812 quantization: Some(quant_boxes),
813 shape: vec![1, 1935, 1, 4],
814 dshape: vec![
815 (DimName::Batch, 1),
816 (DimName::NumBoxes, 1935),
817 (DimName::Padding, 1),
818 (DimName::BoxCoords, 4),
819 ],
820 normalized: Some(true),
821 },
822 configs::Scores {
823 decoder: DecoderType::ModelPack,
824 quantization: Some(quant_scores),
825 shape: vec![1, 1935, 1],
826 dshape: vec![
827 (DimName::Batch, 1),
828 (DimName::NumBoxes, 1935),
829 (DimName::NumClasses, 1),
830 ],
831 },
832 )
833 .with_score_threshold(score_threshold)
834 .with_iou_threshold(iou_threshold)
835 .build()
836 .unwrap();
837
838 let quant_boxes = quant_boxes.into();
839 let quant_scores = quant_scores.into();
840
841 let mut output_boxes: Vec<_> = Vec::with_capacity(50);
842 decode_modelpack_det(
843 (boxes.slice(s![0, .., 0, ..]), quant_boxes),
844 (scores.slice(s![0, .., ..]), quant_scores),
845 score_threshold,
846 iou_threshold,
847 &mut output_boxes,
848 );
849 assert!(output_boxes[0].equal_within_delta(
850 &DetectBox {
851 bbox: BoundingBox {
852 xmin: 0.40513772,
853 ymin: 0.6379755,
854 xmax: 0.5122431,
855 ymax: 0.7730214,
856 },
857 score: 0.4861709,
858 label: 0
859 },
860 1e-6
861 ));
862
863 let mut output_boxes1 = Vec::with_capacity(50);
864 let mut output_masks1 = Vec::with_capacity(50);
865
866 decoder
867 .decode_quantized(
868 &[boxes.view().into(), scores.view().into()],
869 &mut output_boxes1,
870 &mut output_masks1,
871 )
872 .unwrap();
873
874 let mut output_boxes_float = Vec::with_capacity(50);
875 let mut output_masks_float = Vec::with_capacity(50);
876
877 let boxes = dequantize_ndarray(boxes.view(), quant_boxes);
878 let scores = dequantize_ndarray(scores.view(), quant_scores);
879
880 decoder
881 .decode_float::<f32>(
882 &[boxes.view().into_dyn(), scores.view().into_dyn()],
883 &mut output_boxes_float,
884 &mut output_masks_float,
885 )
886 .unwrap();
887
888 compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
889 compare_outputs(
890 (&output_boxes, &output_boxes_float),
891 (&[], &output_masks_float),
892 );
893 }
894
895 #[test]
896 fn test_decoder_modelpack_split_u8() {
897 let score_threshold = 0.45;
898 let iou_threshold = 0.45;
899 let detect0 = include_bytes!(concat!(
900 env!("CARGO_MANIFEST_DIR"),
901 "/../../testdata/modelpack_split_9x15x18.bin"
902 ));
903 let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
904
905 let detect1 = include_bytes!(concat!(
906 env!("CARGO_MANIFEST_DIR"),
907 "/../../testdata/modelpack_split_17x30x18.bin"
908 ));
909 let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
910
911 let quant0 = (0.08547406643629074, 174).into();
912 let quant1 = (0.09929127991199493, 183).into();
913 let anchors0 = vec![
914 [0.36666667461395264, 0.31481480598449707],
915 [0.38749998807907104, 0.4740740656852722],
916 [0.5333333611488342, 0.644444465637207],
917 ];
918 let anchors1 = vec![
919 [0.13750000298023224, 0.2074074000120163],
920 [0.2541666626930237, 0.21481481194496155],
921 [0.23125000298023224, 0.35185185074806213],
922 ];
923
924 let detect_config0 = configs::Detection {
925 decoder: DecoderType::ModelPack,
926 shape: vec![1, 9, 15, 18],
927 anchors: Some(anchors0.clone()),
928 quantization: Some(quant0),
929 dshape: vec![
930 (DimName::Batch, 1),
931 (DimName::Height, 9),
932 (DimName::Width, 15),
933 (DimName::NumAnchorsXFeatures, 18),
934 ],
935 normalized: Some(true),
936 };
937
938 let detect_config1 = configs::Detection {
939 decoder: DecoderType::ModelPack,
940 shape: vec![1, 17, 30, 18],
941 anchors: Some(anchors1.clone()),
942 quantization: Some(quant1),
943 dshape: vec![
944 (DimName::Batch, 1),
945 (DimName::Height, 17),
946 (DimName::Width, 30),
947 (DimName::NumAnchorsXFeatures, 18),
948 ],
949 normalized: Some(true),
950 };
951
952 let config0 = (&detect_config0).try_into().unwrap();
953 let config1 = (&detect_config1).try_into().unwrap();
954
955 let decoder = DecoderBuilder::default()
956 .with_config_modelpack_det_split(vec![detect_config1, detect_config0])
957 .with_score_threshold(score_threshold)
958 .with_iou_threshold(iou_threshold)
959 .build()
960 .unwrap();
961
962 let quant0 = quant0.into();
963 let quant1 = quant1.into();
964
965 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
966 decode_modelpack_split_quant(
967 &[
968 detect0.slice(s![0, .., .., ..]),
969 detect1.slice(s![0, .., .., ..]),
970 ],
971 &[config0, config1],
972 score_threshold,
973 iou_threshold,
974 &mut output_boxes,
975 );
976 assert!(output_boxes[0].equal_within_delta(
977 &DetectBox {
978 bbox: BoundingBox {
979 xmin: 0.43171933,
980 ymin: 0.68243736,
981 xmax: 0.5626645,
982 ymax: 0.808863,
983 },
984 score: 0.99240804,
985 label: 0
986 },
987 1e-6
988 ));
989
990 let mut output_boxes1: Vec<_> = Vec::with_capacity(10);
991 let mut output_masks1: Vec<_> = Vec::with_capacity(10);
992 decoder
993 .decode_quantized(
994 &[detect0.view().into(), detect1.view().into()],
995 &mut output_boxes1,
996 &mut output_masks1,
997 )
998 .unwrap();
999
1000 let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(10);
1001 let mut output_masks1_f32: Vec<_> = Vec::with_capacity(10);
1002
1003 let detect0 = dequantize_ndarray(detect0.view(), quant0);
1004 let detect1 = dequantize_ndarray(detect1.view(), quant1);
1005 decoder
1006 .decode_float::<f32>(
1007 &[detect0.view().into_dyn(), detect1.view().into_dyn()],
1008 &mut output_boxes1_f32,
1009 &mut output_masks1_f32,
1010 )
1011 .unwrap();
1012
1013 compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1014 compare_outputs(
1015 (&output_boxes, &output_boxes1_f32),
1016 (&[], &output_masks1_f32),
1017 );
1018 }
1019
1020 #[test]
1021 fn test_decoder_parse_config_modelpack_split_u8() {
1022 let score_threshold = 0.45;
1023 let iou_threshold = 0.45;
1024 let detect0 = include_bytes!(concat!(
1025 env!("CARGO_MANIFEST_DIR"),
1026 "/../../testdata/modelpack_split_9x15x18.bin"
1027 ));
1028 let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1029
1030 let detect1 = include_bytes!(concat!(
1031 env!("CARGO_MANIFEST_DIR"),
1032 "/../../testdata/modelpack_split_17x30x18.bin"
1033 ));
1034 let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1035
1036 let decoder = DecoderBuilder::default()
1037 .with_config_yaml_str(
1038 include_str!(concat!(
1039 env!("CARGO_MANIFEST_DIR"),
1040 "/../../testdata/modelpack_split.yaml"
1041 ))
1042 .to_string(),
1043 )
1044 .with_score_threshold(score_threshold)
1045 .with_iou_threshold(iou_threshold)
1046 .build()
1047 .unwrap();
1048
1049 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1050 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1051 decoder
1052 .decode_quantized(
1053 &[
1054 ArrayViewDQuantized::from(detect1.view()),
1055 ArrayViewDQuantized::from(detect0.view()),
1056 ],
1057 &mut output_boxes,
1058 &mut output_masks,
1059 )
1060 .unwrap();
1061 assert!(output_boxes[0].equal_within_delta(
1062 &DetectBox {
1063 bbox: BoundingBox {
1064 xmin: 0.43171933,
1065 ymin: 0.68243736,
1066 xmax: 0.5626645,
1067 ymax: 0.808863,
1068 },
1069 score: 0.99240804,
1070 label: 0
1071 },
1072 1e-6
1073 ));
1074 }
1075
1076 #[test]
1077 fn test_modelpack_seg() {
1078 let out = include_bytes!(concat!(
1079 env!("CARGO_MANIFEST_DIR"),
1080 "/../../testdata/modelpack_seg_2x160x160.bin"
1081 ));
1082 let out = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
1083 let quant = (1.0 / 255.0, 0).into();
1084
1085 let decoder = DecoderBuilder::default()
1086 .with_config_modelpack_seg(configs::Segmentation {
1087 decoder: DecoderType::ModelPack,
1088 quantization: Some(quant),
1089 shape: vec![1, 2, 160, 160],
1090 dshape: vec![
1091 (DimName::Batch, 1),
1092 (DimName::NumClasses, 2),
1093 (DimName::Height, 160),
1094 (DimName::Width, 160),
1095 ],
1096 })
1097 .build()
1098 .unwrap();
1099 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1100 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1101 decoder
1102 .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
1103 .unwrap();
1104
1105 let mut mask = out.slice(s![0, .., .., ..]);
1106 mask.swap_axes(0, 1);
1107 mask.swap_axes(1, 2);
1108 let mask = [Segmentation {
1109 xmin: 0.0,
1110 ymin: 0.0,
1111 xmax: 1.0,
1112 ymax: 1.0,
1113 segmentation: mask.into_owned(),
1114 }];
1115 compare_outputs((&[], &output_boxes), (&mask, &output_masks));
1116
1117 decoder
1118 .decode_float::<f32>(
1119 &[dequantize_ndarray(out.view(), quant.into())
1120 .view()
1121 .into_dyn()],
1122 &mut output_boxes,
1123 &mut output_masks,
1124 )
1125 .unwrap();
1126
1127 compare_outputs((&[], &output_boxes), (&[], &[]));
1133 let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1134 let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1135
1136 assert_eq!(mask0, mask1);
1137 }
1138 #[test]
1139 fn test_modelpack_seg_quant() {
1140 let out = include_bytes!(concat!(
1141 env!("CARGO_MANIFEST_DIR"),
1142 "/../../testdata/modelpack_seg_2x160x160.bin"
1143 ));
1144 let out_u8 = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
1145 let out_i8 = out_u8.mapv(|x| (x as i16 - 128) as i8);
1146 let out_u16 = out_u8.mapv(|x| (x as u16) << 8);
1147 let out_i16 = out_u8.mapv(|x| (((x as i32) << 8) - 32768) as i16);
1148 let out_u32 = out_u8.mapv(|x| (x as u32) << 24);
1149 let out_i32 = out_u8.mapv(|x| (((x as i64) << 24) - 2147483648) as i32);
1150
1151 let quant = (1.0 / 255.0, 0).into();
1152
1153 let decoder = DecoderBuilder::default()
1154 .with_config_modelpack_seg(configs::Segmentation {
1155 decoder: DecoderType::ModelPack,
1156 quantization: Some(quant),
1157 shape: vec![1, 2, 160, 160],
1158 dshape: vec![
1159 (DimName::Batch, 1),
1160 (DimName::NumClasses, 2),
1161 (DimName::Height, 160),
1162 (DimName::Width, 160),
1163 ],
1164 })
1165 .build()
1166 .unwrap();
1167 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1168 let mut output_masks_u8: Vec<_> = Vec::with_capacity(10);
1169 decoder
1170 .decode_quantized(
1171 &[out_u8.view().into()],
1172 &mut output_boxes,
1173 &mut output_masks_u8,
1174 )
1175 .unwrap();
1176
1177 let mut output_masks_i8: Vec<_> = Vec::with_capacity(10);
1178 decoder
1179 .decode_quantized(
1180 &[out_i8.view().into()],
1181 &mut output_boxes,
1182 &mut output_masks_i8,
1183 )
1184 .unwrap();
1185
1186 let mut output_masks_u16: Vec<_> = Vec::with_capacity(10);
1187 decoder
1188 .decode_quantized(
1189 &[out_u16.view().into()],
1190 &mut output_boxes,
1191 &mut output_masks_u16,
1192 )
1193 .unwrap();
1194
1195 let mut output_masks_i16: Vec<_> = Vec::with_capacity(10);
1196 decoder
1197 .decode_quantized(
1198 &[out_i16.view().into()],
1199 &mut output_boxes,
1200 &mut output_masks_i16,
1201 )
1202 .unwrap();
1203
1204 let mut output_masks_u32: Vec<_> = Vec::with_capacity(10);
1205 decoder
1206 .decode_quantized(
1207 &[out_u32.view().into()],
1208 &mut output_boxes,
1209 &mut output_masks_u32,
1210 )
1211 .unwrap();
1212
1213 let mut output_masks_i32: Vec<_> = Vec::with_capacity(10);
1214 decoder
1215 .decode_quantized(
1216 &[out_i32.view().into()],
1217 &mut output_boxes,
1218 &mut output_masks_i32,
1219 )
1220 .unwrap();
1221
1222 compare_outputs((&[], &output_boxes), (&[], &[]));
1223 let mask_u8 = segmentation_to_mask(output_masks_u8[0].segmentation.view()).unwrap();
1224 let mask_i8 = segmentation_to_mask(output_masks_i8[0].segmentation.view()).unwrap();
1225 let mask_u16 = segmentation_to_mask(output_masks_u16[0].segmentation.view()).unwrap();
1226 let mask_i16 = segmentation_to_mask(output_masks_i16[0].segmentation.view()).unwrap();
1227 let mask_u32 = segmentation_to_mask(output_masks_u32[0].segmentation.view()).unwrap();
1228 let mask_i32 = segmentation_to_mask(output_masks_i32[0].segmentation.view()).unwrap();
1229 assert_eq!(mask_u8, mask_i8);
1230 assert_eq!(mask_u8, mask_u16);
1231 assert_eq!(mask_u8, mask_i16);
1232 assert_eq!(mask_u8, mask_u32);
1233 assert_eq!(mask_u8, mask_i32);
1234 }
1235
1236 #[test]
1237 fn test_modelpack_segdet() {
1238 let score_threshold = 0.45;
1239 let iou_threshold = 0.45;
1240
1241 let boxes = include_bytes!(concat!(
1242 env!("CARGO_MANIFEST_DIR"),
1243 "/../../testdata/modelpack_boxes_1935x1x4.bin"
1244 ));
1245 let boxes = Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
1246
1247 let scores = include_bytes!(concat!(
1248 env!("CARGO_MANIFEST_DIR"),
1249 "/../../testdata/modelpack_scores_1935x1.bin"
1250 ));
1251 let scores = Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
1252
1253 let seg = include_bytes!(concat!(
1254 env!("CARGO_MANIFEST_DIR"),
1255 "/../../testdata/modelpack_seg_2x160x160.bin"
1256 ));
1257 let seg = Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1258
1259 let quant_boxes = (0.004656755365431309, 21).into();
1260 let quant_scores = (0.0019603664986789227, 0).into();
1261 let quant_seg = (1.0 / 255.0, 0).into();
1262
1263 let decoder = DecoderBuilder::default()
1264 .with_config_modelpack_segdet(
1265 configs::Boxes {
1266 decoder: DecoderType::ModelPack,
1267 quantization: Some(quant_boxes),
1268 shape: vec![1, 1935, 1, 4],
1269 dshape: vec![
1270 (DimName::Batch, 1),
1271 (DimName::NumBoxes, 1935),
1272 (DimName::Padding, 1),
1273 (DimName::BoxCoords, 4),
1274 ],
1275 normalized: Some(true),
1276 },
1277 configs::Scores {
1278 decoder: DecoderType::ModelPack,
1279 quantization: Some(quant_scores),
1280 shape: vec![1, 1935, 1],
1281 dshape: vec![
1282 (DimName::Batch, 1),
1283 (DimName::NumBoxes, 1935),
1284 (DimName::NumClasses, 1),
1285 ],
1286 },
1287 configs::Segmentation {
1288 decoder: DecoderType::ModelPack,
1289 quantization: Some(quant_seg),
1290 shape: vec![1, 2, 160, 160],
1291 dshape: vec![
1292 (DimName::Batch, 1),
1293 (DimName::NumClasses, 2),
1294 (DimName::Height, 160),
1295 (DimName::Width, 160),
1296 ],
1297 },
1298 )
1299 .with_iou_threshold(iou_threshold)
1300 .with_score_threshold(score_threshold)
1301 .build()
1302 .unwrap();
1303 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1304 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1305 decoder
1306 .decode_quantized(
1307 &[scores.view().into(), boxes.view().into(), seg.view().into()],
1308 &mut output_boxes,
1309 &mut output_masks,
1310 )
1311 .unwrap();
1312
1313 let mut mask = seg.slice(s![0, .., .., ..]);
1314 mask.swap_axes(0, 1);
1315 mask.swap_axes(1, 2);
1316 let mask = [Segmentation {
1317 xmin: 0.0,
1318 ymin: 0.0,
1319 xmax: 1.0,
1320 ymax: 1.0,
1321 segmentation: mask.into_owned(),
1322 }];
1323 let correct_boxes = [DetectBox {
1324 bbox: BoundingBox {
1325 xmin: 0.40513772,
1326 ymin: 0.6379755,
1327 xmax: 0.5122431,
1328 ymax: 0.7730214,
1329 },
1330 score: 0.4861709,
1331 label: 0,
1332 }];
1333 compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1334
1335 let scores = dequantize_ndarray(scores.view(), quant_scores.into());
1336 let boxes = dequantize_ndarray(boxes.view(), quant_boxes.into());
1337 let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1338 decoder
1339 .decode_float::<f32>(
1340 &[
1341 scores.view().into_dyn(),
1342 boxes.view().into_dyn(),
1343 seg.view().into_dyn(),
1344 ],
1345 &mut output_boxes,
1346 &mut output_masks,
1347 )
1348 .unwrap();
1349
1350 compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1356 let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1357 let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1358
1359 assert_eq!(mask0, mask1);
1360 }
1361
1362 #[test]
1363 fn test_modelpack_segdet_split() {
1364 let score_threshold = 0.8;
1365 let iou_threshold = 0.5;
1366
1367 let seg = include_bytes!(concat!(
1368 env!("CARGO_MANIFEST_DIR"),
1369 "/../../testdata/modelpack_seg_2x160x160.bin"
1370 ));
1371 let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1372
1373 let detect0 = include_bytes!(concat!(
1374 env!("CARGO_MANIFEST_DIR"),
1375 "/../../testdata/modelpack_split_9x15x18.bin"
1376 ));
1377 let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1378
1379 let detect1 = include_bytes!(concat!(
1380 env!("CARGO_MANIFEST_DIR"),
1381 "/../../testdata/modelpack_split_17x30x18.bin"
1382 ));
1383 let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1384
1385 let quant0 = (0.08547406643629074, 174).into();
1386 let quant1 = (0.09929127991199493, 183).into();
1387 let quant_seg = (1.0 / 255.0, 0).into();
1388
1389 let anchors0 = vec![
1390 [0.36666667461395264, 0.31481480598449707],
1391 [0.38749998807907104, 0.4740740656852722],
1392 [0.5333333611488342, 0.644444465637207],
1393 ];
1394 let anchors1 = vec![
1395 [0.13750000298023224, 0.2074074000120163],
1396 [0.2541666626930237, 0.21481481194496155],
1397 [0.23125000298023224, 0.35185185074806213],
1398 ];
1399
1400 let decoder = DecoderBuilder::default()
1401 .with_config_modelpack_segdet_split(
1402 vec![
1403 configs::Detection {
1404 decoder: DecoderType::ModelPack,
1405 shape: vec![1, 17, 30, 18],
1406 anchors: Some(anchors1),
1407 quantization: Some(quant1),
1408 dshape: vec![
1409 (DimName::Batch, 1),
1410 (DimName::Height, 17),
1411 (DimName::Width, 30),
1412 (DimName::NumAnchorsXFeatures, 18),
1413 ],
1414 normalized: Some(true),
1415 },
1416 configs::Detection {
1417 decoder: DecoderType::ModelPack,
1418 shape: vec![1, 9, 15, 18],
1419 anchors: Some(anchors0),
1420 quantization: Some(quant0),
1421 dshape: vec![
1422 (DimName::Batch, 1),
1423 (DimName::Height, 9),
1424 (DimName::Width, 15),
1425 (DimName::NumAnchorsXFeatures, 18),
1426 ],
1427 normalized: Some(true),
1428 },
1429 ],
1430 configs::Segmentation {
1431 decoder: DecoderType::ModelPack,
1432 quantization: Some(quant_seg),
1433 shape: vec![1, 2, 160, 160],
1434 dshape: vec![
1435 (DimName::Batch, 1),
1436 (DimName::NumClasses, 2),
1437 (DimName::Height, 160),
1438 (DimName::Width, 160),
1439 ],
1440 },
1441 )
1442 .with_score_threshold(score_threshold)
1443 .with_iou_threshold(iou_threshold)
1444 .build()
1445 .unwrap();
1446 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1447 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1448 decoder
1449 .decode_quantized(
1450 &[
1451 detect0.view().into(),
1452 detect1.view().into(),
1453 seg.view().into(),
1454 ],
1455 &mut output_boxes,
1456 &mut output_masks,
1457 )
1458 .unwrap();
1459
1460 let mut mask = seg.slice(s![0, .., .., ..]);
1461 mask.swap_axes(0, 1);
1462 mask.swap_axes(1, 2);
1463 let mask = [Segmentation {
1464 xmin: 0.0,
1465 ymin: 0.0,
1466 xmax: 1.0,
1467 ymax: 1.0,
1468 segmentation: mask.into_owned(),
1469 }];
1470 let correct_boxes = [DetectBox {
1471 bbox: BoundingBox {
1472 xmin: 0.43171933,
1473 ymin: 0.68243736,
1474 xmax: 0.5626645,
1475 ymax: 0.808863,
1476 },
1477 score: 0.99240804,
1478 label: 0,
1479 }];
1480 println!("Output Boxes: {:?}", output_boxes);
1481 compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1482
1483 let detect0 = dequantize_ndarray(detect0.view(), quant0.into());
1484 let detect1 = dequantize_ndarray(detect1.view(), quant1.into());
1485 let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1486 decoder
1487 .decode_float::<f32>(
1488 &[
1489 detect0.view().into_dyn(),
1490 detect1.view().into_dyn(),
1491 seg.view().into_dyn(),
1492 ],
1493 &mut output_boxes,
1494 &mut output_masks,
1495 )
1496 .unwrap();
1497
1498 compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1504 let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1505 let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1506
1507 assert_eq!(mask0, mask1);
1508 }
1509
1510 #[test]
1511 fn test_dequant_chunked() {
1512 let out = include_bytes!(concat!(
1513 env!("CARGO_MANIFEST_DIR"),
1514 "/../../testdata/yolov8s_80_classes.bin"
1515 ));
1516 let mut out =
1517 unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) }.to_vec();
1518 out.push(123); let mut out_dequant = vec![0.0; 84 * 8400 + 1];
1521 let mut out_dequant_simd = vec![0.0; 84 * 8400 + 1];
1522 let quant = Quantization::new(0.0040811873, -123);
1523 dequantize_cpu(&out, quant, &mut out_dequant);
1524
1525 dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1526 assert_eq!(out_dequant, out_dequant_simd);
1527
1528 let quant = Quantization::new(0.0040811873, 0);
1529 dequantize_cpu(&out, quant, &mut out_dequant);
1530
1531 dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1532 assert_eq!(out_dequant, out_dequant_simd);
1533 }
1534
1535 #[test]
1536 fn test_dequant_ground_truth() {
1537 let quant = Quantization::new(0.1, -128);
1542 let input: Vec<i8> = vec![0, 127, -128, 64];
1543 let mut output = vec![0.0f32; 4];
1544 let mut output_chunked = vec![0.0f32; 4];
1545 dequantize_cpu(&input, quant, &mut output);
1546 dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1547 let expected: Vec<f32> = vec![12.8, 25.5, 0.0, 19.2];
1552 for (i, (&out, &exp)) in output.iter().zip(expected.iter()).enumerate() {
1553 assert!((out - exp).abs() < 1e-5, "cpu[{i}]: {out} != {exp}");
1554 }
1555 for (i, (&out, &exp)) in output_chunked.iter().zip(expected.iter()).enumerate() {
1556 assert!((out - exp).abs() < 1e-5, "chunked[{i}]: {out} != {exp}");
1557 }
1558
1559 let quant = Quantization::new(1.0, 0);
1561 dequantize_cpu(&input, quant, &mut output);
1562 dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1563 let expected: Vec<f32> = vec![0.0, 127.0, -128.0, 64.0];
1564 assert_eq!(output, expected);
1565 assert_eq!(output_chunked, expected);
1566
1567 let quant = Quantization::new(0.5, 0);
1569 dequantize_cpu(&input, quant, &mut output);
1570 dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1571 let expected: Vec<f32> = vec![0.0, 63.5, -64.0, 32.0];
1572 assert_eq!(output, expected);
1573 assert_eq!(output_chunked, expected);
1574
1575 let quant = Quantization::new(0.021287762, 31);
1577 let input: Vec<i8> = vec![-128, -1, 0, 1, 31, 127];
1578 let mut output = vec![0.0f32; 6];
1579 let mut output_chunked = vec![0.0f32; 6];
1580 dequantize_cpu(&input, quant, &mut output);
1581 dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1582 for i in 0..6 {
1583 let expected = (input[i] as f32 - 31.0) * 0.021287762;
1584 assert!(
1585 (output[i] - expected).abs() < 1e-5,
1586 "cpu[{i}]: {} != {expected}",
1587 output[i]
1588 );
1589 assert!(
1590 (output_chunked[i] - expected).abs() < 1e-5,
1591 "chunked[{i}]: {} != {expected}",
1592 output_chunked[i]
1593 );
1594 }
1595 }
1596
1597 #[test]
1598 fn test_decoder_yolo_det() {
1599 let score_threshold = 0.25;
1600 let iou_threshold = 0.7;
1601 let out = include_bytes!(concat!(
1602 env!("CARGO_MANIFEST_DIR"),
1603 "/../../testdata/yolov8s_80_classes.bin"
1604 ));
1605 let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
1606 let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
1607 let quant = (0.0040811873, -123).into();
1608
1609 let decoder = DecoderBuilder::default()
1610 .with_config_yolo_det(
1611 configs::Detection {
1612 decoder: DecoderType::Ultralytics,
1613 shape: vec![1, 84, 8400],
1614 anchors: None,
1615 quantization: Some(quant),
1616 dshape: vec![
1617 (DimName::Batch, 1),
1618 (DimName::NumFeatures, 84),
1619 (DimName::NumBoxes, 8400),
1620 ],
1621 normalized: Some(true),
1622 },
1623 Some(DecoderVersion::Yolo11),
1624 )
1625 .with_score_threshold(score_threshold)
1626 .with_iou_threshold(iou_threshold)
1627 .build()
1628 .unwrap();
1629
1630 let mut output_boxes: Vec<_> = Vec::with_capacity(50);
1631 decode_yolo_det(
1632 (out.slice(s![0, .., ..]), quant.into()),
1633 score_threshold,
1634 iou_threshold,
1635 Some(configs::Nms::ClassAgnostic),
1636 &mut output_boxes,
1637 );
1638 assert!(output_boxes[0].equal_within_delta(
1639 &DetectBox {
1640 bbox: BoundingBox {
1641 xmin: 0.5285137,
1642 ymin: 0.05305544,
1643 xmax: 0.87541467,
1644 ymax: 0.9998909,
1645 },
1646 score: 0.5591227,
1647 label: 0
1648 },
1649 1e-6
1650 ));
1651
1652 assert!(output_boxes[1].equal_within_delta(
1653 &DetectBox {
1654 bbox: BoundingBox {
1655 xmin: 0.130598,
1656 ymin: 0.43260583,
1657 xmax: 0.35098213,
1658 ymax: 0.9958097,
1659 },
1660 score: 0.33057618,
1661 label: 75
1662 },
1663 1e-6
1664 ));
1665
1666 let mut output_boxes1: Vec<_> = Vec::with_capacity(50);
1667 let mut output_masks1: Vec<_> = Vec::with_capacity(50);
1668 decoder
1669 .decode_quantized(&[out.view().into()], &mut output_boxes1, &mut output_masks1)
1670 .unwrap();
1671
1672 let out = dequantize_ndarray(out.view(), quant.into());
1673 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(50);
1674 let mut output_masks_f32: Vec<_> = Vec::with_capacity(50);
1675 decoder
1676 .decode_float::<f32>(
1677 &[out.view().into_dyn()],
1678 &mut output_boxes_f32,
1679 &mut output_masks_f32,
1680 )
1681 .unwrap();
1682
1683 compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1684 compare_outputs((&output_boxes, &output_boxes_f32), (&[], &output_masks_f32));
1685 }
1686
1687 #[test]
1688 fn test_decoder_masks() {
1689 let score_threshold = 0.45;
1690 let iou_threshold = 0.45;
1691 let boxes = include_bytes!(concat!(
1692 env!("CARGO_MANIFEST_DIR"),
1693 "/../../testdata/yolov8_boxes_116x8400.bin"
1694 ));
1695 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1696 let boxes = ndarray::Array2::from_shape_vec((116, 8400), boxes.to_vec()).unwrap();
1697 let quant_boxes = Quantization::new(0.021287761628627777, 31);
1698
1699 let protos = include_bytes!(concat!(
1700 env!("CARGO_MANIFEST_DIR"),
1701 "/../../testdata/yolov8_protos_160x160x32.bin"
1702 ));
1703 let protos =
1704 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
1705 let protos = ndarray::Array3::from_shape_vec((160, 160, 32), protos.to_vec()).unwrap();
1706 let quant_protos = Quantization::new(0.02491161972284317, -117);
1707 let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1708 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1709 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1710 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1711 decode_yolo_segdet_float(
1712 seg.view(),
1713 protos.view(),
1714 score_threshold,
1715 iou_threshold,
1716 Some(configs::Nms::ClassAgnostic),
1717 &mut output_boxes,
1718 &mut output_masks,
1719 )
1720 .unwrap();
1721 assert_eq!(output_boxes.len(), 2);
1722 assert_eq!(output_boxes.len(), output_masks.len());
1723
1724 for (b, m) in output_boxes.iter().zip(&output_masks) {
1725 assert!(b.bbox.xmin >= m.xmin);
1726 assert!(b.bbox.ymin >= m.ymin);
1727 assert!(b.bbox.xmax >= m.xmax);
1728 assert!(b.bbox.ymax >= m.ymax);
1729 }
1730 assert!(output_boxes[0].equal_within_delta(
1731 &DetectBox {
1732 bbox: BoundingBox {
1733 xmin: 0.08515105,
1734 ymin: 0.7131401,
1735 xmax: 0.29802868,
1736 ymax: 0.8195788,
1737 },
1738 score: 0.91537374,
1739 label: 23
1740 },
1741 1.0 / 160.0, ));
1743
1744 assert!(output_boxes[1].equal_within_delta(
1745 &DetectBox {
1746 bbox: BoundingBox {
1747 xmin: 0.59605736,
1748 ymin: 0.25545314,
1749 xmax: 0.93666154,
1750 ymax: 0.72378385,
1751 },
1752 score: 0.91537374,
1753 label: 23
1754 },
1755 1.0 / 160.0, ));
1757
1758 let full_mask = include_bytes!(concat!(
1759 env!("CARGO_MANIFEST_DIR"),
1760 "/../../testdata/yolov8_mask_results.bin"
1761 ));
1762 let full_mask = ndarray::Array2::from_shape_vec((160, 160), full_mask.to_vec()).unwrap();
1763
1764 let cropped_mask = full_mask.slice(ndarray::s![
1765 (output_masks[1].ymin * 160.0) as usize..(output_masks[1].ymax * 160.0) as usize,
1766 (output_masks[1].xmin * 160.0) as usize..(output_masks[1].xmax * 160.0) as usize,
1767 ]);
1768
1769 assert_eq!(
1770 cropped_mask,
1771 segmentation_to_mask(output_masks[1].segmentation.view()).unwrap()
1772 );
1773 }
1774
1775 #[test]
1780 fn test_decoder_masks_nchw_protos() {
1781 let score_threshold = 0.45;
1782 let iou_threshold = 0.45;
1783
1784 let boxes_raw = include_bytes!(concat!(
1786 env!("CARGO_MANIFEST_DIR"),
1787 "/../../testdata/yolov8_boxes_116x8400.bin"
1788 ));
1789 let boxes_raw =
1790 unsafe { std::slice::from_raw_parts(boxes_raw.as_ptr() as *const i8, boxes_raw.len()) };
1791 let boxes_2d = ndarray::Array2::from_shape_vec((116, 8400), boxes_raw.to_vec()).unwrap();
1792 let quant_boxes = Quantization::new(0.021287761628627777, 31);
1793
1794 let protos_raw = include_bytes!(concat!(
1796 env!("CARGO_MANIFEST_DIR"),
1797 "/../../testdata/yolov8_protos_160x160x32.bin"
1798 ));
1799 let protos_raw = unsafe {
1800 std::slice::from_raw_parts(protos_raw.as_ptr() as *const i8, protos_raw.len())
1801 };
1802 let protos_hwc =
1803 ndarray::Array3::from_shape_vec((160, 160, 32), protos_raw.to_vec()).unwrap();
1804 let quant_protos = Quantization::new(0.02491161972284317, -117);
1805 let protos_f32_hwc = dequantize_ndarray::<_, _, f32>(protos_hwc.view(), quant_protos);
1806
1807 let seg = dequantize_ndarray::<_, _, f32>(boxes_2d.view(), quant_boxes);
1809 let mut ref_boxes: Vec<_> = Vec::with_capacity(10);
1810 let mut ref_masks: Vec<_> = Vec::with_capacity(10);
1811 decode_yolo_segdet_float(
1812 seg.view(),
1813 protos_f32_hwc.view(),
1814 score_threshold,
1815 iou_threshold,
1816 Some(configs::Nms::ClassAgnostic),
1817 &mut ref_boxes,
1818 &mut ref_masks,
1819 )
1820 .unwrap();
1821 assert_eq!(ref_boxes.len(), 2);
1822
1823 let protos_f32_chw = protos_f32_hwc.permuted_axes([2, 0, 1]); let protos_nchw = protos_f32_chw.insert_axis(ndarray::Axis(0)); let seg_3d = seg.insert_axis(ndarray::Axis(0)); let decoder = DecoderBuilder::default()
1833 .with_config_yolo_segdet(
1834 configs::Detection {
1835 decoder: configs::DecoderType::Ultralytics,
1836 quantization: None,
1837 shape: vec![1, 116, 8400],
1838 dshape: vec![],
1839 normalized: Some(true),
1840 anchors: None,
1841 },
1842 configs::Protos {
1843 decoder: configs::DecoderType::Ultralytics,
1844 quantization: None,
1845 shape: vec![1, 32, 160, 160],
1846 dshape: vec![], },
1848 None, )
1850 .with_score_threshold(score_threshold)
1851 .with_iou_threshold(iou_threshold)
1852 .build()
1853 .unwrap();
1854
1855 let mut cfg_boxes: Vec<_> = Vec::with_capacity(10);
1856 let mut cfg_masks: Vec<_> = Vec::with_capacity(10);
1857 decoder
1858 .decode_float(
1859 &[seg_3d.view().into_dyn(), protos_nchw.view().into_dyn()],
1860 &mut cfg_boxes,
1861 &mut cfg_masks,
1862 )
1863 .unwrap();
1864
1865 assert_eq!(
1867 cfg_boxes.len(),
1868 ref_boxes.len(),
1869 "config path produced {} boxes, reference produced {}",
1870 cfg_boxes.len(),
1871 ref_boxes.len()
1872 );
1873
1874 for (i, (cb, rb)) in cfg_boxes.iter().zip(&ref_boxes).enumerate() {
1876 assert!(
1877 cb.equal_within_delta(rb, 0.01),
1878 "box {i} mismatch: config={cb:?}, reference={rb:?}"
1879 );
1880 }
1881
1882 for (i, (cm, rm)) in cfg_masks.iter().zip(&ref_masks).enumerate() {
1884 let cm_arr = segmentation_to_mask(cm.segmentation.view()).unwrap();
1885 let rm_arr = segmentation_to_mask(rm.segmentation.view()).unwrap();
1886 assert_eq!(
1887 cm_arr, rm_arr,
1888 "mask {i} pixel mismatch between config-driven and reference paths"
1889 );
1890 }
1891 }
1892
1893 #[test]
1894 fn test_decoder_masks_i8() {
1895 let score_threshold = 0.45;
1896 let iou_threshold = 0.45;
1897 let boxes = include_bytes!(concat!(
1898 env!("CARGO_MANIFEST_DIR"),
1899 "/../../testdata/yolov8_boxes_116x8400.bin"
1900 ));
1901 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1902 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes.to_vec()).unwrap();
1903 let quant_boxes = (0.021287761628627777, 31).into();
1904
1905 let protos = include_bytes!(concat!(
1906 env!("CARGO_MANIFEST_DIR"),
1907 "/../../testdata/yolov8_protos_160x160x32.bin"
1908 ));
1909 let protos =
1910 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
1911 let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
1912 let quant_protos = (0.02491161972284317, -117).into();
1913 let mut output_boxes: Vec<_> = Vec::with_capacity(500);
1914 let mut output_masks: Vec<_> = Vec::with_capacity(500);
1915
1916 let decoder = DecoderBuilder::default()
1917 .with_config_yolo_segdet(
1918 configs::Detection {
1919 decoder: configs::DecoderType::Ultralytics,
1920 quantization: Some(quant_boxes),
1921 shape: vec![1, 116, 8400],
1922 anchors: None,
1923 dshape: vec![
1924 (DimName::Batch, 1),
1925 (DimName::NumFeatures, 116),
1926 (DimName::NumBoxes, 8400),
1927 ],
1928 normalized: Some(true),
1929 },
1930 Protos {
1931 decoder: configs::DecoderType::Ultralytics,
1932 quantization: Some(quant_protos),
1933 shape: vec![1, 160, 160, 32],
1934 dshape: vec![
1935 (DimName::Batch, 1),
1936 (DimName::Height, 160),
1937 (DimName::Width, 160),
1938 (DimName::NumProtos, 32),
1939 ],
1940 },
1941 Some(DecoderVersion::Yolo11),
1942 )
1943 .with_score_threshold(score_threshold)
1944 .with_iou_threshold(iou_threshold)
1945 .build()
1946 .unwrap();
1947
1948 let quant_boxes = quant_boxes.into();
1949 let quant_protos = quant_protos.into();
1950
1951 decode_yolo_segdet_quant(
1952 (boxes.slice(s![0, .., ..]), quant_boxes),
1953 (protos.slice(s![0, .., .., ..]), quant_protos),
1954 score_threshold,
1955 iou_threshold,
1956 Some(configs::Nms::ClassAgnostic),
1957 &mut output_boxes,
1958 &mut output_masks,
1959 )
1960 .unwrap();
1961
1962 let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
1963 let mut output_masks1: Vec<_> = Vec::with_capacity(500);
1964
1965 decoder
1966 .decode_quantized(
1967 &[boxes.view().into(), protos.view().into()],
1968 &mut output_boxes1,
1969 &mut output_masks1,
1970 )
1971 .unwrap();
1972
1973 let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1974 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1975
1976 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
1977 let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
1978 decode_yolo_segdet_float(
1979 seg.slice(s![0, .., ..]),
1980 protos.slice(s![0, .., .., ..]),
1981 score_threshold,
1982 iou_threshold,
1983 Some(configs::Nms::ClassAgnostic),
1984 &mut output_boxes_f32,
1985 &mut output_masks_f32,
1986 )
1987 .unwrap();
1988
1989 let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(500);
1990 let mut output_masks1_f32: Vec<_> = Vec::with_capacity(500);
1991
1992 decoder
1993 .decode_float(
1994 &[seg.view().into_dyn(), protos.view().into_dyn()],
1995 &mut output_boxes1_f32,
1996 &mut output_masks1_f32,
1997 )
1998 .unwrap();
1999
2000 compare_outputs(
2001 (&output_boxes, &output_boxes1),
2002 (&output_masks, &output_masks1),
2003 );
2004
2005 compare_outputs(
2006 (&output_boxes, &output_boxes_f32),
2007 (&output_masks, &output_masks_f32),
2008 );
2009
2010 compare_outputs(
2011 (&output_boxes_f32, &output_boxes1_f32),
2012 (&output_masks_f32, &output_masks1_f32),
2013 );
2014 }
2015
2016 #[test]
2017 fn test_decoder_yolo_split() {
2018 let score_threshold = 0.45;
2019 let iou_threshold = 0.45;
2020 let boxes = include_bytes!(concat!(
2021 env!("CARGO_MANIFEST_DIR"),
2022 "/../../testdata/yolov8_boxes_116x8400.bin"
2023 ));
2024 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
2025 let boxes: Vec<_> = boxes.iter().map(|x| *x as i16 * 256).collect();
2026 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2027
2028 let quant_boxes = Quantization::new(0.021287761628627777 / 256.0, 31 * 256);
2029
2030 let decoder = DecoderBuilder::default()
2031 .with_config_yolo_split_det(
2032 configs::Boxes {
2033 decoder: configs::DecoderType::Ultralytics,
2034 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2035 shape: vec![1, 4, 8400],
2036 dshape: vec![
2037 (DimName::Batch, 1),
2038 (DimName::BoxCoords, 4),
2039 (DimName::NumBoxes, 8400),
2040 ],
2041 normalized: Some(true),
2042 },
2043 configs::Scores {
2044 decoder: configs::DecoderType::Ultralytics,
2045 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2046 shape: vec![1, 80, 8400],
2047 dshape: vec![
2048 (DimName::Batch, 1),
2049 (DimName::NumClasses, 80),
2050 (DimName::NumBoxes, 8400),
2051 ],
2052 },
2053 )
2054 .with_score_threshold(score_threshold)
2055 .with_iou_threshold(iou_threshold)
2056 .build()
2057 .unwrap();
2058
2059 let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2060 let mut output_masks: Vec<_> = Vec::with_capacity(500);
2061
2062 decoder
2063 .decode_quantized(
2064 &[
2065 boxes.slice(s![.., ..4, ..]).into(),
2066 boxes.slice(s![.., 4..84, ..]).into(),
2067 ],
2068 &mut output_boxes,
2069 &mut output_masks,
2070 )
2071 .unwrap();
2072
2073 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
2074 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2075 decode_yolo_det_float(
2076 seg.slice(s![0, ..84, ..]),
2077 score_threshold,
2078 iou_threshold,
2079 Some(configs::Nms::ClassAgnostic),
2080 &mut output_boxes_f32,
2081 );
2082
2083 let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
2084 let mut output_masks1: Vec<_> = Vec::with_capacity(500);
2085
2086 decoder
2087 .decode_float(
2088 &[
2089 seg.slice(s![.., ..4, ..]).into_dyn(),
2090 seg.slice(s![.., 4..84, ..]).into_dyn(),
2091 ],
2092 &mut output_boxes1,
2093 &mut output_masks1,
2094 )
2095 .unwrap();
2096 compare_outputs((&output_boxes, &output_boxes_f32), (&output_masks, &[]));
2097 compare_outputs((&output_boxes_f32, &output_boxes1), (&[], &output_masks1));
2098 }
2099
2100 #[test]
2101 fn test_decoder_masks_config_mixed() {
2102 let score_threshold = 0.45;
2103 let iou_threshold = 0.45;
2104 let boxes = include_bytes!(concat!(
2105 env!("CARGO_MANIFEST_DIR"),
2106 "/../../testdata/yolov8_boxes_116x8400.bin"
2107 ));
2108 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
2109 let boxes: Vec<_> = boxes.iter().map(|x| *x as i16 * 256).collect();
2110 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2111
2112 let quant_boxes = Quantization::new(0.021287761628627777 / 256.0, 31 * 256);
2113
2114 let protos = include_bytes!(concat!(
2115 env!("CARGO_MANIFEST_DIR"),
2116 "/../../testdata/yolov8_protos_160x160x32.bin"
2117 ));
2118 let protos =
2119 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
2120 let protos: Vec<_> = protos.to_vec();
2121 let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
2122 let quant_protos = Quantization::new(0.02491161972284317, -117);
2123
2124 let decoder = DecoderBuilder::default()
2125 .with_config_yolo_split_segdet(
2126 configs::Boxes {
2127 decoder: configs::DecoderType::Ultralytics,
2128 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2129 shape: vec![1, 4, 8400],
2130 dshape: vec![
2131 (DimName::Batch, 1),
2132 (DimName::BoxCoords, 4),
2133 (DimName::NumBoxes, 8400),
2134 ],
2135 normalized: Some(true),
2136 },
2137 configs::Scores {
2138 decoder: configs::DecoderType::Ultralytics,
2139 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2140 shape: vec![1, 80, 8400],
2141 dshape: vec![
2142 (DimName::Batch, 1),
2143 (DimName::NumClasses, 80),
2144 (DimName::NumBoxes, 8400),
2145 ],
2146 },
2147 configs::MaskCoefficients {
2148 decoder: configs::DecoderType::Ultralytics,
2149 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2150 shape: vec![1, 32, 8400],
2151 dshape: vec![
2152 (DimName::Batch, 1),
2153 (DimName::NumProtos, 32),
2154 (DimName::NumBoxes, 8400),
2155 ],
2156 },
2157 configs::Protos {
2158 decoder: configs::DecoderType::Ultralytics,
2159 quantization: Some(QuantTuple(quant_protos.scale, quant_protos.zero_point)),
2160 shape: vec![1, 160, 160, 32],
2161 dshape: vec![
2162 (DimName::Batch, 1),
2163 (DimName::Height, 160),
2164 (DimName::Width, 160),
2165 (DimName::NumProtos, 32),
2166 ],
2167 },
2168 )
2169 .with_score_threshold(score_threshold)
2170 .with_iou_threshold(iou_threshold)
2171 .build()
2172 .unwrap();
2173
2174 let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2175 let mut output_masks: Vec<_> = Vec::with_capacity(500);
2176
2177 decoder
2178 .decode_quantized(
2179 &[
2180 boxes.slice(s![.., ..4, ..]).into(),
2181 boxes.slice(s![.., 4..84, ..]).into(),
2182 boxes.slice(s![.., 84.., ..]).into(),
2183 protos.view().into(),
2184 ],
2185 &mut output_boxes,
2186 &mut output_masks,
2187 )
2188 .unwrap();
2189
2190 let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
2191 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
2192 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2193 let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
2194 decode_yolo_segdet_float(
2195 seg.slice(s![0, .., ..]),
2196 protos.slice(s![0, .., .., ..]),
2197 score_threshold,
2198 iou_threshold,
2199 Some(configs::Nms::ClassAgnostic),
2200 &mut output_boxes_f32,
2201 &mut output_masks_f32,
2202 )
2203 .unwrap();
2204
2205 let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
2206 let mut output_masks1: Vec<_> = Vec::with_capacity(500);
2207
2208 decoder
2209 .decode_float(
2210 &[
2211 seg.slice(s![.., ..4, ..]).into_dyn(),
2212 seg.slice(s![.., 4..84, ..]).into_dyn(),
2213 seg.slice(s![.., 84.., ..]).into_dyn(),
2214 protos.view().into_dyn(),
2215 ],
2216 &mut output_boxes1,
2217 &mut output_masks1,
2218 )
2219 .unwrap();
2220 compare_outputs(
2221 (&output_boxes, &output_boxes_f32),
2222 (&output_masks, &output_masks_f32),
2223 );
2224 compare_outputs(
2225 (&output_boxes_f32, &output_boxes1),
2226 (&output_masks_f32, &output_masks1),
2227 );
2228 }
2229
2230 #[test]
2231 fn test_decoder_masks_config_i32() {
2232 let score_threshold = 0.45;
2233 let iou_threshold = 0.45;
2234 let boxes = include_bytes!(concat!(
2235 env!("CARGO_MANIFEST_DIR"),
2236 "/../../testdata/yolov8_boxes_116x8400.bin"
2237 ));
2238 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
2239 let scale = 1 << 23;
2240 let boxes: Vec<_> = boxes.iter().map(|x| *x as i32 * scale).collect();
2241 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2242
2243 let quant_boxes = Quantization::new(0.021287761628627777 / scale as f32, 31 * scale);
2244
2245 let protos = include_bytes!(concat!(
2246 env!("CARGO_MANIFEST_DIR"),
2247 "/../../testdata/yolov8_protos_160x160x32.bin"
2248 ));
2249 let protos =
2250 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
2251 let protos: Vec<_> = protos.iter().map(|x| *x as i32 * scale).collect();
2252 let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
2253 let quant_protos = Quantization::new(0.02491161972284317 / scale as f32, -117 * scale);
2254
2255 let decoder = DecoderBuilder::default()
2256 .with_config_yolo_split_segdet(
2257 configs::Boxes {
2258 decoder: configs::DecoderType::Ultralytics,
2259 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2260 shape: vec![1, 4, 8400],
2261 dshape: vec![
2262 (DimName::Batch, 1),
2263 (DimName::BoxCoords, 4),
2264 (DimName::NumBoxes, 8400),
2265 ],
2266 normalized: Some(true),
2267 },
2268 configs::Scores {
2269 decoder: configs::DecoderType::Ultralytics,
2270 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2271 shape: vec![1, 80, 8400],
2272 dshape: vec![
2273 (DimName::Batch, 1),
2274 (DimName::NumClasses, 80),
2275 (DimName::NumBoxes, 8400),
2276 ],
2277 },
2278 configs::MaskCoefficients {
2279 decoder: configs::DecoderType::Ultralytics,
2280 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2281 shape: vec![1, 32, 8400],
2282 dshape: vec![
2283 (DimName::Batch, 1),
2284 (DimName::NumProtos, 32),
2285 (DimName::NumBoxes, 8400),
2286 ],
2287 },
2288 configs::Protos {
2289 decoder: configs::DecoderType::Ultralytics,
2290 quantization: Some(QuantTuple(quant_protos.scale, quant_protos.zero_point)),
2291 shape: vec![1, 160, 160, 32],
2292 dshape: vec![
2293 (DimName::Batch, 1),
2294 (DimName::Height, 160),
2295 (DimName::Width, 160),
2296 (DimName::NumProtos, 32),
2297 ],
2298 },
2299 )
2300 .with_score_threshold(score_threshold)
2301 .with_iou_threshold(iou_threshold)
2302 .build()
2303 .unwrap();
2304
2305 let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2306 let mut output_masks: Vec<_> = Vec::with_capacity(500);
2307
2308 decoder
2309 .decode_quantized(
2310 &[
2311 boxes.slice(s![.., ..4, ..]).into(),
2312 boxes.slice(s![.., 4..84, ..]).into(),
2313 boxes.slice(s![.., 84.., ..]).into(),
2314 protos.view().into(),
2315 ],
2316 &mut output_boxes,
2317 &mut output_masks,
2318 )
2319 .unwrap();
2320
2321 let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
2322 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
2323 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2324 let mut output_masks_f32: Vec<Segmentation> = Vec::with_capacity(500);
2325 decode_yolo_segdet_float(
2326 seg.slice(s![0, .., ..]),
2327 protos.slice(s![0, .., .., ..]),
2328 score_threshold,
2329 iou_threshold,
2330 Some(configs::Nms::ClassAgnostic),
2331 &mut output_boxes_f32,
2332 &mut output_masks_f32,
2333 )
2334 .unwrap();
2335
2336 assert_eq!(output_boxes.len(), output_boxes_f32.len());
2337 assert_eq!(output_masks.len(), output_masks_f32.len());
2338
2339 compare_outputs(
2340 (&output_boxes, &output_boxes_f32),
2341 (&output_masks, &output_masks_f32),
2342 );
2343 }
2344
2345 #[test]
2347 fn test_context_switch() {
2348 let yolo_det = || {
2349 let score_threshold = 0.25;
2350 let iou_threshold = 0.7;
2351 let out = include_bytes!(concat!(
2352 env!("CARGO_MANIFEST_DIR"),
2353 "/../../testdata/yolov8s_80_classes.bin"
2354 ));
2355 let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
2356 let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
2357 let quant = (0.0040811873, -123).into();
2358
2359 let decoder = DecoderBuilder::default()
2360 .with_config_yolo_det(
2361 configs::Detection {
2362 decoder: DecoderType::Ultralytics,
2363 shape: vec![1, 84, 8400],
2364 anchors: None,
2365 quantization: Some(quant),
2366 dshape: vec![
2367 (DimName::Batch, 1),
2368 (DimName::NumFeatures, 84),
2369 (DimName::NumBoxes, 8400),
2370 ],
2371 normalized: None,
2372 },
2373 None,
2374 )
2375 .with_score_threshold(score_threshold)
2376 .with_iou_threshold(iou_threshold)
2377 .build()
2378 .unwrap();
2379
2380 let mut output_boxes: Vec<_> = Vec::with_capacity(50);
2381 let mut output_masks: Vec<_> = Vec::with_capacity(50);
2382
2383 for _ in 0..100 {
2384 decoder
2385 .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
2386 .unwrap();
2387
2388 assert!(output_boxes[0].equal_within_delta(
2389 &DetectBox {
2390 bbox: BoundingBox {
2391 xmin: 0.5285137,
2392 ymin: 0.05305544,
2393 xmax: 0.87541467,
2394 ymax: 0.9998909,
2395 },
2396 score: 0.5591227,
2397 label: 0
2398 },
2399 1e-6
2400 ));
2401
2402 assert!(output_boxes[1].equal_within_delta(
2403 &DetectBox {
2404 bbox: BoundingBox {
2405 xmin: 0.130598,
2406 ymin: 0.43260583,
2407 xmax: 0.35098213,
2408 ymax: 0.9958097,
2409 },
2410 score: 0.33057618,
2411 label: 75
2412 },
2413 1e-6
2414 ));
2415 assert!(output_masks.is_empty());
2416 }
2417 };
2418
2419 let modelpack_det_split = || {
2420 let score_threshold = 0.8;
2421 let iou_threshold = 0.5;
2422
2423 let seg = include_bytes!(concat!(
2424 env!("CARGO_MANIFEST_DIR"),
2425 "/../../testdata/modelpack_seg_2x160x160.bin"
2426 ));
2427 let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
2428
2429 let detect0 = include_bytes!(concat!(
2430 env!("CARGO_MANIFEST_DIR"),
2431 "/../../testdata/modelpack_split_9x15x18.bin"
2432 ));
2433 let detect0 =
2434 ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
2435
2436 let detect1 = include_bytes!(concat!(
2437 env!("CARGO_MANIFEST_DIR"),
2438 "/../../testdata/modelpack_split_17x30x18.bin"
2439 ));
2440 let detect1 =
2441 ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
2442
2443 let mut mask = seg.slice(s![0, .., .., ..]);
2444 mask.swap_axes(0, 1);
2445 mask.swap_axes(1, 2);
2446 let mask = [Segmentation {
2447 xmin: 0.0,
2448 ymin: 0.0,
2449 xmax: 1.0,
2450 ymax: 1.0,
2451 segmentation: mask.into_owned(),
2452 }];
2453 let correct_boxes = [DetectBox {
2454 bbox: BoundingBox {
2455 xmin: 0.43171933,
2456 ymin: 0.68243736,
2457 xmax: 0.5626645,
2458 ymax: 0.808863,
2459 },
2460 score: 0.99240804,
2461 label: 0,
2462 }];
2463
2464 let quant0 = (0.08547406643629074, 174).into();
2465 let quant1 = (0.09929127991199493, 183).into();
2466 let quant_seg = (1.0 / 255.0, 0).into();
2467
2468 let anchors0 = vec![
2469 [0.36666667461395264, 0.31481480598449707],
2470 [0.38749998807907104, 0.4740740656852722],
2471 [0.5333333611488342, 0.644444465637207],
2472 ];
2473 let anchors1 = vec![
2474 [0.13750000298023224, 0.2074074000120163],
2475 [0.2541666626930237, 0.21481481194496155],
2476 [0.23125000298023224, 0.35185185074806213],
2477 ];
2478
2479 let decoder = DecoderBuilder::default()
2480 .with_config_modelpack_segdet_split(
2481 vec![
2482 configs::Detection {
2483 decoder: DecoderType::ModelPack,
2484 shape: vec![1, 17, 30, 18],
2485 anchors: Some(anchors1),
2486 quantization: Some(quant1),
2487 dshape: vec![
2488 (DimName::Batch, 1),
2489 (DimName::Height, 17),
2490 (DimName::Width, 30),
2491 (DimName::NumAnchorsXFeatures, 18),
2492 ],
2493 normalized: None,
2494 },
2495 configs::Detection {
2496 decoder: DecoderType::ModelPack,
2497 shape: vec![1, 9, 15, 18],
2498 anchors: Some(anchors0),
2499 quantization: Some(quant0),
2500 dshape: vec![
2501 (DimName::Batch, 1),
2502 (DimName::Height, 9),
2503 (DimName::Width, 15),
2504 (DimName::NumAnchorsXFeatures, 18),
2505 ],
2506 normalized: None,
2507 },
2508 ],
2509 configs::Segmentation {
2510 decoder: DecoderType::ModelPack,
2511 quantization: Some(quant_seg),
2512 shape: vec![1, 2, 160, 160],
2513 dshape: vec![
2514 (DimName::Batch, 1),
2515 (DimName::NumClasses, 2),
2516 (DimName::Height, 160),
2517 (DimName::Width, 160),
2518 ],
2519 },
2520 )
2521 .with_score_threshold(score_threshold)
2522 .with_iou_threshold(iou_threshold)
2523 .build()
2524 .unwrap();
2525 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
2526 let mut output_masks: Vec<_> = Vec::with_capacity(10);
2527
2528 for _ in 0..100 {
2529 decoder
2530 .decode_quantized(
2531 &[
2532 detect0.view().into(),
2533 detect1.view().into(),
2534 seg.view().into(),
2535 ],
2536 &mut output_boxes,
2537 &mut output_masks,
2538 )
2539 .unwrap();
2540
2541 compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
2542 }
2543 };
2544
2545 let handles = vec![
2546 std::thread::spawn(yolo_det),
2547 std::thread::spawn(modelpack_det_split),
2548 std::thread::spawn(yolo_det),
2549 std::thread::spawn(modelpack_det_split),
2550 std::thread::spawn(yolo_det),
2551 std::thread::spawn(modelpack_det_split),
2552 std::thread::spawn(yolo_det),
2553 std::thread::spawn(modelpack_det_split),
2554 ];
2555 for handle in handles {
2556 handle.join().unwrap();
2557 }
2558 }
2559
2560 #[test]
2561 fn test_ndarray_to_xyxy_float() {
2562 let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2563 let xyxy: [f32; 4] = XYWH::ndarray_to_xyxy_float(arr.view());
2564 assert_eq!(xyxy, [0.0_f32, 10.0, 20.0, 30.0]);
2565
2566 let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2567 let xyxy: [f32; 4] = XYXY::ndarray_to_xyxy_float(arr.view());
2568 assert_eq!(xyxy, [10.0_f32, 20.0, 20.0, 20.0]);
2569 }
2570
2571 #[test]
2572 fn test_class_aware_nms_float() {
2573 use crate::float::nms_class_aware_float;
2574
2575 let boxes = vec![
2577 DetectBox {
2578 bbox: BoundingBox {
2579 xmin: 0.0,
2580 ymin: 0.0,
2581 xmax: 0.5,
2582 ymax: 0.5,
2583 },
2584 score: 0.9,
2585 label: 0, },
2587 DetectBox {
2588 bbox: BoundingBox {
2589 xmin: 0.1,
2590 ymin: 0.1,
2591 xmax: 0.6,
2592 ymax: 0.6,
2593 },
2594 score: 0.8,
2595 label: 1, },
2597 ];
2598
2599 let result = nms_class_aware_float(0.3, boxes.clone());
2602 assert_eq!(
2603 result.len(),
2604 2,
2605 "Class-aware NMS should keep both boxes with different classes"
2606 );
2607
2608 let same_class_boxes = vec![
2610 DetectBox {
2611 bbox: BoundingBox {
2612 xmin: 0.0,
2613 ymin: 0.0,
2614 xmax: 0.5,
2615 ymax: 0.5,
2616 },
2617 score: 0.9,
2618 label: 0,
2619 },
2620 DetectBox {
2621 bbox: BoundingBox {
2622 xmin: 0.1,
2623 ymin: 0.1,
2624 xmax: 0.6,
2625 ymax: 0.6,
2626 },
2627 score: 0.8,
2628 label: 0, },
2630 ];
2631
2632 let result = nms_class_aware_float(0.3, same_class_boxes);
2633 assert_eq!(
2634 result.len(),
2635 1,
2636 "Class-aware NMS should suppress overlapping box with same class"
2637 );
2638 assert_eq!(result[0].label, 0);
2639 assert!((result[0].score - 0.9).abs() < 1e-6);
2640 }
2641
2642 #[test]
2643 fn test_class_agnostic_vs_aware_nms() {
2644 use crate::float::{nms_class_aware_float, nms_float};
2645
2646 let boxes = vec![
2648 DetectBox {
2649 bbox: BoundingBox {
2650 xmin: 0.0,
2651 ymin: 0.0,
2652 xmax: 0.5,
2653 ymax: 0.5,
2654 },
2655 score: 0.9,
2656 label: 0,
2657 },
2658 DetectBox {
2659 bbox: BoundingBox {
2660 xmin: 0.1,
2661 ymin: 0.1,
2662 xmax: 0.6,
2663 ymax: 0.6,
2664 },
2665 score: 0.8,
2666 label: 1,
2667 },
2668 ];
2669
2670 let agnostic_result = nms_float(0.3, boxes.clone());
2672 assert_eq!(
2673 agnostic_result.len(),
2674 1,
2675 "Class-agnostic NMS should suppress overlapping boxes"
2676 );
2677
2678 let aware_result = nms_class_aware_float(0.3, boxes);
2680 assert_eq!(
2681 aware_result.len(),
2682 2,
2683 "Class-aware NMS should keep boxes with different classes"
2684 );
2685 }
2686
2687 #[test]
2688 fn test_class_aware_nms_int() {
2689 use crate::byte::nms_class_aware_int;
2690
2691 let boxes = vec![
2693 DetectBoxQuantized {
2694 bbox: BoundingBox {
2695 xmin: 0.0,
2696 ymin: 0.0,
2697 xmax: 0.5,
2698 ymax: 0.5,
2699 },
2700 score: 200_u8,
2701 label: 0,
2702 },
2703 DetectBoxQuantized {
2704 bbox: BoundingBox {
2705 xmin: 0.1,
2706 ymin: 0.1,
2707 xmax: 0.6,
2708 ymax: 0.6,
2709 },
2710 score: 180_u8,
2711 label: 1, },
2713 ];
2714
2715 let result = nms_class_aware_int(0.5, boxes);
2717 assert_eq!(
2718 result.len(),
2719 2,
2720 "Class-aware NMS (int) should keep boxes with different classes"
2721 );
2722 }
2723
2724 #[test]
2725 fn test_nms_enum_default() {
2726 let default_nms: configs::Nms = Default::default();
2728 assert_eq!(default_nms, configs::Nms::ClassAgnostic);
2729 }
2730
2731 #[test]
2732 fn test_decoder_nms_mode() {
2733 let decoder = DecoderBuilder::default()
2735 .with_config_yolo_det(
2736 configs::Detection {
2737 anchors: None,
2738 decoder: DecoderType::Ultralytics,
2739 quantization: None,
2740 shape: vec![1, 84, 8400],
2741 dshape: Vec::new(),
2742 normalized: Some(true),
2743 },
2744 None,
2745 )
2746 .with_nms(Some(configs::Nms::ClassAware))
2747 .build()
2748 .unwrap();
2749
2750 assert_eq!(decoder.nms, Some(configs::Nms::ClassAware));
2751 }
2752
2753 #[test]
2754 fn test_decoder_nms_bypass() {
2755 let decoder = DecoderBuilder::default()
2757 .with_config_yolo_det(
2758 configs::Detection {
2759 anchors: None,
2760 decoder: DecoderType::Ultralytics,
2761 quantization: None,
2762 shape: vec![1, 84, 8400],
2763 dshape: Vec::new(),
2764 normalized: Some(true),
2765 },
2766 None,
2767 )
2768 .with_nms(None)
2769 .build()
2770 .unwrap();
2771
2772 assert_eq!(decoder.nms, None);
2773 }
2774
2775 #[test]
2776 fn test_decoder_normalized_boxes_true() {
2777 let decoder = DecoderBuilder::default()
2779 .with_config_yolo_det(
2780 configs::Detection {
2781 anchors: None,
2782 decoder: DecoderType::Ultralytics,
2783 quantization: None,
2784 shape: vec![1, 84, 8400],
2785 dshape: Vec::new(),
2786 normalized: Some(true),
2787 },
2788 None,
2789 )
2790 .build()
2791 .unwrap();
2792
2793 assert_eq!(decoder.normalized_boxes(), Some(true));
2794 }
2795
2796 #[test]
2797 fn test_decoder_normalized_boxes_false() {
2798 let decoder = DecoderBuilder::default()
2801 .with_config_yolo_det(
2802 configs::Detection {
2803 anchors: None,
2804 decoder: DecoderType::Ultralytics,
2805 quantization: None,
2806 shape: vec![1, 84, 8400],
2807 dshape: Vec::new(),
2808 normalized: Some(false),
2809 },
2810 None,
2811 )
2812 .build()
2813 .unwrap();
2814
2815 assert_eq!(decoder.normalized_boxes(), Some(false));
2816 }
2817
2818 #[test]
2819 fn test_decoder_normalized_boxes_unknown() {
2820 let decoder = DecoderBuilder::default()
2822 .with_config_yolo_det(
2823 configs::Detection {
2824 anchors: None,
2825 decoder: DecoderType::Ultralytics,
2826 quantization: None,
2827 shape: vec![1, 84, 8400],
2828 dshape: Vec::new(),
2829 normalized: None,
2830 },
2831 Some(DecoderVersion::Yolo11),
2832 )
2833 .build()
2834 .unwrap();
2835
2836 assert_eq!(decoder.normalized_boxes(), None);
2837 }
2838}
2839
2840#[cfg(feature = "tracker")]
2841#[cfg(test)]
2842#[cfg_attr(coverage_nightly, coverage(off))]
2843mod decoder_tracked_tests {
2844
2845 use edgefirst_tracker::{ByteTrackBuilder, Tracker};
2846 use ndarray::{array, s, Array, Array2, Array3, Array4, ArrayView, Axis, Dimension};
2847 use num_traits::{AsPrimitive, Float, PrimInt};
2848 use rand::{RngExt, SeedableRng};
2849 use rand_distr::StandardNormal;
2850
2851 use crate::{
2852 configs::{self, DimName},
2853 dequantize_ndarray, BoundingBox, DecoderBuilder, DetectBox, Quantization,
2854 };
2855
2856 pub fn quantize_ndarray<T: PrimInt + 'static, D: Dimension, F: Float + AsPrimitive<T>>(
2857 input: ArrayView<F, D>,
2858 quant: Quantization,
2859 ) -> Array<T, D>
2860 where
2861 i32: num_traits::AsPrimitive<F>,
2862 f32: num_traits::AsPrimitive<F>,
2863 {
2864 let zero_point = quant.zero_point.as_();
2865 let div_scale = F::one() / quant.scale.as_();
2866 if zero_point != F::zero() {
2867 input.mapv(|d| (d * div_scale + zero_point).round().as_())
2868 } else {
2869 input.mapv(|d| (d * div_scale).round().as_())
2870 }
2871 }
2872
2873 #[test]
2874 fn test_decoder_tracked_random_jitter() {
2875 use crate::configs::{DecoderType, Nms};
2876 use crate::DecoderBuilder;
2877
2878 let score_threshold = 0.25;
2879 let iou_threshold = 0.1;
2880 let out = include_bytes!(concat!(
2881 env!("CARGO_MANIFEST_DIR"),
2882 "/../../testdata/yolov8s_80_classes.bin"
2883 ));
2884 let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
2885 let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
2886 let quant = (0.0040811873, -123).into();
2887
2888 let decoder = DecoderBuilder::default()
2889 .with_config_yolo_det(
2890 crate::configs::Detection {
2891 decoder: DecoderType::Ultralytics,
2892 shape: vec![1, 84, 8400],
2893 anchors: None,
2894 quantization: Some(quant),
2895 dshape: vec![
2896 (crate::configs::DimName::Batch, 1),
2897 (crate::configs::DimName::NumFeatures, 84),
2898 (crate::configs::DimName::NumBoxes, 8400),
2899 ],
2900 normalized: Some(true),
2901 },
2902 None,
2903 )
2904 .with_score_threshold(score_threshold)
2905 .with_iou_threshold(iou_threshold)
2906 .with_nms(Some(Nms::ClassAgnostic))
2907 .build()
2908 .unwrap();
2909 let mut rng = rand::rngs::StdRng::seed_from_u64(0xAB_BEEF); let expected_boxes = [
2912 crate::DetectBox {
2913 bbox: crate::BoundingBox {
2914 xmin: 0.5285137,
2915 ymin: 0.05305544,
2916 xmax: 0.87541467,
2917 ymax: 0.9998909,
2918 },
2919 score: 0.5591227,
2920 label: 0,
2921 },
2922 crate::DetectBox {
2923 bbox: crate::BoundingBox {
2924 xmin: 0.130598,
2925 ymin: 0.43260583,
2926 xmax: 0.35098213,
2927 ymax: 0.9958097,
2928 },
2929 score: 0.33057618,
2930 label: 75,
2931 },
2932 ];
2933
2934 let mut tracker = ByteTrackBuilder::new()
2935 .track_update(0.1)
2936 .track_high_conf(0.3)
2937 .build();
2938
2939 let mut output_boxes = Vec::with_capacity(50);
2940 let mut output_masks = Vec::with_capacity(50);
2941 let mut output_tracks = Vec::with_capacity(50);
2942
2943 decoder
2944 .decode_tracked_quantized(
2945 &mut tracker,
2946 0,
2947 &[out.view().into()],
2948 &mut output_boxes,
2949 &mut output_masks,
2950 &mut output_tracks,
2951 )
2952 .unwrap();
2953
2954 assert_eq!(output_boxes.len(), 2);
2955 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
2956 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
2957
2958 let mut last_boxes = output_boxes.clone();
2959
2960 for i in 1..=100 {
2961 let mut out = out.clone();
2962 let mut x_values = out.slice_mut(s![0, 0, ..]);
2964 for x in x_values.iter_mut() {
2965 let r: f32 = rng.sample(StandardNormal);
2966 let r = r.clamp(-2.0, 2.0) / 2.0;
2967 *x = x.saturating_add((r * 1e-2 / quant.0) as i8);
2968 }
2969
2970 let mut y_values = out.slice_mut(s![0, 1, ..]);
2971 for y in y_values.iter_mut() {
2972 let r: f32 = rng.sample(StandardNormal);
2973 let r = r.clamp(-2.0, 2.0) / 2.0;
2974 *y = y.saturating_add((r * 1e-2 / quant.0) as i8);
2975 }
2976
2977 decoder
2978 .decode_tracked_quantized(
2979 &mut tracker,
2980 100_000_000 * i / 3, &[out.view().into()],
2982 &mut output_boxes,
2983 &mut output_masks,
2984 &mut output_tracks,
2985 )
2986 .unwrap();
2987
2988 assert_eq!(output_boxes.len(), 2);
2989 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 5e-3));
2990 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 5e-3));
2991
2992 assert!(output_boxes[0].equal_within_delta(&last_boxes[0], 1e-3));
2993 assert!(output_boxes[1].equal_within_delta(&last_boxes[1], 1e-3));
2994 last_boxes = output_boxes.clone();
2995 }
2996 }
2997
2998 fn real_data_expected_boxes() -> [DetectBox; 2] {
3001 [
3002 DetectBox {
3003 bbox: BoundingBox {
3004 xmin: 0.08515105,
3005 ymin: 0.7131401,
3006 xmax: 0.29802868,
3007 ymax: 0.8195788,
3008 },
3009 score: 0.91537374,
3010 label: 23,
3011 },
3012 DetectBox {
3013 bbox: BoundingBox {
3014 xmin: 0.59605736,
3015 ymin: 0.25545314,
3016 xmax: 0.93666154,
3017 ymax: 0.72378385,
3018 },
3019 score: 0.91537374,
3020 label: 23,
3021 },
3022 ]
3023 }
3024
3025 fn e2e_expected_boxes_quant() -> [DetectBox; 1] {
3026 [DetectBox {
3027 bbox: BoundingBox {
3028 xmin: 0.12549022,
3029 ymin: 0.12549022,
3030 xmax: 0.23529413,
3031 ymax: 0.23529413,
3032 },
3033 score: 0.98823535,
3034 label: 2,
3035 }]
3036 }
3037
3038 fn e2e_expected_boxes_float() -> [DetectBox; 1] {
3039 [DetectBox {
3040 bbox: BoundingBox {
3041 xmin: 0.1234,
3042 ymin: 0.1234,
3043 xmax: 0.2345,
3044 ymax: 0.2345,
3045 },
3046 score: 0.9876,
3047 label: 2,
3048 }]
3049 }
3050
3051 fn build_split_decoder(
3052 score_threshold: f32,
3053 iou_threshold: f32,
3054 quant_boxes: (f32, i32),
3055 quant_protos: (f32, i32),
3056 ) -> crate::Decoder {
3057 DecoderBuilder::default()
3058 .with_config_yolo_split_segdet(
3059 configs::Boxes {
3060 decoder: configs::DecoderType::Ultralytics,
3061 quantization: Some(quant_boxes.into()),
3062 shape: vec![1, 4, 8400],
3063 dshape: vec![
3064 (DimName::Batch, 1),
3065 (DimName::BoxCoords, 4),
3066 (DimName::NumBoxes, 8400),
3067 ],
3068 normalized: Some(true),
3069 },
3070 configs::Scores {
3071 decoder: configs::DecoderType::Ultralytics,
3072 quantization: Some(quant_boxes.into()),
3073 shape: vec![1, 80, 8400],
3074 dshape: vec![
3075 (DimName::Batch, 1),
3076 (DimName::NumClasses, 80),
3077 (DimName::NumBoxes, 8400),
3078 ],
3079 },
3080 configs::MaskCoefficients {
3081 decoder: configs::DecoderType::Ultralytics,
3082 quantization: Some(quant_boxes.into()),
3083 shape: vec![1, 32, 8400],
3084 dshape: vec![
3085 (DimName::Batch, 1),
3086 (DimName::NumProtos, 32),
3087 (DimName::NumBoxes, 8400),
3088 ],
3089 },
3090 configs::Protos {
3091 decoder: configs::DecoderType::Ultralytics,
3092 quantization: Some(quant_protos.into()),
3093 shape: vec![1, 160, 160, 32],
3094 dshape: vec![
3095 (DimName::Batch, 1),
3096 (DimName::Height, 160),
3097 (DimName::Width, 160),
3098 (DimName::NumProtos, 32),
3099 ],
3100 },
3101 )
3102 .with_score_threshold(score_threshold)
3103 .with_iou_threshold(iou_threshold)
3104 .build()
3105 .unwrap()
3106 }
3107
3108 macro_rules! real_data_tracked_test {
3115 ($name:ident, quantized, $layout:ident, $output:ident) => {
3116 #[test]
3117 fn $name() {
3118 use crate::configs::Nms;
3119 let is_split = matches!(stringify!($layout), "split");
3120 let is_proto = matches!(stringify!($output), "proto");
3121
3122 let score_threshold = 0.45;
3123 let iou_threshold = 0.45;
3124 let quant_boxes = (0.021287762_f32, 31_i32);
3125 let quant_protos = (0.02491162_f32, -117_i32);
3126
3127 let raw_boxes = include_bytes!(concat!(
3128 env!("CARGO_MANIFEST_DIR"),
3129 "/../../testdata/yolov8_boxes_116x8400.bin"
3130 ));
3131 let raw_boxes = unsafe {
3132 std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len())
3133 };
3134 let boxes_i8 =
3135 ndarray::Array3::from_shape_vec((1, 116, 8400), raw_boxes.to_vec()).unwrap();
3136
3137 let raw_protos = include_bytes!(concat!(
3138 env!("CARGO_MANIFEST_DIR"),
3139 "/../../testdata/yolov8_protos_160x160x32.bin"
3140 ));
3141 let raw_protos = unsafe {
3142 std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3143 };
3144 let protos_i8 =
3145 ndarray::Array4::from_shape_vec((1, 160, 160, 32), raw_protos.to_vec())
3146 .unwrap();
3147
3148 let mask_split = boxes_i8.slice(s![.., 84.., ..]).to_owned();
3150 let mut scores_split = boxes_i8.slice(s![.., 4..84, ..]).to_owned();
3151 let boxes_split = boxes_i8.slice(s![.., ..4, ..]).to_owned();
3152 let mut boxes_combined = boxes_i8;
3153
3154 let decoder = if is_split {
3155 build_split_decoder(score_threshold, iou_threshold, quant_boxes, quant_protos)
3156 } else {
3157 let config_yaml = include_str!(concat!(
3158 env!("CARGO_MANIFEST_DIR"),
3159 "/../../testdata/yolov8_seg.yaml"
3160 ));
3161 DecoderBuilder::default()
3162 .with_config_yaml_str(config_yaml.to_string())
3163 .with_score_threshold(score_threshold)
3164 .with_iou_threshold(iou_threshold)
3165 .with_nms(Some(Nms::ClassAgnostic))
3166 .build()
3167 .unwrap()
3168 };
3169
3170 let expected = real_data_expected_boxes();
3171 let mut tracker = ByteTrackBuilder::new()
3172 .track_update(0.1)
3173 .track_high_conf(0.7)
3174 .build();
3175 let mut output_boxes = Vec::with_capacity(50);
3176 let mut output_tracks = Vec::with_capacity(50);
3177
3178 if is_proto {
3180 {
3181 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
3182 vec![
3183 boxes_split.view().into(),
3184 scores_split.view().into(),
3185 mask_split.view().into(),
3186 protos_i8.view().into(),
3187 ]
3188 } else {
3189 vec![boxes_combined.view().into(), protos_i8.view().into()]
3190 };
3191 decoder
3192 .decode_tracked_quantized_proto(
3193 &mut tracker,
3194 0,
3195 &inputs,
3196 &mut output_boxes,
3197 &mut output_tracks,
3198 )
3199 .unwrap();
3200 }
3201 assert_eq!(output_boxes.len(), 2);
3202 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3203 assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3204
3205 if is_split {
3207 for score in scores_split.iter_mut() {
3208 *score = i8::MIN;
3209 }
3210 } else {
3211 for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
3212 *score = i8::MIN;
3213 }
3214 }
3215
3216 let proto_result = {
3217 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
3218 vec![
3219 boxes_split.view().into(),
3220 scores_split.view().into(),
3221 mask_split.view().into(),
3222 protos_i8.view().into(),
3223 ]
3224 } else {
3225 vec![boxes_combined.view().into(), protos_i8.view().into()]
3226 };
3227 decoder
3228 .decode_tracked_quantized_proto(
3229 &mut tracker,
3230 100_000_000 / 3,
3231 &inputs,
3232 &mut output_boxes,
3233 &mut output_tracks,
3234 )
3235 .unwrap()
3236 };
3237 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
3238 assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
3239 assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
3240 } else {
3241 let mut output_masks = Vec::with_capacity(50);
3242 {
3243 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
3244 vec![
3245 boxes_split.view().into(),
3246 scores_split.view().into(),
3247 mask_split.view().into(),
3248 protos_i8.view().into(),
3249 ]
3250 } else {
3251 vec![boxes_combined.view().into(), protos_i8.view().into()]
3252 };
3253 decoder
3254 .decode_tracked_quantized(
3255 &mut tracker,
3256 0,
3257 &inputs,
3258 &mut output_boxes,
3259 &mut output_masks,
3260 &mut output_tracks,
3261 )
3262 .unwrap();
3263 }
3264 assert_eq!(output_boxes.len(), 2);
3265 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3266 assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3267
3268 if is_split {
3269 for score in scores_split.iter_mut() {
3270 *score = i8::MIN;
3271 }
3272 } else {
3273 for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
3274 *score = i8::MIN;
3275 }
3276 }
3277
3278 {
3279 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
3280 vec![
3281 boxes_split.view().into(),
3282 scores_split.view().into(),
3283 mask_split.view().into(),
3284 protos_i8.view().into(),
3285 ]
3286 } else {
3287 vec![boxes_combined.view().into(), protos_i8.view().into()]
3288 };
3289 decoder
3290 .decode_tracked_quantized(
3291 &mut tracker,
3292 100_000_000 / 3,
3293 &inputs,
3294 &mut output_boxes,
3295 &mut output_masks,
3296 &mut output_tracks,
3297 )
3298 .unwrap();
3299 }
3300 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
3301 assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
3302 assert!(output_masks.is_empty());
3303 }
3304 }
3305 };
3306 ($name:ident, float, $layout:ident, $output:ident) => {
3307 #[test]
3308 fn $name() {
3309 use crate::configs::Nms;
3310 let is_split = matches!(stringify!($layout), "split");
3311 let is_proto = matches!(stringify!($output), "proto");
3312
3313 let score_threshold = 0.45;
3314 let iou_threshold = 0.45;
3315 let quant_boxes = (0.021287762_f32, 31_i32);
3316 let quant_protos = (0.02491162_f32, -117_i32);
3317
3318 let raw_boxes = include_bytes!(concat!(
3319 env!("CARGO_MANIFEST_DIR"),
3320 "/../../testdata/yolov8_boxes_116x8400.bin"
3321 ));
3322 let raw_boxes = unsafe {
3323 std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len())
3324 };
3325 let boxes_i8 =
3326 ndarray::Array3::from_shape_vec((1, 116, 8400), raw_boxes.to_vec()).unwrap();
3327 let boxes_f32 = dequantize_ndarray(boxes_i8.view(), quant_boxes.into());
3328
3329 let raw_protos = include_bytes!(concat!(
3330 env!("CARGO_MANIFEST_DIR"),
3331 "/../../testdata/yolov8_protos_160x160x32.bin"
3332 ));
3333 let raw_protos = unsafe {
3334 std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3335 };
3336 let protos_i8 =
3337 ndarray::Array4::from_shape_vec((1, 160, 160, 32), raw_protos.to_vec())
3338 .unwrap();
3339 let protos_f32 = dequantize_ndarray(protos_i8.view(), quant_protos.into());
3340
3341 let mask_split = boxes_f32.slice(s![.., 84.., ..]).to_owned();
3343 let mut scores_split = boxes_f32.slice(s![.., 4..84, ..]).to_owned();
3344 let boxes_split = boxes_f32.slice(s![.., ..4, ..]).to_owned();
3345 let mut boxes_combined = boxes_f32;
3346
3347 let decoder = if is_split {
3348 build_split_decoder(score_threshold, iou_threshold, quant_boxes, quant_protos)
3349 } else {
3350 let config_yaml = include_str!(concat!(
3351 env!("CARGO_MANIFEST_DIR"),
3352 "/../../testdata/yolov8_seg.yaml"
3353 ));
3354 DecoderBuilder::default()
3355 .with_config_yaml_str(config_yaml.to_string())
3356 .with_score_threshold(score_threshold)
3357 .with_iou_threshold(iou_threshold)
3358 .with_nms(Some(Nms::ClassAgnostic))
3359 .build()
3360 .unwrap()
3361 };
3362
3363 let expected = real_data_expected_boxes();
3364 let mut tracker = ByteTrackBuilder::new()
3365 .track_update(0.1)
3366 .track_high_conf(0.7)
3367 .build();
3368 let mut output_boxes = Vec::with_capacity(50);
3369 let mut output_tracks = Vec::with_capacity(50);
3370
3371 if is_proto {
3372 {
3373 let inputs = if is_split {
3374 vec![
3375 boxes_split.view().into_dyn(),
3376 scores_split.view().into_dyn(),
3377 mask_split.view().into_dyn(),
3378 protos_f32.view().into_dyn(),
3379 ]
3380 } else {
3381 vec![
3382 boxes_combined.view().into_dyn(),
3383 protos_f32.view().into_dyn(),
3384 ]
3385 };
3386 decoder
3387 .decode_tracked_float_proto(
3388 &mut tracker,
3389 0,
3390 &inputs,
3391 &mut output_boxes,
3392 &mut output_tracks,
3393 )
3394 .unwrap();
3395 }
3396 assert_eq!(output_boxes.len(), 2);
3397 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3398 assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3399
3400 if is_split {
3401 for score in scores_split.iter_mut() {
3402 *score = 0.0;
3403 }
3404 } else {
3405 for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
3406 *score = 0.0;
3407 }
3408 }
3409
3410 let proto_result = {
3411 let inputs = if is_split {
3412 vec![
3413 boxes_split.view().into_dyn(),
3414 scores_split.view().into_dyn(),
3415 mask_split.view().into_dyn(),
3416 protos_f32.view().into_dyn(),
3417 ]
3418 } else {
3419 vec![
3420 boxes_combined.view().into_dyn(),
3421 protos_f32.view().into_dyn(),
3422 ]
3423 };
3424 decoder
3425 .decode_tracked_float_proto(
3426 &mut tracker,
3427 100_000_000 / 3,
3428 &inputs,
3429 &mut output_boxes,
3430 &mut output_tracks,
3431 )
3432 .unwrap()
3433 };
3434 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
3435 assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
3436 assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
3437 } else {
3438 let mut output_masks = Vec::with_capacity(50);
3439 {
3440 let inputs = if is_split {
3441 vec![
3442 boxes_split.view().into_dyn(),
3443 scores_split.view().into_dyn(),
3444 mask_split.view().into_dyn(),
3445 protos_f32.view().into_dyn(),
3446 ]
3447 } else {
3448 vec![
3449 boxes_combined.view().into_dyn(),
3450 protos_f32.view().into_dyn(),
3451 ]
3452 };
3453 decoder
3454 .decode_tracked_float(
3455 &mut tracker,
3456 0,
3457 &inputs,
3458 &mut output_boxes,
3459 &mut output_masks,
3460 &mut output_tracks,
3461 )
3462 .unwrap();
3463 }
3464 assert_eq!(output_boxes.len(), 2);
3465 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3466 assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3467
3468 if is_split {
3469 for score in scores_split.iter_mut() {
3470 *score = 0.0;
3471 }
3472 } else {
3473 for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
3474 *score = 0.0;
3475 }
3476 }
3477
3478 {
3479 let inputs = if is_split {
3480 vec![
3481 boxes_split.view().into_dyn(),
3482 scores_split.view().into_dyn(),
3483 mask_split.view().into_dyn(),
3484 protos_f32.view().into_dyn(),
3485 ]
3486 } else {
3487 vec![
3488 boxes_combined.view().into_dyn(),
3489 protos_f32.view().into_dyn(),
3490 ]
3491 };
3492 decoder
3493 .decode_tracked_float(
3494 &mut tracker,
3495 100_000_000 / 3,
3496 &inputs,
3497 &mut output_boxes,
3498 &mut output_masks,
3499 &mut output_tracks,
3500 )
3501 .unwrap();
3502 }
3503 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
3504 assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
3505 assert!(output_masks.is_empty());
3506 }
3507 }
3508 };
3509 }
3510
3511 real_data_tracked_test!(test_decoder_tracked_segdet, quantized, combined, masks);
3512 real_data_tracked_test!(test_decoder_tracked_segdet_float, float, combined, masks);
3513 real_data_tracked_test!(
3514 test_decoder_tracked_segdet_proto,
3515 quantized,
3516 combined,
3517 proto
3518 );
3519 real_data_tracked_test!(
3520 test_decoder_tracked_segdet_proto_float,
3521 float,
3522 combined,
3523 proto
3524 );
3525 real_data_tracked_test!(test_decoder_tracked_segdet_split, quantized, split, masks);
3526 real_data_tracked_test!(test_decoder_tracked_segdet_split_float, float, split, masks);
3527 real_data_tracked_test!(
3528 test_decoder_tracked_segdet_split_proto,
3529 quantized,
3530 split,
3531 proto
3532 );
3533 real_data_tracked_test!(
3534 test_decoder_tracked_segdet_split_proto_float,
3535 float,
3536 split,
3537 proto
3538 );
3539
3540 const E2E_COMBINED_CONFIG: &str = "
3546decoder_version: yolo26
3547outputs:
3548 - type: detection
3549 decoder: ultralytics
3550 quantization: [0.00784313725490196, 0]
3551 shape: [1, 10, 38]
3552 dshape:
3553 - [batch, 1]
3554 - [num_boxes, 10]
3555 - [num_features, 38]
3556 normalized: true
3557 - type: protos
3558 decoder: ultralytics
3559 quantization: [0.0039215686274509803921568627451, 128]
3560 shape: [1, 160, 160, 32]
3561 dshape:
3562 - [batch, 1]
3563 - [height, 160]
3564 - [width, 160]
3565 - [num_protos, 32]
3566";
3567
3568 const E2E_SPLIT_CONFIG: &str = "
3569decoder_version: yolo26
3570outputs:
3571 - type: boxes
3572 decoder: ultralytics
3573 quantization: [0.00784313725490196, 0]
3574 shape: [1, 10, 4]
3575 dshape:
3576 - [batch, 1]
3577 - [num_boxes, 10]
3578 - [box_coords, 4]
3579 normalized: true
3580 - type: scores
3581 decoder: ultralytics
3582 quantization: [0.00784313725490196, 0]
3583 shape: [1, 10, 1]
3584 dshape:
3585 - [batch, 1]
3586 - [num_boxes, 10]
3587 - [num_classes, 1]
3588 - type: classes
3589 decoder: ultralytics
3590 quantization: [0.00784313725490196, 0]
3591 shape: [1, 10, 1]
3592 dshape:
3593 - [batch, 1]
3594 - [num_boxes, 10]
3595 - [num_classes, 1]
3596 - type: mask_coefficients
3597 decoder: ultralytics
3598 quantization: [0.00784313725490196, 0]
3599 shape: [1, 10, 32]
3600 dshape:
3601 - [batch, 1]
3602 - [num_boxes, 10]
3603 - [num_protos, 32]
3604 - type: protos
3605 decoder: ultralytics
3606 quantization: [0.0039215686274509803921568627451, 128]
3607 shape: [1, 160, 160, 32]
3608 dshape:
3609 - [batch, 1]
3610 - [height, 160]
3611 - [width, 160]
3612 - [num_protos, 32]
3613";
3614
3615 macro_rules! e2e_tracked_test {
3616 ($name:ident, quantized, $layout:ident, $output:ident) => {
3617 #[test]
3618 fn $name() {
3619 let is_split = matches!(stringify!($layout), "split");
3620 let is_proto = matches!(stringify!($output), "proto");
3621
3622 let score_threshold = 0.45;
3623 let iou_threshold = 0.45;
3624
3625 let mut boxes = Array2::zeros((10, 4));
3626 let mut scores = Array2::zeros((10, 1));
3627 let mut classes = Array2::zeros((10, 1));
3628 let mask = Array2::zeros((10, 32));
3629 let protos = Array3::<f64>::zeros((160, 160, 32));
3630 let protos = protos.insert_axis(Axis(0));
3631 let protos_quant = (1.0 / 255.0, 0.0);
3632 let protos: Array4<u8> = quantize_ndarray(protos.view(), protos_quant.into());
3633
3634 boxes
3635 .slice_mut(s![0, ..])
3636 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
3637 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
3638 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
3639
3640 let detect_quant = (2.0 / 255.0, 0.0);
3641
3642 let decoder = if is_split {
3643 DecoderBuilder::default()
3644 .with_config_yaml_str(E2E_SPLIT_CONFIG.to_string())
3645 .with_score_threshold(score_threshold)
3646 .with_iou_threshold(iou_threshold)
3647 .build()
3648 .unwrap()
3649 } else {
3650 DecoderBuilder::default()
3651 .with_config_yaml_str(E2E_COMBINED_CONFIG.to_string())
3652 .with_score_threshold(score_threshold)
3653 .with_iou_threshold(iou_threshold)
3654 .build()
3655 .unwrap()
3656 };
3657
3658 let expected = e2e_expected_boxes_quant();
3659 let mut tracker = ByteTrackBuilder::new()
3660 .track_update(0.1)
3661 .track_high_conf(0.7)
3662 .build();
3663 let mut output_boxes = Vec::with_capacity(50);
3664 let mut output_tracks = Vec::with_capacity(50);
3665
3666 if is_split {
3667 let boxes = boxes.insert_axis(Axis(0));
3668 let scores = scores.insert_axis(Axis(0));
3669 let classes = classes.insert_axis(Axis(0));
3670 let mask = mask.insert_axis(Axis(0));
3671
3672 let boxes: Array3<u8> = quantize_ndarray(boxes.view(), detect_quant.into());
3673 let mut scores: Array3<u8> =
3674 quantize_ndarray(scores.view(), detect_quant.into());
3675 let classes: Array3<u8> = quantize_ndarray(classes.view(), detect_quant.into());
3676 let mask: Array3<u8> = quantize_ndarray(mask.view(), detect_quant.into());
3677
3678 if is_proto {
3679 {
3680 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
3681 boxes.view().into(),
3682 scores.view().into(),
3683 classes.view().into(),
3684 mask.view().into(),
3685 protos.view().into(),
3686 ];
3687 decoder
3688 .decode_tracked_quantized_proto(
3689 &mut tracker,
3690 0,
3691 &inputs,
3692 &mut output_boxes,
3693 &mut output_tracks,
3694 )
3695 .unwrap();
3696 }
3697 assert_eq!(output_boxes.len(), 1);
3698 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3699
3700 for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
3701 *score = u8::MIN;
3702 }
3703 let proto_result = {
3704 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
3705 boxes.view().into(),
3706 scores.view().into(),
3707 classes.view().into(),
3708 mask.view().into(),
3709 protos.view().into(),
3710 ];
3711 decoder
3712 .decode_tracked_quantized_proto(
3713 &mut tracker,
3714 100_000_000 / 3,
3715 &inputs,
3716 &mut output_boxes,
3717 &mut output_tracks,
3718 )
3719 .unwrap()
3720 };
3721 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
3722 assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
3723 } else {
3724 let mut output_masks = Vec::with_capacity(50);
3725 {
3726 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
3727 boxes.view().into(),
3728 scores.view().into(),
3729 classes.view().into(),
3730 mask.view().into(),
3731 protos.view().into(),
3732 ];
3733 decoder
3734 .decode_tracked_quantized(
3735 &mut tracker,
3736 0,
3737 &inputs,
3738 &mut output_boxes,
3739 &mut output_masks,
3740 &mut output_tracks,
3741 )
3742 .unwrap();
3743 }
3744 assert_eq!(output_boxes.len(), 1);
3745 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3746
3747 for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
3748 *score = u8::MIN;
3749 }
3750 {
3751 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
3752 boxes.view().into(),
3753 scores.view().into(),
3754 classes.view().into(),
3755 mask.view().into(),
3756 protos.view().into(),
3757 ];
3758 decoder
3759 .decode_tracked_quantized(
3760 &mut tracker,
3761 100_000_000 / 3,
3762 &inputs,
3763 &mut output_boxes,
3764 &mut output_masks,
3765 &mut output_tracks,
3766 )
3767 .unwrap();
3768 }
3769 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
3770 assert!(output_masks.is_empty());
3771 }
3772 } else {
3773 let detect = ndarray::concatenate![
3775 Axis(1),
3776 boxes.view(),
3777 scores.view(),
3778 classes.view(),
3779 mask.view()
3780 ];
3781 let detect = detect.insert_axis(Axis(0));
3782 assert_eq!(detect.shape(), &[1, 10, 38]);
3783 let mut detect: Array3<u8> =
3784 quantize_ndarray(detect.view(), detect_quant.into());
3785
3786 if is_proto {
3787 {
3788 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
3789 vec![detect.view().into(), protos.view().into()];
3790 decoder
3791 .decode_tracked_quantized_proto(
3792 &mut tracker,
3793 0,
3794 &inputs,
3795 &mut output_boxes,
3796 &mut output_tracks,
3797 )
3798 .unwrap();
3799 }
3800 assert_eq!(output_boxes.len(), 1);
3801 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3802
3803 for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
3804 *score = u8::MIN;
3805 }
3806 let proto_result = {
3807 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
3808 vec![detect.view().into(), protos.view().into()];
3809 decoder
3810 .decode_tracked_quantized_proto(
3811 &mut tracker,
3812 100_000_000 / 3,
3813 &inputs,
3814 &mut output_boxes,
3815 &mut output_tracks,
3816 )
3817 .unwrap()
3818 };
3819 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
3820 assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
3821 } else {
3822 let mut output_masks = Vec::with_capacity(50);
3823 {
3824 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
3825 vec![detect.view().into(), protos.view().into()];
3826 decoder
3827 .decode_tracked_quantized(
3828 &mut tracker,
3829 0,
3830 &inputs,
3831 &mut output_boxes,
3832 &mut output_masks,
3833 &mut output_tracks,
3834 )
3835 .unwrap();
3836 }
3837 assert_eq!(output_boxes.len(), 1);
3838 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3839
3840 for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
3841 *score = u8::MIN;
3842 }
3843 {
3844 let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
3845 vec![detect.view().into(), protos.view().into()];
3846 decoder
3847 .decode_tracked_quantized(
3848 &mut tracker,
3849 100_000_000 / 3,
3850 &inputs,
3851 &mut output_boxes,
3852 &mut output_masks,
3853 &mut output_tracks,
3854 )
3855 .unwrap();
3856 }
3857 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
3858 assert!(output_masks.is_empty());
3859 }
3860 }
3861 }
3862 };
3863 ($name:ident, float, $layout:ident, $output:ident) => {
3864 #[test]
3865 fn $name() {
3866 let is_split = matches!(stringify!($layout), "split");
3867 let is_proto = matches!(stringify!($output), "proto");
3868
3869 let score_threshold = 0.45;
3870 let iou_threshold = 0.45;
3871
3872 let mut boxes = Array2::zeros((10, 4));
3873 let mut scores = Array2::zeros((10, 1));
3874 let mut classes = Array2::zeros((10, 1));
3875 let mask: Array2<f64> = Array2::zeros((10, 32));
3876 let protos = Array3::<f64>::zeros((160, 160, 32));
3877 let protos = protos.insert_axis(Axis(0));
3878
3879 boxes
3880 .slice_mut(s![0, ..])
3881 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
3882 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
3883 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
3884
3885 let decoder = if is_split {
3886 DecoderBuilder::default()
3887 .with_config_yaml_str(E2E_SPLIT_CONFIG.to_string())
3888 .with_score_threshold(score_threshold)
3889 .with_iou_threshold(iou_threshold)
3890 .build()
3891 .unwrap()
3892 } else {
3893 DecoderBuilder::default()
3894 .with_config_yaml_str(E2E_COMBINED_CONFIG.to_string())
3895 .with_score_threshold(score_threshold)
3896 .with_iou_threshold(iou_threshold)
3897 .build()
3898 .unwrap()
3899 };
3900
3901 let expected = e2e_expected_boxes_float();
3902 let mut tracker = ByteTrackBuilder::new()
3903 .track_update(0.1)
3904 .track_high_conf(0.7)
3905 .build();
3906 let mut output_boxes = Vec::with_capacity(50);
3907 let mut output_tracks = Vec::with_capacity(50);
3908
3909 if is_split {
3910 let boxes = boxes.insert_axis(Axis(0));
3911 let mut scores = scores.insert_axis(Axis(0));
3912 let classes = classes.insert_axis(Axis(0));
3913 let mask = mask.insert_axis(Axis(0));
3914
3915 if is_proto {
3916 {
3917 let inputs = vec![
3918 boxes.view().into_dyn(),
3919 scores.view().into_dyn(),
3920 classes.view().into_dyn(),
3921 mask.view().into_dyn(),
3922 protos.view().into_dyn(),
3923 ];
3924 decoder
3925 .decode_tracked_float_proto(
3926 &mut tracker,
3927 0,
3928 &inputs,
3929 &mut output_boxes,
3930 &mut output_tracks,
3931 )
3932 .unwrap();
3933 }
3934 assert_eq!(output_boxes.len(), 1);
3935 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3936
3937 for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
3938 *score = 0.0;
3939 }
3940 let proto_result = {
3941 let inputs = vec![
3942 boxes.view().into_dyn(),
3943 scores.view().into_dyn(),
3944 classes.view().into_dyn(),
3945 mask.view().into_dyn(),
3946 protos.view().into_dyn(),
3947 ];
3948 decoder
3949 .decode_tracked_float_proto(
3950 &mut tracker,
3951 100_000_000 / 3,
3952 &inputs,
3953 &mut output_boxes,
3954 &mut output_tracks,
3955 )
3956 .unwrap()
3957 };
3958 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
3959 assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
3960 } else {
3961 let mut output_masks = Vec::with_capacity(50);
3962 {
3963 let inputs = vec![
3964 boxes.view().into_dyn(),
3965 scores.view().into_dyn(),
3966 classes.view().into_dyn(),
3967 mask.view().into_dyn(),
3968 protos.view().into_dyn(),
3969 ];
3970 decoder
3971 .decode_tracked_float(
3972 &mut tracker,
3973 0,
3974 &inputs,
3975 &mut output_boxes,
3976 &mut output_masks,
3977 &mut output_tracks,
3978 )
3979 .unwrap();
3980 }
3981 assert_eq!(output_boxes.len(), 1);
3982 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3983
3984 for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
3985 *score = 0.0;
3986 }
3987 {
3988 let inputs = vec![
3989 boxes.view().into_dyn(),
3990 scores.view().into_dyn(),
3991 classes.view().into_dyn(),
3992 mask.view().into_dyn(),
3993 protos.view().into_dyn(),
3994 ];
3995 decoder
3996 .decode_tracked_float(
3997 &mut tracker,
3998 100_000_000 / 3,
3999 &inputs,
4000 &mut output_boxes,
4001 &mut output_masks,
4002 &mut output_tracks,
4003 )
4004 .unwrap();
4005 }
4006 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4007 assert!(output_masks.is_empty());
4008 }
4009 } else {
4010 let detect = ndarray::concatenate![
4012 Axis(1),
4013 boxes.view(),
4014 scores.view(),
4015 classes.view(),
4016 mask.view()
4017 ];
4018 let mut detect = detect.insert_axis(Axis(0));
4019 assert_eq!(detect.shape(), &[1, 10, 38]);
4020
4021 if is_proto {
4022 {
4023 let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
4024 decoder
4025 .decode_tracked_float_proto(
4026 &mut tracker,
4027 0,
4028 &inputs,
4029 &mut output_boxes,
4030 &mut output_tracks,
4031 )
4032 .unwrap();
4033 }
4034 assert_eq!(output_boxes.len(), 1);
4035 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4036
4037 for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
4038 *score = 0.0;
4039 }
4040 let proto_result = {
4041 let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
4042 decoder
4043 .decode_tracked_float_proto(
4044 &mut tracker,
4045 100_000_000 / 3,
4046 &inputs,
4047 &mut output_boxes,
4048 &mut output_tracks,
4049 )
4050 .unwrap()
4051 };
4052 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4053 assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
4054 } else {
4055 let mut output_masks = Vec::with_capacity(50);
4056 {
4057 let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
4058 decoder
4059 .decode_tracked_float(
4060 &mut tracker,
4061 0,
4062 &inputs,
4063 &mut output_boxes,
4064 &mut output_masks,
4065 &mut output_tracks,
4066 )
4067 .unwrap();
4068 }
4069 assert_eq!(output_boxes.len(), 1);
4070 assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4071
4072 for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
4073 *score = 0.0;
4074 }
4075 {
4076 let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
4077 decoder
4078 .decode_tracked_float(
4079 &mut tracker,
4080 100_000_000 / 3,
4081 &inputs,
4082 &mut output_boxes,
4083 &mut output_masks,
4084 &mut output_tracks,
4085 )
4086 .unwrap();
4087 }
4088 assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4089 assert!(output_masks.is_empty());
4090 }
4091 }
4092 }
4093 };
4094 }
4095
4096 e2e_tracked_test!(
4097 test_decoder_tracked_end_to_end_segdet,
4098 quantized,
4099 combined,
4100 masks
4101 );
4102 e2e_tracked_test!(
4103 test_decoder_tracked_end_to_end_segdet_float,
4104 float,
4105 combined,
4106 masks
4107 );
4108 e2e_tracked_test!(
4109 test_decoder_tracked_end_to_end_segdet_proto,
4110 quantized,
4111 combined,
4112 proto
4113 );
4114 e2e_tracked_test!(
4115 test_decoder_tracked_end_to_end_segdet_proto_float,
4116 float,
4117 combined,
4118 proto
4119 );
4120 e2e_tracked_test!(
4121 test_decoder_tracked_end_to_end_segdet_split,
4122 quantized,
4123 split,
4124 masks
4125 );
4126 e2e_tracked_test!(
4127 test_decoder_tracked_end_to_end_segdet_split_float,
4128 float,
4129 split,
4130 masks
4131 );
4132 e2e_tracked_test!(
4133 test_decoder_tracked_end_to_end_segdet_split_proto,
4134 quantized,
4135 split,
4136 proto
4137 );
4138 e2e_tracked_test!(
4139 test_decoder_tracked_end_to_end_segdet_split_proto_float,
4140 float,
4141 split,
4142 proto
4143 );
4144
4145 #[test]
4146 fn test_decoder_tracked_linear_motion() {
4147 use crate::configs::{DecoderType, Nms};
4148 use crate::DecoderBuilder;
4149
4150 let score_threshold = 0.25;
4151 let iou_threshold = 0.1;
4152 let out = include_bytes!(concat!(
4153 env!("CARGO_MANIFEST_DIR"),
4154 "/../../testdata/yolov8s_80_classes.bin"
4155 ));
4156 let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
4157 let mut out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
4158 let quant = (0.0040811873, -123).into();
4159
4160 let decoder = DecoderBuilder::default()
4161 .with_config_yolo_det(
4162 crate::configs::Detection {
4163 decoder: DecoderType::Ultralytics,
4164 shape: vec![1, 84, 8400],
4165 anchors: None,
4166 quantization: Some(quant),
4167 dshape: vec![
4168 (crate::configs::DimName::Batch, 1),
4169 (crate::configs::DimName::NumFeatures, 84),
4170 (crate::configs::DimName::NumBoxes, 8400),
4171 ],
4172 normalized: Some(true),
4173 },
4174 None,
4175 )
4176 .with_score_threshold(score_threshold)
4177 .with_iou_threshold(iou_threshold)
4178 .with_nms(Some(Nms::ClassAgnostic))
4179 .build()
4180 .unwrap();
4181
4182 let mut expected_boxes = [
4183 DetectBox {
4184 bbox: BoundingBox {
4185 xmin: 0.5285137,
4186 ymin: 0.05305544,
4187 xmax: 0.87541467,
4188 ymax: 0.9998909,
4189 },
4190 score: 0.5591227,
4191 label: 0,
4192 },
4193 DetectBox {
4194 bbox: BoundingBox {
4195 xmin: 0.130598,
4196 ymin: 0.43260583,
4197 xmax: 0.35098213,
4198 ymax: 0.9958097,
4199 },
4200 score: 0.33057618,
4201 label: 75,
4202 },
4203 ];
4204
4205 let mut tracker = ByteTrackBuilder::new()
4206 .track_update(0.1)
4207 .track_high_conf(0.3)
4208 .build();
4209
4210 let mut output_boxes = Vec::with_capacity(50);
4211 let mut output_masks = Vec::with_capacity(50);
4212 let mut output_tracks = Vec::with_capacity(50);
4213
4214 decoder
4215 .decode_tracked_quantized(
4216 &mut tracker,
4217 0,
4218 &[out.view().into()],
4219 &mut output_boxes,
4220 &mut output_masks,
4221 &mut output_tracks,
4222 )
4223 .unwrap();
4224
4225 assert_eq!(output_boxes.len(), 2);
4226 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4227 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
4228
4229 for i in 1..=100 {
4230 let mut out = out.clone();
4231 let mut x_values = out.slice_mut(s![0, 0, ..]);
4233 for x in x_values.iter_mut() {
4234 *x = x.saturating_add((i as f32 * 1e-3 / quant.0).round() as i8);
4235 }
4236
4237 decoder
4238 .decode_tracked_quantized(
4239 &mut tracker,
4240 100_000_000 * i / 3, &[out.view().into()],
4242 &mut output_boxes,
4243 &mut output_masks,
4244 &mut output_tracks,
4245 )
4246 .unwrap();
4247
4248 assert_eq!(output_boxes.len(), 2);
4249 }
4250 let tracks = tracker.get_active_tracks();
4251 let predicted_boxes: Vec<_> = tracks
4252 .iter()
4253 .map(|track| {
4254 let mut l = track.last_box;
4255 l.bbox = track.info.tracked_location.into();
4256 l
4257 })
4258 .collect();
4259 expected_boxes[0].bbox.xmin += 0.1; expected_boxes[0].bbox.xmax += 0.1;
4261 expected_boxes[1].bbox.xmin += 0.1;
4262 expected_boxes[1].bbox.xmax += 0.1;
4263
4264 assert!(predicted_boxes[0].equal_within_delta(&expected_boxes[0], 1e-3));
4265 assert!(predicted_boxes[1].equal_within_delta(&expected_boxes[1], 1e-3));
4266
4267 let mut scores_values = out.slice_mut(s![0, 4.., ..]);
4269 for score in scores_values.iter_mut() {
4270 *score = i8::MIN; }
4272 decoder
4273 .decode_tracked_quantized(
4274 &mut tracker,
4275 100_000_000 * 101 / 3,
4276 &[out.view().into()],
4277 &mut output_boxes,
4278 &mut output_masks,
4279 &mut output_tracks,
4280 )
4281 .unwrap();
4282 expected_boxes[0].bbox.xmin += 0.001; expected_boxes[0].bbox.xmax += 0.001;
4284 expected_boxes[1].bbox.xmin += 0.001;
4285 expected_boxes[1].bbox.xmax += 0.001;
4286
4287 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-3));
4288 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-3));
4289 }
4290
4291 #[test]
4292 fn test_decoder_tracked_end_to_end_float() {
4293 let score_threshold = 0.45;
4294 let iou_threshold = 0.45;
4295
4296 let mut boxes = Array2::zeros((10, 4));
4297 let mut scores = Array2::zeros((10, 1));
4298 let mut classes = Array2::zeros((10, 1));
4299
4300 boxes
4301 .slice_mut(s![0, ..,])
4302 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4303 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4304 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4305
4306 let detect = ndarray::concatenate![Axis(1), boxes.view(), scores.view(), classes.view(),];
4307 let mut detect = detect.insert_axis(Axis(0));
4308 assert_eq!(detect.shape(), &[1, 10, 6]);
4309 let config = "
4310decoder_version: yolo26
4311outputs:
4312 - type: detection
4313 decoder: ultralytics
4314 quantization: [0.00784313725490196, 0]
4315 shape: [1, 10, 6]
4316 dshape:
4317 - [batch, 1]
4318 - [num_boxes, 10]
4319 - [num_features, 6]
4320 normalized: true
4321";
4322
4323 let decoder = DecoderBuilder::default()
4324 .with_config_yaml_str(config.to_string())
4325 .with_score_threshold(score_threshold)
4326 .with_iou_threshold(iou_threshold)
4327 .build()
4328 .unwrap();
4329
4330 let expected_boxes = [DetectBox {
4331 bbox: BoundingBox {
4332 xmin: 0.1234,
4333 ymin: 0.1234,
4334 xmax: 0.2345,
4335 ymax: 0.2345,
4336 },
4337 score: 0.9876,
4338 label: 2,
4339 }];
4340
4341 let mut tracker = ByteTrackBuilder::new()
4342 .track_update(0.1)
4343 .track_high_conf(0.7)
4344 .build();
4345
4346 let mut output_boxes = Vec::with_capacity(50);
4347 let mut output_masks = Vec::with_capacity(50);
4348 let mut output_tracks = Vec::with_capacity(50);
4349
4350 decoder
4351 .decode_tracked_float(
4352 &mut tracker,
4353 0,
4354 &[detect.view().into_dyn()],
4355 &mut output_boxes,
4356 &mut output_masks,
4357 &mut output_tracks,
4358 )
4359 .unwrap();
4360
4361 assert_eq!(output_boxes.len(), 1);
4362 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4363
4364 for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
4367 *score = 0.0; }
4369
4370 decoder
4371 .decode_tracked_float(
4372 &mut tracker,
4373 100_000_000 / 3,
4374 &[detect.view().into_dyn()],
4375 &mut output_boxes,
4376 &mut output_masks,
4377 &mut output_tracks,
4378 )
4379 .unwrap();
4380 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4381 }
4382}