use scirs2_core::ndarray::{Array2, ArrayView2};
use std::cmp::Ordering;
#[derive(Debug, Clone, Copy)]
pub struct BoundingBox {
pub x: f32,
pub y: f32,
pub width: f32,
pub height: f32,
pub score: f32,
pub class_id: Option<usize>,
}
impl BoundingBox {
pub fn new(x: f32, y: f32, width: f32, height: f32, score: f32) -> Self {
Self {
x,
y,
width,
height,
score,
class_id: None,
}
}
pub fn with_class(
x: f32,
y: f32,
width: f32,
height: f32,
score: f32,
class_id: usize,
) -> Self {
Self {
x,
y,
width,
height,
score,
class_id: Some(class_id),
}
}
pub fn area(&self) -> f32 {
self.width * self.height
}
pub fn iou(&self, other: &BoundingBox) -> f32 {
let x1 = self.x.max(other.x);
let y1 = self.y.max(other.y);
let x2 = (self.x + self.width).min(other.x + other.width);
let y2 = (self.y + self.height).min(other.y + other.height);
if x2 < x1 || y2 < y1 {
return 0.0;
}
let intersection = (x2 - x1) * (y2 - y1);
let union = self.area() + other.area() - intersection;
if union > 0.0 {
intersection / union
} else {
0.0
}
}
}
#[allow(dead_code)]
pub fn nms_boxes(_boxes: &[BoundingBox], iouthreshold: f32) -> Vec<usize> {
if _boxes.is_empty() {
return vec![];
}
let mut indices: Vec<usize> = (0.._boxes.len()).collect();
indices.sort_by(|&a, &b| {
_boxes[b]
.score
.partial_cmp(&_boxes[a].score)
.unwrap_or(Ordering::Equal)
});
let mut keep = Vec::new();
let mut suppressed = vec![false; _boxes.len()];
for &idx in &indices {
if suppressed[idx] {
continue;
}
keep.push(idx);
for &other_idx in &indices {
if other_idx != idx && !suppressed[other_idx] {
let iou = _boxes[idx].iou(&_boxes[other_idx]);
if iou > iouthreshold {
suppressed[other_idx] = true;
}
}
}
}
keep
}
#[allow(dead_code)]
pub fn nms_boxes_class_aware(_boxes: &[BoundingBox], iouthreshold: f32) -> Vec<usize> {
if _boxes.is_empty() {
return vec![];
}
let mut class_groups: std::collections::HashMap<Option<usize>, Vec<usize>> =
std::collections::HashMap::new();
for (idx, box_) in _boxes.iter().enumerate() {
class_groups.entry(box_.class_id).or_default().push(idx);
}
let mut keep = Vec::new();
for indices in class_groups.values() {
let class_boxes: Vec<BoundingBox> = indices.iter().map(|&idx| _boxes[idx]).collect();
let kept_in_class = nms_boxes(&class_boxes, iouthreshold);
for kept_idx in kept_in_class {
keep.push(indices[kept_idx]);
}
}
keep.sort_unstable();
keep
}
#[allow(dead_code)]
pub fn soft_nms(
boxes: &mut [BoundingBox],
iou_threshold: f32,
sigma: f32,
score_threshold: f32,
) -> Vec<usize> {
if boxes.is_empty() {
return vec![];
}
let mut indices: Vec<usize> = (0..boxes.len()).collect();
let mut keep = Vec::new();
while !indices.is_empty() {
let max_idx = indices
.iter()
.max_by(|&&a, &&b| {
boxes[a]
.score
.partial_cmp(&boxes[b].score)
.unwrap_or(Ordering::Equal)
})
.copied()
.expect("Operation failed");
let max_pos = indices
.iter()
.position(|&x| x == max_idx)
.expect("Operation failed");
indices.remove(max_pos);
if boxes[max_idx].score >= score_threshold {
keep.push(max_idx);
}
let indices_to_update = indices.clone();
for &idx in &indices_to_update {
let iou = boxes[max_idx].iou(&boxes[idx]);
if iou > iou_threshold {
boxes[idx].score *= (-iou * iou / sigma).exp();
if boxes[idx].score < score_threshold {
indices.retain(|&x| x != idx);
}
}
}
}
keep
}
#[allow(dead_code)]
pub fn nms_2d(
response: &ArrayView2<f32>,
window_size: usize,
threshold: f32,
) -> Vec<(usize, usize, f32)> {
let (height, width) = response.dim();
let half_window = window_size / 2;
let mut peaks = Vec::new();
for y in half_window..(height - half_window) {
for x in half_window..(width - half_window) {
let value = response[[y, x]];
if value < threshold {
continue;
}
let mut is_max = true;
'window: for dy in 0..window_size {
for dx in 0..window_size {
let ny = y + dy - half_window;
let nx = x + dx - half_window;
if (ny != y || nx != x) && response[[ny, nx]] >= value {
is_max = false;
break 'window;
}
}
}
if is_max {
peaks.push((x, y, value));
}
}
}
peaks
}
#[allow(dead_code)]
pub fn fast_nms_2d(
response: &ArrayView2<f32>,
window_size: usize,
threshold: f32,
) -> Vec<(usize, usize, f32)> {
let (height, width) = response.dim();
let half_window = window_size / 2;
let mut max_filtered = Array2::zeros((height, width));
for y in 0..height {
for x in 0..width {
let mut max_val = f32::NEG_INFINITY;
for dy in y.saturating_sub(half_window)..=(y + half_window).min(height - 1) {
for dx in x.saturating_sub(half_window)..=(x + half_window).min(width - 1) {
max_val = max_val.max(response[[dy, dx]]);
}
}
max_filtered[[y, x]] = max_val;
}
}
let mut peaks = Vec::new();
for y in 0..height {
for x in 0..width {
let value = response[[y, x]];
if value >= threshold && value == max_filtered[[y, x]] {
peaks.push((x, y, value));
}
}
}
peaks
}
#[allow(dead_code)]
pub fn oriented_nms(
magnitude: &ArrayView2<f32>,
orientation: &ArrayView2<f32>,
threshold: f32,
) -> Array2<f32> {
let (height, width) = magnitude.dim();
let mut suppressed = Array2::zeros((height, width));
for y in 1..(height - 1) {
for x in 1..(width - 1) {
let mag = magnitude[[y, x]];
if mag < threshold {
continue;
}
let angle = orientation[[y, x]];
let direction =
((angle + std::f32::consts::PI) / (std::f32::consts::PI / 4.0)) as i32 % 8;
let (dx1, dy1, dx2, dy2) = match direction {
0 | 4 => (1, 0, -1, 0), 1 | 5 => (1, 1, -1, -1), 2 | 6 => (0, 1, 0, -1), 3 | 7 => (-1, 1, 1, -1), _ => (1, 0, -1, 0),
};
let mag1 = magnitude[[(y as i32 + dy1) as usize, (x as i32 + dx1) as usize]];
let mag2 = magnitude[[(y as i32 + dy2) as usize, (x as i32 + dx2) as usize]];
if mag >= mag1 && mag >= mag2 {
suppressed[[y, x]] = mag;
}
}
}
suppressed
}
#[allow(dead_code)]
pub fn scale_space_nms(
responses: &[Array2<f32>],
window_size: usize,
threshold: f32,
) -> Vec<(usize, usize, usize, f32)> {
let mut all_peaks = Vec::new();
for (scale, response) in responses.iter().enumerate() {
let peaks = nms_2d(&response.view(), window_size, threshold);
for (x, y, value) in peaks {
let mut is_scale_max = true;
if scale > 0 && responses[scale - 1][[y, x]] >= value {
is_scale_max = false;
}
if scale < responses.len() - 1 && responses[scale + 1][[y, x]] >= value {
is_scale_max = false;
}
if is_scale_max {
all_peaks.push((x, y, scale, value));
}
}
}
all_peaks
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bounding_box_iou() {
let box1 = BoundingBox::new(0.0, 0.0, 10.0, 10.0, 0.9);
let box2 = BoundingBox::new(5.0, 5.0, 10.0, 10.0, 0.8);
let iou = box1.iou(&box2);
assert!(iou > 0.0 && iou < 1.0);
let box3 = BoundingBox::new(20.0, 20.0, 10.0, 10.0, 0.7);
assert_eq!(box1.iou(&box3), 0.0);
}
#[test]
fn test_nms_boxes() {
let boxes = vec![
BoundingBox::new(0.0, 0.0, 10.0, 10.0, 0.9),
BoundingBox::new(1.0, 1.0, 10.0, 10.0, 0.8),
BoundingBox::new(20.0, 20.0, 10.0, 10.0, 0.95),
];
let keep = nms_boxes(&boxes, 0.5);
assert_eq!(keep.len(), 2); }
#[test]
fn test_nms_2d() {
let mut response = Array2::zeros((10, 10));
response[[5, 5]] = 1.0;
response[[2, 2]] = 0.8;
response[[2, 3]] = 0.6;
let peaks = nms_2d(&response.view(), 3, 0.5);
assert_eq!(peaks.len(), 2);
}
#[test]
fn test_soft_nms() {
let mut boxes = vec![
BoundingBox::new(0.0, 0.0, 10.0, 10.0, 0.9),
BoundingBox::new(1.0, 1.0, 10.0, 10.0, 0.8),
BoundingBox::new(20.0, 20.0, 10.0, 10.0, 0.95),
];
let keep = soft_nms(&mut boxes, 0.5, 0.5, 0.3);
assert!(keep.len() >= 2);
assert!(boxes[1].score < 0.8);
}
}