face_verification_core 0.2.0

Cross-platform on-device face liveness and verification core.
Documentation
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
#[cfg(feature = "runtime-tract")]
use tract_onnx::prelude::Framework;

use crate::{
    verify_captured_photos, Analyzer, CapturedPhotoAnalysis, FaceVerificationError, FrameAnalysis,
    ImageInput, LivenessChallenge, VerificationResult, VerificationThresholds,
};

/// Role a model plays in the full image-analysis pipeline.
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub enum ModelRole {
    FaceDetector,
    FaceLandmarks,
    FaceEmbedding,
    FaceExpression,
    HandLandmarks,
    Age,
    Nsfw,
}

/// Serialized model format accepted by the Rust runtime.
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub enum ModelFormat {
    Onnx,
}

/// One model file plus the metadata needed to wire it into a pipeline.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ModelAsset {
    pub role: ModelRole,
    pub format: ModelFormat,
    pub bytes: Vec<u8>,
    pub input_width: Option<u32>,
    pub input_height: Option<u32>,
    pub input_name: Option<String>,
    pub output_names: Vec<String>,
    pub license: Option<String>,
    pub source: Option<String>,
}

/// A complete, swappable model set.
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
pub struct ModelBundle {
    pub models: Vec<ModelAsset>,
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct ModelValidation {
    pub model_count: usize,
    pub roles: Vec<ModelRole>,
}

impl ModelBundle {
    pub fn validate(&self) -> Result<ModelValidation, FaceVerificationError> {
        if self.models.is_empty() {
            return Err(FaceVerificationError::InvalidModel(
                "model bundle is empty".to_owned(),
            ));
        }

        for model in &self.models {
            if model.bytes.is_empty() {
                return Err(FaceVerificationError::InvalidModel(format!(
                    "{:?} model bytes are empty",
                    model.role
                )));
            }
            validate_model_format(model)?;
        }

        Ok(ModelValidation {
            model_count: self.models.len(),
            roles: self.models.iter().map(|model| model.role).collect(),
        })
    }

    pub fn has_role(&self, role: ModelRole) -> bool {
        self.models.iter().any(|model| model.role == role)
    }
}

/// Headless verification engine. It owns the model bundle and thresholds, and
/// is the future home for the ONNX/tract image-analysis pipeline.
#[derive(Debug, Clone)]
pub struct FaceVerificationEngine {
    bundle: ModelBundle,
    thresholds: VerificationThresholds,
}

impl FaceVerificationEngine {
    pub fn new(
        bundle: ModelBundle,
        thresholds: VerificationThresholds,
    ) -> Result<Self, FaceVerificationError> {
        bundle.validate()?;
        Ok(Self { bundle, thresholds })
    }

    pub fn analyze_image(
        &self,
        image: &ImageInput,
    ) -> Result<FrameAnalysis, FaceVerificationError> {
        if image.bytes.is_empty() || image.width == 0 || image.height == 0 {
            return Err(FaceVerificationError::InvalidImage);
        }
        self.require_role(ModelRole::FaceDetector)?;
        self.require_role(ModelRole::FaceEmbedding)?;

        Err(FaceVerificationError::NotImplemented(
            "ONNX image analysis runtime",
        ))
    }

    pub fn verify_images(
        &self,
        images: &[ImageInput],
        challenge: &LivenessChallenge,
    ) -> Result<VerificationResult, FaceVerificationError> {
        let photos = images
            .iter()
            .map(|image| {
                Ok(CapturedPhotoAnalysis {
                    hash: image_hash(image),
                    frame: self.analyze_image(image)?,
                    nsfw: None,
                })
            })
            .collect::<Result<Vec<_>, FaceVerificationError>>()?;

        Ok(verify_captured_photos(&photos, challenge, &self.thresholds))
    }

    fn require_role(&self, role: ModelRole) -> Result<(), FaceVerificationError> {
        if self.bundle.has_role(role) {
            Ok(())
        } else {
            Err(FaceVerificationError::ModelNotLoaded(match role {
                ModelRole::FaceDetector => "face detector",
                ModelRole::FaceLandmarks => "face landmarks",
                ModelRole::FaceEmbedding => "face embedding",
                ModelRole::FaceExpression => "face expression",
                ModelRole::HandLandmarks => "hand landmarks",
                ModelRole::Age => "age",
                ModelRole::Nsfw => "nsfw",
            }))
        }
    }
}

impl Analyzer for FaceVerificationEngine {
    fn analyze(&self, image: &ImageInput) -> Result<FrameAnalysis, FaceVerificationError> {
        self.analyze_image(image)
    }
}

fn image_hash(image: &ImageInput) -> String {
    let mut hasher = Sha256::new();
    hasher.update(&image.bytes);
    format!("{:x}", hasher.finalize())
}

#[cfg(not(feature = "runtime-tract"))]
fn validate_model_format(model: &ModelAsset) -> Result<(), FaceVerificationError> {
    match model.format {
        ModelFormat::Onnx => Ok(()),
    }
}

#[cfg(feature = "runtime-tract")]
fn validate_model_format(model: &ModelAsset) -> Result<(), FaceVerificationError> {
    match model.format {
        ModelFormat::Onnx => {
            let mut cursor = std::io::Cursor::new(&model.bytes);
            tract_onnx::onnx()
                .model_for_read(&mut cursor)
                .map(|_| ())
                .map_err(|err| {
                    FaceVerificationError::InvalidModel(format!(
                        "{:?} ONNX model could not be parsed by tract: {err}",
                        model.role
                    ))
                })
        }
    }
}