Skip to main content

edgefirst_decoder/
yolo.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4//! Internal YOLO decoder kernels.
5//!
6//! All items in this module are `pub(crate)`. The public API surface for
7//! decoding is [`crate::Decoder`] + [`crate::DecoderBuilder`]; external
8//! callers must go through that entry point so we can evolve the kernels
9//! (split paths, dispatch tables, NEON tiers) without breaking semver.
10//! See `CHANGELOG.md` for the 0.20.0 narrowing.
11
12use std::fmt::Debug;
13
14use ndarray::{
15    parallel::prelude::{IntoParallelIterator, ParallelIterator},
16    s, Array2, Array3, ArrayView1, ArrayView2, ArrayView3,
17};
18use num_traits::{AsPrimitive, Float, PrimInt, Signed};
19
20use crate::{
21    byte::{
22        nms_class_aware_int, nms_extra_class_aware_int, nms_extra_int, nms_int,
23        postprocess_boxes_index_quant, postprocess_boxes_quant, quantize_score_threshold,
24    },
25    configs::Nms,
26    dequant_detect_box,
27    float::{
28        nms_class_aware_float, nms_extra_class_aware_float, nms_extra_float, nms_float,
29        postprocess_boxes_float, postprocess_boxes_index_float,
30    },
31    BBoxTypeTrait, BoundingBox, DetectBox, DetectBoxQuantized, ProtoData, ProtoLayout,
32    Quantization, Segmentation, XYWH, XYXY,
33};
34
35/// Maximum number of above-threshold candidates fed to NMS.
36///
37/// At very low score thresholds (e.g., t=0.01 on YOLOv8 with
38/// 8400 anchors × 80 classes), the number of survivors approaches
39/// the full 672 000-entry score grid. NMS is O(n²) and the
40/// downstream mask matmul runs once per survivor, so an
41/// unbounded set produces minutes-per-frame decode times.
42///
43/// `MAX_NMS_CANDIDATES` matches the Ultralytics `max_nms` default
44/// and is applied as a top-K-by-score truncation immediately
45/// before NMS. Values above the cap are silently dropped — at the
46/// score thresholds where the cap activates the bottom of the
47/// candidate list is dominated by noise that NMS would discard
48/// anyway.
49///
50/// Production callers configure this via [`crate::DecoderBuilder::with_pre_nms_top_k`];
51/// only the test-only seg-det wrappers reach for this constant directly.
52#[cfg(test)]
53pub(crate) const MAX_NMS_CANDIDATES: usize = 30_000;
54
55/// Default post-NMS detection cap used by the public `decode_yolo_*`
56/// convenience wrappers when no explicit cap is plumbed in. Mirrors the
57/// `Decoder::max_det` default set by `DecoderBuilder` (also 300, matching
58/// the Ultralytics `max_det` default). Pre-EDGEAI-1302 these wrappers
59/// used `output_boxes.capacity()` as the cap, which silently dropped all
60/// detections when the caller passed `Vec::new()`.
61pub(crate) const DEFAULT_MAX_DETECTIONS: usize = 300;
62
63/// Truncate `boxes` to the highest-scoring `top_k` entries in-place when the
64/// input exceeds the cap. Uses partial sort (O(N)) via `select_nth_unstable_by`
65/// to avoid full O(N log N) sort. No-op when `top_k` is 0 (unbounded) or
66/// when the input length ≤ `top_k`.
67fn truncate_to_top_k_by_score<E: Send>(boxes: &mut Vec<(DetectBox, E)>, top_k: usize) {
68    if top_k > 0 && boxes.len() > top_k {
69        boxes.select_nth_unstable_by(top_k, |a, b| b.0.score.total_cmp(&a.0.score));
70        boxes.truncate(top_k);
71    }
72}
73
74/// Quantized counterpart of [`truncate_to_top_k_by_score`]. Sorts on
75/// the raw quantized score (which preserves order under monotonic
76/// dequantization). Uses partial sort (O(N)) via `select_nth_unstable_by`.
77/// No-op when `top_k` is 0 (unbounded) or when the input length ≤ `top_k`.
78fn truncate_to_top_k_by_score_quant<S: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send>(
79    boxes: &mut Vec<(DetectBoxQuantized<S>, E)>,
80    top_k: usize,
81) {
82    if top_k > 0 && boxes.len() > top_k {
83        boxes.select_nth_unstable_by(top_k, |a, b| b.0.score.cmp(&a.0.score));
84        boxes.truncate(top_k);
85    }
86}
87
88/// Dispatches to the appropriate NMS function based on mode for float boxes.
89///
90/// `max_det` is the post-NMS detection cap; when present it lets the greedy
91/// inner loop break as soon as that many survivors are confirmed (the survivors
92/// are guaranteed to be the top-`max_det` by score because the input is sorted
93/// descending). Pass `None` to run the full O(N²) suppression.
94fn dispatch_nms_float(
95    nms: Option<Nms>,
96    iou: f32,
97    max_det: Option<usize>,
98    boxes: Vec<DetectBox>,
99) -> Vec<DetectBox> {
100    match nms {
101        Some(Nms::ClassAgnostic | Nms::Auto) => nms_float(iou, max_det, boxes),
102        Some(Nms::ClassAware) => nms_class_aware_float(iou, max_det, boxes),
103        None => boxes, // bypass NMS
104    }
105}
106
107/// Dispatches to the appropriate NMS function based on mode for float boxes
108/// with extra data.
109pub(super) fn dispatch_nms_extra_float<E: Send + Sync>(
110    nms: Option<Nms>,
111    iou: f32,
112    max_det: Option<usize>,
113    boxes: Vec<(DetectBox, E)>,
114) -> Vec<(DetectBox, E)> {
115    match nms {
116        Some(Nms::ClassAgnostic | Nms::Auto) => nms_extra_float(iou, max_det, boxes),
117        Some(Nms::ClassAware) => nms_extra_class_aware_float(iou, max_det, boxes),
118        None => boxes, // bypass NMS
119    }
120}
121
122/// Dispatches to the appropriate NMS function based on mode for quantized
123/// boxes.
124fn dispatch_nms_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync>(
125    nms: Option<Nms>,
126    iou: f32,
127    max_det: Option<usize>,
128    boxes: Vec<DetectBoxQuantized<SCORE>>,
129) -> Vec<DetectBoxQuantized<SCORE>> {
130    match nms {
131        Some(Nms::ClassAgnostic | Nms::Auto) => nms_int(iou, max_det, boxes),
132        Some(Nms::ClassAware) => nms_class_aware_int(iou, max_det, boxes),
133        None => boxes, // bypass NMS
134    }
135}
136
137/// Dispatches to the appropriate NMS function based on mode for quantized boxes
138/// with extra data.
139fn dispatch_nms_extra_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send + Sync>(
140    nms: Option<Nms>,
141    iou: f32,
142    max_det: Option<usize>,
143    boxes: Vec<(DetectBoxQuantized<SCORE>, E)>,
144) -> Vec<(DetectBoxQuantized<SCORE>, E)> {
145    match nms {
146        Some(Nms::ClassAgnostic | Nms::Auto) => nms_extra_int(iou, max_det, boxes),
147        Some(Nms::ClassAware) => nms_extra_class_aware_int(iou, max_det, boxes),
148        None => boxes, // bypass NMS
149    }
150}
151
152/// Detection cap helper for the public free `decode_yolo_*` wrappers.
153///
154/// Encodes the convention documented above: if the caller passed a non-empty
155/// `Vec`, that capacity acts as the per-call cap; otherwise fall back to
156/// [`DEFAULT_MAX_DETECTIONS`] so freshly-constructed `Vec::new()` outputs
157/// don't silently drop every detection.
158#[inline]
159fn cap_or_default<T>(v: &Vec<T>) -> usize {
160    if v.capacity() > 0 {
161        v.capacity()
162    } else {
163        DEFAULT_MAX_DETECTIONS
164    }
165}
166
167// ─── Public free decode_yolo_* convenience wrappers ────────────────────────
168//
169// Detection cap convention (applies to every `decode_yolo_*` free function
170// below):
171//
172// These functions are the convenience layer for callers that don't go
173// through `Decoder::decode()` (benches, FFI shims, ad-hoc test harnesses).
174// They use **`output_boxes.capacity()` as a per-call detection cap**:
175//
176//   - When the caller passes `Vec::with_capacity(N)`, the post-NMS output
177//     is truncated to at most `N` detections.
178//   - When the caller passes `Vec::new()` (capacity 0), the implementation
179//     falls back to the [`DEFAULT_MAX_DETECTIONS`] constant (300) so a
180//     freshly-constructed `Vec` doesn't silently drop every detection.
181//
182// This is intentionally **different** from the `Decoder::decode()` /
183// `Decoder::decode_proto()` contract, which bounds output count solely
184// by [`Decoder::max_det`] (set via `DecoderBuilder::with_max_det`,
185// default 300) regardless of the caller's `Vec` capacity (EDGEAI-1302).
186//
187// Use the `Decoder` API when you need explicit control over `max_det`,
188// schema-driven decoding, or EDGEAI-1303 normalization. Use these free
189// functions when you have raw tensors in hand and want a one-shot decode.
190
191/// Decodes YOLO detection outputs from quantized tensors into detection boxes.
192///
193/// Boxes are expected to be in XYWH format.
194///
195/// Expected shapes of inputs:
196/// - output: (4 + num_classes, num_boxes)
197///
198/// See the "Detection cap convention" comment above for how
199/// `output_boxes.capacity()` bounds the result count.
200pub(crate) fn decode_yolo_det<BOX: PrimInt + AsPrimitive<f32> + Send + Sync>(
201    output: (ArrayView2<BOX>, Quantization),
202    score_threshold: f32,
203    iou_threshold: f32,
204    nms: Option<Nms>,
205    output_boxes: &mut Vec<DetectBox>,
206) where
207    f32: AsPrimitive<BOX>,
208{
209    impl_yolo_quant::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
210}
211
212/// Decodes YOLO detection outputs from float tensors into detection boxes.
213///
214/// Boxes are expected to be in XYWH format.
215///
216/// Expected shapes of inputs:
217/// - output: (4 + num_classes, num_boxes)
218pub(crate) fn decode_yolo_det_float<T>(
219    output: ArrayView2<T>,
220    score_threshold: f32,
221    iou_threshold: f32,
222    nms: Option<Nms>,
223    output_boxes: &mut Vec<DetectBox>,
224) where
225    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
226    f32: AsPrimitive<T>,
227{
228    impl_yolo_float::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
229}
230
231/// Test-only seg-det quantized decode shim.
232///
233/// Production callers go through [`crate::Decoder::decode`]; this wrapper
234/// exists only so the parity tests in `lib.rs` and `yolo.rs` can compare the
235/// kernel output against the Decoder output without duplicating the
236/// `impl_yolo_segdet_quant` argument plumbing.
237///
238/// Boxes are expected to be in XYWH format. Expected shapes:
239/// - `boxes`: `(4 + num_classes + num_protos, num_boxes)`
240/// - `protos`: `(proto_height, proto_width, num_protos)`
241///
242/// # Errors
243/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
244#[cfg(test)]
245pub(crate) fn decode_yolo_segdet_quant<
246    BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
247    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
248>(
249    boxes: (ArrayView2<BOX>, Quantization),
250    protos: (ArrayView3<PROTO>, Quantization),
251    score_threshold: f32,
252    iou_threshold: f32,
253    nms: Option<Nms>,
254    output_boxes: &mut Vec<DetectBox>,
255    output_masks: &mut Vec<Segmentation>,
256) -> Result<(), crate::DecoderError>
257where
258    f32: AsPrimitive<BOX>,
259{
260    // Pre-Decoder convenience wrapper: no schema-derived `normalized`
261    // or `input_dims`, so the EDGEAI-1303 normalization is a no-op.
262    // Callers that need that path should go through `DecoderBuilder`
263    // with a schema.
264    let cap = cap_or_default(output_boxes);
265    impl_yolo_segdet_quant::<XYWH, _, _>(
266        boxes,
267        protos,
268        score_threshold,
269        iou_threshold,
270        nms,
271        MAX_NMS_CANDIDATES,
272        cap,
273        None,
274        None,
275        output_boxes,
276        output_masks,
277    )
278}
279
280/// Test-only seg-det float decode shim. See [`decode_yolo_segdet_quant`].
281#[cfg(test)]
282pub(crate) fn decode_yolo_segdet_float<T>(
283    boxes: ArrayView2<T>,
284    protos: ArrayView3<T>,
285    score_threshold: f32,
286    iou_threshold: f32,
287    nms: Option<Nms>,
288    output_boxes: &mut Vec<DetectBox>,
289    output_masks: &mut Vec<Segmentation>,
290) -> Result<(), crate::DecoderError>
291where
292    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
293    f32: AsPrimitive<T>,
294{
295    // Pre-Decoder convenience wrapper: schema-derived normalization is
296    // not available here (see EDGEAI-1303 note in the quantized sibling).
297    let cap = cap_or_default(output_boxes);
298    impl_yolo_segdet_float::<XYWH, _, _>(
299        boxes,
300        protos,
301        score_threshold,
302        iou_threshold,
303        nms,
304        MAX_NMS_CANDIDATES,
305        cap,
306        None,
307        None,
308        output_boxes,
309        output_masks,
310    )
311}
312
313/// Decodes YOLO split detection outputs from quantized tensors into detection
314/// boxes.
315///
316/// Boxes are expected to be in XYWH format.
317///
318/// Expected shapes of inputs:
319/// - boxes: (4, num_boxes)
320/// - scores: (num_classes, num_boxes)
321///
322/// # Panics
323/// Panics if shapes don't match the expected dimensions.
324pub(crate) fn decode_yolo_split_det_quant<
325    BOX: PrimInt + AsPrimitive<i32> + AsPrimitive<f32> + Send + Sync,
326    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
327>(
328    boxes: (ArrayView2<BOX>, Quantization),
329    scores: (ArrayView2<SCORE>, Quantization),
330    score_threshold: f32,
331    iou_threshold: f32,
332    nms: Option<Nms>,
333    output_boxes: &mut Vec<DetectBox>,
334) where
335    f32: AsPrimitive<SCORE>,
336{
337    impl_yolo_split_quant::<XYWH, _, _>(
338        boxes,
339        scores,
340        score_threshold,
341        iou_threshold,
342        nms,
343        output_boxes,
344    );
345}
346
347/// Decodes YOLO split detection outputs from float tensors into detection
348/// boxes.
349///
350/// Boxes are expected to be in XYWH format.
351///
352/// Expected shapes of inputs:
353/// - boxes: (4, num_boxes)
354/// - scores: (num_classes, num_boxes)
355///
356/// # Panics
357/// Panics if shapes don't match the expected dimensions.
358pub(crate) fn decode_yolo_split_det_float<T>(
359    boxes: ArrayView2<T>,
360    scores: ArrayView2<T>,
361    score_threshold: f32,
362    iou_threshold: f32,
363    nms: Option<Nms>,
364    output_boxes: &mut Vec<DetectBox>,
365) where
366    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
367    f32: AsPrimitive<T>,
368{
369    impl_yolo_split_float::<XYWH, _, _>(
370        boxes,
371        scores,
372        score_threshold,
373        iou_threshold,
374        nms,
375        output_boxes,
376    );
377}
378
379/// Decodes end-to-end YOLO detection outputs (post-NMS from model).
380/// Expects an array of shape `(6, N)`, where the first dimension (rows)
381/// corresponds to the 6 per-detection features
382/// `[x1, y1, x2, y2, conf, class]` and the second dimension (columns)
383/// indexes the `N` detections.
384/// Boxes are output directly without NMS (the model already applied NMS).
385///
386/// Coordinates may be normalized `[0, 1]` or absolute pixel values depending
387/// on the model configuration. The caller should check
388/// `decoder.normalized_boxes()` to determine which.
389///
390/// # Errors
391///
392/// Returns `DecoderError::InvalidShape` if `output` has fewer than 6 rows.
393pub(crate) fn decode_yolo_end_to_end_det_float<T>(
394    output: ArrayView2<T>,
395    score_threshold: f32,
396    output_boxes: &mut Vec<DetectBox>,
397) -> Result<(), crate::DecoderError>
398where
399    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
400    f32: AsPrimitive<T>,
401{
402    // Validate input shape: need at least 6 rows (x1, y1, x2, y2, conf, class)
403    if output.shape()[0] < 6 {
404        return Err(crate::DecoderError::InvalidShape(format!(
405            "End-to-end detection output requires at least 6 rows, got {}",
406            output.shape()[0]
407        )));
408    }
409
410    // Input shape: (6, N) -> transpose to (N, 4) for boxes and (N, 1) for scores
411    let boxes = output.slice(s![0..4, ..]).reversed_axes();
412    let scores = output.slice(s![4..5, ..]).reversed_axes();
413    let classes = output.slice(s![5, ..]);
414    let mut boxes =
415        postprocess_boxes_index_float::<XYXY, _, _>(score_threshold.as_(), boxes, scores);
416    boxes.truncate(cap_or_default(output_boxes));
417    output_boxes.clear();
418    for (mut b, i) in boxes.into_iter() {
419        b.label = classes[i].as_() as usize;
420        output_boxes.push(b);
421    }
422    // No NMS — model output is already post-NMS
423    Ok(())
424}
425
426/// Decodes end-to-end YOLO detection + segmentation outputs (post-NMS from
427/// model).
428///
429/// Input shapes:
430/// - detection: (6 + num_protos, N) where rows are [x1, y1, x2, y2, conf,
431///   class, mask_coeff_0, ..., mask_coeff_31]
432/// - protos: (proto_height, proto_width, num_protos)
433///
434/// Boxes are output directly without NMS (model already applied NMS).
435/// Coordinates may be normalized [0,1] or pixel values depending on model
436/// config.
437///
438/// # Errors
439///
440/// Returns `DecoderError::InvalidShape` if:
441/// - output has fewer than 7 rows (6 base + at least 1 mask coefficient)
442/// - protos shape doesn't match mask coefficients count
443pub(crate) fn decode_yolo_end_to_end_segdet_float<T>(
444    output: ArrayView2<T>,
445    protos: ArrayView3<T>,
446    score_threshold: f32,
447    output_boxes: &mut Vec<DetectBox>,
448    output_masks: &mut Vec<crate::Segmentation>,
449) -> Result<(), crate::DecoderError>
450where
451    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
452    f32: AsPrimitive<T>,
453{
454    let (boxes, scores, classes, mask_coeff) =
455        postprocess_yolo_end_to_end_segdet(&output, protos.dim().2)?;
456    let cap = cap_or_default(output_boxes);
457    let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
458        boxes,
459        scores,
460        classes,
461        score_threshold,
462        cap,
463    );
464
465    // No NMS — model output is already post-NMS
466
467    impl_yolo_split_segdet_process_masks(boxes, mask_coeff, protos, output_boxes, output_masks)
468}
469
470/// Decodes split end-to-end YOLO detection outputs (post-NMS from model).
471///
472/// Input shapes (after batch dim removed):
473/// - boxes: (4, N) — xyxy pixel coordinates
474/// - scores: (1, N) — confidence of the top class
475/// - classes: (1, N) — class index of the top class
476///
477/// Boxes are output directly without NMS (model already applied NMS).
478pub(crate) fn decode_yolo_split_end_to_end_det_float<T: Float + AsPrimitive<f32>>(
479    boxes: ArrayView2<T>,
480    scores: ArrayView2<T>,
481    classes: ArrayView2<T>,
482    score_threshold: f32,
483    output_boxes: &mut Vec<DetectBox>,
484) -> Result<(), crate::DecoderError> {
485    let n = boxes.shape()[1];
486
487    let cap = cap_or_default(output_boxes);
488    output_boxes.clear();
489
490    let (boxes, scores, classes) = postprocess_yolo_split_end_to_end_det(boxes, scores, &classes)?;
491
492    for i in 0..n {
493        let score: f32 = scores[[i, 0]].as_();
494        if score < score_threshold {
495            continue;
496        }
497        if output_boxes.len() >= cap {
498            break;
499        }
500        output_boxes.push(DetectBox {
501            bbox: BoundingBox {
502                xmin: boxes[[i, 0]].as_(),
503                ymin: boxes[[i, 1]].as_(),
504                xmax: boxes[[i, 2]].as_(),
505                ymax: boxes[[i, 3]].as_(),
506            },
507            score,
508            label: classes[i].as_() as usize,
509        });
510    }
511    Ok(())
512}
513
514/// Decodes split end-to-end YOLO detection + segmentation outputs.
515///
516/// Input shapes (after batch dim removed):
517/// - boxes: (4, N) — xyxy pixel coordinates
518/// - scores: (1, N) — confidence
519/// - classes: (1, N) — class index
520/// - mask_coeff: (num_protos, N) — mask coefficients per detection
521/// - protos: (proto_h, proto_w, num_protos) — prototype masks
522#[allow(clippy::too_many_arguments)]
523pub(crate) fn decode_yolo_split_end_to_end_segdet_float<T>(
524    boxes: ArrayView2<T>,
525    scores: ArrayView2<T>,
526    classes: ArrayView2<T>,
527    mask_coeff: ArrayView2<T>,
528    protos: ArrayView3<T>,
529    score_threshold: f32,
530    output_boxes: &mut Vec<DetectBox>,
531    output_masks: &mut Vec<crate::Segmentation>,
532) -> Result<(), crate::DecoderError>
533where
534    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
535    f32: AsPrimitive<T>,
536{
537    let (boxes, scores, classes, mask_coeff) =
538        postprocess_yolo_split_end_to_end_segdet(boxes, scores, &classes, mask_coeff)?;
539    let cap = cap_or_default(output_boxes);
540    let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
541        boxes,
542        scores,
543        classes,
544        score_threshold,
545        cap,
546    );
547
548    impl_yolo_split_segdet_process_masks(boxes, mask_coeff, protos, output_boxes, output_masks)
549}
550
551#[allow(clippy::type_complexity)]
552pub(crate) fn postprocess_yolo_end_to_end_segdet<'a, T>(
553    output: &'a ArrayView2<'_, T>,
554    num_protos: usize,
555) -> Result<
556    (
557        ArrayView2<'a, T>,
558        ArrayView2<'a, T>,
559        ArrayView1<'a, T>,
560        ArrayView2<'a, T>,
561    ),
562    crate::DecoderError,
563> {
564    // Validate input shape: need at least 7 rows (6 base + at least 1 mask coeff)
565    if output.shape()[0] < 7 {
566        return Err(crate::DecoderError::InvalidShape(format!(
567            "End-to-end segdet output requires at least 7 rows, got {}",
568            output.shape()[0]
569        )));
570    }
571
572    let num_mask_coeffs = output.shape()[0] - 6;
573    if num_mask_coeffs != num_protos {
574        return Err(crate::DecoderError::InvalidShape(format!(
575            "Mask coefficients count ({}) doesn't match protos count ({})",
576            num_mask_coeffs, num_protos
577        )));
578    }
579
580    // Input shape: (6+num_protos, N) -> transpose for postprocessing
581    let boxes = output.slice(s![0..4, ..]).reversed_axes();
582    let scores = output.slice(s![4..5, ..]).reversed_axes();
583    let classes = output.slice(s![5, ..]);
584    let mask_coeff = output.slice(s![6.., ..]).reversed_axes();
585    Ok((boxes, scores, classes, mask_coeff))
586}
587
588/// Postprocess yolo split end to end det by reversing axes of boxes,
589/// scores, and flattening the class tensor.
590/// Expected input shapes:
591/// - boxes: (4, N)
592/// - scores: (1, N)
593/// - classes: (1, N)
594#[allow(clippy::type_complexity)]
595pub(crate) fn postprocess_yolo_split_end_to_end_det<'a, 'b, 'c, BOXES, SCORES, CLASS>(
596    boxes: ArrayView2<'a, BOXES>,
597    scores: ArrayView2<'b, SCORES>,
598    classes: &'c ArrayView2<CLASS>,
599) -> Result<
600    (
601        ArrayView2<'a, BOXES>,
602        ArrayView2<'b, SCORES>,
603        ArrayView1<'c, CLASS>,
604    ),
605    crate::DecoderError,
606> {
607    let num_boxes = boxes.shape()[1];
608    if boxes.shape()[0] != 4 {
609        return Err(crate::DecoderError::InvalidShape(format!(
610            "Split end-to-end box_coords must be 4, got {}",
611            boxes.shape()[0]
612        )));
613    }
614
615    if scores.shape()[0] != 1 {
616        return Err(crate::DecoderError::InvalidShape(format!(
617            "Split end-to-end scores num_classes must be 1, got {}",
618            scores.shape()[0]
619        )));
620    }
621
622    if classes.shape()[0] != 1 {
623        return Err(crate::DecoderError::InvalidShape(format!(
624            "Split end-to-end classes num_classes must be 1, got {}",
625            classes.shape()[0]
626        )));
627    }
628
629    if scores.shape()[1] != num_boxes {
630        return Err(crate::DecoderError::InvalidShape(format!(
631            "Split end-to-end scores must have same num_boxes as boxes ({}), got {}",
632            num_boxes,
633            scores.shape()[1]
634        )));
635    }
636
637    if classes.shape()[1] != num_boxes {
638        return Err(crate::DecoderError::InvalidShape(format!(
639            "Split end-to-end classes must have same num_boxes as boxes ({}), got {}",
640            num_boxes,
641            classes.shape()[1]
642        )));
643    }
644
645    let boxes = boxes.reversed_axes();
646    let scores = scores.reversed_axes();
647    let classes = classes.slice(s![0, ..]);
648    Ok((boxes, scores, classes))
649}
650
651/// Postprocess yolo split end to end segdet by reversing axes of boxes,
652/// scores, mask tensors and flattening the class tensor.
653#[allow(clippy::type_complexity)]
654pub(crate) fn postprocess_yolo_split_end_to_end_segdet<
655    'a,
656    'b,
657    'c,
658    'd,
659    BOXES,
660    SCORES,
661    CLASS,
662    MASK,
663>(
664    boxes: ArrayView2<'a, BOXES>,
665    scores: ArrayView2<'b, SCORES>,
666    classes: &'c ArrayView2<CLASS>,
667    mask_coeff: ArrayView2<'d, MASK>,
668) -> Result<
669    (
670        ArrayView2<'a, BOXES>,
671        ArrayView2<'b, SCORES>,
672        ArrayView1<'c, CLASS>,
673        ArrayView2<'d, MASK>,
674    ),
675    crate::DecoderError,
676> {
677    let num_boxes = boxes.shape()[1];
678    if boxes.shape()[0] != 4 {
679        return Err(crate::DecoderError::InvalidShape(format!(
680            "Split end-to-end box_coords must be 4, got {}",
681            boxes.shape()[0]
682        )));
683    }
684
685    if scores.shape()[0] != 1 {
686        return Err(crate::DecoderError::InvalidShape(format!(
687            "Split end-to-end scores num_classes must be 1, got {}",
688            scores.shape()[0]
689        )));
690    }
691
692    if classes.shape()[0] != 1 {
693        return Err(crate::DecoderError::InvalidShape(format!(
694            "Split end-to-end classes num_classes must be 1, got {}",
695            classes.shape()[0]
696        )));
697    }
698
699    if scores.shape()[1] != num_boxes {
700        return Err(crate::DecoderError::InvalidShape(format!(
701            "Split end-to-end scores must have same num_boxes as boxes ({}), got {}",
702            num_boxes,
703            scores.shape()[1]
704        )));
705    }
706
707    if classes.shape()[1] != num_boxes {
708        return Err(crate::DecoderError::InvalidShape(format!(
709            "Split end-to-end classes must have same num_boxes as boxes ({}), got {}",
710            num_boxes,
711            classes.shape()[1]
712        )));
713    }
714
715    if mask_coeff.shape()[1] != num_boxes {
716        return Err(crate::DecoderError::InvalidShape(format!(
717            "Split end-to-end mask_coeff must have same num_boxes as boxes ({}), got {}",
718            num_boxes,
719            mask_coeff.shape()[1]
720        )));
721    }
722
723    let boxes = boxes.reversed_axes();
724    let scores = scores.reversed_axes();
725    let classes = classes.slice(s![0, ..]);
726    let mask_coeff = mask_coeff.reversed_axes();
727    Ok((boxes, scores, classes, mask_coeff))
728}
729/// Internal implementation of YOLO decoding for quantized tensors.
730///
731/// Expected shapes of inputs:
732/// - output: (4 + num_classes, num_boxes)
733pub(crate) fn impl_yolo_quant<B: BBoxTypeTrait, T: PrimInt + AsPrimitive<f32> + Send + Sync>(
734    output: (ArrayView2<T>, Quantization),
735    score_threshold: f32,
736    iou_threshold: f32,
737    nms: Option<Nms>,
738    output_boxes: &mut Vec<DetectBox>,
739) where
740    f32: AsPrimitive<T>,
741{
742    let _span = tracing::trace_span!("decode", mode = "quant_det").entered();
743    let (boxes, quant_boxes) = output;
744    let (boxes_tensor, scores_tensor) = postprocess_yolo(&boxes);
745
746    let boxes = {
747        let score_threshold = quantize_score_threshold(score_threshold, quant_boxes);
748        postprocess_boxes_quant::<B, _, _>(
749            score_threshold,
750            boxes_tensor,
751            scores_tensor,
752            quant_boxes,
753        )
754    };
755
756    let cap = cap_or_default(output_boxes);
757    let boxes = dispatch_nms_int(nms, iou_threshold, Some(cap), boxes);
758    // Detection cap convention (see `cap_or_default`). NMS already capped to
759    // `cap`; the `min` here is a redundant guard for non-NMS bypass mode.
760    let len = cap.min(boxes.len());
761    output_boxes.clear();
762    for b in boxes.iter().take(len) {
763        output_boxes.push(dequant_detect_box(b, quant_boxes));
764    }
765}
766
767/// Internal implementation of YOLO decoding for float tensors.
768///
769/// Expected shapes of inputs:
770/// - output: (4 + num_classes, num_boxes)
771pub(crate) fn impl_yolo_float<B: BBoxTypeTrait, T: Float + AsPrimitive<f32> + Send + Sync>(
772    output: ArrayView2<T>,
773    score_threshold: f32,
774    iou_threshold: f32,
775    nms: Option<Nms>,
776    output_boxes: &mut Vec<DetectBox>,
777) where
778    f32: AsPrimitive<T>,
779{
780    let _span = tracing::trace_span!("decode", mode = "float_det").entered();
781    let (boxes_tensor, scores_tensor) = postprocess_yolo(&output);
782    let boxes =
783        postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
784    let cap = cap_or_default(output_boxes);
785    let boxes = dispatch_nms_float(nms, iou_threshold, Some(cap), boxes);
786    // Detection cap convention (see `cap_or_default`). NMS already capped to
787    // `cap`; the `min` here is a redundant guard for non-NMS bypass mode.
788    let len = cap.min(boxes.len());
789    output_boxes.clear();
790    for b in boxes.into_iter().take(len) {
791        output_boxes.push(b);
792    }
793}
794
795/// Internal implementation of YOLO split detection decoding for quantized
796/// tensors.
797///
798/// Expected shapes of inputs:
799/// - boxes: (4, num_boxes)
800/// - scores: (num_classes, num_boxes)
801///
802/// # Panics
803/// Panics if shapes don't match the expected dimensions.
804pub(crate) fn impl_yolo_split_quant<
805    B: BBoxTypeTrait,
806    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
807    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
808>(
809    boxes: (ArrayView2<BOX>, Quantization),
810    scores: (ArrayView2<SCORE>, Quantization),
811    score_threshold: f32,
812    iou_threshold: f32,
813    nms: Option<Nms>,
814    output_boxes: &mut Vec<DetectBox>,
815) where
816    f32: AsPrimitive<SCORE>,
817{
818    let _span = tracing::trace_span!("decode", mode = "split_quant_det").entered();
819    let (boxes_tensor, quant_boxes) = boxes;
820    let (scores_tensor, quant_scores) = scores;
821
822    let boxes_tensor = boxes_tensor.reversed_axes();
823    let scores_tensor = scores_tensor.reversed_axes();
824
825    let boxes = {
826        let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
827        postprocess_boxes_quant::<B, _, _>(
828            score_threshold,
829            boxes_tensor,
830            scores_tensor,
831            quant_boxes,
832        )
833    };
834
835    let cap = cap_or_default(output_boxes);
836    let boxes = dispatch_nms_int(nms, iou_threshold, Some(cap), boxes);
837    // Detection cap convention (see `cap_or_default`). NMS already capped to
838    // `cap`; the `min` here is a redundant guard for non-NMS bypass mode.
839    let len = cap.min(boxes.len());
840    output_boxes.clear();
841    for b in boxes.iter().take(len) {
842        output_boxes.push(dequant_detect_box(b, quant_scores));
843    }
844}
845
846/// Internal implementation of YOLO split detection decoding for float tensors.
847///
848/// Expected shapes of inputs:
849/// - boxes: (4, num_boxes)
850/// - scores: (num_classes, num_boxes)
851///
852/// # Panics
853/// Panics if shapes don't match the expected dimensions.
854pub(crate) fn impl_yolo_split_float<
855    B: BBoxTypeTrait,
856    BOX: Float + AsPrimitive<f32> + Send + Sync,
857    SCORE: Float + AsPrimitive<f32> + Send + Sync,
858>(
859    boxes_tensor: ArrayView2<BOX>,
860    scores_tensor: ArrayView2<SCORE>,
861    score_threshold: f32,
862    iou_threshold: f32,
863    nms: Option<Nms>,
864    output_boxes: &mut Vec<DetectBox>,
865) where
866    f32: AsPrimitive<SCORE>,
867{
868    let _span = tracing::trace_span!("decode", mode = "split_float_det").entered();
869    let boxes_tensor = boxes_tensor.reversed_axes();
870    let scores_tensor = scores_tensor.reversed_axes();
871    let boxes =
872        postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
873    let cap = cap_or_default(output_boxes);
874    let boxes = dispatch_nms_float(nms, iou_threshold, Some(cap), boxes);
875    // Detection cap convention (see `cap_or_default`). NMS already capped to
876    // `cap`; the `min` here is a redundant guard for non-NMS bypass mode.
877    let len = cap.min(boxes.len());
878    output_boxes.clear();
879    for b in boxes.into_iter().take(len) {
880        output_boxes.push(b);
881    }
882}
883
884/// Divide each survivor's bbox by `(input_w, input_h)` when the schema
885/// declares `normalized: false`. Pixel-space box coords from the model
886/// are pulled into the canonical `[0, 1]` range expected by `protobox`
887/// and downstream callers — see EDGEAI-1303.
888///
889/// No-op when `normalized` is `Some(true)` / `None`, when `input_dims`
890/// is `None`, or when there are no survivors.
891#[inline]
892pub(crate) fn maybe_normalize_boxes_in_place(
893    boxes: &mut [(DetectBox, usize)],
894    normalized: Option<bool>,
895    input_dims: Option<(usize, usize)>,
896) {
897    if normalized != Some(false) {
898        return;
899    }
900    let Some((w, h)) = input_dims else {
901        return;
902    };
903    if w == 0 || h == 0 {
904        return;
905    }
906    let inv_w = 1.0 / w as f32;
907    let inv_h = 1.0 / h as f32;
908    for (b, _) in boxes.iter_mut() {
909        b.bbox.xmin *= inv_w;
910        b.bbox.ymin *= inv_h;
911        b.bbox.xmax *= inv_w;
912        b.bbox.ymax *= inv_h;
913    }
914}
915
916/// Internal implementation of YOLO detection segmentation decoding for
917/// quantized tensors.
918///
919/// Expected shapes of inputs:
920/// - boxes: (4 + num_classes + num_protos, num_boxes)
921/// - protos: (proto_height, proto_width, num_protos)
922///
923/// # Errors
924/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
925#[allow(clippy::too_many_arguments)]
926pub(crate) fn impl_yolo_segdet_quant<
927    B: BBoxTypeTrait,
928    BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
929    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
930>(
931    boxes: (ArrayView2<BOX>, Quantization),
932    protos: (ArrayView3<PROTO>, Quantization),
933    score_threshold: f32,
934    iou_threshold: f32,
935    nms: Option<Nms>,
936    pre_nms_top_k: usize,
937    max_det: usize,
938    normalized: Option<bool>,
939    input_dims: Option<(usize, usize)>,
940    output_boxes: &mut Vec<DetectBox>,
941    output_masks: &mut Vec<Segmentation>,
942) -> Result<(), crate::DecoderError>
943where
944    f32: AsPrimitive<BOX>,
945{
946    let (boxes, quant_boxes) = boxes;
947    let num_protos = protos.0.dim().2;
948
949    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
950    let mut boxes = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
951        (boxes_tensor, quant_boxes),
952        (scores_tensor, quant_boxes),
953        score_threshold,
954        iou_threshold,
955        nms,
956        pre_nms_top_k,
957        max_det,
958    );
959    maybe_normalize_boxes_in_place(&mut boxes, normalized, input_dims);
960
961    impl_yolo_split_segdet_quant_process_masks::<_, _>(
962        boxes,
963        (mask_tensor, quant_boxes),
964        protos,
965        output_boxes,
966        output_masks,
967    )
968}
969
970/// Internal implementation of YOLO detection segmentation decoding for
971/// float tensors.
972///
973/// Expected shapes of inputs:
974/// - boxes: (4 + num_classes + num_protos, num_boxes)
975/// - protos: (proto_height, proto_width, num_protos)
976///
977/// # Panics
978/// Panics if shapes don't match the expected dimensions.
979#[allow(clippy::too_many_arguments)]
980pub(crate) fn impl_yolo_segdet_float<
981    B: BBoxTypeTrait,
982    BOX: Float + AsPrimitive<f32> + Send + Sync,
983    PROTO: Float + AsPrimitive<f32> + Send + Sync,
984>(
985    boxes: ArrayView2<BOX>,
986    protos: ArrayView3<PROTO>,
987    score_threshold: f32,
988    iou_threshold: f32,
989    nms: Option<Nms>,
990    pre_nms_top_k: usize,
991    max_det: usize,
992    normalized: Option<bool>,
993    input_dims: Option<(usize, usize)>,
994    output_boxes: &mut Vec<DetectBox>,
995    output_masks: &mut Vec<Segmentation>,
996) -> Result<(), crate::DecoderError>
997where
998    f32: AsPrimitive<BOX>,
999{
1000    let num_protos = protos.dim().2;
1001    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
1002    let mut boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
1003        boxes_tensor,
1004        scores_tensor,
1005        score_threshold,
1006        iou_threshold,
1007        nms,
1008        pre_nms_top_k,
1009        max_det,
1010    );
1011    maybe_normalize_boxes_in_place(&mut boxes, normalized, input_dims);
1012    impl_yolo_split_segdet_process_masks(boxes, mask_tensor, protos, output_boxes, output_masks)
1013}
1014
1015pub(crate) fn impl_yolo_segdet_get_boxes<
1016    B: BBoxTypeTrait,
1017    BOX: Float + AsPrimitive<f32> + Send + Sync,
1018    SCORE: Float + AsPrimitive<f32> + Send + Sync,
1019>(
1020    boxes_tensor: ArrayView2<BOX>,
1021    scores_tensor: ArrayView2<SCORE>,
1022    score_threshold: f32,
1023    iou_threshold: f32,
1024    nms: Option<Nms>,
1025    pre_nms_top_k: usize,
1026    max_det: usize,
1027) -> Vec<(DetectBox, usize)>
1028where
1029    f32: AsPrimitive<SCORE>,
1030{
1031    let span = tracing::trace_span!(
1032        "decode",
1033        n_candidates = tracing::field::Empty,
1034        n_after_topk = tracing::field::Empty,
1035        n_after_nms = tracing::field::Empty,
1036        n_detections = tracing::field::Empty,
1037    );
1038    let _guard = span.enter();
1039
1040    let mut boxes = {
1041        let _s = tracing::trace_span!("score_filter").entered();
1042        postprocess_boxes_index_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor)
1043    };
1044    span.record("n_candidates", boxes.len());
1045
1046    if nms.is_some() {
1047        let _s = tracing::trace_span!("top_k", k = pre_nms_top_k).entered();
1048        truncate_to_top_k_by_score(&mut boxes, pre_nms_top_k);
1049    }
1050    span.record("n_after_topk", boxes.len());
1051
1052    let mut boxes = {
1053        let _s = tracing::trace_span!("nms").entered();
1054        dispatch_nms_extra_float(nms, iou_threshold, Some(max_det), boxes)
1055    };
1056    span.record("n_after_nms", boxes.len());
1057
1058    // NMS already capped to `max_det`; the trailing sort+truncate is a
1059    // redundant guard for the bypass-NMS path (`nms = None`).
1060    boxes.sort_unstable_by(|a, b| b.0.score.total_cmp(&a.0.score));
1061    boxes.truncate(max_det);
1062    span.record("n_detections", boxes.len());
1063
1064    boxes
1065}
1066
1067pub(crate) fn impl_yolo_end_to_end_segdet_get_boxes<
1068    B: BBoxTypeTrait,
1069    BOX: Float + AsPrimitive<f32> + Send + Sync,
1070    SCORE: Float + AsPrimitive<f32> + Send + Sync,
1071    CLASS: AsPrimitive<f32> + Send + Sync,
1072>(
1073    boxes: ArrayView2<BOX>,
1074    scores: ArrayView2<SCORE>,
1075    classes: ArrayView1<CLASS>,
1076    score_threshold: f32,
1077    max_boxes: usize,
1078) -> Vec<(DetectBox, usize)>
1079where
1080    f32: AsPrimitive<SCORE>,
1081{
1082    let mut boxes = postprocess_boxes_index_float::<B, _, _>(score_threshold.as_(), boxes, scores);
1083    boxes.truncate(max_boxes);
1084    for (b, ind) in &mut boxes {
1085        b.label = classes[*ind].as_().round() as usize;
1086    }
1087    boxes
1088}
1089
1090pub(crate) fn impl_yolo_split_segdet_process_masks<
1091    MASK: Float + AsPrimitive<f32> + Send + Sync,
1092    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1093>(
1094    boxes: Vec<(DetectBox, usize)>,
1095    masks_tensor: ArrayView2<MASK>,
1096    protos_tensor: ArrayView3<PROTO>,
1097    output_boxes: &mut Vec<DetectBox>,
1098    output_masks: &mut Vec<Segmentation>,
1099) -> Result<(), crate::DecoderError> {
1100    let _span = tracing::trace_span!("process_masks", n = boxes.len(), mode = "float").entered();
1101    // Boxes are already bounded by the upstream `max_det` cap from
1102    // `_get_boxes`; no second cap is needed here (EDGEAI-1302).
1103
1104    let boxes = decode_segdet_f32(boxes, masks_tensor, protos_tensor)?;
1105    output_boxes.clear();
1106    output_masks.clear();
1107    for (b, roi, m) in boxes.into_iter() {
1108        output_boxes.push(b);
1109        output_masks.push(Segmentation {
1110            xmin: roi.xmin,
1111            ymin: roi.ymin,
1112            xmax: roi.xmax,
1113            ymax: roi.ymax,
1114            segmentation: m,
1115        });
1116    }
1117    Ok(())
1118}
1119/// Expected input shapes:
1120/// - boxes_tensor: (num_boxes, 4)
1121/// - scores_tensor: (num_boxes, num_classes)
1122pub(crate) fn impl_yolo_split_segdet_quant_get_boxes<
1123    B: BBoxTypeTrait,
1124    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
1125    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
1126>(
1127    boxes: (ArrayView2<BOX>, Quantization),
1128    scores: (ArrayView2<SCORE>, Quantization),
1129    score_threshold: f32,
1130    iou_threshold: f32,
1131    nms: Option<Nms>,
1132    pre_nms_top_k: usize,
1133    max_det: usize,
1134) -> Vec<(DetectBox, usize)>
1135where
1136    f32: AsPrimitive<SCORE>,
1137{
1138    let (boxes_tensor, quant_boxes) = boxes;
1139    let (scores_tensor, quant_scores) = scores;
1140
1141    let span = tracing::trace_span!(
1142        "decode",
1143        n_candidates = tracing::field::Empty,
1144        n_after_topk = tracing::field::Empty,
1145        n_after_nms = tracing::field::Empty,
1146        n_detections = tracing::field::Empty,
1147    );
1148    let _guard = span.enter();
1149
1150    let mut boxes = {
1151        let _s = tracing::trace_span!("score_filter").entered();
1152        let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
1153        postprocess_boxes_index_quant::<B, _, _>(
1154            score_threshold,
1155            boxes_tensor,
1156            scores_tensor,
1157            quant_boxes,
1158        )
1159    };
1160    span.record("n_candidates", boxes.len());
1161
1162    if nms.is_some() {
1163        let _s = tracing::trace_span!("top_k", k = pre_nms_top_k).entered();
1164        truncate_to_top_k_by_score_quant(&mut boxes, pre_nms_top_k);
1165    }
1166    span.record("n_after_topk", boxes.len());
1167
1168    let mut boxes = {
1169        let _s = tracing::trace_span!("nms").entered();
1170        dispatch_nms_extra_int(nms, iou_threshold, Some(max_det), boxes)
1171    };
1172    span.record("n_after_nms", boxes.len());
1173
1174    // NMS already capped to `max_det`; the trailing sort+truncate is a
1175    // redundant guard for the bypass-NMS path (`nms = None`).
1176    boxes.sort_unstable_by_key(|b| std::cmp::Reverse(b.0.score));
1177    boxes.truncate(max_det);
1178    let result: Vec<_> = {
1179        let _s = tracing::trace_span!("box_dequant", n = boxes.len()).entered();
1180        boxes
1181            .into_iter()
1182            .map(|(b, i)| (dequant_detect_box(&b, quant_scores), i))
1183            .collect()
1184    };
1185    span.record("n_detections", result.len());
1186
1187    result
1188}
1189
1190pub(crate) fn impl_yolo_split_segdet_quant_process_masks<
1191    MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1192    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1193>(
1194    boxes: Vec<(DetectBox, usize)>,
1195    mask_coeff: (ArrayView2<MASK>, Quantization),
1196    protos: (ArrayView3<PROTO>, Quantization),
1197    output_boxes: &mut Vec<DetectBox>,
1198    output_masks: &mut Vec<Segmentation>,
1199) -> Result<(), crate::DecoderError> {
1200    let _span = tracing::trace_span!("process_masks", n = boxes.len(), mode = "quant").entered();
1201    let (masks, quant_masks) = mask_coeff;
1202    let (protos, quant_protos) = protos;
1203
1204    // Boxes are already bounded by the upstream `max_det` cap from
1205    // `_get_boxes`; no second cap is needed here (EDGEAI-1302).
1206
1207    let boxes = decode_segdet_quant(boxes, masks, protos, quant_masks, quant_protos)?;
1208    output_boxes.clear();
1209    output_masks.clear();
1210    for (b, roi, m) in boxes.into_iter() {
1211        output_boxes.push(b);
1212        output_masks.push(Segmentation {
1213            xmin: roi.xmin,
1214            ymin: roi.ymin,
1215            xmax: roi.xmax,
1216            ymax: roi.ymax,
1217            segmentation: m,
1218        });
1219    }
1220    Ok(())
1221}
1222
1223/// Internal implementation of YOLO split detection segmentation decoding for
1224/// float tensors.
1225///
1226/// Expected shapes of inputs:
1227/// - boxes_tensor: (4, num_boxes)
1228/// - scores_tensor: (num_classes, num_boxes)
1229/// - mask_tensor: (num_protos, num_boxes)
1230/// - protos: (proto_height, proto_width, num_protos)
1231///
1232/// # Errors
1233/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
1234#[allow(clippy::too_many_arguments)]
1235pub(crate) fn impl_yolo_split_segdet_float<
1236    B: BBoxTypeTrait,
1237    BOX: Float + AsPrimitive<f32> + Send + Sync,
1238    SCORE: Float + AsPrimitive<f32> + Send + Sync,
1239    MASK: Float + AsPrimitive<f32> + Send + Sync,
1240    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1241>(
1242    boxes_tensor: ArrayView2<BOX>,
1243    scores_tensor: ArrayView2<SCORE>,
1244    mask_tensor: ArrayView2<MASK>,
1245    protos: ArrayView3<PROTO>,
1246    score_threshold: f32,
1247    iou_threshold: f32,
1248    nms: Option<Nms>,
1249    pre_nms_top_k: usize,
1250    max_det: usize,
1251    normalized: Option<bool>,
1252    input_dims: Option<(usize, usize)>,
1253    output_boxes: &mut Vec<DetectBox>,
1254    output_masks: &mut Vec<Segmentation>,
1255) -> Result<(), crate::DecoderError>
1256where
1257    f32: AsPrimitive<SCORE>,
1258{
1259    let (boxes_tensor, scores_tensor, mask_tensor) =
1260        postprocess_yolo_split_segdet(boxes_tensor, scores_tensor, mask_tensor);
1261
1262    let mut boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
1263        boxes_tensor,
1264        scores_tensor,
1265        score_threshold,
1266        iou_threshold,
1267        nms,
1268        pre_nms_top_k,
1269        max_det,
1270    );
1271    maybe_normalize_boxes_in_place(&mut boxes, normalized, input_dims);
1272    impl_yolo_split_segdet_process_masks(boxes, mask_tensor, protos, output_boxes, output_masks)
1273}
1274
1275// ---------------------------------------------------------------------------
1276// Proto-extraction variants: return ProtoData instead of materialized masks
1277// ---------------------------------------------------------------------------
1278
1279/// Proto-extraction variant of `impl_yolo_segdet_quant`.
1280/// Runs NMS but returns raw `ProtoData` instead of materialized masks.
1281#[allow(clippy::too_many_arguments)]
1282pub(crate) fn impl_yolo_segdet_quant_proto<
1283    B: BBoxTypeTrait,
1284    BOX: PrimInt
1285        + AsPrimitive<i64>
1286        + AsPrimitive<i128>
1287        + AsPrimitive<f32>
1288        + AsPrimitive<i8>
1289        + Send
1290        + Sync,
1291    PROTO: PrimInt
1292        + AsPrimitive<i64>
1293        + AsPrimitive<i128>
1294        + AsPrimitive<f32>
1295        + AsPrimitive<i8>
1296        + Send
1297        + Sync,
1298>(
1299    boxes: (ArrayView2<BOX>, Quantization),
1300    protos: (ArrayView3<PROTO>, Quantization),
1301    score_threshold: f32,
1302    iou_threshold: f32,
1303    nms: Option<Nms>,
1304    pre_nms_top_k: usize,
1305    max_det: usize,
1306    normalized: Option<bool>,
1307    input_dims: Option<(usize, usize)>,
1308    output_boxes: &mut Vec<DetectBox>,
1309) -> ProtoData
1310where
1311    f32: AsPrimitive<BOX>,
1312{
1313    let (boxes_arr, quant_boxes) = boxes;
1314    let (protos_arr, quant_protos) = protos;
1315    let num_protos = protos_arr.dim().2;
1316
1317    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes_arr, num_protos);
1318
1319    let mut det_indices = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
1320        (boxes_tensor, quant_boxes),
1321        (scores_tensor, quant_boxes),
1322        score_threshold,
1323        iou_threshold,
1324        nms,
1325        pre_nms_top_k,
1326        max_det,
1327    );
1328    maybe_normalize_boxes_in_place(&mut det_indices, normalized, input_dims);
1329
1330    extract_proto_data_quant(
1331        det_indices,
1332        mask_tensor,
1333        quant_boxes,
1334        protos_arr,
1335        quant_protos,
1336        output_boxes,
1337    )
1338}
1339
1340/// Proto-extraction variant of `impl_yolo_segdet_float`.
1341/// Runs NMS but returns raw `ProtoData` instead of materialized masks.
1342#[allow(clippy::too_many_arguments)]
1343pub(crate) fn impl_yolo_segdet_float_proto<
1344    B: BBoxTypeTrait,
1345    BOX: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1346    PROTO: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1347>(
1348    boxes: ArrayView2<BOX>,
1349    protos: ArrayView3<PROTO>,
1350    score_threshold: f32,
1351    iou_threshold: f32,
1352    nms: Option<Nms>,
1353    pre_nms_top_k: usize,
1354    max_det: usize,
1355    normalized: Option<bool>,
1356    input_dims: Option<(usize, usize)>,
1357    output_boxes: &mut Vec<DetectBox>,
1358) -> ProtoData
1359where
1360    f32: AsPrimitive<BOX>,
1361{
1362    let num_protos = protos.dim().2;
1363    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
1364
1365    let mut boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
1366        boxes_tensor,
1367        scores_tensor,
1368        score_threshold,
1369        iou_threshold,
1370        nms,
1371        pre_nms_top_k,
1372        max_det,
1373    );
1374    maybe_normalize_boxes_in_place(&mut boxes, normalized, input_dims);
1375
1376    extract_proto_data_float(boxes, mask_tensor, protos, output_boxes)
1377}
1378
1379/// Proto-extraction variant of `impl_yolo_split_segdet_float`.
1380/// Runs NMS but returns raw `ProtoData` instead of materialized masks.
1381#[allow(clippy::too_many_arguments)]
1382pub(crate) fn impl_yolo_split_segdet_float_proto<
1383    B: BBoxTypeTrait,
1384    BOX: Float + AsPrimitive<f32> + Send + Sync,
1385    SCORE: Float + AsPrimitive<f32> + Send + Sync,
1386    MASK: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1387    PROTO: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1388>(
1389    boxes_tensor: ArrayView2<BOX>,
1390    scores_tensor: ArrayView2<SCORE>,
1391    mask_tensor: ArrayView2<MASK>,
1392    protos: ArrayView3<PROTO>,
1393    score_threshold: f32,
1394    iou_threshold: f32,
1395    nms: Option<Nms>,
1396    pre_nms_top_k: usize,
1397    max_det: usize,
1398    output_boxes: &mut Vec<DetectBox>,
1399) -> ProtoData
1400where
1401    f32: AsPrimitive<SCORE>,
1402{
1403    let (boxes_tensor, scores_tensor, mask_tensor) =
1404        postprocess_yolo_split_segdet(boxes_tensor, scores_tensor, mask_tensor);
1405    let det_indices = impl_yolo_segdet_get_boxes::<B, _, _>(
1406        boxes_tensor,
1407        scores_tensor,
1408        score_threshold,
1409        iou_threshold,
1410        nms,
1411        pre_nms_top_k,
1412        max_det,
1413    );
1414
1415    extract_proto_data_float(det_indices, mask_tensor, protos, output_boxes)
1416}
1417
1418/// Proto-extraction variant of `decode_yolo_end_to_end_segdet_float`.
1419pub(crate) fn decode_yolo_end_to_end_segdet_float_proto<T>(
1420    output: ArrayView2<T>,
1421    protos: ArrayView3<T>,
1422    score_threshold: f32,
1423    output_boxes: &mut Vec<DetectBox>,
1424) -> Result<ProtoData, crate::DecoderError>
1425where
1426    T: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1427    f32: AsPrimitive<T>,
1428{
1429    let (boxes, scores, classes, mask_coeff) =
1430        postprocess_yolo_end_to_end_segdet(&output, protos.dim().2)?;
1431    let cap = cap_or_default(output_boxes);
1432    let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
1433        boxes,
1434        scores,
1435        classes,
1436        score_threshold,
1437        cap,
1438    );
1439
1440    Ok(extract_proto_data_float(
1441        boxes,
1442        mask_coeff,
1443        protos,
1444        output_boxes,
1445    ))
1446}
1447
1448/// Proto-extraction variant of `decode_yolo_split_end_to_end_segdet_float`.
1449#[allow(clippy::too_many_arguments)]
1450pub(crate) fn decode_yolo_split_end_to_end_segdet_float_proto<T>(
1451    boxes: ArrayView2<T>,
1452    scores: ArrayView2<T>,
1453    classes: ArrayView2<T>,
1454    mask_coeff: ArrayView2<T>,
1455    protos: ArrayView3<T>,
1456    score_threshold: f32,
1457    output_boxes: &mut Vec<DetectBox>,
1458) -> Result<ProtoData, crate::DecoderError>
1459where
1460    T: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1461    f32: AsPrimitive<T>,
1462{
1463    let (boxes, scores, classes, mask_coeff) =
1464        postprocess_yolo_split_end_to_end_segdet(boxes, scores, &classes, mask_coeff)?;
1465    let cap = cap_or_default(output_boxes);
1466    let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
1467        boxes,
1468        scores,
1469        classes,
1470        score_threshold,
1471        cap,
1472    );
1473
1474    Ok(extract_proto_data_float(
1475        boxes,
1476        mask_coeff,
1477        protos,
1478        output_boxes,
1479    ))
1480}
1481
1482/// Helper: extract ProtoData from float mask coefficients + protos.
1483///
1484/// Builds [`ProtoData`] with both `protos` and `mask_coefficients` as
1485/// [`edgefirst_tensor::TensorDyn`]. Preserves the native element type for
1486/// `f16` and `f32`; narrows `f64` to `f32` (there is no native f64 kernel
1487/// path). `mask_coefficients` shape is `[num_detections, num_protos]`.
1488pub(super) fn extract_proto_data_float<
1489    MASK: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1490    PROTO: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1491>(
1492    det_indices: Vec<(DetectBox, usize)>,
1493    mask_tensor: ArrayView2<MASK>,
1494    protos: ArrayView3<PROTO>,
1495    output_boxes: &mut Vec<DetectBox>,
1496) -> ProtoData {
1497    let _span = tracing::trace_span!(
1498        "extract_proto",
1499        n = det_indices.len(),
1500        num_protos = mask_tensor.ncols(),
1501        layout = "nhwc",
1502    )
1503    .entered();
1504
1505    let num_protos = mask_tensor.ncols();
1506    let n = det_indices.len();
1507
1508    // Per-detection coefficients packed row-major into a contiguous buffer,
1509    // preserving the source dtype. Shape: [N, num_protos] — N=0 is permitted
1510    // (tracker path emits no detections this frame) since Mem-backed tensors
1511    // accept zero-element shapes as "empty collection" sentinels.
1512    let mut coeff_rows: Vec<MASK> = Vec::with_capacity(n * num_protos);
1513    output_boxes.clear();
1514    for (det, idx) in det_indices {
1515        output_boxes.push(det);
1516        let row = mask_tensor.row(idx);
1517        coeff_rows.extend(row.iter().copied());
1518    }
1519
1520    let mask_coefficients = MASK::slice_into_tensor_dyn(&coeff_rows, &[n, num_protos])
1521        .expect("allocating mask_coefficients TensorDyn");
1522    let protos_tensor =
1523        PROTO::arrayview3_into_tensor_dyn(protos).expect("allocating protos TensorDyn");
1524
1525    ProtoData {
1526        mask_coefficients,
1527        protos: protos_tensor,
1528        layout: ProtoLayout::Nhwc,
1529    }
1530}
1531
1532/// Helper: extract ProtoData from quantized mask coefficients + protos.
1533///
1534/// Dequantizes mask coefficients to f32 at extraction (one-time cost on a
1535/// `num_detections * num_protos` slice) and keeps protos in raw i8,
1536/// attaching the dequantization params as
1537/// [`edgefirst_tensor::Quantization::per_tensor`] metadata on the proto
1538/// tensor. The GPU shader / CPU kernel reads `protos.quantization()` and
1539/// dequantizes per-texel.
1540pub(crate) fn extract_proto_data_quant<
1541    MASK: PrimInt + AsPrimitive<f32> + AsPrimitive<i8> + Send + Sync + 'static,
1542    PROTO: PrimInt + AsPrimitive<f32> + AsPrimitive<i8> + Send + Sync + 'static,
1543>(
1544    det_indices: Vec<(DetectBox, usize)>,
1545    mask_tensor: ArrayView2<MASK>,
1546    quant_masks: Quantization,
1547    protos: ArrayView3<PROTO>,
1548    quant_protos: Quantization,
1549    output_boxes: &mut Vec<DetectBox>,
1550) -> ProtoData {
1551    use edgefirst_tensor::{Tensor, TensorDyn, TensorMapTrait, TensorMemory, TensorTrait};
1552
1553    let span = tracing::trace_span!(
1554        "extract_proto",
1555        n = det_indices.len(),
1556        num_protos = tracing::field::Empty,
1557        layout = tracing::field::Empty,
1558    );
1559    let _guard = span.enter();
1560
1561    let num_protos = mask_tensor.ncols();
1562    let n = det_indices.len();
1563    span.record("num_protos", num_protos);
1564
1565    // Fast path: when no detections survive NMS, skip the expensive proto
1566    // tensor copy (819KB for 160×160×32). Allocate with the correct shape
1567    // (preserving the documented ProtoData.protos shape contract) but skip
1568    // copying from the source tensor — the zeroed allocation is sufficient
1569    // since materialize_masks early-returns on empty detect slices.
1570    if n == 0 {
1571        output_boxes.clear();
1572        let (h, w, k) = protos.dim();
1573
1574        // Detect physical layout (same logic as the normal path).
1575        let (proto_shape, proto_layout) = if std::any::TypeId::of::<PROTO>()
1576            == std::any::TypeId::of::<i8>()
1577        {
1578            if protos.is_standard_layout() {
1579                (&[h, w, k][..], ProtoLayout::Nhwc)
1580            } else if protos.ndim() == 3 && protos.strides() == [w as isize, 1, (h * w) as isize] {
1581                (&[k, h, w][..], ProtoLayout::Nchw)
1582            } else {
1583                (&[h, w, k][..], ProtoLayout::Nhwc)
1584            }
1585        } else {
1586            (&[h, w, k][..], ProtoLayout::Nhwc)
1587        };
1588
1589        let coeff_tensor = Tensor::<i8>::new(&[0, num_protos], Some(TensorMemory::Mem), None)
1590            .expect("allocating empty mask_coefficients tensor");
1591        let coeff_quant =
1592            edgefirst_tensor::Quantization::per_tensor(quant_masks.scale, quant_masks.zero_point);
1593        let coeff_tensor = coeff_tensor
1594            .with_quantization(coeff_quant)
1595            .expect("per-tensor quantization on mask coefficients");
1596        let protos_tensor = Tensor::<i8>::new(proto_shape, Some(TensorMemory::Mem), None)
1597            .expect("allocating protos tensor");
1598        let tensor_quant =
1599            edgefirst_tensor::Quantization::per_tensor(quant_protos.scale, quant_protos.zero_point);
1600        let protos_tensor = protos_tensor
1601            .with_quantization(tensor_quant)
1602            .expect("per-tensor quantization on protos tensor");
1603        return ProtoData {
1604            mask_coefficients: TensorDyn::I8(coeff_tensor),
1605            protos: TensorDyn::I8(protos_tensor),
1606            layout: proto_layout,
1607        };
1608    }
1609
1610    // Mask coefficients: keep i8 losslessly when MASK == i8 (preserves
1611    // the fast i8×i8→i32 integer kernel in materialize_masks). Preserve
1612    // i16 natively so the downstream i16×i8 integer path can avoid lossy
1613    // truncation. Other wider types (u16, …) dequantize to f32 at
1614    // extraction because the downstream mask kernels accept F32 natively.
1615    let mask_coefficients: TensorDyn = if std::any::TypeId::of::<MASK>()
1616        == std::any::TypeId::of::<i8>()
1617    {
1618        let mut coeff_i8 = Vec::<i8>::with_capacity(n * num_protos);
1619        output_boxes.clear();
1620        for (det, idx) in det_indices {
1621            output_boxes.push(det);
1622            let row = mask_tensor.row(idx);
1623            coeff_i8.extend(row.iter().map(|v| {
1624                let v_i8: i8 = v.as_();
1625                v_i8
1626            }));
1627        }
1628        let coeff_tensor = Tensor::<i8>::new(&[n, num_protos], Some(TensorMemory::Mem), None)
1629            .expect("allocating mask_coefficients tensor");
1630        if n > 0 {
1631            let mut m = coeff_tensor
1632                .map()
1633                .expect("mapping mask_coefficients tensor");
1634            m.as_mut_slice().copy_from_slice(&coeff_i8);
1635        }
1636        let coeff_quant =
1637            edgefirst_tensor::Quantization::per_tensor(quant_masks.scale, quant_masks.zero_point);
1638        let coeff_tensor = coeff_tensor
1639            .with_quantization(coeff_quant)
1640            .expect("per-tensor quantization on mask coefficients");
1641        TensorDyn::I8(coeff_tensor)
1642    } else if std::any::TypeId::of::<MASK>() == std::any::TypeId::of::<i16>() {
1643        // i16 path: preserve natively for the fast i16×i8→i32 integer kernel.
1644        // f32 has 24-bit mantissa, so all i16 values are exactly representable.
1645        let mut coeff_i16 = Vec::<i16>::with_capacity(n * num_protos);
1646        output_boxes.clear();
1647        for (det, idx) in det_indices {
1648            output_boxes.push(det);
1649            let row = mask_tensor.row(idx);
1650            coeff_i16.extend(row.iter().map(|v| {
1651                let v_f32: f32 = v.as_();
1652                v_f32 as i16
1653            }));
1654        }
1655        let coeff_tensor = Tensor::<i16>::new(&[n, num_protos], Some(TensorMemory::Mem), None)
1656            .expect("allocating mask_coefficients tensor");
1657        if n > 0 {
1658            let mut m = coeff_tensor
1659                .map()
1660                .expect("mapping mask_coefficients tensor");
1661            m.as_mut_slice().copy_from_slice(&coeff_i16);
1662        }
1663        let coeff_quant =
1664            edgefirst_tensor::Quantization::per_tensor(quant_masks.scale, quant_masks.zero_point);
1665        let coeff_tensor = coeff_tensor
1666            .with_quantization(coeff_quant)
1667            .expect("per-tensor quantization on mask coefficients");
1668        TensorDyn::I16(coeff_tensor)
1669    } else {
1670        // Other types (u8, u16, etc.): dequantize to f32 to avoid lossy truncation.
1671        let scale = quant_masks.scale;
1672        let zp = quant_masks.zero_point as f32;
1673        let mut coeff_f32 = Vec::<f32>::with_capacity(n * num_protos);
1674        output_boxes.clear();
1675        for (det, idx) in det_indices {
1676            output_boxes.push(det);
1677            let row = mask_tensor.row(idx);
1678            coeff_f32.extend(row.iter().map(|v| {
1679                let v_f32: f32 = v.as_();
1680                (v_f32 - zp) * scale
1681            }));
1682        }
1683        let coeff_tensor = Tensor::<f32>::new(&[n, num_protos], Some(TensorMemory::Mem), None)
1684            .expect("allocating mask_coefficients tensor");
1685        if n > 0 {
1686            let mut m = coeff_tensor
1687                .map()
1688                .expect("mapping mask_coefficients tensor");
1689            m.as_mut_slice().copy_from_slice(&coeff_f32);
1690        }
1691        TensorDyn::F32(coeff_tensor)
1692    };
1693
1694    // Keep protos in raw i8 — consumers dequantize via protos.quantization().
1695    // When PROTO is already i8, detect layout and copy efficiently without
1696    // transposing. The mask materialisation kernels dispatch on the layout.
1697    let (h, w, k) = protos.dim();
1698
1699    // Determine physical layout and copy strategy.
1700    let (proto_shape, proto_layout) =
1701        if std::any::TypeId::of::<PROTO>() == std::any::TypeId::of::<i8>() {
1702            if protos.is_standard_layout() {
1703                // Already NHWC [H, W, K] in contiguous memory.
1704                (&[h, w, k][..], ProtoLayout::Nhwc)
1705            } else if protos.ndim() == 3 && protos.strides() == [w as isize, 1, (h * w) as isize] {
1706                // NCHW reinterpreted as NHWC via stride swap. Physical storage
1707                // is [K, H, W] contiguous. Keep in NCHW — eliminates the costly
1708                // 3.1ms transpose entirely.
1709                (&[k, h, w][..], ProtoLayout::Nchw)
1710            } else {
1711                // Unknown layout — fall back to iter copy as NHWC.
1712                (&[h, w, k][..], ProtoLayout::Nhwc)
1713            }
1714        } else {
1715            (&[h, w, k][..], ProtoLayout::Nhwc)
1716        };
1717
1718    let protos_tensor = Tensor::<i8>::new(proto_shape, Some(TensorMemory::Mem), None)
1719        .expect("allocating protos tensor");
1720    {
1721        let mut m = protos_tensor.map().expect("mapping protos tensor");
1722        let dst = m.as_mut_slice();
1723        if std::any::TypeId::of::<PROTO>() == std::any::TypeId::of::<i8>() {
1724            // SAFETY: PROTO == i8 checked via TypeId; cast slice view is
1725            // size/alignment-compatible by construction.
1726            if protos.is_standard_layout() {
1727                let src: &[i8] = unsafe {
1728                    std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len())
1729                };
1730                dst.copy_from_slice(src);
1731            } else if protos.ndim() == 3 && protos.strides() == [w as isize, 1, (h * w) as isize] {
1732                // NCHW physical layout — sequential copy WITHOUT transpose.
1733                // This saves ~3.1ms on A53/A55 by avoiding the tiled
1734                // NCHW→NHWC transpose of the 819KB proto buffer.
1735                let total = h * w * k;
1736                // SAFETY: ArrayView was constructed from a contiguous slice of
1737                // `total` elements. as_ptr() points to the base of that slice.
1738                let src: &[i8] =
1739                    unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, total) };
1740                dst.copy_from_slice(src);
1741            } else {
1742                for (d, s) in dst.iter_mut().zip(protos.iter()) {
1743                    let v_i8: i8 = s.as_();
1744                    *d = v_i8;
1745                }
1746            }
1747        } else {
1748            for (d, s) in dst.iter_mut().zip(protos.iter()) {
1749                let v_i8: i8 = s.as_();
1750                *d = v_i8;
1751            }
1752        }
1753    }
1754    let tensor_quant =
1755        edgefirst_tensor::Quantization::per_tensor(quant_protos.scale, quant_protos.zero_point);
1756    let protos_tensor = protos_tensor
1757        .with_quantization(tensor_quant)
1758        .expect("per-tensor quantization on new Tensor<i8>");
1759
1760    span.record("layout", tracing::field::debug(&proto_layout));
1761
1762    ProtoData {
1763        mask_coefficients,
1764        protos: TensorDyn::I8(protos_tensor),
1765        layout: proto_layout,
1766    }
1767}
1768
1769/// Per-float-dtype construction of a [`TensorDyn`] from a flat slice / 3-D
1770/// `ArrayView`. Replaces the old `IntoProtoTensor` trait. Each implementor
1771/// either passes its element type straight to `Tensor::from_slice` /
1772/// `Tensor::from_arrayview3`, or narrows `f64` to `f32` (no native f64 kernel
1773/// path exists).
1774pub trait FloatProtoElem: Copy + 'static {
1775    fn slice_into_tensor_dyn(
1776        values: &[Self],
1777        shape: &[usize],
1778    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn>;
1779
1780    fn arrayview3_into_tensor_dyn(
1781        view: ArrayView3<'_, Self>,
1782    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn>;
1783}
1784
1785impl FloatProtoElem for f32 {
1786    fn slice_into_tensor_dyn(
1787        values: &[f32],
1788        shape: &[usize],
1789    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1790        edgefirst_tensor::Tensor::<f32>::from_slice(values, shape)
1791            .map(edgefirst_tensor::TensorDyn::F32)
1792    }
1793    fn arrayview3_into_tensor_dyn(
1794        view: ArrayView3<'_, f32>,
1795    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1796        edgefirst_tensor::Tensor::<f32>::from_arrayview3(view).map(edgefirst_tensor::TensorDyn::F32)
1797    }
1798}
1799
1800impl FloatProtoElem for half::f16 {
1801    fn slice_into_tensor_dyn(
1802        values: &[half::f16],
1803        shape: &[usize],
1804    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1805        edgefirst_tensor::Tensor::<half::f16>::from_slice(values, shape)
1806            .map(edgefirst_tensor::TensorDyn::F16)
1807    }
1808    fn arrayview3_into_tensor_dyn(
1809        view: ArrayView3<'_, half::f16>,
1810    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1811        edgefirst_tensor::Tensor::<half::f16>::from_arrayview3(view)
1812            .map(edgefirst_tensor::TensorDyn::F16)
1813    }
1814}
1815
1816impl FloatProtoElem for f64 {
1817    fn slice_into_tensor_dyn(
1818        values: &[f64],
1819        shape: &[usize],
1820    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1821        // Narrow to f32 — no native f64 kernel path.
1822        let narrowed: Vec<f32> = values.iter().map(|&v| v as f32).collect();
1823        edgefirst_tensor::Tensor::<f32>::from_slice(&narrowed, shape)
1824            .map(edgefirst_tensor::TensorDyn::F32)
1825    }
1826    fn arrayview3_into_tensor_dyn(
1827        view: ArrayView3<'_, f64>,
1828    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1829        let narrowed: ndarray::Array3<f32> = view.mapv(|v| v as f32);
1830        edgefirst_tensor::Tensor::<f32>::from_arrayview3(narrowed.view())
1831            .map(edgefirst_tensor::TensorDyn::F32)
1832    }
1833}
1834
1835fn postprocess_yolo<'a, T>(
1836    output: &'a ArrayView2<'_, T>,
1837) -> (ArrayView2<'a, T>, ArrayView2<'a, T>) {
1838    let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1839    let scores_tensor = output.slice(s![4.., ..,]).reversed_axes();
1840    (boxes_tensor, scores_tensor)
1841}
1842
1843pub(crate) fn postprocess_yolo_seg<'a, T>(
1844    output: &'a ArrayView2<'_, T>,
1845    num_protos: usize,
1846) -> (ArrayView2<'a, T>, ArrayView2<'a, T>, ArrayView2<'a, T>) {
1847    assert!(
1848        output.shape()[0] > num_protos + 4,
1849        "Output shape is too short: {} <= {} + 4",
1850        output.shape()[0],
1851        num_protos
1852    );
1853    let num_classes = output.shape()[0] - 4 - num_protos;
1854    let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1855    let scores_tensor = output.slice(s![4..(num_classes + 4), ..,]).reversed_axes();
1856    let mask_tensor = output.slice(s![(num_classes + 4).., ..,]).reversed_axes();
1857    (boxes_tensor, scores_tensor, mask_tensor)
1858}
1859
1860pub(crate) fn postprocess_yolo_split_segdet<'a, 'b, 'c, BOX, SCORE, MASK>(
1861    boxes_tensor: ArrayView2<'a, BOX>,
1862    scores_tensor: ArrayView2<'b, SCORE>,
1863    mask_tensor: ArrayView2<'c, MASK>,
1864) -> (
1865    ArrayView2<'a, BOX>,
1866    ArrayView2<'b, SCORE>,
1867    ArrayView2<'c, MASK>,
1868) {
1869    let boxes_tensor = boxes_tensor.reversed_axes();
1870    let scores_tensor = scores_tensor.reversed_axes();
1871    let mask_tensor = mask_tensor.reversed_axes();
1872    (boxes_tensor, scores_tensor, mask_tensor)
1873}
1874
1875fn decode_segdet_f32<
1876    MASK: Float + AsPrimitive<f32> + Send + Sync,
1877    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1878>(
1879    boxes: Vec<(DetectBox, usize)>,
1880    masks: ArrayView2<MASK>,
1881    protos: ArrayView3<PROTO>,
1882) -> Result<Vec<(DetectBox, BoundingBox, Array3<u8>)>, crate::DecoderError> {
1883    if boxes.is_empty() {
1884        return Ok(Vec::new());
1885    }
1886    if masks.shape()[1] != protos.shape()[2] {
1887        return Err(crate::DecoderError::InvalidShape(format!(
1888            "Mask coefficients count ({}) doesn't match protos channel count ({})",
1889            masks.shape()[1],
1890            protos.shape()[2],
1891        )));
1892    }
1893    boxes
1894        .into_par_iter()
1895        .map(|b| {
1896            let ind = b.1;
1897            // `protobox` returns the cropped proto slice for `make_segmentation`
1898            // and a `roi` snapped to the 1/proto-grid step. The detection bbox
1899            // stays untouched (EDGEAI-1304); the snapped roi is reported back
1900            // separately so callers can describe where the cropped mask lives.
1901            let (protos, roi) = protobox(&protos, &b.0.bbox)?;
1902            Ok((b.0, roi, make_segmentation(masks.row(ind), protos.view())))
1903        })
1904        .collect()
1905}
1906
1907pub(crate) fn decode_segdet_quant<
1908    MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1909    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1910>(
1911    boxes: Vec<(DetectBox, usize)>,
1912    masks: ArrayView2<MASK>,
1913    protos: ArrayView3<PROTO>,
1914    quant_masks: Quantization,
1915    quant_protos: Quantization,
1916) -> Result<Vec<(DetectBox, BoundingBox, Array3<u8>)>, crate::DecoderError> {
1917    if boxes.is_empty() {
1918        return Ok(Vec::new());
1919    }
1920    if masks.shape()[1] != protos.shape()[2] {
1921        return Err(crate::DecoderError::InvalidShape(format!(
1922            "Mask coefficients count ({}) doesn't match protos channel count ({})",
1923            masks.shape()[1],
1924            protos.shape()[2],
1925        )));
1926    }
1927
1928    let total_bits = MASK::zero().count_zeros() + PROTO::zero().count_zeros() + 5; // 32 protos is 2^5
1929    boxes
1930        .into_iter()
1931        .map(|b| {
1932            let i = b.1;
1933            // See EDGEAI-1304: the caller's bbox stays untouched; the
1934            // proto-grid-snapped `roi` is reported back so the Segmentation's
1935            // bounds can describe the actual cropped mask region.
1936            let (protos, roi) = protobox(&protos, &b.0.bbox.to_canonical())?;
1937            let seg = match total_bits {
1938                0..=64 => make_segmentation_quant::<MASK, PROTO, i64>(
1939                    masks.row(i),
1940                    protos.view(),
1941                    quant_masks,
1942                    quant_protos,
1943                ),
1944                65..=128 => make_segmentation_quant::<MASK, PROTO, i128>(
1945                    masks.row(i),
1946                    protos.view(),
1947                    quant_masks,
1948                    quant_protos,
1949                ),
1950                _ => {
1951                    return Err(crate::DecoderError::NotSupported(format!(
1952                        "Unsupported bit width ({total_bits}) for segmentation computation"
1953                    )));
1954                }
1955            };
1956            Ok((b.0, roi, seg))
1957        })
1958        .collect()
1959}
1960
1961fn protobox<'a, T>(
1962    protos: &'a ArrayView3<T>,
1963    roi: &BoundingBox,
1964) -> Result<(ArrayView3<'a, T>, BoundingBox), crate::DecoderError> {
1965    let width = protos.dim().1 as f32;
1966    let height = protos.dim().0 as f32;
1967
1968    // Detect un-normalized bounding boxes (pixel-space coordinates).
1969    // protobox expects normalized coordinates in [0, 1]. The decoder will
1970    // normalize pixel-space coords automatically when the schema declares
1971    // `Detection::normalized = false` AND model input dimensions are known
1972    // (EDGEAI-1303); reaching this guard means at least one of those is
1973    // missing.
1974    //
1975    // The limit is set to 2.0 (not 1.01) because YOLO models legitimately
1976    // predict coordinates slightly > 1.0 for objects near frame edges.
1977    // Any value > 2.0 is clearly pixel-space (even the smallest practical
1978    // model input of 32×32 would produce coordinates >> 2.0).
1979    const NORM_LIMIT: f32 = 2.0;
1980    if roi.xmin > NORM_LIMIT
1981        || roi.ymin > NORM_LIMIT
1982        || roi.xmax > NORM_LIMIT
1983        || roi.ymax > NORM_LIMIT
1984    {
1985        return Err(crate::DecoderError::InvalidShape(format!(
1986            "Bounding box coordinates appear un-normalized (pixel-space). \
1987             Got bbox=({:.2}, {:.2}, {:.2}, {:.2}) but expected values in [0, 1]. \
1988             Two ways to fix this: \
1989             (1) declare `Detection::normalized = false` in the model schema \
1990             AND make sure the schema's `input.shape` / `input.dshape` carries \
1991             the model input dims so the decoder can divide by (W, H) before NMS \
1992             (EDGEAI-1303 — verify with `Decoder::input_dims().is_some()`); or \
1993             (2) normalize the boxes in-graph before decode().",
1994            roi.xmin, roi.ymin, roi.xmax, roi.ymax,
1995        )));
1996    }
1997
1998    let roi = [
1999        (roi.xmin * width).clamp(0.0, width) as usize,
2000        (roi.ymin * height).clamp(0.0, height) as usize,
2001        (roi.xmax * width).clamp(0.0, width).ceil() as usize,
2002        (roi.ymax * height).clamp(0.0, height).ceil() as usize,
2003    ];
2004
2005    let roi_norm = [
2006        roi[0] as f32 / width,
2007        roi[1] as f32 / height,
2008        roi[2] as f32 / width,
2009        roi[3] as f32 / height,
2010    ]
2011    .into();
2012
2013    let cropped = protos.slice(s![roi[1]..roi[3], roi[0]..roi[2], ..]);
2014
2015    Ok((cropped, roi_norm))
2016}
2017
2018/// Compute a single instance segmentation mask from mask coefficients and
2019/// proto maps (float path).
2020///
2021/// Computes `sigmoid(coefficients · protos)` and maps to `[0, 255]`.
2022/// Returns an `(H, W, 1)` u8 array.
2023fn make_segmentation<
2024    MASK: Float + AsPrimitive<f32> + Send + Sync,
2025    PROTO: Float + AsPrimitive<f32> + Send + Sync,
2026>(
2027    mask: ArrayView1<MASK>,
2028    protos: ArrayView3<PROTO>,
2029) -> Array3<u8> {
2030    let shape = protos.shape();
2031
2032    // Safe to unwrap since the shapes will always be compatible
2033    let mask = mask.to_shape((1, mask.len())).unwrap();
2034    let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
2035    let protos = protos.reversed_axes();
2036    let mask = mask.map(|x| x.as_());
2037    let protos = protos.map(|x| x.as_());
2038
2039    // Safe to unwrap since the shapes will always be compatible
2040    let mask = mask
2041        .dot(&protos)
2042        .into_shape_with_order((shape[0], shape[1], 1))
2043        .unwrap();
2044
2045    mask.map(|x| {
2046        let sigmoid = 1.0 / (1.0 + (-*x).exp());
2047        (sigmoid * 255.0).round() as u8
2048    })
2049}
2050
2051/// Compute a single instance segmentation mask from quantized mask
2052/// coefficients and proto maps.
2053///
2054/// Dequantizes both inputs (subtracting zero-points), computes the dot
2055/// product, applies sigmoid, and maps to `[0, 255]`.
2056/// Returns an `(H, W, 1)` u8 array.
2057fn make_segmentation_quant<
2058    MASK: PrimInt + AsPrimitive<DEST> + Send + Sync,
2059    PROTO: PrimInt + AsPrimitive<DEST> + Send + Sync,
2060    DEST: PrimInt + 'static + Signed + AsPrimitive<f32> + Debug,
2061>(
2062    mask: ArrayView1<MASK>,
2063    protos: ArrayView3<PROTO>,
2064    quant_masks: Quantization,
2065    quant_protos: Quantization,
2066) -> Array3<u8>
2067where
2068    i32: AsPrimitive<DEST>,
2069    f32: AsPrimitive<DEST>,
2070{
2071    let shape = protos.shape();
2072
2073    // Safe to unwrap since the shapes will always be compatible
2074    let mask = mask.to_shape((1, mask.len())).unwrap();
2075
2076    let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
2077    let protos = protos.reversed_axes();
2078
2079    let zp = quant_masks.zero_point.as_();
2080
2081    let mask = mask.mapv(|x| x.as_() - zp);
2082
2083    let zp = quant_protos.zero_point.as_();
2084    let protos = protos.mapv(|x| x.as_() - zp);
2085
2086    // Safe to unwrap since the shapes will always be compatible
2087    let segmentation = mask
2088        .dot(&protos)
2089        .into_shape_with_order((shape[0], shape[1], 1))
2090        .unwrap();
2091
2092    let combined_scale = quant_masks.scale * quant_protos.scale;
2093    segmentation.map(|x| {
2094        let val: f32 = (*x).as_() * combined_scale;
2095        let sigmoid = 1.0 / (1.0 + (-val).exp());
2096        (sigmoid * 255.0).round() as u8
2097    })
2098}
2099
2100/// Converts Yolo Instance Segmentation into a 2D mask.
2101///
2102/// The input segmentation is expected to have shape (H, W, 1).
2103///
2104/// The output mask will have shape (H, W), with values 0 or 1 based on the
2105/// threshold.
2106///
2107/// # Errors
2108///
2109/// Returns `DecoderError::InvalidShape` if the input segmentation does not
2110/// have shape (H, W, 1).
2111pub(crate) fn yolo_segmentation_to_mask(
2112    segmentation: ArrayView3<u8>,
2113    threshold: u8,
2114) -> Result<Array2<u8>, crate::DecoderError> {
2115    if segmentation.shape()[2] != 1 {
2116        return Err(crate::DecoderError::InvalidShape(format!(
2117            "Yolo Instance Segmentation should have shape (H, W, 1), got (H, W, {})",
2118            segmentation.shape()[2]
2119        )));
2120    }
2121    Ok(segmentation
2122        .slice(s![.., .., 0])
2123        .map(|x| if *x >= threshold { 1 } else { 0 }))
2124}
2125
2126#[cfg(test)]
2127#[cfg_attr(coverage_nightly, coverage(off))]
2128mod tests {
2129    use super::*;
2130    use ndarray::Array2;
2131
2132    // ========================================================================
2133    // Tests for decode_yolo_end_to_end_det_float
2134    // ========================================================================
2135
2136    #[test]
2137    fn test_end_to_end_det_basic_filtering() {
2138        // Create synthetic end-to-end detection output: (6, N) where rows are
2139        // [x1, y1, x2, y2, conf, class]
2140        // 3 detections: one above threshold, two below
2141        let data: Vec<f32> = vec![
2142            // Detection 0: high score (0.9)
2143            0.1, 0.2, 0.3, // x1 values
2144            0.1, 0.2, 0.3, // y1 values
2145            0.5, 0.6, 0.7, // x2 values
2146            0.5, 0.6, 0.7, // y2 values
2147            0.9, 0.1, 0.2, // confidence scores
2148            0.0, 1.0, 2.0, // class indices
2149        ];
2150        let output = Array2::from_shape_vec((6, 3), data).unwrap();
2151
2152        let mut boxes = Vec::with_capacity(10);
2153        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2154
2155        // Only 1 detection should pass threshold of 0.5
2156        assert_eq!(boxes.len(), 1);
2157        assert_eq!(boxes[0].label, 0);
2158        assert!((boxes[0].score - 0.9).abs() < 0.01);
2159        assert!((boxes[0].bbox.xmin - 0.1).abs() < 0.01);
2160        assert!((boxes[0].bbox.ymin - 0.1).abs() < 0.01);
2161        assert!((boxes[0].bbox.xmax - 0.5).abs() < 0.01);
2162        assert!((boxes[0].bbox.ymax - 0.5).abs() < 0.01);
2163    }
2164
2165    #[test]
2166    fn test_end_to_end_det_all_pass_threshold() {
2167        // All detections above threshold
2168        let data: Vec<f32> = vec![
2169            10.0, 20.0, // x1
2170            10.0, 20.0, // y1
2171            50.0, 60.0, // x2
2172            50.0, 60.0, // y2
2173            0.8, 0.7, // conf (both above 0.5)
2174            1.0, 2.0, // class
2175        ];
2176        let output = Array2::from_shape_vec((6, 2), data).unwrap();
2177
2178        let mut boxes = Vec::with_capacity(10);
2179        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2180
2181        assert_eq!(boxes.len(), 2);
2182        assert_eq!(boxes[0].label, 1);
2183        assert_eq!(boxes[1].label, 2);
2184    }
2185
2186    #[test]
2187    fn test_end_to_end_det_none_pass_threshold() {
2188        // All detections below threshold
2189        let data: Vec<f32> = vec![
2190            10.0, 20.0, // x1
2191            10.0, 20.0, // y1
2192            50.0, 60.0, // x2
2193            50.0, 60.0, // y2
2194            0.1, 0.2, // conf (both below 0.5)
2195            1.0, 2.0, // class
2196        ];
2197        let output = Array2::from_shape_vec((6, 2), data).unwrap();
2198
2199        let mut boxes = Vec::with_capacity(10);
2200        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2201
2202        assert_eq!(boxes.len(), 0);
2203    }
2204
2205    #[test]
2206    fn test_end_to_end_det_capacity_limit() {
2207        // Test that output is truncated to capacity
2208        let data: Vec<f32> = vec![
2209            0.1, 0.2, 0.3, 0.4, 0.5, // x1
2210            0.1, 0.2, 0.3, 0.4, 0.5, // y1
2211            0.5, 0.6, 0.7, 0.8, 0.9, // x2
2212            0.5, 0.6, 0.7, 0.8, 0.9, // y2
2213            0.9, 0.9, 0.9, 0.9, 0.9, // conf (all pass)
2214            0.0, 1.0, 2.0, 3.0, 4.0, // class
2215        ];
2216        let output = Array2::from_shape_vec((6, 5), data).unwrap();
2217
2218        let mut boxes = Vec::with_capacity(2); // Only allow 2 boxes
2219        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2220
2221        assert_eq!(boxes.len(), 2);
2222    }
2223
2224    #[test]
2225    fn test_end_to_end_det_empty_output() {
2226        // Test with zero detections
2227        let output = Array2::<f32>::zeros((6, 0));
2228
2229        let mut boxes = Vec::with_capacity(10);
2230        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2231
2232        assert_eq!(boxes.len(), 0);
2233    }
2234
2235    #[test]
2236    fn test_end_to_end_det_pixel_coordinates() {
2237        // Test with pixel coordinates (non-normalized)
2238        let data: Vec<f32> = vec![
2239            100.0, // x1
2240            200.0, // y1
2241            300.0, // x2
2242            400.0, // y2
2243            0.95,  // conf
2244            5.0,   // class
2245        ];
2246        let output = Array2::from_shape_vec((6, 1), data).unwrap();
2247
2248        let mut boxes = Vec::with_capacity(10);
2249        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2250
2251        assert_eq!(boxes.len(), 1);
2252        assert_eq!(boxes[0].label, 5);
2253        assert!((boxes[0].bbox.xmin - 100.0).abs() < 0.01);
2254        assert!((boxes[0].bbox.ymin - 200.0).abs() < 0.01);
2255        assert!((boxes[0].bbox.xmax - 300.0).abs() < 0.01);
2256        assert!((boxes[0].bbox.ymax - 400.0).abs() < 0.01);
2257    }
2258
2259    #[test]
2260    fn test_end_to_end_det_invalid_shape() {
2261        // Test with too few rows (needs at least 6)
2262        let output = Array2::<f32>::zeros((5, 3));
2263
2264        let mut boxes = Vec::with_capacity(10);
2265        let result = decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes);
2266
2267        assert!(result.is_err());
2268        assert!(matches!(
2269            result,
2270            Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 6 rows")
2271        ));
2272    }
2273
2274    // ========================================================================
2275    // Tests for decode_yolo_end_to_end_segdet_float
2276    // ========================================================================
2277
2278    #[test]
2279    fn test_end_to_end_segdet_basic() {
2280        // Create synthetic segdet output: (6 + num_protos, N)
2281        // Detection format: [x1, y1, x2, y2, conf, class, mask_coeff_0..31]
2282        let num_protos = 32;
2283        let num_detections = 2;
2284        let num_features = 6 + num_protos;
2285
2286        // Build detection tensor
2287        let mut data = vec![0.0f32; num_features * num_detections];
2288        // Detection 0: passes threshold
2289        data[0] = 0.1; // x1[0]
2290        data[1] = 0.5; // x1[1]
2291        data[num_detections] = 0.1; // y1[0]
2292        data[num_detections + 1] = 0.5; // y1[1]
2293        data[2 * num_detections] = 0.4; // x2[0]
2294        data[2 * num_detections + 1] = 0.9; // x2[1]
2295        data[3 * num_detections] = 0.4; // y2[0]
2296        data[3 * num_detections + 1] = 0.9; // y2[1]
2297        data[4 * num_detections] = 0.9; // conf[0] - passes
2298        data[4 * num_detections + 1] = 0.3; // conf[1] - fails
2299        data[5 * num_detections] = 1.0; // class[0]
2300        data[5 * num_detections + 1] = 2.0; // class[1]
2301                                            // Fill mask coefficients with small values
2302        for i in 6..num_features {
2303            data[i * num_detections] = 0.1;
2304            data[i * num_detections + 1] = 0.1;
2305        }
2306
2307        let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
2308
2309        // Create protos tensor: (proto_height, proto_width, num_protos)
2310        let protos = Array3::<f32>::zeros((16, 16, num_protos));
2311
2312        let mut boxes = Vec::with_capacity(10);
2313        let mut masks = Vec::with_capacity(10);
2314        decode_yolo_end_to_end_segdet_float(
2315            output.view(),
2316            protos.view(),
2317            0.5,
2318            &mut boxes,
2319            &mut masks,
2320        )
2321        .unwrap();
2322
2323        // Only detection 0 should pass
2324        assert_eq!(boxes.len(), 1);
2325        assert_eq!(masks.len(), 1);
2326        assert_eq!(boxes[0].label, 1);
2327        assert!((boxes[0].score - 0.9).abs() < 0.01);
2328    }
2329
2330    #[test]
2331    fn test_end_to_end_segdet_mask_coordinates() {
2332        // Test that mask coordinates match box coordinates
2333        let num_protos = 32;
2334        let num_features = 6 + num_protos;
2335
2336        let mut data = vec![0.0f32; num_features];
2337        data[0] = 0.2; // x1
2338        data[1] = 0.2; // y1
2339        data[2] = 0.8; // x2
2340        data[3] = 0.8; // y2
2341        data[4] = 0.95; // conf
2342        data[5] = 3.0; // class
2343
2344        let output = Array2::from_shape_vec((num_features, 1), data).unwrap();
2345        let protos = Array3::<f32>::zeros((16, 16, num_protos));
2346
2347        let mut boxes = Vec::with_capacity(10);
2348        let mut masks = Vec::with_capacity(10);
2349        decode_yolo_end_to_end_segdet_float(
2350            output.view(),
2351            protos.view(),
2352            0.5,
2353            &mut boxes,
2354            &mut masks,
2355        )
2356        .unwrap();
2357
2358        assert_eq!(boxes.len(), 1);
2359        assert_eq!(masks.len(), 1);
2360
2361        // Mask region is the proto-grid-aligned crop and encloses the
2362        // post-NMS bbox (EDGEAI-1304); on a 16x16 grid each side may snap
2363        // by up to 1/16 = 0.0625.
2364        let step = 1.0 / 16.0;
2365        assert!(masks[0].xmin <= boxes[0].bbox.xmin);
2366        assert!(masks[0].ymin <= boxes[0].bbox.ymin);
2367        assert!(masks[0].xmax >= boxes[0].bbox.xmax);
2368        assert!(masks[0].ymax >= boxes[0].bbox.ymax);
2369        assert!((boxes[0].bbox.xmin - masks[0].xmin) < step);
2370        assert!((boxes[0].bbox.ymin - masks[0].ymin) < step);
2371        assert!((masks[0].xmax - boxes[0].bbox.xmax) < step);
2372        assert!((masks[0].ymax - boxes[0].bbox.ymax) < step);
2373    }
2374
2375    #[test]
2376    fn test_end_to_end_segdet_empty_output() {
2377        let num_protos = 32;
2378        let output = Array2::<f32>::zeros((6 + num_protos, 0));
2379        let protos = Array3::<f32>::zeros((16, 16, num_protos));
2380
2381        let mut boxes = Vec::with_capacity(10);
2382        let mut masks = Vec::with_capacity(10);
2383        decode_yolo_end_to_end_segdet_float(
2384            output.view(),
2385            protos.view(),
2386            0.5,
2387            &mut boxes,
2388            &mut masks,
2389        )
2390        .unwrap();
2391
2392        assert_eq!(boxes.len(), 0);
2393        assert_eq!(masks.len(), 0);
2394    }
2395
2396    #[test]
2397    fn test_end_to_end_segdet_capacity_limit() {
2398        let num_protos = 32;
2399        let num_detections = 5;
2400        let num_features = 6 + num_protos;
2401
2402        let mut data = vec![0.0f32; num_features * num_detections];
2403        // All detections pass threshold
2404        for i in 0..num_detections {
2405            data[i] = 0.1 * (i as f32); // x1
2406            data[num_detections + i] = 0.1 * (i as f32); // y1
2407            data[2 * num_detections + i] = 0.1 * (i as f32) + 0.2; // x2
2408            data[3 * num_detections + i] = 0.1 * (i as f32) + 0.2; // y2
2409            data[4 * num_detections + i] = 0.9; // conf
2410            data[5 * num_detections + i] = i as f32; // class
2411        }
2412
2413        let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
2414        let protos = Array3::<f32>::zeros((16, 16, num_protos));
2415
2416        let mut boxes = Vec::with_capacity(2); // Limit to 2
2417        let mut masks = Vec::with_capacity(2);
2418        decode_yolo_end_to_end_segdet_float(
2419            output.view(),
2420            protos.view(),
2421            0.5,
2422            &mut boxes,
2423            &mut masks,
2424        )
2425        .unwrap();
2426
2427        assert_eq!(boxes.len(), 2);
2428        assert_eq!(masks.len(), 2);
2429    }
2430
2431    #[test]
2432    fn test_end_to_end_segdet_invalid_shape_too_few_rows() {
2433        // Test with too few rows (needs at least 7: 6 base + 1 mask coeff)
2434        let output = Array2::<f32>::zeros((6, 3));
2435        let protos = Array3::<f32>::zeros((16, 16, 32));
2436
2437        let mut boxes = Vec::with_capacity(10);
2438        let mut masks = Vec::with_capacity(10);
2439        let result = decode_yolo_end_to_end_segdet_float(
2440            output.view(),
2441            protos.view(),
2442            0.5,
2443            &mut boxes,
2444            &mut masks,
2445        );
2446
2447        assert!(result.is_err());
2448        assert!(matches!(
2449            result,
2450            Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 7 rows")
2451        ));
2452    }
2453
2454    #[test]
2455    fn test_end_to_end_segdet_invalid_shape_protos_mismatch() {
2456        // Test with mismatched mask coefficients and protos count
2457        let num_protos = 32;
2458        let output = Array2::<f32>::zeros((6 + 16, 3)); // 16 mask coeffs
2459        let protos = Array3::<f32>::zeros((16, 16, num_protos)); // 32 protos
2460
2461        let mut boxes = Vec::with_capacity(10);
2462        let mut masks = Vec::with_capacity(10);
2463        let result = decode_yolo_end_to_end_segdet_float(
2464            output.view(),
2465            protos.view(),
2466            0.5,
2467            &mut boxes,
2468            &mut masks,
2469        );
2470
2471        assert!(result.is_err());
2472        assert!(matches!(
2473            result,
2474            Err(crate::DecoderError::InvalidShape(s)) if s.contains("doesn't match protos count")
2475        ));
2476    }
2477
2478    // ========================================================================
2479    // Tests for decode_yolo_split_end_to_end_segdet_float
2480    // ========================================================================
2481
2482    #[test]
2483    fn test_split_end_to_end_segdet_basic() {
2484        // Create synthetic segdet output: (6 + num_protos, N)
2485        // Detection format: [x1, y1, x2, y2, conf, class, mask_coeff_0..31]
2486        let num_protos = 32;
2487        let num_detections = 2;
2488        let num_features = 6 + num_protos;
2489
2490        // Build detection tensor
2491        let mut data = vec![0.0f32; num_features * num_detections];
2492        // Detection 0: passes threshold
2493        data[0] = 0.1; // x1[0]
2494        data[1] = 0.5; // x1[1]
2495        data[num_detections] = 0.1; // y1[0]
2496        data[num_detections + 1] = 0.5; // y1[1]
2497        data[2 * num_detections] = 0.4; // x2[0]
2498        data[2 * num_detections + 1] = 0.9; // x2[1]
2499        data[3 * num_detections] = 0.4; // y2[0]
2500        data[3 * num_detections + 1] = 0.9; // y2[1]
2501        data[4 * num_detections] = 0.9; // conf[0] - passes
2502        data[4 * num_detections + 1] = 0.3; // conf[1] - fails
2503        data[5 * num_detections] = 1.0; // class[0]
2504        data[5 * num_detections + 1] = 2.0; // class[1]
2505                                            // Fill mask coefficients with small values
2506        for i in 6..num_features {
2507            data[i * num_detections] = 0.1;
2508            data[i * num_detections + 1] = 0.1;
2509        }
2510
2511        let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
2512        let box_coords = output.slice(s![..4, ..]);
2513        let scores = output.slice(s![4..5, ..]);
2514        let classes = output.slice(s![5..6, ..]);
2515        let mask_coeff = output.slice(s![6.., ..]);
2516        // Create protos tensor: (proto_height, proto_width, num_protos)
2517        let protos = Array3::<f32>::zeros((16, 16, num_protos));
2518
2519        let mut boxes = Vec::with_capacity(10);
2520        let mut masks = Vec::with_capacity(10);
2521        decode_yolo_split_end_to_end_segdet_float(
2522            box_coords,
2523            scores,
2524            classes,
2525            mask_coeff,
2526            protos.view(),
2527            0.5,
2528            &mut boxes,
2529            &mut masks,
2530        )
2531        .unwrap();
2532
2533        // Only detection 0 should pass
2534        assert_eq!(boxes.len(), 1);
2535        assert_eq!(masks.len(), 1);
2536        assert_eq!(boxes[0].label, 1);
2537        assert!((boxes[0].score - 0.9).abs() < 0.01);
2538    }
2539
2540    // ========================================================================
2541    // Tests for yolo_segmentation_to_mask
2542    // ========================================================================
2543
2544    #[test]
2545    fn test_segmentation_to_mask_basic() {
2546        // Create a 4x4x1 segmentation with values above and below threshold
2547        let data: Vec<u8> = vec![
2548            100, 200, 50, 150, // row 0
2549            10, 255, 128, 64, // row 1
2550            0, 127, 128, 255, // row 2
2551            64, 64, 192, 192, // row 3
2552        ];
2553        let segmentation = Array3::from_shape_vec((4, 4, 1), data).unwrap();
2554
2555        let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2556
2557        // Values >= 128 should be 1, others 0
2558        assert_eq!(mask[[0, 0]], 0); // 100 < 128
2559        assert_eq!(mask[[0, 1]], 1); // 200 >= 128
2560        assert_eq!(mask[[0, 2]], 0); // 50 < 128
2561        assert_eq!(mask[[0, 3]], 1); // 150 >= 128
2562        assert_eq!(mask[[1, 1]], 1); // 255 >= 128
2563        assert_eq!(mask[[1, 2]], 1); // 128 >= 128
2564        assert_eq!(mask[[2, 0]], 0); // 0 < 128
2565        assert_eq!(mask[[2, 1]], 0); // 127 < 128
2566    }
2567
2568    #[test]
2569    fn test_segmentation_to_mask_all_above() {
2570        let segmentation = Array3::from_elem((4, 4, 1), 255u8);
2571        let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2572        assert!(mask.iter().all(|&x| x == 1));
2573    }
2574
2575    #[test]
2576    fn test_segmentation_to_mask_all_below() {
2577        let segmentation = Array3::from_elem((4, 4, 1), 64u8);
2578        let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2579        assert!(mask.iter().all(|&x| x == 0));
2580    }
2581
2582    #[test]
2583    fn test_segmentation_to_mask_invalid_shape() {
2584        let segmentation = Array3::from_elem((4, 4, 3), 128u8);
2585        let result = yolo_segmentation_to_mask(segmentation.view(), 128);
2586
2587        assert!(result.is_err());
2588        assert!(matches!(
2589            result,
2590            Err(crate::DecoderError::InvalidShape(s)) if s.contains("(H, W, 1)")
2591        ));
2592    }
2593
2594    // ========================================================================
2595    // Tests for protobox / NORM_LIMIT regression
2596    // ========================================================================
2597
2598    #[test]
2599    fn test_protobox_clamps_edge_coordinates() {
2600        // bbox with xmax=1.0 should not panic (OOB guard)
2601        let protos = Array3::<f32>::zeros((16, 16, 4));
2602        let view = protos.view();
2603        let roi = BoundingBox {
2604            xmin: 0.5,
2605            ymin: 0.5,
2606            xmax: 1.0,
2607            ymax: 1.0,
2608        };
2609        let result = protobox(&view, &roi);
2610        assert!(result.is_ok(), "protobox should accept xmax=1.0");
2611        let (cropped, _roi_norm) = result.unwrap();
2612        // Cropped region must have non-zero spatial dimensions
2613        assert!(cropped.shape()[0] > 0);
2614        assert!(cropped.shape()[1] > 0);
2615        assert_eq!(cropped.shape()[2], 4);
2616    }
2617
2618    #[test]
2619    fn test_protobox_rejects_wildly_out_of_range() {
2620        // bbox with coords > NORM_LIMIT (e.g. 3.0) returns error
2621        let protos = Array3::<f32>::zeros((16, 16, 4));
2622        let view = protos.view();
2623        let roi = BoundingBox {
2624            xmin: 0.0,
2625            ymin: 0.0,
2626            xmax: 3.0,
2627            ymax: 3.0,
2628        };
2629        let result = protobox(&view, &roi);
2630        assert!(
2631            matches!(result, Err(crate::DecoderError::InvalidShape(s)) if s.contains("un-normalized")),
2632            "protobox should reject coords > NORM_LIMIT"
2633        );
2634    }
2635
2636    #[test]
2637    fn test_protobox_accepts_slightly_over_one() {
2638        // bbox with coords at 1.5 (within NORM_LIMIT=2.0) succeeds
2639        let protos = Array3::<f32>::zeros((16, 16, 4));
2640        let view = protos.view();
2641        let roi = BoundingBox {
2642            xmin: 0.0,
2643            ymin: 0.0,
2644            xmax: 1.5,
2645            ymax: 1.5,
2646        };
2647        let result = protobox(&view, &roi);
2648        assert!(
2649            result.is_ok(),
2650            "protobox should accept coords <= NORM_LIMIT (2.0)"
2651        );
2652        let (cropped, _roi_norm) = result.unwrap();
2653        // Entire proto map should be selected when coords > 1.0 (clamped to boundary)
2654        assert_eq!(cropped.shape()[0], 16);
2655        assert_eq!(cropped.shape()[1], 16);
2656    }
2657
2658    #[test]
2659    fn test_segdet_float_proto_no_panic() {
2660        // Simulates YOLOv8n-seg: output0 = [116, 8400] (4 box + 80 class + 32 mask coeff)
2661        // output1 (protos) = [32, 160, 160]
2662        let num_proposals = 100; // enough to produce idx >= 32
2663        let num_classes = 80;
2664        let num_mask_coeffs = 32;
2665        let rows = 4 + num_classes + num_mask_coeffs; // 116
2666
2667        // Fill boxes with valid xywh data so some detections pass the threshold.
2668        // Layout is [116, num_proposals] row-major: row 0=cx, 1=cy, 2=w, 3=h,
2669        // rows 4..84=class scores, rows 84..116=mask coefficients.
2670        let mut data = vec![0.0f32; rows * num_proposals];
2671        for i in 0..num_proposals {
2672            let row = |r: usize| r * num_proposals + i;
2673            data[row(0)] = 320.0; // cx
2674            data[row(1)] = 320.0; // cy
2675            data[row(2)] = 50.0; // w
2676            data[row(3)] = 50.0; // h
2677            data[row(4)] = 0.9; // class-0 score
2678        }
2679        let boxes = ndarray::Array2::from_shape_vec((rows, num_proposals), data).unwrap();
2680
2681        // Protos must be in HWC order. Under the HAL physical-order
2682        // contract, callers declare shape+dshape matching producer memory
2683        // and swap_axes_if_needed permutes the stride tuple into canonical
2684        // [batch, height, width, num_protos] before this function sees it.
2685        let protos = ndarray::Array3::<f32>::zeros((160, 160, num_mask_coeffs));
2686
2687        let mut output_boxes = Vec::with_capacity(300);
2688
2689        // This panicked before fix: mask_tensor.row(idx) with idx >= 32
2690        let proto_data = impl_yolo_segdet_float_proto::<XYWH, _, _>(
2691            boxes.view(),
2692            protos.view(),
2693            0.5,
2694            0.7,
2695            Some(Nms::default()),
2696            MAX_NMS_CANDIDATES,
2697            300,
2698            None,
2699            None,
2700            &mut output_boxes,
2701        );
2702
2703        // Should produce detections (NMS will collapse many overlapping boxes)
2704        assert!(!output_boxes.is_empty());
2705        let coeffs_shape = proto_data.mask_coefficients.shape();
2706        assert_eq!(coeffs_shape[0], output_boxes.len());
2707        // Each mask coefficient vector should have 32 elements
2708        assert_eq!(coeffs_shape[1], num_mask_coeffs);
2709    }
2710
2711    // ========================================================================
2712    // Pre-NMS top-K cap (MAX_NMS_CANDIDATES)
2713    // ========================================================================
2714
2715    /// At very low score thresholds (e.g., t=0.01 on YOLOv8 with 8400×80
2716    /// candidates) almost every score passes the filter, feeding O(n²)
2717    /// NMS and a per-survivor mask matmul. The decoder caps the
2718    /// candidate set fed to NMS at `MAX_NMS_CANDIDATES` (Ultralytics
2719    /// default 30 000) to bound worst-case decode time.
2720    ///
2721    /// This regression test pumps 50 000 above-threshold candidates
2722    /// into `impl_yolo_segdet_get_boxes` with NMS bypassed (Nms=None)
2723    /// and a generous post-NMS cap. Before the fix, the function
2724    /// returned all 50 000; after the fix, exactly 30 000.
2725    #[test]
2726    fn test_pre_nms_cap_truncates_excess_candidates() {
2727        let n: usize = 50_000;
2728        let num_classes = 1;
2729
2730        // Identical valid boxes. Distinct scores (descending) so the
2731        // top-K cap keeps the highest-scoring ones in deterministic
2732        // order — letting us assert *which* ones survived.
2733        let mut boxes_data = Vec::with_capacity(n * 4);
2734        let mut scores_data = Vec::with_capacity(n * num_classes);
2735        for i in 0..n {
2736            boxes_data.extend_from_slice(&[0.1f32, 0.1, 0.5, 0.5]);
2737            // score_i = 0.99 - i * 1e-7 keeps everything well above 0.1
2738            // threshold but strictly decreasing.
2739            scores_data.push(0.99 - (i as f32) * 1e-7);
2740        }
2741        let boxes = Array2::from_shape_vec((n, 4), boxes_data).unwrap();
2742        let scores = Array2::from_shape_vec((n, num_classes), scores_data).unwrap();
2743
2744        let result = impl_yolo_segdet_get_boxes::<XYXY, _, _>(
2745            boxes.view(),
2746            scores.view(),
2747            0.1,
2748            1.0,                             // IoU 1.0 → NMS suppresses nothing
2749            Some(Nms::ClassAgnostic),        // NMS enabled so pre_nms_top_k applies
2750            crate::yolo::MAX_NMS_CANDIDATES, // pre_nms_top_k
2751            usize::MAX,                      // no post-NMS truncation
2752        );
2753
2754        assert_eq!(
2755            result.len(),
2756            crate::yolo::MAX_NMS_CANDIDATES,
2757            "pre-NMS cap should truncate to MAX_NMS_CANDIDATES; got {}",
2758            result.len()
2759        );
2760        // Top-K survivors: highest scores were the first n indices,
2761        // so survivor 0 must have score ~0.99.
2762        let top_score = result[0].0.score;
2763        assert!(
2764            top_score > 0.98,
2765            "highest-ranked survivor should have the largest score, got {top_score}"
2766        );
2767    }
2768
2769    /// Counterpart for the quantized split path. Same contract: feed
2770    /// more than `MAX_NMS_CANDIDATES` survivors above the quantized
2771    /// threshold, confirm `impl_yolo_split_segdet_quant_get_boxes`
2772    /// truncates before NMS.
2773    #[test]
2774    fn test_pre_nms_cap_truncates_excess_candidates_quant() {
2775        use crate::Quantization;
2776        let n: usize = 50_000;
2777        let num_classes = 1;
2778
2779        // i8 boxes with simple scale/zp; the box value 50 dequantizes
2780        // to 0.5 with scale=0.01, zp=0 — fine for a flat box set.
2781        let boxes_data = (0..n).flat_map(|_| [10i8, 10, 50, 50]).collect::<Vec<_>>();
2782        let boxes = Array2::from_shape_vec((n, 4), boxes_data).unwrap();
2783        let quant_boxes = Quantization {
2784            scale: 0.01,
2785            zero_point: 0,
2786        };
2787
2788        // u8 scores: distinct descending values, all well above threshold.
2789        // value 250 → 0.98 with scale 0.00392, zp 0.
2790        // value (250 - i % 200) keeps a wide spread above the dequant
2791        // threshold of 0.5.
2792        let scores_data: Vec<u8> = (0..n)
2793            .map(|i| 250u8.saturating_sub((i % 200) as u8))
2794            .collect();
2795        let scores = Array2::from_shape_vec((n, num_classes), scores_data).unwrap();
2796        let quant_scores = Quantization {
2797            scale: 0.00392,
2798            zero_point: 0,
2799        };
2800
2801        let result = impl_yolo_split_segdet_quant_get_boxes::<XYXY, _, _>(
2802            (boxes.view(), quant_boxes),
2803            (scores.view(), quant_scores),
2804            0.1,
2805            1.0,                             // IoU 1.0 → NMS suppresses nothing
2806            Some(Nms::ClassAgnostic),        // NMS enabled so pre_nms_top_k applies
2807            crate::yolo::MAX_NMS_CANDIDATES, // pre_nms_top_k
2808            usize::MAX,                      // no post-NMS truncation
2809        );
2810
2811        assert_eq!(
2812            result.len(),
2813            crate::yolo::MAX_NMS_CANDIDATES,
2814            "quant path pre-NMS cap should truncate to MAX_NMS_CANDIDATES; got {}",
2815            result.len()
2816        );
2817    }
2818
2819    /// Regression test for HAILORT_BUG.md — the YoloSegDet path
2820    /// (combined `(4 + nc + nm, N)` detection tensor + separate protos)
2821    /// must pair each surviving detection with the mask coefficient
2822    /// row at the SAME anchor index the box came from. The validator
2823    /// sees this path miss the pairing under schema-v2 Hailo inputs
2824    /// (mAP collapse from 46.8 → 3.65 while mask IoU stays at 66.9,
2825    /// the fingerprint of mask-to-detection misalignment).
2826    ///
2827    /// Construction: three anchors with distinct mask-coef signatures
2828    /// that, after dot(coefs, protos) + sigmoid, produce HIGH vs LOW
2829    /// mask pixel values. Two anchors survive (one high, one low); if
2830    /// the mask row is looked up at the wrong index, the per-detection
2831    /// mean mask value would cross the threshold and we catch it.
2832    #[test]
2833    fn segdet_combined_tensor_pairs_detection_with_matching_mask_row() {
2834        let nc = 2; // num_classes
2835        let nm = 2; // num_protos
2836        let n = 3; // num_anchors
2837        let feat = 4 + nc + nm; // 8
2838
2839        // Tensor layout: (8, 3) rows=features, cols=anchors.
2840        // Row indices:  0..4 = xywh, 4..6 = scores, 6..8 = mask_coefs.
2841        //
2842        //         anchor 0 | anchor 1 | anchor 2
2843        // xc       0.2      | 0.5      | 0.8
2844        // yc       0.2      | 0.5      | 0.8
2845        // w        0.1      | 0.1      | 0.1
2846        // h        0.1      | 0.1      | 0.1
2847        // s[0]     0.9      | 0.0      | 0.8   (class 0)
2848        // s[1]     0.0      | 0.0      | 0.0   (class 1 — always loses)
2849        // m[0]     3.0      | 0.0      | -3.0  (high for a0, low for a2)
2850        // m[1]     3.0      | 0.0      | -3.0
2851        //
2852        // Proto[0] = Proto[1] = all-ones (8x8), so
2853        //   mask(a0) = sigmoid(3 + 3) ≈ 0.9975 → 254
2854        //   mask(a2) = sigmoid(-3 + -3) ≈ 0.0025 → 1
2855        // 250-point gap makes any misalignment trivially detectable.
2856        let mut data = vec![0.0f32; feat * n];
2857        let set = |d: &mut [f32], r: usize, c: usize, v: f32| d[r * n + c] = v;
2858        set(&mut data, 0, 0, 0.2);
2859        set(&mut data, 1, 0, 0.2);
2860        set(&mut data, 2, 0, 0.1);
2861        set(&mut data, 3, 0, 0.1);
2862        set(&mut data, 0, 1, 0.5);
2863        set(&mut data, 1, 1, 0.5);
2864        set(&mut data, 2, 1, 0.1);
2865        set(&mut data, 3, 1, 0.1);
2866        set(&mut data, 0, 2, 0.8);
2867        set(&mut data, 1, 2, 0.8);
2868        set(&mut data, 2, 2, 0.1);
2869        set(&mut data, 3, 2, 0.1);
2870        set(&mut data, 4, 0, 0.9);
2871        set(&mut data, 4, 2, 0.8);
2872        set(&mut data, 6, 0, 3.0);
2873        set(&mut data, 7, 0, 3.0);
2874        set(&mut data, 6, 2, -3.0);
2875        set(&mut data, 7, 2, -3.0);
2876
2877        let output = Array2::from_shape_vec((feat, n), data).unwrap();
2878        let protos = Array3::<f32>::from_elem((8, 8, nm), 1.0);
2879
2880        let mut boxes: Vec<DetectBox> = Vec::with_capacity(10);
2881        let mut masks: Vec<Segmentation> = Vec::with_capacity(10);
2882        decode_yolo_segdet_float(
2883            output.view(),
2884            protos.view(),
2885            0.5,
2886            0.5,
2887            Some(Nms::ClassAgnostic),
2888            &mut boxes,
2889            &mut masks,
2890        )
2891        .unwrap();
2892
2893        assert_eq!(
2894            boxes.len(),
2895            2,
2896            "two anchors above threshold should survive (a0 score=0.9, a2 score=0.8); got {}",
2897            boxes.len()
2898        );
2899
2900        // Build a (anchor_index → mask_mean) mapping from the results.
2901        // Anchor 0 has centre (0.2, 0.2), anchor 2 has centre (0.8,
2902        // 0.8). The DetectBox bbox is the post-XYWH-to-XYXY conversion
2903        // of the original xywh; cropping inside protobox may shrink it,
2904        // so match by centre (0.2 vs 0.8) rather than exact bbox.
2905        for (b, m) in boxes.iter().zip(masks.iter()) {
2906            let cx = (b.bbox.xmin + b.bbox.xmax) * 0.5;
2907            let mean = {
2908                let s = &m.segmentation;
2909                let total: u32 = s.iter().map(|&v| v as u32).sum();
2910                total as f32 / s.len() as f32
2911            };
2912            if cx < 0.3 {
2913                // anchor 0 — expect HIGH mask values ≈ 254
2914                assert!(
2915                    mean > 200.0,
2916                    "anchor 0 detection (centre {cx:.2}) should have high-value mask; got mean {mean}"
2917                );
2918            } else if cx > 0.7 {
2919                // anchor 2 — expect LOW mask values ≈ 1
2920                assert!(
2921                    mean < 50.0,
2922                    "anchor 2 detection (centre {cx:.2}) should have low-value mask; got mean {mean}"
2923                );
2924            } else {
2925                panic!("unexpected detection centre {cx:.2}");
2926            }
2927        }
2928    }
2929
2930    // ========================================================================
2931    // Tests for truncate_to_top_k_by_score / truncate_to_top_k_by_score_quant
2932    // ========================================================================
2933
2934    /// Helper: build a Vec of (DetectBox, ()) with the given scores.
2935    fn make_float_boxes(scores: &[f32]) -> Vec<(DetectBox, ())> {
2936        scores
2937            .iter()
2938            .enumerate()
2939            .map(|(i, &s)| {
2940                (
2941                    DetectBox {
2942                        bbox: BoundingBox {
2943                            xmin: 0.0,
2944                            ymin: 0.0,
2945                            xmax: 1.0,
2946                            ymax: 1.0,
2947                        },
2948                        score: s,
2949                        label: i,
2950                    },
2951                    (),
2952                )
2953            })
2954            .collect()
2955    }
2956
2957    /// Helper: build a Vec of (DetectBoxQuantized<i8>, ()) with the given scores.
2958    fn make_quant_boxes(scores: &[i8]) -> Vec<(DetectBoxQuantized<i8>, ())> {
2959        scores
2960            .iter()
2961            .enumerate()
2962            .map(|(i, &s)| {
2963                (
2964                    DetectBoxQuantized {
2965                        bbox: BoundingBox {
2966                            xmin: 0.0,
2967                            ymin: 0.0,
2968                            xmax: 1.0,
2969                            ymax: 1.0,
2970                        },
2971                        score: s,
2972                        label: i,
2973                    },
2974                    (),
2975                )
2976            })
2977            .collect()
2978    }
2979
2980    #[test]
2981    fn truncate_float_top_k_zero_is_unbounded() {
2982        let mut boxes = make_float_boxes(&[0.9, 0.1, 0.5, 0.3, 0.7]);
2983        let original_len = boxes.len();
2984        truncate_to_top_k_by_score(&mut boxes, 0);
2985        assert_eq!(
2986            boxes.len(),
2987            original_len,
2988            "top_k=0 should keep all candidates (no-limit semantics)"
2989        );
2990    }
2991
2992    #[test]
2993    fn truncate_float_top_k_normal() {
2994        let mut boxes = make_float_boxes(&[0.9, 0.1, 0.5, 0.3, 0.7]);
2995        truncate_to_top_k_by_score(&mut boxes, 3);
2996        assert_eq!(boxes.len(), 3);
2997        // The top-3 scores should be 0.9, 0.7, 0.5 (order within top-K is unspecified)
2998        let mut retained: Vec<f32> = boxes.iter().map(|(b, _)| b.score).collect();
2999        retained.sort_by(|a, b| b.total_cmp(a));
3000        assert_eq!(retained, vec![0.9, 0.7, 0.5]);
3001    }
3002
3003    #[test]
3004    fn truncate_float_top_k_noop_when_under_cap() {
3005        let mut boxes = make_float_boxes(&[0.9, 0.5]);
3006        truncate_to_top_k_by_score(&mut boxes, 10);
3007        assert_eq!(boxes.len(), 2, "should be no-op when len <= top_k");
3008    }
3009
3010    #[test]
3011    fn truncate_quant_top_k_zero_is_unbounded() {
3012        let mut boxes = make_quant_boxes(&[120, -50, 30, -10, 80]);
3013        let original_len = boxes.len();
3014        truncate_to_top_k_by_score_quant(&mut boxes, 0);
3015        assert_eq!(
3016            boxes.len(),
3017            original_len,
3018            "top_k=0 should keep all candidates (no-limit semantics)"
3019        );
3020    }
3021
3022    #[test]
3023    fn truncate_quant_top_k_normal() {
3024        let mut boxes = make_quant_boxes(&[120, -50, 30, -10, 80]);
3025        truncate_to_top_k_by_score_quant(&mut boxes, 3);
3026        assert_eq!(boxes.len(), 3);
3027        let mut retained: Vec<i8> = boxes.iter().map(|(b, _)| b.score).collect();
3028        retained.sort_by(|a, b| b.cmp(a));
3029        assert_eq!(retained, vec![120, 80, 30]);
3030    }
3031
3032    #[test]
3033    fn truncate_quant_top_k_noop_when_under_cap() {
3034        let mut boxes = make_quant_boxes(&[120, 80]);
3035        truncate_to_top_k_by_score_quant(&mut boxes, 10);
3036        assert_eq!(boxes.len(), 2, "should be no-op when len <= top_k");
3037    }
3038}