1use std::fmt::Debug;
13
14use ndarray::{
15 parallel::prelude::{IntoParallelIterator, ParallelIterator},
16 s, Array2, Array3, ArrayView1, ArrayView2, ArrayView3,
17};
18use num_traits::{AsPrimitive, Float, PrimInt, Signed};
19
20use crate::{
21 byte::{
22 nms_class_aware_int, nms_extra_class_aware_int, nms_extra_int, nms_int,
23 postprocess_boxes_index_quant, postprocess_boxes_quant, quantize_score_threshold,
24 },
25 configs::Nms,
26 dequant_detect_box,
27 float::{
28 nms_class_aware_float, nms_extra_class_aware_float, nms_extra_float, nms_float,
29 postprocess_boxes_float, postprocess_boxes_index_float,
30 },
31 BBoxTypeTrait, BoundingBox, DetectBox, DetectBoxQuantized, ProtoData, ProtoLayout,
32 Quantization, Segmentation, XYWH, XYXY,
33};
34
35#[cfg(test)]
53pub(crate) const MAX_NMS_CANDIDATES: usize = 30_000;
54
55pub(crate) const DEFAULT_MAX_DETECTIONS: usize = 300;
62
63fn truncate_to_top_k_by_score<E: Send>(boxes: &mut Vec<(DetectBox, E)>, top_k: usize) {
67 if boxes.len() > top_k {
68 boxes.select_nth_unstable_by(top_k, |a, b| b.0.score.total_cmp(&a.0.score));
69 boxes.truncate(top_k);
70 }
71}
72
73fn truncate_to_top_k_by_score_quant<S: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send>(
77 boxes: &mut Vec<(DetectBoxQuantized<S>, E)>,
78 top_k: usize,
79) {
80 if boxes.len() > top_k {
81 boxes.select_nth_unstable_by(top_k, |a, b| b.0.score.cmp(&a.0.score));
82 boxes.truncate(top_k);
83 }
84}
85
86fn dispatch_nms_float(nms: Option<Nms>, iou: f32, boxes: Vec<DetectBox>) -> Vec<DetectBox> {
88 match nms {
89 Some(Nms::ClassAgnostic) => nms_float(iou, boxes),
90 Some(Nms::ClassAware) => nms_class_aware_float(iou, boxes),
91 None => boxes, }
93}
94
95pub(super) fn dispatch_nms_extra_float<E: Send + Sync>(
98 nms: Option<Nms>,
99 iou: f32,
100 boxes: Vec<(DetectBox, E)>,
101) -> Vec<(DetectBox, E)> {
102 match nms {
103 Some(Nms::ClassAgnostic) => nms_extra_float(iou, boxes),
104 Some(Nms::ClassAware) => nms_extra_class_aware_float(iou, boxes),
105 None => boxes, }
107}
108
109fn dispatch_nms_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync>(
112 nms: Option<Nms>,
113 iou: f32,
114 boxes: Vec<DetectBoxQuantized<SCORE>>,
115) -> Vec<DetectBoxQuantized<SCORE>> {
116 match nms {
117 Some(Nms::ClassAgnostic) => nms_int(iou, boxes),
118 Some(Nms::ClassAware) => nms_class_aware_int(iou, boxes),
119 None => boxes, }
121}
122
123fn dispatch_nms_extra_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send + Sync>(
126 nms: Option<Nms>,
127 iou: f32,
128 boxes: Vec<(DetectBoxQuantized<SCORE>, E)>,
129) -> Vec<(DetectBoxQuantized<SCORE>, E)> {
130 match nms {
131 Some(Nms::ClassAgnostic) => nms_extra_int(iou, boxes),
132 Some(Nms::ClassAware) => nms_extra_class_aware_int(iou, boxes),
133 None => boxes, }
135}
136
137#[inline]
144fn cap_or_default<T>(v: &Vec<T>) -> usize {
145 if v.capacity() > 0 {
146 v.capacity()
147 } else {
148 DEFAULT_MAX_DETECTIONS
149 }
150}
151
152pub(crate) fn decode_yolo_det<BOX: PrimInt + AsPrimitive<f32> + Send + Sync>(
186 output: (ArrayView2<BOX>, Quantization),
187 score_threshold: f32,
188 iou_threshold: f32,
189 nms: Option<Nms>,
190 output_boxes: &mut Vec<DetectBox>,
191) where
192 f32: AsPrimitive<BOX>,
193{
194 impl_yolo_quant::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
195}
196
197pub(crate) fn decode_yolo_det_float<T>(
204 output: ArrayView2<T>,
205 score_threshold: f32,
206 iou_threshold: f32,
207 nms: Option<Nms>,
208 output_boxes: &mut Vec<DetectBox>,
209) where
210 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
211 f32: AsPrimitive<T>,
212{
213 impl_yolo_float::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
214}
215
216#[cfg(test)]
230pub(crate) fn decode_yolo_segdet_quant<
231 BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
232 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
233>(
234 boxes: (ArrayView2<BOX>, Quantization),
235 protos: (ArrayView3<PROTO>, Quantization),
236 score_threshold: f32,
237 iou_threshold: f32,
238 nms: Option<Nms>,
239 output_boxes: &mut Vec<DetectBox>,
240 output_masks: &mut Vec<Segmentation>,
241) -> Result<(), crate::DecoderError>
242where
243 f32: AsPrimitive<BOX>,
244{
245 let cap = cap_or_default(output_boxes);
250 impl_yolo_segdet_quant::<XYWH, _, _>(
251 boxes,
252 protos,
253 score_threshold,
254 iou_threshold,
255 nms,
256 MAX_NMS_CANDIDATES,
257 cap,
258 None,
259 None,
260 output_boxes,
261 output_masks,
262 )
263}
264
265#[cfg(test)]
267pub(crate) fn decode_yolo_segdet_float<T>(
268 boxes: ArrayView2<T>,
269 protos: ArrayView3<T>,
270 score_threshold: f32,
271 iou_threshold: f32,
272 nms: Option<Nms>,
273 output_boxes: &mut Vec<DetectBox>,
274 output_masks: &mut Vec<Segmentation>,
275) -> Result<(), crate::DecoderError>
276where
277 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
278 f32: AsPrimitive<T>,
279{
280 let cap = cap_or_default(output_boxes);
283 impl_yolo_segdet_float::<XYWH, _, _>(
284 boxes,
285 protos,
286 score_threshold,
287 iou_threshold,
288 nms,
289 MAX_NMS_CANDIDATES,
290 cap,
291 None,
292 None,
293 output_boxes,
294 output_masks,
295 )
296}
297
298pub(crate) fn decode_yolo_split_det_quant<
310 BOX: PrimInt + AsPrimitive<i32> + AsPrimitive<f32> + Send + Sync,
311 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
312>(
313 boxes: (ArrayView2<BOX>, Quantization),
314 scores: (ArrayView2<SCORE>, Quantization),
315 score_threshold: f32,
316 iou_threshold: f32,
317 nms: Option<Nms>,
318 output_boxes: &mut Vec<DetectBox>,
319) where
320 f32: AsPrimitive<SCORE>,
321{
322 impl_yolo_split_quant::<XYWH, _, _>(
323 boxes,
324 scores,
325 score_threshold,
326 iou_threshold,
327 nms,
328 output_boxes,
329 );
330}
331
332pub(crate) fn decode_yolo_split_det_float<T>(
344 boxes: ArrayView2<T>,
345 scores: ArrayView2<T>,
346 score_threshold: f32,
347 iou_threshold: f32,
348 nms: Option<Nms>,
349 output_boxes: &mut Vec<DetectBox>,
350) where
351 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
352 f32: AsPrimitive<T>,
353{
354 impl_yolo_split_float::<XYWH, _, _>(
355 boxes,
356 scores,
357 score_threshold,
358 iou_threshold,
359 nms,
360 output_boxes,
361 );
362}
363
364pub(crate) fn decode_yolo_end_to_end_det_float<T>(
379 output: ArrayView2<T>,
380 score_threshold: f32,
381 output_boxes: &mut Vec<DetectBox>,
382) -> Result<(), crate::DecoderError>
383where
384 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
385 f32: AsPrimitive<T>,
386{
387 if output.shape()[0] < 6 {
389 return Err(crate::DecoderError::InvalidShape(format!(
390 "End-to-end detection output requires at least 6 rows, got {}",
391 output.shape()[0]
392 )));
393 }
394
395 let boxes = output.slice(s![0..4, ..]).reversed_axes();
397 let scores = output.slice(s![4..5, ..]).reversed_axes();
398 let classes = output.slice(s![5, ..]);
399 let mut boxes =
400 postprocess_boxes_index_float::<XYXY, _, _>(score_threshold.as_(), boxes, scores);
401 boxes.truncate(cap_or_default(output_boxes));
402 output_boxes.clear();
403 for (mut b, i) in boxes.into_iter() {
404 b.label = classes[i].as_() as usize;
405 output_boxes.push(b);
406 }
407 Ok(())
409}
410
411pub(crate) fn decode_yolo_end_to_end_segdet_float<T>(
429 output: ArrayView2<T>,
430 protos: ArrayView3<T>,
431 score_threshold: f32,
432 output_boxes: &mut Vec<DetectBox>,
433 output_masks: &mut Vec<crate::Segmentation>,
434) -> Result<(), crate::DecoderError>
435where
436 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
437 f32: AsPrimitive<T>,
438{
439 let (boxes, scores, classes, mask_coeff) =
440 postprocess_yolo_end_to_end_segdet(&output, protos.dim().2)?;
441 let cap = cap_or_default(output_boxes);
442 let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
443 boxes,
444 scores,
445 classes,
446 score_threshold,
447 cap,
448 );
449
450 impl_yolo_split_segdet_process_masks(boxes, mask_coeff, protos, output_boxes, output_masks)
453}
454
455pub(crate) fn decode_yolo_split_end_to_end_det_float<T: Float + AsPrimitive<f32>>(
464 boxes: ArrayView2<T>,
465 scores: ArrayView2<T>,
466 classes: ArrayView2<T>,
467 score_threshold: f32,
468 output_boxes: &mut Vec<DetectBox>,
469) -> Result<(), crate::DecoderError> {
470 let n = boxes.shape()[1];
471
472 let cap = cap_or_default(output_boxes);
473 output_boxes.clear();
474
475 let (boxes, scores, classes) = postprocess_yolo_split_end_to_end_det(boxes, scores, &classes)?;
476
477 for i in 0..n {
478 let score: f32 = scores[[i, 0]].as_();
479 if score < score_threshold {
480 continue;
481 }
482 if output_boxes.len() >= cap {
483 break;
484 }
485 output_boxes.push(DetectBox {
486 bbox: BoundingBox {
487 xmin: boxes[[i, 0]].as_(),
488 ymin: boxes[[i, 1]].as_(),
489 xmax: boxes[[i, 2]].as_(),
490 ymax: boxes[[i, 3]].as_(),
491 },
492 score,
493 label: classes[i].as_() as usize,
494 });
495 }
496 Ok(())
497}
498
499#[allow(clippy::too_many_arguments)]
508pub(crate) fn decode_yolo_split_end_to_end_segdet_float<T>(
509 boxes: ArrayView2<T>,
510 scores: ArrayView2<T>,
511 classes: ArrayView2<T>,
512 mask_coeff: ArrayView2<T>,
513 protos: ArrayView3<T>,
514 score_threshold: f32,
515 output_boxes: &mut Vec<DetectBox>,
516 output_masks: &mut Vec<crate::Segmentation>,
517) -> Result<(), crate::DecoderError>
518where
519 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
520 f32: AsPrimitive<T>,
521{
522 let (boxes, scores, classes, mask_coeff) =
523 postprocess_yolo_split_end_to_end_segdet(boxes, scores, &classes, mask_coeff)?;
524 let cap = cap_or_default(output_boxes);
525 let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
526 boxes,
527 scores,
528 classes,
529 score_threshold,
530 cap,
531 );
532
533 impl_yolo_split_segdet_process_masks(boxes, mask_coeff, protos, output_boxes, output_masks)
534}
535
536#[allow(clippy::type_complexity)]
537pub(crate) fn postprocess_yolo_end_to_end_segdet<'a, T>(
538 output: &'a ArrayView2<'_, T>,
539 num_protos: usize,
540) -> Result<
541 (
542 ArrayView2<'a, T>,
543 ArrayView2<'a, T>,
544 ArrayView1<'a, T>,
545 ArrayView2<'a, T>,
546 ),
547 crate::DecoderError,
548> {
549 if output.shape()[0] < 7 {
551 return Err(crate::DecoderError::InvalidShape(format!(
552 "End-to-end segdet output requires at least 7 rows, got {}",
553 output.shape()[0]
554 )));
555 }
556
557 let num_mask_coeffs = output.shape()[0] - 6;
558 if num_mask_coeffs != num_protos {
559 return Err(crate::DecoderError::InvalidShape(format!(
560 "Mask coefficients count ({}) doesn't match protos count ({})",
561 num_mask_coeffs, num_protos
562 )));
563 }
564
565 let boxes = output.slice(s![0..4, ..]).reversed_axes();
567 let scores = output.slice(s![4..5, ..]).reversed_axes();
568 let classes = output.slice(s![5, ..]);
569 let mask_coeff = output.slice(s![6.., ..]).reversed_axes();
570 Ok((boxes, scores, classes, mask_coeff))
571}
572
573#[allow(clippy::type_complexity)]
580pub(crate) fn postprocess_yolo_split_end_to_end_det<'a, 'b, 'c, BOXES, SCORES, CLASS>(
581 boxes: ArrayView2<'a, BOXES>,
582 scores: ArrayView2<'b, SCORES>,
583 classes: &'c ArrayView2<CLASS>,
584) -> Result<
585 (
586 ArrayView2<'a, BOXES>,
587 ArrayView2<'b, SCORES>,
588 ArrayView1<'c, CLASS>,
589 ),
590 crate::DecoderError,
591> {
592 let num_boxes = boxes.shape()[1];
593 if boxes.shape()[0] != 4 {
594 return Err(crate::DecoderError::InvalidShape(format!(
595 "Split end-to-end box_coords must be 4, got {}",
596 boxes.shape()[0]
597 )));
598 }
599
600 if scores.shape()[0] != 1 {
601 return Err(crate::DecoderError::InvalidShape(format!(
602 "Split end-to-end scores num_classes must be 1, got {}",
603 scores.shape()[0]
604 )));
605 }
606
607 if classes.shape()[0] != 1 {
608 return Err(crate::DecoderError::InvalidShape(format!(
609 "Split end-to-end classes num_classes must be 1, got {}",
610 classes.shape()[0]
611 )));
612 }
613
614 if scores.shape()[1] != num_boxes {
615 return Err(crate::DecoderError::InvalidShape(format!(
616 "Split end-to-end scores must have same num_boxes as boxes ({}), got {}",
617 num_boxes,
618 scores.shape()[1]
619 )));
620 }
621
622 if classes.shape()[1] != num_boxes {
623 return Err(crate::DecoderError::InvalidShape(format!(
624 "Split end-to-end classes must have same num_boxes as boxes ({}), got {}",
625 num_boxes,
626 classes.shape()[1]
627 )));
628 }
629
630 let boxes = boxes.reversed_axes();
631 let scores = scores.reversed_axes();
632 let classes = classes.slice(s![0, ..]);
633 Ok((boxes, scores, classes))
634}
635
636#[allow(clippy::type_complexity)]
639pub(crate) fn postprocess_yolo_split_end_to_end_segdet<
640 'a,
641 'b,
642 'c,
643 'd,
644 BOXES,
645 SCORES,
646 CLASS,
647 MASK,
648>(
649 boxes: ArrayView2<'a, BOXES>,
650 scores: ArrayView2<'b, SCORES>,
651 classes: &'c ArrayView2<CLASS>,
652 mask_coeff: ArrayView2<'d, MASK>,
653) -> Result<
654 (
655 ArrayView2<'a, BOXES>,
656 ArrayView2<'b, SCORES>,
657 ArrayView1<'c, CLASS>,
658 ArrayView2<'d, MASK>,
659 ),
660 crate::DecoderError,
661> {
662 let num_boxes = boxes.shape()[1];
663 if boxes.shape()[0] != 4 {
664 return Err(crate::DecoderError::InvalidShape(format!(
665 "Split end-to-end box_coords must be 4, got {}",
666 boxes.shape()[0]
667 )));
668 }
669
670 if scores.shape()[0] != 1 {
671 return Err(crate::DecoderError::InvalidShape(format!(
672 "Split end-to-end scores num_classes must be 1, got {}",
673 scores.shape()[0]
674 )));
675 }
676
677 if classes.shape()[0] != 1 {
678 return Err(crate::DecoderError::InvalidShape(format!(
679 "Split end-to-end classes num_classes must be 1, got {}",
680 classes.shape()[0]
681 )));
682 }
683
684 if scores.shape()[1] != num_boxes {
685 return Err(crate::DecoderError::InvalidShape(format!(
686 "Split end-to-end scores must have same num_boxes as boxes ({}), got {}",
687 num_boxes,
688 scores.shape()[1]
689 )));
690 }
691
692 if classes.shape()[1] != num_boxes {
693 return Err(crate::DecoderError::InvalidShape(format!(
694 "Split end-to-end classes must have same num_boxes as boxes ({}), got {}",
695 num_boxes,
696 classes.shape()[1]
697 )));
698 }
699
700 if mask_coeff.shape()[1] != num_boxes {
701 return Err(crate::DecoderError::InvalidShape(format!(
702 "Split end-to-end mask_coeff must have same num_boxes as boxes ({}), got {}",
703 num_boxes,
704 mask_coeff.shape()[1]
705 )));
706 }
707
708 let boxes = boxes.reversed_axes();
709 let scores = scores.reversed_axes();
710 let classes = classes.slice(s![0, ..]);
711 let mask_coeff = mask_coeff.reversed_axes();
712 Ok((boxes, scores, classes, mask_coeff))
713}
714pub(crate) fn impl_yolo_quant<B: BBoxTypeTrait, T: PrimInt + AsPrimitive<f32> + Send + Sync>(
719 output: (ArrayView2<T>, Quantization),
720 score_threshold: f32,
721 iou_threshold: f32,
722 nms: Option<Nms>,
723 output_boxes: &mut Vec<DetectBox>,
724) where
725 f32: AsPrimitive<T>,
726{
727 let _span = tracing::trace_span!("decode", mode = "quant_det").entered();
728 let (boxes, quant_boxes) = output;
729 let (boxes_tensor, scores_tensor) = postprocess_yolo(&boxes);
730
731 let boxes = {
732 let score_threshold = quantize_score_threshold(score_threshold, quant_boxes);
733 postprocess_boxes_quant::<B, _, _>(
734 score_threshold,
735 boxes_tensor,
736 scores_tensor,
737 quant_boxes,
738 )
739 };
740
741 let boxes = dispatch_nms_int(nms, iou_threshold, boxes);
742 let len = cap_or_default(output_boxes).min(boxes.len());
744 output_boxes.clear();
745 for b in boxes.iter().take(len) {
746 output_boxes.push(dequant_detect_box(b, quant_boxes));
747 }
748}
749
750pub(crate) fn impl_yolo_float<B: BBoxTypeTrait, T: Float + AsPrimitive<f32> + Send + Sync>(
755 output: ArrayView2<T>,
756 score_threshold: f32,
757 iou_threshold: f32,
758 nms: Option<Nms>,
759 output_boxes: &mut Vec<DetectBox>,
760) where
761 f32: AsPrimitive<T>,
762{
763 let _span = tracing::trace_span!("decode", mode = "float_det").entered();
764 let (boxes_tensor, scores_tensor) = postprocess_yolo(&output);
765 let boxes =
766 postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
767 let boxes = dispatch_nms_float(nms, iou_threshold, boxes);
768 let len = cap_or_default(output_boxes).min(boxes.len());
770 output_boxes.clear();
771 for b in boxes.into_iter().take(len) {
772 output_boxes.push(b);
773 }
774}
775
776pub(crate) fn impl_yolo_split_quant<
786 B: BBoxTypeTrait,
787 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
788 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
789>(
790 boxes: (ArrayView2<BOX>, Quantization),
791 scores: (ArrayView2<SCORE>, Quantization),
792 score_threshold: f32,
793 iou_threshold: f32,
794 nms: Option<Nms>,
795 output_boxes: &mut Vec<DetectBox>,
796) where
797 f32: AsPrimitive<SCORE>,
798{
799 let _span = tracing::trace_span!("decode", mode = "split_quant_det").entered();
800 let (boxes_tensor, quant_boxes) = boxes;
801 let (scores_tensor, quant_scores) = scores;
802
803 let boxes_tensor = boxes_tensor.reversed_axes();
804 let scores_tensor = scores_tensor.reversed_axes();
805
806 let boxes = {
807 let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
808 postprocess_boxes_quant::<B, _, _>(
809 score_threshold,
810 boxes_tensor,
811 scores_tensor,
812 quant_boxes,
813 )
814 };
815
816 let boxes = dispatch_nms_int(nms, iou_threshold, boxes);
817 let len = cap_or_default(output_boxes).min(boxes.len());
819 output_boxes.clear();
820 for b in boxes.iter().take(len) {
821 output_boxes.push(dequant_detect_box(b, quant_scores));
822 }
823}
824
825pub(crate) fn impl_yolo_split_float<
834 B: BBoxTypeTrait,
835 BOX: Float + AsPrimitive<f32> + Send + Sync,
836 SCORE: Float + AsPrimitive<f32> + Send + Sync,
837>(
838 boxes_tensor: ArrayView2<BOX>,
839 scores_tensor: ArrayView2<SCORE>,
840 score_threshold: f32,
841 iou_threshold: f32,
842 nms: Option<Nms>,
843 output_boxes: &mut Vec<DetectBox>,
844) where
845 f32: AsPrimitive<SCORE>,
846{
847 let _span = tracing::trace_span!("decode", mode = "split_float_det").entered();
848 let boxes_tensor = boxes_tensor.reversed_axes();
849 let scores_tensor = scores_tensor.reversed_axes();
850 let boxes =
851 postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
852 let boxes = dispatch_nms_float(nms, iou_threshold, boxes);
853 let len = cap_or_default(output_boxes).min(boxes.len());
855 output_boxes.clear();
856 for b in boxes.into_iter().take(len) {
857 output_boxes.push(b);
858 }
859}
860
861#[inline]
869pub(crate) fn maybe_normalize_boxes_in_place(
870 boxes: &mut [(DetectBox, usize)],
871 normalized: Option<bool>,
872 input_dims: Option<(usize, usize)>,
873) {
874 if normalized != Some(false) {
875 return;
876 }
877 let Some((w, h)) = input_dims else {
878 return;
879 };
880 if w == 0 || h == 0 {
881 return;
882 }
883 let inv_w = 1.0 / w as f32;
884 let inv_h = 1.0 / h as f32;
885 for (b, _) in boxes.iter_mut() {
886 b.bbox.xmin *= inv_w;
887 b.bbox.ymin *= inv_h;
888 b.bbox.xmax *= inv_w;
889 b.bbox.ymax *= inv_h;
890 }
891}
892
893#[allow(clippy::too_many_arguments)]
903pub(crate) fn impl_yolo_segdet_quant<
904 B: BBoxTypeTrait,
905 BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
906 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
907>(
908 boxes: (ArrayView2<BOX>, Quantization),
909 protos: (ArrayView3<PROTO>, Quantization),
910 score_threshold: f32,
911 iou_threshold: f32,
912 nms: Option<Nms>,
913 pre_nms_top_k: usize,
914 max_det: usize,
915 normalized: Option<bool>,
916 input_dims: Option<(usize, usize)>,
917 output_boxes: &mut Vec<DetectBox>,
918 output_masks: &mut Vec<Segmentation>,
919) -> Result<(), crate::DecoderError>
920where
921 f32: AsPrimitive<BOX>,
922{
923 let (boxes, quant_boxes) = boxes;
924 let num_protos = protos.0.dim().2;
925
926 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
927 let mut boxes = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
928 (boxes_tensor, quant_boxes),
929 (scores_tensor, quant_boxes),
930 score_threshold,
931 iou_threshold,
932 nms,
933 pre_nms_top_k,
934 max_det,
935 );
936 maybe_normalize_boxes_in_place(&mut boxes, normalized, input_dims);
937
938 impl_yolo_split_segdet_quant_process_masks::<_, _>(
939 boxes,
940 (mask_tensor, quant_boxes),
941 protos,
942 output_boxes,
943 output_masks,
944 )
945}
946
947#[allow(clippy::too_many_arguments)]
957pub(crate) fn impl_yolo_segdet_float<
958 B: BBoxTypeTrait,
959 BOX: Float + AsPrimitive<f32> + Send + Sync,
960 PROTO: Float + AsPrimitive<f32> + Send + Sync,
961>(
962 boxes: ArrayView2<BOX>,
963 protos: ArrayView3<PROTO>,
964 score_threshold: f32,
965 iou_threshold: f32,
966 nms: Option<Nms>,
967 pre_nms_top_k: usize,
968 max_det: usize,
969 normalized: Option<bool>,
970 input_dims: Option<(usize, usize)>,
971 output_boxes: &mut Vec<DetectBox>,
972 output_masks: &mut Vec<Segmentation>,
973) -> Result<(), crate::DecoderError>
974where
975 f32: AsPrimitive<BOX>,
976{
977 let num_protos = protos.dim().2;
978 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
979 let mut boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
980 boxes_tensor,
981 scores_tensor,
982 score_threshold,
983 iou_threshold,
984 nms,
985 pre_nms_top_k,
986 max_det,
987 );
988 maybe_normalize_boxes_in_place(&mut boxes, normalized, input_dims);
989 impl_yolo_split_segdet_process_masks(boxes, mask_tensor, protos, output_boxes, output_masks)
990}
991
992pub(crate) fn impl_yolo_segdet_get_boxes<
993 B: BBoxTypeTrait,
994 BOX: Float + AsPrimitive<f32> + Send + Sync,
995 SCORE: Float + AsPrimitive<f32> + Send + Sync,
996>(
997 boxes_tensor: ArrayView2<BOX>,
998 scores_tensor: ArrayView2<SCORE>,
999 score_threshold: f32,
1000 iou_threshold: f32,
1001 nms: Option<Nms>,
1002 pre_nms_top_k: usize,
1003 max_det: usize,
1004) -> Vec<(DetectBox, usize)>
1005where
1006 f32: AsPrimitive<SCORE>,
1007{
1008 let span = tracing::trace_span!(
1009 "decode",
1010 n_candidates = tracing::field::Empty,
1011 n_after_topk = tracing::field::Empty,
1012 n_after_nms = tracing::field::Empty,
1013 n_detections = tracing::field::Empty,
1014 );
1015 let _guard = span.enter();
1016
1017 let mut boxes = {
1018 let _s = tracing::trace_span!("score_filter").entered();
1019 postprocess_boxes_index_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor)
1020 };
1021 span.record("n_candidates", boxes.len());
1022
1023 if nms.is_some() {
1024 let _s = tracing::trace_span!("top_k", k = pre_nms_top_k).entered();
1025 truncate_to_top_k_by_score(&mut boxes, pre_nms_top_k);
1026 }
1027 span.record("n_after_topk", boxes.len());
1028
1029 let mut boxes = {
1030 let _s = tracing::trace_span!("nms").entered();
1031 dispatch_nms_extra_float(nms, iou_threshold, boxes)
1032 };
1033 span.record("n_after_nms", boxes.len());
1034
1035 boxes.sort_unstable_by(|a, b| b.0.score.total_cmp(&a.0.score));
1036 boxes.truncate(max_det);
1037 span.record("n_detections", boxes.len());
1038
1039 boxes
1040}
1041
1042pub(crate) fn impl_yolo_end_to_end_segdet_get_boxes<
1043 B: BBoxTypeTrait,
1044 BOX: Float + AsPrimitive<f32> + Send + Sync,
1045 SCORE: Float + AsPrimitive<f32> + Send + Sync,
1046 CLASS: AsPrimitive<f32> + Send + Sync,
1047>(
1048 boxes: ArrayView2<BOX>,
1049 scores: ArrayView2<SCORE>,
1050 classes: ArrayView1<CLASS>,
1051 score_threshold: f32,
1052 max_boxes: usize,
1053) -> Vec<(DetectBox, usize)>
1054where
1055 f32: AsPrimitive<SCORE>,
1056{
1057 let mut boxes = postprocess_boxes_index_float::<B, _, _>(score_threshold.as_(), boxes, scores);
1058 boxes.truncate(max_boxes);
1059 for (b, ind) in &mut boxes {
1060 b.label = classes[*ind].as_().round() as usize;
1061 }
1062 boxes
1063}
1064
1065pub(crate) fn impl_yolo_split_segdet_process_masks<
1066 MASK: Float + AsPrimitive<f32> + Send + Sync,
1067 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1068>(
1069 boxes: Vec<(DetectBox, usize)>,
1070 masks_tensor: ArrayView2<MASK>,
1071 protos_tensor: ArrayView3<PROTO>,
1072 output_boxes: &mut Vec<DetectBox>,
1073 output_masks: &mut Vec<Segmentation>,
1074) -> Result<(), crate::DecoderError> {
1075 let _span = tracing::trace_span!("process_masks", n = boxes.len(), mode = "float").entered();
1076 let boxes = decode_segdet_f32(boxes, masks_tensor, protos_tensor)?;
1080 output_boxes.clear();
1081 output_masks.clear();
1082 for (b, roi, m) in boxes.into_iter() {
1083 output_boxes.push(b);
1084 output_masks.push(Segmentation {
1085 xmin: roi.xmin,
1086 ymin: roi.ymin,
1087 xmax: roi.xmax,
1088 ymax: roi.ymax,
1089 segmentation: m,
1090 });
1091 }
1092 Ok(())
1093}
1094pub(crate) fn impl_yolo_split_segdet_quant_get_boxes<
1098 B: BBoxTypeTrait,
1099 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
1100 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
1101>(
1102 boxes: (ArrayView2<BOX>, Quantization),
1103 scores: (ArrayView2<SCORE>, Quantization),
1104 score_threshold: f32,
1105 iou_threshold: f32,
1106 nms: Option<Nms>,
1107 pre_nms_top_k: usize,
1108 max_det: usize,
1109) -> Vec<(DetectBox, usize)>
1110where
1111 f32: AsPrimitive<SCORE>,
1112{
1113 let (boxes_tensor, quant_boxes) = boxes;
1114 let (scores_tensor, quant_scores) = scores;
1115
1116 let span = tracing::trace_span!(
1117 "decode",
1118 n_candidates = tracing::field::Empty,
1119 n_after_topk = tracing::field::Empty,
1120 n_after_nms = tracing::field::Empty,
1121 n_detections = tracing::field::Empty,
1122 );
1123 let _guard = span.enter();
1124
1125 let mut boxes = {
1126 let _s = tracing::trace_span!("score_filter").entered();
1127 let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
1128 postprocess_boxes_index_quant::<B, _, _>(
1129 score_threshold,
1130 boxes_tensor,
1131 scores_tensor,
1132 quant_boxes,
1133 )
1134 };
1135 span.record("n_candidates", boxes.len());
1136
1137 if nms.is_some() {
1138 let _s = tracing::trace_span!("top_k", k = pre_nms_top_k).entered();
1139 truncate_to_top_k_by_score_quant(&mut boxes, pre_nms_top_k);
1140 }
1141 span.record("n_after_topk", boxes.len());
1142
1143 let mut boxes = {
1144 let _s = tracing::trace_span!("nms").entered();
1145 dispatch_nms_extra_int(nms, iou_threshold, boxes)
1146 };
1147 span.record("n_after_nms", boxes.len());
1148
1149 boxes.sort_unstable_by(|a, b| b.0.score.cmp(&a.0.score));
1151 boxes.truncate(max_det);
1152 let result: Vec<_> = {
1153 let _s = tracing::trace_span!("box_dequant", n = boxes.len()).entered();
1154 boxes
1155 .into_iter()
1156 .map(|(b, i)| (dequant_detect_box(&b, quant_scores), i))
1157 .collect()
1158 };
1159 span.record("n_detections", result.len());
1160
1161 result
1162}
1163
1164pub(crate) fn impl_yolo_split_segdet_quant_process_masks<
1165 MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1166 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1167>(
1168 boxes: Vec<(DetectBox, usize)>,
1169 mask_coeff: (ArrayView2<MASK>, Quantization),
1170 protos: (ArrayView3<PROTO>, Quantization),
1171 output_boxes: &mut Vec<DetectBox>,
1172 output_masks: &mut Vec<Segmentation>,
1173) -> Result<(), crate::DecoderError> {
1174 let _span = tracing::trace_span!("process_masks", n = boxes.len(), mode = "quant").entered();
1175 let (masks, quant_masks) = mask_coeff;
1176 let (protos, quant_protos) = protos;
1177
1178 let boxes = decode_segdet_quant(boxes, masks, protos, quant_masks, quant_protos)?;
1182 output_boxes.clear();
1183 output_masks.clear();
1184 for (b, roi, m) in boxes.into_iter() {
1185 output_boxes.push(b);
1186 output_masks.push(Segmentation {
1187 xmin: roi.xmin,
1188 ymin: roi.ymin,
1189 xmax: roi.xmax,
1190 ymax: roi.ymax,
1191 segmentation: m,
1192 });
1193 }
1194 Ok(())
1195}
1196
1197#[allow(clippy::too_many_arguments)]
1209pub(crate) fn impl_yolo_split_segdet_float<
1210 B: BBoxTypeTrait,
1211 BOX: Float + AsPrimitive<f32> + Send + Sync,
1212 SCORE: Float + AsPrimitive<f32> + Send + Sync,
1213 MASK: Float + AsPrimitive<f32> + Send + Sync,
1214 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1215>(
1216 boxes_tensor: ArrayView2<BOX>,
1217 scores_tensor: ArrayView2<SCORE>,
1218 mask_tensor: ArrayView2<MASK>,
1219 protos: ArrayView3<PROTO>,
1220 score_threshold: f32,
1221 iou_threshold: f32,
1222 nms: Option<Nms>,
1223 pre_nms_top_k: usize,
1224 max_det: usize,
1225 normalized: Option<bool>,
1226 input_dims: Option<(usize, usize)>,
1227 output_boxes: &mut Vec<DetectBox>,
1228 output_masks: &mut Vec<Segmentation>,
1229) -> Result<(), crate::DecoderError>
1230where
1231 f32: AsPrimitive<SCORE>,
1232{
1233 let (boxes_tensor, scores_tensor, mask_tensor) =
1234 postprocess_yolo_split_segdet(boxes_tensor, scores_tensor, mask_tensor);
1235
1236 let mut boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
1237 boxes_tensor,
1238 scores_tensor,
1239 score_threshold,
1240 iou_threshold,
1241 nms,
1242 pre_nms_top_k,
1243 max_det,
1244 );
1245 maybe_normalize_boxes_in_place(&mut boxes, normalized, input_dims);
1246 impl_yolo_split_segdet_process_masks(boxes, mask_tensor, protos, output_boxes, output_masks)
1247}
1248
1249#[allow(clippy::too_many_arguments)]
1256pub(crate) fn impl_yolo_segdet_quant_proto<
1257 B: BBoxTypeTrait,
1258 BOX: PrimInt
1259 + AsPrimitive<i64>
1260 + AsPrimitive<i128>
1261 + AsPrimitive<f32>
1262 + AsPrimitive<i8>
1263 + Send
1264 + Sync,
1265 PROTO: PrimInt
1266 + AsPrimitive<i64>
1267 + AsPrimitive<i128>
1268 + AsPrimitive<f32>
1269 + AsPrimitive<i8>
1270 + Send
1271 + Sync,
1272>(
1273 boxes: (ArrayView2<BOX>, Quantization),
1274 protos: (ArrayView3<PROTO>, Quantization),
1275 score_threshold: f32,
1276 iou_threshold: f32,
1277 nms: Option<Nms>,
1278 pre_nms_top_k: usize,
1279 max_det: usize,
1280 normalized: Option<bool>,
1281 input_dims: Option<(usize, usize)>,
1282 output_boxes: &mut Vec<DetectBox>,
1283) -> ProtoData
1284where
1285 f32: AsPrimitive<BOX>,
1286{
1287 let (boxes_arr, quant_boxes) = boxes;
1288 let (protos_arr, quant_protos) = protos;
1289 let num_protos = protos_arr.dim().2;
1290
1291 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes_arr, num_protos);
1292
1293 let mut det_indices = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
1294 (boxes_tensor, quant_boxes),
1295 (scores_tensor, quant_boxes),
1296 score_threshold,
1297 iou_threshold,
1298 nms,
1299 pre_nms_top_k,
1300 max_det,
1301 );
1302 maybe_normalize_boxes_in_place(&mut det_indices, normalized, input_dims);
1303
1304 extract_proto_data_quant(
1305 det_indices,
1306 mask_tensor,
1307 quant_boxes,
1308 protos_arr,
1309 quant_protos,
1310 output_boxes,
1311 )
1312}
1313
1314#[allow(clippy::too_many_arguments)]
1317pub(crate) fn impl_yolo_segdet_float_proto<
1318 B: BBoxTypeTrait,
1319 BOX: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1320 PROTO: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1321>(
1322 boxes: ArrayView2<BOX>,
1323 protos: ArrayView3<PROTO>,
1324 score_threshold: f32,
1325 iou_threshold: f32,
1326 nms: Option<Nms>,
1327 pre_nms_top_k: usize,
1328 max_det: usize,
1329 normalized: Option<bool>,
1330 input_dims: Option<(usize, usize)>,
1331 output_boxes: &mut Vec<DetectBox>,
1332) -> ProtoData
1333where
1334 f32: AsPrimitive<BOX>,
1335{
1336 let num_protos = protos.dim().2;
1337 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
1338
1339 let mut boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
1340 boxes_tensor,
1341 scores_tensor,
1342 score_threshold,
1343 iou_threshold,
1344 nms,
1345 pre_nms_top_k,
1346 max_det,
1347 );
1348 maybe_normalize_boxes_in_place(&mut boxes, normalized, input_dims);
1349
1350 extract_proto_data_float(boxes, mask_tensor, protos, output_boxes)
1351}
1352
1353#[allow(clippy::too_many_arguments)]
1356pub(crate) fn impl_yolo_split_segdet_float_proto<
1357 B: BBoxTypeTrait,
1358 BOX: Float + AsPrimitive<f32> + Send + Sync,
1359 SCORE: Float + AsPrimitive<f32> + Send + Sync,
1360 MASK: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1361 PROTO: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1362>(
1363 boxes_tensor: ArrayView2<BOX>,
1364 scores_tensor: ArrayView2<SCORE>,
1365 mask_tensor: ArrayView2<MASK>,
1366 protos: ArrayView3<PROTO>,
1367 score_threshold: f32,
1368 iou_threshold: f32,
1369 nms: Option<Nms>,
1370 pre_nms_top_k: usize,
1371 max_det: usize,
1372 output_boxes: &mut Vec<DetectBox>,
1373) -> ProtoData
1374where
1375 f32: AsPrimitive<SCORE>,
1376{
1377 let (boxes_tensor, scores_tensor, mask_tensor) =
1378 postprocess_yolo_split_segdet(boxes_tensor, scores_tensor, mask_tensor);
1379 let det_indices = impl_yolo_segdet_get_boxes::<B, _, _>(
1380 boxes_tensor,
1381 scores_tensor,
1382 score_threshold,
1383 iou_threshold,
1384 nms,
1385 pre_nms_top_k,
1386 max_det,
1387 );
1388
1389 extract_proto_data_float(det_indices, mask_tensor, protos, output_boxes)
1390}
1391
1392pub(crate) fn decode_yolo_end_to_end_segdet_float_proto<T>(
1394 output: ArrayView2<T>,
1395 protos: ArrayView3<T>,
1396 score_threshold: f32,
1397 output_boxes: &mut Vec<DetectBox>,
1398) -> Result<ProtoData, crate::DecoderError>
1399where
1400 T: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1401 f32: AsPrimitive<T>,
1402{
1403 let (boxes, scores, classes, mask_coeff) =
1404 postprocess_yolo_end_to_end_segdet(&output, protos.dim().2)?;
1405 let cap = cap_or_default(output_boxes);
1406 let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
1407 boxes,
1408 scores,
1409 classes,
1410 score_threshold,
1411 cap,
1412 );
1413
1414 Ok(extract_proto_data_float(
1415 boxes,
1416 mask_coeff,
1417 protos,
1418 output_boxes,
1419 ))
1420}
1421
1422#[allow(clippy::too_many_arguments)]
1424pub(crate) fn decode_yolo_split_end_to_end_segdet_float_proto<T>(
1425 boxes: ArrayView2<T>,
1426 scores: ArrayView2<T>,
1427 classes: ArrayView2<T>,
1428 mask_coeff: ArrayView2<T>,
1429 protos: ArrayView3<T>,
1430 score_threshold: f32,
1431 output_boxes: &mut Vec<DetectBox>,
1432) -> Result<ProtoData, crate::DecoderError>
1433where
1434 T: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1435 f32: AsPrimitive<T>,
1436{
1437 let (boxes, scores, classes, mask_coeff) =
1438 postprocess_yolo_split_end_to_end_segdet(boxes, scores, &classes, mask_coeff)?;
1439 let cap = cap_or_default(output_boxes);
1440 let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
1441 boxes,
1442 scores,
1443 classes,
1444 score_threshold,
1445 cap,
1446 );
1447
1448 Ok(extract_proto_data_float(
1449 boxes,
1450 mask_coeff,
1451 protos,
1452 output_boxes,
1453 ))
1454}
1455
1456pub(super) fn extract_proto_data_float<
1463 MASK: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1464 PROTO: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1465>(
1466 det_indices: Vec<(DetectBox, usize)>,
1467 mask_tensor: ArrayView2<MASK>,
1468 protos: ArrayView3<PROTO>,
1469 output_boxes: &mut Vec<DetectBox>,
1470) -> ProtoData {
1471 let _span = tracing::trace_span!(
1472 "extract_proto",
1473 n = det_indices.len(),
1474 num_protos = mask_tensor.ncols(),
1475 layout = "nhwc",
1476 )
1477 .entered();
1478
1479 let num_protos = mask_tensor.ncols();
1480 let n = det_indices.len();
1481
1482 let mut coeff_rows: Vec<MASK> = Vec::with_capacity(n * num_protos);
1487 output_boxes.clear();
1488 for (det, idx) in det_indices {
1489 output_boxes.push(det);
1490 let row = mask_tensor.row(idx);
1491 coeff_rows.extend(row.iter().copied());
1492 }
1493
1494 let mask_coefficients = MASK::slice_into_tensor_dyn(&coeff_rows, &[n, num_protos])
1495 .expect("allocating mask_coefficients TensorDyn");
1496 let protos_tensor =
1497 PROTO::arrayview3_into_tensor_dyn(protos).expect("allocating protos TensorDyn");
1498
1499 ProtoData {
1500 mask_coefficients,
1501 protos: protos_tensor,
1502 layout: ProtoLayout::Nhwc,
1503 }
1504}
1505
1506pub(crate) fn extract_proto_data_quant<
1515 MASK: PrimInt + AsPrimitive<f32> + AsPrimitive<i8> + Send + Sync,
1516 PROTO: PrimInt + AsPrimitive<f32> + AsPrimitive<i8> + Send + Sync + 'static,
1517>(
1518 det_indices: Vec<(DetectBox, usize)>,
1519 mask_tensor: ArrayView2<MASK>,
1520 quant_masks: Quantization,
1521 protos: ArrayView3<PROTO>,
1522 quant_protos: Quantization,
1523 output_boxes: &mut Vec<DetectBox>,
1524) -> ProtoData {
1525 use edgefirst_tensor::{Tensor, TensorDyn, TensorMapTrait, TensorMemory, TensorTrait};
1526
1527 let span = tracing::trace_span!(
1528 "extract_proto",
1529 n = det_indices.len(),
1530 num_protos = tracing::field::Empty,
1531 layout = tracing::field::Empty,
1532 );
1533 let _guard = span.enter();
1534
1535 let num_protos = mask_tensor.ncols();
1536 let n = det_indices.len();
1537 span.record("num_protos", num_protos);
1538
1539 if n == 0 {
1545 output_boxes.clear();
1546 let (h, w, k) = protos.dim();
1547
1548 let (proto_shape, proto_layout) = if std::any::TypeId::of::<PROTO>()
1550 == std::any::TypeId::of::<i8>()
1551 {
1552 if protos.is_standard_layout() {
1553 (&[h, w, k][..], ProtoLayout::Nhwc)
1554 } else if protos.ndim() == 3 && protos.strides() == [w as isize, 1, (h * w) as isize] {
1555 (&[k, h, w][..], ProtoLayout::Nchw)
1556 } else {
1557 (&[h, w, k][..], ProtoLayout::Nhwc)
1558 }
1559 } else {
1560 (&[h, w, k][..], ProtoLayout::Nhwc)
1561 };
1562
1563 let coeff_tensor = Tensor::<i8>::new(&[0, num_protos], Some(TensorMemory::Mem), None)
1564 .expect("allocating empty mask_coefficients tensor");
1565 let coeff_quant =
1566 edgefirst_tensor::Quantization::per_tensor(quant_masks.scale, quant_masks.zero_point);
1567 let coeff_tensor = coeff_tensor
1568 .with_quantization(coeff_quant)
1569 .expect("per-tensor quantization on mask coefficients");
1570 let protos_tensor = Tensor::<i8>::new(proto_shape, Some(TensorMemory::Mem), None)
1571 .expect("allocating protos tensor");
1572 let tensor_quant =
1573 edgefirst_tensor::Quantization::per_tensor(quant_protos.scale, quant_protos.zero_point);
1574 let protos_tensor = protos_tensor
1575 .with_quantization(tensor_quant)
1576 .expect("per-tensor quantization on protos tensor");
1577 return ProtoData {
1578 mask_coefficients: TensorDyn::I8(coeff_tensor),
1579 protos: TensorDyn::I8(protos_tensor),
1580 layout: proto_layout,
1581 };
1582 }
1583
1584 let mut coeff_i8 = Vec::<i8>::with_capacity(n * num_protos);
1588 output_boxes.clear();
1589 for (det, idx) in det_indices {
1590 output_boxes.push(det);
1591 let row = mask_tensor.row(idx);
1592 coeff_i8.extend(row.iter().map(|v| {
1593 let v_i8: i8 = v.as_();
1594 v_i8
1595 }));
1596 }
1597
1598 let coeff_tensor = Tensor::<i8>::new(&[n, num_protos], Some(TensorMemory::Mem), None)
1601 .expect("allocating mask_coefficients tensor");
1602 if n > 0 {
1603 let mut m = coeff_tensor
1604 .map()
1605 .expect("mapping mask_coefficients tensor");
1606 m.as_mut_slice().copy_from_slice(&coeff_i8);
1607 }
1608 let coeff_quant =
1609 edgefirst_tensor::Quantization::per_tensor(quant_masks.scale, quant_masks.zero_point);
1610 let coeff_tensor = coeff_tensor
1611 .with_quantization(coeff_quant)
1612 .expect("per-tensor quantization on mask coefficients");
1613 let mask_coefficients = TensorDyn::I8(coeff_tensor);
1614
1615 let (h, w, k) = protos.dim();
1619
1620 let (proto_shape, proto_layout) =
1622 if std::any::TypeId::of::<PROTO>() == std::any::TypeId::of::<i8>() {
1623 if protos.is_standard_layout() {
1624 (&[h, w, k][..], ProtoLayout::Nhwc)
1626 } else if protos.ndim() == 3 && protos.strides() == [w as isize, 1, (h * w) as isize] {
1627 (&[k, h, w][..], ProtoLayout::Nchw)
1631 } else {
1632 (&[h, w, k][..], ProtoLayout::Nhwc)
1634 }
1635 } else {
1636 (&[h, w, k][..], ProtoLayout::Nhwc)
1637 };
1638
1639 let protos_tensor = Tensor::<i8>::new(proto_shape, Some(TensorMemory::Mem), None)
1640 .expect("allocating protos tensor");
1641 {
1642 let mut m = protos_tensor.map().expect("mapping protos tensor");
1643 let dst = m.as_mut_slice();
1644 if std::any::TypeId::of::<PROTO>() == std::any::TypeId::of::<i8>() {
1645 if protos.is_standard_layout() {
1648 let src: &[i8] = unsafe {
1649 std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len())
1650 };
1651 dst.copy_from_slice(src);
1652 } else if protos.ndim() == 3 && protos.strides() == [w as isize, 1, (h * w) as isize] {
1653 let total = h * w * k;
1657 let src: &[i8] =
1660 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, total) };
1661 dst.copy_from_slice(src);
1662 } else {
1663 for (d, s) in dst.iter_mut().zip(protos.iter()) {
1664 let v_i8: i8 = s.as_();
1665 *d = v_i8;
1666 }
1667 }
1668 } else {
1669 for (d, s) in dst.iter_mut().zip(protos.iter()) {
1670 let v_i8: i8 = s.as_();
1671 *d = v_i8;
1672 }
1673 }
1674 }
1675 let tensor_quant =
1676 edgefirst_tensor::Quantization::per_tensor(quant_protos.scale, quant_protos.zero_point);
1677 let protos_tensor = protos_tensor
1678 .with_quantization(tensor_quant)
1679 .expect("per-tensor quantization on new Tensor<i8>");
1680
1681 span.record("layout", tracing::field::debug(&proto_layout));
1682
1683 ProtoData {
1684 mask_coefficients,
1685 protos: TensorDyn::I8(protos_tensor),
1686 layout: proto_layout,
1687 }
1688}
1689
1690pub trait FloatProtoElem: Copy + 'static {
1696 fn slice_into_tensor_dyn(
1697 values: &[Self],
1698 shape: &[usize],
1699 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn>;
1700
1701 fn arrayview3_into_tensor_dyn(
1702 view: ArrayView3<'_, Self>,
1703 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn>;
1704}
1705
1706impl FloatProtoElem for f32 {
1707 fn slice_into_tensor_dyn(
1708 values: &[f32],
1709 shape: &[usize],
1710 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1711 edgefirst_tensor::Tensor::<f32>::from_slice(values, shape)
1712 .map(edgefirst_tensor::TensorDyn::F32)
1713 }
1714 fn arrayview3_into_tensor_dyn(
1715 view: ArrayView3<'_, f32>,
1716 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1717 edgefirst_tensor::Tensor::<f32>::from_arrayview3(view).map(edgefirst_tensor::TensorDyn::F32)
1718 }
1719}
1720
1721impl FloatProtoElem for half::f16 {
1722 fn slice_into_tensor_dyn(
1723 values: &[half::f16],
1724 shape: &[usize],
1725 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1726 edgefirst_tensor::Tensor::<half::f16>::from_slice(values, shape)
1727 .map(edgefirst_tensor::TensorDyn::F16)
1728 }
1729 fn arrayview3_into_tensor_dyn(
1730 view: ArrayView3<'_, half::f16>,
1731 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1732 edgefirst_tensor::Tensor::<half::f16>::from_arrayview3(view)
1733 .map(edgefirst_tensor::TensorDyn::F16)
1734 }
1735}
1736
1737impl FloatProtoElem for f64 {
1738 fn slice_into_tensor_dyn(
1739 values: &[f64],
1740 shape: &[usize],
1741 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1742 let narrowed: Vec<f32> = values.iter().map(|&v| v as f32).collect();
1744 edgefirst_tensor::Tensor::<f32>::from_slice(&narrowed, shape)
1745 .map(edgefirst_tensor::TensorDyn::F32)
1746 }
1747 fn arrayview3_into_tensor_dyn(
1748 view: ArrayView3<'_, f64>,
1749 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1750 let narrowed: ndarray::Array3<f32> = view.mapv(|v| v as f32);
1751 edgefirst_tensor::Tensor::<f32>::from_arrayview3(narrowed.view())
1752 .map(edgefirst_tensor::TensorDyn::F32)
1753 }
1754}
1755
1756fn postprocess_yolo<'a, T>(
1757 output: &'a ArrayView2<'_, T>,
1758) -> (ArrayView2<'a, T>, ArrayView2<'a, T>) {
1759 let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1760 let scores_tensor = output.slice(s![4.., ..,]).reversed_axes();
1761 (boxes_tensor, scores_tensor)
1762}
1763
1764pub(crate) fn postprocess_yolo_seg<'a, T>(
1765 output: &'a ArrayView2<'_, T>,
1766 num_protos: usize,
1767) -> (ArrayView2<'a, T>, ArrayView2<'a, T>, ArrayView2<'a, T>) {
1768 assert!(
1769 output.shape()[0] > num_protos + 4,
1770 "Output shape is too short: {} <= {} + 4",
1771 output.shape()[0],
1772 num_protos
1773 );
1774 let num_classes = output.shape()[0] - 4 - num_protos;
1775 let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1776 let scores_tensor = output.slice(s![4..(num_classes + 4), ..,]).reversed_axes();
1777 let mask_tensor = output.slice(s![(num_classes + 4).., ..,]).reversed_axes();
1778 (boxes_tensor, scores_tensor, mask_tensor)
1779}
1780
1781pub(crate) fn postprocess_yolo_split_segdet<'a, 'b, 'c, BOX, SCORE, MASK>(
1782 boxes_tensor: ArrayView2<'a, BOX>,
1783 scores_tensor: ArrayView2<'b, SCORE>,
1784 mask_tensor: ArrayView2<'c, MASK>,
1785) -> (
1786 ArrayView2<'a, BOX>,
1787 ArrayView2<'b, SCORE>,
1788 ArrayView2<'c, MASK>,
1789) {
1790 let boxes_tensor = boxes_tensor.reversed_axes();
1791 let scores_tensor = scores_tensor.reversed_axes();
1792 let mask_tensor = mask_tensor.reversed_axes();
1793 (boxes_tensor, scores_tensor, mask_tensor)
1794}
1795
1796fn decode_segdet_f32<
1797 MASK: Float + AsPrimitive<f32> + Send + Sync,
1798 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1799>(
1800 boxes: Vec<(DetectBox, usize)>,
1801 masks: ArrayView2<MASK>,
1802 protos: ArrayView3<PROTO>,
1803) -> Result<Vec<(DetectBox, BoundingBox, Array3<u8>)>, crate::DecoderError> {
1804 if boxes.is_empty() {
1805 return Ok(Vec::new());
1806 }
1807 if masks.shape()[1] != protos.shape()[2] {
1808 return Err(crate::DecoderError::InvalidShape(format!(
1809 "Mask coefficients count ({}) doesn't match protos channel count ({})",
1810 masks.shape()[1],
1811 protos.shape()[2],
1812 )));
1813 }
1814 boxes
1815 .into_par_iter()
1816 .map(|b| {
1817 let ind = b.1;
1818 let (protos, roi) = protobox(&protos, &b.0.bbox)?;
1823 Ok((b.0, roi, make_segmentation(masks.row(ind), protos.view())))
1824 })
1825 .collect()
1826}
1827
1828pub(crate) fn decode_segdet_quant<
1829 MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1830 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1831>(
1832 boxes: Vec<(DetectBox, usize)>,
1833 masks: ArrayView2<MASK>,
1834 protos: ArrayView3<PROTO>,
1835 quant_masks: Quantization,
1836 quant_protos: Quantization,
1837) -> Result<Vec<(DetectBox, BoundingBox, Array3<u8>)>, crate::DecoderError> {
1838 if boxes.is_empty() {
1839 return Ok(Vec::new());
1840 }
1841 if masks.shape()[1] != protos.shape()[2] {
1842 return Err(crate::DecoderError::InvalidShape(format!(
1843 "Mask coefficients count ({}) doesn't match protos channel count ({})",
1844 masks.shape()[1],
1845 protos.shape()[2],
1846 )));
1847 }
1848
1849 let total_bits = MASK::zero().count_zeros() + PROTO::zero().count_zeros() + 5; boxes
1851 .into_iter()
1852 .map(|b| {
1853 let i = b.1;
1854 let (protos, roi) = protobox(&protos, &b.0.bbox.to_canonical())?;
1858 let seg = match total_bits {
1859 0..=64 => make_segmentation_quant::<MASK, PROTO, i64>(
1860 masks.row(i),
1861 protos.view(),
1862 quant_masks,
1863 quant_protos,
1864 ),
1865 65..=128 => make_segmentation_quant::<MASK, PROTO, i128>(
1866 masks.row(i),
1867 protos.view(),
1868 quant_masks,
1869 quant_protos,
1870 ),
1871 _ => {
1872 return Err(crate::DecoderError::NotSupported(format!(
1873 "Unsupported bit width ({total_bits}) for segmentation computation"
1874 )));
1875 }
1876 };
1877 Ok((b.0, roi, seg))
1878 })
1879 .collect()
1880}
1881
1882fn protobox<'a, T>(
1883 protos: &'a ArrayView3<T>,
1884 roi: &BoundingBox,
1885) -> Result<(ArrayView3<'a, T>, BoundingBox), crate::DecoderError> {
1886 let width = protos.dim().1 as f32;
1887 let height = protos.dim().0 as f32;
1888
1889 const NORM_LIMIT: f32 = 2.0;
1901 if roi.xmin > NORM_LIMIT
1902 || roi.ymin > NORM_LIMIT
1903 || roi.xmax > NORM_LIMIT
1904 || roi.ymax > NORM_LIMIT
1905 {
1906 return Err(crate::DecoderError::InvalidShape(format!(
1907 "Bounding box coordinates appear un-normalized (pixel-space). \
1908 Got bbox=({:.2}, {:.2}, {:.2}, {:.2}) but expected values in [0, 1]. \
1909 Two ways to fix this: \
1910 (1) declare `Detection::normalized = false` in the model schema \
1911 AND make sure the schema's `input.shape` / `input.dshape` carries \
1912 the model input dims so the decoder can divide by (W, H) before NMS \
1913 (EDGEAI-1303 — verify with `Decoder::input_dims().is_some()`); or \
1914 (2) normalize the boxes in-graph before decode().",
1915 roi.xmin, roi.ymin, roi.xmax, roi.ymax,
1916 )));
1917 }
1918
1919 let roi = [
1920 (roi.xmin * width).clamp(0.0, width) as usize,
1921 (roi.ymin * height).clamp(0.0, height) as usize,
1922 (roi.xmax * width).clamp(0.0, width).ceil() as usize,
1923 (roi.ymax * height).clamp(0.0, height).ceil() as usize,
1924 ];
1925
1926 let roi_norm = [
1927 roi[0] as f32 / width,
1928 roi[1] as f32 / height,
1929 roi[2] as f32 / width,
1930 roi[3] as f32 / height,
1931 ]
1932 .into();
1933
1934 let cropped = protos.slice(s![roi[1]..roi[3], roi[0]..roi[2], ..]);
1935
1936 Ok((cropped, roi_norm))
1937}
1938
1939fn make_segmentation<
1945 MASK: Float + AsPrimitive<f32> + Send + Sync,
1946 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1947>(
1948 mask: ArrayView1<MASK>,
1949 protos: ArrayView3<PROTO>,
1950) -> Array3<u8> {
1951 let shape = protos.shape();
1952
1953 let mask = mask.to_shape((1, mask.len())).unwrap();
1955 let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
1956 let protos = protos.reversed_axes();
1957 let mask = mask.map(|x| x.as_());
1958 let protos = protos.map(|x| x.as_());
1959
1960 let mask = mask
1962 .dot(&protos)
1963 .into_shape_with_order((shape[0], shape[1], 1))
1964 .unwrap();
1965
1966 mask.map(|x| {
1967 let sigmoid = 1.0 / (1.0 + (-*x).exp());
1968 (sigmoid * 255.0).round() as u8
1969 })
1970}
1971
1972fn make_segmentation_quant<
1979 MASK: PrimInt + AsPrimitive<DEST> + Send + Sync,
1980 PROTO: PrimInt + AsPrimitive<DEST> + Send + Sync,
1981 DEST: PrimInt + 'static + Signed + AsPrimitive<f32> + Debug,
1982>(
1983 mask: ArrayView1<MASK>,
1984 protos: ArrayView3<PROTO>,
1985 quant_masks: Quantization,
1986 quant_protos: Quantization,
1987) -> Array3<u8>
1988where
1989 i32: AsPrimitive<DEST>,
1990 f32: AsPrimitive<DEST>,
1991{
1992 let shape = protos.shape();
1993
1994 let mask = mask.to_shape((1, mask.len())).unwrap();
1996
1997 let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
1998 let protos = protos.reversed_axes();
1999
2000 let zp = quant_masks.zero_point.as_();
2001
2002 let mask = mask.mapv(|x| x.as_() - zp);
2003
2004 let zp = quant_protos.zero_point.as_();
2005 let protos = protos.mapv(|x| x.as_() - zp);
2006
2007 let segmentation = mask
2009 .dot(&protos)
2010 .into_shape_with_order((shape[0], shape[1], 1))
2011 .unwrap();
2012
2013 let combined_scale = quant_masks.scale * quant_protos.scale;
2014 segmentation.map(|x| {
2015 let val: f32 = (*x).as_() * combined_scale;
2016 let sigmoid = 1.0 / (1.0 + (-val).exp());
2017 (sigmoid * 255.0).round() as u8
2018 })
2019}
2020
2021pub(crate) fn yolo_segmentation_to_mask(
2033 segmentation: ArrayView3<u8>,
2034 threshold: u8,
2035) -> Result<Array2<u8>, crate::DecoderError> {
2036 if segmentation.shape()[2] != 1 {
2037 return Err(crate::DecoderError::InvalidShape(format!(
2038 "Yolo Instance Segmentation should have shape (H, W, 1), got (H, W, {})",
2039 segmentation.shape()[2]
2040 )));
2041 }
2042 Ok(segmentation
2043 .slice(s![.., .., 0])
2044 .map(|x| if *x >= threshold { 1 } else { 0 }))
2045}
2046
2047#[cfg(test)]
2048#[cfg_attr(coverage_nightly, coverage(off))]
2049mod tests {
2050 use super::*;
2051 use ndarray::Array2;
2052
2053 #[test]
2058 fn test_end_to_end_det_basic_filtering() {
2059 let data: Vec<f32> = vec![
2063 0.1, 0.2, 0.3, 0.1, 0.2, 0.3, 0.5, 0.6, 0.7, 0.5, 0.6, 0.7, 0.9, 0.1, 0.2, 0.0, 1.0, 2.0, ];
2071 let output = Array2::from_shape_vec((6, 3), data).unwrap();
2072
2073 let mut boxes = Vec::with_capacity(10);
2074 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2075
2076 assert_eq!(boxes.len(), 1);
2078 assert_eq!(boxes[0].label, 0);
2079 assert!((boxes[0].score - 0.9).abs() < 0.01);
2080 assert!((boxes[0].bbox.xmin - 0.1).abs() < 0.01);
2081 assert!((boxes[0].bbox.ymin - 0.1).abs() < 0.01);
2082 assert!((boxes[0].bbox.xmax - 0.5).abs() < 0.01);
2083 assert!((boxes[0].bbox.ymax - 0.5).abs() < 0.01);
2084 }
2085
2086 #[test]
2087 fn test_end_to_end_det_all_pass_threshold() {
2088 let data: Vec<f32> = vec![
2090 10.0, 20.0, 10.0, 20.0, 50.0, 60.0, 50.0, 60.0, 0.8, 0.7, 1.0, 2.0, ];
2097 let output = Array2::from_shape_vec((6, 2), data).unwrap();
2098
2099 let mut boxes = Vec::with_capacity(10);
2100 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2101
2102 assert_eq!(boxes.len(), 2);
2103 assert_eq!(boxes[0].label, 1);
2104 assert_eq!(boxes[1].label, 2);
2105 }
2106
2107 #[test]
2108 fn test_end_to_end_det_none_pass_threshold() {
2109 let data: Vec<f32> = vec![
2111 10.0, 20.0, 10.0, 20.0, 50.0, 60.0, 50.0, 60.0, 0.1, 0.2, 1.0, 2.0, ];
2118 let output = Array2::from_shape_vec((6, 2), data).unwrap();
2119
2120 let mut boxes = Vec::with_capacity(10);
2121 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2122
2123 assert_eq!(boxes.len(), 0);
2124 }
2125
2126 #[test]
2127 fn test_end_to_end_det_capacity_limit() {
2128 let data: Vec<f32> = vec![
2130 0.1, 0.2, 0.3, 0.4, 0.5, 0.1, 0.2, 0.3, 0.4, 0.5, 0.5, 0.6, 0.7, 0.8, 0.9, 0.5, 0.6, 0.7, 0.8, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.0, 1.0, 2.0, 3.0, 4.0, ];
2137 let output = Array2::from_shape_vec((6, 5), data).unwrap();
2138
2139 let mut boxes = Vec::with_capacity(2); decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2141
2142 assert_eq!(boxes.len(), 2);
2143 }
2144
2145 #[test]
2146 fn test_end_to_end_det_empty_output() {
2147 let output = Array2::<f32>::zeros((6, 0));
2149
2150 let mut boxes = Vec::with_capacity(10);
2151 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2152
2153 assert_eq!(boxes.len(), 0);
2154 }
2155
2156 #[test]
2157 fn test_end_to_end_det_pixel_coordinates() {
2158 let data: Vec<f32> = vec![
2160 100.0, 200.0, 300.0, 400.0, 0.95, 5.0, ];
2167 let output = Array2::from_shape_vec((6, 1), data).unwrap();
2168
2169 let mut boxes = Vec::with_capacity(10);
2170 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2171
2172 assert_eq!(boxes.len(), 1);
2173 assert_eq!(boxes[0].label, 5);
2174 assert!((boxes[0].bbox.xmin - 100.0).abs() < 0.01);
2175 assert!((boxes[0].bbox.ymin - 200.0).abs() < 0.01);
2176 assert!((boxes[0].bbox.xmax - 300.0).abs() < 0.01);
2177 assert!((boxes[0].bbox.ymax - 400.0).abs() < 0.01);
2178 }
2179
2180 #[test]
2181 fn test_end_to_end_det_invalid_shape() {
2182 let output = Array2::<f32>::zeros((5, 3));
2184
2185 let mut boxes = Vec::with_capacity(10);
2186 let result = decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes);
2187
2188 assert!(result.is_err());
2189 assert!(matches!(
2190 result,
2191 Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 6 rows")
2192 ));
2193 }
2194
2195 #[test]
2200 fn test_end_to_end_segdet_basic() {
2201 let num_protos = 32;
2204 let num_detections = 2;
2205 let num_features = 6 + num_protos;
2206
2207 let mut data = vec![0.0f32; num_features * num_detections];
2209 data[0] = 0.1; data[1] = 0.5; data[num_detections] = 0.1; data[num_detections + 1] = 0.5; data[2 * num_detections] = 0.4; data[2 * num_detections + 1] = 0.9; data[3 * num_detections] = 0.4; data[3 * num_detections + 1] = 0.9; data[4 * num_detections] = 0.9; data[4 * num_detections + 1] = 0.3; data[5 * num_detections] = 1.0; data[5 * num_detections + 1] = 2.0; for i in 6..num_features {
2224 data[i * num_detections] = 0.1;
2225 data[i * num_detections + 1] = 0.1;
2226 }
2227
2228 let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
2229
2230 let protos = Array3::<f32>::zeros((16, 16, num_protos));
2232
2233 let mut boxes = Vec::with_capacity(10);
2234 let mut masks = Vec::with_capacity(10);
2235 decode_yolo_end_to_end_segdet_float(
2236 output.view(),
2237 protos.view(),
2238 0.5,
2239 &mut boxes,
2240 &mut masks,
2241 )
2242 .unwrap();
2243
2244 assert_eq!(boxes.len(), 1);
2246 assert_eq!(masks.len(), 1);
2247 assert_eq!(boxes[0].label, 1);
2248 assert!((boxes[0].score - 0.9).abs() < 0.01);
2249 }
2250
2251 #[test]
2252 fn test_end_to_end_segdet_mask_coordinates() {
2253 let num_protos = 32;
2255 let num_features = 6 + num_protos;
2256
2257 let mut data = vec![0.0f32; num_features];
2258 data[0] = 0.2; data[1] = 0.2; data[2] = 0.8; data[3] = 0.8; data[4] = 0.95; data[5] = 3.0; let output = Array2::from_shape_vec((num_features, 1), data).unwrap();
2266 let protos = Array3::<f32>::zeros((16, 16, num_protos));
2267
2268 let mut boxes = Vec::with_capacity(10);
2269 let mut masks = Vec::with_capacity(10);
2270 decode_yolo_end_to_end_segdet_float(
2271 output.view(),
2272 protos.view(),
2273 0.5,
2274 &mut boxes,
2275 &mut masks,
2276 )
2277 .unwrap();
2278
2279 assert_eq!(boxes.len(), 1);
2280 assert_eq!(masks.len(), 1);
2281
2282 let step = 1.0 / 16.0;
2286 assert!(masks[0].xmin <= boxes[0].bbox.xmin);
2287 assert!(masks[0].ymin <= boxes[0].bbox.ymin);
2288 assert!(masks[0].xmax >= boxes[0].bbox.xmax);
2289 assert!(masks[0].ymax >= boxes[0].bbox.ymax);
2290 assert!((boxes[0].bbox.xmin - masks[0].xmin) < step);
2291 assert!((boxes[0].bbox.ymin - masks[0].ymin) < step);
2292 assert!((masks[0].xmax - boxes[0].bbox.xmax) < step);
2293 assert!((masks[0].ymax - boxes[0].bbox.ymax) < step);
2294 }
2295
2296 #[test]
2297 fn test_end_to_end_segdet_empty_output() {
2298 let num_protos = 32;
2299 let output = Array2::<f32>::zeros((6 + num_protos, 0));
2300 let protos = Array3::<f32>::zeros((16, 16, num_protos));
2301
2302 let mut boxes = Vec::with_capacity(10);
2303 let mut masks = Vec::with_capacity(10);
2304 decode_yolo_end_to_end_segdet_float(
2305 output.view(),
2306 protos.view(),
2307 0.5,
2308 &mut boxes,
2309 &mut masks,
2310 )
2311 .unwrap();
2312
2313 assert_eq!(boxes.len(), 0);
2314 assert_eq!(masks.len(), 0);
2315 }
2316
2317 #[test]
2318 fn test_end_to_end_segdet_capacity_limit() {
2319 let num_protos = 32;
2320 let num_detections = 5;
2321 let num_features = 6 + num_protos;
2322
2323 let mut data = vec![0.0f32; num_features * num_detections];
2324 for i in 0..num_detections {
2326 data[i] = 0.1 * (i as f32); data[num_detections + i] = 0.1 * (i as f32); data[2 * num_detections + i] = 0.1 * (i as f32) + 0.2; data[3 * num_detections + i] = 0.1 * (i as f32) + 0.2; data[4 * num_detections + i] = 0.9; data[5 * num_detections + i] = i as f32; }
2333
2334 let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
2335 let protos = Array3::<f32>::zeros((16, 16, num_protos));
2336
2337 let mut boxes = Vec::with_capacity(2); let mut masks = Vec::with_capacity(2);
2339 decode_yolo_end_to_end_segdet_float(
2340 output.view(),
2341 protos.view(),
2342 0.5,
2343 &mut boxes,
2344 &mut masks,
2345 )
2346 .unwrap();
2347
2348 assert_eq!(boxes.len(), 2);
2349 assert_eq!(masks.len(), 2);
2350 }
2351
2352 #[test]
2353 fn test_end_to_end_segdet_invalid_shape_too_few_rows() {
2354 let output = Array2::<f32>::zeros((6, 3));
2356 let protos = Array3::<f32>::zeros((16, 16, 32));
2357
2358 let mut boxes = Vec::with_capacity(10);
2359 let mut masks = Vec::with_capacity(10);
2360 let result = decode_yolo_end_to_end_segdet_float(
2361 output.view(),
2362 protos.view(),
2363 0.5,
2364 &mut boxes,
2365 &mut masks,
2366 );
2367
2368 assert!(result.is_err());
2369 assert!(matches!(
2370 result,
2371 Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 7 rows")
2372 ));
2373 }
2374
2375 #[test]
2376 fn test_end_to_end_segdet_invalid_shape_protos_mismatch() {
2377 let num_protos = 32;
2379 let output = Array2::<f32>::zeros((6 + 16, 3)); let protos = Array3::<f32>::zeros((16, 16, num_protos)); let mut boxes = Vec::with_capacity(10);
2383 let mut masks = Vec::with_capacity(10);
2384 let result = decode_yolo_end_to_end_segdet_float(
2385 output.view(),
2386 protos.view(),
2387 0.5,
2388 &mut boxes,
2389 &mut masks,
2390 );
2391
2392 assert!(result.is_err());
2393 assert!(matches!(
2394 result,
2395 Err(crate::DecoderError::InvalidShape(s)) if s.contains("doesn't match protos count")
2396 ));
2397 }
2398
2399 #[test]
2404 fn test_split_end_to_end_segdet_basic() {
2405 let num_protos = 32;
2408 let num_detections = 2;
2409 let num_features = 6 + num_protos;
2410
2411 let mut data = vec![0.0f32; num_features * num_detections];
2413 data[0] = 0.1; data[1] = 0.5; data[num_detections] = 0.1; data[num_detections + 1] = 0.5; data[2 * num_detections] = 0.4; data[2 * num_detections + 1] = 0.9; data[3 * num_detections] = 0.4; data[3 * num_detections + 1] = 0.9; data[4 * num_detections] = 0.9; data[4 * num_detections + 1] = 0.3; data[5 * num_detections] = 1.0; data[5 * num_detections + 1] = 2.0; for i in 6..num_features {
2428 data[i * num_detections] = 0.1;
2429 data[i * num_detections + 1] = 0.1;
2430 }
2431
2432 let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
2433 let box_coords = output.slice(s![..4, ..]);
2434 let scores = output.slice(s![4..5, ..]);
2435 let classes = output.slice(s![5..6, ..]);
2436 let mask_coeff = output.slice(s![6.., ..]);
2437 let protos = Array3::<f32>::zeros((16, 16, num_protos));
2439
2440 let mut boxes = Vec::with_capacity(10);
2441 let mut masks = Vec::with_capacity(10);
2442 decode_yolo_split_end_to_end_segdet_float(
2443 box_coords,
2444 scores,
2445 classes,
2446 mask_coeff,
2447 protos.view(),
2448 0.5,
2449 &mut boxes,
2450 &mut masks,
2451 )
2452 .unwrap();
2453
2454 assert_eq!(boxes.len(), 1);
2456 assert_eq!(masks.len(), 1);
2457 assert_eq!(boxes[0].label, 1);
2458 assert!((boxes[0].score - 0.9).abs() < 0.01);
2459 }
2460
2461 #[test]
2466 fn test_segmentation_to_mask_basic() {
2467 let data: Vec<u8> = vec![
2469 100, 200, 50, 150, 10, 255, 128, 64, 0, 127, 128, 255, 64, 64, 192, 192, ];
2474 let segmentation = Array3::from_shape_vec((4, 4, 1), data).unwrap();
2475
2476 let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2477
2478 assert_eq!(mask[[0, 0]], 0); assert_eq!(mask[[0, 1]], 1); assert_eq!(mask[[0, 2]], 0); assert_eq!(mask[[0, 3]], 1); assert_eq!(mask[[1, 1]], 1); assert_eq!(mask[[1, 2]], 1); assert_eq!(mask[[2, 0]], 0); assert_eq!(mask[[2, 1]], 0); }
2488
2489 #[test]
2490 fn test_segmentation_to_mask_all_above() {
2491 let segmentation = Array3::from_elem((4, 4, 1), 255u8);
2492 let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2493 assert!(mask.iter().all(|&x| x == 1));
2494 }
2495
2496 #[test]
2497 fn test_segmentation_to_mask_all_below() {
2498 let segmentation = Array3::from_elem((4, 4, 1), 64u8);
2499 let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2500 assert!(mask.iter().all(|&x| x == 0));
2501 }
2502
2503 #[test]
2504 fn test_segmentation_to_mask_invalid_shape() {
2505 let segmentation = Array3::from_elem((4, 4, 3), 128u8);
2506 let result = yolo_segmentation_to_mask(segmentation.view(), 128);
2507
2508 assert!(result.is_err());
2509 assert!(matches!(
2510 result,
2511 Err(crate::DecoderError::InvalidShape(s)) if s.contains("(H, W, 1)")
2512 ));
2513 }
2514
2515 #[test]
2520 fn test_protobox_clamps_edge_coordinates() {
2521 let protos = Array3::<f32>::zeros((16, 16, 4));
2523 let view = protos.view();
2524 let roi = BoundingBox {
2525 xmin: 0.5,
2526 ymin: 0.5,
2527 xmax: 1.0,
2528 ymax: 1.0,
2529 };
2530 let result = protobox(&view, &roi);
2531 assert!(result.is_ok(), "protobox should accept xmax=1.0");
2532 let (cropped, _roi_norm) = result.unwrap();
2533 assert!(cropped.shape()[0] > 0);
2535 assert!(cropped.shape()[1] > 0);
2536 assert_eq!(cropped.shape()[2], 4);
2537 }
2538
2539 #[test]
2540 fn test_protobox_rejects_wildly_out_of_range() {
2541 let protos = Array3::<f32>::zeros((16, 16, 4));
2543 let view = protos.view();
2544 let roi = BoundingBox {
2545 xmin: 0.0,
2546 ymin: 0.0,
2547 xmax: 3.0,
2548 ymax: 3.0,
2549 };
2550 let result = protobox(&view, &roi);
2551 assert!(
2552 matches!(result, Err(crate::DecoderError::InvalidShape(s)) if s.contains("un-normalized")),
2553 "protobox should reject coords > NORM_LIMIT"
2554 );
2555 }
2556
2557 #[test]
2558 fn test_protobox_accepts_slightly_over_one() {
2559 let protos = Array3::<f32>::zeros((16, 16, 4));
2561 let view = protos.view();
2562 let roi = BoundingBox {
2563 xmin: 0.0,
2564 ymin: 0.0,
2565 xmax: 1.5,
2566 ymax: 1.5,
2567 };
2568 let result = protobox(&view, &roi);
2569 assert!(
2570 result.is_ok(),
2571 "protobox should accept coords <= NORM_LIMIT (2.0)"
2572 );
2573 let (cropped, _roi_norm) = result.unwrap();
2574 assert_eq!(cropped.shape()[0], 16);
2576 assert_eq!(cropped.shape()[1], 16);
2577 }
2578
2579 #[test]
2580 fn test_segdet_float_proto_no_panic() {
2581 let num_proposals = 100; let num_classes = 80;
2585 let num_mask_coeffs = 32;
2586 let rows = 4 + num_classes + num_mask_coeffs; let mut data = vec![0.0f32; rows * num_proposals];
2592 for i in 0..num_proposals {
2593 let row = |r: usize| r * num_proposals + i;
2594 data[row(0)] = 320.0; data[row(1)] = 320.0; data[row(2)] = 50.0; data[row(3)] = 50.0; data[row(4)] = 0.9; }
2600 let boxes = ndarray::Array2::from_shape_vec((rows, num_proposals), data).unwrap();
2601
2602 let protos = ndarray::Array3::<f32>::zeros((160, 160, num_mask_coeffs));
2607
2608 let mut output_boxes = Vec::with_capacity(300);
2609
2610 let proto_data = impl_yolo_segdet_float_proto::<XYWH, _, _>(
2612 boxes.view(),
2613 protos.view(),
2614 0.5,
2615 0.7,
2616 Some(Nms::default()),
2617 MAX_NMS_CANDIDATES,
2618 300,
2619 None,
2620 None,
2621 &mut output_boxes,
2622 );
2623
2624 assert!(!output_boxes.is_empty());
2626 let coeffs_shape = proto_data.mask_coefficients.shape();
2627 assert_eq!(coeffs_shape[0], output_boxes.len());
2628 assert_eq!(coeffs_shape[1], num_mask_coeffs);
2630 }
2631
2632 #[test]
2647 fn test_pre_nms_cap_truncates_excess_candidates() {
2648 let n: usize = 50_000;
2649 let num_classes = 1;
2650
2651 let mut boxes_data = Vec::with_capacity(n * 4);
2655 let mut scores_data = Vec::with_capacity(n * num_classes);
2656 for i in 0..n {
2657 boxes_data.extend_from_slice(&[0.1f32, 0.1, 0.5, 0.5]);
2658 scores_data.push(0.99 - (i as f32) * 1e-7);
2661 }
2662 let boxes = Array2::from_shape_vec((n, 4), boxes_data).unwrap();
2663 let scores = Array2::from_shape_vec((n, num_classes), scores_data).unwrap();
2664
2665 let result = impl_yolo_segdet_get_boxes::<XYXY, _, _>(
2666 boxes.view(),
2667 scores.view(),
2668 0.1,
2669 1.0, Some(Nms::ClassAgnostic), crate::yolo::MAX_NMS_CANDIDATES, usize::MAX, );
2674
2675 assert_eq!(
2676 result.len(),
2677 crate::yolo::MAX_NMS_CANDIDATES,
2678 "pre-NMS cap should truncate to MAX_NMS_CANDIDATES; got {}",
2679 result.len()
2680 );
2681 let top_score = result[0].0.score;
2684 assert!(
2685 top_score > 0.98,
2686 "highest-ranked survivor should have the largest score, got {top_score}"
2687 );
2688 }
2689
2690 #[test]
2695 fn test_pre_nms_cap_truncates_excess_candidates_quant() {
2696 use crate::Quantization;
2697 let n: usize = 50_000;
2698 let num_classes = 1;
2699
2700 let boxes_data = (0..n).flat_map(|_| [10i8, 10, 50, 50]).collect::<Vec<_>>();
2703 let boxes = Array2::from_shape_vec((n, 4), boxes_data).unwrap();
2704 let quant_boxes = Quantization {
2705 scale: 0.01,
2706 zero_point: 0,
2707 };
2708
2709 let scores_data: Vec<u8> = (0..n)
2714 .map(|i| 250u8.saturating_sub((i % 200) as u8))
2715 .collect();
2716 let scores = Array2::from_shape_vec((n, num_classes), scores_data).unwrap();
2717 let quant_scores = Quantization {
2718 scale: 0.00392,
2719 zero_point: 0,
2720 };
2721
2722 let result = impl_yolo_split_segdet_quant_get_boxes::<XYXY, _, _>(
2723 (boxes.view(), quant_boxes),
2724 (scores.view(), quant_scores),
2725 0.1,
2726 1.0, Some(Nms::ClassAgnostic), crate::yolo::MAX_NMS_CANDIDATES, usize::MAX, );
2731
2732 assert_eq!(
2733 result.len(),
2734 crate::yolo::MAX_NMS_CANDIDATES,
2735 "quant path pre-NMS cap should truncate to MAX_NMS_CANDIDATES; got {}",
2736 result.len()
2737 );
2738 }
2739
2740 #[test]
2754 fn segdet_combined_tensor_pairs_detection_with_matching_mask_row() {
2755 let nc = 2; let nm = 2; let n = 3; let feat = 4 + nc + nm; let mut data = vec![0.0f32; feat * n];
2778 let set = |d: &mut [f32], r: usize, c: usize, v: f32| d[r * n + c] = v;
2779 set(&mut data, 0, 0, 0.2);
2780 set(&mut data, 1, 0, 0.2);
2781 set(&mut data, 2, 0, 0.1);
2782 set(&mut data, 3, 0, 0.1);
2783 set(&mut data, 0, 1, 0.5);
2784 set(&mut data, 1, 1, 0.5);
2785 set(&mut data, 2, 1, 0.1);
2786 set(&mut data, 3, 1, 0.1);
2787 set(&mut data, 0, 2, 0.8);
2788 set(&mut data, 1, 2, 0.8);
2789 set(&mut data, 2, 2, 0.1);
2790 set(&mut data, 3, 2, 0.1);
2791 set(&mut data, 4, 0, 0.9);
2792 set(&mut data, 4, 2, 0.8);
2793 set(&mut data, 6, 0, 3.0);
2794 set(&mut data, 7, 0, 3.0);
2795 set(&mut data, 6, 2, -3.0);
2796 set(&mut data, 7, 2, -3.0);
2797
2798 let output = Array2::from_shape_vec((feat, n), data).unwrap();
2799 let protos = Array3::<f32>::from_elem((8, 8, nm), 1.0);
2800
2801 let mut boxes: Vec<DetectBox> = Vec::with_capacity(10);
2802 let mut masks: Vec<Segmentation> = Vec::with_capacity(10);
2803 decode_yolo_segdet_float(
2804 output.view(),
2805 protos.view(),
2806 0.5,
2807 0.5,
2808 Some(Nms::ClassAgnostic),
2809 &mut boxes,
2810 &mut masks,
2811 )
2812 .unwrap();
2813
2814 assert_eq!(
2815 boxes.len(),
2816 2,
2817 "two anchors above threshold should survive (a0 score=0.9, a2 score=0.8); got {}",
2818 boxes.len()
2819 );
2820
2821 for (b, m) in boxes.iter().zip(masks.iter()) {
2827 let cx = (b.bbox.xmin + b.bbox.xmax) * 0.5;
2828 let mean = {
2829 let s = &m.segmentation;
2830 let total: u32 = s.iter().map(|&v| v as u32).sum();
2831 total as f32 / s.len() as f32
2832 };
2833 if cx < 0.3 {
2834 assert!(
2836 mean > 200.0,
2837 "anchor 0 detection (centre {cx:.2}) should have high-value mask; got mean {mean}"
2838 );
2839 } else if cx > 0.7 {
2840 assert!(
2842 mean < 50.0,
2843 "anchor 2 detection (centre {cx:.2}) should have low-value mask; got mean {mean}"
2844 );
2845 } else {
2846 panic!("unexpected detection centre {cx:.2}");
2847 }
2848 }
2849 }
2850}