use crate::NmsOptions;
use aligned_vec::{AVec, ConstAlign};
use alloc::vec::Vec;
use burn_tensor::{Int, Shape, Tensor, TensorData, backend::Backend};
use macerator::{Scalar, Simd, Vector, vload};
pub fn nms<B: Backend>(
boxes: Tensor<B, 2>,
scores: Tensor<B, 1>,
options: NmsOptions,
) -> Tensor<B, 1, Int> {
let device = boxes.device();
let [n_boxes, _] = boxes.shape().dims();
if n_boxes == 0 {
return Tensor::<B, 1, Int>::empty([0], &device);
}
let boxes_data = boxes.to_data();
let boxes_vec: Vec<f32> = boxes_data.to_vec().unwrap();
let scores_data = scores.to_data();
let scores_vec: Vec<f32> = scores_data.to_vec().unwrap();
let keep = nms_vec(boxes_vec, scores_vec, options);
let n_kept = keep.len();
let indices_data = TensorData::new(keep, Shape::new([n_kept]));
Tensor::<B, 1, Int>::from_data(indices_data, &device)
}
fn nms_vec(boxes_vec: Vec<f32>, scores_vec: Vec<f32>, options: NmsOptions) -> Vec<i32> {
let n_boxes = scores_vec.len();
if n_boxes == 0 {
return vec![];
}
let mut filtered_indices = Vec::with_capacity(n_boxes);
for (i, &score) in scores_vec.iter().enumerate() {
if score >= options.score_threshold {
filtered_indices.push(i); }
}
let n_filtered = filtered_indices.len();
if n_filtered == 0 {
return vec![];
}
filtered_indices.sort_by(|&a, &b| scores_vec[b].total_cmp(&scores_vec[a]));
const ALIGN: usize = 64;
const FLOATS_PER_ALIGN: usize = ALIGN / size_of::<f32>(); let stride = n_filtered.div_ceil(FLOATS_PER_ALIGN) * FLOATS_PER_ALIGN;
let mut buf: AVec<f32, ConstAlign<64>> = AVec::with_capacity(ALIGN, stride * 5);
buf.resize(stride * 5, 0.0);
let (x1s, rest) = buf.split_at_mut(stride);
let (y1s, rest) = rest.split_at_mut(stride);
let (x2s, rest) = rest.split_at_mut(stride);
let (y2s, areas) = rest.split_at_mut(stride);
for (j, &orig_idx) in filtered_indices.iter().enumerate() {
let x1 = boxes_vec[orig_idx * 4];
let y1 = boxes_vec[orig_idx * 4 + 1];
let x2 = boxes_vec[orig_idx * 4 + 2];
let y2 = boxes_vec[orig_idx * 4 + 3];
x1s[j] = x1;
y1s[j] = y1;
x2s[j] = x2;
y2s[j] = y2;
areas[j] = (x2 - x1) * (y2 - y1);
}
let mut suppressed = vec![false; stride];
let mut keep = Vec::new();
for i in 0..n_filtered {
if suppressed[i] {
continue;
}
suppressed[i] = true;
keep.push(filtered_indices[i] as i32);
if options.max_output_boxes > 0 && keep.len() >= options.max_output_boxes {
break;
}
suppress_overlapping(
x1s[i],
y1s[i],
x2s[i],
y2s[i],
areas[i],
x1s,
y1s,
x2s,
y2s,
areas,
&mut suppressed,
stride,
options.iou_threshold,
);
}
keep
}
#[allow(clippy::too_many_arguments)]
#[inline(always)]
#[macerator::with_simd]
fn suppress_overlapping<'a, S: Simd>(
ref_x1: f32,
ref_y1: f32,
ref_x2: f32,
ref_y2: f32,
ref_area: f32,
x1s: &'a [f32],
y1s: &'a [f32],
x2s: &'a [f32],
y2s: &'a [f32],
areas: &'a [f32],
suppressed: &'a mut [bool],
n_boxes: usize, threshold: f32,
) where
'a: 'a,
{
let lanes = f32::lanes::<S>();
let ref_x1_v: Vector<S, f32> = ref_x1.splat();
let ref_y1_v: Vector<S, f32> = ref_y1.splat();
let ref_x2_v: Vector<S, f32> = ref_x2.splat();
let ref_y2_v: Vector<S, f32> = ref_y2.splat();
let ref_area_v: Vector<S, f32> = ref_area.splat();
let thresh_v: Vector<S, f32> = threshold.splat();
let zero_v: Vector<S, f32> = 0.0f32.splat();
let mut i = 0;
let mut mask_buf = core::mem::MaybeUninit::<[bool; 16]>::uninit();
while i + lanes <= n_boxes {
let all_suppressed = unsafe {
match lanes {
4 => *(suppressed.as_ptr().add(i) as *const u32) == 0x01010101,
8 => *(suppressed.as_ptr().add(i) as *const u64) == 0x0101010101010101,
16 => {
*(suppressed.as_ptr().add(i) as *const u128)
== 0x01010101010101010101010101010101
}
_ => unreachable!(),
}
};
if !all_suppressed {
let x1_v: Vector<S, f32> = unsafe { vload(x1s.as_ptr().add(i)) };
let y1_v: Vector<S, f32> = unsafe { vload(y1s.as_ptr().add(i)) };
let x2_v: Vector<S, f32> = unsafe { vload(x2s.as_ptr().add(i)) };
let y2_v: Vector<S, f32> = unsafe { vload(y2s.as_ptr().add(i)) };
let area_v: Vector<S, f32> = unsafe { vload(areas.as_ptr().add(i)) };
let xx1 = ref_x1_v.max(x1_v);
let yy1 = ref_y1_v.max(y1_v);
let xx2 = ref_x2_v.min(x2_v);
let yy2 = ref_y2_v.min(y2_v);
let w = (xx2 - xx1).max(zero_v);
let h = (yy2 - yy1).max(zero_v);
let inter = w * h;
let union = ref_area_v + area_v - inter;
let iou = inter / union;
let suppress_mask = iou.gt(thresh_v);
unsafe { f32::mask_store_as_bool::<S>(mask_buf.as_mut_ptr().cast(), suppress_mask) };
let mask_buf = unsafe { mask_buf.assume_init() };
for k in 0..lanes {
if mask_buf[k] {
suppressed[i + k] = true;
}
}
}
i += lanes;
}
}