1use std::fmt::Debug;
13
14use ndarray::{
15 parallel::prelude::{IntoParallelIterator, ParallelIterator},
16 s, Array2, Array3, ArrayView1, ArrayView2, ArrayView3,
17};
18use num_traits::{AsPrimitive, Float, PrimInt, Signed};
19
20use crate::{
21 byte::{
22 nms_class_aware_int, nms_extra_class_aware_int, nms_extra_int, nms_int,
23 postprocess_boxes_index_quant, postprocess_boxes_quant, quantize_score_threshold,
24 },
25 configs::Nms,
26 dequant_detect_box,
27 float::{
28 nms_class_aware_float, nms_extra_class_aware_float, nms_extra_float, nms_float,
29 postprocess_boxes_float, postprocess_boxes_index_float,
30 },
31 BBoxTypeTrait, BoundingBox, DetectBox, DetectBoxQuantized, ProtoData, ProtoLayout,
32 Quantization, Segmentation, XYWH, XYXY,
33};
34
35#[cfg(test)]
53pub(crate) const MAX_NMS_CANDIDATES: usize = 30_000;
54
55pub(crate) const DEFAULT_MAX_DETECTIONS: usize = 300;
62
63fn truncate_to_top_k_by_score<E: Send>(boxes: &mut Vec<(DetectBox, E)>, top_k: usize) {
68 if top_k > 0 && boxes.len() > top_k {
69 boxes.select_nth_unstable_by(top_k, |a, b| b.0.score.total_cmp(&a.0.score));
70 boxes.truncate(top_k);
71 }
72}
73
74fn truncate_to_top_k_by_score_quant<S: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send>(
79 boxes: &mut Vec<(DetectBoxQuantized<S>, E)>,
80 top_k: usize,
81) {
82 if top_k > 0 && boxes.len() > top_k {
83 boxes.select_nth_unstable_by(top_k, |a, b| b.0.score.cmp(&a.0.score));
84 boxes.truncate(top_k);
85 }
86}
87
88fn dispatch_nms_float(
95 nms: Option<Nms>,
96 iou: f32,
97 max_det: Option<usize>,
98 boxes: Vec<DetectBox>,
99) -> Vec<DetectBox> {
100 match nms {
101 Some(Nms::ClassAgnostic | Nms::Auto) => nms_float(iou, max_det, boxes),
102 Some(Nms::ClassAware) => nms_class_aware_float(iou, max_det, boxes),
103 None => boxes, }
105}
106
107pub(super) fn dispatch_nms_extra_float<E: Send + Sync>(
110 nms: Option<Nms>,
111 iou: f32,
112 max_det: Option<usize>,
113 boxes: Vec<(DetectBox, E)>,
114) -> Vec<(DetectBox, E)> {
115 match nms {
116 Some(Nms::ClassAgnostic | Nms::Auto) => nms_extra_float(iou, max_det, boxes),
117 Some(Nms::ClassAware) => nms_extra_class_aware_float(iou, max_det, boxes),
118 None => boxes, }
120}
121
122fn dispatch_nms_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync>(
125 nms: Option<Nms>,
126 iou: f32,
127 max_det: Option<usize>,
128 boxes: Vec<DetectBoxQuantized<SCORE>>,
129) -> Vec<DetectBoxQuantized<SCORE>> {
130 match nms {
131 Some(Nms::ClassAgnostic | Nms::Auto) => nms_int(iou, max_det, boxes),
132 Some(Nms::ClassAware) => nms_class_aware_int(iou, max_det, boxes),
133 None => boxes, }
135}
136
137fn dispatch_nms_extra_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send + Sync>(
140 nms: Option<Nms>,
141 iou: f32,
142 max_det: Option<usize>,
143 boxes: Vec<(DetectBoxQuantized<SCORE>, E)>,
144) -> Vec<(DetectBoxQuantized<SCORE>, E)> {
145 match nms {
146 Some(Nms::ClassAgnostic | Nms::Auto) => nms_extra_int(iou, max_det, boxes),
147 Some(Nms::ClassAware) => nms_extra_class_aware_int(iou, max_det, boxes),
148 None => boxes, }
150}
151
152#[inline]
159fn cap_or_default<T>(v: &Vec<T>) -> usize {
160 if v.capacity() > 0 {
161 v.capacity()
162 } else {
163 DEFAULT_MAX_DETECTIONS
164 }
165}
166
167pub(crate) fn decode_yolo_det<BOX: PrimInt + AsPrimitive<f32> + Send + Sync>(
201 output: (ArrayView2<BOX>, Quantization),
202 score_threshold: f32,
203 iou_threshold: f32,
204 nms: Option<Nms>,
205 output_boxes: &mut Vec<DetectBox>,
206) where
207 f32: AsPrimitive<BOX>,
208{
209 impl_yolo_quant::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
210}
211
212pub(crate) fn decode_yolo_det_float<T>(
219 output: ArrayView2<T>,
220 score_threshold: f32,
221 iou_threshold: f32,
222 nms: Option<Nms>,
223 output_boxes: &mut Vec<DetectBox>,
224) where
225 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
226 f32: AsPrimitive<T>,
227{
228 impl_yolo_float::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
229}
230
231#[cfg(test)]
245pub(crate) fn decode_yolo_segdet_quant<
246 BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
247 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
248>(
249 boxes: (ArrayView2<BOX>, Quantization),
250 protos: (ArrayView3<PROTO>, Quantization),
251 score_threshold: f32,
252 iou_threshold: f32,
253 nms: Option<Nms>,
254 output_boxes: &mut Vec<DetectBox>,
255 output_masks: &mut Vec<Segmentation>,
256) -> Result<(), crate::DecoderError>
257where
258 f32: AsPrimitive<BOX>,
259{
260 let cap = cap_or_default(output_boxes);
265 impl_yolo_segdet_quant::<XYWH, _, _>(
266 boxes,
267 protos,
268 score_threshold,
269 iou_threshold,
270 nms,
271 MAX_NMS_CANDIDATES,
272 cap,
273 None,
274 None,
275 output_boxes,
276 output_masks,
277 )
278}
279
280#[cfg(test)]
282pub(crate) fn decode_yolo_segdet_float<T>(
283 boxes: ArrayView2<T>,
284 protos: ArrayView3<T>,
285 score_threshold: f32,
286 iou_threshold: f32,
287 nms: Option<Nms>,
288 output_boxes: &mut Vec<DetectBox>,
289 output_masks: &mut Vec<Segmentation>,
290) -> Result<(), crate::DecoderError>
291where
292 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
293 f32: AsPrimitive<T>,
294{
295 let cap = cap_or_default(output_boxes);
298 impl_yolo_segdet_float::<XYWH, _, _>(
299 boxes,
300 protos,
301 score_threshold,
302 iou_threshold,
303 nms,
304 MAX_NMS_CANDIDATES,
305 cap,
306 None,
307 None,
308 output_boxes,
309 output_masks,
310 )
311}
312
313pub(crate) fn decode_yolo_split_det_quant<
325 BOX: PrimInt + AsPrimitive<i32> + AsPrimitive<f32> + Send + Sync,
326 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
327>(
328 boxes: (ArrayView2<BOX>, Quantization),
329 scores: (ArrayView2<SCORE>, Quantization),
330 score_threshold: f32,
331 iou_threshold: f32,
332 nms: Option<Nms>,
333 output_boxes: &mut Vec<DetectBox>,
334) where
335 f32: AsPrimitive<SCORE>,
336{
337 impl_yolo_split_quant::<XYWH, _, _>(
338 boxes,
339 scores,
340 score_threshold,
341 iou_threshold,
342 nms,
343 output_boxes,
344 );
345}
346
347pub(crate) fn decode_yolo_split_det_float<T>(
359 boxes: ArrayView2<T>,
360 scores: ArrayView2<T>,
361 score_threshold: f32,
362 iou_threshold: f32,
363 nms: Option<Nms>,
364 output_boxes: &mut Vec<DetectBox>,
365) where
366 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
367 f32: AsPrimitive<T>,
368{
369 impl_yolo_split_float::<XYWH, _, _>(
370 boxes,
371 scores,
372 score_threshold,
373 iou_threshold,
374 nms,
375 output_boxes,
376 );
377}
378
379pub(crate) fn decode_yolo_end_to_end_det_float<T>(
394 output: ArrayView2<T>,
395 score_threshold: f32,
396 output_boxes: &mut Vec<DetectBox>,
397) -> Result<(), crate::DecoderError>
398where
399 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
400 f32: AsPrimitive<T>,
401{
402 if output.shape()[0] < 6 {
404 return Err(crate::DecoderError::InvalidShape(format!(
405 "End-to-end detection output requires at least 6 rows, got {}",
406 output.shape()[0]
407 )));
408 }
409
410 let boxes = output.slice(s![0..4, ..]).reversed_axes();
412 let scores = output.slice(s![4..5, ..]).reversed_axes();
413 let classes = output.slice(s![5, ..]);
414 let mut boxes =
415 postprocess_boxes_index_float::<XYXY, _, _>(score_threshold.as_(), boxes, scores);
416 boxes.truncate(cap_or_default(output_boxes));
417 output_boxes.clear();
418 for (mut b, i) in boxes.into_iter() {
419 b.label = classes[i].as_() as usize;
420 output_boxes.push(b);
421 }
422 Ok(())
424}
425
426pub(crate) fn decode_yolo_end_to_end_segdet_float<T>(
444 output: ArrayView2<T>,
445 protos: ArrayView3<T>,
446 score_threshold: f32,
447 output_boxes: &mut Vec<DetectBox>,
448 output_masks: &mut Vec<crate::Segmentation>,
449) -> Result<(), crate::DecoderError>
450where
451 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
452 f32: AsPrimitive<T>,
453{
454 let (boxes, scores, classes, mask_coeff) =
455 postprocess_yolo_end_to_end_segdet(&output, protos.dim().2)?;
456 let cap = cap_or_default(output_boxes);
457 let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
458 boxes,
459 scores,
460 classes,
461 score_threshold,
462 cap,
463 );
464
465 impl_yolo_split_segdet_process_masks(boxes, mask_coeff, protos, output_boxes, output_masks)
468}
469
470pub(crate) fn decode_yolo_split_end_to_end_det_float<T: Float + AsPrimitive<f32>>(
479 boxes: ArrayView2<T>,
480 scores: ArrayView2<T>,
481 classes: ArrayView2<T>,
482 score_threshold: f32,
483 output_boxes: &mut Vec<DetectBox>,
484) -> Result<(), crate::DecoderError> {
485 let n = boxes.shape()[1];
486
487 let cap = cap_or_default(output_boxes);
488 output_boxes.clear();
489
490 let (boxes, scores, classes) = postprocess_yolo_split_end_to_end_det(boxes, scores, &classes)?;
491
492 for i in 0..n {
493 let score: f32 = scores[[i, 0]].as_();
494 if score < score_threshold {
495 continue;
496 }
497 if output_boxes.len() >= cap {
498 break;
499 }
500 output_boxes.push(DetectBox {
501 bbox: BoundingBox {
502 xmin: boxes[[i, 0]].as_(),
503 ymin: boxes[[i, 1]].as_(),
504 xmax: boxes[[i, 2]].as_(),
505 ymax: boxes[[i, 3]].as_(),
506 },
507 score,
508 label: classes[i].as_() as usize,
509 });
510 }
511 Ok(())
512}
513
514#[allow(clippy::too_many_arguments)]
523pub(crate) fn decode_yolo_split_end_to_end_segdet_float<T>(
524 boxes: ArrayView2<T>,
525 scores: ArrayView2<T>,
526 classes: ArrayView2<T>,
527 mask_coeff: ArrayView2<T>,
528 protos: ArrayView3<T>,
529 score_threshold: f32,
530 output_boxes: &mut Vec<DetectBox>,
531 output_masks: &mut Vec<crate::Segmentation>,
532) -> Result<(), crate::DecoderError>
533where
534 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
535 f32: AsPrimitive<T>,
536{
537 let (boxes, scores, classes, mask_coeff) =
538 postprocess_yolo_split_end_to_end_segdet(boxes, scores, &classes, mask_coeff)?;
539 let cap = cap_or_default(output_boxes);
540 let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
541 boxes,
542 scores,
543 classes,
544 score_threshold,
545 cap,
546 );
547
548 impl_yolo_split_segdet_process_masks(boxes, mask_coeff, protos, output_boxes, output_masks)
549}
550
551#[allow(clippy::type_complexity)]
552pub(crate) fn postprocess_yolo_end_to_end_segdet<'a, T>(
553 output: &'a ArrayView2<'_, T>,
554 num_protos: usize,
555) -> Result<
556 (
557 ArrayView2<'a, T>,
558 ArrayView2<'a, T>,
559 ArrayView1<'a, T>,
560 ArrayView2<'a, T>,
561 ),
562 crate::DecoderError,
563> {
564 if output.shape()[0] < 7 {
566 return Err(crate::DecoderError::InvalidShape(format!(
567 "End-to-end segdet output requires at least 7 rows, got {}",
568 output.shape()[0]
569 )));
570 }
571
572 let num_mask_coeffs = output.shape()[0] - 6;
573 if num_mask_coeffs != num_protos {
574 return Err(crate::DecoderError::InvalidShape(format!(
575 "Mask coefficients count ({}) doesn't match protos count ({})",
576 num_mask_coeffs, num_protos
577 )));
578 }
579
580 let boxes = output.slice(s![0..4, ..]).reversed_axes();
582 let scores = output.slice(s![4..5, ..]).reversed_axes();
583 let classes = output.slice(s![5, ..]);
584 let mask_coeff = output.slice(s![6.., ..]).reversed_axes();
585 Ok((boxes, scores, classes, mask_coeff))
586}
587
588#[allow(clippy::type_complexity)]
595pub(crate) fn postprocess_yolo_split_end_to_end_det<'a, 'b, 'c, BOXES, SCORES, CLASS>(
596 boxes: ArrayView2<'a, BOXES>,
597 scores: ArrayView2<'b, SCORES>,
598 classes: &'c ArrayView2<CLASS>,
599) -> Result<
600 (
601 ArrayView2<'a, BOXES>,
602 ArrayView2<'b, SCORES>,
603 ArrayView1<'c, CLASS>,
604 ),
605 crate::DecoderError,
606> {
607 let num_boxes = boxes.shape()[1];
608 if boxes.shape()[0] != 4 {
609 return Err(crate::DecoderError::InvalidShape(format!(
610 "Split end-to-end box_coords must be 4, got {}",
611 boxes.shape()[0]
612 )));
613 }
614
615 if scores.shape()[0] != 1 {
616 return Err(crate::DecoderError::InvalidShape(format!(
617 "Split end-to-end scores num_classes must be 1, got {}",
618 scores.shape()[0]
619 )));
620 }
621
622 if classes.shape()[0] != 1 {
623 return Err(crate::DecoderError::InvalidShape(format!(
624 "Split end-to-end classes num_classes must be 1, got {}",
625 classes.shape()[0]
626 )));
627 }
628
629 if scores.shape()[1] != num_boxes {
630 return Err(crate::DecoderError::InvalidShape(format!(
631 "Split end-to-end scores must have same num_boxes as boxes ({}), got {}",
632 num_boxes,
633 scores.shape()[1]
634 )));
635 }
636
637 if classes.shape()[1] != num_boxes {
638 return Err(crate::DecoderError::InvalidShape(format!(
639 "Split end-to-end classes must have same num_boxes as boxes ({}), got {}",
640 num_boxes,
641 classes.shape()[1]
642 )));
643 }
644
645 let boxes = boxes.reversed_axes();
646 let scores = scores.reversed_axes();
647 let classes = classes.slice(s![0, ..]);
648 Ok((boxes, scores, classes))
649}
650
651#[allow(clippy::type_complexity)]
654pub(crate) fn postprocess_yolo_split_end_to_end_segdet<
655 'a,
656 'b,
657 'c,
658 'd,
659 BOXES,
660 SCORES,
661 CLASS,
662 MASK,
663>(
664 boxes: ArrayView2<'a, BOXES>,
665 scores: ArrayView2<'b, SCORES>,
666 classes: &'c ArrayView2<CLASS>,
667 mask_coeff: ArrayView2<'d, MASK>,
668) -> Result<
669 (
670 ArrayView2<'a, BOXES>,
671 ArrayView2<'b, SCORES>,
672 ArrayView1<'c, CLASS>,
673 ArrayView2<'d, MASK>,
674 ),
675 crate::DecoderError,
676> {
677 let num_boxes = boxes.shape()[1];
678 if boxes.shape()[0] != 4 {
679 return Err(crate::DecoderError::InvalidShape(format!(
680 "Split end-to-end box_coords must be 4, got {}",
681 boxes.shape()[0]
682 )));
683 }
684
685 if scores.shape()[0] != 1 {
686 return Err(crate::DecoderError::InvalidShape(format!(
687 "Split end-to-end scores num_classes must be 1, got {}",
688 scores.shape()[0]
689 )));
690 }
691
692 if classes.shape()[0] != 1 {
693 return Err(crate::DecoderError::InvalidShape(format!(
694 "Split end-to-end classes num_classes must be 1, got {}",
695 classes.shape()[0]
696 )));
697 }
698
699 if scores.shape()[1] != num_boxes {
700 return Err(crate::DecoderError::InvalidShape(format!(
701 "Split end-to-end scores must have same num_boxes as boxes ({}), got {}",
702 num_boxes,
703 scores.shape()[1]
704 )));
705 }
706
707 if classes.shape()[1] != num_boxes {
708 return Err(crate::DecoderError::InvalidShape(format!(
709 "Split end-to-end classes must have same num_boxes as boxes ({}), got {}",
710 num_boxes,
711 classes.shape()[1]
712 )));
713 }
714
715 if mask_coeff.shape()[1] != num_boxes {
716 return Err(crate::DecoderError::InvalidShape(format!(
717 "Split end-to-end mask_coeff must have same num_boxes as boxes ({}), got {}",
718 num_boxes,
719 mask_coeff.shape()[1]
720 )));
721 }
722
723 let boxes = boxes.reversed_axes();
724 let scores = scores.reversed_axes();
725 let classes = classes.slice(s![0, ..]);
726 let mask_coeff = mask_coeff.reversed_axes();
727 Ok((boxes, scores, classes, mask_coeff))
728}
729pub(crate) fn impl_yolo_quant<B: BBoxTypeTrait, T: PrimInt + AsPrimitive<f32> + Send + Sync>(
734 output: (ArrayView2<T>, Quantization),
735 score_threshold: f32,
736 iou_threshold: f32,
737 nms: Option<Nms>,
738 output_boxes: &mut Vec<DetectBox>,
739) where
740 f32: AsPrimitive<T>,
741{
742 let _span = tracing::trace_span!("decoder.decode.yolo_quant_flat").entered();
743 let (boxes, quant_boxes) = output;
744 let (boxes_tensor, scores_tensor) = postprocess_yolo(&boxes);
745
746 let boxes = {
747 let score_threshold = quantize_score_threshold(score_threshold, quant_boxes);
748 postprocess_boxes_quant::<B, _, _>(
749 score_threshold,
750 boxes_tensor,
751 scores_tensor,
752 quant_boxes,
753 )
754 };
755
756 let cap = cap_or_default(output_boxes);
757 let boxes = dispatch_nms_int(nms, iou_threshold, Some(cap), boxes);
758 let len = cap.min(boxes.len());
761 output_boxes.clear();
762 for b in boxes.iter().take(len) {
763 output_boxes.push(dequant_detect_box(b, quant_boxes));
764 }
765}
766
767pub(crate) fn impl_yolo_float<B: BBoxTypeTrait, T: Float + AsPrimitive<f32> + Send + Sync>(
772 output: ArrayView2<T>,
773 score_threshold: f32,
774 iou_threshold: f32,
775 nms: Option<Nms>,
776 output_boxes: &mut Vec<DetectBox>,
777) where
778 f32: AsPrimitive<T>,
779{
780 let _span = tracing::trace_span!("decoder.decode.yolo_float_flat").entered();
781 let (boxes_tensor, scores_tensor) = postprocess_yolo(&output);
782 let boxes =
783 postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
784 let cap = cap_or_default(output_boxes);
785 let boxes = dispatch_nms_float(nms, iou_threshold, Some(cap), boxes);
786 let len = cap.min(boxes.len());
789 output_boxes.clear();
790 for b in boxes.into_iter().take(len) {
791 output_boxes.push(b);
792 }
793}
794
795pub(crate) fn impl_yolo_split_quant<
805 B: BBoxTypeTrait,
806 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
807 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
808>(
809 boxes: (ArrayView2<BOX>, Quantization),
810 scores: (ArrayView2<SCORE>, Quantization),
811 score_threshold: f32,
812 iou_threshold: f32,
813 nms: Option<Nms>,
814 output_boxes: &mut Vec<DetectBox>,
815) where
816 f32: AsPrimitive<SCORE>,
817{
818 let _span = tracing::trace_span!("decoder.decode.yolo_quant_split").entered();
819 let (boxes_tensor, quant_boxes) = boxes;
820 let (scores_tensor, quant_scores) = scores;
821
822 let boxes_tensor = boxes_tensor.reversed_axes();
823 let scores_tensor = scores_tensor.reversed_axes();
824
825 let boxes = {
826 let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
827 postprocess_boxes_quant::<B, _, _>(
828 score_threshold,
829 boxes_tensor,
830 scores_tensor,
831 quant_boxes,
832 )
833 };
834
835 let cap = cap_or_default(output_boxes);
836 let boxes = dispatch_nms_int(nms, iou_threshold, Some(cap), boxes);
837 let len = cap.min(boxes.len());
840 output_boxes.clear();
841 for b in boxes.iter().take(len) {
842 output_boxes.push(dequant_detect_box(b, quant_scores));
843 }
844}
845
846pub(crate) fn impl_yolo_split_float<
855 B: BBoxTypeTrait,
856 BOX: Float + AsPrimitive<f32> + Send + Sync,
857 SCORE: Float + AsPrimitive<f32> + Send + Sync,
858>(
859 boxes_tensor: ArrayView2<BOX>,
860 scores_tensor: ArrayView2<SCORE>,
861 score_threshold: f32,
862 iou_threshold: f32,
863 nms: Option<Nms>,
864 output_boxes: &mut Vec<DetectBox>,
865) where
866 f32: AsPrimitive<SCORE>,
867{
868 let _span = tracing::trace_span!("decoder.decode.yolo_float_split").entered();
869 let boxes_tensor = boxes_tensor.reversed_axes();
870 let scores_tensor = scores_tensor.reversed_axes();
871 let boxes =
872 postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
873 let cap = cap_or_default(output_boxes);
874 let boxes = dispatch_nms_float(nms, iou_threshold, Some(cap), boxes);
875 let len = cap.min(boxes.len());
878 output_boxes.clear();
879 for b in boxes.into_iter().take(len) {
880 output_boxes.push(b);
881 }
882}
883
884#[inline]
892pub(crate) fn maybe_normalize_boxes_in_place(
893 boxes: &mut [(DetectBox, usize)],
894 normalized: Option<bool>,
895 input_dims: Option<(usize, usize)>,
896) {
897 if normalized != Some(false) {
898 return;
899 }
900 let Some((w, h)) = input_dims else {
901 return;
902 };
903 if w == 0 || h == 0 {
904 return;
905 }
906 let inv_w = 1.0 / w as f32;
907 let inv_h = 1.0 / h as f32;
908 for (b, _) in boxes.iter_mut() {
909 b.bbox.xmin *= inv_w;
910 b.bbox.ymin *= inv_h;
911 b.bbox.xmax *= inv_w;
912 b.bbox.ymax *= inv_h;
913 }
914}
915
916#[allow(clippy::too_many_arguments)]
926pub(crate) fn impl_yolo_segdet_quant<
927 B: BBoxTypeTrait,
928 BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
929 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
930>(
931 boxes: (ArrayView2<BOX>, Quantization),
932 protos: (ArrayView3<PROTO>, Quantization),
933 score_threshold: f32,
934 iou_threshold: f32,
935 nms: Option<Nms>,
936 pre_nms_top_k: usize,
937 max_det: usize,
938 normalized: Option<bool>,
939 input_dims: Option<(usize, usize)>,
940 output_boxes: &mut Vec<DetectBox>,
941 output_masks: &mut Vec<Segmentation>,
942) -> Result<(), crate::DecoderError>
943where
944 f32: AsPrimitive<BOX>,
945{
946 let (boxes, quant_boxes) = boxes;
947 let num_protos = protos.0.dim().2;
948
949 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
950 let mut boxes = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
951 (boxes_tensor, quant_boxes),
952 (scores_tensor, quant_boxes),
953 score_threshold,
954 iou_threshold,
955 nms,
956 pre_nms_top_k,
957 max_det,
958 );
959 maybe_normalize_boxes_in_place(&mut boxes, normalized, input_dims);
960
961 impl_yolo_split_segdet_quant_process_masks::<_, _>(
962 boxes,
963 (mask_tensor, quant_boxes),
964 protos,
965 output_boxes,
966 output_masks,
967 )
968}
969
970#[allow(clippy::too_many_arguments)]
980pub(crate) fn impl_yolo_segdet_float<
981 B: BBoxTypeTrait,
982 BOX: Float + AsPrimitive<f32> + Send + Sync,
983 PROTO: Float + AsPrimitive<f32> + Send + Sync,
984>(
985 boxes: ArrayView2<BOX>,
986 protos: ArrayView3<PROTO>,
987 score_threshold: f32,
988 iou_threshold: f32,
989 nms: Option<Nms>,
990 pre_nms_top_k: usize,
991 max_det: usize,
992 normalized: Option<bool>,
993 input_dims: Option<(usize, usize)>,
994 output_boxes: &mut Vec<DetectBox>,
995 output_masks: &mut Vec<Segmentation>,
996) -> Result<(), crate::DecoderError>
997where
998 f32: AsPrimitive<BOX>,
999{
1000 let num_protos = protos.dim().2;
1001 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
1002 let mut boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
1003 boxes_tensor,
1004 scores_tensor,
1005 score_threshold,
1006 iou_threshold,
1007 nms,
1008 pre_nms_top_k,
1009 max_det,
1010 );
1011 maybe_normalize_boxes_in_place(&mut boxes, normalized, input_dims);
1012 impl_yolo_split_segdet_process_masks(boxes, mask_tensor, protos, output_boxes, output_masks)
1013}
1014
1015pub(crate) fn impl_yolo_segdet_get_boxes<
1016 B: BBoxTypeTrait,
1017 BOX: Float + AsPrimitive<f32> + Send + Sync,
1018 SCORE: Float + AsPrimitive<f32> + Send + Sync,
1019>(
1020 boxes_tensor: ArrayView2<BOX>,
1021 scores_tensor: ArrayView2<SCORE>,
1022 score_threshold: f32,
1023 iou_threshold: f32,
1024 nms: Option<Nms>,
1025 pre_nms_top_k: usize,
1026 max_det: usize,
1027) -> Vec<(DetectBox, usize)>
1028where
1029 f32: AsPrimitive<SCORE>,
1030{
1031 let span = tracing::trace_span!(
1032 "decoder.nms_get_boxes",
1033 n_candidates = tracing::field::Empty,
1034 n_after_topk = tracing::field::Empty,
1035 n_after_nms = tracing::field::Empty,
1036 n_detections = tracing::field::Empty,
1037 );
1038 let _guard = span.enter();
1039
1040 let mut boxes = {
1041 let _s = tracing::trace_span!("decoder.nms_get_boxes.score_filter").entered();
1042 postprocess_boxes_index_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor)
1043 };
1044 span.record("n_candidates", boxes.len());
1045
1046 if nms.is_some() {
1047 let _s = tracing::trace_span!("decoder.nms_get_boxes.top_k", k = pre_nms_top_k).entered();
1048 truncate_to_top_k_by_score(&mut boxes, pre_nms_top_k);
1049 }
1050 span.record("n_after_topk", boxes.len());
1051
1052 let mut boxes = {
1053 let _s = tracing::trace_span!("decoder.nms_get_boxes.suppress").entered();
1054 dispatch_nms_extra_float(nms, iou_threshold, Some(max_det), boxes)
1055 };
1056 span.record("n_after_nms", boxes.len());
1057
1058 boxes.sort_unstable_by(|a, b| b.0.score.total_cmp(&a.0.score));
1061 boxes.truncate(max_det);
1062 span.record("n_detections", boxes.len());
1063
1064 boxes
1065}
1066
1067pub(crate) fn impl_yolo_end_to_end_segdet_get_boxes<
1068 B: BBoxTypeTrait,
1069 BOX: Float + AsPrimitive<f32> + Send + Sync,
1070 SCORE: Float + AsPrimitive<f32> + Send + Sync,
1071 CLASS: AsPrimitive<f32> + Send + Sync,
1072>(
1073 boxes: ArrayView2<BOX>,
1074 scores: ArrayView2<SCORE>,
1075 classes: ArrayView1<CLASS>,
1076 score_threshold: f32,
1077 max_boxes: usize,
1078) -> Vec<(DetectBox, usize)>
1079where
1080 f32: AsPrimitive<SCORE>,
1081{
1082 let mut boxes = postprocess_boxes_index_float::<B, _, _>(score_threshold.as_(), boxes, scores);
1083 boxes.truncate(max_boxes);
1084 for (b, ind) in &mut boxes {
1085 b.label = classes[*ind].as_().round() as usize;
1086 }
1087 boxes
1088}
1089
1090pub(crate) fn impl_yolo_split_segdet_process_masks<
1091 MASK: Float + AsPrimitive<f32> + Send + Sync,
1092 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1093>(
1094 boxes: Vec<(DetectBox, usize)>,
1095 masks_tensor: ArrayView2<MASK>,
1096 protos_tensor: ArrayView3<PROTO>,
1097 output_boxes: &mut Vec<DetectBox>,
1098 output_masks: &mut Vec<Segmentation>,
1099) -> Result<(), crate::DecoderError> {
1100 let _span = tracing::trace_span!(
1101 "decoder.decode.process_masks",
1102 n = boxes.len(),
1103 mode = "float"
1104 )
1105 .entered();
1106 let boxes = decode_segdet_f32(boxes, masks_tensor, protos_tensor)?;
1110 output_boxes.clear();
1111 output_masks.clear();
1112 for (b, roi, m) in boxes.into_iter() {
1113 output_boxes.push(b);
1114 output_masks.push(Segmentation {
1115 xmin: roi.xmin,
1116 ymin: roi.ymin,
1117 xmax: roi.xmax,
1118 ymax: roi.ymax,
1119 segmentation: m,
1120 });
1121 }
1122 Ok(())
1123}
1124pub(crate) fn impl_yolo_split_segdet_quant_get_boxes<
1128 B: BBoxTypeTrait,
1129 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
1130 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
1131>(
1132 boxes: (ArrayView2<BOX>, Quantization),
1133 scores: (ArrayView2<SCORE>, Quantization),
1134 score_threshold: f32,
1135 iou_threshold: f32,
1136 nms: Option<Nms>,
1137 pre_nms_top_k: usize,
1138 max_det: usize,
1139) -> Vec<(DetectBox, usize)>
1140where
1141 f32: AsPrimitive<SCORE>,
1142{
1143 let (boxes_tensor, quant_boxes) = boxes;
1144 let (scores_tensor, quant_scores) = scores;
1145
1146 let span = tracing::trace_span!(
1147 "decoder.nms_get_boxes",
1148 n_candidates = tracing::field::Empty,
1149 n_after_topk = tracing::field::Empty,
1150 n_after_nms = tracing::field::Empty,
1151 n_detections = tracing::field::Empty,
1152 );
1153 let _guard = span.enter();
1154
1155 let mut boxes = {
1156 let _s = tracing::trace_span!("decoder.nms_get_boxes.score_filter").entered();
1157 let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
1158 postprocess_boxes_index_quant::<B, _, _>(
1159 score_threshold,
1160 boxes_tensor,
1161 scores_tensor,
1162 quant_boxes,
1163 )
1164 };
1165 span.record("n_candidates", boxes.len());
1166
1167 if nms.is_some() {
1168 let _s = tracing::trace_span!("decoder.nms_get_boxes.top_k", k = pre_nms_top_k).entered();
1169 truncate_to_top_k_by_score_quant(&mut boxes, pre_nms_top_k);
1170 }
1171 span.record("n_after_topk", boxes.len());
1172
1173 let mut boxes = {
1174 let _s = tracing::trace_span!("decoder.nms_get_boxes.suppress").entered();
1175 dispatch_nms_extra_int(nms, iou_threshold, Some(max_det), boxes)
1176 };
1177 span.record("n_after_nms", boxes.len());
1178
1179 boxes.sort_unstable_by_key(|b| std::cmp::Reverse(b.0.score));
1182 boxes.truncate(max_det);
1183 let result: Vec<_> = {
1184 let _s =
1185 tracing::trace_span!("decoder.nms_get_boxes.dequant_boxes", n = boxes.len()).entered();
1186 boxes
1187 .into_iter()
1188 .map(|(b, i)| (dequant_detect_box(&b, quant_scores), i))
1189 .collect()
1190 };
1191 span.record("n_detections", result.len());
1192
1193 result
1194}
1195
1196pub(crate) fn impl_yolo_split_segdet_quant_process_masks<
1197 MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1198 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1199>(
1200 boxes: Vec<(DetectBox, usize)>,
1201 mask_coeff: (ArrayView2<MASK>, Quantization),
1202 protos: (ArrayView3<PROTO>, Quantization),
1203 output_boxes: &mut Vec<DetectBox>,
1204 output_masks: &mut Vec<Segmentation>,
1205) -> Result<(), crate::DecoderError> {
1206 let _span = tracing::trace_span!(
1207 "decoder.decode.process_masks",
1208 n = boxes.len(),
1209 mode = "quant"
1210 )
1211 .entered();
1212 let (masks, quant_masks) = mask_coeff;
1213 let (protos, quant_protos) = protos;
1214
1215 let boxes = decode_segdet_quant(boxes, masks, protos, quant_masks, quant_protos)?;
1219 output_boxes.clear();
1220 output_masks.clear();
1221 for (b, roi, m) in boxes.into_iter() {
1222 output_boxes.push(b);
1223 output_masks.push(Segmentation {
1224 xmin: roi.xmin,
1225 ymin: roi.ymin,
1226 xmax: roi.xmax,
1227 ymax: roi.ymax,
1228 segmentation: m,
1229 });
1230 }
1231 Ok(())
1232}
1233
1234#[allow(clippy::too_many_arguments)]
1246pub(crate) fn impl_yolo_split_segdet_float<
1247 B: BBoxTypeTrait,
1248 BOX: Float + AsPrimitive<f32> + Send + Sync,
1249 SCORE: Float + AsPrimitive<f32> + Send + Sync,
1250 MASK: Float + AsPrimitive<f32> + Send + Sync,
1251 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1252>(
1253 boxes_tensor: ArrayView2<BOX>,
1254 scores_tensor: ArrayView2<SCORE>,
1255 mask_tensor: ArrayView2<MASK>,
1256 protos: ArrayView3<PROTO>,
1257 score_threshold: f32,
1258 iou_threshold: f32,
1259 nms: Option<Nms>,
1260 pre_nms_top_k: usize,
1261 max_det: usize,
1262 normalized: Option<bool>,
1263 input_dims: Option<(usize, usize)>,
1264 output_boxes: &mut Vec<DetectBox>,
1265 output_masks: &mut Vec<Segmentation>,
1266) -> Result<(), crate::DecoderError>
1267where
1268 f32: AsPrimitive<SCORE>,
1269{
1270 let (boxes_tensor, scores_tensor, mask_tensor) =
1271 postprocess_yolo_split_segdet(boxes_tensor, scores_tensor, mask_tensor);
1272
1273 let mut boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
1274 boxes_tensor,
1275 scores_tensor,
1276 score_threshold,
1277 iou_threshold,
1278 nms,
1279 pre_nms_top_k,
1280 max_det,
1281 );
1282 maybe_normalize_boxes_in_place(&mut boxes, normalized, input_dims);
1283 impl_yolo_split_segdet_process_masks(boxes, mask_tensor, protos, output_boxes, output_masks)
1284}
1285
1286#[allow(clippy::too_many_arguments)]
1293pub(crate) fn impl_yolo_segdet_quant_proto<
1294 B: BBoxTypeTrait,
1295 BOX: PrimInt
1296 + AsPrimitive<i64>
1297 + AsPrimitive<i128>
1298 + AsPrimitive<f32>
1299 + AsPrimitive<i8>
1300 + Send
1301 + Sync,
1302 PROTO: PrimInt
1303 + AsPrimitive<i64>
1304 + AsPrimitive<i128>
1305 + AsPrimitive<f32>
1306 + AsPrimitive<i8>
1307 + Send
1308 + Sync,
1309>(
1310 boxes: (ArrayView2<BOX>, Quantization),
1311 protos: (ArrayView3<PROTO>, Quantization),
1312 score_threshold: f32,
1313 iou_threshold: f32,
1314 nms: Option<Nms>,
1315 pre_nms_top_k: usize,
1316 max_det: usize,
1317 normalized: Option<bool>,
1318 input_dims: Option<(usize, usize)>,
1319 output_boxes: &mut Vec<DetectBox>,
1320) -> ProtoData
1321where
1322 f32: AsPrimitive<BOX>,
1323{
1324 let (boxes_arr, quant_boxes) = boxes;
1325 let (protos_arr, quant_protos) = protos;
1326 let num_protos = protos_arr.dim().2;
1327
1328 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes_arr, num_protos);
1329
1330 let mut det_indices = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
1331 (boxes_tensor, quant_boxes),
1332 (scores_tensor, quant_boxes),
1333 score_threshold,
1334 iou_threshold,
1335 nms,
1336 pre_nms_top_k,
1337 max_det,
1338 );
1339 maybe_normalize_boxes_in_place(&mut det_indices, normalized, input_dims);
1340
1341 extract_proto_data_quant(
1342 det_indices,
1343 mask_tensor,
1344 quant_boxes,
1345 protos_arr,
1346 quant_protos,
1347 output_boxes,
1348 )
1349}
1350
1351#[allow(clippy::too_many_arguments)]
1354pub(crate) fn impl_yolo_segdet_float_proto<
1355 B: BBoxTypeTrait,
1356 BOX: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1357 PROTO: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1358>(
1359 boxes: ArrayView2<BOX>,
1360 protos: ArrayView3<PROTO>,
1361 score_threshold: f32,
1362 iou_threshold: f32,
1363 nms: Option<Nms>,
1364 pre_nms_top_k: usize,
1365 max_det: usize,
1366 normalized: Option<bool>,
1367 input_dims: Option<(usize, usize)>,
1368 output_boxes: &mut Vec<DetectBox>,
1369) -> ProtoData
1370where
1371 f32: AsPrimitive<BOX>,
1372{
1373 let num_protos = protos.dim().2;
1374 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
1375
1376 let mut boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
1377 boxes_tensor,
1378 scores_tensor,
1379 score_threshold,
1380 iou_threshold,
1381 nms,
1382 pre_nms_top_k,
1383 max_det,
1384 );
1385 maybe_normalize_boxes_in_place(&mut boxes, normalized, input_dims);
1386
1387 extract_proto_data_float(boxes, mask_tensor, protos, output_boxes)
1388}
1389
1390#[allow(clippy::too_many_arguments)]
1393pub(crate) fn impl_yolo_split_segdet_float_proto<
1394 B: BBoxTypeTrait,
1395 BOX: Float + AsPrimitive<f32> + Send + Sync,
1396 SCORE: Float + AsPrimitive<f32> + Send + Sync,
1397 MASK: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1398 PROTO: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1399>(
1400 boxes_tensor: ArrayView2<BOX>,
1401 scores_tensor: ArrayView2<SCORE>,
1402 mask_tensor: ArrayView2<MASK>,
1403 protos: ArrayView3<PROTO>,
1404 score_threshold: f32,
1405 iou_threshold: f32,
1406 nms: Option<Nms>,
1407 pre_nms_top_k: usize,
1408 max_det: usize,
1409 normalized: Option<bool>,
1410 input_dims: Option<(usize, usize)>,
1411 output_boxes: &mut Vec<DetectBox>,
1412) -> ProtoData
1413where
1414 f32: AsPrimitive<SCORE>,
1415{
1416 let (boxes_tensor, scores_tensor, mask_tensor) =
1417 postprocess_yolo_split_segdet(boxes_tensor, scores_tensor, mask_tensor);
1418 let mut det_indices = impl_yolo_segdet_get_boxes::<B, _, _>(
1419 boxes_tensor,
1420 scores_tensor,
1421 score_threshold,
1422 iou_threshold,
1423 nms,
1424 pre_nms_top_k,
1425 max_det,
1426 );
1427 maybe_normalize_boxes_in_place(&mut det_indices, normalized, input_dims);
1428
1429 extract_proto_data_float(det_indices, mask_tensor, protos, output_boxes)
1430}
1431
1432pub(crate) fn decode_yolo_end_to_end_segdet_float_proto<T>(
1434 output: ArrayView2<T>,
1435 protos: ArrayView3<T>,
1436 score_threshold: f32,
1437 output_boxes: &mut Vec<DetectBox>,
1438) -> Result<ProtoData, crate::DecoderError>
1439where
1440 T: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1441 f32: AsPrimitive<T>,
1442{
1443 let (boxes, scores, classes, mask_coeff) =
1444 postprocess_yolo_end_to_end_segdet(&output, protos.dim().2)?;
1445 let cap = cap_or_default(output_boxes);
1446 let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
1447 boxes,
1448 scores,
1449 classes,
1450 score_threshold,
1451 cap,
1452 );
1453
1454 Ok(extract_proto_data_float(
1455 boxes,
1456 mask_coeff,
1457 protos,
1458 output_boxes,
1459 ))
1460}
1461
1462#[allow(clippy::too_many_arguments)]
1464pub(crate) fn decode_yolo_split_end_to_end_segdet_float_proto<T>(
1465 boxes: ArrayView2<T>,
1466 scores: ArrayView2<T>,
1467 classes: ArrayView2<T>,
1468 mask_coeff: ArrayView2<T>,
1469 protos: ArrayView3<T>,
1470 score_threshold: f32,
1471 output_boxes: &mut Vec<DetectBox>,
1472) -> Result<ProtoData, crate::DecoderError>
1473where
1474 T: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1475 f32: AsPrimitive<T>,
1476{
1477 let (boxes, scores, classes, mask_coeff) =
1478 postprocess_yolo_split_end_to_end_segdet(boxes, scores, &classes, mask_coeff)?;
1479 let cap = cap_or_default(output_boxes);
1480 let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
1481 boxes,
1482 scores,
1483 classes,
1484 score_threshold,
1485 cap,
1486 );
1487
1488 Ok(extract_proto_data_float(
1489 boxes,
1490 mask_coeff,
1491 protos,
1492 output_boxes,
1493 ))
1494}
1495
1496pub(super) fn extract_proto_data_float<
1503 MASK: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1504 PROTO: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1505>(
1506 det_indices: Vec<(DetectBox, usize)>,
1507 mask_tensor: ArrayView2<MASK>,
1508 protos: ArrayView3<PROTO>,
1509 output_boxes: &mut Vec<DetectBox>,
1510) -> ProtoData {
1511 let _span = tracing::trace_span!(
1512 "decoder.decode_proto.extract_proto_data",
1513 mode = "float",
1514 n = det_indices.len(),
1515 num_protos = mask_tensor.ncols(),
1516 layout = "nhwc",
1517 )
1518 .entered();
1519
1520 let num_protos = mask_tensor.ncols();
1521 let n = det_indices.len();
1522
1523 let mut coeff_rows: Vec<MASK> = Vec::with_capacity(n * num_protos);
1528 output_boxes.clear();
1529 for (det, idx) in det_indices {
1530 output_boxes.push(det);
1531 let row = mask_tensor.row(idx);
1532 coeff_rows.extend(row.iter().copied());
1533 }
1534
1535 let mask_coefficients = MASK::slice_into_tensor_dyn(&coeff_rows, &[n, num_protos])
1536 .expect("allocating mask_coefficients TensorDyn");
1537 let protos_tensor =
1538 PROTO::arrayview3_into_tensor_dyn(protos).expect("allocating protos TensorDyn");
1539
1540 ProtoData {
1541 mask_coefficients,
1542 protos: protos_tensor,
1543 layout: ProtoLayout::Nhwc,
1544 }
1545}
1546
1547pub(crate) fn extract_proto_data_quant<
1556 MASK: PrimInt + AsPrimitive<f32> + AsPrimitive<i8> + Send + Sync + 'static,
1557 PROTO: PrimInt + AsPrimitive<f32> + AsPrimitive<i8> + Send + Sync + 'static,
1558>(
1559 det_indices: Vec<(DetectBox, usize)>,
1560 mask_tensor: ArrayView2<MASK>,
1561 quant_masks: Quantization,
1562 protos: ArrayView3<PROTO>,
1563 quant_protos: Quantization,
1564 output_boxes: &mut Vec<DetectBox>,
1565) -> ProtoData {
1566 use edgefirst_tensor::{Tensor, TensorDyn, TensorMapTrait, TensorMemory, TensorTrait};
1567
1568 let span = tracing::trace_span!(
1569 "decoder.decode_proto.extract_proto_data",
1570 mode = "quant",
1571 n = det_indices.len(),
1572 num_protos = tracing::field::Empty,
1573 layout = tracing::field::Empty,
1574 );
1575 let _guard = span.enter();
1576
1577 let num_protos = mask_tensor.ncols();
1578 let n = det_indices.len();
1579 span.record("num_protos", num_protos);
1580
1581 if n == 0 {
1587 output_boxes.clear();
1588 let (h, w, k) = protos.dim();
1589
1590 let (proto_shape, proto_layout) = if std::any::TypeId::of::<PROTO>()
1592 == std::any::TypeId::of::<i8>()
1593 {
1594 if protos.is_standard_layout() {
1595 (&[h, w, k][..], ProtoLayout::Nhwc)
1596 } else if protos.ndim() == 3 && protos.strides() == [w as isize, 1, (h * w) as isize] {
1597 (&[k, h, w][..], ProtoLayout::Nchw)
1598 } else {
1599 (&[h, w, k][..], ProtoLayout::Nhwc)
1600 }
1601 } else {
1602 (&[h, w, k][..], ProtoLayout::Nhwc)
1603 };
1604
1605 let coeff_tensor = Tensor::<i8>::new(&[0, num_protos], Some(TensorMemory::Mem), None)
1606 .expect("allocating empty mask_coefficients tensor");
1607 let coeff_quant =
1608 edgefirst_tensor::Quantization::per_tensor(quant_masks.scale, quant_masks.zero_point);
1609 let coeff_tensor = coeff_tensor
1610 .with_quantization(coeff_quant)
1611 .expect("per-tensor quantization on mask coefficients");
1612 let protos_tensor = Tensor::<i8>::new(proto_shape, Some(TensorMemory::Mem), None)
1613 .expect("allocating protos tensor");
1614 let tensor_quant =
1615 edgefirst_tensor::Quantization::per_tensor(quant_protos.scale, quant_protos.zero_point);
1616 let protos_tensor = protos_tensor
1617 .with_quantization(tensor_quant)
1618 .expect("per-tensor quantization on protos tensor");
1619 return ProtoData {
1620 mask_coefficients: TensorDyn::I8(coeff_tensor),
1621 protos: TensorDyn::I8(protos_tensor),
1622 layout: proto_layout,
1623 };
1624 }
1625
1626 let mask_coefficients: TensorDyn = if std::any::TypeId::of::<MASK>()
1632 == std::any::TypeId::of::<i8>()
1633 {
1634 let mut coeff_i8 = Vec::<i8>::with_capacity(n * num_protos);
1635 output_boxes.clear();
1636 for (det, idx) in det_indices {
1637 output_boxes.push(det);
1638 let row = mask_tensor.row(idx);
1639 coeff_i8.extend(row.iter().map(|v| {
1640 let v_i8: i8 = v.as_();
1641 v_i8
1642 }));
1643 }
1644 let coeff_tensor = Tensor::<i8>::new(&[n, num_protos], Some(TensorMemory::Mem), None)
1645 .expect("allocating mask_coefficients tensor");
1646 if n > 0 {
1647 let mut m = coeff_tensor
1648 .map()
1649 .expect("mapping mask_coefficients tensor");
1650 m.as_mut_slice().copy_from_slice(&coeff_i8);
1651 }
1652 let coeff_quant =
1653 edgefirst_tensor::Quantization::per_tensor(quant_masks.scale, quant_masks.zero_point);
1654 let coeff_tensor = coeff_tensor
1655 .with_quantization(coeff_quant)
1656 .expect("per-tensor quantization on mask coefficients");
1657 TensorDyn::I8(coeff_tensor)
1658 } else if std::any::TypeId::of::<MASK>() == std::any::TypeId::of::<i16>() {
1659 let mut coeff_i16 = Vec::<i16>::with_capacity(n * num_protos);
1662 output_boxes.clear();
1663 for (det, idx) in det_indices {
1664 output_boxes.push(det);
1665 let row = mask_tensor.row(idx);
1666 coeff_i16.extend(row.iter().map(|v| {
1667 let v_f32: f32 = v.as_();
1668 v_f32 as i16
1669 }));
1670 }
1671 let coeff_tensor = Tensor::<i16>::new(&[n, num_protos], Some(TensorMemory::Mem), None)
1672 .expect("allocating mask_coefficients tensor");
1673 if n > 0 {
1674 let mut m = coeff_tensor
1675 .map()
1676 .expect("mapping mask_coefficients tensor");
1677 m.as_mut_slice().copy_from_slice(&coeff_i16);
1678 }
1679 let coeff_quant =
1680 edgefirst_tensor::Quantization::per_tensor(quant_masks.scale, quant_masks.zero_point);
1681 let coeff_tensor = coeff_tensor
1682 .with_quantization(coeff_quant)
1683 .expect("per-tensor quantization on mask coefficients");
1684 TensorDyn::I16(coeff_tensor)
1685 } else {
1686 let scale = quant_masks.scale;
1688 let zp = quant_masks.zero_point as f32;
1689 let mut coeff_f32 = Vec::<f32>::with_capacity(n * num_protos);
1690 output_boxes.clear();
1691 for (det, idx) in det_indices {
1692 output_boxes.push(det);
1693 let row = mask_tensor.row(idx);
1694 coeff_f32.extend(row.iter().map(|v| {
1695 let v_f32: f32 = v.as_();
1696 (v_f32 - zp) * scale
1697 }));
1698 }
1699 let coeff_tensor = Tensor::<f32>::new(&[n, num_protos], Some(TensorMemory::Mem), None)
1700 .expect("allocating mask_coefficients tensor");
1701 if n > 0 {
1702 let mut m = coeff_tensor
1703 .map()
1704 .expect("mapping mask_coefficients tensor");
1705 m.as_mut_slice().copy_from_slice(&coeff_f32);
1706 }
1707 TensorDyn::F32(coeff_tensor)
1708 };
1709
1710 let (h, w, k) = protos.dim();
1714
1715 let (proto_shape, proto_layout) =
1717 if std::any::TypeId::of::<PROTO>() == std::any::TypeId::of::<i8>() {
1718 if protos.is_standard_layout() {
1719 (&[h, w, k][..], ProtoLayout::Nhwc)
1721 } else if protos.ndim() == 3 && protos.strides() == [w as isize, 1, (h * w) as isize] {
1722 (&[k, h, w][..], ProtoLayout::Nchw)
1726 } else {
1727 (&[h, w, k][..], ProtoLayout::Nhwc)
1729 }
1730 } else {
1731 (&[h, w, k][..], ProtoLayout::Nhwc)
1732 };
1733
1734 let protos_tensor = Tensor::<i8>::new(proto_shape, Some(TensorMemory::Mem), None)
1735 .expect("allocating protos tensor");
1736 {
1737 let mut m = protos_tensor.map().expect("mapping protos tensor");
1738 let dst = m.as_mut_slice();
1739 if std::any::TypeId::of::<PROTO>() == std::any::TypeId::of::<i8>() {
1740 if protos.is_standard_layout() {
1743 let src: &[i8] = unsafe {
1744 std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len())
1745 };
1746 dst.copy_from_slice(src);
1747 } else if protos.ndim() == 3 && protos.strides() == [w as isize, 1, (h * w) as isize] {
1748 let total = h * w * k;
1752 let src: &[i8] =
1755 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, total) };
1756 dst.copy_from_slice(src);
1757 } else {
1758 for (d, s) in dst.iter_mut().zip(protos.iter()) {
1759 let v_i8: i8 = s.as_();
1760 *d = v_i8;
1761 }
1762 }
1763 } else {
1764 for (d, s) in dst.iter_mut().zip(protos.iter()) {
1765 let v_i8: i8 = s.as_();
1766 *d = v_i8;
1767 }
1768 }
1769 }
1770 let tensor_quant =
1771 edgefirst_tensor::Quantization::per_tensor(quant_protos.scale, quant_protos.zero_point);
1772 let protos_tensor = protos_tensor
1773 .with_quantization(tensor_quant)
1774 .expect("per-tensor quantization on new Tensor<i8>");
1775
1776 span.record("layout", tracing::field::debug(&proto_layout));
1777
1778 ProtoData {
1779 mask_coefficients,
1780 protos: TensorDyn::I8(protos_tensor),
1781 layout: proto_layout,
1782 }
1783}
1784
1785pub trait FloatProtoElem: Copy + 'static {
1791 fn slice_into_tensor_dyn(
1792 values: &[Self],
1793 shape: &[usize],
1794 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn>;
1795
1796 fn arrayview3_into_tensor_dyn(
1797 view: ArrayView3<'_, Self>,
1798 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn>;
1799}
1800
1801impl FloatProtoElem for f32 {
1802 fn slice_into_tensor_dyn(
1803 values: &[f32],
1804 shape: &[usize],
1805 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1806 edgefirst_tensor::Tensor::<f32>::from_slice(values, shape)
1807 .map(edgefirst_tensor::TensorDyn::F32)
1808 }
1809 fn arrayview3_into_tensor_dyn(
1810 view: ArrayView3<'_, f32>,
1811 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1812 edgefirst_tensor::Tensor::<f32>::from_arrayview3(view).map(edgefirst_tensor::TensorDyn::F32)
1813 }
1814}
1815
1816impl FloatProtoElem for half::f16 {
1817 fn slice_into_tensor_dyn(
1818 values: &[half::f16],
1819 shape: &[usize],
1820 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1821 edgefirst_tensor::Tensor::<half::f16>::from_slice(values, shape)
1822 .map(edgefirst_tensor::TensorDyn::F16)
1823 }
1824 fn arrayview3_into_tensor_dyn(
1825 view: ArrayView3<'_, half::f16>,
1826 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1827 edgefirst_tensor::Tensor::<half::f16>::from_arrayview3(view)
1828 .map(edgefirst_tensor::TensorDyn::F16)
1829 }
1830}
1831
1832impl FloatProtoElem for f64 {
1833 fn slice_into_tensor_dyn(
1834 values: &[f64],
1835 shape: &[usize],
1836 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1837 let narrowed: Vec<f32> = values.iter().map(|&v| v as f32).collect();
1839 edgefirst_tensor::Tensor::<f32>::from_slice(&narrowed, shape)
1840 .map(edgefirst_tensor::TensorDyn::F32)
1841 }
1842 fn arrayview3_into_tensor_dyn(
1843 view: ArrayView3<'_, f64>,
1844 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1845 let narrowed: ndarray::Array3<f32> = view.mapv(|v| v as f32);
1846 edgefirst_tensor::Tensor::<f32>::from_arrayview3(narrowed.view())
1847 .map(edgefirst_tensor::TensorDyn::F32)
1848 }
1849}
1850
1851fn postprocess_yolo<'a, T>(
1852 output: &'a ArrayView2<'_, T>,
1853) -> (ArrayView2<'a, T>, ArrayView2<'a, T>) {
1854 let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1855 let scores_tensor = output.slice(s![4.., ..,]).reversed_axes();
1856 (boxes_tensor, scores_tensor)
1857}
1858
1859pub(crate) fn postprocess_yolo_seg<'a, T>(
1860 output: &'a ArrayView2<'_, T>,
1861 num_protos: usize,
1862) -> (ArrayView2<'a, T>, ArrayView2<'a, T>, ArrayView2<'a, T>) {
1863 assert!(
1864 output.shape()[0] > num_protos + 4,
1865 "Output shape is too short: {} <= {} + 4",
1866 output.shape()[0],
1867 num_protos
1868 );
1869 let num_classes = output.shape()[0] - 4 - num_protos;
1870 let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1871 let scores_tensor = output.slice(s![4..(num_classes + 4), ..,]).reversed_axes();
1872 let mask_tensor = output.slice(s![(num_classes + 4).., ..,]).reversed_axes();
1873 (boxes_tensor, scores_tensor, mask_tensor)
1874}
1875
1876pub(crate) fn postprocess_yolo_split_segdet<'a, 'b, 'c, BOX, SCORE, MASK>(
1877 boxes_tensor: ArrayView2<'a, BOX>,
1878 scores_tensor: ArrayView2<'b, SCORE>,
1879 mask_tensor: ArrayView2<'c, MASK>,
1880) -> (
1881 ArrayView2<'a, BOX>,
1882 ArrayView2<'b, SCORE>,
1883 ArrayView2<'c, MASK>,
1884) {
1885 let boxes_tensor = boxes_tensor.reversed_axes();
1886 let scores_tensor = scores_tensor.reversed_axes();
1887 let mask_tensor = mask_tensor.reversed_axes();
1888 (boxes_tensor, scores_tensor, mask_tensor)
1889}
1890
1891fn decode_segdet_f32<
1892 MASK: Float + AsPrimitive<f32> + Send + Sync,
1893 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1894>(
1895 boxes: Vec<(DetectBox, usize)>,
1896 masks: ArrayView2<MASK>,
1897 protos: ArrayView3<PROTO>,
1898) -> Result<Vec<(DetectBox, BoundingBox, Array3<u8>)>, crate::DecoderError> {
1899 if boxes.is_empty() {
1900 return Ok(Vec::new());
1901 }
1902 if masks.shape()[1] != protos.shape()[2] {
1903 return Err(crate::DecoderError::InvalidShape(format!(
1904 "Mask coefficients count ({}) doesn't match protos channel count ({})",
1905 masks.shape()[1],
1906 protos.shape()[2],
1907 )));
1908 }
1909 boxes
1910 .into_par_iter()
1911 .map(|b| {
1912 let ind = b.1;
1913 let (protos, roi) = protobox(&protos, &b.0.bbox)?;
1918 Ok((b.0, roi, make_segmentation(masks.row(ind), protos.view())))
1919 })
1920 .collect()
1921}
1922
1923pub(crate) fn decode_segdet_quant<
1924 MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1925 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1926>(
1927 boxes: Vec<(DetectBox, usize)>,
1928 masks: ArrayView2<MASK>,
1929 protos: ArrayView3<PROTO>,
1930 quant_masks: Quantization,
1931 quant_protos: Quantization,
1932) -> Result<Vec<(DetectBox, BoundingBox, Array3<u8>)>, crate::DecoderError> {
1933 if boxes.is_empty() {
1934 return Ok(Vec::new());
1935 }
1936 if masks.shape()[1] != protos.shape()[2] {
1937 return Err(crate::DecoderError::InvalidShape(format!(
1938 "Mask coefficients count ({}) doesn't match protos channel count ({})",
1939 masks.shape()[1],
1940 protos.shape()[2],
1941 )));
1942 }
1943
1944 let total_bits = MASK::zero().count_zeros() + PROTO::zero().count_zeros() + 5; boxes
1946 .into_iter()
1947 .map(|b| {
1948 let i = b.1;
1949 let (protos, roi) = protobox(&protos, &b.0.bbox.to_canonical())?;
1953 let seg = match total_bits {
1954 0..=64 => make_segmentation_quant::<MASK, PROTO, i64>(
1955 masks.row(i),
1956 protos.view(),
1957 quant_masks,
1958 quant_protos,
1959 ),
1960 65..=128 => make_segmentation_quant::<MASK, PROTO, i128>(
1961 masks.row(i),
1962 protos.view(),
1963 quant_masks,
1964 quant_protos,
1965 ),
1966 _ => {
1967 return Err(crate::DecoderError::NotSupported(format!(
1968 "Unsupported bit width ({total_bits}) for segmentation computation"
1969 )));
1970 }
1971 };
1972 Ok((b.0, roi, seg))
1973 })
1974 .collect()
1975}
1976
1977fn protobox<'a, T>(
1978 protos: &'a ArrayView3<T>,
1979 roi: &BoundingBox,
1980) -> Result<(ArrayView3<'a, T>, BoundingBox), crate::DecoderError> {
1981 let width = protos.dim().1 as f32;
1982 let height = protos.dim().0 as f32;
1983
1984 const NORM_LIMIT: f32 = 2.0;
1996 if roi.xmin > NORM_LIMIT
1997 || roi.ymin > NORM_LIMIT
1998 || roi.xmax > NORM_LIMIT
1999 || roi.ymax > NORM_LIMIT
2000 {
2001 return Err(crate::DecoderError::InvalidShape(format!(
2002 "Bounding box coordinates appear un-normalized (pixel-space). \
2003 Got bbox=({:.2}, {:.2}, {:.2}, {:.2}) but expected values in [0, 1]. \
2004 Two ways to fix this: \
2005 (1) declare `Detection::normalized = false` in the model schema \
2006 AND make sure the schema's `input.shape` / `input.dshape` carries \
2007 the model input dims so the decoder can divide by (W, H) before NMS \
2008 (EDGEAI-1303 — verify with `Decoder::input_dims().is_some()`); or \
2009 (2) normalize the boxes in-graph before decode().",
2010 roi.xmin, roi.ymin, roi.xmax, roi.ymax,
2011 )));
2012 }
2013
2014 let roi = [
2015 (roi.xmin * width).clamp(0.0, width) as usize,
2016 (roi.ymin * height).clamp(0.0, height) as usize,
2017 (roi.xmax * width).clamp(0.0, width).ceil() as usize,
2018 (roi.ymax * height).clamp(0.0, height).ceil() as usize,
2019 ];
2020
2021 let roi_norm = [
2022 roi[0] as f32 / width,
2023 roi[1] as f32 / height,
2024 roi[2] as f32 / width,
2025 roi[3] as f32 / height,
2026 ]
2027 .into();
2028
2029 let cropped = protos.slice(s![roi[1]..roi[3], roi[0]..roi[2], ..]);
2030
2031 Ok((cropped, roi_norm))
2032}
2033
2034fn make_segmentation<
2040 MASK: Float + AsPrimitive<f32> + Send + Sync,
2041 PROTO: Float + AsPrimitive<f32> + Send + Sync,
2042>(
2043 mask: ArrayView1<MASK>,
2044 protos: ArrayView3<PROTO>,
2045) -> Array3<u8> {
2046 let shape = protos.shape();
2047
2048 let mask = mask.to_shape((1, mask.len())).unwrap();
2050 let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
2051 let protos = protos.reversed_axes();
2052 let mask = mask.map(|x| x.as_());
2053 let protos = protos.map(|x| x.as_());
2054
2055 let mask = mask
2057 .dot(&protos)
2058 .into_shape_with_order((shape[0], shape[1], 1))
2059 .unwrap();
2060
2061 mask.map(|x| {
2062 let sigmoid = 1.0 / (1.0 + (-*x).exp());
2063 (sigmoid * 255.0).round() as u8
2064 })
2065}
2066
2067fn make_segmentation_quant<
2074 MASK: PrimInt + AsPrimitive<DEST> + Send + Sync,
2075 PROTO: PrimInt + AsPrimitive<DEST> + Send + Sync,
2076 DEST: PrimInt + 'static + Signed + AsPrimitive<f32> + Debug,
2077>(
2078 mask: ArrayView1<MASK>,
2079 protos: ArrayView3<PROTO>,
2080 quant_masks: Quantization,
2081 quant_protos: Quantization,
2082) -> Array3<u8>
2083where
2084 i32: AsPrimitive<DEST>,
2085 f32: AsPrimitive<DEST>,
2086{
2087 let shape = protos.shape();
2088
2089 let mask = mask.to_shape((1, mask.len())).unwrap();
2091
2092 let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
2093 let protos = protos.reversed_axes();
2094
2095 let zp = quant_masks.zero_point.as_();
2096
2097 let mask = mask.mapv(|x| x.as_() - zp);
2098
2099 let zp = quant_protos.zero_point.as_();
2100 let protos = protos.mapv(|x| x.as_() - zp);
2101
2102 let segmentation = mask
2104 .dot(&protos)
2105 .into_shape_with_order((shape[0], shape[1], 1))
2106 .unwrap();
2107
2108 let combined_scale = quant_masks.scale * quant_protos.scale;
2109 segmentation.map(|x| {
2110 let val: f32 = (*x).as_() * combined_scale;
2111 let sigmoid = 1.0 / (1.0 + (-val).exp());
2112 (sigmoid * 255.0).round() as u8
2113 })
2114}
2115
2116pub(crate) fn yolo_segmentation_to_mask(
2128 segmentation: ArrayView3<u8>,
2129 threshold: u8,
2130) -> Result<Array2<u8>, crate::DecoderError> {
2131 if segmentation.shape()[2] != 1 {
2132 return Err(crate::DecoderError::InvalidShape(format!(
2133 "Yolo Instance Segmentation should have shape (H, W, 1), got (H, W, {})",
2134 segmentation.shape()[2]
2135 )));
2136 }
2137 Ok(segmentation
2138 .slice(s![.., .., 0])
2139 .map(|x| if *x >= threshold { 1 } else { 0 }))
2140}
2141
2142#[cfg(test)]
2143#[cfg_attr(coverage_nightly, coverage(off))]
2144mod tests {
2145 use super::*;
2146 use ndarray::Array2;
2147
2148 #[test]
2153 fn test_end_to_end_det_basic_filtering() {
2154 let data: Vec<f32> = vec![
2158 0.1, 0.2, 0.3, 0.1, 0.2, 0.3, 0.5, 0.6, 0.7, 0.5, 0.6, 0.7, 0.9, 0.1, 0.2, 0.0, 1.0, 2.0, ];
2166 let output = Array2::from_shape_vec((6, 3), data).unwrap();
2167
2168 let mut boxes = Vec::with_capacity(10);
2169 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2170
2171 assert_eq!(boxes.len(), 1);
2173 assert_eq!(boxes[0].label, 0);
2174 assert!((boxes[0].score - 0.9).abs() < 0.01);
2175 assert!((boxes[0].bbox.xmin - 0.1).abs() < 0.01);
2176 assert!((boxes[0].bbox.ymin - 0.1).abs() < 0.01);
2177 assert!((boxes[0].bbox.xmax - 0.5).abs() < 0.01);
2178 assert!((boxes[0].bbox.ymax - 0.5).abs() < 0.01);
2179 }
2180
2181 #[test]
2182 fn test_end_to_end_det_all_pass_threshold() {
2183 let data: Vec<f32> = vec![
2185 10.0, 20.0, 10.0, 20.0, 50.0, 60.0, 50.0, 60.0, 0.8, 0.7, 1.0, 2.0, ];
2192 let output = Array2::from_shape_vec((6, 2), data).unwrap();
2193
2194 let mut boxes = Vec::with_capacity(10);
2195 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2196
2197 assert_eq!(boxes.len(), 2);
2198 assert_eq!(boxes[0].label, 1);
2199 assert_eq!(boxes[1].label, 2);
2200 }
2201
2202 #[test]
2203 fn test_end_to_end_det_none_pass_threshold() {
2204 let data: Vec<f32> = vec![
2206 10.0, 20.0, 10.0, 20.0, 50.0, 60.0, 50.0, 60.0, 0.1, 0.2, 1.0, 2.0, ];
2213 let output = Array2::from_shape_vec((6, 2), data).unwrap();
2214
2215 let mut boxes = Vec::with_capacity(10);
2216 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2217
2218 assert_eq!(boxes.len(), 0);
2219 }
2220
2221 #[test]
2222 fn test_end_to_end_det_capacity_limit() {
2223 let data: Vec<f32> = vec![
2225 0.1, 0.2, 0.3, 0.4, 0.5, 0.1, 0.2, 0.3, 0.4, 0.5, 0.5, 0.6, 0.7, 0.8, 0.9, 0.5, 0.6, 0.7, 0.8, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.0, 1.0, 2.0, 3.0, 4.0, ];
2232 let output = Array2::from_shape_vec((6, 5), data).unwrap();
2233
2234 let mut boxes = Vec::with_capacity(2); decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2236
2237 assert_eq!(boxes.len(), 2);
2238 }
2239
2240 #[test]
2241 fn test_end_to_end_det_empty_output() {
2242 let output = Array2::<f32>::zeros((6, 0));
2244
2245 let mut boxes = Vec::with_capacity(10);
2246 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2247
2248 assert_eq!(boxes.len(), 0);
2249 }
2250
2251 #[test]
2252 fn test_end_to_end_det_pixel_coordinates() {
2253 let data: Vec<f32> = vec![
2255 100.0, 200.0, 300.0, 400.0, 0.95, 5.0, ];
2262 let output = Array2::from_shape_vec((6, 1), data).unwrap();
2263
2264 let mut boxes = Vec::with_capacity(10);
2265 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2266
2267 assert_eq!(boxes.len(), 1);
2268 assert_eq!(boxes[0].label, 5);
2269 assert!((boxes[0].bbox.xmin - 100.0).abs() < 0.01);
2270 assert!((boxes[0].bbox.ymin - 200.0).abs() < 0.01);
2271 assert!((boxes[0].bbox.xmax - 300.0).abs() < 0.01);
2272 assert!((boxes[0].bbox.ymax - 400.0).abs() < 0.01);
2273 }
2274
2275 #[test]
2276 fn test_end_to_end_det_invalid_shape() {
2277 let output = Array2::<f32>::zeros((5, 3));
2279
2280 let mut boxes = Vec::with_capacity(10);
2281 let result = decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes);
2282
2283 assert!(result.is_err());
2284 assert!(matches!(
2285 result,
2286 Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 6 rows")
2287 ));
2288 }
2289
2290 #[test]
2295 fn test_end_to_end_segdet_basic() {
2296 let num_protos = 32;
2299 let num_detections = 2;
2300 let num_features = 6 + num_protos;
2301
2302 let mut data = vec![0.0f32; num_features * num_detections];
2304 data[0] = 0.1; data[1] = 0.5; data[num_detections] = 0.1; data[num_detections + 1] = 0.5; data[2 * num_detections] = 0.4; data[2 * num_detections + 1] = 0.9; data[3 * num_detections] = 0.4; data[3 * num_detections + 1] = 0.9; data[4 * num_detections] = 0.9; data[4 * num_detections + 1] = 0.3; data[5 * num_detections] = 1.0; data[5 * num_detections + 1] = 2.0; for i in 6..num_features {
2319 data[i * num_detections] = 0.1;
2320 data[i * num_detections + 1] = 0.1;
2321 }
2322
2323 let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
2324
2325 let protos = Array3::<f32>::zeros((16, 16, num_protos));
2327
2328 let mut boxes = Vec::with_capacity(10);
2329 let mut masks = Vec::with_capacity(10);
2330 decode_yolo_end_to_end_segdet_float(
2331 output.view(),
2332 protos.view(),
2333 0.5,
2334 &mut boxes,
2335 &mut masks,
2336 )
2337 .unwrap();
2338
2339 assert_eq!(boxes.len(), 1);
2341 assert_eq!(masks.len(), 1);
2342 assert_eq!(boxes[0].label, 1);
2343 assert!((boxes[0].score - 0.9).abs() < 0.01);
2344 }
2345
2346 #[test]
2347 fn test_end_to_end_segdet_mask_coordinates() {
2348 let num_protos = 32;
2350 let num_features = 6 + num_protos;
2351
2352 let mut data = vec![0.0f32; num_features];
2353 data[0] = 0.2; data[1] = 0.2; data[2] = 0.8; data[3] = 0.8; data[4] = 0.95; data[5] = 3.0; let output = Array2::from_shape_vec((num_features, 1), data).unwrap();
2361 let protos = Array3::<f32>::zeros((16, 16, num_protos));
2362
2363 let mut boxes = Vec::with_capacity(10);
2364 let mut masks = Vec::with_capacity(10);
2365 decode_yolo_end_to_end_segdet_float(
2366 output.view(),
2367 protos.view(),
2368 0.5,
2369 &mut boxes,
2370 &mut masks,
2371 )
2372 .unwrap();
2373
2374 assert_eq!(boxes.len(), 1);
2375 assert_eq!(masks.len(), 1);
2376
2377 let step = 1.0 / 16.0;
2381 assert!(masks[0].xmin <= boxes[0].bbox.xmin);
2382 assert!(masks[0].ymin <= boxes[0].bbox.ymin);
2383 assert!(masks[0].xmax >= boxes[0].bbox.xmax);
2384 assert!(masks[0].ymax >= boxes[0].bbox.ymax);
2385 assert!((boxes[0].bbox.xmin - masks[0].xmin) < step);
2386 assert!((boxes[0].bbox.ymin - masks[0].ymin) < step);
2387 assert!((masks[0].xmax - boxes[0].bbox.xmax) < step);
2388 assert!((masks[0].ymax - boxes[0].bbox.ymax) < step);
2389 }
2390
2391 #[test]
2392 fn test_end_to_end_segdet_empty_output() {
2393 let num_protos = 32;
2394 let output = Array2::<f32>::zeros((6 + num_protos, 0));
2395 let protos = Array3::<f32>::zeros((16, 16, num_protos));
2396
2397 let mut boxes = Vec::with_capacity(10);
2398 let mut masks = Vec::with_capacity(10);
2399 decode_yolo_end_to_end_segdet_float(
2400 output.view(),
2401 protos.view(),
2402 0.5,
2403 &mut boxes,
2404 &mut masks,
2405 )
2406 .unwrap();
2407
2408 assert_eq!(boxes.len(), 0);
2409 assert_eq!(masks.len(), 0);
2410 }
2411
2412 #[test]
2413 fn test_end_to_end_segdet_capacity_limit() {
2414 let num_protos = 32;
2415 let num_detections = 5;
2416 let num_features = 6 + num_protos;
2417
2418 let mut data = vec![0.0f32; num_features * num_detections];
2419 for i in 0..num_detections {
2421 data[i] = 0.1 * (i as f32); data[num_detections + i] = 0.1 * (i as f32); data[2 * num_detections + i] = 0.1 * (i as f32) + 0.2; data[3 * num_detections + i] = 0.1 * (i as f32) + 0.2; data[4 * num_detections + i] = 0.9; data[5 * num_detections + i] = i as f32; }
2428
2429 let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
2430 let protos = Array3::<f32>::zeros((16, 16, num_protos));
2431
2432 let mut boxes = Vec::with_capacity(2); let mut masks = Vec::with_capacity(2);
2434 decode_yolo_end_to_end_segdet_float(
2435 output.view(),
2436 protos.view(),
2437 0.5,
2438 &mut boxes,
2439 &mut masks,
2440 )
2441 .unwrap();
2442
2443 assert_eq!(boxes.len(), 2);
2444 assert_eq!(masks.len(), 2);
2445 }
2446
2447 #[test]
2448 fn test_end_to_end_segdet_invalid_shape_too_few_rows() {
2449 let output = Array2::<f32>::zeros((6, 3));
2451 let protos = Array3::<f32>::zeros((16, 16, 32));
2452
2453 let mut boxes = Vec::with_capacity(10);
2454 let mut masks = Vec::with_capacity(10);
2455 let result = decode_yolo_end_to_end_segdet_float(
2456 output.view(),
2457 protos.view(),
2458 0.5,
2459 &mut boxes,
2460 &mut masks,
2461 );
2462
2463 assert!(result.is_err());
2464 assert!(matches!(
2465 result,
2466 Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 7 rows")
2467 ));
2468 }
2469
2470 #[test]
2471 fn test_end_to_end_segdet_invalid_shape_protos_mismatch() {
2472 let num_protos = 32;
2474 let output = Array2::<f32>::zeros((6 + 16, 3)); let protos = Array3::<f32>::zeros((16, 16, num_protos)); let mut boxes = Vec::with_capacity(10);
2478 let mut masks = Vec::with_capacity(10);
2479 let result = decode_yolo_end_to_end_segdet_float(
2480 output.view(),
2481 protos.view(),
2482 0.5,
2483 &mut boxes,
2484 &mut masks,
2485 );
2486
2487 assert!(result.is_err());
2488 assert!(matches!(
2489 result,
2490 Err(crate::DecoderError::InvalidShape(s)) if s.contains("doesn't match protos count")
2491 ));
2492 }
2493
2494 #[test]
2499 fn test_split_end_to_end_segdet_basic() {
2500 let num_protos = 32;
2503 let num_detections = 2;
2504 let num_features = 6 + num_protos;
2505
2506 let mut data = vec![0.0f32; num_features * num_detections];
2508 data[0] = 0.1; data[1] = 0.5; data[num_detections] = 0.1; data[num_detections + 1] = 0.5; data[2 * num_detections] = 0.4; data[2 * num_detections + 1] = 0.9; data[3 * num_detections] = 0.4; data[3 * num_detections + 1] = 0.9; data[4 * num_detections] = 0.9; data[4 * num_detections + 1] = 0.3; data[5 * num_detections] = 1.0; data[5 * num_detections + 1] = 2.0; for i in 6..num_features {
2523 data[i * num_detections] = 0.1;
2524 data[i * num_detections + 1] = 0.1;
2525 }
2526
2527 let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
2528 let box_coords = output.slice(s![..4, ..]);
2529 let scores = output.slice(s![4..5, ..]);
2530 let classes = output.slice(s![5..6, ..]);
2531 let mask_coeff = output.slice(s![6.., ..]);
2532 let protos = Array3::<f32>::zeros((16, 16, num_protos));
2534
2535 let mut boxes = Vec::with_capacity(10);
2536 let mut masks = Vec::with_capacity(10);
2537 decode_yolo_split_end_to_end_segdet_float(
2538 box_coords,
2539 scores,
2540 classes,
2541 mask_coeff,
2542 protos.view(),
2543 0.5,
2544 &mut boxes,
2545 &mut masks,
2546 )
2547 .unwrap();
2548
2549 assert_eq!(boxes.len(), 1);
2551 assert_eq!(masks.len(), 1);
2552 assert_eq!(boxes[0].label, 1);
2553 assert!((boxes[0].score - 0.9).abs() < 0.01);
2554 }
2555
2556 #[test]
2561 fn test_segmentation_to_mask_basic() {
2562 let data: Vec<u8> = vec![
2564 100, 200, 50, 150, 10, 255, 128, 64, 0, 127, 128, 255, 64, 64, 192, 192, ];
2569 let segmentation = Array3::from_shape_vec((4, 4, 1), data).unwrap();
2570
2571 let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2572
2573 assert_eq!(mask[[0, 0]], 0); assert_eq!(mask[[0, 1]], 1); assert_eq!(mask[[0, 2]], 0); assert_eq!(mask[[0, 3]], 1); assert_eq!(mask[[1, 1]], 1); assert_eq!(mask[[1, 2]], 1); assert_eq!(mask[[2, 0]], 0); assert_eq!(mask[[2, 1]], 0); }
2583
2584 #[test]
2585 fn test_segmentation_to_mask_all_above() {
2586 let segmentation = Array3::from_elem((4, 4, 1), 255u8);
2587 let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2588 assert!(mask.iter().all(|&x| x == 1));
2589 }
2590
2591 #[test]
2592 fn test_segmentation_to_mask_all_below() {
2593 let segmentation = Array3::from_elem((4, 4, 1), 64u8);
2594 let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2595 assert!(mask.iter().all(|&x| x == 0));
2596 }
2597
2598 #[test]
2599 fn test_segmentation_to_mask_invalid_shape() {
2600 let segmentation = Array3::from_elem((4, 4, 3), 128u8);
2601 let result = yolo_segmentation_to_mask(segmentation.view(), 128);
2602
2603 assert!(result.is_err());
2604 assert!(matches!(
2605 result,
2606 Err(crate::DecoderError::InvalidShape(s)) if s.contains("(H, W, 1)")
2607 ));
2608 }
2609
2610 #[test]
2615 fn test_protobox_clamps_edge_coordinates() {
2616 let protos = Array3::<f32>::zeros((16, 16, 4));
2618 let view = protos.view();
2619 let roi = BoundingBox {
2620 xmin: 0.5,
2621 ymin: 0.5,
2622 xmax: 1.0,
2623 ymax: 1.0,
2624 };
2625 let result = protobox(&view, &roi);
2626 assert!(result.is_ok(), "protobox should accept xmax=1.0");
2627 let (cropped, _roi_norm) = result.unwrap();
2628 assert!(cropped.shape()[0] > 0);
2630 assert!(cropped.shape()[1] > 0);
2631 assert_eq!(cropped.shape()[2], 4);
2632 }
2633
2634 #[test]
2635 fn test_protobox_rejects_wildly_out_of_range() {
2636 let protos = Array3::<f32>::zeros((16, 16, 4));
2638 let view = protos.view();
2639 let roi = BoundingBox {
2640 xmin: 0.0,
2641 ymin: 0.0,
2642 xmax: 3.0,
2643 ymax: 3.0,
2644 };
2645 let result = protobox(&view, &roi);
2646 assert!(
2647 matches!(result, Err(crate::DecoderError::InvalidShape(s)) if s.contains("un-normalized")),
2648 "protobox should reject coords > NORM_LIMIT"
2649 );
2650 }
2651
2652 #[test]
2653 fn test_protobox_accepts_slightly_over_one() {
2654 let protos = Array3::<f32>::zeros((16, 16, 4));
2656 let view = protos.view();
2657 let roi = BoundingBox {
2658 xmin: 0.0,
2659 ymin: 0.0,
2660 xmax: 1.5,
2661 ymax: 1.5,
2662 };
2663 let result = protobox(&view, &roi);
2664 assert!(
2665 result.is_ok(),
2666 "protobox should accept coords <= NORM_LIMIT (2.0)"
2667 );
2668 let (cropped, _roi_norm) = result.unwrap();
2669 assert_eq!(cropped.shape()[0], 16);
2671 assert_eq!(cropped.shape()[1], 16);
2672 }
2673
2674 #[test]
2675 fn test_segdet_float_proto_no_panic() {
2676 let num_proposals = 100; let num_classes = 80;
2680 let num_mask_coeffs = 32;
2681 let rows = 4 + num_classes + num_mask_coeffs; let mut data = vec![0.0f32; rows * num_proposals];
2687 for i in 0..num_proposals {
2688 let row = |r: usize| r * num_proposals + i;
2689 data[row(0)] = 320.0; data[row(1)] = 320.0; data[row(2)] = 50.0; data[row(3)] = 50.0; data[row(4)] = 0.9; }
2695 let boxes = ndarray::Array2::from_shape_vec((rows, num_proposals), data).unwrap();
2696
2697 let protos = ndarray::Array3::<f32>::zeros((160, 160, num_mask_coeffs));
2702
2703 let mut output_boxes = Vec::with_capacity(300);
2704
2705 let proto_data = impl_yolo_segdet_float_proto::<XYWH, _, _>(
2707 boxes.view(),
2708 protos.view(),
2709 0.5,
2710 0.7,
2711 Some(Nms::default()),
2712 MAX_NMS_CANDIDATES,
2713 300,
2714 None,
2715 None,
2716 &mut output_boxes,
2717 );
2718
2719 assert!(!output_boxes.is_empty());
2721 let coeffs_shape = proto_data.mask_coefficients.shape();
2722 assert_eq!(coeffs_shape[0], output_boxes.len());
2723 assert_eq!(coeffs_shape[1], num_mask_coeffs);
2725 }
2726
2727 #[test]
2742 fn test_pre_nms_cap_truncates_excess_candidates() {
2743 let n: usize = 50_000;
2744 let num_classes = 1;
2745
2746 let mut boxes_data = Vec::with_capacity(n * 4);
2750 let mut scores_data = Vec::with_capacity(n * num_classes);
2751 for i in 0..n {
2752 boxes_data.extend_from_slice(&[0.1f32, 0.1, 0.5, 0.5]);
2753 scores_data.push(0.99 - (i as f32) * 1e-7);
2756 }
2757 let boxes = Array2::from_shape_vec((n, 4), boxes_data).unwrap();
2758 let scores = Array2::from_shape_vec((n, num_classes), scores_data).unwrap();
2759
2760 let result = impl_yolo_segdet_get_boxes::<XYXY, _, _>(
2761 boxes.view(),
2762 scores.view(),
2763 0.1,
2764 1.0, Some(Nms::ClassAgnostic), crate::yolo::MAX_NMS_CANDIDATES, usize::MAX, );
2769
2770 assert_eq!(
2771 result.len(),
2772 crate::yolo::MAX_NMS_CANDIDATES,
2773 "pre-NMS cap should truncate to MAX_NMS_CANDIDATES; got {}",
2774 result.len()
2775 );
2776 let top_score = result[0].0.score;
2779 assert!(
2780 top_score > 0.98,
2781 "highest-ranked survivor should have the largest score, got {top_score}"
2782 );
2783 }
2784
2785 #[test]
2790 fn test_pre_nms_cap_truncates_excess_candidates_quant() {
2791 use crate::Quantization;
2792 let n: usize = 50_000;
2793 let num_classes = 1;
2794
2795 let boxes_data = (0..n).flat_map(|_| [10i8, 10, 50, 50]).collect::<Vec<_>>();
2798 let boxes = Array2::from_shape_vec((n, 4), boxes_data).unwrap();
2799 let quant_boxes = Quantization {
2800 scale: 0.01,
2801 zero_point: 0,
2802 };
2803
2804 let scores_data: Vec<u8> = (0..n)
2809 .map(|i| 250u8.saturating_sub((i % 200) as u8))
2810 .collect();
2811 let scores = Array2::from_shape_vec((n, num_classes), scores_data).unwrap();
2812 let quant_scores = Quantization {
2813 scale: 0.00392,
2814 zero_point: 0,
2815 };
2816
2817 let result = impl_yolo_split_segdet_quant_get_boxes::<XYXY, _, _>(
2818 (boxes.view(), quant_boxes),
2819 (scores.view(), quant_scores),
2820 0.1,
2821 1.0, Some(Nms::ClassAgnostic), crate::yolo::MAX_NMS_CANDIDATES, usize::MAX, );
2826
2827 assert_eq!(
2828 result.len(),
2829 crate::yolo::MAX_NMS_CANDIDATES,
2830 "quant path pre-NMS cap should truncate to MAX_NMS_CANDIDATES; got {}",
2831 result.len()
2832 );
2833 }
2834
2835 #[test]
2849 fn segdet_combined_tensor_pairs_detection_with_matching_mask_row() {
2850 let nc = 2; let nm = 2; let n = 3; let feat = 4 + nc + nm; let mut data = vec![0.0f32; feat * n];
2873 let set = |d: &mut [f32], r: usize, c: usize, v: f32| d[r * n + c] = v;
2874 set(&mut data, 0, 0, 0.2);
2875 set(&mut data, 1, 0, 0.2);
2876 set(&mut data, 2, 0, 0.1);
2877 set(&mut data, 3, 0, 0.1);
2878 set(&mut data, 0, 1, 0.5);
2879 set(&mut data, 1, 1, 0.5);
2880 set(&mut data, 2, 1, 0.1);
2881 set(&mut data, 3, 1, 0.1);
2882 set(&mut data, 0, 2, 0.8);
2883 set(&mut data, 1, 2, 0.8);
2884 set(&mut data, 2, 2, 0.1);
2885 set(&mut data, 3, 2, 0.1);
2886 set(&mut data, 4, 0, 0.9);
2887 set(&mut data, 4, 2, 0.8);
2888 set(&mut data, 6, 0, 3.0);
2889 set(&mut data, 7, 0, 3.0);
2890 set(&mut data, 6, 2, -3.0);
2891 set(&mut data, 7, 2, -3.0);
2892
2893 let output = Array2::from_shape_vec((feat, n), data).unwrap();
2894 let protos = Array3::<f32>::from_elem((8, 8, nm), 1.0);
2895
2896 let mut boxes: Vec<DetectBox> = Vec::with_capacity(10);
2897 let mut masks: Vec<Segmentation> = Vec::with_capacity(10);
2898 decode_yolo_segdet_float(
2899 output.view(),
2900 protos.view(),
2901 0.5,
2902 0.5,
2903 Some(Nms::ClassAgnostic),
2904 &mut boxes,
2905 &mut masks,
2906 )
2907 .unwrap();
2908
2909 assert_eq!(
2910 boxes.len(),
2911 2,
2912 "two anchors above threshold should survive (a0 score=0.9, a2 score=0.8); got {}",
2913 boxes.len()
2914 );
2915
2916 for (b, m) in boxes.iter().zip(masks.iter()) {
2922 let cx = (b.bbox.xmin + b.bbox.xmax) * 0.5;
2923 let mean = {
2924 let s = &m.segmentation;
2925 let total: u32 = s.iter().map(|&v| v as u32).sum();
2926 total as f32 / s.len() as f32
2927 };
2928 if cx < 0.3 {
2929 assert!(
2931 mean > 200.0,
2932 "anchor 0 detection (centre {cx:.2}) should have high-value mask; got mean {mean}"
2933 );
2934 } else if cx > 0.7 {
2935 assert!(
2937 mean < 50.0,
2938 "anchor 2 detection (centre {cx:.2}) should have low-value mask; got mean {mean}"
2939 );
2940 } else {
2941 panic!("unexpected detection centre {cx:.2}");
2942 }
2943 }
2944 }
2945
2946 fn make_float_boxes(scores: &[f32]) -> Vec<(DetectBox, ())> {
2952 scores
2953 .iter()
2954 .enumerate()
2955 .map(|(i, &s)| {
2956 (
2957 DetectBox {
2958 bbox: BoundingBox {
2959 xmin: 0.0,
2960 ymin: 0.0,
2961 xmax: 1.0,
2962 ymax: 1.0,
2963 },
2964 score: s,
2965 label: i,
2966 },
2967 (),
2968 )
2969 })
2970 .collect()
2971 }
2972
2973 fn make_quant_boxes(scores: &[i8]) -> Vec<(DetectBoxQuantized<i8>, ())> {
2975 scores
2976 .iter()
2977 .enumerate()
2978 .map(|(i, &s)| {
2979 (
2980 DetectBoxQuantized {
2981 bbox: BoundingBox {
2982 xmin: 0.0,
2983 ymin: 0.0,
2984 xmax: 1.0,
2985 ymax: 1.0,
2986 },
2987 score: s,
2988 label: i,
2989 },
2990 (),
2991 )
2992 })
2993 .collect()
2994 }
2995
2996 #[test]
2997 fn truncate_float_top_k_zero_is_unbounded() {
2998 let mut boxes = make_float_boxes(&[0.9, 0.1, 0.5, 0.3, 0.7]);
2999 let original_len = boxes.len();
3000 truncate_to_top_k_by_score(&mut boxes, 0);
3001 assert_eq!(
3002 boxes.len(),
3003 original_len,
3004 "top_k=0 should keep all candidates (no-limit semantics)"
3005 );
3006 }
3007
3008 #[test]
3009 fn truncate_float_top_k_normal() {
3010 let mut boxes = make_float_boxes(&[0.9, 0.1, 0.5, 0.3, 0.7]);
3011 truncate_to_top_k_by_score(&mut boxes, 3);
3012 assert_eq!(boxes.len(), 3);
3013 let mut retained: Vec<f32> = boxes.iter().map(|(b, _)| b.score).collect();
3015 retained.sort_by(|a, b| b.total_cmp(a));
3016 assert_eq!(retained, vec![0.9, 0.7, 0.5]);
3017 }
3018
3019 #[test]
3020 fn truncate_float_top_k_noop_when_under_cap() {
3021 let mut boxes = make_float_boxes(&[0.9, 0.5]);
3022 truncate_to_top_k_by_score(&mut boxes, 10);
3023 assert_eq!(boxes.len(), 2, "should be no-op when len <= top_k");
3024 }
3025
3026 #[test]
3027 fn truncate_quant_top_k_zero_is_unbounded() {
3028 let mut boxes = make_quant_boxes(&[120, -50, 30, -10, 80]);
3029 let original_len = boxes.len();
3030 truncate_to_top_k_by_score_quant(&mut boxes, 0);
3031 assert_eq!(
3032 boxes.len(),
3033 original_len,
3034 "top_k=0 should keep all candidates (no-limit semantics)"
3035 );
3036 }
3037
3038 #[test]
3039 fn truncate_quant_top_k_normal() {
3040 let mut boxes = make_quant_boxes(&[120, -50, 30, -10, 80]);
3041 truncate_to_top_k_by_score_quant(&mut boxes, 3);
3042 assert_eq!(boxes.len(), 3);
3043 let mut retained: Vec<i8> = boxes.iter().map(|(b, _)| b.score).collect();
3044 retained.sort_by(|a, b| b.cmp(a));
3045 assert_eq!(retained, vec![120, 80, 30]);
3046 }
3047
3048 #[test]
3049 fn truncate_quant_top_k_noop_when_under_cap() {
3050 let mut boxes = make_quant_boxes(&[120, 80]);
3051 truncate_to_top_k_by_score_quant(&mut boxes, 10);
3052 assert_eq!(boxes.len(), 2, "should be no-op when len <= top_k");
3053 }
3054}