Skip to main content

edgefirst_decoder/
yolo.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use std::fmt::Debug;
5
6use ndarray::{
7    parallel::prelude::{IntoParallelIterator, ParallelIterator},
8    s, Array2, Array3, ArrayView1, ArrayView2, ArrayView3,
9};
10use num_traits::{AsPrimitive, Float, PrimInt, Signed};
11
12use crate::{
13    byte::{
14        nms_class_aware_int, nms_extra_class_aware_int, nms_extra_int, nms_int,
15        postprocess_boxes_index_quant, postprocess_boxes_quant, quantize_score_threshold,
16    },
17    configs::Nms,
18    dequant_detect_box,
19    float::{
20        nms_class_aware_float, nms_extra_class_aware_float, nms_extra_float, nms_float,
21        postprocess_boxes_float, postprocess_boxes_index_float,
22    },
23    BBoxTypeTrait, BoundingBox, DetectBox, DetectBoxQuantized, ProtoData, ProtoLayout,
24    Quantization, Segmentation, XYWH, XYXY,
25};
26
27/// Maximum number of above-threshold candidates fed to NMS.
28///
29/// At very low score thresholds (e.g., t=0.01 on YOLOv8 with
30/// 8400 anchors × 80 classes), the number of survivors approaches
31/// the full 672 000-entry score grid. NMS is O(n²) and the
32/// downstream mask matmul runs once per survivor, so an
33/// unbounded set produces minutes-per-frame decode times.
34///
35/// `MAX_NMS_CANDIDATES` matches the Ultralytics `max_nms` default
36/// and is applied as a top-K-by-score truncation immediately
37/// before NMS. Values above the cap are silently dropped — at the
38/// score thresholds where the cap activates the bottom of the
39/// candidate list is dominated by noise that NMS would discard
40/// anyway.
41pub const MAX_NMS_CANDIDATES: usize = 30_000;
42
43/// Truncate `boxes` to the highest-scoring `top_k` entries in-place when the
44/// input exceeds the cap. Uses partial sort (O(N)) via `select_nth_unstable_by`
45/// to avoid full O(N log N) sort. No-op when input length ≤ `top_k`.
46fn truncate_to_top_k_by_score<E: Send>(boxes: &mut Vec<(DetectBox, E)>, top_k: usize) {
47    if boxes.len() > top_k {
48        boxes.select_nth_unstable_by(top_k, |a, b| b.0.score.total_cmp(&a.0.score));
49        boxes.truncate(top_k);
50    }
51}
52
53/// Quantized counterpart of [`truncate_to_top_k_by_score`]. Sorts on
54/// the raw quantized score (which preserves order under monotonic
55/// dequantization). Uses partial sort (O(N)) via `select_nth_unstable_by`.
56fn truncate_to_top_k_by_score_quant<S: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send>(
57    boxes: &mut Vec<(DetectBoxQuantized<S>, E)>,
58    top_k: usize,
59) {
60    if boxes.len() > top_k {
61        boxes.select_nth_unstable_by(top_k, |a, b| b.0.score.cmp(&a.0.score));
62        boxes.truncate(top_k);
63    }
64}
65
66/// Dispatches to the appropriate NMS function based on mode for float boxes.
67fn dispatch_nms_float(nms: Option<Nms>, iou: f32, boxes: Vec<DetectBox>) -> Vec<DetectBox> {
68    match nms {
69        Some(Nms::ClassAgnostic) => nms_float(iou, boxes),
70        Some(Nms::ClassAware) => nms_class_aware_float(iou, boxes),
71        None => boxes, // bypass NMS
72    }
73}
74
75/// Dispatches to the appropriate NMS function based on mode for float boxes
76/// with extra data.
77pub(super) fn dispatch_nms_extra_float<E: Send + Sync>(
78    nms: Option<Nms>,
79    iou: f32,
80    boxes: Vec<(DetectBox, E)>,
81) -> Vec<(DetectBox, E)> {
82    match nms {
83        Some(Nms::ClassAgnostic) => nms_extra_float(iou, boxes),
84        Some(Nms::ClassAware) => nms_extra_class_aware_float(iou, boxes),
85        None => boxes, // bypass NMS
86    }
87}
88
89/// Dispatches to the appropriate NMS function based on mode for quantized
90/// boxes.
91fn dispatch_nms_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync>(
92    nms: Option<Nms>,
93    iou: f32,
94    boxes: Vec<DetectBoxQuantized<SCORE>>,
95) -> Vec<DetectBoxQuantized<SCORE>> {
96    match nms {
97        Some(Nms::ClassAgnostic) => nms_int(iou, boxes),
98        Some(Nms::ClassAware) => nms_class_aware_int(iou, boxes),
99        None => boxes, // bypass NMS
100    }
101}
102
103/// Dispatches to the appropriate NMS function based on mode for quantized boxes
104/// with extra data.
105fn dispatch_nms_extra_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send + Sync>(
106    nms: Option<Nms>,
107    iou: f32,
108    boxes: Vec<(DetectBoxQuantized<SCORE>, E)>,
109) -> Vec<(DetectBoxQuantized<SCORE>, E)> {
110    match nms {
111        Some(Nms::ClassAgnostic) => nms_extra_int(iou, boxes),
112        Some(Nms::ClassAware) => nms_extra_class_aware_int(iou, boxes),
113        None => boxes, // bypass NMS
114    }
115}
116
117/// Decodes YOLO detection outputs from quantized tensors into detection boxes.
118///
119/// Boxes are expected to be in XYWH format.
120///
121/// Expected shapes of inputs:
122/// - output: (4 + num_classes, num_boxes)
123pub fn decode_yolo_det<BOX: PrimInt + AsPrimitive<f32> + Send + Sync>(
124    output: (ArrayView2<BOX>, Quantization),
125    score_threshold: f32,
126    iou_threshold: f32,
127    nms: Option<Nms>,
128    output_boxes: &mut Vec<DetectBox>,
129) where
130    f32: AsPrimitive<BOX>,
131{
132    impl_yolo_quant::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
133}
134
135/// Decodes YOLO detection outputs from float tensors into detection boxes.
136///
137/// Boxes are expected to be in XYWH format.
138///
139/// Expected shapes of inputs:
140/// - output: (4 + num_classes, num_boxes)
141pub fn decode_yolo_det_float<T>(
142    output: ArrayView2<T>,
143    score_threshold: f32,
144    iou_threshold: f32,
145    nms: Option<Nms>,
146    output_boxes: &mut Vec<DetectBox>,
147) where
148    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
149    f32: AsPrimitive<T>,
150{
151    impl_yolo_float::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
152}
153
154/// Decodes YOLO detection and segmentation outputs from quantized tensors into
155/// detection boxes and segmentation masks.
156///
157/// Boxes are expected to be in XYWH format.
158///
159/// Expected shapes of inputs:
160/// - boxes: (4 + num_classes + num_protos, num_boxes)
161/// - protos: (proto_height, proto_width, num_protos)
162///
163/// # Errors
164/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
165pub fn decode_yolo_segdet_quant<
166    BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
167    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
168>(
169    boxes: (ArrayView2<BOX>, Quantization),
170    protos: (ArrayView3<PROTO>, Quantization),
171    score_threshold: f32,
172    iou_threshold: f32,
173    nms: Option<Nms>,
174    output_boxes: &mut Vec<DetectBox>,
175    output_masks: &mut Vec<Segmentation>,
176) -> Result<(), crate::DecoderError>
177where
178    f32: AsPrimitive<BOX>,
179{
180    impl_yolo_segdet_quant::<XYWH, _, _>(
181        boxes,
182        protos,
183        score_threshold,
184        iou_threshold,
185        nms,
186        MAX_NMS_CANDIDATES,
187        output_boxes.capacity(),
188        output_boxes,
189        output_masks,
190    )
191}
192
193/// Decodes YOLO detection and segmentation outputs from float tensors into
194/// detection boxes and segmentation masks.
195///
196/// Boxes are expected to be in XYWH format.
197///
198/// Expected shapes of inputs:
199/// - boxes: (4 + num_classes + num_protos, num_boxes)
200/// - protos: (proto_height, proto_width, num_protos)
201///
202/// # Errors
203/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
204pub fn decode_yolo_segdet_float<T>(
205    boxes: ArrayView2<T>,
206    protos: ArrayView3<T>,
207    score_threshold: f32,
208    iou_threshold: f32,
209    nms: Option<Nms>,
210    output_boxes: &mut Vec<DetectBox>,
211    output_masks: &mut Vec<Segmentation>,
212) -> Result<(), crate::DecoderError>
213where
214    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
215    f32: AsPrimitive<T>,
216{
217    impl_yolo_segdet_float::<XYWH, _, _>(
218        boxes,
219        protos,
220        score_threshold,
221        iou_threshold,
222        nms,
223        MAX_NMS_CANDIDATES,
224        output_boxes.capacity(),
225        output_boxes,
226        output_masks,
227    )
228}
229
230/// Decodes YOLO split detection outputs from quantized tensors into detection
231/// boxes.
232///
233/// Boxes are expected to be in XYWH format.
234///
235/// Expected shapes of inputs:
236/// - boxes: (4, num_boxes)
237/// - scores: (num_classes, num_boxes)
238///
239/// # Panics
240/// Panics if shapes don't match the expected dimensions.
241pub fn decode_yolo_split_det_quant<
242    BOX: PrimInt + AsPrimitive<i32> + AsPrimitive<f32> + Send + Sync,
243    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
244>(
245    boxes: (ArrayView2<BOX>, Quantization),
246    scores: (ArrayView2<SCORE>, Quantization),
247    score_threshold: f32,
248    iou_threshold: f32,
249    nms: Option<Nms>,
250    output_boxes: &mut Vec<DetectBox>,
251) where
252    f32: AsPrimitive<SCORE>,
253{
254    impl_yolo_split_quant::<XYWH, _, _>(
255        boxes,
256        scores,
257        score_threshold,
258        iou_threshold,
259        nms,
260        output_boxes,
261    );
262}
263
264/// Decodes YOLO split detection outputs from float tensors into detection
265/// boxes.
266///
267/// Boxes are expected to be in XYWH format.
268///
269/// Expected shapes of inputs:
270/// - boxes: (4, num_boxes)
271/// - scores: (num_classes, num_boxes)
272///
273/// # Panics
274/// Panics if shapes don't match the expected dimensions.
275pub fn decode_yolo_split_det_float<T>(
276    boxes: ArrayView2<T>,
277    scores: ArrayView2<T>,
278    score_threshold: f32,
279    iou_threshold: f32,
280    nms: Option<Nms>,
281    output_boxes: &mut Vec<DetectBox>,
282) where
283    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
284    f32: AsPrimitive<T>,
285{
286    impl_yolo_split_float::<XYWH, _, _>(
287        boxes,
288        scores,
289        score_threshold,
290        iou_threshold,
291        nms,
292        output_boxes,
293    );
294}
295
296/// Decodes YOLO split detection segmentation outputs from quantized tensors
297/// into detection boxes and segmentation masks.
298///
299/// Boxes are expected to be in XYWH format.
300///
301/// Expected shapes of inputs:
302/// - boxes_tensor: (4, num_boxes)
303/// - scores_tensor: (num_classes, num_boxes)
304/// - mask_tensor: (num_protos, num_boxes)
305/// - protos: (proto_height, proto_width, num_protos)
306///
307/// # Errors
308/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
309#[allow(clippy::too_many_arguments)]
310pub fn decode_yolo_split_segdet<
311    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
312    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
313    MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
314    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
315>(
316    boxes: (ArrayView2<BOX>, Quantization),
317    scores: (ArrayView2<SCORE>, Quantization),
318    mask_coeff: (ArrayView2<MASK>, Quantization),
319    protos: (ArrayView3<PROTO>, Quantization),
320    score_threshold: f32,
321    iou_threshold: f32,
322    nms: Option<Nms>,
323    output_boxes: &mut Vec<DetectBox>,
324    output_masks: &mut Vec<Segmentation>,
325) -> Result<(), crate::DecoderError>
326where
327    f32: AsPrimitive<SCORE>,
328{
329    impl_yolo_split_segdet_quant::<XYWH, _, _, _, _>(
330        boxes,
331        scores,
332        mask_coeff,
333        protos,
334        score_threshold,
335        iou_threshold,
336        nms,
337        output_boxes,
338        output_masks,
339    )
340}
341
342/// Decodes YOLO split detection segmentation outputs from float tensors
343/// into detection boxes and segmentation masks.
344///
345/// Boxes are expected to be in XYWH format.
346///
347/// Expected shapes of inputs:
348/// - boxes_tensor: (4, num_boxes)
349/// - scores_tensor: (num_classes, num_boxes)
350/// - mask_tensor: (num_protos, num_boxes)
351/// - protos: (proto_height, proto_width, num_protos)
352///
353/// # Errors
354/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
355#[allow(clippy::too_many_arguments)]
356pub fn decode_yolo_split_segdet_float<T>(
357    boxes: ArrayView2<T>,
358    scores: ArrayView2<T>,
359    mask_coeff: ArrayView2<T>,
360    protos: ArrayView3<T>,
361    score_threshold: f32,
362    iou_threshold: f32,
363    nms: Option<Nms>,
364    output_boxes: &mut Vec<DetectBox>,
365    output_masks: &mut Vec<Segmentation>,
366) -> Result<(), crate::DecoderError>
367where
368    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
369    f32: AsPrimitive<T>,
370{
371    impl_yolo_split_segdet_float::<XYWH, _, _, _, _>(
372        boxes,
373        scores,
374        mask_coeff,
375        protos,
376        score_threshold,
377        iou_threshold,
378        nms,
379        output_boxes,
380        output_masks,
381    )
382}
383
384/// Decodes end-to-end YOLO detection outputs (post-NMS from model).
385/// Expects an array of shape `(6, N)`, where the first dimension (rows)
386/// corresponds to the 6 per-detection features
387/// `[x1, y1, x2, y2, conf, class]` and the second dimension (columns)
388/// indexes the `N` detections.
389/// Boxes are output directly without NMS (the model already applied NMS).
390///
391/// Coordinates may be normalized `[0, 1]` or absolute pixel values depending
392/// on the model configuration. The caller should check
393/// `decoder.normalized_boxes()` to determine which.
394///
395/// # Errors
396///
397/// Returns `DecoderError::InvalidShape` if `output` has fewer than 6 rows.
398pub fn decode_yolo_end_to_end_det_float<T>(
399    output: ArrayView2<T>,
400    score_threshold: f32,
401    output_boxes: &mut Vec<DetectBox>,
402) -> Result<(), crate::DecoderError>
403where
404    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
405    f32: AsPrimitive<T>,
406{
407    // Validate input shape: need at least 6 rows (x1, y1, x2, y2, conf, class)
408    if output.shape()[0] < 6 {
409        return Err(crate::DecoderError::InvalidShape(format!(
410            "End-to-end detection output requires at least 6 rows, got {}",
411            output.shape()[0]
412        )));
413    }
414
415    // Input shape: (6, N) -> transpose to (N, 4) for boxes and (N, 1) for scores
416    let boxes = output.slice(s![0..4, ..]).reversed_axes();
417    let scores = output.slice(s![4..5, ..]).reversed_axes();
418    let classes = output.slice(s![5, ..]);
419    let mut boxes =
420        postprocess_boxes_index_float::<XYXY, _, _>(score_threshold.as_(), boxes, scores);
421    boxes.truncate(output_boxes.capacity());
422    output_boxes.clear();
423    for (mut b, i) in boxes.into_iter() {
424        b.label = classes[i].as_() as usize;
425        output_boxes.push(b);
426    }
427    // No NMS — model output is already post-NMS
428    Ok(())
429}
430
431/// Decodes end-to-end YOLO detection + segmentation outputs (post-NMS from
432/// model).
433///
434/// Input shapes:
435/// - detection: (6 + num_protos, N) where rows are [x1, y1, x2, y2, conf,
436///   class, mask_coeff_0, ..., mask_coeff_31]
437/// - protos: (proto_height, proto_width, num_protos)
438///
439/// Boxes are output directly without NMS (model already applied NMS).
440/// Coordinates may be normalized [0,1] or pixel values depending on model
441/// config.
442///
443/// # Errors
444///
445/// Returns `DecoderError::InvalidShape` if:
446/// - output has fewer than 7 rows (6 base + at least 1 mask coefficient)
447/// - protos shape doesn't match mask coefficients count
448pub fn decode_yolo_end_to_end_segdet_float<T>(
449    output: ArrayView2<T>,
450    protos: ArrayView3<T>,
451    score_threshold: f32,
452    output_boxes: &mut Vec<DetectBox>,
453    output_masks: &mut Vec<crate::Segmentation>,
454) -> Result<(), crate::DecoderError>
455where
456    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
457    f32: AsPrimitive<T>,
458{
459    let (boxes, scores, classes, mask_coeff) =
460        postprocess_yolo_end_to_end_segdet(&output, protos.dim().2)?;
461    let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
462        boxes,
463        scores,
464        classes,
465        score_threshold,
466        output_boxes.capacity(),
467    );
468
469    // No NMS — model output is already post-NMS
470
471    impl_yolo_split_segdet_process_masks(boxes, mask_coeff, protos, output_boxes, output_masks)
472}
473
474/// Decodes split end-to-end YOLO detection outputs (post-NMS from model).
475///
476/// Input shapes (after batch dim removed):
477/// - boxes: (4, N) — xyxy pixel coordinates
478/// - scores: (1, N) — confidence of the top class
479/// - classes: (1, N) — class index of the top class
480///
481/// Boxes are output directly without NMS (model already applied NMS).
482pub fn decode_yolo_split_end_to_end_det_float<T: Float + AsPrimitive<f32>>(
483    boxes: ArrayView2<T>,
484    scores: ArrayView2<T>,
485    classes: ArrayView2<T>,
486    score_threshold: f32,
487    output_boxes: &mut Vec<DetectBox>,
488) -> Result<(), crate::DecoderError> {
489    let n = boxes.shape()[1];
490
491    output_boxes.clear();
492
493    let (boxes, scores, classes) = postprocess_yolo_split_end_to_end_det(boxes, scores, &classes)?;
494
495    for i in 0..n {
496        let score: f32 = scores[[i, 0]].as_();
497        if score < score_threshold {
498            continue;
499        }
500        if output_boxes.len() >= output_boxes.capacity() {
501            break;
502        }
503        output_boxes.push(DetectBox {
504            bbox: BoundingBox {
505                xmin: boxes[[i, 0]].as_(),
506                ymin: boxes[[i, 1]].as_(),
507                xmax: boxes[[i, 2]].as_(),
508                ymax: boxes[[i, 3]].as_(),
509            },
510            score,
511            label: classes[i].as_() as usize,
512        });
513    }
514    Ok(())
515}
516
517/// Decodes split end-to-end YOLO detection + segmentation outputs.
518///
519/// Input shapes (after batch dim removed):
520/// - boxes: (4, N) — xyxy pixel coordinates
521/// - scores: (1, N) — confidence
522/// - classes: (1, N) — class index
523/// - mask_coeff: (num_protos, N) — mask coefficients per detection
524/// - protos: (proto_h, proto_w, num_protos) — prototype masks
525#[allow(clippy::too_many_arguments)]
526pub fn decode_yolo_split_end_to_end_segdet_float<T>(
527    boxes: ArrayView2<T>,
528    scores: ArrayView2<T>,
529    classes: ArrayView2<T>,
530    mask_coeff: ArrayView2<T>,
531    protos: ArrayView3<T>,
532    score_threshold: f32,
533    output_boxes: &mut Vec<DetectBox>,
534    output_masks: &mut Vec<crate::Segmentation>,
535) -> Result<(), crate::DecoderError>
536where
537    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
538    f32: AsPrimitive<T>,
539{
540    let (boxes, scores, classes, mask_coeff) =
541        postprocess_yolo_split_end_to_end_segdet(boxes, scores, &classes, mask_coeff)?;
542    let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
543        boxes,
544        scores,
545        classes,
546        score_threshold,
547        output_boxes.capacity(),
548    );
549
550    impl_yolo_split_segdet_process_masks(boxes, mask_coeff, protos, output_boxes, output_masks)
551}
552
553#[allow(clippy::type_complexity)]
554pub(crate) fn postprocess_yolo_end_to_end_segdet<'a, T>(
555    output: &'a ArrayView2<'_, T>,
556    num_protos: usize,
557) -> Result<
558    (
559        ArrayView2<'a, T>,
560        ArrayView2<'a, T>,
561        ArrayView1<'a, T>,
562        ArrayView2<'a, T>,
563    ),
564    crate::DecoderError,
565> {
566    // Validate input shape: need at least 7 rows (6 base + at least 1 mask coeff)
567    if output.shape()[0] < 7 {
568        return Err(crate::DecoderError::InvalidShape(format!(
569            "End-to-end segdet output requires at least 7 rows, got {}",
570            output.shape()[0]
571        )));
572    }
573
574    let num_mask_coeffs = output.shape()[0] - 6;
575    if num_mask_coeffs != num_protos {
576        return Err(crate::DecoderError::InvalidShape(format!(
577            "Mask coefficients count ({}) doesn't match protos count ({})",
578            num_mask_coeffs, num_protos
579        )));
580    }
581
582    // Input shape: (6+num_protos, N) -> transpose for postprocessing
583    let boxes = output.slice(s![0..4, ..]).reversed_axes();
584    let scores = output.slice(s![4..5, ..]).reversed_axes();
585    let classes = output.slice(s![5, ..]);
586    let mask_coeff = output.slice(s![6.., ..]).reversed_axes();
587    Ok((boxes, scores, classes, mask_coeff))
588}
589
590/// Postprocess yolo split end to end det by reversing axes of boxes,
591/// scores, and flattening the class tensor.
592/// Expected input shapes:
593/// - boxes: (4, N)
594/// - scores: (1, N)
595/// - classes: (1, N)
596#[allow(clippy::type_complexity)]
597pub(crate) fn postprocess_yolo_split_end_to_end_det<'a, 'b, 'c, BOXES, SCORES, CLASS>(
598    boxes: ArrayView2<'a, BOXES>,
599    scores: ArrayView2<'b, SCORES>,
600    classes: &'c ArrayView2<CLASS>,
601) -> Result<
602    (
603        ArrayView2<'a, BOXES>,
604        ArrayView2<'b, SCORES>,
605        ArrayView1<'c, CLASS>,
606    ),
607    crate::DecoderError,
608> {
609    let num_boxes = boxes.shape()[1];
610    if boxes.shape()[0] != 4 {
611        return Err(crate::DecoderError::InvalidShape(format!(
612            "Split end-to-end box_coords must be 4, got {}",
613            boxes.shape()[0]
614        )));
615    }
616
617    if scores.shape()[0] != 1 {
618        return Err(crate::DecoderError::InvalidShape(format!(
619            "Split end-to-end scores num_classes must be 1, got {}",
620            scores.shape()[0]
621        )));
622    }
623
624    if classes.shape()[0] != 1 {
625        return Err(crate::DecoderError::InvalidShape(format!(
626            "Split end-to-end classes num_classes must be 1, got {}",
627            classes.shape()[0]
628        )));
629    }
630
631    if scores.shape()[1] != num_boxes {
632        return Err(crate::DecoderError::InvalidShape(format!(
633            "Split end-to-end scores must have same num_boxes as boxes ({}), got {}",
634            num_boxes,
635            scores.shape()[1]
636        )));
637    }
638
639    if classes.shape()[1] != num_boxes {
640        return Err(crate::DecoderError::InvalidShape(format!(
641            "Split end-to-end classes must have same num_boxes as boxes ({}), got {}",
642            num_boxes,
643            classes.shape()[1]
644        )));
645    }
646
647    let boxes = boxes.reversed_axes();
648    let scores = scores.reversed_axes();
649    let classes = classes.slice(s![0, ..]);
650    Ok((boxes, scores, classes))
651}
652
653/// Postprocess yolo split end to end segdet by reversing axes of boxes,
654/// scores, mask tensors and flattening the class tensor.
655#[allow(clippy::type_complexity)]
656pub(crate) fn postprocess_yolo_split_end_to_end_segdet<
657    'a,
658    'b,
659    'c,
660    'd,
661    BOXES,
662    SCORES,
663    CLASS,
664    MASK,
665>(
666    boxes: ArrayView2<'a, BOXES>,
667    scores: ArrayView2<'b, SCORES>,
668    classes: &'c ArrayView2<CLASS>,
669    mask_coeff: ArrayView2<'d, MASK>,
670) -> Result<
671    (
672        ArrayView2<'a, BOXES>,
673        ArrayView2<'b, SCORES>,
674        ArrayView1<'c, CLASS>,
675        ArrayView2<'d, MASK>,
676    ),
677    crate::DecoderError,
678> {
679    let num_boxes = boxes.shape()[1];
680    if boxes.shape()[0] != 4 {
681        return Err(crate::DecoderError::InvalidShape(format!(
682            "Split end-to-end box_coords must be 4, got {}",
683            boxes.shape()[0]
684        )));
685    }
686
687    if scores.shape()[0] != 1 {
688        return Err(crate::DecoderError::InvalidShape(format!(
689            "Split end-to-end scores num_classes must be 1, got {}",
690            scores.shape()[0]
691        )));
692    }
693
694    if classes.shape()[0] != 1 {
695        return Err(crate::DecoderError::InvalidShape(format!(
696            "Split end-to-end classes num_classes must be 1, got {}",
697            classes.shape()[0]
698        )));
699    }
700
701    if scores.shape()[1] != num_boxes {
702        return Err(crate::DecoderError::InvalidShape(format!(
703            "Split end-to-end scores must have same num_boxes as boxes ({}), got {}",
704            num_boxes,
705            scores.shape()[1]
706        )));
707    }
708
709    if classes.shape()[1] != num_boxes {
710        return Err(crate::DecoderError::InvalidShape(format!(
711            "Split end-to-end classes must have same num_boxes as boxes ({}), got {}",
712            num_boxes,
713            classes.shape()[1]
714        )));
715    }
716
717    if mask_coeff.shape()[1] != num_boxes {
718        return Err(crate::DecoderError::InvalidShape(format!(
719            "Split end-to-end mask_coeff must have same num_boxes as boxes ({}), got {}",
720            num_boxes,
721            mask_coeff.shape()[1]
722        )));
723    }
724
725    let boxes = boxes.reversed_axes();
726    let scores = scores.reversed_axes();
727    let classes = classes.slice(s![0, ..]);
728    let mask_coeff = mask_coeff.reversed_axes();
729    Ok((boxes, scores, classes, mask_coeff))
730}
731/// Internal implementation of YOLO decoding for quantized tensors.
732///
733/// Expected shapes of inputs:
734/// - output: (4 + num_classes, num_boxes)
735pub(crate) fn impl_yolo_quant<B: BBoxTypeTrait, T: PrimInt + AsPrimitive<f32> + Send + Sync>(
736    output: (ArrayView2<T>, Quantization),
737    score_threshold: f32,
738    iou_threshold: f32,
739    nms: Option<Nms>,
740    output_boxes: &mut Vec<DetectBox>,
741) where
742    f32: AsPrimitive<T>,
743{
744    let _span = tracing::trace_span!("decode", mode = "quant_det").entered();
745    let (boxes, quant_boxes) = output;
746    let (boxes_tensor, scores_tensor) = postprocess_yolo(&boxes);
747
748    let boxes = {
749        let score_threshold = quantize_score_threshold(score_threshold, quant_boxes);
750        postprocess_boxes_quant::<B, _, _>(
751            score_threshold,
752            boxes_tensor,
753            scores_tensor,
754            quant_boxes,
755        )
756    };
757
758    let boxes = dispatch_nms_int(nms, iou_threshold, boxes);
759    let len = output_boxes.capacity().min(boxes.len());
760    output_boxes.clear();
761    for b in boxes.iter().take(len) {
762        output_boxes.push(dequant_detect_box(b, quant_boxes));
763    }
764}
765
766/// Internal implementation of YOLO decoding for float tensors.
767///
768/// Expected shapes of inputs:
769/// - output: (4 + num_classes, num_boxes)
770pub(crate) fn impl_yolo_float<B: BBoxTypeTrait, T: Float + AsPrimitive<f32> + Send + Sync>(
771    output: ArrayView2<T>,
772    score_threshold: f32,
773    iou_threshold: f32,
774    nms: Option<Nms>,
775    output_boxes: &mut Vec<DetectBox>,
776) where
777    f32: AsPrimitive<T>,
778{
779    let _span = tracing::trace_span!("decode", mode = "float_det").entered();
780    let (boxes_tensor, scores_tensor) = postprocess_yolo(&output);
781    let boxes =
782        postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
783    let boxes = dispatch_nms_float(nms, iou_threshold, boxes);
784    let len = output_boxes.capacity().min(boxes.len());
785    output_boxes.clear();
786    for b in boxes.into_iter().take(len) {
787        output_boxes.push(b);
788    }
789}
790
791/// Internal implementation of YOLO split detection decoding for quantized
792/// tensors.
793///
794/// Expected shapes of inputs:
795/// - boxes: (4, num_boxes)
796/// - scores: (num_classes, num_boxes)
797///
798/// # Panics
799/// Panics if shapes don't match the expected dimensions.
800pub(crate) fn impl_yolo_split_quant<
801    B: BBoxTypeTrait,
802    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
803    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
804>(
805    boxes: (ArrayView2<BOX>, Quantization),
806    scores: (ArrayView2<SCORE>, Quantization),
807    score_threshold: f32,
808    iou_threshold: f32,
809    nms: Option<Nms>,
810    output_boxes: &mut Vec<DetectBox>,
811) where
812    f32: AsPrimitive<SCORE>,
813{
814    let _span = tracing::trace_span!("decode", mode = "split_quant_det").entered();
815    let (boxes_tensor, quant_boxes) = boxes;
816    let (scores_tensor, quant_scores) = scores;
817
818    let boxes_tensor = boxes_tensor.reversed_axes();
819    let scores_tensor = scores_tensor.reversed_axes();
820
821    let boxes = {
822        let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
823        postprocess_boxes_quant::<B, _, _>(
824            score_threshold,
825            boxes_tensor,
826            scores_tensor,
827            quant_boxes,
828        )
829    };
830
831    let boxes = dispatch_nms_int(nms, iou_threshold, boxes);
832    let len = output_boxes.capacity().min(boxes.len());
833    output_boxes.clear();
834    for b in boxes.iter().take(len) {
835        output_boxes.push(dequant_detect_box(b, quant_scores));
836    }
837}
838
839/// Internal implementation of YOLO split detection decoding for float tensors.
840///
841/// Expected shapes of inputs:
842/// - boxes: (4, num_boxes)
843/// - scores: (num_classes, num_boxes)
844///
845/// # Panics
846/// Panics if shapes don't match the expected dimensions.
847pub(crate) fn impl_yolo_split_float<
848    B: BBoxTypeTrait,
849    BOX: Float + AsPrimitive<f32> + Send + Sync,
850    SCORE: Float + AsPrimitive<f32> + Send + Sync,
851>(
852    boxes_tensor: ArrayView2<BOX>,
853    scores_tensor: ArrayView2<SCORE>,
854    score_threshold: f32,
855    iou_threshold: f32,
856    nms: Option<Nms>,
857    output_boxes: &mut Vec<DetectBox>,
858) where
859    f32: AsPrimitive<SCORE>,
860{
861    let _span = tracing::trace_span!("decode", mode = "split_float_det").entered();
862    let boxes_tensor = boxes_tensor.reversed_axes();
863    let scores_tensor = scores_tensor.reversed_axes();
864    let boxes =
865        postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
866    let boxes = dispatch_nms_float(nms, iou_threshold, boxes);
867    let len = output_boxes.capacity().min(boxes.len());
868    output_boxes.clear();
869    for b in boxes.into_iter().take(len) {
870        output_boxes.push(b);
871    }
872}
873
874/// Internal implementation of YOLO detection segmentation decoding for
875/// quantized tensors.
876///
877/// Expected shapes of inputs:
878/// - boxes: (4 + num_classes + num_protos, num_boxes)
879/// - protos: (proto_height, proto_width, num_protos)
880///
881/// # Errors
882/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
883#[allow(clippy::too_many_arguments)]
884pub(crate) fn impl_yolo_segdet_quant<
885    B: BBoxTypeTrait,
886    BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
887    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
888>(
889    boxes: (ArrayView2<BOX>, Quantization),
890    protos: (ArrayView3<PROTO>, Quantization),
891    score_threshold: f32,
892    iou_threshold: f32,
893    nms: Option<Nms>,
894    pre_nms_top_k: usize,
895    max_det: usize,
896    output_boxes: &mut Vec<DetectBox>,
897    output_masks: &mut Vec<Segmentation>,
898) -> Result<(), crate::DecoderError>
899where
900    f32: AsPrimitive<BOX>,
901{
902    let (boxes, quant_boxes) = boxes;
903    let num_protos = protos.0.dim().2;
904
905    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
906    let boxes = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
907        (boxes_tensor, quant_boxes),
908        (scores_tensor, quant_boxes),
909        score_threshold,
910        iou_threshold,
911        nms,
912        pre_nms_top_k,
913        max_det,
914    );
915
916    impl_yolo_split_segdet_quant_process_masks::<_, _>(
917        boxes,
918        (mask_tensor, quant_boxes),
919        protos,
920        output_boxes,
921        output_masks,
922    )
923}
924
925/// Internal implementation of YOLO detection segmentation decoding for
926/// float tensors.
927///
928/// Expected shapes of inputs:
929/// - boxes: (4 + num_classes + num_protos, num_boxes)
930/// - protos: (proto_height, proto_width, num_protos)
931///
932/// # Panics
933/// Panics if shapes don't match the expected dimensions.
934#[allow(clippy::too_many_arguments)]
935pub(crate) fn impl_yolo_segdet_float<
936    B: BBoxTypeTrait,
937    BOX: Float + AsPrimitive<f32> + Send + Sync,
938    PROTO: Float + AsPrimitive<f32> + Send + Sync,
939>(
940    boxes: ArrayView2<BOX>,
941    protos: ArrayView3<PROTO>,
942    score_threshold: f32,
943    iou_threshold: f32,
944    nms: Option<Nms>,
945    pre_nms_top_k: usize,
946    max_det: usize,
947    output_boxes: &mut Vec<DetectBox>,
948    output_masks: &mut Vec<Segmentation>,
949) -> Result<(), crate::DecoderError>
950where
951    f32: AsPrimitive<BOX>,
952{
953    let num_protos = protos.dim().2;
954    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
955    let boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
956        boxes_tensor,
957        scores_tensor,
958        score_threshold,
959        iou_threshold,
960        nms,
961        pre_nms_top_k,
962        max_det,
963    );
964    impl_yolo_split_segdet_process_masks(boxes, mask_tensor, protos, output_boxes, output_masks)
965}
966
967pub(crate) fn impl_yolo_segdet_get_boxes<
968    B: BBoxTypeTrait,
969    BOX: Float + AsPrimitive<f32> + Send + Sync,
970    SCORE: Float + AsPrimitive<f32> + Send + Sync,
971>(
972    boxes_tensor: ArrayView2<BOX>,
973    scores_tensor: ArrayView2<SCORE>,
974    score_threshold: f32,
975    iou_threshold: f32,
976    nms: Option<Nms>,
977    pre_nms_top_k: usize,
978    max_det: usize,
979) -> Vec<(DetectBox, usize)>
980where
981    f32: AsPrimitive<SCORE>,
982{
983    let span = tracing::trace_span!(
984        "decode",
985        n_candidates = tracing::field::Empty,
986        n_after_topk = tracing::field::Empty,
987        n_after_nms = tracing::field::Empty,
988        n_detections = tracing::field::Empty,
989    );
990    let _guard = span.enter();
991
992    let mut boxes = {
993        let _s = tracing::trace_span!("score_filter").entered();
994        postprocess_boxes_index_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor)
995    };
996    span.record("n_candidates", boxes.len());
997
998    if nms.is_some() {
999        let _s = tracing::trace_span!("top_k", k = pre_nms_top_k).entered();
1000        truncate_to_top_k_by_score(&mut boxes, pre_nms_top_k);
1001    }
1002    span.record("n_after_topk", boxes.len());
1003
1004    let mut boxes = {
1005        let _s = tracing::trace_span!("nms").entered();
1006        dispatch_nms_extra_float(nms, iou_threshold, boxes)
1007    };
1008    span.record("n_after_nms", boxes.len());
1009
1010    boxes.sort_unstable_by(|a, b| b.0.score.total_cmp(&a.0.score));
1011    boxes.truncate(max_det);
1012    span.record("n_detections", boxes.len());
1013
1014    boxes
1015}
1016
1017pub(crate) fn impl_yolo_end_to_end_segdet_get_boxes<
1018    B: BBoxTypeTrait,
1019    BOX: Float + AsPrimitive<f32> + Send + Sync,
1020    SCORE: Float + AsPrimitive<f32> + Send + Sync,
1021    CLASS: AsPrimitive<f32> + Send + Sync,
1022>(
1023    boxes: ArrayView2<BOX>,
1024    scores: ArrayView2<SCORE>,
1025    classes: ArrayView1<CLASS>,
1026    score_threshold: f32,
1027    max_boxes: usize,
1028) -> Vec<(DetectBox, usize)>
1029where
1030    f32: AsPrimitive<SCORE>,
1031{
1032    let mut boxes = postprocess_boxes_index_float::<B, _, _>(score_threshold.as_(), boxes, scores);
1033    boxes.truncate(max_boxes);
1034    for (b, ind) in &mut boxes {
1035        b.label = classes[*ind].as_().round() as usize;
1036    }
1037    boxes
1038}
1039
1040pub(crate) fn impl_yolo_split_segdet_process_masks<
1041    MASK: Float + AsPrimitive<f32> + Send + Sync,
1042    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1043>(
1044    mut boxes: Vec<(DetectBox, usize)>,
1045    masks_tensor: ArrayView2<MASK>,
1046    protos_tensor: ArrayView3<PROTO>,
1047    output_boxes: &mut Vec<DetectBox>,
1048    output_masks: &mut Vec<Segmentation>,
1049) -> Result<(), crate::DecoderError> {
1050    let _span = tracing::trace_span!("process_masks", n = boxes.len(), mode = "float").entered();
1051    // Respect output capacity as a hard limit to avoid computing excess masks.
1052    let cap = output_boxes.capacity();
1053    if cap > 0 && boxes.len() > cap {
1054        boxes.truncate(cap);
1055    }
1056
1057    let boxes = decode_segdet_f32(boxes, masks_tensor, protos_tensor)?;
1058    output_boxes.clear();
1059    output_masks.clear();
1060    for (b, m) in boxes.into_iter() {
1061        output_boxes.push(b);
1062        output_masks.push(Segmentation {
1063            xmin: b.bbox.xmin,
1064            ymin: b.bbox.ymin,
1065            xmax: b.bbox.xmax,
1066            ymax: b.bbox.ymax,
1067            segmentation: m,
1068        });
1069    }
1070    Ok(())
1071}
1072/// Expected input shapes:
1073/// - boxes_tensor: (num_boxes, 4)
1074/// - scores_tensor: (num_boxes, num_classes)
1075pub(crate) fn impl_yolo_split_segdet_quant_get_boxes<
1076    B: BBoxTypeTrait,
1077    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
1078    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
1079>(
1080    boxes: (ArrayView2<BOX>, Quantization),
1081    scores: (ArrayView2<SCORE>, Quantization),
1082    score_threshold: f32,
1083    iou_threshold: f32,
1084    nms: Option<Nms>,
1085    pre_nms_top_k: usize,
1086    max_det: usize,
1087) -> Vec<(DetectBox, usize)>
1088where
1089    f32: AsPrimitive<SCORE>,
1090{
1091    let (boxes_tensor, quant_boxes) = boxes;
1092    let (scores_tensor, quant_scores) = scores;
1093
1094    let span = tracing::trace_span!(
1095        "decode",
1096        n_candidates = tracing::field::Empty,
1097        n_after_topk = tracing::field::Empty,
1098        n_after_nms = tracing::field::Empty,
1099        n_detections = tracing::field::Empty,
1100    );
1101    let _guard = span.enter();
1102
1103    let mut boxes = {
1104        let _s = tracing::trace_span!("score_filter").entered();
1105        let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
1106        postprocess_boxes_index_quant::<B, _, _>(
1107            score_threshold,
1108            boxes_tensor,
1109            scores_tensor,
1110            quant_boxes,
1111        )
1112    };
1113    span.record("n_candidates", boxes.len());
1114
1115    if nms.is_some() {
1116        let _s = tracing::trace_span!("top_k", k = pre_nms_top_k).entered();
1117        truncate_to_top_k_by_score_quant(&mut boxes, pre_nms_top_k);
1118    }
1119    span.record("n_after_topk", boxes.len());
1120
1121    let mut boxes = {
1122        let _s = tracing::trace_span!("nms").entered();
1123        dispatch_nms_extra_int(nms, iou_threshold, boxes)
1124    };
1125    span.record("n_after_nms", boxes.len());
1126
1127    // Sort by score descending before truncation to keep top-scoring detections.
1128    boxes.sort_unstable_by(|a, b| b.0.score.cmp(&a.0.score));
1129    boxes.truncate(max_det);
1130    let result: Vec<_> = {
1131        let _s = tracing::trace_span!("box_dequant", n = boxes.len()).entered();
1132        boxes
1133            .into_iter()
1134            .map(|(b, i)| (dequant_detect_box(&b, quant_scores), i))
1135            .collect()
1136    };
1137    span.record("n_detections", result.len());
1138
1139    result
1140}
1141
1142pub(crate) fn impl_yolo_split_segdet_quant_process_masks<
1143    MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1144    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1145>(
1146    mut boxes: Vec<(DetectBox, usize)>,
1147    mask_coeff: (ArrayView2<MASK>, Quantization),
1148    protos: (ArrayView3<PROTO>, Quantization),
1149    output_boxes: &mut Vec<DetectBox>,
1150    output_masks: &mut Vec<Segmentation>,
1151) -> Result<(), crate::DecoderError> {
1152    let _span = tracing::trace_span!("process_masks", n = boxes.len(), mode = "quant").entered();
1153    let (masks, quant_masks) = mask_coeff;
1154    let (protos, quant_protos) = protos;
1155
1156    // Respect output capacity as a hard limit to avoid computing excess masks.
1157    let cap = output_boxes.capacity();
1158    if cap > 0 && boxes.len() > cap {
1159        boxes.truncate(cap);
1160    }
1161
1162    let boxes = decode_segdet_quant(boxes, masks, protos, quant_masks, quant_protos)?;
1163    output_boxes.clear();
1164    output_masks.clear();
1165    for (b, m) in boxes.into_iter() {
1166        output_boxes.push(b);
1167        output_masks.push(Segmentation {
1168            xmin: b.bbox.xmin,
1169            ymin: b.bbox.ymin,
1170            xmax: b.bbox.xmax,
1171            ymax: b.bbox.ymax,
1172            segmentation: m,
1173        });
1174    }
1175    Ok(())
1176}
1177
1178#[allow(clippy::too_many_arguments)]
1179/// Internal implementation of YOLO split detection segmentation decoding for
1180/// quantized tensors.
1181///
1182/// Expected shapes of inputs:
1183/// - boxes_tensor: (4, num_boxes)
1184/// - scores_tensor: (num_classes, num_boxes)
1185/// - mask_tensor: (num_protos, num_boxes)
1186/// - protos: (proto_height, proto_width, num_protos)
1187///
1188/// # Errors
1189/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
1190pub(crate) fn impl_yolo_split_segdet_quant<
1191    B: BBoxTypeTrait,
1192    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
1193    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
1194    MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1195    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1196>(
1197    boxes: (ArrayView2<BOX>, Quantization),
1198    scores: (ArrayView2<SCORE>, Quantization),
1199    mask_coeff: (ArrayView2<MASK>, Quantization),
1200    protos: (ArrayView3<PROTO>, Quantization),
1201    score_threshold: f32,
1202    iou_threshold: f32,
1203    nms: Option<Nms>,
1204    output_boxes: &mut Vec<DetectBox>,
1205    output_masks: &mut Vec<Segmentation>,
1206) -> Result<(), crate::DecoderError>
1207where
1208    f32: AsPrimitive<SCORE>,
1209{
1210    let (boxes_, scores_, mask_coeff_) =
1211        postprocess_yolo_split_segdet(boxes.0, scores.0, mask_coeff.0);
1212    let boxes = (boxes_, boxes.1);
1213    let scores = (scores_, scores.1);
1214    let mask_coeff = (mask_coeff_, mask_coeff.1);
1215
1216    let boxes = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
1217        boxes,
1218        scores,
1219        score_threshold,
1220        iou_threshold,
1221        nms,
1222        MAX_NMS_CANDIDATES,
1223        output_boxes.capacity(),
1224    );
1225
1226    impl_yolo_split_segdet_quant_process_masks(
1227        boxes,
1228        mask_coeff,
1229        protos,
1230        output_boxes,
1231        output_masks,
1232    )
1233}
1234
1235#[allow(clippy::too_many_arguments)]
1236/// Internal implementation of YOLO split detection segmentation decoding for
1237/// float tensors.
1238///
1239/// Expected shapes of inputs:
1240/// - boxes_tensor: (4, num_boxes)
1241/// - scores_tensor: (num_classes, num_boxes)
1242/// - mask_tensor: (num_protos, num_boxes)
1243/// - protos: (proto_height, proto_width, num_protos)
1244///
1245/// # Errors
1246/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
1247pub(crate) fn impl_yolo_split_segdet_float<
1248    B: BBoxTypeTrait,
1249    BOX: Float + AsPrimitive<f32> + Send + Sync,
1250    SCORE: Float + AsPrimitive<f32> + Send + Sync,
1251    MASK: Float + AsPrimitive<f32> + Send + Sync,
1252    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1253>(
1254    boxes_tensor: ArrayView2<BOX>,
1255    scores_tensor: ArrayView2<SCORE>,
1256    mask_tensor: ArrayView2<MASK>,
1257    protos: ArrayView3<PROTO>,
1258    score_threshold: f32,
1259    iou_threshold: f32,
1260    nms: Option<Nms>,
1261    output_boxes: &mut Vec<DetectBox>,
1262    output_masks: &mut Vec<Segmentation>,
1263) -> Result<(), crate::DecoderError>
1264where
1265    f32: AsPrimitive<SCORE>,
1266{
1267    let (boxes_tensor, scores_tensor, mask_tensor) =
1268        postprocess_yolo_split_segdet(boxes_tensor, scores_tensor, mask_tensor);
1269
1270    let boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
1271        boxes_tensor,
1272        scores_tensor,
1273        score_threshold,
1274        iou_threshold,
1275        nms,
1276        MAX_NMS_CANDIDATES,
1277        output_boxes.capacity(),
1278    );
1279    impl_yolo_split_segdet_process_masks(boxes, mask_tensor, protos, output_boxes, output_masks)
1280}
1281
1282// ---------------------------------------------------------------------------
1283// Proto-extraction variants: return ProtoData instead of materialized masks
1284// ---------------------------------------------------------------------------
1285
1286/// Proto-extraction variant of `impl_yolo_segdet_quant`.
1287/// Runs NMS but returns raw `ProtoData` instead of materialized masks.
1288#[allow(clippy::too_many_arguments)]
1289pub fn impl_yolo_segdet_quant_proto<
1290    B: BBoxTypeTrait,
1291    BOX: PrimInt
1292        + AsPrimitive<i64>
1293        + AsPrimitive<i128>
1294        + AsPrimitive<f32>
1295        + AsPrimitive<i8>
1296        + Send
1297        + Sync,
1298    PROTO: PrimInt
1299        + AsPrimitive<i64>
1300        + AsPrimitive<i128>
1301        + AsPrimitive<f32>
1302        + AsPrimitive<i8>
1303        + Send
1304        + Sync,
1305>(
1306    boxes: (ArrayView2<BOX>, Quantization),
1307    protos: (ArrayView3<PROTO>, Quantization),
1308    score_threshold: f32,
1309    iou_threshold: f32,
1310    nms: Option<Nms>,
1311    pre_nms_top_k: usize,
1312    max_det: usize,
1313    output_boxes: &mut Vec<DetectBox>,
1314) -> ProtoData
1315where
1316    f32: AsPrimitive<BOX>,
1317{
1318    let (boxes_arr, quant_boxes) = boxes;
1319    let (protos_arr, quant_protos) = protos;
1320    let num_protos = protos_arr.dim().2;
1321
1322    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes_arr, num_protos);
1323
1324    let det_indices = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
1325        (boxes_tensor, quant_boxes),
1326        (scores_tensor, quant_boxes),
1327        score_threshold,
1328        iou_threshold,
1329        nms,
1330        pre_nms_top_k,
1331        max_det,
1332    );
1333
1334    extract_proto_data_quant(
1335        det_indices,
1336        mask_tensor,
1337        quant_boxes,
1338        protos_arr,
1339        quant_protos,
1340        output_boxes,
1341    )
1342}
1343
1344/// Proto-extraction variant of `impl_yolo_segdet_float`.
1345/// Runs NMS but returns raw `ProtoData` instead of materialized masks.
1346#[allow(clippy::too_many_arguments)]
1347pub(crate) fn impl_yolo_segdet_float_proto<
1348    B: BBoxTypeTrait,
1349    BOX: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1350    PROTO: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1351>(
1352    boxes: ArrayView2<BOX>,
1353    protos: ArrayView3<PROTO>,
1354    score_threshold: f32,
1355    iou_threshold: f32,
1356    nms: Option<Nms>,
1357    pre_nms_top_k: usize,
1358    max_det: usize,
1359    output_boxes: &mut Vec<DetectBox>,
1360) -> ProtoData
1361where
1362    f32: AsPrimitive<BOX>,
1363{
1364    let num_protos = protos.dim().2;
1365    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
1366
1367    let boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
1368        boxes_tensor,
1369        scores_tensor,
1370        score_threshold,
1371        iou_threshold,
1372        nms,
1373        pre_nms_top_k,
1374        max_det,
1375    );
1376
1377    extract_proto_data_float(boxes, mask_tensor, protos, output_boxes)
1378}
1379
1380/// Proto-extraction variant of `impl_yolo_split_segdet_float`.
1381/// Runs NMS but returns raw `ProtoData` instead of materialized masks.
1382#[allow(clippy::too_many_arguments)]
1383pub(crate) fn impl_yolo_split_segdet_float_proto<
1384    B: BBoxTypeTrait,
1385    BOX: Float + AsPrimitive<f32> + Send + Sync,
1386    SCORE: Float + AsPrimitive<f32> + Send + Sync,
1387    MASK: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1388    PROTO: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1389>(
1390    boxes_tensor: ArrayView2<BOX>,
1391    scores_tensor: ArrayView2<SCORE>,
1392    mask_tensor: ArrayView2<MASK>,
1393    protos: ArrayView3<PROTO>,
1394    score_threshold: f32,
1395    iou_threshold: f32,
1396    nms: Option<Nms>,
1397    pre_nms_top_k: usize,
1398    max_det: usize,
1399    output_boxes: &mut Vec<DetectBox>,
1400) -> ProtoData
1401where
1402    f32: AsPrimitive<SCORE>,
1403{
1404    let (boxes_tensor, scores_tensor, mask_tensor) =
1405        postprocess_yolo_split_segdet(boxes_tensor, scores_tensor, mask_tensor);
1406    let det_indices = impl_yolo_segdet_get_boxes::<B, _, _>(
1407        boxes_tensor,
1408        scores_tensor,
1409        score_threshold,
1410        iou_threshold,
1411        nms,
1412        pre_nms_top_k,
1413        max_det,
1414    );
1415
1416    extract_proto_data_float(det_indices, mask_tensor, protos, output_boxes)
1417}
1418
1419/// Proto-extraction variant of `decode_yolo_end_to_end_segdet_float`.
1420pub fn decode_yolo_end_to_end_segdet_float_proto<T>(
1421    output: ArrayView2<T>,
1422    protos: ArrayView3<T>,
1423    score_threshold: f32,
1424    output_boxes: &mut Vec<DetectBox>,
1425) -> Result<ProtoData, crate::DecoderError>
1426where
1427    T: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1428    f32: AsPrimitive<T>,
1429{
1430    let (boxes, scores, classes, mask_coeff) =
1431        postprocess_yolo_end_to_end_segdet(&output, protos.dim().2)?;
1432    let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
1433        boxes,
1434        scores,
1435        classes,
1436        score_threshold,
1437        output_boxes.capacity(),
1438    );
1439
1440    Ok(extract_proto_data_float(
1441        boxes,
1442        mask_coeff,
1443        protos,
1444        output_boxes,
1445    ))
1446}
1447
1448/// Proto-extraction variant of `decode_yolo_split_end_to_end_segdet_float`.
1449#[allow(clippy::too_many_arguments)]
1450pub fn decode_yolo_split_end_to_end_segdet_float_proto<T>(
1451    boxes: ArrayView2<T>,
1452    scores: ArrayView2<T>,
1453    classes: ArrayView2<T>,
1454    mask_coeff: ArrayView2<T>,
1455    protos: ArrayView3<T>,
1456    score_threshold: f32,
1457    output_boxes: &mut Vec<DetectBox>,
1458) -> Result<ProtoData, crate::DecoderError>
1459where
1460    T: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1461    f32: AsPrimitive<T>,
1462{
1463    let (boxes, scores, classes, mask_coeff) =
1464        postprocess_yolo_split_end_to_end_segdet(boxes, scores, &classes, mask_coeff)?;
1465    let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
1466        boxes,
1467        scores,
1468        classes,
1469        score_threshold,
1470        output_boxes.capacity(),
1471    );
1472
1473    Ok(extract_proto_data_float(
1474        boxes,
1475        mask_coeff,
1476        protos,
1477        output_boxes,
1478    ))
1479}
1480
1481/// Helper: extract ProtoData from float mask coefficients + protos.
1482///
1483/// Builds [`ProtoData`] with both `protos` and `mask_coefficients` as
1484/// [`edgefirst_tensor::TensorDyn`]. Preserves the native element type for
1485/// `f16` and `f32`; narrows `f64` to `f32` (there is no native f64 kernel
1486/// path). `mask_coefficients` shape is `[num_detections, num_protos]`.
1487pub(super) fn extract_proto_data_float<
1488    MASK: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1489    PROTO: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1490>(
1491    det_indices: Vec<(DetectBox, usize)>,
1492    mask_tensor: ArrayView2<MASK>,
1493    protos: ArrayView3<PROTO>,
1494    output_boxes: &mut Vec<DetectBox>,
1495) -> ProtoData {
1496    let _span = tracing::trace_span!(
1497        "extract_proto",
1498        n = det_indices.len(),
1499        num_protos = mask_tensor.ncols(),
1500        layout = "nhwc",
1501    )
1502    .entered();
1503
1504    let num_protos = mask_tensor.ncols();
1505    let n = det_indices.len();
1506
1507    // Per-detection coefficients packed row-major into a contiguous buffer,
1508    // preserving the source dtype. Shape: [N, num_protos] — N=0 is permitted
1509    // (tracker path emits no detections this frame) since Mem-backed tensors
1510    // accept zero-element shapes as "empty collection" sentinels.
1511    let mut coeff_rows: Vec<MASK> = Vec::with_capacity(n * num_protos);
1512    output_boxes.clear();
1513    for (det, idx) in det_indices {
1514        output_boxes.push(det);
1515        let row = mask_tensor.row(idx);
1516        coeff_rows.extend(row.iter().copied());
1517    }
1518
1519    let mask_coefficients = MASK::slice_into_tensor_dyn(&coeff_rows, &[n, num_protos])
1520        .expect("allocating mask_coefficients TensorDyn");
1521    let protos_tensor =
1522        PROTO::arrayview3_into_tensor_dyn(protos).expect("allocating protos TensorDyn");
1523
1524    ProtoData {
1525        mask_coefficients,
1526        protos: protos_tensor,
1527        layout: ProtoLayout::Nhwc,
1528    }
1529}
1530
1531/// Helper: extract ProtoData from quantized mask coefficients + protos.
1532///
1533/// Dequantizes mask coefficients to f32 at extraction (one-time cost on a
1534/// `num_detections * num_protos` slice) and keeps protos in raw i8,
1535/// attaching the dequantization params as
1536/// [`edgefirst_tensor::Quantization::per_tensor`] metadata on the proto
1537/// tensor. The GPU shader / CPU kernel reads `protos.quantization()` and
1538/// dequantizes per-texel.
1539pub(crate) fn extract_proto_data_quant<
1540    MASK: PrimInt + AsPrimitive<f32> + AsPrimitive<i8> + Send + Sync,
1541    PROTO: PrimInt + AsPrimitive<f32> + AsPrimitive<i8> + Send + Sync + 'static,
1542>(
1543    det_indices: Vec<(DetectBox, usize)>,
1544    mask_tensor: ArrayView2<MASK>,
1545    quant_masks: Quantization,
1546    protos: ArrayView3<PROTO>,
1547    quant_protos: Quantization,
1548    output_boxes: &mut Vec<DetectBox>,
1549) -> ProtoData {
1550    use edgefirst_tensor::{Tensor, TensorDyn, TensorMapTrait, TensorMemory, TensorTrait};
1551
1552    let span = tracing::trace_span!(
1553        "extract_proto",
1554        n = det_indices.len(),
1555        num_protos = tracing::field::Empty,
1556        layout = tracing::field::Empty,
1557    );
1558    let _guard = span.enter();
1559
1560    let num_protos = mask_tensor.ncols();
1561    let n = det_indices.len();
1562    span.record("num_protos", num_protos);
1563
1564    // Fast path: when no detections survive NMS, skip the expensive proto
1565    // tensor copy (819KB for 160×160×32). Allocate with the correct shape
1566    // (preserving the documented ProtoData.protos shape contract) but skip
1567    // copying from the source tensor — the zeroed allocation is sufficient
1568    // since materialize_masks early-returns on empty detect slices.
1569    if n == 0 {
1570        output_boxes.clear();
1571        let (h, w, k) = protos.dim();
1572
1573        // Detect physical layout (same logic as the normal path).
1574        let (proto_shape, proto_layout) = if std::any::TypeId::of::<PROTO>()
1575            == std::any::TypeId::of::<i8>()
1576        {
1577            if protos.is_standard_layout() {
1578                (&[h, w, k][..], ProtoLayout::Nhwc)
1579            } else if protos.ndim() == 3 && protos.strides() == [w as isize, 1, (h * w) as isize] {
1580                (&[k, h, w][..], ProtoLayout::Nchw)
1581            } else {
1582                (&[h, w, k][..], ProtoLayout::Nhwc)
1583            }
1584        } else {
1585            (&[h, w, k][..], ProtoLayout::Nhwc)
1586        };
1587
1588        let coeff_tensor = Tensor::<i8>::new(&[0, num_protos], Some(TensorMemory::Mem), None)
1589            .expect("allocating empty mask_coefficients tensor");
1590        let coeff_quant =
1591            edgefirst_tensor::Quantization::per_tensor(quant_masks.scale, quant_masks.zero_point);
1592        let coeff_tensor = coeff_tensor
1593            .with_quantization(coeff_quant)
1594            .expect("per-tensor quantization on mask coefficients");
1595        let protos_tensor = Tensor::<i8>::new(proto_shape, Some(TensorMemory::Mem), None)
1596            .expect("allocating protos tensor");
1597        let tensor_quant =
1598            edgefirst_tensor::Quantization::per_tensor(quant_protos.scale, quant_protos.zero_point);
1599        let protos_tensor = protos_tensor
1600            .with_quantization(tensor_quant)
1601            .expect("per-tensor quantization on protos tensor");
1602        return ProtoData {
1603            mask_coefficients: TensorDyn::I8(coeff_tensor),
1604            protos: TensorDyn::I8(protos_tensor),
1605            layout: proto_layout,
1606        };
1607    }
1608
1609    // Keep mask coefficients in raw i8 with quantization metadata attached.
1610    // Consumers that need f32 can dequantize on the fly; the scaled-path
1611    // integer kernel uses raw i8 directly for i8×i8→i32 dot products.
1612    let mut coeff_i8 = Vec::<i8>::with_capacity(n * num_protos);
1613    output_boxes.clear();
1614    for (det, idx) in det_indices {
1615        output_boxes.push(det);
1616        let row = mask_tensor.row(idx);
1617        coeff_i8.extend(row.iter().map(|v| {
1618            let v_i8: i8 = v.as_();
1619            v_i8
1620        }));
1621    }
1622
1623    // Shape `[n, num_protos]` with n=0 is permitted (tracker path emits no
1624    // fresh detections this frame) via the Mem-backed zero-size allowance.
1625    let coeff_tensor = Tensor::<i8>::new(&[n, num_protos], Some(TensorMemory::Mem), None)
1626        .expect("allocating mask_coefficients tensor");
1627    if n > 0 {
1628        let mut m = coeff_tensor
1629            .map()
1630            .expect("mapping mask_coefficients tensor");
1631        m.as_mut_slice().copy_from_slice(&coeff_i8);
1632    }
1633    let coeff_quant =
1634        edgefirst_tensor::Quantization::per_tensor(quant_masks.scale, quant_masks.zero_point);
1635    let coeff_tensor = coeff_tensor
1636        .with_quantization(coeff_quant)
1637        .expect("per-tensor quantization on mask coefficients");
1638    let mask_coefficients = TensorDyn::I8(coeff_tensor);
1639
1640    // Keep protos in raw i8 — consumers dequantize via protos.quantization().
1641    // When PROTO is already i8, detect layout and copy efficiently without
1642    // transposing. The mask materialisation kernels dispatch on the layout.
1643    let (h, w, k) = protos.dim();
1644
1645    // Determine physical layout and copy strategy.
1646    let (proto_shape, proto_layout) =
1647        if std::any::TypeId::of::<PROTO>() == std::any::TypeId::of::<i8>() {
1648            if protos.is_standard_layout() {
1649                // Already NHWC [H, W, K] in contiguous memory.
1650                (&[h, w, k][..], ProtoLayout::Nhwc)
1651            } else if protos.ndim() == 3 && protos.strides() == [w as isize, 1, (h * w) as isize] {
1652                // NCHW reinterpreted as NHWC via stride swap. Physical storage
1653                // is [K, H, W] contiguous. Keep in NCHW — eliminates the costly
1654                // 3.1ms transpose entirely.
1655                (&[k, h, w][..], ProtoLayout::Nchw)
1656            } else {
1657                // Unknown layout — fall back to iter copy as NHWC.
1658                (&[h, w, k][..], ProtoLayout::Nhwc)
1659            }
1660        } else {
1661            (&[h, w, k][..], ProtoLayout::Nhwc)
1662        };
1663
1664    let protos_tensor = Tensor::<i8>::new(proto_shape, Some(TensorMemory::Mem), None)
1665        .expect("allocating protos tensor");
1666    {
1667        let mut m = protos_tensor.map().expect("mapping protos tensor");
1668        let dst = m.as_mut_slice();
1669        if std::any::TypeId::of::<PROTO>() == std::any::TypeId::of::<i8>() {
1670            // SAFETY: PROTO == i8 checked via TypeId; cast slice view is
1671            // size/alignment-compatible by construction.
1672            if protos.is_standard_layout() {
1673                let src: &[i8] = unsafe {
1674                    std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len())
1675                };
1676                dst.copy_from_slice(src);
1677            } else if protos.ndim() == 3 && protos.strides() == [w as isize, 1, (h * w) as isize] {
1678                // NCHW physical layout — sequential copy WITHOUT transpose.
1679                // This saves ~3.1ms on A53/A55 by avoiding the tiled
1680                // NCHW→NHWC transpose of the 819KB proto buffer.
1681                let total = h * w * k;
1682                // SAFETY: ArrayView was constructed from a contiguous slice of
1683                // `total` elements. as_ptr() points to the base of that slice.
1684                let src: &[i8] =
1685                    unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, total) };
1686                dst.copy_from_slice(src);
1687            } else {
1688                for (d, s) in dst.iter_mut().zip(protos.iter()) {
1689                    let v_i8: i8 = s.as_();
1690                    *d = v_i8;
1691                }
1692            }
1693        } else {
1694            for (d, s) in dst.iter_mut().zip(protos.iter()) {
1695                let v_i8: i8 = s.as_();
1696                *d = v_i8;
1697            }
1698        }
1699    }
1700    let tensor_quant =
1701        edgefirst_tensor::Quantization::per_tensor(quant_protos.scale, quant_protos.zero_point);
1702    let protos_tensor = protos_tensor
1703        .with_quantization(tensor_quant)
1704        .expect("per-tensor quantization on new Tensor<i8>");
1705
1706    span.record("layout", tracing::field::debug(&proto_layout));
1707
1708    ProtoData {
1709        mask_coefficients,
1710        protos: TensorDyn::I8(protos_tensor),
1711        layout: proto_layout,
1712    }
1713}
1714
1715/// Per-float-dtype construction of a [`TensorDyn`] from a flat slice / 3-D
1716/// `ArrayView`. Replaces the old `IntoProtoTensor` trait. Each implementor
1717/// either passes its element type straight to `Tensor::from_slice` /
1718/// `Tensor::from_arrayview3`, or narrows `f64` to `f32` (no native f64 kernel
1719/// path exists).
1720pub trait FloatProtoElem: Copy + 'static {
1721    fn slice_into_tensor_dyn(
1722        values: &[Self],
1723        shape: &[usize],
1724    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn>;
1725
1726    fn arrayview3_into_tensor_dyn(
1727        view: ArrayView3<'_, Self>,
1728    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn>;
1729}
1730
1731impl FloatProtoElem for f32 {
1732    fn slice_into_tensor_dyn(
1733        values: &[f32],
1734        shape: &[usize],
1735    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1736        edgefirst_tensor::Tensor::<f32>::from_slice(values, shape)
1737            .map(edgefirst_tensor::TensorDyn::F32)
1738    }
1739    fn arrayview3_into_tensor_dyn(
1740        view: ArrayView3<'_, f32>,
1741    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1742        edgefirst_tensor::Tensor::<f32>::from_arrayview3(view).map(edgefirst_tensor::TensorDyn::F32)
1743    }
1744}
1745
1746impl FloatProtoElem for half::f16 {
1747    fn slice_into_tensor_dyn(
1748        values: &[half::f16],
1749        shape: &[usize],
1750    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1751        edgefirst_tensor::Tensor::<half::f16>::from_slice(values, shape)
1752            .map(edgefirst_tensor::TensorDyn::F16)
1753    }
1754    fn arrayview3_into_tensor_dyn(
1755        view: ArrayView3<'_, half::f16>,
1756    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1757        edgefirst_tensor::Tensor::<half::f16>::from_arrayview3(view)
1758            .map(edgefirst_tensor::TensorDyn::F16)
1759    }
1760}
1761
1762impl FloatProtoElem for f64 {
1763    fn slice_into_tensor_dyn(
1764        values: &[f64],
1765        shape: &[usize],
1766    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1767        // Narrow to f32 — no native f64 kernel path.
1768        let narrowed: Vec<f32> = values.iter().map(|&v| v as f32).collect();
1769        edgefirst_tensor::Tensor::<f32>::from_slice(&narrowed, shape)
1770            .map(edgefirst_tensor::TensorDyn::F32)
1771    }
1772    fn arrayview3_into_tensor_dyn(
1773        view: ArrayView3<'_, f64>,
1774    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1775        let narrowed: ndarray::Array3<f32> = view.mapv(|v| v as f32);
1776        edgefirst_tensor::Tensor::<f32>::from_arrayview3(narrowed.view())
1777            .map(edgefirst_tensor::TensorDyn::F32)
1778    }
1779}
1780
1781fn postprocess_yolo<'a, T>(
1782    output: &'a ArrayView2<'_, T>,
1783) -> (ArrayView2<'a, T>, ArrayView2<'a, T>) {
1784    let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1785    let scores_tensor = output.slice(s![4.., ..,]).reversed_axes();
1786    (boxes_tensor, scores_tensor)
1787}
1788
1789pub(crate) fn postprocess_yolo_seg<'a, T>(
1790    output: &'a ArrayView2<'_, T>,
1791    num_protos: usize,
1792) -> (ArrayView2<'a, T>, ArrayView2<'a, T>, ArrayView2<'a, T>) {
1793    assert!(
1794        output.shape()[0] > num_protos + 4,
1795        "Output shape is too short: {} <= {} + 4",
1796        output.shape()[0],
1797        num_protos
1798    );
1799    let num_classes = output.shape()[0] - 4 - num_protos;
1800    let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1801    let scores_tensor = output.slice(s![4..(num_classes + 4), ..,]).reversed_axes();
1802    let mask_tensor = output.slice(s![(num_classes + 4).., ..,]).reversed_axes();
1803    (boxes_tensor, scores_tensor, mask_tensor)
1804}
1805
1806pub(crate) fn postprocess_yolo_split_segdet<'a, 'b, 'c, BOX, SCORE, MASK>(
1807    boxes_tensor: ArrayView2<'a, BOX>,
1808    scores_tensor: ArrayView2<'b, SCORE>,
1809    mask_tensor: ArrayView2<'c, MASK>,
1810) -> (
1811    ArrayView2<'a, BOX>,
1812    ArrayView2<'b, SCORE>,
1813    ArrayView2<'c, MASK>,
1814) {
1815    let boxes_tensor = boxes_tensor.reversed_axes();
1816    let scores_tensor = scores_tensor.reversed_axes();
1817    let mask_tensor = mask_tensor.reversed_axes();
1818    (boxes_tensor, scores_tensor, mask_tensor)
1819}
1820
1821fn decode_segdet_f32<
1822    MASK: Float + AsPrimitive<f32> + Send + Sync,
1823    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1824>(
1825    boxes: Vec<(DetectBox, usize)>,
1826    masks: ArrayView2<MASK>,
1827    protos: ArrayView3<PROTO>,
1828) -> Result<Vec<(DetectBox, Array3<u8>)>, crate::DecoderError> {
1829    if boxes.is_empty() {
1830        return Ok(Vec::new());
1831    }
1832    if masks.shape()[1] != protos.shape()[2] {
1833        return Err(crate::DecoderError::InvalidShape(format!(
1834            "Mask coefficients count ({}) doesn't match protos channel count ({})",
1835            masks.shape()[1],
1836            protos.shape()[2],
1837        )));
1838    }
1839    boxes
1840        .into_par_iter()
1841        .map(|mut b| {
1842            let ind = b.1;
1843            let (protos, roi) = protobox(&protos, &b.0.bbox)?;
1844            b.0.bbox = roi;
1845            Ok((b.0, make_segmentation(masks.row(ind), protos.view())))
1846        })
1847        .collect()
1848}
1849
1850pub(crate) fn decode_segdet_quant<
1851    MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1852    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1853>(
1854    boxes: Vec<(DetectBox, usize)>,
1855    masks: ArrayView2<MASK>,
1856    protos: ArrayView3<PROTO>,
1857    quant_masks: Quantization,
1858    quant_protos: Quantization,
1859) -> Result<Vec<(DetectBox, Array3<u8>)>, crate::DecoderError> {
1860    if boxes.is_empty() {
1861        return Ok(Vec::new());
1862    }
1863    if masks.shape()[1] != protos.shape()[2] {
1864        return Err(crate::DecoderError::InvalidShape(format!(
1865            "Mask coefficients count ({}) doesn't match protos channel count ({})",
1866            masks.shape()[1],
1867            protos.shape()[2],
1868        )));
1869    }
1870
1871    let total_bits = MASK::zero().count_zeros() + PROTO::zero().count_zeros() + 5; // 32 protos is 2^5
1872    boxes
1873        .into_iter()
1874        .map(|mut b| {
1875            let i = b.1;
1876            let (protos, roi) = protobox(&protos, &b.0.bbox.to_canonical())?;
1877            b.0.bbox = roi;
1878            let seg = match total_bits {
1879                0..=64 => make_segmentation_quant::<MASK, PROTO, i64>(
1880                    masks.row(i),
1881                    protos.view(),
1882                    quant_masks,
1883                    quant_protos,
1884                ),
1885                65..=128 => make_segmentation_quant::<MASK, PROTO, i128>(
1886                    masks.row(i),
1887                    protos.view(),
1888                    quant_masks,
1889                    quant_protos,
1890                ),
1891                _ => {
1892                    return Err(crate::DecoderError::NotSupported(format!(
1893                        "Unsupported bit width ({total_bits}) for segmentation computation"
1894                    )));
1895                }
1896            };
1897            Ok((b.0, seg))
1898        })
1899        .collect()
1900}
1901
1902fn protobox<'a, T>(
1903    protos: &'a ArrayView3<T>,
1904    roi: &BoundingBox,
1905) -> Result<(ArrayView3<'a, T>, BoundingBox), crate::DecoderError> {
1906    let width = protos.dim().1 as f32;
1907    let height = protos.dim().0 as f32;
1908
1909    // Detect un-normalized bounding boxes (pixel-space coordinates).
1910    // protobox expects normalized coordinates in [0, 1]. ONNX models output
1911    // pixel-space boxes (e.g. 0-640) which must be normalized before calling
1912    // decode(). Without this check, pixel-space coordinates silently clamp to
1913    // the proto boundary, producing empty (0, 0, C) masks for every detection.
1914    //
1915    // The limit is set to 2.0 (not 1.01) because YOLO models legitimately
1916    // predict coordinates slightly > 1.0 for objects near frame edges.
1917    // Any value > 2.0 is clearly pixel-space (even the smallest practical
1918    // model input of 32×32 would produce coordinates >> 2.0).
1919    const NORM_LIMIT: f32 = 2.0;
1920    if roi.xmin > NORM_LIMIT
1921        || roi.ymin > NORM_LIMIT
1922        || roi.xmax > NORM_LIMIT
1923        || roi.ymax > NORM_LIMIT
1924    {
1925        return Err(crate::DecoderError::InvalidShape(format!(
1926            "Bounding box coordinates appear un-normalized (pixel-space). \
1927             Got bbox=({:.2}, {:.2}, {:.2}, {:.2}) but expected values in [0, 1]. \
1928             ONNX models output pixel-space boxes — normalize them by dividing by \
1929             the input dimensions before calling decode().",
1930            roi.xmin, roi.ymin, roi.xmax, roi.ymax,
1931        )));
1932    }
1933
1934    let roi = [
1935        (roi.xmin * width).clamp(0.0, width) as usize,
1936        (roi.ymin * height).clamp(0.0, height) as usize,
1937        (roi.xmax * width).clamp(0.0, width).ceil() as usize,
1938        (roi.ymax * height).clamp(0.0, height).ceil() as usize,
1939    ];
1940
1941    let roi_norm = [
1942        roi[0] as f32 / width,
1943        roi[1] as f32 / height,
1944        roi[2] as f32 / width,
1945        roi[3] as f32 / height,
1946    ]
1947    .into();
1948
1949    let cropped = protos.slice(s![roi[1]..roi[3], roi[0]..roi[2], ..]);
1950
1951    Ok((cropped, roi_norm))
1952}
1953
1954/// Compute a single instance segmentation mask from mask coefficients and
1955/// proto maps (float path).
1956///
1957/// Computes `sigmoid(coefficients · protos)` and maps to `[0, 255]`.
1958/// Returns an `(H, W, 1)` u8 array.
1959fn make_segmentation<
1960    MASK: Float + AsPrimitive<f32> + Send + Sync,
1961    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1962>(
1963    mask: ArrayView1<MASK>,
1964    protos: ArrayView3<PROTO>,
1965) -> Array3<u8> {
1966    let shape = protos.shape();
1967
1968    // Safe to unwrap since the shapes will always be compatible
1969    let mask = mask.to_shape((1, mask.len())).unwrap();
1970    let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
1971    let protos = protos.reversed_axes();
1972    let mask = mask.map(|x| x.as_());
1973    let protos = protos.map(|x| x.as_());
1974
1975    // Safe to unwrap since the shapes will always be compatible
1976    let mask = mask
1977        .dot(&protos)
1978        .into_shape_with_order((shape[0], shape[1], 1))
1979        .unwrap();
1980
1981    mask.map(|x| {
1982        let sigmoid = 1.0 / (1.0 + (-*x).exp());
1983        (sigmoid * 255.0).round() as u8
1984    })
1985}
1986
1987/// Compute a single instance segmentation mask from quantized mask
1988/// coefficients and proto maps.
1989///
1990/// Dequantizes both inputs (subtracting zero-points), computes the dot
1991/// product, applies sigmoid, and maps to `[0, 255]`.
1992/// Returns an `(H, W, 1)` u8 array.
1993fn make_segmentation_quant<
1994    MASK: PrimInt + AsPrimitive<DEST> + Send + Sync,
1995    PROTO: PrimInt + AsPrimitive<DEST> + Send + Sync,
1996    DEST: PrimInt + 'static + Signed + AsPrimitive<f32> + Debug,
1997>(
1998    mask: ArrayView1<MASK>,
1999    protos: ArrayView3<PROTO>,
2000    quant_masks: Quantization,
2001    quant_protos: Quantization,
2002) -> Array3<u8>
2003where
2004    i32: AsPrimitive<DEST>,
2005    f32: AsPrimitive<DEST>,
2006{
2007    let shape = protos.shape();
2008
2009    // Safe to unwrap since the shapes will always be compatible
2010    let mask = mask.to_shape((1, mask.len())).unwrap();
2011
2012    let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
2013    let protos = protos.reversed_axes();
2014
2015    let zp = quant_masks.zero_point.as_();
2016
2017    let mask = mask.mapv(|x| x.as_() - zp);
2018
2019    let zp = quant_protos.zero_point.as_();
2020    let protos = protos.mapv(|x| x.as_() - zp);
2021
2022    // Safe to unwrap since the shapes will always be compatible
2023    let segmentation = mask
2024        .dot(&protos)
2025        .into_shape_with_order((shape[0], shape[1], 1))
2026        .unwrap();
2027
2028    let combined_scale = quant_masks.scale * quant_protos.scale;
2029    segmentation.map(|x| {
2030        let val: f32 = (*x).as_() * combined_scale;
2031        let sigmoid = 1.0 / (1.0 + (-val).exp());
2032        (sigmoid * 255.0).round() as u8
2033    })
2034}
2035
2036/// Converts Yolo Instance Segmentation into a 2D mask.
2037///
2038/// The input segmentation is expected to have shape (H, W, 1).
2039///
2040/// The output mask will have shape (H, W), with values 0 or 1 based on the
2041/// threshold.
2042///
2043/// # Errors
2044///
2045/// Returns `DecoderError::InvalidShape` if the input segmentation does not
2046/// have shape (H, W, 1).
2047pub fn yolo_segmentation_to_mask(
2048    segmentation: ArrayView3<u8>,
2049    threshold: u8,
2050) -> Result<Array2<u8>, crate::DecoderError> {
2051    if segmentation.shape()[2] != 1 {
2052        return Err(crate::DecoderError::InvalidShape(format!(
2053            "Yolo Instance Segmentation should have shape (H, W, 1), got (H, W, {})",
2054            segmentation.shape()[2]
2055        )));
2056    }
2057    Ok(segmentation
2058        .slice(s![.., .., 0])
2059        .map(|x| if *x >= threshold { 1 } else { 0 }))
2060}
2061
2062#[cfg(test)]
2063#[cfg_attr(coverage_nightly, coverage(off))]
2064mod tests {
2065    use super::*;
2066    use ndarray::Array2;
2067
2068    // ========================================================================
2069    // Tests for decode_yolo_end_to_end_det_float
2070    // ========================================================================
2071
2072    #[test]
2073    fn test_end_to_end_det_basic_filtering() {
2074        // Create synthetic end-to-end detection output: (6, N) where rows are
2075        // [x1, y1, x2, y2, conf, class]
2076        // 3 detections: one above threshold, two below
2077        let data: Vec<f32> = vec![
2078            // Detection 0: high score (0.9)
2079            0.1, 0.2, 0.3, // x1 values
2080            0.1, 0.2, 0.3, // y1 values
2081            0.5, 0.6, 0.7, // x2 values
2082            0.5, 0.6, 0.7, // y2 values
2083            0.9, 0.1, 0.2, // confidence scores
2084            0.0, 1.0, 2.0, // class indices
2085        ];
2086        let output = Array2::from_shape_vec((6, 3), data).unwrap();
2087
2088        let mut boxes = Vec::with_capacity(10);
2089        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2090
2091        // Only 1 detection should pass threshold of 0.5
2092        assert_eq!(boxes.len(), 1);
2093        assert_eq!(boxes[0].label, 0);
2094        assert!((boxes[0].score - 0.9).abs() < 0.01);
2095        assert!((boxes[0].bbox.xmin - 0.1).abs() < 0.01);
2096        assert!((boxes[0].bbox.ymin - 0.1).abs() < 0.01);
2097        assert!((boxes[0].bbox.xmax - 0.5).abs() < 0.01);
2098        assert!((boxes[0].bbox.ymax - 0.5).abs() < 0.01);
2099    }
2100
2101    #[test]
2102    fn test_end_to_end_det_all_pass_threshold() {
2103        // All detections above threshold
2104        let data: Vec<f32> = vec![
2105            10.0, 20.0, // x1
2106            10.0, 20.0, // y1
2107            50.0, 60.0, // x2
2108            50.0, 60.0, // y2
2109            0.8, 0.7, // conf (both above 0.5)
2110            1.0, 2.0, // class
2111        ];
2112        let output = Array2::from_shape_vec((6, 2), data).unwrap();
2113
2114        let mut boxes = Vec::with_capacity(10);
2115        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2116
2117        assert_eq!(boxes.len(), 2);
2118        assert_eq!(boxes[0].label, 1);
2119        assert_eq!(boxes[1].label, 2);
2120    }
2121
2122    #[test]
2123    fn test_end_to_end_det_none_pass_threshold() {
2124        // All detections below threshold
2125        let data: Vec<f32> = vec![
2126            10.0, 20.0, // x1
2127            10.0, 20.0, // y1
2128            50.0, 60.0, // x2
2129            50.0, 60.0, // y2
2130            0.1, 0.2, // conf (both below 0.5)
2131            1.0, 2.0, // class
2132        ];
2133        let output = Array2::from_shape_vec((6, 2), data).unwrap();
2134
2135        let mut boxes = Vec::with_capacity(10);
2136        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2137
2138        assert_eq!(boxes.len(), 0);
2139    }
2140
2141    #[test]
2142    fn test_end_to_end_det_capacity_limit() {
2143        // Test that output is truncated to capacity
2144        let data: Vec<f32> = vec![
2145            0.1, 0.2, 0.3, 0.4, 0.5, // x1
2146            0.1, 0.2, 0.3, 0.4, 0.5, // y1
2147            0.5, 0.6, 0.7, 0.8, 0.9, // x2
2148            0.5, 0.6, 0.7, 0.8, 0.9, // y2
2149            0.9, 0.9, 0.9, 0.9, 0.9, // conf (all pass)
2150            0.0, 1.0, 2.0, 3.0, 4.0, // class
2151        ];
2152        let output = Array2::from_shape_vec((6, 5), data).unwrap();
2153
2154        let mut boxes = Vec::with_capacity(2); // Only allow 2 boxes
2155        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2156
2157        assert_eq!(boxes.len(), 2);
2158    }
2159
2160    #[test]
2161    fn test_end_to_end_det_empty_output() {
2162        // Test with zero detections
2163        let output = Array2::<f32>::zeros((6, 0));
2164
2165        let mut boxes = Vec::with_capacity(10);
2166        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2167
2168        assert_eq!(boxes.len(), 0);
2169    }
2170
2171    #[test]
2172    fn test_end_to_end_det_pixel_coordinates() {
2173        // Test with pixel coordinates (non-normalized)
2174        let data: Vec<f32> = vec![
2175            100.0, // x1
2176            200.0, // y1
2177            300.0, // x2
2178            400.0, // y2
2179            0.95,  // conf
2180            5.0,   // class
2181        ];
2182        let output = Array2::from_shape_vec((6, 1), data).unwrap();
2183
2184        let mut boxes = Vec::with_capacity(10);
2185        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2186
2187        assert_eq!(boxes.len(), 1);
2188        assert_eq!(boxes[0].label, 5);
2189        assert!((boxes[0].bbox.xmin - 100.0).abs() < 0.01);
2190        assert!((boxes[0].bbox.ymin - 200.0).abs() < 0.01);
2191        assert!((boxes[0].bbox.xmax - 300.0).abs() < 0.01);
2192        assert!((boxes[0].bbox.ymax - 400.0).abs() < 0.01);
2193    }
2194
2195    #[test]
2196    fn test_end_to_end_det_invalid_shape() {
2197        // Test with too few rows (needs at least 6)
2198        let output = Array2::<f32>::zeros((5, 3));
2199
2200        let mut boxes = Vec::with_capacity(10);
2201        let result = decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes);
2202
2203        assert!(result.is_err());
2204        assert!(matches!(
2205            result,
2206            Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 6 rows")
2207        ));
2208    }
2209
2210    // ========================================================================
2211    // Tests for decode_yolo_end_to_end_segdet_float
2212    // ========================================================================
2213
2214    #[test]
2215    fn test_end_to_end_segdet_basic() {
2216        // Create synthetic segdet output: (6 + num_protos, N)
2217        // Detection format: [x1, y1, x2, y2, conf, class, mask_coeff_0..31]
2218        let num_protos = 32;
2219        let num_detections = 2;
2220        let num_features = 6 + num_protos;
2221
2222        // Build detection tensor
2223        let mut data = vec![0.0f32; num_features * num_detections];
2224        // Detection 0: passes threshold
2225        data[0] = 0.1; // x1[0]
2226        data[1] = 0.5; // x1[1]
2227        data[num_detections] = 0.1; // y1[0]
2228        data[num_detections + 1] = 0.5; // y1[1]
2229        data[2 * num_detections] = 0.4; // x2[0]
2230        data[2 * num_detections + 1] = 0.9; // x2[1]
2231        data[3 * num_detections] = 0.4; // y2[0]
2232        data[3 * num_detections + 1] = 0.9; // y2[1]
2233        data[4 * num_detections] = 0.9; // conf[0] - passes
2234        data[4 * num_detections + 1] = 0.3; // conf[1] - fails
2235        data[5 * num_detections] = 1.0; // class[0]
2236        data[5 * num_detections + 1] = 2.0; // class[1]
2237                                            // Fill mask coefficients with small values
2238        for i in 6..num_features {
2239            data[i * num_detections] = 0.1;
2240            data[i * num_detections + 1] = 0.1;
2241        }
2242
2243        let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
2244
2245        // Create protos tensor: (proto_height, proto_width, num_protos)
2246        let protos = Array3::<f32>::zeros((16, 16, num_protos));
2247
2248        let mut boxes = Vec::with_capacity(10);
2249        let mut masks = Vec::with_capacity(10);
2250        decode_yolo_end_to_end_segdet_float(
2251            output.view(),
2252            protos.view(),
2253            0.5,
2254            &mut boxes,
2255            &mut masks,
2256        )
2257        .unwrap();
2258
2259        // Only detection 0 should pass
2260        assert_eq!(boxes.len(), 1);
2261        assert_eq!(masks.len(), 1);
2262        assert_eq!(boxes[0].label, 1);
2263        assert!((boxes[0].score - 0.9).abs() < 0.01);
2264    }
2265
2266    #[test]
2267    fn test_end_to_end_segdet_mask_coordinates() {
2268        // Test that mask coordinates match box coordinates
2269        let num_protos = 32;
2270        let num_features = 6 + num_protos;
2271
2272        let mut data = vec![0.0f32; num_features];
2273        data[0] = 0.2; // x1
2274        data[1] = 0.2; // y1
2275        data[2] = 0.8; // x2
2276        data[3] = 0.8; // y2
2277        data[4] = 0.95; // conf
2278        data[5] = 3.0; // class
2279
2280        let output = Array2::from_shape_vec((num_features, 1), data).unwrap();
2281        let protos = Array3::<f32>::zeros((16, 16, num_protos));
2282
2283        let mut boxes = Vec::with_capacity(10);
2284        let mut masks = Vec::with_capacity(10);
2285        decode_yolo_end_to_end_segdet_float(
2286            output.view(),
2287            protos.view(),
2288            0.5,
2289            &mut boxes,
2290            &mut masks,
2291        )
2292        .unwrap();
2293
2294        assert_eq!(boxes.len(), 1);
2295        assert_eq!(masks.len(), 1);
2296
2297        // Verify mask coordinates match box coordinates
2298        assert!((masks[0].xmin - boxes[0].bbox.xmin).abs() < 0.01);
2299        assert!((masks[0].ymin - boxes[0].bbox.ymin).abs() < 0.01);
2300        assert!((masks[0].xmax - boxes[0].bbox.xmax).abs() < 0.01);
2301        assert!((masks[0].ymax - boxes[0].bbox.ymax).abs() < 0.01);
2302    }
2303
2304    #[test]
2305    fn test_end_to_end_segdet_empty_output() {
2306        let num_protos = 32;
2307        let output = Array2::<f32>::zeros((6 + num_protos, 0));
2308        let protos = Array3::<f32>::zeros((16, 16, num_protos));
2309
2310        let mut boxes = Vec::with_capacity(10);
2311        let mut masks = Vec::with_capacity(10);
2312        decode_yolo_end_to_end_segdet_float(
2313            output.view(),
2314            protos.view(),
2315            0.5,
2316            &mut boxes,
2317            &mut masks,
2318        )
2319        .unwrap();
2320
2321        assert_eq!(boxes.len(), 0);
2322        assert_eq!(masks.len(), 0);
2323    }
2324
2325    #[test]
2326    fn test_end_to_end_segdet_capacity_limit() {
2327        let num_protos = 32;
2328        let num_detections = 5;
2329        let num_features = 6 + num_protos;
2330
2331        let mut data = vec![0.0f32; num_features * num_detections];
2332        // All detections pass threshold
2333        for i in 0..num_detections {
2334            data[i] = 0.1 * (i as f32); // x1
2335            data[num_detections + i] = 0.1 * (i as f32); // y1
2336            data[2 * num_detections + i] = 0.1 * (i as f32) + 0.2; // x2
2337            data[3 * num_detections + i] = 0.1 * (i as f32) + 0.2; // y2
2338            data[4 * num_detections + i] = 0.9; // conf
2339            data[5 * num_detections + i] = i as f32; // class
2340        }
2341
2342        let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
2343        let protos = Array3::<f32>::zeros((16, 16, num_protos));
2344
2345        let mut boxes = Vec::with_capacity(2); // Limit to 2
2346        let mut masks = Vec::with_capacity(2);
2347        decode_yolo_end_to_end_segdet_float(
2348            output.view(),
2349            protos.view(),
2350            0.5,
2351            &mut boxes,
2352            &mut masks,
2353        )
2354        .unwrap();
2355
2356        assert_eq!(boxes.len(), 2);
2357        assert_eq!(masks.len(), 2);
2358    }
2359
2360    #[test]
2361    fn test_end_to_end_segdet_invalid_shape_too_few_rows() {
2362        // Test with too few rows (needs at least 7: 6 base + 1 mask coeff)
2363        let output = Array2::<f32>::zeros((6, 3));
2364        let protos = Array3::<f32>::zeros((16, 16, 32));
2365
2366        let mut boxes = Vec::with_capacity(10);
2367        let mut masks = Vec::with_capacity(10);
2368        let result = decode_yolo_end_to_end_segdet_float(
2369            output.view(),
2370            protos.view(),
2371            0.5,
2372            &mut boxes,
2373            &mut masks,
2374        );
2375
2376        assert!(result.is_err());
2377        assert!(matches!(
2378            result,
2379            Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 7 rows")
2380        ));
2381    }
2382
2383    #[test]
2384    fn test_end_to_end_segdet_invalid_shape_protos_mismatch() {
2385        // Test with mismatched mask coefficients and protos count
2386        let num_protos = 32;
2387        let output = Array2::<f32>::zeros((6 + 16, 3)); // 16 mask coeffs
2388        let protos = Array3::<f32>::zeros((16, 16, num_protos)); // 32 protos
2389
2390        let mut boxes = Vec::with_capacity(10);
2391        let mut masks = Vec::with_capacity(10);
2392        let result = decode_yolo_end_to_end_segdet_float(
2393            output.view(),
2394            protos.view(),
2395            0.5,
2396            &mut boxes,
2397            &mut masks,
2398        );
2399
2400        assert!(result.is_err());
2401        assert!(matches!(
2402            result,
2403            Err(crate::DecoderError::InvalidShape(s)) if s.contains("doesn't match protos count")
2404        ));
2405    }
2406
2407    // ========================================================================
2408    // Tests for decode_yolo_split_end_to_end_segdet_float
2409    // ========================================================================
2410
2411    #[test]
2412    fn test_split_end_to_end_segdet_basic() {
2413        // Create synthetic segdet output: (6 + num_protos, N)
2414        // Detection format: [x1, y1, x2, y2, conf, class, mask_coeff_0..31]
2415        let num_protos = 32;
2416        let num_detections = 2;
2417        let num_features = 6 + num_protos;
2418
2419        // Build detection tensor
2420        let mut data = vec![0.0f32; num_features * num_detections];
2421        // Detection 0: passes threshold
2422        data[0] = 0.1; // x1[0]
2423        data[1] = 0.5; // x1[1]
2424        data[num_detections] = 0.1; // y1[0]
2425        data[num_detections + 1] = 0.5; // y1[1]
2426        data[2 * num_detections] = 0.4; // x2[0]
2427        data[2 * num_detections + 1] = 0.9; // x2[1]
2428        data[3 * num_detections] = 0.4; // y2[0]
2429        data[3 * num_detections + 1] = 0.9; // y2[1]
2430        data[4 * num_detections] = 0.9; // conf[0] - passes
2431        data[4 * num_detections + 1] = 0.3; // conf[1] - fails
2432        data[5 * num_detections] = 1.0; // class[0]
2433        data[5 * num_detections + 1] = 2.0; // class[1]
2434                                            // Fill mask coefficients with small values
2435        for i in 6..num_features {
2436            data[i * num_detections] = 0.1;
2437            data[i * num_detections + 1] = 0.1;
2438        }
2439
2440        let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
2441        let box_coords = output.slice(s![..4, ..]);
2442        let scores = output.slice(s![4..5, ..]);
2443        let classes = output.slice(s![5..6, ..]);
2444        let mask_coeff = output.slice(s![6.., ..]);
2445        // Create protos tensor: (proto_height, proto_width, num_protos)
2446        let protos = Array3::<f32>::zeros((16, 16, num_protos));
2447
2448        let mut boxes = Vec::with_capacity(10);
2449        let mut masks = Vec::with_capacity(10);
2450        decode_yolo_split_end_to_end_segdet_float(
2451            box_coords,
2452            scores,
2453            classes,
2454            mask_coeff,
2455            protos.view(),
2456            0.5,
2457            &mut boxes,
2458            &mut masks,
2459        )
2460        .unwrap();
2461
2462        // Only detection 0 should pass
2463        assert_eq!(boxes.len(), 1);
2464        assert_eq!(masks.len(), 1);
2465        assert_eq!(boxes[0].label, 1);
2466        assert!((boxes[0].score - 0.9).abs() < 0.01);
2467    }
2468
2469    // ========================================================================
2470    // Tests for yolo_segmentation_to_mask
2471    // ========================================================================
2472
2473    #[test]
2474    fn test_segmentation_to_mask_basic() {
2475        // Create a 4x4x1 segmentation with values above and below threshold
2476        let data: Vec<u8> = vec![
2477            100, 200, 50, 150, // row 0
2478            10, 255, 128, 64, // row 1
2479            0, 127, 128, 255, // row 2
2480            64, 64, 192, 192, // row 3
2481        ];
2482        let segmentation = Array3::from_shape_vec((4, 4, 1), data).unwrap();
2483
2484        let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2485
2486        // Values >= 128 should be 1, others 0
2487        assert_eq!(mask[[0, 0]], 0); // 100 < 128
2488        assert_eq!(mask[[0, 1]], 1); // 200 >= 128
2489        assert_eq!(mask[[0, 2]], 0); // 50 < 128
2490        assert_eq!(mask[[0, 3]], 1); // 150 >= 128
2491        assert_eq!(mask[[1, 1]], 1); // 255 >= 128
2492        assert_eq!(mask[[1, 2]], 1); // 128 >= 128
2493        assert_eq!(mask[[2, 0]], 0); // 0 < 128
2494        assert_eq!(mask[[2, 1]], 0); // 127 < 128
2495    }
2496
2497    #[test]
2498    fn test_segmentation_to_mask_all_above() {
2499        let segmentation = Array3::from_elem((4, 4, 1), 255u8);
2500        let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2501        assert!(mask.iter().all(|&x| x == 1));
2502    }
2503
2504    #[test]
2505    fn test_segmentation_to_mask_all_below() {
2506        let segmentation = Array3::from_elem((4, 4, 1), 64u8);
2507        let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2508        assert!(mask.iter().all(|&x| x == 0));
2509    }
2510
2511    #[test]
2512    fn test_segmentation_to_mask_invalid_shape() {
2513        let segmentation = Array3::from_elem((4, 4, 3), 128u8);
2514        let result = yolo_segmentation_to_mask(segmentation.view(), 128);
2515
2516        assert!(result.is_err());
2517        assert!(matches!(
2518            result,
2519            Err(crate::DecoderError::InvalidShape(s)) if s.contains("(H, W, 1)")
2520        ));
2521    }
2522
2523    // ========================================================================
2524    // Tests for protobox / NORM_LIMIT regression
2525    // ========================================================================
2526
2527    #[test]
2528    fn test_protobox_clamps_edge_coordinates() {
2529        // bbox with xmax=1.0 should not panic (OOB guard)
2530        let protos = Array3::<f32>::zeros((16, 16, 4));
2531        let view = protos.view();
2532        let roi = BoundingBox {
2533            xmin: 0.5,
2534            ymin: 0.5,
2535            xmax: 1.0,
2536            ymax: 1.0,
2537        };
2538        let result = protobox(&view, &roi);
2539        assert!(result.is_ok(), "protobox should accept xmax=1.0");
2540        let (cropped, _roi_norm) = result.unwrap();
2541        // Cropped region must have non-zero spatial dimensions
2542        assert!(cropped.shape()[0] > 0);
2543        assert!(cropped.shape()[1] > 0);
2544        assert_eq!(cropped.shape()[2], 4);
2545    }
2546
2547    #[test]
2548    fn test_protobox_rejects_wildly_out_of_range() {
2549        // bbox with coords > NORM_LIMIT (e.g. 3.0) returns error
2550        let protos = Array3::<f32>::zeros((16, 16, 4));
2551        let view = protos.view();
2552        let roi = BoundingBox {
2553            xmin: 0.0,
2554            ymin: 0.0,
2555            xmax: 3.0,
2556            ymax: 3.0,
2557        };
2558        let result = protobox(&view, &roi);
2559        assert!(
2560            matches!(result, Err(crate::DecoderError::InvalidShape(s)) if s.contains("un-normalized")),
2561            "protobox should reject coords > NORM_LIMIT"
2562        );
2563    }
2564
2565    #[test]
2566    fn test_protobox_accepts_slightly_over_one() {
2567        // bbox with coords at 1.5 (within NORM_LIMIT=2.0) succeeds
2568        let protos = Array3::<f32>::zeros((16, 16, 4));
2569        let view = protos.view();
2570        let roi = BoundingBox {
2571            xmin: 0.0,
2572            ymin: 0.0,
2573            xmax: 1.5,
2574            ymax: 1.5,
2575        };
2576        let result = protobox(&view, &roi);
2577        assert!(
2578            result.is_ok(),
2579            "protobox should accept coords <= NORM_LIMIT (2.0)"
2580        );
2581        let (cropped, _roi_norm) = result.unwrap();
2582        // Entire proto map should be selected when coords > 1.0 (clamped to boundary)
2583        assert_eq!(cropped.shape()[0], 16);
2584        assert_eq!(cropped.shape()[1], 16);
2585    }
2586
2587    #[test]
2588    fn test_segdet_float_proto_no_panic() {
2589        // Simulates YOLOv8n-seg: output0 = [116, 8400] (4 box + 80 class + 32 mask coeff)
2590        // output1 (protos) = [32, 160, 160]
2591        let num_proposals = 100; // enough to produce idx >= 32
2592        let num_classes = 80;
2593        let num_mask_coeffs = 32;
2594        let rows = 4 + num_classes + num_mask_coeffs; // 116
2595
2596        // Fill boxes with valid xywh data so some detections pass the threshold.
2597        // Layout is [116, num_proposals] row-major: row 0=cx, 1=cy, 2=w, 3=h,
2598        // rows 4..84=class scores, rows 84..116=mask coefficients.
2599        let mut data = vec![0.0f32; rows * num_proposals];
2600        for i in 0..num_proposals {
2601            let row = |r: usize| r * num_proposals + i;
2602            data[row(0)] = 320.0; // cx
2603            data[row(1)] = 320.0; // cy
2604            data[row(2)] = 50.0; // w
2605            data[row(3)] = 50.0; // h
2606            data[row(4)] = 0.9; // class-0 score
2607        }
2608        let boxes = ndarray::Array2::from_shape_vec((rows, num_proposals), data).unwrap();
2609
2610        // Protos must be in HWC order. Under the HAL physical-order
2611        // contract, callers declare shape+dshape matching producer memory
2612        // and swap_axes_if_needed permutes the stride tuple into canonical
2613        // [batch, height, width, num_protos] before this function sees it.
2614        let protos = ndarray::Array3::<f32>::zeros((160, 160, num_mask_coeffs));
2615
2616        let mut output_boxes = Vec::with_capacity(300);
2617
2618        // This panicked before fix: mask_tensor.row(idx) with idx >= 32
2619        let proto_data = impl_yolo_segdet_float_proto::<XYWH, _, _>(
2620            boxes.view(),
2621            protos.view(),
2622            0.5,
2623            0.7,
2624            Some(Nms::default()),
2625            MAX_NMS_CANDIDATES,
2626            300,
2627            &mut output_boxes,
2628        );
2629
2630        // Should produce detections (NMS will collapse many overlapping boxes)
2631        assert!(!output_boxes.is_empty());
2632        let coeffs_shape = proto_data.mask_coefficients.shape();
2633        assert_eq!(coeffs_shape[0], output_boxes.len());
2634        // Each mask coefficient vector should have 32 elements
2635        assert_eq!(coeffs_shape[1], num_mask_coeffs);
2636    }
2637
2638    // ========================================================================
2639    // Pre-NMS top-K cap (MAX_NMS_CANDIDATES)
2640    // ========================================================================
2641
2642    /// At very low score thresholds (e.g., t=0.01 on YOLOv8 with 8400×80
2643    /// candidates) almost every score passes the filter, feeding O(n²)
2644    /// NMS and a per-survivor mask matmul. The decoder caps the
2645    /// candidate set fed to NMS at `MAX_NMS_CANDIDATES` (Ultralytics
2646    /// default 30 000) to bound worst-case decode time.
2647    ///
2648    /// This regression test pumps 50 000 above-threshold candidates
2649    /// into `impl_yolo_segdet_get_boxes` with NMS bypassed (Nms=None)
2650    /// and a generous post-NMS cap. Before the fix, the function
2651    /// returned all 50 000; after the fix, exactly 30 000.
2652    #[test]
2653    fn test_pre_nms_cap_truncates_excess_candidates() {
2654        let n: usize = 50_000;
2655        let num_classes = 1;
2656
2657        // Identical valid boxes. Distinct scores (descending) so the
2658        // top-K cap keeps the highest-scoring ones in deterministic
2659        // order — letting us assert *which* ones survived.
2660        let mut boxes_data = Vec::with_capacity(n * 4);
2661        let mut scores_data = Vec::with_capacity(n * num_classes);
2662        for i in 0..n {
2663            boxes_data.extend_from_slice(&[0.1f32, 0.1, 0.5, 0.5]);
2664            // score_i = 0.99 - i * 1e-7 keeps everything well above 0.1
2665            // threshold but strictly decreasing.
2666            scores_data.push(0.99 - (i as f32) * 1e-7);
2667        }
2668        let boxes = Array2::from_shape_vec((n, 4), boxes_data).unwrap();
2669        let scores = Array2::from_shape_vec((n, num_classes), scores_data).unwrap();
2670
2671        let result = impl_yolo_segdet_get_boxes::<XYXY, _, _>(
2672            boxes.view(),
2673            scores.view(),
2674            0.1,
2675            1.0,                             // IoU 1.0 → NMS suppresses nothing
2676            Some(Nms::ClassAgnostic),        // NMS enabled so pre_nms_top_k applies
2677            crate::yolo::MAX_NMS_CANDIDATES, // pre_nms_top_k
2678            usize::MAX,                      // no post-NMS truncation
2679        );
2680
2681        assert_eq!(
2682            result.len(),
2683            crate::yolo::MAX_NMS_CANDIDATES,
2684            "pre-NMS cap should truncate to MAX_NMS_CANDIDATES; got {}",
2685            result.len()
2686        );
2687        // Top-K survivors: highest scores were the first n indices,
2688        // so survivor 0 must have score ~0.99.
2689        let top_score = result[0].0.score;
2690        assert!(
2691            top_score > 0.98,
2692            "highest-ranked survivor should have the largest score, got {top_score}"
2693        );
2694    }
2695
2696    /// Counterpart for the quantized split path. Same contract: feed
2697    /// more than `MAX_NMS_CANDIDATES` survivors above the quantized
2698    /// threshold, confirm `impl_yolo_split_segdet_quant_get_boxes`
2699    /// truncates before NMS.
2700    #[test]
2701    fn test_pre_nms_cap_truncates_excess_candidates_quant() {
2702        use crate::Quantization;
2703        let n: usize = 50_000;
2704        let num_classes = 1;
2705
2706        // i8 boxes with simple scale/zp; the box value 50 dequantizes
2707        // to 0.5 with scale=0.01, zp=0 — fine for a flat box set.
2708        let boxes_data = (0..n).flat_map(|_| [10i8, 10, 50, 50]).collect::<Vec<_>>();
2709        let boxes = Array2::from_shape_vec((n, 4), boxes_data).unwrap();
2710        let quant_boxes = Quantization {
2711            scale: 0.01,
2712            zero_point: 0,
2713        };
2714
2715        // u8 scores: distinct descending values, all well above threshold.
2716        // value 250 → 0.98 with scale 0.00392, zp 0.
2717        // value (250 - i % 200) keeps a wide spread above the dequant
2718        // threshold of 0.5.
2719        let scores_data: Vec<u8> = (0..n)
2720            .map(|i| 250u8.saturating_sub((i % 200) as u8))
2721            .collect();
2722        let scores = Array2::from_shape_vec((n, num_classes), scores_data).unwrap();
2723        let quant_scores = Quantization {
2724            scale: 0.00392,
2725            zero_point: 0,
2726        };
2727
2728        let result = impl_yolo_split_segdet_quant_get_boxes::<XYXY, _, _>(
2729            (boxes.view(), quant_boxes),
2730            (scores.view(), quant_scores),
2731            0.1,
2732            1.0,                             // IoU 1.0 → NMS suppresses nothing
2733            Some(Nms::ClassAgnostic),        // NMS enabled so pre_nms_top_k applies
2734            crate::yolo::MAX_NMS_CANDIDATES, // pre_nms_top_k
2735            usize::MAX,                      // no post-NMS truncation
2736        );
2737
2738        assert_eq!(
2739            result.len(),
2740            crate::yolo::MAX_NMS_CANDIDATES,
2741            "quant path pre-NMS cap should truncate to MAX_NMS_CANDIDATES; got {}",
2742            result.len()
2743        );
2744    }
2745
2746    /// Regression test for HAILORT_BUG.md — the YoloSegDet path
2747    /// (combined `(4 + nc + nm, N)` detection tensor + separate protos)
2748    /// must pair each surviving detection with the mask coefficient
2749    /// row at the SAME anchor index the box came from. The validator
2750    /// sees this path miss the pairing under schema-v2 Hailo inputs
2751    /// (mAP collapse from 46.8 → 3.65 while mask IoU stays at 66.9,
2752    /// the fingerprint of mask-to-detection misalignment).
2753    ///
2754    /// Construction: three anchors with distinct mask-coef signatures
2755    /// that, after dot(coefs, protos) + sigmoid, produce HIGH vs LOW
2756    /// mask pixel values. Two anchors survive (one high, one low); if
2757    /// the mask row is looked up at the wrong index, the per-detection
2758    /// mean mask value would cross the threshold and we catch it.
2759    #[test]
2760    fn segdet_combined_tensor_pairs_detection_with_matching_mask_row() {
2761        let nc = 2; // num_classes
2762        let nm = 2; // num_protos
2763        let n = 3; // num_anchors
2764        let feat = 4 + nc + nm; // 8
2765
2766        // Tensor layout: (8, 3) rows=features, cols=anchors.
2767        // Row indices:  0..4 = xywh, 4..6 = scores, 6..8 = mask_coefs.
2768        //
2769        //         anchor 0 | anchor 1 | anchor 2
2770        // xc       0.2      | 0.5      | 0.8
2771        // yc       0.2      | 0.5      | 0.8
2772        // w        0.1      | 0.1      | 0.1
2773        // h        0.1      | 0.1      | 0.1
2774        // s[0]     0.9      | 0.0      | 0.8   (class 0)
2775        // s[1]     0.0      | 0.0      | 0.0   (class 1 — always loses)
2776        // m[0]     3.0      | 0.0      | -3.0  (high for a0, low for a2)
2777        // m[1]     3.0      | 0.0      | -3.0
2778        //
2779        // Proto[0] = Proto[1] = all-ones (8x8), so
2780        //   mask(a0) = sigmoid(3 + 3) ≈ 0.9975 → 254
2781        //   mask(a2) = sigmoid(-3 + -3) ≈ 0.0025 → 1
2782        // 250-point gap makes any misalignment trivially detectable.
2783        let mut data = vec![0.0f32; feat * n];
2784        let set = |d: &mut [f32], r: usize, c: usize, v: f32| d[r * n + c] = v;
2785        set(&mut data, 0, 0, 0.2);
2786        set(&mut data, 1, 0, 0.2);
2787        set(&mut data, 2, 0, 0.1);
2788        set(&mut data, 3, 0, 0.1);
2789        set(&mut data, 0, 1, 0.5);
2790        set(&mut data, 1, 1, 0.5);
2791        set(&mut data, 2, 1, 0.1);
2792        set(&mut data, 3, 1, 0.1);
2793        set(&mut data, 0, 2, 0.8);
2794        set(&mut data, 1, 2, 0.8);
2795        set(&mut data, 2, 2, 0.1);
2796        set(&mut data, 3, 2, 0.1);
2797        set(&mut data, 4, 0, 0.9);
2798        set(&mut data, 4, 2, 0.8);
2799        set(&mut data, 6, 0, 3.0);
2800        set(&mut data, 7, 0, 3.0);
2801        set(&mut data, 6, 2, -3.0);
2802        set(&mut data, 7, 2, -3.0);
2803
2804        let output = Array2::from_shape_vec((feat, n), data).unwrap();
2805        let protos = Array3::<f32>::from_elem((8, 8, nm), 1.0);
2806
2807        let mut boxes: Vec<DetectBox> = Vec::with_capacity(10);
2808        let mut masks: Vec<Segmentation> = Vec::with_capacity(10);
2809        decode_yolo_segdet_float(
2810            output.view(),
2811            protos.view(),
2812            0.5,
2813            0.5,
2814            Some(Nms::ClassAgnostic),
2815            &mut boxes,
2816            &mut masks,
2817        )
2818        .unwrap();
2819
2820        assert_eq!(
2821            boxes.len(),
2822            2,
2823            "two anchors above threshold should survive (a0 score=0.9, a2 score=0.8); got {}",
2824            boxes.len()
2825        );
2826
2827        // Build a (anchor_index → mask_mean) mapping from the results.
2828        // Anchor 0 has centre (0.2, 0.2), anchor 2 has centre (0.8,
2829        // 0.8). The DetectBox bbox is the post-XYWH-to-XYXY conversion
2830        // of the original xywh; cropping inside protobox may shrink it,
2831        // so match by centre (0.2 vs 0.8) rather than exact bbox.
2832        for (b, m) in boxes.iter().zip(masks.iter()) {
2833            let cx = (b.bbox.xmin + b.bbox.xmax) * 0.5;
2834            let mean = {
2835                let s = &m.segmentation;
2836                let total: u32 = s.iter().map(|&v| v as u32).sum();
2837                total as f32 / s.len() as f32
2838            };
2839            if cx < 0.3 {
2840                // anchor 0 — expect HIGH mask values ≈ 254
2841                assert!(
2842                    mean > 200.0,
2843                    "anchor 0 detection (centre {cx:.2}) should have high-value mask; got mean {mean}"
2844                );
2845            } else if cx > 0.7 {
2846                // anchor 2 — expect LOW mask values ≈ 1
2847                assert!(
2848                    mean < 50.0,
2849                    "anchor 2 detection (centre {cx:.2}) should have low-value mask; got mean {mean}"
2850                );
2851            } else {
2852                panic!("unexpected detection centre {cx:.2}");
2853            }
2854        }
2855    }
2856}