1use std::fmt::Debug;
5
6use ndarray::{
7 parallel::prelude::{IntoParallelIterator, ParallelIterator},
8 s, Array2, Array3, ArrayView1, ArrayView2, ArrayView3,
9};
10use num_traits::{AsPrimitive, Float, PrimInt, Signed};
11
12use crate::{
13 byte::{
14 nms_class_aware_int, nms_extra_class_aware_int, nms_extra_int, nms_int,
15 postprocess_boxes_index_quant, postprocess_boxes_quant, quantize_score_threshold,
16 },
17 configs::Nms,
18 dequant_detect_box,
19 float::{
20 nms_class_aware_float, nms_extra_class_aware_float, nms_extra_float, nms_float,
21 postprocess_boxes_float, postprocess_boxes_index_float,
22 },
23 BBoxTypeTrait, BoundingBox, DetectBox, DetectBoxQuantized, ProtoData, ProtoTensor,
24 Quantization, Segmentation, XYWH, XYXY,
25};
26
27fn dispatch_nms_float(nms: Option<Nms>, iou: f32, boxes: Vec<DetectBox>) -> Vec<DetectBox> {
29 match nms {
30 Some(Nms::ClassAgnostic) => nms_float(iou, boxes),
31 Some(Nms::ClassAware) => nms_class_aware_float(iou, boxes),
32 None => boxes, }
34}
35
36fn dispatch_nms_extra_float<E: Send + Sync>(
39 nms: Option<Nms>,
40 iou: f32,
41 boxes: Vec<(DetectBox, E)>,
42) -> Vec<(DetectBox, E)> {
43 match nms {
44 Some(Nms::ClassAgnostic) => nms_extra_float(iou, boxes),
45 Some(Nms::ClassAware) => nms_extra_class_aware_float(iou, boxes),
46 None => boxes, }
48}
49
50fn dispatch_nms_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync>(
53 nms: Option<Nms>,
54 iou: f32,
55 boxes: Vec<DetectBoxQuantized<SCORE>>,
56) -> Vec<DetectBoxQuantized<SCORE>> {
57 match nms {
58 Some(Nms::ClassAgnostic) => nms_int(iou, boxes),
59 Some(Nms::ClassAware) => nms_class_aware_int(iou, boxes),
60 None => boxes, }
62}
63
64fn dispatch_nms_extra_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send + Sync>(
67 nms: Option<Nms>,
68 iou: f32,
69 boxes: Vec<(DetectBoxQuantized<SCORE>, E)>,
70) -> Vec<(DetectBoxQuantized<SCORE>, E)> {
71 match nms {
72 Some(Nms::ClassAgnostic) => nms_extra_int(iou, boxes),
73 Some(Nms::ClassAware) => nms_extra_class_aware_int(iou, boxes),
74 None => boxes, }
76}
77
78pub fn decode_yolo_det<BOX: PrimInt + AsPrimitive<f32> + Send + Sync>(
85 output: (ArrayView2<BOX>, Quantization),
86 score_threshold: f32,
87 iou_threshold: f32,
88 nms: Option<Nms>,
89 output_boxes: &mut Vec<DetectBox>,
90) where
91 f32: AsPrimitive<BOX>,
92{
93 impl_yolo_quant::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
94}
95
96pub fn decode_yolo_det_float<T>(
103 output: ArrayView2<T>,
104 score_threshold: f32,
105 iou_threshold: f32,
106 nms: Option<Nms>,
107 output_boxes: &mut Vec<DetectBox>,
108) where
109 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
110 f32: AsPrimitive<T>,
111{
112 impl_yolo_float::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
113}
114
115pub fn decode_yolo_segdet_quant<
127 BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
128 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
129>(
130 boxes: (ArrayView2<BOX>, Quantization),
131 protos: (ArrayView3<PROTO>, Quantization),
132 score_threshold: f32,
133 iou_threshold: f32,
134 nms: Option<Nms>,
135 output_boxes: &mut Vec<DetectBox>,
136 output_masks: &mut Vec<Segmentation>,
137) -> Result<(), crate::DecoderError>
138where
139 f32: AsPrimitive<BOX>,
140{
141 impl_yolo_segdet_quant::<XYWH, _, _>(
142 boxes,
143 protos,
144 score_threshold,
145 iou_threshold,
146 nms,
147 output_boxes,
148 output_masks,
149 )
150}
151
152pub fn decode_yolo_segdet_float<T>(
164 boxes: ArrayView2<T>,
165 protos: ArrayView3<T>,
166 score_threshold: f32,
167 iou_threshold: f32,
168 nms: Option<Nms>,
169 output_boxes: &mut Vec<DetectBox>,
170 output_masks: &mut Vec<Segmentation>,
171) -> Result<(), crate::DecoderError>
172where
173 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
174 f32: AsPrimitive<T>,
175{
176 impl_yolo_segdet_float::<XYWH, _, _>(
177 boxes,
178 protos,
179 score_threshold,
180 iou_threshold,
181 nms,
182 output_boxes,
183 output_masks,
184 )
185}
186
187pub fn decode_yolo_split_det_quant<
199 BOX: PrimInt + AsPrimitive<i32> + AsPrimitive<f32> + Send + Sync,
200 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
201>(
202 boxes: (ArrayView2<BOX>, Quantization),
203 scores: (ArrayView2<SCORE>, Quantization),
204 score_threshold: f32,
205 iou_threshold: f32,
206 nms: Option<Nms>,
207 output_boxes: &mut Vec<DetectBox>,
208) where
209 f32: AsPrimitive<SCORE>,
210{
211 impl_yolo_split_quant::<XYWH, _, _>(
212 boxes,
213 scores,
214 score_threshold,
215 iou_threshold,
216 nms,
217 output_boxes,
218 );
219}
220
221pub fn decode_yolo_split_det_float<T>(
233 boxes: ArrayView2<T>,
234 scores: ArrayView2<T>,
235 score_threshold: f32,
236 iou_threshold: f32,
237 nms: Option<Nms>,
238 output_boxes: &mut Vec<DetectBox>,
239) where
240 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
241 f32: AsPrimitive<T>,
242{
243 impl_yolo_split_float::<XYWH, _, _>(
244 boxes,
245 scores,
246 score_threshold,
247 iou_threshold,
248 nms,
249 output_boxes,
250 );
251}
252
253#[allow(clippy::too_many_arguments)]
267pub fn decode_yolo_split_segdet<
268 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
269 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
270 MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
271 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
272>(
273 boxes: (ArrayView2<BOX>, Quantization),
274 scores: (ArrayView2<SCORE>, Quantization),
275 mask_coeff: (ArrayView2<MASK>, Quantization),
276 protos: (ArrayView3<PROTO>, Quantization),
277 score_threshold: f32,
278 iou_threshold: f32,
279 nms: Option<Nms>,
280 output_boxes: &mut Vec<DetectBox>,
281 output_masks: &mut Vec<Segmentation>,
282) -> Result<(), crate::DecoderError>
283where
284 f32: AsPrimitive<SCORE>,
285{
286 impl_yolo_split_segdet_quant::<XYWH, _, _, _, _>(
287 boxes,
288 scores,
289 mask_coeff,
290 protos,
291 score_threshold,
292 iou_threshold,
293 nms,
294 output_boxes,
295 output_masks,
296 )
297}
298
299#[allow(clippy::too_many_arguments)]
313pub fn decode_yolo_split_segdet_float<T>(
314 boxes: ArrayView2<T>,
315 scores: ArrayView2<T>,
316 mask_coeff: ArrayView2<T>,
317 protos: ArrayView3<T>,
318 score_threshold: f32,
319 iou_threshold: f32,
320 nms: Option<Nms>,
321 output_boxes: &mut Vec<DetectBox>,
322 output_masks: &mut Vec<Segmentation>,
323) -> Result<(), crate::DecoderError>
324where
325 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
326 f32: AsPrimitive<T>,
327{
328 impl_yolo_split_segdet_float::<XYWH, _, _, _, _>(
329 boxes,
330 scores,
331 mask_coeff,
332 protos,
333 score_threshold,
334 iou_threshold,
335 nms,
336 output_boxes,
337 output_masks,
338 )
339}
340
341pub fn decode_yolo_end_to_end_det_float<T>(
356 output: ArrayView2<T>,
357 score_threshold: f32,
358 output_boxes: &mut Vec<DetectBox>,
359) -> Result<(), crate::DecoderError>
360where
361 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
362 f32: AsPrimitive<T>,
363{
364 if output.shape()[0] < 6 {
366 return Err(crate::DecoderError::InvalidShape(format!(
367 "End-to-end detection output requires at least 6 rows, got {}",
368 output.shape()[0]
369 )));
370 }
371
372 let boxes = output.slice(s![0..4, ..]).reversed_axes();
374 let scores = output.slice(s![4..5, ..]).reversed_axes();
375 let classes = output.slice(s![5, ..]);
376 let mut boxes =
377 postprocess_boxes_index_float::<XYXY, _, _>(score_threshold.as_(), boxes, scores);
378 boxes.truncate(output_boxes.capacity());
379 output_boxes.clear();
380 for (mut b, i) in boxes.into_iter() {
381 b.label = classes[i].as_() as usize;
382 output_boxes.push(b);
383 }
384 Ok(())
386}
387
388pub fn decode_yolo_end_to_end_segdet_float<T>(
406 output: ArrayView2<T>,
407 protos: ArrayView3<T>,
408 score_threshold: f32,
409 output_boxes: &mut Vec<DetectBox>,
410 output_masks: &mut Vec<crate::Segmentation>,
411) -> Result<(), crate::DecoderError>
412where
413 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
414 f32: AsPrimitive<T>,
415{
416 if output.shape()[0] < 7 {
418 return Err(crate::DecoderError::InvalidShape(format!(
419 "End-to-end segdet output requires at least 7 rows, got {}",
420 output.shape()[0]
421 )));
422 }
423
424 let num_mask_coeffs = output.shape()[0] - 6;
425 let num_protos = protos.shape()[2];
426 if num_mask_coeffs != num_protos {
427 return Err(crate::DecoderError::InvalidShape(format!(
428 "Mask coefficients count ({}) doesn't match protos count ({})",
429 num_mask_coeffs, num_protos
430 )));
431 }
432
433 let boxes = output.slice(s![0..4, ..]).reversed_axes();
435 let scores = output.slice(s![4..5, ..]).reversed_axes();
436 let classes = output.slice(s![5, ..]);
437 let mask_coeff = output.slice(s![6.., ..]).reversed_axes();
438 let mut boxes =
439 postprocess_boxes_index_float::<XYXY, _, _>(score_threshold.as_(), boxes, scores);
440 boxes.truncate(output_boxes.capacity());
441
442 for (b, ind) in &mut boxes {
443 b.label = classes[*ind].as_() as usize;
444 }
445
446 let boxes = decode_segdet_f32(boxes, mask_coeff, protos)?;
449
450 output_boxes.clear();
451 output_masks.clear();
452 for (b, m) in boxes.into_iter() {
453 output_boxes.push(b);
454 output_masks.push(Segmentation {
455 xmin: b.bbox.xmin,
456 ymin: b.bbox.ymin,
457 xmax: b.bbox.xmax,
458 ymax: b.bbox.ymax,
459 segmentation: m,
460 });
461 }
462 Ok(())
463}
464
465pub fn decode_yolo_split_end_to_end_det_float<T: Float + AsPrimitive<f32>>(
474 boxes: ArrayView2<T>,
475 scores: ArrayView2<T>,
476 classes: ArrayView2<T>,
477 score_threshold: f32,
478 output_boxes: &mut Vec<DetectBox>,
479) -> Result<(), crate::DecoderError> {
480 let n = boxes.shape()[0];
481 if boxes.shape()[1] != 4 {
482 return Err(crate::DecoderError::InvalidShape(format!(
483 "Split end-to-end boxes must have 4 columns, got {}",
484 boxes.shape()[1]
485 )));
486 }
487 output_boxes.clear();
488 for i in 0..n {
489 let score: f32 = scores[[i, 0]].as_();
490 if score < score_threshold {
491 continue;
492 }
493 if output_boxes.len() >= output_boxes.capacity() {
494 break;
495 }
496 output_boxes.push(DetectBox {
497 bbox: BoundingBox {
498 xmin: boxes[[i, 0]].as_(),
499 ymin: boxes[[i, 1]].as_(),
500 xmax: boxes[[i, 2]].as_(),
501 ymax: boxes[[i, 3]].as_(),
502 },
503 score,
504 label: classes[[i, 0]].as_() as usize,
505 });
506 }
507 Ok(())
508}
509
510#[allow(clippy::too_many_arguments)]
519pub fn decode_yolo_split_end_to_end_segdet_float<T>(
520 boxes: ArrayView2<T>,
521 scores: ArrayView2<T>,
522 classes: ArrayView2<T>,
523 mask_coeff: ArrayView2<T>,
524 protos: ArrayView3<T>,
525 score_threshold: f32,
526 output_boxes: &mut Vec<DetectBox>,
527 output_masks: &mut Vec<crate::Segmentation>,
528) -> Result<(), crate::DecoderError>
529where
530 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
531 f32: AsPrimitive<T>,
532{
533 let n = boxes.shape()[0];
534 if boxes.shape()[1] != 4 {
535 return Err(crate::DecoderError::InvalidShape(format!(
536 "Split end-to-end boxes must have 4 columns, got {}",
537 boxes.shape()[1]
538 )));
539 }
540
541 let mut qualifying: Vec<(DetectBox, usize)> = Vec::with_capacity(output_boxes.capacity());
543 for i in 0..n {
544 let score: f32 = scores[[i, 0]].as_();
545 if score < score_threshold {
546 continue;
547 }
548 if qualifying.len() >= output_boxes.capacity() {
549 break;
550 }
551 qualifying.push((
552 DetectBox {
553 bbox: BoundingBox {
554 xmin: boxes[[i, 0]].as_(),
555 ymin: boxes[[i, 1]].as_(),
556 xmax: boxes[[i, 2]].as_(),
557 ymax: boxes[[i, 3]].as_(),
558 },
559 score,
560 label: classes[[i, 0]].as_() as usize,
561 },
562 i,
563 ));
564 }
565
566 let result = decode_segdet_f32(qualifying, mask_coeff, protos)?;
568
569 output_boxes.clear();
570 output_masks.clear();
571 for (b, m) in result.into_iter() {
572 output_masks.push(crate::Segmentation {
573 xmin: b.bbox.xmin,
574 ymin: b.bbox.ymin,
575 xmax: b.bbox.xmax,
576 ymax: b.bbox.ymax,
577 segmentation: m,
578 });
579 output_boxes.push(b);
580 }
581 Ok(())
582}
583
584pub fn impl_yolo_quant<B: BBoxTypeTrait, T: PrimInt + AsPrimitive<f32> + Send + Sync>(
589 output: (ArrayView2<T>, Quantization),
590 score_threshold: f32,
591 iou_threshold: f32,
592 nms: Option<Nms>,
593 output_boxes: &mut Vec<DetectBox>,
594) where
595 f32: AsPrimitive<T>,
596{
597 let (boxes, quant_boxes) = output;
598 let (boxes_tensor, scores_tensor) = postprocess_yolo(&boxes);
599
600 let boxes = {
601 let score_threshold = quantize_score_threshold(score_threshold, quant_boxes);
602 postprocess_boxes_quant::<B, _, _>(
603 score_threshold,
604 boxes_tensor,
605 scores_tensor,
606 quant_boxes,
607 )
608 };
609
610 let boxes = dispatch_nms_int(nms, iou_threshold, boxes);
611 let len = output_boxes.capacity().min(boxes.len());
612 output_boxes.clear();
613 for b in boxes.iter().take(len) {
614 output_boxes.push(dequant_detect_box(b, quant_boxes));
615 }
616}
617
618pub fn impl_yolo_float<B: BBoxTypeTrait, T: Float + AsPrimitive<f32> + Send + Sync>(
623 output: ArrayView2<T>,
624 score_threshold: f32,
625 iou_threshold: f32,
626 nms: Option<Nms>,
627 output_boxes: &mut Vec<DetectBox>,
628) where
629 f32: AsPrimitive<T>,
630{
631 let (boxes_tensor, scores_tensor) = postprocess_yolo(&output);
632 let boxes =
633 postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
634 let boxes = dispatch_nms_float(nms, iou_threshold, boxes);
635 let len = output_boxes.capacity().min(boxes.len());
636 output_boxes.clear();
637 for b in boxes.into_iter().take(len) {
638 output_boxes.push(b);
639 }
640}
641
642pub fn impl_yolo_split_quant<
652 B: BBoxTypeTrait,
653 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
654 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
655>(
656 boxes: (ArrayView2<BOX>, Quantization),
657 scores: (ArrayView2<SCORE>, Quantization),
658 score_threshold: f32,
659 iou_threshold: f32,
660 nms: Option<Nms>,
661 output_boxes: &mut Vec<DetectBox>,
662) where
663 f32: AsPrimitive<SCORE>,
664{
665 let (boxes_tensor, quant_boxes) = boxes;
666 let (scores_tensor, quant_scores) = scores;
667
668 let boxes_tensor = boxes_tensor.reversed_axes();
669 let scores_tensor = scores_tensor.reversed_axes();
670
671 let boxes = {
672 let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
673 postprocess_boxes_quant::<B, _, _>(
674 score_threshold,
675 boxes_tensor,
676 scores_tensor,
677 quant_boxes,
678 )
679 };
680
681 let boxes = dispatch_nms_int(nms, iou_threshold, boxes);
682 let len = output_boxes.capacity().min(boxes.len());
683 output_boxes.clear();
684 for b in boxes.iter().take(len) {
685 output_boxes.push(dequant_detect_box(b, quant_scores));
686 }
687}
688
689pub fn impl_yolo_split_float<
698 B: BBoxTypeTrait,
699 BOX: Float + AsPrimitive<f32> + Send + Sync,
700 SCORE: Float + AsPrimitive<f32> + Send + Sync,
701>(
702 boxes_tensor: ArrayView2<BOX>,
703 scores_tensor: ArrayView2<SCORE>,
704 score_threshold: f32,
705 iou_threshold: f32,
706 nms: Option<Nms>,
707 output_boxes: &mut Vec<DetectBox>,
708) where
709 f32: AsPrimitive<SCORE>,
710{
711 let boxes_tensor = boxes_tensor.reversed_axes();
712 let scores_tensor = scores_tensor.reversed_axes();
713 let boxes =
714 postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
715 let boxes = dispatch_nms_float(nms, iou_threshold, boxes);
716 let len = output_boxes.capacity().min(boxes.len());
717 output_boxes.clear();
718 for b in boxes.into_iter().take(len) {
719 output_boxes.push(b);
720 }
721}
722
723pub fn impl_yolo_segdet_quant<
733 B: BBoxTypeTrait,
734 BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
735 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
736>(
737 boxes: (ArrayView2<BOX>, Quantization),
738 protos: (ArrayView3<PROTO>, Quantization),
739 score_threshold: f32,
740 iou_threshold: f32,
741 nms: Option<Nms>,
742 output_boxes: &mut Vec<DetectBox>,
743 output_masks: &mut Vec<Segmentation>,
744) -> Result<(), crate::DecoderError>
745where
746 f32: AsPrimitive<BOX>,
747{
748 let (boxes, quant_boxes) = boxes;
749 let num_protos = protos.0.dim().2;
750 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
751
752 let boxes = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
753 (boxes_tensor.reversed_axes(), quant_boxes),
754 (scores_tensor.reversed_axes(), quant_boxes),
755 score_threshold,
756 iou_threshold,
757 nms,
758 output_boxes.capacity(),
759 );
760
761 impl_yolo_split_segdet_quant_process_masks::<_, _>(
762 boxes,
763 (mask_tensor.reversed_axes(), quant_boxes),
764 protos,
765 output_boxes,
766 output_masks,
767 )
768}
769
770pub fn impl_yolo_segdet_float<
780 B: BBoxTypeTrait,
781 BOX: Float + AsPrimitive<f32> + Send + Sync,
782 PROTO: Float + AsPrimitive<f32> + Send + Sync,
783>(
784 boxes: ArrayView2<BOX>,
785 protos: ArrayView3<PROTO>,
786 score_threshold: f32,
787 iou_threshold: f32,
788 nms: Option<Nms>,
789 output_boxes: &mut Vec<DetectBox>,
790 output_masks: &mut Vec<Segmentation>,
791) -> Result<(), crate::DecoderError>
792where
793 f32: AsPrimitive<BOX>,
794{
795 let num_protos = protos.dim().2;
796 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
797
798 let boxes = postprocess_boxes_index_float::<B, _, _>(
799 score_threshold.as_(),
800 boxes_tensor,
801 scores_tensor,
802 );
803 let mut boxes = dispatch_nms_extra_float(nms, iou_threshold, boxes);
804 boxes.truncate(output_boxes.capacity());
805 let boxes = decode_segdet_f32(boxes, mask_tensor, protos)?;
806 output_boxes.clear();
807 output_masks.clear();
808 for (b, m) in boxes.into_iter() {
809 output_boxes.push(b);
810 output_masks.push(Segmentation {
811 xmin: b.bbox.xmin,
812 ymin: b.bbox.ymin,
813 xmax: b.bbox.xmax,
814 ymax: b.bbox.ymax,
815 segmentation: m,
816 });
817 }
818 Ok(())
819}
820
821pub(crate) fn impl_yolo_split_segdet_quant_get_boxes<
822 B: BBoxTypeTrait,
823 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
824 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
825>(
826 boxes: (ArrayView2<BOX>, Quantization),
827 scores: (ArrayView2<SCORE>, Quantization),
828 score_threshold: f32,
829 iou_threshold: f32,
830 nms: Option<Nms>,
831 max_boxes: usize,
832) -> Vec<(DetectBox, usize)>
833where
834 f32: AsPrimitive<SCORE>,
835{
836 let (boxes_tensor, quant_boxes) = boxes;
837 let (scores_tensor, quant_scores) = scores;
838
839 let boxes_tensor = boxes_tensor.reversed_axes();
840 let scores_tensor = scores_tensor.reversed_axes();
841
842 let boxes = {
843 let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
844 postprocess_boxes_index_quant::<B, _, _>(
845 score_threshold,
846 boxes_tensor,
847 scores_tensor,
848 quant_boxes,
849 )
850 };
851 let mut boxes = dispatch_nms_extra_int(nms, iou_threshold, boxes);
852 boxes.truncate(max_boxes);
853 boxes
854 .into_iter()
855 .map(|(b, i)| (dequant_detect_box(&b, quant_scores), i))
856 .collect()
857}
858
859pub(crate) fn impl_yolo_split_segdet_quant_process_masks<
860 MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
861 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
862>(
863 boxes: Vec<(DetectBox, usize)>,
864 mask_coeff: (ArrayView2<MASK>, Quantization),
865 protos: (ArrayView3<PROTO>, Quantization),
866 output_boxes: &mut Vec<DetectBox>,
867 output_masks: &mut Vec<Segmentation>,
868) -> Result<(), crate::DecoderError> {
869 let (masks, quant_masks) = mask_coeff;
870 let (protos, quant_protos) = protos;
871
872 let masks = masks.reversed_axes();
873
874 let boxes = decode_segdet_quant(boxes, masks, protos, quant_masks, quant_protos)?;
875 output_boxes.clear();
876 output_masks.clear();
877 for (b, m) in boxes.into_iter() {
878 output_boxes.push(b);
879 output_masks.push(Segmentation {
880 xmin: b.bbox.xmin,
881 ymin: b.bbox.ymin,
882 xmax: b.bbox.xmax,
883 ymax: b.bbox.ymax,
884 segmentation: m,
885 });
886 }
887 Ok(())
888}
889
890#[allow(clippy::too_many_arguments)]
891pub fn impl_yolo_split_segdet_quant<
903 B: BBoxTypeTrait,
904 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
905 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
906 MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
907 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
908>(
909 boxes: (ArrayView2<BOX>, Quantization),
910 scores: (ArrayView2<SCORE>, Quantization),
911 mask_coeff: (ArrayView2<MASK>, Quantization),
912 protos: (ArrayView3<PROTO>, Quantization),
913 score_threshold: f32,
914 iou_threshold: f32,
915 nms: Option<Nms>,
916 output_boxes: &mut Vec<DetectBox>,
917 output_masks: &mut Vec<Segmentation>,
918) -> Result<(), crate::DecoderError>
919where
920 f32: AsPrimitive<SCORE>,
921{
922 let boxes = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
923 boxes,
924 scores,
925 score_threshold,
926 iou_threshold,
927 nms,
928 output_boxes.capacity(),
929 );
930
931 impl_yolo_split_segdet_quant_process_masks(
932 boxes,
933 mask_coeff,
934 protos,
935 output_boxes,
936 output_masks,
937 )
938}
939
940#[allow(clippy::too_many_arguments)]
941pub fn impl_yolo_split_segdet_float<
953 B: BBoxTypeTrait,
954 BOX: Float + AsPrimitive<f32> + Send + Sync,
955 SCORE: Float + AsPrimitive<f32> + Send + Sync,
956 MASK: Float + AsPrimitive<f32> + Send + Sync,
957 PROTO: Float + AsPrimitive<f32> + Send + Sync,
958>(
959 boxes_tensor: ArrayView2<BOX>,
960 scores_tensor: ArrayView2<SCORE>,
961 mask_tensor: ArrayView2<MASK>,
962 protos: ArrayView3<PROTO>,
963 score_threshold: f32,
964 iou_threshold: f32,
965 nms: Option<Nms>,
966 output_boxes: &mut Vec<DetectBox>,
967 output_masks: &mut Vec<Segmentation>,
968) -> Result<(), crate::DecoderError>
969where
970 f32: AsPrimitive<SCORE>,
971{
972 let boxes_tensor = boxes_tensor.reversed_axes();
973 let scores_tensor = scores_tensor.reversed_axes();
974 let mask_tensor = mask_tensor.reversed_axes();
975
976 let boxes = postprocess_boxes_index_float::<B, _, _>(
977 score_threshold.as_(),
978 boxes_tensor,
979 scores_tensor,
980 );
981 let mut boxes = dispatch_nms_extra_float(nms, iou_threshold, boxes);
982 boxes.truncate(output_boxes.capacity());
983 let boxes = decode_segdet_f32(boxes, mask_tensor, protos)?;
984 output_boxes.clear();
985 output_masks.clear();
986 for (b, m) in boxes.into_iter() {
987 output_boxes.push(b);
988 output_masks.push(Segmentation {
989 xmin: b.bbox.xmin,
990 ymin: b.bbox.ymin,
991 xmax: b.bbox.xmax,
992 ymax: b.bbox.ymax,
993 segmentation: m,
994 });
995 }
996 Ok(())
997}
998
999pub fn impl_yolo_segdet_quant_proto<
1006 B: BBoxTypeTrait,
1007 BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1008 PROTO: PrimInt
1009 + AsPrimitive<i64>
1010 + AsPrimitive<i128>
1011 + AsPrimitive<f32>
1012 + AsPrimitive<i8>
1013 + Send
1014 + Sync,
1015>(
1016 boxes: (ArrayView2<BOX>, Quantization),
1017 protos: (ArrayView3<PROTO>, Quantization),
1018 score_threshold: f32,
1019 iou_threshold: f32,
1020 nms: Option<Nms>,
1021 output_boxes: &mut Vec<DetectBox>,
1022) -> ProtoData
1023where
1024 f32: AsPrimitive<BOX>,
1025{
1026 let (boxes_arr, quant_boxes) = boxes;
1027 let (protos_arr, quant_protos) = protos;
1028 let num_protos = protos_arr.dim().2;
1029
1030 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes_arr, num_protos);
1031
1032 let det_indices = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
1033 (boxes_tensor.reversed_axes(), quant_boxes),
1034 (scores_tensor.reversed_axes(), quant_boxes),
1035 score_threshold,
1036 iou_threshold,
1037 nms,
1038 output_boxes.capacity(),
1039 );
1040
1041 extract_proto_data_quant(
1042 det_indices,
1043 mask_tensor,
1044 quant_boxes,
1045 protos_arr,
1046 quant_protos,
1047 output_boxes,
1048 )
1049}
1050
1051pub fn impl_yolo_segdet_float_proto<
1054 B: BBoxTypeTrait,
1055 BOX: Float + AsPrimitive<f32> + Send + Sync,
1056 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1057>(
1058 boxes: ArrayView2<BOX>,
1059 protos: ArrayView3<PROTO>,
1060 score_threshold: f32,
1061 iou_threshold: f32,
1062 nms: Option<Nms>,
1063 output_boxes: &mut Vec<DetectBox>,
1064) -> ProtoData
1065where
1066 f32: AsPrimitive<BOX>,
1067{
1068 let num_protos = protos.dim().2;
1069 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
1070
1071 let det_indices = postprocess_boxes_index_float::<B, _, _>(
1072 score_threshold.as_(),
1073 boxes_tensor,
1074 scores_tensor,
1075 );
1076 let mut det_indices = dispatch_nms_extra_float(nms, iou_threshold, det_indices);
1077 det_indices.truncate(output_boxes.capacity());
1078
1079 extract_proto_data_float(det_indices, mask_tensor, protos, output_boxes)
1080}
1081
1082#[allow(clippy::too_many_arguments)]
1085pub fn impl_yolo_split_segdet_quant_proto<
1086 B: BBoxTypeTrait,
1087 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
1088 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
1089 MASK: PrimInt + AsPrimitive<f32> + Send + Sync,
1090 PROTO: PrimInt + AsPrimitive<f32> + AsPrimitive<i8> + Send + Sync,
1091>(
1092 boxes: (ArrayView2<BOX>, Quantization),
1093 scores: (ArrayView2<SCORE>, Quantization),
1094 mask_coeff: (ArrayView2<MASK>, Quantization),
1095 protos: (ArrayView3<PROTO>, Quantization),
1096 score_threshold: f32,
1097 iou_threshold: f32,
1098 nms: Option<Nms>,
1099 output_boxes: &mut Vec<DetectBox>,
1100) -> ProtoData
1101where
1102 f32: AsPrimitive<SCORE>,
1103{
1104 let det_indices = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
1105 boxes,
1106 scores,
1107 score_threshold,
1108 iou_threshold,
1109 nms,
1110 output_boxes.capacity(),
1111 );
1112
1113 let (masks, quant_masks) = mask_coeff;
1114 let masks = masks.reversed_axes();
1115 let (protos_arr, quant_protos) = protos;
1116
1117 extract_proto_data_quant(
1118 det_indices,
1119 masks,
1120 quant_masks,
1121 protos_arr,
1122 quant_protos,
1123 output_boxes,
1124 )
1125}
1126
1127#[allow(clippy::too_many_arguments)]
1130pub fn impl_yolo_split_segdet_float_proto<
1131 B: BBoxTypeTrait,
1132 BOX: Float + AsPrimitive<f32> + Send + Sync,
1133 SCORE: Float + AsPrimitive<f32> + Send + Sync,
1134 MASK: Float + AsPrimitive<f32> + Send + Sync,
1135 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1136>(
1137 boxes_tensor: ArrayView2<BOX>,
1138 scores_tensor: ArrayView2<SCORE>,
1139 mask_tensor: ArrayView2<MASK>,
1140 protos: ArrayView3<PROTO>,
1141 score_threshold: f32,
1142 iou_threshold: f32,
1143 nms: Option<Nms>,
1144 output_boxes: &mut Vec<DetectBox>,
1145) -> ProtoData
1146where
1147 f32: AsPrimitive<SCORE>,
1148{
1149 let boxes_tensor = boxes_tensor.reversed_axes();
1150 let scores_tensor = scores_tensor.reversed_axes();
1151 let mask_tensor = mask_tensor.reversed_axes();
1152
1153 let det_indices = postprocess_boxes_index_float::<B, _, _>(
1154 score_threshold.as_(),
1155 boxes_tensor,
1156 scores_tensor,
1157 );
1158 let mut det_indices = dispatch_nms_extra_float(nms, iou_threshold, det_indices);
1159 det_indices.truncate(output_boxes.capacity());
1160
1161 extract_proto_data_float(det_indices, mask_tensor, protos, output_boxes)
1162}
1163
1164pub fn decode_yolo_end_to_end_segdet_float_proto<T>(
1166 output: ArrayView2<T>,
1167 protos: ArrayView3<T>,
1168 score_threshold: f32,
1169 output_boxes: &mut Vec<DetectBox>,
1170) -> Result<ProtoData, crate::DecoderError>
1171where
1172 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
1173 f32: AsPrimitive<T>,
1174{
1175 if output.shape()[0] < 7 {
1176 return Err(crate::DecoderError::InvalidShape(format!(
1177 "End-to-end segdet output requires at least 7 rows, got {}",
1178 output.shape()[0]
1179 )));
1180 }
1181
1182 let num_mask_coeffs = output.shape()[0] - 6;
1183 let num_protos = protos.shape()[2];
1184 if num_mask_coeffs != num_protos {
1185 return Err(crate::DecoderError::InvalidShape(format!(
1186 "Mask coefficients count ({}) doesn't match protos count ({})",
1187 num_mask_coeffs, num_protos
1188 )));
1189 }
1190
1191 let boxes = output.slice(s![0..4, ..]).reversed_axes();
1192 let scores = output.slice(s![4..5, ..]).reversed_axes();
1193 let classes = output.slice(s![5, ..]);
1194 let mask_coeff = output.slice(s![6.., ..]).reversed_axes();
1195 let mut det_indices =
1196 postprocess_boxes_index_float::<XYXY, _, _>(score_threshold.as_(), boxes, scores);
1197 det_indices.truncate(output_boxes.capacity());
1198
1199 for (b, ind) in &mut det_indices {
1200 b.label = classes[*ind].as_() as usize;
1201 }
1202
1203 Ok(extract_proto_data_float(
1204 det_indices,
1205 mask_coeff,
1206 protos,
1207 output_boxes,
1208 ))
1209}
1210
1211#[allow(clippy::too_many_arguments)]
1213pub fn decode_yolo_split_end_to_end_segdet_float_proto<T>(
1214 boxes: ArrayView2<T>,
1215 scores: ArrayView2<T>,
1216 classes: ArrayView2<T>,
1217 mask_coeff: ArrayView2<T>,
1218 protos: ArrayView3<T>,
1219 score_threshold: f32,
1220 output_boxes: &mut Vec<DetectBox>,
1221) -> Result<ProtoData, crate::DecoderError>
1222where
1223 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
1224 f32: AsPrimitive<T>,
1225{
1226 let n = boxes.shape()[0];
1227 if boxes.shape()[1] != 4 {
1228 return Err(crate::DecoderError::InvalidShape(format!(
1229 "Split end-to-end boxes must have 4 columns, got {}",
1230 boxes.shape()[1]
1231 )));
1232 }
1233
1234 let mut qualifying: Vec<(DetectBox, usize)> = Vec::with_capacity(output_boxes.capacity());
1235 for i in 0..n {
1236 let score: f32 = scores[[i, 0]].as_();
1237 if score < score_threshold {
1238 continue;
1239 }
1240 if qualifying.len() >= output_boxes.capacity() {
1241 break;
1242 }
1243 qualifying.push((
1244 DetectBox {
1245 bbox: BoundingBox {
1246 xmin: boxes[[i, 0]].as_(),
1247 ymin: boxes[[i, 1]].as_(),
1248 xmax: boxes[[i, 2]].as_(),
1249 ymax: boxes[[i, 3]].as_(),
1250 },
1251 score,
1252 label: classes[[i, 0]].as_() as usize,
1253 },
1254 i,
1255 ));
1256 }
1257
1258 Ok(extract_proto_data_float(
1259 qualifying,
1260 mask_coeff,
1261 protos,
1262 output_boxes,
1263 ))
1264}
1265
1266fn extract_proto_data_float<
1268 MASK: Float + AsPrimitive<f32> + Send + Sync,
1269 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1270>(
1271 det_indices: Vec<(DetectBox, usize)>,
1272 mask_tensor: ArrayView2<MASK>,
1273 protos: ArrayView3<PROTO>,
1274 output_boxes: &mut Vec<DetectBox>,
1275) -> ProtoData {
1276 let mut mask_coefficients = Vec::with_capacity(det_indices.len());
1277 output_boxes.clear();
1278 for (det, idx) in det_indices {
1279 output_boxes.push(det);
1280 let row = mask_tensor.row(idx);
1281 mask_coefficients.push(row.iter().map(|v| v.as_()).collect());
1282 }
1283 let protos_f32 = protos.map(|v| v.as_());
1284 ProtoData {
1285 mask_coefficients,
1286 protos: ProtoTensor::Float(protos_f32),
1287 }
1288}
1289
1290pub(crate) fn extract_proto_data_quant<
1296 MASK: PrimInt + AsPrimitive<f32> + Send + Sync,
1297 PROTO: PrimInt + AsPrimitive<f32> + AsPrimitive<i8> + Send + Sync,
1298>(
1299 det_indices: Vec<(DetectBox, usize)>,
1300 mask_tensor: ArrayView2<MASK>,
1301 quant_masks: Quantization,
1302 protos: ArrayView3<PROTO>,
1303 quant_protos: Quantization,
1304 output_boxes: &mut Vec<DetectBox>,
1305) -> ProtoData {
1306 let mut mask_coefficients = Vec::with_capacity(det_indices.len());
1307 output_boxes.clear();
1308 for (det, idx) in det_indices {
1309 output_boxes.push(det);
1310 let row = mask_tensor.row(idx);
1311 mask_coefficients.push(
1312 row.iter()
1313 .map(|v| (v.as_() - quant_masks.zero_point as f32) * quant_masks.scale)
1314 .collect(),
1315 );
1316 }
1317 let protos_i8 = protos.map(|v| {
1319 let v_i8: i8 = v.as_();
1320 v_i8
1321 });
1322 ProtoData {
1323 mask_coefficients,
1324 protos: ProtoTensor::Quantized {
1325 protos: protos_i8,
1326 quantization: quant_protos,
1327 },
1328 }
1329}
1330
1331fn postprocess_yolo<'a, T>(
1332 output: &'a ArrayView2<'_, T>,
1333) -> (ArrayView2<'a, T>, ArrayView2<'a, T>) {
1334 let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1335 let scores_tensor = output.slice(s![4.., ..,]).reversed_axes();
1336 (boxes_tensor, scores_tensor)
1337}
1338
1339fn postprocess_yolo_seg<'a, T>(
1340 output: &'a ArrayView2<'_, T>,
1341 num_protos: usize,
1342) -> (ArrayView2<'a, T>, ArrayView2<'a, T>, ArrayView2<'a, T>) {
1343 assert!(
1344 output.shape()[0] > num_protos + 4,
1345 "Output shape is too short: {} <= {} + 4",
1346 output.shape()[0],
1347 num_protos
1348 );
1349 let num_classes = output.shape()[0] - 4 - num_protos;
1350 let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1351 let scores_tensor = output.slice(s![4..(num_classes + 4), ..,]).reversed_axes();
1352 let mask_tensor = output.slice(s![(num_classes + 4).., ..,]).reversed_axes();
1353 (boxes_tensor, scores_tensor, mask_tensor)
1354}
1355
1356fn decode_segdet_f32<
1357 MASK: Float + AsPrimitive<f32> + Send + Sync,
1358 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1359>(
1360 boxes: Vec<(DetectBox, usize)>,
1361 masks: ArrayView2<MASK>,
1362 protos: ArrayView3<PROTO>,
1363) -> Result<Vec<(DetectBox, Array3<u8>)>, crate::DecoderError> {
1364 if boxes.is_empty() {
1365 return Ok(Vec::new());
1366 }
1367 if masks.shape()[1] != protos.shape()[2] {
1368 return Err(crate::DecoderError::InvalidShape(format!(
1369 "Mask coefficients count ({}) doesn't match protos channel count ({})",
1370 masks.shape()[1],
1371 protos.shape()[2],
1372 )));
1373 }
1374 boxes
1375 .into_par_iter()
1376 .map(|mut b| {
1377 let ind = b.1;
1378 let (protos, roi) = protobox(&protos, &b.0.bbox)?;
1379 b.0.bbox = roi;
1380 Ok((b.0, make_segmentation(masks.row(ind), protos.view())))
1381 })
1382 .collect()
1383}
1384
1385pub(crate) fn decode_segdet_quant<
1386 MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1387 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1388>(
1389 boxes: Vec<(DetectBox, usize)>,
1390 masks: ArrayView2<MASK>,
1391 protos: ArrayView3<PROTO>,
1392 quant_masks: Quantization,
1393 quant_protos: Quantization,
1394) -> Result<Vec<(DetectBox, Array3<u8>)>, crate::DecoderError> {
1395 if boxes.is_empty() {
1396 return Ok(Vec::new());
1397 }
1398 if masks.shape()[1] != protos.shape()[2] {
1399 return Err(crate::DecoderError::InvalidShape(format!(
1400 "Mask coefficients count ({}) doesn't match protos channel count ({})",
1401 masks.shape()[1],
1402 protos.shape()[2],
1403 )));
1404 }
1405
1406 let total_bits = MASK::zero().count_zeros() + PROTO::zero().count_zeros() + 5; boxes
1408 .into_iter()
1409 .map(|mut b| {
1410 let i = b.1;
1411 let (protos, roi) = protobox(&protos, &b.0.bbox.to_canonical())?;
1412 b.0.bbox = roi;
1413 let seg = match total_bits {
1414 0..=64 => make_segmentation_quant::<MASK, PROTO, i64>(
1415 masks.row(i),
1416 protos.view(),
1417 quant_masks,
1418 quant_protos,
1419 ),
1420 65..=128 => make_segmentation_quant::<MASK, PROTO, i128>(
1421 masks.row(i),
1422 protos.view(),
1423 quant_masks,
1424 quant_protos,
1425 ),
1426 _ => {
1427 return Err(crate::DecoderError::NotSupported(format!(
1428 "Unsupported bit width ({total_bits}) for segmentation computation"
1429 )));
1430 }
1431 };
1432 Ok((b.0, seg))
1433 })
1434 .collect()
1435}
1436
1437fn protobox<'a, T>(
1438 protos: &'a ArrayView3<T>,
1439 roi: &BoundingBox,
1440) -> Result<(ArrayView3<'a, T>, BoundingBox), crate::DecoderError> {
1441 let width = protos.dim().1 as f32;
1442 let height = protos.dim().0 as f32;
1443
1444 const NORM_LIMIT: f32 = 1.01;
1450 if roi.xmin > NORM_LIMIT
1451 || roi.ymin > NORM_LIMIT
1452 || roi.xmax > NORM_LIMIT
1453 || roi.ymax > NORM_LIMIT
1454 {
1455 return Err(crate::DecoderError::InvalidShape(format!(
1456 "Bounding box coordinates appear un-normalized (pixel-space). \
1457 Got bbox=({:.2}, {:.2}, {:.2}, {:.2}) but expected values in [0, 1]. \
1458 ONNX models output pixel-space boxes — normalize them by dividing by \
1459 the input dimensions before calling decode().",
1460 roi.xmin, roi.ymin, roi.xmax, roi.ymax,
1461 )));
1462 }
1463
1464 let roi = [
1465 (roi.xmin * width).clamp(0.0, width) as usize,
1466 (roi.ymin * height).clamp(0.0, height) as usize,
1467 (roi.xmax * width).clamp(0.0, width).ceil() as usize,
1468 (roi.ymax * height).clamp(0.0, height).ceil() as usize,
1469 ];
1470
1471 let roi_norm = [
1472 roi[0] as f32 / width,
1473 roi[1] as f32 / height,
1474 roi[2] as f32 / width,
1475 roi[3] as f32 / height,
1476 ]
1477 .into();
1478
1479 let cropped = protos.slice(s![roi[1]..roi[3], roi[0]..roi[2], ..]);
1480
1481 Ok((cropped, roi_norm))
1482}
1483
1484fn make_segmentation<
1490 MASK: Float + AsPrimitive<f32> + Send + Sync,
1491 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1492>(
1493 mask: ArrayView1<MASK>,
1494 protos: ArrayView3<PROTO>,
1495) -> Array3<u8> {
1496 let shape = protos.shape();
1497
1498 let mask = mask.to_shape((1, mask.len())).unwrap();
1500 let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
1501 let protos = protos.reversed_axes();
1502 let mask = mask.map(|x| x.as_());
1503 let protos = protos.map(|x| x.as_());
1504
1505 let mask = mask
1507 .dot(&protos)
1508 .into_shape_with_order((shape[0], shape[1], 1))
1509 .unwrap();
1510
1511 mask.map(|x| {
1512 let sigmoid = 1.0 / (1.0 + (-*x).exp());
1513 (sigmoid * 255.0).round() as u8
1514 })
1515}
1516
1517fn make_segmentation_quant<
1524 MASK: PrimInt + AsPrimitive<DEST> + Send + Sync,
1525 PROTO: PrimInt + AsPrimitive<DEST> + Send + Sync,
1526 DEST: PrimInt + 'static + Signed + AsPrimitive<f32> + Debug,
1527>(
1528 mask: ArrayView1<MASK>,
1529 protos: ArrayView3<PROTO>,
1530 quant_masks: Quantization,
1531 quant_protos: Quantization,
1532) -> Array3<u8>
1533where
1534 i32: AsPrimitive<DEST>,
1535 f32: AsPrimitive<DEST>,
1536{
1537 let shape = protos.shape();
1538
1539 let mask = mask.to_shape((1, mask.len())).unwrap();
1541
1542 let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
1543 let protos = protos.reversed_axes();
1544
1545 let zp = quant_masks.zero_point.as_();
1546
1547 let mask = mask.mapv(|x| x.as_() - zp);
1548
1549 let zp = quant_protos.zero_point.as_();
1550 let protos = protos.mapv(|x| x.as_() - zp);
1551
1552 let segmentation = mask
1554 .dot(&protos)
1555 .into_shape_with_order((shape[0], shape[1], 1))
1556 .unwrap();
1557
1558 let combined_scale = quant_masks.scale * quant_protos.scale;
1559 segmentation.map(|x| {
1560 let val: f32 = (*x).as_() * combined_scale;
1561 let sigmoid = 1.0 / (1.0 + (-val).exp());
1562 (sigmoid * 255.0).round() as u8
1563 })
1564}
1565
1566pub fn yolo_segmentation_to_mask(
1578 segmentation: ArrayView3<u8>,
1579 threshold: u8,
1580) -> Result<Array2<u8>, crate::DecoderError> {
1581 if segmentation.shape()[2] != 1 {
1582 return Err(crate::DecoderError::InvalidShape(format!(
1583 "Yolo Instance Segmentation should have shape (H, W, 1), got (H, W, {})",
1584 segmentation.shape()[2]
1585 )));
1586 }
1587 Ok(segmentation
1588 .slice(s![.., .., 0])
1589 .map(|x| if *x >= threshold { 1 } else { 0 }))
1590}
1591
1592#[cfg(test)]
1593#[cfg_attr(coverage_nightly, coverage(off))]
1594mod tests {
1595 use super::*;
1596 use ndarray::Array2;
1597
1598 #[test]
1603 fn test_end_to_end_det_basic_filtering() {
1604 let data: Vec<f32> = vec![
1608 0.1, 0.2, 0.3, 0.1, 0.2, 0.3, 0.5, 0.6, 0.7, 0.5, 0.6, 0.7, 0.9, 0.1, 0.2, 0.0, 1.0, 2.0, ];
1616 let output = Array2::from_shape_vec((6, 3), data).unwrap();
1617
1618 let mut boxes = Vec::with_capacity(10);
1619 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1620
1621 assert_eq!(boxes.len(), 1);
1623 assert_eq!(boxes[0].label, 0);
1624 assert!((boxes[0].score - 0.9).abs() < 0.01);
1625 assert!((boxes[0].bbox.xmin - 0.1).abs() < 0.01);
1626 assert!((boxes[0].bbox.ymin - 0.1).abs() < 0.01);
1627 assert!((boxes[0].bbox.xmax - 0.5).abs() < 0.01);
1628 assert!((boxes[0].bbox.ymax - 0.5).abs() < 0.01);
1629 }
1630
1631 #[test]
1632 fn test_end_to_end_det_all_pass_threshold() {
1633 let data: Vec<f32> = vec![
1635 10.0, 20.0, 10.0, 20.0, 50.0, 60.0, 50.0, 60.0, 0.8, 0.7, 1.0, 2.0, ];
1642 let output = Array2::from_shape_vec((6, 2), data).unwrap();
1643
1644 let mut boxes = Vec::with_capacity(10);
1645 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1646
1647 assert_eq!(boxes.len(), 2);
1648 assert_eq!(boxes[0].label, 1);
1649 assert_eq!(boxes[1].label, 2);
1650 }
1651
1652 #[test]
1653 fn test_end_to_end_det_none_pass_threshold() {
1654 let data: Vec<f32> = vec![
1656 10.0, 20.0, 10.0, 20.0, 50.0, 60.0, 50.0, 60.0, 0.1, 0.2, 1.0, 2.0, ];
1663 let output = Array2::from_shape_vec((6, 2), data).unwrap();
1664
1665 let mut boxes = Vec::with_capacity(10);
1666 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1667
1668 assert_eq!(boxes.len(), 0);
1669 }
1670
1671 #[test]
1672 fn test_end_to_end_det_capacity_limit() {
1673 let data: Vec<f32> = vec![
1675 0.1, 0.2, 0.3, 0.4, 0.5, 0.1, 0.2, 0.3, 0.4, 0.5, 0.5, 0.6, 0.7, 0.8, 0.9, 0.5, 0.6, 0.7, 0.8, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.0, 1.0, 2.0, 3.0, 4.0, ];
1682 let output = Array2::from_shape_vec((6, 5), data).unwrap();
1683
1684 let mut boxes = Vec::with_capacity(2); decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1686
1687 assert_eq!(boxes.len(), 2);
1688 }
1689
1690 #[test]
1691 fn test_end_to_end_det_empty_output() {
1692 let output = Array2::<f32>::zeros((6, 0));
1694
1695 let mut boxes = Vec::with_capacity(10);
1696 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1697
1698 assert_eq!(boxes.len(), 0);
1699 }
1700
1701 #[test]
1702 fn test_end_to_end_det_pixel_coordinates() {
1703 let data: Vec<f32> = vec![
1705 100.0, 200.0, 300.0, 400.0, 0.95, 5.0, ];
1712 let output = Array2::from_shape_vec((6, 1), data).unwrap();
1713
1714 let mut boxes = Vec::with_capacity(10);
1715 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1716
1717 assert_eq!(boxes.len(), 1);
1718 assert_eq!(boxes[0].label, 5);
1719 assert!((boxes[0].bbox.xmin - 100.0).abs() < 0.01);
1720 assert!((boxes[0].bbox.ymin - 200.0).abs() < 0.01);
1721 assert!((boxes[0].bbox.xmax - 300.0).abs() < 0.01);
1722 assert!((boxes[0].bbox.ymax - 400.0).abs() < 0.01);
1723 }
1724
1725 #[test]
1726 fn test_end_to_end_det_invalid_shape() {
1727 let output = Array2::<f32>::zeros((5, 3));
1729
1730 let mut boxes = Vec::with_capacity(10);
1731 let result = decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes);
1732
1733 assert!(result.is_err());
1734 assert!(matches!(
1735 result,
1736 Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 6 rows")
1737 ));
1738 }
1739
1740 #[test]
1745 fn test_end_to_end_segdet_basic() {
1746 let num_protos = 32;
1749 let num_detections = 2;
1750 let num_features = 6 + num_protos;
1751
1752 let mut data = vec![0.0f32; num_features * num_detections];
1754 data[0] = 0.1; data[1] = 0.5; data[num_detections] = 0.1; data[num_detections + 1] = 0.5; data[2 * num_detections] = 0.4; data[2 * num_detections + 1] = 0.9; data[3 * num_detections] = 0.4; data[3 * num_detections + 1] = 0.9; data[4 * num_detections] = 0.9; data[4 * num_detections + 1] = 0.3; data[5 * num_detections] = 1.0; data[5 * num_detections + 1] = 2.0; for i in 6..num_features {
1769 data[i * num_detections] = 0.1;
1770 data[i * num_detections + 1] = 0.1;
1771 }
1772
1773 let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
1774
1775 let protos = Array3::<f32>::zeros((16, 16, num_protos));
1777
1778 let mut boxes = Vec::with_capacity(10);
1779 let mut masks = Vec::with_capacity(10);
1780 decode_yolo_end_to_end_segdet_float(
1781 output.view(),
1782 protos.view(),
1783 0.5,
1784 &mut boxes,
1785 &mut masks,
1786 )
1787 .unwrap();
1788
1789 assert_eq!(boxes.len(), 1);
1791 assert_eq!(masks.len(), 1);
1792 assert_eq!(boxes[0].label, 1);
1793 assert!((boxes[0].score - 0.9).abs() < 0.01);
1794 }
1795
1796 #[test]
1797 fn test_end_to_end_segdet_mask_coordinates() {
1798 let num_protos = 32;
1800 let num_features = 6 + num_protos;
1801
1802 let mut data = vec![0.0f32; num_features];
1803 data[0] = 0.2; data[1] = 0.2; data[2] = 0.8; data[3] = 0.8; data[4] = 0.95; data[5] = 3.0; let output = Array2::from_shape_vec((num_features, 1), data).unwrap();
1811 let protos = Array3::<f32>::zeros((16, 16, num_protos));
1812
1813 let mut boxes = Vec::with_capacity(10);
1814 let mut masks = Vec::with_capacity(10);
1815 decode_yolo_end_to_end_segdet_float(
1816 output.view(),
1817 protos.view(),
1818 0.5,
1819 &mut boxes,
1820 &mut masks,
1821 )
1822 .unwrap();
1823
1824 assert_eq!(boxes.len(), 1);
1825 assert_eq!(masks.len(), 1);
1826
1827 assert!((masks[0].xmin - boxes[0].bbox.xmin).abs() < 0.01);
1829 assert!((masks[0].ymin - boxes[0].bbox.ymin).abs() < 0.01);
1830 assert!((masks[0].xmax - boxes[0].bbox.xmax).abs() < 0.01);
1831 assert!((masks[0].ymax - boxes[0].bbox.ymax).abs() < 0.01);
1832 }
1833
1834 #[test]
1835 fn test_end_to_end_segdet_empty_output() {
1836 let num_protos = 32;
1837 let output = Array2::<f32>::zeros((6 + num_protos, 0));
1838 let protos = Array3::<f32>::zeros((16, 16, num_protos));
1839
1840 let mut boxes = Vec::with_capacity(10);
1841 let mut masks = Vec::with_capacity(10);
1842 decode_yolo_end_to_end_segdet_float(
1843 output.view(),
1844 protos.view(),
1845 0.5,
1846 &mut boxes,
1847 &mut masks,
1848 )
1849 .unwrap();
1850
1851 assert_eq!(boxes.len(), 0);
1852 assert_eq!(masks.len(), 0);
1853 }
1854
1855 #[test]
1856 fn test_end_to_end_segdet_capacity_limit() {
1857 let num_protos = 32;
1858 let num_detections = 5;
1859 let num_features = 6 + num_protos;
1860
1861 let mut data = vec![0.0f32; num_features * num_detections];
1862 for i in 0..num_detections {
1864 data[i] = 0.1 * (i as f32); data[num_detections + i] = 0.1 * (i as f32); data[2 * num_detections + i] = 0.1 * (i as f32) + 0.2; data[3 * num_detections + i] = 0.1 * (i as f32) + 0.2; data[4 * num_detections + i] = 0.9; data[5 * num_detections + i] = i as f32; }
1871
1872 let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
1873 let protos = Array3::<f32>::zeros((16, 16, num_protos));
1874
1875 let mut boxes = Vec::with_capacity(2); let mut masks = Vec::with_capacity(2);
1877 decode_yolo_end_to_end_segdet_float(
1878 output.view(),
1879 protos.view(),
1880 0.5,
1881 &mut boxes,
1882 &mut masks,
1883 )
1884 .unwrap();
1885
1886 assert_eq!(boxes.len(), 2);
1887 assert_eq!(masks.len(), 2);
1888 }
1889
1890 #[test]
1891 fn test_end_to_end_segdet_invalid_shape_too_few_rows() {
1892 let output = Array2::<f32>::zeros((6, 3));
1894 let protos = Array3::<f32>::zeros((16, 16, 32));
1895
1896 let mut boxes = Vec::with_capacity(10);
1897 let mut masks = Vec::with_capacity(10);
1898 let result = decode_yolo_end_to_end_segdet_float(
1899 output.view(),
1900 protos.view(),
1901 0.5,
1902 &mut boxes,
1903 &mut masks,
1904 );
1905
1906 assert!(result.is_err());
1907 assert!(matches!(
1908 result,
1909 Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 7 rows")
1910 ));
1911 }
1912
1913 #[test]
1914 fn test_end_to_end_segdet_invalid_shape_protos_mismatch() {
1915 let num_protos = 32;
1917 let output = Array2::<f32>::zeros((6 + 16, 3)); let protos = Array3::<f32>::zeros((16, 16, num_protos)); let mut boxes = Vec::with_capacity(10);
1921 let mut masks = Vec::with_capacity(10);
1922 let result = decode_yolo_end_to_end_segdet_float(
1923 output.view(),
1924 protos.view(),
1925 0.5,
1926 &mut boxes,
1927 &mut masks,
1928 );
1929
1930 assert!(result.is_err());
1931 assert!(matches!(
1932 result,
1933 Err(crate::DecoderError::InvalidShape(s)) if s.contains("doesn't match protos count")
1934 ));
1935 }
1936
1937 #[test]
1942 fn test_segmentation_to_mask_basic() {
1943 let data: Vec<u8> = vec![
1945 100, 200, 50, 150, 10, 255, 128, 64, 0, 127, 128, 255, 64, 64, 192, 192, ];
1950 let segmentation = Array3::from_shape_vec((4, 4, 1), data).unwrap();
1951
1952 let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
1953
1954 assert_eq!(mask[[0, 0]], 0); assert_eq!(mask[[0, 1]], 1); assert_eq!(mask[[0, 2]], 0); assert_eq!(mask[[0, 3]], 1); assert_eq!(mask[[1, 1]], 1); assert_eq!(mask[[1, 2]], 1); assert_eq!(mask[[2, 0]], 0); assert_eq!(mask[[2, 1]], 0); }
1964
1965 #[test]
1966 fn test_segmentation_to_mask_all_above() {
1967 let segmentation = Array3::from_elem((4, 4, 1), 255u8);
1968 let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
1969 assert!(mask.iter().all(|&x| x == 1));
1970 }
1971
1972 #[test]
1973 fn test_segmentation_to_mask_all_below() {
1974 let segmentation = Array3::from_elem((4, 4, 1), 64u8);
1975 let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
1976 assert!(mask.iter().all(|&x| x == 0));
1977 }
1978
1979 #[test]
1980 fn test_segmentation_to_mask_invalid_shape() {
1981 let segmentation = Array3::from_elem((4, 4, 3), 128u8);
1982 let result = yolo_segmentation_to_mask(segmentation.view(), 128);
1983
1984 assert!(result.is_err());
1985 assert!(matches!(
1986 result,
1987 Err(crate::DecoderError::InvalidShape(s)) if s.contains("(H, W, 1)")
1988 ));
1989 }
1990
1991 #[test]
1996 fn test_segdet_float_proto_no_panic() {
1997 let num_proposals = 100; let num_classes = 80;
2001 let num_mask_coeffs = 32;
2002 let rows = 4 + num_classes + num_mask_coeffs; let mut data = vec![0.0f32; rows * num_proposals];
2008 for i in 0..num_proposals {
2009 let row = |r: usize| r * num_proposals + i;
2010 data[row(0)] = 320.0; data[row(1)] = 320.0; data[row(2)] = 50.0; data[row(3)] = 50.0; data[row(4)] = 0.9; }
2016 let boxes = ndarray::Array2::from_shape_vec((rows, num_proposals), data).unwrap();
2017
2018 let protos = ndarray::Array3::<f32>::zeros((160, 160, num_mask_coeffs));
2021
2022 let mut output_boxes = Vec::with_capacity(300);
2023
2024 let proto_data = impl_yolo_segdet_float_proto::<XYWH, _, _>(
2026 boxes.view(),
2027 protos.view(),
2028 0.5,
2029 0.7,
2030 Some(Nms::default()),
2031 &mut output_boxes,
2032 );
2033
2034 assert!(!output_boxes.is_empty());
2036 assert_eq!(proto_data.mask_coefficients.len(), output_boxes.len());
2037 for coeffs in &proto_data.mask_coefficients {
2039 assert_eq!(coeffs.len(), num_mask_coeffs);
2040 }
2041 }
2042}