1#![cfg_attr(coverage_nightly, feature(coverage_attribute))]
68
69use ndarray::{Array, Array2, Array3, ArrayView, ArrayView1, ArrayView3, Dimension};
70use num_traits::{AsPrimitive, Float, PrimInt};
71
72pub mod byte;
73pub mod error;
74pub mod float;
75pub mod modelpack;
76pub mod yolo;
77
78mod decoder;
79pub use decoder::*;
80
81pub use configs::{DecoderVersion, Nms};
82pub use error::{DecoderError, DecoderResult};
83
84use crate::{
85 decoder::configs::QuantTuple, modelpack::modelpack_segmentation_to_mask,
86 yolo::yolo_segmentation_to_mask,
87};
88
89pub trait BBoxTypeTrait {
91 fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4];
93
94 fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
96 input: &[B; 4],
97 quant: Quantization,
98 ) -> [A; 4]
99 where
100 f32: AsPrimitive<A>,
101 i32: AsPrimitive<A>;
102
103 #[inline(always)]
114 fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
115 input: ArrayView1<B>,
116 ) -> [A; 4] {
117 Self::to_xyxy_float(&[input[0], input[1], input[2], input[3]])
118 }
119
120 #[inline(always)]
121 fn ndarray_to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
123 input: ArrayView1<B>,
124 quant: Quantization,
125 ) -> [A; 4]
126 where
127 f32: AsPrimitive<A>,
128 i32: AsPrimitive<A>,
129 {
130 Self::to_xyxy_dequant(&[input[0], input[1], input[2], input[3]], quant)
131 }
132}
133
134#[derive(Debug, Clone, Copy, PartialEq, Eq)]
136pub struct XYXY {}
137
138impl BBoxTypeTrait for XYXY {
139 fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
140 input.map(|b| b.as_())
141 }
142
143 fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
144 input: &[B; 4],
145 quant: Quantization,
146 ) -> [A; 4]
147 where
148 f32: AsPrimitive<A>,
149 i32: AsPrimitive<A>,
150 {
151 let scale = quant.scale.as_();
152 let zp = quant.zero_point.as_();
153 input.map(|b| (b.as_() - zp) * scale)
154 }
155
156 #[inline(always)]
157 fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
158 input: ArrayView1<B>,
159 ) -> [A; 4] {
160 [
161 input[0].as_(),
162 input[1].as_(),
163 input[2].as_(),
164 input[3].as_(),
165 ]
166 }
167}
168
169#[derive(Debug, Clone, Copy, PartialEq, Eq)]
172pub struct XYWH {}
173
174impl BBoxTypeTrait for XYWH {
175 #[inline(always)]
176 fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
177 let half = A::one() / (A::one() + A::one());
178 [
179 (input[0].as_()) - (input[2].as_() * half),
180 (input[1].as_()) - (input[3].as_() * half),
181 (input[0].as_()) + (input[2].as_() * half),
182 (input[1].as_()) + (input[3].as_() * half),
183 ]
184 }
185
186 #[inline(always)]
187 fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
188 input: &[B; 4],
189 quant: Quantization,
190 ) -> [A; 4]
191 where
192 f32: AsPrimitive<A>,
193 i32: AsPrimitive<A>,
194 {
195 let scale = quant.scale.as_();
196 let half_scale = (quant.scale * 0.5).as_();
197 let zp = quant.zero_point.as_();
198 let [x, y, w, h] = [
199 (input[0].as_() - zp) * scale,
200 (input[1].as_() - zp) * scale,
201 (input[2].as_() - zp) * half_scale,
202 (input[3].as_() - zp) * half_scale,
203 ];
204
205 [x - w, y - h, x + w, y + h]
206 }
207
208 #[inline(always)]
209 fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
210 input: ArrayView1<B>,
211 ) -> [A; 4] {
212 let half = A::one() / (A::one() + A::one());
213 [
214 (input[0].as_()) - (input[2].as_() * half),
215 (input[1].as_()) - (input[3].as_() * half),
216 (input[0].as_()) + (input[2].as_() * half),
217 (input[1].as_()) + (input[3].as_() * half),
218 ]
219 }
220}
221
222#[derive(Debug, Clone, Copy, PartialEq)]
224pub struct Quantization {
225 pub scale: f32,
226 pub zero_point: i32,
227}
228
229impl Quantization {
230 pub fn new(scale: f32, zero_point: i32) -> Self {
239 Self { scale, zero_point }
240 }
241}
242
243impl From<QuantTuple> for Quantization {
244 fn from(quant_tuple: QuantTuple) -> Quantization {
255 Quantization {
256 scale: quant_tuple.0,
257 zero_point: quant_tuple.1,
258 }
259 }
260}
261
262impl<S, Z> From<(S, Z)> for Quantization
263where
264 S: AsPrimitive<f32>,
265 Z: AsPrimitive<i32>,
266{
267 fn from((scale, zp): (S, Z)) -> Quantization {
276 Self {
277 scale: scale.as_(),
278 zero_point: zp.as_(),
279 }
280 }
281}
282
283impl Default for Quantization {
284 fn default() -> Self {
293 Self {
294 scale: 1.0,
295 zero_point: 0,
296 }
297 }
298}
299
300#[derive(Debug, Clone, Copy, PartialEq, Default)]
302pub struct DetectBox {
303 pub bbox: BoundingBox,
304 pub score: f32,
306 pub label: usize,
308}
309
310#[derive(Debug, Clone, Copy, PartialEq, Default)]
312pub struct BoundingBox {
313 pub xmin: f32,
315 pub ymin: f32,
317 pub xmax: f32,
319 pub ymax: f32,
321}
322
323impl BoundingBox {
324 pub fn new(xmin: f32, ymin: f32, xmax: f32, ymax: f32) -> Self {
326 Self {
327 xmin,
328 ymin,
329 xmax,
330 ymax,
331 }
332 }
333
334 pub fn to_canonical(&self) -> Self {
343 let xmin = self.xmin.min(self.xmax);
344 let xmax = self.xmin.max(self.xmax);
345 let ymin = self.ymin.min(self.ymax);
346 let ymax = self.ymin.max(self.ymax);
347 BoundingBox {
348 xmin,
349 ymin,
350 xmax,
351 ymax,
352 }
353 }
354}
355
356impl From<BoundingBox> for [f32; 4] {
357 fn from(b: BoundingBox) -> Self {
372 [b.xmin, b.ymin, b.xmax, b.ymax]
373 }
374}
375
376impl From<[f32; 4]> for BoundingBox {
377 fn from(arr: [f32; 4]) -> Self {
380 BoundingBox {
381 xmin: arr[0],
382 ymin: arr[1],
383 xmax: arr[2],
384 ymax: arr[3],
385 }
386 }
387}
388
389impl DetectBox {
390 pub fn equal_within_delta(&self, rhs: &DetectBox, eps: f32) -> bool {
419 let eq_delta = |a: f32, b: f32| (a - b).abs() <= eps;
420 self.label == rhs.label
421 && eq_delta(self.score, rhs.score)
422 && eq_delta(self.bbox.xmin, rhs.bbox.xmin)
423 && eq_delta(self.bbox.ymin, rhs.bbox.ymin)
424 && eq_delta(self.bbox.xmax, rhs.bbox.xmax)
425 && eq_delta(self.bbox.ymax, rhs.bbox.ymax)
426 }
427}
428
429#[derive(Debug, Clone, PartialEq, Default)]
432pub struct Segmentation {
433 pub xmin: f32,
435 pub ymin: f32,
437 pub xmax: f32,
439 pub ymax: f32,
441 pub segmentation: Array3<u8>,
451}
452
453#[derive(Debug, Clone)]
459pub enum ProtoTensor {
460 Quantized {
464 protos: Array3<i8>,
465 quantization: Quantization,
466 },
467 Float(Array3<f32>),
469}
470
471impl ProtoTensor {
472 pub fn is_quantized(&self) -> bool {
474 matches!(self, ProtoTensor::Quantized { .. })
475 }
476
477 pub fn dim(&self) -> (usize, usize, usize) {
479 match self {
480 ProtoTensor::Quantized { protos, .. } => protos.dim(),
481 ProtoTensor::Float(arr) => arr.dim(),
482 }
483 }
484
485 pub fn as_f32(&self) -> std::borrow::Cow<'_, Array3<f32>> {
488 match self {
489 ProtoTensor::Float(arr) => std::borrow::Cow::Borrowed(arr),
490 ProtoTensor::Quantized {
491 protos,
492 quantization,
493 } => {
494 let scale = quantization.scale;
495 let zp = quantization.zero_point as f32;
496 std::borrow::Cow::Owned(protos.map(|&v| (v as f32 - zp) * scale))
497 }
498 }
499 }
500}
501
502#[derive(Debug, Clone)]
508pub struct ProtoData {
509 pub mask_coefficients: Vec<Vec<f32>>,
511 pub protos: ProtoTensor,
513}
514
515pub fn dequant_detect_box<SCORE: PrimInt + AsPrimitive<f32>>(
533 detect: &DetectBoxQuantized<SCORE>,
534 quant_scores: Quantization,
535) -> DetectBox {
536 let scaled_zp = -quant_scores.scale * quant_scores.zero_point as f32;
537 DetectBox {
538 bbox: detect.bbox,
539 score: quant_scores.scale * detect.score.as_() + scaled_zp,
540 label: detect.label,
541 }
542}
543#[derive(Debug, Clone, Copy, PartialEq)]
545pub struct DetectBoxQuantized<
546 SCORE: PrimInt + AsPrimitive<f32>,
548> {
549 pub bbox: BoundingBox,
551 pub score: SCORE,
554 pub label: usize,
556}
557
558pub fn dequantize_ndarray<T: AsPrimitive<F>, D: Dimension, F: Float + 'static>(
571 input: ArrayView<T, D>,
572 quant: Quantization,
573) -> Array<F, D>
574where
575 i32: num_traits::AsPrimitive<F>,
576 f32: num_traits::AsPrimitive<F>,
577{
578 let zero_point = quant.zero_point.as_();
579 let scale = quant.scale.as_();
580 if zero_point != F::zero() {
581 let scaled_zero = -zero_point * scale;
582 input.mapv(|d| d.as_() * scale + scaled_zero)
583 } else {
584 input.mapv(|d| d.as_() * scale)
585 }
586}
587
588pub fn dequantize_cpu<T: AsPrimitive<F>, F: Float + 'static>(
601 input: &[T],
602 quant: Quantization,
603 output: &mut [F],
604) where
605 f32: num_traits::AsPrimitive<F>,
606 i32: num_traits::AsPrimitive<F>,
607{
608 assert!(input.len() == output.len());
609 let zero_point = quant.zero_point.as_();
610 let scale = quant.scale.as_();
611 if zero_point != F::zero() {
612 let scaled_zero = -zero_point * scale; input
614 .iter()
615 .zip(output)
616 .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
617 } else {
618 input
619 .iter()
620 .zip(output)
621 .for_each(|(d, deq)| *deq = d.as_() * scale);
622 }
623}
624
625pub fn dequantize_cpu_chunked<T: AsPrimitive<F>, F: Float + 'static>(
639 input: &[T],
640 quant: Quantization,
641 output: &mut [F],
642) where
643 f32: num_traits::AsPrimitive<F>,
644 i32: num_traits::AsPrimitive<F>,
645{
646 assert!(input.len() == output.len());
647 let zero_point = quant.zero_point.as_();
648 let scale = quant.scale.as_();
649
650 let input = input.as_chunks::<4>();
651 let output = output.as_chunks_mut::<4>();
652
653 if zero_point != F::zero() {
654 let scaled_zero = -zero_point * scale; input
657 .0
658 .iter()
659 .zip(output.0)
660 .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale + scaled_zero));
661 input
662 .1
663 .iter()
664 .zip(output.1)
665 .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
666 } else {
667 input
668 .0
669 .iter()
670 .zip(output.0)
671 .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale));
672 input
673 .1
674 .iter()
675 .zip(output.1)
676 .for_each(|(d, deq)| *deq = d.as_() * scale);
677 }
678}
679
680pub fn segmentation_to_mask(segmentation: ArrayView3<u8>) -> Result<Array2<u8>, DecoderError> {
698 if segmentation.shape()[2] == 0 {
699 return Err(DecoderError::InvalidShape(
700 "Segmentation tensor must have non-zero depth".to_string(),
701 ));
702 }
703 if segmentation.shape()[2] == 1 {
704 yolo_segmentation_to_mask(segmentation, 128)
705 } else {
706 Ok(modelpack_segmentation_to_mask(segmentation))
707 }
708}
709
710fn arg_max<T: PartialOrd + Copy>(score: ArrayView1<T>) -> (T, usize) {
712 score
713 .iter()
714 .enumerate()
715 .fold((score[0], 0), |(max, arg_max), (ind, s)| {
716 if max > *s {
717 (max, arg_max)
718 } else {
719 (*s, ind)
720 }
721 })
722}
723#[cfg(test)]
724#[cfg_attr(coverage_nightly, coverage(off))]
725mod decoder_tests {
726 #![allow(clippy::excessive_precision)]
727 use crate::{
728 configs::{DecoderType, DimName, Protos},
729 modelpack::{decode_modelpack_det, decode_modelpack_split_quant},
730 yolo::{
731 decode_yolo_det, decode_yolo_det_float, decode_yolo_segdet_float,
732 decode_yolo_segdet_quant,
733 },
734 *,
735 };
736 use ndarray::{array, s, Array4};
737 use ndarray_stats::DeviationExt;
738
739 fn compare_outputs(
740 boxes: (&[DetectBox], &[DetectBox]),
741 masks: (&[Segmentation], &[Segmentation]),
742 ) {
743 let (boxes0, boxes1) = boxes;
744 let (masks0, masks1) = masks;
745
746 assert_eq!(boxes0.len(), boxes1.len());
747 assert_eq!(masks0.len(), masks1.len());
748
749 for (b_i8, b_f32) in boxes0.iter().zip(boxes1) {
750 assert!(
751 b_i8.equal_within_delta(b_f32, 1e-6),
752 "{b_i8:?} is not equal to {b_f32:?}"
753 );
754 }
755
756 for (m_i8, m_f32) in masks0.iter().zip(masks1) {
757 assert_eq!(
758 [m_i8.xmin, m_i8.ymin, m_i8.xmax, m_i8.ymax],
759 [m_f32.xmin, m_f32.ymin, m_f32.xmax, m_f32.ymax],
760 );
761 assert_eq!(m_i8.segmentation.shape(), m_f32.segmentation.shape());
762 let mask_i8 = m_i8.segmentation.map(|x| *x as i32);
763 let mask_f32 = m_f32.segmentation.map(|x| *x as i32);
764 let diff = &mask_i8 - &mask_f32;
765 for x in 0..diff.shape()[0] {
766 for y in 0..diff.shape()[1] {
767 for z in 0..diff.shape()[2] {
768 let val = diff[[x, y, z]];
769 assert!(
770 val.abs() <= 1,
771 "Difference between mask0 and mask1 is greater than 1 at ({}, {}, {}): {}",
772 x,
773 y,
774 z,
775 val
776 );
777 }
778 }
779 }
780 let mean_sq_err = mask_i8.mean_sq_err(&mask_f32).unwrap();
781 assert!(
782 mean_sq_err < 1e-2,
783 "Mean Square Error between masks was greater than 1%: {:.2}%",
784 mean_sq_err * 100.0
785 );
786 }
787 }
788
789 #[test]
790 fn test_decoder_modelpack() {
791 let score_threshold = 0.45;
792 let iou_threshold = 0.45;
793 let boxes = include_bytes!(concat!(
794 env!("CARGO_MANIFEST_DIR"),
795 "/../../testdata/modelpack_boxes_1935x1x4.bin"
796 ));
797 let boxes = ndarray::Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
798
799 let scores = include_bytes!(concat!(
800 env!("CARGO_MANIFEST_DIR"),
801 "/../../testdata/modelpack_scores_1935x1.bin"
802 ));
803 let scores = ndarray::Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
804
805 let quant_boxes = (0.004656755365431309, 21).into();
806 let quant_scores = (0.0019603664986789227, 0).into();
807
808 let decoder = DecoderBuilder::default()
809 .with_config_modelpack_det(
810 configs::Boxes {
811 decoder: DecoderType::ModelPack,
812 quantization: Some(quant_boxes),
813 shape: vec![1, 1935, 1, 4],
814 dshape: vec![
815 (DimName::Batch, 1),
816 (DimName::NumBoxes, 1935),
817 (DimName::Padding, 1),
818 (DimName::BoxCoords, 4),
819 ],
820 normalized: Some(true),
821 },
822 configs::Scores {
823 decoder: DecoderType::ModelPack,
824 quantization: Some(quant_scores),
825 shape: vec![1, 1935, 1],
826 dshape: vec![
827 (DimName::Batch, 1),
828 (DimName::NumBoxes, 1935),
829 (DimName::NumClasses, 1),
830 ],
831 },
832 )
833 .with_score_threshold(score_threshold)
834 .with_iou_threshold(iou_threshold)
835 .build()
836 .unwrap();
837
838 let quant_boxes = quant_boxes.into();
839 let quant_scores = quant_scores.into();
840
841 let mut output_boxes: Vec<_> = Vec::with_capacity(50);
842 decode_modelpack_det(
843 (boxes.slice(s![0, .., 0, ..]), quant_boxes),
844 (scores.slice(s![0, .., ..]), quant_scores),
845 score_threshold,
846 iou_threshold,
847 &mut output_boxes,
848 );
849 assert!(output_boxes[0].equal_within_delta(
850 &DetectBox {
851 bbox: BoundingBox {
852 xmin: 0.40513772,
853 ymin: 0.6379755,
854 xmax: 0.5122431,
855 ymax: 0.7730214,
856 },
857 score: 0.4861709,
858 label: 0
859 },
860 1e-6
861 ));
862
863 let mut output_boxes1 = Vec::with_capacity(50);
864 let mut output_masks1 = Vec::with_capacity(50);
865
866 decoder
867 .decode_quantized(
868 &[boxes.view().into(), scores.view().into()],
869 &mut output_boxes1,
870 &mut output_masks1,
871 )
872 .unwrap();
873
874 let mut output_boxes_float = Vec::with_capacity(50);
875 let mut output_masks_float = Vec::with_capacity(50);
876
877 let boxes = dequantize_ndarray(boxes.view(), quant_boxes);
878 let scores = dequantize_ndarray(scores.view(), quant_scores);
879
880 decoder
881 .decode_float::<f32>(
882 &[boxes.view().into_dyn(), scores.view().into_dyn()],
883 &mut output_boxes_float,
884 &mut output_masks_float,
885 )
886 .unwrap();
887
888 compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
889 compare_outputs(
890 (&output_boxes, &output_boxes_float),
891 (&[], &output_masks_float),
892 );
893 }
894
895 #[test]
896 fn test_decoder_modelpack_split_u8() {
897 let score_threshold = 0.45;
898 let iou_threshold = 0.45;
899 let detect0 = include_bytes!(concat!(
900 env!("CARGO_MANIFEST_DIR"),
901 "/../../testdata/modelpack_split_9x15x18.bin"
902 ));
903 let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
904
905 let detect1 = include_bytes!(concat!(
906 env!("CARGO_MANIFEST_DIR"),
907 "/../../testdata/modelpack_split_17x30x18.bin"
908 ));
909 let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
910
911 let quant0 = (0.08547406643629074, 174).into();
912 let quant1 = (0.09929127991199493, 183).into();
913 let anchors0 = vec![
914 [0.36666667461395264, 0.31481480598449707],
915 [0.38749998807907104, 0.4740740656852722],
916 [0.5333333611488342, 0.644444465637207],
917 ];
918 let anchors1 = vec![
919 [0.13750000298023224, 0.2074074000120163],
920 [0.2541666626930237, 0.21481481194496155],
921 [0.23125000298023224, 0.35185185074806213],
922 ];
923
924 let detect_config0 = configs::Detection {
925 decoder: DecoderType::ModelPack,
926 shape: vec![1, 9, 15, 18],
927 anchors: Some(anchors0.clone()),
928 quantization: Some(quant0),
929 dshape: vec![
930 (DimName::Batch, 1),
931 (DimName::Height, 9),
932 (DimName::Width, 15),
933 (DimName::NumAnchorsXFeatures, 18),
934 ],
935 normalized: Some(true),
936 };
937
938 let detect_config1 = configs::Detection {
939 decoder: DecoderType::ModelPack,
940 shape: vec![1, 17, 30, 18],
941 anchors: Some(anchors1.clone()),
942 quantization: Some(quant1),
943 dshape: vec![
944 (DimName::Batch, 1),
945 (DimName::Height, 17),
946 (DimName::Width, 30),
947 (DimName::NumAnchorsXFeatures, 18),
948 ],
949 normalized: Some(true),
950 };
951
952 let config0 = (&detect_config0).try_into().unwrap();
953 let config1 = (&detect_config1).try_into().unwrap();
954
955 let decoder = DecoderBuilder::default()
956 .with_config_modelpack_det_split(vec![detect_config1, detect_config0])
957 .with_score_threshold(score_threshold)
958 .with_iou_threshold(iou_threshold)
959 .build()
960 .unwrap();
961
962 let quant0 = quant0.into();
963 let quant1 = quant1.into();
964
965 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
966 decode_modelpack_split_quant(
967 &[
968 detect0.slice(s![0, .., .., ..]),
969 detect1.slice(s![0, .., .., ..]),
970 ],
971 &[config0, config1],
972 score_threshold,
973 iou_threshold,
974 &mut output_boxes,
975 );
976 assert!(output_boxes[0].equal_within_delta(
977 &DetectBox {
978 bbox: BoundingBox {
979 xmin: 0.43171933,
980 ymin: 0.68243736,
981 xmax: 0.5626645,
982 ymax: 0.808863,
983 },
984 score: 0.99240804,
985 label: 0
986 },
987 1e-6
988 ));
989
990 let mut output_boxes1: Vec<_> = Vec::with_capacity(10);
991 let mut output_masks1: Vec<_> = Vec::with_capacity(10);
992 decoder
993 .decode_quantized(
994 &[detect0.view().into(), detect1.view().into()],
995 &mut output_boxes1,
996 &mut output_masks1,
997 )
998 .unwrap();
999
1000 let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(10);
1001 let mut output_masks1_f32: Vec<_> = Vec::with_capacity(10);
1002
1003 let detect0 = dequantize_ndarray(detect0.view(), quant0);
1004 let detect1 = dequantize_ndarray(detect1.view(), quant1);
1005 decoder
1006 .decode_float::<f32>(
1007 &[detect0.view().into_dyn(), detect1.view().into_dyn()],
1008 &mut output_boxes1_f32,
1009 &mut output_masks1_f32,
1010 )
1011 .unwrap();
1012
1013 compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1014 compare_outputs(
1015 (&output_boxes, &output_boxes1_f32),
1016 (&[], &output_masks1_f32),
1017 );
1018 }
1019
1020 #[test]
1021 fn test_decoder_parse_config_modelpack_split_u8() {
1022 let score_threshold = 0.45;
1023 let iou_threshold = 0.45;
1024 let detect0 = include_bytes!(concat!(
1025 env!("CARGO_MANIFEST_DIR"),
1026 "/../../testdata/modelpack_split_9x15x18.bin"
1027 ));
1028 let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1029
1030 let detect1 = include_bytes!(concat!(
1031 env!("CARGO_MANIFEST_DIR"),
1032 "/../../testdata/modelpack_split_17x30x18.bin"
1033 ));
1034 let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1035
1036 let decoder = DecoderBuilder::default()
1037 .with_config_yaml_str(
1038 include_str!(concat!(
1039 env!("CARGO_MANIFEST_DIR"),
1040 "/../../testdata/modelpack_split.yaml"
1041 ))
1042 .to_string(),
1043 )
1044 .with_score_threshold(score_threshold)
1045 .with_iou_threshold(iou_threshold)
1046 .build()
1047 .unwrap();
1048
1049 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1050 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1051 decoder
1052 .decode_quantized(
1053 &[
1054 ArrayViewDQuantized::from(detect1.view()),
1055 ArrayViewDQuantized::from(detect0.view()),
1056 ],
1057 &mut output_boxes,
1058 &mut output_masks,
1059 )
1060 .unwrap();
1061 assert!(output_boxes[0].equal_within_delta(
1062 &DetectBox {
1063 bbox: BoundingBox {
1064 xmin: 0.43171933,
1065 ymin: 0.68243736,
1066 xmax: 0.5626645,
1067 ymax: 0.808863,
1068 },
1069 score: 0.99240804,
1070 label: 0
1071 },
1072 1e-6
1073 ));
1074 }
1075
1076 #[test]
1077 fn test_modelpack_seg() {
1078 let out = include_bytes!(concat!(
1079 env!("CARGO_MANIFEST_DIR"),
1080 "/../../testdata/modelpack_seg_2x160x160.bin"
1081 ));
1082 let out = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
1083 let quant = (1.0 / 255.0, 0).into();
1084
1085 let decoder = DecoderBuilder::default()
1086 .with_config_modelpack_seg(configs::Segmentation {
1087 decoder: DecoderType::ModelPack,
1088 quantization: Some(quant),
1089 shape: vec![1, 2, 160, 160],
1090 dshape: vec![
1091 (DimName::Batch, 1),
1092 (DimName::NumClasses, 2),
1093 (DimName::Height, 160),
1094 (DimName::Width, 160),
1095 ],
1096 })
1097 .build()
1098 .unwrap();
1099 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1100 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1101 decoder
1102 .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
1103 .unwrap();
1104
1105 let mut mask = out.slice(s![0, .., .., ..]);
1106 mask.swap_axes(0, 1);
1107 mask.swap_axes(1, 2);
1108 let mask = [Segmentation {
1109 xmin: 0.0,
1110 ymin: 0.0,
1111 xmax: 1.0,
1112 ymax: 1.0,
1113 segmentation: mask.into_owned(),
1114 }];
1115 compare_outputs((&[], &output_boxes), (&mask, &output_masks));
1116
1117 decoder
1118 .decode_float::<f32>(
1119 &[dequantize_ndarray(out.view(), quant.into())
1120 .view()
1121 .into_dyn()],
1122 &mut output_boxes,
1123 &mut output_masks,
1124 )
1125 .unwrap();
1126
1127 compare_outputs((&[], &output_boxes), (&[], &[]));
1133 let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1134 let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1135
1136 assert_eq!(mask0, mask1);
1137 }
1138 #[test]
1139 fn test_modelpack_seg_quant() {
1140 let out = include_bytes!(concat!(
1141 env!("CARGO_MANIFEST_DIR"),
1142 "/../../testdata/modelpack_seg_2x160x160.bin"
1143 ));
1144 let out_u8 = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
1145 let out_i8 = out_u8.mapv(|x| (x as i16 - 128) as i8);
1146 let out_u16 = out_u8.mapv(|x| (x as u16) << 8);
1147 let out_i16 = out_u8.mapv(|x| (((x as i32) << 8) - 32768) as i16);
1148 let out_u32 = out_u8.mapv(|x| (x as u32) << 24);
1149 let out_i32 = out_u8.mapv(|x| (((x as i64) << 24) - 2147483648) as i32);
1150
1151 let quant = (1.0 / 255.0, 0).into();
1152
1153 let decoder = DecoderBuilder::default()
1154 .with_config_modelpack_seg(configs::Segmentation {
1155 decoder: DecoderType::ModelPack,
1156 quantization: Some(quant),
1157 shape: vec![1, 2, 160, 160],
1158 dshape: vec![
1159 (DimName::Batch, 1),
1160 (DimName::NumClasses, 2),
1161 (DimName::Height, 160),
1162 (DimName::Width, 160),
1163 ],
1164 })
1165 .build()
1166 .unwrap();
1167 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1168 let mut output_masks_u8: Vec<_> = Vec::with_capacity(10);
1169 decoder
1170 .decode_quantized(
1171 &[out_u8.view().into()],
1172 &mut output_boxes,
1173 &mut output_masks_u8,
1174 )
1175 .unwrap();
1176
1177 let mut output_masks_i8: Vec<_> = Vec::with_capacity(10);
1178 decoder
1179 .decode_quantized(
1180 &[out_i8.view().into()],
1181 &mut output_boxes,
1182 &mut output_masks_i8,
1183 )
1184 .unwrap();
1185
1186 let mut output_masks_u16: Vec<_> = Vec::with_capacity(10);
1187 decoder
1188 .decode_quantized(
1189 &[out_u16.view().into()],
1190 &mut output_boxes,
1191 &mut output_masks_u16,
1192 )
1193 .unwrap();
1194
1195 let mut output_masks_i16: Vec<_> = Vec::with_capacity(10);
1196 decoder
1197 .decode_quantized(
1198 &[out_i16.view().into()],
1199 &mut output_boxes,
1200 &mut output_masks_i16,
1201 )
1202 .unwrap();
1203
1204 let mut output_masks_u32: Vec<_> = Vec::with_capacity(10);
1205 decoder
1206 .decode_quantized(
1207 &[out_u32.view().into()],
1208 &mut output_boxes,
1209 &mut output_masks_u32,
1210 )
1211 .unwrap();
1212
1213 let mut output_masks_i32: Vec<_> = Vec::with_capacity(10);
1214 decoder
1215 .decode_quantized(
1216 &[out_i32.view().into()],
1217 &mut output_boxes,
1218 &mut output_masks_i32,
1219 )
1220 .unwrap();
1221
1222 compare_outputs((&[], &output_boxes), (&[], &[]));
1223 let mask_u8 = segmentation_to_mask(output_masks_u8[0].segmentation.view()).unwrap();
1224 let mask_i8 = segmentation_to_mask(output_masks_i8[0].segmentation.view()).unwrap();
1225 let mask_u16 = segmentation_to_mask(output_masks_u16[0].segmentation.view()).unwrap();
1226 let mask_i16 = segmentation_to_mask(output_masks_i16[0].segmentation.view()).unwrap();
1227 let mask_u32 = segmentation_to_mask(output_masks_u32[0].segmentation.view()).unwrap();
1228 let mask_i32 = segmentation_to_mask(output_masks_i32[0].segmentation.view()).unwrap();
1229 assert_eq!(mask_u8, mask_i8);
1230 assert_eq!(mask_u8, mask_u16);
1231 assert_eq!(mask_u8, mask_i16);
1232 assert_eq!(mask_u8, mask_u32);
1233 assert_eq!(mask_u8, mask_i32);
1234 }
1235
1236 #[test]
1237 fn test_modelpack_segdet() {
1238 let score_threshold = 0.45;
1239 let iou_threshold = 0.45;
1240
1241 let boxes = include_bytes!(concat!(
1242 env!("CARGO_MANIFEST_DIR"),
1243 "/../../testdata/modelpack_boxes_1935x1x4.bin"
1244 ));
1245 let boxes = Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
1246
1247 let scores = include_bytes!(concat!(
1248 env!("CARGO_MANIFEST_DIR"),
1249 "/../../testdata/modelpack_scores_1935x1.bin"
1250 ));
1251 let scores = Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
1252
1253 let seg = include_bytes!(concat!(
1254 env!("CARGO_MANIFEST_DIR"),
1255 "/../../testdata/modelpack_seg_2x160x160.bin"
1256 ));
1257 let seg = Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1258
1259 let quant_boxes = (0.004656755365431309, 21).into();
1260 let quant_scores = (0.0019603664986789227, 0).into();
1261 let quant_seg = (1.0 / 255.0, 0).into();
1262
1263 let decoder = DecoderBuilder::default()
1264 .with_config_modelpack_segdet(
1265 configs::Boxes {
1266 decoder: DecoderType::ModelPack,
1267 quantization: Some(quant_boxes),
1268 shape: vec![1, 1935, 1, 4],
1269 dshape: vec![
1270 (DimName::Batch, 1),
1271 (DimName::NumBoxes, 1935),
1272 (DimName::Padding, 1),
1273 (DimName::BoxCoords, 4),
1274 ],
1275 normalized: Some(true),
1276 },
1277 configs::Scores {
1278 decoder: DecoderType::ModelPack,
1279 quantization: Some(quant_scores),
1280 shape: vec![1, 1935, 1],
1281 dshape: vec![
1282 (DimName::Batch, 1),
1283 (DimName::NumBoxes, 1935),
1284 (DimName::NumClasses, 1),
1285 ],
1286 },
1287 configs::Segmentation {
1288 decoder: DecoderType::ModelPack,
1289 quantization: Some(quant_seg),
1290 shape: vec![1, 2, 160, 160],
1291 dshape: vec![
1292 (DimName::Batch, 1),
1293 (DimName::NumClasses, 2),
1294 (DimName::Height, 160),
1295 (DimName::Width, 160),
1296 ],
1297 },
1298 )
1299 .with_iou_threshold(iou_threshold)
1300 .with_score_threshold(score_threshold)
1301 .build()
1302 .unwrap();
1303 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1304 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1305 decoder
1306 .decode_quantized(
1307 &[scores.view().into(), boxes.view().into(), seg.view().into()],
1308 &mut output_boxes,
1309 &mut output_masks,
1310 )
1311 .unwrap();
1312
1313 let mut mask = seg.slice(s![0, .., .., ..]);
1314 mask.swap_axes(0, 1);
1315 mask.swap_axes(1, 2);
1316 let mask = [Segmentation {
1317 xmin: 0.0,
1318 ymin: 0.0,
1319 xmax: 1.0,
1320 ymax: 1.0,
1321 segmentation: mask.into_owned(),
1322 }];
1323 let correct_boxes = [DetectBox {
1324 bbox: BoundingBox {
1325 xmin: 0.40513772,
1326 ymin: 0.6379755,
1327 xmax: 0.5122431,
1328 ymax: 0.7730214,
1329 },
1330 score: 0.4861709,
1331 label: 0,
1332 }];
1333 compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1334
1335 let scores = dequantize_ndarray(scores.view(), quant_scores.into());
1336 let boxes = dequantize_ndarray(boxes.view(), quant_boxes.into());
1337 let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1338 decoder
1339 .decode_float::<f32>(
1340 &[
1341 scores.view().into_dyn(),
1342 boxes.view().into_dyn(),
1343 seg.view().into_dyn(),
1344 ],
1345 &mut output_boxes,
1346 &mut output_masks,
1347 )
1348 .unwrap();
1349
1350 compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1356 let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1357 let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1358
1359 assert_eq!(mask0, mask1);
1360 }
1361
1362 #[test]
1363 fn test_modelpack_segdet_split() {
1364 let score_threshold = 0.8;
1365 let iou_threshold = 0.5;
1366
1367 let seg = include_bytes!(concat!(
1368 env!("CARGO_MANIFEST_DIR"),
1369 "/../../testdata/modelpack_seg_2x160x160.bin"
1370 ));
1371 let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1372
1373 let detect0 = include_bytes!(concat!(
1374 env!("CARGO_MANIFEST_DIR"),
1375 "/../../testdata/modelpack_split_9x15x18.bin"
1376 ));
1377 let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1378
1379 let detect1 = include_bytes!(concat!(
1380 env!("CARGO_MANIFEST_DIR"),
1381 "/../../testdata/modelpack_split_17x30x18.bin"
1382 ));
1383 let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1384
1385 let quant0 = (0.08547406643629074, 174).into();
1386 let quant1 = (0.09929127991199493, 183).into();
1387 let quant_seg = (1.0 / 255.0, 0).into();
1388
1389 let anchors0 = vec![
1390 [0.36666667461395264, 0.31481480598449707],
1391 [0.38749998807907104, 0.4740740656852722],
1392 [0.5333333611488342, 0.644444465637207],
1393 ];
1394 let anchors1 = vec![
1395 [0.13750000298023224, 0.2074074000120163],
1396 [0.2541666626930237, 0.21481481194496155],
1397 [0.23125000298023224, 0.35185185074806213],
1398 ];
1399
1400 let decoder = DecoderBuilder::default()
1401 .with_config_modelpack_segdet_split(
1402 vec![
1403 configs::Detection {
1404 decoder: DecoderType::ModelPack,
1405 shape: vec![1, 17, 30, 18],
1406 anchors: Some(anchors1),
1407 quantization: Some(quant1),
1408 dshape: vec![
1409 (DimName::Batch, 1),
1410 (DimName::Height, 17),
1411 (DimName::Width, 30),
1412 (DimName::NumAnchorsXFeatures, 18),
1413 ],
1414 normalized: Some(true),
1415 },
1416 configs::Detection {
1417 decoder: DecoderType::ModelPack,
1418 shape: vec![1, 9, 15, 18],
1419 anchors: Some(anchors0),
1420 quantization: Some(quant0),
1421 dshape: vec![
1422 (DimName::Batch, 1),
1423 (DimName::Height, 9),
1424 (DimName::Width, 15),
1425 (DimName::NumAnchorsXFeatures, 18),
1426 ],
1427 normalized: Some(true),
1428 },
1429 ],
1430 configs::Segmentation {
1431 decoder: DecoderType::ModelPack,
1432 quantization: Some(quant_seg),
1433 shape: vec![1, 2, 160, 160],
1434 dshape: vec![
1435 (DimName::Batch, 1),
1436 (DimName::NumClasses, 2),
1437 (DimName::Height, 160),
1438 (DimName::Width, 160),
1439 ],
1440 },
1441 )
1442 .with_score_threshold(score_threshold)
1443 .with_iou_threshold(iou_threshold)
1444 .build()
1445 .unwrap();
1446 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1447 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1448 decoder
1449 .decode_quantized(
1450 &[
1451 detect0.view().into(),
1452 detect1.view().into(),
1453 seg.view().into(),
1454 ],
1455 &mut output_boxes,
1456 &mut output_masks,
1457 )
1458 .unwrap();
1459
1460 let mut mask = seg.slice(s![0, .., .., ..]);
1461 mask.swap_axes(0, 1);
1462 mask.swap_axes(1, 2);
1463 let mask = [Segmentation {
1464 xmin: 0.0,
1465 ymin: 0.0,
1466 xmax: 1.0,
1467 ymax: 1.0,
1468 segmentation: mask.into_owned(),
1469 }];
1470 let correct_boxes = [DetectBox {
1471 bbox: BoundingBox {
1472 xmin: 0.43171933,
1473 ymin: 0.68243736,
1474 xmax: 0.5626645,
1475 ymax: 0.808863,
1476 },
1477 score: 0.99240804,
1478 label: 0,
1479 }];
1480 println!("Output Boxes: {:?}", output_boxes);
1481 compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1482
1483 let detect0 = dequantize_ndarray(detect0.view(), quant0.into());
1484 let detect1 = dequantize_ndarray(detect1.view(), quant1.into());
1485 let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1486 decoder
1487 .decode_float::<f32>(
1488 &[
1489 detect0.view().into_dyn(),
1490 detect1.view().into_dyn(),
1491 seg.view().into_dyn(),
1492 ],
1493 &mut output_boxes,
1494 &mut output_masks,
1495 )
1496 .unwrap();
1497
1498 compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1504 let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1505 let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1506
1507 assert_eq!(mask0, mask1);
1508 }
1509
1510 #[test]
1511 fn test_dequant_chunked() {
1512 let out = include_bytes!(concat!(
1513 env!("CARGO_MANIFEST_DIR"),
1514 "/../../testdata/yolov8s_80_classes.bin"
1515 ));
1516 let mut out =
1517 unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) }.to_vec();
1518 out.push(123); let mut out_dequant = vec![0.0; 84 * 8400 + 1];
1521 let mut out_dequant_simd = vec![0.0; 84 * 8400 + 1];
1522 let quant = Quantization::new(0.0040811873, -123);
1523 dequantize_cpu(&out, quant, &mut out_dequant);
1524
1525 dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1526 assert_eq!(out_dequant, out_dequant_simd);
1527
1528 let quant = Quantization::new(0.0040811873, 0);
1529 dequantize_cpu(&out, quant, &mut out_dequant);
1530
1531 dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1532 assert_eq!(out_dequant, out_dequant_simd);
1533 }
1534
1535 #[test]
1536 fn test_decoder_yolo_det() {
1537 let score_threshold = 0.25;
1538 let iou_threshold = 0.7;
1539 let out = include_bytes!(concat!(
1540 env!("CARGO_MANIFEST_DIR"),
1541 "/../../testdata/yolov8s_80_classes.bin"
1542 ));
1543 let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
1544 let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
1545 let quant = (0.0040811873, -123).into();
1546
1547 let decoder = DecoderBuilder::default()
1548 .with_config_yolo_det(
1549 configs::Detection {
1550 decoder: DecoderType::Ultralytics,
1551 shape: vec![1, 84, 8400],
1552 anchors: None,
1553 quantization: Some(quant),
1554 dshape: vec![
1555 (DimName::Batch, 1),
1556 (DimName::NumFeatures, 84),
1557 (DimName::NumBoxes, 8400),
1558 ],
1559 normalized: Some(true),
1560 },
1561 Some(DecoderVersion::Yolo11),
1562 )
1563 .with_score_threshold(score_threshold)
1564 .with_iou_threshold(iou_threshold)
1565 .build()
1566 .unwrap();
1567
1568 let mut output_boxes: Vec<_> = Vec::with_capacity(50);
1569 decode_yolo_det(
1570 (out.slice(s![0, .., ..]), quant.into()),
1571 score_threshold,
1572 iou_threshold,
1573 Some(configs::Nms::ClassAgnostic),
1574 &mut output_boxes,
1575 );
1576 assert!(output_boxes[0].equal_within_delta(
1577 &DetectBox {
1578 bbox: BoundingBox {
1579 xmin: 0.5285137,
1580 ymin: 0.05305544,
1581 xmax: 0.87541467,
1582 ymax: 0.9998909,
1583 },
1584 score: 0.5591227,
1585 label: 0
1586 },
1587 1e-6
1588 ));
1589
1590 assert!(output_boxes[1].equal_within_delta(
1591 &DetectBox {
1592 bbox: BoundingBox {
1593 xmin: 0.130598,
1594 ymin: 0.43260583,
1595 xmax: 0.35098213,
1596 ymax: 0.9958097,
1597 },
1598 score: 0.33057618,
1599 label: 75
1600 },
1601 1e-6
1602 ));
1603
1604 let mut output_boxes1: Vec<_> = Vec::with_capacity(50);
1605 let mut output_masks1: Vec<_> = Vec::with_capacity(50);
1606 decoder
1607 .decode_quantized(&[out.view().into()], &mut output_boxes1, &mut output_masks1)
1608 .unwrap();
1609
1610 let out = dequantize_ndarray(out.view(), quant.into());
1611 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(50);
1612 let mut output_masks_f32: Vec<_> = Vec::with_capacity(50);
1613 decoder
1614 .decode_float::<f32>(
1615 &[out.view().into_dyn()],
1616 &mut output_boxes_f32,
1617 &mut output_masks_f32,
1618 )
1619 .unwrap();
1620
1621 compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1622 compare_outputs((&output_boxes, &output_boxes_f32), (&[], &output_masks_f32));
1623 }
1624
1625 #[test]
1626 fn test_decoder_masks() {
1627 let score_threshold = 0.45;
1628 let iou_threshold = 0.45;
1629 let boxes = include_bytes!(concat!(
1630 env!("CARGO_MANIFEST_DIR"),
1631 "/../../testdata/yolov8_boxes_116x8400.bin"
1632 ));
1633 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1634 let boxes = ndarray::Array2::from_shape_vec((116, 8400), boxes.to_vec()).unwrap();
1635 let quant_boxes = Quantization::new(0.021287761628627777, 31);
1636
1637 let protos = include_bytes!(concat!(
1638 env!("CARGO_MANIFEST_DIR"),
1639 "/../../testdata/yolov8_protos_160x160x32.bin"
1640 ));
1641 let protos =
1642 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
1643 let protos = ndarray::Array3::from_shape_vec((160, 160, 32), protos.to_vec()).unwrap();
1644 let quant_protos = Quantization::new(0.02491161972284317, -117);
1645 let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1646 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1647 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1648 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1649 decode_yolo_segdet_float(
1650 seg.view(),
1651 protos.view(),
1652 score_threshold,
1653 iou_threshold,
1654 Some(configs::Nms::ClassAgnostic),
1655 &mut output_boxes,
1656 &mut output_masks,
1657 )
1658 .unwrap();
1659 assert_eq!(output_boxes.len(), 2);
1660 assert_eq!(output_boxes.len(), output_masks.len());
1661
1662 for (b, m) in output_boxes.iter().zip(&output_masks) {
1663 assert!(b.bbox.xmin >= m.xmin);
1664 assert!(b.bbox.ymin >= m.ymin);
1665 assert!(b.bbox.xmax >= m.xmax);
1666 assert!(b.bbox.ymax >= m.ymax);
1667 }
1668 assert!(output_boxes[0].equal_within_delta(
1669 &DetectBox {
1670 bbox: BoundingBox {
1671 xmin: 0.08515105,
1672 ymin: 0.7131401,
1673 xmax: 0.29802868,
1674 ymax: 0.8195788,
1675 },
1676 score: 0.91537374,
1677 label: 23
1678 },
1679 1.0 / 160.0, ));
1681
1682 assert!(output_boxes[1].equal_within_delta(
1683 &DetectBox {
1684 bbox: BoundingBox {
1685 xmin: 0.59605736,
1686 ymin: 0.25545314,
1687 xmax: 0.93666154,
1688 ymax: 0.72378385,
1689 },
1690 score: 0.91537374,
1691 label: 23
1692 },
1693 1.0 / 160.0, ));
1695
1696 let full_mask = include_bytes!(concat!(
1697 env!("CARGO_MANIFEST_DIR"),
1698 "/../../testdata/yolov8_mask_results.bin"
1699 ));
1700 let full_mask = ndarray::Array2::from_shape_vec((160, 160), full_mask.to_vec()).unwrap();
1701
1702 let cropped_mask = full_mask.slice(ndarray::s![
1703 (output_masks[1].ymin * 160.0) as usize..(output_masks[1].ymax * 160.0) as usize,
1704 (output_masks[1].xmin * 160.0) as usize..(output_masks[1].xmax * 160.0) as usize,
1705 ]);
1706
1707 assert_eq!(
1708 cropped_mask,
1709 segmentation_to_mask(output_masks[1].segmentation.view()).unwrap()
1710 );
1711 }
1712
1713 #[test]
1718 fn test_decoder_masks_nchw_protos() {
1719 let score_threshold = 0.45;
1720 let iou_threshold = 0.45;
1721
1722 let boxes_raw = include_bytes!(concat!(
1724 env!("CARGO_MANIFEST_DIR"),
1725 "/../../testdata/yolov8_boxes_116x8400.bin"
1726 ));
1727 let boxes_raw =
1728 unsafe { std::slice::from_raw_parts(boxes_raw.as_ptr() as *const i8, boxes_raw.len()) };
1729 let boxes_2d = ndarray::Array2::from_shape_vec((116, 8400), boxes_raw.to_vec()).unwrap();
1730 let quant_boxes = Quantization::new(0.021287761628627777, 31);
1731
1732 let protos_raw = include_bytes!(concat!(
1734 env!("CARGO_MANIFEST_DIR"),
1735 "/../../testdata/yolov8_protos_160x160x32.bin"
1736 ));
1737 let protos_raw = unsafe {
1738 std::slice::from_raw_parts(protos_raw.as_ptr() as *const i8, protos_raw.len())
1739 };
1740 let protos_hwc =
1741 ndarray::Array3::from_shape_vec((160, 160, 32), protos_raw.to_vec()).unwrap();
1742 let quant_protos = Quantization::new(0.02491161972284317, -117);
1743 let protos_f32_hwc = dequantize_ndarray::<_, _, f32>(protos_hwc.view(), quant_protos);
1744
1745 let seg = dequantize_ndarray::<_, _, f32>(boxes_2d.view(), quant_boxes);
1747 let mut ref_boxes: Vec<_> = Vec::with_capacity(10);
1748 let mut ref_masks: Vec<_> = Vec::with_capacity(10);
1749 decode_yolo_segdet_float(
1750 seg.view(),
1751 protos_f32_hwc.view(),
1752 score_threshold,
1753 iou_threshold,
1754 Some(configs::Nms::ClassAgnostic),
1755 &mut ref_boxes,
1756 &mut ref_masks,
1757 )
1758 .unwrap();
1759 assert_eq!(ref_boxes.len(), 2);
1760
1761 let protos_f32_chw = protos_f32_hwc.permuted_axes([2, 0, 1]); let protos_nchw = protos_f32_chw.insert_axis(ndarray::Axis(0)); let seg_3d = seg.insert_axis(ndarray::Axis(0)); let decoder = DecoderBuilder::default()
1771 .with_config_yolo_segdet(
1772 configs::Detection {
1773 decoder: configs::DecoderType::Ultralytics,
1774 quantization: None,
1775 shape: vec![1, 116, 8400],
1776 dshape: vec![],
1777 normalized: Some(true),
1778 anchors: None,
1779 },
1780 configs::Protos {
1781 decoder: configs::DecoderType::Ultralytics,
1782 quantization: None,
1783 shape: vec![1, 32, 160, 160],
1784 dshape: vec![], },
1786 None, )
1788 .with_score_threshold(score_threshold)
1789 .with_iou_threshold(iou_threshold)
1790 .build()
1791 .unwrap();
1792
1793 let mut cfg_boxes: Vec<_> = Vec::with_capacity(10);
1794 let mut cfg_masks: Vec<_> = Vec::with_capacity(10);
1795 decoder
1796 .decode_float(
1797 &[seg_3d.view().into_dyn(), protos_nchw.view().into_dyn()],
1798 &mut cfg_boxes,
1799 &mut cfg_masks,
1800 )
1801 .unwrap();
1802
1803 assert_eq!(
1805 cfg_boxes.len(),
1806 ref_boxes.len(),
1807 "config path produced {} boxes, reference produced {}",
1808 cfg_boxes.len(),
1809 ref_boxes.len()
1810 );
1811
1812 for (i, (cb, rb)) in cfg_boxes.iter().zip(&ref_boxes).enumerate() {
1814 assert!(
1815 cb.equal_within_delta(rb, 0.01),
1816 "box {i} mismatch: config={cb:?}, reference={rb:?}"
1817 );
1818 }
1819
1820 for (i, (cm, rm)) in cfg_masks.iter().zip(&ref_masks).enumerate() {
1822 let cm_arr = segmentation_to_mask(cm.segmentation.view()).unwrap();
1823 let rm_arr = segmentation_to_mask(rm.segmentation.view()).unwrap();
1824 assert_eq!(
1825 cm_arr, rm_arr,
1826 "mask {i} pixel mismatch between config-driven and reference paths"
1827 );
1828 }
1829 }
1830
1831 #[test]
1832 fn test_decoder_masks_i8() {
1833 let score_threshold = 0.45;
1834 let iou_threshold = 0.45;
1835 let boxes = include_bytes!(concat!(
1836 env!("CARGO_MANIFEST_DIR"),
1837 "/../../testdata/yolov8_boxes_116x8400.bin"
1838 ));
1839 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1840 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes.to_vec()).unwrap();
1841 let quant_boxes = (0.021287761628627777, 31).into();
1842
1843 let protos = include_bytes!(concat!(
1844 env!("CARGO_MANIFEST_DIR"),
1845 "/../../testdata/yolov8_protos_160x160x32.bin"
1846 ));
1847 let protos =
1848 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
1849 let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
1850 let quant_protos = (0.02491161972284317, -117).into();
1851 let mut output_boxes: Vec<_> = Vec::with_capacity(500);
1852 let mut output_masks: Vec<_> = Vec::with_capacity(500);
1853
1854 let decoder = DecoderBuilder::default()
1855 .with_config_yolo_segdet(
1856 configs::Detection {
1857 decoder: configs::DecoderType::Ultralytics,
1858 quantization: Some(quant_boxes),
1859 shape: vec![1, 116, 8400],
1860 anchors: None,
1861 dshape: vec![
1862 (DimName::Batch, 1),
1863 (DimName::NumFeatures, 116),
1864 (DimName::NumBoxes, 8400),
1865 ],
1866 normalized: Some(true),
1867 },
1868 Protos {
1869 decoder: configs::DecoderType::Ultralytics,
1870 quantization: Some(quant_protos),
1871 shape: vec![1, 160, 160, 32],
1872 dshape: vec![
1873 (DimName::Batch, 1),
1874 (DimName::Height, 160),
1875 (DimName::Width, 160),
1876 (DimName::NumProtos, 32),
1877 ],
1878 },
1879 Some(DecoderVersion::Yolo11),
1880 )
1881 .with_score_threshold(score_threshold)
1882 .with_iou_threshold(iou_threshold)
1883 .build()
1884 .unwrap();
1885
1886 let quant_boxes = quant_boxes.into();
1887 let quant_protos = quant_protos.into();
1888
1889 decode_yolo_segdet_quant(
1890 (boxes.slice(s![0, .., ..]), quant_boxes),
1891 (protos.slice(s![0, .., .., ..]), quant_protos),
1892 score_threshold,
1893 iou_threshold,
1894 Some(configs::Nms::ClassAgnostic),
1895 &mut output_boxes,
1896 &mut output_masks,
1897 )
1898 .unwrap();
1899
1900 let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
1901 let mut output_masks1: Vec<_> = Vec::with_capacity(500);
1902
1903 decoder
1904 .decode_quantized(
1905 &[boxes.view().into(), protos.view().into()],
1906 &mut output_boxes1,
1907 &mut output_masks1,
1908 )
1909 .unwrap();
1910
1911 let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1912 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1913
1914 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
1915 let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
1916 decode_yolo_segdet_float(
1917 seg.slice(s![0, .., ..]),
1918 protos.slice(s![0, .., .., ..]),
1919 score_threshold,
1920 iou_threshold,
1921 Some(configs::Nms::ClassAgnostic),
1922 &mut output_boxes_f32,
1923 &mut output_masks_f32,
1924 )
1925 .unwrap();
1926
1927 let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(500);
1928 let mut output_masks1_f32: Vec<_> = Vec::with_capacity(500);
1929
1930 decoder
1931 .decode_float(
1932 &[seg.view().into_dyn(), protos.view().into_dyn()],
1933 &mut output_boxes1_f32,
1934 &mut output_masks1_f32,
1935 )
1936 .unwrap();
1937
1938 compare_outputs(
1939 (&output_boxes, &output_boxes1),
1940 (&output_masks, &output_masks1),
1941 );
1942
1943 compare_outputs(
1944 (&output_boxes, &output_boxes_f32),
1945 (&output_masks, &output_masks_f32),
1946 );
1947
1948 compare_outputs(
1949 (&output_boxes_f32, &output_boxes1_f32),
1950 (&output_masks_f32, &output_masks1_f32),
1951 );
1952 }
1953
1954 #[test]
1955 fn test_decoder_yolo_split() {
1956 let score_threshold = 0.45;
1957 let iou_threshold = 0.45;
1958 let boxes = include_bytes!(concat!(
1959 env!("CARGO_MANIFEST_DIR"),
1960 "/../../testdata/yolov8_boxes_116x8400.bin"
1961 ));
1962 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1963 let boxes: Vec<_> = boxes.iter().map(|x| *x as i16 * 256).collect();
1964 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
1965
1966 let quant_boxes = Quantization::new(0.021287761628627777 / 256.0, 31 * 256);
1967
1968 let decoder = DecoderBuilder::default()
1969 .with_config_yolo_split_det(
1970 configs::Boxes {
1971 decoder: configs::DecoderType::Ultralytics,
1972 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
1973 shape: vec![1, 4, 8400],
1974 dshape: vec![
1975 (DimName::Batch, 1),
1976 (DimName::BoxCoords, 4),
1977 (DimName::NumBoxes, 8400),
1978 ],
1979 normalized: Some(true),
1980 },
1981 configs::Scores {
1982 decoder: configs::DecoderType::Ultralytics,
1983 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
1984 shape: vec![1, 80, 8400],
1985 dshape: vec![
1986 (DimName::Batch, 1),
1987 (DimName::NumClasses, 80),
1988 (DimName::NumBoxes, 8400),
1989 ],
1990 },
1991 )
1992 .with_score_threshold(score_threshold)
1993 .with_iou_threshold(iou_threshold)
1994 .build()
1995 .unwrap();
1996
1997 let mut output_boxes: Vec<_> = Vec::with_capacity(500);
1998 let mut output_masks: Vec<_> = Vec::with_capacity(500);
1999
2000 decoder
2001 .decode_quantized(
2002 &[
2003 boxes.slice(s![.., ..4, ..]).into(),
2004 boxes.slice(s![.., 4..84, ..]).into(),
2005 ],
2006 &mut output_boxes,
2007 &mut output_masks,
2008 )
2009 .unwrap();
2010
2011 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
2012 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2013 decode_yolo_det_float(
2014 seg.slice(s![0, ..84, ..]),
2015 score_threshold,
2016 iou_threshold,
2017 Some(configs::Nms::ClassAgnostic),
2018 &mut output_boxes_f32,
2019 );
2020
2021 let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
2022 let mut output_masks1: Vec<_> = Vec::with_capacity(500);
2023
2024 decoder
2025 .decode_float(
2026 &[
2027 seg.slice(s![.., ..4, ..]).into_dyn(),
2028 seg.slice(s![.., 4..84, ..]).into_dyn(),
2029 ],
2030 &mut output_boxes1,
2031 &mut output_masks1,
2032 )
2033 .unwrap();
2034 compare_outputs((&output_boxes, &output_boxes_f32), (&output_masks, &[]));
2035 compare_outputs((&output_boxes_f32, &output_boxes1), (&[], &output_masks1));
2036 }
2037
2038 #[test]
2039 fn test_decoder_masks_config_mixed() {
2040 let score_threshold = 0.45;
2041 let iou_threshold = 0.45;
2042 let boxes = include_bytes!(concat!(
2043 env!("CARGO_MANIFEST_DIR"),
2044 "/../../testdata/yolov8_boxes_116x8400.bin"
2045 ));
2046 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
2047 let boxes: Vec<_> = boxes.iter().map(|x| *x as i16 * 256).collect();
2048 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2049
2050 let quant_boxes = Quantization::new(0.021287761628627777 / 256.0, 31 * 256);
2051
2052 let protos = include_bytes!(concat!(
2053 env!("CARGO_MANIFEST_DIR"),
2054 "/../../testdata/yolov8_protos_160x160x32.bin"
2055 ));
2056 let protos =
2057 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
2058 let protos: Vec<_> = protos.to_vec();
2059 let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
2060 let quant_protos = Quantization::new(0.02491161972284317, -117);
2061
2062 let decoder = DecoderBuilder::default()
2063 .with_config_yolo_split_segdet(
2064 configs::Boxes {
2065 decoder: configs::DecoderType::Ultralytics,
2066 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2067 shape: vec![1, 4, 8400],
2068 dshape: vec![
2069 (DimName::Batch, 1),
2070 (DimName::BoxCoords, 4),
2071 (DimName::NumBoxes, 8400),
2072 ],
2073 normalized: Some(true),
2074 },
2075 configs::Scores {
2076 decoder: configs::DecoderType::Ultralytics,
2077 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2078 shape: vec![1, 80, 8400],
2079 dshape: vec![
2080 (DimName::Batch, 1),
2081 (DimName::NumClasses, 80),
2082 (DimName::NumBoxes, 8400),
2083 ],
2084 },
2085 configs::MaskCoefficients {
2086 decoder: configs::DecoderType::Ultralytics,
2087 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2088 shape: vec![1, 32, 8400],
2089 dshape: vec![
2090 (DimName::Batch, 1),
2091 (DimName::NumProtos, 32),
2092 (DimName::NumBoxes, 8400),
2093 ],
2094 },
2095 configs::Protos {
2096 decoder: configs::DecoderType::Ultralytics,
2097 quantization: Some(QuantTuple(quant_protos.scale, quant_protos.zero_point)),
2098 shape: vec![1, 160, 160, 32],
2099 dshape: vec![
2100 (DimName::Batch, 1),
2101 (DimName::Height, 160),
2102 (DimName::Width, 160),
2103 (DimName::NumProtos, 32),
2104 ],
2105 },
2106 )
2107 .with_score_threshold(score_threshold)
2108 .with_iou_threshold(iou_threshold)
2109 .build()
2110 .unwrap();
2111
2112 let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2113 let mut output_masks: Vec<_> = Vec::with_capacity(500);
2114
2115 decoder
2116 .decode_quantized(
2117 &[
2118 boxes.slice(s![.., ..4, ..]).into(),
2119 boxes.slice(s![.., 4..84, ..]).into(),
2120 boxes.slice(s![.., 84.., ..]).into(),
2121 protos.view().into(),
2122 ],
2123 &mut output_boxes,
2124 &mut output_masks,
2125 )
2126 .unwrap();
2127
2128 let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
2129 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
2130 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2131 let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
2132 decode_yolo_segdet_float(
2133 seg.slice(s![0, .., ..]),
2134 protos.slice(s![0, .., .., ..]),
2135 score_threshold,
2136 iou_threshold,
2137 Some(configs::Nms::ClassAgnostic),
2138 &mut output_boxes_f32,
2139 &mut output_masks_f32,
2140 )
2141 .unwrap();
2142
2143 let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
2144 let mut output_masks1: Vec<_> = Vec::with_capacity(500);
2145
2146 decoder
2147 .decode_float(
2148 &[
2149 seg.slice(s![.., ..4, ..]).into_dyn(),
2150 seg.slice(s![.., 4..84, ..]).into_dyn(),
2151 seg.slice(s![.., 84.., ..]).into_dyn(),
2152 protos.view().into_dyn(),
2153 ],
2154 &mut output_boxes1,
2155 &mut output_masks1,
2156 )
2157 .unwrap();
2158 compare_outputs(
2159 (&output_boxes, &output_boxes_f32),
2160 (&output_masks, &output_masks_f32),
2161 );
2162 compare_outputs(
2163 (&output_boxes_f32, &output_boxes1),
2164 (&output_masks_f32, &output_masks1),
2165 );
2166 }
2167
2168 #[test]
2169 fn test_decoder_masks_config_i32() {
2170 let score_threshold = 0.45;
2171 let iou_threshold = 0.45;
2172 let boxes = include_bytes!(concat!(
2173 env!("CARGO_MANIFEST_DIR"),
2174 "/../../testdata/yolov8_boxes_116x8400.bin"
2175 ));
2176 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
2177 let scale = 1 << 23;
2178 let boxes: Vec<_> = boxes.iter().map(|x| *x as i32 * scale).collect();
2179 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2180
2181 let quant_boxes = Quantization::new(0.021287761628627777 / scale as f32, 31 * scale);
2182
2183 let protos = include_bytes!(concat!(
2184 env!("CARGO_MANIFEST_DIR"),
2185 "/../../testdata/yolov8_protos_160x160x32.bin"
2186 ));
2187 let protos =
2188 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
2189 let protos: Vec<_> = protos.iter().map(|x| *x as i32 * scale).collect();
2190 let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
2191 let quant_protos = Quantization::new(0.02491161972284317 / scale as f32, -117 * scale);
2192
2193 let decoder = DecoderBuilder::default()
2194 .with_config_yolo_split_segdet(
2195 configs::Boxes {
2196 decoder: configs::DecoderType::Ultralytics,
2197 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2198 shape: vec![1, 4, 8400],
2199 dshape: vec![
2200 (DimName::Batch, 1),
2201 (DimName::BoxCoords, 4),
2202 (DimName::NumBoxes, 8400),
2203 ],
2204 normalized: Some(true),
2205 },
2206 configs::Scores {
2207 decoder: configs::DecoderType::Ultralytics,
2208 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2209 shape: vec![1, 80, 8400],
2210 dshape: vec![
2211 (DimName::Batch, 1),
2212 (DimName::NumClasses, 80),
2213 (DimName::NumBoxes, 8400),
2214 ],
2215 },
2216 configs::MaskCoefficients {
2217 decoder: configs::DecoderType::Ultralytics,
2218 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2219 shape: vec![1, 32, 8400],
2220 dshape: vec![
2221 (DimName::Batch, 1),
2222 (DimName::NumProtos, 32),
2223 (DimName::NumBoxes, 8400),
2224 ],
2225 },
2226 configs::Protos {
2227 decoder: configs::DecoderType::Ultralytics,
2228 quantization: Some(QuantTuple(quant_protos.scale, quant_protos.zero_point)),
2229 shape: vec![1, 160, 160, 32],
2230 dshape: vec![
2231 (DimName::Batch, 1),
2232 (DimName::Height, 160),
2233 (DimName::Width, 160),
2234 (DimName::NumProtos, 32),
2235 ],
2236 },
2237 )
2238 .with_score_threshold(score_threshold)
2239 .with_iou_threshold(iou_threshold)
2240 .build()
2241 .unwrap();
2242
2243 let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2244 let mut output_masks: Vec<_> = Vec::with_capacity(500);
2245
2246 decoder
2247 .decode_quantized(
2248 &[
2249 boxes.slice(s![.., ..4, ..]).into(),
2250 boxes.slice(s![.., 4..84, ..]).into(),
2251 boxes.slice(s![.., 84.., ..]).into(),
2252 protos.view().into(),
2253 ],
2254 &mut output_boxes,
2255 &mut output_masks,
2256 )
2257 .unwrap();
2258
2259 let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
2260 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
2261 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2262 let mut output_masks_f32: Vec<Segmentation> = Vec::with_capacity(500);
2263 decode_yolo_segdet_float(
2264 seg.slice(s![0, .., ..]),
2265 protos.slice(s![0, .., .., ..]),
2266 score_threshold,
2267 iou_threshold,
2268 Some(configs::Nms::ClassAgnostic),
2269 &mut output_boxes_f32,
2270 &mut output_masks_f32,
2271 )
2272 .unwrap();
2273
2274 assert_eq!(output_boxes.len(), output_boxes_f32.len());
2275 assert_eq!(output_masks.len(), output_masks_f32.len());
2276
2277 compare_outputs(
2278 (&output_boxes, &output_boxes_f32),
2279 (&output_masks, &output_masks_f32),
2280 );
2281 }
2282
2283 #[test]
2285 fn test_context_switch() {
2286 let yolo_det = || {
2287 let score_threshold = 0.25;
2288 let iou_threshold = 0.7;
2289 let out = include_bytes!(concat!(
2290 env!("CARGO_MANIFEST_DIR"),
2291 "/../../testdata/yolov8s_80_classes.bin"
2292 ));
2293 let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
2294 let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
2295 let quant = (0.0040811873, -123).into();
2296
2297 let decoder = DecoderBuilder::default()
2298 .with_config_yolo_det(
2299 configs::Detection {
2300 decoder: DecoderType::Ultralytics,
2301 shape: vec![1, 84, 8400],
2302 anchors: None,
2303 quantization: Some(quant),
2304 dshape: vec![
2305 (DimName::Batch, 1),
2306 (DimName::NumFeatures, 84),
2307 (DimName::NumBoxes, 8400),
2308 ],
2309 normalized: None,
2310 },
2311 None,
2312 )
2313 .with_score_threshold(score_threshold)
2314 .with_iou_threshold(iou_threshold)
2315 .build()
2316 .unwrap();
2317
2318 let mut output_boxes: Vec<_> = Vec::with_capacity(50);
2319 let mut output_masks: Vec<_> = Vec::with_capacity(50);
2320
2321 for _ in 0..100 {
2322 decoder
2323 .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
2324 .unwrap();
2325
2326 assert!(output_boxes[0].equal_within_delta(
2327 &DetectBox {
2328 bbox: BoundingBox {
2329 xmin: 0.5285137,
2330 ymin: 0.05305544,
2331 xmax: 0.87541467,
2332 ymax: 0.9998909,
2333 },
2334 score: 0.5591227,
2335 label: 0
2336 },
2337 1e-6
2338 ));
2339
2340 assert!(output_boxes[1].equal_within_delta(
2341 &DetectBox {
2342 bbox: BoundingBox {
2343 xmin: 0.130598,
2344 ymin: 0.43260583,
2345 xmax: 0.35098213,
2346 ymax: 0.9958097,
2347 },
2348 score: 0.33057618,
2349 label: 75
2350 },
2351 1e-6
2352 ));
2353 assert!(output_masks.is_empty());
2354 }
2355 };
2356
2357 let modelpack_det_split = || {
2358 let score_threshold = 0.8;
2359 let iou_threshold = 0.5;
2360
2361 let seg = include_bytes!(concat!(
2362 env!("CARGO_MANIFEST_DIR"),
2363 "/../../testdata/modelpack_seg_2x160x160.bin"
2364 ));
2365 let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
2366
2367 let detect0 = include_bytes!(concat!(
2368 env!("CARGO_MANIFEST_DIR"),
2369 "/../../testdata/modelpack_split_9x15x18.bin"
2370 ));
2371 let detect0 =
2372 ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
2373
2374 let detect1 = include_bytes!(concat!(
2375 env!("CARGO_MANIFEST_DIR"),
2376 "/../../testdata/modelpack_split_17x30x18.bin"
2377 ));
2378 let detect1 =
2379 ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
2380
2381 let mut mask = seg.slice(s![0, .., .., ..]);
2382 mask.swap_axes(0, 1);
2383 mask.swap_axes(1, 2);
2384 let mask = [Segmentation {
2385 xmin: 0.0,
2386 ymin: 0.0,
2387 xmax: 1.0,
2388 ymax: 1.0,
2389 segmentation: mask.into_owned(),
2390 }];
2391 let correct_boxes = [DetectBox {
2392 bbox: BoundingBox {
2393 xmin: 0.43171933,
2394 ymin: 0.68243736,
2395 xmax: 0.5626645,
2396 ymax: 0.808863,
2397 },
2398 score: 0.99240804,
2399 label: 0,
2400 }];
2401
2402 let quant0 = (0.08547406643629074, 174).into();
2403 let quant1 = (0.09929127991199493, 183).into();
2404 let quant_seg = (1.0 / 255.0, 0).into();
2405
2406 let anchors0 = vec![
2407 [0.36666667461395264, 0.31481480598449707],
2408 [0.38749998807907104, 0.4740740656852722],
2409 [0.5333333611488342, 0.644444465637207],
2410 ];
2411 let anchors1 = vec![
2412 [0.13750000298023224, 0.2074074000120163],
2413 [0.2541666626930237, 0.21481481194496155],
2414 [0.23125000298023224, 0.35185185074806213],
2415 ];
2416
2417 let decoder = DecoderBuilder::default()
2418 .with_config_modelpack_segdet_split(
2419 vec![
2420 configs::Detection {
2421 decoder: DecoderType::ModelPack,
2422 shape: vec![1, 17, 30, 18],
2423 anchors: Some(anchors1),
2424 quantization: Some(quant1),
2425 dshape: vec![
2426 (DimName::Batch, 1),
2427 (DimName::Height, 17),
2428 (DimName::Width, 30),
2429 (DimName::NumAnchorsXFeatures, 18),
2430 ],
2431 normalized: None,
2432 },
2433 configs::Detection {
2434 decoder: DecoderType::ModelPack,
2435 shape: vec![1, 9, 15, 18],
2436 anchors: Some(anchors0),
2437 quantization: Some(quant0),
2438 dshape: vec![
2439 (DimName::Batch, 1),
2440 (DimName::Height, 9),
2441 (DimName::Width, 15),
2442 (DimName::NumAnchorsXFeatures, 18),
2443 ],
2444 normalized: None,
2445 },
2446 ],
2447 configs::Segmentation {
2448 decoder: DecoderType::ModelPack,
2449 quantization: Some(quant_seg),
2450 shape: vec![1, 2, 160, 160],
2451 dshape: vec![
2452 (DimName::Batch, 1),
2453 (DimName::NumClasses, 2),
2454 (DimName::Height, 160),
2455 (DimName::Width, 160),
2456 ],
2457 },
2458 )
2459 .with_score_threshold(score_threshold)
2460 .with_iou_threshold(iou_threshold)
2461 .build()
2462 .unwrap();
2463 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
2464 let mut output_masks: Vec<_> = Vec::with_capacity(10);
2465
2466 for _ in 0..100 {
2467 decoder
2468 .decode_quantized(
2469 &[
2470 detect0.view().into(),
2471 detect1.view().into(),
2472 seg.view().into(),
2473 ],
2474 &mut output_boxes,
2475 &mut output_masks,
2476 )
2477 .unwrap();
2478
2479 compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
2480 }
2481 };
2482
2483 let handles = vec![
2484 std::thread::spawn(yolo_det),
2485 std::thread::spawn(modelpack_det_split),
2486 std::thread::spawn(yolo_det),
2487 std::thread::spawn(modelpack_det_split),
2488 std::thread::spawn(yolo_det),
2489 std::thread::spawn(modelpack_det_split),
2490 std::thread::spawn(yolo_det),
2491 std::thread::spawn(modelpack_det_split),
2492 ];
2493 for handle in handles {
2494 handle.join().unwrap();
2495 }
2496 }
2497
2498 #[test]
2499 fn test_ndarray_to_xyxy_float() {
2500 let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2501 let xyxy: [f32; 4] = XYWH::ndarray_to_xyxy_float(arr.view());
2502 assert_eq!(xyxy, [0.0_f32, 10.0, 20.0, 30.0]);
2503
2504 let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2505 let xyxy: [f32; 4] = XYXY::ndarray_to_xyxy_float(arr.view());
2506 assert_eq!(xyxy, [10.0_f32, 20.0, 20.0, 20.0]);
2507 }
2508
2509 #[test]
2510 fn test_class_aware_nms_float() {
2511 use crate::float::nms_class_aware_float;
2512
2513 let boxes = vec![
2515 DetectBox {
2516 bbox: BoundingBox {
2517 xmin: 0.0,
2518 ymin: 0.0,
2519 xmax: 0.5,
2520 ymax: 0.5,
2521 },
2522 score: 0.9,
2523 label: 0, },
2525 DetectBox {
2526 bbox: BoundingBox {
2527 xmin: 0.1,
2528 ymin: 0.1,
2529 xmax: 0.6,
2530 ymax: 0.6,
2531 },
2532 score: 0.8,
2533 label: 1, },
2535 ];
2536
2537 let result = nms_class_aware_float(0.3, boxes.clone());
2540 assert_eq!(
2541 result.len(),
2542 2,
2543 "Class-aware NMS should keep both boxes with different classes"
2544 );
2545
2546 let same_class_boxes = vec![
2548 DetectBox {
2549 bbox: BoundingBox {
2550 xmin: 0.0,
2551 ymin: 0.0,
2552 xmax: 0.5,
2553 ymax: 0.5,
2554 },
2555 score: 0.9,
2556 label: 0,
2557 },
2558 DetectBox {
2559 bbox: BoundingBox {
2560 xmin: 0.1,
2561 ymin: 0.1,
2562 xmax: 0.6,
2563 ymax: 0.6,
2564 },
2565 score: 0.8,
2566 label: 0, },
2568 ];
2569
2570 let result = nms_class_aware_float(0.3, same_class_boxes);
2571 assert_eq!(
2572 result.len(),
2573 1,
2574 "Class-aware NMS should suppress overlapping box with same class"
2575 );
2576 assert_eq!(result[0].label, 0);
2577 assert!((result[0].score - 0.9).abs() < 1e-6);
2578 }
2579
2580 #[test]
2581 fn test_class_agnostic_vs_aware_nms() {
2582 use crate::float::{nms_class_aware_float, nms_float};
2583
2584 let boxes = vec![
2586 DetectBox {
2587 bbox: BoundingBox {
2588 xmin: 0.0,
2589 ymin: 0.0,
2590 xmax: 0.5,
2591 ymax: 0.5,
2592 },
2593 score: 0.9,
2594 label: 0,
2595 },
2596 DetectBox {
2597 bbox: BoundingBox {
2598 xmin: 0.1,
2599 ymin: 0.1,
2600 xmax: 0.6,
2601 ymax: 0.6,
2602 },
2603 score: 0.8,
2604 label: 1,
2605 },
2606 ];
2607
2608 let agnostic_result = nms_float(0.3, boxes.clone());
2610 assert_eq!(
2611 agnostic_result.len(),
2612 1,
2613 "Class-agnostic NMS should suppress overlapping boxes"
2614 );
2615
2616 let aware_result = nms_class_aware_float(0.3, boxes);
2618 assert_eq!(
2619 aware_result.len(),
2620 2,
2621 "Class-aware NMS should keep boxes with different classes"
2622 );
2623 }
2624
2625 #[test]
2626 fn test_class_aware_nms_int() {
2627 use crate::byte::nms_class_aware_int;
2628
2629 let boxes = vec![
2631 DetectBoxQuantized {
2632 bbox: BoundingBox {
2633 xmin: 0.0,
2634 ymin: 0.0,
2635 xmax: 0.5,
2636 ymax: 0.5,
2637 },
2638 score: 200_u8,
2639 label: 0,
2640 },
2641 DetectBoxQuantized {
2642 bbox: BoundingBox {
2643 xmin: 0.1,
2644 ymin: 0.1,
2645 xmax: 0.6,
2646 ymax: 0.6,
2647 },
2648 score: 180_u8,
2649 label: 1, },
2651 ];
2652
2653 let result = nms_class_aware_int(0.5, boxes);
2655 assert_eq!(
2656 result.len(),
2657 2,
2658 "Class-aware NMS (int) should keep boxes with different classes"
2659 );
2660 }
2661
2662 #[test]
2663 fn test_nms_enum_default() {
2664 let default_nms: configs::Nms = Default::default();
2666 assert_eq!(default_nms, configs::Nms::ClassAgnostic);
2667 }
2668
2669 #[test]
2670 fn test_decoder_nms_mode() {
2671 let decoder = DecoderBuilder::default()
2673 .with_config_yolo_det(
2674 configs::Detection {
2675 anchors: None,
2676 decoder: DecoderType::Ultralytics,
2677 quantization: None,
2678 shape: vec![1, 84, 8400],
2679 dshape: Vec::new(),
2680 normalized: Some(true),
2681 },
2682 None,
2683 )
2684 .with_nms(Some(configs::Nms::ClassAware))
2685 .build()
2686 .unwrap();
2687
2688 assert_eq!(decoder.nms, Some(configs::Nms::ClassAware));
2689 }
2690
2691 #[test]
2692 fn test_decoder_nms_bypass() {
2693 let decoder = DecoderBuilder::default()
2695 .with_config_yolo_det(
2696 configs::Detection {
2697 anchors: None,
2698 decoder: DecoderType::Ultralytics,
2699 quantization: None,
2700 shape: vec![1, 84, 8400],
2701 dshape: Vec::new(),
2702 normalized: Some(true),
2703 },
2704 None,
2705 )
2706 .with_nms(None)
2707 .build()
2708 .unwrap();
2709
2710 assert_eq!(decoder.nms, None);
2711 }
2712
2713 #[test]
2714 fn test_decoder_normalized_boxes_true() {
2715 let decoder = DecoderBuilder::default()
2717 .with_config_yolo_det(
2718 configs::Detection {
2719 anchors: None,
2720 decoder: DecoderType::Ultralytics,
2721 quantization: None,
2722 shape: vec![1, 84, 8400],
2723 dshape: Vec::new(),
2724 normalized: Some(true),
2725 },
2726 None,
2727 )
2728 .build()
2729 .unwrap();
2730
2731 assert_eq!(decoder.normalized_boxes(), Some(true));
2732 }
2733
2734 #[test]
2735 fn test_decoder_normalized_boxes_false() {
2736 let decoder = DecoderBuilder::default()
2739 .with_config_yolo_det(
2740 configs::Detection {
2741 anchors: None,
2742 decoder: DecoderType::Ultralytics,
2743 quantization: None,
2744 shape: vec![1, 84, 8400],
2745 dshape: Vec::new(),
2746 normalized: Some(false),
2747 },
2748 None,
2749 )
2750 .build()
2751 .unwrap();
2752
2753 assert_eq!(decoder.normalized_boxes(), Some(false));
2754 }
2755
2756 #[test]
2757 fn test_decoder_normalized_boxes_unknown() {
2758 let decoder = DecoderBuilder::default()
2760 .with_config_yolo_det(
2761 configs::Detection {
2762 anchors: None,
2763 decoder: DecoderType::Ultralytics,
2764 quantization: None,
2765 shape: vec![1, 84, 8400],
2766 dshape: Vec::new(),
2767 normalized: None,
2768 },
2769 Some(DecoderVersion::Yolo11),
2770 )
2771 .build()
2772 .unwrap();
2773
2774 assert_eq!(decoder.normalized_boxes(), None);
2775 }
2776}
2777
2778#[cfg(feature = "tracker")]
2779#[cfg(test)]
2780#[cfg_attr(coverage_nightly, coverage(off))]
2781mod decoder_tracked_tests {
2782
2783 use edgefirst_tracker::{ByteTrackBuilder, Tracker};
2784 use ndarray::{array, s, Array, Array2, Array3, Array4, ArrayView, Axis, Dimension};
2785 use num_traits::{AsPrimitive, Float, PrimInt};
2786 use rand::{RngExt, SeedableRng};
2787 use rand_distr::StandardNormal;
2788
2789 use crate::{
2790 configs::{self, DimName},
2791 dequantize_ndarray, BoundingBox, DecoderBuilder, DetectBox, Quantization,
2792 };
2793
2794 pub fn quantize_ndarray<T: PrimInt + 'static, D: Dimension, F: Float + AsPrimitive<T>>(
2795 input: ArrayView<F, D>,
2796 quant: Quantization,
2797 ) -> Array<T, D>
2798 where
2799 i32: num_traits::AsPrimitive<F>,
2800 f32: num_traits::AsPrimitive<F>,
2801 {
2802 let zero_point = quant.zero_point.as_();
2803 let div_scale = F::one() / quant.scale.as_();
2804 if zero_point != F::zero() {
2805 input.mapv(|d| (d * div_scale + zero_point).round().as_())
2806 } else {
2807 input.mapv(|d| (d * div_scale).round().as_())
2808 }
2809 }
2810
2811 #[test]
2812 fn test_decoder_tracked_random_jitter() {
2813 use crate::configs::{DecoderType, Nms};
2814 use crate::DecoderBuilder;
2815
2816 let score_threshold = 0.25;
2817 let iou_threshold = 0.1;
2818 let out = include_bytes!(concat!(
2819 env!("CARGO_MANIFEST_DIR"),
2820 "/../../testdata/yolov8s_80_classes.bin"
2821 ));
2822 let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
2823 let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
2824 let quant = (0.0040811873, -123).into();
2825
2826 let decoder = DecoderBuilder::default()
2827 .with_config_yolo_det(
2828 crate::configs::Detection {
2829 decoder: DecoderType::Ultralytics,
2830 shape: vec![1, 84, 8400],
2831 anchors: None,
2832 quantization: Some(quant),
2833 dshape: vec![
2834 (crate::configs::DimName::Batch, 1),
2835 (crate::configs::DimName::NumFeatures, 84),
2836 (crate::configs::DimName::NumBoxes, 8400),
2837 ],
2838 normalized: Some(true),
2839 },
2840 None,
2841 )
2842 .with_score_threshold(score_threshold)
2843 .with_iou_threshold(iou_threshold)
2844 .with_nms(Some(Nms::ClassAgnostic))
2845 .build()
2846 .unwrap();
2847 let mut rng = rand::rngs::StdRng::seed_from_u64(0xAB_BEEF); let expected_boxes = [
2850 crate::DetectBox {
2851 bbox: crate::BoundingBox {
2852 xmin: 0.5285137,
2853 ymin: 0.05305544,
2854 xmax: 0.87541467,
2855 ymax: 0.9998909,
2856 },
2857 score: 0.5591227,
2858 label: 0,
2859 },
2860 crate::DetectBox {
2861 bbox: crate::BoundingBox {
2862 xmin: 0.130598,
2863 ymin: 0.43260583,
2864 xmax: 0.35098213,
2865 ymax: 0.9958097,
2866 },
2867 score: 0.33057618,
2868 label: 75,
2869 },
2870 ];
2871
2872 let mut tracker = ByteTrackBuilder::new()
2873 .track_update(0.1)
2874 .track_high_conf(0.3)
2875 .build();
2876
2877 let mut output_boxes = Vec::with_capacity(50);
2878 let mut output_masks = Vec::with_capacity(50);
2879 let mut output_tracks = Vec::with_capacity(50);
2880
2881 decoder
2882 .decode_tracked_quantized(
2883 &mut tracker,
2884 0,
2885 &[out.view().into()],
2886 &mut output_boxes,
2887 &mut output_masks,
2888 &mut output_tracks,
2889 )
2890 .unwrap();
2891
2892 assert_eq!(output_boxes.len(), 2);
2893 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
2894 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
2895
2896 let mut last_boxes = output_boxes.clone();
2897
2898 for i in 1..=100 {
2899 let mut out = out.clone();
2900 let mut x_values = out.slice_mut(s![0, 0, ..]);
2902 for x in x_values.iter_mut() {
2903 let r: f32 = rng.sample(StandardNormal);
2904 let r = r.clamp(-2.0, 2.0) / 2.0;
2905 *x = x.saturating_add((r * 1e-2 / quant.0) as i8);
2906 }
2907
2908 let mut y_values = out.slice_mut(s![0, 1, ..]);
2909 for y in y_values.iter_mut() {
2910 let r: f32 = rng.sample(StandardNormal);
2911 let r = r.clamp(-2.0, 2.0) / 2.0;
2912 *y = y.saturating_add((r * 1e-2 / quant.0) as i8);
2913 }
2914
2915 decoder
2916 .decode_tracked_quantized(
2917 &mut tracker,
2918 100_000_000 * i / 3, &[out.view().into()],
2920 &mut output_boxes,
2921 &mut output_masks,
2922 &mut output_tracks,
2923 )
2924 .unwrap();
2925
2926 assert_eq!(output_boxes.len(), 2);
2927 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 5e-3));
2928 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 5e-3));
2929
2930 assert!(output_boxes[0].equal_within_delta(&last_boxes[0], 1e-3));
2931 assert!(output_boxes[1].equal_within_delta(&last_boxes[1], 1e-3));
2932 last_boxes = output_boxes.clone();
2933 }
2934 }
2935
2936 #[test]
2937 fn test_decoder_tracked_segdet() {
2938 use crate::configs::Nms;
2939 use crate::DecoderBuilder;
2940
2941 let score_threshold = 0.45;
2942 let iou_threshold = 0.45;
2943 let boxes = include_bytes!(concat!(
2944 env!("CARGO_MANIFEST_DIR"),
2945 "/../../testdata/yolov8_boxes_116x8400.bin"
2946 ));
2947 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
2948 let mut boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes.to_vec()).unwrap();
2949
2950 let protos = include_bytes!(concat!(
2951 env!("CARGO_MANIFEST_DIR"),
2952 "/../../testdata/yolov8_protos_160x160x32.bin"
2953 ));
2954 let protos =
2955 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
2956 let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
2957
2958 let config = include_str!(concat!(
2959 env!("CARGO_MANIFEST_DIR"),
2960 "/../../testdata/yolov8_seg.yaml"
2961 ));
2962
2963 let decoder = DecoderBuilder::default()
2964 .with_config_yaml_str(config.to_string())
2965 .with_score_threshold(score_threshold)
2966 .with_iou_threshold(iou_threshold)
2967 .with_nms(Some(Nms::ClassAgnostic))
2968 .build()
2969 .unwrap();
2970
2971 let expected_boxes = [
2972 DetectBox {
2973 bbox: BoundingBox {
2974 xmin: 0.08515105,
2975 ymin: 0.7131401,
2976 xmax: 0.29802868,
2977 ymax: 0.8195788,
2978 },
2979 score: 0.91537374,
2980 label: 23,
2981 },
2982 DetectBox {
2983 bbox: BoundingBox {
2984 xmin: 0.59605736,
2985 ymin: 0.25545314,
2986 xmax: 0.93666154,
2987 ymax: 0.72378385,
2988 },
2989 score: 0.91537374,
2990 label: 23,
2991 },
2992 ];
2993
2994 let mut tracker = ByteTrackBuilder::new()
2995 .track_update(0.1)
2996 .track_high_conf(0.7)
2997 .build();
2998
2999 let mut output_boxes = Vec::with_capacity(50);
3000 let mut output_masks = Vec::with_capacity(50);
3001 let mut output_tracks = Vec::with_capacity(50);
3002
3003 decoder
3004 .decode_tracked_quantized(
3005 &mut tracker,
3006 0,
3007 &[boxes.view().into(), protos.view().into()],
3008 &mut output_boxes,
3009 &mut output_masks,
3010 &mut output_tracks,
3011 )
3012 .unwrap();
3013
3014 assert_eq!(output_boxes.len(), 2);
3015 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
3016 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1.0 / 160.0));
3017
3018 let mut scores_values = boxes.slice_mut(s![0, 4..84, ..]);
3020 for score in scores_values.iter_mut() {
3021 *score = i8::MIN; }
3023 decoder
3024 .decode_tracked_quantized(
3025 &mut tracker,
3026 100_000_000 / 3,
3027 &[boxes.view().into(), protos.view().into()],
3028 &mut output_boxes,
3029 &mut output_masks,
3030 &mut output_tracks,
3031 )
3032 .unwrap();
3033
3034 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3035 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3036
3037 assert!(output_masks.is_empty())
3039 }
3040
3041 #[test]
3042 fn test_decoder_tracked_segdet_float() {
3043 use crate::configs::Nms;
3044 use crate::DecoderBuilder;
3045
3046 let score_threshold = 0.45;
3047 let iou_threshold = 0.45;
3048 let boxes = include_bytes!(concat!(
3049 env!("CARGO_MANIFEST_DIR"),
3050 "/../../testdata/yolov8_boxes_116x8400.bin"
3051 ));
3052 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
3053 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes.to_vec()).unwrap();
3054 let quant_boxes = (0.021287762, 31);
3055 let mut boxes = dequantize_ndarray(boxes.view(), quant_boxes.into());
3056
3057 let protos = include_bytes!(concat!(
3058 env!("CARGO_MANIFEST_DIR"),
3059 "/../../testdata/yolov8_protos_160x160x32.bin"
3060 ));
3061 let protos =
3062 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
3063 let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
3064 let quant_protos = (0.02491162, -117);
3065 let protos = dequantize_ndarray(protos.view(), quant_protos.into());
3066
3067 let config = include_str!(concat!(
3068 env!("CARGO_MANIFEST_DIR"),
3069 "/../../testdata/yolov8_seg.yaml"
3070 ));
3071
3072 let decoder = DecoderBuilder::default()
3073 .with_config_yaml_str(config.to_string())
3074 .with_score_threshold(score_threshold)
3075 .with_iou_threshold(iou_threshold)
3076 .with_nms(Some(Nms::ClassAgnostic))
3077 .build()
3078 .unwrap();
3079
3080 let expected_boxes = [
3081 DetectBox {
3082 bbox: BoundingBox {
3083 xmin: 0.08515105,
3084 ymin: 0.7131401,
3085 xmax: 0.29802868,
3086 ymax: 0.8195788,
3087 },
3088 score: 0.91537374,
3089 label: 23,
3090 },
3091 DetectBox {
3092 bbox: BoundingBox {
3093 xmin: 0.59605736,
3094 ymin: 0.25545314,
3095 xmax: 0.93666154,
3096 ymax: 0.72378385,
3097 },
3098 score: 0.91537374,
3099 label: 23,
3100 },
3101 ];
3102
3103 let mut tracker = ByteTrackBuilder::new()
3104 .track_update(0.1)
3105 .track_high_conf(0.7)
3106 .build();
3107
3108 let mut output_boxes = Vec::with_capacity(50);
3109 let mut output_masks = Vec::with_capacity(50);
3110 let mut output_tracks = Vec::with_capacity(50);
3111
3112 decoder
3113 .decode_tracked_float(
3114 &mut tracker,
3115 0,
3116 &[boxes.view().into_dyn(), protos.view().into_dyn()],
3117 &mut output_boxes,
3118 &mut output_masks,
3119 &mut output_tracks,
3120 )
3121 .unwrap();
3122
3123 assert_eq!(output_boxes.len(), 2);
3124 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
3125 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1.0 / 160.0));
3126
3127 let mut scores_values = boxes.slice_mut(s![0, 4..84, ..]);
3129 for score in scores_values.iter_mut() {
3130 *score = 0.0; }
3132 decoder
3133 .decode_tracked_float(
3134 &mut tracker,
3135 100_000_000 / 3,
3136 &[boxes.view().into_dyn(), protos.view().into_dyn()],
3137 &mut output_boxes,
3138 &mut output_masks,
3139 &mut output_tracks,
3140 )
3141 .unwrap();
3142
3143 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3144 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3145
3146 assert!(output_masks.is_empty())
3148 }
3149
3150 #[test]
3151 fn test_decoder_tracked_segdet_proto() {
3152 use crate::configs::Nms;
3153 use crate::DecoderBuilder;
3154
3155 let score_threshold = 0.45;
3156 let iou_threshold = 0.45;
3157 let boxes = include_bytes!(concat!(
3158 env!("CARGO_MANIFEST_DIR"),
3159 "/../../testdata/yolov8_boxes_116x8400.bin"
3160 ));
3161 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
3162 let mut boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes.to_vec()).unwrap();
3163
3164 let protos = include_bytes!(concat!(
3165 env!("CARGO_MANIFEST_DIR"),
3166 "/../../testdata/yolov8_protos_160x160x32.bin"
3167 ));
3168 let protos =
3169 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
3170 let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
3171
3172 let config = include_str!(concat!(
3173 env!("CARGO_MANIFEST_DIR"),
3174 "/../../testdata/yolov8_seg.yaml"
3175 ));
3176
3177 let decoder = DecoderBuilder::default()
3178 .with_config_yaml_str(config.to_string())
3179 .with_score_threshold(score_threshold)
3180 .with_iou_threshold(iou_threshold)
3181 .with_nms(Some(Nms::ClassAgnostic))
3182 .build()
3183 .unwrap();
3184
3185 let expected_boxes = [
3186 DetectBox {
3187 bbox: BoundingBox {
3188 xmin: 0.08515105,
3189 ymin: 0.7131401,
3190 xmax: 0.29802868,
3191 ymax: 0.8195788,
3192 },
3193 score: 0.91537374,
3194 label: 23,
3195 },
3196 DetectBox {
3197 bbox: BoundingBox {
3198 xmin: 0.59605736,
3199 ymin: 0.25545314,
3200 xmax: 0.93666154,
3201 ymax: 0.72378385,
3202 },
3203 score: 0.91537374,
3204 label: 23,
3205 },
3206 ];
3207
3208 let mut tracker = ByteTrackBuilder::new()
3209 .track_update(0.1)
3210 .track_high_conf(0.7)
3211 .build();
3212
3213 let mut output_boxes = Vec::with_capacity(50);
3214 let mut output_tracks = Vec::with_capacity(50);
3215
3216 decoder
3217 .decode_tracked_quantized_proto(
3218 &mut tracker,
3219 0,
3220 &[boxes.view().into(), protos.view().into()],
3221 &mut output_boxes,
3222 &mut output_tracks,
3223 )
3224 .unwrap();
3225
3226 assert_eq!(output_boxes.len(), 2);
3227 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
3228 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1.0 / 160.0));
3229
3230 let mut scores_values = boxes.slice_mut(s![0, 4..84, ..]);
3232 for score in scores_values.iter_mut() {
3233 *score = i8::MIN; }
3235 let protos = decoder
3236 .decode_tracked_quantized_proto(
3237 &mut tracker,
3238 100_000_000 / 3,
3239 &[boxes.view().into(), protos.view().into()],
3240 &mut output_boxes,
3241 &mut output_tracks,
3242 )
3243 .unwrap();
3244
3245 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3246 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3247
3248 assert!(protos.is_some_and(|x| x.mask_coefficients.is_empty()));
3250 }
3251
3252 #[test]
3253 fn test_decoder_tracked_segdet_proto_float() {
3254 use crate::configs::Nms;
3255 use crate::DecoderBuilder;
3256
3257 let score_threshold = 0.45;
3258 let iou_threshold = 0.45;
3259 let boxes = include_bytes!(concat!(
3260 env!("CARGO_MANIFEST_DIR"),
3261 "/../../testdata/yolov8_boxes_116x8400.bin"
3262 ));
3263 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
3264 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes.to_vec()).unwrap();
3265 let quant_boxes = (0.021287762, 31);
3266 let mut boxes = dequantize_ndarray(boxes.view(), quant_boxes.into());
3267
3268 let protos = include_bytes!(concat!(
3269 env!("CARGO_MANIFEST_DIR"),
3270 "/../../testdata/yolov8_protos_160x160x32.bin"
3271 ));
3272 let protos =
3273 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
3274 let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
3275 let quant_protos = (0.02491162, -117);
3276 let protos = dequantize_ndarray(protos.view(), quant_protos.into());
3277
3278 let config = include_str!(concat!(
3279 env!("CARGO_MANIFEST_DIR"),
3280 "/../../testdata/yolov8_seg.yaml"
3281 ));
3282
3283 let decoder = DecoderBuilder::default()
3284 .with_config_yaml_str(config.to_string())
3285 .with_score_threshold(score_threshold)
3286 .with_iou_threshold(iou_threshold)
3287 .with_nms(Some(Nms::ClassAgnostic))
3288 .build()
3289 .unwrap();
3290
3291 let expected_boxes = [
3292 DetectBox {
3293 bbox: BoundingBox {
3294 xmin: 0.08515105,
3295 ymin: 0.7131401,
3296 xmax: 0.29802868,
3297 ymax: 0.8195788,
3298 },
3299 score: 0.91537374,
3300 label: 23,
3301 },
3302 DetectBox {
3303 bbox: BoundingBox {
3304 xmin: 0.59605736,
3305 ymin: 0.25545314,
3306 xmax: 0.93666154,
3307 ymax: 0.72378385,
3308 },
3309 score: 0.91537374,
3310 label: 23,
3311 },
3312 ];
3313
3314 let mut tracker = ByteTrackBuilder::new()
3315 .track_update(0.1)
3316 .track_high_conf(0.7)
3317 .build();
3318
3319 let mut output_boxes = Vec::with_capacity(50);
3320 let mut output_tracks = Vec::with_capacity(50);
3321
3322 decoder
3323 .decode_tracked_float_proto(
3324 &mut tracker,
3325 0,
3326 &[boxes.view().into_dyn(), protos.view().into_dyn()],
3327 &mut output_boxes,
3328 &mut output_tracks,
3329 )
3330 .unwrap();
3331
3332 assert_eq!(output_boxes.len(), 2);
3333 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
3334 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1.0 / 160.0));
3335
3336 let mut scores_values = boxes.slice_mut(s![0, 4..84, ..]);
3338 for score in scores_values.iter_mut() {
3339 *score = 0.0; }
3341 let protos = decoder
3342 .decode_tracked_float_proto(
3343 &mut tracker,
3344 100_000_000 / 3,
3345 &[boxes.view().into_dyn(), protos.view().into_dyn()],
3346 &mut output_boxes,
3347 &mut output_tracks,
3348 )
3349 .unwrap();
3350
3351 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3352 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3353
3354 assert!(protos.is_some_and(|x| x.mask_coefficients.is_empty()));
3356 }
3357
3358 #[test]
3359 fn test_decoder_tracked_segdet_split() {
3360 let score_threshold = 0.45;
3361 let iou_threshold = 0.45;
3362
3363 let boxes = include_bytes!(concat!(
3364 env!("CARGO_MANIFEST_DIR"),
3365 "/../../testdata/yolov8_boxes_116x8400.bin"
3366 ));
3367 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
3368 let boxes = boxes.to_vec();
3369 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
3370
3371 let mask = boxes.slice(s![.., 84.., ..]).to_owned();
3372 let mut scores = boxes.slice(s![.., 4..84, ..]).to_owned();
3373 let boxes = boxes.slice(s![.., ..4, ..]).to_owned();
3374
3375 let quant_boxes = (0.021287762, 31);
3376
3377 let protos = include_bytes!(concat!(
3378 env!("CARGO_MANIFEST_DIR"),
3379 "/../../testdata/yolov8_protos_160x160x32.bin"
3380 ));
3381 let protos =
3382 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
3383 let protos = protos.to_vec();
3384 let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos).unwrap();
3385 let quant_protos = (0.02491162, -117);
3386 let decoder = DecoderBuilder::default()
3387 .with_config_yolo_split_segdet(
3388 configs::Boxes {
3389 decoder: configs::DecoderType::Ultralytics,
3390 quantization: Some(quant_boxes.into()),
3391 shape: vec![1, 4, 8400],
3392 dshape: vec![
3393 (DimName::Batch, 1),
3394 (DimName::BoxCoords, 4),
3395 (DimName::NumBoxes, 8400),
3396 ],
3397 normalized: Some(true),
3398 },
3399 configs::Scores {
3400 decoder: configs::DecoderType::Ultralytics,
3401 quantization: Some(quant_boxes.into()),
3402 shape: vec![1, 80, 8400],
3403 dshape: vec![
3404 (DimName::Batch, 1),
3405 (DimName::NumClasses, 80),
3406 (DimName::NumBoxes, 8400),
3407 ],
3408 },
3409 configs::MaskCoefficients {
3410 decoder: configs::DecoderType::Ultralytics,
3411 quantization: Some(quant_boxes.into()),
3412 shape: vec![1, 32, 8400],
3413 dshape: vec![
3414 (DimName::Batch, 1),
3415 (DimName::NumProtos, 32),
3416 (DimName::NumBoxes, 8400),
3417 ],
3418 },
3419 configs::Protos {
3420 decoder: configs::DecoderType::Ultralytics,
3421 quantization: Some(quant_protos.into()),
3422 shape: vec![1, 160, 160, 32],
3423 dshape: vec![
3424 (DimName::Batch, 1),
3425 (DimName::Height, 160),
3426 (DimName::Width, 160),
3427 (DimName::NumProtos, 32),
3428 ],
3429 },
3430 )
3431 .with_score_threshold(score_threshold)
3432 .with_iou_threshold(iou_threshold)
3433 .build()
3434 .unwrap();
3435
3436 let expected_boxes = [
3437 DetectBox {
3438 bbox: BoundingBox {
3439 xmin: 0.08515105,
3440 ymin: 0.7131401,
3441 xmax: 0.29802868,
3442 ymax: 0.8195788,
3443 },
3444 score: 0.91537374,
3445 label: 23,
3446 },
3447 DetectBox {
3448 bbox: BoundingBox {
3449 xmin: 0.59605736,
3450 ymin: 0.25545314,
3451 xmax: 0.93666154,
3452 ymax: 0.72378385,
3453 },
3454 score: 0.91537374,
3455 label: 23,
3456 },
3457 ];
3458
3459 let mut tracker = ByteTrackBuilder::new()
3460 .track_update(0.1)
3461 .track_high_conf(0.7)
3462 .build();
3463
3464 let mut output_boxes = Vec::with_capacity(50);
3465 let mut output_masks = Vec::with_capacity(50);
3466 let mut output_tracks = Vec::with_capacity(50);
3467
3468 decoder
3469 .decode_tracked_quantized(
3470 &mut tracker,
3471 0,
3472 &[
3473 boxes.view().into(),
3474 scores.view().into(),
3475 mask.view().into(),
3476 protos.view().into(),
3477 ],
3478 &mut output_boxes,
3479 &mut output_masks,
3480 &mut output_tracks,
3481 )
3482 .unwrap();
3483
3484 assert_eq!(output_boxes.len(), 2);
3485 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
3486 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1.0 / 160.0));
3487
3488 for score in scores.iter_mut() {
3491 *score = i8::MIN; }
3493 decoder
3494 .decode_tracked_quantized(
3495 &mut tracker,
3496 100_000_000 / 3,
3497 &[
3498 boxes.view().into(),
3499 scores.view().into(),
3500 mask.view().into(),
3501 protos.view().into(),
3502 ],
3503 &mut output_boxes,
3504 &mut output_masks,
3505 &mut output_tracks,
3506 )
3507 .unwrap();
3508
3509 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3510 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3511
3512 assert!(output_masks.is_empty())
3514 }
3515
3516 #[test]
3517 fn test_decoder_tracked_segdet_split_float() {
3518 let score_threshold = 0.45;
3519 let iou_threshold = 0.45;
3520
3521 let boxes = include_bytes!(concat!(
3522 env!("CARGO_MANIFEST_DIR"),
3523 "/../../testdata/yolov8_boxes_116x8400.bin"
3524 ));
3525 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
3526 let boxes = boxes.to_vec();
3527 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
3528 let quant_boxes = (0.021287762, 31);
3529 let boxes = dequantize_ndarray(boxes.view(), quant_boxes.into());
3530
3531 let mask = boxes.slice(s![.., 84.., ..]).to_owned();
3532 let mut scores = boxes.slice(s![.., 4..84, ..]).to_owned();
3533 let boxes = boxes.slice(s![.., ..4, ..]).to_owned();
3534
3535 let protos = include_bytes!(concat!(
3536 env!("CARGO_MANIFEST_DIR"),
3537 "/../../testdata/yolov8_protos_160x160x32.bin"
3538 ));
3539 let protos =
3540 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
3541 let protos = protos.to_vec();
3542 let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos).unwrap();
3543 let quant_protos = (0.02491162, -117);
3544 let protos = dequantize_ndarray(protos.view(), quant_protos.into());
3545
3546 let decoder = DecoderBuilder::default()
3547 .with_config_yolo_split_segdet(
3548 configs::Boxes {
3549 decoder: configs::DecoderType::Ultralytics,
3550 quantization: Some(quant_boxes.into()),
3551 shape: vec![1, 4, 8400],
3552 dshape: vec![
3553 (DimName::Batch, 1),
3554 (DimName::BoxCoords, 4),
3555 (DimName::NumBoxes, 8400),
3556 ],
3557 normalized: Some(true),
3558 },
3559 configs::Scores {
3560 decoder: configs::DecoderType::Ultralytics,
3561 quantization: Some(quant_boxes.into()),
3562 shape: vec![1, 80, 8400],
3563 dshape: vec![
3564 (DimName::Batch, 1),
3565 (DimName::NumClasses, 80),
3566 (DimName::NumBoxes, 8400),
3567 ],
3568 },
3569 configs::MaskCoefficients {
3570 decoder: configs::DecoderType::Ultralytics,
3571 quantization: Some(quant_boxes.into()),
3572 shape: vec![1, 32, 8400],
3573 dshape: vec![
3574 (DimName::Batch, 1),
3575 (DimName::NumProtos, 32),
3576 (DimName::NumBoxes, 8400),
3577 ],
3578 },
3579 configs::Protos {
3580 decoder: configs::DecoderType::Ultralytics,
3581 quantization: Some(quant_protos.into()),
3582 shape: vec![1, 160, 160, 32],
3583 dshape: vec![
3584 (DimName::Batch, 1),
3585 (DimName::Height, 160),
3586 (DimName::Width, 160),
3587 (DimName::NumProtos, 32),
3588 ],
3589 },
3590 )
3591 .with_score_threshold(score_threshold)
3592 .with_iou_threshold(iou_threshold)
3593 .build()
3594 .unwrap();
3595
3596 let expected_boxes = [
3597 DetectBox {
3598 bbox: BoundingBox {
3599 xmin: 0.08515105,
3600 ymin: 0.7131401,
3601 xmax: 0.29802868,
3602 ymax: 0.8195788,
3603 },
3604 score: 0.91537374,
3605 label: 23,
3606 },
3607 DetectBox {
3608 bbox: BoundingBox {
3609 xmin: 0.59605736,
3610 ymin: 0.25545314,
3611 xmax: 0.93666154,
3612 ymax: 0.72378385,
3613 },
3614 score: 0.91537374,
3615 label: 23,
3616 },
3617 ];
3618
3619 let mut tracker = ByteTrackBuilder::new()
3620 .track_update(0.1)
3621 .track_high_conf(0.7)
3622 .build();
3623
3624 let mut output_boxes = Vec::with_capacity(50);
3625 let mut output_masks = Vec::with_capacity(50);
3626 let mut output_tracks = Vec::with_capacity(50);
3627
3628 decoder
3629 .decode_tracked_float(
3630 &mut tracker,
3631 0,
3632 &[
3633 boxes.view().into_dyn(),
3634 scores.view().into_dyn(),
3635 mask.view().into_dyn(),
3636 protos.view().into_dyn(),
3637 ],
3638 &mut output_boxes,
3639 &mut output_masks,
3640 &mut output_tracks,
3641 )
3642 .unwrap();
3643
3644 assert_eq!(output_boxes.len(), 2);
3645 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
3646 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1.0 / 160.0));
3647
3648 for score in scores.iter_mut() {
3651 *score = 0.0; }
3653 decoder
3654 .decode_tracked_float(
3655 &mut tracker,
3656 100_000_000 / 3,
3657 &[
3658 boxes.view().into_dyn(),
3659 scores.view().into_dyn(),
3660 mask.view().into_dyn(),
3661 protos.view().into_dyn(),
3662 ],
3663 &mut output_boxes,
3664 &mut output_masks,
3665 &mut output_tracks,
3666 )
3667 .unwrap();
3668
3669 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3670 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3671
3672 assert!(output_masks.is_empty())
3674 }
3675
3676 #[test]
3677 fn test_decoder_tracked_segdet_split_proto() {
3678 let score_threshold = 0.45;
3679 let iou_threshold = 0.45;
3680
3681 let boxes = include_bytes!(concat!(
3682 env!("CARGO_MANIFEST_DIR"),
3683 "/../../testdata/yolov8_boxes_116x8400.bin"
3684 ));
3685 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
3686 let boxes = boxes.to_vec();
3687 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
3688
3689 let mask = boxes.slice(s![.., 84.., ..]).to_owned();
3690 let mut scores = boxes.slice(s![.., 4..84, ..]).to_owned();
3691 let boxes = boxes.slice(s![.., ..4, ..]).to_owned();
3692
3693 let quant_boxes = (0.021287762, 31);
3694
3695 let protos = include_bytes!(concat!(
3696 env!("CARGO_MANIFEST_DIR"),
3697 "/../../testdata/yolov8_protos_160x160x32.bin"
3698 ));
3699 let protos =
3700 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
3701 let protos = protos.to_vec();
3702 let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos).unwrap();
3703 let quant_protos = (0.02491162, -117);
3704 let decoder = DecoderBuilder::default()
3705 .with_config_yolo_split_segdet(
3706 configs::Boxes {
3707 decoder: configs::DecoderType::Ultralytics,
3708 quantization: Some(quant_boxes.into()),
3709 shape: vec![1, 4, 8400],
3710 dshape: vec![
3711 (DimName::Batch, 1),
3712 (DimName::BoxCoords, 4),
3713 (DimName::NumBoxes, 8400),
3714 ],
3715 normalized: Some(true),
3716 },
3717 configs::Scores {
3718 decoder: configs::DecoderType::Ultralytics,
3719 quantization: Some(quant_boxes.into()),
3720 shape: vec![1, 80, 8400],
3721 dshape: vec![
3722 (DimName::Batch, 1),
3723 (DimName::NumClasses, 80),
3724 (DimName::NumBoxes, 8400),
3725 ],
3726 },
3727 configs::MaskCoefficients {
3728 decoder: configs::DecoderType::Ultralytics,
3729 quantization: Some(quant_boxes.into()),
3730 shape: vec![1, 32, 8400],
3731 dshape: vec![
3732 (DimName::Batch, 1),
3733 (DimName::NumProtos, 32),
3734 (DimName::NumBoxes, 8400),
3735 ],
3736 },
3737 configs::Protos {
3738 decoder: configs::DecoderType::Ultralytics,
3739 quantization: Some(quant_protos.into()),
3740 shape: vec![1, 160, 160, 32],
3741 dshape: vec![
3742 (DimName::Batch, 1),
3743 (DimName::Height, 160),
3744 (DimName::Width, 160),
3745 (DimName::NumProtos, 32),
3746 ],
3747 },
3748 )
3749 .with_score_threshold(score_threshold)
3750 .with_iou_threshold(iou_threshold)
3751 .build()
3752 .unwrap();
3753
3754 let expected_boxes = [
3755 DetectBox {
3756 bbox: BoundingBox {
3757 xmin: 0.08515105,
3758 ymin: 0.7131401,
3759 xmax: 0.29802868,
3760 ymax: 0.8195788,
3761 },
3762 score: 0.91537374,
3763 label: 23,
3764 },
3765 DetectBox {
3766 bbox: BoundingBox {
3767 xmin: 0.59605736,
3768 ymin: 0.25545314,
3769 xmax: 0.93666154,
3770 ymax: 0.72378385,
3771 },
3772 score: 0.91537374,
3773 label: 23,
3774 },
3775 ];
3776
3777 let mut tracker = ByteTrackBuilder::new()
3778 .track_update(0.1)
3779 .track_high_conf(0.7)
3780 .build();
3781
3782 let mut output_boxes = Vec::with_capacity(50);
3783 let mut output_tracks = Vec::with_capacity(50);
3784
3785 decoder
3786 .decode_tracked_quantized_proto(
3787 &mut tracker,
3788 0,
3789 &[
3790 boxes.view().into(),
3791 scores.view().into(),
3792 mask.view().into(),
3793 protos.view().into(),
3794 ],
3795 &mut output_boxes,
3796 &mut output_tracks,
3797 )
3798 .unwrap();
3799
3800 assert_eq!(output_boxes.len(), 2);
3801 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
3802 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1.0 / 160.0));
3803
3804 for score in scores.iter_mut() {
3807 *score = i8::MIN; }
3809 let protos = decoder
3810 .decode_tracked_quantized_proto(
3811 &mut tracker,
3812 100_000_000 / 3,
3813 &[
3814 boxes.view().into(),
3815 scores.view().into(),
3816 mask.view().into(),
3817 protos.view().into(),
3818 ],
3819 &mut output_boxes,
3820 &mut output_tracks,
3821 )
3822 .unwrap();
3823
3824 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3825 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3826
3827 assert!(protos.is_some_and(|x| x.mask_coefficients.is_empty()));
3829 }
3830
3831 #[test]
3832 fn test_decoder_tracked_segdet_split_proto_float() {
3833 let score_threshold = 0.45;
3834 let iou_threshold = 0.45;
3835
3836 let boxes = include_bytes!(concat!(
3837 env!("CARGO_MANIFEST_DIR"),
3838 "/../../testdata/yolov8_boxes_116x8400.bin"
3839 ));
3840 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
3841 let boxes = boxes.to_vec();
3842 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
3843 let quant_boxes = (0.021287762, 31);
3844 let boxes = dequantize_ndarray(boxes.view(), quant_boxes.into());
3845
3846 let mask = boxes.slice(s![.., 84.., ..]).to_owned();
3847 let mut scores = boxes.slice(s![.., 4..84, ..]).to_owned();
3848 let boxes = boxes.slice(s![.., ..4, ..]).to_owned();
3849
3850 let protos = include_bytes!(concat!(
3851 env!("CARGO_MANIFEST_DIR"),
3852 "/../../testdata/yolov8_protos_160x160x32.bin"
3853 ));
3854 let protos =
3855 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
3856 let protos = protos.to_vec();
3857 let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos).unwrap();
3858 let quant_protos = (0.02491162, -117);
3859 let protos = dequantize_ndarray(protos.view(), quant_protos.into());
3860
3861 let decoder = DecoderBuilder::default()
3862 .with_config_yolo_split_segdet(
3863 configs::Boxes {
3864 decoder: configs::DecoderType::Ultralytics,
3865 quantization: Some(quant_boxes.into()),
3866 shape: vec![1, 4, 8400],
3867 dshape: vec![
3868 (DimName::Batch, 1),
3869 (DimName::BoxCoords, 4),
3870 (DimName::NumBoxes, 8400),
3871 ],
3872 normalized: Some(true),
3873 },
3874 configs::Scores {
3875 decoder: configs::DecoderType::Ultralytics,
3876 quantization: Some(quant_boxes.into()),
3877 shape: vec![1, 80, 8400],
3878 dshape: vec![
3879 (DimName::Batch, 1),
3880 (DimName::NumClasses, 80),
3881 (DimName::NumBoxes, 8400),
3882 ],
3883 },
3884 configs::MaskCoefficients {
3885 decoder: configs::DecoderType::Ultralytics,
3886 quantization: Some(quant_boxes.into()),
3887 shape: vec![1, 32, 8400],
3888 dshape: vec![
3889 (DimName::Batch, 1),
3890 (DimName::NumProtos, 32),
3891 (DimName::NumBoxes, 8400),
3892 ],
3893 },
3894 configs::Protos {
3895 decoder: configs::DecoderType::Ultralytics,
3896 quantization: Some(quant_protos.into()),
3897 shape: vec![1, 160, 160, 32],
3898 dshape: vec![
3899 (DimName::Batch, 1),
3900 (DimName::Height, 160),
3901 (DimName::Width, 160),
3902 (DimName::NumProtos, 32),
3903 ],
3904 },
3905 )
3906 .with_score_threshold(score_threshold)
3907 .with_iou_threshold(iou_threshold)
3908 .build()
3909 .unwrap();
3910
3911 let expected_boxes = [
3912 DetectBox {
3913 bbox: BoundingBox {
3914 xmin: 0.08515105,
3915 ymin: 0.7131401,
3916 xmax: 0.29802868,
3917 ymax: 0.8195788,
3918 },
3919 score: 0.91537374,
3920 label: 23,
3921 },
3922 DetectBox {
3923 bbox: BoundingBox {
3924 xmin: 0.59605736,
3925 ymin: 0.25545314,
3926 xmax: 0.93666154,
3927 ymax: 0.72378385,
3928 },
3929 score: 0.91537374,
3930 label: 23,
3931 },
3932 ];
3933
3934 let mut tracker = ByteTrackBuilder::new()
3935 .track_update(0.1)
3936 .track_high_conf(0.7)
3937 .build();
3938
3939 let mut output_boxes = Vec::with_capacity(50);
3940 let mut output_tracks = Vec::with_capacity(50);
3941
3942 decoder
3943 .decode_tracked_float_proto(
3944 &mut tracker,
3945 0,
3946 &[
3947 boxes.view().into_dyn(),
3948 scores.view().into_dyn(),
3949 mask.view().into_dyn(),
3950 protos.view().into_dyn(),
3951 ],
3952 &mut output_boxes,
3953 &mut output_tracks,
3954 )
3955 .unwrap();
3956
3957 assert_eq!(output_boxes.len(), 2);
3958 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
3959 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1.0 / 160.0));
3960
3961 for score in scores.iter_mut() {
3964 *score = 0.0; }
3966 let protos = decoder
3967 .decode_tracked_float_proto(
3968 &mut tracker,
3969 100_000_000 / 3,
3970 &[
3971 boxes.view().into_dyn(),
3972 scores.view().into_dyn(),
3973 mask.view().into_dyn(),
3974 protos.view().into_dyn(),
3975 ],
3976 &mut output_boxes,
3977 &mut output_tracks,
3978 )
3979 .unwrap();
3980
3981 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3982 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3983
3984 assert!(protos.is_some_and(|x| x.mask_coefficients.is_empty()));
3986 }
3987
3988 #[test]
3989 fn test_decoder_tracked_end_to_end_segdet() {
3990 let score_threshold = 0.45;
3991 let iou_threshold = 0.45;
3992
3993 let mut boxes = Array2::zeros((10, 4));
3994 let mut scores = Array2::zeros((10, 1));
3995 let mut classes = Array2::zeros((10, 1));
3996 let mask = Array2::zeros((10, 32));
3997 let protos = Array3::<f64>::zeros((160, 160, 32));
3998 let protos = protos.insert_axis(Axis(0));
3999
4000 let protos_quant = (1.0 / 255.0, 0.0);
4001 let protos: Array4<u8> = quantize_ndarray(protos.view(), protos_quant.into());
4002
4003 boxes
4004 .slice_mut(s![0, ..,])
4005 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4006 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4007 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4008
4009 let detect = ndarray::concatenate![
4010 Axis(1),
4011 boxes.view(),
4012 scores.view(),
4013 classes.view(),
4014 mask.view()
4015 ];
4016 let detect = detect.insert_axis(Axis(0));
4017 assert_eq!(detect.shape(), &[1, 10, 38]);
4018 let detect_quant = (2.0 / 255.0, 0.0);
4019 let mut detect: Array3<u8> = quantize_ndarray(detect.view(), detect_quant.into());
4020 let config = "
4021decoder_version: yolo26
4022outputs:
4023 - type: detection
4024 decoder: ultralytics
4025 quantization: [0.00784313725490196, 0]
4026 shape: [1, 10, 38]
4027 dshape:
4028 - [batch, 1]
4029 - [num_boxes, 10]
4030 - [num_features, 38]
4031 normalized: true
4032 - type: protos
4033 decoder: ultralytics
4034 quantization: [0.0039215686274509803921568627451, 128]
4035 shape: [1, 160, 160, 32]
4036 dshape:
4037 - [batch, 1]
4038 - [height, 160]
4039 - [width, 160]
4040 - [num_protos, 32]
4041";
4042
4043 let decoder = DecoderBuilder::default()
4044 .with_config_yaml_str(config.to_string())
4045 .with_score_threshold(score_threshold)
4046 .with_iou_threshold(iou_threshold)
4047 .build()
4048 .unwrap();
4049
4050 let expected_boxes = [DetectBox {
4052 bbox: BoundingBox {
4053 xmin: 0.12549022,
4054 ymin: 0.12549022,
4055 xmax: 0.23529413,
4056 ymax: 0.23529413,
4057 },
4058 score: 0.98823535,
4059 label: 2,
4060 }];
4061
4062 let mut tracker = ByteTrackBuilder::new()
4063 .track_update(0.1)
4064 .track_high_conf(0.7)
4065 .build();
4066
4067 let mut output_boxes = Vec::with_capacity(50);
4068 let mut output_masks = Vec::with_capacity(50);
4069 let mut output_tracks = Vec::with_capacity(50);
4070
4071 decoder
4072 .decode_tracked_quantized(
4073 &mut tracker,
4074 0,
4075 &[detect.view().into(), protos.view().into()],
4076 &mut output_boxes,
4077 &mut output_masks,
4078 &mut output_tracks,
4079 )
4080 .unwrap();
4081
4082 assert_eq!(output_boxes.len(), 1);
4083 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
4084
4085 for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
4088 *score = u8::MIN; }
4090
4091 decoder
4092 .decode_tracked_quantized(
4093 &mut tracker,
4094 100_000_000 / 3,
4095 &[detect.view().into(), protos.view().into()],
4096 &mut output_boxes,
4097 &mut output_masks,
4098 &mut output_tracks,
4099 )
4100 .unwrap();
4101 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4102 assert!(output_masks.is_empty())
4104 }
4105
4106 #[test]
4107 fn test_decoder_tracked_end_to_end_segdet_float() {
4108 let score_threshold = 0.45;
4109 let iou_threshold = 0.45;
4110
4111 let mut boxes = Array2::zeros((10, 4));
4112 let mut scores = Array2::zeros((10, 1));
4113 let mut classes = Array2::zeros((10, 1));
4114 let mask = Array2::zeros((10, 32));
4115 let protos = Array3::<f64>::zeros((160, 160, 32));
4116 let protos = protos.insert_axis(Axis(0));
4117
4118 boxes
4119 .slice_mut(s![0, ..,])
4120 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4121 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4122 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4123
4124 let detect = ndarray::concatenate![
4125 Axis(1),
4126 boxes.view(),
4127 scores.view(),
4128 classes.view(),
4129 mask.view()
4130 ];
4131 let mut detect = detect.insert_axis(Axis(0));
4132 assert_eq!(detect.shape(), &[1, 10, 38]);
4133 let config = "
4134decoder_version: yolo26
4135outputs:
4136 - type: detection
4137 decoder: ultralytics
4138 quantization: [0.00784313725490196, 0]
4139 shape: [1, 10, 38]
4140 dshape:
4141 - [batch, 1]
4142 - [num_boxes, 10]
4143 - [num_features, 38]
4144 normalized: true
4145 - type: protos
4146 decoder: ultralytics
4147 quantization: [0.0039215686274509803921568627451, 128]
4148 shape: [1, 160, 160, 32]
4149 dshape:
4150 - [batch, 1]
4151 - [height, 160]
4152 - [width, 160]
4153 - [num_protos, 32]
4154";
4155
4156 let decoder = DecoderBuilder::default()
4157 .with_config_yaml_str(config.to_string())
4158 .with_score_threshold(score_threshold)
4159 .with_iou_threshold(iou_threshold)
4160 .build()
4161 .unwrap();
4162
4163 let expected_boxes = [DetectBox {
4164 bbox: BoundingBox {
4165 xmin: 0.1234,
4166 ymin: 0.1234,
4167 xmax: 0.2345,
4168 ymax: 0.2345,
4169 },
4170 score: 0.9876,
4171 label: 2,
4172 }];
4173
4174 let mut tracker = ByteTrackBuilder::new()
4175 .track_update(0.1)
4176 .track_high_conf(0.7)
4177 .build();
4178
4179 let mut output_boxes = Vec::with_capacity(50);
4180 let mut output_masks = Vec::with_capacity(50);
4181 let mut output_tracks = Vec::with_capacity(50);
4182
4183 decoder
4184 .decode_tracked_float(
4185 &mut tracker,
4186 0,
4187 &[detect.view().into_dyn(), protos.view().into_dyn()],
4188 &mut output_boxes,
4189 &mut output_masks,
4190 &mut output_tracks,
4191 )
4192 .unwrap();
4193
4194 assert_eq!(output_boxes.len(), 1);
4195 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
4196
4197 for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
4200 *score = 0.0; }
4202
4203 decoder
4204 .decode_tracked_float(
4205 &mut tracker,
4206 100_000_000 / 3,
4207 &[detect.view().into_dyn(), protos.view().into_dyn()],
4208 &mut output_boxes,
4209 &mut output_masks,
4210 &mut output_tracks,
4211 )
4212 .unwrap();
4213 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4214 assert!(output_masks.is_empty())
4216 }
4217
4218 #[test]
4219 fn test_decoder_tracked_end_to_end_segdet_proto() {
4220 let score_threshold = 0.45;
4221 let iou_threshold = 0.45;
4222
4223 let mut boxes = Array2::zeros((10, 4));
4224 let mut scores = Array2::zeros((10, 1));
4225 let mut classes = Array2::zeros((10, 1));
4226 let mask = Array2::zeros((10, 32));
4227 let protos = Array3::<f64>::zeros((160, 160, 32));
4228 let protos = protos.insert_axis(Axis(0));
4229
4230 let protos_quant = (1.0 / 255.0, 0.0);
4231 let protos: Array4<u8> = quantize_ndarray(protos.view(), protos_quant.into());
4232
4233 boxes
4234 .slice_mut(s![0, ..,])
4235 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4236 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4237 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4238
4239 let detect = ndarray::concatenate![
4240 Axis(1),
4241 boxes.view(),
4242 scores.view(),
4243 classes.view(),
4244 mask.view()
4245 ];
4246 let detect = detect.insert_axis(Axis(0));
4247 assert_eq!(detect.shape(), &[1, 10, 38]);
4248 let detect_quant = (2.0 / 255.0, 0.0);
4249 let mut detect: Array3<u8> = quantize_ndarray(detect.view(), detect_quant.into());
4250 let config = "
4251decoder_version: yolo26
4252outputs:
4253 - type: detection
4254 decoder: ultralytics
4255 quantization: [0.00784313725490196, 0]
4256 shape: [1, 10, 38]
4257 dshape:
4258 - [batch, 1]
4259 - [num_boxes, 10]
4260 - [num_features, 38]
4261 normalized: true
4262 - type: protos
4263 decoder: ultralytics
4264 quantization: [0.0039215686274509803921568627451, 128]
4265 shape: [1, 160, 160, 32]
4266 dshape:
4267 - [batch, 1]
4268 - [height, 160]
4269 - [width, 160]
4270 - [num_protos, 32]
4271";
4272
4273 let decoder = DecoderBuilder::default()
4274 .with_config_yaml_str(config.to_string())
4275 .with_score_threshold(score_threshold)
4276 .with_iou_threshold(iou_threshold)
4277 .build()
4278 .unwrap();
4279
4280 let expected_boxes = [DetectBox {
4282 bbox: BoundingBox {
4283 xmin: 0.12549022,
4284 ymin: 0.12549022,
4285 xmax: 0.23529413,
4286 ymax: 0.23529413,
4287 },
4288 score: 0.98823535,
4289 label: 2,
4290 }];
4291
4292 let mut tracker = ByteTrackBuilder::new()
4293 .track_update(0.1)
4294 .track_high_conf(0.7)
4295 .build();
4296
4297 let mut output_boxes = Vec::with_capacity(50);
4298 let mut output_tracks = Vec::with_capacity(50);
4299
4300 decoder
4301 .decode_tracked_quantized_proto(
4302 &mut tracker,
4303 0,
4304 &[detect.view().into(), protos.view().into()],
4305 &mut output_boxes,
4306 &mut output_tracks,
4307 )
4308 .unwrap();
4309
4310 assert_eq!(output_boxes.len(), 1);
4311 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
4312
4313 for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
4316 *score = u8::MIN; }
4318
4319 let protos = decoder
4320 .decode_tracked_quantized_proto(
4321 &mut tracker,
4322 100_000_000 / 3,
4323 &[detect.view().into(), protos.view().into()],
4324 &mut output_boxes,
4325 &mut output_tracks,
4326 )
4327 .unwrap();
4328 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4329 assert!(protos.is_some_and(|x| x.mask_coefficients.is_empty()))
4331 }
4332
4333 #[test]
4334 fn test_decoder_tracked_end_to_end_segdet_proto_float() {
4335 let score_threshold = 0.45;
4336 let iou_threshold = 0.45;
4337
4338 let mut boxes = Array2::zeros((10, 4));
4339 let mut scores = Array2::zeros((10, 1));
4340 let mut classes = Array2::zeros((10, 1));
4341 let mask = Array2::zeros((10, 32));
4342 let protos = Array3::<f64>::zeros((160, 160, 32));
4343 let protos = protos.insert_axis(Axis(0));
4344
4345 boxes
4346 .slice_mut(s![0, ..,])
4347 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4348 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4349 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4350
4351 let detect = ndarray::concatenate![
4352 Axis(1),
4353 boxes.view(),
4354 scores.view(),
4355 classes.view(),
4356 mask.view()
4357 ];
4358 let mut detect = detect.insert_axis(Axis(0));
4359 assert_eq!(detect.shape(), &[1, 10, 38]);
4360 let config = "
4361decoder_version: yolo26
4362outputs:
4363 - type: detection
4364 decoder: ultralytics
4365 quantization: [0.00784313725490196, 0]
4366 shape: [1, 10, 38]
4367 dshape:
4368 - [batch, 1]
4369 - [num_boxes, 10]
4370 - [num_features, 38]
4371 normalized: true
4372 - type: protos
4373 decoder: ultralytics
4374 quantization: [0.0039215686274509803921568627451, 128]
4375 shape: [1, 160, 160, 32]
4376 dshape:
4377 - [batch, 1]
4378 - [height, 160]
4379 - [width, 160]
4380 - [num_protos, 32]
4381";
4382
4383 let decoder = DecoderBuilder::default()
4384 .with_config_yaml_str(config.to_string())
4385 .with_score_threshold(score_threshold)
4386 .with_iou_threshold(iou_threshold)
4387 .build()
4388 .unwrap();
4389
4390 let expected_boxes = [DetectBox {
4391 bbox: BoundingBox {
4392 xmin: 0.1234,
4393 ymin: 0.1234,
4394 xmax: 0.2345,
4395 ymax: 0.2345,
4396 },
4397 score: 0.9876,
4398 label: 2,
4399 }];
4400
4401 let mut tracker = ByteTrackBuilder::new()
4402 .track_update(0.1)
4403 .track_high_conf(0.7)
4404 .build();
4405
4406 let mut output_boxes = Vec::with_capacity(50);
4407 let mut output_tracks = Vec::with_capacity(50);
4408
4409 decoder
4410 .decode_tracked_float_proto(
4411 &mut tracker,
4412 0,
4413 &[detect.view().into_dyn(), protos.view().into_dyn()],
4414 &mut output_boxes,
4415 &mut output_tracks,
4416 )
4417 .unwrap();
4418
4419 assert_eq!(output_boxes.len(), 1);
4420 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
4421
4422 for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
4425 *score = 0.0; }
4427
4428 let protos = decoder
4429 .decode_tracked_float_proto(
4430 &mut tracker,
4431 100_000_000 / 3,
4432 &[detect.view().into_dyn(), protos.view().into_dyn()],
4433 &mut output_boxes,
4434 &mut output_tracks,
4435 )
4436 .unwrap();
4437 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4438 assert!(protos.is_some_and(|x| x.mask_coefficients.is_empty()))
4440 }
4441
4442 #[test]
4443 fn test_decoder_tracked_end_to_end_segdet_split() {
4444 let score_threshold = 0.45;
4445 let iou_threshold = 0.45;
4446
4447 let mut boxes = Array2::zeros((10, 4));
4448 let mut scores = Array2::zeros((10, 1));
4449 let mut classes = Array2::zeros((10, 1));
4450 let mask: Array2<f64> = Array2::zeros((10, 32));
4451 let protos = Array3::<f64>::zeros((160, 160, 32));
4452 let protos = protos.insert_axis(Axis(0));
4453
4454 let protos_quant = (1.0 / 255.0, 0.0);
4455 let protos: Array4<u8> = quantize_ndarray(protos.view(), protos_quant.into());
4456
4457 boxes
4458 .slice_mut(s![0, ..,])
4459 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4460 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4461 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4462
4463 let boxes = boxes.insert_axis(Axis(0));
4464 let scores = scores.insert_axis(Axis(0));
4465 let classes = classes.insert_axis(Axis(0));
4466 let mask = mask.insert_axis(Axis(0));
4467
4468 let detect_quant = (2.0 / 255.0, 0.0);
4469 let boxes: Array3<u8> = quantize_ndarray(boxes.view(), detect_quant.into());
4470 let mut scores: Array3<u8> = quantize_ndarray(scores.view(), detect_quant.into());
4471 let classes: Array3<u8> = quantize_ndarray(classes.view(), detect_quant.into());
4472 let mask: Array3<u8> = quantize_ndarray(mask.view(), detect_quant.into());
4473
4474 let config = "
4475decoder_version: yolo26
4476outputs:
4477 - type: boxes
4478 decoder: ultralytics
4479 quantization: [0.00784313725490196, 0]
4480 shape: [1, 10, 4]
4481 dshape:
4482 - [batch, 1]
4483 - [num_boxes, 10]
4484 - [box_coords, 4]
4485 normalized: true
4486 - type: scores
4487 decoder: ultralytics
4488 quantization: [0.00784313725490196, 0]
4489 shape: [1, 10, 1]
4490 dshape:
4491 - [batch, 1]
4492 - [num_boxes, 10]
4493 - [num_classes, 1]
4494 - type: classes
4495 decoder: ultralytics
4496 quantization: [0.00784313725490196, 0]
4497 shape: [1, 10, 1]
4498 dshape:
4499 - [batch, 1]
4500 - [num_boxes, 10]
4501 - [num_classes, 1]
4502 - type: mask_coefficients
4503 decoder: ultralytics
4504 quantization: [0.00784313725490196, 0]
4505 shape: [1, 10, 32]
4506 dshape:
4507 - [batch, 1]
4508 - [num_boxes, 10]
4509 - [num_protos, 32]
4510 - type: protos
4511 decoder: ultralytics
4512 quantization: [0.0039215686274509803921568627451, 128]
4513 shape: [1, 160, 160, 32]
4514 dshape:
4515 - [batch, 1]
4516 - [height, 160]
4517 - [width, 160]
4518 - [num_protos, 32]
4519";
4520
4521 let decoder = DecoderBuilder::default()
4522 .with_config_yaml_str(config.to_string())
4523 .with_score_threshold(score_threshold)
4524 .with_iou_threshold(iou_threshold)
4525 .build()
4526 .unwrap();
4527
4528 let expected_boxes = [DetectBox {
4530 bbox: BoundingBox {
4531 xmin: 0.12549022,
4532 ymin: 0.12549022,
4533 xmax: 0.23529413,
4534 ymax: 0.23529413,
4535 },
4536 score: 0.98823535,
4537 label: 2,
4538 }];
4539
4540 let mut tracker = ByteTrackBuilder::new()
4541 .track_update(0.1)
4542 .track_high_conf(0.7)
4543 .build();
4544
4545 let mut output_boxes = Vec::with_capacity(50);
4546 let mut output_masks = Vec::with_capacity(50);
4547 let mut output_tracks = Vec::with_capacity(50);
4548
4549 decoder
4550 .decode_tracked_quantized(
4551 &mut tracker,
4552 0,
4553 &[
4554 boxes.view().into(),
4555 scores.view().into(),
4556 classes.view().into(),
4557 mask.view().into(),
4558 protos.view().into(),
4559 ],
4560 &mut output_boxes,
4561 &mut output_masks,
4562 &mut output_tracks,
4563 )
4564 .unwrap();
4565
4566 assert_eq!(output_boxes.len(), 1);
4567 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
4568
4569 for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
4572 *score = u8::MIN; }
4574
4575 decoder
4576 .decode_tracked_quantized(
4577 &mut tracker,
4578 100_000_000 / 3,
4579 &[
4580 boxes.view().into(),
4581 scores.view().into(),
4582 classes.view().into(),
4583 mask.view().into(),
4584 protos.view().into(),
4585 ],
4586 &mut output_boxes,
4587 &mut output_masks,
4588 &mut output_tracks,
4589 )
4590 .unwrap();
4591 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4592 assert!(output_masks.is_empty())
4594 }
4595 #[test]
4596 fn test_decoder_tracked_end_to_end_segdet_split_float() {
4597 let score_threshold = 0.45;
4598 let iou_threshold = 0.45;
4599
4600 let mut boxes = Array2::zeros((10, 4));
4601 let mut scores = Array2::zeros((10, 1));
4602 let mut classes = Array2::zeros((10, 1));
4603 let mask: Array2<f64> = Array2::zeros((10, 32));
4604 let protos = Array3::<f64>::zeros((160, 160, 32));
4605 let protos = protos.insert_axis(Axis(0));
4606
4607 boxes
4608 .slice_mut(s![0, ..,])
4609 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4610 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4611 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4612
4613 let boxes = boxes.insert_axis(Axis(0));
4614 let mut scores = scores.insert_axis(Axis(0));
4615 let classes = classes.insert_axis(Axis(0));
4616 let mask = mask.insert_axis(Axis(0));
4617
4618 let config = "
4619decoder_version: yolo26
4620outputs:
4621 - type: boxes
4622 decoder: ultralytics
4623 quantization: [0.00784313725490196, 0]
4624 shape: [1, 10, 4]
4625 dshape:
4626 - [batch, 1]
4627 - [num_boxes, 10]
4628 - [box_coords, 4]
4629 normalized: true
4630 - type: scores
4631 decoder: ultralytics
4632 quantization: [0.00784313725490196, 0]
4633 shape: [1, 10, 1]
4634 dshape:
4635 - [batch, 1]
4636 - [num_boxes, 10]
4637 - [num_classes, 1]
4638 - type: classes
4639 decoder: ultralytics
4640 quantization: [0.00784313725490196, 0]
4641 shape: [1, 10, 1]
4642 dshape:
4643 - [batch, 1]
4644 - [num_boxes, 10]
4645 - [num_classes, 1]
4646 - type: mask_coefficients
4647 decoder: ultralytics
4648 quantization: [0.00784313725490196, 0]
4649 shape: [1, 10, 32]
4650 dshape:
4651 - [batch, 1]
4652 - [num_boxes, 10]
4653 - [num_protos, 32]
4654 - type: protos
4655 decoder: ultralytics
4656 quantization: [0.0039215686274509803921568627451, 128]
4657 shape: [1, 160, 160, 32]
4658 dshape:
4659 - [batch, 1]
4660 - [height, 160]
4661 - [width, 160]
4662 - [num_protos, 32]
4663";
4664
4665 let decoder = DecoderBuilder::default()
4666 .with_config_yaml_str(config.to_string())
4667 .with_score_threshold(score_threshold)
4668 .with_iou_threshold(iou_threshold)
4669 .build()
4670 .unwrap();
4671
4672 let expected_boxes = [DetectBox {
4674 bbox: BoundingBox {
4675 xmin: 0.1234,
4676 ymin: 0.1234,
4677 xmax: 0.2345,
4678 ymax: 0.2345,
4679 },
4680 score: 0.9876,
4681 label: 2,
4682 }];
4683
4684 let mut tracker = ByteTrackBuilder::new()
4685 .track_update(0.1)
4686 .track_high_conf(0.7)
4687 .build();
4688
4689 let mut output_boxes = Vec::with_capacity(50);
4690 let mut output_masks = Vec::with_capacity(50);
4691 let mut output_tracks = Vec::with_capacity(50);
4692
4693 decoder
4694 .decode_tracked_float(
4695 &mut tracker,
4696 0,
4697 &[
4698 boxes.view().into_dyn(),
4699 scores.view().into_dyn(),
4700 classes.view().into_dyn(),
4701 mask.view().into_dyn(),
4702 protos.view().into_dyn(),
4703 ],
4704 &mut output_boxes,
4705 &mut output_masks,
4706 &mut output_tracks,
4707 )
4708 .unwrap();
4709
4710 assert_eq!(output_boxes.len(), 1);
4711 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
4712
4713 for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
4716 *score = 0.0; }
4718
4719 decoder
4720 .decode_tracked_float(
4721 &mut tracker,
4722 100_000_000 / 3,
4723 &[
4724 boxes.view().into_dyn(),
4725 scores.view().into_dyn(),
4726 classes.view().into_dyn(),
4727 mask.view().into_dyn(),
4728 protos.view().into_dyn(),
4729 ],
4730 &mut output_boxes,
4731 &mut output_masks,
4732 &mut output_tracks,
4733 )
4734 .unwrap();
4735 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4736 assert!(output_masks.is_empty())
4738 }
4739
4740 #[test]
4741 fn test_decoder_tracked_end_to_end_segdet_split_proto() {
4742 let score_threshold = 0.45;
4743 let iou_threshold = 0.45;
4744
4745 let mut boxes = Array2::zeros((10, 4));
4746 let mut scores = Array2::zeros((10, 1));
4747 let mut classes = Array2::zeros((10, 1));
4748 let mask: Array2<f64> = Array2::zeros((10, 32));
4749 let protos = Array3::<f64>::zeros((160, 160, 32));
4750 let protos = protos.insert_axis(Axis(0));
4751
4752 let protos_quant = (1.0 / 255.0, 0.0);
4753 let protos: Array4<u8> = quantize_ndarray(protos.view(), protos_quant.into());
4754
4755 boxes
4756 .slice_mut(s![0, ..,])
4757 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4758 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4759 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4760
4761 let boxes = boxes.insert_axis(Axis(0));
4762 let scores = scores.insert_axis(Axis(0));
4763 let classes = classes.insert_axis(Axis(0));
4764 let mask = mask.insert_axis(Axis(0));
4765
4766 let detect_quant = (2.0 / 255.0, 0.0);
4767 let boxes: Array3<u8> = quantize_ndarray(boxes.view(), detect_quant.into());
4768 let mut scores: Array3<u8> = quantize_ndarray(scores.view(), detect_quant.into());
4769 let classes: Array3<u8> = quantize_ndarray(classes.view(), detect_quant.into());
4770 let mask: Array3<u8> = quantize_ndarray(mask.view(), detect_quant.into());
4771
4772 let config = "
4773decoder_version: yolo26
4774outputs:
4775 - type: boxes
4776 decoder: ultralytics
4777 quantization: [0.00784313725490196, 0]
4778 shape: [1, 10, 4]
4779 dshape:
4780 - [batch, 1]
4781 - [num_boxes, 10]
4782 - [box_coords, 4]
4783 normalized: true
4784 - type: scores
4785 decoder: ultralytics
4786 quantization: [0.00784313725490196, 0]
4787 shape: [1, 10, 1]
4788 dshape:
4789 - [batch, 1]
4790 - [num_boxes, 10]
4791 - [num_classes, 1]
4792 - type: classes
4793 decoder: ultralytics
4794 quantization: [0.00784313725490196, 0]
4795 shape: [1, 10, 1]
4796 dshape:
4797 - [batch, 1]
4798 - [num_boxes, 10]
4799 - [num_classes, 1]
4800 - type: mask_coefficients
4801 decoder: ultralytics
4802 quantization: [0.00784313725490196, 0]
4803 shape: [1, 10, 32]
4804 dshape:
4805 - [batch, 1]
4806 - [num_boxes, 10]
4807 - [num_protos, 32]
4808 - type: protos
4809 decoder: ultralytics
4810 quantization: [0.0039215686274509803921568627451, 128]
4811 shape: [1, 160, 160, 32]
4812 dshape:
4813 - [batch, 1]
4814 - [height, 160]
4815 - [width, 160]
4816 - [num_protos, 32]
4817";
4818
4819 let decoder = DecoderBuilder::default()
4820 .with_config_yaml_str(config.to_string())
4821 .with_score_threshold(score_threshold)
4822 .with_iou_threshold(iou_threshold)
4823 .build()
4824 .unwrap();
4825
4826 let expected_boxes = [DetectBox {
4828 bbox: BoundingBox {
4829 xmin: 0.12549022,
4830 ymin: 0.12549022,
4831 xmax: 0.23529413,
4832 ymax: 0.23529413,
4833 },
4834 score: 0.98823535,
4835 label: 2,
4836 }];
4837
4838 let mut tracker = ByteTrackBuilder::new()
4839 .track_update(0.1)
4840 .track_high_conf(0.7)
4841 .build();
4842
4843 let mut output_boxes = Vec::with_capacity(50);
4844 let mut output_tracks = Vec::with_capacity(50);
4845
4846 decoder
4847 .decode_tracked_quantized_proto(
4848 &mut tracker,
4849 0,
4850 &[
4851 boxes.view().into(),
4852 scores.view().into(),
4853 classes.view().into(),
4854 mask.view().into(),
4855 protos.view().into(),
4856 ],
4857 &mut output_boxes,
4858 &mut output_tracks,
4859 )
4860 .unwrap();
4861
4862 assert_eq!(output_boxes.len(), 1);
4863 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
4864
4865 for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
4868 *score = u8::MIN; }
4870
4871 let protos = decoder
4872 .decode_tracked_quantized_proto(
4873 &mut tracker,
4874 100_000_000 / 3,
4875 &[
4876 boxes.view().into(),
4877 scores.view().into(),
4878 classes.view().into(),
4879 mask.view().into(),
4880 protos.view().into(),
4881 ],
4882 &mut output_boxes,
4883 &mut output_tracks,
4884 )
4885 .unwrap();
4886 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4887 assert!(protos.is_some_and(|x| x.mask_coefficients.is_empty()))
4889 }
4890
4891 #[test]
4892 fn test_decoder_tracked_end_to_end_segdet_split_proto_float() {
4893 let score_threshold = 0.45;
4894 let iou_threshold = 0.45;
4895
4896 let mut boxes = Array2::zeros((10, 4));
4897 let mut scores = Array2::zeros((10, 1));
4898 let mut classes = Array2::zeros((10, 1));
4899 let mask: Array2<f64> = Array2::zeros((10, 32));
4900 let protos = Array3::<f64>::zeros((160, 160, 32));
4901 let protos = protos.insert_axis(Axis(0));
4902
4903 boxes
4904 .slice_mut(s![0, ..,])
4905 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4906 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4907 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4908
4909 let boxes = boxes.insert_axis(Axis(0));
4910 let mut scores = scores.insert_axis(Axis(0));
4911 let classes = classes.insert_axis(Axis(0));
4912 let mask = mask.insert_axis(Axis(0));
4913
4914 let config = "
4915decoder_version: yolo26
4916outputs:
4917 - type: boxes
4918 decoder: ultralytics
4919 quantization: [0.00784313725490196, 0]
4920 shape: [1, 10, 4]
4921 dshape:
4922 - [batch, 1]
4923 - [num_boxes, 10]
4924 - [box_coords, 4]
4925 normalized: true
4926 - type: scores
4927 decoder: ultralytics
4928 quantization: [0.00784313725490196, 0]
4929 shape: [1, 10, 1]
4930 dshape:
4931 - [batch, 1]
4932 - [num_boxes, 10]
4933 - [num_classes, 1]
4934 - type: classes
4935 decoder: ultralytics
4936 quantization: [0.00784313725490196, 0]
4937 shape: [1, 10, 1]
4938 dshape:
4939 - [batch, 1]
4940 - [num_boxes, 10]
4941 - [num_classes, 1]
4942 - type: mask_coefficients
4943 decoder: ultralytics
4944 quantization: [0.00784313725490196, 0]
4945 shape: [1, 10, 32]
4946 dshape:
4947 - [batch, 1]
4948 - [num_boxes, 10]
4949 - [num_protos, 32]
4950 - type: protos
4951 decoder: ultralytics
4952 quantization: [0.0039215686274509803921568627451, 128]
4953 shape: [1, 160, 160, 32]
4954 dshape:
4955 - [batch, 1]
4956 - [height, 160]
4957 - [width, 160]
4958 - [num_protos, 32]
4959";
4960
4961 let decoder = DecoderBuilder::default()
4962 .with_config_yaml_str(config.to_string())
4963 .with_score_threshold(score_threshold)
4964 .with_iou_threshold(iou_threshold)
4965 .build()
4966 .unwrap();
4967
4968 let expected_boxes = [DetectBox {
4970 bbox: BoundingBox {
4971 xmin: 0.1234,
4972 ymin: 0.1234,
4973 xmax: 0.2345,
4974 ymax: 0.2345,
4975 },
4976 score: 0.9876,
4977 label: 2,
4978 }];
4979
4980 let mut tracker = ByteTrackBuilder::new()
4981 .track_update(0.1)
4982 .track_high_conf(0.7)
4983 .build();
4984
4985 let mut output_boxes = Vec::with_capacity(50);
4986 let mut output_tracks = Vec::with_capacity(50);
4987
4988 decoder
4989 .decode_tracked_float_proto(
4990 &mut tracker,
4991 0,
4992 &[
4993 boxes.view().into_dyn(),
4994 scores.view().into_dyn(),
4995 classes.view().into_dyn(),
4996 mask.view().into_dyn(),
4997 protos.view().into_dyn(),
4998 ],
4999 &mut output_boxes,
5000 &mut output_tracks,
5001 )
5002 .unwrap();
5003
5004 assert_eq!(output_boxes.len(), 1);
5005 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
5006
5007 for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
5010 *score = 0.0; }
5012
5013 let protos = decoder
5014 .decode_tracked_float_proto(
5015 &mut tracker,
5016 100_000_000 / 3,
5017 &[
5018 boxes.view().into_dyn(),
5019 scores.view().into_dyn(),
5020 classes.view().into_dyn(),
5021 mask.view().into_dyn(),
5022 protos.view().into_dyn(),
5023 ],
5024 &mut output_boxes,
5025 &mut output_tracks,
5026 )
5027 .unwrap();
5028 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
5029 assert!(protos.is_some_and(|x| x.mask_coefficients.is_empty()))
5031 }
5032
5033 #[test]
5034 fn test_decoder_tracked_linear_motion() {
5035 use crate::configs::{DecoderType, Nms};
5036 use crate::DecoderBuilder;
5037
5038 let score_threshold = 0.25;
5039 let iou_threshold = 0.1;
5040 let out = include_bytes!(concat!(
5041 env!("CARGO_MANIFEST_DIR"),
5042 "/../../testdata/yolov8s_80_classes.bin"
5043 ));
5044 let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
5045 let mut out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
5046 let quant = (0.0040811873, -123).into();
5047
5048 let decoder = DecoderBuilder::default()
5049 .with_config_yolo_det(
5050 crate::configs::Detection {
5051 decoder: DecoderType::Ultralytics,
5052 shape: vec![1, 84, 8400],
5053 anchors: None,
5054 quantization: Some(quant),
5055 dshape: vec![
5056 (crate::configs::DimName::Batch, 1),
5057 (crate::configs::DimName::NumFeatures, 84),
5058 (crate::configs::DimName::NumBoxes, 8400),
5059 ],
5060 normalized: Some(true),
5061 },
5062 None,
5063 )
5064 .with_score_threshold(score_threshold)
5065 .with_iou_threshold(iou_threshold)
5066 .with_nms(Some(Nms::ClassAgnostic))
5067 .build()
5068 .unwrap();
5069
5070 let mut expected_boxes = [
5071 DetectBox {
5072 bbox: BoundingBox {
5073 xmin: 0.5285137,
5074 ymin: 0.05305544,
5075 xmax: 0.87541467,
5076 ymax: 0.9998909,
5077 },
5078 score: 0.5591227,
5079 label: 0,
5080 },
5081 DetectBox {
5082 bbox: BoundingBox {
5083 xmin: 0.130598,
5084 ymin: 0.43260583,
5085 xmax: 0.35098213,
5086 ymax: 0.9958097,
5087 },
5088 score: 0.33057618,
5089 label: 75,
5090 },
5091 ];
5092
5093 let mut tracker = ByteTrackBuilder::new()
5094 .track_update(0.1)
5095 .track_high_conf(0.3)
5096 .build();
5097
5098 let mut output_boxes = Vec::with_capacity(50);
5099 let mut output_masks = Vec::with_capacity(50);
5100 let mut output_tracks = Vec::with_capacity(50);
5101
5102 decoder
5103 .decode_tracked_quantized(
5104 &mut tracker,
5105 0,
5106 &[out.view().into()],
5107 &mut output_boxes,
5108 &mut output_masks,
5109 &mut output_tracks,
5110 )
5111 .unwrap();
5112
5113 assert_eq!(output_boxes.len(), 2);
5114 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
5115 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
5116
5117 for i in 1..=100 {
5118 let mut out = out.clone();
5119 let mut x_values = out.slice_mut(s![0, 0, ..]);
5121 for x in x_values.iter_mut() {
5122 *x = x.saturating_add((i as f32 * 1e-3 / quant.0).round() as i8);
5123 }
5124
5125 decoder
5126 .decode_tracked_quantized(
5127 &mut tracker,
5128 100_000_000 * i / 3, &[out.view().into()],
5130 &mut output_boxes,
5131 &mut output_masks,
5132 &mut output_tracks,
5133 )
5134 .unwrap();
5135
5136 assert_eq!(output_boxes.len(), 2);
5137 }
5138 let tracks = tracker.get_active_tracks();
5139 let predicted_boxes: Vec<_> = tracks
5140 .iter()
5141 .map(|track| {
5142 let mut l = track.last_box;
5143 l.bbox = track.info.tracked_location.into();
5144 l
5145 })
5146 .collect();
5147 expected_boxes[0].bbox.xmin += 0.1; expected_boxes[0].bbox.xmax += 0.1;
5149 expected_boxes[1].bbox.xmin += 0.1;
5150 expected_boxes[1].bbox.xmax += 0.1;
5151
5152 assert!(predicted_boxes[0].equal_within_delta(&expected_boxes[0], 1e-3));
5153 assert!(predicted_boxes[1].equal_within_delta(&expected_boxes[1], 1e-3));
5154
5155 let mut scores_values = out.slice_mut(s![0, 4.., ..]);
5157 for score in scores_values.iter_mut() {
5158 *score = i8::MIN; }
5160 decoder
5161 .decode_tracked_quantized(
5162 &mut tracker,
5163 100_000_000 * 101 / 3,
5164 &[out.view().into()],
5165 &mut output_boxes,
5166 &mut output_masks,
5167 &mut output_tracks,
5168 )
5169 .unwrap();
5170 expected_boxes[0].bbox.xmin += 0.001; expected_boxes[0].bbox.xmax += 0.001;
5172 expected_boxes[1].bbox.xmin += 0.001;
5173 expected_boxes[1].bbox.xmax += 0.001;
5174
5175 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-3));
5176 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-3));
5177 }
5178
5179 #[test]
5180 fn test_decoder_tracked_end_to_end_float() {
5181 let score_threshold = 0.45;
5182 let iou_threshold = 0.45;
5183
5184 let mut boxes = Array2::zeros((10, 4));
5185 let mut scores = Array2::zeros((10, 1));
5186 let mut classes = Array2::zeros((10, 1));
5187
5188 boxes
5189 .slice_mut(s![0, ..,])
5190 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
5191 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
5192 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
5193
5194 let detect = ndarray::concatenate![Axis(1), boxes.view(), scores.view(), classes.view(),];
5195 let mut detect = detect.insert_axis(Axis(0));
5196 assert_eq!(detect.shape(), &[1, 10, 6]);
5197 let config = "
5198decoder_version: yolo26
5199outputs:
5200 - type: detection
5201 decoder: ultralytics
5202 quantization: [0.00784313725490196, 0]
5203 shape: [1, 10, 6]
5204 dshape:
5205 - [batch, 1]
5206 - [num_boxes, 10]
5207 - [num_features, 6]
5208 normalized: true
5209";
5210
5211 let decoder = DecoderBuilder::default()
5212 .with_config_yaml_str(config.to_string())
5213 .with_score_threshold(score_threshold)
5214 .with_iou_threshold(iou_threshold)
5215 .build()
5216 .unwrap();
5217
5218 let expected_boxes = [DetectBox {
5219 bbox: BoundingBox {
5220 xmin: 0.1234,
5221 ymin: 0.1234,
5222 xmax: 0.2345,
5223 ymax: 0.2345,
5224 },
5225 score: 0.9876,
5226 label: 2,
5227 }];
5228
5229 let mut tracker = ByteTrackBuilder::new()
5230 .track_update(0.1)
5231 .track_high_conf(0.7)
5232 .build();
5233
5234 let mut output_boxes = Vec::with_capacity(50);
5235 let mut output_masks = Vec::with_capacity(50);
5236 let mut output_tracks = Vec::with_capacity(50);
5237
5238 decoder
5239 .decode_tracked_float(
5240 &mut tracker,
5241 0,
5242 &[detect.view().into_dyn()],
5243 &mut output_boxes,
5244 &mut output_masks,
5245 &mut output_tracks,
5246 )
5247 .unwrap();
5248
5249 assert_eq!(output_boxes.len(), 1);
5250 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
5251
5252 for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
5255 *score = 0.0; }
5257
5258 decoder
5259 .decode_tracked_float(
5260 &mut tracker,
5261 100_000_000 / 3,
5262 &[detect.view().into_dyn()],
5263 &mut output_boxes,
5264 &mut output_masks,
5265 &mut output_tracks,
5266 )
5267 .unwrap();
5268 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
5269 }
5270}