1#![cfg_attr(coverage_nightly, feature(coverage_attribute))]
71
72use ndarray::{Array, Array2, Array3, ArrayView, ArrayView1, ArrayView3, Dimension};
73use num_traits::{AsPrimitive, Float, PrimInt};
74
75pub mod byte;
76pub mod error;
77pub mod float;
78pub mod modelpack;
79pub mod yolo;
80
81mod decoder;
82pub use decoder::*;
83
84pub use configs::{DecoderVersion, Nms};
85pub use error::{DecoderError, DecoderResult};
86
87use crate::{
88 decoder::configs::QuantTuple, modelpack::modelpack_segmentation_to_mask,
89 yolo::yolo_segmentation_to_mask,
90};
91
92pub trait BBoxTypeTrait {
94 fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4];
96
97 fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
99 input: &[B; 4],
100 quant: Quantization,
101 ) -> [A; 4]
102 where
103 f32: AsPrimitive<A>,
104 i32: AsPrimitive<A>;
105
106 #[inline(always)]
117 fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
118 input: ArrayView1<B>,
119 ) -> [A; 4] {
120 Self::to_xyxy_float(&[input[0], input[1], input[2], input[3]])
121 }
122
123 #[inline(always)]
124 fn ndarray_to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
126 input: ArrayView1<B>,
127 quant: Quantization,
128 ) -> [A; 4]
129 where
130 f32: AsPrimitive<A>,
131 i32: AsPrimitive<A>,
132 {
133 Self::to_xyxy_dequant(&[input[0], input[1], input[2], input[3]], quant)
134 }
135}
136
137#[derive(Debug, Clone, Copy, PartialEq, Eq)]
139pub struct XYXY {}
140
141impl BBoxTypeTrait for XYXY {
142 fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
143 input.map(|b| b.as_())
144 }
145
146 fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
147 input: &[B; 4],
148 quant: Quantization,
149 ) -> [A; 4]
150 where
151 f32: AsPrimitive<A>,
152 i32: AsPrimitive<A>,
153 {
154 let scale = quant.scale.as_();
155 let zp = quant.zero_point.as_();
156 input.map(|b| (b.as_() - zp) * scale)
157 }
158
159 #[inline(always)]
160 fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
161 input: ArrayView1<B>,
162 ) -> [A; 4] {
163 [
164 input[0].as_(),
165 input[1].as_(),
166 input[2].as_(),
167 input[3].as_(),
168 ]
169 }
170}
171
172#[derive(Debug, Clone, Copy, PartialEq, Eq)]
175pub struct XYWH {}
176
177impl BBoxTypeTrait for XYWH {
178 #[inline(always)]
179 fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
180 let half = A::one() / (A::one() + A::one());
181 [
182 (input[0].as_()) - (input[2].as_() * half),
183 (input[1].as_()) - (input[3].as_() * half),
184 (input[0].as_()) + (input[2].as_() * half),
185 (input[1].as_()) + (input[3].as_() * half),
186 ]
187 }
188
189 #[inline(always)]
190 fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
191 input: &[B; 4],
192 quant: Quantization,
193 ) -> [A; 4]
194 where
195 f32: AsPrimitive<A>,
196 i32: AsPrimitive<A>,
197 {
198 let scale = quant.scale.as_();
199 let half_scale = (quant.scale * 0.5).as_();
200 let zp = quant.zero_point.as_();
201 let [x, y, w, h] = [
202 (input[0].as_() - zp) * scale,
203 (input[1].as_() - zp) * scale,
204 (input[2].as_() - zp) * half_scale,
205 (input[3].as_() - zp) * half_scale,
206 ];
207
208 [x - w, y - h, x + w, y + h]
209 }
210
211 #[inline(always)]
212 fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
213 input: ArrayView1<B>,
214 ) -> [A; 4] {
215 let half = A::one() / (A::one() + A::one());
216 [
217 (input[0].as_()) - (input[2].as_() * half),
218 (input[1].as_()) - (input[3].as_() * half),
219 (input[0].as_()) + (input[2].as_() * half),
220 (input[1].as_()) + (input[3].as_() * half),
221 ]
222 }
223}
224
225#[derive(Debug, Clone, Copy, PartialEq)]
227pub struct Quantization {
228 pub scale: f32,
229 pub zero_point: i32,
230}
231
232impl Quantization {
233 pub fn new(scale: f32, zero_point: i32) -> Self {
242 Self { scale, zero_point }
243 }
244}
245
246impl From<QuantTuple> for Quantization {
247 fn from(quant_tuple: QuantTuple) -> Quantization {
258 Quantization {
259 scale: quant_tuple.0,
260 zero_point: quant_tuple.1,
261 }
262 }
263}
264
265impl<S, Z> From<(S, Z)> for Quantization
266where
267 S: AsPrimitive<f32>,
268 Z: AsPrimitive<i32>,
269{
270 fn from((scale, zp): (S, Z)) -> Quantization {
279 Self {
280 scale: scale.as_(),
281 zero_point: zp.as_(),
282 }
283 }
284}
285
286impl Default for Quantization {
287 fn default() -> Self {
296 Self {
297 scale: 1.0,
298 zero_point: 0,
299 }
300 }
301}
302
303#[derive(Debug, Clone, Copy, PartialEq, Default)]
305pub struct DetectBox {
306 pub bbox: BoundingBox,
307 pub score: f32,
309 pub label: usize,
311}
312
313#[derive(Debug, Clone, Copy, PartialEq, Default)]
315pub struct BoundingBox {
316 pub xmin: f32,
318 pub ymin: f32,
320 pub xmax: f32,
322 pub ymax: f32,
324}
325
326impl BoundingBox {
327 pub fn new(xmin: f32, ymin: f32, xmax: f32, ymax: f32) -> Self {
329 Self {
330 xmin,
331 ymin,
332 xmax,
333 ymax,
334 }
335 }
336
337 pub fn to_canonical(&self) -> Self {
346 let xmin = self.xmin.min(self.xmax);
347 let xmax = self.xmin.max(self.xmax);
348 let ymin = self.ymin.min(self.ymax);
349 let ymax = self.ymin.max(self.ymax);
350 BoundingBox {
351 xmin,
352 ymin,
353 xmax,
354 ymax,
355 }
356 }
357}
358
359impl From<BoundingBox> for [f32; 4] {
360 fn from(b: BoundingBox) -> Self {
375 [b.xmin, b.ymin, b.xmax, b.ymax]
376 }
377}
378
379impl From<[f32; 4]> for BoundingBox {
380 fn from(arr: [f32; 4]) -> Self {
383 BoundingBox {
384 xmin: arr[0],
385 ymin: arr[1],
386 xmax: arr[2],
387 ymax: arr[3],
388 }
389 }
390}
391
392impl DetectBox {
393 pub fn equal_within_delta(&self, rhs: &DetectBox, eps: f32) -> bool {
422 let eq_delta = |a: f32, b: f32| (a - b).abs() <= eps;
423 self.label == rhs.label
424 && eq_delta(self.score, rhs.score)
425 && eq_delta(self.bbox.xmin, rhs.bbox.xmin)
426 && eq_delta(self.bbox.ymin, rhs.bbox.ymin)
427 && eq_delta(self.bbox.xmax, rhs.bbox.xmax)
428 && eq_delta(self.bbox.ymax, rhs.bbox.ymax)
429 }
430}
431
432#[derive(Debug, Clone, PartialEq, Default)]
435pub struct Segmentation {
436 pub xmin: f32,
438 pub ymin: f32,
440 pub xmax: f32,
442 pub ymax: f32,
444 pub segmentation: Array3<u8>,
452}
453
454#[derive(Debug, Clone)]
460pub enum ProtoTensor {
461 Quantized {
465 protos: Array3<i8>,
466 quantization: Quantization,
467 },
468 Float(Array3<f32>),
470}
471
472impl ProtoTensor {
473 pub fn is_quantized(&self) -> bool {
475 matches!(self, ProtoTensor::Quantized { .. })
476 }
477
478 pub fn dim(&self) -> (usize, usize, usize) {
480 match self {
481 ProtoTensor::Quantized { protos, .. } => protos.dim(),
482 ProtoTensor::Float(arr) => arr.dim(),
483 }
484 }
485
486 pub fn as_f32(&self) -> std::borrow::Cow<'_, Array3<f32>> {
489 match self {
490 ProtoTensor::Float(arr) => std::borrow::Cow::Borrowed(arr),
491 ProtoTensor::Quantized {
492 protos,
493 quantization,
494 } => {
495 let scale = quantization.scale;
496 let zp = quantization.zero_point as f32;
497 std::borrow::Cow::Owned(protos.map(|&v| (v as f32 - zp) * scale))
498 }
499 }
500 }
501}
502
503#[derive(Debug, Clone)]
509pub struct ProtoData {
510 pub mask_coefficients: Vec<Vec<f32>>,
512 pub protos: ProtoTensor,
514}
515
516pub fn dequant_detect_box<SCORE: PrimInt + AsPrimitive<f32>>(
534 detect: &DetectBoxQuantized<SCORE>,
535 quant_scores: Quantization,
536) -> DetectBox {
537 let scaled_zp = -quant_scores.scale * quant_scores.zero_point as f32;
538 DetectBox {
539 bbox: detect.bbox,
540 score: quant_scores.scale * detect.score.as_() + scaled_zp,
541 label: detect.label,
542 }
543}
544#[derive(Debug, Clone, Copy, PartialEq)]
546pub struct DetectBoxQuantized<
547 SCORE: PrimInt + AsPrimitive<f32>,
549> {
550 pub bbox: BoundingBox,
552 pub score: SCORE,
555 pub label: usize,
557}
558
559pub fn dequantize_ndarray<T: AsPrimitive<F>, D: Dimension, F: Float + 'static>(
572 input: ArrayView<T, D>,
573 quant: Quantization,
574) -> Array<F, D>
575where
576 i32: num_traits::AsPrimitive<F>,
577 f32: num_traits::AsPrimitive<F>,
578{
579 let zero_point = quant.zero_point.as_();
580 let scale = quant.scale.as_();
581 if zero_point != F::zero() {
582 let scaled_zero = -zero_point * scale;
583 input.mapv(|d| d.as_() * scale + scaled_zero)
584 } else {
585 input.mapv(|d| d.as_() * scale)
586 }
587}
588
589pub fn dequantize_cpu<T: AsPrimitive<F>, F: Float + 'static>(
602 input: &[T],
603 quant: Quantization,
604 output: &mut [F],
605) where
606 f32: num_traits::AsPrimitive<F>,
607 i32: num_traits::AsPrimitive<F>,
608{
609 assert!(input.len() == output.len());
610 let zero_point = quant.zero_point.as_();
611 let scale = quant.scale.as_();
612 if zero_point != F::zero() {
613 let scaled_zero = -zero_point * scale; input
615 .iter()
616 .zip(output)
617 .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
618 } else {
619 input
620 .iter()
621 .zip(output)
622 .for_each(|(d, deq)| *deq = d.as_() * scale);
623 }
624}
625
626pub fn dequantize_cpu_chunked<T: AsPrimitive<F>, F: Float + 'static>(
640 input: &[T],
641 quant: Quantization,
642 output: &mut [F],
643) where
644 f32: num_traits::AsPrimitive<F>,
645 i32: num_traits::AsPrimitive<F>,
646{
647 assert!(input.len() == output.len());
648 let zero_point = quant.zero_point.as_();
649 let scale = quant.scale.as_();
650
651 let input = input.as_chunks::<4>();
652 let output = output.as_chunks_mut::<4>();
653
654 if zero_point != F::zero() {
655 let scaled_zero = -zero_point * scale; input
658 .0
659 .iter()
660 .zip(output.0)
661 .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale + scaled_zero));
662 input
663 .1
664 .iter()
665 .zip(output.1)
666 .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
667 } else {
668 input
669 .0
670 .iter()
671 .zip(output.0)
672 .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale));
673 input
674 .1
675 .iter()
676 .zip(output.1)
677 .for_each(|(d, deq)| *deq = d.as_() * scale);
678 }
679}
680
681pub fn segmentation_to_mask(segmentation: ArrayView3<u8>) -> Result<Array2<u8>, DecoderError> {
699 if segmentation.shape()[2] == 0 {
700 return Err(DecoderError::InvalidShape(
701 "Segmentation tensor must have non-zero depth".to_string(),
702 ));
703 }
704 if segmentation.shape()[2] == 1 {
705 yolo_segmentation_to_mask(segmentation, 128)
706 } else {
707 Ok(modelpack_segmentation_to_mask(segmentation))
708 }
709}
710
711fn arg_max<T: PartialOrd + Copy>(score: ArrayView1<T>) -> (T, usize) {
713 score
714 .iter()
715 .enumerate()
716 .fold((score[0], 0), |(max, arg_max), (ind, s)| {
717 if max > *s {
718 (max, arg_max)
719 } else {
720 (*s, ind)
721 }
722 })
723}
724#[cfg(test)]
725#[cfg_attr(coverage_nightly, coverage(off))]
726mod decoder_tests {
727 #![allow(clippy::excessive_precision)]
728 use crate::{
729 configs::{DecoderType, DimName, Protos},
730 modelpack::{decode_modelpack_det, decode_modelpack_split_quant},
731 yolo::{
732 decode_yolo_det, decode_yolo_det_float, decode_yolo_segdet_float,
733 decode_yolo_segdet_quant,
734 },
735 *,
736 };
737 use ndarray::{array, s, Array4};
738 use ndarray_stats::DeviationExt;
739
740 fn compare_outputs(
741 boxes: (&[DetectBox], &[DetectBox]),
742 masks: (&[Segmentation], &[Segmentation]),
743 ) {
744 let (boxes0, boxes1) = boxes;
745 let (masks0, masks1) = masks;
746
747 assert_eq!(boxes0.len(), boxes1.len());
748 assert_eq!(masks0.len(), masks1.len());
749
750 for (b_i8, b_f32) in boxes0.iter().zip(boxes1) {
751 assert!(
752 b_i8.equal_within_delta(b_f32, 1e-6),
753 "{b_i8:?} is not equal to {b_f32:?}"
754 );
755 }
756
757 for (m_i8, m_f32) in masks0.iter().zip(masks1) {
758 assert_eq!(
759 [m_i8.xmin, m_i8.ymin, m_i8.xmax, m_i8.ymax],
760 [m_f32.xmin, m_f32.ymin, m_f32.xmax, m_f32.ymax],
761 );
762 assert_eq!(m_i8.segmentation.shape(), m_f32.segmentation.shape());
763 let mask_i8 = m_i8.segmentation.map(|x| *x as i32);
764 let mask_f32 = m_f32.segmentation.map(|x| *x as i32);
765 let diff = &mask_i8 - &mask_f32;
766 for x in 0..diff.shape()[0] {
767 for y in 0..diff.shape()[1] {
768 for z in 0..diff.shape()[2] {
769 let val = diff[[x, y, z]];
770 assert!(
771 val.abs() <= 1,
772 "Difference between mask0 and mask1 is greater than 1 at ({}, {}, {}): {}",
773 x,
774 y,
775 z,
776 val
777 );
778 }
779 }
780 }
781 let mean_sq_err = mask_i8.mean_sq_err(&mask_f32).unwrap();
782 assert!(
783 mean_sq_err < 1e-2,
784 "Mean Square Error between masks was greater than 1%: {:.2}%",
785 mean_sq_err * 100.0
786 );
787 }
788 }
789
790 #[test]
791 fn test_decoder_modelpack() {
792 let score_threshold = 0.45;
793 let iou_threshold = 0.45;
794 let boxes = include_bytes!(concat!(
795 env!("CARGO_MANIFEST_DIR"),
796 "/../../testdata/modelpack_boxes_1935x1x4.bin"
797 ));
798 let boxes = ndarray::Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
799
800 let scores = include_bytes!(concat!(
801 env!("CARGO_MANIFEST_DIR"),
802 "/../../testdata/modelpack_scores_1935x1.bin"
803 ));
804 let scores = ndarray::Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
805
806 let quant_boxes = (0.004656755365431309, 21).into();
807 let quant_scores = (0.0019603664986789227, 0).into();
808
809 let decoder = DecoderBuilder::default()
810 .with_config_modelpack_det(
811 configs::Boxes {
812 decoder: DecoderType::ModelPack,
813 quantization: Some(quant_boxes),
814 shape: vec![1, 1935, 1, 4],
815 dshape: vec![
816 (DimName::Batch, 1),
817 (DimName::NumBoxes, 1935),
818 (DimName::Padding, 1),
819 (DimName::BoxCoords, 4),
820 ],
821 normalized: Some(true),
822 },
823 configs::Scores {
824 decoder: DecoderType::ModelPack,
825 quantization: Some(quant_scores),
826 shape: vec![1, 1935, 1],
827 dshape: vec![
828 (DimName::Batch, 1),
829 (DimName::NumBoxes, 1935),
830 (DimName::NumClasses, 1),
831 ],
832 },
833 )
834 .with_score_threshold(score_threshold)
835 .with_iou_threshold(iou_threshold)
836 .build()
837 .unwrap();
838
839 let quant_boxes = quant_boxes.into();
840 let quant_scores = quant_scores.into();
841
842 let mut output_boxes: Vec<_> = Vec::with_capacity(50);
843 decode_modelpack_det(
844 (boxes.slice(s![0, .., 0, ..]), quant_boxes),
845 (scores.slice(s![0, .., ..]), quant_scores),
846 score_threshold,
847 iou_threshold,
848 &mut output_boxes,
849 );
850 assert!(output_boxes[0].equal_within_delta(
851 &DetectBox {
852 bbox: BoundingBox {
853 xmin: 0.40513772,
854 ymin: 0.6379755,
855 xmax: 0.5122431,
856 ymax: 0.7730214,
857 },
858 score: 0.4861709,
859 label: 0
860 },
861 1e-6
862 ));
863
864 let mut output_boxes1 = Vec::with_capacity(50);
865 let mut output_masks1 = Vec::with_capacity(50);
866
867 decoder
868 .decode_quantized(
869 &[boxes.view().into(), scores.view().into()],
870 &mut output_boxes1,
871 &mut output_masks1,
872 )
873 .unwrap();
874
875 let mut output_boxes_float = Vec::with_capacity(50);
876 let mut output_masks_float = Vec::with_capacity(50);
877
878 let boxes = dequantize_ndarray(boxes.view(), quant_boxes);
879 let scores = dequantize_ndarray(scores.view(), quant_scores);
880
881 decoder
882 .decode_float::<f32>(
883 &[boxes.view().into_dyn(), scores.view().into_dyn()],
884 &mut output_boxes_float,
885 &mut output_masks_float,
886 )
887 .unwrap();
888
889 compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
890 compare_outputs(
891 (&output_boxes, &output_boxes_float),
892 (&[], &output_masks_float),
893 );
894 }
895
896 #[test]
897 fn test_decoder_modelpack_split_u8() {
898 let score_threshold = 0.45;
899 let iou_threshold = 0.45;
900 let detect0 = include_bytes!(concat!(
901 env!("CARGO_MANIFEST_DIR"),
902 "/../../testdata/modelpack_split_9x15x18.bin"
903 ));
904 let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
905
906 let detect1 = include_bytes!(concat!(
907 env!("CARGO_MANIFEST_DIR"),
908 "/../../testdata/modelpack_split_17x30x18.bin"
909 ));
910 let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
911
912 let quant0 = (0.08547406643629074, 174).into();
913 let quant1 = (0.09929127991199493, 183).into();
914 let anchors0 = vec![
915 [0.36666667461395264, 0.31481480598449707],
916 [0.38749998807907104, 0.4740740656852722],
917 [0.5333333611488342, 0.644444465637207],
918 ];
919 let anchors1 = vec![
920 [0.13750000298023224, 0.2074074000120163],
921 [0.2541666626930237, 0.21481481194496155],
922 [0.23125000298023224, 0.35185185074806213],
923 ];
924
925 let detect_config0 = configs::Detection {
926 decoder: DecoderType::ModelPack,
927 shape: vec![1, 9, 15, 18],
928 anchors: Some(anchors0.clone()),
929 quantization: Some(quant0),
930 dshape: vec![
931 (DimName::Batch, 1),
932 (DimName::Height, 9),
933 (DimName::Width, 15),
934 (DimName::NumAnchorsXFeatures, 18),
935 ],
936 normalized: Some(true),
937 };
938
939 let detect_config1 = configs::Detection {
940 decoder: DecoderType::ModelPack,
941 shape: vec![1, 17, 30, 18],
942 anchors: Some(anchors1.clone()),
943 quantization: Some(quant1),
944 dshape: vec![
945 (DimName::Batch, 1),
946 (DimName::Height, 17),
947 (DimName::Width, 30),
948 (DimName::NumAnchorsXFeatures, 18),
949 ],
950 normalized: Some(true),
951 };
952
953 let config0 = (&detect_config0).try_into().unwrap();
954 let config1 = (&detect_config1).try_into().unwrap();
955
956 let decoder = DecoderBuilder::default()
957 .with_config_modelpack_det_split(vec![detect_config1, detect_config0])
958 .with_score_threshold(score_threshold)
959 .with_iou_threshold(iou_threshold)
960 .build()
961 .unwrap();
962
963 let quant0 = quant0.into();
964 let quant1 = quant1.into();
965
966 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
967 decode_modelpack_split_quant(
968 &[
969 detect0.slice(s![0, .., .., ..]),
970 detect1.slice(s![0, .., .., ..]),
971 ],
972 &[config0, config1],
973 score_threshold,
974 iou_threshold,
975 &mut output_boxes,
976 );
977 assert!(output_boxes[0].equal_within_delta(
978 &DetectBox {
979 bbox: BoundingBox {
980 xmin: 0.43171933,
981 ymin: 0.68243736,
982 xmax: 0.5626645,
983 ymax: 0.808863,
984 },
985 score: 0.99240804,
986 label: 0
987 },
988 1e-6
989 ));
990
991 let mut output_boxes1: Vec<_> = Vec::with_capacity(10);
992 let mut output_masks1: Vec<_> = Vec::with_capacity(10);
993 decoder
994 .decode_quantized(
995 &[detect0.view().into(), detect1.view().into()],
996 &mut output_boxes1,
997 &mut output_masks1,
998 )
999 .unwrap();
1000
1001 let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(10);
1002 let mut output_masks1_f32: Vec<_> = Vec::with_capacity(10);
1003
1004 let detect0 = dequantize_ndarray(detect0.view(), quant0);
1005 let detect1 = dequantize_ndarray(detect1.view(), quant1);
1006 decoder
1007 .decode_float::<f32>(
1008 &[detect0.view().into_dyn(), detect1.view().into_dyn()],
1009 &mut output_boxes1_f32,
1010 &mut output_masks1_f32,
1011 )
1012 .unwrap();
1013
1014 compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1015 compare_outputs(
1016 (&output_boxes, &output_boxes1_f32),
1017 (&[], &output_masks1_f32),
1018 );
1019 }
1020
1021 #[test]
1022 fn test_decoder_parse_config_modelpack_split_u8() {
1023 let score_threshold = 0.45;
1024 let iou_threshold = 0.45;
1025 let detect0 = include_bytes!(concat!(
1026 env!("CARGO_MANIFEST_DIR"),
1027 "/../../testdata/modelpack_split_9x15x18.bin"
1028 ));
1029 let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1030
1031 let detect1 = include_bytes!(concat!(
1032 env!("CARGO_MANIFEST_DIR"),
1033 "/../../testdata/modelpack_split_17x30x18.bin"
1034 ));
1035 let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1036
1037 let decoder = DecoderBuilder::default()
1038 .with_config_yaml_str(
1039 include_str!(concat!(
1040 env!("CARGO_MANIFEST_DIR"),
1041 "/../../testdata/modelpack_split.yaml"
1042 ))
1043 .to_string(),
1044 )
1045 .with_score_threshold(score_threshold)
1046 .with_iou_threshold(iou_threshold)
1047 .build()
1048 .unwrap();
1049
1050 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1051 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1052 decoder
1053 .decode_quantized(
1054 &[
1055 ArrayViewDQuantized::from(detect1.view()),
1056 ArrayViewDQuantized::from(detect0.view()),
1057 ],
1058 &mut output_boxes,
1059 &mut output_masks,
1060 )
1061 .unwrap();
1062 assert!(output_boxes[0].equal_within_delta(
1063 &DetectBox {
1064 bbox: BoundingBox {
1065 xmin: 0.43171933,
1066 ymin: 0.68243736,
1067 xmax: 0.5626645,
1068 ymax: 0.808863,
1069 },
1070 score: 0.99240804,
1071 label: 0
1072 },
1073 1e-6
1074 ));
1075 }
1076
1077 #[test]
1078 fn test_modelpack_seg() {
1079 let out = include_bytes!(concat!(
1080 env!("CARGO_MANIFEST_DIR"),
1081 "/../../testdata/modelpack_seg_2x160x160.bin"
1082 ));
1083 let out = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
1084 let quant = (1.0 / 255.0, 0).into();
1085
1086 let decoder = DecoderBuilder::default()
1087 .with_config_modelpack_seg(configs::Segmentation {
1088 decoder: DecoderType::ModelPack,
1089 quantization: Some(quant),
1090 shape: vec![1, 2, 160, 160],
1091 dshape: vec![
1092 (DimName::Batch, 1),
1093 (DimName::NumClasses, 2),
1094 (DimName::Height, 160),
1095 (DimName::Width, 160),
1096 ],
1097 })
1098 .build()
1099 .unwrap();
1100 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1101 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1102 decoder
1103 .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
1104 .unwrap();
1105
1106 let mut mask = out.slice(s![0, .., .., ..]);
1107 mask.swap_axes(0, 1);
1108 mask.swap_axes(1, 2);
1109 let mask = [Segmentation {
1110 xmin: 0.0,
1111 ymin: 0.0,
1112 xmax: 1.0,
1113 ymax: 1.0,
1114 segmentation: mask.into_owned(),
1115 }];
1116 compare_outputs((&[], &output_boxes), (&mask, &output_masks));
1117
1118 decoder
1119 .decode_float::<f32>(
1120 &[dequantize_ndarray(out.view(), quant.into())
1121 .view()
1122 .into_dyn()],
1123 &mut output_boxes,
1124 &mut output_masks,
1125 )
1126 .unwrap();
1127
1128 compare_outputs((&[], &output_boxes), (&[], &[]));
1134 let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1135 let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1136
1137 assert_eq!(mask0, mask1);
1138 }
1139 #[test]
1140 fn test_modelpack_seg_quant() {
1141 let out = include_bytes!(concat!(
1142 env!("CARGO_MANIFEST_DIR"),
1143 "/../../testdata/modelpack_seg_2x160x160.bin"
1144 ));
1145 let out_u8 = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
1146 let out_i8 = out_u8.mapv(|x| (x as i16 - 128) as i8);
1147 let out_u16 = out_u8.mapv(|x| (x as u16) << 8);
1148 let out_i16 = out_u8.mapv(|x| (((x as i32) << 8) - 32768) as i16);
1149 let out_u32 = out_u8.mapv(|x| (x as u32) << 24);
1150 let out_i32 = out_u8.mapv(|x| (((x as i64) << 24) - 2147483648) as i32);
1151
1152 let quant = (1.0 / 255.0, 0).into();
1153
1154 let decoder = DecoderBuilder::default()
1155 .with_config_modelpack_seg(configs::Segmentation {
1156 decoder: DecoderType::ModelPack,
1157 quantization: Some(quant),
1158 shape: vec![1, 2, 160, 160],
1159 dshape: vec![
1160 (DimName::Batch, 1),
1161 (DimName::NumClasses, 2),
1162 (DimName::Height, 160),
1163 (DimName::Width, 160),
1164 ],
1165 })
1166 .build()
1167 .unwrap();
1168 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1169 let mut output_masks_u8: Vec<_> = Vec::with_capacity(10);
1170 decoder
1171 .decode_quantized(
1172 &[out_u8.view().into()],
1173 &mut output_boxes,
1174 &mut output_masks_u8,
1175 )
1176 .unwrap();
1177
1178 let mut output_masks_i8: Vec<_> = Vec::with_capacity(10);
1179 decoder
1180 .decode_quantized(
1181 &[out_i8.view().into()],
1182 &mut output_boxes,
1183 &mut output_masks_i8,
1184 )
1185 .unwrap();
1186
1187 let mut output_masks_u16: Vec<_> = Vec::with_capacity(10);
1188 decoder
1189 .decode_quantized(
1190 &[out_u16.view().into()],
1191 &mut output_boxes,
1192 &mut output_masks_u16,
1193 )
1194 .unwrap();
1195
1196 let mut output_masks_i16: Vec<_> = Vec::with_capacity(10);
1197 decoder
1198 .decode_quantized(
1199 &[out_i16.view().into()],
1200 &mut output_boxes,
1201 &mut output_masks_i16,
1202 )
1203 .unwrap();
1204
1205 let mut output_masks_u32: Vec<_> = Vec::with_capacity(10);
1206 decoder
1207 .decode_quantized(
1208 &[out_u32.view().into()],
1209 &mut output_boxes,
1210 &mut output_masks_u32,
1211 )
1212 .unwrap();
1213
1214 let mut output_masks_i32: Vec<_> = Vec::with_capacity(10);
1215 decoder
1216 .decode_quantized(
1217 &[out_i32.view().into()],
1218 &mut output_boxes,
1219 &mut output_masks_i32,
1220 )
1221 .unwrap();
1222
1223 compare_outputs((&[], &output_boxes), (&[], &[]));
1224 let mask_u8 = segmentation_to_mask(output_masks_u8[0].segmentation.view()).unwrap();
1225 let mask_i8 = segmentation_to_mask(output_masks_i8[0].segmentation.view()).unwrap();
1226 let mask_u16 = segmentation_to_mask(output_masks_u16[0].segmentation.view()).unwrap();
1227 let mask_i16 = segmentation_to_mask(output_masks_i16[0].segmentation.view()).unwrap();
1228 let mask_u32 = segmentation_to_mask(output_masks_u32[0].segmentation.view()).unwrap();
1229 let mask_i32 = segmentation_to_mask(output_masks_i32[0].segmentation.view()).unwrap();
1230 assert_eq!(mask_u8, mask_i8);
1231 assert_eq!(mask_u8, mask_u16);
1232 assert_eq!(mask_u8, mask_i16);
1233 assert_eq!(mask_u8, mask_u32);
1234 assert_eq!(mask_u8, mask_i32);
1235 }
1236
1237 #[test]
1238 fn test_modelpack_segdet() {
1239 let score_threshold = 0.45;
1240 let iou_threshold = 0.45;
1241
1242 let boxes = include_bytes!(concat!(
1243 env!("CARGO_MANIFEST_DIR"),
1244 "/../../testdata/modelpack_boxes_1935x1x4.bin"
1245 ));
1246 let boxes = Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
1247
1248 let scores = include_bytes!(concat!(
1249 env!("CARGO_MANIFEST_DIR"),
1250 "/../../testdata/modelpack_scores_1935x1.bin"
1251 ));
1252 let scores = Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
1253
1254 let seg = include_bytes!(concat!(
1255 env!("CARGO_MANIFEST_DIR"),
1256 "/../../testdata/modelpack_seg_2x160x160.bin"
1257 ));
1258 let seg = Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1259
1260 let quant_boxes = (0.004656755365431309, 21).into();
1261 let quant_scores = (0.0019603664986789227, 0).into();
1262 let quant_seg = (1.0 / 255.0, 0).into();
1263
1264 let decoder = DecoderBuilder::default()
1265 .with_config_modelpack_segdet(
1266 configs::Boxes {
1267 decoder: DecoderType::ModelPack,
1268 quantization: Some(quant_boxes),
1269 shape: vec![1, 1935, 1, 4],
1270 dshape: vec![
1271 (DimName::Batch, 1),
1272 (DimName::NumBoxes, 1935),
1273 (DimName::Padding, 1),
1274 (DimName::BoxCoords, 4),
1275 ],
1276 normalized: Some(true),
1277 },
1278 configs::Scores {
1279 decoder: DecoderType::ModelPack,
1280 quantization: Some(quant_scores),
1281 shape: vec![1, 1935, 1],
1282 dshape: vec![
1283 (DimName::Batch, 1),
1284 (DimName::NumBoxes, 1935),
1285 (DimName::NumClasses, 1),
1286 ],
1287 },
1288 configs::Segmentation {
1289 decoder: DecoderType::ModelPack,
1290 quantization: Some(quant_seg),
1291 shape: vec![1, 2, 160, 160],
1292 dshape: vec![
1293 (DimName::Batch, 1),
1294 (DimName::NumClasses, 2),
1295 (DimName::Height, 160),
1296 (DimName::Width, 160),
1297 ],
1298 },
1299 )
1300 .with_iou_threshold(iou_threshold)
1301 .with_score_threshold(score_threshold)
1302 .build()
1303 .unwrap();
1304 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1305 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1306 decoder
1307 .decode_quantized(
1308 &[scores.view().into(), boxes.view().into(), seg.view().into()],
1309 &mut output_boxes,
1310 &mut output_masks,
1311 )
1312 .unwrap();
1313
1314 let mut mask = seg.slice(s![0, .., .., ..]);
1315 mask.swap_axes(0, 1);
1316 mask.swap_axes(1, 2);
1317 let mask = [Segmentation {
1318 xmin: 0.0,
1319 ymin: 0.0,
1320 xmax: 1.0,
1321 ymax: 1.0,
1322 segmentation: mask.into_owned(),
1323 }];
1324 let correct_boxes = [DetectBox {
1325 bbox: BoundingBox {
1326 xmin: 0.40513772,
1327 ymin: 0.6379755,
1328 xmax: 0.5122431,
1329 ymax: 0.7730214,
1330 },
1331 score: 0.4861709,
1332 label: 0,
1333 }];
1334 compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1335
1336 let scores = dequantize_ndarray(scores.view(), quant_scores.into());
1337 let boxes = dequantize_ndarray(boxes.view(), quant_boxes.into());
1338 let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1339 decoder
1340 .decode_float::<f32>(
1341 &[
1342 scores.view().into_dyn(),
1343 boxes.view().into_dyn(),
1344 seg.view().into_dyn(),
1345 ],
1346 &mut output_boxes,
1347 &mut output_masks,
1348 )
1349 .unwrap();
1350
1351 compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1357 let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1358 let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1359
1360 assert_eq!(mask0, mask1);
1361 }
1362
1363 #[test]
1364 fn test_modelpack_segdet_split() {
1365 let score_threshold = 0.8;
1366 let iou_threshold = 0.5;
1367
1368 let seg = include_bytes!(concat!(
1369 env!("CARGO_MANIFEST_DIR"),
1370 "/../../testdata/modelpack_seg_2x160x160.bin"
1371 ));
1372 let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1373
1374 let detect0 = include_bytes!(concat!(
1375 env!("CARGO_MANIFEST_DIR"),
1376 "/../../testdata/modelpack_split_9x15x18.bin"
1377 ));
1378 let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1379
1380 let detect1 = include_bytes!(concat!(
1381 env!("CARGO_MANIFEST_DIR"),
1382 "/../../testdata/modelpack_split_17x30x18.bin"
1383 ));
1384 let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1385
1386 let quant0 = (0.08547406643629074, 174).into();
1387 let quant1 = (0.09929127991199493, 183).into();
1388 let quant_seg = (1.0 / 255.0, 0).into();
1389
1390 let anchors0 = vec![
1391 [0.36666667461395264, 0.31481480598449707],
1392 [0.38749998807907104, 0.4740740656852722],
1393 [0.5333333611488342, 0.644444465637207],
1394 ];
1395 let anchors1 = vec![
1396 [0.13750000298023224, 0.2074074000120163],
1397 [0.2541666626930237, 0.21481481194496155],
1398 [0.23125000298023224, 0.35185185074806213],
1399 ];
1400
1401 let decoder = DecoderBuilder::default()
1402 .with_config_modelpack_segdet_split(
1403 vec![
1404 configs::Detection {
1405 decoder: DecoderType::ModelPack,
1406 shape: vec![1, 17, 30, 18],
1407 anchors: Some(anchors1),
1408 quantization: Some(quant1),
1409 dshape: vec![
1410 (DimName::Batch, 1),
1411 (DimName::Height, 17),
1412 (DimName::Width, 30),
1413 (DimName::NumAnchorsXFeatures, 18),
1414 ],
1415 normalized: Some(true),
1416 },
1417 configs::Detection {
1418 decoder: DecoderType::ModelPack,
1419 shape: vec![1, 9, 15, 18],
1420 anchors: Some(anchors0),
1421 quantization: Some(quant0),
1422 dshape: vec![
1423 (DimName::Batch, 1),
1424 (DimName::Height, 9),
1425 (DimName::Width, 15),
1426 (DimName::NumAnchorsXFeatures, 18),
1427 ],
1428 normalized: Some(true),
1429 },
1430 ],
1431 configs::Segmentation {
1432 decoder: DecoderType::ModelPack,
1433 quantization: Some(quant_seg),
1434 shape: vec![1, 2, 160, 160],
1435 dshape: vec![
1436 (DimName::Batch, 1),
1437 (DimName::NumClasses, 2),
1438 (DimName::Height, 160),
1439 (DimName::Width, 160),
1440 ],
1441 },
1442 )
1443 .with_score_threshold(score_threshold)
1444 .with_iou_threshold(iou_threshold)
1445 .build()
1446 .unwrap();
1447 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1448 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1449 decoder
1450 .decode_quantized(
1451 &[
1452 detect0.view().into(),
1453 detect1.view().into(),
1454 seg.view().into(),
1455 ],
1456 &mut output_boxes,
1457 &mut output_masks,
1458 )
1459 .unwrap();
1460
1461 let mut mask = seg.slice(s![0, .., .., ..]);
1462 mask.swap_axes(0, 1);
1463 mask.swap_axes(1, 2);
1464 let mask = [Segmentation {
1465 xmin: 0.0,
1466 ymin: 0.0,
1467 xmax: 1.0,
1468 ymax: 1.0,
1469 segmentation: mask.into_owned(),
1470 }];
1471 let correct_boxes = [DetectBox {
1472 bbox: BoundingBox {
1473 xmin: 0.43171933,
1474 ymin: 0.68243736,
1475 xmax: 0.5626645,
1476 ymax: 0.808863,
1477 },
1478 score: 0.99240804,
1479 label: 0,
1480 }];
1481 println!("Output Boxes: {:?}", output_boxes);
1482 compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1483
1484 let detect0 = dequantize_ndarray(detect0.view(), quant0.into());
1485 let detect1 = dequantize_ndarray(detect1.view(), quant1.into());
1486 let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1487 decoder
1488 .decode_float::<f32>(
1489 &[
1490 detect0.view().into_dyn(),
1491 detect1.view().into_dyn(),
1492 seg.view().into_dyn(),
1493 ],
1494 &mut output_boxes,
1495 &mut output_masks,
1496 )
1497 .unwrap();
1498
1499 compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1505 let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1506 let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1507
1508 assert_eq!(mask0, mask1);
1509 }
1510
1511 #[test]
1512 fn test_dequant_chunked() {
1513 let out = include_bytes!(concat!(
1514 env!("CARGO_MANIFEST_DIR"),
1515 "/../../testdata/yolov8s_80_classes.bin"
1516 ));
1517 let mut out =
1518 unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) }.to_vec();
1519 out.push(123); let mut out_dequant = vec![0.0; 84 * 8400 + 1];
1522 let mut out_dequant_simd = vec![0.0; 84 * 8400 + 1];
1523 let quant = Quantization::new(0.0040811873, -123);
1524 dequantize_cpu(&out, quant, &mut out_dequant);
1525
1526 dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1527 assert_eq!(out_dequant, out_dequant_simd);
1528
1529 let quant = Quantization::new(0.0040811873, 0);
1530 dequantize_cpu(&out, quant, &mut out_dequant);
1531
1532 dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1533 assert_eq!(out_dequant, out_dequant_simd);
1534 }
1535
1536 #[test]
1537 fn test_decoder_yolo_det() {
1538 let score_threshold = 0.25;
1539 let iou_threshold = 0.7;
1540 let out = include_bytes!(concat!(
1541 env!("CARGO_MANIFEST_DIR"),
1542 "/../../testdata/yolov8s_80_classes.bin"
1543 ));
1544 let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
1545 let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
1546 let quant = (0.0040811873, -123).into();
1547
1548 let decoder = DecoderBuilder::default()
1549 .with_config_yolo_det(
1550 configs::Detection {
1551 decoder: DecoderType::Ultralytics,
1552 shape: vec![1, 84, 8400],
1553 anchors: None,
1554 quantization: Some(quant),
1555 dshape: vec![
1556 (DimName::Batch, 1),
1557 (DimName::NumFeatures, 84),
1558 (DimName::NumBoxes, 8400),
1559 ],
1560 normalized: Some(true),
1561 },
1562 Some(DecoderVersion::Yolo11),
1563 )
1564 .with_score_threshold(score_threshold)
1565 .with_iou_threshold(iou_threshold)
1566 .build()
1567 .unwrap();
1568
1569 let mut output_boxes: Vec<_> = Vec::with_capacity(50);
1570 decode_yolo_det(
1571 (out.slice(s![0, .., ..]), quant.into()),
1572 score_threshold,
1573 iou_threshold,
1574 Some(configs::Nms::ClassAgnostic),
1575 &mut output_boxes,
1576 );
1577 assert!(output_boxes[0].equal_within_delta(
1578 &DetectBox {
1579 bbox: BoundingBox {
1580 xmin: 0.5285137,
1581 ymin: 0.05305544,
1582 xmax: 0.87541467,
1583 ymax: 0.9998909,
1584 },
1585 score: 0.5591227,
1586 label: 0
1587 },
1588 1e-6
1589 ));
1590
1591 assert!(output_boxes[1].equal_within_delta(
1592 &DetectBox {
1593 bbox: BoundingBox {
1594 xmin: 0.130598,
1595 ymin: 0.43260583,
1596 xmax: 0.35098213,
1597 ymax: 0.9958097,
1598 },
1599 score: 0.33057618,
1600 label: 75
1601 },
1602 1e-6
1603 ));
1604
1605 let mut output_boxes1: Vec<_> = Vec::with_capacity(50);
1606 let mut output_masks1: Vec<_> = Vec::with_capacity(50);
1607 decoder
1608 .decode_quantized(&[out.view().into()], &mut output_boxes1, &mut output_masks1)
1609 .unwrap();
1610
1611 let out = dequantize_ndarray(out.view(), quant.into());
1612 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(50);
1613 let mut output_masks_f32: Vec<_> = Vec::with_capacity(50);
1614 decoder
1615 .decode_float::<f32>(
1616 &[out.view().into_dyn()],
1617 &mut output_boxes_f32,
1618 &mut output_masks_f32,
1619 )
1620 .unwrap();
1621
1622 compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1623 compare_outputs((&output_boxes, &output_boxes_f32), (&[], &output_masks_f32));
1624 }
1625
1626 #[test]
1627 fn test_decoder_masks() {
1628 let score_threshold = 0.45;
1629 let iou_threshold = 0.45;
1630 let boxes = include_bytes!(concat!(
1631 env!("CARGO_MANIFEST_DIR"),
1632 "/../../testdata/yolov8_boxes_116x8400.bin"
1633 ));
1634 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1635 let boxes = ndarray::Array2::from_shape_vec((116, 8400), boxes.to_vec()).unwrap();
1636 let quant_boxes = Quantization::new(0.021287761628627777, 31);
1637
1638 let protos = include_bytes!(concat!(
1639 env!("CARGO_MANIFEST_DIR"),
1640 "/../../testdata/yolov8_protos_160x160x32.bin"
1641 ));
1642 let protos =
1643 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
1644 let protos = ndarray::Array3::from_shape_vec((160, 160, 32), protos.to_vec()).unwrap();
1645 let quant_protos = Quantization::new(0.02491161972284317, -117);
1646 let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1647 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1648 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1649 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1650 decode_yolo_segdet_float(
1651 seg.view(),
1652 protos.view(),
1653 score_threshold,
1654 iou_threshold,
1655 Some(configs::Nms::ClassAgnostic),
1656 &mut output_boxes,
1657 &mut output_masks,
1658 )
1659 .unwrap();
1660 assert_eq!(output_boxes.len(), 2);
1661 assert_eq!(output_boxes.len(), output_masks.len());
1662
1663 for (b, m) in output_boxes.iter().zip(&output_masks) {
1664 assert!(b.bbox.xmin >= m.xmin);
1665 assert!(b.bbox.ymin >= m.ymin);
1666 assert!(b.bbox.xmax >= m.xmax);
1667 assert!(b.bbox.ymax >= m.ymax);
1668 }
1669 assert!(output_boxes[0].equal_within_delta(
1670 &DetectBox {
1671 bbox: BoundingBox {
1672 xmin: 0.08515105,
1673 ymin: 0.7131401,
1674 xmax: 0.29802868,
1675 ymax: 0.8195788,
1676 },
1677 score: 0.91537374,
1678 label: 23
1679 },
1680 1.0 / 160.0, ));
1682
1683 assert!(output_boxes[1].equal_within_delta(
1684 &DetectBox {
1685 bbox: BoundingBox {
1686 xmin: 0.59605736,
1687 ymin: 0.25545314,
1688 xmax: 0.93666154,
1689 ymax: 0.72378385,
1690 },
1691 score: 0.91537374,
1692 label: 23
1693 },
1694 1.0 / 160.0, ));
1696
1697 let full_mask = include_bytes!(concat!(
1698 env!("CARGO_MANIFEST_DIR"),
1699 "/../../testdata/yolov8_mask_results.bin"
1700 ));
1701 let full_mask = ndarray::Array2::from_shape_vec((160, 160), full_mask.to_vec()).unwrap();
1702
1703 let cropped_mask = full_mask.slice(ndarray::s![
1704 (output_masks[1].ymin * 160.0) as usize..(output_masks[1].ymax * 160.0) as usize,
1705 (output_masks[1].xmin * 160.0) as usize..(output_masks[1].xmax * 160.0) as usize,
1706 ]);
1707
1708 assert_eq!(
1709 cropped_mask,
1710 segmentation_to_mask(output_masks[1].segmentation.view()).unwrap()
1711 );
1712 }
1713
1714 #[test]
1719 fn test_decoder_masks_nchw_protos() {
1720 let score_threshold = 0.45;
1721 let iou_threshold = 0.45;
1722
1723 let boxes_raw = include_bytes!(concat!(
1725 env!("CARGO_MANIFEST_DIR"),
1726 "/../../testdata/yolov8_boxes_116x8400.bin"
1727 ));
1728 let boxes_raw =
1729 unsafe { std::slice::from_raw_parts(boxes_raw.as_ptr() as *const i8, boxes_raw.len()) };
1730 let boxes_2d = ndarray::Array2::from_shape_vec((116, 8400), boxes_raw.to_vec()).unwrap();
1731 let quant_boxes = Quantization::new(0.021287761628627777, 31);
1732
1733 let protos_raw = include_bytes!(concat!(
1735 env!("CARGO_MANIFEST_DIR"),
1736 "/../../testdata/yolov8_protos_160x160x32.bin"
1737 ));
1738 let protos_raw = unsafe {
1739 std::slice::from_raw_parts(protos_raw.as_ptr() as *const i8, protos_raw.len())
1740 };
1741 let protos_hwc =
1742 ndarray::Array3::from_shape_vec((160, 160, 32), protos_raw.to_vec()).unwrap();
1743 let quant_protos = Quantization::new(0.02491161972284317, -117);
1744 let protos_f32_hwc = dequantize_ndarray::<_, _, f32>(protos_hwc.view(), quant_protos);
1745
1746 let seg = dequantize_ndarray::<_, _, f32>(boxes_2d.view(), quant_boxes);
1748 let mut ref_boxes: Vec<_> = Vec::with_capacity(10);
1749 let mut ref_masks: Vec<_> = Vec::with_capacity(10);
1750 decode_yolo_segdet_float(
1751 seg.view(),
1752 protos_f32_hwc.view(),
1753 score_threshold,
1754 iou_threshold,
1755 Some(configs::Nms::ClassAgnostic),
1756 &mut ref_boxes,
1757 &mut ref_masks,
1758 )
1759 .unwrap();
1760 assert_eq!(ref_boxes.len(), 2);
1761
1762 let protos_f32_chw = protos_f32_hwc.permuted_axes([2, 0, 1]); let protos_nchw = protos_f32_chw.insert_axis(ndarray::Axis(0)); let seg_3d = seg.insert_axis(ndarray::Axis(0)); let decoder = DecoderBuilder::default()
1772 .with_config_yolo_segdet(
1773 configs::Detection {
1774 decoder: configs::DecoderType::Ultralytics,
1775 quantization: None,
1776 shape: vec![1, 116, 8400],
1777 dshape: vec![],
1778 normalized: Some(true),
1779 anchors: None,
1780 },
1781 configs::Protos {
1782 decoder: configs::DecoderType::Ultralytics,
1783 quantization: None,
1784 shape: vec![1, 32, 160, 160],
1785 dshape: vec![], },
1787 None, )
1789 .with_score_threshold(score_threshold)
1790 .with_iou_threshold(iou_threshold)
1791 .build()
1792 .unwrap();
1793
1794 let mut cfg_boxes: Vec<_> = Vec::with_capacity(10);
1795 let mut cfg_masks: Vec<_> = Vec::with_capacity(10);
1796 decoder
1797 .decode_float(
1798 &[seg_3d.view().into_dyn(), protos_nchw.view().into_dyn()],
1799 &mut cfg_boxes,
1800 &mut cfg_masks,
1801 )
1802 .unwrap();
1803
1804 assert_eq!(
1806 cfg_boxes.len(),
1807 ref_boxes.len(),
1808 "config path produced {} boxes, reference produced {}",
1809 cfg_boxes.len(),
1810 ref_boxes.len()
1811 );
1812
1813 for (i, (cb, rb)) in cfg_boxes.iter().zip(&ref_boxes).enumerate() {
1815 assert!(
1816 cb.equal_within_delta(rb, 0.01),
1817 "box {i} mismatch: config={cb:?}, reference={rb:?}"
1818 );
1819 }
1820
1821 for (i, (cm, rm)) in cfg_masks.iter().zip(&ref_masks).enumerate() {
1823 let cm_arr = segmentation_to_mask(cm.segmentation.view()).unwrap();
1824 let rm_arr = segmentation_to_mask(rm.segmentation.view()).unwrap();
1825 assert_eq!(
1826 cm_arr, rm_arr,
1827 "mask {i} pixel mismatch between config-driven and reference paths"
1828 );
1829 }
1830 }
1831
1832 #[test]
1833 fn test_decoder_masks_i8() {
1834 let score_threshold = 0.45;
1835 let iou_threshold = 0.45;
1836 let boxes = include_bytes!(concat!(
1837 env!("CARGO_MANIFEST_DIR"),
1838 "/../../testdata/yolov8_boxes_116x8400.bin"
1839 ));
1840 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1841 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes.to_vec()).unwrap();
1842 let quant_boxes = (0.021287761628627777, 31).into();
1843
1844 let protos = include_bytes!(concat!(
1845 env!("CARGO_MANIFEST_DIR"),
1846 "/../../testdata/yolov8_protos_160x160x32.bin"
1847 ));
1848 let protos =
1849 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
1850 let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
1851 let quant_protos = (0.02491161972284317, -117).into();
1852 let mut output_boxes: Vec<_> = Vec::with_capacity(500);
1853 let mut output_masks: Vec<_> = Vec::with_capacity(500);
1854
1855 let decoder = DecoderBuilder::default()
1856 .with_config_yolo_segdet(
1857 configs::Detection {
1858 decoder: configs::DecoderType::Ultralytics,
1859 quantization: Some(quant_boxes),
1860 shape: vec![1, 116, 8400],
1861 anchors: None,
1862 dshape: vec![
1863 (DimName::Batch, 1),
1864 (DimName::NumFeatures, 116),
1865 (DimName::NumBoxes, 8400),
1866 ],
1867 normalized: Some(true),
1868 },
1869 Protos {
1870 decoder: configs::DecoderType::Ultralytics,
1871 quantization: Some(quant_protos),
1872 shape: vec![1, 160, 160, 32],
1873 dshape: vec![
1874 (DimName::Batch, 1),
1875 (DimName::Height, 160),
1876 (DimName::Width, 160),
1877 (DimName::NumProtos, 32),
1878 ],
1879 },
1880 Some(DecoderVersion::Yolo11),
1881 )
1882 .with_score_threshold(score_threshold)
1883 .with_iou_threshold(iou_threshold)
1884 .build()
1885 .unwrap();
1886
1887 let quant_boxes = quant_boxes.into();
1888 let quant_protos = quant_protos.into();
1889
1890 decode_yolo_segdet_quant(
1891 (boxes.slice(s![0, .., ..]), quant_boxes),
1892 (protos.slice(s![0, .., .., ..]), quant_protos),
1893 score_threshold,
1894 iou_threshold,
1895 Some(configs::Nms::ClassAgnostic),
1896 &mut output_boxes,
1897 &mut output_masks,
1898 )
1899 .unwrap();
1900
1901 let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
1902 let mut output_masks1: Vec<_> = Vec::with_capacity(500);
1903
1904 decoder
1905 .decode_quantized(
1906 &[boxes.view().into(), protos.view().into()],
1907 &mut output_boxes1,
1908 &mut output_masks1,
1909 )
1910 .unwrap();
1911
1912 let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1913 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1914
1915 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
1916 let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
1917 decode_yolo_segdet_float(
1918 seg.slice(s![0, .., ..]),
1919 protos.slice(s![0, .., .., ..]),
1920 score_threshold,
1921 iou_threshold,
1922 Some(configs::Nms::ClassAgnostic),
1923 &mut output_boxes_f32,
1924 &mut output_masks_f32,
1925 )
1926 .unwrap();
1927
1928 let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(500);
1929 let mut output_masks1_f32: Vec<_> = Vec::with_capacity(500);
1930
1931 decoder
1932 .decode_float(
1933 &[seg.view().into_dyn(), protos.view().into_dyn()],
1934 &mut output_boxes1_f32,
1935 &mut output_masks1_f32,
1936 )
1937 .unwrap();
1938
1939 compare_outputs(
1940 (&output_boxes, &output_boxes1),
1941 (&output_masks, &output_masks1),
1942 );
1943
1944 compare_outputs(
1945 (&output_boxes, &output_boxes_f32),
1946 (&output_masks, &output_masks_f32),
1947 );
1948
1949 compare_outputs(
1950 (&output_boxes_f32, &output_boxes1_f32),
1951 (&output_masks_f32, &output_masks1_f32),
1952 );
1953 }
1954
1955 #[test]
1956 fn test_decoder_yolo_split() {
1957 let score_threshold = 0.45;
1958 let iou_threshold = 0.45;
1959 let boxes = include_bytes!(concat!(
1960 env!("CARGO_MANIFEST_DIR"),
1961 "/../../testdata/yolov8_boxes_116x8400.bin"
1962 ));
1963 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1964 let boxes: Vec<_> = boxes.iter().map(|x| *x as i16 * 256).collect();
1965 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
1966
1967 let quant_boxes = Quantization::new(0.021287761628627777 / 256.0, 31 * 256);
1968
1969 let decoder = DecoderBuilder::default()
1970 .with_config_yolo_split_det(
1971 configs::Boxes {
1972 decoder: configs::DecoderType::Ultralytics,
1973 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
1974 shape: vec![1, 4, 8400],
1975 dshape: vec![
1976 (DimName::Batch, 1),
1977 (DimName::BoxCoords, 4),
1978 (DimName::NumBoxes, 8400),
1979 ],
1980 normalized: Some(true),
1981 },
1982 configs::Scores {
1983 decoder: configs::DecoderType::Ultralytics,
1984 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
1985 shape: vec![1, 80, 8400],
1986 dshape: vec![
1987 (DimName::Batch, 1),
1988 (DimName::NumClasses, 80),
1989 (DimName::NumBoxes, 8400),
1990 ],
1991 },
1992 )
1993 .with_score_threshold(score_threshold)
1994 .with_iou_threshold(iou_threshold)
1995 .build()
1996 .unwrap();
1997
1998 let mut output_boxes: Vec<_> = Vec::with_capacity(500);
1999 let mut output_masks: Vec<_> = Vec::with_capacity(500);
2000
2001 decoder
2002 .decode_quantized(
2003 &[
2004 boxes.slice(s![.., ..4, ..]).into(),
2005 boxes.slice(s![.., 4..84, ..]).into(),
2006 ],
2007 &mut output_boxes,
2008 &mut output_masks,
2009 )
2010 .unwrap();
2011
2012 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
2013 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2014 decode_yolo_det_float(
2015 seg.slice(s![0, ..84, ..]),
2016 score_threshold,
2017 iou_threshold,
2018 Some(configs::Nms::ClassAgnostic),
2019 &mut output_boxes_f32,
2020 );
2021
2022 let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
2023 let mut output_masks1: Vec<_> = Vec::with_capacity(500);
2024
2025 decoder
2026 .decode_float(
2027 &[
2028 seg.slice(s![.., ..4, ..]).into_dyn(),
2029 seg.slice(s![.., 4..84, ..]).into_dyn(),
2030 ],
2031 &mut output_boxes1,
2032 &mut output_masks1,
2033 )
2034 .unwrap();
2035 compare_outputs((&output_boxes, &output_boxes_f32), (&output_masks, &[]));
2036 compare_outputs((&output_boxes_f32, &output_boxes1), (&[], &output_masks1));
2037 }
2038
2039 #[test]
2040 fn test_decoder_masks_config_mixed() {
2041 let score_threshold = 0.45;
2042 let iou_threshold = 0.45;
2043 let boxes = include_bytes!(concat!(
2044 env!("CARGO_MANIFEST_DIR"),
2045 "/../../testdata/yolov8_boxes_116x8400.bin"
2046 ));
2047 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
2048 let boxes: Vec<_> = boxes.iter().map(|x| *x as i16 * 256).collect();
2049 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2050
2051 let quant_boxes = Quantization::new(0.021287761628627777 / 256.0, 31 * 256);
2052
2053 let protos = include_bytes!(concat!(
2054 env!("CARGO_MANIFEST_DIR"),
2055 "/../../testdata/yolov8_protos_160x160x32.bin"
2056 ));
2057 let protos =
2058 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
2059 let protos: Vec<_> = protos.to_vec();
2060 let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
2061 let quant_protos = Quantization::new(0.02491161972284317, -117);
2062
2063 let decoder = DecoderBuilder::default()
2064 .with_config_yolo_split_segdet(
2065 configs::Boxes {
2066 decoder: configs::DecoderType::Ultralytics,
2067 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2068 shape: vec![1, 4, 8400],
2069 dshape: vec![
2070 (DimName::Batch, 1),
2071 (DimName::BoxCoords, 4),
2072 (DimName::NumBoxes, 8400),
2073 ],
2074 normalized: Some(true),
2075 },
2076 configs::Scores {
2077 decoder: configs::DecoderType::Ultralytics,
2078 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2079 shape: vec![1, 80, 8400],
2080 dshape: vec![
2081 (DimName::Batch, 1),
2082 (DimName::NumClasses, 80),
2083 (DimName::NumBoxes, 8400),
2084 ],
2085 },
2086 configs::MaskCoefficients {
2087 decoder: configs::DecoderType::Ultralytics,
2088 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2089 shape: vec![1, 32, 8400],
2090 dshape: vec![
2091 (DimName::Batch, 1),
2092 (DimName::NumProtos, 32),
2093 (DimName::NumBoxes, 8400),
2094 ],
2095 },
2096 configs::Protos {
2097 decoder: configs::DecoderType::Ultralytics,
2098 quantization: Some(QuantTuple(quant_protos.scale, quant_protos.zero_point)),
2099 shape: vec![1, 160, 160, 32],
2100 dshape: vec![
2101 (DimName::Batch, 1),
2102 (DimName::Height, 160),
2103 (DimName::Width, 160),
2104 (DimName::NumProtos, 32),
2105 ],
2106 },
2107 )
2108 .with_score_threshold(score_threshold)
2109 .with_iou_threshold(iou_threshold)
2110 .build()
2111 .unwrap();
2112
2113 let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2114 let mut output_masks: Vec<_> = Vec::with_capacity(500);
2115
2116 decoder
2117 .decode_quantized(
2118 &[
2119 boxes.slice(s![.., ..4, ..]).into(),
2120 boxes.slice(s![.., 4..84, ..]).into(),
2121 boxes.slice(s![.., 84.., ..]).into(),
2122 protos.view().into(),
2123 ],
2124 &mut output_boxes,
2125 &mut output_masks,
2126 )
2127 .unwrap();
2128
2129 let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
2130 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
2131 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2132 let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
2133 decode_yolo_segdet_float(
2134 seg.slice(s![0, .., ..]),
2135 protos.slice(s![0, .., .., ..]),
2136 score_threshold,
2137 iou_threshold,
2138 Some(configs::Nms::ClassAgnostic),
2139 &mut output_boxes_f32,
2140 &mut output_masks_f32,
2141 )
2142 .unwrap();
2143
2144 let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
2145 let mut output_masks1: Vec<_> = Vec::with_capacity(500);
2146
2147 decoder
2148 .decode_float(
2149 &[
2150 seg.slice(s![.., ..4, ..]).into_dyn(),
2151 seg.slice(s![.., 4..84, ..]).into_dyn(),
2152 seg.slice(s![.., 84.., ..]).into_dyn(),
2153 protos.view().into_dyn(),
2154 ],
2155 &mut output_boxes1,
2156 &mut output_masks1,
2157 )
2158 .unwrap();
2159 compare_outputs(
2160 (&output_boxes, &output_boxes_f32),
2161 (&output_masks, &output_masks_f32),
2162 );
2163 compare_outputs(
2164 (&output_boxes_f32, &output_boxes1),
2165 (&output_masks_f32, &output_masks1),
2166 );
2167 }
2168
2169 #[test]
2170 fn test_decoder_masks_config_i32() {
2171 let score_threshold = 0.45;
2172 let iou_threshold = 0.45;
2173 let boxes = include_bytes!(concat!(
2174 env!("CARGO_MANIFEST_DIR"),
2175 "/../../testdata/yolov8_boxes_116x8400.bin"
2176 ));
2177 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
2178 let scale = 1 << 23;
2179 let boxes: Vec<_> = boxes.iter().map(|x| *x as i32 * scale).collect();
2180 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2181
2182 let quant_boxes = Quantization::new(0.021287761628627777 / scale as f32, 31 * scale);
2183
2184 let protos = include_bytes!(concat!(
2185 env!("CARGO_MANIFEST_DIR"),
2186 "/../../testdata/yolov8_protos_160x160x32.bin"
2187 ));
2188 let protos =
2189 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
2190 let protos: Vec<_> = protos.iter().map(|x| *x as i32 * scale).collect();
2191 let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
2192 let quant_protos = Quantization::new(0.02491161972284317 / scale as f32, -117 * scale);
2193
2194 let decoder = DecoderBuilder::default()
2195 .with_config_yolo_split_segdet(
2196 configs::Boxes {
2197 decoder: configs::DecoderType::Ultralytics,
2198 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2199 shape: vec![1, 4, 8400],
2200 dshape: vec![
2201 (DimName::Batch, 1),
2202 (DimName::BoxCoords, 4),
2203 (DimName::NumBoxes, 8400),
2204 ],
2205 normalized: Some(true),
2206 },
2207 configs::Scores {
2208 decoder: configs::DecoderType::Ultralytics,
2209 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2210 shape: vec![1, 80, 8400],
2211 dshape: vec![
2212 (DimName::Batch, 1),
2213 (DimName::NumClasses, 80),
2214 (DimName::NumBoxes, 8400),
2215 ],
2216 },
2217 configs::MaskCoefficients {
2218 decoder: configs::DecoderType::Ultralytics,
2219 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2220 shape: vec![1, 32, 8400],
2221 dshape: vec![
2222 (DimName::Batch, 1),
2223 (DimName::NumProtos, 32),
2224 (DimName::NumBoxes, 8400),
2225 ],
2226 },
2227 configs::Protos {
2228 decoder: configs::DecoderType::Ultralytics,
2229 quantization: Some(QuantTuple(quant_protos.scale, quant_protos.zero_point)),
2230 shape: vec![1, 160, 160, 32],
2231 dshape: vec![
2232 (DimName::Batch, 1),
2233 (DimName::Height, 160),
2234 (DimName::Width, 160),
2235 (DimName::NumProtos, 32),
2236 ],
2237 },
2238 )
2239 .with_score_threshold(score_threshold)
2240 .with_iou_threshold(iou_threshold)
2241 .build()
2242 .unwrap();
2243
2244 let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2245 let mut output_masks: Vec<_> = Vec::with_capacity(500);
2246
2247 decoder
2248 .decode_quantized(
2249 &[
2250 boxes.slice(s![.., ..4, ..]).into(),
2251 boxes.slice(s![.., 4..84, ..]).into(),
2252 boxes.slice(s![.., 84.., ..]).into(),
2253 protos.view().into(),
2254 ],
2255 &mut output_boxes,
2256 &mut output_masks,
2257 )
2258 .unwrap();
2259
2260 let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
2261 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
2262 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2263 let mut output_masks_f32: Vec<Segmentation> = Vec::with_capacity(500);
2264 decode_yolo_segdet_float(
2265 seg.slice(s![0, .., ..]),
2266 protos.slice(s![0, .., .., ..]),
2267 score_threshold,
2268 iou_threshold,
2269 Some(configs::Nms::ClassAgnostic),
2270 &mut output_boxes_f32,
2271 &mut output_masks_f32,
2272 )
2273 .unwrap();
2274
2275 assert_eq!(output_boxes.len(), output_boxes_f32.len());
2276 assert_eq!(output_masks.len(), output_masks_f32.len());
2277
2278 compare_outputs(
2279 (&output_boxes, &output_boxes_f32),
2280 (&output_masks, &output_masks_f32),
2281 );
2282 }
2283
2284 #[test]
2286 fn test_context_switch() {
2287 let yolo_det = || {
2288 let score_threshold = 0.25;
2289 let iou_threshold = 0.7;
2290 let out = include_bytes!(concat!(
2291 env!("CARGO_MANIFEST_DIR"),
2292 "/../../testdata/yolov8s_80_classes.bin"
2293 ));
2294 let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
2295 let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
2296 let quant = (0.0040811873, -123).into();
2297
2298 let decoder = DecoderBuilder::default()
2299 .with_config_yolo_det(
2300 configs::Detection {
2301 decoder: DecoderType::Ultralytics,
2302 shape: vec![1, 84, 8400],
2303 anchors: None,
2304 quantization: Some(quant),
2305 dshape: vec![
2306 (DimName::Batch, 1),
2307 (DimName::NumFeatures, 84),
2308 (DimName::NumBoxes, 8400),
2309 ],
2310 normalized: None,
2311 },
2312 None,
2313 )
2314 .with_score_threshold(score_threshold)
2315 .with_iou_threshold(iou_threshold)
2316 .build()
2317 .unwrap();
2318
2319 let mut output_boxes: Vec<_> = Vec::with_capacity(50);
2320 let mut output_masks: Vec<_> = Vec::with_capacity(50);
2321
2322 for _ in 0..100 {
2323 decoder
2324 .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
2325 .unwrap();
2326
2327 assert!(output_boxes[0].equal_within_delta(
2328 &DetectBox {
2329 bbox: BoundingBox {
2330 xmin: 0.5285137,
2331 ymin: 0.05305544,
2332 xmax: 0.87541467,
2333 ymax: 0.9998909,
2334 },
2335 score: 0.5591227,
2336 label: 0
2337 },
2338 1e-6
2339 ));
2340
2341 assert!(output_boxes[1].equal_within_delta(
2342 &DetectBox {
2343 bbox: BoundingBox {
2344 xmin: 0.130598,
2345 ymin: 0.43260583,
2346 xmax: 0.35098213,
2347 ymax: 0.9958097,
2348 },
2349 score: 0.33057618,
2350 label: 75
2351 },
2352 1e-6
2353 ));
2354 assert!(output_masks.is_empty());
2355 }
2356 };
2357
2358 let modelpack_det_split = || {
2359 let score_threshold = 0.8;
2360 let iou_threshold = 0.5;
2361
2362 let seg = include_bytes!(concat!(
2363 env!("CARGO_MANIFEST_DIR"),
2364 "/../../testdata/modelpack_seg_2x160x160.bin"
2365 ));
2366 let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
2367
2368 let detect0 = include_bytes!(concat!(
2369 env!("CARGO_MANIFEST_DIR"),
2370 "/../../testdata/modelpack_split_9x15x18.bin"
2371 ));
2372 let detect0 =
2373 ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
2374
2375 let detect1 = include_bytes!(concat!(
2376 env!("CARGO_MANIFEST_DIR"),
2377 "/../../testdata/modelpack_split_17x30x18.bin"
2378 ));
2379 let detect1 =
2380 ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
2381
2382 let mut mask = seg.slice(s![0, .., .., ..]);
2383 mask.swap_axes(0, 1);
2384 mask.swap_axes(1, 2);
2385 let mask = [Segmentation {
2386 xmin: 0.0,
2387 ymin: 0.0,
2388 xmax: 1.0,
2389 ymax: 1.0,
2390 segmentation: mask.into_owned(),
2391 }];
2392 let correct_boxes = [DetectBox {
2393 bbox: BoundingBox {
2394 xmin: 0.43171933,
2395 ymin: 0.68243736,
2396 xmax: 0.5626645,
2397 ymax: 0.808863,
2398 },
2399 score: 0.99240804,
2400 label: 0,
2401 }];
2402
2403 let quant0 = (0.08547406643629074, 174).into();
2404 let quant1 = (0.09929127991199493, 183).into();
2405 let quant_seg = (1.0 / 255.0, 0).into();
2406
2407 let anchors0 = vec![
2408 [0.36666667461395264, 0.31481480598449707],
2409 [0.38749998807907104, 0.4740740656852722],
2410 [0.5333333611488342, 0.644444465637207],
2411 ];
2412 let anchors1 = vec![
2413 [0.13750000298023224, 0.2074074000120163],
2414 [0.2541666626930237, 0.21481481194496155],
2415 [0.23125000298023224, 0.35185185074806213],
2416 ];
2417
2418 let decoder = DecoderBuilder::default()
2419 .with_config_modelpack_segdet_split(
2420 vec![
2421 configs::Detection {
2422 decoder: DecoderType::ModelPack,
2423 shape: vec![1, 17, 30, 18],
2424 anchors: Some(anchors1),
2425 quantization: Some(quant1),
2426 dshape: vec![
2427 (DimName::Batch, 1),
2428 (DimName::Height, 17),
2429 (DimName::Width, 30),
2430 (DimName::NumAnchorsXFeatures, 18),
2431 ],
2432 normalized: None,
2433 },
2434 configs::Detection {
2435 decoder: DecoderType::ModelPack,
2436 shape: vec![1, 9, 15, 18],
2437 anchors: Some(anchors0),
2438 quantization: Some(quant0),
2439 dshape: vec![
2440 (DimName::Batch, 1),
2441 (DimName::Height, 9),
2442 (DimName::Width, 15),
2443 (DimName::NumAnchorsXFeatures, 18),
2444 ],
2445 normalized: None,
2446 },
2447 ],
2448 configs::Segmentation {
2449 decoder: DecoderType::ModelPack,
2450 quantization: Some(quant_seg),
2451 shape: vec![1, 2, 160, 160],
2452 dshape: vec![
2453 (DimName::Batch, 1),
2454 (DimName::NumClasses, 2),
2455 (DimName::Height, 160),
2456 (DimName::Width, 160),
2457 ],
2458 },
2459 )
2460 .with_score_threshold(score_threshold)
2461 .with_iou_threshold(iou_threshold)
2462 .build()
2463 .unwrap();
2464 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
2465 let mut output_masks: Vec<_> = Vec::with_capacity(10);
2466
2467 for _ in 0..100 {
2468 decoder
2469 .decode_quantized(
2470 &[
2471 detect0.view().into(),
2472 detect1.view().into(),
2473 seg.view().into(),
2474 ],
2475 &mut output_boxes,
2476 &mut output_masks,
2477 )
2478 .unwrap();
2479
2480 compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
2481 }
2482 };
2483
2484 let handles = vec![
2485 std::thread::spawn(yolo_det),
2486 std::thread::spawn(modelpack_det_split),
2487 std::thread::spawn(yolo_det),
2488 std::thread::spawn(modelpack_det_split),
2489 std::thread::spawn(yolo_det),
2490 std::thread::spawn(modelpack_det_split),
2491 std::thread::spawn(yolo_det),
2492 std::thread::spawn(modelpack_det_split),
2493 ];
2494 for handle in handles {
2495 handle.join().unwrap();
2496 }
2497 }
2498
2499 #[test]
2500 fn test_ndarray_to_xyxy_float() {
2501 let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2502 let xyxy: [f32; 4] = XYWH::ndarray_to_xyxy_float(arr.view());
2503 assert_eq!(xyxy, [0.0_f32, 10.0, 20.0, 30.0]);
2504
2505 let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2506 let xyxy: [f32; 4] = XYXY::ndarray_to_xyxy_float(arr.view());
2507 assert_eq!(xyxy, [10.0_f32, 20.0, 20.0, 20.0]);
2508 }
2509
2510 #[test]
2511 fn test_class_aware_nms_float() {
2512 use crate::float::nms_class_aware_float;
2513
2514 let boxes = vec![
2516 DetectBox {
2517 bbox: BoundingBox {
2518 xmin: 0.0,
2519 ymin: 0.0,
2520 xmax: 0.5,
2521 ymax: 0.5,
2522 },
2523 score: 0.9,
2524 label: 0, },
2526 DetectBox {
2527 bbox: BoundingBox {
2528 xmin: 0.1,
2529 ymin: 0.1,
2530 xmax: 0.6,
2531 ymax: 0.6,
2532 },
2533 score: 0.8,
2534 label: 1, },
2536 ];
2537
2538 let result = nms_class_aware_float(0.3, boxes.clone());
2541 assert_eq!(
2542 result.len(),
2543 2,
2544 "Class-aware NMS should keep both boxes with different classes"
2545 );
2546
2547 let same_class_boxes = vec![
2549 DetectBox {
2550 bbox: BoundingBox {
2551 xmin: 0.0,
2552 ymin: 0.0,
2553 xmax: 0.5,
2554 ymax: 0.5,
2555 },
2556 score: 0.9,
2557 label: 0,
2558 },
2559 DetectBox {
2560 bbox: BoundingBox {
2561 xmin: 0.1,
2562 ymin: 0.1,
2563 xmax: 0.6,
2564 ymax: 0.6,
2565 },
2566 score: 0.8,
2567 label: 0, },
2569 ];
2570
2571 let result = nms_class_aware_float(0.3, same_class_boxes);
2572 assert_eq!(
2573 result.len(),
2574 1,
2575 "Class-aware NMS should suppress overlapping box with same class"
2576 );
2577 assert_eq!(result[0].label, 0);
2578 assert!((result[0].score - 0.9).abs() < 1e-6);
2579 }
2580
2581 #[test]
2582 fn test_class_agnostic_vs_aware_nms() {
2583 use crate::float::{nms_class_aware_float, nms_float};
2584
2585 let boxes = vec![
2587 DetectBox {
2588 bbox: BoundingBox {
2589 xmin: 0.0,
2590 ymin: 0.0,
2591 xmax: 0.5,
2592 ymax: 0.5,
2593 },
2594 score: 0.9,
2595 label: 0,
2596 },
2597 DetectBox {
2598 bbox: BoundingBox {
2599 xmin: 0.1,
2600 ymin: 0.1,
2601 xmax: 0.6,
2602 ymax: 0.6,
2603 },
2604 score: 0.8,
2605 label: 1,
2606 },
2607 ];
2608
2609 let agnostic_result = nms_float(0.3, boxes.clone());
2611 assert_eq!(
2612 agnostic_result.len(),
2613 1,
2614 "Class-agnostic NMS should suppress overlapping boxes"
2615 );
2616
2617 let aware_result = nms_class_aware_float(0.3, boxes);
2619 assert_eq!(
2620 aware_result.len(),
2621 2,
2622 "Class-aware NMS should keep boxes with different classes"
2623 );
2624 }
2625
2626 #[test]
2627 fn test_class_aware_nms_int() {
2628 use crate::byte::nms_class_aware_int;
2629
2630 let boxes = vec![
2632 DetectBoxQuantized {
2633 bbox: BoundingBox {
2634 xmin: 0.0,
2635 ymin: 0.0,
2636 xmax: 0.5,
2637 ymax: 0.5,
2638 },
2639 score: 200_u8,
2640 label: 0,
2641 },
2642 DetectBoxQuantized {
2643 bbox: BoundingBox {
2644 xmin: 0.1,
2645 ymin: 0.1,
2646 xmax: 0.6,
2647 ymax: 0.6,
2648 },
2649 score: 180_u8,
2650 label: 1, },
2652 ];
2653
2654 let result = nms_class_aware_int(0.5, boxes);
2656 assert_eq!(
2657 result.len(),
2658 2,
2659 "Class-aware NMS (int) should keep boxes with different classes"
2660 );
2661 }
2662
2663 #[test]
2664 fn test_nms_enum_default() {
2665 let default_nms: configs::Nms = Default::default();
2667 assert_eq!(default_nms, configs::Nms::ClassAgnostic);
2668 }
2669
2670 #[test]
2671 fn test_decoder_nms_mode() {
2672 let decoder = DecoderBuilder::default()
2674 .with_config_yolo_det(
2675 configs::Detection {
2676 anchors: None,
2677 decoder: DecoderType::Ultralytics,
2678 quantization: None,
2679 shape: vec![1, 84, 8400],
2680 dshape: Vec::new(),
2681 normalized: Some(true),
2682 },
2683 None,
2684 )
2685 .with_nms(Some(configs::Nms::ClassAware))
2686 .build()
2687 .unwrap();
2688
2689 assert_eq!(decoder.nms, Some(configs::Nms::ClassAware));
2690 }
2691
2692 #[test]
2693 fn test_decoder_nms_bypass() {
2694 let decoder = DecoderBuilder::default()
2696 .with_config_yolo_det(
2697 configs::Detection {
2698 anchors: None,
2699 decoder: DecoderType::Ultralytics,
2700 quantization: None,
2701 shape: vec![1, 84, 8400],
2702 dshape: Vec::new(),
2703 normalized: Some(true),
2704 },
2705 None,
2706 )
2707 .with_nms(None)
2708 .build()
2709 .unwrap();
2710
2711 assert_eq!(decoder.nms, None);
2712 }
2713
2714 #[test]
2715 fn test_decoder_normalized_boxes_true() {
2716 let decoder = DecoderBuilder::default()
2718 .with_config_yolo_det(
2719 configs::Detection {
2720 anchors: None,
2721 decoder: DecoderType::Ultralytics,
2722 quantization: None,
2723 shape: vec![1, 84, 8400],
2724 dshape: Vec::new(),
2725 normalized: Some(true),
2726 },
2727 None,
2728 )
2729 .build()
2730 .unwrap();
2731
2732 assert_eq!(decoder.normalized_boxes(), Some(true));
2733 }
2734
2735 #[test]
2736 fn test_decoder_normalized_boxes_false() {
2737 let decoder = DecoderBuilder::default()
2740 .with_config_yolo_det(
2741 configs::Detection {
2742 anchors: None,
2743 decoder: DecoderType::Ultralytics,
2744 quantization: None,
2745 shape: vec![1, 84, 8400],
2746 dshape: Vec::new(),
2747 normalized: Some(false),
2748 },
2749 None,
2750 )
2751 .build()
2752 .unwrap();
2753
2754 assert_eq!(decoder.normalized_boxes(), Some(false));
2755 }
2756
2757 #[test]
2758 fn test_decoder_normalized_boxes_unknown() {
2759 let decoder = DecoderBuilder::default()
2761 .with_config_yolo_det(
2762 configs::Detection {
2763 anchors: None,
2764 decoder: DecoderType::Ultralytics,
2765 quantization: None,
2766 shape: vec![1, 84, 8400],
2767 dshape: Vec::new(),
2768 normalized: None,
2769 },
2770 Some(DecoderVersion::Yolo11),
2771 )
2772 .build()
2773 .unwrap();
2774
2775 assert_eq!(decoder.normalized_boxes(), None);
2776 }
2777}
2778
2779#[cfg(feature = "tracker")]
2780#[cfg(test)]
2781#[cfg_attr(coverage_nightly, coverage(off))]
2782mod decoder_tracked_tests {
2783
2784 use edgefirst_tracker::{ByteTrackBuilder, Tracker};
2785 use ndarray::{array, s, Array, Array2, Array3, Array4, ArrayView, Axis, Dimension};
2786 use num_traits::{AsPrimitive, Float, PrimInt};
2787 use rand::{RngExt, SeedableRng};
2788 use rand_distr::StandardNormal;
2789
2790 use crate::{
2791 configs::{self, DimName},
2792 dequantize_ndarray, BoundingBox, DecoderBuilder, DetectBox, Quantization,
2793 };
2794
2795 pub fn quantize_ndarray<T: PrimInt + 'static, D: Dimension, F: Float + AsPrimitive<T>>(
2796 input: ArrayView<F, D>,
2797 quant: Quantization,
2798 ) -> Array<T, D>
2799 where
2800 i32: num_traits::AsPrimitive<F>,
2801 f32: num_traits::AsPrimitive<F>,
2802 {
2803 let zero_point = quant.zero_point.as_();
2804 let div_scale = F::one() / quant.scale.as_();
2805 if zero_point != F::zero() {
2806 input.mapv(|d| (d * div_scale + zero_point).round().as_())
2807 } else {
2808 input.mapv(|d| (d * div_scale).round().as_())
2809 }
2810 }
2811
2812 #[test]
2813 fn test_decoder_tracked_random_jitter() {
2814 use crate::configs::{DecoderType, Nms};
2815 use crate::DecoderBuilder;
2816
2817 let score_threshold = 0.25;
2818 let iou_threshold = 0.1;
2819 let out = include_bytes!(concat!(
2820 env!("CARGO_MANIFEST_DIR"),
2821 "/../../testdata/yolov8s_80_classes.bin"
2822 ));
2823 let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
2824 let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
2825 let quant = (0.0040811873, -123).into();
2826
2827 let decoder = DecoderBuilder::default()
2828 .with_config_yolo_det(
2829 crate::configs::Detection {
2830 decoder: DecoderType::Ultralytics,
2831 shape: vec![1, 84, 8400],
2832 anchors: None,
2833 quantization: Some(quant),
2834 dshape: vec![
2835 (crate::configs::DimName::Batch, 1),
2836 (crate::configs::DimName::NumFeatures, 84),
2837 (crate::configs::DimName::NumBoxes, 8400),
2838 ],
2839 normalized: Some(true),
2840 },
2841 None,
2842 )
2843 .with_score_threshold(score_threshold)
2844 .with_iou_threshold(iou_threshold)
2845 .with_nms(Some(Nms::ClassAgnostic))
2846 .build()
2847 .unwrap();
2848 let mut rng = rand::rngs::StdRng::seed_from_u64(0xAB_BEEF); let expected_boxes = [
2851 crate::DetectBox {
2852 bbox: crate::BoundingBox {
2853 xmin: 0.5285137,
2854 ymin: 0.05305544,
2855 xmax: 0.87541467,
2856 ymax: 0.9998909,
2857 },
2858 score: 0.5591227,
2859 label: 0,
2860 },
2861 crate::DetectBox {
2862 bbox: crate::BoundingBox {
2863 xmin: 0.130598,
2864 ymin: 0.43260583,
2865 xmax: 0.35098213,
2866 ymax: 0.9958097,
2867 },
2868 score: 0.33057618,
2869 label: 75,
2870 },
2871 ];
2872
2873 let mut tracker = ByteTrackBuilder::new()
2874 .track_update(0.1)
2875 .track_high_conf(0.3)
2876 .build();
2877
2878 let mut output_boxes = Vec::with_capacity(50);
2879 let mut output_masks = Vec::with_capacity(50);
2880 let mut output_tracks = Vec::with_capacity(50);
2881
2882 decoder
2883 .decode_tracked_quantized(
2884 &mut tracker,
2885 0,
2886 &[out.view().into()],
2887 &mut output_boxes,
2888 &mut output_masks,
2889 &mut output_tracks,
2890 )
2891 .unwrap();
2892
2893 assert_eq!(output_boxes.len(), 2);
2894 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
2895 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
2896
2897 let mut last_boxes = output_boxes.clone();
2898
2899 for i in 1..=100 {
2900 let mut out = out.clone();
2901 let mut x_values = out.slice_mut(s![0, 0, ..]);
2903 for x in x_values.iter_mut() {
2904 let r: f32 = rng.sample(StandardNormal);
2905 let r = r.clamp(-2.0, 2.0) / 2.0;
2906 *x = x.saturating_add((r * 1e-2 / quant.0) as i8);
2907 }
2908
2909 let mut y_values = out.slice_mut(s![0, 1, ..]);
2910 for y in y_values.iter_mut() {
2911 let r: f32 = rng.sample(StandardNormal);
2912 let r = r.clamp(-2.0, 2.0) / 2.0;
2913 *y = y.saturating_add((r * 1e-2 / quant.0) as i8);
2914 }
2915
2916 decoder
2917 .decode_tracked_quantized(
2918 &mut tracker,
2919 100_000_000 * i / 3, &[out.view().into()],
2921 &mut output_boxes,
2922 &mut output_masks,
2923 &mut output_tracks,
2924 )
2925 .unwrap();
2926
2927 assert_eq!(output_boxes.len(), 2);
2928 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 5e-3));
2929 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 5e-3));
2930
2931 assert!(output_boxes[0].equal_within_delta(&last_boxes[0], 1e-3));
2932 assert!(output_boxes[1].equal_within_delta(&last_boxes[1], 1e-3));
2933 last_boxes = output_boxes.clone();
2934 }
2935 }
2936
2937 #[test]
2938 fn test_decoder_tracked_segdet() {
2939 use crate::configs::Nms;
2940 use crate::DecoderBuilder;
2941
2942 let score_threshold = 0.45;
2943 let iou_threshold = 0.45;
2944 let boxes = include_bytes!(concat!(
2945 env!("CARGO_MANIFEST_DIR"),
2946 "/../../testdata/yolov8_boxes_116x8400.bin"
2947 ));
2948 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
2949 let mut boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes.to_vec()).unwrap();
2950
2951 let protos = include_bytes!(concat!(
2952 env!("CARGO_MANIFEST_DIR"),
2953 "/../../testdata/yolov8_protos_160x160x32.bin"
2954 ));
2955 let protos =
2956 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
2957 let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
2958
2959 let config = include_str!(concat!(
2960 env!("CARGO_MANIFEST_DIR"),
2961 "/../../testdata/yolov8_seg.yaml"
2962 ));
2963
2964 let decoder = DecoderBuilder::default()
2965 .with_config_yaml_str(config.to_string())
2966 .with_score_threshold(score_threshold)
2967 .with_iou_threshold(iou_threshold)
2968 .with_nms(Some(Nms::ClassAgnostic))
2969 .build()
2970 .unwrap();
2971
2972 let expected_boxes = [
2973 DetectBox {
2974 bbox: BoundingBox {
2975 xmin: 0.08515105,
2976 ymin: 0.7131401,
2977 xmax: 0.29802868,
2978 ymax: 0.8195788,
2979 },
2980 score: 0.91537374,
2981 label: 23,
2982 },
2983 DetectBox {
2984 bbox: BoundingBox {
2985 xmin: 0.59605736,
2986 ymin: 0.25545314,
2987 xmax: 0.93666154,
2988 ymax: 0.72378385,
2989 },
2990 score: 0.91537374,
2991 label: 23,
2992 },
2993 ];
2994
2995 let mut tracker = ByteTrackBuilder::new()
2996 .track_update(0.1)
2997 .track_high_conf(0.7)
2998 .build();
2999
3000 let mut output_boxes = Vec::with_capacity(50);
3001 let mut output_masks = Vec::with_capacity(50);
3002 let mut output_tracks = Vec::with_capacity(50);
3003
3004 decoder
3005 .decode_tracked_quantized(
3006 &mut tracker,
3007 0,
3008 &[boxes.view().into(), protos.view().into()],
3009 &mut output_boxes,
3010 &mut output_masks,
3011 &mut output_tracks,
3012 )
3013 .unwrap();
3014
3015 assert_eq!(output_boxes.len(), 2);
3016 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
3017 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1.0 / 160.0));
3018
3019 let mut scores_values = boxes.slice_mut(s![0, 4..84, ..]);
3021 for score in scores_values.iter_mut() {
3022 *score = i8::MIN; }
3024 decoder
3025 .decode_tracked_quantized(
3026 &mut tracker,
3027 100_000_000 / 3,
3028 &[boxes.view().into(), protos.view().into()],
3029 &mut output_boxes,
3030 &mut output_masks,
3031 &mut output_tracks,
3032 )
3033 .unwrap();
3034
3035 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3036 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3037
3038 assert!(output_masks.is_empty())
3040 }
3041
3042 #[test]
3043 fn test_decoder_tracked_segdet_float() {
3044 use crate::configs::Nms;
3045 use crate::DecoderBuilder;
3046
3047 let score_threshold = 0.45;
3048 let iou_threshold = 0.45;
3049 let boxes = include_bytes!(concat!(
3050 env!("CARGO_MANIFEST_DIR"),
3051 "/../../testdata/yolov8_boxes_116x8400.bin"
3052 ));
3053 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
3054 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes.to_vec()).unwrap();
3055 let quant_boxes = (0.021287762, 31);
3056 let mut boxes = dequantize_ndarray(boxes.view(), quant_boxes.into());
3057
3058 let protos = include_bytes!(concat!(
3059 env!("CARGO_MANIFEST_DIR"),
3060 "/../../testdata/yolov8_protos_160x160x32.bin"
3061 ));
3062 let protos =
3063 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
3064 let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
3065 let quant_protos = (0.02491162, -117);
3066 let protos = dequantize_ndarray(protos.view(), quant_protos.into());
3067
3068 let config = include_str!(concat!(
3069 env!("CARGO_MANIFEST_DIR"),
3070 "/../../testdata/yolov8_seg.yaml"
3071 ));
3072
3073 let decoder = DecoderBuilder::default()
3074 .with_config_yaml_str(config.to_string())
3075 .with_score_threshold(score_threshold)
3076 .with_iou_threshold(iou_threshold)
3077 .with_nms(Some(Nms::ClassAgnostic))
3078 .build()
3079 .unwrap();
3080
3081 let expected_boxes = [
3082 DetectBox {
3083 bbox: BoundingBox {
3084 xmin: 0.08515105,
3085 ymin: 0.7131401,
3086 xmax: 0.29802868,
3087 ymax: 0.8195788,
3088 },
3089 score: 0.91537374,
3090 label: 23,
3091 },
3092 DetectBox {
3093 bbox: BoundingBox {
3094 xmin: 0.59605736,
3095 ymin: 0.25545314,
3096 xmax: 0.93666154,
3097 ymax: 0.72378385,
3098 },
3099 score: 0.91537374,
3100 label: 23,
3101 },
3102 ];
3103
3104 let mut tracker = ByteTrackBuilder::new()
3105 .track_update(0.1)
3106 .track_high_conf(0.7)
3107 .build();
3108
3109 let mut output_boxes = Vec::with_capacity(50);
3110 let mut output_masks = Vec::with_capacity(50);
3111 let mut output_tracks = Vec::with_capacity(50);
3112
3113 decoder
3114 .decode_tracked_float(
3115 &mut tracker,
3116 0,
3117 &[boxes.view().into_dyn(), protos.view().into_dyn()],
3118 &mut output_boxes,
3119 &mut output_masks,
3120 &mut output_tracks,
3121 )
3122 .unwrap();
3123
3124 assert_eq!(output_boxes.len(), 2);
3125 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
3126 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1.0 / 160.0));
3127
3128 let mut scores_values = boxes.slice_mut(s![0, 4..84, ..]);
3130 for score in scores_values.iter_mut() {
3131 *score = 0.0; }
3133 decoder
3134 .decode_tracked_float(
3135 &mut tracker,
3136 100_000_000 / 3,
3137 &[boxes.view().into_dyn(), protos.view().into_dyn()],
3138 &mut output_boxes,
3139 &mut output_masks,
3140 &mut output_tracks,
3141 )
3142 .unwrap();
3143
3144 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3145 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3146
3147 assert!(output_masks.is_empty())
3149 }
3150
3151 #[test]
3152 fn test_decoder_tracked_segdet_proto() {
3153 use crate::configs::Nms;
3154 use crate::DecoderBuilder;
3155
3156 let score_threshold = 0.45;
3157 let iou_threshold = 0.45;
3158 let boxes = include_bytes!(concat!(
3159 env!("CARGO_MANIFEST_DIR"),
3160 "/../../testdata/yolov8_boxes_116x8400.bin"
3161 ));
3162 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
3163 let mut boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes.to_vec()).unwrap();
3164
3165 let protos = include_bytes!(concat!(
3166 env!("CARGO_MANIFEST_DIR"),
3167 "/../../testdata/yolov8_protos_160x160x32.bin"
3168 ));
3169 let protos =
3170 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
3171 let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
3172
3173 let config = include_str!(concat!(
3174 env!("CARGO_MANIFEST_DIR"),
3175 "/../../testdata/yolov8_seg.yaml"
3176 ));
3177
3178 let decoder = DecoderBuilder::default()
3179 .with_config_yaml_str(config.to_string())
3180 .with_score_threshold(score_threshold)
3181 .with_iou_threshold(iou_threshold)
3182 .with_nms(Some(Nms::ClassAgnostic))
3183 .build()
3184 .unwrap();
3185
3186 let expected_boxes = [
3187 DetectBox {
3188 bbox: BoundingBox {
3189 xmin: 0.08515105,
3190 ymin: 0.7131401,
3191 xmax: 0.29802868,
3192 ymax: 0.8195788,
3193 },
3194 score: 0.91537374,
3195 label: 23,
3196 },
3197 DetectBox {
3198 bbox: BoundingBox {
3199 xmin: 0.59605736,
3200 ymin: 0.25545314,
3201 xmax: 0.93666154,
3202 ymax: 0.72378385,
3203 },
3204 score: 0.91537374,
3205 label: 23,
3206 },
3207 ];
3208
3209 let mut tracker = ByteTrackBuilder::new()
3210 .track_update(0.1)
3211 .track_high_conf(0.7)
3212 .build();
3213
3214 let mut output_boxes = Vec::with_capacity(50);
3215 let mut output_tracks = Vec::with_capacity(50);
3216
3217 decoder
3218 .decode_tracked_quantized_proto(
3219 &mut tracker,
3220 0,
3221 &[boxes.view().into(), protos.view().into()],
3222 &mut output_boxes,
3223 &mut output_tracks,
3224 )
3225 .unwrap();
3226
3227 assert_eq!(output_boxes.len(), 2);
3228 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
3229 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1.0 / 160.0));
3230
3231 let mut scores_values = boxes.slice_mut(s![0, 4..84, ..]);
3233 for score in scores_values.iter_mut() {
3234 *score = i8::MIN; }
3236 let protos = decoder
3237 .decode_tracked_quantized_proto(
3238 &mut tracker,
3239 100_000_000 / 3,
3240 &[boxes.view().into(), protos.view().into()],
3241 &mut output_boxes,
3242 &mut output_tracks,
3243 )
3244 .unwrap();
3245
3246 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3247 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3248
3249 assert!(protos.is_some_and(|x| x.mask_coefficients.is_empty()));
3251 }
3252
3253 #[test]
3254 fn test_decoder_tracked_segdet_proto_float() {
3255 use crate::configs::Nms;
3256 use crate::DecoderBuilder;
3257
3258 let score_threshold = 0.45;
3259 let iou_threshold = 0.45;
3260 let boxes = include_bytes!(concat!(
3261 env!("CARGO_MANIFEST_DIR"),
3262 "/../../testdata/yolov8_boxes_116x8400.bin"
3263 ));
3264 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
3265 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes.to_vec()).unwrap();
3266 let quant_boxes = (0.021287762, 31);
3267 let mut boxes = dequantize_ndarray(boxes.view(), quant_boxes.into());
3268
3269 let protos = include_bytes!(concat!(
3270 env!("CARGO_MANIFEST_DIR"),
3271 "/../../testdata/yolov8_protos_160x160x32.bin"
3272 ));
3273 let protos =
3274 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
3275 let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
3276 let quant_protos = (0.02491162, -117);
3277 let protos = dequantize_ndarray(protos.view(), quant_protos.into());
3278
3279 let config = include_str!(concat!(
3280 env!("CARGO_MANIFEST_DIR"),
3281 "/../../testdata/yolov8_seg.yaml"
3282 ));
3283
3284 let decoder = DecoderBuilder::default()
3285 .with_config_yaml_str(config.to_string())
3286 .with_score_threshold(score_threshold)
3287 .with_iou_threshold(iou_threshold)
3288 .with_nms(Some(Nms::ClassAgnostic))
3289 .build()
3290 .unwrap();
3291
3292 let expected_boxes = [
3293 DetectBox {
3294 bbox: BoundingBox {
3295 xmin: 0.08515105,
3296 ymin: 0.7131401,
3297 xmax: 0.29802868,
3298 ymax: 0.8195788,
3299 },
3300 score: 0.91537374,
3301 label: 23,
3302 },
3303 DetectBox {
3304 bbox: BoundingBox {
3305 xmin: 0.59605736,
3306 ymin: 0.25545314,
3307 xmax: 0.93666154,
3308 ymax: 0.72378385,
3309 },
3310 score: 0.91537374,
3311 label: 23,
3312 },
3313 ];
3314
3315 let mut tracker = ByteTrackBuilder::new()
3316 .track_update(0.1)
3317 .track_high_conf(0.7)
3318 .build();
3319
3320 let mut output_boxes = Vec::with_capacity(50);
3321 let mut output_tracks = Vec::with_capacity(50);
3322
3323 decoder
3324 .decode_tracked_float_proto(
3325 &mut tracker,
3326 0,
3327 &[boxes.view().into_dyn(), protos.view().into_dyn()],
3328 &mut output_boxes,
3329 &mut output_tracks,
3330 )
3331 .unwrap();
3332
3333 assert_eq!(output_boxes.len(), 2);
3334 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
3335 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1.0 / 160.0));
3336
3337 let mut scores_values = boxes.slice_mut(s![0, 4..84, ..]);
3339 for score in scores_values.iter_mut() {
3340 *score = 0.0; }
3342 let protos = decoder
3343 .decode_tracked_float_proto(
3344 &mut tracker,
3345 100_000_000 / 3,
3346 &[boxes.view().into_dyn(), protos.view().into_dyn()],
3347 &mut output_boxes,
3348 &mut output_tracks,
3349 )
3350 .unwrap();
3351
3352 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3353 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3354
3355 assert!(protos.is_some_and(|x| x.mask_coefficients.is_empty()));
3357 }
3358
3359 #[test]
3360 fn test_decoder_tracked_segdet_split() {
3361 let score_threshold = 0.45;
3362 let iou_threshold = 0.45;
3363
3364 let boxes = include_bytes!(concat!(
3365 env!("CARGO_MANIFEST_DIR"),
3366 "/../../testdata/yolov8_boxes_116x8400.bin"
3367 ));
3368 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
3369 let boxes = boxes.to_vec();
3370 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
3371
3372 let mask = boxes.slice(s![.., 84.., ..]).to_owned();
3373 let mut scores = boxes.slice(s![.., 4..84, ..]).to_owned();
3374 let boxes = boxes.slice(s![.., ..4, ..]).to_owned();
3375
3376 let quant_boxes = (0.021287762, 31);
3377
3378 let protos = include_bytes!(concat!(
3379 env!("CARGO_MANIFEST_DIR"),
3380 "/../../testdata/yolov8_protos_160x160x32.bin"
3381 ));
3382 let protos =
3383 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
3384 let protos = protos.to_vec();
3385 let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos).unwrap();
3386 let quant_protos = (0.02491162, -117);
3387 let decoder = DecoderBuilder::default()
3388 .with_config_yolo_split_segdet(
3389 configs::Boxes {
3390 decoder: configs::DecoderType::Ultralytics,
3391 quantization: Some(quant_boxes.into()),
3392 shape: vec![1, 4, 8400],
3393 dshape: vec![
3394 (DimName::Batch, 1),
3395 (DimName::BoxCoords, 4),
3396 (DimName::NumBoxes, 8400),
3397 ],
3398 normalized: Some(true),
3399 },
3400 configs::Scores {
3401 decoder: configs::DecoderType::Ultralytics,
3402 quantization: Some(quant_boxes.into()),
3403 shape: vec![1, 80, 8400],
3404 dshape: vec![
3405 (DimName::Batch, 1),
3406 (DimName::NumClasses, 80),
3407 (DimName::NumBoxes, 8400),
3408 ],
3409 },
3410 configs::MaskCoefficients {
3411 decoder: configs::DecoderType::Ultralytics,
3412 quantization: Some(quant_boxes.into()),
3413 shape: vec![1, 32, 8400],
3414 dshape: vec![
3415 (DimName::Batch, 1),
3416 (DimName::NumProtos, 32),
3417 (DimName::NumBoxes, 8400),
3418 ],
3419 },
3420 configs::Protos {
3421 decoder: configs::DecoderType::Ultralytics,
3422 quantization: Some(quant_protos.into()),
3423 shape: vec![1, 160, 160, 32],
3424 dshape: vec![
3425 (DimName::Batch, 1),
3426 (DimName::Height, 160),
3427 (DimName::Width, 160),
3428 (DimName::NumProtos, 32),
3429 ],
3430 },
3431 )
3432 .with_score_threshold(score_threshold)
3433 .with_iou_threshold(iou_threshold)
3434 .build()
3435 .unwrap();
3436
3437 let expected_boxes = [
3438 DetectBox {
3439 bbox: BoundingBox {
3440 xmin: 0.08515105,
3441 ymin: 0.7131401,
3442 xmax: 0.29802868,
3443 ymax: 0.8195788,
3444 },
3445 score: 0.91537374,
3446 label: 23,
3447 },
3448 DetectBox {
3449 bbox: BoundingBox {
3450 xmin: 0.59605736,
3451 ymin: 0.25545314,
3452 xmax: 0.93666154,
3453 ymax: 0.72378385,
3454 },
3455 score: 0.91537374,
3456 label: 23,
3457 },
3458 ];
3459
3460 let mut tracker = ByteTrackBuilder::new()
3461 .track_update(0.1)
3462 .track_high_conf(0.7)
3463 .build();
3464
3465 let mut output_boxes = Vec::with_capacity(50);
3466 let mut output_masks = Vec::with_capacity(50);
3467 let mut output_tracks = Vec::with_capacity(50);
3468
3469 decoder
3470 .decode_tracked_quantized(
3471 &mut tracker,
3472 0,
3473 &[
3474 boxes.view().into(),
3475 scores.view().into(),
3476 mask.view().into(),
3477 protos.view().into(),
3478 ],
3479 &mut output_boxes,
3480 &mut output_masks,
3481 &mut output_tracks,
3482 )
3483 .unwrap();
3484
3485 assert_eq!(output_boxes.len(), 2);
3486 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
3487 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1.0 / 160.0));
3488
3489 for score in scores.iter_mut() {
3492 *score = i8::MIN; }
3494 decoder
3495 .decode_tracked_quantized(
3496 &mut tracker,
3497 100_000_000 / 3,
3498 &[
3499 boxes.view().into(),
3500 scores.view().into(),
3501 mask.view().into(),
3502 protos.view().into(),
3503 ],
3504 &mut output_boxes,
3505 &mut output_masks,
3506 &mut output_tracks,
3507 )
3508 .unwrap();
3509
3510 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3511 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3512
3513 assert!(output_masks.is_empty())
3515 }
3516
3517 #[test]
3518 fn test_decoder_tracked_segdet_split_float() {
3519 let score_threshold = 0.45;
3520 let iou_threshold = 0.45;
3521
3522 let boxes = include_bytes!(concat!(
3523 env!("CARGO_MANIFEST_DIR"),
3524 "/../../testdata/yolov8_boxes_116x8400.bin"
3525 ));
3526 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
3527 let boxes = boxes.to_vec();
3528 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
3529 let quant_boxes = (0.021287762, 31);
3530 let boxes = dequantize_ndarray(boxes.view(), quant_boxes.into());
3531
3532 let mask = boxes.slice(s![.., 84.., ..]).to_owned();
3533 let mut scores = boxes.slice(s![.., 4..84, ..]).to_owned();
3534 let boxes = boxes.slice(s![.., ..4, ..]).to_owned();
3535
3536 let protos = include_bytes!(concat!(
3537 env!("CARGO_MANIFEST_DIR"),
3538 "/../../testdata/yolov8_protos_160x160x32.bin"
3539 ));
3540 let protos =
3541 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
3542 let protos = protos.to_vec();
3543 let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos).unwrap();
3544 let quant_protos = (0.02491162, -117);
3545 let protos = dequantize_ndarray(protos.view(), quant_protos.into());
3546
3547 let decoder = DecoderBuilder::default()
3548 .with_config_yolo_split_segdet(
3549 configs::Boxes {
3550 decoder: configs::DecoderType::Ultralytics,
3551 quantization: Some(quant_boxes.into()),
3552 shape: vec![1, 4, 8400],
3553 dshape: vec![
3554 (DimName::Batch, 1),
3555 (DimName::BoxCoords, 4),
3556 (DimName::NumBoxes, 8400),
3557 ],
3558 normalized: Some(true),
3559 },
3560 configs::Scores {
3561 decoder: configs::DecoderType::Ultralytics,
3562 quantization: Some(quant_boxes.into()),
3563 shape: vec![1, 80, 8400],
3564 dshape: vec![
3565 (DimName::Batch, 1),
3566 (DimName::NumClasses, 80),
3567 (DimName::NumBoxes, 8400),
3568 ],
3569 },
3570 configs::MaskCoefficients {
3571 decoder: configs::DecoderType::Ultralytics,
3572 quantization: Some(quant_boxes.into()),
3573 shape: vec![1, 32, 8400],
3574 dshape: vec![
3575 (DimName::Batch, 1),
3576 (DimName::NumProtos, 32),
3577 (DimName::NumBoxes, 8400),
3578 ],
3579 },
3580 configs::Protos {
3581 decoder: configs::DecoderType::Ultralytics,
3582 quantization: Some(quant_protos.into()),
3583 shape: vec![1, 160, 160, 32],
3584 dshape: vec![
3585 (DimName::Batch, 1),
3586 (DimName::Height, 160),
3587 (DimName::Width, 160),
3588 (DimName::NumProtos, 32),
3589 ],
3590 },
3591 )
3592 .with_score_threshold(score_threshold)
3593 .with_iou_threshold(iou_threshold)
3594 .build()
3595 .unwrap();
3596
3597 let expected_boxes = [
3598 DetectBox {
3599 bbox: BoundingBox {
3600 xmin: 0.08515105,
3601 ymin: 0.7131401,
3602 xmax: 0.29802868,
3603 ymax: 0.8195788,
3604 },
3605 score: 0.91537374,
3606 label: 23,
3607 },
3608 DetectBox {
3609 bbox: BoundingBox {
3610 xmin: 0.59605736,
3611 ymin: 0.25545314,
3612 xmax: 0.93666154,
3613 ymax: 0.72378385,
3614 },
3615 score: 0.91537374,
3616 label: 23,
3617 },
3618 ];
3619
3620 let mut tracker = ByteTrackBuilder::new()
3621 .track_update(0.1)
3622 .track_high_conf(0.7)
3623 .build();
3624
3625 let mut output_boxes = Vec::with_capacity(50);
3626 let mut output_masks = Vec::with_capacity(50);
3627 let mut output_tracks = Vec::with_capacity(50);
3628
3629 decoder
3630 .decode_tracked_float(
3631 &mut tracker,
3632 0,
3633 &[
3634 boxes.view().into_dyn(),
3635 scores.view().into_dyn(),
3636 mask.view().into_dyn(),
3637 protos.view().into_dyn(),
3638 ],
3639 &mut output_boxes,
3640 &mut output_masks,
3641 &mut output_tracks,
3642 )
3643 .unwrap();
3644
3645 assert_eq!(output_boxes.len(), 2);
3646 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
3647 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1.0 / 160.0));
3648
3649 for score in scores.iter_mut() {
3652 *score = 0.0; }
3654 decoder
3655 .decode_tracked_float(
3656 &mut tracker,
3657 100_000_000 / 3,
3658 &[
3659 boxes.view().into_dyn(),
3660 scores.view().into_dyn(),
3661 mask.view().into_dyn(),
3662 protos.view().into_dyn(),
3663 ],
3664 &mut output_boxes,
3665 &mut output_masks,
3666 &mut output_tracks,
3667 )
3668 .unwrap();
3669
3670 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3671 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3672
3673 assert!(output_masks.is_empty())
3675 }
3676
3677 #[test]
3678 fn test_decoder_tracked_segdet_split_proto() {
3679 let score_threshold = 0.45;
3680 let iou_threshold = 0.45;
3681
3682 let boxes = include_bytes!(concat!(
3683 env!("CARGO_MANIFEST_DIR"),
3684 "/../../testdata/yolov8_boxes_116x8400.bin"
3685 ));
3686 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
3687 let boxes = boxes.to_vec();
3688 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
3689
3690 let mask = boxes.slice(s![.., 84.., ..]).to_owned();
3691 let mut scores = boxes.slice(s![.., 4..84, ..]).to_owned();
3692 let boxes = boxes.slice(s![.., ..4, ..]).to_owned();
3693
3694 let quant_boxes = (0.021287762, 31);
3695
3696 let protos = include_bytes!(concat!(
3697 env!("CARGO_MANIFEST_DIR"),
3698 "/../../testdata/yolov8_protos_160x160x32.bin"
3699 ));
3700 let protos =
3701 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
3702 let protos = protos.to_vec();
3703 let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos).unwrap();
3704 let quant_protos = (0.02491162, -117);
3705 let decoder = DecoderBuilder::default()
3706 .with_config_yolo_split_segdet(
3707 configs::Boxes {
3708 decoder: configs::DecoderType::Ultralytics,
3709 quantization: Some(quant_boxes.into()),
3710 shape: vec![1, 4, 8400],
3711 dshape: vec![
3712 (DimName::Batch, 1),
3713 (DimName::BoxCoords, 4),
3714 (DimName::NumBoxes, 8400),
3715 ],
3716 normalized: Some(true),
3717 },
3718 configs::Scores {
3719 decoder: configs::DecoderType::Ultralytics,
3720 quantization: Some(quant_boxes.into()),
3721 shape: vec![1, 80, 8400],
3722 dshape: vec![
3723 (DimName::Batch, 1),
3724 (DimName::NumClasses, 80),
3725 (DimName::NumBoxes, 8400),
3726 ],
3727 },
3728 configs::MaskCoefficients {
3729 decoder: configs::DecoderType::Ultralytics,
3730 quantization: Some(quant_boxes.into()),
3731 shape: vec![1, 32, 8400],
3732 dshape: vec![
3733 (DimName::Batch, 1),
3734 (DimName::NumProtos, 32),
3735 (DimName::NumBoxes, 8400),
3736 ],
3737 },
3738 configs::Protos {
3739 decoder: configs::DecoderType::Ultralytics,
3740 quantization: Some(quant_protos.into()),
3741 shape: vec![1, 160, 160, 32],
3742 dshape: vec![
3743 (DimName::Batch, 1),
3744 (DimName::Height, 160),
3745 (DimName::Width, 160),
3746 (DimName::NumProtos, 32),
3747 ],
3748 },
3749 )
3750 .with_score_threshold(score_threshold)
3751 .with_iou_threshold(iou_threshold)
3752 .build()
3753 .unwrap();
3754
3755 let expected_boxes = [
3756 DetectBox {
3757 bbox: BoundingBox {
3758 xmin: 0.08515105,
3759 ymin: 0.7131401,
3760 xmax: 0.29802868,
3761 ymax: 0.8195788,
3762 },
3763 score: 0.91537374,
3764 label: 23,
3765 },
3766 DetectBox {
3767 bbox: BoundingBox {
3768 xmin: 0.59605736,
3769 ymin: 0.25545314,
3770 xmax: 0.93666154,
3771 ymax: 0.72378385,
3772 },
3773 score: 0.91537374,
3774 label: 23,
3775 },
3776 ];
3777
3778 let mut tracker = ByteTrackBuilder::new()
3779 .track_update(0.1)
3780 .track_high_conf(0.7)
3781 .build();
3782
3783 let mut output_boxes = Vec::with_capacity(50);
3784 let mut output_tracks = Vec::with_capacity(50);
3785
3786 decoder
3787 .decode_tracked_quantized_proto(
3788 &mut tracker,
3789 0,
3790 &[
3791 boxes.view().into(),
3792 scores.view().into(),
3793 mask.view().into(),
3794 protos.view().into(),
3795 ],
3796 &mut output_boxes,
3797 &mut output_tracks,
3798 )
3799 .unwrap();
3800
3801 assert_eq!(output_boxes.len(), 2);
3802 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
3803 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1.0 / 160.0));
3804
3805 for score in scores.iter_mut() {
3808 *score = i8::MIN; }
3810 let protos = decoder
3811 .decode_tracked_quantized_proto(
3812 &mut tracker,
3813 100_000_000 / 3,
3814 &[
3815 boxes.view().into(),
3816 scores.view().into(),
3817 mask.view().into(),
3818 protos.view().into(),
3819 ],
3820 &mut output_boxes,
3821 &mut output_tracks,
3822 )
3823 .unwrap();
3824
3825 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3826 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3827
3828 assert!(protos.is_some_and(|x| x.mask_coefficients.is_empty()));
3830 }
3831
3832 #[test]
3833 fn test_decoder_tracked_segdet_split_proto_float() {
3834 let score_threshold = 0.45;
3835 let iou_threshold = 0.45;
3836
3837 let boxes = include_bytes!(concat!(
3838 env!("CARGO_MANIFEST_DIR"),
3839 "/../../testdata/yolov8_boxes_116x8400.bin"
3840 ));
3841 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
3842 let boxes = boxes.to_vec();
3843 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
3844 let quant_boxes = (0.021287762, 31);
3845 let boxes = dequantize_ndarray(boxes.view(), quant_boxes.into());
3846
3847 let mask = boxes.slice(s![.., 84.., ..]).to_owned();
3848 let mut scores = boxes.slice(s![.., 4..84, ..]).to_owned();
3849 let boxes = boxes.slice(s![.., ..4, ..]).to_owned();
3850
3851 let protos = include_bytes!(concat!(
3852 env!("CARGO_MANIFEST_DIR"),
3853 "/../../testdata/yolov8_protos_160x160x32.bin"
3854 ));
3855 let protos =
3856 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
3857 let protos = protos.to_vec();
3858 let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos).unwrap();
3859 let quant_protos = (0.02491162, -117);
3860 let protos = dequantize_ndarray(protos.view(), quant_protos.into());
3861
3862 let decoder = DecoderBuilder::default()
3863 .with_config_yolo_split_segdet(
3864 configs::Boxes {
3865 decoder: configs::DecoderType::Ultralytics,
3866 quantization: Some(quant_boxes.into()),
3867 shape: vec![1, 4, 8400],
3868 dshape: vec![
3869 (DimName::Batch, 1),
3870 (DimName::BoxCoords, 4),
3871 (DimName::NumBoxes, 8400),
3872 ],
3873 normalized: Some(true),
3874 },
3875 configs::Scores {
3876 decoder: configs::DecoderType::Ultralytics,
3877 quantization: Some(quant_boxes.into()),
3878 shape: vec![1, 80, 8400],
3879 dshape: vec![
3880 (DimName::Batch, 1),
3881 (DimName::NumClasses, 80),
3882 (DimName::NumBoxes, 8400),
3883 ],
3884 },
3885 configs::MaskCoefficients {
3886 decoder: configs::DecoderType::Ultralytics,
3887 quantization: Some(quant_boxes.into()),
3888 shape: vec![1, 32, 8400],
3889 dshape: vec![
3890 (DimName::Batch, 1),
3891 (DimName::NumProtos, 32),
3892 (DimName::NumBoxes, 8400),
3893 ],
3894 },
3895 configs::Protos {
3896 decoder: configs::DecoderType::Ultralytics,
3897 quantization: Some(quant_protos.into()),
3898 shape: vec![1, 160, 160, 32],
3899 dshape: vec![
3900 (DimName::Batch, 1),
3901 (DimName::Height, 160),
3902 (DimName::Width, 160),
3903 (DimName::NumProtos, 32),
3904 ],
3905 },
3906 )
3907 .with_score_threshold(score_threshold)
3908 .with_iou_threshold(iou_threshold)
3909 .build()
3910 .unwrap();
3911
3912 let expected_boxes = [
3913 DetectBox {
3914 bbox: BoundingBox {
3915 xmin: 0.08515105,
3916 ymin: 0.7131401,
3917 xmax: 0.29802868,
3918 ymax: 0.8195788,
3919 },
3920 score: 0.91537374,
3921 label: 23,
3922 },
3923 DetectBox {
3924 bbox: BoundingBox {
3925 xmin: 0.59605736,
3926 ymin: 0.25545314,
3927 xmax: 0.93666154,
3928 ymax: 0.72378385,
3929 },
3930 score: 0.91537374,
3931 label: 23,
3932 },
3933 ];
3934
3935 let mut tracker = ByteTrackBuilder::new()
3936 .track_update(0.1)
3937 .track_high_conf(0.7)
3938 .build();
3939
3940 let mut output_boxes = Vec::with_capacity(50);
3941 let mut output_tracks = Vec::with_capacity(50);
3942
3943 decoder
3944 .decode_tracked_float_proto(
3945 &mut tracker,
3946 0,
3947 &[
3948 boxes.view().into_dyn(),
3949 scores.view().into_dyn(),
3950 mask.view().into_dyn(),
3951 protos.view().into_dyn(),
3952 ],
3953 &mut output_boxes,
3954 &mut output_tracks,
3955 )
3956 .unwrap();
3957
3958 assert_eq!(output_boxes.len(), 2);
3959 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
3960 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1.0 / 160.0));
3961
3962 for score in scores.iter_mut() {
3965 *score = 0.0; }
3967 let protos = decoder
3968 .decode_tracked_float_proto(
3969 &mut tracker,
3970 100_000_000 / 3,
3971 &[
3972 boxes.view().into_dyn(),
3973 scores.view().into_dyn(),
3974 mask.view().into_dyn(),
3975 protos.view().into_dyn(),
3976 ],
3977 &mut output_boxes,
3978 &mut output_tracks,
3979 )
3980 .unwrap();
3981
3982 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3983 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3984
3985 assert!(protos.is_some_and(|x| x.mask_coefficients.is_empty()));
3987 }
3988
3989 #[test]
3990 fn test_decoder_tracked_end_to_end_segdet() {
3991 let score_threshold = 0.45;
3992 let iou_threshold = 0.45;
3993
3994 let mut boxes = Array2::zeros((10, 4));
3995 let mut scores = Array2::zeros((10, 1));
3996 let mut classes = Array2::zeros((10, 1));
3997 let mask = Array2::zeros((10, 32));
3998 let protos = Array3::<f64>::zeros((160, 160, 32));
3999 let protos = protos.insert_axis(Axis(0));
4000
4001 let protos_quant = (1.0 / 255.0, 0.0);
4002 let protos: Array4<u8> = quantize_ndarray(protos.view(), protos_quant.into());
4003
4004 boxes
4005 .slice_mut(s![0, ..,])
4006 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4007 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4008 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4009
4010 let detect = ndarray::concatenate![
4011 Axis(1),
4012 boxes.view(),
4013 scores.view(),
4014 classes.view(),
4015 mask.view()
4016 ];
4017 let detect = detect.insert_axis(Axis(0));
4018 assert_eq!(detect.shape(), &[1, 10, 38]);
4019 let detect_quant = (2.0 / 255.0, 0.0);
4020 let mut detect: Array3<u8> = quantize_ndarray(detect.view(), detect_quant.into());
4021 let config = "
4022decoder_version: yolo26
4023outputs:
4024 - type: detection
4025 decoder: ultralytics
4026 quantization: [0.00784313725490196, 0]
4027 shape: [1, 10, 38]
4028 dshape:
4029 - [batch, 1]
4030 - [num_boxes, 10]
4031 - [num_features, 38]
4032 normalized: true
4033 - type: protos
4034 decoder: ultralytics
4035 quantization: [0.0039215686274509803921568627451, 128]
4036 shape: [1, 160, 160, 32]
4037 dshape:
4038 - [batch, 1]
4039 - [height, 160]
4040 - [width, 160]
4041 - [num_protos, 32]
4042";
4043
4044 let decoder = DecoderBuilder::default()
4045 .with_config_yaml_str(config.to_string())
4046 .with_score_threshold(score_threshold)
4047 .with_iou_threshold(iou_threshold)
4048 .build()
4049 .unwrap();
4050
4051 let expected_boxes = [DetectBox {
4053 bbox: BoundingBox {
4054 xmin: 0.12549022,
4055 ymin: 0.12549022,
4056 xmax: 0.23529413,
4057 ymax: 0.23529413,
4058 },
4059 score: 0.98823535,
4060 label: 2,
4061 }];
4062
4063 let mut tracker = ByteTrackBuilder::new()
4064 .track_update(0.1)
4065 .track_high_conf(0.7)
4066 .build();
4067
4068 let mut output_boxes = Vec::with_capacity(50);
4069 let mut output_masks = Vec::with_capacity(50);
4070 let mut output_tracks = Vec::with_capacity(50);
4071
4072 decoder
4073 .decode_tracked_quantized(
4074 &mut tracker,
4075 0,
4076 &[detect.view().into(), protos.view().into()],
4077 &mut output_boxes,
4078 &mut output_masks,
4079 &mut output_tracks,
4080 )
4081 .unwrap();
4082
4083 assert_eq!(output_boxes.len(), 1);
4084 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
4085
4086 for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
4089 *score = u8::MIN; }
4091
4092 decoder
4093 .decode_tracked_quantized(
4094 &mut tracker,
4095 100_000_000 / 3,
4096 &[detect.view().into(), protos.view().into()],
4097 &mut output_boxes,
4098 &mut output_masks,
4099 &mut output_tracks,
4100 )
4101 .unwrap();
4102 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4103 assert!(output_masks.is_empty())
4105 }
4106
4107 #[test]
4108 fn test_decoder_tracked_end_to_end_segdet_float() {
4109 let score_threshold = 0.45;
4110 let iou_threshold = 0.45;
4111
4112 let mut boxes = Array2::zeros((10, 4));
4113 let mut scores = Array2::zeros((10, 1));
4114 let mut classes = Array2::zeros((10, 1));
4115 let mask = Array2::zeros((10, 32));
4116 let protos = Array3::<f64>::zeros((160, 160, 32));
4117 let protos = protos.insert_axis(Axis(0));
4118
4119 boxes
4120 .slice_mut(s![0, ..,])
4121 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4122 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4123 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4124
4125 let detect = ndarray::concatenate![
4126 Axis(1),
4127 boxes.view(),
4128 scores.view(),
4129 classes.view(),
4130 mask.view()
4131 ];
4132 let mut detect = detect.insert_axis(Axis(0));
4133 assert_eq!(detect.shape(), &[1, 10, 38]);
4134 let config = "
4135decoder_version: yolo26
4136outputs:
4137 - type: detection
4138 decoder: ultralytics
4139 quantization: [0.00784313725490196, 0]
4140 shape: [1, 10, 38]
4141 dshape:
4142 - [batch, 1]
4143 - [num_boxes, 10]
4144 - [num_features, 38]
4145 normalized: true
4146 - type: protos
4147 decoder: ultralytics
4148 quantization: [0.0039215686274509803921568627451, 128]
4149 shape: [1, 160, 160, 32]
4150 dshape:
4151 - [batch, 1]
4152 - [height, 160]
4153 - [width, 160]
4154 - [num_protos, 32]
4155";
4156
4157 let decoder = DecoderBuilder::default()
4158 .with_config_yaml_str(config.to_string())
4159 .with_score_threshold(score_threshold)
4160 .with_iou_threshold(iou_threshold)
4161 .build()
4162 .unwrap();
4163
4164 let expected_boxes = [DetectBox {
4165 bbox: BoundingBox {
4166 xmin: 0.1234,
4167 ymin: 0.1234,
4168 xmax: 0.2345,
4169 ymax: 0.2345,
4170 },
4171 score: 0.9876,
4172 label: 2,
4173 }];
4174
4175 let mut tracker = ByteTrackBuilder::new()
4176 .track_update(0.1)
4177 .track_high_conf(0.7)
4178 .build();
4179
4180 let mut output_boxes = Vec::with_capacity(50);
4181 let mut output_masks = Vec::with_capacity(50);
4182 let mut output_tracks = Vec::with_capacity(50);
4183
4184 decoder
4185 .decode_tracked_float(
4186 &mut tracker,
4187 0,
4188 &[detect.view().into_dyn(), protos.view().into_dyn()],
4189 &mut output_boxes,
4190 &mut output_masks,
4191 &mut output_tracks,
4192 )
4193 .unwrap();
4194
4195 assert_eq!(output_boxes.len(), 1);
4196 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
4197
4198 for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
4201 *score = 0.0; }
4203
4204 decoder
4205 .decode_tracked_float(
4206 &mut tracker,
4207 100_000_000 / 3,
4208 &[detect.view().into_dyn(), protos.view().into_dyn()],
4209 &mut output_boxes,
4210 &mut output_masks,
4211 &mut output_tracks,
4212 )
4213 .unwrap();
4214 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4215 assert!(output_masks.is_empty())
4217 }
4218
4219 #[test]
4220 fn test_decoder_tracked_end_to_end_segdet_proto() {
4221 let score_threshold = 0.45;
4222 let iou_threshold = 0.45;
4223
4224 let mut boxes = Array2::zeros((10, 4));
4225 let mut scores = Array2::zeros((10, 1));
4226 let mut classes = Array2::zeros((10, 1));
4227 let mask = Array2::zeros((10, 32));
4228 let protos = Array3::<f64>::zeros((160, 160, 32));
4229 let protos = protos.insert_axis(Axis(0));
4230
4231 let protos_quant = (1.0 / 255.0, 0.0);
4232 let protos: Array4<u8> = quantize_ndarray(protos.view(), protos_quant.into());
4233
4234 boxes
4235 .slice_mut(s![0, ..,])
4236 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4237 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4238 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4239
4240 let detect = ndarray::concatenate![
4241 Axis(1),
4242 boxes.view(),
4243 scores.view(),
4244 classes.view(),
4245 mask.view()
4246 ];
4247 let detect = detect.insert_axis(Axis(0));
4248 assert_eq!(detect.shape(), &[1, 10, 38]);
4249 let detect_quant = (2.0 / 255.0, 0.0);
4250 let mut detect: Array3<u8> = quantize_ndarray(detect.view(), detect_quant.into());
4251 let config = "
4252decoder_version: yolo26
4253outputs:
4254 - type: detection
4255 decoder: ultralytics
4256 quantization: [0.00784313725490196, 0]
4257 shape: [1, 10, 38]
4258 dshape:
4259 - [batch, 1]
4260 - [num_boxes, 10]
4261 - [num_features, 38]
4262 normalized: true
4263 - type: protos
4264 decoder: ultralytics
4265 quantization: [0.0039215686274509803921568627451, 128]
4266 shape: [1, 160, 160, 32]
4267 dshape:
4268 - [batch, 1]
4269 - [height, 160]
4270 - [width, 160]
4271 - [num_protos, 32]
4272";
4273
4274 let decoder = DecoderBuilder::default()
4275 .with_config_yaml_str(config.to_string())
4276 .with_score_threshold(score_threshold)
4277 .with_iou_threshold(iou_threshold)
4278 .build()
4279 .unwrap();
4280
4281 let expected_boxes = [DetectBox {
4283 bbox: BoundingBox {
4284 xmin: 0.12549022,
4285 ymin: 0.12549022,
4286 xmax: 0.23529413,
4287 ymax: 0.23529413,
4288 },
4289 score: 0.98823535,
4290 label: 2,
4291 }];
4292
4293 let mut tracker = ByteTrackBuilder::new()
4294 .track_update(0.1)
4295 .track_high_conf(0.7)
4296 .build();
4297
4298 let mut output_boxes = Vec::with_capacity(50);
4299 let mut output_tracks = Vec::with_capacity(50);
4300
4301 decoder
4302 .decode_tracked_quantized_proto(
4303 &mut tracker,
4304 0,
4305 &[detect.view().into(), protos.view().into()],
4306 &mut output_boxes,
4307 &mut output_tracks,
4308 )
4309 .unwrap();
4310
4311 assert_eq!(output_boxes.len(), 1);
4312 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
4313
4314 for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
4317 *score = u8::MIN; }
4319
4320 let protos = decoder
4321 .decode_tracked_quantized_proto(
4322 &mut tracker,
4323 100_000_000 / 3,
4324 &[detect.view().into(), protos.view().into()],
4325 &mut output_boxes,
4326 &mut output_tracks,
4327 )
4328 .unwrap();
4329 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4330 assert!(protos.is_some_and(|x| x.mask_coefficients.is_empty()))
4332 }
4333
4334 #[test]
4335 fn test_decoder_tracked_end_to_end_segdet_proto_float() {
4336 let score_threshold = 0.45;
4337 let iou_threshold = 0.45;
4338
4339 let mut boxes = Array2::zeros((10, 4));
4340 let mut scores = Array2::zeros((10, 1));
4341 let mut classes = Array2::zeros((10, 1));
4342 let mask = Array2::zeros((10, 32));
4343 let protos = Array3::<f64>::zeros((160, 160, 32));
4344 let protos = protos.insert_axis(Axis(0));
4345
4346 boxes
4347 .slice_mut(s![0, ..,])
4348 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4349 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4350 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4351
4352 let detect = ndarray::concatenate![
4353 Axis(1),
4354 boxes.view(),
4355 scores.view(),
4356 classes.view(),
4357 mask.view()
4358 ];
4359 let mut detect = detect.insert_axis(Axis(0));
4360 assert_eq!(detect.shape(), &[1, 10, 38]);
4361 let config = "
4362decoder_version: yolo26
4363outputs:
4364 - type: detection
4365 decoder: ultralytics
4366 quantization: [0.00784313725490196, 0]
4367 shape: [1, 10, 38]
4368 dshape:
4369 - [batch, 1]
4370 - [num_boxes, 10]
4371 - [num_features, 38]
4372 normalized: true
4373 - type: protos
4374 decoder: ultralytics
4375 quantization: [0.0039215686274509803921568627451, 128]
4376 shape: [1, 160, 160, 32]
4377 dshape:
4378 - [batch, 1]
4379 - [height, 160]
4380 - [width, 160]
4381 - [num_protos, 32]
4382";
4383
4384 let decoder = DecoderBuilder::default()
4385 .with_config_yaml_str(config.to_string())
4386 .with_score_threshold(score_threshold)
4387 .with_iou_threshold(iou_threshold)
4388 .build()
4389 .unwrap();
4390
4391 let expected_boxes = [DetectBox {
4392 bbox: BoundingBox {
4393 xmin: 0.1234,
4394 ymin: 0.1234,
4395 xmax: 0.2345,
4396 ymax: 0.2345,
4397 },
4398 score: 0.9876,
4399 label: 2,
4400 }];
4401
4402 let mut tracker = ByteTrackBuilder::new()
4403 .track_update(0.1)
4404 .track_high_conf(0.7)
4405 .build();
4406
4407 let mut output_boxes = Vec::with_capacity(50);
4408 let mut output_tracks = Vec::with_capacity(50);
4409
4410 decoder
4411 .decode_tracked_float_proto(
4412 &mut tracker,
4413 0,
4414 &[detect.view().into_dyn(), protos.view().into_dyn()],
4415 &mut output_boxes,
4416 &mut output_tracks,
4417 )
4418 .unwrap();
4419
4420 assert_eq!(output_boxes.len(), 1);
4421 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
4422
4423 for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
4426 *score = 0.0; }
4428
4429 let protos = decoder
4430 .decode_tracked_float_proto(
4431 &mut tracker,
4432 100_000_000 / 3,
4433 &[detect.view().into_dyn(), protos.view().into_dyn()],
4434 &mut output_boxes,
4435 &mut output_tracks,
4436 )
4437 .unwrap();
4438 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4439 assert!(protos.is_some_and(|x| x.mask_coefficients.is_empty()))
4441 }
4442
4443 #[test]
4444 fn test_decoder_tracked_end_to_end_segdet_split() {
4445 let score_threshold = 0.45;
4446 let iou_threshold = 0.45;
4447
4448 let mut boxes = Array2::zeros((10, 4));
4449 let mut scores = Array2::zeros((10, 1));
4450 let mut classes = Array2::zeros((10, 1));
4451 let mask: Array2<f64> = Array2::zeros((10, 32));
4452 let protos = Array3::<f64>::zeros((160, 160, 32));
4453 let protos = protos.insert_axis(Axis(0));
4454
4455 let protos_quant = (1.0 / 255.0, 0.0);
4456 let protos: Array4<u8> = quantize_ndarray(protos.view(), protos_quant.into());
4457
4458 boxes
4459 .slice_mut(s![0, ..,])
4460 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4461 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4462 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4463
4464 let boxes = boxes.insert_axis(Axis(0));
4465 let scores = scores.insert_axis(Axis(0));
4466 let classes = classes.insert_axis(Axis(0));
4467 let mask = mask.insert_axis(Axis(0));
4468
4469 let detect_quant = (2.0 / 255.0, 0.0);
4470 let boxes: Array3<u8> = quantize_ndarray(boxes.view(), detect_quant.into());
4471 let mut scores: Array3<u8> = quantize_ndarray(scores.view(), detect_quant.into());
4472 let classes: Array3<u8> = quantize_ndarray(classes.view(), detect_quant.into());
4473 let mask: Array3<u8> = quantize_ndarray(mask.view(), detect_quant.into());
4474
4475 let config = "
4476decoder_version: yolo26
4477outputs:
4478 - type: boxes
4479 decoder: ultralytics
4480 quantization: [0.00784313725490196, 0]
4481 shape: [1, 10, 4]
4482 dshape:
4483 - [batch, 1]
4484 - [num_boxes, 10]
4485 - [box_coords, 4]
4486 normalized: true
4487 - type: scores
4488 decoder: ultralytics
4489 quantization: [0.00784313725490196, 0]
4490 shape: [1, 10, 1]
4491 dshape:
4492 - [batch, 1]
4493 - [num_boxes, 10]
4494 - [num_classes, 1]
4495 - type: classes
4496 decoder: ultralytics
4497 quantization: [0.00784313725490196, 0]
4498 shape: [1, 10, 1]
4499 dshape:
4500 - [batch, 1]
4501 - [num_boxes, 10]
4502 - [num_classes, 1]
4503 - type: mask_coefficients
4504 decoder: ultralytics
4505 quantization: [0.00784313725490196, 0]
4506 shape: [1, 10, 32]
4507 dshape:
4508 - [batch, 1]
4509 - [num_boxes, 10]
4510 - [num_protos, 32]
4511 - type: protos
4512 decoder: ultralytics
4513 quantization: [0.0039215686274509803921568627451, 128]
4514 shape: [1, 160, 160, 32]
4515 dshape:
4516 - [batch, 1]
4517 - [height, 160]
4518 - [width, 160]
4519 - [num_protos, 32]
4520";
4521
4522 let decoder = DecoderBuilder::default()
4523 .with_config_yaml_str(config.to_string())
4524 .with_score_threshold(score_threshold)
4525 .with_iou_threshold(iou_threshold)
4526 .build()
4527 .unwrap();
4528
4529 let expected_boxes = [DetectBox {
4531 bbox: BoundingBox {
4532 xmin: 0.12549022,
4533 ymin: 0.12549022,
4534 xmax: 0.23529413,
4535 ymax: 0.23529413,
4536 },
4537 score: 0.98823535,
4538 label: 2,
4539 }];
4540
4541 let mut tracker = ByteTrackBuilder::new()
4542 .track_update(0.1)
4543 .track_high_conf(0.7)
4544 .build();
4545
4546 let mut output_boxes = Vec::with_capacity(50);
4547 let mut output_masks = Vec::with_capacity(50);
4548 let mut output_tracks = Vec::with_capacity(50);
4549
4550 decoder
4551 .decode_tracked_quantized(
4552 &mut tracker,
4553 0,
4554 &[
4555 boxes.view().into(),
4556 scores.view().into(),
4557 classes.view().into(),
4558 mask.view().into(),
4559 protos.view().into(),
4560 ],
4561 &mut output_boxes,
4562 &mut output_masks,
4563 &mut output_tracks,
4564 )
4565 .unwrap();
4566
4567 assert_eq!(output_boxes.len(), 1);
4568 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
4569
4570 for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
4573 *score = u8::MIN; }
4575
4576 decoder
4577 .decode_tracked_quantized(
4578 &mut tracker,
4579 100_000_000 / 3,
4580 &[
4581 boxes.view().into(),
4582 scores.view().into(),
4583 classes.view().into(),
4584 mask.view().into(),
4585 protos.view().into(),
4586 ],
4587 &mut output_boxes,
4588 &mut output_masks,
4589 &mut output_tracks,
4590 )
4591 .unwrap();
4592 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4593 assert!(output_masks.is_empty())
4595 }
4596 #[test]
4597 fn test_decoder_tracked_end_to_end_segdet_split_float() {
4598 let score_threshold = 0.45;
4599 let iou_threshold = 0.45;
4600
4601 let mut boxes = Array2::zeros((10, 4));
4602 let mut scores = Array2::zeros((10, 1));
4603 let mut classes = Array2::zeros((10, 1));
4604 let mask: Array2<f64> = Array2::zeros((10, 32));
4605 let protos = Array3::<f64>::zeros((160, 160, 32));
4606 let protos = protos.insert_axis(Axis(0));
4607
4608 boxes
4609 .slice_mut(s![0, ..,])
4610 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4611 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4612 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4613
4614 let boxes = boxes.insert_axis(Axis(0));
4615 let mut scores = scores.insert_axis(Axis(0));
4616 let classes = classes.insert_axis(Axis(0));
4617 let mask = mask.insert_axis(Axis(0));
4618
4619 let config = "
4620decoder_version: yolo26
4621outputs:
4622 - type: boxes
4623 decoder: ultralytics
4624 quantization: [0.00784313725490196, 0]
4625 shape: [1, 10, 4]
4626 dshape:
4627 - [batch, 1]
4628 - [num_boxes, 10]
4629 - [box_coords, 4]
4630 normalized: true
4631 - type: scores
4632 decoder: ultralytics
4633 quantization: [0.00784313725490196, 0]
4634 shape: [1, 10, 1]
4635 dshape:
4636 - [batch, 1]
4637 - [num_boxes, 10]
4638 - [num_classes, 1]
4639 - type: classes
4640 decoder: ultralytics
4641 quantization: [0.00784313725490196, 0]
4642 shape: [1, 10, 1]
4643 dshape:
4644 - [batch, 1]
4645 - [num_boxes, 10]
4646 - [num_classes, 1]
4647 - type: mask_coefficients
4648 decoder: ultralytics
4649 quantization: [0.00784313725490196, 0]
4650 shape: [1, 10, 32]
4651 dshape:
4652 - [batch, 1]
4653 - [num_boxes, 10]
4654 - [num_protos, 32]
4655 - type: protos
4656 decoder: ultralytics
4657 quantization: [0.0039215686274509803921568627451, 128]
4658 shape: [1, 160, 160, 32]
4659 dshape:
4660 - [batch, 1]
4661 - [height, 160]
4662 - [width, 160]
4663 - [num_protos, 32]
4664";
4665
4666 let decoder = DecoderBuilder::default()
4667 .with_config_yaml_str(config.to_string())
4668 .with_score_threshold(score_threshold)
4669 .with_iou_threshold(iou_threshold)
4670 .build()
4671 .unwrap();
4672
4673 let expected_boxes = [DetectBox {
4675 bbox: BoundingBox {
4676 xmin: 0.1234,
4677 ymin: 0.1234,
4678 xmax: 0.2345,
4679 ymax: 0.2345,
4680 },
4681 score: 0.9876,
4682 label: 2,
4683 }];
4684
4685 let mut tracker = ByteTrackBuilder::new()
4686 .track_update(0.1)
4687 .track_high_conf(0.7)
4688 .build();
4689
4690 let mut output_boxes = Vec::with_capacity(50);
4691 let mut output_masks = Vec::with_capacity(50);
4692 let mut output_tracks = Vec::with_capacity(50);
4693
4694 decoder
4695 .decode_tracked_float(
4696 &mut tracker,
4697 0,
4698 &[
4699 boxes.view().into_dyn(),
4700 scores.view().into_dyn(),
4701 classes.view().into_dyn(),
4702 mask.view().into_dyn(),
4703 protos.view().into_dyn(),
4704 ],
4705 &mut output_boxes,
4706 &mut output_masks,
4707 &mut output_tracks,
4708 )
4709 .unwrap();
4710
4711 assert_eq!(output_boxes.len(), 1);
4712 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
4713
4714 for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
4717 *score = 0.0; }
4719
4720 decoder
4721 .decode_tracked_float(
4722 &mut tracker,
4723 100_000_000 / 3,
4724 &[
4725 boxes.view().into_dyn(),
4726 scores.view().into_dyn(),
4727 classes.view().into_dyn(),
4728 mask.view().into_dyn(),
4729 protos.view().into_dyn(),
4730 ],
4731 &mut output_boxes,
4732 &mut output_masks,
4733 &mut output_tracks,
4734 )
4735 .unwrap();
4736 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4737 assert!(output_masks.is_empty())
4739 }
4740
4741 #[test]
4742 fn test_decoder_tracked_end_to_end_segdet_split_proto() {
4743 let score_threshold = 0.45;
4744 let iou_threshold = 0.45;
4745
4746 let mut boxes = Array2::zeros((10, 4));
4747 let mut scores = Array2::zeros((10, 1));
4748 let mut classes = Array2::zeros((10, 1));
4749 let mask: Array2<f64> = Array2::zeros((10, 32));
4750 let protos = Array3::<f64>::zeros((160, 160, 32));
4751 let protos = protos.insert_axis(Axis(0));
4752
4753 let protos_quant = (1.0 / 255.0, 0.0);
4754 let protos: Array4<u8> = quantize_ndarray(protos.view(), protos_quant.into());
4755
4756 boxes
4757 .slice_mut(s![0, ..,])
4758 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4759 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4760 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4761
4762 let boxes = boxes.insert_axis(Axis(0));
4763 let scores = scores.insert_axis(Axis(0));
4764 let classes = classes.insert_axis(Axis(0));
4765 let mask = mask.insert_axis(Axis(0));
4766
4767 let detect_quant = (2.0 / 255.0, 0.0);
4768 let boxes: Array3<u8> = quantize_ndarray(boxes.view(), detect_quant.into());
4769 let mut scores: Array3<u8> = quantize_ndarray(scores.view(), detect_quant.into());
4770 let classes: Array3<u8> = quantize_ndarray(classes.view(), detect_quant.into());
4771 let mask: Array3<u8> = quantize_ndarray(mask.view(), detect_quant.into());
4772
4773 let config = "
4774decoder_version: yolo26
4775outputs:
4776 - type: boxes
4777 decoder: ultralytics
4778 quantization: [0.00784313725490196, 0]
4779 shape: [1, 10, 4]
4780 dshape:
4781 - [batch, 1]
4782 - [num_boxes, 10]
4783 - [box_coords, 4]
4784 normalized: true
4785 - type: scores
4786 decoder: ultralytics
4787 quantization: [0.00784313725490196, 0]
4788 shape: [1, 10, 1]
4789 dshape:
4790 - [batch, 1]
4791 - [num_boxes, 10]
4792 - [num_classes, 1]
4793 - type: classes
4794 decoder: ultralytics
4795 quantization: [0.00784313725490196, 0]
4796 shape: [1, 10, 1]
4797 dshape:
4798 - [batch, 1]
4799 - [num_boxes, 10]
4800 - [num_classes, 1]
4801 - type: mask_coefficients
4802 decoder: ultralytics
4803 quantization: [0.00784313725490196, 0]
4804 shape: [1, 10, 32]
4805 dshape:
4806 - [batch, 1]
4807 - [num_boxes, 10]
4808 - [num_protos, 32]
4809 - type: protos
4810 decoder: ultralytics
4811 quantization: [0.0039215686274509803921568627451, 128]
4812 shape: [1, 160, 160, 32]
4813 dshape:
4814 - [batch, 1]
4815 - [height, 160]
4816 - [width, 160]
4817 - [num_protos, 32]
4818";
4819
4820 let decoder = DecoderBuilder::default()
4821 .with_config_yaml_str(config.to_string())
4822 .with_score_threshold(score_threshold)
4823 .with_iou_threshold(iou_threshold)
4824 .build()
4825 .unwrap();
4826
4827 let expected_boxes = [DetectBox {
4829 bbox: BoundingBox {
4830 xmin: 0.12549022,
4831 ymin: 0.12549022,
4832 xmax: 0.23529413,
4833 ymax: 0.23529413,
4834 },
4835 score: 0.98823535,
4836 label: 2,
4837 }];
4838
4839 let mut tracker = ByteTrackBuilder::new()
4840 .track_update(0.1)
4841 .track_high_conf(0.7)
4842 .build();
4843
4844 let mut output_boxes = Vec::with_capacity(50);
4845 let mut output_tracks = Vec::with_capacity(50);
4846
4847 decoder
4848 .decode_tracked_quantized_proto(
4849 &mut tracker,
4850 0,
4851 &[
4852 boxes.view().into(),
4853 scores.view().into(),
4854 classes.view().into(),
4855 mask.view().into(),
4856 protos.view().into(),
4857 ],
4858 &mut output_boxes,
4859 &mut output_tracks,
4860 )
4861 .unwrap();
4862
4863 assert_eq!(output_boxes.len(), 1);
4864 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
4865
4866 for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
4869 *score = u8::MIN; }
4871
4872 let protos = decoder
4873 .decode_tracked_quantized_proto(
4874 &mut tracker,
4875 100_000_000 / 3,
4876 &[
4877 boxes.view().into(),
4878 scores.view().into(),
4879 classes.view().into(),
4880 mask.view().into(),
4881 protos.view().into(),
4882 ],
4883 &mut output_boxes,
4884 &mut output_tracks,
4885 )
4886 .unwrap();
4887 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4888 assert!(protos.is_some_and(|x| x.mask_coefficients.is_empty()))
4890 }
4891
4892 #[test]
4893 fn test_decoder_tracked_end_to_end_segdet_split_proto_float() {
4894 let score_threshold = 0.45;
4895 let iou_threshold = 0.45;
4896
4897 let mut boxes = Array2::zeros((10, 4));
4898 let mut scores = Array2::zeros((10, 1));
4899 let mut classes = Array2::zeros((10, 1));
4900 let mask: Array2<f64> = Array2::zeros((10, 32));
4901 let protos = Array3::<f64>::zeros((160, 160, 32));
4902 let protos = protos.insert_axis(Axis(0));
4903
4904 boxes
4905 .slice_mut(s![0, ..,])
4906 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4907 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4908 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4909
4910 let boxes = boxes.insert_axis(Axis(0));
4911 let mut scores = scores.insert_axis(Axis(0));
4912 let classes = classes.insert_axis(Axis(0));
4913 let mask = mask.insert_axis(Axis(0));
4914
4915 let config = "
4916decoder_version: yolo26
4917outputs:
4918 - type: boxes
4919 decoder: ultralytics
4920 quantization: [0.00784313725490196, 0]
4921 shape: [1, 10, 4]
4922 dshape:
4923 - [batch, 1]
4924 - [num_boxes, 10]
4925 - [box_coords, 4]
4926 normalized: true
4927 - type: scores
4928 decoder: ultralytics
4929 quantization: [0.00784313725490196, 0]
4930 shape: [1, 10, 1]
4931 dshape:
4932 - [batch, 1]
4933 - [num_boxes, 10]
4934 - [num_classes, 1]
4935 - type: classes
4936 decoder: ultralytics
4937 quantization: [0.00784313725490196, 0]
4938 shape: [1, 10, 1]
4939 dshape:
4940 - [batch, 1]
4941 - [num_boxes, 10]
4942 - [num_classes, 1]
4943 - type: mask_coefficients
4944 decoder: ultralytics
4945 quantization: [0.00784313725490196, 0]
4946 shape: [1, 10, 32]
4947 dshape:
4948 - [batch, 1]
4949 - [num_boxes, 10]
4950 - [num_protos, 32]
4951 - type: protos
4952 decoder: ultralytics
4953 quantization: [0.0039215686274509803921568627451, 128]
4954 shape: [1, 160, 160, 32]
4955 dshape:
4956 - [batch, 1]
4957 - [height, 160]
4958 - [width, 160]
4959 - [num_protos, 32]
4960";
4961
4962 let decoder = DecoderBuilder::default()
4963 .with_config_yaml_str(config.to_string())
4964 .with_score_threshold(score_threshold)
4965 .with_iou_threshold(iou_threshold)
4966 .build()
4967 .unwrap();
4968
4969 let expected_boxes = [DetectBox {
4971 bbox: BoundingBox {
4972 xmin: 0.1234,
4973 ymin: 0.1234,
4974 xmax: 0.2345,
4975 ymax: 0.2345,
4976 },
4977 score: 0.9876,
4978 label: 2,
4979 }];
4980
4981 let mut tracker = ByteTrackBuilder::new()
4982 .track_update(0.1)
4983 .track_high_conf(0.7)
4984 .build();
4985
4986 let mut output_boxes = Vec::with_capacity(50);
4987 let mut output_tracks = Vec::with_capacity(50);
4988
4989 decoder
4990 .decode_tracked_float_proto(
4991 &mut tracker,
4992 0,
4993 &[
4994 boxes.view().into_dyn(),
4995 scores.view().into_dyn(),
4996 classes.view().into_dyn(),
4997 mask.view().into_dyn(),
4998 protos.view().into_dyn(),
4999 ],
5000 &mut output_boxes,
5001 &mut output_tracks,
5002 )
5003 .unwrap();
5004
5005 assert_eq!(output_boxes.len(), 1);
5006 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
5007
5008 for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
5011 *score = 0.0; }
5013
5014 let protos = decoder
5015 .decode_tracked_float_proto(
5016 &mut tracker,
5017 100_000_000 / 3,
5018 &[
5019 boxes.view().into_dyn(),
5020 scores.view().into_dyn(),
5021 classes.view().into_dyn(),
5022 mask.view().into_dyn(),
5023 protos.view().into_dyn(),
5024 ],
5025 &mut output_boxes,
5026 &mut output_tracks,
5027 )
5028 .unwrap();
5029 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
5030 assert!(protos.is_some_and(|x| x.mask_coefficients.is_empty()))
5032 }
5033
5034 #[test]
5035 fn test_decoder_tracked_linear_motion() {
5036 use crate::configs::{DecoderType, Nms};
5037 use crate::DecoderBuilder;
5038
5039 let score_threshold = 0.25;
5040 let iou_threshold = 0.1;
5041 let out = include_bytes!(concat!(
5042 env!("CARGO_MANIFEST_DIR"),
5043 "/../../testdata/yolov8s_80_classes.bin"
5044 ));
5045 let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
5046 let mut out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
5047 let quant = (0.0040811873, -123).into();
5048
5049 let decoder = DecoderBuilder::default()
5050 .with_config_yolo_det(
5051 crate::configs::Detection {
5052 decoder: DecoderType::Ultralytics,
5053 shape: vec![1, 84, 8400],
5054 anchors: None,
5055 quantization: Some(quant),
5056 dshape: vec![
5057 (crate::configs::DimName::Batch, 1),
5058 (crate::configs::DimName::NumFeatures, 84),
5059 (crate::configs::DimName::NumBoxes, 8400),
5060 ],
5061 normalized: Some(true),
5062 },
5063 None,
5064 )
5065 .with_score_threshold(score_threshold)
5066 .with_iou_threshold(iou_threshold)
5067 .with_nms(Some(Nms::ClassAgnostic))
5068 .build()
5069 .unwrap();
5070
5071 let mut expected_boxes = [
5072 DetectBox {
5073 bbox: BoundingBox {
5074 xmin: 0.5285137,
5075 ymin: 0.05305544,
5076 xmax: 0.87541467,
5077 ymax: 0.9998909,
5078 },
5079 score: 0.5591227,
5080 label: 0,
5081 },
5082 DetectBox {
5083 bbox: BoundingBox {
5084 xmin: 0.130598,
5085 ymin: 0.43260583,
5086 xmax: 0.35098213,
5087 ymax: 0.9958097,
5088 },
5089 score: 0.33057618,
5090 label: 75,
5091 },
5092 ];
5093
5094 let mut tracker = ByteTrackBuilder::new()
5095 .track_update(0.1)
5096 .track_high_conf(0.3)
5097 .build();
5098
5099 let mut output_boxes = Vec::with_capacity(50);
5100 let mut output_masks = Vec::with_capacity(50);
5101 let mut output_tracks = Vec::with_capacity(50);
5102
5103 decoder
5104 .decode_tracked_quantized(
5105 &mut tracker,
5106 0,
5107 &[out.view().into()],
5108 &mut output_boxes,
5109 &mut output_masks,
5110 &mut output_tracks,
5111 )
5112 .unwrap();
5113
5114 assert_eq!(output_boxes.len(), 2);
5115 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
5116 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
5117
5118 for i in 1..=100 {
5119 let mut out = out.clone();
5120 let mut x_values = out.slice_mut(s![0, 0, ..]);
5122 for x in x_values.iter_mut() {
5123 *x = x.saturating_add((i as f32 * 1e-3 / quant.0).round() as i8);
5124 }
5125
5126 decoder
5127 .decode_tracked_quantized(
5128 &mut tracker,
5129 100_000_000 * i / 3, &[out.view().into()],
5131 &mut output_boxes,
5132 &mut output_masks,
5133 &mut output_tracks,
5134 )
5135 .unwrap();
5136
5137 assert_eq!(output_boxes.len(), 2);
5138 }
5139 let tracks = tracker.get_active_tracks();
5140 let predicted_boxes: Vec<_> = tracks
5141 .iter()
5142 .map(|track| {
5143 let mut l = track.last_box;
5144 l.bbox = track.info.tracked_location.into();
5145 l
5146 })
5147 .collect();
5148 expected_boxes[0].bbox.xmin += 0.1; expected_boxes[0].bbox.xmax += 0.1;
5150 expected_boxes[1].bbox.xmin += 0.1;
5151 expected_boxes[1].bbox.xmax += 0.1;
5152
5153 assert!(predicted_boxes[0].equal_within_delta(&expected_boxes[0], 1e-3));
5154 assert!(predicted_boxes[1].equal_within_delta(&expected_boxes[1], 1e-3));
5155
5156 let mut scores_values = out.slice_mut(s![0, 4.., ..]);
5158 for score in scores_values.iter_mut() {
5159 *score = i8::MIN; }
5161 decoder
5162 .decode_tracked_quantized(
5163 &mut tracker,
5164 100_000_000 * 101 / 3,
5165 &[out.view().into()],
5166 &mut output_boxes,
5167 &mut output_masks,
5168 &mut output_tracks,
5169 )
5170 .unwrap();
5171 expected_boxes[0].bbox.xmin += 0.001; expected_boxes[0].bbox.xmax += 0.001;
5173 expected_boxes[1].bbox.xmin += 0.001;
5174 expected_boxes[1].bbox.xmax += 0.001;
5175
5176 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-3));
5177 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-3));
5178 }
5179
5180 #[test]
5181 fn test_decoder_tracked_end_to_end_float() {
5182 let score_threshold = 0.45;
5183 let iou_threshold = 0.45;
5184
5185 let mut boxes = Array2::zeros((10, 4));
5186 let mut scores = Array2::zeros((10, 1));
5187 let mut classes = Array2::zeros((10, 1));
5188
5189 boxes
5190 .slice_mut(s![0, ..,])
5191 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
5192 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
5193 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
5194
5195 let detect = ndarray::concatenate![Axis(1), boxes.view(), scores.view(), classes.view(),];
5196 let mut detect = detect.insert_axis(Axis(0));
5197 assert_eq!(detect.shape(), &[1, 10, 6]);
5198 let config = "
5199decoder_version: yolo26
5200outputs:
5201 - type: detection
5202 decoder: ultralytics
5203 quantization: [0.00784313725490196, 0]
5204 shape: [1, 10, 6]
5205 dshape:
5206 - [batch, 1]
5207 - [num_boxes, 10]
5208 - [num_features, 6]
5209 normalized: true
5210";
5211
5212 let decoder = DecoderBuilder::default()
5213 .with_config_yaml_str(config.to_string())
5214 .with_score_threshold(score_threshold)
5215 .with_iou_threshold(iou_threshold)
5216 .build()
5217 .unwrap();
5218
5219 let expected_boxes = [DetectBox {
5220 bbox: BoundingBox {
5221 xmin: 0.1234,
5222 ymin: 0.1234,
5223 xmax: 0.2345,
5224 ymax: 0.2345,
5225 },
5226 score: 0.9876,
5227 label: 2,
5228 }];
5229
5230 let mut tracker = ByteTrackBuilder::new()
5231 .track_update(0.1)
5232 .track_high_conf(0.7)
5233 .build();
5234
5235 let mut output_boxes = Vec::with_capacity(50);
5236 let mut output_masks = Vec::with_capacity(50);
5237 let mut output_tracks = Vec::with_capacity(50);
5238
5239 decoder
5240 .decode_tracked_float(
5241 &mut tracker,
5242 0,
5243 &[detect.view().into_dyn()],
5244 &mut output_boxes,
5245 &mut output_masks,
5246 &mut output_tracks,
5247 )
5248 .unwrap();
5249
5250 assert_eq!(output_boxes.len(), 1);
5251 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
5252
5253 for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
5256 *score = 0.0; }
5258
5259 decoder
5260 .decode_tracked_float(
5261 &mut tracker,
5262 100_000_000 / 3,
5263 &[detect.view().into_dyn()],
5264 &mut output_boxes,
5265 &mut output_masks,
5266 &mut output_tracks,
5267 )
5268 .unwrap();
5269 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
5270 }
5271}