rust_faces/
blazeface.rs

1use std::sync::Arc;
2
3use image::{
4    imageops::{self, FilterType},
5    GenericImageView, ImageBuffer, Pixel, Rgb,
6};
7use itertools::Itertools;
8use ndarray::{Array3, ArrayViewD, Axis, CowArray};
9use ort::{tensor::OrtOwnedTensor, Value};
10
11use crate::{
12    detection::{FaceDetector, RustFacesResult},
13    imaging::make_border,
14    priorboxes::{PriorBoxes, PriorBoxesParams},
15    Face, Nms,
16};
17
18pub type Image<P> = ImageBuffer<P, Vec<<P as Pixel>::Subpixel>>;
19
20fn resize_and_border<I: GenericImageView>(
21    image: &I,
22    output_size: (u32, u32),
23    border_color: I::Pixel,
24) -> (Image<I::Pixel>, f32)
25where
26    I::Pixel: 'static,
27    <I::Pixel as Pixel>::Subpixel: 'static,
28{
29    let (input_width, input_height) = image.dimensions();
30    let (output_width, output_height) = output_size;
31    let ratio = (output_width as f32 / input_width as f32)
32        .min(output_height as f32 / input_height as f32)
33        .min(1.0); // avoid scaling up.
34
35    let (resize_width, resize_height) = (
36        (input_width as f32 * ratio).round() as i32,
37        (input_height as f32 * ratio).round() as i32,
38    );
39    let resized = imageops::resize(
40        image,
41        resize_width as u32,
42        resize_height as u32,
43        FilterType::Nearest,
44    );
45
46    let (left, right, top, bottom) = {
47        let (x_pad, y_pad) = (
48            ((output_width as i32 - resize_width) % 16) as f32 / 2.0,
49            ((output_height as i32 - resize_height) % 16) as f32 / 2.0,
50        );
51        (
52            (x_pad - 0.1).round() as u32,
53            (x_pad + 0.1).round() as u32,
54            (y_pad - 0.1).round() as u32,
55            (y_pad + 0.1).round() as u32,
56        )
57    };
58
59    (
60        make_border(&resized, top, bottom, left, right, border_color),
61        ratio,
62    )
63}
64
65#[derive(Debug, Clone)]
66pub struct BlazeFaceParams {
67    pub score_threshold: f32,
68    pub nms: Nms,
69    pub target_size: usize,
70    pub prior_boxes: PriorBoxesParams,
71}
72
73impl Default for BlazeFaceParams {
74    fn default() -> Self {
75        Self {
76            score_threshold: 0.95,
77            nms: Nms::default(),
78            target_size: 1280,
79            prior_boxes: PriorBoxesParams::default(),
80        }
81    }
82}
83
84pub struct BlazeFace {
85    session: ort::Session,
86    params: BlazeFaceParams,
87}
88
89impl BlazeFace {
90    pub fn from_file(
91        env: Arc<ort::Environment>,
92        model_path: &str,
93        params: BlazeFaceParams,
94    ) -> Self {
95        let session = ort::session::SessionBuilder::new(&env)
96            .unwrap()
97            .with_model_from_file(model_path)
98            .unwrap();
99        Self { session, params }
100    }
101}
102
103impl FaceDetector for BlazeFace {
104    fn detect(&self, image: ArrayViewD<u8>) -> RustFacesResult<Vec<Face>> {
105        let shape = image.shape().to_vec();
106        let (width, height, _) = (shape[1], shape[0], shape[2]);
107
108        let image = ImageBuffer::<Rgb<u8>, &[u8]>::from_raw(
109            width as u32,
110            height as u32,
111            image.as_slice().unwrap(),
112        )
113        .unwrap();
114
115        let (image, ratio) = resize_and_border(
116            &image,
117            (
118                self.params.target_size as u32,
119                self.params.target_size as u32,
120            ),
121            Rgb([104, 117, 123]),
122        );
123        let (input_width, input_height) = image.dimensions();
124        let image = Array3::<f32>::from_shape_fn(
125            (3, input_height as usize, input_width as usize),
126            |(c, y, x)| {
127                match c {
128                    // https://github.com/zineos/blazeface/blob/main/tools/test.py seems to use OpenCV's BGR
129                    0 => image.get_pixel(x as u32, y as u32)[2] as f32 - 104.0,
130                    1 => image.get_pixel(x as u32, y as u32)[1] as f32 - 117.0,
131                    2 => image.get_pixel(x as u32, y as u32)[0] as f32 - 123.0,
132                    _ => unreachable!(),
133                }
134            },
135        )
136        .insert_axis(Axis(0));
137
138        let output_tensors = self.session.run(vec![Value::from_array(
139            self.session.allocator(),
140            &CowArray::from(image).into_dyn(),
141        )?])?;
142
143        // Boxes regressions: N box with the format [start x, start y, end x, end y].
144        let boxes: OrtOwnedTensor<f32, _> = output_tensors[0].try_extract()?;
145        let scores: OrtOwnedTensor<f32, _> = output_tensors[1].try_extract()?;
146        let landmarks: OrtOwnedTensor<f32, _> = output_tensors[2].try_extract()?;
147        let num_boxes = boxes.view().shape()[1];
148
149        let priors = PriorBoxes::new(
150            &self.params.prior_boxes,
151            (input_width as usize, input_height as usize),
152        );
153
154        let scale_ratios = (input_width as f32 / ratio, input_height as f32 / ratio);
155
156        let faces = boxes
157            .view()
158            .to_shape((num_boxes, 4))
159            .unwrap()
160            .axis_iter(Axis(0))
161            .zip(
162                landmarks
163                    .view()
164                    .to_shape((num_boxes, 10))
165                    .unwrap()
166                    .axis_iter(Axis(0)),
167            )
168            .zip(priors.anchors.iter())
169            .zip(
170                scores
171                    .view()
172                    .to_shape((num_boxes, 2))
173                    .unwrap()
174                    .axis_iter(Axis(0)),
175            )
176            .filter_map(|(((rect, landmarks), prior), score)| {
177                let score = score[1];
178
179                if score > self.params.score_threshold {
180                    let rect = priors.decode_box(prior, &(rect[0], rect[1], rect[2], rect[3]));
181                    let rect = rect.scale(scale_ratios.0, scale_ratios.1);
182
183                    let landmarks = landmarks
184                        .to_vec()
185                        .chunks(2)
186                        .map(|point| {
187                            let point = priors.decode_landmark(prior, (point[0], point[1]));
188                            (point.0 * scale_ratios.0, point.1 * scale_ratios.1)
189                        })
190                        .collect::<Vec<_>>();
191
192                    Some(Face {
193                        rect,
194                        landmarks: Some(landmarks),
195                        confidence: score,
196                    })
197                } else {
198                    None
199                }
200            })
201            .collect_vec();
202
203        Ok(self.params.nms.suppress_non_maxima(faces))
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use crate::{
210        imaging::ToRgb8,
211        model_repository::{GitHubRepository, ModelRepository},
212        testing::{output_dir, sample_array_image, sample_image},
213    };
214
215    use super::*;
216    use image::RgbImage;
217    use rstest::rstest;
218
219    use std::path::PathBuf;
220
221    #[rstest]
222    pub fn test_resize_and_border(sample_image: RgbImage, output_dir: PathBuf) {
223        let (resized, _) = resize_and_border(&sample_image, (1280, 1280), Rgb([0, 255, 0]));
224
225        resized.save(output_dir.join("test_resized.jpg")).unwrap();
226        assert!(resized.width() == 896);
227        assert!(resized.height() == 1280);
228    }
229
230    #[cfg(feature = "viz")]
231    fn should_detect_impl(
232        blaze_model: crate::FaceDetection,
233        sample_array_image: Array3<u8>,
234        output_dir: PathBuf,
235    ) {
236        use crate::viz;
237        let environment = Arc::new(
238            ort::Environment::builder()
239                .with_name("BlazeFace")
240                .build()
241                .unwrap(),
242        );
243
244        let params = match &blaze_model {
245            crate::FaceDetection::BlazeFace640(params) => params.clone(),
246            crate::FaceDetection::BlazeFace320(params) => params.clone(),
247            _ => unreachable!(),
248        };
249
250        let drive = GitHubRepository::new();
251        let model_path = drive.get_model(&blaze_model).expect("Can't download model")[0].clone();
252
253        let face_detector = BlazeFace::from_file(environment, model_path.to_str().unwrap(), params);
254        let mut canvas = sample_array_image.to_rgb8();
255        let faces = face_detector
256            .detect(sample_array_image.into_dyn().view())
257            .unwrap();
258
259        viz::draw_faces(&mut canvas, faces);
260
261        canvas
262            .save(output_dir.join("blazefaces.png"))
263            .expect("Can't save image");
264    }
265
266    #[rstest]
267    #[cfg(feature = "viz")]
268    fn should_detect_640(sample_array_image: Array3<u8>, output_dir: PathBuf) {
269        should_detect_impl(
270            crate::FaceDetection::BlazeFace640(BlazeFaceParams::default()),
271            sample_array_image,
272            output_dir,
273        );
274    }
275
276    #[rstest]
277    #[cfg(feature = "viz")]
278    fn should_detect_320(sample_array_image: Array3<u8>, output_dir: PathBuf) {
279        should_detect_impl(
280            crate::FaceDetection::BlazeFace320(BlazeFaceParams::default()),
281            sample_array_image,
282            output_dir,
283        );
284    }
285}