Skip to main content

edgefirst_decoder/
yolo.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use std::fmt::Debug;
5
6use ndarray::{
7    parallel::prelude::{IntoParallelIterator, ParallelIterator},
8    s, Array2, Array3, ArrayView1, ArrayView2, ArrayView3,
9};
10use num_traits::{AsPrimitive, Float, PrimInt, Signed};
11
12use crate::{
13    byte::{
14        nms_class_aware_int, nms_extra_class_aware_int, nms_extra_int, nms_int,
15        postprocess_boxes_index_quant, postprocess_boxes_quant, quantize_score_threshold,
16    },
17    configs::Nms,
18    dequant_detect_box,
19    float::{
20        nms_class_aware_float, nms_extra_class_aware_float, nms_extra_float, nms_float,
21        postprocess_boxes_float, postprocess_boxes_index_float,
22    },
23    BBoxTypeTrait, BoundingBox, DetectBox, DetectBoxQuantized, ProtoData, ProtoTensor,
24    Quantization, Segmentation, XYWH, XYXY,
25};
26
27/// Dispatches to the appropriate NMS function based on mode for float boxes.
28fn dispatch_nms_float(nms: Option<Nms>, iou: f32, boxes: Vec<DetectBox>) -> Vec<DetectBox> {
29    match nms {
30        Some(Nms::ClassAgnostic) => nms_float(iou, boxes),
31        Some(Nms::ClassAware) => nms_class_aware_float(iou, boxes),
32        None => boxes, // bypass NMS
33    }
34}
35
36/// Dispatches to the appropriate NMS function based on mode for float boxes
37/// with extra data.
38pub(super) fn dispatch_nms_extra_float<E: Send + Sync>(
39    nms: Option<Nms>,
40    iou: f32,
41    boxes: Vec<(DetectBox, E)>,
42) -> Vec<(DetectBox, E)> {
43    match nms {
44        Some(Nms::ClassAgnostic) => nms_extra_float(iou, boxes),
45        Some(Nms::ClassAware) => nms_extra_class_aware_float(iou, boxes),
46        None => boxes, // bypass NMS
47    }
48}
49
50/// Dispatches to the appropriate NMS function based on mode for quantized
51/// boxes.
52fn dispatch_nms_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync>(
53    nms: Option<Nms>,
54    iou: f32,
55    boxes: Vec<DetectBoxQuantized<SCORE>>,
56) -> Vec<DetectBoxQuantized<SCORE>> {
57    match nms {
58        Some(Nms::ClassAgnostic) => nms_int(iou, boxes),
59        Some(Nms::ClassAware) => nms_class_aware_int(iou, boxes),
60        None => boxes, // bypass NMS
61    }
62}
63
64/// Dispatches to the appropriate NMS function based on mode for quantized boxes
65/// with extra data.
66fn dispatch_nms_extra_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send + Sync>(
67    nms: Option<Nms>,
68    iou: f32,
69    boxes: Vec<(DetectBoxQuantized<SCORE>, E)>,
70) -> Vec<(DetectBoxQuantized<SCORE>, E)> {
71    match nms {
72        Some(Nms::ClassAgnostic) => nms_extra_int(iou, boxes),
73        Some(Nms::ClassAware) => nms_extra_class_aware_int(iou, boxes),
74        None => boxes, // bypass NMS
75    }
76}
77
78/// Decodes YOLO detection outputs from quantized tensors into detection boxes.
79///
80/// Boxes are expected to be in XYWH format.
81///
82/// Expected shapes of inputs:
83/// - output: (4 + num_classes, num_boxes)
84pub fn decode_yolo_det<BOX: PrimInt + AsPrimitive<f32> + Send + Sync>(
85    output: (ArrayView2<BOX>, Quantization),
86    score_threshold: f32,
87    iou_threshold: f32,
88    nms: Option<Nms>,
89    output_boxes: &mut Vec<DetectBox>,
90) where
91    f32: AsPrimitive<BOX>,
92{
93    impl_yolo_quant::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
94}
95
96/// Decodes YOLO detection outputs from float tensors into detection boxes.
97///
98/// Boxes are expected to be in XYWH format.
99///
100/// Expected shapes of inputs:
101/// - output: (4 + num_classes, num_boxes)
102pub fn decode_yolo_det_float<T>(
103    output: ArrayView2<T>,
104    score_threshold: f32,
105    iou_threshold: f32,
106    nms: Option<Nms>,
107    output_boxes: &mut Vec<DetectBox>,
108) where
109    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
110    f32: AsPrimitive<T>,
111{
112    impl_yolo_float::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
113}
114
115/// Decodes YOLO detection and segmentation outputs from quantized tensors into
116/// detection boxes and segmentation masks.
117///
118/// Boxes are expected to be in XYWH format.
119///
120/// Expected shapes of inputs:
121/// - boxes: (4 + num_classes + num_protos, num_boxes)
122/// - protos: (proto_height, proto_width, num_protos)
123///
124/// # Errors
125/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
126pub fn decode_yolo_segdet_quant<
127    BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
128    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
129>(
130    boxes: (ArrayView2<BOX>, Quantization),
131    protos: (ArrayView3<PROTO>, Quantization),
132    score_threshold: f32,
133    iou_threshold: f32,
134    nms: Option<Nms>,
135    output_boxes: &mut Vec<DetectBox>,
136    output_masks: &mut Vec<Segmentation>,
137) -> Result<(), crate::DecoderError>
138where
139    f32: AsPrimitive<BOX>,
140{
141    impl_yolo_segdet_quant::<XYWH, _, _>(
142        boxes,
143        protos,
144        score_threshold,
145        iou_threshold,
146        nms,
147        output_boxes,
148        output_masks,
149    )
150}
151
152/// Decodes YOLO detection and segmentation outputs from float tensors into
153/// detection boxes and segmentation masks.
154///
155/// Boxes are expected to be in XYWH format.
156///
157/// Expected shapes of inputs:
158/// - boxes: (4 + num_classes + num_protos, num_boxes)
159/// - protos: (proto_height, proto_width, num_protos)
160///
161/// # Errors
162/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
163pub fn decode_yolo_segdet_float<T>(
164    boxes: ArrayView2<T>,
165    protos: ArrayView3<T>,
166    score_threshold: f32,
167    iou_threshold: f32,
168    nms: Option<Nms>,
169    output_boxes: &mut Vec<DetectBox>,
170    output_masks: &mut Vec<Segmentation>,
171) -> Result<(), crate::DecoderError>
172where
173    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
174    f32: AsPrimitive<T>,
175{
176    impl_yolo_segdet_float::<XYWH, _, _>(
177        boxes,
178        protos,
179        score_threshold,
180        iou_threshold,
181        nms,
182        output_boxes,
183        output_masks,
184    )
185}
186
187/// Decodes YOLO split detection outputs from quantized tensors into detection
188/// boxes.
189///
190/// Boxes are expected to be in XYWH format.
191///
192/// Expected shapes of inputs:
193/// - boxes: (4, num_boxes)
194/// - scores: (num_classes, num_boxes)
195///
196/// # Panics
197/// Panics if shapes don't match the expected dimensions.
198pub fn decode_yolo_split_det_quant<
199    BOX: PrimInt + AsPrimitive<i32> + AsPrimitive<f32> + Send + Sync,
200    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
201>(
202    boxes: (ArrayView2<BOX>, Quantization),
203    scores: (ArrayView2<SCORE>, Quantization),
204    score_threshold: f32,
205    iou_threshold: f32,
206    nms: Option<Nms>,
207    output_boxes: &mut Vec<DetectBox>,
208) where
209    f32: AsPrimitive<SCORE>,
210{
211    impl_yolo_split_quant::<XYWH, _, _>(
212        boxes,
213        scores,
214        score_threshold,
215        iou_threshold,
216        nms,
217        output_boxes,
218    );
219}
220
221/// Decodes YOLO split detection outputs from float tensors into detection
222/// boxes.
223///
224/// Boxes are expected to be in XYWH format.
225///
226/// Expected shapes of inputs:
227/// - boxes: (4, num_boxes)
228/// - scores: (num_classes, num_boxes)
229///
230/// # Panics
231/// Panics if shapes don't match the expected dimensions.
232pub fn decode_yolo_split_det_float<T>(
233    boxes: ArrayView2<T>,
234    scores: ArrayView2<T>,
235    score_threshold: f32,
236    iou_threshold: f32,
237    nms: Option<Nms>,
238    output_boxes: &mut Vec<DetectBox>,
239) where
240    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
241    f32: AsPrimitive<T>,
242{
243    impl_yolo_split_float::<XYWH, _, _>(
244        boxes,
245        scores,
246        score_threshold,
247        iou_threshold,
248        nms,
249        output_boxes,
250    );
251}
252
253/// Decodes YOLO split detection segmentation outputs from quantized tensors
254/// into detection boxes and segmentation masks.
255///
256/// Boxes are expected to be in XYWH format.
257///
258/// Expected shapes of inputs:
259/// - boxes_tensor: (4, num_boxes)
260/// - scores_tensor: (num_classes, num_boxes)
261/// - mask_tensor: (num_protos, num_boxes)
262/// - protos: (proto_height, proto_width, num_protos)
263///
264/// # Errors
265/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
266#[allow(clippy::too_many_arguments)]
267pub fn decode_yolo_split_segdet<
268    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
269    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
270    MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
271    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
272>(
273    boxes: (ArrayView2<BOX>, Quantization),
274    scores: (ArrayView2<SCORE>, Quantization),
275    mask_coeff: (ArrayView2<MASK>, Quantization),
276    protos: (ArrayView3<PROTO>, Quantization),
277    score_threshold: f32,
278    iou_threshold: f32,
279    nms: Option<Nms>,
280    output_boxes: &mut Vec<DetectBox>,
281    output_masks: &mut Vec<Segmentation>,
282) -> Result<(), crate::DecoderError>
283where
284    f32: AsPrimitive<SCORE>,
285{
286    impl_yolo_split_segdet_quant::<XYWH, _, _, _, _>(
287        boxes,
288        scores,
289        mask_coeff,
290        protos,
291        score_threshold,
292        iou_threshold,
293        nms,
294        output_boxes,
295        output_masks,
296    )
297}
298
299/// Decodes YOLO split detection segmentation outputs from float tensors
300/// into detection boxes and segmentation masks.
301///
302/// Boxes are expected to be in XYWH format.
303///
304/// Expected shapes of inputs:
305/// - boxes_tensor: (4, num_boxes)
306/// - scores_tensor: (num_classes, num_boxes)
307/// - mask_tensor: (num_protos, num_boxes)
308/// - protos: (proto_height, proto_width, num_protos)
309///
310/// # Errors
311/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
312#[allow(clippy::too_many_arguments)]
313pub fn decode_yolo_split_segdet_float<T>(
314    boxes: ArrayView2<T>,
315    scores: ArrayView2<T>,
316    mask_coeff: ArrayView2<T>,
317    protos: ArrayView3<T>,
318    score_threshold: f32,
319    iou_threshold: f32,
320    nms: Option<Nms>,
321    output_boxes: &mut Vec<DetectBox>,
322    output_masks: &mut Vec<Segmentation>,
323) -> Result<(), crate::DecoderError>
324where
325    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
326    f32: AsPrimitive<T>,
327{
328    impl_yolo_split_segdet_float::<XYWH, _, _, _, _>(
329        boxes,
330        scores,
331        mask_coeff,
332        protos,
333        score_threshold,
334        iou_threshold,
335        nms,
336        output_boxes,
337        output_masks,
338    )
339}
340
341/// Decodes end-to-end YOLO detection outputs (post-NMS from model).
342/// Expects an array of shape `(6, N)`, where the first dimension (rows)
343/// corresponds to the 6 per-detection features
344/// `[x1, y1, x2, y2, conf, class]` and the second dimension (columns)
345/// indexes the `N` detections.
346/// Boxes are output directly without NMS (the model already applied NMS).
347///
348/// Coordinates may be normalized `[0, 1]` or absolute pixel values depending
349/// on the model configuration. The caller should check
350/// `decoder.normalized_boxes()` to determine which.
351///
352/// # Errors
353///
354/// Returns `DecoderError::InvalidShape` if `output` has fewer than 6 rows.
355pub fn decode_yolo_end_to_end_det_float<T>(
356    output: ArrayView2<T>,
357    score_threshold: f32,
358    output_boxes: &mut Vec<DetectBox>,
359) -> Result<(), crate::DecoderError>
360where
361    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
362    f32: AsPrimitive<T>,
363{
364    // Validate input shape: need at least 6 rows (x1, y1, x2, y2, conf, class)
365    if output.shape()[0] < 6 {
366        return Err(crate::DecoderError::InvalidShape(format!(
367            "End-to-end detection output requires at least 6 rows, got {}",
368            output.shape()[0]
369        )));
370    }
371
372    // Input shape: (6, N) -> transpose to (N, 4) for boxes and (N, 1) for scores
373    let boxes = output.slice(s![0..4, ..]).reversed_axes();
374    let scores = output.slice(s![4..5, ..]).reversed_axes();
375    let classes = output.slice(s![5, ..]);
376    let mut boxes =
377        postprocess_boxes_index_float::<XYXY, _, _>(score_threshold.as_(), boxes, scores);
378    boxes.truncate(output_boxes.capacity());
379    output_boxes.clear();
380    for (mut b, i) in boxes.into_iter() {
381        b.label = classes[i].as_() as usize;
382        output_boxes.push(b);
383    }
384    // No NMS — model output is already post-NMS
385    Ok(())
386}
387
388/// Decodes end-to-end YOLO detection + segmentation outputs (post-NMS from
389/// model).
390///
391/// Input shapes:
392/// - detection: (6 + num_protos, N) where rows are [x1, y1, x2, y2, conf,
393///   class, mask_coeff_0, ..., mask_coeff_31]
394/// - protos: (proto_height, proto_width, num_protos)
395///
396/// Boxes are output directly without NMS (model already applied NMS).
397/// Coordinates may be normalized [0,1] or pixel values depending on model
398/// config.
399///
400/// # Errors
401///
402/// Returns `DecoderError::InvalidShape` if:
403/// - output has fewer than 7 rows (6 base + at least 1 mask coefficient)
404/// - protos shape doesn't match mask coefficients count
405pub fn decode_yolo_end_to_end_segdet_float<T>(
406    output: ArrayView2<T>,
407    protos: ArrayView3<T>,
408    score_threshold: f32,
409    output_boxes: &mut Vec<DetectBox>,
410    output_masks: &mut Vec<crate::Segmentation>,
411) -> Result<(), crate::DecoderError>
412where
413    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
414    f32: AsPrimitive<T>,
415{
416    let (boxes, scores, classes, mask_coeff) =
417        postprocess_yolo_end_to_end_segdet(&output, protos.dim().2)?;
418    let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
419        boxes,
420        scores,
421        classes,
422        score_threshold,
423        output_boxes.capacity(),
424    );
425
426    // No NMS — model output is already post-NMS
427
428    impl_yolo_split_segdet_process_masks(boxes, mask_coeff, protos, output_boxes, output_masks)
429}
430
431/// Decodes split end-to-end YOLO detection outputs (post-NMS from model).
432///
433/// Input shapes (after batch dim removed):
434/// - boxes: (N, 4) — xyxy pixel coordinates
435/// - scores: (N, 1) — confidence of the top class
436/// - classes: (N, 1) — class index of the top class
437///
438/// Boxes are output directly without NMS (model already applied NMS).
439pub fn decode_yolo_split_end_to_end_det_float<T: Float + AsPrimitive<f32>>(
440    boxes: ArrayView2<T>,
441    scores: ArrayView2<T>,
442    classes: ArrayView2<T>,
443    score_threshold: f32,
444    output_boxes: &mut Vec<DetectBox>,
445) -> Result<(), crate::DecoderError> {
446    let n = boxes.shape()[0];
447    if boxes.shape()[1] != 4 {
448        return Err(crate::DecoderError::InvalidShape(format!(
449            "Split end-to-end boxes must have 4 columns, got {}",
450            boxes.shape()[1]
451        )));
452    }
453    output_boxes.clear();
454    for i in 0..n {
455        let score: f32 = scores[[i, 0]].as_();
456        if score < score_threshold {
457            continue;
458        }
459        if output_boxes.len() >= output_boxes.capacity() {
460            break;
461        }
462        output_boxes.push(DetectBox {
463            bbox: BoundingBox {
464                xmin: boxes[[i, 0]].as_(),
465                ymin: boxes[[i, 1]].as_(),
466                xmax: boxes[[i, 2]].as_(),
467                ymax: boxes[[i, 3]].as_(),
468            },
469            score,
470            label: classes[[i, 0]].as_() as usize,
471        });
472    }
473    Ok(())
474}
475
476/// Decodes split end-to-end YOLO detection + segmentation outputs.
477///
478/// Input shapes (after batch dim removed):
479/// - boxes: (4, N) — xyxy pixel coordinates
480/// - scores: (1, N) — confidence
481/// - classes: (1, N) — class index
482/// - mask_coeff: (num_protos, N) — mask coefficients per detection
483/// - protos: (proto_h, proto_w, num_protos) — prototype masks
484#[allow(clippy::too_many_arguments)]
485pub fn decode_yolo_split_end_to_end_segdet_float<T>(
486    boxes: ArrayView2<T>,
487    scores: ArrayView2<T>,
488    classes: ArrayView2<T>,
489    mask_coeff: ArrayView2<T>,
490    protos: ArrayView3<T>,
491    score_threshold: f32,
492    output_boxes: &mut Vec<DetectBox>,
493    output_masks: &mut Vec<crate::Segmentation>,
494) -> Result<(), crate::DecoderError>
495where
496    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
497    f32: AsPrimitive<T>,
498{
499    let (boxes, scores, classes, mask_coeff) =
500        postprocess_yolo_split_end_to_end_segdet(boxes, scores, &classes, mask_coeff)?;
501    let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
502        boxes,
503        scores,
504        classes,
505        score_threshold,
506        output_boxes.capacity(),
507    );
508
509    impl_yolo_split_segdet_process_masks(boxes, mask_coeff, protos, output_boxes, output_masks)
510}
511
512#[allow(clippy::type_complexity)]
513pub(crate) fn postprocess_yolo_end_to_end_segdet<'a, T>(
514    output: &'a ArrayView2<'_, T>,
515    num_protos: usize,
516) -> Result<
517    (
518        ArrayView2<'a, T>,
519        ArrayView2<'a, T>,
520        ArrayView1<'a, T>,
521        ArrayView2<'a, T>,
522    ),
523    crate::DecoderError,
524> {
525    // Validate input shape: need at least 7 rows (6 base + at least 1 mask coeff)
526    if output.shape()[0] < 7 {
527        return Err(crate::DecoderError::InvalidShape(format!(
528            "End-to-end segdet output requires at least 7 rows, got {}",
529            output.shape()[0]
530        )));
531    }
532
533    let num_mask_coeffs = output.shape()[0] - 6;
534    if num_mask_coeffs != num_protos {
535        return Err(crate::DecoderError::InvalidShape(format!(
536            "Mask coefficients count ({}) doesn't match protos count ({})",
537            num_mask_coeffs, num_protos
538        )));
539    }
540
541    // Input shape: (6+num_protos, N) -> transpose for postprocessing
542    let boxes = output.slice(s![0..4, ..]).reversed_axes();
543    let scores = output.slice(s![4..5, ..]).reversed_axes();
544    let classes = output.slice(s![5, ..]);
545    let mask_coeff = output.slice(s![6.., ..]).reversed_axes();
546    Ok((boxes, scores, classes, mask_coeff))
547}
548
549#[allow(clippy::type_complexity)]
550pub(crate) fn postprocess_yolo_split_end_to_end_segdet<
551    'a,
552    'b,
553    'c,
554    'd,
555    BOXES,
556    SCORES,
557    CLASS,
558    MASK,
559>(
560    boxes: ArrayView2<'a, BOXES>,
561    scores: ArrayView2<'b, SCORES>,
562    classes: &'c ArrayView2<CLASS>,
563    mask_coeff: ArrayView2<'d, MASK>,
564) -> Result<
565    (
566        ArrayView2<'a, BOXES>,
567        ArrayView2<'b, SCORES>,
568        ArrayView1<'c, CLASS>,
569        ArrayView2<'d, MASK>,
570    ),
571    crate::DecoderError,
572> {
573    if boxes.shape()[0] != 4 {
574        return Err(crate::DecoderError::InvalidShape(format!(
575            "Split end-to-end boxes must have 4 columns, got {}",
576            boxes.shape()[0]
577        )));
578    }
579    let boxes = boxes.reversed_axes();
580    let scores = scores.reversed_axes();
581    let classes = classes.slice(s![0, ..]);
582    let mask_coeff = mask_coeff.reversed_axes();
583    Ok((boxes, scores, classes, mask_coeff))
584}
585/// Internal implementation of YOLO decoding for quantized tensors.
586///
587/// Expected shapes of inputs:
588/// - output: (4 + num_classes, num_boxes)
589pub(crate) fn impl_yolo_quant<B: BBoxTypeTrait, T: PrimInt + AsPrimitive<f32> + Send + Sync>(
590    output: (ArrayView2<T>, Quantization),
591    score_threshold: f32,
592    iou_threshold: f32,
593    nms: Option<Nms>,
594    output_boxes: &mut Vec<DetectBox>,
595) where
596    f32: AsPrimitive<T>,
597{
598    let (boxes, quant_boxes) = output;
599    let (boxes_tensor, scores_tensor) = postprocess_yolo(&boxes);
600
601    let boxes = {
602        let score_threshold = quantize_score_threshold(score_threshold, quant_boxes);
603        postprocess_boxes_quant::<B, _, _>(
604            score_threshold,
605            boxes_tensor,
606            scores_tensor,
607            quant_boxes,
608        )
609    };
610
611    let boxes = dispatch_nms_int(nms, iou_threshold, boxes);
612    let len = output_boxes.capacity().min(boxes.len());
613    output_boxes.clear();
614    for b in boxes.iter().take(len) {
615        output_boxes.push(dequant_detect_box(b, quant_boxes));
616    }
617}
618
619/// Internal implementation of YOLO decoding for float tensors.
620///
621/// Expected shapes of inputs:
622/// - output: (4 + num_classes, num_boxes)
623pub(crate) fn impl_yolo_float<B: BBoxTypeTrait, T: Float + AsPrimitive<f32> + Send + Sync>(
624    output: ArrayView2<T>,
625    score_threshold: f32,
626    iou_threshold: f32,
627    nms: Option<Nms>,
628    output_boxes: &mut Vec<DetectBox>,
629) where
630    f32: AsPrimitive<T>,
631{
632    let (boxes_tensor, scores_tensor) = postprocess_yolo(&output);
633    let boxes =
634        postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
635    let boxes = dispatch_nms_float(nms, iou_threshold, boxes);
636    let len = output_boxes.capacity().min(boxes.len());
637    output_boxes.clear();
638    for b in boxes.into_iter().take(len) {
639        output_boxes.push(b);
640    }
641}
642
643/// Internal implementation of YOLO split detection decoding for quantized
644/// tensors.
645///
646/// Expected shapes of inputs:
647/// - boxes: (4, num_boxes)
648/// - scores: (num_classes, num_boxes)
649///
650/// # Panics
651/// Panics if shapes don't match the expected dimensions.
652pub(crate) fn impl_yolo_split_quant<
653    B: BBoxTypeTrait,
654    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
655    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
656>(
657    boxes: (ArrayView2<BOX>, Quantization),
658    scores: (ArrayView2<SCORE>, Quantization),
659    score_threshold: f32,
660    iou_threshold: f32,
661    nms: Option<Nms>,
662    output_boxes: &mut Vec<DetectBox>,
663) where
664    f32: AsPrimitive<SCORE>,
665{
666    let (boxes_tensor, quant_boxes) = boxes;
667    let (scores_tensor, quant_scores) = scores;
668
669    let boxes_tensor = boxes_tensor.reversed_axes();
670    let scores_tensor = scores_tensor.reversed_axes();
671
672    let boxes = {
673        let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
674        postprocess_boxes_quant::<B, _, _>(
675            score_threshold,
676            boxes_tensor,
677            scores_tensor,
678            quant_boxes,
679        )
680    };
681
682    let boxes = dispatch_nms_int(nms, iou_threshold, boxes);
683    let len = output_boxes.capacity().min(boxes.len());
684    output_boxes.clear();
685    for b in boxes.iter().take(len) {
686        output_boxes.push(dequant_detect_box(b, quant_scores));
687    }
688}
689
690/// Internal implementation of YOLO split detection decoding for float tensors.
691///
692/// Expected shapes of inputs:
693/// - boxes: (4, num_boxes)
694/// - scores: (num_classes, num_boxes)
695///
696/// # Panics
697/// Panics if shapes don't match the expected dimensions.
698pub(crate) fn impl_yolo_split_float<
699    B: BBoxTypeTrait,
700    BOX: Float + AsPrimitive<f32> + Send + Sync,
701    SCORE: Float + AsPrimitive<f32> + Send + Sync,
702>(
703    boxes_tensor: ArrayView2<BOX>,
704    scores_tensor: ArrayView2<SCORE>,
705    score_threshold: f32,
706    iou_threshold: f32,
707    nms: Option<Nms>,
708    output_boxes: &mut Vec<DetectBox>,
709) where
710    f32: AsPrimitive<SCORE>,
711{
712    let boxes_tensor = boxes_tensor.reversed_axes();
713    let scores_tensor = scores_tensor.reversed_axes();
714    let boxes =
715        postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
716    let boxes = dispatch_nms_float(nms, iou_threshold, boxes);
717    let len = output_boxes.capacity().min(boxes.len());
718    output_boxes.clear();
719    for b in boxes.into_iter().take(len) {
720        output_boxes.push(b);
721    }
722}
723
724/// Internal implementation of YOLO detection segmentation decoding for
725/// quantized tensors.
726///
727/// Expected shapes of inputs:
728/// - boxes: (4 + num_classes + num_protos, num_boxes)
729/// - protos: (proto_height, proto_width, num_protos)
730///
731/// # Errors
732/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
733pub(crate) fn impl_yolo_segdet_quant<
734    B: BBoxTypeTrait,
735    BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
736    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
737>(
738    boxes: (ArrayView2<BOX>, Quantization),
739    protos: (ArrayView3<PROTO>, Quantization),
740    score_threshold: f32,
741    iou_threshold: f32,
742    nms: Option<Nms>,
743    output_boxes: &mut Vec<DetectBox>,
744    output_masks: &mut Vec<Segmentation>,
745) -> Result<(), crate::DecoderError>
746where
747    f32: AsPrimitive<BOX>,
748{
749    let (boxes, quant_boxes) = boxes;
750    let num_protos = protos.0.dim().2;
751
752    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
753    let boxes = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
754        (boxes_tensor, quant_boxes),
755        (scores_tensor, quant_boxes),
756        score_threshold,
757        iou_threshold,
758        nms,
759        output_boxes.capacity(),
760    );
761
762    impl_yolo_split_segdet_quant_process_masks::<_, _>(
763        boxes,
764        (mask_tensor, quant_boxes),
765        protos,
766        output_boxes,
767        output_masks,
768    )
769}
770
771/// Internal implementation of YOLO detection segmentation decoding for
772/// float tensors.
773///
774/// Expected shapes of inputs:
775/// - boxes: (4 + num_classes + num_protos, num_boxes)
776/// - protos: (proto_height, proto_width, num_protos)
777///
778/// # Panics
779/// Panics if shapes don't match the expected dimensions.
780pub(crate) fn impl_yolo_segdet_float<
781    B: BBoxTypeTrait,
782    BOX: Float + AsPrimitive<f32> + Send + Sync,
783    PROTO: Float + AsPrimitive<f32> + Send + Sync,
784>(
785    boxes: ArrayView2<BOX>,
786    protos: ArrayView3<PROTO>,
787    score_threshold: f32,
788    iou_threshold: f32,
789    nms: Option<Nms>,
790    output_boxes: &mut Vec<DetectBox>,
791    output_masks: &mut Vec<Segmentation>,
792) -> Result<(), crate::DecoderError>
793where
794    f32: AsPrimitive<BOX>,
795{
796    let num_protos = protos.dim().2;
797    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
798    let boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
799        boxes_tensor,
800        scores_tensor,
801        score_threshold,
802        iou_threshold,
803        nms,
804        output_boxes.capacity(),
805    );
806    impl_yolo_split_segdet_process_masks(boxes, mask_tensor, protos, output_boxes, output_masks)
807}
808
809pub(crate) fn impl_yolo_segdet_get_boxes<
810    B: BBoxTypeTrait,
811    BOX: Float + AsPrimitive<f32> + Send + Sync,
812    SCORE: Float + AsPrimitive<f32> + Send + Sync,
813>(
814    boxes_tensor: ArrayView2<BOX>,
815    scores_tensor: ArrayView2<SCORE>,
816    score_threshold: f32,
817    iou_threshold: f32,
818    nms: Option<Nms>,
819    max_boxes: usize,
820) -> Vec<(DetectBox, usize)>
821where
822    f32: AsPrimitive<SCORE>,
823{
824    let boxes = postprocess_boxes_index_float::<B, _, _>(
825        score_threshold.as_(),
826        boxes_tensor,
827        scores_tensor,
828    );
829    let mut boxes = dispatch_nms_extra_float(nms, iou_threshold, boxes);
830    boxes.truncate(max_boxes);
831    boxes
832}
833
834pub(crate) fn impl_yolo_end_to_end_segdet_get_boxes<
835    B: BBoxTypeTrait,
836    BOX: Float + AsPrimitive<f32> + Send + Sync,
837    SCORE: Float + AsPrimitive<f32> + Send + Sync,
838    CLASS: AsPrimitive<f32> + Send + Sync,
839>(
840    boxes: ArrayView2<BOX>,
841    scores: ArrayView2<SCORE>,
842    classes: ArrayView1<CLASS>,
843    score_threshold: f32,
844    max_boxes: usize,
845) -> Vec<(DetectBox, usize)>
846where
847    f32: AsPrimitive<SCORE>,
848{
849    let mut boxes = postprocess_boxes_index_float::<B, _, _>(score_threshold.as_(), boxes, scores);
850    boxes.truncate(max_boxes);
851    for (b, ind) in &mut boxes {
852        b.label = classes[*ind].as_().round() as usize;
853    }
854    boxes
855}
856
857pub(crate) fn impl_yolo_split_segdet_process_masks<
858    MASK: Float + AsPrimitive<f32> + Send + Sync,
859    PROTO: Float + AsPrimitive<f32> + Send + Sync,
860>(
861    boxes: Vec<(DetectBox, usize)>,
862    masks_tensor: ArrayView2<MASK>,
863    protos_tensor: ArrayView3<PROTO>,
864    output_boxes: &mut Vec<DetectBox>,
865    output_masks: &mut Vec<Segmentation>,
866) -> Result<(), crate::DecoderError> {
867    let boxes = decode_segdet_f32(boxes, masks_tensor, protos_tensor)?;
868    output_boxes.clear();
869    output_masks.clear();
870    for (b, m) in boxes.into_iter() {
871        output_boxes.push(b);
872        output_masks.push(Segmentation {
873            xmin: b.bbox.xmin,
874            ymin: b.bbox.ymin,
875            xmax: b.bbox.xmax,
876            ymax: b.bbox.ymax,
877            segmentation: m,
878        });
879    }
880    Ok(())
881}
882/// Expected input shapes:
883/// - boxes_tensor: (num_boxes, 4)
884/// - scores_tensor: (num_boxes, num_classes)
885pub(crate) fn impl_yolo_split_segdet_quant_get_boxes<
886    B: BBoxTypeTrait,
887    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
888    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
889>(
890    boxes: (ArrayView2<BOX>, Quantization),
891    scores: (ArrayView2<SCORE>, Quantization),
892    score_threshold: f32,
893    iou_threshold: f32,
894    nms: Option<Nms>,
895    max_boxes: usize,
896) -> Vec<(DetectBox, usize)>
897where
898    f32: AsPrimitive<SCORE>,
899{
900    let (boxes_tensor, quant_boxes) = boxes;
901    let (scores_tensor, quant_scores) = scores;
902
903    let boxes = {
904        let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
905        postprocess_boxes_index_quant::<B, _, _>(
906            score_threshold,
907            boxes_tensor,
908            scores_tensor,
909            quant_boxes,
910        )
911    };
912    let mut boxes = dispatch_nms_extra_int(nms, iou_threshold, boxes);
913    boxes.truncate(max_boxes);
914    boxes
915        .into_iter()
916        .map(|(b, i)| (dequant_detect_box(&b, quant_scores), i))
917        .collect()
918}
919
920pub(crate) fn impl_yolo_split_segdet_quant_process_masks<
921    MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
922    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
923>(
924    boxes: Vec<(DetectBox, usize)>,
925    mask_coeff: (ArrayView2<MASK>, Quantization),
926    protos: (ArrayView3<PROTO>, Quantization),
927    output_boxes: &mut Vec<DetectBox>,
928    output_masks: &mut Vec<Segmentation>,
929) -> Result<(), crate::DecoderError> {
930    let (masks, quant_masks) = mask_coeff;
931    let (protos, quant_protos) = protos;
932
933    let boxes = decode_segdet_quant(boxes, masks, protos, quant_masks, quant_protos)?;
934    output_boxes.clear();
935    output_masks.clear();
936    for (b, m) in boxes.into_iter() {
937        output_boxes.push(b);
938        output_masks.push(Segmentation {
939            xmin: b.bbox.xmin,
940            ymin: b.bbox.ymin,
941            xmax: b.bbox.xmax,
942            ymax: b.bbox.ymax,
943            segmentation: m,
944        });
945    }
946    Ok(())
947}
948
949#[allow(clippy::too_many_arguments)]
950/// Internal implementation of YOLO split detection segmentation decoding for
951/// quantized tensors.
952///
953/// Expected shapes of inputs:
954/// - boxes_tensor: (4, num_boxes)
955/// - scores_tensor: (num_classes, num_boxes)
956/// - mask_tensor: (num_protos, num_boxes)
957/// - protos: (proto_height, proto_width, num_protos)
958///
959/// # Errors
960/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
961pub(crate) fn impl_yolo_split_segdet_quant<
962    B: BBoxTypeTrait,
963    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
964    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
965    MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
966    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
967>(
968    boxes: (ArrayView2<BOX>, Quantization),
969    scores: (ArrayView2<SCORE>, Quantization),
970    mask_coeff: (ArrayView2<MASK>, Quantization),
971    protos: (ArrayView3<PROTO>, Quantization),
972    score_threshold: f32,
973    iou_threshold: f32,
974    nms: Option<Nms>,
975    output_boxes: &mut Vec<DetectBox>,
976    output_masks: &mut Vec<Segmentation>,
977) -> Result<(), crate::DecoderError>
978where
979    f32: AsPrimitive<SCORE>,
980{
981    let (boxes_, scores_, mask_coeff_) =
982        postprocess_yolo_split_segdet(boxes.0, scores.0, mask_coeff.0);
983    let boxes = (boxes_, boxes.1);
984    let scores = (scores_, scores.1);
985    let mask_coeff = (mask_coeff_, mask_coeff.1);
986
987    let boxes = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
988        boxes,
989        scores,
990        score_threshold,
991        iou_threshold,
992        nms,
993        output_boxes.capacity(),
994    );
995
996    impl_yolo_split_segdet_quant_process_masks(
997        boxes,
998        mask_coeff,
999        protos,
1000        output_boxes,
1001        output_masks,
1002    )
1003}
1004
1005#[allow(clippy::too_many_arguments)]
1006/// Internal implementation of YOLO split detection segmentation decoding for
1007/// float tensors.
1008///
1009/// Expected shapes of inputs:
1010/// - boxes_tensor: (4, num_boxes)
1011/// - scores_tensor: (num_classes, num_boxes)
1012/// - mask_tensor: (num_protos, num_boxes)
1013/// - protos: (proto_height, proto_width, num_protos)
1014///
1015/// # Errors
1016/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
1017pub(crate) fn impl_yolo_split_segdet_float<
1018    B: BBoxTypeTrait,
1019    BOX: Float + AsPrimitive<f32> + Send + Sync,
1020    SCORE: Float + AsPrimitive<f32> + Send + Sync,
1021    MASK: Float + AsPrimitive<f32> + Send + Sync,
1022    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1023>(
1024    boxes_tensor: ArrayView2<BOX>,
1025    scores_tensor: ArrayView2<SCORE>,
1026    mask_tensor: ArrayView2<MASK>,
1027    protos: ArrayView3<PROTO>,
1028    score_threshold: f32,
1029    iou_threshold: f32,
1030    nms: Option<Nms>,
1031    output_boxes: &mut Vec<DetectBox>,
1032    output_masks: &mut Vec<Segmentation>,
1033) -> Result<(), crate::DecoderError>
1034where
1035    f32: AsPrimitive<SCORE>,
1036{
1037    let (boxes_tensor, scores_tensor, mask_tensor) =
1038        postprocess_yolo_split_segdet(boxes_tensor, scores_tensor, mask_tensor);
1039
1040    let boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
1041        boxes_tensor,
1042        scores_tensor,
1043        score_threshold,
1044        iou_threshold,
1045        nms,
1046        output_boxes.capacity(),
1047    );
1048    impl_yolo_split_segdet_process_masks(boxes, mask_tensor, protos, output_boxes, output_masks)
1049}
1050
1051// ---------------------------------------------------------------------------
1052// Proto-extraction variants: return ProtoData instead of materialized masks
1053// ---------------------------------------------------------------------------
1054
1055/// Proto-extraction variant of `impl_yolo_segdet_quant`.
1056/// Runs NMS but returns raw `ProtoData` instead of materialized masks.
1057pub fn impl_yolo_segdet_quant_proto<
1058    B: BBoxTypeTrait,
1059    BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1060    PROTO: PrimInt
1061        + AsPrimitive<i64>
1062        + AsPrimitive<i128>
1063        + AsPrimitive<f32>
1064        + AsPrimitive<i8>
1065        + Send
1066        + Sync,
1067>(
1068    boxes: (ArrayView2<BOX>, Quantization),
1069    protos: (ArrayView3<PROTO>, Quantization),
1070    score_threshold: f32,
1071    iou_threshold: f32,
1072    nms: Option<Nms>,
1073    output_boxes: &mut Vec<DetectBox>,
1074) -> ProtoData
1075where
1076    f32: AsPrimitive<BOX>,
1077{
1078    let (boxes_arr, quant_boxes) = boxes;
1079    let (protos_arr, quant_protos) = protos;
1080    let num_protos = protos_arr.dim().2;
1081
1082    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes_arr, num_protos);
1083
1084    let det_indices = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
1085        (boxes_tensor, quant_boxes),
1086        (scores_tensor, quant_boxes),
1087        score_threshold,
1088        iou_threshold,
1089        nms,
1090        output_boxes.capacity(),
1091    );
1092
1093    extract_proto_data_quant(
1094        det_indices,
1095        mask_tensor,
1096        quant_boxes,
1097        protos_arr,
1098        quant_protos,
1099        output_boxes,
1100    )
1101}
1102
1103/// Proto-extraction variant of `impl_yolo_segdet_float`.
1104/// Runs NMS but returns raw `ProtoData` instead of materialized masks.
1105pub(crate) fn impl_yolo_segdet_float_proto<
1106    B: BBoxTypeTrait,
1107    BOX: Float + AsPrimitive<f32> + Send + Sync,
1108    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1109>(
1110    boxes: ArrayView2<BOX>,
1111    protos: ArrayView3<PROTO>,
1112    score_threshold: f32,
1113    iou_threshold: f32,
1114    nms: Option<Nms>,
1115    output_boxes: &mut Vec<DetectBox>,
1116) -> ProtoData
1117where
1118    f32: AsPrimitive<BOX>,
1119{
1120    let num_protos = protos.dim().2;
1121    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
1122
1123    let boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
1124        boxes_tensor,
1125        scores_tensor,
1126        score_threshold,
1127        iou_threshold,
1128        nms,
1129        output_boxes.capacity(),
1130    );
1131
1132    extract_proto_data_float(boxes, mask_tensor, protos, output_boxes)
1133}
1134
1135/// Proto-extraction variant of `impl_yolo_split_segdet_float`.
1136/// Runs NMS but returns raw `ProtoData` instead of materialized masks.
1137#[allow(clippy::too_many_arguments)]
1138pub(crate) fn impl_yolo_split_segdet_float_proto<
1139    B: BBoxTypeTrait,
1140    BOX: Float + AsPrimitive<f32> + Send + Sync,
1141    SCORE: Float + AsPrimitive<f32> + Send + Sync,
1142    MASK: Float + AsPrimitive<f32> + Send + Sync,
1143    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1144>(
1145    boxes_tensor: ArrayView2<BOX>,
1146    scores_tensor: ArrayView2<SCORE>,
1147    mask_tensor: ArrayView2<MASK>,
1148    protos: ArrayView3<PROTO>,
1149    score_threshold: f32,
1150    iou_threshold: f32,
1151    nms: Option<Nms>,
1152    output_boxes: &mut Vec<DetectBox>,
1153) -> ProtoData
1154where
1155    f32: AsPrimitive<SCORE>,
1156{
1157    let (boxes_tensor, scores_tensor, mask_tensor) =
1158        postprocess_yolo_split_segdet(boxes_tensor, scores_tensor, mask_tensor);
1159    let det_indices = impl_yolo_segdet_get_boxes::<B, _, _>(
1160        boxes_tensor,
1161        scores_tensor,
1162        score_threshold,
1163        iou_threshold,
1164        nms,
1165        output_boxes.capacity(),
1166    );
1167
1168    extract_proto_data_float(det_indices, mask_tensor, protos, output_boxes)
1169}
1170
1171/// Proto-extraction variant of `decode_yolo_end_to_end_segdet_float`.
1172pub fn decode_yolo_end_to_end_segdet_float_proto<T>(
1173    output: ArrayView2<T>,
1174    protos: ArrayView3<T>,
1175    score_threshold: f32,
1176    output_boxes: &mut Vec<DetectBox>,
1177) -> Result<ProtoData, crate::DecoderError>
1178where
1179    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
1180    f32: AsPrimitive<T>,
1181{
1182    let (boxes, scores, classes, mask_coeff) =
1183        postprocess_yolo_end_to_end_segdet(&output, protos.dim().2)?;
1184    let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
1185        boxes,
1186        scores,
1187        classes,
1188        score_threshold,
1189        output_boxes.capacity(),
1190    );
1191
1192    Ok(extract_proto_data_float(
1193        boxes,
1194        mask_coeff,
1195        protos,
1196        output_boxes,
1197    ))
1198}
1199
1200/// Proto-extraction variant of `decode_yolo_split_end_to_end_segdet_float`.
1201#[allow(clippy::too_many_arguments)]
1202pub fn decode_yolo_split_end_to_end_segdet_float_proto<T>(
1203    boxes: ArrayView2<T>,
1204    scores: ArrayView2<T>,
1205    classes: ArrayView2<T>,
1206    mask_coeff: ArrayView2<T>,
1207    protos: ArrayView3<T>,
1208    score_threshold: f32,
1209    output_boxes: &mut Vec<DetectBox>,
1210) -> Result<ProtoData, crate::DecoderError>
1211where
1212    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
1213    f32: AsPrimitive<T>,
1214{
1215    let (boxes, scores, classes, mask_coeff) =
1216        postprocess_yolo_split_end_to_end_segdet(boxes, scores, &classes, mask_coeff)?;
1217    let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
1218        boxes,
1219        scores,
1220        classes,
1221        score_threshold,
1222        output_boxes.capacity(),
1223    );
1224
1225    Ok(extract_proto_data_float(
1226        boxes,
1227        mask_coeff,
1228        protos,
1229        output_boxes,
1230    ))
1231}
1232
1233/// Helper: extract ProtoData from float mask coefficients + protos.
1234pub(super) fn extract_proto_data_float<
1235    MASK: Float + AsPrimitive<f32> + Send + Sync,
1236    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1237>(
1238    det_indices: Vec<(DetectBox, usize)>,
1239    mask_tensor: ArrayView2<MASK>,
1240    protos: ArrayView3<PROTO>,
1241    output_boxes: &mut Vec<DetectBox>,
1242) -> ProtoData {
1243    let mut mask_coefficients = Vec::with_capacity(det_indices.len());
1244    output_boxes.clear();
1245    for (det, idx) in det_indices {
1246        output_boxes.push(det);
1247        let row = mask_tensor.row(idx);
1248        mask_coefficients.push(row.iter().map(|v| v.as_()).collect());
1249    }
1250    let protos_f32 = protos.map(|v| v.as_());
1251    ProtoData {
1252        mask_coefficients,
1253        protos: ProtoTensor::Float(protos_f32),
1254    }
1255}
1256
1257/// Helper: extract ProtoData from quantized mask coefficients + protos.
1258///
1259/// Dequantizes mask coefficients to f32 (small — per-detection) but keeps
1260/// protos in raw int8 form wrapped in `ProtoTensor::Quantized` so the GPU
1261/// shader can dequantize per-texel without CPU overhead.
1262pub(crate) fn extract_proto_data_quant<
1263    MASK: PrimInt + AsPrimitive<f32> + Send + Sync,
1264    PROTO: PrimInt + AsPrimitive<f32> + AsPrimitive<i8> + Send + Sync + 'static,
1265>(
1266    det_indices: Vec<(DetectBox, usize)>,
1267    mask_tensor: ArrayView2<MASK>,
1268    quant_masks: Quantization,
1269    protos: ArrayView3<PROTO>,
1270    quant_protos: Quantization,
1271    output_boxes: &mut Vec<DetectBox>,
1272) -> ProtoData {
1273    let mut mask_coefficients = Vec::with_capacity(det_indices.len());
1274    output_boxes.clear();
1275    for (det, idx) in det_indices {
1276        output_boxes.push(det);
1277        let row = mask_tensor.row(idx);
1278        mask_coefficients.push(
1279            row.iter()
1280                .map(|v| (v.as_() - quant_masks.zero_point as f32) * quant_masks.scale)
1281                .collect(),
1282        );
1283    }
1284    // Keep protos in raw int8 — GPU shader will dequantize per-texel.
1285    // When PROTO is already i8, use to_owned() for a flat memcpy instead of
1286    // per-element as_() conversion.
1287    let protos_i8 = if std::any::TypeId::of::<PROTO>() == std::any::TypeId::of::<i8>() {
1288        // SAFETY: PROTO and i8 have identical size and layout when TypeId matches.
1289        let view_i8 =
1290            unsafe { &*(&protos as *const ArrayView3<'_, PROTO> as *const ArrayView3<'_, i8>) };
1291        view_i8.to_owned()
1292    } else {
1293        protos.map(|v| {
1294            let v_i8: i8 = v.as_();
1295            v_i8
1296        })
1297    };
1298    ProtoData {
1299        mask_coefficients,
1300        protos: ProtoTensor::Quantized {
1301            protos: protos_i8,
1302            quantization: quant_protos,
1303        },
1304    }
1305}
1306
1307fn postprocess_yolo<'a, T>(
1308    output: &'a ArrayView2<'_, T>,
1309) -> (ArrayView2<'a, T>, ArrayView2<'a, T>) {
1310    let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1311    let scores_tensor = output.slice(s![4.., ..,]).reversed_axes();
1312    (boxes_tensor, scores_tensor)
1313}
1314
1315pub(crate) fn postprocess_yolo_seg<'a, T>(
1316    output: &'a ArrayView2<'_, T>,
1317    num_protos: usize,
1318) -> (ArrayView2<'a, T>, ArrayView2<'a, T>, ArrayView2<'a, T>) {
1319    assert!(
1320        output.shape()[0] > num_protos + 4,
1321        "Output shape is too short: {} <= {} + 4",
1322        output.shape()[0],
1323        num_protos
1324    );
1325    let num_classes = output.shape()[0] - 4 - num_protos;
1326    let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1327    let scores_tensor = output.slice(s![4..(num_classes + 4), ..,]).reversed_axes();
1328    let mask_tensor = output.slice(s![(num_classes + 4).., ..,]).reversed_axes();
1329    (boxes_tensor, scores_tensor, mask_tensor)
1330}
1331
1332pub(crate) fn postprocess_yolo_split_segdet<'a, 'b, 'c, BOX, SCORE, MASK>(
1333    boxes_tensor: ArrayView2<'a, BOX>,
1334    scores_tensor: ArrayView2<'b, SCORE>,
1335    mask_tensor: ArrayView2<'c, MASK>,
1336) -> (
1337    ArrayView2<'a, BOX>,
1338    ArrayView2<'b, SCORE>,
1339    ArrayView2<'c, MASK>,
1340) {
1341    let boxes_tensor = boxes_tensor.reversed_axes();
1342    let scores_tensor = scores_tensor.reversed_axes();
1343    let mask_tensor = mask_tensor.reversed_axes();
1344    (boxes_tensor, scores_tensor, mask_tensor)
1345}
1346
1347fn decode_segdet_f32<
1348    MASK: Float + AsPrimitive<f32> + Send + Sync,
1349    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1350>(
1351    boxes: Vec<(DetectBox, usize)>,
1352    masks: ArrayView2<MASK>,
1353    protos: ArrayView3<PROTO>,
1354) -> Result<Vec<(DetectBox, Array3<u8>)>, crate::DecoderError> {
1355    if boxes.is_empty() {
1356        return Ok(Vec::new());
1357    }
1358    if masks.shape()[1] != protos.shape()[2] {
1359        return Err(crate::DecoderError::InvalidShape(format!(
1360            "Mask coefficients count ({}) doesn't match protos channel count ({})",
1361            masks.shape()[1],
1362            protos.shape()[2],
1363        )));
1364    }
1365    boxes
1366        .into_par_iter()
1367        .map(|mut b| {
1368            let ind = b.1;
1369            let (protos, roi) = protobox(&protos, &b.0.bbox)?;
1370            b.0.bbox = roi;
1371            Ok((b.0, make_segmentation(masks.row(ind), protos.view())))
1372        })
1373        .collect()
1374}
1375
1376pub(crate) fn decode_segdet_quant<
1377    MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1378    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1379>(
1380    boxes: Vec<(DetectBox, usize)>,
1381    masks: ArrayView2<MASK>,
1382    protos: ArrayView3<PROTO>,
1383    quant_masks: Quantization,
1384    quant_protos: Quantization,
1385) -> Result<Vec<(DetectBox, Array3<u8>)>, crate::DecoderError> {
1386    if boxes.is_empty() {
1387        return Ok(Vec::new());
1388    }
1389    if masks.shape()[1] != protos.shape()[2] {
1390        return Err(crate::DecoderError::InvalidShape(format!(
1391            "Mask coefficients count ({}) doesn't match protos channel count ({})",
1392            masks.shape()[1],
1393            protos.shape()[2],
1394        )));
1395    }
1396
1397    let total_bits = MASK::zero().count_zeros() + PROTO::zero().count_zeros() + 5; // 32 protos is 2^5
1398    boxes
1399        .into_iter()
1400        .map(|mut b| {
1401            let i = b.1;
1402            let (protos, roi) = protobox(&protos, &b.0.bbox.to_canonical())?;
1403            b.0.bbox = roi;
1404            let seg = match total_bits {
1405                0..=64 => make_segmentation_quant::<MASK, PROTO, i64>(
1406                    masks.row(i),
1407                    protos.view(),
1408                    quant_masks,
1409                    quant_protos,
1410                ),
1411                65..=128 => make_segmentation_quant::<MASK, PROTO, i128>(
1412                    masks.row(i),
1413                    protos.view(),
1414                    quant_masks,
1415                    quant_protos,
1416                ),
1417                _ => {
1418                    return Err(crate::DecoderError::NotSupported(format!(
1419                        "Unsupported bit width ({total_bits}) for segmentation computation"
1420                    )));
1421                }
1422            };
1423            Ok((b.0, seg))
1424        })
1425        .collect()
1426}
1427
1428fn protobox<'a, T>(
1429    protos: &'a ArrayView3<T>,
1430    roi: &BoundingBox,
1431) -> Result<(ArrayView3<'a, T>, BoundingBox), crate::DecoderError> {
1432    let width = protos.dim().1 as f32;
1433    let height = protos.dim().0 as f32;
1434
1435    // Detect un-normalized bounding boxes (pixel-space coordinates).
1436    // protobox expects normalized coordinates in [0, 1]. ONNX models output
1437    // pixel-space boxes (e.g. 0-640) which must be normalized before calling
1438    // decode(). Without this check, pixel-space coordinates silently clamp to
1439    // the proto boundary, producing empty (0, 0, C) masks for every detection.
1440    //
1441    // The limit is set to 2.0 (not 1.01) because YOLO models legitimately
1442    // predict coordinates slightly > 1.0 for objects near frame edges.
1443    // Any value > 2.0 is clearly pixel-space (even the smallest practical
1444    // model input of 32×32 would produce coordinates >> 2.0).
1445    const NORM_LIMIT: f32 = 2.0;
1446    if roi.xmin > NORM_LIMIT
1447        || roi.ymin > NORM_LIMIT
1448        || roi.xmax > NORM_LIMIT
1449        || roi.ymax > NORM_LIMIT
1450    {
1451        return Err(crate::DecoderError::InvalidShape(format!(
1452            "Bounding box coordinates appear un-normalized (pixel-space). \
1453             Got bbox=({:.2}, {:.2}, {:.2}, {:.2}) but expected values in [0, 1]. \
1454             ONNX models output pixel-space boxes — normalize them by dividing by \
1455             the input dimensions before calling decode().",
1456            roi.xmin, roi.ymin, roi.xmax, roi.ymax,
1457        )));
1458    }
1459
1460    let roi = [
1461        (roi.xmin * width).clamp(0.0, width) as usize,
1462        (roi.ymin * height).clamp(0.0, height) as usize,
1463        (roi.xmax * width).clamp(0.0, width).ceil() as usize,
1464        (roi.ymax * height).clamp(0.0, height).ceil() as usize,
1465    ];
1466
1467    let roi_norm = [
1468        roi[0] as f32 / width,
1469        roi[1] as f32 / height,
1470        roi[2] as f32 / width,
1471        roi[3] as f32 / height,
1472    ]
1473    .into();
1474
1475    let cropped = protos.slice(s![roi[1]..roi[3], roi[0]..roi[2], ..]);
1476
1477    Ok((cropped, roi_norm))
1478}
1479
1480/// Compute a single instance segmentation mask from mask coefficients and
1481/// proto maps (float path).
1482///
1483/// Computes `sigmoid(coefficients · protos)` and maps to `[0, 255]`.
1484/// Returns an `(H, W, 1)` u8 array.
1485fn make_segmentation<
1486    MASK: Float + AsPrimitive<f32> + Send + Sync,
1487    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1488>(
1489    mask: ArrayView1<MASK>,
1490    protos: ArrayView3<PROTO>,
1491) -> Array3<u8> {
1492    let shape = protos.shape();
1493
1494    // Safe to unwrap since the shapes will always be compatible
1495    let mask = mask.to_shape((1, mask.len())).unwrap();
1496    let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
1497    let protos = protos.reversed_axes();
1498    let mask = mask.map(|x| x.as_());
1499    let protos = protos.map(|x| x.as_());
1500
1501    // Safe to unwrap since the shapes will always be compatible
1502    let mask = mask
1503        .dot(&protos)
1504        .into_shape_with_order((shape[0], shape[1], 1))
1505        .unwrap();
1506
1507    mask.map(|x| {
1508        let sigmoid = 1.0 / (1.0 + (-*x).exp());
1509        (sigmoid * 255.0).round() as u8
1510    })
1511}
1512
1513/// Compute a single instance segmentation mask from quantized mask
1514/// coefficients and proto maps.
1515///
1516/// Dequantizes both inputs (subtracting zero-points), computes the dot
1517/// product, applies sigmoid, and maps to `[0, 255]`.
1518/// Returns an `(H, W, 1)` u8 array.
1519fn make_segmentation_quant<
1520    MASK: PrimInt + AsPrimitive<DEST> + Send + Sync,
1521    PROTO: PrimInt + AsPrimitive<DEST> + Send + Sync,
1522    DEST: PrimInt + 'static + Signed + AsPrimitive<f32> + Debug,
1523>(
1524    mask: ArrayView1<MASK>,
1525    protos: ArrayView3<PROTO>,
1526    quant_masks: Quantization,
1527    quant_protos: Quantization,
1528) -> Array3<u8>
1529where
1530    i32: AsPrimitive<DEST>,
1531    f32: AsPrimitive<DEST>,
1532{
1533    let shape = protos.shape();
1534
1535    // Safe to unwrap since the shapes will always be compatible
1536    let mask = mask.to_shape((1, mask.len())).unwrap();
1537
1538    let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
1539    let protos = protos.reversed_axes();
1540
1541    let zp = quant_masks.zero_point.as_();
1542
1543    let mask = mask.mapv(|x| x.as_() - zp);
1544
1545    let zp = quant_protos.zero_point.as_();
1546    let protos = protos.mapv(|x| x.as_() - zp);
1547
1548    // Safe to unwrap since the shapes will always be compatible
1549    let segmentation = mask
1550        .dot(&protos)
1551        .into_shape_with_order((shape[0], shape[1], 1))
1552        .unwrap();
1553
1554    let combined_scale = quant_masks.scale * quant_protos.scale;
1555    segmentation.map(|x| {
1556        let val: f32 = (*x).as_() * combined_scale;
1557        let sigmoid = 1.0 / (1.0 + (-val).exp());
1558        (sigmoid * 255.0).round() as u8
1559    })
1560}
1561
1562/// Converts Yolo Instance Segmentation into a 2D mask.
1563///
1564/// The input segmentation is expected to have shape (H, W, 1).
1565///
1566/// The output mask will have shape (H, W), with values 0 or 1 based on the
1567/// threshold.
1568///
1569/// # Errors
1570///
1571/// Returns `DecoderError::InvalidShape` if the input segmentation does not
1572/// have shape (H, W, 1).
1573pub fn yolo_segmentation_to_mask(
1574    segmentation: ArrayView3<u8>,
1575    threshold: u8,
1576) -> Result<Array2<u8>, crate::DecoderError> {
1577    if segmentation.shape()[2] != 1 {
1578        return Err(crate::DecoderError::InvalidShape(format!(
1579            "Yolo Instance Segmentation should have shape (H, W, 1), got (H, W, {})",
1580            segmentation.shape()[2]
1581        )));
1582    }
1583    Ok(segmentation
1584        .slice(s![.., .., 0])
1585        .map(|x| if *x >= threshold { 1 } else { 0 }))
1586}
1587
1588#[cfg(test)]
1589#[cfg_attr(coverage_nightly, coverage(off))]
1590mod tests {
1591    use super::*;
1592    use ndarray::Array2;
1593
1594    // ========================================================================
1595    // Tests for decode_yolo_end_to_end_det_float
1596    // ========================================================================
1597
1598    #[test]
1599    fn test_end_to_end_det_basic_filtering() {
1600        // Create synthetic end-to-end detection output: (6, N) where rows are
1601        // [x1, y1, x2, y2, conf, class]
1602        // 3 detections: one above threshold, two below
1603        let data: Vec<f32> = vec![
1604            // Detection 0: high score (0.9)
1605            0.1, 0.2, 0.3, // x1 values
1606            0.1, 0.2, 0.3, // y1 values
1607            0.5, 0.6, 0.7, // x2 values
1608            0.5, 0.6, 0.7, // y2 values
1609            0.9, 0.1, 0.2, // confidence scores
1610            0.0, 1.0, 2.0, // class indices
1611        ];
1612        let output = Array2::from_shape_vec((6, 3), data).unwrap();
1613
1614        let mut boxes = Vec::with_capacity(10);
1615        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1616
1617        // Only 1 detection should pass threshold of 0.5
1618        assert_eq!(boxes.len(), 1);
1619        assert_eq!(boxes[0].label, 0);
1620        assert!((boxes[0].score - 0.9).abs() < 0.01);
1621        assert!((boxes[0].bbox.xmin - 0.1).abs() < 0.01);
1622        assert!((boxes[0].bbox.ymin - 0.1).abs() < 0.01);
1623        assert!((boxes[0].bbox.xmax - 0.5).abs() < 0.01);
1624        assert!((boxes[0].bbox.ymax - 0.5).abs() < 0.01);
1625    }
1626
1627    #[test]
1628    fn test_end_to_end_det_all_pass_threshold() {
1629        // All detections above threshold
1630        let data: Vec<f32> = vec![
1631            10.0, 20.0, // x1
1632            10.0, 20.0, // y1
1633            50.0, 60.0, // x2
1634            50.0, 60.0, // y2
1635            0.8, 0.7, // conf (both above 0.5)
1636            1.0, 2.0, // class
1637        ];
1638        let output = Array2::from_shape_vec((6, 2), data).unwrap();
1639
1640        let mut boxes = Vec::with_capacity(10);
1641        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1642
1643        assert_eq!(boxes.len(), 2);
1644        assert_eq!(boxes[0].label, 1);
1645        assert_eq!(boxes[1].label, 2);
1646    }
1647
1648    #[test]
1649    fn test_end_to_end_det_none_pass_threshold() {
1650        // All detections below threshold
1651        let data: Vec<f32> = vec![
1652            10.0, 20.0, // x1
1653            10.0, 20.0, // y1
1654            50.0, 60.0, // x2
1655            50.0, 60.0, // y2
1656            0.1, 0.2, // conf (both below 0.5)
1657            1.0, 2.0, // class
1658        ];
1659        let output = Array2::from_shape_vec((6, 2), data).unwrap();
1660
1661        let mut boxes = Vec::with_capacity(10);
1662        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1663
1664        assert_eq!(boxes.len(), 0);
1665    }
1666
1667    #[test]
1668    fn test_end_to_end_det_capacity_limit() {
1669        // Test that output is truncated to capacity
1670        let data: Vec<f32> = vec![
1671            0.1, 0.2, 0.3, 0.4, 0.5, // x1
1672            0.1, 0.2, 0.3, 0.4, 0.5, // y1
1673            0.5, 0.6, 0.7, 0.8, 0.9, // x2
1674            0.5, 0.6, 0.7, 0.8, 0.9, // y2
1675            0.9, 0.9, 0.9, 0.9, 0.9, // conf (all pass)
1676            0.0, 1.0, 2.0, 3.0, 4.0, // class
1677        ];
1678        let output = Array2::from_shape_vec((6, 5), data).unwrap();
1679
1680        let mut boxes = Vec::with_capacity(2); // Only allow 2 boxes
1681        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1682
1683        assert_eq!(boxes.len(), 2);
1684    }
1685
1686    #[test]
1687    fn test_end_to_end_det_empty_output() {
1688        // Test with zero detections
1689        let output = Array2::<f32>::zeros((6, 0));
1690
1691        let mut boxes = Vec::with_capacity(10);
1692        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1693
1694        assert_eq!(boxes.len(), 0);
1695    }
1696
1697    #[test]
1698    fn test_end_to_end_det_pixel_coordinates() {
1699        // Test with pixel coordinates (non-normalized)
1700        let data: Vec<f32> = vec![
1701            100.0, // x1
1702            200.0, // y1
1703            300.0, // x2
1704            400.0, // y2
1705            0.95,  // conf
1706            5.0,   // class
1707        ];
1708        let output = Array2::from_shape_vec((6, 1), data).unwrap();
1709
1710        let mut boxes = Vec::with_capacity(10);
1711        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1712
1713        assert_eq!(boxes.len(), 1);
1714        assert_eq!(boxes[0].label, 5);
1715        assert!((boxes[0].bbox.xmin - 100.0).abs() < 0.01);
1716        assert!((boxes[0].bbox.ymin - 200.0).abs() < 0.01);
1717        assert!((boxes[0].bbox.xmax - 300.0).abs() < 0.01);
1718        assert!((boxes[0].bbox.ymax - 400.0).abs() < 0.01);
1719    }
1720
1721    #[test]
1722    fn test_end_to_end_det_invalid_shape() {
1723        // Test with too few rows (needs at least 6)
1724        let output = Array2::<f32>::zeros((5, 3));
1725
1726        let mut boxes = Vec::with_capacity(10);
1727        let result = decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes);
1728
1729        assert!(result.is_err());
1730        assert!(matches!(
1731            result,
1732            Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 6 rows")
1733        ));
1734    }
1735
1736    // ========================================================================
1737    // Tests for decode_yolo_end_to_end_segdet_float
1738    // ========================================================================
1739
1740    #[test]
1741    fn test_end_to_end_segdet_basic() {
1742        // Create synthetic segdet output: (6 + num_protos, N)
1743        // Detection format: [x1, y1, x2, y2, conf, class, mask_coeff_0..31]
1744        let num_protos = 32;
1745        let num_detections = 2;
1746        let num_features = 6 + num_protos;
1747
1748        // Build detection tensor
1749        let mut data = vec![0.0f32; num_features * num_detections];
1750        // Detection 0: passes threshold
1751        data[0] = 0.1; // x1[0]
1752        data[1] = 0.5; // x1[1]
1753        data[num_detections] = 0.1; // y1[0]
1754        data[num_detections + 1] = 0.5; // y1[1]
1755        data[2 * num_detections] = 0.4; // x2[0]
1756        data[2 * num_detections + 1] = 0.9; // x2[1]
1757        data[3 * num_detections] = 0.4; // y2[0]
1758        data[3 * num_detections + 1] = 0.9; // y2[1]
1759        data[4 * num_detections] = 0.9; // conf[0] - passes
1760        data[4 * num_detections + 1] = 0.3; // conf[1] - fails
1761        data[5 * num_detections] = 1.0; // class[0]
1762        data[5 * num_detections + 1] = 2.0; // class[1]
1763                                            // Fill mask coefficients with small values
1764        for i in 6..num_features {
1765            data[i * num_detections] = 0.1;
1766            data[i * num_detections + 1] = 0.1;
1767        }
1768
1769        let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
1770
1771        // Create protos tensor: (proto_height, proto_width, num_protos)
1772        let protos = Array3::<f32>::zeros((16, 16, num_protos));
1773
1774        let mut boxes = Vec::with_capacity(10);
1775        let mut masks = Vec::with_capacity(10);
1776        decode_yolo_end_to_end_segdet_float(
1777            output.view(),
1778            protos.view(),
1779            0.5,
1780            &mut boxes,
1781            &mut masks,
1782        )
1783        .unwrap();
1784
1785        // Only detection 0 should pass
1786        assert_eq!(boxes.len(), 1);
1787        assert_eq!(masks.len(), 1);
1788        assert_eq!(boxes[0].label, 1);
1789        assert!((boxes[0].score - 0.9).abs() < 0.01);
1790    }
1791
1792    #[test]
1793    fn test_end_to_end_segdet_mask_coordinates() {
1794        // Test that mask coordinates match box coordinates
1795        let num_protos = 32;
1796        let num_features = 6 + num_protos;
1797
1798        let mut data = vec![0.0f32; num_features];
1799        data[0] = 0.2; // x1
1800        data[1] = 0.2; // y1
1801        data[2] = 0.8; // x2
1802        data[3] = 0.8; // y2
1803        data[4] = 0.95; // conf
1804        data[5] = 3.0; // class
1805
1806        let output = Array2::from_shape_vec((num_features, 1), data).unwrap();
1807        let protos = Array3::<f32>::zeros((16, 16, num_protos));
1808
1809        let mut boxes = Vec::with_capacity(10);
1810        let mut masks = Vec::with_capacity(10);
1811        decode_yolo_end_to_end_segdet_float(
1812            output.view(),
1813            protos.view(),
1814            0.5,
1815            &mut boxes,
1816            &mut masks,
1817        )
1818        .unwrap();
1819
1820        assert_eq!(boxes.len(), 1);
1821        assert_eq!(masks.len(), 1);
1822
1823        // Verify mask coordinates match box coordinates
1824        assert!((masks[0].xmin - boxes[0].bbox.xmin).abs() < 0.01);
1825        assert!((masks[0].ymin - boxes[0].bbox.ymin).abs() < 0.01);
1826        assert!((masks[0].xmax - boxes[0].bbox.xmax).abs() < 0.01);
1827        assert!((masks[0].ymax - boxes[0].bbox.ymax).abs() < 0.01);
1828    }
1829
1830    #[test]
1831    fn test_end_to_end_segdet_empty_output() {
1832        let num_protos = 32;
1833        let output = Array2::<f32>::zeros((6 + num_protos, 0));
1834        let protos = Array3::<f32>::zeros((16, 16, num_protos));
1835
1836        let mut boxes = Vec::with_capacity(10);
1837        let mut masks = Vec::with_capacity(10);
1838        decode_yolo_end_to_end_segdet_float(
1839            output.view(),
1840            protos.view(),
1841            0.5,
1842            &mut boxes,
1843            &mut masks,
1844        )
1845        .unwrap();
1846
1847        assert_eq!(boxes.len(), 0);
1848        assert_eq!(masks.len(), 0);
1849    }
1850
1851    #[test]
1852    fn test_end_to_end_segdet_capacity_limit() {
1853        let num_protos = 32;
1854        let num_detections = 5;
1855        let num_features = 6 + num_protos;
1856
1857        let mut data = vec![0.0f32; num_features * num_detections];
1858        // All detections pass threshold
1859        for i in 0..num_detections {
1860            data[i] = 0.1 * (i as f32); // x1
1861            data[num_detections + i] = 0.1 * (i as f32); // y1
1862            data[2 * num_detections + i] = 0.1 * (i as f32) + 0.2; // x2
1863            data[3 * num_detections + i] = 0.1 * (i as f32) + 0.2; // y2
1864            data[4 * num_detections + i] = 0.9; // conf
1865            data[5 * num_detections + i] = i as f32; // class
1866        }
1867
1868        let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
1869        let protos = Array3::<f32>::zeros((16, 16, num_protos));
1870
1871        let mut boxes = Vec::with_capacity(2); // Limit to 2
1872        let mut masks = Vec::with_capacity(2);
1873        decode_yolo_end_to_end_segdet_float(
1874            output.view(),
1875            protos.view(),
1876            0.5,
1877            &mut boxes,
1878            &mut masks,
1879        )
1880        .unwrap();
1881
1882        assert_eq!(boxes.len(), 2);
1883        assert_eq!(masks.len(), 2);
1884    }
1885
1886    #[test]
1887    fn test_end_to_end_segdet_invalid_shape_too_few_rows() {
1888        // Test with too few rows (needs at least 7: 6 base + 1 mask coeff)
1889        let output = Array2::<f32>::zeros((6, 3));
1890        let protos = Array3::<f32>::zeros((16, 16, 32));
1891
1892        let mut boxes = Vec::with_capacity(10);
1893        let mut masks = Vec::with_capacity(10);
1894        let result = decode_yolo_end_to_end_segdet_float(
1895            output.view(),
1896            protos.view(),
1897            0.5,
1898            &mut boxes,
1899            &mut masks,
1900        );
1901
1902        assert!(result.is_err());
1903        assert!(matches!(
1904            result,
1905            Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 7 rows")
1906        ));
1907    }
1908
1909    #[test]
1910    fn test_end_to_end_segdet_invalid_shape_protos_mismatch() {
1911        // Test with mismatched mask coefficients and protos count
1912        let num_protos = 32;
1913        let output = Array2::<f32>::zeros((6 + 16, 3)); // 16 mask coeffs
1914        let protos = Array3::<f32>::zeros((16, 16, num_protos)); // 32 protos
1915
1916        let mut boxes = Vec::with_capacity(10);
1917        let mut masks = Vec::with_capacity(10);
1918        let result = decode_yolo_end_to_end_segdet_float(
1919            output.view(),
1920            protos.view(),
1921            0.5,
1922            &mut boxes,
1923            &mut masks,
1924        );
1925
1926        assert!(result.is_err());
1927        assert!(matches!(
1928            result,
1929            Err(crate::DecoderError::InvalidShape(s)) if s.contains("doesn't match protos count")
1930        ));
1931    }
1932
1933    // ========================================================================
1934    // Tests for decode_yolo_split_end_to_end_segdet_float
1935    // ========================================================================
1936
1937    #[test]
1938    fn test_split_end_to_end_segdet_basic() {
1939        // Create synthetic segdet output: (6 + num_protos, N)
1940        // Detection format: [x1, y1, x2, y2, conf, class, mask_coeff_0..31]
1941        let num_protos = 32;
1942        let num_detections = 2;
1943        let num_features = 6 + num_protos;
1944
1945        // Build detection tensor
1946        let mut data = vec![0.0f32; num_features * num_detections];
1947        // Detection 0: passes threshold
1948        data[0] = 0.1; // x1[0]
1949        data[1] = 0.5; // x1[1]
1950        data[num_detections] = 0.1; // y1[0]
1951        data[num_detections + 1] = 0.5; // y1[1]
1952        data[2 * num_detections] = 0.4; // x2[0]
1953        data[2 * num_detections + 1] = 0.9; // x2[1]
1954        data[3 * num_detections] = 0.4; // y2[0]
1955        data[3 * num_detections + 1] = 0.9; // y2[1]
1956        data[4 * num_detections] = 0.9; // conf[0] - passes
1957        data[4 * num_detections + 1] = 0.3; // conf[1] - fails
1958        data[5 * num_detections] = 1.0; // class[0]
1959        data[5 * num_detections + 1] = 2.0; // class[1]
1960                                            // Fill mask coefficients with small values
1961        for i in 6..num_features {
1962            data[i * num_detections] = 0.1;
1963            data[i * num_detections + 1] = 0.1;
1964        }
1965
1966        let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
1967        let box_coords = output.slice(s![..4, ..]);
1968        let scores = output.slice(s![4..5, ..]);
1969        let classes = output.slice(s![5..6, ..]);
1970        let mask_coeff = output.slice(s![6.., ..]);
1971        // Create protos tensor: (proto_height, proto_width, num_protos)
1972        let protos = Array3::<f32>::zeros((16, 16, num_protos));
1973
1974        let mut boxes = Vec::with_capacity(10);
1975        let mut masks = Vec::with_capacity(10);
1976        decode_yolo_split_end_to_end_segdet_float(
1977            box_coords,
1978            scores,
1979            classes,
1980            mask_coeff,
1981            protos.view(),
1982            0.5,
1983            &mut boxes,
1984            &mut masks,
1985        )
1986        .unwrap();
1987
1988        // Only detection 0 should pass
1989        assert_eq!(boxes.len(), 1);
1990        assert_eq!(masks.len(), 1);
1991        assert_eq!(boxes[0].label, 1);
1992        assert!((boxes[0].score - 0.9).abs() < 0.01);
1993    }
1994
1995    // ========================================================================
1996    // Tests for yolo_segmentation_to_mask
1997    // ========================================================================
1998
1999    #[test]
2000    fn test_segmentation_to_mask_basic() {
2001        // Create a 4x4x1 segmentation with values above and below threshold
2002        let data: Vec<u8> = vec![
2003            100, 200, 50, 150, // row 0
2004            10, 255, 128, 64, // row 1
2005            0, 127, 128, 255, // row 2
2006            64, 64, 192, 192, // row 3
2007        ];
2008        let segmentation = Array3::from_shape_vec((4, 4, 1), data).unwrap();
2009
2010        let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2011
2012        // Values >= 128 should be 1, others 0
2013        assert_eq!(mask[[0, 0]], 0); // 100 < 128
2014        assert_eq!(mask[[0, 1]], 1); // 200 >= 128
2015        assert_eq!(mask[[0, 2]], 0); // 50 < 128
2016        assert_eq!(mask[[0, 3]], 1); // 150 >= 128
2017        assert_eq!(mask[[1, 1]], 1); // 255 >= 128
2018        assert_eq!(mask[[1, 2]], 1); // 128 >= 128
2019        assert_eq!(mask[[2, 0]], 0); // 0 < 128
2020        assert_eq!(mask[[2, 1]], 0); // 127 < 128
2021    }
2022
2023    #[test]
2024    fn test_segmentation_to_mask_all_above() {
2025        let segmentation = Array3::from_elem((4, 4, 1), 255u8);
2026        let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2027        assert!(mask.iter().all(|&x| x == 1));
2028    }
2029
2030    #[test]
2031    fn test_segmentation_to_mask_all_below() {
2032        let segmentation = Array3::from_elem((4, 4, 1), 64u8);
2033        let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2034        assert!(mask.iter().all(|&x| x == 0));
2035    }
2036
2037    #[test]
2038    fn test_segmentation_to_mask_invalid_shape() {
2039        let segmentation = Array3::from_elem((4, 4, 3), 128u8);
2040        let result = yolo_segmentation_to_mask(segmentation.view(), 128);
2041
2042        assert!(result.is_err());
2043        assert!(matches!(
2044            result,
2045            Err(crate::DecoderError::InvalidShape(s)) if s.contains("(H, W, 1)")
2046        ));
2047    }
2048
2049    // ========================================================================
2050    // Tests for protobox / NORM_LIMIT regression
2051    // ========================================================================
2052
2053    #[test]
2054    fn test_protobox_clamps_edge_coordinates() {
2055        // bbox with xmax=1.0 should not panic (OOB guard)
2056        let protos = Array3::<f32>::zeros((16, 16, 4));
2057        let view = protos.view();
2058        let roi = BoundingBox {
2059            xmin: 0.5,
2060            ymin: 0.5,
2061            xmax: 1.0,
2062            ymax: 1.0,
2063        };
2064        let result = protobox(&view, &roi);
2065        assert!(result.is_ok(), "protobox should accept xmax=1.0");
2066        let (cropped, _roi_norm) = result.unwrap();
2067        // Cropped region must have non-zero spatial dimensions
2068        assert!(cropped.shape()[0] > 0);
2069        assert!(cropped.shape()[1] > 0);
2070        assert_eq!(cropped.shape()[2], 4);
2071    }
2072
2073    #[test]
2074    fn test_protobox_rejects_wildly_out_of_range() {
2075        // bbox with coords > NORM_LIMIT (e.g. 3.0) returns error
2076        let protos = Array3::<f32>::zeros((16, 16, 4));
2077        let view = protos.view();
2078        let roi = BoundingBox {
2079            xmin: 0.0,
2080            ymin: 0.0,
2081            xmax: 3.0,
2082            ymax: 3.0,
2083        };
2084        let result = protobox(&view, &roi);
2085        assert!(
2086            matches!(result, Err(crate::DecoderError::InvalidShape(s)) if s.contains("un-normalized")),
2087            "protobox should reject coords > NORM_LIMIT"
2088        );
2089    }
2090
2091    #[test]
2092    fn test_protobox_accepts_slightly_over_one() {
2093        // bbox with coords at 1.5 (within NORM_LIMIT=2.0) succeeds
2094        let protos = Array3::<f32>::zeros((16, 16, 4));
2095        let view = protos.view();
2096        let roi = BoundingBox {
2097            xmin: 0.0,
2098            ymin: 0.0,
2099            xmax: 1.5,
2100            ymax: 1.5,
2101        };
2102        let result = protobox(&view, &roi);
2103        assert!(
2104            result.is_ok(),
2105            "protobox should accept coords <= NORM_LIMIT (2.0)"
2106        );
2107        let (cropped, _roi_norm) = result.unwrap();
2108        // Entire proto map should be selected when coords > 1.0 (clamped to boundary)
2109        assert_eq!(cropped.shape()[0], 16);
2110        assert_eq!(cropped.shape()[1], 16);
2111    }
2112
2113    #[test]
2114    fn test_segdet_float_proto_no_panic() {
2115        // Simulates YOLOv8n-seg: output0 = [116, 8400] (4 box + 80 class + 32 mask coeff)
2116        // output1 (protos) = [32, 160, 160]
2117        let num_proposals = 100; // enough to produce idx >= 32
2118        let num_classes = 80;
2119        let num_mask_coeffs = 32;
2120        let rows = 4 + num_classes + num_mask_coeffs; // 116
2121
2122        // Fill boxes with valid xywh data so some detections pass the threshold.
2123        // Layout is [116, num_proposals] row-major: row 0=cx, 1=cy, 2=w, 3=h,
2124        // rows 4..84=class scores, rows 84..116=mask coefficients.
2125        let mut data = vec![0.0f32; rows * num_proposals];
2126        for i in 0..num_proposals {
2127            let row = |r: usize| r * num_proposals + i;
2128            data[row(0)] = 320.0; // cx
2129            data[row(1)] = 320.0; // cy
2130            data[row(2)] = 50.0; // w
2131            data[row(3)] = 50.0; // h
2132            data[row(4)] = 0.9; // class-0 score
2133        }
2134        let boxes = ndarray::Array2::from_shape_vec((rows, num_proposals), data).unwrap();
2135
2136        // Protos must be in HWC order (decoder.rs protos_to_hwc converts
2137        // before calling into these functions).
2138        let protos = ndarray::Array3::<f32>::zeros((160, 160, num_mask_coeffs));
2139
2140        let mut output_boxes = Vec::with_capacity(300);
2141
2142        // This panicked before fix: mask_tensor.row(idx) with idx >= 32
2143        let proto_data = impl_yolo_segdet_float_proto::<XYWH, _, _>(
2144            boxes.view(),
2145            protos.view(),
2146            0.5,
2147            0.7,
2148            Some(Nms::default()),
2149            &mut output_boxes,
2150        );
2151
2152        // Should produce detections (NMS will collapse many overlapping boxes)
2153        assert!(!output_boxes.is_empty());
2154        assert_eq!(proto_data.mask_coefficients.len(), output_boxes.len());
2155        // Each mask coefficient vector should have 32 elements
2156        for coeffs in &proto_data.mask_coefficients {
2157            assert_eq!(coeffs.len(), num_mask_coeffs);
2158        }
2159    }
2160}