Skip to main content

edgefirst_decoder/
yolo.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use std::fmt::Debug;
5
6use ndarray::{
7    parallel::prelude::{IntoParallelIterator, ParallelIterator},
8    s, Array2, Array3, ArrayView1, ArrayView2, ArrayView3,
9};
10use num_traits::{AsPrimitive, Float, PrimInt, Signed};
11use rayon::slice::ParallelSliceMut;
12
13use crate::{
14    byte::{
15        nms_class_aware_int, nms_extra_class_aware_int, nms_extra_int, nms_int,
16        postprocess_boxes_index_quant, postprocess_boxes_quant, quantize_score_threshold,
17    },
18    configs::Nms,
19    dequant_detect_box,
20    float::{
21        nms_class_aware_float, nms_extra_class_aware_float, nms_extra_float, nms_float,
22        postprocess_boxes_float, postprocess_boxes_index_float,
23    },
24    BBoxTypeTrait, BoundingBox, DetectBox, DetectBoxQuantized, ProtoData, ProtoTensor,
25    Quantization, Segmentation, XYWH, XYXY,
26};
27
28/// Maximum number of above-threshold candidates fed to NMS.
29///
30/// At very low score thresholds (e.g., t=0.01 on YOLOv8 with
31/// 8400 anchors × 80 classes), the number of survivors approaches
32/// the full 672 000-entry score grid. NMS is O(n²) and the
33/// downstream mask matmul runs once per survivor, so an
34/// unbounded set produces minutes-per-frame decode times.
35///
36/// `MAX_NMS_CANDIDATES` matches the Ultralytics `max_nms` default
37/// and is applied as a top-K-by-score truncation immediately
38/// before NMS. Values above the cap are silently dropped — at the
39/// score thresholds where the cap activates the bottom of the
40/// candidate list is dominated by noise that NMS would discard
41/// anyway.
42pub(crate) const MAX_NMS_CANDIDATES: usize = 30_000;
43
44/// Truncate `boxes` to the highest-scoring `MAX_NMS_CANDIDATES`
45/// entries in-place when the input exceeds the cap. No-op
46/// otherwise. The sort is unstable and parallel — order among
47/// equal-score boxes is not guaranteed.
48fn truncate_to_top_k_by_score<E: Send>(boxes: &mut Vec<(DetectBox, E)>) {
49    if boxes.len() > MAX_NMS_CANDIDATES {
50        boxes.par_sort_unstable_by(|a, b| b.0.score.total_cmp(&a.0.score));
51        boxes.truncate(MAX_NMS_CANDIDATES);
52    }
53}
54
55/// Quantized counterpart of [`truncate_to_top_k_by_score`]. Sorts on
56/// the raw quantized score (which preserves order under monotonic
57/// dequantization).
58fn truncate_to_top_k_by_score_quant<S: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send>(
59    boxes: &mut Vec<(DetectBoxQuantized<S>, E)>,
60) {
61    if boxes.len() > MAX_NMS_CANDIDATES {
62        boxes.par_sort_unstable_by(|a, b| b.0.score.cmp(&a.0.score));
63        boxes.truncate(MAX_NMS_CANDIDATES);
64    }
65}
66
67/// Dispatches to the appropriate NMS function based on mode for float boxes.
68fn dispatch_nms_float(nms: Option<Nms>, iou: f32, boxes: Vec<DetectBox>) -> Vec<DetectBox> {
69    match nms {
70        Some(Nms::ClassAgnostic) => nms_float(iou, boxes),
71        Some(Nms::ClassAware) => nms_class_aware_float(iou, boxes),
72        None => boxes, // bypass NMS
73    }
74}
75
76/// Dispatches to the appropriate NMS function based on mode for float boxes
77/// with extra data.
78pub(super) fn dispatch_nms_extra_float<E: Send + Sync>(
79    nms: Option<Nms>,
80    iou: f32,
81    boxes: Vec<(DetectBox, E)>,
82) -> Vec<(DetectBox, E)> {
83    match nms {
84        Some(Nms::ClassAgnostic) => nms_extra_float(iou, boxes),
85        Some(Nms::ClassAware) => nms_extra_class_aware_float(iou, boxes),
86        None => boxes, // bypass NMS
87    }
88}
89
90/// Dispatches to the appropriate NMS function based on mode for quantized
91/// boxes.
92fn dispatch_nms_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync>(
93    nms: Option<Nms>,
94    iou: f32,
95    boxes: Vec<DetectBoxQuantized<SCORE>>,
96) -> Vec<DetectBoxQuantized<SCORE>> {
97    match nms {
98        Some(Nms::ClassAgnostic) => nms_int(iou, boxes),
99        Some(Nms::ClassAware) => nms_class_aware_int(iou, boxes),
100        None => boxes, // bypass NMS
101    }
102}
103
104/// Dispatches to the appropriate NMS function based on mode for quantized boxes
105/// with extra data.
106fn dispatch_nms_extra_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send + Sync>(
107    nms: Option<Nms>,
108    iou: f32,
109    boxes: Vec<(DetectBoxQuantized<SCORE>, E)>,
110) -> Vec<(DetectBoxQuantized<SCORE>, E)> {
111    match nms {
112        Some(Nms::ClassAgnostic) => nms_extra_int(iou, boxes),
113        Some(Nms::ClassAware) => nms_extra_class_aware_int(iou, boxes),
114        None => boxes, // bypass NMS
115    }
116}
117
118/// Decodes YOLO detection outputs from quantized tensors into detection boxes.
119///
120/// Boxes are expected to be in XYWH format.
121///
122/// Expected shapes of inputs:
123/// - output: (4 + num_classes, num_boxes)
124pub fn decode_yolo_det<BOX: PrimInt + AsPrimitive<f32> + Send + Sync>(
125    output: (ArrayView2<BOX>, Quantization),
126    score_threshold: f32,
127    iou_threshold: f32,
128    nms: Option<Nms>,
129    output_boxes: &mut Vec<DetectBox>,
130) where
131    f32: AsPrimitive<BOX>,
132{
133    impl_yolo_quant::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
134}
135
136/// Decodes YOLO detection outputs from float tensors into detection boxes.
137///
138/// Boxes are expected to be in XYWH format.
139///
140/// Expected shapes of inputs:
141/// - output: (4 + num_classes, num_boxes)
142pub fn decode_yolo_det_float<T>(
143    output: ArrayView2<T>,
144    score_threshold: f32,
145    iou_threshold: f32,
146    nms: Option<Nms>,
147    output_boxes: &mut Vec<DetectBox>,
148) where
149    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
150    f32: AsPrimitive<T>,
151{
152    impl_yolo_float::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
153}
154
155/// Decodes YOLO detection and segmentation outputs from quantized tensors into
156/// detection boxes and segmentation masks.
157///
158/// Boxes are expected to be in XYWH format.
159///
160/// Expected shapes of inputs:
161/// - boxes: (4 + num_classes + num_protos, num_boxes)
162/// - protos: (proto_height, proto_width, num_protos)
163///
164/// # Errors
165/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
166pub fn decode_yolo_segdet_quant<
167    BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
168    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
169>(
170    boxes: (ArrayView2<BOX>, Quantization),
171    protos: (ArrayView3<PROTO>, Quantization),
172    score_threshold: f32,
173    iou_threshold: f32,
174    nms: Option<Nms>,
175    output_boxes: &mut Vec<DetectBox>,
176    output_masks: &mut Vec<Segmentation>,
177) -> Result<(), crate::DecoderError>
178where
179    f32: AsPrimitive<BOX>,
180{
181    impl_yolo_segdet_quant::<XYWH, _, _>(
182        boxes,
183        protos,
184        score_threshold,
185        iou_threshold,
186        nms,
187        output_boxes,
188        output_masks,
189    )
190}
191
192/// Decodes YOLO detection and segmentation outputs from float tensors into
193/// detection boxes and segmentation masks.
194///
195/// Boxes are expected to be in XYWH format.
196///
197/// Expected shapes of inputs:
198/// - boxes: (4 + num_classes + num_protos, num_boxes)
199/// - protos: (proto_height, proto_width, num_protos)
200///
201/// # Errors
202/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
203pub fn decode_yolo_segdet_float<T>(
204    boxes: ArrayView2<T>,
205    protos: ArrayView3<T>,
206    score_threshold: f32,
207    iou_threshold: f32,
208    nms: Option<Nms>,
209    output_boxes: &mut Vec<DetectBox>,
210    output_masks: &mut Vec<Segmentation>,
211) -> Result<(), crate::DecoderError>
212where
213    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
214    f32: AsPrimitive<T>,
215{
216    impl_yolo_segdet_float::<XYWH, _, _>(
217        boxes,
218        protos,
219        score_threshold,
220        iou_threshold,
221        nms,
222        output_boxes,
223        output_masks,
224    )
225}
226
227/// Decodes YOLO split detection outputs from quantized tensors into detection
228/// boxes.
229///
230/// Boxes are expected to be in XYWH format.
231///
232/// Expected shapes of inputs:
233/// - boxes: (4, num_boxes)
234/// - scores: (num_classes, num_boxes)
235///
236/// # Panics
237/// Panics if shapes don't match the expected dimensions.
238pub fn decode_yolo_split_det_quant<
239    BOX: PrimInt + AsPrimitive<i32> + AsPrimitive<f32> + Send + Sync,
240    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
241>(
242    boxes: (ArrayView2<BOX>, Quantization),
243    scores: (ArrayView2<SCORE>, Quantization),
244    score_threshold: f32,
245    iou_threshold: f32,
246    nms: Option<Nms>,
247    output_boxes: &mut Vec<DetectBox>,
248) where
249    f32: AsPrimitive<SCORE>,
250{
251    impl_yolo_split_quant::<XYWH, _, _>(
252        boxes,
253        scores,
254        score_threshold,
255        iou_threshold,
256        nms,
257        output_boxes,
258    );
259}
260
261/// Decodes YOLO split detection outputs from float tensors into detection
262/// boxes.
263///
264/// Boxes are expected to be in XYWH format.
265///
266/// Expected shapes of inputs:
267/// - boxes: (4, num_boxes)
268/// - scores: (num_classes, num_boxes)
269///
270/// # Panics
271/// Panics if shapes don't match the expected dimensions.
272pub fn decode_yolo_split_det_float<T>(
273    boxes: ArrayView2<T>,
274    scores: ArrayView2<T>,
275    score_threshold: f32,
276    iou_threshold: f32,
277    nms: Option<Nms>,
278    output_boxes: &mut Vec<DetectBox>,
279) where
280    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
281    f32: AsPrimitive<T>,
282{
283    impl_yolo_split_float::<XYWH, _, _>(
284        boxes,
285        scores,
286        score_threshold,
287        iou_threshold,
288        nms,
289        output_boxes,
290    );
291}
292
293/// Decodes YOLO split detection segmentation outputs from quantized tensors
294/// into detection boxes and segmentation masks.
295///
296/// Boxes are expected to be in XYWH format.
297///
298/// Expected shapes of inputs:
299/// - boxes_tensor: (4, num_boxes)
300/// - scores_tensor: (num_classes, num_boxes)
301/// - mask_tensor: (num_protos, num_boxes)
302/// - protos: (proto_height, proto_width, num_protos)
303///
304/// # Errors
305/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
306#[allow(clippy::too_many_arguments)]
307pub fn decode_yolo_split_segdet<
308    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
309    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
310    MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
311    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
312>(
313    boxes: (ArrayView2<BOX>, Quantization),
314    scores: (ArrayView2<SCORE>, Quantization),
315    mask_coeff: (ArrayView2<MASK>, Quantization),
316    protos: (ArrayView3<PROTO>, Quantization),
317    score_threshold: f32,
318    iou_threshold: f32,
319    nms: Option<Nms>,
320    output_boxes: &mut Vec<DetectBox>,
321    output_masks: &mut Vec<Segmentation>,
322) -> Result<(), crate::DecoderError>
323where
324    f32: AsPrimitive<SCORE>,
325{
326    impl_yolo_split_segdet_quant::<XYWH, _, _, _, _>(
327        boxes,
328        scores,
329        mask_coeff,
330        protos,
331        score_threshold,
332        iou_threshold,
333        nms,
334        output_boxes,
335        output_masks,
336    )
337}
338
339/// Decodes YOLO split detection segmentation outputs from float tensors
340/// into detection boxes and segmentation masks.
341///
342/// Boxes are expected to be in XYWH format.
343///
344/// Expected shapes of inputs:
345/// - boxes_tensor: (4, num_boxes)
346/// - scores_tensor: (num_classes, num_boxes)
347/// - mask_tensor: (num_protos, num_boxes)
348/// - protos: (proto_height, proto_width, num_protos)
349///
350/// # Errors
351/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
352#[allow(clippy::too_many_arguments)]
353pub fn decode_yolo_split_segdet_float<T>(
354    boxes: ArrayView2<T>,
355    scores: ArrayView2<T>,
356    mask_coeff: ArrayView2<T>,
357    protos: ArrayView3<T>,
358    score_threshold: f32,
359    iou_threshold: f32,
360    nms: Option<Nms>,
361    output_boxes: &mut Vec<DetectBox>,
362    output_masks: &mut Vec<Segmentation>,
363) -> Result<(), crate::DecoderError>
364where
365    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
366    f32: AsPrimitive<T>,
367{
368    impl_yolo_split_segdet_float::<XYWH, _, _, _, _>(
369        boxes,
370        scores,
371        mask_coeff,
372        protos,
373        score_threshold,
374        iou_threshold,
375        nms,
376        output_boxes,
377        output_masks,
378    )
379}
380
381/// Decodes end-to-end YOLO detection outputs (post-NMS from model).
382/// Expects an array of shape `(6, N)`, where the first dimension (rows)
383/// corresponds to the 6 per-detection features
384/// `[x1, y1, x2, y2, conf, class]` and the second dimension (columns)
385/// indexes the `N` detections.
386/// Boxes are output directly without NMS (the model already applied NMS).
387///
388/// Coordinates may be normalized `[0, 1]` or absolute pixel values depending
389/// on the model configuration. The caller should check
390/// `decoder.normalized_boxes()` to determine which.
391///
392/// # Errors
393///
394/// Returns `DecoderError::InvalidShape` if `output` has fewer than 6 rows.
395pub fn decode_yolo_end_to_end_det_float<T>(
396    output: ArrayView2<T>,
397    score_threshold: f32,
398    output_boxes: &mut Vec<DetectBox>,
399) -> Result<(), crate::DecoderError>
400where
401    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
402    f32: AsPrimitive<T>,
403{
404    // Validate input shape: need at least 6 rows (x1, y1, x2, y2, conf, class)
405    if output.shape()[0] < 6 {
406        return Err(crate::DecoderError::InvalidShape(format!(
407            "End-to-end detection output requires at least 6 rows, got {}",
408            output.shape()[0]
409        )));
410    }
411
412    // Input shape: (6, N) -> transpose to (N, 4) for boxes and (N, 1) for scores
413    let boxes = output.slice(s![0..4, ..]).reversed_axes();
414    let scores = output.slice(s![4..5, ..]).reversed_axes();
415    let classes = output.slice(s![5, ..]);
416    let mut boxes =
417        postprocess_boxes_index_float::<XYXY, _, _>(score_threshold.as_(), boxes, scores);
418    boxes.truncate(output_boxes.capacity());
419    output_boxes.clear();
420    for (mut b, i) in boxes.into_iter() {
421        b.label = classes[i].as_() as usize;
422        output_boxes.push(b);
423    }
424    // No NMS — model output is already post-NMS
425    Ok(())
426}
427
428/// Decodes end-to-end YOLO detection + segmentation outputs (post-NMS from
429/// model).
430///
431/// Input shapes:
432/// - detection: (6 + num_protos, N) where rows are [x1, y1, x2, y2, conf,
433///   class, mask_coeff_0, ..., mask_coeff_31]
434/// - protos: (proto_height, proto_width, num_protos)
435///
436/// Boxes are output directly without NMS (model already applied NMS).
437/// Coordinates may be normalized [0,1] or pixel values depending on model
438/// config.
439///
440/// # Errors
441///
442/// Returns `DecoderError::InvalidShape` if:
443/// - output has fewer than 7 rows (6 base + at least 1 mask coefficient)
444/// - protos shape doesn't match mask coefficients count
445pub fn decode_yolo_end_to_end_segdet_float<T>(
446    output: ArrayView2<T>,
447    protos: ArrayView3<T>,
448    score_threshold: f32,
449    output_boxes: &mut Vec<DetectBox>,
450    output_masks: &mut Vec<crate::Segmentation>,
451) -> Result<(), crate::DecoderError>
452where
453    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
454    f32: AsPrimitive<T>,
455{
456    let (boxes, scores, classes, mask_coeff) =
457        postprocess_yolo_end_to_end_segdet(&output, protos.dim().2)?;
458    let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
459        boxes,
460        scores,
461        classes,
462        score_threshold,
463        output_boxes.capacity(),
464    );
465
466    // No NMS — model output is already post-NMS
467
468    impl_yolo_split_segdet_process_masks(boxes, mask_coeff, protos, output_boxes, output_masks)
469}
470
471/// Decodes split end-to-end YOLO detection outputs (post-NMS from model).
472///
473/// Input shapes (after batch dim removed):
474/// - boxes: (4, N) — xyxy pixel coordinates
475/// - scores: (1, N) — confidence of the top class
476/// - classes: (1, N) — class index of the top class
477///
478/// Boxes are output directly without NMS (model already applied NMS).
479pub fn decode_yolo_split_end_to_end_det_float<T: Float + AsPrimitive<f32>>(
480    boxes: ArrayView2<T>,
481    scores: ArrayView2<T>,
482    classes: ArrayView2<T>,
483    score_threshold: f32,
484    output_boxes: &mut Vec<DetectBox>,
485) -> Result<(), crate::DecoderError> {
486    let n = boxes.shape()[1];
487
488    output_boxes.clear();
489
490    let (boxes, scores, classes) = postprocess_yolo_split_end_to_end_det(boxes, scores, &classes)?;
491
492    for i in 0..n {
493        let score: f32 = scores[[i, 0]].as_();
494        if score < score_threshold {
495            continue;
496        }
497        if output_boxes.len() >= output_boxes.capacity() {
498            break;
499        }
500        output_boxes.push(DetectBox {
501            bbox: BoundingBox {
502                xmin: boxes[[i, 0]].as_(),
503                ymin: boxes[[i, 1]].as_(),
504                xmax: boxes[[i, 2]].as_(),
505                ymax: boxes[[i, 3]].as_(),
506            },
507            score,
508            label: classes[i].as_() as usize,
509        });
510    }
511    Ok(())
512}
513
514/// Decodes split end-to-end YOLO detection + segmentation outputs.
515///
516/// Input shapes (after batch dim removed):
517/// - boxes: (4, N) — xyxy pixel coordinates
518/// - scores: (1, N) — confidence
519/// - classes: (1, N) — class index
520/// - mask_coeff: (num_protos, N) — mask coefficients per detection
521/// - protos: (proto_h, proto_w, num_protos) — prototype masks
522#[allow(clippy::too_many_arguments)]
523pub fn decode_yolo_split_end_to_end_segdet_float<T>(
524    boxes: ArrayView2<T>,
525    scores: ArrayView2<T>,
526    classes: ArrayView2<T>,
527    mask_coeff: ArrayView2<T>,
528    protos: ArrayView3<T>,
529    score_threshold: f32,
530    output_boxes: &mut Vec<DetectBox>,
531    output_masks: &mut Vec<crate::Segmentation>,
532) -> Result<(), crate::DecoderError>
533where
534    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
535    f32: AsPrimitive<T>,
536{
537    let (boxes, scores, classes, mask_coeff) =
538        postprocess_yolo_split_end_to_end_segdet(boxes, scores, &classes, mask_coeff)?;
539    let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
540        boxes,
541        scores,
542        classes,
543        score_threshold,
544        output_boxes.capacity(),
545    );
546
547    impl_yolo_split_segdet_process_masks(boxes, mask_coeff, protos, output_boxes, output_masks)
548}
549
550#[allow(clippy::type_complexity)]
551pub(crate) fn postprocess_yolo_end_to_end_segdet<'a, T>(
552    output: &'a ArrayView2<'_, T>,
553    num_protos: usize,
554) -> Result<
555    (
556        ArrayView2<'a, T>,
557        ArrayView2<'a, T>,
558        ArrayView1<'a, T>,
559        ArrayView2<'a, T>,
560    ),
561    crate::DecoderError,
562> {
563    // Validate input shape: need at least 7 rows (6 base + at least 1 mask coeff)
564    if output.shape()[0] < 7 {
565        return Err(crate::DecoderError::InvalidShape(format!(
566            "End-to-end segdet output requires at least 7 rows, got {}",
567            output.shape()[0]
568        )));
569    }
570
571    let num_mask_coeffs = output.shape()[0] - 6;
572    if num_mask_coeffs != num_protos {
573        return Err(crate::DecoderError::InvalidShape(format!(
574            "Mask coefficients count ({}) doesn't match protos count ({})",
575            num_mask_coeffs, num_protos
576        )));
577    }
578
579    // Input shape: (6+num_protos, N) -> transpose for postprocessing
580    let boxes = output.slice(s![0..4, ..]).reversed_axes();
581    let scores = output.slice(s![4..5, ..]).reversed_axes();
582    let classes = output.slice(s![5, ..]);
583    let mask_coeff = output.slice(s![6.., ..]).reversed_axes();
584    Ok((boxes, scores, classes, mask_coeff))
585}
586
587/// Postprocess yolo split end to end det by reversing axes of boxes,
588/// scores, and flattening the class tensor.
589/// Expected input shapes:
590/// - boxes: (4, N)
591/// - scores: (1, N)
592/// - classes: (1, N)
593#[allow(clippy::type_complexity)]
594pub(crate) fn postprocess_yolo_split_end_to_end_det<'a, 'b, 'c, BOXES, SCORES, CLASS>(
595    boxes: ArrayView2<'a, BOXES>,
596    scores: ArrayView2<'b, SCORES>,
597    classes: &'c ArrayView2<CLASS>,
598) -> Result<
599    (
600        ArrayView2<'a, BOXES>,
601        ArrayView2<'b, SCORES>,
602        ArrayView1<'c, CLASS>,
603    ),
604    crate::DecoderError,
605> {
606    let num_boxes = boxes.shape()[1];
607    if boxes.shape()[0] != 4 {
608        return Err(crate::DecoderError::InvalidShape(format!(
609            "Split end-to-end box_coords must be 4, got {}",
610            boxes.shape()[0]
611        )));
612    }
613
614    if scores.shape()[0] != 1 {
615        return Err(crate::DecoderError::InvalidShape(format!(
616            "Split end-to-end scores num_classes must be 1, got {}",
617            scores.shape()[0]
618        )));
619    }
620
621    if classes.shape()[0] != 1 {
622        return Err(crate::DecoderError::InvalidShape(format!(
623            "Split end-to-end classes num_classes must be 1, got {}",
624            classes.shape()[0]
625        )));
626    }
627
628    if scores.shape()[1] != num_boxes {
629        return Err(crate::DecoderError::InvalidShape(format!(
630            "Split end-to-end scores must have same num_boxes as boxes ({}), got {}",
631            num_boxes,
632            scores.shape()[1]
633        )));
634    }
635
636    if classes.shape()[1] != num_boxes {
637        return Err(crate::DecoderError::InvalidShape(format!(
638            "Split end-to-end classes must have same num_boxes as boxes ({}), got {}",
639            num_boxes,
640            classes.shape()[1]
641        )));
642    }
643
644    let boxes = boxes.reversed_axes();
645    let scores = scores.reversed_axes();
646    let classes = classes.slice(s![0, ..]);
647    Ok((boxes, scores, classes))
648}
649
650/// Postprocess yolo split end to end segdet by reversing axes of boxes,
651/// scores, mask tensors and flattening the class tensor.
652#[allow(clippy::type_complexity)]
653pub(crate) fn postprocess_yolo_split_end_to_end_segdet<
654    'a,
655    'b,
656    'c,
657    'd,
658    BOXES,
659    SCORES,
660    CLASS,
661    MASK,
662>(
663    boxes: ArrayView2<'a, BOXES>,
664    scores: ArrayView2<'b, SCORES>,
665    classes: &'c ArrayView2<CLASS>,
666    mask_coeff: ArrayView2<'d, MASK>,
667) -> Result<
668    (
669        ArrayView2<'a, BOXES>,
670        ArrayView2<'b, SCORES>,
671        ArrayView1<'c, CLASS>,
672        ArrayView2<'d, MASK>,
673    ),
674    crate::DecoderError,
675> {
676    let num_boxes = boxes.shape()[1];
677    if boxes.shape()[0] != 4 {
678        return Err(crate::DecoderError::InvalidShape(format!(
679            "Split end-to-end box_coords must be 4, got {}",
680            boxes.shape()[0]
681        )));
682    }
683
684    if scores.shape()[0] != 1 {
685        return Err(crate::DecoderError::InvalidShape(format!(
686            "Split end-to-end scores num_classes must be 1, got {}",
687            scores.shape()[0]
688        )));
689    }
690
691    if classes.shape()[0] != 1 {
692        return Err(crate::DecoderError::InvalidShape(format!(
693            "Split end-to-end classes num_classes must be 1, got {}",
694            classes.shape()[0]
695        )));
696    }
697
698    if scores.shape()[1] != num_boxes {
699        return Err(crate::DecoderError::InvalidShape(format!(
700            "Split end-to-end scores must have same num_boxes as boxes ({}), got {}",
701            num_boxes,
702            scores.shape()[1]
703        )));
704    }
705
706    if classes.shape()[1] != num_boxes {
707        return Err(crate::DecoderError::InvalidShape(format!(
708            "Split end-to-end classes must have same num_boxes as boxes ({}), got {}",
709            num_boxes,
710            classes.shape()[1]
711        )));
712    }
713
714    if mask_coeff.shape()[1] != num_boxes {
715        return Err(crate::DecoderError::InvalidShape(format!(
716            "Split end-to-end mask_coeff must have same num_boxes as boxes ({}), got {}",
717            num_boxes,
718            mask_coeff.shape()[1]
719        )));
720    }
721
722    let boxes = boxes.reversed_axes();
723    let scores = scores.reversed_axes();
724    let classes = classes.slice(s![0, ..]);
725    let mask_coeff = mask_coeff.reversed_axes();
726    Ok((boxes, scores, classes, mask_coeff))
727}
728/// Internal implementation of YOLO decoding for quantized tensors.
729///
730/// Expected shapes of inputs:
731/// - output: (4 + num_classes, num_boxes)
732pub(crate) fn impl_yolo_quant<B: BBoxTypeTrait, T: PrimInt + AsPrimitive<f32> + Send + Sync>(
733    output: (ArrayView2<T>, Quantization),
734    score_threshold: f32,
735    iou_threshold: f32,
736    nms: Option<Nms>,
737    output_boxes: &mut Vec<DetectBox>,
738) where
739    f32: AsPrimitive<T>,
740{
741    let (boxes, quant_boxes) = output;
742    let (boxes_tensor, scores_tensor) = postprocess_yolo(&boxes);
743
744    let boxes = {
745        let score_threshold = quantize_score_threshold(score_threshold, quant_boxes);
746        postprocess_boxes_quant::<B, _, _>(
747            score_threshold,
748            boxes_tensor,
749            scores_tensor,
750            quant_boxes,
751        )
752    };
753
754    let boxes = dispatch_nms_int(nms, iou_threshold, boxes);
755    let len = output_boxes.capacity().min(boxes.len());
756    output_boxes.clear();
757    for b in boxes.iter().take(len) {
758        output_boxes.push(dequant_detect_box(b, quant_boxes));
759    }
760}
761
762/// Internal implementation of YOLO decoding for float tensors.
763///
764/// Expected shapes of inputs:
765/// - output: (4 + num_classes, num_boxes)
766pub(crate) fn impl_yolo_float<B: BBoxTypeTrait, T: Float + AsPrimitive<f32> + Send + Sync>(
767    output: ArrayView2<T>,
768    score_threshold: f32,
769    iou_threshold: f32,
770    nms: Option<Nms>,
771    output_boxes: &mut Vec<DetectBox>,
772) where
773    f32: AsPrimitive<T>,
774{
775    let (boxes_tensor, scores_tensor) = postprocess_yolo(&output);
776    let boxes =
777        postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
778    let boxes = dispatch_nms_float(nms, iou_threshold, boxes);
779    let len = output_boxes.capacity().min(boxes.len());
780    output_boxes.clear();
781    for b in boxes.into_iter().take(len) {
782        output_boxes.push(b);
783    }
784}
785
786/// Internal implementation of YOLO split detection decoding for quantized
787/// tensors.
788///
789/// Expected shapes of inputs:
790/// - boxes: (4, num_boxes)
791/// - scores: (num_classes, num_boxes)
792///
793/// # Panics
794/// Panics if shapes don't match the expected dimensions.
795pub(crate) fn impl_yolo_split_quant<
796    B: BBoxTypeTrait,
797    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
798    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
799>(
800    boxes: (ArrayView2<BOX>, Quantization),
801    scores: (ArrayView2<SCORE>, Quantization),
802    score_threshold: f32,
803    iou_threshold: f32,
804    nms: Option<Nms>,
805    output_boxes: &mut Vec<DetectBox>,
806) where
807    f32: AsPrimitive<SCORE>,
808{
809    let (boxes_tensor, quant_boxes) = boxes;
810    let (scores_tensor, quant_scores) = scores;
811
812    let boxes_tensor = boxes_tensor.reversed_axes();
813    let scores_tensor = scores_tensor.reversed_axes();
814
815    let boxes = {
816        let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
817        postprocess_boxes_quant::<B, _, _>(
818            score_threshold,
819            boxes_tensor,
820            scores_tensor,
821            quant_boxes,
822        )
823    };
824
825    let boxes = dispatch_nms_int(nms, iou_threshold, boxes);
826    let len = output_boxes.capacity().min(boxes.len());
827    output_boxes.clear();
828    for b in boxes.iter().take(len) {
829        output_boxes.push(dequant_detect_box(b, quant_scores));
830    }
831}
832
833/// Internal implementation of YOLO split detection decoding for float tensors.
834///
835/// Expected shapes of inputs:
836/// - boxes: (4, num_boxes)
837/// - scores: (num_classes, num_boxes)
838///
839/// # Panics
840/// Panics if shapes don't match the expected dimensions.
841pub(crate) fn impl_yolo_split_float<
842    B: BBoxTypeTrait,
843    BOX: Float + AsPrimitive<f32> + Send + Sync,
844    SCORE: Float + AsPrimitive<f32> + Send + Sync,
845>(
846    boxes_tensor: ArrayView2<BOX>,
847    scores_tensor: ArrayView2<SCORE>,
848    score_threshold: f32,
849    iou_threshold: f32,
850    nms: Option<Nms>,
851    output_boxes: &mut Vec<DetectBox>,
852) where
853    f32: AsPrimitive<SCORE>,
854{
855    let boxes_tensor = boxes_tensor.reversed_axes();
856    let scores_tensor = scores_tensor.reversed_axes();
857    let boxes =
858        postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
859    let boxes = dispatch_nms_float(nms, iou_threshold, boxes);
860    let len = output_boxes.capacity().min(boxes.len());
861    output_boxes.clear();
862    for b in boxes.into_iter().take(len) {
863        output_boxes.push(b);
864    }
865}
866
867/// Internal implementation of YOLO detection segmentation decoding for
868/// quantized tensors.
869///
870/// Expected shapes of inputs:
871/// - boxes: (4 + num_classes + num_protos, num_boxes)
872/// - protos: (proto_height, proto_width, num_protos)
873///
874/// # Errors
875/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
876pub(crate) fn impl_yolo_segdet_quant<
877    B: BBoxTypeTrait,
878    BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
879    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
880>(
881    boxes: (ArrayView2<BOX>, Quantization),
882    protos: (ArrayView3<PROTO>, Quantization),
883    score_threshold: f32,
884    iou_threshold: f32,
885    nms: Option<Nms>,
886    output_boxes: &mut Vec<DetectBox>,
887    output_masks: &mut Vec<Segmentation>,
888) -> Result<(), crate::DecoderError>
889where
890    f32: AsPrimitive<BOX>,
891{
892    let (boxes, quant_boxes) = boxes;
893    let num_protos = protos.0.dim().2;
894
895    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
896    let boxes = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
897        (boxes_tensor, quant_boxes),
898        (scores_tensor, quant_boxes),
899        score_threshold,
900        iou_threshold,
901        nms,
902        output_boxes.capacity(),
903    );
904
905    impl_yolo_split_segdet_quant_process_masks::<_, _>(
906        boxes,
907        (mask_tensor, quant_boxes),
908        protos,
909        output_boxes,
910        output_masks,
911    )
912}
913
914/// Internal implementation of YOLO detection segmentation decoding for
915/// float tensors.
916///
917/// Expected shapes of inputs:
918/// - boxes: (4 + num_classes + num_protos, num_boxes)
919/// - protos: (proto_height, proto_width, num_protos)
920///
921/// # Panics
922/// Panics if shapes don't match the expected dimensions.
923pub(crate) fn impl_yolo_segdet_float<
924    B: BBoxTypeTrait,
925    BOX: Float + AsPrimitive<f32> + Send + Sync,
926    PROTO: Float + AsPrimitive<f32> + Send + Sync,
927>(
928    boxes: ArrayView2<BOX>,
929    protos: ArrayView3<PROTO>,
930    score_threshold: f32,
931    iou_threshold: f32,
932    nms: Option<Nms>,
933    output_boxes: &mut Vec<DetectBox>,
934    output_masks: &mut Vec<Segmentation>,
935) -> Result<(), crate::DecoderError>
936where
937    f32: AsPrimitive<BOX>,
938{
939    let num_protos = protos.dim().2;
940    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
941    let boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
942        boxes_tensor,
943        scores_tensor,
944        score_threshold,
945        iou_threshold,
946        nms,
947        output_boxes.capacity(),
948    );
949    impl_yolo_split_segdet_process_masks(boxes, mask_tensor, protos, output_boxes, output_masks)
950}
951
952pub(crate) fn impl_yolo_segdet_get_boxes<
953    B: BBoxTypeTrait,
954    BOX: Float + AsPrimitive<f32> + Send + Sync,
955    SCORE: Float + AsPrimitive<f32> + Send + Sync,
956>(
957    boxes_tensor: ArrayView2<BOX>,
958    scores_tensor: ArrayView2<SCORE>,
959    score_threshold: f32,
960    iou_threshold: f32,
961    nms: Option<Nms>,
962    max_boxes: usize,
963) -> Vec<(DetectBox, usize)>
964where
965    f32: AsPrimitive<SCORE>,
966{
967    let mut boxes = postprocess_boxes_index_float::<B, _, _>(
968        score_threshold.as_(),
969        boxes_tensor,
970        scores_tensor,
971    );
972    truncate_to_top_k_by_score(&mut boxes);
973    let mut boxes = dispatch_nms_extra_float(nms, iou_threshold, boxes);
974    boxes.truncate(max_boxes);
975    boxes
976}
977
978pub(crate) fn impl_yolo_end_to_end_segdet_get_boxes<
979    B: BBoxTypeTrait,
980    BOX: Float + AsPrimitive<f32> + Send + Sync,
981    SCORE: Float + AsPrimitive<f32> + Send + Sync,
982    CLASS: AsPrimitive<f32> + Send + Sync,
983>(
984    boxes: ArrayView2<BOX>,
985    scores: ArrayView2<SCORE>,
986    classes: ArrayView1<CLASS>,
987    score_threshold: f32,
988    max_boxes: usize,
989) -> Vec<(DetectBox, usize)>
990where
991    f32: AsPrimitive<SCORE>,
992{
993    let mut boxes = postprocess_boxes_index_float::<B, _, _>(score_threshold.as_(), boxes, scores);
994    boxes.truncate(max_boxes);
995    for (b, ind) in &mut boxes {
996        b.label = classes[*ind].as_().round() as usize;
997    }
998    boxes
999}
1000
1001pub(crate) fn impl_yolo_split_segdet_process_masks<
1002    MASK: Float + AsPrimitive<f32> + Send + Sync,
1003    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1004>(
1005    boxes: Vec<(DetectBox, usize)>,
1006    masks_tensor: ArrayView2<MASK>,
1007    protos_tensor: ArrayView3<PROTO>,
1008    output_boxes: &mut Vec<DetectBox>,
1009    output_masks: &mut Vec<Segmentation>,
1010) -> Result<(), crate::DecoderError> {
1011    let boxes = decode_segdet_f32(boxes, masks_tensor, protos_tensor)?;
1012    output_boxes.clear();
1013    output_masks.clear();
1014    for (b, m) in boxes.into_iter() {
1015        output_boxes.push(b);
1016        output_masks.push(Segmentation {
1017            xmin: b.bbox.xmin,
1018            ymin: b.bbox.ymin,
1019            xmax: b.bbox.xmax,
1020            ymax: b.bbox.ymax,
1021            segmentation: m,
1022        });
1023    }
1024    Ok(())
1025}
1026/// Expected input shapes:
1027/// - boxes_tensor: (num_boxes, 4)
1028/// - scores_tensor: (num_boxes, num_classes)
1029pub(crate) fn impl_yolo_split_segdet_quant_get_boxes<
1030    B: BBoxTypeTrait,
1031    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
1032    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
1033>(
1034    boxes: (ArrayView2<BOX>, Quantization),
1035    scores: (ArrayView2<SCORE>, Quantization),
1036    score_threshold: f32,
1037    iou_threshold: f32,
1038    nms: Option<Nms>,
1039    max_boxes: usize,
1040) -> Vec<(DetectBox, usize)>
1041where
1042    f32: AsPrimitive<SCORE>,
1043{
1044    let (boxes_tensor, quant_boxes) = boxes;
1045    let (scores_tensor, quant_scores) = scores;
1046
1047    let mut boxes = {
1048        let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
1049        postprocess_boxes_index_quant::<B, _, _>(
1050            score_threshold,
1051            boxes_tensor,
1052            scores_tensor,
1053            quant_boxes,
1054        )
1055    };
1056    truncate_to_top_k_by_score_quant(&mut boxes);
1057    let mut boxes = dispatch_nms_extra_int(nms, iou_threshold, boxes);
1058    boxes.truncate(max_boxes);
1059    boxes
1060        .into_iter()
1061        .map(|(b, i)| (dequant_detect_box(&b, quant_scores), i))
1062        .collect()
1063}
1064
1065pub(crate) fn impl_yolo_split_segdet_quant_process_masks<
1066    MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1067    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1068>(
1069    boxes: Vec<(DetectBox, usize)>,
1070    mask_coeff: (ArrayView2<MASK>, Quantization),
1071    protos: (ArrayView3<PROTO>, Quantization),
1072    output_boxes: &mut Vec<DetectBox>,
1073    output_masks: &mut Vec<Segmentation>,
1074) -> Result<(), crate::DecoderError> {
1075    let (masks, quant_masks) = mask_coeff;
1076    let (protos, quant_protos) = protos;
1077
1078    let boxes = decode_segdet_quant(boxes, masks, protos, quant_masks, quant_protos)?;
1079    output_boxes.clear();
1080    output_masks.clear();
1081    for (b, m) in boxes.into_iter() {
1082        output_boxes.push(b);
1083        output_masks.push(Segmentation {
1084            xmin: b.bbox.xmin,
1085            ymin: b.bbox.ymin,
1086            xmax: b.bbox.xmax,
1087            ymax: b.bbox.ymax,
1088            segmentation: m,
1089        });
1090    }
1091    Ok(())
1092}
1093
1094#[allow(clippy::too_many_arguments)]
1095/// Internal implementation of YOLO split detection segmentation decoding for
1096/// quantized tensors.
1097///
1098/// Expected shapes of inputs:
1099/// - boxes_tensor: (4, num_boxes)
1100/// - scores_tensor: (num_classes, num_boxes)
1101/// - mask_tensor: (num_protos, num_boxes)
1102/// - protos: (proto_height, proto_width, num_protos)
1103///
1104/// # Errors
1105/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
1106pub(crate) fn impl_yolo_split_segdet_quant<
1107    B: BBoxTypeTrait,
1108    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
1109    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
1110    MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1111    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1112>(
1113    boxes: (ArrayView2<BOX>, Quantization),
1114    scores: (ArrayView2<SCORE>, Quantization),
1115    mask_coeff: (ArrayView2<MASK>, Quantization),
1116    protos: (ArrayView3<PROTO>, Quantization),
1117    score_threshold: f32,
1118    iou_threshold: f32,
1119    nms: Option<Nms>,
1120    output_boxes: &mut Vec<DetectBox>,
1121    output_masks: &mut Vec<Segmentation>,
1122) -> Result<(), crate::DecoderError>
1123where
1124    f32: AsPrimitive<SCORE>,
1125{
1126    let (boxes_, scores_, mask_coeff_) =
1127        postprocess_yolo_split_segdet(boxes.0, scores.0, mask_coeff.0);
1128    let boxes = (boxes_, boxes.1);
1129    let scores = (scores_, scores.1);
1130    let mask_coeff = (mask_coeff_, mask_coeff.1);
1131
1132    let boxes = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
1133        boxes,
1134        scores,
1135        score_threshold,
1136        iou_threshold,
1137        nms,
1138        output_boxes.capacity(),
1139    );
1140
1141    impl_yolo_split_segdet_quant_process_masks(
1142        boxes,
1143        mask_coeff,
1144        protos,
1145        output_boxes,
1146        output_masks,
1147    )
1148}
1149
1150#[allow(clippy::too_many_arguments)]
1151/// Internal implementation of YOLO split detection segmentation decoding for
1152/// float tensors.
1153///
1154/// Expected shapes of inputs:
1155/// - boxes_tensor: (4, num_boxes)
1156/// - scores_tensor: (num_classes, num_boxes)
1157/// - mask_tensor: (num_protos, num_boxes)
1158/// - protos: (proto_height, proto_width, num_protos)
1159///
1160/// # Errors
1161/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
1162pub(crate) fn impl_yolo_split_segdet_float<
1163    B: BBoxTypeTrait,
1164    BOX: Float + AsPrimitive<f32> + Send + Sync,
1165    SCORE: Float + AsPrimitive<f32> + Send + Sync,
1166    MASK: Float + AsPrimitive<f32> + Send + Sync,
1167    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1168>(
1169    boxes_tensor: ArrayView2<BOX>,
1170    scores_tensor: ArrayView2<SCORE>,
1171    mask_tensor: ArrayView2<MASK>,
1172    protos: ArrayView3<PROTO>,
1173    score_threshold: f32,
1174    iou_threshold: f32,
1175    nms: Option<Nms>,
1176    output_boxes: &mut Vec<DetectBox>,
1177    output_masks: &mut Vec<Segmentation>,
1178) -> Result<(), crate::DecoderError>
1179where
1180    f32: AsPrimitive<SCORE>,
1181{
1182    let (boxes_tensor, scores_tensor, mask_tensor) =
1183        postprocess_yolo_split_segdet(boxes_tensor, scores_tensor, mask_tensor);
1184
1185    let boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
1186        boxes_tensor,
1187        scores_tensor,
1188        score_threshold,
1189        iou_threshold,
1190        nms,
1191        output_boxes.capacity(),
1192    );
1193    impl_yolo_split_segdet_process_masks(boxes, mask_tensor, protos, output_boxes, output_masks)
1194}
1195
1196// ---------------------------------------------------------------------------
1197// Proto-extraction variants: return ProtoData instead of materialized masks
1198// ---------------------------------------------------------------------------
1199
1200/// Proto-extraction variant of `impl_yolo_segdet_quant`.
1201/// Runs NMS but returns raw `ProtoData` instead of materialized masks.
1202pub fn impl_yolo_segdet_quant_proto<
1203    B: BBoxTypeTrait,
1204    BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1205    PROTO: PrimInt
1206        + AsPrimitive<i64>
1207        + AsPrimitive<i128>
1208        + AsPrimitive<f32>
1209        + AsPrimitive<i8>
1210        + Send
1211        + Sync,
1212>(
1213    boxes: (ArrayView2<BOX>, Quantization),
1214    protos: (ArrayView3<PROTO>, Quantization),
1215    score_threshold: f32,
1216    iou_threshold: f32,
1217    nms: Option<Nms>,
1218    output_boxes: &mut Vec<DetectBox>,
1219) -> ProtoData
1220where
1221    f32: AsPrimitive<BOX>,
1222{
1223    let (boxes_arr, quant_boxes) = boxes;
1224    let (protos_arr, quant_protos) = protos;
1225    let num_protos = protos_arr.dim().2;
1226
1227    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes_arr, num_protos);
1228
1229    let det_indices = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
1230        (boxes_tensor, quant_boxes),
1231        (scores_tensor, quant_boxes),
1232        score_threshold,
1233        iou_threshold,
1234        nms,
1235        output_boxes.capacity(),
1236    );
1237
1238    extract_proto_data_quant(
1239        det_indices,
1240        mask_tensor,
1241        quant_boxes,
1242        protos_arr,
1243        quant_protos,
1244        output_boxes,
1245    )
1246}
1247
1248/// Proto-extraction variant of `impl_yolo_segdet_float`.
1249/// Runs NMS but returns raw `ProtoData` instead of materialized masks.
1250pub(crate) fn impl_yolo_segdet_float_proto<
1251    B: BBoxTypeTrait,
1252    BOX: Float + AsPrimitive<f32> + Send + Sync,
1253    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1254>(
1255    boxes: ArrayView2<BOX>,
1256    protos: ArrayView3<PROTO>,
1257    score_threshold: f32,
1258    iou_threshold: f32,
1259    nms: Option<Nms>,
1260    output_boxes: &mut Vec<DetectBox>,
1261) -> ProtoData
1262where
1263    f32: AsPrimitive<BOX>,
1264{
1265    let num_protos = protos.dim().2;
1266    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
1267
1268    let boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
1269        boxes_tensor,
1270        scores_tensor,
1271        score_threshold,
1272        iou_threshold,
1273        nms,
1274        output_boxes.capacity(),
1275    );
1276
1277    extract_proto_data_float(boxes, mask_tensor, protos, output_boxes)
1278}
1279
1280/// Proto-extraction variant of `impl_yolo_split_segdet_float`.
1281/// Runs NMS but returns raw `ProtoData` instead of materialized masks.
1282#[allow(clippy::too_many_arguments)]
1283pub(crate) fn impl_yolo_split_segdet_float_proto<
1284    B: BBoxTypeTrait,
1285    BOX: Float + AsPrimitive<f32> + Send + Sync,
1286    SCORE: Float + AsPrimitive<f32> + Send + Sync,
1287    MASK: Float + AsPrimitive<f32> + Send + Sync,
1288    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1289>(
1290    boxes_tensor: ArrayView2<BOX>,
1291    scores_tensor: ArrayView2<SCORE>,
1292    mask_tensor: ArrayView2<MASK>,
1293    protos: ArrayView3<PROTO>,
1294    score_threshold: f32,
1295    iou_threshold: f32,
1296    nms: Option<Nms>,
1297    output_boxes: &mut Vec<DetectBox>,
1298) -> ProtoData
1299where
1300    f32: AsPrimitive<SCORE>,
1301{
1302    let (boxes_tensor, scores_tensor, mask_tensor) =
1303        postprocess_yolo_split_segdet(boxes_tensor, scores_tensor, mask_tensor);
1304    let det_indices = impl_yolo_segdet_get_boxes::<B, _, _>(
1305        boxes_tensor,
1306        scores_tensor,
1307        score_threshold,
1308        iou_threshold,
1309        nms,
1310        output_boxes.capacity(),
1311    );
1312
1313    extract_proto_data_float(det_indices, mask_tensor, protos, output_boxes)
1314}
1315
1316/// Proto-extraction variant of `decode_yolo_end_to_end_segdet_float`.
1317pub fn decode_yolo_end_to_end_segdet_float_proto<T>(
1318    output: ArrayView2<T>,
1319    protos: ArrayView3<T>,
1320    score_threshold: f32,
1321    output_boxes: &mut Vec<DetectBox>,
1322) -> Result<ProtoData, crate::DecoderError>
1323where
1324    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
1325    f32: AsPrimitive<T>,
1326{
1327    let (boxes, scores, classes, mask_coeff) =
1328        postprocess_yolo_end_to_end_segdet(&output, protos.dim().2)?;
1329    let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
1330        boxes,
1331        scores,
1332        classes,
1333        score_threshold,
1334        output_boxes.capacity(),
1335    );
1336
1337    Ok(extract_proto_data_float(
1338        boxes,
1339        mask_coeff,
1340        protos,
1341        output_boxes,
1342    ))
1343}
1344
1345/// Proto-extraction variant of `decode_yolo_split_end_to_end_segdet_float`.
1346#[allow(clippy::too_many_arguments)]
1347pub fn decode_yolo_split_end_to_end_segdet_float_proto<T>(
1348    boxes: ArrayView2<T>,
1349    scores: ArrayView2<T>,
1350    classes: ArrayView2<T>,
1351    mask_coeff: ArrayView2<T>,
1352    protos: ArrayView3<T>,
1353    score_threshold: f32,
1354    output_boxes: &mut Vec<DetectBox>,
1355) -> Result<ProtoData, crate::DecoderError>
1356where
1357    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
1358    f32: AsPrimitive<T>,
1359{
1360    let (boxes, scores, classes, mask_coeff) =
1361        postprocess_yolo_split_end_to_end_segdet(boxes, scores, &classes, mask_coeff)?;
1362    let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
1363        boxes,
1364        scores,
1365        classes,
1366        score_threshold,
1367        output_boxes.capacity(),
1368    );
1369
1370    Ok(extract_proto_data_float(
1371        boxes,
1372        mask_coeff,
1373        protos,
1374        output_boxes,
1375    ))
1376}
1377
1378/// Helper: extract ProtoData from float mask coefficients + protos.
1379pub(super) fn extract_proto_data_float<
1380    MASK: Float + AsPrimitive<f32> + Send + Sync,
1381    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1382>(
1383    det_indices: Vec<(DetectBox, usize)>,
1384    mask_tensor: ArrayView2<MASK>,
1385    protos: ArrayView3<PROTO>,
1386    output_boxes: &mut Vec<DetectBox>,
1387) -> ProtoData {
1388    let mut mask_coefficients = Vec::with_capacity(det_indices.len());
1389    output_boxes.clear();
1390    for (det, idx) in det_indices {
1391        output_boxes.push(det);
1392        let row = mask_tensor.row(idx);
1393        mask_coefficients.push(row.iter().map(|v| v.as_()).collect());
1394    }
1395    let protos_f32 = protos.map(|v| v.as_());
1396    ProtoData {
1397        mask_coefficients,
1398        protos: ProtoTensor::Float(protos_f32),
1399    }
1400}
1401
1402/// Helper: extract ProtoData from quantized mask coefficients + protos.
1403///
1404/// Dequantizes mask coefficients to f32 (small — per-detection) but keeps
1405/// protos in raw int8 form wrapped in `ProtoTensor::Quantized` so the GPU
1406/// shader can dequantize per-texel without CPU overhead.
1407pub(crate) fn extract_proto_data_quant<
1408    MASK: PrimInt + AsPrimitive<f32> + Send + Sync,
1409    PROTO: PrimInt + AsPrimitive<f32> + AsPrimitive<i8> + Send + Sync + 'static,
1410>(
1411    det_indices: Vec<(DetectBox, usize)>,
1412    mask_tensor: ArrayView2<MASK>,
1413    quant_masks: Quantization,
1414    protos: ArrayView3<PROTO>,
1415    quant_protos: Quantization,
1416    output_boxes: &mut Vec<DetectBox>,
1417) -> ProtoData {
1418    let mut mask_coefficients = Vec::with_capacity(det_indices.len());
1419    output_boxes.clear();
1420    for (det, idx) in det_indices {
1421        output_boxes.push(det);
1422        let row = mask_tensor.row(idx);
1423        mask_coefficients.push(
1424            row.iter()
1425                .map(|v| (v.as_() - quant_masks.zero_point as f32) * quant_masks.scale)
1426                .collect(),
1427        );
1428    }
1429    // Keep protos in raw int8 — GPU shader will dequantize per-texel.
1430    // When PROTO is already i8, use to_owned() for a flat memcpy instead of
1431    // per-element as_() conversion.
1432    let protos_i8 = if std::any::TypeId::of::<PROTO>() == std::any::TypeId::of::<i8>() {
1433        // SAFETY: PROTO and i8 have identical size and layout when TypeId matches.
1434        let view_i8 =
1435            unsafe { &*(&protos as *const ArrayView3<'_, PROTO> as *const ArrayView3<'_, i8>) };
1436        view_i8.to_owned()
1437    } else {
1438        protos.map(|v| {
1439            let v_i8: i8 = v.as_();
1440            v_i8
1441        })
1442    };
1443    ProtoData {
1444        mask_coefficients,
1445        protos: ProtoTensor::Quantized {
1446            protos: protos_i8,
1447            quantization: quant_protos,
1448        },
1449    }
1450}
1451
1452fn postprocess_yolo<'a, T>(
1453    output: &'a ArrayView2<'_, T>,
1454) -> (ArrayView2<'a, T>, ArrayView2<'a, T>) {
1455    let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1456    let scores_tensor = output.slice(s![4.., ..,]).reversed_axes();
1457    (boxes_tensor, scores_tensor)
1458}
1459
1460pub(crate) fn postprocess_yolo_seg<'a, T>(
1461    output: &'a ArrayView2<'_, T>,
1462    num_protos: usize,
1463) -> (ArrayView2<'a, T>, ArrayView2<'a, T>, ArrayView2<'a, T>) {
1464    assert!(
1465        output.shape()[0] > num_protos + 4,
1466        "Output shape is too short: {} <= {} + 4",
1467        output.shape()[0],
1468        num_protos
1469    );
1470    let num_classes = output.shape()[0] - 4 - num_protos;
1471    let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1472    let scores_tensor = output.slice(s![4..(num_classes + 4), ..,]).reversed_axes();
1473    let mask_tensor = output.slice(s![(num_classes + 4).., ..,]).reversed_axes();
1474    (boxes_tensor, scores_tensor, mask_tensor)
1475}
1476
1477pub(crate) fn postprocess_yolo_split_segdet<'a, 'b, 'c, BOX, SCORE, MASK>(
1478    boxes_tensor: ArrayView2<'a, BOX>,
1479    scores_tensor: ArrayView2<'b, SCORE>,
1480    mask_tensor: ArrayView2<'c, MASK>,
1481) -> (
1482    ArrayView2<'a, BOX>,
1483    ArrayView2<'b, SCORE>,
1484    ArrayView2<'c, MASK>,
1485) {
1486    let boxes_tensor = boxes_tensor.reversed_axes();
1487    let scores_tensor = scores_tensor.reversed_axes();
1488    let mask_tensor = mask_tensor.reversed_axes();
1489    (boxes_tensor, scores_tensor, mask_tensor)
1490}
1491
1492fn decode_segdet_f32<
1493    MASK: Float + AsPrimitive<f32> + Send + Sync,
1494    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1495>(
1496    boxes: Vec<(DetectBox, usize)>,
1497    masks: ArrayView2<MASK>,
1498    protos: ArrayView3<PROTO>,
1499) -> Result<Vec<(DetectBox, Array3<u8>)>, crate::DecoderError> {
1500    if boxes.is_empty() {
1501        return Ok(Vec::new());
1502    }
1503    if masks.shape()[1] != protos.shape()[2] {
1504        return Err(crate::DecoderError::InvalidShape(format!(
1505            "Mask coefficients count ({}) doesn't match protos channel count ({})",
1506            masks.shape()[1],
1507            protos.shape()[2],
1508        )));
1509    }
1510    boxes
1511        .into_par_iter()
1512        .map(|mut b| {
1513            let ind = b.1;
1514            let (protos, roi) = protobox(&protos, &b.0.bbox)?;
1515            b.0.bbox = roi;
1516            Ok((b.0, make_segmentation(masks.row(ind), protos.view())))
1517        })
1518        .collect()
1519}
1520
1521pub(crate) fn decode_segdet_quant<
1522    MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1523    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1524>(
1525    boxes: Vec<(DetectBox, usize)>,
1526    masks: ArrayView2<MASK>,
1527    protos: ArrayView3<PROTO>,
1528    quant_masks: Quantization,
1529    quant_protos: Quantization,
1530) -> Result<Vec<(DetectBox, Array3<u8>)>, crate::DecoderError> {
1531    if boxes.is_empty() {
1532        return Ok(Vec::new());
1533    }
1534    if masks.shape()[1] != protos.shape()[2] {
1535        return Err(crate::DecoderError::InvalidShape(format!(
1536            "Mask coefficients count ({}) doesn't match protos channel count ({})",
1537            masks.shape()[1],
1538            protos.shape()[2],
1539        )));
1540    }
1541
1542    let total_bits = MASK::zero().count_zeros() + PROTO::zero().count_zeros() + 5; // 32 protos is 2^5
1543    boxes
1544        .into_iter()
1545        .map(|mut b| {
1546            let i = b.1;
1547            let (protos, roi) = protobox(&protos, &b.0.bbox.to_canonical())?;
1548            b.0.bbox = roi;
1549            let seg = match total_bits {
1550                0..=64 => make_segmentation_quant::<MASK, PROTO, i64>(
1551                    masks.row(i),
1552                    protos.view(),
1553                    quant_masks,
1554                    quant_protos,
1555                ),
1556                65..=128 => make_segmentation_quant::<MASK, PROTO, i128>(
1557                    masks.row(i),
1558                    protos.view(),
1559                    quant_masks,
1560                    quant_protos,
1561                ),
1562                _ => {
1563                    return Err(crate::DecoderError::NotSupported(format!(
1564                        "Unsupported bit width ({total_bits}) for segmentation computation"
1565                    )));
1566                }
1567            };
1568            Ok((b.0, seg))
1569        })
1570        .collect()
1571}
1572
1573fn protobox<'a, T>(
1574    protos: &'a ArrayView3<T>,
1575    roi: &BoundingBox,
1576) -> Result<(ArrayView3<'a, T>, BoundingBox), crate::DecoderError> {
1577    let width = protos.dim().1 as f32;
1578    let height = protos.dim().0 as f32;
1579
1580    // Detect un-normalized bounding boxes (pixel-space coordinates).
1581    // protobox expects normalized coordinates in [0, 1]. ONNX models output
1582    // pixel-space boxes (e.g. 0-640) which must be normalized before calling
1583    // decode(). Without this check, pixel-space coordinates silently clamp to
1584    // the proto boundary, producing empty (0, 0, C) masks for every detection.
1585    //
1586    // The limit is set to 2.0 (not 1.01) because YOLO models legitimately
1587    // predict coordinates slightly > 1.0 for objects near frame edges.
1588    // Any value > 2.0 is clearly pixel-space (even the smallest practical
1589    // model input of 32×32 would produce coordinates >> 2.0).
1590    const NORM_LIMIT: f32 = 2.0;
1591    if roi.xmin > NORM_LIMIT
1592        || roi.ymin > NORM_LIMIT
1593        || roi.xmax > NORM_LIMIT
1594        || roi.ymax > NORM_LIMIT
1595    {
1596        return Err(crate::DecoderError::InvalidShape(format!(
1597            "Bounding box coordinates appear un-normalized (pixel-space). \
1598             Got bbox=({:.2}, {:.2}, {:.2}, {:.2}) but expected values in [0, 1]. \
1599             ONNX models output pixel-space boxes — normalize them by dividing by \
1600             the input dimensions before calling decode().",
1601            roi.xmin, roi.ymin, roi.xmax, roi.ymax,
1602        )));
1603    }
1604
1605    let roi = [
1606        (roi.xmin * width).clamp(0.0, width) as usize,
1607        (roi.ymin * height).clamp(0.0, height) as usize,
1608        (roi.xmax * width).clamp(0.0, width).ceil() as usize,
1609        (roi.ymax * height).clamp(0.0, height).ceil() as usize,
1610    ];
1611
1612    let roi_norm = [
1613        roi[0] as f32 / width,
1614        roi[1] as f32 / height,
1615        roi[2] as f32 / width,
1616        roi[3] as f32 / height,
1617    ]
1618    .into();
1619
1620    let cropped = protos.slice(s![roi[1]..roi[3], roi[0]..roi[2], ..]);
1621
1622    Ok((cropped, roi_norm))
1623}
1624
1625/// Compute a single instance segmentation mask from mask coefficients and
1626/// proto maps (float path).
1627///
1628/// Computes `sigmoid(coefficients · protos)` and maps to `[0, 255]`.
1629/// Returns an `(H, W, 1)` u8 array.
1630fn make_segmentation<
1631    MASK: Float + AsPrimitive<f32> + Send + Sync,
1632    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1633>(
1634    mask: ArrayView1<MASK>,
1635    protos: ArrayView3<PROTO>,
1636) -> Array3<u8> {
1637    let shape = protos.shape();
1638
1639    // Safe to unwrap since the shapes will always be compatible
1640    let mask = mask.to_shape((1, mask.len())).unwrap();
1641    let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
1642    let protos = protos.reversed_axes();
1643    let mask = mask.map(|x| x.as_());
1644    let protos = protos.map(|x| x.as_());
1645
1646    // Safe to unwrap since the shapes will always be compatible
1647    let mask = mask
1648        .dot(&protos)
1649        .into_shape_with_order((shape[0], shape[1], 1))
1650        .unwrap();
1651
1652    mask.map(|x| {
1653        let sigmoid = 1.0 / (1.0 + (-*x).exp());
1654        (sigmoid * 255.0).round() as u8
1655    })
1656}
1657
1658/// Compute a single instance segmentation mask from quantized mask
1659/// coefficients and proto maps.
1660///
1661/// Dequantizes both inputs (subtracting zero-points), computes the dot
1662/// product, applies sigmoid, and maps to `[0, 255]`.
1663/// Returns an `(H, W, 1)` u8 array.
1664fn make_segmentation_quant<
1665    MASK: PrimInt + AsPrimitive<DEST> + Send + Sync,
1666    PROTO: PrimInt + AsPrimitive<DEST> + Send + Sync,
1667    DEST: PrimInt + 'static + Signed + AsPrimitive<f32> + Debug,
1668>(
1669    mask: ArrayView1<MASK>,
1670    protos: ArrayView3<PROTO>,
1671    quant_masks: Quantization,
1672    quant_protos: Quantization,
1673) -> Array3<u8>
1674where
1675    i32: AsPrimitive<DEST>,
1676    f32: AsPrimitive<DEST>,
1677{
1678    let shape = protos.shape();
1679
1680    // Safe to unwrap since the shapes will always be compatible
1681    let mask = mask.to_shape((1, mask.len())).unwrap();
1682
1683    let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
1684    let protos = protos.reversed_axes();
1685
1686    let zp = quant_masks.zero_point.as_();
1687
1688    let mask = mask.mapv(|x| x.as_() - zp);
1689
1690    let zp = quant_protos.zero_point.as_();
1691    let protos = protos.mapv(|x| x.as_() - zp);
1692
1693    // Safe to unwrap since the shapes will always be compatible
1694    let segmentation = mask
1695        .dot(&protos)
1696        .into_shape_with_order((shape[0], shape[1], 1))
1697        .unwrap();
1698
1699    let combined_scale = quant_masks.scale * quant_protos.scale;
1700    segmentation.map(|x| {
1701        let val: f32 = (*x).as_() * combined_scale;
1702        let sigmoid = 1.0 / (1.0 + (-val).exp());
1703        (sigmoid * 255.0).round() as u8
1704    })
1705}
1706
1707/// Converts Yolo Instance Segmentation into a 2D mask.
1708///
1709/// The input segmentation is expected to have shape (H, W, 1).
1710///
1711/// The output mask will have shape (H, W), with values 0 or 1 based on the
1712/// threshold.
1713///
1714/// # Errors
1715///
1716/// Returns `DecoderError::InvalidShape` if the input segmentation does not
1717/// have shape (H, W, 1).
1718pub fn yolo_segmentation_to_mask(
1719    segmentation: ArrayView3<u8>,
1720    threshold: u8,
1721) -> Result<Array2<u8>, crate::DecoderError> {
1722    if segmentation.shape()[2] != 1 {
1723        return Err(crate::DecoderError::InvalidShape(format!(
1724            "Yolo Instance Segmentation should have shape (H, W, 1), got (H, W, {})",
1725            segmentation.shape()[2]
1726        )));
1727    }
1728    Ok(segmentation
1729        .slice(s![.., .., 0])
1730        .map(|x| if *x >= threshold { 1 } else { 0 }))
1731}
1732
1733#[cfg(test)]
1734#[cfg_attr(coverage_nightly, coverage(off))]
1735mod tests {
1736    use super::*;
1737    use ndarray::Array2;
1738
1739    // ========================================================================
1740    // Tests for decode_yolo_end_to_end_det_float
1741    // ========================================================================
1742
1743    #[test]
1744    fn test_end_to_end_det_basic_filtering() {
1745        // Create synthetic end-to-end detection output: (6, N) where rows are
1746        // [x1, y1, x2, y2, conf, class]
1747        // 3 detections: one above threshold, two below
1748        let data: Vec<f32> = vec![
1749            // Detection 0: high score (0.9)
1750            0.1, 0.2, 0.3, // x1 values
1751            0.1, 0.2, 0.3, // y1 values
1752            0.5, 0.6, 0.7, // x2 values
1753            0.5, 0.6, 0.7, // y2 values
1754            0.9, 0.1, 0.2, // confidence scores
1755            0.0, 1.0, 2.0, // class indices
1756        ];
1757        let output = Array2::from_shape_vec((6, 3), data).unwrap();
1758
1759        let mut boxes = Vec::with_capacity(10);
1760        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1761
1762        // Only 1 detection should pass threshold of 0.5
1763        assert_eq!(boxes.len(), 1);
1764        assert_eq!(boxes[0].label, 0);
1765        assert!((boxes[0].score - 0.9).abs() < 0.01);
1766        assert!((boxes[0].bbox.xmin - 0.1).abs() < 0.01);
1767        assert!((boxes[0].bbox.ymin - 0.1).abs() < 0.01);
1768        assert!((boxes[0].bbox.xmax - 0.5).abs() < 0.01);
1769        assert!((boxes[0].bbox.ymax - 0.5).abs() < 0.01);
1770    }
1771
1772    #[test]
1773    fn test_end_to_end_det_all_pass_threshold() {
1774        // All detections above threshold
1775        let data: Vec<f32> = vec![
1776            10.0, 20.0, // x1
1777            10.0, 20.0, // y1
1778            50.0, 60.0, // x2
1779            50.0, 60.0, // y2
1780            0.8, 0.7, // conf (both above 0.5)
1781            1.0, 2.0, // class
1782        ];
1783        let output = Array2::from_shape_vec((6, 2), data).unwrap();
1784
1785        let mut boxes = Vec::with_capacity(10);
1786        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1787
1788        assert_eq!(boxes.len(), 2);
1789        assert_eq!(boxes[0].label, 1);
1790        assert_eq!(boxes[1].label, 2);
1791    }
1792
1793    #[test]
1794    fn test_end_to_end_det_none_pass_threshold() {
1795        // All detections below threshold
1796        let data: Vec<f32> = vec![
1797            10.0, 20.0, // x1
1798            10.0, 20.0, // y1
1799            50.0, 60.0, // x2
1800            50.0, 60.0, // y2
1801            0.1, 0.2, // conf (both below 0.5)
1802            1.0, 2.0, // class
1803        ];
1804        let output = Array2::from_shape_vec((6, 2), data).unwrap();
1805
1806        let mut boxes = Vec::with_capacity(10);
1807        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1808
1809        assert_eq!(boxes.len(), 0);
1810    }
1811
1812    #[test]
1813    fn test_end_to_end_det_capacity_limit() {
1814        // Test that output is truncated to capacity
1815        let data: Vec<f32> = vec![
1816            0.1, 0.2, 0.3, 0.4, 0.5, // x1
1817            0.1, 0.2, 0.3, 0.4, 0.5, // y1
1818            0.5, 0.6, 0.7, 0.8, 0.9, // x2
1819            0.5, 0.6, 0.7, 0.8, 0.9, // y2
1820            0.9, 0.9, 0.9, 0.9, 0.9, // conf (all pass)
1821            0.0, 1.0, 2.0, 3.0, 4.0, // class
1822        ];
1823        let output = Array2::from_shape_vec((6, 5), data).unwrap();
1824
1825        let mut boxes = Vec::with_capacity(2); // Only allow 2 boxes
1826        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1827
1828        assert_eq!(boxes.len(), 2);
1829    }
1830
1831    #[test]
1832    fn test_end_to_end_det_empty_output() {
1833        // Test with zero detections
1834        let output = Array2::<f32>::zeros((6, 0));
1835
1836        let mut boxes = Vec::with_capacity(10);
1837        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1838
1839        assert_eq!(boxes.len(), 0);
1840    }
1841
1842    #[test]
1843    fn test_end_to_end_det_pixel_coordinates() {
1844        // Test with pixel coordinates (non-normalized)
1845        let data: Vec<f32> = vec![
1846            100.0, // x1
1847            200.0, // y1
1848            300.0, // x2
1849            400.0, // y2
1850            0.95,  // conf
1851            5.0,   // class
1852        ];
1853        let output = Array2::from_shape_vec((6, 1), data).unwrap();
1854
1855        let mut boxes = Vec::with_capacity(10);
1856        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1857
1858        assert_eq!(boxes.len(), 1);
1859        assert_eq!(boxes[0].label, 5);
1860        assert!((boxes[0].bbox.xmin - 100.0).abs() < 0.01);
1861        assert!((boxes[0].bbox.ymin - 200.0).abs() < 0.01);
1862        assert!((boxes[0].bbox.xmax - 300.0).abs() < 0.01);
1863        assert!((boxes[0].bbox.ymax - 400.0).abs() < 0.01);
1864    }
1865
1866    #[test]
1867    fn test_end_to_end_det_invalid_shape() {
1868        // Test with too few rows (needs at least 6)
1869        let output = Array2::<f32>::zeros((5, 3));
1870
1871        let mut boxes = Vec::with_capacity(10);
1872        let result = decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes);
1873
1874        assert!(result.is_err());
1875        assert!(matches!(
1876            result,
1877            Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 6 rows")
1878        ));
1879    }
1880
1881    // ========================================================================
1882    // Tests for decode_yolo_end_to_end_segdet_float
1883    // ========================================================================
1884
1885    #[test]
1886    fn test_end_to_end_segdet_basic() {
1887        // Create synthetic segdet output: (6 + num_protos, N)
1888        // Detection format: [x1, y1, x2, y2, conf, class, mask_coeff_0..31]
1889        let num_protos = 32;
1890        let num_detections = 2;
1891        let num_features = 6 + num_protos;
1892
1893        // Build detection tensor
1894        let mut data = vec![0.0f32; num_features * num_detections];
1895        // Detection 0: passes threshold
1896        data[0] = 0.1; // x1[0]
1897        data[1] = 0.5; // x1[1]
1898        data[num_detections] = 0.1; // y1[0]
1899        data[num_detections + 1] = 0.5; // y1[1]
1900        data[2 * num_detections] = 0.4; // x2[0]
1901        data[2 * num_detections + 1] = 0.9; // x2[1]
1902        data[3 * num_detections] = 0.4; // y2[0]
1903        data[3 * num_detections + 1] = 0.9; // y2[1]
1904        data[4 * num_detections] = 0.9; // conf[0] - passes
1905        data[4 * num_detections + 1] = 0.3; // conf[1] - fails
1906        data[5 * num_detections] = 1.0; // class[0]
1907        data[5 * num_detections + 1] = 2.0; // class[1]
1908                                            // Fill mask coefficients with small values
1909        for i in 6..num_features {
1910            data[i * num_detections] = 0.1;
1911            data[i * num_detections + 1] = 0.1;
1912        }
1913
1914        let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
1915
1916        // Create protos tensor: (proto_height, proto_width, num_protos)
1917        let protos = Array3::<f32>::zeros((16, 16, num_protos));
1918
1919        let mut boxes = Vec::with_capacity(10);
1920        let mut masks = Vec::with_capacity(10);
1921        decode_yolo_end_to_end_segdet_float(
1922            output.view(),
1923            protos.view(),
1924            0.5,
1925            &mut boxes,
1926            &mut masks,
1927        )
1928        .unwrap();
1929
1930        // Only detection 0 should pass
1931        assert_eq!(boxes.len(), 1);
1932        assert_eq!(masks.len(), 1);
1933        assert_eq!(boxes[0].label, 1);
1934        assert!((boxes[0].score - 0.9).abs() < 0.01);
1935    }
1936
1937    #[test]
1938    fn test_end_to_end_segdet_mask_coordinates() {
1939        // Test that mask coordinates match box coordinates
1940        let num_protos = 32;
1941        let num_features = 6 + num_protos;
1942
1943        let mut data = vec![0.0f32; num_features];
1944        data[0] = 0.2; // x1
1945        data[1] = 0.2; // y1
1946        data[2] = 0.8; // x2
1947        data[3] = 0.8; // y2
1948        data[4] = 0.95; // conf
1949        data[5] = 3.0; // class
1950
1951        let output = Array2::from_shape_vec((num_features, 1), data).unwrap();
1952        let protos = Array3::<f32>::zeros((16, 16, num_protos));
1953
1954        let mut boxes = Vec::with_capacity(10);
1955        let mut masks = Vec::with_capacity(10);
1956        decode_yolo_end_to_end_segdet_float(
1957            output.view(),
1958            protos.view(),
1959            0.5,
1960            &mut boxes,
1961            &mut masks,
1962        )
1963        .unwrap();
1964
1965        assert_eq!(boxes.len(), 1);
1966        assert_eq!(masks.len(), 1);
1967
1968        // Verify mask coordinates match box coordinates
1969        assert!((masks[0].xmin - boxes[0].bbox.xmin).abs() < 0.01);
1970        assert!((masks[0].ymin - boxes[0].bbox.ymin).abs() < 0.01);
1971        assert!((masks[0].xmax - boxes[0].bbox.xmax).abs() < 0.01);
1972        assert!((masks[0].ymax - boxes[0].bbox.ymax).abs() < 0.01);
1973    }
1974
1975    #[test]
1976    fn test_end_to_end_segdet_empty_output() {
1977        let num_protos = 32;
1978        let output = Array2::<f32>::zeros((6 + num_protos, 0));
1979        let protos = Array3::<f32>::zeros((16, 16, num_protos));
1980
1981        let mut boxes = Vec::with_capacity(10);
1982        let mut masks = Vec::with_capacity(10);
1983        decode_yolo_end_to_end_segdet_float(
1984            output.view(),
1985            protos.view(),
1986            0.5,
1987            &mut boxes,
1988            &mut masks,
1989        )
1990        .unwrap();
1991
1992        assert_eq!(boxes.len(), 0);
1993        assert_eq!(masks.len(), 0);
1994    }
1995
1996    #[test]
1997    fn test_end_to_end_segdet_capacity_limit() {
1998        let num_protos = 32;
1999        let num_detections = 5;
2000        let num_features = 6 + num_protos;
2001
2002        let mut data = vec![0.0f32; num_features * num_detections];
2003        // All detections pass threshold
2004        for i in 0..num_detections {
2005            data[i] = 0.1 * (i as f32); // x1
2006            data[num_detections + i] = 0.1 * (i as f32); // y1
2007            data[2 * num_detections + i] = 0.1 * (i as f32) + 0.2; // x2
2008            data[3 * num_detections + i] = 0.1 * (i as f32) + 0.2; // y2
2009            data[4 * num_detections + i] = 0.9; // conf
2010            data[5 * num_detections + i] = i as f32; // class
2011        }
2012
2013        let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
2014        let protos = Array3::<f32>::zeros((16, 16, num_protos));
2015
2016        let mut boxes = Vec::with_capacity(2); // Limit to 2
2017        let mut masks = Vec::with_capacity(2);
2018        decode_yolo_end_to_end_segdet_float(
2019            output.view(),
2020            protos.view(),
2021            0.5,
2022            &mut boxes,
2023            &mut masks,
2024        )
2025        .unwrap();
2026
2027        assert_eq!(boxes.len(), 2);
2028        assert_eq!(masks.len(), 2);
2029    }
2030
2031    #[test]
2032    fn test_end_to_end_segdet_invalid_shape_too_few_rows() {
2033        // Test with too few rows (needs at least 7: 6 base + 1 mask coeff)
2034        let output = Array2::<f32>::zeros((6, 3));
2035        let protos = Array3::<f32>::zeros((16, 16, 32));
2036
2037        let mut boxes = Vec::with_capacity(10);
2038        let mut masks = Vec::with_capacity(10);
2039        let result = decode_yolo_end_to_end_segdet_float(
2040            output.view(),
2041            protos.view(),
2042            0.5,
2043            &mut boxes,
2044            &mut masks,
2045        );
2046
2047        assert!(result.is_err());
2048        assert!(matches!(
2049            result,
2050            Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 7 rows")
2051        ));
2052    }
2053
2054    #[test]
2055    fn test_end_to_end_segdet_invalid_shape_protos_mismatch() {
2056        // Test with mismatched mask coefficients and protos count
2057        let num_protos = 32;
2058        let output = Array2::<f32>::zeros((6 + 16, 3)); // 16 mask coeffs
2059        let protos = Array3::<f32>::zeros((16, 16, num_protos)); // 32 protos
2060
2061        let mut boxes = Vec::with_capacity(10);
2062        let mut masks = Vec::with_capacity(10);
2063        let result = decode_yolo_end_to_end_segdet_float(
2064            output.view(),
2065            protos.view(),
2066            0.5,
2067            &mut boxes,
2068            &mut masks,
2069        );
2070
2071        assert!(result.is_err());
2072        assert!(matches!(
2073            result,
2074            Err(crate::DecoderError::InvalidShape(s)) if s.contains("doesn't match protos count")
2075        ));
2076    }
2077
2078    // ========================================================================
2079    // Tests for decode_yolo_split_end_to_end_segdet_float
2080    // ========================================================================
2081
2082    #[test]
2083    fn test_split_end_to_end_segdet_basic() {
2084        // Create synthetic segdet output: (6 + num_protos, N)
2085        // Detection format: [x1, y1, x2, y2, conf, class, mask_coeff_0..31]
2086        let num_protos = 32;
2087        let num_detections = 2;
2088        let num_features = 6 + num_protos;
2089
2090        // Build detection tensor
2091        let mut data = vec![0.0f32; num_features * num_detections];
2092        // Detection 0: passes threshold
2093        data[0] = 0.1; // x1[0]
2094        data[1] = 0.5; // x1[1]
2095        data[num_detections] = 0.1; // y1[0]
2096        data[num_detections + 1] = 0.5; // y1[1]
2097        data[2 * num_detections] = 0.4; // x2[0]
2098        data[2 * num_detections + 1] = 0.9; // x2[1]
2099        data[3 * num_detections] = 0.4; // y2[0]
2100        data[3 * num_detections + 1] = 0.9; // y2[1]
2101        data[4 * num_detections] = 0.9; // conf[0] - passes
2102        data[4 * num_detections + 1] = 0.3; // conf[1] - fails
2103        data[5 * num_detections] = 1.0; // class[0]
2104        data[5 * num_detections + 1] = 2.0; // class[1]
2105                                            // Fill mask coefficients with small values
2106        for i in 6..num_features {
2107            data[i * num_detections] = 0.1;
2108            data[i * num_detections + 1] = 0.1;
2109        }
2110
2111        let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
2112        let box_coords = output.slice(s![..4, ..]);
2113        let scores = output.slice(s![4..5, ..]);
2114        let classes = output.slice(s![5..6, ..]);
2115        let mask_coeff = output.slice(s![6.., ..]);
2116        // Create protos tensor: (proto_height, proto_width, num_protos)
2117        let protos = Array3::<f32>::zeros((16, 16, num_protos));
2118
2119        let mut boxes = Vec::with_capacity(10);
2120        let mut masks = Vec::with_capacity(10);
2121        decode_yolo_split_end_to_end_segdet_float(
2122            box_coords,
2123            scores,
2124            classes,
2125            mask_coeff,
2126            protos.view(),
2127            0.5,
2128            &mut boxes,
2129            &mut masks,
2130        )
2131        .unwrap();
2132
2133        // Only detection 0 should pass
2134        assert_eq!(boxes.len(), 1);
2135        assert_eq!(masks.len(), 1);
2136        assert_eq!(boxes[0].label, 1);
2137        assert!((boxes[0].score - 0.9).abs() < 0.01);
2138    }
2139
2140    // ========================================================================
2141    // Tests for yolo_segmentation_to_mask
2142    // ========================================================================
2143
2144    #[test]
2145    fn test_segmentation_to_mask_basic() {
2146        // Create a 4x4x1 segmentation with values above and below threshold
2147        let data: Vec<u8> = vec![
2148            100, 200, 50, 150, // row 0
2149            10, 255, 128, 64, // row 1
2150            0, 127, 128, 255, // row 2
2151            64, 64, 192, 192, // row 3
2152        ];
2153        let segmentation = Array3::from_shape_vec((4, 4, 1), data).unwrap();
2154
2155        let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2156
2157        // Values >= 128 should be 1, others 0
2158        assert_eq!(mask[[0, 0]], 0); // 100 < 128
2159        assert_eq!(mask[[0, 1]], 1); // 200 >= 128
2160        assert_eq!(mask[[0, 2]], 0); // 50 < 128
2161        assert_eq!(mask[[0, 3]], 1); // 150 >= 128
2162        assert_eq!(mask[[1, 1]], 1); // 255 >= 128
2163        assert_eq!(mask[[1, 2]], 1); // 128 >= 128
2164        assert_eq!(mask[[2, 0]], 0); // 0 < 128
2165        assert_eq!(mask[[2, 1]], 0); // 127 < 128
2166    }
2167
2168    #[test]
2169    fn test_segmentation_to_mask_all_above() {
2170        let segmentation = Array3::from_elem((4, 4, 1), 255u8);
2171        let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2172        assert!(mask.iter().all(|&x| x == 1));
2173    }
2174
2175    #[test]
2176    fn test_segmentation_to_mask_all_below() {
2177        let segmentation = Array3::from_elem((4, 4, 1), 64u8);
2178        let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2179        assert!(mask.iter().all(|&x| x == 0));
2180    }
2181
2182    #[test]
2183    fn test_segmentation_to_mask_invalid_shape() {
2184        let segmentation = Array3::from_elem((4, 4, 3), 128u8);
2185        let result = yolo_segmentation_to_mask(segmentation.view(), 128);
2186
2187        assert!(result.is_err());
2188        assert!(matches!(
2189            result,
2190            Err(crate::DecoderError::InvalidShape(s)) if s.contains("(H, W, 1)")
2191        ));
2192    }
2193
2194    // ========================================================================
2195    // Tests for protobox / NORM_LIMIT regression
2196    // ========================================================================
2197
2198    #[test]
2199    fn test_protobox_clamps_edge_coordinates() {
2200        // bbox with xmax=1.0 should not panic (OOB guard)
2201        let protos = Array3::<f32>::zeros((16, 16, 4));
2202        let view = protos.view();
2203        let roi = BoundingBox {
2204            xmin: 0.5,
2205            ymin: 0.5,
2206            xmax: 1.0,
2207            ymax: 1.0,
2208        };
2209        let result = protobox(&view, &roi);
2210        assert!(result.is_ok(), "protobox should accept xmax=1.0");
2211        let (cropped, _roi_norm) = result.unwrap();
2212        // Cropped region must have non-zero spatial dimensions
2213        assert!(cropped.shape()[0] > 0);
2214        assert!(cropped.shape()[1] > 0);
2215        assert_eq!(cropped.shape()[2], 4);
2216    }
2217
2218    #[test]
2219    fn test_protobox_rejects_wildly_out_of_range() {
2220        // bbox with coords > NORM_LIMIT (e.g. 3.0) returns error
2221        let protos = Array3::<f32>::zeros((16, 16, 4));
2222        let view = protos.view();
2223        let roi = BoundingBox {
2224            xmin: 0.0,
2225            ymin: 0.0,
2226            xmax: 3.0,
2227            ymax: 3.0,
2228        };
2229        let result = protobox(&view, &roi);
2230        assert!(
2231            matches!(result, Err(crate::DecoderError::InvalidShape(s)) if s.contains("un-normalized")),
2232            "protobox should reject coords > NORM_LIMIT"
2233        );
2234    }
2235
2236    #[test]
2237    fn test_protobox_accepts_slightly_over_one() {
2238        // bbox with coords at 1.5 (within NORM_LIMIT=2.0) succeeds
2239        let protos = Array3::<f32>::zeros((16, 16, 4));
2240        let view = protos.view();
2241        let roi = BoundingBox {
2242            xmin: 0.0,
2243            ymin: 0.0,
2244            xmax: 1.5,
2245            ymax: 1.5,
2246        };
2247        let result = protobox(&view, &roi);
2248        assert!(
2249            result.is_ok(),
2250            "protobox should accept coords <= NORM_LIMIT (2.0)"
2251        );
2252        let (cropped, _roi_norm) = result.unwrap();
2253        // Entire proto map should be selected when coords > 1.0 (clamped to boundary)
2254        assert_eq!(cropped.shape()[0], 16);
2255        assert_eq!(cropped.shape()[1], 16);
2256    }
2257
2258    #[test]
2259    fn test_segdet_float_proto_no_panic() {
2260        // Simulates YOLOv8n-seg: output0 = [116, 8400] (4 box + 80 class + 32 mask coeff)
2261        // output1 (protos) = [32, 160, 160]
2262        let num_proposals = 100; // enough to produce idx >= 32
2263        let num_classes = 80;
2264        let num_mask_coeffs = 32;
2265        let rows = 4 + num_classes + num_mask_coeffs; // 116
2266
2267        // Fill boxes with valid xywh data so some detections pass the threshold.
2268        // Layout is [116, num_proposals] row-major: row 0=cx, 1=cy, 2=w, 3=h,
2269        // rows 4..84=class scores, rows 84..116=mask coefficients.
2270        let mut data = vec![0.0f32; rows * num_proposals];
2271        for i in 0..num_proposals {
2272            let row = |r: usize| r * num_proposals + i;
2273            data[row(0)] = 320.0; // cx
2274            data[row(1)] = 320.0; // cy
2275            data[row(2)] = 50.0; // w
2276            data[row(3)] = 50.0; // h
2277            data[row(4)] = 0.9; // class-0 score
2278        }
2279        let boxes = ndarray::Array2::from_shape_vec((rows, num_proposals), data).unwrap();
2280
2281        // Protos must be in HWC order (decoder.rs protos_to_hwc converts
2282        // before calling into these functions).
2283        let protos = ndarray::Array3::<f32>::zeros((160, 160, num_mask_coeffs));
2284
2285        let mut output_boxes = Vec::with_capacity(300);
2286
2287        // This panicked before fix: mask_tensor.row(idx) with idx >= 32
2288        let proto_data = impl_yolo_segdet_float_proto::<XYWH, _, _>(
2289            boxes.view(),
2290            protos.view(),
2291            0.5,
2292            0.7,
2293            Some(Nms::default()),
2294            &mut output_boxes,
2295        );
2296
2297        // Should produce detections (NMS will collapse many overlapping boxes)
2298        assert!(!output_boxes.is_empty());
2299        assert_eq!(proto_data.mask_coefficients.len(), output_boxes.len());
2300        // Each mask coefficient vector should have 32 elements
2301        for coeffs in &proto_data.mask_coefficients {
2302            assert_eq!(coeffs.len(), num_mask_coeffs);
2303        }
2304    }
2305
2306    // ========================================================================
2307    // Pre-NMS top-K cap (MAX_NMS_CANDIDATES)
2308    // ========================================================================
2309
2310    /// At very low score thresholds (e.g., t=0.01 on YOLOv8 with 8400×80
2311    /// candidates) almost every score passes the filter, feeding O(n²)
2312    /// NMS and a per-survivor mask matmul. The decoder caps the
2313    /// candidate set fed to NMS at `MAX_NMS_CANDIDATES` (Ultralytics
2314    /// default 30 000) to bound worst-case decode time.
2315    ///
2316    /// This regression test pumps 50 000 above-threshold candidates
2317    /// into `impl_yolo_segdet_get_boxes` with NMS bypassed (Nms=None)
2318    /// and a generous post-NMS cap. Before the fix, the function
2319    /// returned all 50 000; after the fix, exactly 30 000.
2320    #[test]
2321    fn test_pre_nms_cap_truncates_excess_candidates() {
2322        let n: usize = 50_000;
2323        let num_classes = 1;
2324
2325        // Identical valid boxes. Distinct scores (descending) so the
2326        // top-K cap keeps the highest-scoring ones in deterministic
2327        // order — letting us assert *which* ones survived.
2328        let mut boxes_data = Vec::with_capacity(n * 4);
2329        let mut scores_data = Vec::with_capacity(n * num_classes);
2330        for i in 0..n {
2331            boxes_data.extend_from_slice(&[0.1f32, 0.1, 0.5, 0.5]);
2332            // score_i = 0.99 - i * 1e-7 keeps everything well above 0.1
2333            // threshold but strictly decreasing.
2334            scores_data.push(0.99 - (i as f32) * 1e-7);
2335        }
2336        let boxes = Array2::from_shape_vec((n, 4), boxes_data).unwrap();
2337        let scores = Array2::from_shape_vec((n, num_classes), scores_data).unwrap();
2338
2339        let result = impl_yolo_segdet_get_boxes::<XYXY, _, _>(
2340            boxes.view(),
2341            scores.view(),
2342            0.1,
2343            1.0,
2344            None,       // bypass NMS so we measure the cap, not suppression
2345            usize::MAX, // no post-NMS truncation
2346        );
2347
2348        assert_eq!(
2349            result.len(),
2350            crate::yolo::MAX_NMS_CANDIDATES,
2351            "pre-NMS cap should truncate to MAX_NMS_CANDIDATES; got {}",
2352            result.len()
2353        );
2354        // Top-K survivors: highest scores were the first n indices,
2355        // so survivor 0 must have score ~0.99.
2356        let top_score = result[0].0.score;
2357        assert!(
2358            top_score > 0.98,
2359            "highest-ranked survivor should have the largest score, got {top_score}"
2360        );
2361    }
2362
2363    /// Counterpart for the quantized split path. Same contract: feed
2364    /// more than `MAX_NMS_CANDIDATES` survivors above the quantized
2365    /// threshold, confirm `impl_yolo_split_segdet_quant_get_boxes`
2366    /// truncates before NMS.
2367    #[test]
2368    fn test_pre_nms_cap_truncates_excess_candidates_quant() {
2369        use crate::Quantization;
2370        let n: usize = 50_000;
2371        let num_classes = 1;
2372
2373        // i8 boxes with simple scale/zp; the box value 50 dequantizes
2374        // to 0.5 with scale=0.01, zp=0 — fine for a flat box set.
2375        let boxes_data = (0..n).flat_map(|_| [10i8, 10, 50, 50]).collect::<Vec<_>>();
2376        let boxes = Array2::from_shape_vec((n, 4), boxes_data).unwrap();
2377        let quant_boxes = Quantization {
2378            scale: 0.01,
2379            zero_point: 0,
2380        };
2381
2382        // u8 scores: distinct descending values, all well above threshold.
2383        // value 250 → 0.98 with scale 0.00392, zp 0.
2384        // value (250 - i % 200) keeps a wide spread above the dequant
2385        // threshold of 0.5.
2386        let scores_data: Vec<u8> = (0..n)
2387            .map(|i| 250u8.saturating_sub((i % 200) as u8))
2388            .collect();
2389        let scores = Array2::from_shape_vec((n, num_classes), scores_data).unwrap();
2390        let quant_scores = Quantization {
2391            scale: 0.00392,
2392            zero_point: 0,
2393        };
2394
2395        let result = impl_yolo_split_segdet_quant_get_boxes::<XYXY, _, _>(
2396            (boxes.view(), quant_boxes),
2397            (scores.view(), quant_scores),
2398            0.1,
2399            1.0,
2400            None,
2401            usize::MAX,
2402        );
2403
2404        assert_eq!(
2405            result.len(),
2406            crate::yolo::MAX_NMS_CANDIDATES,
2407            "quant path pre-NMS cap should truncate to MAX_NMS_CANDIDATES; got {}",
2408            result.len()
2409        );
2410    }
2411
2412    /// Regression test for HAILORT_BUG.md — the YoloSegDet path
2413    /// (combined `(4 + nc + nm, N)` detection tensor + separate protos)
2414    /// must pair each surviving detection with the mask coefficient
2415    /// row at the SAME anchor index the box came from. The validator
2416    /// sees this path miss the pairing under schema-v2 Hailo inputs
2417    /// (mAP collapse from 46.8 → 3.65 while mask IoU stays at 66.9,
2418    /// the fingerprint of mask-to-detection misalignment).
2419    ///
2420    /// Construction: three anchors with distinct mask-coef signatures
2421    /// that, after dot(coefs, protos) + sigmoid, produce HIGH vs LOW
2422    /// mask pixel values. Two anchors survive (one high, one low); if
2423    /// the mask row is looked up at the wrong index, the per-detection
2424    /// mean mask value would cross the threshold and we catch it.
2425    #[test]
2426    fn segdet_combined_tensor_pairs_detection_with_matching_mask_row() {
2427        let nc = 2; // num_classes
2428        let nm = 2; // num_protos
2429        let n = 3; // num_anchors
2430        let feat = 4 + nc + nm; // 8
2431
2432        // Tensor layout: (8, 3) rows=features, cols=anchors.
2433        // Row indices:  0..4 = xywh, 4..6 = scores, 6..8 = mask_coefs.
2434        //
2435        //         anchor 0 | anchor 1 | anchor 2
2436        // xc       0.2      | 0.5      | 0.8
2437        // yc       0.2      | 0.5      | 0.8
2438        // w        0.1      | 0.1      | 0.1
2439        // h        0.1      | 0.1      | 0.1
2440        // s[0]     0.9      | 0.0      | 0.8   (class 0)
2441        // s[1]     0.0      | 0.0      | 0.0   (class 1 — always loses)
2442        // m[0]     3.0      | 0.0      | -3.0  (high for a0, low for a2)
2443        // m[1]     3.0      | 0.0      | -3.0
2444        //
2445        // Proto[0] = Proto[1] = all-ones (8x8), so
2446        //   mask(a0) = sigmoid(3 + 3) ≈ 0.9975 → 254
2447        //   mask(a2) = sigmoid(-3 + -3) ≈ 0.0025 → 1
2448        // 250-point gap makes any misalignment trivially detectable.
2449        let mut data = vec![0.0f32; feat * n];
2450        let set = |d: &mut [f32], r: usize, c: usize, v: f32| d[r * n + c] = v;
2451        set(&mut data, 0, 0, 0.2);
2452        set(&mut data, 1, 0, 0.2);
2453        set(&mut data, 2, 0, 0.1);
2454        set(&mut data, 3, 0, 0.1);
2455        set(&mut data, 0, 1, 0.5);
2456        set(&mut data, 1, 1, 0.5);
2457        set(&mut data, 2, 1, 0.1);
2458        set(&mut data, 3, 1, 0.1);
2459        set(&mut data, 0, 2, 0.8);
2460        set(&mut data, 1, 2, 0.8);
2461        set(&mut data, 2, 2, 0.1);
2462        set(&mut data, 3, 2, 0.1);
2463        set(&mut data, 4, 0, 0.9);
2464        set(&mut data, 4, 2, 0.8);
2465        set(&mut data, 6, 0, 3.0);
2466        set(&mut data, 7, 0, 3.0);
2467        set(&mut data, 6, 2, -3.0);
2468        set(&mut data, 7, 2, -3.0);
2469
2470        let output = Array2::from_shape_vec((feat, n), data).unwrap();
2471        let protos = Array3::<f32>::from_elem((8, 8, nm), 1.0);
2472
2473        let mut boxes: Vec<DetectBox> = Vec::with_capacity(10);
2474        let mut masks: Vec<Segmentation> = Vec::with_capacity(10);
2475        decode_yolo_segdet_float(
2476            output.view(),
2477            protos.view(),
2478            0.5,
2479            0.5,
2480            Some(Nms::ClassAgnostic),
2481            &mut boxes,
2482            &mut masks,
2483        )
2484        .unwrap();
2485
2486        assert_eq!(
2487            boxes.len(),
2488            2,
2489            "two anchors above threshold should survive (a0 score=0.9, a2 score=0.8); got {}",
2490            boxes.len()
2491        );
2492
2493        // Build a (anchor_index → mask_mean) mapping from the results.
2494        // Anchor 0 has centre (0.2, 0.2), anchor 2 has centre (0.8,
2495        // 0.8). The DetectBox bbox is the post-XYWH-to-XYXY conversion
2496        // of the original xywh; cropping inside protobox may shrink it,
2497        // so match by centre (0.2 vs 0.8) rather than exact bbox.
2498        for (b, m) in boxes.iter().zip(masks.iter()) {
2499            let cx = (b.bbox.xmin + b.bbox.xmax) * 0.5;
2500            let mean = {
2501                let s = &m.segmentation;
2502                let total: u32 = s.iter().map(|&v| v as u32).sum();
2503                total as f32 / s.len() as f32
2504            };
2505            if cx < 0.3 {
2506                // anchor 0 — expect HIGH mask values ≈ 254
2507                assert!(
2508                    mean > 200.0,
2509                    "anchor 0 detection (centre {cx:.2}) should have high-value mask; got mean {mean}"
2510                );
2511            } else if cx > 0.7 {
2512                // anchor 2 — expect LOW mask values ≈ 1
2513                assert!(
2514                    mean < 50.0,
2515                    "anchor 2 detection (centre {cx:.2}) should have low-value mask; got mean {mean}"
2516                );
2517            } else {
2518                panic!("unexpected detection centre {cx:.2}");
2519            }
2520        }
2521    }
2522}