Skip to main content

edgefirst_decoder/
yolo.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use std::fmt::Debug;
5
6use ndarray::{
7    parallel::prelude::{IntoParallelIterator, ParallelIterator},
8    s, Array2, Array3, ArrayView1, ArrayView2, ArrayView3,
9};
10use ndarray_stats::QuantileExt;
11use num_traits::{AsPrimitive, Float, PrimInt, Signed};
12
13use crate::{
14    byte::{
15        nms_class_aware_int, nms_extra_class_aware_int, nms_extra_int, nms_int,
16        postprocess_boxes_index_quant, postprocess_boxes_quant, quantize_score_threshold,
17    },
18    configs::Nms,
19    dequant_detect_box,
20    float::{
21        nms_class_aware_float, nms_extra_class_aware_float, nms_extra_float, nms_float,
22        postprocess_boxes_float, postprocess_boxes_index_float,
23    },
24    BBoxTypeTrait, BoundingBox, DetectBox, DetectBoxQuantized, ProtoData, ProtoTensor,
25    Quantization, Segmentation, XYWH, XYXY,
26};
27
28/// Dispatches to the appropriate NMS function based on mode for float boxes.
29fn dispatch_nms_float(nms: Option<Nms>, iou: f32, boxes: Vec<DetectBox>) -> Vec<DetectBox> {
30    match nms {
31        Some(Nms::ClassAgnostic) => nms_float(iou, boxes),
32        Some(Nms::ClassAware) => nms_class_aware_float(iou, boxes),
33        None => boxes, // bypass NMS
34    }
35}
36
37/// Dispatches to the appropriate NMS function based on mode for float boxes
38/// with extra data.
39fn dispatch_nms_extra_float<E: Send + Sync>(
40    nms: Option<Nms>,
41    iou: f32,
42    boxes: Vec<(DetectBox, E)>,
43) -> Vec<(DetectBox, E)> {
44    match nms {
45        Some(Nms::ClassAgnostic) => nms_extra_float(iou, boxes),
46        Some(Nms::ClassAware) => nms_extra_class_aware_float(iou, boxes),
47        None => boxes, // bypass NMS
48    }
49}
50
51/// Dispatches to the appropriate NMS function based on mode for quantized
52/// boxes.
53fn dispatch_nms_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync>(
54    nms: Option<Nms>,
55    iou: f32,
56    boxes: Vec<DetectBoxQuantized<SCORE>>,
57) -> Vec<DetectBoxQuantized<SCORE>> {
58    match nms {
59        Some(Nms::ClassAgnostic) => nms_int(iou, boxes),
60        Some(Nms::ClassAware) => nms_class_aware_int(iou, boxes),
61        None => boxes, // bypass NMS
62    }
63}
64
65/// Dispatches to the appropriate NMS function based on mode for quantized boxes
66/// with extra data.
67fn dispatch_nms_extra_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send + Sync>(
68    nms: Option<Nms>,
69    iou: f32,
70    boxes: Vec<(DetectBoxQuantized<SCORE>, E)>,
71) -> Vec<(DetectBoxQuantized<SCORE>, E)> {
72    match nms {
73        Some(Nms::ClassAgnostic) => nms_extra_int(iou, boxes),
74        Some(Nms::ClassAware) => nms_extra_class_aware_int(iou, boxes),
75        None => boxes, // bypass NMS
76    }
77}
78
79/// Decodes YOLO detection outputs from quantized tensors into detection boxes.
80///
81/// Boxes are expected to be in XYWH format.
82///
83/// Expected shapes of inputs:
84/// - output: (4 + num_classes, num_boxes)
85pub fn decode_yolo_det<BOX: PrimInt + AsPrimitive<f32> + Send + Sync>(
86    output: (ArrayView2<BOX>, Quantization),
87    score_threshold: f32,
88    iou_threshold: f32,
89    nms: Option<Nms>,
90    output_boxes: &mut Vec<DetectBox>,
91) where
92    f32: AsPrimitive<BOX>,
93{
94    impl_yolo_quant::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
95}
96
97/// Decodes YOLO detection outputs from float tensors into detection boxes.
98///
99/// Boxes are expected to be in XYWH format.
100///
101/// Expected shapes of inputs:
102/// - output: (4 + num_classes, num_boxes)
103pub fn decode_yolo_det_float<T>(
104    output: ArrayView2<T>,
105    score_threshold: f32,
106    iou_threshold: f32,
107    nms: Option<Nms>,
108    output_boxes: &mut Vec<DetectBox>,
109) where
110    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
111    f32: AsPrimitive<T>,
112{
113    impl_yolo_float::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
114}
115
116/// Decodes YOLO detection and segmentation outputs from quantized tensors into
117/// detection boxes and segmentation masks.
118///
119/// Boxes are expected to be in XYWH format.
120///
121/// Expected shapes of inputs:
122/// - boxes: (4 + num_classes + num_protos, num_boxes)
123/// - protos: (proto_height, proto_width, num_protos)
124///
125/// # Panics
126/// Panics if shapes don't match the expected dimensions.
127pub fn decode_yolo_segdet_quant<
128    BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
129    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
130>(
131    boxes: (ArrayView2<BOX>, Quantization),
132    protos: (ArrayView3<PROTO>, Quantization),
133    score_threshold: f32,
134    iou_threshold: f32,
135    nms: Option<Nms>,
136    output_boxes: &mut Vec<DetectBox>,
137    output_masks: &mut Vec<Segmentation>,
138) where
139    f32: AsPrimitive<BOX>,
140{
141    impl_yolo_segdet_quant::<XYWH, _, _>(
142        boxes,
143        protos,
144        score_threshold,
145        iou_threshold,
146        nms,
147        output_boxes,
148        output_masks,
149    );
150}
151
152/// Decodes YOLO detection and segmentation outputs from float tensors into
153/// detection boxes and segmentation masks.
154///
155/// Boxes are expected to be in XYWH format.
156///
157/// Expected shapes of inputs:
158/// - boxes: (4 + num_classes + num_protos, num_boxes)
159/// - protos: (proto_height, proto_width, num_protos)
160///
161/// # Panics
162/// Panics if shapes don't match the expected dimensions.
163pub fn decode_yolo_segdet_float<T>(
164    boxes: ArrayView2<T>,
165    protos: ArrayView3<T>,
166    score_threshold: f32,
167    iou_threshold: f32,
168    nms: Option<Nms>,
169    output_boxes: &mut Vec<DetectBox>,
170    output_masks: &mut Vec<Segmentation>,
171) where
172    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
173    f32: AsPrimitive<T>,
174{
175    impl_yolo_segdet_float::<XYWH, _, _>(
176        boxes,
177        protos,
178        score_threshold,
179        iou_threshold,
180        nms,
181        output_boxes,
182        output_masks,
183    );
184}
185
186/// Decodes YOLO split detection outputs from quantized tensors into detection
187/// boxes.
188///
189/// Boxes are expected to be in XYWH format.
190///
191/// Expected shapes of inputs:
192/// - boxes: (4, num_boxes)
193/// - scores: (num_classes, num_boxes)
194///
195/// # Panics
196/// Panics if shapes don't match the expected dimensions.
197pub fn decode_yolo_split_det_quant<
198    BOX: PrimInt + AsPrimitive<i32> + AsPrimitive<f32> + Send + Sync,
199    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
200>(
201    boxes: (ArrayView2<BOX>, Quantization),
202    scores: (ArrayView2<SCORE>, Quantization),
203    score_threshold: f32,
204    iou_threshold: f32,
205    nms: Option<Nms>,
206    output_boxes: &mut Vec<DetectBox>,
207) where
208    f32: AsPrimitive<SCORE>,
209{
210    impl_yolo_split_quant::<XYWH, _, _>(
211        boxes,
212        scores,
213        score_threshold,
214        iou_threshold,
215        nms,
216        output_boxes,
217    );
218}
219
220/// Decodes YOLO split detection outputs from float tensors into detection
221/// boxes.
222///
223/// Boxes are expected to be in XYWH format.
224///
225/// Expected shapes of inputs:
226/// - boxes: (4, num_boxes)
227/// - scores: (num_classes, num_boxes)
228///
229/// # Panics
230/// Panics if shapes don't match the expected dimensions.
231pub fn decode_yolo_split_det_float<T>(
232    boxes: ArrayView2<T>,
233    scores: ArrayView2<T>,
234    score_threshold: f32,
235    iou_threshold: f32,
236    nms: Option<Nms>,
237    output_boxes: &mut Vec<DetectBox>,
238) where
239    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
240    f32: AsPrimitive<T>,
241{
242    impl_yolo_split_float::<XYWH, _, _>(
243        boxes,
244        scores,
245        score_threshold,
246        iou_threshold,
247        nms,
248        output_boxes,
249    );
250}
251
252/// Decodes YOLO split detection segmentation outputs from quantized tensors
253/// into detection boxes and segmentation masks.
254///
255/// Boxes are expected to be in XYWH format.
256///
257/// Expected shapes of inputs:
258/// - boxes_tensor: (4, num_boxes)
259/// - scores_tensor: (num_classes, num_boxes)
260/// - mask_tensor: (num_protos, num_boxes)
261/// - protos: (proto_height, proto_width, num_protos)
262///
263/// # Panics
264/// Panics if shapes don't match the expected dimensions.
265#[allow(clippy::too_many_arguments)]
266pub fn decode_yolo_split_segdet<
267    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
268    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
269    MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
270    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
271>(
272    boxes: (ArrayView2<BOX>, Quantization),
273    scores: (ArrayView2<SCORE>, Quantization),
274    mask_coeff: (ArrayView2<MASK>, Quantization),
275    protos: (ArrayView3<PROTO>, Quantization),
276    score_threshold: f32,
277    iou_threshold: f32,
278    nms: Option<Nms>,
279    output_boxes: &mut Vec<DetectBox>,
280    output_masks: &mut Vec<Segmentation>,
281) where
282    f32: AsPrimitive<SCORE>,
283{
284    impl_yolo_split_segdet_quant::<XYWH, _, _, _, _>(
285        boxes,
286        scores,
287        mask_coeff,
288        protos,
289        score_threshold,
290        iou_threshold,
291        nms,
292        output_boxes,
293        output_masks,
294    );
295}
296
297/// Decodes YOLO split detection segmentation outputs from float tensors
298/// into detection boxes and segmentation masks.
299///
300/// Boxes are expected to be in XYWH format.
301///
302/// Expected shapes of inputs:
303/// - boxes_tensor: (4, num_boxes)
304/// - scores_tensor: (num_classes, num_boxes)
305/// - mask_tensor: (num_protos, num_boxes)
306/// - protos: (proto_height, proto_width, num_protos)
307///
308/// # Panics
309/// Panics if shapes don't match the expected dimensions.
310#[allow(clippy::too_many_arguments)]
311pub fn decode_yolo_split_segdet_float<T>(
312    boxes: ArrayView2<T>,
313    scores: ArrayView2<T>,
314    mask_coeff: ArrayView2<T>,
315    protos: ArrayView3<T>,
316    score_threshold: f32,
317    iou_threshold: f32,
318    nms: Option<Nms>,
319    output_boxes: &mut Vec<DetectBox>,
320    output_masks: &mut Vec<Segmentation>,
321) where
322    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
323    f32: AsPrimitive<T>,
324{
325    impl_yolo_split_segdet_float::<XYWH, _, _, _, _>(
326        boxes,
327        scores,
328        mask_coeff,
329        protos,
330        score_threshold,
331        iou_threshold,
332        nms,
333        output_boxes,
334        output_masks,
335    );
336}
337
338/// Decodes end-to-end YOLO detection outputs (post-NMS from model).
339/// Expects an array of shape `(6, N)`, where the first dimension (rows)
340/// corresponds to the 6 per-detection features
341/// `[x1, y1, x2, y2, conf, class]` and the second dimension (columns)
342/// indexes the `N` detections.
343/// Boxes are output directly without NMS (the model already applied NMS).
344///
345/// Coordinates may be normalized `[0, 1]` or absolute pixel values depending
346/// on the model configuration. The caller should check
347/// `decoder.normalized_boxes()` to determine which.
348///
349/// # Errors
350///
351/// Returns `DecoderError::InvalidShape` if `output` has fewer than 6 rows.
352pub fn decode_yolo_end_to_end_det_float<T>(
353    output: ArrayView2<T>,
354    score_threshold: f32,
355    output_boxes: &mut Vec<DetectBox>,
356) -> Result<(), crate::DecoderError>
357where
358    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
359    f32: AsPrimitive<T>,
360{
361    // Validate input shape: need at least 6 rows (x1, y1, x2, y2, conf, class)
362    if output.shape()[0] < 6 {
363        return Err(crate::DecoderError::InvalidShape(format!(
364            "End-to-end detection output requires at least 6 rows, got {}",
365            output.shape()[0]
366        )));
367    }
368
369    // Input shape: (6, N) -> transpose to (N, 4) for boxes and (N, 1) for scores
370    let boxes = output.slice(s![0..4, ..]).reversed_axes();
371    let scores = output.slice(s![4..5, ..]).reversed_axes();
372    let classes = output.slice(s![5, ..]);
373    let mut boxes =
374        postprocess_boxes_index_float::<XYXY, _, _>(score_threshold.as_(), boxes, scores);
375    boxes.truncate(output_boxes.capacity());
376    output_boxes.clear();
377    for (mut b, i) in boxes.into_iter() {
378        b.label = classes[i].as_() as usize;
379        output_boxes.push(b);
380    }
381    // No NMS — model output is already post-NMS
382    Ok(())
383}
384
385/// Decodes end-to-end YOLO detection + segmentation outputs (post-NMS from
386/// model).
387///
388/// Input shapes:
389/// - detection: (6 + num_protos, N) where rows are [x1, y1, x2, y2, conf,
390///   class, mask_coeff_0, ..., mask_coeff_31]
391/// - protos: (proto_height, proto_width, num_protos)
392///
393/// Boxes are output directly without NMS (model already applied NMS).
394/// Coordinates may be normalized [0,1] or pixel values depending on model
395/// config.
396///
397/// # Errors
398///
399/// Returns `DecoderError::InvalidShape` if:
400/// - output has fewer than 7 rows (6 base + at least 1 mask coefficient)
401/// - protos shape doesn't match mask coefficients count
402pub fn decode_yolo_end_to_end_segdet_float<T>(
403    output: ArrayView2<T>,
404    protos: ArrayView3<T>,
405    score_threshold: f32,
406    output_boxes: &mut Vec<DetectBox>,
407    output_masks: &mut Vec<crate::Segmentation>,
408) -> Result<(), crate::DecoderError>
409where
410    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
411    f32: AsPrimitive<T>,
412{
413    // Validate input shape: need at least 7 rows (6 base + at least 1 mask coeff)
414    if output.shape()[0] < 7 {
415        return Err(crate::DecoderError::InvalidShape(format!(
416            "End-to-end segdet output requires at least 7 rows, got {}",
417            output.shape()[0]
418        )));
419    }
420
421    let num_mask_coeffs = output.shape()[0] - 6;
422    let num_protos = protos.shape()[2];
423    if num_mask_coeffs != num_protos {
424        return Err(crate::DecoderError::InvalidShape(format!(
425            "Mask coefficients count ({}) doesn't match protos count ({})",
426            num_mask_coeffs, num_protos
427        )));
428    }
429
430    // Input shape: (6+num_protos, N) -> transpose for postprocessing
431    let boxes = output.slice(s![0..4, ..]).reversed_axes();
432    let scores = output.slice(s![4..5, ..]).reversed_axes();
433    let classes = output.slice(s![5, ..]);
434    let mask_coeff = output.slice(s![6.., ..]).reversed_axes();
435    let mut boxes =
436        postprocess_boxes_index_float::<XYXY, _, _>(score_threshold.as_(), boxes, scores);
437    boxes.truncate(output_boxes.capacity());
438
439    for (b, ind) in &mut boxes {
440        b.label = classes[*ind].as_() as usize;
441    }
442
443    // No NMS — model output is already post-NMS
444
445    let boxes = decode_segdet_f32(boxes, mask_coeff, protos);
446
447    output_boxes.clear();
448    output_masks.clear();
449    for (b, m) in boxes.into_iter() {
450        output_boxes.push(b);
451        output_masks.push(Segmentation {
452            xmin: b.bbox.xmin,
453            ymin: b.bbox.ymin,
454            xmax: b.bbox.xmax,
455            ymax: b.bbox.ymax,
456            segmentation: m,
457        });
458    }
459    Ok(())
460}
461
462/// Decodes split end-to-end YOLO detection outputs (post-NMS from model).
463///
464/// Input shapes (after batch dim removed):
465/// - boxes: (N, 4) — xyxy pixel coordinates
466/// - scores: (N, 1) — confidence of the top class
467/// - classes: (N, 1) — class index of the top class
468///
469/// Boxes are output directly without NMS (model already applied NMS).
470pub fn decode_yolo_split_end_to_end_det_float<T: Float + AsPrimitive<f32>>(
471    boxes: ArrayView2<T>,
472    scores: ArrayView2<T>,
473    classes: ArrayView2<T>,
474    score_threshold: f32,
475    output_boxes: &mut Vec<DetectBox>,
476) -> Result<(), crate::DecoderError> {
477    let n = boxes.shape()[0];
478    if boxes.shape()[1] != 4 {
479        return Err(crate::DecoderError::InvalidShape(format!(
480            "Split end-to-end boxes must have 4 columns, got {}",
481            boxes.shape()[1]
482        )));
483    }
484    output_boxes.clear();
485    for i in 0..n {
486        let score: f32 = scores[[i, 0]].as_();
487        if score < score_threshold {
488            continue;
489        }
490        if output_boxes.len() >= output_boxes.capacity() {
491            break;
492        }
493        output_boxes.push(DetectBox {
494            bbox: BoundingBox {
495                xmin: boxes[[i, 0]].as_(),
496                ymin: boxes[[i, 1]].as_(),
497                xmax: boxes[[i, 2]].as_(),
498                ymax: boxes[[i, 3]].as_(),
499            },
500            score,
501            label: classes[[i, 0]].as_() as usize,
502        });
503    }
504    Ok(())
505}
506
507/// Decodes split end-to-end YOLO detection + segmentation outputs.
508///
509/// Input shapes (after batch dim removed):
510/// - boxes: (N, 4) — xyxy pixel coordinates
511/// - scores: (N, 1) — confidence
512/// - classes: (N, 1) — class index
513/// - mask_coeff: (N, num_protos) — mask coefficients per detection
514/// - protos: (proto_h, proto_w, num_protos) — prototype masks
515#[allow(clippy::too_many_arguments)]
516pub fn decode_yolo_split_end_to_end_segdet_float<T>(
517    boxes: ArrayView2<T>,
518    scores: ArrayView2<T>,
519    classes: ArrayView2<T>,
520    mask_coeff: ArrayView2<T>,
521    protos: ArrayView3<T>,
522    score_threshold: f32,
523    output_boxes: &mut Vec<DetectBox>,
524    output_masks: &mut Vec<crate::Segmentation>,
525) -> Result<(), crate::DecoderError>
526where
527    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
528    f32: AsPrimitive<T>,
529{
530    let n = boxes.shape()[0];
531    if boxes.shape()[1] != 4 {
532        return Err(crate::DecoderError::InvalidShape(format!(
533            "Split end-to-end boxes must have 4 columns, got {}",
534            boxes.shape()[1]
535        )));
536    }
537
538    // Collect qualifying detections with their indices
539    let mut qualifying: Vec<(DetectBox, usize)> = Vec::with_capacity(output_boxes.capacity());
540    for i in 0..n {
541        let score: f32 = scores[[i, 0]].as_();
542        if score < score_threshold {
543            continue;
544        }
545        if qualifying.len() >= output_boxes.capacity() {
546            break;
547        }
548        qualifying.push((
549            DetectBox {
550                bbox: BoundingBox {
551                    xmin: boxes[[i, 0]].as_(),
552                    ymin: boxes[[i, 1]].as_(),
553                    xmax: boxes[[i, 2]].as_(),
554                    ymax: boxes[[i, 3]].as_(),
555                },
556                score,
557                label: classes[[i, 0]].as_() as usize,
558            },
559            i,
560        ));
561    }
562
563    // Process masks using existing infrastructure
564    let result = decode_segdet_f32(qualifying, mask_coeff, protos);
565
566    output_boxes.clear();
567    output_masks.clear();
568    for (b, m) in result.into_iter() {
569        output_masks.push(crate::Segmentation {
570            xmin: b.bbox.xmin,
571            ymin: b.bbox.ymin,
572            xmax: b.bbox.xmax,
573            ymax: b.bbox.ymax,
574            segmentation: m,
575        });
576        output_boxes.push(b);
577    }
578    Ok(())
579}
580
581/// Internal implementation of YOLO decoding for quantized tensors.
582///
583/// Expected shapes of inputs:
584/// - output: (4 + num_classes, num_boxes)
585pub fn impl_yolo_quant<B: BBoxTypeTrait, T: PrimInt + AsPrimitive<f32> + Send + Sync>(
586    output: (ArrayView2<T>, Quantization),
587    score_threshold: f32,
588    iou_threshold: f32,
589    nms: Option<Nms>,
590    output_boxes: &mut Vec<DetectBox>,
591) where
592    f32: AsPrimitive<T>,
593{
594    let (boxes, quant_boxes) = output;
595    let (boxes_tensor, scores_tensor) = postprocess_yolo(&boxes);
596
597    let boxes = {
598        let score_threshold = quantize_score_threshold(score_threshold, quant_boxes);
599        postprocess_boxes_quant::<B, _, _>(
600            score_threshold,
601            boxes_tensor,
602            scores_tensor,
603            quant_boxes,
604        )
605    };
606
607    let boxes = dispatch_nms_int(nms, iou_threshold, boxes);
608    let len = output_boxes.capacity().min(boxes.len());
609    output_boxes.clear();
610    for b in boxes.iter().take(len) {
611        output_boxes.push(dequant_detect_box(b, quant_boxes));
612    }
613}
614
615/// Internal implementation of YOLO decoding for float tensors.
616///
617/// Expected shapes of inputs:
618/// - output: (4 + num_classes, num_boxes)
619pub fn impl_yolo_float<B: BBoxTypeTrait, T: Float + AsPrimitive<f32> + Send + Sync>(
620    output: ArrayView2<T>,
621    score_threshold: f32,
622    iou_threshold: f32,
623    nms: Option<Nms>,
624    output_boxes: &mut Vec<DetectBox>,
625) where
626    f32: AsPrimitive<T>,
627{
628    let (boxes_tensor, scores_tensor) = postprocess_yolo(&output);
629    let boxes =
630        postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
631    let boxes = dispatch_nms_float(nms, iou_threshold, boxes);
632    let len = output_boxes.capacity().min(boxes.len());
633    output_boxes.clear();
634    for b in boxes.into_iter().take(len) {
635        output_boxes.push(b);
636    }
637}
638
639/// Internal implementation of YOLO split detection decoding for quantized
640/// tensors.
641///
642/// Expected shapes of inputs:
643/// - boxes: (4, num_boxes)
644/// - scores: (num_classes, num_boxes)
645///
646/// # Panics
647/// Panics if shapes don't match the expected dimensions.
648pub fn impl_yolo_split_quant<
649    B: BBoxTypeTrait,
650    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
651    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
652>(
653    boxes: (ArrayView2<BOX>, Quantization),
654    scores: (ArrayView2<SCORE>, Quantization),
655    score_threshold: f32,
656    iou_threshold: f32,
657    nms: Option<Nms>,
658    output_boxes: &mut Vec<DetectBox>,
659) where
660    f32: AsPrimitive<SCORE>,
661{
662    let (boxes_tensor, quant_boxes) = boxes;
663    let (scores_tensor, quant_scores) = scores;
664
665    let boxes_tensor = boxes_tensor.reversed_axes();
666    let scores_tensor = scores_tensor.reversed_axes();
667
668    let boxes = {
669        let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
670        postprocess_boxes_quant::<B, _, _>(
671            score_threshold,
672            boxes_tensor,
673            scores_tensor,
674            quant_boxes,
675        )
676    };
677
678    let boxes = dispatch_nms_int(nms, iou_threshold, boxes);
679    let len = output_boxes.capacity().min(boxes.len());
680    output_boxes.clear();
681    for b in boxes.iter().take(len) {
682        output_boxes.push(dequant_detect_box(b, quant_scores));
683    }
684}
685
686/// Internal implementation of YOLO split detection decoding for float tensors.
687///
688/// Expected shapes of inputs:
689/// - boxes: (4, num_boxes)
690/// - scores: (num_classes, num_boxes)
691///
692/// # Panics
693/// Panics if shapes don't match the expected dimensions.
694pub fn impl_yolo_split_float<
695    B: BBoxTypeTrait,
696    BOX: Float + AsPrimitive<f32> + Send + Sync,
697    SCORE: Float + AsPrimitive<f32> + Send + Sync,
698>(
699    boxes_tensor: ArrayView2<BOX>,
700    scores_tensor: ArrayView2<SCORE>,
701    score_threshold: f32,
702    iou_threshold: f32,
703    nms: Option<Nms>,
704    output_boxes: &mut Vec<DetectBox>,
705) where
706    f32: AsPrimitive<SCORE>,
707{
708    let boxes_tensor = boxes_tensor.reversed_axes();
709    let scores_tensor = scores_tensor.reversed_axes();
710    let boxes =
711        postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
712    let boxes = dispatch_nms_float(nms, iou_threshold, boxes);
713    let len = output_boxes.capacity().min(boxes.len());
714    output_boxes.clear();
715    for b in boxes.into_iter().take(len) {
716        output_boxes.push(b);
717    }
718}
719
720/// Internal implementation of YOLO detection segmentation decoding for
721/// quantized tensors.
722///
723/// Expected shapes of inputs:
724/// - boxes: (4 + num_classes + num_protos, num_boxes)
725/// - protos: (proto_height, proto_width, num_protos)
726///
727/// # Panics
728/// Panics if shapes don't match the expected dimensions.
729pub fn impl_yolo_segdet_quant<
730    B: BBoxTypeTrait,
731    BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
732    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
733>(
734    boxes: (ArrayView2<BOX>, Quantization),
735    protos: (ArrayView3<PROTO>, Quantization),
736    score_threshold: f32,
737    iou_threshold: f32,
738    nms: Option<Nms>,
739    output_boxes: &mut Vec<DetectBox>,
740    output_masks: &mut Vec<Segmentation>,
741) where
742    f32: AsPrimitive<BOX>,
743{
744    let (boxes, quant_boxes) = boxes;
745    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes);
746
747    let boxes = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
748        (boxes_tensor.reversed_axes(), quant_boxes),
749        (scores_tensor.reversed_axes(), quant_boxes),
750        score_threshold,
751        iou_threshold,
752        nms,
753        output_boxes.capacity(),
754    );
755
756    impl_yolo_split_segdet_quant_process_masks::<_, _>(
757        boxes,
758        (mask_tensor.reversed_axes(), quant_boxes),
759        protos,
760        output_boxes,
761        output_masks,
762    );
763}
764
765/// Internal implementation of YOLO detection segmentation decoding for
766/// float tensors.
767///
768/// Expected shapes of inputs:
769/// - boxes: (4 + num_classes + num_protos, num_boxes)
770/// - protos: (proto_height, proto_width, num_protos)
771///
772/// # Panics
773/// Panics if shapes don't match the expected dimensions.
774pub fn impl_yolo_segdet_float<
775    B: BBoxTypeTrait,
776    BOX: Float + AsPrimitive<f32> + Send + Sync,
777    PROTO: Float + AsPrimitive<f32> + Send + Sync,
778>(
779    boxes: ArrayView2<BOX>,
780    protos: ArrayView3<PROTO>,
781    score_threshold: f32,
782    iou_threshold: f32,
783    nms: Option<Nms>,
784    output_boxes: &mut Vec<DetectBox>,
785    output_masks: &mut Vec<Segmentation>,
786) where
787    f32: AsPrimitive<BOX>,
788{
789    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes);
790
791    let boxes = postprocess_boxes_index_float::<B, _, _>(
792        score_threshold.as_(),
793        boxes_tensor,
794        scores_tensor,
795    );
796    let mut boxes = dispatch_nms_extra_float(nms, iou_threshold, boxes);
797    boxes.truncate(output_boxes.capacity());
798    let boxes = decode_segdet_f32(boxes, mask_tensor, protos);
799    output_boxes.clear();
800    output_masks.clear();
801    for (b, m) in boxes.into_iter() {
802        output_boxes.push(b);
803        output_masks.push(Segmentation {
804            xmin: b.bbox.xmin,
805            ymin: b.bbox.ymin,
806            xmax: b.bbox.xmax,
807            ymax: b.bbox.ymax,
808            segmentation: m,
809        });
810    }
811}
812
813pub(crate) fn impl_yolo_split_segdet_quant_get_boxes<
814    B: BBoxTypeTrait,
815    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
816    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
817>(
818    boxes: (ArrayView2<BOX>, Quantization),
819    scores: (ArrayView2<SCORE>, Quantization),
820    score_threshold: f32,
821    iou_threshold: f32,
822    nms: Option<Nms>,
823    max_boxes: usize,
824) -> Vec<(DetectBox, usize)>
825where
826    f32: AsPrimitive<SCORE>,
827{
828    let (boxes_tensor, quant_boxes) = boxes;
829    let (scores_tensor, quant_scores) = scores;
830
831    let boxes_tensor = boxes_tensor.reversed_axes();
832    let scores_tensor = scores_tensor.reversed_axes();
833
834    let boxes = {
835        let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
836        postprocess_boxes_index_quant::<B, _, _>(
837            score_threshold,
838            boxes_tensor,
839            scores_tensor,
840            quant_boxes,
841        )
842    };
843    let mut boxes = dispatch_nms_extra_int(nms, iou_threshold, boxes);
844    boxes.truncate(max_boxes);
845    boxes
846        .into_iter()
847        .map(|(b, i)| (dequant_detect_box(&b, quant_scores), i))
848        .collect()
849}
850
851pub(crate) fn impl_yolo_split_segdet_quant_process_masks<
852    MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
853    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
854>(
855    boxes: Vec<(DetectBox, usize)>,
856    mask_coeff: (ArrayView2<MASK>, Quantization),
857    protos: (ArrayView3<PROTO>, Quantization),
858    output_boxes: &mut Vec<DetectBox>,
859    output_masks: &mut Vec<Segmentation>,
860) {
861    let (masks, quant_masks) = mask_coeff;
862    let (protos, quant_protos) = protos;
863
864    let masks = masks.reversed_axes();
865
866    let boxes = decode_segdet_quant(boxes, masks, protos, quant_masks, quant_protos);
867    output_boxes.clear();
868    output_masks.clear();
869    for (b, m) in boxes.into_iter() {
870        output_boxes.push(b);
871        output_masks.push(Segmentation {
872            xmin: b.bbox.xmin,
873            ymin: b.bbox.ymin,
874            xmax: b.bbox.xmax,
875            ymax: b.bbox.ymax,
876            segmentation: m,
877        });
878    }
879}
880
881#[allow(clippy::too_many_arguments)]
882/// Internal implementation of YOLO split detection segmentation decoding for
883/// quantized tensors.
884///
885/// Expected shapes of inputs:
886/// - boxes_tensor: (4, num_boxes)
887/// - scores_tensor: (num_classes, num_boxes)
888/// - mask_tensor: (num_protos, num_boxes)
889/// - protos: (proto_height, proto_width, num_protos)
890///
891/// # Panics
892/// Panics if shapes don't match the expected dimensions.
893pub fn impl_yolo_split_segdet_quant<
894    B: BBoxTypeTrait,
895    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
896    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
897    MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
898    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
899>(
900    boxes: (ArrayView2<BOX>, Quantization),
901    scores: (ArrayView2<SCORE>, Quantization),
902    mask_coeff: (ArrayView2<MASK>, Quantization),
903    protos: (ArrayView3<PROTO>, Quantization),
904    score_threshold: f32,
905    iou_threshold: f32,
906    nms: Option<Nms>,
907    output_boxes: &mut Vec<DetectBox>,
908    output_masks: &mut Vec<Segmentation>,
909) where
910    f32: AsPrimitive<SCORE>,
911{
912    let boxes = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
913        boxes,
914        scores,
915        score_threshold,
916        iou_threshold,
917        nms,
918        output_boxes.capacity(),
919    );
920
921    impl_yolo_split_segdet_quant_process_masks(
922        boxes,
923        mask_coeff,
924        protos,
925        output_boxes,
926        output_masks,
927    );
928}
929
930#[allow(clippy::too_many_arguments)]
931/// Internal implementation of YOLO split detection segmentation decoding for
932/// float tensors.
933///
934/// Expected shapes of inputs:
935/// - boxes_tensor: (4, num_boxes)
936/// - scores_tensor: (num_classes, num_boxes)
937/// - mask_tensor: (num_protos, num_boxes)
938/// - protos: (proto_height, proto_width, num_protos)
939///
940/// # Panics
941/// Panics if shapes don't match the expected dimensions.
942pub fn impl_yolo_split_segdet_float<
943    B: BBoxTypeTrait,
944    BOX: Float + AsPrimitive<f32> + Send + Sync,
945    SCORE: Float + AsPrimitive<f32> + Send + Sync,
946    MASK: Float + AsPrimitive<f32> + Send + Sync,
947    PROTO: Float + AsPrimitive<f32> + Send + Sync,
948>(
949    boxes_tensor: ArrayView2<BOX>,
950    scores_tensor: ArrayView2<SCORE>,
951    mask_tensor: ArrayView2<MASK>,
952    protos: ArrayView3<PROTO>,
953    score_threshold: f32,
954    iou_threshold: f32,
955    nms: Option<Nms>,
956    output_boxes: &mut Vec<DetectBox>,
957    output_masks: &mut Vec<Segmentation>,
958) where
959    f32: AsPrimitive<SCORE>,
960{
961    let boxes_tensor = boxes_tensor.reversed_axes();
962    let scores_tensor = scores_tensor.reversed_axes();
963    let mask_tensor = mask_tensor.reversed_axes();
964
965    let boxes = postprocess_boxes_index_float::<B, _, _>(
966        score_threshold.as_(),
967        boxes_tensor,
968        scores_tensor,
969    );
970    let mut boxes = dispatch_nms_extra_float(nms, iou_threshold, boxes);
971    boxes.truncate(output_boxes.capacity());
972    let boxes = decode_segdet_f32(boxes, mask_tensor, protos);
973    output_boxes.clear();
974    output_masks.clear();
975    for (b, m) in boxes.into_iter() {
976        output_boxes.push(b);
977        output_masks.push(Segmentation {
978            xmin: b.bbox.xmin,
979            ymin: b.bbox.ymin,
980            xmax: b.bbox.xmax,
981            ymax: b.bbox.ymax,
982            segmentation: m,
983        });
984    }
985}
986
987// ---------------------------------------------------------------------------
988// Proto-extraction variants: return ProtoData instead of materialized masks
989// ---------------------------------------------------------------------------
990
991/// Proto-extraction variant of `impl_yolo_segdet_quant`.
992/// Runs NMS but returns raw `ProtoData` instead of materialized masks.
993pub fn impl_yolo_segdet_quant_proto<
994    B: BBoxTypeTrait,
995    BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
996    PROTO: PrimInt
997        + AsPrimitive<i64>
998        + AsPrimitive<i128>
999        + AsPrimitive<f32>
1000        + AsPrimitive<i8>
1001        + Send
1002        + Sync,
1003>(
1004    boxes: (ArrayView2<BOX>, Quantization),
1005    protos: (ArrayView3<PROTO>, Quantization),
1006    score_threshold: f32,
1007    iou_threshold: f32,
1008    nms: Option<Nms>,
1009    output_boxes: &mut Vec<DetectBox>,
1010) -> ProtoData
1011where
1012    f32: AsPrimitive<BOX>,
1013{
1014    let (boxes_arr, quant_boxes) = boxes;
1015    let (protos_arr, quant_protos) = protos;
1016
1017    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes_arr);
1018
1019    let det_indices = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
1020        (boxes_tensor.reversed_axes(), quant_boxes),
1021        (scores_tensor.reversed_axes(), quant_boxes),
1022        score_threshold,
1023        iou_threshold,
1024        nms,
1025        output_boxes.capacity(),
1026    );
1027
1028    let mask_tensor = mask_tensor.reversed_axes();
1029    extract_proto_data_quant(
1030        det_indices,
1031        mask_tensor,
1032        quant_boxes,
1033        protos_arr,
1034        quant_protos,
1035        output_boxes,
1036    )
1037}
1038
1039/// Proto-extraction variant of `impl_yolo_segdet_float`.
1040/// Runs NMS but returns raw `ProtoData` instead of materialized masks.
1041pub fn impl_yolo_segdet_float_proto<
1042    B: BBoxTypeTrait,
1043    BOX: Float + AsPrimitive<f32> + Send + Sync,
1044    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1045>(
1046    boxes: ArrayView2<BOX>,
1047    protos: ArrayView3<PROTO>,
1048    score_threshold: f32,
1049    iou_threshold: f32,
1050    nms: Option<Nms>,
1051    output_boxes: &mut Vec<DetectBox>,
1052) -> ProtoData
1053where
1054    f32: AsPrimitive<BOX>,
1055{
1056    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes);
1057
1058    let det_indices = postprocess_boxes_index_float::<B, _, _>(
1059        score_threshold.as_(),
1060        boxes_tensor,
1061        scores_tensor,
1062    );
1063    let mut det_indices = dispatch_nms_extra_float(nms, iou_threshold, det_indices);
1064    det_indices.truncate(output_boxes.capacity());
1065
1066    let mask_tensor = mask_tensor.reversed_axes();
1067    extract_proto_data_float(det_indices, mask_tensor, protos, output_boxes)
1068}
1069
1070/// Proto-extraction variant of `impl_yolo_split_segdet_quant`.
1071/// Runs NMS but returns raw `ProtoData` instead of materialized masks.
1072#[allow(clippy::too_many_arguments)]
1073pub fn impl_yolo_split_segdet_quant_proto<
1074    B: BBoxTypeTrait,
1075    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
1076    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
1077    MASK: PrimInt + AsPrimitive<f32> + Send + Sync,
1078    PROTO: PrimInt + AsPrimitive<f32> + AsPrimitive<i8> + Send + Sync,
1079>(
1080    boxes: (ArrayView2<BOX>, Quantization),
1081    scores: (ArrayView2<SCORE>, Quantization),
1082    mask_coeff: (ArrayView2<MASK>, Quantization),
1083    protos: (ArrayView3<PROTO>, Quantization),
1084    score_threshold: f32,
1085    iou_threshold: f32,
1086    nms: Option<Nms>,
1087    output_boxes: &mut Vec<DetectBox>,
1088) -> ProtoData
1089where
1090    f32: AsPrimitive<SCORE>,
1091{
1092    let det_indices = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
1093        boxes,
1094        scores,
1095        score_threshold,
1096        iou_threshold,
1097        nms,
1098        output_boxes.capacity(),
1099    );
1100
1101    let (masks, quant_masks) = mask_coeff;
1102    let masks = masks.reversed_axes();
1103    let (protos_arr, quant_protos) = protos;
1104
1105    extract_proto_data_quant(
1106        det_indices,
1107        masks,
1108        quant_masks,
1109        protos_arr,
1110        quant_protos,
1111        output_boxes,
1112    )
1113}
1114
1115/// Proto-extraction variant of `impl_yolo_split_segdet_float`.
1116/// Runs NMS but returns raw `ProtoData` instead of materialized masks.
1117#[allow(clippy::too_many_arguments)]
1118pub fn impl_yolo_split_segdet_float_proto<
1119    B: BBoxTypeTrait,
1120    BOX: Float + AsPrimitive<f32> + Send + Sync,
1121    SCORE: Float + AsPrimitive<f32> + Send + Sync,
1122    MASK: Float + AsPrimitive<f32> + Send + Sync,
1123    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1124>(
1125    boxes_tensor: ArrayView2<BOX>,
1126    scores_tensor: ArrayView2<SCORE>,
1127    mask_tensor: ArrayView2<MASK>,
1128    protos: ArrayView3<PROTO>,
1129    score_threshold: f32,
1130    iou_threshold: f32,
1131    nms: Option<Nms>,
1132    output_boxes: &mut Vec<DetectBox>,
1133) -> ProtoData
1134where
1135    f32: AsPrimitive<SCORE>,
1136{
1137    let boxes_tensor = boxes_tensor.reversed_axes();
1138    let scores_tensor = scores_tensor.reversed_axes();
1139    let mask_tensor = mask_tensor.reversed_axes();
1140
1141    let det_indices = postprocess_boxes_index_float::<B, _, _>(
1142        score_threshold.as_(),
1143        boxes_tensor,
1144        scores_tensor,
1145    );
1146    let mut det_indices = dispatch_nms_extra_float(nms, iou_threshold, det_indices);
1147    det_indices.truncate(output_boxes.capacity());
1148
1149    extract_proto_data_float(det_indices, mask_tensor, protos, output_boxes)
1150}
1151
1152/// Proto-extraction variant of `decode_yolo_end_to_end_segdet_float`.
1153pub fn decode_yolo_end_to_end_segdet_float_proto<T>(
1154    output: ArrayView2<T>,
1155    protos: ArrayView3<T>,
1156    score_threshold: f32,
1157    output_boxes: &mut Vec<DetectBox>,
1158) -> Result<ProtoData, crate::DecoderError>
1159where
1160    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
1161    f32: AsPrimitive<T>,
1162{
1163    if output.shape()[0] < 7 {
1164        return Err(crate::DecoderError::InvalidShape(format!(
1165            "End-to-end segdet output requires at least 7 rows, got {}",
1166            output.shape()[0]
1167        )));
1168    }
1169
1170    let num_mask_coeffs = output.shape()[0] - 6;
1171    let num_protos = protos.shape()[2];
1172    if num_mask_coeffs != num_protos {
1173        return Err(crate::DecoderError::InvalidShape(format!(
1174            "Mask coefficients count ({}) doesn't match protos count ({})",
1175            num_mask_coeffs, num_protos
1176        )));
1177    }
1178
1179    let boxes = output.slice(s![0..4, ..]).reversed_axes();
1180    let scores = output.slice(s![4..5, ..]).reversed_axes();
1181    let classes = output.slice(s![5, ..]);
1182    let mask_coeff = output.slice(s![6.., ..]).reversed_axes();
1183    let mut det_indices =
1184        postprocess_boxes_index_float::<XYXY, _, _>(score_threshold.as_(), boxes, scores);
1185    det_indices.truncate(output_boxes.capacity());
1186
1187    for (b, ind) in &mut det_indices {
1188        b.label = classes[*ind].as_() as usize;
1189    }
1190
1191    Ok(extract_proto_data_float(
1192        det_indices,
1193        mask_coeff,
1194        protos,
1195        output_boxes,
1196    ))
1197}
1198
1199/// Proto-extraction variant of `decode_yolo_split_end_to_end_segdet_float`.
1200#[allow(clippy::too_many_arguments)]
1201pub fn decode_yolo_split_end_to_end_segdet_float_proto<T>(
1202    boxes: ArrayView2<T>,
1203    scores: ArrayView2<T>,
1204    classes: ArrayView2<T>,
1205    mask_coeff: ArrayView2<T>,
1206    protos: ArrayView3<T>,
1207    score_threshold: f32,
1208    output_boxes: &mut Vec<DetectBox>,
1209) -> Result<ProtoData, crate::DecoderError>
1210where
1211    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
1212    f32: AsPrimitive<T>,
1213{
1214    let n = boxes.shape()[0];
1215    if boxes.shape()[1] != 4 {
1216        return Err(crate::DecoderError::InvalidShape(format!(
1217            "Split end-to-end boxes must have 4 columns, got {}",
1218            boxes.shape()[1]
1219        )));
1220    }
1221
1222    let mut qualifying: Vec<(DetectBox, usize)> = Vec::with_capacity(output_boxes.capacity());
1223    for i in 0..n {
1224        let score: f32 = scores[[i, 0]].as_();
1225        if score < score_threshold {
1226            continue;
1227        }
1228        if qualifying.len() >= output_boxes.capacity() {
1229            break;
1230        }
1231        qualifying.push((
1232            DetectBox {
1233                bbox: BoundingBox {
1234                    xmin: boxes[[i, 0]].as_(),
1235                    ymin: boxes[[i, 1]].as_(),
1236                    xmax: boxes[[i, 2]].as_(),
1237                    ymax: boxes[[i, 3]].as_(),
1238                },
1239                score,
1240                label: classes[[i, 0]].as_() as usize,
1241            },
1242            i,
1243        ));
1244    }
1245
1246    Ok(extract_proto_data_float(
1247        qualifying,
1248        mask_coeff,
1249        protos,
1250        output_boxes,
1251    ))
1252}
1253
1254/// Helper: extract ProtoData from float mask coefficients + protos.
1255fn extract_proto_data_float<
1256    MASK: Float + AsPrimitive<f32> + Send + Sync,
1257    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1258>(
1259    det_indices: Vec<(DetectBox, usize)>,
1260    mask_tensor: ArrayView2<MASK>,
1261    protos: ArrayView3<PROTO>,
1262    output_boxes: &mut Vec<DetectBox>,
1263) -> ProtoData {
1264    let mut mask_coefficients = Vec::with_capacity(det_indices.len());
1265    output_boxes.clear();
1266    for (det, idx) in det_indices {
1267        output_boxes.push(det);
1268        let row = mask_tensor.row(idx);
1269        mask_coefficients.push(row.iter().map(|v| v.as_()).collect());
1270    }
1271    let protos_f32 = protos.map(|v| v.as_());
1272    ProtoData {
1273        mask_coefficients,
1274        protos: ProtoTensor::Float(protos_f32),
1275    }
1276}
1277
1278/// Helper: extract ProtoData from quantized mask coefficients + protos.
1279///
1280/// Dequantizes mask coefficients to f32 (small — per-detection) but keeps
1281/// protos in raw int8 form wrapped in `ProtoTensor::Quantized` so the GPU
1282/// shader can dequantize per-texel without CPU overhead.
1283pub(crate) fn extract_proto_data_quant<
1284    MASK: PrimInt + AsPrimitive<f32> + Send + Sync,
1285    PROTO: PrimInt + AsPrimitive<f32> + AsPrimitive<i8> + Send + Sync,
1286>(
1287    det_indices: Vec<(DetectBox, usize)>,
1288    mask_tensor: ArrayView2<MASK>,
1289    quant_masks: Quantization,
1290    protos: ArrayView3<PROTO>,
1291    quant_protos: Quantization,
1292    output_boxes: &mut Vec<DetectBox>,
1293) -> ProtoData {
1294    let mut mask_coefficients = Vec::with_capacity(det_indices.len());
1295    output_boxes.clear();
1296    for (det, idx) in det_indices {
1297        output_boxes.push(det);
1298        let row = mask_tensor.row(idx);
1299        mask_coefficients.push(
1300            row.iter()
1301                .map(|v| (v.as_() - quant_masks.zero_point as f32) * quant_masks.scale)
1302                .collect(),
1303        );
1304    }
1305    // Keep protos in raw int8 — GPU shader will dequantize per-texel.
1306    let protos_i8 = protos.map(|v| {
1307        let v_i8: i8 = v.as_();
1308        v_i8
1309    });
1310    ProtoData {
1311        mask_coefficients,
1312        protos: ProtoTensor::Quantized {
1313            protos: protos_i8,
1314            quantization: quant_protos,
1315        },
1316    }
1317}
1318
1319fn postprocess_yolo<'a, T>(
1320    output: &'a ArrayView2<'_, T>,
1321) -> (ArrayView2<'a, T>, ArrayView2<'a, T>) {
1322    let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1323    let scores_tensor = output.slice(s![4.., ..,]).reversed_axes();
1324    (boxes_tensor, scores_tensor)
1325}
1326
1327fn postprocess_yolo_seg<'a, T>(
1328    output: &'a ArrayView2<'_, T>,
1329) -> (ArrayView2<'a, T>, ArrayView2<'a, T>, ArrayView2<'a, T>) {
1330    assert!(output.shape()[0] > 32 + 4, "Output shape is too short");
1331    let num_classes = output.shape()[0] - 4 - 32;
1332    let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
1333    let scores_tensor = output.slice(s![4..(num_classes + 4), ..,]).reversed_axes();
1334    let mask_tensor = output.slice(s![(num_classes + 4).., ..,]).reversed_axes();
1335    (boxes_tensor, scores_tensor, mask_tensor)
1336}
1337
1338fn decode_segdet_f32<
1339    MASK: Float + AsPrimitive<f32> + Send + Sync,
1340    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1341>(
1342    boxes: Vec<(DetectBox, usize)>,
1343    masks: ArrayView2<MASK>,
1344    protos: ArrayView3<PROTO>,
1345) -> Vec<(DetectBox, Array3<u8>)> {
1346    if boxes.is_empty() {
1347        return Vec::new();
1348    }
1349    assert!(masks.shape()[1] == protos.shape()[2]);
1350    boxes
1351        .into_par_iter()
1352        .map(|mut b| {
1353            let ind = b.1;
1354            let (protos, roi) = protobox(&protos, &b.0.bbox);
1355            b.0.bbox = roi;
1356            (b.0, make_segmentation(masks.row(ind), protos.view()))
1357        })
1358        .collect()
1359}
1360
1361pub(crate) fn decode_segdet_quant<
1362    MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1363    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
1364>(
1365    boxes: Vec<(DetectBox, usize)>,
1366    masks: ArrayView2<MASK>,
1367    protos: ArrayView3<PROTO>,
1368    quant_masks: Quantization,
1369    quant_protos: Quantization,
1370) -> Vec<(DetectBox, Array3<u8>)> {
1371    if boxes.is_empty() {
1372        return Vec::new();
1373    }
1374    assert!(masks.shape()[1] == protos.shape()[2]);
1375
1376    let total_bits = MASK::zero().count_zeros() + PROTO::zero().count_zeros() + 5; // 32 protos is 2^5
1377    boxes
1378        .into_iter()
1379        .map(|mut b| {
1380            let i = b.1;
1381            let (protos, roi) = protobox(&protos, &b.0.bbox.to_canonical());
1382            b.0.bbox = roi;
1383            let seg = match total_bits {
1384                0..=64 => make_segmentation_quant::<MASK, PROTO, i64>(
1385                    masks.row(i),
1386                    protos.view(),
1387                    quant_masks,
1388                    quant_protos,
1389                ),
1390                65..=128 => make_segmentation_quant::<MASK, PROTO, i128>(
1391                    masks.row(i),
1392                    protos.view(),
1393                    quant_masks,
1394                    quant_protos,
1395                ),
1396                _ => panic!("Unsupported bit width for segmentation computation"),
1397            };
1398            (b.0, seg)
1399        })
1400        .collect()
1401}
1402
1403fn protobox<'a, T>(
1404    protos: &'a ArrayView3<T>,
1405    roi: &BoundingBox,
1406) -> (ArrayView3<'a, T>, BoundingBox) {
1407    let width = protos.dim().1 as f32;
1408    let height = protos.dim().0 as f32;
1409
1410    let roi = [
1411        (roi.xmin * width).clamp(0.0, width) as usize,
1412        (roi.ymin * height).clamp(0.0, height) as usize,
1413        (roi.xmax * width).clamp(0.0, width).ceil() as usize,
1414        (roi.ymax * height).clamp(0.0, height).ceil() as usize,
1415    ];
1416
1417    let roi_norm = [
1418        roi[0] as f32 / width,
1419        roi[1] as f32 / height,
1420        roi[2] as f32 / width,
1421        roi[3] as f32 / height,
1422    ]
1423    .into();
1424
1425    let cropped = protos.slice(s![roi[1]..roi[3], roi[0]..roi[2], ..]);
1426
1427    (cropped, roi_norm)
1428}
1429
1430fn make_segmentation<
1431    MASK: Float + AsPrimitive<f32> + Send + Sync,
1432    PROTO: Float + AsPrimitive<f32> + Send + Sync,
1433>(
1434    mask: ArrayView1<MASK>,
1435    protos: ArrayView3<PROTO>,
1436) -> Array3<u8> {
1437    let shape = protos.shape();
1438
1439    // Safe to unwrap since the shapes will always be compatible
1440    let mask = mask.to_shape((1, mask.len())).unwrap();
1441    let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
1442    let protos = protos.reversed_axes();
1443    let mask = mask.map(|x| x.as_());
1444    let protos = protos.map(|x| x.as_());
1445
1446    // Safe to unwrap since the shapes will always be compatible
1447    let mask = mask
1448        .dot(&protos)
1449        .into_shape_with_order((shape[0], shape[1], 1))
1450        .unwrap();
1451
1452    let min = *mask.min().unwrap_or(&0.0);
1453    let max = *mask.max().unwrap_or(&1.0);
1454    let max = max.max(-min);
1455    let min = -max;
1456    let u8_max = 256.0;
1457    mask.map(|x| ((*x - min) / (max - min) * u8_max) as u8)
1458}
1459
1460fn make_segmentation_quant<
1461    MASK: PrimInt + AsPrimitive<DEST> + Send + Sync,
1462    PROTO: PrimInt + AsPrimitive<DEST> + Send + Sync,
1463    DEST: PrimInt + 'static + Signed + AsPrimitive<f32> + Debug,
1464>(
1465    mask: ArrayView1<MASK>,
1466    protos: ArrayView3<PROTO>,
1467    quant_masks: Quantization,
1468    quant_protos: Quantization,
1469) -> Array3<u8>
1470where
1471    i32: AsPrimitive<DEST>,
1472    f32: AsPrimitive<DEST>,
1473{
1474    let shape = protos.shape();
1475
1476    // Safe to unwrap since the shapes will always be compatible
1477    let mask = mask.to_shape((1, mask.len())).unwrap();
1478
1479    let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
1480    let protos = protos.reversed_axes();
1481
1482    let zp = quant_masks.zero_point.as_();
1483
1484    let mask = mask.mapv(|x| x.as_() - zp);
1485
1486    let zp = quant_protos.zero_point.as_();
1487    let protos = protos.mapv(|x| x.as_() - zp);
1488
1489    // Safe to unwrap since the shapes will always be compatible
1490    let segmentation = mask
1491        .dot(&protos)
1492        .into_shape_with_order((shape[0], shape[1], 1))
1493        .unwrap();
1494
1495    let min = *segmentation.min().unwrap_or(&DEST::zero());
1496    let max = *segmentation.max().unwrap_or(&DEST::one());
1497    let max = max.max(-min);
1498    let min = -max;
1499    segmentation.map(|x| ((*x - min).as_() / (max - min).as_() * 256.0) as u8)
1500}
1501
1502/// Converts Yolo Instance Segmentation into a 2D mask.
1503///
1504/// The input segmentation is expected to have shape (H, W, 1).
1505///
1506/// The output mask will have shape (H, W), with values 0 or 1 based on the
1507/// threshold.
1508///
1509/// # Errors
1510///
1511/// Returns `DecoderError::InvalidShape` if the input segmentation does not
1512/// have shape (H, W, 1).
1513pub fn yolo_segmentation_to_mask(
1514    segmentation: ArrayView3<u8>,
1515    threshold: u8,
1516) -> Result<Array2<u8>, crate::DecoderError> {
1517    if segmentation.shape()[2] != 1 {
1518        return Err(crate::DecoderError::InvalidShape(format!(
1519            "Yolo Instance Segmentation should have shape (H, W, 1), got (H, W, {})",
1520            segmentation.shape()[2]
1521        )));
1522    }
1523    Ok(segmentation
1524        .slice(s![.., .., 0])
1525        .map(|x| if *x >= threshold { 1 } else { 0 }))
1526}
1527
1528#[cfg(test)]
1529#[cfg_attr(coverage_nightly, coverage(off))]
1530mod tests {
1531    use super::*;
1532    use ndarray::Array2;
1533
1534    // ========================================================================
1535    // Tests for decode_yolo_end_to_end_det_float
1536    // ========================================================================
1537
1538    #[test]
1539    fn test_end_to_end_det_basic_filtering() {
1540        // Create synthetic end-to-end detection output: (6, N) where rows are
1541        // [x1, y1, x2, y2, conf, class]
1542        // 3 detections: one above threshold, two below
1543        let data: Vec<f32> = vec![
1544            // Detection 0: high score (0.9)
1545            0.1, 0.2, 0.3, // x1 values
1546            0.1, 0.2, 0.3, // y1 values
1547            0.5, 0.6, 0.7, // x2 values
1548            0.5, 0.6, 0.7, // y2 values
1549            0.9, 0.1, 0.2, // confidence scores
1550            0.0, 1.0, 2.0, // class indices
1551        ];
1552        let output = Array2::from_shape_vec((6, 3), data).unwrap();
1553
1554        let mut boxes = Vec::with_capacity(10);
1555        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1556
1557        // Only 1 detection should pass threshold of 0.5
1558        assert_eq!(boxes.len(), 1);
1559        assert_eq!(boxes[0].label, 0);
1560        assert!((boxes[0].score - 0.9).abs() < 0.01);
1561        assert!((boxes[0].bbox.xmin - 0.1).abs() < 0.01);
1562        assert!((boxes[0].bbox.ymin - 0.1).abs() < 0.01);
1563        assert!((boxes[0].bbox.xmax - 0.5).abs() < 0.01);
1564        assert!((boxes[0].bbox.ymax - 0.5).abs() < 0.01);
1565    }
1566
1567    #[test]
1568    fn test_end_to_end_det_all_pass_threshold() {
1569        // All detections above threshold
1570        let data: Vec<f32> = vec![
1571            10.0, 20.0, // x1
1572            10.0, 20.0, // y1
1573            50.0, 60.0, // x2
1574            50.0, 60.0, // y2
1575            0.8, 0.7, // conf (both above 0.5)
1576            1.0, 2.0, // class
1577        ];
1578        let output = Array2::from_shape_vec((6, 2), data).unwrap();
1579
1580        let mut boxes = Vec::with_capacity(10);
1581        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1582
1583        assert_eq!(boxes.len(), 2);
1584        assert_eq!(boxes[0].label, 1);
1585        assert_eq!(boxes[1].label, 2);
1586    }
1587
1588    #[test]
1589    fn test_end_to_end_det_none_pass_threshold() {
1590        // All detections below threshold
1591        let data: Vec<f32> = vec![
1592            10.0, 20.0, // x1
1593            10.0, 20.0, // y1
1594            50.0, 60.0, // x2
1595            50.0, 60.0, // y2
1596            0.1, 0.2, // conf (both below 0.5)
1597            1.0, 2.0, // class
1598        ];
1599        let output = Array2::from_shape_vec((6, 2), data).unwrap();
1600
1601        let mut boxes = Vec::with_capacity(10);
1602        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1603
1604        assert_eq!(boxes.len(), 0);
1605    }
1606
1607    #[test]
1608    fn test_end_to_end_det_capacity_limit() {
1609        // Test that output is truncated to capacity
1610        let data: Vec<f32> = vec![
1611            0.1, 0.2, 0.3, 0.4, 0.5, // x1
1612            0.1, 0.2, 0.3, 0.4, 0.5, // y1
1613            0.5, 0.6, 0.7, 0.8, 0.9, // x2
1614            0.5, 0.6, 0.7, 0.8, 0.9, // y2
1615            0.9, 0.9, 0.9, 0.9, 0.9, // conf (all pass)
1616            0.0, 1.0, 2.0, 3.0, 4.0, // class
1617        ];
1618        let output = Array2::from_shape_vec((6, 5), data).unwrap();
1619
1620        let mut boxes = Vec::with_capacity(2); // Only allow 2 boxes
1621        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1622
1623        assert_eq!(boxes.len(), 2);
1624    }
1625
1626    #[test]
1627    fn test_end_to_end_det_empty_output() {
1628        // Test with zero detections
1629        let output = Array2::<f32>::zeros((6, 0));
1630
1631        let mut boxes = Vec::with_capacity(10);
1632        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1633
1634        assert_eq!(boxes.len(), 0);
1635    }
1636
1637    #[test]
1638    fn test_end_to_end_det_pixel_coordinates() {
1639        // Test with pixel coordinates (non-normalized)
1640        let data: Vec<f32> = vec![
1641            100.0, // x1
1642            200.0, // y1
1643            300.0, // x2
1644            400.0, // y2
1645            0.95,  // conf
1646            5.0,   // class
1647        ];
1648        let output = Array2::from_shape_vec((6, 1), data).unwrap();
1649
1650        let mut boxes = Vec::with_capacity(10);
1651        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1652
1653        assert_eq!(boxes.len(), 1);
1654        assert_eq!(boxes[0].label, 5);
1655        assert!((boxes[0].bbox.xmin - 100.0).abs() < 0.01);
1656        assert!((boxes[0].bbox.ymin - 200.0).abs() < 0.01);
1657        assert!((boxes[0].bbox.xmax - 300.0).abs() < 0.01);
1658        assert!((boxes[0].bbox.ymax - 400.0).abs() < 0.01);
1659    }
1660
1661    #[test]
1662    fn test_end_to_end_det_invalid_shape() {
1663        // Test with too few rows (needs at least 6)
1664        let output = Array2::<f32>::zeros((5, 3));
1665
1666        let mut boxes = Vec::with_capacity(10);
1667        let result = decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes);
1668
1669        assert!(result.is_err());
1670        assert!(matches!(
1671            result,
1672            Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 6 rows")
1673        ));
1674    }
1675
1676    // ========================================================================
1677    // Tests for decode_yolo_end_to_end_segdet_float
1678    // ========================================================================
1679
1680    #[test]
1681    fn test_end_to_end_segdet_basic() {
1682        // Create synthetic segdet output: (6 + num_protos, N)
1683        // Detection format: [x1, y1, x2, y2, conf, class, mask_coeff_0..31]
1684        let num_protos = 32;
1685        let num_detections = 2;
1686        let num_features = 6 + num_protos;
1687
1688        // Build detection tensor
1689        let mut data = vec![0.0f32; num_features * num_detections];
1690        // Detection 0: passes threshold
1691        data[0] = 0.1; // x1[0]
1692        data[1] = 0.5; // x1[1]
1693        data[num_detections] = 0.1; // y1[0]
1694        data[num_detections + 1] = 0.5; // y1[1]
1695        data[2 * num_detections] = 0.4; // x2[0]
1696        data[2 * num_detections + 1] = 0.9; // x2[1]
1697        data[3 * num_detections] = 0.4; // y2[0]
1698        data[3 * num_detections + 1] = 0.9; // y2[1]
1699        data[4 * num_detections] = 0.9; // conf[0] - passes
1700        data[4 * num_detections + 1] = 0.3; // conf[1] - fails
1701        data[5 * num_detections] = 1.0; // class[0]
1702        data[5 * num_detections + 1] = 2.0; // class[1]
1703                                            // Fill mask coefficients with small values
1704        for i in 6..num_features {
1705            data[i * num_detections] = 0.1;
1706            data[i * num_detections + 1] = 0.1;
1707        }
1708
1709        let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
1710
1711        // Create protos tensor: (proto_height, proto_width, num_protos)
1712        let protos = Array3::<f32>::zeros((16, 16, num_protos));
1713
1714        let mut boxes = Vec::with_capacity(10);
1715        let mut masks = Vec::with_capacity(10);
1716        decode_yolo_end_to_end_segdet_float(
1717            output.view(),
1718            protos.view(),
1719            0.5,
1720            &mut boxes,
1721            &mut masks,
1722        )
1723        .unwrap();
1724
1725        // Only detection 0 should pass
1726        assert_eq!(boxes.len(), 1);
1727        assert_eq!(masks.len(), 1);
1728        assert_eq!(boxes[0].label, 1);
1729        assert!((boxes[0].score - 0.9).abs() < 0.01);
1730    }
1731
1732    #[test]
1733    fn test_end_to_end_segdet_mask_coordinates() {
1734        // Test that mask coordinates match box coordinates
1735        let num_protos = 32;
1736        let num_features = 6 + num_protos;
1737
1738        let mut data = vec![0.0f32; num_features];
1739        data[0] = 0.2; // x1
1740        data[1] = 0.2; // y1
1741        data[2] = 0.8; // x2
1742        data[3] = 0.8; // y2
1743        data[4] = 0.95; // conf
1744        data[5] = 3.0; // class
1745
1746        let output = Array2::from_shape_vec((num_features, 1), data).unwrap();
1747        let protos = Array3::<f32>::zeros((16, 16, num_protos));
1748
1749        let mut boxes = Vec::with_capacity(10);
1750        let mut masks = Vec::with_capacity(10);
1751        decode_yolo_end_to_end_segdet_float(
1752            output.view(),
1753            protos.view(),
1754            0.5,
1755            &mut boxes,
1756            &mut masks,
1757        )
1758        .unwrap();
1759
1760        assert_eq!(boxes.len(), 1);
1761        assert_eq!(masks.len(), 1);
1762
1763        // Verify mask coordinates match box coordinates
1764        assert!((masks[0].xmin - boxes[0].bbox.xmin).abs() < 0.01);
1765        assert!((masks[0].ymin - boxes[0].bbox.ymin).abs() < 0.01);
1766        assert!((masks[0].xmax - boxes[0].bbox.xmax).abs() < 0.01);
1767        assert!((masks[0].ymax - boxes[0].bbox.ymax).abs() < 0.01);
1768    }
1769
1770    #[test]
1771    fn test_end_to_end_segdet_empty_output() {
1772        let num_protos = 32;
1773        let output = Array2::<f32>::zeros((6 + num_protos, 0));
1774        let protos = Array3::<f32>::zeros((16, 16, num_protos));
1775
1776        let mut boxes = Vec::with_capacity(10);
1777        let mut masks = Vec::with_capacity(10);
1778        decode_yolo_end_to_end_segdet_float(
1779            output.view(),
1780            protos.view(),
1781            0.5,
1782            &mut boxes,
1783            &mut masks,
1784        )
1785        .unwrap();
1786
1787        assert_eq!(boxes.len(), 0);
1788        assert_eq!(masks.len(), 0);
1789    }
1790
1791    #[test]
1792    fn test_end_to_end_segdet_capacity_limit() {
1793        let num_protos = 32;
1794        let num_detections = 5;
1795        let num_features = 6 + num_protos;
1796
1797        let mut data = vec![0.0f32; num_features * num_detections];
1798        // All detections pass threshold
1799        for i in 0..num_detections {
1800            data[i] = 0.1 * (i as f32); // x1
1801            data[num_detections + i] = 0.1 * (i as f32); // y1
1802            data[2 * num_detections + i] = 0.1 * (i as f32) + 0.2; // x2
1803            data[3 * num_detections + i] = 0.1 * (i as f32) + 0.2; // y2
1804            data[4 * num_detections + i] = 0.9; // conf
1805            data[5 * num_detections + i] = i as f32; // class
1806        }
1807
1808        let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
1809        let protos = Array3::<f32>::zeros((16, 16, num_protos));
1810
1811        let mut boxes = Vec::with_capacity(2); // Limit to 2
1812        let mut masks = Vec::with_capacity(2);
1813        decode_yolo_end_to_end_segdet_float(
1814            output.view(),
1815            protos.view(),
1816            0.5,
1817            &mut boxes,
1818            &mut masks,
1819        )
1820        .unwrap();
1821
1822        assert_eq!(boxes.len(), 2);
1823        assert_eq!(masks.len(), 2);
1824    }
1825
1826    #[test]
1827    fn test_end_to_end_segdet_invalid_shape_too_few_rows() {
1828        // Test with too few rows (needs at least 7: 6 base + 1 mask coeff)
1829        let output = Array2::<f32>::zeros((6, 3));
1830        let protos = Array3::<f32>::zeros((16, 16, 32));
1831
1832        let mut boxes = Vec::with_capacity(10);
1833        let mut masks = Vec::with_capacity(10);
1834        let result = decode_yolo_end_to_end_segdet_float(
1835            output.view(),
1836            protos.view(),
1837            0.5,
1838            &mut boxes,
1839            &mut masks,
1840        );
1841
1842        assert!(result.is_err());
1843        assert!(matches!(
1844            result,
1845            Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 7 rows")
1846        ));
1847    }
1848
1849    #[test]
1850    fn test_end_to_end_segdet_invalid_shape_protos_mismatch() {
1851        // Test with mismatched mask coefficients and protos count
1852        let num_protos = 32;
1853        let output = Array2::<f32>::zeros((6 + 16, 3)); // 16 mask coeffs
1854        let protos = Array3::<f32>::zeros((16, 16, num_protos)); // 32 protos
1855
1856        let mut boxes = Vec::with_capacity(10);
1857        let mut masks = Vec::with_capacity(10);
1858        let result = decode_yolo_end_to_end_segdet_float(
1859            output.view(),
1860            protos.view(),
1861            0.5,
1862            &mut boxes,
1863            &mut masks,
1864        );
1865
1866        assert!(result.is_err());
1867        assert!(matches!(
1868            result,
1869            Err(crate::DecoderError::InvalidShape(s)) if s.contains("doesn't match protos count")
1870        ));
1871    }
1872
1873    // ========================================================================
1874    // Tests for yolo_segmentation_to_mask
1875    // ========================================================================
1876
1877    #[test]
1878    fn test_segmentation_to_mask_basic() {
1879        // Create a 4x4x1 segmentation with values above and below threshold
1880        let data: Vec<u8> = vec![
1881            100, 200, 50, 150, // row 0
1882            10, 255, 128, 64, // row 1
1883            0, 127, 128, 255, // row 2
1884            64, 64, 192, 192, // row 3
1885        ];
1886        let segmentation = Array3::from_shape_vec((4, 4, 1), data).unwrap();
1887
1888        let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
1889
1890        // Values >= 128 should be 1, others 0
1891        assert_eq!(mask[[0, 0]], 0); // 100 < 128
1892        assert_eq!(mask[[0, 1]], 1); // 200 >= 128
1893        assert_eq!(mask[[0, 2]], 0); // 50 < 128
1894        assert_eq!(mask[[0, 3]], 1); // 150 >= 128
1895        assert_eq!(mask[[1, 1]], 1); // 255 >= 128
1896        assert_eq!(mask[[1, 2]], 1); // 128 >= 128
1897        assert_eq!(mask[[2, 0]], 0); // 0 < 128
1898        assert_eq!(mask[[2, 1]], 0); // 127 < 128
1899    }
1900
1901    #[test]
1902    fn test_segmentation_to_mask_all_above() {
1903        let segmentation = Array3::from_elem((4, 4, 1), 255u8);
1904        let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
1905        assert!(mask.iter().all(|&x| x == 1));
1906    }
1907
1908    #[test]
1909    fn test_segmentation_to_mask_all_below() {
1910        let segmentation = Array3::from_elem((4, 4, 1), 64u8);
1911        let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
1912        assert!(mask.iter().all(|&x| x == 0));
1913    }
1914
1915    #[test]
1916    fn test_segmentation_to_mask_invalid_shape() {
1917        let segmentation = Array3::from_elem((4, 4, 3), 128u8);
1918        let result = yolo_segmentation_to_mask(segmentation.view(), 128);
1919
1920        assert!(result.is_err());
1921        assert!(matches!(
1922            result,
1923            Err(crate::DecoderError::InvalidShape(s)) if s.contains("(H, W, 1)")
1924        ));
1925    }
1926}