use crate::backend_ort::{ModelYuNetOrt, ModelArcFaceOrt, ArcFaceNorm, OrtModelError};
use crate::face_alignment::align_face_sized;
use crate::face_detection::FaceDetection;
use crate::image_buffer::ImageBuffer;
#[derive(Debug, Clone)]
pub struct FaceResult {
pub x: f32,
pub y: f32,
pub width: f32,
pub height: f32,
pub confidence: f32,
pub landmarks: [[f32; 2]; 5],
pub embedding: [f32; 512],
}
pub struct FacePipeline {
detector: ModelYuNetOrt,
recognizer: ModelArcFaceOrt,
}
impl FacePipeline {
pub fn new(
detector_path: &str,
recognizer_path: &str,
) -> Result<Self, OrtModelError> {
Self::new_with_norm(detector_path, recognizer_path, ArcFaceNorm::default())
}
pub fn new_with_norm(
detector_path: &str,
recognizer_path: &str,
norm: ArcFaceNorm,
) -> Result<Self, OrtModelError> {
let detector = ModelYuNetOrt::new_from_file(detector_path)?;
let recognizer = ModelArcFaceOrt::new_from_file_with_norm(recognizer_path, norm)?;
Ok(Self { detector, recognizer })
}
#[cfg(feature = "ort-cuda-backend")]
pub fn new_cuda(
detector_path: &str,
recognizer_path: &str,
) -> Result<Self, OrtModelError> {
Self::new_cuda_with_norm(detector_path, recognizer_path, ArcFaceNorm::default())
}
#[cfg(feature = "ort-cuda-backend")]
pub fn new_cuda_with_norm(
detector_path: &str,
recognizer_path: &str,
norm: ArcFaceNorm,
) -> Result<Self, OrtModelError> {
let detector = ModelYuNetOrt::new_from_file_cuda(detector_path)?;
let recognizer = ModelArcFaceOrt::new_from_file_cuda_with_norm(recognizer_path, norm)?;
Ok(Self { detector, recognizer })
}
#[cfg(feature = "ort-tensorrt-backend")]
pub fn new_tensorrt(
detector_path: &str,
recognizer_path: &str,
) -> Result<Self, OrtModelError> {
Self::new_tensorrt_with_norm(detector_path, recognizer_path, ArcFaceNorm::default())
}
#[cfg(feature = "ort-tensorrt-backend")]
pub fn new_tensorrt_with_norm(
detector_path: &str,
recognizer_path: &str,
norm: ArcFaceNorm,
) -> Result<Self, OrtModelError> {
let detector = ModelYuNetOrt::new_from_file_tensorrt(detector_path)?;
let recognizer = ModelArcFaceOrt::new_from_file_tensorrt_with_norm(recognizer_path, norm)?;
Ok(Self { detector, recognizer })
}
pub fn input_size(&self) -> (u32, u32) {
self.detector.input_size()
}
pub fn aligned_size(&self) -> u32 {
self.recognizer.input_size()
}
pub fn set_letterbox(&mut self, enabled: bool) {
self.detector.set_letterbox(enabled);
}
pub fn process(
&mut self,
image: &ImageBuffer,
conf_threshold: f32,
nms_threshold: f32,
) -> Result<Vec<FaceResult>, OrtModelError> {
let detections = self.detector.forward(image, conf_threshold, nms_threshold)?;
let mut results = Vec::with_capacity(detections.len());
for det in &detections {
let aligned = align_face_sized(image, &det.landmarks, self.aligned_size());
let embedding = self.recognizer.forward(&aligned)?;
results.push(FaceResult {
x: det.x,
y: det.y,
width: det.width,
height: det.height,
confidence: det.confidence,
landmarks: det.landmarks,
embedding,
});
}
Ok(results)
}
pub fn embed(&mut self, aligned_face: &ImageBuffer) -> Result<[f32; 512], OrtModelError> {
self.recognizer.forward(aligned_face)
}
pub fn detect(
&mut self,
image: &ImageBuffer,
conf_threshold: f32,
nms_threshold: f32,
) -> Result<Vec<FaceDetection>, OrtModelError> {
self.detector.forward(image, conf_threshold, nms_threshold)
}
}
pub fn cosine_similarity(a: &[f32; 512], b: &[f32; 512]) -> f32 {
let mut dot = 0.0f32;
for i in 0..512 {
dot += a[i] * b[i];
}
dot
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cosine_similarity_identical() {
let mut a = [0.0f32; 512];
a[0] = 1.0;
let sim = cosine_similarity(&a, &a);
assert!((sim - 1.0).abs() < 1e-6);
}
#[test]
fn test_cosine_similarity_orthogonal() {
let mut a = [0.0f32; 512];
let mut b = [0.0f32; 512];
a[0] = 1.0;
b[1] = 1.0;
let sim = cosine_similarity(&a, &b);
assert!(sim.abs() < 1e-6);
}
#[test]
fn test_cosine_similarity_opposite() {
let mut a = [0.0f32; 512];
let mut b = [0.0f32; 512];
a[0] = 1.0;
b[0] = -1.0;
let sim = cosine_similarity(&a, &b);
assert!((sim + 1.0).abs() < 1e-6);
}
#[test]
#[ignore]
fn test_pipeline_arnold() {
ort::init().commit();
let detector_path = "pretrained/face_detection_yunet_2023mar.onnx";
let recognizer_path = "pretrained/w600k_mbf.onnx";
let image_path = "images/arnold.jpg";
if !std::path::Path::new(detector_path).exists()
|| !std::path::Path::new(recognizer_path).exists()
|| !std::path::Path::new(image_path).exists()
{
eprintln!("Skipping test_pipeline_arnold: model or image files not found");
return;
}
let mut pipeline = FacePipeline::new(detector_path, recognizer_path)
.expect("Failed to create pipeline");
let img = image::open(image_path).expect("Failed to load image");
let img_buffer = ImageBuffer::from_dynamic_image(img);
let faces = pipeline.process(&img_buffer, 0.7, 0.3)
.expect("Pipeline failed");
assert!(!faces.is_empty(), "No faces detected in arnold.jpg");
for (i, face) in faces.iter().enumerate() {
assert!(face.confidence > 0.5, "Face #{} confidence too low: {}", i, face.confidence);
let norm: f32 = face.embedding.iter().map(|v| v * v).sum::<f32>().sqrt();
assert!(
(norm - 1.0).abs() < 0.01,
"Face #{} embedding L2 norm is {}, expected ~1.0", i, norm
);
let nonzero = face.embedding.iter().any(|&v| v.abs() > 1e-6);
assert!(nonzero, "Face #{} embedding is all zeros", i);
}
}
#[test]
#[ignore]
fn test_alignment_on_arnold() {
ort::init().commit();
let detector_path = "pretrained/face_detection_yunet_2023mar.onnx";
let image_path = "images/arnold.jpg";
if !std::path::Path::new(detector_path).exists()
|| !std::path::Path::new(image_path).exists()
{
eprintln!("Skipping test_alignment_on_arnold: files not found");
return;
}
let mut detector = crate::backend_ort::ModelYuNetOrt::new_from_file(detector_path)
.expect("Failed to load YuNet");
let img = image::open(image_path).expect("Failed to load image");
let img_buffer = ImageBuffer::from_dynamic_image(img);
let detections = detector.forward(&img_buffer, 0.7, 0.3)
.expect("Detection failed");
assert!(!detections.is_empty(), "No faces detected");
let aligned = align_face(&img_buffer, &detections[0].landmarks);
assert_eq!(aligned.width(), 112);
assert_eq!(aligned.height(), 112);
assert_eq!(aligned.channels(), 3);
let data = aligned.as_array();
let sum: u64 = data.iter().map(|&v| v as u64).sum();
assert!(sum > 0, "Aligned face is all black");
}
#[test]
#[ignore]
fn test_mbf_vs_r50() {
ort::init().commit();
let detector_path = "pretrained/face_detection_yunet_2023mar.onnx";
let mbf_path = "pretrained/w600k_mbf.onnx";
let r50_path = "pretrained/w600k_r50.onnx";
let image_path = "images/arnold.jpg";
if !std::path::Path::new(detector_path).exists()
|| !std::path::Path::new(mbf_path).exists()
|| !std::path::Path::new(r50_path).exists()
|| !std::path::Path::new(image_path).exists()
{
eprintln!("Skipping test_mbf_vs_r50: model or image files not found");
return;
}
let mut pipeline_mbf = FacePipeline::new(detector_path, mbf_path)
.expect("Failed to create MBF pipeline");
let mut pipeline_r50 = FacePipeline::new_with_norm(
detector_path, r50_path, ArcFaceNorm::ResNet,
).expect("Failed to create R50 pipeline");
let img = image::open(image_path).expect("Failed to load image");
let img_buffer = ImageBuffer::from_dynamic_image(img);
let faces_mbf = pipeline_mbf.process(&img_buffer, 0.7, 0.3)
.expect("MBF pipeline failed");
let faces_r50 = pipeline_r50.process(&img_buffer, 0.7, 0.3)
.expect("R50 pipeline failed");
assert!(!faces_mbf.is_empty(), "MBF: no faces detected");
assert!(!faces_r50.is_empty(), "R50: no faces detected");
let norm_mbf: f32 = faces_mbf[0].embedding.iter().map(|v| v * v).sum::<f32>().sqrt();
let norm_r50: f32 = faces_r50[0].embedding.iter().map(|v| v * v).sum::<f32>().sqrt();
println!("MBF L2 norm: {:.4}", norm_mbf);
println!("R50 L2 norm: {:.4}", norm_r50);
assert!((norm_mbf - 1.0).abs() < 0.01, "MBF norm not ~1.0: {}", norm_mbf);
assert!((norm_r50 - 1.0).abs() < 0.01, "R50 norm not ~1.0: {}", norm_r50);
let cross_sim = cosine_similarity(&faces_mbf[0].embedding, &faces_r50[0].embedding);
println!("MBF vs R50 cosine similarity (informational): {:.4}", cross_sim);
}
}