1use std::fmt::Debug;
5
6use ndarray::{
7 parallel::prelude::{IntoParallelIterator, ParallelIterator},
8 s, Array2, Array3, ArrayView1, ArrayView2, ArrayView3,
9};
10use ndarray_stats::QuantileExt;
11use num_traits::{AsPrimitive, Float, PrimInt, Signed};
12
13use crate::{
14 byte::{
15 nms_class_aware_int, nms_extra_class_aware_int, nms_extra_int, nms_int,
16 postprocess_boxes_index_quant, postprocess_boxes_quant, quantize_score_threshold,
17 },
18 configs::Nms,
19 dequant_detect_box,
20 float::{
21 nms_class_aware_float, nms_extra_class_aware_float, nms_extra_float, nms_float,
22 postprocess_boxes_float, postprocess_boxes_index_float,
23 },
24 BBoxTypeTrait, BoundingBox, DetectBox, DetectBoxQuantized, ProtoData, ProtoTensor,
25 Quantization, Segmentation, XYWH, XYXY,
26};
27
28fn dispatch_nms_float(nms: Option<Nms>, iou: f32, boxes: Vec<DetectBox>) -> Vec<DetectBox> {
30 match nms {
31 Some(Nms::ClassAgnostic) => nms_float(iou, boxes),
32 Some(Nms::ClassAware) => nms_class_aware_float(iou, boxes),
33 None => boxes, }
35}
36
37fn dispatch_nms_extra_float<E: Send + Sync>(
40 nms: Option<Nms>,
41 iou: f32,
42 boxes: Vec<(DetectBox, E)>,
43) -> Vec<(DetectBox, E)> {
44 match nms {
45 Some(Nms::ClassAgnostic) => nms_extra_float(iou, boxes),
46 Some(Nms::ClassAware) => nms_extra_class_aware_float(iou, boxes),
47 None => boxes, }
49}
50
51fn dispatch_nms_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync>(
54 nms: Option<Nms>,
55 iou: f32,
56 boxes: Vec<DetectBoxQuantized<SCORE>>,
57) -> Vec<DetectBoxQuantized<SCORE>> {
58 match nms {
59 Some(Nms::ClassAgnostic) => nms_int(iou, boxes),
60 Some(Nms::ClassAware) => nms_class_aware_int(iou, boxes),
61 None => boxes, }
63}
64
65fn dispatch_nms_extra_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send + Sync>(
68 nms: Option<Nms>,
69 iou: f32,
70 boxes: Vec<(DetectBoxQuantized<SCORE>, E)>,
71) -> Vec<(DetectBoxQuantized<SCORE>, E)> {
72 match nms {
73 Some(Nms::ClassAgnostic) => nms_extra_int(iou, boxes),
74 Some(Nms::ClassAware) => nms_extra_class_aware_int(iou, boxes),
75 None => boxes, }
77}
78
79pub fn decode_yolo_det<BOX: PrimInt + AsPrimitive<f32> + Send + Sync>(
86 output: (ArrayView2<BOX>, Quantization),
87 score_threshold: f32,
88 iou_threshold: f32,
89 nms: Option<Nms>,
90 output_boxes: &mut Vec<DetectBox>,
91) where
92 f32: AsPrimitive<BOX>,
93{
94 impl_yolo_quant::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
95}
96
97pub fn decode_yolo_det_float<T>(
104 output: ArrayView2<T>,
105 score_threshold: f32,
106 iou_threshold: f32,
107 nms: Option<Nms>,
108 output_boxes: &mut Vec<DetectBox>,
109) where
110 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
111 f32: AsPrimitive<T>,
112{
113 impl_yolo_float::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
114}
115
116pub fn decode_yolo_segdet_quant<
128 BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
129 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
130>(
131 boxes: (ArrayView2<BOX>, Quantization),
132 protos: (ArrayView3<PROTO>, Quantization),
133 score_threshold: f32,
134 iou_threshold: f32,
135 nms: Option<Nms>,
136 output_boxes: &mut Vec<DetectBox>,
137 output_masks: &mut Vec<Segmentation>,
138) where
139 f32: AsPrimitive<BOX>,
140{
141 impl_yolo_segdet_quant::<XYWH, _, _>(
142 boxes,
143 protos,
144 score_threshold,
145 iou_threshold,
146 nms,
147 output_boxes,
148 output_masks,
149 );
150}
151
152pub fn decode_yolo_segdet_float<T>(
164 boxes: ArrayView2<T>,
165 protos: ArrayView3<T>,
166 score_threshold: f32,
167 iou_threshold: f32,
168 nms: Option<Nms>,
169 output_boxes: &mut Vec<DetectBox>,
170 output_masks: &mut Vec<Segmentation>,
171) where
172 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
173 f32: AsPrimitive<T>,
174{
175 impl_yolo_segdet_float::<XYWH, _, _>(
176 boxes,
177 protos,
178 score_threshold,
179 iou_threshold,
180 nms,
181 output_boxes,
182 output_masks,
183 );
184}
185
186pub fn decode_yolo_split_det_quant<
198 BOX: PrimInt + AsPrimitive<i32> + AsPrimitive<f32> + Send + Sync,
199 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
200>(
201 boxes: (ArrayView2<BOX>, Quantization),
202 scores: (ArrayView2<SCORE>, Quantization),
203 score_threshold: f32,
204 iou_threshold: f32,
205 nms: Option<Nms>,
206 output_boxes: &mut Vec<DetectBox>,
207) where
208 f32: AsPrimitive<SCORE>,
209{
210 impl_yolo_split_quant::<XYWH, _, _>(
211 boxes,
212 scores,
213 score_threshold,
214 iou_threshold,
215 nms,
216 output_boxes,
217 );
218}
219
220pub fn decode_yolo_split_det_float<T>(
232 boxes: ArrayView2<T>,
233 scores: ArrayView2<T>,
234 score_threshold: f32,
235 iou_threshold: f32,
236 nms: Option<Nms>,
237 output_boxes: &mut Vec<DetectBox>,
238) where
239 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
240 f32: AsPrimitive<T>,
241{
242 impl_yolo_split_float::<XYWH, _, _>(
243 boxes,
244 scores,
245 score_threshold,
246 iou_threshold,
247 nms,
248 output_boxes,
249 );
250}
251
252#[allow(clippy::too_many_arguments)]
266pub fn decode_yolo_split_segdet<
267 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
268 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
269 MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
270 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
271>(
272 boxes: (ArrayView2<BOX>, Quantization),
273 scores: (ArrayView2<SCORE>, Quantization),
274 mask_coeff: (ArrayView2<MASK>, Quantization),
275 protos: (ArrayView3<PROTO>, Quantization),
276 score_threshold: f32,
277 iou_threshold: f32,
278 nms: Option<Nms>,
279 output_boxes: &mut Vec<DetectBox>,
280 output_masks: &mut Vec<Segmentation>,
281) where
282 f32: AsPrimitive<SCORE>,
283{
284 impl_yolo_split_segdet_quant::<XYWH, _, _, _, _>(
285 boxes,
286 scores,
287 mask_coeff,
288 protos,
289 score_threshold,
290 iou_threshold,
291 nms,
292 output_boxes,
293 output_masks,
294 );
295}
296
297#[allow(clippy::too_many_arguments)]
311pub fn decode_yolo_split_segdet_float<T>(
312 boxes: ArrayView2<T>,
313 scores: ArrayView2<T>,
314 mask_coeff: ArrayView2<T>,
315 protos: ArrayView3<T>,
316 score_threshold: f32,
317 iou_threshold: f32,
318 nms: Option<Nms>,
319 output_boxes: &mut Vec<DetectBox>,
320 output_masks: &mut Vec<Segmentation>,
321) where
322 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
323 f32: AsPrimitive<T>,
324{
325 impl_yolo_split_segdet_float::<XYWH, _, _, _, _>(
326 boxes,
327 scores,
328 mask_coeff,
329 protos,
330 score_threshold,
331 iou_threshold,
332 nms,
333 output_boxes,
334 output_masks,
335 );
336}
337
338pub fn decode_yolo_end_to_end_det_float<T>(
353 output: ArrayView2<T>,
354 score_threshold: f32,
355 output_boxes: &mut Vec<DetectBox>,
356) -> Result<(), crate::DecoderError>
357where
358 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
359 f32: AsPrimitive<T>,
360{
361 if output.shape()[0] < 6 {
363 return Err(crate::DecoderError::InvalidShape(format!(
364 "End-to-end detection output requires at least 6 rows, got {}",
365 output.shape()[0]
366 )));
367 }
368
369 let boxes = output.slice(s![0..4, ..]).reversed_axes();
371 let scores = output.slice(s![4..5, ..]).reversed_axes();
372 let classes = output.slice(s![5, ..]);
373 let mut boxes =
374 postprocess_boxes_index_float::<XYXY, _, _>(score_threshold.as_(), boxes, scores);
375 boxes.truncate(output_boxes.capacity());
376 output_boxes.clear();
377 for (mut b, i) in boxes.into_iter() {
378 b.label = classes[i].as_() as usize;
379 output_boxes.push(b);
380 }
381 Ok(())
383}
384
385pub fn decode_yolo_end_to_end_segdet_float<T>(
403 output: ArrayView2<T>,
404 protos: ArrayView3<T>,
405 score_threshold: f32,
406 output_boxes: &mut Vec<DetectBox>,
407 output_masks: &mut Vec<crate::Segmentation>,
408) -> Result<(), crate::DecoderError>
409where
410 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
411 f32: AsPrimitive<T>,
412{
413 if output.shape()[0] < 7 {
415 return Err(crate::DecoderError::InvalidShape(format!(
416 "End-to-end segdet output requires at least 7 rows, got {}",
417 output.shape()[0]
418 )));
419 }
420
421 let num_mask_coeffs = output.shape()[0] - 6;
422 let num_protos = protos.shape()[2];
423 if num_mask_coeffs != num_protos {
424 return Err(crate::DecoderError::InvalidShape(format!(
425 "Mask coefficients count ({}) doesn't match protos count ({})",
426 num_mask_coeffs, num_protos
427 )));
428 }
429
430 let boxes = output.slice(s![0..4, ..]).reversed_axes();
432 let scores = output.slice(s![4..5, ..]).reversed_axes();
433 let classes = output.slice(s![5, ..]);
434 let mask_coeff = output.slice(s![6.., ..]).reversed_axes();
435 let mut boxes =
436 postprocess_boxes_index_float::<XYXY, _, _>(score_threshold.as_(), boxes, scores);
437 boxes.truncate(output_boxes.capacity());
438
439 for (b, ind) in &mut boxes {
440 b.label = classes[*ind].as_() as usize;
441 }
442
443 let boxes = decode_segdet_f32(boxes, mask_coeff, protos);
446
447 output_boxes.clear();
448 output_masks.clear();
449 for (b, m) in boxes.into_iter() {
450 output_boxes.push(b);
451 output_masks.push(Segmentation {
452 xmin: b.bbox.xmin,
453 ymin: b.bbox.ymin,
454 xmax: b.bbox.xmax,
455 ymax: b.bbox.ymax,
456 segmentation: m,
457 });
458 }
459 Ok(())
460}
461
462pub fn decode_yolo_split_end_to_end_det_float<T: Float + AsPrimitive<f32>>(
471 boxes: ArrayView2<T>,
472 scores: ArrayView2<T>,
473 classes: ArrayView2<T>,
474 score_threshold: f32,
475 output_boxes: &mut Vec<DetectBox>,
476) -> Result<(), crate::DecoderError> {
477 let n = boxes.shape()[0];
478 if boxes.shape()[1] != 4 {
479 return Err(crate::DecoderError::InvalidShape(format!(
480 "Split end-to-end boxes must have 4 columns, got {}",
481 boxes.shape()[1]
482 )));
483 }
484 output_boxes.clear();
485 for i in 0..n {
486 let score: f32 = scores[[i, 0]].as_();
487 if score < score_threshold {
488 continue;
489 }
490 if output_boxes.len() >= output_boxes.capacity() {
491 break;
492 }
493 output_boxes.push(DetectBox {
494 bbox: BoundingBox {
495 xmin: boxes[[i, 0]].as_(),
496 ymin: boxes[[i, 1]].as_(),
497 xmax: boxes[[i, 2]].as_(),
498 ymax: boxes[[i, 3]].as_(),
499 },
500 score,
501 label: classes[[i, 0]].as_() as usize,
502 });
503 }
504 Ok(())
505}
506
507#[allow(clippy::too_many_arguments)]
516pub fn decode_yolo_split_end_to_end_segdet_float<T>(
517 boxes: ArrayView2<T>,
518 scores: ArrayView2<T>,
519 classes: ArrayView2<T>,
520 mask_coeff: ArrayView2<T>,
521 protos: ArrayView3<T>,
522 score_threshold: f32,
523 output_boxes: &mut Vec<DetectBox>,
524 output_masks: &mut Vec<crate::Segmentation>,
525) -> Result<(), crate::DecoderError>
526where
527 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
528 f32: AsPrimitive<T>,
529{
530 let n = boxes.shape()[0];
531 if boxes.shape()[1] != 4 {
532 return Err(crate::DecoderError::InvalidShape(format!(
533 "Split end-to-end boxes must have 4 columns, got {}",
534 boxes.shape()[1]
535 )));
536 }
537
538 let mut qualifying: Vec<(DetectBox, usize)> = Vec::with_capacity(output_boxes.capacity());
540 for i in 0..n {
541 let score: f32 = scores[[i, 0]].as_();
542 if score < score_threshold {
543 continue;
544 }
545 if qualifying.len() >= output_boxes.capacity() {
546 break;
547 }
548 qualifying.push((
549 DetectBox {
550 bbox: BoundingBox {
551 xmin: boxes[[i, 0]].as_(),
552 ymin: boxes[[i, 1]].as_(),
553 xmax: boxes[[i, 2]].as_(),
554 ymax: boxes[[i, 3]].as_(),
555 },
556 score,
557 label: classes[[i, 0]].as_() as usize,
558 },
559 i,
560 ));
561 }
562
563 let result = decode_segdet_f32(qualifying, mask_coeff, protos);
565
566 output_boxes.clear();
567 output_masks.clear();
568 for (b, m) in result.into_iter() {
569 output_masks.push(crate::Segmentation {
570 xmin: b.bbox.xmin,
571 ymin: b.bbox.ymin,
572 xmax: b.bbox.xmax,
573 ymax: b.bbox.ymax,
574 segmentation: m,
575 });
576 output_boxes.push(b);
577 }
578 Ok(())
579}
580
581pub fn impl_yolo_quant<B: BBoxTypeTrait, T: PrimInt + AsPrimitive<f32> + Send + Sync>(
586 output: (ArrayView2<T>, Quantization),
587 score_threshold: f32,
588 iou_threshold: f32,
589 nms: Option<Nms>,
590 output_boxes: &mut Vec<DetectBox>,
591) where
592 f32: AsPrimitive<T>,
593{
594 let (boxes, quant_boxes) = output;
595 let (boxes_tensor, scores_tensor) = postprocess_yolo(&boxes);
596
597 let boxes = {
598 let score_threshold = quantize_score_threshold(score_threshold, quant_boxes);
599 postprocess_boxes_quant::<B, _, _>(
600 score_threshold,
601 boxes_tensor,
602 scores_tensor,
603 quant_boxes,
604 )
605 };
606
607 let boxes = dispatch_nms_int(nms, iou_threshold, boxes);
608 let len = output_boxes.capacity().min(boxes.len());
609 output_boxes.clear();
610 for b in boxes.iter().take(len) {
611 output_boxes.push(dequant_detect_box(b, quant_boxes));
612 }
613}
614
615pub fn impl_yolo_float<B: BBoxTypeTrait, T: Float + AsPrimitive<f32> + Send + Sync>(
620 output: ArrayView2<T>,
621 score_threshold: f32,
622 iou_threshold: f32,
623 nms: Option<Nms>,
624 output_boxes: &mut Vec<DetectBox>,
625) where
626 f32: AsPrimitive<T>,
627{
628 let (boxes_tensor, scores_tensor) = postprocess_yolo(&output);
629 let boxes =
630 postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
631 let boxes = dispatch_nms_float(nms, iou_threshold, boxes);
632 let len = output_boxes.capacity().min(boxes.len());
633 output_boxes.clear();
634 for b in boxes.into_iter().take(len) {
635 output_boxes.push(b);
636 }
637}
638
639pub fn impl_yolo_split_quant<
649 B: BBoxTypeTrait,
650 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
651 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
652>(
653 boxes: (ArrayView2<BOX>, Quantization),
654 scores: (ArrayView2<SCORE>, Quantization),
655 score_threshold: f32,
656 iou_threshold: f32,
657 nms: Option<Nms>,
658 output_boxes: &mut Vec<DetectBox>,
659) where
660 f32: AsPrimitive<SCORE>,
661{
662 let (boxes_tensor, quant_boxes) = boxes;
663 let (scores_tensor, quant_scores) = scores;
664
665 let boxes_tensor = boxes_tensor.reversed_axes();
666 let scores_tensor = scores_tensor.reversed_axes();
667
668 let boxes = {
669 let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
670 postprocess_boxes_quant::<B, _, _>(
671 score_threshold,
672 boxes_tensor,
673 scores_tensor,
674 quant_boxes,
675 )
676 };
677
678 let boxes = dispatch_nms_int(nms, iou_threshold, boxes);
679 let len = output_boxes.capacity().min(boxes.len());
680 output_boxes.clear();
681 for b in boxes.iter().take(len) {
682 output_boxes.push(dequant_detect_box(b, quant_scores));
683 }
684}
685
686pub fn impl_yolo_split_float<
695 B: BBoxTypeTrait,
696 BOX: Float + AsPrimitive<f32> + Send + Sync,
697 SCORE: Float + AsPrimitive<f32> + Send + Sync,
698>(
699 boxes_tensor: ArrayView2<BOX>,
700 scores_tensor: ArrayView2<SCORE>,
701 score_threshold: f32,
702 iou_threshold: f32,
703 nms: Option<Nms>,
704 output_boxes: &mut Vec<DetectBox>,
705) where
706 f32: AsPrimitive<SCORE>,
707{
708 let boxes_tensor = boxes_tensor.reversed_axes();
709 let scores_tensor = scores_tensor.reversed_axes();
710 let boxes =
711 postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
712 let boxes = dispatch_nms_float(nms, iou_threshold, boxes);
713 let len = output_boxes.capacity().min(boxes.len());
714 output_boxes.clear();
715 for b in boxes.into_iter().take(len) {
716 output_boxes.push(b);
717 }
718}
719
720pub fn impl_yolo_segdet_quant<
730 B: BBoxTypeTrait,
731 BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
732 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
733>(
734 boxes: (ArrayView2<BOX>, Quantization),
735 protos: (ArrayView3<PROTO>, Quantization),
736 score_threshold: f32,
737 iou_threshold: f32,
738 nms: Option<Nms>,
739 output_boxes: &mut Vec<DetectBox>,
740 output_masks: &mut Vec<Segmentation>,
741) where
742 f32: AsPrimitive<BOX>,
743{
744 let (boxes, quant_boxes) = boxes;
745 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes);
746
747 let boxes = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
748 (boxes_tensor.reversed_axes(), quant_boxes),
749 (scores_tensor.reversed_axes(), quant_boxes),
750 score_threshold,
751 iou_threshold,
752 nms,
753 output_boxes.capacity(),
754 );
755
756 impl_yolo_split_segdet_quant_process_masks::<_, _>(
757 boxes,
758 (mask_tensor.reversed_axes(), quant_boxes),
759 protos,
760 output_boxes,
761 output_masks,
762 );
763}
764
765pub fn impl_yolo_segdet_float<
775 B: BBoxTypeTrait,
776 BOX: Float + AsPrimitive<f32> + Send + Sync,
777 PROTO: Float + AsPrimitive<f32> + Send + Sync,
778>(
779 boxes: ArrayView2<BOX>,
780 protos: ArrayView3<PROTO>,
781 score_threshold: f32,
782 iou_threshold: f32,
783 nms: Option<Nms>,
784 output_boxes: &mut Vec<DetectBox>,
785 output_masks: &mut Vec<Segmentation>,
786) where
787 f32: AsPrimitive<BOX>,
788{
789 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes);
790
791 let boxes = postprocess_boxes_index_float::<B, _, _>(
792 score_threshold.as_(),
793 boxes_tensor,
794 scores_tensor,
795 );
796 let mut boxes = dispatch_nms_extra_float(nms, iou_threshold, boxes);
797 boxes.truncate(output_boxes.capacity());
798 let boxes = decode_segdet_f32(boxes, mask_tensor, protos);
799 output_boxes.clear();
800 output_masks.clear();
801 for (b, m) in boxes.into_iter() {
802 output_boxes.push(b);
803 output_masks.push(Segmentation {
804 xmin: b.bbox.xmin,
805 ymin: b.bbox.ymin,
806 xmax: b.bbox.xmax,
807 ymax: b.bbox.ymax,
808 segmentation: m,
809 });
810 }
811}
812
813pub(crate) fn impl_yolo_split_segdet_quant_get_boxes<
814 B: BBoxTypeTrait,
815 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
816 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
817>(
818 boxes: (ArrayView2<BOX>, Quantization),
819 scores: (ArrayView2<SCORE>, Quantization),
820 score_threshold: f32,
821 iou_threshold: f32,
822 nms: Option<Nms>,
823 max_boxes: usize,
824) -> Vec<(DetectBox, usize)>
825where
826 f32: AsPrimitive<SCORE>,
827{
828 let (boxes_tensor, quant_boxes) = boxes;
829 let (scores_tensor, quant_scores) = scores;
830
831 let boxes_tensor = boxes_tensor.reversed_axes();
832 let scores_tensor = scores_tensor.reversed_axes();
833
834 let boxes = {
835 let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
836 postprocess_boxes_index_quant::<B, _, _>(
837 score_threshold,
838 boxes_tensor,
839 scores_tensor,
840 quant_boxes,
841 )
842 };
843 let mut boxes = dispatch_nms_extra_int(nms, iou_threshold, boxes);
844 boxes.truncate(max_boxes);
845 boxes
846 .into_iter()
847 .map(|(b, i)| (dequant_detect_box(&b, quant_scores), i))
848 .collect()
849}
850
851pub(crate) fn impl_yolo_split_segdet_quant_process_masks<
852 MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
853 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
854>(
855 boxes: Vec<(DetectBox, usize)>,
856 mask_coeff: (ArrayView2<MASK>, Quantization),
857 protos: (ArrayView3<PROTO>, Quantization),
858 output_boxes: &mut Vec<DetectBox>,
859 output_masks: &mut Vec<Segmentation>,
860) {
861 let (masks, quant_masks) = mask_coeff;
862 let (protos, quant_protos) = protos;
863
864 let masks = masks.reversed_axes();
865
866 let boxes = decode_segdet_quant(boxes, masks, protos, quant_masks, quant_protos);
867 output_boxes.clear();
868 output_masks.clear();
869 for (b, m) in boxes.into_iter() {
870 output_boxes.push(b);
871 output_masks.push(Segmentation {
872 xmin: b.bbox.xmin,
873 ymin: b.bbox.ymin,
874 xmax: b.bbox.xmax,
875 ymax: b.bbox.ymax,
876 segmentation: m,
877 });
878 }
879}
880
881#[allow(clippy::too_many_arguments)]
882pub fn impl_yolo_split_segdet_quant<
894 B: BBoxTypeTrait,
895 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
896 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
897 MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
898 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
899>(
900 boxes: (ArrayView2<BOX>, Quantization),
901 scores: (ArrayView2<SCORE>, Quantization),
902 mask_coeff: (ArrayView2<MASK>, Quantization),
903 protos: (ArrayView3<PROTO>, Quantization),
904 score_threshold: f32,
905 iou_threshold: f32,
906 nms: Option<Nms>,
907 output_boxes: &mut Vec<DetectBox>,
908 output_masks: &mut Vec<Segmentation>,
909) where
910 f32: AsPrimitive<SCORE>,
911{
912 let boxes = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
913 boxes,
914 scores,
915 score_threshold,
916 iou_threshold,
917 nms,
918 output_boxes.capacity(),
919 );
920
921 impl_yolo_split_segdet_quant_process_masks(
922 boxes,
923 mask_coeff,
924 protos,
925 output_boxes,
926 output_masks,
927 );
928}
929
930#[allow(clippy::too_many_arguments)]
931pub fn impl_yolo_split_segdet_float<
943 B: BBoxTypeTrait,
944 BOX: Float + AsPrimitive<f32> + Send + Sync,
945 SCORE: Float + AsPrimitive<f32> + Send + Sync,
946 MASK: Float + AsPrimitive<f32> + Send + Sync,
947 PROTO: Float + AsPrimitive<f32> + Send + Sync,
948>(
949 boxes_tensor: ArrayView2<BOX>,
950 scores_tensor: ArrayView2<SCORE>,
951 mask_tensor: ArrayView2<MASK>,
952 protos: ArrayView3<PROTO>,
953 score_threshold: f32,
954 iou_threshold: f32,
955 nms: Option<Nms>,
956 output_boxes: &mut Vec<DetectBox>,
957 output_masks: &mut Vec<Segmentation>,
958) where
959 f32: AsPrimitive<SCORE>,
960{
961 let boxes_tensor = boxes_tensor.reversed_axes();
962 let scores_tensor = scores_tensor.reversed_axes();
963 let mask_tensor = mask_tensor.reversed_axes();
964
965 let boxes = postprocess_boxes_index_float::<B, _, _>(
966 score_threshold.as_(),
967 boxes_tensor,
968 scores_tensor,
969 );
970 let mut boxes = dispatch_nms_extra_float(nms, iou_threshold, boxes);
971 boxes.truncate(output_boxes.capacity());
972 let boxes = decode_segdet_f32(boxes, mask_tensor, protos);
973 output_boxes.clear();
974 output_masks.clear();
975 for (b, m) in boxes.into_iter() {
976 output_boxes.push(b);
977 output_masks.push(Segmentation {
978 xmin: b.bbox.xmin,
979 ymin: b.bbox.ymin,
980 xmax: b.bbox.xmax,
981 ymax: b.bbox.ymax,
982 segmentation: m,
983 });
984 }
985}
986
987pub fn impl_yolo_segdet_quant_proto<
994 B: BBoxTypeTrait,
995 BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
996 PROTO: PrimInt
997 + AsPrimitive<i64>
998 + AsPrimitive<i128>
999 + AsPrimitive<f32>
1000 + AsPrimitive<i8>
1001 + Send
1002 + Sync,
1003>(
1004 boxes: (ArrayView2<BOX>, Quantization),
1005 protos: (ArrayView3<PROTO>, Quantization),
1006 score_threshold: f32,
1007 iou_threshold: f32,
1008 nms: Option<Nms>,
1009 output_boxes: &mut Vec<DetectBox>,
1010) -> ProtoData
1011where
1012 f32: AsPrimitive<BOX>,
1013{
1014 let (boxes_arr, quant_boxes) = boxes;
1015 let (protos_arr, quant_protos) = protos;
1016
1017 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes_arr);
1018
1019 let det_indices = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
1020 (boxes_tensor.reversed_axes(), quant_boxes),
1021 (scores_tensor.reversed_axes(), quant_boxes),
1022 score_threshold,
1023 iou_threshold,
1024 nms,
1025 output_boxes.capacity(),
1026 );
1027
1028 let mask_tensor = mask_tensor.reversed_axes();
1029 extract_proto_data_quant(
1030 det_indices,
1031 mask_tensor,
1032 quant_boxes,
1033 protos_arr,
1034 quant_protos,
1035 output_boxes,
1036 )
1037}
1038
1039pub fn impl_yolo_segdet_float_proto<
1042 B: BBoxTypeTrait,
1043 BOX: Float + AsPrimitive<f32> + Send + Sync,
1044 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1045>(
1046 boxes: ArrayView2<BOX>,
1047 protos: ArrayView3<PROTO>,
1048 score_threshold: f32,
1049 iou_threshold: f32,
1050 nms: Option<Nms>,
1051 output_boxes: &mut Vec<DetectBox>,
1052) -> ProtoData
1053where
1054 f32: AsPrimitive<BOX>,
1055{
1056 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes);
1057
1058 let det_indices = postprocess_boxes_index_float::<B, _, _>(
1059 score_threshold.as_(),
1060 boxes_tensor,
1061 scores_tensor,
1062 );
1063 let mut det_indices = dispatch_nms_extra_float(nms, iou_threshold, det_indices);
1064 det_indices.truncate(output_boxes.capacity());
1065
1066 let mask_tensor = mask_tensor.reversed_axes();
1067 extract_proto_data_float(det_indices, mask_tensor, protos, output_boxes)
1068}
1069
1070#[allow(clippy::too_many_arguments)]
1073pub fn impl_yolo_split_segdet_quant_proto<
1074 B: BBoxTypeTrait,
1075 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
1076 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
1077 MASK: PrimInt + AsPrimitive<f32> + Send + Sync,
1078 PROTO: PrimInt + AsPrimitive<f32> + AsPrimitive<i8> + Send + Sync,
1079>(
1080 boxes: (ArrayView2<BOX>, Quantization),
1081 scores: (ArrayView2<SCORE>, Quantization),
1082 mask_coeff: (ArrayView2<MASK>, Quantization),
1083 protos: (ArrayView3<PROTO>, Quantization),
1084 score_threshold: f32,
1085 iou_threshold: f32,
1086 nms: Option<Nms>,
1087 output_boxes: &mut Vec<DetectBox>,
1088) -> ProtoData
1089where
1090 f32: AsPrimitive<SCORE>,
1091{
1092 let det_indices = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
1093 boxes,
1094 scores,
1095 score_threshold,
1096 iou_threshold,
1097 nms,
1098 output_boxes.capacity(),
1099 );
1100
1101 let (masks, quant_masks) = mask_coeff;
1102 let masks = masks.reversed_axes();
1103 let (protos_arr, quant_protos) = protos;
1104
1105 extract_proto_data_quant(
1106 det_indices,
1107 masks,
1108 quant_masks,
1109 protos_arr,
1110 quant_protos,
1111 output_boxes,
1112 )
1113}
1114
1115#[allow(clippy::too_many_arguments)]
1118pub fn impl_yolo_split_segdet_float_proto<
1119 B: BBoxTypeTrait,
1120 BOX: Float + AsPrimitive<f32> + Send + Sync,
1121 SCORE: Float + AsPrimitive<f32> + Send + Sync,
1122 MASK: Float + AsPrimitive<f32> + Send + Sync,
1123 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1124>(
1125 boxes_tensor: ArrayView2<BOX>,
1126 scores_tensor: ArrayView2<SCORE>,
1127 mask_tensor: ArrayView2<MASK>,
1128 protos: ArrayView3<PROTO>,
1129 score_threshold: f32,
1130 iou_threshold: f32,
1131 nms: Option<Nms>,
1132 output_boxes: &mut Vec<DetectBox>,
1133) -> ProtoData
1134where
1135 f32: AsPrimitive<SCORE>,
1136{
1137 let boxes_tensor = boxes_tensor.reversed_axes();
1138 let scores_tensor = scores_tensor.reversed_axes();
1139 let mask_tensor = mask_tensor.reversed_axes();
1140
1141 let det_indices = postprocess_boxes_index_float::<B, _, _>(
1142 score_threshold.as_(),
1143 boxes_tensor,
1144 scores_tensor,
1145 );
1146 let mut det_indices = dispatch_nms_extra_float(nms, iou_threshold, det_indices);
1147 det_indices.truncate(output_boxes.capacity());
1148
1149 extract_proto_data_float(det_indices, mask_tensor, protos, output_boxes)
1150}
1151
1152pub fn decode_yolo_end_to_end_segdet_float_proto<T>(
1154 output: ArrayView2<T>,
1155 protos: ArrayView3<T>,
1156 score_threshold: f32,
1157 output_boxes: &mut Vec<DetectBox>,
1158) -> Result<ProtoData, crate::DecoderError>
1159where
1160 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
1161 f32: AsPrimitive<T>,
1162{
1163 if output.shape()[0] < 7 {
1164 return Err(crate::DecoderError::InvalidShape(format!(
1165 "End-to-end segdet output requires at least 7 rows, got {}",
1166 output.shape()[0]
1167 )));
1168 }
1169
1170 let num_mask_coeffs = output.shape()[0] - 6;
1171 let num_protos = protos.shape()[2];
1172 if num_mask_coeffs != num_protos {
1173 return Err(crate::DecoderError::InvalidShape(format!(
1174 "Mask coefficients count ({}) doesn't match protos count ({})",
1175 num_mask_coeffs, num_protos
1176 )));
1177 }
1178
1179 let boxes = output.slice(s![0..4, ..]).reversed_axes();
1180 let scores = output.slice(s![4..5, ..]).reversed_axes();
1181 let classes = output.slice(s![5, ..]);
1182 let mask_coeff = output.slice(s![6.., ..]).reversed_axes();
1183 let mut det_indices =
1184 postprocess_boxes_index_float::<XYXY, _, _>(score_threshold.as_(), boxes, scores);
1185 det_indices.truncate(output_boxes.capacity());
1186
1187 for (b, ind) in &mut det_indices {
1188 b.label = classes[*ind].as_() as usize;
1189 }
1190
1191 Ok(extract_proto_data_float(
1192 det_indices,
1193 mask_coeff,
1194 protos,
1195 output_boxes,
1196 ))
1197}
1198
1199#[allow(clippy::too_many_arguments)]
1201pub fn decode_yolo_split_end_to_end_segdet_float_proto<T>(
1202 boxes: ArrayView2<T>,
1203 scores: ArrayView2<T>,
1204 classes: ArrayView2<T>,
1205 mask_coeff: ArrayView2<T>,
1206 protos: ArrayView3<T>,
1207 score_threshold: f32,
1208 output_boxes: &mut Vec<DetectBox>,
1209) -> Result<ProtoData, crate::DecoderError>
1210where
1211 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
1212 f32: AsPrimitive<T>,
1213{
1214 let n = boxes.shape()[0];
1215 if boxes.shape()[1] != 4 {
1216 return Err(crate::DecoderError::InvalidShape(format!(
1217 "Split end-to-end boxes must have 4 columns, got {}",
1218 boxes.shape()[1]
1219 )));
1220 }
1221
1222 let mut qualifying: Vec<(DetectBox, usize)> = Vec::with_capacity(output_boxes.capacity());
1223 for i in 0..n {
1224 let score: f32 = scores[[i, 0]].as_();
1225 if score < score_threshold {
1226 continue;
1227 }
1228 if qualifying.len() >= output_boxes.capacity() {
1229 break;
1230 }
1231 qualifying.push((
1232 DetectBox {
1233 bbox: BoundingBox {
1234 xmin: boxes[[i, 0]].as_(),
1235 ymin: boxes[[i, 1]].as_(),
1236 xmax: boxes[[i, 2]].as_(),
1237 ymax: boxes[[i, 3]].as_(),
1238 },
1239 score,
1240 label: classes[[i, 0]].as_() as usize,
1241 },
1242 i,
1243 ));
1244 }
1245
1246 Ok(extract_proto_data_float(
1247 qualifying,
1248 mask_coeff,
1249 protos,
1250 output_boxes,
1251 ))
1252}
1253
1254fn extract_proto_data_float<
1256 MASK: Float + AsPrimitive<f32> + Send + Sync,
1257 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1258>(
1259 det_indices: Vec<(DetectBox, usize)>,
1260 mask_tensor: ArrayView2<MASK>,
1261 protos: ArrayView3<PROTO>,
1262 output_boxes: &mut Vec<DetectBox>,
1263) -> ProtoData {
1264 let mut mask_coefficients = Vec::with_capacity(det_indices.len());
1265 output_boxes.clear();
1266 for (det, idx) in det_indices {
1267 output_boxes.push(det);
1268 let row = mask_tensor.row(idx);
1269 mask_coefficients.push(row.iter().map(|v| v.as_()).collect());
1270 }
1271 let protos_f32 = protos.map(|v| v.as_());
1272 ProtoData {
1273 mask_coefficients,
1274 protos: ProtoTensor::Float(protos_f32),
1275 }
1276}
1277
1278pub(crate) fn extract_proto_data_quant<
1284 MASK: PrimInt + AsPrimitive<f32> + Send + Sync,
1285 PROTO: PrimInt + AsPrimitive<f32> + AsPrimitive<i8> + Send + Sync,
1286>(
1287 det_indices: Vec<(DetectBox, usize)>,
1288 mask_tensor: ArrayView2<MASK>,
1289 quant_masks: Quantization,
1290 protos: ArrayView3<PROTO>,
1291 quant_protos: Quantization,
1292 output_boxes: &mut Vec<DetectBox>,
1293) -> ProtoData {
1294 let mut mask_coefficients = Vec::with_capacity(det_indices.len());
1295 output_boxes.clear();
1296 for (det, idx) in det_indices {
1297 output_boxes.push(det);
1298 let row = mask_tensor.row(idx);
1299 mask_coefficients.push(
1300 row.iter()
1301 .map(|v| (v.as_() - quant_masks.zero_point as f32) * quant_masks.scale)
1302 .collect(),
1303 );
1304 }
1305 let protos_i8 = protos.map(|v| {
1307 let v_i8: i8 = v.as_();
1308 v_i8
1309 });
1310 ProtoData {
1311 mask_coefficients,
1312 protos: ProtoTensor::Quantized {
1313 protos: protos_i8,
1314 quantization: quant_protos,
1315 },
1316 }
1317}
1318
1319fn postprocess_yolo<'a, T>(
1320 output: &'a ArrayView2<'_, T>,
1321) -> (ArrayView2<'a, T>, ArrayView2<'a, T>) {
1322 let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1323 let scores_tensor = output.slice(s![4.., ..,]).reversed_axes();
1324 (boxes_tensor, scores_tensor)
1325}
1326
1327fn postprocess_yolo_seg<'a, T>(
1328 output: &'a ArrayView2<'_, T>,
1329) -> (ArrayView2<'a, T>, ArrayView2<'a, T>, ArrayView2<'a, T>) {
1330 assert!(output.shape()[0] > 32 + 4, "Output shape is too short");
1331 let num_classes = output.shape()[0] - 4 - 32;
1332 let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1333 let scores_tensor = output.slice(s![4..(num_classes + 4), ..,]).reversed_axes();
1334 let mask_tensor = output.slice(s![(num_classes + 4).., ..,]).reversed_axes();
1335 (boxes_tensor, scores_tensor, mask_tensor)
1336}
1337
1338fn decode_segdet_f32<
1339 MASK: Float + AsPrimitive<f32> + Send + Sync,
1340 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1341>(
1342 boxes: Vec<(DetectBox, usize)>,
1343 masks: ArrayView2<MASK>,
1344 protos: ArrayView3<PROTO>,
1345) -> Vec<(DetectBox, Array3<u8>)> {
1346 if boxes.is_empty() {
1347 return Vec::new();
1348 }
1349 assert!(masks.shape()[1] == protos.shape()[2]);
1350 boxes
1351 .into_par_iter()
1352 .map(|mut b| {
1353 let ind = b.1;
1354 let (protos, roi) = protobox(&protos, &b.0.bbox);
1355 b.0.bbox = roi;
1356 (b.0, make_segmentation(masks.row(ind), protos.view()))
1357 })
1358 .collect()
1359}
1360
1361pub(crate) fn decode_segdet_quant<
1362 MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1363 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1364>(
1365 boxes: Vec<(DetectBox, usize)>,
1366 masks: ArrayView2<MASK>,
1367 protos: ArrayView3<PROTO>,
1368 quant_masks: Quantization,
1369 quant_protos: Quantization,
1370) -> Vec<(DetectBox, Array3<u8>)> {
1371 if boxes.is_empty() {
1372 return Vec::new();
1373 }
1374 assert!(masks.shape()[1] == protos.shape()[2]);
1375
1376 let total_bits = MASK::zero().count_zeros() + PROTO::zero().count_zeros() + 5; boxes
1378 .into_iter()
1379 .map(|mut b| {
1380 let i = b.1;
1381 let (protos, roi) = protobox(&protos, &b.0.bbox.to_canonical());
1382 b.0.bbox = roi;
1383 let seg = match total_bits {
1384 0..=64 => make_segmentation_quant::<MASK, PROTO, i64>(
1385 masks.row(i),
1386 protos.view(),
1387 quant_masks,
1388 quant_protos,
1389 ),
1390 65..=128 => make_segmentation_quant::<MASK, PROTO, i128>(
1391 masks.row(i),
1392 protos.view(),
1393 quant_masks,
1394 quant_protos,
1395 ),
1396 _ => panic!("Unsupported bit width for segmentation computation"),
1397 };
1398 (b.0, seg)
1399 })
1400 .collect()
1401}
1402
1403fn protobox<'a, T>(
1404 protos: &'a ArrayView3<T>,
1405 roi: &BoundingBox,
1406) -> (ArrayView3<'a, T>, BoundingBox) {
1407 let width = protos.dim().1 as f32;
1408 let height = protos.dim().0 as f32;
1409
1410 let roi = [
1411 (roi.xmin * width).clamp(0.0, width) as usize,
1412 (roi.ymin * height).clamp(0.0, height) as usize,
1413 (roi.xmax * width).clamp(0.0, width).ceil() as usize,
1414 (roi.ymax * height).clamp(0.0, height).ceil() as usize,
1415 ];
1416
1417 let roi_norm = [
1418 roi[0] as f32 / width,
1419 roi[1] as f32 / height,
1420 roi[2] as f32 / width,
1421 roi[3] as f32 / height,
1422 ]
1423 .into();
1424
1425 let cropped = protos.slice(s![roi[1]..roi[3], roi[0]..roi[2], ..]);
1426
1427 (cropped, roi_norm)
1428}
1429
1430fn make_segmentation<
1431 MASK: Float + AsPrimitive<f32> + Send + Sync,
1432 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1433>(
1434 mask: ArrayView1<MASK>,
1435 protos: ArrayView3<PROTO>,
1436) -> Array3<u8> {
1437 let shape = protos.shape();
1438
1439 let mask = mask.to_shape((1, mask.len())).unwrap();
1441 let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
1442 let protos = protos.reversed_axes();
1443 let mask = mask.map(|x| x.as_());
1444 let protos = protos.map(|x| x.as_());
1445
1446 let mask = mask
1448 .dot(&protos)
1449 .into_shape_with_order((shape[0], shape[1], 1))
1450 .unwrap();
1451
1452 let min = *mask.min().unwrap_or(&0.0);
1453 let max = *mask.max().unwrap_or(&1.0);
1454 let max = max.max(-min);
1455 let min = -max;
1456 let u8_max = 256.0;
1457 mask.map(|x| ((*x - min) / (max - min) * u8_max) as u8)
1458}
1459
1460fn make_segmentation_quant<
1461 MASK: PrimInt + AsPrimitive<DEST> + Send + Sync,
1462 PROTO: PrimInt + AsPrimitive<DEST> + Send + Sync,
1463 DEST: PrimInt + 'static + Signed + AsPrimitive<f32> + Debug,
1464>(
1465 mask: ArrayView1<MASK>,
1466 protos: ArrayView3<PROTO>,
1467 quant_masks: Quantization,
1468 quant_protos: Quantization,
1469) -> Array3<u8>
1470where
1471 i32: AsPrimitive<DEST>,
1472 f32: AsPrimitive<DEST>,
1473{
1474 let shape = protos.shape();
1475
1476 let mask = mask.to_shape((1, mask.len())).unwrap();
1478
1479 let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
1480 let protos = protos.reversed_axes();
1481
1482 let zp = quant_masks.zero_point.as_();
1483
1484 let mask = mask.mapv(|x| x.as_() - zp);
1485
1486 let zp = quant_protos.zero_point.as_();
1487 let protos = protos.mapv(|x| x.as_() - zp);
1488
1489 let segmentation = mask
1491 .dot(&protos)
1492 .into_shape_with_order((shape[0], shape[1], 1))
1493 .unwrap();
1494
1495 let min = *segmentation.min().unwrap_or(&DEST::zero());
1496 let max = *segmentation.max().unwrap_or(&DEST::one());
1497 let max = max.max(-min);
1498 let min = -max;
1499 segmentation.map(|x| ((*x - min).as_() / (max - min).as_() * 256.0) as u8)
1500}
1501
1502pub fn yolo_segmentation_to_mask(
1514 segmentation: ArrayView3<u8>,
1515 threshold: u8,
1516) -> Result<Array2<u8>, crate::DecoderError> {
1517 if segmentation.shape()[2] != 1 {
1518 return Err(crate::DecoderError::InvalidShape(format!(
1519 "Yolo Instance Segmentation should have shape (H, W, 1), got (H, W, {})",
1520 segmentation.shape()[2]
1521 )));
1522 }
1523 Ok(segmentation
1524 .slice(s![.., .., 0])
1525 .map(|x| if *x >= threshold { 1 } else { 0 }))
1526}
1527
1528#[cfg(test)]
1529#[cfg_attr(coverage_nightly, coverage(off))]
1530mod tests {
1531 use super::*;
1532 use ndarray::Array2;
1533
1534 #[test]
1539 fn test_end_to_end_det_basic_filtering() {
1540 let data: Vec<f32> = vec![
1544 0.1, 0.2, 0.3, 0.1, 0.2, 0.3, 0.5, 0.6, 0.7, 0.5, 0.6, 0.7, 0.9, 0.1, 0.2, 0.0, 1.0, 2.0, ];
1552 let output = Array2::from_shape_vec((6, 3), data).unwrap();
1553
1554 let mut boxes = Vec::with_capacity(10);
1555 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1556
1557 assert_eq!(boxes.len(), 1);
1559 assert_eq!(boxes[0].label, 0);
1560 assert!((boxes[0].score - 0.9).abs() < 0.01);
1561 assert!((boxes[0].bbox.xmin - 0.1).abs() < 0.01);
1562 assert!((boxes[0].bbox.ymin - 0.1).abs() < 0.01);
1563 assert!((boxes[0].bbox.xmax - 0.5).abs() < 0.01);
1564 assert!((boxes[0].bbox.ymax - 0.5).abs() < 0.01);
1565 }
1566
1567 #[test]
1568 fn test_end_to_end_det_all_pass_threshold() {
1569 let data: Vec<f32> = vec![
1571 10.0, 20.0, 10.0, 20.0, 50.0, 60.0, 50.0, 60.0, 0.8, 0.7, 1.0, 2.0, ];
1578 let output = Array2::from_shape_vec((6, 2), data).unwrap();
1579
1580 let mut boxes = Vec::with_capacity(10);
1581 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1582
1583 assert_eq!(boxes.len(), 2);
1584 assert_eq!(boxes[0].label, 1);
1585 assert_eq!(boxes[1].label, 2);
1586 }
1587
1588 #[test]
1589 fn test_end_to_end_det_none_pass_threshold() {
1590 let data: Vec<f32> = vec![
1592 10.0, 20.0, 10.0, 20.0, 50.0, 60.0, 50.0, 60.0, 0.1, 0.2, 1.0, 2.0, ];
1599 let output = Array2::from_shape_vec((6, 2), data).unwrap();
1600
1601 let mut boxes = Vec::with_capacity(10);
1602 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1603
1604 assert_eq!(boxes.len(), 0);
1605 }
1606
1607 #[test]
1608 fn test_end_to_end_det_capacity_limit() {
1609 let data: Vec<f32> = vec![
1611 0.1, 0.2, 0.3, 0.4, 0.5, 0.1, 0.2, 0.3, 0.4, 0.5, 0.5, 0.6, 0.7, 0.8, 0.9, 0.5, 0.6, 0.7, 0.8, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.0, 1.0, 2.0, 3.0, 4.0, ];
1618 let output = Array2::from_shape_vec((6, 5), data).unwrap();
1619
1620 let mut boxes = Vec::with_capacity(2); decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1622
1623 assert_eq!(boxes.len(), 2);
1624 }
1625
1626 #[test]
1627 fn test_end_to_end_det_empty_output() {
1628 let output = Array2::<f32>::zeros((6, 0));
1630
1631 let mut boxes = Vec::with_capacity(10);
1632 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1633
1634 assert_eq!(boxes.len(), 0);
1635 }
1636
1637 #[test]
1638 fn test_end_to_end_det_pixel_coordinates() {
1639 let data: Vec<f32> = vec![
1641 100.0, 200.0, 300.0, 400.0, 0.95, 5.0, ];
1648 let output = Array2::from_shape_vec((6, 1), data).unwrap();
1649
1650 let mut boxes = Vec::with_capacity(10);
1651 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1652
1653 assert_eq!(boxes.len(), 1);
1654 assert_eq!(boxes[0].label, 5);
1655 assert!((boxes[0].bbox.xmin - 100.0).abs() < 0.01);
1656 assert!((boxes[0].bbox.ymin - 200.0).abs() < 0.01);
1657 assert!((boxes[0].bbox.xmax - 300.0).abs() < 0.01);
1658 assert!((boxes[0].bbox.ymax - 400.0).abs() < 0.01);
1659 }
1660
1661 #[test]
1662 fn test_end_to_end_det_invalid_shape() {
1663 let output = Array2::<f32>::zeros((5, 3));
1665
1666 let mut boxes = Vec::with_capacity(10);
1667 let result = decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes);
1668
1669 assert!(result.is_err());
1670 assert!(matches!(
1671 result,
1672 Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 6 rows")
1673 ));
1674 }
1675
1676 #[test]
1681 fn test_end_to_end_segdet_basic() {
1682 let num_protos = 32;
1685 let num_detections = 2;
1686 let num_features = 6 + num_protos;
1687
1688 let mut data = vec![0.0f32; num_features * num_detections];
1690 data[0] = 0.1; data[1] = 0.5; data[num_detections] = 0.1; data[num_detections + 1] = 0.5; data[2 * num_detections] = 0.4; data[2 * num_detections + 1] = 0.9; data[3 * num_detections] = 0.4; data[3 * num_detections + 1] = 0.9; data[4 * num_detections] = 0.9; data[4 * num_detections + 1] = 0.3; data[5 * num_detections] = 1.0; data[5 * num_detections + 1] = 2.0; for i in 6..num_features {
1705 data[i * num_detections] = 0.1;
1706 data[i * num_detections + 1] = 0.1;
1707 }
1708
1709 let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
1710
1711 let protos = Array3::<f32>::zeros((16, 16, num_protos));
1713
1714 let mut boxes = Vec::with_capacity(10);
1715 let mut masks = Vec::with_capacity(10);
1716 decode_yolo_end_to_end_segdet_float(
1717 output.view(),
1718 protos.view(),
1719 0.5,
1720 &mut boxes,
1721 &mut masks,
1722 )
1723 .unwrap();
1724
1725 assert_eq!(boxes.len(), 1);
1727 assert_eq!(masks.len(), 1);
1728 assert_eq!(boxes[0].label, 1);
1729 assert!((boxes[0].score - 0.9).abs() < 0.01);
1730 }
1731
1732 #[test]
1733 fn test_end_to_end_segdet_mask_coordinates() {
1734 let num_protos = 32;
1736 let num_features = 6 + num_protos;
1737
1738 let mut data = vec![0.0f32; num_features];
1739 data[0] = 0.2; data[1] = 0.2; data[2] = 0.8; data[3] = 0.8; data[4] = 0.95; data[5] = 3.0; let output = Array2::from_shape_vec((num_features, 1), data).unwrap();
1747 let protos = Array3::<f32>::zeros((16, 16, num_protos));
1748
1749 let mut boxes = Vec::with_capacity(10);
1750 let mut masks = Vec::with_capacity(10);
1751 decode_yolo_end_to_end_segdet_float(
1752 output.view(),
1753 protos.view(),
1754 0.5,
1755 &mut boxes,
1756 &mut masks,
1757 )
1758 .unwrap();
1759
1760 assert_eq!(boxes.len(), 1);
1761 assert_eq!(masks.len(), 1);
1762
1763 assert!((masks[0].xmin - boxes[0].bbox.xmin).abs() < 0.01);
1765 assert!((masks[0].ymin - boxes[0].bbox.ymin).abs() < 0.01);
1766 assert!((masks[0].xmax - boxes[0].bbox.xmax).abs() < 0.01);
1767 assert!((masks[0].ymax - boxes[0].bbox.ymax).abs() < 0.01);
1768 }
1769
1770 #[test]
1771 fn test_end_to_end_segdet_empty_output() {
1772 let num_protos = 32;
1773 let output = Array2::<f32>::zeros((6 + num_protos, 0));
1774 let protos = Array3::<f32>::zeros((16, 16, num_protos));
1775
1776 let mut boxes = Vec::with_capacity(10);
1777 let mut masks = Vec::with_capacity(10);
1778 decode_yolo_end_to_end_segdet_float(
1779 output.view(),
1780 protos.view(),
1781 0.5,
1782 &mut boxes,
1783 &mut masks,
1784 )
1785 .unwrap();
1786
1787 assert_eq!(boxes.len(), 0);
1788 assert_eq!(masks.len(), 0);
1789 }
1790
1791 #[test]
1792 fn test_end_to_end_segdet_capacity_limit() {
1793 let num_protos = 32;
1794 let num_detections = 5;
1795 let num_features = 6 + num_protos;
1796
1797 let mut data = vec![0.0f32; num_features * num_detections];
1798 for i in 0..num_detections {
1800 data[i] = 0.1 * (i as f32); data[num_detections + i] = 0.1 * (i as f32); data[2 * num_detections + i] = 0.1 * (i as f32) + 0.2; data[3 * num_detections + i] = 0.1 * (i as f32) + 0.2; data[4 * num_detections + i] = 0.9; data[5 * num_detections + i] = i as f32; }
1807
1808 let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
1809 let protos = Array3::<f32>::zeros((16, 16, num_protos));
1810
1811 let mut boxes = Vec::with_capacity(2); let mut masks = Vec::with_capacity(2);
1813 decode_yolo_end_to_end_segdet_float(
1814 output.view(),
1815 protos.view(),
1816 0.5,
1817 &mut boxes,
1818 &mut masks,
1819 )
1820 .unwrap();
1821
1822 assert_eq!(boxes.len(), 2);
1823 assert_eq!(masks.len(), 2);
1824 }
1825
1826 #[test]
1827 fn test_end_to_end_segdet_invalid_shape_too_few_rows() {
1828 let output = Array2::<f32>::zeros((6, 3));
1830 let protos = Array3::<f32>::zeros((16, 16, 32));
1831
1832 let mut boxes = Vec::with_capacity(10);
1833 let mut masks = Vec::with_capacity(10);
1834 let result = decode_yolo_end_to_end_segdet_float(
1835 output.view(),
1836 protos.view(),
1837 0.5,
1838 &mut boxes,
1839 &mut masks,
1840 );
1841
1842 assert!(result.is_err());
1843 assert!(matches!(
1844 result,
1845 Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 7 rows")
1846 ));
1847 }
1848
1849 #[test]
1850 fn test_end_to_end_segdet_invalid_shape_protos_mismatch() {
1851 let num_protos = 32;
1853 let output = Array2::<f32>::zeros((6 + 16, 3)); let protos = Array3::<f32>::zeros((16, 16, num_protos)); let mut boxes = Vec::with_capacity(10);
1857 let mut masks = Vec::with_capacity(10);
1858 let result = decode_yolo_end_to_end_segdet_float(
1859 output.view(),
1860 protos.view(),
1861 0.5,
1862 &mut boxes,
1863 &mut masks,
1864 );
1865
1866 assert!(result.is_err());
1867 assert!(matches!(
1868 result,
1869 Err(crate::DecoderError::InvalidShape(s)) if s.contains("doesn't match protos count")
1870 ));
1871 }
1872
1873 #[test]
1878 fn test_segmentation_to_mask_basic() {
1879 let data: Vec<u8> = vec![
1881 100, 200, 50, 150, 10, 255, 128, 64, 0, 127, 128, 255, 64, 64, 192, 192, ];
1886 let segmentation = Array3::from_shape_vec((4, 4, 1), data).unwrap();
1887
1888 let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
1889
1890 assert_eq!(mask[[0, 0]], 0); assert_eq!(mask[[0, 1]], 1); assert_eq!(mask[[0, 2]], 0); assert_eq!(mask[[0, 3]], 1); assert_eq!(mask[[1, 1]], 1); assert_eq!(mask[[1, 2]], 1); assert_eq!(mask[[2, 0]], 0); assert_eq!(mask[[2, 1]], 0); }
1900
1901 #[test]
1902 fn test_segmentation_to_mask_all_above() {
1903 let segmentation = Array3::from_elem((4, 4, 1), 255u8);
1904 let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
1905 assert!(mask.iter().all(|&x| x == 1));
1906 }
1907
1908 #[test]
1909 fn test_segmentation_to_mask_all_below() {
1910 let segmentation = Array3::from_elem((4, 4, 1), 64u8);
1911 let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
1912 assert!(mask.iter().all(|&x| x == 0));
1913 }
1914
1915 #[test]
1916 fn test_segmentation_to_mask_invalid_shape() {
1917 let segmentation = Array3::from_elem((4, 4, 3), 128u8);
1918 let result = yolo_segmentation_to_mask(segmentation.view(), 128);
1919
1920 assert!(result.is_err());
1921 assert!(matches!(
1922 result,
1923 Err(crate::DecoderError::InvalidShape(s)) if s.contains("(H, W, 1)")
1924 ));
1925 }
1926}