1use std::fmt::Debug;
5
6use ndarray::{
7 parallel::prelude::{IntoParallelIterator, ParallelIterator},
8 s, Array2, Array3, ArrayView1, ArrayView2, ArrayView3,
9};
10use ndarray_stats::QuantileExt;
11use num_traits::{AsPrimitive, Float, PrimInt, Signed};
12
13use crate::{
14 byte::{
15 nms_class_aware_int, nms_extra_class_aware_int, nms_extra_int, nms_int,
16 postprocess_boxes_index_quant, postprocess_boxes_quant, quantize_score_threshold,
17 },
18 configs::Nms,
19 dequant_detect_box,
20 float::{
21 nms_class_aware_float, nms_extra_class_aware_float, nms_extra_float, nms_float,
22 postprocess_boxes_float, postprocess_boxes_index_float,
23 },
24 BBoxTypeTrait, BoundingBox, DetectBox, DetectBoxQuantized, Quantization, Segmentation, XYWH,
25 XYXY,
26};
27
28fn dispatch_nms_float(nms: Option<Nms>, iou: f32, boxes: Vec<DetectBox>) -> Vec<DetectBox> {
30 match nms {
31 Some(Nms::ClassAgnostic) => nms_float(iou, boxes),
32 Some(Nms::ClassAware) => nms_class_aware_float(iou, boxes),
33 None => boxes, }
35}
36
37fn dispatch_nms_extra_float<E: Send + Sync>(
40 nms: Option<Nms>,
41 iou: f32,
42 boxes: Vec<(DetectBox, E)>,
43) -> Vec<(DetectBox, E)> {
44 match nms {
45 Some(Nms::ClassAgnostic) => nms_extra_float(iou, boxes),
46 Some(Nms::ClassAware) => nms_extra_class_aware_float(iou, boxes),
47 None => boxes, }
49}
50
51fn dispatch_nms_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync>(
54 nms: Option<Nms>,
55 iou: f32,
56 boxes: Vec<DetectBoxQuantized<SCORE>>,
57) -> Vec<DetectBoxQuantized<SCORE>> {
58 match nms {
59 Some(Nms::ClassAgnostic) => nms_int(iou, boxes),
60 Some(Nms::ClassAware) => nms_class_aware_int(iou, boxes),
61 None => boxes, }
63}
64
65fn dispatch_nms_extra_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send + Sync>(
68 nms: Option<Nms>,
69 iou: f32,
70 boxes: Vec<(DetectBoxQuantized<SCORE>, E)>,
71) -> Vec<(DetectBoxQuantized<SCORE>, E)> {
72 match nms {
73 Some(Nms::ClassAgnostic) => nms_extra_int(iou, boxes),
74 Some(Nms::ClassAware) => nms_extra_class_aware_int(iou, boxes),
75 None => boxes, }
77}
78
79pub fn decode_yolo_det<BOX: PrimInt + AsPrimitive<f32> + Send + Sync>(
86 output: (ArrayView2<BOX>, Quantization),
87 score_threshold: f32,
88 iou_threshold: f32,
89 nms: Option<Nms>,
90 output_boxes: &mut Vec<DetectBox>,
91) where
92 f32: AsPrimitive<BOX>,
93{
94 impl_yolo_quant::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
95}
96
97pub fn decode_yolo_det_float<T>(
104 output: ArrayView2<T>,
105 score_threshold: f32,
106 iou_threshold: f32,
107 nms: Option<Nms>,
108 output_boxes: &mut Vec<DetectBox>,
109) where
110 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
111 f32: AsPrimitive<T>,
112{
113 impl_yolo_float::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
114}
115
116pub fn decode_yolo_segdet_quant<
128 BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
129 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
130>(
131 boxes: (ArrayView2<BOX>, Quantization),
132 protos: (ArrayView3<PROTO>, Quantization),
133 score_threshold: f32,
134 iou_threshold: f32,
135 nms: Option<Nms>,
136 output_boxes: &mut Vec<DetectBox>,
137 output_masks: &mut Vec<Segmentation>,
138) where
139 f32: AsPrimitive<BOX>,
140{
141 impl_yolo_segdet_quant::<XYWH, _, _>(
142 boxes,
143 protos,
144 score_threshold,
145 iou_threshold,
146 nms,
147 output_boxes,
148 output_masks,
149 );
150}
151
152pub fn decode_yolo_segdet_float<T>(
164 boxes: ArrayView2<T>,
165 protos: ArrayView3<T>,
166 score_threshold: f32,
167 iou_threshold: f32,
168 nms: Option<Nms>,
169 output_boxes: &mut Vec<DetectBox>,
170 output_masks: &mut Vec<Segmentation>,
171) where
172 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
173 f32: AsPrimitive<T>,
174{
175 impl_yolo_segdet_float::<XYWH, _, _>(
176 boxes,
177 protos,
178 score_threshold,
179 iou_threshold,
180 nms,
181 output_boxes,
182 output_masks,
183 );
184}
185
186pub fn decode_yolo_split_det_quant<
198 BOX: PrimInt + AsPrimitive<i32> + AsPrimitive<f32> + Send + Sync,
199 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
200>(
201 boxes: (ArrayView2<BOX>, Quantization),
202 scores: (ArrayView2<SCORE>, Quantization),
203 score_threshold: f32,
204 iou_threshold: f32,
205 nms: Option<Nms>,
206 output_boxes: &mut Vec<DetectBox>,
207) where
208 f32: AsPrimitive<SCORE>,
209{
210 impl_yolo_split_quant::<XYWH, _, _>(
211 boxes,
212 scores,
213 score_threshold,
214 iou_threshold,
215 nms,
216 output_boxes,
217 );
218}
219
220pub fn decode_yolo_split_det_float<T>(
232 boxes: ArrayView2<T>,
233 scores: ArrayView2<T>,
234 score_threshold: f32,
235 iou_threshold: f32,
236 nms: Option<Nms>,
237 output_boxes: &mut Vec<DetectBox>,
238) where
239 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
240 f32: AsPrimitive<T>,
241{
242 impl_yolo_split_float::<XYWH, _, _>(
243 boxes,
244 scores,
245 score_threshold,
246 iou_threshold,
247 nms,
248 output_boxes,
249 );
250}
251
252#[allow(clippy::too_many_arguments)]
266pub fn decode_yolo_split_segdet<
267 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
268 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
269 MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
270 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
271>(
272 boxes: (ArrayView2<BOX>, Quantization),
273 scores: (ArrayView2<SCORE>, Quantization),
274 mask_coeff: (ArrayView2<MASK>, Quantization),
275 protos: (ArrayView3<PROTO>, Quantization),
276 score_threshold: f32,
277 iou_threshold: f32,
278 nms: Option<Nms>,
279 output_boxes: &mut Vec<DetectBox>,
280 output_masks: &mut Vec<Segmentation>,
281) where
282 f32: AsPrimitive<SCORE>,
283{
284 impl_yolo_split_segdet_quant::<XYWH, _, _, _, _>(
285 boxes,
286 scores,
287 mask_coeff,
288 protos,
289 score_threshold,
290 iou_threshold,
291 nms,
292 output_boxes,
293 output_masks,
294 );
295}
296
297#[allow(clippy::too_many_arguments)]
311pub fn decode_yolo_split_segdet_float<T>(
312 boxes: ArrayView2<T>,
313 scores: ArrayView2<T>,
314 mask_coeff: ArrayView2<T>,
315 protos: ArrayView3<T>,
316 score_threshold: f32,
317 iou_threshold: f32,
318 nms: Option<Nms>,
319 output_boxes: &mut Vec<DetectBox>,
320 output_masks: &mut Vec<Segmentation>,
321) where
322 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
323 f32: AsPrimitive<T>,
324{
325 impl_yolo_split_segdet_float::<XYWH, _, _, _, _>(
326 boxes,
327 scores,
328 mask_coeff,
329 protos,
330 score_threshold,
331 iou_threshold,
332 nms,
333 output_boxes,
334 output_masks,
335 );
336}
337
338pub fn decode_yolo_end_to_end_det_float<T>(
353 output: ArrayView2<T>,
354 score_threshold: f32,
355 output_boxes: &mut Vec<DetectBox>,
356) -> Result<(), crate::DecoderError>
357where
358 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
359 f32: AsPrimitive<T>,
360{
361 if output.shape()[0] < 6 {
363 return Err(crate::DecoderError::InvalidShape(format!(
364 "End-to-end detection output requires at least 6 rows, got {}",
365 output.shape()[0]
366 )));
367 }
368
369 let boxes = output.slice(s![0..4, ..]).reversed_axes();
371 let scores = output.slice(s![4..5, ..]).reversed_axes();
372 let classes = output.slice(s![5, ..]);
373 let mut boxes =
374 postprocess_boxes_index_float::<XYXY, _, _>(score_threshold.as_(), boxes, scores);
375 boxes.truncate(output_boxes.capacity());
376 output_boxes.clear();
377 for (mut b, i) in boxes.into_iter() {
378 b.label = classes[i].as_() as usize;
379 output_boxes.push(b);
380 }
381 Ok(())
383}
384
385pub fn decode_yolo_end_to_end_segdet_float<T>(
403 output: ArrayView2<T>,
404 protos: ArrayView3<T>,
405 score_threshold: f32,
406 output_boxes: &mut Vec<DetectBox>,
407 output_masks: &mut Vec<crate::Segmentation>,
408) -> Result<(), crate::DecoderError>
409where
410 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
411 f32: AsPrimitive<T>,
412{
413 if output.shape()[0] < 7 {
415 return Err(crate::DecoderError::InvalidShape(format!(
416 "End-to-end segdet output requires at least 7 rows, got {}",
417 output.shape()[0]
418 )));
419 }
420
421 let num_mask_coeffs = output.shape()[0] - 6;
422 let num_protos = protos.shape()[2];
423 if num_mask_coeffs != num_protos {
424 return Err(crate::DecoderError::InvalidShape(format!(
425 "Mask coefficients count ({}) doesn't match protos count ({})",
426 num_mask_coeffs, num_protos
427 )));
428 }
429
430 let boxes = output.slice(s![0..4, ..]).reversed_axes();
432 let scores = output.slice(s![4..5, ..]).reversed_axes();
433 let classes = output.slice(s![5, ..]);
434 let mask_coeff = output.slice(s![6.., ..]).reversed_axes();
435 let mut boxes =
436 postprocess_boxes_index_float::<XYXY, _, _>(score_threshold.as_(), boxes, scores);
437 boxes.truncate(output_boxes.capacity());
438
439 for (b, ind) in &mut boxes {
440 b.label = classes[*ind].as_() as usize;
441 }
442
443 let boxes = decode_segdet_f32(boxes, mask_coeff, protos);
446
447 output_boxes.clear();
448 output_masks.clear();
449 for (b, m) in boxes.into_iter() {
450 output_boxes.push(b);
451 output_masks.push(Segmentation {
452 xmin: b.bbox.xmin,
453 ymin: b.bbox.ymin,
454 xmax: b.bbox.xmax,
455 ymax: b.bbox.ymax,
456 segmentation: m,
457 });
458 }
459 Ok(())
460}
461pub fn impl_yolo_quant<B: BBoxTypeTrait, T: PrimInt + AsPrimitive<f32> + Send + Sync>(
466 output: (ArrayView2<T>, Quantization),
467 score_threshold: f32,
468 iou_threshold: f32,
469 nms: Option<Nms>,
470 output_boxes: &mut Vec<DetectBox>,
471) where
472 f32: AsPrimitive<T>,
473{
474 let (boxes, quant_boxes) = output;
475 let (boxes_tensor, scores_tensor) = postprocess_yolo(&boxes);
476
477 let boxes = {
478 let score_threshold = quantize_score_threshold(score_threshold, quant_boxes);
479 postprocess_boxes_quant::<B, _, _>(
480 score_threshold,
481 boxes_tensor,
482 scores_tensor,
483 quant_boxes,
484 )
485 };
486
487 let boxes = dispatch_nms_int(nms, iou_threshold, boxes);
488 let len = output_boxes.capacity().min(boxes.len());
489 output_boxes.clear();
490 for b in boxes.iter().take(len) {
491 output_boxes.push(dequant_detect_box(b, quant_boxes));
492 }
493}
494
495pub fn impl_yolo_float<B: BBoxTypeTrait, T: Float + AsPrimitive<f32> + Send + Sync>(
500 output: ArrayView2<T>,
501 score_threshold: f32,
502 iou_threshold: f32,
503 nms: Option<Nms>,
504 output_boxes: &mut Vec<DetectBox>,
505) where
506 f32: AsPrimitive<T>,
507{
508 let (boxes_tensor, scores_tensor) = postprocess_yolo(&output);
509 let boxes =
510 postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
511 let boxes = dispatch_nms_float(nms, iou_threshold, boxes);
512 let len = output_boxes.capacity().min(boxes.len());
513 output_boxes.clear();
514 for b in boxes.into_iter().take(len) {
515 output_boxes.push(b);
516 }
517}
518
519pub fn impl_yolo_split_quant<
529 B: BBoxTypeTrait,
530 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
531 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
532>(
533 boxes: (ArrayView2<BOX>, Quantization),
534 scores: (ArrayView2<SCORE>, Quantization),
535 score_threshold: f32,
536 iou_threshold: f32,
537 nms: Option<Nms>,
538 output_boxes: &mut Vec<DetectBox>,
539) where
540 f32: AsPrimitive<SCORE>,
541{
542 let (boxes_tensor, quant_boxes) = boxes;
543 let (scores_tensor, quant_scores) = scores;
544
545 let boxes_tensor = boxes_tensor.reversed_axes();
546 let scores_tensor = scores_tensor.reversed_axes();
547
548 let boxes = {
549 let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
550 postprocess_boxes_quant::<B, _, _>(
551 score_threshold,
552 boxes_tensor,
553 scores_tensor,
554 quant_boxes,
555 )
556 };
557
558 let boxes = dispatch_nms_int(nms, iou_threshold, boxes);
559 let len = output_boxes.capacity().min(boxes.len());
560 output_boxes.clear();
561 for b in boxes.iter().take(len) {
562 output_boxes.push(dequant_detect_box(b, quant_scores));
563 }
564}
565
566pub fn impl_yolo_split_float<
575 B: BBoxTypeTrait,
576 BOX: Float + AsPrimitive<f32> + Send + Sync,
577 SCORE: Float + AsPrimitive<f32> + Send + Sync,
578>(
579 boxes_tensor: ArrayView2<BOX>,
580 scores_tensor: ArrayView2<SCORE>,
581 score_threshold: f32,
582 iou_threshold: f32,
583 nms: Option<Nms>,
584 output_boxes: &mut Vec<DetectBox>,
585) where
586 f32: AsPrimitive<SCORE>,
587{
588 let boxes_tensor = boxes_tensor.reversed_axes();
589 let scores_tensor = scores_tensor.reversed_axes();
590 let boxes =
591 postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
592 let boxes = dispatch_nms_float(nms, iou_threshold, boxes);
593 let len = output_boxes.capacity().min(boxes.len());
594 output_boxes.clear();
595 for b in boxes.into_iter().take(len) {
596 output_boxes.push(b);
597 }
598}
599
600pub fn impl_yolo_segdet_quant<
610 B: BBoxTypeTrait,
611 BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
612 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
613>(
614 boxes: (ArrayView2<BOX>, Quantization),
615 protos: (ArrayView3<PROTO>, Quantization),
616 score_threshold: f32,
617 iou_threshold: f32,
618 nms: Option<Nms>,
619 output_boxes: &mut Vec<DetectBox>,
620 output_masks: &mut Vec<Segmentation>,
621) where
622 f32: AsPrimitive<BOX>,
623{
624 let (boxes, quant_boxes) = boxes;
625 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes);
626
627 let boxes = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
628 (boxes_tensor.reversed_axes(), quant_boxes),
629 (scores_tensor.reversed_axes(), quant_boxes),
630 score_threshold,
631 iou_threshold,
632 nms,
633 output_boxes.capacity(),
634 );
635
636 impl_yolo_split_segdet_quant_process_masks::<_, _>(
637 boxes,
638 (mask_tensor.reversed_axes(), quant_boxes),
639 protos,
640 output_boxes,
641 output_masks,
642 );
643}
644
645pub fn impl_yolo_segdet_float<
655 B: BBoxTypeTrait,
656 BOX: Float + AsPrimitive<f32> + Send + Sync,
657 PROTO: Float + AsPrimitive<f32> + Send + Sync,
658>(
659 boxes: ArrayView2<BOX>,
660 protos: ArrayView3<PROTO>,
661 score_threshold: f32,
662 iou_threshold: f32,
663 nms: Option<Nms>,
664 output_boxes: &mut Vec<DetectBox>,
665 output_masks: &mut Vec<Segmentation>,
666) where
667 f32: AsPrimitive<BOX>,
668{
669 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes);
670
671 let boxes = postprocess_boxes_index_float::<B, _, _>(
672 score_threshold.as_(),
673 boxes_tensor,
674 scores_tensor,
675 );
676 let mut boxes = dispatch_nms_extra_float(nms, iou_threshold, boxes);
677 boxes.truncate(output_boxes.capacity());
678 let boxes = decode_segdet_f32(boxes, mask_tensor, protos);
679 output_boxes.clear();
680 output_masks.clear();
681 for (b, m) in boxes.into_iter() {
682 output_boxes.push(b);
683 output_masks.push(Segmentation {
684 xmin: b.bbox.xmin,
685 ymin: b.bbox.ymin,
686 xmax: b.bbox.xmax,
687 ymax: b.bbox.ymax,
688 segmentation: m,
689 });
690 }
691}
692
693pub(crate) fn impl_yolo_split_segdet_quant_get_boxes<
694 B: BBoxTypeTrait,
695 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
696 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
697>(
698 boxes: (ArrayView2<BOX>, Quantization),
699 scores: (ArrayView2<SCORE>, Quantization),
700 score_threshold: f32,
701 iou_threshold: f32,
702 nms: Option<Nms>,
703 max_boxes: usize,
704) -> Vec<(DetectBox, usize)>
705where
706 f32: AsPrimitive<SCORE>,
707{
708 let (boxes_tensor, quant_boxes) = boxes;
709 let (scores_tensor, quant_scores) = scores;
710
711 let boxes_tensor = boxes_tensor.reversed_axes();
712 let scores_tensor = scores_tensor.reversed_axes();
713
714 let boxes = {
715 let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
716 postprocess_boxes_index_quant::<B, _, _>(
717 score_threshold,
718 boxes_tensor,
719 scores_tensor,
720 quant_boxes,
721 )
722 };
723 let mut boxes = dispatch_nms_extra_int(nms, iou_threshold, boxes);
724 boxes.truncate(max_boxes);
725 boxes
726 .into_iter()
727 .map(|(b, i)| (dequant_detect_box(&b, quant_scores), i))
728 .collect()
729}
730
731pub(crate) fn impl_yolo_split_segdet_quant_process_masks<
732 MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
733 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
734>(
735 boxes: Vec<(DetectBox, usize)>,
736 mask_coeff: (ArrayView2<MASK>, Quantization),
737 protos: (ArrayView3<PROTO>, Quantization),
738 output_boxes: &mut Vec<DetectBox>,
739 output_masks: &mut Vec<Segmentation>,
740) {
741 let (masks, quant_masks) = mask_coeff;
742 let (protos, quant_protos) = protos;
743
744 let masks = masks.reversed_axes();
745
746 let boxes = decode_segdet_quant(boxes, masks, protos, quant_masks, quant_protos);
747 output_boxes.clear();
748 output_masks.clear();
749 for (b, m) in boxes.into_iter() {
750 output_boxes.push(b);
751 output_masks.push(Segmentation {
752 xmin: b.bbox.xmin,
753 ymin: b.bbox.ymin,
754 xmax: b.bbox.xmax,
755 ymax: b.bbox.ymax,
756 segmentation: m,
757 });
758 }
759}
760
761#[allow(clippy::too_many_arguments)]
762pub fn impl_yolo_split_segdet_quant<
774 B: BBoxTypeTrait,
775 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
776 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
777 MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
778 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
779>(
780 boxes: (ArrayView2<BOX>, Quantization),
781 scores: (ArrayView2<SCORE>, Quantization),
782 mask_coeff: (ArrayView2<MASK>, Quantization),
783 protos: (ArrayView3<PROTO>, Quantization),
784 score_threshold: f32,
785 iou_threshold: f32,
786 nms: Option<Nms>,
787 output_boxes: &mut Vec<DetectBox>,
788 output_masks: &mut Vec<Segmentation>,
789) where
790 f32: AsPrimitive<SCORE>,
791{
792 let boxes = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
793 boxes,
794 scores,
795 score_threshold,
796 iou_threshold,
797 nms,
798 output_boxes.capacity(),
799 );
800
801 impl_yolo_split_segdet_quant_process_masks(
802 boxes,
803 mask_coeff,
804 protos,
805 output_boxes,
806 output_masks,
807 );
808}
809
810#[allow(clippy::too_many_arguments)]
811pub fn impl_yolo_split_segdet_float<
823 B: BBoxTypeTrait,
824 BOX: Float + AsPrimitive<f32> + Send + Sync,
825 SCORE: Float + AsPrimitive<f32> + Send + Sync,
826 MASK: Float + AsPrimitive<f32> + Send + Sync,
827 PROTO: Float + AsPrimitive<f32> + Send + Sync,
828>(
829 boxes_tensor: ArrayView2<BOX>,
830 scores_tensor: ArrayView2<SCORE>,
831 mask_tensor: ArrayView2<MASK>,
832 protos: ArrayView3<PROTO>,
833 score_threshold: f32,
834 iou_threshold: f32,
835 nms: Option<Nms>,
836 output_boxes: &mut Vec<DetectBox>,
837 output_masks: &mut Vec<Segmentation>,
838) where
839 f32: AsPrimitive<SCORE>,
840{
841 let boxes_tensor = boxes_tensor.reversed_axes();
842 let scores_tensor = scores_tensor.reversed_axes();
843 let mask_tensor = mask_tensor.reversed_axes();
844
845 let boxes = postprocess_boxes_index_float::<B, _, _>(
846 score_threshold.as_(),
847 boxes_tensor,
848 scores_tensor,
849 );
850 let mut boxes = dispatch_nms_extra_float(nms, iou_threshold, boxes);
851 boxes.truncate(output_boxes.capacity());
852 let boxes = decode_segdet_f32(boxes, mask_tensor, protos);
853 output_boxes.clear();
854 output_masks.clear();
855 for (b, m) in boxes.into_iter() {
856 output_boxes.push(b);
857 output_masks.push(Segmentation {
858 xmin: b.bbox.xmin,
859 ymin: b.bbox.ymin,
860 xmax: b.bbox.xmax,
861 ymax: b.bbox.ymax,
862 segmentation: m,
863 });
864 }
865}
866
867fn postprocess_yolo<'a, T>(
868 output: &'a ArrayView2<'_, T>,
869) -> (ArrayView2<'a, T>, ArrayView2<'a, T>) {
870 let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
871 let scores_tensor = output.slice(s![4.., ..,]).reversed_axes();
872 (boxes_tensor, scores_tensor)
873}
874
875fn postprocess_yolo_seg<'a, T>(
876 output: &'a ArrayView2<'_, T>,
877) -> (ArrayView2<'a, T>, ArrayView2<'a, T>, ArrayView2<'a, T>) {
878 assert!(output.shape()[0] > 32 + 4, "Output shape is too short");
879 let num_classes = output.shape()[0] - 4 - 32;
880 let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
881 let scores_tensor = output.slice(s![4..(num_classes + 4), ..,]).reversed_axes();
882 let mask_tensor = output.slice(s![(num_classes + 4).., ..,]).reversed_axes();
883 (boxes_tensor, scores_tensor, mask_tensor)
884}
885
886fn decode_segdet_f32<
887 MASK: Float + AsPrimitive<f32> + Send + Sync,
888 PROTO: Float + AsPrimitive<f32> + Send + Sync,
889>(
890 boxes: Vec<(DetectBox, usize)>,
891 masks: ArrayView2<MASK>,
892 protos: ArrayView3<PROTO>,
893) -> Vec<(DetectBox, Array3<u8>)> {
894 if boxes.is_empty() {
895 return Vec::new();
896 }
897 assert!(masks.shape()[1] == protos.shape()[2]);
898 boxes
899 .into_par_iter()
900 .map(|mut b| {
901 let ind = b.1;
902 let (protos, roi) = protobox(&protos, &b.0.bbox);
903 b.0.bbox = roi;
904 (b.0, make_segmentation(masks.row(ind), protos.view()))
905 })
906 .collect()
907}
908
909pub(crate) fn decode_segdet_quant<
910 MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
911 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
912>(
913 boxes: Vec<(DetectBox, usize)>,
914 masks: ArrayView2<MASK>,
915 protos: ArrayView3<PROTO>,
916 quant_masks: Quantization,
917 quant_protos: Quantization,
918) -> Vec<(DetectBox, Array3<u8>)> {
919 if boxes.is_empty() {
920 return Vec::new();
921 }
922 assert!(masks.shape()[1] == protos.shape()[2]);
923
924 let total_bits = MASK::zero().count_zeros() + PROTO::zero().count_zeros() + 5; boxes
926 .into_iter()
927 .map(|mut b| {
928 let i = b.1;
929 let (protos, roi) = protobox(&protos, &b.0.bbox.to_canonical());
930 b.0.bbox = roi;
931 let seg = match total_bits {
932 0..=64 => make_segmentation_quant::<MASK, PROTO, i64>(
933 masks.row(i),
934 protos.view(),
935 quant_masks,
936 quant_protos,
937 ),
938 65..=128 => make_segmentation_quant::<MASK, PROTO, i128>(
939 masks.row(i),
940 protos.view(),
941 quant_masks,
942 quant_protos,
943 ),
944 _ => panic!("Unsupported bit width for segmentation computation"),
945 };
946 (b.0, seg)
947 })
948 .collect()
949}
950
951fn protobox<'a, T>(
952 protos: &'a ArrayView3<T>,
953 roi: &BoundingBox,
954) -> (ArrayView3<'a, T>, BoundingBox) {
955 let width = protos.dim().1 as f32;
956 let height = protos.dim().0 as f32;
957
958 let roi = [
959 (roi.xmin * width).clamp(0.0, width) as usize,
960 (roi.ymin * height).clamp(0.0, height) as usize,
961 (roi.xmax * width).clamp(0.0, width).ceil() as usize,
962 (roi.ymax * height).clamp(0.0, height).ceil() as usize,
963 ];
964
965 let roi_norm = [
966 roi[0] as f32 / width,
967 roi[1] as f32 / height,
968 roi[2] as f32 / width,
969 roi[3] as f32 / height,
970 ]
971 .into();
972
973 let cropped = protos.slice(s![roi[1]..roi[3], roi[0]..roi[2], ..]);
974
975 (cropped, roi_norm)
976}
977
978fn make_segmentation<
979 MASK: Float + AsPrimitive<f32> + Send + Sync,
980 PROTO: Float + AsPrimitive<f32> + Send + Sync,
981>(
982 mask: ArrayView1<MASK>,
983 protos: ArrayView3<PROTO>,
984) -> Array3<u8> {
985 let shape = protos.shape();
986
987 let mask = mask.to_shape((1, mask.len())).unwrap();
989 let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
990 let protos = protos.reversed_axes();
991 let mask = mask.map(|x| x.as_());
992 let protos = protos.map(|x| x.as_());
993
994 let mask = mask
996 .dot(&protos)
997 .into_shape_with_order((shape[0], shape[1], 1))
998 .unwrap();
999
1000 let min = *mask.min().unwrap_or(&0.0);
1001 let max = *mask.max().unwrap_or(&1.0);
1002 let max = max.max(-min);
1003 let min = -max;
1004 let u8_max = 256.0;
1005 mask.map(|x| ((*x - min) / (max - min) * u8_max) as u8)
1006}
1007
1008fn make_segmentation_quant<
1009 MASK: PrimInt + AsPrimitive<DEST> + Send + Sync,
1010 PROTO: PrimInt + AsPrimitive<DEST> + Send + Sync,
1011 DEST: PrimInt + 'static + Signed + AsPrimitive<f32> + Debug,
1012>(
1013 mask: ArrayView1<MASK>,
1014 protos: ArrayView3<PROTO>,
1015 quant_masks: Quantization,
1016 quant_protos: Quantization,
1017) -> Array3<u8>
1018where
1019 i32: AsPrimitive<DEST>,
1020 f32: AsPrimitive<DEST>,
1021{
1022 let shape = protos.shape();
1023
1024 let mask = mask.to_shape((1, mask.len())).unwrap();
1026
1027 let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
1028 let protos = protos.reversed_axes();
1029
1030 let zp = quant_masks.zero_point.as_();
1031
1032 let mask = mask.mapv(|x| x.as_() - zp);
1033
1034 let zp = quant_protos.zero_point.as_();
1035 let protos = protos.mapv(|x| x.as_() - zp);
1036
1037 let segmentation = mask
1039 .dot(&protos)
1040 .into_shape_with_order((shape[0], shape[1], 1))
1041 .unwrap();
1042
1043 let min = *segmentation.min().unwrap_or(&DEST::zero());
1044 let max = *segmentation.max().unwrap_or(&DEST::one());
1045 let max = max.max(-min);
1046 let min = -max;
1047 segmentation.map(|x| ((*x - min).as_() / (max - min).as_() * 256.0) as u8)
1048}
1049
1050pub fn yolo_segmentation_to_mask(
1062 segmentation: ArrayView3<u8>,
1063 threshold: u8,
1064) -> Result<Array2<u8>, crate::DecoderError> {
1065 if segmentation.shape()[2] != 1 {
1066 return Err(crate::DecoderError::InvalidShape(format!(
1067 "Yolo Instance Segmentation should have shape (H, W, 1), got (H, W, {})",
1068 segmentation.shape()[2]
1069 )));
1070 }
1071 Ok(segmentation
1072 .slice(s![.., .., 0])
1073 .map(|x| if *x >= threshold { 1 } else { 0 }))
1074}
1075
1076#[cfg(test)]
1077#[cfg_attr(coverage_nightly, coverage(off))]
1078mod tests {
1079 use super::*;
1080 use ndarray::Array2;
1081
1082 #[test]
1087 fn test_end_to_end_det_basic_filtering() {
1088 let data: Vec<f32> = vec![
1092 0.1, 0.2, 0.3, 0.1, 0.2, 0.3, 0.5, 0.6, 0.7, 0.5, 0.6, 0.7, 0.9, 0.1, 0.2, 0.0, 1.0, 2.0, ];
1100 let output = Array2::from_shape_vec((6, 3), data).unwrap();
1101
1102 let mut boxes = Vec::with_capacity(10);
1103 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1104
1105 assert_eq!(boxes.len(), 1);
1107 assert_eq!(boxes[0].label, 0);
1108 assert!((boxes[0].score - 0.9).abs() < 0.01);
1109 assert!((boxes[0].bbox.xmin - 0.1).abs() < 0.01);
1110 assert!((boxes[0].bbox.ymin - 0.1).abs() < 0.01);
1111 assert!((boxes[0].bbox.xmax - 0.5).abs() < 0.01);
1112 assert!((boxes[0].bbox.ymax - 0.5).abs() < 0.01);
1113 }
1114
1115 #[test]
1116 fn test_end_to_end_det_all_pass_threshold() {
1117 let data: Vec<f32> = vec![
1119 10.0, 20.0, 10.0, 20.0, 50.0, 60.0, 50.0, 60.0, 0.8, 0.7, 1.0, 2.0, ];
1126 let output = Array2::from_shape_vec((6, 2), data).unwrap();
1127
1128 let mut boxes = Vec::with_capacity(10);
1129 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1130
1131 assert_eq!(boxes.len(), 2);
1132 assert_eq!(boxes[0].label, 1);
1133 assert_eq!(boxes[1].label, 2);
1134 }
1135
1136 #[test]
1137 fn test_end_to_end_det_none_pass_threshold() {
1138 let data: Vec<f32> = vec![
1140 10.0, 20.0, 10.0, 20.0, 50.0, 60.0, 50.0, 60.0, 0.1, 0.2, 1.0, 2.0, ];
1147 let output = Array2::from_shape_vec((6, 2), data).unwrap();
1148
1149 let mut boxes = Vec::with_capacity(10);
1150 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1151
1152 assert_eq!(boxes.len(), 0);
1153 }
1154
1155 #[test]
1156 fn test_end_to_end_det_capacity_limit() {
1157 let data: Vec<f32> = vec![
1159 0.1, 0.2, 0.3, 0.4, 0.5, 0.1, 0.2, 0.3, 0.4, 0.5, 0.5, 0.6, 0.7, 0.8, 0.9, 0.5, 0.6, 0.7, 0.8, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.0, 1.0, 2.0, 3.0, 4.0, ];
1166 let output = Array2::from_shape_vec((6, 5), data).unwrap();
1167
1168 let mut boxes = Vec::with_capacity(2); decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1170
1171 assert_eq!(boxes.len(), 2);
1172 }
1173
1174 #[test]
1175 fn test_end_to_end_det_empty_output() {
1176 let output = Array2::<f32>::zeros((6, 0));
1178
1179 let mut boxes = Vec::with_capacity(10);
1180 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1181
1182 assert_eq!(boxes.len(), 0);
1183 }
1184
1185 #[test]
1186 fn test_end_to_end_det_pixel_coordinates() {
1187 let data: Vec<f32> = vec![
1189 100.0, 200.0, 300.0, 400.0, 0.95, 5.0, ];
1196 let output = Array2::from_shape_vec((6, 1), data).unwrap();
1197
1198 let mut boxes = Vec::with_capacity(10);
1199 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1200
1201 assert_eq!(boxes.len(), 1);
1202 assert_eq!(boxes[0].label, 5);
1203 assert!((boxes[0].bbox.xmin - 100.0).abs() < 0.01);
1204 assert!((boxes[0].bbox.ymin - 200.0).abs() < 0.01);
1205 assert!((boxes[0].bbox.xmax - 300.0).abs() < 0.01);
1206 assert!((boxes[0].bbox.ymax - 400.0).abs() < 0.01);
1207 }
1208
1209 #[test]
1210 fn test_end_to_end_det_invalid_shape() {
1211 let output = Array2::<f32>::zeros((5, 3));
1213
1214 let mut boxes = Vec::with_capacity(10);
1215 let result = decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes);
1216
1217 assert!(result.is_err());
1218 assert!(matches!(
1219 result,
1220 Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 6 rows")
1221 ));
1222 }
1223
1224 #[test]
1229 fn test_end_to_end_segdet_basic() {
1230 let num_protos = 32;
1233 let num_detections = 2;
1234 let num_features = 6 + num_protos;
1235
1236 let mut data = vec![0.0f32; num_features * num_detections];
1238 data[0] = 0.1; data[1] = 0.5; data[num_detections] = 0.1; data[num_detections + 1] = 0.5; data[2 * num_detections] = 0.4; data[2 * num_detections + 1] = 0.9; data[3 * num_detections] = 0.4; data[3 * num_detections + 1] = 0.9; data[4 * num_detections] = 0.9; data[4 * num_detections + 1] = 0.3; data[5 * num_detections] = 1.0; data[5 * num_detections + 1] = 2.0; for i in 6..num_features {
1253 data[i * num_detections] = 0.1;
1254 data[i * num_detections + 1] = 0.1;
1255 }
1256
1257 let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
1258
1259 let protos = Array3::<f32>::zeros((16, 16, num_protos));
1261
1262 let mut boxes = Vec::with_capacity(10);
1263 let mut masks = Vec::with_capacity(10);
1264 decode_yolo_end_to_end_segdet_float(
1265 output.view(),
1266 protos.view(),
1267 0.5,
1268 &mut boxes,
1269 &mut masks,
1270 )
1271 .unwrap();
1272
1273 assert_eq!(boxes.len(), 1);
1275 assert_eq!(masks.len(), 1);
1276 assert_eq!(boxes[0].label, 1);
1277 assert!((boxes[0].score - 0.9).abs() < 0.01);
1278 }
1279
1280 #[test]
1281 fn test_end_to_end_segdet_mask_coordinates() {
1282 let num_protos = 32;
1284 let num_features = 6 + num_protos;
1285
1286 let mut data = vec![0.0f32; num_features];
1287 data[0] = 0.2; data[1] = 0.2; data[2] = 0.8; data[3] = 0.8; data[4] = 0.95; data[5] = 3.0; let output = Array2::from_shape_vec((num_features, 1), data).unwrap();
1295 let protos = Array3::<f32>::zeros((16, 16, num_protos));
1296
1297 let mut boxes = Vec::with_capacity(10);
1298 let mut masks = Vec::with_capacity(10);
1299 decode_yolo_end_to_end_segdet_float(
1300 output.view(),
1301 protos.view(),
1302 0.5,
1303 &mut boxes,
1304 &mut masks,
1305 )
1306 .unwrap();
1307
1308 assert_eq!(boxes.len(), 1);
1309 assert_eq!(masks.len(), 1);
1310
1311 assert!((masks[0].xmin - boxes[0].bbox.xmin).abs() < 0.01);
1313 assert!((masks[0].ymin - boxes[0].bbox.ymin).abs() < 0.01);
1314 assert!((masks[0].xmax - boxes[0].bbox.xmax).abs() < 0.01);
1315 assert!((masks[0].ymax - boxes[0].bbox.ymax).abs() < 0.01);
1316 }
1317
1318 #[test]
1319 fn test_end_to_end_segdet_empty_output() {
1320 let num_protos = 32;
1321 let output = Array2::<f32>::zeros((6 + num_protos, 0));
1322 let protos = Array3::<f32>::zeros((16, 16, num_protos));
1323
1324 let mut boxes = Vec::with_capacity(10);
1325 let mut masks = Vec::with_capacity(10);
1326 decode_yolo_end_to_end_segdet_float(
1327 output.view(),
1328 protos.view(),
1329 0.5,
1330 &mut boxes,
1331 &mut masks,
1332 )
1333 .unwrap();
1334
1335 assert_eq!(boxes.len(), 0);
1336 assert_eq!(masks.len(), 0);
1337 }
1338
1339 #[test]
1340 fn test_end_to_end_segdet_capacity_limit() {
1341 let num_protos = 32;
1342 let num_detections = 5;
1343 let num_features = 6 + num_protos;
1344
1345 let mut data = vec![0.0f32; num_features * num_detections];
1346 for i in 0..num_detections {
1348 data[i] = 0.1 * (i as f32); data[num_detections + i] = 0.1 * (i as f32); data[2 * num_detections + i] = 0.1 * (i as f32) + 0.2; data[3 * num_detections + i] = 0.1 * (i as f32) + 0.2; data[4 * num_detections + i] = 0.9; data[5 * num_detections + i] = i as f32; }
1355
1356 let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
1357 let protos = Array3::<f32>::zeros((16, 16, num_protos));
1358
1359 let mut boxes = Vec::with_capacity(2); let mut masks = Vec::with_capacity(2);
1361 decode_yolo_end_to_end_segdet_float(
1362 output.view(),
1363 protos.view(),
1364 0.5,
1365 &mut boxes,
1366 &mut masks,
1367 )
1368 .unwrap();
1369
1370 assert_eq!(boxes.len(), 2);
1371 assert_eq!(masks.len(), 2);
1372 }
1373
1374 #[test]
1375 fn test_end_to_end_segdet_invalid_shape_too_few_rows() {
1376 let output = Array2::<f32>::zeros((6, 3));
1378 let protos = Array3::<f32>::zeros((16, 16, 32));
1379
1380 let mut boxes = Vec::with_capacity(10);
1381 let mut masks = Vec::with_capacity(10);
1382 let result = decode_yolo_end_to_end_segdet_float(
1383 output.view(),
1384 protos.view(),
1385 0.5,
1386 &mut boxes,
1387 &mut masks,
1388 );
1389
1390 assert!(result.is_err());
1391 assert!(matches!(
1392 result,
1393 Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 7 rows")
1394 ));
1395 }
1396
1397 #[test]
1398 fn test_end_to_end_segdet_invalid_shape_protos_mismatch() {
1399 let num_protos = 32;
1401 let output = Array2::<f32>::zeros((6 + 16, 3)); let protos = Array3::<f32>::zeros((16, 16, num_protos)); let mut boxes = Vec::with_capacity(10);
1405 let mut masks = Vec::with_capacity(10);
1406 let result = decode_yolo_end_to_end_segdet_float(
1407 output.view(),
1408 protos.view(),
1409 0.5,
1410 &mut boxes,
1411 &mut masks,
1412 );
1413
1414 assert!(result.is_err());
1415 assert!(matches!(
1416 result,
1417 Err(crate::DecoderError::InvalidShape(s)) if s.contains("doesn't match protos count")
1418 ));
1419 }
1420
1421 #[test]
1426 fn test_segmentation_to_mask_basic() {
1427 let data: Vec<u8> = vec![
1429 100, 200, 50, 150, 10, 255, 128, 64, 0, 127, 128, 255, 64, 64, 192, 192, ];
1434 let segmentation = Array3::from_shape_vec((4, 4, 1), data).unwrap();
1435
1436 let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
1437
1438 assert_eq!(mask[[0, 0]], 0); assert_eq!(mask[[0, 1]], 1); assert_eq!(mask[[0, 2]], 0); assert_eq!(mask[[0, 3]], 1); assert_eq!(mask[[1, 1]], 1); assert_eq!(mask[[1, 2]], 1); assert_eq!(mask[[2, 0]], 0); assert_eq!(mask[[2, 1]], 0); }
1448
1449 #[test]
1450 fn test_segmentation_to_mask_all_above() {
1451 let segmentation = Array3::from_elem((4, 4, 1), 255u8);
1452 let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
1453 assert!(mask.iter().all(|&x| x == 1));
1454 }
1455
1456 #[test]
1457 fn test_segmentation_to_mask_all_below() {
1458 let segmentation = Array3::from_elem((4, 4, 1), 64u8);
1459 let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
1460 assert!(mask.iter().all(|&x| x == 0));
1461 }
1462
1463 #[test]
1464 fn test_segmentation_to_mask_invalid_shape() {
1465 let segmentation = Array3::from_elem((4, 4, 3), 128u8);
1466 let result = yolo_segmentation_to_mask(segmentation.view(), 128);
1467
1468 assert!(result.is_err());
1469 assert!(matches!(
1470 result,
1471 Err(crate::DecoderError::InvalidShape(s)) if s.contains("(H, W, 1)")
1472 ));
1473 }
1474}