1use std::fmt::Debug;
5
6use ndarray::{
7 parallel::prelude::{IntoParallelIterator, ParallelIterator},
8 s, Array2, Array3, ArrayView1, ArrayView2, ArrayView3,
9};
10use num_traits::{AsPrimitive, Float, PrimInt, Signed};
11
12use crate::{
13 byte::{
14 nms_class_aware_int, nms_extra_class_aware_int, nms_extra_int, nms_int,
15 postprocess_boxes_index_quant, postprocess_boxes_quant, quantize_score_threshold,
16 },
17 configs::Nms,
18 dequant_detect_box,
19 float::{
20 nms_class_aware_float, nms_extra_class_aware_float, nms_extra_float, nms_float,
21 postprocess_boxes_float, postprocess_boxes_index_float,
22 },
23 BBoxTypeTrait, BoundingBox, DetectBox, DetectBoxQuantized, ProtoData, ProtoTensor,
24 Quantization, Segmentation, XYWH, XYXY,
25};
26
27fn dispatch_nms_float(nms: Option<Nms>, iou: f32, boxes: Vec<DetectBox>) -> Vec<DetectBox> {
29 match nms {
30 Some(Nms::ClassAgnostic) => nms_float(iou, boxes),
31 Some(Nms::ClassAware) => nms_class_aware_float(iou, boxes),
32 None => boxes, }
34}
35
36pub(super) fn dispatch_nms_extra_float<E: Send + Sync>(
39 nms: Option<Nms>,
40 iou: f32,
41 boxes: Vec<(DetectBox, E)>,
42) -> Vec<(DetectBox, E)> {
43 match nms {
44 Some(Nms::ClassAgnostic) => nms_extra_float(iou, boxes),
45 Some(Nms::ClassAware) => nms_extra_class_aware_float(iou, boxes),
46 None => boxes, }
48}
49
50fn dispatch_nms_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync>(
53 nms: Option<Nms>,
54 iou: f32,
55 boxes: Vec<DetectBoxQuantized<SCORE>>,
56) -> Vec<DetectBoxQuantized<SCORE>> {
57 match nms {
58 Some(Nms::ClassAgnostic) => nms_int(iou, boxes),
59 Some(Nms::ClassAware) => nms_class_aware_int(iou, boxes),
60 None => boxes, }
62}
63
64fn dispatch_nms_extra_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send + Sync>(
67 nms: Option<Nms>,
68 iou: f32,
69 boxes: Vec<(DetectBoxQuantized<SCORE>, E)>,
70) -> Vec<(DetectBoxQuantized<SCORE>, E)> {
71 match nms {
72 Some(Nms::ClassAgnostic) => nms_extra_int(iou, boxes),
73 Some(Nms::ClassAware) => nms_extra_class_aware_int(iou, boxes),
74 None => boxes, }
76}
77
78pub fn decode_yolo_det<BOX: PrimInt + AsPrimitive<f32> + Send + Sync>(
85 output: (ArrayView2<BOX>, Quantization),
86 score_threshold: f32,
87 iou_threshold: f32,
88 nms: Option<Nms>,
89 output_boxes: &mut Vec<DetectBox>,
90) where
91 f32: AsPrimitive<BOX>,
92{
93 impl_yolo_quant::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
94}
95
96pub fn decode_yolo_det_float<T>(
103 output: ArrayView2<T>,
104 score_threshold: f32,
105 iou_threshold: f32,
106 nms: Option<Nms>,
107 output_boxes: &mut Vec<DetectBox>,
108) where
109 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
110 f32: AsPrimitive<T>,
111{
112 impl_yolo_float::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
113}
114
115pub fn decode_yolo_segdet_quant<
127 BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
128 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
129>(
130 boxes: (ArrayView2<BOX>, Quantization),
131 protos: (ArrayView3<PROTO>, Quantization),
132 score_threshold: f32,
133 iou_threshold: f32,
134 nms: Option<Nms>,
135 output_boxes: &mut Vec<DetectBox>,
136 output_masks: &mut Vec<Segmentation>,
137) -> Result<(), crate::DecoderError>
138where
139 f32: AsPrimitive<BOX>,
140{
141 impl_yolo_segdet_quant::<XYWH, _, _>(
142 boxes,
143 protos,
144 score_threshold,
145 iou_threshold,
146 nms,
147 output_boxes,
148 output_masks,
149 )
150}
151
152pub fn decode_yolo_segdet_float<T>(
164 boxes: ArrayView2<T>,
165 protos: ArrayView3<T>,
166 score_threshold: f32,
167 iou_threshold: f32,
168 nms: Option<Nms>,
169 output_boxes: &mut Vec<DetectBox>,
170 output_masks: &mut Vec<Segmentation>,
171) -> Result<(), crate::DecoderError>
172where
173 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
174 f32: AsPrimitive<T>,
175{
176 impl_yolo_segdet_float::<XYWH, _, _>(
177 boxes,
178 protos,
179 score_threshold,
180 iou_threshold,
181 nms,
182 output_boxes,
183 output_masks,
184 )
185}
186
187pub fn decode_yolo_split_det_quant<
199 BOX: PrimInt + AsPrimitive<i32> + AsPrimitive<f32> + Send + Sync,
200 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
201>(
202 boxes: (ArrayView2<BOX>, Quantization),
203 scores: (ArrayView2<SCORE>, Quantization),
204 score_threshold: f32,
205 iou_threshold: f32,
206 nms: Option<Nms>,
207 output_boxes: &mut Vec<DetectBox>,
208) where
209 f32: AsPrimitive<SCORE>,
210{
211 impl_yolo_split_quant::<XYWH, _, _>(
212 boxes,
213 scores,
214 score_threshold,
215 iou_threshold,
216 nms,
217 output_boxes,
218 );
219}
220
221pub fn decode_yolo_split_det_float<T>(
233 boxes: ArrayView2<T>,
234 scores: ArrayView2<T>,
235 score_threshold: f32,
236 iou_threshold: f32,
237 nms: Option<Nms>,
238 output_boxes: &mut Vec<DetectBox>,
239) where
240 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
241 f32: AsPrimitive<T>,
242{
243 impl_yolo_split_float::<XYWH, _, _>(
244 boxes,
245 scores,
246 score_threshold,
247 iou_threshold,
248 nms,
249 output_boxes,
250 );
251}
252
253#[allow(clippy::too_many_arguments)]
267pub fn decode_yolo_split_segdet<
268 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
269 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
270 MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
271 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
272>(
273 boxes: (ArrayView2<BOX>, Quantization),
274 scores: (ArrayView2<SCORE>, Quantization),
275 mask_coeff: (ArrayView2<MASK>, Quantization),
276 protos: (ArrayView3<PROTO>, Quantization),
277 score_threshold: f32,
278 iou_threshold: f32,
279 nms: Option<Nms>,
280 output_boxes: &mut Vec<DetectBox>,
281 output_masks: &mut Vec<Segmentation>,
282) -> Result<(), crate::DecoderError>
283where
284 f32: AsPrimitive<SCORE>,
285{
286 impl_yolo_split_segdet_quant::<XYWH, _, _, _, _>(
287 boxes,
288 scores,
289 mask_coeff,
290 protos,
291 score_threshold,
292 iou_threshold,
293 nms,
294 output_boxes,
295 output_masks,
296 )
297}
298
299#[allow(clippy::too_many_arguments)]
313pub fn decode_yolo_split_segdet_float<T>(
314 boxes: ArrayView2<T>,
315 scores: ArrayView2<T>,
316 mask_coeff: ArrayView2<T>,
317 protos: ArrayView3<T>,
318 score_threshold: f32,
319 iou_threshold: f32,
320 nms: Option<Nms>,
321 output_boxes: &mut Vec<DetectBox>,
322 output_masks: &mut Vec<Segmentation>,
323) -> Result<(), crate::DecoderError>
324where
325 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
326 f32: AsPrimitive<T>,
327{
328 impl_yolo_split_segdet_float::<XYWH, _, _, _, _>(
329 boxes,
330 scores,
331 mask_coeff,
332 protos,
333 score_threshold,
334 iou_threshold,
335 nms,
336 output_boxes,
337 output_masks,
338 )
339}
340
341pub fn decode_yolo_end_to_end_det_float<T>(
356 output: ArrayView2<T>,
357 score_threshold: f32,
358 output_boxes: &mut Vec<DetectBox>,
359) -> Result<(), crate::DecoderError>
360where
361 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
362 f32: AsPrimitive<T>,
363{
364 if output.shape()[0] < 6 {
366 return Err(crate::DecoderError::InvalidShape(format!(
367 "End-to-end detection output requires at least 6 rows, got {}",
368 output.shape()[0]
369 )));
370 }
371
372 let boxes = output.slice(s![0..4, ..]).reversed_axes();
374 let scores = output.slice(s![4..5, ..]).reversed_axes();
375 let classes = output.slice(s![5, ..]);
376 let mut boxes =
377 postprocess_boxes_index_float::<XYXY, _, _>(score_threshold.as_(), boxes, scores);
378 boxes.truncate(output_boxes.capacity());
379 output_boxes.clear();
380 for (mut b, i) in boxes.into_iter() {
381 b.label = classes[i].as_() as usize;
382 output_boxes.push(b);
383 }
384 Ok(())
386}
387
388pub fn decode_yolo_end_to_end_segdet_float<T>(
406 output: ArrayView2<T>,
407 protos: ArrayView3<T>,
408 score_threshold: f32,
409 output_boxes: &mut Vec<DetectBox>,
410 output_masks: &mut Vec<crate::Segmentation>,
411) -> Result<(), crate::DecoderError>
412where
413 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
414 f32: AsPrimitive<T>,
415{
416 let (boxes, scores, classes, mask_coeff) =
417 postprocess_yolo_end_to_end_segdet(&output, protos.dim().2)?;
418 let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
419 boxes,
420 scores,
421 classes,
422 score_threshold,
423 output_boxes.capacity(),
424 );
425
426 impl_yolo_split_segdet_process_masks(boxes, mask_coeff, protos, output_boxes, output_masks)
429}
430
431pub fn decode_yolo_split_end_to_end_det_float<T: Float + AsPrimitive<f32>>(
440 boxes: ArrayView2<T>,
441 scores: ArrayView2<T>,
442 classes: ArrayView2<T>,
443 score_threshold: f32,
444 output_boxes: &mut Vec<DetectBox>,
445) -> Result<(), crate::DecoderError> {
446 let n = boxes.shape()[0];
447 if boxes.shape()[1] != 4 {
448 return Err(crate::DecoderError::InvalidShape(format!(
449 "Split end-to-end boxes must have 4 columns, got {}",
450 boxes.shape()[1]
451 )));
452 }
453 output_boxes.clear();
454 for i in 0..n {
455 let score: f32 = scores[[i, 0]].as_();
456 if score < score_threshold {
457 continue;
458 }
459 if output_boxes.len() >= output_boxes.capacity() {
460 break;
461 }
462 output_boxes.push(DetectBox {
463 bbox: BoundingBox {
464 xmin: boxes[[i, 0]].as_(),
465 ymin: boxes[[i, 1]].as_(),
466 xmax: boxes[[i, 2]].as_(),
467 ymax: boxes[[i, 3]].as_(),
468 },
469 score,
470 label: classes[[i, 0]].as_() as usize,
471 });
472 }
473 Ok(())
474}
475
476#[allow(clippy::too_many_arguments)]
485pub fn decode_yolo_split_end_to_end_segdet_float<T>(
486 boxes: ArrayView2<T>,
487 scores: ArrayView2<T>,
488 classes: ArrayView2<T>,
489 mask_coeff: ArrayView2<T>,
490 protos: ArrayView3<T>,
491 score_threshold: f32,
492 output_boxes: &mut Vec<DetectBox>,
493 output_masks: &mut Vec<crate::Segmentation>,
494) -> Result<(), crate::DecoderError>
495where
496 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
497 f32: AsPrimitive<T>,
498{
499 let (boxes, scores, classes, mask_coeff) =
500 postprocess_yolo_split_end_to_end_segdet(boxes, scores, &classes, mask_coeff)?;
501 let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
502 boxes,
503 scores,
504 classes,
505 score_threshold,
506 output_boxes.capacity(),
507 );
508
509 impl_yolo_split_segdet_process_masks(boxes, mask_coeff, protos, output_boxes, output_masks)
510}
511
512#[allow(clippy::type_complexity)]
513pub(crate) fn postprocess_yolo_end_to_end_segdet<'a, T>(
514 output: &'a ArrayView2<'_, T>,
515 num_protos: usize,
516) -> Result<
517 (
518 ArrayView2<'a, T>,
519 ArrayView2<'a, T>,
520 ArrayView1<'a, T>,
521 ArrayView2<'a, T>,
522 ),
523 crate::DecoderError,
524> {
525 if output.shape()[0] < 7 {
527 return Err(crate::DecoderError::InvalidShape(format!(
528 "End-to-end segdet output requires at least 7 rows, got {}",
529 output.shape()[0]
530 )));
531 }
532
533 let num_mask_coeffs = output.shape()[0] - 6;
534 if num_mask_coeffs != num_protos {
535 return Err(crate::DecoderError::InvalidShape(format!(
536 "Mask coefficients count ({}) doesn't match protos count ({})",
537 num_mask_coeffs, num_protos
538 )));
539 }
540
541 let boxes = output.slice(s![0..4, ..]).reversed_axes();
543 let scores = output.slice(s![4..5, ..]).reversed_axes();
544 let classes = output.slice(s![5, ..]);
545 let mask_coeff = output.slice(s![6.., ..]).reversed_axes();
546 Ok((boxes, scores, classes, mask_coeff))
547}
548
549#[allow(clippy::type_complexity)]
550pub(crate) fn postprocess_yolo_split_end_to_end_segdet<
551 'a,
552 'b,
553 'c,
554 'd,
555 BOXES,
556 SCORES,
557 CLASS,
558 MASK,
559>(
560 boxes: ArrayView2<'a, BOXES>,
561 scores: ArrayView2<'b, SCORES>,
562 classes: &'c ArrayView2<CLASS>,
563 mask_coeff: ArrayView2<'d, MASK>,
564) -> Result<
565 (
566 ArrayView2<'a, BOXES>,
567 ArrayView2<'b, SCORES>,
568 ArrayView1<'c, CLASS>,
569 ArrayView2<'d, MASK>,
570 ),
571 crate::DecoderError,
572> {
573 if boxes.shape()[0] != 4 {
574 return Err(crate::DecoderError::InvalidShape(format!(
575 "Split end-to-end boxes must have 4 columns, got {}",
576 boxes.shape()[0]
577 )));
578 }
579 let boxes = boxes.reversed_axes();
580 let scores = scores.reversed_axes();
581 let classes = classes.slice(s![0, ..]);
582 let mask_coeff = mask_coeff.reversed_axes();
583 Ok((boxes, scores, classes, mask_coeff))
584}
585pub(crate) fn impl_yolo_quant<B: BBoxTypeTrait, T: PrimInt + AsPrimitive<f32> + Send + Sync>(
590 output: (ArrayView2<T>, Quantization),
591 score_threshold: f32,
592 iou_threshold: f32,
593 nms: Option<Nms>,
594 output_boxes: &mut Vec<DetectBox>,
595) where
596 f32: AsPrimitive<T>,
597{
598 let (boxes, quant_boxes) = output;
599 let (boxes_tensor, scores_tensor) = postprocess_yolo(&boxes);
600
601 let boxes = {
602 let score_threshold = quantize_score_threshold(score_threshold, quant_boxes);
603 postprocess_boxes_quant::<B, _, _>(
604 score_threshold,
605 boxes_tensor,
606 scores_tensor,
607 quant_boxes,
608 )
609 };
610
611 let boxes = dispatch_nms_int(nms, iou_threshold, boxes);
612 let len = output_boxes.capacity().min(boxes.len());
613 output_boxes.clear();
614 for b in boxes.iter().take(len) {
615 output_boxes.push(dequant_detect_box(b, quant_boxes));
616 }
617}
618
619pub(crate) fn impl_yolo_float<B: BBoxTypeTrait, T: Float + AsPrimitive<f32> + Send + Sync>(
624 output: ArrayView2<T>,
625 score_threshold: f32,
626 iou_threshold: f32,
627 nms: Option<Nms>,
628 output_boxes: &mut Vec<DetectBox>,
629) where
630 f32: AsPrimitive<T>,
631{
632 let (boxes_tensor, scores_tensor) = postprocess_yolo(&output);
633 let boxes =
634 postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
635 let boxes = dispatch_nms_float(nms, iou_threshold, boxes);
636 let len = output_boxes.capacity().min(boxes.len());
637 output_boxes.clear();
638 for b in boxes.into_iter().take(len) {
639 output_boxes.push(b);
640 }
641}
642
643pub(crate) fn impl_yolo_split_quant<
653 B: BBoxTypeTrait,
654 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
655 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
656>(
657 boxes: (ArrayView2<BOX>, Quantization),
658 scores: (ArrayView2<SCORE>, Quantization),
659 score_threshold: f32,
660 iou_threshold: f32,
661 nms: Option<Nms>,
662 output_boxes: &mut Vec<DetectBox>,
663) where
664 f32: AsPrimitive<SCORE>,
665{
666 let (boxes_tensor, quant_boxes) = boxes;
667 let (scores_tensor, quant_scores) = scores;
668
669 let boxes_tensor = boxes_tensor.reversed_axes();
670 let scores_tensor = scores_tensor.reversed_axes();
671
672 let boxes = {
673 let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
674 postprocess_boxes_quant::<B, _, _>(
675 score_threshold,
676 boxes_tensor,
677 scores_tensor,
678 quant_boxes,
679 )
680 };
681
682 let boxes = dispatch_nms_int(nms, iou_threshold, boxes);
683 let len = output_boxes.capacity().min(boxes.len());
684 output_boxes.clear();
685 for b in boxes.iter().take(len) {
686 output_boxes.push(dequant_detect_box(b, quant_scores));
687 }
688}
689
690pub(crate) fn impl_yolo_split_float<
699 B: BBoxTypeTrait,
700 BOX: Float + AsPrimitive<f32> + Send + Sync,
701 SCORE: Float + AsPrimitive<f32> + Send + Sync,
702>(
703 boxes_tensor: ArrayView2<BOX>,
704 scores_tensor: ArrayView2<SCORE>,
705 score_threshold: f32,
706 iou_threshold: f32,
707 nms: Option<Nms>,
708 output_boxes: &mut Vec<DetectBox>,
709) where
710 f32: AsPrimitive<SCORE>,
711{
712 let boxes_tensor = boxes_tensor.reversed_axes();
713 let scores_tensor = scores_tensor.reversed_axes();
714 let boxes =
715 postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
716 let boxes = dispatch_nms_float(nms, iou_threshold, boxes);
717 let len = output_boxes.capacity().min(boxes.len());
718 output_boxes.clear();
719 for b in boxes.into_iter().take(len) {
720 output_boxes.push(b);
721 }
722}
723
724pub(crate) fn impl_yolo_segdet_quant<
734 B: BBoxTypeTrait,
735 BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
736 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
737>(
738 boxes: (ArrayView2<BOX>, Quantization),
739 protos: (ArrayView3<PROTO>, Quantization),
740 score_threshold: f32,
741 iou_threshold: f32,
742 nms: Option<Nms>,
743 output_boxes: &mut Vec<DetectBox>,
744 output_masks: &mut Vec<Segmentation>,
745) -> Result<(), crate::DecoderError>
746where
747 f32: AsPrimitive<BOX>,
748{
749 let (boxes, quant_boxes) = boxes;
750 let num_protos = protos.0.dim().2;
751
752 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
753 let boxes = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
754 (boxes_tensor, quant_boxes),
755 (scores_tensor, quant_boxes),
756 score_threshold,
757 iou_threshold,
758 nms,
759 output_boxes.capacity(),
760 );
761
762 impl_yolo_split_segdet_quant_process_masks::<_, _>(
763 boxes,
764 (mask_tensor, quant_boxes),
765 protos,
766 output_boxes,
767 output_masks,
768 )
769}
770
771pub(crate) fn impl_yolo_segdet_float<
781 B: BBoxTypeTrait,
782 BOX: Float + AsPrimitive<f32> + Send + Sync,
783 PROTO: Float + AsPrimitive<f32> + Send + Sync,
784>(
785 boxes: ArrayView2<BOX>,
786 protos: ArrayView3<PROTO>,
787 score_threshold: f32,
788 iou_threshold: f32,
789 nms: Option<Nms>,
790 output_boxes: &mut Vec<DetectBox>,
791 output_masks: &mut Vec<Segmentation>,
792) -> Result<(), crate::DecoderError>
793where
794 f32: AsPrimitive<BOX>,
795{
796 let num_protos = protos.dim().2;
797 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
798 let boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
799 boxes_tensor,
800 scores_tensor,
801 score_threshold,
802 iou_threshold,
803 nms,
804 output_boxes.capacity(),
805 );
806 impl_yolo_split_segdet_process_masks(boxes, mask_tensor, protos, output_boxes, output_masks)
807}
808
809pub(crate) fn impl_yolo_segdet_get_boxes<
810 B: BBoxTypeTrait,
811 BOX: Float + AsPrimitive<f32> + Send + Sync,
812 SCORE: Float + AsPrimitive<f32> + Send + Sync,
813>(
814 boxes_tensor: ArrayView2<BOX>,
815 scores_tensor: ArrayView2<SCORE>,
816 score_threshold: f32,
817 iou_threshold: f32,
818 nms: Option<Nms>,
819 max_boxes: usize,
820) -> Vec<(DetectBox, usize)>
821where
822 f32: AsPrimitive<SCORE>,
823{
824 let boxes = postprocess_boxes_index_float::<B, _, _>(
825 score_threshold.as_(),
826 boxes_tensor,
827 scores_tensor,
828 );
829 let mut boxes = dispatch_nms_extra_float(nms, iou_threshold, boxes);
830 boxes.truncate(max_boxes);
831 boxes
832}
833
834pub(crate) fn impl_yolo_end_to_end_segdet_get_boxes<
835 B: BBoxTypeTrait,
836 BOX: Float + AsPrimitive<f32> + Send + Sync,
837 SCORE: Float + AsPrimitive<f32> + Send + Sync,
838 CLASS: AsPrimitive<f32> + Send + Sync,
839>(
840 boxes: ArrayView2<BOX>,
841 scores: ArrayView2<SCORE>,
842 classes: ArrayView1<CLASS>,
843 score_threshold: f32,
844 max_boxes: usize,
845) -> Vec<(DetectBox, usize)>
846where
847 f32: AsPrimitive<SCORE>,
848{
849 let mut boxes = postprocess_boxes_index_float::<B, _, _>(score_threshold.as_(), boxes, scores);
850 boxes.truncate(max_boxes);
851 for (b, ind) in &mut boxes {
852 b.label = classes[*ind].as_().round() as usize;
853 }
854 boxes
855}
856
857pub(crate) fn impl_yolo_split_segdet_process_masks<
858 MASK: Float + AsPrimitive<f32> + Send + Sync,
859 PROTO: Float + AsPrimitive<f32> + Send + Sync,
860>(
861 boxes: Vec<(DetectBox, usize)>,
862 masks_tensor: ArrayView2<MASK>,
863 protos_tensor: ArrayView3<PROTO>,
864 output_boxes: &mut Vec<DetectBox>,
865 output_masks: &mut Vec<Segmentation>,
866) -> Result<(), crate::DecoderError> {
867 let boxes = decode_segdet_f32(boxes, masks_tensor, protos_tensor)?;
868 output_boxes.clear();
869 output_masks.clear();
870 for (b, m) in boxes.into_iter() {
871 output_boxes.push(b);
872 output_masks.push(Segmentation {
873 xmin: b.bbox.xmin,
874 ymin: b.bbox.ymin,
875 xmax: b.bbox.xmax,
876 ymax: b.bbox.ymax,
877 segmentation: m,
878 });
879 }
880 Ok(())
881}
882pub(crate) fn impl_yolo_split_segdet_quant_get_boxes<
886 B: BBoxTypeTrait,
887 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
888 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
889>(
890 boxes: (ArrayView2<BOX>, Quantization),
891 scores: (ArrayView2<SCORE>, Quantization),
892 score_threshold: f32,
893 iou_threshold: f32,
894 nms: Option<Nms>,
895 max_boxes: usize,
896) -> Vec<(DetectBox, usize)>
897where
898 f32: AsPrimitive<SCORE>,
899{
900 let (boxes_tensor, quant_boxes) = boxes;
901 let (scores_tensor, quant_scores) = scores;
902
903 let boxes = {
904 let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
905 postprocess_boxes_index_quant::<B, _, _>(
906 score_threshold,
907 boxes_tensor,
908 scores_tensor,
909 quant_boxes,
910 )
911 };
912 let mut boxes = dispatch_nms_extra_int(nms, iou_threshold, boxes);
913 boxes.truncate(max_boxes);
914 boxes
915 .into_iter()
916 .map(|(b, i)| (dequant_detect_box(&b, quant_scores), i))
917 .collect()
918}
919
920pub(crate) fn impl_yolo_split_segdet_quant_process_masks<
921 MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
922 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
923>(
924 boxes: Vec<(DetectBox, usize)>,
925 mask_coeff: (ArrayView2<MASK>, Quantization),
926 protos: (ArrayView3<PROTO>, Quantization),
927 output_boxes: &mut Vec<DetectBox>,
928 output_masks: &mut Vec<Segmentation>,
929) -> Result<(), crate::DecoderError> {
930 let (masks, quant_masks) = mask_coeff;
931 let (protos, quant_protos) = protos;
932
933 let boxes = decode_segdet_quant(boxes, masks, protos, quant_masks, quant_protos)?;
934 output_boxes.clear();
935 output_masks.clear();
936 for (b, m) in boxes.into_iter() {
937 output_boxes.push(b);
938 output_masks.push(Segmentation {
939 xmin: b.bbox.xmin,
940 ymin: b.bbox.ymin,
941 xmax: b.bbox.xmax,
942 ymax: b.bbox.ymax,
943 segmentation: m,
944 });
945 }
946 Ok(())
947}
948
949#[allow(clippy::too_many_arguments)]
950pub(crate) fn impl_yolo_split_segdet_quant<
962 B: BBoxTypeTrait,
963 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
964 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
965 MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
966 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
967>(
968 boxes: (ArrayView2<BOX>, Quantization),
969 scores: (ArrayView2<SCORE>, Quantization),
970 mask_coeff: (ArrayView2<MASK>, Quantization),
971 protos: (ArrayView3<PROTO>, Quantization),
972 score_threshold: f32,
973 iou_threshold: f32,
974 nms: Option<Nms>,
975 output_boxes: &mut Vec<DetectBox>,
976 output_masks: &mut Vec<Segmentation>,
977) -> Result<(), crate::DecoderError>
978where
979 f32: AsPrimitive<SCORE>,
980{
981 let (boxes_, scores_, mask_coeff_) =
982 postprocess_yolo_split_segdet(boxes.0, scores.0, mask_coeff.0);
983 let boxes = (boxes_, boxes.1);
984 let scores = (scores_, scores.1);
985 let mask_coeff = (mask_coeff_, mask_coeff.1);
986
987 let boxes = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
988 boxes,
989 scores,
990 score_threshold,
991 iou_threshold,
992 nms,
993 output_boxes.capacity(),
994 );
995
996 impl_yolo_split_segdet_quant_process_masks(
997 boxes,
998 mask_coeff,
999 protos,
1000 output_boxes,
1001 output_masks,
1002 )
1003}
1004
1005#[allow(clippy::too_many_arguments)]
1006pub(crate) fn impl_yolo_split_segdet_float<
1018 B: BBoxTypeTrait,
1019 BOX: Float + AsPrimitive<f32> + Send + Sync,
1020 SCORE: Float + AsPrimitive<f32> + Send + Sync,
1021 MASK: Float + AsPrimitive<f32> + Send + Sync,
1022 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1023>(
1024 boxes_tensor: ArrayView2<BOX>,
1025 scores_tensor: ArrayView2<SCORE>,
1026 mask_tensor: ArrayView2<MASK>,
1027 protos: ArrayView3<PROTO>,
1028 score_threshold: f32,
1029 iou_threshold: f32,
1030 nms: Option<Nms>,
1031 output_boxes: &mut Vec<DetectBox>,
1032 output_masks: &mut Vec<Segmentation>,
1033) -> Result<(), crate::DecoderError>
1034where
1035 f32: AsPrimitive<SCORE>,
1036{
1037 let (boxes_tensor, scores_tensor, mask_tensor) =
1038 postprocess_yolo_split_segdet(boxes_tensor, scores_tensor, mask_tensor);
1039
1040 let boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
1041 boxes_tensor,
1042 scores_tensor,
1043 score_threshold,
1044 iou_threshold,
1045 nms,
1046 output_boxes.capacity(),
1047 );
1048 impl_yolo_split_segdet_process_masks(boxes, mask_tensor, protos, output_boxes, output_masks)
1049}
1050
1051pub fn impl_yolo_segdet_quant_proto<
1058 B: BBoxTypeTrait,
1059 BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1060 PROTO: PrimInt
1061 + AsPrimitive<i64>
1062 + AsPrimitive<i128>
1063 + AsPrimitive<f32>
1064 + AsPrimitive<i8>
1065 + Send
1066 + Sync,
1067>(
1068 boxes: (ArrayView2<BOX>, Quantization),
1069 protos: (ArrayView3<PROTO>, Quantization),
1070 score_threshold: f32,
1071 iou_threshold: f32,
1072 nms: Option<Nms>,
1073 output_boxes: &mut Vec<DetectBox>,
1074) -> ProtoData
1075where
1076 f32: AsPrimitive<BOX>,
1077{
1078 let (boxes_arr, quant_boxes) = boxes;
1079 let (protos_arr, quant_protos) = protos;
1080 let num_protos = protos_arr.dim().2;
1081
1082 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes_arr, num_protos);
1083
1084 let det_indices = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
1085 (boxes_tensor, quant_boxes),
1086 (scores_tensor, quant_boxes),
1087 score_threshold,
1088 iou_threshold,
1089 nms,
1090 output_boxes.capacity(),
1091 );
1092
1093 extract_proto_data_quant(
1094 det_indices,
1095 mask_tensor,
1096 quant_boxes,
1097 protos_arr,
1098 quant_protos,
1099 output_boxes,
1100 )
1101}
1102
1103pub(crate) fn impl_yolo_segdet_float_proto<
1106 B: BBoxTypeTrait,
1107 BOX: Float + AsPrimitive<f32> + Send + Sync,
1108 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1109>(
1110 boxes: ArrayView2<BOX>,
1111 protos: ArrayView3<PROTO>,
1112 score_threshold: f32,
1113 iou_threshold: f32,
1114 nms: Option<Nms>,
1115 output_boxes: &mut Vec<DetectBox>,
1116) -> ProtoData
1117where
1118 f32: AsPrimitive<BOX>,
1119{
1120 let num_protos = protos.dim().2;
1121 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
1122
1123 let boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
1124 boxes_tensor,
1125 scores_tensor,
1126 score_threshold,
1127 iou_threshold,
1128 nms,
1129 output_boxes.capacity(),
1130 );
1131
1132 extract_proto_data_float(boxes, mask_tensor, protos, output_boxes)
1133}
1134
1135#[allow(clippy::too_many_arguments)]
1138pub(crate) fn impl_yolo_split_segdet_float_proto<
1139 B: BBoxTypeTrait,
1140 BOX: Float + AsPrimitive<f32> + Send + Sync,
1141 SCORE: Float + AsPrimitive<f32> + Send + Sync,
1142 MASK: Float + AsPrimitive<f32> + Send + Sync,
1143 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1144>(
1145 boxes_tensor: ArrayView2<BOX>,
1146 scores_tensor: ArrayView2<SCORE>,
1147 mask_tensor: ArrayView2<MASK>,
1148 protos: ArrayView3<PROTO>,
1149 score_threshold: f32,
1150 iou_threshold: f32,
1151 nms: Option<Nms>,
1152 output_boxes: &mut Vec<DetectBox>,
1153) -> ProtoData
1154where
1155 f32: AsPrimitive<SCORE>,
1156{
1157 let (boxes_tensor, scores_tensor, mask_tensor) =
1158 postprocess_yolo_split_segdet(boxes_tensor, scores_tensor, mask_tensor);
1159 let det_indices = impl_yolo_segdet_get_boxes::<B, _, _>(
1160 boxes_tensor,
1161 scores_tensor,
1162 score_threshold,
1163 iou_threshold,
1164 nms,
1165 output_boxes.capacity(),
1166 );
1167
1168 extract_proto_data_float(det_indices, mask_tensor, protos, output_boxes)
1169}
1170
1171pub fn decode_yolo_end_to_end_segdet_float_proto<T>(
1173 output: ArrayView2<T>,
1174 protos: ArrayView3<T>,
1175 score_threshold: f32,
1176 output_boxes: &mut Vec<DetectBox>,
1177) -> Result<ProtoData, crate::DecoderError>
1178where
1179 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
1180 f32: AsPrimitive<T>,
1181{
1182 let (boxes, scores, classes, mask_coeff) =
1183 postprocess_yolo_end_to_end_segdet(&output, protos.dim().2)?;
1184 let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
1185 boxes,
1186 scores,
1187 classes,
1188 score_threshold,
1189 output_boxes.capacity(),
1190 );
1191
1192 Ok(extract_proto_data_float(
1193 boxes,
1194 mask_coeff,
1195 protos,
1196 output_boxes,
1197 ))
1198}
1199
1200#[allow(clippy::too_many_arguments)]
1202pub fn decode_yolo_split_end_to_end_segdet_float_proto<T>(
1203 boxes: ArrayView2<T>,
1204 scores: ArrayView2<T>,
1205 classes: ArrayView2<T>,
1206 mask_coeff: ArrayView2<T>,
1207 protos: ArrayView3<T>,
1208 score_threshold: f32,
1209 output_boxes: &mut Vec<DetectBox>,
1210) -> Result<ProtoData, crate::DecoderError>
1211where
1212 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
1213 f32: AsPrimitive<T>,
1214{
1215 let (boxes, scores, classes, mask_coeff) =
1216 postprocess_yolo_split_end_to_end_segdet(boxes, scores, &classes, mask_coeff)?;
1217 let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
1218 boxes,
1219 scores,
1220 classes,
1221 score_threshold,
1222 output_boxes.capacity(),
1223 );
1224
1225 Ok(extract_proto_data_float(
1226 boxes,
1227 mask_coeff,
1228 protos,
1229 output_boxes,
1230 ))
1231}
1232
1233pub(super) fn extract_proto_data_float<
1235 MASK: Float + AsPrimitive<f32> + Send + Sync,
1236 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1237>(
1238 det_indices: Vec<(DetectBox, usize)>,
1239 mask_tensor: ArrayView2<MASK>,
1240 protos: ArrayView3<PROTO>,
1241 output_boxes: &mut Vec<DetectBox>,
1242) -> ProtoData {
1243 let mut mask_coefficients = Vec::with_capacity(det_indices.len());
1244 output_boxes.clear();
1245 for (det, idx) in det_indices {
1246 output_boxes.push(det);
1247 let row = mask_tensor.row(idx);
1248 mask_coefficients.push(row.iter().map(|v| v.as_()).collect());
1249 }
1250 let protos_f32 = protos.map(|v| v.as_());
1251 ProtoData {
1252 mask_coefficients,
1253 protos: ProtoTensor::Float(protos_f32),
1254 }
1255}
1256
1257pub(crate) fn extract_proto_data_quant<
1263 MASK: PrimInt + AsPrimitive<f32> + Send + Sync,
1264 PROTO: PrimInt + AsPrimitive<f32> + AsPrimitive<i8> + Send + Sync + 'static,
1265>(
1266 det_indices: Vec<(DetectBox, usize)>,
1267 mask_tensor: ArrayView2<MASK>,
1268 quant_masks: Quantization,
1269 protos: ArrayView3<PROTO>,
1270 quant_protos: Quantization,
1271 output_boxes: &mut Vec<DetectBox>,
1272) -> ProtoData {
1273 let mut mask_coefficients = Vec::with_capacity(det_indices.len());
1274 output_boxes.clear();
1275 for (det, idx) in det_indices {
1276 output_boxes.push(det);
1277 let row = mask_tensor.row(idx);
1278 mask_coefficients.push(
1279 row.iter()
1280 .map(|v| (v.as_() - quant_masks.zero_point as f32) * quant_masks.scale)
1281 .collect(),
1282 );
1283 }
1284 let protos_i8 = if std::any::TypeId::of::<PROTO>() == std::any::TypeId::of::<i8>() {
1288 let view_i8 =
1290 unsafe { &*(&protos as *const ArrayView3<'_, PROTO> as *const ArrayView3<'_, i8>) };
1291 view_i8.to_owned()
1292 } else {
1293 protos.map(|v| {
1294 let v_i8: i8 = v.as_();
1295 v_i8
1296 })
1297 };
1298 ProtoData {
1299 mask_coefficients,
1300 protos: ProtoTensor::Quantized {
1301 protos: protos_i8,
1302 quantization: quant_protos,
1303 },
1304 }
1305}
1306
1307fn postprocess_yolo<'a, T>(
1308 output: &'a ArrayView2<'_, T>,
1309) -> (ArrayView2<'a, T>, ArrayView2<'a, T>) {
1310 let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1311 let scores_tensor = output.slice(s![4.., ..,]).reversed_axes();
1312 (boxes_tensor, scores_tensor)
1313}
1314
1315pub(crate) fn postprocess_yolo_seg<'a, T>(
1316 output: &'a ArrayView2<'_, T>,
1317 num_protos: usize,
1318) -> (ArrayView2<'a, T>, ArrayView2<'a, T>, ArrayView2<'a, T>) {
1319 assert!(
1320 output.shape()[0] > num_protos + 4,
1321 "Output shape is too short: {} <= {} + 4",
1322 output.shape()[0],
1323 num_protos
1324 );
1325 let num_classes = output.shape()[0] - 4 - num_protos;
1326 let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1327 let scores_tensor = output.slice(s![4..(num_classes + 4), ..,]).reversed_axes();
1328 let mask_tensor = output.slice(s![(num_classes + 4).., ..,]).reversed_axes();
1329 (boxes_tensor, scores_tensor, mask_tensor)
1330}
1331
1332pub(crate) fn postprocess_yolo_split_segdet<'a, 'b, 'c, BOX, SCORE, MASK>(
1333 boxes_tensor: ArrayView2<'a, BOX>,
1334 scores_tensor: ArrayView2<'b, SCORE>,
1335 mask_tensor: ArrayView2<'c, MASK>,
1336) -> (
1337 ArrayView2<'a, BOX>,
1338 ArrayView2<'b, SCORE>,
1339 ArrayView2<'c, MASK>,
1340) {
1341 let boxes_tensor = boxes_tensor.reversed_axes();
1342 let scores_tensor = scores_tensor.reversed_axes();
1343 let mask_tensor = mask_tensor.reversed_axes();
1344 (boxes_tensor, scores_tensor, mask_tensor)
1345}
1346
1347fn decode_segdet_f32<
1348 MASK: Float + AsPrimitive<f32> + Send + Sync,
1349 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1350>(
1351 boxes: Vec<(DetectBox, usize)>,
1352 masks: ArrayView2<MASK>,
1353 protos: ArrayView3<PROTO>,
1354) -> Result<Vec<(DetectBox, Array3<u8>)>, crate::DecoderError> {
1355 if boxes.is_empty() {
1356 return Ok(Vec::new());
1357 }
1358 if masks.shape()[1] != protos.shape()[2] {
1359 return Err(crate::DecoderError::InvalidShape(format!(
1360 "Mask coefficients count ({}) doesn't match protos channel count ({})",
1361 masks.shape()[1],
1362 protos.shape()[2],
1363 )));
1364 }
1365 boxes
1366 .into_par_iter()
1367 .map(|mut b| {
1368 let ind = b.1;
1369 let (protos, roi) = protobox(&protos, &b.0.bbox)?;
1370 b.0.bbox = roi;
1371 Ok((b.0, make_segmentation(masks.row(ind), protos.view())))
1372 })
1373 .collect()
1374}
1375
1376pub(crate) fn decode_segdet_quant<
1377 MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1378 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1379>(
1380 boxes: Vec<(DetectBox, usize)>,
1381 masks: ArrayView2<MASK>,
1382 protos: ArrayView3<PROTO>,
1383 quant_masks: Quantization,
1384 quant_protos: Quantization,
1385) -> Result<Vec<(DetectBox, Array3<u8>)>, crate::DecoderError> {
1386 if boxes.is_empty() {
1387 return Ok(Vec::new());
1388 }
1389 if masks.shape()[1] != protos.shape()[2] {
1390 return Err(crate::DecoderError::InvalidShape(format!(
1391 "Mask coefficients count ({}) doesn't match protos channel count ({})",
1392 masks.shape()[1],
1393 protos.shape()[2],
1394 )));
1395 }
1396
1397 let total_bits = MASK::zero().count_zeros() + PROTO::zero().count_zeros() + 5; boxes
1399 .into_iter()
1400 .map(|mut b| {
1401 let i = b.1;
1402 let (protos, roi) = protobox(&protos, &b.0.bbox.to_canonical())?;
1403 b.0.bbox = roi;
1404 let seg = match total_bits {
1405 0..=64 => make_segmentation_quant::<MASK, PROTO, i64>(
1406 masks.row(i),
1407 protos.view(),
1408 quant_masks,
1409 quant_protos,
1410 ),
1411 65..=128 => make_segmentation_quant::<MASK, PROTO, i128>(
1412 masks.row(i),
1413 protos.view(),
1414 quant_masks,
1415 quant_protos,
1416 ),
1417 _ => {
1418 return Err(crate::DecoderError::NotSupported(format!(
1419 "Unsupported bit width ({total_bits}) for segmentation computation"
1420 )));
1421 }
1422 };
1423 Ok((b.0, seg))
1424 })
1425 .collect()
1426}
1427
1428fn protobox<'a, T>(
1429 protos: &'a ArrayView3<T>,
1430 roi: &BoundingBox,
1431) -> Result<(ArrayView3<'a, T>, BoundingBox), crate::DecoderError> {
1432 let width = protos.dim().1 as f32;
1433 let height = protos.dim().0 as f32;
1434
1435 const NORM_LIMIT: f32 = 2.0;
1446 if roi.xmin > NORM_LIMIT
1447 || roi.ymin > NORM_LIMIT
1448 || roi.xmax > NORM_LIMIT
1449 || roi.ymax > NORM_LIMIT
1450 {
1451 return Err(crate::DecoderError::InvalidShape(format!(
1452 "Bounding box coordinates appear un-normalized (pixel-space). \
1453 Got bbox=({:.2}, {:.2}, {:.2}, {:.2}) but expected values in [0, 1]. \
1454 ONNX models output pixel-space boxes — normalize them by dividing by \
1455 the input dimensions before calling decode().",
1456 roi.xmin, roi.ymin, roi.xmax, roi.ymax,
1457 )));
1458 }
1459
1460 let roi = [
1461 (roi.xmin * width).clamp(0.0, width) as usize,
1462 (roi.ymin * height).clamp(0.0, height) as usize,
1463 (roi.xmax * width).clamp(0.0, width).ceil() as usize,
1464 (roi.ymax * height).clamp(0.0, height).ceil() as usize,
1465 ];
1466
1467 let roi_norm = [
1468 roi[0] as f32 / width,
1469 roi[1] as f32 / height,
1470 roi[2] as f32 / width,
1471 roi[3] as f32 / height,
1472 ]
1473 .into();
1474
1475 let cropped = protos.slice(s![roi[1]..roi[3], roi[0]..roi[2], ..]);
1476
1477 Ok((cropped, roi_norm))
1478}
1479
1480fn make_segmentation<
1486 MASK: Float + AsPrimitive<f32> + Send + Sync,
1487 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1488>(
1489 mask: ArrayView1<MASK>,
1490 protos: ArrayView3<PROTO>,
1491) -> Array3<u8> {
1492 let shape = protos.shape();
1493
1494 let mask = mask.to_shape((1, mask.len())).unwrap();
1496 let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
1497 let protos = protos.reversed_axes();
1498 let mask = mask.map(|x| x.as_());
1499 let protos = protos.map(|x| x.as_());
1500
1501 let mask = mask
1503 .dot(&protos)
1504 .into_shape_with_order((shape[0], shape[1], 1))
1505 .unwrap();
1506
1507 mask.map(|x| {
1508 let sigmoid = 1.0 / (1.0 + (-*x).exp());
1509 (sigmoid * 255.0).round() as u8
1510 })
1511}
1512
1513fn make_segmentation_quant<
1520 MASK: PrimInt + AsPrimitive<DEST> + Send + Sync,
1521 PROTO: PrimInt + AsPrimitive<DEST> + Send + Sync,
1522 DEST: PrimInt + 'static + Signed + AsPrimitive<f32> + Debug,
1523>(
1524 mask: ArrayView1<MASK>,
1525 protos: ArrayView3<PROTO>,
1526 quant_masks: Quantization,
1527 quant_protos: Quantization,
1528) -> Array3<u8>
1529where
1530 i32: AsPrimitive<DEST>,
1531 f32: AsPrimitive<DEST>,
1532{
1533 let shape = protos.shape();
1534
1535 let mask = mask.to_shape((1, mask.len())).unwrap();
1537
1538 let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
1539 let protos = protos.reversed_axes();
1540
1541 let zp = quant_masks.zero_point.as_();
1542
1543 let mask = mask.mapv(|x| x.as_() - zp);
1544
1545 let zp = quant_protos.zero_point.as_();
1546 let protos = protos.mapv(|x| x.as_() - zp);
1547
1548 let segmentation = mask
1550 .dot(&protos)
1551 .into_shape_with_order((shape[0], shape[1], 1))
1552 .unwrap();
1553
1554 let combined_scale = quant_masks.scale * quant_protos.scale;
1555 segmentation.map(|x| {
1556 let val: f32 = (*x).as_() * combined_scale;
1557 let sigmoid = 1.0 / (1.0 + (-val).exp());
1558 (sigmoid * 255.0).round() as u8
1559 })
1560}
1561
1562pub fn yolo_segmentation_to_mask(
1574 segmentation: ArrayView3<u8>,
1575 threshold: u8,
1576) -> Result<Array2<u8>, crate::DecoderError> {
1577 if segmentation.shape()[2] != 1 {
1578 return Err(crate::DecoderError::InvalidShape(format!(
1579 "Yolo Instance Segmentation should have shape (H, W, 1), got (H, W, {})",
1580 segmentation.shape()[2]
1581 )));
1582 }
1583 Ok(segmentation
1584 .slice(s![.., .., 0])
1585 .map(|x| if *x >= threshold { 1 } else { 0 }))
1586}
1587
1588#[cfg(test)]
1589#[cfg_attr(coverage_nightly, coverage(off))]
1590mod tests {
1591 use super::*;
1592 use ndarray::Array2;
1593
1594 #[test]
1599 fn test_end_to_end_det_basic_filtering() {
1600 let data: Vec<f32> = vec![
1604 0.1, 0.2, 0.3, 0.1, 0.2, 0.3, 0.5, 0.6, 0.7, 0.5, 0.6, 0.7, 0.9, 0.1, 0.2, 0.0, 1.0, 2.0, ];
1612 let output = Array2::from_shape_vec((6, 3), data).unwrap();
1613
1614 let mut boxes = Vec::with_capacity(10);
1615 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1616
1617 assert_eq!(boxes.len(), 1);
1619 assert_eq!(boxes[0].label, 0);
1620 assert!((boxes[0].score - 0.9).abs() < 0.01);
1621 assert!((boxes[0].bbox.xmin - 0.1).abs() < 0.01);
1622 assert!((boxes[0].bbox.ymin - 0.1).abs() < 0.01);
1623 assert!((boxes[0].bbox.xmax - 0.5).abs() < 0.01);
1624 assert!((boxes[0].bbox.ymax - 0.5).abs() < 0.01);
1625 }
1626
1627 #[test]
1628 fn test_end_to_end_det_all_pass_threshold() {
1629 let data: Vec<f32> = vec![
1631 10.0, 20.0, 10.0, 20.0, 50.0, 60.0, 50.0, 60.0, 0.8, 0.7, 1.0, 2.0, ];
1638 let output = Array2::from_shape_vec((6, 2), data).unwrap();
1639
1640 let mut boxes = Vec::with_capacity(10);
1641 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1642
1643 assert_eq!(boxes.len(), 2);
1644 assert_eq!(boxes[0].label, 1);
1645 assert_eq!(boxes[1].label, 2);
1646 }
1647
1648 #[test]
1649 fn test_end_to_end_det_none_pass_threshold() {
1650 let data: Vec<f32> = vec![
1652 10.0, 20.0, 10.0, 20.0, 50.0, 60.0, 50.0, 60.0, 0.1, 0.2, 1.0, 2.0, ];
1659 let output = Array2::from_shape_vec((6, 2), data).unwrap();
1660
1661 let mut boxes = Vec::with_capacity(10);
1662 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1663
1664 assert_eq!(boxes.len(), 0);
1665 }
1666
1667 #[test]
1668 fn test_end_to_end_det_capacity_limit() {
1669 let data: Vec<f32> = vec![
1671 0.1, 0.2, 0.3, 0.4, 0.5, 0.1, 0.2, 0.3, 0.4, 0.5, 0.5, 0.6, 0.7, 0.8, 0.9, 0.5, 0.6, 0.7, 0.8, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.0, 1.0, 2.0, 3.0, 4.0, ];
1678 let output = Array2::from_shape_vec((6, 5), data).unwrap();
1679
1680 let mut boxes = Vec::with_capacity(2); decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1682
1683 assert_eq!(boxes.len(), 2);
1684 }
1685
1686 #[test]
1687 fn test_end_to_end_det_empty_output() {
1688 let output = Array2::<f32>::zeros((6, 0));
1690
1691 let mut boxes = Vec::with_capacity(10);
1692 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1693
1694 assert_eq!(boxes.len(), 0);
1695 }
1696
1697 #[test]
1698 fn test_end_to_end_det_pixel_coordinates() {
1699 let data: Vec<f32> = vec![
1701 100.0, 200.0, 300.0, 400.0, 0.95, 5.0, ];
1708 let output = Array2::from_shape_vec((6, 1), data).unwrap();
1709
1710 let mut boxes = Vec::with_capacity(10);
1711 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1712
1713 assert_eq!(boxes.len(), 1);
1714 assert_eq!(boxes[0].label, 5);
1715 assert!((boxes[0].bbox.xmin - 100.0).abs() < 0.01);
1716 assert!((boxes[0].bbox.ymin - 200.0).abs() < 0.01);
1717 assert!((boxes[0].bbox.xmax - 300.0).abs() < 0.01);
1718 assert!((boxes[0].bbox.ymax - 400.0).abs() < 0.01);
1719 }
1720
1721 #[test]
1722 fn test_end_to_end_det_invalid_shape() {
1723 let output = Array2::<f32>::zeros((5, 3));
1725
1726 let mut boxes = Vec::with_capacity(10);
1727 let result = decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes);
1728
1729 assert!(result.is_err());
1730 assert!(matches!(
1731 result,
1732 Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 6 rows")
1733 ));
1734 }
1735
1736 #[test]
1741 fn test_end_to_end_segdet_basic() {
1742 let num_protos = 32;
1745 let num_detections = 2;
1746 let num_features = 6 + num_protos;
1747
1748 let mut data = vec![0.0f32; num_features * num_detections];
1750 data[0] = 0.1; data[1] = 0.5; data[num_detections] = 0.1; data[num_detections + 1] = 0.5; data[2 * num_detections] = 0.4; data[2 * num_detections + 1] = 0.9; data[3 * num_detections] = 0.4; data[3 * num_detections + 1] = 0.9; data[4 * num_detections] = 0.9; data[4 * num_detections + 1] = 0.3; data[5 * num_detections] = 1.0; data[5 * num_detections + 1] = 2.0; for i in 6..num_features {
1765 data[i * num_detections] = 0.1;
1766 data[i * num_detections + 1] = 0.1;
1767 }
1768
1769 let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
1770
1771 let protos = Array3::<f32>::zeros((16, 16, num_protos));
1773
1774 let mut boxes = Vec::with_capacity(10);
1775 let mut masks = Vec::with_capacity(10);
1776 decode_yolo_end_to_end_segdet_float(
1777 output.view(),
1778 protos.view(),
1779 0.5,
1780 &mut boxes,
1781 &mut masks,
1782 )
1783 .unwrap();
1784
1785 assert_eq!(boxes.len(), 1);
1787 assert_eq!(masks.len(), 1);
1788 assert_eq!(boxes[0].label, 1);
1789 assert!((boxes[0].score - 0.9).abs() < 0.01);
1790 }
1791
1792 #[test]
1793 fn test_end_to_end_segdet_mask_coordinates() {
1794 let num_protos = 32;
1796 let num_features = 6 + num_protos;
1797
1798 let mut data = vec![0.0f32; num_features];
1799 data[0] = 0.2; data[1] = 0.2; data[2] = 0.8; data[3] = 0.8; data[4] = 0.95; data[5] = 3.0; let output = Array2::from_shape_vec((num_features, 1), data).unwrap();
1807 let protos = Array3::<f32>::zeros((16, 16, num_protos));
1808
1809 let mut boxes = Vec::with_capacity(10);
1810 let mut masks = Vec::with_capacity(10);
1811 decode_yolo_end_to_end_segdet_float(
1812 output.view(),
1813 protos.view(),
1814 0.5,
1815 &mut boxes,
1816 &mut masks,
1817 )
1818 .unwrap();
1819
1820 assert_eq!(boxes.len(), 1);
1821 assert_eq!(masks.len(), 1);
1822
1823 assert!((masks[0].xmin - boxes[0].bbox.xmin).abs() < 0.01);
1825 assert!((masks[0].ymin - boxes[0].bbox.ymin).abs() < 0.01);
1826 assert!((masks[0].xmax - boxes[0].bbox.xmax).abs() < 0.01);
1827 assert!((masks[0].ymax - boxes[0].bbox.ymax).abs() < 0.01);
1828 }
1829
1830 #[test]
1831 fn test_end_to_end_segdet_empty_output() {
1832 let num_protos = 32;
1833 let output = Array2::<f32>::zeros((6 + num_protos, 0));
1834 let protos = Array3::<f32>::zeros((16, 16, num_protos));
1835
1836 let mut boxes = Vec::with_capacity(10);
1837 let mut masks = Vec::with_capacity(10);
1838 decode_yolo_end_to_end_segdet_float(
1839 output.view(),
1840 protos.view(),
1841 0.5,
1842 &mut boxes,
1843 &mut masks,
1844 )
1845 .unwrap();
1846
1847 assert_eq!(boxes.len(), 0);
1848 assert_eq!(masks.len(), 0);
1849 }
1850
1851 #[test]
1852 fn test_end_to_end_segdet_capacity_limit() {
1853 let num_protos = 32;
1854 let num_detections = 5;
1855 let num_features = 6 + num_protos;
1856
1857 let mut data = vec![0.0f32; num_features * num_detections];
1858 for i in 0..num_detections {
1860 data[i] = 0.1 * (i as f32); data[num_detections + i] = 0.1 * (i as f32); data[2 * num_detections + i] = 0.1 * (i as f32) + 0.2; data[3 * num_detections + i] = 0.1 * (i as f32) + 0.2; data[4 * num_detections + i] = 0.9; data[5 * num_detections + i] = i as f32; }
1867
1868 let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
1869 let protos = Array3::<f32>::zeros((16, 16, num_protos));
1870
1871 let mut boxes = Vec::with_capacity(2); let mut masks = Vec::with_capacity(2);
1873 decode_yolo_end_to_end_segdet_float(
1874 output.view(),
1875 protos.view(),
1876 0.5,
1877 &mut boxes,
1878 &mut masks,
1879 )
1880 .unwrap();
1881
1882 assert_eq!(boxes.len(), 2);
1883 assert_eq!(masks.len(), 2);
1884 }
1885
1886 #[test]
1887 fn test_end_to_end_segdet_invalid_shape_too_few_rows() {
1888 let output = Array2::<f32>::zeros((6, 3));
1890 let protos = Array3::<f32>::zeros((16, 16, 32));
1891
1892 let mut boxes = Vec::with_capacity(10);
1893 let mut masks = Vec::with_capacity(10);
1894 let result = decode_yolo_end_to_end_segdet_float(
1895 output.view(),
1896 protos.view(),
1897 0.5,
1898 &mut boxes,
1899 &mut masks,
1900 );
1901
1902 assert!(result.is_err());
1903 assert!(matches!(
1904 result,
1905 Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 7 rows")
1906 ));
1907 }
1908
1909 #[test]
1910 fn test_end_to_end_segdet_invalid_shape_protos_mismatch() {
1911 let num_protos = 32;
1913 let output = Array2::<f32>::zeros((6 + 16, 3)); let protos = Array3::<f32>::zeros((16, 16, num_protos)); let mut boxes = Vec::with_capacity(10);
1917 let mut masks = Vec::with_capacity(10);
1918 let result = decode_yolo_end_to_end_segdet_float(
1919 output.view(),
1920 protos.view(),
1921 0.5,
1922 &mut boxes,
1923 &mut masks,
1924 );
1925
1926 assert!(result.is_err());
1927 assert!(matches!(
1928 result,
1929 Err(crate::DecoderError::InvalidShape(s)) if s.contains("doesn't match protos count")
1930 ));
1931 }
1932
1933 #[test]
1938 fn test_split_end_to_end_segdet_basic() {
1939 let num_protos = 32;
1942 let num_detections = 2;
1943 let num_features = 6 + num_protos;
1944
1945 let mut data = vec![0.0f32; num_features * num_detections];
1947 data[0] = 0.1; data[1] = 0.5; data[num_detections] = 0.1; data[num_detections + 1] = 0.5; data[2 * num_detections] = 0.4; data[2 * num_detections + 1] = 0.9; data[3 * num_detections] = 0.4; data[3 * num_detections + 1] = 0.9; data[4 * num_detections] = 0.9; data[4 * num_detections + 1] = 0.3; data[5 * num_detections] = 1.0; data[5 * num_detections + 1] = 2.0; for i in 6..num_features {
1962 data[i * num_detections] = 0.1;
1963 data[i * num_detections + 1] = 0.1;
1964 }
1965
1966 let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
1967 let box_coords = output.slice(s![..4, ..]);
1968 let scores = output.slice(s![4..5, ..]);
1969 let classes = output.slice(s![5..6, ..]);
1970 let mask_coeff = output.slice(s![6.., ..]);
1971 let protos = Array3::<f32>::zeros((16, 16, num_protos));
1973
1974 let mut boxes = Vec::with_capacity(10);
1975 let mut masks = Vec::with_capacity(10);
1976 decode_yolo_split_end_to_end_segdet_float(
1977 box_coords,
1978 scores,
1979 classes,
1980 mask_coeff,
1981 protos.view(),
1982 0.5,
1983 &mut boxes,
1984 &mut masks,
1985 )
1986 .unwrap();
1987
1988 assert_eq!(boxes.len(), 1);
1990 assert_eq!(masks.len(), 1);
1991 assert_eq!(boxes[0].label, 1);
1992 assert!((boxes[0].score - 0.9).abs() < 0.01);
1993 }
1994
1995 #[test]
2000 fn test_segmentation_to_mask_basic() {
2001 let data: Vec<u8> = vec![
2003 100, 200, 50, 150, 10, 255, 128, 64, 0, 127, 128, 255, 64, 64, 192, 192, ];
2008 let segmentation = Array3::from_shape_vec((4, 4, 1), data).unwrap();
2009
2010 let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2011
2012 assert_eq!(mask[[0, 0]], 0); assert_eq!(mask[[0, 1]], 1); assert_eq!(mask[[0, 2]], 0); assert_eq!(mask[[0, 3]], 1); assert_eq!(mask[[1, 1]], 1); assert_eq!(mask[[1, 2]], 1); assert_eq!(mask[[2, 0]], 0); assert_eq!(mask[[2, 1]], 0); }
2022
2023 #[test]
2024 fn test_segmentation_to_mask_all_above() {
2025 let segmentation = Array3::from_elem((4, 4, 1), 255u8);
2026 let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2027 assert!(mask.iter().all(|&x| x == 1));
2028 }
2029
2030 #[test]
2031 fn test_segmentation_to_mask_all_below() {
2032 let segmentation = Array3::from_elem((4, 4, 1), 64u8);
2033 let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2034 assert!(mask.iter().all(|&x| x == 0));
2035 }
2036
2037 #[test]
2038 fn test_segmentation_to_mask_invalid_shape() {
2039 let segmentation = Array3::from_elem((4, 4, 3), 128u8);
2040 let result = yolo_segmentation_to_mask(segmentation.view(), 128);
2041
2042 assert!(result.is_err());
2043 assert!(matches!(
2044 result,
2045 Err(crate::DecoderError::InvalidShape(s)) if s.contains("(H, W, 1)")
2046 ));
2047 }
2048
2049 #[test]
2054 fn test_protobox_clamps_edge_coordinates() {
2055 let protos = Array3::<f32>::zeros((16, 16, 4));
2057 let view = protos.view();
2058 let roi = BoundingBox {
2059 xmin: 0.5,
2060 ymin: 0.5,
2061 xmax: 1.0,
2062 ymax: 1.0,
2063 };
2064 let result = protobox(&view, &roi);
2065 assert!(result.is_ok(), "protobox should accept xmax=1.0");
2066 let (cropped, _roi_norm) = result.unwrap();
2067 assert!(cropped.shape()[0] > 0);
2069 assert!(cropped.shape()[1] > 0);
2070 assert_eq!(cropped.shape()[2], 4);
2071 }
2072
2073 #[test]
2074 fn test_protobox_rejects_wildly_out_of_range() {
2075 let protos = Array3::<f32>::zeros((16, 16, 4));
2077 let view = protos.view();
2078 let roi = BoundingBox {
2079 xmin: 0.0,
2080 ymin: 0.0,
2081 xmax: 3.0,
2082 ymax: 3.0,
2083 };
2084 let result = protobox(&view, &roi);
2085 assert!(
2086 matches!(result, Err(crate::DecoderError::InvalidShape(s)) if s.contains("un-normalized")),
2087 "protobox should reject coords > NORM_LIMIT"
2088 );
2089 }
2090
2091 #[test]
2092 fn test_protobox_accepts_slightly_over_one() {
2093 let protos = Array3::<f32>::zeros((16, 16, 4));
2095 let view = protos.view();
2096 let roi = BoundingBox {
2097 xmin: 0.0,
2098 ymin: 0.0,
2099 xmax: 1.5,
2100 ymax: 1.5,
2101 };
2102 let result = protobox(&view, &roi);
2103 assert!(
2104 result.is_ok(),
2105 "protobox should accept coords <= NORM_LIMIT (2.0)"
2106 );
2107 let (cropped, _roi_norm) = result.unwrap();
2108 assert_eq!(cropped.shape()[0], 16);
2110 assert_eq!(cropped.shape()[1], 16);
2111 }
2112
2113 #[test]
2114 fn test_segdet_float_proto_no_panic() {
2115 let num_proposals = 100; let num_classes = 80;
2119 let num_mask_coeffs = 32;
2120 let rows = 4 + num_classes + num_mask_coeffs; let mut data = vec![0.0f32; rows * num_proposals];
2126 for i in 0..num_proposals {
2127 let row = |r: usize| r * num_proposals + i;
2128 data[row(0)] = 320.0; data[row(1)] = 320.0; data[row(2)] = 50.0; data[row(3)] = 50.0; data[row(4)] = 0.9; }
2134 let boxes = ndarray::Array2::from_shape_vec((rows, num_proposals), data).unwrap();
2135
2136 let protos = ndarray::Array3::<f32>::zeros((160, 160, num_mask_coeffs));
2139
2140 let mut output_boxes = Vec::with_capacity(300);
2141
2142 let proto_data = impl_yolo_segdet_float_proto::<XYWH, _, _>(
2144 boxes.view(),
2145 protos.view(),
2146 0.5,
2147 0.7,
2148 Some(Nms::default()),
2149 &mut output_boxes,
2150 );
2151
2152 assert!(!output_boxes.is_empty());
2154 assert_eq!(proto_data.mask_coefficients.len(), output_boxes.len());
2155 for coeffs in &proto_data.mask_coefficients {
2157 assert_eq!(coeffs.len(), num_mask_coeffs);
2158 }
2159 }
2160}