use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
#[cfg(feature = "runtime-tract")]
use tract_onnx::prelude::Framework;
use crate::{
verify_captured_photos, Analyzer, CapturedPhotoAnalysis, FaceVerificationError, FrameAnalysis,
ImageInput, LivenessChallenge, VerificationResult, VerificationThresholds,
};
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub enum ModelRole {
FaceDetector,
FaceLandmarks,
FaceEmbedding,
FaceExpression,
HandLandmarks,
Age,
Nsfw,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub enum ModelFormat {
Onnx,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ModelAsset {
pub role: ModelRole,
pub format: ModelFormat,
pub bytes: Vec<u8>,
pub input_width: Option<u32>,
pub input_height: Option<u32>,
pub input_name: Option<String>,
pub output_names: Vec<String>,
pub license: Option<String>,
pub source: Option<String>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
pub struct ModelBundle {
pub models: Vec<ModelAsset>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct ModelValidation {
pub model_count: usize,
pub roles: Vec<ModelRole>,
}
impl ModelBundle {
pub fn validate(&self) -> Result<ModelValidation, FaceVerificationError> {
if self.models.is_empty() {
return Err(FaceVerificationError::InvalidModel(
"model bundle is empty".to_owned(),
));
}
for model in &self.models {
if model.bytes.is_empty() {
return Err(FaceVerificationError::InvalidModel(format!(
"{:?} model bytes are empty",
model.role
)));
}
validate_model_format(model)?;
}
Ok(ModelValidation {
model_count: self.models.len(),
roles: self.models.iter().map(|model| model.role).collect(),
})
}
pub fn has_role(&self, role: ModelRole) -> bool {
self.models.iter().any(|model| model.role == role)
}
}
#[derive(Debug, Clone)]
pub struct FaceVerificationEngine {
bundle: ModelBundle,
thresholds: VerificationThresholds,
}
impl FaceVerificationEngine {
pub fn new(
bundle: ModelBundle,
thresholds: VerificationThresholds,
) -> Result<Self, FaceVerificationError> {
bundle.validate()?;
Ok(Self { bundle, thresholds })
}
pub fn analyze_image(
&self,
image: &ImageInput,
) -> Result<FrameAnalysis, FaceVerificationError> {
if image.bytes.is_empty() || image.width == 0 || image.height == 0 {
return Err(FaceVerificationError::InvalidImage);
}
self.require_role(ModelRole::FaceDetector)?;
self.require_role(ModelRole::FaceEmbedding)?;
Err(FaceVerificationError::NotImplemented(
"ONNX image analysis runtime",
))
}
pub fn verify_images(
&self,
images: &[ImageInput],
challenge: &LivenessChallenge,
) -> Result<VerificationResult, FaceVerificationError> {
let photos = images
.iter()
.map(|image| {
Ok(CapturedPhotoAnalysis {
hash: image_hash(image),
frame: self.analyze_image(image)?,
nsfw: None,
})
})
.collect::<Result<Vec<_>, FaceVerificationError>>()?;
Ok(verify_captured_photos(&photos, challenge, &self.thresholds))
}
fn require_role(&self, role: ModelRole) -> Result<(), FaceVerificationError> {
if self.bundle.has_role(role) {
Ok(())
} else {
Err(FaceVerificationError::ModelNotLoaded(match role {
ModelRole::FaceDetector => "face detector",
ModelRole::FaceLandmarks => "face landmarks",
ModelRole::FaceEmbedding => "face embedding",
ModelRole::FaceExpression => "face expression",
ModelRole::HandLandmarks => "hand landmarks",
ModelRole::Age => "age",
ModelRole::Nsfw => "nsfw",
}))
}
}
}
impl Analyzer for FaceVerificationEngine {
fn analyze(&self, image: &ImageInput) -> Result<FrameAnalysis, FaceVerificationError> {
self.analyze_image(image)
}
}
fn image_hash(image: &ImageInput) -> String {
let mut hasher = Sha256::new();
hasher.update(&image.bytes);
format!("{:x}", hasher.finalize())
}
#[cfg(not(feature = "runtime-tract"))]
fn validate_model_format(model: &ModelAsset) -> Result<(), FaceVerificationError> {
match model.format {
ModelFormat::Onnx => Ok(()),
}
}
#[cfg(feature = "runtime-tract")]
fn validate_model_format(model: &ModelAsset) -> Result<(), FaceVerificationError> {
match model.format {
ModelFormat::Onnx => {
let mut cursor = std::io::Cursor::new(&model.bytes);
tract_onnx::onnx()
.model_for_read(&mut cursor)
.map(|_| ())
.map_err(|err| {
FaceVerificationError::InvalidModel(format!(
"{:?} ONNX model could not be parsed by tract: {err}",
model.role
))
})
}
}
}