pub mod face;
pub mod logo;
pub mod nms;
pub mod object;
pub mod pyramid;
pub mod text;
pub use face::{FaceDetection, FaceDetector};
pub use logo::{LogoDetection, LogoDetector};
pub use nms::{non_maximum_suppression, soft_nms_boxes, Detection, DetectionBox};
pub use object::{ObjectDetection, ObjectDetector, ObjectType};
pub use text::{TextDetection, TextDetector};
use crate::common::Rect;
pub fn nms<T, B, C>(detections: &mut Vec<T>, bbox_fn: B, conf_fn: C, iou_threshold: f32)
where
T: Clone,
B: Fn(&T) -> Rect,
C: Fn(&T) -> f32,
{
detections.sort_by(|a, b| {
conf_fn(b)
.partial_cmp(&conf_fn(a))
.unwrap_or(std::cmp::Ordering::Equal)
});
let n = detections.len();
let mut suppressed = vec![false; n];
for i in 0..n {
if suppressed[i] {
continue;
}
let bbox_i = bbox_fn(&detections[i]);
for j in (i + 1)..n {
if suppressed[j] {
continue;
}
let bbox_j = bbox_fn(&detections[j]);
if bbox_i.iou(&bbox_j) > iou_threshold {
suppressed[j] = true;
}
}
}
let mut out = Vec::with_capacity(n);
for (i, det) in detections.drain(..).enumerate() {
if !suppressed[i] {
out.push(det);
}
}
*detections = out;
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Clone)]
struct Det {
bbox: Rect,
conf: f32,
}
#[test]
fn test_nms_removes_overlapping() {
let mut dets = vec![
Det {
bbox: Rect::new(0.0, 0.0, 100.0, 100.0),
conf: 0.9,
},
Det {
bbox: Rect::new(5.0, 5.0, 100.0, 100.0),
conf: 0.7,
},
Det {
bbox: Rect::new(200.0, 200.0, 50.0, 50.0),
conf: 0.8,
},
];
nms(&mut dets, |d| d.bbox, |d| d.conf, 0.5);
assert_eq!(dets.len(), 2, "should keep two non-overlapping detections");
assert!((dets[0].conf - 0.9).abs() < f32::EPSILON);
}
#[test]
fn test_nms_empty() {
let mut dets: Vec<Det> = Vec::new();
nms(&mut dets, |d| d.bbox, |d| d.conf, 0.5);
assert!(dets.is_empty());
}
#[test]
fn test_nms_no_overlap() {
let mut dets = vec![
Det {
bbox: Rect::new(0.0, 0.0, 10.0, 10.0),
conf: 0.9,
},
Det {
bbox: Rect::new(100.0, 100.0, 10.0, 10.0),
conf: 0.7,
},
];
nms(&mut dets, |d| d.bbox, |d| d.conf, 0.5);
assert_eq!(dets.len(), 2, "non-overlapping boxes should both survive");
}
}