use std::fmt::Debug;
use ndarray::{
parallel::prelude::{IntoParallelIterator, ParallelIterator},
s, Array2, Array3, ArrayView1, ArrayView2, ArrayView3,
};
use num_traits::{AsPrimitive, Float, PrimInt, Signed};
use rayon::slice::ParallelSliceMut;
use crate::{
byte::{
nms_class_aware_int, nms_extra_class_aware_int, nms_extra_int, nms_int,
postprocess_boxes_index_quant, postprocess_boxes_quant, quantize_score_threshold,
},
configs::Nms,
dequant_detect_box,
float::{
nms_class_aware_float, nms_extra_class_aware_float, nms_extra_float, nms_float,
postprocess_boxes_float, postprocess_boxes_index_float,
},
BBoxTypeTrait, BoundingBox, DetectBox, DetectBoxQuantized, ProtoData, ProtoTensor,
Quantization, Segmentation, XYWH, XYXY,
};
pub(crate) const MAX_NMS_CANDIDATES: usize = 30_000;
fn truncate_to_top_k_by_score<E: Send>(boxes: &mut Vec<(DetectBox, E)>) {
if boxes.len() > MAX_NMS_CANDIDATES {
boxes.par_sort_unstable_by(|a, b| b.0.score.total_cmp(&a.0.score));
boxes.truncate(MAX_NMS_CANDIDATES);
}
}
fn truncate_to_top_k_by_score_quant<S: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send>(
boxes: &mut Vec<(DetectBoxQuantized<S>, E)>,
) {
if boxes.len() > MAX_NMS_CANDIDATES {
boxes.par_sort_unstable_by(|a, b| b.0.score.cmp(&a.0.score));
boxes.truncate(MAX_NMS_CANDIDATES);
}
}
fn dispatch_nms_float(nms: Option<Nms>, iou: f32, boxes: Vec<DetectBox>) -> Vec<DetectBox> {
match nms {
Some(Nms::ClassAgnostic) => nms_float(iou, boxes),
Some(Nms::ClassAware) => nms_class_aware_float(iou, boxes),
None => boxes, }
}
pub(super) fn dispatch_nms_extra_float<E: Send + Sync>(
nms: Option<Nms>,
iou: f32,
boxes: Vec<(DetectBox, E)>,
) -> Vec<(DetectBox, E)> {
match nms {
Some(Nms::ClassAgnostic) => nms_extra_float(iou, boxes),
Some(Nms::ClassAware) => nms_extra_class_aware_float(iou, boxes),
None => boxes, }
}
fn dispatch_nms_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync>(
nms: Option<Nms>,
iou: f32,
boxes: Vec<DetectBoxQuantized<SCORE>>,
) -> Vec<DetectBoxQuantized<SCORE>> {
match nms {
Some(Nms::ClassAgnostic) => nms_int(iou, boxes),
Some(Nms::ClassAware) => nms_class_aware_int(iou, boxes),
None => boxes, }
}
fn dispatch_nms_extra_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send + Sync>(
nms: Option<Nms>,
iou: f32,
boxes: Vec<(DetectBoxQuantized<SCORE>, E)>,
) -> Vec<(DetectBoxQuantized<SCORE>, E)> {
match nms {
Some(Nms::ClassAgnostic) => nms_extra_int(iou, boxes),
Some(Nms::ClassAware) => nms_extra_class_aware_int(iou, boxes),
None => boxes, }
}
pub fn decode_yolo_det<BOX: PrimInt + AsPrimitive<f32> + Send + Sync>(
output: (ArrayView2<BOX>, Quantization),
score_threshold: f32,
iou_threshold: f32,
nms: Option<Nms>,
output_boxes: &mut Vec<DetectBox>,
) where
f32: AsPrimitive<BOX>,
{
impl_yolo_quant::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
}
pub fn decode_yolo_det_float<T>(
output: ArrayView2<T>,
score_threshold: f32,
iou_threshold: f32,
nms: Option<Nms>,
output_boxes: &mut Vec<DetectBox>,
) where
T: Float + AsPrimitive<f32> + Send + Sync + 'static,
f32: AsPrimitive<T>,
{
impl_yolo_float::<XYWH, _>(output, score_threshold, iou_threshold, nms, output_boxes);
}
pub fn decode_yolo_segdet_quant<
BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
>(
boxes: (ArrayView2<BOX>, Quantization),
protos: (ArrayView3<PROTO>, Quantization),
score_threshold: f32,
iou_threshold: f32,
nms: Option<Nms>,
output_boxes: &mut Vec<DetectBox>,
output_masks: &mut Vec<Segmentation>,
) -> Result<(), crate::DecoderError>
where
f32: AsPrimitive<BOX>,
{
impl_yolo_segdet_quant::<XYWH, _, _>(
boxes,
protos,
score_threshold,
iou_threshold,
nms,
output_boxes,
output_masks,
)
}
pub fn decode_yolo_segdet_float<T>(
boxes: ArrayView2<T>,
protos: ArrayView3<T>,
score_threshold: f32,
iou_threshold: f32,
nms: Option<Nms>,
output_boxes: &mut Vec<DetectBox>,
output_masks: &mut Vec<Segmentation>,
) -> Result<(), crate::DecoderError>
where
T: Float + AsPrimitive<f32> + Send + Sync + 'static,
f32: AsPrimitive<T>,
{
impl_yolo_segdet_float::<XYWH, _, _>(
boxes,
protos,
score_threshold,
iou_threshold,
nms,
output_boxes,
output_masks,
)
}
pub fn decode_yolo_split_det_quant<
BOX: PrimInt + AsPrimitive<i32> + AsPrimitive<f32> + Send + Sync,
SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
>(
boxes: (ArrayView2<BOX>, Quantization),
scores: (ArrayView2<SCORE>, Quantization),
score_threshold: f32,
iou_threshold: f32,
nms: Option<Nms>,
output_boxes: &mut Vec<DetectBox>,
) where
f32: AsPrimitive<SCORE>,
{
impl_yolo_split_quant::<XYWH, _, _>(
boxes,
scores,
score_threshold,
iou_threshold,
nms,
output_boxes,
);
}
pub fn decode_yolo_split_det_float<T>(
boxes: ArrayView2<T>,
scores: ArrayView2<T>,
score_threshold: f32,
iou_threshold: f32,
nms: Option<Nms>,
output_boxes: &mut Vec<DetectBox>,
) where
T: Float + AsPrimitive<f32> + Send + Sync + 'static,
f32: AsPrimitive<T>,
{
impl_yolo_split_float::<XYWH, _, _>(
boxes,
scores,
score_threshold,
iou_threshold,
nms,
output_boxes,
);
}
#[allow(clippy::too_many_arguments)]
pub fn decode_yolo_split_segdet<
BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
>(
boxes: (ArrayView2<BOX>, Quantization),
scores: (ArrayView2<SCORE>, Quantization),
mask_coeff: (ArrayView2<MASK>, Quantization),
protos: (ArrayView3<PROTO>, Quantization),
score_threshold: f32,
iou_threshold: f32,
nms: Option<Nms>,
output_boxes: &mut Vec<DetectBox>,
output_masks: &mut Vec<Segmentation>,
) -> Result<(), crate::DecoderError>
where
f32: AsPrimitive<SCORE>,
{
impl_yolo_split_segdet_quant::<XYWH, _, _, _, _>(
boxes,
scores,
mask_coeff,
protos,
score_threshold,
iou_threshold,
nms,
output_boxes,
output_masks,
)
}
#[allow(clippy::too_many_arguments)]
pub fn decode_yolo_split_segdet_float<T>(
boxes: ArrayView2<T>,
scores: ArrayView2<T>,
mask_coeff: ArrayView2<T>,
protos: ArrayView3<T>,
score_threshold: f32,
iou_threshold: f32,
nms: Option<Nms>,
output_boxes: &mut Vec<DetectBox>,
output_masks: &mut Vec<Segmentation>,
) -> Result<(), crate::DecoderError>
where
T: Float + AsPrimitive<f32> + Send + Sync + 'static,
f32: AsPrimitive<T>,
{
impl_yolo_split_segdet_float::<XYWH, _, _, _, _>(
boxes,
scores,
mask_coeff,
protos,
score_threshold,
iou_threshold,
nms,
output_boxes,
output_masks,
)
}
pub fn decode_yolo_end_to_end_det_float<T>(
output: ArrayView2<T>,
score_threshold: f32,
output_boxes: &mut Vec<DetectBox>,
) -> Result<(), crate::DecoderError>
where
T: Float + AsPrimitive<f32> + Send + Sync + 'static,
f32: AsPrimitive<T>,
{
if output.shape()[0] < 6 {
return Err(crate::DecoderError::InvalidShape(format!(
"End-to-end detection output requires at least 6 rows, got {}",
output.shape()[0]
)));
}
let boxes = output.slice(s![0..4, ..]).reversed_axes();
let scores = output.slice(s![4..5, ..]).reversed_axes();
let classes = output.slice(s![5, ..]);
let mut boxes =
postprocess_boxes_index_float::<XYXY, _, _>(score_threshold.as_(), boxes, scores);
boxes.truncate(output_boxes.capacity());
output_boxes.clear();
for (mut b, i) in boxes.into_iter() {
b.label = classes[i].as_() as usize;
output_boxes.push(b);
}
Ok(())
}
pub fn decode_yolo_end_to_end_segdet_float<T>(
output: ArrayView2<T>,
protos: ArrayView3<T>,
score_threshold: f32,
output_boxes: &mut Vec<DetectBox>,
output_masks: &mut Vec<crate::Segmentation>,
) -> Result<(), crate::DecoderError>
where
T: Float + AsPrimitive<f32> + Send + Sync + 'static,
f32: AsPrimitive<T>,
{
let (boxes, scores, classes, mask_coeff) =
postprocess_yolo_end_to_end_segdet(&output, protos.dim().2)?;
let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
boxes,
scores,
classes,
score_threshold,
output_boxes.capacity(),
);
impl_yolo_split_segdet_process_masks(boxes, mask_coeff, protos, output_boxes, output_masks)
}
pub fn decode_yolo_split_end_to_end_det_float<T: Float + AsPrimitive<f32>>(
boxes: ArrayView2<T>,
scores: ArrayView2<T>,
classes: ArrayView2<T>,
score_threshold: f32,
output_boxes: &mut Vec<DetectBox>,
) -> Result<(), crate::DecoderError> {
let n = boxes.shape()[1];
output_boxes.clear();
let (boxes, scores, classes) = postprocess_yolo_split_end_to_end_det(boxes, scores, &classes)?;
for i in 0..n {
let score: f32 = scores[[i, 0]].as_();
if score < score_threshold {
continue;
}
if output_boxes.len() >= output_boxes.capacity() {
break;
}
output_boxes.push(DetectBox {
bbox: BoundingBox {
xmin: boxes[[i, 0]].as_(),
ymin: boxes[[i, 1]].as_(),
xmax: boxes[[i, 2]].as_(),
ymax: boxes[[i, 3]].as_(),
},
score,
label: classes[i].as_() as usize,
});
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn decode_yolo_split_end_to_end_segdet_float<T>(
boxes: ArrayView2<T>,
scores: ArrayView2<T>,
classes: ArrayView2<T>,
mask_coeff: ArrayView2<T>,
protos: ArrayView3<T>,
score_threshold: f32,
output_boxes: &mut Vec<DetectBox>,
output_masks: &mut Vec<crate::Segmentation>,
) -> Result<(), crate::DecoderError>
where
T: Float + AsPrimitive<f32> + Send + Sync + 'static,
f32: AsPrimitive<T>,
{
let (boxes, scores, classes, mask_coeff) =
postprocess_yolo_split_end_to_end_segdet(boxes, scores, &classes, mask_coeff)?;
let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
boxes,
scores,
classes,
score_threshold,
output_boxes.capacity(),
);
impl_yolo_split_segdet_process_masks(boxes, mask_coeff, protos, output_boxes, output_masks)
}
#[allow(clippy::type_complexity)]
pub(crate) fn postprocess_yolo_end_to_end_segdet<'a, T>(
output: &'a ArrayView2<'_, T>,
num_protos: usize,
) -> Result<
(
ArrayView2<'a, T>,
ArrayView2<'a, T>,
ArrayView1<'a, T>,
ArrayView2<'a, T>,
),
crate::DecoderError,
> {
if output.shape()[0] < 7 {
return Err(crate::DecoderError::InvalidShape(format!(
"End-to-end segdet output requires at least 7 rows, got {}",
output.shape()[0]
)));
}
let num_mask_coeffs = output.shape()[0] - 6;
if num_mask_coeffs != num_protos {
return Err(crate::DecoderError::InvalidShape(format!(
"Mask coefficients count ({}) doesn't match protos count ({})",
num_mask_coeffs, num_protos
)));
}
let boxes = output.slice(s![0..4, ..]).reversed_axes();
let scores = output.slice(s![4..5, ..]).reversed_axes();
let classes = output.slice(s![5, ..]);
let mask_coeff = output.slice(s![6.., ..]).reversed_axes();
Ok((boxes, scores, classes, mask_coeff))
}
#[allow(clippy::type_complexity)]
pub(crate) fn postprocess_yolo_split_end_to_end_det<'a, 'b, 'c, BOXES, SCORES, CLASS>(
boxes: ArrayView2<'a, BOXES>,
scores: ArrayView2<'b, SCORES>,
classes: &'c ArrayView2<CLASS>,
) -> Result<
(
ArrayView2<'a, BOXES>,
ArrayView2<'b, SCORES>,
ArrayView1<'c, CLASS>,
),
crate::DecoderError,
> {
let num_boxes = boxes.shape()[1];
if boxes.shape()[0] != 4 {
return Err(crate::DecoderError::InvalidShape(format!(
"Split end-to-end box_coords must be 4, got {}",
boxes.shape()[0]
)));
}
if scores.shape()[0] != 1 {
return Err(crate::DecoderError::InvalidShape(format!(
"Split end-to-end scores num_classes must be 1, got {}",
scores.shape()[0]
)));
}
if classes.shape()[0] != 1 {
return Err(crate::DecoderError::InvalidShape(format!(
"Split end-to-end classes num_classes must be 1, got {}",
classes.shape()[0]
)));
}
if scores.shape()[1] != num_boxes {
return Err(crate::DecoderError::InvalidShape(format!(
"Split end-to-end scores must have same num_boxes as boxes ({}), got {}",
num_boxes,
scores.shape()[1]
)));
}
if classes.shape()[1] != num_boxes {
return Err(crate::DecoderError::InvalidShape(format!(
"Split end-to-end classes must have same num_boxes as boxes ({}), got {}",
num_boxes,
classes.shape()[1]
)));
}
let boxes = boxes.reversed_axes();
let scores = scores.reversed_axes();
let classes = classes.slice(s![0, ..]);
Ok((boxes, scores, classes))
}
#[allow(clippy::type_complexity)]
pub(crate) fn postprocess_yolo_split_end_to_end_segdet<
'a,
'b,
'c,
'd,
BOXES,
SCORES,
CLASS,
MASK,
>(
boxes: ArrayView2<'a, BOXES>,
scores: ArrayView2<'b, SCORES>,
classes: &'c ArrayView2<CLASS>,
mask_coeff: ArrayView2<'d, MASK>,
) -> Result<
(
ArrayView2<'a, BOXES>,
ArrayView2<'b, SCORES>,
ArrayView1<'c, CLASS>,
ArrayView2<'d, MASK>,
),
crate::DecoderError,
> {
let num_boxes = boxes.shape()[1];
if boxes.shape()[0] != 4 {
return Err(crate::DecoderError::InvalidShape(format!(
"Split end-to-end box_coords must be 4, got {}",
boxes.shape()[0]
)));
}
if scores.shape()[0] != 1 {
return Err(crate::DecoderError::InvalidShape(format!(
"Split end-to-end scores num_classes must be 1, got {}",
scores.shape()[0]
)));
}
if classes.shape()[0] != 1 {
return Err(crate::DecoderError::InvalidShape(format!(
"Split end-to-end classes num_classes must be 1, got {}",
classes.shape()[0]
)));
}
if scores.shape()[1] != num_boxes {
return Err(crate::DecoderError::InvalidShape(format!(
"Split end-to-end scores must have same num_boxes as boxes ({}), got {}",
num_boxes,
scores.shape()[1]
)));
}
if classes.shape()[1] != num_boxes {
return Err(crate::DecoderError::InvalidShape(format!(
"Split end-to-end classes must have same num_boxes as boxes ({}), got {}",
num_boxes,
classes.shape()[1]
)));
}
if mask_coeff.shape()[1] != num_boxes {
return Err(crate::DecoderError::InvalidShape(format!(
"Split end-to-end mask_coeff must have same num_boxes as boxes ({}), got {}",
num_boxes,
mask_coeff.shape()[1]
)));
}
let boxes = boxes.reversed_axes();
let scores = scores.reversed_axes();
let classes = classes.slice(s![0, ..]);
let mask_coeff = mask_coeff.reversed_axes();
Ok((boxes, scores, classes, mask_coeff))
}
pub(crate) fn impl_yolo_quant<B: BBoxTypeTrait, T: PrimInt + AsPrimitive<f32> + Send + Sync>(
output: (ArrayView2<T>, Quantization),
score_threshold: f32,
iou_threshold: f32,
nms: Option<Nms>,
output_boxes: &mut Vec<DetectBox>,
) where
f32: AsPrimitive<T>,
{
let (boxes, quant_boxes) = output;
let (boxes_tensor, scores_tensor) = postprocess_yolo(&boxes);
let boxes = {
let score_threshold = quantize_score_threshold(score_threshold, quant_boxes);
postprocess_boxes_quant::<B, _, _>(
score_threshold,
boxes_tensor,
scores_tensor,
quant_boxes,
)
};
let boxes = dispatch_nms_int(nms, iou_threshold, boxes);
let len = output_boxes.capacity().min(boxes.len());
output_boxes.clear();
for b in boxes.iter().take(len) {
output_boxes.push(dequant_detect_box(b, quant_boxes));
}
}
pub(crate) fn impl_yolo_float<B: BBoxTypeTrait, T: Float + AsPrimitive<f32> + Send + Sync>(
output: ArrayView2<T>,
score_threshold: f32,
iou_threshold: f32,
nms: Option<Nms>,
output_boxes: &mut Vec<DetectBox>,
) where
f32: AsPrimitive<T>,
{
let (boxes_tensor, scores_tensor) = postprocess_yolo(&output);
let boxes =
postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
let boxes = dispatch_nms_float(nms, iou_threshold, boxes);
let len = output_boxes.capacity().min(boxes.len());
output_boxes.clear();
for b in boxes.into_iter().take(len) {
output_boxes.push(b);
}
}
pub(crate) fn impl_yolo_split_quant<
B: BBoxTypeTrait,
BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
>(
boxes: (ArrayView2<BOX>, Quantization),
scores: (ArrayView2<SCORE>, Quantization),
score_threshold: f32,
iou_threshold: f32,
nms: Option<Nms>,
output_boxes: &mut Vec<DetectBox>,
) where
f32: AsPrimitive<SCORE>,
{
let (boxes_tensor, quant_boxes) = boxes;
let (scores_tensor, quant_scores) = scores;
let boxes_tensor = boxes_tensor.reversed_axes();
let scores_tensor = scores_tensor.reversed_axes();
let boxes = {
let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
postprocess_boxes_quant::<B, _, _>(
score_threshold,
boxes_tensor,
scores_tensor,
quant_boxes,
)
};
let boxes = dispatch_nms_int(nms, iou_threshold, boxes);
let len = output_boxes.capacity().min(boxes.len());
output_boxes.clear();
for b in boxes.iter().take(len) {
output_boxes.push(dequant_detect_box(b, quant_scores));
}
}
pub(crate) fn impl_yolo_split_float<
B: BBoxTypeTrait,
BOX: Float + AsPrimitive<f32> + Send + Sync,
SCORE: Float + AsPrimitive<f32> + Send + Sync,
>(
boxes_tensor: ArrayView2<BOX>,
scores_tensor: ArrayView2<SCORE>,
score_threshold: f32,
iou_threshold: f32,
nms: Option<Nms>,
output_boxes: &mut Vec<DetectBox>,
) where
f32: AsPrimitive<SCORE>,
{
let boxes_tensor = boxes_tensor.reversed_axes();
let scores_tensor = scores_tensor.reversed_axes();
let boxes =
postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
let boxes = dispatch_nms_float(nms, iou_threshold, boxes);
let len = output_boxes.capacity().min(boxes.len());
output_boxes.clear();
for b in boxes.into_iter().take(len) {
output_boxes.push(b);
}
}
pub(crate) fn impl_yolo_segdet_quant<
B: BBoxTypeTrait,
BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
>(
boxes: (ArrayView2<BOX>, Quantization),
protos: (ArrayView3<PROTO>, Quantization),
score_threshold: f32,
iou_threshold: f32,
nms: Option<Nms>,
output_boxes: &mut Vec<DetectBox>,
output_masks: &mut Vec<Segmentation>,
) -> Result<(), crate::DecoderError>
where
f32: AsPrimitive<BOX>,
{
let (boxes, quant_boxes) = boxes;
let num_protos = protos.0.dim().2;
let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
let boxes = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
(boxes_tensor, quant_boxes),
(scores_tensor, quant_boxes),
score_threshold,
iou_threshold,
nms,
output_boxes.capacity(),
);
impl_yolo_split_segdet_quant_process_masks::<_, _>(
boxes,
(mask_tensor, quant_boxes),
protos,
output_boxes,
output_masks,
)
}
pub(crate) fn impl_yolo_segdet_float<
B: BBoxTypeTrait,
BOX: Float + AsPrimitive<f32> + Send + Sync,
PROTO: Float + AsPrimitive<f32> + Send + Sync,
>(
boxes: ArrayView2<BOX>,
protos: ArrayView3<PROTO>,
score_threshold: f32,
iou_threshold: f32,
nms: Option<Nms>,
output_boxes: &mut Vec<DetectBox>,
output_masks: &mut Vec<Segmentation>,
) -> Result<(), crate::DecoderError>
where
f32: AsPrimitive<BOX>,
{
let num_protos = protos.dim().2;
let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
let boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
boxes_tensor,
scores_tensor,
score_threshold,
iou_threshold,
nms,
output_boxes.capacity(),
);
impl_yolo_split_segdet_process_masks(boxes, mask_tensor, protos, output_boxes, output_masks)
}
pub(crate) fn impl_yolo_segdet_get_boxes<
B: BBoxTypeTrait,
BOX: Float + AsPrimitive<f32> + Send + Sync,
SCORE: Float + AsPrimitive<f32> + Send + Sync,
>(
boxes_tensor: ArrayView2<BOX>,
scores_tensor: ArrayView2<SCORE>,
score_threshold: f32,
iou_threshold: f32,
nms: Option<Nms>,
max_boxes: usize,
) -> Vec<(DetectBox, usize)>
where
f32: AsPrimitive<SCORE>,
{
let mut boxes = postprocess_boxes_index_float::<B, _, _>(
score_threshold.as_(),
boxes_tensor,
scores_tensor,
);
truncate_to_top_k_by_score(&mut boxes);
let mut boxes = dispatch_nms_extra_float(nms, iou_threshold, boxes);
boxes.truncate(max_boxes);
boxes
}
pub(crate) fn impl_yolo_end_to_end_segdet_get_boxes<
B: BBoxTypeTrait,
BOX: Float + AsPrimitive<f32> + Send + Sync,
SCORE: Float + AsPrimitive<f32> + Send + Sync,
CLASS: AsPrimitive<f32> + Send + Sync,
>(
boxes: ArrayView2<BOX>,
scores: ArrayView2<SCORE>,
classes: ArrayView1<CLASS>,
score_threshold: f32,
max_boxes: usize,
) -> Vec<(DetectBox, usize)>
where
f32: AsPrimitive<SCORE>,
{
let mut boxes = postprocess_boxes_index_float::<B, _, _>(score_threshold.as_(), boxes, scores);
boxes.truncate(max_boxes);
for (b, ind) in &mut boxes {
b.label = classes[*ind].as_().round() as usize;
}
boxes
}
pub(crate) fn impl_yolo_split_segdet_process_masks<
MASK: Float + AsPrimitive<f32> + Send + Sync,
PROTO: Float + AsPrimitive<f32> + Send + Sync,
>(
boxes: Vec<(DetectBox, usize)>,
masks_tensor: ArrayView2<MASK>,
protos_tensor: ArrayView3<PROTO>,
output_boxes: &mut Vec<DetectBox>,
output_masks: &mut Vec<Segmentation>,
) -> Result<(), crate::DecoderError> {
let boxes = decode_segdet_f32(boxes, masks_tensor, protos_tensor)?;
output_boxes.clear();
output_masks.clear();
for (b, m) in boxes.into_iter() {
output_boxes.push(b);
output_masks.push(Segmentation {
xmin: b.bbox.xmin,
ymin: b.bbox.ymin,
xmax: b.bbox.xmax,
ymax: b.bbox.ymax,
segmentation: m,
});
}
Ok(())
}
pub(crate) fn impl_yolo_split_segdet_quant_get_boxes<
B: BBoxTypeTrait,
BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
>(
boxes: (ArrayView2<BOX>, Quantization),
scores: (ArrayView2<SCORE>, Quantization),
score_threshold: f32,
iou_threshold: f32,
nms: Option<Nms>,
max_boxes: usize,
) -> Vec<(DetectBox, usize)>
where
f32: AsPrimitive<SCORE>,
{
let (boxes_tensor, quant_boxes) = boxes;
let (scores_tensor, quant_scores) = scores;
let mut boxes = {
let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
postprocess_boxes_index_quant::<B, _, _>(
score_threshold,
boxes_tensor,
scores_tensor,
quant_boxes,
)
};
truncate_to_top_k_by_score_quant(&mut boxes);
let mut boxes = dispatch_nms_extra_int(nms, iou_threshold, boxes);
boxes.truncate(max_boxes);
boxes
.into_iter()
.map(|(b, i)| (dequant_detect_box(&b, quant_scores), i))
.collect()
}
pub(crate) fn impl_yolo_split_segdet_quant_process_masks<
MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
>(
boxes: Vec<(DetectBox, usize)>,
mask_coeff: (ArrayView2<MASK>, Quantization),
protos: (ArrayView3<PROTO>, Quantization),
output_boxes: &mut Vec<DetectBox>,
output_masks: &mut Vec<Segmentation>,
) -> Result<(), crate::DecoderError> {
let (masks, quant_masks) = mask_coeff;
let (protos, quant_protos) = protos;
let boxes = decode_segdet_quant(boxes, masks, protos, quant_masks, quant_protos)?;
output_boxes.clear();
output_masks.clear();
for (b, m) in boxes.into_iter() {
output_boxes.push(b);
output_masks.push(Segmentation {
xmin: b.bbox.xmin,
ymin: b.bbox.ymin,
xmax: b.bbox.xmax,
ymax: b.bbox.ymax,
segmentation: m,
});
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn impl_yolo_split_segdet_quant<
B: BBoxTypeTrait,
BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
>(
boxes: (ArrayView2<BOX>, Quantization),
scores: (ArrayView2<SCORE>, Quantization),
mask_coeff: (ArrayView2<MASK>, Quantization),
protos: (ArrayView3<PROTO>, Quantization),
score_threshold: f32,
iou_threshold: f32,
nms: Option<Nms>,
output_boxes: &mut Vec<DetectBox>,
output_masks: &mut Vec<Segmentation>,
) -> Result<(), crate::DecoderError>
where
f32: AsPrimitive<SCORE>,
{
let (boxes_, scores_, mask_coeff_) =
postprocess_yolo_split_segdet(boxes.0, scores.0, mask_coeff.0);
let boxes = (boxes_, boxes.1);
let scores = (scores_, scores.1);
let mask_coeff = (mask_coeff_, mask_coeff.1);
let boxes = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
boxes,
scores,
score_threshold,
iou_threshold,
nms,
output_boxes.capacity(),
);
impl_yolo_split_segdet_quant_process_masks(
boxes,
mask_coeff,
protos,
output_boxes,
output_masks,
)
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn impl_yolo_split_segdet_float<
B: BBoxTypeTrait,
BOX: Float + AsPrimitive<f32> + Send + Sync,
SCORE: Float + AsPrimitive<f32> + Send + Sync,
MASK: Float + AsPrimitive<f32> + Send + Sync,
PROTO: Float + AsPrimitive<f32> + Send + Sync,
>(
boxes_tensor: ArrayView2<BOX>,
scores_tensor: ArrayView2<SCORE>,
mask_tensor: ArrayView2<MASK>,
protos: ArrayView3<PROTO>,
score_threshold: f32,
iou_threshold: f32,
nms: Option<Nms>,
output_boxes: &mut Vec<DetectBox>,
output_masks: &mut Vec<Segmentation>,
) -> Result<(), crate::DecoderError>
where
f32: AsPrimitive<SCORE>,
{
let (boxes_tensor, scores_tensor, mask_tensor) =
postprocess_yolo_split_segdet(boxes_tensor, scores_tensor, mask_tensor);
let boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
boxes_tensor,
scores_tensor,
score_threshold,
iou_threshold,
nms,
output_boxes.capacity(),
);
impl_yolo_split_segdet_process_masks(boxes, mask_tensor, protos, output_boxes, output_masks)
}
pub fn impl_yolo_segdet_quant_proto<
B: BBoxTypeTrait,
BOX: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + AsPrimitive<f32> + Send + Sync,
PROTO: PrimInt
+ AsPrimitive<i64>
+ AsPrimitive<i128>
+ AsPrimitive<f32>
+ AsPrimitive<i8>
+ Send
+ Sync,
>(
boxes: (ArrayView2<BOX>, Quantization),
protos: (ArrayView3<PROTO>, Quantization),
score_threshold: f32,
iou_threshold: f32,
nms: Option<Nms>,
output_boxes: &mut Vec<DetectBox>,
) -> ProtoData
where
f32: AsPrimitive<BOX>,
{
let (boxes_arr, quant_boxes) = boxes;
let (protos_arr, quant_protos) = protos;
let num_protos = protos_arr.dim().2;
let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes_arr, num_protos);
let det_indices = impl_yolo_split_segdet_quant_get_boxes::<B, _, _>(
(boxes_tensor, quant_boxes),
(scores_tensor, quant_boxes),
score_threshold,
iou_threshold,
nms,
output_boxes.capacity(),
);
extract_proto_data_quant(
det_indices,
mask_tensor,
quant_boxes,
protos_arr,
quant_protos,
output_boxes,
)
}
pub(crate) fn impl_yolo_segdet_float_proto<
B: BBoxTypeTrait,
BOX: Float + AsPrimitive<f32> + Send + Sync,
PROTO: Float + AsPrimitive<f32> + Send + Sync,
>(
boxes: ArrayView2<BOX>,
protos: ArrayView3<PROTO>,
score_threshold: f32,
iou_threshold: f32,
nms: Option<Nms>,
output_boxes: &mut Vec<DetectBox>,
) -> ProtoData
where
f32: AsPrimitive<BOX>,
{
let num_protos = protos.dim().2;
let (boxes_tensor, scores_tensor, mask_tensor) = postprocess_yolo_seg(&boxes, num_protos);
let boxes = impl_yolo_segdet_get_boxes::<B, _, _>(
boxes_tensor,
scores_tensor,
score_threshold,
iou_threshold,
nms,
output_boxes.capacity(),
);
extract_proto_data_float(boxes, mask_tensor, protos, output_boxes)
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn impl_yolo_split_segdet_float_proto<
B: BBoxTypeTrait,
BOX: Float + AsPrimitive<f32> + Send + Sync,
SCORE: Float + AsPrimitive<f32> + Send + Sync,
MASK: Float + AsPrimitive<f32> + Send + Sync,
PROTO: Float + AsPrimitive<f32> + Send + Sync,
>(
boxes_tensor: ArrayView2<BOX>,
scores_tensor: ArrayView2<SCORE>,
mask_tensor: ArrayView2<MASK>,
protos: ArrayView3<PROTO>,
score_threshold: f32,
iou_threshold: f32,
nms: Option<Nms>,
output_boxes: &mut Vec<DetectBox>,
) -> ProtoData
where
f32: AsPrimitive<SCORE>,
{
let (boxes_tensor, scores_tensor, mask_tensor) =
postprocess_yolo_split_segdet(boxes_tensor, scores_tensor, mask_tensor);
let det_indices = impl_yolo_segdet_get_boxes::<B, _, _>(
boxes_tensor,
scores_tensor,
score_threshold,
iou_threshold,
nms,
output_boxes.capacity(),
);
extract_proto_data_float(det_indices, mask_tensor, protos, output_boxes)
}
pub fn decode_yolo_end_to_end_segdet_float_proto<T>(
output: ArrayView2<T>,
protos: ArrayView3<T>,
score_threshold: f32,
output_boxes: &mut Vec<DetectBox>,
) -> Result<ProtoData, crate::DecoderError>
where
T: Float + AsPrimitive<f32> + Send + Sync + 'static,
f32: AsPrimitive<T>,
{
let (boxes, scores, classes, mask_coeff) =
postprocess_yolo_end_to_end_segdet(&output, protos.dim().2)?;
let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
boxes,
scores,
classes,
score_threshold,
output_boxes.capacity(),
);
Ok(extract_proto_data_float(
boxes,
mask_coeff,
protos,
output_boxes,
))
}
#[allow(clippy::too_many_arguments)]
pub fn decode_yolo_split_end_to_end_segdet_float_proto<T>(
boxes: ArrayView2<T>,
scores: ArrayView2<T>,
classes: ArrayView2<T>,
mask_coeff: ArrayView2<T>,
protos: ArrayView3<T>,
score_threshold: f32,
output_boxes: &mut Vec<DetectBox>,
) -> Result<ProtoData, crate::DecoderError>
where
T: Float + AsPrimitive<f32> + Send + Sync + 'static,
f32: AsPrimitive<T>,
{
let (boxes, scores, classes, mask_coeff) =
postprocess_yolo_split_end_to_end_segdet(boxes, scores, &classes, mask_coeff)?;
let boxes = impl_yolo_end_to_end_segdet_get_boxes::<XYXY, _, _, _>(
boxes,
scores,
classes,
score_threshold,
output_boxes.capacity(),
);
Ok(extract_proto_data_float(
boxes,
mask_coeff,
protos,
output_boxes,
))
}
pub(super) fn extract_proto_data_float<
MASK: Float + AsPrimitive<f32> + Send + Sync,
PROTO: Float + AsPrimitive<f32> + Send + Sync,
>(
det_indices: Vec<(DetectBox, usize)>,
mask_tensor: ArrayView2<MASK>,
protos: ArrayView3<PROTO>,
output_boxes: &mut Vec<DetectBox>,
) -> ProtoData {
let mut mask_coefficients = Vec::with_capacity(det_indices.len());
output_boxes.clear();
for (det, idx) in det_indices {
output_boxes.push(det);
let row = mask_tensor.row(idx);
mask_coefficients.push(row.iter().map(|v| v.as_()).collect());
}
let protos_f32 = protos.map(|v| v.as_());
ProtoData {
mask_coefficients,
protos: ProtoTensor::Float(protos_f32),
}
}
pub(crate) fn extract_proto_data_quant<
MASK: PrimInt + AsPrimitive<f32> + Send + Sync,
PROTO: PrimInt + AsPrimitive<f32> + AsPrimitive<i8> + Send + Sync + 'static,
>(
det_indices: Vec<(DetectBox, usize)>,
mask_tensor: ArrayView2<MASK>,
quant_masks: Quantization,
protos: ArrayView3<PROTO>,
quant_protos: Quantization,
output_boxes: &mut Vec<DetectBox>,
) -> ProtoData {
let mut mask_coefficients = Vec::with_capacity(det_indices.len());
output_boxes.clear();
for (det, idx) in det_indices {
output_boxes.push(det);
let row = mask_tensor.row(idx);
mask_coefficients.push(
row.iter()
.map(|v| (v.as_() - quant_masks.zero_point as f32) * quant_masks.scale)
.collect(),
);
}
let protos_i8 = if std::any::TypeId::of::<PROTO>() == std::any::TypeId::of::<i8>() {
let view_i8 =
unsafe { &*(&protos as *const ArrayView3<'_, PROTO> as *const ArrayView3<'_, i8>) };
view_i8.to_owned()
} else {
protos.map(|v| {
let v_i8: i8 = v.as_();
v_i8
})
};
ProtoData {
mask_coefficients,
protos: ProtoTensor::Quantized {
protos: protos_i8,
quantization: quant_protos,
},
}
}
fn postprocess_yolo<'a, T>(
output: &'a ArrayView2<'_, T>,
) -> (ArrayView2<'a, T>, ArrayView2<'a, T>) {
let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
let scores_tensor = output.slice(s![4.., ..,]).reversed_axes();
(boxes_tensor, scores_tensor)
}
pub(crate) fn postprocess_yolo_seg<'a, T>(
output: &'a ArrayView2<'_, T>,
num_protos: usize,
) -> (ArrayView2<'a, T>, ArrayView2<'a, T>, ArrayView2<'a, T>) {
assert!(
output.shape()[0] > num_protos + 4,
"Output shape is too short: {} <= {} + 4",
output.shape()[0],
num_protos
);
let num_classes = output.shape()[0] - 4 - num_protos;
let boxes_tensor = output.slice(s![..4, ..,]).reversed_axes();
let scores_tensor = output.slice(s![4..(num_classes + 4), ..,]).reversed_axes();
let mask_tensor = output.slice(s![(num_classes + 4).., ..,]).reversed_axes();
(boxes_tensor, scores_tensor, mask_tensor)
}
pub(crate) fn postprocess_yolo_split_segdet<'a, 'b, 'c, BOX, SCORE, MASK>(
boxes_tensor: ArrayView2<'a, BOX>,
scores_tensor: ArrayView2<'b, SCORE>,
mask_tensor: ArrayView2<'c, MASK>,
) -> (
ArrayView2<'a, BOX>,
ArrayView2<'b, SCORE>,
ArrayView2<'c, MASK>,
) {
let boxes_tensor = boxes_tensor.reversed_axes();
let scores_tensor = scores_tensor.reversed_axes();
let mask_tensor = mask_tensor.reversed_axes();
(boxes_tensor, scores_tensor, mask_tensor)
}
fn decode_segdet_f32<
MASK: Float + AsPrimitive<f32> + Send + Sync,
PROTO: Float + AsPrimitive<f32> + Send + Sync,
>(
boxes: Vec<(DetectBox, usize)>,
masks: ArrayView2<MASK>,
protos: ArrayView3<PROTO>,
) -> Result<Vec<(DetectBox, Array3<u8>)>, crate::DecoderError> {
if boxes.is_empty() {
return Ok(Vec::new());
}
if masks.shape()[1] != protos.shape()[2] {
return Err(crate::DecoderError::InvalidShape(format!(
"Mask coefficients count ({}) doesn't match protos channel count ({})",
masks.shape()[1],
protos.shape()[2],
)));
}
boxes
.into_par_iter()
.map(|mut b| {
let ind = b.1;
let (protos, roi) = protobox(&protos, &b.0.bbox)?;
b.0.bbox = roi;
Ok((b.0, make_segmentation(masks.row(ind), protos.view())))
})
.collect()
}
pub(crate) fn decode_segdet_quant<
MASK: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
PROTO: PrimInt + AsPrimitive<i64> + AsPrimitive<i128> + Send + Sync,
>(
boxes: Vec<(DetectBox, usize)>,
masks: ArrayView2<MASK>,
protos: ArrayView3<PROTO>,
quant_masks: Quantization,
quant_protos: Quantization,
) -> Result<Vec<(DetectBox, Array3<u8>)>, crate::DecoderError> {
if boxes.is_empty() {
return Ok(Vec::new());
}
if masks.shape()[1] != protos.shape()[2] {
return Err(crate::DecoderError::InvalidShape(format!(
"Mask coefficients count ({}) doesn't match protos channel count ({})",
masks.shape()[1],
protos.shape()[2],
)));
}
let total_bits = MASK::zero().count_zeros() + PROTO::zero().count_zeros() + 5; boxes
.into_iter()
.map(|mut b| {
let i = b.1;
let (protos, roi) = protobox(&protos, &b.0.bbox.to_canonical())?;
b.0.bbox = roi;
let seg = match total_bits {
0..=64 => make_segmentation_quant::<MASK, PROTO, i64>(
masks.row(i),
protos.view(),
quant_masks,
quant_protos,
),
65..=128 => make_segmentation_quant::<MASK, PROTO, i128>(
masks.row(i),
protos.view(),
quant_masks,
quant_protos,
),
_ => {
return Err(crate::DecoderError::NotSupported(format!(
"Unsupported bit width ({total_bits}) for segmentation computation"
)));
}
};
Ok((b.0, seg))
})
.collect()
}
fn protobox<'a, T>(
protos: &'a ArrayView3<T>,
roi: &BoundingBox,
) -> Result<(ArrayView3<'a, T>, BoundingBox), crate::DecoderError> {
let width = protos.dim().1 as f32;
let height = protos.dim().0 as f32;
const NORM_LIMIT: f32 = 2.0;
if roi.xmin > NORM_LIMIT
|| roi.ymin > NORM_LIMIT
|| roi.xmax > NORM_LIMIT
|| roi.ymax > NORM_LIMIT
{
return Err(crate::DecoderError::InvalidShape(format!(
"Bounding box coordinates appear un-normalized (pixel-space). \
Got bbox=({:.2}, {:.2}, {:.2}, {:.2}) but expected values in [0, 1]. \
ONNX models output pixel-space boxes — normalize them by dividing by \
the input dimensions before calling decode().",
roi.xmin, roi.ymin, roi.xmax, roi.ymax,
)));
}
let roi = [
(roi.xmin * width).clamp(0.0, width) as usize,
(roi.ymin * height).clamp(0.0, height) as usize,
(roi.xmax * width).clamp(0.0, width).ceil() as usize,
(roi.ymax * height).clamp(0.0, height).ceil() as usize,
];
let roi_norm = [
roi[0] as f32 / width,
roi[1] as f32 / height,
roi[2] as f32 / width,
roi[3] as f32 / height,
]
.into();
let cropped = protos.slice(s![roi[1]..roi[3], roi[0]..roi[2], ..]);
Ok((cropped, roi_norm))
}
fn make_segmentation<
MASK: Float + AsPrimitive<f32> + Send + Sync,
PROTO: Float + AsPrimitive<f32> + Send + Sync,
>(
mask: ArrayView1<MASK>,
protos: ArrayView3<PROTO>,
) -> Array3<u8> {
let shape = protos.shape();
let mask = mask.to_shape((1, mask.len())).unwrap();
let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
let protos = protos.reversed_axes();
let mask = mask.map(|x| x.as_());
let protos = protos.map(|x| x.as_());
let mask = mask
.dot(&protos)
.into_shape_with_order((shape[0], shape[1], 1))
.unwrap();
mask.map(|x| {
let sigmoid = 1.0 / (1.0 + (-*x).exp());
(sigmoid * 255.0).round() as u8
})
}
fn make_segmentation_quant<
MASK: PrimInt + AsPrimitive<DEST> + Send + Sync,
PROTO: PrimInt + AsPrimitive<DEST> + Send + Sync,
DEST: PrimInt + 'static + Signed + AsPrimitive<f32> + Debug,
>(
mask: ArrayView1<MASK>,
protos: ArrayView3<PROTO>,
quant_masks: Quantization,
quant_protos: Quantization,
) -> Array3<u8>
where
i32: AsPrimitive<DEST>,
f32: AsPrimitive<DEST>,
{
let shape = protos.shape();
let mask = mask.to_shape((1, mask.len())).unwrap();
let protos = protos.to_shape([shape[0] * shape[1], shape[2]]).unwrap();
let protos = protos.reversed_axes();
let zp = quant_masks.zero_point.as_();
let mask = mask.mapv(|x| x.as_() - zp);
let zp = quant_protos.zero_point.as_();
let protos = protos.mapv(|x| x.as_() - zp);
let segmentation = mask
.dot(&protos)
.into_shape_with_order((shape[0], shape[1], 1))
.unwrap();
let combined_scale = quant_masks.scale * quant_protos.scale;
segmentation.map(|x| {
let val: f32 = (*x).as_() * combined_scale;
let sigmoid = 1.0 / (1.0 + (-val).exp());
(sigmoid * 255.0).round() as u8
})
}
pub fn yolo_segmentation_to_mask(
segmentation: ArrayView3<u8>,
threshold: u8,
) -> Result<Array2<u8>, crate::DecoderError> {
if segmentation.shape()[2] != 1 {
return Err(crate::DecoderError::InvalidShape(format!(
"Yolo Instance Segmentation should have shape (H, W, 1), got (H, W, {})",
segmentation.shape()[2]
)));
}
Ok(segmentation
.slice(s![.., .., 0])
.map(|x| if *x >= threshold { 1 } else { 0 }))
}
#[cfg(test)]
#[cfg_attr(coverage_nightly, coverage(off))]
mod tests {
use super::*;
use ndarray::Array2;
#[test]
fn test_end_to_end_det_basic_filtering() {
let data: Vec<f32> = vec![
0.1, 0.2, 0.3, 0.1, 0.2, 0.3, 0.5, 0.6, 0.7, 0.5, 0.6, 0.7, 0.9, 0.1, 0.2, 0.0, 1.0, 2.0, ];
let output = Array2::from_shape_vec((6, 3), data).unwrap();
let mut boxes = Vec::with_capacity(10);
decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
assert_eq!(boxes.len(), 1);
assert_eq!(boxes[0].label, 0);
assert!((boxes[0].score - 0.9).abs() < 0.01);
assert!((boxes[0].bbox.xmin - 0.1).abs() < 0.01);
assert!((boxes[0].bbox.ymin - 0.1).abs() < 0.01);
assert!((boxes[0].bbox.xmax - 0.5).abs() < 0.01);
assert!((boxes[0].bbox.ymax - 0.5).abs() < 0.01);
}
#[test]
fn test_end_to_end_det_all_pass_threshold() {
let data: Vec<f32> = vec![
10.0, 20.0, 10.0, 20.0, 50.0, 60.0, 50.0, 60.0, 0.8, 0.7, 1.0, 2.0, ];
let output = Array2::from_shape_vec((6, 2), data).unwrap();
let mut boxes = Vec::with_capacity(10);
decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
assert_eq!(boxes.len(), 2);
assert_eq!(boxes[0].label, 1);
assert_eq!(boxes[1].label, 2);
}
#[test]
fn test_end_to_end_det_none_pass_threshold() {
let data: Vec<f32> = vec![
10.0, 20.0, 10.0, 20.0, 50.0, 60.0, 50.0, 60.0, 0.1, 0.2, 1.0, 2.0, ];
let output = Array2::from_shape_vec((6, 2), data).unwrap();
let mut boxes = Vec::with_capacity(10);
decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
assert_eq!(boxes.len(), 0);
}
#[test]
fn test_end_to_end_det_capacity_limit() {
let data: Vec<f32> = vec![
0.1, 0.2, 0.3, 0.4, 0.5, 0.1, 0.2, 0.3, 0.4, 0.5, 0.5, 0.6, 0.7, 0.8, 0.9, 0.5, 0.6, 0.7, 0.8, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.0, 1.0, 2.0, 3.0, 4.0, ];
let output = Array2::from_shape_vec((6, 5), data).unwrap();
let mut boxes = Vec::with_capacity(2); decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
assert_eq!(boxes.len(), 2);
}
#[test]
fn test_end_to_end_det_empty_output() {
let output = Array2::<f32>::zeros((6, 0));
let mut boxes = Vec::with_capacity(10);
decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
assert_eq!(boxes.len(), 0);
}
#[test]
fn test_end_to_end_det_pixel_coordinates() {
let data: Vec<f32> = vec![
100.0, 200.0, 300.0, 400.0, 0.95, 5.0, ];
let output = Array2::from_shape_vec((6, 1), data).unwrap();
let mut boxes = Vec::with_capacity(10);
decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes).unwrap();
assert_eq!(boxes.len(), 1);
assert_eq!(boxes[0].label, 5);
assert!((boxes[0].bbox.xmin - 100.0).abs() < 0.01);
assert!((boxes[0].bbox.ymin - 200.0).abs() < 0.01);
assert!((boxes[0].bbox.xmax - 300.0).abs() < 0.01);
assert!((boxes[0].bbox.ymax - 400.0).abs() < 0.01);
}
#[test]
fn test_end_to_end_det_invalid_shape() {
let output = Array2::<f32>::zeros((5, 3));
let mut boxes = Vec::with_capacity(10);
let result = decode_yolo_end_to_end_det_float(output.view(), 0.5, &mut boxes);
assert!(result.is_err());
assert!(matches!(
result,
Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 6 rows")
));
}
#[test]
fn test_end_to_end_segdet_basic() {
let num_protos = 32;
let num_detections = 2;
let num_features = 6 + num_protos;
let mut data = vec![0.0f32; num_features * num_detections];
data[0] = 0.1; data[1] = 0.5; data[num_detections] = 0.1; data[num_detections + 1] = 0.5; data[2 * num_detections] = 0.4; data[2 * num_detections + 1] = 0.9; data[3 * num_detections] = 0.4; data[3 * num_detections + 1] = 0.9; data[4 * num_detections] = 0.9; data[4 * num_detections + 1] = 0.3; data[5 * num_detections] = 1.0; data[5 * num_detections + 1] = 2.0; for i in 6..num_features {
data[i * num_detections] = 0.1;
data[i * num_detections + 1] = 0.1;
}
let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
let protos = Array3::<f32>::zeros((16, 16, num_protos));
let mut boxes = Vec::with_capacity(10);
let mut masks = Vec::with_capacity(10);
decode_yolo_end_to_end_segdet_float(
output.view(),
protos.view(),
0.5,
&mut boxes,
&mut masks,
)
.unwrap();
assert_eq!(boxes.len(), 1);
assert_eq!(masks.len(), 1);
assert_eq!(boxes[0].label, 1);
assert!((boxes[0].score - 0.9).abs() < 0.01);
}
#[test]
fn test_end_to_end_segdet_mask_coordinates() {
let num_protos = 32;
let num_features = 6 + num_protos;
let mut data = vec![0.0f32; num_features];
data[0] = 0.2; data[1] = 0.2; data[2] = 0.8; data[3] = 0.8; data[4] = 0.95; data[5] = 3.0;
let output = Array2::from_shape_vec((num_features, 1), data).unwrap();
let protos = Array3::<f32>::zeros((16, 16, num_protos));
let mut boxes = Vec::with_capacity(10);
let mut masks = Vec::with_capacity(10);
decode_yolo_end_to_end_segdet_float(
output.view(),
protos.view(),
0.5,
&mut boxes,
&mut masks,
)
.unwrap();
assert_eq!(boxes.len(), 1);
assert_eq!(masks.len(), 1);
assert!((masks[0].xmin - boxes[0].bbox.xmin).abs() < 0.01);
assert!((masks[0].ymin - boxes[0].bbox.ymin).abs() < 0.01);
assert!((masks[0].xmax - boxes[0].bbox.xmax).abs() < 0.01);
assert!((masks[0].ymax - boxes[0].bbox.ymax).abs() < 0.01);
}
#[test]
fn test_end_to_end_segdet_empty_output() {
let num_protos = 32;
let output = Array2::<f32>::zeros((6 + num_protos, 0));
let protos = Array3::<f32>::zeros((16, 16, num_protos));
let mut boxes = Vec::with_capacity(10);
let mut masks = Vec::with_capacity(10);
decode_yolo_end_to_end_segdet_float(
output.view(),
protos.view(),
0.5,
&mut boxes,
&mut masks,
)
.unwrap();
assert_eq!(boxes.len(), 0);
assert_eq!(masks.len(), 0);
}
#[test]
fn test_end_to_end_segdet_capacity_limit() {
let num_protos = 32;
let num_detections = 5;
let num_features = 6 + num_protos;
let mut data = vec![0.0f32; num_features * num_detections];
for i in 0..num_detections {
data[i] = 0.1 * (i as f32); data[num_detections + i] = 0.1 * (i as f32); data[2 * num_detections + i] = 0.1 * (i as f32) + 0.2; data[3 * num_detections + i] = 0.1 * (i as f32) + 0.2; data[4 * num_detections + i] = 0.9; data[5 * num_detections + i] = i as f32; }
let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
let protos = Array3::<f32>::zeros((16, 16, num_protos));
let mut boxes = Vec::with_capacity(2); let mut masks = Vec::with_capacity(2);
decode_yolo_end_to_end_segdet_float(
output.view(),
protos.view(),
0.5,
&mut boxes,
&mut masks,
)
.unwrap();
assert_eq!(boxes.len(), 2);
assert_eq!(masks.len(), 2);
}
#[test]
fn test_end_to_end_segdet_invalid_shape_too_few_rows() {
let output = Array2::<f32>::zeros((6, 3));
let protos = Array3::<f32>::zeros((16, 16, 32));
let mut boxes = Vec::with_capacity(10);
let mut masks = Vec::with_capacity(10);
let result = decode_yolo_end_to_end_segdet_float(
output.view(),
protos.view(),
0.5,
&mut boxes,
&mut masks,
);
assert!(result.is_err());
assert!(matches!(
result,
Err(crate::DecoderError::InvalidShape(s)) if s.contains("at least 7 rows")
));
}
#[test]
fn test_end_to_end_segdet_invalid_shape_protos_mismatch() {
let num_protos = 32;
let output = Array2::<f32>::zeros((6 + 16, 3)); let protos = Array3::<f32>::zeros((16, 16, num_protos));
let mut boxes = Vec::with_capacity(10);
let mut masks = Vec::with_capacity(10);
let result = decode_yolo_end_to_end_segdet_float(
output.view(),
protos.view(),
0.5,
&mut boxes,
&mut masks,
);
assert!(result.is_err());
assert!(matches!(
result,
Err(crate::DecoderError::InvalidShape(s)) if s.contains("doesn't match protos count")
));
}
#[test]
fn test_split_end_to_end_segdet_basic() {
let num_protos = 32;
let num_detections = 2;
let num_features = 6 + num_protos;
let mut data = vec![0.0f32; num_features * num_detections];
data[0] = 0.1; data[1] = 0.5; data[num_detections] = 0.1; data[num_detections + 1] = 0.5; data[2 * num_detections] = 0.4; data[2 * num_detections + 1] = 0.9; data[3 * num_detections] = 0.4; data[3 * num_detections + 1] = 0.9; data[4 * num_detections] = 0.9; data[4 * num_detections + 1] = 0.3; data[5 * num_detections] = 1.0; data[5 * num_detections + 1] = 2.0; for i in 6..num_features {
data[i * num_detections] = 0.1;
data[i * num_detections + 1] = 0.1;
}
let output = Array2::from_shape_vec((num_features, num_detections), data).unwrap();
let box_coords = output.slice(s![..4, ..]);
let scores = output.slice(s![4..5, ..]);
let classes = output.slice(s![5..6, ..]);
let mask_coeff = output.slice(s![6.., ..]);
let protos = Array3::<f32>::zeros((16, 16, num_protos));
let mut boxes = Vec::with_capacity(10);
let mut masks = Vec::with_capacity(10);
decode_yolo_split_end_to_end_segdet_float(
box_coords,
scores,
classes,
mask_coeff,
protos.view(),
0.5,
&mut boxes,
&mut masks,
)
.unwrap();
assert_eq!(boxes.len(), 1);
assert_eq!(masks.len(), 1);
assert_eq!(boxes[0].label, 1);
assert!((boxes[0].score - 0.9).abs() < 0.01);
}
#[test]
fn test_segmentation_to_mask_basic() {
let data: Vec<u8> = vec![
100, 200, 50, 150, 10, 255, 128, 64, 0, 127, 128, 255, 64, 64, 192, 192, ];
let segmentation = Array3::from_shape_vec((4, 4, 1), data).unwrap();
let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
assert_eq!(mask[[0, 0]], 0); assert_eq!(mask[[0, 1]], 1); assert_eq!(mask[[0, 2]], 0); assert_eq!(mask[[0, 3]], 1); assert_eq!(mask[[1, 1]], 1); assert_eq!(mask[[1, 2]], 1); assert_eq!(mask[[2, 0]], 0); assert_eq!(mask[[2, 1]], 0); }
#[test]
fn test_segmentation_to_mask_all_above() {
let segmentation = Array3::from_elem((4, 4, 1), 255u8);
let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
assert!(mask.iter().all(|&x| x == 1));
}
#[test]
fn test_segmentation_to_mask_all_below() {
let segmentation = Array3::from_elem((4, 4, 1), 64u8);
let mask = yolo_segmentation_to_mask(segmentation.view(), 128).unwrap();
assert!(mask.iter().all(|&x| x == 0));
}
#[test]
fn test_segmentation_to_mask_invalid_shape() {
let segmentation = Array3::from_elem((4, 4, 3), 128u8);
let result = yolo_segmentation_to_mask(segmentation.view(), 128);
assert!(result.is_err());
assert!(matches!(
result,
Err(crate::DecoderError::InvalidShape(s)) if s.contains("(H, W, 1)")
));
}
#[test]
fn test_protobox_clamps_edge_coordinates() {
let protos = Array3::<f32>::zeros((16, 16, 4));
let view = protos.view();
let roi = BoundingBox {
xmin: 0.5,
ymin: 0.5,
xmax: 1.0,
ymax: 1.0,
};
let result = protobox(&view, &roi);
assert!(result.is_ok(), "protobox should accept xmax=1.0");
let (cropped, _roi_norm) = result.unwrap();
assert!(cropped.shape()[0] > 0);
assert!(cropped.shape()[1] > 0);
assert_eq!(cropped.shape()[2], 4);
}
#[test]
fn test_protobox_rejects_wildly_out_of_range() {
let protos = Array3::<f32>::zeros((16, 16, 4));
let view = protos.view();
let roi = BoundingBox {
xmin: 0.0,
ymin: 0.0,
xmax: 3.0,
ymax: 3.0,
};
let result = protobox(&view, &roi);
assert!(
matches!(result, Err(crate::DecoderError::InvalidShape(s)) if s.contains("un-normalized")),
"protobox should reject coords > NORM_LIMIT"
);
}
#[test]
fn test_protobox_accepts_slightly_over_one() {
let protos = Array3::<f32>::zeros((16, 16, 4));
let view = protos.view();
let roi = BoundingBox {
xmin: 0.0,
ymin: 0.0,
xmax: 1.5,
ymax: 1.5,
};
let result = protobox(&view, &roi);
assert!(
result.is_ok(),
"protobox should accept coords <= NORM_LIMIT (2.0)"
);
let (cropped, _roi_norm) = result.unwrap();
assert_eq!(cropped.shape()[0], 16);
assert_eq!(cropped.shape()[1], 16);
}
#[test]
fn test_segdet_float_proto_no_panic() {
let num_proposals = 100; let num_classes = 80;
let num_mask_coeffs = 32;
let rows = 4 + num_classes + num_mask_coeffs;
let mut data = vec![0.0f32; rows * num_proposals];
for i in 0..num_proposals {
let row = |r: usize| r * num_proposals + i;
data[row(0)] = 320.0; data[row(1)] = 320.0; data[row(2)] = 50.0; data[row(3)] = 50.0; data[row(4)] = 0.9; }
let boxes = ndarray::Array2::from_shape_vec((rows, num_proposals), data).unwrap();
let protos = ndarray::Array3::<f32>::zeros((160, 160, num_mask_coeffs));
let mut output_boxes = Vec::with_capacity(300);
let proto_data = impl_yolo_segdet_float_proto::<XYWH, _, _>(
boxes.view(),
protos.view(),
0.5,
0.7,
Some(Nms::default()),
&mut output_boxes,
);
assert!(!output_boxes.is_empty());
assert_eq!(proto_data.mask_coefficients.len(), output_boxes.len());
for coeffs in &proto_data.mask_coefficients {
assert_eq!(coeffs.len(), num_mask_coeffs);
}
}
#[test]
fn test_pre_nms_cap_truncates_excess_candidates() {
let n: usize = 50_000;
let num_classes = 1;
let mut boxes_data = Vec::with_capacity(n * 4);
let mut scores_data = Vec::with_capacity(n * num_classes);
for i in 0..n {
boxes_data.extend_from_slice(&[0.1f32, 0.1, 0.5, 0.5]);
scores_data.push(0.99 - (i as f32) * 1e-7);
}
let boxes = Array2::from_shape_vec((n, 4), boxes_data).unwrap();
let scores = Array2::from_shape_vec((n, num_classes), scores_data).unwrap();
let result = impl_yolo_segdet_get_boxes::<XYXY, _, _>(
boxes.view(),
scores.view(),
0.1,
1.0,
None, usize::MAX, );
assert_eq!(
result.len(),
crate::yolo::MAX_NMS_CANDIDATES,
"pre-NMS cap should truncate to MAX_NMS_CANDIDATES; got {}",
result.len()
);
let top_score = result[0].0.score;
assert!(
top_score > 0.98,
"highest-ranked survivor should have the largest score, got {top_score}"
);
}
#[test]
fn test_pre_nms_cap_truncates_excess_candidates_quant() {
use crate::Quantization;
let n: usize = 50_000;
let num_classes = 1;
let boxes_data = (0..n).flat_map(|_| [10i8, 10, 50, 50]).collect::<Vec<_>>();
let boxes = Array2::from_shape_vec((n, 4), boxes_data).unwrap();
let quant_boxes = Quantization {
scale: 0.01,
zero_point: 0,
};
let scores_data: Vec<u8> = (0..n)
.map(|i| 250u8.saturating_sub((i % 200) as u8))
.collect();
let scores = Array2::from_shape_vec((n, num_classes), scores_data).unwrap();
let quant_scores = Quantization {
scale: 0.00392,
zero_point: 0,
};
let result = impl_yolo_split_segdet_quant_get_boxes::<XYXY, _, _>(
(boxes.view(), quant_boxes),
(scores.view(), quant_scores),
0.1,
1.0,
None,
usize::MAX,
);
assert_eq!(
result.len(),
crate::yolo::MAX_NMS_CANDIDATES,
"quant path pre-NMS cap should truncate to MAX_NMS_CANDIDATES; got {}",
result.len()
);
}
#[test]
fn segdet_combined_tensor_pairs_detection_with_matching_mask_row() {
let nc = 2; let nm = 2; let n = 3; let feat = 4 + nc + nm;
let mut data = vec![0.0f32; feat * n];
let set = |d: &mut [f32], r: usize, c: usize, v: f32| d[r * n + c] = v;
set(&mut data, 0, 0, 0.2);
set(&mut data, 1, 0, 0.2);
set(&mut data, 2, 0, 0.1);
set(&mut data, 3, 0, 0.1);
set(&mut data, 0, 1, 0.5);
set(&mut data, 1, 1, 0.5);
set(&mut data, 2, 1, 0.1);
set(&mut data, 3, 1, 0.1);
set(&mut data, 0, 2, 0.8);
set(&mut data, 1, 2, 0.8);
set(&mut data, 2, 2, 0.1);
set(&mut data, 3, 2, 0.1);
set(&mut data, 4, 0, 0.9);
set(&mut data, 4, 2, 0.8);
set(&mut data, 6, 0, 3.0);
set(&mut data, 7, 0, 3.0);
set(&mut data, 6, 2, -3.0);
set(&mut data, 7, 2, -3.0);
let output = Array2::from_shape_vec((feat, n), data).unwrap();
let protos = Array3::<f32>::from_elem((8, 8, nm), 1.0);
let mut boxes: Vec<DetectBox> = Vec::with_capacity(10);
let mut masks: Vec<Segmentation> = Vec::with_capacity(10);
decode_yolo_segdet_float(
output.view(),
protos.view(),
0.5,
0.5,
Some(Nms::ClassAgnostic),
&mut boxes,
&mut masks,
)
.unwrap();
assert_eq!(
boxes.len(),
2,
"two anchors above threshold should survive (a0 score=0.9, a2 score=0.8); got {}",
boxes.len()
);
for (b, m) in boxes.iter().zip(masks.iter()) {
let cx = (b.bbox.xmin + b.bbox.xmax) * 0.5;
let mean = {
let s = &m.segmentation;
let total: u32 = s.iter().map(|&v| v as u32).sum();
total as f32 / s.len() as f32
};
if cx < 0.3 {
assert!(
mean > 200.0,
"anchor 0 detection (centre {cx:.2}) should have high-value mask; got mean {mean}"
);
} else if cx > 0.7 {
assert!(
mean < 50.0,
"anchor 2 detection (centre {cx:.2}) should have low-value mask; got mean {mean}"
);
} else {
panic!("unexpected detection centre {cx:.2}");
}
}
}
}