Skip to main content

edgefirst_decoder/
yolo.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4//! Internal YOLO decoder kernels.
5//!
6//! All items in this module are `pub(crate)`. The public API surface for
7//! decoding is [`crate::Decoder`] + [`crate::DecoderBuilder`]; external
8//! callers must go through that entry point so we can evolve the kernels
9//! (split paths, dispatch tables, NEON tiers) without breaking semver.
10//! See `CHANGELOG.md` for the 0.20.0 narrowing.
11
12use std::fmt::Debug;
13
14use ndarray::{
15    parallel::prelude::{IntoParallelIterator, ParallelIterator},
16    s, Array2, Array3, ArrayView1, ArrayView2, ArrayView3,
17};
18use num_traits::{AsPrimitive, Float, PrimInt, Signed};
19
20use crate::{
21    byte::{
22        nms_class_aware_int, nms_extra_class_aware_int, nms_extra_int, nms_int,
23        postprocess_boxes_index_quant, postprocess_boxes_quant, quantize_score_threshold,
24    },
25    configs::Nms,
26    dequant_detect_box,
27    float::{
28        nms_class_aware_float, nms_extra_class_aware_float, nms_extra_float, nms_float,
29        postprocess_boxes_float, postprocess_boxes_index_float,
30    },
31    BBoxTypeTrait, BoundingBox, DetectBox, DetectBoxQuantized, ProtoData, ProtoLayout,
32    Quantization, Segmentation, XYWH, XYXY,
33};
34
35/// Maximum number of above-threshold candidates fed to NMS.
36///
37/// At very low score thresholds (e.g., t=0.01 on YOLOv8 with
38/// 8400 anchors × 80 classes), the number of survivors approaches
39/// the full 672 000-entry score grid. NMS is O(n²) and the
40/// downstream mask matmul runs once per survivor, so an
41/// unbounded set produces minutes-per-frame decode times.
42///
43/// `MAX_NMS_CANDIDATES` matches the Ultralytics `max_nms` default
44/// and is applied as a top-K-by-score truncation immediately
45/// before NMS. Values above the cap are silently dropped — at the
46/// score thresholds where the cap activates the bottom of the
47/// candidate list is dominated by noise that NMS would discard
48/// anyway.
49///
50/// Production callers configure this via [`crate::DecoderBuilder::with_pre_nms_top_k`];
51/// only the test-only seg-det wrappers reach for this constant directly.
52#[cfg(test)]
53pub(crate) const MAX_NMS_CANDIDATES: usize = 30_000;
54
55/// Default post-NMS detection cap used by the public `decode_yolo_*`
56/// convenience wrappers when no explicit cap is plumbed in. Mirrors the
57/// `Decoder::max_det` default set by `DecoderBuilder` (also 300, matching
58/// the Ultralytics `max_det` default). Pre-EDGEAI-1302 these wrappers
59/// used `output_boxes.capacity()` as the cap, which silently dropped all
60/// detections when the caller passed `Vec::new()`.
61pub(crate) const DEFAULT_MAX_DETECTIONS: usize = 300;
62
63/// Truncate `boxes` to the highest-scoring `top_k` entries in-place when the
64/// input exceeds the cap. Uses partial sort (O(N)) via `select_nth_unstable_by`
65/// to avoid full O(N log N) sort. No-op when `top_k` is 0 (unbounded) or
66/// when the input length ≤ `top_k`.
67fn truncate_to_top_k_by_score<E: Send>(boxes: &mut Vec<(DetectBox, E)>, top_k: usize) {
68    if top_k > 0 && boxes.len() > top_k {
69        boxes.select_nth_unstable_by(top_k, |a, b| b.0.score.total_cmp(&a.0.score));
70        boxes.truncate(top_k);
71    }
72}
73
74/// Quantized counterpart of [`truncate_to_top_k_by_score`]. Sorts on
75/// the raw quantized score (which preserves order under monotonic
76/// dequantization). Uses partial sort (O(N)) via `select_nth_unstable_by`.
77/// No-op when `top_k` is 0 (unbounded) or when the input length ≤ `top_k`.
78fn truncate_to_top_k_by_score_quant<S: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send>(
79    boxes: &mut Vec<(DetectBoxQuantized<S>, E)>,
80    top_k: usize,
81) {
82    if top_k > 0 && boxes.len() > top_k {
83        boxes.select_nth_unstable_by(top_k, |a, b| b.0.score.cmp(&a.0.score));
84        boxes.truncate(top_k);
85    }
86}
87
88/// Dispatches to the appropriate NMS function based on mode for float boxes.
89///
90/// `max_det` is the post-NMS detection cap; when present it lets the greedy
91/// inner loop break as soon as that many survivors are confirmed (the survivors
92/// are guaranteed to be the top-`max_det` by score because the input is sorted
93/// descending). Pass `None` to run the full O(N²) suppression.
94fn dispatch_nms_float(
95    nms: Option<Nms>,
96    iou: f32,
97    max_det: Option<usize>,
98    boxes: Vec<DetectBox>,
99) -> Vec<DetectBox> {
100    match nms {
101        Some(Nms::ClassAgnostic | Nms::Auto) => nms_float(iou, max_det, boxes),
102        Some(Nms::ClassAware) => nms_class_aware_float(iou, max_det, boxes),
103        None => boxes, // bypass NMS
104    }
105}
106
107/// Dispatches to the appropriate NMS function based on mode for float boxes
108/// with extra data.
109pub(super) fn dispatch_nms_extra_float<E: Send + Sync>(
110    nms: Option<Nms>,
111    iou: f32,
112    max_det: Option<usize>,
113    boxes: Vec<(DetectBox, E)>,
114) -> Vec<(DetectBox, E)> {
115    match nms {
116        Some(Nms::ClassAgnostic | Nms::Auto) => nms_extra_float(iou, max_det, boxes),
117        Some(Nms::ClassAware) => nms_extra_class_aware_float(iou, max_det, boxes),
118        None => boxes, // bypass NMS
119    }
120}
121
122/// Dispatches to the appropriate NMS function based on mode for quantized
123/// boxes.
124fn dispatch_nms_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync>(
125    nms: Option<Nms>,
126    iou: f32,
127    max_det: Option<usize>,
128    boxes: Vec<DetectBoxQuantized<SCORE>>,
129) -> Vec<DetectBoxQuantized<SCORE>> {
130    match nms {
131        Some(Nms::ClassAgnostic | Nms::Auto) => nms_int(iou, max_det, boxes),
132        Some(Nms::ClassAware) => nms_class_aware_int(iou, max_det, boxes),
133        None => boxes, // bypass NMS
134    }
135}
136
137/// Dispatches to the appropriate NMS function based on mode for quantized boxes
138/// with extra data.
139fn dispatch_nms_extra_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send + Sync>(
140    nms: Option<Nms>,
141    iou: f32,
142    max_det: Option<usize>,
143    boxes: Vec<(DetectBoxQuantized<SCORE>, E)>,
144) -> Vec<(DetectBoxQuantized<SCORE>, E)> {
145    match nms {
146        Some(Nms::ClassAgnostic | Nms::Auto) => nms_extra_int(iou, max_det, boxes),
147        Some(Nms::ClassAware) => nms_extra_class_aware_int(iou, max_det, boxes),
148        None => boxes, // bypass NMS
149    }
150}
151
152/// Detection cap helper for the public free `decode_yolo_*` wrappers.
153///
154/// Encodes the convention documented above: if the caller passed a non-empty
155/// `Vec`, that capacity acts as the per-call cap; otherwise fall back to
156/// [`DEFAULT_MAX_DETECTIONS`] so freshly-constructed `Vec::new()` outputs
157/// don't silently drop every detection.
158#[inline]
159fn cap_or_default<T>(v: &Vec<T>) -> usize {
160    if v.capacity() > 0 {
161        v.capacity()
162    } else {
163        DEFAULT_MAX_DETECTIONS
164    }
165}
166
167// ─── Public free decode_yolo_* convenience wrappers ────────────────────────
168//
169// Detection cap convention (applies to every `decode_yolo_*` free function
170// below):
171//
172// These functions are the convenience layer for callers that don't go
173// through `Decoder::decode()` (benches, FFI shims, ad-hoc test harnesses).
174// They use **`output_boxes.capacity()` as a per-call detection cap**:
175//
176//   - When the caller passes `Vec::with_capacity(N)`, the post-NMS output
177//     is truncated to at most `N` detections.
178//   - When the caller passes `Vec::new()` (capacity 0), the implementation
179//     falls back to the [`DEFAULT_MAX_DETECTIONS`] constant (300) so a
180//     freshly-constructed `Vec` doesn't silently drop every detection.
181//
182// This is intentionally **different** from the `Decoder::decode()` /
183// `Decoder::decode_proto()` contract, which bounds output count solely
184// by [`Decoder::max_det`] (set via `DecoderBuilder::with_max_det`,
185// default 300) regardless of the caller's `Vec` capacity (EDGEAI-1302).
186//
187// Use the `Decoder` API when you need explicit control over `max_det`,
188// schema-driven decoding, or EDGEAI-1303 normalization. Use these free
189// functions when you have raw tensors in hand and want a one-shot decode.
190
191/// Decodes YOLO detection outputs from quantized tensors into detection boxes.
192///
193/// Boxes are expected to be in XYWH format.
194///
195/// Expected shapes of inputs:
196/// - output: (4 + num_classes, num_boxes)
197///
198/// See the "Detection cap convention" comment above for how
199/// `output_boxes.capacity()` bounds the result count.
200pub(crate) fn decode_yolo_det<BOX: PrimInt + AsPrimitive<f32> + Send + Sync>(
201    output: (ArrayView2<BOX>, Quantization),
202    score_threshold: f32,
203    iou_threshold: f32,
204    nms: Option<Nms>,
205    output_boxes: &mut Vec<DetectBox>,
206) where
207    f32: AsPrimitive<BOX>,
208{
209    impl_yolo_quant::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
210}
211
212/// Decodes YOLO detection outputs from float tensors into detection boxes.
213///
214/// Boxes are expected to be in XYWH format.
215///
216/// Expected shapes of inputs:
217/// - output: (4 + num_classes, num_boxes)
218pub(crate) fn decode_yolo_det_float<T>(
219    output: ArrayView2<T>,
220    score_threshold: f32,
221    iou_threshold: f32,
222    nms: Option<Nms>,
223    output_boxes: &mut Vec<DetectBox>,
224) where
225    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
226    f32: AsPrimitive<T>,
227{
228    impl_yolo_float::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
229}
230
231/// Test-only seg-det quantized decode shim.
232///
233/// Production callers go through [`crate::Decoder::decode`]; this wrapper
234/// exists only so the parity tests in `lib.rs` and `yolo.rs` can compare the
235/// kernel output against the Decoder output without duplicating the
236/// `impl_yolo_segdet_quant` argument plumbing.
237///
238/// Boxes are expected to be in XYWH format. Expected shapes:
239/// - `boxes`: `(4 + num_classes + num_protos, num_boxes)`
240/// - `protos`: `(proto_height, proto_width, num_protos)`
241///
242/// # Errors
243/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
244#[cfg(test)]
245pub(crate) fn decode_yolo_segdet_quant<
246    BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
247    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
248>(
249    boxes: (ArrayView2<BOX>, Quantization),
250    protos: (ArrayView3<PROTO>, Quantization),
251    score_threshold: f32,
252    iou_threshold: f32,
253    nms: Option<Nms>,
254    output_boxes: &mut Vec<DetectBox>,
255    output_masks: &mut Vec<Segmentation>,
256) -> Result<(), crate::DecoderError>
257where
258    f32: AsPrimitive<BOX>,
259{
260    // Pre-Decoder convenience wrapper: no schema-derived `normalized`
261    // or `input_dims`, so the EDGEAI-1303 normalization is a no-op.
262    // Callers that need that path should go through `DecoderBuilder`
263    // with a schema.
264    let cap = cap_or_default(output_boxes);
265    impl_yolo_segdet_quant::<XYWH, _, _>(
266        boxes,
267        protos,
268        score_threshold,
269        iou_threshold,
270        nms,
271        MAX_NMS_CANDIDATES,
272        cap,
273        None,
274        None,
275        output_boxes,
276        output_masks,
277    )
278}
279
280/// Test-only seg-det float decode shim. See [`decode_yolo_segdet_quant`].
281#[cfg(test)]
282pub(crate) fn decode_yolo_segdet_float<T>(
283    boxes: ArrayView2<T>,
284    protos: ArrayView3<T>,
285    score_threshold: f32,
286    iou_threshold: f32,
287    nms: Option<Nms>,
288    output_boxes: &mut Vec<DetectBox>,
289    output_masks: &mut Vec<Segmentation>,
290) -> Result<(), crate::DecoderError>
291where
292    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
293    f32: AsPrimitive<T>,
294{
295    // Pre-Decoder convenience wrapper: schema-derived normalization is
296    // not available here (see EDGEAI-1303 note in the quantized sibling).
297    let cap = cap_or_default(output_boxes);
298    impl_yolo_segdet_float::<XYWH, _, _>(
299        boxes,
300        protos,
301        score_threshold,
302        iou_threshold,
303        nms,
304        MAX_NMS_CANDIDATES,
305        cap,
306        None,
307        None,
308        output_boxes,
309        output_masks,
310    )
311}
312
313/// Decodes YOLO split detection outputs from quantized tensors into detection
314/// boxes.
315///
316/// Boxes are expected to be in XYWH format.
317///
318/// Expected shapes of inputs:
319/// - boxes: (4, num_boxes)
320/// - scores: (num_classes, num_boxes)
321///
322/// # Panics
323/// Panics if shapes don't match the expected dimensions.
324pub(crate) fn decode_yolo_split_det_quant<
325    BOX: PrimInt + AsPrimitive<i32> + AsPrimitive<f32> + Send + Sync,
326    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
327>(
328    boxes: (ArrayView2<BOX>, Quantization),
329    scores: (ArrayView2<SCORE>, Quantization),
330    score_threshold: f32,
331    iou_threshold: f32,
332    nms: Option<Nms>,
333    output_boxes: &mut Vec<DetectBox>,
334) where
335    f32: AsPrimitive<SCORE>,
336{
337    impl_yolo_split_quant::<XYWH, _, _>(
338        boxes,
339        scores,
340        score_threshold,
341        iou_threshold,
342        nms,
343        output_boxes,
344    );
345}
346
347/// Decodes YOLO split detection outputs from float tensors into detection
348/// boxes.
349///
350/// Boxes are expected to be in XYWH format.
351///
352/// Expected shapes of inputs:
353/// - boxes: (4, num_boxes)
354/// - scores: (num_classes, num_boxes)
355///
356/// # Panics
357/// Panics if shapes don't match the expected dimensions.
358pub(crate) fn decode_yolo_split_det_float<T>(
359    boxes: ArrayView2<T>,
360    scores: ArrayView2<T>,
361    score_threshold: f32,
362    iou_threshold: f32,
363    nms: Option<Nms>,
364    output_boxes: &mut Vec<DetectBox>,
365) where
366    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
367    f32: AsPrimitive<T>,
368{
369    impl_yolo_split_float::<XYWH, _, _>(
370        boxes,
371        scores,
372        score_threshold,
373        iou_threshold,
374        nms,
375        output_boxes,
376    );
377}
378
379/// Decodes end-to-end YOLO detection outputs (post-NMS from model).
380/// Expects an array of shape `(6, N)`, where the first dimension (rows)
381/// corresponds to the 6 per-detection features
382/// `[x1, y1, x2, y2, conf, class]` and the second dimension (columns)
383/// indexes the `N` detections.
384/// Boxes are output directly without NMS (the model already applied NMS).
385///
386/// Coordinates may be normalized `[0, 1]` or absolute pixel values depending
387/// on the model configuration. The caller should check
388/// `decoder.normalized_boxes()` to determine which.
389///
390/// # Errors
391///
392/// Returns `DecoderError::InvalidShape` if `output` has fewer than 6 rows.
393pub(crate) fn decode_yolo_end_to_end_det_float<T>(
394    output: ArrayView2<T>,
395    score_threshold: f32,
396    output_boxes: &mut Vec<DetectBox>,
397) -> Result<(), crate::DecoderError>
398where
399    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
400    f32: AsPrimitive<T>,
401{
402    // Validate input shape: need at least 6 rows (x1, y1, x2, y2, conf, class)
403    if output.shape()[0] < 6 {
404        return Err(crate::DecoderError::InvalidShape(format!(
405            "End-to-end detection output requires at least 6 rows, got {}",
406            output.shape()[0]
407        )));
408    }
409
410    // Input shape: (6, N) -> transpose to (N, 4) for boxes and (N, 1) for scores
411    let boxes = output.slice(s![0..4, ..]).reversed_axes();
412    let scores = output.slice(s![4..5, ..]).reversed_axes();
413    let classes = output.slice(s![5, ..]);
414    let mut boxes =
415        postprocess_boxes_index_float::<XYXY, _, _>(score_threshold.as_(), boxes, scores);
416    boxes.truncate(cap_or_default(output_boxes));
417    output_boxes.clear();
418    for (mut b, i) in boxes.into_iter() {
419        b.label = classes[i].as_() as usize;
420        output_boxes.push(b);
421    }
422    // No NMS — model output is already post-NMS
423    Ok(())
424}
425
426/// Decodes end-to-end YOLO detection + segmentation outputs (post-NMS from
427/// model).
428///
429/// Input shapes:
430/// - detection: (6 + num_protos, N) where rows are [x1, y1, x2, y2, conf,
431///   class, mask_coeff_0, ..., mask_coeff_31]
432/// - protos: (proto_height, proto_width, num_protos)
433///
434/// Boxes are output directly without NMS (model already applied NMS).
435/// Coordinates may be normalized [0,1] or pixel values depending on model
436/// config.
437///
438/// # Errors
439///
440/// Returns `DecoderError::InvalidShape` if:
441/// - output has fewer than 7 rows (6 base + at least 1 mask coefficient)
442/// - protos shape doesn't match mask coefficients count
443pub(crate) fn decode_yolo_end_to_end_segdet_float<T>(
444    output: ArrayView2<T>,
445    protos: ArrayView3<T>,
446    score_threshold: f32,
447    output_boxes: &mut Vec<DetectBox>,
448    output_masks: &mut Vec<crate::Segmentation>,
449) -> Result<(), crate::DecoderError>
450where
451    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
452    f32: AsPrimitive<T>,
453{
454    let (boxes, scores, classes, mask_coeff) =
455        postprocess_yolo_end_to_end_segdet(&output, protos.dim().2)?;
456    let cap = cap_or_default(output_boxes);
457    let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
458        boxes,
459        scores,
460        classes,
461        score_threshold,
462        cap,
463    );
464
465    // No NMS — model output is already post-NMS
466
467    impl_yolo_split_segdet_process_masks(boxes, mask_coeff, protos, output_boxes, output_masks)
468}
469
470/// Decodes split end-to-end YOLO detection outputs (post-NMS from model).
471///
472/// Input shapes (after batch dim removed):
473/// - boxes: (4, N) — xyxy pixel coordinates
474/// - scores: (1, N) — confidence of the top class
475/// - classes: (1, N) — class index of the top class
476///
477/// Boxes are output directly without NMS (model already applied NMS).
478pub(crate) fn decode_yolo_split_end_to_end_det_float<T: Float + AsPrimitive<f32>>(
479    boxes: ArrayView2<T>,
480    scores: ArrayView2<T>,
481    classes: ArrayView2<T>,
482    score_threshold: f32,
483    output_boxes: &mut Vec<DetectBox>,
484) -> Result<(), crate::DecoderError> {
485    let n = boxes.shape()[1];
486
487    let cap = cap_or_default(output_boxes);
488    output_boxes.clear();
489
490    let (boxes, scores, classes) = postprocess_yolo_split_end_to_end_det(boxes, scores, &classes)?;
491
492    for i in 0..n {
493        let score: f32 = scores[[i, 0]].as_();
494        if score < score_threshold {
495            continue;
496        }
497        if output_boxes.len() >= cap {
498            break;
499        }
500        output_boxes.push(DetectBox {
501            bbox: BoundingBox {
502                xmin: boxes[[i, 0]].as_(),
503                ymin: boxes[[i, 1]].as_(),
504                xmax: boxes[[i, 2]].as_(),
505                ymax: boxes[[i, 3]].as_(),
506            },
507            score,
508            label: classes[i].as_() as usize,
509        });
510    }
511    Ok(())
512}
513
514/// Decodes split end-to-end YOLO detection + segmentation outputs.
515///
516/// Input shapes (after batch dim removed):
517/// - boxes: (4, N) — xyxy pixel coordinates
518/// - scores: (1, N) — confidence
519/// - classes: (1, N) — class index
520/// - mask_coeff: (num_protos, N) — mask coefficients per detection
521/// - protos: (proto_h, proto_w, num_protos) — prototype masks
522#[allow(clippy::too_many_arguments)]
523pub(crate) fn decode_yolo_split_end_to_end_segdet_float<T>(
524    boxes: ArrayView2<T>,
525    scores: ArrayView2<T>,
526    classes: ArrayView2<T>,
527    mask_coeff: ArrayView2<T>,
528    protos: ArrayView3<T>,
529    score_threshold: f32,
530    output_boxes: &mut Vec<DetectBox>,
531    output_masks: &mut Vec<crate::Segmentation>,
532) -> Result<(), crate::DecoderError>
533where
534    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
535    f32: AsPrimitive<T>,
536{
537    let (boxes, scores, classes, mask_coeff) =
538        postprocess_yolo_split_end_to_end_segdet(boxes, scores, &classes, mask_coeff)?;
539    let cap = cap_or_default(output_boxes);
540    let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
541        boxes,
542        scores,
543        classes,
544        score_threshold,
545        cap,
546    );
547
548    impl_yolo_split_segdet_process_masks(boxes, mask_coeff, protos, output_boxes, output_masks)
549}
550
551#[allow(clippy::type_complexity)]
552pub(crate) fn postprocess_yolo_end_to_end_segdet<'a, T>(
553    output: &'a ArrayView2<'_, T>,
554    num_protos: usize,
555) -> Result<
556    (
557        ArrayView2<'a, T>,
558        ArrayView2<'a, T>,
559        ArrayView1<'a, T>,
560        ArrayView2<'a, T>,
561    ),
562    crate::DecoderError,
563> {
564    // Validate input shape: need at least 7 rows (6 base + at least 1 mask coeff)
565    if output.shape()[0] < 7 {
566        return Err(crate::DecoderError::InvalidShape(format!(
567            "End-to-end segdet output requires at least 7 rows, got {}",
568            output.shape()[0]
569        )));
570    }
571
572    let num_mask_coeffs = output.shape()[0] - 6;
573    if num_mask_coeffs != num_protos {
574        return Err(crate::DecoderError::InvalidShape(format!(
575            "Mask coefficients count ({}) doesn't match protos count ({})",
576            num_mask_coeffs, num_protos
577        )));
578    }
579
580    // Input shape: (6+num_protos, N) -> transpose for postprocessing
581    let boxes = output.slice(s![0..4, ..]).reversed_axes();
582    let scores = output.slice(s![4..5, ..]).reversed_axes();
583    let classes = output.slice(s![5, ..]);
584    let mask_coeff = output.slice(s![6.., ..]).reversed_axes();
585    Ok((boxes, scores, classes, mask_coeff))
586}
587
588/// Postprocess yolo split end to end det by reversing axes of boxes,
589/// scores, and flattening the class tensor.
590/// Expected input shapes:
591/// - boxes: (4, N)
592/// - scores: (1, N)
593/// - classes: (1, N)
594#[allow(clippy::type_complexity)]
595pub(crate) fn postprocess_yolo_split_end_to_end_det<'a, 'b, 'c, BOXES, SCORES, CLASS>(
596    boxes: ArrayView2<'a, BOXES>,
597    scores: ArrayView2<'b, SCORES>,
598    classes: &'c ArrayView2<CLASS>,
599) -> Result<
600    (
601        ArrayView2<'a, BOXES>,
602        ArrayView2<'b, SCORES>,
603        ArrayView1<'c, CLASS>,
604    ),
605    crate::DecoderError,
606> {
607    let num_boxes = boxes.shape()[1];
608    if boxes.shape()[0] != 4 {
609        return Err(crate::DecoderError::InvalidShape(format!(
610            "Split end-to-end box_coords must be 4, got {}",
611            boxes.shape()[0]
612        )));
613    }
614
615    if scores.shape()[0] != 1 {
616        return Err(crate::DecoderError::InvalidShape(format!(
617            "Split end-to-end scores num_classes must be 1, got {}",
618            scores.shape()[0]
619        )));
620    }
621
622    if classes.shape()[0] != 1 {
623        return Err(crate::DecoderError::InvalidShape(format!(
624            "Split end-to-end classes num_classes must be 1, got {}",
625            classes.shape()[0]
626        )));
627    }
628
629    if scores.shape()[1] != num_boxes {
630        return Err(crate::DecoderError::InvalidShape(format!(
631            "Split end-to-end scores must have same num_boxes as boxes ({}), got {}",
632            num_boxes,
633            scores.shape()[1]
634        )));
635    }
636
637    if classes.shape()[1] != num_boxes {
638        return Err(crate::DecoderError::InvalidShape(format!(
639            "Split end-to-end classes must have same num_boxes as boxes ({}), got {}",
640            num_boxes,
641            classes.shape()[1]
642        )));
643    }
644
645    let boxes = boxes.reversed_axes();
646    let scores = scores.reversed_axes();
647    let classes = classes.slice(s![0, ..]);
648    Ok((boxes, scores, classes))
649}
650
651/// Postprocess yolo split end to end segdet by reversing axes of boxes,
652/// scores, mask tensors and flattening the class tensor.
653#[allow(clippy::type_complexity)]
654pub(crate) fn postprocess_yolo_split_end_to_end_segdet<
655    'a,
656    'b,
657    'c,
658    'd,
659    BOXES,
660    SCORES,
661    CLASS,
662    MASK,
663>(
664    boxes: ArrayView2<'a, BOXES>,
665    scores: ArrayView2<'b, SCORES>,
666    classes: &'c ArrayView2<CLASS>,
667    mask_coeff: ArrayView2<'d, MASK>,
668) -> Result<
669    (
670        ArrayView2<'a, BOXES>,
671        ArrayView2<'b, SCORES>,
672        ArrayView1<'c, CLASS>,
673        ArrayView2<'d, MASK>,
674    ),
675    crate::DecoderError,
676> {
677    let num_boxes = boxes.shape()[1];
678    if boxes.shape()[0] != 4 {
679        return Err(crate::DecoderError::InvalidShape(format!(
680            "Split end-to-end box_coords must be 4, got {}",
681            boxes.shape()[0]
682        )));
683    }
684
685    if scores.shape()[0] != 1 {
686        return Err(crate::DecoderError::InvalidShape(format!(
687            "Split end-to-end scores num_classes must be 1, got {}",
688            scores.shape()[0]
689        )));
690    }
691
692    if classes.shape()[0] != 1 {
693        return Err(crate::DecoderError::InvalidShape(format!(
694            "Split end-to-end classes num_classes must be 1, got {}",
695            classes.shape()[0]
696        )));
697    }
698
699    if scores.shape()[1] != num_boxes {
700        return Err(crate::DecoderError::InvalidShape(format!(
701            "Split end-to-end scores must have same num_boxes as boxes ({}), got {}",
702            num_boxes,
703            scores.shape()[1]
704        )));
705    }
706
707    if classes.shape()[1] != num_boxes {
708        return Err(crate::DecoderError::InvalidShape(format!(
709            "Split end-to-end classes must have same num_boxes as boxes ({}), got {}",
710            num_boxes,
711            classes.shape()[1]
712        )));
713    }
714
715    if mask_coeff.shape()[1] != num_boxes {
716        return Err(crate::DecoderError::InvalidShape(format!(
717            "Split end-to-end mask_coeff must have same num_boxes as boxes ({}), got {}",
718            num_boxes,
719            mask_coeff.shape()[1]
720        )));
721    }
722
723    let boxes = boxes.reversed_axes();
724    let scores = scores.reversed_axes();
725    let classes = classes.slice(s![0, ..]);
726    let mask_coeff = mask_coeff.reversed_axes();
727    Ok((boxes, scores, classes, mask_coeff))
728}
729/// Internal implementation of YOLO decoding for quantized tensors.
730///
731/// Expected shapes of inputs:
732/// - output: (4 + num_classes, num_boxes)
733pub(crate) fn impl_yolo_quant<B: BBoxTypeTrait, T: PrimInt + AsPrimitive<f32> + Send + Sync>(
734    output: (ArrayView2<T>, Quantization),
735    score_threshold: f32,
736    iou_threshold: f32,
737    nms: Option<Nms>,
738    output_boxes: &mut Vec<DetectBox>,
739) where
740    f32: AsPrimitive<T>,
741{
742    let _span = tracing::trace_span!("decoder.decode.yolo_quant_flat").entered();
743    let (boxes, quant_boxes) = output;
744    let (boxes_tensor, scores_tensor) = postprocess_yolo(&boxes);
745
746    let boxes = {
747        let score_threshold = quantize_score_threshold(score_threshold, quant_boxes);
748        postprocess_boxes_quant::<B, _, _>(
749            score_threshold,
750            boxes_tensor,
751            scores_tensor,
752            quant_boxes,
753        )
754    };
755
756    let cap = cap_or_default(output_boxes);
757    let boxes = dispatch_nms_int(nms, iou_threshold, Some(cap), boxes);
758    // Detection cap convention (see `cap_or_default`). NMS already capped to
759    // `cap`; the `min` here is a redundant guard for non-NMS bypass mode.
760    let len = cap.min(boxes.len());
761    output_boxes.clear();
762    for b in boxes.iter().take(len) {
763        output_boxes.push(dequant_detect_box(b, quant_boxes));
764    }
765}
766
767/// Internal implementation of YOLO decoding for float tensors.
768///
769/// Expected shapes of inputs:
770/// - output: (4 + num_classes, num_boxes)
771pub(crate) fn impl_yolo_float<B: BBoxTypeTrait, T: Float + AsPrimitive<f32> + Send + Sync>(
772    output: ArrayView2<T>,
773    score_threshold: f32,
774    iou_threshold: f32,
775    nms: Option<Nms>,
776    output_boxes: &mut Vec<DetectBox>,
777) where
778    f32: AsPrimitive<T>,
779{
780    let _span = tracing::trace_span!("decoder.decode.yolo_float_flat").entered();
781    let (boxes_tensor, scores_tensor) = postprocess_yolo(&output);
782    let boxes =
783        postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
784    let cap = cap_or_default(output_boxes);
785    let boxes = dispatch_nms_float(nms, iou_threshold, Some(cap), boxes);
786    // Detection cap convention (see `cap_or_default`). NMS already capped to
787    // `cap`; the `min` here is a redundant guard for non-NMS bypass mode.
788    let len = cap.min(boxes.len());
789    output_boxes.clear();
790    for b in boxes.into_iter().take(len) {
791        output_boxes.push(b);
792    }
793}
794
795/// Internal implementation of YOLO split detection decoding for quantized
796/// tensors.
797///
798/// Expected shapes of inputs:
799/// - boxes: (4, num_boxes)
800/// - scores: (num_classes, num_boxes)
801///
802/// # Panics
803/// Panics if shapes don't match the expected dimensions.
804pub(crate) fn impl_yolo_split_quant<
805    B: BBoxTypeTrait,
806    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
807    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
808>(
809    boxes: (ArrayView2<BOX>, Quantization),
810    scores: (ArrayView2<SCORE>, Quantization),
811    score_threshold: f32,
812    iou_threshold: f32,
813    nms: Option<Nms>,
814    output_boxes: &mut Vec<DetectBox>,
815) where
816    f32: AsPrimitive<SCORE>,
817{
818    let _span = tracing::trace_span!("decoder.decode.yolo_quant_split").entered();
819    let (boxes_tensor, quant_boxes) = boxes;
820    let (scores_tensor, quant_scores) = scores;
821
822    let boxes_tensor = boxes_tensor.reversed_axes();
823    let scores_tensor = scores_tensor.reversed_axes();
824
825    let boxes = {
826        let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
827        postprocess_boxes_quant::<B, _, _>(
828            score_threshold,
829            boxes_tensor,
830            scores_tensor,
831            quant_boxes,
832        )
833    };
834
835    let cap = cap_or_default(output_boxes);
836    let boxes = dispatch_nms_int(nms, iou_threshold, Some(cap), boxes);
837    // Detection cap convention (see `cap_or_default`). NMS already capped to
838    // `cap`; the `min` here is a redundant guard for non-NMS bypass mode.
839    let len = cap.min(boxes.len());
840    output_boxes.clear();
841    for b in boxes.iter().take(len) {
842        output_boxes.push(dequant_detect_box(b, quant_scores));
843    }
844}
845
846/// Internal implementation of YOLO split detection decoding for float tensors.
847///
848/// Expected shapes of inputs:
849/// - boxes: (4, num_boxes)
850/// - scores: (num_classes, num_boxes)
851///
852/// # Panics
853/// Panics if shapes don't match the expected dimensions.
854pub(crate) fn impl_yolo_split_float<
855    B: BBoxTypeTrait,
856    BOX: Float + AsPrimitive<f32> + Send + Sync,
857    SCORE: Float + AsPrimitive<f32> + Send + Sync,
858>(
859    boxes_tensor: ArrayView2<BOX>,
860    scores_tensor: ArrayView2<SCORE>,
861    score_threshold: f32,
862    iou_threshold: f32,
863    nms: Option<Nms>,
864    output_boxes: &mut Vec<DetectBox>,
865) where
866    f32: AsPrimitive<SCORE>,
867{
868    let _span = tracing::trace_span!("decoder.decode.yolo_float_split").entered();
869    let boxes_tensor = boxes_tensor.reversed_axes();
870    let scores_tensor = scores_tensor.reversed_axes();
871    let boxes =
872        postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
873    let cap = cap_or_default(output_boxes);
874    let boxes = dispatch_nms_float(nms, iou_threshold, Some(cap), boxes);
875    // Detection cap convention (see `cap_or_default`). NMS already capped to
876    // `cap`; the `min` here is a redundant guard for non-NMS bypass mode.
877    let len = cap.min(boxes.len());
878    output_boxes.clear();
879    for b in boxes.into_iter().take(len) {
880        output_boxes.push(b);
881    }
882}
883
884/// Divide each survivor's bbox by `(input_w, input_h)` when the schema
885/// declares `normalized: false`. Pixel-space box coords from the model
886/// are pulled into the canonical `[0, 1]` range expected by `protobox`
887/// and downstream callers — see EDGEAI-1303.
888///
889/// No-op when `normalized` is `Some(true)` / `None`, when `input_dims`
890/// is `None`, or when there are no survivors.
891#[inline]
892pub(crate) fn maybe_normalize_boxes_in_place(
893    boxes: &mut [(DetectBox, usize)],
894    normalized: Option<bool>,
895    input_dims: Option<(usize, usize)>,
896) {
897    if normalized != Some(false) {
898        return;
899    }
900    let Some((w, h)) = input_dims else {
901        return;
902    };
903    if w == 0 || h == 0 {
904        return;
905    }
906    let inv_w = 1.0 / w as f32;
907    let inv_h = 1.0 / h as f32;
908    for (b, _) in boxes.iter_mut() {
909        b.bbox.xmin *= inv_w;
910        b.bbox.ymin *= inv_h;
911        b.bbox.xmax *= inv_w;
912        b.bbox.ymax *= inv_h;
913    }
914}
915
916/// Internal implementation of YOLO detection segmentation decoding for
917/// quantized tensors.
918///
919/// Expected shapes of inputs:
920/// - boxes: (4 + num_classes + num_protos, num_boxes)
921/// - protos: (proto_height, proto_width, num_protos)
922///
923/// # Errors
924/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
925#[allow(clippy::too_many_arguments)]
926pub(crate) fn impl_yolo_segdet_quant<
927    B: BBoxTypeTrait,
928    BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
929    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
930>(
931    boxes: (ArrayView2<BOX>, Quantization),
932    protos: (ArrayView3<PROTO>, Quantization),
933    score_threshold: f32,
934    iou_threshold: f32,
935    nms: Option<Nms>,
936    pre_nms_top_k: usize,
937    max_det: usize,
938    normalized: Option<bool>,
939    input_dims: Option<(usize, usize)>,
940    output_boxes: &mut Vec<DetectBox>,
941    output_masks: &mut Vec<Segmentation>,
942) -> Result<(), crate::DecoderError>
943where
944    f32: AsPrimitive<BOX>,
945{
946    let (boxes, quant_boxes) = boxes;
947    let num_protos = protos.0.dim().2;
948
949    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
950    let mut boxes = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
951        (boxes_tensor, quant_boxes),
952        (scores_tensor, quant_boxes),
953        score_threshold,
954        iou_threshold,
955        nms,
956        pre_nms_top_k,
957        max_det,
958    );
959    maybe_normalize_boxes_in_place(&mut boxes, normalized, input_dims);
960
961    impl_yolo_split_segdet_quant_process_masks::<_, _>(
962        boxes,
963        (mask_tensor, quant_boxes),
964        protos,
965        output_boxes,
966        output_masks,
967    )
968}
969
970/// Internal implementation of YOLO detection segmentation decoding for
971/// float tensors.
972///
973/// Expected shapes of inputs:
974/// - boxes: (4 + num_classes + num_protos, num_boxes)
975/// - protos: (proto_height, proto_width, num_protos)
976///
977/// # Panics
978/// Panics if shapes don't match the expected dimensions.
979#[allow(clippy::too_many_arguments)]
980pub(crate) fn impl_yolo_segdet_float<
981    B: BBoxTypeTrait,
982    BOX: Float + AsPrimitive<f32> + Send + Sync,
983    PROTO: Float + AsPrimitive<f32> + Send + Sync,
984>(
985    boxes: ArrayView2<BOX>,
986    protos: ArrayView3<PROTO>,
987    score_threshold: f32,
988    iou_threshold: f32,
989    nms: Option<Nms>,
990    pre_nms_top_k: usize,
991    max_det: usize,
992    normalized: Option<bool>,
993    input_dims: Option<(usize, usize)>,
994    output_boxes: &mut Vec<DetectBox>,
995    output_masks: &mut Vec<Segmentation>,
996) -> Result<(), crate::DecoderError>
997where
998    f32: AsPrimitive<BOX>,
999{
1000    let num_protos = protos.dim().2;
1001    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
1002    let mut boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
1003        boxes_tensor,
1004        scores_tensor,
1005        score_threshold,
1006        iou_threshold,
1007        nms,
1008        pre_nms_top_k,
1009        max_det,
1010    );
1011    maybe_normalize_boxes_in_place(&mut boxes, normalized, input_dims);
1012    impl_yolo_split_segdet_process_masks(boxes, mask_tensor, protos, output_boxes, output_masks)
1013}
1014
1015pub(crate) fn impl_yolo_segdet_get_boxes<
1016    B: BBoxTypeTrait,
1017    BOX: Float + AsPrimitive<f32> + Send + Sync,
1018    SCORE: Float + AsPrimitive<f32> + Send + Sync,
1019>(
1020    boxes_tensor: ArrayView2<BOX>,
1021    scores_tensor: ArrayView2<SCORE>,
1022    score_threshold: f32,
1023    iou_threshold: f32,
1024    nms: Option<Nms>,
1025    pre_nms_top_k: usize,
1026    max_det: usize,
1027) -> Vec<(DetectBox, usize)>
1028where
1029    f32: AsPrimitive<SCORE>,
1030{
1031    let span = tracing::trace_span!(
1032        "decoder.nms_get_boxes",
1033        n_candidates = tracing::field::Empty,
1034        n_after_topk = tracing::field::Empty,
1035        n_after_nms = tracing::field::Empty,
1036        n_detections = tracing::field::Empty,
1037    );
1038    let _guard = span.enter();
1039
1040    let mut boxes = {
1041        let _s = tracing::trace_span!("decoder.nms_get_boxes.score_filter").entered();
1042        postprocess_boxes_index_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor)
1043    };
1044    span.record("n_candidates", boxes.len());
1045
1046    if nms.is_some() {
1047        let _s = tracing::trace_span!("decoder.nms_get_boxes.top_k", k = pre_nms_top_k).entered();
1048        truncate_to_top_k_by_score(&mut boxes, pre_nms_top_k);
1049    }
1050    span.record("n_after_topk", boxes.len());
1051
1052    let mut boxes = {
1053        let _s = tracing::trace_span!("decoder.nms_get_boxes.suppress").entered();
1054        dispatch_nms_extra_float(nms, iou_threshold, Some(max_det), boxes)
1055    };
1056    span.record("n_after_nms", boxes.len());
1057
1058    // NMS already capped to `max_det`; the trailing sort+truncate is a
1059    // redundant guard for the bypass-NMS path (`nms = None`).
1060    boxes.sort_unstable_by(|a, b| b.0.score.total_cmp(&a.0.score));
1061    boxes.truncate(max_det);
1062    span.record("n_detections", boxes.len());
1063
1064    boxes
1065}
1066
1067pub(crate) fn impl_yolo_end_to_end_segdet_get_boxes<
1068    B: BBoxTypeTrait,
1069    BOX: Float + AsPrimitive<f32> + Send + Sync,
1070    SCORE: Float + AsPrimitive<f32> + Send + Sync,
1071    CLASS: AsPrimitive<f32> + Send + Sync,
1072>(
1073    boxes: ArrayView2<BOX>,
1074    scores: ArrayView2<SCORE>,
1075    classes: ArrayView1<CLASS>,
1076    score_threshold: f32,
1077    max_boxes: usize,
1078) -> Vec<(DetectBox, usize)>
1079where
1080    f32: AsPrimitive<SCORE>,
1081{
1082    let mut boxes = postprocess_boxes_index_float::<B, _, _>(score_threshold.as_(), boxes, scores);
1083    boxes.truncate(max_boxes);
1084    for (b, ind) in &mut boxes {
1085        b.label = classes[*ind].as_().round() as usize;
1086    }
1087    boxes
1088}
1089
1090pub(crate) fn impl_yolo_split_segdet_process_masks<
1091    MASK: Float + AsPrimitive<f32> + Send + Sync,
1092    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1093>(
1094    boxes: Vec<(DetectBox, usize)>,
1095    masks_tensor: ArrayView2<MASK>,
1096    protos_tensor: ArrayView3<PROTO>,
1097    output_boxes: &mut Vec<DetectBox>,
1098    output_masks: &mut Vec<Segmentation>,
1099) -> Result<(), crate::DecoderError> {
1100    let _span = tracing::trace_span!(
1101        "decoder.decode.process_masks",
1102        n = boxes.len(),
1103        mode = "float"
1104    )
1105    .entered();
1106    // Boxes are already bounded by the upstream `max_det` cap from
1107    // `_get_boxes`; no second cap is needed here (EDGEAI-1302).
1108
1109    let boxes = decode_segdet_f32(boxes, masks_tensor, protos_tensor)?;
1110    output_boxes.clear();
1111    output_masks.clear();
1112    for (b, roi, m) in boxes.into_iter() {
1113        output_boxes.push(b);
1114        output_masks.push(Segmentation {
1115            xmin: roi.xmin,
1116            ymin: roi.ymin,
1117            xmax: roi.xmax,
1118            ymax: roi.ymax,
1119            segmentation: m,
1120        });
1121    }
1122    Ok(())
1123}
1124/// Expected input shapes:
1125/// - boxes_tensor: (num_boxes, 4)
1126/// - scores_tensor: (num_boxes, num_classes)
1127pub(crate) fn impl_yolo_split_segdet_quant_get_boxes<
1128    B: BBoxTypeTrait,
1129    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
1130    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
1131>(
1132    boxes: (ArrayView2<BOX>, Quantization),
1133    scores: (ArrayView2<SCORE>, Quantization),
1134    score_threshold: f32,
1135    iou_threshold: f32,
1136    nms: Option<Nms>,
1137    pre_nms_top_k: usize,
1138    max_det: usize,
1139) -> Vec<(DetectBox, usize)>
1140where
1141    f32: AsPrimitive<SCORE>,
1142{
1143    let (boxes_tensor, quant_boxes) = boxes;
1144    let (scores_tensor, quant_scores) = scores;
1145
1146    let span = tracing::trace_span!(
1147        "decoder.nms_get_boxes",
1148        n_candidates = tracing::field::Empty,
1149        n_after_topk = tracing::field::Empty,
1150        n_after_nms = tracing::field::Empty,
1151        n_detections = tracing::field::Empty,
1152    );
1153    let _guard = span.enter();
1154
1155    let mut boxes = {
1156        let _s = tracing::trace_span!("decoder.nms_get_boxes.score_filter").entered();
1157        let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
1158        postprocess_boxes_index_quant::<B, _, _>(
1159            score_threshold,
1160            boxes_tensor,
1161            scores_tensor,
1162            quant_boxes,
1163        )
1164    };
1165    span.record("n_candidates", boxes.len());
1166
1167    if nms.is_some() {
1168        let _s = tracing::trace_span!("decoder.nms_get_boxes.top_k", k = pre_nms_top_k).entered();
1169        truncate_to_top_k_by_score_quant(&mut boxes, pre_nms_top_k);
1170    }
1171    span.record("n_after_topk", boxes.len());
1172
1173    let mut boxes = {
1174        let _s = tracing::trace_span!("decoder.nms_get_boxes.suppress").entered();
1175        dispatch_nms_extra_int(nms, iou_threshold, Some(max_det), boxes)
1176    };
1177    span.record("n_after_nms", boxes.len());
1178
1179    // NMS already capped to `max_det`; the trailing sort+truncate is a
1180    // redundant guard for the bypass-NMS path (`nms = None`).
1181    boxes.sort_unstable_by_key(|b| std::cmp::Reverse(b.0.score));
1182    boxes.truncate(max_det);
1183    let result: Vec<_> = {
1184        let _s =
1185            tracing::trace_span!("decoder.nms_get_boxes.dequant_boxes", n = boxes.len()).entered();
1186        boxes
1187            .into_iter()
1188            .map(|(b, i)| (dequant_detect_box(&b, quant_scores), i))
1189            .collect()
1190    };
1191    span.record("n_detections", result.len());
1192
1193    result
1194}
1195
1196pub(crate) fn impl_yolo_split_segdet_quant_process_masks<
1197    MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1198    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1199>(
1200    boxes: Vec<(DetectBox, usize)>,
1201    mask_coeff: (ArrayView2<MASK>, Quantization),
1202    protos: (ArrayView3<PROTO>, Quantization),
1203    output_boxes: &mut Vec<DetectBox>,
1204    output_masks: &mut Vec<Segmentation>,
1205) -> Result<(), crate::DecoderError> {
1206    let _span = tracing::trace_span!(
1207        "decoder.decode.process_masks",
1208        n = boxes.len(),
1209        mode = "quant"
1210    )
1211    .entered();
1212    let (masks, quant_masks) = mask_coeff;
1213    let (protos, quant_protos) = protos;
1214
1215    // Boxes are already bounded by the upstream `max_det` cap from
1216    // `_get_boxes`; no second cap is needed here (EDGEAI-1302).
1217
1218    let boxes = decode_segdet_quant(boxes, masks, protos, quant_masks, quant_protos)?;
1219    output_boxes.clear();
1220    output_masks.clear();
1221    for (b, roi, m) in boxes.into_iter() {
1222        output_boxes.push(b);
1223        output_masks.push(Segmentation {
1224            xmin: roi.xmin,
1225            ymin: roi.ymin,
1226            xmax: roi.xmax,
1227            ymax: roi.ymax,
1228            segmentation: m,
1229        });
1230    }
1231    Ok(())
1232}
1233
1234/// Internal implementation of YOLO split detection segmentation decoding for
1235/// float tensors.
1236///
1237/// Expected shapes of inputs:
1238/// - boxes_tensor: (4, num_boxes)
1239/// - scores_tensor: (num_classes, num_boxes)
1240/// - mask_tensor: (num_protos, num_boxes)
1241/// - protos: (proto_height, proto_width, num_protos)
1242///
1243/// # Errors
1244/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
1245#[allow(clippy::too_many_arguments)]
1246pub(crate) fn impl_yolo_split_segdet_float<
1247    B: BBoxTypeTrait,
1248    BOX: Float + AsPrimitive<f32> + Send + Sync,
1249    SCORE: Float + AsPrimitive<f32> + Send + Sync,
1250    MASK: Float + AsPrimitive<f32> + Send + Sync,
1251    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1252>(
1253    boxes_tensor: ArrayView2<BOX>,
1254    scores_tensor: ArrayView2<SCORE>,
1255    mask_tensor: ArrayView2<MASK>,
1256    protos: ArrayView3<PROTO>,
1257    score_threshold: f32,
1258    iou_threshold: f32,
1259    nms: Option<Nms>,
1260    pre_nms_top_k: usize,
1261    max_det: usize,
1262    normalized: Option<bool>,
1263    input_dims: Option<(usize, usize)>,
1264    output_boxes: &mut Vec<DetectBox>,
1265    output_masks: &mut Vec<Segmentation>,
1266) -> Result<(), crate::DecoderError>
1267where
1268    f32: AsPrimitive<SCORE>,
1269{
1270    let (boxes_tensor, scores_tensor, mask_tensor) =
1271        postprocess_yolo_split_segdet(boxes_tensor, scores_tensor, mask_tensor);
1272
1273    let mut boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
1274        boxes_tensor,
1275        scores_tensor,
1276        score_threshold,
1277        iou_threshold,
1278        nms,
1279        pre_nms_top_k,
1280        max_det,
1281    );
1282    maybe_normalize_boxes_in_place(&mut boxes, normalized, input_dims);
1283    impl_yolo_split_segdet_process_masks(boxes, mask_tensor, protos, output_boxes, output_masks)
1284}
1285
1286// ---------------------------------------------------------------------------
1287// Proto-extraction variants: return ProtoData instead of materialized masks
1288// ---------------------------------------------------------------------------
1289
1290/// Proto-extraction variant of `impl_yolo_segdet_quant`.
1291/// Runs NMS but returns raw `ProtoData` instead of materialized masks.
1292#[allow(clippy::too_many_arguments)]
1293pub(crate) fn impl_yolo_segdet_quant_proto<
1294    B: BBoxTypeTrait,
1295    BOX: PrimInt
1296        + AsPrimitive<i64>
1297        + AsPrimitive<i128>
1298        + AsPrimitive<f32>
1299        + AsPrimitive<i8>
1300        + Send
1301        + Sync,
1302    PROTO: PrimInt
1303        + AsPrimitive<i64>
1304        + AsPrimitive<i128>
1305        + AsPrimitive<f32>
1306        + AsPrimitive<i8>
1307        + Send
1308        + Sync,
1309>(
1310    boxes: (ArrayView2<BOX>, Quantization),
1311    protos: (ArrayView3<PROTO>, Quantization),
1312    score_threshold: f32,
1313    iou_threshold: f32,
1314    nms: Option<Nms>,
1315    pre_nms_top_k: usize,
1316    max_det: usize,
1317    normalized: Option<bool>,
1318    input_dims: Option<(usize, usize)>,
1319    output_boxes: &mut Vec<DetectBox>,
1320) -> ProtoData
1321where
1322    f32: AsPrimitive<BOX>,
1323{
1324    let (boxes_arr, quant_boxes) = boxes;
1325    let (protos_arr, quant_protos) = protos;
1326    let num_protos = protos_arr.dim().2;
1327
1328    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes_arr, num_protos);
1329
1330    let mut det_indices = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
1331        (boxes_tensor, quant_boxes),
1332        (scores_tensor, quant_boxes),
1333        score_threshold,
1334        iou_threshold,
1335        nms,
1336        pre_nms_top_k,
1337        max_det,
1338    );
1339    maybe_normalize_boxes_in_place(&mut det_indices, normalized, input_dims);
1340
1341    extract_proto_data_quant(
1342        det_indices,
1343        mask_tensor,
1344        quant_boxes,
1345        protos_arr,
1346        quant_protos,
1347        output_boxes,
1348    )
1349}
1350
1351/// Proto-extraction variant of `impl_yolo_segdet_float`.
1352/// Runs NMS but returns raw `ProtoData` instead of materialized masks.
1353#[allow(clippy::too_many_arguments)]
1354pub(crate) fn impl_yolo_segdet_float_proto<
1355    B: BBoxTypeTrait,
1356    BOX: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1357    PROTO: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1358>(
1359    boxes: ArrayView2<BOX>,
1360    protos: ArrayView3<PROTO>,
1361    score_threshold: f32,
1362    iou_threshold: f32,
1363    nms: Option<Nms>,
1364    pre_nms_top_k: usize,
1365    max_det: usize,
1366    normalized: Option<bool>,
1367    input_dims: Option<(usize, usize)>,
1368    output_boxes: &mut Vec<DetectBox>,
1369) -> ProtoData
1370where
1371    f32: AsPrimitive<BOX>,
1372{
1373    let num_protos = protos.dim().2;
1374    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
1375
1376    let mut boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
1377        boxes_tensor,
1378        scores_tensor,
1379        score_threshold,
1380        iou_threshold,
1381        nms,
1382        pre_nms_top_k,
1383        max_det,
1384    );
1385    maybe_normalize_boxes_in_place(&mut boxes, normalized, input_dims);
1386
1387    extract_proto_data_float(boxes, mask_tensor, protos, output_boxes)
1388}
1389
1390/// Proto-extraction variant of `impl_yolo_split_segdet_float`.
1391/// Runs NMS but returns raw `ProtoData` instead of materialized masks.
1392#[allow(clippy::too_many_arguments)]
1393pub(crate) fn impl_yolo_split_segdet_float_proto<
1394    B: BBoxTypeTrait,
1395    BOX: Float + AsPrimitive<f32> + Send + Sync,
1396    SCORE: Float + AsPrimitive<f32> + Send + Sync,
1397    MASK: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1398    PROTO: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1399>(
1400    boxes_tensor: ArrayView2<BOX>,
1401    scores_tensor: ArrayView2<SCORE>,
1402    mask_tensor: ArrayView2<MASK>,
1403    protos: ArrayView3<PROTO>,
1404    score_threshold: f32,
1405    iou_threshold: f32,
1406    nms: Option<Nms>,
1407    pre_nms_top_k: usize,
1408    max_det: usize,
1409    normalized: Option<bool>,
1410    input_dims: Option<(usize, usize)>,
1411    output_boxes: &mut Vec<DetectBox>,
1412) -> ProtoData
1413where
1414    f32: AsPrimitive<SCORE>,
1415{
1416    let (boxes_tensor, scores_tensor, mask_tensor) =
1417        postprocess_yolo_split_segdet(boxes_tensor, scores_tensor, mask_tensor);
1418    let mut det_indices = impl_yolo_segdet_get_boxes::<B, _, _>(
1419        boxes_tensor,
1420        scores_tensor,
1421        score_threshold,
1422        iou_threshold,
1423        nms,
1424        pre_nms_top_k,
1425        max_det,
1426    );
1427    maybe_normalize_boxes_in_place(&mut det_indices, normalized, input_dims);
1428
1429    extract_proto_data_float(det_indices, mask_tensor, protos, output_boxes)
1430}
1431
1432/// Proto-extraction variant of `decode_yolo_end_to_end_segdet_float`.
1433pub(crate) fn decode_yolo_end_to_end_segdet_float_proto<T>(
1434    output: ArrayView2<T>,
1435    protos: ArrayView3<T>,
1436    score_threshold: f32,
1437    output_boxes: &mut Vec<DetectBox>,
1438) -> Result<ProtoData, crate::DecoderError>
1439where
1440    T: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1441    f32: AsPrimitive<T>,
1442{
1443    let (boxes, scores, classes, mask_coeff) =
1444        postprocess_yolo_end_to_end_segdet(&output, protos.dim().2)?;
1445    let cap = cap_or_default(output_boxes);
1446    let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
1447        boxes,
1448        scores,
1449        classes,
1450        score_threshold,
1451        cap,
1452    );
1453
1454    Ok(extract_proto_data_float(
1455        boxes,
1456        mask_coeff,
1457        protos,
1458        output_boxes,
1459    ))
1460}
1461
1462/// Proto-extraction variant of `decode_yolo_split_end_to_end_segdet_float`.
1463#[allow(clippy::too_many_arguments)]
1464pub(crate) fn decode_yolo_split_end_to_end_segdet_float_proto<T>(
1465    boxes: ArrayView2<T>,
1466    scores: ArrayView2<T>,
1467    classes: ArrayView2<T>,
1468    mask_coeff: ArrayView2<T>,
1469    protos: ArrayView3<T>,
1470    score_threshold: f32,
1471    output_boxes: &mut Vec<DetectBox>,
1472) -> Result<ProtoData, crate::DecoderError>
1473where
1474    T: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1475    f32: AsPrimitive<T>,
1476{
1477    let (boxes, scores, classes, mask_coeff) =
1478        postprocess_yolo_split_end_to_end_segdet(boxes, scores, &classes, mask_coeff)?;
1479    let cap = cap_or_default(output_boxes);
1480    let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
1481        boxes,
1482        scores,
1483        classes,
1484        score_threshold,
1485        cap,
1486    );
1487
1488    Ok(extract_proto_data_float(
1489        boxes,
1490        mask_coeff,
1491        protos,
1492        output_boxes,
1493    ))
1494}
1495
1496/// Helper: extract ProtoData from float mask coefficients + protos.
1497///
1498/// Builds [`ProtoData`] with both `protos` and `mask_coefficients` as
1499/// [`edgefirst_tensor::TensorDyn`]. Preserves the native element type for
1500/// `f16` and `f32`; narrows `f64` to `f32` (there is no native f64 kernel
1501/// path). `mask_coefficients` shape is `[num_detections, num_protos]`.
1502pub(super) fn extract_proto_data_float<
1503    MASK: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1504    PROTO: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1505>(
1506    det_indices: Vec<(DetectBox, usize)>,
1507    mask_tensor: ArrayView2<MASK>,
1508    protos: ArrayView3<PROTO>,
1509    output_boxes: &mut Vec<DetectBox>,
1510) -> ProtoData {
1511    let _span = tracing::trace_span!(
1512        "decoder.decode_proto.extract_proto_data",
1513        mode = "float",
1514        n = det_indices.len(),
1515        num_protos = mask_tensor.ncols(),
1516        layout = "nhwc",
1517    )
1518    .entered();
1519
1520    let num_protos = mask_tensor.ncols();
1521    let n = det_indices.len();
1522
1523    // Per-detection coefficients packed row-major into a contiguous buffer,
1524    // preserving the source dtype. Shape: [N, num_protos] — N=0 is permitted
1525    // (tracker path emits no detections this frame) since Mem-backed tensors
1526    // accept zero-element shapes as "empty collection" sentinels.
1527    let mut coeff_rows: Vec<MASK> = Vec::with_capacity(n * num_protos);
1528    output_boxes.clear();
1529    for (det, idx) in det_indices {
1530        output_boxes.push(det);
1531        let row = mask_tensor.row(idx);
1532        coeff_rows.extend(row.iter().copied());
1533    }
1534
1535    let mask_coefficients = MASK::slice_into_tensor_dyn(&coeff_rows, &[n, num_protos])
1536        .expect("allocating mask_coefficients TensorDyn");
1537    let protos_tensor =
1538        PROTO::arrayview3_into_tensor_dyn(protos).expect("allocating protos TensorDyn");
1539
1540    ProtoData {
1541        mask_coefficients,
1542        protos: protos_tensor,
1543        layout: ProtoLayout::Nhwc,
1544    }
1545}
1546
1547/// Helper: extract ProtoData from quantized mask coefficients + protos.
1548///
1549/// Dequantizes mask coefficients to f32 at extraction (one-time cost on a
1550/// `num_detections * num_protos` slice) and keeps protos in raw i8,
1551/// attaching the dequantization params as
1552/// [`edgefirst_tensor::Quantization::per_tensor`] metadata on the proto
1553/// tensor. The GPU shader / CPU kernel reads `protos.quantization()` and
1554/// dequantizes per-texel.
1555pub(crate) fn extract_proto_data_quant<
1556    MASK: PrimInt + AsPrimitive<f32> + AsPrimitive<i8> + Send + Sync + 'static,
1557    PROTO: PrimInt + AsPrimitive<f32> + AsPrimitive<i8> + Send + Sync + 'static,
1558>(
1559    det_indices: Vec<(DetectBox, usize)>,
1560    mask_tensor: ArrayView2<MASK>,
1561    quant_masks: Quantization,
1562    protos: ArrayView3<PROTO>,
1563    quant_protos: Quantization,
1564    output_boxes: &mut Vec<DetectBox>,
1565) -> ProtoData {
1566    use edgefirst_tensor::{Tensor, TensorDyn, TensorMapTrait, TensorMemory, TensorTrait};
1567
1568    let span = tracing::trace_span!(
1569        "decoder.decode_proto.extract_proto_data",
1570        mode = "quant",
1571        n = det_indices.len(),
1572        num_protos = tracing::field::Empty,
1573        layout = tracing::field::Empty,
1574    );
1575    let _guard = span.enter();
1576
1577    let num_protos = mask_tensor.ncols();
1578    let n = det_indices.len();
1579    span.record("num_protos", num_protos);
1580
1581    // Fast path: when no detections survive NMS, skip the expensive proto
1582    // tensor copy (819KB for 160×160×32). Allocate with the correct shape
1583    // (preserving the documented ProtoData.protos shape contract) but skip
1584    // copying from the source tensor — the zeroed allocation is sufficient
1585    // since materialize_masks early-returns on empty detect slices.
1586    if n == 0 {
1587        output_boxes.clear();
1588        let (h, w, k) = protos.dim();
1589
1590        // Detect physical layout (same logic as the normal path).
1591        let (proto_shape, proto_layout) = if std::any::TypeId::of::<PROTO>()
1592            == std::any::TypeId::of::<i8>()
1593        {
1594            if protos.is_standard_layout() {
1595                (&[h, w, k][..], ProtoLayout::Nhwc)
1596            } else if protos.ndim() == 3 && protos.strides() == [w as isize, 1, (h * w) as isize] {
1597                (&[k, h, w][..], ProtoLayout::Nchw)
1598            } else {
1599                (&[h, w, k][..], ProtoLayout::Nhwc)
1600            }
1601        } else {
1602            (&[h, w, k][..], ProtoLayout::Nhwc)
1603        };
1604
1605        let coeff_tensor = Tensor::<i8>::new(&[0, num_protos], Some(TensorMemory::Mem), None)
1606            .expect("allocating empty mask_coefficients tensor");
1607        let coeff_quant =
1608            edgefirst_tensor::Quantization::per_tensor(quant_masks.scale, quant_masks.zero_point);
1609        let coeff_tensor = coeff_tensor
1610            .with_quantization(coeff_quant)
1611            .expect("per-tensor quantization on mask coefficients");
1612        let protos_tensor = Tensor::<i8>::new(proto_shape, Some(TensorMemory::Mem), None)
1613            .expect("allocating protos tensor");
1614        let tensor_quant =
1615            edgefirst_tensor::Quantization::per_tensor(quant_protos.scale, quant_protos.zero_point);
1616        let protos_tensor = protos_tensor
1617            .with_quantization(tensor_quant)
1618            .expect("per-tensor quantization on protos tensor");
1619        return ProtoData {
1620            mask_coefficients: TensorDyn::I8(coeff_tensor),
1621            protos: TensorDyn::I8(protos_tensor),
1622            layout: proto_layout,
1623        };
1624    }
1625
1626    // Mask coefficients: keep i8 losslessly when MASK == i8 (preserves
1627    // the fast i8×i8→i32 integer kernel in materialize_masks). Preserve
1628    // i16 natively so the downstream i16×i8 integer path can avoid lossy
1629    // truncation. Other wider types (u16, …) dequantize to f32 at
1630    // extraction because the downstream mask kernels accept F32 natively.
1631    let mask_coefficients: TensorDyn = if std::any::TypeId::of::<MASK>()
1632        == std::any::TypeId::of::<i8>()
1633    {
1634        let mut coeff_i8 = Vec::<i8>::with_capacity(n * num_protos);
1635        output_boxes.clear();
1636        for (det, idx) in det_indices {
1637            output_boxes.push(det);
1638            let row = mask_tensor.row(idx);
1639            coeff_i8.extend(row.iter().map(|v| {
1640                let v_i8: i8 = v.as_();
1641                v_i8
1642            }));
1643        }
1644        let coeff_tensor = Tensor::<i8>::new(&[n, num_protos], Some(TensorMemory::Mem), None)
1645            .expect("allocating mask_coefficients tensor");
1646        if n > 0 {
1647            let mut m = coeff_tensor
1648                .map()
1649                .expect("mapping mask_coefficients tensor");
1650            m.as_mut_slice().copy_from_slice(&coeff_i8);
1651        }
1652        let coeff_quant =
1653            edgefirst_tensor::Quantization::per_tensor(quant_masks.scale, quant_masks.zero_point);
1654        let coeff_tensor = coeff_tensor
1655            .with_quantization(coeff_quant)
1656            .expect("per-tensor quantization on mask coefficients");
1657        TensorDyn::I8(coeff_tensor)
1658    } else if std::any::TypeId::of::<MASK>() == std::any::TypeId::of::<i16>() {
1659        // i16 path: preserve natively for the fast i16×i8→i32 integer kernel.
1660        // f32 has 24-bit mantissa, so all i16 values are exactly representable.
1661        let mut coeff_i16 = Vec::<i16>::with_capacity(n * num_protos);
1662        output_boxes.clear();
1663        for (det, idx) in det_indices {
1664            output_boxes.push(det);
1665            let row = mask_tensor.row(idx);
1666            coeff_i16.extend(row.iter().map(|v| {
1667                let v_f32: f32 = v.as_();
1668                v_f32 as i16
1669            }));
1670        }
1671        let coeff_tensor = Tensor::<i16>::new(&[n, num_protos], Some(TensorMemory::Mem), None)
1672            .expect("allocating mask_coefficients tensor");
1673        if n > 0 {
1674            let mut m = coeff_tensor
1675                .map()
1676                .expect("mapping mask_coefficients tensor");
1677            m.as_mut_slice().copy_from_slice(&coeff_i16);
1678        }
1679        let coeff_quant =
1680            edgefirst_tensor::Quantization::per_tensor(quant_masks.scale, quant_masks.zero_point);
1681        let coeff_tensor = coeff_tensor
1682            .with_quantization(coeff_quant)
1683            .expect("per-tensor quantization on mask coefficients");
1684        TensorDyn::I16(coeff_tensor)
1685    } else {
1686        // Other types (u8, u16, etc.): dequantize to f32 to avoid lossy truncation.
1687        let scale = quant_masks.scale;
1688        let zp = quant_masks.zero_point as f32;
1689        let mut coeff_f32 = Vec::<f32>::with_capacity(n * num_protos);
1690        output_boxes.clear();
1691        for (det, idx) in det_indices {
1692            output_boxes.push(det);
1693            let row = mask_tensor.row(idx);
1694            coeff_f32.extend(row.iter().map(|v| {
1695                let v_f32: f32 = v.as_();
1696                (v_f32 - zp) * scale
1697            }));
1698        }
1699        let coeff_tensor = Tensor::<f32>::new(&[n, num_protos], Some(TensorMemory::Mem), None)
1700            .expect("allocating mask_coefficients tensor");
1701        if n > 0 {
1702            let mut m = coeff_tensor
1703                .map()
1704                .expect("mapping mask_coefficients tensor");
1705            m.as_mut_slice().copy_from_slice(&coeff_f32);
1706        }
1707        TensorDyn::F32(coeff_tensor)
1708    };
1709
1710    // Keep protos in raw i8 — consumers dequantize via protos.quantization().
1711    // When PROTO is already i8, detect layout and copy efficiently without
1712    // transposing. The mask materialisation kernels dispatch on the layout.
1713    let (h, w, k) = protos.dim();
1714
1715    // Determine physical layout and copy strategy.
1716    let (proto_shape, proto_layout) =
1717        if std::any::TypeId::of::<PROTO>() == std::any::TypeId::of::<i8>() {
1718            if protos.is_standard_layout() {
1719                // Already NHWC [H, W, K] in contiguous memory.
1720                (&[h, w, k][..], ProtoLayout::Nhwc)
1721            } else if protos.ndim() == 3 && protos.strides() == [w as isize, 1, (h * w) as isize] {
1722                // NCHW reinterpreted as NHWC via stride swap. Physical storage
1723                // is [K, H, W] contiguous. Keep in NCHW — eliminates the costly
1724                // 3.1ms transpose entirely.
1725                (&[k, h, w][..], ProtoLayout::Nchw)
1726            } else {
1727                // Unknown layout — fall back to iter copy as NHWC.
1728                (&[h, w, k][..], ProtoLayout::Nhwc)
1729            }
1730        } else {
1731            (&[h, w, k][..], ProtoLayout::Nhwc)
1732        };
1733
1734    let protos_tensor = Tensor::<i8>::new(proto_shape, Some(TensorMemory::Mem), None)
1735        .expect("allocating protos tensor");
1736    {
1737        let mut m = protos_tensor.map().expect("mapping protos tensor");
1738        let dst = m.as_mut_slice();
1739        if std::any::TypeId::of::<PROTO>() == std::any::TypeId::of::<i8>() {
1740            // SAFETY: PROTO == i8 checked via TypeId; cast slice view is
1741            // size/alignment-compatible by construction.
1742            if protos.is_standard_layout() {
1743                let src: &[i8] = unsafe {
1744                    std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len())
1745                };
1746                dst.copy_from_slice(src);
1747            } else if protos.ndim() == 3 && protos.strides() == [w as isize, 1, (h * w) as isize] {
1748                // NCHW physical layout — sequential copy WITHOUT transpose.
1749                // This saves ~3.1ms on A53/A55 by avoiding the tiled
1750                // NCHW→NHWC transpose of the 819KB proto buffer.
1751                let total = h * w * k;
1752                // SAFETY: ArrayView was constructed from a contiguous slice of
1753                // `total` elements. as_ptr() points to the base of that slice.
1754                let src: &[i8] =
1755                    unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, total) };
1756                dst.copy_from_slice(src);
1757            } else {
1758                for (d, s) in dst.iter_mut().zip(protos.iter()) {
1759                    let v_i8: i8 = s.as_();
1760                    *d = v_i8;
1761                }
1762            }
1763        } else {
1764            for (d, s) in dst.iter_mut().zip(protos.iter()) {
1765                let v_i8: i8 = s.as_();
1766                *d = v_i8;
1767            }
1768        }
1769    }
1770    let tensor_quant =
1771        edgefirst_tensor::Quantization::per_tensor(quant_protos.scale, quant_protos.zero_point);
1772    let protos_tensor = protos_tensor
1773        .with_quantization(tensor_quant)
1774        .expect("per-tensor quantization on new Tensor<i8>");
1775
1776    span.record("layout", tracing::field::debug(&proto_layout));
1777
1778    ProtoData {
1779        mask_coefficients,
1780        protos: TensorDyn::I8(protos_tensor),
1781        layout: proto_layout,
1782    }
1783}
1784
1785/// Per-float-dtype construction of a [`TensorDyn`] from a flat slice / 3-D
1786/// `ArrayView`. Replaces the old `IntoProtoTensor` trait. Each implementor
1787/// either passes its element type straight to `Tensor::from_slice` /
1788/// `Tensor::from_arrayview3`, or narrows `f64` to `f32` (no native f64 kernel
1789/// path exists).
1790pub trait FloatProtoElem: Copy + 'static {
1791    fn slice_into_tensor_dyn(
1792        values: &[Self],
1793        shape: &[usize],
1794    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn>;
1795
1796    fn arrayview3_into_tensor_dyn(
1797        view: ArrayView3<'_, Self>,
1798    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn>;
1799}
1800
1801impl FloatProtoElem for f32 {
1802    fn slice_into_tensor_dyn(
1803        values: &[f32],
1804        shape: &[usize],
1805    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1806        edgefirst_tensor::Tensor::<f32>::from_slice(values, shape)
1807            .map(edgefirst_tensor::TensorDyn::F32)
1808    }
1809    fn arrayview3_into_tensor_dyn(
1810        view: ArrayView3<'_, f32>,
1811    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1812        edgefirst_tensor::Tensor::<f32>::from_arrayview3(view).map(edgefirst_tensor::TensorDyn::F32)
1813    }
1814}
1815
1816impl FloatProtoElem for half::f16 {
1817    fn slice_into_tensor_dyn(
1818        values: &[half::f16],
1819        shape: &[usize],
1820    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1821        edgefirst_tensor::Tensor::<half::f16>::from_slice(values, shape)
1822            .map(edgefirst_tensor::TensorDyn::F16)
1823    }
1824    fn arrayview3_into_tensor_dyn(
1825        view: ArrayView3<'_, half::f16>,
1826    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1827        edgefirst_tensor::Tensor::<half::f16>::from_arrayview3(view)
1828            .map(edgefirst_tensor::TensorDyn::F16)
1829    }
1830}
1831
1832impl FloatProtoElem for f64 {
1833    fn slice_into_tensor_dyn(
1834        values: &[f64],
1835        shape: &[usize],
1836    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1837        // Narrow to f32 — no native f64 kernel path.
1838        let narrowed: Vec<f32> = values.iter().map(|&v| v as f32).collect();
1839        edgefirst_tensor::Tensor::<f32>::from_slice(&narrowed, shape)
1840            .map(edgefirst_tensor::TensorDyn::F32)
1841    }
1842    fn arrayview3_into_tensor_dyn(
1843        view: ArrayView3<'_, f64>,
1844    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1845        let narrowed: ndarray::Array3<f32> = view.mapv(|v| v as f32);
1846        edgefirst_tensor::Tensor::<f32>::from_arrayview3(narrowed.view())
1847            .map(edgefirst_tensor::TensorDyn::F32)
1848    }
1849}
1850
1851fn postprocess_yolo<'a, T>(
1852    output: &'a ArrayView2<'_, T>,
1853) -> (ArrayView2<'a, T>, ArrayView2<'a, T>) {
1854    let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1855    let scores_tensor = output.slice(s![4.., ..,]).reversed_axes();
1856    (boxes_tensor, scores_tensor)
1857}
1858
1859pub(crate) fn postprocess_yolo_seg<'a, T>(
1860    output: &'a ArrayView2<'_, T>,
1861    num_protos: usize,
1862) -> (ArrayView2<'a, T>, ArrayView2<'a, T>, ArrayView2<'a, T>) {
1863    assert!(
1864        output.shape()[0] > num_protos + 4,
1865        "Output shape is too short: {} <= {} + 4",
1866        output.shape()[0],
1867        num_protos
1868    );
1869    let num_classes = output.shape()[0] - 4 - num_protos;
1870    let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1871    let scores_tensor = output.slice(s![4..(num_classes + 4), ..,]).reversed_axes();
1872    let mask_tensor = output.slice(s![(num_classes + 4).., ..,]).reversed_axes();
1873    (boxes_tensor, scores_tensor, mask_tensor)
1874}
1875
1876pub(crate) fn postprocess_yolo_split_segdet<'a, 'b, 'c, BOX, SCORE, MASK>(
1877    boxes_tensor: ArrayView2<'a, BOX>,
1878    scores_tensor: ArrayView2<'b, SCORE>,
1879    mask_tensor: ArrayView2<'c, MASK>,
1880) -> (
1881    ArrayView2<'a, BOX>,
1882    ArrayView2<'b, SCORE>,
1883    ArrayView2<'c, MASK>,
1884) {
1885    let boxes_tensor = boxes_tensor.reversed_axes();
1886    let scores_tensor = scores_tensor.reversed_axes();
1887    let mask_tensor = mask_tensor.reversed_axes();
1888    (boxes_tensor, scores_tensor, mask_tensor)
1889}
1890
1891fn decode_segdet_f32<
1892    MASK: Float + AsPrimitive<f32> + Send + Sync,
1893    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1894>(
1895    boxes: Vec<(DetectBox, usize)>,
1896    masks: ArrayView2<MASK>,
1897    protos: ArrayView3<PROTO>,
1898) -> Result<Vec<(DetectBox, BoundingBox, Array3<u8>)>, crate::DecoderError> {
1899    if boxes.is_empty() {
1900        return Ok(Vec::new());
1901    }
1902    if masks.shape()[1] != protos.shape()[2] {
1903        return Err(crate::DecoderError::InvalidShape(format!(
1904            "Mask coefficients count ({}) doesn't match protos channel count ({})",
1905            masks.shape()[1],
1906            protos.shape()[2],
1907        )));
1908    }
1909    boxes
1910        .into_par_iter()
1911        .map(|b| {
1912            let ind = b.1;
1913            // `protobox` returns the cropped proto slice for `make_segmentation`
1914            // and a `roi` snapped to the 1/proto-grid step. The detection bbox
1915            // stays untouched (EDGEAI-1304); the snapped roi is reported back
1916            // separately so callers can describe where the cropped mask lives.
1917            let (protos, roi) = protobox(&protos, &b.0.bbox)?;
1918            Ok((b.0, roi, make_segmentation(masks.row(ind), protos.view())))
1919        })
1920        .collect()
1921}
1922
1923pub(crate) fn decode_segdet_quant<
1924    MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1925    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1926>(
1927    boxes: Vec<(DetectBox, usize)>,
1928    masks: ArrayView2<MASK>,
1929    protos: ArrayView3<PROTO>,
1930    quant_masks: Quantization,
1931    quant_protos: Quantization,
1932) -> Result<Vec<(DetectBox, BoundingBox, Array3<u8>)>, crate::DecoderError> {
1933    if boxes.is_empty() {
1934        return Ok(Vec::new());
1935    }
1936    if masks.shape()[1] != protos.shape()[2] {
1937        return Err(crate::DecoderError::InvalidShape(format!(
1938            "Mask coefficients count ({}) doesn't match protos channel count ({})",
1939            masks.shape()[1],
1940            protos.shape()[2],
1941        )));
1942    }
1943
1944    let total_bits = MASK::zero().count_zeros() + PROTO::zero().count_zeros() + 5; // 32 protos is 2^5
1945    boxes
1946        .into_iter()
1947        .map(|b| {
1948            let i = b.1;
1949            // See EDGEAI-1304: the caller's bbox stays untouched; the
1950            // proto-grid-snapped `roi` is reported back so the Segmentation's
1951            // bounds can describe the actual cropped mask region.
1952            let (protos, roi) = protobox(&protos, &b.0.bbox.to_canonical())?;
1953            let seg = match total_bits {
1954                0..=64 => make_segmentation_quant::<MASK, PROTO, i64>(
1955                    masks.row(i),
1956                    protos.view(),
1957                    quant_masks,
1958                    quant_protos,
1959                ),
1960                65..=128 => make_segmentation_quant::<MASK, PROTO, i128>(
1961                    masks.row(i),
1962                    protos.view(),
1963                    quant_masks,
1964                    quant_protos,
1965                ),
1966                _ => {
1967                    return Err(crate::DecoderError::NotSupported(format!(
1968                        "Unsupported bit width ({total_bits}) for segmentation computation"
1969                    )));
1970                }
1971            };
1972            Ok((b.0, roi, seg))
1973        })
1974        .collect()
1975}
1976
1977fn protobox<'a, T>(
1978    protos: &'a ArrayView3<T>,
1979    roi: &BoundingBox,
1980) -> Result<(ArrayView3<'a, T>, BoundingBox), crate::DecoderError> {
1981    let width = protos.dim().1 as f32;
1982    let height = protos.dim().0 as f32;
1983
1984    // Detect un-normalized bounding boxes (pixel-space coordinates).
1985    // protobox expects normalized coordinates in [0, 1]. The decoder will
1986    // normalize pixel-space coords automatically when the schema declares
1987    // `Detection::normalized = false` AND model input dimensions are known
1988    // (EDGEAI-1303); reaching this guard means at least one of those is
1989    // missing.
1990    //
1991    // The limit is set to 2.0 (not 1.01) because YOLO models legitimately
1992    // predict coordinates slightly > 1.0 for objects near frame edges.
1993    // Any value > 2.0 is clearly pixel-space (even the smallest practical
1994    // model input of 32×32 would produce coordinates >> 2.0).
1995    const NORM_LIMIT: f32 = 2.0;
1996    if roi.xmin > NORM_LIMIT
1997        || roi.ymin > NORM_LIMIT
1998        || roi.xmax > NORM_LIMIT
1999        || roi.ymax > NORM_LIMIT
2000    {
2001        return Err(crate::DecoderError::InvalidShape(format!(
2002            "Bounding box coordinates appear un-normalized (pixel-space). \
2003             Got bbox=({:.2}, {:.2}, {:.2}, {:.2}) but expected values in [0, 1]. \
2004             Two ways to fix this: \
2005             (1) declare `Detection::normalized = false` in the model schema \
2006             AND make sure the schema's `input.shape` / `input.dshape` carries \
2007             the model input dims so the decoder can divide by (W, H) before NMS \
2008             (EDGEAI-1303 — verify with `Decoder::input_dims().is_some()`); or \
2009             (2) normalize the boxes in-graph before decode().",
2010            roi.xmin, roi.ymin, roi.xmax, roi.ymax,
2011        )));
2012    }
2013
2014    let roi = [
2015        (roi.xmin * width).clamp(0.0, width) as usize,
2016        (roi.ymin * height).clamp(0.0, height) as usize,
2017        (roi.xmax * width).clamp(0.0, width).ceil() as usize,
2018        (roi.ymax * height).clamp(0.0, height).ceil() as usize,
2019    ];
2020
2021    let roi_norm = [
2022        roi[0] as f32 / width,
2023        roi[1] as f32 / height,
2024        roi[2] as f32 / width,
2025        roi[3] as f32 / height,
2026    ]
2027    .into();
2028
2029    let cropped = protos.slice(s![roi[1]..roi[3], roi[0]..roi[2], ..]);
2030
2031    Ok((cropped, roi_norm))
2032}
2033
2034/// Compute a single instance segmentation mask from mask coefficients and
2035/// proto maps (float path).
2036///
2037/// Computes `sigmoid(coefficients · protos)` and maps to `[0, 255]`.
2038/// Returns an `(H, W, 1)` u8 array.
2039fn make_segmentation<
2040    MASK: Float + AsPrimitive<f32> + Send + Sync,
2041    PROTO: Float + AsPrimitive<f32> + Send + Sync,
2042>(
2043    mask: ArrayView1<MASK>,
2044    protos: ArrayView3<PROTO>,
2045) -> Array3<u8> {
2046    let shape = protos.shape();
2047
2048    // Safe to unwrap since the shapes will always be compatible
2049    let mask = mask.to_shape((1, mask.len())).unwrap();
2050    let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
2051    let protos = protos.reversed_axes();
2052    let mask = mask.map(|x| x.as_());
2053    let protos = protos.map(|x| x.as_());
2054
2055    // Safe to unwrap since the shapes will always be compatible
2056    let mask = mask
2057        .dot(&protos)
2058        .into_shape_with_order((shape[0], shape[1], 1))
2059        .unwrap();
2060
2061    mask.map(|x| {
2062        let sigmoid = 1.0 / (1.0 + (-*x).exp());
2063        (sigmoid * 255.0).round() as u8
2064    })
2065}
2066
2067/// Compute a single instance segmentation mask from quantized mask
2068/// coefficients and proto maps.
2069///
2070/// Dequantizes both inputs (subtracting zero-points), computes the dot
2071/// product, applies sigmoid, and maps to `[0, 255]`.
2072/// Returns an `(H, W, 1)` u8 array.
2073fn make_segmentation_quant<
2074    MASK: PrimInt + AsPrimitive<DEST> + Send + Sync,
2075    PROTO: PrimInt + AsPrimitive<DEST> + Send + Sync,
2076    DEST: PrimInt + 'static + Signed + AsPrimitive<f32> + Debug,
2077>(
2078    mask: ArrayView1<MASK>,
2079    protos: ArrayView3<PROTO>,
2080    quant_masks: Quantization,
2081    quant_protos: Quantization,
2082) -> Array3<u8>
2083where
2084    i32: AsPrimitive<DEST>,
2085    f32: AsPrimitive<DEST>,
2086{
2087    let shape = protos.shape();
2088
2089    // Safe to unwrap since the shapes will always be compatible
2090    let mask = mask.to_shape((1, mask.len())).unwrap();
2091
2092    let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
2093    let protos = protos.reversed_axes();
2094
2095    let zp = quant_masks.zero_point.as_();
2096
2097    let mask = mask.mapv(|x| x.as_() - zp);
2098
2099    let zp = quant_protos.zero_point.as_();
2100    let protos = protos.mapv(|x| x.as_() - zp);
2101
2102    // Safe to unwrap since the shapes will always be compatible
2103    let segmentation = mask
2104        .dot(&protos)
2105        .into_shape_with_order((shape[0], shape[1], 1))
2106        .unwrap();
2107
2108    let combined_scale = quant_masks.scale * quant_protos.scale;
2109    segmentation.map(|x| {
2110        let val: f32 = (*x).as_() * combined_scale;
2111        let sigmoid = 1.0 / (1.0 + (-val).exp());
2112        (sigmoid * 255.0).round() as u8
2113    })
2114}
2115
2116/// Converts Yolo Instance Segmentation into a 2D mask.
2117///
2118/// The input segmentation is expected to have shape (H, W, 1).
2119///
2120/// The output mask will have shape (H, W), with values 0 or 1 based on the
2121/// threshold.
2122///
2123/// # Errors
2124///
2125/// Returns `DecoderError::InvalidShape` if the input segmentation does not
2126/// have shape (H, W, 1).
2127pub(crate) fn yolo_segmentation_to_mask(
2128    segmentation: ArrayView3<u8>,
2129    threshold: u8,
2130) -> Result<Array2<u8>, crate::DecoderError> {
2131    if segmentation.shape()[2] != 1 {
2132        return Err(crate::DecoderError::InvalidShape(format!(
2133            "Yolo Instance Segmentation should have shape (H, W, 1), got (H, W, {})",
2134            segmentation.shape()[2]
2135        )));
2136    }
2137    Ok(segmentation
2138        .slice(s![.., .., 0])
2139        .map(|x| if *x >= threshold { 1 } else { 0 }))
2140}
2141
2142#[cfg(test)]
2143#[cfg_attr(coverage_nightly, coverage(off))]
2144mod tests {
2145    use super::*;
2146    use ndarray::Array2;
2147
2148    // ========================================================================
2149    // Tests for decode_yolo_end_to_end_det_float
2150    // ========================================================================
2151
2152    #[test]
2153    fn test_end_to_end_det_basic_filtering() {
2154        // Create synthetic end-to-end detection output: (6, N) where rows are
2155        // [x1, y1, x2, y2, conf, class]
2156        // 3 detections: one above threshold, two below
2157        let data: Vec<f32> = vec![
2158            // Detection 0: high score (0.9)
2159            0.1, 0.2, 0.3, // x1 values
2160            0.1, 0.2, 0.3, // y1 values
2161            0.5, 0.6, 0.7, // x2 values
2162            0.5, 0.6, 0.7, // y2 values
2163            0.9, 0.1, 0.2, // confidence scores
2164            0.0, 1.0, 2.0, // class indices
2165        ];
2166        let output = Array2::from_shape_vec((6, 3), data).unwrap();
2167
2168        let mut boxes = Vec::with_capacity(10);
2169        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2170
2171        // Only 1 detection should pass threshold of 0.5
2172        assert_eq!(boxes.len(), 1);
2173        assert_eq!(boxes[0].label, 0);
2174        assert!((boxes[0].score - 0.9).abs() < 0.01);
2175        assert!((boxes[0].bbox.xmin - 0.1).abs() < 0.01);
2176        assert!((boxes[0].bbox.ymin - 0.1).abs() < 0.01);
2177        assert!((boxes[0].bbox.xmax - 0.5).abs() < 0.01);
2178        assert!((boxes[0].bbox.ymax - 0.5).abs() < 0.01);
2179    }
2180
2181    #[test]
2182    fn test_end_to_end_det_all_pass_threshold() {
2183        // All detections above threshold
2184        let data: Vec<f32> = vec![
2185            10.0, 20.0, // x1
2186            10.0, 20.0, // y1
2187            50.0, 60.0, // x2
2188            50.0, 60.0, // y2
2189            0.8, 0.7, // conf (both above 0.5)
2190            1.0, 2.0, // class
2191        ];
2192        let output = Array2::from_shape_vec((6, 2), data).unwrap();
2193
2194        let mut boxes = Vec::with_capacity(10);
2195        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2196
2197        assert_eq!(boxes.len(), 2);
2198        assert_eq!(boxes[0].label, 1);
2199        assert_eq!(boxes[1].label, 2);
2200    }
2201
2202    #[test]
2203    fn test_end_to_end_det_none_pass_threshold() {
2204        // All detections below threshold
2205        let data: Vec<f32> = vec![
2206            10.0, 20.0, // x1
2207            10.0, 20.0, // y1
2208            50.0, 60.0, // x2
2209            50.0, 60.0, // y2
2210            0.1, 0.2, // conf (both below 0.5)
2211            1.0, 2.0, // class
2212        ];
2213        let output = Array2::from_shape_vec((6, 2), data).unwrap();
2214
2215        let mut boxes = Vec::with_capacity(10);
2216        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2217
2218        assert_eq!(boxes.len(), 0);
2219    }
2220
2221    #[test]
2222    fn test_end_to_end_det_capacity_limit() {
2223        // Test that output is truncated to capacity
2224        let data: Vec<f32> = vec![
2225            0.1, 0.2, 0.3, 0.4, 0.5, // x1
2226            0.1, 0.2, 0.3, 0.4, 0.5, // y1
2227            0.5, 0.6, 0.7, 0.8, 0.9, // x2
2228            0.5, 0.6, 0.7, 0.8, 0.9, // y2
2229            0.9, 0.9, 0.9, 0.9, 0.9, // conf (all pass)
2230            0.0, 1.0, 2.0, 3.0, 4.0, // class
2231        ];
2232        let output = Array2::from_shape_vec((6, 5), data).unwrap();
2233
2234        let mut boxes = Vec::with_capacity(2); // Only allow 2 boxes
2235        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2236
2237        assert_eq!(boxes.len(), 2);
2238    }
2239
2240    #[test]
2241    fn test_end_to_end_det_empty_output() {
2242        // Test with zero detections
2243        let output = Array2::<f32>::zeros((6, 0));
2244
2245        let mut boxes = Vec::with_capacity(10);
2246        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2247
2248        assert_eq!(boxes.len(), 0);
2249    }
2250
2251    #[test]
2252    fn test_end_to_end_det_pixel_coordinates() {
2253        // Test with pixel coordinates (non-normalized)
2254        let data: Vec<f32> = vec![
2255            100.0, // x1
2256            200.0, // y1
2257            300.0, // x2
2258            400.0, // y2
2259            0.95,  // conf
2260            5.0,   // class
2261        ];
2262        let output = Array2::from_shape_vec((6, 1), data).unwrap();
2263
2264        let mut boxes = Vec::with_capacity(10);
2265        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2266
2267        assert_eq!(boxes.len(), 1);
2268        assert_eq!(boxes[0].label, 5);
2269        assert!((boxes[0].bbox.xmin - 100.0).abs() < 0.01);
2270        assert!((boxes[0].bbox.ymin - 200.0).abs() < 0.01);
2271        assert!((boxes[0].bbox.xmax - 300.0).abs() < 0.01);
2272        assert!((boxes[0].bbox.ymax - 400.0).abs() < 0.01);
2273    }
2274
2275    #[test]
2276    fn test_end_to_end_det_invalid_shape() {
2277        // Test with too few rows (needs at least 6)
2278        let output = Array2::<f32>::zeros((5, 3));
2279
2280        let mut boxes = Vec::with_capacity(10);
2281        let result = decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes);
2282
2283        assert!(result.is_err());
2284        assert!(matches!(
2285            result,
2286            Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 6 rows")
2287        ));
2288    }
2289
2290    // ========================================================================
2291    // Tests for decode_yolo_end_to_end_segdet_float
2292    // ========================================================================
2293
2294    #[test]
2295    fn test_end_to_end_segdet_basic() {
2296        // Create synthetic segdet output: (6 + num_protos, N)
2297        // Detection format: [x1, y1, x2, y2, conf, class, mask_coeff_0..31]
2298        let num_protos = 32;
2299        let num_detections = 2;
2300        let num_features = 6 + num_protos;
2301
2302        // Build detection tensor
2303        let mut data = vec![0.0f32; num_features * num_detections];
2304        // Detection 0: passes threshold
2305        data[0] = 0.1; // x1[0]
2306        data[1] = 0.5; // x1[1]
2307        data[num_detections] = 0.1; // y1[0]
2308        data[num_detections + 1] = 0.5; // y1[1]
2309        data[2 * num_detections] = 0.4; // x2[0]
2310        data[2 * num_detections + 1] = 0.9; // x2[1]
2311        data[3 * num_detections] = 0.4; // y2[0]
2312        data[3 * num_detections + 1] = 0.9; // y2[1]
2313        data[4 * num_detections] = 0.9; // conf[0] - passes
2314        data[4 * num_detections + 1] = 0.3; // conf[1] - fails
2315        data[5 * num_detections] = 1.0; // class[0]
2316        data[5 * num_detections + 1] = 2.0; // class[1]
2317                                            // Fill mask coefficients with small values
2318        for i in 6..num_features {
2319            data[i * num_detections] = 0.1;
2320            data[i * num_detections + 1] = 0.1;
2321        }
2322
2323        let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
2324
2325        // Create protos tensor: (proto_height, proto_width, num_protos)
2326        let protos = Array3::<f32>::zeros((16, 16, num_protos));
2327
2328        let mut boxes = Vec::with_capacity(10);
2329        let mut masks = Vec::with_capacity(10);
2330        decode_yolo_end_to_end_segdet_float(
2331            output.view(),
2332            protos.view(),
2333            0.5,
2334            &mut boxes,
2335            &mut masks,
2336        )
2337        .unwrap();
2338
2339        // Only detection 0 should pass
2340        assert_eq!(boxes.len(), 1);
2341        assert_eq!(masks.len(), 1);
2342        assert_eq!(boxes[0].label, 1);
2343        assert!((boxes[0].score - 0.9).abs() < 0.01);
2344    }
2345
2346    #[test]
2347    fn test_end_to_end_segdet_mask_coordinates() {
2348        // Test that mask coordinates match box coordinates
2349        let num_protos = 32;
2350        let num_features = 6 + num_protos;
2351
2352        let mut data = vec![0.0f32; num_features];
2353        data[0] = 0.2; // x1
2354        data[1] = 0.2; // y1
2355        data[2] = 0.8; // x2
2356        data[3] = 0.8; // y2
2357        data[4] = 0.95; // conf
2358        data[5] = 3.0; // class
2359
2360        let output = Array2::from_shape_vec((num_features, 1), data).unwrap();
2361        let protos = Array3::<f32>::zeros((16, 16, num_protos));
2362
2363        let mut boxes = Vec::with_capacity(10);
2364        let mut masks = Vec::with_capacity(10);
2365        decode_yolo_end_to_end_segdet_float(
2366            output.view(),
2367            protos.view(),
2368            0.5,
2369            &mut boxes,
2370            &mut masks,
2371        )
2372        .unwrap();
2373
2374        assert_eq!(boxes.len(), 1);
2375        assert_eq!(masks.len(), 1);
2376
2377        // Mask region is the proto-grid-aligned crop and encloses the
2378        // post-NMS bbox (EDGEAI-1304); on a 16x16 grid each side may snap
2379        // by up to 1/16 = 0.0625.
2380        let step = 1.0 / 16.0;
2381        assert!(masks[0].xmin <= boxes[0].bbox.xmin);
2382        assert!(masks[0].ymin <= boxes[0].bbox.ymin);
2383        assert!(masks[0].xmax >= boxes[0].bbox.xmax);
2384        assert!(masks[0].ymax >= boxes[0].bbox.ymax);
2385        assert!((boxes[0].bbox.xmin - masks[0].xmin) < step);
2386        assert!((boxes[0].bbox.ymin - masks[0].ymin) < step);
2387        assert!((masks[0].xmax - boxes[0].bbox.xmax) < step);
2388        assert!((masks[0].ymax - boxes[0].bbox.ymax) < step);
2389    }
2390
2391    #[test]
2392    fn test_end_to_end_segdet_empty_output() {
2393        let num_protos = 32;
2394        let output = Array2::<f32>::zeros((6 + num_protos, 0));
2395        let protos = Array3::<f32>::zeros((16, 16, num_protos));
2396
2397        let mut boxes = Vec::with_capacity(10);
2398        let mut masks = Vec::with_capacity(10);
2399        decode_yolo_end_to_end_segdet_float(
2400            output.view(),
2401            protos.view(),
2402            0.5,
2403            &mut boxes,
2404            &mut masks,
2405        )
2406        .unwrap();
2407
2408        assert_eq!(boxes.len(), 0);
2409        assert_eq!(masks.len(), 0);
2410    }
2411
2412    #[test]
2413    fn test_end_to_end_segdet_capacity_limit() {
2414        let num_protos = 32;
2415        let num_detections = 5;
2416        let num_features = 6 + num_protos;
2417
2418        let mut data = vec![0.0f32; num_features * num_detections];
2419        // All detections pass threshold
2420        for i in 0..num_detections {
2421            data[i] = 0.1 * (i as f32); // x1
2422            data[num_detections + i] = 0.1 * (i as f32); // y1
2423            data[2 * num_detections + i] = 0.1 * (i as f32) + 0.2; // x2
2424            data[3 * num_detections + i] = 0.1 * (i as f32) + 0.2; // y2
2425            data[4 * num_detections + i] = 0.9; // conf
2426            data[5 * num_detections + i] = i as f32; // class
2427        }
2428
2429        let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
2430        let protos = Array3::<f32>::zeros((16, 16, num_protos));
2431
2432        let mut boxes = Vec::with_capacity(2); // Limit to 2
2433        let mut masks = Vec::with_capacity(2);
2434        decode_yolo_end_to_end_segdet_float(
2435            output.view(),
2436            protos.view(),
2437            0.5,
2438            &mut boxes,
2439            &mut masks,
2440        )
2441        .unwrap();
2442
2443        assert_eq!(boxes.len(), 2);
2444        assert_eq!(masks.len(), 2);
2445    }
2446
2447    #[test]
2448    fn test_end_to_end_segdet_invalid_shape_too_few_rows() {
2449        // Test with too few rows (needs at least 7: 6 base + 1 mask coeff)
2450        let output = Array2::<f32>::zeros((6, 3));
2451        let protos = Array3::<f32>::zeros((16, 16, 32));
2452
2453        let mut boxes = Vec::with_capacity(10);
2454        let mut masks = Vec::with_capacity(10);
2455        let result = decode_yolo_end_to_end_segdet_float(
2456            output.view(),
2457            protos.view(),
2458            0.5,
2459            &mut boxes,
2460            &mut masks,
2461        );
2462
2463        assert!(result.is_err());
2464        assert!(matches!(
2465            result,
2466            Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 7 rows")
2467        ));
2468    }
2469
2470    #[test]
2471    fn test_end_to_end_segdet_invalid_shape_protos_mismatch() {
2472        // Test with mismatched mask coefficients and protos count
2473        let num_protos = 32;
2474        let output = Array2::<f32>::zeros((6 + 16, 3)); // 16 mask coeffs
2475        let protos = Array3::<f32>::zeros((16, 16, num_protos)); // 32 protos
2476
2477        let mut boxes = Vec::with_capacity(10);
2478        let mut masks = Vec::with_capacity(10);
2479        let result = decode_yolo_end_to_end_segdet_float(
2480            output.view(),
2481            protos.view(),
2482            0.5,
2483            &mut boxes,
2484            &mut masks,
2485        );
2486
2487        assert!(result.is_err());
2488        assert!(matches!(
2489            result,
2490            Err(crate::DecoderError::InvalidShape(s)) if s.contains("doesn't match protos count")
2491        ));
2492    }
2493
2494    // ========================================================================
2495    // Tests for decode_yolo_split_end_to_end_segdet_float
2496    // ========================================================================
2497
2498    #[test]
2499    fn test_split_end_to_end_segdet_basic() {
2500        // Create synthetic segdet output: (6 + num_protos, N)
2501        // Detection format: [x1, y1, x2, y2, conf, class, mask_coeff_0..31]
2502        let num_protos = 32;
2503        let num_detections = 2;
2504        let num_features = 6 + num_protos;
2505
2506        // Build detection tensor
2507        let mut data = vec![0.0f32; num_features * num_detections];
2508        // Detection 0: passes threshold
2509        data[0] = 0.1; // x1[0]
2510        data[1] = 0.5; // x1[1]
2511        data[num_detections] = 0.1; // y1[0]
2512        data[num_detections + 1] = 0.5; // y1[1]
2513        data[2 * num_detections] = 0.4; // x2[0]
2514        data[2 * num_detections + 1] = 0.9; // x2[1]
2515        data[3 * num_detections] = 0.4; // y2[0]
2516        data[3 * num_detections + 1] = 0.9; // y2[1]
2517        data[4 * num_detections] = 0.9; // conf[0] - passes
2518        data[4 * num_detections + 1] = 0.3; // conf[1] - fails
2519        data[5 * num_detections] = 1.0; // class[0]
2520        data[5 * num_detections + 1] = 2.0; // class[1]
2521                                            // Fill mask coefficients with small values
2522        for i in 6..num_features {
2523            data[i * num_detections] = 0.1;
2524            data[i * num_detections + 1] = 0.1;
2525        }
2526
2527        let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
2528        let box_coords = output.slice(s![..4, ..]);
2529        let scores = output.slice(s![4..5, ..]);
2530        let classes = output.slice(s![5..6, ..]);
2531        let mask_coeff = output.slice(s![6.., ..]);
2532        // Create protos tensor: (proto_height, proto_width, num_protos)
2533        let protos = Array3::<f32>::zeros((16, 16, num_protos));
2534
2535        let mut boxes = Vec::with_capacity(10);
2536        let mut masks = Vec::with_capacity(10);
2537        decode_yolo_split_end_to_end_segdet_float(
2538            box_coords,
2539            scores,
2540            classes,
2541            mask_coeff,
2542            protos.view(),
2543            0.5,
2544            &mut boxes,
2545            &mut masks,
2546        )
2547        .unwrap();
2548
2549        // Only detection 0 should pass
2550        assert_eq!(boxes.len(), 1);
2551        assert_eq!(masks.len(), 1);
2552        assert_eq!(boxes[0].label, 1);
2553        assert!((boxes[0].score - 0.9).abs() < 0.01);
2554    }
2555
2556    // ========================================================================
2557    // Tests for yolo_segmentation_to_mask
2558    // ========================================================================
2559
2560    #[test]
2561    fn test_segmentation_to_mask_basic() {
2562        // Create a 4x4x1 segmentation with values above and below threshold
2563        let data: Vec<u8> = vec![
2564            100, 200, 50, 150, // row 0
2565            10, 255, 128, 64, // row 1
2566            0, 127, 128, 255, // row 2
2567            64, 64, 192, 192, // row 3
2568        ];
2569        let segmentation = Array3::from_shape_vec((4, 4, 1), data).unwrap();
2570
2571        let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2572
2573        // Values >= 128 should be 1, others 0
2574        assert_eq!(mask[[0, 0]], 0); // 100 < 128
2575        assert_eq!(mask[[0, 1]], 1); // 200 >= 128
2576        assert_eq!(mask[[0, 2]], 0); // 50 < 128
2577        assert_eq!(mask[[0, 3]], 1); // 150 >= 128
2578        assert_eq!(mask[[1, 1]], 1); // 255 >= 128
2579        assert_eq!(mask[[1, 2]], 1); // 128 >= 128
2580        assert_eq!(mask[[2, 0]], 0); // 0 < 128
2581        assert_eq!(mask[[2, 1]], 0); // 127 < 128
2582    }
2583
2584    #[test]
2585    fn test_segmentation_to_mask_all_above() {
2586        let segmentation = Array3::from_elem((4, 4, 1), 255u8);
2587        let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2588        assert!(mask.iter().all(|&x| x == 1));
2589    }
2590
2591    #[test]
2592    fn test_segmentation_to_mask_all_below() {
2593        let segmentation = Array3::from_elem((4, 4, 1), 64u8);
2594        let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2595        assert!(mask.iter().all(|&x| x == 0));
2596    }
2597
2598    #[test]
2599    fn test_segmentation_to_mask_invalid_shape() {
2600        let segmentation = Array3::from_elem((4, 4, 3), 128u8);
2601        let result = yolo_segmentation_to_mask(segmentation.view(), 128);
2602
2603        assert!(result.is_err());
2604        assert!(matches!(
2605            result,
2606            Err(crate::DecoderError::InvalidShape(s)) if s.contains("(H, W, 1)")
2607        ));
2608    }
2609
2610    // ========================================================================
2611    // Tests for protobox / NORM_LIMIT regression
2612    // ========================================================================
2613
2614    #[test]
2615    fn test_protobox_clamps_edge_coordinates() {
2616        // bbox with xmax=1.0 should not panic (OOB guard)
2617        let protos = Array3::<f32>::zeros((16, 16, 4));
2618        let view = protos.view();
2619        let roi = BoundingBox {
2620            xmin: 0.5,
2621            ymin: 0.5,
2622            xmax: 1.0,
2623            ymax: 1.0,
2624        };
2625        let result = protobox(&view, &roi);
2626        assert!(result.is_ok(), "protobox should accept xmax=1.0");
2627        let (cropped, _roi_norm) = result.unwrap();
2628        // Cropped region must have non-zero spatial dimensions
2629        assert!(cropped.shape()[0] > 0);
2630        assert!(cropped.shape()[1] > 0);
2631        assert_eq!(cropped.shape()[2], 4);
2632    }
2633
2634    #[test]
2635    fn test_protobox_rejects_wildly_out_of_range() {
2636        // bbox with coords > NORM_LIMIT (e.g. 3.0) returns error
2637        let protos = Array3::<f32>::zeros((16, 16, 4));
2638        let view = protos.view();
2639        let roi = BoundingBox {
2640            xmin: 0.0,
2641            ymin: 0.0,
2642            xmax: 3.0,
2643            ymax: 3.0,
2644        };
2645        let result = protobox(&view, &roi);
2646        assert!(
2647            matches!(result, Err(crate::DecoderError::InvalidShape(s)) if s.contains("un-normalized")),
2648            "protobox should reject coords > NORM_LIMIT"
2649        );
2650    }
2651
2652    #[test]
2653    fn test_protobox_accepts_slightly_over_one() {
2654        // bbox with coords at 1.5 (within NORM_LIMIT=2.0) succeeds
2655        let protos = Array3::<f32>::zeros((16, 16, 4));
2656        let view = protos.view();
2657        let roi = BoundingBox {
2658            xmin: 0.0,
2659            ymin: 0.0,
2660            xmax: 1.5,
2661            ymax: 1.5,
2662        };
2663        let result = protobox(&view, &roi);
2664        assert!(
2665            result.is_ok(),
2666            "protobox should accept coords <= NORM_LIMIT (2.0)"
2667        );
2668        let (cropped, _roi_norm) = result.unwrap();
2669        // Entire proto map should be selected when coords > 1.0 (clamped to boundary)
2670        assert_eq!(cropped.shape()[0], 16);
2671        assert_eq!(cropped.shape()[1], 16);
2672    }
2673
2674    #[test]
2675    fn test_segdet_float_proto_no_panic() {
2676        // Simulates YOLOv8n-seg: output0 = [116, 8400] (4 box + 80 class + 32 mask coeff)
2677        // output1 (protos) = [32, 160, 160]
2678        let num_proposals = 100; // enough to produce idx >= 32
2679        let num_classes = 80;
2680        let num_mask_coeffs = 32;
2681        let rows = 4 + num_classes + num_mask_coeffs; // 116
2682
2683        // Fill boxes with valid xywh data so some detections pass the threshold.
2684        // Layout is [116, num_proposals] row-major: row 0=cx, 1=cy, 2=w, 3=h,
2685        // rows 4..84=class scores, rows 84..116=mask coefficients.
2686        let mut data = vec![0.0f32; rows * num_proposals];
2687        for i in 0..num_proposals {
2688            let row = |r: usize| r * num_proposals + i;
2689            data[row(0)] = 320.0; // cx
2690            data[row(1)] = 320.0; // cy
2691            data[row(2)] = 50.0; // w
2692            data[row(3)] = 50.0; // h
2693            data[row(4)] = 0.9; // class-0 score
2694        }
2695        let boxes = ndarray::Array2::from_shape_vec((rows, num_proposals), data).unwrap();
2696
2697        // Protos must be in HWC order. Under the HAL physical-order
2698        // contract, callers declare shape+dshape matching producer memory
2699        // and swap_axes_if_needed permutes the stride tuple into canonical
2700        // [batch, height, width, num_protos] before this function sees it.
2701        let protos = ndarray::Array3::<f32>::zeros((160, 160, num_mask_coeffs));
2702
2703        let mut output_boxes = Vec::with_capacity(300);
2704
2705        // This panicked before fix: mask_tensor.row(idx) with idx >= 32
2706        let proto_data = impl_yolo_segdet_float_proto::<XYWH, _, _>(
2707            boxes.view(),
2708            protos.view(),
2709            0.5,
2710            0.7,
2711            Some(Nms::default()),
2712            MAX_NMS_CANDIDATES,
2713            300,
2714            None,
2715            None,
2716            &mut output_boxes,
2717        );
2718
2719        // Should produce detections (NMS will collapse many overlapping boxes)
2720        assert!(!output_boxes.is_empty());
2721        let coeffs_shape = proto_data.mask_coefficients.shape();
2722        assert_eq!(coeffs_shape[0], output_boxes.len());
2723        // Each mask coefficient vector should have 32 elements
2724        assert_eq!(coeffs_shape[1], num_mask_coeffs);
2725    }
2726
2727    // ========================================================================
2728    // Pre-NMS top-K cap (MAX_NMS_CANDIDATES)
2729    // ========================================================================
2730
2731    /// At very low score thresholds (e.g., t=0.01 on YOLOv8 with 8400×80
2732    /// candidates) almost every score passes the filter, feeding O(n²)
2733    /// NMS and a per-survivor mask matmul. The decoder caps the
2734    /// candidate set fed to NMS at `MAX_NMS_CANDIDATES` (Ultralytics
2735    /// default 30 000) to bound worst-case decode time.
2736    ///
2737    /// This regression test pumps 50 000 above-threshold candidates
2738    /// into `impl_yolo_segdet_get_boxes` with NMS bypassed (Nms=None)
2739    /// and a generous post-NMS cap. Before the fix, the function
2740    /// returned all 50 000; after the fix, exactly 30 000.
2741    #[test]
2742    fn test_pre_nms_cap_truncates_excess_candidates() {
2743        let n: usize = 50_000;
2744        let num_classes = 1;
2745
2746        // Identical valid boxes. Distinct scores (descending) so the
2747        // top-K cap keeps the highest-scoring ones in deterministic
2748        // order — letting us assert *which* ones survived.
2749        let mut boxes_data = Vec::with_capacity(n * 4);
2750        let mut scores_data = Vec::with_capacity(n * num_classes);
2751        for i in 0..n {
2752            boxes_data.extend_from_slice(&[0.1f32, 0.1, 0.5, 0.5]);
2753            // score_i = 0.99 - i * 1e-7 keeps everything well above 0.1
2754            // threshold but strictly decreasing.
2755            scores_data.push(0.99 - (i as f32) * 1e-7);
2756        }
2757        let boxes = Array2::from_shape_vec((n, 4), boxes_data).unwrap();
2758        let scores = Array2::from_shape_vec((n, num_classes), scores_data).unwrap();
2759
2760        let result = impl_yolo_segdet_get_boxes::<XYXY, _, _>(
2761            boxes.view(),
2762            scores.view(),
2763            0.1,
2764            1.0,                             // IoU 1.0 → NMS suppresses nothing
2765            Some(Nms::ClassAgnostic),        // NMS enabled so pre_nms_top_k applies
2766            crate::yolo::MAX_NMS_CANDIDATES, // pre_nms_top_k
2767            usize::MAX,                      // no post-NMS truncation
2768        );
2769
2770        assert_eq!(
2771            result.len(),
2772            crate::yolo::MAX_NMS_CANDIDATES,
2773            "pre-NMS cap should truncate to MAX_NMS_CANDIDATES; got {}",
2774            result.len()
2775        );
2776        // Top-K survivors: highest scores were the first n indices,
2777        // so survivor 0 must have score ~0.99.
2778        let top_score = result[0].0.score;
2779        assert!(
2780            top_score > 0.98,
2781            "highest-ranked survivor should have the largest score, got {top_score}"
2782        );
2783    }
2784
2785    /// Counterpart for the quantized split path. Same contract: feed
2786    /// more than `MAX_NMS_CANDIDATES` survivors above the quantized
2787    /// threshold, confirm `impl_yolo_split_segdet_quant_get_boxes`
2788    /// truncates before NMS.
2789    #[test]
2790    fn test_pre_nms_cap_truncates_excess_candidates_quant() {
2791        use crate::Quantization;
2792        let n: usize = 50_000;
2793        let num_classes = 1;
2794
2795        // i8 boxes with simple scale/zp; the box value 50 dequantizes
2796        // to 0.5 with scale=0.01, zp=0 — fine for a flat box set.
2797        let boxes_data = (0..n).flat_map(|_| [10i8, 10, 50, 50]).collect::<Vec<_>>();
2798        let boxes = Array2::from_shape_vec((n, 4), boxes_data).unwrap();
2799        let quant_boxes = Quantization {
2800            scale: 0.01,
2801            zero_point: 0,
2802        };
2803
2804        // u8 scores: distinct descending values, all well above threshold.
2805        // value 250 → 0.98 with scale 0.00392, zp 0.
2806        // value (250 - i % 200) keeps a wide spread above the dequant
2807        // threshold of 0.5.
2808        let scores_data: Vec<u8> = (0..n)
2809            .map(|i| 250u8.saturating_sub((i % 200) as u8))
2810            .collect();
2811        let scores = Array2::from_shape_vec((n, num_classes), scores_data).unwrap();
2812        let quant_scores = Quantization {
2813            scale: 0.00392,
2814            zero_point: 0,
2815        };
2816
2817        let result = impl_yolo_split_segdet_quant_get_boxes::<XYXY, _, _>(
2818            (boxes.view(), quant_boxes),
2819            (scores.view(), quant_scores),
2820            0.1,
2821            1.0,                             // IoU 1.0 → NMS suppresses nothing
2822            Some(Nms::ClassAgnostic),        // NMS enabled so pre_nms_top_k applies
2823            crate::yolo::MAX_NMS_CANDIDATES, // pre_nms_top_k
2824            usize::MAX,                      // no post-NMS truncation
2825        );
2826
2827        assert_eq!(
2828            result.len(),
2829            crate::yolo::MAX_NMS_CANDIDATES,
2830            "quant path pre-NMS cap should truncate to MAX_NMS_CANDIDATES; got {}",
2831            result.len()
2832        );
2833    }
2834
2835    /// Regression test for HAILORT_BUG.md — the YoloSegDet path
2836    /// (combined `(4 + nc + nm, N)` detection tensor + separate protos)
2837    /// must pair each surviving detection with the mask coefficient
2838    /// row at the SAME anchor index the box came from. The validator
2839    /// sees this path miss the pairing under schema-v2 Hailo inputs
2840    /// (mAP collapse from 46.8 → 3.65 while mask IoU stays at 66.9,
2841    /// the fingerprint of mask-to-detection misalignment).
2842    ///
2843    /// Construction: three anchors with distinct mask-coef signatures
2844    /// that, after dot(coefs, protos) + sigmoid, produce HIGH vs LOW
2845    /// mask pixel values. Two anchors survive (one high, one low); if
2846    /// the mask row is looked up at the wrong index, the per-detection
2847    /// mean mask value would cross the threshold and we catch it.
2848    #[test]
2849    fn segdet_combined_tensor_pairs_detection_with_matching_mask_row() {
2850        let nc = 2; // num_classes
2851        let nm = 2; // num_protos
2852        let n = 3; // num_anchors
2853        let feat = 4 + nc + nm; // 8
2854
2855        // Tensor layout: (8, 3) rows=features, cols=anchors.
2856        // Row indices:  0..4 = xywh, 4..6 = scores, 6..8 = mask_coefs.
2857        //
2858        //         anchor 0 | anchor 1 | anchor 2
2859        // xc       0.2      | 0.5      | 0.8
2860        // yc       0.2      | 0.5      | 0.8
2861        // w        0.1      | 0.1      | 0.1
2862        // h        0.1      | 0.1      | 0.1
2863        // s[0]     0.9      | 0.0      | 0.8   (class 0)
2864        // s[1]     0.0      | 0.0      | 0.0   (class 1 — always loses)
2865        // m[0]     3.0      | 0.0      | -3.0  (high for a0, low for a2)
2866        // m[1]     3.0      | 0.0      | -3.0
2867        //
2868        // Proto[0] = Proto[1] = all-ones (8x8), so
2869        //   mask(a0) = sigmoid(3 + 3) ≈ 0.9975 → 254
2870        //   mask(a2) = sigmoid(-3 + -3) ≈ 0.0025 → 1
2871        // 250-point gap makes any misalignment trivially detectable.
2872        let mut data = vec![0.0f32; feat * n];
2873        let set = |d: &mut [f32], r: usize, c: usize, v: f32| d[r * n + c] = v;
2874        set(&mut data, 0, 0, 0.2);
2875        set(&mut data, 1, 0, 0.2);
2876        set(&mut data, 2, 0, 0.1);
2877        set(&mut data, 3, 0, 0.1);
2878        set(&mut data, 0, 1, 0.5);
2879        set(&mut data, 1, 1, 0.5);
2880        set(&mut data, 2, 1, 0.1);
2881        set(&mut data, 3, 1, 0.1);
2882        set(&mut data, 0, 2, 0.8);
2883        set(&mut data, 1, 2, 0.8);
2884        set(&mut data, 2, 2, 0.1);
2885        set(&mut data, 3, 2, 0.1);
2886        set(&mut data, 4, 0, 0.9);
2887        set(&mut data, 4, 2, 0.8);
2888        set(&mut data, 6, 0, 3.0);
2889        set(&mut data, 7, 0, 3.0);
2890        set(&mut data, 6, 2, -3.0);
2891        set(&mut data, 7, 2, -3.0);
2892
2893        let output = Array2::from_shape_vec((feat, n), data).unwrap();
2894        let protos = Array3::<f32>::from_elem((8, 8, nm), 1.0);
2895
2896        let mut boxes: Vec<DetectBox> = Vec::with_capacity(10);
2897        let mut masks: Vec<Segmentation> = Vec::with_capacity(10);
2898        decode_yolo_segdet_float(
2899            output.view(),
2900            protos.view(),
2901            0.5,
2902            0.5,
2903            Some(Nms::ClassAgnostic),
2904            &mut boxes,
2905            &mut masks,
2906        )
2907        .unwrap();
2908
2909        assert_eq!(
2910            boxes.len(),
2911            2,
2912            "two anchors above threshold should survive (a0 score=0.9, a2 score=0.8); got {}",
2913            boxes.len()
2914        );
2915
2916        // Build a (anchor_index → mask_mean) mapping from the results.
2917        // Anchor 0 has centre (0.2, 0.2), anchor 2 has centre (0.8,
2918        // 0.8). The DetectBox bbox is the post-XYWH-to-XYXY conversion
2919        // of the original xywh; cropping inside protobox may shrink it,
2920        // so match by centre (0.2 vs 0.8) rather than exact bbox.
2921        for (b, m) in boxes.iter().zip(masks.iter()) {
2922            let cx = (b.bbox.xmin + b.bbox.xmax) * 0.5;
2923            let mean = {
2924                let s = &m.segmentation;
2925                let total: u32 = s.iter().map(|&v| v as u32).sum();
2926                total as f32 / s.len() as f32
2927            };
2928            if cx < 0.3 {
2929                // anchor 0 — expect HIGH mask values ≈ 254
2930                assert!(
2931                    mean > 200.0,
2932                    "anchor 0 detection (centre {cx:.2}) should have high-value mask; got mean {mean}"
2933                );
2934            } else if cx > 0.7 {
2935                // anchor 2 — expect LOW mask values ≈ 1
2936                assert!(
2937                    mean < 50.0,
2938                    "anchor 2 detection (centre {cx:.2}) should have low-value mask; got mean {mean}"
2939                );
2940            } else {
2941                panic!("unexpected detection centre {cx:.2}");
2942            }
2943        }
2944    }
2945
2946    // ========================================================================
2947    // Tests for truncate_to_top_k_by_score / truncate_to_top_k_by_score_quant
2948    // ========================================================================
2949
2950    /// Helper: build a Vec of (DetectBox, ()) with the given scores.
2951    fn make_float_boxes(scores: &[f32]) -> Vec<(DetectBox, ())> {
2952        scores
2953            .iter()
2954            .enumerate()
2955            .map(|(i, &s)| {
2956                (
2957                    DetectBox {
2958                        bbox: BoundingBox {
2959                            xmin: 0.0,
2960                            ymin: 0.0,
2961                            xmax: 1.0,
2962                            ymax: 1.0,
2963                        },
2964                        score: s,
2965                        label: i,
2966                    },
2967                    (),
2968                )
2969            })
2970            .collect()
2971    }
2972
2973    /// Helper: build a Vec of (DetectBoxQuantized<i8>, ()) with the given scores.
2974    fn make_quant_boxes(scores: &[i8]) -> Vec<(DetectBoxQuantized<i8>, ())> {
2975        scores
2976            .iter()
2977            .enumerate()
2978            .map(|(i, &s)| {
2979                (
2980                    DetectBoxQuantized {
2981                        bbox: BoundingBox {
2982                            xmin: 0.0,
2983                            ymin: 0.0,
2984                            xmax: 1.0,
2985                            ymax: 1.0,
2986                        },
2987                        score: s,
2988                        label: i,
2989                    },
2990                    (),
2991                )
2992            })
2993            .collect()
2994    }
2995
2996    #[test]
2997    fn truncate_float_top_k_zero_is_unbounded() {
2998        let mut boxes = make_float_boxes(&[0.9, 0.1, 0.5, 0.3, 0.7]);
2999        let original_len = boxes.len();
3000        truncate_to_top_k_by_score(&mut boxes, 0);
3001        assert_eq!(
3002            boxes.len(),
3003            original_len,
3004            "top_k=0 should keep all candidates (no-limit semantics)"
3005        );
3006    }
3007
3008    #[test]
3009    fn truncate_float_top_k_normal() {
3010        let mut boxes = make_float_boxes(&[0.9, 0.1, 0.5, 0.3, 0.7]);
3011        truncate_to_top_k_by_score(&mut boxes, 3);
3012        assert_eq!(boxes.len(), 3);
3013        // The top-3 scores should be 0.9, 0.7, 0.5 (order within top-K is unspecified)
3014        let mut retained: Vec<f32> = boxes.iter().map(|(b, _)| b.score).collect();
3015        retained.sort_by(|a, b| b.total_cmp(a));
3016        assert_eq!(retained, vec![0.9, 0.7, 0.5]);
3017    }
3018
3019    #[test]
3020    fn truncate_float_top_k_noop_when_under_cap() {
3021        let mut boxes = make_float_boxes(&[0.9, 0.5]);
3022        truncate_to_top_k_by_score(&mut boxes, 10);
3023        assert_eq!(boxes.len(), 2, "should be no-op when len <= top_k");
3024    }
3025
3026    #[test]
3027    fn truncate_quant_top_k_zero_is_unbounded() {
3028        let mut boxes = make_quant_boxes(&[120, -50, 30, -10, 80]);
3029        let original_len = boxes.len();
3030        truncate_to_top_k_by_score_quant(&mut boxes, 0);
3031        assert_eq!(
3032            boxes.len(),
3033            original_len,
3034            "top_k=0 should keep all candidates (no-limit semantics)"
3035        );
3036    }
3037
3038    #[test]
3039    fn truncate_quant_top_k_normal() {
3040        let mut boxes = make_quant_boxes(&[120, -50, 30, -10, 80]);
3041        truncate_to_top_k_by_score_quant(&mut boxes, 3);
3042        assert_eq!(boxes.len(), 3);
3043        let mut retained: Vec<i8> = boxes.iter().map(|(b, _)| b.score).collect();
3044        retained.sort_by(|a, b| b.cmp(a));
3045        assert_eq!(retained, vec![120, 80, 30]);
3046    }
3047
3048    #[test]
3049    fn truncate_quant_top_k_noop_when_under_cap() {
3050        let mut boxes = make_quant_boxes(&[120, 80]);
3051        truncate_to_top_k_by_score_quant(&mut boxes, 10);
3052        assert_eq!(boxes.len(), 2, "should be no-op when len <= top_k");
3053    }
3054}