1#![cfg_attr(coverage_nightly, feature(coverage_attribute))]
68
69use ndarray::{Array, Array2, Array3, ArrayView, ArrayView1, ArrayView3, Dimension};
70use num_traits::{AsPrimitive, Float, PrimInt};
71
72pub mod byte;
73pub mod error;
74pub mod float;
75pub mod modelpack;
76pub mod yolo;
77
78mod decoder;
79pub use decoder::*;
80
81pub use configs::{DecoderVersion, Nms};
82pub use error::{DecoderError, DecoderResult};
83
84use crate::{
85 decoder::configs::QuantTuple, modelpack::modelpack_segmentation_to_mask,
86 yolo::yolo_segmentation_to_mask,
87};
88
89pub trait BBoxTypeTrait {
91 fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4];
93
94 fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
96 input: &[B; 4],
97 quant: Quantization,
98 ) -> [A; 4]
99 where
100 f32: AsPrimitive<A>,
101 i32: AsPrimitive<A>;
102
103 #[inline(always)]
114 fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
115 input: ArrayView1<B>,
116 ) -> [A; 4] {
117 Self::to_xyxy_float(&[input[0], input[1], input[2], input[3]])
118 }
119
120 #[inline(always)]
121 fn ndarray_to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
123 input: ArrayView1<B>,
124 quant: Quantization,
125 ) -> [A; 4]
126 where
127 f32: AsPrimitive<A>,
128 i32: AsPrimitive<A>,
129 {
130 Self::to_xyxy_dequant(&[input[0], input[1], input[2], input[3]], quant)
131 }
132}
133
134#[derive(Debug, Clone, Copy, PartialEq, Eq)]
136pub struct XYXY {}
137
138impl BBoxTypeTrait for XYXY {
139 fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
140 input.map(|b| b.as_())
141 }
142
143 fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
144 input: &[B; 4],
145 quant: Quantization,
146 ) -> [A; 4]
147 where
148 f32: AsPrimitive<A>,
149 i32: AsPrimitive<A>,
150 {
151 let scale = quant.scale.as_();
152 let zp = quant.zero_point.as_();
153 input.map(|b| (b.as_() - zp) * scale)
154 }
155
156 #[inline(always)]
157 fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
158 input: ArrayView1<B>,
159 ) -> [A; 4] {
160 [
161 input[0].as_(),
162 input[1].as_(),
163 input[2].as_(),
164 input[3].as_(),
165 ]
166 }
167}
168
169#[derive(Debug, Clone, Copy, PartialEq, Eq)]
172pub struct XYWH {}
173
174impl BBoxTypeTrait for XYWH {
175 #[inline(always)]
176 fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
177 let half = A::one() / (A::one() + A::one());
178 [
179 (input[0].as_()) - (input[2].as_() * half),
180 (input[1].as_()) - (input[3].as_() * half),
181 (input[0].as_()) + (input[2].as_() * half),
182 (input[1].as_()) + (input[3].as_() * half),
183 ]
184 }
185
186 #[inline(always)]
187 fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
188 input: &[B; 4],
189 quant: Quantization,
190 ) -> [A; 4]
191 where
192 f32: AsPrimitive<A>,
193 i32: AsPrimitive<A>,
194 {
195 let scale = quant.scale.as_();
196 let half_scale = (quant.scale * 0.5).as_();
197 let zp = quant.zero_point.as_();
198 let [x, y, w, h] = [
199 (input[0].as_() - zp) * scale,
200 (input[1].as_() - zp) * scale,
201 (input[2].as_() - zp) * half_scale,
202 (input[3].as_() - zp) * half_scale,
203 ];
204
205 [x - w, y - h, x + w, y + h]
206 }
207
208 #[inline(always)]
209 fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
210 input: ArrayView1<B>,
211 ) -> [A; 4] {
212 let half = A::one() / (A::one() + A::one());
213 [
214 (input[0].as_()) - (input[2].as_() * half),
215 (input[1].as_()) - (input[3].as_() * half),
216 (input[0].as_()) + (input[2].as_() * half),
217 (input[1].as_()) + (input[3].as_() * half),
218 ]
219 }
220}
221
222#[derive(Debug, Clone, Copy, PartialEq)]
224pub struct Quantization {
225 pub scale: f32,
226 pub zero_point: i32,
227}
228
229impl Quantization {
230 pub fn new(scale: f32, zero_point: i32) -> Self {
239 Self { scale, zero_point }
240 }
241}
242
243impl From<QuantTuple> for Quantization {
244 fn from(quant_tuple: QuantTuple) -> Quantization {
255 Quantization {
256 scale: quant_tuple.0,
257 zero_point: quant_tuple.1,
258 }
259 }
260}
261
262impl<S, Z> From<(S, Z)> for Quantization
263where
264 S: AsPrimitive<f32>,
265 Z: AsPrimitive<i32>,
266{
267 fn from((scale, zp): (S, Z)) -> Quantization {
276 Self {
277 scale: scale.as_(),
278 zero_point: zp.as_(),
279 }
280 }
281}
282
283impl Default for Quantization {
284 fn default() -> Self {
293 Self {
294 scale: 1.0,
295 zero_point: 0,
296 }
297 }
298}
299
300#[derive(Debug, Clone, Copy, PartialEq, Default)]
302pub struct DetectBox {
303 pub bbox: BoundingBox,
304 pub score: f32,
306 pub label: usize,
308}
309
310#[derive(Debug, Clone, Copy, PartialEq, Default)]
312pub struct BoundingBox {
313 pub xmin: f32,
315 pub ymin: f32,
317 pub xmax: f32,
319 pub ymax: f32,
321}
322
323impl BoundingBox {
324 pub fn new(xmin: f32, ymin: f32, xmax: f32, ymax: f32) -> Self {
326 Self {
327 xmin,
328 ymin,
329 xmax,
330 ymax,
331 }
332 }
333
334 pub fn to_canonical(&self) -> Self {
343 let xmin = self.xmin.min(self.xmax);
344 let xmax = self.xmin.max(self.xmax);
345 let ymin = self.ymin.min(self.ymax);
346 let ymax = self.ymin.max(self.ymax);
347 BoundingBox {
348 xmin,
349 ymin,
350 xmax,
351 ymax,
352 }
353 }
354}
355
356impl From<BoundingBox> for [f32; 4] {
357 fn from(b: BoundingBox) -> Self {
372 [b.xmin, b.ymin, b.xmax, b.ymax]
373 }
374}
375
376impl From<[f32; 4]> for BoundingBox {
377 fn from(arr: [f32; 4]) -> Self {
380 BoundingBox {
381 xmin: arr[0],
382 ymin: arr[1],
383 xmax: arr[2],
384 ymax: arr[3],
385 }
386 }
387}
388
389impl DetectBox {
390 pub fn equal_within_delta(&self, rhs: &DetectBox, eps: f32) -> bool {
419 let eq_delta = |a: f32, b: f32| (a - b).abs() <= eps;
420 self.label == rhs.label
421 && eq_delta(self.score, rhs.score)
422 && eq_delta(self.bbox.xmin, rhs.bbox.xmin)
423 && eq_delta(self.bbox.ymin, rhs.bbox.ymin)
424 && eq_delta(self.bbox.xmax, rhs.bbox.xmax)
425 && eq_delta(self.bbox.ymax, rhs.bbox.ymax)
426 }
427}
428
429#[derive(Debug, Clone, PartialEq, Default)]
432pub struct Segmentation {
433 pub xmin: f32,
435 pub ymin: f32,
437 pub xmax: f32,
439 pub ymax: f32,
441 pub segmentation: Array3<u8>,
449}
450
451#[derive(Debug, Clone)]
457pub enum ProtoTensor {
458 Quantized {
462 protos: Array3<i8>,
463 quantization: Quantization,
464 },
465 Float(Array3<f32>),
467}
468
469impl ProtoTensor {
470 pub fn is_quantized(&self) -> bool {
472 matches!(self, ProtoTensor::Quantized { .. })
473 }
474
475 pub fn dim(&self) -> (usize, usize, usize) {
477 match self {
478 ProtoTensor::Quantized { protos, .. } => protos.dim(),
479 ProtoTensor::Float(arr) => arr.dim(),
480 }
481 }
482
483 pub fn as_f32(&self) -> std::borrow::Cow<'_, Array3<f32>> {
486 match self {
487 ProtoTensor::Float(arr) => std::borrow::Cow::Borrowed(arr),
488 ProtoTensor::Quantized {
489 protos,
490 quantization,
491 } => {
492 let scale = quantization.scale;
493 let zp = quantization.zero_point as f32;
494 std::borrow::Cow::Owned(protos.map(|&v| (v as f32 - zp) * scale))
495 }
496 }
497 }
498}
499
500#[derive(Debug, Clone)]
506pub struct ProtoData {
507 pub mask_coefficients: Vec<Vec<f32>>,
509 pub protos: ProtoTensor,
511}
512
513pub fn dequant_detect_box<SCORE: PrimInt + AsPrimitive<f32>>(
531 detect: &DetectBoxQuantized<SCORE>,
532 quant_scores: Quantization,
533) -> DetectBox {
534 let scaled_zp = -quant_scores.scale * quant_scores.zero_point as f32;
535 DetectBox {
536 bbox: detect.bbox,
537 score: quant_scores.scale * detect.score.as_() + scaled_zp,
538 label: detect.label,
539 }
540}
541#[derive(Debug, Clone, Copy, PartialEq)]
543pub struct DetectBoxQuantized<
544 SCORE: PrimInt + AsPrimitive<f32>,
546> {
547 pub bbox: BoundingBox,
549 pub score: SCORE,
552 pub label: usize,
554}
555
556pub fn dequantize_ndarray<T: AsPrimitive<F>, D: Dimension, F: Float + 'static>(
569 input: ArrayView<T, D>,
570 quant: Quantization,
571) -> Array<F, D>
572where
573 i32: num_traits::AsPrimitive<F>,
574 f32: num_traits::AsPrimitive<F>,
575{
576 let zero_point = quant.zero_point.as_();
577 let scale = quant.scale.as_();
578 if zero_point != F::zero() {
579 let scaled_zero = -zero_point * scale;
580 input.mapv(|d| d.as_() * scale + scaled_zero)
581 } else {
582 input.mapv(|d| d.as_() * scale)
583 }
584}
585
586pub fn dequantize_cpu<T: AsPrimitive<F>, F: Float + 'static>(
599 input: &[T],
600 quant: Quantization,
601 output: &mut [F],
602) where
603 f32: num_traits::AsPrimitive<F>,
604 i32: num_traits::AsPrimitive<F>,
605{
606 assert!(input.len() == output.len());
607 let zero_point = quant.zero_point.as_();
608 let scale = quant.scale.as_();
609 if zero_point != F::zero() {
610 let scaled_zero = -zero_point * scale; input
612 .iter()
613 .zip(output)
614 .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
615 } else {
616 input
617 .iter()
618 .zip(output)
619 .for_each(|(d, deq)| *deq = d.as_() * scale);
620 }
621}
622
623pub fn dequantize_cpu_chunked<T: AsPrimitive<F>, F: Float + 'static>(
637 input: &[T],
638 quant: Quantization,
639 output: &mut [F],
640) where
641 f32: num_traits::AsPrimitive<F>,
642 i32: num_traits::AsPrimitive<F>,
643{
644 assert!(input.len() == output.len());
645 let zero_point = quant.zero_point.as_();
646 let scale = quant.scale.as_();
647
648 let input = input.as_chunks::<4>();
649 let output = output.as_chunks_mut::<4>();
650
651 if zero_point != F::zero() {
652 let scaled_zero = -zero_point * scale; input
655 .0
656 .iter()
657 .zip(output.0)
658 .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale + scaled_zero));
659 input
660 .1
661 .iter()
662 .zip(output.1)
663 .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
664 } else {
665 input
666 .0
667 .iter()
668 .zip(output.0)
669 .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale));
670 input
671 .1
672 .iter()
673 .zip(output.1)
674 .for_each(|(d, deq)| *deq = d.as_() * scale);
675 }
676}
677
678pub fn segmentation_to_mask(segmentation: ArrayView3<u8>) -> Result<Array2<u8>, DecoderError> {
696 if segmentation.shape()[2] == 0 {
697 return Err(DecoderError::InvalidShape(
698 "Segmentation tensor must have non-zero depth".to_string(),
699 ));
700 }
701 if segmentation.shape()[2] == 1 {
702 yolo_segmentation_to_mask(segmentation, 128)
703 } else {
704 Ok(modelpack_segmentation_to_mask(segmentation))
705 }
706}
707
708fn arg_max<T: PartialOrd + Copy>(score: ArrayView1<T>) -> (T, usize) {
710 score
711 .iter()
712 .enumerate()
713 .fold((score[0], 0), |(max, arg_max), (ind, s)| {
714 if max > *s {
715 (max, arg_max)
716 } else {
717 (*s, ind)
718 }
719 })
720}
721#[cfg(test)]
722#[cfg_attr(coverage_nightly, coverage(off))]
723mod decoder_tests {
724 #![allow(clippy::excessive_precision)]
725 use crate::{
726 configs::{DecoderType, DimName, Protos},
727 modelpack::{decode_modelpack_det, decode_modelpack_split_quant},
728 yolo::{
729 decode_yolo_det, decode_yolo_det_float, decode_yolo_segdet_float,
730 decode_yolo_segdet_quant,
731 },
732 *,
733 };
734 use ndarray::{array, s, Array4};
735 use ndarray_stats::DeviationExt;
736
737 fn compare_outputs(
738 boxes: (&[DetectBox], &[DetectBox]),
739 masks: (&[Segmentation], &[Segmentation]),
740 ) {
741 let (boxes0, boxes1) = boxes;
742 let (masks0, masks1) = masks;
743
744 assert_eq!(boxes0.len(), boxes1.len());
745 assert_eq!(masks0.len(), masks1.len());
746
747 for (b_i8, b_f32) in boxes0.iter().zip(boxes1) {
748 assert!(
749 b_i8.equal_within_delta(b_f32, 1e-6),
750 "{b_i8:?} is not equal to {b_f32:?}"
751 );
752 }
753
754 for (m_i8, m_f32) in masks0.iter().zip(masks1) {
755 assert_eq!(
756 [m_i8.xmin, m_i8.ymin, m_i8.xmax, m_i8.ymax],
757 [m_f32.xmin, m_f32.ymin, m_f32.xmax, m_f32.ymax],
758 );
759 assert_eq!(m_i8.segmentation.shape(), m_f32.segmentation.shape());
760 let mask_i8 = m_i8.segmentation.map(|x| *x as i32);
761 let mask_f32 = m_f32.segmentation.map(|x| *x as i32);
762 let diff = &mask_i8 - &mask_f32;
763 for x in 0..diff.shape()[0] {
764 for y in 0..diff.shape()[1] {
765 for z in 0..diff.shape()[2] {
766 let val = diff[[x, y, z]];
767 assert!(
768 val.abs() <= 1,
769 "Difference between mask0 and mask1 is greater than 1 at ({}, {}, {}): {}",
770 x,
771 y,
772 z,
773 val
774 );
775 }
776 }
777 }
778 let mean_sq_err = mask_i8.mean_sq_err(&mask_f32).unwrap();
779 assert!(
780 mean_sq_err < 1e-2,
781 "Mean Square Error between masks was greater than 1%: {:.2}%",
782 mean_sq_err * 100.0
783 );
784 }
785 }
786
787 #[test]
788 fn test_decoder_modelpack() {
789 let score_threshold = 0.45;
790 let iou_threshold = 0.45;
791 let boxes = include_bytes!(concat!(
792 env!("CARGO_MANIFEST_DIR"),
793 "/../../testdata/modelpack_boxes_1935x1x4.bin"
794 ));
795 let boxes = ndarray::Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
796
797 let scores = include_bytes!(concat!(
798 env!("CARGO_MANIFEST_DIR"),
799 "/../../testdata/modelpack_scores_1935x1.bin"
800 ));
801 let scores = ndarray::Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
802
803 let quant_boxes = (0.004656755365431309, 21).into();
804 let quant_scores = (0.0019603664986789227, 0).into();
805
806 let decoder = DecoderBuilder::default()
807 .with_config_modelpack_det(
808 configs::Boxes {
809 decoder: DecoderType::ModelPack,
810 quantization: Some(quant_boxes),
811 shape: vec![1, 1935, 1, 4],
812 dshape: vec![
813 (DimName::Batch, 1),
814 (DimName::NumBoxes, 1935),
815 (DimName::Padding, 1),
816 (DimName::BoxCoords, 4),
817 ],
818 normalized: Some(true),
819 },
820 configs::Scores {
821 decoder: DecoderType::ModelPack,
822 quantization: Some(quant_scores),
823 shape: vec![1, 1935, 1],
824 dshape: vec![
825 (DimName::Batch, 1),
826 (DimName::NumBoxes, 1935),
827 (DimName::NumClasses, 1),
828 ],
829 },
830 )
831 .with_score_threshold(score_threshold)
832 .with_iou_threshold(iou_threshold)
833 .build()
834 .unwrap();
835
836 let quant_boxes = quant_boxes.into();
837 let quant_scores = quant_scores.into();
838
839 let mut output_boxes: Vec<_> = Vec::with_capacity(50);
840 decode_modelpack_det(
841 (boxes.slice(s![0, .., 0, ..]), quant_boxes),
842 (scores.slice(s![0, .., ..]), quant_scores),
843 score_threshold,
844 iou_threshold,
845 &mut output_boxes,
846 );
847 assert!(output_boxes[0].equal_within_delta(
848 &DetectBox {
849 bbox: BoundingBox {
850 xmin: 0.40513772,
851 ymin: 0.6379755,
852 xmax: 0.5122431,
853 ymax: 0.7730214,
854 },
855 score: 0.4861709,
856 label: 0
857 },
858 1e-6
859 ));
860
861 let mut output_boxes1 = Vec::with_capacity(50);
862 let mut output_masks1 = Vec::with_capacity(50);
863
864 decoder
865 .decode_quantized(
866 &[boxes.view().into(), scores.view().into()],
867 &mut output_boxes1,
868 &mut output_masks1,
869 )
870 .unwrap();
871
872 let mut output_boxes_float = Vec::with_capacity(50);
873 let mut output_masks_float = Vec::with_capacity(50);
874
875 let boxes = dequantize_ndarray(boxes.view(), quant_boxes);
876 let scores = dequantize_ndarray(scores.view(), quant_scores);
877
878 decoder
879 .decode_float::<f32>(
880 &[boxes.view().into_dyn(), scores.view().into_dyn()],
881 &mut output_boxes_float,
882 &mut output_masks_float,
883 )
884 .unwrap();
885
886 compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
887 compare_outputs(
888 (&output_boxes, &output_boxes_float),
889 (&[], &output_masks_float),
890 );
891 }
892
893 #[test]
894 fn test_decoder_modelpack_split_u8() {
895 let score_threshold = 0.45;
896 let iou_threshold = 0.45;
897 let detect0 = include_bytes!(concat!(
898 env!("CARGO_MANIFEST_DIR"),
899 "/../../testdata/modelpack_split_9x15x18.bin"
900 ));
901 let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
902
903 let detect1 = include_bytes!(concat!(
904 env!("CARGO_MANIFEST_DIR"),
905 "/../../testdata/modelpack_split_17x30x18.bin"
906 ));
907 let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
908
909 let quant0 = (0.08547406643629074, 174).into();
910 let quant1 = (0.09929127991199493, 183).into();
911 let anchors0 = vec![
912 [0.36666667461395264, 0.31481480598449707],
913 [0.38749998807907104, 0.4740740656852722],
914 [0.5333333611488342, 0.644444465637207],
915 ];
916 let anchors1 = vec![
917 [0.13750000298023224, 0.2074074000120163],
918 [0.2541666626930237, 0.21481481194496155],
919 [0.23125000298023224, 0.35185185074806213],
920 ];
921
922 let detect_config0 = configs::Detection {
923 decoder: DecoderType::ModelPack,
924 shape: vec![1, 9, 15, 18],
925 anchors: Some(anchors0.clone()),
926 quantization: Some(quant0),
927 dshape: vec![
928 (DimName::Batch, 1),
929 (DimName::Height, 9),
930 (DimName::Width, 15),
931 (DimName::NumAnchorsXFeatures, 18),
932 ],
933 normalized: Some(true),
934 };
935
936 let detect_config1 = configs::Detection {
937 decoder: DecoderType::ModelPack,
938 shape: vec![1, 17, 30, 18],
939 anchors: Some(anchors1.clone()),
940 quantization: Some(quant1),
941 dshape: vec![
942 (DimName::Batch, 1),
943 (DimName::Height, 17),
944 (DimName::Width, 30),
945 (DimName::NumAnchorsXFeatures, 18),
946 ],
947 normalized: Some(true),
948 };
949
950 let config0 = (&detect_config0).try_into().unwrap();
951 let config1 = (&detect_config1).try_into().unwrap();
952
953 let decoder = DecoderBuilder::default()
954 .with_config_modelpack_det_split(vec![detect_config1, detect_config0])
955 .with_score_threshold(score_threshold)
956 .with_iou_threshold(iou_threshold)
957 .build()
958 .unwrap();
959
960 let quant0 = quant0.into();
961 let quant1 = quant1.into();
962
963 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
964 decode_modelpack_split_quant(
965 &[
966 detect0.slice(s![0, .., .., ..]),
967 detect1.slice(s![0, .., .., ..]),
968 ],
969 &[config0, config1],
970 score_threshold,
971 iou_threshold,
972 &mut output_boxes,
973 );
974 assert!(output_boxes[0].equal_within_delta(
975 &DetectBox {
976 bbox: BoundingBox {
977 xmin: 0.43171933,
978 ymin: 0.68243736,
979 xmax: 0.5626645,
980 ymax: 0.808863,
981 },
982 score: 0.99240804,
983 label: 0
984 },
985 1e-6
986 ));
987
988 let mut output_boxes1: Vec<_> = Vec::with_capacity(10);
989 let mut output_masks1: Vec<_> = Vec::with_capacity(10);
990 decoder
991 .decode_quantized(
992 &[detect0.view().into(), detect1.view().into()],
993 &mut output_boxes1,
994 &mut output_masks1,
995 )
996 .unwrap();
997
998 let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(10);
999 let mut output_masks1_f32: Vec<_> = Vec::with_capacity(10);
1000
1001 let detect0 = dequantize_ndarray(detect0.view(), quant0);
1002 let detect1 = dequantize_ndarray(detect1.view(), quant1);
1003 decoder
1004 .decode_float::<f32>(
1005 &[detect0.view().into_dyn(), detect1.view().into_dyn()],
1006 &mut output_boxes1_f32,
1007 &mut output_masks1_f32,
1008 )
1009 .unwrap();
1010
1011 compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1012 compare_outputs(
1013 (&output_boxes, &output_boxes1_f32),
1014 (&[], &output_masks1_f32),
1015 );
1016 }
1017
1018 #[test]
1019 fn test_decoder_parse_config_modelpack_split_u8() {
1020 let score_threshold = 0.45;
1021 let iou_threshold = 0.45;
1022 let detect0 = include_bytes!(concat!(
1023 env!("CARGO_MANIFEST_DIR"),
1024 "/../../testdata/modelpack_split_9x15x18.bin"
1025 ));
1026 let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1027
1028 let detect1 = include_bytes!(concat!(
1029 env!("CARGO_MANIFEST_DIR"),
1030 "/../../testdata/modelpack_split_17x30x18.bin"
1031 ));
1032 let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1033
1034 let decoder = DecoderBuilder::default()
1035 .with_config_yaml_str(
1036 include_str!(concat!(
1037 env!("CARGO_MANIFEST_DIR"),
1038 "/../../testdata/modelpack_split.yaml"
1039 ))
1040 .to_string(),
1041 )
1042 .with_score_threshold(score_threshold)
1043 .with_iou_threshold(iou_threshold)
1044 .build()
1045 .unwrap();
1046
1047 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1048 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1049 decoder
1050 .decode_quantized(
1051 &[
1052 ArrayViewDQuantized::from(detect1.view()),
1053 ArrayViewDQuantized::from(detect0.view()),
1054 ],
1055 &mut output_boxes,
1056 &mut output_masks,
1057 )
1058 .unwrap();
1059 assert!(output_boxes[0].equal_within_delta(
1060 &DetectBox {
1061 bbox: BoundingBox {
1062 xmin: 0.43171933,
1063 ymin: 0.68243736,
1064 xmax: 0.5626645,
1065 ymax: 0.808863,
1066 },
1067 score: 0.99240804,
1068 label: 0
1069 },
1070 1e-6
1071 ));
1072 }
1073
1074 #[test]
1075 fn test_modelpack_seg() {
1076 let out = include_bytes!(concat!(
1077 env!("CARGO_MANIFEST_DIR"),
1078 "/../../testdata/modelpack_seg_2x160x160.bin"
1079 ));
1080 let out = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
1081 let quant = (1.0 / 255.0, 0).into();
1082
1083 let decoder = DecoderBuilder::default()
1084 .with_config_modelpack_seg(configs::Segmentation {
1085 decoder: DecoderType::ModelPack,
1086 quantization: Some(quant),
1087 shape: vec![1, 2, 160, 160],
1088 dshape: vec![
1089 (DimName::Batch, 1),
1090 (DimName::NumClasses, 2),
1091 (DimName::Height, 160),
1092 (DimName::Width, 160),
1093 ],
1094 })
1095 .build()
1096 .unwrap();
1097 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1098 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1099 decoder
1100 .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
1101 .unwrap();
1102
1103 let mut mask = out.slice(s![0, .., .., ..]);
1104 mask.swap_axes(0, 1);
1105 mask.swap_axes(1, 2);
1106 let mask = [Segmentation {
1107 xmin: 0.0,
1108 ymin: 0.0,
1109 xmax: 1.0,
1110 ymax: 1.0,
1111 segmentation: mask.into_owned(),
1112 }];
1113 compare_outputs((&[], &output_boxes), (&mask, &output_masks));
1114
1115 decoder
1116 .decode_float::<f32>(
1117 &[dequantize_ndarray(out.view(), quant.into())
1118 .view()
1119 .into_dyn()],
1120 &mut output_boxes,
1121 &mut output_masks,
1122 )
1123 .unwrap();
1124
1125 compare_outputs((&[], &output_boxes), (&[], &[]));
1131 let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1132 let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1133
1134 assert_eq!(mask0, mask1);
1135 }
1136 #[test]
1137 fn test_modelpack_seg_quant() {
1138 let out = include_bytes!(concat!(
1139 env!("CARGO_MANIFEST_DIR"),
1140 "/../../testdata/modelpack_seg_2x160x160.bin"
1141 ));
1142 let out_u8 = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
1143 let out_i8 = out_u8.mapv(|x| (x as i16 - 128) as i8);
1144 let out_u16 = out_u8.mapv(|x| (x as u16) << 8);
1145 let out_i16 = out_u8.mapv(|x| (((x as i32) << 8) - 32768) as i16);
1146 let out_u32 = out_u8.mapv(|x| (x as u32) << 24);
1147 let out_i32 = out_u8.mapv(|x| (((x as i64) << 24) - 2147483648) as i32);
1148
1149 let quant = (1.0 / 255.0, 0).into();
1150
1151 let decoder = DecoderBuilder::default()
1152 .with_config_modelpack_seg(configs::Segmentation {
1153 decoder: DecoderType::ModelPack,
1154 quantization: Some(quant),
1155 shape: vec![1, 2, 160, 160],
1156 dshape: vec![
1157 (DimName::Batch, 1),
1158 (DimName::NumClasses, 2),
1159 (DimName::Height, 160),
1160 (DimName::Width, 160),
1161 ],
1162 })
1163 .build()
1164 .unwrap();
1165 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1166 let mut output_masks_u8: Vec<_> = Vec::with_capacity(10);
1167 decoder
1168 .decode_quantized(
1169 &[out_u8.view().into()],
1170 &mut output_boxes,
1171 &mut output_masks_u8,
1172 )
1173 .unwrap();
1174
1175 let mut output_masks_i8: Vec<_> = Vec::with_capacity(10);
1176 decoder
1177 .decode_quantized(
1178 &[out_i8.view().into()],
1179 &mut output_boxes,
1180 &mut output_masks_i8,
1181 )
1182 .unwrap();
1183
1184 let mut output_masks_u16: Vec<_> = Vec::with_capacity(10);
1185 decoder
1186 .decode_quantized(
1187 &[out_u16.view().into()],
1188 &mut output_boxes,
1189 &mut output_masks_u16,
1190 )
1191 .unwrap();
1192
1193 let mut output_masks_i16: Vec<_> = Vec::with_capacity(10);
1194 decoder
1195 .decode_quantized(
1196 &[out_i16.view().into()],
1197 &mut output_boxes,
1198 &mut output_masks_i16,
1199 )
1200 .unwrap();
1201
1202 let mut output_masks_u32: Vec<_> = Vec::with_capacity(10);
1203 decoder
1204 .decode_quantized(
1205 &[out_u32.view().into()],
1206 &mut output_boxes,
1207 &mut output_masks_u32,
1208 )
1209 .unwrap();
1210
1211 let mut output_masks_i32: Vec<_> = Vec::with_capacity(10);
1212 decoder
1213 .decode_quantized(
1214 &[out_i32.view().into()],
1215 &mut output_boxes,
1216 &mut output_masks_i32,
1217 )
1218 .unwrap();
1219
1220 compare_outputs((&[], &output_boxes), (&[], &[]));
1221 let mask_u8 = segmentation_to_mask(output_masks_u8[0].segmentation.view()).unwrap();
1222 let mask_i8 = segmentation_to_mask(output_masks_i8[0].segmentation.view()).unwrap();
1223 let mask_u16 = segmentation_to_mask(output_masks_u16[0].segmentation.view()).unwrap();
1224 let mask_i16 = segmentation_to_mask(output_masks_i16[0].segmentation.view()).unwrap();
1225 let mask_u32 = segmentation_to_mask(output_masks_u32[0].segmentation.view()).unwrap();
1226 let mask_i32 = segmentation_to_mask(output_masks_i32[0].segmentation.view()).unwrap();
1227 assert_eq!(mask_u8, mask_i8);
1228 assert_eq!(mask_u8, mask_u16);
1229 assert_eq!(mask_u8, mask_i16);
1230 assert_eq!(mask_u8, mask_u32);
1231 assert_eq!(mask_u8, mask_i32);
1232 }
1233
1234 #[test]
1235 fn test_modelpack_segdet() {
1236 let score_threshold = 0.45;
1237 let iou_threshold = 0.45;
1238
1239 let boxes = include_bytes!(concat!(
1240 env!("CARGO_MANIFEST_DIR"),
1241 "/../../testdata/modelpack_boxes_1935x1x4.bin"
1242 ));
1243 let boxes = Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
1244
1245 let scores = include_bytes!(concat!(
1246 env!("CARGO_MANIFEST_DIR"),
1247 "/../../testdata/modelpack_scores_1935x1.bin"
1248 ));
1249 let scores = Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
1250
1251 let seg = include_bytes!(concat!(
1252 env!("CARGO_MANIFEST_DIR"),
1253 "/../../testdata/modelpack_seg_2x160x160.bin"
1254 ));
1255 let seg = Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1256
1257 let quant_boxes = (0.004656755365431309, 21).into();
1258 let quant_scores = (0.0019603664986789227, 0).into();
1259 let quant_seg = (1.0 / 255.0, 0).into();
1260
1261 let decoder = DecoderBuilder::default()
1262 .with_config_modelpack_segdet(
1263 configs::Boxes {
1264 decoder: DecoderType::ModelPack,
1265 quantization: Some(quant_boxes),
1266 shape: vec![1, 1935, 1, 4],
1267 dshape: vec![
1268 (DimName::Batch, 1),
1269 (DimName::NumBoxes, 1935),
1270 (DimName::Padding, 1),
1271 (DimName::BoxCoords, 4),
1272 ],
1273 normalized: Some(true),
1274 },
1275 configs::Scores {
1276 decoder: DecoderType::ModelPack,
1277 quantization: Some(quant_scores),
1278 shape: vec![1, 1935, 1],
1279 dshape: vec![
1280 (DimName::Batch, 1),
1281 (DimName::NumBoxes, 1935),
1282 (DimName::NumClasses, 1),
1283 ],
1284 },
1285 configs::Segmentation {
1286 decoder: DecoderType::ModelPack,
1287 quantization: Some(quant_seg),
1288 shape: vec![1, 2, 160, 160],
1289 dshape: vec![
1290 (DimName::Batch, 1),
1291 (DimName::NumClasses, 2),
1292 (DimName::Height, 160),
1293 (DimName::Width, 160),
1294 ],
1295 },
1296 )
1297 .with_iou_threshold(iou_threshold)
1298 .with_score_threshold(score_threshold)
1299 .build()
1300 .unwrap();
1301 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1302 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1303 decoder
1304 .decode_quantized(
1305 &[scores.view().into(), boxes.view().into(), seg.view().into()],
1306 &mut output_boxes,
1307 &mut output_masks,
1308 )
1309 .unwrap();
1310
1311 let mut mask = seg.slice(s![0, .., .., ..]);
1312 mask.swap_axes(0, 1);
1313 mask.swap_axes(1, 2);
1314 let mask = [Segmentation {
1315 xmin: 0.0,
1316 ymin: 0.0,
1317 xmax: 1.0,
1318 ymax: 1.0,
1319 segmentation: mask.into_owned(),
1320 }];
1321 let correct_boxes = [DetectBox {
1322 bbox: BoundingBox {
1323 xmin: 0.40513772,
1324 ymin: 0.6379755,
1325 xmax: 0.5122431,
1326 ymax: 0.7730214,
1327 },
1328 score: 0.4861709,
1329 label: 0,
1330 }];
1331 compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1332
1333 let scores = dequantize_ndarray(scores.view(), quant_scores.into());
1334 let boxes = dequantize_ndarray(boxes.view(), quant_boxes.into());
1335 let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1336 decoder
1337 .decode_float::<f32>(
1338 &[
1339 scores.view().into_dyn(),
1340 boxes.view().into_dyn(),
1341 seg.view().into_dyn(),
1342 ],
1343 &mut output_boxes,
1344 &mut output_masks,
1345 )
1346 .unwrap();
1347
1348 compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1354 let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1355 let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1356
1357 assert_eq!(mask0, mask1);
1358 }
1359
1360 #[test]
1361 fn test_modelpack_segdet_split() {
1362 let score_threshold = 0.8;
1363 let iou_threshold = 0.5;
1364
1365 let seg = include_bytes!(concat!(
1366 env!("CARGO_MANIFEST_DIR"),
1367 "/../../testdata/modelpack_seg_2x160x160.bin"
1368 ));
1369 let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1370
1371 let detect0 = include_bytes!(concat!(
1372 env!("CARGO_MANIFEST_DIR"),
1373 "/../../testdata/modelpack_split_9x15x18.bin"
1374 ));
1375 let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1376
1377 let detect1 = include_bytes!(concat!(
1378 env!("CARGO_MANIFEST_DIR"),
1379 "/../../testdata/modelpack_split_17x30x18.bin"
1380 ));
1381 let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1382
1383 let quant0 = (0.08547406643629074, 174).into();
1384 let quant1 = (0.09929127991199493, 183).into();
1385 let quant_seg = (1.0 / 255.0, 0).into();
1386
1387 let anchors0 = vec![
1388 [0.36666667461395264, 0.31481480598449707],
1389 [0.38749998807907104, 0.4740740656852722],
1390 [0.5333333611488342, 0.644444465637207],
1391 ];
1392 let anchors1 = vec![
1393 [0.13750000298023224, 0.2074074000120163],
1394 [0.2541666626930237, 0.21481481194496155],
1395 [0.23125000298023224, 0.35185185074806213],
1396 ];
1397
1398 let decoder = DecoderBuilder::default()
1399 .with_config_modelpack_segdet_split(
1400 vec![
1401 configs::Detection {
1402 decoder: DecoderType::ModelPack,
1403 shape: vec![1, 17, 30, 18],
1404 anchors: Some(anchors1),
1405 quantization: Some(quant1),
1406 dshape: vec![
1407 (DimName::Batch, 1),
1408 (DimName::Height, 17),
1409 (DimName::Width, 30),
1410 (DimName::NumAnchorsXFeatures, 18),
1411 ],
1412 normalized: Some(true),
1413 },
1414 configs::Detection {
1415 decoder: DecoderType::ModelPack,
1416 shape: vec![1, 9, 15, 18],
1417 anchors: Some(anchors0),
1418 quantization: Some(quant0),
1419 dshape: vec![
1420 (DimName::Batch, 1),
1421 (DimName::Height, 9),
1422 (DimName::Width, 15),
1423 (DimName::NumAnchorsXFeatures, 18),
1424 ],
1425 normalized: Some(true),
1426 },
1427 ],
1428 configs::Segmentation {
1429 decoder: DecoderType::ModelPack,
1430 quantization: Some(quant_seg),
1431 shape: vec![1, 2, 160, 160],
1432 dshape: vec![
1433 (DimName::Batch, 1),
1434 (DimName::NumClasses, 2),
1435 (DimName::Height, 160),
1436 (DimName::Width, 160),
1437 ],
1438 },
1439 )
1440 .with_score_threshold(score_threshold)
1441 .with_iou_threshold(iou_threshold)
1442 .build()
1443 .unwrap();
1444 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1445 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1446 decoder
1447 .decode_quantized(
1448 &[
1449 detect0.view().into(),
1450 detect1.view().into(),
1451 seg.view().into(),
1452 ],
1453 &mut output_boxes,
1454 &mut output_masks,
1455 )
1456 .unwrap();
1457
1458 let mut mask = seg.slice(s![0, .., .., ..]);
1459 mask.swap_axes(0, 1);
1460 mask.swap_axes(1, 2);
1461 let mask = [Segmentation {
1462 xmin: 0.0,
1463 ymin: 0.0,
1464 xmax: 1.0,
1465 ymax: 1.0,
1466 segmentation: mask.into_owned(),
1467 }];
1468 let correct_boxes = [DetectBox {
1469 bbox: BoundingBox {
1470 xmin: 0.43171933,
1471 ymin: 0.68243736,
1472 xmax: 0.5626645,
1473 ymax: 0.808863,
1474 },
1475 score: 0.99240804,
1476 label: 0,
1477 }];
1478 println!("Output Boxes: {:?}", output_boxes);
1479 compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1480
1481 let detect0 = dequantize_ndarray(detect0.view(), quant0.into());
1482 let detect1 = dequantize_ndarray(detect1.view(), quant1.into());
1483 let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1484 decoder
1485 .decode_float::<f32>(
1486 &[
1487 detect0.view().into_dyn(),
1488 detect1.view().into_dyn(),
1489 seg.view().into_dyn(),
1490 ],
1491 &mut output_boxes,
1492 &mut output_masks,
1493 )
1494 .unwrap();
1495
1496 compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1502 let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1503 let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1504
1505 assert_eq!(mask0, mask1);
1506 }
1507
1508 #[test]
1509 fn test_dequant_chunked() {
1510 let out = include_bytes!(concat!(
1511 env!("CARGO_MANIFEST_DIR"),
1512 "/../../testdata/yolov8s_80_classes.bin"
1513 ));
1514 let mut out =
1515 unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) }.to_vec();
1516 out.push(123); let mut out_dequant = vec![0.0; 84 * 8400 + 1];
1519 let mut out_dequant_simd = vec![0.0; 84 * 8400 + 1];
1520 let quant = Quantization::new(0.0040811873, -123);
1521 dequantize_cpu(&out, quant, &mut out_dequant);
1522
1523 dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1524 assert_eq!(out_dequant, out_dequant_simd);
1525
1526 let quant = Quantization::new(0.0040811873, 0);
1527 dequantize_cpu(&out, quant, &mut out_dequant);
1528
1529 dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1530 assert_eq!(out_dequant, out_dequant_simd);
1531 }
1532
1533 #[test]
1534 fn test_decoder_yolo_det() {
1535 let score_threshold = 0.25;
1536 let iou_threshold = 0.7;
1537 let out = include_bytes!(concat!(
1538 env!("CARGO_MANIFEST_DIR"),
1539 "/../../testdata/yolov8s_80_classes.bin"
1540 ));
1541 let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
1542 let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
1543 let quant = (0.0040811873, -123).into();
1544
1545 let decoder = DecoderBuilder::default()
1546 .with_config_yolo_det(
1547 configs::Detection {
1548 decoder: DecoderType::Ultralytics,
1549 shape: vec![1, 84, 8400],
1550 anchors: None,
1551 quantization: Some(quant),
1552 dshape: vec![
1553 (DimName::Batch, 1),
1554 (DimName::NumFeatures, 84),
1555 (DimName::NumBoxes, 8400),
1556 ],
1557 normalized: Some(true),
1558 },
1559 Some(DecoderVersion::Yolo11),
1560 )
1561 .with_score_threshold(score_threshold)
1562 .with_iou_threshold(iou_threshold)
1563 .build()
1564 .unwrap();
1565
1566 let mut output_boxes: Vec<_> = Vec::with_capacity(50);
1567 decode_yolo_det(
1568 (out.slice(s![0, .., ..]), quant.into()),
1569 score_threshold,
1570 iou_threshold,
1571 Some(configs::Nms::ClassAgnostic),
1572 &mut output_boxes,
1573 );
1574 assert!(output_boxes[0].equal_within_delta(
1575 &DetectBox {
1576 bbox: BoundingBox {
1577 xmin: 0.5285137,
1578 ymin: 0.05305544,
1579 xmax: 0.87541467,
1580 ymax: 0.9998909,
1581 },
1582 score: 0.5591227,
1583 label: 0
1584 },
1585 1e-6
1586 ));
1587
1588 assert!(output_boxes[1].equal_within_delta(
1589 &DetectBox {
1590 bbox: BoundingBox {
1591 xmin: 0.130598,
1592 ymin: 0.43260583,
1593 xmax: 0.35098213,
1594 ymax: 0.9958097,
1595 },
1596 score: 0.33057618,
1597 label: 75
1598 },
1599 1e-6
1600 ));
1601
1602 let mut output_boxes1: Vec<_> = Vec::with_capacity(50);
1603 let mut output_masks1: Vec<_> = Vec::with_capacity(50);
1604 decoder
1605 .decode_quantized(&[out.view().into()], &mut output_boxes1, &mut output_masks1)
1606 .unwrap();
1607
1608 let out = dequantize_ndarray(out.view(), quant.into());
1609 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(50);
1610 let mut output_masks_f32: Vec<_> = Vec::with_capacity(50);
1611 decoder
1612 .decode_float::<f32>(
1613 &[out.view().into_dyn()],
1614 &mut output_boxes_f32,
1615 &mut output_masks_f32,
1616 )
1617 .unwrap();
1618
1619 compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1620 compare_outputs((&output_boxes, &output_boxes_f32), (&[], &output_masks_f32));
1621 }
1622
1623 #[test]
1624 fn test_decoder_masks() {
1625 let score_threshold = 0.45;
1626 let iou_threshold = 0.45;
1627 let boxes = include_bytes!(concat!(
1628 env!("CARGO_MANIFEST_DIR"),
1629 "/../../testdata/yolov8_boxes_116x8400.bin"
1630 ));
1631 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1632 let boxes = ndarray::Array2::from_shape_vec((116, 8400), boxes.to_vec()).unwrap();
1633 let quant_boxes = Quantization::new(0.021287761628627777, 31);
1634
1635 let protos = include_bytes!(concat!(
1636 env!("CARGO_MANIFEST_DIR"),
1637 "/../../testdata/yolov8_protos_160x160x32.bin"
1638 ));
1639 let protos =
1640 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
1641 let protos = ndarray::Array3::from_shape_vec((160, 160, 32), protos.to_vec()).unwrap();
1642 let quant_protos = Quantization::new(0.02491161972284317, -117);
1643 let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1644 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1645 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1646 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1647 decode_yolo_segdet_float(
1648 seg.view(),
1649 protos.view(),
1650 score_threshold,
1651 iou_threshold,
1652 Some(configs::Nms::ClassAgnostic),
1653 &mut output_boxes,
1654 &mut output_masks,
1655 )
1656 .unwrap();
1657 assert_eq!(output_boxes.len(), 2);
1658 assert_eq!(output_boxes.len(), output_masks.len());
1659
1660 for (b, m) in output_boxes.iter().zip(&output_masks) {
1661 assert!(b.bbox.xmin >= m.xmin);
1662 assert!(b.bbox.ymin >= m.ymin);
1663 assert!(b.bbox.xmax >= m.xmax);
1664 assert!(b.bbox.ymax >= m.ymax);
1665 }
1666 assert!(output_boxes[0].equal_within_delta(
1667 &DetectBox {
1668 bbox: BoundingBox {
1669 xmin: 0.08515105,
1670 ymin: 0.7131401,
1671 xmax: 0.29802868,
1672 ymax: 0.8195788,
1673 },
1674 score: 0.91537374,
1675 label: 23
1676 },
1677 1.0 / 160.0, ));
1679
1680 assert!(output_boxes[1].equal_within_delta(
1681 &DetectBox {
1682 bbox: BoundingBox {
1683 xmin: 0.59605736,
1684 ymin: 0.25545314,
1685 xmax: 0.93666154,
1686 ymax: 0.72378385,
1687 },
1688 score: 0.91537374,
1689 label: 23
1690 },
1691 1.0 / 160.0, ));
1693
1694 let full_mask = include_bytes!(concat!(
1695 env!("CARGO_MANIFEST_DIR"),
1696 "/../../testdata/yolov8_mask_results.bin"
1697 ));
1698 let full_mask = ndarray::Array2::from_shape_vec((160, 160), full_mask.to_vec()).unwrap();
1699
1700 let cropped_mask = full_mask.slice(ndarray::s![
1701 (output_masks[1].ymin * 160.0) as usize..(output_masks[1].ymax * 160.0) as usize,
1702 (output_masks[1].xmin * 160.0) as usize..(output_masks[1].xmax * 160.0) as usize,
1703 ]);
1704
1705 assert_eq!(
1706 cropped_mask,
1707 segmentation_to_mask(output_masks[1].segmentation.view()).unwrap()
1708 );
1709 }
1710
1711 #[test]
1716 fn test_decoder_masks_nchw_protos() {
1717 let score_threshold = 0.45;
1718 let iou_threshold = 0.45;
1719
1720 let boxes_raw = include_bytes!(concat!(
1722 env!("CARGO_MANIFEST_DIR"),
1723 "/../../testdata/yolov8_boxes_116x8400.bin"
1724 ));
1725 let boxes_raw =
1726 unsafe { std::slice::from_raw_parts(boxes_raw.as_ptr() as *const i8, boxes_raw.len()) };
1727 let boxes_2d = ndarray::Array2::from_shape_vec((116, 8400), boxes_raw.to_vec()).unwrap();
1728 let quant_boxes = Quantization::new(0.021287761628627777, 31);
1729
1730 let protos_raw = include_bytes!(concat!(
1732 env!("CARGO_MANIFEST_DIR"),
1733 "/../../testdata/yolov8_protos_160x160x32.bin"
1734 ));
1735 let protos_raw = unsafe {
1736 std::slice::from_raw_parts(protos_raw.as_ptr() as *const i8, protos_raw.len())
1737 };
1738 let protos_hwc =
1739 ndarray::Array3::from_shape_vec((160, 160, 32), protos_raw.to_vec()).unwrap();
1740 let quant_protos = Quantization::new(0.02491161972284317, -117);
1741 let protos_f32_hwc = dequantize_ndarray::<_, _, f32>(protos_hwc.view(), quant_protos);
1742
1743 let seg = dequantize_ndarray::<_, _, f32>(boxes_2d.view(), quant_boxes);
1745 let mut ref_boxes: Vec<_> = Vec::with_capacity(10);
1746 let mut ref_masks: Vec<_> = Vec::with_capacity(10);
1747 decode_yolo_segdet_float(
1748 seg.view(),
1749 protos_f32_hwc.view(),
1750 score_threshold,
1751 iou_threshold,
1752 Some(configs::Nms::ClassAgnostic),
1753 &mut ref_boxes,
1754 &mut ref_masks,
1755 )
1756 .unwrap();
1757 assert_eq!(ref_boxes.len(), 2);
1758
1759 let protos_f32_chw = protos_f32_hwc.permuted_axes([2, 0, 1]); let protos_nchw = protos_f32_chw.insert_axis(ndarray::Axis(0)); let seg_3d = seg.insert_axis(ndarray::Axis(0)); let decoder = DecoderBuilder::default()
1769 .with_config_yolo_segdet(
1770 configs::Detection {
1771 decoder: configs::DecoderType::Ultralytics,
1772 quantization: None,
1773 shape: vec![1, 116, 8400],
1774 dshape: vec![],
1775 normalized: Some(true),
1776 anchors: None,
1777 },
1778 configs::Protos {
1779 decoder: configs::DecoderType::Ultralytics,
1780 quantization: None,
1781 shape: vec![1, 32, 160, 160],
1782 dshape: vec![], },
1784 None, )
1786 .with_score_threshold(score_threshold)
1787 .with_iou_threshold(iou_threshold)
1788 .build()
1789 .unwrap();
1790
1791 let mut cfg_boxes: Vec<_> = Vec::with_capacity(10);
1792 let mut cfg_masks: Vec<_> = Vec::with_capacity(10);
1793 decoder
1794 .decode_float(
1795 &[seg_3d.view().into_dyn(), protos_nchw.view().into_dyn()],
1796 &mut cfg_boxes,
1797 &mut cfg_masks,
1798 )
1799 .unwrap();
1800
1801 assert_eq!(
1803 cfg_boxes.len(),
1804 ref_boxes.len(),
1805 "config path produced {} boxes, reference produced {}",
1806 cfg_boxes.len(),
1807 ref_boxes.len()
1808 );
1809
1810 for (i, (cb, rb)) in cfg_boxes.iter().zip(&ref_boxes).enumerate() {
1812 assert!(
1813 cb.equal_within_delta(rb, 0.01),
1814 "box {i} mismatch: config={cb:?}, reference={rb:?}"
1815 );
1816 }
1817
1818 for (i, (cm, rm)) in cfg_masks.iter().zip(&ref_masks).enumerate() {
1820 let cm_arr = segmentation_to_mask(cm.segmentation.view()).unwrap();
1821 let rm_arr = segmentation_to_mask(rm.segmentation.view()).unwrap();
1822 assert_eq!(
1823 cm_arr, rm_arr,
1824 "mask {i} pixel mismatch between config-driven and reference paths"
1825 );
1826 }
1827 }
1828
1829 #[test]
1830 fn test_decoder_masks_i8() {
1831 let score_threshold = 0.45;
1832 let iou_threshold = 0.45;
1833 let boxes = include_bytes!(concat!(
1834 env!("CARGO_MANIFEST_DIR"),
1835 "/../../testdata/yolov8_boxes_116x8400.bin"
1836 ));
1837 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1838 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes.to_vec()).unwrap();
1839 let quant_boxes = (0.021287761628627777, 31).into();
1840
1841 let protos = include_bytes!(concat!(
1842 env!("CARGO_MANIFEST_DIR"),
1843 "/../../testdata/yolov8_protos_160x160x32.bin"
1844 ));
1845 let protos =
1846 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
1847 let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
1848 let quant_protos = (0.02491161972284317, -117).into();
1849 let mut output_boxes: Vec<_> = Vec::with_capacity(500);
1850 let mut output_masks: Vec<_> = Vec::with_capacity(500);
1851
1852 let decoder = DecoderBuilder::default()
1853 .with_config_yolo_segdet(
1854 configs::Detection {
1855 decoder: configs::DecoderType::Ultralytics,
1856 quantization: Some(quant_boxes),
1857 shape: vec![1, 116, 8400],
1858 anchors: None,
1859 dshape: vec![
1860 (DimName::Batch, 1),
1861 (DimName::NumFeatures, 116),
1862 (DimName::NumBoxes, 8400),
1863 ],
1864 normalized: Some(true),
1865 },
1866 Protos {
1867 decoder: configs::DecoderType::Ultralytics,
1868 quantization: Some(quant_protos),
1869 shape: vec![1, 160, 160, 32],
1870 dshape: vec![
1871 (DimName::Batch, 1),
1872 (DimName::Height, 160),
1873 (DimName::Width, 160),
1874 (DimName::NumProtos, 32),
1875 ],
1876 },
1877 Some(DecoderVersion::Yolo11),
1878 )
1879 .with_score_threshold(score_threshold)
1880 .with_iou_threshold(iou_threshold)
1881 .build()
1882 .unwrap();
1883
1884 let quant_boxes = quant_boxes.into();
1885 let quant_protos = quant_protos.into();
1886
1887 decode_yolo_segdet_quant(
1888 (boxes.slice(s![0, .., ..]), quant_boxes),
1889 (protos.slice(s![0, .., .., ..]), quant_protos),
1890 score_threshold,
1891 iou_threshold,
1892 Some(configs::Nms::ClassAgnostic),
1893 &mut output_boxes,
1894 &mut output_masks,
1895 )
1896 .unwrap();
1897
1898 let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
1899 let mut output_masks1: Vec<_> = Vec::with_capacity(500);
1900
1901 decoder
1902 .decode_quantized(
1903 &[boxes.view().into(), protos.view().into()],
1904 &mut output_boxes1,
1905 &mut output_masks1,
1906 )
1907 .unwrap();
1908
1909 let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1910 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1911
1912 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
1913 let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
1914 decode_yolo_segdet_float(
1915 seg.slice(s![0, .., ..]),
1916 protos.slice(s![0, .., .., ..]),
1917 score_threshold,
1918 iou_threshold,
1919 Some(configs::Nms::ClassAgnostic),
1920 &mut output_boxes_f32,
1921 &mut output_masks_f32,
1922 )
1923 .unwrap();
1924
1925 let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(500);
1926 let mut output_masks1_f32: Vec<_> = Vec::with_capacity(500);
1927
1928 decoder
1929 .decode_float(
1930 &[seg.view().into_dyn(), protos.view().into_dyn()],
1931 &mut output_boxes1_f32,
1932 &mut output_masks1_f32,
1933 )
1934 .unwrap();
1935
1936 compare_outputs(
1937 (&output_boxes, &output_boxes1),
1938 (&output_masks, &output_masks1),
1939 );
1940
1941 compare_outputs(
1942 (&output_boxes, &output_boxes_f32),
1943 (&output_masks, &output_masks_f32),
1944 );
1945
1946 compare_outputs(
1947 (&output_boxes_f32, &output_boxes1_f32),
1948 (&output_masks_f32, &output_masks1_f32),
1949 );
1950 }
1951
1952 #[test]
1953 fn test_decoder_yolo_split() {
1954 let score_threshold = 0.45;
1955 let iou_threshold = 0.45;
1956 let boxes = include_bytes!(concat!(
1957 env!("CARGO_MANIFEST_DIR"),
1958 "/../../testdata/yolov8_boxes_116x8400.bin"
1959 ));
1960 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1961 let boxes: Vec<_> = boxes.iter().map(|x| *x as i16 * 256).collect();
1962 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
1963
1964 let quant_boxes = Quantization::new(0.021287761628627777 / 256.0, 31 * 256);
1965
1966 let decoder = DecoderBuilder::default()
1967 .with_config_yolo_split_det(
1968 configs::Boxes {
1969 decoder: configs::DecoderType::Ultralytics,
1970 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
1971 shape: vec![1, 4, 8400],
1972 dshape: vec![
1973 (DimName::Batch, 1),
1974 (DimName::BoxCoords, 4),
1975 (DimName::NumBoxes, 8400),
1976 ],
1977 normalized: Some(true),
1978 },
1979 configs::Scores {
1980 decoder: configs::DecoderType::Ultralytics,
1981 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
1982 shape: vec![1, 80, 8400],
1983 dshape: vec![
1984 (DimName::Batch, 1),
1985 (DimName::NumClasses, 80),
1986 (DimName::NumBoxes, 8400),
1987 ],
1988 },
1989 )
1990 .with_score_threshold(score_threshold)
1991 .with_iou_threshold(iou_threshold)
1992 .build()
1993 .unwrap();
1994
1995 let mut output_boxes: Vec<_> = Vec::with_capacity(500);
1996 let mut output_masks: Vec<_> = Vec::with_capacity(500);
1997
1998 decoder
1999 .decode_quantized(
2000 &[
2001 boxes.slice(s![.., ..4, ..]).into(),
2002 boxes.slice(s![.., 4..84, ..]).into(),
2003 ],
2004 &mut output_boxes,
2005 &mut output_masks,
2006 )
2007 .unwrap();
2008
2009 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
2010 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2011 decode_yolo_det_float(
2012 seg.slice(s![0, ..84, ..]),
2013 score_threshold,
2014 iou_threshold,
2015 Some(configs::Nms::ClassAgnostic),
2016 &mut output_boxes_f32,
2017 );
2018
2019 let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
2020 let mut output_masks1: Vec<_> = Vec::with_capacity(500);
2021
2022 decoder
2023 .decode_float(
2024 &[
2025 seg.slice(s![.., ..4, ..]).into_dyn(),
2026 seg.slice(s![.., 4..84, ..]).into_dyn(),
2027 ],
2028 &mut output_boxes1,
2029 &mut output_masks1,
2030 )
2031 .unwrap();
2032 compare_outputs((&output_boxes, &output_boxes_f32), (&output_masks, &[]));
2033 compare_outputs((&output_boxes_f32, &output_boxes1), (&[], &output_masks1));
2034 }
2035
2036 #[test]
2037 fn test_decoder_masks_config_mixed() {
2038 let score_threshold = 0.45;
2039 let iou_threshold = 0.45;
2040 let boxes = include_bytes!(concat!(
2041 env!("CARGO_MANIFEST_DIR"),
2042 "/../../testdata/yolov8_boxes_116x8400.bin"
2043 ));
2044 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
2045 let boxes: Vec<_> = boxes.iter().map(|x| *x as i16 * 256).collect();
2046 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2047
2048 let quant_boxes = Quantization::new(0.021287761628627777 / 256.0, 31 * 256);
2049
2050 let protos = include_bytes!(concat!(
2051 env!("CARGO_MANIFEST_DIR"),
2052 "/../../testdata/yolov8_protos_160x160x32.bin"
2053 ));
2054 let protos =
2055 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
2056 let protos: Vec<_> = protos.to_vec();
2057 let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
2058 let quant_protos = Quantization::new(0.02491161972284317, -117);
2059
2060 let decoder = DecoderBuilder::default()
2061 .with_config_yolo_split_segdet(
2062 configs::Boxes {
2063 decoder: configs::DecoderType::Ultralytics,
2064 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2065 shape: vec![1, 4, 8400],
2066 dshape: vec![
2067 (DimName::Batch, 1),
2068 (DimName::BoxCoords, 4),
2069 (DimName::NumBoxes, 8400),
2070 ],
2071 normalized: Some(true),
2072 },
2073 configs::Scores {
2074 decoder: configs::DecoderType::Ultralytics,
2075 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2076 shape: vec![1, 80, 8400],
2077 dshape: vec![
2078 (DimName::Batch, 1),
2079 (DimName::NumClasses, 80),
2080 (DimName::NumBoxes, 8400),
2081 ],
2082 },
2083 configs::MaskCoefficients {
2084 decoder: configs::DecoderType::Ultralytics,
2085 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2086 shape: vec![1, 32, 8400],
2087 dshape: vec![
2088 (DimName::Batch, 1),
2089 (DimName::NumProtos, 32),
2090 (DimName::NumBoxes, 8400),
2091 ],
2092 },
2093 configs::Protos {
2094 decoder: configs::DecoderType::Ultralytics,
2095 quantization: Some(QuantTuple(quant_protos.scale, quant_protos.zero_point)),
2096 shape: vec![1, 160, 160, 32],
2097 dshape: vec![
2098 (DimName::Batch, 1),
2099 (DimName::Height, 160),
2100 (DimName::Width, 160),
2101 (DimName::NumProtos, 32),
2102 ],
2103 },
2104 )
2105 .with_score_threshold(score_threshold)
2106 .with_iou_threshold(iou_threshold)
2107 .build()
2108 .unwrap();
2109
2110 let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2111 let mut output_masks: Vec<_> = Vec::with_capacity(500);
2112
2113 decoder
2114 .decode_quantized(
2115 &[
2116 boxes.slice(s![.., ..4, ..]).into(),
2117 boxes.slice(s![.., 4..84, ..]).into(),
2118 boxes.slice(s![.., 84.., ..]).into(),
2119 protos.view().into(),
2120 ],
2121 &mut output_boxes,
2122 &mut output_masks,
2123 )
2124 .unwrap();
2125
2126 let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
2127 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
2128 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2129 let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
2130 decode_yolo_segdet_float(
2131 seg.slice(s![0, .., ..]),
2132 protos.slice(s![0, .., .., ..]),
2133 score_threshold,
2134 iou_threshold,
2135 Some(configs::Nms::ClassAgnostic),
2136 &mut output_boxes_f32,
2137 &mut output_masks_f32,
2138 )
2139 .unwrap();
2140
2141 let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
2142 let mut output_masks1: Vec<_> = Vec::with_capacity(500);
2143
2144 decoder
2145 .decode_float(
2146 &[
2147 seg.slice(s![.., ..4, ..]).into_dyn(),
2148 seg.slice(s![.., 4..84, ..]).into_dyn(),
2149 seg.slice(s![.., 84.., ..]).into_dyn(),
2150 protos.view().into_dyn(),
2151 ],
2152 &mut output_boxes1,
2153 &mut output_masks1,
2154 )
2155 .unwrap();
2156 compare_outputs(
2157 (&output_boxes, &output_boxes_f32),
2158 (&output_masks, &output_masks_f32),
2159 );
2160 compare_outputs(
2161 (&output_boxes_f32, &output_boxes1),
2162 (&output_masks_f32, &output_masks1),
2163 );
2164 }
2165
2166 #[test]
2167 fn test_decoder_masks_config_i32() {
2168 let score_threshold = 0.45;
2169 let iou_threshold = 0.45;
2170 let boxes = include_bytes!(concat!(
2171 env!("CARGO_MANIFEST_DIR"),
2172 "/../../testdata/yolov8_boxes_116x8400.bin"
2173 ));
2174 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
2175 let scale = 1 << 23;
2176 let boxes: Vec<_> = boxes.iter().map(|x| *x as i32 * scale).collect();
2177 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2178
2179 let quant_boxes = Quantization::new(0.021287761628627777 / scale as f32, 31 * scale);
2180
2181 let protos = include_bytes!(concat!(
2182 env!("CARGO_MANIFEST_DIR"),
2183 "/../../testdata/yolov8_protos_160x160x32.bin"
2184 ));
2185 let protos =
2186 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
2187 let protos: Vec<_> = protos.iter().map(|x| *x as i32 * scale).collect();
2188 let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
2189 let quant_protos = Quantization::new(0.02491161972284317 / scale as f32, -117 * scale);
2190
2191 let decoder = DecoderBuilder::default()
2192 .with_config_yolo_split_segdet(
2193 configs::Boxes {
2194 decoder: configs::DecoderType::Ultralytics,
2195 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2196 shape: vec![1, 4, 8400],
2197 dshape: vec![
2198 (DimName::Batch, 1),
2199 (DimName::BoxCoords, 4),
2200 (DimName::NumBoxes, 8400),
2201 ],
2202 normalized: Some(true),
2203 },
2204 configs::Scores {
2205 decoder: configs::DecoderType::Ultralytics,
2206 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2207 shape: vec![1, 80, 8400],
2208 dshape: vec![
2209 (DimName::Batch, 1),
2210 (DimName::NumClasses, 80),
2211 (DimName::NumBoxes, 8400),
2212 ],
2213 },
2214 configs::MaskCoefficients {
2215 decoder: configs::DecoderType::Ultralytics,
2216 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2217 shape: vec![1, 32, 8400],
2218 dshape: vec![
2219 (DimName::Batch, 1),
2220 (DimName::NumProtos, 32),
2221 (DimName::NumBoxes, 8400),
2222 ],
2223 },
2224 configs::Protos {
2225 decoder: configs::DecoderType::Ultralytics,
2226 quantization: Some(QuantTuple(quant_protos.scale, quant_protos.zero_point)),
2227 shape: vec![1, 160, 160, 32],
2228 dshape: vec![
2229 (DimName::Batch, 1),
2230 (DimName::Height, 160),
2231 (DimName::Width, 160),
2232 (DimName::NumProtos, 32),
2233 ],
2234 },
2235 )
2236 .with_score_threshold(score_threshold)
2237 .with_iou_threshold(iou_threshold)
2238 .build()
2239 .unwrap();
2240
2241 let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2242 let mut output_masks: Vec<_> = Vec::with_capacity(500);
2243
2244 decoder
2245 .decode_quantized(
2246 &[
2247 boxes.slice(s![.., ..4, ..]).into(),
2248 boxes.slice(s![.., 4..84, ..]).into(),
2249 boxes.slice(s![.., 84.., ..]).into(),
2250 protos.view().into(),
2251 ],
2252 &mut output_boxes,
2253 &mut output_masks,
2254 )
2255 .unwrap();
2256
2257 let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
2258 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
2259 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2260 let mut output_masks_f32: Vec<Segmentation> = Vec::with_capacity(500);
2261 decode_yolo_segdet_float(
2262 seg.slice(s![0, .., ..]),
2263 protos.slice(s![0, .., .., ..]),
2264 score_threshold,
2265 iou_threshold,
2266 Some(configs::Nms::ClassAgnostic),
2267 &mut output_boxes_f32,
2268 &mut output_masks_f32,
2269 )
2270 .unwrap();
2271
2272 assert_eq!(output_boxes.len(), output_boxes_f32.len());
2273 assert_eq!(output_masks.len(), output_masks_f32.len());
2274
2275 compare_outputs(
2276 (&output_boxes, &output_boxes_f32),
2277 (&output_masks, &output_masks_f32),
2278 );
2279 }
2280
2281 #[test]
2283 fn test_context_switch() {
2284 let yolo_det = || {
2285 let score_threshold = 0.25;
2286 let iou_threshold = 0.7;
2287 let out = include_bytes!(concat!(
2288 env!("CARGO_MANIFEST_DIR"),
2289 "/../../testdata/yolov8s_80_classes.bin"
2290 ));
2291 let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
2292 let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
2293 let quant = (0.0040811873, -123).into();
2294
2295 let decoder = DecoderBuilder::default()
2296 .with_config_yolo_det(
2297 configs::Detection {
2298 decoder: DecoderType::Ultralytics,
2299 shape: vec![1, 84, 8400],
2300 anchors: None,
2301 quantization: Some(quant),
2302 dshape: vec![
2303 (DimName::Batch, 1),
2304 (DimName::NumFeatures, 84),
2305 (DimName::NumBoxes, 8400),
2306 ],
2307 normalized: None,
2308 },
2309 None,
2310 )
2311 .with_score_threshold(score_threshold)
2312 .with_iou_threshold(iou_threshold)
2313 .build()
2314 .unwrap();
2315
2316 let mut output_boxes: Vec<_> = Vec::with_capacity(50);
2317 let mut output_masks: Vec<_> = Vec::with_capacity(50);
2318
2319 for _ in 0..100 {
2320 decoder
2321 .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
2322 .unwrap();
2323
2324 assert!(output_boxes[0].equal_within_delta(
2325 &DetectBox {
2326 bbox: BoundingBox {
2327 xmin: 0.5285137,
2328 ymin: 0.05305544,
2329 xmax: 0.87541467,
2330 ymax: 0.9998909,
2331 },
2332 score: 0.5591227,
2333 label: 0
2334 },
2335 1e-6
2336 ));
2337
2338 assert!(output_boxes[1].equal_within_delta(
2339 &DetectBox {
2340 bbox: BoundingBox {
2341 xmin: 0.130598,
2342 ymin: 0.43260583,
2343 xmax: 0.35098213,
2344 ymax: 0.9958097,
2345 },
2346 score: 0.33057618,
2347 label: 75
2348 },
2349 1e-6
2350 ));
2351 assert!(output_masks.is_empty());
2352 }
2353 };
2354
2355 let modelpack_det_split = || {
2356 let score_threshold = 0.8;
2357 let iou_threshold = 0.5;
2358
2359 let seg = include_bytes!(concat!(
2360 env!("CARGO_MANIFEST_DIR"),
2361 "/../../testdata/modelpack_seg_2x160x160.bin"
2362 ));
2363 let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
2364
2365 let detect0 = include_bytes!(concat!(
2366 env!("CARGO_MANIFEST_DIR"),
2367 "/../../testdata/modelpack_split_9x15x18.bin"
2368 ));
2369 let detect0 =
2370 ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
2371
2372 let detect1 = include_bytes!(concat!(
2373 env!("CARGO_MANIFEST_DIR"),
2374 "/../../testdata/modelpack_split_17x30x18.bin"
2375 ));
2376 let detect1 =
2377 ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
2378
2379 let mut mask = seg.slice(s![0, .., .., ..]);
2380 mask.swap_axes(0, 1);
2381 mask.swap_axes(1, 2);
2382 let mask = [Segmentation {
2383 xmin: 0.0,
2384 ymin: 0.0,
2385 xmax: 1.0,
2386 ymax: 1.0,
2387 segmentation: mask.into_owned(),
2388 }];
2389 let correct_boxes = [DetectBox {
2390 bbox: BoundingBox {
2391 xmin: 0.43171933,
2392 ymin: 0.68243736,
2393 xmax: 0.5626645,
2394 ymax: 0.808863,
2395 },
2396 score: 0.99240804,
2397 label: 0,
2398 }];
2399
2400 let quant0 = (0.08547406643629074, 174).into();
2401 let quant1 = (0.09929127991199493, 183).into();
2402 let quant_seg = (1.0 / 255.0, 0).into();
2403
2404 let anchors0 = vec![
2405 [0.36666667461395264, 0.31481480598449707],
2406 [0.38749998807907104, 0.4740740656852722],
2407 [0.5333333611488342, 0.644444465637207],
2408 ];
2409 let anchors1 = vec![
2410 [0.13750000298023224, 0.2074074000120163],
2411 [0.2541666626930237, 0.21481481194496155],
2412 [0.23125000298023224, 0.35185185074806213],
2413 ];
2414
2415 let decoder = DecoderBuilder::default()
2416 .with_config_modelpack_segdet_split(
2417 vec![
2418 configs::Detection {
2419 decoder: DecoderType::ModelPack,
2420 shape: vec![1, 17, 30, 18],
2421 anchors: Some(anchors1),
2422 quantization: Some(quant1),
2423 dshape: vec![
2424 (DimName::Batch, 1),
2425 (DimName::Height, 17),
2426 (DimName::Width, 30),
2427 (DimName::NumAnchorsXFeatures, 18),
2428 ],
2429 normalized: None,
2430 },
2431 configs::Detection {
2432 decoder: DecoderType::ModelPack,
2433 shape: vec![1, 9, 15, 18],
2434 anchors: Some(anchors0),
2435 quantization: Some(quant0),
2436 dshape: vec![
2437 (DimName::Batch, 1),
2438 (DimName::Height, 9),
2439 (DimName::Width, 15),
2440 (DimName::NumAnchorsXFeatures, 18),
2441 ],
2442 normalized: None,
2443 },
2444 ],
2445 configs::Segmentation {
2446 decoder: DecoderType::ModelPack,
2447 quantization: Some(quant_seg),
2448 shape: vec![1, 2, 160, 160],
2449 dshape: vec![
2450 (DimName::Batch, 1),
2451 (DimName::NumClasses, 2),
2452 (DimName::Height, 160),
2453 (DimName::Width, 160),
2454 ],
2455 },
2456 )
2457 .with_score_threshold(score_threshold)
2458 .with_iou_threshold(iou_threshold)
2459 .build()
2460 .unwrap();
2461 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
2462 let mut output_masks: Vec<_> = Vec::with_capacity(10);
2463
2464 for _ in 0..100 {
2465 decoder
2466 .decode_quantized(
2467 &[
2468 detect0.view().into(),
2469 detect1.view().into(),
2470 seg.view().into(),
2471 ],
2472 &mut output_boxes,
2473 &mut output_masks,
2474 )
2475 .unwrap();
2476
2477 compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
2478 }
2479 };
2480
2481 let handles = vec![
2482 std::thread::spawn(yolo_det),
2483 std::thread::spawn(modelpack_det_split),
2484 std::thread::spawn(yolo_det),
2485 std::thread::spawn(modelpack_det_split),
2486 std::thread::spawn(yolo_det),
2487 std::thread::spawn(modelpack_det_split),
2488 std::thread::spawn(yolo_det),
2489 std::thread::spawn(modelpack_det_split),
2490 ];
2491 for handle in handles {
2492 handle.join().unwrap();
2493 }
2494 }
2495
2496 #[test]
2497 fn test_ndarray_to_xyxy_float() {
2498 let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2499 let xyxy: [f32; 4] = XYWH::ndarray_to_xyxy_float(arr.view());
2500 assert_eq!(xyxy, [0.0_f32, 10.0, 20.0, 30.0]);
2501
2502 let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2503 let xyxy: [f32; 4] = XYXY::ndarray_to_xyxy_float(arr.view());
2504 assert_eq!(xyxy, [10.0_f32, 20.0, 20.0, 20.0]);
2505 }
2506
2507 #[test]
2508 fn test_class_aware_nms_float() {
2509 use crate::float::nms_class_aware_float;
2510
2511 let boxes = vec![
2513 DetectBox {
2514 bbox: BoundingBox {
2515 xmin: 0.0,
2516 ymin: 0.0,
2517 xmax: 0.5,
2518 ymax: 0.5,
2519 },
2520 score: 0.9,
2521 label: 0, },
2523 DetectBox {
2524 bbox: BoundingBox {
2525 xmin: 0.1,
2526 ymin: 0.1,
2527 xmax: 0.6,
2528 ymax: 0.6,
2529 },
2530 score: 0.8,
2531 label: 1, },
2533 ];
2534
2535 let result = nms_class_aware_float(0.3, boxes.clone());
2538 assert_eq!(
2539 result.len(),
2540 2,
2541 "Class-aware NMS should keep both boxes with different classes"
2542 );
2543
2544 let same_class_boxes = vec![
2546 DetectBox {
2547 bbox: BoundingBox {
2548 xmin: 0.0,
2549 ymin: 0.0,
2550 xmax: 0.5,
2551 ymax: 0.5,
2552 },
2553 score: 0.9,
2554 label: 0,
2555 },
2556 DetectBox {
2557 bbox: BoundingBox {
2558 xmin: 0.1,
2559 ymin: 0.1,
2560 xmax: 0.6,
2561 ymax: 0.6,
2562 },
2563 score: 0.8,
2564 label: 0, },
2566 ];
2567
2568 let result = nms_class_aware_float(0.3, same_class_boxes);
2569 assert_eq!(
2570 result.len(),
2571 1,
2572 "Class-aware NMS should suppress overlapping box with same class"
2573 );
2574 assert_eq!(result[0].label, 0);
2575 assert!((result[0].score - 0.9).abs() < 1e-6);
2576 }
2577
2578 #[test]
2579 fn test_class_agnostic_vs_aware_nms() {
2580 use crate::float::{nms_class_aware_float, nms_float};
2581
2582 let boxes = vec![
2584 DetectBox {
2585 bbox: BoundingBox {
2586 xmin: 0.0,
2587 ymin: 0.0,
2588 xmax: 0.5,
2589 ymax: 0.5,
2590 },
2591 score: 0.9,
2592 label: 0,
2593 },
2594 DetectBox {
2595 bbox: BoundingBox {
2596 xmin: 0.1,
2597 ymin: 0.1,
2598 xmax: 0.6,
2599 ymax: 0.6,
2600 },
2601 score: 0.8,
2602 label: 1,
2603 },
2604 ];
2605
2606 let agnostic_result = nms_float(0.3, boxes.clone());
2608 assert_eq!(
2609 agnostic_result.len(),
2610 1,
2611 "Class-agnostic NMS should suppress overlapping boxes"
2612 );
2613
2614 let aware_result = nms_class_aware_float(0.3, boxes);
2616 assert_eq!(
2617 aware_result.len(),
2618 2,
2619 "Class-aware NMS should keep boxes with different classes"
2620 );
2621 }
2622
2623 #[test]
2624 fn test_class_aware_nms_int() {
2625 use crate::byte::nms_class_aware_int;
2626
2627 let boxes = vec![
2629 DetectBoxQuantized {
2630 bbox: BoundingBox {
2631 xmin: 0.0,
2632 ymin: 0.0,
2633 xmax: 0.5,
2634 ymax: 0.5,
2635 },
2636 score: 200_u8,
2637 label: 0,
2638 },
2639 DetectBoxQuantized {
2640 bbox: BoundingBox {
2641 xmin: 0.1,
2642 ymin: 0.1,
2643 xmax: 0.6,
2644 ymax: 0.6,
2645 },
2646 score: 180_u8,
2647 label: 1, },
2649 ];
2650
2651 let result = nms_class_aware_int(0.5, boxes);
2653 assert_eq!(
2654 result.len(),
2655 2,
2656 "Class-aware NMS (int) should keep boxes with different classes"
2657 );
2658 }
2659
2660 #[test]
2661 fn test_nms_enum_default() {
2662 let default_nms: configs::Nms = Default::default();
2664 assert_eq!(default_nms, configs::Nms::ClassAgnostic);
2665 }
2666
2667 #[test]
2668 fn test_decoder_nms_mode() {
2669 let decoder = DecoderBuilder::default()
2671 .with_config_yolo_det(
2672 configs::Detection {
2673 anchors: None,
2674 decoder: DecoderType::Ultralytics,
2675 quantization: None,
2676 shape: vec![1, 84, 8400],
2677 dshape: Vec::new(),
2678 normalized: Some(true),
2679 },
2680 None,
2681 )
2682 .with_nms(Some(configs::Nms::ClassAware))
2683 .build()
2684 .unwrap();
2685
2686 assert_eq!(decoder.nms, Some(configs::Nms::ClassAware));
2687 }
2688
2689 #[test]
2690 fn test_decoder_nms_bypass() {
2691 let decoder = DecoderBuilder::default()
2693 .with_config_yolo_det(
2694 configs::Detection {
2695 anchors: None,
2696 decoder: DecoderType::Ultralytics,
2697 quantization: None,
2698 shape: vec![1, 84, 8400],
2699 dshape: Vec::new(),
2700 normalized: Some(true),
2701 },
2702 None,
2703 )
2704 .with_nms(None)
2705 .build()
2706 .unwrap();
2707
2708 assert_eq!(decoder.nms, None);
2709 }
2710
2711 #[test]
2712 fn test_decoder_normalized_boxes_true() {
2713 let decoder = DecoderBuilder::default()
2715 .with_config_yolo_det(
2716 configs::Detection {
2717 anchors: None,
2718 decoder: DecoderType::Ultralytics,
2719 quantization: None,
2720 shape: vec![1, 84, 8400],
2721 dshape: Vec::new(),
2722 normalized: Some(true),
2723 },
2724 None,
2725 )
2726 .build()
2727 .unwrap();
2728
2729 assert_eq!(decoder.normalized_boxes(), Some(true));
2730 }
2731
2732 #[test]
2733 fn test_decoder_normalized_boxes_false() {
2734 let decoder = DecoderBuilder::default()
2737 .with_config_yolo_det(
2738 configs::Detection {
2739 anchors: None,
2740 decoder: DecoderType::Ultralytics,
2741 quantization: None,
2742 shape: vec![1, 84, 8400],
2743 dshape: Vec::new(),
2744 normalized: Some(false),
2745 },
2746 None,
2747 )
2748 .build()
2749 .unwrap();
2750
2751 assert_eq!(decoder.normalized_boxes(), Some(false));
2752 }
2753
2754 #[test]
2755 fn test_decoder_normalized_boxes_unknown() {
2756 let decoder = DecoderBuilder::default()
2758 .with_config_yolo_det(
2759 configs::Detection {
2760 anchors: None,
2761 decoder: DecoderType::Ultralytics,
2762 quantization: None,
2763 shape: vec![1, 84, 8400],
2764 dshape: Vec::new(),
2765 normalized: None,
2766 },
2767 Some(DecoderVersion::Yolo11),
2768 )
2769 .build()
2770 .unwrap();
2771
2772 assert_eq!(decoder.normalized_boxes(), None);
2773 }
2774}
2775
2776#[cfg(feature = "tracker")]
2777#[cfg(test)]
2778#[cfg_attr(coverage_nightly, coverage(off))]
2779mod decoder_tracked_tests {
2780
2781 use edgefirst_tracker::{ByteTrackBuilder, Tracker};
2782 use ndarray::{array, s, Array, Array2, Array3, Array4, ArrayView, Axis, Dimension};
2783 use num_traits::{AsPrimitive, Float, PrimInt};
2784 use rand::{RngExt, SeedableRng};
2785 use rand_distr::StandardNormal;
2786
2787 use crate::{
2788 configs::{self, DimName},
2789 dequantize_ndarray, BoundingBox, DecoderBuilder, DetectBox, Quantization,
2790 };
2791
2792 pub fn quantize_ndarray<T: PrimInt + 'static, D: Dimension, F: Float + AsPrimitive<T>>(
2793 input: ArrayView<F, D>,
2794 quant: Quantization,
2795 ) -> Array<T, D>
2796 where
2797 i32: num_traits::AsPrimitive<F>,
2798 f32: num_traits::AsPrimitive<F>,
2799 {
2800 let zero_point = quant.zero_point.as_();
2801 let div_scale = F::one() / quant.scale.as_();
2802 if zero_point != F::zero() {
2803 input.mapv(|d| (d * div_scale + zero_point).round().as_())
2804 } else {
2805 input.mapv(|d| (d * div_scale).round().as_())
2806 }
2807 }
2808
2809 #[test]
2810 fn test_decoder_tracked_random_jitter() {
2811 use crate::configs::{DecoderType, Nms};
2812 use crate::DecoderBuilder;
2813
2814 let score_threshold = 0.25;
2815 let iou_threshold = 0.1;
2816 let out = include_bytes!(concat!(
2817 env!("CARGO_MANIFEST_DIR"),
2818 "/../../testdata/yolov8s_80_classes.bin"
2819 ));
2820 let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
2821 let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
2822 let quant = (0.0040811873, -123).into();
2823
2824 let decoder = DecoderBuilder::default()
2825 .with_config_yolo_det(
2826 crate::configs::Detection {
2827 decoder: DecoderType::Ultralytics,
2828 shape: vec![1, 84, 8400],
2829 anchors: None,
2830 quantization: Some(quant),
2831 dshape: vec![
2832 (crate::configs::DimName::Batch, 1),
2833 (crate::configs::DimName::NumFeatures, 84),
2834 (crate::configs::DimName::NumBoxes, 8400),
2835 ],
2836 normalized: Some(true),
2837 },
2838 None,
2839 )
2840 .with_score_threshold(score_threshold)
2841 .with_iou_threshold(iou_threshold)
2842 .with_nms(Some(Nms::ClassAgnostic))
2843 .build()
2844 .unwrap();
2845 let mut rng = rand::rngs::StdRng::seed_from_u64(0xAB_BEEF); let expected_boxes = [
2848 crate::DetectBox {
2849 bbox: crate::BoundingBox {
2850 xmin: 0.5285137,
2851 ymin: 0.05305544,
2852 xmax: 0.87541467,
2853 ymax: 0.9998909,
2854 },
2855 score: 0.5591227,
2856 label: 0,
2857 },
2858 crate::DetectBox {
2859 bbox: crate::BoundingBox {
2860 xmin: 0.130598,
2861 ymin: 0.43260583,
2862 xmax: 0.35098213,
2863 ymax: 0.9958097,
2864 },
2865 score: 0.33057618,
2866 label: 75,
2867 },
2868 ];
2869
2870 let mut tracker = ByteTrackBuilder::new()
2871 .track_update(0.1)
2872 .track_high_conf(0.3)
2873 .build();
2874
2875 let mut output_boxes = Vec::with_capacity(50);
2876 let mut output_masks = Vec::with_capacity(50);
2877 let mut output_tracks = Vec::with_capacity(50);
2878
2879 decoder
2880 .decode_tracked_quantized(
2881 &mut tracker,
2882 0,
2883 &[out.view().into()],
2884 &mut output_boxes,
2885 &mut output_masks,
2886 &mut output_tracks,
2887 )
2888 .unwrap();
2889
2890 assert_eq!(output_boxes.len(), 2);
2891 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
2892 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
2893
2894 let mut last_boxes = output_boxes.clone();
2895
2896 for i in 1..=100 {
2897 let mut out = out.clone();
2898 let mut x_values = out.slice_mut(s![0, 0, ..]);
2900 for x in x_values.iter_mut() {
2901 let r: f32 = rng.sample(StandardNormal);
2902 let r = r.clamp(-2.0, 2.0) / 2.0;
2903 *x = x.saturating_add((r * 1e-2 / quant.0) as i8);
2904 }
2905
2906 let mut y_values = out.slice_mut(s![0, 1, ..]);
2907 for y in y_values.iter_mut() {
2908 let r: f32 = rng.sample(StandardNormal);
2909 let r = r.clamp(-2.0, 2.0) / 2.0;
2910 *y = y.saturating_add((r * 1e-2 / quant.0) as i8);
2911 }
2912
2913 decoder
2914 .decode_tracked_quantized(
2915 &mut tracker,
2916 100_000_000 * i / 3, &[out.view().into()],
2918 &mut output_boxes,
2919 &mut output_masks,
2920 &mut output_tracks,
2921 )
2922 .unwrap();
2923
2924 assert_eq!(output_boxes.len(), 2);
2925 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 5e-3));
2926 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 5e-3));
2927
2928 assert!(output_boxes[0].equal_within_delta(&last_boxes[0], 1e-3));
2929 assert!(output_boxes[1].equal_within_delta(&last_boxes[1], 1e-3));
2930 last_boxes = output_boxes.clone();
2931 }
2932 }
2933
2934 #[test]
2935 fn test_decoder_tracked_segdet() {
2936 use crate::configs::Nms;
2937 use crate::DecoderBuilder;
2938
2939 let score_threshold = 0.45;
2940 let iou_threshold = 0.45;
2941 let boxes = include_bytes!(concat!(
2942 env!("CARGO_MANIFEST_DIR"),
2943 "/../../testdata/yolov8_boxes_116x8400.bin"
2944 ));
2945 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
2946 let mut boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes.to_vec()).unwrap();
2947
2948 let protos = include_bytes!(concat!(
2949 env!("CARGO_MANIFEST_DIR"),
2950 "/../../testdata/yolov8_protos_160x160x32.bin"
2951 ));
2952 let protos =
2953 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
2954 let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
2955
2956 let config = include_str!(concat!(
2957 env!("CARGO_MANIFEST_DIR"),
2958 "/../../testdata/yolov8_seg.yaml"
2959 ));
2960
2961 let decoder = DecoderBuilder::default()
2962 .with_config_yaml_str(config.to_string())
2963 .with_score_threshold(score_threshold)
2964 .with_iou_threshold(iou_threshold)
2965 .with_nms(Some(Nms::ClassAgnostic))
2966 .build()
2967 .unwrap();
2968
2969 let expected_boxes = [
2970 DetectBox {
2971 bbox: BoundingBox {
2972 xmin: 0.08515105,
2973 ymin: 0.7131401,
2974 xmax: 0.29802868,
2975 ymax: 0.8195788,
2976 },
2977 score: 0.91537374,
2978 label: 23,
2979 },
2980 DetectBox {
2981 bbox: BoundingBox {
2982 xmin: 0.59605736,
2983 ymin: 0.25545314,
2984 xmax: 0.93666154,
2985 ymax: 0.72378385,
2986 },
2987 score: 0.91537374,
2988 label: 23,
2989 },
2990 ];
2991
2992 let mut tracker = ByteTrackBuilder::new()
2993 .track_update(0.1)
2994 .track_high_conf(0.7)
2995 .build();
2996
2997 let mut output_boxes = Vec::with_capacity(50);
2998 let mut output_masks = Vec::with_capacity(50);
2999 let mut output_tracks = Vec::with_capacity(50);
3000
3001 decoder
3002 .decode_tracked_quantized(
3003 &mut tracker,
3004 0,
3005 &[boxes.view().into(), protos.view().into()],
3006 &mut output_boxes,
3007 &mut output_masks,
3008 &mut output_tracks,
3009 )
3010 .unwrap();
3011
3012 assert_eq!(output_boxes.len(), 2);
3013 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
3014 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1.0 / 160.0));
3015
3016 let mut scores_values = boxes.slice_mut(s![0, 4..84, ..]);
3018 for score in scores_values.iter_mut() {
3019 *score = i8::MIN; }
3021 decoder
3022 .decode_tracked_quantized(
3023 &mut tracker,
3024 100_000_000 / 3,
3025 &[boxes.view().into(), protos.view().into()],
3026 &mut output_boxes,
3027 &mut output_masks,
3028 &mut output_tracks,
3029 )
3030 .unwrap();
3031
3032 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3033 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3034
3035 assert!(output_masks.is_empty())
3037 }
3038
3039 #[test]
3040 fn test_decoder_tracked_segdet_float() {
3041 use crate::configs::Nms;
3042 use crate::DecoderBuilder;
3043
3044 let score_threshold = 0.45;
3045 let iou_threshold = 0.45;
3046 let boxes = include_bytes!(concat!(
3047 env!("CARGO_MANIFEST_DIR"),
3048 "/../../testdata/yolov8_boxes_116x8400.bin"
3049 ));
3050 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
3051 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes.to_vec()).unwrap();
3052 let quant_boxes = (0.021287762, 31);
3053 let mut boxes = dequantize_ndarray(boxes.view(), quant_boxes.into());
3054
3055 let protos = include_bytes!(concat!(
3056 env!("CARGO_MANIFEST_DIR"),
3057 "/../../testdata/yolov8_protos_160x160x32.bin"
3058 ));
3059 let protos =
3060 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
3061 let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
3062 let quant_protos = (0.02491162, -117);
3063 let protos = dequantize_ndarray(protos.view(), quant_protos.into());
3064
3065 let config = include_str!(concat!(
3066 env!("CARGO_MANIFEST_DIR"),
3067 "/../../testdata/yolov8_seg.yaml"
3068 ));
3069
3070 let decoder = DecoderBuilder::default()
3071 .with_config_yaml_str(config.to_string())
3072 .with_score_threshold(score_threshold)
3073 .with_iou_threshold(iou_threshold)
3074 .with_nms(Some(Nms::ClassAgnostic))
3075 .build()
3076 .unwrap();
3077
3078 let expected_boxes = [
3079 DetectBox {
3080 bbox: BoundingBox {
3081 xmin: 0.08515105,
3082 ymin: 0.7131401,
3083 xmax: 0.29802868,
3084 ymax: 0.8195788,
3085 },
3086 score: 0.91537374,
3087 label: 23,
3088 },
3089 DetectBox {
3090 bbox: BoundingBox {
3091 xmin: 0.59605736,
3092 ymin: 0.25545314,
3093 xmax: 0.93666154,
3094 ymax: 0.72378385,
3095 },
3096 score: 0.91537374,
3097 label: 23,
3098 },
3099 ];
3100
3101 let mut tracker = ByteTrackBuilder::new()
3102 .track_update(0.1)
3103 .track_high_conf(0.7)
3104 .build();
3105
3106 let mut output_boxes = Vec::with_capacity(50);
3107 let mut output_masks = Vec::with_capacity(50);
3108 let mut output_tracks = Vec::with_capacity(50);
3109
3110 decoder
3111 .decode_tracked_float(
3112 &mut tracker,
3113 0,
3114 &[boxes.view().into_dyn(), protos.view().into_dyn()],
3115 &mut output_boxes,
3116 &mut output_masks,
3117 &mut output_tracks,
3118 )
3119 .unwrap();
3120
3121 assert_eq!(output_boxes.len(), 2);
3122 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
3123 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1.0 / 160.0));
3124
3125 let mut scores_values = boxes.slice_mut(s![0, 4..84, ..]);
3127 for score in scores_values.iter_mut() {
3128 *score = 0.0; }
3130 decoder
3131 .decode_tracked_float(
3132 &mut tracker,
3133 100_000_000 / 3,
3134 &[boxes.view().into_dyn(), protos.view().into_dyn()],
3135 &mut output_boxes,
3136 &mut output_masks,
3137 &mut output_tracks,
3138 )
3139 .unwrap();
3140
3141 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3142 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3143
3144 assert!(output_masks.is_empty())
3146 }
3147
3148 #[test]
3149 fn test_decoder_tracked_segdet_proto() {
3150 use crate::configs::Nms;
3151 use crate::DecoderBuilder;
3152
3153 let score_threshold = 0.45;
3154 let iou_threshold = 0.45;
3155 let boxes = include_bytes!(concat!(
3156 env!("CARGO_MANIFEST_DIR"),
3157 "/../../testdata/yolov8_boxes_116x8400.bin"
3158 ));
3159 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
3160 let mut boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes.to_vec()).unwrap();
3161
3162 let protos = include_bytes!(concat!(
3163 env!("CARGO_MANIFEST_DIR"),
3164 "/../../testdata/yolov8_protos_160x160x32.bin"
3165 ));
3166 let protos =
3167 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
3168 let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
3169
3170 let config = include_str!(concat!(
3171 env!("CARGO_MANIFEST_DIR"),
3172 "/../../testdata/yolov8_seg.yaml"
3173 ));
3174
3175 let decoder = DecoderBuilder::default()
3176 .with_config_yaml_str(config.to_string())
3177 .with_score_threshold(score_threshold)
3178 .with_iou_threshold(iou_threshold)
3179 .with_nms(Some(Nms::ClassAgnostic))
3180 .build()
3181 .unwrap();
3182
3183 let expected_boxes = [
3184 DetectBox {
3185 bbox: BoundingBox {
3186 xmin: 0.08515105,
3187 ymin: 0.7131401,
3188 xmax: 0.29802868,
3189 ymax: 0.8195788,
3190 },
3191 score: 0.91537374,
3192 label: 23,
3193 },
3194 DetectBox {
3195 bbox: BoundingBox {
3196 xmin: 0.59605736,
3197 ymin: 0.25545314,
3198 xmax: 0.93666154,
3199 ymax: 0.72378385,
3200 },
3201 score: 0.91537374,
3202 label: 23,
3203 },
3204 ];
3205
3206 let mut tracker = ByteTrackBuilder::new()
3207 .track_update(0.1)
3208 .track_high_conf(0.7)
3209 .build();
3210
3211 let mut output_boxes = Vec::with_capacity(50);
3212 let mut output_tracks = Vec::with_capacity(50);
3213
3214 decoder
3215 .decode_tracked_quantized_proto(
3216 &mut tracker,
3217 0,
3218 &[boxes.view().into(), protos.view().into()],
3219 &mut output_boxes,
3220 &mut output_tracks,
3221 )
3222 .unwrap();
3223
3224 assert_eq!(output_boxes.len(), 2);
3225 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
3226 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1.0 / 160.0));
3227
3228 let mut scores_values = boxes.slice_mut(s![0, 4..84, ..]);
3230 for score in scores_values.iter_mut() {
3231 *score = i8::MIN; }
3233 let protos = decoder
3234 .decode_tracked_quantized_proto(
3235 &mut tracker,
3236 100_000_000 / 3,
3237 &[boxes.view().into(), protos.view().into()],
3238 &mut output_boxes,
3239 &mut output_tracks,
3240 )
3241 .unwrap();
3242
3243 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3244 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3245
3246 assert!(protos.is_some_and(|x| x.mask_coefficients.is_empty()));
3248 }
3249
3250 #[test]
3251 fn test_decoder_tracked_segdet_proto_float() {
3252 use crate::configs::Nms;
3253 use crate::DecoderBuilder;
3254
3255 let score_threshold = 0.45;
3256 let iou_threshold = 0.45;
3257 let boxes = include_bytes!(concat!(
3258 env!("CARGO_MANIFEST_DIR"),
3259 "/../../testdata/yolov8_boxes_116x8400.bin"
3260 ));
3261 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
3262 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes.to_vec()).unwrap();
3263 let quant_boxes = (0.021287762, 31);
3264 let mut boxes = dequantize_ndarray(boxes.view(), quant_boxes.into());
3265
3266 let protos = include_bytes!(concat!(
3267 env!("CARGO_MANIFEST_DIR"),
3268 "/../../testdata/yolov8_protos_160x160x32.bin"
3269 ));
3270 let protos =
3271 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
3272 let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
3273 let quant_protos = (0.02491162, -117);
3274 let protos = dequantize_ndarray(protos.view(), quant_protos.into());
3275
3276 let config = include_str!(concat!(
3277 env!("CARGO_MANIFEST_DIR"),
3278 "/../../testdata/yolov8_seg.yaml"
3279 ));
3280
3281 let decoder = DecoderBuilder::default()
3282 .with_config_yaml_str(config.to_string())
3283 .with_score_threshold(score_threshold)
3284 .with_iou_threshold(iou_threshold)
3285 .with_nms(Some(Nms::ClassAgnostic))
3286 .build()
3287 .unwrap();
3288
3289 let expected_boxes = [
3290 DetectBox {
3291 bbox: BoundingBox {
3292 xmin: 0.08515105,
3293 ymin: 0.7131401,
3294 xmax: 0.29802868,
3295 ymax: 0.8195788,
3296 },
3297 score: 0.91537374,
3298 label: 23,
3299 },
3300 DetectBox {
3301 bbox: BoundingBox {
3302 xmin: 0.59605736,
3303 ymin: 0.25545314,
3304 xmax: 0.93666154,
3305 ymax: 0.72378385,
3306 },
3307 score: 0.91537374,
3308 label: 23,
3309 },
3310 ];
3311
3312 let mut tracker = ByteTrackBuilder::new()
3313 .track_update(0.1)
3314 .track_high_conf(0.7)
3315 .build();
3316
3317 let mut output_boxes = Vec::with_capacity(50);
3318 let mut output_tracks = Vec::with_capacity(50);
3319
3320 decoder
3321 .decode_tracked_float_proto(
3322 &mut tracker,
3323 0,
3324 &[boxes.view().into_dyn(), protos.view().into_dyn()],
3325 &mut output_boxes,
3326 &mut output_tracks,
3327 )
3328 .unwrap();
3329
3330 assert_eq!(output_boxes.len(), 2);
3331 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
3332 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1.0 / 160.0));
3333
3334 let mut scores_values = boxes.slice_mut(s![0, 4..84, ..]);
3336 for score in scores_values.iter_mut() {
3337 *score = 0.0; }
3339 let protos = decoder
3340 .decode_tracked_float_proto(
3341 &mut tracker,
3342 100_000_000 / 3,
3343 &[boxes.view().into_dyn(), protos.view().into_dyn()],
3344 &mut output_boxes,
3345 &mut output_tracks,
3346 )
3347 .unwrap();
3348
3349 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3350 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3351
3352 assert!(protos.is_some_and(|x| x.mask_coefficients.is_empty()));
3354 }
3355
3356 #[test]
3357 fn test_decoder_tracked_segdet_split() {
3358 let score_threshold = 0.45;
3359 let iou_threshold = 0.45;
3360
3361 let boxes = include_bytes!(concat!(
3362 env!("CARGO_MANIFEST_DIR"),
3363 "/../../testdata/yolov8_boxes_116x8400.bin"
3364 ));
3365 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
3366 let boxes = boxes.to_vec();
3367 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
3368
3369 let mask = boxes.slice(s![.., 84.., ..]).to_owned();
3370 let mut scores = boxes.slice(s![.., 4..84, ..]).to_owned();
3371 let boxes = boxes.slice(s![.., ..4, ..]).to_owned();
3372
3373 let quant_boxes = (0.021287762, 31);
3374
3375 let protos = include_bytes!(concat!(
3376 env!("CARGO_MANIFEST_DIR"),
3377 "/../../testdata/yolov8_protos_160x160x32.bin"
3378 ));
3379 let protos =
3380 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
3381 let protos = protos.to_vec();
3382 let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos).unwrap();
3383 let quant_protos = (0.02491162, -117);
3384 let decoder = DecoderBuilder::default()
3385 .with_config_yolo_split_segdet(
3386 configs::Boxes {
3387 decoder: configs::DecoderType::Ultralytics,
3388 quantization: Some(quant_boxes.into()),
3389 shape: vec![1, 4, 8400],
3390 dshape: vec![
3391 (DimName::Batch, 1),
3392 (DimName::BoxCoords, 4),
3393 (DimName::NumBoxes, 8400),
3394 ],
3395 normalized: Some(true),
3396 },
3397 configs::Scores {
3398 decoder: configs::DecoderType::Ultralytics,
3399 quantization: Some(quant_boxes.into()),
3400 shape: vec![1, 80, 8400],
3401 dshape: vec![
3402 (DimName::Batch, 1),
3403 (DimName::NumClasses, 80),
3404 (DimName::NumBoxes, 8400),
3405 ],
3406 },
3407 configs::MaskCoefficients {
3408 decoder: configs::DecoderType::Ultralytics,
3409 quantization: Some(quant_boxes.into()),
3410 shape: vec![1, 32, 8400],
3411 dshape: vec![
3412 (DimName::Batch, 1),
3413 (DimName::NumProtos, 32),
3414 (DimName::NumBoxes, 8400),
3415 ],
3416 },
3417 configs::Protos {
3418 decoder: configs::DecoderType::Ultralytics,
3419 quantization: Some(quant_protos.into()),
3420 shape: vec![1, 160, 160, 32],
3421 dshape: vec![
3422 (DimName::Batch, 1),
3423 (DimName::Height, 160),
3424 (DimName::Width, 160),
3425 (DimName::NumProtos, 32),
3426 ],
3427 },
3428 )
3429 .with_score_threshold(score_threshold)
3430 .with_iou_threshold(iou_threshold)
3431 .build()
3432 .unwrap();
3433
3434 let expected_boxes = [
3435 DetectBox {
3436 bbox: BoundingBox {
3437 xmin: 0.08515105,
3438 ymin: 0.7131401,
3439 xmax: 0.29802868,
3440 ymax: 0.8195788,
3441 },
3442 score: 0.91537374,
3443 label: 23,
3444 },
3445 DetectBox {
3446 bbox: BoundingBox {
3447 xmin: 0.59605736,
3448 ymin: 0.25545314,
3449 xmax: 0.93666154,
3450 ymax: 0.72378385,
3451 },
3452 score: 0.91537374,
3453 label: 23,
3454 },
3455 ];
3456
3457 let mut tracker = ByteTrackBuilder::new()
3458 .track_update(0.1)
3459 .track_high_conf(0.7)
3460 .build();
3461
3462 let mut output_boxes = Vec::with_capacity(50);
3463 let mut output_masks = Vec::with_capacity(50);
3464 let mut output_tracks = Vec::with_capacity(50);
3465
3466 decoder
3467 .decode_tracked_quantized(
3468 &mut tracker,
3469 0,
3470 &[
3471 boxes.view().into(),
3472 scores.view().into(),
3473 mask.view().into(),
3474 protos.view().into(),
3475 ],
3476 &mut output_boxes,
3477 &mut output_masks,
3478 &mut output_tracks,
3479 )
3480 .unwrap();
3481
3482 assert_eq!(output_boxes.len(), 2);
3483 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
3484 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1.0 / 160.0));
3485
3486 for score in scores.iter_mut() {
3489 *score = i8::MIN; }
3491 decoder
3492 .decode_tracked_quantized(
3493 &mut tracker,
3494 100_000_000 / 3,
3495 &[
3496 boxes.view().into(),
3497 scores.view().into(),
3498 mask.view().into(),
3499 protos.view().into(),
3500 ],
3501 &mut output_boxes,
3502 &mut output_masks,
3503 &mut output_tracks,
3504 )
3505 .unwrap();
3506
3507 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3508 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3509
3510 assert!(output_masks.is_empty())
3512 }
3513
3514 #[test]
3515 fn test_decoder_tracked_segdet_split_float() {
3516 let score_threshold = 0.45;
3517 let iou_threshold = 0.45;
3518
3519 let boxes = include_bytes!(concat!(
3520 env!("CARGO_MANIFEST_DIR"),
3521 "/../../testdata/yolov8_boxes_116x8400.bin"
3522 ));
3523 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
3524 let boxes = boxes.to_vec();
3525 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
3526 let quant_boxes = (0.021287762, 31);
3527 let boxes = dequantize_ndarray(boxes.view(), quant_boxes.into());
3528
3529 let mask = boxes.slice(s![.., 84.., ..]).to_owned();
3530 let mut scores = boxes.slice(s![.., 4..84, ..]).to_owned();
3531 let boxes = boxes.slice(s![.., ..4, ..]).to_owned();
3532
3533 let protos = include_bytes!(concat!(
3534 env!("CARGO_MANIFEST_DIR"),
3535 "/../../testdata/yolov8_protos_160x160x32.bin"
3536 ));
3537 let protos =
3538 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
3539 let protos = protos.to_vec();
3540 let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos).unwrap();
3541 let quant_protos = (0.02491162, -117);
3542 let protos = dequantize_ndarray(protos.view(), quant_protos.into());
3543
3544 let decoder = DecoderBuilder::default()
3545 .with_config_yolo_split_segdet(
3546 configs::Boxes {
3547 decoder: configs::DecoderType::Ultralytics,
3548 quantization: Some(quant_boxes.into()),
3549 shape: vec![1, 4, 8400],
3550 dshape: vec![
3551 (DimName::Batch, 1),
3552 (DimName::BoxCoords, 4),
3553 (DimName::NumBoxes, 8400),
3554 ],
3555 normalized: Some(true),
3556 },
3557 configs::Scores {
3558 decoder: configs::DecoderType::Ultralytics,
3559 quantization: Some(quant_boxes.into()),
3560 shape: vec![1, 80, 8400],
3561 dshape: vec![
3562 (DimName::Batch, 1),
3563 (DimName::NumClasses, 80),
3564 (DimName::NumBoxes, 8400),
3565 ],
3566 },
3567 configs::MaskCoefficients {
3568 decoder: configs::DecoderType::Ultralytics,
3569 quantization: Some(quant_boxes.into()),
3570 shape: vec![1, 32, 8400],
3571 dshape: vec![
3572 (DimName::Batch, 1),
3573 (DimName::NumProtos, 32),
3574 (DimName::NumBoxes, 8400),
3575 ],
3576 },
3577 configs::Protos {
3578 decoder: configs::DecoderType::Ultralytics,
3579 quantization: Some(quant_protos.into()),
3580 shape: vec![1, 160, 160, 32],
3581 dshape: vec![
3582 (DimName::Batch, 1),
3583 (DimName::Height, 160),
3584 (DimName::Width, 160),
3585 (DimName::NumProtos, 32),
3586 ],
3587 },
3588 )
3589 .with_score_threshold(score_threshold)
3590 .with_iou_threshold(iou_threshold)
3591 .build()
3592 .unwrap();
3593
3594 let expected_boxes = [
3595 DetectBox {
3596 bbox: BoundingBox {
3597 xmin: 0.08515105,
3598 ymin: 0.7131401,
3599 xmax: 0.29802868,
3600 ymax: 0.8195788,
3601 },
3602 score: 0.91537374,
3603 label: 23,
3604 },
3605 DetectBox {
3606 bbox: BoundingBox {
3607 xmin: 0.59605736,
3608 ymin: 0.25545314,
3609 xmax: 0.93666154,
3610 ymax: 0.72378385,
3611 },
3612 score: 0.91537374,
3613 label: 23,
3614 },
3615 ];
3616
3617 let mut tracker = ByteTrackBuilder::new()
3618 .track_update(0.1)
3619 .track_high_conf(0.7)
3620 .build();
3621
3622 let mut output_boxes = Vec::with_capacity(50);
3623 let mut output_masks = Vec::with_capacity(50);
3624 let mut output_tracks = Vec::with_capacity(50);
3625
3626 decoder
3627 .decode_tracked_float(
3628 &mut tracker,
3629 0,
3630 &[
3631 boxes.view().into_dyn(),
3632 scores.view().into_dyn(),
3633 mask.view().into_dyn(),
3634 protos.view().into_dyn(),
3635 ],
3636 &mut output_boxes,
3637 &mut output_masks,
3638 &mut output_tracks,
3639 )
3640 .unwrap();
3641
3642 assert_eq!(output_boxes.len(), 2);
3643 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
3644 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1.0 / 160.0));
3645
3646 for score in scores.iter_mut() {
3649 *score = 0.0; }
3651 decoder
3652 .decode_tracked_float(
3653 &mut tracker,
3654 100_000_000 / 3,
3655 &[
3656 boxes.view().into_dyn(),
3657 scores.view().into_dyn(),
3658 mask.view().into_dyn(),
3659 protos.view().into_dyn(),
3660 ],
3661 &mut output_boxes,
3662 &mut output_masks,
3663 &mut output_tracks,
3664 )
3665 .unwrap();
3666
3667 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3668 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3669
3670 assert!(output_masks.is_empty())
3672 }
3673
3674 #[test]
3675 fn test_decoder_tracked_segdet_split_proto() {
3676 let score_threshold = 0.45;
3677 let iou_threshold = 0.45;
3678
3679 let boxes = include_bytes!(concat!(
3680 env!("CARGO_MANIFEST_DIR"),
3681 "/../../testdata/yolov8_boxes_116x8400.bin"
3682 ));
3683 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
3684 let boxes = boxes.to_vec();
3685 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
3686
3687 let mask = boxes.slice(s![.., 84.., ..]).to_owned();
3688 let mut scores = boxes.slice(s![.., 4..84, ..]).to_owned();
3689 let boxes = boxes.slice(s![.., ..4, ..]).to_owned();
3690
3691 let quant_boxes = (0.021287762, 31);
3692
3693 let protos = include_bytes!(concat!(
3694 env!("CARGO_MANIFEST_DIR"),
3695 "/../../testdata/yolov8_protos_160x160x32.bin"
3696 ));
3697 let protos =
3698 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
3699 let protos = protos.to_vec();
3700 let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos).unwrap();
3701 let quant_protos = (0.02491162, -117);
3702 let decoder = DecoderBuilder::default()
3703 .with_config_yolo_split_segdet(
3704 configs::Boxes {
3705 decoder: configs::DecoderType::Ultralytics,
3706 quantization: Some(quant_boxes.into()),
3707 shape: vec![1, 4, 8400],
3708 dshape: vec![
3709 (DimName::Batch, 1),
3710 (DimName::BoxCoords, 4),
3711 (DimName::NumBoxes, 8400),
3712 ],
3713 normalized: Some(true),
3714 },
3715 configs::Scores {
3716 decoder: configs::DecoderType::Ultralytics,
3717 quantization: Some(quant_boxes.into()),
3718 shape: vec![1, 80, 8400],
3719 dshape: vec![
3720 (DimName::Batch, 1),
3721 (DimName::NumClasses, 80),
3722 (DimName::NumBoxes, 8400),
3723 ],
3724 },
3725 configs::MaskCoefficients {
3726 decoder: configs::DecoderType::Ultralytics,
3727 quantization: Some(quant_boxes.into()),
3728 shape: vec![1, 32, 8400],
3729 dshape: vec![
3730 (DimName::Batch, 1),
3731 (DimName::NumProtos, 32),
3732 (DimName::NumBoxes, 8400),
3733 ],
3734 },
3735 configs::Protos {
3736 decoder: configs::DecoderType::Ultralytics,
3737 quantization: Some(quant_protos.into()),
3738 shape: vec![1, 160, 160, 32],
3739 dshape: vec![
3740 (DimName::Batch, 1),
3741 (DimName::Height, 160),
3742 (DimName::Width, 160),
3743 (DimName::NumProtos, 32),
3744 ],
3745 },
3746 )
3747 .with_score_threshold(score_threshold)
3748 .with_iou_threshold(iou_threshold)
3749 .build()
3750 .unwrap();
3751
3752 let expected_boxes = [
3753 DetectBox {
3754 bbox: BoundingBox {
3755 xmin: 0.08515105,
3756 ymin: 0.7131401,
3757 xmax: 0.29802868,
3758 ymax: 0.8195788,
3759 },
3760 score: 0.91537374,
3761 label: 23,
3762 },
3763 DetectBox {
3764 bbox: BoundingBox {
3765 xmin: 0.59605736,
3766 ymin: 0.25545314,
3767 xmax: 0.93666154,
3768 ymax: 0.72378385,
3769 },
3770 score: 0.91537374,
3771 label: 23,
3772 },
3773 ];
3774
3775 let mut tracker = ByteTrackBuilder::new()
3776 .track_update(0.1)
3777 .track_high_conf(0.7)
3778 .build();
3779
3780 let mut output_boxes = Vec::with_capacity(50);
3781 let mut output_tracks = Vec::with_capacity(50);
3782
3783 decoder
3784 .decode_tracked_quantized_proto(
3785 &mut tracker,
3786 0,
3787 &[
3788 boxes.view().into(),
3789 scores.view().into(),
3790 mask.view().into(),
3791 protos.view().into(),
3792 ],
3793 &mut output_boxes,
3794 &mut output_tracks,
3795 )
3796 .unwrap();
3797
3798 assert_eq!(output_boxes.len(), 2);
3799 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
3800 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1.0 / 160.0));
3801
3802 for score in scores.iter_mut() {
3805 *score = i8::MIN; }
3807 let protos = decoder
3808 .decode_tracked_quantized_proto(
3809 &mut tracker,
3810 100_000_000 / 3,
3811 &[
3812 boxes.view().into(),
3813 scores.view().into(),
3814 mask.view().into(),
3815 protos.view().into(),
3816 ],
3817 &mut output_boxes,
3818 &mut output_tracks,
3819 )
3820 .unwrap();
3821
3822 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3823 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3824
3825 assert!(protos.is_some_and(|x| x.mask_coefficients.is_empty()));
3827 }
3828
3829 #[test]
3830 fn test_decoder_tracked_segdet_split_proto_float() {
3831 let score_threshold = 0.45;
3832 let iou_threshold = 0.45;
3833
3834 let boxes = include_bytes!(concat!(
3835 env!("CARGO_MANIFEST_DIR"),
3836 "/../../testdata/yolov8_boxes_116x8400.bin"
3837 ));
3838 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
3839 let boxes = boxes.to_vec();
3840 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
3841 let quant_boxes = (0.021287762, 31);
3842 let boxes = dequantize_ndarray(boxes.view(), quant_boxes.into());
3843
3844 let mask = boxes.slice(s![.., 84.., ..]).to_owned();
3845 let mut scores = boxes.slice(s![.., 4..84, ..]).to_owned();
3846 let boxes = boxes.slice(s![.., ..4, ..]).to_owned();
3847
3848 let protos = include_bytes!(concat!(
3849 env!("CARGO_MANIFEST_DIR"),
3850 "/../../testdata/yolov8_protos_160x160x32.bin"
3851 ));
3852 let protos =
3853 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
3854 let protos = protos.to_vec();
3855 let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos).unwrap();
3856 let quant_protos = (0.02491162, -117);
3857 let protos = dequantize_ndarray(protos.view(), quant_protos.into());
3858
3859 let decoder = DecoderBuilder::default()
3860 .with_config_yolo_split_segdet(
3861 configs::Boxes {
3862 decoder: configs::DecoderType::Ultralytics,
3863 quantization: Some(quant_boxes.into()),
3864 shape: vec![1, 4, 8400],
3865 dshape: vec![
3866 (DimName::Batch, 1),
3867 (DimName::BoxCoords, 4),
3868 (DimName::NumBoxes, 8400),
3869 ],
3870 normalized: Some(true),
3871 },
3872 configs::Scores {
3873 decoder: configs::DecoderType::Ultralytics,
3874 quantization: Some(quant_boxes.into()),
3875 shape: vec![1, 80, 8400],
3876 dshape: vec![
3877 (DimName::Batch, 1),
3878 (DimName::NumClasses, 80),
3879 (DimName::NumBoxes, 8400),
3880 ],
3881 },
3882 configs::MaskCoefficients {
3883 decoder: configs::DecoderType::Ultralytics,
3884 quantization: Some(quant_boxes.into()),
3885 shape: vec![1, 32, 8400],
3886 dshape: vec![
3887 (DimName::Batch, 1),
3888 (DimName::NumProtos, 32),
3889 (DimName::NumBoxes, 8400),
3890 ],
3891 },
3892 configs::Protos {
3893 decoder: configs::DecoderType::Ultralytics,
3894 quantization: Some(quant_protos.into()),
3895 shape: vec![1, 160, 160, 32],
3896 dshape: vec![
3897 (DimName::Batch, 1),
3898 (DimName::Height, 160),
3899 (DimName::Width, 160),
3900 (DimName::NumProtos, 32),
3901 ],
3902 },
3903 )
3904 .with_score_threshold(score_threshold)
3905 .with_iou_threshold(iou_threshold)
3906 .build()
3907 .unwrap();
3908
3909 let expected_boxes = [
3910 DetectBox {
3911 bbox: BoundingBox {
3912 xmin: 0.08515105,
3913 ymin: 0.7131401,
3914 xmax: 0.29802868,
3915 ymax: 0.8195788,
3916 },
3917 score: 0.91537374,
3918 label: 23,
3919 },
3920 DetectBox {
3921 bbox: BoundingBox {
3922 xmin: 0.59605736,
3923 ymin: 0.25545314,
3924 xmax: 0.93666154,
3925 ymax: 0.72378385,
3926 },
3927 score: 0.91537374,
3928 label: 23,
3929 },
3930 ];
3931
3932 let mut tracker = ByteTrackBuilder::new()
3933 .track_update(0.1)
3934 .track_high_conf(0.7)
3935 .build();
3936
3937 let mut output_boxes = Vec::with_capacity(50);
3938 let mut output_tracks = Vec::with_capacity(50);
3939
3940 decoder
3941 .decode_tracked_float_proto(
3942 &mut tracker,
3943 0,
3944 &[
3945 boxes.view().into_dyn(),
3946 scores.view().into_dyn(),
3947 mask.view().into_dyn(),
3948 protos.view().into_dyn(),
3949 ],
3950 &mut output_boxes,
3951 &mut output_tracks,
3952 )
3953 .unwrap();
3954
3955 assert_eq!(output_boxes.len(), 2);
3956 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
3957 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1.0 / 160.0));
3958
3959 for score in scores.iter_mut() {
3962 *score = 0.0; }
3964 let protos = decoder
3965 .decode_tracked_float_proto(
3966 &mut tracker,
3967 100_000_000 / 3,
3968 &[
3969 boxes.view().into_dyn(),
3970 scores.view().into_dyn(),
3971 mask.view().into_dyn(),
3972 protos.view().into_dyn(),
3973 ],
3974 &mut output_boxes,
3975 &mut output_tracks,
3976 )
3977 .unwrap();
3978
3979 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3980 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3981
3982 assert!(protos.is_some_and(|x| x.mask_coefficients.is_empty()));
3984 }
3985
3986 #[test]
3987 fn test_decoder_tracked_end_to_end_segdet() {
3988 let score_threshold = 0.45;
3989 let iou_threshold = 0.45;
3990
3991 let mut boxes = Array2::zeros((10, 4));
3992 let mut scores = Array2::zeros((10, 1));
3993 let mut classes = Array2::zeros((10, 1));
3994 let mask = Array2::zeros((10, 32));
3995 let protos = Array3::<f64>::zeros((160, 160, 32));
3996 let protos = protos.insert_axis(Axis(0));
3997
3998 let protos_quant = (1.0 / 255.0, 0.0);
3999 let protos: Array4<u8> = quantize_ndarray(protos.view(), protos_quant.into());
4000
4001 boxes
4002 .slice_mut(s![0, ..,])
4003 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4004 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4005 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4006
4007 let detect = ndarray::concatenate![
4008 Axis(1),
4009 boxes.view(),
4010 scores.view(),
4011 classes.view(),
4012 mask.view()
4013 ];
4014 let detect = detect.insert_axis(Axis(0));
4015 assert_eq!(detect.shape(), &[1, 10, 38]);
4016 let detect_quant = (2.0 / 255.0, 0.0);
4017 let mut detect: Array3<u8> = quantize_ndarray(detect.view(), detect_quant.into());
4018 let config = "
4019decoder_version: yolo26
4020outputs:
4021 - type: detection
4022 decoder: ultralytics
4023 quantization: [0.00784313725490196, 0]
4024 shape: [1, 10, 38]
4025 dshape:
4026 - [batch, 1]
4027 - [num_boxes, 10]
4028 - [num_features, 38]
4029 normalized: true
4030 - type: protos
4031 decoder: ultralytics
4032 quantization: [0.0039215686274509803921568627451, 128]
4033 shape: [1, 160, 160, 32]
4034 dshape:
4035 - [batch, 1]
4036 - [height, 160]
4037 - [width, 160]
4038 - [num_protos, 32]
4039";
4040
4041 let decoder = DecoderBuilder::default()
4042 .with_config_yaml_str(config.to_string())
4043 .with_score_threshold(score_threshold)
4044 .with_iou_threshold(iou_threshold)
4045 .build()
4046 .unwrap();
4047
4048 let expected_boxes = [DetectBox {
4050 bbox: BoundingBox {
4051 xmin: 0.12549022,
4052 ymin: 0.12549022,
4053 xmax: 0.23529413,
4054 ymax: 0.23529413,
4055 },
4056 score: 0.98823535,
4057 label: 2,
4058 }];
4059
4060 let mut tracker = ByteTrackBuilder::new()
4061 .track_update(0.1)
4062 .track_high_conf(0.7)
4063 .build();
4064
4065 let mut output_boxes = Vec::with_capacity(50);
4066 let mut output_masks = Vec::with_capacity(50);
4067 let mut output_tracks = Vec::with_capacity(50);
4068
4069 decoder
4070 .decode_tracked_quantized(
4071 &mut tracker,
4072 0,
4073 &[detect.view().into(), protos.view().into()],
4074 &mut output_boxes,
4075 &mut output_masks,
4076 &mut output_tracks,
4077 )
4078 .unwrap();
4079
4080 assert_eq!(output_boxes.len(), 1);
4081 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
4082
4083 for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
4086 *score = u8::MIN; }
4088
4089 decoder
4090 .decode_tracked_quantized(
4091 &mut tracker,
4092 100_000_000 / 3,
4093 &[detect.view().into(), protos.view().into()],
4094 &mut output_boxes,
4095 &mut output_masks,
4096 &mut output_tracks,
4097 )
4098 .unwrap();
4099 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4100 assert!(output_masks.is_empty())
4102 }
4103
4104 #[test]
4105 fn test_decoder_tracked_end_to_end_segdet_float() {
4106 let score_threshold = 0.45;
4107 let iou_threshold = 0.45;
4108
4109 let mut boxes = Array2::zeros((10, 4));
4110 let mut scores = Array2::zeros((10, 1));
4111 let mut classes = Array2::zeros((10, 1));
4112 let mask = Array2::zeros((10, 32));
4113 let protos = Array3::<f64>::zeros((160, 160, 32));
4114 let protos = protos.insert_axis(Axis(0));
4115
4116 boxes
4117 .slice_mut(s![0, ..,])
4118 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4119 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4120 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4121
4122 let detect = ndarray::concatenate![
4123 Axis(1),
4124 boxes.view(),
4125 scores.view(),
4126 classes.view(),
4127 mask.view()
4128 ];
4129 let mut detect = detect.insert_axis(Axis(0));
4130 assert_eq!(detect.shape(), &[1, 10, 38]);
4131 let config = "
4132decoder_version: yolo26
4133outputs:
4134 - type: detection
4135 decoder: ultralytics
4136 quantization: [0.00784313725490196, 0]
4137 shape: [1, 10, 38]
4138 dshape:
4139 - [batch, 1]
4140 - [num_boxes, 10]
4141 - [num_features, 38]
4142 normalized: true
4143 - type: protos
4144 decoder: ultralytics
4145 quantization: [0.0039215686274509803921568627451, 128]
4146 shape: [1, 160, 160, 32]
4147 dshape:
4148 - [batch, 1]
4149 - [height, 160]
4150 - [width, 160]
4151 - [num_protos, 32]
4152";
4153
4154 let decoder = DecoderBuilder::default()
4155 .with_config_yaml_str(config.to_string())
4156 .with_score_threshold(score_threshold)
4157 .with_iou_threshold(iou_threshold)
4158 .build()
4159 .unwrap();
4160
4161 let expected_boxes = [DetectBox {
4162 bbox: BoundingBox {
4163 xmin: 0.1234,
4164 ymin: 0.1234,
4165 xmax: 0.2345,
4166 ymax: 0.2345,
4167 },
4168 score: 0.9876,
4169 label: 2,
4170 }];
4171
4172 let mut tracker = ByteTrackBuilder::new()
4173 .track_update(0.1)
4174 .track_high_conf(0.7)
4175 .build();
4176
4177 let mut output_boxes = Vec::with_capacity(50);
4178 let mut output_masks = Vec::with_capacity(50);
4179 let mut output_tracks = Vec::with_capacity(50);
4180
4181 decoder
4182 .decode_tracked_float(
4183 &mut tracker,
4184 0,
4185 &[detect.view().into_dyn(), protos.view().into_dyn()],
4186 &mut output_boxes,
4187 &mut output_masks,
4188 &mut output_tracks,
4189 )
4190 .unwrap();
4191
4192 assert_eq!(output_boxes.len(), 1);
4193 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
4194
4195 for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
4198 *score = 0.0; }
4200
4201 decoder
4202 .decode_tracked_float(
4203 &mut tracker,
4204 100_000_000 / 3,
4205 &[detect.view().into_dyn(), protos.view().into_dyn()],
4206 &mut output_boxes,
4207 &mut output_masks,
4208 &mut output_tracks,
4209 )
4210 .unwrap();
4211 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4212 assert!(output_masks.is_empty())
4214 }
4215
4216 #[test]
4217 fn test_decoder_tracked_end_to_end_segdet_proto() {
4218 let score_threshold = 0.45;
4219 let iou_threshold = 0.45;
4220
4221 let mut boxes = Array2::zeros((10, 4));
4222 let mut scores = Array2::zeros((10, 1));
4223 let mut classes = Array2::zeros((10, 1));
4224 let mask = Array2::zeros((10, 32));
4225 let protos = Array3::<f64>::zeros((160, 160, 32));
4226 let protos = protos.insert_axis(Axis(0));
4227
4228 let protos_quant = (1.0 / 255.0, 0.0);
4229 let protos: Array4<u8> = quantize_ndarray(protos.view(), protos_quant.into());
4230
4231 boxes
4232 .slice_mut(s![0, ..,])
4233 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4234 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4235 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4236
4237 let detect = ndarray::concatenate![
4238 Axis(1),
4239 boxes.view(),
4240 scores.view(),
4241 classes.view(),
4242 mask.view()
4243 ];
4244 let detect = detect.insert_axis(Axis(0));
4245 assert_eq!(detect.shape(), &[1, 10, 38]);
4246 let detect_quant = (2.0 / 255.0, 0.0);
4247 let mut detect: Array3<u8> = quantize_ndarray(detect.view(), detect_quant.into());
4248 let config = "
4249decoder_version: yolo26
4250outputs:
4251 - type: detection
4252 decoder: ultralytics
4253 quantization: [0.00784313725490196, 0]
4254 shape: [1, 10, 38]
4255 dshape:
4256 - [batch, 1]
4257 - [num_boxes, 10]
4258 - [num_features, 38]
4259 normalized: true
4260 - type: protos
4261 decoder: ultralytics
4262 quantization: [0.0039215686274509803921568627451, 128]
4263 shape: [1, 160, 160, 32]
4264 dshape:
4265 - [batch, 1]
4266 - [height, 160]
4267 - [width, 160]
4268 - [num_protos, 32]
4269";
4270
4271 let decoder = DecoderBuilder::default()
4272 .with_config_yaml_str(config.to_string())
4273 .with_score_threshold(score_threshold)
4274 .with_iou_threshold(iou_threshold)
4275 .build()
4276 .unwrap();
4277
4278 let expected_boxes = [DetectBox {
4280 bbox: BoundingBox {
4281 xmin: 0.12549022,
4282 ymin: 0.12549022,
4283 xmax: 0.23529413,
4284 ymax: 0.23529413,
4285 },
4286 score: 0.98823535,
4287 label: 2,
4288 }];
4289
4290 let mut tracker = ByteTrackBuilder::new()
4291 .track_update(0.1)
4292 .track_high_conf(0.7)
4293 .build();
4294
4295 let mut output_boxes = Vec::with_capacity(50);
4296 let mut output_tracks = Vec::with_capacity(50);
4297
4298 decoder
4299 .decode_tracked_quantized_proto(
4300 &mut tracker,
4301 0,
4302 &[detect.view().into(), protos.view().into()],
4303 &mut output_boxes,
4304 &mut output_tracks,
4305 )
4306 .unwrap();
4307
4308 assert_eq!(output_boxes.len(), 1);
4309 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
4310
4311 for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
4314 *score = u8::MIN; }
4316
4317 let protos = decoder
4318 .decode_tracked_quantized_proto(
4319 &mut tracker,
4320 100_000_000 / 3,
4321 &[detect.view().into(), protos.view().into()],
4322 &mut output_boxes,
4323 &mut output_tracks,
4324 )
4325 .unwrap();
4326 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4327 assert!(protos.is_some_and(|x| x.mask_coefficients.is_empty()))
4329 }
4330
4331 #[test]
4332 fn test_decoder_tracked_end_to_end_segdet_proto_float() {
4333 let score_threshold = 0.45;
4334 let iou_threshold = 0.45;
4335
4336 let mut boxes = Array2::zeros((10, 4));
4337 let mut scores = Array2::zeros((10, 1));
4338 let mut classes = Array2::zeros((10, 1));
4339 let mask = Array2::zeros((10, 32));
4340 let protos = Array3::<f64>::zeros((160, 160, 32));
4341 let protos = protos.insert_axis(Axis(0));
4342
4343 boxes
4344 .slice_mut(s![0, ..,])
4345 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4346 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4347 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4348
4349 let detect = ndarray::concatenate![
4350 Axis(1),
4351 boxes.view(),
4352 scores.view(),
4353 classes.view(),
4354 mask.view()
4355 ];
4356 let mut detect = detect.insert_axis(Axis(0));
4357 assert_eq!(detect.shape(), &[1, 10, 38]);
4358 let config = "
4359decoder_version: yolo26
4360outputs:
4361 - type: detection
4362 decoder: ultralytics
4363 quantization: [0.00784313725490196, 0]
4364 shape: [1, 10, 38]
4365 dshape:
4366 - [batch, 1]
4367 - [num_boxes, 10]
4368 - [num_features, 38]
4369 normalized: true
4370 - type: protos
4371 decoder: ultralytics
4372 quantization: [0.0039215686274509803921568627451, 128]
4373 shape: [1, 160, 160, 32]
4374 dshape:
4375 - [batch, 1]
4376 - [height, 160]
4377 - [width, 160]
4378 - [num_protos, 32]
4379";
4380
4381 let decoder = DecoderBuilder::default()
4382 .with_config_yaml_str(config.to_string())
4383 .with_score_threshold(score_threshold)
4384 .with_iou_threshold(iou_threshold)
4385 .build()
4386 .unwrap();
4387
4388 let expected_boxes = [DetectBox {
4389 bbox: BoundingBox {
4390 xmin: 0.1234,
4391 ymin: 0.1234,
4392 xmax: 0.2345,
4393 ymax: 0.2345,
4394 },
4395 score: 0.9876,
4396 label: 2,
4397 }];
4398
4399 let mut tracker = ByteTrackBuilder::new()
4400 .track_update(0.1)
4401 .track_high_conf(0.7)
4402 .build();
4403
4404 let mut output_boxes = Vec::with_capacity(50);
4405 let mut output_tracks = Vec::with_capacity(50);
4406
4407 decoder
4408 .decode_tracked_float_proto(
4409 &mut tracker,
4410 0,
4411 &[detect.view().into_dyn(), protos.view().into_dyn()],
4412 &mut output_boxes,
4413 &mut output_tracks,
4414 )
4415 .unwrap();
4416
4417 assert_eq!(output_boxes.len(), 1);
4418 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
4419
4420 for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
4423 *score = 0.0; }
4425
4426 let protos = decoder
4427 .decode_tracked_float_proto(
4428 &mut tracker,
4429 100_000_000 / 3,
4430 &[detect.view().into_dyn(), protos.view().into_dyn()],
4431 &mut output_boxes,
4432 &mut output_tracks,
4433 )
4434 .unwrap();
4435 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4436 assert!(protos.is_some_and(|x| x.mask_coefficients.is_empty()))
4438 }
4439
4440 #[test]
4441 fn test_decoder_tracked_end_to_end_segdet_split() {
4442 let score_threshold = 0.45;
4443 let iou_threshold = 0.45;
4444
4445 let mut boxes = Array2::zeros((10, 4));
4446 let mut scores = Array2::zeros((10, 1));
4447 let mut classes = Array2::zeros((10, 1));
4448 let mask: Array2<f64> = Array2::zeros((10, 32));
4449 let protos = Array3::<f64>::zeros((160, 160, 32));
4450 let protos = protos.insert_axis(Axis(0));
4451
4452 let protos_quant = (1.0 / 255.0, 0.0);
4453 let protos: Array4<u8> = quantize_ndarray(protos.view(), protos_quant.into());
4454
4455 boxes
4456 .slice_mut(s![0, ..,])
4457 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4458 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4459 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4460
4461 let boxes = boxes.insert_axis(Axis(0));
4462 let scores = scores.insert_axis(Axis(0));
4463 let classes = classes.insert_axis(Axis(0));
4464 let mask = mask.insert_axis(Axis(0));
4465
4466 let detect_quant = (2.0 / 255.0, 0.0);
4467 let boxes: Array3<u8> = quantize_ndarray(boxes.view(), detect_quant.into());
4468 let mut scores: Array3<u8> = quantize_ndarray(scores.view(), detect_quant.into());
4469 let classes: Array3<u8> = quantize_ndarray(classes.view(), detect_quant.into());
4470 let mask: Array3<u8> = quantize_ndarray(mask.view(), detect_quant.into());
4471
4472 let config = "
4473decoder_version: yolo26
4474outputs:
4475 - type: boxes
4476 decoder: ultralytics
4477 quantization: [0.00784313725490196, 0]
4478 shape: [1, 10, 4]
4479 dshape:
4480 - [batch, 1]
4481 - [num_boxes, 10]
4482 - [box_coords, 4]
4483 normalized: true
4484 - type: scores
4485 decoder: ultralytics
4486 quantization: [0.00784313725490196, 0]
4487 shape: [1, 10, 1]
4488 dshape:
4489 - [batch, 1]
4490 - [num_boxes, 10]
4491 - [num_classes, 1]
4492 - type: classes
4493 decoder: ultralytics
4494 quantization: [0.00784313725490196, 0]
4495 shape: [1, 10, 1]
4496 dshape:
4497 - [batch, 1]
4498 - [num_boxes, 10]
4499 - [num_classes, 1]
4500 - type: mask_coefficients
4501 decoder: ultralytics
4502 quantization: [0.00784313725490196, 0]
4503 shape: [1, 10, 32]
4504 dshape:
4505 - [batch, 1]
4506 - [num_boxes, 10]
4507 - [num_protos, 32]
4508 - type: protos
4509 decoder: ultralytics
4510 quantization: [0.0039215686274509803921568627451, 128]
4511 shape: [1, 160, 160, 32]
4512 dshape:
4513 - [batch, 1]
4514 - [height, 160]
4515 - [width, 160]
4516 - [num_protos, 32]
4517";
4518
4519 let decoder = DecoderBuilder::default()
4520 .with_config_yaml_str(config.to_string())
4521 .with_score_threshold(score_threshold)
4522 .with_iou_threshold(iou_threshold)
4523 .build()
4524 .unwrap();
4525
4526 let expected_boxes = [DetectBox {
4528 bbox: BoundingBox {
4529 xmin: 0.12549022,
4530 ymin: 0.12549022,
4531 xmax: 0.23529413,
4532 ymax: 0.23529413,
4533 },
4534 score: 0.98823535,
4535 label: 2,
4536 }];
4537
4538 let mut tracker = ByteTrackBuilder::new()
4539 .track_update(0.1)
4540 .track_high_conf(0.7)
4541 .build();
4542
4543 let mut output_boxes = Vec::with_capacity(50);
4544 let mut output_masks = Vec::with_capacity(50);
4545 let mut output_tracks = Vec::with_capacity(50);
4546
4547 decoder
4548 .decode_tracked_quantized(
4549 &mut tracker,
4550 0,
4551 &[
4552 boxes.view().into(),
4553 scores.view().into(),
4554 classes.view().into(),
4555 mask.view().into(),
4556 protos.view().into(),
4557 ],
4558 &mut output_boxes,
4559 &mut output_masks,
4560 &mut output_tracks,
4561 )
4562 .unwrap();
4563
4564 assert_eq!(output_boxes.len(), 1);
4565 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
4566
4567 for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
4570 *score = u8::MIN; }
4572
4573 decoder
4574 .decode_tracked_quantized(
4575 &mut tracker,
4576 100_000_000 / 3,
4577 &[
4578 boxes.view().into(),
4579 scores.view().into(),
4580 classes.view().into(),
4581 mask.view().into(),
4582 protos.view().into(),
4583 ],
4584 &mut output_boxes,
4585 &mut output_masks,
4586 &mut output_tracks,
4587 )
4588 .unwrap();
4589 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4590 assert!(output_masks.is_empty())
4592 }
4593 #[test]
4594 fn test_decoder_tracked_end_to_end_segdet_split_float() {
4595 let score_threshold = 0.45;
4596 let iou_threshold = 0.45;
4597
4598 let mut boxes = Array2::zeros((10, 4));
4599 let mut scores = Array2::zeros((10, 1));
4600 let mut classes = Array2::zeros((10, 1));
4601 let mask: Array2<f64> = Array2::zeros((10, 32));
4602 let protos = Array3::<f64>::zeros((160, 160, 32));
4603 let protos = protos.insert_axis(Axis(0));
4604
4605 boxes
4606 .slice_mut(s![0, ..,])
4607 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4608 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4609 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4610
4611 let boxes = boxes.insert_axis(Axis(0));
4612 let mut scores = scores.insert_axis(Axis(0));
4613 let classes = classes.insert_axis(Axis(0));
4614 let mask = mask.insert_axis(Axis(0));
4615
4616 let config = "
4617decoder_version: yolo26
4618outputs:
4619 - type: boxes
4620 decoder: ultralytics
4621 quantization: [0.00784313725490196, 0]
4622 shape: [1, 10, 4]
4623 dshape:
4624 - [batch, 1]
4625 - [num_boxes, 10]
4626 - [box_coords, 4]
4627 normalized: true
4628 - type: scores
4629 decoder: ultralytics
4630 quantization: [0.00784313725490196, 0]
4631 shape: [1, 10, 1]
4632 dshape:
4633 - [batch, 1]
4634 - [num_boxes, 10]
4635 - [num_classes, 1]
4636 - type: classes
4637 decoder: ultralytics
4638 quantization: [0.00784313725490196, 0]
4639 shape: [1, 10, 1]
4640 dshape:
4641 - [batch, 1]
4642 - [num_boxes, 10]
4643 - [num_classes, 1]
4644 - type: mask_coefficients
4645 decoder: ultralytics
4646 quantization: [0.00784313725490196, 0]
4647 shape: [1, 10, 32]
4648 dshape:
4649 - [batch, 1]
4650 - [num_boxes, 10]
4651 - [num_protos, 32]
4652 - type: protos
4653 decoder: ultralytics
4654 quantization: [0.0039215686274509803921568627451, 128]
4655 shape: [1, 160, 160, 32]
4656 dshape:
4657 - [batch, 1]
4658 - [height, 160]
4659 - [width, 160]
4660 - [num_protos, 32]
4661";
4662
4663 let decoder = DecoderBuilder::default()
4664 .with_config_yaml_str(config.to_string())
4665 .with_score_threshold(score_threshold)
4666 .with_iou_threshold(iou_threshold)
4667 .build()
4668 .unwrap();
4669
4670 let expected_boxes = [DetectBox {
4672 bbox: BoundingBox {
4673 xmin: 0.1234,
4674 ymin: 0.1234,
4675 xmax: 0.2345,
4676 ymax: 0.2345,
4677 },
4678 score: 0.9876,
4679 label: 2,
4680 }];
4681
4682 let mut tracker = ByteTrackBuilder::new()
4683 .track_update(0.1)
4684 .track_high_conf(0.7)
4685 .build();
4686
4687 let mut output_boxes = Vec::with_capacity(50);
4688 let mut output_masks = Vec::with_capacity(50);
4689 let mut output_tracks = Vec::with_capacity(50);
4690
4691 decoder
4692 .decode_tracked_float(
4693 &mut tracker,
4694 0,
4695 &[
4696 boxes.view().into_dyn(),
4697 scores.view().into_dyn(),
4698 classes.view().into_dyn(),
4699 mask.view().into_dyn(),
4700 protos.view().into_dyn(),
4701 ],
4702 &mut output_boxes,
4703 &mut output_masks,
4704 &mut output_tracks,
4705 )
4706 .unwrap();
4707
4708 assert_eq!(output_boxes.len(), 1);
4709 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
4710
4711 for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
4714 *score = 0.0; }
4716
4717 decoder
4718 .decode_tracked_float(
4719 &mut tracker,
4720 100_000_000 / 3,
4721 &[
4722 boxes.view().into_dyn(),
4723 scores.view().into_dyn(),
4724 classes.view().into_dyn(),
4725 mask.view().into_dyn(),
4726 protos.view().into_dyn(),
4727 ],
4728 &mut output_boxes,
4729 &mut output_masks,
4730 &mut output_tracks,
4731 )
4732 .unwrap();
4733 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4734 assert!(output_masks.is_empty())
4736 }
4737
4738 #[test]
4739 fn test_decoder_tracked_end_to_end_segdet_split_proto() {
4740 let score_threshold = 0.45;
4741 let iou_threshold = 0.45;
4742
4743 let mut boxes = Array2::zeros((10, 4));
4744 let mut scores = Array2::zeros((10, 1));
4745 let mut classes = Array2::zeros((10, 1));
4746 let mask: Array2<f64> = Array2::zeros((10, 32));
4747 let protos = Array3::<f64>::zeros((160, 160, 32));
4748 let protos = protos.insert_axis(Axis(0));
4749
4750 let protos_quant = (1.0 / 255.0, 0.0);
4751 let protos: Array4<u8> = quantize_ndarray(protos.view(), protos_quant.into());
4752
4753 boxes
4754 .slice_mut(s![0, ..,])
4755 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4756 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4757 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4758
4759 let boxes = boxes.insert_axis(Axis(0));
4760 let scores = scores.insert_axis(Axis(0));
4761 let classes = classes.insert_axis(Axis(0));
4762 let mask = mask.insert_axis(Axis(0));
4763
4764 let detect_quant = (2.0 / 255.0, 0.0);
4765 let boxes: Array3<u8> = quantize_ndarray(boxes.view(), detect_quant.into());
4766 let mut scores: Array3<u8> = quantize_ndarray(scores.view(), detect_quant.into());
4767 let classes: Array3<u8> = quantize_ndarray(classes.view(), detect_quant.into());
4768 let mask: Array3<u8> = quantize_ndarray(mask.view(), detect_quant.into());
4769
4770 let config = "
4771decoder_version: yolo26
4772outputs:
4773 - type: boxes
4774 decoder: ultralytics
4775 quantization: [0.00784313725490196, 0]
4776 shape: [1, 10, 4]
4777 dshape:
4778 - [batch, 1]
4779 - [num_boxes, 10]
4780 - [box_coords, 4]
4781 normalized: true
4782 - type: scores
4783 decoder: ultralytics
4784 quantization: [0.00784313725490196, 0]
4785 shape: [1, 10, 1]
4786 dshape:
4787 - [batch, 1]
4788 - [num_boxes, 10]
4789 - [num_classes, 1]
4790 - type: classes
4791 decoder: ultralytics
4792 quantization: [0.00784313725490196, 0]
4793 shape: [1, 10, 1]
4794 dshape:
4795 - [batch, 1]
4796 - [num_boxes, 10]
4797 - [num_classes, 1]
4798 - type: mask_coefficients
4799 decoder: ultralytics
4800 quantization: [0.00784313725490196, 0]
4801 shape: [1, 10, 32]
4802 dshape:
4803 - [batch, 1]
4804 - [num_boxes, 10]
4805 - [num_protos, 32]
4806 - type: protos
4807 decoder: ultralytics
4808 quantization: [0.0039215686274509803921568627451, 128]
4809 shape: [1, 160, 160, 32]
4810 dshape:
4811 - [batch, 1]
4812 - [height, 160]
4813 - [width, 160]
4814 - [num_protos, 32]
4815";
4816
4817 let decoder = DecoderBuilder::default()
4818 .with_config_yaml_str(config.to_string())
4819 .with_score_threshold(score_threshold)
4820 .with_iou_threshold(iou_threshold)
4821 .build()
4822 .unwrap();
4823
4824 let expected_boxes = [DetectBox {
4826 bbox: BoundingBox {
4827 xmin: 0.12549022,
4828 ymin: 0.12549022,
4829 xmax: 0.23529413,
4830 ymax: 0.23529413,
4831 },
4832 score: 0.98823535,
4833 label: 2,
4834 }];
4835
4836 let mut tracker = ByteTrackBuilder::new()
4837 .track_update(0.1)
4838 .track_high_conf(0.7)
4839 .build();
4840
4841 let mut output_boxes = Vec::with_capacity(50);
4842 let mut output_tracks = Vec::with_capacity(50);
4843
4844 decoder
4845 .decode_tracked_quantized_proto(
4846 &mut tracker,
4847 0,
4848 &[
4849 boxes.view().into(),
4850 scores.view().into(),
4851 classes.view().into(),
4852 mask.view().into(),
4853 protos.view().into(),
4854 ],
4855 &mut output_boxes,
4856 &mut output_tracks,
4857 )
4858 .unwrap();
4859
4860 assert_eq!(output_boxes.len(), 1);
4861 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
4862
4863 for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
4866 *score = u8::MIN; }
4868
4869 let protos = decoder
4870 .decode_tracked_quantized_proto(
4871 &mut tracker,
4872 100_000_000 / 3,
4873 &[
4874 boxes.view().into(),
4875 scores.view().into(),
4876 classes.view().into(),
4877 mask.view().into(),
4878 protos.view().into(),
4879 ],
4880 &mut output_boxes,
4881 &mut output_tracks,
4882 )
4883 .unwrap();
4884 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4885 assert!(protos.is_some_and(|x| x.mask_coefficients.is_empty()))
4887 }
4888
4889 #[test]
4890 fn test_decoder_tracked_end_to_end_segdet_split_proto_float() {
4891 let score_threshold = 0.45;
4892 let iou_threshold = 0.45;
4893
4894 let mut boxes = Array2::zeros((10, 4));
4895 let mut scores = Array2::zeros((10, 1));
4896 let mut classes = Array2::zeros((10, 1));
4897 let mask: Array2<f64> = Array2::zeros((10, 32));
4898 let protos = Array3::<f64>::zeros((160, 160, 32));
4899 let protos = protos.insert_axis(Axis(0));
4900
4901 boxes
4902 .slice_mut(s![0, ..,])
4903 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4904 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4905 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4906
4907 let boxes = boxes.insert_axis(Axis(0));
4908 let mut scores = scores.insert_axis(Axis(0));
4909 let classes = classes.insert_axis(Axis(0));
4910 let mask = mask.insert_axis(Axis(0));
4911
4912 let config = "
4913decoder_version: yolo26
4914outputs:
4915 - type: boxes
4916 decoder: ultralytics
4917 quantization: [0.00784313725490196, 0]
4918 shape: [1, 10, 4]
4919 dshape:
4920 - [batch, 1]
4921 - [num_boxes, 10]
4922 - [box_coords, 4]
4923 normalized: true
4924 - type: scores
4925 decoder: ultralytics
4926 quantization: [0.00784313725490196, 0]
4927 shape: [1, 10, 1]
4928 dshape:
4929 - [batch, 1]
4930 - [num_boxes, 10]
4931 - [num_classes, 1]
4932 - type: classes
4933 decoder: ultralytics
4934 quantization: [0.00784313725490196, 0]
4935 shape: [1, 10, 1]
4936 dshape:
4937 - [batch, 1]
4938 - [num_boxes, 10]
4939 - [num_classes, 1]
4940 - type: mask_coefficients
4941 decoder: ultralytics
4942 quantization: [0.00784313725490196, 0]
4943 shape: [1, 10, 32]
4944 dshape:
4945 - [batch, 1]
4946 - [num_boxes, 10]
4947 - [num_protos, 32]
4948 - type: protos
4949 decoder: ultralytics
4950 quantization: [0.0039215686274509803921568627451, 128]
4951 shape: [1, 160, 160, 32]
4952 dshape:
4953 - [batch, 1]
4954 - [height, 160]
4955 - [width, 160]
4956 - [num_protos, 32]
4957";
4958
4959 let decoder = DecoderBuilder::default()
4960 .with_config_yaml_str(config.to_string())
4961 .with_score_threshold(score_threshold)
4962 .with_iou_threshold(iou_threshold)
4963 .build()
4964 .unwrap();
4965
4966 let expected_boxes = [DetectBox {
4968 bbox: BoundingBox {
4969 xmin: 0.1234,
4970 ymin: 0.1234,
4971 xmax: 0.2345,
4972 ymax: 0.2345,
4973 },
4974 score: 0.9876,
4975 label: 2,
4976 }];
4977
4978 let mut tracker = ByteTrackBuilder::new()
4979 .track_update(0.1)
4980 .track_high_conf(0.7)
4981 .build();
4982
4983 let mut output_boxes = Vec::with_capacity(50);
4984 let mut output_tracks = Vec::with_capacity(50);
4985
4986 decoder
4987 .decode_tracked_float_proto(
4988 &mut tracker,
4989 0,
4990 &[
4991 boxes.view().into_dyn(),
4992 scores.view().into_dyn(),
4993 classes.view().into_dyn(),
4994 mask.view().into_dyn(),
4995 protos.view().into_dyn(),
4996 ],
4997 &mut output_boxes,
4998 &mut output_tracks,
4999 )
5000 .unwrap();
5001
5002 assert_eq!(output_boxes.len(), 1);
5003 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
5004
5005 for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
5008 *score = 0.0; }
5010
5011 let protos = decoder
5012 .decode_tracked_float_proto(
5013 &mut tracker,
5014 100_000_000 / 3,
5015 &[
5016 boxes.view().into_dyn(),
5017 scores.view().into_dyn(),
5018 classes.view().into_dyn(),
5019 mask.view().into_dyn(),
5020 protos.view().into_dyn(),
5021 ],
5022 &mut output_boxes,
5023 &mut output_tracks,
5024 )
5025 .unwrap();
5026 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
5027 assert!(protos.is_some_and(|x| x.mask_coefficients.is_empty()))
5029 }
5030
5031 #[test]
5032 fn test_decoder_tracked_linear_motion() {
5033 use crate::configs::{DecoderType, Nms};
5034 use crate::DecoderBuilder;
5035
5036 let score_threshold = 0.25;
5037 let iou_threshold = 0.1;
5038 let out = include_bytes!(concat!(
5039 env!("CARGO_MANIFEST_DIR"),
5040 "/../../testdata/yolov8s_80_classes.bin"
5041 ));
5042 let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
5043 let mut out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
5044 let quant = (0.0040811873, -123).into();
5045
5046 let decoder = DecoderBuilder::default()
5047 .with_config_yolo_det(
5048 crate::configs::Detection {
5049 decoder: DecoderType::Ultralytics,
5050 shape: vec![1, 84, 8400],
5051 anchors: None,
5052 quantization: Some(quant),
5053 dshape: vec![
5054 (crate::configs::DimName::Batch, 1),
5055 (crate::configs::DimName::NumFeatures, 84),
5056 (crate::configs::DimName::NumBoxes, 8400),
5057 ],
5058 normalized: Some(true),
5059 },
5060 None,
5061 )
5062 .with_score_threshold(score_threshold)
5063 .with_iou_threshold(iou_threshold)
5064 .with_nms(Some(Nms::ClassAgnostic))
5065 .build()
5066 .unwrap();
5067
5068 let mut expected_boxes = [
5069 DetectBox {
5070 bbox: BoundingBox {
5071 xmin: 0.5285137,
5072 ymin: 0.05305544,
5073 xmax: 0.87541467,
5074 ymax: 0.9998909,
5075 },
5076 score: 0.5591227,
5077 label: 0,
5078 },
5079 DetectBox {
5080 bbox: BoundingBox {
5081 xmin: 0.130598,
5082 ymin: 0.43260583,
5083 xmax: 0.35098213,
5084 ymax: 0.9958097,
5085 },
5086 score: 0.33057618,
5087 label: 75,
5088 },
5089 ];
5090
5091 let mut tracker = ByteTrackBuilder::new()
5092 .track_update(0.1)
5093 .track_high_conf(0.3)
5094 .build();
5095
5096 let mut output_boxes = Vec::with_capacity(50);
5097 let mut output_masks = Vec::with_capacity(50);
5098 let mut output_tracks = Vec::with_capacity(50);
5099
5100 decoder
5101 .decode_tracked_quantized(
5102 &mut tracker,
5103 0,
5104 &[out.view().into()],
5105 &mut output_boxes,
5106 &mut output_masks,
5107 &mut output_tracks,
5108 )
5109 .unwrap();
5110
5111 assert_eq!(output_boxes.len(), 2);
5112 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
5113 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
5114
5115 for i in 1..=100 {
5116 let mut out = out.clone();
5117 let mut x_values = out.slice_mut(s![0, 0, ..]);
5119 for x in x_values.iter_mut() {
5120 *x = x.saturating_add((i as f32 * 1e-3 / quant.0).round() as i8);
5121 }
5122
5123 decoder
5124 .decode_tracked_quantized(
5125 &mut tracker,
5126 100_000_000 * i / 3, &[out.view().into()],
5128 &mut output_boxes,
5129 &mut output_masks,
5130 &mut output_tracks,
5131 )
5132 .unwrap();
5133
5134 assert_eq!(output_boxes.len(), 2);
5135 }
5136 let tracks = tracker.get_active_tracks();
5137 let predicted_boxes: Vec<_> = tracks
5138 .iter()
5139 .map(|track| {
5140 let mut l = track.last_box;
5141 l.bbox = track.info.tracked_location.into();
5142 l
5143 })
5144 .collect();
5145 expected_boxes[0].bbox.xmin += 0.1; expected_boxes[0].bbox.xmax += 0.1;
5147 expected_boxes[1].bbox.xmin += 0.1;
5148 expected_boxes[1].bbox.xmax += 0.1;
5149
5150 assert!(predicted_boxes[0].equal_within_delta(&expected_boxes[0], 1e-3));
5151 assert!(predicted_boxes[1].equal_within_delta(&expected_boxes[1], 1e-3));
5152
5153 let mut scores_values = out.slice_mut(s![0, 4.., ..]);
5155 for score in scores_values.iter_mut() {
5156 *score = i8::MIN; }
5158 decoder
5159 .decode_tracked_quantized(
5160 &mut tracker,
5161 100_000_000 * 101 / 3,
5162 &[out.view().into()],
5163 &mut output_boxes,
5164 &mut output_masks,
5165 &mut output_tracks,
5166 )
5167 .unwrap();
5168 expected_boxes[0].bbox.xmin += 0.001; expected_boxes[0].bbox.xmax += 0.001;
5170 expected_boxes[1].bbox.xmin += 0.001;
5171 expected_boxes[1].bbox.xmax += 0.001;
5172
5173 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-3));
5174 assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-3));
5175 }
5176
5177 #[test]
5178 fn test_decoder_tracked_end_to_end_float() {
5179 let score_threshold = 0.45;
5180 let iou_threshold = 0.45;
5181
5182 let mut boxes = Array2::zeros((10, 4));
5183 let mut scores = Array2::zeros((10, 1));
5184 let mut classes = Array2::zeros((10, 1));
5185
5186 boxes
5187 .slice_mut(s![0, ..,])
5188 .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
5189 scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
5190 classes.slice_mut(s![0, ..]).assign(&array![2.0]);
5191
5192 let detect = ndarray::concatenate![Axis(1), boxes.view(), scores.view(), classes.view(),];
5193 let mut detect = detect.insert_axis(Axis(0));
5194 assert_eq!(detect.shape(), &[1, 10, 6]);
5195 let config = "
5196decoder_version: yolo26
5197outputs:
5198 - type: detection
5199 decoder: ultralytics
5200 quantization: [0.00784313725490196, 0]
5201 shape: [1, 10, 6]
5202 dshape:
5203 - [batch, 1]
5204 - [num_boxes, 10]
5205 - [num_features, 6]
5206 normalized: true
5207";
5208
5209 let decoder = DecoderBuilder::default()
5210 .with_config_yaml_str(config.to_string())
5211 .with_score_threshold(score_threshold)
5212 .with_iou_threshold(iou_threshold)
5213 .build()
5214 .unwrap();
5215
5216 let expected_boxes = [DetectBox {
5217 bbox: BoundingBox {
5218 xmin: 0.1234,
5219 ymin: 0.1234,
5220 xmax: 0.2345,
5221 ymax: 0.2345,
5222 },
5223 score: 0.9876,
5224 label: 2,
5225 }];
5226
5227 let mut tracker = ByteTrackBuilder::new()
5228 .track_update(0.1)
5229 .track_high_conf(0.7)
5230 .build();
5231
5232 let mut output_boxes = Vec::with_capacity(50);
5233 let mut output_masks = Vec::with_capacity(50);
5234 let mut output_tracks = Vec::with_capacity(50);
5235
5236 decoder
5237 .decode_tracked_float(
5238 &mut tracker,
5239 0,
5240 &[detect.view().into_dyn()],
5241 &mut output_boxes,
5242 &mut output_masks,
5243 &mut output_tracks,
5244 )
5245 .unwrap();
5246
5247 assert_eq!(output_boxes.len(), 1);
5248 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
5249
5250 for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
5253 *score = 0.0; }
5255
5256 decoder
5257 .decode_tracked_float(
5258 &mut tracker,
5259 100_000_000 / 3,
5260 &[detect.view().into_dyn()],
5261 &mut output_boxes,
5262 &mut output_masks,
5263 &mut output_tracks,
5264 )
5265 .unwrap();
5266 assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
5267 }
5268}