use std::borrow::Cow;
use ndarray::{Array3, ArrayView2, ArrayView3, Axis};
use tracing::trace_span;
use crate::per_scale::outputs::{BufferRef, DecodedOutputsRef, ProtosView};
use crate::yolo::{
extract_proto_data_float, impl_yolo_segdet_get_boxes, impl_yolo_split_segdet_process_masks,
};
use crate::{DecoderError, DetectBox, Nms, ProtoData, Segmentation, XYWH};
enum ProtosF32<'a> {
Borrowed(ArrayView3<'a, f32>),
Owned(Array3<f32>),
}
impl ProtosF32<'_> {
fn view(&self) -> ArrayView3<'_, f32> {
match self {
Self::Borrowed(v) => v.reborrow(),
Self::Owned(a) => a.view(),
}
}
}
struct WidenedF32<'a> {
boxes: Cow<'a, [f32]>,
scores: Cow<'a, [f32]>,
mask_coefs: Option<Cow<'a, [f32]>>,
protos: Option<ProtosF32<'a>>,
n: usize,
nc: usize,
nm: usize,
}
fn widen_to_f32<'a>(decoded: &'a DecodedOutputsRef<'a>) -> WidenedF32<'a> {
let kind = match &decoded.boxes {
BufferRef::F32(_) => "f32_borrow",
BufferRef::F16(_) => "f16_widen",
};
let _span = trace_span!("per_scale_bridge::widen_to_f32", kind = kind).entered();
let n = decoded.total_anchors;
let nc = decoded.num_classes;
let nm = decoded.num_mask_coefs;
let boxes: Cow<'a, [f32]> = match &decoded.boxes {
BufferRef::F32(s) => Cow::Borrowed(*s),
BufferRef::F16(s) => Cow::Owned(s.iter().map(|v| v.to_f32()).collect()),
};
let scores: Cow<'a, [f32]> = match &decoded.scores {
BufferRef::F32(s) => Cow::Borrowed(*s),
BufferRef::F16(s) => Cow::Owned(s.iter().map(|v| v.to_f32()).collect()),
};
let mask_coefs = decoded.mask_coefs.as_ref().map(|b| match b {
BufferRef::F32(s) => Cow::Borrowed(*s),
BufferRef::F16(s) => Cow::Owned(s.iter().map(|v| v.to_f32()).collect()),
});
let protos = decoded.protos.as_ref().map(|p| match p {
ProtosView::F32(a) => ProtosF32::Borrowed(a.index_axis(Axis(0), 0)),
ProtosView::F16(a) => ProtosF32::Owned(a.index_axis(Axis(0), 0).mapv(|v| v.to_f32())),
});
WidenedF32 {
boxes,
scores,
mask_coefs,
protos,
n,
nc,
nm,
}
}
#[allow(clippy::too_many_arguments)]
pub(super) fn per_scale_to_proto_data<'a>(
decoded: &'a DecodedOutputsRef<'a>,
output_boxes: &mut Vec<DetectBox>,
iou_threshold: f32,
score_threshold: f32,
nms_mode: Option<Nms>,
pre_nms_top_k: usize,
max_det: usize,
normalized: Option<bool>,
input_dims: Option<(usize, usize)>,
) -> Result<Option<ProtoData>, DecoderError> {
let _span = trace_span!("per_scale_bridge::per_scale_to_proto_data").entered();
let widened = widen_to_f32(decoded);
let n = widened.n;
let nc = widened.nc;
let nm = widened.nm;
let boxes_view = ArrayView2::<f32>::from_shape((n, 4), widened.boxes.as_ref())
.map_err(|e| DecoderError::Internal(format!("per_scale boxes view: {e}")))?;
let scores_view = ArrayView2::<f32>::from_shape((n, nc), widened.scores.as_ref())
.map_err(|e| DecoderError::Internal(format!("per_scale scores view: {e}")))?;
output_boxes.clear();
let mut det_indices = {
let _s = trace_span!("per_scale_bridge::nms_get_boxes").entered();
impl_yolo_segdet_get_boxes::<XYWH, _, _>(
boxes_view,
scores_view,
score_threshold,
iou_threshold,
nms_mode,
pre_nms_top_k,
max_det,
)
};
crate::yolo::maybe_normalize_boxes_in_place(&mut det_indices, normalized, input_dims);
match (widened.mask_coefs.as_ref(), widened.protos.as_ref()) {
(Some(mc), Some(pr)) => {
let _s = trace_span!("per_scale_bridge::extract_proto_data").entered();
let mc_view = ArrayView2::<f32>::from_shape((n, nm), mc.as_ref())
.map_err(|e| DecoderError::Internal(format!("per_scale mc view: {e}")))?;
let pr_view = pr.view();
let proto_data = extract_proto_data_float(det_indices, mc_view, pr_view, output_boxes);
Ok(Some(proto_data))
}
_ => {
for (db, _) in det_indices {
output_boxes.push(db);
}
Ok(None)
}
}
}
#[allow(clippy::too_many_arguments)]
pub(super) fn per_scale_to_masks<'a>(
decoded: &'a DecodedOutputsRef<'a>,
output_boxes: &mut Vec<DetectBox>,
output_masks: &mut Vec<Segmentation>,
iou_threshold: f32,
score_threshold: f32,
nms_mode: Option<Nms>,
pre_nms_top_k: usize,
max_det: usize,
normalized: Option<bool>,
input_dims: Option<(usize, usize)>,
) -> Result<(), DecoderError> {
let _span = trace_span!("per_scale_bridge::per_scale_to_masks").entered();
let widened = widen_to_f32(decoded);
let n = widened.n;
let nc = widened.nc;
let nm = widened.nm;
let boxes_view = ArrayView2::<f32>::from_shape((n, 4), widened.boxes.as_ref())
.map_err(|e| DecoderError::Internal(format!("per_scale boxes view: {e}")))?;
let scores_view = ArrayView2::<f32>::from_shape((n, nc), widened.scores.as_ref())
.map_err(|e| DecoderError::Internal(format!("per_scale scores view: {e}")))?;
output_boxes.clear();
output_masks.clear();
let mut det_indices = {
let _s = trace_span!("per_scale_bridge::nms_get_boxes").entered();
impl_yolo_segdet_get_boxes::<XYWH, _, _>(
boxes_view,
scores_view,
score_threshold,
iou_threshold,
nms_mode,
pre_nms_top_k,
max_det,
)
};
crate::yolo::maybe_normalize_boxes_in_place(&mut det_indices, normalized, input_dims);
match (widened.mask_coefs.as_ref(), widened.protos.as_ref()) {
(Some(mc), Some(pr)) => {
let _s = trace_span!("per_scale_bridge::process_masks").entered();
let mc_view = ArrayView2::<f32>::from_shape((n, nm), mc.as_ref())
.map_err(|e| DecoderError::Internal(format!("per_scale mc view: {e}")))?;
let pr_view = pr.view();
impl_yolo_split_segdet_process_masks(
det_indices,
mc_view,
pr_view,
output_boxes,
output_masks,
)
}
_ => {
for (db, _) in det_indices {
output_boxes.push(db);
}
Ok(())
}
}
}