1use std::fmt::Debug;
13
14use ndarray::{
15 parallel::prelude::{IntoParallelIterator, ParallelIterator},
16 s, Array2, Array3, ArrayView1, ArrayView2, ArrayView3,
17};
18use num_traits::{AsPrimitive, Float, PrimInt, Signed};
19
20use crate::{
21 byte::{
22 nms_class_aware_int, nms_extra_class_aware_int, nms_extra_int, nms_int,
23 postprocess_boxes_index_quant, postprocess_boxes_quant, quantize_score_threshold,
24 },
25 configs::Nms,
26 dequant_detect_box,
27 float::{
28 nms_class_aware_float, nms_extra_class_aware_float, nms_extra_float, nms_float,
29 postprocess_boxes_float, postprocess_boxes_index_float,
30 },
31 BBoxTypeTrait, BoundingBox, DetectBox, DetectBoxQuantized, ProtoData, ProtoLayout,
32 Quantization, Segmentation, XYWH, XYXY,
33};
34
35#[cfg(test)]
53pub(crate) const MAX_NMS_CANDIDATES: usize = 30_000;
54
55pub(crate) const DEFAULT_MAX_DETECTIONS: usize = 300;
62
63fn truncate_to_top_k_by_score<E: Send>(boxes: &mut Vec<(DetectBox, E)>, top_k: usize) {
67 if boxes.len() > top_k {
68 boxes.select_nth_unstable_by(top_k, |a, b| b.0.score.total_cmp(&a.0.score));
69 boxes.truncate(top_k);
70 }
71}
72
73fn truncate_to_top_k_by_score_quant<S: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send>(
77 boxes: &mut Vec<(DetectBoxQuantized<S>, E)>,
78 top_k: usize,
79) {
80 if boxes.len() > top_k {
81 boxes.select_nth_unstable_by(top_k, |a, b| b.0.score.cmp(&a.0.score));
82 boxes.truncate(top_k);
83 }
84}
85
86fn dispatch_nms_float(
93 nms: Option<Nms>,
94 iou: f32,
95 max_det: Option<usize>,
96 boxes: Vec<DetectBox>,
97) -> Vec<DetectBox> {
98 match nms {
99 Some(Nms::ClassAgnostic) => nms_float(iou, max_det, boxes),
100 Some(Nms::ClassAware) => nms_class_aware_float(iou, max_det, boxes),
101 None => boxes, }
103}
104
105pub(super) fn dispatch_nms_extra_float<E: Send + Sync>(
108 nms: Option<Nms>,
109 iou: f32,
110 max_det: Option<usize>,
111 boxes: Vec<(DetectBox, E)>,
112) -> Vec<(DetectBox, E)> {
113 match nms {
114 Some(Nms::ClassAgnostic) => nms_extra_float(iou, max_det, boxes),
115 Some(Nms::ClassAware) => nms_extra_class_aware_float(iou, max_det, boxes),
116 None => boxes, }
118}
119
120fn dispatch_nms_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync>(
123 nms: Option<Nms>,
124 iou: f32,
125 max_det: Option<usize>,
126 boxes: Vec<DetectBoxQuantized<SCORE>>,
127) -> Vec<DetectBoxQuantized<SCORE>> {
128 match nms {
129 Some(Nms::ClassAgnostic) => nms_int(iou, max_det, boxes),
130 Some(Nms::ClassAware) => nms_class_aware_int(iou, max_det, boxes),
131 None => boxes, }
133}
134
135fn dispatch_nms_extra_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send + Sync>(
138 nms: Option<Nms>,
139 iou: f32,
140 max_det: Option<usize>,
141 boxes: Vec<(DetectBoxQuantized<SCORE>, E)>,
142) -> Vec<(DetectBoxQuantized<SCORE>, E)> {
143 match nms {
144 Some(Nms::ClassAgnostic) => nms_extra_int(iou, max_det, boxes),
145 Some(Nms::ClassAware) => nms_extra_class_aware_int(iou, max_det, boxes),
146 None => boxes, }
148}
149
150#[inline]
157fn cap_or_default<T>(v: &Vec<T>) -> usize {
158 if v.capacity() > 0 {
159 v.capacity()
160 } else {
161 DEFAULT_MAX_DETECTIONS
162 }
163}
164
165pub(crate) fn decode_yolo_det<BOX: PrimInt + AsPrimitive<f32> + Send + Sync>(
199 output: (ArrayView2<BOX>, Quantization),
200 score_threshold: f32,
201 iou_threshold: f32,
202 nms: Option<Nms>,
203 output_boxes: &mut Vec<DetectBox>,
204) where
205 f32: AsPrimitive<BOX>,
206{
207 impl_yolo_quant::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
208}
209
210pub(crate) fn decode_yolo_det_float<T>(
217 output: ArrayView2<T>,
218 score_threshold: f32,
219 iou_threshold: f32,
220 nms: Option<Nms>,
221 output_boxes: &mut Vec<DetectBox>,
222) where
223 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
224 f32: AsPrimitive<T>,
225{
226 impl_yolo_float::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
227}
228
229#[cfg(test)]
243pub(crate) fn decode_yolo_segdet_quant<
244 BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
245 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
246>(
247 boxes: (ArrayView2<BOX>, Quantization),
248 protos: (ArrayView3<PROTO>, Quantization),
249 score_threshold: f32,
250 iou_threshold: f32,
251 nms: Option<Nms>,
252 output_boxes: &mut Vec<DetectBox>,
253 output_masks: &mut Vec<Segmentation>,
254) -> Result<(), crate::DecoderError>
255where
256 f32: AsPrimitive<BOX>,
257{
258 let cap = cap_or_default(output_boxes);
263 impl_yolo_segdet_quant::<XYWH, _, _>(
264 boxes,
265 protos,
266 score_threshold,
267 iou_threshold,
268 nms,
269 MAX_NMS_CANDIDATES,
270 cap,
271 None,
272 None,
273 output_boxes,
274 output_masks,
275 )
276}
277
278#[cfg(test)]
280pub(crate) fn decode_yolo_segdet_float<T>(
281 boxes: ArrayView2<T>,
282 protos: ArrayView3<T>,
283 score_threshold: f32,
284 iou_threshold: f32,
285 nms: Option<Nms>,
286 output_boxes: &mut Vec<DetectBox>,
287 output_masks: &mut Vec<Segmentation>,
288) -> Result<(), crate::DecoderError>
289where
290 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
291 f32: AsPrimitive<T>,
292{
293 let cap = cap_or_default(output_boxes);
296 impl_yolo_segdet_float::<XYWH, _, _>(
297 boxes,
298 protos,
299 score_threshold,
300 iou_threshold,
301 nms,
302 MAX_NMS_CANDIDATES,
303 cap,
304 None,
305 None,
306 output_boxes,
307 output_masks,
308 )
309}
310
311pub(crate) fn decode_yolo_split_det_quant<
323 BOX: PrimInt + AsPrimitive<i32> + AsPrimitive<f32> + Send + Sync,
324 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
325>(
326 boxes: (ArrayView2<BOX>, Quantization),
327 scores: (ArrayView2<SCORE>, Quantization),
328 score_threshold: f32,
329 iou_threshold: f32,
330 nms: Option<Nms>,
331 output_boxes: &mut Vec<DetectBox>,
332) where
333 f32: AsPrimitive<SCORE>,
334{
335 impl_yolo_split_quant::<XYWH, _, _>(
336 boxes,
337 scores,
338 score_threshold,
339 iou_threshold,
340 nms,
341 output_boxes,
342 );
343}
344
345pub(crate) fn decode_yolo_split_det_float<T>(
357 boxes: ArrayView2<T>,
358 scores: ArrayView2<T>,
359 score_threshold: f32,
360 iou_threshold: f32,
361 nms: Option<Nms>,
362 output_boxes: &mut Vec<DetectBox>,
363) where
364 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
365 f32: AsPrimitive<T>,
366{
367 impl_yolo_split_float::<XYWH, _, _>(
368 boxes,
369 scores,
370 score_threshold,
371 iou_threshold,
372 nms,
373 output_boxes,
374 );
375}
376
377pub(crate) fn decode_yolo_end_to_end_det_float<T>(
392 output: ArrayView2<T>,
393 score_threshold: f32,
394 output_boxes: &mut Vec<DetectBox>,
395) -> Result<(), crate::DecoderError>
396where
397 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
398 f32: AsPrimitive<T>,
399{
400 if output.shape()[0] < 6 {
402 return Err(crate::DecoderError::InvalidShape(format!(
403 "End-to-end detection output requires at least 6 rows, got {}",
404 output.shape()[0]
405 )));
406 }
407
408 let boxes = output.slice(s![0..4, ..]).reversed_axes();
410 let scores = output.slice(s![4..5, ..]).reversed_axes();
411 let classes = output.slice(s![5, ..]);
412 let mut boxes =
413 postprocess_boxes_index_float::<XYXY, _, _>(score_threshold.as_(), boxes, scores);
414 boxes.truncate(cap_or_default(output_boxes));
415 output_boxes.clear();
416 for (mut b, i) in boxes.into_iter() {
417 b.label = classes[i].as_() as usize;
418 output_boxes.push(b);
419 }
420 Ok(())
422}
423
424pub(crate) fn decode_yolo_end_to_end_segdet_float<T>(
442 output: ArrayView2<T>,
443 protos: ArrayView3<T>,
444 score_threshold: f32,
445 output_boxes: &mut Vec<DetectBox>,
446 output_masks: &mut Vec<crate::Segmentation>,
447) -> Result<(), crate::DecoderError>
448where
449 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
450 f32: AsPrimitive<T>,
451{
452 let (boxes, scores, classes, mask_coeff) =
453 postprocess_yolo_end_to_end_segdet(&output, protos.dim().2)?;
454 let cap = cap_or_default(output_boxes);
455 let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
456 boxes,
457 scores,
458 classes,
459 score_threshold,
460 cap,
461 );
462
463 impl_yolo_split_segdet_process_masks(boxes, mask_coeff, protos, output_boxes, output_masks)
466}
467
468pub(crate) fn decode_yolo_split_end_to_end_det_float<T: Float + AsPrimitive<f32>>(
477 boxes: ArrayView2<T>,
478 scores: ArrayView2<T>,
479 classes: ArrayView2<T>,
480 score_threshold: f32,
481 output_boxes: &mut Vec<DetectBox>,
482) -> Result<(), crate::DecoderError> {
483 let n = boxes.shape()[1];
484
485 let cap = cap_or_default(output_boxes);
486 output_boxes.clear();
487
488 let (boxes, scores, classes) = postprocess_yolo_split_end_to_end_det(boxes, scores, &classes)?;
489
490 for i in 0..n {
491 let score: f32 = scores[[i, 0]].as_();
492 if score < score_threshold {
493 continue;
494 }
495 if output_boxes.len() >= cap {
496 break;
497 }
498 output_boxes.push(DetectBox {
499 bbox: BoundingBox {
500 xmin: boxes[[i, 0]].as_(),
501 ymin: boxes[[i, 1]].as_(),
502 xmax: boxes[[i, 2]].as_(),
503 ymax: boxes[[i, 3]].as_(),
504 },
505 score,
506 label: classes[i].as_() as usize,
507 });
508 }
509 Ok(())
510}
511
512#[allow(clippy::too_many_arguments)]
521pub(crate) fn decode_yolo_split_end_to_end_segdet_float<T>(
522 boxes: ArrayView2<T>,
523 scores: ArrayView2<T>,
524 classes: ArrayView2<T>,
525 mask_coeff: ArrayView2<T>,
526 protos: ArrayView3<T>,
527 score_threshold: f32,
528 output_boxes: &mut Vec<DetectBox>,
529 output_masks: &mut Vec<crate::Segmentation>,
530) -> Result<(), crate::DecoderError>
531where
532 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
533 f32: AsPrimitive<T>,
534{
535 let (boxes, scores, classes, mask_coeff) =
536 postprocess_yolo_split_end_to_end_segdet(boxes, scores, &classes, mask_coeff)?;
537 let cap = cap_or_default(output_boxes);
538 let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
539 boxes,
540 scores,
541 classes,
542 score_threshold,
543 cap,
544 );
545
546 impl_yolo_split_segdet_process_masks(boxes, mask_coeff, protos, output_boxes, output_masks)
547}
548
549#[allow(clippy::type_complexity)]
550pub(crate) fn postprocess_yolo_end_to_end_segdet<'a, T>(
551 output: &'a ArrayView2<'_, T>,
552 num_protos: usize,
553) -> Result<
554 (
555 ArrayView2<'a, T>,
556 ArrayView2<'a, T>,
557 ArrayView1<'a, T>,
558 ArrayView2<'a, T>,
559 ),
560 crate::DecoderError,
561> {
562 if output.shape()[0] < 7 {
564 return Err(crate::DecoderError::InvalidShape(format!(
565 "End-to-end segdet output requires at least 7 rows, got {}",
566 output.shape()[0]
567 )));
568 }
569
570 let num_mask_coeffs = output.shape()[0] - 6;
571 if num_mask_coeffs != num_protos {
572 return Err(crate::DecoderError::InvalidShape(format!(
573 "Mask coefficients count ({}) doesn't match protos count ({})",
574 num_mask_coeffs, num_protos
575 )));
576 }
577
578 let boxes = output.slice(s![0..4, ..]).reversed_axes();
580 let scores = output.slice(s![4..5, ..]).reversed_axes();
581 let classes = output.slice(s![5, ..]);
582 let mask_coeff = output.slice(s![6.., ..]).reversed_axes();
583 Ok((boxes, scores, classes, mask_coeff))
584}
585
586#[allow(clippy::type_complexity)]
593pub(crate) fn postprocess_yolo_split_end_to_end_det<'a, 'b, 'c, BOXES, SCORES, CLASS>(
594 boxes: ArrayView2<'a, BOXES>,
595 scores: ArrayView2<'b, SCORES>,
596 classes: &'c ArrayView2<CLASS>,
597) -> Result<
598 (
599 ArrayView2<'a, BOXES>,
600 ArrayView2<'b, SCORES>,
601 ArrayView1<'c, CLASS>,
602 ),
603 crate::DecoderError,
604> {
605 let num_boxes = boxes.shape()[1];
606 if boxes.shape()[0] != 4 {
607 return Err(crate::DecoderError::InvalidShape(format!(
608 "Split end-to-end box_coords must be 4, got {}",
609 boxes.shape()[0]
610 )));
611 }
612
613 if scores.shape()[0] != 1 {
614 return Err(crate::DecoderError::InvalidShape(format!(
615 "Split end-to-end scores num_classes must be 1, got {}",
616 scores.shape()[0]
617 )));
618 }
619
620 if classes.shape()[0] != 1 {
621 return Err(crate::DecoderError::InvalidShape(format!(
622 "Split end-to-end classes num_classes must be 1, got {}",
623 classes.shape()[0]
624 )));
625 }
626
627 if scores.shape()[1] != num_boxes {
628 return Err(crate::DecoderError::InvalidShape(format!(
629 "Split end-to-end scores must have same num_boxes as boxes ({}), got {}",
630 num_boxes,
631 scores.shape()[1]
632 )));
633 }
634
635 if classes.shape()[1] != num_boxes {
636 return Err(crate::DecoderError::InvalidShape(format!(
637 "Split end-to-end classes must have same num_boxes as boxes ({}), got {}",
638 num_boxes,
639 classes.shape()[1]
640 )));
641 }
642
643 let boxes = boxes.reversed_axes();
644 let scores = scores.reversed_axes();
645 let classes = classes.slice(s![0, ..]);
646 Ok((boxes, scores, classes))
647}
648
649#[allow(clippy::type_complexity)]
652pub(crate) fn postprocess_yolo_split_end_to_end_segdet<
653 'a,
654 'b,
655 'c,
656 'd,
657 BOXES,
658 SCORES,
659 CLASS,
660 MASK,
661>(
662 boxes: ArrayView2<'a, BOXES>,
663 scores: ArrayView2<'b, SCORES>,
664 classes: &'c ArrayView2<CLASS>,
665 mask_coeff: ArrayView2<'d, MASK>,
666) -> Result<
667 (
668 ArrayView2<'a, BOXES>,
669 ArrayView2<'b, SCORES>,
670 ArrayView1<'c, CLASS>,
671 ArrayView2<'d, MASK>,
672 ),
673 crate::DecoderError,
674> {
675 let num_boxes = boxes.shape()[1];
676 if boxes.shape()[0] != 4 {
677 return Err(crate::DecoderError::InvalidShape(format!(
678 "Split end-to-end box_coords must be 4, got {}",
679 boxes.shape()[0]
680 )));
681 }
682
683 if scores.shape()[0] != 1 {
684 return Err(crate::DecoderError::InvalidShape(format!(
685 "Split end-to-end scores num_classes must be 1, got {}",
686 scores.shape()[0]
687 )));
688 }
689
690 if classes.shape()[0] != 1 {
691 return Err(crate::DecoderError::InvalidShape(format!(
692 "Split end-to-end classes num_classes must be 1, got {}",
693 classes.shape()[0]
694 )));
695 }
696
697 if scores.shape()[1] != num_boxes {
698 return Err(crate::DecoderError::InvalidShape(format!(
699 "Split end-to-end scores must have same num_boxes as boxes ({}), got {}",
700 num_boxes,
701 scores.shape()[1]
702 )));
703 }
704
705 if classes.shape()[1] != num_boxes {
706 return Err(crate::DecoderError::InvalidShape(format!(
707 "Split end-to-end classes must have same num_boxes as boxes ({}), got {}",
708 num_boxes,
709 classes.shape()[1]
710 )));
711 }
712
713 if mask_coeff.shape()[1] != num_boxes {
714 return Err(crate::DecoderError::InvalidShape(format!(
715 "Split end-to-end mask_coeff must have same num_boxes as boxes ({}), got {}",
716 num_boxes,
717 mask_coeff.shape()[1]
718 )));
719 }
720
721 let boxes = boxes.reversed_axes();
722 let scores = scores.reversed_axes();
723 let classes = classes.slice(s![0, ..]);
724 let mask_coeff = mask_coeff.reversed_axes();
725 Ok((boxes, scores, classes, mask_coeff))
726}
727pub(crate) fn impl_yolo_quant<B: BBoxTypeTrait, T: PrimInt + AsPrimitive<f32> + Send + Sync>(
732 output: (ArrayView2<T>, Quantization),
733 score_threshold: f32,
734 iou_threshold: f32,
735 nms: Option<Nms>,
736 output_boxes: &mut Vec<DetectBox>,
737) where
738 f32: AsPrimitive<T>,
739{
740 let _span = tracing::trace_span!("decode", mode = "quant_det").entered();
741 let (boxes, quant_boxes) = output;
742 let (boxes_tensor, scores_tensor) = postprocess_yolo(&boxes);
743
744 let boxes = {
745 let score_threshold = quantize_score_threshold(score_threshold, quant_boxes);
746 postprocess_boxes_quant::<B, _, _>(
747 score_threshold,
748 boxes_tensor,
749 scores_tensor,
750 quant_boxes,
751 )
752 };
753
754 let cap = cap_or_default(output_boxes);
755 let boxes = dispatch_nms_int(nms, iou_threshold, Some(cap), boxes);
756 let len = cap.min(boxes.len());
759 output_boxes.clear();
760 for b in boxes.iter().take(len) {
761 output_boxes.push(dequant_detect_box(b, quant_boxes));
762 }
763}
764
765pub(crate) fn impl_yolo_float<B: BBoxTypeTrait, T: Float + AsPrimitive<f32> + Send + Sync>(
770 output: ArrayView2<T>,
771 score_threshold: f32,
772 iou_threshold: f32,
773 nms: Option<Nms>,
774 output_boxes: &mut Vec<DetectBox>,
775) where
776 f32: AsPrimitive<T>,
777{
778 let _span = tracing::trace_span!("decode", mode = "float_det").entered();
779 let (boxes_tensor, scores_tensor) = postprocess_yolo(&output);
780 let boxes =
781 postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
782 let cap = cap_or_default(output_boxes);
783 let boxes = dispatch_nms_float(nms, iou_threshold, Some(cap), boxes);
784 let len = cap.min(boxes.len());
787 output_boxes.clear();
788 for b in boxes.into_iter().take(len) {
789 output_boxes.push(b);
790 }
791}
792
793pub(crate) fn impl_yolo_split_quant<
803 B: BBoxTypeTrait,
804 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
805 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
806>(
807 boxes: (ArrayView2<BOX>, Quantization),
808 scores: (ArrayView2<SCORE>, Quantization),
809 score_threshold: f32,
810 iou_threshold: f32,
811 nms: Option<Nms>,
812 output_boxes: &mut Vec<DetectBox>,
813) where
814 f32: AsPrimitive<SCORE>,
815{
816 let _span = tracing::trace_span!("decode", mode = "split_quant_det").entered();
817 let (boxes_tensor, quant_boxes) = boxes;
818 let (scores_tensor, quant_scores) = scores;
819
820 let boxes_tensor = boxes_tensor.reversed_axes();
821 let scores_tensor = scores_tensor.reversed_axes();
822
823 let boxes = {
824 let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
825 postprocess_boxes_quant::<B, _, _>(
826 score_threshold,
827 boxes_tensor,
828 scores_tensor,
829 quant_boxes,
830 )
831 };
832
833 let cap = cap_or_default(output_boxes);
834 let boxes = dispatch_nms_int(nms, iou_threshold, Some(cap), boxes);
835 let len = cap.min(boxes.len());
838 output_boxes.clear();
839 for b in boxes.iter().take(len) {
840 output_boxes.push(dequant_detect_box(b, quant_scores));
841 }
842}
843
844pub(crate) fn impl_yolo_split_float<
853 B: BBoxTypeTrait,
854 BOX: Float + AsPrimitive<f32> + Send + Sync,
855 SCORE: Float + AsPrimitive<f32> + Send + Sync,
856>(
857 boxes_tensor: ArrayView2<BOX>,
858 scores_tensor: ArrayView2<SCORE>,
859 score_threshold: f32,
860 iou_threshold: f32,
861 nms: Option<Nms>,
862 output_boxes: &mut Vec<DetectBox>,
863) where
864 f32: AsPrimitive<SCORE>,
865{
866 let _span = tracing::trace_span!("decode", mode = "split_float_det").entered();
867 let boxes_tensor = boxes_tensor.reversed_axes();
868 let scores_tensor = scores_tensor.reversed_axes();
869 let boxes =
870 postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
871 let cap = cap_or_default(output_boxes);
872 let boxes = dispatch_nms_float(nms, iou_threshold, Some(cap), boxes);
873 let len = cap.min(boxes.len());
876 output_boxes.clear();
877 for b in boxes.into_iter().take(len) {
878 output_boxes.push(b);
879 }
880}
881
882#[inline]
890pub(crate) fn maybe_normalize_boxes_in_place(
891 boxes: &mut [(DetectBox, usize)],
892 normalized: Option<bool>,
893 input_dims: Option<(usize, usize)>,
894) {
895 if normalized != Some(false) {
896 return;
897 }
898 let Some((w, h)) = input_dims else {
899 return;
900 };
901 if w == 0 || h == 0 {
902 return;
903 }
904 let inv_w = 1.0 / w as f32;
905 let inv_h = 1.0 / h as f32;
906 for (b, _) in boxes.iter_mut() {
907 b.bbox.xmin *= inv_w;
908 b.bbox.ymin *= inv_h;
909 b.bbox.xmax *= inv_w;
910 b.bbox.ymax *= inv_h;
911 }
912}
913
914#[allow(clippy::too_many_arguments)]
924pub(crate) fn impl_yolo_segdet_quant<
925 B: BBoxTypeTrait,
926 BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
927 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
928>(
929 boxes: (ArrayView2<BOX>, Quantization),
930 protos: (ArrayView3<PROTO>, Quantization),
931 score_threshold: f32,
932 iou_threshold: f32,
933 nms: Option<Nms>,
934 pre_nms_top_k: usize,
935 max_det: usize,
936 normalized: Option<bool>,
937 input_dims: Option<(usize, usize)>,
938 output_boxes: &mut Vec<DetectBox>,
939 output_masks: &mut Vec<Segmentation>,
940) -> Result<(), crate::DecoderError>
941where
942 f32: AsPrimitive<BOX>,
943{
944 let (boxes, quant_boxes) = boxes;
945 let num_protos = protos.0.dim().2;
946
947 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
948 let mut boxes = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
949 (boxes_tensor, quant_boxes),
950 (scores_tensor, quant_boxes),
951 score_threshold,
952 iou_threshold,
953 nms,
954 pre_nms_top_k,
955 max_det,
956 );
957 maybe_normalize_boxes_in_place(&mut boxes, normalized, input_dims);
958
959 impl_yolo_split_segdet_quant_process_masks::<_, _>(
960 boxes,
961 (mask_tensor, quant_boxes),
962 protos,
963 output_boxes,
964 output_masks,
965 )
966}
967
968#[allow(clippy::too_many_arguments)]
978pub(crate) fn impl_yolo_segdet_float<
979 B: BBoxTypeTrait,
980 BOX: Float + AsPrimitive<f32> + Send + Sync,
981 PROTO: Float + AsPrimitive<f32> + Send + Sync,
982>(
983 boxes: ArrayView2<BOX>,
984 protos: ArrayView3<PROTO>,
985 score_threshold: f32,
986 iou_threshold: f32,
987 nms: Option<Nms>,
988 pre_nms_top_k: usize,
989 max_det: usize,
990 normalized: Option<bool>,
991 input_dims: Option<(usize, usize)>,
992 output_boxes: &mut Vec<DetectBox>,
993 output_masks: &mut Vec<Segmentation>,
994) -> Result<(), crate::DecoderError>
995where
996 f32: AsPrimitive<BOX>,
997{
998 let num_protos = protos.dim().2;
999 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
1000 let mut boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
1001 boxes_tensor,
1002 scores_tensor,
1003 score_threshold,
1004 iou_threshold,
1005 nms,
1006 pre_nms_top_k,
1007 max_det,
1008 );
1009 maybe_normalize_boxes_in_place(&mut boxes, normalized, input_dims);
1010 impl_yolo_split_segdet_process_masks(boxes, mask_tensor, protos, output_boxes, output_masks)
1011}
1012
1013pub(crate) fn impl_yolo_segdet_get_boxes<
1014 B: BBoxTypeTrait,
1015 BOX: Float + AsPrimitive<f32> + Send + Sync,
1016 SCORE: Float + AsPrimitive<f32> + Send + Sync,
1017>(
1018 boxes_tensor: ArrayView2<BOX>,
1019 scores_tensor: ArrayView2<SCORE>,
1020 score_threshold: f32,
1021 iou_threshold: f32,
1022 nms: Option<Nms>,
1023 pre_nms_top_k: usize,
1024 max_det: usize,
1025) -> Vec<(DetectBox, usize)>
1026where
1027 f32: AsPrimitive<SCORE>,
1028{
1029 let span = tracing::trace_span!(
1030 "decode",
1031 n_candidates = tracing::field::Empty,
1032 n_after_topk = tracing::field::Empty,
1033 n_after_nms = tracing::field::Empty,
1034 n_detections = tracing::field::Empty,
1035 );
1036 let _guard = span.enter();
1037
1038 let mut boxes = {
1039 let _s = tracing::trace_span!("score_filter").entered();
1040 postprocess_boxes_index_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor)
1041 };
1042 span.record("n_candidates", boxes.len());
1043
1044 if nms.is_some() {
1045 let _s = tracing::trace_span!("top_k", k = pre_nms_top_k).entered();
1046 truncate_to_top_k_by_score(&mut boxes, pre_nms_top_k);
1047 }
1048 span.record("n_after_topk", boxes.len());
1049
1050 let mut boxes = {
1051 let _s = tracing::trace_span!("nms").entered();
1052 dispatch_nms_extra_float(nms, iou_threshold, Some(max_det), boxes)
1053 };
1054 span.record("n_after_nms", boxes.len());
1055
1056 boxes.sort_unstable_by(|a, b| b.0.score.total_cmp(&a.0.score));
1059 boxes.truncate(max_det);
1060 span.record("n_detections", boxes.len());
1061
1062 boxes
1063}
1064
1065pub(crate) fn impl_yolo_end_to_end_segdet_get_boxes<
1066 B: BBoxTypeTrait,
1067 BOX: Float + AsPrimitive<f32> + Send + Sync,
1068 SCORE: Float + AsPrimitive<f32> + Send + Sync,
1069 CLASS: AsPrimitive<f32> + Send + Sync,
1070>(
1071 boxes: ArrayView2<BOX>,
1072 scores: ArrayView2<SCORE>,
1073 classes: ArrayView1<CLASS>,
1074 score_threshold: f32,
1075 max_boxes: usize,
1076) -> Vec<(DetectBox, usize)>
1077where
1078 f32: AsPrimitive<SCORE>,
1079{
1080 let mut boxes = postprocess_boxes_index_float::<B, _, _>(score_threshold.as_(), boxes, scores);
1081 boxes.truncate(max_boxes);
1082 for (b, ind) in &mut boxes {
1083 b.label = classes[*ind].as_().round() as usize;
1084 }
1085 boxes
1086}
1087
1088pub(crate) fn impl_yolo_split_segdet_process_masks<
1089 MASK: Float + AsPrimitive<f32> + Send + Sync,
1090 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1091>(
1092 boxes: Vec<(DetectBox, usize)>,
1093 masks_tensor: ArrayView2<MASK>,
1094 protos_tensor: ArrayView3<PROTO>,
1095 output_boxes: &mut Vec<DetectBox>,
1096 output_masks: &mut Vec<Segmentation>,
1097) -> Result<(), crate::DecoderError> {
1098 let _span = tracing::trace_span!("process_masks", n = boxes.len(), mode = "float").entered();
1099 let boxes = decode_segdet_f32(boxes, masks_tensor, protos_tensor)?;
1103 output_boxes.clear();
1104 output_masks.clear();
1105 for (b, roi, m) in boxes.into_iter() {
1106 output_boxes.push(b);
1107 output_masks.push(Segmentation {
1108 xmin: roi.xmin,
1109 ymin: roi.ymin,
1110 xmax: roi.xmax,
1111 ymax: roi.ymax,
1112 segmentation: m,
1113 });
1114 }
1115 Ok(())
1116}
1117pub(crate) fn impl_yolo_split_segdet_quant_get_boxes<
1121 B: BBoxTypeTrait,
1122 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
1123 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
1124>(
1125 boxes: (ArrayView2<BOX>, Quantization),
1126 scores: (ArrayView2<SCORE>, Quantization),
1127 score_threshold: f32,
1128 iou_threshold: f32,
1129 nms: Option<Nms>,
1130 pre_nms_top_k: usize,
1131 max_det: usize,
1132) -> Vec<(DetectBox, usize)>
1133where
1134 f32: AsPrimitive<SCORE>,
1135{
1136 let (boxes_tensor, quant_boxes) = boxes;
1137 let (scores_tensor, quant_scores) = scores;
1138
1139 let span = tracing::trace_span!(
1140 "decode",
1141 n_candidates = tracing::field::Empty,
1142 n_after_topk = tracing::field::Empty,
1143 n_after_nms = tracing::field::Empty,
1144 n_detections = tracing::field::Empty,
1145 );
1146 let _guard = span.enter();
1147
1148 let mut boxes = {
1149 let _s = tracing::trace_span!("score_filter").entered();
1150 let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
1151 postprocess_boxes_index_quant::<B, _, _>(
1152 score_threshold,
1153 boxes_tensor,
1154 scores_tensor,
1155 quant_boxes,
1156 )
1157 };
1158 span.record("n_candidates", boxes.len());
1159
1160 if nms.is_some() {
1161 let _s = tracing::trace_span!("top_k", k = pre_nms_top_k).entered();
1162 truncate_to_top_k_by_score_quant(&mut boxes, pre_nms_top_k);
1163 }
1164 span.record("n_after_topk", boxes.len());
1165
1166 let mut boxes = {
1167 let _s = tracing::trace_span!("nms").entered();
1168 dispatch_nms_extra_int(nms, iou_threshold, Some(max_det), boxes)
1169 };
1170 span.record("n_after_nms", boxes.len());
1171
1172 boxes.sort_unstable_by(|a, b| b.0.score.cmp(&a.0.score));
1175 boxes.truncate(max_det);
1176 let result: Vec<_> = {
1177 let _s = tracing::trace_span!("box_dequant", n = boxes.len()).entered();
1178 boxes
1179 .into_iter()
1180 .map(|(b, i)| (dequant_detect_box(&b, quant_scores), i))
1181 .collect()
1182 };
1183 span.record("n_detections", result.len());
1184
1185 result
1186}
1187
1188pub(crate) fn impl_yolo_split_segdet_quant_process_masks<
1189 MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1190 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1191>(
1192 boxes: Vec<(DetectBox, usize)>,
1193 mask_coeff: (ArrayView2<MASK>, Quantization),
1194 protos: (ArrayView3<PROTO>, Quantization),
1195 output_boxes: &mut Vec<DetectBox>,
1196 output_masks: &mut Vec<Segmentation>,
1197) -> Result<(), crate::DecoderError> {
1198 let _span = tracing::trace_span!("process_masks", n = boxes.len(), mode = "quant").entered();
1199 let (masks, quant_masks) = mask_coeff;
1200 let (protos, quant_protos) = protos;
1201
1202 let boxes = decode_segdet_quant(boxes, masks, protos, quant_masks, quant_protos)?;
1206 output_boxes.clear();
1207 output_masks.clear();
1208 for (b, roi, m) in boxes.into_iter() {
1209 output_boxes.push(b);
1210 output_masks.push(Segmentation {
1211 xmin: roi.xmin,
1212 ymin: roi.ymin,
1213 xmax: roi.xmax,
1214 ymax: roi.ymax,
1215 segmentation: m,
1216 });
1217 }
1218 Ok(())
1219}
1220
1221#[allow(clippy::too_many_arguments)]
1233pub(crate) fn impl_yolo_split_segdet_float<
1234 B: BBoxTypeTrait,
1235 BOX: Float + AsPrimitive<f32> + Send + Sync,
1236 SCORE: Float + AsPrimitive<f32> + Send + Sync,
1237 MASK: Float + AsPrimitive<f32> + Send + Sync,
1238 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1239>(
1240 boxes_tensor: ArrayView2<BOX>,
1241 scores_tensor: ArrayView2<SCORE>,
1242 mask_tensor: ArrayView2<MASK>,
1243 protos: ArrayView3<PROTO>,
1244 score_threshold: f32,
1245 iou_threshold: f32,
1246 nms: Option<Nms>,
1247 pre_nms_top_k: usize,
1248 max_det: usize,
1249 normalized: Option<bool>,
1250 input_dims: Option<(usize, usize)>,
1251 output_boxes: &mut Vec<DetectBox>,
1252 output_masks: &mut Vec<Segmentation>,
1253) -> Result<(), crate::DecoderError>
1254where
1255 f32: AsPrimitive<SCORE>,
1256{
1257 let (boxes_tensor, scores_tensor, mask_tensor) =
1258 postprocess_yolo_split_segdet(boxes_tensor, scores_tensor, mask_tensor);
1259
1260 let mut boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
1261 boxes_tensor,
1262 scores_tensor,
1263 score_threshold,
1264 iou_threshold,
1265 nms,
1266 pre_nms_top_k,
1267 max_det,
1268 );
1269 maybe_normalize_boxes_in_place(&mut boxes, normalized, input_dims);
1270 impl_yolo_split_segdet_process_masks(boxes, mask_tensor, protos, output_boxes, output_masks)
1271}
1272
1273#[allow(clippy::too_many_arguments)]
1280pub(crate) fn impl_yolo_segdet_quant_proto<
1281 B: BBoxTypeTrait,
1282 BOX: PrimInt
1283 + AsPrimitive<i64>
1284 + AsPrimitive<i128>
1285 + AsPrimitive<f32>
1286 + AsPrimitive<i8>
1287 + Send
1288 + Sync,
1289 PROTO: PrimInt
1290 + AsPrimitive<i64>
1291 + AsPrimitive<i128>
1292 + AsPrimitive<f32>
1293 + AsPrimitive<i8>
1294 + Send
1295 + Sync,
1296>(
1297 boxes: (ArrayView2<BOX>, Quantization),
1298 protos: (ArrayView3<PROTO>, Quantization),
1299 score_threshold: f32,
1300 iou_threshold: f32,
1301 nms: Option<Nms>,
1302 pre_nms_top_k: usize,
1303 max_det: usize,
1304 normalized: Option<bool>,
1305 input_dims: Option<(usize, usize)>,
1306 output_boxes: &mut Vec<DetectBox>,
1307) -> ProtoData
1308where
1309 f32: AsPrimitive<BOX>,
1310{
1311 let (boxes_arr, quant_boxes) = boxes;
1312 let (protos_arr, quant_protos) = protos;
1313 let num_protos = protos_arr.dim().2;
1314
1315 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes_arr, num_protos);
1316
1317 let mut det_indices = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
1318 (boxes_tensor, quant_boxes),
1319 (scores_tensor, quant_boxes),
1320 score_threshold,
1321 iou_threshold,
1322 nms,
1323 pre_nms_top_k,
1324 max_det,
1325 );
1326 maybe_normalize_boxes_in_place(&mut det_indices, normalized, input_dims);
1327
1328 extract_proto_data_quant(
1329 det_indices,
1330 mask_tensor,
1331 quant_boxes,
1332 protos_arr,
1333 quant_protos,
1334 output_boxes,
1335 )
1336}
1337
1338#[allow(clippy::too_many_arguments)]
1341pub(crate) fn impl_yolo_segdet_float_proto<
1342 B: BBoxTypeTrait,
1343 BOX: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1344 PROTO: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1345>(
1346 boxes: ArrayView2<BOX>,
1347 protos: ArrayView3<PROTO>,
1348 score_threshold: f32,
1349 iou_threshold: f32,
1350 nms: Option<Nms>,
1351 pre_nms_top_k: usize,
1352 max_det: usize,
1353 normalized: Option<bool>,
1354 input_dims: Option<(usize, usize)>,
1355 output_boxes: &mut Vec<DetectBox>,
1356) -> ProtoData
1357where
1358 f32: AsPrimitive<BOX>,
1359{
1360 let num_protos = protos.dim().2;
1361 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
1362
1363 let mut boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
1364 boxes_tensor,
1365 scores_tensor,
1366 score_threshold,
1367 iou_threshold,
1368 nms,
1369 pre_nms_top_k,
1370 max_det,
1371 );
1372 maybe_normalize_boxes_in_place(&mut boxes, normalized, input_dims);
1373
1374 extract_proto_data_float(boxes, mask_tensor, protos, output_boxes)
1375}
1376
1377#[allow(clippy::too_many_arguments)]
1380pub(crate) fn impl_yolo_split_segdet_float_proto<
1381 B: BBoxTypeTrait,
1382 BOX: Float + AsPrimitive<f32> + Send + Sync,
1383 SCORE: Float + AsPrimitive<f32> + Send + Sync,
1384 MASK: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1385 PROTO: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1386>(
1387 boxes_tensor: ArrayView2<BOX>,
1388 scores_tensor: ArrayView2<SCORE>,
1389 mask_tensor: ArrayView2<MASK>,
1390 protos: ArrayView3<PROTO>,
1391 score_threshold: f32,
1392 iou_threshold: f32,
1393 nms: Option<Nms>,
1394 pre_nms_top_k: usize,
1395 max_det: usize,
1396 output_boxes: &mut Vec<DetectBox>,
1397) -> ProtoData
1398where
1399 f32: AsPrimitive<SCORE>,
1400{
1401 let (boxes_tensor, scores_tensor, mask_tensor) =
1402 postprocess_yolo_split_segdet(boxes_tensor, scores_tensor, mask_tensor);
1403 let det_indices = impl_yolo_segdet_get_boxes::<B, _, _>(
1404 boxes_tensor,
1405 scores_tensor,
1406 score_threshold,
1407 iou_threshold,
1408 nms,
1409 pre_nms_top_k,
1410 max_det,
1411 );
1412
1413 extract_proto_data_float(det_indices, mask_tensor, protos, output_boxes)
1414}
1415
1416pub(crate) fn decode_yolo_end_to_end_segdet_float_proto<T>(
1418 output: ArrayView2<T>,
1419 protos: ArrayView3<T>,
1420 score_threshold: f32,
1421 output_boxes: &mut Vec<DetectBox>,
1422) -> Result<ProtoData, crate::DecoderError>
1423where
1424 T: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1425 f32: AsPrimitive<T>,
1426{
1427 let (boxes, scores, classes, mask_coeff) =
1428 postprocess_yolo_end_to_end_segdet(&output, protos.dim().2)?;
1429 let cap = cap_or_default(output_boxes);
1430 let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
1431 boxes,
1432 scores,
1433 classes,
1434 score_threshold,
1435 cap,
1436 );
1437
1438 Ok(extract_proto_data_float(
1439 boxes,
1440 mask_coeff,
1441 protos,
1442 output_boxes,
1443 ))
1444}
1445
1446#[allow(clippy::too_many_arguments)]
1448pub(crate) fn decode_yolo_split_end_to_end_segdet_float_proto<T>(
1449 boxes: ArrayView2<T>,
1450 scores: ArrayView2<T>,
1451 classes: ArrayView2<T>,
1452 mask_coeff: ArrayView2<T>,
1453 protos: ArrayView3<T>,
1454 score_threshold: f32,
1455 output_boxes: &mut Vec<DetectBox>,
1456) -> Result<ProtoData, crate::DecoderError>
1457where
1458 T: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1459 f32: AsPrimitive<T>,
1460{
1461 let (boxes, scores, classes, mask_coeff) =
1462 postprocess_yolo_split_end_to_end_segdet(boxes, scores, &classes, mask_coeff)?;
1463 let cap = cap_or_default(output_boxes);
1464 let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
1465 boxes,
1466 scores,
1467 classes,
1468 score_threshold,
1469 cap,
1470 );
1471
1472 Ok(extract_proto_data_float(
1473 boxes,
1474 mask_coeff,
1475 protos,
1476 output_boxes,
1477 ))
1478}
1479
1480pub(super) fn extract_proto_data_float<
1487 MASK: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1488 PROTO: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1489>(
1490 det_indices: Vec<(DetectBox, usize)>,
1491 mask_tensor: ArrayView2<MASK>,
1492 protos: ArrayView3<PROTO>,
1493 output_boxes: &mut Vec<DetectBox>,
1494) -> ProtoData {
1495 let _span = tracing::trace_span!(
1496 "extract_proto",
1497 n = det_indices.len(),
1498 num_protos = mask_tensor.ncols(),
1499 layout = "nhwc",
1500 )
1501 .entered();
1502
1503 let num_protos = mask_tensor.ncols();
1504 let n = det_indices.len();
1505
1506 let mut coeff_rows: Vec<MASK> = Vec::with_capacity(n * num_protos);
1511 output_boxes.clear();
1512 for (det, idx) in det_indices {
1513 output_boxes.push(det);
1514 let row = mask_tensor.row(idx);
1515 coeff_rows.extend(row.iter().copied());
1516 }
1517
1518 let mask_coefficients = MASK::slice_into_tensor_dyn(&coeff_rows, &[n, num_protos])
1519 .expect("allocating mask_coefficients TensorDyn");
1520 let protos_tensor =
1521 PROTO::arrayview3_into_tensor_dyn(protos).expect("allocating protos TensorDyn");
1522
1523 ProtoData {
1524 mask_coefficients,
1525 protos: protos_tensor,
1526 layout: ProtoLayout::Nhwc,
1527 }
1528}
1529
1530pub(crate) fn extract_proto_data_quant<
1539 MASK: PrimInt + AsPrimitive<f32> + AsPrimitive<i8> + Send + Sync,
1540 PROTO: PrimInt + AsPrimitive<f32> + AsPrimitive<i8> + Send + Sync + 'static,
1541>(
1542 det_indices: Vec<(DetectBox, usize)>,
1543 mask_tensor: ArrayView2<MASK>,
1544 quant_masks: Quantization,
1545 protos: ArrayView3<PROTO>,
1546 quant_protos: Quantization,
1547 output_boxes: &mut Vec<DetectBox>,
1548) -> ProtoData {
1549 use edgefirst_tensor::{Tensor, TensorDyn, TensorMapTrait, TensorMemory, TensorTrait};
1550
1551 let span = tracing::trace_span!(
1552 "extract_proto",
1553 n = det_indices.len(),
1554 num_protos = tracing::field::Empty,
1555 layout = tracing::field::Empty,
1556 );
1557 let _guard = span.enter();
1558
1559 let num_protos = mask_tensor.ncols();
1560 let n = det_indices.len();
1561 span.record("num_protos", num_protos);
1562
1563 if n == 0 {
1569 output_boxes.clear();
1570 let (h, w, k) = protos.dim();
1571
1572 let (proto_shape, proto_layout) = if std::any::TypeId::of::<PROTO>()
1574 == std::any::TypeId::of::<i8>()
1575 {
1576 if protos.is_standard_layout() {
1577 (&[h, w, k][..], ProtoLayout::Nhwc)
1578 } else if protos.ndim() == 3 && protos.strides() == [w as isize, 1, (h * w) as isize] {
1579 (&[k, h, w][..], ProtoLayout::Nchw)
1580 } else {
1581 (&[h, w, k][..], ProtoLayout::Nhwc)
1582 }
1583 } else {
1584 (&[h, w, k][..], ProtoLayout::Nhwc)
1585 };
1586
1587 let coeff_tensor = Tensor::<i8>::new(&[0, num_protos], Some(TensorMemory::Mem), None)
1588 .expect("allocating empty mask_coefficients tensor");
1589 let coeff_quant =
1590 edgefirst_tensor::Quantization::per_tensor(quant_masks.scale, quant_masks.zero_point);
1591 let coeff_tensor = coeff_tensor
1592 .with_quantization(coeff_quant)
1593 .expect("per-tensor quantization on mask coefficients");
1594 let protos_tensor = Tensor::<i8>::new(proto_shape, Some(TensorMemory::Mem), None)
1595 .expect("allocating protos tensor");
1596 let tensor_quant =
1597 edgefirst_tensor::Quantization::per_tensor(quant_protos.scale, quant_protos.zero_point);
1598 let protos_tensor = protos_tensor
1599 .with_quantization(tensor_quant)
1600 .expect("per-tensor quantization on protos tensor");
1601 return ProtoData {
1602 mask_coefficients: TensorDyn::I8(coeff_tensor),
1603 protos: TensorDyn::I8(protos_tensor),
1604 layout: proto_layout,
1605 };
1606 }
1607
1608 let mut coeff_i8 = Vec::<i8>::with_capacity(n * num_protos);
1612 output_boxes.clear();
1613 for (det, idx) in det_indices {
1614 output_boxes.push(det);
1615 let row = mask_tensor.row(idx);
1616 coeff_i8.extend(row.iter().map(|v| {
1617 let v_i8: i8 = v.as_();
1618 v_i8
1619 }));
1620 }
1621
1622 let coeff_tensor = Tensor::<i8>::new(&[n, num_protos], Some(TensorMemory::Mem), None)
1625 .expect("allocating mask_coefficients tensor");
1626 if n > 0 {
1627 let mut m = coeff_tensor
1628 .map()
1629 .expect("mapping mask_coefficients tensor");
1630 m.as_mut_slice().copy_from_slice(&coeff_i8);
1631 }
1632 let coeff_quant =
1633 edgefirst_tensor::Quantization::per_tensor(quant_masks.scale, quant_masks.zero_point);
1634 let coeff_tensor = coeff_tensor
1635 .with_quantization(coeff_quant)
1636 .expect("per-tensor quantization on mask coefficients");
1637 let mask_coefficients = TensorDyn::I8(coeff_tensor);
1638
1639 let (h, w, k) = protos.dim();
1643
1644 let (proto_shape, proto_layout) =
1646 if std::any::TypeId::of::<PROTO>() == std::any::TypeId::of::<i8>() {
1647 if protos.is_standard_layout() {
1648 (&[h, w, k][..], ProtoLayout::Nhwc)
1650 } else if protos.ndim() == 3 && protos.strides() == [w as isize, 1, (h * w) as isize] {
1651 (&[k, h, w][..], ProtoLayout::Nchw)
1655 } else {
1656 (&[h, w, k][..], ProtoLayout::Nhwc)
1658 }
1659 } else {
1660 (&[h, w, k][..], ProtoLayout::Nhwc)
1661 };
1662
1663 let protos_tensor = Tensor::<i8>::new(proto_shape, Some(TensorMemory::Mem), None)
1664 .expect("allocating protos tensor");
1665 {
1666 let mut m = protos_tensor.map().expect("mapping protos tensor");
1667 let dst = m.as_mut_slice();
1668 if std::any::TypeId::of::<PROTO>() == std::any::TypeId::of::<i8>() {
1669 if protos.is_standard_layout() {
1672 let src: &[i8] = unsafe {
1673 std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len())
1674 };
1675 dst.copy_from_slice(src);
1676 } else if protos.ndim() == 3 && protos.strides() == [w as isize, 1, (h * w) as isize] {
1677 let total = h * w * k;
1681 let src: &[i8] =
1684 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, total) };
1685 dst.copy_from_slice(src);
1686 } else {
1687 for (d, s) in dst.iter_mut().zip(protos.iter()) {
1688 let v_i8: i8 = s.as_();
1689 *d = v_i8;
1690 }
1691 }
1692 } else {
1693 for (d, s) in dst.iter_mut().zip(protos.iter()) {
1694 let v_i8: i8 = s.as_();
1695 *d = v_i8;
1696 }
1697 }
1698 }
1699 let tensor_quant =
1700 edgefirst_tensor::Quantization::per_tensor(quant_protos.scale, quant_protos.zero_point);
1701 let protos_tensor = protos_tensor
1702 .with_quantization(tensor_quant)
1703 .expect("per-tensor quantization on new Tensor<i8>");
1704
1705 span.record("layout", tracing::field::debug(&proto_layout));
1706
1707 ProtoData {
1708 mask_coefficients,
1709 protos: TensorDyn::I8(protos_tensor),
1710 layout: proto_layout,
1711 }
1712}
1713
1714pub trait FloatProtoElem: Copy + 'static {
1720 fn slice_into_tensor_dyn(
1721 values: &[Self],
1722 shape: &[usize],
1723 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn>;
1724
1725 fn arrayview3_into_tensor_dyn(
1726 view: ArrayView3<'_, Self>,
1727 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn>;
1728}
1729
1730impl FloatProtoElem for f32 {
1731 fn slice_into_tensor_dyn(
1732 values: &[f32],
1733 shape: &[usize],
1734 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1735 edgefirst_tensor::Tensor::<f32>::from_slice(values, shape)
1736 .map(edgefirst_tensor::TensorDyn::F32)
1737 }
1738 fn arrayview3_into_tensor_dyn(
1739 view: ArrayView3<'_, f32>,
1740 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1741 edgefirst_tensor::Tensor::<f32>::from_arrayview3(view).map(edgefirst_tensor::TensorDyn::F32)
1742 }
1743}
1744
1745impl FloatProtoElem for half::f16 {
1746 fn slice_into_tensor_dyn(
1747 values: &[half::f16],
1748 shape: &[usize],
1749 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1750 edgefirst_tensor::Tensor::<half::f16>::from_slice(values, shape)
1751 .map(edgefirst_tensor::TensorDyn::F16)
1752 }
1753 fn arrayview3_into_tensor_dyn(
1754 view: ArrayView3<'_, half::f16>,
1755 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1756 edgefirst_tensor::Tensor::<half::f16>::from_arrayview3(view)
1757 .map(edgefirst_tensor::TensorDyn::F16)
1758 }
1759}
1760
1761impl FloatProtoElem for f64 {
1762 fn slice_into_tensor_dyn(
1763 values: &[f64],
1764 shape: &[usize],
1765 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1766 let narrowed: Vec<f32> = values.iter().map(|&v| v as f32).collect();
1768 edgefirst_tensor::Tensor::<f32>::from_slice(&narrowed, shape)
1769 .map(edgefirst_tensor::TensorDyn::F32)
1770 }
1771 fn arrayview3_into_tensor_dyn(
1772 view: ArrayView3<'_, f64>,
1773 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1774 let narrowed: ndarray::Array3<f32> = view.mapv(|v| v as f32);
1775 edgefirst_tensor::Tensor::<f32>::from_arrayview3(narrowed.view())
1776 .map(edgefirst_tensor::TensorDyn::F32)
1777 }
1778}
1779
1780fn postprocess_yolo<'a, T>(
1781 output: &'a ArrayView2<'_, T>,
1782) -> (ArrayView2<'a, T>, ArrayView2<'a, T>) {
1783 let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1784 let scores_tensor = output.slice(s![4.., ..,]).reversed_axes();
1785 (boxes_tensor, scores_tensor)
1786}
1787
1788pub(crate) fn postprocess_yolo_seg<'a, T>(
1789 output: &'a ArrayView2<'_, T>,
1790 num_protos: usize,
1791) -> (ArrayView2<'a, T>, ArrayView2<'a, T>, ArrayView2<'a, T>) {
1792 assert!(
1793 output.shape()[0] > num_protos + 4,
1794 "Output shape is too short: {} <= {} + 4",
1795 output.shape()[0],
1796 num_protos
1797 );
1798 let num_classes = output.shape()[0] - 4 - num_protos;
1799 let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1800 let scores_tensor = output.slice(s![4..(num_classes + 4), ..,]).reversed_axes();
1801 let mask_tensor = output.slice(s![(num_classes + 4).., ..,]).reversed_axes();
1802 (boxes_tensor, scores_tensor, mask_tensor)
1803}
1804
1805pub(crate) fn postprocess_yolo_split_segdet<'a, 'b, 'c, BOX, SCORE, MASK>(
1806 boxes_tensor: ArrayView2<'a, BOX>,
1807 scores_tensor: ArrayView2<'b, SCORE>,
1808 mask_tensor: ArrayView2<'c, MASK>,
1809) -> (
1810 ArrayView2<'a, BOX>,
1811 ArrayView2<'b, SCORE>,
1812 ArrayView2<'c, MASK>,
1813) {
1814 let boxes_tensor = boxes_tensor.reversed_axes();
1815 let scores_tensor = scores_tensor.reversed_axes();
1816 let mask_tensor = mask_tensor.reversed_axes();
1817 (boxes_tensor, scores_tensor, mask_tensor)
1818}
1819
1820fn decode_segdet_f32<
1821 MASK: Float + AsPrimitive<f32> + Send + Sync,
1822 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1823>(
1824 boxes: Vec<(DetectBox, usize)>,
1825 masks: ArrayView2<MASK>,
1826 protos: ArrayView3<PROTO>,
1827) -> Result<Vec<(DetectBox, BoundingBox, Array3<u8>)>, crate::DecoderError> {
1828 if boxes.is_empty() {
1829 return Ok(Vec::new());
1830 }
1831 if masks.shape()[1] != protos.shape()[2] {
1832 return Err(crate::DecoderError::InvalidShape(format!(
1833 "Mask coefficients count ({}) doesn't match protos channel count ({})",
1834 masks.shape()[1],
1835 protos.shape()[2],
1836 )));
1837 }
1838 boxes
1839 .into_par_iter()
1840 .map(|b| {
1841 let ind = b.1;
1842 let (protos, roi) = protobox(&protos, &b.0.bbox)?;
1847 Ok((b.0, roi, make_segmentation(masks.row(ind), protos.view())))
1848 })
1849 .collect()
1850}
1851
1852pub(crate) fn decode_segdet_quant<
1853 MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1854 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1855>(
1856 boxes: Vec<(DetectBox, usize)>,
1857 masks: ArrayView2<MASK>,
1858 protos: ArrayView3<PROTO>,
1859 quant_masks: Quantization,
1860 quant_protos: Quantization,
1861) -> Result<Vec<(DetectBox, BoundingBox, Array3<u8>)>, crate::DecoderError> {
1862 if boxes.is_empty() {
1863 return Ok(Vec::new());
1864 }
1865 if masks.shape()[1] != protos.shape()[2] {
1866 return Err(crate::DecoderError::InvalidShape(format!(
1867 "Mask coefficients count ({}) doesn't match protos channel count ({})",
1868 masks.shape()[1],
1869 protos.shape()[2],
1870 )));
1871 }
1872
1873 let total_bits = MASK::zero().count_zeros() + PROTO::zero().count_zeros() + 5; boxes
1875 .into_iter()
1876 .map(|b| {
1877 let i = b.1;
1878 let (protos, roi) = protobox(&protos, &b.0.bbox.to_canonical())?;
1882 let seg = match total_bits {
1883 0..=64 => make_segmentation_quant::<MASK, PROTO, i64>(
1884 masks.row(i),
1885 protos.view(),
1886 quant_masks,
1887 quant_protos,
1888 ),
1889 65..=128 => make_segmentation_quant::<MASK, PROTO, i128>(
1890 masks.row(i),
1891 protos.view(),
1892 quant_masks,
1893 quant_protos,
1894 ),
1895 _ => {
1896 return Err(crate::DecoderError::NotSupported(format!(
1897 "Unsupported bit width ({total_bits}) for segmentation computation"
1898 )));
1899 }
1900 };
1901 Ok((b.0, roi, seg))
1902 })
1903 .collect()
1904}
1905
1906fn protobox<'a, T>(
1907 protos: &'a ArrayView3<T>,
1908 roi: &BoundingBox,
1909) -> Result<(ArrayView3<'a, T>, BoundingBox), crate::DecoderError> {
1910 let width = protos.dim().1 as f32;
1911 let height = protos.dim().0 as f32;
1912
1913 const NORM_LIMIT: f32 = 2.0;
1925 if roi.xmin > NORM_LIMIT
1926 || roi.ymin > NORM_LIMIT
1927 || roi.xmax > NORM_LIMIT
1928 || roi.ymax > NORM_LIMIT
1929 {
1930 return Err(crate::DecoderError::InvalidShape(format!(
1931 "Bounding box coordinates appear un-normalized (pixel-space). \
1932 Got bbox=({:.2}, {:.2}, {:.2}, {:.2}) but expected values in [0, 1]. \
1933 Two ways to fix this: \
1934 (1) declare `Detection::normalized = false` in the model schema \
1935 AND make sure the schema's `input.shape` / `input.dshape` carries \
1936 the model input dims so the decoder can divide by (W, H) before NMS \
1937 (EDGEAI-1303 — verify with `Decoder::input_dims().is_some()`); or \
1938 (2) normalize the boxes in-graph before decode().",
1939 roi.xmin, roi.ymin, roi.xmax, roi.ymax,
1940 )));
1941 }
1942
1943 let roi = [
1944 (roi.xmin * width).clamp(0.0, width) as usize,
1945 (roi.ymin * height).clamp(0.0, height) as usize,
1946 (roi.xmax * width).clamp(0.0, width).ceil() as usize,
1947 (roi.ymax * height).clamp(0.0, height).ceil() as usize,
1948 ];
1949
1950 let roi_norm = [
1951 roi[0] as f32 / width,
1952 roi[1] as f32 / height,
1953 roi[2] as f32 / width,
1954 roi[3] as f32 / height,
1955 ]
1956 .into();
1957
1958 let cropped = protos.slice(s![roi[1]..roi[3], roi[0]..roi[2], ..]);
1959
1960 Ok((cropped, roi_norm))
1961}
1962
1963fn make_segmentation<
1969 MASK: Float + AsPrimitive<f32> + Send + Sync,
1970 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1971>(
1972 mask: ArrayView1<MASK>,
1973 protos: ArrayView3<PROTO>,
1974) -> Array3<u8> {
1975 let shape = protos.shape();
1976
1977 let mask = mask.to_shape((1, mask.len())).unwrap();
1979 let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
1980 let protos = protos.reversed_axes();
1981 let mask = mask.map(|x| x.as_());
1982 let protos = protos.map(|x| x.as_());
1983
1984 let mask = mask
1986 .dot(&protos)
1987 .into_shape_with_order((shape[0], shape[1], 1))
1988 .unwrap();
1989
1990 mask.map(|x| {
1991 let sigmoid = 1.0 / (1.0 + (-*x).exp());
1992 (sigmoid * 255.0).round() as u8
1993 })
1994}
1995
1996fn make_segmentation_quant<
2003 MASK: PrimInt + AsPrimitive<DEST> + Send + Sync,
2004 PROTO: PrimInt + AsPrimitive<DEST> + Send + Sync,
2005 DEST: PrimInt + 'static + Signed + AsPrimitive<f32> + Debug,
2006>(
2007 mask: ArrayView1<MASK>,
2008 protos: ArrayView3<PROTO>,
2009 quant_masks: Quantization,
2010 quant_protos: Quantization,
2011) -> Array3<u8>
2012where
2013 i32: AsPrimitive<DEST>,
2014 f32: AsPrimitive<DEST>,
2015{
2016 let shape = protos.shape();
2017
2018 let mask = mask.to_shape((1, mask.len())).unwrap();
2020
2021 let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
2022 let protos = protos.reversed_axes();
2023
2024 let zp = quant_masks.zero_point.as_();
2025
2026 let mask = mask.mapv(|x| x.as_() - zp);
2027
2028 let zp = quant_protos.zero_point.as_();
2029 let protos = protos.mapv(|x| x.as_() - zp);
2030
2031 let segmentation = mask
2033 .dot(&protos)
2034 .into_shape_with_order((shape[0], shape[1], 1))
2035 .unwrap();
2036
2037 let combined_scale = quant_masks.scale * quant_protos.scale;
2038 segmentation.map(|x| {
2039 let val: f32 = (*x).as_() * combined_scale;
2040 let sigmoid = 1.0 / (1.0 + (-val).exp());
2041 (sigmoid * 255.0).round() as u8
2042 })
2043}
2044
2045pub(crate) fn yolo_segmentation_to_mask(
2057 segmentation: ArrayView3<u8>,
2058 threshold: u8,
2059) -> Result<Array2<u8>, crate::DecoderError> {
2060 if segmentation.shape()[2] != 1 {
2061 return Err(crate::DecoderError::InvalidShape(format!(
2062 "Yolo Instance Segmentation should have shape (H, W, 1), got (H, W, {})",
2063 segmentation.shape()[2]
2064 )));
2065 }
2066 Ok(segmentation
2067 .slice(s![.., .., 0])
2068 .map(|x| if *x >= threshold { 1 } else { 0 }))
2069}
2070
2071#[cfg(test)]
2072#[cfg_attr(coverage_nightly, coverage(off))]
2073mod tests {
2074 use super::*;
2075 use ndarray::Array2;
2076
2077 #[test]
2082 fn test_end_to_end_det_basic_filtering() {
2083 let data: Vec<f32> = vec![
2087 0.1, 0.2, 0.3, 0.1, 0.2, 0.3, 0.5, 0.6, 0.7, 0.5, 0.6, 0.7, 0.9, 0.1, 0.2, 0.0, 1.0, 2.0, ];
2095 let output = Array2::from_shape_vec((6, 3), data).unwrap();
2096
2097 let mut boxes = Vec::with_capacity(10);
2098 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2099
2100 assert_eq!(boxes.len(), 1);
2102 assert_eq!(boxes[0].label, 0);
2103 assert!((boxes[0].score - 0.9).abs() < 0.01);
2104 assert!((boxes[0].bbox.xmin - 0.1).abs() < 0.01);
2105 assert!((boxes[0].bbox.ymin - 0.1).abs() < 0.01);
2106 assert!((boxes[0].bbox.xmax - 0.5).abs() < 0.01);
2107 assert!((boxes[0].bbox.ymax - 0.5).abs() < 0.01);
2108 }
2109
2110 #[test]
2111 fn test_end_to_end_det_all_pass_threshold() {
2112 let data: Vec<f32> = vec![
2114 10.0, 20.0, 10.0, 20.0, 50.0, 60.0, 50.0, 60.0, 0.8, 0.7, 1.0, 2.0, ];
2121 let output = Array2::from_shape_vec((6, 2), data).unwrap();
2122
2123 let mut boxes = Vec::with_capacity(10);
2124 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2125
2126 assert_eq!(boxes.len(), 2);
2127 assert_eq!(boxes[0].label, 1);
2128 assert_eq!(boxes[1].label, 2);
2129 }
2130
2131 #[test]
2132 fn test_end_to_end_det_none_pass_threshold() {
2133 let data: Vec<f32> = vec![
2135 10.0, 20.0, 10.0, 20.0, 50.0, 60.0, 50.0, 60.0, 0.1, 0.2, 1.0, 2.0, ];
2142 let output = Array2::from_shape_vec((6, 2), data).unwrap();
2143
2144 let mut boxes = Vec::with_capacity(10);
2145 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2146
2147 assert_eq!(boxes.len(), 0);
2148 }
2149
2150 #[test]
2151 fn test_end_to_end_det_capacity_limit() {
2152 let data: Vec<f32> = vec![
2154 0.1, 0.2, 0.3, 0.4, 0.5, 0.1, 0.2, 0.3, 0.4, 0.5, 0.5, 0.6, 0.7, 0.8, 0.9, 0.5, 0.6, 0.7, 0.8, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.0, 1.0, 2.0, 3.0, 4.0, ];
2161 let output = Array2::from_shape_vec((6, 5), data).unwrap();
2162
2163 let mut boxes = Vec::with_capacity(2); decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2165
2166 assert_eq!(boxes.len(), 2);
2167 }
2168
2169 #[test]
2170 fn test_end_to_end_det_empty_output() {
2171 let output = Array2::<f32>::zeros((6, 0));
2173
2174 let mut boxes = Vec::with_capacity(10);
2175 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2176
2177 assert_eq!(boxes.len(), 0);
2178 }
2179
2180 #[test]
2181 fn test_end_to_end_det_pixel_coordinates() {
2182 let data: Vec<f32> = vec![
2184 100.0, 200.0, 300.0, 400.0, 0.95, 5.0, ];
2191 let output = Array2::from_shape_vec((6, 1), data).unwrap();
2192
2193 let mut boxes = Vec::with_capacity(10);
2194 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2195
2196 assert_eq!(boxes.len(), 1);
2197 assert_eq!(boxes[0].label, 5);
2198 assert!((boxes[0].bbox.xmin - 100.0).abs() < 0.01);
2199 assert!((boxes[0].bbox.ymin - 200.0).abs() < 0.01);
2200 assert!((boxes[0].bbox.xmax - 300.0).abs() < 0.01);
2201 assert!((boxes[0].bbox.ymax - 400.0).abs() < 0.01);
2202 }
2203
2204 #[test]
2205 fn test_end_to_end_det_invalid_shape() {
2206 let output = Array2::<f32>::zeros((5, 3));
2208
2209 let mut boxes = Vec::with_capacity(10);
2210 let result = decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes);
2211
2212 assert!(result.is_err());
2213 assert!(matches!(
2214 result,
2215 Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 6 rows")
2216 ));
2217 }
2218
2219 #[test]
2224 fn test_end_to_end_segdet_basic() {
2225 let num_protos = 32;
2228 let num_detections = 2;
2229 let num_features = 6 + num_protos;
2230
2231 let mut data = vec![0.0f32; num_features * num_detections];
2233 data[0] = 0.1; data[1] = 0.5; data[num_detections] = 0.1; data[num_detections + 1] = 0.5; data[2 * num_detections] = 0.4; data[2 * num_detections + 1] = 0.9; data[3 * num_detections] = 0.4; data[3 * num_detections + 1] = 0.9; data[4 * num_detections] = 0.9; data[4 * num_detections + 1] = 0.3; data[5 * num_detections] = 1.0; data[5 * num_detections + 1] = 2.0; for i in 6..num_features {
2248 data[i * num_detections] = 0.1;
2249 data[i * num_detections + 1] = 0.1;
2250 }
2251
2252 let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
2253
2254 let protos = Array3::<f32>::zeros((16, 16, num_protos));
2256
2257 let mut boxes = Vec::with_capacity(10);
2258 let mut masks = Vec::with_capacity(10);
2259 decode_yolo_end_to_end_segdet_float(
2260 output.view(),
2261 protos.view(),
2262 0.5,
2263 &mut boxes,
2264 &mut masks,
2265 )
2266 .unwrap();
2267
2268 assert_eq!(boxes.len(), 1);
2270 assert_eq!(masks.len(), 1);
2271 assert_eq!(boxes[0].label, 1);
2272 assert!((boxes[0].score - 0.9).abs() < 0.01);
2273 }
2274
2275 #[test]
2276 fn test_end_to_end_segdet_mask_coordinates() {
2277 let num_protos = 32;
2279 let num_features = 6 + num_protos;
2280
2281 let mut data = vec![0.0f32; num_features];
2282 data[0] = 0.2; data[1] = 0.2; data[2] = 0.8; data[3] = 0.8; data[4] = 0.95; data[5] = 3.0; let output = Array2::from_shape_vec((num_features, 1), data).unwrap();
2290 let protos = Array3::<f32>::zeros((16, 16, num_protos));
2291
2292 let mut boxes = Vec::with_capacity(10);
2293 let mut masks = Vec::with_capacity(10);
2294 decode_yolo_end_to_end_segdet_float(
2295 output.view(),
2296 protos.view(),
2297 0.5,
2298 &mut boxes,
2299 &mut masks,
2300 )
2301 .unwrap();
2302
2303 assert_eq!(boxes.len(), 1);
2304 assert_eq!(masks.len(), 1);
2305
2306 let step = 1.0 / 16.0;
2310 assert!(masks[0].xmin <= boxes[0].bbox.xmin);
2311 assert!(masks[0].ymin <= boxes[0].bbox.ymin);
2312 assert!(masks[0].xmax >= boxes[0].bbox.xmax);
2313 assert!(masks[0].ymax >= boxes[0].bbox.ymax);
2314 assert!((boxes[0].bbox.xmin - masks[0].xmin) < step);
2315 assert!((boxes[0].bbox.ymin - masks[0].ymin) < step);
2316 assert!((masks[0].xmax - boxes[0].bbox.xmax) < step);
2317 assert!((masks[0].ymax - boxes[0].bbox.ymax) < step);
2318 }
2319
2320 #[test]
2321 fn test_end_to_end_segdet_empty_output() {
2322 let num_protos = 32;
2323 let output = Array2::<f32>::zeros((6 + num_protos, 0));
2324 let protos = Array3::<f32>::zeros((16, 16, num_protos));
2325
2326 let mut boxes = Vec::with_capacity(10);
2327 let mut masks = Vec::with_capacity(10);
2328 decode_yolo_end_to_end_segdet_float(
2329 output.view(),
2330 protos.view(),
2331 0.5,
2332 &mut boxes,
2333 &mut masks,
2334 )
2335 .unwrap();
2336
2337 assert_eq!(boxes.len(), 0);
2338 assert_eq!(masks.len(), 0);
2339 }
2340
2341 #[test]
2342 fn test_end_to_end_segdet_capacity_limit() {
2343 let num_protos = 32;
2344 let num_detections = 5;
2345 let num_features = 6 + num_protos;
2346
2347 let mut data = vec![0.0f32; num_features * num_detections];
2348 for i in 0..num_detections {
2350 data[i] = 0.1 * (i as f32); data[num_detections + i] = 0.1 * (i as f32); data[2 * num_detections + i] = 0.1 * (i as f32) + 0.2; data[3 * num_detections + i] = 0.1 * (i as f32) + 0.2; data[4 * num_detections + i] = 0.9; data[5 * num_detections + i] = i as f32; }
2357
2358 let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
2359 let protos = Array3::<f32>::zeros((16, 16, num_protos));
2360
2361 let mut boxes = Vec::with_capacity(2); let mut masks = Vec::with_capacity(2);
2363 decode_yolo_end_to_end_segdet_float(
2364 output.view(),
2365 protos.view(),
2366 0.5,
2367 &mut boxes,
2368 &mut masks,
2369 )
2370 .unwrap();
2371
2372 assert_eq!(boxes.len(), 2);
2373 assert_eq!(masks.len(), 2);
2374 }
2375
2376 #[test]
2377 fn test_end_to_end_segdet_invalid_shape_too_few_rows() {
2378 let output = Array2::<f32>::zeros((6, 3));
2380 let protos = Array3::<f32>::zeros((16, 16, 32));
2381
2382 let mut boxes = Vec::with_capacity(10);
2383 let mut masks = Vec::with_capacity(10);
2384 let result = decode_yolo_end_to_end_segdet_float(
2385 output.view(),
2386 protos.view(),
2387 0.5,
2388 &mut boxes,
2389 &mut masks,
2390 );
2391
2392 assert!(result.is_err());
2393 assert!(matches!(
2394 result,
2395 Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 7 rows")
2396 ));
2397 }
2398
2399 #[test]
2400 fn test_end_to_end_segdet_invalid_shape_protos_mismatch() {
2401 let num_protos = 32;
2403 let output = Array2::<f32>::zeros((6 + 16, 3)); let protos = Array3::<f32>::zeros((16, 16, num_protos)); let mut boxes = Vec::with_capacity(10);
2407 let mut masks = Vec::with_capacity(10);
2408 let result = decode_yolo_end_to_end_segdet_float(
2409 output.view(),
2410 protos.view(),
2411 0.5,
2412 &mut boxes,
2413 &mut masks,
2414 );
2415
2416 assert!(result.is_err());
2417 assert!(matches!(
2418 result,
2419 Err(crate::DecoderError::InvalidShape(s)) if s.contains("doesn't match protos count")
2420 ));
2421 }
2422
2423 #[test]
2428 fn test_split_end_to_end_segdet_basic() {
2429 let num_protos = 32;
2432 let num_detections = 2;
2433 let num_features = 6 + num_protos;
2434
2435 let mut data = vec![0.0f32; num_features * num_detections];
2437 data[0] = 0.1; data[1] = 0.5; data[num_detections] = 0.1; data[num_detections + 1] = 0.5; data[2 * num_detections] = 0.4; data[2 * num_detections + 1] = 0.9; data[3 * num_detections] = 0.4; data[3 * num_detections + 1] = 0.9; data[4 * num_detections] = 0.9; data[4 * num_detections + 1] = 0.3; data[5 * num_detections] = 1.0; data[5 * num_detections + 1] = 2.0; for i in 6..num_features {
2452 data[i * num_detections] = 0.1;
2453 data[i * num_detections + 1] = 0.1;
2454 }
2455
2456 let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
2457 let box_coords = output.slice(s![..4, ..]);
2458 let scores = output.slice(s![4..5, ..]);
2459 let classes = output.slice(s![5..6, ..]);
2460 let mask_coeff = output.slice(s![6.., ..]);
2461 let protos = Array3::<f32>::zeros((16, 16, num_protos));
2463
2464 let mut boxes = Vec::with_capacity(10);
2465 let mut masks = Vec::with_capacity(10);
2466 decode_yolo_split_end_to_end_segdet_float(
2467 box_coords,
2468 scores,
2469 classes,
2470 mask_coeff,
2471 protos.view(),
2472 0.5,
2473 &mut boxes,
2474 &mut masks,
2475 )
2476 .unwrap();
2477
2478 assert_eq!(boxes.len(), 1);
2480 assert_eq!(masks.len(), 1);
2481 assert_eq!(boxes[0].label, 1);
2482 assert!((boxes[0].score - 0.9).abs() < 0.01);
2483 }
2484
2485 #[test]
2490 fn test_segmentation_to_mask_basic() {
2491 let data: Vec<u8> = vec![
2493 100, 200, 50, 150, 10, 255, 128, 64, 0, 127, 128, 255, 64, 64, 192, 192, ];
2498 let segmentation = Array3::from_shape_vec((4, 4, 1), data).unwrap();
2499
2500 let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2501
2502 assert_eq!(mask[[0, 0]], 0); assert_eq!(mask[[0, 1]], 1); assert_eq!(mask[[0, 2]], 0); assert_eq!(mask[[0, 3]], 1); assert_eq!(mask[[1, 1]], 1); assert_eq!(mask[[1, 2]], 1); assert_eq!(mask[[2, 0]], 0); assert_eq!(mask[[2, 1]], 0); }
2512
2513 #[test]
2514 fn test_segmentation_to_mask_all_above() {
2515 let segmentation = Array3::from_elem((4, 4, 1), 255u8);
2516 let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2517 assert!(mask.iter().all(|&x| x == 1));
2518 }
2519
2520 #[test]
2521 fn test_segmentation_to_mask_all_below() {
2522 let segmentation = Array3::from_elem((4, 4, 1), 64u8);
2523 let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2524 assert!(mask.iter().all(|&x| x == 0));
2525 }
2526
2527 #[test]
2528 fn test_segmentation_to_mask_invalid_shape() {
2529 let segmentation = Array3::from_elem((4, 4, 3), 128u8);
2530 let result = yolo_segmentation_to_mask(segmentation.view(), 128);
2531
2532 assert!(result.is_err());
2533 assert!(matches!(
2534 result,
2535 Err(crate::DecoderError::InvalidShape(s)) if s.contains("(H, W, 1)")
2536 ));
2537 }
2538
2539 #[test]
2544 fn test_protobox_clamps_edge_coordinates() {
2545 let protos = Array3::<f32>::zeros((16, 16, 4));
2547 let view = protos.view();
2548 let roi = BoundingBox {
2549 xmin: 0.5,
2550 ymin: 0.5,
2551 xmax: 1.0,
2552 ymax: 1.0,
2553 };
2554 let result = protobox(&view, &roi);
2555 assert!(result.is_ok(), "protobox should accept xmax=1.0");
2556 let (cropped, _roi_norm) = result.unwrap();
2557 assert!(cropped.shape()[0] > 0);
2559 assert!(cropped.shape()[1] > 0);
2560 assert_eq!(cropped.shape()[2], 4);
2561 }
2562
2563 #[test]
2564 fn test_protobox_rejects_wildly_out_of_range() {
2565 let protos = Array3::<f32>::zeros((16, 16, 4));
2567 let view = protos.view();
2568 let roi = BoundingBox {
2569 xmin: 0.0,
2570 ymin: 0.0,
2571 xmax: 3.0,
2572 ymax: 3.0,
2573 };
2574 let result = protobox(&view, &roi);
2575 assert!(
2576 matches!(result, Err(crate::DecoderError::InvalidShape(s)) if s.contains("un-normalized")),
2577 "protobox should reject coords > NORM_LIMIT"
2578 );
2579 }
2580
2581 #[test]
2582 fn test_protobox_accepts_slightly_over_one() {
2583 let protos = Array3::<f32>::zeros((16, 16, 4));
2585 let view = protos.view();
2586 let roi = BoundingBox {
2587 xmin: 0.0,
2588 ymin: 0.0,
2589 xmax: 1.5,
2590 ymax: 1.5,
2591 };
2592 let result = protobox(&view, &roi);
2593 assert!(
2594 result.is_ok(),
2595 "protobox should accept coords <= NORM_LIMIT (2.0)"
2596 );
2597 let (cropped, _roi_norm) = result.unwrap();
2598 assert_eq!(cropped.shape()[0], 16);
2600 assert_eq!(cropped.shape()[1], 16);
2601 }
2602
2603 #[test]
2604 fn test_segdet_float_proto_no_panic() {
2605 let num_proposals = 100; let num_classes = 80;
2609 let num_mask_coeffs = 32;
2610 let rows = 4 + num_classes + num_mask_coeffs; let mut data = vec![0.0f32; rows * num_proposals];
2616 for i in 0..num_proposals {
2617 let row = |r: usize| r * num_proposals + i;
2618 data[row(0)] = 320.0; data[row(1)] = 320.0; data[row(2)] = 50.0; data[row(3)] = 50.0; data[row(4)] = 0.9; }
2624 let boxes = ndarray::Array2::from_shape_vec((rows, num_proposals), data).unwrap();
2625
2626 let protos = ndarray::Array3::<f32>::zeros((160, 160, num_mask_coeffs));
2631
2632 let mut output_boxes = Vec::with_capacity(300);
2633
2634 let proto_data = impl_yolo_segdet_float_proto::<XYWH, _, _>(
2636 boxes.view(),
2637 protos.view(),
2638 0.5,
2639 0.7,
2640 Some(Nms::default()),
2641 MAX_NMS_CANDIDATES,
2642 300,
2643 None,
2644 None,
2645 &mut output_boxes,
2646 );
2647
2648 assert!(!output_boxes.is_empty());
2650 let coeffs_shape = proto_data.mask_coefficients.shape();
2651 assert_eq!(coeffs_shape[0], output_boxes.len());
2652 assert_eq!(coeffs_shape[1], num_mask_coeffs);
2654 }
2655
2656 #[test]
2671 fn test_pre_nms_cap_truncates_excess_candidates() {
2672 let n: usize = 50_000;
2673 let num_classes = 1;
2674
2675 let mut boxes_data = Vec::with_capacity(n * 4);
2679 let mut scores_data = Vec::with_capacity(n * num_classes);
2680 for i in 0..n {
2681 boxes_data.extend_from_slice(&[0.1f32, 0.1, 0.5, 0.5]);
2682 scores_data.push(0.99 - (i as f32) * 1e-7);
2685 }
2686 let boxes = Array2::from_shape_vec((n, 4), boxes_data).unwrap();
2687 let scores = Array2::from_shape_vec((n, num_classes), scores_data).unwrap();
2688
2689 let result = impl_yolo_segdet_get_boxes::<XYXY, _, _>(
2690 boxes.view(),
2691 scores.view(),
2692 0.1,
2693 1.0, Some(Nms::ClassAgnostic), crate::yolo::MAX_NMS_CANDIDATES, usize::MAX, );
2698
2699 assert_eq!(
2700 result.len(),
2701 crate::yolo::MAX_NMS_CANDIDATES,
2702 "pre-NMS cap should truncate to MAX_NMS_CANDIDATES; got {}",
2703 result.len()
2704 );
2705 let top_score = result[0].0.score;
2708 assert!(
2709 top_score > 0.98,
2710 "highest-ranked survivor should have the largest score, got {top_score}"
2711 );
2712 }
2713
2714 #[test]
2719 fn test_pre_nms_cap_truncates_excess_candidates_quant() {
2720 use crate::Quantization;
2721 let n: usize = 50_000;
2722 let num_classes = 1;
2723
2724 let boxes_data = (0..n).flat_map(|_| [10i8, 10, 50, 50]).collect::<Vec<_>>();
2727 let boxes = Array2::from_shape_vec((n, 4), boxes_data).unwrap();
2728 let quant_boxes = Quantization {
2729 scale: 0.01,
2730 zero_point: 0,
2731 };
2732
2733 let scores_data: Vec<u8> = (0..n)
2738 .map(|i| 250u8.saturating_sub((i % 200) as u8))
2739 .collect();
2740 let scores = Array2::from_shape_vec((n, num_classes), scores_data).unwrap();
2741 let quant_scores = Quantization {
2742 scale: 0.00392,
2743 zero_point: 0,
2744 };
2745
2746 let result = impl_yolo_split_segdet_quant_get_boxes::<XYXY, _, _>(
2747 (boxes.view(), quant_boxes),
2748 (scores.view(), quant_scores),
2749 0.1,
2750 1.0, Some(Nms::ClassAgnostic), crate::yolo::MAX_NMS_CANDIDATES, usize::MAX, );
2755
2756 assert_eq!(
2757 result.len(),
2758 crate::yolo::MAX_NMS_CANDIDATES,
2759 "quant path pre-NMS cap should truncate to MAX_NMS_CANDIDATES; got {}",
2760 result.len()
2761 );
2762 }
2763
2764 #[test]
2778 fn segdet_combined_tensor_pairs_detection_with_matching_mask_row() {
2779 let nc = 2; let nm = 2; let n = 3; let feat = 4 + nc + nm; let mut data = vec![0.0f32; feat * n];
2802 let set = |d: &mut [f32], r: usize, c: usize, v: f32| d[r * n + c] = v;
2803 set(&mut data, 0, 0, 0.2);
2804 set(&mut data, 1, 0, 0.2);
2805 set(&mut data, 2, 0, 0.1);
2806 set(&mut data, 3, 0, 0.1);
2807 set(&mut data, 0, 1, 0.5);
2808 set(&mut data, 1, 1, 0.5);
2809 set(&mut data, 2, 1, 0.1);
2810 set(&mut data, 3, 1, 0.1);
2811 set(&mut data, 0, 2, 0.8);
2812 set(&mut data, 1, 2, 0.8);
2813 set(&mut data, 2, 2, 0.1);
2814 set(&mut data, 3, 2, 0.1);
2815 set(&mut data, 4, 0, 0.9);
2816 set(&mut data, 4, 2, 0.8);
2817 set(&mut data, 6, 0, 3.0);
2818 set(&mut data, 7, 0, 3.0);
2819 set(&mut data, 6, 2, -3.0);
2820 set(&mut data, 7, 2, -3.0);
2821
2822 let output = Array2::from_shape_vec((feat, n), data).unwrap();
2823 let protos = Array3::<f32>::from_elem((8, 8, nm), 1.0);
2824
2825 let mut boxes: Vec<DetectBox> = Vec::with_capacity(10);
2826 let mut masks: Vec<Segmentation> = Vec::with_capacity(10);
2827 decode_yolo_segdet_float(
2828 output.view(),
2829 protos.view(),
2830 0.5,
2831 0.5,
2832 Some(Nms::ClassAgnostic),
2833 &mut boxes,
2834 &mut masks,
2835 )
2836 .unwrap();
2837
2838 assert_eq!(
2839 boxes.len(),
2840 2,
2841 "two anchors above threshold should survive (a0 score=0.9, a2 score=0.8); got {}",
2842 boxes.len()
2843 );
2844
2845 for (b, m) in boxes.iter().zip(masks.iter()) {
2851 let cx = (b.bbox.xmin + b.bbox.xmax) * 0.5;
2852 let mean = {
2853 let s = &m.segmentation;
2854 let total: u32 = s.iter().map(|&v| v as u32).sum();
2855 total as f32 / s.len() as f32
2856 };
2857 if cx < 0.3 {
2858 assert!(
2860 mean > 200.0,
2861 "anchor 0 detection (centre {cx:.2}) should have high-value mask; got mean {mean}"
2862 );
2863 } else if cx > 0.7 {
2864 assert!(
2866 mean < 50.0,
2867 "anchor 2 detection (centre {cx:.2}) should have low-value mask; got mean {mean}"
2868 );
2869 } else {
2870 panic!("unexpected detection centre {cx:.2}");
2871 }
2872 }
2873 }
2874}