1use std::fmt::Debug;
13
14use ndarray::{
15 parallel::prelude::{IntoParallelIterator, ParallelIterator},
16 s, Array2, Array3, ArrayView1, ArrayView2, ArrayView3,
17};
18use num_traits::{AsPrimitive, Float, PrimInt, Signed};
19
20use crate::{
21 byte::{
22 nms_class_aware_int, nms_extra_class_aware_int, nms_extra_int, nms_int,
23 postprocess_boxes_index_quant, postprocess_boxes_quant, quantize_score_threshold,
24 },
25 configs::Nms,
26 dequant_detect_box,
27 float::{
28 nms_class_aware_float, nms_extra_class_aware_float, nms_extra_float, nms_float,
29 postprocess_boxes_float, postprocess_boxes_index_float,
30 },
31 BBoxTypeTrait, BoundingBox, DetectBox, DetectBoxQuantized, ProtoData, ProtoLayout,
32 Quantization, Segmentation, XYWH, XYXY,
33};
34
35#[cfg(test)]
53pub(crate) const MAX_NMS_CANDIDATES: usize = 30_000;
54
55pub(crate) const DEFAULT_MAX_DETECTIONS: usize = 300;
62
63fn truncate_to_top_k_by_score<E: Send>(boxes: &mut Vec<(DetectBox, E)>, top_k: usize) {
68 if top_k > 0 && boxes.len() > top_k {
69 boxes.select_nth_unstable_by(top_k, |a, b| b.0.score.total_cmp(&a.0.score));
70 boxes.truncate(top_k);
71 }
72}
73
74fn truncate_to_top_k_by_score_quant<S: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send>(
79 boxes: &mut Vec<(DetectBoxQuantized<S>, E)>,
80 top_k: usize,
81) {
82 if top_k > 0 && boxes.len() > top_k {
83 boxes.select_nth_unstable_by(top_k, |a, b| b.0.score.cmp(&a.0.score));
84 boxes.truncate(top_k);
85 }
86}
87
88fn dispatch_nms_float(
95 nms: Option<Nms>,
96 iou: f32,
97 max_det: Option<usize>,
98 boxes: Vec<DetectBox>,
99) -> Vec<DetectBox> {
100 match nms {
101 Some(Nms::ClassAgnostic | Nms::Auto) => nms_float(iou, max_det, boxes),
102 Some(Nms::ClassAware) => nms_class_aware_float(iou, max_det, boxes),
103 None => boxes, }
105}
106
107pub(super) fn dispatch_nms_extra_float<E: Send + Sync>(
110 nms: Option<Nms>,
111 iou: f32,
112 max_det: Option<usize>,
113 boxes: Vec<(DetectBox, E)>,
114) -> Vec<(DetectBox, E)> {
115 match nms {
116 Some(Nms::ClassAgnostic | Nms::Auto) => nms_extra_float(iou, max_det, boxes),
117 Some(Nms::ClassAware) => nms_extra_class_aware_float(iou, max_det, boxes),
118 None => boxes, }
120}
121
122fn dispatch_nms_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync>(
125 nms: Option<Nms>,
126 iou: f32,
127 max_det: Option<usize>,
128 boxes: Vec<DetectBoxQuantized<SCORE>>,
129) -> Vec<DetectBoxQuantized<SCORE>> {
130 match nms {
131 Some(Nms::ClassAgnostic | Nms::Auto) => nms_int(iou, max_det, boxes),
132 Some(Nms::ClassAware) => nms_class_aware_int(iou, max_det, boxes),
133 None => boxes, }
135}
136
137fn dispatch_nms_extra_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send + Sync>(
140 nms: Option<Nms>,
141 iou: f32,
142 max_det: Option<usize>,
143 boxes: Vec<(DetectBoxQuantized<SCORE>, E)>,
144) -> Vec<(DetectBoxQuantized<SCORE>, E)> {
145 match nms {
146 Some(Nms::ClassAgnostic | Nms::Auto) => nms_extra_int(iou, max_det, boxes),
147 Some(Nms::ClassAware) => nms_extra_class_aware_int(iou, max_det, boxes),
148 None => boxes, }
150}
151
152#[inline]
159fn cap_or_default<T>(v: &Vec<T>) -> usize {
160 if v.capacity() > 0 {
161 v.capacity()
162 } else {
163 DEFAULT_MAX_DETECTIONS
164 }
165}
166
167pub(crate) fn decode_yolo_det<BOX: PrimInt + AsPrimitive<f32> + Send + Sync>(
201 output: (ArrayView2<BOX>, Quantization),
202 score_threshold: f32,
203 iou_threshold: f32,
204 nms: Option<Nms>,
205 output_boxes: &mut Vec<DetectBox>,
206) where
207 f32: AsPrimitive<BOX>,
208{
209 impl_yolo_quant::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
210}
211
212pub(crate) fn decode_yolo_det_float<T>(
219 output: ArrayView2<T>,
220 score_threshold: f32,
221 iou_threshold: f32,
222 nms: Option<Nms>,
223 output_boxes: &mut Vec<DetectBox>,
224) where
225 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
226 f32: AsPrimitive<T>,
227{
228 impl_yolo_float::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
229}
230
231#[cfg(test)]
245pub(crate) fn decode_yolo_segdet_quant<
246 BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
247 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
248>(
249 boxes: (ArrayView2<BOX>, Quantization),
250 protos: (ArrayView3<PROTO>, Quantization),
251 score_threshold: f32,
252 iou_threshold: f32,
253 nms: Option<Nms>,
254 output_boxes: &mut Vec<DetectBox>,
255 output_masks: &mut Vec<Segmentation>,
256) -> Result<(), crate::DecoderError>
257where
258 f32: AsPrimitive<BOX>,
259{
260 let cap = cap_or_default(output_boxes);
265 impl_yolo_segdet_quant::<XYWH, _, _>(
266 boxes,
267 protos,
268 score_threshold,
269 iou_threshold,
270 nms,
271 MAX_NMS_CANDIDATES,
272 cap,
273 None,
274 None,
275 output_boxes,
276 output_masks,
277 )
278}
279
280#[cfg(test)]
282pub(crate) fn decode_yolo_segdet_float<T>(
283 boxes: ArrayView2<T>,
284 protos: ArrayView3<T>,
285 score_threshold: f32,
286 iou_threshold: f32,
287 nms: Option<Nms>,
288 output_boxes: &mut Vec<DetectBox>,
289 output_masks: &mut Vec<Segmentation>,
290) -> Result<(), crate::DecoderError>
291where
292 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
293 f32: AsPrimitive<T>,
294{
295 let cap = cap_or_default(output_boxes);
298 impl_yolo_segdet_float::<XYWH, _, _>(
299 boxes,
300 protos,
301 score_threshold,
302 iou_threshold,
303 nms,
304 MAX_NMS_CANDIDATES,
305 cap,
306 None,
307 None,
308 output_boxes,
309 output_masks,
310 )
311}
312
313pub(crate) fn decode_yolo_split_det_quant<
325 BOX: PrimInt + AsPrimitive<i32> + AsPrimitive<f32> + Send + Sync,
326 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
327>(
328 boxes: (ArrayView2<BOX>, Quantization),
329 scores: (ArrayView2<SCORE>, Quantization),
330 score_threshold: f32,
331 iou_threshold: f32,
332 nms: Option<Nms>,
333 output_boxes: &mut Vec<DetectBox>,
334) where
335 f32: AsPrimitive<SCORE>,
336{
337 impl_yolo_split_quant::<XYWH, _, _>(
338 boxes,
339 scores,
340 score_threshold,
341 iou_threshold,
342 nms,
343 output_boxes,
344 );
345}
346
347pub(crate) fn decode_yolo_split_det_float<T>(
359 boxes: ArrayView2<T>,
360 scores: ArrayView2<T>,
361 score_threshold: f32,
362 iou_threshold: f32,
363 nms: Option<Nms>,
364 output_boxes: &mut Vec<DetectBox>,
365) where
366 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
367 f32: AsPrimitive<T>,
368{
369 impl_yolo_split_float::<XYWH, _, _>(
370 boxes,
371 scores,
372 score_threshold,
373 iou_threshold,
374 nms,
375 output_boxes,
376 );
377}
378
379pub(crate) fn decode_yolo_end_to_end_det_float<T>(
394 output: ArrayView2<T>,
395 score_threshold: f32,
396 output_boxes: &mut Vec<DetectBox>,
397) -> Result<(), crate::DecoderError>
398where
399 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
400 f32: AsPrimitive<T>,
401{
402 if output.shape()[0] < 6 {
404 return Err(crate::DecoderError::InvalidShape(format!(
405 "End-to-end detection output requires at least 6 rows, got {}",
406 output.shape()[0]
407 )));
408 }
409
410 let boxes = output.slice(s![0..4, ..]).reversed_axes();
412 let scores = output.slice(s![4..5, ..]).reversed_axes();
413 let classes = output.slice(s![5, ..]);
414 let mut boxes =
415 postprocess_boxes_index_float::<XYXY, _, _>(score_threshold.as_(), boxes, scores);
416 boxes.truncate(cap_or_default(output_boxes));
417 output_boxes.clear();
418 for (mut b, i) in boxes.into_iter() {
419 b.label = classes[i].as_() as usize;
420 output_boxes.push(b);
421 }
422 Ok(())
424}
425
426pub(crate) fn decode_yolo_end_to_end_segdet_float<T>(
444 output: ArrayView2<T>,
445 protos: ArrayView3<T>,
446 score_threshold: f32,
447 output_boxes: &mut Vec<DetectBox>,
448 output_masks: &mut Vec<crate::Segmentation>,
449) -> Result<(), crate::DecoderError>
450where
451 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
452 f32: AsPrimitive<T>,
453{
454 let (boxes, scores, classes, mask_coeff) =
455 postprocess_yolo_end_to_end_segdet(&output, protos.dim().2)?;
456 let cap = cap_or_default(output_boxes);
457 let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
458 boxes,
459 scores,
460 classes,
461 score_threshold,
462 cap,
463 );
464
465 impl_yolo_split_segdet_process_masks(boxes, mask_coeff, protos, output_boxes, output_masks)
468}
469
470pub(crate) fn decode_yolo_split_end_to_end_det_float<T: Float + AsPrimitive<f32>>(
479 boxes: ArrayView2<T>,
480 scores: ArrayView2<T>,
481 classes: ArrayView2<T>,
482 score_threshold: f32,
483 output_boxes: &mut Vec<DetectBox>,
484) -> Result<(), crate::DecoderError> {
485 let n = boxes.shape()[1];
486
487 let cap = cap_or_default(output_boxes);
488 output_boxes.clear();
489
490 let (boxes, scores, classes) = postprocess_yolo_split_end_to_end_det(boxes, scores, &classes)?;
491
492 for i in 0..n {
493 let score: f32 = scores[[i, 0]].as_();
494 if score < score_threshold {
495 continue;
496 }
497 if output_boxes.len() >= cap {
498 break;
499 }
500 output_boxes.push(DetectBox {
501 bbox: BoundingBox {
502 xmin: boxes[[i, 0]].as_(),
503 ymin: boxes[[i, 1]].as_(),
504 xmax: boxes[[i, 2]].as_(),
505 ymax: boxes[[i, 3]].as_(),
506 },
507 score,
508 label: classes[i].as_() as usize,
509 });
510 }
511 Ok(())
512}
513
514#[allow(clippy::too_many_arguments)]
523pub(crate) fn decode_yolo_split_end_to_end_segdet_float<T>(
524 boxes: ArrayView2<T>,
525 scores: ArrayView2<T>,
526 classes: ArrayView2<T>,
527 mask_coeff: ArrayView2<T>,
528 protos: ArrayView3<T>,
529 score_threshold: f32,
530 output_boxes: &mut Vec<DetectBox>,
531 output_masks: &mut Vec<crate::Segmentation>,
532) -> Result<(), crate::DecoderError>
533where
534 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
535 f32: AsPrimitive<T>,
536{
537 let (boxes, scores, classes, mask_coeff) =
538 postprocess_yolo_split_end_to_end_segdet(boxes, scores, &classes, mask_coeff)?;
539 let cap = cap_or_default(output_boxes);
540 let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
541 boxes,
542 scores,
543 classes,
544 score_threshold,
545 cap,
546 );
547
548 impl_yolo_split_segdet_process_masks(boxes, mask_coeff, protos, output_boxes, output_masks)
549}
550
551#[allow(clippy::type_complexity)]
552pub(crate) fn postprocess_yolo_end_to_end_segdet<'a, T>(
553 output: &'a ArrayView2<'_, T>,
554 num_protos: usize,
555) -> Result<
556 (
557 ArrayView2<'a, T>,
558 ArrayView2<'a, T>,
559 ArrayView1<'a, T>,
560 ArrayView2<'a, T>,
561 ),
562 crate::DecoderError,
563> {
564 if output.shape()[0] < 7 {
566 return Err(crate::DecoderError::InvalidShape(format!(
567 "End-to-end segdet output requires at least 7 rows, got {}",
568 output.shape()[0]
569 )));
570 }
571
572 let num_mask_coeffs = output.shape()[0] - 6;
573 if num_mask_coeffs != num_protos {
574 return Err(crate::DecoderError::InvalidShape(format!(
575 "Mask coefficients count ({}) doesn't match protos count ({})",
576 num_mask_coeffs, num_protos
577 )));
578 }
579
580 let boxes = output.slice(s![0..4, ..]).reversed_axes();
582 let scores = output.slice(s![4..5, ..]).reversed_axes();
583 let classes = output.slice(s![5, ..]);
584 let mask_coeff = output.slice(s![6.., ..]).reversed_axes();
585 Ok((boxes, scores, classes, mask_coeff))
586}
587
588#[allow(clippy::type_complexity)]
595pub(crate) fn postprocess_yolo_split_end_to_end_det<'a, 'b, 'c, BOXES, SCORES, CLASS>(
596 boxes: ArrayView2<'a, BOXES>,
597 scores: ArrayView2<'b, SCORES>,
598 classes: &'c ArrayView2<CLASS>,
599) -> Result<
600 (
601 ArrayView2<'a, BOXES>,
602 ArrayView2<'b, SCORES>,
603 ArrayView1<'c, CLASS>,
604 ),
605 crate::DecoderError,
606> {
607 let num_boxes = boxes.shape()[1];
608 if boxes.shape()[0] != 4 {
609 return Err(crate::DecoderError::InvalidShape(format!(
610 "Split end-to-end box_coords must be 4, got {}",
611 boxes.shape()[0]
612 )));
613 }
614
615 if scores.shape()[0] != 1 {
616 return Err(crate::DecoderError::InvalidShape(format!(
617 "Split end-to-end scores num_classes must be 1, got {}",
618 scores.shape()[0]
619 )));
620 }
621
622 if classes.shape()[0] != 1 {
623 return Err(crate::DecoderError::InvalidShape(format!(
624 "Split end-to-end classes num_classes must be 1, got {}",
625 classes.shape()[0]
626 )));
627 }
628
629 if scores.shape()[1] != num_boxes {
630 return Err(crate::DecoderError::InvalidShape(format!(
631 "Split end-to-end scores must have same num_boxes as boxes ({}), got {}",
632 num_boxes,
633 scores.shape()[1]
634 )));
635 }
636
637 if classes.shape()[1] != num_boxes {
638 return Err(crate::DecoderError::InvalidShape(format!(
639 "Split end-to-end classes must have same num_boxes as boxes ({}), got {}",
640 num_boxes,
641 classes.shape()[1]
642 )));
643 }
644
645 let boxes = boxes.reversed_axes();
646 let scores = scores.reversed_axes();
647 let classes = classes.slice(s![0, ..]);
648 Ok((boxes, scores, classes))
649}
650
651#[allow(clippy::type_complexity)]
654pub(crate) fn postprocess_yolo_split_end_to_end_segdet<
655 'a,
656 'b,
657 'c,
658 'd,
659 BOXES,
660 SCORES,
661 CLASS,
662 MASK,
663>(
664 boxes: ArrayView2<'a, BOXES>,
665 scores: ArrayView2<'b, SCORES>,
666 classes: &'c ArrayView2<CLASS>,
667 mask_coeff: ArrayView2<'d, MASK>,
668) -> Result<
669 (
670 ArrayView2<'a, BOXES>,
671 ArrayView2<'b, SCORES>,
672 ArrayView1<'c, CLASS>,
673 ArrayView2<'d, MASK>,
674 ),
675 crate::DecoderError,
676> {
677 let num_boxes = boxes.shape()[1];
678 if boxes.shape()[0] != 4 {
679 return Err(crate::DecoderError::InvalidShape(format!(
680 "Split end-to-end box_coords must be 4, got {}",
681 boxes.shape()[0]
682 )));
683 }
684
685 if scores.shape()[0] != 1 {
686 return Err(crate::DecoderError::InvalidShape(format!(
687 "Split end-to-end scores num_classes must be 1, got {}",
688 scores.shape()[0]
689 )));
690 }
691
692 if classes.shape()[0] != 1 {
693 return Err(crate::DecoderError::InvalidShape(format!(
694 "Split end-to-end classes num_classes must be 1, got {}",
695 classes.shape()[0]
696 )));
697 }
698
699 if scores.shape()[1] != num_boxes {
700 return Err(crate::DecoderError::InvalidShape(format!(
701 "Split end-to-end scores must have same num_boxes as boxes ({}), got {}",
702 num_boxes,
703 scores.shape()[1]
704 )));
705 }
706
707 if classes.shape()[1] != num_boxes {
708 return Err(crate::DecoderError::InvalidShape(format!(
709 "Split end-to-end classes must have same num_boxes as boxes ({}), got {}",
710 num_boxes,
711 classes.shape()[1]
712 )));
713 }
714
715 if mask_coeff.shape()[1] != num_boxes {
716 return Err(crate::DecoderError::InvalidShape(format!(
717 "Split end-to-end mask_coeff must have same num_boxes as boxes ({}), got {}",
718 num_boxes,
719 mask_coeff.shape()[1]
720 )));
721 }
722
723 let boxes = boxes.reversed_axes();
724 let scores = scores.reversed_axes();
725 let classes = classes.slice(s![0, ..]);
726 let mask_coeff = mask_coeff.reversed_axes();
727 Ok((boxes, scores, classes, mask_coeff))
728}
729pub(crate) fn impl_yolo_quant<B: BBoxTypeTrait, T: PrimInt + AsPrimitive<f32> + Send + Sync>(
734 output: (ArrayView2<T>, Quantization),
735 score_threshold: f32,
736 iou_threshold: f32,
737 nms: Option<Nms>,
738 output_boxes: &mut Vec<DetectBox>,
739) where
740 f32: AsPrimitive<T>,
741{
742 let _span = tracing::trace_span!("decode", mode = "quant_det").entered();
743 let (boxes, quant_boxes) = output;
744 let (boxes_tensor, scores_tensor) = postprocess_yolo(&boxes);
745
746 let boxes = {
747 let score_threshold = quantize_score_threshold(score_threshold, quant_boxes);
748 postprocess_boxes_quant::<B, _, _>(
749 score_threshold,
750 boxes_tensor,
751 scores_tensor,
752 quant_boxes,
753 )
754 };
755
756 let cap = cap_or_default(output_boxes);
757 let boxes = dispatch_nms_int(nms, iou_threshold, Some(cap), boxes);
758 let len = cap.min(boxes.len());
761 output_boxes.clear();
762 for b in boxes.iter().take(len) {
763 output_boxes.push(dequant_detect_box(b, quant_boxes));
764 }
765}
766
767pub(crate) fn impl_yolo_float<B: BBoxTypeTrait, T: Float + AsPrimitive<f32> + Send + Sync>(
772 output: ArrayView2<T>,
773 score_threshold: f32,
774 iou_threshold: f32,
775 nms: Option<Nms>,
776 output_boxes: &mut Vec<DetectBox>,
777) where
778 f32: AsPrimitive<T>,
779{
780 let _span = tracing::trace_span!("decode", mode = "float_det").entered();
781 let (boxes_tensor, scores_tensor) = postprocess_yolo(&output);
782 let boxes =
783 postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
784 let cap = cap_or_default(output_boxes);
785 let boxes = dispatch_nms_float(nms, iou_threshold, Some(cap), boxes);
786 let len = cap.min(boxes.len());
789 output_boxes.clear();
790 for b in boxes.into_iter().take(len) {
791 output_boxes.push(b);
792 }
793}
794
795pub(crate) fn impl_yolo_split_quant<
805 B: BBoxTypeTrait,
806 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
807 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
808>(
809 boxes: (ArrayView2<BOX>, Quantization),
810 scores: (ArrayView2<SCORE>, Quantization),
811 score_threshold: f32,
812 iou_threshold: f32,
813 nms: Option<Nms>,
814 output_boxes: &mut Vec<DetectBox>,
815) where
816 f32: AsPrimitive<SCORE>,
817{
818 let _span = tracing::trace_span!("decode", mode = "split_quant_det").entered();
819 let (boxes_tensor, quant_boxes) = boxes;
820 let (scores_tensor, quant_scores) = scores;
821
822 let boxes_tensor = boxes_tensor.reversed_axes();
823 let scores_tensor = scores_tensor.reversed_axes();
824
825 let boxes = {
826 let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
827 postprocess_boxes_quant::<B, _, _>(
828 score_threshold,
829 boxes_tensor,
830 scores_tensor,
831 quant_boxes,
832 )
833 };
834
835 let cap = cap_or_default(output_boxes);
836 let boxes = dispatch_nms_int(nms, iou_threshold, Some(cap), boxes);
837 let len = cap.min(boxes.len());
840 output_boxes.clear();
841 for b in boxes.iter().take(len) {
842 output_boxes.push(dequant_detect_box(b, quant_scores));
843 }
844}
845
846pub(crate) fn impl_yolo_split_float<
855 B: BBoxTypeTrait,
856 BOX: Float + AsPrimitive<f32> + Send + Sync,
857 SCORE: Float + AsPrimitive<f32> + Send + Sync,
858>(
859 boxes_tensor: ArrayView2<BOX>,
860 scores_tensor: ArrayView2<SCORE>,
861 score_threshold: f32,
862 iou_threshold: f32,
863 nms: Option<Nms>,
864 output_boxes: &mut Vec<DetectBox>,
865) where
866 f32: AsPrimitive<SCORE>,
867{
868 let _span = tracing::trace_span!("decode", mode = "split_float_det").entered();
869 let boxes_tensor = boxes_tensor.reversed_axes();
870 let scores_tensor = scores_tensor.reversed_axes();
871 let boxes =
872 postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
873 let cap = cap_or_default(output_boxes);
874 let boxes = dispatch_nms_float(nms, iou_threshold, Some(cap), boxes);
875 let len = cap.min(boxes.len());
878 output_boxes.clear();
879 for b in boxes.into_iter().take(len) {
880 output_boxes.push(b);
881 }
882}
883
884#[inline]
892pub(crate) fn maybe_normalize_boxes_in_place(
893 boxes: &mut [(DetectBox, usize)],
894 normalized: Option<bool>,
895 input_dims: Option<(usize, usize)>,
896) {
897 if normalized != Some(false) {
898 return;
899 }
900 let Some((w, h)) = input_dims else {
901 return;
902 };
903 if w == 0 || h == 0 {
904 return;
905 }
906 let inv_w = 1.0 / w as f32;
907 let inv_h = 1.0 / h as f32;
908 for (b, _) in boxes.iter_mut() {
909 b.bbox.xmin *= inv_w;
910 b.bbox.ymin *= inv_h;
911 b.bbox.xmax *= inv_w;
912 b.bbox.ymax *= inv_h;
913 }
914}
915
916#[allow(clippy::too_many_arguments)]
926pub(crate) fn impl_yolo_segdet_quant<
927 B: BBoxTypeTrait,
928 BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
929 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
930>(
931 boxes: (ArrayView2<BOX>, Quantization),
932 protos: (ArrayView3<PROTO>, Quantization),
933 score_threshold: f32,
934 iou_threshold: f32,
935 nms: Option<Nms>,
936 pre_nms_top_k: usize,
937 max_det: usize,
938 normalized: Option<bool>,
939 input_dims: Option<(usize, usize)>,
940 output_boxes: &mut Vec<DetectBox>,
941 output_masks: &mut Vec<Segmentation>,
942) -> Result<(), crate::DecoderError>
943where
944 f32: AsPrimitive<BOX>,
945{
946 let (boxes, quant_boxes) = boxes;
947 let num_protos = protos.0.dim().2;
948
949 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
950 let mut boxes = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
951 (boxes_tensor, quant_boxes),
952 (scores_tensor, quant_boxes),
953 score_threshold,
954 iou_threshold,
955 nms,
956 pre_nms_top_k,
957 max_det,
958 );
959 maybe_normalize_boxes_in_place(&mut boxes, normalized, input_dims);
960
961 impl_yolo_split_segdet_quant_process_masks::<_, _>(
962 boxes,
963 (mask_tensor, quant_boxes),
964 protos,
965 output_boxes,
966 output_masks,
967 )
968}
969
970#[allow(clippy::too_many_arguments)]
980pub(crate) fn impl_yolo_segdet_float<
981 B: BBoxTypeTrait,
982 BOX: Float + AsPrimitive<f32> + Send + Sync,
983 PROTO: Float + AsPrimitive<f32> + Send + Sync,
984>(
985 boxes: ArrayView2<BOX>,
986 protos: ArrayView3<PROTO>,
987 score_threshold: f32,
988 iou_threshold: f32,
989 nms: Option<Nms>,
990 pre_nms_top_k: usize,
991 max_det: usize,
992 normalized: Option<bool>,
993 input_dims: Option<(usize, usize)>,
994 output_boxes: &mut Vec<DetectBox>,
995 output_masks: &mut Vec<Segmentation>,
996) -> Result<(), crate::DecoderError>
997where
998 f32: AsPrimitive<BOX>,
999{
1000 let num_protos = protos.dim().2;
1001 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
1002 let mut boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
1003 boxes_tensor,
1004 scores_tensor,
1005 score_threshold,
1006 iou_threshold,
1007 nms,
1008 pre_nms_top_k,
1009 max_det,
1010 );
1011 maybe_normalize_boxes_in_place(&mut boxes, normalized, input_dims);
1012 impl_yolo_split_segdet_process_masks(boxes, mask_tensor, protos, output_boxes, output_masks)
1013}
1014
1015pub(crate) fn impl_yolo_segdet_get_boxes<
1016 B: BBoxTypeTrait,
1017 BOX: Float + AsPrimitive<f32> + Send + Sync,
1018 SCORE: Float + AsPrimitive<f32> + Send + Sync,
1019>(
1020 boxes_tensor: ArrayView2<BOX>,
1021 scores_tensor: ArrayView2<SCORE>,
1022 score_threshold: f32,
1023 iou_threshold: f32,
1024 nms: Option<Nms>,
1025 pre_nms_top_k: usize,
1026 max_det: usize,
1027) -> Vec<(DetectBox, usize)>
1028where
1029 f32: AsPrimitive<SCORE>,
1030{
1031 let span = tracing::trace_span!(
1032 "decode",
1033 n_candidates = tracing::field::Empty,
1034 n_after_topk = tracing::field::Empty,
1035 n_after_nms = tracing::field::Empty,
1036 n_detections = tracing::field::Empty,
1037 );
1038 let _guard = span.enter();
1039
1040 let mut boxes = {
1041 let _s = tracing::trace_span!("score_filter").entered();
1042 postprocess_boxes_index_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor)
1043 };
1044 span.record("n_candidates", boxes.len());
1045
1046 if nms.is_some() {
1047 let _s = tracing::trace_span!("top_k", k = pre_nms_top_k).entered();
1048 truncate_to_top_k_by_score(&mut boxes, pre_nms_top_k);
1049 }
1050 span.record("n_after_topk", boxes.len());
1051
1052 let mut boxes = {
1053 let _s = tracing::trace_span!("nms").entered();
1054 dispatch_nms_extra_float(nms, iou_threshold, Some(max_det), boxes)
1055 };
1056 span.record("n_after_nms", boxes.len());
1057
1058 boxes.sort_unstable_by(|a, b| b.0.score.total_cmp(&a.0.score));
1061 boxes.truncate(max_det);
1062 span.record("n_detections", boxes.len());
1063
1064 boxes
1065}
1066
1067pub(crate) fn impl_yolo_end_to_end_segdet_get_boxes<
1068 B: BBoxTypeTrait,
1069 BOX: Float + AsPrimitive<f32> + Send + Sync,
1070 SCORE: Float + AsPrimitive<f32> + Send + Sync,
1071 CLASS: AsPrimitive<f32> + Send + Sync,
1072>(
1073 boxes: ArrayView2<BOX>,
1074 scores: ArrayView2<SCORE>,
1075 classes: ArrayView1<CLASS>,
1076 score_threshold: f32,
1077 max_boxes: usize,
1078) -> Vec<(DetectBox, usize)>
1079where
1080 f32: AsPrimitive<SCORE>,
1081{
1082 let mut boxes = postprocess_boxes_index_float::<B, _, _>(score_threshold.as_(), boxes, scores);
1083 boxes.truncate(max_boxes);
1084 for (b, ind) in &mut boxes {
1085 b.label = classes[*ind].as_().round() as usize;
1086 }
1087 boxes
1088}
1089
1090pub(crate) fn impl_yolo_split_segdet_process_masks<
1091 MASK: Float + AsPrimitive<f32> + Send + Sync,
1092 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1093>(
1094 boxes: Vec<(DetectBox, usize)>,
1095 masks_tensor: ArrayView2<MASK>,
1096 protos_tensor: ArrayView3<PROTO>,
1097 output_boxes: &mut Vec<DetectBox>,
1098 output_masks: &mut Vec<Segmentation>,
1099) -> Result<(), crate::DecoderError> {
1100 let _span = tracing::trace_span!("process_masks", n = boxes.len(), mode = "float").entered();
1101 let boxes = decode_segdet_f32(boxes, masks_tensor, protos_tensor)?;
1105 output_boxes.clear();
1106 output_masks.clear();
1107 for (b, roi, m) in boxes.into_iter() {
1108 output_boxes.push(b);
1109 output_masks.push(Segmentation {
1110 xmin: roi.xmin,
1111 ymin: roi.ymin,
1112 xmax: roi.xmax,
1113 ymax: roi.ymax,
1114 segmentation: m,
1115 });
1116 }
1117 Ok(())
1118}
1119pub(crate) fn impl_yolo_split_segdet_quant_get_boxes<
1123 B: BBoxTypeTrait,
1124 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
1125 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
1126>(
1127 boxes: (ArrayView2<BOX>, Quantization),
1128 scores: (ArrayView2<SCORE>, Quantization),
1129 score_threshold: f32,
1130 iou_threshold: f32,
1131 nms: Option<Nms>,
1132 pre_nms_top_k: usize,
1133 max_det: usize,
1134) -> Vec<(DetectBox, usize)>
1135where
1136 f32: AsPrimitive<SCORE>,
1137{
1138 let (boxes_tensor, quant_boxes) = boxes;
1139 let (scores_tensor, quant_scores) = scores;
1140
1141 let span = tracing::trace_span!(
1142 "decode",
1143 n_candidates = tracing::field::Empty,
1144 n_after_topk = tracing::field::Empty,
1145 n_after_nms = tracing::field::Empty,
1146 n_detections = tracing::field::Empty,
1147 );
1148 let _guard = span.enter();
1149
1150 let mut boxes = {
1151 let _s = tracing::trace_span!("score_filter").entered();
1152 let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
1153 postprocess_boxes_index_quant::<B, _, _>(
1154 score_threshold,
1155 boxes_tensor,
1156 scores_tensor,
1157 quant_boxes,
1158 )
1159 };
1160 span.record("n_candidates", boxes.len());
1161
1162 if nms.is_some() {
1163 let _s = tracing::trace_span!("top_k", k = pre_nms_top_k).entered();
1164 truncate_to_top_k_by_score_quant(&mut boxes, pre_nms_top_k);
1165 }
1166 span.record("n_after_topk", boxes.len());
1167
1168 let mut boxes = {
1169 let _s = tracing::trace_span!("nms").entered();
1170 dispatch_nms_extra_int(nms, iou_threshold, Some(max_det), boxes)
1171 };
1172 span.record("n_after_nms", boxes.len());
1173
1174 boxes.sort_unstable_by_key(|b| std::cmp::Reverse(b.0.score));
1177 boxes.truncate(max_det);
1178 let result: Vec<_> = {
1179 let _s = tracing::trace_span!("box_dequant", n = boxes.len()).entered();
1180 boxes
1181 .into_iter()
1182 .map(|(b, i)| (dequant_detect_box(&b, quant_scores), i))
1183 .collect()
1184 };
1185 span.record("n_detections", result.len());
1186
1187 result
1188}
1189
1190pub(crate) fn impl_yolo_split_segdet_quant_process_masks<
1191 MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1192 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1193>(
1194 boxes: Vec<(DetectBox, usize)>,
1195 mask_coeff: (ArrayView2<MASK>, Quantization),
1196 protos: (ArrayView3<PROTO>, Quantization),
1197 output_boxes: &mut Vec<DetectBox>,
1198 output_masks: &mut Vec<Segmentation>,
1199) -> Result<(), crate::DecoderError> {
1200 let _span = tracing::trace_span!("process_masks", n = boxes.len(), mode = "quant").entered();
1201 let (masks, quant_masks) = mask_coeff;
1202 let (protos, quant_protos) = protos;
1203
1204 let boxes = decode_segdet_quant(boxes, masks, protos, quant_masks, quant_protos)?;
1208 output_boxes.clear();
1209 output_masks.clear();
1210 for (b, roi, m) in boxes.into_iter() {
1211 output_boxes.push(b);
1212 output_masks.push(Segmentation {
1213 xmin: roi.xmin,
1214 ymin: roi.ymin,
1215 xmax: roi.xmax,
1216 ymax: roi.ymax,
1217 segmentation: m,
1218 });
1219 }
1220 Ok(())
1221}
1222
1223#[allow(clippy::too_many_arguments)]
1235pub(crate) fn impl_yolo_split_segdet_float<
1236 B: BBoxTypeTrait,
1237 BOX: Float + AsPrimitive<f32> + Send + Sync,
1238 SCORE: Float + AsPrimitive<f32> + Send + Sync,
1239 MASK: Float + AsPrimitive<f32> + Send + Sync,
1240 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1241>(
1242 boxes_tensor: ArrayView2<BOX>,
1243 scores_tensor: ArrayView2<SCORE>,
1244 mask_tensor: ArrayView2<MASK>,
1245 protos: ArrayView3<PROTO>,
1246 score_threshold: f32,
1247 iou_threshold: f32,
1248 nms: Option<Nms>,
1249 pre_nms_top_k: usize,
1250 max_det: usize,
1251 normalized: Option<bool>,
1252 input_dims: Option<(usize, usize)>,
1253 output_boxes: &mut Vec<DetectBox>,
1254 output_masks: &mut Vec<Segmentation>,
1255) -> Result<(), crate::DecoderError>
1256where
1257 f32: AsPrimitive<SCORE>,
1258{
1259 let (boxes_tensor, scores_tensor, mask_tensor) =
1260 postprocess_yolo_split_segdet(boxes_tensor, scores_tensor, mask_tensor);
1261
1262 let mut boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
1263 boxes_tensor,
1264 scores_tensor,
1265 score_threshold,
1266 iou_threshold,
1267 nms,
1268 pre_nms_top_k,
1269 max_det,
1270 );
1271 maybe_normalize_boxes_in_place(&mut boxes, normalized, input_dims);
1272 impl_yolo_split_segdet_process_masks(boxes, mask_tensor, protos, output_boxes, output_masks)
1273}
1274
1275#[allow(clippy::too_many_arguments)]
1282pub(crate) fn impl_yolo_segdet_quant_proto<
1283 B: BBoxTypeTrait,
1284 BOX: PrimInt
1285 + AsPrimitive<i64>
1286 + AsPrimitive<i128>
1287 + AsPrimitive<f32>
1288 + AsPrimitive<i8>
1289 + Send
1290 + Sync,
1291 PROTO: PrimInt
1292 + AsPrimitive<i64>
1293 + AsPrimitive<i128>
1294 + AsPrimitive<f32>
1295 + AsPrimitive<i8>
1296 + Send
1297 + Sync,
1298>(
1299 boxes: (ArrayView2<BOX>, Quantization),
1300 protos: (ArrayView3<PROTO>, Quantization),
1301 score_threshold: f32,
1302 iou_threshold: f32,
1303 nms: Option<Nms>,
1304 pre_nms_top_k: usize,
1305 max_det: usize,
1306 normalized: Option<bool>,
1307 input_dims: Option<(usize, usize)>,
1308 output_boxes: &mut Vec<DetectBox>,
1309) -> ProtoData
1310where
1311 f32: AsPrimitive<BOX>,
1312{
1313 let (boxes_arr, quant_boxes) = boxes;
1314 let (protos_arr, quant_protos) = protos;
1315 let num_protos = protos_arr.dim().2;
1316
1317 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes_arr, num_protos);
1318
1319 let mut det_indices = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
1320 (boxes_tensor, quant_boxes),
1321 (scores_tensor, quant_boxes),
1322 score_threshold,
1323 iou_threshold,
1324 nms,
1325 pre_nms_top_k,
1326 max_det,
1327 );
1328 maybe_normalize_boxes_in_place(&mut det_indices, normalized, input_dims);
1329
1330 extract_proto_data_quant(
1331 det_indices,
1332 mask_tensor,
1333 quant_boxes,
1334 protos_arr,
1335 quant_protos,
1336 output_boxes,
1337 )
1338}
1339
1340#[allow(clippy::too_many_arguments)]
1343pub(crate) fn impl_yolo_segdet_float_proto<
1344 B: BBoxTypeTrait,
1345 BOX: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1346 PROTO: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1347>(
1348 boxes: ArrayView2<BOX>,
1349 protos: ArrayView3<PROTO>,
1350 score_threshold: f32,
1351 iou_threshold: f32,
1352 nms: Option<Nms>,
1353 pre_nms_top_k: usize,
1354 max_det: usize,
1355 normalized: Option<bool>,
1356 input_dims: Option<(usize, usize)>,
1357 output_boxes: &mut Vec<DetectBox>,
1358) -> ProtoData
1359where
1360 f32: AsPrimitive<BOX>,
1361{
1362 let num_protos = protos.dim().2;
1363 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
1364
1365 let mut boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
1366 boxes_tensor,
1367 scores_tensor,
1368 score_threshold,
1369 iou_threshold,
1370 nms,
1371 pre_nms_top_k,
1372 max_det,
1373 );
1374 maybe_normalize_boxes_in_place(&mut boxes, normalized, input_dims);
1375
1376 extract_proto_data_float(boxes, mask_tensor, protos, output_boxes)
1377}
1378
1379#[allow(clippy::too_many_arguments)]
1382pub(crate) fn impl_yolo_split_segdet_float_proto<
1383 B: BBoxTypeTrait,
1384 BOX: Float + AsPrimitive<f32> + Send + Sync,
1385 SCORE: Float + AsPrimitive<f32> + Send + Sync,
1386 MASK: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1387 PROTO: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1388>(
1389 boxes_tensor: ArrayView2<BOX>,
1390 scores_tensor: ArrayView2<SCORE>,
1391 mask_tensor: ArrayView2<MASK>,
1392 protos: ArrayView3<PROTO>,
1393 score_threshold: f32,
1394 iou_threshold: f32,
1395 nms: Option<Nms>,
1396 pre_nms_top_k: usize,
1397 max_det: usize,
1398 output_boxes: &mut Vec<DetectBox>,
1399) -> ProtoData
1400where
1401 f32: AsPrimitive<SCORE>,
1402{
1403 let (boxes_tensor, scores_tensor, mask_tensor) =
1404 postprocess_yolo_split_segdet(boxes_tensor, scores_tensor, mask_tensor);
1405 let det_indices = impl_yolo_segdet_get_boxes::<B, _, _>(
1406 boxes_tensor,
1407 scores_tensor,
1408 score_threshold,
1409 iou_threshold,
1410 nms,
1411 pre_nms_top_k,
1412 max_det,
1413 );
1414
1415 extract_proto_data_float(det_indices, mask_tensor, protos, output_boxes)
1416}
1417
1418pub(crate) fn decode_yolo_end_to_end_segdet_float_proto<T>(
1420 output: ArrayView2<T>,
1421 protos: ArrayView3<T>,
1422 score_threshold: f32,
1423 output_boxes: &mut Vec<DetectBox>,
1424) -> Result<ProtoData, crate::DecoderError>
1425where
1426 T: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1427 f32: AsPrimitive<T>,
1428{
1429 let (boxes, scores, classes, mask_coeff) =
1430 postprocess_yolo_end_to_end_segdet(&output, protos.dim().2)?;
1431 let cap = cap_or_default(output_boxes);
1432 let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
1433 boxes,
1434 scores,
1435 classes,
1436 score_threshold,
1437 cap,
1438 );
1439
1440 Ok(extract_proto_data_float(
1441 boxes,
1442 mask_coeff,
1443 protos,
1444 output_boxes,
1445 ))
1446}
1447
1448#[allow(clippy::too_many_arguments)]
1450pub(crate) fn decode_yolo_split_end_to_end_segdet_float_proto<T>(
1451 boxes: ArrayView2<T>,
1452 scores: ArrayView2<T>,
1453 classes: ArrayView2<T>,
1454 mask_coeff: ArrayView2<T>,
1455 protos: ArrayView3<T>,
1456 score_threshold: f32,
1457 output_boxes: &mut Vec<DetectBox>,
1458) -> Result<ProtoData, crate::DecoderError>
1459where
1460 T: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1461 f32: AsPrimitive<T>,
1462{
1463 let (boxes, scores, classes, mask_coeff) =
1464 postprocess_yolo_split_end_to_end_segdet(boxes, scores, &classes, mask_coeff)?;
1465 let cap = cap_or_default(output_boxes);
1466 let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
1467 boxes,
1468 scores,
1469 classes,
1470 score_threshold,
1471 cap,
1472 );
1473
1474 Ok(extract_proto_data_float(
1475 boxes,
1476 mask_coeff,
1477 protos,
1478 output_boxes,
1479 ))
1480}
1481
1482pub(super) fn extract_proto_data_float<
1489 MASK: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1490 PROTO: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1491>(
1492 det_indices: Vec<(DetectBox, usize)>,
1493 mask_tensor: ArrayView2<MASK>,
1494 protos: ArrayView3<PROTO>,
1495 output_boxes: &mut Vec<DetectBox>,
1496) -> ProtoData {
1497 let _span = tracing::trace_span!(
1498 "extract_proto",
1499 n = det_indices.len(),
1500 num_protos = mask_tensor.ncols(),
1501 layout = "nhwc",
1502 )
1503 .entered();
1504
1505 let num_protos = mask_tensor.ncols();
1506 let n = det_indices.len();
1507
1508 let mut coeff_rows: Vec<MASK> = Vec::with_capacity(n * num_protos);
1513 output_boxes.clear();
1514 for (det, idx) in det_indices {
1515 output_boxes.push(det);
1516 let row = mask_tensor.row(idx);
1517 coeff_rows.extend(row.iter().copied());
1518 }
1519
1520 let mask_coefficients = MASK::slice_into_tensor_dyn(&coeff_rows, &[n, num_protos])
1521 .expect("allocating mask_coefficients TensorDyn");
1522 let protos_tensor =
1523 PROTO::arrayview3_into_tensor_dyn(protos).expect("allocating protos TensorDyn");
1524
1525 ProtoData {
1526 mask_coefficients,
1527 protos: protos_tensor,
1528 layout: ProtoLayout::Nhwc,
1529 }
1530}
1531
1532pub(crate) fn extract_proto_data_quant<
1541 MASK: PrimInt + AsPrimitive<f32> + AsPrimitive<i8> + Send + Sync + 'static,
1542 PROTO: PrimInt + AsPrimitive<f32> + AsPrimitive<i8> + Send + Sync + 'static,
1543>(
1544 det_indices: Vec<(DetectBox, usize)>,
1545 mask_tensor: ArrayView2<MASK>,
1546 quant_masks: Quantization,
1547 protos: ArrayView3<PROTO>,
1548 quant_protos: Quantization,
1549 output_boxes: &mut Vec<DetectBox>,
1550) -> ProtoData {
1551 use edgefirst_tensor::{Tensor, TensorDyn, TensorMapTrait, TensorMemory, TensorTrait};
1552
1553 let span = tracing::trace_span!(
1554 "extract_proto",
1555 n = det_indices.len(),
1556 num_protos = tracing::field::Empty,
1557 layout = tracing::field::Empty,
1558 );
1559 let _guard = span.enter();
1560
1561 let num_protos = mask_tensor.ncols();
1562 let n = det_indices.len();
1563 span.record("num_protos", num_protos);
1564
1565 if n == 0 {
1571 output_boxes.clear();
1572 let (h, w, k) = protos.dim();
1573
1574 let (proto_shape, proto_layout) = if std::any::TypeId::of::<PROTO>()
1576 == std::any::TypeId::of::<i8>()
1577 {
1578 if protos.is_standard_layout() {
1579 (&[h, w, k][..], ProtoLayout::Nhwc)
1580 } else if protos.ndim() == 3 && protos.strides() == [w as isize, 1, (h * w) as isize] {
1581 (&[k, h, w][..], ProtoLayout::Nchw)
1582 } else {
1583 (&[h, w, k][..], ProtoLayout::Nhwc)
1584 }
1585 } else {
1586 (&[h, w, k][..], ProtoLayout::Nhwc)
1587 };
1588
1589 let coeff_tensor = Tensor::<i8>::new(&[0, num_protos], Some(TensorMemory::Mem), None)
1590 .expect("allocating empty mask_coefficients tensor");
1591 let coeff_quant =
1592 edgefirst_tensor::Quantization::per_tensor(quant_masks.scale, quant_masks.zero_point);
1593 let coeff_tensor = coeff_tensor
1594 .with_quantization(coeff_quant)
1595 .expect("per-tensor quantization on mask coefficients");
1596 let protos_tensor = Tensor::<i8>::new(proto_shape, Some(TensorMemory::Mem), None)
1597 .expect("allocating protos tensor");
1598 let tensor_quant =
1599 edgefirst_tensor::Quantization::per_tensor(quant_protos.scale, quant_protos.zero_point);
1600 let protos_tensor = protos_tensor
1601 .with_quantization(tensor_quant)
1602 .expect("per-tensor quantization on protos tensor");
1603 return ProtoData {
1604 mask_coefficients: TensorDyn::I8(coeff_tensor),
1605 protos: TensorDyn::I8(protos_tensor),
1606 layout: proto_layout,
1607 };
1608 }
1609
1610 let mask_coefficients: TensorDyn = if std::any::TypeId::of::<MASK>()
1616 == std::any::TypeId::of::<i8>()
1617 {
1618 let mut coeff_i8 = Vec::<i8>::with_capacity(n * num_protos);
1619 output_boxes.clear();
1620 for (det, idx) in det_indices {
1621 output_boxes.push(det);
1622 let row = mask_tensor.row(idx);
1623 coeff_i8.extend(row.iter().map(|v| {
1624 let v_i8: i8 = v.as_();
1625 v_i8
1626 }));
1627 }
1628 let coeff_tensor = Tensor::<i8>::new(&[n, num_protos], Some(TensorMemory::Mem), None)
1629 .expect("allocating mask_coefficients tensor");
1630 if n > 0 {
1631 let mut m = coeff_tensor
1632 .map()
1633 .expect("mapping mask_coefficients tensor");
1634 m.as_mut_slice().copy_from_slice(&coeff_i8);
1635 }
1636 let coeff_quant =
1637 edgefirst_tensor::Quantization::per_tensor(quant_masks.scale, quant_masks.zero_point);
1638 let coeff_tensor = coeff_tensor
1639 .with_quantization(coeff_quant)
1640 .expect("per-tensor quantization on mask coefficients");
1641 TensorDyn::I8(coeff_tensor)
1642 } else if std::any::TypeId::of::<MASK>() == std::any::TypeId::of::<i16>() {
1643 let mut coeff_i16 = Vec::<i16>::with_capacity(n * num_protos);
1646 output_boxes.clear();
1647 for (det, idx) in det_indices {
1648 output_boxes.push(det);
1649 let row = mask_tensor.row(idx);
1650 coeff_i16.extend(row.iter().map(|v| {
1651 let v_f32: f32 = v.as_();
1652 v_f32 as i16
1653 }));
1654 }
1655 let coeff_tensor = Tensor::<i16>::new(&[n, num_protos], Some(TensorMemory::Mem), None)
1656 .expect("allocating mask_coefficients tensor");
1657 if n > 0 {
1658 let mut m = coeff_tensor
1659 .map()
1660 .expect("mapping mask_coefficients tensor");
1661 m.as_mut_slice().copy_from_slice(&coeff_i16);
1662 }
1663 let coeff_quant =
1664 edgefirst_tensor::Quantization::per_tensor(quant_masks.scale, quant_masks.zero_point);
1665 let coeff_tensor = coeff_tensor
1666 .with_quantization(coeff_quant)
1667 .expect("per-tensor quantization on mask coefficients");
1668 TensorDyn::I16(coeff_tensor)
1669 } else {
1670 let scale = quant_masks.scale;
1672 let zp = quant_masks.zero_point as f32;
1673 let mut coeff_f32 = Vec::<f32>::with_capacity(n * num_protos);
1674 output_boxes.clear();
1675 for (det, idx) in det_indices {
1676 output_boxes.push(det);
1677 let row = mask_tensor.row(idx);
1678 coeff_f32.extend(row.iter().map(|v| {
1679 let v_f32: f32 = v.as_();
1680 (v_f32 - zp) * scale
1681 }));
1682 }
1683 let coeff_tensor = Tensor::<f32>::new(&[n, num_protos], Some(TensorMemory::Mem), None)
1684 .expect("allocating mask_coefficients tensor");
1685 if n > 0 {
1686 let mut m = coeff_tensor
1687 .map()
1688 .expect("mapping mask_coefficients tensor");
1689 m.as_mut_slice().copy_from_slice(&coeff_f32);
1690 }
1691 TensorDyn::F32(coeff_tensor)
1692 };
1693
1694 let (h, w, k) = protos.dim();
1698
1699 let (proto_shape, proto_layout) =
1701 if std::any::TypeId::of::<PROTO>() == std::any::TypeId::of::<i8>() {
1702 if protos.is_standard_layout() {
1703 (&[h, w, k][..], ProtoLayout::Nhwc)
1705 } else if protos.ndim() == 3 && protos.strides() == [w as isize, 1, (h * w) as isize] {
1706 (&[k, h, w][..], ProtoLayout::Nchw)
1710 } else {
1711 (&[h, w, k][..], ProtoLayout::Nhwc)
1713 }
1714 } else {
1715 (&[h, w, k][..], ProtoLayout::Nhwc)
1716 };
1717
1718 let protos_tensor = Tensor::<i8>::new(proto_shape, Some(TensorMemory::Mem), None)
1719 .expect("allocating protos tensor");
1720 {
1721 let mut m = protos_tensor.map().expect("mapping protos tensor");
1722 let dst = m.as_mut_slice();
1723 if std::any::TypeId::of::<PROTO>() == std::any::TypeId::of::<i8>() {
1724 if protos.is_standard_layout() {
1727 let src: &[i8] = unsafe {
1728 std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len())
1729 };
1730 dst.copy_from_slice(src);
1731 } else if protos.ndim() == 3 && protos.strides() == [w as isize, 1, (h * w) as isize] {
1732 let total = h * w * k;
1736 let src: &[i8] =
1739 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, total) };
1740 dst.copy_from_slice(src);
1741 } else {
1742 for (d, s) in dst.iter_mut().zip(protos.iter()) {
1743 let v_i8: i8 = s.as_();
1744 *d = v_i8;
1745 }
1746 }
1747 } else {
1748 for (d, s) in dst.iter_mut().zip(protos.iter()) {
1749 let v_i8: i8 = s.as_();
1750 *d = v_i8;
1751 }
1752 }
1753 }
1754 let tensor_quant =
1755 edgefirst_tensor::Quantization::per_tensor(quant_protos.scale, quant_protos.zero_point);
1756 let protos_tensor = protos_tensor
1757 .with_quantization(tensor_quant)
1758 .expect("per-tensor quantization on new Tensor<i8>");
1759
1760 span.record("layout", tracing::field::debug(&proto_layout));
1761
1762 ProtoData {
1763 mask_coefficients,
1764 protos: TensorDyn::I8(protos_tensor),
1765 layout: proto_layout,
1766 }
1767}
1768
1769pub trait FloatProtoElem: Copy + 'static {
1775 fn slice_into_tensor_dyn(
1776 values: &[Self],
1777 shape: &[usize],
1778 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn>;
1779
1780 fn arrayview3_into_tensor_dyn(
1781 view: ArrayView3<'_, Self>,
1782 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn>;
1783}
1784
1785impl FloatProtoElem for f32 {
1786 fn slice_into_tensor_dyn(
1787 values: &[f32],
1788 shape: &[usize],
1789 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1790 edgefirst_tensor::Tensor::<f32>::from_slice(values, shape)
1791 .map(edgefirst_tensor::TensorDyn::F32)
1792 }
1793 fn arrayview3_into_tensor_dyn(
1794 view: ArrayView3<'_, f32>,
1795 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1796 edgefirst_tensor::Tensor::<f32>::from_arrayview3(view).map(edgefirst_tensor::TensorDyn::F32)
1797 }
1798}
1799
1800impl FloatProtoElem for half::f16 {
1801 fn slice_into_tensor_dyn(
1802 values: &[half::f16],
1803 shape: &[usize],
1804 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1805 edgefirst_tensor::Tensor::<half::f16>::from_slice(values, shape)
1806 .map(edgefirst_tensor::TensorDyn::F16)
1807 }
1808 fn arrayview3_into_tensor_dyn(
1809 view: ArrayView3<'_, half::f16>,
1810 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1811 edgefirst_tensor::Tensor::<half::f16>::from_arrayview3(view)
1812 .map(edgefirst_tensor::TensorDyn::F16)
1813 }
1814}
1815
1816impl FloatProtoElem for f64 {
1817 fn slice_into_tensor_dyn(
1818 values: &[f64],
1819 shape: &[usize],
1820 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1821 let narrowed: Vec<f32> = values.iter().map(|&v| v as f32).collect();
1823 edgefirst_tensor::Tensor::<f32>::from_slice(&narrowed, shape)
1824 .map(edgefirst_tensor::TensorDyn::F32)
1825 }
1826 fn arrayview3_into_tensor_dyn(
1827 view: ArrayView3<'_, f64>,
1828 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1829 let narrowed: ndarray::Array3<f32> = view.mapv(|v| v as f32);
1830 edgefirst_tensor::Tensor::<f32>::from_arrayview3(narrowed.view())
1831 .map(edgefirst_tensor::TensorDyn::F32)
1832 }
1833}
1834
1835fn postprocess_yolo<'a, T>(
1836 output: &'a ArrayView2<'_, T>,
1837) -> (ArrayView2<'a, T>, ArrayView2<'a, T>) {
1838 let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1839 let scores_tensor = output.slice(s![4.., ..,]).reversed_axes();
1840 (boxes_tensor, scores_tensor)
1841}
1842
1843pub(crate) fn postprocess_yolo_seg<'a, T>(
1844 output: &'a ArrayView2<'_, T>,
1845 num_protos: usize,
1846) -> (ArrayView2<'a, T>, ArrayView2<'a, T>, ArrayView2<'a, T>) {
1847 assert!(
1848 output.shape()[0] > num_protos + 4,
1849 "Output shape is too short: {} <= {} + 4",
1850 output.shape()[0],
1851 num_protos
1852 );
1853 let num_classes = output.shape()[0] - 4 - num_protos;
1854 let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1855 let scores_tensor = output.slice(s![4..(num_classes + 4), ..,]).reversed_axes();
1856 let mask_tensor = output.slice(s![(num_classes + 4).., ..,]).reversed_axes();
1857 (boxes_tensor, scores_tensor, mask_tensor)
1858}
1859
1860pub(crate) fn postprocess_yolo_split_segdet<'a, 'b, 'c, BOX, SCORE, MASK>(
1861 boxes_tensor: ArrayView2<'a, BOX>,
1862 scores_tensor: ArrayView2<'b, SCORE>,
1863 mask_tensor: ArrayView2<'c, MASK>,
1864) -> (
1865 ArrayView2<'a, BOX>,
1866 ArrayView2<'b, SCORE>,
1867 ArrayView2<'c, MASK>,
1868) {
1869 let boxes_tensor = boxes_tensor.reversed_axes();
1870 let scores_tensor = scores_tensor.reversed_axes();
1871 let mask_tensor = mask_tensor.reversed_axes();
1872 (boxes_tensor, scores_tensor, mask_tensor)
1873}
1874
1875fn decode_segdet_f32<
1876 MASK: Float + AsPrimitive<f32> + Send + Sync,
1877 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1878>(
1879 boxes: Vec<(DetectBox, usize)>,
1880 masks: ArrayView2<MASK>,
1881 protos: ArrayView3<PROTO>,
1882) -> Result<Vec<(DetectBox, BoundingBox, Array3<u8>)>, crate::DecoderError> {
1883 if boxes.is_empty() {
1884 return Ok(Vec::new());
1885 }
1886 if masks.shape()[1] != protos.shape()[2] {
1887 return Err(crate::DecoderError::InvalidShape(format!(
1888 "Mask coefficients count ({}) doesn't match protos channel count ({})",
1889 masks.shape()[1],
1890 protos.shape()[2],
1891 )));
1892 }
1893 boxes
1894 .into_par_iter()
1895 .map(|b| {
1896 let ind = b.1;
1897 let (protos, roi) = protobox(&protos, &b.0.bbox)?;
1902 Ok((b.0, roi, make_segmentation(masks.row(ind), protos.view())))
1903 })
1904 .collect()
1905}
1906
1907pub(crate) fn decode_segdet_quant<
1908 MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1909 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1910>(
1911 boxes: Vec<(DetectBox, usize)>,
1912 masks: ArrayView2<MASK>,
1913 protos: ArrayView3<PROTO>,
1914 quant_masks: Quantization,
1915 quant_protos: Quantization,
1916) -> Result<Vec<(DetectBox, BoundingBox, Array3<u8>)>, crate::DecoderError> {
1917 if boxes.is_empty() {
1918 return Ok(Vec::new());
1919 }
1920 if masks.shape()[1] != protos.shape()[2] {
1921 return Err(crate::DecoderError::InvalidShape(format!(
1922 "Mask coefficients count ({}) doesn't match protos channel count ({})",
1923 masks.shape()[1],
1924 protos.shape()[2],
1925 )));
1926 }
1927
1928 let total_bits = MASK::zero().count_zeros() + PROTO::zero().count_zeros() + 5; boxes
1930 .into_iter()
1931 .map(|b| {
1932 let i = b.1;
1933 let (protos, roi) = protobox(&protos, &b.0.bbox.to_canonical())?;
1937 let seg = match total_bits {
1938 0..=64 => make_segmentation_quant::<MASK, PROTO, i64>(
1939 masks.row(i),
1940 protos.view(),
1941 quant_masks,
1942 quant_protos,
1943 ),
1944 65..=128 => make_segmentation_quant::<MASK, PROTO, i128>(
1945 masks.row(i),
1946 protos.view(),
1947 quant_masks,
1948 quant_protos,
1949 ),
1950 _ => {
1951 return Err(crate::DecoderError::NotSupported(format!(
1952 "Unsupported bit width ({total_bits}) for segmentation computation"
1953 )));
1954 }
1955 };
1956 Ok((b.0, roi, seg))
1957 })
1958 .collect()
1959}
1960
1961fn protobox<'a, T>(
1962 protos: &'a ArrayView3<T>,
1963 roi: &BoundingBox,
1964) -> Result<(ArrayView3<'a, T>, BoundingBox), crate::DecoderError> {
1965 let width = protos.dim().1 as f32;
1966 let height = protos.dim().0 as f32;
1967
1968 const NORM_LIMIT: f32 = 2.0;
1980 if roi.xmin > NORM_LIMIT
1981 || roi.ymin > NORM_LIMIT
1982 || roi.xmax > NORM_LIMIT
1983 || roi.ymax > NORM_LIMIT
1984 {
1985 return Err(crate::DecoderError::InvalidShape(format!(
1986 "Bounding box coordinates appear un-normalized (pixel-space). \
1987 Got bbox=({:.2}, {:.2}, {:.2}, {:.2}) but expected values in [0, 1]. \
1988 Two ways to fix this: \
1989 (1) declare `Detection::normalized = false` in the model schema \
1990 AND make sure the schema's `input.shape` / `input.dshape` carries \
1991 the model input dims so the decoder can divide by (W, H) before NMS \
1992 (EDGEAI-1303 — verify with `Decoder::input_dims().is_some()`); or \
1993 (2) normalize the boxes in-graph before decode().",
1994 roi.xmin, roi.ymin, roi.xmax, roi.ymax,
1995 )));
1996 }
1997
1998 let roi = [
1999 (roi.xmin * width).clamp(0.0, width) as usize,
2000 (roi.ymin * height).clamp(0.0, height) as usize,
2001 (roi.xmax * width).clamp(0.0, width).ceil() as usize,
2002 (roi.ymax * height).clamp(0.0, height).ceil() as usize,
2003 ];
2004
2005 let roi_norm = [
2006 roi[0] as f32 / width,
2007 roi[1] as f32 / height,
2008 roi[2] as f32 / width,
2009 roi[3] as f32 / height,
2010 ]
2011 .into();
2012
2013 let cropped = protos.slice(s![roi[1]..roi[3], roi[0]..roi[2], ..]);
2014
2015 Ok((cropped, roi_norm))
2016}
2017
2018fn make_segmentation<
2024 MASK: Float + AsPrimitive<f32> + Send + Sync,
2025 PROTO: Float + AsPrimitive<f32> + Send + Sync,
2026>(
2027 mask: ArrayView1<MASK>,
2028 protos: ArrayView3<PROTO>,
2029) -> Array3<u8> {
2030 let shape = protos.shape();
2031
2032 let mask = mask.to_shape((1, mask.len())).unwrap();
2034 let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
2035 let protos = protos.reversed_axes();
2036 let mask = mask.map(|x| x.as_());
2037 let protos = protos.map(|x| x.as_());
2038
2039 let mask = mask
2041 .dot(&protos)
2042 .into_shape_with_order((shape[0], shape[1], 1))
2043 .unwrap();
2044
2045 mask.map(|x| {
2046 let sigmoid = 1.0 / (1.0 + (-*x).exp());
2047 (sigmoid * 255.0).round() as u8
2048 })
2049}
2050
2051fn make_segmentation_quant<
2058 MASK: PrimInt + AsPrimitive<DEST> + Send + Sync,
2059 PROTO: PrimInt + AsPrimitive<DEST> + Send + Sync,
2060 DEST: PrimInt + 'static + Signed + AsPrimitive<f32> + Debug,
2061>(
2062 mask: ArrayView1<MASK>,
2063 protos: ArrayView3<PROTO>,
2064 quant_masks: Quantization,
2065 quant_protos: Quantization,
2066) -> Array3<u8>
2067where
2068 i32: AsPrimitive<DEST>,
2069 f32: AsPrimitive<DEST>,
2070{
2071 let shape = protos.shape();
2072
2073 let mask = mask.to_shape((1, mask.len())).unwrap();
2075
2076 let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
2077 let protos = protos.reversed_axes();
2078
2079 let zp = quant_masks.zero_point.as_();
2080
2081 let mask = mask.mapv(|x| x.as_() - zp);
2082
2083 let zp = quant_protos.zero_point.as_();
2084 let protos = protos.mapv(|x| x.as_() - zp);
2085
2086 let segmentation = mask
2088 .dot(&protos)
2089 .into_shape_with_order((shape[0], shape[1], 1))
2090 .unwrap();
2091
2092 let combined_scale = quant_masks.scale * quant_protos.scale;
2093 segmentation.map(|x| {
2094 let val: f32 = (*x).as_() * combined_scale;
2095 let sigmoid = 1.0 / (1.0 + (-val).exp());
2096 (sigmoid * 255.0).round() as u8
2097 })
2098}
2099
2100pub(crate) fn yolo_segmentation_to_mask(
2112 segmentation: ArrayView3<u8>,
2113 threshold: u8,
2114) -> Result<Array2<u8>, crate::DecoderError> {
2115 if segmentation.shape()[2] != 1 {
2116 return Err(crate::DecoderError::InvalidShape(format!(
2117 "Yolo Instance Segmentation should have shape (H, W, 1), got (H, W, {})",
2118 segmentation.shape()[2]
2119 )));
2120 }
2121 Ok(segmentation
2122 .slice(s![.., .., 0])
2123 .map(|x| if *x >= threshold { 1 } else { 0 }))
2124}
2125
2126#[cfg(test)]
2127#[cfg_attr(coverage_nightly, coverage(off))]
2128mod tests {
2129 use super::*;
2130 use ndarray::Array2;
2131
2132 #[test]
2137 fn test_end_to_end_det_basic_filtering() {
2138 let data: Vec<f32> = vec![
2142 0.1, 0.2, 0.3, 0.1, 0.2, 0.3, 0.5, 0.6, 0.7, 0.5, 0.6, 0.7, 0.9, 0.1, 0.2, 0.0, 1.0, 2.0, ];
2150 let output = Array2::from_shape_vec((6, 3), data).unwrap();
2151
2152 let mut boxes = Vec::with_capacity(10);
2153 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2154
2155 assert_eq!(boxes.len(), 1);
2157 assert_eq!(boxes[0].label, 0);
2158 assert!((boxes[0].score - 0.9).abs() < 0.01);
2159 assert!((boxes[0].bbox.xmin - 0.1).abs() < 0.01);
2160 assert!((boxes[0].bbox.ymin - 0.1).abs() < 0.01);
2161 assert!((boxes[0].bbox.xmax - 0.5).abs() < 0.01);
2162 assert!((boxes[0].bbox.ymax - 0.5).abs() < 0.01);
2163 }
2164
2165 #[test]
2166 fn test_end_to_end_det_all_pass_threshold() {
2167 let data: Vec<f32> = vec![
2169 10.0, 20.0, 10.0, 20.0, 50.0, 60.0, 50.0, 60.0, 0.8, 0.7, 1.0, 2.0, ];
2176 let output = Array2::from_shape_vec((6, 2), data).unwrap();
2177
2178 let mut boxes = Vec::with_capacity(10);
2179 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2180
2181 assert_eq!(boxes.len(), 2);
2182 assert_eq!(boxes[0].label, 1);
2183 assert_eq!(boxes[1].label, 2);
2184 }
2185
2186 #[test]
2187 fn test_end_to_end_det_none_pass_threshold() {
2188 let data: Vec<f32> = vec![
2190 10.0, 20.0, 10.0, 20.0, 50.0, 60.0, 50.0, 60.0, 0.1, 0.2, 1.0, 2.0, ];
2197 let output = Array2::from_shape_vec((6, 2), data).unwrap();
2198
2199 let mut boxes = Vec::with_capacity(10);
2200 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2201
2202 assert_eq!(boxes.len(), 0);
2203 }
2204
2205 #[test]
2206 fn test_end_to_end_det_capacity_limit() {
2207 let data: Vec<f32> = vec![
2209 0.1, 0.2, 0.3, 0.4, 0.5, 0.1, 0.2, 0.3, 0.4, 0.5, 0.5, 0.6, 0.7, 0.8, 0.9, 0.5, 0.6, 0.7, 0.8, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.0, 1.0, 2.0, 3.0, 4.0, ];
2216 let output = Array2::from_shape_vec((6, 5), data).unwrap();
2217
2218 let mut boxes = Vec::with_capacity(2); decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2220
2221 assert_eq!(boxes.len(), 2);
2222 }
2223
2224 #[test]
2225 fn test_end_to_end_det_empty_output() {
2226 let output = Array2::<f32>::zeros((6, 0));
2228
2229 let mut boxes = Vec::with_capacity(10);
2230 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2231
2232 assert_eq!(boxes.len(), 0);
2233 }
2234
2235 #[test]
2236 fn test_end_to_end_det_pixel_coordinates() {
2237 let data: Vec<f32> = vec![
2239 100.0, 200.0, 300.0, 400.0, 0.95, 5.0, ];
2246 let output = Array2::from_shape_vec((6, 1), data).unwrap();
2247
2248 let mut boxes = Vec::with_capacity(10);
2249 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2250
2251 assert_eq!(boxes.len(), 1);
2252 assert_eq!(boxes[0].label, 5);
2253 assert!((boxes[0].bbox.xmin - 100.0).abs() < 0.01);
2254 assert!((boxes[0].bbox.ymin - 200.0).abs() < 0.01);
2255 assert!((boxes[0].bbox.xmax - 300.0).abs() < 0.01);
2256 assert!((boxes[0].bbox.ymax - 400.0).abs() < 0.01);
2257 }
2258
2259 #[test]
2260 fn test_end_to_end_det_invalid_shape() {
2261 let output = Array2::<f32>::zeros((5, 3));
2263
2264 let mut boxes = Vec::with_capacity(10);
2265 let result = decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes);
2266
2267 assert!(result.is_err());
2268 assert!(matches!(
2269 result,
2270 Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 6 rows")
2271 ));
2272 }
2273
2274 #[test]
2279 fn test_end_to_end_segdet_basic() {
2280 let num_protos = 32;
2283 let num_detections = 2;
2284 let num_features = 6 + num_protos;
2285
2286 let mut data = vec![0.0f32; num_features * num_detections];
2288 data[0] = 0.1; data[1] = 0.5; data[num_detections] = 0.1; data[num_detections + 1] = 0.5; data[2 * num_detections] = 0.4; data[2 * num_detections + 1] = 0.9; data[3 * num_detections] = 0.4; data[3 * num_detections + 1] = 0.9; data[4 * num_detections] = 0.9; data[4 * num_detections + 1] = 0.3; data[5 * num_detections] = 1.0; data[5 * num_detections + 1] = 2.0; for i in 6..num_features {
2303 data[i * num_detections] = 0.1;
2304 data[i * num_detections + 1] = 0.1;
2305 }
2306
2307 let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
2308
2309 let protos = Array3::<f32>::zeros((16, 16, num_protos));
2311
2312 let mut boxes = Vec::with_capacity(10);
2313 let mut masks = Vec::with_capacity(10);
2314 decode_yolo_end_to_end_segdet_float(
2315 output.view(),
2316 protos.view(),
2317 0.5,
2318 &mut boxes,
2319 &mut masks,
2320 )
2321 .unwrap();
2322
2323 assert_eq!(boxes.len(), 1);
2325 assert_eq!(masks.len(), 1);
2326 assert_eq!(boxes[0].label, 1);
2327 assert!((boxes[0].score - 0.9).abs() < 0.01);
2328 }
2329
2330 #[test]
2331 fn test_end_to_end_segdet_mask_coordinates() {
2332 let num_protos = 32;
2334 let num_features = 6 + num_protos;
2335
2336 let mut data = vec![0.0f32; num_features];
2337 data[0] = 0.2; data[1] = 0.2; data[2] = 0.8; data[3] = 0.8; data[4] = 0.95; data[5] = 3.0; let output = Array2::from_shape_vec((num_features, 1), data).unwrap();
2345 let protos = Array3::<f32>::zeros((16, 16, num_protos));
2346
2347 let mut boxes = Vec::with_capacity(10);
2348 let mut masks = Vec::with_capacity(10);
2349 decode_yolo_end_to_end_segdet_float(
2350 output.view(),
2351 protos.view(),
2352 0.5,
2353 &mut boxes,
2354 &mut masks,
2355 )
2356 .unwrap();
2357
2358 assert_eq!(boxes.len(), 1);
2359 assert_eq!(masks.len(), 1);
2360
2361 let step = 1.0 / 16.0;
2365 assert!(masks[0].xmin <= boxes[0].bbox.xmin);
2366 assert!(masks[0].ymin <= boxes[0].bbox.ymin);
2367 assert!(masks[0].xmax >= boxes[0].bbox.xmax);
2368 assert!(masks[0].ymax >= boxes[0].bbox.ymax);
2369 assert!((boxes[0].bbox.xmin - masks[0].xmin) < step);
2370 assert!((boxes[0].bbox.ymin - masks[0].ymin) < step);
2371 assert!((masks[0].xmax - boxes[0].bbox.xmax) < step);
2372 assert!((masks[0].ymax - boxes[0].bbox.ymax) < step);
2373 }
2374
2375 #[test]
2376 fn test_end_to_end_segdet_empty_output() {
2377 let num_protos = 32;
2378 let output = Array2::<f32>::zeros((6 + num_protos, 0));
2379 let protos = Array3::<f32>::zeros((16, 16, num_protos));
2380
2381 let mut boxes = Vec::with_capacity(10);
2382 let mut masks = Vec::with_capacity(10);
2383 decode_yolo_end_to_end_segdet_float(
2384 output.view(),
2385 protos.view(),
2386 0.5,
2387 &mut boxes,
2388 &mut masks,
2389 )
2390 .unwrap();
2391
2392 assert_eq!(boxes.len(), 0);
2393 assert_eq!(masks.len(), 0);
2394 }
2395
2396 #[test]
2397 fn test_end_to_end_segdet_capacity_limit() {
2398 let num_protos = 32;
2399 let num_detections = 5;
2400 let num_features = 6 + num_protos;
2401
2402 let mut data = vec![0.0f32; num_features * num_detections];
2403 for i in 0..num_detections {
2405 data[i] = 0.1 * (i as f32); data[num_detections + i] = 0.1 * (i as f32); data[2 * num_detections + i] = 0.1 * (i as f32) + 0.2; data[3 * num_detections + i] = 0.1 * (i as f32) + 0.2; data[4 * num_detections + i] = 0.9; data[5 * num_detections + i] = i as f32; }
2412
2413 let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
2414 let protos = Array3::<f32>::zeros((16, 16, num_protos));
2415
2416 let mut boxes = Vec::with_capacity(2); let mut masks = Vec::with_capacity(2);
2418 decode_yolo_end_to_end_segdet_float(
2419 output.view(),
2420 protos.view(),
2421 0.5,
2422 &mut boxes,
2423 &mut masks,
2424 )
2425 .unwrap();
2426
2427 assert_eq!(boxes.len(), 2);
2428 assert_eq!(masks.len(), 2);
2429 }
2430
2431 #[test]
2432 fn test_end_to_end_segdet_invalid_shape_too_few_rows() {
2433 let output = Array2::<f32>::zeros((6, 3));
2435 let protos = Array3::<f32>::zeros((16, 16, 32));
2436
2437 let mut boxes = Vec::with_capacity(10);
2438 let mut masks = Vec::with_capacity(10);
2439 let result = decode_yolo_end_to_end_segdet_float(
2440 output.view(),
2441 protos.view(),
2442 0.5,
2443 &mut boxes,
2444 &mut masks,
2445 );
2446
2447 assert!(result.is_err());
2448 assert!(matches!(
2449 result,
2450 Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 7 rows")
2451 ));
2452 }
2453
2454 #[test]
2455 fn test_end_to_end_segdet_invalid_shape_protos_mismatch() {
2456 let num_protos = 32;
2458 let output = Array2::<f32>::zeros((6 + 16, 3)); let protos = Array3::<f32>::zeros((16, 16, num_protos)); let mut boxes = Vec::with_capacity(10);
2462 let mut masks = Vec::with_capacity(10);
2463 let result = decode_yolo_end_to_end_segdet_float(
2464 output.view(),
2465 protos.view(),
2466 0.5,
2467 &mut boxes,
2468 &mut masks,
2469 );
2470
2471 assert!(result.is_err());
2472 assert!(matches!(
2473 result,
2474 Err(crate::DecoderError::InvalidShape(s)) if s.contains("doesn't match protos count")
2475 ));
2476 }
2477
2478 #[test]
2483 fn test_split_end_to_end_segdet_basic() {
2484 let num_protos = 32;
2487 let num_detections = 2;
2488 let num_features = 6 + num_protos;
2489
2490 let mut data = vec![0.0f32; num_features * num_detections];
2492 data[0] = 0.1; data[1] = 0.5; data[num_detections] = 0.1; data[num_detections + 1] = 0.5; data[2 * num_detections] = 0.4; data[2 * num_detections + 1] = 0.9; data[3 * num_detections] = 0.4; data[3 * num_detections + 1] = 0.9; data[4 * num_detections] = 0.9; data[4 * num_detections + 1] = 0.3; data[5 * num_detections] = 1.0; data[5 * num_detections + 1] = 2.0; for i in 6..num_features {
2507 data[i * num_detections] = 0.1;
2508 data[i * num_detections + 1] = 0.1;
2509 }
2510
2511 let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
2512 let box_coords = output.slice(s![..4, ..]);
2513 let scores = output.slice(s![4..5, ..]);
2514 let classes = output.slice(s![5..6, ..]);
2515 let mask_coeff = output.slice(s![6.., ..]);
2516 let protos = Array3::<f32>::zeros((16, 16, num_protos));
2518
2519 let mut boxes = Vec::with_capacity(10);
2520 let mut masks = Vec::with_capacity(10);
2521 decode_yolo_split_end_to_end_segdet_float(
2522 box_coords,
2523 scores,
2524 classes,
2525 mask_coeff,
2526 protos.view(),
2527 0.5,
2528 &mut boxes,
2529 &mut masks,
2530 )
2531 .unwrap();
2532
2533 assert_eq!(boxes.len(), 1);
2535 assert_eq!(masks.len(), 1);
2536 assert_eq!(boxes[0].label, 1);
2537 assert!((boxes[0].score - 0.9).abs() < 0.01);
2538 }
2539
2540 #[test]
2545 fn test_segmentation_to_mask_basic() {
2546 let data: Vec<u8> = vec![
2548 100, 200, 50, 150, 10, 255, 128, 64, 0, 127, 128, 255, 64, 64, 192, 192, ];
2553 let segmentation = Array3::from_shape_vec((4, 4, 1), data).unwrap();
2554
2555 let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2556
2557 assert_eq!(mask[[0, 0]], 0); assert_eq!(mask[[0, 1]], 1); assert_eq!(mask[[0, 2]], 0); assert_eq!(mask[[0, 3]], 1); assert_eq!(mask[[1, 1]], 1); assert_eq!(mask[[1, 2]], 1); assert_eq!(mask[[2, 0]], 0); assert_eq!(mask[[2, 1]], 0); }
2567
2568 #[test]
2569 fn test_segmentation_to_mask_all_above() {
2570 let segmentation = Array3::from_elem((4, 4, 1), 255u8);
2571 let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2572 assert!(mask.iter().all(|&x| x == 1));
2573 }
2574
2575 #[test]
2576 fn test_segmentation_to_mask_all_below() {
2577 let segmentation = Array3::from_elem((4, 4, 1), 64u8);
2578 let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2579 assert!(mask.iter().all(|&x| x == 0));
2580 }
2581
2582 #[test]
2583 fn test_segmentation_to_mask_invalid_shape() {
2584 let segmentation = Array3::from_elem((4, 4, 3), 128u8);
2585 let result = yolo_segmentation_to_mask(segmentation.view(), 128);
2586
2587 assert!(result.is_err());
2588 assert!(matches!(
2589 result,
2590 Err(crate::DecoderError::InvalidShape(s)) if s.contains("(H, W, 1)")
2591 ));
2592 }
2593
2594 #[test]
2599 fn test_protobox_clamps_edge_coordinates() {
2600 let protos = Array3::<f32>::zeros((16, 16, 4));
2602 let view = protos.view();
2603 let roi = BoundingBox {
2604 xmin: 0.5,
2605 ymin: 0.5,
2606 xmax: 1.0,
2607 ymax: 1.0,
2608 };
2609 let result = protobox(&view, &roi);
2610 assert!(result.is_ok(), "protobox should accept xmax=1.0");
2611 let (cropped, _roi_norm) = result.unwrap();
2612 assert!(cropped.shape()[0] > 0);
2614 assert!(cropped.shape()[1] > 0);
2615 assert_eq!(cropped.shape()[2], 4);
2616 }
2617
2618 #[test]
2619 fn test_protobox_rejects_wildly_out_of_range() {
2620 let protos = Array3::<f32>::zeros((16, 16, 4));
2622 let view = protos.view();
2623 let roi = BoundingBox {
2624 xmin: 0.0,
2625 ymin: 0.0,
2626 xmax: 3.0,
2627 ymax: 3.0,
2628 };
2629 let result = protobox(&view, &roi);
2630 assert!(
2631 matches!(result, Err(crate::DecoderError::InvalidShape(s)) if s.contains("un-normalized")),
2632 "protobox should reject coords > NORM_LIMIT"
2633 );
2634 }
2635
2636 #[test]
2637 fn test_protobox_accepts_slightly_over_one() {
2638 let protos = Array3::<f32>::zeros((16, 16, 4));
2640 let view = protos.view();
2641 let roi = BoundingBox {
2642 xmin: 0.0,
2643 ymin: 0.0,
2644 xmax: 1.5,
2645 ymax: 1.5,
2646 };
2647 let result = protobox(&view, &roi);
2648 assert!(
2649 result.is_ok(),
2650 "protobox should accept coords <= NORM_LIMIT (2.0)"
2651 );
2652 let (cropped, _roi_norm) = result.unwrap();
2653 assert_eq!(cropped.shape()[0], 16);
2655 assert_eq!(cropped.shape()[1], 16);
2656 }
2657
2658 #[test]
2659 fn test_segdet_float_proto_no_panic() {
2660 let num_proposals = 100; let num_classes = 80;
2664 let num_mask_coeffs = 32;
2665 let rows = 4 + num_classes + num_mask_coeffs; let mut data = vec![0.0f32; rows * num_proposals];
2671 for i in 0..num_proposals {
2672 let row = |r: usize| r * num_proposals + i;
2673 data[row(0)] = 320.0; data[row(1)] = 320.0; data[row(2)] = 50.0; data[row(3)] = 50.0; data[row(4)] = 0.9; }
2679 let boxes = ndarray::Array2::from_shape_vec((rows, num_proposals), data).unwrap();
2680
2681 let protos = ndarray::Array3::<f32>::zeros((160, 160, num_mask_coeffs));
2686
2687 let mut output_boxes = Vec::with_capacity(300);
2688
2689 let proto_data = impl_yolo_segdet_float_proto::<XYWH, _, _>(
2691 boxes.view(),
2692 protos.view(),
2693 0.5,
2694 0.7,
2695 Some(Nms::default()),
2696 MAX_NMS_CANDIDATES,
2697 300,
2698 None,
2699 None,
2700 &mut output_boxes,
2701 );
2702
2703 assert!(!output_boxes.is_empty());
2705 let coeffs_shape = proto_data.mask_coefficients.shape();
2706 assert_eq!(coeffs_shape[0], output_boxes.len());
2707 assert_eq!(coeffs_shape[1], num_mask_coeffs);
2709 }
2710
2711 #[test]
2726 fn test_pre_nms_cap_truncates_excess_candidates() {
2727 let n: usize = 50_000;
2728 let num_classes = 1;
2729
2730 let mut boxes_data = Vec::with_capacity(n * 4);
2734 let mut scores_data = Vec::with_capacity(n * num_classes);
2735 for i in 0..n {
2736 boxes_data.extend_from_slice(&[0.1f32, 0.1, 0.5, 0.5]);
2737 scores_data.push(0.99 - (i as f32) * 1e-7);
2740 }
2741 let boxes = Array2::from_shape_vec((n, 4), boxes_data).unwrap();
2742 let scores = Array2::from_shape_vec((n, num_classes), scores_data).unwrap();
2743
2744 let result = impl_yolo_segdet_get_boxes::<XYXY, _, _>(
2745 boxes.view(),
2746 scores.view(),
2747 0.1,
2748 1.0, Some(Nms::ClassAgnostic), crate::yolo::MAX_NMS_CANDIDATES, usize::MAX, );
2753
2754 assert_eq!(
2755 result.len(),
2756 crate::yolo::MAX_NMS_CANDIDATES,
2757 "pre-NMS cap should truncate to MAX_NMS_CANDIDATES; got {}",
2758 result.len()
2759 );
2760 let top_score = result[0].0.score;
2763 assert!(
2764 top_score > 0.98,
2765 "highest-ranked survivor should have the largest score, got {top_score}"
2766 );
2767 }
2768
2769 #[test]
2774 fn test_pre_nms_cap_truncates_excess_candidates_quant() {
2775 use crate::Quantization;
2776 let n: usize = 50_000;
2777 let num_classes = 1;
2778
2779 let boxes_data = (0..n).flat_map(|_| [10i8, 10, 50, 50]).collect::<Vec<_>>();
2782 let boxes = Array2::from_shape_vec((n, 4), boxes_data).unwrap();
2783 let quant_boxes = Quantization {
2784 scale: 0.01,
2785 zero_point: 0,
2786 };
2787
2788 let scores_data: Vec<u8> = (0..n)
2793 .map(|i| 250u8.saturating_sub((i % 200) as u8))
2794 .collect();
2795 let scores = Array2::from_shape_vec((n, num_classes), scores_data).unwrap();
2796 let quant_scores = Quantization {
2797 scale: 0.00392,
2798 zero_point: 0,
2799 };
2800
2801 let result = impl_yolo_split_segdet_quant_get_boxes::<XYXY, _, _>(
2802 (boxes.view(), quant_boxes),
2803 (scores.view(), quant_scores),
2804 0.1,
2805 1.0, Some(Nms::ClassAgnostic), crate::yolo::MAX_NMS_CANDIDATES, usize::MAX, );
2810
2811 assert_eq!(
2812 result.len(),
2813 crate::yolo::MAX_NMS_CANDIDATES,
2814 "quant path pre-NMS cap should truncate to MAX_NMS_CANDIDATES; got {}",
2815 result.len()
2816 );
2817 }
2818
2819 #[test]
2833 fn segdet_combined_tensor_pairs_detection_with_matching_mask_row() {
2834 let nc = 2; let nm = 2; let n = 3; let feat = 4 + nc + nm; let mut data = vec![0.0f32; feat * n];
2857 let set = |d: &mut [f32], r: usize, c: usize, v: f32| d[r * n + c] = v;
2858 set(&mut data, 0, 0, 0.2);
2859 set(&mut data, 1, 0, 0.2);
2860 set(&mut data, 2, 0, 0.1);
2861 set(&mut data, 3, 0, 0.1);
2862 set(&mut data, 0, 1, 0.5);
2863 set(&mut data, 1, 1, 0.5);
2864 set(&mut data, 2, 1, 0.1);
2865 set(&mut data, 3, 1, 0.1);
2866 set(&mut data, 0, 2, 0.8);
2867 set(&mut data, 1, 2, 0.8);
2868 set(&mut data, 2, 2, 0.1);
2869 set(&mut data, 3, 2, 0.1);
2870 set(&mut data, 4, 0, 0.9);
2871 set(&mut data, 4, 2, 0.8);
2872 set(&mut data, 6, 0, 3.0);
2873 set(&mut data, 7, 0, 3.0);
2874 set(&mut data, 6, 2, -3.0);
2875 set(&mut data, 7, 2, -3.0);
2876
2877 let output = Array2::from_shape_vec((feat, n), data).unwrap();
2878 let protos = Array3::<f32>::from_elem((8, 8, nm), 1.0);
2879
2880 let mut boxes: Vec<DetectBox> = Vec::with_capacity(10);
2881 let mut masks: Vec<Segmentation> = Vec::with_capacity(10);
2882 decode_yolo_segdet_float(
2883 output.view(),
2884 protos.view(),
2885 0.5,
2886 0.5,
2887 Some(Nms::ClassAgnostic),
2888 &mut boxes,
2889 &mut masks,
2890 )
2891 .unwrap();
2892
2893 assert_eq!(
2894 boxes.len(),
2895 2,
2896 "two anchors above threshold should survive (a0 score=0.9, a2 score=0.8); got {}",
2897 boxes.len()
2898 );
2899
2900 for (b, m) in boxes.iter().zip(masks.iter()) {
2906 let cx = (b.bbox.xmin + b.bbox.xmax) * 0.5;
2907 let mean = {
2908 let s = &m.segmentation;
2909 let total: u32 = s.iter().map(|&v| v as u32).sum();
2910 total as f32 / s.len() as f32
2911 };
2912 if cx < 0.3 {
2913 assert!(
2915 mean > 200.0,
2916 "anchor 0 detection (centre {cx:.2}) should have high-value mask; got mean {mean}"
2917 );
2918 } else if cx > 0.7 {
2919 assert!(
2921 mean < 50.0,
2922 "anchor 2 detection (centre {cx:.2}) should have low-value mask; got mean {mean}"
2923 );
2924 } else {
2925 panic!("unexpected detection centre {cx:.2}");
2926 }
2927 }
2928 }
2929
2930 fn make_float_boxes(scores: &[f32]) -> Vec<(DetectBox, ())> {
2936 scores
2937 .iter()
2938 .enumerate()
2939 .map(|(i, &s)| {
2940 (
2941 DetectBox {
2942 bbox: BoundingBox {
2943 xmin: 0.0,
2944 ymin: 0.0,
2945 xmax: 1.0,
2946 ymax: 1.0,
2947 },
2948 score: s,
2949 label: i,
2950 },
2951 (),
2952 )
2953 })
2954 .collect()
2955 }
2956
2957 fn make_quant_boxes(scores: &[i8]) -> Vec<(DetectBoxQuantized<i8>, ())> {
2959 scores
2960 .iter()
2961 .enumerate()
2962 .map(|(i, &s)| {
2963 (
2964 DetectBoxQuantized {
2965 bbox: BoundingBox {
2966 xmin: 0.0,
2967 ymin: 0.0,
2968 xmax: 1.0,
2969 ymax: 1.0,
2970 },
2971 score: s,
2972 label: i,
2973 },
2974 (),
2975 )
2976 })
2977 .collect()
2978 }
2979
2980 #[test]
2981 fn truncate_float_top_k_zero_is_unbounded() {
2982 let mut boxes = make_float_boxes(&[0.9, 0.1, 0.5, 0.3, 0.7]);
2983 let original_len = boxes.len();
2984 truncate_to_top_k_by_score(&mut boxes, 0);
2985 assert_eq!(
2986 boxes.len(),
2987 original_len,
2988 "top_k=0 should keep all candidates (no-limit semantics)"
2989 );
2990 }
2991
2992 #[test]
2993 fn truncate_float_top_k_normal() {
2994 let mut boxes = make_float_boxes(&[0.9, 0.1, 0.5, 0.3, 0.7]);
2995 truncate_to_top_k_by_score(&mut boxes, 3);
2996 assert_eq!(boxes.len(), 3);
2997 let mut retained: Vec<f32> = boxes.iter().map(|(b, _)| b.score).collect();
2999 retained.sort_by(|a, b| b.total_cmp(a));
3000 assert_eq!(retained, vec![0.9, 0.7, 0.5]);
3001 }
3002
3003 #[test]
3004 fn truncate_float_top_k_noop_when_under_cap() {
3005 let mut boxes = make_float_boxes(&[0.9, 0.5]);
3006 truncate_to_top_k_by_score(&mut boxes, 10);
3007 assert_eq!(boxes.len(), 2, "should be no-op when len <= top_k");
3008 }
3009
3010 #[test]
3011 fn truncate_quant_top_k_zero_is_unbounded() {
3012 let mut boxes = make_quant_boxes(&[120, -50, 30, -10, 80]);
3013 let original_len = boxes.len();
3014 truncate_to_top_k_by_score_quant(&mut boxes, 0);
3015 assert_eq!(
3016 boxes.len(),
3017 original_len,
3018 "top_k=0 should keep all candidates (no-limit semantics)"
3019 );
3020 }
3021
3022 #[test]
3023 fn truncate_quant_top_k_normal() {
3024 let mut boxes = make_quant_boxes(&[120, -50, 30, -10, 80]);
3025 truncate_to_top_k_by_score_quant(&mut boxes, 3);
3026 assert_eq!(boxes.len(), 3);
3027 let mut retained: Vec<i8> = boxes.iter().map(|(b, _)| b.score).collect();
3028 retained.sort_by(|a, b| b.cmp(a));
3029 assert_eq!(retained, vec![120, 80, 30]);
3030 }
3031
3032 #[test]
3033 fn truncate_quant_top_k_noop_when_under_cap() {
3034 let mut boxes = make_quant_boxes(&[120, 80]);
3035 truncate_to_top_k_by_score_quant(&mut boxes, 10);
3036 assert_eq!(boxes.len(), 2, "should be no-op when len <= top_k");
3037 }
3038}