1use std::collections::HashMap;
2
3use crate::Face;
4
5#[derive(Copy, Clone, Debug)]
7pub struct Nms {
8 pub iou_threshold: f32,
9}
10
11impl Default for Nms {
12 fn default() -> Self {
13 Self { iou_threshold: 0.3 }
14 }
15}
16
17impl Nms {
18 pub fn suppress_non_maxima(&self, mut faces: Vec<Face>) -> Vec<Face> {
28 faces.sort_by(|a, b| a.confidence.partial_cmp(&b.confidence).unwrap());
29
30 let mut faces_map = HashMap::new();
31 faces.iter().rev().enumerate().for_each(|(i, face)| {
32 faces_map.insert(i, face);
33 });
34
35 let mut nms_faces = Vec::with_capacity(faces.len());
36 let mut count = 0;
37 while !faces_map.is_empty() {
38 if let Some((_, face)) = faces_map.remove_entry(&count) {
39 nms_faces.push(face.clone());
40 faces_map.retain(|_, face2| face.rect.iou(&face2.rect) < self.iou_threshold);
42 }
43 count += 1;
44 }
45
46 nms_faces
47 }
48
49 pub fn suppress_non_maxima_min(&self, mut faces: Vec<Face>) -> Vec<Face> {
59 faces.sort_by(|a, b| a.confidence.partial_cmp(&b.confidence).unwrap());
60
61 let mut faces_map = HashMap::new();
62 faces.iter().rev().enumerate().for_each(|(i, face)| {
63 faces_map.insert(i, face);
64 });
65
66 let mut nms_faces = Vec::with_capacity(faces.len());
67 let mut count = 0;
68 while !faces_map.is_empty() {
69 if let Some((_, face)) = faces_map.remove_entry(&count) {
70 nms_faces.push(face.clone());
71 faces_map.retain(|_, face2| face.rect.iou_min(&face2.rect) < self.iou_threshold);
73 }
74 count += 1;
75 }
76
77 nms_faces
78 }
79}
80
81#[cfg(test)]
82mod tests {
83 use rstest::rstest;
84
85 use super::*;
86 use crate::{Face, Rect};
87
88 #[rstest]
89 fn test_nms() {
90 let nms = Nms::default();
91 let faces = vec![
92 Face {
93 rect: Rect {
94 x: 0.0,
95 y: 0.0,
96 width: 1.0,
97 height: 1.0,
98 },
99 confidence: 0.9,
100 landmarks: None,
101 },
102 Face {
103 rect: Rect {
104 x: 0.0,
105 y: 0.0,
106 width: 1.0,
107 height: 1.0,
108 },
109 confidence: 0.8,
110 landmarks: None,
111 },
112 Face {
113 rect: Rect {
114 x: 0.0,
115 y: 0.0,
116 width: 1.0,
117 height: 1.0,
118 },
119 confidence: 0.7,
120 landmarks: None,
121 },
122 ];
123
124 let faces = nms.suppress_non_maxima(faces);
125
126 assert_eq!(faces.len(), 1);
127 assert_eq!(faces[0].confidence, 0.9);
128 }
129}