use crate::{arg_max, BBoxTypeTrait, BoundingBox, DetectBox};
use ndarray::{
parallel::prelude::{IntoParallelIterator, ParallelIterator as _},
Array1, ArrayView2, Zip,
};
use num_traits::{AsPrimitive, Float};
use rayon::slice::ParallelSliceMut;
pub fn postprocess_boxes_float<
B: BBoxTypeTrait,
BOX: Float + AsPrimitive<f32> + Send + Sync,
SCORE: Float + AsPrimitive<f32> + Send + Sync,
>(
threshold: SCORE,
boxes: ArrayView2<BOX>,
scores: ArrayView2<SCORE>,
) -> Vec<DetectBox> {
assert_eq!(scores.dim().0, boxes.dim().0);
assert_eq!(boxes.dim().1, 4);
Zip::from(scores.rows())
.and(boxes.rows())
.into_par_iter()
.filter_map(|(score, bbox)| {
let (score_, label) = arg_max(score);
if score_ < threshold {
return None;
}
let bbox = B::ndarray_to_xyxy_float(bbox);
Some(DetectBox {
label,
score: score_.as_(),
bbox: bbox.into(),
})
})
.collect()
}
pub fn postprocess_boxes_index_float<
B: BBoxTypeTrait,
BOX: Float + AsPrimitive<f32> + Send + Sync,
SCORE: Float + AsPrimitive<f32> + Send + Sync,
>(
threshold: SCORE,
boxes: ArrayView2<BOX>,
scores: ArrayView2<SCORE>,
) -> Vec<(DetectBox, usize)> {
assert_eq!(scores.dim().0, boxes.dim().0);
assert_eq!(boxes.dim().1, 4);
let indices: Array1<usize> = (0..boxes.dim().0).collect();
Zip::from(scores.rows())
.and(boxes.rows())
.and(&indices)
.into_par_iter()
.filter_map(|(score, bbox, i)| {
let (score_, label) = arg_max(score);
if score_ < threshold {
return None;
}
let bbox = B::ndarray_to_xyxy_float(bbox);
Some((
DetectBox {
label,
score: score_.as_(),
bbox: bbox.into(),
},
*i,
))
})
.collect()
}
#[must_use]
pub fn nms_float(iou: f32, mut boxes: Vec<DetectBox>) -> Vec<DetectBox> {
boxes.par_sort_by(|a, b| b.score.total_cmp(&a.score));
if iou >= 1.0 {
return boxes;
}
for i in 0..boxes.len() {
if boxes[i].score < 0.0 {
continue;
}
for j in (i + 1)..boxes.len() {
if boxes[j].score < 0.0 {
continue;
}
if jaccard(&boxes[j].bbox, &boxes[i].bbox, iou) {
boxes[j].score = -1.0;
}
}
}
boxes.into_iter().filter(|b| b.score >= 0.0).collect()
}
#[must_use]
pub fn nms_extra_float<E: Send + Sync>(
iou: f32,
mut boxes: Vec<(DetectBox, E)>,
) -> Vec<(DetectBox, E)> {
boxes.par_sort_by(|a, b| b.0.score.total_cmp(&a.0.score));
if iou >= 1.0 {
return boxes;
}
for i in 0..boxes.len() {
if boxes[i].0.score < 0.0 {
continue;
}
for j in (i + 1)..boxes.len() {
if boxes[j].0.score < 0.0 {
continue;
}
if jaccard(&boxes[j].0.bbox, &boxes[i].0.bbox, iou) {
boxes[j].0.score = -1.0;
}
}
}
boxes.into_iter().filter(|b| b.0.score >= 0.0).collect()
}
#[must_use]
pub fn nms_class_aware_float(iou: f32, mut boxes: Vec<DetectBox>) -> Vec<DetectBox> {
boxes.par_sort_by(|a, b| b.score.total_cmp(&a.score));
if iou >= 1.0 {
return boxes;
}
for i in 0..boxes.len() {
if boxes[i].score < 0.0 {
continue;
}
for j in (i + 1)..boxes.len() {
if boxes[j].score < 0.0 {
continue;
}
if boxes[j].label == boxes[i].label && jaccard(&boxes[j].bbox, &boxes[i].bbox, iou) {
boxes[j].score = -1.0;
}
}
}
boxes.into_iter().filter(|b| b.score >= 0.0).collect()
}
#[must_use]
pub fn nms_extra_class_aware_float<E: Send + Sync>(
iou: f32,
mut boxes: Vec<(DetectBox, E)>,
) -> Vec<(DetectBox, E)> {
boxes.par_sort_by(|a, b| b.0.score.total_cmp(&a.0.score));
if iou >= 1.0 {
return boxes;
}
for i in 0..boxes.len() {
if boxes[i].0.score < 0.0 {
continue;
}
for j in (i + 1)..boxes.len() {
if boxes[j].0.score < 0.0 {
continue;
}
if boxes[j].0.label == boxes[i].0.label
&& jaccard(&boxes[j].0.bbox, &boxes[i].0.bbox, iou)
{
boxes[j].0.score = -1.0;
}
}
}
boxes.into_iter().filter(|b| b.0.score >= 0.0).collect()
}
pub fn jaccard(a: &BoundingBox, b: &BoundingBox, iou: f32) -> bool {
let left = a.xmin.max(b.xmin);
let top = a.ymin.max(b.ymin);
let right = a.xmax.min(b.xmax);
let bottom = a.ymax.min(b.ymax);
let intersection = (right - left).max(0.0) * (bottom - top).max(0.0);
let area_a = (a.xmax - a.xmin) * (a.ymax - a.ymin);
let area_b = (b.xmax - b.xmin) * (b.ymax - b.ymin);
let union = area_a + area_b - intersection;
intersection > iou * union
}