1use std::fmt::Debug;
5
6use ndarray::{
7 parallel::prelude::{IntoParallelIterator, ParallelIterator},
8 s, Array2, Array3, ArrayView1, ArrayView2, ArrayView3,
9};
10use num_traits::{AsPrimitive, Float, PrimInt, Signed};
11
12use crate::{
13 byte::{
14 nms_class_aware_int, nms_extra_class_aware_int, nms_extra_int, nms_int,
15 postprocess_boxes_index_quant, postprocess_boxes_quant, quantize_score_threshold,
16 },
17 configs::Nms,
18 dequant_detect_box,
19 float::{
20 nms_class_aware_float, nms_extra_class_aware_float, nms_extra_float, nms_float,
21 postprocess_boxes_float, postprocess_boxes_index_float,
22 },
23 BBoxTypeTrait, BoundingBox, DetectBox, DetectBoxQuantized, ProtoData, ProtoLayout,
24 Quantization, Segmentation, XYWH, XYXY,
25};
26
27pub const MAX_NMS_CANDIDATES: usize = 30_000;
42
43fn truncate_to_top_k_by_score<E: Send>(boxes: &mut Vec<(DetectBox, E)>, top_k: usize) {
47 if boxes.len() > top_k {
48 boxes.select_nth_unstable_by(top_k, |a, b| b.0.score.total_cmp(&a.0.score));
49 boxes.truncate(top_k);
50 }
51}
52
53fn truncate_to_top_k_by_score_quant<S: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send>(
57 boxes: &mut Vec<(DetectBoxQuantized<S>, E)>,
58 top_k: usize,
59) {
60 if boxes.len() > top_k {
61 boxes.select_nth_unstable_by(top_k, |a, b| b.0.score.cmp(&a.0.score));
62 boxes.truncate(top_k);
63 }
64}
65
66fn dispatch_nms_float(nms: Option<Nms>, iou: f32, boxes: Vec<DetectBox>) -> Vec<DetectBox> {
68 match nms {
69 Some(Nms::ClassAgnostic) => nms_float(iou, boxes),
70 Some(Nms::ClassAware) => nms_class_aware_float(iou, boxes),
71 None => boxes, }
73}
74
75pub(super) fn dispatch_nms_extra_float<E: Send + Sync>(
78 nms: Option<Nms>,
79 iou: f32,
80 boxes: Vec<(DetectBox, E)>,
81) -> Vec<(DetectBox, E)> {
82 match nms {
83 Some(Nms::ClassAgnostic) => nms_extra_float(iou, boxes),
84 Some(Nms::ClassAware) => nms_extra_class_aware_float(iou, boxes),
85 None => boxes, }
87}
88
89fn dispatch_nms_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync>(
92 nms: Option<Nms>,
93 iou: f32,
94 boxes: Vec<DetectBoxQuantized<SCORE>>,
95) -> Vec<DetectBoxQuantized<SCORE>> {
96 match nms {
97 Some(Nms::ClassAgnostic) => nms_int(iou, boxes),
98 Some(Nms::ClassAware) => nms_class_aware_int(iou, boxes),
99 None => boxes, }
101}
102
103fn dispatch_nms_extra_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send + Sync>(
106 nms: Option<Nms>,
107 iou: f32,
108 boxes: Vec<(DetectBoxQuantized<SCORE>, E)>,
109) -> Vec<(DetectBoxQuantized<SCORE>, E)> {
110 match nms {
111 Some(Nms::ClassAgnostic) => nms_extra_int(iou, boxes),
112 Some(Nms::ClassAware) => nms_extra_class_aware_int(iou, boxes),
113 None => boxes, }
115}
116
117pub fn decode_yolo_det<BOX: PrimInt + AsPrimitive<f32> + Send + Sync>(
124 output: (ArrayView2<BOX>, Quantization),
125 score_threshold: f32,
126 iou_threshold: f32,
127 nms: Option<Nms>,
128 output_boxes: &mut Vec<DetectBox>,
129) where
130 f32: AsPrimitive<BOX>,
131{
132 impl_yolo_quant::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
133}
134
135pub fn decode_yolo_det_float<T>(
142 output: ArrayView2<T>,
143 score_threshold: f32,
144 iou_threshold: f32,
145 nms: Option<Nms>,
146 output_boxes: &mut Vec<DetectBox>,
147) where
148 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
149 f32: AsPrimitive<T>,
150{
151 impl_yolo_float::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
152}
153
154pub fn decode_yolo_segdet_quant<
166 BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
167 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
168>(
169 boxes: (ArrayView2<BOX>, Quantization),
170 protos: (ArrayView3<PROTO>, Quantization),
171 score_threshold: f32,
172 iou_threshold: f32,
173 nms: Option<Nms>,
174 output_boxes: &mut Vec<DetectBox>,
175 output_masks: &mut Vec<Segmentation>,
176) -> Result<(), crate::DecoderError>
177where
178 f32: AsPrimitive<BOX>,
179{
180 impl_yolo_segdet_quant::<XYWH, _, _>(
181 boxes,
182 protos,
183 score_threshold,
184 iou_threshold,
185 nms,
186 MAX_NMS_CANDIDATES,
187 output_boxes.capacity(),
188 output_boxes,
189 output_masks,
190 )
191}
192
193pub fn decode_yolo_segdet_float<T>(
205 boxes: ArrayView2<T>,
206 protos: ArrayView3<T>,
207 score_threshold: f32,
208 iou_threshold: f32,
209 nms: Option<Nms>,
210 output_boxes: &mut Vec<DetectBox>,
211 output_masks: &mut Vec<Segmentation>,
212) -> Result<(), crate::DecoderError>
213where
214 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
215 f32: AsPrimitive<T>,
216{
217 impl_yolo_segdet_float::<XYWH, _, _>(
218 boxes,
219 protos,
220 score_threshold,
221 iou_threshold,
222 nms,
223 MAX_NMS_CANDIDATES,
224 output_boxes.capacity(),
225 output_boxes,
226 output_masks,
227 )
228}
229
230pub fn decode_yolo_split_det_quant<
242 BOX: PrimInt + AsPrimitive<i32> + AsPrimitive<f32> + Send + Sync,
243 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
244>(
245 boxes: (ArrayView2<BOX>, Quantization),
246 scores: (ArrayView2<SCORE>, Quantization),
247 score_threshold: f32,
248 iou_threshold: f32,
249 nms: Option<Nms>,
250 output_boxes: &mut Vec<DetectBox>,
251) where
252 f32: AsPrimitive<SCORE>,
253{
254 impl_yolo_split_quant::<XYWH, _, _>(
255 boxes,
256 scores,
257 score_threshold,
258 iou_threshold,
259 nms,
260 output_boxes,
261 );
262}
263
264pub fn decode_yolo_split_det_float<T>(
276 boxes: ArrayView2<T>,
277 scores: ArrayView2<T>,
278 score_threshold: f32,
279 iou_threshold: f32,
280 nms: Option<Nms>,
281 output_boxes: &mut Vec<DetectBox>,
282) where
283 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
284 f32: AsPrimitive<T>,
285{
286 impl_yolo_split_float::<XYWH, _, _>(
287 boxes,
288 scores,
289 score_threshold,
290 iou_threshold,
291 nms,
292 output_boxes,
293 );
294}
295
296#[allow(clippy::too_many_arguments)]
310pub fn decode_yolo_split_segdet<
311 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
312 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
313 MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
314 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
315>(
316 boxes: (ArrayView2<BOX>, Quantization),
317 scores: (ArrayView2<SCORE>, Quantization),
318 mask_coeff: (ArrayView2<MASK>, Quantization),
319 protos: (ArrayView3<PROTO>, Quantization),
320 score_threshold: f32,
321 iou_threshold: f32,
322 nms: Option<Nms>,
323 output_boxes: &mut Vec<DetectBox>,
324 output_masks: &mut Vec<Segmentation>,
325) -> Result<(), crate::DecoderError>
326where
327 f32: AsPrimitive<SCORE>,
328{
329 impl_yolo_split_segdet_quant::<XYWH, _, _, _, _>(
330 boxes,
331 scores,
332 mask_coeff,
333 protos,
334 score_threshold,
335 iou_threshold,
336 nms,
337 output_boxes,
338 output_masks,
339 )
340}
341
342#[allow(clippy::too_many_arguments)]
356pub fn decode_yolo_split_segdet_float<T>(
357 boxes: ArrayView2<T>,
358 scores: ArrayView2<T>,
359 mask_coeff: ArrayView2<T>,
360 protos: ArrayView3<T>,
361 score_threshold: f32,
362 iou_threshold: f32,
363 nms: Option<Nms>,
364 output_boxes: &mut Vec<DetectBox>,
365 output_masks: &mut Vec<Segmentation>,
366) -> Result<(), crate::DecoderError>
367where
368 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
369 f32: AsPrimitive<T>,
370{
371 impl_yolo_split_segdet_float::<XYWH, _, _, _, _>(
372 boxes,
373 scores,
374 mask_coeff,
375 protos,
376 score_threshold,
377 iou_threshold,
378 nms,
379 output_boxes,
380 output_masks,
381 )
382}
383
384pub fn decode_yolo_end_to_end_det_float<T>(
399 output: ArrayView2<T>,
400 score_threshold: f32,
401 output_boxes: &mut Vec<DetectBox>,
402) -> Result<(), crate::DecoderError>
403where
404 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
405 f32: AsPrimitive<T>,
406{
407 if output.shape()[0] < 6 {
409 return Err(crate::DecoderError::InvalidShape(format!(
410 "End-to-end detection output requires at least 6 rows, got {}",
411 output.shape()[0]
412 )));
413 }
414
415 let boxes = output.slice(s![0..4, ..]).reversed_axes();
417 let scores = output.slice(s![4..5, ..]).reversed_axes();
418 let classes = output.slice(s![5, ..]);
419 let mut boxes =
420 postprocess_boxes_index_float::<XYXY, _, _>(score_threshold.as_(), boxes, scores);
421 boxes.truncate(output_boxes.capacity());
422 output_boxes.clear();
423 for (mut b, i) in boxes.into_iter() {
424 b.label = classes[i].as_() as usize;
425 output_boxes.push(b);
426 }
427 Ok(())
429}
430
431pub fn decode_yolo_end_to_end_segdet_float<T>(
449 output: ArrayView2<T>,
450 protos: ArrayView3<T>,
451 score_threshold: f32,
452 output_boxes: &mut Vec<DetectBox>,
453 output_masks: &mut Vec<crate::Segmentation>,
454) -> Result<(), crate::DecoderError>
455where
456 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
457 f32: AsPrimitive<T>,
458{
459 let (boxes, scores, classes, mask_coeff) =
460 postprocess_yolo_end_to_end_segdet(&output, protos.dim().2)?;
461 let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
462 boxes,
463 scores,
464 classes,
465 score_threshold,
466 output_boxes.capacity(),
467 );
468
469 impl_yolo_split_segdet_process_masks(boxes, mask_coeff, protos, output_boxes, output_masks)
472}
473
474pub fn decode_yolo_split_end_to_end_det_float<T: Float + AsPrimitive<f32>>(
483 boxes: ArrayView2<T>,
484 scores: ArrayView2<T>,
485 classes: ArrayView2<T>,
486 score_threshold: f32,
487 output_boxes: &mut Vec<DetectBox>,
488) -> Result<(), crate::DecoderError> {
489 let n = boxes.shape()[1];
490
491 output_boxes.clear();
492
493 let (boxes, scores, classes) = postprocess_yolo_split_end_to_end_det(boxes, scores, &classes)?;
494
495 for i in 0..n {
496 let score: f32 = scores[[i, 0]].as_();
497 if score < score_threshold {
498 continue;
499 }
500 if output_boxes.len() >= output_boxes.capacity() {
501 break;
502 }
503 output_boxes.push(DetectBox {
504 bbox: BoundingBox {
505 xmin: boxes[[i, 0]].as_(),
506 ymin: boxes[[i, 1]].as_(),
507 xmax: boxes[[i, 2]].as_(),
508 ymax: boxes[[i, 3]].as_(),
509 },
510 score,
511 label: classes[i].as_() as usize,
512 });
513 }
514 Ok(())
515}
516
517#[allow(clippy::too_many_arguments)]
526pub fn decode_yolo_split_end_to_end_segdet_float<T>(
527 boxes: ArrayView2<T>,
528 scores: ArrayView2<T>,
529 classes: ArrayView2<T>,
530 mask_coeff: ArrayView2<T>,
531 protos: ArrayView3<T>,
532 score_threshold: f32,
533 output_boxes: &mut Vec<DetectBox>,
534 output_masks: &mut Vec<crate::Segmentation>,
535) -> Result<(), crate::DecoderError>
536where
537 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
538 f32: AsPrimitive<T>,
539{
540 let (boxes, scores, classes, mask_coeff) =
541 postprocess_yolo_split_end_to_end_segdet(boxes, scores, &classes, mask_coeff)?;
542 let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
543 boxes,
544 scores,
545 classes,
546 score_threshold,
547 output_boxes.capacity(),
548 );
549
550 impl_yolo_split_segdet_process_masks(boxes, mask_coeff, protos, output_boxes, output_masks)
551}
552
553#[allow(clippy::type_complexity)]
554pub(crate) fn postprocess_yolo_end_to_end_segdet<'a, T>(
555 output: &'a ArrayView2<'_, T>,
556 num_protos: usize,
557) -> Result<
558 (
559 ArrayView2<'a, T>,
560 ArrayView2<'a, T>,
561 ArrayView1<'a, T>,
562 ArrayView2<'a, T>,
563 ),
564 crate::DecoderError,
565> {
566 if output.shape()[0] < 7 {
568 return Err(crate::DecoderError::InvalidShape(format!(
569 "End-to-end segdet output requires at least 7 rows, got {}",
570 output.shape()[0]
571 )));
572 }
573
574 let num_mask_coeffs = output.shape()[0] - 6;
575 if num_mask_coeffs != num_protos {
576 return Err(crate::DecoderError::InvalidShape(format!(
577 "Mask coefficients count ({}) doesn't match protos count ({})",
578 num_mask_coeffs, num_protos
579 )));
580 }
581
582 let boxes = output.slice(s![0..4, ..]).reversed_axes();
584 let scores = output.slice(s![4..5, ..]).reversed_axes();
585 let classes = output.slice(s![5, ..]);
586 let mask_coeff = output.slice(s![6.., ..]).reversed_axes();
587 Ok((boxes, scores, classes, mask_coeff))
588}
589
590#[allow(clippy::type_complexity)]
597pub(crate) fn postprocess_yolo_split_end_to_end_det<'a, 'b, 'c, BOXES, SCORES, CLASS>(
598 boxes: ArrayView2<'a, BOXES>,
599 scores: ArrayView2<'b, SCORES>,
600 classes: &'c ArrayView2<CLASS>,
601) -> Result<
602 (
603 ArrayView2<'a, BOXES>,
604 ArrayView2<'b, SCORES>,
605 ArrayView1<'c, CLASS>,
606 ),
607 crate::DecoderError,
608> {
609 let num_boxes = boxes.shape()[1];
610 if boxes.shape()[0] != 4 {
611 return Err(crate::DecoderError::InvalidShape(format!(
612 "Split end-to-end box_coords must be 4, got {}",
613 boxes.shape()[0]
614 )));
615 }
616
617 if scores.shape()[0] != 1 {
618 return Err(crate::DecoderError::InvalidShape(format!(
619 "Split end-to-end scores num_classes must be 1, got {}",
620 scores.shape()[0]
621 )));
622 }
623
624 if classes.shape()[0] != 1 {
625 return Err(crate::DecoderError::InvalidShape(format!(
626 "Split end-to-end classes num_classes must be 1, got {}",
627 classes.shape()[0]
628 )));
629 }
630
631 if scores.shape()[1] != num_boxes {
632 return Err(crate::DecoderError::InvalidShape(format!(
633 "Split end-to-end scores must have same num_boxes as boxes ({}), got {}",
634 num_boxes,
635 scores.shape()[1]
636 )));
637 }
638
639 if classes.shape()[1] != num_boxes {
640 return Err(crate::DecoderError::InvalidShape(format!(
641 "Split end-to-end classes must have same num_boxes as boxes ({}), got {}",
642 num_boxes,
643 classes.shape()[1]
644 )));
645 }
646
647 let boxes = boxes.reversed_axes();
648 let scores = scores.reversed_axes();
649 let classes = classes.slice(s![0, ..]);
650 Ok((boxes, scores, classes))
651}
652
653#[allow(clippy::type_complexity)]
656pub(crate) fn postprocess_yolo_split_end_to_end_segdet<
657 'a,
658 'b,
659 'c,
660 'd,
661 BOXES,
662 SCORES,
663 CLASS,
664 MASK,
665>(
666 boxes: ArrayView2<'a, BOXES>,
667 scores: ArrayView2<'b, SCORES>,
668 classes: &'c ArrayView2<CLASS>,
669 mask_coeff: ArrayView2<'d, MASK>,
670) -> Result<
671 (
672 ArrayView2<'a, BOXES>,
673 ArrayView2<'b, SCORES>,
674 ArrayView1<'c, CLASS>,
675 ArrayView2<'d, MASK>,
676 ),
677 crate::DecoderError,
678> {
679 let num_boxes = boxes.shape()[1];
680 if boxes.shape()[0] != 4 {
681 return Err(crate::DecoderError::InvalidShape(format!(
682 "Split end-to-end box_coords must be 4, got {}",
683 boxes.shape()[0]
684 )));
685 }
686
687 if scores.shape()[0] != 1 {
688 return Err(crate::DecoderError::InvalidShape(format!(
689 "Split end-to-end scores num_classes must be 1, got {}",
690 scores.shape()[0]
691 )));
692 }
693
694 if classes.shape()[0] != 1 {
695 return Err(crate::DecoderError::InvalidShape(format!(
696 "Split end-to-end classes num_classes must be 1, got {}",
697 classes.shape()[0]
698 )));
699 }
700
701 if scores.shape()[1] != num_boxes {
702 return Err(crate::DecoderError::InvalidShape(format!(
703 "Split end-to-end scores must have same num_boxes as boxes ({}), got {}",
704 num_boxes,
705 scores.shape()[1]
706 )));
707 }
708
709 if classes.shape()[1] != num_boxes {
710 return Err(crate::DecoderError::InvalidShape(format!(
711 "Split end-to-end classes must have same num_boxes as boxes ({}), got {}",
712 num_boxes,
713 classes.shape()[1]
714 )));
715 }
716
717 if mask_coeff.shape()[1] != num_boxes {
718 return Err(crate::DecoderError::InvalidShape(format!(
719 "Split end-to-end mask_coeff must have same num_boxes as boxes ({}), got {}",
720 num_boxes,
721 mask_coeff.shape()[1]
722 )));
723 }
724
725 let boxes = boxes.reversed_axes();
726 let scores = scores.reversed_axes();
727 let classes = classes.slice(s![0, ..]);
728 let mask_coeff = mask_coeff.reversed_axes();
729 Ok((boxes, scores, classes, mask_coeff))
730}
731pub(crate) fn impl_yolo_quant<B: BBoxTypeTrait, T: PrimInt + AsPrimitive<f32> + Send + Sync>(
736 output: (ArrayView2<T>, Quantization),
737 score_threshold: f32,
738 iou_threshold: f32,
739 nms: Option<Nms>,
740 output_boxes: &mut Vec<DetectBox>,
741) where
742 f32: AsPrimitive<T>,
743{
744 let _span = tracing::trace_span!("decode", mode = "quant_det").entered();
745 let (boxes, quant_boxes) = output;
746 let (boxes_tensor, scores_tensor) = postprocess_yolo(&boxes);
747
748 let boxes = {
749 let score_threshold = quantize_score_threshold(score_threshold, quant_boxes);
750 postprocess_boxes_quant::<B, _, _>(
751 score_threshold,
752 boxes_tensor,
753 scores_tensor,
754 quant_boxes,
755 )
756 };
757
758 let boxes = dispatch_nms_int(nms, iou_threshold, boxes);
759 let len = output_boxes.capacity().min(boxes.len());
760 output_boxes.clear();
761 for b in boxes.iter().take(len) {
762 output_boxes.push(dequant_detect_box(b, quant_boxes));
763 }
764}
765
766pub(crate) fn impl_yolo_float<B: BBoxTypeTrait, T: Float + AsPrimitive<f32> + Send + Sync>(
771 output: ArrayView2<T>,
772 score_threshold: f32,
773 iou_threshold: f32,
774 nms: Option<Nms>,
775 output_boxes: &mut Vec<DetectBox>,
776) where
777 f32: AsPrimitive<T>,
778{
779 let _span = tracing::trace_span!("decode", mode = "float_det").entered();
780 let (boxes_tensor, scores_tensor) = postprocess_yolo(&output);
781 let boxes =
782 postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
783 let boxes = dispatch_nms_float(nms, iou_threshold, boxes);
784 let len = output_boxes.capacity().min(boxes.len());
785 output_boxes.clear();
786 for b in boxes.into_iter().take(len) {
787 output_boxes.push(b);
788 }
789}
790
791pub(crate) fn impl_yolo_split_quant<
801 B: BBoxTypeTrait,
802 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
803 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
804>(
805 boxes: (ArrayView2<BOX>, Quantization),
806 scores: (ArrayView2<SCORE>, Quantization),
807 score_threshold: f32,
808 iou_threshold: f32,
809 nms: Option<Nms>,
810 output_boxes: &mut Vec<DetectBox>,
811) where
812 f32: AsPrimitive<SCORE>,
813{
814 let _span = tracing::trace_span!("decode", mode = "split_quant_det").entered();
815 let (boxes_tensor, quant_boxes) = boxes;
816 let (scores_tensor, quant_scores) = scores;
817
818 let boxes_tensor = boxes_tensor.reversed_axes();
819 let scores_tensor = scores_tensor.reversed_axes();
820
821 let boxes = {
822 let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
823 postprocess_boxes_quant::<B, _, _>(
824 score_threshold,
825 boxes_tensor,
826 scores_tensor,
827 quant_boxes,
828 )
829 };
830
831 let boxes = dispatch_nms_int(nms, iou_threshold, boxes);
832 let len = output_boxes.capacity().min(boxes.len());
833 output_boxes.clear();
834 for b in boxes.iter().take(len) {
835 output_boxes.push(dequant_detect_box(b, quant_scores));
836 }
837}
838
839pub(crate) fn impl_yolo_split_float<
848 B: BBoxTypeTrait,
849 BOX: Float + AsPrimitive<f32> + Send + Sync,
850 SCORE: Float + AsPrimitive<f32> + Send + Sync,
851>(
852 boxes_tensor: ArrayView2<BOX>,
853 scores_tensor: ArrayView2<SCORE>,
854 score_threshold: f32,
855 iou_threshold: f32,
856 nms: Option<Nms>,
857 output_boxes: &mut Vec<DetectBox>,
858) where
859 f32: AsPrimitive<SCORE>,
860{
861 let _span = tracing::trace_span!("decode", mode = "split_float_det").entered();
862 let boxes_tensor = boxes_tensor.reversed_axes();
863 let scores_tensor = scores_tensor.reversed_axes();
864 let boxes =
865 postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
866 let boxes = dispatch_nms_float(nms, iou_threshold, boxes);
867 let len = output_boxes.capacity().min(boxes.len());
868 output_boxes.clear();
869 for b in boxes.into_iter().take(len) {
870 output_boxes.push(b);
871 }
872}
873
874#[allow(clippy::too_many_arguments)]
884pub(crate) fn impl_yolo_segdet_quant<
885 B: BBoxTypeTrait,
886 BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
887 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
888>(
889 boxes: (ArrayView2<BOX>, Quantization),
890 protos: (ArrayView3<PROTO>, Quantization),
891 score_threshold: f32,
892 iou_threshold: f32,
893 nms: Option<Nms>,
894 pre_nms_top_k: usize,
895 max_det: usize,
896 output_boxes: &mut Vec<DetectBox>,
897 output_masks: &mut Vec<Segmentation>,
898) -> Result<(), crate::DecoderError>
899where
900 f32: AsPrimitive<BOX>,
901{
902 let (boxes, quant_boxes) = boxes;
903 let num_protos = protos.0.dim().2;
904
905 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
906 let boxes = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
907 (boxes_tensor, quant_boxes),
908 (scores_tensor, quant_boxes),
909 score_threshold,
910 iou_threshold,
911 nms,
912 pre_nms_top_k,
913 max_det,
914 );
915
916 impl_yolo_split_segdet_quant_process_masks::<_, _>(
917 boxes,
918 (mask_tensor, quant_boxes),
919 protos,
920 output_boxes,
921 output_masks,
922 )
923}
924
925#[allow(clippy::too_many_arguments)]
935pub(crate) fn impl_yolo_segdet_float<
936 B: BBoxTypeTrait,
937 BOX: Float + AsPrimitive<f32> + Send + Sync,
938 PROTO: Float + AsPrimitive<f32> + Send + Sync,
939>(
940 boxes: ArrayView2<BOX>,
941 protos: ArrayView3<PROTO>,
942 score_threshold: f32,
943 iou_threshold: f32,
944 nms: Option<Nms>,
945 pre_nms_top_k: usize,
946 max_det: usize,
947 output_boxes: &mut Vec<DetectBox>,
948 output_masks: &mut Vec<Segmentation>,
949) -> Result<(), crate::DecoderError>
950where
951 f32: AsPrimitive<BOX>,
952{
953 let num_protos = protos.dim().2;
954 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
955 let boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
956 boxes_tensor,
957 scores_tensor,
958 score_threshold,
959 iou_threshold,
960 nms,
961 pre_nms_top_k,
962 max_det,
963 );
964 impl_yolo_split_segdet_process_masks(boxes, mask_tensor, protos, output_boxes, output_masks)
965}
966
967pub(crate) fn impl_yolo_segdet_get_boxes<
968 B: BBoxTypeTrait,
969 BOX: Float + AsPrimitive<f32> + Send + Sync,
970 SCORE: Float + AsPrimitive<f32> + Send + Sync,
971>(
972 boxes_tensor: ArrayView2<BOX>,
973 scores_tensor: ArrayView2<SCORE>,
974 score_threshold: f32,
975 iou_threshold: f32,
976 nms: Option<Nms>,
977 pre_nms_top_k: usize,
978 max_det: usize,
979) -> Vec<(DetectBox, usize)>
980where
981 f32: AsPrimitive<SCORE>,
982{
983 let span = tracing::trace_span!(
984 "decode",
985 n_candidates = tracing::field::Empty,
986 n_after_topk = tracing::field::Empty,
987 n_after_nms = tracing::field::Empty,
988 n_detections = tracing::field::Empty,
989 );
990 let _guard = span.enter();
991
992 let mut boxes = {
993 let _s = tracing::trace_span!("score_filter").entered();
994 postprocess_boxes_index_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor)
995 };
996 span.record("n_candidates", boxes.len());
997
998 if nms.is_some() {
999 let _s = tracing::trace_span!("top_k", k = pre_nms_top_k).entered();
1000 truncate_to_top_k_by_score(&mut boxes, pre_nms_top_k);
1001 }
1002 span.record("n_after_topk", boxes.len());
1003
1004 let mut boxes = {
1005 let _s = tracing::trace_span!("nms").entered();
1006 dispatch_nms_extra_float(nms, iou_threshold, boxes)
1007 };
1008 span.record("n_after_nms", boxes.len());
1009
1010 boxes.sort_unstable_by(|a, b| b.0.score.total_cmp(&a.0.score));
1011 boxes.truncate(max_det);
1012 span.record("n_detections", boxes.len());
1013
1014 boxes
1015}
1016
1017pub(crate) fn impl_yolo_end_to_end_segdet_get_boxes<
1018 B: BBoxTypeTrait,
1019 BOX: Float + AsPrimitive<f32> + Send + Sync,
1020 SCORE: Float + AsPrimitive<f32> + Send + Sync,
1021 CLASS: AsPrimitive<f32> + Send + Sync,
1022>(
1023 boxes: ArrayView2<BOX>,
1024 scores: ArrayView2<SCORE>,
1025 classes: ArrayView1<CLASS>,
1026 score_threshold: f32,
1027 max_boxes: usize,
1028) -> Vec<(DetectBox, usize)>
1029where
1030 f32: AsPrimitive<SCORE>,
1031{
1032 let mut boxes = postprocess_boxes_index_float::<B, _, _>(score_threshold.as_(), boxes, scores);
1033 boxes.truncate(max_boxes);
1034 for (b, ind) in &mut boxes {
1035 b.label = classes[*ind].as_().round() as usize;
1036 }
1037 boxes
1038}
1039
1040pub(crate) fn impl_yolo_split_segdet_process_masks<
1041 MASK: Float + AsPrimitive<f32> + Send + Sync,
1042 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1043>(
1044 mut boxes: Vec<(DetectBox, usize)>,
1045 masks_tensor: ArrayView2<MASK>,
1046 protos_tensor: ArrayView3<PROTO>,
1047 output_boxes: &mut Vec<DetectBox>,
1048 output_masks: &mut Vec<Segmentation>,
1049) -> Result<(), crate::DecoderError> {
1050 let _span = tracing::trace_span!("process_masks", n = boxes.len(), mode = "float").entered();
1051 let cap = output_boxes.capacity();
1053 if cap > 0 && boxes.len() > cap {
1054 boxes.truncate(cap);
1055 }
1056
1057 let boxes = decode_segdet_f32(boxes, masks_tensor, protos_tensor)?;
1058 output_boxes.clear();
1059 output_masks.clear();
1060 for (b, m) in boxes.into_iter() {
1061 output_boxes.push(b);
1062 output_masks.push(Segmentation {
1063 xmin: b.bbox.xmin,
1064 ymin: b.bbox.ymin,
1065 xmax: b.bbox.xmax,
1066 ymax: b.bbox.ymax,
1067 segmentation: m,
1068 });
1069 }
1070 Ok(())
1071}
1072pub(crate) fn impl_yolo_split_segdet_quant_get_boxes<
1076 B: BBoxTypeTrait,
1077 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
1078 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
1079>(
1080 boxes: (ArrayView2<BOX>, Quantization),
1081 scores: (ArrayView2<SCORE>, Quantization),
1082 score_threshold: f32,
1083 iou_threshold: f32,
1084 nms: Option<Nms>,
1085 pre_nms_top_k: usize,
1086 max_det: usize,
1087) -> Vec<(DetectBox, usize)>
1088where
1089 f32: AsPrimitive<SCORE>,
1090{
1091 let (boxes_tensor, quant_boxes) = boxes;
1092 let (scores_tensor, quant_scores) = scores;
1093
1094 let span = tracing::trace_span!(
1095 "decode",
1096 n_candidates = tracing::field::Empty,
1097 n_after_topk = tracing::field::Empty,
1098 n_after_nms = tracing::field::Empty,
1099 n_detections = tracing::field::Empty,
1100 );
1101 let _guard = span.enter();
1102
1103 let mut boxes = {
1104 let _s = tracing::trace_span!("score_filter").entered();
1105 let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
1106 postprocess_boxes_index_quant::<B, _, _>(
1107 score_threshold,
1108 boxes_tensor,
1109 scores_tensor,
1110 quant_boxes,
1111 )
1112 };
1113 span.record("n_candidates", boxes.len());
1114
1115 if nms.is_some() {
1116 let _s = tracing::trace_span!("top_k", k = pre_nms_top_k).entered();
1117 truncate_to_top_k_by_score_quant(&mut boxes, pre_nms_top_k);
1118 }
1119 span.record("n_after_topk", boxes.len());
1120
1121 let mut boxes = {
1122 let _s = tracing::trace_span!("nms").entered();
1123 dispatch_nms_extra_int(nms, iou_threshold, boxes)
1124 };
1125 span.record("n_after_nms", boxes.len());
1126
1127 boxes.sort_unstable_by(|a, b| b.0.score.cmp(&a.0.score));
1129 boxes.truncate(max_det);
1130 let result: Vec<_> = {
1131 let _s = tracing::trace_span!("box_dequant", n = boxes.len()).entered();
1132 boxes
1133 .into_iter()
1134 .map(|(b, i)| (dequant_detect_box(&b, quant_scores), i))
1135 .collect()
1136 };
1137 span.record("n_detections", result.len());
1138
1139 result
1140}
1141
1142pub(crate) fn impl_yolo_split_segdet_quant_process_masks<
1143 MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1144 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1145>(
1146 mut boxes: Vec<(DetectBox, usize)>,
1147 mask_coeff: (ArrayView2<MASK>, Quantization),
1148 protos: (ArrayView3<PROTO>, Quantization),
1149 output_boxes: &mut Vec<DetectBox>,
1150 output_masks: &mut Vec<Segmentation>,
1151) -> Result<(), crate::DecoderError> {
1152 let _span = tracing::trace_span!("process_masks", n = boxes.len(), mode = "quant").entered();
1153 let (masks, quant_masks) = mask_coeff;
1154 let (protos, quant_protos) = protos;
1155
1156 let cap = output_boxes.capacity();
1158 if cap > 0 && boxes.len() > cap {
1159 boxes.truncate(cap);
1160 }
1161
1162 let boxes = decode_segdet_quant(boxes, masks, protos, quant_masks, quant_protos)?;
1163 output_boxes.clear();
1164 output_masks.clear();
1165 for (b, m) in boxes.into_iter() {
1166 output_boxes.push(b);
1167 output_masks.push(Segmentation {
1168 xmin: b.bbox.xmin,
1169 ymin: b.bbox.ymin,
1170 xmax: b.bbox.xmax,
1171 ymax: b.bbox.ymax,
1172 segmentation: m,
1173 });
1174 }
1175 Ok(())
1176}
1177
1178#[allow(clippy::too_many_arguments)]
1179pub(crate) fn impl_yolo_split_segdet_quant<
1191 B: BBoxTypeTrait,
1192 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
1193 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
1194 MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1195 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1196>(
1197 boxes: (ArrayView2<BOX>, Quantization),
1198 scores: (ArrayView2<SCORE>, Quantization),
1199 mask_coeff: (ArrayView2<MASK>, Quantization),
1200 protos: (ArrayView3<PROTO>, Quantization),
1201 score_threshold: f32,
1202 iou_threshold: f32,
1203 nms: Option<Nms>,
1204 output_boxes: &mut Vec<DetectBox>,
1205 output_masks: &mut Vec<Segmentation>,
1206) -> Result<(), crate::DecoderError>
1207where
1208 f32: AsPrimitive<SCORE>,
1209{
1210 let (boxes_, scores_, mask_coeff_) =
1211 postprocess_yolo_split_segdet(boxes.0, scores.0, mask_coeff.0);
1212 let boxes = (boxes_, boxes.1);
1213 let scores = (scores_, scores.1);
1214 let mask_coeff = (mask_coeff_, mask_coeff.1);
1215
1216 let boxes = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
1217 boxes,
1218 scores,
1219 score_threshold,
1220 iou_threshold,
1221 nms,
1222 MAX_NMS_CANDIDATES,
1223 output_boxes.capacity(),
1224 );
1225
1226 impl_yolo_split_segdet_quant_process_masks(
1227 boxes,
1228 mask_coeff,
1229 protos,
1230 output_boxes,
1231 output_masks,
1232 )
1233}
1234
1235#[allow(clippy::too_many_arguments)]
1236pub(crate) fn impl_yolo_split_segdet_float<
1248 B: BBoxTypeTrait,
1249 BOX: Float + AsPrimitive<f32> + Send + Sync,
1250 SCORE: Float + AsPrimitive<f32> + Send + Sync,
1251 MASK: Float + AsPrimitive<f32> + Send + Sync,
1252 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1253>(
1254 boxes_tensor: ArrayView2<BOX>,
1255 scores_tensor: ArrayView2<SCORE>,
1256 mask_tensor: ArrayView2<MASK>,
1257 protos: ArrayView3<PROTO>,
1258 score_threshold: f32,
1259 iou_threshold: f32,
1260 nms: Option<Nms>,
1261 output_boxes: &mut Vec<DetectBox>,
1262 output_masks: &mut Vec<Segmentation>,
1263) -> Result<(), crate::DecoderError>
1264where
1265 f32: AsPrimitive<SCORE>,
1266{
1267 let (boxes_tensor, scores_tensor, mask_tensor) =
1268 postprocess_yolo_split_segdet(boxes_tensor, scores_tensor, mask_tensor);
1269
1270 let boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
1271 boxes_tensor,
1272 scores_tensor,
1273 score_threshold,
1274 iou_threshold,
1275 nms,
1276 MAX_NMS_CANDIDATES,
1277 output_boxes.capacity(),
1278 );
1279 impl_yolo_split_segdet_process_masks(boxes, mask_tensor, protos, output_boxes, output_masks)
1280}
1281
1282#[allow(clippy::too_many_arguments)]
1289pub fn impl_yolo_segdet_quant_proto<
1290 B: BBoxTypeTrait,
1291 BOX: PrimInt
1292 + AsPrimitive<i64>
1293 + AsPrimitive<i128>
1294 + AsPrimitive<f32>
1295 + AsPrimitive<i8>
1296 + Send
1297 + Sync,
1298 PROTO: PrimInt
1299 + AsPrimitive<i64>
1300 + AsPrimitive<i128>
1301 + AsPrimitive<f32>
1302 + AsPrimitive<i8>
1303 + Send
1304 + Sync,
1305>(
1306 boxes: (ArrayView2<BOX>, Quantization),
1307 protos: (ArrayView3<PROTO>, Quantization),
1308 score_threshold: f32,
1309 iou_threshold: f32,
1310 nms: Option<Nms>,
1311 pre_nms_top_k: usize,
1312 max_det: usize,
1313 output_boxes: &mut Vec<DetectBox>,
1314) -> ProtoData
1315where
1316 f32: AsPrimitive<BOX>,
1317{
1318 let (boxes_arr, quant_boxes) = boxes;
1319 let (protos_arr, quant_protos) = protos;
1320 let num_protos = protos_arr.dim().2;
1321
1322 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes_arr, num_protos);
1323
1324 let det_indices = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
1325 (boxes_tensor, quant_boxes),
1326 (scores_tensor, quant_boxes),
1327 score_threshold,
1328 iou_threshold,
1329 nms,
1330 pre_nms_top_k,
1331 max_det,
1332 );
1333
1334 extract_proto_data_quant(
1335 det_indices,
1336 mask_tensor,
1337 quant_boxes,
1338 protos_arr,
1339 quant_protos,
1340 output_boxes,
1341 )
1342}
1343
1344#[allow(clippy::too_many_arguments)]
1347pub(crate) fn impl_yolo_segdet_float_proto<
1348 B: BBoxTypeTrait,
1349 BOX: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1350 PROTO: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1351>(
1352 boxes: ArrayView2<BOX>,
1353 protos: ArrayView3<PROTO>,
1354 score_threshold: f32,
1355 iou_threshold: f32,
1356 nms: Option<Nms>,
1357 pre_nms_top_k: usize,
1358 max_det: usize,
1359 output_boxes: &mut Vec<DetectBox>,
1360) -> ProtoData
1361where
1362 f32: AsPrimitive<BOX>,
1363{
1364 let num_protos = protos.dim().2;
1365 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
1366
1367 let boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
1368 boxes_tensor,
1369 scores_tensor,
1370 score_threshold,
1371 iou_threshold,
1372 nms,
1373 pre_nms_top_k,
1374 max_det,
1375 );
1376
1377 extract_proto_data_float(boxes, mask_tensor, protos, output_boxes)
1378}
1379
1380#[allow(clippy::too_many_arguments)]
1383pub(crate) fn impl_yolo_split_segdet_float_proto<
1384 B: BBoxTypeTrait,
1385 BOX: Float + AsPrimitive<f32> + Send + Sync,
1386 SCORE: Float + AsPrimitive<f32> + Send + Sync,
1387 MASK: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1388 PROTO: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1389>(
1390 boxes_tensor: ArrayView2<BOX>,
1391 scores_tensor: ArrayView2<SCORE>,
1392 mask_tensor: ArrayView2<MASK>,
1393 protos: ArrayView3<PROTO>,
1394 score_threshold: f32,
1395 iou_threshold: f32,
1396 nms: Option<Nms>,
1397 pre_nms_top_k: usize,
1398 max_det: usize,
1399 output_boxes: &mut Vec<DetectBox>,
1400) -> ProtoData
1401where
1402 f32: AsPrimitive<SCORE>,
1403{
1404 let (boxes_tensor, scores_tensor, mask_tensor) =
1405 postprocess_yolo_split_segdet(boxes_tensor, scores_tensor, mask_tensor);
1406 let det_indices = impl_yolo_segdet_get_boxes::<B, _, _>(
1407 boxes_tensor,
1408 scores_tensor,
1409 score_threshold,
1410 iou_threshold,
1411 nms,
1412 pre_nms_top_k,
1413 max_det,
1414 );
1415
1416 extract_proto_data_float(det_indices, mask_tensor, protos, output_boxes)
1417}
1418
1419pub fn decode_yolo_end_to_end_segdet_float_proto<T>(
1421 output: ArrayView2<T>,
1422 protos: ArrayView3<T>,
1423 score_threshold: f32,
1424 output_boxes: &mut Vec<DetectBox>,
1425) -> Result<ProtoData, crate::DecoderError>
1426where
1427 T: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1428 f32: AsPrimitive<T>,
1429{
1430 let (boxes, scores, classes, mask_coeff) =
1431 postprocess_yolo_end_to_end_segdet(&output, protos.dim().2)?;
1432 let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
1433 boxes,
1434 scores,
1435 classes,
1436 score_threshold,
1437 output_boxes.capacity(),
1438 );
1439
1440 Ok(extract_proto_data_float(
1441 boxes,
1442 mask_coeff,
1443 protos,
1444 output_boxes,
1445 ))
1446}
1447
1448#[allow(clippy::too_many_arguments)]
1450pub fn decode_yolo_split_end_to_end_segdet_float_proto<T>(
1451 boxes: ArrayView2<T>,
1452 scores: ArrayView2<T>,
1453 classes: ArrayView2<T>,
1454 mask_coeff: ArrayView2<T>,
1455 protos: ArrayView3<T>,
1456 score_threshold: f32,
1457 output_boxes: &mut Vec<DetectBox>,
1458) -> Result<ProtoData, crate::DecoderError>
1459where
1460 T: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1461 f32: AsPrimitive<T>,
1462{
1463 let (boxes, scores, classes, mask_coeff) =
1464 postprocess_yolo_split_end_to_end_segdet(boxes, scores, &classes, mask_coeff)?;
1465 let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
1466 boxes,
1467 scores,
1468 classes,
1469 score_threshold,
1470 output_boxes.capacity(),
1471 );
1472
1473 Ok(extract_proto_data_float(
1474 boxes,
1475 mask_coeff,
1476 protos,
1477 output_boxes,
1478 ))
1479}
1480
1481pub(super) fn extract_proto_data_float<
1488 MASK: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1489 PROTO: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1490>(
1491 det_indices: Vec<(DetectBox, usize)>,
1492 mask_tensor: ArrayView2<MASK>,
1493 protos: ArrayView3<PROTO>,
1494 output_boxes: &mut Vec<DetectBox>,
1495) -> ProtoData {
1496 let _span = tracing::trace_span!(
1497 "extract_proto",
1498 n = det_indices.len(),
1499 num_protos = mask_tensor.ncols(),
1500 layout = "nhwc",
1501 )
1502 .entered();
1503
1504 let num_protos = mask_tensor.ncols();
1505 let n = det_indices.len();
1506
1507 let mut coeff_rows: Vec<MASK> = Vec::with_capacity(n * num_protos);
1512 output_boxes.clear();
1513 for (det, idx) in det_indices {
1514 output_boxes.push(det);
1515 let row = mask_tensor.row(idx);
1516 coeff_rows.extend(row.iter().copied());
1517 }
1518
1519 let mask_coefficients = MASK::slice_into_tensor_dyn(&coeff_rows, &[n, num_protos])
1520 .expect("allocating mask_coefficients TensorDyn");
1521 let protos_tensor =
1522 PROTO::arrayview3_into_tensor_dyn(protos).expect("allocating protos TensorDyn");
1523
1524 ProtoData {
1525 mask_coefficients,
1526 protos: protos_tensor,
1527 layout: ProtoLayout::Nhwc,
1528 }
1529}
1530
1531pub(crate) fn extract_proto_data_quant<
1540 MASK: PrimInt + AsPrimitive<f32> + AsPrimitive<i8> + Send + Sync,
1541 PROTO: PrimInt + AsPrimitive<f32> + AsPrimitive<i8> + Send + Sync + 'static,
1542>(
1543 det_indices: Vec<(DetectBox, usize)>,
1544 mask_tensor: ArrayView2<MASK>,
1545 quant_masks: Quantization,
1546 protos: ArrayView3<PROTO>,
1547 quant_protos: Quantization,
1548 output_boxes: &mut Vec<DetectBox>,
1549) -> ProtoData {
1550 use edgefirst_tensor::{Tensor, TensorDyn, TensorMapTrait, TensorMemory, TensorTrait};
1551
1552 let span = tracing::trace_span!(
1553 "extract_proto",
1554 n = det_indices.len(),
1555 num_protos = tracing::field::Empty,
1556 layout = tracing::field::Empty,
1557 );
1558 let _guard = span.enter();
1559
1560 let num_protos = mask_tensor.ncols();
1561 let n = det_indices.len();
1562 span.record("num_protos", num_protos);
1563
1564 if n == 0 {
1570 output_boxes.clear();
1571 let (h, w, k) = protos.dim();
1572
1573 let (proto_shape, proto_layout) = if std::any::TypeId::of::<PROTO>()
1575 == std::any::TypeId::of::<i8>()
1576 {
1577 if protos.is_standard_layout() {
1578 (&[h, w, k][..], ProtoLayout::Nhwc)
1579 } else if protos.ndim() == 3 && protos.strides() == [w as isize, 1, (h * w) as isize] {
1580 (&[k, h, w][..], ProtoLayout::Nchw)
1581 } else {
1582 (&[h, w, k][..], ProtoLayout::Nhwc)
1583 }
1584 } else {
1585 (&[h, w, k][..], ProtoLayout::Nhwc)
1586 };
1587
1588 let coeff_tensor = Tensor::<i8>::new(&[0, num_protos], Some(TensorMemory::Mem), None)
1589 .expect("allocating empty mask_coefficients tensor");
1590 let coeff_quant =
1591 edgefirst_tensor::Quantization::per_tensor(quant_masks.scale, quant_masks.zero_point);
1592 let coeff_tensor = coeff_tensor
1593 .with_quantization(coeff_quant)
1594 .expect("per-tensor quantization on mask coefficients");
1595 let protos_tensor = Tensor::<i8>::new(proto_shape, Some(TensorMemory::Mem), None)
1596 .expect("allocating protos tensor");
1597 let tensor_quant =
1598 edgefirst_tensor::Quantization::per_tensor(quant_protos.scale, quant_protos.zero_point);
1599 let protos_tensor = protos_tensor
1600 .with_quantization(tensor_quant)
1601 .expect("per-tensor quantization on protos tensor");
1602 return ProtoData {
1603 mask_coefficients: TensorDyn::I8(coeff_tensor),
1604 protos: TensorDyn::I8(protos_tensor),
1605 layout: proto_layout,
1606 };
1607 }
1608
1609 let mut coeff_i8 = Vec::<i8>::with_capacity(n * num_protos);
1613 output_boxes.clear();
1614 for (det, idx) in det_indices {
1615 output_boxes.push(det);
1616 let row = mask_tensor.row(idx);
1617 coeff_i8.extend(row.iter().map(|v| {
1618 let v_i8: i8 = v.as_();
1619 v_i8
1620 }));
1621 }
1622
1623 let coeff_tensor = Tensor::<i8>::new(&[n, num_protos], Some(TensorMemory::Mem), None)
1626 .expect("allocating mask_coefficients tensor");
1627 if n > 0 {
1628 let mut m = coeff_tensor
1629 .map()
1630 .expect("mapping mask_coefficients tensor");
1631 m.as_mut_slice().copy_from_slice(&coeff_i8);
1632 }
1633 let coeff_quant =
1634 edgefirst_tensor::Quantization::per_tensor(quant_masks.scale, quant_masks.zero_point);
1635 let coeff_tensor = coeff_tensor
1636 .with_quantization(coeff_quant)
1637 .expect("per-tensor quantization on mask coefficients");
1638 let mask_coefficients = TensorDyn::I8(coeff_tensor);
1639
1640 let (h, w, k) = protos.dim();
1644
1645 let (proto_shape, proto_layout) =
1647 if std::any::TypeId::of::<PROTO>() == std::any::TypeId::of::<i8>() {
1648 if protos.is_standard_layout() {
1649 (&[h, w, k][..], ProtoLayout::Nhwc)
1651 } else if protos.ndim() == 3 && protos.strides() == [w as isize, 1, (h * w) as isize] {
1652 (&[k, h, w][..], ProtoLayout::Nchw)
1656 } else {
1657 (&[h, w, k][..], ProtoLayout::Nhwc)
1659 }
1660 } else {
1661 (&[h, w, k][..], ProtoLayout::Nhwc)
1662 };
1663
1664 let protos_tensor = Tensor::<i8>::new(proto_shape, Some(TensorMemory::Mem), None)
1665 .expect("allocating protos tensor");
1666 {
1667 let mut m = protos_tensor.map().expect("mapping protos tensor");
1668 let dst = m.as_mut_slice();
1669 if std::any::TypeId::of::<PROTO>() == std::any::TypeId::of::<i8>() {
1670 if protos.is_standard_layout() {
1673 let src: &[i8] = unsafe {
1674 std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len())
1675 };
1676 dst.copy_from_slice(src);
1677 } else if protos.ndim() == 3 && protos.strides() == [w as isize, 1, (h * w) as isize] {
1678 let total = h * w * k;
1682 let src: &[i8] =
1685 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, total) };
1686 dst.copy_from_slice(src);
1687 } else {
1688 for (d, s) in dst.iter_mut().zip(protos.iter()) {
1689 let v_i8: i8 = s.as_();
1690 *d = v_i8;
1691 }
1692 }
1693 } else {
1694 for (d, s) in dst.iter_mut().zip(protos.iter()) {
1695 let v_i8: i8 = s.as_();
1696 *d = v_i8;
1697 }
1698 }
1699 }
1700 let tensor_quant =
1701 edgefirst_tensor::Quantization::per_tensor(quant_protos.scale, quant_protos.zero_point);
1702 let protos_tensor = protos_tensor
1703 .with_quantization(tensor_quant)
1704 .expect("per-tensor quantization on new Tensor<i8>");
1705
1706 span.record("layout", tracing::field::debug(&proto_layout));
1707
1708 ProtoData {
1709 mask_coefficients,
1710 protos: TensorDyn::I8(protos_tensor),
1711 layout: proto_layout,
1712 }
1713}
1714
1715pub trait FloatProtoElem: Copy + 'static {
1721 fn slice_into_tensor_dyn(
1722 values: &[Self],
1723 shape: &[usize],
1724 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn>;
1725
1726 fn arrayview3_into_tensor_dyn(
1727 view: ArrayView3<'_, Self>,
1728 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn>;
1729}
1730
1731impl FloatProtoElem for f32 {
1732 fn slice_into_tensor_dyn(
1733 values: &[f32],
1734 shape: &[usize],
1735 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1736 edgefirst_tensor::Tensor::<f32>::from_slice(values, shape)
1737 .map(edgefirst_tensor::TensorDyn::F32)
1738 }
1739 fn arrayview3_into_tensor_dyn(
1740 view: ArrayView3<'_, f32>,
1741 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1742 edgefirst_tensor::Tensor::<f32>::from_arrayview3(view).map(edgefirst_tensor::TensorDyn::F32)
1743 }
1744}
1745
1746impl FloatProtoElem for half::f16 {
1747 fn slice_into_tensor_dyn(
1748 values: &[half::f16],
1749 shape: &[usize],
1750 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1751 edgefirst_tensor::Tensor::<half::f16>::from_slice(values, shape)
1752 .map(edgefirst_tensor::TensorDyn::F16)
1753 }
1754 fn arrayview3_into_tensor_dyn(
1755 view: ArrayView3<'_, half::f16>,
1756 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1757 edgefirst_tensor::Tensor::<half::f16>::from_arrayview3(view)
1758 .map(edgefirst_tensor::TensorDyn::F16)
1759 }
1760}
1761
1762impl FloatProtoElem for f64 {
1763 fn slice_into_tensor_dyn(
1764 values: &[f64],
1765 shape: &[usize],
1766 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1767 let narrowed: Vec<f32> = values.iter().map(|&v| v as f32).collect();
1769 edgefirst_tensor::Tensor::<f32>::from_slice(&narrowed, shape)
1770 .map(edgefirst_tensor::TensorDyn::F32)
1771 }
1772 fn arrayview3_into_tensor_dyn(
1773 view: ArrayView3<'_, f64>,
1774 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1775 let narrowed: ndarray::Array3<f32> = view.mapv(|v| v as f32);
1776 edgefirst_tensor::Tensor::<f32>::from_arrayview3(narrowed.view())
1777 .map(edgefirst_tensor::TensorDyn::F32)
1778 }
1779}
1780
1781fn postprocess_yolo<'a, T>(
1782 output: &'a ArrayView2<'_, T>,
1783) -> (ArrayView2<'a, T>, ArrayView2<'a, T>) {
1784 let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1785 let scores_tensor = output.slice(s![4.., ..,]).reversed_axes();
1786 (boxes_tensor, scores_tensor)
1787}
1788
1789pub(crate) fn postprocess_yolo_seg<'a, T>(
1790 output: &'a ArrayView2<'_, T>,
1791 num_protos: usize,
1792) -> (ArrayView2<'a, T>, ArrayView2<'a, T>, ArrayView2<'a, T>) {
1793 assert!(
1794 output.shape()[0] > num_protos + 4,
1795 "Output shape is too short: {} <= {} + 4",
1796 output.shape()[0],
1797 num_protos
1798 );
1799 let num_classes = output.shape()[0] - 4 - num_protos;
1800 let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1801 let scores_tensor = output.slice(s![4..(num_classes + 4), ..,]).reversed_axes();
1802 let mask_tensor = output.slice(s![(num_classes + 4).., ..,]).reversed_axes();
1803 (boxes_tensor, scores_tensor, mask_tensor)
1804}
1805
1806pub(crate) fn postprocess_yolo_split_segdet<'a, 'b, 'c, BOX, SCORE, MASK>(
1807 boxes_tensor: ArrayView2<'a, BOX>,
1808 scores_tensor: ArrayView2<'b, SCORE>,
1809 mask_tensor: ArrayView2<'c, MASK>,
1810) -> (
1811 ArrayView2<'a, BOX>,
1812 ArrayView2<'b, SCORE>,
1813 ArrayView2<'c, MASK>,
1814) {
1815 let boxes_tensor = boxes_tensor.reversed_axes();
1816 let scores_tensor = scores_tensor.reversed_axes();
1817 let mask_tensor = mask_tensor.reversed_axes();
1818 (boxes_tensor, scores_tensor, mask_tensor)
1819}
1820
1821fn decode_segdet_f32<
1822 MASK: Float + AsPrimitive<f32> + Send + Sync,
1823 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1824>(
1825 boxes: Vec<(DetectBox, usize)>,
1826 masks: ArrayView2<MASK>,
1827 protos: ArrayView3<PROTO>,
1828) -> Result<Vec<(DetectBox, Array3<u8>)>, crate::DecoderError> {
1829 if boxes.is_empty() {
1830 return Ok(Vec::new());
1831 }
1832 if masks.shape()[1] != protos.shape()[2] {
1833 return Err(crate::DecoderError::InvalidShape(format!(
1834 "Mask coefficients count ({}) doesn't match protos channel count ({})",
1835 masks.shape()[1],
1836 protos.shape()[2],
1837 )));
1838 }
1839 boxes
1840 .into_par_iter()
1841 .map(|mut b| {
1842 let ind = b.1;
1843 let (protos, roi) = protobox(&protos, &b.0.bbox)?;
1844 b.0.bbox = roi;
1845 Ok((b.0, make_segmentation(masks.row(ind), protos.view())))
1846 })
1847 .collect()
1848}
1849
1850pub(crate) fn decode_segdet_quant<
1851 MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1852 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1853>(
1854 boxes: Vec<(DetectBox, usize)>,
1855 masks: ArrayView2<MASK>,
1856 protos: ArrayView3<PROTO>,
1857 quant_masks: Quantization,
1858 quant_protos: Quantization,
1859) -> Result<Vec<(DetectBox, Array3<u8>)>, crate::DecoderError> {
1860 if boxes.is_empty() {
1861 return Ok(Vec::new());
1862 }
1863 if masks.shape()[1] != protos.shape()[2] {
1864 return Err(crate::DecoderError::InvalidShape(format!(
1865 "Mask coefficients count ({}) doesn't match protos channel count ({})",
1866 masks.shape()[1],
1867 protos.shape()[2],
1868 )));
1869 }
1870
1871 let total_bits = MASK::zero().count_zeros() + PROTO::zero().count_zeros() + 5; boxes
1873 .into_iter()
1874 .map(|mut b| {
1875 let i = b.1;
1876 let (protos, roi) = protobox(&protos, &b.0.bbox.to_canonical())?;
1877 b.0.bbox = roi;
1878 let seg = match total_bits {
1879 0..=64 => make_segmentation_quant::<MASK, PROTO, i64>(
1880 masks.row(i),
1881 protos.view(),
1882 quant_masks,
1883 quant_protos,
1884 ),
1885 65..=128 => make_segmentation_quant::<MASK, PROTO, i128>(
1886 masks.row(i),
1887 protos.view(),
1888 quant_masks,
1889 quant_protos,
1890 ),
1891 _ => {
1892 return Err(crate::DecoderError::NotSupported(format!(
1893 "Unsupported bit width ({total_bits}) for segmentation computation"
1894 )));
1895 }
1896 };
1897 Ok((b.0, seg))
1898 })
1899 .collect()
1900}
1901
1902fn protobox<'a, T>(
1903 protos: &'a ArrayView3<T>,
1904 roi: &BoundingBox,
1905) -> Result<(ArrayView3<'a, T>, BoundingBox), crate::DecoderError> {
1906 let width = protos.dim().1 as f32;
1907 let height = protos.dim().0 as f32;
1908
1909 const NORM_LIMIT: f32 = 2.0;
1920 if roi.xmin > NORM_LIMIT
1921 || roi.ymin > NORM_LIMIT
1922 || roi.xmax > NORM_LIMIT
1923 || roi.ymax > NORM_LIMIT
1924 {
1925 return Err(crate::DecoderError::InvalidShape(format!(
1926 "Bounding box coordinates appear un-normalized (pixel-space). \
1927 Got bbox=({:.2}, {:.2}, {:.2}, {:.2}) but expected values in [0, 1]. \
1928 ONNX models output pixel-space boxes — normalize them by dividing by \
1929 the input dimensions before calling decode().",
1930 roi.xmin, roi.ymin, roi.xmax, roi.ymax,
1931 )));
1932 }
1933
1934 let roi = [
1935 (roi.xmin * width).clamp(0.0, width) as usize,
1936 (roi.ymin * height).clamp(0.0, height) as usize,
1937 (roi.xmax * width).clamp(0.0, width).ceil() as usize,
1938 (roi.ymax * height).clamp(0.0, height).ceil() as usize,
1939 ];
1940
1941 let roi_norm = [
1942 roi[0] as f32 / width,
1943 roi[1] as f32 / height,
1944 roi[2] as f32 / width,
1945 roi[3] as f32 / height,
1946 ]
1947 .into();
1948
1949 let cropped = protos.slice(s![roi[1]..roi[3], roi[0]..roi[2], ..]);
1950
1951 Ok((cropped, roi_norm))
1952}
1953
1954fn make_segmentation<
1960 MASK: Float + AsPrimitive<f32> + Send + Sync,
1961 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1962>(
1963 mask: ArrayView1<MASK>,
1964 protos: ArrayView3<PROTO>,
1965) -> Array3<u8> {
1966 let shape = protos.shape();
1967
1968 let mask = mask.to_shape((1, mask.len())).unwrap();
1970 let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
1971 let protos = protos.reversed_axes();
1972 let mask = mask.map(|x| x.as_());
1973 let protos = protos.map(|x| x.as_());
1974
1975 let mask = mask
1977 .dot(&protos)
1978 .into_shape_with_order((shape[0], shape[1], 1))
1979 .unwrap();
1980
1981 mask.map(|x| {
1982 let sigmoid = 1.0 / (1.0 + (-*x).exp());
1983 (sigmoid * 255.0).round() as u8
1984 })
1985}
1986
1987fn make_segmentation_quant<
1994 MASK: PrimInt + AsPrimitive<DEST> + Send + Sync,
1995 PROTO: PrimInt + AsPrimitive<DEST> + Send + Sync,
1996 DEST: PrimInt + 'static + Signed + AsPrimitive<f32> + Debug,
1997>(
1998 mask: ArrayView1<MASK>,
1999 protos: ArrayView3<PROTO>,
2000 quant_masks: Quantization,
2001 quant_protos: Quantization,
2002) -> Array3<u8>
2003where
2004 i32: AsPrimitive<DEST>,
2005 f32: AsPrimitive<DEST>,
2006{
2007 let shape = protos.shape();
2008
2009 let mask = mask.to_shape((1, mask.len())).unwrap();
2011
2012 let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
2013 let protos = protos.reversed_axes();
2014
2015 let zp = quant_masks.zero_point.as_();
2016
2017 let mask = mask.mapv(|x| x.as_() - zp);
2018
2019 let zp = quant_protos.zero_point.as_();
2020 let protos = protos.mapv(|x| x.as_() - zp);
2021
2022 let segmentation = mask
2024 .dot(&protos)
2025 .into_shape_with_order((shape[0], shape[1], 1))
2026 .unwrap();
2027
2028 let combined_scale = quant_masks.scale * quant_protos.scale;
2029 segmentation.map(|x| {
2030 let val: f32 = (*x).as_() * combined_scale;
2031 let sigmoid = 1.0 / (1.0 + (-val).exp());
2032 (sigmoid * 255.0).round() as u8
2033 })
2034}
2035
2036pub fn yolo_segmentation_to_mask(
2048 segmentation: ArrayView3<u8>,
2049 threshold: u8,
2050) -> Result<Array2<u8>, crate::DecoderError> {
2051 if segmentation.shape()[2] != 1 {
2052 return Err(crate::DecoderError::InvalidShape(format!(
2053 "Yolo Instance Segmentation should have shape (H, W, 1), got (H, W, {})",
2054 segmentation.shape()[2]
2055 )));
2056 }
2057 Ok(segmentation
2058 .slice(s![.., .., 0])
2059 .map(|x| if *x >= threshold { 1 } else { 0 }))
2060}
2061
2062#[cfg(test)]
2063#[cfg_attr(coverage_nightly, coverage(off))]
2064mod tests {
2065 use super::*;
2066 use ndarray::Array2;
2067
2068 #[test]
2073 fn test_end_to_end_det_basic_filtering() {
2074 let data: Vec<f32> = vec![
2078 0.1, 0.2, 0.3, 0.1, 0.2, 0.3, 0.5, 0.6, 0.7, 0.5, 0.6, 0.7, 0.9, 0.1, 0.2, 0.0, 1.0, 2.0, ];
2086 let output = Array2::from_shape_vec((6, 3), data).unwrap();
2087
2088 let mut boxes = Vec::with_capacity(10);
2089 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2090
2091 assert_eq!(boxes.len(), 1);
2093 assert_eq!(boxes[0].label, 0);
2094 assert!((boxes[0].score - 0.9).abs() < 0.01);
2095 assert!((boxes[0].bbox.xmin - 0.1).abs() < 0.01);
2096 assert!((boxes[0].bbox.ymin - 0.1).abs() < 0.01);
2097 assert!((boxes[0].bbox.xmax - 0.5).abs() < 0.01);
2098 assert!((boxes[0].bbox.ymax - 0.5).abs() < 0.01);
2099 }
2100
2101 #[test]
2102 fn test_end_to_end_det_all_pass_threshold() {
2103 let data: Vec<f32> = vec![
2105 10.0, 20.0, 10.0, 20.0, 50.0, 60.0, 50.0, 60.0, 0.8, 0.7, 1.0, 2.0, ];
2112 let output = Array2::from_shape_vec((6, 2), data).unwrap();
2113
2114 let mut boxes = Vec::with_capacity(10);
2115 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2116
2117 assert_eq!(boxes.len(), 2);
2118 assert_eq!(boxes[0].label, 1);
2119 assert_eq!(boxes[1].label, 2);
2120 }
2121
2122 #[test]
2123 fn test_end_to_end_det_none_pass_threshold() {
2124 let data: Vec<f32> = vec![
2126 10.0, 20.0, 10.0, 20.0, 50.0, 60.0, 50.0, 60.0, 0.1, 0.2, 1.0, 2.0, ];
2133 let output = Array2::from_shape_vec((6, 2), data).unwrap();
2134
2135 let mut boxes = Vec::with_capacity(10);
2136 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2137
2138 assert_eq!(boxes.len(), 0);
2139 }
2140
2141 #[test]
2142 fn test_end_to_end_det_capacity_limit() {
2143 let data: Vec<f32> = vec![
2145 0.1, 0.2, 0.3, 0.4, 0.5, 0.1, 0.2, 0.3, 0.4, 0.5, 0.5, 0.6, 0.7, 0.8, 0.9, 0.5, 0.6, 0.7, 0.8, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.0, 1.0, 2.0, 3.0, 4.0, ];
2152 let output = Array2::from_shape_vec((6, 5), data).unwrap();
2153
2154 let mut boxes = Vec::with_capacity(2); decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2156
2157 assert_eq!(boxes.len(), 2);
2158 }
2159
2160 #[test]
2161 fn test_end_to_end_det_empty_output() {
2162 let output = Array2::<f32>::zeros((6, 0));
2164
2165 let mut boxes = Vec::with_capacity(10);
2166 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2167
2168 assert_eq!(boxes.len(), 0);
2169 }
2170
2171 #[test]
2172 fn test_end_to_end_det_pixel_coordinates() {
2173 let data: Vec<f32> = vec![
2175 100.0, 200.0, 300.0, 400.0, 0.95, 5.0, ];
2182 let output = Array2::from_shape_vec((6, 1), data).unwrap();
2183
2184 let mut boxes = Vec::with_capacity(10);
2185 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2186
2187 assert_eq!(boxes.len(), 1);
2188 assert_eq!(boxes[0].label, 5);
2189 assert!((boxes[0].bbox.xmin - 100.0).abs() < 0.01);
2190 assert!((boxes[0].bbox.ymin - 200.0).abs() < 0.01);
2191 assert!((boxes[0].bbox.xmax - 300.0).abs() < 0.01);
2192 assert!((boxes[0].bbox.ymax - 400.0).abs() < 0.01);
2193 }
2194
2195 #[test]
2196 fn test_end_to_end_det_invalid_shape() {
2197 let output = Array2::<f32>::zeros((5, 3));
2199
2200 let mut boxes = Vec::with_capacity(10);
2201 let result = decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes);
2202
2203 assert!(result.is_err());
2204 assert!(matches!(
2205 result,
2206 Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 6 rows")
2207 ));
2208 }
2209
2210 #[test]
2215 fn test_end_to_end_segdet_basic() {
2216 let num_protos = 32;
2219 let num_detections = 2;
2220 let num_features = 6 + num_protos;
2221
2222 let mut data = vec![0.0f32; num_features * num_detections];
2224 data[0] = 0.1; data[1] = 0.5; data[num_detections] = 0.1; data[num_detections + 1] = 0.5; data[2 * num_detections] = 0.4; data[2 * num_detections + 1] = 0.9; data[3 * num_detections] = 0.4; data[3 * num_detections + 1] = 0.9; data[4 * num_detections] = 0.9; data[4 * num_detections + 1] = 0.3; data[5 * num_detections] = 1.0; data[5 * num_detections + 1] = 2.0; for i in 6..num_features {
2239 data[i * num_detections] = 0.1;
2240 data[i * num_detections + 1] = 0.1;
2241 }
2242
2243 let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
2244
2245 let protos = Array3::<f32>::zeros((16, 16, num_protos));
2247
2248 let mut boxes = Vec::with_capacity(10);
2249 let mut masks = Vec::with_capacity(10);
2250 decode_yolo_end_to_end_segdet_float(
2251 output.view(),
2252 protos.view(),
2253 0.5,
2254 &mut boxes,
2255 &mut masks,
2256 )
2257 .unwrap();
2258
2259 assert_eq!(boxes.len(), 1);
2261 assert_eq!(masks.len(), 1);
2262 assert_eq!(boxes[0].label, 1);
2263 assert!((boxes[0].score - 0.9).abs() < 0.01);
2264 }
2265
2266 #[test]
2267 fn test_end_to_end_segdet_mask_coordinates() {
2268 let num_protos = 32;
2270 let num_features = 6 + num_protos;
2271
2272 let mut data = vec![0.0f32; num_features];
2273 data[0] = 0.2; data[1] = 0.2; data[2] = 0.8; data[3] = 0.8; data[4] = 0.95; data[5] = 3.0; let output = Array2::from_shape_vec((num_features, 1), data).unwrap();
2281 let protos = Array3::<f32>::zeros((16, 16, num_protos));
2282
2283 let mut boxes = Vec::with_capacity(10);
2284 let mut masks = Vec::with_capacity(10);
2285 decode_yolo_end_to_end_segdet_float(
2286 output.view(),
2287 protos.view(),
2288 0.5,
2289 &mut boxes,
2290 &mut masks,
2291 )
2292 .unwrap();
2293
2294 assert_eq!(boxes.len(), 1);
2295 assert_eq!(masks.len(), 1);
2296
2297 assert!((masks[0].xmin - boxes[0].bbox.xmin).abs() < 0.01);
2299 assert!((masks[0].ymin - boxes[0].bbox.ymin).abs() < 0.01);
2300 assert!((masks[0].xmax - boxes[0].bbox.xmax).abs() < 0.01);
2301 assert!((masks[0].ymax - boxes[0].bbox.ymax).abs() < 0.01);
2302 }
2303
2304 #[test]
2305 fn test_end_to_end_segdet_empty_output() {
2306 let num_protos = 32;
2307 let output = Array2::<f32>::zeros((6 + num_protos, 0));
2308 let protos = Array3::<f32>::zeros((16, 16, num_protos));
2309
2310 let mut boxes = Vec::with_capacity(10);
2311 let mut masks = Vec::with_capacity(10);
2312 decode_yolo_end_to_end_segdet_float(
2313 output.view(),
2314 protos.view(),
2315 0.5,
2316 &mut boxes,
2317 &mut masks,
2318 )
2319 .unwrap();
2320
2321 assert_eq!(boxes.len(), 0);
2322 assert_eq!(masks.len(), 0);
2323 }
2324
2325 #[test]
2326 fn test_end_to_end_segdet_capacity_limit() {
2327 let num_protos = 32;
2328 let num_detections = 5;
2329 let num_features = 6 + num_protos;
2330
2331 let mut data = vec![0.0f32; num_features * num_detections];
2332 for i in 0..num_detections {
2334 data[i] = 0.1 * (i as f32); data[num_detections + i] = 0.1 * (i as f32); data[2 * num_detections + i] = 0.1 * (i as f32) + 0.2; data[3 * num_detections + i] = 0.1 * (i as f32) + 0.2; data[4 * num_detections + i] = 0.9; data[5 * num_detections + i] = i as f32; }
2341
2342 let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
2343 let protos = Array3::<f32>::zeros((16, 16, num_protos));
2344
2345 let mut boxes = Vec::with_capacity(2); let mut masks = Vec::with_capacity(2);
2347 decode_yolo_end_to_end_segdet_float(
2348 output.view(),
2349 protos.view(),
2350 0.5,
2351 &mut boxes,
2352 &mut masks,
2353 )
2354 .unwrap();
2355
2356 assert_eq!(boxes.len(), 2);
2357 assert_eq!(masks.len(), 2);
2358 }
2359
2360 #[test]
2361 fn test_end_to_end_segdet_invalid_shape_too_few_rows() {
2362 let output = Array2::<f32>::zeros((6, 3));
2364 let protos = Array3::<f32>::zeros((16, 16, 32));
2365
2366 let mut boxes = Vec::with_capacity(10);
2367 let mut masks = Vec::with_capacity(10);
2368 let result = decode_yolo_end_to_end_segdet_float(
2369 output.view(),
2370 protos.view(),
2371 0.5,
2372 &mut boxes,
2373 &mut masks,
2374 );
2375
2376 assert!(result.is_err());
2377 assert!(matches!(
2378 result,
2379 Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 7 rows")
2380 ));
2381 }
2382
2383 #[test]
2384 fn test_end_to_end_segdet_invalid_shape_protos_mismatch() {
2385 let num_protos = 32;
2387 let output = Array2::<f32>::zeros((6 + 16, 3)); let protos = Array3::<f32>::zeros((16, 16, num_protos)); let mut boxes = Vec::with_capacity(10);
2391 let mut masks = Vec::with_capacity(10);
2392 let result = decode_yolo_end_to_end_segdet_float(
2393 output.view(),
2394 protos.view(),
2395 0.5,
2396 &mut boxes,
2397 &mut masks,
2398 );
2399
2400 assert!(result.is_err());
2401 assert!(matches!(
2402 result,
2403 Err(crate::DecoderError::InvalidShape(s)) if s.contains("doesn't match protos count")
2404 ));
2405 }
2406
2407 #[test]
2412 fn test_split_end_to_end_segdet_basic() {
2413 let num_protos = 32;
2416 let num_detections = 2;
2417 let num_features = 6 + num_protos;
2418
2419 let mut data = vec![0.0f32; num_features * num_detections];
2421 data[0] = 0.1; data[1] = 0.5; data[num_detections] = 0.1; data[num_detections + 1] = 0.5; data[2 * num_detections] = 0.4; data[2 * num_detections + 1] = 0.9; data[3 * num_detections] = 0.4; data[3 * num_detections + 1] = 0.9; data[4 * num_detections] = 0.9; data[4 * num_detections + 1] = 0.3; data[5 * num_detections] = 1.0; data[5 * num_detections + 1] = 2.0; for i in 6..num_features {
2436 data[i * num_detections] = 0.1;
2437 data[i * num_detections + 1] = 0.1;
2438 }
2439
2440 let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
2441 let box_coords = output.slice(s![..4, ..]);
2442 let scores = output.slice(s![4..5, ..]);
2443 let classes = output.slice(s![5..6, ..]);
2444 let mask_coeff = output.slice(s![6.., ..]);
2445 let protos = Array3::<f32>::zeros((16, 16, num_protos));
2447
2448 let mut boxes = Vec::with_capacity(10);
2449 let mut masks = Vec::with_capacity(10);
2450 decode_yolo_split_end_to_end_segdet_float(
2451 box_coords,
2452 scores,
2453 classes,
2454 mask_coeff,
2455 protos.view(),
2456 0.5,
2457 &mut boxes,
2458 &mut masks,
2459 )
2460 .unwrap();
2461
2462 assert_eq!(boxes.len(), 1);
2464 assert_eq!(masks.len(), 1);
2465 assert_eq!(boxes[0].label, 1);
2466 assert!((boxes[0].score - 0.9).abs() < 0.01);
2467 }
2468
2469 #[test]
2474 fn test_segmentation_to_mask_basic() {
2475 let data: Vec<u8> = vec![
2477 100, 200, 50, 150, 10, 255, 128, 64, 0, 127, 128, 255, 64, 64, 192, 192, ];
2482 let segmentation = Array3::from_shape_vec((4, 4, 1), data).unwrap();
2483
2484 let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2485
2486 assert_eq!(mask[[0, 0]], 0); assert_eq!(mask[[0, 1]], 1); assert_eq!(mask[[0, 2]], 0); assert_eq!(mask[[0, 3]], 1); assert_eq!(mask[[1, 1]], 1); assert_eq!(mask[[1, 2]], 1); assert_eq!(mask[[2, 0]], 0); assert_eq!(mask[[2, 1]], 0); }
2496
2497 #[test]
2498 fn test_segmentation_to_mask_all_above() {
2499 let segmentation = Array3::from_elem((4, 4, 1), 255u8);
2500 let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2501 assert!(mask.iter().all(|&x| x == 1));
2502 }
2503
2504 #[test]
2505 fn test_segmentation_to_mask_all_below() {
2506 let segmentation = Array3::from_elem((4, 4, 1), 64u8);
2507 let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2508 assert!(mask.iter().all(|&x| x == 0));
2509 }
2510
2511 #[test]
2512 fn test_segmentation_to_mask_invalid_shape() {
2513 let segmentation = Array3::from_elem((4, 4, 3), 128u8);
2514 let result = yolo_segmentation_to_mask(segmentation.view(), 128);
2515
2516 assert!(result.is_err());
2517 assert!(matches!(
2518 result,
2519 Err(crate::DecoderError::InvalidShape(s)) if s.contains("(H, W, 1)")
2520 ));
2521 }
2522
2523 #[test]
2528 fn test_protobox_clamps_edge_coordinates() {
2529 let protos = Array3::<f32>::zeros((16, 16, 4));
2531 let view = protos.view();
2532 let roi = BoundingBox {
2533 xmin: 0.5,
2534 ymin: 0.5,
2535 xmax: 1.0,
2536 ymax: 1.0,
2537 };
2538 let result = protobox(&view, &roi);
2539 assert!(result.is_ok(), "protobox should accept xmax=1.0");
2540 let (cropped, _roi_norm) = result.unwrap();
2541 assert!(cropped.shape()[0] > 0);
2543 assert!(cropped.shape()[1] > 0);
2544 assert_eq!(cropped.shape()[2], 4);
2545 }
2546
2547 #[test]
2548 fn test_protobox_rejects_wildly_out_of_range() {
2549 let protos = Array3::<f32>::zeros((16, 16, 4));
2551 let view = protos.view();
2552 let roi = BoundingBox {
2553 xmin: 0.0,
2554 ymin: 0.0,
2555 xmax: 3.0,
2556 ymax: 3.0,
2557 };
2558 let result = protobox(&view, &roi);
2559 assert!(
2560 matches!(result, Err(crate::DecoderError::InvalidShape(s)) if s.contains("un-normalized")),
2561 "protobox should reject coords > NORM_LIMIT"
2562 );
2563 }
2564
2565 #[test]
2566 fn test_protobox_accepts_slightly_over_one() {
2567 let protos = Array3::<f32>::zeros((16, 16, 4));
2569 let view = protos.view();
2570 let roi = BoundingBox {
2571 xmin: 0.0,
2572 ymin: 0.0,
2573 xmax: 1.5,
2574 ymax: 1.5,
2575 };
2576 let result = protobox(&view, &roi);
2577 assert!(
2578 result.is_ok(),
2579 "protobox should accept coords <= NORM_LIMIT (2.0)"
2580 );
2581 let (cropped, _roi_norm) = result.unwrap();
2582 assert_eq!(cropped.shape()[0], 16);
2584 assert_eq!(cropped.shape()[1], 16);
2585 }
2586
2587 #[test]
2588 fn test_segdet_float_proto_no_panic() {
2589 let num_proposals = 100; let num_classes = 80;
2593 let num_mask_coeffs = 32;
2594 let rows = 4 + num_classes + num_mask_coeffs; let mut data = vec![0.0f32; rows * num_proposals];
2600 for i in 0..num_proposals {
2601 let row = |r: usize| r * num_proposals + i;
2602 data[row(0)] = 320.0; data[row(1)] = 320.0; data[row(2)] = 50.0; data[row(3)] = 50.0; data[row(4)] = 0.9; }
2608 let boxes = ndarray::Array2::from_shape_vec((rows, num_proposals), data).unwrap();
2609
2610 let protos = ndarray::Array3::<f32>::zeros((160, 160, num_mask_coeffs));
2615
2616 let mut output_boxes = Vec::with_capacity(300);
2617
2618 let proto_data = impl_yolo_segdet_float_proto::<XYWH, _, _>(
2620 boxes.view(),
2621 protos.view(),
2622 0.5,
2623 0.7,
2624 Some(Nms::default()),
2625 MAX_NMS_CANDIDATES,
2626 300,
2627 &mut output_boxes,
2628 );
2629
2630 assert!(!output_boxes.is_empty());
2632 let coeffs_shape = proto_data.mask_coefficients.shape();
2633 assert_eq!(coeffs_shape[0], output_boxes.len());
2634 assert_eq!(coeffs_shape[1], num_mask_coeffs);
2636 }
2637
2638 #[test]
2653 fn test_pre_nms_cap_truncates_excess_candidates() {
2654 let n: usize = 50_000;
2655 let num_classes = 1;
2656
2657 let mut boxes_data = Vec::with_capacity(n * 4);
2661 let mut scores_data = Vec::with_capacity(n * num_classes);
2662 for i in 0..n {
2663 boxes_data.extend_from_slice(&[0.1f32, 0.1, 0.5, 0.5]);
2664 scores_data.push(0.99 - (i as f32) * 1e-7);
2667 }
2668 let boxes = Array2::from_shape_vec((n, 4), boxes_data).unwrap();
2669 let scores = Array2::from_shape_vec((n, num_classes), scores_data).unwrap();
2670
2671 let result = impl_yolo_segdet_get_boxes::<XYXY, _, _>(
2672 boxes.view(),
2673 scores.view(),
2674 0.1,
2675 1.0, Some(Nms::ClassAgnostic), crate::yolo::MAX_NMS_CANDIDATES, usize::MAX, );
2680
2681 assert_eq!(
2682 result.len(),
2683 crate::yolo::MAX_NMS_CANDIDATES,
2684 "pre-NMS cap should truncate to MAX_NMS_CANDIDATES; got {}",
2685 result.len()
2686 );
2687 let top_score = result[0].0.score;
2690 assert!(
2691 top_score > 0.98,
2692 "highest-ranked survivor should have the largest score, got {top_score}"
2693 );
2694 }
2695
2696 #[test]
2701 fn test_pre_nms_cap_truncates_excess_candidates_quant() {
2702 use crate::Quantization;
2703 let n: usize = 50_000;
2704 let num_classes = 1;
2705
2706 let boxes_data = (0..n).flat_map(|_| [10i8, 10, 50, 50]).collect::<Vec<_>>();
2709 let boxes = Array2::from_shape_vec((n, 4), boxes_data).unwrap();
2710 let quant_boxes = Quantization {
2711 scale: 0.01,
2712 zero_point: 0,
2713 };
2714
2715 let scores_data: Vec<u8> = (0..n)
2720 .map(|i| 250u8.saturating_sub((i % 200) as u8))
2721 .collect();
2722 let scores = Array2::from_shape_vec((n, num_classes), scores_data).unwrap();
2723 let quant_scores = Quantization {
2724 scale: 0.00392,
2725 zero_point: 0,
2726 };
2727
2728 let result = impl_yolo_split_segdet_quant_get_boxes::<XYXY, _, _>(
2729 (boxes.view(), quant_boxes),
2730 (scores.view(), quant_scores),
2731 0.1,
2732 1.0, Some(Nms::ClassAgnostic), crate::yolo::MAX_NMS_CANDIDATES, usize::MAX, );
2737
2738 assert_eq!(
2739 result.len(),
2740 crate::yolo::MAX_NMS_CANDIDATES,
2741 "quant path pre-NMS cap should truncate to MAX_NMS_CANDIDATES; got {}",
2742 result.len()
2743 );
2744 }
2745
2746 #[test]
2760 fn segdet_combined_tensor_pairs_detection_with_matching_mask_row() {
2761 let nc = 2; let nm = 2; let n = 3; let feat = 4 + nc + nm; let mut data = vec![0.0f32; feat * n];
2784 let set = |d: &mut [f32], r: usize, c: usize, v: f32| d[r * n + c] = v;
2785 set(&mut data, 0, 0, 0.2);
2786 set(&mut data, 1, 0, 0.2);
2787 set(&mut data, 2, 0, 0.1);
2788 set(&mut data, 3, 0, 0.1);
2789 set(&mut data, 0, 1, 0.5);
2790 set(&mut data, 1, 1, 0.5);
2791 set(&mut data, 2, 1, 0.1);
2792 set(&mut data, 3, 1, 0.1);
2793 set(&mut data, 0, 2, 0.8);
2794 set(&mut data, 1, 2, 0.8);
2795 set(&mut data, 2, 2, 0.1);
2796 set(&mut data, 3, 2, 0.1);
2797 set(&mut data, 4, 0, 0.9);
2798 set(&mut data, 4, 2, 0.8);
2799 set(&mut data, 6, 0, 3.0);
2800 set(&mut data, 7, 0, 3.0);
2801 set(&mut data, 6, 2, -3.0);
2802 set(&mut data, 7, 2, -3.0);
2803
2804 let output = Array2::from_shape_vec((feat, n), data).unwrap();
2805 let protos = Array3::<f32>::from_elem((8, 8, nm), 1.0);
2806
2807 let mut boxes: Vec<DetectBox> = Vec::with_capacity(10);
2808 let mut masks: Vec<Segmentation> = Vec::with_capacity(10);
2809 decode_yolo_segdet_float(
2810 output.view(),
2811 protos.view(),
2812 0.5,
2813 0.5,
2814 Some(Nms::ClassAgnostic),
2815 &mut boxes,
2816 &mut masks,
2817 )
2818 .unwrap();
2819
2820 assert_eq!(
2821 boxes.len(),
2822 2,
2823 "two anchors above threshold should survive (a0 score=0.9, a2 score=0.8); got {}",
2824 boxes.len()
2825 );
2826
2827 for (b, m) in boxes.iter().zip(masks.iter()) {
2833 let cx = (b.bbox.xmin + b.bbox.xmax) * 0.5;
2834 let mean = {
2835 let s = &m.segmentation;
2836 let total: u32 = s.iter().map(|&v| v as u32).sum();
2837 total as f32 / s.len() as f32
2838 };
2839 if cx < 0.3 {
2840 assert!(
2842 mean > 200.0,
2843 "anchor 0 detection (centre {cx:.2}) should have high-value mask; got mean {mean}"
2844 );
2845 } else if cx > 0.7 {
2846 assert!(
2848 mean < 50.0,
2849 "anchor 2 detection (centre {cx:.2}) should have low-value mask; got mean {mean}"
2850 );
2851 } else {
2852 panic!("unexpected detection centre {cx:.2}");
2853 }
2854 }
2855 }
2856}