1#![cfg_attr(coverage_nightly, feature(coverage_attribute))]
71
72use ndarray::{Array, Array2, Array3, ArrayView, ArrayView1, ArrayView3, Dimension};
73use num_traits::{AsPrimitive, Float, PrimInt};
74
75pub mod byte;
76pub mod error;
77pub mod float;
78pub mod modelpack;
79pub mod yolo;
80
81mod decoder;
82pub use decoder::*;
83
84pub use configs::{DecoderVersion, Nms};
85pub use error::{DecoderError, DecoderResult};
86
87use crate::{
88 decoder::configs::QuantTuple, modelpack::modelpack_segmentation_to_mask,
89 yolo::yolo_segmentation_to_mask,
90};
91
92pub trait BBoxTypeTrait {
94 fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4];
96
97 fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
99 input: &[B; 4],
100 quant: Quantization,
101 ) -> [A; 4]
102 where
103 f32: AsPrimitive<A>,
104 i32: AsPrimitive<A>;
105
106 #[inline(always)]
117 fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
118 input: ArrayView1<B>,
119 ) -> [A; 4] {
120 Self::to_xyxy_float(&[input[0], input[1], input[2], input[3]])
121 }
122
123 #[inline(always)]
124 fn ndarray_to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
126 input: ArrayView1<B>,
127 quant: Quantization,
128 ) -> [A; 4]
129 where
130 f32: AsPrimitive<A>,
131 i32: AsPrimitive<A>,
132 {
133 Self::to_xyxy_dequant(&[input[0], input[1], input[2], input[3]], quant)
134 }
135}
136
137#[derive(Debug, Clone, Copy, PartialEq, Eq)]
139pub struct XYXY {}
140
141impl BBoxTypeTrait for XYXY {
142 fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
143 input.map(|b| b.as_())
144 }
145
146 fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
147 input: &[B; 4],
148 quant: Quantization,
149 ) -> [A; 4]
150 where
151 f32: AsPrimitive<A>,
152 i32: AsPrimitive<A>,
153 {
154 let scale = quant.scale.as_();
155 let zp = quant.zero_point.as_();
156 input.map(|b| (b.as_() - zp) * scale)
157 }
158
159 #[inline(always)]
160 fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
161 input: ArrayView1<B>,
162 ) -> [A; 4] {
163 [
164 input[0].as_(),
165 input[1].as_(),
166 input[2].as_(),
167 input[3].as_(),
168 ]
169 }
170}
171
172#[derive(Debug, Clone, Copy, PartialEq, Eq)]
175pub struct XYWH {}
176
177impl BBoxTypeTrait for XYWH {
178 #[inline(always)]
179 fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
180 let half = A::one() / (A::one() + A::one());
181 [
182 (input[0].as_()) - (input[2].as_() * half),
183 (input[1].as_()) - (input[3].as_() * half),
184 (input[0].as_()) + (input[2].as_() * half),
185 (input[1].as_()) + (input[3].as_() * half),
186 ]
187 }
188
189 #[inline(always)]
190 fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
191 input: &[B; 4],
192 quant: Quantization,
193 ) -> [A; 4]
194 where
195 f32: AsPrimitive<A>,
196 i32: AsPrimitive<A>,
197 {
198 let scale = quant.scale.as_();
199 let half_scale = (quant.scale * 0.5).as_();
200 let zp = quant.zero_point.as_();
201 let [x, y, w, h] = [
202 (input[0].as_() - zp) * scale,
203 (input[1].as_() - zp) * scale,
204 (input[2].as_() - zp) * half_scale,
205 (input[3].as_() - zp) * half_scale,
206 ];
207
208 [x - w, y - h, x + w, y + h]
209 }
210
211 #[inline(always)]
212 fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
213 input: ArrayView1<B>,
214 ) -> [A; 4] {
215 let half = A::one() / (A::one() + A::one());
216 [
217 (input[0].as_()) - (input[2].as_() * half),
218 (input[1].as_()) - (input[3].as_() * half),
219 (input[0].as_()) + (input[2].as_() * half),
220 (input[1].as_()) + (input[3].as_() * half),
221 ]
222 }
223}
224
225#[derive(Debug, Clone, Copy, PartialEq)]
227pub struct Quantization {
228 pub scale: f32,
229 pub zero_point: i32,
230}
231
232impl Quantization {
233 pub fn new(scale: f32, zero_point: i32) -> Self {
242 Self { scale, zero_point }
243 }
244}
245
246impl From<QuantTuple> for Quantization {
247 fn from(quant_tuple: QuantTuple) -> Quantization {
258 Quantization {
259 scale: quant_tuple.0,
260 zero_point: quant_tuple.1,
261 }
262 }
263}
264
265impl<S, Z> From<(S, Z)> for Quantization
266where
267 S: AsPrimitive<f32>,
268 Z: AsPrimitive<i32>,
269{
270 fn from((scale, zp): (S, Z)) -> Quantization {
279 Self {
280 scale: scale.as_(),
281 zero_point: zp.as_(),
282 }
283 }
284}
285
286impl Default for Quantization {
287 fn default() -> Self {
296 Self {
297 scale: 1.0,
298 zero_point: 0,
299 }
300 }
301}
302
303#[derive(Debug, Clone, Copy, PartialEq, Default)]
305pub struct DetectBox {
306 pub bbox: BoundingBox,
307 pub score: f32,
309 pub label: usize,
311}
312
313#[derive(Debug, Clone, Copy, PartialEq, Default)]
315pub struct BoundingBox {
316 pub xmin: f32,
318 pub ymin: f32,
320 pub xmax: f32,
322 pub ymax: f32,
324}
325
326impl BoundingBox {
327 pub fn new(xmin: f32, ymin: f32, xmax: f32, ymax: f32) -> Self {
329 Self {
330 xmin,
331 ymin,
332 xmax,
333 ymax,
334 }
335 }
336
337 pub fn to_canonical(&self) -> Self {
346 let xmin = self.xmin.min(self.xmax);
347 let xmax = self.xmin.max(self.xmax);
348 let ymin = self.ymin.min(self.ymax);
349 let ymax = self.ymin.max(self.ymax);
350 BoundingBox {
351 xmin,
352 ymin,
353 xmax,
354 ymax,
355 }
356 }
357}
358
359impl From<BoundingBox> for [f32; 4] {
360 fn from(b: BoundingBox) -> Self {
375 [b.xmin, b.ymin, b.xmax, b.ymax]
376 }
377}
378
379impl From<[f32; 4]> for BoundingBox {
380 fn from(arr: [f32; 4]) -> Self {
383 BoundingBox {
384 xmin: arr[0],
385 ymin: arr[1],
386 xmax: arr[2],
387 ymax: arr[3],
388 }
389 }
390}
391
392impl DetectBox {
393 pub fn equal_within_delta(&self, rhs: &DetectBox, eps: f32) -> bool {
422 let eq_delta = |a: f32, b: f32| (a - b).abs() <= eps;
423 self.label == rhs.label
424 && eq_delta(self.score, rhs.score)
425 && eq_delta(self.bbox.xmin, rhs.bbox.xmin)
426 && eq_delta(self.bbox.ymin, rhs.bbox.ymin)
427 && eq_delta(self.bbox.xmax, rhs.bbox.xmax)
428 && eq_delta(self.bbox.ymax, rhs.bbox.ymax)
429 }
430}
431
432#[derive(Debug, Clone, PartialEq, Default)]
435pub struct Segmentation {
436 pub xmin: f32,
438 pub ymin: f32,
440 pub xmax: f32,
442 pub ymax: f32,
444 pub segmentation: Array3<u8>,
447}
448
449pub fn dequant_detect_box<SCORE: PrimInt + AsPrimitive<f32>>(
467 detect: &DetectBoxQuantized<SCORE>,
468 quant_scores: Quantization,
469) -> DetectBox {
470 let scaled_zp = -quant_scores.scale * quant_scores.zero_point as f32;
471 DetectBox {
472 bbox: detect.bbox,
473 score: quant_scores.scale * detect.score.as_() + scaled_zp,
474 label: detect.label,
475 }
476}
477#[derive(Debug, Clone, Copy, PartialEq)]
479pub struct DetectBoxQuantized<
480 SCORE: PrimInt + AsPrimitive<f32>,
482> {
483 pub bbox: BoundingBox,
485 pub score: SCORE,
488 pub label: usize,
490}
491
492pub fn dequantize_ndarray<T: AsPrimitive<F>, D: Dimension, F: Float + 'static>(
505 input: ArrayView<T, D>,
506 quant: Quantization,
507) -> Array<F, D>
508where
509 i32: num_traits::AsPrimitive<F>,
510 f32: num_traits::AsPrimitive<F>,
511{
512 let zero_point = quant.zero_point.as_();
513 let scale = quant.scale.as_();
514 if zero_point != F::zero() {
515 let scaled_zero = -zero_point * scale;
516 input.mapv(|d| d.as_() * scale + scaled_zero)
517 } else {
518 input.mapv(|d| d.as_() * scale)
519 }
520}
521
522pub fn dequantize_cpu<T: AsPrimitive<F>, F: Float + 'static>(
535 input: &[T],
536 quant: Quantization,
537 output: &mut [F],
538) where
539 f32: num_traits::AsPrimitive<F>,
540 i32: num_traits::AsPrimitive<F>,
541{
542 assert!(input.len() == output.len());
543 let zero_point = quant.zero_point.as_();
544 let scale = quant.scale.as_();
545 if zero_point != F::zero() {
546 let scaled_zero = -zero_point * scale; input
548 .iter()
549 .zip(output)
550 .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
551 } else {
552 input
553 .iter()
554 .zip(output)
555 .for_each(|(d, deq)| *deq = d.as_() * scale);
556 }
557}
558
559pub fn dequantize_cpu_chunked<T: AsPrimitive<F>, F: Float + 'static>(
573 input: &[T],
574 quant: Quantization,
575 output: &mut [F],
576) where
577 f32: num_traits::AsPrimitive<F>,
578 i32: num_traits::AsPrimitive<F>,
579{
580 assert!(input.len() == output.len());
581 let zero_point = quant.zero_point.as_();
582 let scale = quant.scale.as_();
583
584 let input = input.as_chunks::<4>();
585 let output = output.as_chunks_mut::<4>();
586
587 if zero_point != F::zero() {
588 let scaled_zero = -zero_point * scale; input
591 .0
592 .iter()
593 .zip(output.0)
594 .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale + scaled_zero));
595 input
596 .1
597 .iter()
598 .zip(output.1)
599 .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
600 } else {
601 input
602 .0
603 .iter()
604 .zip(output.0)
605 .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale));
606 input
607 .1
608 .iter()
609 .zip(output.1)
610 .for_each(|(d, deq)| *deq = d.as_() * scale);
611 }
612}
613
614pub fn segmentation_to_mask(segmentation: ArrayView3<u8>) -> Result<Array2<u8>, DecoderError> {
632 if segmentation.shape()[2] == 0 {
633 return Err(DecoderError::InvalidShape(
634 "Segmentation tensor must have non-zero depth".to_string(),
635 ));
636 }
637 if segmentation.shape()[2] == 1 {
638 yolo_segmentation_to_mask(segmentation, 128)
639 } else {
640 Ok(modelpack_segmentation_to_mask(segmentation))
641 }
642}
643
644fn arg_max<T: PartialOrd + Copy>(score: ArrayView1<T>) -> (T, usize) {
646 score
647 .iter()
648 .enumerate()
649 .fold((score[0], 0), |(max, arg_max), (ind, s)| {
650 if max > *s {
651 (max, arg_max)
652 } else {
653 (*s, ind)
654 }
655 })
656}
657#[cfg(test)]
658#[cfg_attr(coverage_nightly, coverage(off))]
659mod decoder_tests {
660 #![allow(clippy::excessive_precision)]
661 use crate::{
662 configs::{DecoderType, DimName, Protos},
663 modelpack::{decode_modelpack_det, decode_modelpack_split_quant},
664 yolo::{
665 decode_yolo_det, decode_yolo_det_float, decode_yolo_segdet_float,
666 decode_yolo_segdet_quant,
667 },
668 *,
669 };
670 use ndarray::{array, s, Array4};
671 use ndarray_stats::DeviationExt;
672
673 fn compare_outputs(
674 boxes: (&[DetectBox], &[DetectBox]),
675 masks: (&[Segmentation], &[Segmentation]),
676 ) {
677 let (boxes0, boxes1) = boxes;
678 let (masks0, masks1) = masks;
679
680 assert_eq!(boxes0.len(), boxes1.len());
681 assert_eq!(masks0.len(), masks1.len());
682
683 for (b_i8, b_f32) in boxes0.iter().zip(boxes1) {
684 assert!(
685 b_i8.equal_within_delta(b_f32, 1e-6),
686 "{b_i8:?} is not equal to {b_f32:?}"
687 );
688 }
689
690 for (m_i8, m_f32) in masks0.iter().zip(masks1) {
691 assert_eq!(
692 [m_i8.xmin, m_i8.ymin, m_i8.xmax, m_i8.ymax],
693 [m_f32.xmin, m_f32.ymin, m_f32.xmax, m_f32.ymax],
694 );
695 assert_eq!(m_i8.segmentation.shape(), m_f32.segmentation.shape());
696 let mask_i8 = m_i8.segmentation.map(|x| *x as i32);
697 let mask_f32 = m_f32.segmentation.map(|x| *x as i32);
698 let diff = &mask_i8 - &mask_f32;
699 for x in 0..diff.shape()[0] {
700 for y in 0..diff.shape()[1] {
701 for z in 0..diff.shape()[2] {
702 let val = diff[[x, y, z]];
703 assert!(
704 val.abs() <= 1,
705 "Difference between mask0 and mask1 is greater than 1 at ({}, {}, {}): {}",
706 x,
707 y,
708 z,
709 val
710 );
711 }
712 }
713 }
714 let mean_sq_err = mask_i8.mean_sq_err(&mask_f32).unwrap();
715 assert!(
716 mean_sq_err < 1e-2,
717 "Mean Square Error between masks was greater than 1%: {:.2}%",
718 mean_sq_err * 100.0
719 );
720 }
721 }
722
723 #[test]
724 fn test_decoder_modelpack() {
725 let score_threshold = 0.45;
726 let iou_threshold = 0.45;
727 let boxes = include_bytes!("../../../testdata/modelpack_boxes_1935x1x4.bin");
728 let boxes = ndarray::Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
729
730 let scores = include_bytes!("../../../testdata/modelpack_scores_1935x1.bin");
731 let scores = ndarray::Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
732
733 let quant_boxes = (0.004656755365431309, 21).into();
734 let quant_scores = (0.0019603664986789227, 0).into();
735
736 let decoder = DecoderBuilder::default()
737 .with_config_modelpack_det(
738 configs::Boxes {
739 decoder: DecoderType::ModelPack,
740 quantization: Some(quant_boxes),
741 shape: vec![1, 1935, 1, 4],
742 dshape: vec![
743 (DimName::Batch, 1),
744 (DimName::NumBoxes, 1935),
745 (DimName::Padding, 1),
746 (DimName::BoxCoords, 4),
747 ],
748 normalized: Some(true),
749 },
750 configs::Scores {
751 decoder: DecoderType::ModelPack,
752 quantization: Some(quant_scores),
753 shape: vec![1, 1935, 1],
754 dshape: vec![
755 (DimName::Batch, 1),
756 (DimName::NumBoxes, 1935),
757 (DimName::NumClasses, 1),
758 ],
759 },
760 )
761 .with_score_threshold(score_threshold)
762 .with_iou_threshold(iou_threshold)
763 .build()
764 .unwrap();
765
766 let quant_boxes = quant_boxes.into();
767 let quant_scores = quant_scores.into();
768
769 let mut output_boxes: Vec<_> = Vec::with_capacity(50);
770 decode_modelpack_det(
771 (boxes.slice(s![0, .., 0, ..]), quant_boxes),
772 (scores.slice(s![0, .., ..]), quant_scores),
773 score_threshold,
774 iou_threshold,
775 &mut output_boxes,
776 );
777 assert!(output_boxes[0].equal_within_delta(
778 &DetectBox {
779 bbox: BoundingBox {
780 xmin: 0.40513772,
781 ymin: 0.6379755,
782 xmax: 0.5122431,
783 ymax: 0.7730214,
784 },
785 score: 0.4861709,
786 label: 0
787 },
788 1e-6
789 ));
790
791 let mut output_boxes1 = Vec::with_capacity(50);
792 let mut output_masks1 = Vec::with_capacity(50);
793
794 decoder
795 .decode_quantized(
796 &[boxes.view().into(), scores.view().into()],
797 &mut output_boxes1,
798 &mut output_masks1,
799 )
800 .unwrap();
801
802 let mut output_boxes_float = Vec::with_capacity(50);
803 let mut output_masks_float = Vec::with_capacity(50);
804
805 let boxes = dequantize_ndarray(boxes.view(), quant_boxes);
806 let scores = dequantize_ndarray(scores.view(), quant_scores);
807
808 decoder
809 .decode_float::<f32>(
810 &[boxes.view().into_dyn(), scores.view().into_dyn()],
811 &mut output_boxes_float,
812 &mut output_masks_float,
813 )
814 .unwrap();
815
816 compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
817 compare_outputs(
818 (&output_boxes, &output_boxes_float),
819 (&[], &output_masks_float),
820 );
821 }
822
823 #[test]
824 fn test_decoder_modelpack_split_u8() {
825 let score_threshold = 0.45;
826 let iou_threshold = 0.45;
827 let detect0 = include_bytes!("../../../testdata/modelpack_split_9x15x18.bin");
828 let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
829
830 let detect1 = include_bytes!("../../../testdata/modelpack_split_17x30x18.bin");
831 let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
832
833 let quant0 = (0.08547406643629074, 174).into();
834 let quant1 = (0.09929127991199493, 183).into();
835 let anchors0 = vec![
836 [0.36666667461395264, 0.31481480598449707],
837 [0.38749998807907104, 0.4740740656852722],
838 [0.5333333611488342, 0.644444465637207],
839 ];
840 let anchors1 = vec![
841 [0.13750000298023224, 0.2074074000120163],
842 [0.2541666626930237, 0.21481481194496155],
843 [0.23125000298023224, 0.35185185074806213],
844 ];
845
846 let detect_config0 = configs::Detection {
847 decoder: DecoderType::ModelPack,
848 shape: vec![1, 9, 15, 18],
849 anchors: Some(anchors0.clone()),
850 quantization: Some(quant0),
851 dshape: vec![
852 (DimName::Batch, 1),
853 (DimName::Height, 9),
854 (DimName::Width, 15),
855 (DimName::NumAnchorsXFeatures, 18),
856 ],
857 normalized: Some(true),
858 };
859
860 let detect_config1 = configs::Detection {
861 decoder: DecoderType::ModelPack,
862 shape: vec![1, 17, 30, 18],
863 anchors: Some(anchors1.clone()),
864 quantization: Some(quant1),
865 dshape: vec![
866 (DimName::Batch, 1),
867 (DimName::Height, 17),
868 (DimName::Width, 30),
869 (DimName::NumAnchorsXFeatures, 18),
870 ],
871 normalized: Some(true),
872 };
873
874 let config0 = (&detect_config0).try_into().unwrap();
875 let config1 = (&detect_config1).try_into().unwrap();
876
877 let decoder = DecoderBuilder::default()
878 .with_config_modelpack_det_split(vec![detect_config1, detect_config0])
879 .with_score_threshold(score_threshold)
880 .with_iou_threshold(iou_threshold)
881 .build()
882 .unwrap();
883
884 let quant0 = quant0.into();
885 let quant1 = quant1.into();
886
887 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
888 decode_modelpack_split_quant(
889 &[
890 detect0.slice(s![0, .., .., ..]),
891 detect1.slice(s![0, .., .., ..]),
892 ],
893 &[config0, config1],
894 score_threshold,
895 iou_threshold,
896 &mut output_boxes,
897 );
898 assert!(output_boxes[0].equal_within_delta(
899 &DetectBox {
900 bbox: BoundingBox {
901 xmin: 0.43171933,
902 ymin: 0.68243736,
903 xmax: 0.5626645,
904 ymax: 0.808863,
905 },
906 score: 0.99240804,
907 label: 0
908 },
909 1e-6
910 ));
911
912 let mut output_boxes1: Vec<_> = Vec::with_capacity(10);
913 let mut output_masks1: Vec<_> = Vec::with_capacity(10);
914 decoder
915 .decode_quantized(
916 &[detect0.view().into(), detect1.view().into()],
917 &mut output_boxes1,
918 &mut output_masks1,
919 )
920 .unwrap();
921
922 let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(10);
923 let mut output_masks1_f32: Vec<_> = Vec::with_capacity(10);
924
925 let detect0 = dequantize_ndarray(detect0.view(), quant0);
926 let detect1 = dequantize_ndarray(detect1.view(), quant1);
927 decoder
928 .decode_float::<f32>(
929 &[detect0.view().into_dyn(), detect1.view().into_dyn()],
930 &mut output_boxes1_f32,
931 &mut output_masks1_f32,
932 )
933 .unwrap();
934
935 compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
936 compare_outputs(
937 (&output_boxes, &output_boxes1_f32),
938 (&[], &output_masks1_f32),
939 );
940 }
941
942 #[test]
943 fn test_decoder_parse_config_modelpack_split_u8() {
944 let score_threshold = 0.45;
945 let iou_threshold = 0.45;
946 let detect0 = include_bytes!("../../../testdata/modelpack_split_9x15x18.bin");
947 let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
948
949 let detect1 = include_bytes!("../../../testdata/modelpack_split_17x30x18.bin");
950 let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
951
952 let decoder = DecoderBuilder::default()
953 .with_config_yaml_str(
954 include_str!("../../../testdata/modelpack_split.yaml").to_string(),
955 )
956 .with_score_threshold(score_threshold)
957 .with_iou_threshold(iou_threshold)
958 .build()
959 .unwrap();
960
961 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
962 let mut output_masks: Vec<_> = Vec::with_capacity(10);
963 decoder
964 .decode_quantized(
965 &[
966 ArrayViewDQuantized::from(detect1.view()),
967 ArrayViewDQuantized::from(detect0.view()),
968 ],
969 &mut output_boxes,
970 &mut output_masks,
971 )
972 .unwrap();
973 assert!(output_boxes[0].equal_within_delta(
974 &DetectBox {
975 bbox: BoundingBox {
976 xmin: 0.43171933,
977 ymin: 0.68243736,
978 xmax: 0.5626645,
979 ymax: 0.808863,
980 },
981 score: 0.99240804,
982 label: 0
983 },
984 1e-6
985 ));
986 }
987
988 #[test]
989 fn test_modelpack_seg() {
990 let out = include_bytes!("../../../testdata/modelpack_seg_2x160x160.bin");
991 let out = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
992 let quant = (1.0 / 255.0, 0).into();
993
994 let decoder = DecoderBuilder::default()
995 .with_config_modelpack_seg(configs::Segmentation {
996 decoder: DecoderType::ModelPack,
997 quantization: Some(quant),
998 shape: vec![1, 2, 160, 160],
999 dshape: vec![
1000 (DimName::Batch, 1),
1001 (DimName::NumClasses, 2),
1002 (DimName::Height, 160),
1003 (DimName::Width, 160),
1004 ],
1005 })
1006 .build()
1007 .unwrap();
1008 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1009 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1010 decoder
1011 .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
1012 .unwrap();
1013
1014 let mut mask = out.slice(s![0, .., .., ..]);
1015 mask.swap_axes(0, 1);
1016 mask.swap_axes(1, 2);
1017 let mask = [Segmentation {
1018 xmin: 0.0,
1019 ymin: 0.0,
1020 xmax: 1.0,
1021 ymax: 1.0,
1022 segmentation: mask.into_owned(),
1023 }];
1024 compare_outputs((&[], &output_boxes), (&mask, &output_masks));
1025
1026 decoder
1027 .decode_float::<f32>(
1028 &[dequantize_ndarray(out.view(), quant.into())
1029 .view()
1030 .into_dyn()],
1031 &mut output_boxes,
1032 &mut output_masks,
1033 )
1034 .unwrap();
1035
1036 compare_outputs((&[], &output_boxes), (&[], &[]));
1042 let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1043 let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1044
1045 assert_eq!(mask0, mask1);
1046 }
1047 #[test]
1048 fn test_modelpack_seg_quant() {
1049 let out = include_bytes!("../../../testdata/modelpack_seg_2x160x160.bin");
1050 let out_u8 = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
1051 let out_i8 = out_u8.mapv(|x| (x as i16 - 128) as i8);
1052 let out_u16 = out_u8.mapv(|x| (x as u16) << 8);
1053 let out_i16 = out_u8.mapv(|x| (((x as i32) << 8) - 32768) as i16);
1054 let out_u32 = out_u8.mapv(|x| (x as u32) << 24);
1055 let out_i32 = out_u8.mapv(|x| (((x as i64) << 24) - 2147483648) as i32);
1056
1057 let quant = (1.0 / 255.0, 0).into();
1058
1059 let decoder = DecoderBuilder::default()
1060 .with_config_modelpack_seg(configs::Segmentation {
1061 decoder: DecoderType::ModelPack,
1062 quantization: Some(quant),
1063 shape: vec![1, 2, 160, 160],
1064 dshape: vec![
1065 (DimName::Batch, 1),
1066 (DimName::NumClasses, 2),
1067 (DimName::Height, 160),
1068 (DimName::Width, 160),
1069 ],
1070 })
1071 .build()
1072 .unwrap();
1073 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1074 let mut output_masks_u8: Vec<_> = Vec::with_capacity(10);
1075 decoder
1076 .decode_quantized(
1077 &[out_u8.view().into()],
1078 &mut output_boxes,
1079 &mut output_masks_u8,
1080 )
1081 .unwrap();
1082
1083 let mut output_masks_i8: Vec<_> = Vec::with_capacity(10);
1084 decoder
1085 .decode_quantized(
1086 &[out_i8.view().into()],
1087 &mut output_boxes,
1088 &mut output_masks_i8,
1089 )
1090 .unwrap();
1091
1092 let mut output_masks_u16: Vec<_> = Vec::with_capacity(10);
1093 decoder
1094 .decode_quantized(
1095 &[out_u16.view().into()],
1096 &mut output_boxes,
1097 &mut output_masks_u16,
1098 )
1099 .unwrap();
1100
1101 let mut output_masks_i16: Vec<_> = Vec::with_capacity(10);
1102 decoder
1103 .decode_quantized(
1104 &[out_i16.view().into()],
1105 &mut output_boxes,
1106 &mut output_masks_i16,
1107 )
1108 .unwrap();
1109
1110 let mut output_masks_u32: Vec<_> = Vec::with_capacity(10);
1111 decoder
1112 .decode_quantized(
1113 &[out_u32.view().into()],
1114 &mut output_boxes,
1115 &mut output_masks_u32,
1116 )
1117 .unwrap();
1118
1119 let mut output_masks_i32: Vec<_> = Vec::with_capacity(10);
1120 decoder
1121 .decode_quantized(
1122 &[out_i32.view().into()],
1123 &mut output_boxes,
1124 &mut output_masks_i32,
1125 )
1126 .unwrap();
1127
1128 compare_outputs((&[], &output_boxes), (&[], &[]));
1129 let mask_u8 = segmentation_to_mask(output_masks_u8[0].segmentation.view()).unwrap();
1130 let mask_i8 = segmentation_to_mask(output_masks_i8[0].segmentation.view()).unwrap();
1131 let mask_u16 = segmentation_to_mask(output_masks_u16[0].segmentation.view()).unwrap();
1132 let mask_i16 = segmentation_to_mask(output_masks_i16[0].segmentation.view()).unwrap();
1133 let mask_u32 = segmentation_to_mask(output_masks_u32[0].segmentation.view()).unwrap();
1134 let mask_i32 = segmentation_to_mask(output_masks_i32[0].segmentation.view()).unwrap();
1135 assert_eq!(mask_u8, mask_i8);
1136 assert_eq!(mask_u8, mask_u16);
1137 assert_eq!(mask_u8, mask_i16);
1138 assert_eq!(mask_u8, mask_u32);
1139 assert_eq!(mask_u8, mask_i32);
1140 }
1141
1142 #[test]
1143 fn test_modelpack_segdet() {
1144 let score_threshold = 0.45;
1145 let iou_threshold = 0.45;
1146
1147 let boxes = include_bytes!("../../../testdata/modelpack_boxes_1935x1x4.bin");
1148 let boxes = Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
1149
1150 let scores = include_bytes!("../../../testdata/modelpack_scores_1935x1.bin");
1151 let scores = Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
1152
1153 let seg = include_bytes!("../../../testdata/modelpack_seg_2x160x160.bin");
1154 let seg = Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1155
1156 let quant_boxes = (0.004656755365431309, 21).into();
1157 let quant_scores = (0.0019603664986789227, 0).into();
1158 let quant_seg = (1.0 / 255.0, 0).into();
1159
1160 let decoder = DecoderBuilder::default()
1161 .with_config_modelpack_segdet(
1162 configs::Boxes {
1163 decoder: DecoderType::ModelPack,
1164 quantization: Some(quant_boxes),
1165 shape: vec![1, 1935, 1, 4],
1166 dshape: vec![
1167 (DimName::Batch, 1),
1168 (DimName::NumBoxes, 1935),
1169 (DimName::Padding, 1),
1170 (DimName::BoxCoords, 4),
1171 ],
1172 normalized: Some(true),
1173 },
1174 configs::Scores {
1175 decoder: DecoderType::ModelPack,
1176 quantization: Some(quant_scores),
1177 shape: vec![1, 1935, 1],
1178 dshape: vec![
1179 (DimName::Batch, 1),
1180 (DimName::NumBoxes, 1935),
1181 (DimName::NumClasses, 1),
1182 ],
1183 },
1184 configs::Segmentation {
1185 decoder: DecoderType::ModelPack,
1186 quantization: Some(quant_seg),
1187 shape: vec![1, 2, 160, 160],
1188 dshape: vec![
1189 (DimName::Batch, 1),
1190 (DimName::NumClasses, 2),
1191 (DimName::Height, 160),
1192 (DimName::Width, 160),
1193 ],
1194 },
1195 )
1196 .with_iou_threshold(iou_threshold)
1197 .with_score_threshold(score_threshold)
1198 .build()
1199 .unwrap();
1200 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1201 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1202 decoder
1203 .decode_quantized(
1204 &[scores.view().into(), boxes.view().into(), seg.view().into()],
1205 &mut output_boxes,
1206 &mut output_masks,
1207 )
1208 .unwrap();
1209
1210 let mut mask = seg.slice(s![0, .., .., ..]);
1211 mask.swap_axes(0, 1);
1212 mask.swap_axes(1, 2);
1213 let mask = [Segmentation {
1214 xmin: 0.0,
1215 ymin: 0.0,
1216 xmax: 1.0,
1217 ymax: 1.0,
1218 segmentation: mask.into_owned(),
1219 }];
1220 let correct_boxes = [DetectBox {
1221 bbox: BoundingBox {
1222 xmin: 0.40513772,
1223 ymin: 0.6379755,
1224 xmax: 0.5122431,
1225 ymax: 0.7730214,
1226 },
1227 score: 0.4861709,
1228 label: 0,
1229 }];
1230 compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1231
1232 let scores = dequantize_ndarray(scores.view(), quant_scores.into());
1233 let boxes = dequantize_ndarray(boxes.view(), quant_boxes.into());
1234 let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1235 decoder
1236 .decode_float::<f32>(
1237 &[
1238 scores.view().into_dyn(),
1239 boxes.view().into_dyn(),
1240 seg.view().into_dyn(),
1241 ],
1242 &mut output_boxes,
1243 &mut output_masks,
1244 )
1245 .unwrap();
1246
1247 compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1253 let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1254 let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1255
1256 assert_eq!(mask0, mask1);
1257 }
1258
1259 #[test]
1260 fn test_modelpack_segdet_split() {
1261 let score_threshold = 0.8;
1262 let iou_threshold = 0.5;
1263
1264 let seg = include_bytes!("../../../testdata/modelpack_seg_2x160x160.bin");
1265 let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1266
1267 let detect0 = include_bytes!("../../../testdata/modelpack_split_9x15x18.bin");
1268 let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1269
1270 let detect1 = include_bytes!("../../../testdata/modelpack_split_17x30x18.bin");
1271 let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1272
1273 let quant0 = (0.08547406643629074, 174).into();
1274 let quant1 = (0.09929127991199493, 183).into();
1275 let quant_seg = (1.0 / 255.0, 0).into();
1276
1277 let anchors0 = vec![
1278 [0.36666667461395264, 0.31481480598449707],
1279 [0.38749998807907104, 0.4740740656852722],
1280 [0.5333333611488342, 0.644444465637207],
1281 ];
1282 let anchors1 = vec![
1283 [0.13750000298023224, 0.2074074000120163],
1284 [0.2541666626930237, 0.21481481194496155],
1285 [0.23125000298023224, 0.35185185074806213],
1286 ];
1287
1288 let decoder = DecoderBuilder::default()
1289 .with_config_modelpack_segdet_split(
1290 vec![
1291 configs::Detection {
1292 decoder: DecoderType::ModelPack,
1293 shape: vec![1, 17, 30, 18],
1294 anchors: Some(anchors1),
1295 quantization: Some(quant1),
1296 dshape: vec![
1297 (DimName::Batch, 1),
1298 (DimName::Height, 17),
1299 (DimName::Width, 30),
1300 (DimName::NumAnchorsXFeatures, 18),
1301 ],
1302 normalized: Some(true),
1303 },
1304 configs::Detection {
1305 decoder: DecoderType::ModelPack,
1306 shape: vec![1, 9, 15, 18],
1307 anchors: Some(anchors0),
1308 quantization: Some(quant0),
1309 dshape: vec![
1310 (DimName::Batch, 1),
1311 (DimName::Height, 9),
1312 (DimName::Width, 15),
1313 (DimName::NumAnchorsXFeatures, 18),
1314 ],
1315 normalized: Some(true),
1316 },
1317 ],
1318 configs::Segmentation {
1319 decoder: DecoderType::ModelPack,
1320 quantization: Some(quant_seg),
1321 shape: vec![1, 2, 160, 160],
1322 dshape: vec![
1323 (DimName::Batch, 1),
1324 (DimName::NumClasses, 2),
1325 (DimName::Height, 160),
1326 (DimName::Width, 160),
1327 ],
1328 },
1329 )
1330 .with_score_threshold(score_threshold)
1331 .with_iou_threshold(iou_threshold)
1332 .build()
1333 .unwrap();
1334 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1335 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1336 decoder
1337 .decode_quantized(
1338 &[
1339 detect0.view().into(),
1340 detect1.view().into(),
1341 seg.view().into(),
1342 ],
1343 &mut output_boxes,
1344 &mut output_masks,
1345 )
1346 .unwrap();
1347
1348 let mut mask = seg.slice(s![0, .., .., ..]);
1349 mask.swap_axes(0, 1);
1350 mask.swap_axes(1, 2);
1351 let mask = [Segmentation {
1352 xmin: 0.0,
1353 ymin: 0.0,
1354 xmax: 1.0,
1355 ymax: 1.0,
1356 segmentation: mask.into_owned(),
1357 }];
1358 let correct_boxes = [DetectBox {
1359 bbox: BoundingBox {
1360 xmin: 0.43171933,
1361 ymin: 0.68243736,
1362 xmax: 0.5626645,
1363 ymax: 0.808863,
1364 },
1365 score: 0.99240804,
1366 label: 0,
1367 }];
1368 println!("Output Boxes: {:?}", output_boxes);
1369 compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1370
1371 let detect0 = dequantize_ndarray(detect0.view(), quant0.into());
1372 let detect1 = dequantize_ndarray(detect1.view(), quant1.into());
1373 let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1374 decoder
1375 .decode_float::<f32>(
1376 &[
1377 detect0.view().into_dyn(),
1378 detect1.view().into_dyn(),
1379 seg.view().into_dyn(),
1380 ],
1381 &mut output_boxes,
1382 &mut output_masks,
1383 )
1384 .unwrap();
1385
1386 compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1392 let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1393 let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1394
1395 assert_eq!(mask0, mask1);
1396 }
1397
1398 #[test]
1399 fn test_dequant_chunked() {
1400 let out = include_bytes!("../../../testdata/yolov8s_80_classes.bin");
1401 let mut out =
1402 unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) }.to_vec();
1403 out.push(123); let mut out_dequant = vec![0.0; 84 * 8400 + 1];
1406 let mut out_dequant_simd = vec![0.0; 84 * 8400 + 1];
1407 let quant = Quantization::new(0.0040811873, -123);
1408 dequantize_cpu(&out, quant, &mut out_dequant);
1409
1410 dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1411 assert_eq!(out_dequant, out_dequant_simd);
1412
1413 let quant = Quantization::new(0.0040811873, 0);
1414 dequantize_cpu(&out, quant, &mut out_dequant);
1415
1416 dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1417 assert_eq!(out_dequant, out_dequant_simd);
1418 }
1419
1420 #[test]
1421 fn test_decoder_yolo_det() {
1422 let score_threshold = 0.25;
1423 let iou_threshold = 0.7;
1424 let out = include_bytes!("../../../testdata/yolov8s_80_classes.bin");
1425 let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
1426 let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
1427 let quant = (0.0040811873, -123).into();
1428
1429 let decoder = DecoderBuilder::default()
1430 .with_config_yolo_det(
1431 configs::Detection {
1432 decoder: DecoderType::Ultralytics,
1433 shape: vec![1, 84, 8400],
1434 anchors: None,
1435 quantization: Some(quant),
1436 dshape: vec![
1437 (DimName::Batch, 1),
1438 (DimName::NumFeatures, 84),
1439 (DimName::NumBoxes, 8400),
1440 ],
1441 normalized: Some(true),
1442 },
1443 Some(DecoderVersion::Yolo11),
1444 )
1445 .with_score_threshold(score_threshold)
1446 .with_iou_threshold(iou_threshold)
1447 .build()
1448 .unwrap();
1449
1450 let mut output_boxes: Vec<_> = Vec::with_capacity(50);
1451 decode_yolo_det(
1452 (out.slice(s![0, .., ..]), quant.into()),
1453 score_threshold,
1454 iou_threshold,
1455 Some(configs::Nms::ClassAgnostic),
1456 &mut output_boxes,
1457 );
1458 assert!(output_boxes[0].equal_within_delta(
1459 &DetectBox {
1460 bbox: BoundingBox {
1461 xmin: 0.5285137,
1462 ymin: 0.05305544,
1463 xmax: 0.87541467,
1464 ymax: 0.9998909,
1465 },
1466 score: 0.5591227,
1467 label: 0
1468 },
1469 1e-6
1470 ));
1471
1472 assert!(output_boxes[1].equal_within_delta(
1473 &DetectBox {
1474 bbox: BoundingBox {
1475 xmin: 0.130598,
1476 ymin: 0.43260583,
1477 xmax: 0.35098213,
1478 ymax: 0.9958097,
1479 },
1480 score: 0.33057618,
1481 label: 75
1482 },
1483 1e-6
1484 ));
1485
1486 let mut output_boxes1: Vec<_> = Vec::with_capacity(50);
1487 let mut output_masks1: Vec<_> = Vec::with_capacity(50);
1488 decoder
1489 .decode_quantized(&[out.view().into()], &mut output_boxes1, &mut output_masks1)
1490 .unwrap();
1491
1492 let out = dequantize_ndarray(out.view(), quant.into());
1493 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(50);
1494 let mut output_masks_f32: Vec<_> = Vec::with_capacity(50);
1495 decoder
1496 .decode_float::<f32>(
1497 &[out.view().into_dyn()],
1498 &mut output_boxes_f32,
1499 &mut output_masks_f32,
1500 )
1501 .unwrap();
1502
1503 compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1504 compare_outputs((&output_boxes, &output_boxes_f32), (&[], &output_masks_f32));
1505 }
1506
1507 #[test]
1508 fn test_decoder_masks() {
1509 let score_threshold = 0.45;
1510 let iou_threshold = 0.45;
1511 let boxes = include_bytes!("../../../testdata/yolov8_boxes_116x8400.bin");
1512 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1513 let boxes = ndarray::Array2::from_shape_vec((116, 8400), boxes.to_vec()).unwrap();
1514 let quant_boxes = Quantization::new(0.021287761628627777, 31);
1515
1516 let protos = include_bytes!("../../../testdata/yolov8_protos_160x160x32.bin");
1517 let protos =
1518 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
1519 let protos = ndarray::Array3::from_shape_vec((160, 160, 32), protos.to_vec()).unwrap();
1520 let quant_protos = Quantization::new(0.02491161972284317, -117);
1521 let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1522 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1523 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1524 let mut output_masks: Vec<_> = Vec::with_capacity(10);
1525 decode_yolo_segdet_float(
1526 seg.view(),
1527 protos.view(),
1528 score_threshold,
1529 iou_threshold,
1530 Some(configs::Nms::ClassAgnostic),
1531 &mut output_boxes,
1532 &mut output_masks,
1533 );
1534 assert_eq!(output_boxes.len(), 2);
1535 assert_eq!(output_boxes.len(), output_masks.len());
1536
1537 for (b, m) in output_boxes.iter().zip(&output_masks) {
1538 assert!(b.bbox.xmin >= m.xmin);
1539 assert!(b.bbox.ymin >= m.ymin);
1540 assert!(b.bbox.xmax >= m.xmax);
1541 assert!(b.bbox.ymax >= m.ymax);
1542 }
1543 assert!(output_boxes[0].equal_within_delta(
1544 &DetectBox {
1545 bbox: BoundingBox {
1546 xmin: 0.08515105,
1547 ymin: 0.7131401,
1548 xmax: 0.29802868,
1549 ymax: 0.8195788,
1550 },
1551 score: 0.91537374,
1552 label: 23
1553 },
1554 1.0 / 160.0, ));
1556
1557 assert!(output_boxes[1].equal_within_delta(
1558 &DetectBox {
1559 bbox: BoundingBox {
1560 xmin: 0.59605736,
1561 ymin: 0.25545314,
1562 xmax: 0.93666154,
1563 ymax: 0.72378385,
1564 },
1565 score: 0.91537374,
1566 label: 23
1567 },
1568 1.0 / 160.0, ));
1570
1571 let full_mask = include_bytes!("../../../testdata/yolov8_mask_results.bin");
1572 let full_mask = ndarray::Array2::from_shape_vec((160, 160), full_mask.to_vec()).unwrap();
1573
1574 let cropped_mask = full_mask.slice(ndarray::s![
1575 (output_masks[1].ymin * 160.0) as usize..(output_masks[1].ymax * 160.0) as usize,
1576 (output_masks[1].xmin * 160.0) as usize..(output_masks[1].xmax * 160.0) as usize,
1577 ]);
1578
1579 assert_eq!(
1580 cropped_mask,
1581 segmentation_to_mask(output_masks[1].segmentation.view()).unwrap()
1582 );
1583 }
1584
1585 #[test]
1586 fn test_decoder_masks_i8() {
1587 let score_threshold = 0.45;
1588 let iou_threshold = 0.45;
1589 let boxes = include_bytes!("../../../testdata/yolov8_boxes_116x8400.bin");
1590 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1591 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes.to_vec()).unwrap();
1592 let quant_boxes = (0.021287761628627777, 31).into();
1593
1594 let protos = include_bytes!("../../../testdata/yolov8_protos_160x160x32.bin");
1595 let protos =
1596 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
1597 let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
1598 let quant_protos = (0.02491161972284317, -117).into();
1599 let mut output_boxes: Vec<_> = Vec::with_capacity(500);
1600 let mut output_masks: Vec<_> = Vec::with_capacity(500);
1601
1602 let decoder = DecoderBuilder::default()
1603 .with_config_yolo_segdet(
1604 configs::Detection {
1605 decoder: configs::DecoderType::Ultralytics,
1606 quantization: Some(quant_boxes),
1607 shape: vec![1, 116, 8400],
1608 anchors: None,
1609 dshape: vec![
1610 (DimName::Batch, 1),
1611 (DimName::NumFeatures, 116),
1612 (DimName::NumBoxes, 8400),
1613 ],
1614 normalized: Some(true),
1615 },
1616 Protos {
1617 decoder: configs::DecoderType::Ultralytics,
1618 quantization: Some(quant_protos),
1619 shape: vec![1, 160, 160, 32],
1620 dshape: vec![
1621 (DimName::Batch, 1),
1622 (DimName::Height, 160),
1623 (DimName::Width, 160),
1624 (DimName::NumProtos, 32),
1625 ],
1626 },
1627 Some(DecoderVersion::Yolo11),
1628 )
1629 .with_score_threshold(score_threshold)
1630 .with_iou_threshold(iou_threshold)
1631 .build()
1632 .unwrap();
1633
1634 let quant_boxes = quant_boxes.into();
1635 let quant_protos = quant_protos.into();
1636
1637 decode_yolo_segdet_quant(
1638 (boxes.slice(s![0, .., ..]), quant_boxes),
1639 (protos.slice(s![0, .., .., ..]), quant_protos),
1640 score_threshold,
1641 iou_threshold,
1642 Some(configs::Nms::ClassAgnostic),
1643 &mut output_boxes,
1644 &mut output_masks,
1645 );
1646
1647 let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
1648 let mut output_masks1: Vec<_> = Vec::with_capacity(500);
1649
1650 decoder
1651 .decode_quantized(
1652 &[boxes.view().into(), protos.view().into()],
1653 &mut output_boxes1,
1654 &mut output_masks1,
1655 )
1656 .unwrap();
1657
1658 let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1659 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1660
1661 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
1662 let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
1663 decode_yolo_segdet_float(
1664 seg.slice(s![0, .., ..]),
1665 protos.slice(s![0, .., .., ..]),
1666 score_threshold,
1667 iou_threshold,
1668 Some(configs::Nms::ClassAgnostic),
1669 &mut output_boxes_f32,
1670 &mut output_masks_f32,
1671 );
1672
1673 let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(500);
1674 let mut output_masks1_f32: Vec<_> = Vec::with_capacity(500);
1675
1676 decoder
1677 .decode_float(
1678 &[seg.view().into_dyn(), protos.view().into_dyn()],
1679 &mut output_boxes1_f32,
1680 &mut output_masks1_f32,
1681 )
1682 .unwrap();
1683
1684 compare_outputs(
1685 (&output_boxes, &output_boxes1),
1686 (&output_masks, &output_masks1),
1687 );
1688
1689 compare_outputs(
1690 (&output_boxes, &output_boxes_f32),
1691 (&output_masks, &output_masks_f32),
1692 );
1693
1694 compare_outputs(
1695 (&output_boxes_f32, &output_boxes1_f32),
1696 (&output_masks_f32, &output_masks1_f32),
1697 );
1698 }
1699
1700 #[test]
1701 fn test_decoder_yolo_split() {
1702 let score_threshold = 0.45;
1703 let iou_threshold = 0.45;
1704 let boxes = include_bytes!("../../../testdata/yolov8_boxes_116x8400.bin");
1705 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1706 let boxes: Vec<_> = boxes.iter().map(|x| *x as i16 * 256).collect();
1707 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
1708
1709 let quant_boxes = Quantization::new(0.021287761628627777 / 256.0, 31 * 256);
1710
1711 let decoder = DecoderBuilder::default()
1712 .with_config_yolo_split_det(
1713 configs::Boxes {
1714 decoder: configs::DecoderType::Ultralytics,
1715 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
1716 shape: vec![1, 4, 8400],
1717 dshape: vec![
1718 (DimName::Batch, 1),
1719 (DimName::BoxCoords, 4),
1720 (DimName::NumBoxes, 8400),
1721 ],
1722 normalized: Some(true),
1723 },
1724 configs::Scores {
1725 decoder: configs::DecoderType::Ultralytics,
1726 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
1727 shape: vec![1, 80, 8400],
1728 dshape: vec![
1729 (DimName::Batch, 1),
1730 (DimName::NumClasses, 80),
1731 (DimName::NumBoxes, 8400),
1732 ],
1733 },
1734 )
1735 .with_score_threshold(score_threshold)
1736 .with_iou_threshold(iou_threshold)
1737 .build()
1738 .unwrap();
1739
1740 let mut output_boxes: Vec<_> = Vec::with_capacity(500);
1741 let mut output_masks: Vec<_> = Vec::with_capacity(500);
1742
1743 decoder
1744 .decode_quantized(
1745 &[
1746 boxes.slice(s![.., ..4, ..]).into(),
1747 boxes.slice(s![.., 4..84, ..]).into(),
1748 ],
1749 &mut output_boxes,
1750 &mut output_masks,
1751 )
1752 .unwrap();
1753
1754 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1755 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
1756 decode_yolo_det_float(
1757 seg.slice(s![0, ..84, ..]),
1758 score_threshold,
1759 iou_threshold,
1760 Some(configs::Nms::ClassAgnostic),
1761 &mut output_boxes_f32,
1762 );
1763
1764 let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
1765 let mut output_masks1: Vec<_> = Vec::with_capacity(500);
1766
1767 decoder
1768 .decode_float(
1769 &[
1770 seg.slice(s![.., ..4, ..]).into_dyn(),
1771 seg.slice(s![.., 4..84, ..]).into_dyn(),
1772 ],
1773 &mut output_boxes1,
1774 &mut output_masks1,
1775 )
1776 .unwrap();
1777 compare_outputs((&output_boxes, &output_boxes_f32), (&output_masks, &[]));
1778 compare_outputs((&output_boxes_f32, &output_boxes1), (&[], &output_masks1));
1779 }
1780
1781 #[test]
1782 fn test_decoder_masks_config_mixed() {
1783 let score_threshold = 0.45;
1784 let iou_threshold = 0.45;
1785 let boxes = include_bytes!("../../../testdata/yolov8_boxes_116x8400.bin");
1786 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1787 let boxes: Vec<_> = boxes.iter().map(|x| *x as i16 * 256).collect();
1788 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
1789
1790 let quant_boxes = Quantization::new(0.021287761628627777 / 256.0, 31 * 256);
1791
1792 let protos = include_bytes!("../../../testdata/yolov8_protos_160x160x32.bin");
1793 let protos =
1794 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
1795 let protos: Vec<_> = protos.to_vec();
1796 let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
1797 let quant_protos = Quantization::new(0.02491161972284317, -117);
1798
1799 let decoder = DecoderBuilder::default()
1800 .with_config_yolo_split_segdet(
1801 configs::Boxes {
1802 decoder: configs::DecoderType::Ultralytics,
1803 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
1804 shape: vec![1, 4, 8400],
1805 dshape: vec![
1806 (DimName::Batch, 1),
1807 (DimName::BoxCoords, 4),
1808 (DimName::NumBoxes, 8400),
1809 ],
1810 normalized: Some(true),
1811 },
1812 configs::Scores {
1813 decoder: configs::DecoderType::Ultralytics,
1814 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
1815 shape: vec![1, 80, 8400],
1816 dshape: vec![
1817 (DimName::Batch, 1),
1818 (DimName::NumClasses, 80),
1819 (DimName::NumBoxes, 8400),
1820 ],
1821 },
1822 configs::MaskCoefficients {
1823 decoder: configs::DecoderType::Ultralytics,
1824 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
1825 shape: vec![1, 32, 8400],
1826 dshape: vec![
1827 (DimName::Batch, 1),
1828 (DimName::NumProtos, 32),
1829 (DimName::NumBoxes, 8400),
1830 ],
1831 },
1832 configs::Protos {
1833 decoder: configs::DecoderType::Ultralytics,
1834 quantization: Some(QuantTuple(quant_protos.scale, quant_protos.zero_point)),
1835 shape: vec![1, 160, 160, 32],
1836 dshape: vec![
1837 (DimName::Batch, 1),
1838 (DimName::Height, 160),
1839 (DimName::Width, 160),
1840 (DimName::NumProtos, 32),
1841 ],
1842 },
1843 )
1844 .with_score_threshold(score_threshold)
1845 .with_iou_threshold(iou_threshold)
1846 .build()
1847 .unwrap();
1848
1849 let mut output_boxes: Vec<_> = Vec::with_capacity(500);
1850 let mut output_masks: Vec<_> = Vec::with_capacity(500);
1851
1852 decoder
1853 .decode_quantized(
1854 &[
1855 boxes.slice(s![.., ..4, ..]).into(),
1856 boxes.slice(s![.., 4..84, ..]).into(),
1857 boxes.slice(s![.., 84.., ..]).into(),
1858 protos.view().into(),
1859 ],
1860 &mut output_boxes,
1861 &mut output_masks,
1862 )
1863 .unwrap();
1864
1865 let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1866 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1867 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
1868 let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
1869 decode_yolo_segdet_float(
1870 seg.slice(s![0, .., ..]),
1871 protos.slice(s![0, .., .., ..]),
1872 score_threshold,
1873 iou_threshold,
1874 Some(configs::Nms::ClassAgnostic),
1875 &mut output_boxes_f32,
1876 &mut output_masks_f32,
1877 );
1878
1879 let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
1880 let mut output_masks1: Vec<_> = Vec::with_capacity(500);
1881
1882 decoder
1883 .decode_float(
1884 &[
1885 seg.slice(s![.., ..4, ..]).into_dyn(),
1886 seg.slice(s![.., 4..84, ..]).into_dyn(),
1887 seg.slice(s![.., 84.., ..]).into_dyn(),
1888 protos.view().into_dyn(),
1889 ],
1890 &mut output_boxes1,
1891 &mut output_masks1,
1892 )
1893 .unwrap();
1894 compare_outputs(
1895 (&output_boxes, &output_boxes_f32),
1896 (&output_masks, &output_masks_f32),
1897 );
1898 compare_outputs(
1899 (&output_boxes_f32, &output_boxes1),
1900 (&output_masks_f32, &output_masks1),
1901 );
1902 }
1903
1904 #[test]
1905 fn test_decoder_masks_config_i32() {
1906 let score_threshold = 0.45;
1907 let iou_threshold = 0.45;
1908 let boxes = include_bytes!("../../../testdata/yolov8_boxes_116x8400.bin");
1909 let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1910 let scale = 1 << 23;
1911 let boxes: Vec<_> = boxes.iter().map(|x| *x as i32 * scale).collect();
1912 let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
1913
1914 let quant_boxes = Quantization::new(0.021287761628627777 / scale as f32, 31 * scale);
1915
1916 let protos = include_bytes!("../../../testdata/yolov8_protos_160x160x32.bin");
1917 let protos =
1918 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
1919 let protos: Vec<_> = protos.iter().map(|x| *x as i32 * scale).collect();
1920 let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
1921 let quant_protos = Quantization::new(0.02491161972284317 / scale as f32, -117 * scale);
1922
1923 let decoder = DecoderBuilder::default()
1924 .with_config_yolo_split_segdet(
1925 configs::Boxes {
1926 decoder: configs::DecoderType::Ultralytics,
1927 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
1928 shape: vec![1, 4, 8400],
1929 dshape: vec![
1930 (DimName::Batch, 1),
1931 (DimName::BoxCoords, 4),
1932 (DimName::NumBoxes, 8400),
1933 ],
1934 normalized: Some(true),
1935 },
1936 configs::Scores {
1937 decoder: configs::DecoderType::Ultralytics,
1938 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
1939 shape: vec![1, 80, 8400],
1940 dshape: vec![
1941 (DimName::Batch, 1),
1942 (DimName::NumClasses, 80),
1943 (DimName::NumBoxes, 8400),
1944 ],
1945 },
1946 configs::MaskCoefficients {
1947 decoder: configs::DecoderType::Ultralytics,
1948 quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
1949 shape: vec![1, 32, 8400],
1950 dshape: vec![
1951 (DimName::Batch, 1),
1952 (DimName::NumProtos, 32),
1953 (DimName::NumBoxes, 8400),
1954 ],
1955 },
1956 configs::Protos {
1957 decoder: configs::DecoderType::Ultralytics,
1958 quantization: Some(QuantTuple(quant_protos.scale, quant_protos.zero_point)),
1959 shape: vec![1, 160, 160, 32],
1960 dshape: vec![
1961 (DimName::Batch, 1),
1962 (DimName::Height, 160),
1963 (DimName::Width, 160),
1964 (DimName::NumProtos, 32),
1965 ],
1966 },
1967 )
1968 .with_score_threshold(score_threshold)
1969 .with_iou_threshold(iou_threshold)
1970 .build()
1971 .unwrap();
1972
1973 let mut output_boxes: Vec<_> = Vec::with_capacity(500);
1974 let mut output_masks: Vec<_> = Vec::with_capacity(500);
1975
1976 decoder
1977 .decode_quantized(
1978 &[
1979 boxes.slice(s![.., ..4, ..]).into(),
1980 boxes.slice(s![.., 4..84, ..]).into(),
1981 boxes.slice(s![.., 84.., ..]).into(),
1982 protos.view().into(),
1983 ],
1984 &mut output_boxes,
1985 &mut output_masks,
1986 )
1987 .unwrap();
1988
1989 let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1990 let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1991 let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
1992 let mut output_masks_f32: Vec<Segmentation> = Vec::with_capacity(500);
1993 decode_yolo_segdet_float(
1994 seg.slice(s![0, .., ..]),
1995 protos.slice(s![0, .., .., ..]),
1996 score_threshold,
1997 iou_threshold,
1998 Some(configs::Nms::ClassAgnostic),
1999 &mut output_boxes_f32,
2000 &mut output_masks_f32,
2001 );
2002
2003 assert_eq!(output_boxes.len(), output_boxes_f32.len());
2004 assert_eq!(output_masks.len(), output_masks_f32.len());
2005
2006 compare_outputs(
2007 (&output_boxes, &output_boxes_f32),
2008 (&output_masks, &output_masks_f32),
2009 );
2010 }
2011
2012 #[test]
2014 fn test_context_switch() {
2015 let yolo_det = || {
2016 let score_threshold = 0.25;
2017 let iou_threshold = 0.7;
2018 let out = include_bytes!("../../../testdata/yolov8s_80_classes.bin");
2019 let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
2020 let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
2021 let quant = (0.0040811873, -123).into();
2022
2023 let decoder = DecoderBuilder::default()
2024 .with_config_yolo_det(
2025 configs::Detection {
2026 decoder: DecoderType::Ultralytics,
2027 shape: vec![1, 84, 8400],
2028 anchors: None,
2029 quantization: Some(quant),
2030 dshape: vec![
2031 (DimName::Batch, 1),
2032 (DimName::NumFeatures, 84),
2033 (DimName::NumBoxes, 8400),
2034 ],
2035 normalized: None,
2036 },
2037 None,
2038 )
2039 .with_score_threshold(score_threshold)
2040 .with_iou_threshold(iou_threshold)
2041 .build()
2042 .unwrap();
2043
2044 let mut output_boxes: Vec<_> = Vec::with_capacity(50);
2045 let mut output_masks: Vec<_> = Vec::with_capacity(50);
2046
2047 for _ in 0..100 {
2048 decoder
2049 .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
2050 .unwrap();
2051
2052 assert!(output_boxes[0].equal_within_delta(
2053 &DetectBox {
2054 bbox: BoundingBox {
2055 xmin: 0.5285137,
2056 ymin: 0.05305544,
2057 xmax: 0.87541467,
2058 ymax: 0.9998909,
2059 },
2060 score: 0.5591227,
2061 label: 0
2062 },
2063 1e-6
2064 ));
2065
2066 assert!(output_boxes[1].equal_within_delta(
2067 &DetectBox {
2068 bbox: BoundingBox {
2069 xmin: 0.130598,
2070 ymin: 0.43260583,
2071 xmax: 0.35098213,
2072 ymax: 0.9958097,
2073 },
2074 score: 0.33057618,
2075 label: 75
2076 },
2077 1e-6
2078 ));
2079 assert!(output_masks.is_empty());
2080 }
2081 };
2082
2083 let modelpack_det_split = || {
2084 let score_threshold = 0.8;
2085 let iou_threshold = 0.5;
2086
2087 let seg = include_bytes!("../../../testdata/modelpack_seg_2x160x160.bin");
2088 let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
2089
2090 let detect0 = include_bytes!("../../../testdata/modelpack_split_9x15x18.bin");
2091 let detect0 =
2092 ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
2093
2094 let detect1 = include_bytes!("../../../testdata/modelpack_split_17x30x18.bin");
2095 let detect1 =
2096 ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
2097
2098 let mut mask = seg.slice(s![0, .., .., ..]);
2099 mask.swap_axes(0, 1);
2100 mask.swap_axes(1, 2);
2101 let mask = [Segmentation {
2102 xmin: 0.0,
2103 ymin: 0.0,
2104 xmax: 1.0,
2105 ymax: 1.0,
2106 segmentation: mask.into_owned(),
2107 }];
2108 let correct_boxes = [DetectBox {
2109 bbox: BoundingBox {
2110 xmin: 0.43171933,
2111 ymin: 0.68243736,
2112 xmax: 0.5626645,
2113 ymax: 0.808863,
2114 },
2115 score: 0.99240804,
2116 label: 0,
2117 }];
2118
2119 let quant0 = (0.08547406643629074, 174).into();
2120 let quant1 = (0.09929127991199493, 183).into();
2121 let quant_seg = (1.0 / 255.0, 0).into();
2122
2123 let anchors0 = vec![
2124 [0.36666667461395264, 0.31481480598449707],
2125 [0.38749998807907104, 0.4740740656852722],
2126 [0.5333333611488342, 0.644444465637207],
2127 ];
2128 let anchors1 = vec![
2129 [0.13750000298023224, 0.2074074000120163],
2130 [0.2541666626930237, 0.21481481194496155],
2131 [0.23125000298023224, 0.35185185074806213],
2132 ];
2133
2134 let decoder = DecoderBuilder::default()
2135 .with_config_modelpack_segdet_split(
2136 vec![
2137 configs::Detection {
2138 decoder: DecoderType::ModelPack,
2139 shape: vec![1, 17, 30, 18],
2140 anchors: Some(anchors1),
2141 quantization: Some(quant1),
2142 dshape: vec![
2143 (DimName::Batch, 1),
2144 (DimName::Height, 17),
2145 (DimName::Width, 30),
2146 (DimName::NumAnchorsXFeatures, 18),
2147 ],
2148 normalized: None,
2149 },
2150 configs::Detection {
2151 decoder: DecoderType::ModelPack,
2152 shape: vec![1, 9, 15, 18],
2153 anchors: Some(anchors0),
2154 quantization: Some(quant0),
2155 dshape: vec![
2156 (DimName::Batch, 1),
2157 (DimName::Height, 9),
2158 (DimName::Width, 15),
2159 (DimName::NumAnchorsXFeatures, 18),
2160 ],
2161 normalized: None,
2162 },
2163 ],
2164 configs::Segmentation {
2165 decoder: DecoderType::ModelPack,
2166 quantization: Some(quant_seg),
2167 shape: vec![1, 2, 160, 160],
2168 dshape: vec![
2169 (DimName::Batch, 1),
2170 (DimName::NumClasses, 2),
2171 (DimName::Height, 160),
2172 (DimName::Width, 160),
2173 ],
2174 },
2175 )
2176 .with_score_threshold(score_threshold)
2177 .with_iou_threshold(iou_threshold)
2178 .build()
2179 .unwrap();
2180 let mut output_boxes: Vec<_> = Vec::with_capacity(10);
2181 let mut output_masks: Vec<_> = Vec::with_capacity(10);
2182
2183 for _ in 0..100 {
2184 decoder
2185 .decode_quantized(
2186 &[
2187 detect0.view().into(),
2188 detect1.view().into(),
2189 seg.view().into(),
2190 ],
2191 &mut output_boxes,
2192 &mut output_masks,
2193 )
2194 .unwrap();
2195
2196 compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
2197 }
2198 };
2199
2200 let handles = vec![
2201 std::thread::spawn(yolo_det),
2202 std::thread::spawn(modelpack_det_split),
2203 std::thread::spawn(yolo_det),
2204 std::thread::spawn(modelpack_det_split),
2205 std::thread::spawn(yolo_det),
2206 std::thread::spawn(modelpack_det_split),
2207 std::thread::spawn(yolo_det),
2208 std::thread::spawn(modelpack_det_split),
2209 ];
2210 for handle in handles {
2211 handle.join().unwrap();
2212 }
2213 }
2214
2215 #[test]
2216 fn test_ndarray_to_xyxy_float() {
2217 let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2218 let xyxy: [f32; 4] = XYWH::ndarray_to_xyxy_float(arr.view());
2219 assert_eq!(xyxy, [0.0_f32, 10.0, 20.0, 30.0]);
2220
2221 let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2222 let xyxy: [f32; 4] = XYXY::ndarray_to_xyxy_float(arr.view());
2223 assert_eq!(xyxy, [10.0_f32, 20.0, 20.0, 20.0]);
2224 }
2225
2226 #[test]
2227 fn test_class_aware_nms_float() {
2228 use crate::float::nms_class_aware_float;
2229
2230 let boxes = vec![
2232 DetectBox {
2233 bbox: BoundingBox {
2234 xmin: 0.0,
2235 ymin: 0.0,
2236 xmax: 0.5,
2237 ymax: 0.5,
2238 },
2239 score: 0.9,
2240 label: 0, },
2242 DetectBox {
2243 bbox: BoundingBox {
2244 xmin: 0.1,
2245 ymin: 0.1,
2246 xmax: 0.6,
2247 ymax: 0.6,
2248 },
2249 score: 0.8,
2250 label: 1, },
2252 ];
2253
2254 let result = nms_class_aware_float(0.3, boxes.clone());
2257 assert_eq!(
2258 result.len(),
2259 2,
2260 "Class-aware NMS should keep both boxes with different classes"
2261 );
2262
2263 let same_class_boxes = vec![
2265 DetectBox {
2266 bbox: BoundingBox {
2267 xmin: 0.0,
2268 ymin: 0.0,
2269 xmax: 0.5,
2270 ymax: 0.5,
2271 },
2272 score: 0.9,
2273 label: 0,
2274 },
2275 DetectBox {
2276 bbox: BoundingBox {
2277 xmin: 0.1,
2278 ymin: 0.1,
2279 xmax: 0.6,
2280 ymax: 0.6,
2281 },
2282 score: 0.8,
2283 label: 0, },
2285 ];
2286
2287 let result = nms_class_aware_float(0.3, same_class_boxes);
2288 assert_eq!(
2289 result.len(),
2290 1,
2291 "Class-aware NMS should suppress overlapping box with same class"
2292 );
2293 assert_eq!(result[0].label, 0);
2294 assert!((result[0].score - 0.9).abs() < 1e-6);
2295 }
2296
2297 #[test]
2298 fn test_class_agnostic_vs_aware_nms() {
2299 use crate::float::{nms_class_aware_float, nms_float};
2300
2301 let boxes = vec![
2303 DetectBox {
2304 bbox: BoundingBox {
2305 xmin: 0.0,
2306 ymin: 0.0,
2307 xmax: 0.5,
2308 ymax: 0.5,
2309 },
2310 score: 0.9,
2311 label: 0,
2312 },
2313 DetectBox {
2314 bbox: BoundingBox {
2315 xmin: 0.1,
2316 ymin: 0.1,
2317 xmax: 0.6,
2318 ymax: 0.6,
2319 },
2320 score: 0.8,
2321 label: 1,
2322 },
2323 ];
2324
2325 let agnostic_result = nms_float(0.3, boxes.clone());
2327 assert_eq!(
2328 agnostic_result.len(),
2329 1,
2330 "Class-agnostic NMS should suppress overlapping boxes"
2331 );
2332
2333 let aware_result = nms_class_aware_float(0.3, boxes);
2335 assert_eq!(
2336 aware_result.len(),
2337 2,
2338 "Class-aware NMS should keep boxes with different classes"
2339 );
2340 }
2341
2342 #[test]
2343 fn test_class_aware_nms_int() {
2344 use crate::byte::nms_class_aware_int;
2345
2346 let boxes = vec![
2348 DetectBoxQuantized {
2349 bbox: BoundingBox {
2350 xmin: 0.0,
2351 ymin: 0.0,
2352 xmax: 0.5,
2353 ymax: 0.5,
2354 },
2355 score: 200_u8,
2356 label: 0,
2357 },
2358 DetectBoxQuantized {
2359 bbox: BoundingBox {
2360 xmin: 0.1,
2361 ymin: 0.1,
2362 xmax: 0.6,
2363 ymax: 0.6,
2364 },
2365 score: 180_u8,
2366 label: 1, },
2368 ];
2369
2370 let result = nms_class_aware_int(0.5, boxes);
2372 assert_eq!(
2373 result.len(),
2374 2,
2375 "Class-aware NMS (int) should keep boxes with different classes"
2376 );
2377 }
2378
2379 #[test]
2380 fn test_nms_enum_default() {
2381 let default_nms: configs::Nms = Default::default();
2383 assert_eq!(default_nms, configs::Nms::ClassAgnostic);
2384 }
2385
2386 #[test]
2387 fn test_decoder_nms_mode() {
2388 let decoder = DecoderBuilder::default()
2390 .with_config_yolo_det(
2391 configs::Detection {
2392 anchors: None,
2393 decoder: DecoderType::Ultralytics,
2394 quantization: None,
2395 shape: vec![1, 84, 8400],
2396 dshape: Vec::new(),
2397 normalized: Some(true),
2398 },
2399 None,
2400 )
2401 .with_nms(Some(configs::Nms::ClassAware))
2402 .build()
2403 .unwrap();
2404
2405 assert_eq!(decoder.nms, Some(configs::Nms::ClassAware));
2406 }
2407
2408 #[test]
2409 fn test_decoder_nms_bypass() {
2410 let decoder = DecoderBuilder::default()
2412 .with_config_yolo_det(
2413 configs::Detection {
2414 anchors: None,
2415 decoder: DecoderType::Ultralytics,
2416 quantization: None,
2417 shape: vec![1, 84, 8400],
2418 dshape: Vec::new(),
2419 normalized: Some(true),
2420 },
2421 None,
2422 )
2423 .with_nms(None)
2424 .build()
2425 .unwrap();
2426
2427 assert_eq!(decoder.nms, None);
2428 }
2429
2430 #[test]
2431 fn test_decoder_normalized_boxes_true() {
2432 let decoder = DecoderBuilder::default()
2434 .with_config_yolo_det(
2435 configs::Detection {
2436 anchors: None,
2437 decoder: DecoderType::Ultralytics,
2438 quantization: None,
2439 shape: vec![1, 84, 8400],
2440 dshape: Vec::new(),
2441 normalized: Some(true),
2442 },
2443 None,
2444 )
2445 .build()
2446 .unwrap();
2447
2448 assert_eq!(decoder.normalized_boxes(), Some(true));
2449 }
2450
2451 #[test]
2452 fn test_decoder_normalized_boxes_false() {
2453 let decoder = DecoderBuilder::default()
2456 .with_config_yolo_det(
2457 configs::Detection {
2458 anchors: None,
2459 decoder: DecoderType::Ultralytics,
2460 quantization: None,
2461 shape: vec![1, 84, 8400],
2462 dshape: Vec::new(),
2463 normalized: Some(false),
2464 },
2465 None,
2466 )
2467 .build()
2468 .unwrap();
2469
2470 assert_eq!(decoder.normalized_boxes(), Some(false));
2471 }
2472
2473 #[test]
2474 fn test_decoder_normalized_boxes_unknown() {
2475 let decoder = DecoderBuilder::default()
2477 .with_config_yolo_det(
2478 configs::Detection {
2479 anchors: None,
2480 decoder: DecoderType::Ultralytics,
2481 quantization: None,
2482 shape: vec![1, 84, 8400],
2483 dshape: Vec::new(),
2484 normalized: None,
2485 },
2486 Some(DecoderVersion::Yolo11),
2487 )
2488 .build()
2489 .unwrap();
2490
2491 assert_eq!(decoder.normalized_boxes(), None);
2492 }
2493}