Skip to main content

edgefirst_decoder/
yolo.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use std::fmt::Debug;
5
6use ndarray::{
7    parallel::prelude::{IntoParallelIterator, ParallelIterator},
8    s, Array2, Array3, ArrayView1, ArrayView2, ArrayView3,
9};
10use num_traits::{AsPrimitive, Float, PrimInt, Signed};
11
12use crate::{
13    byte::{
14        nms_class_aware_int, nms_extra_class_aware_int, nms_extra_int, nms_int,
15        postprocess_boxes_index_quant, postprocess_boxes_quant, quantize_score_threshold,
16    },
17    configs::Nms,
18    dequant_detect_box,
19    float::{
20        nms_class_aware_float, nms_extra_class_aware_float, nms_extra_float, nms_float,
21        postprocess_boxes_float, postprocess_boxes_index_float,
22    },
23    BBoxTypeTrait, BoundingBox, DetectBox, DetectBoxQuantized, ProtoData, ProtoTensor,
24    Quantization, Segmentation, XYWH, XYXY,
25};
26
27/// Dispatches to the appropriate NMS function based on mode for float boxes.
28fn dispatch_nms_float(nms: Option<Nms>, iou: f32, boxes: Vec<DetectBox>) -> Vec<DetectBox> {
29    match nms {
30        Some(Nms::ClassAgnostic) => nms_float(iou, boxes),
31        Some(Nms::ClassAware) => nms_class_aware_float(iou, boxes),
32        None => boxes, // bypass NMS
33    }
34}
35
36/// Dispatches to the appropriate NMS function based on mode for float boxes
37/// with extra data.
38fn dispatch_nms_extra_float<E: Send + Sync>(
39    nms: Option<Nms>,
40    iou: f32,
41    boxes: Vec<(DetectBox, E)>,
42) -> Vec<(DetectBox, E)> {
43    match nms {
44        Some(Nms::ClassAgnostic) => nms_extra_float(iou, boxes),
45        Some(Nms::ClassAware) => nms_extra_class_aware_float(iou, boxes),
46        None => boxes, // bypass NMS
47    }
48}
49
50/// Dispatches to the appropriate NMS function based on mode for quantized
51/// boxes.
52fn dispatch_nms_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync>(
53    nms: Option<Nms>,
54    iou: f32,
55    boxes: Vec<DetectBoxQuantized<SCORE>>,
56) -> Vec<DetectBoxQuantized<SCORE>> {
57    match nms {
58        Some(Nms::ClassAgnostic) => nms_int(iou, boxes),
59        Some(Nms::ClassAware) => nms_class_aware_int(iou, boxes),
60        None => boxes, // bypass NMS
61    }
62}
63
64/// Dispatches to the appropriate NMS function based on mode for quantized boxes
65/// with extra data.
66fn dispatch_nms_extra_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send + Sync>(
67    nms: Option<Nms>,
68    iou: f32,
69    boxes: Vec<(DetectBoxQuantized<SCORE>, E)>,
70) -> Vec<(DetectBoxQuantized<SCORE>, E)> {
71    match nms {
72        Some(Nms::ClassAgnostic) => nms_extra_int(iou, boxes),
73        Some(Nms::ClassAware) => nms_extra_class_aware_int(iou, boxes),
74        None => boxes, // bypass NMS
75    }
76}
77
78/// Decodes YOLO detection outputs from quantized tensors into detection boxes.
79///
80/// Boxes are expected to be in XYWH format.
81///
82/// Expected shapes of inputs:
83/// - output: (4 + num_classes, num_boxes)
84pub fn decode_yolo_det<BOX: PrimInt + AsPrimitive<f32> + Send + Sync>(
85    output: (ArrayView2<BOX>, Quantization),
86    score_threshold: f32,
87    iou_threshold: f32,
88    nms: Option<Nms>,
89    output_boxes: &mut Vec<DetectBox>,
90) where
91    f32: AsPrimitive<BOX>,
92{
93    impl_yolo_quant::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
94}
95
96/// Decodes YOLO detection outputs from float tensors into detection boxes.
97///
98/// Boxes are expected to be in XYWH format.
99///
100/// Expected shapes of inputs:
101/// - output: (4 + num_classes, num_boxes)
102pub fn decode_yolo_det_float<T>(
103    output: ArrayView2<T>,
104    score_threshold: f32,
105    iou_threshold: f32,
106    nms: Option<Nms>,
107    output_boxes: &mut Vec<DetectBox>,
108) where
109    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
110    f32: AsPrimitive<T>,
111{
112    impl_yolo_float::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
113}
114
115/// Decodes YOLO detection and segmentation outputs from quantized tensors into
116/// detection boxes and segmentation masks.
117///
118/// Boxes are expected to be in XYWH format.
119///
120/// Expected shapes of inputs:
121/// - boxes: (4 + num_classes + num_protos, num_boxes)
122/// - protos: (proto_height, proto_width, num_protos)
123///
124/// # Errors
125/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
126pub fn decode_yolo_segdet_quant<
127    BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
128    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
129>(
130    boxes: (ArrayView2<BOX>, Quantization),
131    protos: (ArrayView3<PROTO>, Quantization),
132    score_threshold: f32,
133    iou_threshold: f32,
134    nms: Option<Nms>,
135    output_boxes: &mut Vec<DetectBox>,
136    output_masks: &mut Vec<Segmentation>,
137) -> Result<(), crate::DecoderError>
138where
139    f32: AsPrimitive<BOX>,
140{
141    impl_yolo_segdet_quant::<XYWH, _, _>(
142        boxes,
143        protos,
144        score_threshold,
145        iou_threshold,
146        nms,
147        output_boxes,
148        output_masks,
149    )
150}
151
152/// Decodes YOLO detection and segmentation outputs from float tensors into
153/// detection boxes and segmentation masks.
154///
155/// Boxes are expected to be in XYWH format.
156///
157/// Expected shapes of inputs:
158/// - boxes: (4 + num_classes + num_protos, num_boxes)
159/// - protos: (proto_height, proto_width, num_protos)
160///
161/// # Errors
162/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
163pub fn decode_yolo_segdet_float<T>(
164    boxes: ArrayView2<T>,
165    protos: ArrayView3<T>,
166    score_threshold: f32,
167    iou_threshold: f32,
168    nms: Option<Nms>,
169    output_boxes: &mut Vec<DetectBox>,
170    output_masks: &mut Vec<Segmentation>,
171) -> Result<(), crate::DecoderError>
172where
173    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
174    f32: AsPrimitive<T>,
175{
176    impl_yolo_segdet_float::<XYWH, _, _>(
177        boxes,
178        protos,
179        score_threshold,
180        iou_threshold,
181        nms,
182        output_boxes,
183        output_masks,
184    )
185}
186
187/// Decodes YOLO split detection outputs from quantized tensors into detection
188/// boxes.
189///
190/// Boxes are expected to be in XYWH format.
191///
192/// Expected shapes of inputs:
193/// - boxes: (4, num_boxes)
194/// - scores: (num_classes, num_boxes)
195///
196/// # Panics
197/// Panics if shapes don't match the expected dimensions.
198pub fn decode_yolo_split_det_quant<
199    BOX: PrimInt + AsPrimitive<i32> + AsPrimitive<f32> + Send + Sync,
200    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
201>(
202    boxes: (ArrayView2<BOX>, Quantization),
203    scores: (ArrayView2<SCORE>, Quantization),
204    score_threshold: f32,
205    iou_threshold: f32,
206    nms: Option<Nms>,
207    output_boxes: &mut Vec<DetectBox>,
208) where
209    f32: AsPrimitive<SCORE>,
210{
211    impl_yolo_split_quant::<XYWH, _, _>(
212        boxes,
213        scores,
214        score_threshold,
215        iou_threshold,
216        nms,
217        output_boxes,
218    );
219}
220
221/// Decodes YOLO split detection outputs from float tensors into detection
222/// boxes.
223///
224/// Boxes are expected to be in XYWH format.
225///
226/// Expected shapes of inputs:
227/// - boxes: (4, num_boxes)
228/// - scores: (num_classes, num_boxes)
229///
230/// # Panics
231/// Panics if shapes don't match the expected dimensions.
232pub fn decode_yolo_split_det_float<T>(
233    boxes: ArrayView2<T>,
234    scores: ArrayView2<T>,
235    score_threshold: f32,
236    iou_threshold: f32,
237    nms: Option<Nms>,
238    output_boxes: &mut Vec<DetectBox>,
239) where
240    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
241    f32: AsPrimitive<T>,
242{
243    impl_yolo_split_float::<XYWH, _, _>(
244        boxes,
245        scores,
246        score_threshold,
247        iou_threshold,
248        nms,
249        output_boxes,
250    );
251}
252
253/// Decodes YOLO split detection segmentation outputs from quantized tensors
254/// into detection boxes and segmentation masks.
255///
256/// Boxes are expected to be in XYWH format.
257///
258/// Expected shapes of inputs:
259/// - boxes_tensor: (4, num_boxes)
260/// - scores_tensor: (num_classes, num_boxes)
261/// - mask_tensor: (num_protos, num_boxes)
262/// - protos: (proto_height, proto_width, num_protos)
263///
264/// # Errors
265/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
266#[allow(clippy::too_many_arguments)]
267pub fn decode_yolo_split_segdet<
268    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
269    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
270    MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
271    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
272>(
273    boxes: (ArrayView2<BOX>, Quantization),
274    scores: (ArrayView2<SCORE>, Quantization),
275    mask_coeff: (ArrayView2<MASK>, Quantization),
276    protos: (ArrayView3<PROTO>, Quantization),
277    score_threshold: f32,
278    iou_threshold: f32,
279    nms: Option<Nms>,
280    output_boxes: &mut Vec<DetectBox>,
281    output_masks: &mut Vec<Segmentation>,
282) -> Result<(), crate::DecoderError>
283where
284    f32: AsPrimitive<SCORE>,
285{
286    impl_yolo_split_segdet_quant::<XYWH, _, _, _, _>(
287        boxes,
288        scores,
289        mask_coeff,
290        protos,
291        score_threshold,
292        iou_threshold,
293        nms,
294        output_boxes,
295        output_masks,
296    )
297}
298
299/// Decodes YOLO split detection segmentation outputs from float tensors
300/// into detection boxes and segmentation masks.
301///
302/// Boxes are expected to be in XYWH format.
303///
304/// Expected shapes of inputs:
305/// - boxes_tensor: (4, num_boxes)
306/// - scores_tensor: (num_classes, num_boxes)
307/// - mask_tensor: (num_protos, num_boxes)
308/// - protos: (proto_height, proto_width, num_protos)
309///
310/// # Errors
311/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
312#[allow(clippy::too_many_arguments)]
313pub fn decode_yolo_split_segdet_float<T>(
314    boxes: ArrayView2<T>,
315    scores: ArrayView2<T>,
316    mask_coeff: ArrayView2<T>,
317    protos: ArrayView3<T>,
318    score_threshold: f32,
319    iou_threshold: f32,
320    nms: Option<Nms>,
321    output_boxes: &mut Vec<DetectBox>,
322    output_masks: &mut Vec<Segmentation>,
323) -> Result<(), crate::DecoderError>
324where
325    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
326    f32: AsPrimitive<T>,
327{
328    impl_yolo_split_segdet_float::<XYWH, _, _, _, _>(
329        boxes,
330        scores,
331        mask_coeff,
332        protos,
333        score_threshold,
334        iou_threshold,
335        nms,
336        output_boxes,
337        output_masks,
338    )
339}
340
341/// Decodes end-to-end YOLO detection outputs (post-NMS from model).
342/// Expects an array of shape `(6, N)`, where the first dimension (rows)
343/// corresponds to the 6 per-detection features
344/// `[x1, y1, x2, y2, conf, class]` and the second dimension (columns)
345/// indexes the `N` detections.
346/// Boxes are output directly without NMS (the model already applied NMS).
347///
348/// Coordinates may be normalized `[0, 1]` or absolute pixel values depending
349/// on the model configuration. The caller should check
350/// `decoder.normalized_boxes()` to determine which.
351///
352/// # Errors
353///
354/// Returns `DecoderError::InvalidShape` if `output` has fewer than 6 rows.
355pub fn decode_yolo_end_to_end_det_float<T>(
356    output: ArrayView2<T>,
357    score_threshold: f32,
358    output_boxes: &mut Vec<DetectBox>,
359) -> Result<(), crate::DecoderError>
360where
361    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
362    f32: AsPrimitive<T>,
363{
364    // Validate input shape: need at least 6 rows (x1, y1, x2, y2, conf, class)
365    if output.shape()[0] < 6 {
366        return Err(crate::DecoderError::InvalidShape(format!(
367            "End-to-end detection output requires at least 6 rows, got {}",
368            output.shape()[0]
369        )));
370    }
371
372    // Input shape: (6, N) -> transpose to (N, 4) for boxes and (N, 1) for scores
373    let boxes = output.slice(s![0..4, ..]).reversed_axes();
374    let scores = output.slice(s![4..5, ..]).reversed_axes();
375    let classes = output.slice(s![5, ..]);
376    let mut boxes =
377        postprocess_boxes_index_float::<XYXY, _, _>(score_threshold.as_(), boxes, scores);
378    boxes.truncate(output_boxes.capacity());
379    output_boxes.clear();
380    for (mut b, i) in boxes.into_iter() {
381        b.label = classes[i].as_() as usize;
382        output_boxes.push(b);
383    }
384    // No NMS — model output is already post-NMS
385    Ok(())
386}
387
388/// Decodes end-to-end YOLO detection + segmentation outputs (post-NMS from
389/// model).
390///
391/// Input shapes:
392/// - detection: (6 + num_protos, N) where rows are [x1, y1, x2, y2, conf,
393///   class, mask_coeff_0, ..., mask_coeff_31]
394/// - protos: (proto_height, proto_width, num_protos)
395///
396/// Boxes are output directly without NMS (model already applied NMS).
397/// Coordinates may be normalized [0,1] or pixel values depending on model
398/// config.
399///
400/// # Errors
401///
402/// Returns `DecoderError::InvalidShape` if:
403/// - output has fewer than 7 rows (6 base + at least 1 mask coefficient)
404/// - protos shape doesn't match mask coefficients count
405pub fn decode_yolo_end_to_end_segdet_float<T>(
406    output: ArrayView2<T>,
407    protos: ArrayView3<T>,
408    score_threshold: f32,
409    output_boxes: &mut Vec<DetectBox>,
410    output_masks: &mut Vec<crate::Segmentation>,
411) -> Result<(), crate::DecoderError>
412where
413    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
414    f32: AsPrimitive<T>,
415{
416    // Validate input shape: need at least 7 rows (6 base + at least 1 mask coeff)
417    if output.shape()[0] < 7 {
418        return Err(crate::DecoderError::InvalidShape(format!(
419            "End-to-end segdet output requires at least 7 rows, got {}",
420            output.shape()[0]
421        )));
422    }
423
424    let num_mask_coeffs = output.shape()[0] - 6;
425    let num_protos = protos.shape()[2];
426    if num_mask_coeffs != num_protos {
427        return Err(crate::DecoderError::InvalidShape(format!(
428            "Mask coefficients count ({}) doesn't match protos count ({})",
429            num_mask_coeffs, num_protos
430        )));
431    }
432
433    // Input shape: (6+num_protos, N) -> transpose for postprocessing
434    let boxes = output.slice(s![0..4, ..]).reversed_axes();
435    let scores = output.slice(s![4..5, ..]).reversed_axes();
436    let classes = output.slice(s![5, ..]);
437    let mask_coeff = output.slice(s![6.., ..]).reversed_axes();
438    let mut boxes =
439        postprocess_boxes_index_float::<XYXY, _, _>(score_threshold.as_(), boxes, scores);
440    boxes.truncate(output_boxes.capacity());
441
442    for (b, ind) in &mut boxes {
443        b.label = classes[*ind].as_() as usize;
444    }
445
446    // No NMS — model output is already post-NMS
447
448    let boxes = decode_segdet_f32(boxes, mask_coeff, protos)?;
449
450    output_boxes.clear();
451    output_masks.clear();
452    for (b, m) in boxes.into_iter() {
453        output_boxes.push(b);
454        output_masks.push(Segmentation {
455            xmin: b.bbox.xmin,
456            ymin: b.bbox.ymin,
457            xmax: b.bbox.xmax,
458            ymax: b.bbox.ymax,
459            segmentation: m,
460        });
461    }
462    Ok(())
463}
464
465/// Decodes split end-to-end YOLO detection outputs (post-NMS from model).
466///
467/// Input shapes (after batch dim removed):
468/// - boxes: (N, 4) — xyxy pixel coordinates
469/// - scores: (N, 1) — confidence of the top class
470/// - classes: (N, 1) — class index of the top class
471///
472/// Boxes are output directly without NMS (model already applied NMS).
473pub fn decode_yolo_split_end_to_end_det_float<T: Float + AsPrimitive<f32>>(
474    boxes: ArrayView2<T>,
475    scores: ArrayView2<T>,
476    classes: ArrayView2<T>,
477    score_threshold: f32,
478    output_boxes: &mut Vec<DetectBox>,
479) -> Result<(), crate::DecoderError> {
480    let n = boxes.shape()[0];
481    if boxes.shape()[1] != 4 {
482        return Err(crate::DecoderError::InvalidShape(format!(
483            "Split end-to-end boxes must have 4 columns, got {}",
484            boxes.shape()[1]
485        )));
486    }
487    output_boxes.clear();
488    for i in 0..n {
489        let score: f32 = scores[[i, 0]].as_();
490        if score < score_threshold {
491            continue;
492        }
493        if output_boxes.len() >= output_boxes.capacity() {
494            break;
495        }
496        output_boxes.push(DetectBox {
497            bbox: BoundingBox {
498                xmin: boxes[[i, 0]].as_(),
499                ymin: boxes[[i, 1]].as_(),
500                xmax: boxes[[i, 2]].as_(),
501                ymax: boxes[[i, 3]].as_(),
502            },
503            score,
504            label: classes[[i, 0]].as_() as usize,
505        });
506    }
507    Ok(())
508}
509
510/// Decodes split end-to-end YOLO detection + segmentation outputs.
511///
512/// Input shapes (after batch dim removed):
513/// - boxes: (N, 4) — xyxy pixel coordinates
514/// - scores: (N, 1) — confidence
515/// - classes: (N, 1) — class index
516/// - mask_coeff: (N, num_protos) — mask coefficients per detection
517/// - protos: (proto_h, proto_w, num_protos) — prototype masks
518#[allow(clippy::too_many_arguments)]
519pub fn decode_yolo_split_end_to_end_segdet_float<T>(
520    boxes: ArrayView2<T>,
521    scores: ArrayView2<T>,
522    classes: ArrayView2<T>,
523    mask_coeff: ArrayView2<T>,
524    protos: ArrayView3<T>,
525    score_threshold: f32,
526    output_boxes: &mut Vec<DetectBox>,
527    output_masks: &mut Vec<crate::Segmentation>,
528) -> Result<(), crate::DecoderError>
529where
530    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
531    f32: AsPrimitive<T>,
532{
533    let n = boxes.shape()[0];
534    if boxes.shape()[1] != 4 {
535        return Err(crate::DecoderError::InvalidShape(format!(
536            "Split end-to-end boxes must have 4 columns, got {}",
537            boxes.shape()[1]
538        )));
539    }
540
541    // Collect qualifying detections with their indices
542    let mut qualifying: Vec<(DetectBox, usize)> = Vec::with_capacity(output_boxes.capacity());
543    for i in 0..n {
544        let score: f32 = scores[[i, 0]].as_();
545        if score < score_threshold {
546            continue;
547        }
548        if qualifying.len() >= output_boxes.capacity() {
549            break;
550        }
551        qualifying.push((
552            DetectBox {
553                bbox: BoundingBox {
554                    xmin: boxes[[i, 0]].as_(),
555                    ymin: boxes[[i, 1]].as_(),
556                    xmax: boxes[[i, 2]].as_(),
557                    ymax: boxes[[i, 3]].as_(),
558                },
559                score,
560                label: classes[[i, 0]].as_() as usize,
561            },
562            i,
563        ));
564    }
565
566    // Process masks using existing infrastructure
567    let result = decode_segdet_f32(qualifying, mask_coeff, protos)?;
568
569    output_boxes.clear();
570    output_masks.clear();
571    for (b, m) in result.into_iter() {
572        output_masks.push(crate::Segmentation {
573            xmin: b.bbox.xmin,
574            ymin: b.bbox.ymin,
575            xmax: b.bbox.xmax,
576            ymax: b.bbox.ymax,
577            segmentation: m,
578        });
579        output_boxes.push(b);
580    }
581    Ok(())
582}
583
584/// Internal implementation of YOLO decoding for quantized tensors.
585///
586/// Expected shapes of inputs:
587/// - output: (4 + num_classes, num_boxes)
588pub fn impl_yolo_quant<B: BBoxTypeTrait, T: PrimInt + AsPrimitive<f32> + Send + Sync>(
589    output: (ArrayView2<T>, Quantization),
590    score_threshold: f32,
591    iou_threshold: f32,
592    nms: Option<Nms>,
593    output_boxes: &mut Vec<DetectBox>,
594) where
595    f32: AsPrimitive<T>,
596{
597    let (boxes, quant_boxes) = output;
598    let (boxes_tensor, scores_tensor) = postprocess_yolo(&boxes);
599
600    let boxes = {
601        let score_threshold = quantize_score_threshold(score_threshold, quant_boxes);
602        postprocess_boxes_quant::<B, _, _>(
603            score_threshold,
604            boxes_tensor,
605            scores_tensor,
606            quant_boxes,
607        )
608    };
609
610    let boxes = dispatch_nms_int(nms, iou_threshold, boxes);
611    let len = output_boxes.capacity().min(boxes.len());
612    output_boxes.clear();
613    for b in boxes.iter().take(len) {
614        output_boxes.push(dequant_detect_box(b, quant_boxes));
615    }
616}
617
618/// Internal implementation of YOLO decoding for float tensors.
619///
620/// Expected shapes of inputs:
621/// - output: (4 + num_classes, num_boxes)
622pub fn impl_yolo_float<B: BBoxTypeTrait, T: Float + AsPrimitive<f32> + Send + Sync>(
623    output: ArrayView2<T>,
624    score_threshold: f32,
625    iou_threshold: f32,
626    nms: Option<Nms>,
627    output_boxes: &mut Vec<DetectBox>,
628) where
629    f32: AsPrimitive<T>,
630{
631    let (boxes_tensor, scores_tensor) = postprocess_yolo(&output);
632    let boxes =
633        postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
634    let boxes = dispatch_nms_float(nms, iou_threshold, boxes);
635    let len = output_boxes.capacity().min(boxes.len());
636    output_boxes.clear();
637    for b in boxes.into_iter().take(len) {
638        output_boxes.push(b);
639    }
640}
641
642/// Internal implementation of YOLO split detection decoding for quantized
643/// tensors.
644///
645/// Expected shapes of inputs:
646/// - boxes: (4, num_boxes)
647/// - scores: (num_classes, num_boxes)
648///
649/// # Panics
650/// Panics if shapes don't match the expected dimensions.
651pub fn impl_yolo_split_quant<
652    B: BBoxTypeTrait,
653    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
654    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
655>(
656    boxes: (ArrayView2<BOX>, Quantization),
657    scores: (ArrayView2<SCORE>, Quantization),
658    score_threshold: f32,
659    iou_threshold: f32,
660    nms: Option<Nms>,
661    output_boxes: &mut Vec<DetectBox>,
662) where
663    f32: AsPrimitive<SCORE>,
664{
665    let (boxes_tensor, quant_boxes) = boxes;
666    let (scores_tensor, quant_scores) = scores;
667
668    let boxes_tensor = boxes_tensor.reversed_axes();
669    let scores_tensor = scores_tensor.reversed_axes();
670
671    let boxes = {
672        let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
673        postprocess_boxes_quant::<B, _, _>(
674            score_threshold,
675            boxes_tensor,
676            scores_tensor,
677            quant_boxes,
678        )
679    };
680
681    let boxes = dispatch_nms_int(nms, iou_threshold, boxes);
682    let len = output_boxes.capacity().min(boxes.len());
683    output_boxes.clear();
684    for b in boxes.iter().take(len) {
685        output_boxes.push(dequant_detect_box(b, quant_scores));
686    }
687}
688
689/// Internal implementation of YOLO split detection decoding for float tensors.
690///
691/// Expected shapes of inputs:
692/// - boxes: (4, num_boxes)
693/// - scores: (num_classes, num_boxes)
694///
695/// # Panics
696/// Panics if shapes don't match the expected dimensions.
697pub fn impl_yolo_split_float<
698    B: BBoxTypeTrait,
699    BOX: Float + AsPrimitive<f32> + Send + Sync,
700    SCORE: Float + AsPrimitive<f32> + Send + Sync,
701>(
702    boxes_tensor: ArrayView2<BOX>,
703    scores_tensor: ArrayView2<SCORE>,
704    score_threshold: f32,
705    iou_threshold: f32,
706    nms: Option<Nms>,
707    output_boxes: &mut Vec<DetectBox>,
708) where
709    f32: AsPrimitive<SCORE>,
710{
711    let boxes_tensor = boxes_tensor.reversed_axes();
712    let scores_tensor = scores_tensor.reversed_axes();
713    let boxes =
714        postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
715    let boxes = dispatch_nms_float(nms, iou_threshold, boxes);
716    let len = output_boxes.capacity().min(boxes.len());
717    output_boxes.clear();
718    for b in boxes.into_iter().take(len) {
719        output_boxes.push(b);
720    }
721}
722
723/// Internal implementation of YOLO detection segmentation decoding for
724/// quantized tensors.
725///
726/// Expected shapes of inputs:
727/// - boxes: (4 + num_classes + num_protos, num_boxes)
728/// - protos: (proto_height, proto_width, num_protos)
729///
730/// # Errors
731/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
732pub fn impl_yolo_segdet_quant<
733    B: BBoxTypeTrait,
734    BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
735    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
736>(
737    boxes: (ArrayView2<BOX>, Quantization),
738    protos: (ArrayView3<PROTO>, Quantization),
739    score_threshold: f32,
740    iou_threshold: f32,
741    nms: Option<Nms>,
742    output_boxes: &mut Vec<DetectBox>,
743    output_masks: &mut Vec<Segmentation>,
744) -> Result<(), crate::DecoderError>
745where
746    f32: AsPrimitive<BOX>,
747{
748    let (boxes, quant_boxes) = boxes;
749    let num_protos = protos.0.dim().2;
750    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
751
752    let boxes = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
753        (boxes_tensor.reversed_axes(), quant_boxes),
754        (scores_tensor.reversed_axes(), quant_boxes),
755        score_threshold,
756        iou_threshold,
757        nms,
758        output_boxes.capacity(),
759    );
760
761    impl_yolo_split_segdet_quant_process_masks::<_, _>(
762        boxes,
763        (mask_tensor.reversed_axes(), quant_boxes),
764        protos,
765        output_boxes,
766        output_masks,
767    )
768}
769
770/// Internal implementation of YOLO detection segmentation decoding for
771/// float tensors.
772///
773/// Expected shapes of inputs:
774/// - boxes: (4 + num_classes + num_protos, num_boxes)
775/// - protos: (proto_height, proto_width, num_protos)
776///
777/// # Panics
778/// Panics if shapes don't match the expected dimensions.
779pub fn impl_yolo_segdet_float<
780    B: BBoxTypeTrait,
781    BOX: Float + AsPrimitive<f32> + Send + Sync,
782    PROTO: Float + AsPrimitive<f32> + Send + Sync,
783>(
784    boxes: ArrayView2<BOX>,
785    protos: ArrayView3<PROTO>,
786    score_threshold: f32,
787    iou_threshold: f32,
788    nms: Option<Nms>,
789    output_boxes: &mut Vec<DetectBox>,
790    output_masks: &mut Vec<Segmentation>,
791) -> Result<(), crate::DecoderError>
792where
793    f32: AsPrimitive<BOX>,
794{
795    let num_protos = protos.dim().2;
796    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
797
798    let boxes = postprocess_boxes_index_float::<B, _, _>(
799        score_threshold.as_(),
800        boxes_tensor,
801        scores_tensor,
802    );
803    let mut boxes = dispatch_nms_extra_float(nms, iou_threshold, boxes);
804    boxes.truncate(output_boxes.capacity());
805    let boxes = decode_segdet_f32(boxes, mask_tensor, protos)?;
806    output_boxes.clear();
807    output_masks.clear();
808    for (b, m) in boxes.into_iter() {
809        output_boxes.push(b);
810        output_masks.push(Segmentation {
811            xmin: b.bbox.xmin,
812            ymin: b.bbox.ymin,
813            xmax: b.bbox.xmax,
814            ymax: b.bbox.ymax,
815            segmentation: m,
816        });
817    }
818    Ok(())
819}
820
821pub(crate) fn impl_yolo_split_segdet_quant_get_boxes<
822    B: BBoxTypeTrait,
823    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
824    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
825>(
826    boxes: (ArrayView2<BOX>, Quantization),
827    scores: (ArrayView2<SCORE>, Quantization),
828    score_threshold: f32,
829    iou_threshold: f32,
830    nms: Option<Nms>,
831    max_boxes: usize,
832) -> Vec<(DetectBox, usize)>
833where
834    f32: AsPrimitive<SCORE>,
835{
836    let (boxes_tensor, quant_boxes) = boxes;
837    let (scores_tensor, quant_scores) = scores;
838
839    let boxes_tensor = boxes_tensor.reversed_axes();
840    let scores_tensor = scores_tensor.reversed_axes();
841
842    let boxes = {
843        let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
844        postprocess_boxes_index_quant::<B, _, _>(
845            score_threshold,
846            boxes_tensor,
847            scores_tensor,
848            quant_boxes,
849        )
850    };
851    let mut boxes = dispatch_nms_extra_int(nms, iou_threshold, boxes);
852    boxes.truncate(max_boxes);
853    boxes
854        .into_iter()
855        .map(|(b, i)| (dequant_detect_box(&b, quant_scores), i))
856        .collect()
857}
858
859pub(crate) fn impl_yolo_split_segdet_quant_process_masks<
860    MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
861    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
862>(
863    boxes: Vec<(DetectBox, usize)>,
864    mask_coeff: (ArrayView2<MASK>, Quantization),
865    protos: (ArrayView3<PROTO>, Quantization),
866    output_boxes: &mut Vec<DetectBox>,
867    output_masks: &mut Vec<Segmentation>,
868) -> Result<(), crate::DecoderError> {
869    let (masks, quant_masks) = mask_coeff;
870    let (protos, quant_protos) = protos;
871
872    let masks = masks.reversed_axes();
873
874    let boxes = decode_segdet_quant(boxes, masks, protos, quant_masks, quant_protos)?;
875    output_boxes.clear();
876    output_masks.clear();
877    for (b, m) in boxes.into_iter() {
878        output_boxes.push(b);
879        output_masks.push(Segmentation {
880            xmin: b.bbox.xmin,
881            ymin: b.bbox.ymin,
882            xmax: b.bbox.xmax,
883            ymax: b.bbox.ymax,
884            segmentation: m,
885        });
886    }
887    Ok(())
888}
889
890#[allow(clippy::too_many_arguments)]
891/// Internal implementation of YOLO split detection segmentation decoding for
892/// quantized tensors.
893///
894/// Expected shapes of inputs:
895/// - boxes_tensor: (4, num_boxes)
896/// - scores_tensor: (num_classes, num_boxes)
897/// - mask_tensor: (num_protos, num_boxes)
898/// - protos: (proto_height, proto_width, num_protos)
899///
900/// # Errors
901/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
902pub fn impl_yolo_split_segdet_quant<
903    B: BBoxTypeTrait,
904    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
905    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
906    MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
907    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
908>(
909    boxes: (ArrayView2<BOX>, Quantization),
910    scores: (ArrayView2<SCORE>, Quantization),
911    mask_coeff: (ArrayView2<MASK>, Quantization),
912    protos: (ArrayView3<PROTO>, Quantization),
913    score_threshold: f32,
914    iou_threshold: f32,
915    nms: Option<Nms>,
916    output_boxes: &mut Vec<DetectBox>,
917    output_masks: &mut Vec<Segmentation>,
918) -> Result<(), crate::DecoderError>
919where
920    f32: AsPrimitive<SCORE>,
921{
922    let boxes = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
923        boxes,
924        scores,
925        score_threshold,
926        iou_threshold,
927        nms,
928        output_boxes.capacity(),
929    );
930
931    impl_yolo_split_segdet_quant_process_masks(
932        boxes,
933        mask_coeff,
934        protos,
935        output_boxes,
936        output_masks,
937    )
938}
939
940#[allow(clippy::too_many_arguments)]
941/// Internal implementation of YOLO split detection segmentation decoding for
942/// float tensors.
943///
944/// Expected shapes of inputs:
945/// - boxes_tensor: (4, num_boxes)
946/// - scores_tensor: (num_classes, num_boxes)
947/// - mask_tensor: (num_protos, num_boxes)
948/// - protos: (proto_height, proto_width, num_protos)
949///
950/// # Errors
951/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
952pub fn impl_yolo_split_segdet_float<
953    B: BBoxTypeTrait,
954    BOX: Float + AsPrimitive<f32> + Send + Sync,
955    SCORE: Float + AsPrimitive<f32> + Send + Sync,
956    MASK: Float + AsPrimitive<f32> + Send + Sync,
957    PROTO: Float + AsPrimitive<f32> + Send + Sync,
958>(
959    boxes_tensor: ArrayView2<BOX>,
960    scores_tensor: ArrayView2<SCORE>,
961    mask_tensor: ArrayView2<MASK>,
962    protos: ArrayView3<PROTO>,
963    score_threshold: f32,
964    iou_threshold: f32,
965    nms: Option<Nms>,
966    output_boxes: &mut Vec<DetectBox>,
967    output_masks: &mut Vec<Segmentation>,
968) -> Result<(), crate::DecoderError>
969where
970    f32: AsPrimitive<SCORE>,
971{
972    let boxes_tensor = boxes_tensor.reversed_axes();
973    let scores_tensor = scores_tensor.reversed_axes();
974    let mask_tensor = mask_tensor.reversed_axes();
975
976    let boxes = postprocess_boxes_index_float::<B, _, _>(
977        score_threshold.as_(),
978        boxes_tensor,
979        scores_tensor,
980    );
981    let mut boxes = dispatch_nms_extra_float(nms, iou_threshold, boxes);
982    boxes.truncate(output_boxes.capacity());
983    let boxes = decode_segdet_f32(boxes, mask_tensor, protos)?;
984    output_boxes.clear();
985    output_masks.clear();
986    for (b, m) in boxes.into_iter() {
987        output_boxes.push(b);
988        output_masks.push(Segmentation {
989            xmin: b.bbox.xmin,
990            ymin: b.bbox.ymin,
991            xmax: b.bbox.xmax,
992            ymax: b.bbox.ymax,
993            segmentation: m,
994        });
995    }
996    Ok(())
997}
998
999// ---------------------------------------------------------------------------
1000// Proto-extraction variants: return ProtoData instead of materialized masks
1001// ---------------------------------------------------------------------------
1002
1003/// Proto-extraction variant of `impl_yolo_segdet_quant`.
1004/// Runs NMS but returns raw `ProtoData` instead of materialized masks.
1005pub fn impl_yolo_segdet_quant_proto<
1006    B: BBoxTypeTrait,
1007    BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1008    PROTO: PrimInt
1009        + AsPrimitive<i64>
1010        + AsPrimitive<i128>
1011        + AsPrimitive<f32>
1012        + AsPrimitive<i8>
1013        + Send
1014        + Sync,
1015>(
1016    boxes: (ArrayView2<BOX>, Quantization),
1017    protos: (ArrayView3<PROTO>, Quantization),
1018    score_threshold: f32,
1019    iou_threshold: f32,
1020    nms: Option<Nms>,
1021    output_boxes: &mut Vec<DetectBox>,
1022) -> ProtoData
1023where
1024    f32: AsPrimitive<BOX>,
1025{
1026    let (boxes_arr, quant_boxes) = boxes;
1027    let (protos_arr, quant_protos) = protos;
1028    let num_protos = protos_arr.dim().2;
1029
1030    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes_arr, num_protos);
1031
1032    let det_indices = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
1033        (boxes_tensor.reversed_axes(), quant_boxes),
1034        (scores_tensor.reversed_axes(), quant_boxes),
1035        score_threshold,
1036        iou_threshold,
1037        nms,
1038        output_boxes.capacity(),
1039    );
1040
1041    extract_proto_data_quant(
1042        det_indices,
1043        mask_tensor,
1044        quant_boxes,
1045        protos_arr,
1046        quant_protos,
1047        output_boxes,
1048    )
1049}
1050
1051/// Proto-extraction variant of `impl_yolo_segdet_float`.
1052/// Runs NMS but returns raw `ProtoData` instead of materialized masks.
1053pub fn impl_yolo_segdet_float_proto<
1054    B: BBoxTypeTrait,
1055    BOX: Float + AsPrimitive<f32> + Send + Sync,
1056    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1057>(
1058    boxes: ArrayView2<BOX>,
1059    protos: ArrayView3<PROTO>,
1060    score_threshold: f32,
1061    iou_threshold: f32,
1062    nms: Option<Nms>,
1063    output_boxes: &mut Vec<DetectBox>,
1064) -> ProtoData
1065where
1066    f32: AsPrimitive<BOX>,
1067{
1068    let num_protos = protos.dim().2;
1069    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
1070
1071    let det_indices = postprocess_boxes_index_float::<B, _, _>(
1072        score_threshold.as_(),
1073        boxes_tensor,
1074        scores_tensor,
1075    );
1076    let mut det_indices = dispatch_nms_extra_float(nms, iou_threshold, det_indices);
1077    det_indices.truncate(output_boxes.capacity());
1078
1079    extract_proto_data_float(det_indices, mask_tensor, protos, output_boxes)
1080}
1081
1082/// Proto-extraction variant of `impl_yolo_split_segdet_quant`.
1083/// Runs NMS but returns raw `ProtoData` instead of materialized masks.
1084#[allow(clippy::too_many_arguments)]
1085pub fn impl_yolo_split_segdet_quant_proto<
1086    B: BBoxTypeTrait,
1087    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
1088    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
1089    MASK: PrimInt + AsPrimitive<f32> + Send + Sync,
1090    PROTO: PrimInt + AsPrimitive<f32> + AsPrimitive<i8> + Send + Sync,
1091>(
1092    boxes: (ArrayView2<BOX>, Quantization),
1093    scores: (ArrayView2<SCORE>, Quantization),
1094    mask_coeff: (ArrayView2<MASK>, Quantization),
1095    protos: (ArrayView3<PROTO>, Quantization),
1096    score_threshold: f32,
1097    iou_threshold: f32,
1098    nms: Option<Nms>,
1099    output_boxes: &mut Vec<DetectBox>,
1100) -> ProtoData
1101where
1102    f32: AsPrimitive<SCORE>,
1103{
1104    let det_indices = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
1105        boxes,
1106        scores,
1107        score_threshold,
1108        iou_threshold,
1109        nms,
1110        output_boxes.capacity(),
1111    );
1112
1113    let (masks, quant_masks) = mask_coeff;
1114    let masks = masks.reversed_axes();
1115    let (protos_arr, quant_protos) = protos;
1116
1117    extract_proto_data_quant(
1118        det_indices,
1119        masks,
1120        quant_masks,
1121        protos_arr,
1122        quant_protos,
1123        output_boxes,
1124    )
1125}
1126
1127/// Proto-extraction variant of `impl_yolo_split_segdet_float`.
1128/// Runs NMS but returns raw `ProtoData` instead of materialized masks.
1129#[allow(clippy::too_many_arguments)]
1130pub fn impl_yolo_split_segdet_float_proto<
1131    B: BBoxTypeTrait,
1132    BOX: Float + AsPrimitive<f32> + Send + Sync,
1133    SCORE: Float + AsPrimitive<f32> + Send + Sync,
1134    MASK: Float + AsPrimitive<f32> + Send + Sync,
1135    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1136>(
1137    boxes_tensor: ArrayView2<BOX>,
1138    scores_tensor: ArrayView2<SCORE>,
1139    mask_tensor: ArrayView2<MASK>,
1140    protos: ArrayView3<PROTO>,
1141    score_threshold: f32,
1142    iou_threshold: f32,
1143    nms: Option<Nms>,
1144    output_boxes: &mut Vec<DetectBox>,
1145) -> ProtoData
1146where
1147    f32: AsPrimitive<SCORE>,
1148{
1149    let boxes_tensor = boxes_tensor.reversed_axes();
1150    let scores_tensor = scores_tensor.reversed_axes();
1151    let mask_tensor = mask_tensor.reversed_axes();
1152
1153    let det_indices = postprocess_boxes_index_float::<B, _, _>(
1154        score_threshold.as_(),
1155        boxes_tensor,
1156        scores_tensor,
1157    );
1158    let mut det_indices = dispatch_nms_extra_float(nms, iou_threshold, det_indices);
1159    det_indices.truncate(output_boxes.capacity());
1160
1161    extract_proto_data_float(det_indices, mask_tensor, protos, output_boxes)
1162}
1163
1164/// Proto-extraction variant of `decode_yolo_end_to_end_segdet_float`.
1165pub fn decode_yolo_end_to_end_segdet_float_proto<T>(
1166    output: ArrayView2<T>,
1167    protos: ArrayView3<T>,
1168    score_threshold: f32,
1169    output_boxes: &mut Vec<DetectBox>,
1170) -> Result<ProtoData, crate::DecoderError>
1171where
1172    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
1173    f32: AsPrimitive<T>,
1174{
1175    if output.shape()[0] < 7 {
1176        return Err(crate::DecoderError::InvalidShape(format!(
1177            "End-to-end segdet output requires at least 7 rows, got {}",
1178            output.shape()[0]
1179        )));
1180    }
1181
1182    let num_mask_coeffs = output.shape()[0] - 6;
1183    let num_protos = protos.shape()[2];
1184    if num_mask_coeffs != num_protos {
1185        return Err(crate::DecoderError::InvalidShape(format!(
1186            "Mask coefficients count ({}) doesn't match protos count ({})",
1187            num_mask_coeffs, num_protos
1188        )));
1189    }
1190
1191    let boxes = output.slice(s![0..4, ..]).reversed_axes();
1192    let scores = output.slice(s![4..5, ..]).reversed_axes();
1193    let classes = output.slice(s![5, ..]);
1194    let mask_coeff = output.slice(s![6.., ..]).reversed_axes();
1195    let mut det_indices =
1196        postprocess_boxes_index_float::<XYXY, _, _>(score_threshold.as_(), boxes, scores);
1197    det_indices.truncate(output_boxes.capacity());
1198
1199    for (b, ind) in &mut det_indices {
1200        b.label = classes[*ind].as_() as usize;
1201    }
1202
1203    Ok(extract_proto_data_float(
1204        det_indices,
1205        mask_coeff,
1206        protos,
1207        output_boxes,
1208    ))
1209}
1210
1211/// Proto-extraction variant of `decode_yolo_split_end_to_end_segdet_float`.
1212#[allow(clippy::too_many_arguments)]
1213pub fn decode_yolo_split_end_to_end_segdet_float_proto<T>(
1214    boxes: ArrayView2<T>,
1215    scores: ArrayView2<T>,
1216    classes: ArrayView2<T>,
1217    mask_coeff: ArrayView2<T>,
1218    protos: ArrayView3<T>,
1219    score_threshold: f32,
1220    output_boxes: &mut Vec<DetectBox>,
1221) -> Result<ProtoData, crate::DecoderError>
1222where
1223    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
1224    f32: AsPrimitive<T>,
1225{
1226    let n = boxes.shape()[0];
1227    if boxes.shape()[1] != 4 {
1228        return Err(crate::DecoderError::InvalidShape(format!(
1229            "Split end-to-end boxes must have 4 columns, got {}",
1230            boxes.shape()[1]
1231        )));
1232    }
1233
1234    let mut qualifying: Vec<(DetectBox, usize)> = Vec::with_capacity(output_boxes.capacity());
1235    for i in 0..n {
1236        let score: f32 = scores[[i, 0]].as_();
1237        if score < score_threshold {
1238            continue;
1239        }
1240        if qualifying.len() >= output_boxes.capacity() {
1241            break;
1242        }
1243        qualifying.push((
1244            DetectBox {
1245                bbox: BoundingBox {
1246                    xmin: boxes[[i, 0]].as_(),
1247                    ymin: boxes[[i, 1]].as_(),
1248                    xmax: boxes[[i, 2]].as_(),
1249                    ymax: boxes[[i, 3]].as_(),
1250                },
1251                score,
1252                label: classes[[i, 0]].as_() as usize,
1253            },
1254            i,
1255        ));
1256    }
1257
1258    Ok(extract_proto_data_float(
1259        qualifying,
1260        mask_coeff,
1261        protos,
1262        output_boxes,
1263    ))
1264}
1265
1266/// Helper: extract ProtoData from float mask coefficients + protos.
1267fn extract_proto_data_float<
1268    MASK: Float + AsPrimitive<f32> + Send + Sync,
1269    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1270>(
1271    det_indices: Vec<(DetectBox, usize)>,
1272    mask_tensor: ArrayView2<MASK>,
1273    protos: ArrayView3<PROTO>,
1274    output_boxes: &mut Vec<DetectBox>,
1275) -> ProtoData {
1276    let mut mask_coefficients = Vec::with_capacity(det_indices.len());
1277    output_boxes.clear();
1278    for (det, idx) in det_indices {
1279        output_boxes.push(det);
1280        let row = mask_tensor.row(idx);
1281        mask_coefficients.push(row.iter().map(|v| v.as_()).collect());
1282    }
1283    let protos_f32 = protos.map(|v| v.as_());
1284    ProtoData {
1285        mask_coefficients,
1286        protos: ProtoTensor::Float(protos_f32),
1287    }
1288}
1289
1290/// Helper: extract ProtoData from quantized mask coefficients + protos.
1291///
1292/// Dequantizes mask coefficients to f32 (small — per-detection) but keeps
1293/// protos in raw int8 form wrapped in `ProtoTensor::Quantized` so the GPU
1294/// shader can dequantize per-texel without CPU overhead.
1295pub(crate) fn extract_proto_data_quant<
1296    MASK: PrimInt + AsPrimitive<f32> + Send + Sync,
1297    PROTO: PrimInt + AsPrimitive<f32> + AsPrimitive<i8> + Send + Sync,
1298>(
1299    det_indices: Vec<(DetectBox, usize)>,
1300    mask_tensor: ArrayView2<MASK>,
1301    quant_masks: Quantization,
1302    protos: ArrayView3<PROTO>,
1303    quant_protos: Quantization,
1304    output_boxes: &mut Vec<DetectBox>,
1305) -> ProtoData {
1306    let mut mask_coefficients = Vec::with_capacity(det_indices.len());
1307    output_boxes.clear();
1308    for (det, idx) in det_indices {
1309        output_boxes.push(det);
1310        let row = mask_tensor.row(idx);
1311        mask_coefficients.push(
1312            row.iter()
1313                .map(|v| (v.as_() - quant_masks.zero_point as f32) * quant_masks.scale)
1314                .collect(),
1315        );
1316    }
1317    // Keep protos in raw int8 — GPU shader will dequantize per-texel.
1318    let protos_i8 = protos.map(|v| {
1319        let v_i8: i8 = v.as_();
1320        v_i8
1321    });
1322    ProtoData {
1323        mask_coefficients,
1324        protos: ProtoTensor::Quantized {
1325            protos: protos_i8,
1326            quantization: quant_protos,
1327        },
1328    }
1329}
1330
1331fn postprocess_yolo<'a, T>(
1332    output: &'a ArrayView2<'_, T>,
1333) -> (ArrayView2<'a, T>, ArrayView2<'a, T>) {
1334    let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1335    let scores_tensor = output.slice(s![4.., ..,]).reversed_axes();
1336    (boxes_tensor, scores_tensor)
1337}
1338
1339fn postprocess_yolo_seg<'a, T>(
1340    output: &'a ArrayView2<'_, T>,
1341    num_protos: usize,
1342) -> (ArrayView2<'a, T>, ArrayView2<'a, T>, ArrayView2<'a, T>) {
1343    assert!(
1344        output.shape()[0] > num_protos + 4,
1345        "Output shape is too short: {} <= {} + 4",
1346        output.shape()[0],
1347        num_protos
1348    );
1349    let num_classes = output.shape()[0] - 4 - num_protos;
1350    let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1351    let scores_tensor = output.slice(s![4..(num_classes + 4), ..,]).reversed_axes();
1352    let mask_tensor = output.slice(s![(num_classes + 4).., ..,]).reversed_axes();
1353    (boxes_tensor, scores_tensor, mask_tensor)
1354}
1355
1356fn decode_segdet_f32<
1357    MASK: Float + AsPrimitive<f32> + Send + Sync,
1358    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1359>(
1360    boxes: Vec<(DetectBox, usize)>,
1361    masks: ArrayView2<MASK>,
1362    protos: ArrayView3<PROTO>,
1363) -> Result<Vec<(DetectBox, Array3<u8>)>, crate::DecoderError> {
1364    if boxes.is_empty() {
1365        return Ok(Vec::new());
1366    }
1367    if masks.shape()[1] != protos.shape()[2] {
1368        return Err(crate::DecoderError::InvalidShape(format!(
1369            "Mask coefficients count ({}) doesn't match protos channel count ({})",
1370            masks.shape()[1],
1371            protos.shape()[2],
1372        )));
1373    }
1374    boxes
1375        .into_par_iter()
1376        .map(|mut b| {
1377            let ind = b.1;
1378            let (protos, roi) = protobox(&protos, &b.0.bbox)?;
1379            b.0.bbox = roi;
1380            Ok((b.0, make_segmentation(masks.row(ind), protos.view())))
1381        })
1382        .collect()
1383}
1384
1385pub(crate) fn decode_segdet_quant<
1386    MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1387    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1388>(
1389    boxes: Vec<(DetectBox, usize)>,
1390    masks: ArrayView2<MASK>,
1391    protos: ArrayView3<PROTO>,
1392    quant_masks: Quantization,
1393    quant_protos: Quantization,
1394) -> Result<Vec<(DetectBox, Array3<u8>)>, crate::DecoderError> {
1395    if boxes.is_empty() {
1396        return Ok(Vec::new());
1397    }
1398    if masks.shape()[1] != protos.shape()[2] {
1399        return Err(crate::DecoderError::InvalidShape(format!(
1400            "Mask coefficients count ({}) doesn't match protos channel count ({})",
1401            masks.shape()[1],
1402            protos.shape()[2],
1403        )));
1404    }
1405
1406    let total_bits = MASK::zero().count_zeros() + PROTO::zero().count_zeros() + 5; // 32 protos is 2^5
1407    boxes
1408        .into_iter()
1409        .map(|mut b| {
1410            let i = b.1;
1411            let (protos, roi) = protobox(&protos, &b.0.bbox.to_canonical())?;
1412            b.0.bbox = roi;
1413            let seg = match total_bits {
1414                0..=64 => make_segmentation_quant::<MASK, PROTO, i64>(
1415                    masks.row(i),
1416                    protos.view(),
1417                    quant_masks,
1418                    quant_protos,
1419                ),
1420                65..=128 => make_segmentation_quant::<MASK, PROTO, i128>(
1421                    masks.row(i),
1422                    protos.view(),
1423                    quant_masks,
1424                    quant_protos,
1425                ),
1426                _ => {
1427                    return Err(crate::DecoderError::NotSupported(format!(
1428                        "Unsupported bit width ({total_bits}) for segmentation computation"
1429                    )));
1430                }
1431            };
1432            Ok((b.0, seg))
1433        })
1434        .collect()
1435}
1436
1437fn protobox<'a, T>(
1438    protos: &'a ArrayView3<T>,
1439    roi: &BoundingBox,
1440) -> Result<(ArrayView3<'a, T>, BoundingBox), crate::DecoderError> {
1441    let width = protos.dim().1 as f32;
1442    let height = protos.dim().0 as f32;
1443
1444    // Detect un-normalized bounding boxes (pixel-space coordinates).
1445    // protobox expects normalized coordinates in [0, 1]. ONNX models output
1446    // pixel-space boxes (e.g. 0-640) which must be normalized before calling
1447    // decode(). Without this check, pixel-space coordinates silently clamp to
1448    // the proto boundary, producing empty (0, 0, C) masks for every detection.
1449    //
1450    // The limit is set to 2.0 (not 1.01) because YOLO models legitimately
1451    // predict coordinates slightly > 1.0 for objects near frame edges.
1452    // Any value > 2.0 is clearly pixel-space (even the smallest practical
1453    // model input of 32×32 would produce coordinates >> 2.0).
1454    const NORM_LIMIT: f32 = 2.0;
1455    if roi.xmin > NORM_LIMIT
1456        || roi.ymin > NORM_LIMIT
1457        || roi.xmax > NORM_LIMIT
1458        || roi.ymax > NORM_LIMIT
1459    {
1460        return Err(crate::DecoderError::InvalidShape(format!(
1461            "Bounding box coordinates appear un-normalized (pixel-space). \
1462             Got bbox=({:.2}, {:.2}, {:.2}, {:.2}) but expected values in [0, 1]. \
1463             ONNX models output pixel-space boxes — normalize them by dividing by \
1464             the input dimensions before calling decode().",
1465            roi.xmin, roi.ymin, roi.xmax, roi.ymax,
1466        )));
1467    }
1468
1469    let roi = [
1470        (roi.xmin * width).clamp(0.0, width) as usize,
1471        (roi.ymin * height).clamp(0.0, height) as usize,
1472        (roi.xmax * width).clamp(0.0, width).ceil() as usize,
1473        (roi.ymax * height).clamp(0.0, height).ceil() as usize,
1474    ];
1475
1476    let roi_norm = [
1477        roi[0] as f32 / width,
1478        roi[1] as f32 / height,
1479        roi[2] as f32 / width,
1480        roi[3] as f32 / height,
1481    ]
1482    .into();
1483
1484    let cropped = protos.slice(s![roi[1]..roi[3], roi[0]..roi[2], ..]);
1485
1486    Ok((cropped, roi_norm))
1487}
1488
1489/// Compute a single instance segmentation mask from mask coefficients and
1490/// proto maps (float path).
1491///
1492/// Computes `sigmoid(coefficients · protos)` and maps to `[0, 255]`.
1493/// Returns an `(H, W, 1)` u8 array.
1494fn make_segmentation<
1495    MASK: Float + AsPrimitive<f32> + Send + Sync,
1496    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1497>(
1498    mask: ArrayView1<MASK>,
1499    protos: ArrayView3<PROTO>,
1500) -> Array3<u8> {
1501    let shape = protos.shape();
1502
1503    // Safe to unwrap since the shapes will always be compatible
1504    let mask = mask.to_shape((1, mask.len())).unwrap();
1505    let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
1506    let protos = protos.reversed_axes();
1507    let mask = mask.map(|x| x.as_());
1508    let protos = protos.map(|x| x.as_());
1509
1510    // Safe to unwrap since the shapes will always be compatible
1511    let mask = mask
1512        .dot(&protos)
1513        .into_shape_with_order((shape[0], shape[1], 1))
1514        .unwrap();
1515
1516    mask.map(|x| {
1517        let sigmoid = 1.0 / (1.0 + (-*x).exp());
1518        (sigmoid * 255.0).round() as u8
1519    })
1520}
1521
1522/// Compute a single instance segmentation mask from quantized mask
1523/// coefficients and proto maps.
1524///
1525/// Dequantizes both inputs (subtracting zero-points), computes the dot
1526/// product, applies sigmoid, and maps to `[0, 255]`.
1527/// Returns an `(H, W, 1)` u8 array.
1528fn make_segmentation_quant<
1529    MASK: PrimInt + AsPrimitive<DEST> + Send + Sync,
1530    PROTO: PrimInt + AsPrimitive<DEST> + Send + Sync,
1531    DEST: PrimInt + 'static + Signed + AsPrimitive<f32> + Debug,
1532>(
1533    mask: ArrayView1<MASK>,
1534    protos: ArrayView3<PROTO>,
1535    quant_masks: Quantization,
1536    quant_protos: Quantization,
1537) -> Array3<u8>
1538where
1539    i32: AsPrimitive<DEST>,
1540    f32: AsPrimitive<DEST>,
1541{
1542    let shape = protos.shape();
1543
1544    // Safe to unwrap since the shapes will always be compatible
1545    let mask = mask.to_shape((1, mask.len())).unwrap();
1546
1547    let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
1548    let protos = protos.reversed_axes();
1549
1550    let zp = quant_masks.zero_point.as_();
1551
1552    let mask = mask.mapv(|x| x.as_() - zp);
1553
1554    let zp = quant_protos.zero_point.as_();
1555    let protos = protos.mapv(|x| x.as_() - zp);
1556
1557    // Safe to unwrap since the shapes will always be compatible
1558    let segmentation = mask
1559        .dot(&protos)
1560        .into_shape_with_order((shape[0], shape[1], 1))
1561        .unwrap();
1562
1563    let combined_scale = quant_masks.scale * quant_protos.scale;
1564    segmentation.map(|x| {
1565        let val: f32 = (*x).as_() * combined_scale;
1566        let sigmoid = 1.0 / (1.0 + (-val).exp());
1567        (sigmoid * 255.0).round() as u8
1568    })
1569}
1570
1571/// Converts Yolo Instance Segmentation into a 2D mask.
1572///
1573/// The input segmentation is expected to have shape (H, W, 1).
1574///
1575/// The output mask will have shape (H, W), with values 0 or 1 based on the
1576/// threshold.
1577///
1578/// # Errors
1579///
1580/// Returns `DecoderError::InvalidShape` if the input segmentation does not
1581/// have shape (H, W, 1).
1582pub fn yolo_segmentation_to_mask(
1583    segmentation: ArrayView3<u8>,
1584    threshold: u8,
1585) -> Result<Array2<u8>, crate::DecoderError> {
1586    if segmentation.shape()[2] != 1 {
1587        return Err(crate::DecoderError::InvalidShape(format!(
1588            "Yolo Instance Segmentation should have shape (H, W, 1), got (H, W, {})",
1589            segmentation.shape()[2]
1590        )));
1591    }
1592    Ok(segmentation
1593        .slice(s![.., .., 0])
1594        .map(|x| if *x >= threshold { 1 } else { 0 }))
1595}
1596
1597#[cfg(test)]
1598#[cfg_attr(coverage_nightly, coverage(off))]
1599mod tests {
1600    use super::*;
1601    use ndarray::Array2;
1602
1603    // ========================================================================
1604    // Tests for decode_yolo_end_to_end_det_float
1605    // ========================================================================
1606
1607    #[test]
1608    fn test_end_to_end_det_basic_filtering() {
1609        // Create synthetic end-to-end detection output: (6, N) where rows are
1610        // [x1, y1, x2, y2, conf, class]
1611        // 3 detections: one above threshold, two below
1612        let data: Vec<f32> = vec![
1613            // Detection 0: high score (0.9)
1614            0.1, 0.2, 0.3, // x1 values
1615            0.1, 0.2, 0.3, // y1 values
1616            0.5, 0.6, 0.7, // x2 values
1617            0.5, 0.6, 0.7, // y2 values
1618            0.9, 0.1, 0.2, // confidence scores
1619            0.0, 1.0, 2.0, // class indices
1620        ];
1621        let output = Array2::from_shape_vec((6, 3), data).unwrap();
1622
1623        let mut boxes = Vec::with_capacity(10);
1624        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1625
1626        // Only 1 detection should pass threshold of 0.5
1627        assert_eq!(boxes.len(), 1);
1628        assert_eq!(boxes[0].label, 0);
1629        assert!((boxes[0].score - 0.9).abs() < 0.01);
1630        assert!((boxes[0].bbox.xmin - 0.1).abs() < 0.01);
1631        assert!((boxes[0].bbox.ymin - 0.1).abs() < 0.01);
1632        assert!((boxes[0].bbox.xmax - 0.5).abs() < 0.01);
1633        assert!((boxes[0].bbox.ymax - 0.5).abs() < 0.01);
1634    }
1635
1636    #[test]
1637    fn test_end_to_end_det_all_pass_threshold() {
1638        // All detections above threshold
1639        let data: Vec<f32> = vec![
1640            10.0, 20.0, // x1
1641            10.0, 20.0, // y1
1642            50.0, 60.0, // x2
1643            50.0, 60.0, // y2
1644            0.8, 0.7, // conf (both above 0.5)
1645            1.0, 2.0, // class
1646        ];
1647        let output = Array2::from_shape_vec((6, 2), data).unwrap();
1648
1649        let mut boxes = Vec::with_capacity(10);
1650        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1651
1652        assert_eq!(boxes.len(), 2);
1653        assert_eq!(boxes[0].label, 1);
1654        assert_eq!(boxes[1].label, 2);
1655    }
1656
1657    #[test]
1658    fn test_end_to_end_det_none_pass_threshold() {
1659        // All detections below threshold
1660        let data: Vec<f32> = vec![
1661            10.0, 20.0, // x1
1662            10.0, 20.0, // y1
1663            50.0, 60.0, // x2
1664            50.0, 60.0, // y2
1665            0.1, 0.2, // conf (both below 0.5)
1666            1.0, 2.0, // class
1667        ];
1668        let output = Array2::from_shape_vec((6, 2), data).unwrap();
1669
1670        let mut boxes = Vec::with_capacity(10);
1671        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1672
1673        assert_eq!(boxes.len(), 0);
1674    }
1675
1676    #[test]
1677    fn test_end_to_end_det_capacity_limit() {
1678        // Test that output is truncated to capacity
1679        let data: Vec<f32> = vec![
1680            0.1, 0.2, 0.3, 0.4, 0.5, // x1
1681            0.1, 0.2, 0.3, 0.4, 0.5, // y1
1682            0.5, 0.6, 0.7, 0.8, 0.9, // x2
1683            0.5, 0.6, 0.7, 0.8, 0.9, // y2
1684            0.9, 0.9, 0.9, 0.9, 0.9, // conf (all pass)
1685            0.0, 1.0, 2.0, 3.0, 4.0, // class
1686        ];
1687        let output = Array2::from_shape_vec((6, 5), data).unwrap();
1688
1689        let mut boxes = Vec::with_capacity(2); // Only allow 2 boxes
1690        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1691
1692        assert_eq!(boxes.len(), 2);
1693    }
1694
1695    #[test]
1696    fn test_end_to_end_det_empty_output() {
1697        // Test with zero detections
1698        let output = Array2::<f32>::zeros((6, 0));
1699
1700        let mut boxes = Vec::with_capacity(10);
1701        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1702
1703        assert_eq!(boxes.len(), 0);
1704    }
1705
1706    #[test]
1707    fn test_end_to_end_det_pixel_coordinates() {
1708        // Test with pixel coordinates (non-normalized)
1709        let data: Vec<f32> = vec![
1710            100.0, // x1
1711            200.0, // y1
1712            300.0, // x2
1713            400.0, // y2
1714            0.95,  // conf
1715            5.0,   // class
1716        ];
1717        let output = Array2::from_shape_vec((6, 1), data).unwrap();
1718
1719        let mut boxes = Vec::with_capacity(10);
1720        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1721
1722        assert_eq!(boxes.len(), 1);
1723        assert_eq!(boxes[0].label, 5);
1724        assert!((boxes[0].bbox.xmin - 100.0).abs() < 0.01);
1725        assert!((boxes[0].bbox.ymin - 200.0).abs() < 0.01);
1726        assert!((boxes[0].bbox.xmax - 300.0).abs() < 0.01);
1727        assert!((boxes[0].bbox.ymax - 400.0).abs() < 0.01);
1728    }
1729
1730    #[test]
1731    fn test_end_to_end_det_invalid_shape() {
1732        // Test with too few rows (needs at least 6)
1733        let output = Array2::<f32>::zeros((5, 3));
1734
1735        let mut boxes = Vec::with_capacity(10);
1736        let result = decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes);
1737
1738        assert!(result.is_err());
1739        assert!(matches!(
1740            result,
1741            Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 6 rows")
1742        ));
1743    }
1744
1745    // ========================================================================
1746    // Tests for decode_yolo_end_to_end_segdet_float
1747    // ========================================================================
1748
1749    #[test]
1750    fn test_end_to_end_segdet_basic() {
1751        // Create synthetic segdet output: (6 + num_protos, N)
1752        // Detection format: [x1, y1, x2, y2, conf, class, mask_coeff_0..31]
1753        let num_protos = 32;
1754        let num_detections = 2;
1755        let num_features = 6 + num_protos;
1756
1757        // Build detection tensor
1758        let mut data = vec![0.0f32; num_features * num_detections];
1759        // Detection 0: passes threshold
1760        data[0] = 0.1; // x1[0]
1761        data[1] = 0.5; // x1[1]
1762        data[num_detections] = 0.1; // y1[0]
1763        data[num_detections + 1] = 0.5; // y1[1]
1764        data[2 * num_detections] = 0.4; // x2[0]
1765        data[2 * num_detections + 1] = 0.9; // x2[1]
1766        data[3 * num_detections] = 0.4; // y2[0]
1767        data[3 * num_detections + 1] = 0.9; // y2[1]
1768        data[4 * num_detections] = 0.9; // conf[0] - passes
1769        data[4 * num_detections + 1] = 0.3; // conf[1] - fails
1770        data[5 * num_detections] = 1.0; // class[0]
1771        data[5 * num_detections + 1] = 2.0; // class[1]
1772                                            // Fill mask coefficients with small values
1773        for i in 6..num_features {
1774            data[i * num_detections] = 0.1;
1775            data[i * num_detections + 1] = 0.1;
1776        }
1777
1778        let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
1779
1780        // Create protos tensor: (proto_height, proto_width, num_protos)
1781        let protos = Array3::<f32>::zeros((16, 16, num_protos));
1782
1783        let mut boxes = Vec::with_capacity(10);
1784        let mut masks = Vec::with_capacity(10);
1785        decode_yolo_end_to_end_segdet_float(
1786            output.view(),
1787            protos.view(),
1788            0.5,
1789            &mut boxes,
1790            &mut masks,
1791        )
1792        .unwrap();
1793
1794        // Only detection 0 should pass
1795        assert_eq!(boxes.len(), 1);
1796        assert_eq!(masks.len(), 1);
1797        assert_eq!(boxes[0].label, 1);
1798        assert!((boxes[0].score - 0.9).abs() < 0.01);
1799    }
1800
1801    #[test]
1802    fn test_end_to_end_segdet_mask_coordinates() {
1803        // Test that mask coordinates match box coordinates
1804        let num_protos = 32;
1805        let num_features = 6 + num_protos;
1806
1807        let mut data = vec![0.0f32; num_features];
1808        data[0] = 0.2; // x1
1809        data[1] = 0.2; // y1
1810        data[2] = 0.8; // x2
1811        data[3] = 0.8; // y2
1812        data[4] = 0.95; // conf
1813        data[5] = 3.0; // class
1814
1815        let output = Array2::from_shape_vec((num_features, 1), data).unwrap();
1816        let protos = Array3::<f32>::zeros((16, 16, num_protos));
1817
1818        let mut boxes = Vec::with_capacity(10);
1819        let mut masks = Vec::with_capacity(10);
1820        decode_yolo_end_to_end_segdet_float(
1821            output.view(),
1822            protos.view(),
1823            0.5,
1824            &mut boxes,
1825            &mut masks,
1826        )
1827        .unwrap();
1828
1829        assert_eq!(boxes.len(), 1);
1830        assert_eq!(masks.len(), 1);
1831
1832        // Verify mask coordinates match box coordinates
1833        assert!((masks[0].xmin - boxes[0].bbox.xmin).abs() < 0.01);
1834        assert!((masks[0].ymin - boxes[0].bbox.ymin).abs() < 0.01);
1835        assert!((masks[0].xmax - boxes[0].bbox.xmax).abs() < 0.01);
1836        assert!((masks[0].ymax - boxes[0].bbox.ymax).abs() < 0.01);
1837    }
1838
1839    #[test]
1840    fn test_end_to_end_segdet_empty_output() {
1841        let num_protos = 32;
1842        let output = Array2::<f32>::zeros((6 + num_protos, 0));
1843        let protos = Array3::<f32>::zeros((16, 16, num_protos));
1844
1845        let mut boxes = Vec::with_capacity(10);
1846        let mut masks = Vec::with_capacity(10);
1847        decode_yolo_end_to_end_segdet_float(
1848            output.view(),
1849            protos.view(),
1850            0.5,
1851            &mut boxes,
1852            &mut masks,
1853        )
1854        .unwrap();
1855
1856        assert_eq!(boxes.len(), 0);
1857        assert_eq!(masks.len(), 0);
1858    }
1859
1860    #[test]
1861    fn test_end_to_end_segdet_capacity_limit() {
1862        let num_protos = 32;
1863        let num_detections = 5;
1864        let num_features = 6 + num_protos;
1865
1866        let mut data = vec![0.0f32; num_features * num_detections];
1867        // All detections pass threshold
1868        for i in 0..num_detections {
1869            data[i] = 0.1 * (i as f32); // x1
1870            data[num_detections + i] = 0.1 * (i as f32); // y1
1871            data[2 * num_detections + i] = 0.1 * (i as f32) + 0.2; // x2
1872            data[3 * num_detections + i] = 0.1 * (i as f32) + 0.2; // y2
1873            data[4 * num_detections + i] = 0.9; // conf
1874            data[5 * num_detections + i] = i as f32; // class
1875        }
1876
1877        let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
1878        let protos = Array3::<f32>::zeros((16, 16, num_protos));
1879
1880        let mut boxes = Vec::with_capacity(2); // Limit to 2
1881        let mut masks = Vec::with_capacity(2);
1882        decode_yolo_end_to_end_segdet_float(
1883            output.view(),
1884            protos.view(),
1885            0.5,
1886            &mut boxes,
1887            &mut masks,
1888        )
1889        .unwrap();
1890
1891        assert_eq!(boxes.len(), 2);
1892        assert_eq!(masks.len(), 2);
1893    }
1894
1895    #[test]
1896    fn test_end_to_end_segdet_invalid_shape_too_few_rows() {
1897        // Test with too few rows (needs at least 7: 6 base + 1 mask coeff)
1898        let output = Array2::<f32>::zeros((6, 3));
1899        let protos = Array3::<f32>::zeros((16, 16, 32));
1900
1901        let mut boxes = Vec::with_capacity(10);
1902        let mut masks = Vec::with_capacity(10);
1903        let result = decode_yolo_end_to_end_segdet_float(
1904            output.view(),
1905            protos.view(),
1906            0.5,
1907            &mut boxes,
1908            &mut masks,
1909        );
1910
1911        assert!(result.is_err());
1912        assert!(matches!(
1913            result,
1914            Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 7 rows")
1915        ));
1916    }
1917
1918    #[test]
1919    fn test_end_to_end_segdet_invalid_shape_protos_mismatch() {
1920        // Test with mismatched mask coefficients and protos count
1921        let num_protos = 32;
1922        let output = Array2::<f32>::zeros((6 + 16, 3)); // 16 mask coeffs
1923        let protos = Array3::<f32>::zeros((16, 16, num_protos)); // 32 protos
1924
1925        let mut boxes = Vec::with_capacity(10);
1926        let mut masks = Vec::with_capacity(10);
1927        let result = decode_yolo_end_to_end_segdet_float(
1928            output.view(),
1929            protos.view(),
1930            0.5,
1931            &mut boxes,
1932            &mut masks,
1933        );
1934
1935        assert!(result.is_err());
1936        assert!(matches!(
1937            result,
1938            Err(crate::DecoderError::InvalidShape(s)) if s.contains("doesn't match protos count")
1939        ));
1940    }
1941
1942    // ========================================================================
1943    // Tests for yolo_segmentation_to_mask
1944    // ========================================================================
1945
1946    #[test]
1947    fn test_segmentation_to_mask_basic() {
1948        // Create a 4x4x1 segmentation with values above and below threshold
1949        let data: Vec<u8> = vec![
1950            100, 200, 50, 150, // row 0
1951            10, 255, 128, 64, // row 1
1952            0, 127, 128, 255, // row 2
1953            64, 64, 192, 192, // row 3
1954        ];
1955        let segmentation = Array3::from_shape_vec((4, 4, 1), data).unwrap();
1956
1957        let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
1958
1959        // Values >= 128 should be 1, others 0
1960        assert_eq!(mask[[0, 0]], 0); // 100 < 128
1961        assert_eq!(mask[[0, 1]], 1); // 200 >= 128
1962        assert_eq!(mask[[0, 2]], 0); // 50 < 128
1963        assert_eq!(mask[[0, 3]], 1); // 150 >= 128
1964        assert_eq!(mask[[1, 1]], 1); // 255 >= 128
1965        assert_eq!(mask[[1, 2]], 1); // 128 >= 128
1966        assert_eq!(mask[[2, 0]], 0); // 0 < 128
1967        assert_eq!(mask[[2, 1]], 0); // 127 < 128
1968    }
1969
1970    #[test]
1971    fn test_segmentation_to_mask_all_above() {
1972        let segmentation = Array3::from_elem((4, 4, 1), 255u8);
1973        let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
1974        assert!(mask.iter().all(|&x| x == 1));
1975    }
1976
1977    #[test]
1978    fn test_segmentation_to_mask_all_below() {
1979        let segmentation = Array3::from_elem((4, 4, 1), 64u8);
1980        let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
1981        assert!(mask.iter().all(|&x| x == 0));
1982    }
1983
1984    #[test]
1985    fn test_segmentation_to_mask_invalid_shape() {
1986        let segmentation = Array3::from_elem((4, 4, 3), 128u8);
1987        let result = yolo_segmentation_to_mask(segmentation.view(), 128);
1988
1989        assert!(result.is_err());
1990        assert!(matches!(
1991            result,
1992            Err(crate::DecoderError::InvalidShape(s)) if s.contains("(H, W, 1)")
1993        ));
1994    }
1995
1996    // ========================================================================
1997    // Tests for protobox / NORM_LIMIT regression
1998    // ========================================================================
1999
2000    #[test]
2001    fn test_protobox_clamps_edge_coordinates() {
2002        // bbox with xmax=1.0 should not panic (OOB guard)
2003        let protos = Array3::<f32>::zeros((16, 16, 4));
2004        let view = protos.view();
2005        let roi = BoundingBox {
2006            xmin: 0.5,
2007            ymin: 0.5,
2008            xmax: 1.0,
2009            ymax: 1.0,
2010        };
2011        let result = protobox(&view, &roi);
2012        assert!(result.is_ok(), "protobox should accept xmax=1.0");
2013        let (cropped, _roi_norm) = result.unwrap();
2014        // Cropped region must have non-zero spatial dimensions
2015        assert!(cropped.shape()[0] > 0);
2016        assert!(cropped.shape()[1] > 0);
2017        assert_eq!(cropped.shape()[2], 4);
2018    }
2019
2020    #[test]
2021    fn test_protobox_rejects_wildly_out_of_range() {
2022        // bbox with coords > NORM_LIMIT (e.g. 3.0) returns error
2023        let protos = Array3::<f32>::zeros((16, 16, 4));
2024        let view = protos.view();
2025        let roi = BoundingBox {
2026            xmin: 0.0,
2027            ymin: 0.0,
2028            xmax: 3.0,
2029            ymax: 3.0,
2030        };
2031        let result = protobox(&view, &roi);
2032        assert!(
2033            matches!(result, Err(crate::DecoderError::InvalidShape(s)) if s.contains("un-normalized")),
2034            "protobox should reject coords > NORM_LIMIT"
2035        );
2036    }
2037
2038    #[test]
2039    fn test_protobox_accepts_slightly_over_one() {
2040        // bbox with coords at 1.5 (within NORM_LIMIT=2.0) succeeds
2041        let protos = Array3::<f32>::zeros((16, 16, 4));
2042        let view = protos.view();
2043        let roi = BoundingBox {
2044            xmin: 0.0,
2045            ymin: 0.0,
2046            xmax: 1.5,
2047            ymax: 1.5,
2048        };
2049        let result = protobox(&view, &roi);
2050        assert!(
2051            result.is_ok(),
2052            "protobox should accept coords <= NORM_LIMIT (2.0)"
2053        );
2054        let (cropped, _roi_norm) = result.unwrap();
2055        // Entire proto map should be selected when coords > 1.0 (clamped to boundary)
2056        assert_eq!(cropped.shape()[0], 16);
2057        assert_eq!(cropped.shape()[1], 16);
2058    }
2059
2060    #[test]
2061    fn test_segdet_float_proto_no_panic() {
2062        // Simulates YOLOv8n-seg: output0 = [116, 8400] (4 box + 80 class + 32 mask coeff)
2063        // output1 (protos) = [32, 160, 160]
2064        let num_proposals = 100; // enough to produce idx >= 32
2065        let num_classes = 80;
2066        let num_mask_coeffs = 32;
2067        let rows = 4 + num_classes + num_mask_coeffs; // 116
2068
2069        // Fill boxes with valid xywh data so some detections pass the threshold.
2070        // Layout is [116, num_proposals] row-major: row 0=cx, 1=cy, 2=w, 3=h,
2071        // rows 4..84=class scores, rows 84..116=mask coefficients.
2072        let mut data = vec![0.0f32; rows * num_proposals];
2073        for i in 0..num_proposals {
2074            let row = |r: usize| r * num_proposals + i;
2075            data[row(0)] = 320.0; // cx
2076            data[row(1)] = 320.0; // cy
2077            data[row(2)] = 50.0; // w
2078            data[row(3)] = 50.0; // h
2079            data[row(4)] = 0.9; // class-0 score
2080        }
2081        let boxes = ndarray::Array2::from_shape_vec((rows, num_proposals), data).unwrap();
2082
2083        // Protos must be in HWC order (decoder.rs protos_to_hwc converts
2084        // before calling into these functions).
2085        let protos = ndarray::Array3::<f32>::zeros((160, 160, num_mask_coeffs));
2086
2087        let mut output_boxes = Vec::with_capacity(300);
2088
2089        // This panicked before fix: mask_tensor.row(idx) with idx >= 32
2090        let proto_data = impl_yolo_segdet_float_proto::<XYWH, _, _>(
2091            boxes.view(),
2092            protos.view(),
2093            0.5,
2094            0.7,
2095            Some(Nms::default()),
2096            &mut output_boxes,
2097        );
2098
2099        // Should produce detections (NMS will collapse many overlapping boxes)
2100        assert!(!output_boxes.is_empty());
2101        assert_eq!(proto_data.mask_coefficients.len(), output_boxes.len());
2102        // Each mask coefficient vector should have 32 elements
2103        for coeffs in &proto_data.mask_coefficients {
2104            assert_eq!(coeffs.len(), num_mask_coeffs);
2105        }
2106    }
2107}