Skip to main content

edgefirst_decoder/
yolo.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4//! Internal YOLO decoder kernels.
5//!
6//! All items in this module are `pub(crate)`. The public API surface for
7//! decoding is [`crate::Decoder`] + [`crate::DecoderBuilder`]; external
8//! callers must go through that entry point so we can evolve the kernels
9//! (split paths, dispatch tables, NEON tiers) without breaking semver.
10//! See `CHANGELOG.md` for the 0.20.0 narrowing.
11
12use std::fmt::Debug;
13
14use ndarray::{
15    parallel::prelude::{IntoParallelIterator, ParallelIterator},
16    s, Array2, Array3, ArrayView1, ArrayView2, ArrayView3,
17};
18use num_traits::{AsPrimitive, Float, PrimInt, Signed};
19
20use crate::{
21    byte::{
22        nms_class_aware_int, nms_extra_class_aware_int, nms_extra_int, nms_int,
23        postprocess_boxes_index_quant, postprocess_boxes_quant, quantize_score_threshold,
24    },
25    configs::Nms,
26    dequant_detect_box,
27    float::{
28        nms_class_aware_float, nms_extra_class_aware_float, nms_extra_float, nms_float,
29        postprocess_boxes_float, postprocess_boxes_index_float,
30    },
31    BBoxTypeTrait, BoundingBox, DetectBox, DetectBoxQuantized, ProtoData, ProtoLayout,
32    Quantization, Segmentation, XYWH, XYXY,
33};
34
35/// Maximum number of above-threshold candidates fed to NMS.
36///
37/// At very low score thresholds (e.g., t=0.01 on YOLOv8 with
38/// 8400 anchors × 80 classes), the number of survivors approaches
39/// the full 672 000-entry score grid. NMS is O(n²) and the
40/// downstream mask matmul runs once per survivor, so an
41/// unbounded set produces minutes-per-frame decode times.
42///
43/// `MAX_NMS_CANDIDATES` matches the Ultralytics `max_nms` default
44/// and is applied as a top-K-by-score truncation immediately
45/// before NMS. Values above the cap are silently dropped — at the
46/// score thresholds where the cap activates the bottom of the
47/// candidate list is dominated by noise that NMS would discard
48/// anyway.
49///
50/// Production callers configure this via [`crate::DecoderBuilder::with_pre_nms_top_k`];
51/// only the test-only seg-det wrappers reach for this constant directly.
52#[cfg(test)]
53pub(crate) const MAX_NMS_CANDIDATES: usize = 30_000;
54
55/// Default post-NMS detection cap used by the public `decode_yolo_*`
56/// convenience wrappers when no explicit cap is plumbed in. Mirrors the
57/// `Decoder::max_det` default set by `DecoderBuilder` (also 300, matching
58/// the Ultralytics `max_det` default). Pre-EDGEAI-1302 these wrappers
59/// used `output_boxes.capacity()` as the cap, which silently dropped all
60/// detections when the caller passed `Vec::new()`.
61pub(crate) const DEFAULT_MAX_DETECTIONS: usize = 300;
62
63/// Truncate `boxes` to the highest-scoring `top_k` entries in-place when the
64/// input exceeds the cap. Uses partial sort (O(N)) via `select_nth_unstable_by`
65/// to avoid full O(N log N) sort. No-op when input length ≤ `top_k`.
66fn truncate_to_top_k_by_score<E: Send>(boxes: &mut Vec<(DetectBox, E)>, top_k: usize) {
67    if boxes.len() > top_k {
68        boxes.select_nth_unstable_by(top_k, |a, b| b.0.score.total_cmp(&a.0.score));
69        boxes.truncate(top_k);
70    }
71}
72
73/// Quantized counterpart of [`truncate_to_top_k_by_score`]. Sorts on
74/// the raw quantized score (which preserves order under monotonic
75/// dequantization). Uses partial sort (O(N)) via `select_nth_unstable_by`.
76fn truncate_to_top_k_by_score_quant<S: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send>(
77    boxes: &mut Vec<(DetectBoxQuantized<S>, E)>,
78    top_k: usize,
79) {
80    if boxes.len() > top_k {
81        boxes.select_nth_unstable_by(top_k, |a, b| b.0.score.cmp(&a.0.score));
82        boxes.truncate(top_k);
83    }
84}
85
86/// Dispatches to the appropriate NMS function based on mode for float boxes.
87///
88/// `max_det` is the post-NMS detection cap; when present it lets the greedy
89/// inner loop break as soon as that many survivors are confirmed (the survivors
90/// are guaranteed to be the top-`max_det` by score because the input is sorted
91/// descending). Pass `None` to run the full O(N²) suppression.
92fn dispatch_nms_float(
93    nms: Option<Nms>,
94    iou: f32,
95    max_det: Option<usize>,
96    boxes: Vec<DetectBox>,
97) -> Vec<DetectBox> {
98    match nms {
99        Some(Nms::ClassAgnostic) => nms_float(iou, max_det, boxes),
100        Some(Nms::ClassAware) => nms_class_aware_float(iou, max_det, boxes),
101        None => boxes, // bypass NMS
102    }
103}
104
105/// Dispatches to the appropriate NMS function based on mode for float boxes
106/// with extra data.
107pub(super) fn dispatch_nms_extra_float<E: Send + Sync>(
108    nms: Option<Nms>,
109    iou: f32,
110    max_det: Option<usize>,
111    boxes: Vec<(DetectBox, E)>,
112) -> Vec<(DetectBox, E)> {
113    match nms {
114        Some(Nms::ClassAgnostic) => nms_extra_float(iou, max_det, boxes),
115        Some(Nms::ClassAware) => nms_extra_class_aware_float(iou, max_det, boxes),
116        None => boxes, // bypass NMS
117    }
118}
119
120/// Dispatches to the appropriate NMS function based on mode for quantized
121/// boxes.
122fn dispatch_nms_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync>(
123    nms: Option<Nms>,
124    iou: f32,
125    max_det: Option<usize>,
126    boxes: Vec<DetectBoxQuantized<SCORE>>,
127) -> Vec<DetectBoxQuantized<SCORE>> {
128    match nms {
129        Some(Nms::ClassAgnostic) => nms_int(iou, max_det, boxes),
130        Some(Nms::ClassAware) => nms_class_aware_int(iou, max_det, boxes),
131        None => boxes, // bypass NMS
132    }
133}
134
135/// Dispatches to the appropriate NMS function based on mode for quantized boxes
136/// with extra data.
137fn dispatch_nms_extra_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send + Sync>(
138    nms: Option<Nms>,
139    iou: f32,
140    max_det: Option<usize>,
141    boxes: Vec<(DetectBoxQuantized<SCORE>, E)>,
142) -> Vec<(DetectBoxQuantized<SCORE>, E)> {
143    match nms {
144        Some(Nms::ClassAgnostic) => nms_extra_int(iou, max_det, boxes),
145        Some(Nms::ClassAware) => nms_extra_class_aware_int(iou, max_det, boxes),
146        None => boxes, // bypass NMS
147    }
148}
149
150/// Detection cap helper for the public free `decode_yolo_*` wrappers.
151///
152/// Encodes the convention documented above: if the caller passed a non-empty
153/// `Vec`, that capacity acts as the per-call cap; otherwise fall back to
154/// [`DEFAULT_MAX_DETECTIONS`] so freshly-constructed `Vec::new()` outputs
155/// don't silently drop every detection.
156#[inline]
157fn cap_or_default<T>(v: &Vec<T>) -> usize {
158    if v.capacity() > 0 {
159        v.capacity()
160    } else {
161        DEFAULT_MAX_DETECTIONS
162    }
163}
164
165// ─── Public free decode_yolo_* convenience wrappers ────────────────────────
166//
167// Detection cap convention (applies to every `decode_yolo_*` free function
168// below):
169//
170// These functions are the convenience layer for callers that don't go
171// through `Decoder::decode()` (benches, FFI shims, ad-hoc test harnesses).
172// They use **`output_boxes.capacity()` as a per-call detection cap**:
173//
174//   - When the caller passes `Vec::with_capacity(N)`, the post-NMS output
175//     is truncated to at most `N` detections.
176//   - When the caller passes `Vec::new()` (capacity 0), the implementation
177//     falls back to the [`DEFAULT_MAX_DETECTIONS`] constant (300) so a
178//     freshly-constructed `Vec` doesn't silently drop every detection.
179//
180// This is intentionally **different** from the `Decoder::decode()` /
181// `Decoder::decode_proto()` contract, which bounds output count solely
182// by [`Decoder::max_det`] (set via `DecoderBuilder::with_max_det`,
183// default 300) regardless of the caller's `Vec` capacity (EDGEAI-1302).
184//
185// Use the `Decoder` API when you need explicit control over `max_det`,
186// schema-driven decoding, or EDGEAI-1303 normalization. Use these free
187// functions when you have raw tensors in hand and want a one-shot decode.
188
189/// Decodes YOLO detection outputs from quantized tensors into detection boxes.
190///
191/// Boxes are expected to be in XYWH format.
192///
193/// Expected shapes of inputs:
194/// - output: (4 + num_classes, num_boxes)
195///
196/// See the "Detection cap convention" comment above for how
197/// `output_boxes.capacity()` bounds the result count.
198pub(crate) fn decode_yolo_det<BOX: PrimInt + AsPrimitive<f32> + Send + Sync>(
199    output: (ArrayView2<BOX>, Quantization),
200    score_threshold: f32,
201    iou_threshold: f32,
202    nms: Option<Nms>,
203    output_boxes: &mut Vec<DetectBox>,
204) where
205    f32: AsPrimitive<BOX>,
206{
207    impl_yolo_quant::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
208}
209
210/// Decodes YOLO detection outputs from float tensors into detection boxes.
211///
212/// Boxes are expected to be in XYWH format.
213///
214/// Expected shapes of inputs:
215/// - output: (4 + num_classes, num_boxes)
216pub(crate) fn decode_yolo_det_float<T>(
217    output: ArrayView2<T>,
218    score_threshold: f32,
219    iou_threshold: f32,
220    nms: Option<Nms>,
221    output_boxes: &mut Vec<DetectBox>,
222) where
223    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
224    f32: AsPrimitive<T>,
225{
226    impl_yolo_float::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
227}
228
229/// Test-only seg-det quantized decode shim.
230///
231/// Production callers go through [`crate::Decoder::decode`]; this wrapper
232/// exists only so the parity tests in `lib.rs` and `yolo.rs` can compare the
233/// kernel output against the Decoder output without duplicating the
234/// `impl_yolo_segdet_quant` argument plumbing.
235///
236/// Boxes are expected to be in XYWH format. Expected shapes:
237/// - `boxes`: `(4 + num_classes + num_protos, num_boxes)`
238/// - `protos`: `(proto_height, proto_width, num_protos)`
239///
240/// # Errors
241/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
242#[cfg(test)]
243pub(crate) fn decode_yolo_segdet_quant<
244    BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
245    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
246>(
247    boxes: (ArrayView2<BOX>, Quantization),
248    protos: (ArrayView3<PROTO>, Quantization),
249    score_threshold: f32,
250    iou_threshold: f32,
251    nms: Option<Nms>,
252    output_boxes: &mut Vec<DetectBox>,
253    output_masks: &mut Vec<Segmentation>,
254) -> Result<(), crate::DecoderError>
255where
256    f32: AsPrimitive<BOX>,
257{
258    // Pre-Decoder convenience wrapper: no schema-derived `normalized`
259    // or `input_dims`, so the EDGEAI-1303 normalization is a no-op.
260    // Callers that need that path should go through `DecoderBuilder`
261    // with a schema.
262    let cap = cap_or_default(output_boxes);
263    impl_yolo_segdet_quant::<XYWH, _, _>(
264        boxes,
265        protos,
266        score_threshold,
267        iou_threshold,
268        nms,
269        MAX_NMS_CANDIDATES,
270        cap,
271        None,
272        None,
273        output_boxes,
274        output_masks,
275    )
276}
277
278/// Test-only seg-det float decode shim. See [`decode_yolo_segdet_quant`].
279#[cfg(test)]
280pub(crate) fn decode_yolo_segdet_float<T>(
281    boxes: ArrayView2<T>,
282    protos: ArrayView3<T>,
283    score_threshold: f32,
284    iou_threshold: f32,
285    nms: Option<Nms>,
286    output_boxes: &mut Vec<DetectBox>,
287    output_masks: &mut Vec<Segmentation>,
288) -> Result<(), crate::DecoderError>
289where
290    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
291    f32: AsPrimitive<T>,
292{
293    // Pre-Decoder convenience wrapper: schema-derived normalization is
294    // not available here (see EDGEAI-1303 note in the quantized sibling).
295    let cap = cap_or_default(output_boxes);
296    impl_yolo_segdet_float::<XYWH, _, _>(
297        boxes,
298        protos,
299        score_threshold,
300        iou_threshold,
301        nms,
302        MAX_NMS_CANDIDATES,
303        cap,
304        None,
305        None,
306        output_boxes,
307        output_masks,
308    )
309}
310
311/// Decodes YOLO split detection outputs from quantized tensors into detection
312/// boxes.
313///
314/// Boxes are expected to be in XYWH format.
315///
316/// Expected shapes of inputs:
317/// - boxes: (4, num_boxes)
318/// - scores: (num_classes, num_boxes)
319///
320/// # Panics
321/// Panics if shapes don't match the expected dimensions.
322pub(crate) fn decode_yolo_split_det_quant<
323    BOX: PrimInt + AsPrimitive<i32> + AsPrimitive<f32> + Send + Sync,
324    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
325>(
326    boxes: (ArrayView2<BOX>, Quantization),
327    scores: (ArrayView2<SCORE>, Quantization),
328    score_threshold: f32,
329    iou_threshold: f32,
330    nms: Option<Nms>,
331    output_boxes: &mut Vec<DetectBox>,
332) where
333    f32: AsPrimitive<SCORE>,
334{
335    impl_yolo_split_quant::<XYWH, _, _>(
336        boxes,
337        scores,
338        score_threshold,
339        iou_threshold,
340        nms,
341        output_boxes,
342    );
343}
344
345/// Decodes YOLO split detection outputs from float tensors into detection
346/// boxes.
347///
348/// Boxes are expected to be in XYWH format.
349///
350/// Expected shapes of inputs:
351/// - boxes: (4, num_boxes)
352/// - scores: (num_classes, num_boxes)
353///
354/// # Panics
355/// Panics if shapes don't match the expected dimensions.
356pub(crate) fn decode_yolo_split_det_float<T>(
357    boxes: ArrayView2<T>,
358    scores: ArrayView2<T>,
359    score_threshold: f32,
360    iou_threshold: f32,
361    nms: Option<Nms>,
362    output_boxes: &mut Vec<DetectBox>,
363) where
364    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
365    f32: AsPrimitive<T>,
366{
367    impl_yolo_split_float::<XYWH, _, _>(
368        boxes,
369        scores,
370        score_threshold,
371        iou_threshold,
372        nms,
373        output_boxes,
374    );
375}
376
377/// Decodes end-to-end YOLO detection outputs (post-NMS from model).
378/// Expects an array of shape `(6, N)`, where the first dimension (rows)
379/// corresponds to the 6 per-detection features
380/// `[x1, y1, x2, y2, conf, class]` and the second dimension (columns)
381/// indexes the `N` detections.
382/// Boxes are output directly without NMS (the model already applied NMS).
383///
384/// Coordinates may be normalized `[0, 1]` or absolute pixel values depending
385/// on the model configuration. The caller should check
386/// `decoder.normalized_boxes()` to determine which.
387///
388/// # Errors
389///
390/// Returns `DecoderError::InvalidShape` if `output` has fewer than 6 rows.
391pub(crate) fn decode_yolo_end_to_end_det_float<T>(
392    output: ArrayView2<T>,
393    score_threshold: f32,
394    output_boxes: &mut Vec<DetectBox>,
395) -> Result<(), crate::DecoderError>
396where
397    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
398    f32: AsPrimitive<T>,
399{
400    // Validate input shape: need at least 6 rows (x1, y1, x2, y2, conf, class)
401    if output.shape()[0] < 6 {
402        return Err(crate::DecoderError::InvalidShape(format!(
403            "End-to-end detection output requires at least 6 rows, got {}",
404            output.shape()[0]
405        )));
406    }
407
408    // Input shape: (6, N) -> transpose to (N, 4) for boxes and (N, 1) for scores
409    let boxes = output.slice(s![0..4, ..]).reversed_axes();
410    let scores = output.slice(s![4..5, ..]).reversed_axes();
411    let classes = output.slice(s![5, ..]);
412    let mut boxes =
413        postprocess_boxes_index_float::<XYXY, _, _>(score_threshold.as_(), boxes, scores);
414    boxes.truncate(cap_or_default(output_boxes));
415    output_boxes.clear();
416    for (mut b, i) in boxes.into_iter() {
417        b.label = classes[i].as_() as usize;
418        output_boxes.push(b);
419    }
420    // No NMS — model output is already post-NMS
421    Ok(())
422}
423
424/// Decodes end-to-end YOLO detection + segmentation outputs (post-NMS from
425/// model).
426///
427/// Input shapes:
428/// - detection: (6 + num_protos, N) where rows are [x1, y1, x2, y2, conf,
429///   class, mask_coeff_0, ..., mask_coeff_31]
430/// - protos: (proto_height, proto_width, num_protos)
431///
432/// Boxes are output directly without NMS (model already applied NMS).
433/// Coordinates may be normalized [0,1] or pixel values depending on model
434/// config.
435///
436/// # Errors
437///
438/// Returns `DecoderError::InvalidShape` if:
439/// - output has fewer than 7 rows (6 base + at least 1 mask coefficient)
440/// - protos shape doesn't match mask coefficients count
441pub(crate) fn decode_yolo_end_to_end_segdet_float<T>(
442    output: ArrayView2<T>,
443    protos: ArrayView3<T>,
444    score_threshold: f32,
445    output_boxes: &mut Vec<DetectBox>,
446    output_masks: &mut Vec<crate::Segmentation>,
447) -> Result<(), crate::DecoderError>
448where
449    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
450    f32: AsPrimitive<T>,
451{
452    let (boxes, scores, classes, mask_coeff) =
453        postprocess_yolo_end_to_end_segdet(&output, protos.dim().2)?;
454    let cap = cap_or_default(output_boxes);
455    let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
456        boxes,
457        scores,
458        classes,
459        score_threshold,
460        cap,
461    );
462
463    // No NMS — model output is already post-NMS
464
465    impl_yolo_split_segdet_process_masks(boxes, mask_coeff, protos, output_boxes, output_masks)
466}
467
468/// Decodes split end-to-end YOLO detection outputs (post-NMS from model).
469///
470/// Input shapes (after batch dim removed):
471/// - boxes: (4, N) — xyxy pixel coordinates
472/// - scores: (1, N) — confidence of the top class
473/// - classes: (1, N) — class index of the top class
474///
475/// Boxes are output directly without NMS (model already applied NMS).
476pub(crate) fn decode_yolo_split_end_to_end_det_float<T: Float + AsPrimitive<f32>>(
477    boxes: ArrayView2<T>,
478    scores: ArrayView2<T>,
479    classes: ArrayView2<T>,
480    score_threshold: f32,
481    output_boxes: &mut Vec<DetectBox>,
482) -> Result<(), crate::DecoderError> {
483    let n = boxes.shape()[1];
484
485    let cap = cap_or_default(output_boxes);
486    output_boxes.clear();
487
488    let (boxes, scores, classes) = postprocess_yolo_split_end_to_end_det(boxes, scores, &classes)?;
489
490    for i in 0..n {
491        let score: f32 = scores[[i, 0]].as_();
492        if score < score_threshold {
493            continue;
494        }
495        if output_boxes.len() >= cap {
496            break;
497        }
498        output_boxes.push(DetectBox {
499            bbox: BoundingBox {
500                xmin: boxes[[i, 0]].as_(),
501                ymin: boxes[[i, 1]].as_(),
502                xmax: boxes[[i, 2]].as_(),
503                ymax: boxes[[i, 3]].as_(),
504            },
505            score,
506            label: classes[i].as_() as usize,
507        });
508    }
509    Ok(())
510}
511
512/// Decodes split end-to-end YOLO detection + segmentation outputs.
513///
514/// Input shapes (after batch dim removed):
515/// - boxes: (4, N) — xyxy pixel coordinates
516/// - scores: (1, N) — confidence
517/// - classes: (1, N) — class index
518/// - mask_coeff: (num_protos, N) — mask coefficients per detection
519/// - protos: (proto_h, proto_w, num_protos) — prototype masks
520#[allow(clippy::too_many_arguments)]
521pub(crate) fn decode_yolo_split_end_to_end_segdet_float<T>(
522    boxes: ArrayView2<T>,
523    scores: ArrayView2<T>,
524    classes: ArrayView2<T>,
525    mask_coeff: ArrayView2<T>,
526    protos: ArrayView3<T>,
527    score_threshold: f32,
528    output_boxes: &mut Vec<DetectBox>,
529    output_masks: &mut Vec<crate::Segmentation>,
530) -> Result<(), crate::DecoderError>
531where
532    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
533    f32: AsPrimitive<T>,
534{
535    let (boxes, scores, classes, mask_coeff) =
536        postprocess_yolo_split_end_to_end_segdet(boxes, scores, &classes, mask_coeff)?;
537    let cap = cap_or_default(output_boxes);
538    let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
539        boxes,
540        scores,
541        classes,
542        score_threshold,
543        cap,
544    );
545
546    impl_yolo_split_segdet_process_masks(boxes, mask_coeff, protos, output_boxes, output_masks)
547}
548
549#[allow(clippy::type_complexity)]
550pub(crate) fn postprocess_yolo_end_to_end_segdet<'a, T>(
551    output: &'a ArrayView2<'_, T>,
552    num_protos: usize,
553) -> Result<
554    (
555        ArrayView2<'a, T>,
556        ArrayView2<'a, T>,
557        ArrayView1<'a, T>,
558        ArrayView2<'a, T>,
559    ),
560    crate::DecoderError,
561> {
562    // Validate input shape: need at least 7 rows (6 base + at least 1 mask coeff)
563    if output.shape()[0] < 7 {
564        return Err(crate::DecoderError::InvalidShape(format!(
565            "End-to-end segdet output requires at least 7 rows, got {}",
566            output.shape()[0]
567        )));
568    }
569
570    let num_mask_coeffs = output.shape()[0] - 6;
571    if num_mask_coeffs != num_protos {
572        return Err(crate::DecoderError::InvalidShape(format!(
573            "Mask coefficients count ({}) doesn't match protos count ({})",
574            num_mask_coeffs, num_protos
575        )));
576    }
577
578    // Input shape: (6+num_protos, N) -> transpose for postprocessing
579    let boxes = output.slice(s![0..4, ..]).reversed_axes();
580    let scores = output.slice(s![4..5, ..]).reversed_axes();
581    let classes = output.slice(s![5, ..]);
582    let mask_coeff = output.slice(s![6.., ..]).reversed_axes();
583    Ok((boxes, scores, classes, mask_coeff))
584}
585
586/// Postprocess yolo split end to end det by reversing axes of boxes,
587/// scores, and flattening the class tensor.
588/// Expected input shapes:
589/// - boxes: (4, N)
590/// - scores: (1, N)
591/// - classes: (1, N)
592#[allow(clippy::type_complexity)]
593pub(crate) fn postprocess_yolo_split_end_to_end_det<'a, 'b, 'c, BOXES, SCORES, CLASS>(
594    boxes: ArrayView2<'a, BOXES>,
595    scores: ArrayView2<'b, SCORES>,
596    classes: &'c ArrayView2<CLASS>,
597) -> Result<
598    (
599        ArrayView2<'a, BOXES>,
600        ArrayView2<'b, SCORES>,
601        ArrayView1<'c, CLASS>,
602    ),
603    crate::DecoderError,
604> {
605    let num_boxes = boxes.shape()[1];
606    if boxes.shape()[0] != 4 {
607        return Err(crate::DecoderError::InvalidShape(format!(
608            "Split end-to-end box_coords must be 4, got {}",
609            boxes.shape()[0]
610        )));
611    }
612
613    if scores.shape()[0] != 1 {
614        return Err(crate::DecoderError::InvalidShape(format!(
615            "Split end-to-end scores num_classes must be 1, got {}",
616            scores.shape()[0]
617        )));
618    }
619
620    if classes.shape()[0] != 1 {
621        return Err(crate::DecoderError::InvalidShape(format!(
622            "Split end-to-end classes num_classes must be 1, got {}",
623            classes.shape()[0]
624        )));
625    }
626
627    if scores.shape()[1] != num_boxes {
628        return Err(crate::DecoderError::InvalidShape(format!(
629            "Split end-to-end scores must have same num_boxes as boxes ({}), got {}",
630            num_boxes,
631            scores.shape()[1]
632        )));
633    }
634
635    if classes.shape()[1] != num_boxes {
636        return Err(crate::DecoderError::InvalidShape(format!(
637            "Split end-to-end classes must have same num_boxes as boxes ({}), got {}",
638            num_boxes,
639            classes.shape()[1]
640        )));
641    }
642
643    let boxes = boxes.reversed_axes();
644    let scores = scores.reversed_axes();
645    let classes = classes.slice(s![0, ..]);
646    Ok((boxes, scores, classes))
647}
648
649/// Postprocess yolo split end to end segdet by reversing axes of boxes,
650/// scores, mask tensors and flattening the class tensor.
651#[allow(clippy::type_complexity)]
652pub(crate) fn postprocess_yolo_split_end_to_end_segdet<
653    'a,
654    'b,
655    'c,
656    'd,
657    BOXES,
658    SCORES,
659    CLASS,
660    MASK,
661>(
662    boxes: ArrayView2<'a, BOXES>,
663    scores: ArrayView2<'b, SCORES>,
664    classes: &'c ArrayView2<CLASS>,
665    mask_coeff: ArrayView2<'d, MASK>,
666) -> Result<
667    (
668        ArrayView2<'a, BOXES>,
669        ArrayView2<'b, SCORES>,
670        ArrayView1<'c, CLASS>,
671        ArrayView2<'d, MASK>,
672    ),
673    crate::DecoderError,
674> {
675    let num_boxes = boxes.shape()[1];
676    if boxes.shape()[0] != 4 {
677        return Err(crate::DecoderError::InvalidShape(format!(
678            "Split end-to-end box_coords must be 4, got {}",
679            boxes.shape()[0]
680        )));
681    }
682
683    if scores.shape()[0] != 1 {
684        return Err(crate::DecoderError::InvalidShape(format!(
685            "Split end-to-end scores num_classes must be 1, got {}",
686            scores.shape()[0]
687        )));
688    }
689
690    if classes.shape()[0] != 1 {
691        return Err(crate::DecoderError::InvalidShape(format!(
692            "Split end-to-end classes num_classes must be 1, got {}",
693            classes.shape()[0]
694        )));
695    }
696
697    if scores.shape()[1] != num_boxes {
698        return Err(crate::DecoderError::InvalidShape(format!(
699            "Split end-to-end scores must have same num_boxes as boxes ({}), got {}",
700            num_boxes,
701            scores.shape()[1]
702        )));
703    }
704
705    if classes.shape()[1] != num_boxes {
706        return Err(crate::DecoderError::InvalidShape(format!(
707            "Split end-to-end classes must have same num_boxes as boxes ({}), got {}",
708            num_boxes,
709            classes.shape()[1]
710        )));
711    }
712
713    if mask_coeff.shape()[1] != num_boxes {
714        return Err(crate::DecoderError::InvalidShape(format!(
715            "Split end-to-end mask_coeff must have same num_boxes as boxes ({}), got {}",
716            num_boxes,
717            mask_coeff.shape()[1]
718        )));
719    }
720
721    let boxes = boxes.reversed_axes();
722    let scores = scores.reversed_axes();
723    let classes = classes.slice(s![0, ..]);
724    let mask_coeff = mask_coeff.reversed_axes();
725    Ok((boxes, scores, classes, mask_coeff))
726}
727/// Internal implementation of YOLO decoding for quantized tensors.
728///
729/// Expected shapes of inputs:
730/// - output: (4 + num_classes, num_boxes)
731pub(crate) fn impl_yolo_quant<B: BBoxTypeTrait, T: PrimInt + AsPrimitive<f32> + Send + Sync>(
732    output: (ArrayView2<T>, Quantization),
733    score_threshold: f32,
734    iou_threshold: f32,
735    nms: Option<Nms>,
736    output_boxes: &mut Vec<DetectBox>,
737) where
738    f32: AsPrimitive<T>,
739{
740    let _span = tracing::trace_span!("decode", mode = "quant_det").entered();
741    let (boxes, quant_boxes) = output;
742    let (boxes_tensor, scores_tensor) = postprocess_yolo(&boxes);
743
744    let boxes = {
745        let score_threshold = quantize_score_threshold(score_threshold, quant_boxes);
746        postprocess_boxes_quant::<B, _, _>(
747            score_threshold,
748            boxes_tensor,
749            scores_tensor,
750            quant_boxes,
751        )
752    };
753
754    let cap = cap_or_default(output_boxes);
755    let boxes = dispatch_nms_int(nms, iou_threshold, Some(cap), boxes);
756    // Detection cap convention (see `cap_or_default`). NMS already capped to
757    // `cap`; the `min` here is a redundant guard for non-NMS bypass mode.
758    let len = cap.min(boxes.len());
759    output_boxes.clear();
760    for b in boxes.iter().take(len) {
761        output_boxes.push(dequant_detect_box(b, quant_boxes));
762    }
763}
764
765/// Internal implementation of YOLO decoding for float tensors.
766///
767/// Expected shapes of inputs:
768/// - output: (4 + num_classes, num_boxes)
769pub(crate) fn impl_yolo_float<B: BBoxTypeTrait, T: Float + AsPrimitive<f32> + Send + Sync>(
770    output: ArrayView2<T>,
771    score_threshold: f32,
772    iou_threshold: f32,
773    nms: Option<Nms>,
774    output_boxes: &mut Vec<DetectBox>,
775) where
776    f32: AsPrimitive<T>,
777{
778    let _span = tracing::trace_span!("decode", mode = "float_det").entered();
779    let (boxes_tensor, scores_tensor) = postprocess_yolo(&output);
780    let boxes =
781        postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
782    let cap = cap_or_default(output_boxes);
783    let boxes = dispatch_nms_float(nms, iou_threshold, Some(cap), boxes);
784    // Detection cap convention (see `cap_or_default`). NMS already capped to
785    // `cap`; the `min` here is a redundant guard for non-NMS bypass mode.
786    let len = cap.min(boxes.len());
787    output_boxes.clear();
788    for b in boxes.into_iter().take(len) {
789        output_boxes.push(b);
790    }
791}
792
793/// Internal implementation of YOLO split detection decoding for quantized
794/// tensors.
795///
796/// Expected shapes of inputs:
797/// - boxes: (4, num_boxes)
798/// - scores: (num_classes, num_boxes)
799///
800/// # Panics
801/// Panics if shapes don't match the expected dimensions.
802pub(crate) fn impl_yolo_split_quant<
803    B: BBoxTypeTrait,
804    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
805    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
806>(
807    boxes: (ArrayView2<BOX>, Quantization),
808    scores: (ArrayView2<SCORE>, Quantization),
809    score_threshold: f32,
810    iou_threshold: f32,
811    nms: Option<Nms>,
812    output_boxes: &mut Vec<DetectBox>,
813) where
814    f32: AsPrimitive<SCORE>,
815{
816    let _span = tracing::trace_span!("decode", mode = "split_quant_det").entered();
817    let (boxes_tensor, quant_boxes) = boxes;
818    let (scores_tensor, quant_scores) = scores;
819
820    let boxes_tensor = boxes_tensor.reversed_axes();
821    let scores_tensor = scores_tensor.reversed_axes();
822
823    let boxes = {
824        let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
825        postprocess_boxes_quant::<B, _, _>(
826            score_threshold,
827            boxes_tensor,
828            scores_tensor,
829            quant_boxes,
830        )
831    };
832
833    let cap = cap_or_default(output_boxes);
834    let boxes = dispatch_nms_int(nms, iou_threshold, Some(cap), boxes);
835    // Detection cap convention (see `cap_or_default`). NMS already capped to
836    // `cap`; the `min` here is a redundant guard for non-NMS bypass mode.
837    let len = cap.min(boxes.len());
838    output_boxes.clear();
839    for b in boxes.iter().take(len) {
840        output_boxes.push(dequant_detect_box(b, quant_scores));
841    }
842}
843
844/// Internal implementation of YOLO split detection decoding for float tensors.
845///
846/// Expected shapes of inputs:
847/// - boxes: (4, num_boxes)
848/// - scores: (num_classes, num_boxes)
849///
850/// # Panics
851/// Panics if shapes don't match the expected dimensions.
852pub(crate) fn impl_yolo_split_float<
853    B: BBoxTypeTrait,
854    BOX: Float + AsPrimitive<f32> + Send + Sync,
855    SCORE: Float + AsPrimitive<f32> + Send + Sync,
856>(
857    boxes_tensor: ArrayView2<BOX>,
858    scores_tensor: ArrayView2<SCORE>,
859    score_threshold: f32,
860    iou_threshold: f32,
861    nms: Option<Nms>,
862    output_boxes: &mut Vec<DetectBox>,
863) where
864    f32: AsPrimitive<SCORE>,
865{
866    let _span = tracing::trace_span!("decode", mode = "split_float_det").entered();
867    let boxes_tensor = boxes_tensor.reversed_axes();
868    let scores_tensor = scores_tensor.reversed_axes();
869    let boxes =
870        postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
871    let cap = cap_or_default(output_boxes);
872    let boxes = dispatch_nms_float(nms, iou_threshold, Some(cap), boxes);
873    // Detection cap convention (see `cap_or_default`). NMS already capped to
874    // `cap`; the `min` here is a redundant guard for non-NMS bypass mode.
875    let len = cap.min(boxes.len());
876    output_boxes.clear();
877    for b in boxes.into_iter().take(len) {
878        output_boxes.push(b);
879    }
880}
881
882/// Divide each survivor's bbox by `(input_w, input_h)` when the schema
883/// declares `normalized: false`. Pixel-space box coords from the model
884/// are pulled into the canonical `[0, 1]` range expected by `protobox`
885/// and downstream callers — see EDGEAI-1303.
886///
887/// No-op when `normalized` is `Some(true)` / `None`, when `input_dims`
888/// is `None`, or when there are no survivors.
889#[inline]
890pub(crate) fn maybe_normalize_boxes_in_place(
891    boxes: &mut [(DetectBox, usize)],
892    normalized: Option<bool>,
893    input_dims: Option<(usize, usize)>,
894) {
895    if normalized != Some(false) {
896        return;
897    }
898    let Some((w, h)) = input_dims else {
899        return;
900    };
901    if w == 0 || h == 0 {
902        return;
903    }
904    let inv_w = 1.0 / w as f32;
905    let inv_h = 1.0 / h as f32;
906    for (b, _) in boxes.iter_mut() {
907        b.bbox.xmin *= inv_w;
908        b.bbox.ymin *= inv_h;
909        b.bbox.xmax *= inv_w;
910        b.bbox.ymax *= inv_h;
911    }
912}
913
914/// Internal implementation of YOLO detection segmentation decoding for
915/// quantized tensors.
916///
917/// Expected shapes of inputs:
918/// - boxes: (4 + num_classes + num_protos, num_boxes)
919/// - protos: (proto_height, proto_width, num_protos)
920///
921/// # Errors
922/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
923#[allow(clippy::too_many_arguments)]
924pub(crate) fn impl_yolo_segdet_quant<
925    B: BBoxTypeTrait,
926    BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
927    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
928>(
929    boxes: (ArrayView2<BOX>, Quantization),
930    protos: (ArrayView3<PROTO>, Quantization),
931    score_threshold: f32,
932    iou_threshold: f32,
933    nms: Option<Nms>,
934    pre_nms_top_k: usize,
935    max_det: usize,
936    normalized: Option<bool>,
937    input_dims: Option<(usize, usize)>,
938    output_boxes: &mut Vec<DetectBox>,
939    output_masks: &mut Vec<Segmentation>,
940) -> Result<(), crate::DecoderError>
941where
942    f32: AsPrimitive<BOX>,
943{
944    let (boxes, quant_boxes) = boxes;
945    let num_protos = protos.0.dim().2;
946
947    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
948    let mut boxes = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
949        (boxes_tensor, quant_boxes),
950        (scores_tensor, quant_boxes),
951        score_threshold,
952        iou_threshold,
953        nms,
954        pre_nms_top_k,
955        max_det,
956    );
957    maybe_normalize_boxes_in_place(&mut boxes, normalized, input_dims);
958
959    impl_yolo_split_segdet_quant_process_masks::<_, _>(
960        boxes,
961        (mask_tensor, quant_boxes),
962        protos,
963        output_boxes,
964        output_masks,
965    )
966}
967
968/// Internal implementation of YOLO detection segmentation decoding for
969/// float tensors.
970///
971/// Expected shapes of inputs:
972/// - boxes: (4 + num_classes + num_protos, num_boxes)
973/// - protos: (proto_height, proto_width, num_protos)
974///
975/// # Panics
976/// Panics if shapes don't match the expected dimensions.
977#[allow(clippy::too_many_arguments)]
978pub(crate) fn impl_yolo_segdet_float<
979    B: BBoxTypeTrait,
980    BOX: Float + AsPrimitive<f32> + Send + Sync,
981    PROTO: Float + AsPrimitive<f32> + Send + Sync,
982>(
983    boxes: ArrayView2<BOX>,
984    protos: ArrayView3<PROTO>,
985    score_threshold: f32,
986    iou_threshold: f32,
987    nms: Option<Nms>,
988    pre_nms_top_k: usize,
989    max_det: usize,
990    normalized: Option<bool>,
991    input_dims: Option<(usize, usize)>,
992    output_boxes: &mut Vec<DetectBox>,
993    output_masks: &mut Vec<Segmentation>,
994) -> Result<(), crate::DecoderError>
995where
996    f32: AsPrimitive<BOX>,
997{
998    let num_protos = protos.dim().2;
999    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
1000    let mut boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
1001        boxes_tensor,
1002        scores_tensor,
1003        score_threshold,
1004        iou_threshold,
1005        nms,
1006        pre_nms_top_k,
1007        max_det,
1008    );
1009    maybe_normalize_boxes_in_place(&mut boxes, normalized, input_dims);
1010    impl_yolo_split_segdet_process_masks(boxes, mask_tensor, protos, output_boxes, output_masks)
1011}
1012
1013pub(crate) fn impl_yolo_segdet_get_boxes<
1014    B: BBoxTypeTrait,
1015    BOX: Float + AsPrimitive<f32> + Send + Sync,
1016    SCORE: Float + AsPrimitive<f32> + Send + Sync,
1017>(
1018    boxes_tensor: ArrayView2<BOX>,
1019    scores_tensor: ArrayView2<SCORE>,
1020    score_threshold: f32,
1021    iou_threshold: f32,
1022    nms: Option<Nms>,
1023    pre_nms_top_k: usize,
1024    max_det: usize,
1025) -> Vec<(DetectBox, usize)>
1026where
1027    f32: AsPrimitive<SCORE>,
1028{
1029    let span = tracing::trace_span!(
1030        "decode",
1031        n_candidates = tracing::field::Empty,
1032        n_after_topk = tracing::field::Empty,
1033        n_after_nms = tracing::field::Empty,
1034        n_detections = tracing::field::Empty,
1035    );
1036    let _guard = span.enter();
1037
1038    let mut boxes = {
1039        let _s = tracing::trace_span!("score_filter").entered();
1040        postprocess_boxes_index_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor)
1041    };
1042    span.record("n_candidates", boxes.len());
1043
1044    if nms.is_some() {
1045        let _s = tracing::trace_span!("top_k", k = pre_nms_top_k).entered();
1046        truncate_to_top_k_by_score(&mut boxes, pre_nms_top_k);
1047    }
1048    span.record("n_after_topk", boxes.len());
1049
1050    let mut boxes = {
1051        let _s = tracing::trace_span!("nms").entered();
1052        dispatch_nms_extra_float(nms, iou_threshold, Some(max_det), boxes)
1053    };
1054    span.record("n_after_nms", boxes.len());
1055
1056    // NMS already capped to `max_det`; the trailing sort+truncate is a
1057    // redundant guard for the bypass-NMS path (`nms = None`).
1058    boxes.sort_unstable_by(|a, b| b.0.score.total_cmp(&a.0.score));
1059    boxes.truncate(max_det);
1060    span.record("n_detections", boxes.len());
1061
1062    boxes
1063}
1064
1065pub(crate) fn impl_yolo_end_to_end_segdet_get_boxes<
1066    B: BBoxTypeTrait,
1067    BOX: Float + AsPrimitive<f32> + Send + Sync,
1068    SCORE: Float + AsPrimitive<f32> + Send + Sync,
1069    CLASS: AsPrimitive<f32> + Send + Sync,
1070>(
1071    boxes: ArrayView2<BOX>,
1072    scores: ArrayView2<SCORE>,
1073    classes: ArrayView1<CLASS>,
1074    score_threshold: f32,
1075    max_boxes: usize,
1076) -> Vec<(DetectBox, usize)>
1077where
1078    f32: AsPrimitive<SCORE>,
1079{
1080    let mut boxes = postprocess_boxes_index_float::<B, _, _>(score_threshold.as_(), boxes, scores);
1081    boxes.truncate(max_boxes);
1082    for (b, ind) in &mut boxes {
1083        b.label = classes[*ind].as_().round() as usize;
1084    }
1085    boxes
1086}
1087
1088pub(crate) fn impl_yolo_split_segdet_process_masks<
1089    MASK: Float + AsPrimitive<f32> + Send + Sync,
1090    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1091>(
1092    boxes: Vec<(DetectBox, usize)>,
1093    masks_tensor: ArrayView2<MASK>,
1094    protos_tensor: ArrayView3<PROTO>,
1095    output_boxes: &mut Vec<DetectBox>,
1096    output_masks: &mut Vec<Segmentation>,
1097) -> Result<(), crate::DecoderError> {
1098    let _span = tracing::trace_span!("process_masks", n = boxes.len(), mode = "float").entered();
1099    // Boxes are already bounded by the upstream `max_det` cap from
1100    // `_get_boxes`; no second cap is needed here (EDGEAI-1302).
1101
1102    let boxes = decode_segdet_f32(boxes, masks_tensor, protos_tensor)?;
1103    output_boxes.clear();
1104    output_masks.clear();
1105    for (b, roi, m) in boxes.into_iter() {
1106        output_boxes.push(b);
1107        output_masks.push(Segmentation {
1108            xmin: roi.xmin,
1109            ymin: roi.ymin,
1110            xmax: roi.xmax,
1111            ymax: roi.ymax,
1112            segmentation: m,
1113        });
1114    }
1115    Ok(())
1116}
1117/// Expected input shapes:
1118/// - boxes_tensor: (num_boxes, 4)
1119/// - scores_tensor: (num_boxes, num_classes)
1120pub(crate) fn impl_yolo_split_segdet_quant_get_boxes<
1121    B: BBoxTypeTrait,
1122    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
1123    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
1124>(
1125    boxes: (ArrayView2<BOX>, Quantization),
1126    scores: (ArrayView2<SCORE>, Quantization),
1127    score_threshold: f32,
1128    iou_threshold: f32,
1129    nms: Option<Nms>,
1130    pre_nms_top_k: usize,
1131    max_det: usize,
1132) -> Vec<(DetectBox, usize)>
1133where
1134    f32: AsPrimitive<SCORE>,
1135{
1136    let (boxes_tensor, quant_boxes) = boxes;
1137    let (scores_tensor, quant_scores) = scores;
1138
1139    let span = tracing::trace_span!(
1140        "decode",
1141        n_candidates = tracing::field::Empty,
1142        n_after_topk = tracing::field::Empty,
1143        n_after_nms = tracing::field::Empty,
1144        n_detections = tracing::field::Empty,
1145    );
1146    let _guard = span.enter();
1147
1148    let mut boxes = {
1149        let _s = tracing::trace_span!("score_filter").entered();
1150        let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
1151        postprocess_boxes_index_quant::<B, _, _>(
1152            score_threshold,
1153            boxes_tensor,
1154            scores_tensor,
1155            quant_boxes,
1156        )
1157    };
1158    span.record("n_candidates", boxes.len());
1159
1160    if nms.is_some() {
1161        let _s = tracing::trace_span!("top_k", k = pre_nms_top_k).entered();
1162        truncate_to_top_k_by_score_quant(&mut boxes, pre_nms_top_k);
1163    }
1164    span.record("n_after_topk", boxes.len());
1165
1166    let mut boxes = {
1167        let _s = tracing::trace_span!("nms").entered();
1168        dispatch_nms_extra_int(nms, iou_threshold, Some(max_det), boxes)
1169    };
1170    span.record("n_after_nms", boxes.len());
1171
1172    // NMS already capped to `max_det`; the trailing sort+truncate is a
1173    // redundant guard for the bypass-NMS path (`nms = None`).
1174    boxes.sort_unstable_by(|a, b| b.0.score.cmp(&a.0.score));
1175    boxes.truncate(max_det);
1176    let result: Vec<_> = {
1177        let _s = tracing::trace_span!("box_dequant", n = boxes.len()).entered();
1178        boxes
1179            .into_iter()
1180            .map(|(b, i)| (dequant_detect_box(&b, quant_scores), i))
1181            .collect()
1182    };
1183    span.record("n_detections", result.len());
1184
1185    result
1186}
1187
1188pub(crate) fn impl_yolo_split_segdet_quant_process_masks<
1189    MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1190    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
1191>(
1192    boxes: Vec<(DetectBox, usize)>,
1193    mask_coeff: (ArrayView2<MASK>, Quantization),
1194    protos: (ArrayView3<PROTO>, Quantization),
1195    output_boxes: &mut Vec<DetectBox>,
1196    output_masks: &mut Vec<Segmentation>,
1197) -> Result<(), crate::DecoderError> {
1198    let _span = tracing::trace_span!("process_masks", n = boxes.len(), mode = "quant").entered();
1199    let (masks, quant_masks) = mask_coeff;
1200    let (protos, quant_protos) = protos;
1201
1202    // Boxes are already bounded by the upstream `max_det` cap from
1203    // `_get_boxes`; no second cap is needed here (EDGEAI-1302).
1204
1205    let boxes = decode_segdet_quant(boxes, masks, protos, quant_masks, quant_protos)?;
1206    output_boxes.clear();
1207    output_masks.clear();
1208    for (b, roi, m) in boxes.into_iter() {
1209        output_boxes.push(b);
1210        output_masks.push(Segmentation {
1211            xmin: roi.xmin,
1212            ymin: roi.ymin,
1213            xmax: roi.xmax,
1214            ymax: roi.ymax,
1215            segmentation: m,
1216        });
1217    }
1218    Ok(())
1219}
1220
1221/// Internal implementation of YOLO split detection segmentation decoding for
1222/// float tensors.
1223///
1224/// Expected shapes of inputs:
1225/// - boxes_tensor: (4, num_boxes)
1226/// - scores_tensor: (num_classes, num_boxes)
1227/// - mask_tensor: (num_protos, num_boxes)
1228/// - protos: (proto_height, proto_width, num_protos)
1229///
1230/// # Errors
1231/// Returns `DecoderError::InvalidShape` if bounding boxes are not normalized.
1232#[allow(clippy::too_many_arguments)]
1233pub(crate) fn impl_yolo_split_segdet_float<
1234    B: BBoxTypeTrait,
1235    BOX: Float + AsPrimitive<f32> + Send + Sync,
1236    SCORE: Float + AsPrimitive<f32> + Send + Sync,
1237    MASK: Float + AsPrimitive<f32> + Send + Sync,
1238    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1239>(
1240    boxes_tensor: ArrayView2<BOX>,
1241    scores_tensor: ArrayView2<SCORE>,
1242    mask_tensor: ArrayView2<MASK>,
1243    protos: ArrayView3<PROTO>,
1244    score_threshold: f32,
1245    iou_threshold: f32,
1246    nms: Option<Nms>,
1247    pre_nms_top_k: usize,
1248    max_det: usize,
1249    normalized: Option<bool>,
1250    input_dims: Option<(usize, usize)>,
1251    output_boxes: &mut Vec<DetectBox>,
1252    output_masks: &mut Vec<Segmentation>,
1253) -> Result<(), crate::DecoderError>
1254where
1255    f32: AsPrimitive<SCORE>,
1256{
1257    let (boxes_tensor, scores_tensor, mask_tensor) =
1258        postprocess_yolo_split_segdet(boxes_tensor, scores_tensor, mask_tensor);
1259
1260    let mut boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
1261        boxes_tensor,
1262        scores_tensor,
1263        score_threshold,
1264        iou_threshold,
1265        nms,
1266        pre_nms_top_k,
1267        max_det,
1268    );
1269    maybe_normalize_boxes_in_place(&mut boxes, normalized, input_dims);
1270    impl_yolo_split_segdet_process_masks(boxes, mask_tensor, protos, output_boxes, output_masks)
1271}
1272
1273// ---------------------------------------------------------------------------
1274// Proto-extraction variants: return ProtoData instead of materialized masks
1275// ---------------------------------------------------------------------------
1276
1277/// Proto-extraction variant of `impl_yolo_segdet_quant`.
1278/// Runs NMS but returns raw `ProtoData` instead of materialized masks.
1279#[allow(clippy::too_many_arguments)]
1280pub(crate) fn impl_yolo_segdet_quant_proto<
1281    B: BBoxTypeTrait,
1282    BOX: PrimInt
1283        + AsPrimitive<i64>
1284        + AsPrimitive<i128>
1285        + AsPrimitive<f32>
1286        + AsPrimitive<i8>
1287        + Send
1288        + Sync,
1289    PROTO: PrimInt
1290        + AsPrimitive<i64>
1291        + AsPrimitive<i128>
1292        + AsPrimitive<f32>
1293        + AsPrimitive<i8>
1294        + Send
1295        + Sync,
1296>(
1297    boxes: (ArrayView2<BOX>, Quantization),
1298    protos: (ArrayView3<PROTO>, Quantization),
1299    score_threshold: f32,
1300    iou_threshold: f32,
1301    nms: Option<Nms>,
1302    pre_nms_top_k: usize,
1303    max_det: usize,
1304    normalized: Option<bool>,
1305    input_dims: Option<(usize, usize)>,
1306    output_boxes: &mut Vec<DetectBox>,
1307) -> ProtoData
1308where
1309    f32: AsPrimitive<BOX>,
1310{
1311    let (boxes_arr, quant_boxes) = boxes;
1312    let (protos_arr, quant_protos) = protos;
1313    let num_protos = protos_arr.dim().2;
1314
1315    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes_arr, num_protos);
1316
1317    let mut det_indices = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
1318        (boxes_tensor, quant_boxes),
1319        (scores_tensor, quant_boxes),
1320        score_threshold,
1321        iou_threshold,
1322        nms,
1323        pre_nms_top_k,
1324        max_det,
1325    );
1326    maybe_normalize_boxes_in_place(&mut det_indices, normalized, input_dims);
1327
1328    extract_proto_data_quant(
1329        det_indices,
1330        mask_tensor,
1331        quant_boxes,
1332        protos_arr,
1333        quant_protos,
1334        output_boxes,
1335    )
1336}
1337
1338/// Proto-extraction variant of `impl_yolo_segdet_float`.
1339/// Runs NMS but returns raw `ProtoData` instead of materialized masks.
1340#[allow(clippy::too_many_arguments)]
1341pub(crate) fn impl_yolo_segdet_float_proto<
1342    B: BBoxTypeTrait,
1343    BOX: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1344    PROTO: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1345>(
1346    boxes: ArrayView2<BOX>,
1347    protos: ArrayView3<PROTO>,
1348    score_threshold: f32,
1349    iou_threshold: f32,
1350    nms: Option<Nms>,
1351    pre_nms_top_k: usize,
1352    max_det: usize,
1353    normalized: Option<bool>,
1354    input_dims: Option<(usize, usize)>,
1355    output_boxes: &mut Vec<DetectBox>,
1356) -> ProtoData
1357where
1358    f32: AsPrimitive<BOX>,
1359{
1360    let num_protos = protos.dim().2;
1361    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
1362
1363    let mut boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
1364        boxes_tensor,
1365        scores_tensor,
1366        score_threshold,
1367        iou_threshold,
1368        nms,
1369        pre_nms_top_k,
1370        max_det,
1371    );
1372    maybe_normalize_boxes_in_place(&mut boxes, normalized, input_dims);
1373
1374    extract_proto_data_float(boxes, mask_tensor, protos, output_boxes)
1375}
1376
1377/// Proto-extraction variant of `impl_yolo_split_segdet_float`.
1378/// Runs NMS but returns raw `ProtoData` instead of materialized masks.
1379#[allow(clippy::too_many_arguments)]
1380pub(crate) fn impl_yolo_split_segdet_float_proto<
1381    B: BBoxTypeTrait,
1382    BOX: Float + AsPrimitive<f32> + Send + Sync,
1383    SCORE: Float + AsPrimitive<f32> + Send + Sync,
1384    MASK: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1385    PROTO: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1386>(
1387    boxes_tensor: ArrayView2<BOX>,
1388    scores_tensor: ArrayView2<SCORE>,
1389    mask_tensor: ArrayView2<MASK>,
1390    protos: ArrayView3<PROTO>,
1391    score_threshold: f32,
1392    iou_threshold: f32,
1393    nms: Option<Nms>,
1394    pre_nms_top_k: usize,
1395    max_det: usize,
1396    output_boxes: &mut Vec<DetectBox>,
1397) -> ProtoData
1398where
1399    f32: AsPrimitive<SCORE>,
1400{
1401    let (boxes_tensor, scores_tensor, mask_tensor) =
1402        postprocess_yolo_split_segdet(boxes_tensor, scores_tensor, mask_tensor);
1403    let det_indices = impl_yolo_segdet_get_boxes::<B, _, _>(
1404        boxes_tensor,
1405        scores_tensor,
1406        score_threshold,
1407        iou_threshold,
1408        nms,
1409        pre_nms_top_k,
1410        max_det,
1411    );
1412
1413    extract_proto_data_float(det_indices, mask_tensor, protos, output_boxes)
1414}
1415
1416/// Proto-extraction variant of `decode_yolo_end_to_end_segdet_float`.
1417pub(crate) fn decode_yolo_end_to_end_segdet_float_proto<T>(
1418    output: ArrayView2<T>,
1419    protos: ArrayView3<T>,
1420    score_threshold: f32,
1421    output_boxes: &mut Vec<DetectBox>,
1422) -> Result<ProtoData, crate::DecoderError>
1423where
1424    T: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1425    f32: AsPrimitive<T>,
1426{
1427    let (boxes, scores, classes, mask_coeff) =
1428        postprocess_yolo_end_to_end_segdet(&output, protos.dim().2)?;
1429    let cap = cap_or_default(output_boxes);
1430    let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
1431        boxes,
1432        scores,
1433        classes,
1434        score_threshold,
1435        cap,
1436    );
1437
1438    Ok(extract_proto_data_float(
1439        boxes,
1440        mask_coeff,
1441        protos,
1442        output_boxes,
1443    ))
1444}
1445
1446/// Proto-extraction variant of `decode_yolo_split_end_to_end_segdet_float`.
1447#[allow(clippy::too_many_arguments)]
1448pub(crate) fn decode_yolo_split_end_to_end_segdet_float_proto<T>(
1449    boxes: ArrayView2<T>,
1450    scores: ArrayView2<T>,
1451    classes: ArrayView2<T>,
1452    mask_coeff: ArrayView2<T>,
1453    protos: ArrayView3<T>,
1454    score_threshold: f32,
1455    output_boxes: &mut Vec<DetectBox>,
1456) -> Result<ProtoData, crate::DecoderError>
1457where
1458    T: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1459    f32: AsPrimitive<T>,
1460{
1461    let (boxes, scores, classes, mask_coeff) =
1462        postprocess_yolo_split_end_to_end_segdet(boxes, scores, &classes, mask_coeff)?;
1463    let cap = cap_or_default(output_boxes);
1464    let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
1465        boxes,
1466        scores,
1467        classes,
1468        score_threshold,
1469        cap,
1470    );
1471
1472    Ok(extract_proto_data_float(
1473        boxes,
1474        mask_coeff,
1475        protos,
1476        output_boxes,
1477    ))
1478}
1479
1480/// Helper: extract ProtoData from float mask coefficients + protos.
1481///
1482/// Builds [`ProtoData`] with both `protos` and `mask_coefficients` as
1483/// [`edgefirst_tensor::TensorDyn`]. Preserves the native element type for
1484/// `f16` and `f32`; narrows `f64` to `f32` (there is no native f64 kernel
1485/// path). `mask_coefficients` shape is `[num_detections, num_protos]`.
1486pub(super) fn extract_proto_data_float<
1487    MASK: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1488    PROTO: Float + AsPrimitive<f32> + Copy + Send + Sync + FloatProtoElem,
1489>(
1490    det_indices: Vec<(DetectBox, usize)>,
1491    mask_tensor: ArrayView2<MASK>,
1492    protos: ArrayView3<PROTO>,
1493    output_boxes: &mut Vec<DetectBox>,
1494) -> ProtoData {
1495    let _span = tracing::trace_span!(
1496        "extract_proto",
1497        n = det_indices.len(),
1498        num_protos = mask_tensor.ncols(),
1499        layout = "nhwc",
1500    )
1501    .entered();
1502
1503    let num_protos = mask_tensor.ncols();
1504    let n = det_indices.len();
1505
1506    // Per-detection coefficients packed row-major into a contiguous buffer,
1507    // preserving the source dtype. Shape: [N, num_protos] — N=0 is permitted
1508    // (tracker path emits no detections this frame) since Mem-backed tensors
1509    // accept zero-element shapes as "empty collection" sentinels.
1510    let mut coeff_rows: Vec<MASK> = Vec::with_capacity(n * num_protos);
1511    output_boxes.clear();
1512    for (det, idx) in det_indices {
1513        output_boxes.push(det);
1514        let row = mask_tensor.row(idx);
1515        coeff_rows.extend(row.iter().copied());
1516    }
1517
1518    let mask_coefficients = MASK::slice_into_tensor_dyn(&coeff_rows, &[n, num_protos])
1519        .expect("allocating mask_coefficients TensorDyn");
1520    let protos_tensor =
1521        PROTO::arrayview3_into_tensor_dyn(protos).expect("allocating protos TensorDyn");
1522
1523    ProtoData {
1524        mask_coefficients,
1525        protos: protos_tensor,
1526        layout: ProtoLayout::Nhwc,
1527    }
1528}
1529
1530/// Helper: extract ProtoData from quantized mask coefficients + protos.
1531///
1532/// Dequantizes mask coefficients to f32 at extraction (one-time cost on a
1533/// `num_detections * num_protos` slice) and keeps protos in raw i8,
1534/// attaching the dequantization params as
1535/// [`edgefirst_tensor::Quantization::per_tensor`] metadata on the proto
1536/// tensor. The GPU shader / CPU kernel reads `protos.quantization()` and
1537/// dequantizes per-texel.
1538pub(crate) fn extract_proto_data_quant<
1539    MASK: PrimInt + AsPrimitive<f32> + AsPrimitive<i8> + Send + Sync,
1540    PROTO: PrimInt + AsPrimitive<f32> + AsPrimitive<i8> + Send + Sync + 'static,
1541>(
1542    det_indices: Vec<(DetectBox, usize)>,
1543    mask_tensor: ArrayView2<MASK>,
1544    quant_masks: Quantization,
1545    protos: ArrayView3<PROTO>,
1546    quant_protos: Quantization,
1547    output_boxes: &mut Vec<DetectBox>,
1548) -> ProtoData {
1549    use edgefirst_tensor::{Tensor, TensorDyn, TensorMapTrait, TensorMemory, TensorTrait};
1550
1551    let span = tracing::trace_span!(
1552        "extract_proto",
1553        n = det_indices.len(),
1554        num_protos = tracing::field::Empty,
1555        layout = tracing::field::Empty,
1556    );
1557    let _guard = span.enter();
1558
1559    let num_protos = mask_tensor.ncols();
1560    let n = det_indices.len();
1561    span.record("num_protos", num_protos);
1562
1563    // Fast path: when no detections survive NMS, skip the expensive proto
1564    // tensor copy (819KB for 160×160×32). Allocate with the correct shape
1565    // (preserving the documented ProtoData.protos shape contract) but skip
1566    // copying from the source tensor — the zeroed allocation is sufficient
1567    // since materialize_masks early-returns on empty detect slices.
1568    if n == 0 {
1569        output_boxes.clear();
1570        let (h, w, k) = protos.dim();
1571
1572        // Detect physical layout (same logic as the normal path).
1573        let (proto_shape, proto_layout) = if std::any::TypeId::of::<PROTO>()
1574            == std::any::TypeId::of::<i8>()
1575        {
1576            if protos.is_standard_layout() {
1577                (&[h, w, k][..], ProtoLayout::Nhwc)
1578            } else if protos.ndim() == 3 && protos.strides() == [w as isize, 1, (h * w) as isize] {
1579                (&[k, h, w][..], ProtoLayout::Nchw)
1580            } else {
1581                (&[h, w, k][..], ProtoLayout::Nhwc)
1582            }
1583        } else {
1584            (&[h, w, k][..], ProtoLayout::Nhwc)
1585        };
1586
1587        let coeff_tensor = Tensor::<i8>::new(&[0, num_protos], Some(TensorMemory::Mem), None)
1588            .expect("allocating empty mask_coefficients tensor");
1589        let coeff_quant =
1590            edgefirst_tensor::Quantization::per_tensor(quant_masks.scale, quant_masks.zero_point);
1591        let coeff_tensor = coeff_tensor
1592            .with_quantization(coeff_quant)
1593            .expect("per-tensor quantization on mask coefficients");
1594        let protos_tensor = Tensor::<i8>::new(proto_shape, Some(TensorMemory::Mem), None)
1595            .expect("allocating protos tensor");
1596        let tensor_quant =
1597            edgefirst_tensor::Quantization::per_tensor(quant_protos.scale, quant_protos.zero_point);
1598        let protos_tensor = protos_tensor
1599            .with_quantization(tensor_quant)
1600            .expect("per-tensor quantization on protos tensor");
1601        return ProtoData {
1602            mask_coefficients: TensorDyn::I8(coeff_tensor),
1603            protos: TensorDyn::I8(protos_tensor),
1604            layout: proto_layout,
1605        };
1606    }
1607
1608    // Keep mask coefficients in raw i8 with quantization metadata attached.
1609    // Consumers that need f32 can dequantize on the fly; the scaled-path
1610    // integer kernel uses raw i8 directly for i8×i8→i32 dot products.
1611    let mut coeff_i8 = Vec::<i8>::with_capacity(n * num_protos);
1612    output_boxes.clear();
1613    for (det, idx) in det_indices {
1614        output_boxes.push(det);
1615        let row = mask_tensor.row(idx);
1616        coeff_i8.extend(row.iter().map(|v| {
1617            let v_i8: i8 = v.as_();
1618            v_i8
1619        }));
1620    }
1621
1622    // Shape `[n, num_protos]` with n=0 is permitted (tracker path emits no
1623    // fresh detections this frame) via the Mem-backed zero-size allowance.
1624    let coeff_tensor = Tensor::<i8>::new(&[n, num_protos], Some(TensorMemory::Mem), None)
1625        .expect("allocating mask_coefficients tensor");
1626    if n > 0 {
1627        let mut m = coeff_tensor
1628            .map()
1629            .expect("mapping mask_coefficients tensor");
1630        m.as_mut_slice().copy_from_slice(&coeff_i8);
1631    }
1632    let coeff_quant =
1633        edgefirst_tensor::Quantization::per_tensor(quant_masks.scale, quant_masks.zero_point);
1634    let coeff_tensor = coeff_tensor
1635        .with_quantization(coeff_quant)
1636        .expect("per-tensor quantization on mask coefficients");
1637    let mask_coefficients = TensorDyn::I8(coeff_tensor);
1638
1639    // Keep protos in raw i8 — consumers dequantize via protos.quantization().
1640    // When PROTO is already i8, detect layout and copy efficiently without
1641    // transposing. The mask materialisation kernels dispatch on the layout.
1642    let (h, w, k) = protos.dim();
1643
1644    // Determine physical layout and copy strategy.
1645    let (proto_shape, proto_layout) =
1646        if std::any::TypeId::of::<PROTO>() == std::any::TypeId::of::<i8>() {
1647            if protos.is_standard_layout() {
1648                // Already NHWC [H, W, K] in contiguous memory.
1649                (&[h, w, k][..], ProtoLayout::Nhwc)
1650            } else if protos.ndim() == 3 && protos.strides() == [w as isize, 1, (h * w) as isize] {
1651                // NCHW reinterpreted as NHWC via stride swap. Physical storage
1652                // is [K, H, W] contiguous. Keep in NCHW — eliminates the costly
1653                // 3.1ms transpose entirely.
1654                (&[k, h, w][..], ProtoLayout::Nchw)
1655            } else {
1656                // Unknown layout — fall back to iter copy as NHWC.
1657                (&[h, w, k][..], ProtoLayout::Nhwc)
1658            }
1659        } else {
1660            (&[h, w, k][..], ProtoLayout::Nhwc)
1661        };
1662
1663    let protos_tensor = Tensor::<i8>::new(proto_shape, Some(TensorMemory::Mem), None)
1664        .expect("allocating protos tensor");
1665    {
1666        let mut m = protos_tensor.map().expect("mapping protos tensor");
1667        let dst = m.as_mut_slice();
1668        if std::any::TypeId::of::<PROTO>() == std::any::TypeId::of::<i8>() {
1669            // SAFETY: PROTO == i8 checked via TypeId; cast slice view is
1670            // size/alignment-compatible by construction.
1671            if protos.is_standard_layout() {
1672                let src: &[i8] = unsafe {
1673                    std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len())
1674                };
1675                dst.copy_from_slice(src);
1676            } else if protos.ndim() == 3 && protos.strides() == [w as isize, 1, (h * w) as isize] {
1677                // NCHW physical layout — sequential copy WITHOUT transpose.
1678                // This saves ~3.1ms on A53/A55 by avoiding the tiled
1679                // NCHW→NHWC transpose of the 819KB proto buffer.
1680                let total = h * w * k;
1681                // SAFETY: ArrayView was constructed from a contiguous slice of
1682                // `total` elements. as_ptr() points to the base of that slice.
1683                let src: &[i8] =
1684                    unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, total) };
1685                dst.copy_from_slice(src);
1686            } else {
1687                for (d, s) in dst.iter_mut().zip(protos.iter()) {
1688                    let v_i8: i8 = s.as_();
1689                    *d = v_i8;
1690                }
1691            }
1692        } else {
1693            for (d, s) in dst.iter_mut().zip(protos.iter()) {
1694                let v_i8: i8 = s.as_();
1695                *d = v_i8;
1696            }
1697        }
1698    }
1699    let tensor_quant =
1700        edgefirst_tensor::Quantization::per_tensor(quant_protos.scale, quant_protos.zero_point);
1701    let protos_tensor = protos_tensor
1702        .with_quantization(tensor_quant)
1703        .expect("per-tensor quantization on new Tensor<i8>");
1704
1705    span.record("layout", tracing::field::debug(&proto_layout));
1706
1707    ProtoData {
1708        mask_coefficients,
1709        protos: TensorDyn::I8(protos_tensor),
1710        layout: proto_layout,
1711    }
1712}
1713
1714/// Per-float-dtype construction of a [`TensorDyn`] from a flat slice / 3-D
1715/// `ArrayView`. Replaces the old `IntoProtoTensor` trait. Each implementor
1716/// either passes its element type straight to `Tensor::from_slice` /
1717/// `Tensor::from_arrayview3`, or narrows `f64` to `f32` (no native f64 kernel
1718/// path exists).
1719pub trait FloatProtoElem: Copy + 'static {
1720    fn slice_into_tensor_dyn(
1721        values: &[Self],
1722        shape: &[usize],
1723    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn>;
1724
1725    fn arrayview3_into_tensor_dyn(
1726        view: ArrayView3<'_, Self>,
1727    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn>;
1728}
1729
1730impl FloatProtoElem for f32 {
1731    fn slice_into_tensor_dyn(
1732        values: &[f32],
1733        shape: &[usize],
1734    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1735        edgefirst_tensor::Tensor::<f32>::from_slice(values, shape)
1736            .map(edgefirst_tensor::TensorDyn::F32)
1737    }
1738    fn arrayview3_into_tensor_dyn(
1739        view: ArrayView3<'_, f32>,
1740    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1741        edgefirst_tensor::Tensor::<f32>::from_arrayview3(view).map(edgefirst_tensor::TensorDyn::F32)
1742    }
1743}
1744
1745impl FloatProtoElem for half::f16 {
1746    fn slice_into_tensor_dyn(
1747        values: &[half::f16],
1748        shape: &[usize],
1749    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1750        edgefirst_tensor::Tensor::<half::f16>::from_slice(values, shape)
1751            .map(edgefirst_tensor::TensorDyn::F16)
1752    }
1753    fn arrayview3_into_tensor_dyn(
1754        view: ArrayView3<'_, half::f16>,
1755    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1756        edgefirst_tensor::Tensor::<half::f16>::from_arrayview3(view)
1757            .map(edgefirst_tensor::TensorDyn::F16)
1758    }
1759}
1760
1761impl FloatProtoElem for f64 {
1762    fn slice_into_tensor_dyn(
1763        values: &[f64],
1764        shape: &[usize],
1765    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1766        // Narrow to f32 — no native f64 kernel path.
1767        let narrowed: Vec<f32> = values.iter().map(|&v| v as f32).collect();
1768        edgefirst_tensor::Tensor::<f32>::from_slice(&narrowed, shape)
1769            .map(edgefirst_tensor::TensorDyn::F32)
1770    }
1771    fn arrayview3_into_tensor_dyn(
1772        view: ArrayView3<'_, f64>,
1773    ) -> edgefirst_tensor::Result<edgefirst_tensor::TensorDyn> {
1774        let narrowed: ndarray::Array3<f32> = view.mapv(|v| v as f32);
1775        edgefirst_tensor::Tensor::<f32>::from_arrayview3(narrowed.view())
1776            .map(edgefirst_tensor::TensorDyn::F32)
1777    }
1778}
1779
1780fn postprocess_yolo<'a, T>(
1781    output: &'a ArrayView2<'_, T>,
1782) -> (ArrayView2<'a, T>, ArrayView2<'a, T>) {
1783    let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1784    let scores_tensor = output.slice(s![4.., ..,]).reversed_axes();
1785    (boxes_tensor, scores_tensor)
1786}
1787
1788pub(crate) fn postprocess_yolo_seg<'a, T>(
1789    output: &'a ArrayView2<'_, T>,
1790    num_protos: usize,
1791) -> (ArrayView2<'a, T>, ArrayView2<'a, T>, ArrayView2<'a, T>) {
1792    assert!(
1793        output.shape()[0] > num_protos + 4,
1794        "Output shape is too short: {} <= {} + 4",
1795        output.shape()[0],
1796        num_protos
1797    );
1798    let num_classes = output.shape()[0] - 4 - num_protos;
1799    let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1800    let scores_tensor = output.slice(s![4..(num_classes + 4), ..,]).reversed_axes();
1801    let mask_tensor = output.slice(s![(num_classes + 4).., ..,]).reversed_axes();
1802    (boxes_tensor, scores_tensor, mask_tensor)
1803}
1804
1805pub(crate) fn postprocess_yolo_split_segdet<'a, 'b, 'c, BOX, SCORE, MASK>(
1806    boxes_tensor: ArrayView2<'a, BOX>,
1807    scores_tensor: ArrayView2<'b, SCORE>,
1808    mask_tensor: ArrayView2<'c, MASK>,
1809) -> (
1810    ArrayView2<'a, BOX>,
1811    ArrayView2<'b, SCORE>,
1812    ArrayView2<'c, MASK>,
1813) {
1814    let boxes_tensor = boxes_tensor.reversed_axes();
1815    let scores_tensor = scores_tensor.reversed_axes();
1816    let mask_tensor = mask_tensor.reversed_axes();
1817    (boxes_tensor, scores_tensor, mask_tensor)
1818}
1819
1820fn decode_segdet_f32<
1821    MASK: Float + AsPrimitive<f32> + Send + Sync,
1822    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1823>(
1824    boxes: Vec<(DetectBox, usize)>,
1825    masks: ArrayView2<MASK>,
1826    protos: ArrayView3<PROTO>,
1827) -> Result<Vec<(DetectBox, BoundingBox, Array3<u8>)>, crate::DecoderError> {
1828    if boxes.is_empty() {
1829        return Ok(Vec::new());
1830    }
1831    if masks.shape()[1] != protos.shape()[2] {
1832        return Err(crate::DecoderError::InvalidShape(format!(
1833            "Mask coefficients count ({}) doesn't match protos channel count ({})",
1834            masks.shape()[1],
1835            protos.shape()[2],
1836        )));
1837    }
1838    boxes
1839        .into_par_iter()
1840        .map(|b| {
1841            let ind = b.1;
1842            // `protobox` returns the cropped proto slice for `make_segmentation`
1843            // and a `roi` snapped to the 1/proto-grid step. The detection bbox
1844            // stays untouched (EDGEAI-1304); the snapped roi is reported back
1845            // separately so callers can describe where the cropped mask lives.
1846            let (protos, roi) = protobox(&protos, &b.0.bbox)?;
1847            Ok((b.0, roi, make_segmentation(masks.row(ind), protos.view())))
1848        })
1849        .collect()
1850}
1851
1852pub(crate) fn decode_segdet_quant<
1853    MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1854    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1855>(
1856    boxes: Vec<(DetectBox, usize)>,
1857    masks: ArrayView2<MASK>,
1858    protos: ArrayView3<PROTO>,
1859    quant_masks: Quantization,
1860    quant_protos: Quantization,
1861) -> Result<Vec<(DetectBox, BoundingBox, Array3<u8>)>, crate::DecoderError> {
1862    if boxes.is_empty() {
1863        return Ok(Vec::new());
1864    }
1865    if masks.shape()[1] != protos.shape()[2] {
1866        return Err(crate::DecoderError::InvalidShape(format!(
1867            "Mask coefficients count ({}) doesn't match protos channel count ({})",
1868            masks.shape()[1],
1869            protos.shape()[2],
1870        )));
1871    }
1872
1873    let total_bits = MASK::zero().count_zeros() + PROTO::zero().count_zeros() + 5; // 32 protos is 2^5
1874    boxes
1875        .into_iter()
1876        .map(|b| {
1877            let i = b.1;
1878            // See EDGEAI-1304: the caller's bbox stays untouched; the
1879            // proto-grid-snapped `roi` is reported back so the Segmentation's
1880            // bounds can describe the actual cropped mask region.
1881            let (protos, roi) = protobox(&protos, &b.0.bbox.to_canonical())?;
1882            let seg = match total_bits {
1883                0..=64 => make_segmentation_quant::<MASK, PROTO, i64>(
1884                    masks.row(i),
1885                    protos.view(),
1886                    quant_masks,
1887                    quant_protos,
1888                ),
1889                65..=128 => make_segmentation_quant::<MASK, PROTO, i128>(
1890                    masks.row(i),
1891                    protos.view(),
1892                    quant_masks,
1893                    quant_protos,
1894                ),
1895                _ => {
1896                    return Err(crate::DecoderError::NotSupported(format!(
1897                        "Unsupported bit width ({total_bits}) for segmentation computation"
1898                    )));
1899                }
1900            };
1901            Ok((b.0, roi, seg))
1902        })
1903        .collect()
1904}
1905
1906fn protobox<'a, T>(
1907    protos: &'a ArrayView3<T>,
1908    roi: &BoundingBox,
1909) -> Result<(ArrayView3<'a, T>, BoundingBox), crate::DecoderError> {
1910    let width = protos.dim().1 as f32;
1911    let height = protos.dim().0 as f32;
1912
1913    // Detect un-normalized bounding boxes (pixel-space coordinates).
1914    // protobox expects normalized coordinates in [0, 1]. The decoder will
1915    // normalize pixel-space coords automatically when the schema declares
1916    // `Detection::normalized = false` AND model input dimensions are known
1917    // (EDGEAI-1303); reaching this guard means at least one of those is
1918    // missing.
1919    //
1920    // The limit is set to 2.0 (not 1.01) because YOLO models legitimately
1921    // predict coordinates slightly > 1.0 for objects near frame edges.
1922    // Any value > 2.0 is clearly pixel-space (even the smallest practical
1923    // model input of 32×32 would produce coordinates >> 2.0).
1924    const NORM_LIMIT: f32 = 2.0;
1925    if roi.xmin > NORM_LIMIT
1926        || roi.ymin > NORM_LIMIT
1927        || roi.xmax > NORM_LIMIT
1928        || roi.ymax > NORM_LIMIT
1929    {
1930        return Err(crate::DecoderError::InvalidShape(format!(
1931            "Bounding box coordinates appear un-normalized (pixel-space). \
1932             Got bbox=({:.2}, {:.2}, {:.2}, {:.2}) but expected values in [0, 1]. \
1933             Two ways to fix this: \
1934             (1) declare `Detection::normalized = false` in the model schema \
1935             AND make sure the schema's `input.shape` / `input.dshape` carries \
1936             the model input dims so the decoder can divide by (W, H) before NMS \
1937             (EDGEAI-1303 — verify with `Decoder::input_dims().is_some()`); or \
1938             (2) normalize the boxes in-graph before decode().",
1939            roi.xmin, roi.ymin, roi.xmax, roi.ymax,
1940        )));
1941    }
1942
1943    let roi = [
1944        (roi.xmin * width).clamp(0.0, width) as usize,
1945        (roi.ymin * height).clamp(0.0, height) as usize,
1946        (roi.xmax * width).clamp(0.0, width).ceil() as usize,
1947        (roi.ymax * height).clamp(0.0, height).ceil() as usize,
1948    ];
1949
1950    let roi_norm = [
1951        roi[0] as f32 / width,
1952        roi[1] as f32 / height,
1953        roi[2] as f32 / width,
1954        roi[3] as f32 / height,
1955    ]
1956    .into();
1957
1958    let cropped = protos.slice(s![roi[1]..roi[3], roi[0]..roi[2], ..]);
1959
1960    Ok((cropped, roi_norm))
1961}
1962
1963/// Compute a single instance segmentation mask from mask coefficients and
1964/// proto maps (float path).
1965///
1966/// Computes `sigmoid(coefficients · protos)` and maps to `[0, 255]`.
1967/// Returns an `(H, W, 1)` u8 array.
1968fn make_segmentation<
1969    MASK: Float + AsPrimitive<f32> + Send + Sync,
1970    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1971>(
1972    mask: ArrayView1<MASK>,
1973    protos: ArrayView3<PROTO>,
1974) -> Array3<u8> {
1975    let shape = protos.shape();
1976
1977    // Safe to unwrap since the shapes will always be compatible
1978    let mask = mask.to_shape((1, mask.len())).unwrap();
1979    let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
1980    let protos = protos.reversed_axes();
1981    let mask = mask.map(|x| x.as_());
1982    let protos = protos.map(|x| x.as_());
1983
1984    // Safe to unwrap since the shapes will always be compatible
1985    let mask = mask
1986        .dot(&protos)
1987        .into_shape_with_order((shape[0], shape[1], 1))
1988        .unwrap();
1989
1990    mask.map(|x| {
1991        let sigmoid = 1.0 / (1.0 + (-*x).exp());
1992        (sigmoid * 255.0).round() as u8
1993    })
1994}
1995
1996/// Compute a single instance segmentation mask from quantized mask
1997/// coefficients and proto maps.
1998///
1999/// Dequantizes both inputs (subtracting zero-points), computes the dot
2000/// product, applies sigmoid, and maps to `[0, 255]`.
2001/// Returns an `(H, W, 1)` u8 array.
2002fn make_segmentation_quant<
2003    MASK: PrimInt + AsPrimitive<DEST> + Send + Sync,
2004    PROTO: PrimInt + AsPrimitive<DEST> + Send + Sync,
2005    DEST: PrimInt + 'static + Signed + AsPrimitive<f32> + Debug,
2006>(
2007    mask: ArrayView1<MASK>,
2008    protos: ArrayView3<PROTO>,
2009    quant_masks: Quantization,
2010    quant_protos: Quantization,
2011) -> Array3<u8>
2012where
2013    i32: AsPrimitive<DEST>,
2014    f32: AsPrimitive<DEST>,
2015{
2016    let shape = protos.shape();
2017
2018    // Safe to unwrap since the shapes will always be compatible
2019    let mask = mask.to_shape((1, mask.len())).unwrap();
2020
2021    let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
2022    let protos = protos.reversed_axes();
2023
2024    let zp = quant_masks.zero_point.as_();
2025
2026    let mask = mask.mapv(|x| x.as_() - zp);
2027
2028    let zp = quant_protos.zero_point.as_();
2029    let protos = protos.mapv(|x| x.as_() - zp);
2030
2031    // Safe to unwrap since the shapes will always be compatible
2032    let segmentation = mask
2033        .dot(&protos)
2034        .into_shape_with_order((shape[0], shape[1], 1))
2035        .unwrap();
2036
2037    let combined_scale = quant_masks.scale * quant_protos.scale;
2038    segmentation.map(|x| {
2039        let val: f32 = (*x).as_() * combined_scale;
2040        let sigmoid = 1.0 / (1.0 + (-val).exp());
2041        (sigmoid * 255.0).round() as u8
2042    })
2043}
2044
2045/// Converts Yolo Instance Segmentation into a 2D mask.
2046///
2047/// The input segmentation is expected to have shape (H, W, 1).
2048///
2049/// The output mask will have shape (H, W), with values 0 or 1 based on the
2050/// threshold.
2051///
2052/// # Errors
2053///
2054/// Returns `DecoderError::InvalidShape` if the input segmentation does not
2055/// have shape (H, W, 1).
2056pub(crate) fn yolo_segmentation_to_mask(
2057    segmentation: ArrayView3<u8>,
2058    threshold: u8,
2059) -> Result<Array2<u8>, crate::DecoderError> {
2060    if segmentation.shape()[2] != 1 {
2061        return Err(crate::DecoderError::InvalidShape(format!(
2062            "Yolo Instance Segmentation should have shape (H, W, 1), got (H, W, {})",
2063            segmentation.shape()[2]
2064        )));
2065    }
2066    Ok(segmentation
2067        .slice(s![.., .., 0])
2068        .map(|x| if *x >= threshold { 1 } else { 0 }))
2069}
2070
2071#[cfg(test)]
2072#[cfg_attr(coverage_nightly, coverage(off))]
2073mod tests {
2074    use super::*;
2075    use ndarray::Array2;
2076
2077    // ========================================================================
2078    // Tests for decode_yolo_end_to_end_det_float
2079    // ========================================================================
2080
2081    #[test]
2082    fn test_end_to_end_det_basic_filtering() {
2083        // Create synthetic end-to-end detection output: (6, N) where rows are
2084        // [x1, y1, x2, y2, conf, class]
2085        // 3 detections: one above threshold, two below
2086        let data: Vec<f32> = vec![
2087            // Detection 0: high score (0.9)
2088            0.1, 0.2, 0.3, // x1 values
2089            0.1, 0.2, 0.3, // y1 values
2090            0.5, 0.6, 0.7, // x2 values
2091            0.5, 0.6, 0.7, // y2 values
2092            0.9, 0.1, 0.2, // confidence scores
2093            0.0, 1.0, 2.0, // class indices
2094        ];
2095        let output = Array2::from_shape_vec((6, 3), data).unwrap();
2096
2097        let mut boxes = Vec::with_capacity(10);
2098        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2099
2100        // Only 1 detection should pass threshold of 0.5
2101        assert_eq!(boxes.len(), 1);
2102        assert_eq!(boxes[0].label, 0);
2103        assert!((boxes[0].score - 0.9).abs() < 0.01);
2104        assert!((boxes[0].bbox.xmin - 0.1).abs() < 0.01);
2105        assert!((boxes[0].bbox.ymin - 0.1).abs() < 0.01);
2106        assert!((boxes[0].bbox.xmax - 0.5).abs() < 0.01);
2107        assert!((boxes[0].bbox.ymax - 0.5).abs() < 0.01);
2108    }
2109
2110    #[test]
2111    fn test_end_to_end_det_all_pass_threshold() {
2112        // All detections above threshold
2113        let data: Vec<f32> = vec![
2114            10.0, 20.0, // x1
2115            10.0, 20.0, // y1
2116            50.0, 60.0, // x2
2117            50.0, 60.0, // y2
2118            0.8, 0.7, // conf (both above 0.5)
2119            1.0, 2.0, // class
2120        ];
2121        let output = Array2::from_shape_vec((6, 2), data).unwrap();
2122
2123        let mut boxes = Vec::with_capacity(10);
2124        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2125
2126        assert_eq!(boxes.len(), 2);
2127        assert_eq!(boxes[0].label, 1);
2128        assert_eq!(boxes[1].label, 2);
2129    }
2130
2131    #[test]
2132    fn test_end_to_end_det_none_pass_threshold() {
2133        // All detections below threshold
2134        let data: Vec<f32> = vec![
2135            10.0, 20.0, // x1
2136            10.0, 20.0, // y1
2137            50.0, 60.0, // x2
2138            50.0, 60.0, // y2
2139            0.1, 0.2, // conf (both below 0.5)
2140            1.0, 2.0, // class
2141        ];
2142        let output = Array2::from_shape_vec((6, 2), data).unwrap();
2143
2144        let mut boxes = Vec::with_capacity(10);
2145        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2146
2147        assert_eq!(boxes.len(), 0);
2148    }
2149
2150    #[test]
2151    fn test_end_to_end_det_capacity_limit() {
2152        // Test that output is truncated to capacity
2153        let data: Vec<f32> = vec![
2154            0.1, 0.2, 0.3, 0.4, 0.5, // x1
2155            0.1, 0.2, 0.3, 0.4, 0.5, // y1
2156            0.5, 0.6, 0.7, 0.8, 0.9, // x2
2157            0.5, 0.6, 0.7, 0.8, 0.9, // y2
2158            0.9, 0.9, 0.9, 0.9, 0.9, // conf (all pass)
2159            0.0, 1.0, 2.0, 3.0, 4.0, // class
2160        ];
2161        let output = Array2::from_shape_vec((6, 5), data).unwrap();
2162
2163        let mut boxes = Vec::with_capacity(2); // Only allow 2 boxes
2164        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2165
2166        assert_eq!(boxes.len(), 2);
2167    }
2168
2169    #[test]
2170    fn test_end_to_end_det_empty_output() {
2171        // Test with zero detections
2172        let output = Array2::<f32>::zeros((6, 0));
2173
2174        let mut boxes = Vec::with_capacity(10);
2175        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2176
2177        assert_eq!(boxes.len(), 0);
2178    }
2179
2180    #[test]
2181    fn test_end_to_end_det_pixel_coordinates() {
2182        // Test with pixel coordinates (non-normalized)
2183        let data: Vec<f32> = vec![
2184            100.0, // x1
2185            200.0, // y1
2186            300.0, // x2
2187            400.0, // y2
2188            0.95,  // conf
2189            5.0,   // class
2190        ];
2191        let output = Array2::from_shape_vec((6, 1), data).unwrap();
2192
2193        let mut boxes = Vec::with_capacity(10);
2194        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
2195
2196        assert_eq!(boxes.len(), 1);
2197        assert_eq!(boxes[0].label, 5);
2198        assert!((boxes[0].bbox.xmin - 100.0).abs() < 0.01);
2199        assert!((boxes[0].bbox.ymin - 200.0).abs() < 0.01);
2200        assert!((boxes[0].bbox.xmax - 300.0).abs() < 0.01);
2201        assert!((boxes[0].bbox.ymax - 400.0).abs() < 0.01);
2202    }
2203
2204    #[test]
2205    fn test_end_to_end_det_invalid_shape() {
2206        // Test with too few rows (needs at least 6)
2207        let output = Array2::<f32>::zeros((5, 3));
2208
2209        let mut boxes = Vec::with_capacity(10);
2210        let result = decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes);
2211
2212        assert!(result.is_err());
2213        assert!(matches!(
2214            result,
2215            Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 6 rows")
2216        ));
2217    }
2218
2219    // ========================================================================
2220    // Tests for decode_yolo_end_to_end_segdet_float
2221    // ========================================================================
2222
2223    #[test]
2224    fn test_end_to_end_segdet_basic() {
2225        // Create synthetic segdet output: (6 + num_protos, N)
2226        // Detection format: [x1, y1, x2, y2, conf, class, mask_coeff_0..31]
2227        let num_protos = 32;
2228        let num_detections = 2;
2229        let num_features = 6 + num_protos;
2230
2231        // Build detection tensor
2232        let mut data = vec![0.0f32; num_features * num_detections];
2233        // Detection 0: passes threshold
2234        data[0] = 0.1; // x1[0]
2235        data[1] = 0.5; // x1[1]
2236        data[num_detections] = 0.1; // y1[0]
2237        data[num_detections + 1] = 0.5; // y1[1]
2238        data[2 * num_detections] = 0.4; // x2[0]
2239        data[2 * num_detections + 1] = 0.9; // x2[1]
2240        data[3 * num_detections] = 0.4; // y2[0]
2241        data[3 * num_detections + 1] = 0.9; // y2[1]
2242        data[4 * num_detections] = 0.9; // conf[0] - passes
2243        data[4 * num_detections + 1] = 0.3; // conf[1] - fails
2244        data[5 * num_detections] = 1.0; // class[0]
2245        data[5 * num_detections + 1] = 2.0; // class[1]
2246                                            // Fill mask coefficients with small values
2247        for i in 6..num_features {
2248            data[i * num_detections] = 0.1;
2249            data[i * num_detections + 1] = 0.1;
2250        }
2251
2252        let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
2253
2254        // Create protos tensor: (proto_height, proto_width, num_protos)
2255        let protos = Array3::<f32>::zeros((16, 16, num_protos));
2256
2257        let mut boxes = Vec::with_capacity(10);
2258        let mut masks = Vec::with_capacity(10);
2259        decode_yolo_end_to_end_segdet_float(
2260            output.view(),
2261            protos.view(),
2262            0.5,
2263            &mut boxes,
2264            &mut masks,
2265        )
2266        .unwrap();
2267
2268        // Only detection 0 should pass
2269        assert_eq!(boxes.len(), 1);
2270        assert_eq!(masks.len(), 1);
2271        assert_eq!(boxes[0].label, 1);
2272        assert!((boxes[0].score - 0.9).abs() < 0.01);
2273    }
2274
2275    #[test]
2276    fn test_end_to_end_segdet_mask_coordinates() {
2277        // Test that mask coordinates match box coordinates
2278        let num_protos = 32;
2279        let num_features = 6 + num_protos;
2280
2281        let mut data = vec![0.0f32; num_features];
2282        data[0] = 0.2; // x1
2283        data[1] = 0.2; // y1
2284        data[2] = 0.8; // x2
2285        data[3] = 0.8; // y2
2286        data[4] = 0.95; // conf
2287        data[5] = 3.0; // class
2288
2289        let output = Array2::from_shape_vec((num_features, 1), data).unwrap();
2290        let protos = Array3::<f32>::zeros((16, 16, num_protos));
2291
2292        let mut boxes = Vec::with_capacity(10);
2293        let mut masks = Vec::with_capacity(10);
2294        decode_yolo_end_to_end_segdet_float(
2295            output.view(),
2296            protos.view(),
2297            0.5,
2298            &mut boxes,
2299            &mut masks,
2300        )
2301        .unwrap();
2302
2303        assert_eq!(boxes.len(), 1);
2304        assert_eq!(masks.len(), 1);
2305
2306        // Mask region is the proto-grid-aligned crop and encloses the
2307        // post-NMS bbox (EDGEAI-1304); on a 16x16 grid each side may snap
2308        // by up to 1/16 = 0.0625.
2309        let step = 1.0 / 16.0;
2310        assert!(masks[0].xmin <= boxes[0].bbox.xmin);
2311        assert!(masks[0].ymin <= boxes[0].bbox.ymin);
2312        assert!(masks[0].xmax >= boxes[0].bbox.xmax);
2313        assert!(masks[0].ymax >= boxes[0].bbox.ymax);
2314        assert!((boxes[0].bbox.xmin - masks[0].xmin) < step);
2315        assert!((boxes[0].bbox.ymin - masks[0].ymin) < step);
2316        assert!((masks[0].xmax - boxes[0].bbox.xmax) < step);
2317        assert!((masks[0].ymax - boxes[0].bbox.ymax) < step);
2318    }
2319
2320    #[test]
2321    fn test_end_to_end_segdet_empty_output() {
2322        let num_protos = 32;
2323        let output = Array2::<f32>::zeros((6 + num_protos, 0));
2324        let protos = Array3::<f32>::zeros((16, 16, num_protos));
2325
2326        let mut boxes = Vec::with_capacity(10);
2327        let mut masks = Vec::with_capacity(10);
2328        decode_yolo_end_to_end_segdet_float(
2329            output.view(),
2330            protos.view(),
2331            0.5,
2332            &mut boxes,
2333            &mut masks,
2334        )
2335        .unwrap();
2336
2337        assert_eq!(boxes.len(), 0);
2338        assert_eq!(masks.len(), 0);
2339    }
2340
2341    #[test]
2342    fn test_end_to_end_segdet_capacity_limit() {
2343        let num_protos = 32;
2344        let num_detections = 5;
2345        let num_features = 6 + num_protos;
2346
2347        let mut data = vec![0.0f32; num_features * num_detections];
2348        // All detections pass threshold
2349        for i in 0..num_detections {
2350            data[i] = 0.1 * (i as f32); // x1
2351            data[num_detections + i] = 0.1 * (i as f32); // y1
2352            data[2 * num_detections + i] = 0.1 * (i as f32) + 0.2; // x2
2353            data[3 * num_detections + i] = 0.1 * (i as f32) + 0.2; // y2
2354            data[4 * num_detections + i] = 0.9; // conf
2355            data[5 * num_detections + i] = i as f32; // class
2356        }
2357
2358        let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
2359        let protos = Array3::<f32>::zeros((16, 16, num_protos));
2360
2361        let mut boxes = Vec::with_capacity(2); // Limit to 2
2362        let mut masks = Vec::with_capacity(2);
2363        decode_yolo_end_to_end_segdet_float(
2364            output.view(),
2365            protos.view(),
2366            0.5,
2367            &mut boxes,
2368            &mut masks,
2369        )
2370        .unwrap();
2371
2372        assert_eq!(boxes.len(), 2);
2373        assert_eq!(masks.len(), 2);
2374    }
2375
2376    #[test]
2377    fn test_end_to_end_segdet_invalid_shape_too_few_rows() {
2378        // Test with too few rows (needs at least 7: 6 base + 1 mask coeff)
2379        let output = Array2::<f32>::zeros((6, 3));
2380        let protos = Array3::<f32>::zeros((16, 16, 32));
2381
2382        let mut boxes = Vec::with_capacity(10);
2383        let mut masks = Vec::with_capacity(10);
2384        let result = decode_yolo_end_to_end_segdet_float(
2385            output.view(),
2386            protos.view(),
2387            0.5,
2388            &mut boxes,
2389            &mut masks,
2390        );
2391
2392        assert!(result.is_err());
2393        assert!(matches!(
2394            result,
2395            Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 7 rows")
2396        ));
2397    }
2398
2399    #[test]
2400    fn test_end_to_end_segdet_invalid_shape_protos_mismatch() {
2401        // Test with mismatched mask coefficients and protos count
2402        let num_protos = 32;
2403        let output = Array2::<f32>::zeros((6 + 16, 3)); // 16 mask coeffs
2404        let protos = Array3::<f32>::zeros((16, 16, num_protos)); // 32 protos
2405
2406        let mut boxes = Vec::with_capacity(10);
2407        let mut masks = Vec::with_capacity(10);
2408        let result = decode_yolo_end_to_end_segdet_float(
2409            output.view(),
2410            protos.view(),
2411            0.5,
2412            &mut boxes,
2413            &mut masks,
2414        );
2415
2416        assert!(result.is_err());
2417        assert!(matches!(
2418            result,
2419            Err(crate::DecoderError::InvalidShape(s)) if s.contains("doesn't match protos count")
2420        ));
2421    }
2422
2423    // ========================================================================
2424    // Tests for decode_yolo_split_end_to_end_segdet_float
2425    // ========================================================================
2426
2427    #[test]
2428    fn test_split_end_to_end_segdet_basic() {
2429        // Create synthetic segdet output: (6 + num_protos, N)
2430        // Detection format: [x1, y1, x2, y2, conf, class, mask_coeff_0..31]
2431        let num_protos = 32;
2432        let num_detections = 2;
2433        let num_features = 6 + num_protos;
2434
2435        // Build detection tensor
2436        let mut data = vec![0.0f32; num_features * num_detections];
2437        // Detection 0: passes threshold
2438        data[0] = 0.1; // x1[0]
2439        data[1] = 0.5; // x1[1]
2440        data[num_detections] = 0.1; // y1[0]
2441        data[num_detections + 1] = 0.5; // y1[1]
2442        data[2 * num_detections] = 0.4; // x2[0]
2443        data[2 * num_detections + 1] = 0.9; // x2[1]
2444        data[3 * num_detections] = 0.4; // y2[0]
2445        data[3 * num_detections + 1] = 0.9; // y2[1]
2446        data[4 * num_detections] = 0.9; // conf[0] - passes
2447        data[4 * num_detections + 1] = 0.3; // conf[1] - fails
2448        data[5 * num_detections] = 1.0; // class[0]
2449        data[5 * num_detections + 1] = 2.0; // class[1]
2450                                            // Fill mask coefficients with small values
2451        for i in 6..num_features {
2452            data[i * num_detections] = 0.1;
2453            data[i * num_detections + 1] = 0.1;
2454        }
2455
2456        let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
2457        let box_coords = output.slice(s![..4, ..]);
2458        let scores = output.slice(s![4..5, ..]);
2459        let classes = output.slice(s![5..6, ..]);
2460        let mask_coeff = output.slice(s![6.., ..]);
2461        // Create protos tensor: (proto_height, proto_width, num_protos)
2462        let protos = Array3::<f32>::zeros((16, 16, num_protos));
2463
2464        let mut boxes = Vec::with_capacity(10);
2465        let mut masks = Vec::with_capacity(10);
2466        decode_yolo_split_end_to_end_segdet_float(
2467            box_coords,
2468            scores,
2469            classes,
2470            mask_coeff,
2471            protos.view(),
2472            0.5,
2473            &mut boxes,
2474            &mut masks,
2475        )
2476        .unwrap();
2477
2478        // Only detection 0 should pass
2479        assert_eq!(boxes.len(), 1);
2480        assert_eq!(masks.len(), 1);
2481        assert_eq!(boxes[0].label, 1);
2482        assert!((boxes[0].score - 0.9).abs() < 0.01);
2483    }
2484
2485    // ========================================================================
2486    // Tests for yolo_segmentation_to_mask
2487    // ========================================================================
2488
2489    #[test]
2490    fn test_segmentation_to_mask_basic() {
2491        // Create a 4x4x1 segmentation with values above and below threshold
2492        let data: Vec<u8> = vec![
2493            100, 200, 50, 150, // row 0
2494            10, 255, 128, 64, // row 1
2495            0, 127, 128, 255, // row 2
2496            64, 64, 192, 192, // row 3
2497        ];
2498        let segmentation = Array3::from_shape_vec((4, 4, 1), data).unwrap();
2499
2500        let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2501
2502        // Values >= 128 should be 1, others 0
2503        assert_eq!(mask[[0, 0]], 0); // 100 < 128
2504        assert_eq!(mask[[0, 1]], 1); // 200 >= 128
2505        assert_eq!(mask[[0, 2]], 0); // 50 < 128
2506        assert_eq!(mask[[0, 3]], 1); // 150 >= 128
2507        assert_eq!(mask[[1, 1]], 1); // 255 >= 128
2508        assert_eq!(mask[[1, 2]], 1); // 128 >= 128
2509        assert_eq!(mask[[2, 0]], 0); // 0 < 128
2510        assert_eq!(mask[[2, 1]], 0); // 127 < 128
2511    }
2512
2513    #[test]
2514    fn test_segmentation_to_mask_all_above() {
2515        let segmentation = Array3::from_elem((4, 4, 1), 255u8);
2516        let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2517        assert!(mask.iter().all(|&x| x == 1));
2518    }
2519
2520    #[test]
2521    fn test_segmentation_to_mask_all_below() {
2522        let segmentation = Array3::from_elem((4, 4, 1), 64u8);
2523        let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
2524        assert!(mask.iter().all(|&x| x == 0));
2525    }
2526
2527    #[test]
2528    fn test_segmentation_to_mask_invalid_shape() {
2529        let segmentation = Array3::from_elem((4, 4, 3), 128u8);
2530        let result = yolo_segmentation_to_mask(segmentation.view(), 128);
2531
2532        assert!(result.is_err());
2533        assert!(matches!(
2534            result,
2535            Err(crate::DecoderError::InvalidShape(s)) if s.contains("(H, W, 1)")
2536        ));
2537    }
2538
2539    // ========================================================================
2540    // Tests for protobox / NORM_LIMIT regression
2541    // ========================================================================
2542
2543    #[test]
2544    fn test_protobox_clamps_edge_coordinates() {
2545        // bbox with xmax=1.0 should not panic (OOB guard)
2546        let protos = Array3::<f32>::zeros((16, 16, 4));
2547        let view = protos.view();
2548        let roi = BoundingBox {
2549            xmin: 0.5,
2550            ymin: 0.5,
2551            xmax: 1.0,
2552            ymax: 1.0,
2553        };
2554        let result = protobox(&view, &roi);
2555        assert!(result.is_ok(), "protobox should accept xmax=1.0");
2556        let (cropped, _roi_norm) = result.unwrap();
2557        // Cropped region must have non-zero spatial dimensions
2558        assert!(cropped.shape()[0] > 0);
2559        assert!(cropped.shape()[1] > 0);
2560        assert_eq!(cropped.shape()[2], 4);
2561    }
2562
2563    #[test]
2564    fn test_protobox_rejects_wildly_out_of_range() {
2565        // bbox with coords > NORM_LIMIT (e.g. 3.0) returns error
2566        let protos = Array3::<f32>::zeros((16, 16, 4));
2567        let view = protos.view();
2568        let roi = BoundingBox {
2569            xmin: 0.0,
2570            ymin: 0.0,
2571            xmax: 3.0,
2572            ymax: 3.0,
2573        };
2574        let result = protobox(&view, &roi);
2575        assert!(
2576            matches!(result, Err(crate::DecoderError::InvalidShape(s)) if s.contains("un-normalized")),
2577            "protobox should reject coords > NORM_LIMIT"
2578        );
2579    }
2580
2581    #[test]
2582    fn test_protobox_accepts_slightly_over_one() {
2583        // bbox with coords at 1.5 (within NORM_LIMIT=2.0) succeeds
2584        let protos = Array3::<f32>::zeros((16, 16, 4));
2585        let view = protos.view();
2586        let roi = BoundingBox {
2587            xmin: 0.0,
2588            ymin: 0.0,
2589            xmax: 1.5,
2590            ymax: 1.5,
2591        };
2592        let result = protobox(&view, &roi);
2593        assert!(
2594            result.is_ok(),
2595            "protobox should accept coords <= NORM_LIMIT (2.0)"
2596        );
2597        let (cropped, _roi_norm) = result.unwrap();
2598        // Entire proto map should be selected when coords > 1.0 (clamped to boundary)
2599        assert_eq!(cropped.shape()[0], 16);
2600        assert_eq!(cropped.shape()[1], 16);
2601    }
2602
2603    #[test]
2604    fn test_segdet_float_proto_no_panic() {
2605        // Simulates YOLOv8n-seg: output0 = [116, 8400] (4 box + 80 class + 32 mask coeff)
2606        // output1 (protos) = [32, 160, 160]
2607        let num_proposals = 100; // enough to produce idx >= 32
2608        let num_classes = 80;
2609        let num_mask_coeffs = 32;
2610        let rows = 4 + num_classes + num_mask_coeffs; // 116
2611
2612        // Fill boxes with valid xywh data so some detections pass the threshold.
2613        // Layout is [116, num_proposals] row-major: row 0=cx, 1=cy, 2=w, 3=h,
2614        // rows 4..84=class scores, rows 84..116=mask coefficients.
2615        let mut data = vec![0.0f32; rows * num_proposals];
2616        for i in 0..num_proposals {
2617            let row = |r: usize| r * num_proposals + i;
2618            data[row(0)] = 320.0; // cx
2619            data[row(1)] = 320.0; // cy
2620            data[row(2)] = 50.0; // w
2621            data[row(3)] = 50.0; // h
2622            data[row(4)] = 0.9; // class-0 score
2623        }
2624        let boxes = ndarray::Array2::from_shape_vec((rows, num_proposals), data).unwrap();
2625
2626        // Protos must be in HWC order. Under the HAL physical-order
2627        // contract, callers declare shape+dshape matching producer memory
2628        // and swap_axes_if_needed permutes the stride tuple into canonical
2629        // [batch, height, width, num_protos] before this function sees it.
2630        let protos = ndarray::Array3::<f32>::zeros((160, 160, num_mask_coeffs));
2631
2632        let mut output_boxes = Vec::with_capacity(300);
2633
2634        // This panicked before fix: mask_tensor.row(idx) with idx >= 32
2635        let proto_data = impl_yolo_segdet_float_proto::<XYWH, _, _>(
2636            boxes.view(),
2637            protos.view(),
2638            0.5,
2639            0.7,
2640            Some(Nms::default()),
2641            MAX_NMS_CANDIDATES,
2642            300,
2643            None,
2644            None,
2645            &mut output_boxes,
2646        );
2647
2648        // Should produce detections (NMS will collapse many overlapping boxes)
2649        assert!(!output_boxes.is_empty());
2650        let coeffs_shape = proto_data.mask_coefficients.shape();
2651        assert_eq!(coeffs_shape[0], output_boxes.len());
2652        // Each mask coefficient vector should have 32 elements
2653        assert_eq!(coeffs_shape[1], num_mask_coeffs);
2654    }
2655
2656    // ========================================================================
2657    // Pre-NMS top-K cap (MAX_NMS_CANDIDATES)
2658    // ========================================================================
2659
2660    /// At very low score thresholds (e.g., t=0.01 on YOLOv8 with 8400×80
2661    /// candidates) almost every score passes the filter, feeding O(n²)
2662    /// NMS and a per-survivor mask matmul. The decoder caps the
2663    /// candidate set fed to NMS at `MAX_NMS_CANDIDATES` (Ultralytics
2664    /// default 30 000) to bound worst-case decode time.
2665    ///
2666    /// This regression test pumps 50 000 above-threshold candidates
2667    /// into `impl_yolo_segdet_get_boxes` with NMS bypassed (Nms=None)
2668    /// and a generous post-NMS cap. Before the fix, the function
2669    /// returned all 50 000; after the fix, exactly 30 000.
2670    #[test]
2671    fn test_pre_nms_cap_truncates_excess_candidates() {
2672        let n: usize = 50_000;
2673        let num_classes = 1;
2674
2675        // Identical valid boxes. Distinct scores (descending) so the
2676        // top-K cap keeps the highest-scoring ones in deterministic
2677        // order — letting us assert *which* ones survived.
2678        let mut boxes_data = Vec::with_capacity(n * 4);
2679        let mut scores_data = Vec::with_capacity(n * num_classes);
2680        for i in 0..n {
2681            boxes_data.extend_from_slice(&[0.1f32, 0.1, 0.5, 0.5]);
2682            // score_i = 0.99 - i * 1e-7 keeps everything well above 0.1
2683            // threshold but strictly decreasing.
2684            scores_data.push(0.99 - (i as f32) * 1e-7);
2685        }
2686        let boxes = Array2::from_shape_vec((n, 4), boxes_data).unwrap();
2687        let scores = Array2::from_shape_vec((n, num_classes), scores_data).unwrap();
2688
2689        let result = impl_yolo_segdet_get_boxes::<XYXY, _, _>(
2690            boxes.view(),
2691            scores.view(),
2692            0.1,
2693            1.0,                             // IoU 1.0 → NMS suppresses nothing
2694            Some(Nms::ClassAgnostic),        // NMS enabled so pre_nms_top_k applies
2695            crate::yolo::MAX_NMS_CANDIDATES, // pre_nms_top_k
2696            usize::MAX,                      // no post-NMS truncation
2697        );
2698
2699        assert_eq!(
2700            result.len(),
2701            crate::yolo::MAX_NMS_CANDIDATES,
2702            "pre-NMS cap should truncate to MAX_NMS_CANDIDATES; got {}",
2703            result.len()
2704        );
2705        // Top-K survivors: highest scores were the first n indices,
2706        // so survivor 0 must have score ~0.99.
2707        let top_score = result[0].0.score;
2708        assert!(
2709            top_score > 0.98,
2710            "highest-ranked survivor should have the largest score, got {top_score}"
2711        );
2712    }
2713
2714    /// Counterpart for the quantized split path. Same contract: feed
2715    /// more than `MAX_NMS_CANDIDATES` survivors above the quantized
2716    /// threshold, confirm `impl_yolo_split_segdet_quant_get_boxes`
2717    /// truncates before NMS.
2718    #[test]
2719    fn test_pre_nms_cap_truncates_excess_candidates_quant() {
2720        use crate::Quantization;
2721        let n: usize = 50_000;
2722        let num_classes = 1;
2723
2724        // i8 boxes with simple scale/zp; the box value 50 dequantizes
2725        // to 0.5 with scale=0.01, zp=0 — fine for a flat box set.
2726        let boxes_data = (0..n).flat_map(|_| [10i8, 10, 50, 50]).collect::<Vec<_>>();
2727        let boxes = Array2::from_shape_vec((n, 4), boxes_data).unwrap();
2728        let quant_boxes = Quantization {
2729            scale: 0.01,
2730            zero_point: 0,
2731        };
2732
2733        // u8 scores: distinct descending values, all well above threshold.
2734        // value 250 → 0.98 with scale 0.00392, zp 0.
2735        // value (250 - i % 200) keeps a wide spread above the dequant
2736        // threshold of 0.5.
2737        let scores_data: Vec<u8> = (0..n)
2738            .map(|i| 250u8.saturating_sub((i % 200) as u8))
2739            .collect();
2740        let scores = Array2::from_shape_vec((n, num_classes), scores_data).unwrap();
2741        let quant_scores = Quantization {
2742            scale: 0.00392,
2743            zero_point: 0,
2744        };
2745
2746        let result = impl_yolo_split_segdet_quant_get_boxes::<XYXY, _, _>(
2747            (boxes.view(), quant_boxes),
2748            (scores.view(), quant_scores),
2749            0.1,
2750            1.0,                             // IoU 1.0 → NMS suppresses nothing
2751            Some(Nms::ClassAgnostic),        // NMS enabled so pre_nms_top_k applies
2752            crate::yolo::MAX_NMS_CANDIDATES, // pre_nms_top_k
2753            usize::MAX,                      // no post-NMS truncation
2754        );
2755
2756        assert_eq!(
2757            result.len(),
2758            crate::yolo::MAX_NMS_CANDIDATES,
2759            "quant path pre-NMS cap should truncate to MAX_NMS_CANDIDATES; got {}",
2760            result.len()
2761        );
2762    }
2763
2764    /// Regression test for HAILORT_BUG.md — the YoloSegDet path
2765    /// (combined `(4 + nc + nm, N)` detection tensor + separate protos)
2766    /// must pair each surviving detection with the mask coefficient
2767    /// row at the SAME anchor index the box came from. The validator
2768    /// sees this path miss the pairing under schema-v2 Hailo inputs
2769    /// (mAP collapse from 46.8 → 3.65 while mask IoU stays at 66.9,
2770    /// the fingerprint of mask-to-detection misalignment).
2771    ///
2772    /// Construction: three anchors with distinct mask-coef signatures
2773    /// that, after dot(coefs, protos) + sigmoid, produce HIGH vs LOW
2774    /// mask pixel values. Two anchors survive (one high, one low); if
2775    /// the mask row is looked up at the wrong index, the per-detection
2776    /// mean mask value would cross the threshold and we catch it.
2777    #[test]
2778    fn segdet_combined_tensor_pairs_detection_with_matching_mask_row() {
2779        let nc = 2; // num_classes
2780        let nm = 2; // num_protos
2781        let n = 3; // num_anchors
2782        let feat = 4 + nc + nm; // 8
2783
2784        // Tensor layout: (8, 3) rows=features, cols=anchors.
2785        // Row indices:  0..4 = xywh, 4..6 = scores, 6..8 = mask_coefs.
2786        //
2787        //         anchor 0 | anchor 1 | anchor 2
2788        // xc       0.2      | 0.5      | 0.8
2789        // yc       0.2      | 0.5      | 0.8
2790        // w        0.1      | 0.1      | 0.1
2791        // h        0.1      | 0.1      | 0.1
2792        // s[0]     0.9      | 0.0      | 0.8   (class 0)
2793        // s[1]     0.0      | 0.0      | 0.0   (class 1 — always loses)
2794        // m[0]     3.0      | 0.0      | -3.0  (high for a0, low for a2)
2795        // m[1]     3.0      | 0.0      | -3.0
2796        //
2797        // Proto[0] = Proto[1] = all-ones (8x8), so
2798        //   mask(a0) = sigmoid(3 + 3) ≈ 0.9975 → 254
2799        //   mask(a2) = sigmoid(-3 + -3) ≈ 0.0025 → 1
2800        // 250-point gap makes any misalignment trivially detectable.
2801        let mut data = vec![0.0f32; feat * n];
2802        let set = |d: &mut [f32], r: usize, c: usize, v: f32| d[r * n + c] = v;
2803        set(&mut data, 0, 0, 0.2);
2804        set(&mut data, 1, 0, 0.2);
2805        set(&mut data, 2, 0, 0.1);
2806        set(&mut data, 3, 0, 0.1);
2807        set(&mut data, 0, 1, 0.5);
2808        set(&mut data, 1, 1, 0.5);
2809        set(&mut data, 2, 1, 0.1);
2810        set(&mut data, 3, 1, 0.1);
2811        set(&mut data, 0, 2, 0.8);
2812        set(&mut data, 1, 2, 0.8);
2813        set(&mut data, 2, 2, 0.1);
2814        set(&mut data, 3, 2, 0.1);
2815        set(&mut data, 4, 0, 0.9);
2816        set(&mut data, 4, 2, 0.8);
2817        set(&mut data, 6, 0, 3.0);
2818        set(&mut data, 7, 0, 3.0);
2819        set(&mut data, 6, 2, -3.0);
2820        set(&mut data, 7, 2, -3.0);
2821
2822        let output = Array2::from_shape_vec((feat, n), data).unwrap();
2823        let protos = Array3::<f32>::from_elem((8, 8, nm), 1.0);
2824
2825        let mut boxes: Vec<DetectBox> = Vec::with_capacity(10);
2826        let mut masks: Vec<Segmentation> = Vec::with_capacity(10);
2827        decode_yolo_segdet_float(
2828            output.view(),
2829            protos.view(),
2830            0.5,
2831            0.5,
2832            Some(Nms::ClassAgnostic),
2833            &mut boxes,
2834            &mut masks,
2835        )
2836        .unwrap();
2837
2838        assert_eq!(
2839            boxes.len(),
2840            2,
2841            "two anchors above threshold should survive (a0 score=0.9, a2 score=0.8); got {}",
2842            boxes.len()
2843        );
2844
2845        // Build a (anchor_index → mask_mean) mapping from the results.
2846        // Anchor 0 has centre (0.2, 0.2), anchor 2 has centre (0.8,
2847        // 0.8). The DetectBox bbox is the post-XYWH-to-XYXY conversion
2848        // of the original xywh; cropping inside protobox may shrink it,
2849        // so match by centre (0.2 vs 0.8) rather than exact bbox.
2850        for (b, m) in boxes.iter().zip(masks.iter()) {
2851            let cx = (b.bbox.xmin + b.bbox.xmax) * 0.5;
2852            let mean = {
2853                let s = &m.segmentation;
2854                let total: u32 = s.iter().map(|&v| v as u32).sum();
2855                total as f32 / s.len() as f32
2856            };
2857            if cx < 0.3 {
2858                // anchor 0 — expect HIGH mask values ≈ 254
2859                assert!(
2860                    mean > 200.0,
2861                    "anchor 0 detection (centre {cx:.2}) should have high-value mask; got mean {mean}"
2862                );
2863            } else if cx > 0.7 {
2864                // anchor 2 — expect LOW mask values ≈ 1
2865                assert!(
2866                    mean < 50.0,
2867                    "anchor 2 detection (centre {cx:.2}) should have low-value mask; got mean {mean}"
2868                );
2869            } else {
2870                panic!("unexpected detection centre {cx:.2}");
2871            }
2872        }
2873    }
2874}