Skip to main content

edgefirst_decoder/
yolo.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use std::fmt::Debug;
5
6use ndarray::{
7    parallel::prelude::{IntoParallelIterator, ParallelIterator},
8    s, Array2, Array3, ArrayView1, ArrayView2, ArrayView3,
9};
10use num_traits::{AsPrimitive, Float, PrimInt, Signed};
11
12use crate::{
13    byte::{
14        nms_class_aware_int, nms_extra_class_aware_int, nms_extra_int, nms_int,
15        postprocess_boxes_index_quant, postprocess_boxes_quant, quantize_score_threshold,
16    },
17    configs::Nms,
18    dequant_detect_box,
19    float::{
20        nms_class_aware_float, nms_extra_class_aware_float, nms_extra_float, nms_float,
21        postprocess_boxes_float, postprocess_boxes_index_float,
22    },
23    BBoxTypeTrait, BoundingBox, DetectBox, DetectBoxQuantized, ProtoData, ProtoTensor,
24    Quantization, Segmentation, XYWH, XYXY,
25};
26
27/// Dispatches to the appropriate NMS function based on mode for float boxes.
28fn dispatch_nms_float(nms: Option<Nms>, iou: f32, boxes: Vec<DetectBox>) -> Vec<DetectBox> {
29    match nms {
30        Some(Nms::ClassAgnostic) => nms_float(iou, boxes),
31        Some(Nms::ClassAware) => nms_class_aware_float(iou, boxes),
32        None => boxes, // bypass NMS
33    }
34}
35
36/// Dispatches to the appropriate NMS function based on mode for float boxes
37/// with extra data.
38pub(super) fn dispatch_nms_extra_float<E: Send + Sync>(
39    nms: Option<Nms>,
40    iou: f32,
41    boxes: Vec<(DetectBox, E)>,
42) -> Vec<(DetectBox, E)> {
43    match nms {
44        Some(Nms::ClassAgnostic) => nms_extra_float(iou, boxes),
45        Some(Nms::ClassAware) => nms_extra_class_aware_float(iou, boxes),
46        None => boxes, // bypass NMS
47    }
48}
49
50/// Dispatches to the appropriate NMS function based on mode for quantized
51/// boxes.
52fn dispatch_nms_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync>(
53    nms: Option<Nms>,
54    iou: f32,
55    boxes: Vec<DetectBoxQuantized<SCORE>>,
56) -> Vec<DetectBoxQuantized<SCORE>> {
57    match nms {
58        Some(Nms::ClassAgnostic) => nms_int(iou, boxes),
59        Some(Nms::ClassAware) => nms_class_aware_int(iou, boxes),
60        None => boxes, // bypass NMS
61    }
62}
63
64/// Dispatches to the appropriate NMS function based on mode for quantized boxes
65/// with extra data.
66fn dispatch_nms_extra_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send + Sync>(
67    nms: Option<Nms>,
68    iou: f32,
69    boxes: Vec<(DetectBoxQuantized<SCORE>, E)>,
70) -> Vec<(DetectBoxQuantized<SCORE>, E)> {
71    match nms {
72        Some(Nms::ClassAgnostic) => nms_extra_int(iou, boxes),
73        Some(Nms::ClassAware) => nms_extra_class_aware_int(iou, boxes),
74        None => boxes, // bypass NMS
75    }
76}
77
78/// Decodes YOLO detection outputs from quantized tensors into detection boxes.
79///
80/// Boxes are expected to be in XYWH format.
81///
82/// Expected shapes of inputs:
83/// - output: (4 + num_classes, num_boxes)
84pub fn decode_yolo_det<BOX: PrimInt + AsPrimitive<f32> + Send + Sync>(
85    output: (ArrayView2<BOX>, Quantization),
86    score_threshold: f32,
87    iou_threshold: f32,
88    nms: Option<Nms>,
89    output_boxes: &mut Vec<DetectBox>,
90) where
91    f32: AsPrimitive<BOX>,
92{
93    impl_yolo_quant::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
94}
95
96/// Decodes YOLO detection outputs from float tensors into detection boxes.
97///
98/// Boxes are expected to be in XYWH format.
99///
100/// Expected shapes of inputs:
101/// - output: (4 + num_classes, num_boxes)
102pub fn decode_yolo_det_float<T>(
103    output: ArrayView2<T>,
104    score_threshold: f32,
105    iou_threshold: f32,
106    nms: Option<Nms>,
107    output_boxes: &mut Vec<DetectBox>,
108) where
109    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
110    f32: AsPrimitive<T>,
111{
112    impl_yolo_float::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
113}
114
115/// Decodes YOLO detection and segmentation outputs from quantized tensors into
116/// detection boxes and segmentation masks.
117///
118/// Boxes are expected to be in XYWH format.
119///
120/// Expected shapes of inputs:
121/// - boxes: (4 + num_classes + num_protos, num_boxes)
122/// - protos: (proto_height, proto_width, num_protos)
123///
124/// # Errors
125/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
126pub fn decode_yolo_segdet_quant<
127    BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
128    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
129>(
130    boxes: (ArrayView2<BOX>, Quantization),
131    protos: (ArrayView3<PROTO>, Quantization),
132    score_threshold: f32,
133    iou_threshold: f32,
134    nms: Option<Nms>,
135    output_boxes: &mut Vec<DetectBox>,
136    output_masks: &mut Vec<Segmentation>,
137) -> Result<(), crate::DecoderError>
138where
139    f32: AsPrimitive<BOX>,
140{
141    impl_yolo_segdet_quant::<XYWH, _, _>(
142        boxes,
143        protos,
144        score_threshold,
145        iou_threshold,
146        nms,
147        output_boxes,
148        output_masks,
149    )
150}
151
152/// Decodes YOLO detection and segmentation outputs from float tensors into
153/// detection boxes and segmentation masks.
154///
155/// Boxes are expected to be in XYWH format.
156///
157/// Expected shapes of inputs:
158/// - boxes: (4 + num_classes + num_protos, num_boxes)
159/// - protos: (proto_height, proto_width, num_protos)
160///
161/// # Errors
162/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
163pub fn decode_yolo_segdet_float<T>(
164    boxes: ArrayView2<T>,
165    protos: ArrayView3<T>,
166    score_threshold: f32,
167    iou_threshold: f32,
168    nms: Option<Nms>,
169    output_boxes: &mut Vec<DetectBox>,
170    output_masks: &mut Vec<Segmentation>,
171) -> Result<(), crate::DecoderError>
172where
173    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
174    f32: AsPrimitive<T>,
175{
176    impl_yolo_segdet_float::<XYWH, _, _>(
177        boxes,
178        protos,
179        score_threshold,
180        iou_threshold,
181        nms,
182        output_boxes,
183        output_masks,
184    )
185}
186
187/// Decodes YOLO split detection outputs from quantized tensors into detection
188/// boxes.
189///
190/// Boxes are expected to be in XYWH format.
191///
192/// Expected shapes of inputs:
193/// - boxes: (4, num_boxes)
194/// - scores: (num_classes, num_boxes)
195///
196/// # Panics
197/// Panics if shapes don't match the expected dimensions.
198pub fn decode_yolo_split_det_quant<
199    BOX: PrimInt + AsPrimitive<i32> + AsPrimitive<f32> + Send + Sync,
200    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
201>(
202    boxes: (ArrayView2<BOX>, Quantization),
203    scores: (ArrayView2<SCORE>, Quantization),
204    score_threshold: f32,
205    iou_threshold: f32,
206    nms: Option<Nms>,
207    output_boxes: &mut Vec<DetectBox>,
208) where
209    f32: AsPrimitive<SCORE>,
210{
211    impl_yolo_split_quant::<XYWH, _, _>(
212        boxes,
213        scores,
214        score_threshold,
215        iou_threshold,
216        nms,
217        output_boxes,
218    );
219}
220
221/// Decodes YOLO split detection outputs from float tensors into detection
222/// boxes.
223///
224/// Boxes are expected to be in XYWH format.
225///
226/// Expected shapes of inputs:
227/// - boxes: (4, num_boxes)
228/// - scores: (num_classes, num_boxes)
229///
230/// # Panics
231/// Panics if shapes don't match the expected dimensions.
232pub fn decode_yolo_split_det_float<T>(
233    boxes: ArrayView2<T>,
234    scores: ArrayView2<T>,
235    score_threshold: f32,
236    iou_threshold: f32,
237    nms: Option<Nms>,
238    output_boxes: &mut Vec<DetectBox>,
239) where
240    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
241    f32: AsPrimitive<T>,
242{
243    impl_yolo_split_float::<XYWH, _, _>(
244        boxes,
245        scores,
246        score_threshold,
247        iou_threshold,
248        nms,
249        output_boxes,
250    );
251}
252
253/// Decodes YOLO split detection segmentation outputs from quantized tensors
254/// into detection boxes and segmentation masks.
255///
256/// Boxes are expected to be in XYWH format.
257///
258/// Expected shapes of inputs:
259/// - boxes_tensor: (4, num_boxes)
260/// - scores_tensor: (num_classes, num_boxes)
261/// - mask_tensor: (num_protos, num_boxes)
262/// - protos: (proto_height, proto_width, num_protos)
263///
264/// # Errors
265/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
266#[allow(clippy::too_many_arguments)]
267pub fn decode_yolo_split_segdet<
268    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
269    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
270    MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
271    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
272>(
273    boxes: (ArrayView2<BOX>, Quantization),
274    scores: (ArrayView2<SCORE>, Quantization),
275    mask_coeff: (ArrayView2<MASK>, Quantization),
276    protos: (ArrayView3<PROTO>, Quantization),
277    score_threshold: f32,
278    iou_threshold: f32,
279    nms: Option<Nms>,
280    output_boxes: &mut Vec<DetectBox>,
281    output_masks: &mut Vec<Segmentation>,
282) -> Result<(), crate::DecoderError>
283where
284    f32: AsPrimitive<SCORE>,
285{
286    impl_yolo_split_segdet_quant::<XYWH, _, _, _, _>(
287        boxes,
288        scores,
289        mask_coeff,
290        protos,
291        score_threshold,
292        iou_threshold,
293        nms,
294        output_boxes,
295        output_masks,
296    )
297}
298
299/// Decodes YOLO split detection segmentation outputs from float tensors
300/// into detection boxes and segmentation masks.
301///
302/// Boxes are expected to be in XYWH format.
303///
304/// Expected shapes of inputs:
305/// - boxes_tensor: (4, num_boxes)
306/// - scores_tensor: (num_classes, num_boxes)
307/// - mask_tensor: (num_protos, num_boxes)
308/// - protos: (proto_height, proto_width, num_protos)
309///
310/// # Errors
311/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
312#[allow(clippy::too_many_arguments)]
313pub fn decode_yolo_split_segdet_float<T>(
314    boxes: ArrayView2<T>,
315    scores: ArrayView2<T>,
316    mask_coeff: ArrayView2<T>,
317    protos: ArrayView3<T>,
318    score_threshold: f32,
319    iou_threshold: f32,
320    nms: Option<Nms>,
321    output_boxes: &mut Vec<DetectBox>,
322    output_masks: &mut Vec<Segmentation>,
323) -> Result<(), crate::DecoderError>
324where
325    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
326    f32: AsPrimitive<T>,
327{
328    impl_yolo_split_segdet_float::<XYWH, _, _, _, _>(
329        boxes,
330        scores,
331        mask_coeff,
332        protos,
333        score_threshold,
334        iou_threshold,
335        nms,
336        output_boxes,
337        output_masks,
338    )
339}
340
341/// Decodes end-to-end YOLO detection outputs (post-NMS from model).
342/// Expects an array of shape `(6, N)`, where the first dimension (rows)
343/// corresponds to the 6 per-detection features
344/// `[x1, y1, x2, y2, conf, class]` and the second dimension (columns)
345/// indexes the `N` detections.
346/// Boxes are output directly without NMS (the model already applied NMS).
347///
348/// Coordinates may be normalized `[0, 1]` or absolute pixel values depending
349/// on the model configuration. The caller should check
350/// `decoder.normalized_boxes()` to determine which.
351///
352/// # Errors
353///
354/// Returns `DecoderError::InvalidShape` if `output` has fewer than 6 rows.
355pub fn decode_yolo_end_to_end_det_float<T>(
356    output: ArrayView2<T>,
357    score_threshold: f32,
358    output_boxes: &mut Vec<DetectBox>,
359) -> Result<(), crate::DecoderError>
360where
361    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
362    f32: AsPrimitive<T>,
363{
364    // Validate input shape: need at least 6 rows (x1, y1, x2, y2, conf, class)
365    if output.shape()[0] < 6 {
366        return Err(crate::DecoderError::InvalidShape(format!(
367            "End-to-end detection output requires at least 6 rows, got {}",
368            output.shape()[0]
369        )));
370    }
371
372    // Input shape: (6, N) -> transpose to (N, 4) for boxes and (N, 1) for scores
373    let boxes = output.slice(s![0..4, ..]).reversed_axes();
374    let scores = output.slice(s![4..5, ..]).reversed_axes();
375    let classes = output.slice(s![5, ..]);
376    let mut boxes =
377        postprocess_boxes_index_float::<XYXY, _, _>(score_threshold.as_(), boxes, scores);
378    boxes.truncate(output_boxes.capacity());
379    output_boxes.clear();
380    for (mut b, i) in boxes.into_iter() {
381        b.label = classes[i].as_() as usize;
382        output_boxes.push(b);
383    }
384    // No NMS — model output is already post-NMS
385    Ok(())
386}
387
388/// Decodes end-to-end YOLO detection + segmentation outputs (post-NMS from
389/// model).
390///
391/// Input shapes:
392/// - detection: (6 + num_protos, N) where rows are [x1, y1, x2, y2, conf,
393///   class, mask_coeff_0, ..., mask_coeff_31]
394/// - protos: (proto_height, proto_width, num_protos)
395///
396/// Boxes are output directly without NMS (model already applied NMS).
397/// Coordinates may be normalized [0,1] or pixel values depending on model
398/// config.
399///
400/// # Errors
401///
402/// Returns `DecoderError::InvalidShape` if:
403/// - output has fewer than 7 rows (6 base + at least 1 mask coefficient)
404/// - protos shape doesn't match mask coefficients count
405pub fn decode_yolo_end_to_end_segdet_float<T>(
406    output: ArrayView2<T>,
407    protos: ArrayView3<T>,
408    score_threshold: f32,
409    output_boxes: &mut Vec<DetectBox>,
410    output_masks: &mut Vec<crate::Segmentation>,
411) -> Result<(), crate::DecoderError>
412where
413    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
414    f32: AsPrimitive<T>,
415{
416    let (boxes, scores, classes, mask_coeff) =
417        postprocess_yolo_end_to_end_segdet(&output, protos.dim().2)?;
418    let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
419        boxes,
420        scores,
421        classes,
422        score_threshold,
423        output_boxes.capacity(),
424    );
425
426    // No NMS — model output is already post-NMS
427
428    impl_yolo_split_segdet_process_masks(boxes, mask_coeff, protos, output_boxes, output_masks)
429}
430
431/// Decodes split end-to-end YOLO detection outputs (post-NMS from model).
432///
433/// Input shapes (after batch dim removed):
434/// - boxes: (4, N) — xyxy pixel coordinates
435/// - scores: (1, N) — confidence of the top class
436/// - classes: (1, N) — class index of the top class
437///
438/// Boxes are output directly without NMS (model already applied NMS).
439pub fn decode_yolo_split_end_to_end_det_float<T: Float + AsPrimitive<f32>>(
440    boxes: ArrayView2<T>,
441    scores: ArrayView2<T>,
442    classes: ArrayView2<T>,
443    score_threshold: f32,
444    output_boxes: &mut Vec<DetectBox>,
445) -> Result<(), crate::DecoderError> {
446    let n = boxes.shape()[1];
447
448    output_boxes.clear();
449
450    let (boxes, scores, classes) = postprocess_yolo_split_end_to_end_det(boxes, scores, &classes)?;
451
452    for i in 0..n {
453        let score: f32 = scores[[i, 0]].as_();
454        if score < score_threshold {
455            continue;
456        }
457        if output_boxes.len() >= output_boxes.capacity() {
458            break;
459        }
460        output_boxes.push(DetectBox {
461            bbox: BoundingBox {
462                xmin: boxes[[i, 0]].as_(),
463                ymin: boxes[[i, 1]].as_(),
464                xmax: boxes[[i, 2]].as_(),
465                ymax: boxes[[i, 3]].as_(),
466            },
467            score,
468            label: classes[i].as_() as usize,
469        });
470    }
471    Ok(())
472}
473
474/// Decodes split end-to-end YOLO detection + segmentation outputs.
475///
476/// Input shapes (after batch dim removed):
477/// - boxes: (4, N) — xyxy pixel coordinates
478/// - scores: (1, N) — confidence
479/// - classes: (1, N) — class index
480/// - mask_coeff: (num_protos, N) — mask coefficients per detection
481/// - protos: (proto_h, proto_w, num_protos) — prototype masks
482#[allow(clippy::too_many_arguments)]
483pub fn decode_yolo_split_end_to_end_segdet_float<T>(
484    boxes: ArrayView2<T>,
485    scores: ArrayView2<T>,
486    classes: ArrayView2<T>,
487    mask_coeff: ArrayView2<T>,
488    protos: ArrayView3<T>,
489    score_threshold: f32,
490    output_boxes: &mut Vec<DetectBox>,
491    output_masks: &mut Vec<crate::Segmentation>,
492) -> Result<(), crate::DecoderError>
493where
494    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
495    f32: AsPrimitive<T>,
496{
497    let (boxes, scores, classes, mask_coeff) =
498        postprocess_yolo_split_end_to_end_segdet(boxes, scores, &classes, mask_coeff)?;
499    let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
500        boxes,
501        scores,
502        classes,
503        score_threshold,
504        output_boxes.capacity(),
505    );
506
507    impl_yolo_split_segdet_process_masks(boxes, mask_coeff, protos, output_boxes, output_masks)
508}
509
510#[allow(clippy::type_complexity)]
511pub(crate) fn postprocess_yolo_end_to_end_segdet<'a, T>(
512    output: &'a ArrayView2<'_, T>,
513    num_protos: usize,
514) -> Result<
515    (
516        ArrayView2<'a, T>,
517        ArrayView2<'a, T>,
518        ArrayView1<'a, T>,
519        ArrayView2<'a, T>,
520    ),
521    crate::DecoderError,
522> {
523    // Validate input shape: need at least 7 rows (6 base + at least 1 mask coeff)
524    if output.shape()[0] < 7 {
525        return Err(crate::DecoderError::InvalidShape(format!(
526            "End-to-end segdet output requires at least 7 rows, got {}",
527            output.shape()[0]
528        )));
529    }
530
531    let num_mask_coeffs = output.shape()[0] - 6;
532    if num_mask_coeffs != num_protos {
533        return Err(crate::DecoderError::InvalidShape(format!(
534            "Mask coefficients count ({}) doesn't match protos count ({})",
535            num_mask_coeffs, num_protos
536        )));
537    }
538
539    // Input shape: (6+num_protos, N) -> transpose for postprocessing
540    let boxes = output.slice(s![0..4, ..]).reversed_axes();
541    let scores = output.slice(s![4..5, ..]).reversed_axes();
542    let classes = output.slice(s![5, ..]);
543    let mask_coeff = output.slice(s![6.., ..]).reversed_axes();
544    Ok((boxes, scores, classes, mask_coeff))
545}
546
547/// Postprocess yolo split end to end det by reversing axes of boxes,
548/// scores, and flattening the class tensor.
549/// Expected input shapes:
550/// - boxes: (4, N)
551/// - scores: (1, N)
552/// - classes: (1, N)
553#[allow(clippy::type_complexity)]
554pub(crate) fn postprocess_yolo_split_end_to_end_det<'a, 'b, 'c, BOXES, SCORES, CLASS>(
555    boxes: ArrayView2<'a, BOXES>,
556    scores: ArrayView2<'b, SCORES>,
557    classes: &'c ArrayView2<CLASS>,
558) -> Result<
559    (
560        ArrayView2<'a, BOXES>,
561        ArrayView2<'b, SCORES>,
562        ArrayView1<'c, CLASS>,
563    ),
564    crate::DecoderError,
565> {
566    let num_boxes = boxes.shape()[1];
567    if boxes.shape()[0] != 4 {
568        return Err(crate::DecoderError::InvalidShape(format!(
569            "Split end-to-end box_coords must be 4, got {}",
570            boxes.shape()[0]
571        )));
572    }
573
574    if scores.shape()[0] != 1 {
575        return Err(crate::DecoderError::InvalidShape(format!(
576            "Split end-to-end scores num_classes must be 1, got {}",
577            scores.shape()[0]
578        )));
579    }
580
581    if classes.shape()[0] != 1 {
582        return Err(crate::DecoderError::InvalidShape(format!(
583            "Split end-to-end classes num_classes must be 1, got {}",
584            classes.shape()[0]
585        )));
586    }
587
588    if scores.shape()[1] != num_boxes {
589        return Err(crate::DecoderError::InvalidShape(format!(
590            "Split end-to-end scores must have same num_boxes as boxes ({}), got {}",
591            num_boxes,
592            scores.shape()[1]
593        )));
594    }
595
596    if classes.shape()[1] != num_boxes {
597        return Err(crate::DecoderError::InvalidShape(format!(
598            "Split end-to-end classes must have same num_boxes as boxes ({}), got {}",
599            num_boxes,
600            classes.shape()[1]
601        )));
602    }
603
604    let boxes = boxes.reversed_axes();
605    let scores = scores.reversed_axes();
606    let classes = classes.slice(s![0, ..]);
607    Ok((boxes, scores, classes))
608}
609
610/// Postprocess yolo split end to end segdet by reversing axes of boxes,
611/// scores, mask tensors and flattening the class tensor.
612#[allow(clippy::type_complexity)]
613pub(crate) fn postprocess_yolo_split_end_to_end_segdet<
614    'a,
615    'b,
616    'c,
617    'd,
618    BOXES,
619    SCORES,
620    CLASS,
621    MASK,
622>(
623    boxes: ArrayView2<'a, BOXES>,
624    scores: ArrayView2<'b, SCORES>,
625    classes: &'c ArrayView2<CLASS>,
626    mask_coeff: ArrayView2<'d, MASK>,
627) -> Result<
628    (
629        ArrayView2<'a, BOXES>,
630        ArrayView2<'b, SCORES>,
631        ArrayView1<'c, CLASS>,
632        ArrayView2<'d, MASK>,
633    ),
634    crate::DecoderError,
635> {
636    let num_boxes = boxes.shape()[1];
637    if boxes.shape()[0] != 4 {
638        return Err(crate::DecoderError::InvalidShape(format!(
639            "Split end-to-end box_coords must be 4, got {}",
640            boxes.shape()[0]
641        )));
642    }
643
644    if scores.shape()[0] != 1 {
645        return Err(crate::DecoderError::InvalidShape(format!(
646            "Split end-to-end scores num_classes must be 1, got {}",
647            scores.shape()[0]
648        )));
649    }
650
651    if classes.shape()[0] != 1 {
652        return Err(crate::DecoderError::InvalidShape(format!(
653            "Split end-to-end classes num_classes must be 1, got {}",
654            classes.shape()[0]
655        )));
656    }
657
658    if scores.shape()[1] != num_boxes {
659        return Err(crate::DecoderError::InvalidShape(format!(
660            "Split end-to-end scores must have same num_boxes as boxes ({}), got {}",
661            num_boxes,
662            scores.shape()[1]
663        )));
664    }
665
666    if classes.shape()[1] != num_boxes {
667        return Err(crate::DecoderError::InvalidShape(format!(
668            "Split end-to-end classes must have same num_boxes as boxes ({}), got {}",
669            num_boxes,
670            classes.shape()[1]
671        )));
672    }
673
674    if mask_coeff.shape()[1] != num_boxes {
675        return Err(crate::DecoderError::InvalidShape(format!(
676            "Split end-to-end mask_coeff must have same num_boxes as boxes ({}), got {}",
677            num_boxes,
678            mask_coeff.shape()[1]
679        )));
680    }
681
682    let boxes = boxes.reversed_axes();
683    let scores = scores.reversed_axes();
684    let classes = classes.slice(s![0, ..]);
685    let mask_coeff = mask_coeff.reversed_axes();
686    Ok((boxes, scores, classes, mask_coeff))
687}
688/// Internal implementation of YOLO decoding for quantized tensors.
689///
690/// Expected shapes of inputs:
691/// - output: (4 + num_classes, num_boxes)
692pub(crate) fn impl_yolo_quant<B: BBoxTypeTrait, T: PrimInt + AsPrimitive<f32> + Send + Sync>(
693    output: (ArrayView2<T>, Quantization),
694    score_threshold: f32,
695    iou_threshold: f32,
696    nms: Option<Nms>,
697    output_boxes: &mut Vec<DetectBox>,
698) where
699    f32: AsPrimitive<T>,
700{
701    let (boxes, quant_boxes) = output;
702    let (boxes_tensor, scores_tensor) = postprocess_yolo(&boxes);
703
704    let boxes = {
705        let score_threshold = quantize_score_threshold(score_threshold, quant_boxes);
706        postprocess_boxes_quant::<B, _, _>(
707            score_threshold,
708            boxes_tensor,
709            scores_tensor,
710            quant_boxes,
711        )
712    };
713
714    let boxes = dispatch_nms_int(nms, iou_threshold, boxes);
715    let len = output_boxes.capacity().min(boxes.len());
716    output_boxes.clear();
717    for b in boxes.iter().take(len) {
718        output_boxes.push(dequant_detect_box(b, quant_boxes));
719    }
720}
721
722/// Internal implementation of YOLO decoding for float tensors.
723///
724/// Expected shapes of inputs:
725/// - output: (4 + num_classes, num_boxes)
726pub(crate) fn impl_yolo_float<B: BBoxTypeTrait, T: Float + AsPrimitive<f32> + Send + Sync>(
727    output: ArrayView2<T>,
728    score_threshold: f32,
729    iou_threshold: f32,
730    nms: Option<Nms>,
731    output_boxes: &mut Vec<DetectBox>,
732) where
733    f32: AsPrimitive<T>,
734{
735    let (boxes_tensor, scores_tensor) = postprocess_yolo(&output);
736    let boxes =
737        postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
738    let boxes = dispatch_nms_float(nms, iou_threshold, boxes);
739    let len = output_boxes.capacity().min(boxes.len());
740    output_boxes.clear();
741    for b in boxes.into_iter().take(len) {
742        output_boxes.push(b);
743    }
744}
745
746/// Internal implementation of YOLO split detection decoding for quantized
747/// tensors.
748///
749/// Expected shapes of inputs:
750/// - boxes: (4, num_boxes)
751/// - scores: (num_classes, num_boxes)
752///
753/// # Panics
754/// Panics if shapes don't match the expected dimensions.
755pub(crate) fn impl_yolo_split_quant<
756    B: BBoxTypeTrait,
757    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
758    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
759>(
760    boxes: (ArrayView2<BOX>, Quantization),
761    scores: (ArrayView2<SCORE>, Quantization),
762    score_threshold: f32,
763    iou_threshold: f32,
764    nms: Option<Nms>,
765    output_boxes: &mut Vec<DetectBox>,
766) where
767    f32: AsPrimitive<SCORE>,
768{
769    let (boxes_tensor, quant_boxes) = boxes;
770    let (scores_tensor, quant_scores) = scores;
771
772    let boxes_tensor = boxes_tensor.reversed_axes();
773    let scores_tensor = scores_tensor.reversed_axes();
774
775    let boxes = {
776        let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
777        postprocess_boxes_quant::<B, _, _>(
778            score_threshold,
779            boxes_tensor,
780            scores_tensor,
781            quant_boxes,
782        )
783    };
784
785    let boxes = dispatch_nms_int(nms, iou_threshold, boxes);
786    let len = output_boxes.capacity().min(boxes.len());
787    output_boxes.clear();
788    for b in boxes.iter().take(len) {
789        output_boxes.push(dequant_detect_box(b, quant_scores));
790    }
791}
792
793/// Internal implementation of YOLO split detection decoding for float tensors.
794///
795/// Expected shapes of inputs:
796/// - boxes: (4, num_boxes)
797/// - scores: (num_classes, num_boxes)
798///
799/// # Panics
800/// Panics if shapes don't match the expected dimensions.
801pub(crate) fn impl_yolo_split_float<
802    B: BBoxTypeTrait,
803    BOX: Float + AsPrimitive<f32> + Send + Sync,
804    SCORE: Float + AsPrimitive<f32> + Send + Sync,
805>(
806    boxes_tensor: ArrayView2<BOX>,
807    scores_tensor: ArrayView2<SCORE>,
808    score_threshold: f32,
809    iou_threshold: f32,
810    nms: Option<Nms>,
811    output_boxes: &mut Vec<DetectBox>,
812) where
813    f32: AsPrimitive<SCORE>,
814{
815    let boxes_tensor = boxes_tensor.reversed_axes();
816    let scores_tensor = scores_tensor.reversed_axes();
817    let boxes =
818        postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
819    let boxes = dispatch_nms_float(nms, iou_threshold, boxes);
820    let len = output_boxes.capacity().min(boxes.len());
821    output_boxes.clear();
822    for b in boxes.into_iter().take(len) {
823        output_boxes.push(b);
824    }
825}
826
827/// Internal implementation of YOLO detection segmentation decoding for
828/// quantized tensors.
829///
830/// Expected shapes of inputs:
831/// - boxes: (4 + num_classes + num_protos, num_boxes)
832/// - protos: (proto_height, proto_width, num_protos)
833///
834/// # Errors
835/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
836pub(crate) fn impl_yolo_segdet_quant<
837    B: BBoxTypeTrait,
838    BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
839    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
840>(
841    boxes: (ArrayView2<BOX>, Quantization),
842    protos: (ArrayView3<PROTO>, Quantization),
843    score_threshold: f32,
844    iou_threshold: f32,
845    nms: Option<Nms>,
846    output_boxes: &mut Vec<DetectBox>,
847    output_masks: &mut Vec<Segmentation>,
848) -> Result<(), crate::DecoderError>
849where
850    f32: AsPrimitive<BOX>,
851{
852    let (boxes, quant_boxes) = boxes;
853    let num_protos = protos.0.dim().2;
854
855    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
856    let boxes = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
857        (boxes_tensor, quant_boxes),
858        (scores_tensor, quant_boxes),
859        score_threshold,
860        iou_threshold,
861        nms,
862        output_boxes.capacity(),
863    );
864
865    impl_yolo_split_segdet_quant_process_masks::<_, _>(
866        boxes,
867        (mask_tensor, quant_boxes),
868        protos,
869        output_boxes,
870        output_masks,
871    )
872}
873
874/// Internal implementation of YOLO detection segmentation decoding for
875/// float tensors.
876///
877/// Expected shapes of inputs:
878/// - boxes: (4 + num_classes + num_protos, num_boxes)
879/// - protos: (proto_height, proto_width, num_protos)
880///
881/// # Panics
882/// Panics if shapes don't match the expected dimensions.
883pub(crate) fn impl_yolo_segdet_float<
884    B: BBoxTypeTrait,
885    BOX: Float + AsPrimitive<f32> + Send + Sync,
886    PROTO: Float + AsPrimitive<f32> + Send + Sync,
887>(
888    boxes: ArrayView2<BOX>,
889    protos: ArrayView3<PROTO>,
890    score_threshold: f32,
891    iou_threshold: f32,
892    nms: Option<Nms>,
893    output_boxes: &mut Vec<DetectBox>,
894    output_masks: &mut Vec<Segmentation>,
895) -> Result<(), crate::DecoderError>
896where
897    f32: AsPrimitive<BOX>,
898{
899    let num_protos = protos.dim().2;
900    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
901    let boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
902        boxes_tensor,
903        scores_tensor,
904        score_threshold,
905        iou_threshold,
906        nms,
907        output_boxes.capacity(),
908    );
909    impl_yolo_split_segdet_process_masks(boxes, mask_tensor, protos, output_boxes, output_masks)
910}
911
912pub(crate) fn impl_yolo_segdet_get_boxes<
913    B: BBoxTypeTrait,
914    BOX: Float + AsPrimitive<f32> + Send + Sync,
915    SCORE: Float + AsPrimitive<f32> + Send + Sync,
916>(
917    boxes_tensor: ArrayView2<BOX>,
918    scores_tensor: ArrayView2<SCORE>,
919    score_threshold: f32,
920    iou_threshold: f32,
921    nms: Option<Nms>,
922    max_boxes: usize,
923) -> Vec<(DetectBox, usize)>
924where
925    f32: AsPrimitive<SCORE>,
926{
927    let boxes = postprocess_boxes_index_float::<B, _, _>(
928        score_threshold.as_(),
929        boxes_tensor,
930        scores_tensor,
931    );
932    let mut boxes = dispatch_nms_extra_float(nms, iou_threshold, boxes);
933    boxes.truncate(max_boxes);
934    boxes
935}
936
937pub(crate) fn impl_yolo_end_to_end_segdet_get_boxes<
938    B: BBoxTypeTrait,
939    BOX: Float + AsPrimitive<f32> + Send + Sync,
940    SCORE: Float + AsPrimitive<f32> + Send + Sync,
941    CLASS: AsPrimitive<f32> + Send + Sync,
942>(
943    boxes: ArrayView2<BOX>,
944    scores: ArrayView2<SCORE>,
945    classes: ArrayView1<CLASS>,
946    score_threshold: f32,
947    max_boxes: usize,
948) -> Vec<(DetectBox, usize)>
949where
950    f32: AsPrimitive<SCORE>,
951{
952    let mut boxes = postprocess_boxes_index_float::<B, _, _>(score_threshold.as_(), boxes, scores);
953    boxes.truncate(max_boxes);
954    for (b, ind) in &mut boxes {
955        b.label = classes[*ind].as_().round() as usize;
956    }
957    boxes
958}
959
960pub(crate) fn impl_yolo_split_segdet_process_masks<
961    MASK: Float + AsPrimitive<f32> + Send + Sync,
962    PROTO: Float + AsPrimitive<f32> + Send + Sync,
963>(
964    boxes: Vec<(DetectBox, usize)>,
965    masks_tensor: ArrayView2<MASK>,
966    protos_tensor: ArrayView3<PROTO>,
967    output_boxes: &mut Vec<DetectBox>,
968    output_masks: &mut Vec<Segmentation>,
969) -> Result<(), crate::DecoderError> {
970    let boxes = decode_segdet_f32(boxes, masks_tensor, protos_tensor)?;
971    output_boxes.clear();
972    output_masks.clear();
973    for (b, m) in boxes.into_iter() {
974        output_boxes.push(b);
975        output_masks.push(Segmentation {
976            xmin: b.bbox.xmin,
977            ymin: b.bbox.ymin,
978            xmax: b.bbox.xmax,
979            ymax: b.bbox.ymax,
980            segmentation: m,
981        });
982    }
983    Ok(())
984}
985/// Expected input shapes:
986/// - boxes_tensor: (num_boxes, 4)
987/// - scores_tensor: (num_boxes, num_classes)
988pub(crate) fn impl_yolo_split_segdet_quant_get_boxes<
989    B: BBoxTypeTrait,
990    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
991    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
992>(
993    boxes: (ArrayView2<BOX>, Quantization),
994    scores: (ArrayView2<SCORE>, Quantization),
995    score_threshold: f32,
996    iou_threshold: f32,
997    nms: Option<Nms>,
998    max_boxes: usize,
999) -> Vec<(DetectBox, usize)>
1000where
1001    f32: AsPrimitive<SCORE>,
1002{
1003    let (boxes_tensor, quant_boxes) = boxes;
1004    let (scores_tensor, quant_scores) = scores;
1005
1006    let boxes = {
1007        let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
1008        postprocess_boxes_index_quant::<B, _, _>(
1009            score_threshold,
1010            boxes_tensor,
1011            scores_tensor,
1012            quant_boxes,
1013        )
1014    };
1015    let mut boxes = dispatch_nms_extra_int(nms, iou_threshold, boxes);
1016    boxes.truncate(max_boxes);
1017    boxes
1018        .into_iter()
1019        .map(|(b, i)| (dequant_detect_box(&b, quant_scores), i))
1020        .collect()
1021}
1022
1023pub(crate) fn impl_yolo_split_segdet_quant_process_masks<
1024    MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1025    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1026>(
1027    boxes: Vec<(DetectBox, usize)>,
1028    mask_coeff: (ArrayView2<MASK>, Quantization),
1029    protos: (ArrayView3<PROTO>, Quantization),
1030    output_boxes: &mut Vec<DetectBox>,
1031    output_masks: &mut Vec<Segmentation>,
1032) -> Result<(), crate::DecoderError> {
1033    let (masks, quant_masks) = mask_coeff;
1034    let (protos, quant_protos) = protos;
1035
1036    let boxes = decode_segdet_quant(boxes, masks, protos, quant_masks, quant_protos)?;
1037    output_boxes.clear();
1038    output_masks.clear();
1039    for (b, m) in boxes.into_iter() {
1040        output_boxes.push(b);
1041        output_masks.push(Segmentation {
1042            xmin: b.bbox.xmin,
1043            ymin: b.bbox.ymin,
1044            xmax: b.bbox.xmax,
1045            ymax: b.bbox.ymax,
1046            segmentation: m,
1047        });
1048    }
1049    Ok(())
1050}
1051
1052#[allow(clippy::too_many_arguments)]
1053/// Internal implementation of YOLO split detection segmentation decoding for
1054/// quantized tensors.
1055///
1056/// Expected shapes of inputs:
1057/// - boxes_tensor: (4, num_boxes)
1058/// - scores_tensor: (num_classes, num_boxes)
1059/// - mask_tensor: (num_protos, num_boxes)
1060/// - protos: (proto_height, proto_width, num_protos)
1061///
1062/// # Errors
1063/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
1064pub(crate) fn impl_yolo_split_segdet_quant<
1065    B: BBoxTypeTrait,
1066    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
1067    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
1068    MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1069    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1070>(
1071    boxes: (ArrayView2<BOX>, Quantization),
1072    scores: (ArrayView2<SCORE>, Quantization),
1073    mask_coeff: (ArrayView2<MASK>, Quantization),
1074    protos: (ArrayView3<PROTO>, Quantization),
1075    score_threshold: f32,
1076    iou_threshold: f32,
1077    nms: Option<Nms>,
1078    output_boxes: &mut Vec<DetectBox>,
1079    output_masks: &mut Vec<Segmentation>,
1080) -> Result<(), crate::DecoderError>
1081where
1082    f32: AsPrimitive<SCORE>,
1083{
1084    let (boxes_, scores_, mask_coeff_) =
1085        postprocess_yolo_split_segdet(boxes.0, scores.0, mask_coeff.0);
1086    let boxes = (boxes_, boxes.1);
1087    let scores = (scores_, scores.1);
1088    let mask_coeff = (mask_coeff_, mask_coeff.1);
1089
1090    let boxes = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
1091        boxes,
1092        scores,
1093        score_threshold,
1094        iou_threshold,
1095        nms,
1096        output_boxes.capacity(),
1097    );
1098
1099    impl_yolo_split_segdet_quant_process_masks(
1100        boxes,
1101        mask_coeff,
1102        protos,
1103        output_boxes,
1104        output_masks,
1105    )
1106}
1107
1108#[allow(clippy::too_many_arguments)]
1109/// Internal implementation of YOLO split detection segmentation decoding for
1110/// float tensors.
1111///
1112/// Expected shapes of inputs:
1113/// - boxes_tensor: (4, num_boxes)
1114/// - scores_tensor: (num_classes, num_boxes)
1115/// - mask_tensor: (num_protos, num_boxes)
1116/// - protos: (proto_height, proto_width, num_protos)
1117///
1118/// # Errors
1119/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
1120pub(crate) fn impl_yolo_split_segdet_float<
1121    B: BBoxTypeTrait,
1122    BOX: Float + AsPrimitive<f32> + Send + Sync,
1123    SCORE: Float + AsPrimitive<f32> + Send + Sync,
1124    MASK: Float + AsPrimitive<f32> + Send + Sync,
1125    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1126>(
1127    boxes_tensor: ArrayView2<BOX>,
1128    scores_tensor: ArrayView2<SCORE>,
1129    mask_tensor: ArrayView2<MASK>,
1130    protos: ArrayView3<PROTO>,
1131    score_threshold: f32,
1132    iou_threshold: f32,
1133    nms: Option<Nms>,
1134    output_boxes: &mut Vec<DetectBox>,
1135    output_masks: &mut Vec<Segmentation>,
1136) -> Result<(), crate::DecoderError>
1137where
1138    f32: AsPrimitive<SCORE>,
1139{
1140    let (boxes_tensor, scores_tensor, mask_tensor) =
1141        postprocess_yolo_split_segdet(boxes_tensor, scores_tensor, mask_tensor);
1142
1143    let boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
1144        boxes_tensor,
1145        scores_tensor,
1146        score_threshold,
1147        iou_threshold,
1148        nms,
1149        output_boxes.capacity(),
1150    );
1151    impl_yolo_split_segdet_process_masks(boxes, mask_tensor, protos, output_boxes, output_masks)
1152}
1153
1154// ---------------------------------------------------------------------------
1155// Proto-extraction variants: return ProtoData instead of materialized masks
1156// ---------------------------------------------------------------------------
1157
1158/// Proto-extraction variant of `impl_yolo_segdet_quant`.
1159/// Runs NMS but returns raw `ProtoData` instead of materialized masks.
1160pub fn impl_yolo_segdet_quant_proto<
1161    B: BBoxTypeTrait,
1162    BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1163    PROTO: PrimInt
1164        + AsPrimitive<i64>
1165        + AsPrimitive<i128>
1166        + AsPrimitive<f32>
1167        + AsPrimitive<i8>
1168        + Send
1169        + Sync,
1170>(
1171    boxes: (ArrayView2<BOX>, Quantization),
1172    protos: (ArrayView3<PROTO>, Quantization),
1173    score_threshold: f32,
1174    iou_threshold: f32,
1175    nms: Option<Nms>,
1176    output_boxes: &mut Vec<DetectBox>,
1177) -> ProtoData
1178where
1179    f32: AsPrimitive<BOX>,
1180{
1181    let (boxes_arr, quant_boxes) = boxes;
1182    let (protos_arr, quant_protos) = protos;
1183    let num_protos = protos_arr.dim().2;
1184
1185    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes_arr, num_protos);
1186
1187    let det_indices = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
1188        (boxes_tensor, quant_boxes),
1189        (scores_tensor, quant_boxes),
1190        score_threshold,
1191        iou_threshold,
1192        nms,
1193        output_boxes.capacity(),
1194    );
1195
1196    extract_proto_data_quant(
1197        det_indices,
1198        mask_tensor,
1199        quant_boxes,
1200        protos_arr,
1201        quant_protos,
1202        output_boxes,
1203    )
1204}
1205
1206/// Proto-extraction variant of `impl_yolo_segdet_float`.
1207/// Runs NMS but returns raw `ProtoData` instead of materialized masks.
1208pub(crate) fn impl_yolo_segdet_float_proto<
1209    B: BBoxTypeTrait,
1210    BOX: Float + AsPrimitive<f32> + Send + Sync,
1211    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1212>(
1213    boxes: ArrayView2<BOX>,
1214    protos: ArrayView3<PROTO>,
1215    score_threshold: f32,
1216    iou_threshold: f32,
1217    nms: Option<Nms>,
1218    output_boxes: &mut Vec<DetectBox>,
1219) -> ProtoData
1220where
1221    f32: AsPrimitive<BOX>,
1222{
1223    let num_protos = protos.dim().2;
1224    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
1225
1226    let boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
1227        boxes_tensor,
1228        scores_tensor,
1229        score_threshold,
1230        iou_threshold,
1231        nms,
1232        output_boxes.capacity(),
1233    );
1234
1235    extract_proto_data_float(boxes, mask_tensor, protos, output_boxes)
1236}
1237
1238/// Proto-extraction variant of `impl_yolo_split_segdet_float`.
1239/// Runs NMS but returns raw `ProtoData` instead of materialized masks.
1240#[allow(clippy::too_many_arguments)]
1241pub(crate) fn impl_yolo_split_segdet_float_proto<
1242    B: BBoxTypeTrait,
1243    BOX: Float + AsPrimitive<f32> + Send + Sync,
1244    SCORE: Float + AsPrimitive<f32> + Send + Sync,
1245    MASK: Float + AsPrimitive<f32> + Send + Sync,
1246    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1247>(
1248    boxes_tensor: ArrayView2<BOX>,
1249    scores_tensor: ArrayView2<SCORE>,
1250    mask_tensor: ArrayView2<MASK>,
1251    protos: ArrayView3<PROTO>,
1252    score_threshold: f32,
1253    iou_threshold: f32,
1254    nms: Option<Nms>,
1255    output_boxes: &mut Vec<DetectBox>,
1256) -> ProtoData
1257where
1258    f32: AsPrimitive<SCORE>,
1259{
1260    let (boxes_tensor, scores_tensor, mask_tensor) =
1261        postprocess_yolo_split_segdet(boxes_tensor, scores_tensor, mask_tensor);
1262    let det_indices = impl_yolo_segdet_get_boxes::<B, _, _>(
1263        boxes_tensor,
1264        scores_tensor,
1265        score_threshold,
1266        iou_threshold,
1267        nms,
1268        output_boxes.capacity(),
1269    );
1270
1271    extract_proto_data_float(det_indices, mask_tensor, protos, output_boxes)
1272}
1273
1274/// Proto-extraction variant of `decode_yolo_end_to_end_segdet_float`.
1275pub fn decode_yolo_end_to_end_segdet_float_proto<T>(
1276    output: ArrayView2<T>,
1277    protos: ArrayView3<T>,
1278    score_threshold: f32,
1279    output_boxes: &mut Vec<DetectBox>,
1280) -> Result<ProtoData, crate::DecoderError>
1281where
1282    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
1283    f32: AsPrimitive<T>,
1284{
1285    let (boxes, scores, classes, mask_coeff) =
1286        postprocess_yolo_end_to_end_segdet(&output, protos.dim().2)?;
1287    let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
1288        boxes,
1289        scores,
1290        classes,
1291        score_threshold,
1292        output_boxes.capacity(),
1293    );
1294
1295    Ok(extract_proto_data_float(
1296        boxes,
1297        mask_coeff,
1298        protos,
1299        output_boxes,
1300    ))
1301}
1302
1303/// Proto-extraction variant of `decode_yolo_split_end_to_end_segdet_float`.
1304#[allow(clippy::too_many_arguments)]
1305pub fn decode_yolo_split_end_to_end_segdet_float_proto<T>(
1306    boxes: ArrayView2<T>,
1307    scores: ArrayView2<T>,
1308    classes: ArrayView2<T>,
1309    mask_coeff: ArrayView2<T>,
1310    protos: ArrayView3<T>,
1311    score_threshold: f32,
1312    output_boxes: &mut Vec<DetectBox>,
1313) -> Result<ProtoData, crate::DecoderError>
1314where
1315    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
1316    f32: AsPrimitive<T>,
1317{
1318    let (boxes, scores, classes, mask_coeff) =
1319        postprocess_yolo_split_end_to_end_segdet(boxes, scores, &classes, mask_coeff)?;
1320    let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
1321        boxes,
1322        scores,
1323        classes,
1324        score_threshold,
1325        output_boxes.capacity(),
1326    );
1327
1328    Ok(extract_proto_data_float(
1329        boxes,
1330        mask_coeff,
1331        protos,
1332        output_boxes,
1333    ))
1334}
1335
1336/// Helper: extract ProtoData from float mask coefficients + protos.
1337pub(super) fn extract_proto_data_float<
1338    MASK: Float + AsPrimitive<f32> + Send + Sync,
1339    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1340>(
1341    det_indices: Vec<(DetectBox, usize)>,
1342    mask_tensor: ArrayView2<MASK>,
1343    protos: ArrayView3<PROTO>,
1344    output_boxes: &mut Vec<DetectBox>,
1345) -> ProtoData {
1346    let mut mask_coefficients = Vec::with_capacity(det_indices.len());
1347    output_boxes.clear();
1348    for (det, idx) in det_indices {
1349        output_boxes.push(det);
1350        let row = mask_tensor.row(idx);
1351        mask_coefficients.push(row.iter().map(|v| v.as_()).collect());
1352    }
1353    let protos_f32 = protos.map(|v| v.as_());
1354    ProtoData {
1355        mask_coefficients,
1356        protos: ProtoTensor::Float(protos_f32),
1357    }
1358}
1359
1360/// Helper: extract ProtoData from quantized mask coefficients + protos.
1361///
1362/// Dequantizes mask coefficients to f32 (small — per-detection) but keeps
1363/// protos in raw int8 form wrapped in `ProtoTensor::Quantized` so the GPU
1364/// shader can dequantize per-texel without CPU overhead.
1365pub(crate) fn extract_proto_data_quant<
1366    MASK: PrimInt + AsPrimitive<f32> + Send + Sync,
1367    PROTO: PrimInt + AsPrimitive<f32> + AsPrimitive<i8> + Send + Sync + 'static,
1368>(
1369    det_indices: Vec<(DetectBox, usize)>,
1370    mask_tensor: ArrayView2<MASK>,
1371    quant_masks: Quantization,
1372    protos: ArrayView3<PROTO>,
1373    quant_protos: Quantization,
1374    output_boxes: &mut Vec<DetectBox>,
1375) -> ProtoData {
1376    let mut mask_coefficients = Vec::with_capacity(det_indices.len());
1377    output_boxes.clear();
1378    for (det, idx) in det_indices {
1379        output_boxes.push(det);
1380        let row = mask_tensor.row(idx);
1381        mask_coefficients.push(
1382            row.iter()
1383                .map(|v| (v.as_() - quant_masks.zero_point as f32) * quant_masks.scale)
1384                .collect(),
1385        );
1386    }
1387    // Keep protos in raw int8 — GPU shader will dequantize per-texel.
1388    // When PROTO is already i8, use to_owned() for a flat memcpy instead of
1389    // per-element as_() conversion.
1390    let protos_i8 = if std::any::TypeId::of::<PROTO>() == std::any::TypeId::of::<i8>() {
1391        // SAFETY: PROTO and i8 have identical size and layout when TypeId matches.
1392        let view_i8 =
1393            unsafe { &*(&protos as *const ArrayView3<'_, PROTO> as *const ArrayView3<'_, i8>) };
1394        view_i8.to_owned()
1395    } else {
1396        protos.map(|v| {
1397            let v_i8: i8 = v.as_();
1398            v_i8
1399        })
1400    };
1401    ProtoData {
1402        mask_coefficients,
1403        protos: ProtoTensor::Quantized {
1404            protos: protos_i8,
1405            quantization: quant_protos,
1406        },
1407    }
1408}
1409
1410fn postprocess_yolo<'a, T>(
1411    output: &'a ArrayView2<'_, T>,
1412) -> (ArrayView2<'a, T>, ArrayView2<'a, T>) {
1413    let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1414    let scores_tensor = output.slice(s![4.., ..,]).reversed_axes();
1415    (boxes_tensor, scores_tensor)
1416}
1417
1418pub(crate) fn postprocess_yolo_seg<'a, T>(
1419    output: &'a ArrayView2<'_, T>,
1420    num_protos: usize,
1421) -> (ArrayView2<'a, T>, ArrayView2<'a, T>, ArrayView2<'a, T>) {
1422    assert!(
1423        output.shape()[0] > num_protos + 4,
1424        "Output shape is too short: {} <= {} + 4",
1425        output.shape()[0],
1426        num_protos
1427    );
1428    let num_classes = output.shape()[0] - 4 - num_protos;
1429    let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1430    let scores_tensor = output.slice(s![4..(num_classes + 4), ..,]).reversed_axes();
1431    let mask_tensor = output.slice(s![(num_classes + 4).., ..,]).reversed_axes();
1432    (boxes_tensor, scores_tensor, mask_tensor)
1433}
1434
1435pub(crate) fn postprocess_yolo_split_segdet<'a, 'b, 'c, BOX, SCORE, MASK>(
1436    boxes_tensor: ArrayView2<'a, BOX>,
1437    scores_tensor: ArrayView2<'b, SCORE>,
1438    mask_tensor: ArrayView2<'c, MASK>,
1439) -> (
1440    ArrayView2<'a, BOX>,
1441    ArrayView2<'b, SCORE>,
1442    ArrayView2<'c, MASK>,
1443) {
1444    let boxes_tensor = boxes_tensor.reversed_axes();
1445    let scores_tensor = scores_tensor.reversed_axes();
1446    let mask_tensor = mask_tensor.reversed_axes();
1447    (boxes_tensor, scores_tensor, mask_tensor)
1448}
1449
1450fn decode_segdet_f32<
1451    MASK: Float + AsPrimitive<f32> + Send + Sync,
1452    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1453>(
1454    boxes: Vec<(DetectBox, usize)>,
1455    masks: ArrayView2<MASK>,
1456    protos: ArrayView3<PROTO>,
1457) -> Result<Vec<(DetectBox, Array3<u8>)>, crate::DecoderError> {
1458    if boxes.is_empty() {
1459        return Ok(Vec::new());
1460    }
1461    if masks.shape()[1] != protos.shape()[2] {
1462        return Err(crate::DecoderError::InvalidShape(format!(
1463            "Mask coefficients count ({}) doesn't match protos channel count ({})",
1464            masks.shape()[1],
1465            protos.shape()[2],
1466        )));
1467    }
1468    boxes
1469        .into_par_iter()
1470        .map(|mut b| {
1471            let ind = b.1;
1472            let (protos, roi) = protobox(&protos, &b.0.bbox)?;
1473            b.0.bbox = roi;
1474            Ok((b.0, make_segmentation(masks.row(ind), protos.view())))
1475        })
1476        .collect()
1477}
1478
1479pub(crate) fn decode_segdet_quant<
1480    MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1481    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1482>(
1483    boxes: Vec<(DetectBox, usize)>,
1484    masks: ArrayView2<MASK>,
1485    protos: ArrayView3<PROTO>,
1486    quant_masks: Quantization,
1487    quant_protos: Quantization,
1488) -> Result<Vec<(DetectBox, Array3<u8>)>, crate::DecoderError> {
1489    if boxes.is_empty() {
1490        return Ok(Vec::new());
1491    }
1492    if masks.shape()[1] != protos.shape()[2] {
1493        return Err(crate::DecoderError::InvalidShape(format!(
1494            "Mask coefficients count ({}) doesn't match protos channel count ({})",
1495            masks.shape()[1],
1496            protos.shape()[2],
1497        )));
1498    }
1499
1500    let total_bits = MASK::zero().count_zeros() + PROTO::zero().count_zeros() + 5; // 32 protos is 2^5
1501    boxes
1502        .into_iter()
1503        .map(|mut b| {
1504            let i = b.1;
1505            let (protos, roi) = protobox(&protos, &b.0.bbox.to_canonical())?;
1506            b.0.bbox = roi;
1507            let seg = match total_bits {
1508                0..=64 => make_segmentation_quant::<MASK, PROTO, i64>(
1509                    masks.row(i),
1510                    protos.view(),
1511                    quant_masks,
1512                    quant_protos,
1513                ),
1514                65..=128 => make_segmentation_quant::<MASK, PROTO, i128>(
1515                    masks.row(i),
1516                    protos.view(),
1517                    quant_masks,
1518                    quant_protos,
1519                ),
1520                _ => {
1521                    return Err(crate::DecoderError::NotSupported(format!(
1522                        "Unsupported bit width ({total_bits}) for segmentation computation"
1523                    )));
1524                }
1525            };
1526            Ok((b.0, seg))
1527        })
1528        .collect()
1529}
1530
1531fn protobox<'a, T>(
1532    protos: &'a ArrayView3<T>,
1533    roi: &BoundingBox,
1534) -> Result<(ArrayView3<'a, T>, BoundingBox), crate::DecoderError> {
1535    let width = protos.dim().1 as f32;
1536    let height = protos.dim().0 as f32;
1537
1538    // Detect un-normalized bounding boxes (pixel-space coordinates).
1539    // protobox expects normalized coordinates in [0, 1]. ONNX models output
1540    // pixel-space boxes (e.g. 0-640) which must be normalized before calling
1541    // decode(). Without this check, pixel-space coordinates silently clamp to
1542    // the proto boundary, producing empty (0, 0, C) masks for every detection.
1543    //
1544    // The limit is set to 2.0 (not 1.01) because YOLO models legitimately
1545    // predict coordinates slightly > 1.0 for objects near frame edges.
1546    // Any value > 2.0 is clearly pixel-space (even the smallest practical
1547    // model input of 32×32 would produce coordinates >> 2.0).
1548    const NORM_LIMIT: f32 = 2.0;
1549    if roi.xmin > NORM_LIMIT
1550        || roi.ymin > NORM_LIMIT
1551        || roi.xmax > NORM_LIMIT
1552        || roi.ymax > NORM_LIMIT
1553    {
1554        return Err(crate::DecoderError::InvalidShape(format!(
1555            "Bounding box coordinates appear un-normalized (pixel-space). \
1556             Got bbox=({:.2}, {:.2}, {:.2}, {:.2}) but expected values in [0, 1]. \
1557             ONNX models output pixel-space boxes — normalize them by dividing by \
1558             the input dimensions before calling decode().",
1559            roi.xmin, roi.ymin, roi.xmax, roi.ymax,
1560        )));
1561    }
1562
1563    let roi = [
1564        (roi.xmin * width).clamp(0.0, width) as usize,
1565        (roi.ymin * height).clamp(0.0, height) as usize,
1566        (roi.xmax * width).clamp(0.0, width).ceil() as usize,
1567        (roi.ymax * height).clamp(0.0, height).ceil() as usize,
1568    ];
1569
1570    let roi_norm = [
1571        roi[0] as f32 / width,
1572        roi[1] as f32 / height,
1573        roi[2] as f32 / width,
1574        roi[3] as f32 / height,
1575    ]
1576    .into();
1577
1578    let cropped = protos.slice(s![roi[1]..roi[3], roi[0]..roi[2], ..]);
1579
1580    Ok((cropped, roi_norm))
1581}
1582
1583/// Compute a single instance segmentation mask from mask coefficients and
1584/// proto maps (float path).
1585///
1586/// Computes `sigmoid(coefficients · protos)` and maps to `[0, 255]`.
1587/// Returns an `(H, W, 1)` u8 array.
1588fn make_segmentation<
1589    MASK: Float + AsPrimitive<f32> + Send + Sync,
1590    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1591>(
1592    mask: ArrayView1<MASK>,
1593    protos: ArrayView3<PROTO>,
1594) -> Array3<u8> {
1595    let shape = protos.shape();
1596
1597    // Safe to unwrap since the shapes will always be compatible
1598    let mask = mask.to_shape((1, mask.len())).unwrap();
1599    let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
1600    let protos = protos.reversed_axes();
1601    let mask = mask.map(|x| x.as_());
1602    let protos = protos.map(|x| x.as_());
1603
1604    // Safe to unwrap since the shapes will always be compatible
1605    let mask = mask
1606        .dot(&protos)
1607        .into_shape_with_order((shape[0], shape[1], 1))
1608        .unwrap();
1609
1610    mask.map(|x| {
1611        let sigmoid = 1.0 / (1.0 + (-*x).exp());
1612        (sigmoid * 255.0).round() as u8
1613    })
1614}
1615
1616/// Compute a single instance segmentation mask from quantized mask
1617/// coefficients and proto maps.
1618///
1619/// Dequantizes both inputs (subtracting zero-points), computes the dot
1620/// product, applies sigmoid, and maps to `[0, 255]`.
1621/// Returns an `(H, W, 1)` u8 array.
1622fn make_segmentation_quant<
1623    MASK: PrimInt + AsPrimitive<DEST> + Send + Sync,
1624    PROTO: PrimInt + AsPrimitive<DEST> + Send + Sync,
1625    DEST: PrimInt + 'static + Signed + AsPrimitive<f32> + Debug,
1626>(
1627    mask: ArrayView1<MASK>,
1628    protos: ArrayView3<PROTO>,
1629    quant_masks: Quantization,
1630    quant_protos: Quantization,
1631) -> Array3<u8>
1632where
1633    i32: AsPrimitive<DEST>,
1634    f32: AsPrimitive<DEST>,
1635{
1636    let shape = protos.shape();
1637
1638    // Safe to unwrap since the shapes will always be compatible
1639    let mask = mask.to_shape((1, mask.len())).unwrap();
1640
1641    let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
1642    let protos = protos.reversed_axes();
1643
1644    let zp = quant_masks.zero_point.as_();
1645
1646    let mask = mask.mapv(|x| x.as_() - zp);
1647
1648    let zp = quant_protos.zero_point.as_();
1649    let protos = protos.mapv(|x| x.as_() - zp);
1650
1651    // Safe to unwrap since the shapes will always be compatible
1652    let segmentation = mask
1653        .dot(&protos)
1654        .into_shape_with_order((shape[0], shape[1], 1))
1655        .unwrap();
1656
1657    let combined_scale = quant_masks.scale * quant_protos.scale;
1658    segmentation.map(|x| {
1659        let val: f32 = (*x).as_() * combined_scale;
1660        let sigmoid = 1.0 / (1.0 + (-val).exp());
1661        (sigmoid * 255.0).round() as u8
1662    })
1663}
1664
1665/// Converts Yolo Instance Segmentation into a 2D mask.
1666///
1667/// The input segmentation is expected to have shape (H, W, 1).
1668///
1669/// The output mask will have shape (H, W), with values 0 or 1 based on the
1670/// threshold.
1671///
1672/// # Errors
1673///
1674/// Returns `DecoderError::InvalidShape` if the input segmentation does not
1675/// have shape (H, W, 1).
1676pub fn yolo_segmentation_to_mask(
1677    segmentation: ArrayView3<u8>,
1678    threshold: u8,
1679) -> Result<Array2<u8>, crate::DecoderError> {
1680    if segmentation.shape()[2] != 1 {
1681        return Err(crate::DecoderError::InvalidShape(format!(
1682            "Yolo Instance Segmentation should have shape (H, W, 1), got (H, W, {})",
1683            segmentation.shape()[2]
1684        )));
1685    }
1686    Ok(segmentation
1687        .slice(s![.., .., 0])
1688        .map(|x| if *x >= threshold { 1 } else { 0 }))
1689}
1690
1691#[cfg(test)]
1692#[cfg_attr(coverage_nightly, coverage(off))]
1693mod tests {
1694    use super::*;
1695    use ndarray::Array2;
1696
1697    // ========================================================================
1698    // Tests for decode_yolo_end_to_end_det_float
1699    // ========================================================================
1700
1701    #[test]
1702    fn test_end_to_end_det_basic_filtering() {
1703        // Create synthetic end-to-end detection output: (6, N) where rows are
1704        // [x1, y1, x2, y2, conf, class]
1705        // 3 detections: one above threshold, two below
1706        let data: Vec<f32> = vec![
1707            // Detection 0: high score (0.9)
1708            0.1, 0.2, 0.3, // x1 values
1709            0.1, 0.2, 0.3, // y1 values
1710            0.5, 0.6, 0.7, // x2 values
1711            0.5, 0.6, 0.7, // y2 values
1712            0.9, 0.1, 0.2, // confidence scores
1713            0.0, 1.0, 2.0, // class indices
1714        ];
1715        let output = Array2::from_shape_vec((6, 3), data).unwrap();
1716
1717        let mut boxes = Vec::with_capacity(10);
1718        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1719
1720        // Only 1 detection should pass threshold of 0.5
1721        assert_eq!(boxes.len(), 1);
1722        assert_eq!(boxes[0].label, 0);
1723        assert!((boxes[0].score - 0.9).abs() < 0.01);
1724        assert!((boxes[0].bbox.xmin - 0.1).abs() < 0.01);
1725        assert!((boxes[0].bbox.ymin - 0.1).abs() < 0.01);
1726        assert!((boxes[0].bbox.xmax - 0.5).abs() < 0.01);
1727        assert!((boxes[0].bbox.ymax - 0.5).abs() < 0.01);
1728    }
1729
1730    #[test]
1731    fn test_end_to_end_det_all_pass_threshold() {
1732        // All detections above threshold
1733        let data: Vec<f32> = vec![
1734            10.0, 20.0, // x1
1735            10.0, 20.0, // y1
1736            50.0, 60.0, // x2
1737            50.0, 60.0, // y2
1738            0.8, 0.7, // conf (both above 0.5)
1739            1.0, 2.0, // class
1740        ];
1741        let output = Array2::from_shape_vec((6, 2), data).unwrap();
1742
1743        let mut boxes = Vec::with_capacity(10);
1744        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1745
1746        assert_eq!(boxes.len(), 2);
1747        assert_eq!(boxes[0].label, 1);
1748        assert_eq!(boxes[1].label, 2);
1749    }
1750
1751    #[test]
1752    fn test_end_to_end_det_none_pass_threshold() {
1753        // All detections below threshold
1754        let data: Vec<f32> = vec![
1755            10.0, 20.0, // x1
1756            10.0, 20.0, // y1
1757            50.0, 60.0, // x2
1758            50.0, 60.0, // y2
1759            0.1, 0.2, // conf (both below 0.5)
1760            1.0, 2.0, // class
1761        ];
1762        let output = Array2::from_shape_vec((6, 2), data).unwrap();
1763
1764        let mut boxes = Vec::with_capacity(10);
1765        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1766
1767        assert_eq!(boxes.len(), 0);
1768    }
1769
1770    #[test]
1771    fn test_end_to_end_det_capacity_limit() {
1772        // Test that output is truncated to capacity
1773        let data: Vec<f32> = vec![
1774            0.1, 0.2, 0.3, 0.4, 0.5, // x1
1775            0.1, 0.2, 0.3, 0.4, 0.5, // y1
1776            0.5, 0.6, 0.7, 0.8, 0.9, // x2
1777            0.5, 0.6, 0.7, 0.8, 0.9, // y2
1778            0.9, 0.9, 0.9, 0.9, 0.9, // conf (all pass)
1779            0.0, 1.0, 2.0, 3.0, 4.0, // class
1780        ];
1781        let output = Array2::from_shape_vec((6, 5), data).unwrap();
1782
1783        let mut boxes = Vec::with_capacity(2); // Only allow 2 boxes
1784        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1785
1786        assert_eq!(boxes.len(), 2);
1787    }
1788
1789    #[test]
1790    fn test_end_to_end_det_empty_output() {
1791        // Test with zero detections
1792        let output = Array2::<f32>::zeros((6, 0));
1793
1794        let mut boxes = Vec::with_capacity(10);
1795        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1796
1797        assert_eq!(boxes.len(), 0);
1798    }
1799
1800    #[test]
1801    fn test_end_to_end_det_pixel_coordinates() {
1802        // Test with pixel coordinates (non-normalized)
1803        let data: Vec<f32> = vec![
1804            100.0, // x1
1805            200.0, // y1
1806            300.0, // x2
1807            400.0, // y2
1808            0.95,  // conf
1809            5.0,   // class
1810        ];
1811        let output = Array2::from_shape_vec((6, 1), data).unwrap();
1812
1813        let mut boxes = Vec::with_capacity(10);
1814        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1815
1816        assert_eq!(boxes.len(), 1);
1817        assert_eq!(boxes[0].label, 5);
1818        assert!((boxes[0].bbox.xmin - 100.0).abs() < 0.01);
1819        assert!((boxes[0].bbox.ymin - 200.0).abs() < 0.01);
1820        assert!((boxes[0].bbox.xmax - 300.0).abs() < 0.01);
1821        assert!((boxes[0].bbox.ymax - 400.0).abs() < 0.01);
1822    }
1823
1824    #[test]
1825    fn test_end_to_end_det_invalid_shape() {
1826        // Test with too few rows (needs at least 6)
1827        let output = Array2::<f32>::zeros((5, 3));
1828
1829        let mut boxes = Vec::with_capacity(10);
1830        let result = decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes);
1831
1832        assert!(result.is_err());
1833        assert!(matches!(
1834            result,
1835            Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 6 rows")
1836        ));
1837    }
1838
1839    // ========================================================================
1840    // Tests for decode_yolo_end_to_end_segdet_float
1841    // ========================================================================
1842
1843    #[test]
1844    fn test_end_to_end_segdet_basic() {
1845        // Create synthetic segdet output: (6 + num_protos, N)
1846        // Detection format: [x1, y1, x2, y2, conf, class, mask_coeff_0..31]
1847        let num_protos = 32;
1848        let num_detections = 2;
1849        let num_features = 6 + num_protos;
1850
1851        // Build detection tensor
1852        let mut data = vec![0.0f32; num_features * num_detections];
1853        // Detection 0: passes threshold
1854        data[0] = 0.1; // x1[0]
1855        data[1] = 0.5; // x1[1]
1856        data[num_detections] = 0.1; // y1[0]
1857        data[num_detections + 1] = 0.5; // y1[1]
1858        data[2 * num_detections] = 0.4; // x2[0]
1859        data[2 * num_detections + 1] = 0.9; // x2[1]
1860        data[3 * num_detections] = 0.4; // y2[0]
1861        data[3 * num_detections + 1] = 0.9; // y2[1]
1862        data[4 * num_detections] = 0.9; // conf[0] - passes
1863        data[4 * num_detections + 1] = 0.3; // conf[1] - fails
1864        data[5 * num_detections] = 1.0; // class[0]
1865        data[5 * num_detections + 1] = 2.0; // class[1]
1866                                            // Fill mask coefficients with small values
1867        for i in 6..num_features {
1868            data[i * num_detections] = 0.1;
1869            data[i * num_detections + 1] = 0.1;
1870        }
1871
1872        let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
1873
1874        // Create protos tensor: (proto_height, proto_width, num_protos)
1875        let protos = Array3::<f32>::zeros((16, 16, num_protos));
1876
1877        let mut boxes = Vec::with_capacity(10);
1878        let mut masks = Vec::with_capacity(10);
1879        decode_yolo_end_to_end_segdet_float(
1880            output.view(),
1881            protos.view(),
1882            0.5,
1883            &mut boxes,
1884            &mut masks,
1885        )
1886        .unwrap();
1887
1888        // Only detection 0 should pass
1889        assert_eq!(boxes.len(), 1);
1890        assert_eq!(masks.len(), 1);
1891        assert_eq!(boxes[0].label, 1);
1892        assert!((boxes[0].score - 0.9).abs() < 0.01);
1893    }
1894
1895    #[test]
1896    fn test_end_to_end_segdet_mask_coordinates() {
1897        // Test that mask coordinates match box coordinates
1898        let num_protos = 32;
1899        let num_features = 6 + num_protos;
1900
1901        let mut data = vec![0.0f32; num_features];
1902        data[0] = 0.2; // x1
1903        data[1] = 0.2; // y1
1904        data[2] = 0.8; // x2
1905        data[3] = 0.8; // y2
1906        data[4] = 0.95; // conf
1907        data[5] = 3.0; // class
1908
1909        let output = Array2::from_shape_vec((num_features, 1), data).unwrap();
1910        let protos = Array3::<f32>::zeros((16, 16, num_protos));
1911
1912        let mut boxes = Vec::with_capacity(10);
1913        let mut masks = Vec::with_capacity(10);
1914        decode_yolo_end_to_end_segdet_float(
1915            output.view(),
1916            protos.view(),
1917            0.5,
1918            &mut boxes,
1919            &mut masks,
1920        )
1921        .unwrap();
1922
1923        assert_eq!(boxes.len(), 1);
1924        assert_eq!(masks.len(), 1);
1925
1926        // Verify mask coordinates match box coordinates
1927        assert!((masks[0].xmin - boxes[0].bbox.xmin).abs() < 0.01);
1928        assert!((masks[0].ymin - boxes[0].bbox.ymin).abs() < 0.01);
1929        assert!((masks[0].xmax - boxes[0].bbox.xmax).abs() < 0.01);
1930        assert!((masks[0].ymax - boxes[0].bbox.ymax).abs() < 0.01);
1931    }
1932
1933    #[test]
1934    fn test_end_to_end_segdet_empty_output() {
1935        let num_protos = 32;
1936        let output = Array2::<f32>::zeros((6 + num_protos, 0));
1937        let protos = Array3::<f32>::zeros((16, 16, num_protos));
1938
1939        let mut boxes = Vec::with_capacity(10);
1940        let mut masks = Vec::with_capacity(10);
1941        decode_yolo_end_to_end_segdet_float(
1942            output.view(),
1943            protos.view(),
1944            0.5,
1945            &mut boxes,
1946            &mut masks,
1947        )
1948        .unwrap();
1949
1950        assert_eq!(boxes.len(), 0);
1951        assert_eq!(masks.len(), 0);
1952    }
1953
1954    #[test]
1955    fn test_end_to_end_segdet_capacity_limit() {
1956        let num_protos = 32;
1957        let num_detections = 5;
1958        let num_features = 6 + num_protos;
1959
1960        let mut data = vec![0.0f32; num_features * num_detections];
1961        // All detections pass threshold
1962        for i in 0..num_detections {
1963            data[i] = 0.1 * (i as f32); // x1
1964            data[num_detections + i] = 0.1 * (i as f32); // y1
1965            data[2 * num_detections + i] = 0.1 * (i as f32) + 0.2; // x2
1966            data[3 * num_detections + i] = 0.1 * (i as f32) + 0.2; // y2
1967            data[4 * num_detections + i] = 0.9; // conf
1968            data[5 * num_detections + i] = i as f32; // class
1969        }
1970
1971        let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
1972        let protos = Array3::<f32>::zeros((16, 16, num_protos));
1973
1974        let mut boxes = Vec::with_capacity(2); // Limit to 2
1975        let mut masks = Vec::with_capacity(2);
1976        decode_yolo_end_to_end_segdet_float(
1977            output.view(),
1978            protos.view(),
1979            0.5,
1980            &mut boxes,
1981            &mut masks,
1982        )
1983        .unwrap();
1984
1985        assert_eq!(boxes.len(), 2);
1986        assert_eq!(masks.len(), 2);
1987    }
1988
1989    #[test]
1990    fn test_end_to_end_segdet_invalid_shape_too_few_rows() {
1991        // Test with too few rows (needs at least 7: 6 base + 1 mask coeff)
1992        let output = Array2::<f32>::zeros((6, 3));
1993        let protos = Array3::<f32>::zeros((16, 16, 32));
1994
1995        let mut boxes = Vec::with_capacity(10);
1996        let mut masks = Vec::with_capacity(10);
1997        let result = decode_yolo_end_to_end_segdet_float(
1998            output.view(),
1999            protos.view(),
2000            0.5,
2001            &mut boxes,
2002            &mut masks,
2003        );
2004
2005        assert!(result.is_err());
2006        assert!(matches!(
2007            result,
2008            Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 7 rows")
2009        ));
2010    }
2011
2012    #[test]
2013    fn test_end_to_end_segdet_invalid_shape_protos_mismatch() {
2014        // Test with mismatched mask coefficients and protos count
2015        let num_protos = 32;
2016        let output = Array2::<f32>::zeros((6 + 16, 3)); // 16 mask coeffs
2017        let protos = Array3::<f32>::zeros((16, 16, num_protos)); // 32 protos
2018
2019        let mut boxes = Vec::with_capacity(10);
2020        let mut masks = Vec::with_capacity(10);
2021        let result = decode_yolo_end_to_end_segdet_float(
2022            output.view(),
2023            protos.view(),
2024            0.5,
2025            &mut boxes,
2026            &mut masks,
2027        );
2028
2029        assert!(result.is_err());
2030        assert!(matches!(
2031            result,
2032            Err(crate::DecoderError::InvalidShape(s)) if s.contains("doesn't match protos count")
2033        ));
2034    }
2035
2036    // ========================================================================
2037    // Tests for decode_yolo_split_end_to_end_segdet_float
2038    // ========================================================================
2039
2040    #[test]
2041    fn test_split_end_to_end_segdet_basic() {
2042        // Create synthetic segdet output: (6 + num_protos, N)
2043        // Detection format: [x1, y1, x2, y2, conf, class, mask_coeff_0..31]
2044        let num_protos = 32;
2045        let num_detections = 2;
2046        let num_features = 6 + num_protos;
2047
2048        // Build detection tensor
2049        let mut data = vec![0.0f32; num_features * num_detections];
2050        // Detection 0: passes threshold
2051        data[0] = 0.1; // x1[0]
2052        data[1] = 0.5; // x1[1]
2053        data[num_detections] = 0.1; // y1[0]
2054        data[num_detections + 1] = 0.5; // y1[1]
2055        data[2 * num_detections] = 0.4; // x2[0]
2056        data[2 * num_detections + 1] = 0.9; // x2[1]
2057        data[3 * num_detections] = 0.4; // y2[0]
2058        data[3 * num_detections + 1] = 0.9; // y2[1]
2059        data[4 * num_detections] = 0.9; // conf[0] - passes
2060        data[4 * num_detections + 1] = 0.3; // conf[1] - fails
2061        data[5 * num_detections] = 1.0; // class[0]
2062        data[5 * num_detections + 1] = 2.0; // class[1]
2063                                            // Fill mask coefficients with small values
2064        for i in 6..num_features {
2065            data[i * num_detections] = 0.1;
2066            data[i * num_detections + 1] = 0.1;
2067        }
2068
2069        let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
2070        let box_coords = output.slice(s![..4, ..]);
2071        let scores = output.slice(s![4..5, ..]);
2072        let classes = output.slice(s![5..6, ..]);
2073        let mask_coeff = output.slice(s![6.., ..]);
2074        // Create protos tensor: (proto_height, proto_width, num_protos)
2075        let protos = Array3::<f32>::zeros((16, 16, num_protos));
2076
2077        let mut boxes = Vec::with_capacity(10);
2078        let mut masks = Vec::with_capacity(10);
2079        decode_yolo_split_end_to_end_segdet_float(
2080            box_coords,
2081            scores,
2082            classes,
2083            mask_coeff,
2084            protos.view(),
2085            0.5,
2086            &mut boxes,
2087            &mut masks,
2088        )
2089        .unwrap();
2090
2091        // Only detection 0 should pass
2092        assert_eq!(boxes.len(), 1);
2093        assert_eq!(masks.len(), 1);
2094        assert_eq!(boxes[0].label, 1);
2095        assert!((boxes[0].score - 0.9).abs() < 0.01);
2096    }
2097
2098    // ========================================================================
2099    // Tests for yolo_segmentation_to_mask
2100    // ========================================================================
2101
2102    #[test]
2103    fn test_segmentation_to_mask_basic() {
2104        // Create a 4x4x1 segmentation with values above and below threshold
2105        let data: Vec<u8> = vec![
2106            100, 200, 50, 150, // row 0
2107            10, 255, 128, 64, // row 1
2108            0, 127, 128, 255, // row 2
2109            64, 64, 192, 192, // row 3
2110        ];
2111        let segmentation = Array3::from_shape_vec((4, 4, 1), data).unwrap();
2112
2113        let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2114
2115        // Values >= 128 should be 1, others 0
2116        assert_eq!(mask[[0, 0]], 0); // 100 < 128
2117        assert_eq!(mask[[0, 1]], 1); // 200 >= 128
2118        assert_eq!(mask[[0, 2]], 0); // 50 < 128
2119        assert_eq!(mask[[0, 3]], 1); // 150 >= 128
2120        assert_eq!(mask[[1, 1]], 1); // 255 >= 128
2121        assert_eq!(mask[[1, 2]], 1); // 128 >= 128
2122        assert_eq!(mask[[2, 0]], 0); // 0 < 128
2123        assert_eq!(mask[[2, 1]], 0); // 127 < 128
2124    }
2125
2126    #[test]
2127    fn test_segmentation_to_mask_all_above() {
2128        let segmentation = Array3::from_elem((4, 4, 1), 255u8);
2129        let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2130        assert!(mask.iter().all(|&x| x == 1));
2131    }
2132
2133    #[test]
2134    fn test_segmentation_to_mask_all_below() {
2135        let segmentation = Array3::from_elem((4, 4, 1), 64u8);
2136        let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2137        assert!(mask.iter().all(|&x| x == 0));
2138    }
2139
2140    #[test]
2141    fn test_segmentation_to_mask_invalid_shape() {
2142        let segmentation = Array3::from_elem((4, 4, 3), 128u8);
2143        let result = yolo_segmentation_to_mask(segmentation.view(), 128);
2144
2145        assert!(result.is_err());
2146        assert!(matches!(
2147            result,
2148            Err(crate::DecoderError::InvalidShape(s)) if s.contains("(H, W, 1)")
2149        ));
2150    }
2151
2152    // ========================================================================
2153    // Tests for protobox / NORM_LIMIT regression
2154    // ========================================================================
2155
2156    #[test]
2157    fn test_protobox_clamps_edge_coordinates() {
2158        // bbox with xmax=1.0 should not panic (OOB guard)
2159        let protos = Array3::<f32>::zeros((16, 16, 4));
2160        let view = protos.view();
2161        let roi = BoundingBox {
2162            xmin: 0.5,
2163            ymin: 0.5,
2164            xmax: 1.0,
2165            ymax: 1.0,
2166        };
2167        let result = protobox(&view, &roi);
2168        assert!(result.is_ok(), "protobox should accept xmax=1.0");
2169        let (cropped, _roi_norm) = result.unwrap();
2170        // Cropped region must have non-zero spatial dimensions
2171        assert!(cropped.shape()[0] > 0);
2172        assert!(cropped.shape()[1] > 0);
2173        assert_eq!(cropped.shape()[2], 4);
2174    }
2175
2176    #[test]
2177    fn test_protobox_rejects_wildly_out_of_range() {
2178        // bbox with coords > NORM_LIMIT (e.g. 3.0) returns error
2179        let protos = Array3::<f32>::zeros((16, 16, 4));
2180        let view = protos.view();
2181        let roi = BoundingBox {
2182            xmin: 0.0,
2183            ymin: 0.0,
2184            xmax: 3.0,
2185            ymax: 3.0,
2186        };
2187        let result = protobox(&view, &roi);
2188        assert!(
2189            matches!(result, Err(crate::DecoderError::InvalidShape(s)) if s.contains("un-normalized")),
2190            "protobox should reject coords > NORM_LIMIT"
2191        );
2192    }
2193
2194    #[test]
2195    fn test_protobox_accepts_slightly_over_one() {
2196        // bbox with coords at 1.5 (within NORM_LIMIT=2.0) succeeds
2197        let protos = Array3::<f32>::zeros((16, 16, 4));
2198        let view = protos.view();
2199        let roi = BoundingBox {
2200            xmin: 0.0,
2201            ymin: 0.0,
2202            xmax: 1.5,
2203            ymax: 1.5,
2204        };
2205        let result = protobox(&view, &roi);
2206        assert!(
2207            result.is_ok(),
2208            "protobox should accept coords <= NORM_LIMIT (2.0)"
2209        );
2210        let (cropped, _roi_norm) = result.unwrap();
2211        // Entire proto map should be selected when coords > 1.0 (clamped to boundary)
2212        assert_eq!(cropped.shape()[0], 16);
2213        assert_eq!(cropped.shape()[1], 16);
2214    }
2215
2216    #[test]
2217    fn test_segdet_float_proto_no_panic() {
2218        // Simulates YOLOv8n-seg: output0 = [116, 8400] (4 box + 80 class + 32 mask coeff)
2219        // output1 (protos) = [32, 160, 160]
2220        let num_proposals = 100; // enough to produce idx >= 32
2221        let num_classes = 80;
2222        let num_mask_coeffs = 32;
2223        let rows = 4 + num_classes + num_mask_coeffs; // 116
2224
2225        // Fill boxes with valid xywh data so some detections pass the threshold.
2226        // Layout is [116, num_proposals] row-major: row 0=cx, 1=cy, 2=w, 3=h,
2227        // rows 4..84=class scores, rows 84..116=mask coefficients.
2228        let mut data = vec![0.0f32; rows * num_proposals];
2229        for i in 0..num_proposals {
2230            let row = |r: usize| r * num_proposals + i;
2231            data[row(0)] = 320.0; // cx
2232            data[row(1)] = 320.0; // cy
2233            data[row(2)] = 50.0; // w
2234            data[row(3)] = 50.0; // h
2235            data[row(4)] = 0.9; // class-0 score
2236        }
2237        let boxes = ndarray::Array2::from_shape_vec((rows, num_proposals), data).unwrap();
2238
2239        // Protos must be in HWC order (decoder.rs protos_to_hwc converts
2240        // before calling into these functions).
2241        let protos = ndarray::Array3::<f32>::zeros((160, 160, num_mask_coeffs));
2242
2243        let mut output_boxes = Vec::with_capacity(300);
2244
2245        // This panicked before fix: mask_tensor.row(idx) with idx >= 32
2246        let proto_data = impl_yolo_segdet_float_proto::<XYWH, _, _>(
2247            boxes.view(),
2248            protos.view(),
2249            0.5,
2250            0.7,
2251            Some(Nms::default()),
2252            &mut output_boxes,
2253        );
2254
2255        // Should produce detections (NMS will collapse many overlapping boxes)
2256        assert!(!output_boxes.is_empty());
2257        assert_eq!(proto_data.mask_coefficients.len(), output_boxes.len());
2258        // Each mask coefficient vector should have 32 elements
2259        for coeffs in &proto_data.mask_coefficients {
2260            assert_eq!(coeffs.len(), num_mask_coeffs);
2261        }
2262    }
2263}