use ndarray::{ArrayView, IxDyn, s};
use crate::{YoloResult, BoundingBox};
use core::cmp::Ordering;
fn iou(a: &YoloResult, b: &YoloResult) -> f32 {
let xx1 = a.bbox.x1.max(b.bbox.x1);
let yy1 = a.bbox.y1.max(b.bbox.y1);
let xx2 = a.bbox.x2.min(b.bbox.x2);
let yy2 = a.bbox.y2.min(b.bbox.y2);
let w = (xx2 - xx1).max(0.0);
let h = (yy2 - yy1).max(0.0);
let inter = w * h;
let area_a = (a.bbox.x2 - a.bbox.x1) * (a.bbox.y2 - a.bbox.y1);
let area_b = (b.bbox.x2 - b.bbox.x1) * (b.bbox.y2 - b.bbox.y1);
inter / (area_a + area_b - inter + 1e-6)
}
fn nms(mut dets: Vec<YoloResult>, iou_thresh: f32) -> Vec<YoloResult> {
dets.sort_by(|a, b| b.confidence.partial_cmp(&a.confidence).unwrap_or(Ordering::Equal));
let mut keep = Vec::new();
while let Some(current) = dets.pop() {
keep.push(current.clone());
dets.retain(|d| iou(¤t, d) < iou_thresh);
}
keep
}
pub fn yolo_nms(output: ArrayView<f32, IxDyn>, conf_thresh: f32, iou_thresh: f32) -> Vec<YoloResult> {
let shape = output.shape();
let (n, c) = (shape[0], shape[1]);
let mut dets = Vec::new();
for i in 0..n {
let cx = output[[i, 0, 0]];
let cy = output[[i, 1, 0]];
let w = output[[i, 2, 0]];
let h = output[[i, 3, 0]];
let class_slice = output.slice(s![i, 4..c, 0]);
if let Some((cls, cls_score)) = class_slice
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(Ordering::Equal))
{
let confidence = *cls_score;
if confidence > conf_thresh {
dets.push(YoloResult {
bbox: BoundingBox {
x1: cx - w / 2.0,
y1: cy - h / 2.0,
x2: cx + w / 2.0,
y2: cy + h / 2.0,
},
confidence,
class_id: cls,
});
}
}
}
nms(dets, iou_thresh)
}