1use std::fmt::Debug;
5
6use ndarray::{
7 parallel::prelude::{IntoParallelIterator, ParallelIterator},
8 s, Array2, Array3, ArrayView1, ArrayView2, ArrayView3,
9};
10use num_traits::{AsPrimitive, Float, PrimInt, Signed};
11
12use crate::{
13 byte::{
14 nms_class_aware_int, nms_extra_class_aware_int, nms_extra_int, nms_int,
15 postprocess_boxes_index_quant, postprocess_boxes_quant, quantize_score_threshold,
16 },
17 configs::Nms,
18 dequant_detect_box,
19 float::{
20 nms_class_aware_float, nms_extra_class_aware_float, nms_extra_float, nms_float,
21 postprocess_boxes_float, postprocess_boxes_index_float,
22 },
23 BBoxTypeTrait, BoundingBox, DetectBox, DetectBoxQuantized, ProtoData, ProtoTensor,
24 Quantization, Segmentation, XYWH, XYXY,
25};
26
27fn dispatch_nms_float(nms: Option<Nms>, iou: f32, boxes: Vec<DetectBox>) -> Vec<DetectBox> {
29 match nms {
30 Some(Nms::ClassAgnostic) => nms_float(iou, boxes),
31 Some(Nms::ClassAware) => nms_class_aware_float(iou, boxes),
32 None => boxes, }
34}
35
36pub(super) fn dispatch_nms_extra_float<E: Send + Sync>(
39 nms: Option<Nms>,
40 iou: f32,
41 boxes: Vec<(DetectBox, E)>,
42) -> Vec<(DetectBox, E)> {
43 match nms {
44 Some(Nms::ClassAgnostic) => nms_extra_float(iou, boxes),
45 Some(Nms::ClassAware) => nms_extra_class_aware_float(iou, boxes),
46 None => boxes, }
48}
49
50fn dispatch_nms_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync>(
53 nms: Option<Nms>,
54 iou: f32,
55 boxes: Vec<DetectBoxQuantized<SCORE>>,
56) -> Vec<DetectBoxQuantized<SCORE>> {
57 match nms {
58 Some(Nms::ClassAgnostic) => nms_int(iou, boxes),
59 Some(Nms::ClassAware) => nms_class_aware_int(iou, boxes),
60 None => boxes, }
62}
63
64fn dispatch_nms_extra_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send + Sync>(
67 nms: Option<Nms>,
68 iou: f32,
69 boxes: Vec<(DetectBoxQuantized<SCORE>, E)>,
70) -> Vec<(DetectBoxQuantized<SCORE>, E)> {
71 match nms {
72 Some(Nms::ClassAgnostic) => nms_extra_int(iou, boxes),
73 Some(Nms::ClassAware) => nms_extra_class_aware_int(iou, boxes),
74 None => boxes, }
76}
77
78pub fn decode_yolo_det<BOX: PrimInt + AsPrimitive<f32> + Send + Sync>(
85 output: (ArrayView2<BOX>, Quantization),
86 score_threshold: f32,
87 iou_threshold: f32,
88 nms: Option<Nms>,
89 output_boxes: &mut Vec<DetectBox>,
90) where
91 f32: AsPrimitive<BOX>,
92{
93 impl_yolo_quant::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
94}
95
96pub fn decode_yolo_det_float<T>(
103 output: ArrayView2<T>,
104 score_threshold: f32,
105 iou_threshold: f32,
106 nms: Option<Nms>,
107 output_boxes: &mut Vec<DetectBox>,
108) where
109 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
110 f32: AsPrimitive<T>,
111{
112 impl_yolo_float::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
113}
114
115pub fn decode_yolo_segdet_quant<
127 BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
128 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
129>(
130 boxes: (ArrayView2<BOX>, Quantization),
131 protos: (ArrayView3<PROTO>, Quantization),
132 score_threshold: f32,
133 iou_threshold: f32,
134 nms: Option<Nms>,
135 output_boxes: &mut Vec<DetectBox>,
136 output_masks: &mut Vec<Segmentation>,
137) -> Result<(), crate::DecoderError>
138where
139 f32: AsPrimitive<BOX>,
140{
141 impl_yolo_segdet_quant::<XYWH, _, _>(
142 boxes,
143 protos,
144 score_threshold,
145 iou_threshold,
146 nms,
147 output_boxes,
148 output_masks,
149 )
150}
151
152pub fn decode_yolo_segdet_float<T>(
164 boxes: ArrayView2<T>,
165 protos: ArrayView3<T>,
166 score_threshold: f32,
167 iou_threshold: f32,
168 nms: Option<Nms>,
169 output_boxes: &mut Vec<DetectBox>,
170 output_masks: &mut Vec<Segmentation>,
171) -> Result<(), crate::DecoderError>
172where
173 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
174 f32: AsPrimitive<T>,
175{
176 impl_yolo_segdet_float::<XYWH, _, _>(
177 boxes,
178 protos,
179 score_threshold,
180 iou_threshold,
181 nms,
182 output_boxes,
183 output_masks,
184 )
185}
186
187pub fn decode_yolo_split_det_quant<
199 BOX: PrimInt + AsPrimitive<i32> + AsPrimitive<f32> + Send + Sync,
200 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
201>(
202 boxes: (ArrayView2<BOX>, Quantization),
203 scores: (ArrayView2<SCORE>, Quantization),
204 score_threshold: f32,
205 iou_threshold: f32,
206 nms: Option<Nms>,
207 output_boxes: &mut Vec<DetectBox>,
208) where
209 f32: AsPrimitive<SCORE>,
210{
211 impl_yolo_split_quant::<XYWH, _, _>(
212 boxes,
213 scores,
214 score_threshold,
215 iou_threshold,
216 nms,
217 output_boxes,
218 );
219}
220
221pub fn decode_yolo_split_det_float<T>(
233 boxes: ArrayView2<T>,
234 scores: ArrayView2<T>,
235 score_threshold: f32,
236 iou_threshold: f32,
237 nms: Option<Nms>,
238 output_boxes: &mut Vec<DetectBox>,
239) where
240 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
241 f32: AsPrimitive<T>,
242{
243 impl_yolo_split_float::<XYWH, _, _>(
244 boxes,
245 scores,
246 score_threshold,
247 iou_threshold,
248 nms,
249 output_boxes,
250 );
251}
252
253#[allow(clippy::too_many_arguments)]
267pub fn decode_yolo_split_segdet<
268 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
269 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
270 MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
271 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
272>(
273 boxes: (ArrayView2<BOX>, Quantization),
274 scores: (ArrayView2<SCORE>, Quantization),
275 mask_coeff: (ArrayView2<MASK>, Quantization),
276 protos: (ArrayView3<PROTO>, Quantization),
277 score_threshold: f32,
278 iou_threshold: f32,
279 nms: Option<Nms>,
280 output_boxes: &mut Vec<DetectBox>,
281 output_masks: &mut Vec<Segmentation>,
282) -> Result<(), crate::DecoderError>
283where
284 f32: AsPrimitive<SCORE>,
285{
286 impl_yolo_split_segdet_quant::<XYWH, _, _, _, _>(
287 boxes,
288 scores,
289 mask_coeff,
290 protos,
291 score_threshold,
292 iou_threshold,
293 nms,
294 output_boxes,
295 output_masks,
296 )
297}
298
299#[allow(clippy::too_many_arguments)]
313pub fn decode_yolo_split_segdet_float<T>(
314 boxes: ArrayView2<T>,
315 scores: ArrayView2<T>,
316 mask_coeff: ArrayView2<T>,
317 protos: ArrayView3<T>,
318 score_threshold: f32,
319 iou_threshold: f32,
320 nms: Option<Nms>,
321 output_boxes: &mut Vec<DetectBox>,
322 output_masks: &mut Vec<Segmentation>,
323) -> Result<(), crate::DecoderError>
324where
325 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
326 f32: AsPrimitive<T>,
327{
328 impl_yolo_split_segdet_float::<XYWH, _, _, _, _>(
329 boxes,
330 scores,
331 mask_coeff,
332 protos,
333 score_threshold,
334 iou_threshold,
335 nms,
336 output_boxes,
337 output_masks,
338 )
339}
340
341pub fn decode_yolo_end_to_end_det_float<T>(
356 output: ArrayView2<T>,
357 score_threshold: f32,
358 output_boxes: &mut Vec<DetectBox>,
359) -> Result<(), crate::DecoderError>
360where
361 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
362 f32: AsPrimitive<T>,
363{
364 if output.shape()[0] < 6 {
366 return Err(crate::DecoderError::InvalidShape(format!(
367 "End-to-end detection output requires at least 6 rows, got {}",
368 output.shape()[0]
369 )));
370 }
371
372 let boxes = output.slice(s![0..4, ..]).reversed_axes();
374 let scores = output.slice(s![4..5, ..]).reversed_axes();
375 let classes = output.slice(s![5, ..]);
376 let mut boxes =
377 postprocess_boxes_index_float::<XYXY, _, _>(score_threshold.as_(), boxes, scores);
378 boxes.truncate(output_boxes.capacity());
379 output_boxes.clear();
380 for (mut b, i) in boxes.into_iter() {
381 b.label = classes[i].as_() as usize;
382 output_boxes.push(b);
383 }
384 Ok(())
386}
387
388pub fn decode_yolo_end_to_end_segdet_float<T>(
406 output: ArrayView2<T>,
407 protos: ArrayView3<T>,
408 score_threshold: f32,
409 output_boxes: &mut Vec<DetectBox>,
410 output_masks: &mut Vec<crate::Segmentation>,
411) -> Result<(), crate::DecoderError>
412where
413 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
414 f32: AsPrimitive<T>,
415{
416 let (boxes, scores, classes, mask_coeff) =
417 postprocess_yolo_end_to_end_segdet(&output, protos.dim().2)?;
418 let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
419 boxes,
420 scores,
421 classes,
422 score_threshold,
423 output_boxes.capacity(),
424 );
425
426 impl_yolo_split_segdet_process_masks(boxes, mask_coeff, protos, output_boxes, output_masks)
429}
430
431pub fn decode_yolo_split_end_to_end_det_float<T: Float + AsPrimitive<f32>>(
440 boxes: ArrayView2<T>,
441 scores: ArrayView2<T>,
442 classes: ArrayView2<T>,
443 score_threshold: f32,
444 output_boxes: &mut Vec<DetectBox>,
445) -> Result<(), crate::DecoderError> {
446 let n = boxes.shape()[1];
447
448 output_boxes.clear();
449
450 let (boxes, scores, classes) = postprocess_yolo_split_end_to_end_det(boxes, scores, &classes)?;
451
452 for i in 0..n {
453 let score: f32 = scores[[i, 0]].as_();
454 if score < score_threshold {
455 continue;
456 }
457 if output_boxes.len() >= output_boxes.capacity() {
458 break;
459 }
460 output_boxes.push(DetectBox {
461 bbox: BoundingBox {
462 xmin: boxes[[i, 0]].as_(),
463 ymin: boxes[[i, 1]].as_(),
464 xmax: boxes[[i, 2]].as_(),
465 ymax: boxes[[i, 3]].as_(),
466 },
467 score,
468 label: classes[i].as_() as usize,
469 });
470 }
471 Ok(())
472}
473
474#[allow(clippy::too_many_arguments)]
483pub fn decode_yolo_split_end_to_end_segdet_float<T>(
484 boxes: ArrayView2<T>,
485 scores: ArrayView2<T>,
486 classes: ArrayView2<T>,
487 mask_coeff: ArrayView2<T>,
488 protos: ArrayView3<T>,
489 score_threshold: f32,
490 output_boxes: &mut Vec<DetectBox>,
491 output_masks: &mut Vec<crate::Segmentation>,
492) -> Result<(), crate::DecoderError>
493where
494 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
495 f32: AsPrimitive<T>,
496{
497 let (boxes, scores, classes, mask_coeff) =
498 postprocess_yolo_split_end_to_end_segdet(boxes, scores, &classes, mask_coeff)?;
499 let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
500 boxes,
501 scores,
502 classes,
503 score_threshold,
504 output_boxes.capacity(),
505 );
506
507 impl_yolo_split_segdet_process_masks(boxes, mask_coeff, protos, output_boxes, output_masks)
508}
509
510#[allow(clippy::type_complexity)]
511pub(crate) fn postprocess_yolo_end_to_end_segdet<'a, T>(
512 output: &'a ArrayView2<'_, T>,
513 num_protos: usize,
514) -> Result<
515 (
516 ArrayView2<'a, T>,
517 ArrayView2<'a, T>,
518 ArrayView1<'a, T>,
519 ArrayView2<'a, T>,
520 ),
521 crate::DecoderError,
522> {
523 if output.shape()[0] < 7 {
525 return Err(crate::DecoderError::InvalidShape(format!(
526 "End-to-end segdet output requires at least 7 rows, got {}",
527 output.shape()[0]
528 )));
529 }
530
531 let num_mask_coeffs = output.shape()[0] - 6;
532 if num_mask_coeffs != num_protos {
533 return Err(crate::DecoderError::InvalidShape(format!(
534 "Mask coefficients count ({}) doesn't match protos count ({})",
535 num_mask_coeffs, num_protos
536 )));
537 }
538
539 let boxes = output.slice(s![0..4, ..]).reversed_axes();
541 let scores = output.slice(s![4..5, ..]).reversed_axes();
542 let classes = output.slice(s![5, ..]);
543 let mask_coeff = output.slice(s![6.., ..]).reversed_axes();
544 Ok((boxes, scores, classes, mask_coeff))
545}
546
547#[allow(clippy::type_complexity)]
554pub(crate) fn postprocess_yolo_split_end_to_end_det<'a, 'b, 'c, BOXES, SCORES, CLASS>(
555 boxes: ArrayView2<'a, BOXES>,
556 scores: ArrayView2<'b, SCORES>,
557 classes: &'c ArrayView2<CLASS>,
558) -> Result<
559 (
560 ArrayView2<'a, BOXES>,
561 ArrayView2<'b, SCORES>,
562 ArrayView1<'c, CLASS>,
563 ),
564 crate::DecoderError,
565> {
566 let num_boxes = boxes.shape()[1];
567 if boxes.shape()[0] != 4 {
568 return Err(crate::DecoderError::InvalidShape(format!(
569 "Split end-to-end box_coords must be 4, got {}",
570 boxes.shape()[0]
571 )));
572 }
573
574 if scores.shape()[0] != 1 {
575 return Err(crate::DecoderError::InvalidShape(format!(
576 "Split end-to-end scores num_classes must be 1, got {}",
577 scores.shape()[0]
578 )));
579 }
580
581 if classes.shape()[0] != 1 {
582 return Err(crate::DecoderError::InvalidShape(format!(
583 "Split end-to-end classes num_classes must be 1, got {}",
584 classes.shape()[0]
585 )));
586 }
587
588 if scores.shape()[1] != num_boxes {
589 return Err(crate::DecoderError::InvalidShape(format!(
590 "Split end-to-end scores must have same num_boxes as boxes ({}), got {}",
591 num_boxes,
592 scores.shape()[1]
593 )));
594 }
595
596 if classes.shape()[1] != num_boxes {
597 return Err(crate::DecoderError::InvalidShape(format!(
598 "Split end-to-end classes must have same num_boxes as boxes ({}), got {}",
599 num_boxes,
600 classes.shape()[1]
601 )));
602 }
603
604 let boxes = boxes.reversed_axes();
605 let scores = scores.reversed_axes();
606 let classes = classes.slice(s![0, ..]);
607 Ok((boxes, scores, classes))
608}
609
610#[allow(clippy::type_complexity)]
613pub(crate) fn postprocess_yolo_split_end_to_end_segdet<
614 'a,
615 'b,
616 'c,
617 'd,
618 BOXES,
619 SCORES,
620 CLASS,
621 MASK,
622>(
623 boxes: ArrayView2<'a, BOXES>,
624 scores: ArrayView2<'b, SCORES>,
625 classes: &'c ArrayView2<CLASS>,
626 mask_coeff: ArrayView2<'d, MASK>,
627) -> Result<
628 (
629 ArrayView2<'a, BOXES>,
630 ArrayView2<'b, SCORES>,
631 ArrayView1<'c, CLASS>,
632 ArrayView2<'d, MASK>,
633 ),
634 crate::DecoderError,
635> {
636 let num_boxes = boxes.shape()[1];
637 if boxes.shape()[0] != 4 {
638 return Err(crate::DecoderError::InvalidShape(format!(
639 "Split end-to-end box_coords must be 4, got {}",
640 boxes.shape()[0]
641 )));
642 }
643
644 if scores.shape()[0] != 1 {
645 return Err(crate::DecoderError::InvalidShape(format!(
646 "Split end-to-end scores num_classes must be 1, got {}",
647 scores.shape()[0]
648 )));
649 }
650
651 if classes.shape()[0] != 1 {
652 return Err(crate::DecoderError::InvalidShape(format!(
653 "Split end-to-end classes num_classes must be 1, got {}",
654 classes.shape()[0]
655 )));
656 }
657
658 if scores.shape()[1] != num_boxes {
659 return Err(crate::DecoderError::InvalidShape(format!(
660 "Split end-to-end scores must have same num_boxes as boxes ({}), got {}",
661 num_boxes,
662 scores.shape()[1]
663 )));
664 }
665
666 if classes.shape()[1] != num_boxes {
667 return Err(crate::DecoderError::InvalidShape(format!(
668 "Split end-to-end classes must have same num_boxes as boxes ({}), got {}",
669 num_boxes,
670 classes.shape()[1]
671 )));
672 }
673
674 if mask_coeff.shape()[1] != num_boxes {
675 return Err(crate::DecoderError::InvalidShape(format!(
676 "Split end-to-end mask_coeff must have same num_boxes as boxes ({}), got {}",
677 num_boxes,
678 mask_coeff.shape()[1]
679 )));
680 }
681
682 let boxes = boxes.reversed_axes();
683 let scores = scores.reversed_axes();
684 let classes = classes.slice(s![0, ..]);
685 let mask_coeff = mask_coeff.reversed_axes();
686 Ok((boxes, scores, classes, mask_coeff))
687}
688pub(crate) fn impl_yolo_quant<B: BBoxTypeTrait, T: PrimInt + AsPrimitive<f32> + Send + Sync>(
693 output: (ArrayView2<T>, Quantization),
694 score_threshold: f32,
695 iou_threshold: f32,
696 nms: Option<Nms>,
697 output_boxes: &mut Vec<DetectBox>,
698) where
699 f32: AsPrimitive<T>,
700{
701 let (boxes, quant_boxes) = output;
702 let (boxes_tensor, scores_tensor) = postprocess_yolo(&boxes);
703
704 let boxes = {
705 let score_threshold = quantize_score_threshold(score_threshold, quant_boxes);
706 postprocess_boxes_quant::<B, _, _>(
707 score_threshold,
708 boxes_tensor,
709 scores_tensor,
710 quant_boxes,
711 )
712 };
713
714 let boxes = dispatch_nms_int(nms, iou_threshold, boxes);
715 let len = output_boxes.capacity().min(boxes.len());
716 output_boxes.clear();
717 for b in boxes.iter().take(len) {
718 output_boxes.push(dequant_detect_box(b, quant_boxes));
719 }
720}
721
722pub(crate) fn impl_yolo_float<B: BBoxTypeTrait, T: Float + AsPrimitive<f32> + Send + Sync>(
727 output: ArrayView2<T>,
728 score_threshold: f32,
729 iou_threshold: f32,
730 nms: Option<Nms>,
731 output_boxes: &mut Vec<DetectBox>,
732) where
733 f32: AsPrimitive<T>,
734{
735 let (boxes_tensor, scores_tensor) = postprocess_yolo(&output);
736 let boxes =
737 postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
738 let boxes = dispatch_nms_float(nms, iou_threshold, boxes);
739 let len = output_boxes.capacity().min(boxes.len());
740 output_boxes.clear();
741 for b in boxes.into_iter().take(len) {
742 output_boxes.push(b);
743 }
744}
745
746pub(crate) fn impl_yolo_split_quant<
756 B: BBoxTypeTrait,
757 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
758 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
759>(
760 boxes: (ArrayView2<BOX>, Quantization),
761 scores: (ArrayView2<SCORE>, Quantization),
762 score_threshold: f32,
763 iou_threshold: f32,
764 nms: Option<Nms>,
765 output_boxes: &mut Vec<DetectBox>,
766) where
767 f32: AsPrimitive<SCORE>,
768{
769 let (boxes_tensor, quant_boxes) = boxes;
770 let (scores_tensor, quant_scores) = scores;
771
772 let boxes_tensor = boxes_tensor.reversed_axes();
773 let scores_tensor = scores_tensor.reversed_axes();
774
775 let boxes = {
776 let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
777 postprocess_boxes_quant::<B, _, _>(
778 score_threshold,
779 boxes_tensor,
780 scores_tensor,
781 quant_boxes,
782 )
783 };
784
785 let boxes = dispatch_nms_int(nms, iou_threshold, boxes);
786 let len = output_boxes.capacity().min(boxes.len());
787 output_boxes.clear();
788 for b in boxes.iter().take(len) {
789 output_boxes.push(dequant_detect_box(b, quant_scores));
790 }
791}
792
793pub(crate) fn impl_yolo_split_float<
802 B: BBoxTypeTrait,
803 BOX: Float + AsPrimitive<f32> + Send + Sync,
804 SCORE: Float + AsPrimitive<f32> + Send + Sync,
805>(
806 boxes_tensor: ArrayView2<BOX>,
807 scores_tensor: ArrayView2<SCORE>,
808 score_threshold: f32,
809 iou_threshold: f32,
810 nms: Option<Nms>,
811 output_boxes: &mut Vec<DetectBox>,
812) where
813 f32: AsPrimitive<SCORE>,
814{
815 let boxes_tensor = boxes_tensor.reversed_axes();
816 let scores_tensor = scores_tensor.reversed_axes();
817 let boxes =
818 postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
819 let boxes = dispatch_nms_float(nms, iou_threshold, boxes);
820 let len = output_boxes.capacity().min(boxes.len());
821 output_boxes.clear();
822 for b in boxes.into_iter().take(len) {
823 output_boxes.push(b);
824 }
825}
826
827pub(crate) fn impl_yolo_segdet_quant<
837 B: BBoxTypeTrait,
838 BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
839 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
840>(
841 boxes: (ArrayView2<BOX>, Quantization),
842 protos: (ArrayView3<PROTO>, Quantization),
843 score_threshold: f32,
844 iou_threshold: f32,
845 nms: Option<Nms>,
846 output_boxes: &mut Vec<DetectBox>,
847 output_masks: &mut Vec<Segmentation>,
848) -> Result<(), crate::DecoderError>
849where
850 f32: AsPrimitive<BOX>,
851{
852 let (boxes, quant_boxes) = boxes;
853 let num_protos = protos.0.dim().2;
854
855 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
856 let boxes = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
857 (boxes_tensor, quant_boxes),
858 (scores_tensor, quant_boxes),
859 score_threshold,
860 iou_threshold,
861 nms,
862 output_boxes.capacity(),
863 );
864
865 impl_yolo_split_segdet_quant_process_masks::<_, _>(
866 boxes,
867 (mask_tensor, quant_boxes),
868 protos,
869 output_boxes,
870 output_masks,
871 )
872}
873
874pub(crate) fn impl_yolo_segdet_float<
884 B: BBoxTypeTrait,
885 BOX: Float + AsPrimitive<f32> + Send + Sync,
886 PROTO: Float + AsPrimitive<f32> + Send + Sync,
887>(
888 boxes: ArrayView2<BOX>,
889 protos: ArrayView3<PROTO>,
890 score_threshold: f32,
891 iou_threshold: f32,
892 nms: Option<Nms>,
893 output_boxes: &mut Vec<DetectBox>,
894 output_masks: &mut Vec<Segmentation>,
895) -> Result<(), crate::DecoderError>
896where
897 f32: AsPrimitive<BOX>,
898{
899 let num_protos = protos.dim().2;
900 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
901 let boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
902 boxes_tensor,
903 scores_tensor,
904 score_threshold,
905 iou_threshold,
906 nms,
907 output_boxes.capacity(),
908 );
909 impl_yolo_split_segdet_process_masks(boxes, mask_tensor, protos, output_boxes, output_masks)
910}
911
912pub(crate) fn impl_yolo_segdet_get_boxes<
913 B: BBoxTypeTrait,
914 BOX: Float + AsPrimitive<f32> + Send + Sync,
915 SCORE: Float + AsPrimitive<f32> + Send + Sync,
916>(
917 boxes_tensor: ArrayView2<BOX>,
918 scores_tensor: ArrayView2<SCORE>,
919 score_threshold: f32,
920 iou_threshold: f32,
921 nms: Option<Nms>,
922 max_boxes: usize,
923) -> Vec<(DetectBox, usize)>
924where
925 f32: AsPrimitive<SCORE>,
926{
927 let boxes = postprocess_boxes_index_float::<B, _, _>(
928 score_threshold.as_(),
929 boxes_tensor,
930 scores_tensor,
931 );
932 let mut boxes = dispatch_nms_extra_float(nms, iou_threshold, boxes);
933 boxes.truncate(max_boxes);
934 boxes
935}
936
937pub(crate) fn impl_yolo_end_to_end_segdet_get_boxes<
938 B: BBoxTypeTrait,
939 BOX: Float + AsPrimitive<f32> + Send + Sync,
940 SCORE: Float + AsPrimitive<f32> + Send + Sync,
941 CLASS: AsPrimitive<f32> + Send + Sync,
942>(
943 boxes: ArrayView2<BOX>,
944 scores: ArrayView2<SCORE>,
945 classes: ArrayView1<CLASS>,
946 score_threshold: f32,
947 max_boxes: usize,
948) -> Vec<(DetectBox, usize)>
949where
950 f32: AsPrimitive<SCORE>,
951{
952 let mut boxes = postprocess_boxes_index_float::<B, _, _>(score_threshold.as_(), boxes, scores);
953 boxes.truncate(max_boxes);
954 for (b, ind) in &mut boxes {
955 b.label = classes[*ind].as_().round() as usize;
956 }
957 boxes
958}
959
960pub(crate) fn impl_yolo_split_segdet_process_masks<
961 MASK: Float + AsPrimitive<f32> + Send + Sync,
962 PROTO: Float + AsPrimitive<f32> + Send + Sync,
963>(
964 boxes: Vec<(DetectBox, usize)>,
965 masks_tensor: ArrayView2<MASK>,
966 protos_tensor: ArrayView3<PROTO>,
967 output_boxes: &mut Vec<DetectBox>,
968 output_masks: &mut Vec<Segmentation>,
969) -> Result<(), crate::DecoderError> {
970 let boxes = decode_segdet_f32(boxes, masks_tensor, protos_tensor)?;
971 output_boxes.clear();
972 output_masks.clear();
973 for (b, m) in boxes.into_iter() {
974 output_boxes.push(b);
975 output_masks.push(Segmentation {
976 xmin: b.bbox.xmin,
977 ymin: b.bbox.ymin,
978 xmax: b.bbox.xmax,
979 ymax: b.bbox.ymax,
980 segmentation: m,
981 });
982 }
983 Ok(())
984}
985pub(crate) fn impl_yolo_split_segdet_quant_get_boxes<
989 B: BBoxTypeTrait,
990 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
991 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
992>(
993 boxes: (ArrayView2<BOX>, Quantization),
994 scores: (ArrayView2<SCORE>, Quantization),
995 score_threshold: f32,
996 iou_threshold: f32,
997 nms: Option<Nms>,
998 max_boxes: usize,
999) -> Vec<(DetectBox, usize)>
1000where
1001 f32: AsPrimitive<SCORE>,
1002{
1003 let (boxes_tensor, quant_boxes) = boxes;
1004 let (scores_tensor, quant_scores) = scores;
1005
1006 let boxes = {
1007 let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
1008 postprocess_boxes_index_quant::<B, _, _>(
1009 score_threshold,
1010 boxes_tensor,
1011 scores_tensor,
1012 quant_boxes,
1013 )
1014 };
1015 let mut boxes = dispatch_nms_extra_int(nms, iou_threshold, boxes);
1016 boxes.truncate(max_boxes);
1017 boxes
1018 .into_iter()
1019 .map(|(b, i)| (dequant_detect_box(&b, quant_scores), i))
1020 .collect()
1021}
1022
1023pub(crate) fn impl_yolo_split_segdet_quant_process_masks<
1024 MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1025 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1026>(
1027 boxes: Vec<(DetectBox, usize)>,
1028 mask_coeff: (ArrayView2<MASK>, Quantization),
1029 protos: (ArrayView3<PROTO>, Quantization),
1030 output_boxes: &mut Vec<DetectBox>,
1031 output_masks: &mut Vec<Segmentation>,
1032) -> Result<(), crate::DecoderError> {
1033 let (masks, quant_masks) = mask_coeff;
1034 let (protos, quant_protos) = protos;
1035
1036 let boxes = decode_segdet_quant(boxes, masks, protos, quant_masks, quant_protos)?;
1037 output_boxes.clear();
1038 output_masks.clear();
1039 for (b, m) in boxes.into_iter() {
1040 output_boxes.push(b);
1041 output_masks.push(Segmentation {
1042 xmin: b.bbox.xmin,
1043 ymin: b.bbox.ymin,
1044 xmax: b.bbox.xmax,
1045 ymax: b.bbox.ymax,
1046 segmentation: m,
1047 });
1048 }
1049 Ok(())
1050}
1051
1052#[allow(clippy::too_many_arguments)]
1053pub(crate) fn impl_yolo_split_segdet_quant<
1065 B: BBoxTypeTrait,
1066 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
1067 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
1068 MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1069 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1070>(
1071 boxes: (ArrayView2<BOX>, Quantization),
1072 scores: (ArrayView2<SCORE>, Quantization),
1073 mask_coeff: (ArrayView2<MASK>, Quantization),
1074 protos: (ArrayView3<PROTO>, Quantization),
1075 score_threshold: f32,
1076 iou_threshold: f32,
1077 nms: Option<Nms>,
1078 output_boxes: &mut Vec<DetectBox>,
1079 output_masks: &mut Vec<Segmentation>,
1080) -> Result<(), crate::DecoderError>
1081where
1082 f32: AsPrimitive<SCORE>,
1083{
1084 let (boxes_, scores_, mask_coeff_) =
1085 postprocess_yolo_split_segdet(boxes.0, scores.0, mask_coeff.0);
1086 let boxes = (boxes_, boxes.1);
1087 let scores = (scores_, scores.1);
1088 let mask_coeff = (mask_coeff_, mask_coeff.1);
1089
1090 let boxes = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
1091 boxes,
1092 scores,
1093 score_threshold,
1094 iou_threshold,
1095 nms,
1096 output_boxes.capacity(),
1097 );
1098
1099 impl_yolo_split_segdet_quant_process_masks(
1100 boxes,
1101 mask_coeff,
1102 protos,
1103 output_boxes,
1104 output_masks,
1105 )
1106}
1107
1108#[allow(clippy::too_many_arguments)]
1109pub(crate) fn impl_yolo_split_segdet_float<
1121 B: BBoxTypeTrait,
1122 BOX: Float + AsPrimitive<f32> + Send + Sync,
1123 SCORE: Float + AsPrimitive<f32> + Send + Sync,
1124 MASK: Float + AsPrimitive<f32> + Send + Sync,
1125 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1126>(
1127 boxes_tensor: ArrayView2<BOX>,
1128 scores_tensor: ArrayView2<SCORE>,
1129 mask_tensor: ArrayView2<MASK>,
1130 protos: ArrayView3<PROTO>,
1131 score_threshold: f32,
1132 iou_threshold: f32,
1133 nms: Option<Nms>,
1134 output_boxes: &mut Vec<DetectBox>,
1135 output_masks: &mut Vec<Segmentation>,
1136) -> Result<(), crate::DecoderError>
1137where
1138 f32: AsPrimitive<SCORE>,
1139{
1140 let (boxes_tensor, scores_tensor, mask_tensor) =
1141 postprocess_yolo_split_segdet(boxes_tensor, scores_tensor, mask_tensor);
1142
1143 let boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
1144 boxes_tensor,
1145 scores_tensor,
1146 score_threshold,
1147 iou_threshold,
1148 nms,
1149 output_boxes.capacity(),
1150 );
1151 impl_yolo_split_segdet_process_masks(boxes, mask_tensor, protos, output_boxes, output_masks)
1152}
1153
1154pub fn impl_yolo_segdet_quant_proto<
1161 B: BBoxTypeTrait,
1162 BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1163 PROTO: PrimInt
1164 + AsPrimitive<i64>
1165 + AsPrimitive<i128>
1166 + AsPrimitive<f32>
1167 + AsPrimitive<i8>
1168 + Send
1169 + Sync,
1170>(
1171 boxes: (ArrayView2<BOX>, Quantization),
1172 protos: (ArrayView3<PROTO>, Quantization),
1173 score_threshold: f32,
1174 iou_threshold: f32,
1175 nms: Option<Nms>,
1176 output_boxes: &mut Vec<DetectBox>,
1177) -> ProtoData
1178where
1179 f32: AsPrimitive<BOX>,
1180{
1181 let (boxes_arr, quant_boxes) = boxes;
1182 let (protos_arr, quant_protos) = protos;
1183 let num_protos = protos_arr.dim().2;
1184
1185 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes_arr, num_protos);
1186
1187 let det_indices = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
1188 (boxes_tensor, quant_boxes),
1189 (scores_tensor, quant_boxes),
1190 score_threshold,
1191 iou_threshold,
1192 nms,
1193 output_boxes.capacity(),
1194 );
1195
1196 extract_proto_data_quant(
1197 det_indices,
1198 mask_tensor,
1199 quant_boxes,
1200 protos_arr,
1201 quant_protos,
1202 output_boxes,
1203 )
1204}
1205
1206pub(crate) fn impl_yolo_segdet_float_proto<
1209 B: BBoxTypeTrait,
1210 BOX: Float + AsPrimitive<f32> + Send + Sync,
1211 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1212>(
1213 boxes: ArrayView2<BOX>,
1214 protos: ArrayView3<PROTO>,
1215 score_threshold: f32,
1216 iou_threshold: f32,
1217 nms: Option<Nms>,
1218 output_boxes: &mut Vec<DetectBox>,
1219) -> ProtoData
1220where
1221 f32: AsPrimitive<BOX>,
1222{
1223 let num_protos = protos.dim().2;
1224 let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
1225
1226 let boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
1227 boxes_tensor,
1228 scores_tensor,
1229 score_threshold,
1230 iou_threshold,
1231 nms,
1232 output_boxes.capacity(),
1233 );
1234
1235 extract_proto_data_float(boxes, mask_tensor, protos, output_boxes)
1236}
1237
1238#[allow(clippy::too_many_arguments)]
1241pub(crate) fn impl_yolo_split_segdet_float_proto<
1242 B: BBoxTypeTrait,
1243 BOX: Float + AsPrimitive<f32> + Send + Sync,
1244 SCORE: Float + AsPrimitive<f32> + Send + Sync,
1245 MASK: Float + AsPrimitive<f32> + Send + Sync,
1246 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1247>(
1248 boxes_tensor: ArrayView2<BOX>,
1249 scores_tensor: ArrayView2<SCORE>,
1250 mask_tensor: ArrayView2<MASK>,
1251 protos: ArrayView3<PROTO>,
1252 score_threshold: f32,
1253 iou_threshold: f32,
1254 nms: Option<Nms>,
1255 output_boxes: &mut Vec<DetectBox>,
1256) -> ProtoData
1257where
1258 f32: AsPrimitive<SCORE>,
1259{
1260 let (boxes_tensor, scores_tensor, mask_tensor) =
1261 postprocess_yolo_split_segdet(boxes_tensor, scores_tensor, mask_tensor);
1262 let det_indices = impl_yolo_segdet_get_boxes::<B, _, _>(
1263 boxes_tensor,
1264 scores_tensor,
1265 score_threshold,
1266 iou_threshold,
1267 nms,
1268 output_boxes.capacity(),
1269 );
1270
1271 extract_proto_data_float(det_indices, mask_tensor, protos, output_boxes)
1272}
1273
1274pub fn decode_yolo_end_to_end_segdet_float_proto<T>(
1276 output: ArrayView2<T>,
1277 protos: ArrayView3<T>,
1278 score_threshold: f32,
1279 output_boxes: &mut Vec<DetectBox>,
1280) -> Result<ProtoData, crate::DecoderError>
1281where
1282 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
1283 f32: AsPrimitive<T>,
1284{
1285 let (boxes, scores, classes, mask_coeff) =
1286 postprocess_yolo_end_to_end_segdet(&output, protos.dim().2)?;
1287 let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
1288 boxes,
1289 scores,
1290 classes,
1291 score_threshold,
1292 output_boxes.capacity(),
1293 );
1294
1295 Ok(extract_proto_data_float(
1296 boxes,
1297 mask_coeff,
1298 protos,
1299 output_boxes,
1300 ))
1301}
1302
1303#[allow(clippy::too_many_arguments)]
1305pub fn decode_yolo_split_end_to_end_segdet_float_proto<T>(
1306 boxes: ArrayView2<T>,
1307 scores: ArrayView2<T>,
1308 classes: ArrayView2<T>,
1309 mask_coeff: ArrayView2<T>,
1310 protos: ArrayView3<T>,
1311 score_threshold: f32,
1312 output_boxes: &mut Vec<DetectBox>,
1313) -> Result<ProtoData, crate::DecoderError>
1314where
1315 T: Float + AsPrimitive<f32> + Send + Sync + 'static,
1316 f32: AsPrimitive<T>,
1317{
1318 let (boxes, scores, classes, mask_coeff) =
1319 postprocess_yolo_split_end_to_end_segdet(boxes, scores, &classes, mask_coeff)?;
1320 let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
1321 boxes,
1322 scores,
1323 classes,
1324 score_threshold,
1325 output_boxes.capacity(),
1326 );
1327
1328 Ok(extract_proto_data_float(
1329 boxes,
1330 mask_coeff,
1331 protos,
1332 output_boxes,
1333 ))
1334}
1335
1336pub(super) fn extract_proto_data_float<
1338 MASK: Float + AsPrimitive<f32> + Send + Sync,
1339 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1340>(
1341 det_indices: Vec<(DetectBox, usize)>,
1342 mask_tensor: ArrayView2<MASK>,
1343 protos: ArrayView3<PROTO>,
1344 output_boxes: &mut Vec<DetectBox>,
1345) -> ProtoData {
1346 let mut mask_coefficients = Vec::with_capacity(det_indices.len());
1347 output_boxes.clear();
1348 for (det, idx) in det_indices {
1349 output_boxes.push(det);
1350 let row = mask_tensor.row(idx);
1351 mask_coefficients.push(row.iter().map(|v| v.as_()).collect());
1352 }
1353 let protos_f32 = protos.map(|v| v.as_());
1354 ProtoData {
1355 mask_coefficients,
1356 protos: ProtoTensor::Float(protos_f32),
1357 }
1358}
1359
1360pub(crate) fn extract_proto_data_quant<
1366 MASK: PrimInt + AsPrimitive<f32> + Send + Sync,
1367 PROTO: PrimInt + AsPrimitive<f32> + AsPrimitive<i8> + Send + Sync + 'static,
1368>(
1369 det_indices: Vec<(DetectBox, usize)>,
1370 mask_tensor: ArrayView2<MASK>,
1371 quant_masks: Quantization,
1372 protos: ArrayView3<PROTO>,
1373 quant_protos: Quantization,
1374 output_boxes: &mut Vec<DetectBox>,
1375) -> ProtoData {
1376 let mut mask_coefficients = Vec::with_capacity(det_indices.len());
1377 output_boxes.clear();
1378 for (det, idx) in det_indices {
1379 output_boxes.push(det);
1380 let row = mask_tensor.row(idx);
1381 mask_coefficients.push(
1382 row.iter()
1383 .map(|v| (v.as_() - quant_masks.zero_point as f32) * quant_masks.scale)
1384 .collect(),
1385 );
1386 }
1387 let protos_i8 = if std::any::TypeId::of::<PROTO>() == std::any::TypeId::of::<i8>() {
1391 let view_i8 =
1393 unsafe { &*(&protos as *const ArrayView3<'_, PROTO> as *const ArrayView3<'_, i8>) };
1394 view_i8.to_owned()
1395 } else {
1396 protos.map(|v| {
1397 let v_i8: i8 = v.as_();
1398 v_i8
1399 })
1400 };
1401 ProtoData {
1402 mask_coefficients,
1403 protos: ProtoTensor::Quantized {
1404 protos: protos_i8,
1405 quantization: quant_protos,
1406 },
1407 }
1408}
1409
1410fn postprocess_yolo<'a, T>(
1411 output: &'a ArrayView2<'_, T>,
1412) -> (ArrayView2<'a, T>, ArrayView2<'a, T>) {
1413 let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1414 let scores_tensor = output.slice(s![4.., ..,]).reversed_axes();
1415 (boxes_tensor, scores_tensor)
1416}
1417
1418pub(crate) fn postprocess_yolo_seg<'a, T>(
1419 output: &'a ArrayView2<'_, T>,
1420 num_protos: usize,
1421) -> (ArrayView2<'a, T>, ArrayView2<'a, T>, ArrayView2<'a, T>) {
1422 assert!(
1423 output.shape()[0] > num_protos + 4,
1424 "Output shape is too short: {} <= {} + 4",
1425 output.shape()[0],
1426 num_protos
1427 );
1428 let num_classes = output.shape()[0] - 4 - num_protos;
1429 let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1430 let scores_tensor = output.slice(s![4..(num_classes + 4), ..,]).reversed_axes();
1431 let mask_tensor = output.slice(s![(num_classes + 4).., ..,]).reversed_axes();
1432 (boxes_tensor, scores_tensor, mask_tensor)
1433}
1434
1435pub(crate) fn postprocess_yolo_split_segdet<'a, 'b, 'c, BOX, SCORE, MASK>(
1436 boxes_tensor: ArrayView2<'a, BOX>,
1437 scores_tensor: ArrayView2<'b, SCORE>,
1438 mask_tensor: ArrayView2<'c, MASK>,
1439) -> (
1440 ArrayView2<'a, BOX>,
1441 ArrayView2<'b, SCORE>,
1442 ArrayView2<'c, MASK>,
1443) {
1444 let boxes_tensor = boxes_tensor.reversed_axes();
1445 let scores_tensor = scores_tensor.reversed_axes();
1446 let mask_tensor = mask_tensor.reversed_axes();
1447 (boxes_tensor, scores_tensor, mask_tensor)
1448}
1449
1450fn decode_segdet_f32<
1451 MASK: Float + AsPrimitive<f32> + Send + Sync,
1452 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1453>(
1454 boxes: Vec<(DetectBox, usize)>,
1455 masks: ArrayView2<MASK>,
1456 protos: ArrayView3<PROTO>,
1457) -> Result<Vec<(DetectBox, Array3<u8>)>, crate::DecoderError> {
1458 if boxes.is_empty() {
1459 return Ok(Vec::new());
1460 }
1461 if masks.shape()[1] != protos.shape()[2] {
1462 return Err(crate::DecoderError::InvalidShape(format!(
1463 "Mask coefficients count ({}) doesn't match protos channel count ({})",
1464 masks.shape()[1],
1465 protos.shape()[2],
1466 )));
1467 }
1468 boxes
1469 .into_par_iter()
1470 .map(|mut b| {
1471 let ind = b.1;
1472 let (protos, roi) = protobox(&protos, &b.0.bbox)?;
1473 b.0.bbox = roi;
1474 Ok((b.0, make_segmentation(masks.row(ind), protos.view())))
1475 })
1476 .collect()
1477}
1478
1479pub(crate) fn decode_segdet_quant<
1480 MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1481 PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1482>(
1483 boxes: Vec<(DetectBox, usize)>,
1484 masks: ArrayView2<MASK>,
1485 protos: ArrayView3<PROTO>,
1486 quant_masks: Quantization,
1487 quant_protos: Quantization,
1488) -> Result<Vec<(DetectBox, Array3<u8>)>, crate::DecoderError> {
1489 if boxes.is_empty() {
1490 return Ok(Vec::new());
1491 }
1492 if masks.shape()[1] != protos.shape()[2] {
1493 return Err(crate::DecoderError::InvalidShape(format!(
1494 "Mask coefficients count ({}) doesn't match protos channel count ({})",
1495 masks.shape()[1],
1496 protos.shape()[2],
1497 )));
1498 }
1499
1500 let total_bits = MASK::zero().count_zeros() + PROTO::zero().count_zeros() + 5; boxes
1502 .into_iter()
1503 .map(|mut b| {
1504 let i = b.1;
1505 let (protos, roi) = protobox(&protos, &b.0.bbox.to_canonical())?;
1506 b.0.bbox = roi;
1507 let seg = match total_bits {
1508 0..=64 => make_segmentation_quant::<MASK, PROTO, i64>(
1509 masks.row(i),
1510 protos.view(),
1511 quant_masks,
1512 quant_protos,
1513 ),
1514 65..=128 => make_segmentation_quant::<MASK, PROTO, i128>(
1515 masks.row(i),
1516 protos.view(),
1517 quant_masks,
1518 quant_protos,
1519 ),
1520 _ => {
1521 return Err(crate::DecoderError::NotSupported(format!(
1522 "Unsupported bit width ({total_bits}) for segmentation computation"
1523 )));
1524 }
1525 };
1526 Ok((b.0, seg))
1527 })
1528 .collect()
1529}
1530
1531fn protobox<'a, T>(
1532 protos: &'a ArrayView3<T>,
1533 roi: &BoundingBox,
1534) -> Result<(ArrayView3<'a, T>, BoundingBox), crate::DecoderError> {
1535 let width = protos.dim().1 as f32;
1536 let height = protos.dim().0 as f32;
1537
1538 const NORM_LIMIT: f32 = 2.0;
1549 if roi.xmin > NORM_LIMIT
1550 || roi.ymin > NORM_LIMIT
1551 || roi.xmax > NORM_LIMIT
1552 || roi.ymax > NORM_LIMIT
1553 {
1554 return Err(crate::DecoderError::InvalidShape(format!(
1555 "Bounding box coordinates appear un-normalized (pixel-space). \
1556 Got bbox=({:.2}, {:.2}, {:.2}, {:.2}) but expected values in [0, 1]. \
1557 ONNX models output pixel-space boxes — normalize them by dividing by \
1558 the input dimensions before calling decode().",
1559 roi.xmin, roi.ymin, roi.xmax, roi.ymax,
1560 )));
1561 }
1562
1563 let roi = [
1564 (roi.xmin * width).clamp(0.0, width) as usize,
1565 (roi.ymin * height).clamp(0.0, height) as usize,
1566 (roi.xmax * width).clamp(0.0, width).ceil() as usize,
1567 (roi.ymax * height).clamp(0.0, height).ceil() as usize,
1568 ];
1569
1570 let roi_norm = [
1571 roi[0] as f32 / width,
1572 roi[1] as f32 / height,
1573 roi[2] as f32 / width,
1574 roi[3] as f32 / height,
1575 ]
1576 .into();
1577
1578 let cropped = protos.slice(s![roi[1]..roi[3], roi[0]..roi[2], ..]);
1579
1580 Ok((cropped, roi_norm))
1581}
1582
1583fn make_segmentation<
1589 MASK: Float + AsPrimitive<f32> + Send + Sync,
1590 PROTO: Float + AsPrimitive<f32> + Send + Sync,
1591>(
1592 mask: ArrayView1<MASK>,
1593 protos: ArrayView3<PROTO>,
1594) -> Array3<u8> {
1595 let shape = protos.shape();
1596
1597 let mask = mask.to_shape((1, mask.len())).unwrap();
1599 let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
1600 let protos = protos.reversed_axes();
1601 let mask = mask.map(|x| x.as_());
1602 let protos = protos.map(|x| x.as_());
1603
1604 let mask = mask
1606 .dot(&protos)
1607 .into_shape_with_order((shape[0], shape[1], 1))
1608 .unwrap();
1609
1610 mask.map(|x| {
1611 let sigmoid = 1.0 / (1.0 + (-*x).exp());
1612 (sigmoid * 255.0).round() as u8
1613 })
1614}
1615
1616fn make_segmentation_quant<
1623 MASK: PrimInt + AsPrimitive<DEST> + Send + Sync,
1624 PROTO: PrimInt + AsPrimitive<DEST> + Send + Sync,
1625 DEST: PrimInt + 'static + Signed + AsPrimitive<f32> + Debug,
1626>(
1627 mask: ArrayView1<MASK>,
1628 protos: ArrayView3<PROTO>,
1629 quant_masks: Quantization,
1630 quant_protos: Quantization,
1631) -> Array3<u8>
1632where
1633 i32: AsPrimitive<DEST>,
1634 f32: AsPrimitive<DEST>,
1635{
1636 let shape = protos.shape();
1637
1638 let mask = mask.to_shape((1, mask.len())).unwrap();
1640
1641 let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
1642 let protos = protos.reversed_axes();
1643
1644 let zp = quant_masks.zero_point.as_();
1645
1646 let mask = mask.mapv(|x| x.as_() - zp);
1647
1648 let zp = quant_protos.zero_point.as_();
1649 let protos = protos.mapv(|x| x.as_() - zp);
1650
1651 let segmentation = mask
1653 .dot(&protos)
1654 .into_shape_with_order((shape[0], shape[1], 1))
1655 .unwrap();
1656
1657 let combined_scale = quant_masks.scale * quant_protos.scale;
1658 segmentation.map(|x| {
1659 let val: f32 = (*x).as_() * combined_scale;
1660 let sigmoid = 1.0 / (1.0 + (-val).exp());
1661 (sigmoid * 255.0).round() as u8
1662 })
1663}
1664
1665pub fn yolo_segmentation_to_mask(
1677 segmentation: ArrayView3<u8>,
1678 threshold: u8,
1679) -> Result<Array2<u8>, crate::DecoderError> {
1680 if segmentation.shape()[2] != 1 {
1681 return Err(crate::DecoderError::InvalidShape(format!(
1682 "Yolo Instance Segmentation should have shape (H, W, 1), got (H, W, {})",
1683 segmentation.shape()[2]
1684 )));
1685 }
1686 Ok(segmentation
1687 .slice(s![.., .., 0])
1688 .map(|x| if *x >= threshold { 1 } else { 0 }))
1689}
1690
1691#[cfg(test)]
1692#[cfg_attr(coverage_nightly, coverage(off))]
1693mod tests {
1694 use super::*;
1695 use ndarray::Array2;
1696
1697 #[test]
1702 fn test_end_to_end_det_basic_filtering() {
1703 let data: Vec<f32> = vec![
1707 0.1, 0.2, 0.3, 0.1, 0.2, 0.3, 0.5, 0.6, 0.7, 0.5, 0.6, 0.7, 0.9, 0.1, 0.2, 0.0, 1.0, 2.0, ];
1715 let output = Array2::from_shape_vec((6, 3), data).unwrap();
1716
1717 let mut boxes = Vec::with_capacity(10);
1718 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1719
1720 assert_eq!(boxes.len(), 1);
1722 assert_eq!(boxes[0].label, 0);
1723 assert!((boxes[0].score - 0.9).abs() < 0.01);
1724 assert!((boxes[0].bbox.xmin - 0.1).abs() < 0.01);
1725 assert!((boxes[0].bbox.ymin - 0.1).abs() < 0.01);
1726 assert!((boxes[0].bbox.xmax - 0.5).abs() < 0.01);
1727 assert!((boxes[0].bbox.ymax - 0.5).abs() < 0.01);
1728 }
1729
1730 #[test]
1731 fn test_end_to_end_det_all_pass_threshold() {
1732 let data: Vec<f32> = vec![
1734 10.0, 20.0, 10.0, 20.0, 50.0, 60.0, 50.0, 60.0, 0.8, 0.7, 1.0, 2.0, ];
1741 let output = Array2::from_shape_vec((6, 2), data).unwrap();
1742
1743 let mut boxes = Vec::with_capacity(10);
1744 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1745
1746 assert_eq!(boxes.len(), 2);
1747 assert_eq!(boxes[0].label, 1);
1748 assert_eq!(boxes[1].label, 2);
1749 }
1750
1751 #[test]
1752 fn test_end_to_end_det_none_pass_threshold() {
1753 let data: Vec<f32> = vec![
1755 10.0, 20.0, 10.0, 20.0, 50.0, 60.0, 50.0, 60.0, 0.1, 0.2, 1.0, 2.0, ];
1762 let output = Array2::from_shape_vec((6, 2), data).unwrap();
1763
1764 let mut boxes = Vec::with_capacity(10);
1765 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1766
1767 assert_eq!(boxes.len(), 0);
1768 }
1769
1770 #[test]
1771 fn test_end_to_end_det_capacity_limit() {
1772 let data: Vec<f32> = vec![
1774 0.1, 0.2, 0.3, 0.4, 0.5, 0.1, 0.2, 0.3, 0.4, 0.5, 0.5, 0.6, 0.7, 0.8, 0.9, 0.5, 0.6, 0.7, 0.8, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.0, 1.0, 2.0, 3.0, 4.0, ];
1781 let output = Array2::from_shape_vec((6, 5), data).unwrap();
1782
1783 let mut boxes = Vec::with_capacity(2); decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1785
1786 assert_eq!(boxes.len(), 2);
1787 }
1788
1789 #[test]
1790 fn test_end_to_end_det_empty_output() {
1791 let output = Array2::<f32>::zeros((6, 0));
1793
1794 let mut boxes = Vec::with_capacity(10);
1795 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1796
1797 assert_eq!(boxes.len(), 0);
1798 }
1799
1800 #[test]
1801 fn test_end_to_end_det_pixel_coordinates() {
1802 let data: Vec<f32> = vec![
1804 100.0, 200.0, 300.0, 400.0, 0.95, 5.0, ];
1811 let output = Array2::from_shape_vec((6, 1), data).unwrap();
1812
1813 let mut boxes = Vec::with_capacity(10);
1814 decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1815
1816 assert_eq!(boxes.len(), 1);
1817 assert_eq!(boxes[0].label, 5);
1818 assert!((boxes[0].bbox.xmin - 100.0).abs() < 0.01);
1819 assert!((boxes[0].bbox.ymin - 200.0).abs() < 0.01);
1820 assert!((boxes[0].bbox.xmax - 300.0).abs() < 0.01);
1821 assert!((boxes[0].bbox.ymax - 400.0).abs() < 0.01);
1822 }
1823
1824 #[test]
1825 fn test_end_to_end_det_invalid_shape() {
1826 let output = Array2::<f32>::zeros((5, 3));
1828
1829 let mut boxes = Vec::with_capacity(10);
1830 let result = decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes);
1831
1832 assert!(result.is_err());
1833 assert!(matches!(
1834 result,
1835 Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 6 rows")
1836 ));
1837 }
1838
1839 #[test]
1844 fn test_end_to_end_segdet_basic() {
1845 let num_protos = 32;
1848 let num_detections = 2;
1849 let num_features = 6 + num_protos;
1850
1851 let mut data = vec![0.0f32; num_features * num_detections];
1853 data[0] = 0.1; data[1] = 0.5; data[num_detections] = 0.1; data[num_detections + 1] = 0.5; data[2 * num_detections] = 0.4; data[2 * num_detections + 1] = 0.9; data[3 * num_detections] = 0.4; data[3 * num_detections + 1] = 0.9; data[4 * num_detections] = 0.9; data[4 * num_detections + 1] = 0.3; data[5 * num_detections] = 1.0; data[5 * num_detections + 1] = 2.0; for i in 6..num_features {
1868 data[i * num_detections] = 0.1;
1869 data[i * num_detections + 1] = 0.1;
1870 }
1871
1872 let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
1873
1874 let protos = Array3::<f32>::zeros((16, 16, num_protos));
1876
1877 let mut boxes = Vec::with_capacity(10);
1878 let mut masks = Vec::with_capacity(10);
1879 decode_yolo_end_to_end_segdet_float(
1880 output.view(),
1881 protos.view(),
1882 0.5,
1883 &mut boxes,
1884 &mut masks,
1885 )
1886 .unwrap();
1887
1888 assert_eq!(boxes.len(), 1);
1890 assert_eq!(masks.len(), 1);
1891 assert_eq!(boxes[0].label, 1);
1892 assert!((boxes[0].score - 0.9).abs() < 0.01);
1893 }
1894
1895 #[test]
1896 fn test_end_to_end_segdet_mask_coordinates() {
1897 let num_protos = 32;
1899 let num_features = 6 + num_protos;
1900
1901 let mut data = vec![0.0f32; num_features];
1902 data[0] = 0.2; data[1] = 0.2; data[2] = 0.8; data[3] = 0.8; data[4] = 0.95; data[5] = 3.0; let output = Array2::from_shape_vec((num_features, 1), data).unwrap();
1910 let protos = Array3::<f32>::zeros((16, 16, num_protos));
1911
1912 let mut boxes = Vec::with_capacity(10);
1913 let mut masks = Vec::with_capacity(10);
1914 decode_yolo_end_to_end_segdet_float(
1915 output.view(),
1916 protos.view(),
1917 0.5,
1918 &mut boxes,
1919 &mut masks,
1920 )
1921 .unwrap();
1922
1923 assert_eq!(boxes.len(), 1);
1924 assert_eq!(masks.len(), 1);
1925
1926 assert!((masks[0].xmin - boxes[0].bbox.xmin).abs() < 0.01);
1928 assert!((masks[0].ymin - boxes[0].bbox.ymin).abs() < 0.01);
1929 assert!((masks[0].xmax - boxes[0].bbox.xmax).abs() < 0.01);
1930 assert!((masks[0].ymax - boxes[0].bbox.ymax).abs() < 0.01);
1931 }
1932
1933 #[test]
1934 fn test_end_to_end_segdet_empty_output() {
1935 let num_protos = 32;
1936 let output = Array2::<f32>::zeros((6 + num_protos, 0));
1937 let protos = Array3::<f32>::zeros((16, 16, num_protos));
1938
1939 let mut boxes = Vec::with_capacity(10);
1940 let mut masks = Vec::with_capacity(10);
1941 decode_yolo_end_to_end_segdet_float(
1942 output.view(),
1943 protos.view(),
1944 0.5,
1945 &mut boxes,
1946 &mut masks,
1947 )
1948 .unwrap();
1949
1950 assert_eq!(boxes.len(), 0);
1951 assert_eq!(masks.len(), 0);
1952 }
1953
1954 #[test]
1955 fn test_end_to_end_segdet_capacity_limit() {
1956 let num_protos = 32;
1957 let num_detections = 5;
1958 let num_features = 6 + num_protos;
1959
1960 let mut data = vec![0.0f32; num_features * num_detections];
1961 for i in 0..num_detections {
1963 data[i] = 0.1 * (i as f32); data[num_detections + i] = 0.1 * (i as f32); data[2 * num_detections + i] = 0.1 * (i as f32) + 0.2; data[3 * num_detections + i] = 0.1 * (i as f32) + 0.2; data[4 * num_detections + i] = 0.9; data[5 * num_detections + i] = i as f32; }
1970
1971 let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
1972 let protos = Array3::<f32>::zeros((16, 16, num_protos));
1973
1974 let mut boxes = Vec::with_capacity(2); let mut masks = Vec::with_capacity(2);
1976 decode_yolo_end_to_end_segdet_float(
1977 output.view(),
1978 protos.view(),
1979 0.5,
1980 &mut boxes,
1981 &mut masks,
1982 )
1983 .unwrap();
1984
1985 assert_eq!(boxes.len(), 2);
1986 assert_eq!(masks.len(), 2);
1987 }
1988
1989 #[test]
1990 fn test_end_to_end_segdet_invalid_shape_too_few_rows() {
1991 let output = Array2::<f32>::zeros((6, 3));
1993 let protos = Array3::<f32>::zeros((16, 16, 32));
1994
1995 let mut boxes = Vec::with_capacity(10);
1996 let mut masks = Vec::with_capacity(10);
1997 let result = decode_yolo_end_to_end_segdet_float(
1998 output.view(),
1999 protos.view(),
2000 0.5,
2001 &mut boxes,
2002 &mut masks,
2003 );
2004
2005 assert!(result.is_err());
2006 assert!(matches!(
2007 result,
2008 Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 7 rows")
2009 ));
2010 }
2011
2012 #[test]
2013 fn test_end_to_end_segdet_invalid_shape_protos_mismatch() {
2014 let num_protos = 32;
2016 let output = Array2::<f32>::zeros((6 + 16, 3)); let protos = Array3::<f32>::zeros((16, 16, num_protos)); let mut boxes = Vec::with_capacity(10);
2020 let mut masks = Vec::with_capacity(10);
2021 let result = decode_yolo_end_to_end_segdet_float(
2022 output.view(),
2023 protos.view(),
2024 0.5,
2025 &mut boxes,
2026 &mut masks,
2027 );
2028
2029 assert!(result.is_err());
2030 assert!(matches!(
2031 result,
2032 Err(crate::DecoderError::InvalidShape(s)) if s.contains("doesn't match protos count")
2033 ));
2034 }
2035
2036 #[test]
2041 fn test_split_end_to_end_segdet_basic() {
2042 let num_protos = 32;
2045 let num_detections = 2;
2046 let num_features = 6 + num_protos;
2047
2048 let mut data = vec![0.0f32; num_features * num_detections];
2050 data[0] = 0.1; data[1] = 0.5; data[num_detections] = 0.1; data[num_detections + 1] = 0.5; data[2 * num_detections] = 0.4; data[2 * num_detections + 1] = 0.9; data[3 * num_detections] = 0.4; data[3 * num_detections + 1] = 0.9; data[4 * num_detections] = 0.9; data[4 * num_detections + 1] = 0.3; data[5 * num_detections] = 1.0; data[5 * num_detections + 1] = 2.0; for i in 6..num_features {
2065 data[i * num_detections] = 0.1;
2066 data[i * num_detections + 1] = 0.1;
2067 }
2068
2069 let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
2070 let box_coords = output.slice(s![..4, ..]);
2071 let scores = output.slice(s![4..5, ..]);
2072 let classes = output.slice(s![5..6, ..]);
2073 let mask_coeff = output.slice(s![6.., ..]);
2074 let protos = Array3::<f32>::zeros((16, 16, num_protos));
2076
2077 let mut boxes = Vec::with_capacity(10);
2078 let mut masks = Vec::with_capacity(10);
2079 decode_yolo_split_end_to_end_segdet_float(
2080 box_coords,
2081 scores,
2082 classes,
2083 mask_coeff,
2084 protos.view(),
2085 0.5,
2086 &mut boxes,
2087 &mut masks,
2088 )
2089 .unwrap();
2090
2091 assert_eq!(boxes.len(), 1);
2093 assert_eq!(masks.len(), 1);
2094 assert_eq!(boxes[0].label, 1);
2095 assert!((boxes[0].score - 0.9).abs() < 0.01);
2096 }
2097
2098 #[test]
2103 fn test_segmentation_to_mask_basic() {
2104 let data: Vec<u8> = vec![
2106 100, 200, 50, 150, 10, 255, 128, 64, 0, 127, 128, 255, 64, 64, 192, 192, ];
2111 let segmentation = Array3::from_shape_vec((4, 4, 1), data).unwrap();
2112
2113 let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2114
2115 assert_eq!(mask[[0, 0]], 0); assert_eq!(mask[[0, 1]], 1); assert_eq!(mask[[0, 2]], 0); assert_eq!(mask[[0, 3]], 1); assert_eq!(mask[[1, 1]], 1); assert_eq!(mask[[1, 2]], 1); assert_eq!(mask[[2, 0]], 0); assert_eq!(mask[[2, 1]], 0); }
2125
2126 #[test]
2127 fn test_segmentation_to_mask_all_above() {
2128 let segmentation = Array3::from_elem((4, 4, 1), 255u8);
2129 let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2130 assert!(mask.iter().all(|&x| x == 1));
2131 }
2132
2133 #[test]
2134 fn test_segmentation_to_mask_all_below() {
2135 let segmentation = Array3::from_elem((4, 4, 1), 64u8);
2136 let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2137 assert!(mask.iter().all(|&x| x == 0));
2138 }
2139
2140 #[test]
2141 fn test_segmentation_to_mask_invalid_shape() {
2142 let segmentation = Array3::from_elem((4, 4, 3), 128u8);
2143 let result = yolo_segmentation_to_mask(segmentation.view(), 128);
2144
2145 assert!(result.is_err());
2146 assert!(matches!(
2147 result,
2148 Err(crate::DecoderError::InvalidShape(s)) if s.contains("(H, W, 1)")
2149 ));
2150 }
2151
2152 #[test]
2157 fn test_protobox_clamps_edge_coordinates() {
2158 let protos = Array3::<f32>::zeros((16, 16, 4));
2160 let view = protos.view();
2161 let roi = BoundingBox {
2162 xmin: 0.5,
2163 ymin: 0.5,
2164 xmax: 1.0,
2165 ymax: 1.0,
2166 };
2167 let result = protobox(&view, &roi);
2168 assert!(result.is_ok(), "protobox should accept xmax=1.0");
2169 let (cropped, _roi_norm) = result.unwrap();
2170 assert!(cropped.shape()[0] > 0);
2172 assert!(cropped.shape()[1] > 0);
2173 assert_eq!(cropped.shape()[2], 4);
2174 }
2175
2176 #[test]
2177 fn test_protobox_rejects_wildly_out_of_range() {
2178 let protos = Array3::<f32>::zeros((16, 16, 4));
2180 let view = protos.view();
2181 let roi = BoundingBox {
2182 xmin: 0.0,
2183 ymin: 0.0,
2184 xmax: 3.0,
2185 ymax: 3.0,
2186 };
2187 let result = protobox(&view, &roi);
2188 assert!(
2189 matches!(result, Err(crate::DecoderError::InvalidShape(s)) if s.contains("un-normalized")),
2190 "protobox should reject coords > NORM_LIMIT"
2191 );
2192 }
2193
2194 #[test]
2195 fn test_protobox_accepts_slightly_over_one() {
2196 let protos = Array3::<f32>::zeros((16, 16, 4));
2198 let view = protos.view();
2199 let roi = BoundingBox {
2200 xmin: 0.0,
2201 ymin: 0.0,
2202 xmax: 1.5,
2203 ymax: 1.5,
2204 };
2205 let result = protobox(&view, &roi);
2206 assert!(
2207 result.is_ok(),
2208 "protobox should accept coords <= NORM_LIMIT (2.0)"
2209 );
2210 let (cropped, _roi_norm) = result.unwrap();
2211 assert_eq!(cropped.shape()[0], 16);
2213 assert_eq!(cropped.shape()[1], 16);
2214 }
2215
2216 #[test]
2217 fn test_segdet_float_proto_no_panic() {
2218 let num_proposals = 100; let num_classes = 80;
2222 let num_mask_coeffs = 32;
2223 let rows = 4 + num_classes + num_mask_coeffs; let mut data = vec![0.0f32; rows * num_proposals];
2229 for i in 0..num_proposals {
2230 let row = |r: usize| r * num_proposals + i;
2231 data[row(0)] = 320.0; data[row(1)] = 320.0; data[row(2)] = 50.0; data[row(3)] = 50.0; data[row(4)] = 0.9; }
2237 let boxes = ndarray::Array2::from_shape_vec((rows, num_proposals), data).unwrap();
2238
2239 let protos = ndarray::Array3::<f32>::zeros((160, 160, num_mask_coeffs));
2242
2243 let mut output_boxes = Vec::with_capacity(300);
2244
2245 let proto_data = impl_yolo_segdet_float_proto::<XYWH, _, _>(
2247 boxes.view(),
2248 protos.view(),
2249 0.5,
2250 0.7,
2251 Some(Nms::default()),
2252 &mut output_boxes,
2253 );
2254
2255 assert!(!output_boxes.is_empty());
2257 assert_eq!(proto_data.mask_coefficients.len(), output_boxes.len());
2258 for coeffs in &proto_data.mask_coefficients {
2260 assert_eq!(coeffs.len(), num_mask_coeffs);
2261 }
2262 }
2263}