Skip to main content

edgefirst_decoder/
yolo.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use std::fmt::Debug;
5
6use ndarray::{
7    parallel::prelude::{IntoParallelIterator, ParallelIterator},
8    s, Array2, Array3, ArrayView1, ArrayView2, ArrayView3,
9};
10use ndarray_stats::QuantileExt;
11use num_traits::{AsPrimitive, Float, PrimInt, Signed};
12
13use crate::{
14    byte::{
15        nms_class_aware_int, nms_extra_class_aware_int, nms_extra_int, nms_int,
16        postprocess_boxes_index_quant, postprocess_boxes_quant, quantize_score_threshold,
17    },
18    configs::Nms,
19    dequant_detect_box,
20    float::{
21        nms_class_aware_float, nms_extra_class_aware_float, nms_extra_float, nms_float,
22        postprocess_boxes_float, postprocess_boxes_index_float,
23    },
24    BBoxTypeTrait, BoundingBox, DetectBox, DetectBoxQuantized, Quantization, Segmentation, XYWH,
25    XYXY,
26};
27
28/// Dispatches to the appropriate NMS function based on mode for float boxes.
29fn dispatch_nms_float(nms: Option<Nms>, iou: f32, boxes: Vec<DetectBox>) -> Vec<DetectBox> {
30    match nms {
31        Some(Nms::ClassAgnostic) => nms_float(iou, boxes),
32        Some(Nms::ClassAware) => nms_class_aware_float(iou, boxes),
33        None => boxes, // bypass NMS
34    }
35}
36
37/// Dispatches to the appropriate NMS function based on mode for float boxes
38/// with extra data.
39fn dispatch_nms_extra_float<E: Send + Sync>(
40    nms: Option<Nms>,
41    iou: f32,
42    boxes: Vec<(DetectBox, E)>,
43) -> Vec<(DetectBox, E)> {
44    match nms {
45        Some(Nms::ClassAgnostic) => nms_extra_float(iou, boxes),
46        Some(Nms::ClassAware) => nms_extra_class_aware_float(iou, boxes),
47        None => boxes, // bypass NMS
48    }
49}
50
51/// Dispatches to the appropriate NMS function based on mode for quantized
52/// boxes.
53fn dispatch_nms_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync>(
54    nms: Option<Nms>,
55    iou: f32,
56    boxes: Vec<DetectBoxQuantized<SCORE>>,
57) -> Vec<DetectBoxQuantized<SCORE>> {
58    match nms {
59        Some(Nms::ClassAgnostic) => nms_int(iou, boxes),
60        Some(Nms::ClassAware) => nms_class_aware_int(iou, boxes),
61        None => boxes, // bypass NMS
62    }
63}
64
65/// Dispatches to the appropriate NMS function based on mode for quantized boxes
66/// with extra data.
67fn dispatch_nms_extra_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send + Sync>(
68    nms: Option<Nms>,
69    iou: f32,
70    boxes: Vec<(DetectBoxQuantized<SCORE>, E)>,
71) -> Vec<(DetectBoxQuantized<SCORE>, E)> {
72    match nms {
73        Some(Nms::ClassAgnostic) => nms_extra_int(iou, boxes),
74        Some(Nms::ClassAware) => nms_extra_class_aware_int(iou, boxes),
75        None => boxes, // bypass NMS
76    }
77}
78
79/// Decodes YOLO detection outputs from quantized tensors into detection boxes.
80///
81/// Boxes are expected to be in XYWH format.
82///
83/// Expected shapes of inputs:
84/// - output: (4 + num_classes, num_boxes)
85pub fn decode_yolo_det<BOX: PrimInt + AsPrimitive<f32> + Send + Sync>(
86    output: (ArrayView2<BOX>, Quantization),
87    score_threshold: f32,
88    iou_threshold: f32,
89    nms: Option<Nms>,
90    output_boxes: &mut Vec<DetectBox>,
91) where
92    f32: AsPrimitive<BOX>,
93{
94    impl_yolo_quant::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
95}
96
97/// Decodes YOLO detection outputs from float tensors into detection boxes.
98///
99/// Boxes are expected to be in XYWH format.
100///
101/// Expected shapes of inputs:
102/// - output: (4 + num_classes, num_boxes)
103pub fn decode_yolo_det_float<T>(
104    output: ArrayView2<T>,
105    score_threshold: f32,
106    iou_threshold: f32,
107    nms: Option<Nms>,
108    output_boxes: &mut Vec<DetectBox>,
109) where
110    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
111    f32: AsPrimitive<T>,
112{
113    impl_yolo_float::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
114}
115
116/// Decodes YOLO detection and segmentation outputs from quantized tensors into
117/// detection boxes and segmentation masks.
118///
119/// Boxes are expected to be in XYWH format.
120///
121/// Expected shapes of inputs:
122/// - boxes: (4 + num_classes + num_protos, num_boxes)
123/// - protos: (proto_height, proto_width, num_protos)
124///
125/// # Panics
126/// Panics if shapes don't match the expected dimensions.
127pub fn decode_yolo_segdet_quant<
128    BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
129    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
130>(
131    boxes: (ArrayView2<BOX>, Quantization),
132    protos: (ArrayView3<PROTO>, Quantization),
133    score_threshold: f32,
134    iou_threshold: f32,
135    nms: Option<Nms>,
136    output_boxes: &mut Vec<DetectBox>,
137    output_masks: &mut Vec<Segmentation>,
138) where
139    f32: AsPrimitive<BOX>,
140{
141    impl_yolo_segdet_quant::<XYWH, _, _>(
142        boxes,
143        protos,
144        score_threshold,
145        iou_threshold,
146        nms,
147        output_boxes,
148        output_masks,
149    );
150}
151
152/// Decodes YOLO detection and segmentation outputs from float tensors into
153/// detection boxes and segmentation masks.
154///
155/// Boxes are expected to be in XYWH format.
156///
157/// Expected shapes of inputs:
158/// - boxes: (4 + num_classes + num_protos, num_boxes)
159/// - protos: (proto_height, proto_width, num_protos)
160///
161/// # Panics
162/// Panics if shapes don't match the expected dimensions.
163pub fn decode_yolo_segdet_float<T>(
164    boxes: ArrayView2<T>,
165    protos: ArrayView3<T>,
166    score_threshold: f32,
167    iou_threshold: f32,
168    nms: Option<Nms>,
169    output_boxes: &mut Vec<DetectBox>,
170    output_masks: &mut Vec<Segmentation>,
171) where
172    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
173    f32: AsPrimitive<T>,
174{
175    impl_yolo_segdet_float::<XYWH, _, _>(
176        boxes,
177        protos,
178        score_threshold,
179        iou_threshold,
180        nms,
181        output_boxes,
182        output_masks,
183    );
184}
185
186/// Decodes YOLO split detection outputs from quantized tensors into detection
187/// boxes.
188///
189/// Boxes are expected to be in XYWH format.
190///
191/// Expected shapes of inputs:
192/// - boxes: (4, num_boxes)
193/// - scores: (num_classes, num_boxes)
194///
195/// # Panics
196/// Panics if shapes don't match the expected dimensions.
197pub fn decode_yolo_split_det_quant<
198    BOX: PrimInt + AsPrimitive<i32> + AsPrimitive<f32> + Send + Sync,
199    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
200>(
201    boxes: (ArrayView2<BOX>, Quantization),
202    scores: (ArrayView2<SCORE>, Quantization),
203    score_threshold: f32,
204    iou_threshold: f32,
205    nms: Option<Nms>,
206    output_boxes: &mut Vec<DetectBox>,
207) where
208    f32: AsPrimitive<SCORE>,
209{
210    impl_yolo_split_quant::<XYWH, _, _>(
211        boxes,
212        scores,
213        score_threshold,
214        iou_threshold,
215        nms,
216        output_boxes,
217    );
218}
219
220/// Decodes YOLO split detection outputs from float tensors into detection
221/// boxes.
222///
223/// Boxes are expected to be in XYWH format.
224///
225/// Expected shapes of inputs:
226/// - boxes: (4, num_boxes)
227/// - scores: (num_classes, num_boxes)
228///
229/// # Panics
230/// Panics if shapes don't match the expected dimensions.
231pub fn decode_yolo_split_det_float<T>(
232    boxes: ArrayView2<T>,
233    scores: ArrayView2<T>,
234    score_threshold: f32,
235    iou_threshold: f32,
236    nms: Option<Nms>,
237    output_boxes: &mut Vec<DetectBox>,
238) where
239    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
240    f32: AsPrimitive<T>,
241{
242    impl_yolo_split_float::<XYWH, _, _>(
243        boxes,
244        scores,
245        score_threshold,
246        iou_threshold,
247        nms,
248        output_boxes,
249    );
250}
251
252/// Decodes YOLO split detection segmentation outputs from quantized tensors
253/// into detection boxes and segmentation masks.
254///
255/// Boxes are expected to be in XYWH format.
256///
257/// Expected shapes of inputs:
258/// - boxes_tensor: (4, num_boxes)
259/// - scores_tensor: (num_classes, num_boxes)
260/// - mask_tensor: (num_protos, num_boxes)
261/// - protos: (proto_height, proto_width, num_protos)
262///
263/// # Panics
264/// Panics if shapes don't match the expected dimensions.
265#[allow(clippy::too_many_arguments)]
266pub fn decode_yolo_split_segdet<
267    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
268    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
269    MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
270    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
271>(
272    boxes: (ArrayView2<BOX>, Quantization),
273    scores: (ArrayView2<SCORE>, Quantization),
274    mask_coeff: (ArrayView2<MASK>, Quantization),
275    protos: (ArrayView3<PROTO>, Quantization),
276    score_threshold: f32,
277    iou_threshold: f32,
278    nms: Option<Nms>,
279    output_boxes: &mut Vec<DetectBox>,
280    output_masks: &mut Vec<Segmentation>,
281) where
282    f32: AsPrimitive<SCORE>,
283{
284    impl_yolo_split_segdet_quant::<XYWH, _, _, _, _>(
285        boxes,
286        scores,
287        mask_coeff,
288        protos,
289        score_threshold,
290        iou_threshold,
291        nms,
292        output_boxes,
293        output_masks,
294    );
295}
296
297/// Decodes YOLO split detection segmentation outputs from float tensors
298/// into detection boxes and segmentation masks.
299///
300/// Boxes are expected to be in XYWH format.
301///
302/// Expected shapes of inputs:
303/// - boxes_tensor: (4, num_boxes)
304/// - scores_tensor: (num_classes, num_boxes)
305/// - mask_tensor: (num_protos, num_boxes)
306/// - protos: (proto_height, proto_width, num_protos)
307///
308/// # Panics
309/// Panics if shapes don't match the expected dimensions.
310#[allow(clippy::too_many_arguments)]
311pub fn decode_yolo_split_segdet_float<T>(
312    boxes: ArrayView2<T>,
313    scores: ArrayView2<T>,
314    mask_coeff: ArrayView2<T>,
315    protos: ArrayView3<T>,
316    score_threshold: f32,
317    iou_threshold: f32,
318    nms: Option<Nms>,
319    output_boxes: &mut Vec<DetectBox>,
320    output_masks: &mut Vec<Segmentation>,
321) where
322    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
323    f32: AsPrimitive<T>,
324{
325    impl_yolo_split_segdet_float::<XYWH, _, _, _, _>(
326        boxes,
327        scores,
328        mask_coeff,
329        protos,
330        score_threshold,
331        iou_threshold,
332        nms,
333        output_boxes,
334        output_masks,
335    );
336}
337
338/// Decodes end-to-end YOLO detection outputs (post-NMS from model).
339/// Expects an array of shape `(6, N)`, where the first dimension (rows)
340/// corresponds to the 6 per-detection features
341/// `[x1, y1, x2, y2, conf, class]` and the second dimension (columns)
342/// indexes the `N` detections.
343/// Boxes are output directly without NMS (the model already applied NMS).
344///
345/// Coordinates may be normalized `[0, 1]` or absolute pixel values depending
346/// on the model configuration. The caller should check
347/// `decoder.normalized_boxes()` to determine which.
348///
349/// # Errors
350///
351/// Returns `DecoderError::InvalidShape` if `output` has fewer than 6 rows.
352pub fn decode_yolo_end_to_end_det_float<T>(
353    output: ArrayView2<T>,
354    score_threshold: f32,
355    output_boxes: &mut Vec<DetectBox>,
356) -> Result<(), crate::DecoderError>
357where
358    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
359    f32: AsPrimitive<T>,
360{
361    // Validate input shape: need at least 6 rows (x1, y1, x2, y2, conf, class)
362    if output.shape()[0] < 6 {
363        return Err(crate::DecoderError::InvalidShape(format!(
364            "End-to-end detection output requires at least 6 rows, got {}",
365            output.shape()[0]
366        )));
367    }
368
369    // Input shape: (6, N) -> transpose to (N, 4) for boxes and (N, 1) for scores
370    let boxes = output.slice(s![0..4, ..]).reversed_axes();
371    let scores = output.slice(s![4..5, ..]).reversed_axes();
372    let classes = output.slice(s![5, ..]);
373    let mut boxes =
374        postprocess_boxes_index_float::<XYXY, _, _>(score_threshold.as_(), boxes, scores);
375    boxes.truncate(output_boxes.capacity());
376    output_boxes.clear();
377    for (mut b, i) in boxes.into_iter() {
378        b.label = classes[i].as_() as usize;
379        output_boxes.push(b);
380    }
381    // No NMS — model output is already post-NMS
382    Ok(())
383}
384
385/// Decodes end-to-end YOLO detection + segmentation outputs (post-NMS from
386/// model).
387///
388/// Input shapes:
389/// - detection: (6 + num_protos, N) where rows are [x1, y1, x2, y2, conf,
390///   class, mask_coeff_0, ..., mask_coeff_31]
391/// - protos: (proto_height, proto_width, num_protos)
392///
393/// Boxes are output directly without NMS (model already applied NMS).
394/// Coordinates may be normalized [0,1] or pixel values depending on model
395/// config.
396///
397/// # Errors
398///
399/// Returns `DecoderError::InvalidShape` if:
400/// - output has fewer than 7 rows (6 base + at least 1 mask coefficient)
401/// - protos shape doesn't match mask coefficients count
402pub fn decode_yolo_end_to_end_segdet_float<T>(
403    output: ArrayView2<T>,
404    protos: ArrayView3<T>,
405    score_threshold: f32,
406    output_boxes: &mut Vec<DetectBox>,
407    output_masks: &mut Vec<crate::Segmentation>,
408) -> Result<(), crate::DecoderError>
409where
410    T: Float + AsPrimitive<f32> + Send + Sync + 'static,
411    f32: AsPrimitive<T>,
412{
413    // Validate input shape: need at least 7 rows (6 base + at least 1 mask coeff)
414    if output.shape()[0] < 7 {
415        return Err(crate::DecoderError::InvalidShape(format!(
416            "End-to-end segdet output requires at least 7 rows, got {}",
417            output.shape()[0]
418        )));
419    }
420
421    let num_mask_coeffs = output.shape()[0] - 6;
422    let num_protos = protos.shape()[2];
423    if num_mask_coeffs != num_protos {
424        return Err(crate::DecoderError::InvalidShape(format!(
425            "Mask coefficients count ({}) doesn't match protos count ({})",
426            num_mask_coeffs, num_protos
427        )));
428    }
429
430    // Input shape: (6+num_protos, N) -> transpose for postprocessing
431    let boxes = output.slice(s![0..4, ..]).reversed_axes();
432    let scores = output.slice(s![4..5, ..]).reversed_axes();
433    let classes = output.slice(s![5, ..]);
434    let mask_coeff = output.slice(s![6.., ..]).reversed_axes();
435    let mut boxes =
436        postprocess_boxes_index_float::<XYXY, _, _>(score_threshold.as_(), boxes, scores);
437    boxes.truncate(output_boxes.capacity());
438
439    for (b, ind) in &mut boxes {
440        b.label = classes[*ind].as_() as usize;
441    }
442
443    // No NMS — model output is already post-NMS
444
445    let boxes = decode_segdet_f32(boxes, mask_coeff, protos);
446
447    output_boxes.clear();
448    output_masks.clear();
449    for (b, m) in boxes.into_iter() {
450        output_boxes.push(b);
451        output_masks.push(Segmentation {
452            xmin: b.bbox.xmin,
453            ymin: b.bbox.ymin,
454            xmax: b.bbox.xmax,
455            ymax: b.bbox.ymax,
456            segmentation: m,
457        });
458    }
459    Ok(())
460}
461/// Internal implementation of YOLO decoding for quantized tensors.
462///
463/// Expected shapes of inputs:
464/// - output: (4 + num_classes, num_boxes)
465pub fn impl_yolo_quant<B: BBoxTypeTrait, T: PrimInt + AsPrimitive<f32> + Send + Sync>(
466    output: (ArrayView2<T>, Quantization),
467    score_threshold: f32,
468    iou_threshold: f32,
469    nms: Option<Nms>,
470    output_boxes: &mut Vec<DetectBox>,
471) where
472    f32: AsPrimitive<T>,
473{
474    let (boxes, quant_boxes) = output;
475    let (boxes_tensor, scores_tensor) = postprocess_yolo(&boxes);
476
477    let boxes = {
478        let score_threshold = quantize_score_threshold(score_threshold, quant_boxes);
479        postprocess_boxes_quant::<B, _, _>(
480            score_threshold,
481            boxes_tensor,
482            scores_tensor,
483            quant_boxes,
484        )
485    };
486
487    let boxes = dispatch_nms_int(nms, iou_threshold, boxes);
488    let len = output_boxes.capacity().min(boxes.len());
489    output_boxes.clear();
490    for b in boxes.iter().take(len) {
491        output_boxes.push(dequant_detect_box(b, quant_boxes));
492    }
493}
494
495/// Internal implementation of YOLO decoding for float tensors.
496///
497/// Expected shapes of inputs:
498/// - output: (4 + num_classes, num_boxes)
499pub fn impl_yolo_float<B: BBoxTypeTrait, T: Float + AsPrimitive<f32> + Send + Sync>(
500    output: ArrayView2<T>,
501    score_threshold: f32,
502    iou_threshold: f32,
503    nms: Option<Nms>,
504    output_boxes: &mut Vec<DetectBox>,
505) where
506    f32: AsPrimitive<T>,
507{
508    let (boxes_tensor, scores_tensor) = postprocess_yolo(&output);
509    let boxes =
510        postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
511    let boxes = dispatch_nms_float(nms, iou_threshold, boxes);
512    let len = output_boxes.capacity().min(boxes.len());
513    output_boxes.clear();
514    for b in boxes.into_iter().take(len) {
515        output_boxes.push(b);
516    }
517}
518
519/// Internal implementation of YOLO split detection decoding for quantized
520/// tensors.
521///
522/// Expected shapes of inputs:
523/// - boxes: (4, num_boxes)
524/// - scores: (num_classes, num_boxes)
525///
526/// # Panics
527/// Panics if shapes don't match the expected dimensions.
528pub fn impl_yolo_split_quant<
529    B: BBoxTypeTrait,
530    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
531    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
532>(
533    boxes: (ArrayView2<BOX>, Quantization),
534    scores: (ArrayView2<SCORE>, Quantization),
535    score_threshold: f32,
536    iou_threshold: f32,
537    nms: Option<Nms>,
538    output_boxes: &mut Vec<DetectBox>,
539) where
540    f32: AsPrimitive<SCORE>,
541{
542    let (boxes_tensor, quant_boxes) = boxes;
543    let (scores_tensor, quant_scores) = scores;
544
545    let boxes_tensor = boxes_tensor.reversed_axes();
546    let scores_tensor = scores_tensor.reversed_axes();
547
548    let boxes = {
549        let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
550        postprocess_boxes_quant::<B, _, _>(
551            score_threshold,
552            boxes_tensor,
553            scores_tensor,
554            quant_boxes,
555        )
556    };
557
558    let boxes = dispatch_nms_int(nms, iou_threshold, boxes);
559    let len = output_boxes.capacity().min(boxes.len());
560    output_boxes.clear();
561    for b in boxes.iter().take(len) {
562        output_boxes.push(dequant_detect_box(b, quant_scores));
563    }
564}
565
566/// Internal implementation of YOLO split detection decoding for float tensors.
567///
568/// Expected shapes of inputs:
569/// - boxes: (4, num_boxes)
570/// - scores: (num_classes, num_boxes)
571///
572/// # Panics
573/// Panics if shapes don't match the expected dimensions.
574pub fn impl_yolo_split_float<
575    B: BBoxTypeTrait,
576    BOX: Float + AsPrimitive<f32> + Send + Sync,
577    SCORE: Float + AsPrimitive<f32> + Send + Sync,
578>(
579    boxes_tensor: ArrayView2<BOX>,
580    scores_tensor: ArrayView2<SCORE>,
581    score_threshold: f32,
582    iou_threshold: f32,
583    nms: Option<Nms>,
584    output_boxes: &mut Vec<DetectBox>,
585) where
586    f32: AsPrimitive<SCORE>,
587{
588    let boxes_tensor = boxes_tensor.reversed_axes();
589    let scores_tensor = scores_tensor.reversed_axes();
590    let boxes =
591        postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
592    let boxes = dispatch_nms_float(nms, iou_threshold, boxes);
593    let len = output_boxes.capacity().min(boxes.len());
594    output_boxes.clear();
595    for b in boxes.into_iter().take(len) {
596        output_boxes.push(b);
597    }
598}
599
600/// Internal implementation of YOLO detection segmentation decoding for
601/// quantized tensors.
602///
603/// Expected shapes of inputs:
604/// - boxes: (4 + num_classes + num_protos, num_boxes)
605/// - protos: (proto_height, proto_width, num_protos)
606///
607/// # Panics
608/// Panics if shapes don't match the expected dimensions.
609pub fn impl_yolo_segdet_quant<
610    B: BBoxTypeTrait,
611    BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
612    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
613>(
614    boxes: (ArrayView2<BOX>, Quantization),
615    protos: (ArrayView3<PROTO>, Quantization),
616    score_threshold: f32,
617    iou_threshold: f32,
618    nms: Option<Nms>,
619    output_boxes: &mut Vec<DetectBox>,
620    output_masks: &mut Vec<Segmentation>,
621) where
622    f32: AsPrimitive<BOX>,
623{
624    let (boxes, quant_boxes) = boxes;
625    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes);
626
627    let boxes = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
628        (boxes_tensor.reversed_axes(), quant_boxes),
629        (scores_tensor.reversed_axes(), quant_boxes),
630        score_threshold,
631        iou_threshold,
632        nms,
633        output_boxes.capacity(),
634    );
635
636    impl_yolo_split_segdet_quant_process_masks::<_, _>(
637        boxes,
638        (mask_tensor.reversed_axes(), quant_boxes),
639        protos,
640        output_boxes,
641        output_masks,
642    );
643}
644
645/// Internal implementation of YOLO detection segmentation decoding for
646/// float tensors.
647///
648/// Expected shapes of inputs:
649/// - boxes: (4 + num_classes + num_protos, num_boxes)
650/// - protos: (proto_height, proto_width, num_protos)
651///
652/// # Panics
653/// Panics if shapes don't match the expected dimensions.
654pub fn impl_yolo_segdet_float<
655    B: BBoxTypeTrait,
656    BOX: Float + AsPrimitive<f32> + Send + Sync,
657    PROTO: Float + AsPrimitive<f32> + Send + Sync,
658>(
659    boxes: ArrayView2<BOX>,
660    protos: ArrayView3<PROTO>,
661    score_threshold: f32,
662    iou_threshold: f32,
663    nms: Option<Nms>,
664    output_boxes: &mut Vec<DetectBox>,
665    output_masks: &mut Vec<Segmentation>,
666) where
667    f32: AsPrimitive<BOX>,
668{
669    let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes);
670
671    let boxes = postprocess_boxes_index_float::<B, _, _>(
672        score_threshold.as_(),
673        boxes_tensor,
674        scores_tensor,
675    );
676    let mut boxes = dispatch_nms_extra_float(nms, iou_threshold, boxes);
677    boxes.truncate(output_boxes.capacity());
678    let boxes = decode_segdet_f32(boxes, mask_tensor, protos);
679    output_boxes.clear();
680    output_masks.clear();
681    for (b, m) in boxes.into_iter() {
682        output_boxes.push(b);
683        output_masks.push(Segmentation {
684            xmin: b.bbox.xmin,
685            ymin: b.bbox.ymin,
686            xmax: b.bbox.xmax,
687            ymax: b.bbox.ymax,
688            segmentation: m,
689        });
690    }
691}
692
693pub(crate) fn impl_yolo_split_segdet_quant_get_boxes<
694    B: BBoxTypeTrait,
695    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
696    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
697>(
698    boxes: (ArrayView2<BOX>, Quantization),
699    scores: (ArrayView2<SCORE>, Quantization),
700    score_threshold: f32,
701    iou_threshold: f32,
702    nms: Option<Nms>,
703    max_boxes: usize,
704) -> Vec<(DetectBox, usize)>
705where
706    f32: AsPrimitive<SCORE>,
707{
708    let (boxes_tensor, quant_boxes) = boxes;
709    let (scores_tensor, quant_scores) = scores;
710
711    let boxes_tensor = boxes_tensor.reversed_axes();
712    let scores_tensor = scores_tensor.reversed_axes();
713
714    let boxes = {
715        let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
716        postprocess_boxes_index_quant::<B, _, _>(
717            score_threshold,
718            boxes_tensor,
719            scores_tensor,
720            quant_boxes,
721        )
722    };
723    let mut boxes = dispatch_nms_extra_int(nms, iou_threshold, boxes);
724    boxes.truncate(max_boxes);
725    boxes
726        .into_iter()
727        .map(|(b, i)| (dequant_detect_box(&b, quant_scores), i))
728        .collect()
729}
730
731pub(crate) fn impl_yolo_split_segdet_quant_process_masks<
732    MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
733    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
734>(
735    boxes: Vec<(DetectBox, usize)>,
736    mask_coeff: (ArrayView2<MASK>, Quantization),
737    protos: (ArrayView3<PROTO>, Quantization),
738    output_boxes: &mut Vec<DetectBox>,
739    output_masks: &mut Vec<Segmentation>,
740) {
741    let (masks, quant_masks) = mask_coeff;
742    let (protos, quant_protos) = protos;
743
744    let masks = masks.reversed_axes();
745
746    let boxes = decode_segdet_quant(boxes, masks, protos, quant_masks, quant_protos);
747    output_boxes.clear();
748    output_masks.clear();
749    for (b, m) in boxes.into_iter() {
750        output_boxes.push(b);
751        output_masks.push(Segmentation {
752            xmin: b.bbox.xmin,
753            ymin: b.bbox.ymin,
754            xmax: b.bbox.xmax,
755            ymax: b.bbox.ymax,
756            segmentation: m,
757        });
758    }
759}
760
761#[allow(clippy::too_many_arguments)]
762/// Internal implementation of YOLO split detection segmentation decoding for
763/// quantized tensors.
764///
765/// Expected shapes of inputs:
766/// - boxes_tensor: (4, num_boxes)
767/// - scores_tensor: (num_classes, num_boxes)
768/// - mask_tensor: (num_protos, num_boxes)
769/// - protos: (proto_height, proto_width, num_protos)
770///
771/// # Panics
772/// Panics if shapes don't match the expected dimensions.
773pub fn impl_yolo_split_segdet_quant<
774    B: BBoxTypeTrait,
775    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
776    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
777    MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
778    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
779>(
780    boxes: (ArrayView2<BOX>, Quantization),
781    scores: (ArrayView2<SCORE>, Quantization),
782    mask_coeff: (ArrayView2<MASK>, Quantization),
783    protos: (ArrayView3<PROTO>, Quantization),
784    score_threshold: f32,
785    iou_threshold: f32,
786    nms: Option<Nms>,
787    output_boxes: &mut Vec<DetectBox>,
788    output_masks: &mut Vec<Segmentation>,
789) where
790    f32: AsPrimitive<SCORE>,
791{
792    let boxes = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
793        boxes,
794        scores,
795        score_threshold,
796        iou_threshold,
797        nms,
798        output_boxes.capacity(),
799    );
800
801    impl_yolo_split_segdet_quant_process_masks(
802        boxes,
803        mask_coeff,
804        protos,
805        output_boxes,
806        output_masks,
807    );
808}
809
810#[allow(clippy::too_many_arguments)]
811/// Internal implementation of YOLO split detection segmentation decoding for
812/// float tensors.
813///
814/// Expected shapes of inputs:
815/// - boxes_tensor: (4, num_boxes)
816/// - scores_tensor: (num_classes, num_boxes)
817/// - mask_tensor: (num_protos, num_boxes)
818/// - protos: (proto_height, proto_width, num_protos)
819///
820/// # Panics
821/// Panics if shapes don't match the expected dimensions.
822pub fn impl_yolo_split_segdet_float<
823    B: BBoxTypeTrait,
824    BOX: Float + AsPrimitive<f32> + Send + Sync,
825    SCORE: Float + AsPrimitive<f32> + Send + Sync,
826    MASK: Float + AsPrimitive<f32> + Send + Sync,
827    PROTO: Float + AsPrimitive<f32> + Send + Sync,
828>(
829    boxes_tensor: ArrayView2<BOX>,
830    scores_tensor: ArrayView2<SCORE>,
831    mask_tensor: ArrayView2<MASK>,
832    protos: ArrayView3<PROTO>,
833    score_threshold: f32,
834    iou_threshold: f32,
835    nms: Option<Nms>,
836    output_boxes: &mut Vec<DetectBox>,
837    output_masks: &mut Vec<Segmentation>,
838) where
839    f32: AsPrimitive<SCORE>,
840{
841    let boxes_tensor = boxes_tensor.reversed_axes();
842    let scores_tensor = scores_tensor.reversed_axes();
843    let mask_tensor = mask_tensor.reversed_axes();
844
845    let boxes = postprocess_boxes_index_float::<B, _, _>(
846        score_threshold.as_(),
847        boxes_tensor,
848        scores_tensor,
849    );
850    let mut boxes = dispatch_nms_extra_float(nms, iou_threshold, boxes);
851    boxes.truncate(output_boxes.capacity());
852    let boxes = decode_segdet_f32(boxes, mask_tensor, protos);
853    output_boxes.clear();
854    output_masks.clear();
855    for (b, m) in boxes.into_iter() {
856        output_boxes.push(b);
857        output_masks.push(Segmentation {
858            xmin: b.bbox.xmin,
859            ymin: b.bbox.ymin,
860            xmax: b.bbox.xmax,
861            ymax: b.bbox.ymax,
862            segmentation: m,
863        });
864    }
865}
866
867fn postprocess_yolo<'a, T>(
868    output: &'a ArrayView2<'_, T>,
869) -> (ArrayView2<'a, T>, ArrayView2<'a, T>) {
870    let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
871    let scores_tensor = output.slice(s![4.., ..,]).reversed_axes();
872    (boxes_tensor, scores_tensor)
873}
874
875fn postprocess_yolo_seg<'a, T>(
876    output: &'a ArrayView2<'_, T>,
877) -> (ArrayView2<'a, T>, ArrayView2<'a, T>, ArrayView2<'a, T>) {
878    assert!(output.shape()[0] > 32 + 4, "Output shape is too short");
879    let num_classes = output.shape()[0] - 4 - 32;
880    let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
881    let scores_tensor = output.slice(s![4..(num_classes + 4), ..,]).reversed_axes();
882    let mask_tensor = output.slice(s![(num_classes + 4).., ..,]).reversed_axes();
883    (boxes_tensor, scores_tensor, mask_tensor)
884}
885
886fn decode_segdet_f32<
887    MASK: Float + AsPrimitive<f32> + Send + Sync,
888    PROTO: Float + AsPrimitive<f32> + Send + Sync,
889>(
890    boxes: Vec<(DetectBox, usize)>,
891    masks: ArrayView2<MASK>,
892    protos: ArrayView3<PROTO>,
893) -> Vec<(DetectBox, Array3<u8>)> {
894    if boxes.is_empty() {
895        return Vec::new();
896    }
897    assert!(masks.shape()[1] == protos.shape()[2]);
898    boxes
899        .into_par_iter()
900        .map(|mut b| {
901            let ind = b.1;
902            let (protos, roi) = protobox(&protos, &b.0.bbox);
903            b.0.bbox = roi;
904            (b.0, make_segmentation(masks.row(ind), protos.view()))
905        })
906        .collect()
907}
908
909pub(crate) fn decode_segdet_quant<
910    MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
911    PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
912>(
913    boxes: Vec<(DetectBox, usize)>,
914    masks: ArrayView2<MASK>,
915    protos: ArrayView3<PROTO>,
916    quant_masks: Quantization,
917    quant_protos: Quantization,
918) -> Vec<(DetectBox, Array3<u8>)> {
919    if boxes.is_empty() {
920        return Vec::new();
921    }
922    assert!(masks.shape()[1] == protos.shape()[2]);
923
924    let total_bits = MASK::zero().count_zeros() + PROTO::zero().count_zeros() + 5; // 32 protos is 2^5
925    boxes
926        .into_iter()
927        .map(|mut b| {
928            let i = b.1;
929            let (protos, roi) = protobox(&protos, &b.0.bbox.to_canonical());
930            b.0.bbox = roi;
931            let seg = match total_bits {
932                0..=64 => make_segmentation_quant::<MASK, PROTO, i64>(
933                    masks.row(i),
934                    protos.view(),
935                    quant_masks,
936                    quant_protos,
937                ),
938                65..=128 => make_segmentation_quant::<MASK, PROTO, i128>(
939                    masks.row(i),
940                    protos.view(),
941                    quant_masks,
942                    quant_protos,
943                ),
944                _ => panic!("Unsupported bit width for segmentation computation"),
945            };
946            (b.0, seg)
947        })
948        .collect()
949}
950
951fn protobox<'a, T>(
952    protos: &'a ArrayView3<T>,
953    roi: &BoundingBox,
954) -> (ArrayView3<'a, T>, BoundingBox) {
955    let width = protos.dim().1 as f32;
956    let height = protos.dim().0 as f32;
957
958    let roi = [
959        (roi.xmin * width).clamp(0.0, width) as usize,
960        (roi.ymin * height).clamp(0.0, height) as usize,
961        (roi.xmax * width).clamp(0.0, width).ceil() as usize,
962        (roi.ymax * height).clamp(0.0, height).ceil() as usize,
963    ];
964
965    let roi_norm = [
966        roi[0] as f32 / width,
967        roi[1] as f32 / height,
968        roi[2] as f32 / width,
969        roi[3] as f32 / height,
970    ]
971    .into();
972
973    let cropped = protos.slice(s![roi[1]..roi[3], roi[0]..roi[2], ..]);
974
975    (cropped, roi_norm)
976}
977
978fn make_segmentation<
979    MASK: Float + AsPrimitive<f32> + Send + Sync,
980    PROTO: Float + AsPrimitive<f32> + Send + Sync,
981>(
982    mask: ArrayView1<MASK>,
983    protos: ArrayView3<PROTO>,
984) -> Array3<u8> {
985    let shape = protos.shape();
986
987    // Safe to unwrap since the shapes will always be compatible
988    let mask = mask.to_shape((1, mask.len())).unwrap();
989    let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
990    let protos = protos.reversed_axes();
991    let mask = mask.map(|x| x.as_());
992    let protos = protos.map(|x| x.as_());
993
994    // Safe to unwrap since the shapes will always be compatible
995    let mask = mask
996        .dot(&protos)
997        .into_shape_with_order((shape[0], shape[1], 1))
998        .unwrap();
999
1000    let min = *mask.min().unwrap_or(&0.0);
1001    let max = *mask.max().unwrap_or(&1.0);
1002    let max = max.max(-min);
1003    let min = -max;
1004    let u8_max = 256.0;
1005    mask.map(|x| ((*x - min) / (max - min) * u8_max) as u8)
1006}
1007
1008fn make_segmentation_quant<
1009    MASK: PrimInt + AsPrimitive<DEST> + Send + Sync,
1010    PROTO: PrimInt + AsPrimitive<DEST> + Send + Sync,
1011    DEST: PrimInt + 'static + Signed + AsPrimitive<f32> + Debug,
1012>(
1013    mask: ArrayView1<MASK>,
1014    protos: ArrayView3<PROTO>,
1015    quant_masks: Quantization,
1016    quant_protos: Quantization,
1017) -> Array3<u8>
1018where
1019    i32: AsPrimitive<DEST>,
1020    f32: AsPrimitive<DEST>,
1021{
1022    let shape = protos.shape();
1023
1024    // Safe to unwrap since the shapes will always be compatible
1025    let mask = mask.to_shape((1, mask.len())).unwrap();
1026
1027    let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
1028    let protos = protos.reversed_axes();
1029
1030    let zp = quant_masks.zero_point.as_();
1031
1032    let mask = mask.mapv(|x| x.as_() - zp);
1033
1034    let zp = quant_protos.zero_point.as_();
1035    let protos = protos.mapv(|x| x.as_() - zp);
1036
1037    // Safe to unwrap since the shapes will always be compatible
1038    let segmentation = mask
1039        .dot(&protos)
1040        .into_shape_with_order((shape[0], shape[1], 1))
1041        .unwrap();
1042
1043    let min = *segmentation.min().unwrap_or(&DEST::zero());
1044    let max = *segmentation.max().unwrap_or(&DEST::one());
1045    let max = max.max(-min);
1046    let min = -max;
1047    segmentation.map(|x| ((*x - min).as_() / (max - min).as_() * 256.0) as u8)
1048}
1049
1050/// Converts Yolo Instance Segmentation into a 2D mask.
1051///
1052/// The input segmentation is expected to have shape (H, W, 1).
1053///
1054/// The output mask will have shape (H, W), with values 0 or 1 based on the
1055/// threshold.
1056///
1057/// # Errors
1058///
1059/// Returns `DecoderError::InvalidShape` if the input segmentation does not
1060/// have shape (H, W, 1).
1061pub fn yolo_segmentation_to_mask(
1062    segmentation: ArrayView3<u8>,
1063    threshold: u8,
1064) -> Result<Array2<u8>, crate::DecoderError> {
1065    if segmentation.shape()[2] != 1 {
1066        return Err(crate::DecoderError::InvalidShape(format!(
1067            "Yolo Instance Segmentation should have shape (H, W, 1), got (H, W, {})",
1068            segmentation.shape()[2]
1069        )));
1070    }
1071    Ok(segmentation
1072        .slice(s![.., .., 0])
1073        .map(|x| if *x >= threshold { 1 } else { 0 }))
1074}
1075
1076#[cfg(test)]
1077#[cfg_attr(coverage_nightly, coverage(off))]
1078mod tests {
1079    use super::*;
1080    use ndarray::Array2;
1081
1082    // ========================================================================
1083    // Tests for decode_yolo_end_to_end_det_float
1084    // ========================================================================
1085
1086    #[test]
1087    fn test_end_to_end_det_basic_filtering() {
1088        // Create synthetic end-to-end detection output: (6, N) where rows are
1089        // [x1, y1, x2, y2, conf, class]
1090        // 3 detections: one above threshold, two below
1091        let data: Vec<f32> = vec![
1092            // Detection 0: high score (0.9)
1093            0.1, 0.2, 0.3, // x1 values
1094            0.1, 0.2, 0.3, // y1 values
1095            0.5, 0.6, 0.7, // x2 values
1096            0.5, 0.6, 0.7, // y2 values
1097            0.9, 0.1, 0.2, // confidence scores
1098            0.0, 1.0, 2.0, // class indices
1099        ];
1100        let output = Array2::from_shape_vec((6, 3), data).unwrap();
1101
1102        let mut boxes = Vec::with_capacity(10);
1103        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1104
1105        // Only 1 detection should pass threshold of 0.5
1106        assert_eq!(boxes.len(), 1);
1107        assert_eq!(boxes[0].label, 0);
1108        assert!((boxes[0].score - 0.9).abs() < 0.01);
1109        assert!((boxes[0].bbox.xmin - 0.1).abs() < 0.01);
1110        assert!((boxes[0].bbox.ymin - 0.1).abs() < 0.01);
1111        assert!((boxes[0].bbox.xmax - 0.5).abs() < 0.01);
1112        assert!((boxes[0].bbox.ymax - 0.5).abs() < 0.01);
1113    }
1114
1115    #[test]
1116    fn test_end_to_end_det_all_pass_threshold() {
1117        // All detections above threshold
1118        let data: Vec<f32> = vec![
1119            10.0, 20.0, // x1
1120            10.0, 20.0, // y1
1121            50.0, 60.0, // x2
1122            50.0, 60.0, // y2
1123            0.8, 0.7, // conf (both above 0.5)
1124            1.0, 2.0, // class
1125        ];
1126        let output = Array2::from_shape_vec((6, 2), data).unwrap();
1127
1128        let mut boxes = Vec::with_capacity(10);
1129        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1130
1131        assert_eq!(boxes.len(), 2);
1132        assert_eq!(boxes[0].label, 1);
1133        assert_eq!(boxes[1].label, 2);
1134    }
1135
1136    #[test]
1137    fn test_end_to_end_det_none_pass_threshold() {
1138        // All detections below threshold
1139        let data: Vec<f32> = vec![
1140            10.0, 20.0, // x1
1141            10.0, 20.0, // y1
1142            50.0, 60.0, // x2
1143            50.0, 60.0, // y2
1144            0.1, 0.2, // conf (both below 0.5)
1145            1.0, 2.0, // class
1146        ];
1147        let output = Array2::from_shape_vec((6, 2), data).unwrap();
1148
1149        let mut boxes = Vec::with_capacity(10);
1150        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1151
1152        assert_eq!(boxes.len(), 0);
1153    }
1154
1155    #[test]
1156    fn test_end_to_end_det_capacity_limit() {
1157        // Test that output is truncated to capacity
1158        let data: Vec<f32> = vec![
1159            0.1, 0.2, 0.3, 0.4, 0.5, // x1
1160            0.1, 0.2, 0.3, 0.4, 0.5, // y1
1161            0.5, 0.6, 0.7, 0.8, 0.9, // x2
1162            0.5, 0.6, 0.7, 0.8, 0.9, // y2
1163            0.9, 0.9, 0.9, 0.9, 0.9, // conf (all pass)
1164            0.0, 1.0, 2.0, 3.0, 4.0, // class
1165        ];
1166        let output = Array2::from_shape_vec((6, 5), data).unwrap();
1167
1168        let mut boxes = Vec::with_capacity(2); // Only allow 2 boxes
1169        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1170
1171        assert_eq!(boxes.len(), 2);
1172    }
1173
1174    #[test]
1175    fn test_end_to_end_det_empty_output() {
1176        // Test with zero detections
1177        let output = Array2::<f32>::zeros((6, 0));
1178
1179        let mut boxes = Vec::with_capacity(10);
1180        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1181
1182        assert_eq!(boxes.len(), 0);
1183    }
1184
1185    #[test]
1186    fn test_end_to_end_det_pixel_coordinates() {
1187        // Test with pixel coordinates (non-normalized)
1188        let data: Vec<f32> = vec![
1189            100.0, // x1
1190            200.0, // y1
1191            300.0, // x2
1192            400.0, // y2
1193            0.95,  // conf
1194            5.0,   // class
1195        ];
1196        let output = Array2::from_shape_vec((6, 1), data).unwrap();
1197
1198        let mut boxes = Vec::with_capacity(10);
1199        decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
1200
1201        assert_eq!(boxes.len(), 1);
1202        assert_eq!(boxes[0].label, 5);
1203        assert!((boxes[0].bbox.xmin - 100.0).abs() < 0.01);
1204        assert!((boxes[0].bbox.ymin - 200.0).abs() < 0.01);
1205        assert!((boxes[0].bbox.xmax - 300.0).abs() < 0.01);
1206        assert!((boxes[0].bbox.ymax - 400.0).abs() < 0.01);
1207    }
1208
1209    #[test]
1210    fn test_end_to_end_det_invalid_shape() {
1211        // Test with too few rows (needs at least 6)
1212        let output = Array2::<f32>::zeros((5, 3));
1213
1214        let mut boxes = Vec::with_capacity(10);
1215        let result = decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes);
1216
1217        assert!(result.is_err());
1218        assert!(matches!(
1219            result,
1220            Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 6 rows")
1221        ));
1222    }
1223
1224    // ========================================================================
1225    // Tests for decode_yolo_end_to_end_segdet_float
1226    // ========================================================================
1227
1228    #[test]
1229    fn test_end_to_end_segdet_basic() {
1230        // Create synthetic segdet output: (6 + num_protos, N)
1231        // Detection format: [x1, y1, x2, y2, conf, class, mask_coeff_0..31]
1232        let num_protos = 32;
1233        let num_detections = 2;
1234        let num_features = 6 + num_protos;
1235
1236        // Build detection tensor
1237        let mut data = vec![0.0f32; num_features * num_detections];
1238        // Detection 0: passes threshold
1239        data[0] = 0.1; // x1[0]
1240        data[1] = 0.5; // x1[1]
1241        data[num_detections] = 0.1; // y1[0]
1242        data[num_detections + 1] = 0.5; // y1[1]
1243        data[2 * num_detections] = 0.4; // x2[0]
1244        data[2 * num_detections + 1] = 0.9; // x2[1]
1245        data[3 * num_detections] = 0.4; // y2[0]
1246        data[3 * num_detections + 1] = 0.9; // y2[1]
1247        data[4 * num_detections] = 0.9; // conf[0] - passes
1248        data[4 * num_detections + 1] = 0.3; // conf[1] - fails
1249        data[5 * num_detections] = 1.0; // class[0]
1250        data[5 * num_detections + 1] = 2.0; // class[1]
1251                                            // Fill mask coefficients with small values
1252        for i in 6..num_features {
1253            data[i * num_detections] = 0.1;
1254            data[i * num_detections + 1] = 0.1;
1255        }
1256
1257        let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
1258
1259        // Create protos tensor: (proto_height, proto_width, num_protos)
1260        let protos = Array3::<f32>::zeros((16, 16, num_protos));
1261
1262        let mut boxes = Vec::with_capacity(10);
1263        let mut masks = Vec::with_capacity(10);
1264        decode_yolo_end_to_end_segdet_float(
1265            output.view(),
1266            protos.view(),
1267            0.5,
1268            &mut boxes,
1269            &mut masks,
1270        )
1271        .unwrap();
1272
1273        // Only detection 0 should pass
1274        assert_eq!(boxes.len(), 1);
1275        assert_eq!(masks.len(), 1);
1276        assert_eq!(boxes[0].label, 1);
1277        assert!((boxes[0].score - 0.9).abs() < 0.01);
1278    }
1279
1280    #[test]
1281    fn test_end_to_end_segdet_mask_coordinates() {
1282        // Test that mask coordinates match box coordinates
1283        let num_protos = 32;
1284        let num_features = 6 + num_protos;
1285
1286        let mut data = vec![0.0f32; num_features];
1287        data[0] = 0.2; // x1
1288        data[1] = 0.2; // y1
1289        data[2] = 0.8; // x2
1290        data[3] = 0.8; // y2
1291        data[4] = 0.95; // conf
1292        data[5] = 3.0; // class
1293
1294        let output = Array2::from_shape_vec((num_features, 1), data).unwrap();
1295        let protos = Array3::<f32>::zeros((16, 16, num_protos));
1296
1297        let mut boxes = Vec::with_capacity(10);
1298        let mut masks = Vec::with_capacity(10);
1299        decode_yolo_end_to_end_segdet_float(
1300            output.view(),
1301            protos.view(),
1302            0.5,
1303            &mut boxes,
1304            &mut masks,
1305        )
1306        .unwrap();
1307
1308        assert_eq!(boxes.len(), 1);
1309        assert_eq!(masks.len(), 1);
1310
1311        // Verify mask coordinates match box coordinates
1312        assert!((masks[0].xmin - boxes[0].bbox.xmin).abs() < 0.01);
1313        assert!((masks[0].ymin - boxes[0].bbox.ymin).abs() < 0.01);
1314        assert!((masks[0].xmax - boxes[0].bbox.xmax).abs() < 0.01);
1315        assert!((masks[0].ymax - boxes[0].bbox.ymax).abs() < 0.01);
1316    }
1317
1318    #[test]
1319    fn test_end_to_end_segdet_empty_output() {
1320        let num_protos = 32;
1321        let output = Array2::<f32>::zeros((6 + num_protos, 0));
1322        let protos = Array3::<f32>::zeros((16, 16, num_protos));
1323
1324        let mut boxes = Vec::with_capacity(10);
1325        let mut masks = Vec::with_capacity(10);
1326        decode_yolo_end_to_end_segdet_float(
1327            output.view(),
1328            protos.view(),
1329            0.5,
1330            &mut boxes,
1331            &mut masks,
1332        )
1333        .unwrap();
1334
1335        assert_eq!(boxes.len(), 0);
1336        assert_eq!(masks.len(), 0);
1337    }
1338
1339    #[test]
1340    fn test_end_to_end_segdet_capacity_limit() {
1341        let num_protos = 32;
1342        let num_detections = 5;
1343        let num_features = 6 + num_protos;
1344
1345        let mut data = vec![0.0f32; num_features * num_detections];
1346        // All detections pass threshold
1347        for i in 0..num_detections {
1348            data[i] = 0.1 * (i as f32); // x1
1349            data[num_detections + i] = 0.1 * (i as f32); // y1
1350            data[2 * num_detections + i] = 0.1 * (i as f32) + 0.2; // x2
1351            data[3 * num_detections + i] = 0.1 * (i as f32) + 0.2; // y2
1352            data[4 * num_detections + i] = 0.9; // conf
1353            data[5 * num_detections + i] = i as f32; // class
1354        }
1355
1356        let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
1357        let protos = Array3::<f32>::zeros((16, 16, num_protos));
1358
1359        let mut boxes = Vec::with_capacity(2); // Limit to 2
1360        let mut masks = Vec::with_capacity(2);
1361        decode_yolo_end_to_end_segdet_float(
1362            output.view(),
1363            protos.view(),
1364            0.5,
1365            &mut boxes,
1366            &mut masks,
1367        )
1368        .unwrap();
1369
1370        assert_eq!(boxes.len(), 2);
1371        assert_eq!(masks.len(), 2);
1372    }
1373
1374    #[test]
1375    fn test_end_to_end_segdet_invalid_shape_too_few_rows() {
1376        // Test with too few rows (needs at least 7: 6 base + 1 mask coeff)
1377        let output = Array2::<f32>::zeros((6, 3));
1378        let protos = Array3::<f32>::zeros((16, 16, 32));
1379
1380        let mut boxes = Vec::with_capacity(10);
1381        let mut masks = Vec::with_capacity(10);
1382        let result = decode_yolo_end_to_end_segdet_float(
1383            output.view(),
1384            protos.view(),
1385            0.5,
1386            &mut boxes,
1387            &mut masks,
1388        );
1389
1390        assert!(result.is_err());
1391        assert!(matches!(
1392            result,
1393            Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 7 rows")
1394        ));
1395    }
1396
1397    #[test]
1398    fn test_end_to_end_segdet_invalid_shape_protos_mismatch() {
1399        // Test with mismatched mask coefficients and protos count
1400        let num_protos = 32;
1401        let output = Array2::<f32>::zeros((6 + 16, 3)); // 16 mask coeffs
1402        let protos = Array3::<f32>::zeros((16, 16, num_protos)); // 32 protos
1403
1404        let mut boxes = Vec::with_capacity(10);
1405        let mut masks = Vec::with_capacity(10);
1406        let result = decode_yolo_end_to_end_segdet_float(
1407            output.view(),
1408            protos.view(),
1409            0.5,
1410            &mut boxes,
1411            &mut masks,
1412        );
1413
1414        assert!(result.is_err());
1415        assert!(matches!(
1416            result,
1417            Err(crate::DecoderError::InvalidShape(s)) if s.contains("doesn't match protos count")
1418        ));
1419    }
1420
1421    // ========================================================================
1422    // Tests for yolo_segmentation_to_mask
1423    // ========================================================================
1424
1425    #[test]
1426    fn test_segmentation_to_mask_basic() {
1427        // Create a 4x4x1 segmentation with values above and below threshold
1428        let data: Vec<u8> = vec![
1429            100, 200, 50, 150, // row 0
1430            10, 255, 128, 64, // row 1
1431            0, 127, 128, 255, // row 2
1432            64, 64, 192, 192, // row 3
1433        ];
1434        let segmentation = Array3::from_shape_vec((4, 4, 1), data).unwrap();
1435
1436        let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
1437
1438        // Values >= 128 should be 1, others 0
1439        assert_eq!(mask[[0, 0]], 0); // 100 < 128
1440        assert_eq!(mask[[0, 1]], 1); // 200 >= 128
1441        assert_eq!(mask[[0, 2]], 0); // 50 < 128
1442        assert_eq!(mask[[0, 3]], 1); // 150 >= 128
1443        assert_eq!(mask[[1, 1]], 1); // 255 >= 128
1444        assert_eq!(mask[[1, 2]], 1); // 128 >= 128
1445        assert_eq!(mask[[2, 0]], 0); // 0 < 128
1446        assert_eq!(mask[[2, 1]], 0); // 127 < 128
1447    }
1448
1449    #[test]
1450    fn test_segmentation_to_mask_all_above() {
1451        let segmentation = Array3::from_elem((4, 4, 1), 255u8);
1452        let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
1453        assert!(mask.iter().all(|&x| x == 1));
1454    }
1455
1456    #[test]
1457    fn test_segmentation_to_mask_all_below() {
1458        let segmentation = Array3::from_elem((4, 4, 1), 64u8);
1459        let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
1460        assert!(mask.iter().all(|&x| x == 0));
1461    }
1462
1463    #[test]
1464    fn test_segmentation_to_mask_invalid_shape() {
1465        let segmentation = Array3::from_elem((4, 4, 3), 128u8);
1466        let result = yolo_segmentation_to_mask(segmentation.view(), 128);
1467
1468        assert!(result.is_err());
1469        assert!(matches!(
1470            result,
1471            Err(crate::DecoderError::InvalidShape(s)) if s.contains("(H, W, 1)")
1472        ));
1473    }
1474}