1#![cfg_attr(coverage_nightly, feature(coverage_attribute))]
71
72use ndarray::{Array, Array2, Array3, ArrayView, ArrayView1, ArrayView3, Dimension};
73use num_traits::{AsPrimitive, Float, PrimInt};
74
75pub mod byte;
76pub mod error;
77pub mod float;
78pub mod modelpack;
79pub mod yolo;
80
81mod decoder;
82pub use decoder::*;
83
84pub use configs::{DecoderVersion, Nms};
85pub use error::{DecoderError, DecoderResult};
86
87use crate::{
88 decoder::configs::QuantTuple, modelpack::modelpack_segmentation_to_mask,
89 yolo::yolo_segmentation_to_mask,
90};
91
92pub trait BBoxTypeTrait {
94 fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4];
96
97 fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
99 input: &[B; 4],
100 quant: Quantization,
101 ) -> [A; 4]
102 where
103 f32: AsPrimitive<A>,
104 i32: AsPrimitive<A>;
105
106 #[inline(always)]
117 fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
118 input: ArrayView1<B>,
119 ) -> [A; 4] {
120 Self::to_xyxy_float(&[input[0], input[1], input[2], input[3]])
121 }
122
123 #[inline(always)]
124 fn ndarray_to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
126 input: ArrayView1<B>,
127 quant: Quantization,
128 ) -> [A; 4]
129 where
130 f32: AsPrimitive<A>,
131 i32: AsPrimitive<A>,
132 {
133 Self::to_xyxy_dequant(&[input[0], input[1], input[2], input[3]], quant)
134 }
135}
136
137#[derive(Debug, Clone, Copy, PartialEq, Eq)]
139pub struct XYXY {}
140
141impl BBoxTypeTrait for XYXY {
142 fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
143 input.map(|b| b.as_())
144 }
145
146 fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
147 input: &[B; 4],
148 quant: Quantization,
149 ) -> [A; 4]
150 where
151 f32: AsPrimitive<A>,
152 i32: AsPrimitive<A>,
153 {
154 let scale = quant.scale.as_();
155 let zp = quant.zero_point.as_();
156 input.map(|b| (b.as_() - zp) * scale)
157 }
158
159 #[inline(always)]
160 fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
161 input: ArrayView1<B>,
162 ) -> [A; 4] {
163 [
164 input[0].as_(),
165 input[1].as_(),
166 input[2].as_(),
167 input[3].as_(),
168 ]
169 }
170}
171
172#[derive(Debug, Clone, Copy, PartialEq, Eq)]
175pub struct XYWH {}
176
177impl BBoxTypeTrait for XYWH {
178 #[inline(always)]
179 fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
180 let half = A::one() / (A::one() + A::one());
181 [
182 (input[0].as_()) - (input[2].as_() * half),
183 (input[1].as_()) - (input[3].as_() * half),
184 (input[0].as_()) + (input[2].as_() * half),
185 (input[1].as_()) + (input[3].as_() * half),
186 ]
187 }
188
189 #[inline(always)]
190 fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
191 input: &[B; 4],
192 quant: Quantization,
193 ) -> [A; 4]
194 where
195 f32: AsPrimitive<A>,
196 i32: AsPrimitive<A>,
197 {
198 let scale = quant.scale.as_();
199 let half_scale = (quant.scale * 0.5).as_();
200 let zp = quant.zero_point.as_();
201 let [x, y, w, h] = [
202 (input[0].as_() - zp) * scale,
203 (input[1].as_() - zp) * scale,
204 (input[2].as_() - zp) * half_scale,
205 (input[3].as_() - zp) * half_scale,
206 ];
207
208 [x - w, y - h, x + w, y + h]
209 }
210
211 #[inline(always)]
212 fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
213 input: ArrayView1<B>,
214 ) -> [A; 4] {
215 let half = A::one() / (A::one() + A::one());
216 [
217 (input[0].as_()) - (input[2].as_() * half),
218 (input[1].as_()) - (input[3].as_() * half),
219 (input[0].as_()) + (input[2].as_() * half),
220 (input[1].as_()) + (input[3].as_() * half),
221 ]
222 }
223}
224
225#[derive(Debug, Clone, Copy, PartialEq)]
227pub struct Quantization {
228 pub scale: f32,
229 pub zero_point: i32,
230}
231
232impl Quantization {
233 pub fn new(scale: f32, zero_point: i32) -> Self {
242 Self { scale, zero_point }
243 }
244}
245
246impl From<QuantTuple> for Quantization {
247 fn from(quant_tuple: QuantTuple) -> Quantization {
258 Quantization {
259 scale: quant_tuple.0,
260 zero_point: quant_tuple.1,
261 }
262 }
263}
264
265impl<S, Z> From<(S, Z)> for Quantization
266where
267 S: AsPrimitive<f32>,
268 Z: AsPrimitive<i32>,
269{
270 fn from((scale, zp): (S, Z)) -> Quantization {
279 Self {
280 scale: scale.as_(),
281 zero_point: zp.as_(),
282 }
283 }
284}
285
286impl Default for Quantization {
287 fn default() -> Self {
296 Self {
297 scale: 1.0,
298 zero_point: 0,
299 }
300 }
301}
302
303#[derive(Debug, Clone, Copy, PartialEq, Default)]
305pub struct DetectBox {
306 pub bbox: BoundingBox,
307 pub score: f32,
309 pub label: usize,
311}
312
313#[derive(Debug, Clone, Copy, PartialEq, Default)]
315pub struct BoundingBox {
316 pub xmin: f32,
318 pub ymin: f32,
320 pub xmax: f32,
322 pub ymax: f32,
324}
325
326impl BoundingBox {
327 pub fn new(xmin: f32, ymin: f32, xmax: f32, ymax: f32) -> Self {
329 Self {
330 xmin,
331 ymin,
332 xmax,
333 ymax,
334 }
335 }
336
337 pub fn to_canonical(&self) -> Self {
346 let xmin = self.xmin.min(self.xmax);
347 let xmax = self.xmin.max(self.xmax);
348 let ymin = self.ymin.min(self.ymax);
349 let ymax = self.ymin.max(self.ymax);
350 BoundingBox {
351 xmin,
352 ymin,
353 xmax,
354 ymax,
355 }
356 }
357}
358
359impl From<BoundingBox> for [f32; 4] {
360 fn from(b: BoundingBox) -> Self {
375 [b.xmin, b.ymin, b.xmax, b.ymax]
376 }
377}
378
379impl From<[f32; 4]> for BoundingBox {
380 fn from(arr: [f32; 4]) -> Self {
383 BoundingBox {
384 xmin: arr[0],
385 ymin: arr[1],
386 xmax: arr[2],
387 ymax: arr[3],
388 }
389 }
390}
391
392impl DetectBox {
393 pub fn equal_within_delta(&self, rhs: &DetectBox, eps: f32) -> bool {
422 let eq_delta = |a: f32, b: f32| (a - b).abs() <= eps;
423 self.label == rhs.label
424 && eq_delta(self.score, rhs.score)
425 && eq_delta(self.bbox.xmin, rhs.bbox.xmin)
426 && eq_delta(self.bbox.ymin, rhs.bbox.ymin)
427 && eq_delta(self.bbox.xmax, rhs.bbox.xmax)
428 && eq_delta(self.bbox.ymax, rhs.bbox.ymax)
429 }
430}
431
432#[derive(Debug, Clone, PartialEq, Default)]
435pub struct Segmentation {
436 pub xmin: f32,
438 pub ymin: f32,
440 pub xmax: f32,
442 pub ymax: f32,
444 pub segmentation: Array3<u8>,
452}
453
454#[derive(Debug, Clone)]
460pub enum ProtoTensor {
461 Quantized {
465 protos: Array3<i8>,
466 quantization: Quantization,
467 },
468 Float(Array3<f32>),
470}
471
472impl ProtoTensor {
473 pub fn is_quantized(&self) -> bool {
475 matches!(self, ProtoTensor::Quantized { .. })
476 }
477
478 pub fn dim(&self) -> (usize, usize, usize) {
480 match self {
481 ProtoTensor::Quantized { protos, .. } => protos.dim(),
482 ProtoTensor::Float(arr) => arr.dim(),
483 }
484 }
485
486 pub fn as_f32(&self) -> std::borrow::Cow<'_, Array3<f32>> {
489 match self {
490 ProtoTensor::Float(arr) => std::borrow::Cow::Borrowed(arr),
491 ProtoTensor::Quantized {
492 protos,
493 quantization,
494 } => {
495 let scale = quantization.scale;
496 let zp = quantization.zero_point as f32;
497 std::borrow::Cow::Owned(protos.map(|&v| (v as f32 - zp) * scale))
498 }
499 }
500 }
501}
502
503#[derive(Debug, Clone)]
509pub struct ProtoData {
510 pub mask_coefficients: Vec<Vec<f32>>,
512 pub protos: ProtoTensor,
514}
515
516pub fn dequant_detect_box<SCORE: PrimInt + AsPrimitive<f32>>(
534 detect: &DetectBoxQuantized<SCORE>,
535 quant_scores: Quantization,
536) -> DetectBox {
537 let scaled_zp = -quant_scores.scale * quant_scores.zero_point as f32;
538 DetectBox {
539 bbox: detect.bbox,
540 score: quant_scores.scale * detect.score.as_() + scaled_zp,
541 label: detect.label,
542 }
543}
544#[derive(Debug, Clone, Copy, PartialEq)]
546pub struct DetectBoxQuantized<
547 SCORE: PrimInt + AsPrimitive<f32>,
549> {
550 pub bbox: BoundingBox,
552 pub score: SCORE,
555 pub label: usize,
557}
558
559pub fn dequantize_ndarray<T: AsPrimitive<F>, D: Dimension, F: Float + 'static>(
572 input: ArrayView<T, D>,
573 quant: Quantization,
574) -> Array<F, D>
575where
576 i32: num_traits::AsPrimitive<F>,
577 f32: num_traits::AsPrimitive<F>,
578{
579 let zero_point = quant.zero_point.as_();
580 let scale = quant.scale.as_();
581 if zero_point != F::zero() {
582 let scaled_zero = -zero_point * scale;
583 input.mapv(|d| d.as_() * scale + scaled_zero)
584 } else {
585 input.mapv(|d| d.as_() * scale)
586 }
587}
588
589pub fn dequantize_cpu<T: AsPrimitive<F>, F: Float + 'static>(
602 input: &[T],
603 quant: Quantization,
604 output: &mut [F],
605) where
606 f32: num_traits::AsPrimitive<F>,
607 i32: num_traits::AsPrimitive<F>,
608{
609 assert!(input.len() == output.len());
610 let zero_point = quant.zero_point.as_();
611 let scale = quant.scale.as_();
612 if zero_point != F::zero() {
613 let scaled_zero = -zero_point * scale; input
615 .iter()
616 .zip(output)
617 .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
618 } else {
619 input
620 .iter()
621 .zip(output)
622 .for_each(|(d, deq)| *deq = d.as_() * scale);
623 }
624}
625
626pub fn dequantize_cpu_chunked<T: AsPrimitive<F>, F: Float + 'static>(
640 input: &[T],
641 quant: Quantization,
642 output: &mut [F],
643) where
644 f32: num_traits::AsPrimitive<F>,
645 i32: num_traits::AsPrimitive<F>,
646{
647 assert!(input.len() == output.len());
648 let zero_point = quant.zero_point.as_();
649 let scale = quant.scale.as_();
650
651 let input = input.as_chunks::<4>();
652 let output = output.as_chunks_mut::<4>();
653
654 if zero_point != F::zero() {
655 let scaled_zero = -zero_point * scale; input
658 .0
659 .iter()
660 .zip(output.0)
661 .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale + scaled_zero));
662 input
663 .1
664 .iter()
665 .zip(output.1)
666 .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
667 } else {
668 input
669 .0
670 .iter()
671 .zip(output.0)
672 .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale));
673 input
674 .1
675 .iter()
676 .zip(output.1)
677 .for_each(|(d, deq)| *deq = d.as_() * scale);
678 }
679}
680
681pub fn segmentation_to_mask(segmentation: ArrayView3<u8>) -> Result<Array2<u8>, DecoderError> {
699 if segmentation.shape()[2] == 0 {
700 return Err(DecoderError::InvalidShape(
701 "Segmentation tensor must have non-zero depth".to_string(),
702 ));
703 }
704 if segmentation.shape()[2] == 1 {
705 yolo_segmentation_to_mask(segmentation, 128)
706 } else {
707 Ok(modelpack_segmentation_to_mask(segmentation))
708 }
709}
710
711fn arg_max<T: PartialOrd + Copy>(score: ArrayView1<T>) -> (T, usize) {
713 score
714 .iter()
715 .enumerate()
716 .fold((score[0], 0), |(max, arg_max), (ind, s)| {
717 if max > *s {
718 (max, arg_max)
719 } else {
720 (*s, ind)
721 }
722 })
723}
724#[cfg(test)]
725#[cfg_attr(coverage_nightly, coverage(off))]
726mod decoder_tests {
727 #![allow(clippy::excessive_precision)]
728 use crate::{
729 configs::{DecoderType, DimName, Protos},
730 modelpack::{decode_modelpack_det, decode_modelpack_split_quant},
731 yolo::{
732 decode_yolo_det, decode_yolo_det_float, decode_yolo_segdet_float,
733 decode_yolo_segdet_quant,
734 },
735 *,
736 };
737 use ndarray::{array, s, Array4};
738 use ndarray_stats::DeviationExt;
739
740 fn compare_outputs(
741 boxes: (&[DetectBox], &[DetectBox]),
742 masks: (&[Segmentation], &[Segmentation]),
743 ) {
744 let (boxes0, boxes1) = boxes;
745 let (masks0, masks1) = masks;
746
747 assert_eq!(boxes0.len(), boxes1.len());
748 assert_eq!(masks0.len(), masks1.len());
749
750 for (b_i8, b_f32) in boxes0.iter().zip(boxes1) {
751 assert!(
752 b_i8.equal_within_delta(b_f32, 1e-6),
753 "{b_i8:?} is not equal to {b_f32:?}"
754 );
755 }
756
757 for (m_i8, m_f32) in masks0.iter().zip(masks1) {
758 assert_eq!(
759 [m_i8.xmin, m_i8.ymin, m_i8.xmax, m_i8.ymax],
760 [m_f32.xmin, m_f32.ymin, m_f32.xmax, m_f32.ymax],
761 );
762 assert_eq!(m_i8.segmentation.shape(), m_f32.segmentation.shape());
763 let mask_i8 = m_i8.segmentation.map(|x| *x as i32);
764 let mask_f32 = m_f32.segmentation.map(|x| *x as i32);
765 let diff = &mask_i8 - &mask_f32;
766 for x in 0..diff.shape()[0] {
767 for y in 0..diff.shape()[1] {
768 for z in 0..diff.shape()[2] {
769 let val = diff[[x, y, z]];
770 assert!(
771 val.abs() <= 1,
772 "Difference between mask0 and mask1 is greater than 1 at ({}, {}, {}): {}",
773 x,
774 y,
775 z,
776 val
777 );
778 }
779 }
780 }
781 let mean_sq_err = mask_i8.mean_sq_err(&mask_f32).unwrap();
782 assert!(
783 mean_sq_err < 1e-2,
784 "Mean Square Error between masks was greater than 1%: {:.2}%",
785 mean_sq_err * 100.0
786 );
787 }
788 }
789
790 #[test]
791 fn test_decoder_modelpack() {
792 let score_threshold = 0.45;
793 let iou_threshold = 0.45;
794 let boxes = include_bytes!("../../../testdata/modelpack_boxes_1935x1x4.bin");
795 let boxes = ndarray::Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
796
797 let scores = include_bytes!("../../../testdata/modelpack_scores_1935x1.bin");
798 let scores = ndarray::Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
799
800 let quant_boxes = (0.004656755365431309, 21).into();
801 let quant_scores = (0.0019603664986789227, 0).into();
802
803 let decoder = DecoderBuilder::default()
804 .with_config_modelpack_det(
805 configs::Boxes {
806 decoder: DecoderType::ModelPack,
807 quantization: Some(quant_boxes),
808 shape: vec![1, 1935, 1, 4],
809 dshape: vec![
810 (DimName::Batch, 1),
811 (DimName::NumBoxes, 1935),
812 (DimName::Padding, 1),
813 (DimName::BoxCoords, 4),
814 ],
815 normalized: Some(true),
816 },
817 configs::Scores {
818 decoder: DecoderType::ModelPack,
819 quantization: Some(quant_scores),
820 shape: vec![1, 1935, 1],
821 dshape: vec![
822 (DimName::Batch, 1),
823 (DimName::NumBoxes, 1935),
824 (DimName::NumClasses, 1),
825 ],
826 },
827 )
828 .with_score_threshold(score_threshold)
829 .with_iou_threshold(iou_threshold)
830 .build()
831 .unwrap();
832
833 let quant_boxes = quant_boxes.into();
834 let quant_scores = quant_scores.into();
835
836 let mut output_boxes: Vec<_> = Vec::with_capacity(50);
837 decode_modelpack_det(
838 (boxes.slice(s![0, .., 0, ..]), quant_boxes),
839 (scores.slice(s![0, .., ..]), quant_scores),
840 score_threshold,
841 iou_threshold,
842 &mut output_boxes,
843 );
844 assert!(output_boxes[0].equal_within_delta(
845 &DetectBox {
846 bbox: BoundingBox {
847 xmin: 0.40513772,
848 ymin: 0.6379755,
849 xmax: 0.5122431,
850 ymax: 0.7730214,
851 },
852 score: 0.4861709,
853 label: 0
854 },
855 1e-6
856 ));
857
858 let mut output_boxes1 = Vec::with_capacity(50);
859 let mut output_masks1 = Vec::with_capacity(50);
860
861 decoder
862 .decode_quantized(
863 &[boxes.view().into(), scores.view().into()],
864 &mut output_boxes1,
865 &mut output_masks1,
866 )
867 .unwrap();
868
869 let mut output_boxes_float = Vec::with_capacity(50);
870 let mut output_masks_float = Vec::with_capacity(50);
871
872 let boxes = dequantize_ndarray(boxes.view(), quant_boxes);
873 let scores = dequantize_ndarray(scores.view(), quant_scores);
874
875 decoder
876 .decode_float::<f32>(
877 &[boxes.view().into_dyn(), scores.view().into_dyn()],
878 &mut output_boxes_float,
879 &mut output_masks_float,
880 )
881 .unwrap();
882
883 compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
884 compare_outputs(
885 (&output_boxes, &output_boxes_float),
886 (&[], &output_masks_float),
887 );
888 }
889
890 #[test]
891 fn test_decoder_modelpack_split_u8() {
892 let score_threshold = 0.45;
893 let iou_threshold = 0.45;
894 let detect0 = include_bytes!("../../../testdata/modelpack_split_9x15x18.bin");
895 let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
896
897 let detect1 = include_bytes!("../../../testdata/modelpack_split_17x30x18.bin");
898 let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
899
900 let quant0 = (0.08547406643629074, 174).into();
901 let quant1 = (0.09929127991199493, 183).into();
902 let anchors0 = vec![
903 [0.36666667461395264, 0.31481480598449707],
904 [0.38749998807907104, 0.4740740656852722],
905 [0.5333333611488342, 0.644444465637207],
906 ];
907 let anchors1 = vec![
908 [0.13750000298023224, 0.2074074000120163],
909 [0.2541666626930237, 0.21481481194496155],
910 [0.23125000298023224, 0.35185185074806213],
911 ];
912
913 let detect_config0 = configs::Detection {
914 decoder: DecoderType::ModelPack,
915 shape: vec![1, 9, 15, 18],
916 anchors: Some(anchors0.clone()),
917 quantization: Some(quant0),
918 dshape: vec![
919 (DimName::Batch, 1),
920 (DimName::Height, 9),
921 (DimName::Width, 15),
922 (DimName::NumAnchorsXFeatures, 18),
923 ],
924 normalized: Some(true),
925 };
926
927 let detect_config1 = configs::Detection {
928 decoder: DecoderType::ModelPack,
929 shape: vec![1, 17, 30, 18],
930 anchors: Some(anchors1.clone()),
931 quantization: Some(quant1),
932 dshape: vec![
933 (DimName::Batch, 1),
934 (DimName::Height, 17),
935 (DimName::Width, 30),
936 (DimName::NumAnchorsXFeatures, 18),
937 ],
938 normalized: Some(true),
939 };
940
941 let config0 = (&detect_config0).try_into().unwrap();
942 let config1 = (&detect_config1).try_into().unwrap();
943
944 let decoder = DecoderBuilder::default()
945 .with_config_modelpack_det_split(vec![detect_config1, detect_config0])
946 .with_score_threshold(score_threshold)
947 .with_iou_threshold(iou_threshold)
948 .build()
949 .unwrap();
950
951 let quant0 = quant0.into();
952 let quant1 = quant1.into();
953
954 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
955 decode_modelpack_split_quant(
956 &[
957 detect0.slice(s![0, .., .., ..]),
958 detect1.slice(s![0, .., .., ..]),
959 ],
960 &[config0, config1],
961 score_threshold,
962 iou_threshold,
963 &mut output_boxes,
964 );
965 assert!(output_boxes[0].equal_within_delta(
966 &DetectBox {
967 bbox: BoundingBox {
968 xmin: 0.43171933,
969 ymin: 0.68243736,
970 xmax: 0.5626645,
971 ymax: 0.808863,
972 },
973 score: 0.99240804,
974 label: 0
975 },
976 1e-6
977 ));
978
979 let mut output_boxes1: Vec<_> = Vec::with_capacity(10);
980 let mut output_masks1: Vec<_> = Vec::with_capacity(10);
981 decoder
982 .decode_quantized(
983 &[detect0.view().into(), detect1.view().into()],
984 &mut output_boxes1,
985 &mut output_masks1,
986 )
987 .unwrap();
988
989 let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(10);
990 let mut output_masks1_f32: Vec<_> = Vec::with_capacity(10);
991
992 let detect0 = dequantize_ndarray(detect0.view(), quant0);
993 let detect1 = dequantize_ndarray(detect1.view(), quant1);
994 decoder
995 .decode_float::<f32>(
996 &[detect0.view().into_dyn(), detect1.view().into_dyn()],
997 &mut output_boxes1_f32,
998 &mut output_masks1_f32,
999 )
1000 .unwrap();
1001
1002 compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1003 compare_outputs(
1004 (&output_boxes, &output_boxes1_f32),
1005 (&[], &output_masks1_f32),
1006 );
1007 }
1008
1009 #[test]
1010 fn test_decoder_parse_config_modelpack_split_u8() {
1011 let score_threshold = 0.45;
1012 let iou_threshold = 0.45;
1013 let detect0 = include_bytes!("../../../testdata/modelpack_split_9x15x18.bin");
1014 let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1015
1016 let detect1 = include_bytes!("../../../testdata/modelpack_split_17x30x18.bin");
1017 let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1018
1019 let decoder = DecoderBuilder::default()
1020 .with_config_yaml_str(
1021 include_str!("../../../testdata/modelpack_split.yaml").to_string(),
1022 )
1023 .with_score_threshold(score_threshold)
1024 .with_iou_threshold(iou_threshold)
1025 .build()
1026 .unwrap();
1027
1028 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1029 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1030 decoder
1031 .decode_quantized(
1032 &[
1033 ArrayViewDQuantized::from(detect1.view()),
1034 ArrayViewDQuantized::from(detect0.view()),
1035 ],
1036 &mut output_boxes,
1037 &mut output_masks,
1038 )
1039 .unwrap();
1040 assert!(output_boxes[0].equal_within_delta(
1041 &DetectBox {
1042 bbox: BoundingBox {
1043 xmin: 0.43171933,
1044 ymin: 0.68243736,
1045 xmax: 0.5626645,
1046 ymax: 0.808863,
1047 },
1048 score: 0.99240804,
1049 label: 0
1050 },
1051 1e-6
1052 ));
1053 }
1054
1055 #[test]
1056 fn test_modelpack_seg() {
1057 let out = include_bytes!("../../../testdata/modelpack_seg_2x160x160.bin");
1058 let out = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
1059 let quant = (1.0 / 255.0, 0).into();
1060
1061 let decoder = DecoderBuilder::default()
1062 .with_config_modelpack_seg(configs::Segmentation {
1063 decoder: DecoderType::ModelPack,
1064 quantization: Some(quant),
1065 shape: vec![1, 2, 160, 160],
1066 dshape: vec![
1067 (DimName::Batch, 1),
1068 (DimName::NumClasses, 2),
1069 (DimName::Height, 160),
1070 (DimName::Width, 160),
1071 ],
1072 })
1073 .build()
1074 .unwrap();
1075 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1076 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1077 decoder
1078 .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
1079 .unwrap();
1080
1081 let mut mask = out.slice(s![0, .., .., ..]);
1082 mask.swap_axes(0, 1);
1083 mask.swap_axes(1, 2);
1084 let mask = [Segmentation {
1085 xmin: 0.0,
1086 ymin: 0.0,
1087 xmax: 1.0,
1088 ymax: 1.0,
1089 segmentation: mask.into_owned(),
1090 }];
1091 compare_outputs((&[], &output_boxes), (&mask, &output_masks));
1092
1093 decoder
1094 .decode_float::<f32>(
1095 &[dequantize_ndarray(out.view(), quant.into())
1096 .view()
1097 .into_dyn()],
1098 &mut output_boxes,
1099 &mut output_masks,
1100 )
1101 .unwrap();
1102
1103 compare_outputs((&[], &output_boxes), (&[], &[]));
1109 let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1110 let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1111
1112 assert_eq!(mask0, mask1);
1113 }
1114 #[test]
1115 fn test_modelpack_seg_quant() {
1116 let out = include_bytes!("../../../testdata/modelpack_seg_2x160x160.bin");
1117 let out_u8 = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
1118 let out_i8 = out_u8.mapv(|x| (x as i16 - 128) as i8);
1119 let out_u16 = out_u8.mapv(|x| (x as u16) << 8);
1120 let out_i16 = out_u8.mapv(|x| (((x as i32) << 8) - 32768) as i16);
1121 let out_u32 = out_u8.mapv(|x| (x as u32) << 24);
1122 let out_i32 = out_u8.mapv(|x| (((x as i64) << 24) - 2147483648) as i32);
1123
1124 let quant = (1.0 / 255.0, 0).into();
1125
1126 let decoder = DecoderBuilder::default()
1127 .with_config_modelpack_seg(configs::Segmentation {
1128 decoder: DecoderType::ModelPack,
1129 quantization: Some(quant),
1130 shape: vec![1, 2, 160, 160],
1131 dshape: vec![
1132 (DimName::Batch, 1),
1133 (DimName::NumClasses, 2),
1134 (DimName::Height, 160),
1135 (DimName::Width, 160),
1136 ],
1137 })
1138 .build()
1139 .unwrap();
1140 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1141 let mut output_masks_u8: Vec<_> = Vec::with_capacity(10);
1142 decoder
1143 .decode_quantized(
1144 &[out_u8.view().into()],
1145 &mut output_boxes,
1146 &mut output_masks_u8,
1147 )
1148 .unwrap();
1149
1150 let mut output_masks_i8: Vec<_> = Vec::with_capacity(10);
1151 decoder
1152 .decode_quantized(
1153 &[out_i8.view().into()],
1154 &mut output_boxes,
1155 &mut output_masks_i8,
1156 )
1157 .unwrap();
1158
1159 let mut output_masks_u16: Vec<_> = Vec::with_capacity(10);
1160 decoder
1161 .decode_quantized(
1162 &[out_u16.view().into()],
1163 &mut output_boxes,
1164 &mut output_masks_u16,
1165 )
1166 .unwrap();
1167
1168 let mut output_masks_i16: Vec<_> = Vec::with_capacity(10);
1169 decoder
1170 .decode_quantized(
1171 &[out_i16.view().into()],
1172 &mut output_boxes,
1173 &mut output_masks_i16,
1174 )
1175 .unwrap();
1176
1177 let mut output_masks_u32: Vec<_> = Vec::with_capacity(10);
1178 decoder
1179 .decode_quantized(
1180 &[out_u32.view().into()],
1181 &mut output_boxes,
1182 &mut output_masks_u32,
1183 )
1184 .unwrap();
1185
1186 let mut output_masks_i32: Vec<_> = Vec::with_capacity(10);
1187 decoder
1188 .decode_quantized(
1189 &[out_i32.view().into()],
1190 &mut output_boxes,
1191 &mut output_masks_i32,
1192 )
1193 .unwrap();
1194
1195 compare_outputs((&[], &output_boxes), (&[], &[]));
1196 let mask_u8 = segmentation_to_mask(output_masks_u8[0].segmentation.view()).unwrap();
1197 let mask_i8 = segmentation_to_mask(output_masks_i8[0].segmentation.view()).unwrap();
1198 let mask_u16 = segmentation_to_mask(output_masks_u16[0].segmentation.view()).unwrap();
1199 let mask_i16 = segmentation_to_mask(output_masks_i16[0].segmentation.view()).unwrap();
1200 let mask_u32 = segmentation_to_mask(output_masks_u32[0].segmentation.view()).unwrap();
1201 let mask_i32 = segmentation_to_mask(output_masks_i32[0].segmentation.view()).unwrap();
1202 assert_eq!(mask_u8, mask_i8);
1203 assert_eq!(mask_u8, mask_u16);
1204 assert_eq!(mask_u8, mask_i16);
1205 assert_eq!(mask_u8, mask_u32);
1206 assert_eq!(mask_u8, mask_i32);
1207 }
1208
1209 #[test]
1210 fn test_modelpack_segdet() {
1211 let score_threshold = 0.45;
1212 let iou_threshold = 0.45;
1213
1214 let boxes = include_bytes!("../../../testdata/modelpack_boxes_1935x1x4.bin");
1215 let boxes = Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
1216
1217 let scores = include_bytes!("../../../testdata/modelpack_scores_1935x1.bin");
1218 let scores = Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
1219
1220 let seg = include_bytes!("../../../testdata/modelpack_seg_2x160x160.bin");
1221 let seg = Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1222
1223 let quant_boxes = (0.004656755365431309, 21).into();
1224 let quant_scores = (0.0019603664986789227, 0).into();
1225 let quant_seg = (1.0 / 255.0, 0).into();
1226
1227 let decoder = DecoderBuilder::default()
1228 .with_config_modelpack_segdet(
1229 configs::Boxes {
1230 decoder: DecoderType::ModelPack,
1231 quantization: Some(quant_boxes),
1232 shape: vec![1, 1935, 1, 4],
1233 dshape: vec![
1234 (DimName::Batch, 1),
1235 (DimName::NumBoxes, 1935),
1236 (DimName::Padding, 1),
1237 (DimName::BoxCoords, 4),
1238 ],
1239 normalized: Some(true),
1240 },
1241 configs::Scores {
1242 decoder: DecoderType::ModelPack,
1243 quantization: Some(quant_scores),
1244 shape: vec![1, 1935, 1],
1245 dshape: vec![
1246 (DimName::Batch, 1),
1247 (DimName::NumBoxes, 1935),
1248 (DimName::NumClasses, 1),
1249 ],
1250 },
1251 configs::Segmentation {
1252 decoder: DecoderType::ModelPack,
1253 quantization: Some(quant_seg),
1254 shape: vec![1, 2, 160, 160],
1255 dshape: vec![
1256 (DimName::Batch, 1),
1257 (DimName::NumClasses, 2),
1258 (DimName::Height, 160),
1259 (DimName::Width, 160),
1260 ],
1261 },
1262 )
1263 .with_iou_threshold(iou_threshold)
1264 .with_score_threshold(score_threshold)
1265 .build()
1266 .unwrap();
1267 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1268 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1269 decoder
1270 .decode_quantized(
1271 &[scores.view().into(), boxes.view().into(), seg.view().into()],
1272 &mut output_boxes,
1273 &mut output_masks,
1274 )
1275 .unwrap();
1276
1277 let mut mask = seg.slice(s![0, .., .., ..]);
1278 mask.swap_axes(0, 1);
1279 mask.swap_axes(1, 2);
1280 let mask = [Segmentation {
1281 xmin: 0.0,
1282 ymin: 0.0,
1283 xmax: 1.0,
1284 ymax: 1.0,
1285 segmentation: mask.into_owned(),
1286 }];
1287 let correct_boxes = [DetectBox {
1288 bbox: BoundingBox {
1289 xmin: 0.40513772,
1290 ymin: 0.6379755,
1291 xmax: 0.5122431,
1292 ymax: 0.7730214,
1293 },
1294 score: 0.4861709,
1295 label: 0,
1296 }];
1297 compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1298
1299 let scores = dequantize_ndarray(scores.view(), quant_scores.into());
1300 let boxes = dequantize_ndarray(boxes.view(), quant_boxes.into());
1301 let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1302 decoder
1303 .decode_float::<f32>(
1304 &[
1305 scores.view().into_dyn(),
1306 boxes.view().into_dyn(),
1307 seg.view().into_dyn(),
1308 ],
1309 &mut output_boxes,
1310 &mut output_masks,
1311 )
1312 .unwrap();
1313
1314 compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1320 let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1321 let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1322
1323 assert_eq!(mask0, mask1);
1324 }
1325
1326 #[test]
1327 fn test_modelpack_segdet_split() {
1328 let score_threshold = 0.8;
1329 let iou_threshold = 0.5;
1330
1331 let seg = include_bytes!("../../../testdata/modelpack_seg_2x160x160.bin");
1332 let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1333
1334 let detect0 = include_bytes!("../../../testdata/modelpack_split_9x15x18.bin");
1335 let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1336
1337 let detect1 = include_bytes!("../../../testdata/modelpack_split_17x30x18.bin");
1338 let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1339
1340 let quant0 = (0.08547406643629074, 174).into();
1341 let quant1 = (0.09929127991199493, 183).into();
1342 let quant_seg = (1.0 / 255.0, 0).into();
1343
1344 let anchors0 = vec![
1345 [0.36666667461395264, 0.31481480598449707],
1346 [0.38749998807907104, 0.4740740656852722],
1347 [0.5333333611488342, 0.644444465637207],
1348 ];
1349 let anchors1 = vec![
1350 [0.13750000298023224, 0.2074074000120163],
1351 [0.2541666626930237, 0.21481481194496155],
1352 [0.23125000298023224, 0.35185185074806213],
1353 ];
1354
1355 let decoder = DecoderBuilder::default()
1356 .with_config_modelpack_segdet_split(
1357 vec![
1358 configs::Detection {
1359 decoder: DecoderType::ModelPack,
1360 shape: vec![1, 17, 30, 18],
1361 anchors: Some(anchors1),
1362 quantization: Some(quant1),
1363 dshape: vec![
1364 (DimName::Batch, 1),
1365 (DimName::Height, 17),
1366 (DimName::Width, 30),
1367 (DimName::NumAnchorsXFeatures, 18),
1368 ],
1369 normalized: Some(true),
1370 },
1371 configs::Detection {
1372 decoder: DecoderType::ModelPack,
1373 shape: vec![1, 9, 15, 18],
1374 anchors: Some(anchors0),
1375 quantization: Some(quant0),
1376 dshape: vec![
1377 (DimName::Batch, 1),
1378 (DimName::Height, 9),
1379 (DimName::Width, 15),
1380 (DimName::NumAnchorsXFeatures, 18),
1381 ],
1382 normalized: Some(true),
1383 },
1384 ],
1385 configs::Segmentation {
1386 decoder: DecoderType::ModelPack,
1387 quantization: Some(quant_seg),
1388 shape: vec![1, 2, 160, 160],
1389 dshape: vec![
1390 (DimName::Batch, 1),
1391 (DimName::NumClasses, 2),
1392 (DimName::Height, 160),
1393 (DimName::Width, 160),
1394 ],
1395 },
1396 )
1397 .with_score_threshold(score_threshold)
1398 .with_iou_threshold(iou_threshold)
1399 .build()
1400 .unwrap();
1401 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1402 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1403 decoder
1404 .decode_quantized(
1405 &[
1406 detect0.view().into(),
1407 detect1.view().into(),
1408 seg.view().into(),
1409 ],
1410 &mut output_boxes,
1411 &mut output_masks,
1412 )
1413 .unwrap();
1414
1415 let mut mask = seg.slice(s![0, .., .., ..]);
1416 mask.swap_axes(0, 1);
1417 mask.swap_axes(1, 2);
1418 let mask = [Segmentation {
1419 xmin: 0.0,
1420 ymin: 0.0,
1421 xmax: 1.0,
1422 ymax: 1.0,
1423 segmentation: mask.into_owned(),
1424 }];
1425 let correct_boxes = [DetectBox {
1426 bbox: BoundingBox {
1427 xmin: 0.43171933,
1428 ymin: 0.68243736,
1429 xmax: 0.5626645,
1430 ymax: 0.808863,
1431 },
1432 score: 0.99240804,
1433 label: 0,
1434 }];
1435 println!("Output Boxes: {:?}", output_boxes);
1436 compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1437
1438 let detect0 = dequantize_ndarray(detect0.view(), quant0.into());
1439 let detect1 = dequantize_ndarray(detect1.view(), quant1.into());
1440 let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1441 decoder
1442 .decode_float::<f32>(
1443 &[
1444 detect0.view().into_dyn(),
1445 detect1.view().into_dyn(),
1446 seg.view().into_dyn(),
1447 ],
1448 &mut output_boxes,
1449 &mut output_masks,
1450 )
1451 .unwrap();
1452
1453 compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1459 let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1460 let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1461
1462 assert_eq!(mask0, mask1);
1463 }
1464
1465 #[test]
1466 fn test_dequant_chunked() {
1467 let out = include_bytes!("../../../testdata/yolov8s_80_classes.bin");
1468 let mut out =
1469 unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) }.to_vec();
1470 out.push(123); let mut out_dequant = vec![0.0; 84 * 8400 + 1];
1473 let mut out_dequant_simd = vec![0.0; 84 * 8400 + 1];
1474 let quant = Quantization::new(0.0040811873, -123);
1475 dequantize_cpu(&out, quant, &mut out_dequant);
1476
1477 dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1478 assert_eq!(out_dequant, out_dequant_simd);
1479
1480 let quant = Quantization::new(0.0040811873, 0);
1481 dequantize_cpu(&out, quant, &mut out_dequant);
1482
1483 dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1484 assert_eq!(out_dequant, out_dequant_simd);
1485 }
1486
1487 #[test]
1488 fn test_decoder_yolo_det() {
1489 let score_threshold = 0.25;
1490 let iou_threshold = 0.7;
1491 let out = include_bytes!("../../../testdata/yolov8s_80_classes.bin");
1492 let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
1493 let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
1494 let quant = (0.0040811873, -123).into();
1495
1496 let decoder = DecoderBuilder::default()
1497 .with_config_yolo_det(
1498 configs::Detection {
1499 decoder: DecoderType::Ultralytics,
1500 shape: vec![1, 84, 8400],
1501 anchors: None,
1502 quantization: Some(quant),
1503 dshape: vec![
1504 (DimName::Batch, 1),
1505 (DimName::NumFeatures, 84),
1506 (DimName::NumBoxes, 8400),
1507 ],
1508 normalized: Some(true),
1509 },
1510 Some(DecoderVersion::Yolo11),
1511 )
1512 .with_score_threshold(score_threshold)
1513 .with_iou_threshold(iou_threshold)
1514 .build()
1515 .unwrap();
1516
1517 let mut output_boxes: Vec<_> = Vec::with_capacity(50);
1518 decode_yolo_det(
1519 (out.slice(s![0, .., ..]), quant.into()),
1520 score_threshold,
1521 iou_threshold,
1522 Some(configs::Nms::ClassAgnostic),
1523 &mut output_boxes,
1524 );
1525 assert!(output_boxes[0].equal_within_delta(
1526 &DetectBox {
1527 bbox: BoundingBox {
1528 xmin: 0.5285137,
1529 ymin: 0.05305544,
1530 xmax: 0.87541467,
1531 ymax: 0.9998909,
1532 },
1533 score: 0.5591227,
1534 label: 0
1535 },
1536 1e-6
1537 ));
1538
1539 assert!(output_boxes[1].equal_within_delta(
1540 &DetectBox {
1541 bbox: BoundingBox {
1542 xmin: 0.130598,
1543 ymin: 0.43260583,
1544 xmax: 0.35098213,
1545 ymax: 0.9958097,
1546 },
1547 score: 0.33057618,
1548 label: 75
1549 },
1550 1e-6
1551 ));
1552
1553 let mut output_boxes1: Vec<_> = Vec::with_capacity(50);
1554 let mut output_masks1: Vec<_> = Vec::with_capacity(50);
1555 decoder
1556 .decode_quantized(&[out.view().into()], &mut output_boxes1, &mut output_masks1)
1557 .unwrap();
1558
1559 let out = dequantize_ndarray(out.view(), quant.into());
1560 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(50);
1561 let mut output_masks_f32: Vec<_> = Vec::with_capacity(50);
1562 decoder
1563 .decode_float::<f32>(
1564 &[out.view().into_dyn()],
1565 &mut output_boxes_f32,
1566 &mut output_masks_f32,
1567 )
1568 .unwrap();
1569
1570 compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1571 compare_outputs((&output_boxes, &output_boxes_f32), (&[], &output_masks_f32));
1572 }
1573
1574 #[test]
1575 fn test_decoder_masks() {
1576 let score_threshold = 0.45;
1577 let iou_threshold = 0.45;
1578 let boxes = include_bytes!("../../../testdata/yolov8_boxes_116x8400.bin");
1579 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1580 let boxes = ndarray::Array2::from_shape_vec((116, 8400), boxes.to_vec()).unwrap();
1581 let quant_boxes = Quantization::new(0.021287761628627777, 31);
1582
1583 let protos = include_bytes!("../../../testdata/yolov8_protos_160x160x32.bin");
1584 let protos =
1585 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
1586 let protos = ndarray::Array3::from_shape_vec((160, 160, 32), protos.to_vec()).unwrap();
1587 let quant_protos = Quantization::new(0.02491161972284317, -117);
1588 let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1589 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1590 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1591 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1592 decode_yolo_segdet_float(
1593 seg.view(),
1594 protos.view(),
1595 score_threshold,
1596 iou_threshold,
1597 Some(configs::Nms::ClassAgnostic),
1598 &mut output_boxes,
1599 &mut output_masks,
1600 )
1601 .unwrap();
1602 assert_eq!(output_boxes.len(), 2);
1603 assert_eq!(output_boxes.len(), output_masks.len());
1604
1605 for (b, m) in output_boxes.iter().zip(&output_masks) {
1606 assert!(b.bbox.xmin >= m.xmin);
1607 assert!(b.bbox.ymin >= m.ymin);
1608 assert!(b.bbox.xmax >= m.xmax);
1609 assert!(b.bbox.ymax >= m.ymax);
1610 }
1611 assert!(output_boxes[0].equal_within_delta(
1612 &DetectBox {
1613 bbox: BoundingBox {
1614 xmin: 0.08515105,
1615 ymin: 0.7131401,
1616 xmax: 0.29802868,
1617 ymax: 0.8195788,
1618 },
1619 score: 0.91537374,
1620 label: 23
1621 },
1622 1.0 / 160.0, ));
1624
1625 assert!(output_boxes[1].equal_within_delta(
1626 &DetectBox {
1627 bbox: BoundingBox {
1628 xmin: 0.59605736,
1629 ymin: 0.25545314,
1630 xmax: 0.93666154,
1631 ymax: 0.72378385,
1632 },
1633 score: 0.91537374,
1634 label: 23
1635 },
1636 1.0 / 160.0, ));
1638
1639 let full_mask = include_bytes!("../../../testdata/yolov8_mask_results.bin");
1640 let full_mask = ndarray::Array2::from_shape_vec((160, 160), full_mask.to_vec()).unwrap();
1641
1642 let cropped_mask = full_mask.slice(ndarray::s![
1643 (output_masks[1].ymin * 160.0) as usize..(output_masks[1].ymax * 160.0) as usize,
1644 (output_masks[1].xmin * 160.0) as usize..(output_masks[1].xmax * 160.0) as usize,
1645 ]);
1646
1647 assert_eq!(
1648 cropped_mask,
1649 segmentation_to_mask(output_masks[1].segmentation.view()).unwrap()
1650 );
1651 }
1652
1653 #[test]
1658 fn test_decoder_masks_nchw_protos() {
1659 let score_threshold = 0.45;
1660 let iou_threshold = 0.45;
1661
1662 let boxes_raw = include_bytes!("../../../testdata/yolov8_boxes_116x8400.bin");
1664 let boxes_raw =
1665 unsafe { std::slice::from_raw_parts(boxes_raw.as_ptr() as *const i8, boxes_raw.len()) };
1666 let boxes_2d = ndarray::Array2::from_shape_vec((116, 8400), boxes_raw.to_vec()).unwrap();
1667 let quant_boxes = Quantization::new(0.021287761628627777, 31);
1668
1669 let protos_raw = include_bytes!("../../../testdata/yolov8_protos_160x160x32.bin");
1671 let protos_raw = unsafe {
1672 std::slice::from_raw_parts(protos_raw.as_ptr() as *const i8, protos_raw.len())
1673 };
1674 let protos_hwc =
1675 ndarray::Array3::from_shape_vec((160, 160, 32), protos_raw.to_vec()).unwrap();
1676 let quant_protos = Quantization::new(0.02491161972284317, -117);
1677 let protos_f32_hwc = dequantize_ndarray::<_, _, f32>(protos_hwc.view(), quant_protos);
1678
1679 let seg = dequantize_ndarray::<_, _, f32>(boxes_2d.view(), quant_boxes);
1681 let mut ref_boxes: Vec<_> = Vec::with_capacity(10);
1682 let mut ref_masks: Vec<_> = Vec::with_capacity(10);
1683 decode_yolo_segdet_float(
1684 seg.view(),
1685 protos_f32_hwc.view(),
1686 score_threshold,
1687 iou_threshold,
1688 Some(configs::Nms::ClassAgnostic),
1689 &mut ref_boxes,
1690 &mut ref_masks,
1691 )
1692 .unwrap();
1693 assert_eq!(ref_boxes.len(), 2);
1694
1695 let protos_f32_chw = protos_f32_hwc.permuted_axes([2, 0, 1]); let protos_nchw = protos_f32_chw.insert_axis(ndarray::Axis(0)); let seg_3d = seg.insert_axis(ndarray::Axis(0)); let decoder = DecoderBuilder::default()
1705 .with_config_yolo_segdet(
1706 configs::Detection {
1707 decoder: configs::DecoderType::Ultralytics,
1708 quantization: None,
1709 shape: vec![1, 116, 8400],
1710 dshape: vec![],
1711 normalized: Some(true),
1712 anchors: None,
1713 },
1714 configs::Protos {
1715 decoder: configs::DecoderType::Ultralytics,
1716 quantization: None,
1717 shape: vec![1, 32, 160, 160],
1718 dshape: vec![], },
1720 None, )
1722 .with_score_threshold(score_threshold)
1723 .with_iou_threshold(iou_threshold)
1724 .build()
1725 .unwrap();
1726
1727 let mut cfg_boxes: Vec<_> = Vec::with_capacity(10);
1728 let mut cfg_masks: Vec<_> = Vec::with_capacity(10);
1729 decoder
1730 .decode_float(
1731 &[seg_3d.view().into_dyn(), protos_nchw.view().into_dyn()],
1732 &mut cfg_boxes,
1733 &mut cfg_masks,
1734 )
1735 .unwrap();
1736
1737 assert_eq!(
1739 cfg_boxes.len(),
1740 ref_boxes.len(),
1741 "config path produced {} boxes, reference produced {}",
1742 cfg_boxes.len(),
1743 ref_boxes.len()
1744 );
1745
1746 for (i, (cb, rb)) in cfg_boxes.iter().zip(&ref_boxes).enumerate() {
1748 assert!(
1749 cb.equal_within_delta(rb, 0.01),
1750 "box {i} mismatch: config={cb:?}, reference={rb:?}"
1751 );
1752 }
1753
1754 for (i, (cm, rm)) in cfg_masks.iter().zip(&ref_masks).enumerate() {
1756 let cm_arr = segmentation_to_mask(cm.segmentation.view()).unwrap();
1757 let rm_arr = segmentation_to_mask(rm.segmentation.view()).unwrap();
1758 assert_eq!(
1759 cm_arr, rm_arr,
1760 "mask {i} pixel mismatch between config-driven and reference paths"
1761 );
1762 }
1763 }
1764
1765 #[test]
1766 fn test_decoder_masks_i8() {
1767 let score_threshold = 0.45;
1768 let iou_threshold = 0.45;
1769 let boxes = include_bytes!("../../../testdata/yolov8_boxes_116x8400.bin");
1770 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1771 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes.to_vec()).unwrap();
1772 let quant_boxes = (0.021287761628627777, 31).into();
1773
1774 let protos = include_bytes!("../../../testdata/yolov8_protos_160x160x32.bin");
1775 let protos =
1776 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
1777 let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
1778 let quant_protos = (0.02491161972284317, -117).into();
1779 let mut output_boxes: Vec<_> = Vec::with_capacity(500);
1780 let mut output_masks: Vec<_> = Vec::with_capacity(500);
1781
1782 let decoder = DecoderBuilder::default()
1783 .with_config_yolo_segdet(
1784 configs::Detection {
1785 decoder: configs::DecoderType::Ultralytics,
1786 quantization: Some(quant_boxes),
1787 shape: vec![1, 116, 8400],
1788 anchors: None,
1789 dshape: vec![
1790 (DimName::Batch, 1),
1791 (DimName::NumFeatures, 116),
1792 (DimName::NumBoxes, 8400),
1793 ],
1794 normalized: Some(true),
1795 },
1796 Protos {
1797 decoder: configs::DecoderType::Ultralytics,
1798 quantization: Some(quant_protos),
1799 shape: vec![1, 160, 160, 32],
1800 dshape: vec![
1801 (DimName::Batch, 1),
1802 (DimName::Height, 160),
1803 (DimName::Width, 160),
1804 (DimName::NumProtos, 32),
1805 ],
1806 },
1807 Some(DecoderVersion::Yolo11),
1808 )
1809 .with_score_threshold(score_threshold)
1810 .with_iou_threshold(iou_threshold)
1811 .build()
1812 .unwrap();
1813
1814 let quant_boxes = quant_boxes.into();
1815 let quant_protos = quant_protos.into();
1816
1817 decode_yolo_segdet_quant(
1818 (boxes.slice(s![0, .., ..]), quant_boxes),
1819 (protos.slice(s![0, .., .., ..]), quant_protos),
1820 score_threshold,
1821 iou_threshold,
1822 Some(configs::Nms::ClassAgnostic),
1823 &mut output_boxes,
1824 &mut output_masks,
1825 )
1826 .unwrap();
1827
1828 let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
1829 let mut output_masks1: Vec<_> = Vec::with_capacity(500);
1830
1831 decoder
1832 .decode_quantized(
1833 &[boxes.view().into(), protos.view().into()],
1834 &mut output_boxes1,
1835 &mut output_masks1,
1836 )
1837 .unwrap();
1838
1839 let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1840 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1841
1842 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
1843 let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
1844 decode_yolo_segdet_float(
1845 seg.slice(s![0, .., ..]),
1846 protos.slice(s![0, .., .., ..]),
1847 score_threshold,
1848 iou_threshold,
1849 Some(configs::Nms::ClassAgnostic),
1850 &mut output_boxes_f32,
1851 &mut output_masks_f32,
1852 )
1853 .unwrap();
1854
1855 let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(500);
1856 let mut output_masks1_f32: Vec<_> = Vec::with_capacity(500);
1857
1858 decoder
1859 .decode_float(
1860 &[seg.view().into_dyn(), protos.view().into_dyn()],
1861 &mut output_boxes1_f32,
1862 &mut output_masks1_f32,
1863 )
1864 .unwrap();
1865
1866 compare_outputs(
1867 (&output_boxes, &output_boxes1),
1868 (&output_masks, &output_masks1),
1869 );
1870
1871 compare_outputs(
1872 (&output_boxes, &output_boxes_f32),
1873 (&output_masks, &output_masks_f32),
1874 );
1875
1876 compare_outputs(
1877 (&output_boxes_f32, &output_boxes1_f32),
1878 (&output_masks_f32, &output_masks1_f32),
1879 );
1880 }
1881
1882 #[test]
1883 fn test_decoder_yolo_split() {
1884 let score_threshold = 0.45;
1885 let iou_threshold = 0.45;
1886 let boxes = include_bytes!("../../../testdata/yolov8_boxes_116x8400.bin");
1887 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1888 let boxes: Vec<_> = boxes.iter().map(|x| *x as i16 * 256).collect();
1889 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
1890
1891 let quant_boxes = Quantization::new(0.021287761628627777 / 256.0, 31 * 256);
1892
1893 let decoder = DecoderBuilder::default()
1894 .with_config_yolo_split_det(
1895 configs::Boxes {
1896 decoder: configs::DecoderType::Ultralytics,
1897 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
1898 shape: vec![1, 4, 8400],
1899 dshape: vec![
1900 (DimName::Batch, 1),
1901 (DimName::BoxCoords, 4),
1902 (DimName::NumBoxes, 8400),
1903 ],
1904 normalized: Some(true),
1905 },
1906 configs::Scores {
1907 decoder: configs::DecoderType::Ultralytics,
1908 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
1909 shape: vec![1, 80, 8400],
1910 dshape: vec![
1911 (DimName::Batch, 1),
1912 (DimName::NumClasses, 80),
1913 (DimName::NumBoxes, 8400),
1914 ],
1915 },
1916 )
1917 .with_score_threshold(score_threshold)
1918 .with_iou_threshold(iou_threshold)
1919 .build()
1920 .unwrap();
1921
1922 let mut output_boxes: Vec<_> = Vec::with_capacity(500);
1923 let mut output_masks: Vec<_> = Vec::with_capacity(500);
1924
1925 decoder
1926 .decode_quantized(
1927 &[
1928 boxes.slice(s![.., ..4, ..]).into(),
1929 boxes.slice(s![.., 4..84, ..]).into(),
1930 ],
1931 &mut output_boxes,
1932 &mut output_masks,
1933 )
1934 .unwrap();
1935
1936 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1937 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
1938 decode_yolo_det_float(
1939 seg.slice(s![0, ..84, ..]),
1940 score_threshold,
1941 iou_threshold,
1942 Some(configs::Nms::ClassAgnostic),
1943 &mut output_boxes_f32,
1944 );
1945
1946 let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
1947 let mut output_masks1: Vec<_> = Vec::with_capacity(500);
1948
1949 decoder
1950 .decode_float(
1951 &[
1952 seg.slice(s![.., ..4, ..]).into_dyn(),
1953 seg.slice(s![.., 4..84, ..]).into_dyn(),
1954 ],
1955 &mut output_boxes1,
1956 &mut output_masks1,
1957 )
1958 .unwrap();
1959 compare_outputs((&output_boxes, &output_boxes_f32), (&output_masks, &[]));
1960 compare_outputs((&output_boxes_f32, &output_boxes1), (&[], &output_masks1));
1961 }
1962
1963 #[test]
1964 fn test_decoder_masks_config_mixed() {
1965 let score_threshold = 0.45;
1966 let iou_threshold = 0.45;
1967 let boxes = include_bytes!("../../../testdata/yolov8_boxes_116x8400.bin");
1968 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1969 let boxes: Vec<_> = boxes.iter().map(|x| *x as i16 * 256).collect();
1970 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
1971
1972 let quant_boxes = Quantization::new(0.021287761628627777 / 256.0, 31 * 256);
1973
1974 let protos = include_bytes!("../../../testdata/yolov8_protos_160x160x32.bin");
1975 let protos =
1976 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
1977 let protos: Vec<_> = protos.to_vec();
1978 let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
1979 let quant_protos = Quantization::new(0.02491161972284317, -117);
1980
1981 let decoder = DecoderBuilder::default()
1982 .with_config_yolo_split_segdet(
1983 configs::Boxes {
1984 decoder: configs::DecoderType::Ultralytics,
1985 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
1986 shape: vec![1, 4, 8400],
1987 dshape: vec![
1988 (DimName::Batch, 1),
1989 (DimName::BoxCoords, 4),
1990 (DimName::NumBoxes, 8400),
1991 ],
1992 normalized: Some(true),
1993 },
1994 configs::Scores {
1995 decoder: configs::DecoderType::Ultralytics,
1996 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
1997 shape: vec![1, 80, 8400],
1998 dshape: vec![
1999 (DimName::Batch, 1),
2000 (DimName::NumClasses, 80),
2001 (DimName::NumBoxes, 8400),
2002 ],
2003 },
2004 configs::MaskCoefficients {
2005 decoder: configs::DecoderType::Ultralytics,
2006 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2007 shape: vec![1, 32, 8400],
2008 dshape: vec![
2009 (DimName::Batch, 1),
2010 (DimName::NumProtos, 32),
2011 (DimName::NumBoxes, 8400),
2012 ],
2013 },
2014 configs::Protos {
2015 decoder: configs::DecoderType::Ultralytics,
2016 quantization: Some(QuantTuple(quant_protos.scale, quant_protos.zero_point)),
2017 shape: vec![1, 160, 160, 32],
2018 dshape: vec![
2019 (DimName::Batch, 1),
2020 (DimName::Height, 160),
2021 (DimName::Width, 160),
2022 (DimName::NumProtos, 32),
2023 ],
2024 },
2025 )
2026 .with_score_threshold(score_threshold)
2027 .with_iou_threshold(iou_threshold)
2028 .build()
2029 .unwrap();
2030
2031 let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2032 let mut output_masks: Vec<_> = Vec::with_capacity(500);
2033
2034 decoder
2035 .decode_quantized(
2036 &[
2037 boxes.slice(s![.., ..4, ..]).into(),
2038 boxes.slice(s![.., 4..84, ..]).into(),
2039 boxes.slice(s![.., 84.., ..]).into(),
2040 protos.view().into(),
2041 ],
2042 &mut output_boxes,
2043 &mut output_masks,
2044 )
2045 .unwrap();
2046
2047 let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
2048 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
2049 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2050 let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
2051 decode_yolo_segdet_float(
2052 seg.slice(s![0, .., ..]),
2053 protos.slice(s![0, .., .., ..]),
2054 score_threshold,
2055 iou_threshold,
2056 Some(configs::Nms::ClassAgnostic),
2057 &mut output_boxes_f32,
2058 &mut output_masks_f32,
2059 )
2060 .unwrap();
2061
2062 let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
2063 let mut output_masks1: Vec<_> = Vec::with_capacity(500);
2064
2065 decoder
2066 .decode_float(
2067 &[
2068 seg.slice(s![.., ..4, ..]).into_dyn(),
2069 seg.slice(s![.., 4..84, ..]).into_dyn(),
2070 seg.slice(s![.., 84.., ..]).into_dyn(),
2071 protos.view().into_dyn(),
2072 ],
2073 &mut output_boxes1,
2074 &mut output_masks1,
2075 )
2076 .unwrap();
2077 compare_outputs(
2078 (&output_boxes, &output_boxes_f32),
2079 (&output_masks, &output_masks_f32),
2080 );
2081 compare_outputs(
2082 (&output_boxes_f32, &output_boxes1),
2083 (&output_masks_f32, &output_masks1),
2084 );
2085 }
2086
2087 #[test]
2088 fn test_decoder_masks_config_i32() {
2089 let score_threshold = 0.45;
2090 let iou_threshold = 0.45;
2091 let boxes = include_bytes!("../../../testdata/yolov8_boxes_116x8400.bin");
2092 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
2093 let scale = 1 << 23;
2094 let boxes: Vec<_> = boxes.iter().map(|x| *x as i32 * scale).collect();
2095 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2096
2097 let quant_boxes = Quantization::new(0.021287761628627777 / scale as f32, 31 * scale);
2098
2099 let protos = include_bytes!("../../../testdata/yolov8_protos_160x160x32.bin");
2100 let protos =
2101 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
2102 let protos: Vec<_> = protos.iter().map(|x| *x as i32 * scale).collect();
2103 let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
2104 let quant_protos = Quantization::new(0.02491161972284317 / scale as f32, -117 * scale);
2105
2106 let decoder = DecoderBuilder::default()
2107 .with_config_yolo_split_segdet(
2108 configs::Boxes {
2109 decoder: configs::DecoderType::Ultralytics,
2110 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2111 shape: vec![1, 4, 8400],
2112 dshape: vec![
2113 (DimName::Batch, 1),
2114 (DimName::BoxCoords, 4),
2115 (DimName::NumBoxes, 8400),
2116 ],
2117 normalized: Some(true),
2118 },
2119 configs::Scores {
2120 decoder: configs::DecoderType::Ultralytics,
2121 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2122 shape: vec![1, 80, 8400],
2123 dshape: vec![
2124 (DimName::Batch, 1),
2125 (DimName::NumClasses, 80),
2126 (DimName::NumBoxes, 8400),
2127 ],
2128 },
2129 configs::MaskCoefficients {
2130 decoder: configs::DecoderType::Ultralytics,
2131 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2132 shape: vec![1, 32, 8400],
2133 dshape: vec![
2134 (DimName::Batch, 1),
2135 (DimName::NumProtos, 32),
2136 (DimName::NumBoxes, 8400),
2137 ],
2138 },
2139 configs::Protos {
2140 decoder: configs::DecoderType::Ultralytics,
2141 quantization: Some(QuantTuple(quant_protos.scale, quant_protos.zero_point)),
2142 shape: vec![1, 160, 160, 32],
2143 dshape: vec![
2144 (DimName::Batch, 1),
2145 (DimName::Height, 160),
2146 (DimName::Width, 160),
2147 (DimName::NumProtos, 32),
2148 ],
2149 },
2150 )
2151 .with_score_threshold(score_threshold)
2152 .with_iou_threshold(iou_threshold)
2153 .build()
2154 .unwrap();
2155
2156 let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2157 let mut output_masks: Vec<_> = Vec::with_capacity(500);
2158
2159 decoder
2160 .decode_quantized(
2161 &[
2162 boxes.slice(s![.., ..4, ..]).into(),
2163 boxes.slice(s![.., 4..84, ..]).into(),
2164 boxes.slice(s![.., 84.., ..]).into(),
2165 protos.view().into(),
2166 ],
2167 &mut output_boxes,
2168 &mut output_masks,
2169 )
2170 .unwrap();
2171
2172 let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
2173 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
2174 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2175 let mut output_masks_f32: Vec<Segmentation> = Vec::with_capacity(500);
2176 decode_yolo_segdet_float(
2177 seg.slice(s![0, .., ..]),
2178 protos.slice(s![0, .., .., ..]),
2179 score_threshold,
2180 iou_threshold,
2181 Some(configs::Nms::ClassAgnostic),
2182 &mut output_boxes_f32,
2183 &mut output_masks_f32,
2184 )
2185 .unwrap();
2186
2187 assert_eq!(output_boxes.len(), output_boxes_f32.len());
2188 assert_eq!(output_masks.len(), output_masks_f32.len());
2189
2190 compare_outputs(
2191 (&output_boxes, &output_boxes_f32),
2192 (&output_masks, &output_masks_f32),
2193 );
2194 }
2195
2196 #[test]
2198 fn test_context_switch() {
2199 let yolo_det = || {
2200 let score_threshold = 0.25;
2201 let iou_threshold = 0.7;
2202 let out = include_bytes!("../../../testdata/yolov8s_80_classes.bin");
2203 let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
2204 let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
2205 let quant = (0.0040811873, -123).into();
2206
2207 let decoder = DecoderBuilder::default()
2208 .with_config_yolo_det(
2209 configs::Detection {
2210 decoder: DecoderType::Ultralytics,
2211 shape: vec![1, 84, 8400],
2212 anchors: None,
2213 quantization: Some(quant),
2214 dshape: vec![
2215 (DimName::Batch, 1),
2216 (DimName::NumFeatures, 84),
2217 (DimName::NumBoxes, 8400),
2218 ],
2219 normalized: None,
2220 },
2221 None,
2222 )
2223 .with_score_threshold(score_threshold)
2224 .with_iou_threshold(iou_threshold)
2225 .build()
2226 .unwrap();
2227
2228 let mut output_boxes: Vec<_> = Vec::with_capacity(50);
2229 let mut output_masks: Vec<_> = Vec::with_capacity(50);
2230
2231 for _ in 0..100 {
2232 decoder
2233 .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
2234 .unwrap();
2235
2236 assert!(output_boxes[0].equal_within_delta(
2237 &DetectBox {
2238 bbox: BoundingBox {
2239 xmin: 0.5285137,
2240 ymin: 0.05305544,
2241 xmax: 0.87541467,
2242 ymax: 0.9998909,
2243 },
2244 score: 0.5591227,
2245 label: 0
2246 },
2247 1e-6
2248 ));
2249
2250 assert!(output_boxes[1].equal_within_delta(
2251 &DetectBox {
2252 bbox: BoundingBox {
2253 xmin: 0.130598,
2254 ymin: 0.43260583,
2255 xmax: 0.35098213,
2256 ymax: 0.9958097,
2257 },
2258 score: 0.33057618,
2259 label: 75
2260 },
2261 1e-6
2262 ));
2263 assert!(output_masks.is_empty());
2264 }
2265 };
2266
2267 let modelpack_det_split = || {
2268 let score_threshold = 0.8;
2269 let iou_threshold = 0.5;
2270
2271 let seg = include_bytes!("../../../testdata/modelpack_seg_2x160x160.bin");
2272 let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
2273
2274 let detect0 = include_bytes!("../../../testdata/modelpack_split_9x15x18.bin");
2275 let detect0 =
2276 ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
2277
2278 let detect1 = include_bytes!("../../../testdata/modelpack_split_17x30x18.bin");
2279 let detect1 =
2280 ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
2281
2282 let mut mask = seg.slice(s![0, .., .., ..]);
2283 mask.swap_axes(0, 1);
2284 mask.swap_axes(1, 2);
2285 let mask = [Segmentation {
2286 xmin: 0.0,
2287 ymin: 0.0,
2288 xmax: 1.0,
2289 ymax: 1.0,
2290 segmentation: mask.into_owned(),
2291 }];
2292 let correct_boxes = [DetectBox {
2293 bbox: BoundingBox {
2294 xmin: 0.43171933,
2295 ymin: 0.68243736,
2296 xmax: 0.5626645,
2297 ymax: 0.808863,
2298 },
2299 score: 0.99240804,
2300 label: 0,
2301 }];
2302
2303 let quant0 = (0.08547406643629074, 174).into();
2304 let quant1 = (0.09929127991199493, 183).into();
2305 let quant_seg = (1.0 / 255.0, 0).into();
2306
2307 let anchors0 = vec![
2308 [0.36666667461395264, 0.31481480598449707],
2309 [0.38749998807907104, 0.4740740656852722],
2310 [0.5333333611488342, 0.644444465637207],
2311 ];
2312 let anchors1 = vec![
2313 [0.13750000298023224, 0.2074074000120163],
2314 [0.2541666626930237, 0.21481481194496155],
2315 [0.23125000298023224, 0.35185185074806213],
2316 ];
2317
2318 let decoder = DecoderBuilder::default()
2319 .with_config_modelpack_segdet_split(
2320 vec![
2321 configs::Detection {
2322 decoder: DecoderType::ModelPack,
2323 shape: vec![1, 17, 30, 18],
2324 anchors: Some(anchors1),
2325 quantization: Some(quant1),
2326 dshape: vec![
2327 (DimName::Batch, 1),
2328 (DimName::Height, 17),
2329 (DimName::Width, 30),
2330 (DimName::NumAnchorsXFeatures, 18),
2331 ],
2332 normalized: None,
2333 },
2334 configs::Detection {
2335 decoder: DecoderType::ModelPack,
2336 shape: vec![1, 9, 15, 18],
2337 anchors: Some(anchors0),
2338 quantization: Some(quant0),
2339 dshape: vec![
2340 (DimName::Batch, 1),
2341 (DimName::Height, 9),
2342 (DimName::Width, 15),
2343 (DimName::NumAnchorsXFeatures, 18),
2344 ],
2345 normalized: None,
2346 },
2347 ],
2348 configs::Segmentation {
2349 decoder: DecoderType::ModelPack,
2350 quantization: Some(quant_seg),
2351 shape: vec![1, 2, 160, 160],
2352 dshape: vec![
2353 (DimName::Batch, 1),
2354 (DimName::NumClasses, 2),
2355 (DimName::Height, 160),
2356 (DimName::Width, 160),
2357 ],
2358 },
2359 )
2360 .with_score_threshold(score_threshold)
2361 .with_iou_threshold(iou_threshold)
2362 .build()
2363 .unwrap();
2364 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
2365 let mut output_masks: Vec<_> = Vec::with_capacity(10);
2366
2367 for _ in 0..100 {
2368 decoder
2369 .decode_quantized(
2370 &[
2371 detect0.view().into(),
2372 detect1.view().into(),
2373 seg.view().into(),
2374 ],
2375 &mut output_boxes,
2376 &mut output_masks,
2377 )
2378 .unwrap();
2379
2380 compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
2381 }
2382 };
2383
2384 let handles = vec![
2385 std::thread::spawn(yolo_det),
2386 std::thread::spawn(modelpack_det_split),
2387 std::thread::spawn(yolo_det),
2388 std::thread::spawn(modelpack_det_split),
2389 std::thread::spawn(yolo_det),
2390 std::thread::spawn(modelpack_det_split),
2391 std::thread::spawn(yolo_det),
2392 std::thread::spawn(modelpack_det_split),
2393 ];
2394 for handle in handles {
2395 handle.join().unwrap();
2396 }
2397 }
2398
2399 #[test]
2400 fn test_ndarray_to_xyxy_float() {
2401 let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2402 let xyxy: [f32; 4] = XYWH::ndarray_to_xyxy_float(arr.view());
2403 assert_eq!(xyxy, [0.0_f32, 10.0, 20.0, 30.0]);
2404
2405 let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2406 let xyxy: [f32; 4] = XYXY::ndarray_to_xyxy_float(arr.view());
2407 assert_eq!(xyxy, [10.0_f32, 20.0, 20.0, 20.0]);
2408 }
2409
2410 #[test]
2411 fn test_class_aware_nms_float() {
2412 use crate::float::nms_class_aware_float;
2413
2414 let boxes = vec![
2416 DetectBox {
2417 bbox: BoundingBox {
2418 xmin: 0.0,
2419 ymin: 0.0,
2420 xmax: 0.5,
2421 ymax: 0.5,
2422 },
2423 score: 0.9,
2424 label: 0, },
2426 DetectBox {
2427 bbox: BoundingBox {
2428 xmin: 0.1,
2429 ymin: 0.1,
2430 xmax: 0.6,
2431 ymax: 0.6,
2432 },
2433 score: 0.8,
2434 label: 1, },
2436 ];
2437
2438 let result = nms_class_aware_float(0.3, boxes.clone());
2441 assert_eq!(
2442 result.len(),
2443 2,
2444 "Class-aware NMS should keep both boxes with different classes"
2445 );
2446
2447 let same_class_boxes = vec![
2449 DetectBox {
2450 bbox: BoundingBox {
2451 xmin: 0.0,
2452 ymin: 0.0,
2453 xmax: 0.5,
2454 ymax: 0.5,
2455 },
2456 score: 0.9,
2457 label: 0,
2458 },
2459 DetectBox {
2460 bbox: BoundingBox {
2461 xmin: 0.1,
2462 ymin: 0.1,
2463 xmax: 0.6,
2464 ymax: 0.6,
2465 },
2466 score: 0.8,
2467 label: 0, },
2469 ];
2470
2471 let result = nms_class_aware_float(0.3, same_class_boxes);
2472 assert_eq!(
2473 result.len(),
2474 1,
2475 "Class-aware NMS should suppress overlapping box with same class"
2476 );
2477 assert_eq!(result[0].label, 0);
2478 assert!((result[0].score - 0.9).abs() < 1e-6);
2479 }
2480
2481 #[test]
2482 fn test_class_agnostic_vs_aware_nms() {
2483 use crate::float::{nms_class_aware_float, nms_float};
2484
2485 let boxes = vec![
2487 DetectBox {
2488 bbox: BoundingBox {
2489 xmin: 0.0,
2490 ymin: 0.0,
2491 xmax: 0.5,
2492 ymax: 0.5,
2493 },
2494 score: 0.9,
2495 label: 0,
2496 },
2497 DetectBox {
2498 bbox: BoundingBox {
2499 xmin: 0.1,
2500 ymin: 0.1,
2501 xmax: 0.6,
2502 ymax: 0.6,
2503 },
2504 score: 0.8,
2505 label: 1,
2506 },
2507 ];
2508
2509 let agnostic_result = nms_float(0.3, boxes.clone());
2511 assert_eq!(
2512 agnostic_result.len(),
2513 1,
2514 "Class-agnostic NMS should suppress overlapping boxes"
2515 );
2516
2517 let aware_result = nms_class_aware_float(0.3, boxes);
2519 assert_eq!(
2520 aware_result.len(),
2521 2,
2522 "Class-aware NMS should keep boxes with different classes"
2523 );
2524 }
2525
2526 #[test]
2527 fn test_class_aware_nms_int() {
2528 use crate::byte::nms_class_aware_int;
2529
2530 let boxes = vec![
2532 DetectBoxQuantized {
2533 bbox: BoundingBox {
2534 xmin: 0.0,
2535 ymin: 0.0,
2536 xmax: 0.5,
2537 ymax: 0.5,
2538 },
2539 score: 200_u8,
2540 label: 0,
2541 },
2542 DetectBoxQuantized {
2543 bbox: BoundingBox {
2544 xmin: 0.1,
2545 ymin: 0.1,
2546 xmax: 0.6,
2547 ymax: 0.6,
2548 },
2549 score: 180_u8,
2550 label: 1, },
2552 ];
2553
2554 let result = nms_class_aware_int(0.5, boxes);
2556 assert_eq!(
2557 result.len(),
2558 2,
2559 "Class-aware NMS (int) should keep boxes with different classes"
2560 );
2561 }
2562
2563 #[test]
2564 fn test_nms_enum_default() {
2565 let default_nms: configs::Nms = Default::default();
2567 assert_eq!(default_nms, configs::Nms::ClassAgnostic);
2568 }
2569
2570 #[test]
2571 fn test_decoder_nms_mode() {
2572 let decoder = DecoderBuilder::default()
2574 .with_config_yolo_det(
2575 configs::Detection {
2576 anchors: None,
2577 decoder: DecoderType::Ultralytics,
2578 quantization: None,
2579 shape: vec![1, 84, 8400],
2580 dshape: Vec::new(),
2581 normalized: Some(true),
2582 },
2583 None,
2584 )
2585 .with_nms(Some(configs::Nms::ClassAware))
2586 .build()
2587 .unwrap();
2588
2589 assert_eq!(decoder.nms, Some(configs::Nms::ClassAware));
2590 }
2591
2592 #[test]
2593 fn test_decoder_nms_bypass() {
2594 let decoder = DecoderBuilder::default()
2596 .with_config_yolo_det(
2597 configs::Detection {
2598 anchors: None,
2599 decoder: DecoderType::Ultralytics,
2600 quantization: None,
2601 shape: vec![1, 84, 8400],
2602 dshape: Vec::new(),
2603 normalized: Some(true),
2604 },
2605 None,
2606 )
2607 .with_nms(None)
2608 .build()
2609 .unwrap();
2610
2611 assert_eq!(decoder.nms, None);
2612 }
2613
2614 #[test]
2615 fn test_decoder_normalized_boxes_true() {
2616 let decoder = DecoderBuilder::default()
2618 .with_config_yolo_det(
2619 configs::Detection {
2620 anchors: None,
2621 decoder: DecoderType::Ultralytics,
2622 quantization: None,
2623 shape: vec![1, 84, 8400],
2624 dshape: Vec::new(),
2625 normalized: Some(true),
2626 },
2627 None,
2628 )
2629 .build()
2630 .unwrap();
2631
2632 assert_eq!(decoder.normalized_boxes(), Some(true));
2633 }
2634
2635 #[test]
2636 fn test_decoder_normalized_boxes_false() {
2637 let decoder = DecoderBuilder::default()
2640 .with_config_yolo_det(
2641 configs::Detection {
2642 anchors: None,
2643 decoder: DecoderType::Ultralytics,
2644 quantization: None,
2645 shape: vec![1, 84, 8400],
2646 dshape: Vec::new(),
2647 normalized: Some(false),
2648 },
2649 None,
2650 )
2651 .build()
2652 .unwrap();
2653
2654 assert_eq!(decoder.normalized_boxes(), Some(false));
2655 }
2656
2657 #[test]
2658 fn test_decoder_normalized_boxes_unknown() {
2659 let decoder = DecoderBuilder::default()
2661 .with_config_yolo_det(
2662 configs::Detection {
2663 anchors: None,
2664 decoder: DecoderType::Ultralytics,
2665 quantization: None,
2666 shape: vec![1, 84, 8400],
2667 dshape: Vec::new(),
2668 normalized: None,
2669 },
2670 Some(DecoderVersion::Yolo11),
2671 )
2672 .build()
2673 .unwrap();
2674
2675 assert_eq!(decoder.normalized_boxes(), None);
2676 }
2677}