Skip to main content

edgefirst_decoder/
yolo.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4//! Internal YOLO decoder kernels.
5//!
6//! All items in this module are `pub(crate)`. The public API surface for
7//! decoding is [`crate::Decoder`] + [`crate::DecoderBuilder`]; external
8//! callers must go through that entry point so we can evolve the kernels
9//! (split paths, dispatch tables, NEON tiers) without breaking semver.
10//! See `CHANGELOG.md` for the 0.20.0 narrowing.
11
12use std::fmt::Debug;
13
14use ndarray::{
15    parallel::prelude::{IntoParallelIterator, ParallelIterator},
16    s, Array2, Array3, ArrayView1, ArrayView2, ArrayView3,
17};
18use num_traits::{AsPrimitive, Float, PrimInt, Signed};
19
20use crate::{
21    byte::{
22        nms_class_aware_int, nms_extra_class_aware_int, nms_extra_int, nms_int,
23        postprocess_boxes_index_quant, postprocess_boxes_quant, quantize_score_threshold,
24    },
25    configs::Nms,
26    dequant_detect_box,
27    float::{
28        nms_class_aware_float, nms_extra_class_aware_float, nms_extra_float, nms_float,
29        postprocess_boxes_float, postprocess_boxes_index_float,
30    },
31    BBoxTypeTrait, BoundingBox, DetectBox, DetectBoxQuantized, ProtoData, ProtoLayout,
32    Quantization, Segmentation, XYWH, XYXY,
33};
34
35/// Maximum number of above-threshold candidates fed to NMS.
36///
37/// At very low score thresholds (e.g., t=0.01 on YOLOv8 with
38/// 8400 anchors × 80 classes), the number of survivors approaches
39/// the full 672 000-entry score grid. NMS is O(n²) and the
40/// downstream mask matmul runs once per survivor, so an
41/// unbounded set produces minutes-per-frame decode times.
42///
43/// `MAX_NMS_CANDIDATES` matches the Ultralytics `max_nms` default
44/// and is applied as a top-K-by-score truncation immediately
45/// before NMS. Values above the cap are silently dropped — at the
46/// score thresholds where the cap activates the bottom of the
47/// candidate list is dominated by noise that NMS would discard
48/// anyway.
49///
50/// Production callers configure this via [`crate::DecoderBuilder::with_pre_nms_top_k`];
51/// only the test-only seg-det wrappers reach for this constant directly.
52#[cfg(test)]
53pub(crate) const MAX_NMS_CANDIDATES: usize = 30_000;
54
55/// Default post-NMS detection cap used by the public `decode_yolo_*`
56/// convenience wrappers when no explicit cap is plumbed in. Mirrors the
57/// `Decoder::max_det` default set by `DecoderBuilder` (also 300, matching
58/// the Ultralytics `max_det` default). Pre-EDGEAI-1302 these wrappers
59/// used `output_boxes.capacity()` as the cap, which silently dropped all
60/// detections when the caller passed `Vec::new()`.
61pub(crate) const DEFAULT_MAX_DETECTIONS: usize = 300;
62
63/// Truncate `boxes` to the highest-scoring `top_k` entries in-place when the
64/// input exceeds the cap. Uses partial sort (O(N)) via `select_nth_unstable_by`
65/// to avoid full O(N log N) sort. No-op when `top_k` is 0 (unbounded) or
66/// when the input length ≤ `top_k`.
67fn truncate_to_top_k_by_score<E: Send>(boxes: &mut Vec<(DetectBox, E)>, top_k: usize) {
68    if top_k > 0 && boxes.len() > top_k {
69        boxes.select_nth_unstable_by(top_k, |a, b| b.0.score.total_cmp(&a.0.score));
70        boxes.truncate(top_k);
71    }
72}
73
74/// Quantized counterpart of [`truncate_to_top_k_by_score`]. Sorts on
75/// the raw quantized score (which preserves order under monotonic
76/// dequantization). Uses partial sort (O(N)) via `select_nth_unstable_by`.
77/// No-op when `top_k` is 0 (unbounded) or when the input length ≤ `top_k`.
78fn truncate_to_top_k_by_score_quant<S: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send>(
79    boxes: &mut Vec<(DetectBoxQuantized<S>, E)>,
80    top_k: usize,
81) {
82    if top_k > 0 && boxes.len() > top_k {
83        boxes.select_nth_unstable_by(top_k, |a, b| b.0.score.cmp(&a.0.score));
84        boxes.truncate(top_k);
85    }
86}
87
88/// Dispatches to the appropriate NMS function based on mode for float boxes.
89///
90/// `max_det` is the post-NMS detection cap; when present it lets the greedy
91/// inner loop break as soon as that many survivors are confirmed (the survivors
92/// are guaranteed to be the top-`max_det` by score because the input is sorted
93/// descending). Pass `None` to run the full O(N²) suppression.
94fn dispatch_nms_float(
95    nms: Option<Nms>,
96    iou: f32,
97    max_det: Option<usize>,
98    boxes: Vec<DetectBox>,
99) -> Vec<DetectBox> {
100    match nms {
101        Some(Nms::ClassAgnostic | Nms::Auto) => nms_float(iou, max_det, boxes),
102        Some(Nms::ClassAware) => nms_class_aware_float(iou, max_det, boxes),
103        None => boxes, // bypass NMS
104    }
105}
106
107/// Dispatches to the appropriate NMS function based on mode for float boxes
108/// with extra data.
109pub(super) fn dispatch_nms_extra_float<E: Send + Sync>(
110    nms: Option<Nms>,
111    iou: f32,
112    max_det: Option<usize>,
113    boxes: Vec<(DetectBox, E)>,
114) -> Vec<(DetectBox, E)> {
115    match nms {
116        Some(Nms::ClassAgnostic | Nms::Auto) => nms_extra_float(iou, max_det, boxes),
117        Some(Nms::ClassAware) => nms_extra_class_aware_float(iou, max_det, boxes),
118        None => boxes, // bypass NMS
119    }
120}
121
122/// Dispatches to the appropriate NMS function based on mode for quantized
123/// boxes.
124fn dispatch_nms_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync>(
125    nms: Option<Nms>,
126    iou: f32,
127    max_det: Option<usize>,
128    boxes: Vec<DetectBoxQuantized<SCORE>>,
129) -> Vec<DetectBoxQuantized<SCORE>> {
130    match nms {
131        Some(Nms::ClassAgnostic | Nms::Auto) => nms_int(iou, max_det, boxes),
132        Some(Nms::ClassAware) => nms_class_aware_int(iou, max_det, boxes),
133        None => boxes, // bypass NMS
134    }
135}
136
137/// Dispatches to the appropriate NMS function based on mode for quantized boxes
138/// with extra data.
139fn dispatch_nms_extra_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send + Sync>(
140    nms: Option<Nms>,
141    iou: f32,
142    max_det: Option<usize>,
143    boxes: Vec<(DetectBoxQuantized<SCORE>, E)>,
144) -> Vec<(DetectBoxQuantized<SCORE>, E)> {
145    match nms {
146        Some(Nms::ClassAgnostic | Nms::Auto) => nms_extra_int(iou, max_det, boxes),
147        Some(Nms::ClassAware) => nms_extra_class_aware_int(iou, max_det, boxes),
148        None => boxes, // bypass NMS
149    }
150}
151
152/// Detection cap helper for the public free `decode_yolo_*` wrappers.
153///
154/// Encodes the convention documented above: if the caller passed a non-empty
155/// `Vec`, that capacity acts as the per-call cap; otherwise fall back to
156/// [`DEFAULT_MAX_DETECTIONS`] so freshly-constructed `Vec::new()` outputs
157/// don't silently drop every detection.
158#[inline]
159fn cap_or_default<T>(v: &Vec<T>) -> usize {
160    if v.capacity() > 0 {
161        v.capacity()
162    } else {
163        DEFAULT_MAX_DETECTIONS
164    }
165}
166
167// ─── Public free decode_yolo_* convenience wrappers ────────────────────────
168//
169// Detection cap convention (applies to every `decode_yolo_*` free function
170// below):
171//
172// These functions are the convenience layer for callers that don't go
173// through `Decoder::decode()` (benches, FFI shims, ad-hoc test harnesses).
174// They use **`output_boxes.capacity()` as a per-call detection cap**:
175//
176//   - When the caller passes `Vec::with_capacity(N)`, the post-NMS output
177//     is truncated to at most `N` detections.
178//   - When the caller passes `Vec::new()` (capacity 0), the implementation
179//     falls back to the [`DEFAULT_MAX_DETECTIONS`] constant (300) so a
180//     freshly-constructed `Vec` doesn't silently drop every detection.
181//
182// This is intentionally **different** from the `Decoder::decode()` /
183// `Decoder::decode_proto()` contract, which bounds output count solely
184// by [`Decoder::max_det`] (set via `DecoderBuilder::with_max_det`,
185// default 300) regardless of the caller's `Vec` capacity (EDGEAI-1302).
186//
187// Use the `Decoder` API when you need explicit control over `max_det`,
188// schema-driven decoding, or EDGEAI-1303 normalization. Use these free
189// functions when you have raw tensors in hand and want a one-shot decode.
190
191/// Decodes YOLO detection outputs from quantized tensors into detection boxes.
192///
193/// Boxes are expected to be in XYWH format.
194///
195/// Expected shapes of inputs:
196/// - output: (4 + num_classes, num_boxes)
197///
198/// See the "Detection cap convention" comment above for how
199/// `output_boxes.capacity()` bounds the result count.
200pub(crate) fn decode_yolo_det<BOX: PrimInt + AsPrimitive<f32> + Send + Sync>(
201    output: (ArrayView2<BOX>, Quantization),
202    score_threshold: f32,
203    iou_threshold: f32,
204    nms: Option<Nms>,
205    output_boxes: &mut Vec<DetectBox>,
206) where
207    f32: AsPrimitive<BOX>,
208{
209    impl_yolo_quant::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
210}
211
212/// Decodes YOLO detection outputs from float tensors into detection boxes.
213///
214/// Boxes are expected to be in XYWH format.
215///
216/// Expected shapes of inputs:
217/// - output: (4 + num_classes, num_boxes)
218pub(crate) fn decode_yolo_det_float<T>(
219    output: ArrayView2<T>,
220    score_threshold: f32,
221    iou_threshold: f32,
222    nms: Option<Nms>,
223    output_boxes: &mut Vec<DetectBox>,
224) where
225    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
226    f32: AsPrimitive<T>,
227{
228    impl_yolo_float::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
229}
230
231/// Test-only seg-det quantized decode shim.
232///
233/// Production callers go through [`crate::Decoder::decode`]; this wrapper
234/// exists only so the parity tests in `lib.rs` and `yolo.rs` can compare the
235/// kernel output against the Decoder output without duplicating the
236/// `impl_yolo_segdet_quant` argument plumbing.
237///
238/// Boxes are expected to be in XYWH format. Expected shapes:
239/// - `boxes`: `(4 + num_classes + num_protos, num_boxes)`
240/// - `protos`: `(proto_height, proto_width, num_protos)`
241///
242/// # Errors
243/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
244#[cfg(test)]
245pub(crate) fn decode_yolo_segdet_quant<
246    BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
247    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
248>(
249    boxes: (ArrayView2<BOX>, Quantization),
250    protos: (ArrayView3<PROTO>, Quantization),
251    score_threshold: f32,
252    iou_threshold: f32,
253    nms: Option<Nms>,
254    output_boxes: &mut Vec<DetectBox>,
255    output_masks: &mut Vec<Segmentation>,
256) -> Result<(), crate::DecoderError>
257where
258    f32: AsPrimitive<BOX>,
259{
260    // Pre-Decoder convenience wrapper: no schema-derived `normalized`
261    // or `input_dims`, so the EDGEAI-1303 normalization is a no-op.
262    // Callers that need that path should go through `DecoderBuilder`
263    // with a schema.
264    let cap = cap_or_default(output_boxes);
265    impl_yolo_segdet_quant::<XYWH, _, _>(
266        boxes,
267        protos,
268        score_threshold,
269        iou_threshold,
270        nms,
271        MAX_NMS_CANDIDATES,
272        cap,
273        None,
274        None,
275        output_boxes,
276        output_masks,
277    )
278}
279
280/// Test-only seg-det float decode shim. See [`decode_yolo_segdet_quant`].
281#[cfg(test)]
282pub(crate) fn decode_yolo_segdet_float<T>(
283    boxes: ArrayView2<T>,
284    protos: ArrayView3<T>,
285    score_threshold: f32,
286    iou_threshold: f32,
287    nms: Option<Nms>,
288    output_boxes: &mut Vec<DetectBox>,
289    output_masks: &mut Vec<Segmentation>,
290) -> Result<(), crate::DecoderError>
291where
292    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
293    f32: AsPrimitive<T>,
294{
295    // Pre-Decoder convenience wrapper: schema-derived normalization is
296    // not available here (see EDGEAI-1303 note in the quantized sibling).
297    let cap = cap_or_default(output_boxes);
298    impl_yolo_segdet_float::<XYWH, _, _>(
299        boxes,
300        protos,
301        score_threshold,
302        iou_threshold,
303        nms,
304        MAX_NMS_CANDIDATES,
305        cap,
306        None,
307        None,
308        output_boxes,
309        output_masks,
310    )
311}
312
313/// Decodes YOLO split detection outputs from quantized tensors into detection
314/// boxes.
315///
316/// Boxes are expected to be in XYWH format.
317///
318/// Expected shapes of inputs:
319/// - boxes: (4, num_boxes)
320/// - scores: (num_classes, num_boxes)
321///
322/// # Panics
323/// Panics if shapes don't match the expected dimensions.
324pub(crate) fn decode_yolo_split_det_quant<
325    BOX: PrimInt + AsPrimitive<i32> + AsPrimitive<f32> + Send + Sync,
326    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
327>(
328    boxes: (ArrayView2<BOX>, Quantization),
329    scores: (ArrayView2<SCORE>, Quantization),
330    score_threshold: f32,
331    iou_threshold: f32,
332    nms: Option<Nms>,
333    output_boxes: &mut Vec<DetectBox>,
334) where
335    f32: AsPrimitive<SCORE>,
336{
337    impl_yolo_split_quant::<XYWH, _, _>(
338        boxes,
339        scores,
340        score_threshold,
341        iou_threshold,
342        nms,
343        output_boxes,
344    );
345}
346
347/// Decodes YOLO split detection outputs from float tensors into detection
348/// boxes.
349///
350/// Boxes are expected to be in XYWH format.
351///
352/// Expected shapes of inputs:
353/// - boxes: (4, num_boxes)
354/// - scores: (num_classes, num_boxes)
355///
356/// # Panics
357/// Panics if shapes don't match the expected dimensions.
358pub(crate) fn decode_yolo_split_det_float<T>(
359    boxes: ArrayView2<T>,
360    scores: ArrayView2<T>,
361    score_threshold: f32,
362    iou_threshold: f32,
363    nms: Option<Nms>,
364    output_boxes: &mut Vec<DetectBox>,
365) where
366    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
367    f32: AsPrimitive<T>,
368{
369    impl_yolo_split_float::<XYWH, _, _>(
370        boxes,
371        scores,
372        score_threshold,
373        iou_threshold,
374        nms,
375        output_boxes,
376    );
377}
378
379/// Decodes end-to-end YOLO detection outputs (post-NMS from model).
380/// Expects an array of shape `(6, N)`, where the first dimension (rows)
381/// corresponds to the 6 per-detection features
382/// `[x1, y1, x2, y2, conf, class]` and the second dimension (columns)
383/// indexes the `N` detections.
384/// Boxes are output directly without NMS (the model already applied NMS).
385///
386/// Coordinates may be normalized `[0, 1]` or absolute pixel values depending
387/// on the model configuration. The caller should check
388/// `decoder.normalized_boxes()` to determine which.
389///
390/// # Errors
391///
392/// Returns `DecoderError::InvalidShape` if `output` has fewer than 6 rows.
393pub(crate) fn decode_yolo_end_to_end_det_float<T>(
394    output: ArrayView2<T>,
395    score_threshold: f32,
396    output_boxes: &mut Vec<DetectBox>,
397) -> Result<(), crate::DecoderError>
398where
399    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
400    f32: AsPrimitive<T>,
401{
402    // Validate input shape: need at least 6 rows (x1, y1, x2, y2, conf, class)
403    if output.shape()[0] < 6 {
404        return Err(crate::DecoderError::InvalidShape(format!(
405            "End-to-end detection output requires at least 6 rows, got {}",
406            output.shape()[0]
407        )));
408    }
409
410    // Input shape: (6, N) -> transpose to (N, 4) for boxes and (N, 1) for scores
411    let boxes = output.slice(s![0..4, ..]).reversed_axes();
412    let scores = output.slice(s![4..5, ..]).reversed_axes();
413    let classes = output.slice(s![5, ..]);
414    let mut boxes =
415        postprocess_boxes_index_float::<XYXY, _, _>(score_threshold.as_(), boxes, scores);
416    boxes.truncate(cap_or_default(output_boxes));
417    output_boxes.clear();
418    for (mut b, i) in boxes.into_iter() {
419        b.label = classes[i].as_() as usize;
420        output_boxes.push(b);
421    }
422    // No NMS — model output is already post-NMS
423    Ok(())
424}
425
426/// Decodes end-to-end YOLO detection + segmentation outputs (post-NMS from
427/// model).
428///
429/// Input shapes:
430/// - detection: (6 + num_protos, N) where rows are [x1, y1, x2, y2, conf,
431///   class, mask_coeff_0, ..., mask_coeff_31]
432/// - protos: (proto_height, proto_width, num_protos)
433///
434/// Boxes are output directly without NMS (model already applied NMS).
435/// Coordinates may be normalized [0,1] or pixel values depending on model
436/// config.
437///
438/// # Errors
439///
440/// Returns `DecoderError::InvalidShape` if:
441/// - output has fewer than 7 rows (6 base + at least 1 mask coefficient)
442/// - protos shape doesn't match mask coefficients count
443pub(crate) fn decode_yolo_end_to_end_segdet_float<T>(
444    output: ArrayView2<T>,
445    protos: ArrayView3<T>,
446    score_threshold: f32,
447    output_boxes: &mut Vec<DetectBox>,
448    output_masks: &mut Vec<crate::Segmentation>,
449) -> Result<(), crate::DecoderError>
450where
451    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
452    f32: AsPrimitive<T>,
453{
454    let (boxes, scores, classes, mask_coeff) =
455        postprocess_yolo_end_to_end_segdet(&output, protos.dim().2)?;
456    let cap = cap_or_default(output_boxes);
457    let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
458        boxes,
459        scores,
460        classes,
461        score_threshold,
462        cap,
463    );
464
465    // No NMS — model output is already post-NMS
466
467    impl_yolo_split_segdet_process_masks(boxes, mask_coeff, protos, output_boxes, output_masks)
468}
469
470/// Decodes split end-to-end YOLO detection outputs (post-NMS from model).
471///
472/// Input shapes (after batch dim removed):
473/// - boxes: (4, N) — xyxy pixel coordinates
474/// - scores: (1, N) — confidence of the top class
475/// - classes: (1, N) — class index of the top class
476///
477/// Boxes are output directly without NMS (model already applied NMS).
478pub(crate) fn decode_yolo_split_end_to_end_det_float<T: Float + AsPrimitive<f32>>(
479    boxes: ArrayView2<T>,
480    scores: ArrayView2<T>,
481    classes: ArrayView2<T>,
482    score_threshold: f32,
483    output_boxes: &mut Vec<DetectBox>,
484) -> Result<(), crate::DecoderError> {
485    let n = boxes.shape()[1];
486
487    let cap = cap_or_default(output_boxes);
488    output_boxes.clear();
489
490    let (boxes, scores, classes) = postprocess_yolo_split_end_to_end_det(boxes, scores, &classes)?;
491
492    for i in 0..n {
493        let score: f32 = scores[[i, 0]].as_();
494        if score < score_threshold {
495            continue;
496        }
497        if output_boxes.len() >= cap {
498            break;
499        }
500        output_boxes.push(DetectBox {
501            bbox: BoundingBox {
502                xmin: boxes[[i, 0]].as_(),
503                ymin: boxes[[i, 1]].as_(),
504                xmax: boxes[[i, 2]].as_(),
505                ymax: boxes[[i, 3]].as_(),
506            },
507            score,
508            label: classes[i].as_() as usize,
509        });
510    }
511    Ok(())
512}
513
514/// Decodes split end-to-end YOLO detection + segmentation outputs.
515///
516/// Input shapes (after batch dim removed):
517/// - boxes: (4, N) — xyxy pixel coordinates
518/// - scores: (1, N) — confidence
519/// - classes: (1, N) — class index
520/// - mask_coeff: (num_protos, N) — mask coefficients per detection
521/// - protos: (proto_h, proto_w, num_protos) — prototype masks
522#[allow(clippy::too_many_arguments)]
523pub(crate) fn decode_yolo_split_end_to_end_segdet_float<T>(
524    boxes: ArrayView2<T>,
525    scores: ArrayView2<T>,
526    classes: ArrayView2<T>,
527    mask_coeff: ArrayView2<T>,
528    protos: ArrayView3<T>,
529    score_threshold: f32,
530    output_boxes: &mut Vec<DetectBox>,
531    output_masks: &mut Vec<crate::Segmentation>,
532) -> Result<(), crate::DecoderError>
533where
534    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
535    f32: AsPrimitive<T>,
536{
537    let (boxes, scores, classes, mask_coeff) =
538        postprocess_yolo_split_end_to_end_segdet(boxes, scores, &classes, mask_coeff)?;
539    let cap = cap_or_default(output_boxes);
540    let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
541        boxes,
542        scores,
543        classes,
544        score_threshold,
545        cap,
546    );
547
548    impl_yolo_split_segdet_process_masks(boxes, mask_coeff, protos, output_boxes, output_masks)
549}
550
551#[allow(clippy::type_complexity)]
552pub(crate) fn postprocess_yolo_end_to_end_segdet<'a, T>(
553    output: &'a ArrayView2<'_, T>,
554    num_protos: usize,
555) -> Result<
556    (
557        ArrayView2<'a, T>,
558        ArrayView2<'a, T>,
559        ArrayView1<'a, T>,
560        ArrayView2<'a, T>,
561    ),
562    crate::DecoderError,
563> {
564    // Validate input shape: need at least 7 rows (6 base + at least 1 mask coeff)
565    if output.shape()[0] < 7 {
566        return Err(crate::DecoderError::InvalidShape(format!(
567            "End-to-end segdet output requires at least 7 rows, got {}",
568            output.shape()[0]
569        )));
570    }
571
572    let num_mask_coeffs = output.shape()[0] - 6;
573    if num_mask_coeffs != num_protos {
574        return Err(crate::DecoderError::InvalidShape(format!(
575            "Mask coefficients count ({}) doesn't match protos count ({})",
576            num_mask_coeffs, num_protos
577        )));
578    }
579
580    // Input shape: (6+num_protos, N) -> transpose for postprocessing
581    let boxes = output.slice(s![0..4, ..]).reversed_axes();
582    let scores = output.slice(s![4..5, ..]).reversed_axes();
583    let classes = output.slice(s![5, ..]);
584    let mask_coeff = output.slice(s![6.., ..]).reversed_axes();
585    Ok((boxes, scores, classes, mask_coeff))
586}
587
588/// Postprocess yolo split end to end det by reversing axes of boxes,
589/// scores, and flattening the class tensor.
590/// Expected input shapes:
591/// - boxes: (4, N)
592/// - scores: (1, N)
593/// - classes: (1, N)
594#[allow(clippy::type_complexity)]
595pub(crate) fn postprocess_yolo_split_end_to_end_det<'a, 'b, 'c, BOXES, SCORES, CLASS>(
596    boxes: ArrayView2<'a, BOXES>,
597    scores: ArrayView2<'b, SCORES>,
598    classes: &'c ArrayView2<CLASS>,
599) -> Result<
600    (
601        ArrayView2<'a, BOXES>,
602        ArrayView2<'b, SCORES>,
603        ArrayView1<'c, CLASS>,
604    ),
605    crate::DecoderError,
606> {
607    let num_boxes = boxes.shape()[1];
608    if boxes.shape()[0] != 4 {
609        return Err(crate::DecoderError::InvalidShape(format!(
610            "Split end-to-end box_coords must be 4, got {}",
611            boxes.shape()[0]
612        )));
613    }
614
615    if scores.shape()[0] != 1 {
616        return Err(crate::DecoderError::InvalidShape(format!(
617            "Split end-to-end scores num_classes must be 1, got {}",
618            scores.shape()[0]
619        )));
620    }
621
622    if classes.shape()[0] != 1 {
623        return Err(crate::DecoderError::InvalidShape(format!(
624            "Split end-to-end classes num_classes must be 1, got {}",
625            classes.shape()[0]
626        )));
627    }
628
629    if scores.shape()[1] != num_boxes {
630        return Err(crate::DecoderError::InvalidShape(format!(
631            "Split end-to-end scores must have same num_boxes as boxes ({}), got {}",
632            num_boxes,
633            scores.shape()[1]
634        )));
635    }
636
637    if classes.shape()[1] != num_boxes {
638        return Err(crate::DecoderError::InvalidShape(format!(
639            "Split end-to-end classes must have same num_boxes as boxes ({}), got {}",
640            num_boxes,
641            classes.shape()[1]
642        )));
643    }
644
645    let boxes = boxes.reversed_axes();
646    let scores = scores.reversed_axes();
647    let classes = classes.slice(s![0, ..]);
648    Ok((boxes, scores, classes))
649}
650
651/// Postprocess yolo split end to end segdet by reversing axes of boxes,
652/// scores, mask tensors and flattening the class tensor.
653#[allow(clippy::type_complexity)]
654pub(crate) fn postprocess_yolo_split_end_to_end_segdet<
655    'a,
656    'b,
657    'c,
658    'd,
659    BOXES,
660    SCORES,
661    CLASS,
662    MASK,
663>(
664    boxes: ArrayView2<'a, BOXES>,
665    scores: ArrayView2<'b, SCORES>,
666    classes: &'c ArrayView2<CLASS>,
667    mask_coeff: ArrayView2<'d, MASK>,
668) -> Result<
669    (
670        ArrayView2<'a, BOXES>,
671        ArrayView2<'b, SCORES>,
672        ArrayView1<'c, CLASS>,
673        ArrayView2<'d, MASK>,
674    ),
675    crate::DecoderError,
676> {
677    let num_boxes = boxes.shape()[1];
678    if boxes.shape()[0] != 4 {
679        return Err(crate::DecoderError::InvalidShape(format!(
680            "Split end-to-end box_coords must be 4, got {}",
681            boxes.shape()[0]
682        )));
683    }
684
685    if scores.shape()[0] != 1 {
686        return Err(crate::DecoderError::InvalidShape(format!(
687            "Split end-to-end scores num_classes must be 1, got {}",
688            scores.shape()[0]
689        )));
690    }
691
692    if classes.shape()[0] != 1 {
693        return Err(crate::DecoderError::InvalidShape(format!(
694            "Split end-to-end classes num_classes must be 1, got {}",
695            classes.shape()[0]
696        )));
697    }
698
699    if scores.shape()[1] != num_boxes {
700        return Err(crate::DecoderError::InvalidShape(format!(
701            "Split end-to-end scores must have same num_boxes as boxes ({}), got {}",
702            num_boxes,
703            scores.shape()[1]
704        )));
705    }
706
707    if classes.shape()[1] != num_boxes {
708        return Err(crate::DecoderError::InvalidShape(format!(
709            "Split end-to-end classes must have same num_boxes as boxes ({}), got {}",
710            num_boxes,
711            classes.shape()[1]
712        )));
713    }
714
715    if mask_coeff.shape()[1] != num_boxes {
716        return Err(crate::DecoderError::InvalidShape(format!(
717            "Split end-to-end mask_coeff must have same num_boxes as boxes ({}), got {}",
718            num_boxes,
719            mask_coeff.shape()[1]
720        )));
721    }
722
723    let boxes = boxes.reversed_axes();
724    let scores = scores.reversed_axes();
725    let classes = classes.slice(s![0, ..]);
726    let mask_coeff = mask_coeff.reversed_axes();
727    Ok((boxes, scores, classes, mask_coeff))
728}
729/// Internal implementation of YOLO decoding for quantized tensors.
730///
731/// Expected shapes of inputs:
732/// - output: (4 + num_classes, num_boxes)
733pub(crate) fn impl_yolo_quant<B: BBoxTypeTrait, T: PrimInt + AsPrimitive<f32> + Send + Sync>(
734    output: (ArrayView2<T>, Quantization),
735    score_threshold: f32,
736    iou_threshold: f32,
737    nms: Option<Nms>,
738    output_boxes: &mut Vec<DetectBox>,
739) where
740    f32: AsPrimitive<T>,
741{
742    let _span = tracing::trace_span!("decoder.decode.yolo_quant_flat").entered();
743    let (boxes, quant_boxes) = output;
744    let (boxes_tensor, scores_tensor) = postprocess_yolo(&boxes);
745
746    let boxes = {
747        let score_threshold = quantize_score_threshold(score_threshold, quant_boxes);
748        postprocess_boxes_quant::<B, _, _>(
749            score_threshold,
750            boxes_tensor,
751            scores_tensor,
752            quant_boxes,
753        )
754    };
755
756    let cap = cap_or_default(output_boxes);
757    let boxes = dispatch_nms_int(nms, iou_threshold, Some(cap), boxes);
758    // Detection cap convention (see `cap_or_default`). NMS already capped to
759    // `cap`; the `min` here is a redundant guard for non-NMS bypass mode.
760    let len = cap.min(boxes.len());
761    output_boxes.clear();
762    for b in boxes.iter().take(len) {
763        output_boxes.push(dequant_detect_box(b, quant_boxes));
764    }
765}
766
767/// Internal implementation of YOLO decoding for float tensors.
768///
769/// Expected shapes of inputs:
770/// - output: (4 + num_classes, num_boxes)
771pub(crate) fn impl_yolo_float<B: BBoxTypeTrait, T: Float + AsPrimitive<f32> + Send + Sync>(
772    output: ArrayView2<T>,
773    score_threshold: f32,
774    iou_threshold: f32,
775    nms: Option<Nms>,
776    output_boxes: &mut Vec<DetectBox>,
777) where
778    f32: AsPrimitive<T>,
779{
780    let _span = tracing::trace_span!("decoder.decode.yolo_float_flat").entered();
781    let (boxes_tensor, scores_tensor) = postprocess_yolo(&output);
782    let boxes =
783        postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
784    let cap = cap_or_default(output_boxes);
785    let boxes = dispatch_nms_float(nms, iou_threshold, Some(cap), boxes);
786    // Detection cap convention (see `cap_or_default`). NMS already capped to
787    // `cap`; the `min` here is a redundant guard for non-NMS bypass mode.
788    let len = cap.min(boxes.len());
789    output_boxes.clear();
790    for b in boxes.into_iter().take(len) {
791        output_boxes.push(b);
792    }
793}
794
795/// Internal implementation of YOLO split detection decoding for quantized
796/// tensors.
797///
798/// Expected shapes of inputs:
799/// - boxes: (4, num_boxes)
800/// - scores: (num_classes, num_boxes)
801///
802/// # Panics
803/// Panics if shapes don't match the expected dimensions.
804pub(crate) fn impl_yolo_split_quant<
805    B: BBoxTypeTrait,
806    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
807    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
808>(
809    boxes: (ArrayView2<BOX>, Quantization),
810    scores: (ArrayView2<SCORE>, Quantization),
811    score_threshold: f32,
812    iou_threshold: f32,
813    nms: Option<Nms>,
814    output_boxes: &mut Vec<DetectBox>,
815) where
816    f32: AsPrimitive<SCORE>,
817{
818    let _span = tracing::trace_span!("decoder.decode.yolo_quant_split").entered();
819    let (boxes_tensor, quant_boxes) = boxes;
820    let (scores_tensor, quant_scores) = scores;
821
822    let boxes_tensor = boxes_tensor.reversed_axes();
823    let scores_tensor = scores_tensor.reversed_axes();
824
825    let boxes = {
826        let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
827        postprocess_boxes_quant::<B, _, _>(
828            score_threshold,
829            boxes_tensor,
830            scores_tensor,
831            quant_boxes,
832        )
833    };
834
835    let cap = cap_or_default(output_boxes);
836    let boxes = dispatch_nms_int(nms, iou_threshold, Some(cap), boxes);
837    // Detection cap convention (see `cap_or_default`). NMS already capped to
838    // `cap`; the `min` here is a redundant guard for non-NMS bypass mode.
839    let len = cap.min(boxes.len());
840    output_boxes.clear();
841    for b in boxes.iter().take(len) {
842        output_boxes.push(dequant_detect_box(b, quant_scores));
843    }
844}
845
846/// Internal implementation of YOLO split detection decoding for float tensors.
847///
848/// Expected shapes of inputs:
849/// - boxes: (4, num_boxes)
850/// - scores: (num_classes, num_boxes)
851///
852/// # Panics
853/// Panics if shapes don't match the expected dimensions.
854pub(crate) fn impl_yolo_split_float<
855    B: BBoxTypeTrait,
856    BOX: Float + AsPrimitive<f32> + Send + Sync,
857    SCORE: Float + AsPrimitive<f32> + Send + Sync,
858>(
859    boxes_tensor: ArrayView2<BOX>,
860    scores_tensor: ArrayView2<SCORE>,
861    score_threshold: f32,
862    iou_threshold: f32,
863    nms: Option<Nms>,
864    output_boxes: &mut Vec<DetectBox>,
865) where
866    f32: AsPrimitive<SCORE>,
867{
868    let _span = tracing::trace_span!("decoder.decode.yolo_float_split").entered();
869    let boxes_tensor = boxes_tensor.reversed_axes();
870    let scores_tensor = scores_tensor.reversed_axes();
871    let boxes =
872        postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
873    let cap = cap_or_default(output_boxes);
874    let boxes = dispatch_nms_float(nms, iou_threshold, Some(cap), boxes);
875    // Detection cap convention (see `cap_or_default`). NMS already capped to
876    // `cap`; the `min` here is a redundant guard for non-NMS bypass mode.
877    let len = cap.min(boxes.len());
878    output_boxes.clear();
879    for b in boxes.into_iter().take(len) {
880        output_boxes.push(b);
881    }
882}
883
884/// Divide each survivor's bbox by `(input_w, input_h)` when the schema
885/// declares `normalized: false`. Pixel-space box coords from the model
886/// are pulled into the canonical `[0, 1]` range expected by `protobox`
887/// and downstream callers — see EDGEAI-1303.
888///
889/// No-op when `normalized` is `Some(true)` / `None`, when `input_dims`
890/// is `None`, or when there are no survivors.
891#[inline]
892pub(crate) fn maybe_normalize_boxes_in_place(
893    boxes: &mut [(DetectBox, usize)],
894    normalized: Option<bool>,
895    input_dims: Option<(usize, usize)>,
896) {
897    if normalized != Some(false) {
898        return;
899    }
900    let Some((w, h)) = input_dims else {
901        return;
902    };
903    if w == 0 || h == 0 {
904        return;
905    }
906    let inv_w = 1.0 / w as f32;
907    let inv_h = 1.0 / h as f32;
908    for (b, _) in boxes.iter_mut() {
909        b.bbox.xmin *= inv_w;
910        b.bbox.ymin *= inv_h;
911        b.bbox.xmax *= inv_w;
912        b.bbox.ymax *= inv_h;
913    }
914}
915
916/// Internal implementation of YOLO detection segmentation decoding for
917/// quantized tensors.
918///
919/// Expected shapes of inputs:
920/// - boxes: (4 + num_classes + num_protos, num_boxes)
921/// - protos: (proto_height, proto_width, num_protos)
922///
923/// # Errors
924/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
925#[allow(clippy::too_many_arguments)]
926pub(crate) fn impl_yolo_segdet_quant<
927    B: BBoxTypeTrait,
928    BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
929    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
930>(
931    boxes: (ArrayView2<BOX>, Quantization),
932    protos: (ArrayView3<PROTO>, Quantization),
933    score_threshold: f32,
934    iou_threshold: f32,
935    nms: Option<Nms>,
936    pre_nms_top_k: usize,
937    max_det: usize,
938    normalized: Option<bool>,
939    input_dims: Option<(usize, usize)>,
940    output_boxes: &mut Vec<DetectBox>,
941    output_masks: &mut Vec<Segmentation>,
942) -> Result<(), crate::DecoderError>
943where
944    f32: AsPrimitive<BOX>,
945{
946    let (boxes, quant_boxes) = boxes;
947    let num_protos = protos.0.dim().2;
948
949    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
950    let mut boxes = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
951        (boxes_tensor, quant_boxes),
952        (scores_tensor, quant_boxes),
953        score_threshold,
954        iou_threshold,
955        nms,
956        pre_nms_top_k,
957        max_det,
958    );
959    maybe_normalize_boxes_in_place(&mut boxes, normalized, input_dims);
960
961    impl_yolo_split_segdet_quant_process_masks::<_, _>(
962        boxes,
963        (mask_tensor, quant_boxes),
964        protos,
965        output_boxes,
966        output_masks,
967    )
968}
969
970/// Internal implementation of YOLO detection segmentation decoding for
971/// float tensors.
972///
973/// Expected shapes of inputs:
974/// - boxes: (4 + num_classes + num_protos, num_boxes)
975/// - protos: (proto_height, proto_width, num_protos)
976///
977/// # Panics
978/// Panics if shapes don't match the expected dimensions.
979#[allow(clippy::too_many_arguments)]
980pub(crate) fn impl_yolo_segdet_float<
981    B: BBoxTypeTrait,
982    BOX: Float + AsPrimitive<f32> + Send + Sync,
983    PROTO: Float + AsPrimitive<f32> + Send + Sync,
984>(
985    boxes: ArrayView2<BOX>,
986    protos: ArrayView3<PROTO>,
987    score_threshold: f32,
988    iou_threshold: f32,
989    nms: Option<Nms>,
990    pre_nms_top_k: usize,
991    max_det: usize,
992    normalized: Option<bool>,
993    input_dims: Option<(usize, usize)>,
994    output_boxes: &mut Vec<DetectBox>,
995    output_masks: &mut Vec<Segmentation>,
996) -> Result<(), crate::DecoderError>
997where
998    f32: AsPrimitive<BOX>,
999{
1000    let num_protos = protos.dim().2;
1001    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
1002    let mut boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
1003        boxes_tensor,
1004        scores_tensor,
1005        score_threshold,
1006        iou_threshold,
1007        nms,
1008        pre_nms_top_k,
1009        max_det,
1010    );
1011    maybe_normalize_boxes_in_place(&mut boxes, normalized, input_dims);
1012    impl_yolo_split_segdet_process_masks(boxes, mask_tensor, protos, output_boxes, output_masks)
1013}
1014
1015pub(crate) fn impl_yolo_segdet_get_boxes<
1016    B: BBoxTypeTrait,
1017    BOX: Float + AsPrimitive<f32> + Send + Sync,
1018    SCORE: Float + AsPrimitive<f32> + Send + Sync,
1019>(
1020    boxes_tensor: ArrayView2<BOX>,
1021    scores_tensor: ArrayView2<SCORE>,
1022    score_threshold: f32,
1023    iou_threshold: f32,
1024    nms: Option<Nms>,
1025    pre_nms_top_k: usize,
1026    max_det: usize,
1027) -> Vec<(DetectBox, usize)>
1028where
1029    f32: AsPrimitive<SCORE>,
1030{
1031    let span = tracing::trace_span!(
1032        "decoder.nms_get_boxes",
1033        n_candidates = tracing::field::Empty,
1034        n_after_topk = tracing::field::Empty,
1035        n_after_nms = tracing::field::Empty,
1036        n_detections = tracing::field::Empty,
1037    );
1038    let _guard = span.enter();
1039
1040    let mut boxes = {
1041        let _s = tracing::trace_span!("decoder.nms_get_boxes.score_filter").entered();
1042        postprocess_boxes_index_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor)
1043    };
1044    span.record("n_candidates", boxes.len());
1045
1046    if nms.is_some() {
1047        let _s = tracing::trace_span!("decoder.nms_get_boxes.top_k", k = pre_nms_top_k).entered();
1048        truncate_to_top_k_by_score(&mut boxes, pre_nms_top_k);
1049    }
1050    span.record("n_after_topk", boxes.len());
1051
1052    let mut boxes = {
1053        let _s = tracing::trace_span!("decoder.nms_get_boxes.suppress").entered();
1054        dispatch_nms_extra_float(nms, iou_threshold, Some(max_det), boxes)
1055    };
1056    span.record("n_after_nms", boxes.len());
1057
1058    // NMS already capped to `max_det`; the trailing sort+truncate is a
1059    // redundant guard for the bypass-NMS path (`nms = None`).
1060    boxes.sort_unstable_by(|a, b| b.0.score.total_cmp(&a.0.score));
1061    boxes.truncate(max_det);
1062    span.record("n_detections", boxes.len());
1063
1064    boxes
1065}
1066
1067pub(crate) fn impl_yolo_end_to_end_segdet_get_boxes<
1068    B: BBoxTypeTrait,
1069    BOX: Float + AsPrimitive<f32> + Send + Sync,
1070    SCORE: Float + AsPrimitive<f32> + Send + Sync,
1071    CLASS: AsPrimitive<f32> + Send + Sync,
1072>(
1073    boxes: ArrayView2<BOX>,
1074    scores: ArrayView2<SCORE>,
1075    classes: ArrayView1<CLASS>,
1076    score_threshold: f32,
1077    max_boxes: usize,
1078) -> Vec<(DetectBox, usize)>
1079where
1080    f32: AsPrimitive<SCORE>,
1081{
1082    let mut boxes = postprocess_boxes_index_float::<B, _, _>(score_threshold.as_(), boxes, scores);
1083    boxes.truncate(max_boxes);
1084    for (b, ind) in &mut boxes {
1085        b.label = classes[*ind].as_().round() as usize;
1086    }
1087    boxes
1088}
1089
1090pub(crate) fn impl_yolo_split_segdet_process_masks<
1091    MASK: Float + AsPrimitive<f32> + Send + Sync,
1092    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1093>(
1094    boxes: Vec<(DetectBox, usize)>,
1095    masks_tensor: ArrayView2<MASK>,
1096    protos_tensor: ArrayView3<PROTO>,
1097    output_boxes: &mut Vec<DetectBox>,
1098    output_masks: &mut Vec<Segmentation>,
1099) -> Result<(), crate::DecoderError> {
1100    let _span = tracing::trace_span!(
1101        "decoder.decode.process_masks",
1102        n = boxes.len(),
1103        mode = "float"
1104    )
1105    .entered();
1106    // Boxes are already bounded by the upstream `max_det` cap from
1107    // `_get_boxes`; no second cap is needed here (EDGEAI-1302).
1108
1109    let boxes = decode_segdet_f32(boxes, masks_tensor, protos_tensor)?;
1110    output_boxes.clear();
1111    output_masks.clear();
1112    for (b, roi, m) in boxes.into_iter() {
1113        output_boxes.push(b);
1114        output_masks.push(Segmentation {
1115            xmin: roi.xmin,
1116            ymin: roi.ymin,
1117            xmax: roi.xmax,
1118            ymax: roi.ymax,
1119            segmentation: m,
1120        });
1121    }
1122    Ok(())
1123}
1124/// Expected input shapes:
1125/// - boxes_tensor: (num_boxes, 4)
1126/// - scores_tensor: (num_boxes, num_classes)
1127pub(crate) fn impl_yolo_split_segdet_quant_get_boxes<
1128    B: BBoxTypeTrait,
1129    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
1130    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
1131>(
1132    boxes: (ArrayView2<BOX>, Quantization),
1133    scores: (ArrayView2<SCORE>, Quantization),
1134    score_threshold: f32,
1135    iou_threshold: f32,
1136    nms: Option<Nms>,
1137    pre_nms_top_k: usize,
1138    max_det: usize,
1139) -> Vec<(DetectBox, usize)>
1140where
1141    f32: AsPrimitive<SCORE>,
1142{
1143    let (boxes_tensor, quant_boxes) = boxes;
1144    let (scores_tensor, quant_scores) = scores;
1145
1146    let span = tracing::trace_span!(
1147        "decoder.nms_get_boxes",
1148        n_candidates = tracing::field::Empty,
1149        n_after_topk = tracing::field::Empty,
1150        n_after_nms = tracing::field::Empty,
1151        n_detections = tracing::field::Empty,
1152    );
1153    let _guard = span.enter();
1154
1155    let mut boxes = {
1156        let _s = tracing::trace_span!("decoder.nms_get_boxes.score_filter").entered();
1157        let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
1158        postprocess_boxes_index_quant::<B, _, _>(
1159            score_threshold,
1160            boxes_tensor,
1161            scores_tensor,
1162            quant_boxes,
1163        )
1164    };
1165    span.record("n_candidates", boxes.len());
1166
1167    if nms.is_some() {
1168        let _s = tracing::trace_span!("decoder.nms_get_boxes.top_k", k = pre_nms_top_k).entered();
1169        truncate_to_top_k_by_score_quant(&mut boxes, pre_nms_top_k);
1170    }
1171    span.record("n_after_topk", boxes.len());
1172
1173    let mut boxes = {
1174        let _s = tracing::trace_span!("decoder.nms_get_boxes.suppress").entered();
1175        dispatch_nms_extra_int(nms, iou_threshold, Some(max_det), boxes)
1176    };
1177    span.record("n_after_nms", boxes.len());
1178
1179    // NMS already capped to `max_det`; the trailing sort+truncate is a
1180    // redundant guard for the bypass-NMS path (`nms = None`).
1181    boxes.sort_unstable_by_key(|b| std::cmp::Reverse(b.0.score));
1182    boxes.truncate(max_det);
1183    let result: Vec<_> = {
1184        let _s =
1185            tracing::trace_span!("decoder.nms_get_boxes.dequant_boxes", n = boxes.len()).entered();
1186        boxes
1187            .into_iter()
1188            .map(|(b, i)| (dequant_detect_box(&b, quant_scores), i))
1189            .collect()
1190    };
1191    span.record("n_detections", result.len());
1192
1193    result
1194}
1195
1196pub(crate) fn impl_yolo_split_segdet_quant_process_masks<
1197    MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1198    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1199>(
1200    boxes: Vec<(DetectBox, usize)>,
1201    mask_coeff: (ArrayView2<MASK>, Quantization),
1202    protos: (ArrayView3<PROTO>, Quantization),
1203    output_boxes: &mut Vec<DetectBox>,
1204    output_masks: &mut Vec<Segmentation>,
1205) -> Result<(), crate::DecoderError> {
1206    let _span = tracing::trace_span!(
1207        "decoder.decode.process_masks",
1208        n = boxes.len(),
1209        mode = "quant"
1210    )
1211    .entered();
1212    let (masks, quant_masks) = mask_coeff;
1213    let (protos, quant_protos) = protos;
1214
1215    // Boxes are already bounded by the upstream `max_det` cap from
1216    // `_get_boxes`; no second cap is needed here (EDGEAI-1302).
1217
1218    let boxes = decode_segdet_quant(boxes, masks, protos, quant_masks, quant_protos)?;
1219    output_boxes.clear();
1220    output_masks.clear();
1221    for (b, roi, m) in boxes.into_iter() {
1222        output_boxes.push(b);
1223        output_masks.push(Segmentation {
1224            xmin: roi.xmin,
1225            ymin: roi.ymin,
1226            xmax: roi.xmax,
1227            ymax: roi.ymax,
1228            segmentation: m,
1229        });
1230    }
1231    Ok(())
1232}
1233
1234/// Internal implementation of YOLO split detection segmentation decoding for
1235/// float tensors.
1236///
1237/// Expected shapes of inputs:
1238/// - boxes_tensor: (4, num_boxes)
1239/// - scores_tensor: (num_classes, num_boxes)
1240/// - mask_tensor: (num_protos, num_boxes)
1241/// - protos: (proto_height, proto_width, num_protos)
1242///
1243/// # Errors
1244/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
1245#[allow(clippy::too_many_arguments)]
1246pub(crate) fn impl_yolo_split_segdet_float<
1247    B: BBoxTypeTrait,
1248    BOX: Float + AsPrimitive<f32> + Send + Sync,
1249    SCORE: Float + AsPrimitive<f32> + Send + Sync,
1250    MASK: Float + AsPrimitive<f32> + Send + Sync,
1251    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1252>(
1253    boxes_tensor: ArrayView2<BOX>,
1254    scores_tensor: ArrayView2<SCORE>,
1255    mask_tensor: ArrayView2<MASK>,
1256    protos: ArrayView3<PROTO>,
1257    score_threshold: f32,
1258    iou_threshold: f32,
1259    nms: Option<Nms>,
1260    pre_nms_top_k: usize,
1261    max_det: usize,
1262    normalized: Option<bool>,
1263    input_dims: Option<(usize, usize)>,
1264    output_boxes: &mut Vec<DetectBox>,
1265    output_masks: &mut Vec<Segmentation>,
1266) -> Result<(), crate::DecoderError>
1267where
1268    f32: AsPrimitive<SCORE>,
1269{
1270    let (boxes_tensor, scores_tensor, mask_tensor) =
1271        postprocess_yolo_split_segdet(boxes_tensor, scores_tensor, mask_tensor);
1272
1273    let mut boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
1274        boxes_tensor,
1275        scores_tensor,
1276        score_threshold,
1277        iou_threshold,
1278        nms,
1279        pre_nms_top_k,
1280        max_det,
1281    );
1282    maybe_normalize_boxes_in_place(&mut boxes, normalized, input_dims);
1283    impl_yolo_split_segdet_process_masks(boxes, mask_tensor, protos, output_boxes, output_masks)
1284}
1285
1286// ---------------------------------------------------------------------------
1287// Proto-extraction variants: return ProtoData instead of materialized masks
1288// ---------------------------------------------------------------------------
1289
1290/// Proto-extraction variant of `impl_yolo_segdet_quant`.
1291/// Runs NMS but returns raw `ProtoData` instead of materialized masks.
1292#[allow(clippy::too_many_arguments)]
1293pub(crate) fn impl_yolo_segdet_quant_proto<
1294    B: BBoxTypeTrait,
1295    BOX: PrimInt
1296        + AsPrimitive<i64>
1297        + AsPrimitive<i128>
1298        + AsPrimitive<f32>
1299        + AsPrimitive<i8>
1300        + Send
1301        + Sync,
1302    PROTO: PrimInt
1303        + AsPrimitive<i64>
1304        + AsPrimitive<i128>
1305        + AsPrimitive<f32>
1306        + AsPrimitive<i8>
1307        + Send
1308        + Sync,
1309>(
1310    boxes: (ArrayView2<BOX>, Quantization),
1311    protos: (ArrayView3<PROTO>, Quantization),
1312    score_threshold: f32,
1313    iou_threshold: f32,
1314    nms: Option<Nms>,
1315    pre_nms_top_k: usize,
1316    max_det: usize,
1317    normalized: Option<bool>,
1318    input_dims: Option<(usize, usize)>,
1319    output_boxes: &mut Vec<DetectBox>,
1320) -> ProtoData
1321where
1322    f32: AsPrimitive<BOX>,
1323{
1324    let (boxes_arr, quant_boxes) = boxes;
1325    let (protos_arr, quant_protos) = protos;
1326    let num_protos = protos_arr.dim().2;
1327
1328    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes_arr, num_protos);
1329
1330    let mut det_indices = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
1331        (boxes_tensor, quant_boxes),
1332        (scores_tensor, quant_boxes),
1333        score_threshold,
1334        iou_threshold,
1335        nms,
1336        pre_nms_top_k,
1337        max_det,
1338    );
1339    maybe_normalize_boxes_in_place(&mut det_indices, normalized, input_dims);
1340
1341    extract_proto_data_quant(
1342        det_indices,
1343        mask_tensor,
1344        quant_boxes,
1345        protos_arr,
1346        quant_protos,
1347        output_boxes,
1348    )
1349}
1350
1351/// Proto-extraction variant of `impl_yolo_segdet_float`.
1352/// Runs NMS but returns raw `ProtoData` instead of materialized masks.
1353#[allow(clippy::too_many_arguments)]
1354pub(crate) fn impl_yolo_segdet_float_proto<
1355    B: BBoxTypeTrait,
1356    BOX: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1357    PROTO: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1358>(
1359    boxes: ArrayView2<BOX>,
1360    protos: ArrayView3<PROTO>,
1361    score_threshold: f32,
1362    iou_threshold: f32,
1363    nms: Option<Nms>,
1364    pre_nms_top_k: usize,
1365    max_det: usize,
1366    normalized: Option<bool>,
1367    input_dims: Option<(usize, usize)>,
1368    output_boxes: &mut Vec<DetectBox>,
1369) -> ProtoData
1370where
1371    f32: AsPrimitive<BOX>,
1372{
1373    let num_protos = protos.dim().2;
1374    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
1375
1376    let mut boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
1377        boxes_tensor,
1378        scores_tensor,
1379        score_threshold,
1380        iou_threshold,
1381        nms,
1382        pre_nms_top_k,
1383        max_det,
1384    );
1385    maybe_normalize_boxes_in_place(&mut boxes, normalized, input_dims);
1386
1387    extract_proto_data_float(boxes, mask_tensor, protos, output_boxes)
1388}
1389
1390/// Proto-extraction variant of `impl_yolo_split_segdet_float`.
1391/// Runs NMS but returns raw `ProtoData` instead of materialized masks.
1392#[allow(clippy::too_many_arguments)]
1393pub(crate) fn impl_yolo_split_segdet_float_proto<
1394    B: BBoxTypeTrait,
1395    BOX: Float + AsPrimitive<f32> + Send + Sync,
1396    SCORE: Float + AsPrimitive<f32> + Send + Sync,
1397    MASK: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1398    PROTO: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1399>(
1400    boxes_tensor: ArrayView2<BOX>,
1401    scores_tensor: ArrayView2<SCORE>,
1402    mask_tensor: ArrayView2<MASK>,
1403    protos: ArrayView3<PROTO>,
1404    score_threshold: f32,
1405    iou_threshold: f32,
1406    nms: Option<Nms>,
1407    pre_nms_top_k: usize,
1408    max_det: usize,
1409    output_boxes: &mut Vec<DetectBox>,
1410) -> ProtoData
1411where
1412    f32: AsPrimitive<SCORE>,
1413{
1414    let (boxes_tensor, scores_tensor, mask_tensor) =
1415        postprocess_yolo_split_segdet(boxes_tensor, scores_tensor, mask_tensor);
1416    let det_indices = impl_yolo_segdet_get_boxes::<B, _, _>(
1417        boxes_tensor,
1418        scores_tensor,
1419        score_threshold,
1420        iou_threshold,
1421        nms,
1422        pre_nms_top_k,
1423        max_det,
1424    );
1425
1426    extract_proto_data_float(det_indices, mask_tensor, protos, output_boxes)
1427}
1428
1429/// Proto-extraction variant of `decode_yolo_end_to_end_segdet_float`.
1430pub(crate) fn decode_yolo_end_to_end_segdet_float_proto<T>(
1431    output: ArrayView2<T>,
1432    protos: ArrayView3<T>,
1433    score_threshold: f32,
1434    output_boxes: &mut Vec<DetectBox>,
1435) -> Result<ProtoData, crate::DecoderError>
1436where
1437    T: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1438    f32: AsPrimitive<T>,
1439{
1440    let (boxes, scores, classes, mask_coeff) =
1441        postprocess_yolo_end_to_end_segdet(&output, protos.dim().2)?;
1442    let cap = cap_or_default(output_boxes);
1443    let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
1444        boxes,
1445        scores,
1446        classes,
1447        score_threshold,
1448        cap,
1449    );
1450
1451    Ok(extract_proto_data_float(
1452        boxes,
1453        mask_coeff,
1454        protos,
1455        output_boxes,
1456    ))
1457}
1458
1459/// Proto-extraction variant of `decode_yolo_split_end_to_end_segdet_float`.
1460#[allow(clippy::too_many_arguments)]
1461pub(crate) fn decode_yolo_split_end_to_end_segdet_float_proto<T>(
1462    boxes: ArrayView2<T>,
1463    scores: ArrayView2<T>,
1464    classes: ArrayView2<T>,
1465    mask_coeff: ArrayView2<T>,
1466    protos: ArrayView3<T>,
1467    score_threshold: f32,
1468    output_boxes: &mut Vec<DetectBox>,
1469) -> Result<ProtoData, crate::DecoderError>
1470where
1471    T: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1472    f32: AsPrimitive<T>,
1473{
1474    let (boxes, scores, classes, mask_coeff) =
1475        postprocess_yolo_split_end_to_end_segdet(boxes, scores, &classes, mask_coeff)?;
1476    let cap = cap_or_default(output_boxes);
1477    let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
1478        boxes,
1479        scores,
1480        classes,
1481        score_threshold,
1482        cap,
1483    );
1484
1485    Ok(extract_proto_data_float(
1486        boxes,
1487        mask_coeff,
1488        protos,
1489        output_boxes,
1490    ))
1491}
1492
1493/// Helper: extract ProtoData from float mask coefficients + protos.
1494///
1495/// Builds [`ProtoData`] with both `protos` and `mask_coefficients` as
1496/// [`edgefirst_tensor::TensorDyn`]. Preserves the native element type for
1497/// `f16` and `f32`; narrows `f64` to `f32` (there is no native f64 kernel
1498/// path). `mask_coefficients` shape is `[num_detections, num_protos]`.
1499pub(super) fn extract_proto_data_float<
1500    MASK: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1501    PROTO: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1502>(
1503    det_indices: Vec<(DetectBox, usize)>,
1504    mask_tensor: ArrayView2<MASK>,
1505    protos: ArrayView3<PROTO>,
1506    output_boxes: &mut Vec<DetectBox>,
1507) -> ProtoData {
1508    let _span = tracing::trace_span!(
1509        "decoder.decode_proto.extract_proto_data",
1510        mode = "float",
1511        n = det_indices.len(),
1512        num_protos = mask_tensor.ncols(),
1513        layout = "nhwc",
1514    )
1515    .entered();
1516
1517    let num_protos = mask_tensor.ncols();
1518    let n = det_indices.len();
1519
1520    // Per-detection coefficients packed row-major into a contiguous buffer,
1521    // preserving the source dtype. Shape: [N, num_protos] — N=0 is permitted
1522    // (tracker path emits no detections this frame) since Mem-backed tensors
1523    // accept zero-element shapes as "empty collection" sentinels.
1524    let mut coeff_rows: Vec<MASK> = Vec::with_capacity(n * num_protos);
1525    output_boxes.clear();
1526    for (det, idx) in det_indices {
1527        output_boxes.push(det);
1528        let row = mask_tensor.row(idx);
1529        coeff_rows.extend(row.iter().copied());
1530    }
1531
1532    let mask_coefficients = MASK::slice_into_tensor_dyn(&coeff_rows, &[n, num_protos])
1533        .expect("allocating mask_coefficients TensorDyn");
1534    let protos_tensor =
1535        PROTO::arrayview3_into_tensor_dyn(protos).expect("allocating protos TensorDyn");
1536
1537    ProtoData {
1538        mask_coefficients,
1539        protos: protos_tensor,
1540        layout: ProtoLayout::Nhwc,
1541    }
1542}
1543
1544/// Helper: extract ProtoData from quantized mask coefficients + protos.
1545///
1546/// Dequantizes mask coefficients to f32 at extraction (one-time cost on a
1547/// `num_detections * num_protos` slice) and keeps protos in raw i8,
1548/// attaching the dequantization params as
1549/// [`edgefirst_tensor::Quantization::per_tensor`] metadata on the proto
1550/// tensor. The GPU shader / CPU kernel reads `protos.quantization()` and
1551/// dequantizes per-texel.
1552pub(crate) fn extract_proto_data_quant<
1553    MASK: PrimInt + AsPrimitive<f32> + AsPrimitive<i8> + Send + Sync + 'static,
1554    PROTO: PrimInt + AsPrimitive<f32> + AsPrimitive<i8> + Send + Sync + 'static,
1555>(
1556    det_indices: Vec<(DetectBox, usize)>,
1557    mask_tensor: ArrayView2<MASK>,
1558    quant_masks: Quantization,
1559    protos: ArrayView3<PROTO>,
1560    quant_protos: Quantization,
1561    output_boxes: &mut Vec<DetectBox>,
1562) -> ProtoData {
1563    use edgefirst_tensor::{Tensor, TensorDyn, TensorMapTrait, TensorMemory, TensorTrait};
1564
1565    let span = tracing::trace_span!(
1566        "decoder.decode_proto.extract_proto_data",
1567        mode = "quant",
1568        n = det_indices.len(),
1569        num_protos = tracing::field::Empty,
1570        layout = tracing::field::Empty,
1571    );
1572    let _guard = span.enter();
1573
1574    let num_protos = mask_tensor.ncols();
1575    let n = det_indices.len();
1576    span.record("num_protos", num_protos);
1577
1578    // Fast path: when no detections survive NMS, skip the expensive proto
1579    // tensor copy (819KB for 160×160×32). Allocate with the correct shape
1580    // (preserving the documented ProtoData.protos shape contract) but skip
1581    // copying from the source tensor — the zeroed allocation is sufficient
1582    // since materialize_masks early-returns on empty detect slices.
1583    if n == 0 {
1584        output_boxes.clear();
1585        let (h, w, k) = protos.dim();
1586
1587        // Detect physical layout (same logic as the normal path).
1588        let (proto_shape, proto_layout) = if std::any::TypeId::of::<PROTO>()
1589            == std::any::TypeId::of::<i8>()
1590        {
1591            if protos.is_standard_layout() {
1592                (&[h, w, k][..], ProtoLayout::Nhwc)
1593            } else if protos.ndim() == 3 && protos.strides() == [w as isize, 1, (h * w) as isize] {
1594                (&[k, h, w][..], ProtoLayout::Nchw)
1595            } else {
1596                (&[h, w, k][..], ProtoLayout::Nhwc)
1597            }
1598        } else {
1599            (&[h, w, k][..], ProtoLayout::Nhwc)
1600        };
1601
1602        let coeff_tensor = Tensor::<i8>::new(&[0, num_protos], Some(TensorMemory::Mem), None)
1603            .expect("allocating empty mask_coefficients tensor");
1604        let coeff_quant =
1605            edgefirst_tensor::Quantization::per_tensor(quant_masks.scale, quant_masks.zero_point);
1606        let coeff_tensor = coeff_tensor
1607            .with_quantization(coeff_quant)
1608            .expect("per-tensor quantization on mask coefficients");
1609        let protos_tensor = Tensor::<i8>::new(proto_shape, Some(TensorMemory::Mem), None)
1610            .expect("allocating protos tensor");
1611        let tensor_quant =
1612            edgefirst_tensor::Quantization::per_tensor(quant_protos.scale, quant_protos.zero_point);
1613        let protos_tensor = protos_tensor
1614            .with_quantization(tensor_quant)
1615            .expect("per-tensor quantization on protos tensor");
1616        return ProtoData {
1617            mask_coefficients: TensorDyn::I8(coeff_tensor),
1618            protos: TensorDyn::I8(protos_tensor),
1619            layout: proto_layout,
1620        };
1621    }
1622
1623    // Mask coefficients: keep i8 losslessly when MASK == i8 (preserves
1624    // the fast i8×i8→i32 integer kernel in materialize_masks). Preserve
1625    // i16 natively so the downstream i16×i8 integer path can avoid lossy
1626    // truncation. Other wider types (u16, …) dequantize to f32 at
1627    // extraction because the downstream mask kernels accept F32 natively.
1628    let mask_coefficients: TensorDyn = if std::any::TypeId::of::<MASK>()
1629        == std::any::TypeId::of::<i8>()
1630    {
1631        let mut coeff_i8 = Vec::<i8>::with_capacity(n * num_protos);
1632        output_boxes.clear();
1633        for (det, idx) in det_indices {
1634            output_boxes.push(det);
1635            let row = mask_tensor.row(idx);
1636            coeff_i8.extend(row.iter().map(|v| {
1637                let v_i8: i8 = v.as_();
1638                v_i8
1639            }));
1640        }
1641        let coeff_tensor = Tensor::<i8>::new(&[n, num_protos], Some(TensorMemory::Mem), None)
1642            .expect("allocating mask_coefficients tensor");
1643        if n > 0 {
1644            let mut m = coeff_tensor
1645                .map()
1646                .expect("mapping mask_coefficients tensor");
1647            m.as_mut_slice().copy_from_slice(&coeff_i8);
1648        }
1649        let coeff_quant =
1650            edgefirst_tensor::Quantization::per_tensor(quant_masks.scale, quant_masks.zero_point);
1651        let coeff_tensor = coeff_tensor
1652            .with_quantization(coeff_quant)
1653            .expect("per-tensor quantization on mask coefficients");
1654        TensorDyn::I8(coeff_tensor)
1655    } else if std::any::TypeId::of::<MASK>() == std::any::TypeId::of::<i16>() {
1656        // i16 path: preserve natively for the fast i16×i8→i32 integer kernel.
1657        // f32 has 24-bit mantissa, so all i16 values are exactly representable.
1658        let mut coeff_i16 = Vec::<i16>::with_capacity(n * num_protos);
1659        output_boxes.clear();
1660        for (det, idx) in det_indices {
1661            output_boxes.push(det);
1662            let row = mask_tensor.row(idx);
1663            coeff_i16.extend(row.iter().map(|v| {
1664                let v_f32: f32 = v.as_();
1665                v_f32 as i16
1666            }));
1667        }
1668        let coeff_tensor = Tensor::<i16>::new(&[n, num_protos], Some(TensorMemory::Mem), None)
1669            .expect("allocating mask_coefficients tensor");
1670        if n > 0 {
1671            let mut m = coeff_tensor
1672                .map()
1673                .expect("mapping mask_coefficients tensor");
1674            m.as_mut_slice().copy_from_slice(&coeff_i16);
1675        }
1676        let coeff_quant =
1677            edgefirst_tensor::Quantization::per_tensor(quant_masks.scale, quant_masks.zero_point);
1678        let coeff_tensor = coeff_tensor
1679            .with_quantization(coeff_quant)
1680            .expect("per-tensor quantization on mask coefficients");
1681        TensorDyn::I16(coeff_tensor)
1682    } else {
1683        // Other types (u8, u16, etc.): dequantize to f32 to avoid lossy truncation.
1684        let scale = quant_masks.scale;
1685        let zp = quant_masks.zero_point as f32;
1686        let mut coeff_f32 = Vec::<f32>::with_capacity(n * num_protos);
1687        output_boxes.clear();
1688        for (det, idx) in det_indices {
1689            output_boxes.push(det);
1690            let row = mask_tensor.row(idx);
1691            coeff_f32.extend(row.iter().map(|v| {
1692                let v_f32: f32 = v.as_();
1693                (v_f32 - zp) * scale
1694            }));
1695        }
1696        let coeff_tensor = Tensor::<f32>::new(&[n, num_protos], Some(TensorMemory::Mem), None)
1697            .expect("allocating mask_coefficients tensor");
1698        if n > 0 {
1699            let mut m = coeff_tensor
1700                .map()
1701                .expect("mapping mask_coefficients tensor");
1702            m.as_mut_slice().copy_from_slice(&coeff_f32);
1703        }
1704        TensorDyn::F32(coeff_tensor)
1705    };
1706
1707    // Keep protos in raw i8 — consumers dequantize via protos.quantization().
1708    // When PROTO is already i8, detect layout and copy efficiently without
1709    // transposing. The mask materialisation kernels dispatch on the layout.
1710    let (h, w, k) = protos.dim();
1711
1712    // Determine physical layout and copy strategy.
1713    let (proto_shape, proto_layout) =
1714        if std::any::TypeId::of::<PROTO>() == std::any::TypeId::of::<i8>() {
1715            if protos.is_standard_layout() {
1716                // Already NHWC [H, W, K] in contiguous memory.
1717                (&[h, w, k][..], ProtoLayout::Nhwc)
1718            } else if protos.ndim() == 3 && protos.strides() == [w as isize, 1, (h * w) as isize] {
1719                // NCHW reinterpreted as NHWC via stride swap. Physical storage
1720                // is [K, H, W] contiguous. Keep in NCHW — eliminates the costly
1721                // 3.1ms transpose entirely.
1722                (&[k, h, w][..], ProtoLayout::Nchw)
1723            } else {
1724                // Unknown layout — fall back to iter copy as NHWC.
1725                (&[h, w, k][..], ProtoLayout::Nhwc)
1726            }
1727        } else {
1728            (&[h, w, k][..], ProtoLayout::Nhwc)
1729        };
1730
1731    let protos_tensor = Tensor::<i8>::new(proto_shape, Some(TensorMemory::Mem), None)
1732        .expect("allocating protos tensor");
1733    {
1734        let mut m = protos_tensor.map().expect("mapping protos tensor");
1735        let dst = m.as_mut_slice();
1736        if std::any::TypeId::of::<PROTO>() == std::any::TypeId::of::<i8>() {
1737            // SAFETY: PROTO == i8 checked via TypeId; cast slice view is
1738            // size/alignment-compatible by construction.
1739            if protos.is_standard_layout() {
1740                let src: &[i8] = unsafe {
1741                    std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len())
1742                };
1743                dst.copy_from_slice(src);
1744            } else if protos.ndim() == 3 && protos.strides() == [w as isize, 1, (h * w) as isize] {
1745                // NCHW physical layout — sequential copy WITHOUT transpose.
1746                // This saves ~3.1ms on A53/A55 by avoiding the tiled
1747                // NCHW→NHWC transpose of the 819KB proto buffer.
1748                let total = h * w * k;
1749                // SAFETY: ArrayView was constructed from a contiguous slice of
1750                // `total` elements. as_ptr() points to the base of that slice.
1751                let src: &[i8] =
1752                    unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, total) };
1753                dst.copy_from_slice(src);
1754            } else {
1755                for (d, s) in dst.iter_mut().zip(protos.iter()) {
1756                    let v_i8: i8 = s.as_();
1757                    *d = v_i8;
1758                }
1759            }
1760        } else {
1761            for (d, s) in dst.iter_mut().zip(protos.iter()) {
1762                let v_i8: i8 = s.as_();
1763                *d = v_i8;
1764            }
1765        }
1766    }
1767    let tensor_quant =
1768        edgefirst_tensor::Quantization::per_tensor(quant_protos.scale, quant_protos.zero_point);
1769    let protos_tensor = protos_tensor
1770        .with_quantization(tensor_quant)
1771        .expect("per-tensor quantization on new Tensor<i8>");
1772
1773    span.record("layout", tracing::field::debug(&proto_layout));
1774
1775    ProtoData {
1776        mask_coefficients,
1777        protos: TensorDyn::I8(protos_tensor),
1778        layout: proto_layout,
1779    }
1780}
1781
1782/// Per-float-dtype construction of a [`TensorDyn`] from a flat slice / 3-D
1783/// `ArrayView`. Replaces the old `IntoProtoTensor` trait. Each implementor
1784/// either passes its element type straight to `Tensor::from_slice` /
1785/// `Tensor::from_arrayview3`, or narrows `f64` to `f32` (no native f64 kernel
1786/// path exists).
1787pub trait FloatProtoElem: Copy + 'static {
1788    fn slice_into_tensor_dyn(
1789        values: &[Self],
1790        shape: &[usize],
1791    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn>;
1792
1793    fn arrayview3_into_tensor_dyn(
1794        view: ArrayView3<'_, Self>,
1795    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn>;
1796}
1797
1798impl FloatProtoElem for f32 {
1799    fn slice_into_tensor_dyn(
1800        values: &[f32],
1801        shape: &[usize],
1802    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1803        edgefirst_tensor::Tensor::<f32>::from_slice(values, shape)
1804            .map(edgefirst_tensor::TensorDyn::F32)
1805    }
1806    fn arrayview3_into_tensor_dyn(
1807        view: ArrayView3<'_, f32>,
1808    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1809        edgefirst_tensor::Tensor::<f32>::from_arrayview3(view).map(edgefirst_tensor::TensorDyn::F32)
1810    }
1811}
1812
1813impl FloatProtoElem for half::f16 {
1814    fn slice_into_tensor_dyn(
1815        values: &[half::f16],
1816        shape: &[usize],
1817    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1818        edgefirst_tensor::Tensor::<half::f16>::from_slice(values, shape)
1819            .map(edgefirst_tensor::TensorDyn::F16)
1820    }
1821    fn arrayview3_into_tensor_dyn(
1822        view: ArrayView3<'_, half::f16>,
1823    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1824        edgefirst_tensor::Tensor::<half::f16>::from_arrayview3(view)
1825            .map(edgefirst_tensor::TensorDyn::F16)
1826    }
1827}
1828
1829impl FloatProtoElem for f64 {
1830    fn slice_into_tensor_dyn(
1831        values: &[f64],
1832        shape: &[usize],
1833    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1834        // Narrow to f32 — no native f64 kernel path.
1835        let narrowed: Vec<f32> = values.iter().map(|&v| v as f32).collect();
1836        edgefirst_tensor::Tensor::<f32>::from_slice(&narrowed, shape)
1837            .map(edgefirst_tensor::TensorDyn::F32)
1838    }
1839    fn arrayview3_into_tensor_dyn(
1840        view: ArrayView3<'_, f64>,
1841    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1842        let narrowed: ndarray::Array3<f32> = view.mapv(|v| v as f32);
1843        edgefirst_tensor::Tensor::<f32>::from_arrayview3(narrowed.view())
1844            .map(edgefirst_tensor::TensorDyn::F32)
1845    }
1846}
1847
1848fn postprocess_yolo<'a, T>(
1849    output: &'a ArrayView2<'_, T>,
1850) -> (ArrayView2<'a, T>, ArrayView2<'a, T>) {
1851    let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1852    let scores_tensor = output.slice(s![4.., ..,]).reversed_axes();
1853    (boxes_tensor, scores_tensor)
1854}
1855
1856pub(crate) fn postprocess_yolo_seg<'a, T>(
1857    output: &'a ArrayView2<'_, T>,
1858    num_protos: usize,
1859) -> (ArrayView2<'a, T>, ArrayView2<'a, T>, ArrayView2<'a, T>) {
1860    assert!(
1861        output.shape()[0] > num_protos + 4,
1862        "Output shape is too short: {} <= {} + 4",
1863        output.shape()[0],
1864        num_protos
1865    );
1866    let num_classes = output.shape()[0] - 4 - num_protos;
1867    let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1868    let scores_tensor = output.slice(s![4..(num_classes + 4), ..,]).reversed_axes();
1869    let mask_tensor = output.slice(s![(num_classes + 4).., ..,]).reversed_axes();
1870    (boxes_tensor, scores_tensor, mask_tensor)
1871}
1872
1873pub(crate) fn postprocess_yolo_split_segdet<'a, 'b, 'c, BOX, SCORE, MASK>(
1874    boxes_tensor: ArrayView2<'a, BOX>,
1875    scores_tensor: ArrayView2<'b, SCORE>,
1876    mask_tensor: ArrayView2<'c, MASK>,
1877) -> (
1878    ArrayView2<'a, BOX>,
1879    ArrayView2<'b, SCORE>,
1880    ArrayView2<'c, MASK>,
1881) {
1882    let boxes_tensor = boxes_tensor.reversed_axes();
1883    let scores_tensor = scores_tensor.reversed_axes();
1884    let mask_tensor = mask_tensor.reversed_axes();
1885    (boxes_tensor, scores_tensor, mask_tensor)
1886}
1887
1888fn decode_segdet_f32<
1889    MASK: Float + AsPrimitive<f32> + Send + Sync,
1890    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1891>(
1892    boxes: Vec<(DetectBox, usize)>,
1893    masks: ArrayView2<MASK>,
1894    protos: ArrayView3<PROTO>,
1895) -> Result<Vec<(DetectBox, BoundingBox, Array3<u8>)>, crate::DecoderError> {
1896    if boxes.is_empty() {
1897        return Ok(Vec::new());
1898    }
1899    if masks.shape()[1] != protos.shape()[2] {
1900        return Err(crate::DecoderError::InvalidShape(format!(
1901            "Mask coefficients count ({}) doesn't match protos channel count ({})",
1902            masks.shape()[1],
1903            protos.shape()[2],
1904        )));
1905    }
1906    boxes
1907        .into_par_iter()
1908        .map(|b| {
1909            let ind = b.1;
1910            // `protobox` returns the cropped proto slice for `make_segmentation`
1911            // and a `roi` snapped to the 1/proto-grid step. The detection bbox
1912            // stays untouched (EDGEAI-1304); the snapped roi is reported back
1913            // separately so callers can describe where the cropped mask lives.
1914            let (protos, roi) = protobox(&protos, &b.0.bbox)?;
1915            Ok((b.0, roi, make_segmentation(masks.row(ind), protos.view())))
1916        })
1917        .collect()
1918}
1919
1920pub(crate) fn decode_segdet_quant<
1921    MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1922    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1923>(
1924    boxes: Vec<(DetectBox, usize)>,
1925    masks: ArrayView2<MASK>,
1926    protos: ArrayView3<PROTO>,
1927    quant_masks: Quantization,
1928    quant_protos: Quantization,
1929) -> Result<Vec<(DetectBox, BoundingBox, Array3<u8>)>, crate::DecoderError> {
1930    if boxes.is_empty() {
1931        return Ok(Vec::new());
1932    }
1933    if masks.shape()[1] != protos.shape()[2] {
1934        return Err(crate::DecoderError::InvalidShape(format!(
1935            "Mask coefficients count ({}) doesn't match protos channel count ({})",
1936            masks.shape()[1],
1937            protos.shape()[2],
1938        )));
1939    }
1940
1941    let total_bits = MASK::zero().count_zeros() + PROTO::zero().count_zeros() + 5; // 32 protos is 2^5
1942    boxes
1943        .into_iter()
1944        .map(|b| {
1945            let i = b.1;
1946            // See EDGEAI-1304: the caller's bbox stays untouched; the
1947            // proto-grid-snapped `roi` is reported back so the Segmentation's
1948            // bounds can describe the actual cropped mask region.
1949            let (protos, roi) = protobox(&protos, &b.0.bbox.to_canonical())?;
1950            let seg = match total_bits {
1951                0..=64 => make_segmentation_quant::<MASK, PROTO, i64>(
1952                    masks.row(i),
1953                    protos.view(),
1954                    quant_masks,
1955                    quant_protos,
1956                ),
1957                65..=128 => make_segmentation_quant::<MASK, PROTO, i128>(
1958                    masks.row(i),
1959                    protos.view(),
1960                    quant_masks,
1961                    quant_protos,
1962                ),
1963                _ => {
1964                    return Err(crate::DecoderError::NotSupported(format!(
1965                        "Unsupported bit width ({total_bits}) for segmentation computation"
1966                    )));
1967                }
1968            };
1969            Ok((b.0, roi, seg))
1970        })
1971        .collect()
1972}
1973
1974fn protobox<'a, T>(
1975    protos: &'a ArrayView3<T>,
1976    roi: &BoundingBox,
1977) -> Result<(ArrayView3<'a, T>, BoundingBox), crate::DecoderError> {
1978    let width = protos.dim().1 as f32;
1979    let height = protos.dim().0 as f32;
1980
1981    // Detect un-normalized bounding boxes (pixel-space coordinates).
1982    // protobox expects normalized coordinates in [0, 1]. The decoder will
1983    // normalize pixel-space coords automatically when the schema declares
1984    // `Detection::normalized = false` AND model input dimensions are known
1985    // (EDGEAI-1303); reaching this guard means at least one of those is
1986    // missing.
1987    //
1988    // The limit is set to 2.0 (not 1.01) because YOLO models legitimately
1989    // predict coordinates slightly > 1.0 for objects near frame edges.
1990    // Any value > 2.0 is clearly pixel-space (even the smallest practical
1991    // model input of 32×32 would produce coordinates >> 2.0).
1992    const NORM_LIMIT: f32 = 2.0;
1993    if roi.xmin > NORM_LIMIT
1994        || roi.ymin > NORM_LIMIT
1995        || roi.xmax > NORM_LIMIT
1996        || roi.ymax > NORM_LIMIT
1997    {
1998        return Err(crate::DecoderError::InvalidShape(format!(
1999            "Bounding box coordinates appear un-normalized (pixel-space). \
2000             Got bbox=({:.2}, {:.2}, {:.2}, {:.2}) but expected values in [0, 1]. \
2001             Two ways to fix this: \
2002             (1) declare `Detection::normalized = false` in the model schema \
2003             AND make sure the schema's `input.shape` / `input.dshape` carries \
2004             the model input dims so the decoder can divide by (W, H) before NMS \
2005             (EDGEAI-1303 — verify with `Decoder::input_dims().is_some()`); or \
2006             (2) normalize the boxes in-graph before decode().",
2007            roi.xmin, roi.ymin, roi.xmax, roi.ymax,
2008        )));
2009    }
2010
2011    let roi = [
2012        (roi.xmin * width).clamp(0.0, width) as usize,
2013        (roi.ymin * height).clamp(0.0, height) as usize,
2014        (roi.xmax * width).clamp(0.0, width).ceil() as usize,
2015        (roi.ymax * height).clamp(0.0, height).ceil() as usize,
2016    ];
2017
2018    let roi_norm = [
2019        roi[0] as f32 / width,
2020        roi[1] as f32 / height,
2021        roi[2] as f32 / width,
2022        roi[3] as f32 / height,
2023    ]
2024    .into();
2025
2026    let cropped = protos.slice(s![roi[1]..roi[3], roi[0]..roi[2], ..]);
2027
2028    Ok((cropped, roi_norm))
2029}
2030
2031/// Compute a single instance segmentation mask from mask coefficients and
2032/// proto maps (float path).
2033///
2034/// Computes `sigmoid(coefficients · protos)` and maps to `[0, 255]`.
2035/// Returns an `(H, W, 1)` u8 array.
2036fn make_segmentation<
2037    MASK: Float + AsPrimitive<f32> + Send + Sync,
2038    PROTO: Float + AsPrimitive<f32> + Send + Sync,
2039>(
2040    mask: ArrayView1<MASK>,
2041    protos: ArrayView3<PROTO>,
2042) -> Array3<u8> {
2043    let shape = protos.shape();
2044
2045    // Safe to unwrap since the shapes will always be compatible
2046    let mask = mask.to_shape((1, mask.len())).unwrap();
2047    let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
2048    let protos = protos.reversed_axes();
2049    let mask = mask.map(|x| x.as_());
2050    let protos = protos.map(|x| x.as_());
2051
2052    // Safe to unwrap since the shapes will always be compatible
2053    let mask = mask
2054        .dot(&protos)
2055        .into_shape_with_order((shape[0], shape[1], 1))
2056        .unwrap();
2057
2058    mask.map(|x| {
2059        let sigmoid = 1.0 / (1.0 + (-*x).exp());
2060        (sigmoid * 255.0).round() as u8
2061    })
2062}
2063
2064/// Compute a single instance segmentation mask from quantized mask
2065/// coefficients and proto maps.
2066///
2067/// Dequantizes both inputs (subtracting zero-points), computes the dot
2068/// product, applies sigmoid, and maps to `[0, 255]`.
2069/// Returns an `(H, W, 1)` u8 array.
2070fn make_segmentation_quant<
2071    MASK: PrimInt + AsPrimitive<DEST> + Send + Sync,
2072    PROTO: PrimInt + AsPrimitive<DEST> + Send + Sync,
2073    DEST: PrimInt + 'static + Signed + AsPrimitive<f32> + Debug,
2074>(
2075    mask: ArrayView1<MASK>,
2076    protos: ArrayView3<PROTO>,
2077    quant_masks: Quantization,
2078    quant_protos: Quantization,
2079) -> Array3<u8>
2080where
2081    i32: AsPrimitive<DEST>,
2082    f32: AsPrimitive<DEST>,
2083{
2084    let shape = protos.shape();
2085
2086    // Safe to unwrap since the shapes will always be compatible
2087    let mask = mask.to_shape((1, mask.len())).unwrap();
2088
2089    let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
2090    let protos = protos.reversed_axes();
2091
2092    let zp = quant_masks.zero_point.as_();
2093
2094    let mask = mask.mapv(|x| x.as_() - zp);
2095
2096    let zp = quant_protos.zero_point.as_();
2097    let protos = protos.mapv(|x| x.as_() - zp);
2098
2099    // Safe to unwrap since the shapes will always be compatible
2100    let segmentation = mask
2101        .dot(&protos)
2102        .into_shape_with_order((shape[0], shape[1], 1))
2103        .unwrap();
2104
2105    let combined_scale = quant_masks.scale * quant_protos.scale;
2106    segmentation.map(|x| {
2107        let val: f32 = (*x).as_() * combined_scale;
2108        let sigmoid = 1.0 / (1.0 + (-val).exp());
2109        (sigmoid * 255.0).round() as u8
2110    })
2111}
2112
2113/// Converts Yolo Instance Segmentation into a 2D mask.
2114///
2115/// The input segmentation is expected to have shape (H, W, 1).
2116///
2117/// The output mask will have shape (H, W), with values 0 or 1 based on the
2118/// threshold.
2119///
2120/// # Errors
2121///
2122/// Returns `DecoderError::InvalidShape` if the input segmentation does not
2123/// have shape (H, W, 1).
2124pub(crate) fn yolo_segmentation_to_mask(
2125    segmentation: ArrayView3<u8>,
2126    threshold: u8,
2127) -> Result<Array2<u8>, crate::DecoderError> {
2128    if segmentation.shape()[2] != 1 {
2129        return Err(crate::DecoderError::InvalidShape(format!(
2130            "Yolo Instance Segmentation should have shape (H, W, 1), got (H, W, {})",
2131            segmentation.shape()[2]
2132        )));
2133    }
2134    Ok(segmentation
2135        .slice(s![.., .., 0])
2136        .map(|x| if *x >= threshold { 1 } else { 0 }))
2137}
2138
2139#[cfg(test)]
2140#[cfg_attr(coverage_nightly, coverage(off))]
2141mod tests {
2142    use super::*;
2143    use ndarray::Array2;
2144
2145    // ========================================================================
2146    // Tests for decode_yolo_end_to_end_det_float
2147    // ========================================================================
2148
2149    #[test]
2150    fn test_end_to_end_det_basic_filtering() {
2151        // Create synthetic end-to-end detection output: (6, N) where rows are
2152        // [x1, y1, x2, y2, conf, class]
2153        // 3 detections: one above threshold, two below
2154        let data: Vec<f32> = vec![
2155            // Detection 0: high score (0.9)
2156            0.1, 0.2, 0.3, // x1 values
2157            0.1, 0.2, 0.3, // y1 values
2158            0.5, 0.6, 0.7, // x2 values
2159            0.5, 0.6, 0.7, // y2 values
2160            0.9, 0.1, 0.2, // confidence scores
2161            0.0, 1.0, 2.0, // class indices
2162        ];
2163        let output = Array2::from_shape_vec((6, 3), data).unwrap();
2164
2165        let mut boxes = Vec::with_capacity(10);
2166        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2167
2168        // Only 1 detection should pass threshold of 0.5
2169        assert_eq!(boxes.len(), 1);
2170        assert_eq!(boxes[0].label, 0);
2171        assert!((boxes[0].score - 0.9).abs() < 0.01);
2172        assert!((boxes[0].bbox.xmin - 0.1).abs() < 0.01);
2173        assert!((boxes[0].bbox.ymin - 0.1).abs() < 0.01);
2174        assert!((boxes[0].bbox.xmax - 0.5).abs() < 0.01);
2175        assert!((boxes[0].bbox.ymax - 0.5).abs() < 0.01);
2176    }
2177
2178    #[test]
2179    fn test_end_to_end_det_all_pass_threshold() {
2180        // All detections above threshold
2181        let data: Vec<f32> = vec![
2182            10.0, 20.0, // x1
2183            10.0, 20.0, // y1
2184            50.0, 60.0, // x2
2185            50.0, 60.0, // y2
2186            0.8, 0.7, // conf (both above 0.5)
2187            1.0, 2.0, // class
2188        ];
2189        let output = Array2::from_shape_vec((6, 2), data).unwrap();
2190
2191        let mut boxes = Vec::with_capacity(10);
2192        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2193
2194        assert_eq!(boxes.len(), 2);
2195        assert_eq!(boxes[0].label, 1);
2196        assert_eq!(boxes[1].label, 2);
2197    }
2198
2199    #[test]
2200    fn test_end_to_end_det_none_pass_threshold() {
2201        // All detections below threshold
2202        let data: Vec<f32> = vec![
2203            10.0, 20.0, // x1
2204            10.0, 20.0, // y1
2205            50.0, 60.0, // x2
2206            50.0, 60.0, // y2
2207            0.1, 0.2, // conf (both below 0.5)
2208            1.0, 2.0, // class
2209        ];
2210        let output = Array2::from_shape_vec((6, 2), data).unwrap();
2211
2212        let mut boxes = Vec::with_capacity(10);
2213        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2214
2215        assert_eq!(boxes.len(), 0);
2216    }
2217
2218    #[test]
2219    fn test_end_to_end_det_capacity_limit() {
2220        // Test that output is truncated to capacity
2221        let data: Vec<f32> = vec![
2222            0.1, 0.2, 0.3, 0.4, 0.5, // x1
2223            0.1, 0.2, 0.3, 0.4, 0.5, // y1
2224            0.5, 0.6, 0.7, 0.8, 0.9, // x2
2225            0.5, 0.6, 0.7, 0.8, 0.9, // y2
2226            0.9, 0.9, 0.9, 0.9, 0.9, // conf (all pass)
2227            0.0, 1.0, 2.0, 3.0, 4.0, // class
2228        ];
2229        let output = Array2::from_shape_vec((6, 5), data).unwrap();
2230
2231        let mut boxes = Vec::with_capacity(2); // Only allow 2 boxes
2232        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2233
2234        assert_eq!(boxes.len(), 2);
2235    }
2236
2237    #[test]
2238    fn test_end_to_end_det_empty_output() {
2239        // Test with zero detections
2240        let output = Array2::<f32>::zeros((6, 0));
2241
2242        let mut boxes = Vec::with_capacity(10);
2243        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2244
2245        assert_eq!(boxes.len(), 0);
2246    }
2247
2248    #[test]
2249    fn test_end_to_end_det_pixel_coordinates() {
2250        // Test with pixel coordinates (non-normalized)
2251        let data: Vec<f32> = vec![
2252            100.0, // x1
2253            200.0, // y1
2254            300.0, // x2
2255            400.0, // y2
2256            0.95,  // conf
2257            5.0,   // class
2258        ];
2259        let output = Array2::from_shape_vec((6, 1), data).unwrap();
2260
2261        let mut boxes = Vec::with_capacity(10);
2262        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2263
2264        assert_eq!(boxes.len(), 1);
2265        assert_eq!(boxes[0].label, 5);
2266        assert!((boxes[0].bbox.xmin - 100.0).abs() < 0.01);
2267        assert!((boxes[0].bbox.ymin - 200.0).abs() < 0.01);
2268        assert!((boxes[0].bbox.xmax - 300.0).abs() < 0.01);
2269        assert!((boxes[0].bbox.ymax - 400.0).abs() < 0.01);
2270    }
2271
2272    #[test]
2273    fn test_end_to_end_det_invalid_shape() {
2274        // Test with too few rows (needs at least 6)
2275        let output = Array2::<f32>::zeros((5, 3));
2276
2277        let mut boxes = Vec::with_capacity(10);
2278        let result = decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes);
2279
2280        assert!(result.is_err());
2281        assert!(matches!(
2282            result,
2283            Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 6 rows")
2284        ));
2285    }
2286
2287    // ========================================================================
2288    // Tests for decode_yolo_end_to_end_segdet_float
2289    // ========================================================================
2290
2291    #[test]
2292    fn test_end_to_end_segdet_basic() {
2293        // Create synthetic segdet output: (6 + num_protos, N)
2294        // Detection format: [x1, y1, x2, y2, conf, class, mask_coeff_0..31]
2295        let num_protos = 32;
2296        let num_detections = 2;
2297        let num_features = 6 + num_protos;
2298
2299        // Build detection tensor
2300        let mut data = vec![0.0f32; num_features * num_detections];
2301        // Detection 0: passes threshold
2302        data[0] = 0.1; // x1[0]
2303        data[1] = 0.5; // x1[1]
2304        data[num_detections] = 0.1; // y1[0]
2305        data[num_detections + 1] = 0.5; // y1[1]
2306        data[2 * num_detections] = 0.4; // x2[0]
2307        data[2 * num_detections + 1] = 0.9; // x2[1]
2308        data[3 * num_detections] = 0.4; // y2[0]
2309        data[3 * num_detections + 1] = 0.9; // y2[1]
2310        data[4 * num_detections] = 0.9; // conf[0] - passes
2311        data[4 * num_detections + 1] = 0.3; // conf[1] - fails
2312        data[5 * num_detections] = 1.0; // class[0]
2313        data[5 * num_detections + 1] = 2.0; // class[1]
2314                                            // Fill mask coefficients with small values
2315        for i in 6..num_features {
2316            data[i * num_detections] = 0.1;
2317            data[i * num_detections + 1] = 0.1;
2318        }
2319
2320        let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
2321
2322        // Create protos tensor: (proto_height, proto_width, num_protos)
2323        let protos = Array3::<f32>::zeros((16, 16, num_protos));
2324
2325        let mut boxes = Vec::with_capacity(10);
2326        let mut masks = Vec::with_capacity(10);
2327        decode_yolo_end_to_end_segdet_float(
2328            output.view(),
2329            protos.view(),
2330            0.5,
2331            &mut boxes,
2332            &mut masks,
2333        )
2334        .unwrap();
2335
2336        // Only detection 0 should pass
2337        assert_eq!(boxes.len(), 1);
2338        assert_eq!(masks.len(), 1);
2339        assert_eq!(boxes[0].label, 1);
2340        assert!((boxes[0].score - 0.9).abs() < 0.01);
2341    }
2342
2343    #[test]
2344    fn test_end_to_end_segdet_mask_coordinates() {
2345        // Test that mask coordinates match box coordinates
2346        let num_protos = 32;
2347        let num_features = 6 + num_protos;
2348
2349        let mut data = vec![0.0f32; num_features];
2350        data[0] = 0.2; // x1
2351        data[1] = 0.2; // y1
2352        data[2] = 0.8; // x2
2353        data[3] = 0.8; // y2
2354        data[4] = 0.95; // conf
2355        data[5] = 3.0; // class
2356
2357        let output = Array2::from_shape_vec((num_features, 1), data).unwrap();
2358        let protos = Array3::<f32>::zeros((16, 16, num_protos));
2359
2360        let mut boxes = Vec::with_capacity(10);
2361        let mut masks = Vec::with_capacity(10);
2362        decode_yolo_end_to_end_segdet_float(
2363            output.view(),
2364            protos.view(),
2365            0.5,
2366            &mut boxes,
2367            &mut masks,
2368        )
2369        .unwrap();
2370
2371        assert_eq!(boxes.len(), 1);
2372        assert_eq!(masks.len(), 1);
2373
2374        // Mask region is the proto-grid-aligned crop and encloses the
2375        // post-NMS bbox (EDGEAI-1304); on a 16x16 grid each side may snap
2376        // by up to 1/16 = 0.0625.
2377        let step = 1.0 / 16.0;
2378        assert!(masks[0].xmin <= boxes[0].bbox.xmin);
2379        assert!(masks[0].ymin <= boxes[0].bbox.ymin);
2380        assert!(masks[0].xmax >= boxes[0].bbox.xmax);
2381        assert!(masks[0].ymax >= boxes[0].bbox.ymax);
2382        assert!((boxes[0].bbox.xmin - masks[0].xmin) < step);
2383        assert!((boxes[0].bbox.ymin - masks[0].ymin) < step);
2384        assert!((masks[0].xmax - boxes[0].bbox.xmax) < step);
2385        assert!((masks[0].ymax - boxes[0].bbox.ymax) < step);
2386    }
2387
2388    #[test]
2389    fn test_end_to_end_segdet_empty_output() {
2390        let num_protos = 32;
2391        let output = Array2::<f32>::zeros((6 + num_protos, 0));
2392        let protos = Array3::<f32>::zeros((16, 16, num_protos));
2393
2394        let mut boxes = Vec::with_capacity(10);
2395        let mut masks = Vec::with_capacity(10);
2396        decode_yolo_end_to_end_segdet_float(
2397            output.view(),
2398            protos.view(),
2399            0.5,
2400            &mut boxes,
2401            &mut masks,
2402        )
2403        .unwrap();
2404
2405        assert_eq!(boxes.len(), 0);
2406        assert_eq!(masks.len(), 0);
2407    }
2408
2409    #[test]
2410    fn test_end_to_end_segdet_capacity_limit() {
2411        let num_protos = 32;
2412        let num_detections = 5;
2413        let num_features = 6 + num_protos;
2414
2415        let mut data = vec![0.0f32; num_features * num_detections];
2416        // All detections pass threshold
2417        for i in 0..num_detections {
2418            data[i] = 0.1 * (i as f32); // x1
2419            data[num_detections + i] = 0.1 * (i as f32); // y1
2420            data[2 * num_detections + i] = 0.1 * (i as f32) + 0.2; // x2
2421            data[3 * num_detections + i] = 0.1 * (i as f32) + 0.2; // y2
2422            data[4 * num_detections + i] = 0.9; // conf
2423            data[5 * num_detections + i] = i as f32; // class
2424        }
2425
2426        let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
2427        let protos = Array3::<f32>::zeros((16, 16, num_protos));
2428
2429        let mut boxes = Vec::with_capacity(2); // Limit to 2
2430        let mut masks = Vec::with_capacity(2);
2431        decode_yolo_end_to_end_segdet_float(
2432            output.view(),
2433            protos.view(),
2434            0.5,
2435            &mut boxes,
2436            &mut masks,
2437        )
2438        .unwrap();
2439
2440        assert_eq!(boxes.len(), 2);
2441        assert_eq!(masks.len(), 2);
2442    }
2443
2444    #[test]
2445    fn test_end_to_end_segdet_invalid_shape_too_few_rows() {
2446        // Test with too few rows (needs at least 7: 6 base + 1 mask coeff)
2447        let output = Array2::<f32>::zeros((6, 3));
2448        let protos = Array3::<f32>::zeros((16, 16, 32));
2449
2450        let mut boxes = Vec::with_capacity(10);
2451        let mut masks = Vec::with_capacity(10);
2452        let result = decode_yolo_end_to_end_segdet_float(
2453            output.view(),
2454            protos.view(),
2455            0.5,
2456            &mut boxes,
2457            &mut masks,
2458        );
2459
2460        assert!(result.is_err());
2461        assert!(matches!(
2462            result,
2463            Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 7 rows")
2464        ));
2465    }
2466
2467    #[test]
2468    fn test_end_to_end_segdet_invalid_shape_protos_mismatch() {
2469        // Test with mismatched mask coefficients and protos count
2470        let num_protos = 32;
2471        let output = Array2::<f32>::zeros((6 + 16, 3)); // 16 mask coeffs
2472        let protos = Array3::<f32>::zeros((16, 16, num_protos)); // 32 protos
2473
2474        let mut boxes = Vec::with_capacity(10);
2475        let mut masks = Vec::with_capacity(10);
2476        let result = decode_yolo_end_to_end_segdet_float(
2477            output.view(),
2478            protos.view(),
2479            0.5,
2480            &mut boxes,
2481            &mut masks,
2482        );
2483
2484        assert!(result.is_err());
2485        assert!(matches!(
2486            result,
2487            Err(crate::DecoderError::InvalidShape(s)) if s.contains("doesn't match protos count")
2488        ));
2489    }
2490
2491    // ========================================================================
2492    // Tests for decode_yolo_split_end_to_end_segdet_float
2493    // ========================================================================
2494
2495    #[test]
2496    fn test_split_end_to_end_segdet_basic() {
2497        // Create synthetic segdet output: (6 + num_protos, N)
2498        // Detection format: [x1, y1, x2, y2, conf, class, mask_coeff_0..31]
2499        let num_protos = 32;
2500        let num_detections = 2;
2501        let num_features = 6 + num_protos;
2502
2503        // Build detection tensor
2504        let mut data = vec![0.0f32; num_features * num_detections];
2505        // Detection 0: passes threshold
2506        data[0] = 0.1; // x1[0]
2507        data[1] = 0.5; // x1[1]
2508        data[num_detections] = 0.1; // y1[0]
2509        data[num_detections + 1] = 0.5; // y1[1]
2510        data[2 * num_detections] = 0.4; // x2[0]
2511        data[2 * num_detections + 1] = 0.9; // x2[1]
2512        data[3 * num_detections] = 0.4; // y2[0]
2513        data[3 * num_detections + 1] = 0.9; // y2[1]
2514        data[4 * num_detections] = 0.9; // conf[0] - passes
2515        data[4 * num_detections + 1] = 0.3; // conf[1] - fails
2516        data[5 * num_detections] = 1.0; // class[0]
2517        data[5 * num_detections + 1] = 2.0; // class[1]
2518                                            // Fill mask coefficients with small values
2519        for i in 6..num_features {
2520            data[i * num_detections] = 0.1;
2521            data[i * num_detections + 1] = 0.1;
2522        }
2523
2524        let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
2525        let box_coords = output.slice(s![..4, ..]);
2526        let scores = output.slice(s![4..5, ..]);
2527        let classes = output.slice(s![5..6, ..]);
2528        let mask_coeff = output.slice(s![6.., ..]);
2529        // Create protos tensor: (proto_height, proto_width, num_protos)
2530        let protos = Array3::<f32>::zeros((16, 16, num_protos));
2531
2532        let mut boxes = Vec::with_capacity(10);
2533        let mut masks = Vec::with_capacity(10);
2534        decode_yolo_split_end_to_end_segdet_float(
2535            box_coords,
2536            scores,
2537            classes,
2538            mask_coeff,
2539            protos.view(),
2540            0.5,
2541            &mut boxes,
2542            &mut masks,
2543        )
2544        .unwrap();
2545
2546        // Only detection 0 should pass
2547        assert_eq!(boxes.len(), 1);
2548        assert_eq!(masks.len(), 1);
2549        assert_eq!(boxes[0].label, 1);
2550        assert!((boxes[0].score - 0.9).abs() < 0.01);
2551    }
2552
2553    // ========================================================================
2554    // Tests for yolo_segmentation_to_mask
2555    // ========================================================================
2556
2557    #[test]
2558    fn test_segmentation_to_mask_basic() {
2559        // Create a 4x4x1 segmentation with values above and below threshold
2560        let data: Vec<u8> = vec![
2561            100, 200, 50, 150, // row 0
2562            10, 255, 128, 64, // row 1
2563            0, 127, 128, 255, // row 2
2564            64, 64, 192, 192, // row 3
2565        ];
2566        let segmentation = Array3::from_shape_vec((4, 4, 1), data).unwrap();
2567
2568        let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2569
2570        // Values >= 128 should be 1, others 0
2571        assert_eq!(mask[[0, 0]], 0); // 100 < 128
2572        assert_eq!(mask[[0, 1]], 1); // 200 >= 128
2573        assert_eq!(mask[[0, 2]], 0); // 50 < 128
2574        assert_eq!(mask[[0, 3]], 1); // 150 >= 128
2575        assert_eq!(mask[[1, 1]], 1); // 255 >= 128
2576        assert_eq!(mask[[1, 2]], 1); // 128 >= 128
2577        assert_eq!(mask[[2, 0]], 0); // 0 < 128
2578        assert_eq!(mask[[2, 1]], 0); // 127 < 128
2579    }
2580
2581    #[test]
2582    fn test_segmentation_to_mask_all_above() {
2583        let segmentation = Array3::from_elem((4, 4, 1), 255u8);
2584        let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2585        assert!(mask.iter().all(|&x| x == 1));
2586    }
2587
2588    #[test]
2589    fn test_segmentation_to_mask_all_below() {
2590        let segmentation = Array3::from_elem((4, 4, 1), 64u8);
2591        let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2592        assert!(mask.iter().all(|&x| x == 0));
2593    }
2594
2595    #[test]
2596    fn test_segmentation_to_mask_invalid_shape() {
2597        let segmentation = Array3::from_elem((4, 4, 3), 128u8);
2598        let result = yolo_segmentation_to_mask(segmentation.view(), 128);
2599
2600        assert!(result.is_err());
2601        assert!(matches!(
2602            result,
2603            Err(crate::DecoderError::InvalidShape(s)) if s.contains("(H, W, 1)")
2604        ));
2605    }
2606
2607    // ========================================================================
2608    // Tests for protobox / NORM_LIMIT regression
2609    // ========================================================================
2610
2611    #[test]
2612    fn test_protobox_clamps_edge_coordinates() {
2613        // bbox with xmax=1.0 should not panic (OOB guard)
2614        let protos = Array3::<f32>::zeros((16, 16, 4));
2615        let view = protos.view();
2616        let roi = BoundingBox {
2617            xmin: 0.5,
2618            ymin: 0.5,
2619            xmax: 1.0,
2620            ymax: 1.0,
2621        };
2622        let result = protobox(&view, &roi);
2623        assert!(result.is_ok(), "protobox should accept xmax=1.0");
2624        let (cropped, _roi_norm) = result.unwrap();
2625        // Cropped region must have non-zero spatial dimensions
2626        assert!(cropped.shape()[0] > 0);
2627        assert!(cropped.shape()[1] > 0);
2628        assert_eq!(cropped.shape()[2], 4);
2629    }
2630
2631    #[test]
2632    fn test_protobox_rejects_wildly_out_of_range() {
2633        // bbox with coords > NORM_LIMIT (e.g. 3.0) returns error
2634        let protos = Array3::<f32>::zeros((16, 16, 4));
2635        let view = protos.view();
2636        let roi = BoundingBox {
2637            xmin: 0.0,
2638            ymin: 0.0,
2639            xmax: 3.0,
2640            ymax: 3.0,
2641        };
2642        let result = protobox(&view, &roi);
2643        assert!(
2644            matches!(result, Err(crate::DecoderError::InvalidShape(s)) if s.contains("un-normalized")),
2645            "protobox should reject coords > NORM_LIMIT"
2646        );
2647    }
2648
2649    #[test]
2650    fn test_protobox_accepts_slightly_over_one() {
2651        // bbox with coords at 1.5 (within NORM_LIMIT=2.0) succeeds
2652        let protos = Array3::<f32>::zeros((16, 16, 4));
2653        let view = protos.view();
2654        let roi = BoundingBox {
2655            xmin: 0.0,
2656            ymin: 0.0,
2657            xmax: 1.5,
2658            ymax: 1.5,
2659        };
2660        let result = protobox(&view, &roi);
2661        assert!(
2662            result.is_ok(),
2663            "protobox should accept coords <= NORM_LIMIT (2.0)"
2664        );
2665        let (cropped, _roi_norm) = result.unwrap();
2666        // Entire proto map should be selected when coords > 1.0 (clamped to boundary)
2667        assert_eq!(cropped.shape()[0], 16);
2668        assert_eq!(cropped.shape()[1], 16);
2669    }
2670
2671    #[test]
2672    fn test_segdet_float_proto_no_panic() {
2673        // Simulates YOLOv8n-seg: output0 = [116, 8400] (4 box + 80 class + 32 mask coeff)
2674        // output1 (protos) = [32, 160, 160]
2675        let num_proposals = 100; // enough to produce idx >= 32
2676        let num_classes = 80;
2677        let num_mask_coeffs = 32;
2678        let rows = 4 + num_classes + num_mask_coeffs; // 116
2679
2680        // Fill boxes with valid xywh data so some detections pass the threshold.
2681        // Layout is [116, num_proposals] row-major: row 0=cx, 1=cy, 2=w, 3=h,
2682        // rows 4..84=class scores, rows 84..116=mask coefficients.
2683        let mut data = vec![0.0f32; rows * num_proposals];
2684        for i in 0..num_proposals {
2685            let row = |r: usize| r * num_proposals + i;
2686            data[row(0)] = 320.0; // cx
2687            data[row(1)] = 320.0; // cy
2688            data[row(2)] = 50.0; // w
2689            data[row(3)] = 50.0; // h
2690            data[row(4)] = 0.9; // class-0 score
2691        }
2692        let boxes = ndarray::Array2::from_shape_vec((rows, num_proposals), data).unwrap();
2693
2694        // Protos must be in HWC order. Under the HAL physical-order
2695        // contract, callers declare shape+dshape matching producer memory
2696        // and swap_axes_if_needed permutes the stride tuple into canonical
2697        // [batch, height, width, num_protos] before this function sees it.
2698        let protos = ndarray::Array3::<f32>::zeros((160, 160, num_mask_coeffs));
2699
2700        let mut output_boxes = Vec::with_capacity(300);
2701
2702        // This panicked before fix: mask_tensor.row(idx) with idx >= 32
2703        let proto_data = impl_yolo_segdet_float_proto::<XYWH, _, _>(
2704            boxes.view(),
2705            protos.view(),
2706            0.5,
2707            0.7,
2708            Some(Nms::default()),
2709            MAX_NMS_CANDIDATES,
2710            300,
2711            None,
2712            None,
2713            &mut output_boxes,
2714        );
2715
2716        // Should produce detections (NMS will collapse many overlapping boxes)
2717        assert!(!output_boxes.is_empty());
2718        let coeffs_shape = proto_data.mask_coefficients.shape();
2719        assert_eq!(coeffs_shape[0], output_boxes.len());
2720        // Each mask coefficient vector should have 32 elements
2721        assert_eq!(coeffs_shape[1], num_mask_coeffs);
2722    }
2723
2724    // ========================================================================
2725    // Pre-NMS top-K cap (MAX_NMS_CANDIDATES)
2726    // ========================================================================
2727
2728    /// At very low score thresholds (e.g., t=0.01 on YOLOv8 with 8400×80
2729    /// candidates) almost every score passes the filter, feeding O(n²)
2730    /// NMS and a per-survivor mask matmul. The decoder caps the
2731    /// candidate set fed to NMS at `MAX_NMS_CANDIDATES` (Ultralytics
2732    /// default 30 000) to bound worst-case decode time.
2733    ///
2734    /// This regression test pumps 50 000 above-threshold candidates
2735    /// into `impl_yolo_segdet_get_boxes` with NMS bypassed (Nms=None)
2736    /// and a generous post-NMS cap. Before the fix, the function
2737    /// returned all 50 000; after the fix, exactly 30 000.
2738    #[test]
2739    fn test_pre_nms_cap_truncates_excess_candidates() {
2740        let n: usize = 50_000;
2741        let num_classes = 1;
2742
2743        // Identical valid boxes. Distinct scores (descending) so the
2744        // top-K cap keeps the highest-scoring ones in deterministic
2745        // order — letting us assert *which* ones survived.
2746        let mut boxes_data = Vec::with_capacity(n * 4);
2747        let mut scores_data = Vec::with_capacity(n * num_classes);
2748        for i in 0..n {
2749            boxes_data.extend_from_slice(&[0.1f32, 0.1, 0.5, 0.5]);
2750            // score_i = 0.99 - i * 1e-7 keeps everything well above 0.1
2751            // threshold but strictly decreasing.
2752            scores_data.push(0.99 - (i as f32) * 1e-7);
2753        }
2754        let boxes = Array2::from_shape_vec((n, 4), boxes_data).unwrap();
2755        let scores = Array2::from_shape_vec((n, num_classes), scores_data).unwrap();
2756
2757        let result = impl_yolo_segdet_get_boxes::<XYXY, _, _>(
2758            boxes.view(),
2759            scores.view(),
2760            0.1,
2761            1.0,                             // IoU 1.0 → NMS suppresses nothing
2762            Some(Nms::ClassAgnostic),        // NMS enabled so pre_nms_top_k applies
2763            crate::yolo::MAX_NMS_CANDIDATES, // pre_nms_top_k
2764            usize::MAX,                      // no post-NMS truncation
2765        );
2766
2767        assert_eq!(
2768            result.len(),
2769            crate::yolo::MAX_NMS_CANDIDATES,
2770            "pre-NMS cap should truncate to MAX_NMS_CANDIDATES; got {}",
2771            result.len()
2772        );
2773        // Top-K survivors: highest scores were the first n indices,
2774        // so survivor 0 must have score ~0.99.
2775        let top_score = result[0].0.score;
2776        assert!(
2777            top_score > 0.98,
2778            "highest-ranked survivor should have the largest score, got {top_score}"
2779        );
2780    }
2781
2782    /// Counterpart for the quantized split path. Same contract: feed
2783    /// more than `MAX_NMS_CANDIDATES` survivors above the quantized
2784    /// threshold, confirm `impl_yolo_split_segdet_quant_get_boxes`
2785    /// truncates before NMS.
2786    #[test]
2787    fn test_pre_nms_cap_truncates_excess_candidates_quant() {
2788        use crate::Quantization;
2789        let n: usize = 50_000;
2790        let num_classes = 1;
2791
2792        // i8 boxes with simple scale/zp; the box value 50 dequantizes
2793        // to 0.5 with scale=0.01, zp=0 — fine for a flat box set.
2794        let boxes_data = (0..n).flat_map(|_| [10i8, 10, 50, 50]).collect::<Vec<_>>();
2795        let boxes = Array2::from_shape_vec((n, 4), boxes_data).unwrap();
2796        let quant_boxes = Quantization {
2797            scale: 0.01,
2798            zero_point: 0,
2799        };
2800
2801        // u8 scores: distinct descending values, all well above threshold.
2802        // value 250 → 0.98 with scale 0.00392, zp 0.
2803        // value (250 - i % 200) keeps a wide spread above the dequant
2804        // threshold of 0.5.
2805        let scores_data: Vec<u8> = (0..n)
2806            .map(|i| 250u8.saturating_sub((i % 200) as u8))
2807            .collect();
2808        let scores = Array2::from_shape_vec((n, num_classes), scores_data).unwrap();
2809        let quant_scores = Quantization {
2810            scale: 0.00392,
2811            zero_point: 0,
2812        };
2813
2814        let result = impl_yolo_split_segdet_quant_get_boxes::<XYXY, _, _>(
2815            (boxes.view(), quant_boxes),
2816            (scores.view(), quant_scores),
2817            0.1,
2818            1.0,                             // IoU 1.0 → NMS suppresses nothing
2819            Some(Nms::ClassAgnostic),        // NMS enabled so pre_nms_top_k applies
2820            crate::yolo::MAX_NMS_CANDIDATES, // pre_nms_top_k
2821            usize::MAX,                      // no post-NMS truncation
2822        );
2823
2824        assert_eq!(
2825            result.len(),
2826            crate::yolo::MAX_NMS_CANDIDATES,
2827            "quant path pre-NMS cap should truncate to MAX_NMS_CANDIDATES; got {}",
2828            result.len()
2829        );
2830    }
2831
2832    /// Regression test for HAILORT_BUG.md — the YoloSegDet path
2833    /// (combined `(4 + nc + nm, N)` detection tensor + separate protos)
2834    /// must pair each surviving detection with the mask coefficient
2835    /// row at the SAME anchor index the box came from. The validator
2836    /// sees this path miss the pairing under schema-v2 Hailo inputs
2837    /// (mAP collapse from 46.8 → 3.65 while mask IoU stays at 66.9,
2838    /// the fingerprint of mask-to-detection misalignment).
2839    ///
2840    /// Construction: three anchors with distinct mask-coef signatures
2841    /// that, after dot(coefs, protos) + sigmoid, produce HIGH vs LOW
2842    /// mask pixel values. Two anchors survive (one high, one low); if
2843    /// the mask row is looked up at the wrong index, the per-detection
2844    /// mean mask value would cross the threshold and we catch it.
2845    #[test]
2846    fn segdet_combined_tensor_pairs_detection_with_matching_mask_row() {
2847        let nc = 2; // num_classes
2848        let nm = 2; // num_protos
2849        let n = 3; // num_anchors
2850        let feat = 4 + nc + nm; // 8
2851
2852        // Tensor layout: (8, 3) rows=features, cols=anchors.
2853        // Row indices:  0..4 = xywh, 4..6 = scores, 6..8 = mask_coefs.
2854        //
2855        //         anchor 0 | anchor 1 | anchor 2
2856        // xc       0.2      | 0.5      | 0.8
2857        // yc       0.2      | 0.5      | 0.8
2858        // w        0.1      | 0.1      | 0.1
2859        // h        0.1      | 0.1      | 0.1
2860        // s[0]     0.9      | 0.0      | 0.8   (class 0)
2861        // s[1]     0.0      | 0.0      | 0.0   (class 1 — always loses)
2862        // m[0]     3.0      | 0.0      | -3.0  (high for a0, low for a2)
2863        // m[1]     3.0      | 0.0      | -3.0
2864        //
2865        // Proto[0] = Proto[1] = all-ones (8x8), so
2866        //   mask(a0) = sigmoid(3 + 3) ≈ 0.9975 → 254
2867        //   mask(a2) = sigmoid(-3 + -3) ≈ 0.0025 → 1
2868        // 250-point gap makes any misalignment trivially detectable.
2869        let mut data = vec![0.0f32; feat * n];
2870        let set = |d: &mut [f32], r: usize, c: usize, v: f32| d[r * n + c] = v;
2871        set(&mut data, 0, 0, 0.2);
2872        set(&mut data, 1, 0, 0.2);
2873        set(&mut data, 2, 0, 0.1);
2874        set(&mut data, 3, 0, 0.1);
2875        set(&mut data, 0, 1, 0.5);
2876        set(&mut data, 1, 1, 0.5);
2877        set(&mut data, 2, 1, 0.1);
2878        set(&mut data, 3, 1, 0.1);
2879        set(&mut data, 0, 2, 0.8);
2880        set(&mut data, 1, 2, 0.8);
2881        set(&mut data, 2, 2, 0.1);
2882        set(&mut data, 3, 2, 0.1);
2883        set(&mut data, 4, 0, 0.9);
2884        set(&mut data, 4, 2, 0.8);
2885        set(&mut data, 6, 0, 3.0);
2886        set(&mut data, 7, 0, 3.0);
2887        set(&mut data, 6, 2, -3.0);
2888        set(&mut data, 7, 2, -3.0);
2889
2890        let output = Array2::from_shape_vec((feat, n), data).unwrap();
2891        let protos = Array3::<f32>::from_elem((8, 8, nm), 1.0);
2892
2893        let mut boxes: Vec<DetectBox> = Vec::with_capacity(10);
2894        let mut masks: Vec<Segmentation> = Vec::with_capacity(10);
2895        decode_yolo_segdet_float(
2896            output.view(),
2897            protos.view(),
2898            0.5,
2899            0.5,
2900            Some(Nms::ClassAgnostic),
2901            &mut boxes,
2902            &mut masks,
2903        )
2904        .unwrap();
2905
2906        assert_eq!(
2907            boxes.len(),
2908            2,
2909            "two anchors above threshold should survive (a0 score=0.9, a2 score=0.8); got {}",
2910            boxes.len()
2911        );
2912
2913        // Build a (anchor_index → mask_mean) mapping from the results.
2914        // Anchor 0 has centre (0.2, 0.2), anchor 2 has centre (0.8,
2915        // 0.8). The DetectBox bbox is the post-XYWH-to-XYXY conversion
2916        // of the original xywh; cropping inside protobox may shrink it,
2917        // so match by centre (0.2 vs 0.8) rather than exact bbox.
2918        for (b, m) in boxes.iter().zip(masks.iter()) {
2919            let cx = (b.bbox.xmin + b.bbox.xmax) * 0.5;
2920            let mean = {
2921                let s = &m.segmentation;
2922                let total: u32 = s.iter().map(|&v| v as u32).sum();
2923                total as f32 / s.len() as f32
2924            };
2925            if cx < 0.3 {
2926                // anchor 0 — expect HIGH mask values ≈ 254
2927                assert!(
2928                    mean > 200.0,
2929                    "anchor 0 detection (centre {cx:.2}) should have high-value mask; got mean {mean}"
2930                );
2931            } else if cx > 0.7 {
2932                // anchor 2 — expect LOW mask values ≈ 1
2933                assert!(
2934                    mean < 50.0,
2935                    "anchor 2 detection (centre {cx:.2}) should have low-value mask; got mean {mean}"
2936                );
2937            } else {
2938                panic!("unexpected detection centre {cx:.2}");
2939            }
2940        }
2941    }
2942
2943    // ========================================================================
2944    // Tests for truncate_to_top_k_by_score / truncate_to_top_k_by_score_quant
2945    // ========================================================================
2946
2947    /// Helper: build a Vec of (DetectBox, ()) with the given scores.
2948    fn make_float_boxes(scores: &[f32]) -> Vec<(DetectBox, ())> {
2949        scores
2950            .iter()
2951            .enumerate()
2952            .map(|(i, &s)| {
2953                (
2954                    DetectBox {
2955                        bbox: BoundingBox {
2956                            xmin: 0.0,
2957                            ymin: 0.0,
2958                            xmax: 1.0,
2959                            ymax: 1.0,
2960                        },
2961                        score: s,
2962                        label: i,
2963                    },
2964                    (),
2965                )
2966            })
2967            .collect()
2968    }
2969
2970    /// Helper: build a Vec of (DetectBoxQuantized<i8>, ()) with the given scores.
2971    fn make_quant_boxes(scores: &[i8]) -> Vec<(DetectBoxQuantized<i8>, ())> {
2972        scores
2973            .iter()
2974            .enumerate()
2975            .map(|(i, &s)| {
2976                (
2977                    DetectBoxQuantized {
2978                        bbox: BoundingBox {
2979                            xmin: 0.0,
2980                            ymin: 0.0,
2981                            xmax: 1.0,
2982                            ymax: 1.0,
2983                        },
2984                        score: s,
2985                        label: i,
2986                    },
2987                    (),
2988                )
2989            })
2990            .collect()
2991    }
2992
2993    #[test]
2994    fn truncate_float_top_k_zero_is_unbounded() {
2995        let mut boxes = make_float_boxes(&[0.9, 0.1, 0.5, 0.3, 0.7]);
2996        let original_len = boxes.len();
2997        truncate_to_top_k_by_score(&mut boxes, 0);
2998        assert_eq!(
2999            boxes.len(),
3000            original_len,
3001            "top_k=0 should keep all candidates (no-limit semantics)"
3002        );
3003    }
3004
3005    #[test]
3006    fn truncate_float_top_k_normal() {
3007        let mut boxes = make_float_boxes(&[0.9, 0.1, 0.5, 0.3, 0.7]);
3008        truncate_to_top_k_by_score(&mut boxes, 3);
3009        assert_eq!(boxes.len(), 3);
3010        // The top-3 scores should be 0.9, 0.7, 0.5 (order within top-K is unspecified)
3011        let mut retained: Vec<f32> = boxes.iter().map(|(b, _)| b.score).collect();
3012        retained.sort_by(|a, b| b.total_cmp(a));
3013        assert_eq!(retained, vec![0.9, 0.7, 0.5]);
3014    }
3015
3016    #[test]
3017    fn truncate_float_top_k_noop_when_under_cap() {
3018        let mut boxes = make_float_boxes(&[0.9, 0.5]);
3019        truncate_to_top_k_by_score(&mut boxes, 10);
3020        assert_eq!(boxes.len(), 2, "should be no-op when len <= top_k");
3021    }
3022
3023    #[test]
3024    fn truncate_quant_top_k_zero_is_unbounded() {
3025        let mut boxes = make_quant_boxes(&[120, -50, 30, -10, 80]);
3026        let original_len = boxes.len();
3027        truncate_to_top_k_by_score_quant(&mut boxes, 0);
3028        assert_eq!(
3029            boxes.len(),
3030            original_len,
3031            "top_k=0 should keep all candidates (no-limit semantics)"
3032        );
3033    }
3034
3035    #[test]
3036    fn truncate_quant_top_k_normal() {
3037        let mut boxes = make_quant_boxes(&[120, -50, 30, -10, 80]);
3038        truncate_to_top_k_by_score_quant(&mut boxes, 3);
3039        assert_eq!(boxes.len(), 3);
3040        let mut retained: Vec<i8> = boxes.iter().map(|(b, _)| b.score).collect();
3041        retained.sort_by(|a, b| b.cmp(a));
3042        assert_eq!(retained, vec![120, 80, 30]);
3043    }
3044
3045    #[test]
3046    fn truncate_quant_top_k_noop_when_under_cap() {
3047        let mut boxes = make_quant_boxes(&[120, 80]);
3048        truncate_to_top_k_by_score_quant(&mut boxes, 10);
3049        assert_eq!(boxes.len(), 2, "should be no-op when len <= top_k");
3050    }
3051}