use crate::detector::{DetectedFace, ScrfdDetector};
use crate::embedder::ArcFaceEmbedder;
use crate::error::FaceIdError;
use crate::face_align::norm_crop;
use crate::gender_age::{Gender, GenderAgeEstimator};
use crate::model_manager::HfModel;
use bon::bon;
use image::DynamicImage;
use ort::ep::ExecutionProviderDispatch;
use rayon::prelude::*;
use std::path::Path;
use std::sync::Mutex;
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct FaceAnalysis {
pub detection: DetectedFace,
pub embedding: Vec<f32>,
pub gender: Gender,
pub age: u8,
}
pub struct FaceAnalyzer {
pub detector: Mutex<ScrfdDetector>,
pub embedder: Mutex<ArcFaceEmbedder>,
pub gender_age: Mutex<GenderAgeEstimator>,
}
#[bon]
impl FaceAnalyzer {
#[cfg(feature = "hf-hub")]
#[builder(finish_fn = build)]
pub async fn from_hf(
#[builder(default = HfModel::default_detector())] detector_model: HfModel,
#[builder(default = HfModel::default_embedder())] embedder_model: HfModel,
#[builder(default = HfModel::default_gender_age())] gender_age_model: HfModel,
#[builder(default = (640, 640))] detector_input_size: (u32, u32),
#[builder(default = 0.5)] detector_score_threshold: f32,
#[builder(default = 0.4)] detector_iou_threshold: f32,
#[builder(default = &[])] with_execution_providers: &[ExecutionProviderDispatch],
) -> Result<Self, FaceIdError> {
let detector = ScrfdDetector::from_hf()
.input_size(detector_input_size)
.score_threshold(detector_score_threshold)
.iou_threshold(detector_iou_threshold)
.model(detector_model)
.with_execution_providers(with_execution_providers)
.build()
.await?;
let embedder = ArcFaceEmbedder::from_hf()
.model(embedder_model)
.with_execution_providers(with_execution_providers)
.build()
.await?;
let gender_age = GenderAgeEstimator::from_hf()
.model(gender_age_model)
.with_execution_providers(with_execution_providers)
.build()
.await?;
Ok(Self {
detector: Mutex::new(detector),
embedder: Mutex::new(embedder),
gender_age: Mutex::new(gender_age),
})
}
#[builder(finish_fn = build)]
pub fn new(
#[builder(start_fn)] det_model: impl AsRef<Path>,
#[builder(start_fn)] rec_model: impl AsRef<Path>,
#[builder(start_fn)] attr_model: impl AsRef<Path>,
#[builder(default = (640, 640))] detector_input_size: (u32, u32),
#[builder(default = 0.5)] detector_score_threshold: f32,
#[builder(default = 0.4)] detector_iou_threshold: f32,
#[builder(default = &[])] with_execution_providers: &[ExecutionProviderDispatch],
) -> Result<Self, FaceIdError> {
let detector = ScrfdDetector::builder(det_model)
.input_size(detector_input_size)
.score_threshold(detector_score_threshold)
.iou_threshold(detector_iou_threshold)
.with_execution_providers(with_execution_providers)
.build()?;
let embedder = ArcFaceEmbedder::builder(rec_model)
.with_execution_providers(with_execution_providers)
.build()?;
let gender_age = GenderAgeEstimator::builder(attr_model)
.with_execution_providers(with_execution_providers)
.build()?;
Ok(Self {
detector: Mutex::new(detector),
embedder: Mutex::new(embedder),
gender_age: Mutex::new(gender_age),
})
}
pub fn analyze(&self, img: &DynamicImage) -> Result<Vec<FaceAnalysis>, FaceIdError> {
let rgb_img = img.to_rgb8();
let results = self
.detector
.lock()
.map_err(|_| FaceIdError::MutexPoisoned("Detector".into()))?
.detect(img)?;
if results.is_empty() {
return Ok(vec![]);
}
let embed_crops: Vec<_> = results
.par_iter()
.map(|res| {
let landmarks = res.landmarks.as_ref().ok_or_else(|| {
FaceIdError::InvalidModel(
"One or more faces missing landmarks for embedding".into(),
)
})?;
let lms_array: [(f32, f32); 5] = landmarks
.iter()
.map(|&(x, y)| (x * rgb_img.width() as f32, y * rgb_img.height() as f32))
.collect::<Vec<_>>()
.try_into()
.map_err(|_| {
FaceIdError::InvalidModel("Landmarks were not 5-point keypoints".into())
})?;
Ok(norm_crop(&rgb_img, &lms_array, 112))
})
.collect::<Result<Vec<_>, FaceIdError>>()?;
let embeddings = self
.embedder
.lock()
.map_err(|_| FaceIdError::MutexPoisoned("Embedder".into()))?
.compute_embeddings_batch(&embed_crops)?;
let ga_crops: Vec<_> = results
.par_iter()
.map(|res| GenderAgeEstimator::align_crop(&rgb_img, &res.bbox, 96))
.collect();
let ga_results = self
.gender_age
.lock()
.map_err(|_| FaceIdError::MutexPoisoned("GenderAge".into()))?
.estimate_batch(&ga_crops)?;
if embeddings.len() != results.len() || ga_results.len() != results.len() {
return Err(FaceIdError::Ort(format!(
"Inconsistent batch results: expected {}, got {} embeddings and {} ga results",
results.len(),
embeddings.len(),
ga_results.len()
)));
}
let final_results = results
.into_iter()
.zip(embeddings)
.zip(ga_results)
.map(|((det, emb), ga)| FaceAnalysis {
detection: det,
embedding: emb,
gender: ga.gender,
age: ga.age,
})
.collect();
Ok(final_results)
}
}