1use crate::error::{VisionError, VisionResult};
20
21#[derive(Debug, Clone, Copy, PartialEq)]
28pub struct BBox {
29 pub x1: f32,
31 pub y1: f32,
33 pub x2: f32,
35 pub y2: f32,
37 pub score: f32,
39}
40
41impl BBox {
42 #[must_use]
44 #[inline]
45 pub fn new(x1: f32, y1: f32, x2: f32, y2: f32, score: f32) -> Self {
46 Self {
47 x1,
48 y1,
49 x2,
50 y2,
51 score,
52 }
53 }
54
55 #[must_use]
57 #[inline]
58 pub fn area(&self) -> f32 {
59 let w = self.x2 - self.x1;
60 let h = self.y2 - self.y1;
61 if w <= 0.0 || h <= 0.0 { 0.0 } else { w * h }
62 }
63}
64
65#[must_use]
71pub fn iou(a: &BBox, b: &BBox) -> f32 {
72 let area_a = a.area();
73 let area_b = b.area();
74 if area_a <= 0.0 || area_b <= 0.0 {
75 return 0.0;
76 }
77 let ix1 = a.x1.max(b.x1);
78 let iy1 = a.y1.max(b.y1);
79 let ix2 = a.x2.min(b.x2);
80 let iy2 = a.y2.min(b.y2);
81 let iw = (ix2 - ix1).max(0.0);
82 let ih = (iy2 - iy1).max(0.0);
83 let inter = iw * ih;
84 let union = area_a + area_b - inter;
85 if union <= 0.0 {
86 return 0.0;
87 }
88 (inter / union).clamp(0.0, 1.0)
89}
90
91pub fn nms(boxes: &[BBox], iou_threshold: f32) -> VisionResult<Vec<usize>> {
107 if !(0.0..=1.0).contains(&iou_threshold) || !iou_threshold.is_finite() {
108 return Err(VisionError::NonFinite("nms iou_threshold"));
109 }
110 if boxes.is_empty() {
111 return Ok(Vec::new());
112 }
113
114 let mut order: Vec<usize> = (0..boxes.len()).collect();
115 order.sort_by(|&a, &b| {
116 boxes[b]
117 .score
118 .partial_cmp(&boxes[a].score)
119 .unwrap_or(std::cmp::Ordering::Equal)
120 .then(a.cmp(&b))
121 });
122
123 let mut kept: Vec<usize> = Vec::new();
124 for &idx in &order {
125 let candidate = &boxes[idx];
126 let mut suppressed = false;
127 for &k in &kept {
128 if iou(candidate, &boxes[k]) > iou_threshold {
129 suppressed = true;
130 break;
131 }
132 }
133 if !suppressed {
134 kept.push(idx);
135 }
136 }
137 Ok(kept)
138}
139
140pub fn soft_nms(
157 boxes: &[BBox],
158 sigma: f32,
159 score_threshold: f32,
160) -> VisionResult<Vec<(usize, f32)>> {
161 if sigma <= 0.0 || !sigma.is_finite() {
162 return Err(VisionError::NonFinite("soft_nms sigma"));
163 }
164 if !score_threshold.is_finite() {
165 return Err(VisionError::NonFinite("soft_nms score_threshold"));
166 }
167 if boxes.is_empty() {
168 return Ok(Vec::new());
169 }
170
171 let inv_sigma = 1.0_f32 / sigma;
172 let mut pool: Vec<(usize, f32)> = boxes.iter().map(|b| (b.score, b)).enumerate().fold(
175 Vec::with_capacity(boxes.len()),
176 |mut acc, (i, (s, _))| {
177 acc.push((i, s));
178 acc
179 },
180 );
181
182 let mut out: Vec<(usize, f32)> = Vec::new();
183
184 while !pool.is_empty() {
185 let (max_pos, max_score) = pool.iter().enumerate().fold(
187 (0usize, f32::NEG_INFINITY),
188 |(best_i, best_s), (i, &(_, s))| {
189 if s > best_s { (i, s) } else { (best_i, best_s) }
190 },
191 );
192
193 if max_score <= score_threshold {
194 break;
196 }
197
198 let pivot = pool.swap_remove(max_pos);
199 out.push(pivot);
200
201 let pivot_box = &boxes[pivot.0];
202 for entry in pool.iter_mut() {
203 let ov = iou(pivot_box, &boxes[entry.0]);
204 let decay = (-(ov * ov) * inv_sigma).exp();
205 entry.1 *= decay;
206 }
207 }
208
209 Ok(out)
210}
211
212#[cfg(test)]
215mod tests {
216 use super::*;
217
218 fn b(x1: f32, y1: f32, x2: f32, y2: f32, s: f32) -> BBox {
219 BBox::new(x1, y1, x2, y2, s)
220 }
221
222 #[test]
223 fn iou_identical_is_1() {
224 let a = b(0.0, 0.0, 10.0, 10.0, 0.9);
225 assert!((iou(&a, &a) - 1.0).abs() < 1e-6);
226 }
227
228 #[test]
229 fn iou_disjoint_is_0() {
230 let a = b(0.0, 0.0, 1.0, 1.0, 0.9);
231 let c = b(2.0, 2.0, 3.0, 3.0, 0.8);
232 assert!(iou(&a, &c).abs() < 1e-7);
233 }
234
235 #[test]
236 fn iou_half_overlap() {
237 let a = b(0.0, 0.0, 10.0, 10.0, 0.9);
240 let c = b(5.0, 0.0, 15.0, 10.0, 0.8);
241 assert!((iou(&a, &c) - 1.0 / 3.0).abs() < 1e-5);
242 }
243
244 #[test]
245 fn nms_keeps_highest() {
246 let boxes = vec![b(0.0, 0.0, 10.0, 10.0, 0.4), b(1.0, 1.0, 10.0, 10.0, 0.9)];
248 let kept = nms(&boxes, 0.3).expect("ok");
249 assert_eq!(kept, vec![1]);
250 }
251
252 #[test]
253 fn nms_suppresses_overlap() {
254 let boxes = vec![
256 b(0.0, 0.0, 10.0, 10.0, 0.9),
257 b(0.5, 0.5, 10.5, 10.5, 0.8),
258 b(50.0, 50.0, 60.0, 60.0, 0.7),
259 ];
260 let kept = nms(&boxes, 0.3).expect("ok");
261 assert_eq!(kept, vec![0, 2]);
263 }
264
265 #[test]
266 fn nms_keeps_disjoint() {
267 let boxes = vec![b(0.0, 0.0, 1.0, 1.0, 0.9), b(5.0, 5.0, 6.0, 6.0, 0.8)];
268 let kept = nms(&boxes, 0.5).expect("ok");
269 assert_eq!(kept, vec![0, 1]);
270 }
271
272 #[test]
273 fn soft_nms_decays_scores() {
274 let boxes = vec![b(0.0, 0.0, 10.0, 10.0, 0.9), b(1.0, 1.0, 10.0, 10.0, 0.8)];
276 let out = soft_nms(&boxes, 0.5, 0.0).expect("ok");
277 assert_eq!(out.len(), 2);
278 assert_eq!(out[0].0, 0);
279 assert_eq!(out[1].0, 1);
280 assert!(out[1].1 < 0.8, "expected decay, got {}", out[1].1);
281 }
282
283 #[test]
284 fn empty_boxes() {
285 let boxes: Vec<BBox> = Vec::new();
286 assert_eq!(nms(&boxes, 0.5).expect("ok"), Vec::<usize>::new());
287 assert_eq!(
288 soft_nms(&boxes, 0.5, 0.0).expect("ok"),
289 Vec::<(usize, f32)>::new()
290 );
291 }
292
293 #[test]
294 fn threshold_1_keeps_all() {
295 let boxes = vec![b(0.0, 0.0, 1.0, 1.0, 0.9), b(0.0, 0.0, 1.0, 1.0, 0.8)];
297 let kept = nms(&boxes, 1.0).expect("ok");
298 assert_eq!(kept.len(), 2);
299 }
300
301 #[test]
302 fn soft_nms_threshold_filters() {
303 let boxes = vec![b(0.0, 0.0, 10.0, 10.0, 1.0), b(0.0, 0.0, 10.0, 10.0, 0.5)];
306 let out = soft_nms(&boxes, 0.5, 0.4).expect("ok");
307 assert_eq!(out.len(), 1);
308 assert_eq!(out[0].0, 0);
309 }
310
311 #[test]
314 fn nms_invalid_threshold_errors() {
315 let boxes = vec![b(0.0, 0.0, 1.0, 1.0, 0.9)];
316 assert!(matches!(nms(&boxes, 1.5), Err(VisionError::NonFinite(_))));
317 assert!(matches!(nms(&boxes, -0.1), Err(VisionError::NonFinite(_))));
318 assert!(matches!(
319 nms(&boxes, f32::NAN),
320 Err(VisionError::NonFinite(_))
321 ));
322 }
323
324 #[test]
325 fn soft_nms_invalid_sigma_errors() {
326 let boxes = vec![b(0.0, 0.0, 1.0, 1.0, 0.9)];
327 assert!(matches!(
328 soft_nms(&boxes, 0.0, 0.0),
329 Err(VisionError::NonFinite(_))
330 ));
331 assert!(matches!(
332 soft_nms(&boxes, -1.0, 0.0),
333 Err(VisionError::NonFinite(_))
334 ));
335 }
336
337 #[test]
338 fn nms_returns_descending_score_order() {
339 let boxes = vec![
340 b(0.0, 0.0, 1.0, 1.0, 0.3),
341 b(5.0, 0.0, 6.0, 1.0, 0.9),
342 b(10.0, 0.0, 11.0, 1.0, 0.5),
343 ];
344 let kept = nms(&boxes, 0.5).expect("ok");
345 assert_eq!(kept, vec![1, 2, 0]);
346 }
347
348 #[test]
349 fn soft_nms_disjoint_no_decay() {
350 let boxes = vec![b(0.0, 0.0, 1.0, 1.0, 0.7), b(5.0, 5.0, 6.0, 6.0, 0.6)];
351 let out = soft_nms(&boxes, 0.5, 0.0).expect("ok");
352 assert_eq!(out.len(), 2);
353 assert!((out[0].1 - 0.7).abs() < 1e-5);
354 assert!((out[1].1 - 0.6).abs() < 1e-5);
355 }
356
357 #[test]
358 fn iou_degenerate_is_0() {
359 let a = b(0.0, 0.0, 10.0, 10.0, 0.9);
360 let degenerate = b(5.0, 5.0, 5.0, 5.0, 0.8);
361 assert!(iou(&a, °enerate).abs() < 1e-7);
362 }
363}