1use std::fmt::Debug;
13
14use ndarray::{
15 parallel::prelude::{IntoParallelIterator, ParallelIterator},
16 s, Array2, Array3, ArrayView1, ArrayView2, ArrayView3,
17};
18use num_traits::{AsPrimitive, Float, PrimInt, Signed};
19
20use crate::{
21 byte::{
22 nms_class_aware_int, nms_extra_class_aware_int, nms_extra_int, nms_int,
23 postprocess_boxes_index_quant, postprocess_boxes_quant, quantize_score_threshold,
24 },
25 configs::Nms,
26 dequant_detect_box,
27 float::{
28 nms_class_aware_float, nms_extra_class_aware_float, nms_extra_float, nms_float,
29 postprocess_boxes_float, postprocess_boxes_index_float,
30 },
31 BBoxTypeTrait, BoundingBox, DetectBox, DetectBoxQuantized, ProtoData, ProtoLayout,
32 Quantization, Segmentation, XYWH, XYXY,
33};
34
35#[cfg(test)]
53pub(crate) const MAX_NMS_CANDIDATES: usize = 30_000;
54
55pub(crate) const DEFAULT_MAX_DETECTIONS: usize = 300;
62
63fn truncate_to_top_k_by_score<E: Send>(boxes: &mut Vec<(DetectBox, E)>, top_k: usize) {
68 if top_k > 0 && boxes.len() > top_k {
69 boxes.select_nth_unstable_by(top_k, |a, b| b.0.score.total_cmp(&a.0.score));
70 boxes.truncate(top_k);
71 }
72}
73
74fn truncate_to_top_k_by_score_quant<S: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send>(
79 boxes: &mut Vec<(DetectBoxQuantized<S>, E)>,
80 top_k: usize,
81) {
82 if top_k > 0 && boxes.len() > top_k {
83 boxes.select_nth_unstable_by(top_k, |a, b| b.0.score.cmp(&a.0.score));
84 boxes.truncate(top_k);
85 }
86}
87
88fn dispatch_nms_float(
95 nms: Option<Nms>,
96 iou: f32,
97 max_det: Option<usize>,
98 boxes: Vec<DetectBox>,
99) -> Vec<DetectBox> {
100 match nms {
101 Some(Nms::ClassAgnostic | Nms::Auto) => nms_float(iou, max_det, boxes),
102 Some(Nms::ClassAware) => nms_class_aware_float(iou, max_det, boxes),
103 None => boxes, }
105}
106
107pub(super) fn dispatch_nms_extra_float<E: Send + Sync>(
110 nms: Option<Nms>,
111 iou: f32,
112 max_det: Option<usize>,
113 boxes: Vec<(DetectBox, E)>,
114) -> Vec<(DetectBox, E)> {
115 match nms {
116 Some(Nms::ClassAgnostic | Nms::Auto) => nms_extra_float(iou, max_det, boxes),
117 Some(Nms::ClassAware) => nms_extra_class_aware_float(iou, max_det, boxes),
118 None => boxes, }
120}
121
122fn dispatch_nms_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync>(
125 nms: Option<Nms>,
126 iou: f32,
127 max_det: Option<usize>,
128 boxes: Vec<DetectBoxQuantized<SCORE>>,
129) -> Vec<DetectBoxQuantized<SCORE>> {
130 match nms {
131 Some(Nms::ClassAgnostic | Nms::Auto) => nms_int(iou, max_det, boxes),
132 Some(Nms::ClassAware) => nms_class_aware_int(iou, max_det, boxes),
133 None => boxes, }
135}
136
137fn dispatch_nms_extra_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send + Sync>(
140 nms: Option<Nms>,
141 iou: f32,
142 max_det: Option<usize>,
143 boxes: Vec<(DetectBoxQuantized<SCORE>, E)>,
144) -> Vec<(DetectBoxQuantized<SCORE>, E)> {
145 match nms {
146 Some(Nms::ClassAgnostic | Nms::Auto) => nms_extra_int(iou, max_det, boxes),
147 Some(Nms::ClassAware) => nms_extra_class_aware_int(iou, max_det, boxes),
148 None => boxes, }
150}
151
152#[inline]
159fn cap_or_default<T>(v: &Vec<T>) -> usize {
160 if v.capacity() > 0 {
161 v.capacity()
162 } else {
163 DEFAULT_MAX_DETECTIONS
164 }
165}
166
167pub(crate) fn decode_yolo_det<BOX: PrimInt + AsPrimitive<f32> + Send + Sync>(
201 output: (ArrayView2<BOX>, Quantization),
202 score_threshold: f32,
203 iou_threshold: f32,
204 nms: Option<Nms>,
205 output_boxes: &mut Vec<DetectBox>,
206) where
207 f32: AsPrimitive<BOX>,
208{
209 impl_yolo_quant::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
210}
211
212pub(crate) fn decode_yolo_det_float<T>(
219 output: ArrayView2<T>,
220 score_threshold: f32,
221 iou_threshold: f32,
222 nms: Option<Nms>,
223 output_boxes: &mut Vec<DetectBox>,
224) where
225 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
226 f32: AsPrimitive<T>,
227{
228 impl_yolo_float::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
229}
230
231#[cfg(test)]
245pub(crate) fn decode_yolo_segdet_quant<
246 BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
247 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
248>(
249 boxes: (ArrayView2<BOX>, Quantization),
250 protos: (ArrayView3<PROTO>, Quantization),
251 score_threshold: f32,
252 iou_threshold: f32,
253 nms: Option<Nms>,
254 output_boxes: &mut Vec<DetectBox>,
255 output_masks: &mut Vec<Segmentation>,
256) -> Result<(), crate::DecoderError>
257where
258 f32: AsPrimitive<BOX>,
259{
260 let cap = cap_or_default(output_boxes);
265 impl_yolo_segdet_quant::<XYWH, _, _>(
266 boxes,
267 protos,
268 score_threshold,
269 iou_threshold,
270 nms,
271 MAX_NMS_CANDIDATES,
272 cap,
273 None,
274 None,
275 output_boxes,
276 output_masks,
277 )
278}
279
280#[cfg(test)]
282pub(crate) fn decode_yolo_segdet_float<T>(
283 boxes: ArrayView2<T>,
284 protos: ArrayView3<T>,
285 score_threshold: f32,
286 iou_threshold: f32,
287 nms: Option<Nms>,
288 output_boxes: &mut Vec<DetectBox>,
289 output_masks: &mut Vec<Segmentation>,
290) -> Result<(), crate::DecoderError>
291where
292 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
293 f32: AsPrimitive<T>,
294{
295 let cap = cap_or_default(output_boxes);
298 impl_yolo_segdet_float::<XYWH, _, _>(
299 boxes,
300 protos,
301 score_threshold,
302 iou_threshold,
303 nms,
304 MAX_NMS_CANDIDATES,
305 cap,
306 None,
307 None,
308 output_boxes,
309 output_masks,
310 )
311}
312
313pub(crate) fn decode_yolo_split_det_quant<
325 BOX: PrimInt + AsPrimitive<i32> + AsPrimitive<f32> + Send + Sync,
326 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
327>(
328 boxes: (ArrayView2<BOX>, Quantization),
329 scores: (ArrayView2<SCORE>, Quantization),
330 score_threshold: f32,
331 iou_threshold: f32,
332 nms: Option<Nms>,
333 output_boxes: &mut Vec<DetectBox>,
334) where
335 f32: AsPrimitive<SCORE>,
336{
337 impl_yolo_split_quant::<XYWH, _, _>(
338 boxes,
339 scores,
340 score_threshold,
341 iou_threshold,
342 nms,
343 output_boxes,
344 );
345}
346
347pub(crate) fn decode_yolo_split_det_float<T>(
359 boxes: ArrayView2<T>,
360 scores: ArrayView2<T>,
361 score_threshold: f32,
362 iou_threshold: f32,
363 nms: Option<Nms>,
364 output_boxes: &mut Vec<DetectBox>,
365) where
366 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
367 f32: AsPrimitive<T>,
368{
369 impl_yolo_split_float::<XYWH, _, _>(
370 boxes,
371 scores,
372 score_threshold,
373 iou_threshold,
374 nms,
375 output_boxes,
376 );
377}
378
379pub(crate) fn decode_yolo_end_to_end_det_float<T>(
394 output: ArrayView2<T>,
395 score_threshold: f32,
396 output_boxes: &mut Vec<DetectBox>,
397) -> Result<(), crate::DecoderError>
398where
399 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
400 f32: AsPrimitive<T>,
401{
402 if output.shape()[0] < 6 {
404 return Err(crate::DecoderError::InvalidShape(format!(
405 "End-to-end detection output requires at least 6 rows, got {}",
406 output.shape()[0]
407 )));
408 }
409
410 let boxes = output.slice(s![0..4, ..]).reversed_axes();
412 let scores = output.slice(s![4..5, ..]).reversed_axes();
413 let classes = output.slice(s![5, ..]);
414 let mut boxes =
415 postprocess_boxes_index_float::<XYXY, _, _>(score_threshold.as_(), boxes, scores);
416 boxes.truncate(cap_or_default(output_boxes));
417 output_boxes.clear();
418 for (mut b, i) in boxes.into_iter() {
419 b.label = classes[i].as_() as usize;
420 output_boxes.push(b);
421 }
422 Ok(())
424}
425
426pub(crate) fn decode_yolo_end_to_end_segdet_float<T>(
444 output: ArrayView2<T>,
445 protos: ArrayView3<T>,
446 score_threshold: f32,
447 output_boxes: &mut Vec<DetectBox>,
448 output_masks: &mut Vec<crate::Segmentation>,
449) -> Result<(), crate::DecoderError>
450where
451 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
452 f32: AsPrimitive<T>,
453{
454 let (boxes, scores, classes, mask_coeff) =
455 postprocess_yolo_end_to_end_segdet(&output, protos.dim().2)?;
456 let cap = cap_or_default(output_boxes);
457 let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
458 boxes,
459 scores,
460 classes,
461 score_threshold,
462 cap,
463 );
464
465 impl_yolo_split_segdet_process_masks(boxes, mask_coeff, protos, output_boxes, output_masks)
468}
469
470pub(crate) fn decode_yolo_split_end_to_end_det_float<T: Float + AsPrimitive<f32>>(
479 boxes: ArrayView2<T>,
480 scores: ArrayView2<T>,
481 classes: ArrayView2<T>,
482 score_threshold: f32,
483 output_boxes: &mut Vec<DetectBox>,
484) -> Result<(), crate::DecoderError> {
485 let n = boxes.shape()[1];
486
487 let cap = cap_or_default(output_boxes);
488 output_boxes.clear();
489
490 let (boxes, scores, classes) = postprocess_yolo_split_end_to_end_det(boxes, scores, &classes)?;
491
492 for i in 0..n {
493 let score: f32 = scores[[i, 0]].as_();
494 if score < score_threshold {
495 continue;
496 }
497 if output_boxes.len() >= cap {
498 break;
499 }
500 output_boxes.push(DetectBox {
501 bbox: BoundingBox {
502 xmin: boxes[[i, 0]].as_(),
503 ymin: boxes[[i, 1]].as_(),
504 xmax: boxes[[i, 2]].as_(),
505 ymax: boxes[[i, 3]].as_(),
506 },
507 score,
508 label: classes[i].as_() as usize,
509 });
510 }
511 Ok(())
512}
513
514#[allow(clippy::too_many_arguments)]
523pub(crate) fn decode_yolo_split_end_to_end_segdet_float<T>(
524 boxes: ArrayView2<T>,
525 scores: ArrayView2<T>,
526 classes: ArrayView2<T>,
527 mask_coeff: ArrayView2<T>,
528 protos: ArrayView3<T>,
529 score_threshold: f32,
530 output_boxes: &mut Vec<DetectBox>,
531 output_masks: &mut Vec<crate::Segmentation>,
532) -> Result<(), crate::DecoderError>
533where
534 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
535 f32: AsPrimitive<T>,
536{
537 let (boxes, scores, classes, mask_coeff) =
538 postprocess_yolo_split_end_to_end_segdet(boxes, scores, &classes, mask_coeff)?;
539 let cap = cap_or_default(output_boxes);
540 let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
541 boxes,
542 scores,
543 classes,
544 score_threshold,
545 cap,
546 );
547
548 impl_yolo_split_segdet_process_masks(boxes, mask_coeff, protos, output_boxes, output_masks)
549}
550
551#[allow(clippy::type_complexity)]
552pub(crate) fn postprocess_yolo_end_to_end_segdet<'a, T>(
553 output: &'a ArrayView2<'_, T>,
554 num_protos: usize,
555) -> Result<
556 (
557 ArrayView2<'a, T>,
558 ArrayView2<'a, T>,
559 ArrayView1<'a, T>,
560 ArrayView2<'a, T>,
561 ),
562 crate::DecoderError,
563> {
564 if output.shape()[0] < 7 {
566 return Err(crate::DecoderError::InvalidShape(format!(
567 "End-to-end segdet output requires at least 7 rows, got {}",
568 output.shape()[0]
569 )));
570 }
571
572 let num_mask_coeffs = output.shape()[0] - 6;
573 if num_mask_coeffs != num_protos {
574 return Err(crate::DecoderError::InvalidShape(format!(
575 "Mask coefficients count ({}) doesn't match protos count ({})",
576 num_mask_coeffs, num_protos
577 )));
578 }
579
580 let boxes = output.slice(s![0..4, ..]).reversed_axes();
582 let scores = output.slice(s![4..5, ..]).reversed_axes();
583 let classes = output.slice(s![5, ..]);
584 let mask_coeff = output.slice(s![6.., ..]).reversed_axes();
585 Ok((boxes, scores, classes, mask_coeff))
586}
587
588#[allow(clippy::type_complexity)]
595pub(crate) fn postprocess_yolo_split_end_to_end_det<'a, 'b, 'c, BOXES, SCORES, CLASS>(
596 boxes: ArrayView2<'a, BOXES>,
597 scores: ArrayView2<'b, SCORES>,
598 classes: &'c ArrayView2<CLASS>,
599) -> Result<
600 (
601 ArrayView2<'a, BOXES>,
602 ArrayView2<'b, SCORES>,
603 ArrayView1<'c, CLASS>,
604 ),
605 crate::DecoderError,
606> {
607 let num_boxes = boxes.shape()[1];
608 if boxes.shape()[0] != 4 {
609 return Err(crate::DecoderError::InvalidShape(format!(
610 "Split end-to-end box_coords must be 4, got {}",
611 boxes.shape()[0]
612 )));
613 }
614
615 if scores.shape()[0] != 1 {
616 return Err(crate::DecoderError::InvalidShape(format!(
617 "Split end-to-end scores num_classes must be 1, got {}",
618 scores.shape()[0]
619 )));
620 }
621
622 if classes.shape()[0] != 1 {
623 return Err(crate::DecoderError::InvalidShape(format!(
624 "Split end-to-end classes num_classes must be 1, got {}",
625 classes.shape()[0]
626 )));
627 }
628
629 if scores.shape()[1] != num_boxes {
630 return Err(crate::DecoderError::InvalidShape(format!(
631 "Split end-to-end scores must have same num_boxes as boxes ({}), got {}",
632 num_boxes,
633 scores.shape()[1]
634 )));
635 }
636
637 if classes.shape()[1] != num_boxes {
638 return Err(crate::DecoderError::InvalidShape(format!(
639 "Split end-to-end classes must have same num_boxes as boxes ({}), got {}",
640 num_boxes,
641 classes.shape()[1]
642 )));
643 }
644
645 let boxes = boxes.reversed_axes();
646 let scores = scores.reversed_axes();
647 let classes = classes.slice(s![0, ..]);
648 Ok((boxes, scores, classes))
649}
650
651#[allow(clippy::type_complexity)]
654pub(crate) fn postprocess_yolo_split_end_to_end_segdet<
655 'a,
656 'b,
657 'c,
658 'd,
659 BOXES,
660 SCORES,
661 CLASS,
662 MASK,
663>(
664 boxes: ArrayView2<'a, BOXES>,
665 scores: ArrayView2<'b, SCORES>,
666 classes: &'c ArrayView2<CLASS>,
667 mask_coeff: ArrayView2<'d, MASK>,
668) -> Result<
669 (
670 ArrayView2<'a, BOXES>,
671 ArrayView2<'b, SCORES>,
672 ArrayView1<'c, CLASS>,
673 ArrayView2<'d, MASK>,
674 ),
675 crate::DecoderError,
676> {
677 let num_boxes = boxes.shape()[1];
678 if boxes.shape()[0] != 4 {
679 return Err(crate::DecoderError::InvalidShape(format!(
680 "Split end-to-end box_coords must be 4, got {}",
681 boxes.shape()[0]
682 )));
683 }
684
685 if scores.shape()[0] != 1 {
686 return Err(crate::DecoderError::InvalidShape(format!(
687 "Split end-to-end scores num_classes must be 1, got {}",
688 scores.shape()[0]
689 )));
690 }
691
692 if classes.shape()[0] != 1 {
693 return Err(crate::DecoderError::InvalidShape(format!(
694 "Split end-to-end classes num_classes must be 1, got {}",
695 classes.shape()[0]
696 )));
697 }
698
699 if scores.shape()[1] != num_boxes {
700 return Err(crate::DecoderError::InvalidShape(format!(
701 "Split end-to-end scores must have same num_boxes as boxes ({}), got {}",
702 num_boxes,
703 scores.shape()[1]
704 )));
705 }
706
707 if classes.shape()[1] != num_boxes {
708 return Err(crate::DecoderError::InvalidShape(format!(
709 "Split end-to-end classes must have same num_boxes as boxes ({}), got {}",
710 num_boxes,
711 classes.shape()[1]
712 )));
713 }
714
715 if mask_coeff.shape()[1] != num_boxes {
716 return Err(crate::DecoderError::InvalidShape(format!(
717 "Split end-to-end mask_coeff must have same num_boxes as boxes ({}), got {}",
718 num_boxes,
719 mask_coeff.shape()[1]
720 )));
721 }
722
723 let boxes = boxes.reversed_axes();
724 let scores = scores.reversed_axes();
725 let classes = classes.slice(s![0, ..]);
726 let mask_coeff = mask_coeff.reversed_axes();
727 Ok((boxes, scores, classes, mask_coeff))
728}
729pub(crate) fn impl_yolo_quant<B: BBoxTypeTrait, T: PrimInt + AsPrimitive<f32> + Send + Sync>(
734 output: (ArrayView2<T>, Quantization),
735 score_threshold: f32,
736 iou_threshold: f32,
737 nms: Option<Nms>,
738 output_boxes: &mut Vec<DetectBox>,
739) where
740 f32: AsPrimitive<T>,
741{
742 let _span = tracing::trace_span!("decoder.decode.yolo_quant_flat").entered();
743 let (boxes, quant_boxes) = output;
744 let (boxes_tensor, scores_tensor) = postprocess_yolo(&boxes);
745
746 let boxes = {
747 let score_threshold = quantize_score_threshold(score_threshold, quant_boxes);
748 postprocess_boxes_quant::<B, _, _>(
749 score_threshold,
750 boxes_tensor,
751 scores_tensor,
752 quant_boxes,
753 )
754 };
755
756 let cap = cap_or_default(output_boxes);
757 let boxes = dispatch_nms_int(nms, iou_threshold, Some(cap), boxes);
758 let len = cap.min(boxes.len());
761 output_boxes.clear();
762 for b in boxes.iter().take(len) {
763 output_boxes.push(dequant_detect_box(b, quant_boxes));
764 }
765}
766
767pub(crate) fn impl_yolo_float<B: BBoxTypeTrait, T: Float + AsPrimitive<f32> + Send + Sync>(
772 output: ArrayView2<T>,
773 score_threshold: f32,
774 iou_threshold: f32,
775 nms: Option<Nms>,
776 output_boxes: &mut Vec<DetectBox>,
777) where
778 f32: AsPrimitive<T>,
779{
780 let _span = tracing::trace_span!("decoder.decode.yolo_float_flat").entered();
781 let (boxes_tensor, scores_tensor) = postprocess_yolo(&output);
782 let boxes =
783 postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
784 let cap = cap_or_default(output_boxes);
785 let boxes = dispatch_nms_float(nms, iou_threshold, Some(cap), boxes);
786 let len = cap.min(boxes.len());
789 output_boxes.clear();
790 for b in boxes.into_iter().take(len) {
791 output_boxes.push(b);
792 }
793}
794
795pub(crate) fn impl_yolo_split_quant<
805 B: BBoxTypeTrait,
806 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
807 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
808>(
809 boxes: (ArrayView2<BOX>, Quantization),
810 scores: (ArrayView2<SCORE>, Quantization),
811 score_threshold: f32,
812 iou_threshold: f32,
813 nms: Option<Nms>,
814 output_boxes: &mut Vec<DetectBox>,
815) where
816 f32: AsPrimitive<SCORE>,
817{
818 let _span = tracing::trace_span!("decoder.decode.yolo_quant_split").entered();
819 let (boxes_tensor, quant_boxes) = boxes;
820 let (scores_tensor, quant_scores) = scores;
821
822 let boxes_tensor = boxes_tensor.reversed_axes();
823 let scores_tensor = scores_tensor.reversed_axes();
824
825 let boxes = {
826 let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
827 postprocess_boxes_quant::<B, _, _>(
828 score_threshold,
829 boxes_tensor,
830 scores_tensor,
831 quant_boxes,
832 )
833 };
834
835 let cap = cap_or_default(output_boxes);
836 let boxes = dispatch_nms_int(nms, iou_threshold, Some(cap), boxes);
837 let len = cap.min(boxes.len());
840 output_boxes.clear();
841 for b in boxes.iter().take(len) {
842 output_boxes.push(dequant_detect_box(b, quant_scores));
843 }
844}
845
846pub(crate) fn impl_yolo_split_float<
855 B: BBoxTypeTrait,
856 BOX: Float + AsPrimitive<f32> + Send + Sync,
857 SCORE: Float + AsPrimitive<f32> + Send + Sync,
858>(
859 boxes_tensor: ArrayView2<BOX>,
860 scores_tensor: ArrayView2<SCORE>,
861 score_threshold: f32,
862 iou_threshold: f32,
863 nms: Option<Nms>,
864 output_boxes: &mut Vec<DetectBox>,
865) where
866 f32: AsPrimitive<SCORE>,
867{
868 let _span = tracing::trace_span!("decoder.decode.yolo_float_split").entered();
869 let boxes_tensor = boxes_tensor.reversed_axes();
870 let scores_tensor = scores_tensor.reversed_axes();
871 let boxes =
872 postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
873 let cap = cap_or_default(output_boxes);
874 let boxes = dispatch_nms_float(nms, iou_threshold, Some(cap), boxes);
875 let len = cap.min(boxes.len());
878 output_boxes.clear();
879 for b in boxes.into_iter().take(len) {
880 output_boxes.push(b);
881 }
882}
883
884#[inline]
892pub(crate) fn maybe_normalize_boxes_in_place(
893 boxes: &mut [(DetectBox, usize)],
894 normalized: Option<bool>,
895 input_dims: Option<(usize, usize)>,
896) {
897 if normalized != Some(false) {
898 return;
899 }
900 let Some((w, h)) = input_dims else {
901 return;
902 };
903 if w == 0 || h == 0 {
904 return;
905 }
906 let inv_w = 1.0 / w as f32;
907 let inv_h = 1.0 / h as f32;
908 for (b, _) in boxes.iter_mut() {
909 b.bbox.xmin *= inv_w;
910 b.bbox.ymin *= inv_h;
911 b.bbox.xmax *= inv_w;
912 b.bbox.ymax *= inv_h;
913 }
914}
915
916#[allow(clippy::too_many_arguments)]
926pub(crate) fn impl_yolo_segdet_quant<
927 B: BBoxTypeTrait,
928 BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
929 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
930>(
931 boxes: (ArrayView2<BOX>, Quantization),
932 protos: (ArrayView3<PROTO>, Quantization),
933 score_threshold: f32,
934 iou_threshold: f32,
935 nms: Option<Nms>,
936 pre_nms_top_k: usize,
937 max_det: usize,
938 normalized: Option<bool>,
939 input_dims: Option<(usize, usize)>,
940 output_boxes: &mut Vec<DetectBox>,
941 output_masks: &mut Vec<Segmentation>,
942) -> Result<(), crate::DecoderError>
943where
944 f32: AsPrimitive<BOX>,
945{
946 let (boxes, quant_boxes) = boxes;
947 let num_protos = protos.0.dim().2;
948
949 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
950 let mut boxes = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
951 (boxes_tensor, quant_boxes),
952 (scores_tensor, quant_boxes),
953 score_threshold,
954 iou_threshold,
955 nms,
956 pre_nms_top_k,
957 max_det,
958 );
959 maybe_normalize_boxes_in_place(&mut boxes, normalized, input_dims);
960
961 impl_yolo_split_segdet_quant_process_masks::<_, _>(
962 boxes,
963 (mask_tensor, quant_boxes),
964 protos,
965 output_boxes,
966 output_masks,
967 )
968}
969
970#[allow(clippy::too_many_arguments)]
980pub(crate) fn impl_yolo_segdet_float<
981 B: BBoxTypeTrait,
982 BOX: Float + AsPrimitive<f32> + Send + Sync,
983 PROTO: Float + AsPrimitive<f32> + Send + Sync,
984>(
985 boxes: ArrayView2<BOX>,
986 protos: ArrayView3<PROTO>,
987 score_threshold: f32,
988 iou_threshold: f32,
989 nms: Option<Nms>,
990 pre_nms_top_k: usize,
991 max_det: usize,
992 normalized: Option<bool>,
993 input_dims: Option<(usize, usize)>,
994 output_boxes: &mut Vec<DetectBox>,
995 output_masks: &mut Vec<Segmentation>,
996) -> Result<(), crate::DecoderError>
997where
998 f32: AsPrimitive<BOX>,
999{
1000 let num_protos = protos.dim().2;
1001 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
1002 let mut boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
1003 boxes_tensor,
1004 scores_tensor,
1005 score_threshold,
1006 iou_threshold,
1007 nms,
1008 pre_nms_top_k,
1009 max_det,
1010 );
1011 maybe_normalize_boxes_in_place(&mut boxes, normalized, input_dims);
1012 impl_yolo_split_segdet_process_masks(boxes, mask_tensor, protos, output_boxes, output_masks)
1013}
1014
1015pub(crate) fn impl_yolo_segdet_get_boxes<
1016 B: BBoxTypeTrait,
1017 BOX: Float + AsPrimitive<f32> + Send + Sync,
1018 SCORE: Float + AsPrimitive<f32> + Send + Sync,
1019>(
1020 boxes_tensor: ArrayView2<BOX>,
1021 scores_tensor: ArrayView2<SCORE>,
1022 score_threshold: f32,
1023 iou_threshold: f32,
1024 nms: Option<Nms>,
1025 pre_nms_top_k: usize,
1026 max_det: usize,
1027) -> Vec<(DetectBox, usize)>
1028where
1029 f32: AsPrimitive<SCORE>,
1030{
1031 let span = tracing::trace_span!(
1032 "decoder.nms_get_boxes",
1033 n_candidates = tracing::field::Empty,
1034 n_after_topk = tracing::field::Empty,
1035 n_after_nms = tracing::field::Empty,
1036 n_detections = tracing::field::Empty,
1037 );
1038 let _guard = span.enter();
1039
1040 let mut boxes = {
1041 let _s = tracing::trace_span!("decoder.nms_get_boxes.score_filter").entered();
1042 postprocess_boxes_index_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor)
1043 };
1044 span.record("n_candidates", boxes.len());
1045
1046 if nms.is_some() {
1047 let _s = tracing::trace_span!("decoder.nms_get_boxes.top_k", k = pre_nms_top_k).entered();
1048 truncate_to_top_k_by_score(&mut boxes, pre_nms_top_k);
1049 }
1050 span.record("n_after_topk", boxes.len());
1051
1052 let mut boxes = {
1053 let _s = tracing::trace_span!("decoder.nms_get_boxes.suppress").entered();
1054 dispatch_nms_extra_float(nms, iou_threshold, Some(max_det), boxes)
1055 };
1056 span.record("n_after_nms", boxes.len());
1057
1058 boxes.sort_unstable_by(|a, b| b.0.score.total_cmp(&a.0.score));
1061 boxes.truncate(max_det);
1062 span.record("n_detections", boxes.len());
1063
1064 boxes
1065}
1066
1067pub(crate) fn impl_yolo_end_to_end_segdet_get_boxes<
1068 B: BBoxTypeTrait,
1069 BOX: Float + AsPrimitive<f32> + Send + Sync,
1070 SCORE: Float + AsPrimitive<f32> + Send + Sync,
1071 CLASS: AsPrimitive<f32> + Send + Sync,
1072>(
1073 boxes: ArrayView2<BOX>,
1074 scores: ArrayView2<SCORE>,
1075 classes: ArrayView1<CLASS>,
1076 score_threshold: f32,
1077 max_boxes: usize,
1078) -> Vec<(DetectBox, usize)>
1079where
1080 f32: AsPrimitive<SCORE>,
1081{
1082 let mut boxes = postprocess_boxes_index_float::<B, _, _>(score_threshold.as_(), boxes, scores);
1083 boxes.truncate(max_boxes);
1084 for (b, ind) in &mut boxes {
1085 b.label = classes[*ind].as_().round() as usize;
1086 }
1087 boxes
1088}
1089
1090pub(crate) fn impl_yolo_split_segdet_process_masks<
1091 MASK: Float + AsPrimitive<f32> + Send + Sync,
1092 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1093>(
1094 boxes: Vec<(DetectBox, usize)>,
1095 masks_tensor: ArrayView2<MASK>,
1096 protos_tensor: ArrayView3<PROTO>,
1097 output_boxes: &mut Vec<DetectBox>,
1098 output_masks: &mut Vec<Segmentation>,
1099) -> Result<(), crate::DecoderError> {
1100 let _span = tracing::trace_span!(
1101 "decoder.decode.process_masks",
1102 n = boxes.len(),
1103 mode = "float"
1104 )
1105 .entered();
1106 let boxes = decode_segdet_f32(boxes, masks_tensor, protos_tensor)?;
1110 output_boxes.clear();
1111 output_masks.clear();
1112 for (b, roi, m) in boxes.into_iter() {
1113 output_boxes.push(b);
1114 output_masks.push(Segmentation {
1115 xmin: roi.xmin,
1116 ymin: roi.ymin,
1117 xmax: roi.xmax,
1118 ymax: roi.ymax,
1119 segmentation: m,
1120 });
1121 }
1122 Ok(())
1123}
1124pub(crate) fn impl_yolo_split_segdet_quant_get_boxes<
1128 B: BBoxTypeTrait,
1129 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
1130 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
1131>(
1132 boxes: (ArrayView2<BOX>, Quantization),
1133 scores: (ArrayView2<SCORE>, Quantization),
1134 score_threshold: f32,
1135 iou_threshold: f32,
1136 nms: Option<Nms>,
1137 pre_nms_top_k: usize,
1138 max_det: usize,
1139) -> Vec<(DetectBox, usize)>
1140where
1141 f32: AsPrimitive<SCORE>,
1142{
1143 let (boxes_tensor, quant_boxes) = boxes;
1144 let (scores_tensor, quant_scores) = scores;
1145
1146 let span = tracing::trace_span!(
1147 "decoder.nms_get_boxes",
1148 n_candidates = tracing::field::Empty,
1149 n_after_topk = tracing::field::Empty,
1150 n_after_nms = tracing::field::Empty,
1151 n_detections = tracing::field::Empty,
1152 );
1153 let _guard = span.enter();
1154
1155 let mut boxes = {
1156 let _s = tracing::trace_span!("decoder.nms_get_boxes.score_filter").entered();
1157 let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
1158 postprocess_boxes_index_quant::<B, _, _>(
1159 score_threshold,
1160 boxes_tensor,
1161 scores_tensor,
1162 quant_boxes,
1163 )
1164 };
1165 span.record("n_candidates", boxes.len());
1166
1167 if nms.is_some() {
1168 let _s = tracing::trace_span!("decoder.nms_get_boxes.top_k", k = pre_nms_top_k).entered();
1169 truncate_to_top_k_by_score_quant(&mut boxes, pre_nms_top_k);
1170 }
1171 span.record("n_after_topk", boxes.len());
1172
1173 let mut boxes = {
1174 let _s = tracing::trace_span!("decoder.nms_get_boxes.suppress").entered();
1175 dispatch_nms_extra_int(nms, iou_threshold, Some(max_det), boxes)
1176 };
1177 span.record("n_after_nms", boxes.len());
1178
1179 boxes.sort_unstable_by_key(|b| std::cmp::Reverse(b.0.score));
1182 boxes.truncate(max_det);
1183 let result: Vec<_> = {
1184 let _s =
1185 tracing::trace_span!("decoder.nms_get_boxes.dequant_boxes", n = boxes.len()).entered();
1186 boxes
1187 .into_iter()
1188 .map(|(b, i)| (dequant_detect_box(&b, quant_scores), i))
1189 .collect()
1190 };
1191 span.record("n_detections", result.len());
1192
1193 result
1194}
1195
1196pub(crate) fn impl_yolo_split_segdet_quant_process_masks<
1197 MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1198 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1199>(
1200 boxes: Vec<(DetectBox, usize)>,
1201 mask_coeff: (ArrayView2<MASK>, Quantization),
1202 protos: (ArrayView3<PROTO>, Quantization),
1203 output_boxes: &mut Vec<DetectBox>,
1204 output_masks: &mut Vec<Segmentation>,
1205) -> Result<(), crate::DecoderError> {
1206 let _span = tracing::trace_span!(
1207 "decoder.decode.process_masks",
1208 n = boxes.len(),
1209 mode = "quant"
1210 )
1211 .entered();
1212 let (masks, quant_masks) = mask_coeff;
1213 let (protos, quant_protos) = protos;
1214
1215 let boxes = decode_segdet_quant(boxes, masks, protos, quant_masks, quant_protos)?;
1219 output_boxes.clear();
1220 output_masks.clear();
1221 for (b, roi, m) in boxes.into_iter() {
1222 output_boxes.push(b);
1223 output_masks.push(Segmentation {
1224 xmin: roi.xmin,
1225 ymin: roi.ymin,
1226 xmax: roi.xmax,
1227 ymax: roi.ymax,
1228 segmentation: m,
1229 });
1230 }
1231 Ok(())
1232}
1233
1234#[allow(clippy::too_many_arguments)]
1246pub(crate) fn impl_yolo_split_segdet_float<
1247 B: BBoxTypeTrait,
1248 BOX: Float + AsPrimitive<f32> + Send + Sync,
1249 SCORE: Float + AsPrimitive<f32> + Send + Sync,
1250 MASK: Float + AsPrimitive<f32> + Send + Sync,
1251 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1252>(
1253 boxes_tensor: ArrayView2<BOX>,
1254 scores_tensor: ArrayView2<SCORE>,
1255 mask_tensor: ArrayView2<MASK>,
1256 protos: ArrayView3<PROTO>,
1257 score_threshold: f32,
1258 iou_threshold: f32,
1259 nms: Option<Nms>,
1260 pre_nms_top_k: usize,
1261 max_det: usize,
1262 normalized: Option<bool>,
1263 input_dims: Option<(usize, usize)>,
1264 output_boxes: &mut Vec<DetectBox>,
1265 output_masks: &mut Vec<Segmentation>,
1266) -> Result<(), crate::DecoderError>
1267where
1268 f32: AsPrimitive<SCORE>,
1269{
1270 let (boxes_tensor, scores_tensor, mask_tensor) =
1271 postprocess_yolo_split_segdet(boxes_tensor, scores_tensor, mask_tensor);
1272
1273 let mut boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
1274 boxes_tensor,
1275 scores_tensor,
1276 score_threshold,
1277 iou_threshold,
1278 nms,
1279 pre_nms_top_k,
1280 max_det,
1281 );
1282 maybe_normalize_boxes_in_place(&mut boxes, normalized, input_dims);
1283 impl_yolo_split_segdet_process_masks(boxes, mask_tensor, protos, output_boxes, output_masks)
1284}
1285
1286#[allow(clippy::too_many_arguments)]
1293pub(crate) fn impl_yolo_segdet_quant_proto<
1294 B: BBoxTypeTrait,
1295 BOX: PrimInt
1296 + AsPrimitive<i64>
1297 + AsPrimitive<i128>
1298 + AsPrimitive<f32>
1299 + AsPrimitive<i8>
1300 + Send
1301 + Sync,
1302 PROTO: PrimInt
1303 + AsPrimitive<i64>
1304 + AsPrimitive<i128>
1305 + AsPrimitive<f32>
1306 + AsPrimitive<i8>
1307 + Send
1308 + Sync,
1309>(
1310 boxes: (ArrayView2<BOX>, Quantization),
1311 protos: (ArrayView3<PROTO>, Quantization),
1312 score_threshold: f32,
1313 iou_threshold: f32,
1314 nms: Option<Nms>,
1315 pre_nms_top_k: usize,
1316 max_det: usize,
1317 normalized: Option<bool>,
1318 input_dims: Option<(usize, usize)>,
1319 output_boxes: &mut Vec<DetectBox>,
1320) -> ProtoData
1321where
1322 f32: AsPrimitive<BOX>,
1323{
1324 let (boxes_arr, quant_boxes) = boxes;
1325 let (protos_arr, quant_protos) = protos;
1326 let num_protos = protos_arr.dim().2;
1327
1328 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes_arr, num_protos);
1329
1330 let mut det_indices = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
1331 (boxes_tensor, quant_boxes),
1332 (scores_tensor, quant_boxes),
1333 score_threshold,
1334 iou_threshold,
1335 nms,
1336 pre_nms_top_k,
1337 max_det,
1338 );
1339 maybe_normalize_boxes_in_place(&mut det_indices, normalized, input_dims);
1340
1341 extract_proto_data_quant(
1342 det_indices,
1343 mask_tensor,
1344 quant_boxes,
1345 protos_arr,
1346 quant_protos,
1347 output_boxes,
1348 )
1349}
1350
1351#[allow(clippy::too_many_arguments)]
1354pub(crate) fn impl_yolo_segdet_float_proto<
1355 B: BBoxTypeTrait,
1356 BOX: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1357 PROTO: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1358>(
1359 boxes: ArrayView2<BOX>,
1360 protos: ArrayView3<PROTO>,
1361 score_threshold: f32,
1362 iou_threshold: f32,
1363 nms: Option<Nms>,
1364 pre_nms_top_k: usize,
1365 max_det: usize,
1366 normalized: Option<bool>,
1367 input_dims: Option<(usize, usize)>,
1368 output_boxes: &mut Vec<DetectBox>,
1369) -> ProtoData
1370where
1371 f32: AsPrimitive<BOX>,
1372{
1373 let num_protos = protos.dim().2;
1374 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
1375
1376 let mut boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
1377 boxes_tensor,
1378 scores_tensor,
1379 score_threshold,
1380 iou_threshold,
1381 nms,
1382 pre_nms_top_k,
1383 max_det,
1384 );
1385 maybe_normalize_boxes_in_place(&mut boxes, normalized, input_dims);
1386
1387 extract_proto_data_float(boxes, mask_tensor, protos, output_boxes)
1388}
1389
1390#[allow(clippy::too_many_arguments)]
1393pub(crate) fn impl_yolo_split_segdet_float_proto<
1394 B: BBoxTypeTrait,
1395 BOX: Float + AsPrimitive<f32> + Send + Sync,
1396 SCORE: Float + AsPrimitive<f32> + Send + Sync,
1397 MASK: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1398 PROTO: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1399>(
1400 boxes_tensor: ArrayView2<BOX>,
1401 scores_tensor: ArrayView2<SCORE>,
1402 mask_tensor: ArrayView2<MASK>,
1403 protos: ArrayView3<PROTO>,
1404 score_threshold: f32,
1405 iou_threshold: f32,
1406 nms: Option<Nms>,
1407 pre_nms_top_k: usize,
1408 max_det: usize,
1409 output_boxes: &mut Vec<DetectBox>,
1410) -> ProtoData
1411where
1412 f32: AsPrimitive<SCORE>,
1413{
1414 let (boxes_tensor, scores_tensor, mask_tensor) =
1415 postprocess_yolo_split_segdet(boxes_tensor, scores_tensor, mask_tensor);
1416 let det_indices = impl_yolo_segdet_get_boxes::<B, _, _>(
1417 boxes_tensor,
1418 scores_tensor,
1419 score_threshold,
1420 iou_threshold,
1421 nms,
1422 pre_nms_top_k,
1423 max_det,
1424 );
1425
1426 extract_proto_data_float(det_indices, mask_tensor, protos, output_boxes)
1427}
1428
1429pub(crate) fn decode_yolo_end_to_end_segdet_float_proto<T>(
1431 output: ArrayView2<T>,
1432 protos: ArrayView3<T>,
1433 score_threshold: f32,
1434 output_boxes: &mut Vec<DetectBox>,
1435) -> Result<ProtoData, crate::DecoderError>
1436where
1437 T: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1438 f32: AsPrimitive<T>,
1439{
1440 let (boxes, scores, classes, mask_coeff) =
1441 postprocess_yolo_end_to_end_segdet(&output, protos.dim().2)?;
1442 let cap = cap_or_default(output_boxes);
1443 let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
1444 boxes,
1445 scores,
1446 classes,
1447 score_threshold,
1448 cap,
1449 );
1450
1451 Ok(extract_proto_data_float(
1452 boxes,
1453 mask_coeff,
1454 protos,
1455 output_boxes,
1456 ))
1457}
1458
1459#[allow(clippy::too_many_arguments)]
1461pub(crate) fn decode_yolo_split_end_to_end_segdet_float_proto<T>(
1462 boxes: ArrayView2<T>,
1463 scores: ArrayView2<T>,
1464 classes: ArrayView2<T>,
1465 mask_coeff: ArrayView2<T>,
1466 protos: ArrayView3<T>,
1467 score_threshold: f32,
1468 output_boxes: &mut Vec<DetectBox>,
1469) -> Result<ProtoData, crate::DecoderError>
1470where
1471 T: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1472 f32: AsPrimitive<T>,
1473{
1474 let (boxes, scores, classes, mask_coeff) =
1475 postprocess_yolo_split_end_to_end_segdet(boxes, scores, &classes, mask_coeff)?;
1476 let cap = cap_or_default(output_boxes);
1477 let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
1478 boxes,
1479 scores,
1480 classes,
1481 score_threshold,
1482 cap,
1483 );
1484
1485 Ok(extract_proto_data_float(
1486 boxes,
1487 mask_coeff,
1488 protos,
1489 output_boxes,
1490 ))
1491}
1492
1493pub(super) fn extract_proto_data_float<
1500 MASK: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1501 PROTO: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1502>(
1503 det_indices: Vec<(DetectBox, usize)>,
1504 mask_tensor: ArrayView2<MASK>,
1505 protos: ArrayView3<PROTO>,
1506 output_boxes: &mut Vec<DetectBox>,
1507) -> ProtoData {
1508 let _span = tracing::trace_span!(
1509 "decoder.decode_proto.extract_proto_data",
1510 mode = "float",
1511 n = det_indices.len(),
1512 num_protos = mask_tensor.ncols(),
1513 layout = "nhwc",
1514 )
1515 .entered();
1516
1517 let num_protos = mask_tensor.ncols();
1518 let n = det_indices.len();
1519
1520 let mut coeff_rows: Vec<MASK> = Vec::with_capacity(n * num_protos);
1525 output_boxes.clear();
1526 for (det, idx) in det_indices {
1527 output_boxes.push(det);
1528 let row = mask_tensor.row(idx);
1529 coeff_rows.extend(row.iter().copied());
1530 }
1531
1532 let mask_coefficients = MASK::slice_into_tensor_dyn(&coeff_rows, &[n, num_protos])
1533 .expect("allocating mask_coefficients TensorDyn");
1534 let protos_tensor =
1535 PROTO::arrayview3_into_tensor_dyn(protos).expect("allocating protos TensorDyn");
1536
1537 ProtoData {
1538 mask_coefficients,
1539 protos: protos_tensor,
1540 layout: ProtoLayout::Nhwc,
1541 }
1542}
1543
1544pub(crate) fn extract_proto_data_quant<
1553 MASK: PrimInt + AsPrimitive<f32> + AsPrimitive<i8> + Send + Sync + 'static,
1554 PROTO: PrimInt + AsPrimitive<f32> + AsPrimitive<i8> + Send + Sync + 'static,
1555>(
1556 det_indices: Vec<(DetectBox, usize)>,
1557 mask_tensor: ArrayView2<MASK>,
1558 quant_masks: Quantization,
1559 protos: ArrayView3<PROTO>,
1560 quant_protos: Quantization,
1561 output_boxes: &mut Vec<DetectBox>,
1562) -> ProtoData {
1563 use edgefirst_tensor::{Tensor, TensorDyn, TensorMapTrait, TensorMemory, TensorTrait};
1564
1565 let span = tracing::trace_span!(
1566 "decoder.decode_proto.extract_proto_data",
1567 mode = "quant",
1568 n = det_indices.len(),
1569 num_protos = tracing::field::Empty,
1570 layout = tracing::field::Empty,
1571 );
1572 let _guard = span.enter();
1573
1574 let num_protos = mask_tensor.ncols();
1575 let n = det_indices.len();
1576 span.record("num_protos", num_protos);
1577
1578 if n == 0 {
1584 output_boxes.clear();
1585 let (h, w, k) = protos.dim();
1586
1587 let (proto_shape, proto_layout) = if std::any::TypeId::of::<PROTO>()
1589 == std::any::TypeId::of::<i8>()
1590 {
1591 if protos.is_standard_layout() {
1592 (&[h, w, k][..], ProtoLayout::Nhwc)
1593 } else if protos.ndim() == 3 && protos.strides() == [w as isize, 1, (h * w) as isize] {
1594 (&[k, h, w][..], ProtoLayout::Nchw)
1595 } else {
1596 (&[h, w, k][..], ProtoLayout::Nhwc)
1597 }
1598 } else {
1599 (&[h, w, k][..], ProtoLayout::Nhwc)
1600 };
1601
1602 let coeff_tensor = Tensor::<i8>::new(&[0, num_protos], Some(TensorMemory::Mem), None)
1603 .expect("allocating empty mask_coefficients tensor");
1604 let coeff_quant =
1605 edgefirst_tensor::Quantization::per_tensor(quant_masks.scale, quant_masks.zero_point);
1606 let coeff_tensor = coeff_tensor
1607 .with_quantization(coeff_quant)
1608 .expect("per-tensor quantization on mask coefficients");
1609 let protos_tensor = Tensor::<i8>::new(proto_shape, Some(TensorMemory::Mem), None)
1610 .expect("allocating protos tensor");
1611 let tensor_quant =
1612 edgefirst_tensor::Quantization::per_tensor(quant_protos.scale, quant_protos.zero_point);
1613 let protos_tensor = protos_tensor
1614 .with_quantization(tensor_quant)
1615 .expect("per-tensor quantization on protos tensor");
1616 return ProtoData {
1617 mask_coefficients: TensorDyn::I8(coeff_tensor),
1618 protos: TensorDyn::I8(protos_tensor),
1619 layout: proto_layout,
1620 };
1621 }
1622
1623 let mask_coefficients: TensorDyn = if std::any::TypeId::of::<MASK>()
1629 == std::any::TypeId::of::<i8>()
1630 {
1631 let mut coeff_i8 = Vec::<i8>::with_capacity(n * num_protos);
1632 output_boxes.clear();
1633 for (det, idx) in det_indices {
1634 output_boxes.push(det);
1635 let row = mask_tensor.row(idx);
1636 coeff_i8.extend(row.iter().map(|v| {
1637 let v_i8: i8 = v.as_();
1638 v_i8
1639 }));
1640 }
1641 let coeff_tensor = Tensor::<i8>::new(&[n, num_protos], Some(TensorMemory::Mem), None)
1642 .expect("allocating mask_coefficients tensor");
1643 if n > 0 {
1644 let mut m = coeff_tensor
1645 .map()
1646 .expect("mapping mask_coefficients tensor");
1647 m.as_mut_slice().copy_from_slice(&coeff_i8);
1648 }
1649 let coeff_quant =
1650 edgefirst_tensor::Quantization::per_tensor(quant_masks.scale, quant_masks.zero_point);
1651 let coeff_tensor = coeff_tensor
1652 .with_quantization(coeff_quant)
1653 .expect("per-tensor quantization on mask coefficients");
1654 TensorDyn::I8(coeff_tensor)
1655 } else if std::any::TypeId::of::<MASK>() == std::any::TypeId::of::<i16>() {
1656 let mut coeff_i16 = Vec::<i16>::with_capacity(n * num_protos);
1659 output_boxes.clear();
1660 for (det, idx) in det_indices {
1661 output_boxes.push(det);
1662 let row = mask_tensor.row(idx);
1663 coeff_i16.extend(row.iter().map(|v| {
1664 let v_f32: f32 = v.as_();
1665 v_f32 as i16
1666 }));
1667 }
1668 let coeff_tensor = Tensor::<i16>::new(&[n, num_protos], Some(TensorMemory::Mem), None)
1669 .expect("allocating mask_coefficients tensor");
1670 if n > 0 {
1671 let mut m = coeff_tensor
1672 .map()
1673 .expect("mapping mask_coefficients tensor");
1674 m.as_mut_slice().copy_from_slice(&coeff_i16);
1675 }
1676 let coeff_quant =
1677 edgefirst_tensor::Quantization::per_tensor(quant_masks.scale, quant_masks.zero_point);
1678 let coeff_tensor = coeff_tensor
1679 .with_quantization(coeff_quant)
1680 .expect("per-tensor quantization on mask coefficients");
1681 TensorDyn::I16(coeff_tensor)
1682 } else {
1683 let scale = quant_masks.scale;
1685 let zp = quant_masks.zero_point as f32;
1686 let mut coeff_f32 = Vec::<f32>::with_capacity(n * num_protos);
1687 output_boxes.clear();
1688 for (det, idx) in det_indices {
1689 output_boxes.push(det);
1690 let row = mask_tensor.row(idx);
1691 coeff_f32.extend(row.iter().map(|v| {
1692 let v_f32: f32 = v.as_();
1693 (v_f32 - zp) * scale
1694 }));
1695 }
1696 let coeff_tensor = Tensor::<f32>::new(&[n, num_protos], Some(TensorMemory::Mem), None)
1697 .expect("allocating mask_coefficients tensor");
1698 if n > 0 {
1699 let mut m = coeff_tensor
1700 .map()
1701 .expect("mapping mask_coefficients tensor");
1702 m.as_mut_slice().copy_from_slice(&coeff_f32);
1703 }
1704 TensorDyn::F32(coeff_tensor)
1705 };
1706
1707 let (h, w, k) = protos.dim();
1711
1712 let (proto_shape, proto_layout) =
1714 if std::any::TypeId::of::<PROTO>() == std::any::TypeId::of::<i8>() {
1715 if protos.is_standard_layout() {
1716 (&[h, w, k][..], ProtoLayout::Nhwc)
1718 } else if protos.ndim() == 3 && protos.strides() == [w as isize, 1, (h * w) as isize] {
1719 (&[k, h, w][..], ProtoLayout::Nchw)
1723 } else {
1724 (&[h, w, k][..], ProtoLayout::Nhwc)
1726 }
1727 } else {
1728 (&[h, w, k][..], ProtoLayout::Nhwc)
1729 };
1730
1731 let protos_tensor = Tensor::<i8>::new(proto_shape, Some(TensorMemory::Mem), None)
1732 .expect("allocating protos tensor");
1733 {
1734 let mut m = protos_tensor.map().expect("mapping protos tensor");
1735 let dst = m.as_mut_slice();
1736 if std::any::TypeId::of::<PROTO>() == std::any::TypeId::of::<i8>() {
1737 if protos.is_standard_layout() {
1740 let src: &[i8] = unsafe {
1741 std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len())
1742 };
1743 dst.copy_from_slice(src);
1744 } else if protos.ndim() == 3 && protos.strides() == [w as isize, 1, (h * w) as isize] {
1745 let total = h * w * k;
1749 let src: &[i8] =
1752 unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, total) };
1753 dst.copy_from_slice(src);
1754 } else {
1755 for (d, s) in dst.iter_mut().zip(protos.iter()) {
1756 let v_i8: i8 = s.as_();
1757 *d = v_i8;
1758 }
1759 }
1760 } else {
1761 for (d, s) in dst.iter_mut().zip(protos.iter()) {
1762 let v_i8: i8 = s.as_();
1763 *d = v_i8;
1764 }
1765 }
1766 }
1767 let tensor_quant =
1768 edgefirst_tensor::Quantization::per_tensor(quant_protos.scale, quant_protos.zero_point);
1769 let protos_tensor = protos_tensor
1770 .with_quantization(tensor_quant)
1771 .expect("per-tensor quantization on new Tensor<i8>");
1772
1773 span.record("layout", tracing::field::debug(&proto_layout));
1774
1775 ProtoData {
1776 mask_coefficients,
1777 protos: TensorDyn::I8(protos_tensor),
1778 layout: proto_layout,
1779 }
1780}
1781
1782pub trait FloatProtoElem: Copy + 'static {
1788 fn slice_into_tensor_dyn(
1789 values: &[Self],
1790 shape: &[usize],
1791 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn>;
1792
1793 fn arrayview3_into_tensor_dyn(
1794 view: ArrayView3<'_, Self>,
1795 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn>;
1796}
1797
1798impl FloatProtoElem for f32 {
1799 fn slice_into_tensor_dyn(
1800 values: &[f32],
1801 shape: &[usize],
1802 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1803 edgefirst_tensor::Tensor::<f32>::from_slice(values, shape)
1804 .map(edgefirst_tensor::TensorDyn::F32)
1805 }
1806 fn arrayview3_into_tensor_dyn(
1807 view: ArrayView3<'_, f32>,
1808 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1809 edgefirst_tensor::Tensor::<f32>::from_arrayview3(view).map(edgefirst_tensor::TensorDyn::F32)
1810 }
1811}
1812
1813impl FloatProtoElem for half::f16 {
1814 fn slice_into_tensor_dyn(
1815 values: &[half::f16],
1816 shape: &[usize],
1817 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1818 edgefirst_tensor::Tensor::<half::f16>::from_slice(values, shape)
1819 .map(edgefirst_tensor::TensorDyn::F16)
1820 }
1821 fn arrayview3_into_tensor_dyn(
1822 view: ArrayView3<'_, half::f16>,
1823 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1824 edgefirst_tensor::Tensor::<half::f16>::from_arrayview3(view)
1825 .map(edgefirst_tensor::TensorDyn::F16)
1826 }
1827}
1828
1829impl FloatProtoElem for f64 {
1830 fn slice_into_tensor_dyn(
1831 values: &[f64],
1832 shape: &[usize],
1833 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1834 let narrowed: Vec<f32> = values.iter().map(|&v| v as f32).collect();
1836 edgefirst_tensor::Tensor::<f32>::from_slice(&narrowed, shape)
1837 .map(edgefirst_tensor::TensorDyn::F32)
1838 }
1839 fn arrayview3_into_tensor_dyn(
1840 view: ArrayView3<'_, f64>,
1841 ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1842 let narrowed: ndarray::Array3<f32> = view.mapv(|v| v as f32);
1843 edgefirst_tensor::Tensor::<f32>::from_arrayview3(narrowed.view())
1844 .map(edgefirst_tensor::TensorDyn::F32)
1845 }
1846}
1847
1848fn postprocess_yolo<'a, T>(
1849 output: &'a ArrayView2<'_, T>,
1850) -> (ArrayView2<'a, T>, ArrayView2<'a, T>) {
1851 let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1852 let scores_tensor = output.slice(s![4.., ..,]).reversed_axes();
1853 (boxes_tensor, scores_tensor)
1854}
1855
1856pub(crate) fn postprocess_yolo_seg<'a, T>(
1857 output: &'a ArrayView2<'_, T>,
1858 num_protos: usize,
1859) -> (ArrayView2<'a, T>, ArrayView2<'a, T>, ArrayView2<'a, T>) {
1860 assert!(
1861 output.shape()[0] > num_protos + 4,
1862 "Output shape is too short: {} <= {} + 4",
1863 output.shape()[0],
1864 num_protos
1865 );
1866 let num_classes = output.shape()[0] - 4 - num_protos;
1867 let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1868 let scores_tensor = output.slice(s![4..(num_classes + 4), ..,]).reversed_axes();
1869 let mask_tensor = output.slice(s![(num_classes + 4).., ..,]).reversed_axes();
1870 (boxes_tensor, scores_tensor, mask_tensor)
1871}
1872
1873pub(crate) fn postprocess_yolo_split_segdet<'a, 'b, 'c, BOX, SCORE, MASK>(
1874 boxes_tensor: ArrayView2<'a, BOX>,
1875 scores_tensor: ArrayView2<'b, SCORE>,
1876 mask_tensor: ArrayView2<'c, MASK>,
1877) -> (
1878 ArrayView2<'a, BOX>,
1879 ArrayView2<'b, SCORE>,
1880 ArrayView2<'c, MASK>,
1881) {
1882 let boxes_tensor = boxes_tensor.reversed_axes();
1883 let scores_tensor = scores_tensor.reversed_axes();
1884 let mask_tensor = mask_tensor.reversed_axes();
1885 (boxes_tensor, scores_tensor, mask_tensor)
1886}
1887
1888fn decode_segdet_f32<
1889 MASK: Float + AsPrimitive<f32> + Send + Sync,
1890 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1891>(
1892 boxes: Vec<(DetectBox, usize)>,
1893 masks: ArrayView2<MASK>,
1894 protos: ArrayView3<PROTO>,
1895) -> Result<Vec<(DetectBox, BoundingBox, Array3<u8>)>, crate::DecoderError> {
1896 if boxes.is_empty() {
1897 return Ok(Vec::new());
1898 }
1899 if masks.shape()[1] != protos.shape()[2] {
1900 return Err(crate::DecoderError::InvalidShape(format!(
1901 "Mask coefficients count ({}) doesn't match protos channel count ({})",
1902 masks.shape()[1],
1903 protos.shape()[2],
1904 )));
1905 }
1906 boxes
1907 .into_par_iter()
1908 .map(|b| {
1909 let ind = b.1;
1910 let (protos, roi) = protobox(&protos, &b.0.bbox)?;
1915 Ok((b.0, roi, make_segmentation(masks.row(ind), protos.view())))
1916 })
1917 .collect()
1918}
1919
1920pub(crate) fn decode_segdet_quant<
1921 MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1922 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1923>(
1924 boxes: Vec<(DetectBox, usize)>,
1925 masks: ArrayView2<MASK>,
1926 protos: ArrayView3<PROTO>,
1927 quant_masks: Quantization,
1928 quant_protos: Quantization,
1929) -> Result<Vec<(DetectBox, BoundingBox, Array3<u8>)>, crate::DecoderError> {
1930 if boxes.is_empty() {
1931 return Ok(Vec::new());
1932 }
1933 if masks.shape()[1] != protos.shape()[2] {
1934 return Err(crate::DecoderError::InvalidShape(format!(
1935 "Mask coefficients count ({}) doesn't match protos channel count ({})",
1936 masks.shape()[1],
1937 protos.shape()[2],
1938 )));
1939 }
1940
1941 let total_bits = MASK::zero().count_zeros() + PROTO::zero().count_zeros() + 5; boxes
1943 .into_iter()
1944 .map(|b| {
1945 let i = b.1;
1946 let (protos, roi) = protobox(&protos, &b.0.bbox.to_canonical())?;
1950 let seg = match total_bits {
1951 0..=64 => make_segmentation_quant::<MASK, PROTO, i64>(
1952 masks.row(i),
1953 protos.view(),
1954 quant_masks,
1955 quant_protos,
1956 ),
1957 65..=128 => make_segmentation_quant::<MASK, PROTO, i128>(
1958 masks.row(i),
1959 protos.view(),
1960 quant_masks,
1961 quant_protos,
1962 ),
1963 _ => {
1964 return Err(crate::DecoderError::NotSupported(format!(
1965 "Unsupported bit width ({total_bits}) for segmentation computation"
1966 )));
1967 }
1968 };
1969 Ok((b.0, roi, seg))
1970 })
1971 .collect()
1972}
1973
1974fn protobox<'a, T>(
1975 protos: &'a ArrayView3<T>,
1976 roi: &BoundingBox,
1977) -> Result<(ArrayView3<'a, T>, BoundingBox), crate::DecoderError> {
1978 let width = protos.dim().1 as f32;
1979 let height = protos.dim().0 as f32;
1980
1981 const NORM_LIMIT: f32 = 2.0;
1993 if roi.xmin > NORM_LIMIT
1994 || roi.ymin > NORM_LIMIT
1995 || roi.xmax > NORM_LIMIT
1996 || roi.ymax > NORM_LIMIT
1997 {
1998 return Err(crate::DecoderError::InvalidShape(format!(
1999 "Bounding box coordinates appear un-normalized (pixel-space). \
2000 Got bbox=({:.2}, {:.2}, {:.2}, {:.2}) but expected values in [0, 1]. \
2001 Two ways to fix this: \
2002 (1) declare `Detection::normalized = false` in the model schema \
2003 AND make sure the schema's `input.shape` / `input.dshape` carries \
2004 the model input dims so the decoder can divide by (W, H) before NMS \
2005 (EDGEAI-1303 — verify with `Decoder::input_dims().is_some()`); or \
2006 (2) normalize the boxes in-graph before decode().",
2007 roi.xmin, roi.ymin, roi.xmax, roi.ymax,
2008 )));
2009 }
2010
2011 let roi = [
2012 (roi.xmin * width).clamp(0.0, width) as usize,
2013 (roi.ymin * height).clamp(0.0, height) as usize,
2014 (roi.xmax * width).clamp(0.0, width).ceil() as usize,
2015 (roi.ymax * height).clamp(0.0, height).ceil() as usize,
2016 ];
2017
2018 let roi_norm = [
2019 roi[0] as f32 / width,
2020 roi[1] as f32 / height,
2021 roi[2] as f32 / width,
2022 roi[3] as f32 / height,
2023 ]
2024 .into();
2025
2026 let cropped = protos.slice(s![roi[1]..roi[3], roi[0]..roi[2], ..]);
2027
2028 Ok((cropped, roi_norm))
2029}
2030
2031fn make_segmentation<
2037 MASK: Float + AsPrimitive<f32> + Send + Sync,
2038 PROTO: Float + AsPrimitive<f32> + Send + Sync,
2039>(
2040 mask: ArrayView1<MASK>,
2041 protos: ArrayView3<PROTO>,
2042) -> Array3<u8> {
2043 let shape = protos.shape();
2044
2045 let mask = mask.to_shape((1, mask.len())).unwrap();
2047 let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
2048 let protos = protos.reversed_axes();
2049 let mask = mask.map(|x| x.as_());
2050 let protos = protos.map(|x| x.as_());
2051
2052 let mask = mask
2054 .dot(&protos)
2055 .into_shape_with_order((shape[0], shape[1], 1))
2056 .unwrap();
2057
2058 mask.map(|x| {
2059 let sigmoid = 1.0 / (1.0 + (-*x).exp());
2060 (sigmoid * 255.0).round() as u8
2061 })
2062}
2063
2064fn make_segmentation_quant<
2071 MASK: PrimInt + AsPrimitive<DEST> + Send + Sync,
2072 PROTO: PrimInt + AsPrimitive<DEST> + Send + Sync,
2073 DEST: PrimInt + 'static + Signed + AsPrimitive<f32> + Debug,
2074>(
2075 mask: ArrayView1<MASK>,
2076 protos: ArrayView3<PROTO>,
2077 quant_masks: Quantization,
2078 quant_protos: Quantization,
2079) -> Array3<u8>
2080where
2081 i32: AsPrimitive<DEST>,
2082 f32: AsPrimitive<DEST>,
2083{
2084 let shape = protos.shape();
2085
2086 let mask = mask.to_shape((1, mask.len())).unwrap();
2088
2089 let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
2090 let protos = protos.reversed_axes();
2091
2092 let zp = quant_masks.zero_point.as_();
2093
2094 let mask = mask.mapv(|x| x.as_() - zp);
2095
2096 let zp = quant_protos.zero_point.as_();
2097 let protos = protos.mapv(|x| x.as_() - zp);
2098
2099 let segmentation = mask
2101 .dot(&protos)
2102 .into_shape_with_order((shape[0], shape[1], 1))
2103 .unwrap();
2104
2105 let combined_scale = quant_masks.scale * quant_protos.scale;
2106 segmentation.map(|x| {
2107 let val: f32 = (*x).as_() * combined_scale;
2108 let sigmoid = 1.0 / (1.0 + (-val).exp());
2109 (sigmoid * 255.0).round() as u8
2110 })
2111}
2112
2113pub(crate) fn yolo_segmentation_to_mask(
2125 segmentation: ArrayView3<u8>,
2126 threshold: u8,
2127) -> Result<Array2<u8>, crate::DecoderError> {
2128 if segmentation.shape()[2] != 1 {
2129 return Err(crate::DecoderError::InvalidShape(format!(
2130 "Yolo Instance Segmentation should have shape (H, W, 1), got (H, W, {})",
2131 segmentation.shape()[2]
2132 )));
2133 }
2134 Ok(segmentation
2135 .slice(s![.., .., 0])
2136 .map(|x| if *x >= threshold { 1 } else { 0 }))
2137}
2138
2139#[cfg(test)]
2140#[cfg_attr(coverage_nightly, coverage(off))]
2141mod tests {
2142 use super::*;
2143 use ndarray::Array2;
2144
2145 #[test]
2150 fn test_end_to_end_det_basic_filtering() {
2151 let data: Vec<f32> = vec![
2155 0.1, 0.2, 0.3, 0.1, 0.2, 0.3, 0.5, 0.6, 0.7, 0.5, 0.6, 0.7, 0.9, 0.1, 0.2, 0.0, 1.0, 2.0, ];
2163 let output = Array2::from_shape_vec((6, 3), data).unwrap();
2164
2165 let mut boxes = Vec::with_capacity(10);
2166 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2167
2168 assert_eq!(boxes.len(), 1);
2170 assert_eq!(boxes[0].label, 0);
2171 assert!((boxes[0].score - 0.9).abs() < 0.01);
2172 assert!((boxes[0].bbox.xmin - 0.1).abs() < 0.01);
2173 assert!((boxes[0].bbox.ymin - 0.1).abs() < 0.01);
2174 assert!((boxes[0].bbox.xmax - 0.5).abs() < 0.01);
2175 assert!((boxes[0].bbox.ymax - 0.5).abs() < 0.01);
2176 }
2177
2178 #[test]
2179 fn test_end_to_end_det_all_pass_threshold() {
2180 let data: Vec<f32> = vec![
2182 10.0, 20.0, 10.0, 20.0, 50.0, 60.0, 50.0, 60.0, 0.8, 0.7, 1.0, 2.0, ];
2189 let output = Array2::from_shape_vec((6, 2), data).unwrap();
2190
2191 let mut boxes = Vec::with_capacity(10);
2192 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2193
2194 assert_eq!(boxes.len(), 2);
2195 assert_eq!(boxes[0].label, 1);
2196 assert_eq!(boxes[1].label, 2);
2197 }
2198
2199 #[test]
2200 fn test_end_to_end_det_none_pass_threshold() {
2201 let data: Vec<f32> = vec![
2203 10.0, 20.0, 10.0, 20.0, 50.0, 60.0, 50.0, 60.0, 0.1, 0.2, 1.0, 2.0, ];
2210 let output = Array2::from_shape_vec((6, 2), data).unwrap();
2211
2212 let mut boxes = Vec::with_capacity(10);
2213 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2214
2215 assert_eq!(boxes.len(), 0);
2216 }
2217
2218 #[test]
2219 fn test_end_to_end_det_capacity_limit() {
2220 let data: Vec<f32> = vec![
2222 0.1, 0.2, 0.3, 0.4, 0.5, 0.1, 0.2, 0.3, 0.4, 0.5, 0.5, 0.6, 0.7, 0.8, 0.9, 0.5, 0.6, 0.7, 0.8, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.0, 1.0, 2.0, 3.0, 4.0, ];
2229 let output = Array2::from_shape_vec((6, 5), data).unwrap();
2230
2231 let mut boxes = Vec::with_capacity(2); decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2233
2234 assert_eq!(boxes.len(), 2);
2235 }
2236
2237 #[test]
2238 fn test_end_to_end_det_empty_output() {
2239 let output = Array2::<f32>::zeros((6, 0));
2241
2242 let mut boxes = Vec::with_capacity(10);
2243 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2244
2245 assert_eq!(boxes.len(), 0);
2246 }
2247
2248 #[test]
2249 fn test_end_to_end_det_pixel_coordinates() {
2250 let data: Vec<f32> = vec![
2252 100.0, 200.0, 300.0, 400.0, 0.95, 5.0, ];
2259 let output = Array2::from_shape_vec((6, 1), data).unwrap();
2260
2261 let mut boxes = Vec::with_capacity(10);
2262 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2263
2264 assert_eq!(boxes.len(), 1);
2265 assert_eq!(boxes[0].label, 5);
2266 assert!((boxes[0].bbox.xmin - 100.0).abs() < 0.01);
2267 assert!((boxes[0].bbox.ymin - 200.0).abs() < 0.01);
2268 assert!((boxes[0].bbox.xmax - 300.0).abs() < 0.01);
2269 assert!((boxes[0].bbox.ymax - 400.0).abs() < 0.01);
2270 }
2271
2272 #[test]
2273 fn test_end_to_end_det_invalid_shape() {
2274 let output = Array2::<f32>::zeros((5, 3));
2276
2277 let mut boxes = Vec::with_capacity(10);
2278 let result = decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes);
2279
2280 assert!(result.is_err());
2281 assert!(matches!(
2282 result,
2283 Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 6 rows")
2284 ));
2285 }
2286
2287 #[test]
2292 fn test_end_to_end_segdet_basic() {
2293 let num_protos = 32;
2296 let num_detections = 2;
2297 let num_features = 6 + num_protos;
2298
2299 let mut data = vec![0.0f32; num_features * num_detections];
2301 data[0] = 0.1; data[1] = 0.5; data[num_detections] = 0.1; data[num_detections + 1] = 0.5; data[2 * num_detections] = 0.4; data[2 * num_detections + 1] = 0.9; data[3 * num_detections] = 0.4; data[3 * num_detections + 1] = 0.9; data[4 * num_detections] = 0.9; data[4 * num_detections + 1] = 0.3; data[5 * num_detections] = 1.0; data[5 * num_detections + 1] = 2.0; for i in 6..num_features {
2316 data[i * num_detections] = 0.1;
2317 data[i * num_detections + 1] = 0.1;
2318 }
2319
2320 let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
2321
2322 let protos = Array3::<f32>::zeros((16, 16, num_protos));
2324
2325 let mut boxes = Vec::with_capacity(10);
2326 let mut masks = Vec::with_capacity(10);
2327 decode_yolo_end_to_end_segdet_float(
2328 output.view(),
2329 protos.view(),
2330 0.5,
2331 &mut boxes,
2332 &mut masks,
2333 )
2334 .unwrap();
2335
2336 assert_eq!(boxes.len(), 1);
2338 assert_eq!(masks.len(), 1);
2339 assert_eq!(boxes[0].label, 1);
2340 assert!((boxes[0].score - 0.9).abs() < 0.01);
2341 }
2342
2343 #[test]
2344 fn test_end_to_end_segdet_mask_coordinates() {
2345 let num_protos = 32;
2347 let num_features = 6 + num_protos;
2348
2349 let mut data = vec![0.0f32; num_features];
2350 data[0] = 0.2; data[1] = 0.2; data[2] = 0.8; data[3] = 0.8; data[4] = 0.95; data[5] = 3.0; let output = Array2::from_shape_vec((num_features, 1), data).unwrap();
2358 let protos = Array3::<f32>::zeros((16, 16, num_protos));
2359
2360 let mut boxes = Vec::with_capacity(10);
2361 let mut masks = Vec::with_capacity(10);
2362 decode_yolo_end_to_end_segdet_float(
2363 output.view(),
2364 protos.view(),
2365 0.5,
2366 &mut boxes,
2367 &mut masks,
2368 )
2369 .unwrap();
2370
2371 assert_eq!(boxes.len(), 1);
2372 assert_eq!(masks.len(), 1);
2373
2374 let step = 1.0 / 16.0;
2378 assert!(masks[0].xmin <= boxes[0].bbox.xmin);
2379 assert!(masks[0].ymin <= boxes[0].bbox.ymin);
2380 assert!(masks[0].xmax >= boxes[0].bbox.xmax);
2381 assert!(masks[0].ymax >= boxes[0].bbox.ymax);
2382 assert!((boxes[0].bbox.xmin - masks[0].xmin) < step);
2383 assert!((boxes[0].bbox.ymin - masks[0].ymin) < step);
2384 assert!((masks[0].xmax - boxes[0].bbox.xmax) < step);
2385 assert!((masks[0].ymax - boxes[0].bbox.ymax) < step);
2386 }
2387
2388 #[test]
2389 fn test_end_to_end_segdet_empty_output() {
2390 let num_protos = 32;
2391 let output = Array2::<f32>::zeros((6 + num_protos, 0));
2392 let protos = Array3::<f32>::zeros((16, 16, num_protos));
2393
2394 let mut boxes = Vec::with_capacity(10);
2395 let mut masks = Vec::with_capacity(10);
2396 decode_yolo_end_to_end_segdet_float(
2397 output.view(),
2398 protos.view(),
2399 0.5,
2400 &mut boxes,
2401 &mut masks,
2402 )
2403 .unwrap();
2404
2405 assert_eq!(boxes.len(), 0);
2406 assert_eq!(masks.len(), 0);
2407 }
2408
2409 #[test]
2410 fn test_end_to_end_segdet_capacity_limit() {
2411 let num_protos = 32;
2412 let num_detections = 5;
2413 let num_features = 6 + num_protos;
2414
2415 let mut data = vec![0.0f32; num_features * num_detections];
2416 for i in 0..num_detections {
2418 data[i] = 0.1 * (i as f32); data[num_detections + i] = 0.1 * (i as f32); data[2 * num_detections + i] = 0.1 * (i as f32) + 0.2; data[3 * num_detections + i] = 0.1 * (i as f32) + 0.2; data[4 * num_detections + i] = 0.9; data[5 * num_detections + i] = i as f32; }
2425
2426 let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
2427 let protos = Array3::<f32>::zeros((16, 16, num_protos));
2428
2429 let mut boxes = Vec::with_capacity(2); let mut masks = Vec::with_capacity(2);
2431 decode_yolo_end_to_end_segdet_float(
2432 output.view(),
2433 protos.view(),
2434 0.5,
2435 &mut boxes,
2436 &mut masks,
2437 )
2438 .unwrap();
2439
2440 assert_eq!(boxes.len(), 2);
2441 assert_eq!(masks.len(), 2);
2442 }
2443
2444 #[test]
2445 fn test_end_to_end_segdet_invalid_shape_too_few_rows() {
2446 let output = Array2::<f32>::zeros((6, 3));
2448 let protos = Array3::<f32>::zeros((16, 16, 32));
2449
2450 let mut boxes = Vec::with_capacity(10);
2451 let mut masks = Vec::with_capacity(10);
2452 let result = decode_yolo_end_to_end_segdet_float(
2453 output.view(),
2454 protos.view(),
2455 0.5,
2456 &mut boxes,
2457 &mut masks,
2458 );
2459
2460 assert!(result.is_err());
2461 assert!(matches!(
2462 result,
2463 Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 7 rows")
2464 ));
2465 }
2466
2467 #[test]
2468 fn test_end_to_end_segdet_invalid_shape_protos_mismatch() {
2469 let num_protos = 32;
2471 let output = Array2::<f32>::zeros((6 + 16, 3)); let protos = Array3::<f32>::zeros((16, 16, num_protos)); let mut boxes = Vec::with_capacity(10);
2475 let mut masks = Vec::with_capacity(10);
2476 let result = decode_yolo_end_to_end_segdet_float(
2477 output.view(),
2478 protos.view(),
2479 0.5,
2480 &mut boxes,
2481 &mut masks,
2482 );
2483
2484 assert!(result.is_err());
2485 assert!(matches!(
2486 result,
2487 Err(crate::DecoderError::InvalidShape(s)) if s.contains("doesn't match protos count")
2488 ));
2489 }
2490
2491 #[test]
2496 fn test_split_end_to_end_segdet_basic() {
2497 let num_protos = 32;
2500 let num_detections = 2;
2501 let num_features = 6 + num_protos;
2502
2503 let mut data = vec![0.0f32; num_features * num_detections];
2505 data[0] = 0.1; data[1] = 0.5; data[num_detections] = 0.1; data[num_detections + 1] = 0.5; data[2 * num_detections] = 0.4; data[2 * num_detections + 1] = 0.9; data[3 * num_detections] = 0.4; data[3 * num_detections + 1] = 0.9; data[4 * num_detections] = 0.9; data[4 * num_detections + 1] = 0.3; data[5 * num_detections] = 1.0; data[5 * num_detections + 1] = 2.0; for i in 6..num_features {
2520 data[i * num_detections] = 0.1;
2521 data[i * num_detections + 1] = 0.1;
2522 }
2523
2524 let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
2525 let box_coords = output.slice(s![..4, ..]);
2526 let scores = output.slice(s![4..5, ..]);
2527 let classes = output.slice(s![5..6, ..]);
2528 let mask_coeff = output.slice(s![6.., ..]);
2529 let protos = Array3::<f32>::zeros((16, 16, num_protos));
2531
2532 let mut boxes = Vec::with_capacity(10);
2533 let mut masks = Vec::with_capacity(10);
2534 decode_yolo_split_end_to_end_segdet_float(
2535 box_coords,
2536 scores,
2537 classes,
2538 mask_coeff,
2539 protos.view(),
2540 0.5,
2541 &mut boxes,
2542 &mut masks,
2543 )
2544 .unwrap();
2545
2546 assert_eq!(boxes.len(), 1);
2548 assert_eq!(masks.len(), 1);
2549 assert_eq!(boxes[0].label, 1);
2550 assert!((boxes[0].score - 0.9).abs() < 0.01);
2551 }
2552
2553 #[test]
2558 fn test_segmentation_to_mask_basic() {
2559 let data: Vec<u8> = vec![
2561 100, 200, 50, 150, 10, 255, 128, 64, 0, 127, 128, 255, 64, 64, 192, 192, ];
2566 let segmentation = Array3::from_shape_vec((4, 4, 1), data).unwrap();
2567
2568 let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2569
2570 assert_eq!(mask[[0, 0]], 0); assert_eq!(mask[[0, 1]], 1); assert_eq!(mask[[0, 2]], 0); assert_eq!(mask[[0, 3]], 1); assert_eq!(mask[[1, 1]], 1); assert_eq!(mask[[1, 2]], 1); assert_eq!(mask[[2, 0]], 0); assert_eq!(mask[[2, 1]], 0); }
2580
2581 #[test]
2582 fn test_segmentation_to_mask_all_above() {
2583 let segmentation = Array3::from_elem((4, 4, 1), 255u8);
2584 let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2585 assert!(mask.iter().all(|&x| x == 1));
2586 }
2587
2588 #[test]
2589 fn test_segmentation_to_mask_all_below() {
2590 let segmentation = Array3::from_elem((4, 4, 1), 64u8);
2591 let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2592 assert!(mask.iter().all(|&x| x == 0));
2593 }
2594
2595 #[test]
2596 fn test_segmentation_to_mask_invalid_shape() {
2597 let segmentation = Array3::from_elem((4, 4, 3), 128u8);
2598 let result = yolo_segmentation_to_mask(segmentation.view(), 128);
2599
2600 assert!(result.is_err());
2601 assert!(matches!(
2602 result,
2603 Err(crate::DecoderError::InvalidShape(s)) if s.contains("(H, W, 1)")
2604 ));
2605 }
2606
2607 #[test]
2612 fn test_protobox_clamps_edge_coordinates() {
2613 let protos = Array3::<f32>::zeros((16, 16, 4));
2615 let view = protos.view();
2616 let roi = BoundingBox {
2617 xmin: 0.5,
2618 ymin: 0.5,
2619 xmax: 1.0,
2620 ymax: 1.0,
2621 };
2622 let result = protobox(&view, &roi);
2623 assert!(result.is_ok(), "protobox should accept xmax=1.0");
2624 let (cropped, _roi_norm) = result.unwrap();
2625 assert!(cropped.shape()[0] > 0);
2627 assert!(cropped.shape()[1] > 0);
2628 assert_eq!(cropped.shape()[2], 4);
2629 }
2630
2631 #[test]
2632 fn test_protobox_rejects_wildly_out_of_range() {
2633 let protos = Array3::<f32>::zeros((16, 16, 4));
2635 let view = protos.view();
2636 let roi = BoundingBox {
2637 xmin: 0.0,
2638 ymin: 0.0,
2639 xmax: 3.0,
2640 ymax: 3.0,
2641 };
2642 let result = protobox(&view, &roi);
2643 assert!(
2644 matches!(result, Err(crate::DecoderError::InvalidShape(s)) if s.contains("un-normalized")),
2645 "protobox should reject coords > NORM_LIMIT"
2646 );
2647 }
2648
2649 #[test]
2650 fn test_protobox_accepts_slightly_over_one() {
2651 let protos = Array3::<f32>::zeros((16, 16, 4));
2653 let view = protos.view();
2654 let roi = BoundingBox {
2655 xmin: 0.0,
2656 ymin: 0.0,
2657 xmax: 1.5,
2658 ymax: 1.5,
2659 };
2660 let result = protobox(&view, &roi);
2661 assert!(
2662 result.is_ok(),
2663 "protobox should accept coords <= NORM_LIMIT (2.0)"
2664 );
2665 let (cropped, _roi_norm) = result.unwrap();
2666 assert_eq!(cropped.shape()[0], 16);
2668 assert_eq!(cropped.shape()[1], 16);
2669 }
2670
2671 #[test]
2672 fn test_segdet_float_proto_no_panic() {
2673 let num_proposals = 100; let num_classes = 80;
2677 let num_mask_coeffs = 32;
2678 let rows = 4 + num_classes + num_mask_coeffs; let mut data = vec![0.0f32; rows * num_proposals];
2684 for i in 0..num_proposals {
2685 let row = |r: usize| r * num_proposals + i;
2686 data[row(0)] = 320.0; data[row(1)] = 320.0; data[row(2)] = 50.0; data[row(3)] = 50.0; data[row(4)] = 0.9; }
2692 let boxes = ndarray::Array2::from_shape_vec((rows, num_proposals), data).unwrap();
2693
2694 let protos = ndarray::Array3::<f32>::zeros((160, 160, num_mask_coeffs));
2699
2700 let mut output_boxes = Vec::with_capacity(300);
2701
2702 let proto_data = impl_yolo_segdet_float_proto::<XYWH, _, _>(
2704 boxes.view(),
2705 protos.view(),
2706 0.5,
2707 0.7,
2708 Some(Nms::default()),
2709 MAX_NMS_CANDIDATES,
2710 300,
2711 None,
2712 None,
2713 &mut output_boxes,
2714 );
2715
2716 assert!(!output_boxes.is_empty());
2718 let coeffs_shape = proto_data.mask_coefficients.shape();
2719 assert_eq!(coeffs_shape[0], output_boxes.len());
2720 assert_eq!(coeffs_shape[1], num_mask_coeffs);
2722 }
2723
2724 #[test]
2739 fn test_pre_nms_cap_truncates_excess_candidates() {
2740 let n: usize = 50_000;
2741 let num_classes = 1;
2742
2743 let mut boxes_data = Vec::with_capacity(n * 4);
2747 let mut scores_data = Vec::with_capacity(n * num_classes);
2748 for i in 0..n {
2749 boxes_data.extend_from_slice(&[0.1f32, 0.1, 0.5, 0.5]);
2750 scores_data.push(0.99 - (i as f32) * 1e-7);
2753 }
2754 let boxes = Array2::from_shape_vec((n, 4), boxes_data).unwrap();
2755 let scores = Array2::from_shape_vec((n, num_classes), scores_data).unwrap();
2756
2757 let result = impl_yolo_segdet_get_boxes::<XYXY, _, _>(
2758 boxes.view(),
2759 scores.view(),
2760 0.1,
2761 1.0, Some(Nms::ClassAgnostic), crate::yolo::MAX_NMS_CANDIDATES, usize::MAX, );
2766
2767 assert_eq!(
2768 result.len(),
2769 crate::yolo::MAX_NMS_CANDIDATES,
2770 "pre-NMS cap should truncate to MAX_NMS_CANDIDATES; got {}",
2771 result.len()
2772 );
2773 let top_score = result[0].0.score;
2776 assert!(
2777 top_score > 0.98,
2778 "highest-ranked survivor should have the largest score, got {top_score}"
2779 );
2780 }
2781
2782 #[test]
2787 fn test_pre_nms_cap_truncates_excess_candidates_quant() {
2788 use crate::Quantization;
2789 let n: usize = 50_000;
2790 let num_classes = 1;
2791
2792 let boxes_data = (0..n).flat_map(|_| [10i8, 10, 50, 50]).collect::<Vec<_>>();
2795 let boxes = Array2::from_shape_vec((n, 4), boxes_data).unwrap();
2796 let quant_boxes = Quantization {
2797 scale: 0.01,
2798 zero_point: 0,
2799 };
2800
2801 let scores_data: Vec<u8> = (0..n)
2806 .map(|i| 250u8.saturating_sub((i % 200) as u8))
2807 .collect();
2808 let scores = Array2::from_shape_vec((n, num_classes), scores_data).unwrap();
2809 let quant_scores = Quantization {
2810 scale: 0.00392,
2811 zero_point: 0,
2812 };
2813
2814 let result = impl_yolo_split_segdet_quant_get_boxes::<XYXY, _, _>(
2815 (boxes.view(), quant_boxes),
2816 (scores.view(), quant_scores),
2817 0.1,
2818 1.0, Some(Nms::ClassAgnostic), crate::yolo::MAX_NMS_CANDIDATES, usize::MAX, );
2823
2824 assert_eq!(
2825 result.len(),
2826 crate::yolo::MAX_NMS_CANDIDATES,
2827 "quant path pre-NMS cap should truncate to MAX_NMS_CANDIDATES; got {}",
2828 result.len()
2829 );
2830 }
2831
2832 #[test]
2846 fn segdet_combined_tensor_pairs_detection_with_matching_mask_row() {
2847 let nc = 2; let nm = 2; let n = 3; let feat = 4 + nc + nm; let mut data = vec![0.0f32; feat * n];
2870 let set = |d: &mut [f32], r: usize, c: usize, v: f32| d[r * n + c] = v;
2871 set(&mut data, 0, 0, 0.2);
2872 set(&mut data, 1, 0, 0.2);
2873 set(&mut data, 2, 0, 0.1);
2874 set(&mut data, 3, 0, 0.1);
2875 set(&mut data, 0, 1, 0.5);
2876 set(&mut data, 1, 1, 0.5);
2877 set(&mut data, 2, 1, 0.1);
2878 set(&mut data, 3, 1, 0.1);
2879 set(&mut data, 0, 2, 0.8);
2880 set(&mut data, 1, 2, 0.8);
2881 set(&mut data, 2, 2, 0.1);
2882 set(&mut data, 3, 2, 0.1);
2883 set(&mut data, 4, 0, 0.9);
2884 set(&mut data, 4, 2, 0.8);
2885 set(&mut data, 6, 0, 3.0);
2886 set(&mut data, 7, 0, 3.0);
2887 set(&mut data, 6, 2, -3.0);
2888 set(&mut data, 7, 2, -3.0);
2889
2890 let output = Array2::from_shape_vec((feat, n), data).unwrap();
2891 let protos = Array3::<f32>::from_elem((8, 8, nm), 1.0);
2892
2893 let mut boxes: Vec<DetectBox> = Vec::with_capacity(10);
2894 let mut masks: Vec<Segmentation> = Vec::with_capacity(10);
2895 decode_yolo_segdet_float(
2896 output.view(),
2897 protos.view(),
2898 0.5,
2899 0.5,
2900 Some(Nms::ClassAgnostic),
2901 &mut boxes,
2902 &mut masks,
2903 )
2904 .unwrap();
2905
2906 assert_eq!(
2907 boxes.len(),
2908 2,
2909 "two anchors above threshold should survive (a0 score=0.9, a2 score=0.8); got {}",
2910 boxes.len()
2911 );
2912
2913 for (b, m) in boxes.iter().zip(masks.iter()) {
2919 let cx = (b.bbox.xmin + b.bbox.xmax) * 0.5;
2920 let mean = {
2921 let s = &m.segmentation;
2922 let total: u32 = s.iter().map(|&v| v as u32).sum();
2923 total as f32 / s.len() as f32
2924 };
2925 if cx < 0.3 {
2926 assert!(
2928 mean > 200.0,
2929 "anchor 0 detection (centre {cx:.2}) should have high-value mask; got mean {mean}"
2930 );
2931 } else if cx > 0.7 {
2932 assert!(
2934 mean < 50.0,
2935 "anchor 2 detection (centre {cx:.2}) should have low-value mask; got mean {mean}"
2936 );
2937 } else {
2938 panic!("unexpected detection centre {cx:.2}");
2939 }
2940 }
2941 }
2942
2943 fn make_float_boxes(scores: &[f32]) -> Vec<(DetectBox, ())> {
2949 scores
2950 .iter()
2951 .enumerate()
2952 .map(|(i, &s)| {
2953 (
2954 DetectBox {
2955 bbox: BoundingBox {
2956 xmin: 0.0,
2957 ymin: 0.0,
2958 xmax: 1.0,
2959 ymax: 1.0,
2960 },
2961 score: s,
2962 label: i,
2963 },
2964 (),
2965 )
2966 })
2967 .collect()
2968 }
2969
2970 fn make_quant_boxes(scores: &[i8]) -> Vec<(DetectBoxQuantized<i8>, ())> {
2972 scores
2973 .iter()
2974 .enumerate()
2975 .map(|(i, &s)| {
2976 (
2977 DetectBoxQuantized {
2978 bbox: BoundingBox {
2979 xmin: 0.0,
2980 ymin: 0.0,
2981 xmax: 1.0,
2982 ymax: 1.0,
2983 },
2984 score: s,
2985 label: i,
2986 },
2987 (),
2988 )
2989 })
2990 .collect()
2991 }
2992
2993 #[test]
2994 fn truncate_float_top_k_zero_is_unbounded() {
2995 let mut boxes = make_float_boxes(&[0.9, 0.1, 0.5, 0.3, 0.7]);
2996 let original_len = boxes.len();
2997 truncate_to_top_k_by_score(&mut boxes, 0);
2998 assert_eq!(
2999 boxes.len(),
3000 original_len,
3001 "top_k=0 should keep all candidates (no-limit semantics)"
3002 );
3003 }
3004
3005 #[test]
3006 fn truncate_float_top_k_normal() {
3007 let mut boxes = make_float_boxes(&[0.9, 0.1, 0.5, 0.3, 0.7]);
3008 truncate_to_top_k_by_score(&mut boxes, 3);
3009 assert_eq!(boxes.len(), 3);
3010 let mut retained: Vec<f32> = boxes.iter().map(|(b, _)| b.score).collect();
3012 retained.sort_by(|a, b| b.total_cmp(a));
3013 assert_eq!(retained, vec![0.9, 0.7, 0.5]);
3014 }
3015
3016 #[test]
3017 fn truncate_float_top_k_noop_when_under_cap() {
3018 let mut boxes = make_float_boxes(&[0.9, 0.5]);
3019 truncate_to_top_k_by_score(&mut boxes, 10);
3020 assert_eq!(boxes.len(), 2, "should be no-op when len <= top_k");
3021 }
3022
3023 #[test]
3024 fn truncate_quant_top_k_zero_is_unbounded() {
3025 let mut boxes = make_quant_boxes(&[120, -50, 30, -10, 80]);
3026 let original_len = boxes.len();
3027 truncate_to_top_k_by_score_quant(&mut boxes, 0);
3028 assert_eq!(
3029 boxes.len(),
3030 original_len,
3031 "top_k=0 should keep all candidates (no-limit semantics)"
3032 );
3033 }
3034
3035 #[test]
3036 fn truncate_quant_top_k_normal() {
3037 let mut boxes = make_quant_boxes(&[120, -50, 30, -10, 80]);
3038 truncate_to_top_k_by_score_quant(&mut boxes, 3);
3039 assert_eq!(boxes.len(), 3);
3040 let mut retained: Vec<i8> = boxes.iter().map(|(b, _)| b.score).collect();
3041 retained.sort_by(|a, b| b.cmp(a));
3042 assert_eq!(retained, vec![120, 80, 30]);
3043 }
3044
3045 #[test]
3046 fn truncate_quant_top_k_noop_when_under_cap() {
3047 let mut boxes = make_quant_boxes(&[120, 80]);
3048 truncate_to_top_k_by_score_quant(&mut boxes, 10);
3049 assert_eq!(boxes.len(), 2, "should be no-op when len <= top_k");
3050 }
3051}