1#![cfg_attr(coverage_nightly, feature(coverage_attribute))]
71
72use ndarray::{Array, Array2, Array3, ArrayView, ArrayView1, ArrayView3, Dimension};
73use num_traits::{AsPrimitive, Float, PrimInt};
74
75pub mod byte;
76pub mod error;
77pub mod float;
78pub mod modelpack;
79pub mod yolo;
80
81mod decoder;
82pub use decoder::*;
83
84pub use configs::{DecoderVersion, Nms};
85pub use error::{DecoderError, DecoderResult};
86
87use crate::{
88 decoder::configs::QuantTuple, modelpack::modelpack_segmentation_to_mask,
89 yolo::yolo_segmentation_to_mask,
90};
91
92pub trait BBoxTypeTrait {
94 fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4];
96
97 fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
99 input: &[B; 4],
100 quant: Quantization,
101 ) -> [A; 4]
102 where
103 f32: AsPrimitive<A>,
104 i32: AsPrimitive<A>;
105
106 #[inline(always)]
117 fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
118 input: ArrayView1<B>,
119 ) -> [A; 4] {
120 Self::to_xyxy_float(&[input[0], input[1], input[2], input[3]])
121 }
122
123 #[inline(always)]
124 fn ndarray_to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
126 input: ArrayView1<B>,
127 quant: Quantization,
128 ) -> [A; 4]
129 where
130 f32: AsPrimitive<A>,
131 i32: AsPrimitive<A>,
132 {
133 Self::to_xyxy_dequant(&[input[0], input[1], input[2], input[3]], quant)
134 }
135}
136
137#[derive(Debug, Clone, Copy, PartialEq, Eq)]
139pub struct XYXY {}
140
141impl BBoxTypeTrait for XYXY {
142 fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
143 input.map(|b| b.as_())
144 }
145
146 fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
147 input: &[B; 4],
148 quant: Quantization,
149 ) -> [A; 4]
150 where
151 f32: AsPrimitive<A>,
152 i32: AsPrimitive<A>,
153 {
154 let scale = quant.scale.as_();
155 let zp = quant.zero_point.as_();
156 input.map(|b| (b.as_() - zp) * scale)
157 }
158
159 #[inline(always)]
160 fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
161 input: ArrayView1<B>,
162 ) -> [A; 4] {
163 [
164 input[0].as_(),
165 input[1].as_(),
166 input[2].as_(),
167 input[3].as_(),
168 ]
169 }
170}
171
172#[derive(Debug, Clone, Copy, PartialEq, Eq)]
175pub struct XYWH {}
176
177impl BBoxTypeTrait for XYWH {
178 #[inline(always)]
179 fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
180 let half = A::one() / (A::one() + A::one());
181 [
182 (input[0].as_()) - (input[2].as_() * half),
183 (input[1].as_()) - (input[3].as_() * half),
184 (input[0].as_()) + (input[2].as_() * half),
185 (input[1].as_()) + (input[3].as_() * half),
186 ]
187 }
188
189 #[inline(always)]
190 fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
191 input: &[B; 4],
192 quant: Quantization,
193 ) -> [A; 4]
194 where
195 f32: AsPrimitive<A>,
196 i32: AsPrimitive<A>,
197 {
198 let scale = quant.scale.as_();
199 let half_scale = (quant.scale * 0.5).as_();
200 let zp = quant.zero_point.as_();
201 let [x, y, w, h] = [
202 (input[0].as_() - zp) * scale,
203 (input[1].as_() - zp) * scale,
204 (input[2].as_() - zp) * half_scale,
205 (input[3].as_() - zp) * half_scale,
206 ];
207
208 [x - w, y - h, x + w, y + h]
209 }
210
211 #[inline(always)]
212 fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
213 input: ArrayView1<B>,
214 ) -> [A; 4] {
215 let half = A::one() / (A::one() + A::one());
216 [
217 (input[0].as_()) - (input[2].as_() * half),
218 (input[1].as_()) - (input[3].as_() * half),
219 (input[0].as_()) + (input[2].as_() * half),
220 (input[1].as_()) + (input[3].as_() * half),
221 ]
222 }
223}
224
225#[derive(Debug, Clone, Copy, PartialEq)]
227pub struct Quantization {
228 pub scale: f32,
229 pub zero_point: i32,
230}
231
232impl Quantization {
233 pub fn new(scale: f32, zero_point: i32) -> Self {
242 Self { scale, zero_point }
243 }
244}
245
246impl From<QuantTuple> for Quantization {
247 fn from(quant_tuple: QuantTuple) -> Quantization {
258 Quantization {
259 scale: quant_tuple.0,
260 zero_point: quant_tuple.1,
261 }
262 }
263}
264
265impl<S, Z> From<(S, Z)> for Quantization
266where
267 S: AsPrimitive<f32>,
268 Z: AsPrimitive<i32>,
269{
270 fn from((scale, zp): (S, Z)) -> Quantization {
279 Self {
280 scale: scale.as_(),
281 zero_point: zp.as_(),
282 }
283 }
284}
285
286impl Default for Quantization {
287 fn default() -> Self {
296 Self {
297 scale: 1.0,
298 zero_point: 0,
299 }
300 }
301}
302
303#[derive(Debug, Clone, Copy, PartialEq, Default)]
305pub struct DetectBox {
306 pub bbox: BoundingBox,
307 pub score: f32,
309 pub label: usize,
311}
312
313#[derive(Debug, Clone, Copy, PartialEq, Default)]
315pub struct BoundingBox {
316 pub xmin: f32,
318 pub ymin: f32,
320 pub xmax: f32,
322 pub ymax: f32,
324}
325
326impl BoundingBox {
327 pub fn new(xmin: f32, ymin: f32, xmax: f32, ymax: f32) -> Self {
329 Self {
330 xmin,
331 ymin,
332 xmax,
333 ymax,
334 }
335 }
336
337 pub fn to_canonical(&self) -> Self {
346 let xmin = self.xmin.min(self.xmax);
347 let xmax = self.xmin.max(self.xmax);
348 let ymin = self.ymin.min(self.ymax);
349 let ymax = self.ymin.max(self.ymax);
350 BoundingBox {
351 xmin,
352 ymin,
353 xmax,
354 ymax,
355 }
356 }
357}
358
359impl From<BoundingBox> for [f32; 4] {
360 fn from(b: BoundingBox) -> Self {
375 [b.xmin, b.ymin, b.xmax, b.ymax]
376 }
377}
378
379impl From<[f32; 4]> for BoundingBox {
380 fn from(arr: [f32; 4]) -> Self {
383 BoundingBox {
384 xmin: arr[0],
385 ymin: arr[1],
386 xmax: arr[2],
387 ymax: arr[3],
388 }
389 }
390}
391
392impl DetectBox {
393 pub fn equal_within_delta(&self, rhs: &DetectBox, eps: f32) -> bool {
422 let eq_delta = |a: f32, b: f32| (a - b).abs() <= eps;
423 self.label == rhs.label
424 && eq_delta(self.score, rhs.score)
425 && eq_delta(self.bbox.xmin, rhs.bbox.xmin)
426 && eq_delta(self.bbox.ymin, rhs.bbox.ymin)
427 && eq_delta(self.bbox.xmax, rhs.bbox.xmax)
428 && eq_delta(self.bbox.ymax, rhs.bbox.ymax)
429 }
430}
431
432#[derive(Debug, Clone, PartialEq, Default)]
435pub struct Segmentation {
436 pub xmin: f32,
438 pub ymin: f32,
440 pub xmax: f32,
442 pub ymax: f32,
444 pub segmentation: Array3<u8>,
447}
448
449#[derive(Debug, Clone)]
455pub enum ProtoTensor {
456 Quantized {
460 protos: Array3<i8>,
461 quantization: Quantization,
462 },
463 Float(Array3<f32>),
465}
466
467impl ProtoTensor {
468 pub fn is_quantized(&self) -> bool {
470 matches!(self, ProtoTensor::Quantized { .. })
471 }
472
473 pub fn dim(&self) -> (usize, usize, usize) {
475 match self {
476 ProtoTensor::Quantized { protos, .. } => protos.dim(),
477 ProtoTensor::Float(arr) => arr.dim(),
478 }
479 }
480
481 pub fn as_f32(&self) -> std::borrow::Cow<'_, Array3<f32>> {
484 match self {
485 ProtoTensor::Float(arr) => std::borrow::Cow::Borrowed(arr),
486 ProtoTensor::Quantized {
487 protos,
488 quantization,
489 } => {
490 let scale = quantization.scale;
491 let zp = quantization.zero_point as f32;
492 std::borrow::Cow::Owned(protos.map(|&v| (v as f32 - zp) * scale))
493 }
494 }
495 }
496}
497
498#[derive(Debug, Clone)]
504pub struct ProtoData {
505 pub mask_coefficients: Vec<Vec<f32>>,
507 pub protos: ProtoTensor,
509}
510
511pub fn dequant_detect_box<SCORE: PrimInt + AsPrimitive<f32>>(
529 detect: &DetectBoxQuantized<SCORE>,
530 quant_scores: Quantization,
531) -> DetectBox {
532 let scaled_zp = -quant_scores.scale * quant_scores.zero_point as f32;
533 DetectBox {
534 bbox: detect.bbox,
535 score: quant_scores.scale * detect.score.as_() + scaled_zp,
536 label: detect.label,
537 }
538}
539#[derive(Debug, Clone, Copy, PartialEq)]
541pub struct DetectBoxQuantized<
542 SCORE: PrimInt + AsPrimitive<f32>,
544> {
545 pub bbox: BoundingBox,
547 pub score: SCORE,
550 pub label: usize,
552}
553
554pub fn dequantize_ndarray<T: AsPrimitive<F>, D: Dimension, F: Float + 'static>(
567 input: ArrayView<T, D>,
568 quant: Quantization,
569) -> Array<F, D>
570where
571 i32: num_traits::AsPrimitive<F>,
572 f32: num_traits::AsPrimitive<F>,
573{
574 let zero_point = quant.zero_point.as_();
575 let scale = quant.scale.as_();
576 if zero_point != F::zero() {
577 let scaled_zero = -zero_point * scale;
578 input.mapv(|d| d.as_() * scale + scaled_zero)
579 } else {
580 input.mapv(|d| d.as_() * scale)
581 }
582}
583
584pub fn dequantize_cpu<T: AsPrimitive<F>, F: Float + 'static>(
597 input: &[T],
598 quant: Quantization,
599 output: &mut [F],
600) where
601 f32: num_traits::AsPrimitive<F>,
602 i32: num_traits::AsPrimitive<F>,
603{
604 assert!(input.len() == output.len());
605 let zero_point = quant.zero_point.as_();
606 let scale = quant.scale.as_();
607 if zero_point != F::zero() {
608 let scaled_zero = -zero_point * scale; input
610 .iter()
611 .zip(output)
612 .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
613 } else {
614 input
615 .iter()
616 .zip(output)
617 .for_each(|(d, deq)| *deq = d.as_() * scale);
618 }
619}
620
621pub fn dequantize_cpu_chunked<T: AsPrimitive<F>, F: Float + 'static>(
635 input: &[T],
636 quant: Quantization,
637 output: &mut [F],
638) where
639 f32: num_traits::AsPrimitive<F>,
640 i32: num_traits::AsPrimitive<F>,
641{
642 assert!(input.len() == output.len());
643 let zero_point = quant.zero_point.as_();
644 let scale = quant.scale.as_();
645
646 let input = input.as_chunks::<4>();
647 let output = output.as_chunks_mut::<4>();
648
649 if zero_point != F::zero() {
650 let scaled_zero = -zero_point * scale; input
653 .0
654 .iter()
655 .zip(output.0)
656 .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale + scaled_zero));
657 input
658 .1
659 .iter()
660 .zip(output.1)
661 .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
662 } else {
663 input
664 .0
665 .iter()
666 .zip(output.0)
667 .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale));
668 input
669 .1
670 .iter()
671 .zip(output.1)
672 .for_each(|(d, deq)| *deq = d.as_() * scale);
673 }
674}
675
676pub fn segmentation_to_mask(segmentation: ArrayView3<u8>) -> Result<Array2<u8>, DecoderError> {
694 if segmentation.shape()[2] == 0 {
695 return Err(DecoderError::InvalidShape(
696 "Segmentation tensor must have non-zero depth".to_string(),
697 ));
698 }
699 if segmentation.shape()[2] == 1 {
700 yolo_segmentation_to_mask(segmentation, 128)
701 } else {
702 Ok(modelpack_segmentation_to_mask(segmentation))
703 }
704}
705
706fn arg_max<T: PartialOrd + Copy>(score: ArrayView1<T>) -> (T, usize) {
708 score
709 .iter()
710 .enumerate()
711 .fold((score[0], 0), |(max, arg_max), (ind, s)| {
712 if max > *s {
713 (max, arg_max)
714 } else {
715 (*s, ind)
716 }
717 })
718}
719#[cfg(test)]
720#[cfg_attr(coverage_nightly, coverage(off))]
721mod decoder_tests {
722 #![allow(clippy::excessive_precision)]
723 use crate::{
724 configs::{DecoderType, DimName, Protos},
725 modelpack::{decode_modelpack_det, decode_modelpack_split_quant},
726 yolo::{
727 decode_yolo_det, decode_yolo_det_float, decode_yolo_segdet_float,
728 decode_yolo_segdet_quant,
729 },
730 *,
731 };
732 use ndarray::{array, s, Array4};
733 use ndarray_stats::DeviationExt;
734
735 fn compare_outputs(
736 boxes: (&[DetectBox], &[DetectBox]),
737 masks: (&[Segmentation], &[Segmentation]),
738 ) {
739 let (boxes0, boxes1) = boxes;
740 let (masks0, masks1) = masks;
741
742 assert_eq!(boxes0.len(), boxes1.len());
743 assert_eq!(masks0.len(), masks1.len());
744
745 for (b_i8, b_f32) in boxes0.iter().zip(boxes1) {
746 assert!(
747 b_i8.equal_within_delta(b_f32, 1e-6),
748 "{b_i8:?} is not equal to {b_f32:?}"
749 );
750 }
751
752 for (m_i8, m_f32) in masks0.iter().zip(masks1) {
753 assert_eq!(
754 [m_i8.xmin, m_i8.ymin, m_i8.xmax, m_i8.ymax],
755 [m_f32.xmin, m_f32.ymin, m_f32.xmax, m_f32.ymax],
756 );
757 assert_eq!(m_i8.segmentation.shape(), m_f32.segmentation.shape());
758 let mask_i8 = m_i8.segmentation.map(|x| *x as i32);
759 let mask_f32 = m_f32.segmentation.map(|x| *x as i32);
760 let diff = &mask_i8 - &mask_f32;
761 for x in 0..diff.shape()[0] {
762 for y in 0..diff.shape()[1] {
763 for z in 0..diff.shape()[2] {
764 let val = diff[[x, y, z]];
765 assert!(
766 val.abs() <= 1,
767 "Difference between mask0 and mask1 is greater than 1 at ({}, {}, {}): {}",
768 x,
769 y,
770 z,
771 val
772 );
773 }
774 }
775 }
776 let mean_sq_err = mask_i8.mean_sq_err(&mask_f32).unwrap();
777 assert!(
778 mean_sq_err < 1e-2,
779 "Mean Square Error between masks was greater than 1%: {:.2}%",
780 mean_sq_err * 100.0
781 );
782 }
783 }
784
785 #[test]
786 fn test_decoder_modelpack() {
787 let score_threshold = 0.45;
788 let iou_threshold = 0.45;
789 let boxes = include_bytes!("../../../testdata/modelpack_boxes_1935x1x4.bin");
790 let boxes = ndarray::Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
791
792 let scores = include_bytes!("../../../testdata/modelpack_scores_1935x1.bin");
793 let scores = ndarray::Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
794
795 let quant_boxes = (0.004656755365431309, 21).into();
796 let quant_scores = (0.0019603664986789227, 0).into();
797
798 let decoder = DecoderBuilder::default()
799 .with_config_modelpack_det(
800 configs::Boxes {
801 decoder: DecoderType::ModelPack,
802 quantization: Some(quant_boxes),
803 shape: vec![1, 1935, 1, 4],
804 dshape: vec![
805 (DimName::Batch, 1),
806 (DimName::NumBoxes, 1935),
807 (DimName::Padding, 1),
808 (DimName::BoxCoords, 4),
809 ],
810 normalized: Some(true),
811 },
812 configs::Scores {
813 decoder: DecoderType::ModelPack,
814 quantization: Some(quant_scores),
815 shape: vec![1, 1935, 1],
816 dshape: vec![
817 (DimName::Batch, 1),
818 (DimName::NumBoxes, 1935),
819 (DimName::NumClasses, 1),
820 ],
821 },
822 )
823 .with_score_threshold(score_threshold)
824 .with_iou_threshold(iou_threshold)
825 .build()
826 .unwrap();
827
828 let quant_boxes = quant_boxes.into();
829 let quant_scores = quant_scores.into();
830
831 let mut output_boxes: Vec<_> = Vec::with_capacity(50);
832 decode_modelpack_det(
833 (boxes.slice(s![0, .., 0, ..]), quant_boxes),
834 (scores.slice(s![0, .., ..]), quant_scores),
835 score_threshold,
836 iou_threshold,
837 &mut output_boxes,
838 );
839 assert!(output_boxes[0].equal_within_delta(
840 &DetectBox {
841 bbox: BoundingBox {
842 xmin: 0.40513772,
843 ymin: 0.6379755,
844 xmax: 0.5122431,
845 ymax: 0.7730214,
846 },
847 score: 0.4861709,
848 label: 0
849 },
850 1e-6
851 ));
852
853 let mut output_boxes1 = Vec::with_capacity(50);
854 let mut output_masks1 = Vec::with_capacity(50);
855
856 decoder
857 .decode_quantized(
858 &[boxes.view().into(), scores.view().into()],
859 &mut output_boxes1,
860 &mut output_masks1,
861 )
862 .unwrap();
863
864 let mut output_boxes_float = Vec::with_capacity(50);
865 let mut output_masks_float = Vec::with_capacity(50);
866
867 let boxes = dequantize_ndarray(boxes.view(), quant_boxes);
868 let scores = dequantize_ndarray(scores.view(), quant_scores);
869
870 decoder
871 .decode_float::<f32>(
872 &[boxes.view().into_dyn(), scores.view().into_dyn()],
873 &mut output_boxes_float,
874 &mut output_masks_float,
875 )
876 .unwrap();
877
878 compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
879 compare_outputs(
880 (&output_boxes, &output_boxes_float),
881 (&[], &output_masks_float),
882 );
883 }
884
885 #[test]
886 fn test_decoder_modelpack_split_u8() {
887 let score_threshold = 0.45;
888 let iou_threshold = 0.45;
889 let detect0 = include_bytes!("../../../testdata/modelpack_split_9x15x18.bin");
890 let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
891
892 let detect1 = include_bytes!("../../../testdata/modelpack_split_17x30x18.bin");
893 let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
894
895 let quant0 = (0.08547406643629074, 174).into();
896 let quant1 = (0.09929127991199493, 183).into();
897 let anchors0 = vec![
898 [0.36666667461395264, 0.31481480598449707],
899 [0.38749998807907104, 0.4740740656852722],
900 [0.5333333611488342, 0.644444465637207],
901 ];
902 let anchors1 = vec![
903 [0.13750000298023224, 0.2074074000120163],
904 [0.2541666626930237, 0.21481481194496155],
905 [0.23125000298023224, 0.35185185074806213],
906 ];
907
908 let detect_config0 = configs::Detection {
909 decoder: DecoderType::ModelPack,
910 shape: vec![1, 9, 15, 18],
911 anchors: Some(anchors0.clone()),
912 quantization: Some(quant0),
913 dshape: vec![
914 (DimName::Batch, 1),
915 (DimName::Height, 9),
916 (DimName::Width, 15),
917 (DimName::NumAnchorsXFeatures, 18),
918 ],
919 normalized: Some(true),
920 };
921
922 let detect_config1 = configs::Detection {
923 decoder: DecoderType::ModelPack,
924 shape: vec![1, 17, 30, 18],
925 anchors: Some(anchors1.clone()),
926 quantization: Some(quant1),
927 dshape: vec![
928 (DimName::Batch, 1),
929 (DimName::Height, 17),
930 (DimName::Width, 30),
931 (DimName::NumAnchorsXFeatures, 18),
932 ],
933 normalized: Some(true),
934 };
935
936 let config0 = (&detect_config0).try_into().unwrap();
937 let config1 = (&detect_config1).try_into().unwrap();
938
939 let decoder = DecoderBuilder::default()
940 .with_config_modelpack_det_split(vec![detect_config1, detect_config0])
941 .with_score_threshold(score_threshold)
942 .with_iou_threshold(iou_threshold)
943 .build()
944 .unwrap();
945
946 let quant0 = quant0.into();
947 let quant1 = quant1.into();
948
949 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
950 decode_modelpack_split_quant(
951 &[
952 detect0.slice(s![0, .., .., ..]),
953 detect1.slice(s![0, .., .., ..]),
954 ],
955 &[config0, config1],
956 score_threshold,
957 iou_threshold,
958 &mut output_boxes,
959 );
960 assert!(output_boxes[0].equal_within_delta(
961 &DetectBox {
962 bbox: BoundingBox {
963 xmin: 0.43171933,
964 ymin: 0.68243736,
965 xmax: 0.5626645,
966 ymax: 0.808863,
967 },
968 score: 0.99240804,
969 label: 0
970 },
971 1e-6
972 ));
973
974 let mut output_boxes1: Vec<_> = Vec::with_capacity(10);
975 let mut output_masks1: Vec<_> = Vec::with_capacity(10);
976 decoder
977 .decode_quantized(
978 &[detect0.view().into(), detect1.view().into()],
979 &mut output_boxes1,
980 &mut output_masks1,
981 )
982 .unwrap();
983
984 let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(10);
985 let mut output_masks1_f32: Vec<_> = Vec::with_capacity(10);
986
987 let detect0 = dequantize_ndarray(detect0.view(), quant0);
988 let detect1 = dequantize_ndarray(detect1.view(), quant1);
989 decoder
990 .decode_float::<f32>(
991 &[detect0.view().into_dyn(), detect1.view().into_dyn()],
992 &mut output_boxes1_f32,
993 &mut output_masks1_f32,
994 )
995 .unwrap();
996
997 compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
998 compare_outputs(
999 (&output_boxes, &output_boxes1_f32),
1000 (&[], &output_masks1_f32),
1001 );
1002 }
1003
1004 #[test]
1005 fn test_decoder_parse_config_modelpack_split_u8() {
1006 let score_threshold = 0.45;
1007 let iou_threshold = 0.45;
1008 let detect0 = include_bytes!("../../../testdata/modelpack_split_9x15x18.bin");
1009 let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1010
1011 let detect1 = include_bytes!("../../../testdata/modelpack_split_17x30x18.bin");
1012 let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1013
1014 let decoder = DecoderBuilder::default()
1015 .with_config_yaml_str(
1016 include_str!("../../../testdata/modelpack_split.yaml").to_string(),
1017 )
1018 .with_score_threshold(score_threshold)
1019 .with_iou_threshold(iou_threshold)
1020 .build()
1021 .unwrap();
1022
1023 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1024 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1025 decoder
1026 .decode_quantized(
1027 &[
1028 ArrayViewDQuantized::from(detect1.view()),
1029 ArrayViewDQuantized::from(detect0.view()),
1030 ],
1031 &mut output_boxes,
1032 &mut output_masks,
1033 )
1034 .unwrap();
1035 assert!(output_boxes[0].equal_within_delta(
1036 &DetectBox {
1037 bbox: BoundingBox {
1038 xmin: 0.43171933,
1039 ymin: 0.68243736,
1040 xmax: 0.5626645,
1041 ymax: 0.808863,
1042 },
1043 score: 0.99240804,
1044 label: 0
1045 },
1046 1e-6
1047 ));
1048 }
1049
1050 #[test]
1051 fn test_modelpack_seg() {
1052 let out = include_bytes!("../../../testdata/modelpack_seg_2x160x160.bin");
1053 let out = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
1054 let quant = (1.0 / 255.0, 0).into();
1055
1056 let decoder = DecoderBuilder::default()
1057 .with_config_modelpack_seg(configs::Segmentation {
1058 decoder: DecoderType::ModelPack,
1059 quantization: Some(quant),
1060 shape: vec![1, 2, 160, 160],
1061 dshape: vec![
1062 (DimName::Batch, 1),
1063 (DimName::NumClasses, 2),
1064 (DimName::Height, 160),
1065 (DimName::Width, 160),
1066 ],
1067 })
1068 .build()
1069 .unwrap();
1070 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1071 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1072 decoder
1073 .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
1074 .unwrap();
1075
1076 let mut mask = out.slice(s![0, .., .., ..]);
1077 mask.swap_axes(0, 1);
1078 mask.swap_axes(1, 2);
1079 let mask = [Segmentation {
1080 xmin: 0.0,
1081 ymin: 0.0,
1082 xmax: 1.0,
1083 ymax: 1.0,
1084 segmentation: mask.into_owned(),
1085 }];
1086 compare_outputs((&[], &output_boxes), (&mask, &output_masks));
1087
1088 decoder
1089 .decode_float::<f32>(
1090 &[dequantize_ndarray(out.view(), quant.into())
1091 .view()
1092 .into_dyn()],
1093 &mut output_boxes,
1094 &mut output_masks,
1095 )
1096 .unwrap();
1097
1098 compare_outputs((&[], &output_boxes), (&[], &[]));
1104 let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1105 let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1106
1107 assert_eq!(mask0, mask1);
1108 }
1109 #[test]
1110 fn test_modelpack_seg_quant() {
1111 let out = include_bytes!("../../../testdata/modelpack_seg_2x160x160.bin");
1112 let out_u8 = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
1113 let out_i8 = out_u8.mapv(|x| (x as i16 - 128) as i8);
1114 let out_u16 = out_u8.mapv(|x| (x as u16) << 8);
1115 let out_i16 = out_u8.mapv(|x| (((x as i32) << 8) - 32768) as i16);
1116 let out_u32 = out_u8.mapv(|x| (x as u32) << 24);
1117 let out_i32 = out_u8.mapv(|x| (((x as i64) << 24) - 2147483648) as i32);
1118
1119 let quant = (1.0 / 255.0, 0).into();
1120
1121 let decoder = DecoderBuilder::default()
1122 .with_config_modelpack_seg(configs::Segmentation {
1123 decoder: DecoderType::ModelPack,
1124 quantization: Some(quant),
1125 shape: vec![1, 2, 160, 160],
1126 dshape: vec![
1127 (DimName::Batch, 1),
1128 (DimName::NumClasses, 2),
1129 (DimName::Height, 160),
1130 (DimName::Width, 160),
1131 ],
1132 })
1133 .build()
1134 .unwrap();
1135 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1136 let mut output_masks_u8: Vec<_> = Vec::with_capacity(10);
1137 decoder
1138 .decode_quantized(
1139 &[out_u8.view().into()],
1140 &mut output_boxes,
1141 &mut output_masks_u8,
1142 )
1143 .unwrap();
1144
1145 let mut output_masks_i8: Vec<_> = Vec::with_capacity(10);
1146 decoder
1147 .decode_quantized(
1148 &[out_i8.view().into()],
1149 &mut output_boxes,
1150 &mut output_masks_i8,
1151 )
1152 .unwrap();
1153
1154 let mut output_masks_u16: Vec<_> = Vec::with_capacity(10);
1155 decoder
1156 .decode_quantized(
1157 &[out_u16.view().into()],
1158 &mut output_boxes,
1159 &mut output_masks_u16,
1160 )
1161 .unwrap();
1162
1163 let mut output_masks_i16: Vec<_> = Vec::with_capacity(10);
1164 decoder
1165 .decode_quantized(
1166 &[out_i16.view().into()],
1167 &mut output_boxes,
1168 &mut output_masks_i16,
1169 )
1170 .unwrap();
1171
1172 let mut output_masks_u32: Vec<_> = Vec::with_capacity(10);
1173 decoder
1174 .decode_quantized(
1175 &[out_u32.view().into()],
1176 &mut output_boxes,
1177 &mut output_masks_u32,
1178 )
1179 .unwrap();
1180
1181 let mut output_masks_i32: Vec<_> = Vec::with_capacity(10);
1182 decoder
1183 .decode_quantized(
1184 &[out_i32.view().into()],
1185 &mut output_boxes,
1186 &mut output_masks_i32,
1187 )
1188 .unwrap();
1189
1190 compare_outputs((&[], &output_boxes), (&[], &[]));
1191 let mask_u8 = segmentation_to_mask(output_masks_u8[0].segmentation.view()).unwrap();
1192 let mask_i8 = segmentation_to_mask(output_masks_i8[0].segmentation.view()).unwrap();
1193 let mask_u16 = segmentation_to_mask(output_masks_u16[0].segmentation.view()).unwrap();
1194 let mask_i16 = segmentation_to_mask(output_masks_i16[0].segmentation.view()).unwrap();
1195 let mask_u32 = segmentation_to_mask(output_masks_u32[0].segmentation.view()).unwrap();
1196 let mask_i32 = segmentation_to_mask(output_masks_i32[0].segmentation.view()).unwrap();
1197 assert_eq!(mask_u8, mask_i8);
1198 assert_eq!(mask_u8, mask_u16);
1199 assert_eq!(mask_u8, mask_i16);
1200 assert_eq!(mask_u8, mask_u32);
1201 assert_eq!(mask_u8, mask_i32);
1202 }
1203
1204 #[test]
1205 fn test_modelpack_segdet() {
1206 let score_threshold = 0.45;
1207 let iou_threshold = 0.45;
1208
1209 let boxes = include_bytes!("../../../testdata/modelpack_boxes_1935x1x4.bin");
1210 let boxes = Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
1211
1212 let scores = include_bytes!("../../../testdata/modelpack_scores_1935x1.bin");
1213 let scores = Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
1214
1215 let seg = include_bytes!("../../../testdata/modelpack_seg_2x160x160.bin");
1216 let seg = Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1217
1218 let quant_boxes = (0.004656755365431309, 21).into();
1219 let quant_scores = (0.0019603664986789227, 0).into();
1220 let quant_seg = (1.0 / 255.0, 0).into();
1221
1222 let decoder = DecoderBuilder::default()
1223 .with_config_modelpack_segdet(
1224 configs::Boxes {
1225 decoder: DecoderType::ModelPack,
1226 quantization: Some(quant_boxes),
1227 shape: vec![1, 1935, 1, 4],
1228 dshape: vec![
1229 (DimName::Batch, 1),
1230 (DimName::NumBoxes, 1935),
1231 (DimName::Padding, 1),
1232 (DimName::BoxCoords, 4),
1233 ],
1234 normalized: Some(true),
1235 },
1236 configs::Scores {
1237 decoder: DecoderType::ModelPack,
1238 quantization: Some(quant_scores),
1239 shape: vec![1, 1935, 1],
1240 dshape: vec![
1241 (DimName::Batch, 1),
1242 (DimName::NumBoxes, 1935),
1243 (DimName::NumClasses, 1),
1244 ],
1245 },
1246 configs::Segmentation {
1247 decoder: DecoderType::ModelPack,
1248 quantization: Some(quant_seg),
1249 shape: vec![1, 2, 160, 160],
1250 dshape: vec![
1251 (DimName::Batch, 1),
1252 (DimName::NumClasses, 2),
1253 (DimName::Height, 160),
1254 (DimName::Width, 160),
1255 ],
1256 },
1257 )
1258 .with_iou_threshold(iou_threshold)
1259 .with_score_threshold(score_threshold)
1260 .build()
1261 .unwrap();
1262 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1263 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1264 decoder
1265 .decode_quantized(
1266 &[scores.view().into(), boxes.view().into(), seg.view().into()],
1267 &mut output_boxes,
1268 &mut output_masks,
1269 )
1270 .unwrap();
1271
1272 let mut mask = seg.slice(s![0, .., .., ..]);
1273 mask.swap_axes(0, 1);
1274 mask.swap_axes(1, 2);
1275 let mask = [Segmentation {
1276 xmin: 0.0,
1277 ymin: 0.0,
1278 xmax: 1.0,
1279 ymax: 1.0,
1280 segmentation: mask.into_owned(),
1281 }];
1282 let correct_boxes = [DetectBox {
1283 bbox: BoundingBox {
1284 xmin: 0.40513772,
1285 ymin: 0.6379755,
1286 xmax: 0.5122431,
1287 ymax: 0.7730214,
1288 },
1289 score: 0.4861709,
1290 label: 0,
1291 }];
1292 compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1293
1294 let scores = dequantize_ndarray(scores.view(), quant_scores.into());
1295 let boxes = dequantize_ndarray(boxes.view(), quant_boxes.into());
1296 let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1297 decoder
1298 .decode_float::<f32>(
1299 &[
1300 scores.view().into_dyn(),
1301 boxes.view().into_dyn(),
1302 seg.view().into_dyn(),
1303 ],
1304 &mut output_boxes,
1305 &mut output_masks,
1306 )
1307 .unwrap();
1308
1309 compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1315 let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1316 let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1317
1318 assert_eq!(mask0, mask1);
1319 }
1320
1321 #[test]
1322 fn test_modelpack_segdet_split() {
1323 let score_threshold = 0.8;
1324 let iou_threshold = 0.5;
1325
1326 let seg = include_bytes!("../../../testdata/modelpack_seg_2x160x160.bin");
1327 let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1328
1329 let detect0 = include_bytes!("../../../testdata/modelpack_split_9x15x18.bin");
1330 let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1331
1332 let detect1 = include_bytes!("../../../testdata/modelpack_split_17x30x18.bin");
1333 let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1334
1335 let quant0 = (0.08547406643629074, 174).into();
1336 let quant1 = (0.09929127991199493, 183).into();
1337 let quant_seg = (1.0 / 255.0, 0).into();
1338
1339 let anchors0 = vec![
1340 [0.36666667461395264, 0.31481480598449707],
1341 [0.38749998807907104, 0.4740740656852722],
1342 [0.5333333611488342, 0.644444465637207],
1343 ];
1344 let anchors1 = vec![
1345 [0.13750000298023224, 0.2074074000120163],
1346 [0.2541666626930237, 0.21481481194496155],
1347 [0.23125000298023224, 0.35185185074806213],
1348 ];
1349
1350 let decoder = DecoderBuilder::default()
1351 .with_config_modelpack_segdet_split(
1352 vec![
1353 configs::Detection {
1354 decoder: DecoderType::ModelPack,
1355 shape: vec![1, 17, 30, 18],
1356 anchors: Some(anchors1),
1357 quantization: Some(quant1),
1358 dshape: vec![
1359 (DimName::Batch, 1),
1360 (DimName::Height, 17),
1361 (DimName::Width, 30),
1362 (DimName::NumAnchorsXFeatures, 18),
1363 ],
1364 normalized: Some(true),
1365 },
1366 configs::Detection {
1367 decoder: DecoderType::ModelPack,
1368 shape: vec![1, 9, 15, 18],
1369 anchors: Some(anchors0),
1370 quantization: Some(quant0),
1371 dshape: vec![
1372 (DimName::Batch, 1),
1373 (DimName::Height, 9),
1374 (DimName::Width, 15),
1375 (DimName::NumAnchorsXFeatures, 18),
1376 ],
1377 normalized: Some(true),
1378 },
1379 ],
1380 configs::Segmentation {
1381 decoder: DecoderType::ModelPack,
1382 quantization: Some(quant_seg),
1383 shape: vec![1, 2, 160, 160],
1384 dshape: vec![
1385 (DimName::Batch, 1),
1386 (DimName::NumClasses, 2),
1387 (DimName::Height, 160),
1388 (DimName::Width, 160),
1389 ],
1390 },
1391 )
1392 .with_score_threshold(score_threshold)
1393 .with_iou_threshold(iou_threshold)
1394 .build()
1395 .unwrap();
1396 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1397 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1398 decoder
1399 .decode_quantized(
1400 &[
1401 detect0.view().into(),
1402 detect1.view().into(),
1403 seg.view().into(),
1404 ],
1405 &mut output_boxes,
1406 &mut output_masks,
1407 )
1408 .unwrap();
1409
1410 let mut mask = seg.slice(s![0, .., .., ..]);
1411 mask.swap_axes(0, 1);
1412 mask.swap_axes(1, 2);
1413 let mask = [Segmentation {
1414 xmin: 0.0,
1415 ymin: 0.0,
1416 xmax: 1.0,
1417 ymax: 1.0,
1418 segmentation: mask.into_owned(),
1419 }];
1420 let correct_boxes = [DetectBox {
1421 bbox: BoundingBox {
1422 xmin: 0.43171933,
1423 ymin: 0.68243736,
1424 xmax: 0.5626645,
1425 ymax: 0.808863,
1426 },
1427 score: 0.99240804,
1428 label: 0,
1429 }];
1430 println!("Output Boxes: {:?}", output_boxes);
1431 compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1432
1433 let detect0 = dequantize_ndarray(detect0.view(), quant0.into());
1434 let detect1 = dequantize_ndarray(detect1.view(), quant1.into());
1435 let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1436 decoder
1437 .decode_float::<f32>(
1438 &[
1439 detect0.view().into_dyn(),
1440 detect1.view().into_dyn(),
1441 seg.view().into_dyn(),
1442 ],
1443 &mut output_boxes,
1444 &mut output_masks,
1445 )
1446 .unwrap();
1447
1448 compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1454 let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1455 let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1456
1457 assert_eq!(mask0, mask1);
1458 }
1459
1460 #[test]
1461 fn test_dequant_chunked() {
1462 let out = include_bytes!("../../../testdata/yolov8s_80_classes.bin");
1463 let mut out =
1464 unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) }.to_vec();
1465 out.push(123); let mut out_dequant = vec![0.0; 84 * 8400 + 1];
1468 let mut out_dequant_simd = vec![0.0; 84 * 8400 + 1];
1469 let quant = Quantization::new(0.0040811873, -123);
1470 dequantize_cpu(&out, quant, &mut out_dequant);
1471
1472 dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1473 assert_eq!(out_dequant, out_dequant_simd);
1474
1475 let quant = Quantization::new(0.0040811873, 0);
1476 dequantize_cpu(&out, quant, &mut out_dequant);
1477
1478 dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1479 assert_eq!(out_dequant, out_dequant_simd);
1480 }
1481
1482 #[test]
1483 fn test_decoder_yolo_det() {
1484 let score_threshold = 0.25;
1485 let iou_threshold = 0.7;
1486 let out = include_bytes!("../../../testdata/yolov8s_80_classes.bin");
1487 let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
1488 let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
1489 let quant = (0.0040811873, -123).into();
1490
1491 let decoder = DecoderBuilder::default()
1492 .with_config_yolo_det(
1493 configs::Detection {
1494 decoder: DecoderType::Ultralytics,
1495 shape: vec![1, 84, 8400],
1496 anchors: None,
1497 quantization: Some(quant),
1498 dshape: vec![
1499 (DimName::Batch, 1),
1500 (DimName::NumFeatures, 84),
1501 (DimName::NumBoxes, 8400),
1502 ],
1503 normalized: Some(true),
1504 },
1505 Some(DecoderVersion::Yolo11),
1506 )
1507 .with_score_threshold(score_threshold)
1508 .with_iou_threshold(iou_threshold)
1509 .build()
1510 .unwrap();
1511
1512 let mut output_boxes: Vec<_> = Vec::with_capacity(50);
1513 decode_yolo_det(
1514 (out.slice(s![0, .., ..]), quant.into()),
1515 score_threshold,
1516 iou_threshold,
1517 Some(configs::Nms::ClassAgnostic),
1518 &mut output_boxes,
1519 );
1520 assert!(output_boxes[0].equal_within_delta(
1521 &DetectBox {
1522 bbox: BoundingBox {
1523 xmin: 0.5285137,
1524 ymin: 0.05305544,
1525 xmax: 0.87541467,
1526 ymax: 0.9998909,
1527 },
1528 score: 0.5591227,
1529 label: 0
1530 },
1531 1e-6
1532 ));
1533
1534 assert!(output_boxes[1].equal_within_delta(
1535 &DetectBox {
1536 bbox: BoundingBox {
1537 xmin: 0.130598,
1538 ymin: 0.43260583,
1539 xmax: 0.35098213,
1540 ymax: 0.9958097,
1541 },
1542 score: 0.33057618,
1543 label: 75
1544 },
1545 1e-6
1546 ));
1547
1548 let mut output_boxes1: Vec<_> = Vec::with_capacity(50);
1549 let mut output_masks1: Vec<_> = Vec::with_capacity(50);
1550 decoder
1551 .decode_quantized(&[out.view().into()], &mut output_boxes1, &mut output_masks1)
1552 .unwrap();
1553
1554 let out = dequantize_ndarray(out.view(), quant.into());
1555 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(50);
1556 let mut output_masks_f32: Vec<_> = Vec::with_capacity(50);
1557 decoder
1558 .decode_float::<f32>(
1559 &[out.view().into_dyn()],
1560 &mut output_boxes_f32,
1561 &mut output_masks_f32,
1562 )
1563 .unwrap();
1564
1565 compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1566 compare_outputs((&output_boxes, &output_boxes_f32), (&[], &output_masks_f32));
1567 }
1568
1569 #[test]
1570 fn test_decoder_masks() {
1571 let score_threshold = 0.45;
1572 let iou_threshold = 0.45;
1573 let boxes = include_bytes!("../../../testdata/yolov8_boxes_116x8400.bin");
1574 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1575 let boxes = ndarray::Array2::from_shape_vec((116, 8400), boxes.to_vec()).unwrap();
1576 let quant_boxes = Quantization::new(0.021287761628627777, 31);
1577
1578 let protos = include_bytes!("../../../testdata/yolov8_protos_160x160x32.bin");
1579 let protos =
1580 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
1581 let protos = ndarray::Array3::from_shape_vec((160, 160, 32), protos.to_vec()).unwrap();
1582 let quant_protos = Quantization::new(0.02491161972284317, -117);
1583 let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1584 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1585 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1586 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1587 decode_yolo_segdet_float(
1588 seg.view(),
1589 protos.view(),
1590 score_threshold,
1591 iou_threshold,
1592 Some(configs::Nms::ClassAgnostic),
1593 &mut output_boxes,
1594 &mut output_masks,
1595 );
1596 assert_eq!(output_boxes.len(), 2);
1597 assert_eq!(output_boxes.len(), output_masks.len());
1598
1599 for (b, m) in output_boxes.iter().zip(&output_masks) {
1600 assert!(b.bbox.xmin >= m.xmin);
1601 assert!(b.bbox.ymin >= m.ymin);
1602 assert!(b.bbox.xmax >= m.xmax);
1603 assert!(b.bbox.ymax >= m.ymax);
1604 }
1605 assert!(output_boxes[0].equal_within_delta(
1606 &DetectBox {
1607 bbox: BoundingBox {
1608 xmin: 0.08515105,
1609 ymin: 0.7131401,
1610 xmax: 0.29802868,
1611 ymax: 0.8195788,
1612 },
1613 score: 0.91537374,
1614 label: 23
1615 },
1616 1.0 / 160.0, ));
1618
1619 assert!(output_boxes[1].equal_within_delta(
1620 &DetectBox {
1621 bbox: BoundingBox {
1622 xmin: 0.59605736,
1623 ymin: 0.25545314,
1624 xmax: 0.93666154,
1625 ymax: 0.72378385,
1626 },
1627 score: 0.91537374,
1628 label: 23
1629 },
1630 1.0 / 160.0, ));
1632
1633 let full_mask = include_bytes!("../../../testdata/yolov8_mask_results.bin");
1634 let full_mask = ndarray::Array2::from_shape_vec((160, 160), full_mask.to_vec()).unwrap();
1635
1636 let cropped_mask = full_mask.slice(ndarray::s![
1637 (output_masks[1].ymin * 160.0) as usize..(output_masks[1].ymax * 160.0) as usize,
1638 (output_masks[1].xmin * 160.0) as usize..(output_masks[1].xmax * 160.0) as usize,
1639 ]);
1640
1641 assert_eq!(
1642 cropped_mask,
1643 segmentation_to_mask(output_masks[1].segmentation.view()).unwrap()
1644 );
1645 }
1646
1647 #[test]
1648 fn test_decoder_masks_i8() {
1649 let score_threshold = 0.45;
1650 let iou_threshold = 0.45;
1651 let boxes = include_bytes!("../../../testdata/yolov8_boxes_116x8400.bin");
1652 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1653 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes.to_vec()).unwrap();
1654 let quant_boxes = (0.021287761628627777, 31).into();
1655
1656 let protos = include_bytes!("../../../testdata/yolov8_protos_160x160x32.bin");
1657 let protos =
1658 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
1659 let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
1660 let quant_protos = (0.02491161972284317, -117).into();
1661 let mut output_boxes: Vec<_> = Vec::with_capacity(500);
1662 let mut output_masks: Vec<_> = Vec::with_capacity(500);
1663
1664 let decoder = DecoderBuilder::default()
1665 .with_config_yolo_segdet(
1666 configs::Detection {
1667 decoder: configs::DecoderType::Ultralytics,
1668 quantization: Some(quant_boxes),
1669 shape: vec![1, 116, 8400],
1670 anchors: None,
1671 dshape: vec![
1672 (DimName::Batch, 1),
1673 (DimName::NumFeatures, 116),
1674 (DimName::NumBoxes, 8400),
1675 ],
1676 normalized: Some(true),
1677 },
1678 Protos {
1679 decoder: configs::DecoderType::Ultralytics,
1680 quantization: Some(quant_protos),
1681 shape: vec![1, 160, 160, 32],
1682 dshape: vec![
1683 (DimName::Batch, 1),
1684 (DimName::Height, 160),
1685 (DimName::Width, 160),
1686 (DimName::NumProtos, 32),
1687 ],
1688 },
1689 Some(DecoderVersion::Yolo11),
1690 )
1691 .with_score_threshold(score_threshold)
1692 .with_iou_threshold(iou_threshold)
1693 .build()
1694 .unwrap();
1695
1696 let quant_boxes = quant_boxes.into();
1697 let quant_protos = quant_protos.into();
1698
1699 decode_yolo_segdet_quant(
1700 (boxes.slice(s![0, .., ..]), quant_boxes),
1701 (protos.slice(s![0, .., .., ..]), quant_protos),
1702 score_threshold,
1703 iou_threshold,
1704 Some(configs::Nms::ClassAgnostic),
1705 &mut output_boxes,
1706 &mut output_masks,
1707 );
1708
1709 let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
1710 let mut output_masks1: Vec<_> = Vec::with_capacity(500);
1711
1712 decoder
1713 .decode_quantized(
1714 &[boxes.view().into(), protos.view().into()],
1715 &mut output_boxes1,
1716 &mut output_masks1,
1717 )
1718 .unwrap();
1719
1720 let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1721 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1722
1723 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
1724 let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
1725 decode_yolo_segdet_float(
1726 seg.slice(s![0, .., ..]),
1727 protos.slice(s![0, .., .., ..]),
1728 score_threshold,
1729 iou_threshold,
1730 Some(configs::Nms::ClassAgnostic),
1731 &mut output_boxes_f32,
1732 &mut output_masks_f32,
1733 );
1734
1735 let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(500);
1736 let mut output_masks1_f32: Vec<_> = Vec::with_capacity(500);
1737
1738 decoder
1739 .decode_float(
1740 &[seg.view().into_dyn(), protos.view().into_dyn()],
1741 &mut output_boxes1_f32,
1742 &mut output_masks1_f32,
1743 )
1744 .unwrap();
1745
1746 compare_outputs(
1747 (&output_boxes, &output_boxes1),
1748 (&output_masks, &output_masks1),
1749 );
1750
1751 compare_outputs(
1752 (&output_boxes, &output_boxes_f32),
1753 (&output_masks, &output_masks_f32),
1754 );
1755
1756 compare_outputs(
1757 (&output_boxes_f32, &output_boxes1_f32),
1758 (&output_masks_f32, &output_masks1_f32),
1759 );
1760 }
1761
1762 #[test]
1763 fn test_decoder_yolo_split() {
1764 let score_threshold = 0.45;
1765 let iou_threshold = 0.45;
1766 let boxes = include_bytes!("../../../testdata/yolov8_boxes_116x8400.bin");
1767 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1768 let boxes: Vec<_> = boxes.iter().map(|x| *x as i16 * 256).collect();
1769 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
1770
1771 let quant_boxes = Quantization::new(0.021287761628627777 / 256.0, 31 * 256);
1772
1773 let decoder = DecoderBuilder::default()
1774 .with_config_yolo_split_det(
1775 configs::Boxes {
1776 decoder: configs::DecoderType::Ultralytics,
1777 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
1778 shape: vec![1, 4, 8400],
1779 dshape: vec![
1780 (DimName::Batch, 1),
1781 (DimName::BoxCoords, 4),
1782 (DimName::NumBoxes, 8400),
1783 ],
1784 normalized: Some(true),
1785 },
1786 configs::Scores {
1787 decoder: configs::DecoderType::Ultralytics,
1788 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
1789 shape: vec![1, 80, 8400],
1790 dshape: vec![
1791 (DimName::Batch, 1),
1792 (DimName::NumClasses, 80),
1793 (DimName::NumBoxes, 8400),
1794 ],
1795 },
1796 )
1797 .with_score_threshold(score_threshold)
1798 .with_iou_threshold(iou_threshold)
1799 .build()
1800 .unwrap();
1801
1802 let mut output_boxes: Vec<_> = Vec::with_capacity(500);
1803 let mut output_masks: Vec<_> = Vec::with_capacity(500);
1804
1805 decoder
1806 .decode_quantized(
1807 &[
1808 boxes.slice(s![.., ..4, ..]).into(),
1809 boxes.slice(s![.., 4..84, ..]).into(),
1810 ],
1811 &mut output_boxes,
1812 &mut output_masks,
1813 )
1814 .unwrap();
1815
1816 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1817 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
1818 decode_yolo_det_float(
1819 seg.slice(s![0, ..84, ..]),
1820 score_threshold,
1821 iou_threshold,
1822 Some(configs::Nms::ClassAgnostic),
1823 &mut output_boxes_f32,
1824 );
1825
1826 let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
1827 let mut output_masks1: Vec<_> = Vec::with_capacity(500);
1828
1829 decoder
1830 .decode_float(
1831 &[
1832 seg.slice(s![.., ..4, ..]).into_dyn(),
1833 seg.slice(s![.., 4..84, ..]).into_dyn(),
1834 ],
1835 &mut output_boxes1,
1836 &mut output_masks1,
1837 )
1838 .unwrap();
1839 compare_outputs((&output_boxes, &output_boxes_f32), (&output_masks, &[]));
1840 compare_outputs((&output_boxes_f32, &output_boxes1), (&[], &output_masks1));
1841 }
1842
1843 #[test]
1844 fn test_decoder_masks_config_mixed() {
1845 let score_threshold = 0.45;
1846 let iou_threshold = 0.45;
1847 let boxes = include_bytes!("../../../testdata/yolov8_boxes_116x8400.bin");
1848 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1849 let boxes: Vec<_> = boxes.iter().map(|x| *x as i16 * 256).collect();
1850 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
1851
1852 let quant_boxes = Quantization::new(0.021287761628627777 / 256.0, 31 * 256);
1853
1854 let protos = include_bytes!("../../../testdata/yolov8_protos_160x160x32.bin");
1855 let protos =
1856 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
1857 let protos: Vec<_> = protos.to_vec();
1858 let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
1859 let quant_protos = Quantization::new(0.02491161972284317, -117);
1860
1861 let decoder = DecoderBuilder::default()
1862 .with_config_yolo_split_segdet(
1863 configs::Boxes {
1864 decoder: configs::DecoderType::Ultralytics,
1865 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
1866 shape: vec![1, 4, 8400],
1867 dshape: vec![
1868 (DimName::Batch, 1),
1869 (DimName::BoxCoords, 4),
1870 (DimName::NumBoxes, 8400),
1871 ],
1872 normalized: Some(true),
1873 },
1874 configs::Scores {
1875 decoder: configs::DecoderType::Ultralytics,
1876 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
1877 shape: vec![1, 80, 8400],
1878 dshape: vec![
1879 (DimName::Batch, 1),
1880 (DimName::NumClasses, 80),
1881 (DimName::NumBoxes, 8400),
1882 ],
1883 },
1884 configs::MaskCoefficients {
1885 decoder: configs::DecoderType::Ultralytics,
1886 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
1887 shape: vec![1, 32, 8400],
1888 dshape: vec![
1889 (DimName::Batch, 1),
1890 (DimName::NumProtos, 32),
1891 (DimName::NumBoxes, 8400),
1892 ],
1893 },
1894 configs::Protos {
1895 decoder: configs::DecoderType::Ultralytics,
1896 quantization: Some(QuantTuple(quant_protos.scale, quant_protos.zero_point)),
1897 shape: vec![1, 160, 160, 32],
1898 dshape: vec![
1899 (DimName::Batch, 1),
1900 (DimName::Height, 160),
1901 (DimName::Width, 160),
1902 (DimName::NumProtos, 32),
1903 ],
1904 },
1905 )
1906 .with_score_threshold(score_threshold)
1907 .with_iou_threshold(iou_threshold)
1908 .build()
1909 .unwrap();
1910
1911 let mut output_boxes: Vec<_> = Vec::with_capacity(500);
1912 let mut output_masks: Vec<_> = Vec::with_capacity(500);
1913
1914 decoder
1915 .decode_quantized(
1916 &[
1917 boxes.slice(s![.., ..4, ..]).into(),
1918 boxes.slice(s![.., 4..84, ..]).into(),
1919 boxes.slice(s![.., 84.., ..]).into(),
1920 protos.view().into(),
1921 ],
1922 &mut output_boxes,
1923 &mut output_masks,
1924 )
1925 .unwrap();
1926
1927 let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1928 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1929 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
1930 let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
1931 decode_yolo_segdet_float(
1932 seg.slice(s![0, .., ..]),
1933 protos.slice(s![0, .., .., ..]),
1934 score_threshold,
1935 iou_threshold,
1936 Some(configs::Nms::ClassAgnostic),
1937 &mut output_boxes_f32,
1938 &mut output_masks_f32,
1939 );
1940
1941 let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
1942 let mut output_masks1: Vec<_> = Vec::with_capacity(500);
1943
1944 decoder
1945 .decode_float(
1946 &[
1947 seg.slice(s![.., ..4, ..]).into_dyn(),
1948 seg.slice(s![.., 4..84, ..]).into_dyn(),
1949 seg.slice(s![.., 84.., ..]).into_dyn(),
1950 protos.view().into_dyn(),
1951 ],
1952 &mut output_boxes1,
1953 &mut output_masks1,
1954 )
1955 .unwrap();
1956 compare_outputs(
1957 (&output_boxes, &output_boxes_f32),
1958 (&output_masks, &output_masks_f32),
1959 );
1960 compare_outputs(
1961 (&output_boxes_f32, &output_boxes1),
1962 (&output_masks_f32, &output_masks1),
1963 );
1964 }
1965
1966 #[test]
1967 fn test_decoder_masks_config_i32() {
1968 let score_threshold = 0.45;
1969 let iou_threshold = 0.45;
1970 let boxes = include_bytes!("../../../testdata/yolov8_boxes_116x8400.bin");
1971 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1972 let scale = 1 << 23;
1973 let boxes: Vec<_> = boxes.iter().map(|x| *x as i32 * scale).collect();
1974 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
1975
1976 let quant_boxes = Quantization::new(0.021287761628627777 / scale as f32, 31 * scale);
1977
1978 let protos = include_bytes!("../../../testdata/yolov8_protos_160x160x32.bin");
1979 let protos =
1980 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
1981 let protos: Vec<_> = protos.iter().map(|x| *x as i32 * scale).collect();
1982 let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
1983 let quant_protos = Quantization::new(0.02491161972284317 / scale as f32, -117 * scale);
1984
1985 let decoder = DecoderBuilder::default()
1986 .with_config_yolo_split_segdet(
1987 configs::Boxes {
1988 decoder: configs::DecoderType::Ultralytics,
1989 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
1990 shape: vec![1, 4, 8400],
1991 dshape: vec![
1992 (DimName::Batch, 1),
1993 (DimName::BoxCoords, 4),
1994 (DimName::NumBoxes, 8400),
1995 ],
1996 normalized: Some(true),
1997 },
1998 configs::Scores {
1999 decoder: configs::DecoderType::Ultralytics,
2000 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2001 shape: vec![1, 80, 8400],
2002 dshape: vec![
2003 (DimName::Batch, 1),
2004 (DimName::NumClasses, 80),
2005 (DimName::NumBoxes, 8400),
2006 ],
2007 },
2008 configs::MaskCoefficients {
2009 decoder: configs::DecoderType::Ultralytics,
2010 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2011 shape: vec![1, 32, 8400],
2012 dshape: vec![
2013 (DimName::Batch, 1),
2014 (DimName::NumProtos, 32),
2015 (DimName::NumBoxes, 8400),
2016 ],
2017 },
2018 configs::Protos {
2019 decoder: configs::DecoderType::Ultralytics,
2020 quantization: Some(QuantTuple(quant_protos.scale, quant_protos.zero_point)),
2021 shape: vec![1, 160, 160, 32],
2022 dshape: vec![
2023 (DimName::Batch, 1),
2024 (DimName::Height, 160),
2025 (DimName::Width, 160),
2026 (DimName::NumProtos, 32),
2027 ],
2028 },
2029 )
2030 .with_score_threshold(score_threshold)
2031 .with_iou_threshold(iou_threshold)
2032 .build()
2033 .unwrap();
2034
2035 let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2036 let mut output_masks: Vec<_> = Vec::with_capacity(500);
2037
2038 decoder
2039 .decode_quantized(
2040 &[
2041 boxes.slice(s![.., ..4, ..]).into(),
2042 boxes.slice(s![.., 4..84, ..]).into(),
2043 boxes.slice(s![.., 84.., ..]).into(),
2044 protos.view().into(),
2045 ],
2046 &mut output_boxes,
2047 &mut output_masks,
2048 )
2049 .unwrap();
2050
2051 let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
2052 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
2053 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2054 let mut output_masks_f32: Vec<Segmentation> = Vec::with_capacity(500);
2055 decode_yolo_segdet_float(
2056 seg.slice(s![0, .., ..]),
2057 protos.slice(s![0, .., .., ..]),
2058 score_threshold,
2059 iou_threshold,
2060 Some(configs::Nms::ClassAgnostic),
2061 &mut output_boxes_f32,
2062 &mut output_masks_f32,
2063 );
2064
2065 assert_eq!(output_boxes.len(), output_boxes_f32.len());
2066 assert_eq!(output_masks.len(), output_masks_f32.len());
2067
2068 compare_outputs(
2069 (&output_boxes, &output_boxes_f32),
2070 (&output_masks, &output_masks_f32),
2071 );
2072 }
2073
2074 #[test]
2076 fn test_context_switch() {
2077 let yolo_det = || {
2078 let score_threshold = 0.25;
2079 let iou_threshold = 0.7;
2080 let out = include_bytes!("../../../testdata/yolov8s_80_classes.bin");
2081 let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
2082 let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
2083 let quant = (0.0040811873, -123).into();
2084
2085 let decoder = DecoderBuilder::default()
2086 .with_config_yolo_det(
2087 configs::Detection {
2088 decoder: DecoderType::Ultralytics,
2089 shape: vec![1, 84, 8400],
2090 anchors: None,
2091 quantization: Some(quant),
2092 dshape: vec![
2093 (DimName::Batch, 1),
2094 (DimName::NumFeatures, 84),
2095 (DimName::NumBoxes, 8400),
2096 ],
2097 normalized: None,
2098 },
2099 None,
2100 )
2101 .with_score_threshold(score_threshold)
2102 .with_iou_threshold(iou_threshold)
2103 .build()
2104 .unwrap();
2105
2106 let mut output_boxes: Vec<_> = Vec::with_capacity(50);
2107 let mut output_masks: Vec<_> = Vec::with_capacity(50);
2108
2109 for _ in 0..100 {
2110 decoder
2111 .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
2112 .unwrap();
2113
2114 assert!(output_boxes[0].equal_within_delta(
2115 &DetectBox {
2116 bbox: BoundingBox {
2117 xmin: 0.5285137,
2118 ymin: 0.05305544,
2119 xmax: 0.87541467,
2120 ymax: 0.9998909,
2121 },
2122 score: 0.5591227,
2123 label: 0
2124 },
2125 1e-6
2126 ));
2127
2128 assert!(output_boxes[1].equal_within_delta(
2129 &DetectBox {
2130 bbox: BoundingBox {
2131 xmin: 0.130598,
2132 ymin: 0.43260583,
2133 xmax: 0.35098213,
2134 ymax: 0.9958097,
2135 },
2136 score: 0.33057618,
2137 label: 75
2138 },
2139 1e-6
2140 ));
2141 assert!(output_masks.is_empty());
2142 }
2143 };
2144
2145 let modelpack_det_split = || {
2146 let score_threshold = 0.8;
2147 let iou_threshold = 0.5;
2148
2149 let seg = include_bytes!("../../../testdata/modelpack_seg_2x160x160.bin");
2150 let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
2151
2152 let detect0 = include_bytes!("../../../testdata/modelpack_split_9x15x18.bin");
2153 let detect0 =
2154 ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
2155
2156 let detect1 = include_bytes!("../../../testdata/modelpack_split_17x30x18.bin");
2157 let detect1 =
2158 ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
2159
2160 let mut mask = seg.slice(s![0, .., .., ..]);
2161 mask.swap_axes(0, 1);
2162 mask.swap_axes(1, 2);
2163 let mask = [Segmentation {
2164 xmin: 0.0,
2165 ymin: 0.0,
2166 xmax: 1.0,
2167 ymax: 1.0,
2168 segmentation: mask.into_owned(),
2169 }];
2170 let correct_boxes = [DetectBox {
2171 bbox: BoundingBox {
2172 xmin: 0.43171933,
2173 ymin: 0.68243736,
2174 xmax: 0.5626645,
2175 ymax: 0.808863,
2176 },
2177 score: 0.99240804,
2178 label: 0,
2179 }];
2180
2181 let quant0 = (0.08547406643629074, 174).into();
2182 let quant1 = (0.09929127991199493, 183).into();
2183 let quant_seg = (1.0 / 255.0, 0).into();
2184
2185 let anchors0 = vec![
2186 [0.36666667461395264, 0.31481480598449707],
2187 [0.38749998807907104, 0.4740740656852722],
2188 [0.5333333611488342, 0.644444465637207],
2189 ];
2190 let anchors1 = vec![
2191 [0.13750000298023224, 0.2074074000120163],
2192 [0.2541666626930237, 0.21481481194496155],
2193 [0.23125000298023224, 0.35185185074806213],
2194 ];
2195
2196 let decoder = DecoderBuilder::default()
2197 .with_config_modelpack_segdet_split(
2198 vec![
2199 configs::Detection {
2200 decoder: DecoderType::ModelPack,
2201 shape: vec![1, 17, 30, 18],
2202 anchors: Some(anchors1),
2203 quantization: Some(quant1),
2204 dshape: vec![
2205 (DimName::Batch, 1),
2206 (DimName::Height, 17),
2207 (DimName::Width, 30),
2208 (DimName::NumAnchorsXFeatures, 18),
2209 ],
2210 normalized: None,
2211 },
2212 configs::Detection {
2213 decoder: DecoderType::ModelPack,
2214 shape: vec![1, 9, 15, 18],
2215 anchors: Some(anchors0),
2216 quantization: Some(quant0),
2217 dshape: vec![
2218 (DimName::Batch, 1),
2219 (DimName::Height, 9),
2220 (DimName::Width, 15),
2221 (DimName::NumAnchorsXFeatures, 18),
2222 ],
2223 normalized: None,
2224 },
2225 ],
2226 configs::Segmentation {
2227 decoder: DecoderType::ModelPack,
2228 quantization: Some(quant_seg),
2229 shape: vec![1, 2, 160, 160],
2230 dshape: vec![
2231 (DimName::Batch, 1),
2232 (DimName::NumClasses, 2),
2233 (DimName::Height, 160),
2234 (DimName::Width, 160),
2235 ],
2236 },
2237 )
2238 .with_score_threshold(score_threshold)
2239 .with_iou_threshold(iou_threshold)
2240 .build()
2241 .unwrap();
2242 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
2243 let mut output_masks: Vec<_> = Vec::with_capacity(10);
2244
2245 for _ in 0..100 {
2246 decoder
2247 .decode_quantized(
2248 &[
2249 detect0.view().into(),
2250 detect1.view().into(),
2251 seg.view().into(),
2252 ],
2253 &mut output_boxes,
2254 &mut output_masks,
2255 )
2256 .unwrap();
2257
2258 compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
2259 }
2260 };
2261
2262 let handles = vec![
2263 std::thread::spawn(yolo_det),
2264 std::thread::spawn(modelpack_det_split),
2265 std::thread::spawn(yolo_det),
2266 std::thread::spawn(modelpack_det_split),
2267 std::thread::spawn(yolo_det),
2268 std::thread::spawn(modelpack_det_split),
2269 std::thread::spawn(yolo_det),
2270 std::thread::spawn(modelpack_det_split),
2271 ];
2272 for handle in handles {
2273 handle.join().unwrap();
2274 }
2275 }
2276
2277 #[test]
2278 fn test_ndarray_to_xyxy_float() {
2279 let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2280 let xyxy: [f32; 4] = XYWH::ndarray_to_xyxy_float(arr.view());
2281 assert_eq!(xyxy, [0.0_f32, 10.0, 20.0, 30.0]);
2282
2283 let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2284 let xyxy: [f32; 4] = XYXY::ndarray_to_xyxy_float(arr.view());
2285 assert_eq!(xyxy, [10.0_f32, 20.0, 20.0, 20.0]);
2286 }
2287
2288 #[test]
2289 fn test_class_aware_nms_float() {
2290 use crate::float::nms_class_aware_float;
2291
2292 let boxes = vec![
2294 DetectBox {
2295 bbox: BoundingBox {
2296 xmin: 0.0,
2297 ymin: 0.0,
2298 xmax: 0.5,
2299 ymax: 0.5,
2300 },
2301 score: 0.9,
2302 label: 0, },
2304 DetectBox {
2305 bbox: BoundingBox {
2306 xmin: 0.1,
2307 ymin: 0.1,
2308 xmax: 0.6,
2309 ymax: 0.6,
2310 },
2311 score: 0.8,
2312 label: 1, },
2314 ];
2315
2316 let result = nms_class_aware_float(0.3, boxes.clone());
2319 assert_eq!(
2320 result.len(),
2321 2,
2322 "Class-aware NMS should keep both boxes with different classes"
2323 );
2324
2325 let same_class_boxes = vec![
2327 DetectBox {
2328 bbox: BoundingBox {
2329 xmin: 0.0,
2330 ymin: 0.0,
2331 xmax: 0.5,
2332 ymax: 0.5,
2333 },
2334 score: 0.9,
2335 label: 0,
2336 },
2337 DetectBox {
2338 bbox: BoundingBox {
2339 xmin: 0.1,
2340 ymin: 0.1,
2341 xmax: 0.6,
2342 ymax: 0.6,
2343 },
2344 score: 0.8,
2345 label: 0, },
2347 ];
2348
2349 let result = nms_class_aware_float(0.3, same_class_boxes);
2350 assert_eq!(
2351 result.len(),
2352 1,
2353 "Class-aware NMS should suppress overlapping box with same class"
2354 );
2355 assert_eq!(result[0].label, 0);
2356 assert!((result[0].score - 0.9).abs() < 1e-6);
2357 }
2358
2359 #[test]
2360 fn test_class_agnostic_vs_aware_nms() {
2361 use crate::float::{nms_class_aware_float, nms_float};
2362
2363 let boxes = vec![
2365 DetectBox {
2366 bbox: BoundingBox {
2367 xmin: 0.0,
2368 ymin: 0.0,
2369 xmax: 0.5,
2370 ymax: 0.5,
2371 },
2372 score: 0.9,
2373 label: 0,
2374 },
2375 DetectBox {
2376 bbox: BoundingBox {
2377 xmin: 0.1,
2378 ymin: 0.1,
2379 xmax: 0.6,
2380 ymax: 0.6,
2381 },
2382 score: 0.8,
2383 label: 1,
2384 },
2385 ];
2386
2387 let agnostic_result = nms_float(0.3, boxes.clone());
2389 assert_eq!(
2390 agnostic_result.len(),
2391 1,
2392 "Class-agnostic NMS should suppress overlapping boxes"
2393 );
2394
2395 let aware_result = nms_class_aware_float(0.3, boxes);
2397 assert_eq!(
2398 aware_result.len(),
2399 2,
2400 "Class-aware NMS should keep boxes with different classes"
2401 );
2402 }
2403
2404 #[test]
2405 fn test_class_aware_nms_int() {
2406 use crate::byte::nms_class_aware_int;
2407
2408 let boxes = vec![
2410 DetectBoxQuantized {
2411 bbox: BoundingBox {
2412 xmin: 0.0,
2413 ymin: 0.0,
2414 xmax: 0.5,
2415 ymax: 0.5,
2416 },
2417 score: 200_u8,
2418 label: 0,
2419 },
2420 DetectBoxQuantized {
2421 bbox: BoundingBox {
2422 xmin: 0.1,
2423 ymin: 0.1,
2424 xmax: 0.6,
2425 ymax: 0.6,
2426 },
2427 score: 180_u8,
2428 label: 1, },
2430 ];
2431
2432 let result = nms_class_aware_int(0.5, boxes);
2434 assert_eq!(
2435 result.len(),
2436 2,
2437 "Class-aware NMS (int) should keep boxes with different classes"
2438 );
2439 }
2440
2441 #[test]
2442 fn test_nms_enum_default() {
2443 let default_nms: configs::Nms = Default::default();
2445 assert_eq!(default_nms, configs::Nms::ClassAgnostic);
2446 }
2447
2448 #[test]
2449 fn test_decoder_nms_mode() {
2450 let decoder = DecoderBuilder::default()
2452 .with_config_yolo_det(
2453 configs::Detection {
2454 anchors: None,
2455 decoder: DecoderType::Ultralytics,
2456 quantization: None,
2457 shape: vec![1, 84, 8400],
2458 dshape: Vec::new(),
2459 normalized: Some(true),
2460 },
2461 None,
2462 )
2463 .with_nms(Some(configs::Nms::ClassAware))
2464 .build()
2465 .unwrap();
2466
2467 assert_eq!(decoder.nms, Some(configs::Nms::ClassAware));
2468 }
2469
2470 #[test]
2471 fn test_decoder_nms_bypass() {
2472 let decoder = DecoderBuilder::default()
2474 .with_config_yolo_det(
2475 configs::Detection {
2476 anchors: None,
2477 decoder: DecoderType::Ultralytics,
2478 quantization: None,
2479 shape: vec![1, 84, 8400],
2480 dshape: Vec::new(),
2481 normalized: Some(true),
2482 },
2483 None,
2484 )
2485 .with_nms(None)
2486 .build()
2487 .unwrap();
2488
2489 assert_eq!(decoder.nms, None);
2490 }
2491
2492 #[test]
2493 fn test_decoder_normalized_boxes_true() {
2494 let decoder = DecoderBuilder::default()
2496 .with_config_yolo_det(
2497 configs::Detection {
2498 anchors: None,
2499 decoder: DecoderType::Ultralytics,
2500 quantization: None,
2501 shape: vec![1, 84, 8400],
2502 dshape: Vec::new(),
2503 normalized: Some(true),
2504 },
2505 None,
2506 )
2507 .build()
2508 .unwrap();
2509
2510 assert_eq!(decoder.normalized_boxes(), Some(true));
2511 }
2512
2513 #[test]
2514 fn test_decoder_normalized_boxes_false() {
2515 let decoder = DecoderBuilder::default()
2518 .with_config_yolo_det(
2519 configs::Detection {
2520 anchors: None,
2521 decoder: DecoderType::Ultralytics,
2522 quantization: None,
2523 shape: vec![1, 84, 8400],
2524 dshape: Vec::new(),
2525 normalized: Some(false),
2526 },
2527 None,
2528 )
2529 .build()
2530 .unwrap();
2531
2532 assert_eq!(decoder.normalized_boxes(), Some(false));
2533 }
2534
2535 #[test]
2536 fn test_decoder_normalized_boxes_unknown() {
2537 let decoder = DecoderBuilder::default()
2539 .with_config_yolo_det(
2540 configs::Detection {
2541 anchors: None,
2542 decoder: DecoderType::Ultralytics,
2543 quantization: None,
2544 shape: vec![1, 84, 8400],
2545 dshape: Vec::new(),
2546 normalized: None,
2547 },
2548 Some(DecoderVersion::Yolo11),
2549 )
2550 .build()
2551 .unwrap();
2552
2553 assert_eq!(decoder.normalized_boxes(), None);
2554 }
2555}