1use yscv_video::Frame;
2
3use crate::heatmap::{HeatmapDetectScratch, detect_from_heatmap_data_with_scratch, map_shape};
4use crate::nms::validate_nms_args;
5use crate::{CLASS_ID_FACE, DetectError, Detection, non_max_suppression};
6
7#[derive(Debug, Default, Clone, PartialEq)]
12pub struct Rgb8PeopleDetectScratch {
13 grayscale_heatmap: Vec<f32>,
14 heatmap: HeatmapDetectScratch,
15}
16
17#[derive(Debug, Default, Clone, PartialEq)]
22pub struct Rgb8FaceDetectScratch {
23 skin_heatmap: Vec<f32>,
24 heatmap: HeatmapDetectScratch,
25}
26
27#[derive(Debug, Default, Clone, PartialEq)]
32pub struct FramePeopleDetectScratch {
33 grayscale_heatmap: Vec<f32>,
34 heatmap: HeatmapDetectScratch,
35}
36
37#[derive(Debug, Default, Clone, PartialEq)]
42pub struct FrameFaceDetectScratch {
43 skin_heatmap: Vec<f32>,
44 heatmap: HeatmapDetectScratch,
45}
46
47pub fn detect_people_from_frame(
49 frame: &Frame,
50 score_threshold: f32,
51 min_area: usize,
52 iou_threshold: f32,
53 max_detections: usize,
54) -> Result<Vec<Detection>, DetectError> {
55 let mut scratch = FramePeopleDetectScratch::default();
56 detect_people_from_frame_with_scratch(
57 frame,
58 score_threshold,
59 min_area,
60 iou_threshold,
61 max_detections,
62 &mut scratch,
63 )
64}
65
66pub fn detect_people_from_frame_with_scratch(
68 frame: &Frame,
69 score_threshold: f32,
70 min_area: usize,
71 iou_threshold: f32,
72 max_detections: usize,
73 scratch: &mut FramePeopleDetectScratch,
74) -> Result<Vec<Detection>, DetectError> {
75 let image = frame.image();
76 let (h, w, c) = map_shape(image)?;
77 match c {
78 1 => detect_from_heatmap_data_with_scratch(
79 (h, w),
80 image.data(),
81 score_threshold,
82 min_area,
83 iou_threshold,
84 max_detections,
85 &mut scratch.heatmap,
86 ),
87 3 => {
88 fill_frame_rgb_grayscale_heatmap((h, w), image.data(), &mut scratch.grayscale_heatmap);
89 detect_from_heatmap_data_with_scratch(
90 (h, w),
91 &scratch.grayscale_heatmap,
92 score_threshold,
93 min_area,
94 iou_threshold,
95 max_detections,
96 &mut scratch.heatmap,
97 )
98 }
99 other => Err(DetectError::InvalidChannelCount {
100 expected: 1,
101 got: other,
102 }),
103 }
104}
105
106pub fn detect_people_from_rgb8(
111 width: usize,
112 height: usize,
113 rgb8: &[u8],
114 score_threshold: f32,
115 min_area: usize,
116 iou_threshold: f32,
117 max_detections: usize,
118) -> Result<Vec<Detection>, DetectError> {
119 let mut scratch = Rgb8PeopleDetectScratch::default();
120 detect_people_from_rgb8_with_scratch(
121 (width, height),
122 rgb8,
123 score_threshold,
124 min_area,
125 iou_threshold,
126 max_detections,
127 &mut scratch,
128 )
129}
130
131pub fn detect_people_from_rgb8_with_scratch(
133 shape: (usize, usize),
134 rgb8: &[u8],
135 score_threshold: f32,
136 min_area: usize,
137 iou_threshold: f32,
138 max_detections: usize,
139 scratch: &mut Rgb8PeopleDetectScratch,
140) -> Result<Vec<Detection>, DetectError> {
141 let (width, height) = shape;
142 fill_rgb8_grayscale_heatmap(width, height, rgb8, &mut scratch.grayscale_heatmap)?;
143 detect_from_heatmap_data_with_scratch(
144 (height, width),
145 &scratch.grayscale_heatmap,
146 score_threshold,
147 min_area,
148 iou_threshold,
149 max_detections,
150 &mut scratch.heatmap,
151 )
152}
153
154pub fn detect_faces_from_frame(
159 frame: &Frame,
160 score_threshold: f32,
161 min_area: usize,
162 iou_threshold: f32,
163 max_detections: usize,
164) -> Result<Vec<Detection>, DetectError> {
165 let mut scratch = FrameFaceDetectScratch::default();
166 detect_faces_from_frame_with_scratch(
167 frame,
168 score_threshold,
169 min_area,
170 iou_threshold,
171 max_detections,
172 &mut scratch,
173 )
174}
175
176pub fn detect_faces_from_frame_with_scratch(
178 frame: &Frame,
179 score_threshold: f32,
180 min_area: usize,
181 iou_threshold: f32,
182 max_detections: usize,
183 scratch: &mut FrameFaceDetectScratch,
184) -> Result<Vec<Detection>, DetectError> {
185 validate_nms_args(iou_threshold, max_detections)?;
186 let image = frame.image();
187 let (h, w, c) = map_shape(image)?;
188 if c != 3 {
189 return Err(DetectError::InvalidChannelCount {
190 expected: 3,
191 got: c,
192 });
193 }
194
195 fill_frame_rgb_skin_heatmap((h, w), image.data(), &mut scratch.skin_heatmap);
196 detect_faces_from_skin_heatmap_data_with_scratch(
197 (h, w),
198 &scratch.skin_heatmap,
199 score_threshold,
200 min_area,
201 iou_threshold,
202 max_detections,
203 &mut scratch.heatmap,
204 )
205}
206
207pub fn detect_faces_from_rgb8(
212 width: usize,
213 height: usize,
214 rgb8: &[u8],
215 score_threshold: f32,
216 min_area: usize,
217 iou_threshold: f32,
218 max_detections: usize,
219) -> Result<Vec<Detection>, DetectError> {
220 let mut scratch = Rgb8FaceDetectScratch::default();
221 detect_faces_from_rgb8_with_scratch(
222 (width, height),
223 rgb8,
224 score_threshold,
225 min_area,
226 iou_threshold,
227 max_detections,
228 &mut scratch,
229 )
230}
231
232pub fn detect_faces_from_rgb8_with_scratch(
234 shape: (usize, usize),
235 rgb8: &[u8],
236 score_threshold: f32,
237 min_area: usize,
238 iou_threshold: f32,
239 max_detections: usize,
240 scratch: &mut Rgb8FaceDetectScratch,
241) -> Result<Vec<Detection>, DetectError> {
242 let (width, height) = shape;
243 validate_nms_args(iou_threshold, max_detections)?;
244 fill_rgb8_skin_heatmap(width, height, rgb8, &mut scratch.skin_heatmap)?;
245 detect_faces_from_skin_heatmap_data_with_scratch(
246 (height, width),
247 &scratch.skin_heatmap,
248 score_threshold,
249 min_area,
250 iou_threshold,
251 max_detections,
252 &mut scratch.heatmap,
253 )
254}
255
256fn detect_faces_from_skin_heatmap_data_with_scratch(
257 shape: (usize, usize),
258 skin_heatmap_data: &[f32],
259 score_threshold: f32,
260 min_area: usize,
261 iou_threshold: f32,
262 max_detections: usize,
263 heatmap_scratch: &mut HeatmapDetectScratch,
264) -> Result<Vec<Detection>, DetectError> {
265 let candidate_limit = max_detections.saturating_mul(4).max(max_detections);
266 let candidates = detect_from_heatmap_data_with_scratch(
267 shape,
268 skin_heatmap_data,
269 score_threshold,
270 min_area,
271 iou_threshold,
272 candidate_limit,
273 heatmap_scratch,
274 )?;
275
276 let mut faces = Vec::with_capacity(candidates.len());
277 for candidate in candidates {
278 let height = candidate.bbox.height();
279 if height <= 1.0e-6 {
280 continue;
281 }
282 let aspect_ratio = candidate.bbox.width() / height;
283 if !(0.65..=1.8).contains(&aspect_ratio) {
284 continue;
285 }
286
287 let shape_score = triangular_score(aspect_ratio, 0.65, 1.8, 1.0);
288 let score = clamp01(0.75 * candidate.score + 0.25 * shape_score);
289 faces.push(Detection {
290 bbox: candidate.bbox,
291 score,
292 class_id: CLASS_ID_FACE,
293 });
294 }
295
296 Ok(non_max_suppression(&faces, iou_threshold, max_detections))
297}
298
299fn fill_frame_rgb_grayscale_heatmap(shape: (usize, usize), rgb: &[f32], out: &mut Vec<f32>) {
300 let pixel_count = shape.0 * shape.1;
301 if out.len() != pixel_count {
302 out.resize(pixel_count, 0.0);
303 }
304
305 for (rgb, value) in rgb.chunks_exact(3).zip(out.iter_mut()) {
306 *value = (rgb[0] + rgb[1] + rgb[2]) / 3.0;
307 }
308}
309
310fn fill_frame_rgb_skin_heatmap(shape: (usize, usize), rgb: &[f32], out: &mut Vec<f32>) {
311 let pixel_count = shape.0 * shape.1;
312 if out.len() != pixel_count {
313 out.resize(pixel_count, 0.0);
314 }
315
316 let max_value = rgb.iter().copied().fold(0.0f32, f32::max);
317 let scale = if max_value > 1.5 { 1.0 / 255.0 } else { 1.0 };
318 for (rgb, value) in rgb.chunks_exact(3).zip(out.iter_mut()) {
319 let r = clamp01(rgb[0] * scale);
320 let g = clamp01(rgb[1] * scale);
321 let b = clamp01(rgb[2] * scale);
322 *value = skin_probability(r, g, b);
323 }
324}
325
326fn fill_rgb8_skin_heatmap(
327 width: usize,
328 height: usize,
329 rgb8: &[u8],
330 out: &mut Vec<f32>,
331) -> Result<(), DetectError> {
332 validate_rgb8_buffer_size(width, height, rgb8)?;
333 let pixel_count = width
334 .checked_mul(height)
335 .ok_or(DetectError::Rgb8DimensionsOverflow { width, height })?;
336 if out.len() != pixel_count {
337 out.resize(pixel_count, 0.0);
338 }
339
340 const SCALE: f32 = 1.0 / 255.0;
341 for (rgb, value) in rgb8.chunks_exact(3).zip(out.iter_mut()) {
342 let r = rgb[0] as f32 * SCALE;
343 let g = rgb[1] as f32 * SCALE;
344 let b = rgb[2] as f32 * SCALE;
345 *value = skin_probability(r, g, b);
346 }
347 Ok(())
348}
349
350fn fill_rgb8_grayscale_heatmap(
351 width: usize,
352 height: usize,
353 rgb8: &[u8],
354 out: &mut Vec<f32>,
355) -> Result<(), DetectError> {
356 validate_rgb8_buffer_size(width, height, rgb8)?;
357 let pixel_count = width
358 .checked_mul(height)
359 .ok_or(DetectError::Rgb8DimensionsOverflow { width, height })?;
360 if out.len() != pixel_count {
361 out.resize(pixel_count, 0.0);
362 }
363
364 const SCALE: f32 = 1.0 / 255.0;
365 for (rgb, value) in rgb8.chunks_exact(3).zip(out.iter_mut()) {
366 *value = (rgb[0] as f32 + rgb[1] as f32 + rgb[2] as f32) * (SCALE / 3.0);
367 }
368 Ok(())
369}
370
371fn validate_rgb8_buffer_size(width: usize, height: usize, rgb8: &[u8]) -> Result<(), DetectError> {
372 let expected = width
373 .checked_mul(height)
374 .and_then(|pixels| pixels.checked_mul(3))
375 .ok_or(DetectError::Rgb8DimensionsOverflow { width, height })?;
376 if rgb8.len() != expected {
377 return Err(DetectError::InvalidRgb8BufferSize {
378 expected,
379 got: rgb8.len(),
380 });
381 }
382 Ok(())
383}
384
385fn skin_probability(r: f32, g: f32, b: f32) -> f32 {
386 let y = 0.299 * r + 0.587 * g + 0.114 * b;
387 let cb = 0.5 + 0.564 * (b - y);
388 let cr = 0.5 + 0.713 * (r - y);
389
390 let cb_score = triangular_score(cb, 0.28, 0.57, 0.43);
391 let cr_score = triangular_score(cr, 0.36, 0.76, 0.56);
392 let luminance_score = triangular_score(y, 0.08, 0.95, 0.55);
393
394 let rg_bias = clamp01((r - g + 0.15) / 0.35);
395 let gb_bias = clamp01((g - b + 0.10) / 0.35);
396 let chroma = ((r - g).abs() + (g - b).abs() + (r - b).abs()) / 3.0;
397 let saturation_score = clamp01(chroma / 0.45);
398
399 let score = 0.32 * cb_score
400 + 0.32 * cr_score
401 + 0.16 * luminance_score
402 + 0.10 * rg_bias
403 + 0.10 * gb_bias;
404 clamp01(score * saturation_score.max(0.3))
405}
406
407fn triangular_score(value: f32, min: f32, max: f32, center: f32) -> f32 {
408 if value < min || value > max {
409 return 0.0;
410 }
411 if (value - center).abs() <= f32::EPSILON {
412 return 1.0;
413 }
414 if value < center {
415 return (value - min) / (center - min);
416 }
417 (max - value) / (max - center)
418}
419
420fn clamp01(value: f32) -> f32 {
421 value.clamp(0.0, 1.0)
422}