1use std::fmt::Debug;
5
6use ndarray::{
7 parallel::prelude::{IntoParallelIterator, ParallelIterator},
8 s, Array2, Array3, ArrayView1, ArrayView2, ArrayView3,
9};
10use num_traits::{AsPrimitive, Float, PrimInt, Signed};
11
12use crate::{
13 byte::{
14 nms_class_aware_int, nms_extra_class_aware_int, nms_extra_int, nms_int,
15 postprocess_boxes_index_quant, postprocess_boxes_quant, quantize_score_threshold,
16 },
17 configs::Nms,
18 dequant_detect_box,
19 float::{
20 nms_class_aware_float, nms_extra_class_aware_float, nms_extra_float, nms_float,
21 postprocess_boxes_float, postprocess_boxes_index_float,
22 },
23 BBoxTypeTrait, BoundingBox, DetectBox, DetectBoxQuantized, ProtoData, ProtoTensor,
24 Quantization, Segmentation, XYWH, XYXY,
25};
26
27fn dispatch_nms_float(nms: Option<Nms>, iou: f32, boxes: Vec<DetectBox>) -> Vec<DetectBox> {
29 match nms {
30 Some(Nms::ClassAgnostic) => nms_float(iou, boxes),
31 Some(Nms::ClassAware) => nms_class_aware_float(iou, boxes),
32 None => boxes, }
34}
35
36fn dispatch_nms_extra_float<E: Send + Sync>(
39 nms: Option<Nms>,
40 iou: f32,
41 boxes: Vec<(DetectBox, E)>,
42) -> Vec<(DetectBox, E)> {
43 match nms {
44 Some(Nms::ClassAgnostic) => nms_extra_float(iou, boxes),
45 Some(Nms::ClassAware) => nms_extra_class_aware_float(iou, boxes),
46 None => boxes, }
48}
49
50fn dispatch_nms_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync>(
53 nms: Option<Nms>,
54 iou: f32,
55 boxes: Vec<DetectBoxQuantized<SCORE>>,
56) -> Vec<DetectBoxQuantized<SCORE>> {
57 match nms {
58 Some(Nms::ClassAgnostic) => nms_int(iou, boxes),
59 Some(Nms::ClassAware) => nms_class_aware_int(iou, boxes),
60 None => boxes, }
62}
63
64fn dispatch_nms_extra_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send + Sync>(
67 nms: Option<Nms>,
68 iou: f32,
69 boxes: Vec<(DetectBoxQuantized<SCORE>, E)>,
70) -> Vec<(DetectBoxQuantized<SCORE>, E)> {
71 match nms {
72 Some(Nms::ClassAgnostic) => nms_extra_int(iou, boxes),
73 Some(Nms::ClassAware) => nms_extra_class_aware_int(iou, boxes),
74 None => boxes, }
76}
77
78pub fn decode_yolo_det<BOX: PrimInt + AsPrimitive<f32> + Send + Sync>(
85 output: (ArrayView2<BOX>, Quantization),
86 score_threshold: f32,
87 iou_threshold: f32,
88 nms: Option<Nms>,
89 output_boxes: &mut Vec<DetectBox>,
90) where
91 f32: AsPrimitive<BOX>,
92{
93 impl_yolo_quant::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
94}
95
96pub fn decode_yolo_det_float<T>(
103 output: ArrayView2<T>,
104 score_threshold: f32,
105 iou_threshold: f32,
106 nms: Option<Nms>,
107 output_boxes: &mut Vec<DetectBox>,
108) where
109 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
110 f32: AsPrimitive<T>,
111{
112 impl_yolo_float::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
113}
114
115pub fn decode_yolo_segdet_quant<
127 BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
128 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
129>(
130 boxes: (ArrayView2<BOX>, Quantization),
131 protos: (ArrayView3<PROTO>, Quantization),
132 score_threshold: f32,
133 iou_threshold: f32,
134 nms: Option<Nms>,
135 output_boxes: &mut Vec<DetectBox>,
136 output_masks: &mut Vec<Segmentation>,
137) -> Result<(), crate::DecoderError>
138where
139 f32: AsPrimitive<BOX>,
140{
141 impl_yolo_segdet_quant::<XYWH, _, _>(
142 boxes,
143 protos,
144 score_threshold,
145 iou_threshold,
146 nms,
147 output_boxes,
148 output_masks,
149 )
150}
151
152pub fn decode_yolo_segdet_float<T>(
164 boxes: ArrayView2<T>,
165 protos: ArrayView3<T>,
166 score_threshold: f32,
167 iou_threshold: f32,
168 nms: Option<Nms>,
169 output_boxes: &mut Vec<DetectBox>,
170 output_masks: &mut Vec<Segmentation>,
171) -> Result<(), crate::DecoderError>
172where
173 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
174 f32: AsPrimitive<T>,
175{
176 impl_yolo_segdet_float::<XYWH, _, _>(
177 boxes,
178 protos,
179 score_threshold,
180 iou_threshold,
181 nms,
182 output_boxes,
183 output_masks,
184 )
185}
186
187pub fn decode_yolo_split_det_quant<
199 BOX: PrimInt + AsPrimitive<i32> + AsPrimitive<f32> + Send + Sync,
200 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
201>(
202 boxes: (ArrayView2<BOX>, Quantization),
203 scores: (ArrayView2<SCORE>, Quantization),
204 score_threshold: f32,
205 iou_threshold: f32,
206 nms: Option<Nms>,
207 output_boxes: &mut Vec<DetectBox>,
208) where
209 f32: AsPrimitive<SCORE>,
210{
211 impl_yolo_split_quant::<XYWH, _, _>(
212 boxes,
213 scores,
214 score_threshold,
215 iou_threshold,
216 nms,
217 output_boxes,
218 );
219}
220
221pub fn decode_yolo_split_det_float<T>(
233 boxes: ArrayView2<T>,
234 scores: ArrayView2<T>,
235 score_threshold: f32,
236 iou_threshold: f32,
237 nms: Option<Nms>,
238 output_boxes: &mut Vec<DetectBox>,
239) where
240 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
241 f32: AsPrimitive<T>,
242{
243 impl_yolo_split_float::<XYWH, _, _>(
244 boxes,
245 scores,
246 score_threshold,
247 iou_threshold,
248 nms,
249 output_boxes,
250 );
251}
252
253#[allow(clippy::too_many_arguments)]
267pub fn decode_yolo_split_segdet<
268 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
269 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
270 MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
271 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
272>(
273 boxes: (ArrayView2<BOX>, Quantization),
274 scores: (ArrayView2<SCORE>, Quantization),
275 mask_coeff: (ArrayView2<MASK>, Quantization),
276 protos: (ArrayView3<PROTO>, Quantization),
277 score_threshold: f32,
278 iou_threshold: f32,
279 nms: Option<Nms>,
280 output_boxes: &mut Vec<DetectBox>,
281 output_masks: &mut Vec<Segmentation>,
282) -> Result<(), crate::DecoderError>
283where
284 f32: AsPrimitive<SCORE>,
285{
286 impl_yolo_split_segdet_quant::<XYWH, _, _, _, _>(
287 boxes,
288 scores,
289 mask_coeff,
290 protos,
291 score_threshold,
292 iou_threshold,
293 nms,
294 output_boxes,
295 output_masks,
296 )
297}
298
299#[allow(clippy::too_many_arguments)]
313pub fn decode_yolo_split_segdet_float<T>(
314 boxes: ArrayView2<T>,
315 scores: ArrayView2<T>,
316 mask_coeff: ArrayView2<T>,
317 protos: ArrayView3<T>,
318 score_threshold: f32,
319 iou_threshold: f32,
320 nms: Option<Nms>,
321 output_boxes: &mut Vec<DetectBox>,
322 output_masks: &mut Vec<Segmentation>,
323) -> Result<(), crate::DecoderError>
324where
325 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
326 f32: AsPrimitive<T>,
327{
328 impl_yolo_split_segdet_float::<XYWH, _, _, _, _>(
329 boxes,
330 scores,
331 mask_coeff,
332 protos,
333 score_threshold,
334 iou_threshold,
335 nms,
336 output_boxes,
337 output_masks,
338 )
339}
340
341pub fn decode_yolo_end_to_end_det_float<T>(
356 output: ArrayView2<T>,
357 score_threshold: f32,
358 output_boxes: &mut Vec<DetectBox>,
359) -> Result<(), crate::DecoderError>
360where
361 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
362 f32: AsPrimitive<T>,
363{
364 if output.shape()[0] < 6 {
366 return Err(crate::DecoderError::InvalidShape(format!(
367 "End-to-end detection output requires at least 6 rows, got {}",
368 output.shape()[0]
369 )));
370 }
371
372 let boxes = output.slice(s![0..4, ..]).reversed_axes();
374 let scores = output.slice(s![4..5, ..]).reversed_axes();
375 let classes = output.slice(s![5, ..]);
376 let mut boxes =
377 postprocess_boxes_index_float::<XYXY, _, _>(score_threshold.as_(), boxes, scores);
378 boxes.truncate(output_boxes.capacity());
379 output_boxes.clear();
380 for (mut b, i) in boxes.into_iter() {
381 b.label = classes[i].as_() as usize;
382 output_boxes.push(b);
383 }
384 Ok(())
386}
387
388pub fn decode_yolo_end_to_end_segdet_float<T>(
406 output: ArrayView2<T>,
407 protos: ArrayView3<T>,
408 score_threshold: f32,
409 output_boxes: &mut Vec<DetectBox>,
410 output_masks: &mut Vec<crate::Segmentation>,
411) -> Result<(), crate::DecoderError>
412where
413 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
414 f32: AsPrimitive<T>,
415{
416 if output.shape()[0] < 7 {
418 return Err(crate::DecoderError::InvalidShape(format!(
419 "End-to-end segdet output requires at least 7 rows, got {}",
420 output.shape()[0]
421 )));
422 }
423
424 let num_mask_coeffs = output.shape()[0] - 6;
425 let num_protos = protos.shape()[2];
426 if num_mask_coeffs != num_protos {
427 return Err(crate::DecoderError::InvalidShape(format!(
428 "Mask coefficients count ({}) doesn't match protos count ({})",
429 num_mask_coeffs, num_protos
430 )));
431 }
432
433 let boxes = output.slice(s![0..4, ..]).reversed_axes();
435 let scores = output.slice(s![4..5, ..]).reversed_axes();
436 let classes = output.slice(s![5, ..]);
437 let mask_coeff = output.slice(s![6.., ..]).reversed_axes();
438 let mut boxes =
439 postprocess_boxes_index_float::<XYXY, _, _>(score_threshold.as_(), boxes, scores);
440 boxes.truncate(output_boxes.capacity());
441
442 for (b, ind) in &mut boxes {
443 b.label = classes[*ind].as_() as usize;
444 }
445
446 let boxes = decode_segdet_f32(boxes, mask_coeff, protos)?;
449
450 output_boxes.clear();
451 output_masks.clear();
452 for (b, m) in boxes.into_iter() {
453 output_boxes.push(b);
454 output_masks.push(Segmentation {
455 xmin: b.bbox.xmin,
456 ymin: b.bbox.ymin,
457 xmax: b.bbox.xmax,
458 ymax: b.bbox.ymax,
459 segmentation: m,
460 });
461 }
462 Ok(())
463}
464
465pub fn decode_yolo_split_end_to_end_det_float<T: Float + AsPrimitive<f32>>(
474 boxes: ArrayView2<T>,
475 scores: ArrayView2<T>,
476 classes: ArrayView2<T>,
477 score_threshold: f32,
478 output_boxes: &mut Vec<DetectBox>,
479) -> Result<(), crate::DecoderError> {
480 let n = boxes.shape()[0];
481 if boxes.shape()[1] != 4 {
482 return Err(crate::DecoderError::InvalidShape(format!(
483 "Split end-to-end boxes must have 4 columns, got {}",
484 boxes.shape()[1]
485 )));
486 }
487 output_boxes.clear();
488 for i in 0..n {
489 let score: f32 = scores[[i, 0]].as_();
490 if score < score_threshold {
491 continue;
492 }
493 if output_boxes.len() >= output_boxes.capacity() {
494 break;
495 }
496 output_boxes.push(DetectBox {
497 bbox: BoundingBox {
498 xmin: boxes[[i, 0]].as_(),
499 ymin: boxes[[i, 1]].as_(),
500 xmax: boxes[[i, 2]].as_(),
501 ymax: boxes[[i, 3]].as_(),
502 },
503 score,
504 label: classes[[i, 0]].as_() as usize,
505 });
506 }
507 Ok(())
508}
509
510#[allow(clippy::too_many_arguments)]
519pub fn decode_yolo_split_end_to_end_segdet_float<T>(
520 boxes: ArrayView2<T>,
521 scores: ArrayView2<T>,
522 classes: ArrayView2<T>,
523 mask_coeff: ArrayView2<T>,
524 protos: ArrayView3<T>,
525 score_threshold: f32,
526 output_boxes: &mut Vec<DetectBox>,
527 output_masks: &mut Vec<crate::Segmentation>,
528) -> Result<(), crate::DecoderError>
529where
530 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
531 f32: AsPrimitive<T>,
532{
533 let n = boxes.shape()[0];
534 if boxes.shape()[1] != 4 {
535 return Err(crate::DecoderError::InvalidShape(format!(
536 "Split end-to-end boxes must have 4 columns, got {}",
537 boxes.shape()[1]
538 )));
539 }
540
541 let mut qualifying: Vec<(DetectBox, usize)> = Vec::with_capacity(output_boxes.capacity());
543 for i in 0..n {
544 let score: f32 = scores[[i, 0]].as_();
545 if score < score_threshold {
546 continue;
547 }
548 if qualifying.len() >= output_boxes.capacity() {
549 break;
550 }
551 qualifying.push((
552 DetectBox {
553 bbox: BoundingBox {
554 xmin: boxes[[i, 0]].as_(),
555 ymin: boxes[[i, 1]].as_(),
556 xmax: boxes[[i, 2]].as_(),
557 ymax: boxes[[i, 3]].as_(),
558 },
559 score,
560 label: classes[[i, 0]].as_() as usize,
561 },
562 i,
563 ));
564 }
565
566 let result = decode_segdet_f32(qualifying, mask_coeff, protos)?;
568
569 output_boxes.clear();
570 output_masks.clear();
571 for (b, m) in result.into_iter() {
572 output_masks.push(crate::Segmentation {
573 xmin: b.bbox.xmin,
574 ymin: b.bbox.ymin,
575 xmax: b.bbox.xmax,
576 ymax: b.bbox.ymax,
577 segmentation: m,
578 });
579 output_boxes.push(b);
580 }
581 Ok(())
582}
583
584pub fn impl_yolo_quant<B: BBoxTypeTrait, T: PrimInt + AsPrimitive<f32> + Send + Sync>(
589 output: (ArrayView2<T>, Quantization),
590 score_threshold: f32,
591 iou_threshold: f32,
592 nms: Option<Nms>,
593 output_boxes: &mut Vec<DetectBox>,
594) where
595 f32: AsPrimitive<T>,
596{
597 let (boxes, quant_boxes) = output;
598 let (boxes_tensor, scores_tensor) = postprocess_yolo(&boxes);
599
600 let boxes = {
601 let score_threshold = quantize_score_threshold(score_threshold, quant_boxes);
602 postprocess_boxes_quant::<B, _, _>(
603 score_threshold,
604 boxes_tensor,
605 scores_tensor,
606 quant_boxes,
607 )
608 };
609
610 let boxes = dispatch_nms_int(nms, iou_threshold, boxes);
611 let len = output_boxes.capacity().min(boxes.len());
612 output_boxes.clear();
613 for b in boxes.iter().take(len) {
614 output_boxes.push(dequant_detect_box(b, quant_boxes));
615 }
616}
617
618pub fn impl_yolo_float<B: BBoxTypeTrait, T: Float + AsPrimitive<f32> + Send + Sync>(
623 output: ArrayView2<T>,
624 score_threshold: f32,
625 iou_threshold: f32,
626 nms: Option<Nms>,
627 output_boxes: &mut Vec<DetectBox>,
628) where
629 f32: AsPrimitive<T>,
630{
631 let (boxes_tensor, scores_tensor) = postprocess_yolo(&output);
632 let boxes =
633 postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
634 let boxes = dispatch_nms_float(nms, iou_threshold, boxes);
635 let len = output_boxes.capacity().min(boxes.len());
636 output_boxes.clear();
637 for b in boxes.into_iter().take(len) {
638 output_boxes.push(b);
639 }
640}
641
642pub fn impl_yolo_split_quant<
652 B: BBoxTypeTrait,
653 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
654 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
655>(
656 boxes: (ArrayView2<BOX>, Quantization),
657 scores: (ArrayView2<SCORE>, Quantization),
658 score_threshold: f32,
659 iou_threshold: f32,
660 nms: Option<Nms>,
661 output_boxes: &mut Vec<DetectBox>,
662) where
663 f32: AsPrimitive<SCORE>,
664{
665 let (boxes_tensor, quant_boxes) = boxes;
666 let (scores_tensor, quant_scores) = scores;
667
668 let boxes_tensor = boxes_tensor.reversed_axes();
669 let scores_tensor = scores_tensor.reversed_axes();
670
671 let boxes = {
672 let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
673 postprocess_boxes_quant::<B, _, _>(
674 score_threshold,
675 boxes_tensor,
676 scores_tensor,
677 quant_boxes,
678 )
679 };
680
681 let boxes = dispatch_nms_int(nms, iou_threshold, boxes);
682 let len = output_boxes.capacity().min(boxes.len());
683 output_boxes.clear();
684 for b in boxes.iter().take(len) {
685 output_boxes.push(dequant_detect_box(b, quant_scores));
686 }
687}
688
689pub fn impl_yolo_split_float<
698 B: BBoxTypeTrait,
699 BOX: Float + AsPrimitive<f32> + Send + Sync,
700 SCORE: Float + AsPrimitive<f32> + Send + Sync,
701>(
702 boxes_tensor: ArrayView2<BOX>,
703 scores_tensor: ArrayView2<SCORE>,
704 score_threshold: f32,
705 iou_threshold: f32,
706 nms: Option<Nms>,
707 output_boxes: &mut Vec<DetectBox>,
708) where
709 f32: AsPrimitive<SCORE>,
710{
711 let boxes_tensor = boxes_tensor.reversed_axes();
712 let scores_tensor = scores_tensor.reversed_axes();
713 let boxes =
714 postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
715 let boxes = dispatch_nms_float(nms, iou_threshold, boxes);
716 let len = output_boxes.capacity().min(boxes.len());
717 output_boxes.clear();
718 for b in boxes.into_iter().take(len) {
719 output_boxes.push(b);
720 }
721}
722
723pub fn impl_yolo_segdet_quant<
733 B: BBoxTypeTrait,
734 BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
735 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
736>(
737 boxes: (ArrayView2<BOX>, Quantization),
738 protos: (ArrayView3<PROTO>, Quantization),
739 score_threshold: f32,
740 iou_threshold: f32,
741 nms: Option<Nms>,
742 output_boxes: &mut Vec<DetectBox>,
743 output_masks: &mut Vec<Segmentation>,
744) -> Result<(), crate::DecoderError>
745where
746 f32: AsPrimitive<BOX>,
747{
748 let (boxes, quant_boxes) = boxes;
749 let num_protos = protos.0.dim().2;
750 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
751
752 let boxes = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
753 (boxes_tensor.reversed_axes(), quant_boxes),
754 (scores_tensor.reversed_axes(), quant_boxes),
755 score_threshold,
756 iou_threshold,
757 nms,
758 output_boxes.capacity(),
759 );
760
761 impl_yolo_split_segdet_quant_process_masks::<_, _>(
762 boxes,
763 (mask_tensor.reversed_axes(), quant_boxes),
764 protos,
765 output_boxes,
766 output_masks,
767 )
768}
769
770pub fn impl_yolo_segdet_float<
780 B: BBoxTypeTrait,
781 BOX: Float + AsPrimitive<f32> + Send + Sync,
782 PROTO: Float + AsPrimitive<f32> + Send + Sync,
783>(
784 boxes: ArrayView2<BOX>,
785 protos: ArrayView3<PROTO>,
786 score_threshold: f32,
787 iou_threshold: f32,
788 nms: Option<Nms>,
789 output_boxes: &mut Vec<DetectBox>,
790 output_masks: &mut Vec<Segmentation>,
791) -> Result<(), crate::DecoderError>
792where
793 f32: AsPrimitive<BOX>,
794{
795 let num_protos = protos.dim().2;
796 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
797
798 let boxes = postprocess_boxes_index_float::<B, _, _>(
799 score_threshold.as_(),
800 boxes_tensor,
801 scores_tensor,
802 );
803 let mut boxes = dispatch_nms_extra_float(nms, iou_threshold, boxes);
804 boxes.truncate(output_boxes.capacity());
805 let boxes = decode_segdet_f32(boxes, mask_tensor, protos)?;
806 output_boxes.clear();
807 output_masks.clear();
808 for (b, m) in boxes.into_iter() {
809 output_boxes.push(b);
810 output_masks.push(Segmentation {
811 xmin: b.bbox.xmin,
812 ymin: b.bbox.ymin,
813 xmax: b.bbox.xmax,
814 ymax: b.bbox.ymax,
815 segmentation: m,
816 });
817 }
818 Ok(())
819}
820
821pub(crate) fn impl_yolo_split_segdet_quant_get_boxes<
822 B: BBoxTypeTrait,
823 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
824 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
825>(
826 boxes: (ArrayView2<BOX>, Quantization),
827 scores: (ArrayView2<SCORE>, Quantization),
828 score_threshold: f32,
829 iou_threshold: f32,
830 nms: Option<Nms>,
831 max_boxes: usize,
832) -> Vec<(DetectBox, usize)>
833where
834 f32: AsPrimitive<SCORE>,
835{
836 let (boxes_tensor, quant_boxes) = boxes;
837 let (scores_tensor, quant_scores) = scores;
838
839 let boxes_tensor = boxes_tensor.reversed_axes();
840 let scores_tensor = scores_tensor.reversed_axes();
841
842 let boxes = {
843 let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
844 postprocess_boxes_index_quant::<B, _, _>(
845 score_threshold,
846 boxes_tensor,
847 scores_tensor,
848 quant_boxes,
849 )
850 };
851 let mut boxes = dispatch_nms_extra_int(nms, iou_threshold, boxes);
852 boxes.truncate(max_boxes);
853 boxes
854 .into_iter()
855 .map(|(b, i)| (dequant_detect_box(&b, quant_scores), i))
856 .collect()
857}
858
859pub(crate) fn impl_yolo_split_segdet_quant_process_masks<
860 MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
861 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
862>(
863 boxes: Vec<(DetectBox, usize)>,
864 mask_coeff: (ArrayView2<MASK>, Quantization),
865 protos: (ArrayView3<PROTO>, Quantization),
866 output_boxes: &mut Vec<DetectBox>,
867 output_masks: &mut Vec<Segmentation>,
868) -> Result<(), crate::DecoderError> {
869 let (masks, quant_masks) = mask_coeff;
870 let (protos, quant_protos) = protos;
871
872 let masks = masks.reversed_axes();
873
874 let boxes = decode_segdet_quant(boxes, masks, protos, quant_masks, quant_protos)?;
875 output_boxes.clear();
876 output_masks.clear();
877 for (b, m) in boxes.into_iter() {
878 output_boxes.push(b);
879 output_masks.push(Segmentation {
880 xmin: b.bbox.xmin,
881 ymin: b.bbox.ymin,
882 xmax: b.bbox.xmax,
883 ymax: b.bbox.ymax,
884 segmentation: m,
885 });
886 }
887 Ok(())
888}
889
890#[allow(clippy::too_many_arguments)]
891pub fn impl_yolo_split_segdet_quant<
903 B: BBoxTypeTrait,
904 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
905 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
906 MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
907 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
908>(
909 boxes: (ArrayView2<BOX>, Quantization),
910 scores: (ArrayView2<SCORE>, Quantization),
911 mask_coeff: (ArrayView2<MASK>, Quantization),
912 protos: (ArrayView3<PROTO>, Quantization),
913 score_threshold: f32,
914 iou_threshold: f32,
915 nms: Option<Nms>,
916 output_boxes: &mut Vec<DetectBox>,
917 output_masks: &mut Vec<Segmentation>,
918) -> Result<(), crate::DecoderError>
919where
920 f32: AsPrimitive<SCORE>,
921{
922 let boxes = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
923 boxes,
924 scores,
925 score_threshold,
926 iou_threshold,
927 nms,
928 output_boxes.capacity(),
929 );
930
931 impl_yolo_split_segdet_quant_process_masks(
932 boxes,
933 mask_coeff,
934 protos,
935 output_boxes,
936 output_masks,
937 )
938}
939
940#[allow(clippy::too_many_arguments)]
941pub fn impl_yolo_split_segdet_float<
953 B: BBoxTypeTrait,
954 BOX: Float + AsPrimitive<f32> + Send + Sync,
955 SCORE: Float + AsPrimitive<f32> + Send + Sync,
956 MASK: Float + AsPrimitive<f32> + Send + Sync,
957 PROTO: Float + AsPrimitive<f32> + Send + Sync,
958>(
959 boxes_tensor: ArrayView2<BOX>,
960 scores_tensor: ArrayView2<SCORE>,
961 mask_tensor: ArrayView2<MASK>,
962 protos: ArrayView3<PROTO>,
963 score_threshold: f32,
964 iou_threshold: f32,
965 nms: Option<Nms>,
966 output_boxes: &mut Vec<DetectBox>,
967 output_masks: &mut Vec<Segmentation>,
968) -> Result<(), crate::DecoderError>
969where
970 f32: AsPrimitive<SCORE>,
971{
972 let boxes_tensor = boxes_tensor.reversed_axes();
973 let scores_tensor = scores_tensor.reversed_axes();
974 let mask_tensor = mask_tensor.reversed_axes();
975
976 let boxes = postprocess_boxes_index_float::<B, _, _>(
977 score_threshold.as_(),
978 boxes_tensor,
979 scores_tensor,
980 );
981 let mut boxes = dispatch_nms_extra_float(nms, iou_threshold, boxes);
982 boxes.truncate(output_boxes.capacity());
983 let boxes = decode_segdet_f32(boxes, mask_tensor, protos)?;
984 output_boxes.clear();
985 output_masks.clear();
986 for (b, m) in boxes.into_iter() {
987 output_boxes.push(b);
988 output_masks.push(Segmentation {
989 xmin: b.bbox.xmin,
990 ymin: b.bbox.ymin,
991 xmax: b.bbox.xmax,
992 ymax: b.bbox.ymax,
993 segmentation: m,
994 });
995 }
996 Ok(())
997}
998
999pub fn impl_yolo_segdet_quant_proto<
1006 B: BBoxTypeTrait,
1007 BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1008 PROTO: PrimInt
1009 + AsPrimitive<i64>
1010 + AsPrimitive<i128>
1011 + AsPrimitive<f32>
1012 + AsPrimitive<i8>
1013 + Send
1014 + Sync,
1015>(
1016 boxes: (ArrayView2<BOX>, Quantization),
1017 protos: (ArrayView3<PROTO>, Quantization),
1018 score_threshold: f32,
1019 iou_threshold: f32,
1020 nms: Option<Nms>,
1021 output_boxes: &mut Vec<DetectBox>,
1022) -> ProtoData
1023where
1024 f32: AsPrimitive<BOX>,
1025{
1026 let (boxes_arr, quant_boxes) = boxes;
1027 let (protos_arr, quant_protos) = protos;
1028 let num_protos = protos_arr.dim().2;
1029
1030 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes_arr, num_protos);
1031
1032 let det_indices = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
1033 (boxes_tensor.reversed_axes(), quant_boxes),
1034 (scores_tensor.reversed_axes(), quant_boxes),
1035 score_threshold,
1036 iou_threshold,
1037 nms,
1038 output_boxes.capacity(),
1039 );
1040
1041 extract_proto_data_quant(
1042 det_indices,
1043 mask_tensor,
1044 quant_boxes,
1045 protos_arr,
1046 quant_protos,
1047 output_boxes,
1048 )
1049}
1050
1051pub fn impl_yolo_segdet_float_proto<
1054 B: BBoxTypeTrait,
1055 BOX: Float + AsPrimitive<f32> + Send + Sync,
1056 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1057>(
1058 boxes: ArrayView2<BOX>,
1059 protos: ArrayView3<PROTO>,
1060 score_threshold: f32,
1061 iou_threshold: f32,
1062 nms: Option<Nms>,
1063 output_boxes: &mut Vec<DetectBox>,
1064) -> ProtoData
1065where
1066 f32: AsPrimitive<BOX>,
1067{
1068 let num_protos = protos.dim().2;
1069 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
1070
1071 let det_indices = postprocess_boxes_index_float::<B, _, _>(
1072 score_threshold.as_(),
1073 boxes_tensor,
1074 scores_tensor,
1075 );
1076 let mut det_indices = dispatch_nms_extra_float(nms, iou_threshold, det_indices);
1077 det_indices.truncate(output_boxes.capacity());
1078
1079 extract_proto_data_float(det_indices, mask_tensor, protos, output_boxes)
1080}
1081
1082#[allow(clippy::too_many_arguments)]
1085pub fn impl_yolo_split_segdet_quant_proto<
1086 B: BBoxTypeTrait,
1087 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
1088 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
1089 MASK: PrimInt + AsPrimitive<f32> + Send + Sync,
1090 PROTO: PrimInt + AsPrimitive<f32> + AsPrimitive<i8> + Send + Sync,
1091>(
1092 boxes: (ArrayView2<BOX>, Quantization),
1093 scores: (ArrayView2<SCORE>, Quantization),
1094 mask_coeff: (ArrayView2<MASK>, Quantization),
1095 protos: (ArrayView3<PROTO>, Quantization),
1096 score_threshold: f32,
1097 iou_threshold: f32,
1098 nms: Option<Nms>,
1099 output_boxes: &mut Vec<DetectBox>,
1100) -> ProtoData
1101where
1102 f32: AsPrimitive<SCORE>,
1103{
1104 let det_indices = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
1105 boxes,
1106 scores,
1107 score_threshold,
1108 iou_threshold,
1109 nms,
1110 output_boxes.capacity(),
1111 );
1112
1113 let (masks, quant_masks) = mask_coeff;
1114 let masks = masks.reversed_axes();
1115 let (protos_arr, quant_protos) = protos;
1116
1117 extract_proto_data_quant(
1118 det_indices,
1119 masks,
1120 quant_masks,
1121 protos_arr,
1122 quant_protos,
1123 output_boxes,
1124 )
1125}
1126
1127#[allow(clippy::too_many_arguments)]
1130pub fn impl_yolo_split_segdet_float_proto<
1131 B: BBoxTypeTrait,
1132 BOX: Float + AsPrimitive<f32> + Send + Sync,
1133 SCORE: Float + AsPrimitive<f32> + Send + Sync,
1134 MASK: Float + AsPrimitive<f32> + Send + Sync,
1135 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1136>(
1137 boxes_tensor: ArrayView2<BOX>,
1138 scores_tensor: ArrayView2<SCORE>,
1139 mask_tensor: ArrayView2<MASK>,
1140 protos: ArrayView3<PROTO>,
1141 score_threshold: f32,
1142 iou_threshold: f32,
1143 nms: Option<Nms>,
1144 output_boxes: &mut Vec<DetectBox>,
1145) -> ProtoData
1146where
1147 f32: AsPrimitive<SCORE>,
1148{
1149 let boxes_tensor = boxes_tensor.reversed_axes();
1150 let scores_tensor = scores_tensor.reversed_axes();
1151 let mask_tensor = mask_tensor.reversed_axes();
1152
1153 let det_indices = postprocess_boxes_index_float::<B, _, _>(
1154 score_threshold.as_(),
1155 boxes_tensor,
1156 scores_tensor,
1157 );
1158 let mut det_indices = dispatch_nms_extra_float(nms, iou_threshold, det_indices);
1159 det_indices.truncate(output_boxes.capacity());
1160
1161 extract_proto_data_float(det_indices, mask_tensor, protos, output_boxes)
1162}
1163
1164pub fn decode_yolo_end_to_end_segdet_float_proto<T>(
1166 output: ArrayView2<T>,
1167 protos: ArrayView3<T>,
1168 score_threshold: f32,
1169 output_boxes: &mut Vec<DetectBox>,
1170) -> Result<ProtoData, crate::DecoderError>
1171where
1172 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
1173 f32: AsPrimitive<T>,
1174{
1175 if output.shape()[0] < 7 {
1176 return Err(crate::DecoderError::InvalidShape(format!(
1177 "End-to-end segdet output requires at least 7 rows, got {}",
1178 output.shape()[0]
1179 )));
1180 }
1181
1182 let num_mask_coeffs = output.shape()[0] - 6;
1183 let num_protos = protos.shape()[2];
1184 if num_mask_coeffs != num_protos {
1185 return Err(crate::DecoderError::InvalidShape(format!(
1186 "Mask coefficients count ({}) doesn't match protos count ({})",
1187 num_mask_coeffs, num_protos
1188 )));
1189 }
1190
1191 let boxes = output.slice(s![0..4, ..]).reversed_axes();
1192 let scores = output.slice(s![4..5, ..]).reversed_axes();
1193 let classes = output.slice(s![5, ..]);
1194 let mask_coeff = output.slice(s![6.., ..]).reversed_axes();
1195 let mut det_indices =
1196 postprocess_boxes_index_float::<XYXY, _, _>(score_threshold.as_(), boxes, scores);
1197 det_indices.truncate(output_boxes.capacity());
1198
1199 for (b, ind) in &mut det_indices {
1200 b.label = classes[*ind].as_() as usize;
1201 }
1202
1203 Ok(extract_proto_data_float(
1204 det_indices,
1205 mask_coeff,
1206 protos,
1207 output_boxes,
1208 ))
1209}
1210
1211#[allow(clippy::too_many_arguments)]
1213pub fn decode_yolo_split_end_to_end_segdet_float_proto<T>(
1214 boxes: ArrayView2<T>,
1215 scores: ArrayView2<T>,
1216 classes: ArrayView2<T>,
1217 mask_coeff: ArrayView2<T>,
1218 protos: ArrayView3<T>,
1219 score_threshold: f32,
1220 output_boxes: &mut Vec<DetectBox>,
1221) -> Result<ProtoData, crate::DecoderError>
1222where
1223 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
1224 f32: AsPrimitive<T>,
1225{
1226 let n = boxes.shape()[0];
1227 if boxes.shape()[1] != 4 {
1228 return Err(crate::DecoderError::InvalidShape(format!(
1229 "Split end-to-end boxes must have 4 columns, got {}",
1230 boxes.shape()[1]
1231 )));
1232 }
1233
1234 let mut qualifying: Vec<(DetectBox, usize)> = Vec::with_capacity(output_boxes.capacity());
1235 for i in 0..n {
1236 let score: f32 = scores[[i, 0]].as_();
1237 if score < score_threshold {
1238 continue;
1239 }
1240 if qualifying.len() >= output_boxes.capacity() {
1241 break;
1242 }
1243 qualifying.push((
1244 DetectBox {
1245 bbox: BoundingBox {
1246 xmin: boxes[[i, 0]].as_(),
1247 ymin: boxes[[i, 1]].as_(),
1248 xmax: boxes[[i, 2]].as_(),
1249 ymax: boxes[[i, 3]].as_(),
1250 },
1251 score,
1252 label: classes[[i, 0]].as_() as usize,
1253 },
1254 i,
1255 ));
1256 }
1257
1258 Ok(extract_proto_data_float(
1259 qualifying,
1260 mask_coeff,
1261 protos,
1262 output_boxes,
1263 ))
1264}
1265
1266fn extract_proto_data_float<
1268 MASK: Float + AsPrimitive<f32> + Send + Sync,
1269 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1270>(
1271 det_indices: Vec<(DetectBox, usize)>,
1272 mask_tensor: ArrayView2<MASK>,
1273 protos: ArrayView3<PROTO>,
1274 output_boxes: &mut Vec<DetectBox>,
1275) -> ProtoData {
1276 let mut mask_coefficients = Vec::with_capacity(det_indices.len());
1277 output_boxes.clear();
1278 for (det, idx) in det_indices {
1279 output_boxes.push(det);
1280 let row = mask_tensor.row(idx);
1281 mask_coefficients.push(row.iter().map(|v| v.as_()).collect());
1282 }
1283 let protos_f32 = protos.map(|v| v.as_());
1284 ProtoData {
1285 mask_coefficients,
1286 protos: ProtoTensor::Float(protos_f32),
1287 }
1288}
1289
1290pub(crate) fn extract_proto_data_quant<
1296 MASK: PrimInt + AsPrimitive<f32> + Send + Sync,
1297 PROTO: PrimInt + AsPrimitive<f32> + AsPrimitive<i8> + Send + Sync,
1298>(
1299 det_indices: Vec<(DetectBox, usize)>,
1300 mask_tensor: ArrayView2<MASK>,
1301 quant_masks: Quantization,
1302 protos: ArrayView3<PROTO>,
1303 quant_protos: Quantization,
1304 output_boxes: &mut Vec<DetectBox>,
1305) -> ProtoData {
1306 let mut mask_coefficients = Vec::with_capacity(det_indices.len());
1307 output_boxes.clear();
1308 for (det, idx) in det_indices {
1309 output_boxes.push(det);
1310 let row = mask_tensor.row(idx);
1311 mask_coefficients.push(
1312 row.iter()
1313 .map(|v| (v.as_() - quant_masks.zero_point as f32) * quant_masks.scale)
1314 .collect(),
1315 );
1316 }
1317 let protos_i8 = protos.map(|v| {
1319 let v_i8: i8 = v.as_();
1320 v_i8
1321 });
1322 ProtoData {
1323 mask_coefficients,
1324 protos: ProtoTensor::Quantized {
1325 protos: protos_i8,
1326 quantization: quant_protos,
1327 },
1328 }
1329}
1330
1331fn postprocess_yolo<'a, T>(
1332 output: &'a ArrayView2<'_, T>,
1333) -> (ArrayView2<'a, T>, ArrayView2<'a, T>) {
1334 let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1335 let scores_tensor = output.slice(s![4.., ..,]).reversed_axes();
1336 (boxes_tensor, scores_tensor)
1337}
1338
1339fn postprocess_yolo_seg<'a, T>(
1340 output: &'a ArrayView2<'_, T>,
1341 num_protos: usize,
1342) -> (ArrayView2<'a, T>, ArrayView2<'a, T>, ArrayView2<'a, T>) {
1343 assert!(
1344 output.shape()[0] > num_protos + 4,
1345 "Output shape is too short: {} <= {} + 4",
1346 output.shape()[0],
1347 num_protos
1348 );
1349 let num_classes = output.shape()[0] - 4 - num_protos;
1350 let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1351 let scores_tensor = output.slice(s![4..(num_classes + 4), ..,]).reversed_axes();
1352 let mask_tensor = output.slice(s![(num_classes + 4).., ..,]).reversed_axes();
1353 (boxes_tensor, scores_tensor, mask_tensor)
1354}
1355
1356fn decode_segdet_f32<
1357 MASK: Float + AsPrimitive<f32> + Send + Sync,
1358 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1359>(
1360 boxes: Vec<(DetectBox, usize)>,
1361 masks: ArrayView2<MASK>,
1362 protos: ArrayView3<PROTO>,
1363) -> Result<Vec<(DetectBox, Array3<u8>)>, crate::DecoderError> {
1364 if boxes.is_empty() {
1365 return Ok(Vec::new());
1366 }
1367 if masks.shape()[1] != protos.shape()[2] {
1368 return Err(crate::DecoderError::InvalidShape(format!(
1369 "Mask coefficients count ({}) doesn't match protos channel count ({})",
1370 masks.shape()[1],
1371 protos.shape()[2],
1372 )));
1373 }
1374 boxes
1375 .into_par_iter()
1376 .map(|mut b| {
1377 let ind = b.1;
1378 let (protos, roi) = protobox(&protos, &b.0.bbox)?;
1379 b.0.bbox = roi;
1380 Ok((b.0, make_segmentation(masks.row(ind), protos.view())))
1381 })
1382 .collect()
1383}
1384
1385pub(crate) fn decode_segdet_quant<
1386 MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1387 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1388>(
1389 boxes: Vec<(DetectBox, usize)>,
1390 masks: ArrayView2<MASK>,
1391 protos: ArrayView3<PROTO>,
1392 quant_masks: Quantization,
1393 quant_protos: Quantization,
1394) -> Result<Vec<(DetectBox, Array3<u8>)>, crate::DecoderError> {
1395 if boxes.is_empty() {
1396 return Ok(Vec::new());
1397 }
1398 if masks.shape()[1] != protos.shape()[2] {
1399 return Err(crate::DecoderError::InvalidShape(format!(
1400 "Mask coefficients count ({}) doesn't match protos channel count ({})",
1401 masks.shape()[1],
1402 protos.shape()[2],
1403 )));
1404 }
1405
1406 let total_bits = MASK::zero().count_zeros() + PROTO::zero().count_zeros() + 5; boxes
1408 .into_iter()
1409 .map(|mut b| {
1410 let i = b.1;
1411 let (protos, roi) = protobox(&protos, &b.0.bbox.to_canonical())?;
1412 b.0.bbox = roi;
1413 let seg = match total_bits {
1414 0..=64 => make_segmentation_quant::<MASK, PROTO, i64>(
1415 masks.row(i),
1416 protos.view(),
1417 quant_masks,
1418 quant_protos,
1419 ),
1420 65..=128 => make_segmentation_quant::<MASK, PROTO, i128>(
1421 masks.row(i),
1422 protos.view(),
1423 quant_masks,
1424 quant_protos,
1425 ),
1426 _ => {
1427 return Err(crate::DecoderError::NotSupported(format!(
1428 "Unsupported bit width ({total_bits}) for segmentation computation"
1429 )));
1430 }
1431 };
1432 Ok((b.0, seg))
1433 })
1434 .collect()
1435}
1436
1437fn protobox<'a, T>(
1438 protos: &'a ArrayView3<T>,
1439 roi: &BoundingBox,
1440) -> Result<(ArrayView3<'a, T>, BoundingBox), crate::DecoderError> {
1441 let width = protos.dim().1 as f32;
1442 let height = protos.dim().0 as f32;
1443
1444 const NORM_LIMIT: f32 = 2.0;
1455 if roi.xmin > NORM_LIMIT
1456 || roi.ymin > NORM_LIMIT
1457 || roi.xmax > NORM_LIMIT
1458 || roi.ymax > NORM_LIMIT
1459 {
1460 return Err(crate::DecoderError::InvalidShape(format!(
1461 "Bounding box coordinates appear un-normalized (pixel-space). \
1462 Got bbox=({:.2}, {:.2}, {:.2}, {:.2}) but expected values in [0, 1]. \
1463 ONNX models output pixel-space boxes — normalize them by dividing by \
1464 the input dimensions before calling decode().",
1465 roi.xmin, roi.ymin, roi.xmax, roi.ymax,
1466 )));
1467 }
1468
1469 let roi = [
1470 (roi.xmin * width).clamp(0.0, width) as usize,
1471 (roi.ymin * height).clamp(0.0, height) as usize,
1472 (roi.xmax * width).clamp(0.0, width).ceil() as usize,
1473 (roi.ymax * height).clamp(0.0, height).ceil() as usize,
1474 ];
1475
1476 let roi_norm = [
1477 roi[0] as f32 / width,
1478 roi[1] as f32 / height,
1479 roi[2] as f32 / width,
1480 roi[3] as f32 / height,
1481 ]
1482 .into();
1483
1484 let cropped = protos.slice(s![roi[1]..roi[3], roi[0]..roi[2], ..]);
1485
1486 Ok((cropped, roi_norm))
1487}
1488
1489fn make_segmentation<
1495 MASK: Float + AsPrimitive<f32> + Send + Sync,
1496 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1497>(
1498 mask: ArrayView1<MASK>,
1499 protos: ArrayView3<PROTO>,
1500) -> Array3<u8> {
1501 let shape = protos.shape();
1502
1503 let mask = mask.to_shape((1, mask.len())).unwrap();
1505 let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
1506 let protos = protos.reversed_axes();
1507 let mask = mask.map(|x| x.as_());
1508 let protos = protos.map(|x| x.as_());
1509
1510 let mask = mask
1512 .dot(&protos)
1513 .into_shape_with_order((shape[0], shape[1], 1))
1514 .unwrap();
1515
1516 mask.map(|x| {
1517 let sigmoid = 1.0 / (1.0 + (-*x).exp());
1518 (sigmoid * 255.0).round() as u8
1519 })
1520}
1521
1522fn make_segmentation_quant<
1529 MASK: PrimInt + AsPrimitive<DEST> + Send + Sync,
1530 PROTO: PrimInt + AsPrimitive<DEST> + Send + Sync,
1531 DEST: PrimInt + 'static + Signed + AsPrimitive<f32> + Debug,
1532>(
1533 mask: ArrayView1<MASK>,
1534 protos: ArrayView3<PROTO>,
1535 quant_masks: Quantization,
1536 quant_protos: Quantization,
1537) -> Array3<u8>
1538where
1539 i32: AsPrimitive<DEST>,
1540 f32: AsPrimitive<DEST>,
1541{
1542 let shape = protos.shape();
1543
1544 let mask = mask.to_shape((1, mask.len())).unwrap();
1546
1547 let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
1548 let protos = protos.reversed_axes();
1549
1550 let zp = quant_masks.zero_point.as_();
1551
1552 let mask = mask.mapv(|x| x.as_() - zp);
1553
1554 let zp = quant_protos.zero_point.as_();
1555 let protos = protos.mapv(|x| x.as_() - zp);
1556
1557 let segmentation = mask
1559 .dot(&protos)
1560 .into_shape_with_order((shape[0], shape[1], 1))
1561 .unwrap();
1562
1563 let combined_scale = quant_masks.scale * quant_protos.scale;
1564 segmentation.map(|x| {
1565 let val: f32 = (*x).as_() * combined_scale;
1566 let sigmoid = 1.0 / (1.0 + (-val).exp());
1567 (sigmoid * 255.0).round() as u8
1568 })
1569}
1570
1571pub fn yolo_segmentation_to_mask(
1583 segmentation: ArrayView3<u8>,
1584 threshold: u8,
1585) -> Result<Array2<u8>, crate::DecoderError> {
1586 if segmentation.shape()[2] != 1 {
1587 return Err(crate::DecoderError::InvalidShape(format!(
1588 "Yolo Instance Segmentation should have shape (H, W, 1), got (H, W, {})",
1589 segmentation.shape()[2]
1590 )));
1591 }
1592 Ok(segmentation
1593 .slice(s![.., .., 0])
1594 .map(|x| if *x >= threshold { 1 } else { 0 }))
1595}
1596
1597#[cfg(test)]
1598#[cfg_attr(coverage_nightly, coverage(off))]
1599mod tests {
1600 use super::*;
1601 use ndarray::Array2;
1602
1603 #[test]
1608 fn test_end_to_end_det_basic_filtering() {
1609 let data: Vec<f32> = vec![
1613 0.1, 0.2, 0.3, 0.1, 0.2, 0.3, 0.5, 0.6, 0.7, 0.5, 0.6, 0.7, 0.9, 0.1, 0.2, 0.0, 1.0, 2.0, ];
1621 let output = Array2::from_shape_vec((6, 3), data).unwrap();
1622
1623 let mut boxes = Vec::with_capacity(10);
1624 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1625
1626 assert_eq!(boxes.len(), 1);
1628 assert_eq!(boxes[0].label, 0);
1629 assert!((boxes[0].score - 0.9).abs() < 0.01);
1630 assert!((boxes[0].bbox.xmin - 0.1).abs() < 0.01);
1631 assert!((boxes[0].bbox.ymin - 0.1).abs() < 0.01);
1632 assert!((boxes[0].bbox.xmax - 0.5).abs() < 0.01);
1633 assert!((boxes[0].bbox.ymax - 0.5).abs() < 0.01);
1634 }
1635
1636 #[test]
1637 fn test_end_to_end_det_all_pass_threshold() {
1638 let data: Vec<f32> = vec![
1640 10.0, 20.0, 10.0, 20.0, 50.0, 60.0, 50.0, 60.0, 0.8, 0.7, 1.0, 2.0, ];
1647 let output = Array2::from_shape_vec((6, 2), data).unwrap();
1648
1649 let mut boxes = Vec::with_capacity(10);
1650 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1651
1652 assert_eq!(boxes.len(), 2);
1653 assert_eq!(boxes[0].label, 1);
1654 assert_eq!(boxes[1].label, 2);
1655 }
1656
1657 #[test]
1658 fn test_end_to_end_det_none_pass_threshold() {
1659 let data: Vec<f32> = vec![
1661 10.0, 20.0, 10.0, 20.0, 50.0, 60.0, 50.0, 60.0, 0.1, 0.2, 1.0, 2.0, ];
1668 let output = Array2::from_shape_vec((6, 2), data).unwrap();
1669
1670 let mut boxes = Vec::with_capacity(10);
1671 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1672
1673 assert_eq!(boxes.len(), 0);
1674 }
1675
1676 #[test]
1677 fn test_end_to_end_det_capacity_limit() {
1678 let data: Vec<f32> = vec![
1680 0.1, 0.2, 0.3, 0.4, 0.5, 0.1, 0.2, 0.3, 0.4, 0.5, 0.5, 0.6, 0.7, 0.8, 0.9, 0.5, 0.6, 0.7, 0.8, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.0, 1.0, 2.0, 3.0, 4.0, ];
1687 let output = Array2::from_shape_vec((6, 5), data).unwrap();
1688
1689 let mut boxes = Vec::with_capacity(2); decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1691
1692 assert_eq!(boxes.len(), 2);
1693 }
1694
1695 #[test]
1696 fn test_end_to_end_det_empty_output() {
1697 let output = Array2::<f32>::zeros((6, 0));
1699
1700 let mut boxes = Vec::with_capacity(10);
1701 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1702
1703 assert_eq!(boxes.len(), 0);
1704 }
1705
1706 #[test]
1707 fn test_end_to_end_det_pixel_coordinates() {
1708 let data: Vec<f32> = vec![
1710 100.0, 200.0, 300.0, 400.0, 0.95, 5.0, ];
1717 let output = Array2::from_shape_vec((6, 1), data).unwrap();
1718
1719 let mut boxes = Vec::with_capacity(10);
1720 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1721
1722 assert_eq!(boxes.len(), 1);
1723 assert_eq!(boxes[0].label, 5);
1724 assert!((boxes[0].bbox.xmin - 100.0).abs() < 0.01);
1725 assert!((boxes[0].bbox.ymin - 200.0).abs() < 0.01);
1726 assert!((boxes[0].bbox.xmax - 300.0).abs() < 0.01);
1727 assert!((boxes[0].bbox.ymax - 400.0).abs() < 0.01);
1728 }
1729
1730 #[test]
1731 fn test_end_to_end_det_invalid_shape() {
1732 let output = Array2::<f32>::zeros((5, 3));
1734
1735 let mut boxes = Vec::with_capacity(10);
1736 let result = decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes);
1737
1738 assert!(result.is_err());
1739 assert!(matches!(
1740 result,
1741 Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 6 rows")
1742 ));
1743 }
1744
1745 #[test]
1750 fn test_end_to_end_segdet_basic() {
1751 let num_protos = 32;
1754 let num_detections = 2;
1755 let num_features = 6 + num_protos;
1756
1757 let mut data = vec![0.0f32; num_features * num_detections];
1759 data[0] = 0.1; data[1] = 0.5; data[num_detections] = 0.1; data[num_detections + 1] = 0.5; data[2 * num_detections] = 0.4; data[2 * num_detections + 1] = 0.9; data[3 * num_detections] = 0.4; data[3 * num_detections + 1] = 0.9; data[4 * num_detections] = 0.9; data[4 * num_detections + 1] = 0.3; data[5 * num_detections] = 1.0; data[5 * num_detections + 1] = 2.0; for i in 6..num_features {
1774 data[i * num_detections] = 0.1;
1775 data[i * num_detections + 1] = 0.1;
1776 }
1777
1778 let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
1779
1780 let protos = Array3::<f32>::zeros((16, 16, num_protos));
1782
1783 let mut boxes = Vec::with_capacity(10);
1784 let mut masks = Vec::with_capacity(10);
1785 decode_yolo_end_to_end_segdet_float(
1786 output.view(),
1787 protos.view(),
1788 0.5,
1789 &mut boxes,
1790 &mut masks,
1791 )
1792 .unwrap();
1793
1794 assert_eq!(boxes.len(), 1);
1796 assert_eq!(masks.len(), 1);
1797 assert_eq!(boxes[0].label, 1);
1798 assert!((boxes[0].score - 0.9).abs() < 0.01);
1799 }
1800
1801 #[test]
1802 fn test_end_to_end_segdet_mask_coordinates() {
1803 let num_protos = 32;
1805 let num_features = 6 + num_protos;
1806
1807 let mut data = vec![0.0f32; num_features];
1808 data[0] = 0.2; data[1] = 0.2; data[2] = 0.8; data[3] = 0.8; data[4] = 0.95; data[5] = 3.0; let output = Array2::from_shape_vec((num_features, 1), data).unwrap();
1816 let protos = Array3::<f32>::zeros((16, 16, num_protos));
1817
1818 let mut boxes = Vec::with_capacity(10);
1819 let mut masks = Vec::with_capacity(10);
1820 decode_yolo_end_to_end_segdet_float(
1821 output.view(),
1822 protos.view(),
1823 0.5,
1824 &mut boxes,
1825 &mut masks,
1826 )
1827 .unwrap();
1828
1829 assert_eq!(boxes.len(), 1);
1830 assert_eq!(masks.len(), 1);
1831
1832 assert!((masks[0].xmin - boxes[0].bbox.xmin).abs() < 0.01);
1834 assert!((masks[0].ymin - boxes[0].bbox.ymin).abs() < 0.01);
1835 assert!((masks[0].xmax - boxes[0].bbox.xmax).abs() < 0.01);
1836 assert!((masks[0].ymax - boxes[0].bbox.ymax).abs() < 0.01);
1837 }
1838
1839 #[test]
1840 fn test_end_to_end_segdet_empty_output() {
1841 let num_protos = 32;
1842 let output = Array2::<f32>::zeros((6 + num_protos, 0));
1843 let protos = Array3::<f32>::zeros((16, 16, num_protos));
1844
1845 let mut boxes = Vec::with_capacity(10);
1846 let mut masks = Vec::with_capacity(10);
1847 decode_yolo_end_to_end_segdet_float(
1848 output.view(),
1849 protos.view(),
1850 0.5,
1851 &mut boxes,
1852 &mut masks,
1853 )
1854 .unwrap();
1855
1856 assert_eq!(boxes.len(), 0);
1857 assert_eq!(masks.len(), 0);
1858 }
1859
1860 #[test]
1861 fn test_end_to_end_segdet_capacity_limit() {
1862 let num_protos = 32;
1863 let num_detections = 5;
1864 let num_features = 6 + num_protos;
1865
1866 let mut data = vec![0.0f32; num_features * num_detections];
1867 for i in 0..num_detections {
1869 data[i] = 0.1 * (i as f32); data[num_detections + i] = 0.1 * (i as f32); data[2 * num_detections + i] = 0.1 * (i as f32) + 0.2; data[3 * num_detections + i] = 0.1 * (i as f32) + 0.2; data[4 * num_detections + i] = 0.9; data[5 * num_detections + i] = i as f32; }
1876
1877 let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
1878 let protos = Array3::<f32>::zeros((16, 16, num_protos));
1879
1880 let mut boxes = Vec::with_capacity(2); let mut masks = Vec::with_capacity(2);
1882 decode_yolo_end_to_end_segdet_float(
1883 output.view(),
1884 protos.view(),
1885 0.5,
1886 &mut boxes,
1887 &mut masks,
1888 )
1889 .unwrap();
1890
1891 assert_eq!(boxes.len(), 2);
1892 assert_eq!(masks.len(), 2);
1893 }
1894
1895 #[test]
1896 fn test_end_to_end_segdet_invalid_shape_too_few_rows() {
1897 let output = Array2::<f32>::zeros((6, 3));
1899 let protos = Array3::<f32>::zeros((16, 16, 32));
1900
1901 let mut boxes = Vec::with_capacity(10);
1902 let mut masks = Vec::with_capacity(10);
1903 let result = decode_yolo_end_to_end_segdet_float(
1904 output.view(),
1905 protos.view(),
1906 0.5,
1907 &mut boxes,
1908 &mut masks,
1909 );
1910
1911 assert!(result.is_err());
1912 assert!(matches!(
1913 result,
1914 Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 7 rows")
1915 ));
1916 }
1917
1918 #[test]
1919 fn test_end_to_end_segdet_invalid_shape_protos_mismatch() {
1920 let num_protos = 32;
1922 let output = Array2::<f32>::zeros((6 + 16, 3)); let protos = Array3::<f32>::zeros((16, 16, num_protos)); let mut boxes = Vec::with_capacity(10);
1926 let mut masks = Vec::with_capacity(10);
1927 let result = decode_yolo_end_to_end_segdet_float(
1928 output.view(),
1929 protos.view(),
1930 0.5,
1931 &mut boxes,
1932 &mut masks,
1933 );
1934
1935 assert!(result.is_err());
1936 assert!(matches!(
1937 result,
1938 Err(crate::DecoderError::InvalidShape(s)) if s.contains("doesn't match protos count")
1939 ));
1940 }
1941
1942 #[test]
1947 fn test_segmentation_to_mask_basic() {
1948 let data: Vec<u8> = vec![
1950 100, 200, 50, 150, 10, 255, 128, 64, 0, 127, 128, 255, 64, 64, 192, 192, ];
1955 let segmentation = Array3::from_shape_vec((4, 4, 1), data).unwrap();
1956
1957 let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
1958
1959 assert_eq!(mask[[0, 0]], 0); assert_eq!(mask[[0, 1]], 1); assert_eq!(mask[[0, 2]], 0); assert_eq!(mask[[0, 3]], 1); assert_eq!(mask[[1, 1]], 1); assert_eq!(mask[[1, 2]], 1); assert_eq!(mask[[2, 0]], 0); assert_eq!(mask[[2, 1]], 0); }
1969
1970 #[test]
1971 fn test_segmentation_to_mask_all_above() {
1972 let segmentation = Array3::from_elem((4, 4, 1), 255u8);
1973 let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
1974 assert!(mask.iter().all(|&x| x == 1));
1975 }
1976
1977 #[test]
1978 fn test_segmentation_to_mask_all_below() {
1979 let segmentation = Array3::from_elem((4, 4, 1), 64u8);
1980 let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
1981 assert!(mask.iter().all(|&x| x == 0));
1982 }
1983
1984 #[test]
1985 fn test_segmentation_to_mask_invalid_shape() {
1986 let segmentation = Array3::from_elem((4, 4, 3), 128u8);
1987 let result = yolo_segmentation_to_mask(segmentation.view(), 128);
1988
1989 assert!(result.is_err());
1990 assert!(matches!(
1991 result,
1992 Err(crate::DecoderError::InvalidShape(s)) if s.contains("(H, W, 1)")
1993 ));
1994 }
1995
1996 #[test]
2001 fn test_protobox_clamps_edge_coordinates() {
2002 let protos = Array3::<f32>::zeros((16, 16, 4));
2004 let view = protos.view();
2005 let roi = BoundingBox {
2006 xmin: 0.5,
2007 ymin: 0.5,
2008 xmax: 1.0,
2009 ymax: 1.0,
2010 };
2011 let result = protobox(&view, &roi);
2012 assert!(result.is_ok(), "protobox should accept xmax=1.0");
2013 let (cropped, _roi_norm) = result.unwrap();
2014 assert!(cropped.shape()[0] > 0);
2016 assert!(cropped.shape()[1] > 0);
2017 assert_eq!(cropped.shape()[2], 4);
2018 }
2019
2020 #[test]
2021 fn test_protobox_rejects_wildly_out_of_range() {
2022 let protos = Array3::<f32>::zeros((16, 16, 4));
2024 let view = protos.view();
2025 let roi = BoundingBox {
2026 xmin: 0.0,
2027 ymin: 0.0,
2028 xmax: 3.0,
2029 ymax: 3.0,
2030 };
2031 let result = protobox(&view, &roi);
2032 assert!(
2033 matches!(result, Err(crate::DecoderError::InvalidShape(s)) if s.contains("un-normalized")),
2034 "protobox should reject coords > NORM_LIMIT"
2035 );
2036 }
2037
2038 #[test]
2039 fn test_protobox_accepts_slightly_over_one() {
2040 let protos = Array3::<f32>::zeros((16, 16, 4));
2042 let view = protos.view();
2043 let roi = BoundingBox {
2044 xmin: 0.0,
2045 ymin: 0.0,
2046 xmax: 1.5,
2047 ymax: 1.5,
2048 };
2049 let result = protobox(&view, &roi);
2050 assert!(
2051 result.is_ok(),
2052 "protobox should accept coords <= NORM_LIMIT (2.0)"
2053 );
2054 let (cropped, _roi_norm) = result.unwrap();
2055 assert_eq!(cropped.shape()[0], 16);
2057 assert_eq!(cropped.shape()[1], 16);
2058 }
2059
2060 #[test]
2061 fn test_segdet_float_proto_no_panic() {
2062 let num_proposals = 100; let num_classes = 80;
2066 let num_mask_coeffs = 32;
2067 let rows = 4 + num_classes + num_mask_coeffs; let mut data = vec![0.0f32; rows * num_proposals];
2073 for i in 0..num_proposals {
2074 let row = |r: usize| r * num_proposals + i;
2075 data[row(0)] = 320.0; data[row(1)] = 320.0; data[row(2)] = 50.0; data[row(3)] = 50.0; data[row(4)] = 0.9; }
2081 let boxes = ndarray::Array2::from_shape_vec((rows, num_proposals), data).unwrap();
2082
2083 let protos = ndarray::Array3::<f32>::zeros((160, 160, num_mask_coeffs));
2086
2087 let mut output_boxes = Vec::with_capacity(300);
2088
2089 let proto_data = impl_yolo_segdet_float_proto::<XYWH, _, _>(
2091 boxes.view(),
2092 protos.view(),
2093 0.5,
2094 0.7,
2095 Some(Nms::default()),
2096 &mut output_boxes,
2097 );
2098
2099 assert!(!output_boxes.is_empty());
2101 assert_eq!(proto_data.mask_coefficients.len(), output_boxes.len());
2102 for coeffs in &proto_data.mask_coefficients {
2104 assert_eq!(coeffs.len(), num_mask_coeffs);
2105 }
2106 }
2107}