1use std::sync::Arc;
2
3use image::{
4 imageops::{self, FilterType},
5 GenericImageView, ImageBuffer, Pixel, Rgb,
6};
7use itertools::Itertools;
8use ndarray::{Array3, ArrayViewD, Axis, CowArray};
9use ort::{tensor::OrtOwnedTensor, Value};
10
11use crate::{
12 detection::{FaceDetector, RustFacesResult},
13 imaging::make_border,
14 priorboxes::{PriorBoxes, PriorBoxesParams},
15 Face, Nms,
16};
17
18pub type Image<P> = ImageBuffer<P, Vec<<P as Pixel>::Subpixel>>;
19
20fn resize_and_border<I: GenericImageView>(
21 image: &I,
22 output_size: (u32, u32),
23 border_color: I::Pixel,
24) -> (Image<I::Pixel>, f32)
25where
26 I::Pixel: 'static,
27 <I::Pixel as Pixel>::Subpixel: 'static,
28{
29 let (input_width, input_height) = image.dimensions();
30 let (output_width, output_height) = output_size;
31 let ratio = (output_width as f32 / input_width as f32)
32 .min(output_height as f32 / input_height as f32)
33 .min(1.0); let (resize_width, resize_height) = (
36 (input_width as f32 * ratio).round() as i32,
37 (input_height as f32 * ratio).round() as i32,
38 );
39 let resized = imageops::resize(
40 image,
41 resize_width as u32,
42 resize_height as u32,
43 FilterType::Nearest,
44 );
45
46 let (left, right, top, bottom) = {
47 let (x_pad, y_pad) = (
48 ((output_width as i32 - resize_width) % 16) as f32 / 2.0,
49 ((output_height as i32 - resize_height) % 16) as f32 / 2.0,
50 );
51 (
52 (x_pad - 0.1).round() as u32,
53 (x_pad + 0.1).round() as u32,
54 (y_pad - 0.1).round() as u32,
55 (y_pad + 0.1).round() as u32,
56 )
57 };
58
59 (
60 make_border(&resized, top, bottom, left, right, border_color),
61 ratio,
62 )
63}
64
65#[derive(Debug, Clone)]
66pub struct BlazeFaceParams {
67 pub score_threshold: f32,
68 pub nms: Nms,
69 pub target_size: usize,
70 pub prior_boxes: PriorBoxesParams,
71}
72
73impl Default for BlazeFaceParams {
74 fn default() -> Self {
75 Self {
76 score_threshold: 0.95,
77 nms: Nms::default(),
78 target_size: 1280,
79 prior_boxes: PriorBoxesParams::default(),
80 }
81 }
82}
83
84pub struct BlazeFace {
85 session: ort::Session,
86 params: BlazeFaceParams,
87}
88
89impl BlazeFace {
90 pub fn from_file(
91 env: Arc<ort::Environment>,
92 model_path: &str,
93 params: BlazeFaceParams,
94 ) -> Self {
95 let session = ort::session::SessionBuilder::new(&env)
96 .unwrap()
97 .with_model_from_file(model_path)
98 .unwrap();
99 Self { session, params }
100 }
101}
102
103impl FaceDetector for BlazeFace {
104 fn detect(&self, image: ArrayViewD<u8>) -> RustFacesResult<Vec<Face>> {
105 let shape = image.shape().to_vec();
106 let (width, height, _) = (shape[1], shape[0], shape[2]);
107
108 let image = ImageBuffer::<Rgb<u8>, &[u8]>::from_raw(
109 width as u32,
110 height as u32,
111 image.as_slice().unwrap(),
112 )
113 .unwrap();
114
115 let (image, ratio) = resize_and_border(
116 &image,
117 (
118 self.params.target_size as u32,
119 self.params.target_size as u32,
120 ),
121 Rgb([104, 117, 123]),
122 );
123 let (input_width, input_height) = image.dimensions();
124 let image = Array3::<f32>::from_shape_fn(
125 (3, input_height as usize, input_width as usize),
126 |(c, y, x)| {
127 match c {
128 0 => image.get_pixel(x as u32, y as u32)[2] as f32 - 104.0,
130 1 => image.get_pixel(x as u32, y as u32)[1] as f32 - 117.0,
131 2 => image.get_pixel(x as u32, y as u32)[0] as f32 - 123.0,
132 _ => unreachable!(),
133 }
134 },
135 )
136 .insert_axis(Axis(0));
137
138 let output_tensors = self.session.run(vec![Value::from_array(
139 self.session.allocator(),
140 &CowArray::from(image).into_dyn(),
141 )?])?;
142
143 let boxes: OrtOwnedTensor<f32, _> = output_tensors[0].try_extract()?;
145 let scores: OrtOwnedTensor<f32, _> = output_tensors[1].try_extract()?;
146 let landmarks: OrtOwnedTensor<f32, _> = output_tensors[2].try_extract()?;
147 let num_boxes = boxes.view().shape()[1];
148
149 let priors = PriorBoxes::new(
150 &self.params.prior_boxes,
151 (input_width as usize, input_height as usize),
152 );
153
154 let scale_ratios = (input_width as f32 / ratio, input_height as f32 / ratio);
155
156 let faces = boxes
157 .view()
158 .to_shape((num_boxes, 4))
159 .unwrap()
160 .axis_iter(Axis(0))
161 .zip(
162 landmarks
163 .view()
164 .to_shape((num_boxes, 10))
165 .unwrap()
166 .axis_iter(Axis(0)),
167 )
168 .zip(priors.anchors.iter())
169 .zip(
170 scores
171 .view()
172 .to_shape((num_boxes, 2))
173 .unwrap()
174 .axis_iter(Axis(0)),
175 )
176 .filter_map(|(((rect, landmarks), prior), score)| {
177 let score = score[1];
178
179 if score > self.params.score_threshold {
180 let rect = priors.decode_box(prior, &(rect[0], rect[1], rect[2], rect[3]));
181 let rect = rect.scale(scale_ratios.0, scale_ratios.1);
182
183 let landmarks = landmarks
184 .to_vec()
185 .chunks(2)
186 .map(|point| {
187 let point = priors.decode_landmark(prior, (point[0], point[1]));
188 (point.0 * scale_ratios.0, point.1 * scale_ratios.1)
189 })
190 .collect::<Vec<_>>();
191
192 Some(Face {
193 rect,
194 landmarks: Some(landmarks),
195 confidence: score,
196 })
197 } else {
198 None
199 }
200 })
201 .collect_vec();
202
203 Ok(self.params.nms.suppress_non_maxima(faces))
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use crate::{
210 imaging::ToRgb8,
211 model_repository::{GitHubRepository, ModelRepository},
212 testing::{output_dir, sample_array_image, sample_image},
213 };
214
215 use super::*;
216 use image::RgbImage;
217 use rstest::rstest;
218
219 use std::path::PathBuf;
220
221 #[rstest]
222 pub fn test_resize_and_border(sample_image: RgbImage, output_dir: PathBuf) {
223 let (resized, _) = resize_and_border(&sample_image, (1280, 1280), Rgb([0, 255, 0]));
224
225 resized.save(output_dir.join("test_resized.jpg")).unwrap();
226 assert!(resized.width() == 896);
227 assert!(resized.height() == 1280);
228 }
229
230 #[cfg(feature = "viz")]
231 fn should_detect_impl(
232 blaze_model: crate::FaceDetection,
233 sample_array_image: Array3<u8>,
234 output_dir: PathBuf,
235 ) {
236 use crate::viz;
237 let environment = Arc::new(
238 ort::Environment::builder()
239 .with_name("BlazeFace")
240 .build()
241 .unwrap(),
242 );
243
244 let params = match &blaze_model {
245 crate::FaceDetection::BlazeFace640(params) => params.clone(),
246 crate::FaceDetection::BlazeFace320(params) => params.clone(),
247 _ => unreachable!(),
248 };
249
250 let drive = GitHubRepository::new();
251 let model_path = drive.get_model(&blaze_model).expect("Can't download model")[0].clone();
252
253 let face_detector = BlazeFace::from_file(environment, model_path.to_str().unwrap(), params);
254 let mut canvas = sample_array_image.to_rgb8();
255 let faces = face_detector
256 .detect(sample_array_image.into_dyn().view())
257 .unwrap();
258
259 viz::draw_faces(&mut canvas, faces);
260
261 canvas
262 .save(output_dir.join("blazefaces.png"))
263 .expect("Can't save image");
264 }
265
266 #[rstest]
267 #[cfg(feature = "viz")]
268 fn should_detect_640(sample_array_image: Array3<u8>, output_dir: PathBuf) {
269 should_detect_impl(
270 crate::FaceDetection::BlazeFace640(BlazeFaceParams::default()),
271 sample_array_image,
272 output_dir,
273 );
274 }
275
276 #[rstest]
277 #[cfg(feature = "viz")]
278 fn should_detect_320(sample_array_image: Array3<u8>, output_dir: PathBuf) {
279 should_detect_impl(
280 crate::FaceDetection::BlazeFace320(BlazeFaceParams::default()),
281 sample_array_image,
282 output_dir,
283 );
284 }
285}