use std::cmp::Ordering;
use crate::rotation;
use crate::utils;
#[cfg(feature = "ndarray")]
use ndarray::{ArrayView1, ArrayView2, Axis};
use num_traits::{Num, ToPrimitive};
use rstar::{RTree, RTreeNum, AABB};
#[inline(always)]
fn area_f64<N>(bx: N, by: N, bxx: N, byy: N) -> f64
where
N: ToPrimitive,
{
(bxx.to_f64().unwrap() - bx.to_f64().unwrap()) * (byy.to_f64().unwrap() - by.to_f64().unwrap())
}
pub fn filter_and_sort_scores(scores: &[f64], score_threshold: f64) -> Vec<usize> {
let mut indices: Vec<_> = if score_threshold > utils::ZERO {
scores
.iter()
.enumerate()
.filter(|(_, &score)| score >= score_threshold)
.map(|(idx, _)| idx)
.collect()
} else {
(0..scores.len()).collect()
};
indices.sort_unstable_by(|&a, &b| scores[b].partial_cmp(&scores[a]).unwrap_or(Ordering::Equal));
indices
}
pub fn nms_slice<N>(
boxes: &[N],
scores: &[f64],
iou_threshold: f64,
score_threshold: f64,
) -> Vec<usize>
where
N: Num + PartialEq + PartialOrd + ToPrimitive + Copy,
{
let order = filter_and_sort_scores(scores, score_threshold);
let mut keep: Vec<usize> = Vec::new();
let mut suppress = vec![false; order.len()];
for (i, &idx) in order.iter().enumerate() {
if suppress[i] {
continue;
}
keep.push(idx);
let (b1x, b1y, b1xx, b1yy) = utils::row4(boxes, idx);
let area1 = area_f64(b1x, b1y, b1xx, b1yy);
for j in (i + 1)..order.len() {
if suppress[j] {
continue;
}
let (b2x, b2y, b2xx, b2yy) = utils::row4(boxes, order[j]);
let x = utils::max(b1x, b2x);
let y = utils::max(b1y, b2y);
let xx = utils::min(b1xx, b2xx);
let yy = utils::min(b1yy, b2yy);
if x > xx || y > yy {
continue;
}
let intersection = area_f64(x, y, xx, yy);
let area2 = area_f64(b2x, b2y, b2xx, b2yy);
let union = area1 + area2 - intersection;
let iou = intersection / union;
if iou > iou_threshold {
suppress[j] = true;
}
}
}
keep
}
pub fn rtree_nms_slice<N>(
boxes: &[N],
scores: &[f64],
iou_threshold: f64,
score_threshold: f64,
) -> Vec<usize>
where
N: RTreeNum + PartialEq + PartialOrd + ToPrimitive + Copy + Send + Sync,
{
let order = filter_and_sort_scores(scores, score_threshold);
let mut keep: Vec<usize> = Vec::new();
let mut suppress = vec![false; boxes.len()];
let rtree: RTree<utils::Bbox<N>> = RTree::bulk_load(
order
.iter()
.map(|&idx| {
let (x1, y1, x2, y2) = utils::row4(boxes, idx);
utils::Bbox {
x1,
y1,
x2,
y2,
index: idx,
}
})
.collect(),
);
for &idx in &order {
if suppress[idx] {
continue;
}
keep.push(idx);
let (b1x, b1y, b1xx, b1yy) = utils::row4(boxes, idx);
let area1 = area_f64(b1x, b1y, b1xx, b1yy);
for bbox in
rtree.locate_in_envelope_intersecting(&AABB::from_corners([b1x, b1y], [b1xx, b1yy]))
{
let idx_j = bbox.index;
if suppress[idx_j] {
continue;
}
let (b2x, b2y, b2xx, b2yy) = utils::row4(boxes, idx_j);
let x = utils::max(b1x, b2x);
let y = utils::max(b1y, b2y);
let xx = utils::min(b1xx, b2xx);
let yy = utils::min(b1yy, b2yy);
if x > xx || y > yy {
continue;
}
let intersection = area_f64(x, y, xx, yy);
let area2 = area_f64(b2x, b2y, b2xx, b2yy);
let union = area1 + area2 - intersection;
let iou = intersection / union;
if iou > iou_threshold {
suppress[idx_j] = true;
}
}
}
keep
}
pub fn rotated_nms_slice<N>(
boxes: &[N],
scores: &[f64],
iou_threshold: f64,
score_threshold: f64,
) -> Vec<usize>
where
N: Num + PartialEq + PartialOrd + ToPrimitive + Copy,
{
let order = filter_and_sort_scores(scores, score_threshold);
let mut keep: Vec<usize> = Vec::new();
let mut suppress = vec![false; order.len()];
for (i, &idx) in order.iter().enumerate() {
if suppress[i] {
continue;
}
keep.push(idx);
let box1 = utils::row5(boxes, idx);
let w1 = box1.2.to_f64().unwrap();
let h1 = box1.3.to_f64().unwrap();
let area1 = h1 * w1;
if area1 == 0.0 {
continue;
}
let rect1 = rotation::Rect::new(
box1.0.to_f64().unwrap(),
box1.1.to_f64().unwrap(),
w1,
h1,
box1.4.to_f64().unwrap(),
);
for j in (i + 1)..order.len() {
if suppress[j] {
continue;
}
let box2 = utils::row5(boxes, order[j]);
let w2 = box2.2.to_f64().unwrap();
let h2 = box2.3.to_f64().unwrap();
let area2 = w2 * h2;
if area2 == 0.0 {
continue;
}
let rect2 = rotation::Rect::new(
box2.0.to_f64().unwrap(),
box2.1.to_f64().unwrap(),
w2,
h2,
box2.4.to_f64().unwrap(),
);
if !rotation::envelopes_intersect(&rect1, &rect2) {
continue;
}
let intersection = rotation::intersection_area(&rect1, &rect2);
if intersection == 0.0 {
continue;
}
let union = area1 + area2 - intersection;
let iou: f64 = intersection / union;
if iou > iou_threshold {
suppress[j] = true;
}
}
}
keep
}
pub fn rtree_rotated_nms_slice<N>(
boxes: &[N],
scores: &[f64],
iou_threshold: f64,
score_threshold: f64,
) -> Vec<usize>
where
N: RTreeNum + PartialEq + PartialOrd + ToPrimitive + Copy + Send + Sync,
{
let order = filter_and_sort_scores(scores, score_threshold);
let mut keep: Vec<usize> = Vec::new();
let mut suppress = vec![false; boxes.len()];
let rtree: RTree<utils::Bbox<f64>> = RTree::bulk_load(
order
.iter()
.map(|&idx| {
let (cx, cy, w, h, a) = utils::row5(boxes, idx);
let rect = rotation::Rect::new(
cx.to_f64().unwrap(),
cy.to_f64().unwrap(),
w.to_f64().unwrap(),
h.to_f64().unwrap(),
a.to_f64().unwrap(),
);
let (x1, y1, x2, y2) = rotation::minimal_bounding_rect(&rect.points());
utils::Bbox {
x1,
y1,
x2,
y2,
index: idx,
}
})
.collect(),
);
for &idx in &order {
if suppress[idx] {
continue;
}
keep.push(idx);
let (cx, cy, w, h, a) = utils::row5(boxes, idx);
let w = w.to_f64().unwrap();
let h = h.to_f64().unwrap();
let area1 = w * h;
if area1 == 0.0 {
continue;
}
let rect1 = rotation::Rect::new(
cx.to_f64().unwrap(),
cy.to_f64().unwrap(),
w,
h,
a.to_f64().unwrap(),
);
let (b1x, b1y, b1xx, b1yy) = rotation::minimal_bounding_rect(&rect1.points());
for bbox in
rtree.locate_in_envelope_intersecting(&AABB::from_corners([b1x, b1y], [b1xx, b1yy]))
{
let idx_j = bbox.index;
if suppress[idx_j] {
continue;
}
let (cx2, cy2, w2, h2, a2) = utils::row5(boxes, idx_j);
let w2 = w2.to_f64().unwrap();
let h2 = h2.to_f64().unwrap();
let area2 = w2 * h2;
if area2 == 0.0 {
continue;
}
let rect2 = rotation::Rect::new(
cx2.to_f64().unwrap(),
cy2.to_f64().unwrap(),
w2,
h2,
a2.to_f64().unwrap(),
);
let intersection = rotation::intersection_area(&rect1, &rect2);
let union = area1 + area2 - intersection;
let iou = intersection / union;
if iou > iou_threshold {
suppress[idx_j] = true;
}
}
}
keep
}
#[cfg(feature = "ndarray")]
pub fn nms<'a, N, BA, SA>(
boxes: BA,
scores: SA,
iou_threshold: f64,
score_threshold: f64,
) -> Vec<usize>
where
N: Num + PartialEq + PartialOrd + ToPrimitive + Copy + PartialEq + 'a,
BA: Into<ArrayView2<'a, N>>,
SA: Into<ArrayView1<'a, f64>>,
{
let boxes = boxes.into();
let scores = scores.into();
assert_eq!(boxes.nrows(), scores.len_of(Axis(0)));
let boxes_slice = boxes.as_slice().expect("boxes must be contiguous");
let scores_slice = scores.as_slice().expect("scores must be contiguous");
nms_slice(boxes_slice, scores_slice, iou_threshold, score_threshold)
}
#[cfg(feature = "ndarray")]
pub fn rtree_nms<'a, N, BA, SA>(
boxes: BA,
scores: SA,
iou_threshold: f64,
score_threshold: f64,
) -> Vec<usize>
where
N: RTreeNum + PartialEq + PartialOrd + ToPrimitive + Copy + PartialEq + Send + Sync + 'a,
BA: Into<ArrayView2<'a, N>>,
SA: Into<ArrayView1<'a, f64>>,
{
let scores = scores.into();
let boxes = boxes.into();
let boxes_slice = boxes.as_slice().expect("boxes must be contiguous");
let scores_slice = scores.as_slice().expect("scores must be contiguous");
rtree_nms_slice(boxes_slice, scores_slice, iou_threshold, score_threshold)
}
#[cfg(feature = "ndarray")]
pub fn rotated_nms<'a, N, BA, SA>(
boxes: BA,
scores: SA,
iou_threshold: f64,
score_threshold: f64,
) -> Vec<usize>
where
N: Num + PartialEq + PartialOrd + ToPrimitive + Copy + PartialEq + 'a,
BA: Into<ArrayView2<'a, N>>,
SA: Into<ArrayView1<'a, f64>>,
{
let boxes = boxes.into();
let scores = scores.into();
assert_eq!(boxes.nrows(), scores.len_of(Axis(0)));
let boxes_slice = boxes.as_slice().expect("boxes must be contiguous");
let scores_slice = scores.as_slice().expect("scores must be contiguous");
rotated_nms_slice(boxes_slice, scores_slice, iou_threshold, score_threshold)
}
#[cfg(feature = "ndarray")]
pub fn rtree_rotated_nms<'a, N, BA, SA>(
boxes: BA,
scores: SA,
iou_threshold: f64,
score_threshold: f64,
) -> Vec<usize>
where
N: RTreeNum + PartialEq + PartialOrd + ToPrimitive + Copy + PartialEq + Send + Sync + 'a,
BA: Into<ArrayView2<'a, N>>,
SA: Into<ArrayView1<'a, f64>>,
{
let scores = scores.into();
let boxes = boxes.into();
let boxes_slice = boxes.as_slice().expect("boxes must be contiguous");
let scores_slice = scores.as_slice().expect("scores must be contiguous");
rtree_rotated_nms_slice(boxes_slice, scores_slice, iou_threshold, score_threshold)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_filter_and_sort_scores_no_thresh() {
let scores = vec![0.9, 0.3, 0.7, 0.5, 0.1];
let result = filter_and_sort_scores(&scores, 0.0);
assert_eq!(result, vec![0, 2, 3, 1, 4]);
}
#[test]
fn test_filter_and_sort_scores_with_thresh() {
let scores = vec![0.9, 0.3, 0.7, 0.5, 0.1];
let result = filter_and_sort_scores(&scores, 0.5);
assert_eq!(result, vec![0, 2, 3]);
}
#[test]
fn test_nms_slice_normal() {
let boxes = vec![
184.68927598,
850.65932762,
201.47437531,
866.02327337,
185.68927598,
851.65932762,
200.47437531,
865.02327337,
875.33814954,
706.46958933,
902.14487263,
737.14697788,
874.33814954,
703.46958933,
901.14487263,
732.14697788,
277.71729109,
744.81869575,
308.13768447,
777.11413807,
275.71729109,
740.81869575,
310.13768447,
765.11413807,
];
let scores = vec![0.9, 0.8, 0.7, 0.6, 0.5, 0.4];
let keep = nms_slice(&boxes, &scores, 0.5, 0.0);
let keep_rtree = rtree_nms_slice(&boxes, &scores, 0.5, 0.0);
assert_eq!(keep, vec![0, 2, 4]);
assert_eq!(keep_rtree, keep);
}
#[test]
fn test_rotated_nms_slice_normal() {
let boxes = vec![
1.0, 2.0, 10.0, 5.0, 45.0, 0.0, 1.0, 9.0, 4.0, 30.0, 10.0, 20.0, 5.0, 8.0, -45.0,
];
let scores = vec![0.9, 0.8, 0.7];
let keep = rotated_nms_slice(&boxes, &scores, 0.5, 0.0);
assert_eq!(keep, vec![0, 2]);
}
#[test]
fn test_rtree_rotated_nms_slice_normal() {
let boxes = vec![
1.0, 2.0, 10.0, 5.0, 45.0, 0.0, 1.0, 9.0, 4.0, 30.0, 10.0, 20.0, 5.0, 8.0, -45.0,
];
let scores = vec![0.9, 0.8, 0.7];
let keep = rtree_rotated_nms_slice(&boxes, &scores, 0.5, 0.0);
assert_eq!(keep, vec![0, 2]);
}
#[cfg(feature = "ndarray")]
mod ndarray_tests {
use super::*;
use ndarray::{arr2, Array1};
#[test]
fn test_nms_normal_case() {
let boxes = arr2(&[
[184.68927598, 850.65932762, 201.47437531, 866.02327337],
[185.68927598, 851.65932762, 200.47437531, 865.02327337],
[875.33814954, 706.46958933, 902.14487263, 737.14697788],
[874.33814954, 703.46958933, 901.14487263, 732.14697788],
[277.71729109, 744.81869575, 308.13768447, 777.11413807],
[275.71729109, 740.81869575, 310.13768447, 765.11413807],
]);
let scores = Array1::from(vec![0.9, 0.8, 0.7, 0.6, 0.5, 0.4]);
let keep = nms(&boxes, &scores, 0.5, 0.0);
let keep_rtree = rtree_nms(&boxes, &scores, 0.5, 0.0);
assert_eq!(keep, vec![0, 2, 4]);
assert_eq!(keep_rtree, keep);
}
#[test]
fn test_nms_empty_case() {
let boxes = arr2(&[[0.0, 0.0, 2.0, 2.0], [1.0, 1.0, 3.0, 3.0]]);
let scores = Array1::from(vec![0.0, 0.0]);
let keep = nms(&boxes, &scores, 0.5, 1.0);
let keep_rtree = rtree_nms(&boxes, &scores, 0.5, 1.0);
assert_eq!(keep, vec![]);
assert_eq!(keep, keep_rtree)
}
#[test]
fn test_nms_score_threshold() {
let boxes = arr2(&[[0.0, 0.0, 2.0, 2.0], [1.0, 1.0, 3.0, 3.0]]);
let scores = Array1::from(vec![0.0, 1.0]);
let keep = nms(&boxes, &scores, 0.5, 0.5);
let keep_rtree = rtree_nms(&boxes, &scores, 0.5, 0.5);
assert_eq!(keep, vec![1]);
assert_eq!(keep, keep_rtree)
}
#[test]
fn test_nms_iou_threshold() {
let boxes = arr2(&[[0.0, 0.0, 2.0, 2.0], [1.0, 1.0, 3.0, 3.0]]);
let scores = Array1::from(vec![1.0, 1.0]);
let keep = nms(&boxes, &scores, 0.8, 0.0);
let keep_rtree = rtree_nms(&boxes, &scores, 0.8, 0.0);
assert_eq!(keep, vec![0, 1]);
assert_eq!(keep, keep_rtree)
}
#[test]
fn test_rotated_nms_normal_case() {
let boxes = arr2(&[
[1.0, 2.0, 10.0, 5.0, 45.0],
[0.0, 1.0, 9.0, 4.0, 30.0],
[10.0, 20.0, 5.0, 8.0, -45.0],
]);
let scores = Array1::from(vec![0.9, 0.8, 0.7]);
let keep = rotated_nms(&boxes, &scores, 0.5, 0.0);
let keep_rtree = rtree_rotated_nms(&boxes, &scores, 0.5, 0.0);
assert_eq!(keep, vec![0, 2]);
assert_eq!(keep, keep_rtree);
}
#[test]
fn test_rotated_nms_empty_case() {
let boxes = arr2(&[[0.0, 0.0, 2.0, 2.0, 10.0], [1.0, 1.0, 3.0, 3.0, 10.0]]);
let scores = Array1::from(vec![0.0, 0.0]);
let keep = rotated_nms(&boxes, &scores, 0.5, 1.0);
let keep_rtree = rtree_rotated_nms(&boxes, &scores, 0.5, 1.0);
assert_eq!(keep, vec![]);
assert_eq!(keep, keep_rtree);
}
#[test]
fn test_rotated_nms_score_threshold() {
let boxes = arr2(&[[0.0, 0.0, 2.0, 2.0, 10.0], [1.0, 1.0, 3.0, 3.0, 12.0]]);
let scores = Array1::from(vec![0.0, 1.0]);
let keep = rotated_nms(&boxes, &scores, 0.5, 0.5);
let keep_rtree = rtree_rotated_nms(&boxes, &scores, 0.5, 0.5);
assert_eq!(keep, vec![1]);
assert_eq!(keep, keep_rtree);
}
#[test]
fn test_rotated_nms_iou_threshold() {
let boxes = arr2(&[[0.0, 0.0, 2.0, 2.0, 10.0], [1.0, 1.0, 3.0, 3.0, 45.0]]);
let scores = Array1::from(vec![1.0, 1.0]);
let keep = rotated_nms(&boxes, &scores, 0.8, 0.0);
let keep_rtree = rtree_rotated_nms(&boxes, &scores, 0.8, 0.0);
assert_eq!(keep, vec![0, 1]);
assert_eq!(keep, keep_rtree);
}
}
}