#[cfg(target_arch = "aarch64")]
use crate::arg_max_i8;
use crate::{
arg_max, float::jaccard, BBoxTypeTrait, BoundingBox, DetectBoxQuantized, Quantization,
};
use ndarray::{
parallel::prelude::{IntoParallelIterator, ParallelIterator as _},
Array1, ArrayView1, ArrayView2, Zip,
};
use num_traits::{AsPrimitive, PrimInt};
use rayon::slice::ParallelSliceMut;
#[cfg(target_arch = "aarch64")]
unsafe fn column_max_update_neon(
col_ptr: *const u8,
max_ptr: *mut u8,
class_ptr: *mut u8,
n: usize,
class_idx: u8,
signed: bool,
) {
use std::arch::aarch64::*;
let class_vec = vdupq_n_u8(class_idx);
let chunks = n / 16;
let remainder = n % 16;
if signed {
for chunk in 0..chunks {
let offset = chunk * 16;
let col = vld1q_s8(col_ptr.add(offset) as *const i8);
let cur_max = vld1q_s8(max_ptr.add(offset) as *const i8);
let mask = vcgeq_s8(col, cur_max);
let new_max = vmaxq_s8(col, cur_max);
vst1q_s8(max_ptr.add(offset) as *mut i8, new_max);
let cur_class = vld1q_u8(class_ptr.add(offset));
let new_class = vbslq_u8(mask, class_vec, cur_class);
vst1q_u8(class_ptr.add(offset), new_class);
}
for i in (chunks * 16)..n {
let val = *(col_ptr.add(i) as *const i8);
let cur = *(max_ptr.add(i) as *const i8);
if val >= cur {
*(max_ptr.add(i) as *mut i8) = val;
*class_ptr.add(i) = class_idx;
}
}
} else {
for chunk in 0..chunks {
let offset = chunk * 16;
let col = vld1q_u8(col_ptr.add(offset));
let cur_max = vld1q_u8(max_ptr.add(offset));
let mask = vcgeq_u8(col, cur_max);
let new_max = vmaxq_u8(col, cur_max);
vst1q_u8(max_ptr.add(offset), new_max);
let cur_class = vld1q_u8(class_ptr.add(offset));
let new_class = vbslq_u8(mask, class_vec, cur_class);
vst1q_u8(class_ptr.add(offset), new_class);
}
for i in (chunks * 16)..n {
let val = *col_ptr.add(i);
let cur = *max_ptr.add(i);
if val >= cur {
*max_ptr.add(i) = val;
*class_ptr.add(i) = class_idx;
}
}
}
let _ = remainder; }
#[inline(always)]
fn fast_arg_max<T: PrimInt + Copy>(score: ArrayView1<T>) -> (T, usize) {
#[cfg(target_arch = "aarch64")]
{
if std::mem::size_of::<T>() == 1 && score.as_slice().is_some() {
let slice = score.as_slice().unwrap();
let ptr = slice.as_ptr() as *const i8;
let i8_slice = unsafe { std::slice::from_raw_parts(ptr, slice.len()) };
if T::min_value() < T::zero() {
let (max_val, idx) = arg_max_i8(i8_slice);
let result: T = unsafe { std::mem::transmute_copy(&max_val) };
return (result, idx);
}
}
}
arg_max(score)
}
#[doc(hidden)]
pub fn postprocess_boxes_quant<
B: BBoxTypeTrait,
Boxes: PrimInt + AsPrimitive<f32> + Send + Sync,
Scores: PrimInt + AsPrimitive<f32> + Send + Sync,
>(
threshold: Scores,
boxes: ArrayView2<Boxes>,
scores: ArrayView2<Scores>,
quant_boxes: Quantization,
) -> Vec<DetectBoxQuantized<Scores>> {
assert_eq!(scores.dim().0, boxes.dim().0);
assert_eq!(boxes.dim().1, 4);
if scores.strides()[0] == 1 && scores.as_slice().is_none() {
return postprocess_boxes_quant_column_major::<B, _, _>(
threshold,
boxes,
scores,
quant_boxes,
);
}
Zip::from(scores.rows())
.and(boxes.rows())
.into_par_iter()
.filter_map(|(score, bbox)| {
let (score_, label) = fast_arg_max(score);
if score_ < threshold {
return None;
}
let bbox_quant = B::ndarray_to_xyxy_dequant(bbox.view(), quant_boxes);
Some(DetectBoxQuantized {
label,
score: score_,
bbox: BoundingBox::from(bbox_quant),
})
})
.collect()
}
fn postprocess_boxes_quant_column_major<
B: BBoxTypeTrait,
Boxes: PrimInt + AsPrimitive<f32> + Send + Sync,
Scores: PrimInt + AsPrimitive<f32> + Send + Sync,
>(
threshold: Scores,
boxes: ArrayView2<Boxes>,
scores: ArrayView2<Scores>,
quant_boxes: Quantization,
) -> Vec<DetectBoxQuantized<Scores>> {
let (n_candidates, n_classes) = scores.dim();
if n_classes > 255 {
return Zip::from(scores.rows())
.and(boxes.rows())
.into_par_iter()
.filter_map(|(score, bbox)| {
let (score_, label) = fast_arg_max(score);
if score_ < threshold {
return None;
}
let bbox_quant = B::ndarray_to_xyxy_dequant(bbox.view(), quant_boxes);
Some(DetectBoxQuantized {
label,
score: score_,
bbox: BoundingBox::from(bbox_quant),
})
})
.collect();
}
let mut max_scores = vec![Scores::min_value(); n_candidates];
let mut max_classes = vec![0u8; n_candidates];
for class_idx in 0..n_classes {
let col = scores.column(class_idx);
if let Some(slice) = col.as_slice() {
#[cfg(target_arch = "aarch64")]
{
if std::mem::size_of::<Scores>() == 1 {
unsafe {
column_max_update_neon(
slice.as_ptr() as *const u8,
max_scores.as_mut_ptr() as *mut u8,
max_classes.as_mut_ptr(),
n_candidates,
class_idx as u8,
Scores::min_value() < Scores::zero(),
);
}
continue;
}
}
for (i, &val) in slice.iter().enumerate() {
if val >= max_scores[i] {
max_scores[i] = val;
max_classes[i] = class_idx as u8;
}
}
} else {
for (i, &val) in col.iter().enumerate() {
if val >= max_scores[i] {
max_scores[i] = val;
max_classes[i] = class_idx as u8;
}
}
}
}
let boxes_buf: [Vec<Boxes>; 4] = if boxes.strides()[0] == 1 && boxes.as_slice().is_none() {
let mut cols: [Vec<Boxes>; 4] = [
vec![Boxes::zero(); n_candidates],
vec![Boxes::zero(); n_candidates],
vec![Boxes::zero(); n_candidates],
vec![Boxes::zero(); n_candidates],
];
for (dim, col_buf) in cols.iter_mut().enumerate() {
let col = boxes.column(dim);
if let Some(slice) = col.as_slice() {
col_buf.copy_from_slice(slice);
} else {
for (i, &val) in col.iter().enumerate() {
col_buf[i] = val;
}
}
}
cols
} else {
[vec![], vec![], vec![], vec![]]
};
let boxes_copied = !boxes_buf[0].is_empty();
let mut result = Vec::new();
for i in 0..n_candidates {
if max_scores[i] >= threshold {
let bbox_quant = if boxes_copied {
let raw = [
boxes_buf[0][i],
boxes_buf[1][i],
boxes_buf[2][i],
boxes_buf[3][i],
];
B::to_xyxy_dequant(&raw, quant_boxes)
} else {
B::ndarray_to_xyxy_dequant(boxes.row(i), quant_boxes)
};
result.push(DetectBoxQuantized {
label: max_classes[i] as usize,
score: max_scores[i],
bbox: BoundingBox::from(bbox_quant),
});
}
}
result
}
#[doc(hidden)]
pub fn postprocess_boxes_index_quant<
B: BBoxTypeTrait,
Boxes: PrimInt + AsPrimitive<f32> + Send + Sync,
Scores: PrimInt + AsPrimitive<f32> + Send + Sync,
>(
threshold: Scores,
boxes: ArrayView2<Boxes>,
scores: ArrayView2<Scores>,
quant_boxes: Quantization,
) -> Vec<(DetectBoxQuantized<Scores>, usize)> {
assert_eq!(scores.dim().0, boxes.dim().0);
assert_eq!(boxes.dim().1, 4);
if scores.strides()[0] == 1 && scores.as_slice().is_none() {
return postprocess_boxes_index_quant_column_major::<B, _, _>(
threshold,
boxes,
scores,
quant_boxes,
);
}
let indices: Array1<usize> = (0..boxes.dim().0).collect();
Zip::from(scores.rows())
.and(boxes.rows())
.and(&indices)
.into_par_iter()
.filter_map(|(score, bbox, index)| {
let (score_, label) = fast_arg_max(score);
if score_ < threshold {
return None;
}
let bbox_quant = B::ndarray_to_xyxy_dequant(bbox.view(), quant_boxes);
Some((
DetectBoxQuantized {
label,
score: score_,
bbox: BoundingBox::from(bbox_quant),
},
*index,
))
})
.collect()
}
fn postprocess_boxes_index_quant_column_major<
B: BBoxTypeTrait,
Boxes: PrimInt + AsPrimitive<f32> + Send + Sync,
Scores: PrimInt + AsPrimitive<f32> + Send + Sync,
>(
threshold: Scores,
boxes: ArrayView2<Boxes>,
scores: ArrayView2<Scores>,
quant_boxes: Quantization,
) -> Vec<(DetectBoxQuantized<Scores>, usize)> {
let (n_candidates, n_classes) = scores.dim();
if n_classes > 255 {
let indices: Array1<usize> = (0..n_candidates).collect();
return Zip::from(scores.rows())
.and(boxes.rows())
.and(&indices)
.into_par_iter()
.filter_map(|(score, bbox, index)| {
let (score_, label) = fast_arg_max(score);
if score_ < threshold {
return None;
}
let bbox_quant = B::ndarray_to_xyxy_dequant(bbox.view(), quant_boxes);
Some((
DetectBoxQuantized {
label,
score: score_,
bbox: BoundingBox::from(bbox_quant),
},
*index,
))
})
.collect();
}
let mut max_scores = vec![Scores::min_value(); n_candidates];
let mut max_classes = vec![0u8; n_candidates];
for class_idx in 0..n_classes {
let col = scores.column(class_idx);
if let Some(slice) = col.as_slice() {
#[cfg(target_arch = "aarch64")]
{
if std::mem::size_of::<Scores>() == 1 {
unsafe {
column_max_update_neon(
slice.as_ptr() as *const u8,
max_scores.as_mut_ptr() as *mut u8,
max_classes.as_mut_ptr(),
n_candidates,
class_idx as u8,
Scores::min_value() < Scores::zero(), );
}
continue;
}
}
for (i, &val) in slice.iter().enumerate() {
if val >= max_scores[i] {
max_scores[i] = val;
max_classes[i] = class_idx as u8;
}
}
} else {
for (i, &val) in col.iter().enumerate() {
if val >= max_scores[i] {
max_scores[i] = val;
max_classes[i] = class_idx as u8;
}
}
}
}
let boxes_buf: [Vec<Boxes>; 4] = if boxes.strides()[0] == 1 && boxes.as_slice().is_none() {
let mut cols: [Vec<Boxes>; 4] = [
vec![Boxes::zero(); n_candidates],
vec![Boxes::zero(); n_candidates],
vec![Boxes::zero(); n_candidates],
vec![Boxes::zero(); n_candidates],
];
for (dim, col_buf) in cols.iter_mut().enumerate() {
let col = boxes.column(dim);
if let Some(slice) = col.as_slice() {
col_buf.copy_from_slice(slice);
} else {
for (i, &val) in col.iter().enumerate() {
col_buf[i] = val;
}
}
}
cols
} else {
[vec![], vec![], vec![], vec![]]
};
let boxes_copied = !boxes_buf[0].is_empty();
let mut result = Vec::new();
for i in 0..n_candidates {
if max_scores[i] >= threshold {
let bbox_quant = if boxes_copied {
let raw = [
boxes_buf[0][i],
boxes_buf[1][i],
boxes_buf[2][i],
boxes_buf[3][i],
];
B::to_xyxy_dequant(&raw, quant_boxes)
} else {
B::ndarray_to_xyxy_dequant(boxes.row(i), quant_boxes)
};
result.push((
DetectBoxQuantized {
label: max_classes[i] as usize,
score: max_scores[i],
bbox: BoundingBox::from(bbox_quant),
},
i,
));
}
}
result
}
#[doc(hidden)]
#[must_use]
pub fn nms_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync>(
iou: f32,
max_det: Option<usize>,
mut boxes: Vec<DetectBoxQuantized<SCORE>>,
) -> Vec<DetectBoxQuantized<SCORE>> {
boxes.par_sort_by(|a, b| b.score.cmp(&a.score));
if iou >= 1.0 {
return match max_det {
Some(n) => {
boxes.truncate(n);
boxes
}
None => boxes,
};
}
let min_val = SCORE::min_value();
let cap = max_det.unwrap_or(usize::MAX);
let mut survivors: usize = 0;
for i in 0..boxes.len() {
if boxes[i].score <= min_val {
continue;
}
for j in (i + 1)..boxes.len() {
if boxes[j].score <= min_val {
continue;
}
if jaccard(&boxes[j].bbox, &boxes[i].bbox, iou) {
boxes[j].score = min_val;
}
}
survivors += 1;
if survivors >= cap {
break;
}
}
boxes
.into_iter()
.filter(|b| b.score > min_val)
.take(cap)
.collect()
}
#[doc(hidden)]
#[must_use]
pub fn nms_extra_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send + Sync>(
iou: f32,
max_det: Option<usize>,
mut boxes: Vec<(DetectBoxQuantized<SCORE>, E)>,
) -> Vec<(DetectBoxQuantized<SCORE>, E)> {
boxes.par_sort_by(|a, b| b.0.score.cmp(&a.0.score));
if iou >= 1.0 {
return match max_det {
Some(n) => {
boxes.truncate(n);
boxes
}
None => boxes,
};
}
let min_val = SCORE::min_value();
let cap = max_det.unwrap_or(usize::MAX);
let mut survivors: usize = 0;
for i in 0..boxes.len() {
if boxes[i].0.score <= min_val {
continue;
}
for j in (i + 1)..boxes.len() {
if boxes[j].0.score <= min_val {
continue;
}
if jaccard(&boxes[j].0.bbox, &boxes[i].0.bbox, iou) {
boxes[j].0.score = min_val;
}
}
survivors += 1;
if survivors >= cap {
break;
}
}
boxes
.into_iter()
.filter(|b| b.0.score > min_val)
.take(cap)
.collect()
}
#[doc(hidden)]
#[must_use]
pub fn nms_class_aware_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync>(
iou: f32,
max_det: Option<usize>,
mut boxes: Vec<DetectBoxQuantized<SCORE>>,
) -> Vec<DetectBoxQuantized<SCORE>> {
boxes.par_sort_by(|a, b| b.score.cmp(&a.score));
if iou >= 1.0 {
return match max_det {
Some(n) => {
boxes.truncate(n);
boxes
}
None => boxes,
};
}
let min_val = SCORE::min_value();
let cap = max_det.unwrap_or(usize::MAX);
let mut survivors: usize = 0;
for i in 0..boxes.len() {
if boxes[i].score <= min_val {
continue;
}
for j in (i + 1)..boxes.len() {
if boxes[j].score <= min_val {
continue;
}
if boxes[j].label == boxes[i].label && jaccard(&boxes[j].bbox, &boxes[i].bbox, iou) {
boxes[j].score = min_val;
}
}
survivors += 1;
if survivors >= cap {
break;
}
}
boxes
.into_iter()
.filter(|b| b.score > min_val)
.take(cap)
.collect()
}
#[doc(hidden)]
#[must_use]
pub fn nms_extra_class_aware_int<
SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
E: Send + Sync,
>(
iou: f32,
max_det: Option<usize>,
mut boxes: Vec<(DetectBoxQuantized<SCORE>, E)>,
) -> Vec<(DetectBoxQuantized<SCORE>, E)> {
boxes.par_sort_by(|a, b| b.0.score.cmp(&a.0.score));
if iou >= 1.0 {
return match max_det {
Some(n) => {
boxes.truncate(n);
boxes
}
None => boxes,
};
}
let min_val = SCORE::min_value();
let cap = max_det.unwrap_or(usize::MAX);
let mut survivors: usize = 0;
for i in 0..boxes.len() {
if boxes[i].0.score <= min_val {
continue;
}
for j in (i + 1)..boxes.len() {
if boxes[j].0.score <= min_val {
continue;
}
if boxes[j].0.label == boxes[i].0.label
&& jaccard(&boxes[j].0.bbox, &boxes[i].0.bbox, iou)
{
boxes[j].0.score = min_val;
}
}
survivors += 1;
if survivors >= cap {
break;
}
}
boxes
.into_iter()
.filter(|b| b.0.score > min_val)
.take(cap)
.collect()
}
#[doc(hidden)]
pub fn quantize_score_threshold<T: PrimInt + AsPrimitive<f32>>(score: f32, quant: Quantization) -> T
where
f32: AsPrimitive<T>,
{
if quant.scale == 0.0 {
return T::max_value();
}
let v = (score / quant.scale + quant.zero_point as f32).ceil();
let v = v.clamp(T::min_value().as_(), T::max_value().as_());
v.as_()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::XYWH;
use ndarray::Array2;
#[test]
fn column_major_matches_row_major() {
let n_classes = 80usize;
let n_candidates = 100usize;
let mut scores_physical = Array2::<u8>::zeros((n_classes, n_candidates));
for c in 0..n_classes {
for i in 0..n_candidates {
scores_physical[[c, i]] = ((c * 3 + i * 7) % 256) as u8;
}
}
let mut boxes_physical = Array2::<i16>::zeros((4, n_candidates));
for i in 0..n_candidates {
boxes_physical[[0, i]] = (i * 10) as i16; boxes_physical[[1, i]] = (i * 20) as i16; boxes_physical[[2, i]] = (i * 10 + 50) as i16; boxes_physical[[3, i]] = (i * 20 + 100) as i16; }
let quant = Quantization {
scale: 0.00390625,
zero_point: 0,
};
let threshold: u8 = 10;
let scores_contiguous = scores_physical.clone().reversed_axes().to_owned();
let boxes_contiguous = boxes_physical.clone().reversed_axes().to_owned();
let row_result = postprocess_boxes_index_quant::<XYWH, _, _>(
threshold,
boxes_contiguous.view(),
scores_contiguous.view(),
quant,
);
let scores_view = scores_physical.view().reversed_axes();
let boxes_view = boxes_physical.view().reversed_axes();
assert!(scores_view.as_slice().is_none(), "should be non-contiguous");
assert_eq!(scores_view.strides()[0], 1);
let col_result =
postprocess_boxes_index_quant::<XYWH, _, _>(threshold, boxes_view, scores_view, quant);
assert_eq!(
row_result.len(),
col_result.len(),
"different number of results: row={}, col={}",
row_result.len(),
col_result.len()
);
for (i, (row, col)) in row_result.iter().zip(col_result.iter()).enumerate() {
assert_eq!(
row.0.label, col.0.label,
"candidate {i}: label mismatch row={} col={}",
row.0.label, col.0.label
);
assert_eq!(row.0.score, col.0.score, "candidate {i}: score mismatch");
assert_eq!(row.1, col.1, "candidate {i}: index mismatch");
assert_eq!(row.0.bbox, col.0.bbox, "candidate {i}: bbox mismatch");
}
}
#[test]
fn column_major_matches_row_major_i8() {
let n_classes = 80usize;
let n_candidates = 50usize;
let mut scores_physical = Array2::<i8>::zeros((n_classes, n_candidates));
for c in 0..n_classes {
for i in 0..n_candidates {
scores_physical[[c, i]] = ((c as i16 * 3 + i as i16 * 7) % 256 - 128) as i8;
}
}
let mut boxes_physical = Array2::<i16>::zeros((4, n_candidates));
for i in 0..n_candidates {
boxes_physical[[0, i]] = (i * 10) as i16;
boxes_physical[[1, i]] = (i * 20) as i16;
boxes_physical[[2, i]] = (i * 10 + 50) as i16;
boxes_physical[[3, i]] = (i * 20 + 100) as i16;
}
let quant = Quantization {
scale: 0.0256,
zero_point: -116,
};
let threshold: i8 = -100;
let scores_contiguous = scores_physical.clone().reversed_axes().to_owned();
let boxes_contiguous = boxes_physical.clone().reversed_axes().to_owned();
let row_result = postprocess_boxes_index_quant::<XYWH, _, _>(
threshold,
boxes_contiguous.view(),
scores_contiguous.view(),
quant,
);
let scores_view = scores_physical.view().reversed_axes();
let boxes_view = boxes_physical.view().reversed_axes();
let col_result =
postprocess_boxes_index_quant::<XYWH, _, _>(threshold, boxes_view, scores_view, quant);
assert_eq!(row_result.len(), col_result.len());
for (i, (row, col)) in row_result.iter().zip(col_result.iter()).enumerate() {
assert_eq!(row.0.label, col.0.label, "i8 candidate {i}: label mismatch");
assert_eq!(row.0.score, col.0.score, "i8 candidate {i}: score mismatch");
assert_eq!(row.1, col.1, "i8 candidate {i}: index mismatch");
}
}
fn make_nms_boxes_int(n: usize) -> Vec<DetectBoxQuantized<u8>> {
(0..n)
.map(|i| DetectBoxQuantized {
bbox: BoundingBox {
xmin: i as f32 * 100.0,
ymin: 0.0,
xmax: i as f32 * 100.0 + 10.0,
ymax: 10.0,
},
label: 0,
score: (200 - i as u32).min(255) as u8,
})
.collect()
}
#[test]
fn nms_int_max_det_matches_full_truncated() {
let boxes = make_nms_boxes_int(20);
let n = 5;
let full = nms_int(0.5, None, boxes.clone());
let capped = nms_int(0.5, Some(n), boxes);
assert_eq!(capped.len(), n);
assert_eq!(&full[..n], &capped[..]);
}
#[test]
fn nms_int_max_det_zero_returns_empty() {
let boxes = make_nms_boxes_int(10);
let result = nms_int(0.5, Some(0), boxes);
assert!(result.is_empty());
}
#[test]
fn nms_int_max_det_iou_ge_1_returns_sorted_truncated() {
let boxes = make_nms_boxes_int(10);
let result = nms_int(1.0, Some(3), boxes);
assert_eq!(result.len(), 3);
assert!(result[0].score >= result[1].score);
assert!(result[1].score >= result[2].score);
}
#[test]
fn nms_int_max_det_larger_than_input() {
let boxes = make_nms_boxes_int(5);
let full = nms_int(0.5, None, boxes.clone());
let capped = nms_int(0.5, Some(100), boxes);
assert_eq!(full.len(), capped.len());
}
}