Skip to main content

edgefirst_decoder/
yolo.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use std::fmt::Debug;
5
6use ndarray::{
7    parallel::prelude::{IntoParallelIterator, ParallelIterator},
8    s, Array2, Array3, ArrayView1, ArrayView2, ArrayView3,
9};
10use num_traits::{AsPrimitive, Float, PrimInt, Signed};
11
12use crate::{
13    byte::{
14        nms_class_aware_int, nms_extra_class_aware_int, nms_extra_int, nms_int,
15        postprocess_boxes_index_quant, postprocess_boxes_quant, quantize_score_threshold,
16    },
17    configs::Nms,
18    dequant_detect_box,
19    float::{
20        nms_class_aware_float, nms_extra_class_aware_float, nms_extra_float, nms_float,
21        postprocess_boxes_float, postprocess_boxes_index_float,
22    },
23    BBoxTypeTrait, BoundingBox, DetectBox, DetectBoxQuantized, ProtoData, ProtoTensor,
24    Quantization, Segmentation, XYWH, XYXY,
25};
26
27/// Dispatches to the appropriate NMS function based on mode for float boxes.
28fn dispatch_nms_float(nms: Option<Nms>, iou: f32, boxes: Vec<DetectBox>) -> Vec<DetectBox> {
29    match nms {
30        Some(Nms::ClassAgnostic) => nms_float(iou, boxes),
31        Some(Nms::ClassAware) => nms_class_aware_float(iou, boxes),
32        None => boxes, // bypass NMS
33    }
34}
35
36/// Dispatches to the appropriate NMS function based on mode for float boxes
37/// with extra data.
38pub(super) fn dispatch_nms_extra_float<E: Send + Sync>(
39    nms: Option<Nms>,
40    iou: f32,
41    boxes: Vec<(DetectBox, E)>,
42) -> Vec<(DetectBox, E)> {
43    match nms {
44        Some(Nms::ClassAgnostic) => nms_extra_float(iou, boxes),
45        Some(Nms::ClassAware) => nms_extra_class_aware_float(iou, boxes),
46        None => boxes, // bypass NMS
47    }
48}
49
50/// Dispatches to the appropriate NMS function based on mode for quantized
51/// boxes.
52fn dispatch_nms_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync>(
53    nms: Option<Nms>,
54    iou: f32,
55    boxes: Vec<DetectBoxQuantized<SCORE>>,
56) -> Vec<DetectBoxQuantized<SCORE>> {
57    match nms {
58        Some(Nms::ClassAgnostic) => nms_int(iou, boxes),
59        Some(Nms::ClassAware) => nms_class_aware_int(iou, boxes),
60        None => boxes, // bypass NMS
61    }
62}
63
64/// Dispatches to the appropriate NMS function based on mode for quantized boxes
65/// with extra data.
66fn dispatch_nms_extra_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send + Sync>(
67    nms: Option<Nms>,
68    iou: f32,
69    boxes: Vec<(DetectBoxQuantized<SCORE>, E)>,
70) -> Vec<(DetectBoxQuantized<SCORE>, E)> {
71    match nms {
72        Some(Nms::ClassAgnostic) => nms_extra_int(iou, boxes),
73        Some(Nms::ClassAware) => nms_extra_class_aware_int(iou, boxes),
74        None => boxes, // bypass NMS
75    }
76}
77
78/// Decodes YOLO detection outputs from quantized tensors into detection boxes.
79///
80/// Boxes are expected to be in XYWH format.
81///
82/// Expected shapes of inputs:
83/// - output: (4 + num_classes, num_boxes)
84pub fn decode_yolo_det<BOX: PrimInt + AsPrimitive<f32> + Send + Sync>(
85    output: (ArrayView2<BOX>, Quantization),
86    score_threshold: f32,
87    iou_threshold: f32,
88    nms: Option<Nms>,
89    output_boxes: &mut Vec<DetectBox>,
90) where
91    f32: AsPrimitive<BOX>,
92{
93    impl_yolo_quant::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
94}
95
96/// Decodes YOLO detection outputs from float tensors into detection boxes.
97///
98/// Boxes are expected to be in XYWH format.
99///
100/// Expected shapes of inputs:
101/// - output: (4 + num_classes, num_boxes)
102pub fn decode_yolo_det_float<T>(
103    output: ArrayView2<T>,
104    score_threshold: f32,
105    iou_threshold: f32,
106    nms: Option<Nms>,
107    output_boxes: &mut Vec<DetectBox>,
108) where
109    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
110    f32: AsPrimitive<T>,
111{
112    impl_yolo_float::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
113}
114
115/// Decodes YOLO detection and segmentation outputs from quantized tensors into
116/// detection boxes and segmentation masks.
117///
118/// Boxes are expected to be in XYWH format.
119///
120/// Expected shapes of inputs:
121/// - boxes: (4 + num_classes + num_protos, num_boxes)
122/// - protos: (proto_height, proto_width, num_protos)
123///
124/// # Errors
125/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
126pub fn decode_yolo_segdet_quant<
127    BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
128    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
129>(
130    boxes: (ArrayView2<BOX>, Quantization),
131    protos: (ArrayView3<PROTO>, Quantization),
132    score_threshold: f32,
133    iou_threshold: f32,
134    nms: Option<Nms>,
135    output_boxes: &mut Vec<DetectBox>,
136    output_masks: &mut Vec<Segmentation>,
137) -> Result<(), crate::DecoderError>
138where
139    f32: AsPrimitive<BOX>,
140{
141    impl_yolo_segdet_quant::<XYWH, _, _>(
142        boxes,
143        protos,
144        score_threshold,
145        iou_threshold,
146        nms,
147        output_boxes,
148        output_masks,
149    )
150}
151
152/// Decodes YOLO detection and segmentation outputs from float tensors into
153/// detection boxes and segmentation masks.
154///
155/// Boxes are expected to be in XYWH format.
156///
157/// Expected shapes of inputs:
158/// - boxes: (4 + num_classes + num_protos, num_boxes)
159/// - protos: (proto_height, proto_width, num_protos)
160///
161/// # Errors
162/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
163pub fn decode_yolo_segdet_float<T>(
164    boxes: ArrayView2<T>,
165    protos: ArrayView3<T>,
166    score_threshold: f32,
167    iou_threshold: f32,
168    nms: Option<Nms>,
169    output_boxes: &mut Vec<DetectBox>,
170    output_masks: &mut Vec<Segmentation>,
171) -> Result<(), crate::DecoderError>
172where
173    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
174    f32: AsPrimitive<T>,
175{
176    impl_yolo_segdet_float::<XYWH, _, _>(
177        boxes,
178        protos,
179        score_threshold,
180        iou_threshold,
181        nms,
182        output_boxes,
183        output_masks,
184    )
185}
186
187/// Decodes YOLO split detection outputs from quantized tensors into detection
188/// boxes.
189///
190/// Boxes are expected to be in XYWH format.
191///
192/// Expected shapes of inputs:
193/// - boxes: (4, num_boxes)
194/// - scores: (num_classes, num_boxes)
195///
196/// # Panics
197/// Panics if shapes don't match the expected dimensions.
198pub fn decode_yolo_split_det_quant<
199    BOX: PrimInt + AsPrimitive<i32> + AsPrimitive<f32> + Send + Sync,
200    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
201>(
202    boxes: (ArrayView2<BOX>, Quantization),
203    scores: (ArrayView2<SCORE>, Quantization),
204    score_threshold: f32,
205    iou_threshold: f32,
206    nms: Option<Nms>,
207    output_boxes: &mut Vec<DetectBox>,
208) where
209    f32: AsPrimitive<SCORE>,
210{
211    impl_yolo_split_quant::<XYWH, _, _>(
212        boxes,
213        scores,
214        score_threshold,
215        iou_threshold,
216        nms,
217        output_boxes,
218    );
219}
220
221/// Decodes YOLO split detection outputs from float tensors into detection
222/// boxes.
223///
224/// Boxes are expected to be in XYWH format.
225///
226/// Expected shapes of inputs:
227/// - boxes: (4, num_boxes)
228/// - scores: (num_classes, num_boxes)
229///
230/// # Panics
231/// Panics if shapes don't match the expected dimensions.
232pub fn decode_yolo_split_det_float<T>(
233    boxes: ArrayView2<T>,
234    scores: ArrayView2<T>,
235    score_threshold: f32,
236    iou_threshold: f32,
237    nms: Option<Nms>,
238    output_boxes: &mut Vec<DetectBox>,
239) where
240    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
241    f32: AsPrimitive<T>,
242{
243    impl_yolo_split_float::<XYWH, _, _>(
244        boxes,
245        scores,
246        score_threshold,
247        iou_threshold,
248        nms,
249        output_boxes,
250    );
251}
252
253/// Decodes YOLO split detection segmentation outputs from quantized tensors
254/// into detection boxes and segmentation masks.
255///
256/// Boxes are expected to be in XYWH format.
257///
258/// Expected shapes of inputs:
259/// - boxes_tensor: (4, num_boxes)
260/// - scores_tensor: (num_classes, num_boxes)
261/// - mask_tensor: (num_protos, num_boxes)
262/// - protos: (proto_height, proto_width, num_protos)
263///
264/// # Errors
265/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
266#[allow(clippy::too_many_arguments)]
267pub fn decode_yolo_split_segdet<
268    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
269    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
270    MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
271    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
272>(
273    boxes: (ArrayView2<BOX>, Quantization),
274    scores: (ArrayView2<SCORE>, Quantization),
275    mask_coeff: (ArrayView2<MASK>, Quantization),
276    protos: (ArrayView3<PROTO>, Quantization),
277    score_threshold: f32,
278    iou_threshold: f32,
279    nms: Option<Nms>,
280    output_boxes: &mut Vec<DetectBox>,
281    output_masks: &mut Vec<Segmentation>,
282) -> Result<(), crate::DecoderError>
283where
284    f32: AsPrimitive<SCORE>,
285{
286    impl_yolo_split_segdet_quant::<XYWH, _, _, _, _>(
287        boxes,
288        scores,
289        mask_coeff,
290        protos,
291        score_threshold,
292        iou_threshold,
293        nms,
294        output_boxes,
295        output_masks,
296    )
297}
298
299/// Decodes YOLO split detection segmentation outputs from float tensors
300/// into detection boxes and segmentation masks.
301///
302/// Boxes are expected to be in XYWH format.
303///
304/// Expected shapes of inputs:
305/// - boxes_tensor: (4, num_boxes)
306/// - scores_tensor: (num_classes, num_boxes)
307/// - mask_tensor: (num_protos, num_boxes)
308/// - protos: (proto_height, proto_width, num_protos)
309///
310/// # Errors
311/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
312#[allow(clippy::too_many_arguments)]
313pub fn decode_yolo_split_segdet_float<T>(
314    boxes: ArrayView2<T>,
315    scores: ArrayView2<T>,
316    mask_coeff: ArrayView2<T>,
317    protos: ArrayView3<T>,
318    score_threshold: f32,
319    iou_threshold: f32,
320    nms: Option<Nms>,
321    output_boxes: &mut Vec<DetectBox>,
322    output_masks: &mut Vec<Segmentation>,
323) -> Result<(), crate::DecoderError>
324where
325    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
326    f32: AsPrimitive<T>,
327{
328    impl_yolo_split_segdet_float::<XYWH, _, _, _, _>(
329        boxes,
330        scores,
331        mask_coeff,
332        protos,
333        score_threshold,
334        iou_threshold,
335        nms,
336        output_boxes,
337        output_masks,
338    )
339}
340
341/// Decodes end-to-end YOLO detection outputs (post-NMS from model).
342/// Expects an array of shape `(6, N)`, where the first dimension (rows)
343/// corresponds to the 6 per-detection features
344/// `[x1, y1, x2, y2, conf, class]` and the second dimension (columns)
345/// indexes the `N` detections.
346/// Boxes are output directly without NMS (the model already applied NMS).
347///
348/// Coordinates may be normalized `[0, 1]` or absolute pixel values depending
349/// on the model configuration. The caller should check
350/// `decoder.normalized_boxes()` to determine which.
351///
352/// # Errors
353///
354/// Returns `DecoderError::InvalidShape` if `output` has fewer than 6 rows.
355pub fn decode_yolo_end_to_end_det_float<T>(
356    output: ArrayView2<T>,
357    score_threshold: f32,
358    output_boxes: &mut Vec<DetectBox>,
359) -> Result<(), crate::DecoderError>
360where
361    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
362    f32: AsPrimitive<T>,
363{
364    // Validate input shape: need at least 6 rows (x1, y1, x2, y2, conf, class)
365    if output.shape()[0] < 6 {
366        return Err(crate::DecoderError::InvalidShape(format!(
367            "End-to-end detection output requires at least 6 rows, got {}",
368            output.shape()[0]
369        )));
370    }
371
372    // Input shape: (6, N) -> transpose to (N, 4) for boxes and (N, 1) for scores
373    let boxes = output.slice(s![0..4, ..]).reversed_axes();
374    let scores = output.slice(s![4..5, ..]).reversed_axes();
375    let classes = output.slice(s![5, ..]);
376    let mut boxes =
377        postprocess_boxes_index_float::<XYXY, _, _>(score_threshold.as_(), boxes, scores);
378    boxes.truncate(output_boxes.capacity());
379    output_boxes.clear();
380    for (mut b, i) in boxes.into_iter() {
381        b.label = classes[i].as_() as usize;
382        output_boxes.push(b);
383    }
384    // No NMS — model output is already post-NMS
385    Ok(())
386}
387
388/// Decodes end-to-end YOLO detection + segmentation outputs (post-NMS from
389/// model).
390///
391/// Input shapes:
392/// - detection: (6 + num_protos, N) where rows are [x1, y1, x2, y2, conf,
393///   class, mask_coeff_0, ..., mask_coeff_31]
394/// - protos: (proto_height, proto_width, num_protos)
395///
396/// Boxes are output directly without NMS (model already applied NMS).
397/// Coordinates may be normalized [0,1] or pixel values depending on model
398/// config.
399///
400/// # Errors
401///
402/// Returns `DecoderError::InvalidShape` if:
403/// - output has fewer than 7 rows (6 base + at least 1 mask coefficient)
404/// - protos shape doesn't match mask coefficients count
405pub fn decode_yolo_end_to_end_segdet_float<T>(
406    output: ArrayView2<T>,
407    protos: ArrayView3<T>,
408    score_threshold: f32,
409    output_boxes: &mut Vec<DetectBox>,
410    output_masks: &mut Vec<crate::Segmentation>,
411) -> Result<(), crate::DecoderError>
412where
413    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
414    f32: AsPrimitive<T>,
415{
416    let (boxes, scores, classes, mask_coeff) =
417        postprocess_yolo_end_to_end_segdet(&output, protos.dim().2)?;
418    let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
419        boxes,
420        scores,
421        classes,
422        score_threshold,
423        output_boxes.capacity(),
424    );
425
426    // No NMS — model output is already post-NMS
427
428    impl_yolo_split_segdet_process_masks(boxes, mask_coeff, protos, output_boxes, output_masks)
429}
430
431/// Decodes split end-to-end YOLO detection outputs (post-NMS from model).
432///
433/// Input shapes (after batch dim removed):
434/// - boxes: (N, 4) — xyxy pixel coordinates
435/// - scores: (N, 1) — confidence of the top class
436/// - classes: (N, 1) — class index of the top class
437///
438/// Boxes are output directly without NMS (model already applied NMS).
439pub fn decode_yolo_split_end_to_end_det_float<T: Float + AsPrimitive<f32>>(
440    boxes: ArrayView2<T>,
441    scores: ArrayView2<T>,
442    classes: ArrayView2<T>,
443    score_threshold: f32,
444    output_boxes: &mut Vec<DetectBox>,
445) -> Result<(), crate::DecoderError> {
446    let n = boxes.shape()[0];
447    if boxes.shape()[1] != 4 {
448        return Err(crate::DecoderError::InvalidShape(format!(
449            "Split end-to-end boxes must have 4 columns, got {}",
450            boxes.shape()[1]
451        )));
452    }
453    output_boxes.clear();
454    for i in 0..n {
455        let score: f32 = scores[[i, 0]].as_();
456        if score < score_threshold {
457            continue;
458        }
459        if output_boxes.len() >= output_boxes.capacity() {
460            break;
461        }
462        output_boxes.push(DetectBox {
463            bbox: BoundingBox {
464                xmin: boxes[[i, 0]].as_(),
465                ymin: boxes[[i, 1]].as_(),
466                xmax: boxes[[i, 2]].as_(),
467                ymax: boxes[[i, 3]].as_(),
468            },
469            score,
470            label: classes[[i, 0]].as_() as usize,
471        });
472    }
473    Ok(())
474}
475
476/// Decodes split end-to-end YOLO detection + segmentation outputs.
477///
478/// Input shapes (after batch dim removed):
479/// - boxes: (4, N) — xyxy pixel coordinates
480/// - scores: (1, N) — confidence
481/// - classes: (1, N) — class index
482/// - mask_coeff: (num_protos, N) — mask coefficients per detection
483/// - protos: (proto_h, proto_w, num_protos) — prototype masks
484#[allow(clippy::too_many_arguments)]
485pub fn decode_yolo_split_end_to_end_segdet_float<T>(
486    boxes: ArrayView2<T>,
487    scores: ArrayView2<T>,
488    classes: ArrayView2<T>,
489    mask_coeff: ArrayView2<T>,
490    protos: ArrayView3<T>,
491    score_threshold: f32,
492    output_boxes: &mut Vec<DetectBox>,
493    output_masks: &mut Vec<crate::Segmentation>,
494) -> Result<(), crate::DecoderError>
495where
496    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
497    f32: AsPrimitive<T>,
498{
499    let (boxes, scores, classes, mask_coeff) =
500        postprocess_yolo_split_end_to_end_segdet(boxes, scores, &classes, mask_coeff)?;
501    let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
502        boxes,
503        scores,
504        classes,
505        score_threshold,
506        output_boxes.capacity(),
507    );
508
509    impl_yolo_split_segdet_process_masks(boxes, mask_coeff, protos, output_boxes, output_masks)
510}
511
512#[allow(clippy::type_complexity)]
513pub(crate) fn postprocess_yolo_end_to_end_segdet<'a, T>(
514    output: &'a ArrayView2<'_, T>,
515    num_protos: usize,
516) -> Result<
517    (
518        ArrayView2<'a, T>,
519        ArrayView2<'a, T>,
520        ArrayView1<'a, T>,
521        ArrayView2<'a, T>,
522    ),
523    crate::DecoderError,
524> {
525    // Validate input shape: need at least 7 rows (6 base + at least 1 mask coeff)
526    if output.shape()[0] < 7 {
527        return Err(crate::DecoderError::InvalidShape(format!(
528            "End-to-end segdet output requires at least 7 rows, got {}",
529            output.shape()[0]
530        )));
531    }
532
533    let num_mask_coeffs = output.shape()[0] - 6;
534    if num_mask_coeffs != num_protos {
535        return Err(crate::DecoderError::InvalidShape(format!(
536            "Mask coefficients count ({}) doesn't match protos count ({})",
537            num_mask_coeffs, num_protos
538        )));
539    }
540
541    // Input shape: (6+num_protos, N) -> transpose for postprocessing
542    let boxes = output.slice(s![0..4, ..]).reversed_axes();
543    let scores = output.slice(s![4..5, ..]).reversed_axes();
544    let classes = output.slice(s![5, ..]);
545    let mask_coeff = output.slice(s![6.., ..]).reversed_axes();
546    Ok((boxes, scores, classes, mask_coeff))
547}
548
549#[allow(clippy::type_complexity)]
550pub(crate) fn postprocess_yolo_split_end_to_end_segdet<
551    'a,
552    'b,
553    'c,
554    'd,
555    BOXES,
556    SCORES,
557    CLASS,
558    MASK,
559>(
560    boxes: ArrayView2<'a, BOXES>,
561    scores: ArrayView2<'b, SCORES>,
562    classes: &'c ArrayView2<CLASS>,
563    mask_coeff: ArrayView2<'d, MASK>,
564) -> Result<
565    (
566        ArrayView2<'a, BOXES>,
567        ArrayView2<'b, SCORES>,
568        ArrayView1<'c, CLASS>,
569        ArrayView2<'d, MASK>,
570    ),
571    crate::DecoderError,
572> {
573    if boxes.shape()[0] != 4 {
574        return Err(crate::DecoderError::InvalidShape(format!(
575            "Split end-to-end boxes must have 4 columns, got {}",
576            boxes.shape()[0]
577        )));
578    }
579    let boxes = boxes.reversed_axes();
580    let scores = scores.reversed_axes();
581    let classes = classes.slice(s![0, ..]);
582    let mask_coeff = mask_coeff.reversed_axes();
583    Ok((boxes, scores, classes, mask_coeff))
584}
585/// Internal implementation of YOLO decoding for quantized tensors.
586///
587/// Expected shapes of inputs:
588/// - output: (4 + num_classes, num_boxes)
589pub fn impl_yolo_quant<B: BBoxTypeTrait, T: PrimInt + AsPrimitive<f32> + Send + Sync>(
590    output: (ArrayView2<T>, Quantization),
591    score_threshold: f32,
592    iou_threshold: f32,
593    nms: Option<Nms>,
594    output_boxes: &mut Vec<DetectBox>,
595) where
596    f32: AsPrimitive<T>,
597{
598    let (boxes, quant_boxes) = output;
599    let (boxes_tensor, scores_tensor) = postprocess_yolo(&boxes);
600
601    let boxes = {
602        let score_threshold = quantize_score_threshold(score_threshold, quant_boxes);
603        postprocess_boxes_quant::<B, _, _>(
604            score_threshold,
605            boxes_tensor,
606            scores_tensor,
607            quant_boxes,
608        )
609    };
610
611    let boxes = dispatch_nms_int(nms, iou_threshold, boxes);
612    let len = output_boxes.capacity().min(boxes.len());
613    output_boxes.clear();
614    for b in boxes.iter().take(len) {
615        output_boxes.push(dequant_detect_box(b, quant_boxes));
616    }
617}
618
619/// Internal implementation of YOLO decoding for float tensors.
620///
621/// Expected shapes of inputs:
622/// - output: (4 + num_classes, num_boxes)
623pub fn impl_yolo_float<B: BBoxTypeTrait, T: Float + AsPrimitive<f32> + Send + Sync>(
624    output: ArrayView2<T>,
625    score_threshold: f32,
626    iou_threshold: f32,
627    nms: Option<Nms>,
628    output_boxes: &mut Vec<DetectBox>,
629) where
630    f32: AsPrimitive<T>,
631{
632    let (boxes_tensor, scores_tensor) = postprocess_yolo(&output);
633    let boxes =
634        postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
635    let boxes = dispatch_nms_float(nms, iou_threshold, boxes);
636    let len = output_boxes.capacity().min(boxes.len());
637    output_boxes.clear();
638    for b in boxes.into_iter().take(len) {
639        output_boxes.push(b);
640    }
641}
642
643/// Internal implementation of YOLO split detection decoding for quantized
644/// tensors.
645///
646/// Expected shapes of inputs:
647/// - boxes: (4, num_boxes)
648/// - scores: (num_classes, num_boxes)
649///
650/// # Panics
651/// Panics if shapes don't match the expected dimensions.
652pub fn impl_yolo_split_quant<
653    B: BBoxTypeTrait,
654    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
655    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
656>(
657    boxes: (ArrayView2<BOX>, Quantization),
658    scores: (ArrayView2<SCORE>, Quantization),
659    score_threshold: f32,
660    iou_threshold: f32,
661    nms: Option<Nms>,
662    output_boxes: &mut Vec<DetectBox>,
663) where
664    f32: AsPrimitive<SCORE>,
665{
666    let (boxes_tensor, quant_boxes) = boxes;
667    let (scores_tensor, quant_scores) = scores;
668
669    let boxes_tensor = boxes_tensor.reversed_axes();
670    let scores_tensor = scores_tensor.reversed_axes();
671
672    let boxes = {
673        let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
674        postprocess_boxes_quant::<B, _, _>(
675            score_threshold,
676            boxes_tensor,
677            scores_tensor,
678            quant_boxes,
679        )
680    };
681
682    let boxes = dispatch_nms_int(nms, iou_threshold, boxes);
683    let len = output_boxes.capacity().min(boxes.len());
684    output_boxes.clear();
685    for b in boxes.iter().take(len) {
686        output_boxes.push(dequant_detect_box(b, quant_scores));
687    }
688}
689
690/// Internal implementation of YOLO split detection decoding for float tensors.
691///
692/// Expected shapes of inputs:
693/// - boxes: (4, num_boxes)
694/// - scores: (num_classes, num_boxes)
695///
696/// # Panics
697/// Panics if shapes don't match the expected dimensions.
698pub fn impl_yolo_split_float<
699    B: BBoxTypeTrait,
700    BOX: Float + AsPrimitive<f32> + Send + Sync,
701    SCORE: Float + AsPrimitive<f32> + Send + Sync,
702>(
703    boxes_tensor: ArrayView2<BOX>,
704    scores_tensor: ArrayView2<SCORE>,
705    score_threshold: f32,
706    iou_threshold: f32,
707    nms: Option<Nms>,
708    output_boxes: &mut Vec<DetectBox>,
709) where
710    f32: AsPrimitive<SCORE>,
711{
712    let boxes_tensor = boxes_tensor.reversed_axes();
713    let scores_tensor = scores_tensor.reversed_axes();
714    let boxes =
715        postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
716    let boxes = dispatch_nms_float(nms, iou_threshold, boxes);
717    let len = output_boxes.capacity().min(boxes.len());
718    output_boxes.clear();
719    for b in boxes.into_iter().take(len) {
720        output_boxes.push(b);
721    }
722}
723
724/// Internal implementation of YOLO detection segmentation decoding for
725/// quantized tensors.
726///
727/// Expected shapes of inputs:
728/// - boxes: (4 + num_classes + num_protos, num_boxes)
729/// - protos: (proto_height, proto_width, num_protos)
730///
731/// # Errors
732/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
733pub fn impl_yolo_segdet_quant<
734    B: BBoxTypeTrait,
735    BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
736    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
737>(
738    boxes: (ArrayView2<BOX>, Quantization),
739    protos: (ArrayView3<PROTO>, Quantization),
740    score_threshold: f32,
741    iou_threshold: f32,
742    nms: Option<Nms>,
743    output_boxes: &mut Vec<DetectBox>,
744    output_masks: &mut Vec<Segmentation>,
745) -> Result<(), crate::DecoderError>
746where
747    f32: AsPrimitive<BOX>,
748{
749    let (boxes, quant_boxes) = boxes;
750    let num_protos = protos.0.dim().2;
751
752    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
753    let boxes = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
754        (boxes_tensor, quant_boxes),
755        (scores_tensor, quant_boxes),
756        score_threshold,
757        iou_threshold,
758        nms,
759        output_boxes.capacity(),
760    );
761
762    impl_yolo_split_segdet_quant_process_masks::<_, _>(
763        boxes,
764        (mask_tensor, quant_boxes),
765        protos,
766        output_boxes,
767        output_masks,
768    )
769}
770
771/// Internal implementation of YOLO detection segmentation decoding for
772/// float tensors.
773///
774/// Expected shapes of inputs:
775/// - boxes: (4 + num_classes + num_protos, num_boxes)
776/// - protos: (proto_height, proto_width, num_protos)
777///
778/// # Panics
779/// Panics if shapes don't match the expected dimensions.
780pub fn impl_yolo_segdet_float<
781    B: BBoxTypeTrait,
782    BOX: Float + AsPrimitive<f32> + Send + Sync,
783    PROTO: Float + AsPrimitive<f32> + Send + Sync,
784>(
785    boxes: ArrayView2<BOX>,
786    protos: ArrayView3<PROTO>,
787    score_threshold: f32,
788    iou_threshold: f32,
789    nms: Option<Nms>,
790    output_boxes: &mut Vec<DetectBox>,
791    output_masks: &mut Vec<Segmentation>,
792) -> Result<(), crate::DecoderError>
793where
794    f32: AsPrimitive<BOX>,
795{
796    let num_protos = protos.dim().2;
797    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
798    let boxes = impl_yolo_segdet_get_boxes::<XYWH, _, _>(
799        boxes_tensor,
800        scores_tensor,
801        score_threshold,
802        iou_threshold,
803        nms,
804        output_boxes.capacity(),
805    );
806    impl_yolo_split_segdet_process_masks(boxes, mask_tensor, protos, output_boxes, output_masks)
807}
808
809pub(crate) fn impl_yolo_segdet_get_boxes<
810    B: BBoxTypeTrait,
811    BOX: Float + AsPrimitive<f32> + Send + Sync,
812    SCORE: Float + AsPrimitive<f32> + Send + Sync,
813>(
814    boxes_tensor: ArrayView2<BOX>,
815    scores_tensor: ArrayView2<SCORE>,
816    score_threshold: f32,
817    iou_threshold: f32,
818    nms: Option<Nms>,
819    max_boxes: usize,
820) -> Vec<(DetectBox, usize)>
821where
822    f32: AsPrimitive<SCORE>,
823{
824    let boxes = postprocess_boxes_index_float::<B, _, _>(
825        score_threshold.as_(),
826        boxes_tensor,
827        scores_tensor,
828    );
829    let mut boxes = dispatch_nms_extra_float(nms, iou_threshold, boxes);
830    boxes.truncate(max_boxes);
831    boxes
832}
833
834pub(crate) fn impl_yolo_end_to_end_segdet_get_boxes<
835    B: BBoxTypeTrait,
836    BOX: Float + AsPrimitive<f32> + Send + Sync,
837    SCORE: Float + AsPrimitive<f32> + Send + Sync,
838    CLASS: AsPrimitive<f32> + Send + Sync,
839>(
840    boxes: ArrayView2<BOX>,
841    scores: ArrayView2<SCORE>,
842    classes: ArrayView1<CLASS>,
843    score_threshold: f32,
844    max_boxes: usize,
845) -> Vec<(DetectBox, usize)>
846where
847    f32: AsPrimitive<SCORE>,
848{
849    let mut boxes = postprocess_boxes_index_float::<B, _, _>(score_threshold.as_(), boxes, scores);
850    boxes.truncate(max_boxes);
851    for (b, ind) in &mut boxes {
852        b.label = classes[*ind].as_().round() as usize;
853    }
854    boxes
855}
856
857pub(crate) fn impl_yolo_split_segdet_process_masks<
858    MASK: Float + AsPrimitive<f32> + Send + Sync,
859    PROTO: Float + AsPrimitive<f32> + Send + Sync,
860>(
861    boxes: Vec<(DetectBox, usize)>,
862    masks_tensor: ArrayView2<MASK>,
863    protos_tensor: ArrayView3<PROTO>,
864    output_boxes: &mut Vec<DetectBox>,
865    output_masks: &mut Vec<Segmentation>,
866) -> Result<(), crate::DecoderError> {
867    let boxes = decode_segdet_f32(boxes, masks_tensor, protos_tensor)?;
868    output_boxes.clear();
869    output_masks.clear();
870    for (b, m) in boxes.into_iter() {
871        output_boxes.push(b);
872        output_masks.push(Segmentation {
873            xmin: b.bbox.xmin,
874            ymin: b.bbox.ymin,
875            xmax: b.bbox.xmax,
876            ymax: b.bbox.ymax,
877            segmentation: m,
878        });
879    }
880    Ok(())
881}
882/// Expected input shapes:
883/// - boxes_tensor: (num_boxes, 4)
884/// - scores_tensor: (num_boxes, num_classes)
885pub(crate) fn impl_yolo_split_segdet_quant_get_boxes<
886    B: BBoxTypeTrait,
887    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
888    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
889>(
890    boxes: (ArrayView2<BOX>, Quantization),
891    scores: (ArrayView2<SCORE>, Quantization),
892    score_threshold: f32,
893    iou_threshold: f32,
894    nms: Option<Nms>,
895    max_boxes: usize,
896) -> Vec<(DetectBox, usize)>
897where
898    f32: AsPrimitive<SCORE>,
899{
900    let (boxes_tensor, quant_boxes) = boxes;
901    let (scores_tensor, quant_scores) = scores;
902
903    let boxes = {
904        let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
905        postprocess_boxes_index_quant::<B, _, _>(
906            score_threshold,
907            boxes_tensor,
908            scores_tensor,
909            quant_boxes,
910        )
911    };
912    let mut boxes = dispatch_nms_extra_int(nms, iou_threshold, boxes);
913    boxes.truncate(max_boxes);
914    boxes
915        .into_iter()
916        .map(|(b, i)| (dequant_detect_box(&b, quant_scores), i))
917        .collect()
918}
919
920pub(crate) fn impl_yolo_split_segdet_quant_process_masks<
921    MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
922    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
923>(
924    boxes: Vec<(DetectBox, usize)>,
925    mask_coeff: (ArrayView2<MASK>, Quantization),
926    protos: (ArrayView3<PROTO>, Quantization),
927    output_boxes: &mut Vec<DetectBox>,
928    output_masks: &mut Vec<Segmentation>,
929) -> Result<(), crate::DecoderError> {
930    let (masks, quant_masks) = mask_coeff;
931    let (protos, quant_protos) = protos;
932
933    let boxes = decode_segdet_quant(boxes, masks, protos, quant_masks, quant_protos)?;
934    output_boxes.clear();
935    output_masks.clear();
936    for (b, m) in boxes.into_iter() {
937        output_boxes.push(b);
938        output_masks.push(Segmentation {
939            xmin: b.bbox.xmin,
940            ymin: b.bbox.ymin,
941            xmax: b.bbox.xmax,
942            ymax: b.bbox.ymax,
943            segmentation: m,
944        });
945    }
946    Ok(())
947}
948
949#[allow(clippy::too_many_arguments)]
950/// Internal implementation of YOLO split detection segmentation decoding for
951/// quantized tensors.
952///
953/// Expected shapes of inputs:
954/// - boxes_tensor: (4, num_boxes)
955/// - scores_tensor: (num_classes, num_boxes)
956/// - mask_tensor: (num_protos, num_boxes)
957/// - protos: (proto_height, proto_width, num_protos)
958///
959/// # Errors
960/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
961pub fn impl_yolo_split_segdet_quant<
962    B: BBoxTypeTrait,
963    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
964    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
965    MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
966    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
967>(
968    boxes: (ArrayView2<BOX>, Quantization),
969    scores: (ArrayView2<SCORE>, Quantization),
970    mask_coeff: (ArrayView2<MASK>, Quantization),
971    protos: (ArrayView3<PROTO>, Quantization),
972    score_threshold: f32,
973    iou_threshold: f32,
974    nms: Option<Nms>,
975    output_boxes: &mut Vec<DetectBox>,
976    output_masks: &mut Vec<Segmentation>,
977) -> Result<(), crate::DecoderError>
978where
979    f32: AsPrimitive<SCORE>,
980{
981    let (boxes_, scores_, mask_coeff_) =
982        postprocess_yolo_split_segdet(boxes.0, scores.0, mask_coeff.0);
983    let boxes = (boxes_, boxes.1);
984    let scores = (scores_, scores.1);
985    let mask_coeff = (mask_coeff_, mask_coeff.1);
986
987    let boxes = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
988        boxes,
989        scores,
990        score_threshold,
991        iou_threshold,
992        nms,
993        output_boxes.capacity(),
994    );
995
996    impl_yolo_split_segdet_quant_process_masks(
997        boxes,
998        mask_coeff,
999        protos,
1000        output_boxes,
1001        output_masks,
1002    )
1003}
1004
1005#[allow(clippy::too_many_arguments)]
1006/// Internal implementation of YOLO split detection segmentation decoding for
1007/// float tensors.
1008///
1009/// Expected shapes of inputs:
1010/// - boxes_tensor: (4, num_boxes)
1011/// - scores_tensor: (num_classes, num_boxes)
1012/// - mask_tensor: (num_protos, num_boxes)
1013/// - protos: (proto_height, proto_width, num_protos)
1014///
1015/// # Errors
1016/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
1017pub fn impl_yolo_split_segdet_float<
1018    B: BBoxTypeTrait,
1019    BOX: Float + AsPrimitive<f32> + Send + Sync,
1020    SCORE: Float + AsPrimitive<f32> + Send + Sync,
1021    MASK: Float + AsPrimitive<f32> + Send + Sync,
1022    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1023>(
1024    boxes_tensor: ArrayView2<BOX>,
1025    scores_tensor: ArrayView2<SCORE>,
1026    mask_tensor: ArrayView2<MASK>,
1027    protos: ArrayView3<PROTO>,
1028    score_threshold: f32,
1029    iou_threshold: f32,
1030    nms: Option<Nms>,
1031    output_boxes: &mut Vec<DetectBox>,
1032    output_masks: &mut Vec<Segmentation>,
1033) -> Result<(), crate::DecoderError>
1034where
1035    f32: AsPrimitive<SCORE>,
1036{
1037    let (boxes_tensor, scores_tensor, mask_tensor) =
1038        postprocess_yolo_split_segdet(boxes_tensor, scores_tensor, mask_tensor);
1039
1040    let boxes = impl_yolo_segdet_get_boxes::<XYWH, _, _>(
1041        boxes_tensor,
1042        scores_tensor,
1043        score_threshold,
1044        iou_threshold,
1045        nms,
1046        output_boxes.capacity(),
1047    );
1048    impl_yolo_split_segdet_process_masks(boxes, mask_tensor, protos, output_boxes, output_masks)
1049}
1050
1051// ---------------------------------------------------------------------------
1052// Proto-extraction variants: return ProtoData instead of materialized masks
1053// ---------------------------------------------------------------------------
1054
1055/// Proto-extraction variant of `impl_yolo_segdet_quant`.
1056/// Runs NMS but returns raw `ProtoData` instead of materialized masks.
1057pub fn impl_yolo_segdet_quant_proto<
1058    B: BBoxTypeTrait,
1059    BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1060    PROTO: PrimInt
1061        + AsPrimitive<i64>
1062        + AsPrimitive<i128>
1063        + AsPrimitive<f32>
1064        + AsPrimitive<i8>
1065        + Send
1066        + Sync,
1067>(
1068    boxes: (ArrayView2<BOX>, Quantization),
1069    protos: (ArrayView3<PROTO>, Quantization),
1070    score_threshold: f32,
1071    iou_threshold: f32,
1072    nms: Option<Nms>,
1073    output_boxes: &mut Vec<DetectBox>,
1074) -> ProtoData
1075where
1076    f32: AsPrimitive<BOX>,
1077{
1078    let (boxes_arr, quant_boxes) = boxes;
1079    let (protos_arr, quant_protos) = protos;
1080    let num_protos = protos_arr.dim().2;
1081
1082    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes_arr, num_protos);
1083
1084    let det_indices = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
1085        (boxes_tensor, quant_boxes),
1086        (scores_tensor, quant_boxes),
1087        score_threshold,
1088        iou_threshold,
1089        nms,
1090        output_boxes.capacity(),
1091    );
1092
1093    extract_proto_data_quant(
1094        det_indices,
1095        mask_tensor,
1096        quant_boxes,
1097        protos_arr,
1098        quant_protos,
1099        output_boxes,
1100    )
1101}
1102
1103/// Proto-extraction variant of `impl_yolo_segdet_float`.
1104/// Runs NMS but returns raw `ProtoData` instead of materialized masks.
1105pub fn impl_yolo_segdet_float_proto<
1106    B: BBoxTypeTrait,
1107    BOX: Float + AsPrimitive<f32> + Send + Sync,
1108    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1109>(
1110    boxes: ArrayView2<BOX>,
1111    protos: ArrayView3<PROTO>,
1112    score_threshold: f32,
1113    iou_threshold: f32,
1114    nms: Option<Nms>,
1115    output_boxes: &mut Vec<DetectBox>,
1116) -> ProtoData
1117where
1118    f32: AsPrimitive<BOX>,
1119{
1120    let num_protos = protos.dim().2;
1121    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
1122
1123    let boxes = impl_yolo_segdet_get_boxes::<XYWH, _, _>(
1124        boxes_tensor,
1125        scores_tensor,
1126        score_threshold,
1127        iou_threshold,
1128        nms,
1129        output_boxes.capacity(),
1130    );
1131
1132    extract_proto_data_float(boxes, mask_tensor, protos, output_boxes)
1133}
1134
1135/// Proto-extraction variant of `impl_yolo_split_segdet_quant`.
1136/// Runs NMS but returns raw `ProtoData` instead of materialized masks.
1137#[allow(clippy::too_many_arguments)]
1138pub fn impl_yolo_split_segdet_quant_proto<
1139    B: BBoxTypeTrait,
1140    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
1141    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
1142    MASK: PrimInt + AsPrimitive<f32> + Send + Sync,
1143    PROTO: PrimInt + AsPrimitive<f32> + AsPrimitive<i8> + Send + Sync,
1144>(
1145    boxes: (ArrayView2<BOX>, Quantization),
1146    scores: (ArrayView2<SCORE>, Quantization),
1147    mask_coeff: (ArrayView2<MASK>, Quantization),
1148    protos: (ArrayView3<PROTO>, Quantization),
1149    score_threshold: f32,
1150    iou_threshold: f32,
1151    nms: Option<Nms>,
1152    output_boxes: &mut Vec<DetectBox>,
1153) -> ProtoData
1154where
1155    f32: AsPrimitive<SCORE>,
1156{
1157    let (boxes_, scores_, mask_coeff_) =
1158        postprocess_yolo_split_segdet(boxes.0, scores.0, mask_coeff.0);
1159    let boxes = (boxes_, boxes.1);
1160    let scores = (scores_, scores.1);
1161    let mask_coeff = (mask_coeff_, mask_coeff.1);
1162
1163    let det_indices = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
1164        boxes,
1165        scores,
1166        score_threshold,
1167        iou_threshold,
1168        nms,
1169        output_boxes.capacity(),
1170    );
1171
1172    let (masks, quant_masks) = mask_coeff;
1173    let masks = masks.reversed_axes();
1174    let (protos_arr, quant_protos) = protos;
1175
1176    extract_proto_data_quant(
1177        det_indices,
1178        masks,
1179        quant_masks,
1180        protos_arr,
1181        quant_protos,
1182        output_boxes,
1183    )
1184}
1185
1186/// Proto-extraction variant of `impl_yolo_split_segdet_float`.
1187/// Runs NMS but returns raw `ProtoData` instead of materialized masks.
1188#[allow(clippy::too_many_arguments)]
1189pub fn impl_yolo_split_segdet_float_proto<
1190    B: BBoxTypeTrait,
1191    BOX: Float + AsPrimitive<f32> + Send + Sync,
1192    SCORE: Float + AsPrimitive<f32> + Send + Sync,
1193    MASK: Float + AsPrimitive<f32> + Send + Sync,
1194    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1195>(
1196    boxes_tensor: ArrayView2<BOX>,
1197    scores_tensor: ArrayView2<SCORE>,
1198    mask_tensor: ArrayView2<MASK>,
1199    protos: ArrayView3<PROTO>,
1200    score_threshold: f32,
1201    iou_threshold: f32,
1202    nms: Option<Nms>,
1203    output_boxes: &mut Vec<DetectBox>,
1204) -> ProtoData
1205where
1206    f32: AsPrimitive<SCORE>,
1207{
1208    let (boxes_tensor, scores_tensor, mask_tensor) =
1209        postprocess_yolo_split_segdet(boxes_tensor, scores_tensor, mask_tensor);
1210    let det_indices = impl_yolo_segdet_get_boxes::<B, _, _>(
1211        boxes_tensor,
1212        scores_tensor,
1213        score_threshold,
1214        iou_threshold,
1215        nms,
1216        output_boxes.capacity(),
1217    );
1218
1219    extract_proto_data_float(det_indices, mask_tensor, protos, output_boxes)
1220}
1221
1222/// Proto-extraction variant of `decode_yolo_end_to_end_segdet_float`.
1223pub fn decode_yolo_end_to_end_segdet_float_proto<T>(
1224    output: ArrayView2<T>,
1225    protos: ArrayView3<T>,
1226    score_threshold: f32,
1227    output_boxes: &mut Vec<DetectBox>,
1228) -> Result<ProtoData, crate::DecoderError>
1229where
1230    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
1231    f32: AsPrimitive<T>,
1232{
1233    let (boxes, scores, classes, mask_coeff) =
1234        postprocess_yolo_end_to_end_segdet(&output, protos.dim().2)?;
1235    let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
1236        boxes,
1237        scores,
1238        classes,
1239        score_threshold,
1240        output_boxes.capacity(),
1241    );
1242
1243    Ok(extract_proto_data_float(
1244        boxes,
1245        mask_coeff,
1246        protos,
1247        output_boxes,
1248    ))
1249}
1250
1251/// Proto-extraction variant of `decode_yolo_split_end_to_end_segdet_float`.
1252#[allow(clippy::too_many_arguments)]
1253pub fn decode_yolo_split_end_to_end_segdet_float_proto<T>(
1254    boxes: ArrayView2<T>,
1255    scores: ArrayView2<T>,
1256    classes: ArrayView2<T>,
1257    mask_coeff: ArrayView2<T>,
1258    protos: ArrayView3<T>,
1259    score_threshold: f32,
1260    output_boxes: &mut Vec<DetectBox>,
1261) -> Result<ProtoData, crate::DecoderError>
1262where
1263    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
1264    f32: AsPrimitive<T>,
1265{
1266    let (boxes, scores, classes, mask_coeff) =
1267        postprocess_yolo_split_end_to_end_segdet(boxes, scores, &classes, mask_coeff)?;
1268    let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
1269        boxes,
1270        scores,
1271        classes,
1272        score_threshold,
1273        output_boxes.capacity(),
1274    );
1275
1276    Ok(extract_proto_data_float(
1277        boxes,
1278        mask_coeff,
1279        protos,
1280        output_boxes,
1281    ))
1282}
1283
1284/// Helper: extract ProtoData from float mask coefficients + protos.
1285pub(super) fn extract_proto_data_float<
1286    MASK: Float + AsPrimitive<f32> + Send + Sync,
1287    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1288>(
1289    det_indices: Vec<(DetectBox, usize)>,
1290    mask_tensor: ArrayView2<MASK>,
1291    protos: ArrayView3<PROTO>,
1292    output_boxes: &mut Vec<DetectBox>,
1293) -> ProtoData {
1294    let mut mask_coefficients = Vec::with_capacity(det_indices.len());
1295    output_boxes.clear();
1296    for (det, idx) in det_indices {
1297        output_boxes.push(det);
1298        let row = mask_tensor.row(idx);
1299        mask_coefficients.push(row.iter().map(|v| v.as_()).collect());
1300    }
1301    let protos_f32 = protos.map(|v| v.as_());
1302    ProtoData {
1303        mask_coefficients,
1304        protos: ProtoTensor::Float(protos_f32),
1305    }
1306}
1307
1308/// Helper: extract ProtoData from quantized mask coefficients + protos.
1309///
1310/// Dequantizes mask coefficients to f32 (small — per-detection) but keeps
1311/// protos in raw int8 form wrapped in `ProtoTensor::Quantized` so the GPU
1312/// shader can dequantize per-texel without CPU overhead.
1313pub(crate) fn extract_proto_data_quant<
1314    MASK: PrimInt + AsPrimitive<f32> + Send + Sync,
1315    PROTO: PrimInt + AsPrimitive<f32> + AsPrimitive<i8> + Send + Sync,
1316>(
1317    det_indices: Vec<(DetectBox, usize)>,
1318    mask_tensor: ArrayView2<MASK>,
1319    quant_masks: Quantization,
1320    protos: ArrayView3<PROTO>,
1321    quant_protos: Quantization,
1322    output_boxes: &mut Vec<DetectBox>,
1323) -> ProtoData {
1324    let mut mask_coefficients = Vec::with_capacity(det_indices.len());
1325    output_boxes.clear();
1326    for (det, idx) in det_indices {
1327        output_boxes.push(det);
1328        let row = mask_tensor.row(idx);
1329        mask_coefficients.push(
1330            row.iter()
1331                .map(|v| (v.as_() - quant_masks.zero_point as f32) * quant_masks.scale)
1332                .collect(),
1333        );
1334    }
1335    // Keep protos in raw int8 — GPU shader will dequantize per-texel.
1336    let protos_i8 = protos.map(|v| {
1337        let v_i8: i8 = v.as_();
1338        v_i8
1339    });
1340    ProtoData {
1341        mask_coefficients,
1342        protos: ProtoTensor::Quantized {
1343            protos: protos_i8,
1344            quantization: quant_protos,
1345        },
1346    }
1347}
1348
1349fn postprocess_yolo<'a, T>(
1350    output: &'a ArrayView2<'_, T>,
1351) -> (ArrayView2<'a, T>, ArrayView2<'a, T>) {
1352    let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1353    let scores_tensor = output.slice(s![4.., ..,]).reversed_axes();
1354    (boxes_tensor, scores_tensor)
1355}
1356
1357pub(crate) fn postprocess_yolo_seg<'a, T>(
1358    output: &'a ArrayView2<'_, T>,
1359    num_protos: usize,
1360) -> (ArrayView2<'a, T>, ArrayView2<'a, T>, ArrayView2<'a, T>) {
1361    assert!(
1362        output.shape()[0] > num_protos + 4,
1363        "Output shape is too short: {} <= {} + 4",
1364        output.shape()[0],
1365        num_protos
1366    );
1367    let num_classes = output.shape()[0] - 4 - num_protos;
1368    let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1369    let scores_tensor = output.slice(s![4..(num_classes + 4), ..,]).reversed_axes();
1370    let mask_tensor = output.slice(s![(num_classes + 4).., ..,]).reversed_axes();
1371    (boxes_tensor, scores_tensor, mask_tensor)
1372}
1373
1374pub(crate) fn postprocess_yolo_split_segdet<'a, 'b, 'c, BOX, SCORE, MASK>(
1375    boxes_tensor: ArrayView2<'a, BOX>,
1376    scores_tensor: ArrayView2<'b, SCORE>,
1377    mask_tensor: ArrayView2<'c, MASK>,
1378) -> (
1379    ArrayView2<'a, BOX>,
1380    ArrayView2<'b, SCORE>,
1381    ArrayView2<'c, MASK>,
1382) {
1383    let boxes_tensor = boxes_tensor.reversed_axes();
1384    let scores_tensor = scores_tensor.reversed_axes();
1385    let mask_tensor = mask_tensor.reversed_axes();
1386    (boxes_tensor, scores_tensor, mask_tensor)
1387}
1388
1389fn decode_segdet_f32<
1390    MASK: Float + AsPrimitive<f32> + Send + Sync,
1391    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1392>(
1393    boxes: Vec<(DetectBox, usize)>,
1394    masks: ArrayView2<MASK>,
1395    protos: ArrayView3<PROTO>,
1396) -> Result<Vec<(DetectBox, Array3<u8>)>, crate::DecoderError> {
1397    if boxes.is_empty() {
1398        return Ok(Vec::new());
1399    }
1400    if masks.shape()[1] != protos.shape()[2] {
1401        return Err(crate::DecoderError::InvalidShape(format!(
1402            "Mask coefficients count ({}) doesn't match protos channel count ({})",
1403            masks.shape()[1],
1404            protos.shape()[2],
1405        )));
1406    }
1407    boxes
1408        .into_par_iter()
1409        .map(|mut b| {
1410            let ind = b.1;
1411            let (protos, roi) = protobox(&protos, &b.0.bbox)?;
1412            b.0.bbox = roi;
1413            Ok((b.0, make_segmentation(masks.row(ind), protos.view())))
1414        })
1415        .collect()
1416}
1417
1418pub(crate) fn decode_segdet_quant<
1419    MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1420    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1421>(
1422    boxes: Vec<(DetectBox, usize)>,
1423    masks: ArrayView2<MASK>,
1424    protos: ArrayView3<PROTO>,
1425    quant_masks: Quantization,
1426    quant_protos: Quantization,
1427) -> Result<Vec<(DetectBox, Array3<u8>)>, crate::DecoderError> {
1428    if boxes.is_empty() {
1429        return Ok(Vec::new());
1430    }
1431    if masks.shape()[1] != protos.shape()[2] {
1432        return Err(crate::DecoderError::InvalidShape(format!(
1433            "Mask coefficients count ({}) doesn't match protos channel count ({})",
1434            masks.shape()[1],
1435            protos.shape()[2],
1436        )));
1437    }
1438
1439    let total_bits = MASK::zero().count_zeros() + PROTO::zero().count_zeros() + 5; // 32 protos is 2^5
1440    boxes
1441        .into_iter()
1442        .map(|mut b| {
1443            let i = b.1;
1444            let (protos, roi) = protobox(&protos, &b.0.bbox.to_canonical())?;
1445            b.0.bbox = roi;
1446            let seg = match total_bits {
1447                0..=64 => make_segmentation_quant::<MASK, PROTO, i64>(
1448                    masks.row(i),
1449                    protos.view(),
1450                    quant_masks,
1451                    quant_protos,
1452                ),
1453                65..=128 => make_segmentation_quant::<MASK, PROTO, i128>(
1454                    masks.row(i),
1455                    protos.view(),
1456                    quant_masks,
1457                    quant_protos,
1458                ),
1459                _ => {
1460                    return Err(crate::DecoderError::NotSupported(format!(
1461                        "Unsupported bit width ({total_bits}) for segmentation computation"
1462                    )));
1463                }
1464            };
1465            Ok((b.0, seg))
1466        })
1467        .collect()
1468}
1469
1470fn protobox<'a, T>(
1471    protos: &'a ArrayView3<T>,
1472    roi: &BoundingBox,
1473) -> Result<(ArrayView3<'a, T>, BoundingBox), crate::DecoderError> {
1474    let width = protos.dim().1 as f32;
1475    let height = protos.dim().0 as f32;
1476
1477    // Detect un-normalized bounding boxes (pixel-space coordinates).
1478    // protobox expects normalized coordinates in [0, 1]. ONNX models output
1479    // pixel-space boxes (e.g. 0-640) which must be normalized before calling
1480    // decode(). Without this check, pixel-space coordinates silently clamp to
1481    // the proto boundary, producing empty (0, 0, C) masks for every detection.
1482    //
1483    // The limit is set to 2.0 (not 1.01) because YOLO models legitimately
1484    // predict coordinates slightly > 1.0 for objects near frame edges.
1485    // Any value > 2.0 is clearly pixel-space (even the smallest practical
1486    // model input of 32×32 would produce coordinates >> 2.0).
1487    const NORM_LIMIT: f32 = 2.0;
1488    if roi.xmin > NORM_LIMIT
1489        || roi.ymin > NORM_LIMIT
1490        || roi.xmax > NORM_LIMIT
1491        || roi.ymax > NORM_LIMIT
1492    {
1493        return Err(crate::DecoderError::InvalidShape(format!(
1494            "Bounding box coordinates appear un-normalized (pixel-space). \
1495             Got bbox=({:.2}, {:.2}, {:.2}, {:.2}) but expected values in [0, 1]. \
1496             ONNX models output pixel-space boxes — normalize them by dividing by \
1497             the input dimensions before calling decode().",
1498            roi.xmin, roi.ymin, roi.xmax, roi.ymax,
1499        )));
1500    }
1501
1502    let roi = [
1503        (roi.xmin * width).clamp(0.0, width) as usize,
1504        (roi.ymin * height).clamp(0.0, height) as usize,
1505        (roi.xmax * width).clamp(0.0, width).ceil() as usize,
1506        (roi.ymax * height).clamp(0.0, height).ceil() as usize,
1507    ];
1508
1509    let roi_norm = [
1510        roi[0] as f32 / width,
1511        roi[1] as f32 / height,
1512        roi[2] as f32 / width,
1513        roi[3] as f32 / height,
1514    ]
1515    .into();
1516
1517    let cropped = protos.slice(s![roi[1]..roi[3], roi[0]..roi[2], ..]);
1518
1519    Ok((cropped, roi_norm))
1520}
1521
1522/// Compute a single instance segmentation mask from mask coefficients and
1523/// proto maps (float path).
1524///
1525/// Computes `sigmoid(coefficients · protos)` and maps to `[0, 255]`.
1526/// Returns an `(H, W, 1)` u8 array.
1527fn make_segmentation<
1528    MASK: Float + AsPrimitive<f32> + Send + Sync,
1529    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1530>(
1531    mask: ArrayView1<MASK>,
1532    protos: ArrayView3<PROTO>,
1533) -> Array3<u8> {
1534    let shape = protos.shape();
1535
1536    // Safe to unwrap since the shapes will always be compatible
1537    let mask = mask.to_shape((1, mask.len())).unwrap();
1538    let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
1539    let protos = protos.reversed_axes();
1540    let mask = mask.map(|x| x.as_());
1541    let protos = protos.map(|x| x.as_());
1542
1543    // Safe to unwrap since the shapes will always be compatible
1544    let mask = mask
1545        .dot(&protos)
1546        .into_shape_with_order((shape[0], shape[1], 1))
1547        .unwrap();
1548
1549    mask.map(|x| {
1550        let sigmoid = 1.0 / (1.0 + (-*x).exp());
1551        (sigmoid * 255.0).round() as u8
1552    })
1553}
1554
1555/// Compute a single instance segmentation mask from quantized mask
1556/// coefficients and proto maps.
1557///
1558/// Dequantizes both inputs (subtracting zero-points), computes the dot
1559/// product, applies sigmoid, and maps to `[0, 255]`.
1560/// Returns an `(H, W, 1)` u8 array.
1561fn make_segmentation_quant<
1562    MASK: PrimInt + AsPrimitive<DEST> + Send + Sync,
1563    PROTO: PrimInt + AsPrimitive<DEST> + Send + Sync,
1564    DEST: PrimInt + 'static + Signed + AsPrimitive<f32> + Debug,
1565>(
1566    mask: ArrayView1<MASK>,
1567    protos: ArrayView3<PROTO>,
1568    quant_masks: Quantization,
1569    quant_protos: Quantization,
1570) -> Array3<u8>
1571where
1572    i32: AsPrimitive<DEST>,
1573    f32: AsPrimitive<DEST>,
1574{
1575    let shape = protos.shape();
1576
1577    // Safe to unwrap since the shapes will always be compatible
1578    let mask = mask.to_shape((1, mask.len())).unwrap();
1579
1580    let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
1581    let protos = protos.reversed_axes();
1582
1583    let zp = quant_masks.zero_point.as_();
1584
1585    let mask = mask.mapv(|x| x.as_() - zp);
1586
1587    let zp = quant_protos.zero_point.as_();
1588    let protos = protos.mapv(|x| x.as_() - zp);
1589
1590    // Safe to unwrap since the shapes will always be compatible
1591    let segmentation = mask
1592        .dot(&protos)
1593        .into_shape_with_order((shape[0], shape[1], 1))
1594        .unwrap();
1595
1596    let combined_scale = quant_masks.scale * quant_protos.scale;
1597    segmentation.map(|x| {
1598        let val: f32 = (*x).as_() * combined_scale;
1599        let sigmoid = 1.0 / (1.0 + (-val).exp());
1600        (sigmoid * 255.0).round() as u8
1601    })
1602}
1603
1604/// Converts Yolo Instance Segmentation into a 2D mask.
1605///
1606/// The input segmentation is expected to have shape (H, W, 1).
1607///
1608/// The output mask will have shape (H, W), with values 0 or 1 based on the
1609/// threshold.
1610///
1611/// # Errors
1612///
1613/// Returns `DecoderError::InvalidShape` if the input segmentation does not
1614/// have shape (H, W, 1).
1615pub fn yolo_segmentation_to_mask(
1616    segmentation: ArrayView3<u8>,
1617    threshold: u8,
1618) -> Result<Array2<u8>, crate::DecoderError> {
1619    if segmentation.shape()[2] != 1 {
1620        return Err(crate::DecoderError::InvalidShape(format!(
1621            "Yolo Instance Segmentation should have shape (H, W, 1), got (H, W, {})",
1622            segmentation.shape()[2]
1623        )));
1624    }
1625    Ok(segmentation
1626        .slice(s![.., .., 0])
1627        .map(|x| if *x >= threshold { 1 } else { 0 }))
1628}
1629
1630#[cfg(test)]
1631#[cfg_attr(coverage_nightly, coverage(off))]
1632mod tests {
1633    use super::*;
1634    use ndarray::Array2;
1635
1636    // ========================================================================
1637    // Tests for decode_yolo_end_to_end_det_float
1638    // ========================================================================
1639
1640    #[test]
1641    fn test_end_to_end_det_basic_filtering() {
1642        // Create synthetic end-to-end detection output: (6, N) where rows are
1643        // [x1, y1, x2, y2, conf, class]
1644        // 3 detections: one above threshold, two below
1645        let data: Vec<f32> = vec![
1646            // Detection 0: high score (0.9)
1647            0.1, 0.2, 0.3, // x1 values
1648            0.1, 0.2, 0.3, // y1 values
1649            0.5, 0.6, 0.7, // x2 values
1650            0.5, 0.6, 0.7, // y2 values
1651            0.9, 0.1, 0.2, // confidence scores
1652            0.0, 1.0, 2.0, // class indices
1653        ];
1654        let output = Array2::from_shape_vec((6, 3), data).unwrap();
1655
1656        let mut boxes = Vec::with_capacity(10);
1657        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1658
1659        // Only 1 detection should pass threshold of 0.5
1660        assert_eq!(boxes.len(), 1);
1661        assert_eq!(boxes[0].label, 0);
1662        assert!((boxes[0].score - 0.9).abs() < 0.01);
1663        assert!((boxes[0].bbox.xmin - 0.1).abs() < 0.01);
1664        assert!((boxes[0].bbox.ymin - 0.1).abs() < 0.01);
1665        assert!((boxes[0].bbox.xmax - 0.5).abs() < 0.01);
1666        assert!((boxes[0].bbox.ymax - 0.5).abs() < 0.01);
1667    }
1668
1669    #[test]
1670    fn test_end_to_end_det_all_pass_threshold() {
1671        // All detections above threshold
1672        let data: Vec<f32> = vec![
1673            10.0, 20.0, // x1
1674            10.0, 20.0, // y1
1675            50.0, 60.0, // x2
1676            50.0, 60.0, // y2
1677            0.8, 0.7, // conf (both above 0.5)
1678            1.0, 2.0, // class
1679        ];
1680        let output = Array2::from_shape_vec((6, 2), data).unwrap();
1681
1682        let mut boxes = Vec::with_capacity(10);
1683        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1684
1685        assert_eq!(boxes.len(), 2);
1686        assert_eq!(boxes[0].label, 1);
1687        assert_eq!(boxes[1].label, 2);
1688    }
1689
1690    #[test]
1691    fn test_end_to_end_det_none_pass_threshold() {
1692        // All detections below threshold
1693        let data: Vec<f32> = vec![
1694            10.0, 20.0, // x1
1695            10.0, 20.0, // y1
1696            50.0, 60.0, // x2
1697            50.0, 60.0, // y2
1698            0.1, 0.2, // conf (both below 0.5)
1699            1.0, 2.0, // class
1700        ];
1701        let output = Array2::from_shape_vec((6, 2), data).unwrap();
1702
1703        let mut boxes = Vec::with_capacity(10);
1704        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1705
1706        assert_eq!(boxes.len(), 0);
1707    }
1708
1709    #[test]
1710    fn test_end_to_end_det_capacity_limit() {
1711        // Test that output is truncated to capacity
1712        let data: Vec<f32> = vec![
1713            0.1, 0.2, 0.3, 0.4, 0.5, // x1
1714            0.1, 0.2, 0.3, 0.4, 0.5, // y1
1715            0.5, 0.6, 0.7, 0.8, 0.9, // x2
1716            0.5, 0.6, 0.7, 0.8, 0.9, // y2
1717            0.9, 0.9, 0.9, 0.9, 0.9, // conf (all pass)
1718            0.0, 1.0, 2.0, 3.0, 4.0, // class
1719        ];
1720        let output = Array2::from_shape_vec((6, 5), data).unwrap();
1721
1722        let mut boxes = Vec::with_capacity(2); // Only allow 2 boxes
1723        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1724
1725        assert_eq!(boxes.len(), 2);
1726    }
1727
1728    #[test]
1729    fn test_end_to_end_det_empty_output() {
1730        // Test with zero detections
1731        let output = Array2::<f32>::zeros((6, 0));
1732
1733        let mut boxes = Vec::with_capacity(10);
1734        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1735
1736        assert_eq!(boxes.len(), 0);
1737    }
1738
1739    #[test]
1740    fn test_end_to_end_det_pixel_coordinates() {
1741        // Test with pixel coordinates (non-normalized)
1742        let data: Vec<f32> = vec![
1743            100.0, // x1
1744            200.0, // y1
1745            300.0, // x2
1746            400.0, // y2
1747            0.95,  // conf
1748            5.0,   // class
1749        ];
1750        let output = Array2::from_shape_vec((6, 1), data).unwrap();
1751
1752        let mut boxes = Vec::with_capacity(10);
1753        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1754
1755        assert_eq!(boxes.len(), 1);
1756        assert_eq!(boxes[0].label, 5);
1757        assert!((boxes[0].bbox.xmin - 100.0).abs() < 0.01);
1758        assert!((boxes[0].bbox.ymin - 200.0).abs() < 0.01);
1759        assert!((boxes[0].bbox.xmax - 300.0).abs() < 0.01);
1760        assert!((boxes[0].bbox.ymax - 400.0).abs() < 0.01);
1761    }
1762
1763    #[test]
1764    fn test_end_to_end_det_invalid_shape() {
1765        // Test with too few rows (needs at least 6)
1766        let output = Array2::<f32>::zeros((5, 3));
1767
1768        let mut boxes = Vec::with_capacity(10);
1769        let result = decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes);
1770
1771        assert!(result.is_err());
1772        assert!(matches!(
1773            result,
1774            Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 6 rows")
1775        ));
1776    }
1777
1778    // ========================================================================
1779    // Tests for decode_yolo_end_to_end_segdet_float
1780    // ========================================================================
1781
1782    #[test]
1783    fn test_end_to_end_segdet_basic() {
1784        // Create synthetic segdet output: (6 + num_protos, N)
1785        // Detection format: [x1, y1, x2, y2, conf, class, mask_coeff_0..31]
1786        let num_protos = 32;
1787        let num_detections = 2;
1788        let num_features = 6 + num_protos;
1789
1790        // Build detection tensor
1791        let mut data = vec![0.0f32; num_features * num_detections];
1792        // Detection 0: passes threshold
1793        data[0] = 0.1; // x1[0]
1794        data[1] = 0.5; // x1[1]
1795        data[num_detections] = 0.1; // y1[0]
1796        data[num_detections + 1] = 0.5; // y1[1]
1797        data[2 * num_detections] = 0.4; // x2[0]
1798        data[2 * num_detections + 1] = 0.9; // x2[1]
1799        data[3 * num_detections] = 0.4; // y2[0]
1800        data[3 * num_detections + 1] = 0.9; // y2[1]
1801        data[4 * num_detections] = 0.9; // conf[0] - passes
1802        data[4 * num_detections + 1] = 0.3; // conf[1] - fails
1803        data[5 * num_detections] = 1.0; // class[0]
1804        data[5 * num_detections + 1] = 2.0; // class[1]
1805                                            // Fill mask coefficients with small values
1806        for i in 6..num_features {
1807            data[i * num_detections] = 0.1;
1808            data[i * num_detections + 1] = 0.1;
1809        }
1810
1811        let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
1812
1813        // Create protos tensor: (proto_height, proto_width, num_protos)
1814        let protos = Array3::<f32>::zeros((16, 16, num_protos));
1815
1816        let mut boxes = Vec::with_capacity(10);
1817        let mut masks = Vec::with_capacity(10);
1818        decode_yolo_end_to_end_segdet_float(
1819            output.view(),
1820            protos.view(),
1821            0.5,
1822            &mut boxes,
1823            &mut masks,
1824        )
1825        .unwrap();
1826
1827        // Only detection 0 should pass
1828        assert_eq!(boxes.len(), 1);
1829        assert_eq!(masks.len(), 1);
1830        assert_eq!(boxes[0].label, 1);
1831        assert!((boxes[0].score - 0.9).abs() < 0.01);
1832    }
1833
1834    #[test]
1835    fn test_end_to_end_segdet_mask_coordinates() {
1836        // Test that mask coordinates match box coordinates
1837        let num_protos = 32;
1838        let num_features = 6 + num_protos;
1839
1840        let mut data = vec![0.0f32; num_features];
1841        data[0] = 0.2; // x1
1842        data[1] = 0.2; // y1
1843        data[2] = 0.8; // x2
1844        data[3] = 0.8; // y2
1845        data[4] = 0.95; // conf
1846        data[5] = 3.0; // class
1847
1848        let output = Array2::from_shape_vec((num_features, 1), data).unwrap();
1849        let protos = Array3::<f32>::zeros((16, 16, num_protos));
1850
1851        let mut boxes = Vec::with_capacity(10);
1852        let mut masks = Vec::with_capacity(10);
1853        decode_yolo_end_to_end_segdet_float(
1854            output.view(),
1855            protos.view(),
1856            0.5,
1857            &mut boxes,
1858            &mut masks,
1859        )
1860        .unwrap();
1861
1862        assert_eq!(boxes.len(), 1);
1863        assert_eq!(masks.len(), 1);
1864
1865        // Verify mask coordinates match box coordinates
1866        assert!((masks[0].xmin - boxes[0].bbox.xmin).abs() < 0.01);
1867        assert!((masks[0].ymin - boxes[0].bbox.ymin).abs() < 0.01);
1868        assert!((masks[0].xmax - boxes[0].bbox.xmax).abs() < 0.01);
1869        assert!((masks[0].ymax - boxes[0].bbox.ymax).abs() < 0.01);
1870    }
1871
1872    #[test]
1873    fn test_end_to_end_segdet_empty_output() {
1874        let num_protos = 32;
1875        let output = Array2::<f32>::zeros((6 + num_protos, 0));
1876        let protos = Array3::<f32>::zeros((16, 16, num_protos));
1877
1878        let mut boxes = Vec::with_capacity(10);
1879        let mut masks = Vec::with_capacity(10);
1880        decode_yolo_end_to_end_segdet_float(
1881            output.view(),
1882            protos.view(),
1883            0.5,
1884            &mut boxes,
1885            &mut masks,
1886        )
1887        .unwrap();
1888
1889        assert_eq!(boxes.len(), 0);
1890        assert_eq!(masks.len(), 0);
1891    }
1892
1893    #[test]
1894    fn test_end_to_end_segdet_capacity_limit() {
1895        let num_protos = 32;
1896        let num_detections = 5;
1897        let num_features = 6 + num_protos;
1898
1899        let mut data = vec![0.0f32; num_features * num_detections];
1900        // All detections pass threshold
1901        for i in 0..num_detections {
1902            data[i] = 0.1 * (i as f32); // x1
1903            data[num_detections + i] = 0.1 * (i as f32); // y1
1904            data[2 * num_detections + i] = 0.1 * (i as f32) + 0.2; // x2
1905            data[3 * num_detections + i] = 0.1 * (i as f32) + 0.2; // y2
1906            data[4 * num_detections + i] = 0.9; // conf
1907            data[5 * num_detections + i] = i as f32; // class
1908        }
1909
1910        let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
1911        let protos = Array3::<f32>::zeros((16, 16, num_protos));
1912
1913        let mut boxes = Vec::with_capacity(2); // Limit to 2
1914        let mut masks = Vec::with_capacity(2);
1915        decode_yolo_end_to_end_segdet_float(
1916            output.view(),
1917            protos.view(),
1918            0.5,
1919            &mut boxes,
1920            &mut masks,
1921        )
1922        .unwrap();
1923
1924        assert_eq!(boxes.len(), 2);
1925        assert_eq!(masks.len(), 2);
1926    }
1927
1928    #[test]
1929    fn test_end_to_end_segdet_invalid_shape_too_few_rows() {
1930        // Test with too few rows (needs at least 7: 6 base + 1 mask coeff)
1931        let output = Array2::<f32>::zeros((6, 3));
1932        let protos = Array3::<f32>::zeros((16, 16, 32));
1933
1934        let mut boxes = Vec::with_capacity(10);
1935        let mut masks = Vec::with_capacity(10);
1936        let result = decode_yolo_end_to_end_segdet_float(
1937            output.view(),
1938            protos.view(),
1939            0.5,
1940            &mut boxes,
1941            &mut masks,
1942        );
1943
1944        assert!(result.is_err());
1945        assert!(matches!(
1946            result,
1947            Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 7 rows")
1948        ));
1949    }
1950
1951    #[test]
1952    fn test_end_to_end_segdet_invalid_shape_protos_mismatch() {
1953        // Test with mismatched mask coefficients and protos count
1954        let num_protos = 32;
1955        let output = Array2::<f32>::zeros((6 + 16, 3)); // 16 mask coeffs
1956        let protos = Array3::<f32>::zeros((16, 16, num_protos)); // 32 protos
1957
1958        let mut boxes = Vec::with_capacity(10);
1959        let mut masks = Vec::with_capacity(10);
1960        let result = decode_yolo_end_to_end_segdet_float(
1961            output.view(),
1962            protos.view(),
1963            0.5,
1964            &mut boxes,
1965            &mut masks,
1966        );
1967
1968        assert!(result.is_err());
1969        assert!(matches!(
1970            result,
1971            Err(crate::DecoderError::InvalidShape(s)) if s.contains("doesn't match protos count")
1972        ));
1973    }
1974
1975    // ========================================================================
1976    // Tests for decode_yolo_split_end_to_end_segdet_float
1977    // ========================================================================
1978
1979    #[test]
1980    fn test_split_end_to_end_segdet_basic() {
1981        // Create synthetic segdet output: (6 + num_protos, N)
1982        // Detection format: [x1, y1, x2, y2, conf, class, mask_coeff_0..31]
1983        let num_protos = 32;
1984        let num_detections = 2;
1985        let num_features = 6 + num_protos;
1986
1987        // Build detection tensor
1988        let mut data = vec![0.0f32; num_features * num_detections];
1989        // Detection 0: passes threshold
1990        data[0] = 0.1; // x1[0]
1991        data[1] = 0.5; // x1[1]
1992        data[num_detections] = 0.1; // y1[0]
1993        data[num_detections + 1] = 0.5; // y1[1]
1994        data[2 * num_detections] = 0.4; // x2[0]
1995        data[2 * num_detections + 1] = 0.9; // x2[1]
1996        data[3 * num_detections] = 0.4; // y2[0]
1997        data[3 * num_detections + 1] = 0.9; // y2[1]
1998        data[4 * num_detections] = 0.9; // conf[0] - passes
1999        data[4 * num_detections + 1] = 0.3; // conf[1] - fails
2000        data[5 * num_detections] = 1.0; // class[0]
2001        data[5 * num_detections + 1] = 2.0; // class[1]
2002                                            // Fill mask coefficients with small values
2003        for i in 6..num_features {
2004            data[i * num_detections] = 0.1;
2005            data[i * num_detections + 1] = 0.1;
2006        }
2007
2008        let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
2009        let box_coords = output.slice(s![..4, ..]);
2010        let scores = output.slice(s![4..5, ..]);
2011        let classes = output.slice(s![5..6, ..]);
2012        let mask_coeff = output.slice(s![6.., ..]);
2013        // Create protos tensor: (proto_height, proto_width, num_protos)
2014        let protos = Array3::<f32>::zeros((16, 16, num_protos));
2015
2016        let mut boxes = Vec::with_capacity(10);
2017        let mut masks = Vec::with_capacity(10);
2018        decode_yolo_split_end_to_end_segdet_float(
2019            box_coords,
2020            scores,
2021            classes,
2022            mask_coeff,
2023            protos.view(),
2024            0.5,
2025            &mut boxes,
2026            &mut masks,
2027        )
2028        .unwrap();
2029
2030        // Only detection 0 should pass
2031        assert_eq!(boxes.len(), 1);
2032        assert_eq!(masks.len(), 1);
2033        assert_eq!(boxes[0].label, 1);
2034        assert!((boxes[0].score - 0.9).abs() < 0.01);
2035    }
2036
2037    // ========================================================================
2038    // Tests for yolo_segmentation_to_mask
2039    // ========================================================================
2040
2041    #[test]
2042    fn test_segmentation_to_mask_basic() {
2043        // Create a 4x4x1 segmentation with values above and below threshold
2044        let data: Vec<u8> = vec![
2045            100, 200, 50, 150, // row 0
2046            10, 255, 128, 64, // row 1
2047            0, 127, 128, 255, // row 2
2048            64, 64, 192, 192, // row 3
2049        ];
2050        let segmentation = Array3::from_shape_vec((4, 4, 1), data).unwrap();
2051
2052        let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2053
2054        // Values >= 128 should be 1, others 0
2055        assert_eq!(mask[[0, 0]], 0); // 100 < 128
2056        assert_eq!(mask[[0, 1]], 1); // 200 >= 128
2057        assert_eq!(mask[[0, 2]], 0); // 50 < 128
2058        assert_eq!(mask[[0, 3]], 1); // 150 >= 128
2059        assert_eq!(mask[[1, 1]], 1); // 255 >= 128
2060        assert_eq!(mask[[1, 2]], 1); // 128 >= 128
2061        assert_eq!(mask[[2, 0]], 0); // 0 < 128
2062        assert_eq!(mask[[2, 1]], 0); // 127 < 128
2063    }
2064
2065    #[test]
2066    fn test_segmentation_to_mask_all_above() {
2067        let segmentation = Array3::from_elem((4, 4, 1), 255u8);
2068        let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2069        assert!(mask.iter().all(|&x| x == 1));
2070    }
2071
2072    #[test]
2073    fn test_segmentation_to_mask_all_below() {
2074        let segmentation = Array3::from_elem((4, 4, 1), 64u8);
2075        let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2076        assert!(mask.iter().all(|&x| x == 0));
2077    }
2078
2079    #[test]
2080    fn test_segmentation_to_mask_invalid_shape() {
2081        let segmentation = Array3::from_elem((4, 4, 3), 128u8);
2082        let result = yolo_segmentation_to_mask(segmentation.view(), 128);
2083
2084        assert!(result.is_err());
2085        assert!(matches!(
2086            result,
2087            Err(crate::DecoderError::InvalidShape(s)) if s.contains("(H, W, 1)")
2088        ));
2089    }
2090
2091    // ========================================================================
2092    // Tests for protobox / NORM_LIMIT regression
2093    // ========================================================================
2094
2095    #[test]
2096    fn test_protobox_clamps_edge_coordinates() {
2097        // bbox with xmax=1.0 should not panic (OOB guard)
2098        let protos = Array3::<f32>::zeros((16, 16, 4));
2099        let view = protos.view();
2100        let roi = BoundingBox {
2101            xmin: 0.5,
2102            ymin: 0.5,
2103            xmax: 1.0,
2104            ymax: 1.0,
2105        };
2106        let result = protobox(&view, &roi);
2107        assert!(result.is_ok(), "protobox should accept xmax=1.0");
2108        let (cropped, _roi_norm) = result.unwrap();
2109        // Cropped region must have non-zero spatial dimensions
2110        assert!(cropped.shape()[0] > 0);
2111        assert!(cropped.shape()[1] > 0);
2112        assert_eq!(cropped.shape()[2], 4);
2113    }
2114
2115    #[test]
2116    fn test_protobox_rejects_wildly_out_of_range() {
2117        // bbox with coords > NORM_LIMIT (e.g. 3.0) returns error
2118        let protos = Array3::<f32>::zeros((16, 16, 4));
2119        let view = protos.view();
2120        let roi = BoundingBox {
2121            xmin: 0.0,
2122            ymin: 0.0,
2123            xmax: 3.0,
2124            ymax: 3.0,
2125        };
2126        let result = protobox(&view, &roi);
2127        assert!(
2128            matches!(result, Err(crate::DecoderError::InvalidShape(s)) if s.contains("un-normalized")),
2129            "protobox should reject coords > NORM_LIMIT"
2130        );
2131    }
2132
2133    #[test]
2134    fn test_protobox_accepts_slightly_over_one() {
2135        // bbox with coords at 1.5 (within NORM_LIMIT=2.0) succeeds
2136        let protos = Array3::<f32>::zeros((16, 16, 4));
2137        let view = protos.view();
2138        let roi = BoundingBox {
2139            xmin: 0.0,
2140            ymin: 0.0,
2141            xmax: 1.5,
2142            ymax: 1.5,
2143        };
2144        let result = protobox(&view, &roi);
2145        assert!(
2146            result.is_ok(),
2147            "protobox should accept coords <= NORM_LIMIT (2.0)"
2148        );
2149        let (cropped, _roi_norm) = result.unwrap();
2150        // Entire proto map should be selected when coords > 1.0 (clamped to boundary)
2151        assert_eq!(cropped.shape()[0], 16);
2152        assert_eq!(cropped.shape()[1], 16);
2153    }
2154
2155    #[test]
2156    fn test_segdet_float_proto_no_panic() {
2157        // Simulates YOLOv8n-seg: output0 = [116, 8400] (4 box + 80 class + 32 mask coeff)
2158        // output1 (protos) = [32, 160, 160]
2159        let num_proposals = 100; // enough to produce idx >= 32
2160        let num_classes = 80;
2161        let num_mask_coeffs = 32;
2162        let rows = 4 + num_classes + num_mask_coeffs; // 116
2163
2164        // Fill boxes with valid xywh data so some detections pass the threshold.
2165        // Layout is [116, num_proposals] row-major: row 0=cx, 1=cy, 2=w, 3=h,
2166        // rows 4..84=class scores, rows 84..116=mask coefficients.
2167        let mut data = vec![0.0f32; rows * num_proposals];
2168        for i in 0..num_proposals {
2169            let row = |r: usize| r * num_proposals + i;
2170            data[row(0)] = 320.0; // cx
2171            data[row(1)] = 320.0; // cy
2172            data[row(2)] = 50.0; // w
2173            data[row(3)] = 50.0; // h
2174            data[row(4)] = 0.9; // class-0 score
2175        }
2176        let boxes = ndarray::Array2::from_shape_vec((rows, num_proposals), data).unwrap();
2177
2178        // Protos must be in HWC order (decoder.rs protos_to_hwc converts
2179        // before calling into these functions).
2180        let protos = ndarray::Array3::<f32>::zeros((160, 160, num_mask_coeffs));
2181
2182        let mut output_boxes = Vec::with_capacity(300);
2183
2184        // This panicked before fix: mask_tensor.row(idx) with idx >= 32
2185        let proto_data = impl_yolo_segdet_float_proto::<XYWH, _, _>(
2186            boxes.view(),
2187            protos.view(),
2188            0.5,
2189            0.7,
2190            Some(Nms::default()),
2191            &mut output_boxes,
2192        );
2193
2194        // Should produce detections (NMS will collapse many overlapping boxes)
2195        assert!(!output_boxes.is_empty());
2196        assert_eq!(proto_data.mask_coefficients.len(), output_boxes.len());
2197        // Each mask coefficient vector should have 32 elements
2198        for coeffs in &proto_data.mask_coefficients {
2199            assert_eq!(coeffs.len(), num_mask_coeffs);
2200        }
2201    }
2202}