rust_faces/
nms.rs

1use std::collections::HashMap;
2
3use crate::Face;
4
5/// Non-maximum suppression.
6#[derive(Copy, Clone, Debug)]
7pub struct Nms {
8    pub iou_threshold: f32,
9}
10
11impl Default for Nms {
12    fn default() -> Self {
13        Self { iou_threshold: 0.3 }
14    }
15}
16
17impl Nms {
18    /// Suppress non-maxima faces.
19    ///
20    /// # Arguments
21    ///
22    /// * `faces` - Faces to suppress.
23    ///
24    /// # Returns
25    ///
26    /// * `Vec<Face>` - Suppressed faces.
27    pub fn suppress_non_maxima(&self, mut faces: Vec<Face>) -> Vec<Face> {
28        faces.sort_by(|a, b| a.confidence.partial_cmp(&b.confidence).unwrap());
29
30        let mut faces_map = HashMap::new();
31        faces.iter().rev().enumerate().for_each(|(i, face)| {
32            faces_map.insert(i, face);
33        });
34
35        let mut nms_faces = Vec::with_capacity(faces.len());
36        let mut count = 0;
37        while !faces_map.is_empty() {
38            if let Some((_, face)) = faces_map.remove_entry(&count) {
39                nms_faces.push(face.clone());
40                //faces_map.retain(|_, face2| face.rect.iou(&face2.rect) < self.iou_threshold);
41                faces_map.retain(|_, face2| face.rect.iou(&face2.rect) < self.iou_threshold);
42            }
43            count += 1;
44        }
45
46        nms_faces
47    }
48
49    /// Suppress non-maxima faces.
50    ///
51    /// # Arguments
52    ///
53    /// * `faces` - Faces to suppress.
54    ///
55    /// # Returns
56    ///
57    /// * `Vec<Face>` - Suppressed faces.
58    pub fn suppress_non_maxima_min(&self, mut faces: Vec<Face>) -> Vec<Face> {
59        faces.sort_by(|a, b| a.confidence.partial_cmp(&b.confidence).unwrap());
60
61        let mut faces_map = HashMap::new();
62        faces.iter().rev().enumerate().for_each(|(i, face)| {
63            faces_map.insert(i, face);
64        });
65
66        let mut nms_faces = Vec::with_capacity(faces.len());
67        let mut count = 0;
68        while !faces_map.is_empty() {
69            if let Some((_, face)) = faces_map.remove_entry(&count) {
70                nms_faces.push(face.clone());
71                //faces_map.retain(|_, face2| face.rect.iou(&face2.rect) < self.iou_threshold);
72                faces_map.retain(|_, face2| face.rect.iou_min(&face2.rect) < self.iou_threshold);
73            }
74            count += 1;
75        }
76
77        nms_faces
78    }
79}
80
81#[cfg(test)]
82mod tests {
83    use rstest::rstest;
84
85    use super::*;
86    use crate::{Face, Rect};
87
88    #[rstest]
89    fn test_nms() {
90        let nms = Nms::default();
91        let faces = vec![
92            Face {
93                rect: Rect {
94                    x: 0.0,
95                    y: 0.0,
96                    width: 1.0,
97                    height: 1.0,
98                },
99                confidence: 0.9,
100                landmarks: None,
101            },
102            Face {
103                rect: Rect {
104                    x: 0.0,
105                    y: 0.0,
106                    width: 1.0,
107                    height: 1.0,
108                },
109                confidence: 0.8,
110                landmarks: None,
111            },
112            Face {
113                rect: Rect {
114                    x: 0.0,
115                    y: 0.0,
116                    width: 1.0,
117                    height: 1.0,
118                },
119                confidence: 0.7,
120                landmarks: None,
121            },
122        ];
123
124        let faces = nms.suppress_non_maxima(faces);
125
126        assert_eq!(faces.len(), 1);
127        assert_eq!(faces[0].confidence, 0.9);
128    }
129}