1use std::fmt::Debug;
5
6use ndarray::{
7 parallel::prelude::{IntoParallelIterator, ParallelIterator},
8 s, Array2, Array3, ArrayView1, ArrayView2, ArrayView3,
9};
10use num_traits::{AsPrimitive, Float, PrimInt, Signed};
11
12use crate::{
13 byte::{
14 nms_class_aware_int, nms_extra_class_aware_int, nms_extra_int, nms_int,
15 postprocess_boxes_index_quant, postprocess_boxes_quant, quantize_score_threshold,
16 },
17 configs::Nms,
18 dequant_detect_box,
19 float::{
20 nms_class_aware_float, nms_extra_class_aware_float, nms_extra_float, nms_float,
21 postprocess_boxes_float, postprocess_boxes_index_float,
22 },
23 BBoxTypeTrait, BoundingBox, DetectBox, DetectBoxQuantized, ProtoData, ProtoTensor,
24 Quantization, Segmentation, XYWH, XYXY,
25};
26
27fn dispatch_nms_float(nms: Option<Nms>, iou: f32, boxes: Vec<DetectBox>) -> Vec<DetectBox> {
29 match nms {
30 Some(Nms::ClassAgnostic) => nms_float(iou, boxes),
31 Some(Nms::ClassAware) => nms_class_aware_float(iou, boxes),
32 None => boxes, }
34}
35
36pub(super) fn dispatch_nms_extra_float<E: Send + Sync>(
39 nms: Option<Nms>,
40 iou: f32,
41 boxes: Vec<(DetectBox, E)>,
42) -> Vec<(DetectBox, E)> {
43 match nms {
44 Some(Nms::ClassAgnostic) => nms_extra_float(iou, boxes),
45 Some(Nms::ClassAware) => nms_extra_class_aware_float(iou, boxes),
46 None => boxes, }
48}
49
50fn dispatch_nms_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync>(
53 nms: Option<Nms>,
54 iou: f32,
55 boxes: Vec<DetectBoxQuantized<SCORE>>,
56) -> Vec<DetectBoxQuantized<SCORE>> {
57 match nms {
58 Some(Nms::ClassAgnostic) => nms_int(iou, boxes),
59 Some(Nms::ClassAware) => nms_class_aware_int(iou, boxes),
60 None => boxes, }
62}
63
64fn dispatch_nms_extra_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send + Sync>(
67 nms: Option<Nms>,
68 iou: f32,
69 boxes: Vec<(DetectBoxQuantized<SCORE>, E)>,
70) -> Vec<(DetectBoxQuantized<SCORE>, E)> {
71 match nms {
72 Some(Nms::ClassAgnostic) => nms_extra_int(iou, boxes),
73 Some(Nms::ClassAware) => nms_extra_class_aware_int(iou, boxes),
74 None => boxes, }
76}
77
78pub fn decode_yolo_det<BOX: PrimInt + AsPrimitive<f32> + Send + Sync>(
85 output: (ArrayView2<BOX>, Quantization),
86 score_threshold: f32,
87 iou_threshold: f32,
88 nms: Option<Nms>,
89 output_boxes: &mut Vec<DetectBox>,
90) where
91 f32: AsPrimitive<BOX>,
92{
93 impl_yolo_quant::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
94}
95
96pub fn decode_yolo_det_float<T>(
103 output: ArrayView2<T>,
104 score_threshold: f32,
105 iou_threshold: f32,
106 nms: Option<Nms>,
107 output_boxes: &mut Vec<DetectBox>,
108) where
109 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
110 f32: AsPrimitive<T>,
111{
112 impl_yolo_float::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
113}
114
115pub fn decode_yolo_segdet_quant<
127 BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
128 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
129>(
130 boxes: (ArrayView2<BOX>, Quantization),
131 protos: (ArrayView3<PROTO>, Quantization),
132 score_threshold: f32,
133 iou_threshold: f32,
134 nms: Option<Nms>,
135 output_boxes: &mut Vec<DetectBox>,
136 output_masks: &mut Vec<Segmentation>,
137) -> Result<(), crate::DecoderError>
138where
139 f32: AsPrimitive<BOX>,
140{
141 impl_yolo_segdet_quant::<XYWH, _, _>(
142 boxes,
143 protos,
144 score_threshold,
145 iou_threshold,
146 nms,
147 output_boxes,
148 output_masks,
149 )
150}
151
152pub fn decode_yolo_segdet_float<T>(
164 boxes: ArrayView2<T>,
165 protos: ArrayView3<T>,
166 score_threshold: f32,
167 iou_threshold: f32,
168 nms: Option<Nms>,
169 output_boxes: &mut Vec<DetectBox>,
170 output_masks: &mut Vec<Segmentation>,
171) -> Result<(), crate::DecoderError>
172where
173 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
174 f32: AsPrimitive<T>,
175{
176 impl_yolo_segdet_float::<XYWH, _, _>(
177 boxes,
178 protos,
179 score_threshold,
180 iou_threshold,
181 nms,
182 output_boxes,
183 output_masks,
184 )
185}
186
187pub fn decode_yolo_split_det_quant<
199 BOX: PrimInt + AsPrimitive<i32> + AsPrimitive<f32> + Send + Sync,
200 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
201>(
202 boxes: (ArrayView2<BOX>, Quantization),
203 scores: (ArrayView2<SCORE>, Quantization),
204 score_threshold: f32,
205 iou_threshold: f32,
206 nms: Option<Nms>,
207 output_boxes: &mut Vec<DetectBox>,
208) where
209 f32: AsPrimitive<SCORE>,
210{
211 impl_yolo_split_quant::<XYWH, _, _>(
212 boxes,
213 scores,
214 score_threshold,
215 iou_threshold,
216 nms,
217 output_boxes,
218 );
219}
220
221pub fn decode_yolo_split_det_float<T>(
233 boxes: ArrayView2<T>,
234 scores: ArrayView2<T>,
235 score_threshold: f32,
236 iou_threshold: f32,
237 nms: Option<Nms>,
238 output_boxes: &mut Vec<DetectBox>,
239) where
240 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
241 f32: AsPrimitive<T>,
242{
243 impl_yolo_split_float::<XYWH, _, _>(
244 boxes,
245 scores,
246 score_threshold,
247 iou_threshold,
248 nms,
249 output_boxes,
250 );
251}
252
253#[allow(clippy::too_many_arguments)]
267pub fn decode_yolo_split_segdet<
268 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
269 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
270 MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
271 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
272>(
273 boxes: (ArrayView2<BOX>, Quantization),
274 scores: (ArrayView2<SCORE>, Quantization),
275 mask_coeff: (ArrayView2<MASK>, Quantization),
276 protos: (ArrayView3<PROTO>, Quantization),
277 score_threshold: f32,
278 iou_threshold: f32,
279 nms: Option<Nms>,
280 output_boxes: &mut Vec<DetectBox>,
281 output_masks: &mut Vec<Segmentation>,
282) -> Result<(), crate::DecoderError>
283where
284 f32: AsPrimitive<SCORE>,
285{
286 impl_yolo_split_segdet_quant::<XYWH, _, _, _, _>(
287 boxes,
288 scores,
289 mask_coeff,
290 protos,
291 score_threshold,
292 iou_threshold,
293 nms,
294 output_boxes,
295 output_masks,
296 )
297}
298
299#[allow(clippy::too_many_arguments)]
313pub fn decode_yolo_split_segdet_float<T>(
314 boxes: ArrayView2<T>,
315 scores: ArrayView2<T>,
316 mask_coeff: ArrayView2<T>,
317 protos: ArrayView3<T>,
318 score_threshold: f32,
319 iou_threshold: f32,
320 nms: Option<Nms>,
321 output_boxes: &mut Vec<DetectBox>,
322 output_masks: &mut Vec<Segmentation>,
323) -> Result<(), crate::DecoderError>
324where
325 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
326 f32: AsPrimitive<T>,
327{
328 impl_yolo_split_segdet_float::<XYWH, _, _, _, _>(
329 boxes,
330 scores,
331 mask_coeff,
332 protos,
333 score_threshold,
334 iou_threshold,
335 nms,
336 output_boxes,
337 output_masks,
338 )
339}
340
341pub fn decode_yolo_end_to_end_det_float<T>(
356 output: ArrayView2<T>,
357 score_threshold: f32,
358 output_boxes: &mut Vec<DetectBox>,
359) -> Result<(), crate::DecoderError>
360where
361 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
362 f32: AsPrimitive<T>,
363{
364 if output.shape()[0] < 6 {
366 return Err(crate::DecoderError::InvalidShape(format!(
367 "End-to-end detection output requires at least 6 rows, got {}",
368 output.shape()[0]
369 )));
370 }
371
372 let boxes = output.slice(s![0..4, ..]).reversed_axes();
374 let scores = output.slice(s![4..5, ..]).reversed_axes();
375 let classes = output.slice(s![5, ..]);
376 let mut boxes =
377 postprocess_boxes_index_float::<XYXY, _, _>(score_threshold.as_(), boxes, scores);
378 boxes.truncate(output_boxes.capacity());
379 output_boxes.clear();
380 for (mut b, i) in boxes.into_iter() {
381 b.label = classes[i].as_() as usize;
382 output_boxes.push(b);
383 }
384 Ok(())
386}
387
388pub fn decode_yolo_end_to_end_segdet_float<T>(
406 output: ArrayView2<T>,
407 protos: ArrayView3<T>,
408 score_threshold: f32,
409 output_boxes: &mut Vec<DetectBox>,
410 output_masks: &mut Vec<crate::Segmentation>,
411) -> Result<(), crate::DecoderError>
412where
413 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
414 f32: AsPrimitive<T>,
415{
416 let (boxes, scores, classes, mask_coeff) =
417 postprocess_yolo_end_to_end_segdet(&output, protos.dim().2)?;
418 let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
419 boxes,
420 scores,
421 classes,
422 score_threshold,
423 output_boxes.capacity(),
424 );
425
426 impl_yolo_split_segdet_process_masks(boxes, mask_coeff, protos, output_boxes, output_masks)
429}
430
431pub fn decode_yolo_split_end_to_end_det_float<T: Float + AsPrimitive<f32>>(
440 boxes: ArrayView2<T>,
441 scores: ArrayView2<T>,
442 classes: ArrayView2<T>,
443 score_threshold: f32,
444 output_boxes: &mut Vec<DetectBox>,
445) -> Result<(), crate::DecoderError> {
446 let n = boxes.shape()[0];
447 if boxes.shape()[1] != 4 {
448 return Err(crate::DecoderError::InvalidShape(format!(
449 "Split end-to-end boxes must have 4 columns, got {}",
450 boxes.shape()[1]
451 )));
452 }
453 output_boxes.clear();
454 for i in 0..n {
455 let score: f32 = scores[[i, 0]].as_();
456 if score < score_threshold {
457 continue;
458 }
459 if output_boxes.len() >= output_boxes.capacity() {
460 break;
461 }
462 output_boxes.push(DetectBox {
463 bbox: BoundingBox {
464 xmin: boxes[[i, 0]].as_(),
465 ymin: boxes[[i, 1]].as_(),
466 xmax: boxes[[i, 2]].as_(),
467 ymax: boxes[[i, 3]].as_(),
468 },
469 score,
470 label: classes[[i, 0]].as_() as usize,
471 });
472 }
473 Ok(())
474}
475
476#[allow(clippy::too_many_arguments)]
485pub fn decode_yolo_split_end_to_end_segdet_float<T>(
486 boxes: ArrayView2<T>,
487 scores: ArrayView2<T>,
488 classes: ArrayView2<T>,
489 mask_coeff: ArrayView2<T>,
490 protos: ArrayView3<T>,
491 score_threshold: f32,
492 output_boxes: &mut Vec<DetectBox>,
493 output_masks: &mut Vec<crate::Segmentation>,
494) -> Result<(), crate::DecoderError>
495where
496 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
497 f32: AsPrimitive<T>,
498{
499 let (boxes, scores, classes, mask_coeff) =
500 postprocess_yolo_split_end_to_end_segdet(boxes, scores, &classes, mask_coeff)?;
501 let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
502 boxes,
503 scores,
504 classes,
505 score_threshold,
506 output_boxes.capacity(),
507 );
508
509 impl_yolo_split_segdet_process_masks(boxes, mask_coeff, protos, output_boxes, output_masks)
510}
511
512#[allow(clippy::type_complexity)]
513pub(crate) fn postprocess_yolo_end_to_end_segdet<'a, T>(
514 output: &'a ArrayView2<'_, T>,
515 num_protos: usize,
516) -> Result<
517 (
518 ArrayView2<'a, T>,
519 ArrayView2<'a, T>,
520 ArrayView1<'a, T>,
521 ArrayView2<'a, T>,
522 ),
523 crate::DecoderError,
524> {
525 if output.shape()[0] < 7 {
527 return Err(crate::DecoderError::InvalidShape(format!(
528 "End-to-end segdet output requires at least 7 rows, got {}",
529 output.shape()[0]
530 )));
531 }
532
533 let num_mask_coeffs = output.shape()[0] - 6;
534 if num_mask_coeffs != num_protos {
535 return Err(crate::DecoderError::InvalidShape(format!(
536 "Mask coefficients count ({}) doesn't match protos count ({})",
537 num_mask_coeffs, num_protos
538 )));
539 }
540
541 let boxes = output.slice(s![0..4, ..]).reversed_axes();
543 let scores = output.slice(s![4..5, ..]).reversed_axes();
544 let classes = output.slice(s![5, ..]);
545 let mask_coeff = output.slice(s![6.., ..]).reversed_axes();
546 Ok((boxes, scores, classes, mask_coeff))
547}
548
549#[allow(clippy::type_complexity)]
550pub(crate) fn postprocess_yolo_split_end_to_end_segdet<
551 'a,
552 'b,
553 'c,
554 'd,
555 BOXES,
556 SCORES,
557 CLASS,
558 MASK,
559>(
560 boxes: ArrayView2<'a, BOXES>,
561 scores: ArrayView2<'b, SCORES>,
562 classes: &'c ArrayView2<CLASS>,
563 mask_coeff: ArrayView2<'d, MASK>,
564) -> Result<
565 (
566 ArrayView2<'a, BOXES>,
567 ArrayView2<'b, SCORES>,
568 ArrayView1<'c, CLASS>,
569 ArrayView2<'d, MASK>,
570 ),
571 crate::DecoderError,
572> {
573 if boxes.shape()[0] != 4 {
574 return Err(crate::DecoderError::InvalidShape(format!(
575 "Split end-to-end boxes must have 4 columns, got {}",
576 boxes.shape()[0]
577 )));
578 }
579 let boxes = boxes.reversed_axes();
580 let scores = scores.reversed_axes();
581 let classes = classes.slice(s![0, ..]);
582 let mask_coeff = mask_coeff.reversed_axes();
583 Ok((boxes, scores, classes, mask_coeff))
584}
585pub fn impl_yolo_quant<B: BBoxTypeTrait, T: PrimInt + AsPrimitive<f32> + Send + Sync>(
590 output: (ArrayView2<T>, Quantization),
591 score_threshold: f32,
592 iou_threshold: f32,
593 nms: Option<Nms>,
594 output_boxes: &mut Vec<DetectBox>,
595) where
596 f32: AsPrimitive<T>,
597{
598 let (boxes, quant_boxes) = output;
599 let (boxes_tensor, scores_tensor) = postprocess_yolo(&boxes);
600
601 let boxes = {
602 let score_threshold = quantize_score_threshold(score_threshold, quant_boxes);
603 postprocess_boxes_quant::<B, _, _>(
604 score_threshold,
605 boxes_tensor,
606 scores_tensor,
607 quant_boxes,
608 )
609 };
610
611 let boxes = dispatch_nms_int(nms, iou_threshold, boxes);
612 let len = output_boxes.capacity().min(boxes.len());
613 output_boxes.clear();
614 for b in boxes.iter().take(len) {
615 output_boxes.push(dequant_detect_box(b, quant_boxes));
616 }
617}
618
619pub fn impl_yolo_float<B: BBoxTypeTrait, T: Float + AsPrimitive<f32> + Send + Sync>(
624 output: ArrayView2<T>,
625 score_threshold: f32,
626 iou_threshold: f32,
627 nms: Option<Nms>,
628 output_boxes: &mut Vec<DetectBox>,
629) where
630 f32: AsPrimitive<T>,
631{
632 let (boxes_tensor, scores_tensor) = postprocess_yolo(&output);
633 let boxes =
634 postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
635 let boxes = dispatch_nms_float(nms, iou_threshold, boxes);
636 let len = output_boxes.capacity().min(boxes.len());
637 output_boxes.clear();
638 for b in boxes.into_iter().take(len) {
639 output_boxes.push(b);
640 }
641}
642
643pub fn impl_yolo_split_quant<
653 B: BBoxTypeTrait,
654 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
655 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
656>(
657 boxes: (ArrayView2<BOX>, Quantization),
658 scores: (ArrayView2<SCORE>, Quantization),
659 score_threshold: f32,
660 iou_threshold: f32,
661 nms: Option<Nms>,
662 output_boxes: &mut Vec<DetectBox>,
663) where
664 f32: AsPrimitive<SCORE>,
665{
666 let (boxes_tensor, quant_boxes) = boxes;
667 let (scores_tensor, quant_scores) = scores;
668
669 let boxes_tensor = boxes_tensor.reversed_axes();
670 let scores_tensor = scores_tensor.reversed_axes();
671
672 let boxes = {
673 let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
674 postprocess_boxes_quant::<B, _, _>(
675 score_threshold,
676 boxes_tensor,
677 scores_tensor,
678 quant_boxes,
679 )
680 };
681
682 let boxes = dispatch_nms_int(nms, iou_threshold, boxes);
683 let len = output_boxes.capacity().min(boxes.len());
684 output_boxes.clear();
685 for b in boxes.iter().take(len) {
686 output_boxes.push(dequant_detect_box(b, quant_scores));
687 }
688}
689
690pub fn impl_yolo_split_float<
699 B: BBoxTypeTrait,
700 BOX: Float + AsPrimitive<f32> + Send + Sync,
701 SCORE: Float + AsPrimitive<f32> + Send + Sync,
702>(
703 boxes_tensor: ArrayView2<BOX>,
704 scores_tensor: ArrayView2<SCORE>,
705 score_threshold: f32,
706 iou_threshold: f32,
707 nms: Option<Nms>,
708 output_boxes: &mut Vec<DetectBox>,
709) where
710 f32: AsPrimitive<SCORE>,
711{
712 let boxes_tensor = boxes_tensor.reversed_axes();
713 let scores_tensor = scores_tensor.reversed_axes();
714 let boxes =
715 postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
716 let boxes = dispatch_nms_float(nms, iou_threshold, boxes);
717 let len = output_boxes.capacity().min(boxes.len());
718 output_boxes.clear();
719 for b in boxes.into_iter().take(len) {
720 output_boxes.push(b);
721 }
722}
723
724pub fn impl_yolo_segdet_quant<
734 B: BBoxTypeTrait,
735 BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
736 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
737>(
738 boxes: (ArrayView2<BOX>, Quantization),
739 protos: (ArrayView3<PROTO>, Quantization),
740 score_threshold: f32,
741 iou_threshold: f32,
742 nms: Option<Nms>,
743 output_boxes: &mut Vec<DetectBox>,
744 output_masks: &mut Vec<Segmentation>,
745) -> Result<(), crate::DecoderError>
746where
747 f32: AsPrimitive<BOX>,
748{
749 let (boxes, quant_boxes) = boxes;
750 let num_protos = protos.0.dim().2;
751
752 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
753 let boxes = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
754 (boxes_tensor, quant_boxes),
755 (scores_tensor, quant_boxes),
756 score_threshold,
757 iou_threshold,
758 nms,
759 output_boxes.capacity(),
760 );
761
762 impl_yolo_split_segdet_quant_process_masks::<_, _>(
763 boxes,
764 (mask_tensor, quant_boxes),
765 protos,
766 output_boxes,
767 output_masks,
768 )
769}
770
771pub fn impl_yolo_segdet_float<
781 B: BBoxTypeTrait,
782 BOX: Float + AsPrimitive<f32> + Send + Sync,
783 PROTO: Float + AsPrimitive<f32> + Send + Sync,
784>(
785 boxes: ArrayView2<BOX>,
786 protos: ArrayView3<PROTO>,
787 score_threshold: f32,
788 iou_threshold: f32,
789 nms: Option<Nms>,
790 output_boxes: &mut Vec<DetectBox>,
791 output_masks: &mut Vec<Segmentation>,
792) -> Result<(), crate::DecoderError>
793where
794 f32: AsPrimitive<BOX>,
795{
796 let num_protos = protos.dim().2;
797 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
798 let boxes = impl_yolo_segdet_get_boxes::<XYWH, _, _>(
799 boxes_tensor,
800 scores_tensor,
801 score_threshold,
802 iou_threshold,
803 nms,
804 output_boxes.capacity(),
805 );
806 impl_yolo_split_segdet_process_masks(boxes, mask_tensor, protos, output_boxes, output_masks)
807}
808
809pub(crate) fn impl_yolo_segdet_get_boxes<
810 B: BBoxTypeTrait,
811 BOX: Float + AsPrimitive<f32> + Send + Sync,
812 SCORE: Float + AsPrimitive<f32> + Send + Sync,
813>(
814 boxes_tensor: ArrayView2<BOX>,
815 scores_tensor: ArrayView2<SCORE>,
816 score_threshold: f32,
817 iou_threshold: f32,
818 nms: Option<Nms>,
819 max_boxes: usize,
820) -> Vec<(DetectBox, usize)>
821where
822 f32: AsPrimitive<SCORE>,
823{
824 let boxes = postprocess_boxes_index_float::<B, _, _>(
825 score_threshold.as_(),
826 boxes_tensor,
827 scores_tensor,
828 );
829 let mut boxes = dispatch_nms_extra_float(nms, iou_threshold, boxes);
830 boxes.truncate(max_boxes);
831 boxes
832}
833
834pub(crate) fn impl_yolo_end_to_end_segdet_get_boxes<
835 B: BBoxTypeTrait,
836 BOX: Float + AsPrimitive<f32> + Send + Sync,
837 SCORE: Float + AsPrimitive<f32> + Send + Sync,
838 CLASS: AsPrimitive<f32> + Send + Sync,
839>(
840 boxes: ArrayView2<BOX>,
841 scores: ArrayView2<SCORE>,
842 classes: ArrayView1<CLASS>,
843 score_threshold: f32,
844 max_boxes: usize,
845) -> Vec<(DetectBox, usize)>
846where
847 f32: AsPrimitive<SCORE>,
848{
849 let mut boxes = postprocess_boxes_index_float::<B, _, _>(score_threshold.as_(), boxes, scores);
850 boxes.truncate(max_boxes);
851 for (b, ind) in &mut boxes {
852 b.label = classes[*ind].as_().round() as usize;
853 }
854 boxes
855}
856
857pub(crate) fn impl_yolo_split_segdet_process_masks<
858 MASK: Float + AsPrimitive<f32> + Send + Sync,
859 PROTO: Float + AsPrimitive<f32> + Send + Sync,
860>(
861 boxes: Vec<(DetectBox, usize)>,
862 masks_tensor: ArrayView2<MASK>,
863 protos_tensor: ArrayView3<PROTO>,
864 output_boxes: &mut Vec<DetectBox>,
865 output_masks: &mut Vec<Segmentation>,
866) -> Result<(), crate::DecoderError> {
867 let boxes = decode_segdet_f32(boxes, masks_tensor, protos_tensor)?;
868 output_boxes.clear();
869 output_masks.clear();
870 for (b, m) in boxes.into_iter() {
871 output_boxes.push(b);
872 output_masks.push(Segmentation {
873 xmin: b.bbox.xmin,
874 ymin: b.bbox.ymin,
875 xmax: b.bbox.xmax,
876 ymax: b.bbox.ymax,
877 segmentation: m,
878 });
879 }
880 Ok(())
881}
882pub(crate) fn impl_yolo_split_segdet_quant_get_boxes<
886 B: BBoxTypeTrait,
887 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
888 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
889>(
890 boxes: (ArrayView2<BOX>, Quantization),
891 scores: (ArrayView2<SCORE>, Quantization),
892 score_threshold: f32,
893 iou_threshold: f32,
894 nms: Option<Nms>,
895 max_boxes: usize,
896) -> Vec<(DetectBox, usize)>
897where
898 f32: AsPrimitive<SCORE>,
899{
900 let (boxes_tensor, quant_boxes) = boxes;
901 let (scores_tensor, quant_scores) = scores;
902
903 let boxes = {
904 let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
905 postprocess_boxes_index_quant::<B, _, _>(
906 score_threshold,
907 boxes_tensor,
908 scores_tensor,
909 quant_boxes,
910 )
911 };
912 let mut boxes = dispatch_nms_extra_int(nms, iou_threshold, boxes);
913 boxes.truncate(max_boxes);
914 boxes
915 .into_iter()
916 .map(|(b, i)| (dequant_detect_box(&b, quant_scores), i))
917 .collect()
918}
919
920pub(crate) fn impl_yolo_split_segdet_quant_process_masks<
921 MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
922 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
923>(
924 boxes: Vec<(DetectBox, usize)>,
925 mask_coeff: (ArrayView2<MASK>, Quantization),
926 protos: (ArrayView3<PROTO>, Quantization),
927 output_boxes: &mut Vec<DetectBox>,
928 output_masks: &mut Vec<Segmentation>,
929) -> Result<(), crate::DecoderError> {
930 let (masks, quant_masks) = mask_coeff;
931 let (protos, quant_protos) = protos;
932
933 let boxes = decode_segdet_quant(boxes, masks, protos, quant_masks, quant_protos)?;
934 output_boxes.clear();
935 output_masks.clear();
936 for (b, m) in boxes.into_iter() {
937 output_boxes.push(b);
938 output_masks.push(Segmentation {
939 xmin: b.bbox.xmin,
940 ymin: b.bbox.ymin,
941 xmax: b.bbox.xmax,
942 ymax: b.bbox.ymax,
943 segmentation: m,
944 });
945 }
946 Ok(())
947}
948
949#[allow(clippy::too_many_arguments)]
950pub fn impl_yolo_split_segdet_quant<
962 B: BBoxTypeTrait,
963 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
964 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
965 MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
966 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
967>(
968 boxes: (ArrayView2<BOX>, Quantization),
969 scores: (ArrayView2<SCORE>, Quantization),
970 mask_coeff: (ArrayView2<MASK>, Quantization),
971 protos: (ArrayView3<PROTO>, Quantization),
972 score_threshold: f32,
973 iou_threshold: f32,
974 nms: Option<Nms>,
975 output_boxes: &mut Vec<DetectBox>,
976 output_masks: &mut Vec<Segmentation>,
977) -> Result<(), crate::DecoderError>
978where
979 f32: AsPrimitive<SCORE>,
980{
981 let (boxes_, scores_, mask_coeff_) =
982 postprocess_yolo_split_segdet(boxes.0, scores.0, mask_coeff.0);
983 let boxes = (boxes_, boxes.1);
984 let scores = (scores_, scores.1);
985 let mask_coeff = (mask_coeff_, mask_coeff.1);
986
987 let boxes = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
988 boxes,
989 scores,
990 score_threshold,
991 iou_threshold,
992 nms,
993 output_boxes.capacity(),
994 );
995
996 impl_yolo_split_segdet_quant_process_masks(
997 boxes,
998 mask_coeff,
999 protos,
1000 output_boxes,
1001 output_masks,
1002 )
1003}
1004
1005#[allow(clippy::too_many_arguments)]
1006pub fn impl_yolo_split_segdet_float<
1018 B: BBoxTypeTrait,
1019 BOX: Float + AsPrimitive<f32> + Send + Sync,
1020 SCORE: Float + AsPrimitive<f32> + Send + Sync,
1021 MASK: Float + AsPrimitive<f32> + Send + Sync,
1022 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1023>(
1024 boxes_tensor: ArrayView2<BOX>,
1025 scores_tensor: ArrayView2<SCORE>,
1026 mask_tensor: ArrayView2<MASK>,
1027 protos: ArrayView3<PROTO>,
1028 score_threshold: f32,
1029 iou_threshold: f32,
1030 nms: Option<Nms>,
1031 output_boxes: &mut Vec<DetectBox>,
1032 output_masks: &mut Vec<Segmentation>,
1033) -> Result<(), crate::DecoderError>
1034where
1035 f32: AsPrimitive<SCORE>,
1036{
1037 let (boxes_tensor, scores_tensor, mask_tensor) =
1038 postprocess_yolo_split_segdet(boxes_tensor, scores_tensor, mask_tensor);
1039
1040 let boxes = impl_yolo_segdet_get_boxes::<XYWH, _, _>(
1041 boxes_tensor,
1042 scores_tensor,
1043 score_threshold,
1044 iou_threshold,
1045 nms,
1046 output_boxes.capacity(),
1047 );
1048 impl_yolo_split_segdet_process_masks(boxes, mask_tensor, protos, output_boxes, output_masks)
1049}
1050
1051pub fn impl_yolo_segdet_quant_proto<
1058 B: BBoxTypeTrait,
1059 BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1060 PROTO: PrimInt
1061 + AsPrimitive<i64>
1062 + AsPrimitive<i128>
1063 + AsPrimitive<f32>
1064 + AsPrimitive<i8>
1065 + Send
1066 + Sync,
1067>(
1068 boxes: (ArrayView2<BOX>, Quantization),
1069 protos: (ArrayView3<PROTO>, Quantization),
1070 score_threshold: f32,
1071 iou_threshold: f32,
1072 nms: Option<Nms>,
1073 output_boxes: &mut Vec<DetectBox>,
1074) -> ProtoData
1075where
1076 f32: AsPrimitive<BOX>,
1077{
1078 let (boxes_arr, quant_boxes) = boxes;
1079 let (protos_arr, quant_protos) = protos;
1080 let num_protos = protos_arr.dim().2;
1081
1082 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes_arr, num_protos);
1083
1084 let det_indices = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
1085 (boxes_tensor, quant_boxes),
1086 (scores_tensor, quant_boxes),
1087 score_threshold,
1088 iou_threshold,
1089 nms,
1090 output_boxes.capacity(),
1091 );
1092
1093 extract_proto_data_quant(
1094 det_indices,
1095 mask_tensor,
1096 quant_boxes,
1097 protos_arr,
1098 quant_protos,
1099 output_boxes,
1100 )
1101}
1102
1103pub fn impl_yolo_segdet_float_proto<
1106 B: BBoxTypeTrait,
1107 BOX: Float + AsPrimitive<f32> + Send + Sync,
1108 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1109>(
1110 boxes: ArrayView2<BOX>,
1111 protos: ArrayView3<PROTO>,
1112 score_threshold: f32,
1113 iou_threshold: f32,
1114 nms: Option<Nms>,
1115 output_boxes: &mut Vec<DetectBox>,
1116) -> ProtoData
1117where
1118 f32: AsPrimitive<BOX>,
1119{
1120 let num_protos = protos.dim().2;
1121 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
1122
1123 let boxes = impl_yolo_segdet_get_boxes::<XYWH, _, _>(
1124 boxes_tensor,
1125 scores_tensor,
1126 score_threshold,
1127 iou_threshold,
1128 nms,
1129 output_boxes.capacity(),
1130 );
1131
1132 extract_proto_data_float(boxes, mask_tensor, protos, output_boxes)
1133}
1134
1135#[allow(clippy::too_many_arguments)]
1138pub fn impl_yolo_split_segdet_quant_proto<
1139 B: BBoxTypeTrait,
1140 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
1141 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
1142 MASK: PrimInt + AsPrimitive<f32> + Send + Sync,
1143 PROTO: PrimInt + AsPrimitive<f32> + AsPrimitive<i8> + Send + Sync,
1144>(
1145 boxes: (ArrayView2<BOX>, Quantization),
1146 scores: (ArrayView2<SCORE>, Quantization),
1147 mask_coeff: (ArrayView2<MASK>, Quantization),
1148 protos: (ArrayView3<PROTO>, Quantization),
1149 score_threshold: f32,
1150 iou_threshold: f32,
1151 nms: Option<Nms>,
1152 output_boxes: &mut Vec<DetectBox>,
1153) -> ProtoData
1154where
1155 f32: AsPrimitive<SCORE>,
1156{
1157 let (boxes_, scores_, mask_coeff_) =
1158 postprocess_yolo_split_segdet(boxes.0, scores.0, mask_coeff.0);
1159 let boxes = (boxes_, boxes.1);
1160 let scores = (scores_, scores.1);
1161 let mask_coeff = (mask_coeff_, mask_coeff.1);
1162
1163 let det_indices = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
1164 boxes,
1165 scores,
1166 score_threshold,
1167 iou_threshold,
1168 nms,
1169 output_boxes.capacity(),
1170 );
1171
1172 let (masks, quant_masks) = mask_coeff;
1173 let masks = masks.reversed_axes();
1174 let (protos_arr, quant_protos) = protos;
1175
1176 extract_proto_data_quant(
1177 det_indices,
1178 masks,
1179 quant_masks,
1180 protos_arr,
1181 quant_protos,
1182 output_boxes,
1183 )
1184}
1185
1186#[allow(clippy::too_many_arguments)]
1189pub fn impl_yolo_split_segdet_float_proto<
1190 B: BBoxTypeTrait,
1191 BOX: Float + AsPrimitive<f32> + Send + Sync,
1192 SCORE: Float + AsPrimitive<f32> + Send + Sync,
1193 MASK: Float + AsPrimitive<f32> + Send + Sync,
1194 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1195>(
1196 boxes_tensor: ArrayView2<BOX>,
1197 scores_tensor: ArrayView2<SCORE>,
1198 mask_tensor: ArrayView2<MASK>,
1199 protos: ArrayView3<PROTO>,
1200 score_threshold: f32,
1201 iou_threshold: f32,
1202 nms: Option<Nms>,
1203 output_boxes: &mut Vec<DetectBox>,
1204) -> ProtoData
1205where
1206 f32: AsPrimitive<SCORE>,
1207{
1208 let (boxes_tensor, scores_tensor, mask_tensor) =
1209 postprocess_yolo_split_segdet(boxes_tensor, scores_tensor, mask_tensor);
1210 let det_indices = impl_yolo_segdet_get_boxes::<B, _, _>(
1211 boxes_tensor,
1212 scores_tensor,
1213 score_threshold,
1214 iou_threshold,
1215 nms,
1216 output_boxes.capacity(),
1217 );
1218
1219 extract_proto_data_float(det_indices, mask_tensor, protos, output_boxes)
1220}
1221
1222pub fn decode_yolo_end_to_end_segdet_float_proto<T>(
1224 output: ArrayView2<T>,
1225 protos: ArrayView3<T>,
1226 score_threshold: f32,
1227 output_boxes: &mut Vec<DetectBox>,
1228) -> Result<ProtoData, crate::DecoderError>
1229where
1230 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
1231 f32: AsPrimitive<T>,
1232{
1233 let (boxes, scores, classes, mask_coeff) =
1234 postprocess_yolo_end_to_end_segdet(&output, protos.dim().2)?;
1235 let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
1236 boxes,
1237 scores,
1238 classes,
1239 score_threshold,
1240 output_boxes.capacity(),
1241 );
1242
1243 Ok(extract_proto_data_float(
1244 boxes,
1245 mask_coeff,
1246 protos,
1247 output_boxes,
1248 ))
1249}
1250
1251#[allow(clippy::too_many_arguments)]
1253pub fn decode_yolo_split_end_to_end_segdet_float_proto<T>(
1254 boxes: ArrayView2<T>,
1255 scores: ArrayView2<T>,
1256 classes: ArrayView2<T>,
1257 mask_coeff: ArrayView2<T>,
1258 protos: ArrayView3<T>,
1259 score_threshold: f32,
1260 output_boxes: &mut Vec<DetectBox>,
1261) -> Result<ProtoData, crate::DecoderError>
1262where
1263 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
1264 f32: AsPrimitive<T>,
1265{
1266 let (boxes, scores, classes, mask_coeff) =
1267 postprocess_yolo_split_end_to_end_segdet(boxes, scores, &classes, mask_coeff)?;
1268 let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
1269 boxes,
1270 scores,
1271 classes,
1272 score_threshold,
1273 output_boxes.capacity(),
1274 );
1275
1276 Ok(extract_proto_data_float(
1277 boxes,
1278 mask_coeff,
1279 protos,
1280 output_boxes,
1281 ))
1282}
1283
1284pub(super) fn extract_proto_data_float<
1286 MASK: Float + AsPrimitive<f32> + Send + Sync,
1287 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1288>(
1289 det_indices: Vec<(DetectBox, usize)>,
1290 mask_tensor: ArrayView2<MASK>,
1291 protos: ArrayView3<PROTO>,
1292 output_boxes: &mut Vec<DetectBox>,
1293) -> ProtoData {
1294 let mut mask_coefficients = Vec::with_capacity(det_indices.len());
1295 output_boxes.clear();
1296 for (det, idx) in det_indices {
1297 output_boxes.push(det);
1298 let row = mask_tensor.row(idx);
1299 mask_coefficients.push(row.iter().map(|v| v.as_()).collect());
1300 }
1301 let protos_f32 = protos.map(|v| v.as_());
1302 ProtoData {
1303 mask_coefficients,
1304 protos: ProtoTensor::Float(protos_f32),
1305 }
1306}
1307
1308pub(crate) fn extract_proto_data_quant<
1314 MASK: PrimInt + AsPrimitive<f32> + Send + Sync,
1315 PROTO: PrimInt + AsPrimitive<f32> + AsPrimitive<i8> + Send + Sync,
1316>(
1317 det_indices: Vec<(DetectBox, usize)>,
1318 mask_tensor: ArrayView2<MASK>,
1319 quant_masks: Quantization,
1320 protos: ArrayView3<PROTO>,
1321 quant_protos: Quantization,
1322 output_boxes: &mut Vec<DetectBox>,
1323) -> ProtoData {
1324 let mut mask_coefficients = Vec::with_capacity(det_indices.len());
1325 output_boxes.clear();
1326 for (det, idx) in det_indices {
1327 output_boxes.push(det);
1328 let row = mask_tensor.row(idx);
1329 mask_coefficients.push(
1330 row.iter()
1331 .map(|v| (v.as_() - quant_masks.zero_point as f32) * quant_masks.scale)
1332 .collect(),
1333 );
1334 }
1335 let protos_i8 = protos.map(|v| {
1337 let v_i8: i8 = v.as_();
1338 v_i8
1339 });
1340 ProtoData {
1341 mask_coefficients,
1342 protos: ProtoTensor::Quantized {
1343 protos: protos_i8,
1344 quantization: quant_protos,
1345 },
1346 }
1347}
1348
1349fn postprocess_yolo<'a, T>(
1350 output: &'a ArrayView2<'_, T>,
1351) -> (ArrayView2<'a, T>, ArrayView2<'a, T>) {
1352 let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1353 let scores_tensor = output.slice(s![4.., ..,]).reversed_axes();
1354 (boxes_tensor, scores_tensor)
1355}
1356
1357pub(crate) fn postprocess_yolo_seg<'a, T>(
1358 output: &'a ArrayView2<'_, T>,
1359 num_protos: usize,
1360) -> (ArrayView2<'a, T>, ArrayView2<'a, T>, ArrayView2<'a, T>) {
1361 assert!(
1362 output.shape()[0] > num_protos + 4,
1363 "Output shape is too short: {} <= {} + 4",
1364 output.shape()[0],
1365 num_protos
1366 );
1367 let num_classes = output.shape()[0] - 4 - num_protos;
1368 let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1369 let scores_tensor = output.slice(s![4..(num_classes + 4), ..,]).reversed_axes();
1370 let mask_tensor = output.slice(s![(num_classes + 4).., ..,]).reversed_axes();
1371 (boxes_tensor, scores_tensor, mask_tensor)
1372}
1373
1374pub(crate) fn postprocess_yolo_split_segdet<'a, 'b, 'c, BOX, SCORE, MASK>(
1375 boxes_tensor: ArrayView2<'a, BOX>,
1376 scores_tensor: ArrayView2<'b, SCORE>,
1377 mask_tensor: ArrayView2<'c, MASK>,
1378) -> (
1379 ArrayView2<'a, BOX>,
1380 ArrayView2<'b, SCORE>,
1381 ArrayView2<'c, MASK>,
1382) {
1383 let boxes_tensor = boxes_tensor.reversed_axes();
1384 let scores_tensor = scores_tensor.reversed_axes();
1385 let mask_tensor = mask_tensor.reversed_axes();
1386 (boxes_tensor, scores_tensor, mask_tensor)
1387}
1388
1389fn decode_segdet_f32<
1390 MASK: Float + AsPrimitive<f32> + Send + Sync,
1391 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1392>(
1393 boxes: Vec<(DetectBox, usize)>,
1394 masks: ArrayView2<MASK>,
1395 protos: ArrayView3<PROTO>,
1396) -> Result<Vec<(DetectBox, Array3<u8>)>, crate::DecoderError> {
1397 if boxes.is_empty() {
1398 return Ok(Vec::new());
1399 }
1400 if masks.shape()[1] != protos.shape()[2] {
1401 return Err(crate::DecoderError::InvalidShape(format!(
1402 "Mask coefficients count ({}) doesn't match protos channel count ({})",
1403 masks.shape()[1],
1404 protos.shape()[2],
1405 )));
1406 }
1407 boxes
1408 .into_par_iter()
1409 .map(|mut b| {
1410 let ind = b.1;
1411 let (protos, roi) = protobox(&protos, &b.0.bbox)?;
1412 b.0.bbox = roi;
1413 Ok((b.0, make_segmentation(masks.row(ind), protos.view())))
1414 })
1415 .collect()
1416}
1417
1418pub(crate) fn decode_segdet_quant<
1419 MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1420 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1421>(
1422 boxes: Vec<(DetectBox, usize)>,
1423 masks: ArrayView2<MASK>,
1424 protos: ArrayView3<PROTO>,
1425 quant_masks: Quantization,
1426 quant_protos: Quantization,
1427) -> Result<Vec<(DetectBox, Array3<u8>)>, crate::DecoderError> {
1428 if boxes.is_empty() {
1429 return Ok(Vec::new());
1430 }
1431 if masks.shape()[1] != protos.shape()[2] {
1432 return Err(crate::DecoderError::InvalidShape(format!(
1433 "Mask coefficients count ({}) doesn't match protos channel count ({})",
1434 masks.shape()[1],
1435 protos.shape()[2],
1436 )));
1437 }
1438
1439 let total_bits = MASK::zero().count_zeros() + PROTO::zero().count_zeros() + 5; boxes
1441 .into_iter()
1442 .map(|mut b| {
1443 let i = b.1;
1444 let (protos, roi) = protobox(&protos, &b.0.bbox.to_canonical())?;
1445 b.0.bbox = roi;
1446 let seg = match total_bits {
1447 0..=64 => make_segmentation_quant::<MASK, PROTO, i64>(
1448 masks.row(i),
1449 protos.view(),
1450 quant_masks,
1451 quant_protos,
1452 ),
1453 65..=128 => make_segmentation_quant::<MASK, PROTO, i128>(
1454 masks.row(i),
1455 protos.view(),
1456 quant_masks,
1457 quant_protos,
1458 ),
1459 _ => {
1460 return Err(crate::DecoderError::NotSupported(format!(
1461 "Unsupported bit width ({total_bits}) for segmentation computation"
1462 )));
1463 }
1464 };
1465 Ok((b.0, seg))
1466 })
1467 .collect()
1468}
1469
1470fn protobox<'a, T>(
1471 protos: &'a ArrayView3<T>,
1472 roi: &BoundingBox,
1473) -> Result<(ArrayView3<'a, T>, BoundingBox), crate::DecoderError> {
1474 let width = protos.dim().1 as f32;
1475 let height = protos.dim().0 as f32;
1476
1477 const NORM_LIMIT: f32 = 2.0;
1488 if roi.xmin > NORM_LIMIT
1489 || roi.ymin > NORM_LIMIT
1490 || roi.xmax > NORM_LIMIT
1491 || roi.ymax > NORM_LIMIT
1492 {
1493 return Err(crate::DecoderError::InvalidShape(format!(
1494 "Bounding box coordinates appear un-normalized (pixel-space). \
1495 Got bbox=({:.2}, {:.2}, {:.2}, {:.2}) but expected values in [0, 1]. \
1496 ONNX models output pixel-space boxes — normalize them by dividing by \
1497 the input dimensions before calling decode().",
1498 roi.xmin, roi.ymin, roi.xmax, roi.ymax,
1499 )));
1500 }
1501
1502 let roi = [
1503 (roi.xmin * width).clamp(0.0, width) as usize,
1504 (roi.ymin * height).clamp(0.0, height) as usize,
1505 (roi.xmax * width).clamp(0.0, width).ceil() as usize,
1506 (roi.ymax * height).clamp(0.0, height).ceil() as usize,
1507 ];
1508
1509 let roi_norm = [
1510 roi[0] as f32 / width,
1511 roi[1] as f32 / height,
1512 roi[2] as f32 / width,
1513 roi[3] as f32 / height,
1514 ]
1515 .into();
1516
1517 let cropped = protos.slice(s![roi[1]..roi[3], roi[0]..roi[2], ..]);
1518
1519 Ok((cropped, roi_norm))
1520}
1521
1522fn make_segmentation<
1528 MASK: Float + AsPrimitive<f32> + Send + Sync,
1529 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1530>(
1531 mask: ArrayView1<MASK>,
1532 protos: ArrayView3<PROTO>,
1533) -> Array3<u8> {
1534 let shape = protos.shape();
1535
1536 let mask = mask.to_shape((1, mask.len())).unwrap();
1538 let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
1539 let protos = protos.reversed_axes();
1540 let mask = mask.map(|x| x.as_());
1541 let protos = protos.map(|x| x.as_());
1542
1543 let mask = mask
1545 .dot(&protos)
1546 .into_shape_with_order((shape[0], shape[1], 1))
1547 .unwrap();
1548
1549 mask.map(|x| {
1550 let sigmoid = 1.0 / (1.0 + (-*x).exp());
1551 (sigmoid * 255.0).round() as u8
1552 })
1553}
1554
1555fn make_segmentation_quant<
1562 MASK: PrimInt + AsPrimitive<DEST> + Send + Sync,
1563 PROTO: PrimInt + AsPrimitive<DEST> + Send + Sync,
1564 DEST: PrimInt + 'static + Signed + AsPrimitive<f32> + Debug,
1565>(
1566 mask: ArrayView1<MASK>,
1567 protos: ArrayView3<PROTO>,
1568 quant_masks: Quantization,
1569 quant_protos: Quantization,
1570) -> Array3<u8>
1571where
1572 i32: AsPrimitive<DEST>,
1573 f32: AsPrimitive<DEST>,
1574{
1575 let shape = protos.shape();
1576
1577 let mask = mask.to_shape((1, mask.len())).unwrap();
1579
1580 let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
1581 let protos = protos.reversed_axes();
1582
1583 let zp = quant_masks.zero_point.as_();
1584
1585 let mask = mask.mapv(|x| x.as_() - zp);
1586
1587 let zp = quant_protos.zero_point.as_();
1588 let protos = protos.mapv(|x| x.as_() - zp);
1589
1590 let segmentation = mask
1592 .dot(&protos)
1593 .into_shape_with_order((shape[0], shape[1], 1))
1594 .unwrap();
1595
1596 let combined_scale = quant_masks.scale * quant_protos.scale;
1597 segmentation.map(|x| {
1598 let val: f32 = (*x).as_() * combined_scale;
1599 let sigmoid = 1.0 / (1.0 + (-val).exp());
1600 (sigmoid * 255.0).round() as u8
1601 })
1602}
1603
1604pub fn yolo_segmentation_to_mask(
1616 segmentation: ArrayView3<u8>,
1617 threshold: u8,
1618) -> Result<Array2<u8>, crate::DecoderError> {
1619 if segmentation.shape()[2] != 1 {
1620 return Err(crate::DecoderError::InvalidShape(format!(
1621 "Yolo Instance Segmentation should have shape (H, W, 1), got (H, W, {})",
1622 segmentation.shape()[2]
1623 )));
1624 }
1625 Ok(segmentation
1626 .slice(s![.., .., 0])
1627 .map(|x| if *x >= threshold { 1 } else { 0 }))
1628}
1629
1630#[cfg(test)]
1631#[cfg_attr(coverage_nightly, coverage(off))]
1632mod tests {
1633 use super::*;
1634 use ndarray::Array2;
1635
1636 #[test]
1641 fn test_end_to_end_det_basic_filtering() {
1642 let data: Vec<f32> = vec![
1646 0.1, 0.2, 0.3, 0.1, 0.2, 0.3, 0.5, 0.6, 0.7, 0.5, 0.6, 0.7, 0.9, 0.1, 0.2, 0.0, 1.0, 2.0, ];
1654 let output = Array2::from_shape_vec((6, 3), data).unwrap();
1655
1656 let mut boxes = Vec::with_capacity(10);
1657 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1658
1659 assert_eq!(boxes.len(), 1);
1661 assert_eq!(boxes[0].label, 0);
1662 assert!((boxes[0].score - 0.9).abs() < 0.01);
1663 assert!((boxes[0].bbox.xmin - 0.1).abs() < 0.01);
1664 assert!((boxes[0].bbox.ymin - 0.1).abs() < 0.01);
1665 assert!((boxes[0].bbox.xmax - 0.5).abs() < 0.01);
1666 assert!((boxes[0].bbox.ymax - 0.5).abs() < 0.01);
1667 }
1668
1669 #[test]
1670 fn test_end_to_end_det_all_pass_threshold() {
1671 let data: Vec<f32> = vec![
1673 10.0, 20.0, 10.0, 20.0, 50.0, 60.0, 50.0, 60.0, 0.8, 0.7, 1.0, 2.0, ];
1680 let output = Array2::from_shape_vec((6, 2), data).unwrap();
1681
1682 let mut boxes = Vec::with_capacity(10);
1683 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1684
1685 assert_eq!(boxes.len(), 2);
1686 assert_eq!(boxes[0].label, 1);
1687 assert_eq!(boxes[1].label, 2);
1688 }
1689
1690 #[test]
1691 fn test_end_to_end_det_none_pass_threshold() {
1692 let data: Vec<f32> = vec![
1694 10.0, 20.0, 10.0, 20.0, 50.0, 60.0, 50.0, 60.0, 0.1, 0.2, 1.0, 2.0, ];
1701 let output = Array2::from_shape_vec((6, 2), data).unwrap();
1702
1703 let mut boxes = Vec::with_capacity(10);
1704 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1705
1706 assert_eq!(boxes.len(), 0);
1707 }
1708
1709 #[test]
1710 fn test_end_to_end_det_capacity_limit() {
1711 let data: Vec<f32> = vec![
1713 0.1, 0.2, 0.3, 0.4, 0.5, 0.1, 0.2, 0.3, 0.4, 0.5, 0.5, 0.6, 0.7, 0.8, 0.9, 0.5, 0.6, 0.7, 0.8, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.0, 1.0, 2.0, 3.0, 4.0, ];
1720 let output = Array2::from_shape_vec((6, 5), data).unwrap();
1721
1722 let mut boxes = Vec::with_capacity(2); decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1724
1725 assert_eq!(boxes.len(), 2);
1726 }
1727
1728 #[test]
1729 fn test_end_to_end_det_empty_output() {
1730 let output = Array2::<f32>::zeros((6, 0));
1732
1733 let mut boxes = Vec::with_capacity(10);
1734 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1735
1736 assert_eq!(boxes.len(), 0);
1737 }
1738
1739 #[test]
1740 fn test_end_to_end_det_pixel_coordinates() {
1741 let data: Vec<f32> = vec![
1743 100.0, 200.0, 300.0, 400.0, 0.95, 5.0, ];
1750 let output = Array2::from_shape_vec((6, 1), data).unwrap();
1751
1752 let mut boxes = Vec::with_capacity(10);
1753 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1754
1755 assert_eq!(boxes.len(), 1);
1756 assert_eq!(boxes[0].label, 5);
1757 assert!((boxes[0].bbox.xmin - 100.0).abs() < 0.01);
1758 assert!((boxes[0].bbox.ymin - 200.0).abs() < 0.01);
1759 assert!((boxes[0].bbox.xmax - 300.0).abs() < 0.01);
1760 assert!((boxes[0].bbox.ymax - 400.0).abs() < 0.01);
1761 }
1762
1763 #[test]
1764 fn test_end_to_end_det_invalid_shape() {
1765 let output = Array2::<f32>::zeros((5, 3));
1767
1768 let mut boxes = Vec::with_capacity(10);
1769 let result = decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes);
1770
1771 assert!(result.is_err());
1772 assert!(matches!(
1773 result,
1774 Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 6 rows")
1775 ));
1776 }
1777
1778 #[test]
1783 fn test_end_to_end_segdet_basic() {
1784 let num_protos = 32;
1787 let num_detections = 2;
1788 let num_features = 6 + num_protos;
1789
1790 let mut data = vec![0.0f32; num_features * num_detections];
1792 data[0] = 0.1; data[1] = 0.5; data[num_detections] = 0.1; data[num_detections + 1] = 0.5; data[2 * num_detections] = 0.4; data[2 * num_detections + 1] = 0.9; data[3 * num_detections] = 0.4; data[3 * num_detections + 1] = 0.9; data[4 * num_detections] = 0.9; data[4 * num_detections + 1] = 0.3; data[5 * num_detections] = 1.0; data[5 * num_detections + 1] = 2.0; for i in 6..num_features {
1807 data[i * num_detections] = 0.1;
1808 data[i * num_detections + 1] = 0.1;
1809 }
1810
1811 let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
1812
1813 let protos = Array3::<f32>::zeros((16, 16, num_protos));
1815
1816 let mut boxes = Vec::with_capacity(10);
1817 let mut masks = Vec::with_capacity(10);
1818 decode_yolo_end_to_end_segdet_float(
1819 output.view(),
1820 protos.view(),
1821 0.5,
1822 &mut boxes,
1823 &mut masks,
1824 )
1825 .unwrap();
1826
1827 assert_eq!(boxes.len(), 1);
1829 assert_eq!(masks.len(), 1);
1830 assert_eq!(boxes[0].label, 1);
1831 assert!((boxes[0].score - 0.9).abs() < 0.01);
1832 }
1833
1834 #[test]
1835 fn test_end_to_end_segdet_mask_coordinates() {
1836 let num_protos = 32;
1838 let num_features = 6 + num_protos;
1839
1840 let mut data = vec![0.0f32; num_features];
1841 data[0] = 0.2; data[1] = 0.2; data[2] = 0.8; data[3] = 0.8; data[4] = 0.95; data[5] = 3.0; let output = Array2::from_shape_vec((num_features, 1), data).unwrap();
1849 let protos = Array3::<f32>::zeros((16, 16, num_protos));
1850
1851 let mut boxes = Vec::with_capacity(10);
1852 let mut masks = Vec::with_capacity(10);
1853 decode_yolo_end_to_end_segdet_float(
1854 output.view(),
1855 protos.view(),
1856 0.5,
1857 &mut boxes,
1858 &mut masks,
1859 )
1860 .unwrap();
1861
1862 assert_eq!(boxes.len(), 1);
1863 assert_eq!(masks.len(), 1);
1864
1865 assert!((masks[0].xmin - boxes[0].bbox.xmin).abs() < 0.01);
1867 assert!((masks[0].ymin - boxes[0].bbox.ymin).abs() < 0.01);
1868 assert!((masks[0].xmax - boxes[0].bbox.xmax).abs() < 0.01);
1869 assert!((masks[0].ymax - boxes[0].bbox.ymax).abs() < 0.01);
1870 }
1871
1872 #[test]
1873 fn test_end_to_end_segdet_empty_output() {
1874 let num_protos = 32;
1875 let output = Array2::<f32>::zeros((6 + num_protos, 0));
1876 let protos = Array3::<f32>::zeros((16, 16, num_protos));
1877
1878 let mut boxes = Vec::with_capacity(10);
1879 let mut masks = Vec::with_capacity(10);
1880 decode_yolo_end_to_end_segdet_float(
1881 output.view(),
1882 protos.view(),
1883 0.5,
1884 &mut boxes,
1885 &mut masks,
1886 )
1887 .unwrap();
1888
1889 assert_eq!(boxes.len(), 0);
1890 assert_eq!(masks.len(), 0);
1891 }
1892
1893 #[test]
1894 fn test_end_to_end_segdet_capacity_limit() {
1895 let num_protos = 32;
1896 let num_detections = 5;
1897 let num_features = 6 + num_protos;
1898
1899 let mut data = vec![0.0f32; num_features * num_detections];
1900 for i in 0..num_detections {
1902 data[i] = 0.1 * (i as f32); data[num_detections + i] = 0.1 * (i as f32); data[2 * num_detections + i] = 0.1 * (i as f32) + 0.2; data[3 * num_detections + i] = 0.1 * (i as f32) + 0.2; data[4 * num_detections + i] = 0.9; data[5 * num_detections + i] = i as f32; }
1909
1910 let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
1911 let protos = Array3::<f32>::zeros((16, 16, num_protos));
1912
1913 let mut boxes = Vec::with_capacity(2); let mut masks = Vec::with_capacity(2);
1915 decode_yolo_end_to_end_segdet_float(
1916 output.view(),
1917 protos.view(),
1918 0.5,
1919 &mut boxes,
1920 &mut masks,
1921 )
1922 .unwrap();
1923
1924 assert_eq!(boxes.len(), 2);
1925 assert_eq!(masks.len(), 2);
1926 }
1927
1928 #[test]
1929 fn test_end_to_end_segdet_invalid_shape_too_few_rows() {
1930 let output = Array2::<f32>::zeros((6, 3));
1932 let protos = Array3::<f32>::zeros((16, 16, 32));
1933
1934 let mut boxes = Vec::with_capacity(10);
1935 let mut masks = Vec::with_capacity(10);
1936 let result = decode_yolo_end_to_end_segdet_float(
1937 output.view(),
1938 protos.view(),
1939 0.5,
1940 &mut boxes,
1941 &mut masks,
1942 );
1943
1944 assert!(result.is_err());
1945 assert!(matches!(
1946 result,
1947 Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 7 rows")
1948 ));
1949 }
1950
1951 #[test]
1952 fn test_end_to_end_segdet_invalid_shape_protos_mismatch() {
1953 let num_protos = 32;
1955 let output = Array2::<f32>::zeros((6 + 16, 3)); let protos = Array3::<f32>::zeros((16, 16, num_protos)); let mut boxes = Vec::with_capacity(10);
1959 let mut masks = Vec::with_capacity(10);
1960 let result = decode_yolo_end_to_end_segdet_float(
1961 output.view(),
1962 protos.view(),
1963 0.5,
1964 &mut boxes,
1965 &mut masks,
1966 );
1967
1968 assert!(result.is_err());
1969 assert!(matches!(
1970 result,
1971 Err(crate::DecoderError::InvalidShape(s)) if s.contains("doesn't match protos count")
1972 ));
1973 }
1974
1975 #[test]
1980 fn test_split_end_to_end_segdet_basic() {
1981 let num_protos = 32;
1984 let num_detections = 2;
1985 let num_features = 6 + num_protos;
1986
1987 let mut data = vec![0.0f32; num_features * num_detections];
1989 data[0] = 0.1; data[1] = 0.5; data[num_detections] = 0.1; data[num_detections + 1] = 0.5; data[2 * num_detections] = 0.4; data[2 * num_detections + 1] = 0.9; data[3 * num_detections] = 0.4; data[3 * num_detections + 1] = 0.9; data[4 * num_detections] = 0.9; data[4 * num_detections + 1] = 0.3; data[5 * num_detections] = 1.0; data[5 * num_detections + 1] = 2.0; for i in 6..num_features {
2004 data[i * num_detections] = 0.1;
2005 data[i * num_detections + 1] = 0.1;
2006 }
2007
2008 let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
2009 let box_coords = output.slice(s![..4, ..]);
2010 let scores = output.slice(s![4..5, ..]);
2011 let classes = output.slice(s![5..6, ..]);
2012 let mask_coeff = output.slice(s![6.., ..]);
2013 let protos = Array3::<f32>::zeros((16, 16, num_protos));
2015
2016 let mut boxes = Vec::with_capacity(10);
2017 let mut masks = Vec::with_capacity(10);
2018 decode_yolo_split_end_to_end_segdet_float(
2019 box_coords,
2020 scores,
2021 classes,
2022 mask_coeff,
2023 protos.view(),
2024 0.5,
2025 &mut boxes,
2026 &mut masks,
2027 )
2028 .unwrap();
2029
2030 assert_eq!(boxes.len(), 1);
2032 assert_eq!(masks.len(), 1);
2033 assert_eq!(boxes[0].label, 1);
2034 assert!((boxes[0].score - 0.9).abs() < 0.01);
2035 }
2036
2037 #[test]
2042 fn test_segmentation_to_mask_basic() {
2043 let data: Vec<u8> = vec![
2045 100, 200, 50, 150, 10, 255, 128, 64, 0, 127, 128, 255, 64, 64, 192, 192, ];
2050 let segmentation = Array3::from_shape_vec((4, 4, 1), data).unwrap();
2051
2052 let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2053
2054 assert_eq!(mask[[0, 0]], 0); assert_eq!(mask[[0, 1]], 1); assert_eq!(mask[[0, 2]], 0); assert_eq!(mask[[0, 3]], 1); assert_eq!(mask[[1, 1]], 1); assert_eq!(mask[[1, 2]], 1); assert_eq!(mask[[2, 0]], 0); assert_eq!(mask[[2, 1]], 0); }
2064
2065 #[test]
2066 fn test_segmentation_to_mask_all_above() {
2067 let segmentation = Array3::from_elem((4, 4, 1), 255u8);
2068 let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2069 assert!(mask.iter().all(|&x| x == 1));
2070 }
2071
2072 #[test]
2073 fn test_segmentation_to_mask_all_below() {
2074 let segmentation = Array3::from_elem((4, 4, 1), 64u8);
2075 let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2076 assert!(mask.iter().all(|&x| x == 0));
2077 }
2078
2079 #[test]
2080 fn test_segmentation_to_mask_invalid_shape() {
2081 let segmentation = Array3::from_elem((4, 4, 3), 128u8);
2082 let result = yolo_segmentation_to_mask(segmentation.view(), 128);
2083
2084 assert!(result.is_err());
2085 assert!(matches!(
2086 result,
2087 Err(crate::DecoderError::InvalidShape(s)) if s.contains("(H, W, 1)")
2088 ));
2089 }
2090
2091 #[test]
2096 fn test_protobox_clamps_edge_coordinates() {
2097 let protos = Array3::<f32>::zeros((16, 16, 4));
2099 let view = protos.view();
2100 let roi = BoundingBox {
2101 xmin: 0.5,
2102 ymin: 0.5,
2103 xmax: 1.0,
2104 ymax: 1.0,
2105 };
2106 let result = protobox(&view, &roi);
2107 assert!(result.is_ok(), "protobox should accept xmax=1.0");
2108 let (cropped, _roi_norm) = result.unwrap();
2109 assert!(cropped.shape()[0] > 0);
2111 assert!(cropped.shape()[1] > 0);
2112 assert_eq!(cropped.shape()[2], 4);
2113 }
2114
2115 #[test]
2116 fn test_protobox_rejects_wildly_out_of_range() {
2117 let protos = Array3::<f32>::zeros((16, 16, 4));
2119 let view = protos.view();
2120 let roi = BoundingBox {
2121 xmin: 0.0,
2122 ymin: 0.0,
2123 xmax: 3.0,
2124 ymax: 3.0,
2125 };
2126 let result = protobox(&view, &roi);
2127 assert!(
2128 matches!(result, Err(crate::DecoderError::InvalidShape(s)) if s.contains("un-normalized")),
2129 "protobox should reject coords > NORM_LIMIT"
2130 );
2131 }
2132
2133 #[test]
2134 fn test_protobox_accepts_slightly_over_one() {
2135 let protos = Array3::<f32>::zeros((16, 16, 4));
2137 let view = protos.view();
2138 let roi = BoundingBox {
2139 xmin: 0.0,
2140 ymin: 0.0,
2141 xmax: 1.5,
2142 ymax: 1.5,
2143 };
2144 let result = protobox(&view, &roi);
2145 assert!(
2146 result.is_ok(),
2147 "protobox should accept coords <= NORM_LIMIT (2.0)"
2148 );
2149 let (cropped, _roi_norm) = result.unwrap();
2150 assert_eq!(cropped.shape()[0], 16);
2152 assert_eq!(cropped.shape()[1], 16);
2153 }
2154
2155 #[test]
2156 fn test_segdet_float_proto_no_panic() {
2157 let num_proposals = 100; let num_classes = 80;
2161 let num_mask_coeffs = 32;
2162 let rows = 4 + num_classes + num_mask_coeffs; let mut data = vec![0.0f32; rows * num_proposals];
2168 for i in 0..num_proposals {
2169 let row = |r: usize| r * num_proposals + i;
2170 data[row(0)] = 320.0; data[row(1)] = 320.0; data[row(2)] = 50.0; data[row(3)] = 50.0; data[row(4)] = 0.9; }
2176 let boxes = ndarray::Array2::from_shape_vec((rows, num_proposals), data).unwrap();
2177
2178 let protos = ndarray::Array3::<f32>::zeros((160, 160, num_mask_coeffs));
2181
2182 let mut output_boxes = Vec::with_capacity(300);
2183
2184 let proto_data = impl_yolo_segdet_float_proto::<XYWH, _, _>(
2186 boxes.view(),
2187 protos.view(),
2188 0.5,
2189 0.7,
2190 Some(Nms::default()),
2191 &mut output_boxes,
2192 );
2193
2194 assert!(!output_boxes.is_empty());
2196 assert_eq!(proto_data.mask_coefficients.len(), output_boxes.len());
2197 for coeffs in &proto_data.mask_coefficients {
2199 assert_eq!(coeffs.len(), num_mask_coeffs);
2200 }
2201 }
2202}