use crate::{
arg_max, float::jaccard, BBoxTypeTrait, BoundingBox, DetectBoxQuantized, Quantization,
};
use ndarray::{
parallel::prelude::{IntoParallelIterator, ParallelIterator as _},
Array1, ArrayView2, Zip,
};
use num_traits::{AsPrimitive, PrimInt};
use rayon::slice::ParallelSliceMut;
#[doc(hidden)]
pub fn postprocess_boxes_quant<
B: BBoxTypeTrait,
Boxes: PrimInt + AsPrimitive<f32> + Send + Sync,
Scores: PrimInt + AsPrimitive<f32> + Send + Sync,
>(
threshold: Scores,
boxes: ArrayView2<Boxes>,
scores: ArrayView2<Scores>,
quant_boxes: Quantization,
) -> Vec<DetectBoxQuantized<Scores>> {
assert_eq!(scores.dim().0, boxes.dim().0);
assert_eq!(boxes.dim().1, 4);
Zip::from(scores.rows())
.and(boxes.rows())
.into_par_iter()
.filter_map(|(score, bbox)| {
let (score_, label) = arg_max(score);
if score_ < threshold {
return None;
}
let bbox_quant = B::ndarray_to_xyxy_dequant(bbox.view(), quant_boxes);
Some(DetectBoxQuantized {
label,
score: score_,
bbox: BoundingBox::from(bbox_quant),
})
})
.collect()
}
#[doc(hidden)]
pub fn postprocess_boxes_index_quant<
B: BBoxTypeTrait,
Boxes: PrimInt + AsPrimitive<f32> + Send + Sync,
Scores: PrimInt + AsPrimitive<f32> + Send + Sync,
>(
threshold: Scores,
boxes: ArrayView2<Boxes>,
scores: ArrayView2<Scores>,
quant_boxes: Quantization,
) -> Vec<(DetectBoxQuantized<Scores>, usize)> {
assert_eq!(scores.dim().0, boxes.dim().0);
assert_eq!(boxes.dim().1, 4);
let indices: Array1<usize> = (0..boxes.dim().0).collect();
Zip::from(scores.rows())
.and(boxes.rows())
.and(&indices)
.into_par_iter()
.filter_map(|(score, bbox, index)| {
let (score_, label) = arg_max(score);
if score_ < threshold {
return None;
}
let bbox_quant = B::ndarray_to_xyxy_dequant(bbox.view(), quant_boxes);
Some((
DetectBoxQuantized {
label,
score: score_,
bbox: BoundingBox::from(bbox_quant),
},
*index,
))
})
.collect()
}
#[doc(hidden)]
#[must_use]
pub fn nms_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync>(
iou: f32,
mut boxes: Vec<DetectBoxQuantized<SCORE>>,
) -> Vec<DetectBoxQuantized<SCORE>> {
boxes.par_sort_by(|a, b| b.score.cmp(&a.score));
if iou >= 1.0 {
return boxes;
}
let min_val = SCORE::min_value();
for i in 0..boxes.len() {
if boxes[i].score <= min_val {
continue;
}
for j in (i + 1)..boxes.len() {
if boxes[j].score <= min_val {
continue;
}
if jaccard(&boxes[j].bbox, &boxes[i].bbox, iou) {
boxes[j].score = min_val;
}
}
}
boxes.into_iter().filter(|b| b.score > min_val).collect()
}
#[doc(hidden)]
#[must_use]
pub fn nms_extra_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send + Sync>(
iou: f32,
mut boxes: Vec<(DetectBoxQuantized<SCORE>, E)>,
) -> Vec<(DetectBoxQuantized<SCORE>, E)> {
boxes.par_sort_by(|a, b| b.0.score.cmp(&a.0.score));
if iou >= 1.0 {
return boxes;
}
let min_val = SCORE::min_value();
for i in 0..boxes.len() {
if boxes[i].0.score <= min_val {
continue;
}
for j in (i + 1)..boxes.len() {
if boxes[j].0.score <= min_val {
continue;
}
if jaccard(&boxes[j].0.bbox, &boxes[i].0.bbox, iou) {
boxes[j].0.score = min_val;
}
}
}
boxes.into_iter().filter(|b| b.0.score > min_val).collect()
}
#[doc(hidden)]
#[must_use]
pub fn nms_class_aware_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync>(
iou: f32,
mut boxes: Vec<DetectBoxQuantized<SCORE>>,
) -> Vec<DetectBoxQuantized<SCORE>> {
boxes.par_sort_by(|a, b| b.score.cmp(&a.score));
if iou >= 1.0 {
return boxes;
}
let min_val = SCORE::min_value();
for i in 0..boxes.len() {
if boxes[i].score <= min_val {
continue;
}
for j in (i + 1)..boxes.len() {
if boxes[j].score <= min_val {
continue;
}
if boxes[j].label == boxes[i].label && jaccard(&boxes[j].bbox, &boxes[i].bbox, iou) {
boxes[j].score = min_val;
}
}
}
boxes.into_iter().filter(|b| b.score > min_val).collect()
}
#[doc(hidden)]
#[must_use]
pub fn nms_extra_class_aware_int<
SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
E: Send + Sync,
>(
iou: f32,
mut boxes: Vec<(DetectBoxQuantized<SCORE>, E)>,
) -> Vec<(DetectBoxQuantized<SCORE>, E)> {
boxes.par_sort_by(|a, b| b.0.score.cmp(&a.0.score));
if iou >= 1.0 {
return boxes;
}
let min_val = SCORE::min_value();
for i in 0..boxes.len() {
if boxes[i].0.score <= min_val {
continue;
}
for j in (i + 1)..boxes.len() {
if boxes[j].0.score <= min_val {
continue;
}
if boxes[j].0.label == boxes[i].0.label
&& jaccard(&boxes[j].0.bbox, &boxes[i].0.bbox, iou)
{
boxes[j].0.score = min_val;
}
}
}
boxes.into_iter().filter(|b| b.0.score > min_val).collect()
}
#[doc(hidden)]
pub fn quantize_score_threshold<T: PrimInt + AsPrimitive<f32>>(score: f32, quant: Quantization) -> T
where
f32: AsPrimitive<T>,
{
if quant.scale == 0.0 {
return T::max_value();
}
let v = (score / quant.scale + quant.zero_point as f32).ceil();
let v = v.clamp(T::min_value().as_(), T::max_value().as_());
v.as_()
}