Skip to main content

edgefirst_decoder/
yolo.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use std::fmt::Debug;
5
6use ndarray::{
7    parallel::prelude::{IntoParallelIterator, ParallelIterator},
8    s, Array2, Array3, ArrayView1, ArrayView2, ArrayView3,
9};
10use num_traits::{AsPrimitive, Float, PrimInt, Signed};
11use rayon::slice::ParallelSliceMut;
12
13use crate::{
14    byte::{
15        nms_class_aware_int, nms_extra_class_aware_int, nms_extra_int, nms_int,
16        postprocess_boxes_index_quant, postprocess_boxes_quant, quantize_score_threshold,
17    },
18    configs::Nms,
19    dequant_detect_box,
20    float::{
21        nms_class_aware_float, nms_extra_class_aware_float, nms_extra_float, nms_float,
22        postprocess_boxes_float, postprocess_boxes_index_float,
23    },
24    BBoxTypeTrait, BoundingBox, DetectBox, DetectBoxQuantized, ProtoData, Quantization,
25    Segmentation, XYWH, XYXY,
26};
27
28/// Maximum number of above-threshold candidates fed to NMS.
29///
30/// At very low score thresholds (e.g., t=0.01 on YOLOv8 with
31/// 8400 anchors × 80 classes), the number of survivors approaches
32/// the full 672 000-entry score grid. NMS is O(n²) and the
33/// downstream mask matmul runs once per survivor, so an
34/// unbounded set produces minutes-per-frame decode times.
35///
36/// `MAX_NMS_CANDIDATES` matches the Ultralytics `max_nms` default
37/// and is applied as a top-K-by-score truncation immediately
38/// before NMS. Values above the cap are silently dropped — at the
39/// score thresholds where the cap activates the bottom of the
40/// candidate list is dominated by noise that NMS would discard
41/// anyway.
42pub(crate) const MAX_NMS_CANDIDATES: usize = 30_000;
43
44/// Truncate `boxes` to the highest-scoring `MAX_NMS_CANDIDATES`
45/// entries in-place when the input exceeds the cap. No-op
46/// otherwise. The sort is unstable and parallel — order among
47/// equal-score boxes is not guaranteed.
48fn truncate_to_top_k_by_score<E: Send>(boxes: &mut Vec<(DetectBox, E)>) {
49    if boxes.len() > MAX_NMS_CANDIDATES {
50        boxes.par_sort_unstable_by(|a, b| b.0.score.total_cmp(&a.0.score));
51        boxes.truncate(MAX_NMS_CANDIDATES);
52    }
53}
54
55/// Quantized counterpart of [`truncate_to_top_k_by_score`]. Sorts on
56/// the raw quantized score (which preserves order under monotonic
57/// dequantization).
58fn truncate_to_top_k_by_score_quant<S: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send>(
59    boxes: &mut Vec<(DetectBoxQuantized<S>, E)>,
60) {
61    if boxes.len() > MAX_NMS_CANDIDATES {
62        boxes.par_sort_unstable_by(|a, b| b.0.score.cmp(&a.0.score));
63        boxes.truncate(MAX_NMS_CANDIDATES);
64    }
65}
66
67/// Dispatches to the appropriate NMS function based on mode for float boxes.
68fn dispatch_nms_float(nms: Option<Nms>, iou: f32, boxes: Vec<DetectBox>) -> Vec<DetectBox> {
69    match nms {
70        Some(Nms::ClassAgnostic) => nms_float(iou, boxes),
71        Some(Nms::ClassAware) => nms_class_aware_float(iou, boxes),
72        None => boxes, // bypass NMS
73    }
74}
75
76/// Dispatches to the appropriate NMS function based on mode for float boxes
77/// with extra data.
78pub(super) fn dispatch_nms_extra_float<E: Send + Sync>(
79    nms: Option<Nms>,
80    iou: f32,
81    boxes: Vec<(DetectBox, E)>,
82) -> Vec<(DetectBox, E)> {
83    match nms {
84        Some(Nms::ClassAgnostic) => nms_extra_float(iou, boxes),
85        Some(Nms::ClassAware) => nms_extra_class_aware_float(iou, boxes),
86        None => boxes, // bypass NMS
87    }
88}
89
90/// Dispatches to the appropriate NMS function based on mode for quantized
91/// boxes.
92fn dispatch_nms_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync>(
93    nms: Option<Nms>,
94    iou: f32,
95    boxes: Vec<DetectBoxQuantized<SCORE>>,
96) -> Vec<DetectBoxQuantized<SCORE>> {
97    match nms {
98        Some(Nms::ClassAgnostic) => nms_int(iou, boxes),
99        Some(Nms::ClassAware) => nms_class_aware_int(iou, boxes),
100        None => boxes, // bypass NMS
101    }
102}
103
104/// Dispatches to the appropriate NMS function based on mode for quantized boxes
105/// with extra data.
106fn dispatch_nms_extra_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send + Sync>(
107    nms: Option<Nms>,
108    iou: f32,
109    boxes: Vec<(DetectBoxQuantized<SCORE>, E)>,
110) -> Vec<(DetectBoxQuantized<SCORE>, E)> {
111    match nms {
112        Some(Nms::ClassAgnostic) => nms_extra_int(iou, boxes),
113        Some(Nms::ClassAware) => nms_extra_class_aware_int(iou, boxes),
114        None => boxes, // bypass NMS
115    }
116}
117
118/// Decodes YOLO detection outputs from quantized tensors into detection boxes.
119///
120/// Boxes are expected to be in XYWH format.
121///
122/// Expected shapes of inputs:
123/// - output: (4 + num_classes, num_boxes)
124pub fn decode_yolo_det<BOX: PrimInt + AsPrimitive<f32> + Send + Sync>(
125    output: (ArrayView2<BOX>, Quantization),
126    score_threshold: f32,
127    iou_threshold: f32,
128    nms: Option<Nms>,
129    output_boxes: &mut Vec<DetectBox>,
130) where
131    f32: AsPrimitive<BOX>,
132{
133    impl_yolo_quant::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
134}
135
136/// Decodes YOLO detection outputs from float tensors into detection boxes.
137///
138/// Boxes are expected to be in XYWH format.
139///
140/// Expected shapes of inputs:
141/// - output: (4 + num_classes, num_boxes)
142pub fn decode_yolo_det_float<T>(
143    output: ArrayView2<T>,
144    score_threshold: f32,
145    iou_threshold: f32,
146    nms: Option<Nms>,
147    output_boxes: &mut Vec<DetectBox>,
148) where
149    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
150    f32: AsPrimitive<T>,
151{
152    impl_yolo_float::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
153}
154
155/// Decodes YOLO detection and segmentation outputs from quantized tensors into
156/// detection boxes and segmentation masks.
157///
158/// Boxes are expected to be in XYWH format.
159///
160/// Expected shapes of inputs:
161/// - boxes: (4 + num_classes + num_protos, num_boxes)
162/// - protos: (proto_height, proto_width, num_protos)
163///
164/// # Errors
165/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
166pub fn decode_yolo_segdet_quant<
167    BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
168    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
169>(
170    boxes: (ArrayView2<BOX>, Quantization),
171    protos: (ArrayView3<PROTO>, Quantization),
172    score_threshold: f32,
173    iou_threshold: f32,
174    nms: Option<Nms>,
175    output_boxes: &mut Vec<DetectBox>,
176    output_masks: &mut Vec<Segmentation>,
177) -> Result<(), crate::DecoderError>
178where
179    f32: AsPrimitive<BOX>,
180{
181    impl_yolo_segdet_quant::<XYWH, _, _>(
182        boxes,
183        protos,
184        score_threshold,
185        iou_threshold,
186        nms,
187        output_boxes,
188        output_masks,
189    )
190}
191
192/// Decodes YOLO detection and segmentation outputs from float tensors into
193/// detection boxes and segmentation masks.
194///
195/// Boxes are expected to be in XYWH format.
196///
197/// Expected shapes of inputs:
198/// - boxes: (4 + num_classes + num_protos, num_boxes)
199/// - protos: (proto_height, proto_width, num_protos)
200///
201/// # Errors
202/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
203pub fn decode_yolo_segdet_float<T>(
204    boxes: ArrayView2<T>,
205    protos: ArrayView3<T>,
206    score_threshold: f32,
207    iou_threshold: f32,
208    nms: Option<Nms>,
209    output_boxes: &mut Vec<DetectBox>,
210    output_masks: &mut Vec<Segmentation>,
211) -> Result<(), crate::DecoderError>
212where
213    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
214    f32: AsPrimitive<T>,
215{
216    impl_yolo_segdet_float::<XYWH, _, _>(
217        boxes,
218        protos,
219        score_threshold,
220        iou_threshold,
221        nms,
222        output_boxes,
223        output_masks,
224    )
225}
226
227/// Decodes YOLO split detection outputs from quantized tensors into detection
228/// boxes.
229///
230/// Boxes are expected to be in XYWH format.
231///
232/// Expected shapes of inputs:
233/// - boxes: (4, num_boxes)
234/// - scores: (num_classes, num_boxes)
235///
236/// # Panics
237/// Panics if shapes don't match the expected dimensions.
238pub fn decode_yolo_split_det_quant<
239    BOX: PrimInt + AsPrimitive<i32> + AsPrimitive<f32> + Send + Sync,
240    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
241>(
242    boxes: (ArrayView2<BOX>, Quantization),
243    scores: (ArrayView2<SCORE>, Quantization),
244    score_threshold: f32,
245    iou_threshold: f32,
246    nms: Option<Nms>,
247    output_boxes: &mut Vec<DetectBox>,
248) where
249    f32: AsPrimitive<SCORE>,
250{
251    impl_yolo_split_quant::<XYWH, _, _>(
252        boxes,
253        scores,
254        score_threshold,
255        iou_threshold,
256        nms,
257        output_boxes,
258    );
259}
260
261/// Decodes YOLO split detection outputs from float tensors into detection
262/// boxes.
263///
264/// Boxes are expected to be in XYWH format.
265///
266/// Expected shapes of inputs:
267/// - boxes: (4, num_boxes)
268/// - scores: (num_classes, num_boxes)
269///
270/// # Panics
271/// Panics if shapes don't match the expected dimensions.
272pub fn decode_yolo_split_det_float<T>(
273    boxes: ArrayView2<T>,
274    scores: ArrayView2<T>,
275    score_threshold: f32,
276    iou_threshold: f32,
277    nms: Option<Nms>,
278    output_boxes: &mut Vec<DetectBox>,
279) where
280    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
281    f32: AsPrimitive<T>,
282{
283    impl_yolo_split_float::<XYWH, _, _>(
284        boxes,
285        scores,
286        score_threshold,
287        iou_threshold,
288        nms,
289        output_boxes,
290    );
291}
292
293/// Decodes YOLO split detection segmentation outputs from quantized tensors
294/// into detection boxes and segmentation masks.
295///
296/// Boxes are expected to be in XYWH format.
297///
298/// Expected shapes of inputs:
299/// - boxes_tensor: (4, num_boxes)
300/// - scores_tensor: (num_classes, num_boxes)
301/// - mask_tensor: (num_protos, num_boxes)
302/// - protos: (proto_height, proto_width, num_protos)
303///
304/// # Errors
305/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
306#[allow(clippy::too_many_arguments)]
307pub fn decode_yolo_split_segdet<
308    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
309    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
310    MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
311    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
312>(
313    boxes: (ArrayView2<BOX>, Quantization),
314    scores: (ArrayView2<SCORE>, Quantization),
315    mask_coeff: (ArrayView2<MASK>, Quantization),
316    protos: (ArrayView3<PROTO>, Quantization),
317    score_threshold: f32,
318    iou_threshold: f32,
319    nms: Option<Nms>,
320    output_boxes: &mut Vec<DetectBox>,
321    output_masks: &mut Vec<Segmentation>,
322) -> Result<(), crate::DecoderError>
323where
324    f32: AsPrimitive<SCORE>,
325{
326    impl_yolo_split_segdet_quant::<XYWH, _, _, _, _>(
327        boxes,
328        scores,
329        mask_coeff,
330        protos,
331        score_threshold,
332        iou_threshold,
333        nms,
334        output_boxes,
335        output_masks,
336    )
337}
338
339/// Decodes YOLO split detection segmentation outputs from float tensors
340/// into detection boxes and segmentation masks.
341///
342/// Boxes are expected to be in XYWH format.
343///
344/// Expected shapes of inputs:
345/// - boxes_tensor: (4, num_boxes)
346/// - scores_tensor: (num_classes, num_boxes)
347/// - mask_tensor: (num_protos, num_boxes)
348/// - protos: (proto_height, proto_width, num_protos)
349///
350/// # Errors
351/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
352#[allow(clippy::too_many_arguments)]
353pub fn decode_yolo_split_segdet_float<T>(
354    boxes: ArrayView2<T>,
355    scores: ArrayView2<T>,
356    mask_coeff: ArrayView2<T>,
357    protos: ArrayView3<T>,
358    score_threshold: f32,
359    iou_threshold: f32,
360    nms: Option<Nms>,
361    output_boxes: &mut Vec<DetectBox>,
362    output_masks: &mut Vec<Segmentation>,
363) -> Result<(), crate::DecoderError>
364where
365    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
366    f32: AsPrimitive<T>,
367{
368    impl_yolo_split_segdet_float::<XYWH, _, _, _, _>(
369        boxes,
370        scores,
371        mask_coeff,
372        protos,
373        score_threshold,
374        iou_threshold,
375        nms,
376        output_boxes,
377        output_masks,
378    )
379}
380
381/// Decodes end-to-end YOLO detection outputs (post-NMS from model).
382/// Expects an array of shape `(6, N)`, where the first dimension (rows)
383/// corresponds to the 6 per-detection features
384/// `[x1, y1, x2, y2, conf, class]` and the second dimension (columns)
385/// indexes the `N` detections.
386/// Boxes are output directly without NMS (the model already applied NMS).
387///
388/// Coordinates may be normalized `[0, 1]` or absolute pixel values depending
389/// on the model configuration. The caller should check
390/// `decoder.normalized_boxes()` to determine which.
391///
392/// # Errors
393///
394/// Returns `DecoderError::InvalidShape` if `output` has fewer than 6 rows.
395pub fn decode_yolo_end_to_end_det_float<T>(
396    output: ArrayView2<T>,
397    score_threshold: f32,
398    output_boxes: &mut Vec<DetectBox>,
399) -> Result<(), crate::DecoderError>
400where
401    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
402    f32: AsPrimitive<T>,
403{
404    // Validate input shape: need at least 6 rows (x1, y1, x2, y2, conf, class)
405    if output.shape()[0] < 6 {
406        return Err(crate::DecoderError::InvalidShape(format!(
407            "End-to-end detection output requires at least 6 rows, got {}",
408            output.shape()[0]
409        )));
410    }
411
412    // Input shape: (6, N) -> transpose to (N, 4) for boxes and (N, 1) for scores
413    let boxes = output.slice(s![0..4, ..]).reversed_axes();
414    let scores = output.slice(s![4..5, ..]).reversed_axes();
415    let classes = output.slice(s![5, ..]);
416    let mut boxes =
417        postprocess_boxes_index_float::<XYXY, _, _>(score_threshold.as_(), boxes, scores);
418    boxes.truncate(output_boxes.capacity());
419    output_boxes.clear();
420    for (mut b, i) in boxes.into_iter() {
421        b.label = classes[i].as_() as usize;
422        output_boxes.push(b);
423    }
424    // No NMS — model output is already post-NMS
425    Ok(())
426}
427
428/// Decodes end-to-end YOLO detection + segmentation outputs (post-NMS from
429/// model).
430///
431/// Input shapes:
432/// - detection: (6 + num_protos, N) where rows are [x1, y1, x2, y2, conf,
433///   class, mask_coeff_0, ..., mask_coeff_31]
434/// - protos: (proto_height, proto_width, num_protos)
435///
436/// Boxes are output directly without NMS (model already applied NMS).
437/// Coordinates may be normalized [0,1] or pixel values depending on model
438/// config.
439///
440/// # Errors
441///
442/// Returns `DecoderError::InvalidShape` if:
443/// - output has fewer than 7 rows (6 base + at least 1 mask coefficient)
444/// - protos shape doesn't match mask coefficients count
445pub fn decode_yolo_end_to_end_segdet_float<T>(
446    output: ArrayView2<T>,
447    protos: ArrayView3<T>,
448    score_threshold: f32,
449    output_boxes: &mut Vec<DetectBox>,
450    output_masks: &mut Vec<crate::Segmentation>,
451) -> Result<(), crate::DecoderError>
452where
453    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
454    f32: AsPrimitive<T>,
455{
456    let (boxes, scores, classes, mask_coeff) =
457        postprocess_yolo_end_to_end_segdet(&output, protos.dim().2)?;
458    let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
459        boxes,
460        scores,
461        classes,
462        score_threshold,
463        output_boxes.capacity(),
464    );
465
466    // No NMS — model output is already post-NMS
467
468    impl_yolo_split_segdet_process_masks(boxes, mask_coeff, protos, output_boxes, output_masks)
469}
470
471/// Decodes split end-to-end YOLO detection outputs (post-NMS from model).
472///
473/// Input shapes (after batch dim removed):
474/// - boxes: (4, N) — xyxy pixel coordinates
475/// - scores: (1, N) — confidence of the top class
476/// - classes: (1, N) — class index of the top class
477///
478/// Boxes are output directly without NMS (model already applied NMS).
479pub fn decode_yolo_split_end_to_end_det_float<T: Float + AsPrimitive<f32>>(
480    boxes: ArrayView2<T>,
481    scores: ArrayView2<T>,
482    classes: ArrayView2<T>,
483    score_threshold: f32,
484    output_boxes: &mut Vec<DetectBox>,
485) -> Result<(), crate::DecoderError> {
486    let n = boxes.shape()[1];
487
488    output_boxes.clear();
489
490    let (boxes, scores, classes) = postprocess_yolo_split_end_to_end_det(boxes, scores, &classes)?;
491
492    for i in 0..n {
493        let score: f32 = scores[[i, 0]].as_();
494        if score < score_threshold {
495            continue;
496        }
497        if output_boxes.len() >= output_boxes.capacity() {
498            break;
499        }
500        output_boxes.push(DetectBox {
501            bbox: BoundingBox {
502                xmin: boxes[[i, 0]].as_(),
503                ymin: boxes[[i, 1]].as_(),
504                xmax: boxes[[i, 2]].as_(),
505                ymax: boxes[[i, 3]].as_(),
506            },
507            score,
508            label: classes[i].as_() as usize,
509        });
510    }
511    Ok(())
512}
513
514/// Decodes split end-to-end YOLO detection + segmentation outputs.
515///
516/// Input shapes (after batch dim removed):
517/// - boxes: (4, N) — xyxy pixel coordinates
518/// - scores: (1, N) — confidence
519/// - classes: (1, N) — class index
520/// - mask_coeff: (num_protos, N) — mask coefficients per detection
521/// - protos: (proto_h, proto_w, num_protos) — prototype masks
522#[allow(clippy::too_many_arguments)]
523pub fn decode_yolo_split_end_to_end_segdet_float<T>(
524    boxes: ArrayView2<T>,
525    scores: ArrayView2<T>,
526    classes: ArrayView2<T>,
527    mask_coeff: ArrayView2<T>,
528    protos: ArrayView3<T>,
529    score_threshold: f32,
530    output_boxes: &mut Vec<DetectBox>,
531    output_masks: &mut Vec<crate::Segmentation>,
532) -> Result<(), crate::DecoderError>
533where
534    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
535    f32: AsPrimitive<T>,
536{
537    let (boxes, scores, classes, mask_coeff) =
538        postprocess_yolo_split_end_to_end_segdet(boxes, scores, &classes, mask_coeff)?;
539    let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
540        boxes,
541        scores,
542        classes,
543        score_threshold,
544        output_boxes.capacity(),
545    );
546
547    impl_yolo_split_segdet_process_masks(boxes, mask_coeff, protos, output_boxes, output_masks)
548}
549
550#[allow(clippy::type_complexity)]
551pub(crate) fn postprocess_yolo_end_to_end_segdet<'a, T>(
552    output: &'a ArrayView2<'_, T>,
553    num_protos: usize,
554) -> Result<
555    (
556        ArrayView2<'a, T>,
557        ArrayView2<'a, T>,
558        ArrayView1<'a, T>,
559        ArrayView2<'a, T>,
560    ),
561    crate::DecoderError,
562> {
563    // Validate input shape: need at least 7 rows (6 base + at least 1 mask coeff)
564    if output.shape()[0] < 7 {
565        return Err(crate::DecoderError::InvalidShape(format!(
566            "End-to-end segdet output requires at least 7 rows, got {}",
567            output.shape()[0]
568        )));
569    }
570
571    let num_mask_coeffs = output.shape()[0] - 6;
572    if num_mask_coeffs != num_protos {
573        return Err(crate::DecoderError::InvalidShape(format!(
574            "Mask coefficients count ({}) doesn't match protos count ({})",
575            num_mask_coeffs, num_protos
576        )));
577    }
578
579    // Input shape: (6+num_protos, N) -> transpose for postprocessing
580    let boxes = output.slice(s![0..4, ..]).reversed_axes();
581    let scores = output.slice(s![4..5, ..]).reversed_axes();
582    let classes = output.slice(s![5, ..]);
583    let mask_coeff = output.slice(s![6.., ..]).reversed_axes();
584    Ok((boxes, scores, classes, mask_coeff))
585}
586
587/// Postprocess yolo split end to end det by reversing axes of boxes,
588/// scores, and flattening the class tensor.
589/// Expected input shapes:
590/// - boxes: (4, N)
591/// - scores: (1, N)
592/// - classes: (1, N)
593#[allow(clippy::type_complexity)]
594pub(crate) fn postprocess_yolo_split_end_to_end_det<'a, 'b, 'c, BOXES, SCORES, CLASS>(
595    boxes: ArrayView2<'a, BOXES>,
596    scores: ArrayView2<'b, SCORES>,
597    classes: &'c ArrayView2<CLASS>,
598) -> Result<
599    (
600        ArrayView2<'a, BOXES>,
601        ArrayView2<'b, SCORES>,
602        ArrayView1<'c, CLASS>,
603    ),
604    crate::DecoderError,
605> {
606    let num_boxes = boxes.shape()[1];
607    if boxes.shape()[0] != 4 {
608        return Err(crate::DecoderError::InvalidShape(format!(
609            "Split end-to-end box_coords must be 4, got {}",
610            boxes.shape()[0]
611        )));
612    }
613
614    if scores.shape()[0] != 1 {
615        return Err(crate::DecoderError::InvalidShape(format!(
616            "Split end-to-end scores num_classes must be 1, got {}",
617            scores.shape()[0]
618        )));
619    }
620
621    if classes.shape()[0] != 1 {
622        return Err(crate::DecoderError::InvalidShape(format!(
623            "Split end-to-end classes num_classes must be 1, got {}",
624            classes.shape()[0]
625        )));
626    }
627
628    if scores.shape()[1] != num_boxes {
629        return Err(crate::DecoderError::InvalidShape(format!(
630            "Split end-to-end scores must have same num_boxes as boxes ({}), got {}",
631            num_boxes,
632            scores.shape()[1]
633        )));
634    }
635
636    if classes.shape()[1] != num_boxes {
637        return Err(crate::DecoderError::InvalidShape(format!(
638            "Split end-to-end classes must have same num_boxes as boxes ({}), got {}",
639            num_boxes,
640            classes.shape()[1]
641        )));
642    }
643
644    let boxes = boxes.reversed_axes();
645    let scores = scores.reversed_axes();
646    let classes = classes.slice(s![0, ..]);
647    Ok((boxes, scores, classes))
648}
649
650/// Postprocess yolo split end to end segdet by reversing axes of boxes,
651/// scores, mask tensors and flattening the class tensor.
652#[allow(clippy::type_complexity)]
653pub(crate) fn postprocess_yolo_split_end_to_end_segdet<
654    'a,
655    'b,
656    'c,
657    'd,
658    BOXES,
659    SCORES,
660    CLASS,
661    MASK,
662>(
663    boxes: ArrayView2<'a, BOXES>,
664    scores: ArrayView2<'b, SCORES>,
665    classes: &'c ArrayView2<CLASS>,
666    mask_coeff: ArrayView2<'d, MASK>,
667) -> Result<
668    (
669        ArrayView2<'a, BOXES>,
670        ArrayView2<'b, SCORES>,
671        ArrayView1<'c, CLASS>,
672        ArrayView2<'d, MASK>,
673    ),
674    crate::DecoderError,
675> {
676    let num_boxes = boxes.shape()[1];
677    if boxes.shape()[0] != 4 {
678        return Err(crate::DecoderError::InvalidShape(format!(
679            "Split end-to-end box_coords must be 4, got {}",
680            boxes.shape()[0]
681        )));
682    }
683
684    if scores.shape()[0] != 1 {
685        return Err(crate::DecoderError::InvalidShape(format!(
686            "Split end-to-end scores num_classes must be 1, got {}",
687            scores.shape()[0]
688        )));
689    }
690
691    if classes.shape()[0] != 1 {
692        return Err(crate::DecoderError::InvalidShape(format!(
693            "Split end-to-end classes num_classes must be 1, got {}",
694            classes.shape()[0]
695        )));
696    }
697
698    if scores.shape()[1] != num_boxes {
699        return Err(crate::DecoderError::InvalidShape(format!(
700            "Split end-to-end scores must have same num_boxes as boxes ({}), got {}",
701            num_boxes,
702            scores.shape()[1]
703        )));
704    }
705
706    if classes.shape()[1] != num_boxes {
707        return Err(crate::DecoderError::InvalidShape(format!(
708            "Split end-to-end classes must have same num_boxes as boxes ({}), got {}",
709            num_boxes,
710            classes.shape()[1]
711        )));
712    }
713
714    if mask_coeff.shape()[1] != num_boxes {
715        return Err(crate::DecoderError::InvalidShape(format!(
716            "Split end-to-end mask_coeff must have same num_boxes as boxes ({}), got {}",
717            num_boxes,
718            mask_coeff.shape()[1]
719        )));
720    }
721
722    let boxes = boxes.reversed_axes();
723    let scores = scores.reversed_axes();
724    let classes = classes.slice(s![0, ..]);
725    let mask_coeff = mask_coeff.reversed_axes();
726    Ok((boxes, scores, classes, mask_coeff))
727}
728/// Internal implementation of YOLO decoding for quantized tensors.
729///
730/// Expected shapes of inputs:
731/// - output: (4 + num_classes, num_boxes)
732pub(crate) fn impl_yolo_quant<B: BBoxTypeTrait, T: PrimInt + AsPrimitive<f32> + Send + Sync>(
733    output: (ArrayView2<T>, Quantization),
734    score_threshold: f32,
735    iou_threshold: f32,
736    nms: Option<Nms>,
737    output_boxes: &mut Vec<DetectBox>,
738) where
739    f32: AsPrimitive<T>,
740{
741    let (boxes, quant_boxes) = output;
742    let (boxes_tensor, scores_tensor) = postprocess_yolo(&boxes);
743
744    let boxes = {
745        let score_threshold = quantize_score_threshold(score_threshold, quant_boxes);
746        postprocess_boxes_quant::<B, _, _>(
747            score_threshold,
748            boxes_tensor,
749            scores_tensor,
750            quant_boxes,
751        )
752    };
753
754    let boxes = dispatch_nms_int(nms, iou_threshold, boxes);
755    let len = output_boxes.capacity().min(boxes.len());
756    output_boxes.clear();
757    for b in boxes.iter().take(len) {
758        output_boxes.push(dequant_detect_box(b, quant_boxes));
759    }
760}
761
762/// Internal implementation of YOLO decoding for float tensors.
763///
764/// Expected shapes of inputs:
765/// - output: (4 + num_classes, num_boxes)
766pub(crate) fn impl_yolo_float<B: BBoxTypeTrait, T: Float + AsPrimitive<f32> + Send + Sync>(
767    output: ArrayView2<T>,
768    score_threshold: f32,
769    iou_threshold: f32,
770    nms: Option<Nms>,
771    output_boxes: &mut Vec<DetectBox>,
772) where
773    f32: AsPrimitive<T>,
774{
775    let (boxes_tensor, scores_tensor) = postprocess_yolo(&output);
776    let boxes =
777        postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
778    let boxes = dispatch_nms_float(nms, iou_threshold, boxes);
779    let len = output_boxes.capacity().min(boxes.len());
780    output_boxes.clear();
781    for b in boxes.into_iter().take(len) {
782        output_boxes.push(b);
783    }
784}
785
786/// Internal implementation of YOLO split detection decoding for quantized
787/// tensors.
788///
789/// Expected shapes of inputs:
790/// - boxes: (4, num_boxes)
791/// - scores: (num_classes, num_boxes)
792///
793/// # Panics
794/// Panics if shapes don't match the expected dimensions.
795pub(crate) fn impl_yolo_split_quant<
796    B: BBoxTypeTrait,
797    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
798    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
799>(
800    boxes: (ArrayView2<BOX>, Quantization),
801    scores: (ArrayView2<SCORE>, Quantization),
802    score_threshold: f32,
803    iou_threshold: f32,
804    nms: Option<Nms>,
805    output_boxes: &mut Vec<DetectBox>,
806) where
807    f32: AsPrimitive<SCORE>,
808{
809    let (boxes_tensor, quant_boxes) = boxes;
810    let (scores_tensor, quant_scores) = scores;
811
812    let boxes_tensor = boxes_tensor.reversed_axes();
813    let scores_tensor = scores_tensor.reversed_axes();
814
815    let boxes = {
816        let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
817        postprocess_boxes_quant::<B, _, _>(
818            score_threshold,
819            boxes_tensor,
820            scores_tensor,
821            quant_boxes,
822        )
823    };
824
825    let boxes = dispatch_nms_int(nms, iou_threshold, boxes);
826    let len = output_boxes.capacity().min(boxes.len());
827    output_boxes.clear();
828    for b in boxes.iter().take(len) {
829        output_boxes.push(dequant_detect_box(b, quant_scores));
830    }
831}
832
833/// Internal implementation of YOLO split detection decoding for float tensors.
834///
835/// Expected shapes of inputs:
836/// - boxes: (4, num_boxes)
837/// - scores: (num_classes, num_boxes)
838///
839/// # Panics
840/// Panics if shapes don't match the expected dimensions.
841pub(crate) fn impl_yolo_split_float<
842    B: BBoxTypeTrait,
843    BOX: Float + AsPrimitive<f32> + Send + Sync,
844    SCORE: Float + AsPrimitive<f32> + Send + Sync,
845>(
846    boxes_tensor: ArrayView2<BOX>,
847    scores_tensor: ArrayView2<SCORE>,
848    score_threshold: f32,
849    iou_threshold: f32,
850    nms: Option<Nms>,
851    output_boxes: &mut Vec<DetectBox>,
852) where
853    f32: AsPrimitive<SCORE>,
854{
855    let boxes_tensor = boxes_tensor.reversed_axes();
856    let scores_tensor = scores_tensor.reversed_axes();
857    let boxes =
858        postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
859    let boxes = dispatch_nms_float(nms, iou_threshold, boxes);
860    let len = output_boxes.capacity().min(boxes.len());
861    output_boxes.clear();
862    for b in boxes.into_iter().take(len) {
863        output_boxes.push(b);
864    }
865}
866
867/// Internal implementation of YOLO detection segmentation decoding for
868/// quantized tensors.
869///
870/// Expected shapes of inputs:
871/// - boxes: (4 + num_classes + num_protos, num_boxes)
872/// - protos: (proto_height, proto_width, num_protos)
873///
874/// # Errors
875/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
876pub(crate) fn impl_yolo_segdet_quant<
877    B: BBoxTypeTrait,
878    BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
879    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
880>(
881    boxes: (ArrayView2<BOX>, Quantization),
882    protos: (ArrayView3<PROTO>, Quantization),
883    score_threshold: f32,
884    iou_threshold: f32,
885    nms: Option<Nms>,
886    output_boxes: &mut Vec<DetectBox>,
887    output_masks: &mut Vec<Segmentation>,
888) -> Result<(), crate::DecoderError>
889where
890    f32: AsPrimitive<BOX>,
891{
892    let (boxes, quant_boxes) = boxes;
893    let num_protos = protos.0.dim().2;
894
895    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
896    let boxes = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
897        (boxes_tensor, quant_boxes),
898        (scores_tensor, quant_boxes),
899        score_threshold,
900        iou_threshold,
901        nms,
902        output_boxes.capacity(),
903    );
904
905    impl_yolo_split_segdet_quant_process_masks::<_, _>(
906        boxes,
907        (mask_tensor, quant_boxes),
908        protos,
909        output_boxes,
910        output_masks,
911    )
912}
913
914/// Internal implementation of YOLO detection segmentation decoding for
915/// float tensors.
916///
917/// Expected shapes of inputs:
918/// - boxes: (4 + num_classes + num_protos, num_boxes)
919/// - protos: (proto_height, proto_width, num_protos)
920///
921/// # Panics
922/// Panics if shapes don't match the expected dimensions.
923pub(crate) fn impl_yolo_segdet_float<
924    B: BBoxTypeTrait,
925    BOX: Float + AsPrimitive<f32> + Send + Sync,
926    PROTO: Float + AsPrimitive<f32> + Send + Sync,
927>(
928    boxes: ArrayView2<BOX>,
929    protos: ArrayView3<PROTO>,
930    score_threshold: f32,
931    iou_threshold: f32,
932    nms: Option<Nms>,
933    output_boxes: &mut Vec<DetectBox>,
934    output_masks: &mut Vec<Segmentation>,
935) -> Result<(), crate::DecoderError>
936where
937    f32: AsPrimitive<BOX>,
938{
939    let num_protos = protos.dim().2;
940    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
941    let boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
942        boxes_tensor,
943        scores_tensor,
944        score_threshold,
945        iou_threshold,
946        nms,
947        output_boxes.capacity(),
948    );
949    impl_yolo_split_segdet_process_masks(boxes, mask_tensor, protos, output_boxes, output_masks)
950}
951
952pub(crate) fn impl_yolo_segdet_get_boxes<
953    B: BBoxTypeTrait,
954    BOX: Float + AsPrimitive<f32> + Send + Sync,
955    SCORE: Float + AsPrimitive<f32> + Send + Sync,
956>(
957    boxes_tensor: ArrayView2<BOX>,
958    scores_tensor: ArrayView2<SCORE>,
959    score_threshold: f32,
960    iou_threshold: f32,
961    nms: Option<Nms>,
962    max_boxes: usize,
963) -> Vec<(DetectBox, usize)>
964where
965    f32: AsPrimitive<SCORE>,
966{
967    let mut boxes = postprocess_boxes_index_float::<B, _, _>(
968        score_threshold.as_(),
969        boxes_tensor,
970        scores_tensor,
971    );
972    truncate_to_top_k_by_score(&mut boxes);
973    let mut boxes = dispatch_nms_extra_float(nms, iou_threshold, boxes);
974    boxes.truncate(max_boxes);
975    boxes
976}
977
978pub(crate) fn impl_yolo_end_to_end_segdet_get_boxes<
979    B: BBoxTypeTrait,
980    BOX: Float + AsPrimitive<f32> + Send + Sync,
981    SCORE: Float + AsPrimitive<f32> + Send + Sync,
982    CLASS: AsPrimitive<f32> + Send + Sync,
983>(
984    boxes: ArrayView2<BOX>,
985    scores: ArrayView2<SCORE>,
986    classes: ArrayView1<CLASS>,
987    score_threshold: f32,
988    max_boxes: usize,
989) -> Vec<(DetectBox, usize)>
990where
991    f32: AsPrimitive<SCORE>,
992{
993    let mut boxes = postprocess_boxes_index_float::<B, _, _>(score_threshold.as_(), boxes, scores);
994    boxes.truncate(max_boxes);
995    for (b, ind) in &mut boxes {
996        b.label = classes[*ind].as_().round() as usize;
997    }
998    boxes
999}
1000
1001pub(crate) fn impl_yolo_split_segdet_process_masks<
1002    MASK: Float + AsPrimitive<f32> + Send + Sync,
1003    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1004>(
1005    boxes: Vec<(DetectBox, usize)>,
1006    masks_tensor: ArrayView2<MASK>,
1007    protos_tensor: ArrayView3<PROTO>,
1008    output_boxes: &mut Vec<DetectBox>,
1009    output_masks: &mut Vec<Segmentation>,
1010) -> Result<(), crate::DecoderError> {
1011    let boxes = decode_segdet_f32(boxes, masks_tensor, protos_tensor)?;
1012    output_boxes.clear();
1013    output_masks.clear();
1014    for (b, m) in boxes.into_iter() {
1015        output_boxes.push(b);
1016        output_masks.push(Segmentation {
1017            xmin: b.bbox.xmin,
1018            ymin: b.bbox.ymin,
1019            xmax: b.bbox.xmax,
1020            ymax: b.bbox.ymax,
1021            segmentation: m,
1022        });
1023    }
1024    Ok(())
1025}
1026/// Expected input shapes:
1027/// - boxes_tensor: (num_boxes, 4)
1028/// - scores_tensor: (num_boxes, num_classes)
1029pub(crate) fn impl_yolo_split_segdet_quant_get_boxes<
1030    B: BBoxTypeTrait,
1031    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
1032    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
1033>(
1034    boxes: (ArrayView2<BOX>, Quantization),
1035    scores: (ArrayView2<SCORE>, Quantization),
1036    score_threshold: f32,
1037    iou_threshold: f32,
1038    nms: Option<Nms>,
1039    max_boxes: usize,
1040) -> Vec<(DetectBox, usize)>
1041where
1042    f32: AsPrimitive<SCORE>,
1043{
1044    let (boxes_tensor, quant_boxes) = boxes;
1045    let (scores_tensor, quant_scores) = scores;
1046
1047    let mut boxes = {
1048        let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
1049        postprocess_boxes_index_quant::<B, _, _>(
1050            score_threshold,
1051            boxes_tensor,
1052            scores_tensor,
1053            quant_boxes,
1054        )
1055    };
1056    truncate_to_top_k_by_score_quant(&mut boxes);
1057    let mut boxes = dispatch_nms_extra_int(nms, iou_threshold, boxes);
1058    boxes.truncate(max_boxes);
1059    boxes
1060        .into_iter()
1061        .map(|(b, i)| (dequant_detect_box(&b, quant_scores), i))
1062        .collect()
1063}
1064
1065pub(crate) fn impl_yolo_split_segdet_quant_process_masks<
1066    MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1067    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1068>(
1069    boxes: Vec<(DetectBox, usize)>,
1070    mask_coeff: (ArrayView2<MASK>, Quantization),
1071    protos: (ArrayView3<PROTO>, Quantization),
1072    output_boxes: &mut Vec<DetectBox>,
1073    output_masks: &mut Vec<Segmentation>,
1074) -> Result<(), crate::DecoderError> {
1075    let (masks, quant_masks) = mask_coeff;
1076    let (protos, quant_protos) = protos;
1077
1078    let boxes = decode_segdet_quant(boxes, masks, protos, quant_masks, quant_protos)?;
1079    output_boxes.clear();
1080    output_masks.clear();
1081    for (b, m) in boxes.into_iter() {
1082        output_boxes.push(b);
1083        output_masks.push(Segmentation {
1084            xmin: b.bbox.xmin,
1085            ymin: b.bbox.ymin,
1086            xmax: b.bbox.xmax,
1087            ymax: b.bbox.ymax,
1088            segmentation: m,
1089        });
1090    }
1091    Ok(())
1092}
1093
1094#[allow(clippy::too_many_arguments)]
1095/// Internal implementation of YOLO split detection segmentation decoding for
1096/// quantized tensors.
1097///
1098/// Expected shapes of inputs:
1099/// - boxes_tensor: (4, num_boxes)
1100/// - scores_tensor: (num_classes, num_boxes)
1101/// - mask_tensor: (num_protos, num_boxes)
1102/// - protos: (proto_height, proto_width, num_protos)
1103///
1104/// # Errors
1105/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
1106pub(crate) fn impl_yolo_split_segdet_quant<
1107    B: BBoxTypeTrait,
1108    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
1109    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
1110    MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1111    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1112>(
1113    boxes: (ArrayView2<BOX>, Quantization),
1114    scores: (ArrayView2<SCORE>, Quantization),
1115    mask_coeff: (ArrayView2<MASK>, Quantization),
1116    protos: (ArrayView3<PROTO>, Quantization),
1117    score_threshold: f32,
1118    iou_threshold: f32,
1119    nms: Option<Nms>,
1120    output_boxes: &mut Vec<DetectBox>,
1121    output_masks: &mut Vec<Segmentation>,
1122) -> Result<(), crate::DecoderError>
1123where
1124    f32: AsPrimitive<SCORE>,
1125{
1126    let (boxes_, scores_, mask_coeff_) =
1127        postprocess_yolo_split_segdet(boxes.0, scores.0, mask_coeff.0);
1128    let boxes = (boxes_, boxes.1);
1129    let scores = (scores_, scores.1);
1130    let mask_coeff = (mask_coeff_, mask_coeff.1);
1131
1132    let boxes = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
1133        boxes,
1134        scores,
1135        score_threshold,
1136        iou_threshold,
1137        nms,
1138        output_boxes.capacity(),
1139    );
1140
1141    impl_yolo_split_segdet_quant_process_masks(
1142        boxes,
1143        mask_coeff,
1144        protos,
1145        output_boxes,
1146        output_masks,
1147    )
1148}
1149
1150#[allow(clippy::too_many_arguments)]
1151/// Internal implementation of YOLO split detection segmentation decoding for
1152/// float tensors.
1153///
1154/// Expected shapes of inputs:
1155/// - boxes_tensor: (4, num_boxes)
1156/// - scores_tensor: (num_classes, num_boxes)
1157/// - mask_tensor: (num_protos, num_boxes)
1158/// - protos: (proto_height, proto_width, num_protos)
1159///
1160/// # Errors
1161/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
1162pub(crate) fn impl_yolo_split_segdet_float<
1163    B: BBoxTypeTrait,
1164    BOX: Float + AsPrimitive<f32> + Send + Sync,
1165    SCORE: Float + AsPrimitive<f32> + Send + Sync,
1166    MASK: Float + AsPrimitive<f32> + Send + Sync,
1167    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1168>(
1169    boxes_tensor: ArrayView2<BOX>,
1170    scores_tensor: ArrayView2<SCORE>,
1171    mask_tensor: ArrayView2<MASK>,
1172    protos: ArrayView3<PROTO>,
1173    score_threshold: f32,
1174    iou_threshold: f32,
1175    nms: Option<Nms>,
1176    output_boxes: &mut Vec<DetectBox>,
1177    output_masks: &mut Vec<Segmentation>,
1178) -> Result<(), crate::DecoderError>
1179where
1180    f32: AsPrimitive<SCORE>,
1181{
1182    let (boxes_tensor, scores_tensor, mask_tensor) =
1183        postprocess_yolo_split_segdet(boxes_tensor, scores_tensor, mask_tensor);
1184
1185    let boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
1186        boxes_tensor,
1187        scores_tensor,
1188        score_threshold,
1189        iou_threshold,
1190        nms,
1191        output_boxes.capacity(),
1192    );
1193    impl_yolo_split_segdet_process_masks(boxes, mask_tensor, protos, output_boxes, output_masks)
1194}
1195
1196// ---------------------------------------------------------------------------
1197// Proto-extraction variants: return ProtoData instead of materialized masks
1198// ---------------------------------------------------------------------------
1199
1200/// Proto-extraction variant of `impl_yolo_segdet_quant`.
1201/// Runs NMS but returns raw `ProtoData` instead of materialized masks.
1202pub fn impl_yolo_segdet_quant_proto<
1203    B: BBoxTypeTrait,
1204    BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1205    PROTO: PrimInt
1206        + AsPrimitive<i64>
1207        + AsPrimitive<i128>
1208        + AsPrimitive<f32>
1209        + AsPrimitive<i8>
1210        + Send
1211        + Sync,
1212>(
1213    boxes: (ArrayView2<BOX>, Quantization),
1214    protos: (ArrayView3<PROTO>, Quantization),
1215    score_threshold: f32,
1216    iou_threshold: f32,
1217    nms: Option<Nms>,
1218    output_boxes: &mut Vec<DetectBox>,
1219) -> ProtoData
1220where
1221    f32: AsPrimitive<BOX>,
1222{
1223    let (boxes_arr, quant_boxes) = boxes;
1224    let (protos_arr, quant_protos) = protos;
1225    let num_protos = protos_arr.dim().2;
1226
1227    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes_arr, num_protos);
1228
1229    let det_indices = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
1230        (boxes_tensor, quant_boxes),
1231        (scores_tensor, quant_boxes),
1232        score_threshold,
1233        iou_threshold,
1234        nms,
1235        output_boxes.capacity(),
1236    );
1237
1238    extract_proto_data_quant(
1239        det_indices,
1240        mask_tensor,
1241        quant_boxes,
1242        protos_arr,
1243        quant_protos,
1244        output_boxes,
1245    )
1246}
1247
1248/// Proto-extraction variant of `impl_yolo_segdet_float`.
1249/// Runs NMS but returns raw `ProtoData` instead of materialized masks.
1250pub(crate) fn impl_yolo_segdet_float_proto<
1251    B: BBoxTypeTrait,
1252    BOX: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1253    PROTO: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1254>(
1255    boxes: ArrayView2<BOX>,
1256    protos: ArrayView3<PROTO>,
1257    score_threshold: f32,
1258    iou_threshold: f32,
1259    nms: Option<Nms>,
1260    output_boxes: &mut Vec<DetectBox>,
1261) -> ProtoData
1262where
1263    f32: AsPrimitive<BOX>,
1264{
1265    let num_protos = protos.dim().2;
1266    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
1267
1268    let boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
1269        boxes_tensor,
1270        scores_tensor,
1271        score_threshold,
1272        iou_threshold,
1273        nms,
1274        output_boxes.capacity(),
1275    );
1276
1277    extract_proto_data_float(boxes, mask_tensor, protos, output_boxes)
1278}
1279
1280/// Proto-extraction variant of `impl_yolo_split_segdet_float`.
1281/// Runs NMS but returns raw `ProtoData` instead of materialized masks.
1282#[allow(clippy::too_many_arguments)]
1283pub(crate) fn impl_yolo_split_segdet_float_proto<
1284    B: BBoxTypeTrait,
1285    BOX: Float + AsPrimitive<f32> + Send + Sync,
1286    SCORE: Float + AsPrimitive<f32> + Send + Sync,
1287    MASK: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1288    PROTO: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1289>(
1290    boxes_tensor: ArrayView2<BOX>,
1291    scores_tensor: ArrayView2<SCORE>,
1292    mask_tensor: ArrayView2<MASK>,
1293    protos: ArrayView3<PROTO>,
1294    score_threshold: f32,
1295    iou_threshold: f32,
1296    nms: Option<Nms>,
1297    output_boxes: &mut Vec<DetectBox>,
1298) -> ProtoData
1299where
1300    f32: AsPrimitive<SCORE>,
1301{
1302    let (boxes_tensor, scores_tensor, mask_tensor) =
1303        postprocess_yolo_split_segdet(boxes_tensor, scores_tensor, mask_tensor);
1304    let det_indices = impl_yolo_segdet_get_boxes::<B, _, _>(
1305        boxes_tensor,
1306        scores_tensor,
1307        score_threshold,
1308        iou_threshold,
1309        nms,
1310        output_boxes.capacity(),
1311    );
1312
1313    extract_proto_data_float(det_indices, mask_tensor, protos, output_boxes)
1314}
1315
1316/// Proto-extraction variant of `decode_yolo_end_to_end_segdet_float`.
1317pub fn decode_yolo_end_to_end_segdet_float_proto<T>(
1318    output: ArrayView2<T>,
1319    protos: ArrayView3<T>,
1320    score_threshold: f32,
1321    output_boxes: &mut Vec<DetectBox>,
1322) -> Result<ProtoData, crate::DecoderError>
1323where
1324    T: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1325    f32: AsPrimitive<T>,
1326{
1327    let (boxes, scores, classes, mask_coeff) =
1328        postprocess_yolo_end_to_end_segdet(&output, protos.dim().2)?;
1329    let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
1330        boxes,
1331        scores,
1332        classes,
1333        score_threshold,
1334        output_boxes.capacity(),
1335    );
1336
1337    Ok(extract_proto_data_float(
1338        boxes,
1339        mask_coeff,
1340        protos,
1341        output_boxes,
1342    ))
1343}
1344
1345/// Proto-extraction variant of `decode_yolo_split_end_to_end_segdet_float`.
1346#[allow(clippy::too_many_arguments)]
1347pub fn decode_yolo_split_end_to_end_segdet_float_proto<T>(
1348    boxes: ArrayView2<T>,
1349    scores: ArrayView2<T>,
1350    classes: ArrayView2<T>,
1351    mask_coeff: ArrayView2<T>,
1352    protos: ArrayView3<T>,
1353    score_threshold: f32,
1354    output_boxes: &mut Vec<DetectBox>,
1355) -> Result<ProtoData, crate::DecoderError>
1356where
1357    T: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1358    f32: AsPrimitive<T>,
1359{
1360    let (boxes, scores, classes, mask_coeff) =
1361        postprocess_yolo_split_end_to_end_segdet(boxes, scores, &classes, mask_coeff)?;
1362    let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
1363        boxes,
1364        scores,
1365        classes,
1366        score_threshold,
1367        output_boxes.capacity(),
1368    );
1369
1370    Ok(extract_proto_data_float(
1371        boxes,
1372        mask_coeff,
1373        protos,
1374        output_boxes,
1375    ))
1376}
1377
1378/// Helper: extract ProtoData from float mask coefficients + protos.
1379///
1380/// Builds [`ProtoData`] with both `protos` and `mask_coefficients` as
1381/// [`edgefirst_tensor::TensorDyn`]. Preserves the native element type for
1382/// `f16` and `f32`; narrows `f64` to `f32` (there is no native f64 kernel
1383/// path). `mask_coefficients` shape is `[num_detections, num_protos]`.
1384pub(super) fn extract_proto_data_float<
1385    MASK: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1386    PROTO: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1387>(
1388    det_indices: Vec<(DetectBox, usize)>,
1389    mask_tensor: ArrayView2<MASK>,
1390    protos: ArrayView3<PROTO>,
1391    output_boxes: &mut Vec<DetectBox>,
1392) -> ProtoData {
1393    let num_protos = mask_tensor.ncols();
1394    let n = det_indices.len();
1395
1396    // Per-detection coefficients packed row-major into a contiguous buffer,
1397    // preserving the source dtype. Shape: [N, num_protos] — N=0 is permitted
1398    // (tracker path emits no detections this frame) since Mem-backed tensors
1399    // accept zero-element shapes as "empty collection" sentinels.
1400    let mut coeff_rows: Vec<MASK> = Vec::with_capacity(n * num_protos);
1401    output_boxes.clear();
1402    for (det, idx) in det_indices {
1403        output_boxes.push(det);
1404        let row = mask_tensor.row(idx);
1405        coeff_rows.extend(row.iter().copied());
1406    }
1407
1408    let mask_coefficients = MASK::slice_into_tensor_dyn(&coeff_rows, &[n, num_protos])
1409        .expect("allocating mask_coefficients TensorDyn");
1410    let protos_tensor =
1411        PROTO::arrayview3_into_tensor_dyn(protos).expect("allocating protos TensorDyn");
1412
1413    ProtoData {
1414        mask_coefficients,
1415        protos: protos_tensor,
1416    }
1417}
1418
1419/// Helper: extract ProtoData from quantized mask coefficients + protos.
1420///
1421/// Dequantizes mask coefficients to f32 at extraction (one-time cost on a
1422/// `num_detections * num_protos` slice) and keeps protos in raw i8,
1423/// attaching the dequantization params as
1424/// [`edgefirst_tensor::Quantization::per_tensor`] metadata on the proto
1425/// tensor. The GPU shader / CPU kernel reads `protos.quantization()` and
1426/// dequantizes per-texel.
1427pub(crate) fn extract_proto_data_quant<
1428    MASK: PrimInt + AsPrimitive<f32> + Send + Sync,
1429    PROTO: PrimInt + AsPrimitive<f32> + AsPrimitive<i8> + Send + Sync + 'static,
1430>(
1431    det_indices: Vec<(DetectBox, usize)>,
1432    mask_tensor: ArrayView2<MASK>,
1433    quant_masks: Quantization,
1434    protos: ArrayView3<PROTO>,
1435    quant_protos: Quantization,
1436    output_boxes: &mut Vec<DetectBox>,
1437) -> ProtoData {
1438    use edgefirst_tensor::{Tensor, TensorDyn, TensorMapTrait, TensorMemory, TensorTrait};
1439
1440    let num_protos = mask_tensor.ncols();
1441    let n = det_indices.len();
1442    let mut coeff_f32 = Vec::<f32>::with_capacity(n * num_protos);
1443    output_boxes.clear();
1444    for (det, idx) in det_indices {
1445        output_boxes.push(det);
1446        let row = mask_tensor.row(idx);
1447        coeff_f32.extend(
1448            row.iter()
1449                .map(|v| (v.as_() - quant_masks.zero_point as f32) * quant_masks.scale),
1450        );
1451    }
1452
1453    // Shape `[n, num_protos]` with n=0 is permitted (tracker path emits no
1454    // fresh detections this frame) via the Mem-backed zero-size allowance.
1455    let coeff_tensor = Tensor::<f32>::new(&[n, num_protos], Some(TensorMemory::Mem), None)
1456        .expect("allocating mask_coefficients tensor");
1457    if n > 0 {
1458        let mut m = coeff_tensor
1459            .map()
1460            .expect("mapping mask_coefficients tensor");
1461        m.as_mut_slice().copy_from_slice(&coeff_f32);
1462    }
1463    let mask_coefficients = TensorDyn::F32(coeff_tensor);
1464
1465    // Keep protos in raw i8 — consumers dequantize via protos.quantization().
1466    // When PROTO is already i8, memcpy via to_owned(); else per-element as_().
1467    let (h, w, k) = protos.dim();
1468    let protos_tensor = Tensor::<i8>::new(&[h, w, k], Some(TensorMemory::Mem), None)
1469        .expect("allocating protos tensor");
1470    {
1471        let mut m = protos_tensor.map().expect("mapping protos tensor");
1472        let dst = m.as_mut_slice();
1473        if std::any::TypeId::of::<PROTO>() == std::any::TypeId::of::<i8>() {
1474            // SAFETY: PROTO == i8 checked via TypeId; cast slice view is
1475            // size/alignment-compatible by construction.
1476            let src: &[i8] =
1477                unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
1478            if protos.is_standard_layout() {
1479                dst.copy_from_slice(src);
1480            } else {
1481                for (d, s) in dst.iter_mut().zip(protos.iter()) {
1482                    let v_i8: i8 = s.as_();
1483                    *d = v_i8;
1484                }
1485            }
1486        } else {
1487            for (d, s) in dst.iter_mut().zip(protos.iter()) {
1488                let v_i8: i8 = s.as_();
1489                *d = v_i8;
1490            }
1491        }
1492    }
1493    let tensor_quant =
1494        edgefirst_tensor::Quantization::per_tensor(quant_protos.scale, quant_protos.zero_point);
1495    let protos_tensor = protos_tensor
1496        .with_quantization(tensor_quant)
1497        .expect("per-tensor quantization on new Tensor<i8>");
1498
1499    ProtoData {
1500        mask_coefficients,
1501        protos: TensorDyn::I8(protos_tensor),
1502    }
1503}
1504
1505/// Per-float-dtype construction of a [`TensorDyn`] from a flat slice / 3-D
1506/// `ArrayView`. Replaces the old `IntoProtoTensor` trait. Each implementor
1507/// either passes its element type straight to `Tensor::from_slice` /
1508/// `Tensor::from_arrayview3`, or narrows `f64` to `f32` (no native f64 kernel
1509/// path exists).
1510pub trait FloatProtoElem: Copy + 'static {
1511    fn slice_into_tensor_dyn(
1512        values: &[Self],
1513        shape: &[usize],
1514    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn>;
1515
1516    fn arrayview3_into_tensor_dyn(
1517        view: ArrayView3<'_, Self>,
1518    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn>;
1519}
1520
1521impl FloatProtoElem for f32 {
1522    fn slice_into_tensor_dyn(
1523        values: &[f32],
1524        shape: &[usize],
1525    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1526        edgefirst_tensor::Tensor::<f32>::from_slice(values, shape)
1527            .map(edgefirst_tensor::TensorDyn::F32)
1528    }
1529    fn arrayview3_into_tensor_dyn(
1530        view: ArrayView3<'_, f32>,
1531    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1532        edgefirst_tensor::Tensor::<f32>::from_arrayview3(view).map(edgefirst_tensor::TensorDyn::F32)
1533    }
1534}
1535
1536impl FloatProtoElem for half::f16 {
1537    fn slice_into_tensor_dyn(
1538        values: &[half::f16],
1539        shape: &[usize],
1540    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1541        edgefirst_tensor::Tensor::<half::f16>::from_slice(values, shape)
1542            .map(edgefirst_tensor::TensorDyn::F16)
1543    }
1544    fn arrayview3_into_tensor_dyn(
1545        view: ArrayView3<'_, half::f16>,
1546    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1547        edgefirst_tensor::Tensor::<half::f16>::from_arrayview3(view)
1548            .map(edgefirst_tensor::TensorDyn::F16)
1549    }
1550}
1551
1552impl FloatProtoElem for f64 {
1553    fn slice_into_tensor_dyn(
1554        values: &[f64],
1555        shape: &[usize],
1556    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1557        // Narrow to f32 — no native f64 kernel path.
1558        let narrowed: Vec<f32> = values.iter().map(|&v| v as f32).collect();
1559        edgefirst_tensor::Tensor::<f32>::from_slice(&narrowed, shape)
1560            .map(edgefirst_tensor::TensorDyn::F32)
1561    }
1562    fn arrayview3_into_tensor_dyn(
1563        view: ArrayView3<'_, f64>,
1564    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1565        let narrowed: ndarray::Array3<f32> = view.mapv(|v| v as f32);
1566        edgefirst_tensor::Tensor::<f32>::from_arrayview3(narrowed.view())
1567            .map(edgefirst_tensor::TensorDyn::F32)
1568    }
1569}
1570
1571fn postprocess_yolo<'a, T>(
1572    output: &'a ArrayView2<'_, T>,
1573) -> (ArrayView2<'a, T>, ArrayView2<'a, T>) {
1574    let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1575    let scores_tensor = output.slice(s![4.., ..,]).reversed_axes();
1576    (boxes_tensor, scores_tensor)
1577}
1578
1579pub(crate) fn postprocess_yolo_seg<'a, T>(
1580    output: &'a ArrayView2<'_, T>,
1581    num_protos: usize,
1582) -> (ArrayView2<'a, T>, ArrayView2<'a, T>, ArrayView2<'a, T>) {
1583    assert!(
1584        output.shape()[0] > num_protos + 4,
1585        "Output shape is too short: {} <= {} + 4",
1586        output.shape()[0],
1587        num_protos
1588    );
1589    let num_classes = output.shape()[0] - 4 - num_protos;
1590    let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1591    let scores_tensor = output.slice(s![4..(num_classes + 4), ..,]).reversed_axes();
1592    let mask_tensor = output.slice(s![(num_classes + 4).., ..,]).reversed_axes();
1593    (boxes_tensor, scores_tensor, mask_tensor)
1594}
1595
1596pub(crate) fn postprocess_yolo_split_segdet<'a, 'b, 'c, BOX, SCORE, MASK>(
1597    boxes_tensor: ArrayView2<'a, BOX>,
1598    scores_tensor: ArrayView2<'b, SCORE>,
1599    mask_tensor: ArrayView2<'c, MASK>,
1600) -> (
1601    ArrayView2<'a, BOX>,
1602    ArrayView2<'b, SCORE>,
1603    ArrayView2<'c, MASK>,
1604) {
1605    let boxes_tensor = boxes_tensor.reversed_axes();
1606    let scores_tensor = scores_tensor.reversed_axes();
1607    let mask_tensor = mask_tensor.reversed_axes();
1608    (boxes_tensor, scores_tensor, mask_tensor)
1609}
1610
1611fn decode_segdet_f32<
1612    MASK: Float + AsPrimitive<f32> + Send + Sync,
1613    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1614>(
1615    boxes: Vec<(DetectBox, usize)>,
1616    masks: ArrayView2<MASK>,
1617    protos: ArrayView3<PROTO>,
1618) -> Result<Vec<(DetectBox, Array3<u8>)>, crate::DecoderError> {
1619    if boxes.is_empty() {
1620        return Ok(Vec::new());
1621    }
1622    if masks.shape()[1] != protos.shape()[2] {
1623        return Err(crate::DecoderError::InvalidShape(format!(
1624            "Mask coefficients count ({}) doesn't match protos channel count ({})",
1625            masks.shape()[1],
1626            protos.shape()[2],
1627        )));
1628    }
1629    boxes
1630        .into_par_iter()
1631        .map(|mut b| {
1632            let ind = b.1;
1633            let (protos, roi) = protobox(&protos, &b.0.bbox)?;
1634            b.0.bbox = roi;
1635            Ok((b.0, make_segmentation(masks.row(ind), protos.view())))
1636        })
1637        .collect()
1638}
1639
1640pub(crate) fn decode_segdet_quant<
1641    MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1642    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1643>(
1644    boxes: Vec<(DetectBox, usize)>,
1645    masks: ArrayView2<MASK>,
1646    protos: ArrayView3<PROTO>,
1647    quant_masks: Quantization,
1648    quant_protos: Quantization,
1649) -> Result<Vec<(DetectBox, Array3<u8>)>, crate::DecoderError> {
1650    if boxes.is_empty() {
1651        return Ok(Vec::new());
1652    }
1653    if masks.shape()[1] != protos.shape()[2] {
1654        return Err(crate::DecoderError::InvalidShape(format!(
1655            "Mask coefficients count ({}) doesn't match protos channel count ({})",
1656            masks.shape()[1],
1657            protos.shape()[2],
1658        )));
1659    }
1660
1661    let total_bits = MASK::zero().count_zeros() + PROTO::zero().count_zeros() + 5; // 32 protos is 2^5
1662    boxes
1663        .into_iter()
1664        .map(|mut b| {
1665            let i = b.1;
1666            let (protos, roi) = protobox(&protos, &b.0.bbox.to_canonical())?;
1667            b.0.bbox = roi;
1668            let seg = match total_bits {
1669                0..=64 => make_segmentation_quant::<MASK, PROTO, i64>(
1670                    masks.row(i),
1671                    protos.view(),
1672                    quant_masks,
1673                    quant_protos,
1674                ),
1675                65..=128 => make_segmentation_quant::<MASK, PROTO, i128>(
1676                    masks.row(i),
1677                    protos.view(),
1678                    quant_masks,
1679                    quant_protos,
1680                ),
1681                _ => {
1682                    return Err(crate::DecoderError::NotSupported(format!(
1683                        "Unsupported bit width ({total_bits}) for segmentation computation"
1684                    )));
1685                }
1686            };
1687            Ok((b.0, seg))
1688        })
1689        .collect()
1690}
1691
1692fn protobox<'a, T>(
1693    protos: &'a ArrayView3<T>,
1694    roi: &BoundingBox,
1695) -> Result<(ArrayView3<'a, T>, BoundingBox), crate::DecoderError> {
1696    let width = protos.dim().1 as f32;
1697    let height = protos.dim().0 as f32;
1698
1699    // Detect un-normalized bounding boxes (pixel-space coordinates).
1700    // protobox expects normalized coordinates in [0, 1]. ONNX models output
1701    // pixel-space boxes (e.g. 0-640) which must be normalized before calling
1702    // decode(). Without this check, pixel-space coordinates silently clamp to
1703    // the proto boundary, producing empty (0, 0, C) masks for every detection.
1704    //
1705    // The limit is set to 2.0 (not 1.01) because YOLO models legitimately
1706    // predict coordinates slightly > 1.0 for objects near frame edges.
1707    // Any value > 2.0 is clearly pixel-space (even the smallest practical
1708    // model input of 32×32 would produce coordinates >> 2.0).
1709    const NORM_LIMIT: f32 = 2.0;
1710    if roi.xmin > NORM_LIMIT
1711        || roi.ymin > NORM_LIMIT
1712        || roi.xmax > NORM_LIMIT
1713        || roi.ymax > NORM_LIMIT
1714    {
1715        return Err(crate::DecoderError::InvalidShape(format!(
1716            "Bounding box coordinates appear un-normalized (pixel-space). \
1717             Got bbox=({:.2}, {:.2}, {:.2}, {:.2}) but expected values in [0, 1]. \
1718             ONNX models output pixel-space boxes — normalize them by dividing by \
1719             the input dimensions before calling decode().",
1720            roi.xmin, roi.ymin, roi.xmax, roi.ymax,
1721        )));
1722    }
1723
1724    let roi = [
1725        (roi.xmin * width).clamp(0.0, width) as usize,
1726        (roi.ymin * height).clamp(0.0, height) as usize,
1727        (roi.xmax * width).clamp(0.0, width).ceil() as usize,
1728        (roi.ymax * height).clamp(0.0, height).ceil() as usize,
1729    ];
1730
1731    let roi_norm = [
1732        roi[0] as f32 / width,
1733        roi[1] as f32 / height,
1734        roi[2] as f32 / width,
1735        roi[3] as f32 / height,
1736    ]
1737    .into();
1738
1739    let cropped = protos.slice(s![roi[1]..roi[3], roi[0]..roi[2], ..]);
1740
1741    Ok((cropped, roi_norm))
1742}
1743
1744/// Compute a single instance segmentation mask from mask coefficients and
1745/// proto maps (float path).
1746///
1747/// Computes `sigmoid(coefficients · protos)` and maps to `[0, 255]`.
1748/// Returns an `(H, W, 1)` u8 array.
1749fn make_segmentation<
1750    MASK: Float + AsPrimitive<f32> + Send + Sync,
1751    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1752>(
1753    mask: ArrayView1<MASK>,
1754    protos: ArrayView3<PROTO>,
1755) -> Array3<u8> {
1756    let shape = protos.shape();
1757
1758    // Safe to unwrap since the shapes will always be compatible
1759    let mask = mask.to_shape((1, mask.len())).unwrap();
1760    let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
1761    let protos = protos.reversed_axes();
1762    let mask = mask.map(|x| x.as_());
1763    let protos = protos.map(|x| x.as_());
1764
1765    // Safe to unwrap since the shapes will always be compatible
1766    let mask = mask
1767        .dot(&protos)
1768        .into_shape_with_order((shape[0], shape[1], 1))
1769        .unwrap();
1770
1771    mask.map(|x| {
1772        let sigmoid = 1.0 / (1.0 + (-*x).exp());
1773        (sigmoid * 255.0).round() as u8
1774    })
1775}
1776
1777/// Compute a single instance segmentation mask from quantized mask
1778/// coefficients and proto maps.
1779///
1780/// Dequantizes both inputs (subtracting zero-points), computes the dot
1781/// product, applies sigmoid, and maps to `[0, 255]`.
1782/// Returns an `(H, W, 1)` u8 array.
1783fn make_segmentation_quant<
1784    MASK: PrimInt + AsPrimitive<DEST> + Send + Sync,
1785    PROTO: PrimInt + AsPrimitive<DEST> + Send + Sync,
1786    DEST: PrimInt + 'static + Signed + AsPrimitive<f32> + Debug,
1787>(
1788    mask: ArrayView1<MASK>,
1789    protos: ArrayView3<PROTO>,
1790    quant_masks: Quantization,
1791    quant_protos: Quantization,
1792) -> Array3<u8>
1793where
1794    i32: AsPrimitive<DEST>,
1795    f32: AsPrimitive<DEST>,
1796{
1797    let shape = protos.shape();
1798
1799    // Safe to unwrap since the shapes will always be compatible
1800    let mask = mask.to_shape((1, mask.len())).unwrap();
1801
1802    let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
1803    let protos = protos.reversed_axes();
1804
1805    let zp = quant_masks.zero_point.as_();
1806
1807    let mask = mask.mapv(|x| x.as_() - zp);
1808
1809    let zp = quant_protos.zero_point.as_();
1810    let protos = protos.mapv(|x| x.as_() - zp);
1811
1812    // Safe to unwrap since the shapes will always be compatible
1813    let segmentation = mask
1814        .dot(&protos)
1815        .into_shape_with_order((shape[0], shape[1], 1))
1816        .unwrap();
1817
1818    let combined_scale = quant_masks.scale * quant_protos.scale;
1819    segmentation.map(|x| {
1820        let val: f32 = (*x).as_() * combined_scale;
1821        let sigmoid = 1.0 / (1.0 + (-val).exp());
1822        (sigmoid * 255.0).round() as u8
1823    })
1824}
1825
1826/// Converts Yolo Instance Segmentation into a 2D mask.
1827///
1828/// The input segmentation is expected to have shape (H, W, 1).
1829///
1830/// The output mask will have shape (H, W), with values 0 or 1 based on the
1831/// threshold.
1832///
1833/// # Errors
1834///
1835/// Returns `DecoderError::InvalidShape` if the input segmentation does not
1836/// have shape (H, W, 1).
1837pub fn yolo_segmentation_to_mask(
1838    segmentation: ArrayView3<u8>,
1839    threshold: u8,
1840) -> Result<Array2<u8>, crate::DecoderError> {
1841    if segmentation.shape()[2] != 1 {
1842        return Err(crate::DecoderError::InvalidShape(format!(
1843            "Yolo Instance Segmentation should have shape (H, W, 1), got (H, W, {})",
1844            segmentation.shape()[2]
1845        )));
1846    }
1847    Ok(segmentation
1848        .slice(s![.., .., 0])
1849        .map(|x| if *x >= threshold { 1 } else { 0 }))
1850}
1851
1852#[cfg(test)]
1853#[cfg_attr(coverage_nightly, coverage(off))]
1854mod tests {
1855    use super::*;
1856    use ndarray::Array2;
1857
1858    // ========================================================================
1859    // Tests for decode_yolo_end_to_end_det_float
1860    // ========================================================================
1861
1862    #[test]
1863    fn test_end_to_end_det_basic_filtering() {
1864        // Create synthetic end-to-end detection output: (6, N) where rows are
1865        // [x1, y1, x2, y2, conf, class]
1866        // 3 detections: one above threshold, two below
1867        let data: Vec<f32> = vec![
1868            // Detection 0: high score (0.9)
1869            0.1, 0.2, 0.3, // x1 values
1870            0.1, 0.2, 0.3, // y1 values
1871            0.5, 0.6, 0.7, // x2 values
1872            0.5, 0.6, 0.7, // y2 values
1873            0.9, 0.1, 0.2, // confidence scores
1874            0.0, 1.0, 2.0, // class indices
1875        ];
1876        let output = Array2::from_shape_vec((6, 3), data).unwrap();
1877
1878        let mut boxes = Vec::with_capacity(10);
1879        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1880
1881        // Only 1 detection should pass threshold of 0.5
1882        assert_eq!(boxes.len(), 1);
1883        assert_eq!(boxes[0].label, 0);
1884        assert!((boxes[0].score - 0.9).abs() < 0.01);
1885        assert!((boxes[0].bbox.xmin - 0.1).abs() < 0.01);
1886        assert!((boxes[0].bbox.ymin - 0.1).abs() < 0.01);
1887        assert!((boxes[0].bbox.xmax - 0.5).abs() < 0.01);
1888        assert!((boxes[0].bbox.ymax - 0.5).abs() < 0.01);
1889    }
1890
1891    #[test]
1892    fn test_end_to_end_det_all_pass_threshold() {
1893        // All detections above threshold
1894        let data: Vec<f32> = vec![
1895            10.0, 20.0, // x1
1896            10.0, 20.0, // y1
1897            50.0, 60.0, // x2
1898            50.0, 60.0, // y2
1899            0.8, 0.7, // conf (both above 0.5)
1900            1.0, 2.0, // class
1901        ];
1902        let output = Array2::from_shape_vec((6, 2), data).unwrap();
1903
1904        let mut boxes = Vec::with_capacity(10);
1905        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1906
1907        assert_eq!(boxes.len(), 2);
1908        assert_eq!(boxes[0].label, 1);
1909        assert_eq!(boxes[1].label, 2);
1910    }
1911
1912    #[test]
1913    fn test_end_to_end_det_none_pass_threshold() {
1914        // All detections below threshold
1915        let data: Vec<f32> = vec![
1916            10.0, 20.0, // x1
1917            10.0, 20.0, // y1
1918            50.0, 60.0, // x2
1919            50.0, 60.0, // y2
1920            0.1, 0.2, // conf (both below 0.5)
1921            1.0, 2.0, // class
1922        ];
1923        let output = Array2::from_shape_vec((6, 2), data).unwrap();
1924
1925        let mut boxes = Vec::with_capacity(10);
1926        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1927
1928        assert_eq!(boxes.len(), 0);
1929    }
1930
1931    #[test]
1932    fn test_end_to_end_det_capacity_limit() {
1933        // Test that output is truncated to capacity
1934        let data: Vec<f32> = vec![
1935            0.1, 0.2, 0.3, 0.4, 0.5, // x1
1936            0.1, 0.2, 0.3, 0.4, 0.5, // y1
1937            0.5, 0.6, 0.7, 0.8, 0.9, // x2
1938            0.5, 0.6, 0.7, 0.8, 0.9, // y2
1939            0.9, 0.9, 0.9, 0.9, 0.9, // conf (all pass)
1940            0.0, 1.0, 2.0, 3.0, 4.0, // class
1941        ];
1942        let output = Array2::from_shape_vec((6, 5), data).unwrap();
1943
1944        let mut boxes = Vec::with_capacity(2); // Only allow 2 boxes
1945        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1946
1947        assert_eq!(boxes.len(), 2);
1948    }
1949
1950    #[test]
1951    fn test_end_to_end_det_empty_output() {
1952        // Test with zero detections
1953        let output = Array2::<f32>::zeros((6, 0));
1954
1955        let mut boxes = Vec::with_capacity(10);
1956        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1957
1958        assert_eq!(boxes.len(), 0);
1959    }
1960
1961    #[test]
1962    fn test_end_to_end_det_pixel_coordinates() {
1963        // Test with pixel coordinates (non-normalized)
1964        let data: Vec<f32> = vec![
1965            100.0, // x1
1966            200.0, // y1
1967            300.0, // x2
1968            400.0, // y2
1969            0.95,  // conf
1970            5.0,   // class
1971        ];
1972        let output = Array2::from_shape_vec((6, 1), data).unwrap();
1973
1974        let mut boxes = Vec::with_capacity(10);
1975        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1976
1977        assert_eq!(boxes.len(), 1);
1978        assert_eq!(boxes[0].label, 5);
1979        assert!((boxes[0].bbox.xmin - 100.0).abs() < 0.01);
1980        assert!((boxes[0].bbox.ymin - 200.0).abs() < 0.01);
1981        assert!((boxes[0].bbox.xmax - 300.0).abs() < 0.01);
1982        assert!((boxes[0].bbox.ymax - 400.0).abs() < 0.01);
1983    }
1984
1985    #[test]
1986    fn test_end_to_end_det_invalid_shape() {
1987        // Test with too few rows (needs at least 6)
1988        let output = Array2::<f32>::zeros((5, 3));
1989
1990        let mut boxes = Vec::with_capacity(10);
1991        let result = decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes);
1992
1993        assert!(result.is_err());
1994        assert!(matches!(
1995            result,
1996            Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 6 rows")
1997        ));
1998    }
1999
2000    // ========================================================================
2001    // Tests for decode_yolo_end_to_end_segdet_float
2002    // ========================================================================
2003
2004    #[test]
2005    fn test_end_to_end_segdet_basic() {
2006        // Create synthetic segdet output: (6 + num_protos, N)
2007        // Detection format: [x1, y1, x2, y2, conf, class, mask_coeff_0..31]
2008        let num_protos = 32;
2009        let num_detections = 2;
2010        let num_features = 6 + num_protos;
2011
2012        // Build detection tensor
2013        let mut data = vec![0.0f32; num_features * num_detections];
2014        // Detection 0: passes threshold
2015        data[0] = 0.1; // x1[0]
2016        data[1] = 0.5; // x1[1]
2017        data[num_detections] = 0.1; // y1[0]
2018        data[num_detections + 1] = 0.5; // y1[1]
2019        data[2 * num_detections] = 0.4; // x2[0]
2020        data[2 * num_detections + 1] = 0.9; // x2[1]
2021        data[3 * num_detections] = 0.4; // y2[0]
2022        data[3 * num_detections + 1] = 0.9; // y2[1]
2023        data[4 * num_detections] = 0.9; // conf[0] - passes
2024        data[4 * num_detections + 1] = 0.3; // conf[1] - fails
2025        data[5 * num_detections] = 1.0; // class[0]
2026        data[5 * num_detections + 1] = 2.0; // class[1]
2027                                            // Fill mask coefficients with small values
2028        for i in 6..num_features {
2029            data[i * num_detections] = 0.1;
2030            data[i * num_detections + 1] = 0.1;
2031        }
2032
2033        let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
2034
2035        // Create protos tensor: (proto_height, proto_width, num_protos)
2036        let protos = Array3::<f32>::zeros((16, 16, num_protos));
2037
2038        let mut boxes = Vec::with_capacity(10);
2039        let mut masks = Vec::with_capacity(10);
2040        decode_yolo_end_to_end_segdet_float(
2041            output.view(),
2042            protos.view(),
2043            0.5,
2044            &mut boxes,
2045            &mut masks,
2046        )
2047        .unwrap();
2048
2049        // Only detection 0 should pass
2050        assert_eq!(boxes.len(), 1);
2051        assert_eq!(masks.len(), 1);
2052        assert_eq!(boxes[0].label, 1);
2053        assert!((boxes[0].score - 0.9).abs() < 0.01);
2054    }
2055
2056    #[test]
2057    fn test_end_to_end_segdet_mask_coordinates() {
2058        // Test that mask coordinates match box coordinates
2059        let num_protos = 32;
2060        let num_features = 6 + num_protos;
2061
2062        let mut data = vec![0.0f32; num_features];
2063        data[0] = 0.2; // x1
2064        data[1] = 0.2; // y1
2065        data[2] = 0.8; // x2
2066        data[3] = 0.8; // y2
2067        data[4] = 0.95; // conf
2068        data[5] = 3.0; // class
2069
2070        let output = Array2::from_shape_vec((num_features, 1), data).unwrap();
2071        let protos = Array3::<f32>::zeros((16, 16, num_protos));
2072
2073        let mut boxes = Vec::with_capacity(10);
2074        let mut masks = Vec::with_capacity(10);
2075        decode_yolo_end_to_end_segdet_float(
2076            output.view(),
2077            protos.view(),
2078            0.5,
2079            &mut boxes,
2080            &mut masks,
2081        )
2082        .unwrap();
2083
2084        assert_eq!(boxes.len(), 1);
2085        assert_eq!(masks.len(), 1);
2086
2087        // Verify mask coordinates match box coordinates
2088        assert!((masks[0].xmin - boxes[0].bbox.xmin).abs() < 0.01);
2089        assert!((masks[0].ymin - boxes[0].bbox.ymin).abs() < 0.01);
2090        assert!((masks[0].xmax - boxes[0].bbox.xmax).abs() < 0.01);
2091        assert!((masks[0].ymax - boxes[0].bbox.ymax).abs() < 0.01);
2092    }
2093
2094    #[test]
2095    fn test_end_to_end_segdet_empty_output() {
2096        let num_protos = 32;
2097        let output = Array2::<f32>::zeros((6 + num_protos, 0));
2098        let protos = Array3::<f32>::zeros((16, 16, num_protos));
2099
2100        let mut boxes = Vec::with_capacity(10);
2101        let mut masks = Vec::with_capacity(10);
2102        decode_yolo_end_to_end_segdet_float(
2103            output.view(),
2104            protos.view(),
2105            0.5,
2106            &mut boxes,
2107            &mut masks,
2108        )
2109        .unwrap();
2110
2111        assert_eq!(boxes.len(), 0);
2112        assert_eq!(masks.len(), 0);
2113    }
2114
2115    #[test]
2116    fn test_end_to_end_segdet_capacity_limit() {
2117        let num_protos = 32;
2118        let num_detections = 5;
2119        let num_features = 6 + num_protos;
2120
2121        let mut data = vec![0.0f32; num_features * num_detections];
2122        // All detections pass threshold
2123        for i in 0..num_detections {
2124            data[i] = 0.1 * (i as f32); // x1
2125            data[num_detections + i] = 0.1 * (i as f32); // y1
2126            data[2 * num_detections + i] = 0.1 * (i as f32) + 0.2; // x2
2127            data[3 * num_detections + i] = 0.1 * (i as f32) + 0.2; // y2
2128            data[4 * num_detections + i] = 0.9; // conf
2129            data[5 * num_detections + i] = i as f32; // class
2130        }
2131
2132        let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
2133        let protos = Array3::<f32>::zeros((16, 16, num_protos));
2134
2135        let mut boxes = Vec::with_capacity(2); // Limit to 2
2136        let mut masks = Vec::with_capacity(2);
2137        decode_yolo_end_to_end_segdet_float(
2138            output.view(),
2139            protos.view(),
2140            0.5,
2141            &mut boxes,
2142            &mut masks,
2143        )
2144        .unwrap();
2145
2146        assert_eq!(boxes.len(), 2);
2147        assert_eq!(masks.len(), 2);
2148    }
2149
2150    #[test]
2151    fn test_end_to_end_segdet_invalid_shape_too_few_rows() {
2152        // Test with too few rows (needs at least 7: 6 base + 1 mask coeff)
2153        let output = Array2::<f32>::zeros((6, 3));
2154        let protos = Array3::<f32>::zeros((16, 16, 32));
2155
2156        let mut boxes = Vec::with_capacity(10);
2157        let mut masks = Vec::with_capacity(10);
2158        let result = decode_yolo_end_to_end_segdet_float(
2159            output.view(),
2160            protos.view(),
2161            0.5,
2162            &mut boxes,
2163            &mut masks,
2164        );
2165
2166        assert!(result.is_err());
2167        assert!(matches!(
2168            result,
2169            Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 7 rows")
2170        ));
2171    }
2172
2173    #[test]
2174    fn test_end_to_end_segdet_invalid_shape_protos_mismatch() {
2175        // Test with mismatched mask coefficients and protos count
2176        let num_protos = 32;
2177        let output = Array2::<f32>::zeros((6 + 16, 3)); // 16 mask coeffs
2178        let protos = Array3::<f32>::zeros((16, 16, num_protos)); // 32 protos
2179
2180        let mut boxes = Vec::with_capacity(10);
2181        let mut masks = Vec::with_capacity(10);
2182        let result = decode_yolo_end_to_end_segdet_float(
2183            output.view(),
2184            protos.view(),
2185            0.5,
2186            &mut boxes,
2187            &mut masks,
2188        );
2189
2190        assert!(result.is_err());
2191        assert!(matches!(
2192            result,
2193            Err(crate::DecoderError::InvalidShape(s)) if s.contains("doesn't match protos count")
2194        ));
2195    }
2196
2197    // ========================================================================
2198    // Tests for decode_yolo_split_end_to_end_segdet_float
2199    // ========================================================================
2200
2201    #[test]
2202    fn test_split_end_to_end_segdet_basic() {
2203        // Create synthetic segdet output: (6 + num_protos, N)
2204        // Detection format: [x1, y1, x2, y2, conf, class, mask_coeff_0..31]
2205        let num_protos = 32;
2206        let num_detections = 2;
2207        let num_features = 6 + num_protos;
2208
2209        // Build detection tensor
2210        let mut data = vec![0.0f32; num_features * num_detections];
2211        // Detection 0: passes threshold
2212        data[0] = 0.1; // x1[0]
2213        data[1] = 0.5; // x1[1]
2214        data[num_detections] = 0.1; // y1[0]
2215        data[num_detections + 1] = 0.5; // y1[1]
2216        data[2 * num_detections] = 0.4; // x2[0]
2217        data[2 * num_detections + 1] = 0.9; // x2[1]
2218        data[3 * num_detections] = 0.4; // y2[0]
2219        data[3 * num_detections + 1] = 0.9; // y2[1]
2220        data[4 * num_detections] = 0.9; // conf[0] - passes
2221        data[4 * num_detections + 1] = 0.3; // conf[1] - fails
2222        data[5 * num_detections] = 1.0; // class[0]
2223        data[5 * num_detections + 1] = 2.0; // class[1]
2224                                            // Fill mask coefficients with small values
2225        for i in 6..num_features {
2226            data[i * num_detections] = 0.1;
2227            data[i * num_detections + 1] = 0.1;
2228        }
2229
2230        let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
2231        let box_coords = output.slice(s![..4, ..]);
2232        let scores = output.slice(s![4..5, ..]);
2233        let classes = output.slice(s![5..6, ..]);
2234        let mask_coeff = output.slice(s![6.., ..]);
2235        // Create protos tensor: (proto_height, proto_width, num_protos)
2236        let protos = Array3::<f32>::zeros((16, 16, num_protos));
2237
2238        let mut boxes = Vec::with_capacity(10);
2239        let mut masks = Vec::with_capacity(10);
2240        decode_yolo_split_end_to_end_segdet_float(
2241            box_coords,
2242            scores,
2243            classes,
2244            mask_coeff,
2245            protos.view(),
2246            0.5,
2247            &mut boxes,
2248            &mut masks,
2249        )
2250        .unwrap();
2251
2252        // Only detection 0 should pass
2253        assert_eq!(boxes.len(), 1);
2254        assert_eq!(masks.len(), 1);
2255        assert_eq!(boxes[0].label, 1);
2256        assert!((boxes[0].score - 0.9).abs() < 0.01);
2257    }
2258
2259    // ========================================================================
2260    // Tests for yolo_segmentation_to_mask
2261    // ========================================================================
2262
2263    #[test]
2264    fn test_segmentation_to_mask_basic() {
2265        // Create a 4x4x1 segmentation with values above and below threshold
2266        let data: Vec<u8> = vec![
2267            100, 200, 50, 150, // row 0
2268            10, 255, 128, 64, // row 1
2269            0, 127, 128, 255, // row 2
2270            64, 64, 192, 192, // row 3
2271        ];
2272        let segmentation = Array3::from_shape_vec((4, 4, 1), data).unwrap();
2273
2274        let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2275
2276        // Values >= 128 should be 1, others 0
2277        assert_eq!(mask[[0, 0]], 0); // 100 < 128
2278        assert_eq!(mask[[0, 1]], 1); // 200 >= 128
2279        assert_eq!(mask[[0, 2]], 0); // 50 < 128
2280        assert_eq!(mask[[0, 3]], 1); // 150 >= 128
2281        assert_eq!(mask[[1, 1]], 1); // 255 >= 128
2282        assert_eq!(mask[[1, 2]], 1); // 128 >= 128
2283        assert_eq!(mask[[2, 0]], 0); // 0 < 128
2284        assert_eq!(mask[[2, 1]], 0); // 127 < 128
2285    }
2286
2287    #[test]
2288    fn test_segmentation_to_mask_all_above() {
2289        let segmentation = Array3::from_elem((4, 4, 1), 255u8);
2290        let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2291        assert!(mask.iter().all(|&x| x == 1));
2292    }
2293
2294    #[test]
2295    fn test_segmentation_to_mask_all_below() {
2296        let segmentation = Array3::from_elem((4, 4, 1), 64u8);
2297        let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2298        assert!(mask.iter().all(|&x| x == 0));
2299    }
2300
2301    #[test]
2302    fn test_segmentation_to_mask_invalid_shape() {
2303        let segmentation = Array3::from_elem((4, 4, 3), 128u8);
2304        let result = yolo_segmentation_to_mask(segmentation.view(), 128);
2305
2306        assert!(result.is_err());
2307        assert!(matches!(
2308            result,
2309            Err(crate::DecoderError::InvalidShape(s)) if s.contains("(H, W, 1)")
2310        ));
2311    }
2312
2313    // ========================================================================
2314    // Tests for protobox / NORM_LIMIT regression
2315    // ========================================================================
2316
2317    #[test]
2318    fn test_protobox_clamps_edge_coordinates() {
2319        // bbox with xmax=1.0 should not panic (OOB guard)
2320        let protos = Array3::<f32>::zeros((16, 16, 4));
2321        let view = protos.view();
2322        let roi = BoundingBox {
2323            xmin: 0.5,
2324            ymin: 0.5,
2325            xmax: 1.0,
2326            ymax: 1.0,
2327        };
2328        let result = protobox(&view, &roi);
2329        assert!(result.is_ok(), "protobox should accept xmax=1.0");
2330        let (cropped, _roi_norm) = result.unwrap();
2331        // Cropped region must have non-zero spatial dimensions
2332        assert!(cropped.shape()[0] > 0);
2333        assert!(cropped.shape()[1] > 0);
2334        assert_eq!(cropped.shape()[2], 4);
2335    }
2336
2337    #[test]
2338    fn test_protobox_rejects_wildly_out_of_range() {
2339        // bbox with coords > NORM_LIMIT (e.g. 3.0) returns error
2340        let protos = Array3::<f32>::zeros((16, 16, 4));
2341        let view = protos.view();
2342        let roi = BoundingBox {
2343            xmin: 0.0,
2344            ymin: 0.0,
2345            xmax: 3.0,
2346            ymax: 3.0,
2347        };
2348        let result = protobox(&view, &roi);
2349        assert!(
2350            matches!(result, Err(crate::DecoderError::InvalidShape(s)) if s.contains("un-normalized")),
2351            "protobox should reject coords > NORM_LIMIT"
2352        );
2353    }
2354
2355    #[test]
2356    fn test_protobox_accepts_slightly_over_one() {
2357        // bbox with coords at 1.5 (within NORM_LIMIT=2.0) succeeds
2358        let protos = Array3::<f32>::zeros((16, 16, 4));
2359        let view = protos.view();
2360        let roi = BoundingBox {
2361            xmin: 0.0,
2362            ymin: 0.0,
2363            xmax: 1.5,
2364            ymax: 1.5,
2365        };
2366        let result = protobox(&view, &roi);
2367        assert!(
2368            result.is_ok(),
2369            "protobox should accept coords <= NORM_LIMIT (2.0)"
2370        );
2371        let (cropped, _roi_norm) = result.unwrap();
2372        // Entire proto map should be selected when coords > 1.0 (clamped to boundary)
2373        assert_eq!(cropped.shape()[0], 16);
2374        assert_eq!(cropped.shape()[1], 16);
2375    }
2376
2377    #[test]
2378    fn test_segdet_float_proto_no_panic() {
2379        // Simulates YOLOv8n-seg: output0 = [116, 8400] (4 box + 80 class + 32 mask coeff)
2380        // output1 (protos) = [32, 160, 160]
2381        let num_proposals = 100; // enough to produce idx >= 32
2382        let num_classes = 80;
2383        let num_mask_coeffs = 32;
2384        let rows = 4 + num_classes + num_mask_coeffs; // 116
2385
2386        // Fill boxes with valid xywh data so some detections pass the threshold.
2387        // Layout is [116, num_proposals] row-major: row 0=cx, 1=cy, 2=w, 3=h,
2388        // rows 4..84=class scores, rows 84..116=mask coefficients.
2389        let mut data = vec![0.0f32; rows * num_proposals];
2390        for i in 0..num_proposals {
2391            let row = |r: usize| r * num_proposals + i;
2392            data[row(0)] = 320.0; // cx
2393            data[row(1)] = 320.0; // cy
2394            data[row(2)] = 50.0; // w
2395            data[row(3)] = 50.0; // h
2396            data[row(4)] = 0.9; // class-0 score
2397        }
2398        let boxes = ndarray::Array2::from_shape_vec((rows, num_proposals), data).unwrap();
2399
2400        // Protos must be in HWC order. Under the HAL physical-order
2401        // contract, callers declare shape+dshape matching producer memory
2402        // and swap_axes_if_needed permutes the stride tuple into canonical
2403        // [batch, height, width, num_protos] before this function sees it.
2404        let protos = ndarray::Array3::<f32>::zeros((160, 160, num_mask_coeffs));
2405
2406        let mut output_boxes = Vec::with_capacity(300);
2407
2408        // This panicked before fix: mask_tensor.row(idx) with idx >= 32
2409        let proto_data = impl_yolo_segdet_float_proto::<XYWH, _, _>(
2410            boxes.view(),
2411            protos.view(),
2412            0.5,
2413            0.7,
2414            Some(Nms::default()),
2415            &mut output_boxes,
2416        );
2417
2418        // Should produce detections (NMS will collapse many overlapping boxes)
2419        assert!(!output_boxes.is_empty());
2420        let coeffs_shape = proto_data.mask_coefficients.shape();
2421        assert_eq!(coeffs_shape[0], output_boxes.len());
2422        // Each mask coefficient vector should have 32 elements
2423        assert_eq!(coeffs_shape[1], num_mask_coeffs);
2424    }
2425
2426    // ========================================================================
2427    // Pre-NMS top-K cap (MAX_NMS_CANDIDATES)
2428    // ========================================================================
2429
2430    /// At very low score thresholds (e.g., t=0.01 on YOLOv8 with 8400×80
2431    /// candidates) almost every score passes the filter, feeding O(n²)
2432    /// NMS and a per-survivor mask matmul. The decoder caps the
2433    /// candidate set fed to NMS at `MAX_NMS_CANDIDATES` (Ultralytics
2434    /// default 30 000) to bound worst-case decode time.
2435    ///
2436    /// This regression test pumps 50 000 above-threshold candidates
2437    /// into `impl_yolo_segdet_get_boxes` with NMS bypassed (Nms=None)
2438    /// and a generous post-NMS cap. Before the fix, the function
2439    /// returned all 50 000; after the fix, exactly 30 000.
2440    #[test]
2441    fn test_pre_nms_cap_truncates_excess_candidates() {
2442        let n: usize = 50_000;
2443        let num_classes = 1;
2444
2445        // Identical valid boxes. Distinct scores (descending) so the
2446        // top-K cap keeps the highest-scoring ones in deterministic
2447        // order — letting us assert *which* ones survived.
2448        let mut boxes_data = Vec::with_capacity(n * 4);
2449        let mut scores_data = Vec::with_capacity(n * num_classes);
2450        for i in 0..n {
2451            boxes_data.extend_from_slice(&[0.1f32, 0.1, 0.5, 0.5]);
2452            // score_i = 0.99 - i * 1e-7 keeps everything well above 0.1
2453            // threshold but strictly decreasing.
2454            scores_data.push(0.99 - (i as f32) * 1e-7);
2455        }
2456        let boxes = Array2::from_shape_vec((n, 4), boxes_data).unwrap();
2457        let scores = Array2::from_shape_vec((n, num_classes), scores_data).unwrap();
2458
2459        let result = impl_yolo_segdet_get_boxes::<XYXY, _, _>(
2460            boxes.view(),
2461            scores.view(),
2462            0.1,
2463            1.0,
2464            None,       // bypass NMS so we measure the cap, not suppression
2465            usize::MAX, // no post-NMS truncation
2466        );
2467
2468        assert_eq!(
2469            result.len(),
2470            crate::yolo::MAX_NMS_CANDIDATES,
2471            "pre-NMS cap should truncate to MAX_NMS_CANDIDATES; got {}",
2472            result.len()
2473        );
2474        // Top-K survivors: highest scores were the first n indices,
2475        // so survivor 0 must have score ~0.99.
2476        let top_score = result[0].0.score;
2477        assert!(
2478            top_score > 0.98,
2479            "highest-ranked survivor should have the largest score, got {top_score}"
2480        );
2481    }
2482
2483    /// Counterpart for the quantized split path. Same contract: feed
2484    /// more than `MAX_NMS_CANDIDATES` survivors above the quantized
2485    /// threshold, confirm `impl_yolo_split_segdet_quant_get_boxes`
2486    /// truncates before NMS.
2487    #[test]
2488    fn test_pre_nms_cap_truncates_excess_candidates_quant() {
2489        use crate::Quantization;
2490        let n: usize = 50_000;
2491        let num_classes = 1;
2492
2493        // i8 boxes with simple scale/zp; the box value 50 dequantizes
2494        // to 0.5 with scale=0.01, zp=0 — fine for a flat box set.
2495        let boxes_data = (0..n).flat_map(|_| [10i8, 10, 50, 50]).collect::<Vec<_>>();
2496        let boxes = Array2::from_shape_vec((n, 4), boxes_data).unwrap();
2497        let quant_boxes = Quantization {
2498            scale: 0.01,
2499            zero_point: 0,
2500        };
2501
2502        // u8 scores: distinct descending values, all well above threshold.
2503        // value 250 → 0.98 with scale 0.00392, zp 0.
2504        // value (250 - i % 200) keeps a wide spread above the dequant
2505        // threshold of 0.5.
2506        let scores_data: Vec<u8> = (0..n)
2507            .map(|i| 250u8.saturating_sub((i % 200) as u8))
2508            .collect();
2509        let scores = Array2::from_shape_vec((n, num_classes), scores_data).unwrap();
2510        let quant_scores = Quantization {
2511            scale: 0.00392,
2512            zero_point: 0,
2513        };
2514
2515        let result = impl_yolo_split_segdet_quant_get_boxes::<XYXY, _, _>(
2516            (boxes.view(), quant_boxes),
2517            (scores.view(), quant_scores),
2518            0.1,
2519            1.0,
2520            None,
2521            usize::MAX,
2522        );
2523
2524        assert_eq!(
2525            result.len(),
2526            crate::yolo::MAX_NMS_CANDIDATES,
2527            "quant path pre-NMS cap should truncate to MAX_NMS_CANDIDATES; got {}",
2528            result.len()
2529        );
2530    }
2531
2532    /// Regression test for HAILORT_BUG.md — the YoloSegDet path
2533    /// (combined `(4 + nc + nm, N)` detection tensor + separate protos)
2534    /// must pair each surviving detection with the mask coefficient
2535    /// row at the SAME anchor index the box came from. The validator
2536    /// sees this path miss the pairing under schema-v2 Hailo inputs
2537    /// (mAP collapse from 46.8 → 3.65 while mask IoU stays at 66.9,
2538    /// the fingerprint of mask-to-detection misalignment).
2539    ///
2540    /// Construction: three anchors with distinct mask-coef signatures
2541    /// that, after dot(coefs, protos) + sigmoid, produce HIGH vs LOW
2542    /// mask pixel values. Two anchors survive (one high, one low); if
2543    /// the mask row is looked up at the wrong index, the per-detection
2544    /// mean mask value would cross the threshold and we catch it.
2545    #[test]
2546    fn segdet_combined_tensor_pairs_detection_with_matching_mask_row() {
2547        let nc = 2; // num_classes
2548        let nm = 2; // num_protos
2549        let n = 3; // num_anchors
2550        let feat = 4 + nc + nm; // 8
2551
2552        // Tensor layout: (8, 3) rows=features, cols=anchors.
2553        // Row indices:  0..4 = xywh, 4..6 = scores, 6..8 = mask_coefs.
2554        //
2555        //         anchor 0 | anchor 1 | anchor 2
2556        // xc       0.2      | 0.5      | 0.8
2557        // yc       0.2      | 0.5      | 0.8
2558        // w        0.1      | 0.1      | 0.1
2559        // h        0.1      | 0.1      | 0.1
2560        // s[0]     0.9      | 0.0      | 0.8   (class 0)
2561        // s[1]     0.0      | 0.0      | 0.0   (class 1 — always loses)
2562        // m[0]     3.0      | 0.0      | -3.0  (high for a0, low for a2)
2563        // m[1]     3.0      | 0.0      | -3.0
2564        //
2565        // Proto[0] = Proto[1] = all-ones (8x8), so
2566        //   mask(a0) = sigmoid(3 + 3) ≈ 0.9975 → 254
2567        //   mask(a2) = sigmoid(-3 + -3) ≈ 0.0025 → 1
2568        // 250-point gap makes any misalignment trivially detectable.
2569        let mut data = vec![0.0f32; feat * n];
2570        let set = |d: &mut [f32], r: usize, c: usize, v: f32| d[r * n + c] = v;
2571        set(&mut data, 0, 0, 0.2);
2572        set(&mut data, 1, 0, 0.2);
2573        set(&mut data, 2, 0, 0.1);
2574        set(&mut data, 3, 0, 0.1);
2575        set(&mut data, 0, 1, 0.5);
2576        set(&mut data, 1, 1, 0.5);
2577        set(&mut data, 2, 1, 0.1);
2578        set(&mut data, 3, 1, 0.1);
2579        set(&mut data, 0, 2, 0.8);
2580        set(&mut data, 1, 2, 0.8);
2581        set(&mut data, 2, 2, 0.1);
2582        set(&mut data, 3, 2, 0.1);
2583        set(&mut data, 4, 0, 0.9);
2584        set(&mut data, 4, 2, 0.8);
2585        set(&mut data, 6, 0, 3.0);
2586        set(&mut data, 7, 0, 3.0);
2587        set(&mut data, 6, 2, -3.0);
2588        set(&mut data, 7, 2, -3.0);
2589
2590        let output = Array2::from_shape_vec((feat, n), data).unwrap();
2591        let protos = Array3::<f32>::from_elem((8, 8, nm), 1.0);
2592
2593        let mut boxes: Vec<DetectBox> = Vec::with_capacity(10);
2594        let mut masks: Vec<Segmentation> = Vec::with_capacity(10);
2595        decode_yolo_segdet_float(
2596            output.view(),
2597            protos.view(),
2598            0.5,
2599            0.5,
2600            Some(Nms::ClassAgnostic),
2601            &mut boxes,
2602            &mut masks,
2603        )
2604        .unwrap();
2605
2606        assert_eq!(
2607            boxes.len(),
2608            2,
2609            "two anchors above threshold should survive (a0 score=0.9, a2 score=0.8); got {}",
2610            boxes.len()
2611        );
2612
2613        // Build a (anchor_index → mask_mean) mapping from the results.
2614        // Anchor 0 has centre (0.2, 0.2), anchor 2 has centre (0.8,
2615        // 0.8). The DetectBox bbox is the post-XYWH-to-XYXY conversion
2616        // of the original xywh; cropping inside protobox may shrink it,
2617        // so match by centre (0.2 vs 0.8) rather than exact bbox.
2618        for (b, m) in boxes.iter().zip(masks.iter()) {
2619            let cx = (b.bbox.xmin + b.bbox.xmax) * 0.5;
2620            let mean = {
2621                let s = &m.segmentation;
2622                let total: u32 = s.iter().map(|&v| v as u32).sum();
2623                total as f32 / s.len() as f32
2624            };
2625            if cx < 0.3 {
2626                // anchor 0 — expect HIGH mask values ≈ 254
2627                assert!(
2628                    mean > 200.0,
2629                    "anchor 0 detection (centre {cx:.2}) should have high-value mask; got mean {mean}"
2630                );
2631            } else if cx > 0.7 {
2632                // anchor 2 — expect LOW mask values ≈ 1
2633                assert!(
2634                    mean < 50.0,
2635                    "anchor 2 detection (centre {cx:.2}) should have low-value mask; got mean {mean}"
2636                );
2637            } else {
2638                panic!("unexpected detection centre {cx:.2}");
2639            }
2640        }
2641    }
2642}