#![allow(clippy::similar_names)]
#[cfg(feature = "clustering")]
use crate::analyzer::{FaceAnalysis, FaceAnalyzer};
use crate::detector::BoundingBox;
#[cfg(feature = "clustering")]
use crate::error::FaceIdError;
#[cfg(feature = "clustering")]
use hdbscan::{DistanceMetric, Hdbscan, HdbscanHyperParams, NnAlgorithm};
use image::{DynamicImage, GenericImageView, ImageBuffer, Rgb};
use rayon::prelude::*;
#[cfg(feature = "clustering")]
use std::collections::HashMap;
#[cfg(feature = "clustering")]
use std::path::{Path, PathBuf};
#[must_use]
pub fn extract_face_thumbnail(
img: &DynamicImage,
bbox: &BoundingBox,
padding_factor: f32,
size: u32,
) -> ImageBuffer<Rgb<u8>, Vec<u8>> {
let (img_w, img_h) = img.dimensions();
let bbox = bbox.scale(img_w, img_h);
let width = bbox.width();
let height = bbox.height();
let cx = bbox.x1 + width / 2.0;
let cy = bbox.y1 + height / 2.0;
let side = width.max(height) * padding_factor;
let x1 = (cx - side / 2.0).round() as i32;
let y1 = (cy - side / 2.0).round() as i32;
let side_u = side.round() as u32;
let src_x1 = x1.max(0) as u32;
let src_y1 = y1.max(0) as u32;
let src_x2 = (x1 + side_u.cast_signed()).min(img_w.cast_signed()) as u32;
let src_y2 = (y1 + side_u.cast_signed()).min(img_h.cast_signed()) as u32;
if src_x2 > src_x1 && src_y2 > src_y1 {
let crop_w = src_x2 - src_x1;
let crop_h = src_y2 - src_y1;
let sub_img = img.view(src_x1, src_y1, crop_w, crop_h).to_image();
DynamicImage::ImageRgba8(sub_img)
.resize(size, size, image::imageops::FilterType::CatmullRom)
.to_rgb8()
} else {
ImageBuffer::new(size, size)
}
}
#[cfg(feature = "clustering")]
#[bon::builder]
pub fn cluster_faces<P: AsRef<Path> + Sync + Send>(
#[builder(start_fn)] analyzer: &FaceAnalyzer,
#[builder(start_fn)] paths: Vec<P>,
#[builder(default = 5)] min_cluster_size: usize,
#[builder(default = usize::MAX)] max_cluster_size: usize,
#[builder(default = false)] allow_single_cluster: bool,
min_samples: Option<usize>,
#[builder(default = 0.0)] epsilon: f64,
#[builder(default = DistanceMetric::Euclidean)] dist_metric: DistanceMetric,
#[builder(default = NnAlgorithm::Auto)] nn_algo: NnAlgorithm,
) -> Result<HashMap<i32, Vec<(PathBuf, FaceAnalysis)>>, FaceIdError> {
let all_faces: Vec<(PathBuf, FaceAnalysis)> = paths
.into_par_iter()
.map(
|path_ref| -> Result<Vec<(PathBuf, FaceAnalysis)>, FaceIdError> {
let path = path_ref.as_ref().to_path_buf();
let img = image::open(&path)?;
let faces = analyzer.analyze(&img)?;
Ok(faces.into_iter().map(|f| (path.clone(), f)).collect())
},
)
.collect::<Result<Vec<Vec<_>>, _>>()?
.into_iter()
.flatten()
.collect();
if all_faces.is_empty() {
return Ok(HashMap::new());
}
let (embeddings, face_refs): (Vec<Vec<f32>>, Vec<&(PathBuf, FaceAnalysis)>) = all_faces
.iter()
.map(|pair| (pair.1.embedding.clone(), pair))
.unzip();
if embeddings.is_empty() {
return Ok(HashMap::new());
}
let mut hp_builder = HdbscanHyperParams::builder()
.min_cluster_size(min_cluster_size)
.max_cluster_size(max_cluster_size)
.allow_single_cluster(allow_single_cluster)
.epsilon(epsilon)
.dist_metric(dist_metric)
.nn_algorithm(nn_algo);
if let Some(ms) = min_samples {
hp_builder = hp_builder.min_samples(ms);
} else {
hp_builder = hp_builder.min_samples(min_cluster_size);
}
let hyper_params = hp_builder.build();
let clusterer = Hdbscan::new(&embeddings, hyper_params);
let labels: Vec<i32> = clusterer
.cluster()
.map_err(|e| FaceIdError::Clustering(e.to_string()))?;
let mut clusters: HashMap<i32, Vec<(PathBuf, FaceAnalysis)>> = HashMap::new();
for (idx, &label) in labels.iter().enumerate() {
let (path, face) = face_refs[idx];
clusters
.entry(label)
.or_default()
.push((path.clone(), face.clone()));
}
Ok(clusters)
}
#[cfg(test)]
mod tests {
use super::*;
use image::RgbImage;
#[test]
fn test_extract_face_thumbnail_edge_case() {
let img = DynamicImage::ImageRgb8(RgbImage::from_pixel(50, 100, Rgb([255, 255, 255])));
let bbox = BoundingBox {
x1: 0.8,
y1: 0.5,
x2: 1.0,
y2: 0.6,
};
let thumbnail = extract_face_thumbnail(&img, &bbox, 4.0, 100);
assert_ne!(thumbnail.width(), thumbnail.height());
assert_eq!(thumbnail.width(), 63);
assert_eq!(thumbnail.height(), 100);
}
}