1use crate::{BoundingBox, DetectError, Detection};
2
3pub fn iou(a: BoundingBox, b: BoundingBox) -> f32 {
5 let inter_x1 = a.x1.max(b.x1);
6 let inter_y1 = a.y1.max(b.y1);
7 let inter_x2 = a.x2.min(b.x2);
8 let inter_y2 = a.y2.min(b.y2);
9
10 let inter_w = (inter_x2 - inter_x1).max(0.0);
11 let inter_h = (inter_y2 - inter_y1).max(0.0);
12 let inter = inter_w * inter_h;
13 let union = a.area() + b.area() - inter;
14 if union <= 0.0 { 0.0 } else { inter / union }
15}
16
17pub fn non_max_suppression(
19 detections: &[Detection],
20 iou_threshold: f32,
21 max_detections: usize,
22) -> Vec<Detection> {
23 let mut sorted = detections
24 .iter()
25 .copied()
26 .filter(is_finite_detection)
27 .collect::<Vec<_>>();
28 sorted.sort_by(|a, b| b.score.total_cmp(&a.score));
29
30 let mut selected: Vec<Detection> = Vec::new();
31 for candidate in sorted {
32 if selected.len() >= max_detections {
33 break;
34 }
35 let mut suppressed = false;
36 for chosen in &selected {
37 if chosen.class_id == candidate.class_id
38 && iou(chosen.bbox, candidate.bbox) > iou_threshold
39 {
40 suppressed = true;
41 break;
42 }
43 }
44 if !suppressed {
45 selected.push(candidate);
46 }
47 }
48 selected
49}
50
51pub(crate) fn validate_nms_args(
52 iou_threshold: f32,
53 max_detections: usize,
54) -> Result<(), DetectError> {
55 if !iou_threshold.is_finite() || !(0.0..=1.0).contains(&iou_threshold) {
56 return Err(DetectError::InvalidIouThreshold { iou_threshold });
57 }
58 if max_detections == 0 {
59 return Err(DetectError::InvalidMaxDetections { max_detections });
60 }
61 Ok(())
62}
63
64fn is_finite_detection(detection: &Detection) -> bool {
65 detection.score.is_finite()
66 && detection.bbox.x1.is_finite()
67 && detection.bbox.y1.is_finite()
68 && detection.bbox.x2.is_finite()
69 && detection.bbox.y2.is_finite()
70}
71
72pub fn soft_nms(detections: &mut Vec<Detection>, sigma: f32, score_threshold: f32) {
79 detections.retain(is_finite_detection);
81
82 let mut i = 0;
85 while i < detections.len() {
86 let mut max_idx = i;
88 for j in (i + 1)..detections.len() {
89 if detections[j].score > detections[max_idx].score {
90 max_idx = j;
91 }
92 }
93 detections.swap(i, max_idx);
94
95 let current = detections[i];
97 let mut j = i + 1;
98 while j < detections.len() {
99 let overlap = iou(current.bbox, detections[j].bbox);
100 detections[j].score *= (-overlap * overlap / sigma).exp();
101 if detections[j].score < score_threshold {
102 detections.swap_remove(j);
103 } else {
105 j += 1;
106 }
107 }
108 i += 1;
109 }
110}
111
112pub fn batched_nms(detections: &[Detection], iou_threshold: f32) -> Vec<Detection> {
118 use std::collections::HashMap;
119
120 let mut by_class: HashMap<usize, Vec<Detection>> = HashMap::new();
121 for det in detections {
122 by_class.entry(det.class_id).or_default().push(*det);
123 }
124
125 let mut results: Vec<Detection> = Vec::new();
126 for class_dets in by_class.values() {
127 let kept = non_max_suppression(class_dets, iou_threshold, class_dets.len());
128 results.extend(kept);
129 }
130
131 results.sort_by(|a, b| b.score.total_cmp(&a.score));
132 results
133}