Skip to main content

edgefirst_decoder/
yolo.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4//! Internal YOLO decoder kernels.
5//!
6//! All items in this module are `pub(crate)`. The public API surface for
7//! decoding is [`crate::Decoder`] + [`crate::DecoderBuilder`]; external
8//! callers must go through that entry point so we can evolve the kernels
9//! (split paths, dispatch tables, NEON tiers) without breaking semver.
10//! See `CHANGELOG.md` for the 0.20.0 narrowing.
11
12use std::fmt::Debug;
13
14use ndarray::{
15    parallel::prelude::{IntoParallelIterator, ParallelIterator},
16    s, Array2, Array3, ArrayView1, ArrayView2, ArrayView3,
17};
18use num_traits::{AsPrimitive, Float, PrimInt, Signed};
19
20use crate::{
21    byte::{
22        nms_class_aware_int, nms_extra_class_aware_int, nms_extra_int, nms_int,
23        postprocess_boxes_index_quant, postprocess_boxes_quant, quantize_score_threshold,
24    },
25    configs::Nms,
26    dequant_detect_box,
27    float::{
28        nms_class_aware_float, nms_extra_class_aware_float, nms_extra_float, nms_float,
29        postprocess_boxes_float, postprocess_boxes_index_float,
30    },
31    BBoxTypeTrait, BoundingBox, DetectBox, DetectBoxQuantized, ProtoData, ProtoLayout,
32    Quantization, Segmentation, XYWH, XYXY,
33};
34
35/// Maximum number of above-threshold candidates fed to NMS.
36///
37/// At very low score thresholds (e.g., t=0.01 on YOLOv8 with
38/// 8400 anchors × 80 classes), the number of survivors approaches
39/// the full 672 000-entry score grid. NMS is O(n²) and the
40/// downstream mask matmul runs once per survivor, so an
41/// unbounded set produces minutes-per-frame decode times.
42///
43/// `MAX_NMS_CANDIDATES` matches the Ultralytics `max_nms` default
44/// and is applied as a top-K-by-score truncation immediately
45/// before NMS. Values above the cap are silently dropped — at the
46/// score thresholds where the cap activates the bottom of the
47/// candidate list is dominated by noise that NMS would discard
48/// anyway.
49///
50/// Production callers configure this via [`crate::DecoderBuilder::with_pre_nms_top_k`];
51/// only the test-only seg-det wrappers reach for this constant directly.
52#[cfg(test)]
53pub(crate) const MAX_NMS_CANDIDATES: usize = 30_000;
54
55/// Default post-NMS detection cap used by the public `decode_yolo_*`
56/// convenience wrappers when no explicit cap is plumbed in. Mirrors the
57/// `Decoder::max_det` default set by `DecoderBuilder` (also 300, matching
58/// the Ultralytics `max_det` default). Pre-EDGEAI-1302 these wrappers
59/// used `output_boxes.capacity()` as the cap, which silently dropped all
60/// detections when the caller passed `Vec::new()`.
61pub(crate) const DEFAULT_MAX_DETECTIONS: usize = 300;
62
63/// Truncate `boxes` to the highest-scoring `top_k` entries in-place when the
64/// input exceeds the cap. Uses partial sort (O(N)) via `select_nth_unstable_by`
65/// to avoid full O(N log N) sort. No-op when input length ≤ `top_k`.
66fn truncate_to_top_k_by_score<E: Send>(boxes: &mut Vec<(DetectBox, E)>, top_k: usize) {
67    if boxes.len() > top_k {
68        boxes.select_nth_unstable_by(top_k, |a, b| b.0.score.total_cmp(&a.0.score));
69        boxes.truncate(top_k);
70    }
71}
72
73/// Quantized counterpart of [`truncate_to_top_k_by_score`]. Sorts on
74/// the raw quantized score (which preserves order under monotonic
75/// dequantization). Uses partial sort (O(N)) via `select_nth_unstable_by`.
76fn truncate_to_top_k_by_score_quant<S: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send>(
77    boxes: &mut Vec<(DetectBoxQuantized<S>, E)>,
78    top_k: usize,
79) {
80    if boxes.len() > top_k {
81        boxes.select_nth_unstable_by(top_k, |a, b| b.0.score.cmp(&a.0.score));
82        boxes.truncate(top_k);
83    }
84}
85
86/// Dispatches to the appropriate NMS function based on mode for float boxes.
87fn dispatch_nms_float(nms: Option<Nms>, iou: f32, boxes: Vec<DetectBox>) -> Vec<DetectBox> {
88    match nms {
89        Some(Nms::ClassAgnostic) => nms_float(iou, boxes),
90        Some(Nms::ClassAware) => nms_class_aware_float(iou, boxes),
91        None => boxes, // bypass NMS
92    }
93}
94
95/// Dispatches to the appropriate NMS function based on mode for float boxes
96/// with extra data.
97pub(super) fn dispatch_nms_extra_float<E: Send + Sync>(
98    nms: Option<Nms>,
99    iou: f32,
100    boxes: Vec<(DetectBox, E)>,
101) -> Vec<(DetectBox, E)> {
102    match nms {
103        Some(Nms::ClassAgnostic) => nms_extra_float(iou, boxes),
104        Some(Nms::ClassAware) => nms_extra_class_aware_float(iou, boxes),
105        None => boxes, // bypass NMS
106    }
107}
108
109/// Dispatches to the appropriate NMS function based on mode for quantized
110/// boxes.
111fn dispatch_nms_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync>(
112    nms: Option<Nms>,
113    iou: f32,
114    boxes: Vec<DetectBoxQuantized<SCORE>>,
115) -> Vec<DetectBoxQuantized<SCORE>> {
116    match nms {
117        Some(Nms::ClassAgnostic) => nms_int(iou, boxes),
118        Some(Nms::ClassAware) => nms_class_aware_int(iou, boxes),
119        None => boxes, // bypass NMS
120    }
121}
122
123/// Dispatches to the appropriate NMS function based on mode for quantized boxes
124/// with extra data.
125fn dispatch_nms_extra_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send + Sync>(
126    nms: Option<Nms>,
127    iou: f32,
128    boxes: Vec<(DetectBoxQuantized<SCORE>, E)>,
129) -> Vec<(DetectBoxQuantized<SCORE>, E)> {
130    match nms {
131        Some(Nms::ClassAgnostic) => nms_extra_int(iou, boxes),
132        Some(Nms::ClassAware) => nms_extra_class_aware_int(iou, boxes),
133        None => boxes, // bypass NMS
134    }
135}
136
137/// Detection cap helper for the public free `decode_yolo_*` wrappers.
138///
139/// Encodes the convention documented above: if the caller passed a non-empty
140/// `Vec`, that capacity acts as the per-call cap; otherwise fall back to
141/// [`DEFAULT_MAX_DETECTIONS`] so freshly-constructed `Vec::new()` outputs
142/// don't silently drop every detection.
143#[inline]
144fn cap_or_default<T>(v: &Vec<T>) -> usize {
145    if v.capacity() > 0 {
146        v.capacity()
147    } else {
148        DEFAULT_MAX_DETECTIONS
149    }
150}
151
152// ─── Public free decode_yolo_* convenience wrappers ────────────────────────
153//
154// Detection cap convention (applies to every `decode_yolo_*` free function
155// below):
156//
157// These functions are the convenience layer for callers that don't go
158// through `Decoder::decode()` (benches, FFI shims, ad-hoc test harnesses).
159// They use **`output_boxes.capacity()` as a per-call detection cap**:
160//
161//   - When the caller passes `Vec::with_capacity(N)`, the post-NMS output
162//     is truncated to at most `N` detections.
163//   - When the caller passes `Vec::new()` (capacity 0), the implementation
164//     falls back to the [`DEFAULT_MAX_DETECTIONS`] constant (300) so a
165//     freshly-constructed `Vec` doesn't silently drop every detection.
166//
167// This is intentionally **different** from the `Decoder::decode()` /
168// `Decoder::decode_proto()` contract, which bounds output count solely
169// by [`Decoder::max_det`] (set via `DecoderBuilder::with_max_det`,
170// default 300) regardless of the caller's `Vec` capacity (EDGEAI-1302).
171//
172// Use the `Decoder` API when you need explicit control over `max_det`,
173// schema-driven decoding, or EDGEAI-1303 normalization. Use these free
174// functions when you have raw tensors in hand and want a one-shot decode.
175
176/// Decodes YOLO detection outputs from quantized tensors into detection boxes.
177///
178/// Boxes are expected to be in XYWH format.
179///
180/// Expected shapes of inputs:
181/// - output: (4 + num_classes, num_boxes)
182///
183/// See the "Detection cap convention" comment above for how
184/// `output_boxes.capacity()` bounds the result count.
185pub(crate) fn decode_yolo_det<BOX: PrimInt + AsPrimitive<f32> + Send + Sync>(
186    output: (ArrayView2<BOX>, Quantization),
187    score_threshold: f32,
188    iou_threshold: f32,
189    nms: Option<Nms>,
190    output_boxes: &mut Vec<DetectBox>,
191) where
192    f32: AsPrimitive<BOX>,
193{
194    impl_yolo_quant::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
195}
196
197/// Decodes YOLO detection outputs from float tensors into detection boxes.
198///
199/// Boxes are expected to be in XYWH format.
200///
201/// Expected shapes of inputs:
202/// - output: (4 + num_classes, num_boxes)
203pub(crate) fn decode_yolo_det_float<T>(
204    output: ArrayView2<T>,
205    score_threshold: f32,
206    iou_threshold: f32,
207    nms: Option<Nms>,
208    output_boxes: &mut Vec<DetectBox>,
209) where
210    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
211    f32: AsPrimitive<T>,
212{
213    impl_yolo_float::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
214}
215
216/// Test-only seg-det quantized decode shim.
217///
218/// Production callers go through [`crate::Decoder::decode`]; this wrapper
219/// exists only so the parity tests in `lib.rs` and `yolo.rs` can compare the
220/// kernel output against the Decoder output without duplicating the
221/// `impl_yolo_segdet_quant` argument plumbing.
222///
223/// Boxes are expected to be in XYWH format. Expected shapes:
224/// - `boxes`: `(4 + num_classes + num_protos, num_boxes)`
225/// - `protos`: `(proto_height, proto_width, num_protos)`
226///
227/// # Errors
228/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
229#[cfg(test)]
230pub(crate) fn decode_yolo_segdet_quant<
231    BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
232    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
233>(
234    boxes: (ArrayView2<BOX>, Quantization),
235    protos: (ArrayView3<PROTO>, Quantization),
236    score_threshold: f32,
237    iou_threshold: f32,
238    nms: Option<Nms>,
239    output_boxes: &mut Vec<DetectBox>,
240    output_masks: &mut Vec<Segmentation>,
241) -> Result<(), crate::DecoderError>
242where
243    f32: AsPrimitive<BOX>,
244{
245    // Pre-Decoder convenience wrapper: no schema-derived `normalized`
246    // or `input_dims`, so the EDGEAI-1303 normalization is a no-op.
247    // Callers that need that path should go through `DecoderBuilder`
248    // with a schema.
249    let cap = cap_or_default(output_boxes);
250    impl_yolo_segdet_quant::<XYWH, _, _>(
251        boxes,
252        protos,
253        score_threshold,
254        iou_threshold,
255        nms,
256        MAX_NMS_CANDIDATES,
257        cap,
258        None,
259        None,
260        output_boxes,
261        output_masks,
262    )
263}
264
265/// Test-only seg-det float decode shim. See [`decode_yolo_segdet_quant`].
266#[cfg(test)]
267pub(crate) fn decode_yolo_segdet_float<T>(
268    boxes: ArrayView2<T>,
269    protos: ArrayView3<T>,
270    score_threshold: f32,
271    iou_threshold: f32,
272    nms: Option<Nms>,
273    output_boxes: &mut Vec<DetectBox>,
274    output_masks: &mut Vec<Segmentation>,
275) -> Result<(), crate::DecoderError>
276where
277    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
278    f32: AsPrimitive<T>,
279{
280    // Pre-Decoder convenience wrapper: schema-derived normalization is
281    // not available here (see EDGEAI-1303 note in the quantized sibling).
282    let cap = cap_or_default(output_boxes);
283    impl_yolo_segdet_float::<XYWH, _, _>(
284        boxes,
285        protos,
286        score_threshold,
287        iou_threshold,
288        nms,
289        MAX_NMS_CANDIDATES,
290        cap,
291        None,
292        None,
293        output_boxes,
294        output_masks,
295    )
296}
297
298/// Decodes YOLO split detection outputs from quantized tensors into detection
299/// boxes.
300///
301/// Boxes are expected to be in XYWH format.
302///
303/// Expected shapes of inputs:
304/// - boxes: (4, num_boxes)
305/// - scores: (num_classes, num_boxes)
306///
307/// # Panics
308/// Panics if shapes don't match the expected dimensions.
309pub(crate) fn decode_yolo_split_det_quant<
310    BOX: PrimInt + AsPrimitive<i32> + AsPrimitive<f32> + Send + Sync,
311    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
312>(
313    boxes: (ArrayView2<BOX>, Quantization),
314    scores: (ArrayView2<SCORE>, Quantization),
315    score_threshold: f32,
316    iou_threshold: f32,
317    nms: Option<Nms>,
318    output_boxes: &mut Vec<DetectBox>,
319) where
320    f32: AsPrimitive<SCORE>,
321{
322    impl_yolo_split_quant::<XYWH, _, _>(
323        boxes,
324        scores,
325        score_threshold,
326        iou_threshold,
327        nms,
328        output_boxes,
329    );
330}
331
332/// Decodes YOLO split detection outputs from float tensors into detection
333/// boxes.
334///
335/// Boxes are expected to be in XYWH format.
336///
337/// Expected shapes of inputs:
338/// - boxes: (4, num_boxes)
339/// - scores: (num_classes, num_boxes)
340///
341/// # Panics
342/// Panics if shapes don't match the expected dimensions.
343pub(crate) fn decode_yolo_split_det_float<T>(
344    boxes: ArrayView2<T>,
345    scores: ArrayView2<T>,
346    score_threshold: f32,
347    iou_threshold: f32,
348    nms: Option<Nms>,
349    output_boxes: &mut Vec<DetectBox>,
350) where
351    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
352    f32: AsPrimitive<T>,
353{
354    impl_yolo_split_float::<XYWH, _, _>(
355        boxes,
356        scores,
357        score_threshold,
358        iou_threshold,
359        nms,
360        output_boxes,
361    );
362}
363
364/// Decodes end-to-end YOLO detection outputs (post-NMS from model).
365/// Expects an array of shape `(6, N)`, where the first dimension (rows)
366/// corresponds to the 6 per-detection features
367/// `[x1, y1, x2, y2, conf, class]` and the second dimension (columns)
368/// indexes the `N` detections.
369/// Boxes are output directly without NMS (the model already applied NMS).
370///
371/// Coordinates may be normalized `[0, 1]` or absolute pixel values depending
372/// on the model configuration. The caller should check
373/// `decoder.normalized_boxes()` to determine which.
374///
375/// # Errors
376///
377/// Returns `DecoderError::InvalidShape` if `output` has fewer than 6 rows.
378pub(crate) fn decode_yolo_end_to_end_det_float<T>(
379    output: ArrayView2<T>,
380    score_threshold: f32,
381    output_boxes: &mut Vec<DetectBox>,
382) -> Result<(), crate::DecoderError>
383where
384    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
385    f32: AsPrimitive<T>,
386{
387    // Validate input shape: need at least 6 rows (x1, y1, x2, y2, conf, class)
388    if output.shape()[0] < 6 {
389        return Err(crate::DecoderError::InvalidShape(format!(
390            "End-to-end detection output requires at least 6 rows, got {}",
391            output.shape()[0]
392        )));
393    }
394
395    // Input shape: (6, N) -> transpose to (N, 4) for boxes and (N, 1) for scores
396    let boxes = output.slice(s![0..4, ..]).reversed_axes();
397    let scores = output.slice(s![4..5, ..]).reversed_axes();
398    let classes = output.slice(s![5, ..]);
399    let mut boxes =
400        postprocess_boxes_index_float::<XYXY, _, _>(score_threshold.as_(), boxes, scores);
401    boxes.truncate(cap_or_default(output_boxes));
402    output_boxes.clear();
403    for (mut b, i) in boxes.into_iter() {
404        b.label = classes[i].as_() as usize;
405        output_boxes.push(b);
406    }
407    // No NMS — model output is already post-NMS
408    Ok(())
409}
410
411/// Decodes end-to-end YOLO detection + segmentation outputs (post-NMS from
412/// model).
413///
414/// Input shapes:
415/// - detection: (6 + num_protos, N) where rows are [x1, y1, x2, y2, conf,
416///   class, mask_coeff_0, ..., mask_coeff_31]
417/// - protos: (proto_height, proto_width, num_protos)
418///
419/// Boxes are output directly without NMS (model already applied NMS).
420/// Coordinates may be normalized [0,1] or pixel values depending on model
421/// config.
422///
423/// # Errors
424///
425/// Returns `DecoderError::InvalidShape` if:
426/// - output has fewer than 7 rows (6 base + at least 1 mask coefficient)
427/// - protos shape doesn't match mask coefficients count
428pub(crate) fn decode_yolo_end_to_end_segdet_float<T>(
429    output: ArrayView2<T>,
430    protos: ArrayView3<T>,
431    score_threshold: f32,
432    output_boxes: &mut Vec<DetectBox>,
433    output_masks: &mut Vec<crate::Segmentation>,
434) -> Result<(), crate::DecoderError>
435where
436    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
437    f32: AsPrimitive<T>,
438{
439    let (boxes, scores, classes, mask_coeff) =
440        postprocess_yolo_end_to_end_segdet(&output, protos.dim().2)?;
441    let cap = cap_or_default(output_boxes);
442    let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
443        boxes,
444        scores,
445        classes,
446        score_threshold,
447        cap,
448    );
449
450    // No NMS — model output is already post-NMS
451
452    impl_yolo_split_segdet_process_masks(boxes, mask_coeff, protos, output_boxes, output_masks)
453}
454
455/// Decodes split end-to-end YOLO detection outputs (post-NMS from model).
456///
457/// Input shapes (after batch dim removed):
458/// - boxes: (4, N) — xyxy pixel coordinates
459/// - scores: (1, N) — confidence of the top class
460/// - classes: (1, N) — class index of the top class
461///
462/// Boxes are output directly without NMS (model already applied NMS).
463pub(crate) fn decode_yolo_split_end_to_end_det_float<T: Float + AsPrimitive<f32>>(
464    boxes: ArrayView2<T>,
465    scores: ArrayView2<T>,
466    classes: ArrayView2<T>,
467    score_threshold: f32,
468    output_boxes: &mut Vec<DetectBox>,
469) -> Result<(), crate::DecoderError> {
470    let n = boxes.shape()[1];
471
472    let cap = cap_or_default(output_boxes);
473    output_boxes.clear();
474
475    let (boxes, scores, classes) = postprocess_yolo_split_end_to_end_det(boxes, scores, &classes)?;
476
477    for i in 0..n {
478        let score: f32 = scores[[i, 0]].as_();
479        if score < score_threshold {
480            continue;
481        }
482        if output_boxes.len() >= cap {
483            break;
484        }
485        output_boxes.push(DetectBox {
486            bbox: BoundingBox {
487                xmin: boxes[[i, 0]].as_(),
488                ymin: boxes[[i, 1]].as_(),
489                xmax: boxes[[i, 2]].as_(),
490                ymax: boxes[[i, 3]].as_(),
491            },
492            score,
493            label: classes[i].as_() as usize,
494        });
495    }
496    Ok(())
497}
498
499/// Decodes split end-to-end YOLO detection + segmentation outputs.
500///
501/// Input shapes (after batch dim removed):
502/// - boxes: (4, N) — xyxy pixel coordinates
503/// - scores: (1, N) — confidence
504/// - classes: (1, N) — class index
505/// - mask_coeff: (num_protos, N) — mask coefficients per detection
506/// - protos: (proto_h, proto_w, num_protos) — prototype masks
507#[allow(clippy::too_many_arguments)]
508pub(crate) fn decode_yolo_split_end_to_end_segdet_float<T>(
509    boxes: ArrayView2<T>,
510    scores: ArrayView2<T>,
511    classes: ArrayView2<T>,
512    mask_coeff: ArrayView2<T>,
513    protos: ArrayView3<T>,
514    score_threshold: f32,
515    output_boxes: &mut Vec<DetectBox>,
516    output_masks: &mut Vec<crate::Segmentation>,
517) -> Result<(), crate::DecoderError>
518where
519    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
520    f32: AsPrimitive<T>,
521{
522    let (boxes, scores, classes, mask_coeff) =
523        postprocess_yolo_split_end_to_end_segdet(boxes, scores, &classes, mask_coeff)?;
524    let cap = cap_or_default(output_boxes);
525    let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
526        boxes,
527        scores,
528        classes,
529        score_threshold,
530        cap,
531    );
532
533    impl_yolo_split_segdet_process_masks(boxes, mask_coeff, protos, output_boxes, output_masks)
534}
535
536#[allow(clippy::type_complexity)]
537pub(crate) fn postprocess_yolo_end_to_end_segdet<'a, T>(
538    output: &'a ArrayView2<'_, T>,
539    num_protos: usize,
540) -> Result<
541    (
542        ArrayView2<'a, T>,
543        ArrayView2<'a, T>,
544        ArrayView1<'a, T>,
545        ArrayView2<'a, T>,
546    ),
547    crate::DecoderError,
548> {
549    // Validate input shape: need at least 7 rows (6 base + at least 1 mask coeff)
550    if output.shape()[0] < 7 {
551        return Err(crate::DecoderError::InvalidShape(format!(
552            "End-to-end segdet output requires at least 7 rows, got {}",
553            output.shape()[0]
554        )));
555    }
556
557    let num_mask_coeffs = output.shape()[0] - 6;
558    if num_mask_coeffs != num_protos {
559        return Err(crate::DecoderError::InvalidShape(format!(
560            "Mask coefficients count ({}) doesn't match protos count ({})",
561            num_mask_coeffs, num_protos
562        )));
563    }
564
565    // Input shape: (6+num_protos, N) -> transpose for postprocessing
566    let boxes = output.slice(s![0..4, ..]).reversed_axes();
567    let scores = output.slice(s![4..5, ..]).reversed_axes();
568    let classes = output.slice(s![5, ..]);
569    let mask_coeff = output.slice(s![6.., ..]).reversed_axes();
570    Ok((boxes, scores, classes, mask_coeff))
571}
572
573/// Postprocess yolo split end to end det by reversing axes of boxes,
574/// scores, and flattening the class tensor.
575/// Expected input shapes:
576/// - boxes: (4, N)
577/// - scores: (1, N)
578/// - classes: (1, N)
579#[allow(clippy::type_complexity)]
580pub(crate) fn postprocess_yolo_split_end_to_end_det<'a, 'b, 'c, BOXES, SCORES, CLASS>(
581    boxes: ArrayView2<'a, BOXES>,
582    scores: ArrayView2<'b, SCORES>,
583    classes: &'c ArrayView2<CLASS>,
584) -> Result<
585    (
586        ArrayView2<'a, BOXES>,
587        ArrayView2<'b, SCORES>,
588        ArrayView1<'c, CLASS>,
589    ),
590    crate::DecoderError,
591> {
592    let num_boxes = boxes.shape()[1];
593    if boxes.shape()[0] != 4 {
594        return Err(crate::DecoderError::InvalidShape(format!(
595            "Split end-to-end box_coords must be 4, got {}",
596            boxes.shape()[0]
597        )));
598    }
599
600    if scores.shape()[0] != 1 {
601        return Err(crate::DecoderError::InvalidShape(format!(
602            "Split end-to-end scores num_classes must be 1, got {}",
603            scores.shape()[0]
604        )));
605    }
606
607    if classes.shape()[0] != 1 {
608        return Err(crate::DecoderError::InvalidShape(format!(
609            "Split end-to-end classes num_classes must be 1, got {}",
610            classes.shape()[0]
611        )));
612    }
613
614    if scores.shape()[1] != num_boxes {
615        return Err(crate::DecoderError::InvalidShape(format!(
616            "Split end-to-end scores must have same num_boxes as boxes ({}), got {}",
617            num_boxes,
618            scores.shape()[1]
619        )));
620    }
621
622    if classes.shape()[1] != num_boxes {
623        return Err(crate::DecoderError::InvalidShape(format!(
624            "Split end-to-end classes must have same num_boxes as boxes ({}), got {}",
625            num_boxes,
626            classes.shape()[1]
627        )));
628    }
629
630    let boxes = boxes.reversed_axes();
631    let scores = scores.reversed_axes();
632    let classes = classes.slice(s![0, ..]);
633    Ok((boxes, scores, classes))
634}
635
636/// Postprocess yolo split end to end segdet by reversing axes of boxes,
637/// scores, mask tensors and flattening the class tensor.
638#[allow(clippy::type_complexity)]
639pub(crate) fn postprocess_yolo_split_end_to_end_segdet<
640    'a,
641    'b,
642    'c,
643    'd,
644    BOXES,
645    SCORES,
646    CLASS,
647    MASK,
648>(
649    boxes: ArrayView2<'a, BOXES>,
650    scores: ArrayView2<'b, SCORES>,
651    classes: &'c ArrayView2<CLASS>,
652    mask_coeff: ArrayView2<'d, MASK>,
653) -> Result<
654    (
655        ArrayView2<'a, BOXES>,
656        ArrayView2<'b, SCORES>,
657        ArrayView1<'c, CLASS>,
658        ArrayView2<'d, MASK>,
659    ),
660    crate::DecoderError,
661> {
662    let num_boxes = boxes.shape()[1];
663    if boxes.shape()[0] != 4 {
664        return Err(crate::DecoderError::InvalidShape(format!(
665            "Split end-to-end box_coords must be 4, got {}",
666            boxes.shape()[0]
667        )));
668    }
669
670    if scores.shape()[0] != 1 {
671        return Err(crate::DecoderError::InvalidShape(format!(
672            "Split end-to-end scores num_classes must be 1, got {}",
673            scores.shape()[0]
674        )));
675    }
676
677    if classes.shape()[0] != 1 {
678        return Err(crate::DecoderError::InvalidShape(format!(
679            "Split end-to-end classes num_classes must be 1, got {}",
680            classes.shape()[0]
681        )));
682    }
683
684    if scores.shape()[1] != num_boxes {
685        return Err(crate::DecoderError::InvalidShape(format!(
686            "Split end-to-end scores must have same num_boxes as boxes ({}), got {}",
687            num_boxes,
688            scores.shape()[1]
689        )));
690    }
691
692    if classes.shape()[1] != num_boxes {
693        return Err(crate::DecoderError::InvalidShape(format!(
694            "Split end-to-end classes must have same num_boxes as boxes ({}), got {}",
695            num_boxes,
696            classes.shape()[1]
697        )));
698    }
699
700    if mask_coeff.shape()[1] != num_boxes {
701        return Err(crate::DecoderError::InvalidShape(format!(
702            "Split end-to-end mask_coeff must have same num_boxes as boxes ({}), got {}",
703            num_boxes,
704            mask_coeff.shape()[1]
705        )));
706    }
707
708    let boxes = boxes.reversed_axes();
709    let scores = scores.reversed_axes();
710    let classes = classes.slice(s![0, ..]);
711    let mask_coeff = mask_coeff.reversed_axes();
712    Ok((boxes, scores, classes, mask_coeff))
713}
714/// Internal implementation of YOLO decoding for quantized tensors.
715///
716/// Expected shapes of inputs:
717/// - output: (4 + num_classes, num_boxes)
718pub(crate) fn impl_yolo_quant<B: BBoxTypeTrait, T: PrimInt + AsPrimitive<f32> + Send + Sync>(
719    output: (ArrayView2<T>, Quantization),
720    score_threshold: f32,
721    iou_threshold: f32,
722    nms: Option<Nms>,
723    output_boxes: &mut Vec<DetectBox>,
724) where
725    f32: AsPrimitive<T>,
726{
727    let _span = tracing::trace_span!("decode", mode = "quant_det").entered();
728    let (boxes, quant_boxes) = output;
729    let (boxes_tensor, scores_tensor) = postprocess_yolo(&boxes);
730
731    let boxes = {
732        let score_threshold = quantize_score_threshold(score_threshold, quant_boxes);
733        postprocess_boxes_quant::<B, _, _>(
734            score_threshold,
735            boxes_tensor,
736            scores_tensor,
737            quant_boxes,
738        )
739    };
740
741    let boxes = dispatch_nms_int(nms, iou_threshold, boxes);
742    // Detection cap convention (see `cap_or_default`).
743    let len = cap_or_default(output_boxes).min(boxes.len());
744    output_boxes.clear();
745    for b in boxes.iter().take(len) {
746        output_boxes.push(dequant_detect_box(b, quant_boxes));
747    }
748}
749
750/// Internal implementation of YOLO decoding for float tensors.
751///
752/// Expected shapes of inputs:
753/// - output: (4 + num_classes, num_boxes)
754pub(crate) fn impl_yolo_float<B: BBoxTypeTrait, T: Float + AsPrimitive<f32> + Send + Sync>(
755    output: ArrayView2<T>,
756    score_threshold: f32,
757    iou_threshold: f32,
758    nms: Option<Nms>,
759    output_boxes: &mut Vec<DetectBox>,
760) where
761    f32: AsPrimitive<T>,
762{
763    let _span = tracing::trace_span!("decode", mode = "float_det").entered();
764    let (boxes_tensor, scores_tensor) = postprocess_yolo(&output);
765    let boxes =
766        postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
767    let boxes = dispatch_nms_float(nms, iou_threshold, boxes);
768    // Detection cap convention (see `cap_or_default`).
769    let len = cap_or_default(output_boxes).min(boxes.len());
770    output_boxes.clear();
771    for b in boxes.into_iter().take(len) {
772        output_boxes.push(b);
773    }
774}
775
776/// Internal implementation of YOLO split detection decoding for quantized
777/// tensors.
778///
779/// Expected shapes of inputs:
780/// - boxes: (4, num_boxes)
781/// - scores: (num_classes, num_boxes)
782///
783/// # Panics
784/// Panics if shapes don't match the expected dimensions.
785pub(crate) fn impl_yolo_split_quant<
786    B: BBoxTypeTrait,
787    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
788    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
789>(
790    boxes: (ArrayView2<BOX>, Quantization),
791    scores: (ArrayView2<SCORE>, Quantization),
792    score_threshold: f32,
793    iou_threshold: f32,
794    nms: Option<Nms>,
795    output_boxes: &mut Vec<DetectBox>,
796) where
797    f32: AsPrimitive<SCORE>,
798{
799    let _span = tracing::trace_span!("decode", mode = "split_quant_det").entered();
800    let (boxes_tensor, quant_boxes) = boxes;
801    let (scores_tensor, quant_scores) = scores;
802
803    let boxes_tensor = boxes_tensor.reversed_axes();
804    let scores_tensor = scores_tensor.reversed_axes();
805
806    let boxes = {
807        let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
808        postprocess_boxes_quant::<B, _, _>(
809            score_threshold,
810            boxes_tensor,
811            scores_tensor,
812            quant_boxes,
813        )
814    };
815
816    let boxes = dispatch_nms_int(nms, iou_threshold, boxes);
817    // Detection cap convention (see `cap_or_default`).
818    let len = cap_or_default(output_boxes).min(boxes.len());
819    output_boxes.clear();
820    for b in boxes.iter().take(len) {
821        output_boxes.push(dequant_detect_box(b, quant_scores));
822    }
823}
824
825/// Internal implementation of YOLO split detection decoding for float tensors.
826///
827/// Expected shapes of inputs:
828/// - boxes: (4, num_boxes)
829/// - scores: (num_classes, num_boxes)
830///
831/// # Panics
832/// Panics if shapes don't match the expected dimensions.
833pub(crate) fn impl_yolo_split_float<
834    B: BBoxTypeTrait,
835    BOX: Float + AsPrimitive<f32> + Send + Sync,
836    SCORE: Float + AsPrimitive<f32> + Send + Sync,
837>(
838    boxes_tensor: ArrayView2<BOX>,
839    scores_tensor: ArrayView2<SCORE>,
840    score_threshold: f32,
841    iou_threshold: f32,
842    nms: Option<Nms>,
843    output_boxes: &mut Vec<DetectBox>,
844) where
845    f32: AsPrimitive<SCORE>,
846{
847    let _span = tracing::trace_span!("decode", mode = "split_float_det").entered();
848    let boxes_tensor = boxes_tensor.reversed_axes();
849    let scores_tensor = scores_tensor.reversed_axes();
850    let boxes =
851        postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
852    let boxes = dispatch_nms_float(nms, iou_threshold, boxes);
853    // Detection cap convention (see `cap_or_default`).
854    let len = cap_or_default(output_boxes).min(boxes.len());
855    output_boxes.clear();
856    for b in boxes.into_iter().take(len) {
857        output_boxes.push(b);
858    }
859}
860
861/// Divide each survivor's bbox by `(input_w, input_h)` when the schema
862/// declares `normalized: false`. Pixel-space box coords from the model
863/// are pulled into the canonical `[0, 1]` range expected by `protobox`
864/// and downstream callers — see EDGEAI-1303.
865///
866/// No-op when `normalized` is `Some(true)` / `None`, when `input_dims`
867/// is `None`, or when there are no survivors.
868#[inline]
869pub(crate) fn maybe_normalize_boxes_in_place(
870    boxes: &mut [(DetectBox, usize)],
871    normalized: Option<bool>,
872    input_dims: Option<(usize, usize)>,
873) {
874    if normalized != Some(false) {
875        return;
876    }
877    let Some((w, h)) = input_dims else {
878        return;
879    };
880    if w == 0 || h == 0 {
881        return;
882    }
883    let inv_w = 1.0 / w as f32;
884    let inv_h = 1.0 / h as f32;
885    for (b, _) in boxes.iter_mut() {
886        b.bbox.xmin *= inv_w;
887        b.bbox.ymin *= inv_h;
888        b.bbox.xmax *= inv_w;
889        b.bbox.ymax *= inv_h;
890    }
891}
892
893/// Internal implementation of YOLO detection segmentation decoding for
894/// quantized tensors.
895///
896/// Expected shapes of inputs:
897/// - boxes: (4 + num_classes + num_protos, num_boxes)
898/// - protos: (proto_height, proto_width, num_protos)
899///
900/// # Errors
901/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
902#[allow(clippy::too_many_arguments)]
903pub(crate) fn impl_yolo_segdet_quant<
904    B: BBoxTypeTrait,
905    BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
906    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
907>(
908    boxes: (ArrayView2<BOX>, Quantization),
909    protos: (ArrayView3<PROTO>, Quantization),
910    score_threshold: f32,
911    iou_threshold: f32,
912    nms: Option<Nms>,
913    pre_nms_top_k: usize,
914    max_det: usize,
915    normalized: Option<bool>,
916    input_dims: Option<(usize, usize)>,
917    output_boxes: &mut Vec<DetectBox>,
918    output_masks: &mut Vec<Segmentation>,
919) -> Result<(), crate::DecoderError>
920where
921    f32: AsPrimitive<BOX>,
922{
923    let (boxes, quant_boxes) = boxes;
924    let num_protos = protos.0.dim().2;
925
926    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
927    let mut boxes = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
928        (boxes_tensor, quant_boxes),
929        (scores_tensor, quant_boxes),
930        score_threshold,
931        iou_threshold,
932        nms,
933        pre_nms_top_k,
934        max_det,
935    );
936    maybe_normalize_boxes_in_place(&mut boxes, normalized, input_dims);
937
938    impl_yolo_split_segdet_quant_process_masks::<_, _>(
939        boxes,
940        (mask_tensor, quant_boxes),
941        protos,
942        output_boxes,
943        output_masks,
944    )
945}
946
947/// Internal implementation of YOLO detection segmentation decoding for
948/// float tensors.
949///
950/// Expected shapes of inputs:
951/// - boxes: (4 + num_classes + num_protos, num_boxes)
952/// - protos: (proto_height, proto_width, num_protos)
953///
954/// # Panics
955/// Panics if shapes don't match the expected dimensions.
956#[allow(clippy::too_many_arguments)]
957pub(crate) fn impl_yolo_segdet_float<
958    B: BBoxTypeTrait,
959    BOX: Float + AsPrimitive<f32> + Send + Sync,
960    PROTO: Float + AsPrimitive<f32> + Send + Sync,
961>(
962    boxes: ArrayView2<BOX>,
963    protos: ArrayView3<PROTO>,
964    score_threshold: f32,
965    iou_threshold: f32,
966    nms: Option<Nms>,
967    pre_nms_top_k: usize,
968    max_det: usize,
969    normalized: Option<bool>,
970    input_dims: Option<(usize, usize)>,
971    output_boxes: &mut Vec<DetectBox>,
972    output_masks: &mut Vec<Segmentation>,
973) -> Result<(), crate::DecoderError>
974where
975    f32: AsPrimitive<BOX>,
976{
977    let num_protos = protos.dim().2;
978    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
979    let mut boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
980        boxes_tensor,
981        scores_tensor,
982        score_threshold,
983        iou_threshold,
984        nms,
985        pre_nms_top_k,
986        max_det,
987    );
988    maybe_normalize_boxes_in_place(&mut boxes, normalized, input_dims);
989    impl_yolo_split_segdet_process_masks(boxes, mask_tensor, protos, output_boxes, output_masks)
990}
991
992pub(crate) fn impl_yolo_segdet_get_boxes<
993    B: BBoxTypeTrait,
994    BOX: Float + AsPrimitive<f32> + Send + Sync,
995    SCORE: Float + AsPrimitive<f32> + Send + Sync,
996>(
997    boxes_tensor: ArrayView2<BOX>,
998    scores_tensor: ArrayView2<SCORE>,
999    score_threshold: f32,
1000    iou_threshold: f32,
1001    nms: Option<Nms>,
1002    pre_nms_top_k: usize,
1003    max_det: usize,
1004) -> Vec<(DetectBox, usize)>
1005where
1006    f32: AsPrimitive<SCORE>,
1007{
1008    let span = tracing::trace_span!(
1009        "decode",
1010        n_candidates = tracing::field::Empty,
1011        n_after_topk = tracing::field::Empty,
1012        n_after_nms = tracing::field::Empty,
1013        n_detections = tracing::field::Empty,
1014    );
1015    let _guard = span.enter();
1016
1017    let mut boxes = {
1018        let _s = tracing::trace_span!("score_filter").entered();
1019        postprocess_boxes_index_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor)
1020    };
1021    span.record("n_candidates", boxes.len());
1022
1023    if nms.is_some() {
1024        let _s = tracing::trace_span!("top_k", k = pre_nms_top_k).entered();
1025        truncate_to_top_k_by_score(&mut boxes, pre_nms_top_k);
1026    }
1027    span.record("n_after_topk", boxes.len());
1028
1029    let mut boxes = {
1030        let _s = tracing::trace_span!("nms").entered();
1031        dispatch_nms_extra_float(nms, iou_threshold, boxes)
1032    };
1033    span.record("n_after_nms", boxes.len());
1034
1035    boxes.sort_unstable_by(|a, b| b.0.score.total_cmp(&a.0.score));
1036    boxes.truncate(max_det);
1037    span.record("n_detections", boxes.len());
1038
1039    boxes
1040}
1041
1042pub(crate) fn impl_yolo_end_to_end_segdet_get_boxes<
1043    B: BBoxTypeTrait,
1044    BOX: Float + AsPrimitive<f32> + Send + Sync,
1045    SCORE: Float + AsPrimitive<f32> + Send + Sync,
1046    CLASS: AsPrimitive<f32> + Send + Sync,
1047>(
1048    boxes: ArrayView2<BOX>,
1049    scores: ArrayView2<SCORE>,
1050    classes: ArrayView1<CLASS>,
1051    score_threshold: f32,
1052    max_boxes: usize,
1053) -> Vec<(DetectBox, usize)>
1054where
1055    f32: AsPrimitive<SCORE>,
1056{
1057    let mut boxes = postprocess_boxes_index_float::<B, _, _>(score_threshold.as_(), boxes, scores);
1058    boxes.truncate(max_boxes);
1059    for (b, ind) in &mut boxes {
1060        b.label = classes[*ind].as_().round() as usize;
1061    }
1062    boxes
1063}
1064
1065pub(crate) fn impl_yolo_split_segdet_process_masks<
1066    MASK: Float + AsPrimitive<f32> + Send + Sync,
1067    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1068>(
1069    boxes: Vec<(DetectBox, usize)>,
1070    masks_tensor: ArrayView2<MASK>,
1071    protos_tensor: ArrayView3<PROTO>,
1072    output_boxes: &mut Vec<DetectBox>,
1073    output_masks: &mut Vec<Segmentation>,
1074) -> Result<(), crate::DecoderError> {
1075    let _span = tracing::trace_span!("process_masks", n = boxes.len(), mode = "float").entered();
1076    // Boxes are already bounded by the upstream `max_det` cap from
1077    // `_get_boxes`; no second cap is needed here (EDGEAI-1302).
1078
1079    let boxes = decode_segdet_f32(boxes, masks_tensor, protos_tensor)?;
1080    output_boxes.clear();
1081    output_masks.clear();
1082    for (b, roi, m) in boxes.into_iter() {
1083        output_boxes.push(b);
1084        output_masks.push(Segmentation {
1085            xmin: roi.xmin,
1086            ymin: roi.ymin,
1087            xmax: roi.xmax,
1088            ymax: roi.ymax,
1089            segmentation: m,
1090        });
1091    }
1092    Ok(())
1093}
1094/// Expected input shapes:
1095/// - boxes_tensor: (num_boxes, 4)
1096/// - scores_tensor: (num_boxes, num_classes)
1097pub(crate) fn impl_yolo_split_segdet_quant_get_boxes<
1098    B: BBoxTypeTrait,
1099    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
1100    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
1101>(
1102    boxes: (ArrayView2<BOX>, Quantization),
1103    scores: (ArrayView2<SCORE>, Quantization),
1104    score_threshold: f32,
1105    iou_threshold: f32,
1106    nms: Option<Nms>,
1107    pre_nms_top_k: usize,
1108    max_det: usize,
1109) -> Vec<(DetectBox, usize)>
1110where
1111    f32: AsPrimitive<SCORE>,
1112{
1113    let (boxes_tensor, quant_boxes) = boxes;
1114    let (scores_tensor, quant_scores) = scores;
1115
1116    let span = tracing::trace_span!(
1117        "decode",
1118        n_candidates = tracing::field::Empty,
1119        n_after_topk = tracing::field::Empty,
1120        n_after_nms = tracing::field::Empty,
1121        n_detections = tracing::field::Empty,
1122    );
1123    let _guard = span.enter();
1124
1125    let mut boxes = {
1126        let _s = tracing::trace_span!("score_filter").entered();
1127        let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
1128        postprocess_boxes_index_quant::<B, _, _>(
1129            score_threshold,
1130            boxes_tensor,
1131            scores_tensor,
1132            quant_boxes,
1133        )
1134    };
1135    span.record("n_candidates", boxes.len());
1136
1137    if nms.is_some() {
1138        let _s = tracing::trace_span!("top_k", k = pre_nms_top_k).entered();
1139        truncate_to_top_k_by_score_quant(&mut boxes, pre_nms_top_k);
1140    }
1141    span.record("n_after_topk", boxes.len());
1142
1143    let mut boxes = {
1144        let _s = tracing::trace_span!("nms").entered();
1145        dispatch_nms_extra_int(nms, iou_threshold, boxes)
1146    };
1147    span.record("n_after_nms", boxes.len());
1148
1149    // Sort by score descending before truncation to keep top-scoring detections.
1150    boxes.sort_unstable_by(|a, b| b.0.score.cmp(&a.0.score));
1151    boxes.truncate(max_det);
1152    let result: Vec<_> = {
1153        let _s = tracing::trace_span!("box_dequant", n = boxes.len()).entered();
1154        boxes
1155            .into_iter()
1156            .map(|(b, i)| (dequant_detect_box(&b, quant_scores), i))
1157            .collect()
1158    };
1159    span.record("n_detections", result.len());
1160
1161    result
1162}
1163
1164pub(crate) fn impl_yolo_split_segdet_quant_process_masks<
1165    MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1166    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1167>(
1168    boxes: Vec<(DetectBox, usize)>,
1169    mask_coeff: (ArrayView2<MASK>, Quantization),
1170    protos: (ArrayView3<PROTO>, Quantization),
1171    output_boxes: &mut Vec<DetectBox>,
1172    output_masks: &mut Vec<Segmentation>,
1173) -> Result<(), crate::DecoderError> {
1174    let _span = tracing::trace_span!("process_masks", n = boxes.len(), mode = "quant").entered();
1175    let (masks, quant_masks) = mask_coeff;
1176    let (protos, quant_protos) = protos;
1177
1178    // Boxes are already bounded by the upstream `max_det` cap from
1179    // `_get_boxes`; no second cap is needed here (EDGEAI-1302).
1180
1181    let boxes = decode_segdet_quant(boxes, masks, protos, quant_masks, quant_protos)?;
1182    output_boxes.clear();
1183    output_masks.clear();
1184    for (b, roi, m) in boxes.into_iter() {
1185        output_boxes.push(b);
1186        output_masks.push(Segmentation {
1187            xmin: roi.xmin,
1188            ymin: roi.ymin,
1189            xmax: roi.xmax,
1190            ymax: roi.ymax,
1191            segmentation: m,
1192        });
1193    }
1194    Ok(())
1195}
1196
1197/// Internal implementation of YOLO split detection segmentation decoding for
1198/// float tensors.
1199///
1200/// Expected shapes of inputs:
1201/// - boxes_tensor: (4, num_boxes)
1202/// - scores_tensor: (num_classes, num_boxes)
1203/// - mask_tensor: (num_protos, num_boxes)
1204/// - protos: (proto_height, proto_width, num_protos)
1205///
1206/// # Errors
1207/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
1208#[allow(clippy::too_many_arguments)]
1209pub(crate) fn impl_yolo_split_segdet_float<
1210    B: BBoxTypeTrait,
1211    BOX: Float + AsPrimitive<f32> + Send + Sync,
1212    SCORE: Float + AsPrimitive<f32> + Send + Sync,
1213    MASK: Float + AsPrimitive<f32> + Send + Sync,
1214    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1215>(
1216    boxes_tensor: ArrayView2<BOX>,
1217    scores_tensor: ArrayView2<SCORE>,
1218    mask_tensor: ArrayView2<MASK>,
1219    protos: ArrayView3<PROTO>,
1220    score_threshold: f32,
1221    iou_threshold: f32,
1222    nms: Option<Nms>,
1223    pre_nms_top_k: usize,
1224    max_det: usize,
1225    normalized: Option<bool>,
1226    input_dims: Option<(usize, usize)>,
1227    output_boxes: &mut Vec<DetectBox>,
1228    output_masks: &mut Vec<Segmentation>,
1229) -> Result<(), crate::DecoderError>
1230where
1231    f32: AsPrimitive<SCORE>,
1232{
1233    let (boxes_tensor, scores_tensor, mask_tensor) =
1234        postprocess_yolo_split_segdet(boxes_tensor, scores_tensor, mask_tensor);
1235
1236    let mut boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
1237        boxes_tensor,
1238        scores_tensor,
1239        score_threshold,
1240        iou_threshold,
1241        nms,
1242        pre_nms_top_k,
1243        max_det,
1244    );
1245    maybe_normalize_boxes_in_place(&mut boxes, normalized, input_dims);
1246    impl_yolo_split_segdet_process_masks(boxes, mask_tensor, protos, output_boxes, output_masks)
1247}
1248
1249// ---------------------------------------------------------------------------
1250// Proto-extraction variants: return ProtoData instead of materialized masks
1251// ---------------------------------------------------------------------------
1252
1253/// Proto-extraction variant of `impl_yolo_segdet_quant`.
1254/// Runs NMS but returns raw `ProtoData` instead of materialized masks.
1255#[allow(clippy::too_many_arguments)]
1256pub(crate) fn impl_yolo_segdet_quant_proto<
1257    B: BBoxTypeTrait,
1258    BOX: PrimInt
1259        + AsPrimitive<i64>
1260        + AsPrimitive<i128>
1261        + AsPrimitive<f32>
1262        + AsPrimitive<i8>
1263        + Send
1264        + Sync,
1265    PROTO: PrimInt
1266        + AsPrimitive<i64>
1267        + AsPrimitive<i128>
1268        + AsPrimitive<f32>
1269        + AsPrimitive<i8>
1270        + Send
1271        + Sync,
1272>(
1273    boxes: (ArrayView2<BOX>, Quantization),
1274    protos: (ArrayView3<PROTO>, Quantization),
1275    score_threshold: f32,
1276    iou_threshold: f32,
1277    nms: Option<Nms>,
1278    pre_nms_top_k: usize,
1279    max_det: usize,
1280    normalized: Option<bool>,
1281    input_dims: Option<(usize, usize)>,
1282    output_boxes: &mut Vec<DetectBox>,
1283) -> ProtoData
1284where
1285    f32: AsPrimitive<BOX>,
1286{
1287    let (boxes_arr, quant_boxes) = boxes;
1288    let (protos_arr, quant_protos) = protos;
1289    let num_protos = protos_arr.dim().2;
1290
1291    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes_arr, num_protos);
1292
1293    let mut det_indices = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
1294        (boxes_tensor, quant_boxes),
1295        (scores_tensor, quant_boxes),
1296        score_threshold,
1297        iou_threshold,
1298        nms,
1299        pre_nms_top_k,
1300        max_det,
1301    );
1302    maybe_normalize_boxes_in_place(&mut det_indices, normalized, input_dims);
1303
1304    extract_proto_data_quant(
1305        det_indices,
1306        mask_tensor,
1307        quant_boxes,
1308        protos_arr,
1309        quant_protos,
1310        output_boxes,
1311    )
1312}
1313
1314/// Proto-extraction variant of `impl_yolo_segdet_float`.
1315/// Runs NMS but returns raw `ProtoData` instead of materialized masks.
1316#[allow(clippy::too_many_arguments)]
1317pub(crate) fn impl_yolo_segdet_float_proto<
1318    B: BBoxTypeTrait,
1319    BOX: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1320    PROTO: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1321>(
1322    boxes: ArrayView2<BOX>,
1323    protos: ArrayView3<PROTO>,
1324    score_threshold: f32,
1325    iou_threshold: f32,
1326    nms: Option<Nms>,
1327    pre_nms_top_k: usize,
1328    max_det: usize,
1329    normalized: Option<bool>,
1330    input_dims: Option<(usize, usize)>,
1331    output_boxes: &mut Vec<DetectBox>,
1332) -> ProtoData
1333where
1334    f32: AsPrimitive<BOX>,
1335{
1336    let num_protos = protos.dim().2;
1337    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
1338
1339    let mut boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
1340        boxes_tensor,
1341        scores_tensor,
1342        score_threshold,
1343        iou_threshold,
1344        nms,
1345        pre_nms_top_k,
1346        max_det,
1347    );
1348    maybe_normalize_boxes_in_place(&mut boxes, normalized, input_dims);
1349
1350    extract_proto_data_float(boxes, mask_tensor, protos, output_boxes)
1351}
1352
1353/// Proto-extraction variant of `impl_yolo_split_segdet_float`.
1354/// Runs NMS but returns raw `ProtoData` instead of materialized masks.
1355#[allow(clippy::too_many_arguments)]
1356pub(crate) fn impl_yolo_split_segdet_float_proto<
1357    B: BBoxTypeTrait,
1358    BOX: Float + AsPrimitive<f32> + Send + Sync,
1359    SCORE: Float + AsPrimitive<f32> + Send + Sync,
1360    MASK: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1361    PROTO: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1362>(
1363    boxes_tensor: ArrayView2<BOX>,
1364    scores_tensor: ArrayView2<SCORE>,
1365    mask_tensor: ArrayView2<MASK>,
1366    protos: ArrayView3<PROTO>,
1367    score_threshold: f32,
1368    iou_threshold: f32,
1369    nms: Option<Nms>,
1370    pre_nms_top_k: usize,
1371    max_det: usize,
1372    output_boxes: &mut Vec<DetectBox>,
1373) -> ProtoData
1374where
1375    f32: AsPrimitive<SCORE>,
1376{
1377    let (boxes_tensor, scores_tensor, mask_tensor) =
1378        postprocess_yolo_split_segdet(boxes_tensor, scores_tensor, mask_tensor);
1379    let det_indices = impl_yolo_segdet_get_boxes::<B, _, _>(
1380        boxes_tensor,
1381        scores_tensor,
1382        score_threshold,
1383        iou_threshold,
1384        nms,
1385        pre_nms_top_k,
1386        max_det,
1387    );
1388
1389    extract_proto_data_float(det_indices, mask_tensor, protos, output_boxes)
1390}
1391
1392/// Proto-extraction variant of `decode_yolo_end_to_end_segdet_float`.
1393pub(crate) fn decode_yolo_end_to_end_segdet_float_proto<T>(
1394    output: ArrayView2<T>,
1395    protos: ArrayView3<T>,
1396    score_threshold: f32,
1397    output_boxes: &mut Vec<DetectBox>,
1398) -> Result<ProtoData, crate::DecoderError>
1399where
1400    T: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1401    f32: AsPrimitive<T>,
1402{
1403    let (boxes, scores, classes, mask_coeff) =
1404        postprocess_yolo_end_to_end_segdet(&output, protos.dim().2)?;
1405    let cap = cap_or_default(output_boxes);
1406    let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
1407        boxes,
1408        scores,
1409        classes,
1410        score_threshold,
1411        cap,
1412    );
1413
1414    Ok(extract_proto_data_float(
1415        boxes,
1416        mask_coeff,
1417        protos,
1418        output_boxes,
1419    ))
1420}
1421
1422/// Proto-extraction variant of `decode_yolo_split_end_to_end_segdet_float`.
1423#[allow(clippy::too_many_arguments)]
1424pub(crate) fn decode_yolo_split_end_to_end_segdet_float_proto<T>(
1425    boxes: ArrayView2<T>,
1426    scores: ArrayView2<T>,
1427    classes: ArrayView2<T>,
1428    mask_coeff: ArrayView2<T>,
1429    protos: ArrayView3<T>,
1430    score_threshold: f32,
1431    output_boxes: &mut Vec<DetectBox>,
1432) -> Result<ProtoData, crate::DecoderError>
1433where
1434    T: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1435    f32: AsPrimitive<T>,
1436{
1437    let (boxes, scores, classes, mask_coeff) =
1438        postprocess_yolo_split_end_to_end_segdet(boxes, scores, &classes, mask_coeff)?;
1439    let cap = cap_or_default(output_boxes);
1440    let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
1441        boxes,
1442        scores,
1443        classes,
1444        score_threshold,
1445        cap,
1446    );
1447
1448    Ok(extract_proto_data_float(
1449        boxes,
1450        mask_coeff,
1451        protos,
1452        output_boxes,
1453    ))
1454}
1455
1456/// Helper: extract ProtoData from float mask coefficients + protos.
1457///
1458/// Builds [`ProtoData`] with both `protos` and `mask_coefficients` as
1459/// [`edgefirst_tensor::TensorDyn`]. Preserves the native element type for
1460/// `f16` and `f32`; narrows `f64` to `f32` (there is no native f64 kernel
1461/// path). `mask_coefficients` shape is `[num_detections, num_protos]`.
1462pub(super) fn extract_proto_data_float<
1463    MASK: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1464    PROTO: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1465>(
1466    det_indices: Vec<(DetectBox, usize)>,
1467    mask_tensor: ArrayView2<MASK>,
1468    protos: ArrayView3<PROTO>,
1469    output_boxes: &mut Vec<DetectBox>,
1470) -> ProtoData {
1471    let _span = tracing::trace_span!(
1472        "extract_proto",
1473        n = det_indices.len(),
1474        num_protos = mask_tensor.ncols(),
1475        layout = "nhwc",
1476    )
1477    .entered();
1478
1479    let num_protos = mask_tensor.ncols();
1480    let n = det_indices.len();
1481
1482    // Per-detection coefficients packed row-major into a contiguous buffer,
1483    // preserving the source dtype. Shape: [N, num_protos] — N=0 is permitted
1484    // (tracker path emits no detections this frame) since Mem-backed tensors
1485    // accept zero-element shapes as "empty collection" sentinels.
1486    let mut coeff_rows: Vec<MASK> = Vec::with_capacity(n * num_protos);
1487    output_boxes.clear();
1488    for (det, idx) in det_indices {
1489        output_boxes.push(det);
1490        let row = mask_tensor.row(idx);
1491        coeff_rows.extend(row.iter().copied());
1492    }
1493
1494    let mask_coefficients = MASK::slice_into_tensor_dyn(&coeff_rows, &[n, num_protos])
1495        .expect("allocating mask_coefficients TensorDyn");
1496    let protos_tensor =
1497        PROTO::arrayview3_into_tensor_dyn(protos).expect("allocating protos TensorDyn");
1498
1499    ProtoData {
1500        mask_coefficients,
1501        protos: protos_tensor,
1502        layout: ProtoLayout::Nhwc,
1503    }
1504}
1505
1506/// Helper: extract ProtoData from quantized mask coefficients + protos.
1507///
1508/// Dequantizes mask coefficients to f32 at extraction (one-time cost on a
1509/// `num_detections * num_protos` slice) and keeps protos in raw i8,
1510/// attaching the dequantization params as
1511/// [`edgefirst_tensor::Quantization::per_tensor`] metadata on the proto
1512/// tensor. The GPU shader / CPU kernel reads `protos.quantization()` and
1513/// dequantizes per-texel.
1514pub(crate) fn extract_proto_data_quant<
1515    MASK: PrimInt + AsPrimitive<f32> + AsPrimitive<i8> + Send + Sync,
1516    PROTO: PrimInt + AsPrimitive<f32> + AsPrimitive<i8> + Send + Sync + 'static,
1517>(
1518    det_indices: Vec<(DetectBox, usize)>,
1519    mask_tensor: ArrayView2<MASK>,
1520    quant_masks: Quantization,
1521    protos: ArrayView3<PROTO>,
1522    quant_protos: Quantization,
1523    output_boxes: &mut Vec<DetectBox>,
1524) -> ProtoData {
1525    use edgefirst_tensor::{Tensor, TensorDyn, TensorMapTrait, TensorMemory, TensorTrait};
1526
1527    let span = tracing::trace_span!(
1528        "extract_proto",
1529        n = det_indices.len(),
1530        num_protos = tracing::field::Empty,
1531        layout = tracing::field::Empty,
1532    );
1533    let _guard = span.enter();
1534
1535    let num_protos = mask_tensor.ncols();
1536    let n = det_indices.len();
1537    span.record("num_protos", num_protos);
1538
1539    // Fast path: when no detections survive NMS, skip the expensive proto
1540    // tensor copy (819KB for 160×160×32). Allocate with the correct shape
1541    // (preserving the documented ProtoData.protos shape contract) but skip
1542    // copying from the source tensor — the zeroed allocation is sufficient
1543    // since materialize_masks early-returns on empty detect slices.
1544    if n == 0 {
1545        output_boxes.clear();
1546        let (h, w, k) = protos.dim();
1547
1548        // Detect physical layout (same logic as the normal path).
1549        let (proto_shape, proto_layout) = if std::any::TypeId::of::<PROTO>()
1550            == std::any::TypeId::of::<i8>()
1551        {
1552            if protos.is_standard_layout() {
1553                (&[h, w, k][..], ProtoLayout::Nhwc)
1554            } else if protos.ndim() == 3 && protos.strides() == [w as isize, 1, (h * w) as isize] {
1555                (&[k, h, w][..], ProtoLayout::Nchw)
1556            } else {
1557                (&[h, w, k][..], ProtoLayout::Nhwc)
1558            }
1559        } else {
1560            (&[h, w, k][..], ProtoLayout::Nhwc)
1561        };
1562
1563        let coeff_tensor = Tensor::<i8>::new(&[0, num_protos], Some(TensorMemory::Mem), None)
1564            .expect("allocating empty mask_coefficients tensor");
1565        let coeff_quant =
1566            edgefirst_tensor::Quantization::per_tensor(quant_masks.scale, quant_masks.zero_point);
1567        let coeff_tensor = coeff_tensor
1568            .with_quantization(coeff_quant)
1569            .expect("per-tensor quantization on mask coefficients");
1570        let protos_tensor = Tensor::<i8>::new(proto_shape, Some(TensorMemory::Mem), None)
1571            .expect("allocating protos tensor");
1572        let tensor_quant =
1573            edgefirst_tensor::Quantization::per_tensor(quant_protos.scale, quant_protos.zero_point);
1574        let protos_tensor = protos_tensor
1575            .with_quantization(tensor_quant)
1576            .expect("per-tensor quantization on protos tensor");
1577        return ProtoData {
1578            mask_coefficients: TensorDyn::I8(coeff_tensor),
1579            protos: TensorDyn::I8(protos_tensor),
1580            layout: proto_layout,
1581        };
1582    }
1583
1584    // Keep mask coefficients in raw i8 with quantization metadata attached.
1585    // Consumers that need f32 can dequantize on the fly; the scaled-path
1586    // integer kernel uses raw i8 directly for i8×i8→i32 dot products.
1587    let mut coeff_i8 = Vec::<i8>::with_capacity(n * num_protos);
1588    output_boxes.clear();
1589    for (det, idx) in det_indices {
1590        output_boxes.push(det);
1591        let row = mask_tensor.row(idx);
1592        coeff_i8.extend(row.iter().map(|v| {
1593            let v_i8: i8 = v.as_();
1594            v_i8
1595        }));
1596    }
1597
1598    // Shape `[n, num_protos]` with n=0 is permitted (tracker path emits no
1599    // fresh detections this frame) via the Mem-backed zero-size allowance.
1600    let coeff_tensor = Tensor::<i8>::new(&[n, num_protos], Some(TensorMemory::Mem), None)
1601        .expect("allocating mask_coefficients tensor");
1602    if n > 0 {
1603        let mut m = coeff_tensor
1604            .map()
1605            .expect("mapping mask_coefficients tensor");
1606        m.as_mut_slice().copy_from_slice(&coeff_i8);
1607    }
1608    let coeff_quant =
1609        edgefirst_tensor::Quantization::per_tensor(quant_masks.scale, quant_masks.zero_point);
1610    let coeff_tensor = coeff_tensor
1611        .with_quantization(coeff_quant)
1612        .expect("per-tensor quantization on mask coefficients");
1613    let mask_coefficients = TensorDyn::I8(coeff_tensor);
1614
1615    // Keep protos in raw i8 — consumers dequantize via protos.quantization().
1616    // When PROTO is already i8, detect layout and copy efficiently without
1617    // transposing. The mask materialisation kernels dispatch on the layout.
1618    let (h, w, k) = protos.dim();
1619
1620    // Determine physical layout and copy strategy.
1621    let (proto_shape, proto_layout) =
1622        if std::any::TypeId::of::<PROTO>() == std::any::TypeId::of::<i8>() {
1623            if protos.is_standard_layout() {
1624                // Already NHWC [H, W, K] in contiguous memory.
1625                (&[h, w, k][..], ProtoLayout::Nhwc)
1626            } else if protos.ndim() == 3 && protos.strides() == [w as isize, 1, (h * w) as isize] {
1627                // NCHW reinterpreted as NHWC via stride swap. Physical storage
1628                // is [K, H, W] contiguous. Keep in NCHW — eliminates the costly
1629                // 3.1ms transpose entirely.
1630                (&[k, h, w][..], ProtoLayout::Nchw)
1631            } else {
1632                // Unknown layout — fall back to iter copy as NHWC.
1633                (&[h, w, k][..], ProtoLayout::Nhwc)
1634            }
1635        } else {
1636            (&[h, w, k][..], ProtoLayout::Nhwc)
1637        };
1638
1639    let protos_tensor = Tensor::<i8>::new(proto_shape, Some(TensorMemory::Mem), None)
1640        .expect("allocating protos tensor");
1641    {
1642        let mut m = protos_tensor.map().expect("mapping protos tensor");
1643        let dst = m.as_mut_slice();
1644        if std::any::TypeId::of::<PROTO>() == std::any::TypeId::of::<i8>() {
1645            // SAFETY: PROTO == i8 checked via TypeId; cast slice view is
1646            // size/alignment-compatible by construction.
1647            if protos.is_standard_layout() {
1648                let src: &[i8] = unsafe {
1649                    std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len())
1650                };
1651                dst.copy_from_slice(src);
1652            } else if protos.ndim() == 3 && protos.strides() == [w as isize, 1, (h * w) as isize] {
1653                // NCHW physical layout — sequential copy WITHOUT transpose.
1654                // This saves ~3.1ms on A53/A55 by avoiding the tiled
1655                // NCHW→NHWC transpose of the 819KB proto buffer.
1656                let total = h * w * k;
1657                // SAFETY: ArrayView was constructed from a contiguous slice of
1658                // `total` elements. as_ptr() points to the base of that slice.
1659                let src: &[i8] =
1660                    unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, total) };
1661                dst.copy_from_slice(src);
1662            } else {
1663                for (d, s) in dst.iter_mut().zip(protos.iter()) {
1664                    let v_i8: i8 = s.as_();
1665                    *d = v_i8;
1666                }
1667            }
1668        } else {
1669            for (d, s) in dst.iter_mut().zip(protos.iter()) {
1670                let v_i8: i8 = s.as_();
1671                *d = v_i8;
1672            }
1673        }
1674    }
1675    let tensor_quant =
1676        edgefirst_tensor::Quantization::per_tensor(quant_protos.scale, quant_protos.zero_point);
1677    let protos_tensor = protos_tensor
1678        .with_quantization(tensor_quant)
1679        .expect("per-tensor quantization on new Tensor<i8>");
1680
1681    span.record("layout", tracing::field::debug(&proto_layout));
1682
1683    ProtoData {
1684        mask_coefficients,
1685        protos: TensorDyn::I8(protos_tensor),
1686        layout: proto_layout,
1687    }
1688}
1689
1690/// Per-float-dtype construction of a [`TensorDyn`] from a flat slice / 3-D
1691/// `ArrayView`. Replaces the old `IntoProtoTensor` trait. Each implementor
1692/// either passes its element type straight to `Tensor::from_slice` /
1693/// `Tensor::from_arrayview3`, or narrows `f64` to `f32` (no native f64 kernel
1694/// path exists).
1695pub trait FloatProtoElem: Copy + 'static {
1696    fn slice_into_tensor_dyn(
1697        values: &[Self],
1698        shape: &[usize],
1699    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn>;
1700
1701    fn arrayview3_into_tensor_dyn(
1702        view: ArrayView3<'_, Self>,
1703    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn>;
1704}
1705
1706impl FloatProtoElem for f32 {
1707    fn slice_into_tensor_dyn(
1708        values: &[f32],
1709        shape: &[usize],
1710    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1711        edgefirst_tensor::Tensor::<f32>::from_slice(values, shape)
1712            .map(edgefirst_tensor::TensorDyn::F32)
1713    }
1714    fn arrayview3_into_tensor_dyn(
1715        view: ArrayView3<'_, f32>,
1716    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1717        edgefirst_tensor::Tensor::<f32>::from_arrayview3(view).map(edgefirst_tensor::TensorDyn::F32)
1718    }
1719}
1720
1721impl FloatProtoElem for half::f16 {
1722    fn slice_into_tensor_dyn(
1723        values: &[half::f16],
1724        shape: &[usize],
1725    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1726        edgefirst_tensor::Tensor::<half::f16>::from_slice(values, shape)
1727            .map(edgefirst_tensor::TensorDyn::F16)
1728    }
1729    fn arrayview3_into_tensor_dyn(
1730        view: ArrayView3<'_, half::f16>,
1731    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1732        edgefirst_tensor::Tensor::<half::f16>::from_arrayview3(view)
1733            .map(edgefirst_tensor::TensorDyn::F16)
1734    }
1735}
1736
1737impl FloatProtoElem for f64 {
1738    fn slice_into_tensor_dyn(
1739        values: &[f64],
1740        shape: &[usize],
1741    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1742        // Narrow to f32 — no native f64 kernel path.
1743        let narrowed: Vec<f32> = values.iter().map(|&v| v as f32).collect();
1744        edgefirst_tensor::Tensor::<f32>::from_slice(&narrowed, shape)
1745            .map(edgefirst_tensor::TensorDyn::F32)
1746    }
1747    fn arrayview3_into_tensor_dyn(
1748        view: ArrayView3<'_, f64>,
1749    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1750        let narrowed: ndarray::Array3<f32> = view.mapv(|v| v as f32);
1751        edgefirst_tensor::Tensor::<f32>::from_arrayview3(narrowed.view())
1752            .map(edgefirst_tensor::TensorDyn::F32)
1753    }
1754}
1755
1756fn postprocess_yolo<'a, T>(
1757    output: &'a ArrayView2<'_, T>,
1758) -> (ArrayView2<'a, T>, ArrayView2<'a, T>) {
1759    let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1760    let scores_tensor = output.slice(s![4.., ..,]).reversed_axes();
1761    (boxes_tensor, scores_tensor)
1762}
1763
1764pub(crate) fn postprocess_yolo_seg<'a, T>(
1765    output: &'a ArrayView2<'_, T>,
1766    num_protos: usize,
1767) -> (ArrayView2<'a, T>, ArrayView2<'a, T>, ArrayView2<'a, T>) {
1768    assert!(
1769        output.shape()[0] > num_protos + 4,
1770        "Output shape is too short: {} <= {} + 4",
1771        output.shape()[0],
1772        num_protos
1773    );
1774    let num_classes = output.shape()[0] - 4 - num_protos;
1775    let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1776    let scores_tensor = output.slice(s![4..(num_classes + 4), ..,]).reversed_axes();
1777    let mask_tensor = output.slice(s![(num_classes + 4).., ..,]).reversed_axes();
1778    (boxes_tensor, scores_tensor, mask_tensor)
1779}
1780
1781pub(crate) fn postprocess_yolo_split_segdet<'a, 'b, 'c, BOX, SCORE, MASK>(
1782    boxes_tensor: ArrayView2<'a, BOX>,
1783    scores_tensor: ArrayView2<'b, SCORE>,
1784    mask_tensor: ArrayView2<'c, MASK>,
1785) -> (
1786    ArrayView2<'a, BOX>,
1787    ArrayView2<'b, SCORE>,
1788    ArrayView2<'c, MASK>,
1789) {
1790    let boxes_tensor = boxes_tensor.reversed_axes();
1791    let scores_tensor = scores_tensor.reversed_axes();
1792    let mask_tensor = mask_tensor.reversed_axes();
1793    (boxes_tensor, scores_tensor, mask_tensor)
1794}
1795
1796fn decode_segdet_f32<
1797    MASK: Float + AsPrimitive<f32> + Send + Sync,
1798    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1799>(
1800    boxes: Vec<(DetectBox, usize)>,
1801    masks: ArrayView2<MASK>,
1802    protos: ArrayView3<PROTO>,
1803) -> Result<Vec<(DetectBox, BoundingBox, Array3<u8>)>, crate::DecoderError> {
1804    if boxes.is_empty() {
1805        return Ok(Vec::new());
1806    }
1807    if masks.shape()[1] != protos.shape()[2] {
1808        return Err(crate::DecoderError::InvalidShape(format!(
1809            "Mask coefficients count ({}) doesn't match protos channel count ({})",
1810            masks.shape()[1],
1811            protos.shape()[2],
1812        )));
1813    }
1814    boxes
1815        .into_par_iter()
1816        .map(|b| {
1817            let ind = b.1;
1818            // `protobox` returns the cropped proto slice for `make_segmentation`
1819            // and a `roi` snapped to the 1/proto-grid step. The detection bbox
1820            // stays untouched (EDGEAI-1304); the snapped roi is reported back
1821            // separately so callers can describe where the cropped mask lives.
1822            let (protos, roi) = protobox(&protos, &b.0.bbox)?;
1823            Ok((b.0, roi, make_segmentation(masks.row(ind), protos.view())))
1824        })
1825        .collect()
1826}
1827
1828pub(crate) fn decode_segdet_quant<
1829    MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1830    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1831>(
1832    boxes: Vec<(DetectBox, usize)>,
1833    masks: ArrayView2<MASK>,
1834    protos: ArrayView3<PROTO>,
1835    quant_masks: Quantization,
1836    quant_protos: Quantization,
1837) -> Result<Vec<(DetectBox, BoundingBox, Array3<u8>)>, crate::DecoderError> {
1838    if boxes.is_empty() {
1839        return Ok(Vec::new());
1840    }
1841    if masks.shape()[1] != protos.shape()[2] {
1842        return Err(crate::DecoderError::InvalidShape(format!(
1843            "Mask coefficients count ({}) doesn't match protos channel count ({})",
1844            masks.shape()[1],
1845            protos.shape()[2],
1846        )));
1847    }
1848
1849    let total_bits = MASK::zero().count_zeros() + PROTO::zero().count_zeros() + 5; // 32 protos is 2^5
1850    boxes
1851        .into_iter()
1852        .map(|b| {
1853            let i = b.1;
1854            // See EDGEAI-1304: the caller's bbox stays untouched; the
1855            // proto-grid-snapped `roi` is reported back so the Segmentation's
1856            // bounds can describe the actual cropped mask region.
1857            let (protos, roi) = protobox(&protos, &b.0.bbox.to_canonical())?;
1858            let seg = match total_bits {
1859                0..=64 => make_segmentation_quant::<MASK, PROTO, i64>(
1860                    masks.row(i),
1861                    protos.view(),
1862                    quant_masks,
1863                    quant_protos,
1864                ),
1865                65..=128 => make_segmentation_quant::<MASK, PROTO, i128>(
1866                    masks.row(i),
1867                    protos.view(),
1868                    quant_masks,
1869                    quant_protos,
1870                ),
1871                _ => {
1872                    return Err(crate::DecoderError::NotSupported(format!(
1873                        "Unsupported bit width ({total_bits}) for segmentation computation"
1874                    )));
1875                }
1876            };
1877            Ok((b.0, roi, seg))
1878        })
1879        .collect()
1880}
1881
1882fn protobox<'a, T>(
1883    protos: &'a ArrayView3<T>,
1884    roi: &BoundingBox,
1885) -> Result<(ArrayView3<'a, T>, BoundingBox), crate::DecoderError> {
1886    let width = protos.dim().1 as f32;
1887    let height = protos.dim().0 as f32;
1888
1889    // Detect un-normalized bounding boxes (pixel-space coordinates).
1890    // protobox expects normalized coordinates in [0, 1]. The decoder will
1891    // normalize pixel-space coords automatically when the schema declares
1892    // `Detection::normalized = false` AND model input dimensions are known
1893    // (EDGEAI-1303); reaching this guard means at least one of those is
1894    // missing.
1895    //
1896    // The limit is set to 2.0 (not 1.01) because YOLO models legitimately
1897    // predict coordinates slightly > 1.0 for objects near frame edges.
1898    // Any value > 2.0 is clearly pixel-space (even the smallest practical
1899    // model input of 32×32 would produce coordinates >> 2.0).
1900    const NORM_LIMIT: f32 = 2.0;
1901    if roi.xmin > NORM_LIMIT
1902        || roi.ymin > NORM_LIMIT
1903        || roi.xmax > NORM_LIMIT
1904        || roi.ymax > NORM_LIMIT
1905    {
1906        return Err(crate::DecoderError::InvalidShape(format!(
1907            "Bounding box coordinates appear un-normalized (pixel-space). \
1908             Got bbox=({:.2}, {:.2}, {:.2}, {:.2}) but expected values in [0, 1]. \
1909             Two ways to fix this: \
1910             (1) declare `Detection::normalized = false` in the model schema \
1911             AND make sure the schema's `input.shape` / `input.dshape` carries \
1912             the model input dims so the decoder can divide by (W, H) before NMS \
1913             (EDGEAI-1303 — verify with `Decoder::input_dims().is_some()`); or \
1914             (2) normalize the boxes in-graph before decode().",
1915            roi.xmin, roi.ymin, roi.xmax, roi.ymax,
1916        )));
1917    }
1918
1919    let roi = [
1920        (roi.xmin * width).clamp(0.0, width) as usize,
1921        (roi.ymin * height).clamp(0.0, height) as usize,
1922        (roi.xmax * width).clamp(0.0, width).ceil() as usize,
1923        (roi.ymax * height).clamp(0.0, height).ceil() as usize,
1924    ];
1925
1926    let roi_norm = [
1927        roi[0] as f32 / width,
1928        roi[1] as f32 / height,
1929        roi[2] as f32 / width,
1930        roi[3] as f32 / height,
1931    ]
1932    .into();
1933
1934    let cropped = protos.slice(s![roi[1]..roi[3], roi[0]..roi[2], ..]);
1935
1936    Ok((cropped, roi_norm))
1937}
1938
1939/// Compute a single instance segmentation mask from mask coefficients and
1940/// proto maps (float path).
1941///
1942/// Computes `sigmoid(coefficients · protos)` and maps to `[0, 255]`.
1943/// Returns an `(H, W, 1)` u8 array.
1944fn make_segmentation<
1945    MASK: Float + AsPrimitive<f32> + Send + Sync,
1946    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1947>(
1948    mask: ArrayView1<MASK>,
1949    protos: ArrayView3<PROTO>,
1950) -> Array3<u8> {
1951    let shape = protos.shape();
1952
1953    // Safe to unwrap since the shapes will always be compatible
1954    let mask = mask.to_shape((1, mask.len())).unwrap();
1955    let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
1956    let protos = protos.reversed_axes();
1957    let mask = mask.map(|x| x.as_());
1958    let protos = protos.map(|x| x.as_());
1959
1960    // Safe to unwrap since the shapes will always be compatible
1961    let mask = mask
1962        .dot(&protos)
1963        .into_shape_with_order((shape[0], shape[1], 1))
1964        .unwrap();
1965
1966    mask.map(|x| {
1967        let sigmoid = 1.0 / (1.0 + (-*x).exp());
1968        (sigmoid * 255.0).round() as u8
1969    })
1970}
1971
1972/// Compute a single instance segmentation mask from quantized mask
1973/// coefficients and proto maps.
1974///
1975/// Dequantizes both inputs (subtracting zero-points), computes the dot
1976/// product, applies sigmoid, and maps to `[0, 255]`.
1977/// Returns an `(H, W, 1)` u8 array.
1978fn make_segmentation_quant<
1979    MASK: PrimInt + AsPrimitive<DEST> + Send + Sync,
1980    PROTO: PrimInt + AsPrimitive<DEST> + Send + Sync,
1981    DEST: PrimInt + 'static + Signed + AsPrimitive<f32> + Debug,
1982>(
1983    mask: ArrayView1<MASK>,
1984    protos: ArrayView3<PROTO>,
1985    quant_masks: Quantization,
1986    quant_protos: Quantization,
1987) -> Array3<u8>
1988where
1989    i32: AsPrimitive<DEST>,
1990    f32: AsPrimitive<DEST>,
1991{
1992    let shape = protos.shape();
1993
1994    // Safe to unwrap since the shapes will always be compatible
1995    let mask = mask.to_shape((1, mask.len())).unwrap();
1996
1997    let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
1998    let protos = protos.reversed_axes();
1999
2000    let zp = quant_masks.zero_point.as_();
2001
2002    let mask = mask.mapv(|x| x.as_() - zp);
2003
2004    let zp = quant_protos.zero_point.as_();
2005    let protos = protos.mapv(|x| x.as_() - zp);
2006
2007    // Safe to unwrap since the shapes will always be compatible
2008    let segmentation = mask
2009        .dot(&protos)
2010        .into_shape_with_order((shape[0], shape[1], 1))
2011        .unwrap();
2012
2013    let combined_scale = quant_masks.scale * quant_protos.scale;
2014    segmentation.map(|x| {
2015        let val: f32 = (*x).as_() * combined_scale;
2016        let sigmoid = 1.0 / (1.0 + (-val).exp());
2017        (sigmoid * 255.0).round() as u8
2018    })
2019}
2020
2021/// Converts Yolo Instance Segmentation into a 2D mask.
2022///
2023/// The input segmentation is expected to have shape (H, W, 1).
2024///
2025/// The output mask will have shape (H, W), with values 0 or 1 based on the
2026/// threshold.
2027///
2028/// # Errors
2029///
2030/// Returns `DecoderError::InvalidShape` if the input segmentation does not
2031/// have shape (H, W, 1).
2032pub(crate) fn yolo_segmentation_to_mask(
2033    segmentation: ArrayView3<u8>,
2034    threshold: u8,
2035) -> Result<Array2<u8>, crate::DecoderError> {
2036    if segmentation.shape()[2] != 1 {
2037        return Err(crate::DecoderError::InvalidShape(format!(
2038            "Yolo Instance Segmentation should have shape (H, W, 1), got (H, W, {})",
2039            segmentation.shape()[2]
2040        )));
2041    }
2042    Ok(segmentation
2043        .slice(s![.., .., 0])
2044        .map(|x| if *x >= threshold { 1 } else { 0 }))
2045}
2046
2047#[cfg(test)]
2048#[cfg_attr(coverage_nightly, coverage(off))]
2049mod tests {
2050    use super::*;
2051    use ndarray::Array2;
2052
2053    // ========================================================================
2054    // Tests for decode_yolo_end_to_end_det_float
2055    // ========================================================================
2056
2057    #[test]
2058    fn test_end_to_end_det_basic_filtering() {
2059        // Create synthetic end-to-end detection output: (6, N) where rows are
2060        // [x1, y1, x2, y2, conf, class]
2061        // 3 detections: one above threshold, two below
2062        let data: Vec<f32> = vec![
2063            // Detection 0: high score (0.9)
2064            0.1, 0.2, 0.3, // x1 values
2065            0.1, 0.2, 0.3, // y1 values
2066            0.5, 0.6, 0.7, // x2 values
2067            0.5, 0.6, 0.7, // y2 values
2068            0.9, 0.1, 0.2, // confidence scores
2069            0.0, 1.0, 2.0, // class indices
2070        ];
2071        let output = Array2::from_shape_vec((6, 3), data).unwrap();
2072
2073        let mut boxes = Vec::with_capacity(10);
2074        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2075
2076        // Only 1 detection should pass threshold of 0.5
2077        assert_eq!(boxes.len(), 1);
2078        assert_eq!(boxes[0].label, 0);
2079        assert!((boxes[0].score - 0.9).abs() < 0.01);
2080        assert!((boxes[0].bbox.xmin - 0.1).abs() < 0.01);
2081        assert!((boxes[0].bbox.ymin - 0.1).abs() < 0.01);
2082        assert!((boxes[0].bbox.xmax - 0.5).abs() < 0.01);
2083        assert!((boxes[0].bbox.ymax - 0.5).abs() < 0.01);
2084    }
2085
2086    #[test]
2087    fn test_end_to_end_det_all_pass_threshold() {
2088        // All detections above threshold
2089        let data: Vec<f32> = vec![
2090            10.0, 20.0, // x1
2091            10.0, 20.0, // y1
2092            50.0, 60.0, // x2
2093            50.0, 60.0, // y2
2094            0.8, 0.7, // conf (both above 0.5)
2095            1.0, 2.0, // class
2096        ];
2097        let output = Array2::from_shape_vec((6, 2), data).unwrap();
2098
2099        let mut boxes = Vec::with_capacity(10);
2100        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2101
2102        assert_eq!(boxes.len(), 2);
2103        assert_eq!(boxes[0].label, 1);
2104        assert_eq!(boxes[1].label, 2);
2105    }
2106
2107    #[test]
2108    fn test_end_to_end_det_none_pass_threshold() {
2109        // All detections below threshold
2110        let data: Vec<f32> = vec![
2111            10.0, 20.0, // x1
2112            10.0, 20.0, // y1
2113            50.0, 60.0, // x2
2114            50.0, 60.0, // y2
2115            0.1, 0.2, // conf (both below 0.5)
2116            1.0, 2.0, // class
2117        ];
2118        let output = Array2::from_shape_vec((6, 2), data).unwrap();
2119
2120        let mut boxes = Vec::with_capacity(10);
2121        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2122
2123        assert_eq!(boxes.len(), 0);
2124    }
2125
2126    #[test]
2127    fn test_end_to_end_det_capacity_limit() {
2128        // Test that output is truncated to capacity
2129        let data: Vec<f32> = vec![
2130            0.1, 0.2, 0.3, 0.4, 0.5, // x1
2131            0.1, 0.2, 0.3, 0.4, 0.5, // y1
2132            0.5, 0.6, 0.7, 0.8, 0.9, // x2
2133            0.5, 0.6, 0.7, 0.8, 0.9, // y2
2134            0.9, 0.9, 0.9, 0.9, 0.9, // conf (all pass)
2135            0.0, 1.0, 2.0, 3.0, 4.0, // class
2136        ];
2137        let output = Array2::from_shape_vec((6, 5), data).unwrap();
2138
2139        let mut boxes = Vec::with_capacity(2); // Only allow 2 boxes
2140        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2141
2142        assert_eq!(boxes.len(), 2);
2143    }
2144
2145    #[test]
2146    fn test_end_to_end_det_empty_output() {
2147        // Test with zero detections
2148        let output = Array2::<f32>::zeros((6, 0));
2149
2150        let mut boxes = Vec::with_capacity(10);
2151        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2152
2153        assert_eq!(boxes.len(), 0);
2154    }
2155
2156    #[test]
2157    fn test_end_to_end_det_pixel_coordinates() {
2158        // Test with pixel coordinates (non-normalized)
2159        let data: Vec<f32> = vec![
2160            100.0, // x1
2161            200.0, // y1
2162            300.0, // x2
2163            400.0, // y2
2164            0.95,  // conf
2165            5.0,   // class
2166        ];
2167        let output = Array2::from_shape_vec((6, 1), data).unwrap();
2168
2169        let mut boxes = Vec::with_capacity(10);
2170        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2171
2172        assert_eq!(boxes.len(), 1);
2173        assert_eq!(boxes[0].label, 5);
2174        assert!((boxes[0].bbox.xmin - 100.0).abs() < 0.01);
2175        assert!((boxes[0].bbox.ymin - 200.0).abs() < 0.01);
2176        assert!((boxes[0].bbox.xmax - 300.0).abs() < 0.01);
2177        assert!((boxes[0].bbox.ymax - 400.0).abs() < 0.01);
2178    }
2179
2180    #[test]
2181    fn test_end_to_end_det_invalid_shape() {
2182        // Test with too few rows (needs at least 6)
2183        let output = Array2::<f32>::zeros((5, 3));
2184
2185        let mut boxes = Vec::with_capacity(10);
2186        let result = decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes);
2187
2188        assert!(result.is_err());
2189        assert!(matches!(
2190            result,
2191            Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 6 rows")
2192        ));
2193    }
2194
2195    // ========================================================================
2196    // Tests for decode_yolo_end_to_end_segdet_float
2197    // ========================================================================
2198
2199    #[test]
2200    fn test_end_to_end_segdet_basic() {
2201        // Create synthetic segdet output: (6 + num_protos, N)
2202        // Detection format: [x1, y1, x2, y2, conf, class, mask_coeff_0..31]
2203        let num_protos = 32;
2204        let num_detections = 2;
2205        let num_features = 6 + num_protos;
2206
2207        // Build detection tensor
2208        let mut data = vec![0.0f32; num_features * num_detections];
2209        // Detection 0: passes threshold
2210        data[0] = 0.1; // x1[0]
2211        data[1] = 0.5; // x1[1]
2212        data[num_detections] = 0.1; // y1[0]
2213        data[num_detections + 1] = 0.5; // y1[1]
2214        data[2 * num_detections] = 0.4; // x2[0]
2215        data[2 * num_detections + 1] = 0.9; // x2[1]
2216        data[3 * num_detections] = 0.4; // y2[0]
2217        data[3 * num_detections + 1] = 0.9; // y2[1]
2218        data[4 * num_detections] = 0.9; // conf[0] - passes
2219        data[4 * num_detections + 1] = 0.3; // conf[1] - fails
2220        data[5 * num_detections] = 1.0; // class[0]
2221        data[5 * num_detections + 1] = 2.0; // class[1]
2222                                            // Fill mask coefficients with small values
2223        for i in 6..num_features {
2224            data[i * num_detections] = 0.1;
2225            data[i * num_detections + 1] = 0.1;
2226        }
2227
2228        let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
2229
2230        // Create protos tensor: (proto_height, proto_width, num_protos)
2231        let protos = Array3::<f32>::zeros((16, 16, num_protos));
2232
2233        let mut boxes = Vec::with_capacity(10);
2234        let mut masks = Vec::with_capacity(10);
2235        decode_yolo_end_to_end_segdet_float(
2236            output.view(),
2237            protos.view(),
2238            0.5,
2239            &mut boxes,
2240            &mut masks,
2241        )
2242        .unwrap();
2243
2244        // Only detection 0 should pass
2245        assert_eq!(boxes.len(), 1);
2246        assert_eq!(masks.len(), 1);
2247        assert_eq!(boxes[0].label, 1);
2248        assert!((boxes[0].score - 0.9).abs() < 0.01);
2249    }
2250
2251    #[test]
2252    fn test_end_to_end_segdet_mask_coordinates() {
2253        // Test that mask coordinates match box coordinates
2254        let num_protos = 32;
2255        let num_features = 6 + num_protos;
2256
2257        let mut data = vec![0.0f32; num_features];
2258        data[0] = 0.2; // x1
2259        data[1] = 0.2; // y1
2260        data[2] = 0.8; // x2
2261        data[3] = 0.8; // y2
2262        data[4] = 0.95; // conf
2263        data[5] = 3.0; // class
2264
2265        let output = Array2::from_shape_vec((num_features, 1), data).unwrap();
2266        let protos = Array3::<f32>::zeros((16, 16, num_protos));
2267
2268        let mut boxes = Vec::with_capacity(10);
2269        let mut masks = Vec::with_capacity(10);
2270        decode_yolo_end_to_end_segdet_float(
2271            output.view(),
2272            protos.view(),
2273            0.5,
2274            &mut boxes,
2275            &mut masks,
2276        )
2277        .unwrap();
2278
2279        assert_eq!(boxes.len(), 1);
2280        assert_eq!(masks.len(), 1);
2281
2282        // Mask region is the proto-grid-aligned crop and encloses the
2283        // post-NMS bbox (EDGEAI-1304); on a 16x16 grid each side may snap
2284        // by up to 1/16 = 0.0625.
2285        let step = 1.0 / 16.0;
2286        assert!(masks[0].xmin <= boxes[0].bbox.xmin);
2287        assert!(masks[0].ymin <= boxes[0].bbox.ymin);
2288        assert!(masks[0].xmax >= boxes[0].bbox.xmax);
2289        assert!(masks[0].ymax >= boxes[0].bbox.ymax);
2290        assert!((boxes[0].bbox.xmin - masks[0].xmin) < step);
2291        assert!((boxes[0].bbox.ymin - masks[0].ymin) < step);
2292        assert!((masks[0].xmax - boxes[0].bbox.xmax) < step);
2293        assert!((masks[0].ymax - boxes[0].bbox.ymax) < step);
2294    }
2295
2296    #[test]
2297    fn test_end_to_end_segdet_empty_output() {
2298        let num_protos = 32;
2299        let output = Array2::<f32>::zeros((6 + num_protos, 0));
2300        let protos = Array3::<f32>::zeros((16, 16, num_protos));
2301
2302        let mut boxes = Vec::with_capacity(10);
2303        let mut masks = Vec::with_capacity(10);
2304        decode_yolo_end_to_end_segdet_float(
2305            output.view(),
2306            protos.view(),
2307            0.5,
2308            &mut boxes,
2309            &mut masks,
2310        )
2311        .unwrap();
2312
2313        assert_eq!(boxes.len(), 0);
2314        assert_eq!(masks.len(), 0);
2315    }
2316
2317    #[test]
2318    fn test_end_to_end_segdet_capacity_limit() {
2319        let num_protos = 32;
2320        let num_detections = 5;
2321        let num_features = 6 + num_protos;
2322
2323        let mut data = vec![0.0f32; num_features * num_detections];
2324        // All detections pass threshold
2325        for i in 0..num_detections {
2326            data[i] = 0.1 * (i as f32); // x1
2327            data[num_detections + i] = 0.1 * (i as f32); // y1
2328            data[2 * num_detections + i] = 0.1 * (i as f32) + 0.2; // x2
2329            data[3 * num_detections + i] = 0.1 * (i as f32) + 0.2; // y2
2330            data[4 * num_detections + i] = 0.9; // conf
2331            data[5 * num_detections + i] = i as f32; // class
2332        }
2333
2334        let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
2335        let protos = Array3::<f32>::zeros((16, 16, num_protos));
2336
2337        let mut boxes = Vec::with_capacity(2); // Limit to 2
2338        let mut masks = Vec::with_capacity(2);
2339        decode_yolo_end_to_end_segdet_float(
2340            output.view(),
2341            protos.view(),
2342            0.5,
2343            &mut boxes,
2344            &mut masks,
2345        )
2346        .unwrap();
2347
2348        assert_eq!(boxes.len(), 2);
2349        assert_eq!(masks.len(), 2);
2350    }
2351
2352    #[test]
2353    fn test_end_to_end_segdet_invalid_shape_too_few_rows() {
2354        // Test with too few rows (needs at least 7: 6 base + 1 mask coeff)
2355        let output = Array2::<f32>::zeros((6, 3));
2356        let protos = Array3::<f32>::zeros((16, 16, 32));
2357
2358        let mut boxes = Vec::with_capacity(10);
2359        let mut masks = Vec::with_capacity(10);
2360        let result = decode_yolo_end_to_end_segdet_float(
2361            output.view(),
2362            protos.view(),
2363            0.5,
2364            &mut boxes,
2365            &mut masks,
2366        );
2367
2368        assert!(result.is_err());
2369        assert!(matches!(
2370            result,
2371            Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 7 rows")
2372        ));
2373    }
2374
2375    #[test]
2376    fn test_end_to_end_segdet_invalid_shape_protos_mismatch() {
2377        // Test with mismatched mask coefficients and protos count
2378        let num_protos = 32;
2379        let output = Array2::<f32>::zeros((6 + 16, 3)); // 16 mask coeffs
2380        let protos = Array3::<f32>::zeros((16, 16, num_protos)); // 32 protos
2381
2382        let mut boxes = Vec::with_capacity(10);
2383        let mut masks = Vec::with_capacity(10);
2384        let result = decode_yolo_end_to_end_segdet_float(
2385            output.view(),
2386            protos.view(),
2387            0.5,
2388            &mut boxes,
2389            &mut masks,
2390        );
2391
2392        assert!(result.is_err());
2393        assert!(matches!(
2394            result,
2395            Err(crate::DecoderError::InvalidShape(s)) if s.contains("doesn't match protos count")
2396        ));
2397    }
2398
2399    // ========================================================================
2400    // Tests for decode_yolo_split_end_to_end_segdet_float
2401    // ========================================================================
2402
2403    #[test]
2404    fn test_split_end_to_end_segdet_basic() {
2405        // Create synthetic segdet output: (6 + num_protos, N)
2406        // Detection format: [x1, y1, x2, y2, conf, class, mask_coeff_0..31]
2407        let num_protos = 32;
2408        let num_detections = 2;
2409        let num_features = 6 + num_protos;
2410
2411        // Build detection tensor
2412        let mut data = vec![0.0f32; num_features * num_detections];
2413        // Detection 0: passes threshold
2414        data[0] = 0.1; // x1[0]
2415        data[1] = 0.5; // x1[1]
2416        data[num_detections] = 0.1; // y1[0]
2417        data[num_detections + 1] = 0.5; // y1[1]
2418        data[2 * num_detections] = 0.4; // x2[0]
2419        data[2 * num_detections + 1] = 0.9; // x2[1]
2420        data[3 * num_detections] = 0.4; // y2[0]
2421        data[3 * num_detections + 1] = 0.9; // y2[1]
2422        data[4 * num_detections] = 0.9; // conf[0] - passes
2423        data[4 * num_detections + 1] = 0.3; // conf[1] - fails
2424        data[5 * num_detections] = 1.0; // class[0]
2425        data[5 * num_detections + 1] = 2.0; // class[1]
2426                                            // Fill mask coefficients with small values
2427        for i in 6..num_features {
2428            data[i * num_detections] = 0.1;
2429            data[i * num_detections + 1] = 0.1;
2430        }
2431
2432        let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
2433        let box_coords = output.slice(s![..4, ..]);
2434        let scores = output.slice(s![4..5, ..]);
2435        let classes = output.slice(s![5..6, ..]);
2436        let mask_coeff = output.slice(s![6.., ..]);
2437        // Create protos tensor: (proto_height, proto_width, num_protos)
2438        let protos = Array3::<f32>::zeros((16, 16, num_protos));
2439
2440        let mut boxes = Vec::with_capacity(10);
2441        let mut masks = Vec::with_capacity(10);
2442        decode_yolo_split_end_to_end_segdet_float(
2443            box_coords,
2444            scores,
2445            classes,
2446            mask_coeff,
2447            protos.view(),
2448            0.5,
2449            &mut boxes,
2450            &mut masks,
2451        )
2452        .unwrap();
2453
2454        // Only detection 0 should pass
2455        assert_eq!(boxes.len(), 1);
2456        assert_eq!(masks.len(), 1);
2457        assert_eq!(boxes[0].label, 1);
2458        assert!((boxes[0].score - 0.9).abs() < 0.01);
2459    }
2460
2461    // ========================================================================
2462    // Tests for yolo_segmentation_to_mask
2463    // ========================================================================
2464
2465    #[test]
2466    fn test_segmentation_to_mask_basic() {
2467        // Create a 4x4x1 segmentation with values above and below threshold
2468        let data: Vec<u8> = vec![
2469            100, 200, 50, 150, // row 0
2470            10, 255, 128, 64, // row 1
2471            0, 127, 128, 255, // row 2
2472            64, 64, 192, 192, // row 3
2473        ];
2474        let segmentation = Array3::from_shape_vec((4, 4, 1), data).unwrap();
2475
2476        let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2477
2478        // Values >= 128 should be 1, others 0
2479        assert_eq!(mask[[0, 0]], 0); // 100 < 128
2480        assert_eq!(mask[[0, 1]], 1); // 200 >= 128
2481        assert_eq!(mask[[0, 2]], 0); // 50 < 128
2482        assert_eq!(mask[[0, 3]], 1); // 150 >= 128
2483        assert_eq!(mask[[1, 1]], 1); // 255 >= 128
2484        assert_eq!(mask[[1, 2]], 1); // 128 >= 128
2485        assert_eq!(mask[[2, 0]], 0); // 0 < 128
2486        assert_eq!(mask[[2, 1]], 0); // 127 < 128
2487    }
2488
2489    #[test]
2490    fn test_segmentation_to_mask_all_above() {
2491        let segmentation = Array3::from_elem((4, 4, 1), 255u8);
2492        let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2493        assert!(mask.iter().all(|&x| x == 1));
2494    }
2495
2496    #[test]
2497    fn test_segmentation_to_mask_all_below() {
2498        let segmentation = Array3::from_elem((4, 4, 1), 64u8);
2499        let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2500        assert!(mask.iter().all(|&x| x == 0));
2501    }
2502
2503    #[test]
2504    fn test_segmentation_to_mask_invalid_shape() {
2505        let segmentation = Array3::from_elem((4, 4, 3), 128u8);
2506        let result = yolo_segmentation_to_mask(segmentation.view(), 128);
2507
2508        assert!(result.is_err());
2509        assert!(matches!(
2510            result,
2511            Err(crate::DecoderError::InvalidShape(s)) if s.contains("(H, W, 1)")
2512        ));
2513    }
2514
2515    // ========================================================================
2516    // Tests for protobox / NORM_LIMIT regression
2517    // ========================================================================
2518
2519    #[test]
2520    fn test_protobox_clamps_edge_coordinates() {
2521        // bbox with xmax=1.0 should not panic (OOB guard)
2522        let protos = Array3::<f32>::zeros((16, 16, 4));
2523        let view = protos.view();
2524        let roi = BoundingBox {
2525            xmin: 0.5,
2526            ymin: 0.5,
2527            xmax: 1.0,
2528            ymax: 1.0,
2529        };
2530        let result = protobox(&view, &roi);
2531        assert!(result.is_ok(), "protobox should accept xmax=1.0");
2532        let (cropped, _roi_norm) = result.unwrap();
2533        // Cropped region must have non-zero spatial dimensions
2534        assert!(cropped.shape()[0] > 0);
2535        assert!(cropped.shape()[1] > 0);
2536        assert_eq!(cropped.shape()[2], 4);
2537    }
2538
2539    #[test]
2540    fn test_protobox_rejects_wildly_out_of_range() {
2541        // bbox with coords > NORM_LIMIT (e.g. 3.0) returns error
2542        let protos = Array3::<f32>::zeros((16, 16, 4));
2543        let view = protos.view();
2544        let roi = BoundingBox {
2545            xmin: 0.0,
2546            ymin: 0.0,
2547            xmax: 3.0,
2548            ymax: 3.0,
2549        };
2550        let result = protobox(&view, &roi);
2551        assert!(
2552            matches!(result, Err(crate::DecoderError::InvalidShape(s)) if s.contains("un-normalized")),
2553            "protobox should reject coords > NORM_LIMIT"
2554        );
2555    }
2556
2557    #[test]
2558    fn test_protobox_accepts_slightly_over_one() {
2559        // bbox with coords at 1.5 (within NORM_LIMIT=2.0) succeeds
2560        let protos = Array3::<f32>::zeros((16, 16, 4));
2561        let view = protos.view();
2562        let roi = BoundingBox {
2563            xmin: 0.0,
2564            ymin: 0.0,
2565            xmax: 1.5,
2566            ymax: 1.5,
2567        };
2568        let result = protobox(&view, &roi);
2569        assert!(
2570            result.is_ok(),
2571            "protobox should accept coords <= NORM_LIMIT (2.0)"
2572        );
2573        let (cropped, _roi_norm) = result.unwrap();
2574        // Entire proto map should be selected when coords > 1.0 (clamped to boundary)
2575        assert_eq!(cropped.shape()[0], 16);
2576        assert_eq!(cropped.shape()[1], 16);
2577    }
2578
2579    #[test]
2580    fn test_segdet_float_proto_no_panic() {
2581        // Simulates YOLOv8n-seg: output0 = [116, 8400] (4 box + 80 class + 32 mask coeff)
2582        // output1 (protos) = [32, 160, 160]
2583        let num_proposals = 100; // enough to produce idx >= 32
2584        let num_classes = 80;
2585        let num_mask_coeffs = 32;
2586        let rows = 4 + num_classes + num_mask_coeffs; // 116
2587
2588        // Fill boxes with valid xywh data so some detections pass the threshold.
2589        // Layout is [116, num_proposals] row-major: row 0=cx, 1=cy, 2=w, 3=h,
2590        // rows 4..84=class scores, rows 84..116=mask coefficients.
2591        let mut data = vec![0.0f32; rows * num_proposals];
2592        for i in 0..num_proposals {
2593            let row = |r: usize| r * num_proposals + i;
2594            data[row(0)] = 320.0; // cx
2595            data[row(1)] = 320.0; // cy
2596            data[row(2)] = 50.0; // w
2597            data[row(3)] = 50.0; // h
2598            data[row(4)] = 0.9; // class-0 score
2599        }
2600        let boxes = ndarray::Array2::from_shape_vec((rows, num_proposals), data).unwrap();
2601
2602        // Protos must be in HWC order. Under the HAL physical-order
2603        // contract, callers declare shape+dshape matching producer memory
2604        // and swap_axes_if_needed permutes the stride tuple into canonical
2605        // [batch, height, width, num_protos] before this function sees it.
2606        let protos = ndarray::Array3::<f32>::zeros((160, 160, num_mask_coeffs));
2607
2608        let mut output_boxes = Vec::with_capacity(300);
2609
2610        // This panicked before fix: mask_tensor.row(idx) with idx >= 32
2611        let proto_data = impl_yolo_segdet_float_proto::<XYWH, _, _>(
2612            boxes.view(),
2613            protos.view(),
2614            0.5,
2615            0.7,
2616            Some(Nms::default()),
2617            MAX_NMS_CANDIDATES,
2618            300,
2619            None,
2620            None,
2621            &mut output_boxes,
2622        );
2623
2624        // Should produce detections (NMS will collapse many overlapping boxes)
2625        assert!(!output_boxes.is_empty());
2626        let coeffs_shape = proto_data.mask_coefficients.shape();
2627        assert_eq!(coeffs_shape[0], output_boxes.len());
2628        // Each mask coefficient vector should have 32 elements
2629        assert_eq!(coeffs_shape[1], num_mask_coeffs);
2630    }
2631
2632    // ========================================================================
2633    // Pre-NMS top-K cap (MAX_NMS_CANDIDATES)
2634    // ========================================================================
2635
2636    /// At very low score thresholds (e.g., t=0.01 on YOLOv8 with 8400×80
2637    /// candidates) almost every score passes the filter, feeding O(n²)
2638    /// NMS and a per-survivor mask matmul. The decoder caps the
2639    /// candidate set fed to NMS at `MAX_NMS_CANDIDATES` (Ultralytics
2640    /// default 30 000) to bound worst-case decode time.
2641    ///
2642    /// This regression test pumps 50 000 above-threshold candidates
2643    /// into `impl_yolo_segdet_get_boxes` with NMS bypassed (Nms=None)
2644    /// and a generous post-NMS cap. Before the fix, the function
2645    /// returned all 50 000; after the fix, exactly 30 000.
2646    #[test]
2647    fn test_pre_nms_cap_truncates_excess_candidates() {
2648        let n: usize = 50_000;
2649        let num_classes = 1;
2650
2651        // Identical valid boxes. Distinct scores (descending) so the
2652        // top-K cap keeps the highest-scoring ones in deterministic
2653        // order — letting us assert *which* ones survived.
2654        let mut boxes_data = Vec::with_capacity(n * 4);
2655        let mut scores_data = Vec::with_capacity(n * num_classes);
2656        for i in 0..n {
2657            boxes_data.extend_from_slice(&[0.1f32, 0.1, 0.5, 0.5]);
2658            // score_i = 0.99 - i * 1e-7 keeps everything well above 0.1
2659            // threshold but strictly decreasing.
2660            scores_data.push(0.99 - (i as f32) * 1e-7);
2661        }
2662        let boxes = Array2::from_shape_vec((n, 4), boxes_data).unwrap();
2663        let scores = Array2::from_shape_vec((n, num_classes), scores_data).unwrap();
2664
2665        let result = impl_yolo_segdet_get_boxes::<XYXY, _, _>(
2666            boxes.view(),
2667            scores.view(),
2668            0.1,
2669            1.0,                             // IoU 1.0 → NMS suppresses nothing
2670            Some(Nms::ClassAgnostic),        // NMS enabled so pre_nms_top_k applies
2671            crate::yolo::MAX_NMS_CANDIDATES, // pre_nms_top_k
2672            usize::MAX,                      // no post-NMS truncation
2673        );
2674
2675        assert_eq!(
2676            result.len(),
2677            crate::yolo::MAX_NMS_CANDIDATES,
2678            "pre-NMS cap should truncate to MAX_NMS_CANDIDATES; got {}",
2679            result.len()
2680        );
2681        // Top-K survivors: highest scores were the first n indices,
2682        // so survivor 0 must have score ~0.99.
2683        let top_score = result[0].0.score;
2684        assert!(
2685            top_score > 0.98,
2686            "highest-ranked survivor should have the largest score, got {top_score}"
2687        );
2688    }
2689
2690    /// Counterpart for the quantized split path. Same contract: feed
2691    /// more than `MAX_NMS_CANDIDATES` survivors above the quantized
2692    /// threshold, confirm `impl_yolo_split_segdet_quant_get_boxes`
2693    /// truncates before NMS.
2694    #[test]
2695    fn test_pre_nms_cap_truncates_excess_candidates_quant() {
2696        use crate::Quantization;
2697        let n: usize = 50_000;
2698        let num_classes = 1;
2699
2700        // i8 boxes with simple scale/zp; the box value 50 dequantizes
2701        // to 0.5 with scale=0.01, zp=0 — fine for a flat box set.
2702        let boxes_data = (0..n).flat_map(|_| [10i8, 10, 50, 50]).collect::<Vec<_>>();
2703        let boxes = Array2::from_shape_vec((n, 4), boxes_data).unwrap();
2704        let quant_boxes = Quantization {
2705            scale: 0.01,
2706            zero_point: 0,
2707        };
2708
2709        // u8 scores: distinct descending values, all well above threshold.
2710        // value 250 → 0.98 with scale 0.00392, zp 0.
2711        // value (250 - i % 200) keeps a wide spread above the dequant
2712        // threshold of 0.5.
2713        let scores_data: Vec<u8> = (0..n)
2714            .map(|i| 250u8.saturating_sub((i % 200) as u8))
2715            .collect();
2716        let scores = Array2::from_shape_vec((n, num_classes), scores_data).unwrap();
2717        let quant_scores = Quantization {
2718            scale: 0.00392,
2719            zero_point: 0,
2720        };
2721
2722        let result = impl_yolo_split_segdet_quant_get_boxes::<XYXY, _, _>(
2723            (boxes.view(), quant_boxes),
2724            (scores.view(), quant_scores),
2725            0.1,
2726            1.0,                             // IoU 1.0 → NMS suppresses nothing
2727            Some(Nms::ClassAgnostic),        // NMS enabled so pre_nms_top_k applies
2728            crate::yolo::MAX_NMS_CANDIDATES, // pre_nms_top_k
2729            usize::MAX,                      // no post-NMS truncation
2730        );
2731
2732        assert_eq!(
2733            result.len(),
2734            crate::yolo::MAX_NMS_CANDIDATES,
2735            "quant path pre-NMS cap should truncate to MAX_NMS_CANDIDATES; got {}",
2736            result.len()
2737        );
2738    }
2739
2740    /// Regression test for HAILORT_BUG.md — the YoloSegDet path
2741    /// (combined `(4 + nc + nm, N)` detection tensor + separate protos)
2742    /// must pair each surviving detection with the mask coefficient
2743    /// row at the SAME anchor index the box came from. The validator
2744    /// sees this path miss the pairing under schema-v2 Hailo inputs
2745    /// (mAP collapse from 46.8 → 3.65 while mask IoU stays at 66.9,
2746    /// the fingerprint of mask-to-detection misalignment).
2747    ///
2748    /// Construction: three anchors with distinct mask-coef signatures
2749    /// that, after dot(coefs, protos) + sigmoid, produce HIGH vs LOW
2750    /// mask pixel values. Two anchors survive (one high, one low); if
2751    /// the mask row is looked up at the wrong index, the per-detection
2752    /// mean mask value would cross the threshold and we catch it.
2753    #[test]
2754    fn segdet_combined_tensor_pairs_detection_with_matching_mask_row() {
2755        let nc = 2; // num_classes
2756        let nm = 2; // num_protos
2757        let n = 3; // num_anchors
2758        let feat = 4 + nc + nm; // 8
2759
2760        // Tensor layout: (8, 3) rows=features, cols=anchors.
2761        // Row indices:  0..4 = xywh, 4..6 = scores, 6..8 = mask_coefs.
2762        //
2763        //         anchor 0 | anchor 1 | anchor 2
2764        // xc       0.2      | 0.5      | 0.8
2765        // yc       0.2      | 0.5      | 0.8
2766        // w        0.1      | 0.1      | 0.1
2767        // h        0.1      | 0.1      | 0.1
2768        // s[0]     0.9      | 0.0      | 0.8   (class 0)
2769        // s[1]     0.0      | 0.0      | 0.0   (class 1 — always loses)
2770        // m[0]     3.0      | 0.0      | -3.0  (high for a0, low for a2)
2771        // m[1]     3.0      | 0.0      | -3.0
2772        //
2773        // Proto[0] = Proto[1] = all-ones (8x8), so
2774        //   mask(a0) = sigmoid(3 + 3) ≈ 0.9975 → 254
2775        //   mask(a2) = sigmoid(-3 + -3) ≈ 0.0025 → 1
2776        // 250-point gap makes any misalignment trivially detectable.
2777        let mut data = vec![0.0f32; feat * n];
2778        let set = |d: &mut [f32], r: usize, c: usize, v: f32| d[r * n + c] = v;
2779        set(&mut data, 0, 0, 0.2);
2780        set(&mut data, 1, 0, 0.2);
2781        set(&mut data, 2, 0, 0.1);
2782        set(&mut data, 3, 0, 0.1);
2783        set(&mut data, 0, 1, 0.5);
2784        set(&mut data, 1, 1, 0.5);
2785        set(&mut data, 2, 1, 0.1);
2786        set(&mut data, 3, 1, 0.1);
2787        set(&mut data, 0, 2, 0.8);
2788        set(&mut data, 1, 2, 0.8);
2789        set(&mut data, 2, 2, 0.1);
2790        set(&mut data, 3, 2, 0.1);
2791        set(&mut data, 4, 0, 0.9);
2792        set(&mut data, 4, 2, 0.8);
2793        set(&mut data, 6, 0, 3.0);
2794        set(&mut data, 7, 0, 3.0);
2795        set(&mut data, 6, 2, -3.0);
2796        set(&mut data, 7, 2, -3.0);
2797
2798        let output = Array2::from_shape_vec((feat, n), data).unwrap();
2799        let protos = Array3::<f32>::from_elem((8, 8, nm), 1.0);
2800
2801        let mut boxes: Vec<DetectBox> = Vec::with_capacity(10);
2802        let mut masks: Vec<Segmentation> = Vec::with_capacity(10);
2803        decode_yolo_segdet_float(
2804            output.view(),
2805            protos.view(),
2806            0.5,
2807            0.5,
2808            Some(Nms::ClassAgnostic),
2809            &mut boxes,
2810            &mut masks,
2811        )
2812        .unwrap();
2813
2814        assert_eq!(
2815            boxes.len(),
2816            2,
2817            "two anchors above threshold should survive (a0 score=0.9, a2 score=0.8); got {}",
2818            boxes.len()
2819        );
2820
2821        // Build a (anchor_index → mask_mean) mapping from the results.
2822        // Anchor 0 has centre (0.2, 0.2), anchor 2 has centre (0.8,
2823        // 0.8). The DetectBox bbox is the post-XYWH-to-XYXY conversion
2824        // of the original xywh; cropping inside protobox may shrink it,
2825        // so match by centre (0.2 vs 0.8) rather than exact bbox.
2826        for (b, m) in boxes.iter().zip(masks.iter()) {
2827            let cx = (b.bbox.xmin + b.bbox.xmax) * 0.5;
2828            let mean = {
2829                let s = &m.segmentation;
2830                let total: u32 = s.iter().map(|&v| v as u32).sum();
2831                total as f32 / s.len() as f32
2832            };
2833            if cx < 0.3 {
2834                // anchor 0 — expect HIGH mask values ≈ 254
2835                assert!(
2836                    mean > 200.0,
2837                    "anchor 0 detection (centre {cx:.2}) should have high-value mask; got mean {mean}"
2838                );
2839            } else if cx > 0.7 {
2840                // anchor 2 — expect LOW mask values ≈ 1
2841                assert!(
2842                    mean < 50.0,
2843                    "anchor 2 detection (centre {cx:.2}) should have low-value mask; got mean {mean}"
2844                );
2845            } else {
2846                panic!("unexpected detection centre {cx:.2}");
2847            }
2848        }
2849    }
2850}