1use std::fmt::Debug;
5
6use ndarray::{
7 parallel::prelude::{IntoParallelIterator, ParallelIterator},
8 s, Array2, Array3, ArrayView1, ArrayView2, ArrayView3,
9};
10use num_traits::{AsPrimitive, Float, PrimInt, Signed};
11use rayon::slice::ParallelSliceMut;
12
13use crate::{
14 byte::{
15 nms_class_aware_int, nms_extra_class_aware_int, nms_extra_int, nms_int,
16 postprocess_boxes_index_quant, postprocess_boxes_quant, quantize_score_threshold,
17 },
18 configs::Nms,
19 dequant_detect_box,
20 float::{
21 nms_class_aware_float, nms_extra_class_aware_float, nms_extra_float, nms_float,
22 postprocess_boxes_float, postprocess_boxes_index_float,
23 },
24 BBoxTypeTrait, BoundingBox, DetectBox, DetectBoxQuantized, ProtoData, ProtoTensor,
25 Quantization, Segmentation, XYWH, XYXY,
26};
27
28pub(crate) const MAX_NMS_CANDIDATES: usize = 30_000;
43
44fn truncate_to_top_k_by_score<E: Send>(boxes: &mut Vec<(DetectBox, E)>) {
49 if boxes.len() > MAX_NMS_CANDIDATES {
50 boxes.par_sort_unstable_by(|a, b| b.0.score.total_cmp(&a.0.score));
51 boxes.truncate(MAX_NMS_CANDIDATES);
52 }
53}
54
55fn truncate_to_top_k_by_score_quant<S: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send>(
59 boxes: &mut Vec<(DetectBoxQuantized<S>, E)>,
60) {
61 if boxes.len() > MAX_NMS_CANDIDATES {
62 boxes.par_sort_unstable_by(|a, b| b.0.score.cmp(&a.0.score));
63 boxes.truncate(MAX_NMS_CANDIDATES);
64 }
65}
66
67fn dispatch_nms_float(nms: Option<Nms>, iou: f32, boxes: Vec<DetectBox>) -> Vec<DetectBox> {
69 match nms {
70 Some(Nms::ClassAgnostic) => nms_float(iou, boxes),
71 Some(Nms::ClassAware) => nms_class_aware_float(iou, boxes),
72 None => boxes, }
74}
75
76pub(super) fn dispatch_nms_extra_float<E: Send + Sync>(
79 nms: Option<Nms>,
80 iou: f32,
81 boxes: Vec<(DetectBox, E)>,
82) -> Vec<(DetectBox, E)> {
83 match nms {
84 Some(Nms::ClassAgnostic) => nms_extra_float(iou, boxes),
85 Some(Nms::ClassAware) => nms_extra_class_aware_float(iou, boxes),
86 None => boxes, }
88}
89
90fn dispatch_nms_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync>(
93 nms: Option<Nms>,
94 iou: f32,
95 boxes: Vec<DetectBoxQuantized<SCORE>>,
96) -> Vec<DetectBoxQuantized<SCORE>> {
97 match nms {
98 Some(Nms::ClassAgnostic) => nms_int(iou, boxes),
99 Some(Nms::ClassAware) => nms_class_aware_int(iou, boxes),
100 None => boxes, }
102}
103
104fn dispatch_nms_extra_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send + Sync>(
107 nms: Option<Nms>,
108 iou: f32,
109 boxes: Vec<(DetectBoxQuantized<SCORE>, E)>,
110) -> Vec<(DetectBoxQuantized<SCORE>, E)> {
111 match nms {
112 Some(Nms::ClassAgnostic) => nms_extra_int(iou, boxes),
113 Some(Nms::ClassAware) => nms_extra_class_aware_int(iou, boxes),
114 None => boxes, }
116}
117
118pub fn decode_yolo_det<BOX: PrimInt + AsPrimitive<f32> + Send + Sync>(
125 output: (ArrayView2<BOX>, Quantization),
126 score_threshold: f32,
127 iou_threshold: f32,
128 nms: Option<Nms>,
129 output_boxes: &mut Vec<DetectBox>,
130) where
131 f32: AsPrimitive<BOX>,
132{
133 impl_yolo_quant::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
134}
135
136pub fn decode_yolo_det_float<T>(
143 output: ArrayView2<T>,
144 score_threshold: f32,
145 iou_threshold: f32,
146 nms: Option<Nms>,
147 output_boxes: &mut Vec<DetectBox>,
148) where
149 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
150 f32: AsPrimitive<T>,
151{
152 impl_yolo_float::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
153}
154
155pub fn decode_yolo_segdet_quant<
167 BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
168 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
169>(
170 boxes: (ArrayView2<BOX>, Quantization),
171 protos: (ArrayView3<PROTO>, Quantization),
172 score_threshold: f32,
173 iou_threshold: f32,
174 nms: Option<Nms>,
175 output_boxes: &mut Vec<DetectBox>,
176 output_masks: &mut Vec<Segmentation>,
177) -> Result<(), crate::DecoderError>
178where
179 f32: AsPrimitive<BOX>,
180{
181 impl_yolo_segdet_quant::<XYWH, _, _>(
182 boxes,
183 protos,
184 score_threshold,
185 iou_threshold,
186 nms,
187 output_boxes,
188 output_masks,
189 )
190}
191
192pub fn decode_yolo_segdet_float<T>(
204 boxes: ArrayView2<T>,
205 protos: ArrayView3<T>,
206 score_threshold: f32,
207 iou_threshold: f32,
208 nms: Option<Nms>,
209 output_boxes: &mut Vec<DetectBox>,
210 output_masks: &mut Vec<Segmentation>,
211) -> Result<(), crate::DecoderError>
212where
213 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
214 f32: AsPrimitive<T>,
215{
216 impl_yolo_segdet_float::<XYWH, _, _>(
217 boxes,
218 protos,
219 score_threshold,
220 iou_threshold,
221 nms,
222 output_boxes,
223 output_masks,
224 )
225}
226
227pub fn decode_yolo_split_det_quant<
239 BOX: PrimInt + AsPrimitive<i32> + AsPrimitive<f32> + Send + Sync,
240 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
241>(
242 boxes: (ArrayView2<BOX>, Quantization),
243 scores: (ArrayView2<SCORE>, Quantization),
244 score_threshold: f32,
245 iou_threshold: f32,
246 nms: Option<Nms>,
247 output_boxes: &mut Vec<DetectBox>,
248) where
249 f32: AsPrimitive<SCORE>,
250{
251 impl_yolo_split_quant::<XYWH, _, _>(
252 boxes,
253 scores,
254 score_threshold,
255 iou_threshold,
256 nms,
257 output_boxes,
258 );
259}
260
261pub fn decode_yolo_split_det_float<T>(
273 boxes: ArrayView2<T>,
274 scores: ArrayView2<T>,
275 score_threshold: f32,
276 iou_threshold: f32,
277 nms: Option<Nms>,
278 output_boxes: &mut Vec<DetectBox>,
279) where
280 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
281 f32: AsPrimitive<T>,
282{
283 impl_yolo_split_float::<XYWH, _, _>(
284 boxes,
285 scores,
286 score_threshold,
287 iou_threshold,
288 nms,
289 output_boxes,
290 );
291}
292
293#[allow(clippy::too_many_arguments)]
307pub fn decode_yolo_split_segdet<
308 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
309 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
310 MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
311 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
312>(
313 boxes: (ArrayView2<BOX>, Quantization),
314 scores: (ArrayView2<SCORE>, Quantization),
315 mask_coeff: (ArrayView2<MASK>, Quantization),
316 protos: (ArrayView3<PROTO>, Quantization),
317 score_threshold: f32,
318 iou_threshold: f32,
319 nms: Option<Nms>,
320 output_boxes: &mut Vec<DetectBox>,
321 output_masks: &mut Vec<Segmentation>,
322) -> Result<(), crate::DecoderError>
323where
324 f32: AsPrimitive<SCORE>,
325{
326 impl_yolo_split_segdet_quant::<XYWH, _, _, _, _>(
327 boxes,
328 scores,
329 mask_coeff,
330 protos,
331 score_threshold,
332 iou_threshold,
333 nms,
334 output_boxes,
335 output_masks,
336 )
337}
338
339#[allow(clippy::too_many_arguments)]
353pub fn decode_yolo_split_segdet_float<T>(
354 boxes: ArrayView2<T>,
355 scores: ArrayView2<T>,
356 mask_coeff: ArrayView2<T>,
357 protos: ArrayView3<T>,
358 score_threshold: f32,
359 iou_threshold: f32,
360 nms: Option<Nms>,
361 output_boxes: &mut Vec<DetectBox>,
362 output_masks: &mut Vec<Segmentation>,
363) -> Result<(), crate::DecoderError>
364where
365 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
366 f32: AsPrimitive<T>,
367{
368 impl_yolo_split_segdet_float::<XYWH, _, _, _, _>(
369 boxes,
370 scores,
371 mask_coeff,
372 protos,
373 score_threshold,
374 iou_threshold,
375 nms,
376 output_boxes,
377 output_masks,
378 )
379}
380
381pub fn decode_yolo_end_to_end_det_float<T>(
396 output: ArrayView2<T>,
397 score_threshold: f32,
398 output_boxes: &mut Vec<DetectBox>,
399) -> Result<(), crate::DecoderError>
400where
401 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
402 f32: AsPrimitive<T>,
403{
404 if output.shape()[0] < 6 {
406 return Err(crate::DecoderError::InvalidShape(format!(
407 "End-to-end detection output requires at least 6 rows, got {}",
408 output.shape()[0]
409 )));
410 }
411
412 let boxes = output.slice(s![0..4, ..]).reversed_axes();
414 let scores = output.slice(s![4..5, ..]).reversed_axes();
415 let classes = output.slice(s![5, ..]);
416 let mut boxes =
417 postprocess_boxes_index_float::<XYXY, _, _>(score_threshold.as_(), boxes, scores);
418 boxes.truncate(output_boxes.capacity());
419 output_boxes.clear();
420 for (mut b, i) in boxes.into_iter() {
421 b.label = classes[i].as_() as usize;
422 output_boxes.push(b);
423 }
424 Ok(())
426}
427
428pub fn decode_yolo_end_to_end_segdet_float<T>(
446 output: ArrayView2<T>,
447 protos: ArrayView3<T>,
448 score_threshold: f32,
449 output_boxes: &mut Vec<DetectBox>,
450 output_masks: &mut Vec<crate::Segmentation>,
451) -> Result<(), crate::DecoderError>
452where
453 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
454 f32: AsPrimitive<T>,
455{
456 let (boxes, scores, classes, mask_coeff) =
457 postprocess_yolo_end_to_end_segdet(&output, protos.dim().2)?;
458 let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
459 boxes,
460 scores,
461 classes,
462 score_threshold,
463 output_boxes.capacity(),
464 );
465
466 impl_yolo_split_segdet_process_masks(boxes, mask_coeff, protos, output_boxes, output_masks)
469}
470
471pub fn decode_yolo_split_end_to_end_det_float<T: Float + AsPrimitive<f32>>(
480 boxes: ArrayView2<T>,
481 scores: ArrayView2<T>,
482 classes: ArrayView2<T>,
483 score_threshold: f32,
484 output_boxes: &mut Vec<DetectBox>,
485) -> Result<(), crate::DecoderError> {
486 let n = boxes.shape()[1];
487
488 output_boxes.clear();
489
490 let (boxes, scores, classes) = postprocess_yolo_split_end_to_end_det(boxes, scores, &classes)?;
491
492 for i in 0..n {
493 let score: f32 = scores[[i, 0]].as_();
494 if score < score_threshold {
495 continue;
496 }
497 if output_boxes.len() >= output_boxes.capacity() {
498 break;
499 }
500 output_boxes.push(DetectBox {
501 bbox: BoundingBox {
502 xmin: boxes[[i, 0]].as_(),
503 ymin: boxes[[i, 1]].as_(),
504 xmax: boxes[[i, 2]].as_(),
505 ymax: boxes[[i, 3]].as_(),
506 },
507 score,
508 label: classes[i].as_() as usize,
509 });
510 }
511 Ok(())
512}
513
514#[allow(clippy::too_many_arguments)]
523pub fn decode_yolo_split_end_to_end_segdet_float<T>(
524 boxes: ArrayView2<T>,
525 scores: ArrayView2<T>,
526 classes: ArrayView2<T>,
527 mask_coeff: ArrayView2<T>,
528 protos: ArrayView3<T>,
529 score_threshold: f32,
530 output_boxes: &mut Vec<DetectBox>,
531 output_masks: &mut Vec<crate::Segmentation>,
532) -> Result<(), crate::DecoderError>
533where
534 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
535 f32: AsPrimitive<T>,
536{
537 let (boxes, scores, classes, mask_coeff) =
538 postprocess_yolo_split_end_to_end_segdet(boxes, scores, &classes, mask_coeff)?;
539 let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
540 boxes,
541 scores,
542 classes,
543 score_threshold,
544 output_boxes.capacity(),
545 );
546
547 impl_yolo_split_segdet_process_masks(boxes, mask_coeff, protos, output_boxes, output_masks)
548}
549
550#[allow(clippy::type_complexity)]
551pub(crate) fn postprocess_yolo_end_to_end_segdet<'a, T>(
552 output: &'a ArrayView2<'_, T>,
553 num_protos: usize,
554) -> Result<
555 (
556 ArrayView2<'a, T>,
557 ArrayView2<'a, T>,
558 ArrayView1<'a, T>,
559 ArrayView2<'a, T>,
560 ),
561 crate::DecoderError,
562> {
563 if output.shape()[0] < 7 {
565 return Err(crate::DecoderError::InvalidShape(format!(
566 "End-to-end segdet output requires at least 7 rows, got {}",
567 output.shape()[0]
568 )));
569 }
570
571 let num_mask_coeffs = output.shape()[0] - 6;
572 if num_mask_coeffs != num_protos {
573 return Err(crate::DecoderError::InvalidShape(format!(
574 "Mask coefficients count ({}) doesn't match protos count ({})",
575 num_mask_coeffs, num_protos
576 )));
577 }
578
579 let boxes = output.slice(s![0..4, ..]).reversed_axes();
581 let scores = output.slice(s![4..5, ..]).reversed_axes();
582 let classes = output.slice(s![5, ..]);
583 let mask_coeff = output.slice(s![6.., ..]).reversed_axes();
584 Ok((boxes, scores, classes, mask_coeff))
585}
586
587#[allow(clippy::type_complexity)]
594pub(crate) fn postprocess_yolo_split_end_to_end_det<'a, 'b, 'c, BOXES, SCORES, CLASS>(
595 boxes: ArrayView2<'a, BOXES>,
596 scores: ArrayView2<'b, SCORES>,
597 classes: &'c ArrayView2<CLASS>,
598) -> Result<
599 (
600 ArrayView2<'a, BOXES>,
601 ArrayView2<'b, SCORES>,
602 ArrayView1<'c, CLASS>,
603 ),
604 crate::DecoderError,
605> {
606 let num_boxes = boxes.shape()[1];
607 if boxes.shape()[0] != 4 {
608 return Err(crate::DecoderError::InvalidShape(format!(
609 "Split end-to-end box_coords must be 4, got {}",
610 boxes.shape()[0]
611 )));
612 }
613
614 if scores.shape()[0] != 1 {
615 return Err(crate::DecoderError::InvalidShape(format!(
616 "Split end-to-end scores num_classes must be 1, got {}",
617 scores.shape()[0]
618 )));
619 }
620
621 if classes.shape()[0] != 1 {
622 return Err(crate::DecoderError::InvalidShape(format!(
623 "Split end-to-end classes num_classes must be 1, got {}",
624 classes.shape()[0]
625 )));
626 }
627
628 if scores.shape()[1] != num_boxes {
629 return Err(crate::DecoderError::InvalidShape(format!(
630 "Split end-to-end scores must have same num_boxes as boxes ({}), got {}",
631 num_boxes,
632 scores.shape()[1]
633 )));
634 }
635
636 if classes.shape()[1] != num_boxes {
637 return Err(crate::DecoderError::InvalidShape(format!(
638 "Split end-to-end classes must have same num_boxes as boxes ({}), got {}",
639 num_boxes,
640 classes.shape()[1]
641 )));
642 }
643
644 let boxes = boxes.reversed_axes();
645 let scores = scores.reversed_axes();
646 let classes = classes.slice(s![0, ..]);
647 Ok((boxes, scores, classes))
648}
649
650#[allow(clippy::type_complexity)]
653pub(crate) fn postprocess_yolo_split_end_to_end_segdet<
654 'a,
655 'b,
656 'c,
657 'd,
658 BOXES,
659 SCORES,
660 CLASS,
661 MASK,
662>(
663 boxes: ArrayView2<'a, BOXES>,
664 scores: ArrayView2<'b, SCORES>,
665 classes: &'c ArrayView2<CLASS>,
666 mask_coeff: ArrayView2<'d, MASK>,
667) -> Result<
668 (
669 ArrayView2<'a, BOXES>,
670 ArrayView2<'b, SCORES>,
671 ArrayView1<'c, CLASS>,
672 ArrayView2<'d, MASK>,
673 ),
674 crate::DecoderError,
675> {
676 let num_boxes = boxes.shape()[1];
677 if boxes.shape()[0] != 4 {
678 return Err(crate::DecoderError::InvalidShape(format!(
679 "Split end-to-end box_coords must be 4, got {}",
680 boxes.shape()[0]
681 )));
682 }
683
684 if scores.shape()[0] != 1 {
685 return Err(crate::DecoderError::InvalidShape(format!(
686 "Split end-to-end scores num_classes must be 1, got {}",
687 scores.shape()[0]
688 )));
689 }
690
691 if classes.shape()[0] != 1 {
692 return Err(crate::DecoderError::InvalidShape(format!(
693 "Split end-to-end classes num_classes must be 1, got {}",
694 classes.shape()[0]
695 )));
696 }
697
698 if scores.shape()[1] != num_boxes {
699 return Err(crate::DecoderError::InvalidShape(format!(
700 "Split end-to-end scores must have same num_boxes as boxes ({}), got {}",
701 num_boxes,
702 scores.shape()[1]
703 )));
704 }
705
706 if classes.shape()[1] != num_boxes {
707 return Err(crate::DecoderError::InvalidShape(format!(
708 "Split end-to-end classes must have same num_boxes as boxes ({}), got {}",
709 num_boxes,
710 classes.shape()[1]
711 )));
712 }
713
714 if mask_coeff.shape()[1] != num_boxes {
715 return Err(crate::DecoderError::InvalidShape(format!(
716 "Split end-to-end mask_coeff must have same num_boxes as boxes ({}), got {}",
717 num_boxes,
718 mask_coeff.shape()[1]
719 )));
720 }
721
722 let boxes = boxes.reversed_axes();
723 let scores = scores.reversed_axes();
724 let classes = classes.slice(s![0, ..]);
725 let mask_coeff = mask_coeff.reversed_axes();
726 Ok((boxes, scores, classes, mask_coeff))
727}
728pub(crate) fn impl_yolo_quant<B: BBoxTypeTrait, T: PrimInt + AsPrimitive<f32> + Send + Sync>(
733 output: (ArrayView2<T>, Quantization),
734 score_threshold: f32,
735 iou_threshold: f32,
736 nms: Option<Nms>,
737 output_boxes: &mut Vec<DetectBox>,
738) where
739 f32: AsPrimitive<T>,
740{
741 let (boxes, quant_boxes) = output;
742 let (boxes_tensor, scores_tensor) = postprocess_yolo(&boxes);
743
744 let boxes = {
745 let score_threshold = quantize_score_threshold(score_threshold, quant_boxes);
746 postprocess_boxes_quant::<B, _, _>(
747 score_threshold,
748 boxes_tensor,
749 scores_tensor,
750 quant_boxes,
751 )
752 };
753
754 let boxes = dispatch_nms_int(nms, iou_threshold, boxes);
755 let len = output_boxes.capacity().min(boxes.len());
756 output_boxes.clear();
757 for b in boxes.iter().take(len) {
758 output_boxes.push(dequant_detect_box(b, quant_boxes));
759 }
760}
761
762pub(crate) fn impl_yolo_float<B: BBoxTypeTrait, T: Float + AsPrimitive<f32> + Send + Sync>(
767 output: ArrayView2<T>,
768 score_threshold: f32,
769 iou_threshold: f32,
770 nms: Option<Nms>,
771 output_boxes: &mut Vec<DetectBox>,
772) where
773 f32: AsPrimitive<T>,
774{
775 let (boxes_tensor, scores_tensor) = postprocess_yolo(&output);
776 let boxes =
777 postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
778 let boxes = dispatch_nms_float(nms, iou_threshold, boxes);
779 let len = output_boxes.capacity().min(boxes.len());
780 output_boxes.clear();
781 for b in boxes.into_iter().take(len) {
782 output_boxes.push(b);
783 }
784}
785
786pub(crate) fn impl_yolo_split_quant<
796 B: BBoxTypeTrait,
797 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
798 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
799>(
800 boxes: (ArrayView2<BOX>, Quantization),
801 scores: (ArrayView2<SCORE>, Quantization),
802 score_threshold: f32,
803 iou_threshold: f32,
804 nms: Option<Nms>,
805 output_boxes: &mut Vec<DetectBox>,
806) where
807 f32: AsPrimitive<SCORE>,
808{
809 let (boxes_tensor, quant_boxes) = boxes;
810 let (scores_tensor, quant_scores) = scores;
811
812 let boxes_tensor = boxes_tensor.reversed_axes();
813 let scores_tensor = scores_tensor.reversed_axes();
814
815 let boxes = {
816 let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
817 postprocess_boxes_quant::<B, _, _>(
818 score_threshold,
819 boxes_tensor,
820 scores_tensor,
821 quant_boxes,
822 )
823 };
824
825 let boxes = dispatch_nms_int(nms, iou_threshold, boxes);
826 let len = output_boxes.capacity().min(boxes.len());
827 output_boxes.clear();
828 for b in boxes.iter().take(len) {
829 output_boxes.push(dequant_detect_box(b, quant_scores));
830 }
831}
832
833pub(crate) fn impl_yolo_split_float<
842 B: BBoxTypeTrait,
843 BOX: Float + AsPrimitive<f32> + Send + Sync,
844 SCORE: Float + AsPrimitive<f32> + Send + Sync,
845>(
846 boxes_tensor: ArrayView2<BOX>,
847 scores_tensor: ArrayView2<SCORE>,
848 score_threshold: f32,
849 iou_threshold: f32,
850 nms: Option<Nms>,
851 output_boxes: &mut Vec<DetectBox>,
852) where
853 f32: AsPrimitive<SCORE>,
854{
855 let boxes_tensor = boxes_tensor.reversed_axes();
856 let scores_tensor = scores_tensor.reversed_axes();
857 let boxes =
858 postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
859 let boxes = dispatch_nms_float(nms, iou_threshold, boxes);
860 let len = output_boxes.capacity().min(boxes.len());
861 output_boxes.clear();
862 for b in boxes.into_iter().take(len) {
863 output_boxes.push(b);
864 }
865}
866
867pub(crate) fn impl_yolo_segdet_quant<
877 B: BBoxTypeTrait,
878 BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
879 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
880>(
881 boxes: (ArrayView2<BOX>, Quantization),
882 protos: (ArrayView3<PROTO>, Quantization),
883 score_threshold: f32,
884 iou_threshold: f32,
885 nms: Option<Nms>,
886 output_boxes: &mut Vec<DetectBox>,
887 output_masks: &mut Vec<Segmentation>,
888) -> Result<(), crate::DecoderError>
889where
890 f32: AsPrimitive<BOX>,
891{
892 let (boxes, quant_boxes) = boxes;
893 let num_protos = protos.0.dim().2;
894
895 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
896 let boxes = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
897 (boxes_tensor, quant_boxes),
898 (scores_tensor, quant_boxes),
899 score_threshold,
900 iou_threshold,
901 nms,
902 output_boxes.capacity(),
903 );
904
905 impl_yolo_split_segdet_quant_process_masks::<_, _>(
906 boxes,
907 (mask_tensor, quant_boxes),
908 protos,
909 output_boxes,
910 output_masks,
911 )
912}
913
914pub(crate) fn impl_yolo_segdet_float<
924 B: BBoxTypeTrait,
925 BOX: Float + AsPrimitive<f32> + Send + Sync,
926 PROTO: Float + AsPrimitive<f32> + Send + Sync,
927>(
928 boxes: ArrayView2<BOX>,
929 protos: ArrayView3<PROTO>,
930 score_threshold: f32,
931 iou_threshold: f32,
932 nms: Option<Nms>,
933 output_boxes: &mut Vec<DetectBox>,
934 output_masks: &mut Vec<Segmentation>,
935) -> Result<(), crate::DecoderError>
936where
937 f32: AsPrimitive<BOX>,
938{
939 let num_protos = protos.dim().2;
940 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
941 let boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
942 boxes_tensor,
943 scores_tensor,
944 score_threshold,
945 iou_threshold,
946 nms,
947 output_boxes.capacity(),
948 );
949 impl_yolo_split_segdet_process_masks(boxes, mask_tensor, protos, output_boxes, output_masks)
950}
951
952pub(crate) fn impl_yolo_segdet_get_boxes<
953 B: BBoxTypeTrait,
954 BOX: Float + AsPrimitive<f32> + Send + Sync,
955 SCORE: Float + AsPrimitive<f32> + Send + Sync,
956>(
957 boxes_tensor: ArrayView2<BOX>,
958 scores_tensor: ArrayView2<SCORE>,
959 score_threshold: f32,
960 iou_threshold: f32,
961 nms: Option<Nms>,
962 max_boxes: usize,
963) -> Vec<(DetectBox, usize)>
964where
965 f32: AsPrimitive<SCORE>,
966{
967 let mut boxes = postprocess_boxes_index_float::<B, _, _>(
968 score_threshold.as_(),
969 boxes_tensor,
970 scores_tensor,
971 );
972 truncate_to_top_k_by_score(&mut boxes);
973 let mut boxes = dispatch_nms_extra_float(nms, iou_threshold, boxes);
974 boxes.truncate(max_boxes);
975 boxes
976}
977
978pub(crate) fn impl_yolo_end_to_end_segdet_get_boxes<
979 B: BBoxTypeTrait,
980 BOX: Float + AsPrimitive<f32> + Send + Sync,
981 SCORE: Float + AsPrimitive<f32> + Send + Sync,
982 CLASS: AsPrimitive<f32> + Send + Sync,
983>(
984 boxes: ArrayView2<BOX>,
985 scores: ArrayView2<SCORE>,
986 classes: ArrayView1<CLASS>,
987 score_threshold: f32,
988 max_boxes: usize,
989) -> Vec<(DetectBox, usize)>
990where
991 f32: AsPrimitive<SCORE>,
992{
993 let mut boxes = postprocess_boxes_index_float::<B, _, _>(score_threshold.as_(), boxes, scores);
994 boxes.truncate(max_boxes);
995 for (b, ind) in &mut boxes {
996 b.label = classes[*ind].as_().round() as usize;
997 }
998 boxes
999}
1000
1001pub(crate) fn impl_yolo_split_segdet_process_masks<
1002 MASK: Float + AsPrimitive<f32> + Send + Sync,
1003 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1004>(
1005 boxes: Vec<(DetectBox, usize)>,
1006 masks_tensor: ArrayView2<MASK>,
1007 protos_tensor: ArrayView3<PROTO>,
1008 output_boxes: &mut Vec<DetectBox>,
1009 output_masks: &mut Vec<Segmentation>,
1010) -> Result<(), crate::DecoderError> {
1011 let boxes = decode_segdet_f32(boxes, masks_tensor, protos_tensor)?;
1012 output_boxes.clear();
1013 output_masks.clear();
1014 for (b, m) in boxes.into_iter() {
1015 output_boxes.push(b);
1016 output_masks.push(Segmentation {
1017 xmin: b.bbox.xmin,
1018 ymin: b.bbox.ymin,
1019 xmax: b.bbox.xmax,
1020 ymax: b.bbox.ymax,
1021 segmentation: m,
1022 });
1023 }
1024 Ok(())
1025}
1026pub(crate) fn impl_yolo_split_segdet_quant_get_boxes<
1030 B: BBoxTypeTrait,
1031 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
1032 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
1033>(
1034 boxes: (ArrayView2<BOX>, Quantization),
1035 scores: (ArrayView2<SCORE>, Quantization),
1036 score_threshold: f32,
1037 iou_threshold: f32,
1038 nms: Option<Nms>,
1039 max_boxes: usize,
1040) -> Vec<(DetectBox, usize)>
1041where
1042 f32: AsPrimitive<SCORE>,
1043{
1044 let (boxes_tensor, quant_boxes) = boxes;
1045 let (scores_tensor, quant_scores) = scores;
1046
1047 let mut boxes = {
1048 let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
1049 postprocess_boxes_index_quant::<B, _, _>(
1050 score_threshold,
1051 boxes_tensor,
1052 scores_tensor,
1053 quant_boxes,
1054 )
1055 };
1056 truncate_to_top_k_by_score_quant(&mut boxes);
1057 let mut boxes = dispatch_nms_extra_int(nms, iou_threshold, boxes);
1058 boxes.truncate(max_boxes);
1059 boxes
1060 .into_iter()
1061 .map(|(b, i)| (dequant_detect_box(&b, quant_scores), i))
1062 .collect()
1063}
1064
1065pub(crate) fn impl_yolo_split_segdet_quant_process_masks<
1066 MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1067 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1068>(
1069 boxes: Vec<(DetectBox, usize)>,
1070 mask_coeff: (ArrayView2<MASK>, Quantization),
1071 protos: (ArrayView3<PROTO>, Quantization),
1072 output_boxes: &mut Vec<DetectBox>,
1073 output_masks: &mut Vec<Segmentation>,
1074) -> Result<(), crate::DecoderError> {
1075 let (masks, quant_masks) = mask_coeff;
1076 let (protos, quant_protos) = protos;
1077
1078 let boxes = decode_segdet_quant(boxes, masks, protos, quant_masks, quant_protos)?;
1079 output_boxes.clear();
1080 output_masks.clear();
1081 for (b, m) in boxes.into_iter() {
1082 output_boxes.push(b);
1083 output_masks.push(Segmentation {
1084 xmin: b.bbox.xmin,
1085 ymin: b.bbox.ymin,
1086 xmax: b.bbox.xmax,
1087 ymax: b.bbox.ymax,
1088 segmentation: m,
1089 });
1090 }
1091 Ok(())
1092}
1093
1094#[allow(clippy::too_many_arguments)]
1095pub(crate) fn impl_yolo_split_segdet_quant<
1107 B: BBoxTypeTrait,
1108 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
1109 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
1110 MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1111 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1112>(
1113 boxes: (ArrayView2<BOX>, Quantization),
1114 scores: (ArrayView2<SCORE>, Quantization),
1115 mask_coeff: (ArrayView2<MASK>, Quantization),
1116 protos: (ArrayView3<PROTO>, Quantization),
1117 score_threshold: f32,
1118 iou_threshold: f32,
1119 nms: Option<Nms>,
1120 output_boxes: &mut Vec<DetectBox>,
1121 output_masks: &mut Vec<Segmentation>,
1122) -> Result<(), crate::DecoderError>
1123where
1124 f32: AsPrimitive<SCORE>,
1125{
1126 let (boxes_, scores_, mask_coeff_) =
1127 postprocess_yolo_split_segdet(boxes.0, scores.0, mask_coeff.0);
1128 let boxes = (boxes_, boxes.1);
1129 let scores = (scores_, scores.1);
1130 let mask_coeff = (mask_coeff_, mask_coeff.1);
1131
1132 let boxes = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
1133 boxes,
1134 scores,
1135 score_threshold,
1136 iou_threshold,
1137 nms,
1138 output_boxes.capacity(),
1139 );
1140
1141 impl_yolo_split_segdet_quant_process_masks(
1142 boxes,
1143 mask_coeff,
1144 protos,
1145 output_boxes,
1146 output_masks,
1147 )
1148}
1149
1150#[allow(clippy::too_many_arguments)]
1151pub(crate) fn impl_yolo_split_segdet_float<
1163 B: BBoxTypeTrait,
1164 BOX: Float + AsPrimitive<f32> + Send + Sync,
1165 SCORE: Float + AsPrimitive<f32> + Send + Sync,
1166 MASK: Float + AsPrimitive<f32> + Send + Sync,
1167 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1168>(
1169 boxes_tensor: ArrayView2<BOX>,
1170 scores_tensor: ArrayView2<SCORE>,
1171 mask_tensor: ArrayView2<MASK>,
1172 protos: ArrayView3<PROTO>,
1173 score_threshold: f32,
1174 iou_threshold: f32,
1175 nms: Option<Nms>,
1176 output_boxes: &mut Vec<DetectBox>,
1177 output_masks: &mut Vec<Segmentation>,
1178) -> Result<(), crate::DecoderError>
1179where
1180 f32: AsPrimitive<SCORE>,
1181{
1182 let (boxes_tensor, scores_tensor, mask_tensor) =
1183 postprocess_yolo_split_segdet(boxes_tensor, scores_tensor, mask_tensor);
1184
1185 let boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
1186 boxes_tensor,
1187 scores_tensor,
1188 score_threshold,
1189 iou_threshold,
1190 nms,
1191 output_boxes.capacity(),
1192 );
1193 impl_yolo_split_segdet_process_masks(boxes, mask_tensor, protos, output_boxes, output_masks)
1194}
1195
1196pub fn impl_yolo_segdet_quant_proto<
1203 B: BBoxTypeTrait,
1204 BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1205 PROTO: PrimInt
1206 + AsPrimitive<i64>
1207 + AsPrimitive<i128>
1208 + AsPrimitive<f32>
1209 + AsPrimitive<i8>
1210 + Send
1211 + Sync,
1212>(
1213 boxes: (ArrayView2<BOX>, Quantization),
1214 protos: (ArrayView3<PROTO>, Quantization),
1215 score_threshold: f32,
1216 iou_threshold: f32,
1217 nms: Option<Nms>,
1218 output_boxes: &mut Vec<DetectBox>,
1219) -> ProtoData
1220where
1221 f32: AsPrimitive<BOX>,
1222{
1223 let (boxes_arr, quant_boxes) = boxes;
1224 let (protos_arr, quant_protos) = protos;
1225 let num_protos = protos_arr.dim().2;
1226
1227 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes_arr, num_protos);
1228
1229 let det_indices = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
1230 (boxes_tensor, quant_boxes),
1231 (scores_tensor, quant_boxes),
1232 score_threshold,
1233 iou_threshold,
1234 nms,
1235 output_boxes.capacity(),
1236 );
1237
1238 extract_proto_data_quant(
1239 det_indices,
1240 mask_tensor,
1241 quant_boxes,
1242 protos_arr,
1243 quant_protos,
1244 output_boxes,
1245 )
1246}
1247
1248pub(crate) fn impl_yolo_segdet_float_proto<
1251 B: BBoxTypeTrait,
1252 BOX: Float + AsPrimitive<f32> + Send + Sync,
1253 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1254>(
1255 boxes: ArrayView2<BOX>,
1256 protos: ArrayView3<PROTO>,
1257 score_threshold: f32,
1258 iou_threshold: f32,
1259 nms: Option<Nms>,
1260 output_boxes: &mut Vec<DetectBox>,
1261) -> ProtoData
1262where
1263 f32: AsPrimitive<BOX>,
1264{
1265 let num_protos = protos.dim().2;
1266 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
1267
1268 let boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
1269 boxes_tensor,
1270 scores_tensor,
1271 score_threshold,
1272 iou_threshold,
1273 nms,
1274 output_boxes.capacity(),
1275 );
1276
1277 extract_proto_data_float(boxes, mask_tensor, protos, output_boxes)
1278}
1279
1280#[allow(clippy::too_many_arguments)]
1283pub(crate) fn impl_yolo_split_segdet_float_proto<
1284 B: BBoxTypeTrait,
1285 BOX: Float + AsPrimitive<f32> + Send + Sync,
1286 SCORE: Float + AsPrimitive<f32> + Send + Sync,
1287 MASK: Float + AsPrimitive<f32> + Send + Sync,
1288 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1289>(
1290 boxes_tensor: ArrayView2<BOX>,
1291 scores_tensor: ArrayView2<SCORE>,
1292 mask_tensor: ArrayView2<MASK>,
1293 protos: ArrayView3<PROTO>,
1294 score_threshold: f32,
1295 iou_threshold: f32,
1296 nms: Option<Nms>,
1297 output_boxes: &mut Vec<DetectBox>,
1298) -> ProtoData
1299where
1300 f32: AsPrimitive<SCORE>,
1301{
1302 let (boxes_tensor, scores_tensor, mask_tensor) =
1303 postprocess_yolo_split_segdet(boxes_tensor, scores_tensor, mask_tensor);
1304 let det_indices = impl_yolo_segdet_get_boxes::<B, _, _>(
1305 boxes_tensor,
1306 scores_tensor,
1307 score_threshold,
1308 iou_threshold,
1309 nms,
1310 output_boxes.capacity(),
1311 );
1312
1313 extract_proto_data_float(det_indices, mask_tensor, protos, output_boxes)
1314}
1315
1316pub fn decode_yolo_end_to_end_segdet_float_proto<T>(
1318 output: ArrayView2<T>,
1319 protos: ArrayView3<T>,
1320 score_threshold: f32,
1321 output_boxes: &mut Vec<DetectBox>,
1322) -> Result<ProtoData, crate::DecoderError>
1323where
1324 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
1325 f32: AsPrimitive<T>,
1326{
1327 let (boxes, scores, classes, mask_coeff) =
1328 postprocess_yolo_end_to_end_segdet(&output, protos.dim().2)?;
1329 let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
1330 boxes,
1331 scores,
1332 classes,
1333 score_threshold,
1334 output_boxes.capacity(),
1335 );
1336
1337 Ok(extract_proto_data_float(
1338 boxes,
1339 mask_coeff,
1340 protos,
1341 output_boxes,
1342 ))
1343}
1344
1345#[allow(clippy::too_many_arguments)]
1347pub fn decode_yolo_split_end_to_end_segdet_float_proto<T>(
1348 boxes: ArrayView2<T>,
1349 scores: ArrayView2<T>,
1350 classes: ArrayView2<T>,
1351 mask_coeff: ArrayView2<T>,
1352 protos: ArrayView3<T>,
1353 score_threshold: f32,
1354 output_boxes: &mut Vec<DetectBox>,
1355) -> Result<ProtoData, crate::DecoderError>
1356where
1357 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
1358 f32: AsPrimitive<T>,
1359{
1360 let (boxes, scores, classes, mask_coeff) =
1361 postprocess_yolo_split_end_to_end_segdet(boxes, scores, &classes, mask_coeff)?;
1362 let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
1363 boxes,
1364 scores,
1365 classes,
1366 score_threshold,
1367 output_boxes.capacity(),
1368 );
1369
1370 Ok(extract_proto_data_float(
1371 boxes,
1372 mask_coeff,
1373 protos,
1374 output_boxes,
1375 ))
1376}
1377
1378pub(super) fn extract_proto_data_float<
1380 MASK: Float + AsPrimitive<f32> + Send + Sync,
1381 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1382>(
1383 det_indices: Vec<(DetectBox, usize)>,
1384 mask_tensor: ArrayView2<MASK>,
1385 protos: ArrayView3<PROTO>,
1386 output_boxes: &mut Vec<DetectBox>,
1387) -> ProtoData {
1388 let mut mask_coefficients = Vec::with_capacity(det_indices.len());
1389 output_boxes.clear();
1390 for (det, idx) in det_indices {
1391 output_boxes.push(det);
1392 let row = mask_tensor.row(idx);
1393 mask_coefficients.push(row.iter().map(|v| v.as_()).collect());
1394 }
1395 let protos_f32 = protos.map(|v| v.as_());
1396 ProtoData {
1397 mask_coefficients,
1398 protos: ProtoTensor::Float(protos_f32),
1399 }
1400}
1401
1402pub(crate) fn extract_proto_data_quant<
1408 MASK: PrimInt + AsPrimitive<f32> + Send + Sync,
1409 PROTO: PrimInt + AsPrimitive<f32> + AsPrimitive<i8> + Send + Sync + 'static,
1410>(
1411 det_indices: Vec<(DetectBox, usize)>,
1412 mask_tensor: ArrayView2<MASK>,
1413 quant_masks: Quantization,
1414 protos: ArrayView3<PROTO>,
1415 quant_protos: Quantization,
1416 output_boxes: &mut Vec<DetectBox>,
1417) -> ProtoData {
1418 let mut mask_coefficients = Vec::with_capacity(det_indices.len());
1419 output_boxes.clear();
1420 for (det, idx) in det_indices {
1421 output_boxes.push(det);
1422 let row = mask_tensor.row(idx);
1423 mask_coefficients.push(
1424 row.iter()
1425 .map(|v| (v.as_() - quant_masks.zero_point as f32) * quant_masks.scale)
1426 .collect(),
1427 );
1428 }
1429 let protos_i8 = if std::any::TypeId::of::<PROTO>() == std::any::TypeId::of::<i8>() {
1433 let view_i8 =
1435 unsafe { &*(&protos as *const ArrayView3<'_, PROTO> as *const ArrayView3<'_, i8>) };
1436 view_i8.to_owned()
1437 } else {
1438 protos.map(|v| {
1439 let v_i8: i8 = v.as_();
1440 v_i8
1441 })
1442 };
1443 ProtoData {
1444 mask_coefficients,
1445 protos: ProtoTensor::Quantized {
1446 protos: protos_i8,
1447 quantization: quant_protos,
1448 },
1449 }
1450}
1451
1452fn postprocess_yolo<'a, T>(
1453 output: &'a ArrayView2<'_, T>,
1454) -> (ArrayView2<'a, T>, ArrayView2<'a, T>) {
1455 let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1456 let scores_tensor = output.slice(s![4.., ..,]).reversed_axes();
1457 (boxes_tensor, scores_tensor)
1458}
1459
1460pub(crate) fn postprocess_yolo_seg<'a, T>(
1461 output: &'a ArrayView2<'_, T>,
1462 num_protos: usize,
1463) -> (ArrayView2<'a, T>, ArrayView2<'a, T>, ArrayView2<'a, T>) {
1464 assert!(
1465 output.shape()[0] > num_protos + 4,
1466 "Output shape is too short: {} <= {} + 4",
1467 output.shape()[0],
1468 num_protos
1469 );
1470 let num_classes = output.shape()[0] - 4 - num_protos;
1471 let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1472 let scores_tensor = output.slice(s![4..(num_classes + 4), ..,]).reversed_axes();
1473 let mask_tensor = output.slice(s![(num_classes + 4).., ..,]).reversed_axes();
1474 (boxes_tensor, scores_tensor, mask_tensor)
1475}
1476
1477pub(crate) fn postprocess_yolo_split_segdet<'a, 'b, 'c, BOX, SCORE, MASK>(
1478 boxes_tensor: ArrayView2<'a, BOX>,
1479 scores_tensor: ArrayView2<'b, SCORE>,
1480 mask_tensor: ArrayView2<'c, MASK>,
1481) -> (
1482 ArrayView2<'a, BOX>,
1483 ArrayView2<'b, SCORE>,
1484 ArrayView2<'c, MASK>,
1485) {
1486 let boxes_tensor = boxes_tensor.reversed_axes();
1487 let scores_tensor = scores_tensor.reversed_axes();
1488 let mask_tensor = mask_tensor.reversed_axes();
1489 (boxes_tensor, scores_tensor, mask_tensor)
1490}
1491
1492fn decode_segdet_f32<
1493 MASK: Float + AsPrimitive<f32> + Send + Sync,
1494 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1495>(
1496 boxes: Vec<(DetectBox, usize)>,
1497 masks: ArrayView2<MASK>,
1498 protos: ArrayView3<PROTO>,
1499) -> Result<Vec<(DetectBox, Array3<u8>)>, crate::DecoderError> {
1500 if boxes.is_empty() {
1501 return Ok(Vec::new());
1502 }
1503 if masks.shape()[1] != protos.shape()[2] {
1504 return Err(crate::DecoderError::InvalidShape(format!(
1505 "Mask coefficients count ({}) doesn't match protos channel count ({})",
1506 masks.shape()[1],
1507 protos.shape()[2],
1508 )));
1509 }
1510 boxes
1511 .into_par_iter()
1512 .map(|mut b| {
1513 let ind = b.1;
1514 let (protos, roi) = protobox(&protos, &b.0.bbox)?;
1515 b.0.bbox = roi;
1516 Ok((b.0, make_segmentation(masks.row(ind), protos.view())))
1517 })
1518 .collect()
1519}
1520
1521pub(crate) fn decode_segdet_quant<
1522 MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1523 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1524>(
1525 boxes: Vec<(DetectBox, usize)>,
1526 masks: ArrayView2<MASK>,
1527 protos: ArrayView3<PROTO>,
1528 quant_masks: Quantization,
1529 quant_protos: Quantization,
1530) -> Result<Vec<(DetectBox, Array3<u8>)>, crate::DecoderError> {
1531 if boxes.is_empty() {
1532 return Ok(Vec::new());
1533 }
1534 if masks.shape()[1] != protos.shape()[2] {
1535 return Err(crate::DecoderError::InvalidShape(format!(
1536 "Mask coefficients count ({}) doesn't match protos channel count ({})",
1537 masks.shape()[1],
1538 protos.shape()[2],
1539 )));
1540 }
1541
1542 let total_bits = MASK::zero().count_zeros() + PROTO::zero().count_zeros() + 5; boxes
1544 .into_iter()
1545 .map(|mut b| {
1546 let i = b.1;
1547 let (protos, roi) = protobox(&protos, &b.0.bbox.to_canonical())?;
1548 b.0.bbox = roi;
1549 let seg = match total_bits {
1550 0..=64 => make_segmentation_quant::<MASK, PROTO, i64>(
1551 masks.row(i),
1552 protos.view(),
1553 quant_masks,
1554 quant_protos,
1555 ),
1556 65..=128 => make_segmentation_quant::<MASK, PROTO, i128>(
1557 masks.row(i),
1558 protos.view(),
1559 quant_masks,
1560 quant_protos,
1561 ),
1562 _ => {
1563 return Err(crate::DecoderError::NotSupported(format!(
1564 "Unsupported bit width ({total_bits}) for segmentation computation"
1565 )));
1566 }
1567 };
1568 Ok((b.0, seg))
1569 })
1570 .collect()
1571}
1572
1573fn protobox<'a, T>(
1574 protos: &'a ArrayView3<T>,
1575 roi: &BoundingBox,
1576) -> Result<(ArrayView3<'a, T>, BoundingBox), crate::DecoderError> {
1577 let width = protos.dim().1 as f32;
1578 let height = protos.dim().0 as f32;
1579
1580 const NORM_LIMIT: f32 = 2.0;
1591 if roi.xmin > NORM_LIMIT
1592 || roi.ymin > NORM_LIMIT
1593 || roi.xmax > NORM_LIMIT
1594 || roi.ymax > NORM_LIMIT
1595 {
1596 return Err(crate::DecoderError::InvalidShape(format!(
1597 "Bounding box coordinates appear un-normalized (pixel-space). \
1598 Got bbox=({:.2}, {:.2}, {:.2}, {:.2}) but expected values in [0, 1]. \
1599 ONNX models output pixel-space boxes — normalize them by dividing by \
1600 the input dimensions before calling decode().",
1601 roi.xmin, roi.ymin, roi.xmax, roi.ymax,
1602 )));
1603 }
1604
1605 let roi = [
1606 (roi.xmin * width).clamp(0.0, width) as usize,
1607 (roi.ymin * height).clamp(0.0, height) as usize,
1608 (roi.xmax * width).clamp(0.0, width).ceil() as usize,
1609 (roi.ymax * height).clamp(0.0, height).ceil() as usize,
1610 ];
1611
1612 let roi_norm = [
1613 roi[0] as f32 / width,
1614 roi[1] as f32 / height,
1615 roi[2] as f32 / width,
1616 roi[3] as f32 / height,
1617 ]
1618 .into();
1619
1620 let cropped = protos.slice(s![roi[1]..roi[3], roi[0]..roi[2], ..]);
1621
1622 Ok((cropped, roi_norm))
1623}
1624
1625fn make_segmentation<
1631 MASK: Float + AsPrimitive<f32> + Send + Sync,
1632 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1633>(
1634 mask: ArrayView1<MASK>,
1635 protos: ArrayView3<PROTO>,
1636) -> Array3<u8> {
1637 let shape = protos.shape();
1638
1639 let mask = mask.to_shape((1, mask.len())).unwrap();
1641 let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
1642 let protos = protos.reversed_axes();
1643 let mask = mask.map(|x| x.as_());
1644 let protos = protos.map(|x| x.as_());
1645
1646 let mask = mask
1648 .dot(&protos)
1649 .into_shape_with_order((shape[0], shape[1], 1))
1650 .unwrap();
1651
1652 mask.map(|x| {
1653 let sigmoid = 1.0 / (1.0 + (-*x).exp());
1654 (sigmoid * 255.0).round() as u8
1655 })
1656}
1657
1658fn make_segmentation_quant<
1665 MASK: PrimInt + AsPrimitive<DEST> + Send + Sync,
1666 PROTO: PrimInt + AsPrimitive<DEST> + Send + Sync,
1667 DEST: PrimInt + 'static + Signed + AsPrimitive<f32> + Debug,
1668>(
1669 mask: ArrayView1<MASK>,
1670 protos: ArrayView3<PROTO>,
1671 quant_masks: Quantization,
1672 quant_protos: Quantization,
1673) -> Array3<u8>
1674where
1675 i32: AsPrimitive<DEST>,
1676 f32: AsPrimitive<DEST>,
1677{
1678 let shape = protos.shape();
1679
1680 let mask = mask.to_shape((1, mask.len())).unwrap();
1682
1683 let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
1684 let protos = protos.reversed_axes();
1685
1686 let zp = quant_masks.zero_point.as_();
1687
1688 let mask = mask.mapv(|x| x.as_() - zp);
1689
1690 let zp = quant_protos.zero_point.as_();
1691 let protos = protos.mapv(|x| x.as_() - zp);
1692
1693 let segmentation = mask
1695 .dot(&protos)
1696 .into_shape_with_order((shape[0], shape[1], 1))
1697 .unwrap();
1698
1699 let combined_scale = quant_masks.scale * quant_protos.scale;
1700 segmentation.map(|x| {
1701 let val: f32 = (*x).as_() * combined_scale;
1702 let sigmoid = 1.0 / (1.0 + (-val).exp());
1703 (sigmoid * 255.0).round() as u8
1704 })
1705}
1706
1707pub fn yolo_segmentation_to_mask(
1719 segmentation: ArrayView3<u8>,
1720 threshold: u8,
1721) -> Result<Array2<u8>, crate::DecoderError> {
1722 if segmentation.shape()[2] != 1 {
1723 return Err(crate::DecoderError::InvalidShape(format!(
1724 "Yolo Instance Segmentation should have shape (H, W, 1), got (H, W, {})",
1725 segmentation.shape()[2]
1726 )));
1727 }
1728 Ok(segmentation
1729 .slice(s![.., .., 0])
1730 .map(|x| if *x >= threshold { 1 } else { 0 }))
1731}
1732
1733#[cfg(test)]
1734#[cfg_attr(coverage_nightly, coverage(off))]
1735mod tests {
1736 use super::*;
1737 use ndarray::Array2;
1738
1739 #[test]
1744 fn test_end_to_end_det_basic_filtering() {
1745 let data: Vec<f32> = vec![
1749 0.1, 0.2, 0.3, 0.1, 0.2, 0.3, 0.5, 0.6, 0.7, 0.5, 0.6, 0.7, 0.9, 0.1, 0.2, 0.0, 1.0, 2.0, ];
1757 let output = Array2::from_shape_vec((6, 3), data).unwrap();
1758
1759 let mut boxes = Vec::with_capacity(10);
1760 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1761
1762 assert_eq!(boxes.len(), 1);
1764 assert_eq!(boxes[0].label, 0);
1765 assert!((boxes[0].score - 0.9).abs() < 0.01);
1766 assert!((boxes[0].bbox.xmin - 0.1).abs() < 0.01);
1767 assert!((boxes[0].bbox.ymin - 0.1).abs() < 0.01);
1768 assert!((boxes[0].bbox.xmax - 0.5).abs() < 0.01);
1769 assert!((boxes[0].bbox.ymax - 0.5).abs() < 0.01);
1770 }
1771
1772 #[test]
1773 fn test_end_to_end_det_all_pass_threshold() {
1774 let data: Vec<f32> = vec![
1776 10.0, 20.0, 10.0, 20.0, 50.0, 60.0, 50.0, 60.0, 0.8, 0.7, 1.0, 2.0, ];
1783 let output = Array2::from_shape_vec((6, 2), data).unwrap();
1784
1785 let mut boxes = Vec::with_capacity(10);
1786 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1787
1788 assert_eq!(boxes.len(), 2);
1789 assert_eq!(boxes[0].label, 1);
1790 assert_eq!(boxes[1].label, 2);
1791 }
1792
1793 #[test]
1794 fn test_end_to_end_det_none_pass_threshold() {
1795 let data: Vec<f32> = vec![
1797 10.0, 20.0, 10.0, 20.0, 50.0, 60.0, 50.0, 60.0, 0.1, 0.2, 1.0, 2.0, ];
1804 let output = Array2::from_shape_vec((6, 2), data).unwrap();
1805
1806 let mut boxes = Vec::with_capacity(10);
1807 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1808
1809 assert_eq!(boxes.len(), 0);
1810 }
1811
1812 #[test]
1813 fn test_end_to_end_det_capacity_limit() {
1814 let data: Vec<f32> = vec![
1816 0.1, 0.2, 0.3, 0.4, 0.5, 0.1, 0.2, 0.3, 0.4, 0.5, 0.5, 0.6, 0.7, 0.8, 0.9, 0.5, 0.6, 0.7, 0.8, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.0, 1.0, 2.0, 3.0, 4.0, ];
1823 let output = Array2::from_shape_vec((6, 5), data).unwrap();
1824
1825 let mut boxes = Vec::with_capacity(2); decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1827
1828 assert_eq!(boxes.len(), 2);
1829 }
1830
1831 #[test]
1832 fn test_end_to_end_det_empty_output() {
1833 let output = Array2::<f32>::zeros((6, 0));
1835
1836 let mut boxes = Vec::with_capacity(10);
1837 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1838
1839 assert_eq!(boxes.len(), 0);
1840 }
1841
1842 #[test]
1843 fn test_end_to_end_det_pixel_coordinates() {
1844 let data: Vec<f32> = vec![
1846 100.0, 200.0, 300.0, 400.0, 0.95, 5.0, ];
1853 let output = Array2::from_shape_vec((6, 1), data).unwrap();
1854
1855 let mut boxes = Vec::with_capacity(10);
1856 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1857
1858 assert_eq!(boxes.len(), 1);
1859 assert_eq!(boxes[0].label, 5);
1860 assert!((boxes[0].bbox.xmin - 100.0).abs() < 0.01);
1861 assert!((boxes[0].bbox.ymin - 200.0).abs() < 0.01);
1862 assert!((boxes[0].bbox.xmax - 300.0).abs() < 0.01);
1863 assert!((boxes[0].bbox.ymax - 400.0).abs() < 0.01);
1864 }
1865
1866 #[test]
1867 fn test_end_to_end_det_invalid_shape() {
1868 let output = Array2::<f32>::zeros((5, 3));
1870
1871 let mut boxes = Vec::with_capacity(10);
1872 let result = decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes);
1873
1874 assert!(result.is_err());
1875 assert!(matches!(
1876 result,
1877 Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 6 rows")
1878 ));
1879 }
1880
1881 #[test]
1886 fn test_end_to_end_segdet_basic() {
1887 let num_protos = 32;
1890 let num_detections = 2;
1891 let num_features = 6 + num_protos;
1892
1893 let mut data = vec![0.0f32; num_features * num_detections];
1895 data[0] = 0.1; data[1] = 0.5; data[num_detections] = 0.1; data[num_detections + 1] = 0.5; data[2 * num_detections] = 0.4; data[2 * num_detections + 1] = 0.9; data[3 * num_detections] = 0.4; data[3 * num_detections + 1] = 0.9; data[4 * num_detections] = 0.9; data[4 * num_detections + 1] = 0.3; data[5 * num_detections] = 1.0; data[5 * num_detections + 1] = 2.0; for i in 6..num_features {
1910 data[i * num_detections] = 0.1;
1911 data[i * num_detections + 1] = 0.1;
1912 }
1913
1914 let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
1915
1916 let protos = Array3::<f32>::zeros((16, 16, num_protos));
1918
1919 let mut boxes = Vec::with_capacity(10);
1920 let mut masks = Vec::with_capacity(10);
1921 decode_yolo_end_to_end_segdet_float(
1922 output.view(),
1923 protos.view(),
1924 0.5,
1925 &mut boxes,
1926 &mut masks,
1927 )
1928 .unwrap();
1929
1930 assert_eq!(boxes.len(), 1);
1932 assert_eq!(masks.len(), 1);
1933 assert_eq!(boxes[0].label, 1);
1934 assert!((boxes[0].score - 0.9).abs() < 0.01);
1935 }
1936
1937 #[test]
1938 fn test_end_to_end_segdet_mask_coordinates() {
1939 let num_protos = 32;
1941 let num_features = 6 + num_protos;
1942
1943 let mut data = vec![0.0f32; num_features];
1944 data[0] = 0.2; data[1] = 0.2; data[2] = 0.8; data[3] = 0.8; data[4] = 0.95; data[5] = 3.0; let output = Array2::from_shape_vec((num_features, 1), data).unwrap();
1952 let protos = Array3::<f32>::zeros((16, 16, num_protos));
1953
1954 let mut boxes = Vec::with_capacity(10);
1955 let mut masks = Vec::with_capacity(10);
1956 decode_yolo_end_to_end_segdet_float(
1957 output.view(),
1958 protos.view(),
1959 0.5,
1960 &mut boxes,
1961 &mut masks,
1962 )
1963 .unwrap();
1964
1965 assert_eq!(boxes.len(), 1);
1966 assert_eq!(masks.len(), 1);
1967
1968 assert!((masks[0].xmin - boxes[0].bbox.xmin).abs() < 0.01);
1970 assert!((masks[0].ymin - boxes[0].bbox.ymin).abs() < 0.01);
1971 assert!((masks[0].xmax - boxes[0].bbox.xmax).abs() < 0.01);
1972 assert!((masks[0].ymax - boxes[0].bbox.ymax).abs() < 0.01);
1973 }
1974
1975 #[test]
1976 fn test_end_to_end_segdet_empty_output() {
1977 let num_protos = 32;
1978 let output = Array2::<f32>::zeros((6 + num_protos, 0));
1979 let protos = Array3::<f32>::zeros((16, 16, num_protos));
1980
1981 let mut boxes = Vec::with_capacity(10);
1982 let mut masks = Vec::with_capacity(10);
1983 decode_yolo_end_to_end_segdet_float(
1984 output.view(),
1985 protos.view(),
1986 0.5,
1987 &mut boxes,
1988 &mut masks,
1989 )
1990 .unwrap();
1991
1992 assert_eq!(boxes.len(), 0);
1993 assert_eq!(masks.len(), 0);
1994 }
1995
1996 #[test]
1997 fn test_end_to_end_segdet_capacity_limit() {
1998 let num_protos = 32;
1999 let num_detections = 5;
2000 let num_features = 6 + num_protos;
2001
2002 let mut data = vec![0.0f32; num_features * num_detections];
2003 for i in 0..num_detections {
2005 data[i] = 0.1 * (i as f32); data[num_detections + i] = 0.1 * (i as f32); data[2 * num_detections + i] = 0.1 * (i as f32) + 0.2; data[3 * num_detections + i] = 0.1 * (i as f32) + 0.2; data[4 * num_detections + i] = 0.9; data[5 * num_detections + i] = i as f32; }
2012
2013 let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
2014 let protos = Array3::<f32>::zeros((16, 16, num_protos));
2015
2016 let mut boxes = Vec::with_capacity(2); let mut masks = Vec::with_capacity(2);
2018 decode_yolo_end_to_end_segdet_float(
2019 output.view(),
2020 protos.view(),
2021 0.5,
2022 &mut boxes,
2023 &mut masks,
2024 )
2025 .unwrap();
2026
2027 assert_eq!(boxes.len(), 2);
2028 assert_eq!(masks.len(), 2);
2029 }
2030
2031 #[test]
2032 fn test_end_to_end_segdet_invalid_shape_too_few_rows() {
2033 let output = Array2::<f32>::zeros((6, 3));
2035 let protos = Array3::<f32>::zeros((16, 16, 32));
2036
2037 let mut boxes = Vec::with_capacity(10);
2038 let mut masks = Vec::with_capacity(10);
2039 let result = decode_yolo_end_to_end_segdet_float(
2040 output.view(),
2041 protos.view(),
2042 0.5,
2043 &mut boxes,
2044 &mut masks,
2045 );
2046
2047 assert!(result.is_err());
2048 assert!(matches!(
2049 result,
2050 Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 7 rows")
2051 ));
2052 }
2053
2054 #[test]
2055 fn test_end_to_end_segdet_invalid_shape_protos_mismatch() {
2056 let num_protos = 32;
2058 let output = Array2::<f32>::zeros((6 + 16, 3)); let protos = Array3::<f32>::zeros((16, 16, num_protos)); let mut boxes = Vec::with_capacity(10);
2062 let mut masks = Vec::with_capacity(10);
2063 let result = decode_yolo_end_to_end_segdet_float(
2064 output.view(),
2065 protos.view(),
2066 0.5,
2067 &mut boxes,
2068 &mut masks,
2069 );
2070
2071 assert!(result.is_err());
2072 assert!(matches!(
2073 result,
2074 Err(crate::DecoderError::InvalidShape(s)) if s.contains("doesn't match protos count")
2075 ));
2076 }
2077
2078 #[test]
2083 fn test_split_end_to_end_segdet_basic() {
2084 let num_protos = 32;
2087 let num_detections = 2;
2088 let num_features = 6 + num_protos;
2089
2090 let mut data = vec![0.0f32; num_features * num_detections];
2092 data[0] = 0.1; data[1] = 0.5; data[num_detections] = 0.1; data[num_detections + 1] = 0.5; data[2 * num_detections] = 0.4; data[2 * num_detections + 1] = 0.9; data[3 * num_detections] = 0.4; data[3 * num_detections + 1] = 0.9; data[4 * num_detections] = 0.9; data[4 * num_detections + 1] = 0.3; data[5 * num_detections] = 1.0; data[5 * num_detections + 1] = 2.0; for i in 6..num_features {
2107 data[i * num_detections] = 0.1;
2108 data[i * num_detections + 1] = 0.1;
2109 }
2110
2111 let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
2112 let box_coords = output.slice(s![..4, ..]);
2113 let scores = output.slice(s![4..5, ..]);
2114 let classes = output.slice(s![5..6, ..]);
2115 let mask_coeff = output.slice(s![6.., ..]);
2116 let protos = Array3::<f32>::zeros((16, 16, num_protos));
2118
2119 let mut boxes = Vec::with_capacity(10);
2120 let mut masks = Vec::with_capacity(10);
2121 decode_yolo_split_end_to_end_segdet_float(
2122 box_coords,
2123 scores,
2124 classes,
2125 mask_coeff,
2126 protos.view(),
2127 0.5,
2128 &mut boxes,
2129 &mut masks,
2130 )
2131 .unwrap();
2132
2133 assert_eq!(boxes.len(), 1);
2135 assert_eq!(masks.len(), 1);
2136 assert_eq!(boxes[0].label, 1);
2137 assert!((boxes[0].score - 0.9).abs() < 0.01);
2138 }
2139
2140 #[test]
2145 fn test_segmentation_to_mask_basic() {
2146 let data: Vec<u8> = vec![
2148 100, 200, 50, 150, 10, 255, 128, 64, 0, 127, 128, 255, 64, 64, 192, 192, ];
2153 let segmentation = Array3::from_shape_vec((4, 4, 1), data).unwrap();
2154
2155 let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2156
2157 assert_eq!(mask[[0, 0]], 0); assert_eq!(mask[[0, 1]], 1); assert_eq!(mask[[0, 2]], 0); assert_eq!(mask[[0, 3]], 1); assert_eq!(mask[[1, 1]], 1); assert_eq!(mask[[1, 2]], 1); assert_eq!(mask[[2, 0]], 0); assert_eq!(mask[[2, 1]], 0); }
2167
2168 #[test]
2169 fn test_segmentation_to_mask_all_above() {
2170 let segmentation = Array3::from_elem((4, 4, 1), 255u8);
2171 let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2172 assert!(mask.iter().all(|&x| x == 1));
2173 }
2174
2175 #[test]
2176 fn test_segmentation_to_mask_all_below() {
2177 let segmentation = Array3::from_elem((4, 4, 1), 64u8);
2178 let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2179 assert!(mask.iter().all(|&x| x == 0));
2180 }
2181
2182 #[test]
2183 fn test_segmentation_to_mask_invalid_shape() {
2184 let segmentation = Array3::from_elem((4, 4, 3), 128u8);
2185 let result = yolo_segmentation_to_mask(segmentation.view(), 128);
2186
2187 assert!(result.is_err());
2188 assert!(matches!(
2189 result,
2190 Err(crate::DecoderError::InvalidShape(s)) if s.contains("(H, W, 1)")
2191 ));
2192 }
2193
2194 #[test]
2199 fn test_protobox_clamps_edge_coordinates() {
2200 let protos = Array3::<f32>::zeros((16, 16, 4));
2202 let view = protos.view();
2203 let roi = BoundingBox {
2204 xmin: 0.5,
2205 ymin: 0.5,
2206 xmax: 1.0,
2207 ymax: 1.0,
2208 };
2209 let result = protobox(&view, &roi);
2210 assert!(result.is_ok(), "protobox should accept xmax=1.0");
2211 let (cropped, _roi_norm) = result.unwrap();
2212 assert!(cropped.shape()[0] > 0);
2214 assert!(cropped.shape()[1] > 0);
2215 assert_eq!(cropped.shape()[2], 4);
2216 }
2217
2218 #[test]
2219 fn test_protobox_rejects_wildly_out_of_range() {
2220 let protos = Array3::<f32>::zeros((16, 16, 4));
2222 let view = protos.view();
2223 let roi = BoundingBox {
2224 xmin: 0.0,
2225 ymin: 0.0,
2226 xmax: 3.0,
2227 ymax: 3.0,
2228 };
2229 let result = protobox(&view, &roi);
2230 assert!(
2231 matches!(result, Err(crate::DecoderError::InvalidShape(s)) if s.contains("un-normalized")),
2232 "protobox should reject coords > NORM_LIMIT"
2233 );
2234 }
2235
2236 #[test]
2237 fn test_protobox_accepts_slightly_over_one() {
2238 let protos = Array3::<f32>::zeros((16, 16, 4));
2240 let view = protos.view();
2241 let roi = BoundingBox {
2242 xmin: 0.0,
2243 ymin: 0.0,
2244 xmax: 1.5,
2245 ymax: 1.5,
2246 };
2247 let result = protobox(&view, &roi);
2248 assert!(
2249 result.is_ok(),
2250 "protobox should accept coords <= NORM_LIMIT (2.0)"
2251 );
2252 let (cropped, _roi_norm) = result.unwrap();
2253 assert_eq!(cropped.shape()[0], 16);
2255 assert_eq!(cropped.shape()[1], 16);
2256 }
2257
2258 #[test]
2259 fn test_segdet_float_proto_no_panic() {
2260 let num_proposals = 100; let num_classes = 80;
2264 let num_mask_coeffs = 32;
2265 let rows = 4 + num_classes + num_mask_coeffs; let mut data = vec![0.0f32; rows * num_proposals];
2271 for i in 0..num_proposals {
2272 let row = |r: usize| r * num_proposals + i;
2273 data[row(0)] = 320.0; data[row(1)] = 320.0; data[row(2)] = 50.0; data[row(3)] = 50.0; data[row(4)] = 0.9; }
2279 let boxes = ndarray::Array2::from_shape_vec((rows, num_proposals), data).unwrap();
2280
2281 let protos = ndarray::Array3::<f32>::zeros((160, 160, num_mask_coeffs));
2284
2285 let mut output_boxes = Vec::with_capacity(300);
2286
2287 let proto_data = impl_yolo_segdet_float_proto::<XYWH, _, _>(
2289 boxes.view(),
2290 protos.view(),
2291 0.5,
2292 0.7,
2293 Some(Nms::default()),
2294 &mut output_boxes,
2295 );
2296
2297 assert!(!output_boxes.is_empty());
2299 assert_eq!(proto_data.mask_coefficients.len(), output_boxes.len());
2300 for coeffs in &proto_data.mask_coefficients {
2302 assert_eq!(coeffs.len(), num_mask_coeffs);
2303 }
2304 }
2305
2306 #[test]
2321 fn test_pre_nms_cap_truncates_excess_candidates() {
2322 let n: usize = 50_000;
2323 let num_classes = 1;
2324
2325 let mut boxes_data = Vec::with_capacity(n * 4);
2329 let mut scores_data = Vec::with_capacity(n * num_classes);
2330 for i in 0..n {
2331 boxes_data.extend_from_slice(&[0.1f32, 0.1, 0.5, 0.5]);
2332 scores_data.push(0.99 - (i as f32) * 1e-7);
2335 }
2336 let boxes = Array2::from_shape_vec((n, 4), boxes_data).unwrap();
2337 let scores = Array2::from_shape_vec((n, num_classes), scores_data).unwrap();
2338
2339 let result = impl_yolo_segdet_get_boxes::<XYXY, _, _>(
2340 boxes.view(),
2341 scores.view(),
2342 0.1,
2343 1.0,
2344 None, usize::MAX, );
2347
2348 assert_eq!(
2349 result.len(),
2350 crate::yolo::MAX_NMS_CANDIDATES,
2351 "pre-NMS cap should truncate to MAX_NMS_CANDIDATES; got {}",
2352 result.len()
2353 );
2354 let top_score = result[0].0.score;
2357 assert!(
2358 top_score > 0.98,
2359 "highest-ranked survivor should have the largest score, got {top_score}"
2360 );
2361 }
2362
2363 #[test]
2368 fn test_pre_nms_cap_truncates_excess_candidates_quant() {
2369 use crate::Quantization;
2370 let n: usize = 50_000;
2371 let num_classes = 1;
2372
2373 let boxes_data = (0..n).flat_map(|_| [10i8, 10, 50, 50]).collect::<Vec<_>>();
2376 let boxes = Array2::from_shape_vec((n, 4), boxes_data).unwrap();
2377 let quant_boxes = Quantization {
2378 scale: 0.01,
2379 zero_point: 0,
2380 };
2381
2382 let scores_data: Vec<u8> = (0..n)
2387 .map(|i| 250u8.saturating_sub((i % 200) as u8))
2388 .collect();
2389 let scores = Array2::from_shape_vec((n, num_classes), scores_data).unwrap();
2390 let quant_scores = Quantization {
2391 scale: 0.00392,
2392 zero_point: 0,
2393 };
2394
2395 let result = impl_yolo_split_segdet_quant_get_boxes::<XYXY, _, _>(
2396 (boxes.view(), quant_boxes),
2397 (scores.view(), quant_scores),
2398 0.1,
2399 1.0,
2400 None,
2401 usize::MAX,
2402 );
2403
2404 assert_eq!(
2405 result.len(),
2406 crate::yolo::MAX_NMS_CANDIDATES,
2407 "quant path pre-NMS cap should truncate to MAX_NMS_CANDIDATES; got {}",
2408 result.len()
2409 );
2410 }
2411
2412 #[test]
2426 fn segdet_combined_tensor_pairs_detection_with_matching_mask_row() {
2427 let nc = 2; let nm = 2; let n = 3; let feat = 4 + nc + nm; let mut data = vec![0.0f32; feat * n];
2450 let set = |d: &mut [f32], r: usize, c: usize, v: f32| d[r * n + c] = v;
2451 set(&mut data, 0, 0, 0.2);
2452 set(&mut data, 1, 0, 0.2);
2453 set(&mut data, 2, 0, 0.1);
2454 set(&mut data, 3, 0, 0.1);
2455 set(&mut data, 0, 1, 0.5);
2456 set(&mut data, 1, 1, 0.5);
2457 set(&mut data, 2, 1, 0.1);
2458 set(&mut data, 3, 1, 0.1);
2459 set(&mut data, 0, 2, 0.8);
2460 set(&mut data, 1, 2, 0.8);
2461 set(&mut data, 2, 2, 0.1);
2462 set(&mut data, 3, 2, 0.1);
2463 set(&mut data, 4, 0, 0.9);
2464 set(&mut data, 4, 2, 0.8);
2465 set(&mut data, 6, 0, 3.0);
2466 set(&mut data, 7, 0, 3.0);
2467 set(&mut data, 6, 2, -3.0);
2468 set(&mut data, 7, 2, -3.0);
2469
2470 let output = Array2::from_shape_vec((feat, n), data).unwrap();
2471 let protos = Array3::<f32>::from_elem((8, 8, nm), 1.0);
2472
2473 let mut boxes: Vec<DetectBox> = Vec::with_capacity(10);
2474 let mut masks: Vec<Segmentation> = Vec::with_capacity(10);
2475 decode_yolo_segdet_float(
2476 output.view(),
2477 protos.view(),
2478 0.5,
2479 0.5,
2480 Some(Nms::ClassAgnostic),
2481 &mut boxes,
2482 &mut masks,
2483 )
2484 .unwrap();
2485
2486 assert_eq!(
2487 boxes.len(),
2488 2,
2489 "two anchors above threshold should survive (a0 score=0.9, a2 score=0.8); got {}",
2490 boxes.len()
2491 );
2492
2493 for (b, m) in boxes.iter().zip(masks.iter()) {
2499 let cx = (b.bbox.xmin + b.bbox.xmax) * 0.5;
2500 let mean = {
2501 let s = &m.segmentation;
2502 let total: u32 = s.iter().map(|&v| v as u32).sum();
2503 total as f32 / s.len() as f32
2504 };
2505 if cx < 0.3 {
2506 assert!(
2508 mean > 200.0,
2509 "anchor 0 detection (centre {cx:.2}) should have high-value mask; got mean {mean}"
2510 );
2511 } else if cx > 0.7 {
2512 assert!(
2514 mean < 50.0,
2515 "anchor 2 detection (centre {cx:.2}) should have low-value mask; got mean {mean}"
2516 );
2517 } else {
2518 panic!("unexpected detection centre {cx:.2}");
2519 }
2520 }
2521 }
2522}