1use std::fmt::Debug;
5
6use ndarray::{
7 parallel::prelude::{IntoParallelIterator, ParallelIterator},
8 s, Array2, Array3, ArrayView1, ArrayView2, ArrayView3,
9};
10use num_traits::{AsPrimitive, Float, PrimInt, Signed};
11use rayon::slice::ParallelSliceMut;
12
13use crate::{
14 byte::{
15 nms_class_aware_int, nms_extra_class_aware_int, nms_extra_int, nms_int,
16 postprocess_boxes_index_quant, postprocess_boxes_quant, quantize_score_threshold,
17 },
18 configs::Nms,
19 dequant_detect_box,
20 float::{
21 nms_class_aware_float, nms_extra_class_aware_float, nms_extra_float, nms_float,
22 postprocess_boxes_float, postprocess_boxes_index_float,
23 },
24 BBoxTypeTrait, BoundingBox, DetectBox, DetectBoxQuantized, ProtoData, Quantization,
25 Segmentation, XYWH, XYXY,
26};
27
28pub(crate) const MAX_NMS_CANDIDATES: usize = 30_000;
43
44fn truncate_to_top_k_by_score<E: Send>(boxes: &mut Vec<(DetectBox, E)>) {
49 if boxes.len() > MAX_NMS_CANDIDATES {
50 boxes.par_sort_unstable_by(|a, b| b.0.score.total_cmp(&a.0.score));
51 boxes.truncate(MAX_NMS_CANDIDATES);
52 }
53}
54
55fn truncate_to_top_k_by_score_quant<S: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send>(
59 boxes: &mut Vec<(DetectBoxQuantized<S>, E)>,
60) {
61 if boxes.len() > MAX_NMS_CANDIDATES {
62 boxes.par_sort_unstable_by(|a, b| b.0.score.cmp(&a.0.score));
63 boxes.truncate(MAX_NMS_CANDIDATES);
64 }
65}
66
67fn dispatch_nms_float(nms: Option<Nms>, iou: f32, boxes: Vec<DetectBox>) -> Vec<DetectBox> {
69 match nms {
70 Some(Nms::ClassAgnostic) => nms_float(iou, boxes),
71 Some(Nms::ClassAware) => nms_class_aware_float(iou, boxes),
72 None => boxes, }
74}
75
76pub(super) fn dispatch_nms_extra_float<E: Send + Sync>(
79 nms: Option<Nms>,
80 iou: f32,
81 boxes: Vec<(DetectBox, E)>,
82) -> Vec<(DetectBox, E)> {
83 match nms {
84 Some(Nms::ClassAgnostic) => nms_extra_float(iou, boxes),
85 Some(Nms::ClassAware) => nms_extra_class_aware_float(iou, boxes),
86 None => boxes, }
88}
89
90fn dispatch_nms_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync>(
93 nms: Option<Nms>,
94 iou: f32,
95 boxes: Vec<DetectBoxQuantized<SCORE>>,
96) -> Vec<DetectBoxQuantized<SCORE>> {
97 match nms {
98 Some(Nms::ClassAgnostic) => nms_int(iou, boxes),
99 Some(Nms::ClassAware) => nms_class_aware_int(iou, boxes),
100 None => boxes, }
102}
103
104fn dispatch_nms_extra_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send + Sync>(
107 nms: Option<Nms>,
108 iou: f32,
109 boxes: Vec<(DetectBoxQuantized<SCORE>, E)>,
110) -> Vec<(DetectBoxQuantized<SCORE>, E)> {
111 match nms {
112 Some(Nms::ClassAgnostic) => nms_extra_int(iou, boxes),
113 Some(Nms::ClassAware) => nms_extra_class_aware_int(iou, boxes),
114 None => boxes, }
116}
117
118pub fn decode_yolo_det<BOX: PrimInt + AsPrimitive<f32> + Send + Sync>(
125 output: (ArrayView2<BOX>, Quantization),
126 score_threshold: f32,
127 iou_threshold: f32,
128 nms: Option<Nms>,
129 output_boxes: &mut Vec<DetectBox>,
130) where
131 f32: AsPrimitive<BOX>,
132{
133 impl_yolo_quant::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
134}
135
136pub fn decode_yolo_det_float<T>(
143 output: ArrayView2<T>,
144 score_threshold: f32,
145 iou_threshold: f32,
146 nms: Option<Nms>,
147 output_boxes: &mut Vec<DetectBox>,
148) where
149 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
150 f32: AsPrimitive<T>,
151{
152 impl_yolo_float::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
153}
154
155pub fn decode_yolo_segdet_quant<
167 BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
168 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
169>(
170 boxes: (ArrayView2<BOX>, Quantization),
171 protos: (ArrayView3<PROTO>, Quantization),
172 score_threshold: f32,
173 iou_threshold: f32,
174 nms: Option<Nms>,
175 output_boxes: &mut Vec<DetectBox>,
176 output_masks: &mut Vec<Segmentation>,
177) -> Result<(), crate::DecoderError>
178where
179 f32: AsPrimitive<BOX>,
180{
181 impl_yolo_segdet_quant::<XYWH, _, _>(
182 boxes,
183 protos,
184 score_threshold,
185 iou_threshold,
186 nms,
187 output_boxes,
188 output_masks,
189 )
190}
191
192pub fn decode_yolo_segdet_float<T>(
204 boxes: ArrayView2<T>,
205 protos: ArrayView3<T>,
206 score_threshold: f32,
207 iou_threshold: f32,
208 nms: Option<Nms>,
209 output_boxes: &mut Vec<DetectBox>,
210 output_masks: &mut Vec<Segmentation>,
211) -> Result<(), crate::DecoderError>
212where
213 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
214 f32: AsPrimitive<T>,
215{
216 impl_yolo_segdet_float::<XYWH, _, _>(
217 boxes,
218 protos,
219 score_threshold,
220 iou_threshold,
221 nms,
222 output_boxes,
223 output_masks,
224 )
225}
226
227pub fn decode_yolo_split_det_quant<
239 BOX: PrimInt + AsPrimitive<i32> + AsPrimitive<f32> + Send + Sync,
240 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
241>(
242 boxes: (ArrayView2<BOX>, Quantization),
243 scores: (ArrayView2<SCORE>, Quantization),
244 score_threshold: f32,
245 iou_threshold: f32,
246 nms: Option<Nms>,
247 output_boxes: &mut Vec<DetectBox>,
248) where
249 f32: AsPrimitive<SCORE>,
250{
251 impl_yolo_split_quant::<XYWH, _, _>(
252 boxes,
253 scores,
254 score_threshold,
255 iou_threshold,
256 nms,
257 output_boxes,
258 );
259}
260
261pub fn decode_yolo_split_det_float<T>(
273 boxes: ArrayView2<T>,
274 scores: ArrayView2<T>,
275 score_threshold: f32,
276 iou_threshold: f32,
277 nms: Option<Nms>,
278 output_boxes: &mut Vec<DetectBox>,
279) where
280 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
281 f32: AsPrimitive<T>,
282{
283 impl_yolo_split_float::<XYWH, _, _>(
284 boxes,
285 scores,
286 score_threshold,
287 iou_threshold,
288 nms,
289 output_boxes,
290 );
291}
292
293#[allow(clippy::too_many_arguments)]
307pub fn decode_yolo_split_segdet<
308 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
309 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
310 MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
311 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
312>(
313 boxes: (ArrayView2<BOX>, Quantization),
314 scores: (ArrayView2<SCORE>, Quantization),
315 mask_coeff: (ArrayView2<MASK>, Quantization),
316 protos: (ArrayView3<PROTO>, Quantization),
317 score_threshold: f32,
318 iou_threshold: f32,
319 nms: Option<Nms>,
320 output_boxes: &mut Vec<DetectBox>,
321 output_masks: &mut Vec<Segmentation>,
322) -> Result<(), crate::DecoderError>
323where
324 f32: AsPrimitive<SCORE>,
325{
326 impl_yolo_split_segdet_quant::<XYWH, _, _, _, _>(
327 boxes,
328 scores,
329 mask_coeff,
330 protos,
331 score_threshold,
332 iou_threshold,
333 nms,
334 output_boxes,
335 output_masks,
336 )
337}
338
339#[allow(clippy::too_many_arguments)]
353pub fn decode_yolo_split_segdet_float<T>(
354 boxes: ArrayView2<T>,
355 scores: ArrayView2<T>,
356 mask_coeff: ArrayView2<T>,
357 protos: ArrayView3<T>,
358 score_threshold: f32,
359 iou_threshold: f32,
360 nms: Option<Nms>,
361 output_boxes: &mut Vec<DetectBox>,
362 output_masks: &mut Vec<Segmentation>,
363) -> Result<(), crate::DecoderError>
364where
365 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
366 f32: AsPrimitive<T>,
367{
368 impl_yolo_split_segdet_float::<XYWH, _, _, _, _>(
369 boxes,
370 scores,
371 mask_coeff,
372 protos,
373 score_threshold,
374 iou_threshold,
375 nms,
376 output_boxes,
377 output_masks,
378 )
379}
380
381pub fn decode_yolo_end_to_end_det_float<T>(
396 output: ArrayView2<T>,
397 score_threshold: f32,
398 output_boxes: &mut Vec<DetectBox>,
399) -> Result<(), crate::DecoderError>
400where
401 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
402 f32: AsPrimitive<T>,
403{
404 if output.shape()[0] < 6 {
406 return Err(crate::DecoderError::InvalidShape(format!(
407 "End-to-end detection output requires at least 6 rows, got {}",
408 output.shape()[0]
409 )));
410 }
411
412 let boxes = output.slice(s![0..4, ..]).reversed_axes();
414 let scores = output.slice(s![4..5, ..]).reversed_axes();
415 let classes = output.slice(s![5, ..]);
416 let mut boxes =
417 postprocess_boxes_index_float::<XYXY, _, _>(score_threshold.as_(), boxes, scores);
418 boxes.truncate(output_boxes.capacity());
419 output_boxes.clear();
420 for (mut b, i) in boxes.into_iter() {
421 b.label = classes[i].as_() as usize;
422 output_boxes.push(b);
423 }
424 Ok(())
426}
427
428pub fn decode_yolo_end_to_end_segdet_float<T>(
446 output: ArrayView2<T>,
447 protos: ArrayView3<T>,
448 score_threshold: f32,
449 output_boxes: &mut Vec<DetectBox>,
450 output_masks: &mut Vec<crate::Segmentation>,
451) -> Result<(), crate::DecoderError>
452where
453 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
454 f32: AsPrimitive<T>,
455{
456 let (boxes, scores, classes, mask_coeff) =
457 postprocess_yolo_end_to_end_segdet(&output, protos.dim().2)?;
458 let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
459 boxes,
460 scores,
461 classes,
462 score_threshold,
463 output_boxes.capacity(),
464 );
465
466 impl_yolo_split_segdet_process_masks(boxes, mask_coeff, protos, output_boxes, output_masks)
469}
470
471pub fn decode_yolo_split_end_to_end_det_float<T: Float + AsPrimitive<f32>>(
480 boxes: ArrayView2<T>,
481 scores: ArrayView2<T>,
482 classes: ArrayView2<T>,
483 score_threshold: f32,
484 output_boxes: &mut Vec<DetectBox>,
485) -> Result<(), crate::DecoderError> {
486 let n = boxes.shape()[1];
487
488 output_boxes.clear();
489
490 let (boxes, scores, classes) = postprocess_yolo_split_end_to_end_det(boxes, scores, &classes)?;
491
492 for i in 0..n {
493 let score: f32 = scores[[i, 0]].as_();
494 if score < score_threshold {
495 continue;
496 }
497 if output_boxes.len() >= output_boxes.capacity() {
498 break;
499 }
500 output_boxes.push(DetectBox {
501 bbox: BoundingBox {
502 xmin: boxes[[i, 0]].as_(),
503 ymin: boxes[[i, 1]].as_(),
504 xmax: boxes[[i, 2]].as_(),
505 ymax: boxes[[i, 3]].as_(),
506 },
507 score,
508 label: classes[i].as_() as usize,
509 });
510 }
511 Ok(())
512}
513
514#[allow(clippy::too_many_arguments)]
523pub fn decode_yolo_split_end_to_end_segdet_float<T>(
524 boxes: ArrayView2<T>,
525 scores: ArrayView2<T>,
526 classes: ArrayView2<T>,
527 mask_coeff: ArrayView2<T>,
528 protos: ArrayView3<T>,
529 score_threshold: f32,
530 output_boxes: &mut Vec<DetectBox>,
531 output_masks: &mut Vec<crate::Segmentation>,
532) -> Result<(), crate::DecoderError>
533where
534 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
535 f32: AsPrimitive<T>,
536{
537 let (boxes, scores, classes, mask_coeff) =
538 postprocess_yolo_split_end_to_end_segdet(boxes, scores, &classes, mask_coeff)?;
539 let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
540 boxes,
541 scores,
542 classes,
543 score_threshold,
544 output_boxes.capacity(),
545 );
546
547 impl_yolo_split_segdet_process_masks(boxes, mask_coeff, protos, output_boxes, output_masks)
548}
549
550#[allow(clippy::type_complexity)]
551pub(crate) fn postprocess_yolo_end_to_end_segdet<'a, T>(
552 output: &'a ArrayView2<'_, T>,
553 num_protos: usize,
554) -> Result<
555 (
556 ArrayView2<'a, T>,
557 ArrayView2<'a, T>,
558 ArrayView1<'a, T>,
559 ArrayView2<'a, T>,
560 ),
561 crate::DecoderError,
562> {
563 if output.shape()[0] < 7 {
565 return Err(crate::DecoderError::InvalidShape(format!(
566 "End-to-end segdet output requires at least 7 rows, got {}",
567 output.shape()[0]
568 )));
569 }
570
571 let num_mask_coeffs = output.shape()[0] - 6;
572 if num_mask_coeffs != num_protos {
573 return Err(crate::DecoderError::InvalidShape(format!(
574 "Mask coefficients count ({}) doesn't match protos count ({})",
575 num_mask_coeffs, num_protos
576 )));
577 }
578
579 let boxes = output.slice(s![0..4, ..]).reversed_axes();
581 let scores = output.slice(s![4..5, ..]).reversed_axes();
582 let classes = output.slice(s![5, ..]);
583 let mask_coeff = output.slice(s![6.., ..]).reversed_axes();
584 Ok((boxes, scores, classes, mask_coeff))
585}
586
587#[allow(clippy::type_complexity)]
594pub(crate) fn postprocess_yolo_split_end_to_end_det<'a, 'b, 'c, BOXES, SCORES, CLASS>(
595 boxes: ArrayView2<'a, BOXES>,
596 scores: ArrayView2<'b, SCORES>,
597 classes: &'c ArrayView2<CLASS>,
598) -> Result<
599 (
600 ArrayView2<'a, BOXES>,
601 ArrayView2<'b, SCORES>,
602 ArrayView1<'c, CLASS>,
603 ),
604 crate::DecoderError,
605> {
606 let num_boxes = boxes.shape()[1];
607 if boxes.shape()[0] != 4 {
608 return Err(crate::DecoderError::InvalidShape(format!(
609 "Split end-to-end box_coords must be 4, got {}",
610 boxes.shape()[0]
611 )));
612 }
613
614 if scores.shape()[0] != 1 {
615 return Err(crate::DecoderError::InvalidShape(format!(
616 "Split end-to-end scores num_classes must be 1, got {}",
617 scores.shape()[0]
618 )));
619 }
620
621 if classes.shape()[0] != 1 {
622 return Err(crate::DecoderError::InvalidShape(format!(
623 "Split end-to-end classes num_classes must be 1, got {}",
624 classes.shape()[0]
625 )));
626 }
627
628 if scores.shape()[1] != num_boxes {
629 return Err(crate::DecoderError::InvalidShape(format!(
630 "Split end-to-end scores must have same num_boxes as boxes ({}), got {}",
631 num_boxes,
632 scores.shape()[1]
633 )));
634 }
635
636 if classes.shape()[1] != num_boxes {
637 return Err(crate::DecoderError::InvalidShape(format!(
638 "Split end-to-end classes must have same num_boxes as boxes ({}), got {}",
639 num_boxes,
640 classes.shape()[1]
641 )));
642 }
643
644 let boxes = boxes.reversed_axes();
645 let scores = scores.reversed_axes();
646 let classes = classes.slice(s![0, ..]);
647 Ok((boxes, scores, classes))
648}
649
650#[allow(clippy::type_complexity)]
653pub(crate) fn postprocess_yolo_split_end_to_end_segdet<
654 'a,
655 'b,
656 'c,
657 'd,
658 BOXES,
659 SCORES,
660 CLASS,
661 MASK,
662>(
663 boxes: ArrayView2<'a, BOXES>,
664 scores: ArrayView2<'b, SCORES>,
665 classes: &'c ArrayView2<CLASS>,
666 mask_coeff: ArrayView2<'d, MASK>,
667) -> Result<
668 (
669 ArrayView2<'a, BOXES>,
670 ArrayView2<'b, SCORES>,
671 ArrayView1<'c, CLASS>,
672 ArrayView2<'d, MASK>,
673 ),
674 crate::DecoderError,
675> {
676 let num_boxes = boxes.shape()[1];
677 if boxes.shape()[0] != 4 {
678 return Err(crate::DecoderError::InvalidShape(format!(
679 "Split end-to-end box_coords must be 4, got {}",
680 boxes.shape()[0]
681 )));
682 }
683
684 if scores.shape()[0] != 1 {
685 return Err(crate::DecoderError::InvalidShape(format!(
686 "Split end-to-end scores num_classes must be 1, got {}",
687 scores.shape()[0]
688 )));
689 }
690
691 if classes.shape()[0] != 1 {
692 return Err(crate::DecoderError::InvalidShape(format!(
693 "Split end-to-end classes num_classes must be 1, got {}",
694 classes.shape()[0]
695 )));
696 }
697
698 if scores.shape()[1] != num_boxes {
699 return Err(crate::DecoderError::InvalidShape(format!(
700 "Split end-to-end scores must have same num_boxes as boxes ({}), got {}",
701 num_boxes,
702 scores.shape()[1]
703 )));
704 }
705
706 if classes.shape()[1] != num_boxes {
707 return Err(crate::DecoderError::InvalidShape(format!(
708 "Split end-to-end classes must have same num_boxes as boxes ({}), got {}",
709 num_boxes,
710 classes.shape()[1]
711 )));
712 }
713
714 if mask_coeff.shape()[1] != num_boxes {
715 return Err(crate::DecoderError::InvalidShape(format!(
716 "Split end-to-end mask_coeff must have same num_boxes as boxes ({}), got {}",
717 num_boxes,
718 mask_coeff.shape()[1]
719 )));
720 }
721
722 let boxes = boxes.reversed_axes();
723 let scores = scores.reversed_axes();
724 let classes = classes.slice(s![0, ..]);
725 let mask_coeff = mask_coeff.reversed_axes();
726 Ok((boxes, scores, classes, mask_coeff))
727}
728pub(crate) fn impl_yolo_quant<B: BBoxTypeTrait, T: PrimInt + AsPrimitive<f32> + Send + Sync>(
733 output: (ArrayView2<T>, Quantization),
734 score_threshold: f32,
735 iou_threshold: f32,
736 nms: Option<Nms>,
737 output_boxes: &mut Vec<DetectBox>,
738) where
739 f32: AsPrimitive<T>,
740{
741 let (boxes, quant_boxes) = output;
742 let (boxes_tensor, scores_tensor) = postprocess_yolo(&boxes);
743
744 let boxes = {
745 let score_threshold = quantize_score_threshold(score_threshold, quant_boxes);
746 postprocess_boxes_quant::<B, _, _>(
747 score_threshold,
748 boxes_tensor,
749 scores_tensor,
750 quant_boxes,
751 )
752 };
753
754 let boxes = dispatch_nms_int(nms, iou_threshold, boxes);
755 let len = output_boxes.capacity().min(boxes.len());
756 output_boxes.clear();
757 for b in boxes.iter().take(len) {
758 output_boxes.push(dequant_detect_box(b, quant_boxes));
759 }
760}
761
762pub(crate) fn impl_yolo_float<B: BBoxTypeTrait, T: Float + AsPrimitive<f32> + Send + Sync>(
767 output: ArrayView2<T>,
768 score_threshold: f32,
769 iou_threshold: f32,
770 nms: Option<Nms>,
771 output_boxes: &mut Vec<DetectBox>,
772) where
773 f32: AsPrimitive<T>,
774{
775 let (boxes_tensor, scores_tensor) = postprocess_yolo(&output);
776 let boxes =
777 postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
778 let boxes = dispatch_nms_float(nms, iou_threshold, boxes);
779 let len = output_boxes.capacity().min(boxes.len());
780 output_boxes.clear();
781 for b in boxes.into_iter().take(len) {
782 output_boxes.push(b);
783 }
784}
785
786pub(crate) fn impl_yolo_split_quant<
796 B: BBoxTypeTrait,
797 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
798 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
799>(
800 boxes: (ArrayView2<BOX>, Quantization),
801 scores: (ArrayView2<SCORE>, Quantization),
802 score_threshold: f32,
803 iou_threshold: f32,
804 nms: Option<Nms>,
805 output_boxes: &mut Vec<DetectBox>,
806) where
807 f32: AsPrimitive<SCORE>,
808{
809 let (boxes_tensor, quant_boxes) = boxes;
810 let (scores_tensor, quant_scores) = scores;
811
812 let boxes_tensor = boxes_tensor.reversed_axes();
813 let scores_tensor = scores_tensor.reversed_axes();
814
815 let boxes = {
816 let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
817 postprocess_boxes_quant::<B, _, _>(
818 score_threshold,
819 boxes_tensor,
820 scores_tensor,
821 quant_boxes,
822 )
823 };
824
825 let boxes = dispatch_nms_int(nms, iou_threshold, boxes);
826 let len = output_boxes.capacity().min(boxes.len());
827 output_boxes.clear();
828 for b in boxes.iter().take(len) {
829 output_boxes.push(dequant_detect_box(b, quant_scores));
830 }
831}
832
833pub(crate) fn impl_yolo_split_float<
842 B: BBoxTypeTrait,
843 BOX: Float + AsPrimitive<f32> + Send + Sync,
844 SCORE: Float + AsPrimitive<f32> + Send + Sync,
845>(
846 boxes_tensor: ArrayView2<BOX>,
847 scores_tensor: ArrayView2<SCORE>,
848 score_threshold: f32,
849 iou_threshold: f32,
850 nms: Option<Nms>,
851 output_boxes: &mut Vec<DetectBox>,
852) where
853 f32: AsPrimitive<SCORE>,
854{
855 let boxes_tensor = boxes_tensor.reversed_axes();
856 let scores_tensor = scores_tensor.reversed_axes();
857 let boxes =
858 postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
859 let boxes = dispatch_nms_float(nms, iou_threshold, boxes);
860 let len = output_boxes.capacity().min(boxes.len());
861 output_boxes.clear();
862 for b in boxes.into_iter().take(len) {
863 output_boxes.push(b);
864 }
865}
866
867pub(crate) fn impl_yolo_segdet_quant<
877 B: BBoxTypeTrait,
878 BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
879 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
880>(
881 boxes: (ArrayView2<BOX>, Quantization),
882 protos: (ArrayView3<PROTO>, Quantization),
883 score_threshold: f32,
884 iou_threshold: f32,
885 nms: Option<Nms>,
886 output_boxes: &mut Vec<DetectBox>,
887 output_masks: &mut Vec<Segmentation>,
888) -> Result<(), crate::DecoderError>
889where
890 f32: AsPrimitive<BOX>,
891{
892 let (boxes, quant_boxes) = boxes;
893 let num_protos = protos.0.dim().2;
894
895 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
896 let boxes = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
897 (boxes_tensor, quant_boxes),
898 (scores_tensor, quant_boxes),
899 score_threshold,
900 iou_threshold,
901 nms,
902 output_boxes.capacity(),
903 );
904
905 impl_yolo_split_segdet_quant_process_masks::<_, _>(
906 boxes,
907 (mask_tensor, quant_boxes),
908 protos,
909 output_boxes,
910 output_masks,
911 )
912}
913
914pub(crate) fn impl_yolo_segdet_float<
924 B: BBoxTypeTrait,
925 BOX: Float + AsPrimitive<f32> + Send + Sync,
926 PROTO: Float + AsPrimitive<f32> + Send + Sync,
927>(
928 boxes: ArrayView2<BOX>,
929 protos: ArrayView3<PROTO>,
930 score_threshold: f32,
931 iou_threshold: f32,
932 nms: Option<Nms>,
933 output_boxes: &mut Vec<DetectBox>,
934 output_masks: &mut Vec<Segmentation>,
935) -> Result<(), crate::DecoderError>
936where
937 f32: AsPrimitive<BOX>,
938{
939 let num_protos = protos.dim().2;
940 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
941 let boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
942 boxes_tensor,
943 scores_tensor,
944 score_threshold,
945 iou_threshold,
946 nms,
947 output_boxes.capacity(),
948 );
949 impl_yolo_split_segdet_process_masks(boxes, mask_tensor, protos, output_boxes, output_masks)
950}
951
952pub(crate) fn impl_yolo_segdet_get_boxes<
953 B: BBoxTypeTrait,
954 BOX: Float + AsPrimitive<f32> + Send + Sync,
955 SCORE: Float + AsPrimitive<f32> + Send + Sync,
956>(
957 boxes_tensor: ArrayView2<BOX>,
958 scores_tensor: ArrayView2<SCORE>,
959 score_threshold: f32,
960 iou_threshold: f32,
961 nms: Option<Nms>,
962 max_boxes: usize,
963) -> Vec<(DetectBox, usize)>
964where
965 f32: AsPrimitive<SCORE>,
966{
967 let mut boxes = postprocess_boxes_index_float::<B, _, _>(
968 score_threshold.as_(),
969 boxes_tensor,
970 scores_tensor,
971 );
972 truncate_to_top_k_by_score(&mut boxes);
973 let mut boxes = dispatch_nms_extra_float(nms, iou_threshold, boxes);
974 boxes.truncate(max_boxes);
975 boxes
976}
977
978pub(crate) fn impl_yolo_end_to_end_segdet_get_boxes<
979 B: BBoxTypeTrait,
980 BOX: Float + AsPrimitive<f32> + Send + Sync,
981 SCORE: Float + AsPrimitive<f32> + Send + Sync,
982 CLASS: AsPrimitive<f32> + Send + Sync,
983>(
984 boxes: ArrayView2<BOX>,
985 scores: ArrayView2<SCORE>,
986 classes: ArrayView1<CLASS>,
987 score_threshold: f32,
988 max_boxes: usize,
989) -> Vec<(DetectBox, usize)>
990where
991 f32: AsPrimitive<SCORE>,
992{
993 let mut boxes = postprocess_boxes_index_float::<B, _, _>(score_threshold.as_(), boxes, scores);
994 boxes.truncate(max_boxes);
995 for (b, ind) in &mut boxes {
996 b.label = classes[*ind].as_().round() as usize;
997 }
998 boxes
999}
1000
1001pub(crate) fn impl_yolo_split_segdet_process_masks<
1002 MASK: Float + AsPrimitive<f32> + Send + Sync,
1003 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1004>(
1005 boxes: Vec<(DetectBox, usize)>,
1006 masks_tensor: ArrayView2<MASK>,
1007 protos_tensor: ArrayView3<PROTO>,
1008 output_boxes: &mut Vec<DetectBox>,
1009 output_masks: &mut Vec<Segmentation>,
1010) -> Result<(), crate::DecoderError> {
1011 let boxes = decode_segdet_f32(boxes, masks_tensor, protos_tensor)?;
1012 output_boxes.clear();
1013 output_masks.clear();
1014 for (b, m) in boxes.into_iter() {
1015 output_boxes.push(b);
1016 output_masks.push(Segmentation {
1017 xmin: b.bbox.xmin,
1018 ymin: b.bbox.ymin,
1019 xmax: b.bbox.xmax,
1020 ymax: b.bbox.ymax,
1021 segmentation: m,
1022 });
1023 }
1024 Ok(())
1025}
1026pub(crate) fn impl_yolo_split_segdet_quant_get_boxes<
1030 B: BBoxTypeTrait,
1031 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
1032 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
1033>(
1034 boxes: (ArrayView2<BOX>, Quantization),
1035 scores: (ArrayView2<SCORE>, Quantization),
1036 score_threshold: f32,
1037 iou_threshold: f32,
1038 nms: Option<Nms>,
1039 max_boxes: usize,
1040) -> Vec<(DetectBox, usize)>
1041where
1042 f32: AsPrimitive<SCORE>,
1043{
1044 let (boxes_tensor, quant_boxes) = boxes;
1045 let (scores_tensor, quant_scores) = scores;
1046
1047 let mut boxes = {
1048 let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
1049 postprocess_boxes_index_quant::<B, _, _>(
1050 score_threshold,
1051 boxes_tensor,
1052 scores_tensor,
1053 quant_boxes,
1054 )
1055 };
1056 truncate_to_top_k_by_score_quant(&mut boxes);
1057 let mut boxes = dispatch_nms_extra_int(nms, iou_threshold, boxes);
1058 boxes.truncate(max_boxes);
1059 boxes
1060 .into_iter()
1061 .map(|(b, i)| (dequant_detect_box(&b, quant_scores), i))
1062 .collect()
1063}
1064
1065pub(crate) fn impl_yolo_split_segdet_quant_process_masks<
1066 MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1067 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1068>(
1069 boxes: Vec<(DetectBox, usize)>,
1070 mask_coeff: (ArrayView2<MASK>, Quantization),
1071 protos: (ArrayView3<PROTO>, Quantization),
1072 output_boxes: &mut Vec<DetectBox>,
1073 output_masks: &mut Vec<Segmentation>,
1074) -> Result<(), crate::DecoderError> {
1075 let (masks, quant_masks) = mask_coeff;
1076 let (protos, quant_protos) = protos;
1077
1078 let boxes = decode_segdet_quant(boxes, masks, protos, quant_masks, quant_protos)?;
1079 output_boxes.clear();
1080 output_masks.clear();
1081 for (b, m) in boxes.into_iter() {
1082 output_boxes.push(b);
1083 output_masks.push(Segmentation {
1084 xmin: b.bbox.xmin,
1085 ymin: b.bbox.ymin,
1086 xmax: b.bbox.xmax,
1087 ymax: b.bbox.ymax,
1088 segmentation: m,
1089 });
1090 }
1091 Ok(())
1092}
1093
1094#[allow(clippy::too_many_arguments)]
1095pub(crate) fn impl_yolo_split_segdet_quant<
1107 B: BBoxTypeTrait,
1108 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
1109 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
1110 MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1111 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1112>(
1113 boxes: (ArrayView2<BOX>, Quantization),
1114 scores: (ArrayView2<SCORE>, Quantization),
1115 mask_coeff: (ArrayView2<MASK>, Quantization),
1116 protos: (ArrayView3<PROTO>, Quantization),
1117 score_threshold: f32,
1118 iou_threshold: f32,
1119 nms: Option<Nms>,
1120 output_boxes: &mut Vec<DetectBox>,
1121 output_masks: &mut Vec<Segmentation>,
1122) -> Result<(), crate::DecoderError>
1123where
1124 f32: AsPrimitive<SCORE>,
1125{
1126 let (boxes_, scores_, mask_coeff_) =
1127 postprocess_yolo_split_segdet(boxes.0, scores.0, mask_coeff.0);
1128 let boxes = (boxes_, boxes.1);
1129 let scores = (scores_, scores.1);
1130 let mask_coeff = (mask_coeff_, mask_coeff.1);
1131
1132 let boxes = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
1133 boxes,
1134 scores,
1135 score_threshold,
1136 iou_threshold,
1137 nms,
1138 output_boxes.capacity(),
1139 );
1140
1141 impl_yolo_split_segdet_quant_process_masks(
1142 boxes,
1143 mask_coeff,
1144 protos,
1145 output_boxes,
1146 output_masks,
1147 )
1148}
1149
1150#[allow(clippy::too_many_arguments)]
1151pub(crate) fn impl_yolo_split_segdet_float<
1163 B: BBoxTypeTrait,
1164 BOX: Float + AsPrimitive<f32> + Send + Sync,
1165 SCORE: Float + AsPrimitive<f32> + Send + Sync,
1166 MASK: Float + AsPrimitive<f32> + Send + Sync,
1167 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1168>(
1169 boxes_tensor: ArrayView2<BOX>,
1170 scores_tensor: ArrayView2<SCORE>,
1171 mask_tensor: ArrayView2<MASK>,
1172 protos: ArrayView3<PROTO>,
1173 score_threshold: f32,
1174 iou_threshold: f32,
1175 nms: Option<Nms>,
1176 output_boxes: &mut Vec<DetectBox>,
1177 output_masks: &mut Vec<Segmentation>,
1178) -> Result<(), crate::DecoderError>
1179where
1180 f32: AsPrimitive<SCORE>,
1181{
1182 let (boxes_tensor, scores_tensor, mask_tensor) =
1183 postprocess_yolo_split_segdet(boxes_tensor, scores_tensor, mask_tensor);
1184
1185 let boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
1186 boxes_tensor,
1187 scores_tensor,
1188 score_threshold,
1189 iou_threshold,
1190 nms,
1191 output_boxes.capacity(),
1192 );
1193 impl_yolo_split_segdet_process_masks(boxes, mask_tensor, protos, output_boxes, output_masks)
1194}
1195
1196pub fn impl_yolo_segdet_quant_proto<
1203 B: BBoxTypeTrait,
1204 BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1205 PROTO: PrimInt
1206 + AsPrimitive<i64>
1207 + AsPrimitive<i128>
1208 + AsPrimitive<f32>
1209 + AsPrimitive<i8>
1210 + Send
1211 + Sync,
1212>(
1213 boxes: (ArrayView2<BOX>, Quantization),
1214 protos: (ArrayView3<PROTO>, Quantization),
1215 score_threshold: f32,
1216 iou_threshold: f32,
1217 nms: Option<Nms>,
1218 output_boxes: &mut Vec<DetectBox>,
1219) -> ProtoData
1220where
1221 f32: AsPrimitive<BOX>,
1222{
1223 let (boxes_arr, quant_boxes) = boxes;
1224 let (protos_arr, quant_protos) = protos;
1225 let num_protos = protos_arr.dim().2;
1226
1227 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes_arr, num_protos);
1228
1229 let det_indices = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
1230 (boxes_tensor, quant_boxes),
1231 (scores_tensor, quant_boxes),
1232 score_threshold,
1233 iou_threshold,
1234 nms,
1235 output_boxes.capacity(),
1236 );
1237
1238 extract_proto_data_quant(
1239 det_indices,
1240 mask_tensor,
1241 quant_boxes,
1242 protos_arr,
1243 quant_protos,
1244 output_boxes,
1245 )
1246}
1247
1248pub(crate) fn impl_yolo_segdet_float_proto<
1251 B: BBoxTypeTrait,
1252 BOX: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1253 PROTO: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1254>(
1255 boxes: ArrayView2<BOX>,
1256 protos: ArrayView3<PROTO>,
1257 score_threshold: f32,
1258 iou_threshold: f32,
1259 nms: Option<Nms>,
1260 output_boxes: &mut Vec<DetectBox>,
1261) -> ProtoData
1262where
1263 f32: AsPrimitive<BOX>,
1264{
1265 let num_protos = protos.dim().2;
1266 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
1267
1268 let boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
1269 boxes_tensor,
1270 scores_tensor,
1271 score_threshold,
1272 iou_threshold,
1273 nms,
1274 output_boxes.capacity(),
1275 );
1276
1277 extract_proto_data_float(boxes, mask_tensor, protos, output_boxes)
1278}
1279
1280#[allow(clippy::too_many_arguments)]
1283pub(crate) fn impl_yolo_split_segdet_float_proto<
1284 B: BBoxTypeTrait,
1285 BOX: Float + AsPrimitive<f32> + Send + Sync,
1286 SCORE: Float + AsPrimitive<f32> + Send + Sync,
1287 MASK: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1288 PROTO: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1289>(
1290 boxes_tensor: ArrayView2<BOX>,
1291 scores_tensor: ArrayView2<SCORE>,
1292 mask_tensor: ArrayView2<MASK>,
1293 protos: ArrayView3<PROTO>,
1294 score_threshold: f32,
1295 iou_threshold: f32,
1296 nms: Option<Nms>,
1297 output_boxes: &mut Vec<DetectBox>,
1298) -> ProtoData
1299where
1300 f32: AsPrimitive<SCORE>,
1301{
1302 let (boxes_tensor, scores_tensor, mask_tensor) =
1303 postprocess_yolo_split_segdet(boxes_tensor, scores_tensor, mask_tensor);
1304 let det_indices = impl_yolo_segdet_get_boxes::<B, _, _>(
1305 boxes_tensor,
1306 scores_tensor,
1307 score_threshold,
1308 iou_threshold,
1309 nms,
1310 output_boxes.capacity(),
1311 );
1312
1313 extract_proto_data_float(det_indices, mask_tensor, protos, output_boxes)
1314}
1315
1316pub fn decode_yolo_end_to_end_segdet_float_proto<T>(
1318 output: ArrayView2<T>,
1319 protos: ArrayView3<T>,
1320 score_threshold: f32,
1321 output_boxes: &mut Vec<DetectBox>,
1322) -> Result<ProtoData, crate::DecoderError>
1323where
1324 T: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1325 f32: AsPrimitive<T>,
1326{
1327 let (boxes, scores, classes, mask_coeff) =
1328 postprocess_yolo_end_to_end_segdet(&output, protos.dim().2)?;
1329 let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
1330 boxes,
1331 scores,
1332 classes,
1333 score_threshold,
1334 output_boxes.capacity(),
1335 );
1336
1337 Ok(extract_proto_data_float(
1338 boxes,
1339 mask_coeff,
1340 protos,
1341 output_boxes,
1342 ))
1343}
1344
1345#[allow(clippy::too_many_arguments)]
1347pub fn decode_yolo_split_end_to_end_segdet_float_proto<T>(
1348 boxes: ArrayView2<T>,
1349 scores: ArrayView2<T>,
1350 classes: ArrayView2<T>,
1351 mask_coeff: ArrayView2<T>,
1352 protos: ArrayView3<T>,
1353 score_threshold: f32,
1354 output_boxes: &mut Vec<DetectBox>,
1355) -> Result<ProtoData, crate::DecoderError>
1356where
1357 T: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1358 f32: AsPrimitive<T>,
1359{
1360 let (boxes, scores, classes, mask_coeff) =
1361 postprocess_yolo_split_end_to_end_segdet(boxes, scores, &classes, mask_coeff)?;
1362 let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
1363 boxes,
1364 scores,
1365 classes,
1366 score_threshold,
1367 output_boxes.capacity(),
1368 );
1369
1370 Ok(extract_proto_data_float(
1371 boxes,
1372 mask_coeff,
1373 protos,
1374 output_boxes,
1375 ))
1376}
1377
1378pub(super) fn extract_proto_data_float<
1385 MASK: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1386 PROTO: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1387>(
1388 det_indices: Vec<(DetectBox, usize)>,
1389 mask_tensor: ArrayView2<MASK>,
1390 protos: ArrayView3<PROTO>,
1391 output_boxes: &mut Vec<DetectBox>,
1392) -> ProtoData {
1393 let num_protos = mask_tensor.ncols();
1394 let n = det_indices.len();
1395
1396 let mut coeff_rows: Vec<MASK> = Vec::with_capacity(n * num_protos);
1401 output_boxes.clear();
1402 for (det, idx) in det_indices {
1403 output_boxes.push(det);
1404 let row = mask_tensor.row(idx);
1405 coeff_rows.extend(row.iter().copied());
1406 }
1407
1408 let mask_coefficients = MASK::slice_into_tensor_dyn(&coeff_rows, &[n, num_protos])
1409 .expect("allocating mask_coefficients TensorDyn");
1410 let protos_tensor =
1411 PROTO::arrayview3_into_tensor_dyn(protos).expect("allocating protos TensorDyn");
1412
1413 ProtoData {
1414 mask_coefficients,
1415 protos: protos_tensor,
1416 }
1417}
1418
1419pub(crate) fn extract_proto_data_quant<
1428 MASK: PrimInt + AsPrimitive<f32> + Send + Sync,
1429 PROTO: PrimInt + AsPrimitive<f32> + AsPrimitive<i8> + Send + Sync + 'static,
1430>(
1431 det_indices: Vec<(DetectBox, usize)>,
1432 mask_tensor: ArrayView2<MASK>,
1433 quant_masks: Quantization,
1434 protos: ArrayView3<PROTO>,
1435 quant_protos: Quantization,
1436 output_boxes: &mut Vec<DetectBox>,
1437) -> ProtoData {
1438 use edgefirst_tensor::{Tensor, TensorDyn, TensorMapTrait, TensorMemory, TensorTrait};
1439
1440 let num_protos = mask_tensor.ncols();
1441 let n = det_indices.len();
1442 let mut coeff_f32 = Vec::<f32>::with_capacity(n * num_protos);
1443 output_boxes.clear();
1444 for (det, idx) in det_indices {
1445 output_boxes.push(det);
1446 let row = mask_tensor.row(idx);
1447 coeff_f32.extend(
1448 row.iter()
1449 .map(|v| (v.as_() - quant_masks.zero_point as f32) * quant_masks.scale),
1450 );
1451 }
1452
1453 let coeff_tensor = Tensor::<f32>::new(&[n, num_protos], Some(TensorMemory::Mem), None)
1456 .expect("allocating mask_coefficients tensor");
1457 if n > 0 {
1458 let mut m = coeff_tensor
1459 .map()
1460 .expect("mapping mask_coefficients tensor");
1461 m.as_mut_slice().copy_from_slice(&coeff_f32);
1462 }
1463 let mask_coefficients = TensorDyn::F32(coeff_tensor);
1464
1465 let (h, w, k) = protos.dim();
1468 let protos_tensor = Tensor::<i8>::new(&[h, w, k], Some(TensorMemory::Mem), None)
1469 .expect("allocating protos tensor");
1470 {
1471 let mut m = protos_tensor.map().expect("mapping protos tensor");
1472 let dst = m.as_mut_slice();
1473 if std::any::TypeId::of::<PROTO>() == std::any::TypeId::of::<i8>() {
1474 let src: &[i8] =
1477 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
1478 if protos.is_standard_layout() {
1479 dst.copy_from_slice(src);
1480 } else {
1481 for (d, s) in dst.iter_mut().zip(protos.iter()) {
1482 let v_i8: i8 = s.as_();
1483 *d = v_i8;
1484 }
1485 }
1486 } else {
1487 for (d, s) in dst.iter_mut().zip(protos.iter()) {
1488 let v_i8: i8 = s.as_();
1489 *d = v_i8;
1490 }
1491 }
1492 }
1493 let tensor_quant =
1494 edgefirst_tensor::Quantization::per_tensor(quant_protos.scale, quant_protos.zero_point);
1495 let protos_tensor = protos_tensor
1496 .with_quantization(tensor_quant)
1497 .expect("per-tensor quantization on new Tensor<i8>");
1498
1499 ProtoData {
1500 mask_coefficients,
1501 protos: TensorDyn::I8(protos_tensor),
1502 }
1503}
1504
1505pub trait FloatProtoElem: Copy + 'static {
1511 fn slice_into_tensor_dyn(
1512 values: &[Self],
1513 shape: &[usize],
1514 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn>;
1515
1516 fn arrayview3_into_tensor_dyn(
1517 view: ArrayView3<'_, Self>,
1518 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn>;
1519}
1520
1521impl FloatProtoElem for f32 {
1522 fn slice_into_tensor_dyn(
1523 values: &[f32],
1524 shape: &[usize],
1525 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1526 edgefirst_tensor::Tensor::<f32>::from_slice(values, shape)
1527 .map(edgefirst_tensor::TensorDyn::F32)
1528 }
1529 fn arrayview3_into_tensor_dyn(
1530 view: ArrayView3<'_, f32>,
1531 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1532 edgefirst_tensor::Tensor::<f32>::from_arrayview3(view).map(edgefirst_tensor::TensorDyn::F32)
1533 }
1534}
1535
1536impl FloatProtoElem for half::f16 {
1537 fn slice_into_tensor_dyn(
1538 values: &[half::f16],
1539 shape: &[usize],
1540 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1541 edgefirst_tensor::Tensor::<half::f16>::from_slice(values, shape)
1542 .map(edgefirst_tensor::TensorDyn::F16)
1543 }
1544 fn arrayview3_into_tensor_dyn(
1545 view: ArrayView3<'_, half::f16>,
1546 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1547 edgefirst_tensor::Tensor::<half::f16>::from_arrayview3(view)
1548 .map(edgefirst_tensor::TensorDyn::F16)
1549 }
1550}
1551
1552impl FloatProtoElem for f64 {
1553 fn slice_into_tensor_dyn(
1554 values: &[f64],
1555 shape: &[usize],
1556 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1557 let narrowed: Vec<f32> = values.iter().map(|&v| v as f32).collect();
1559 edgefirst_tensor::Tensor::<f32>::from_slice(&narrowed, shape)
1560 .map(edgefirst_tensor::TensorDyn::F32)
1561 }
1562 fn arrayview3_into_tensor_dyn(
1563 view: ArrayView3<'_, f64>,
1564 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1565 let narrowed: ndarray::Array3<f32> = view.mapv(|v| v as f32);
1566 edgefirst_tensor::Tensor::<f32>::from_arrayview3(narrowed.view())
1567 .map(edgefirst_tensor::TensorDyn::F32)
1568 }
1569}
1570
1571fn postprocess_yolo<'a, T>(
1572 output: &'a ArrayView2<'_, T>,
1573) -> (ArrayView2<'a, T>, ArrayView2<'a, T>) {
1574 let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1575 let scores_tensor = output.slice(s![4.., ..,]).reversed_axes();
1576 (boxes_tensor, scores_tensor)
1577}
1578
1579pub(crate) fn postprocess_yolo_seg<'a, T>(
1580 output: &'a ArrayView2<'_, T>,
1581 num_protos: usize,
1582) -> (ArrayView2<'a, T>, ArrayView2<'a, T>, ArrayView2<'a, T>) {
1583 assert!(
1584 output.shape()[0] > num_protos + 4,
1585 "Output shape is too short: {} <= {} + 4",
1586 output.shape()[0],
1587 num_protos
1588 );
1589 let num_classes = output.shape()[0] - 4 - num_protos;
1590 let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1591 let scores_tensor = output.slice(s![4..(num_classes + 4), ..,]).reversed_axes();
1592 let mask_tensor = output.slice(s![(num_classes + 4).., ..,]).reversed_axes();
1593 (boxes_tensor, scores_tensor, mask_tensor)
1594}
1595
1596pub(crate) fn postprocess_yolo_split_segdet<'a, 'b, 'c, BOX, SCORE, MASK>(
1597 boxes_tensor: ArrayView2<'a, BOX>,
1598 scores_tensor: ArrayView2<'b, SCORE>,
1599 mask_tensor: ArrayView2<'c, MASK>,
1600) -> (
1601 ArrayView2<'a, BOX>,
1602 ArrayView2<'b, SCORE>,
1603 ArrayView2<'c, MASK>,
1604) {
1605 let boxes_tensor = boxes_tensor.reversed_axes();
1606 let scores_tensor = scores_tensor.reversed_axes();
1607 let mask_tensor = mask_tensor.reversed_axes();
1608 (boxes_tensor, scores_tensor, mask_tensor)
1609}
1610
1611fn decode_segdet_f32<
1612 MASK: Float + AsPrimitive<f32> + Send + Sync,
1613 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1614>(
1615 boxes: Vec<(DetectBox, usize)>,
1616 masks: ArrayView2<MASK>,
1617 protos: ArrayView3<PROTO>,
1618) -> Result<Vec<(DetectBox, Array3<u8>)>, crate::DecoderError> {
1619 if boxes.is_empty() {
1620 return Ok(Vec::new());
1621 }
1622 if masks.shape()[1] != protos.shape()[2] {
1623 return Err(crate::DecoderError::InvalidShape(format!(
1624 "Mask coefficients count ({}) doesn't match protos channel count ({})",
1625 masks.shape()[1],
1626 protos.shape()[2],
1627 )));
1628 }
1629 boxes
1630 .into_par_iter()
1631 .map(|mut b| {
1632 let ind = b.1;
1633 let (protos, roi) = protobox(&protos, &b.0.bbox)?;
1634 b.0.bbox = roi;
1635 Ok((b.0, make_segmentation(masks.row(ind), protos.view())))
1636 })
1637 .collect()
1638}
1639
1640pub(crate) fn decode_segdet_quant<
1641 MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1642 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1643>(
1644 boxes: Vec<(DetectBox, usize)>,
1645 masks: ArrayView2<MASK>,
1646 protos: ArrayView3<PROTO>,
1647 quant_masks: Quantization,
1648 quant_protos: Quantization,
1649) -> Result<Vec<(DetectBox, Array3<u8>)>, crate::DecoderError> {
1650 if boxes.is_empty() {
1651 return Ok(Vec::new());
1652 }
1653 if masks.shape()[1] != protos.shape()[2] {
1654 return Err(crate::DecoderError::InvalidShape(format!(
1655 "Mask coefficients count ({}) doesn't match protos channel count ({})",
1656 masks.shape()[1],
1657 protos.shape()[2],
1658 )));
1659 }
1660
1661 let total_bits = MASK::zero().count_zeros() + PROTO::zero().count_zeros() + 5; boxes
1663 .into_iter()
1664 .map(|mut b| {
1665 let i = b.1;
1666 let (protos, roi) = protobox(&protos, &b.0.bbox.to_canonical())?;
1667 b.0.bbox = roi;
1668 let seg = match total_bits {
1669 0..=64 => make_segmentation_quant::<MASK, PROTO, i64>(
1670 masks.row(i),
1671 protos.view(),
1672 quant_masks,
1673 quant_protos,
1674 ),
1675 65..=128 => make_segmentation_quant::<MASK, PROTO, i128>(
1676 masks.row(i),
1677 protos.view(),
1678 quant_masks,
1679 quant_protos,
1680 ),
1681 _ => {
1682 return Err(crate::DecoderError::NotSupported(format!(
1683 "Unsupported bit width ({total_bits}) for segmentation computation"
1684 )));
1685 }
1686 };
1687 Ok((b.0, seg))
1688 })
1689 .collect()
1690}
1691
1692fn protobox<'a, T>(
1693 protos: &'a ArrayView3<T>,
1694 roi: &BoundingBox,
1695) -> Result<(ArrayView3<'a, T>, BoundingBox), crate::DecoderError> {
1696 let width = protos.dim().1 as f32;
1697 let height = protos.dim().0 as f32;
1698
1699 const NORM_LIMIT: f32 = 2.0;
1710 if roi.xmin > NORM_LIMIT
1711 || roi.ymin > NORM_LIMIT
1712 || roi.xmax > NORM_LIMIT
1713 || roi.ymax > NORM_LIMIT
1714 {
1715 return Err(crate::DecoderError::InvalidShape(format!(
1716 "Bounding box coordinates appear un-normalized (pixel-space). \
1717 Got bbox=({:.2}, {:.2}, {:.2}, {:.2}) but expected values in [0, 1]. \
1718 ONNX models output pixel-space boxes — normalize them by dividing by \
1719 the input dimensions before calling decode().",
1720 roi.xmin, roi.ymin, roi.xmax, roi.ymax,
1721 )));
1722 }
1723
1724 let roi = [
1725 (roi.xmin * width).clamp(0.0, width) as usize,
1726 (roi.ymin * height).clamp(0.0, height) as usize,
1727 (roi.xmax * width).clamp(0.0, width).ceil() as usize,
1728 (roi.ymax * height).clamp(0.0, height).ceil() as usize,
1729 ];
1730
1731 let roi_norm = [
1732 roi[0] as f32 / width,
1733 roi[1] as f32 / height,
1734 roi[2] as f32 / width,
1735 roi[3] as f32 / height,
1736 ]
1737 .into();
1738
1739 let cropped = protos.slice(s![roi[1]..roi[3], roi[0]..roi[2], ..]);
1740
1741 Ok((cropped, roi_norm))
1742}
1743
1744fn make_segmentation<
1750 MASK: Float + AsPrimitive<f32> + Send + Sync,
1751 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1752>(
1753 mask: ArrayView1<MASK>,
1754 protos: ArrayView3<PROTO>,
1755) -> Array3<u8> {
1756 let shape = protos.shape();
1757
1758 let mask = mask.to_shape((1, mask.len())).unwrap();
1760 let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
1761 let protos = protos.reversed_axes();
1762 let mask = mask.map(|x| x.as_());
1763 let protos = protos.map(|x| x.as_());
1764
1765 let mask = mask
1767 .dot(&protos)
1768 .into_shape_with_order((shape[0], shape[1], 1))
1769 .unwrap();
1770
1771 mask.map(|x| {
1772 let sigmoid = 1.0 / (1.0 + (-*x).exp());
1773 (sigmoid * 255.0).round() as u8
1774 })
1775}
1776
1777fn make_segmentation_quant<
1784 MASK: PrimInt + AsPrimitive<DEST> + Send + Sync,
1785 PROTO: PrimInt + AsPrimitive<DEST> + Send + Sync,
1786 DEST: PrimInt + 'static + Signed + AsPrimitive<f32> + Debug,
1787>(
1788 mask: ArrayView1<MASK>,
1789 protos: ArrayView3<PROTO>,
1790 quant_masks: Quantization,
1791 quant_protos: Quantization,
1792) -> Array3<u8>
1793where
1794 i32: AsPrimitive<DEST>,
1795 f32: AsPrimitive<DEST>,
1796{
1797 let shape = protos.shape();
1798
1799 let mask = mask.to_shape((1, mask.len())).unwrap();
1801
1802 let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
1803 let protos = protos.reversed_axes();
1804
1805 let zp = quant_masks.zero_point.as_();
1806
1807 let mask = mask.mapv(|x| x.as_() - zp);
1808
1809 let zp = quant_protos.zero_point.as_();
1810 let protos = protos.mapv(|x| x.as_() - zp);
1811
1812 let segmentation = mask
1814 .dot(&protos)
1815 .into_shape_with_order((shape[0], shape[1], 1))
1816 .unwrap();
1817
1818 let combined_scale = quant_masks.scale * quant_protos.scale;
1819 segmentation.map(|x| {
1820 let val: f32 = (*x).as_() * combined_scale;
1821 let sigmoid = 1.0 / (1.0 + (-val).exp());
1822 (sigmoid * 255.0).round() as u8
1823 })
1824}
1825
1826pub fn yolo_segmentation_to_mask(
1838 segmentation: ArrayView3<u8>,
1839 threshold: u8,
1840) -> Result<Array2<u8>, crate::DecoderError> {
1841 if segmentation.shape()[2] != 1 {
1842 return Err(crate::DecoderError::InvalidShape(format!(
1843 "Yolo Instance Segmentation should have shape (H, W, 1), got (H, W, {})",
1844 segmentation.shape()[2]
1845 )));
1846 }
1847 Ok(segmentation
1848 .slice(s![.., .., 0])
1849 .map(|x| if *x >= threshold { 1 } else { 0 }))
1850}
1851
1852#[cfg(test)]
1853#[cfg_attr(coverage_nightly, coverage(off))]
1854mod tests {
1855 use super::*;
1856 use ndarray::Array2;
1857
1858 #[test]
1863 fn test_end_to_end_det_basic_filtering() {
1864 let data: Vec<f32> = vec![
1868 0.1, 0.2, 0.3, 0.1, 0.2, 0.3, 0.5, 0.6, 0.7, 0.5, 0.6, 0.7, 0.9, 0.1, 0.2, 0.0, 1.0, 2.0, ];
1876 let output = Array2::from_shape_vec((6, 3), data).unwrap();
1877
1878 let mut boxes = Vec::with_capacity(10);
1879 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1880
1881 assert_eq!(boxes.len(), 1);
1883 assert_eq!(boxes[0].label, 0);
1884 assert!((boxes[0].score - 0.9).abs() < 0.01);
1885 assert!((boxes[0].bbox.xmin - 0.1).abs() < 0.01);
1886 assert!((boxes[0].bbox.ymin - 0.1).abs() < 0.01);
1887 assert!((boxes[0].bbox.xmax - 0.5).abs() < 0.01);
1888 assert!((boxes[0].bbox.ymax - 0.5).abs() < 0.01);
1889 }
1890
1891 #[test]
1892 fn test_end_to_end_det_all_pass_threshold() {
1893 let data: Vec<f32> = vec![
1895 10.0, 20.0, 10.0, 20.0, 50.0, 60.0, 50.0, 60.0, 0.8, 0.7, 1.0, 2.0, ];
1902 let output = Array2::from_shape_vec((6, 2), data).unwrap();
1903
1904 let mut boxes = Vec::with_capacity(10);
1905 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1906
1907 assert_eq!(boxes.len(), 2);
1908 assert_eq!(boxes[0].label, 1);
1909 assert_eq!(boxes[1].label, 2);
1910 }
1911
1912 #[test]
1913 fn test_end_to_end_det_none_pass_threshold() {
1914 let data: Vec<f32> = vec![
1916 10.0, 20.0, 10.0, 20.0, 50.0, 60.0, 50.0, 60.0, 0.1, 0.2, 1.0, 2.0, ];
1923 let output = Array2::from_shape_vec((6, 2), data).unwrap();
1924
1925 let mut boxes = Vec::with_capacity(10);
1926 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1927
1928 assert_eq!(boxes.len(), 0);
1929 }
1930
1931 #[test]
1932 fn test_end_to_end_det_capacity_limit() {
1933 let data: Vec<f32> = vec![
1935 0.1, 0.2, 0.3, 0.4, 0.5, 0.1, 0.2, 0.3, 0.4, 0.5, 0.5, 0.6, 0.7, 0.8, 0.9, 0.5, 0.6, 0.7, 0.8, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.0, 1.0, 2.0, 3.0, 4.0, ];
1942 let output = Array2::from_shape_vec((6, 5), data).unwrap();
1943
1944 let mut boxes = Vec::with_capacity(2); decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1946
1947 assert_eq!(boxes.len(), 2);
1948 }
1949
1950 #[test]
1951 fn test_end_to_end_det_empty_output() {
1952 let output = Array2::<f32>::zeros((6, 0));
1954
1955 let mut boxes = Vec::with_capacity(10);
1956 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1957
1958 assert_eq!(boxes.len(), 0);
1959 }
1960
1961 #[test]
1962 fn test_end_to_end_det_pixel_coordinates() {
1963 let data: Vec<f32> = vec![
1965 100.0, 200.0, 300.0, 400.0, 0.95, 5.0, ];
1972 let output = Array2::from_shape_vec((6, 1), data).unwrap();
1973
1974 let mut boxes = Vec::with_capacity(10);
1975 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1976
1977 assert_eq!(boxes.len(), 1);
1978 assert_eq!(boxes[0].label, 5);
1979 assert!((boxes[0].bbox.xmin - 100.0).abs() < 0.01);
1980 assert!((boxes[0].bbox.ymin - 200.0).abs() < 0.01);
1981 assert!((boxes[0].bbox.xmax - 300.0).abs() < 0.01);
1982 assert!((boxes[0].bbox.ymax - 400.0).abs() < 0.01);
1983 }
1984
1985 #[test]
1986 fn test_end_to_end_det_invalid_shape() {
1987 let output = Array2::<f32>::zeros((5, 3));
1989
1990 let mut boxes = Vec::with_capacity(10);
1991 let result = decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes);
1992
1993 assert!(result.is_err());
1994 assert!(matches!(
1995 result,
1996 Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 6 rows")
1997 ));
1998 }
1999
2000 #[test]
2005 fn test_end_to_end_segdet_basic() {
2006 let num_protos = 32;
2009 let num_detections = 2;
2010 let num_features = 6 + num_protos;
2011
2012 let mut data = vec![0.0f32; num_features * num_detections];
2014 data[0] = 0.1; data[1] = 0.5; data[num_detections] = 0.1; data[num_detections + 1] = 0.5; data[2 * num_detections] = 0.4; data[2 * num_detections + 1] = 0.9; data[3 * num_detections] = 0.4; data[3 * num_detections + 1] = 0.9; data[4 * num_detections] = 0.9; data[4 * num_detections + 1] = 0.3; data[5 * num_detections] = 1.0; data[5 * num_detections + 1] = 2.0; for i in 6..num_features {
2029 data[i * num_detections] = 0.1;
2030 data[i * num_detections + 1] = 0.1;
2031 }
2032
2033 let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
2034
2035 let protos = Array3::<f32>::zeros((16, 16, num_protos));
2037
2038 let mut boxes = Vec::with_capacity(10);
2039 let mut masks = Vec::with_capacity(10);
2040 decode_yolo_end_to_end_segdet_float(
2041 output.view(),
2042 protos.view(),
2043 0.5,
2044 &mut boxes,
2045 &mut masks,
2046 )
2047 .unwrap();
2048
2049 assert_eq!(boxes.len(), 1);
2051 assert_eq!(masks.len(), 1);
2052 assert_eq!(boxes[0].label, 1);
2053 assert!((boxes[0].score - 0.9).abs() < 0.01);
2054 }
2055
2056 #[test]
2057 fn test_end_to_end_segdet_mask_coordinates() {
2058 let num_protos = 32;
2060 let num_features = 6 + num_protos;
2061
2062 let mut data = vec![0.0f32; num_features];
2063 data[0] = 0.2; data[1] = 0.2; data[2] = 0.8; data[3] = 0.8; data[4] = 0.95; data[5] = 3.0; let output = Array2::from_shape_vec((num_features, 1), data).unwrap();
2071 let protos = Array3::<f32>::zeros((16, 16, num_protos));
2072
2073 let mut boxes = Vec::with_capacity(10);
2074 let mut masks = Vec::with_capacity(10);
2075 decode_yolo_end_to_end_segdet_float(
2076 output.view(),
2077 protos.view(),
2078 0.5,
2079 &mut boxes,
2080 &mut masks,
2081 )
2082 .unwrap();
2083
2084 assert_eq!(boxes.len(), 1);
2085 assert_eq!(masks.len(), 1);
2086
2087 assert!((masks[0].xmin - boxes[0].bbox.xmin).abs() < 0.01);
2089 assert!((masks[0].ymin - boxes[0].bbox.ymin).abs() < 0.01);
2090 assert!((masks[0].xmax - boxes[0].bbox.xmax).abs() < 0.01);
2091 assert!((masks[0].ymax - boxes[0].bbox.ymax).abs() < 0.01);
2092 }
2093
2094 #[test]
2095 fn test_end_to_end_segdet_empty_output() {
2096 let num_protos = 32;
2097 let output = Array2::<f32>::zeros((6 + num_protos, 0));
2098 let protos = Array3::<f32>::zeros((16, 16, num_protos));
2099
2100 let mut boxes = Vec::with_capacity(10);
2101 let mut masks = Vec::with_capacity(10);
2102 decode_yolo_end_to_end_segdet_float(
2103 output.view(),
2104 protos.view(),
2105 0.5,
2106 &mut boxes,
2107 &mut masks,
2108 )
2109 .unwrap();
2110
2111 assert_eq!(boxes.len(), 0);
2112 assert_eq!(masks.len(), 0);
2113 }
2114
2115 #[test]
2116 fn test_end_to_end_segdet_capacity_limit() {
2117 let num_protos = 32;
2118 let num_detections = 5;
2119 let num_features = 6 + num_protos;
2120
2121 let mut data = vec![0.0f32; num_features * num_detections];
2122 for i in 0..num_detections {
2124 data[i] = 0.1 * (i as f32); data[num_detections + i] = 0.1 * (i as f32); data[2 * num_detections + i] = 0.1 * (i as f32) + 0.2; data[3 * num_detections + i] = 0.1 * (i as f32) + 0.2; data[4 * num_detections + i] = 0.9; data[5 * num_detections + i] = i as f32; }
2131
2132 let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
2133 let protos = Array3::<f32>::zeros((16, 16, num_protos));
2134
2135 let mut boxes = Vec::with_capacity(2); let mut masks = Vec::with_capacity(2);
2137 decode_yolo_end_to_end_segdet_float(
2138 output.view(),
2139 protos.view(),
2140 0.5,
2141 &mut boxes,
2142 &mut masks,
2143 )
2144 .unwrap();
2145
2146 assert_eq!(boxes.len(), 2);
2147 assert_eq!(masks.len(), 2);
2148 }
2149
2150 #[test]
2151 fn test_end_to_end_segdet_invalid_shape_too_few_rows() {
2152 let output = Array2::<f32>::zeros((6, 3));
2154 let protos = Array3::<f32>::zeros((16, 16, 32));
2155
2156 let mut boxes = Vec::with_capacity(10);
2157 let mut masks = Vec::with_capacity(10);
2158 let result = decode_yolo_end_to_end_segdet_float(
2159 output.view(),
2160 protos.view(),
2161 0.5,
2162 &mut boxes,
2163 &mut masks,
2164 );
2165
2166 assert!(result.is_err());
2167 assert!(matches!(
2168 result,
2169 Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 7 rows")
2170 ));
2171 }
2172
2173 #[test]
2174 fn test_end_to_end_segdet_invalid_shape_protos_mismatch() {
2175 let num_protos = 32;
2177 let output = Array2::<f32>::zeros((6 + 16, 3)); let protos = Array3::<f32>::zeros((16, 16, num_protos)); let mut boxes = Vec::with_capacity(10);
2181 let mut masks = Vec::with_capacity(10);
2182 let result = decode_yolo_end_to_end_segdet_float(
2183 output.view(),
2184 protos.view(),
2185 0.5,
2186 &mut boxes,
2187 &mut masks,
2188 );
2189
2190 assert!(result.is_err());
2191 assert!(matches!(
2192 result,
2193 Err(crate::DecoderError::InvalidShape(s)) if s.contains("doesn't match protos count")
2194 ));
2195 }
2196
2197 #[test]
2202 fn test_split_end_to_end_segdet_basic() {
2203 let num_protos = 32;
2206 let num_detections = 2;
2207 let num_features = 6 + num_protos;
2208
2209 let mut data = vec![0.0f32; num_features * num_detections];
2211 data[0] = 0.1; data[1] = 0.5; data[num_detections] = 0.1; data[num_detections + 1] = 0.5; data[2 * num_detections] = 0.4; data[2 * num_detections + 1] = 0.9; data[3 * num_detections] = 0.4; data[3 * num_detections + 1] = 0.9; data[4 * num_detections] = 0.9; data[4 * num_detections + 1] = 0.3; data[5 * num_detections] = 1.0; data[5 * num_detections + 1] = 2.0; for i in 6..num_features {
2226 data[i * num_detections] = 0.1;
2227 data[i * num_detections + 1] = 0.1;
2228 }
2229
2230 let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
2231 let box_coords = output.slice(s![..4, ..]);
2232 let scores = output.slice(s![4..5, ..]);
2233 let classes = output.slice(s![5..6, ..]);
2234 let mask_coeff = output.slice(s![6.., ..]);
2235 let protos = Array3::<f32>::zeros((16, 16, num_protos));
2237
2238 let mut boxes = Vec::with_capacity(10);
2239 let mut masks = Vec::with_capacity(10);
2240 decode_yolo_split_end_to_end_segdet_float(
2241 box_coords,
2242 scores,
2243 classes,
2244 mask_coeff,
2245 protos.view(),
2246 0.5,
2247 &mut boxes,
2248 &mut masks,
2249 )
2250 .unwrap();
2251
2252 assert_eq!(boxes.len(), 1);
2254 assert_eq!(masks.len(), 1);
2255 assert_eq!(boxes[0].label, 1);
2256 assert!((boxes[0].score - 0.9).abs() < 0.01);
2257 }
2258
2259 #[test]
2264 fn test_segmentation_to_mask_basic() {
2265 let data: Vec<u8> = vec![
2267 100, 200, 50, 150, 10, 255, 128, 64, 0, 127, 128, 255, 64, 64, 192, 192, ];
2272 let segmentation = Array3::from_shape_vec((4, 4, 1), data).unwrap();
2273
2274 let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2275
2276 assert_eq!(mask[[0, 0]], 0); assert_eq!(mask[[0, 1]], 1); assert_eq!(mask[[0, 2]], 0); assert_eq!(mask[[0, 3]], 1); assert_eq!(mask[[1, 1]], 1); assert_eq!(mask[[1, 2]], 1); assert_eq!(mask[[2, 0]], 0); assert_eq!(mask[[2, 1]], 0); }
2286
2287 #[test]
2288 fn test_segmentation_to_mask_all_above() {
2289 let segmentation = Array3::from_elem((4, 4, 1), 255u8);
2290 let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2291 assert!(mask.iter().all(|&x| x == 1));
2292 }
2293
2294 #[test]
2295 fn test_segmentation_to_mask_all_below() {
2296 let segmentation = Array3::from_elem((4, 4, 1), 64u8);
2297 let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2298 assert!(mask.iter().all(|&x| x == 0));
2299 }
2300
2301 #[test]
2302 fn test_segmentation_to_mask_invalid_shape() {
2303 let segmentation = Array3::from_elem((4, 4, 3), 128u8);
2304 let result = yolo_segmentation_to_mask(segmentation.view(), 128);
2305
2306 assert!(result.is_err());
2307 assert!(matches!(
2308 result,
2309 Err(crate::DecoderError::InvalidShape(s)) if s.contains("(H, W, 1)")
2310 ));
2311 }
2312
2313 #[test]
2318 fn test_protobox_clamps_edge_coordinates() {
2319 let protos = Array3::<f32>::zeros((16, 16, 4));
2321 let view = protos.view();
2322 let roi = BoundingBox {
2323 xmin: 0.5,
2324 ymin: 0.5,
2325 xmax: 1.0,
2326 ymax: 1.0,
2327 };
2328 let result = protobox(&view, &roi);
2329 assert!(result.is_ok(), "protobox should accept xmax=1.0");
2330 let (cropped, _roi_norm) = result.unwrap();
2331 assert!(cropped.shape()[0] > 0);
2333 assert!(cropped.shape()[1] > 0);
2334 assert_eq!(cropped.shape()[2], 4);
2335 }
2336
2337 #[test]
2338 fn test_protobox_rejects_wildly_out_of_range() {
2339 let protos = Array3::<f32>::zeros((16, 16, 4));
2341 let view = protos.view();
2342 let roi = BoundingBox {
2343 xmin: 0.0,
2344 ymin: 0.0,
2345 xmax: 3.0,
2346 ymax: 3.0,
2347 };
2348 let result = protobox(&view, &roi);
2349 assert!(
2350 matches!(result, Err(crate::DecoderError::InvalidShape(s)) if s.contains("un-normalized")),
2351 "protobox should reject coords > NORM_LIMIT"
2352 );
2353 }
2354
2355 #[test]
2356 fn test_protobox_accepts_slightly_over_one() {
2357 let protos = Array3::<f32>::zeros((16, 16, 4));
2359 let view = protos.view();
2360 let roi = BoundingBox {
2361 xmin: 0.0,
2362 ymin: 0.0,
2363 xmax: 1.5,
2364 ymax: 1.5,
2365 };
2366 let result = protobox(&view, &roi);
2367 assert!(
2368 result.is_ok(),
2369 "protobox should accept coords <= NORM_LIMIT (2.0)"
2370 );
2371 let (cropped, _roi_norm) = result.unwrap();
2372 assert_eq!(cropped.shape()[0], 16);
2374 assert_eq!(cropped.shape()[1], 16);
2375 }
2376
2377 #[test]
2378 fn test_segdet_float_proto_no_panic() {
2379 let num_proposals = 100; let num_classes = 80;
2383 let num_mask_coeffs = 32;
2384 let rows = 4 + num_classes + num_mask_coeffs; let mut data = vec![0.0f32; rows * num_proposals];
2390 for i in 0..num_proposals {
2391 let row = |r: usize| r * num_proposals + i;
2392 data[row(0)] = 320.0; data[row(1)] = 320.0; data[row(2)] = 50.0; data[row(3)] = 50.0; data[row(4)] = 0.9; }
2398 let boxes = ndarray::Array2::from_shape_vec((rows, num_proposals), data).unwrap();
2399
2400 let protos = ndarray::Array3::<f32>::zeros((160, 160, num_mask_coeffs));
2405
2406 let mut output_boxes = Vec::with_capacity(300);
2407
2408 let proto_data = impl_yolo_segdet_float_proto::<XYWH, _, _>(
2410 boxes.view(),
2411 protos.view(),
2412 0.5,
2413 0.7,
2414 Some(Nms::default()),
2415 &mut output_boxes,
2416 );
2417
2418 assert!(!output_boxes.is_empty());
2420 let coeffs_shape = proto_data.mask_coefficients.shape();
2421 assert_eq!(coeffs_shape[0], output_boxes.len());
2422 assert_eq!(coeffs_shape[1], num_mask_coeffs);
2424 }
2425
2426 #[test]
2441 fn test_pre_nms_cap_truncates_excess_candidates() {
2442 let n: usize = 50_000;
2443 let num_classes = 1;
2444
2445 let mut boxes_data = Vec::with_capacity(n * 4);
2449 let mut scores_data = Vec::with_capacity(n * num_classes);
2450 for i in 0..n {
2451 boxes_data.extend_from_slice(&[0.1f32, 0.1, 0.5, 0.5]);
2452 scores_data.push(0.99 - (i as f32) * 1e-7);
2455 }
2456 let boxes = Array2::from_shape_vec((n, 4), boxes_data).unwrap();
2457 let scores = Array2::from_shape_vec((n, num_classes), scores_data).unwrap();
2458
2459 let result = impl_yolo_segdet_get_boxes::<XYXY, _, _>(
2460 boxes.view(),
2461 scores.view(),
2462 0.1,
2463 1.0,
2464 None, usize::MAX, );
2467
2468 assert_eq!(
2469 result.len(),
2470 crate::yolo::MAX_NMS_CANDIDATES,
2471 "pre-NMS cap should truncate to MAX_NMS_CANDIDATES; got {}",
2472 result.len()
2473 );
2474 let top_score = result[0].0.score;
2477 assert!(
2478 top_score > 0.98,
2479 "highest-ranked survivor should have the largest score, got {top_score}"
2480 );
2481 }
2482
2483 #[test]
2488 fn test_pre_nms_cap_truncates_excess_candidates_quant() {
2489 use crate::Quantization;
2490 let n: usize = 50_000;
2491 let num_classes = 1;
2492
2493 let boxes_data = (0..n).flat_map(|_| [10i8, 10, 50, 50]).collect::<Vec<_>>();
2496 let boxes = Array2::from_shape_vec((n, 4), boxes_data).unwrap();
2497 let quant_boxes = Quantization {
2498 scale: 0.01,
2499 zero_point: 0,
2500 };
2501
2502 let scores_data: Vec<u8> = (0..n)
2507 .map(|i| 250u8.saturating_sub((i % 200) as u8))
2508 .collect();
2509 let scores = Array2::from_shape_vec((n, num_classes), scores_data).unwrap();
2510 let quant_scores = Quantization {
2511 scale: 0.00392,
2512 zero_point: 0,
2513 };
2514
2515 let result = impl_yolo_split_segdet_quant_get_boxes::<XYXY, _, _>(
2516 (boxes.view(), quant_boxes),
2517 (scores.view(), quant_scores),
2518 0.1,
2519 1.0,
2520 None,
2521 usize::MAX,
2522 );
2523
2524 assert_eq!(
2525 result.len(),
2526 crate::yolo::MAX_NMS_CANDIDATES,
2527 "quant path pre-NMS cap should truncate to MAX_NMS_CANDIDATES; got {}",
2528 result.len()
2529 );
2530 }
2531
2532 #[test]
2546 fn segdet_combined_tensor_pairs_detection_with_matching_mask_row() {
2547 let nc = 2; let nm = 2; let n = 3; let feat = 4 + nc + nm; let mut data = vec![0.0f32; feat * n];
2570 let set = |d: &mut [f32], r: usize, c: usize, v: f32| d[r * n + c] = v;
2571 set(&mut data, 0, 0, 0.2);
2572 set(&mut data, 1, 0, 0.2);
2573 set(&mut data, 2, 0, 0.1);
2574 set(&mut data, 3, 0, 0.1);
2575 set(&mut data, 0, 1, 0.5);
2576 set(&mut data, 1, 1, 0.5);
2577 set(&mut data, 2, 1, 0.1);
2578 set(&mut data, 3, 1, 0.1);
2579 set(&mut data, 0, 2, 0.8);
2580 set(&mut data, 1, 2, 0.8);
2581 set(&mut data, 2, 2, 0.1);
2582 set(&mut data, 3, 2, 0.1);
2583 set(&mut data, 4, 0, 0.9);
2584 set(&mut data, 4, 2, 0.8);
2585 set(&mut data, 6, 0, 3.0);
2586 set(&mut data, 7, 0, 3.0);
2587 set(&mut data, 6, 2, -3.0);
2588 set(&mut data, 7, 2, -3.0);
2589
2590 let output = Array2::from_shape_vec((feat, n), data).unwrap();
2591 let protos = Array3::<f32>::from_elem((8, 8, nm), 1.0);
2592
2593 let mut boxes: Vec<DetectBox> = Vec::with_capacity(10);
2594 let mut masks: Vec<Segmentation> = Vec::with_capacity(10);
2595 decode_yolo_segdet_float(
2596 output.view(),
2597 protos.view(),
2598 0.5,
2599 0.5,
2600 Some(Nms::ClassAgnostic),
2601 &mut boxes,
2602 &mut masks,
2603 )
2604 .unwrap();
2605
2606 assert_eq!(
2607 boxes.len(),
2608 2,
2609 "two anchors above threshold should survive (a0 score=0.9, a2 score=0.8); got {}",
2610 boxes.len()
2611 );
2612
2613 for (b, m) in boxes.iter().zip(masks.iter()) {
2619 let cx = (b.bbox.xmin + b.bbox.xmax) * 0.5;
2620 let mean = {
2621 let s = &m.segmentation;
2622 let total: u32 = s.iter().map(|&v| v as u32).sum();
2623 total as f32 / s.len() as f32
2624 };
2625 if cx < 0.3 {
2626 assert!(
2628 mean > 200.0,
2629 "anchor 0 detection (centre {cx:.2}) should have high-value mask; got mean {mean}"
2630 );
2631 } else if cx > 0.7 {
2632 assert!(
2634 mean < 50.0,
2635 "anchor 2 detection (centre {cx:.2}) should have low-value mask; got mean {mean}"
2636 );
2637 } else {
2638 panic!("unexpected detection centre {cx:.2}");
2639 }
2640 }
2641 }
2642}