use std::path::{Path, PathBuf};
use std::sync::Arc;
use crate::device::DeviceType;
use crate::error::{MlError, MlResult};
use crate::model::OnnxModel;
use crate::pipeline::{PipelineInfo, PipelineTask, TypedPipeline};
use crate::preprocess::{ImagePreprocessor, TensorLayout};
use super::types::FaceEmbedding;
#[derive(Clone, Debug)]
pub struct FaceImage {
pub pixels: Vec<u8>,
pub width: u32,
pub height: u32,
}
impl FaceImage {
pub fn new(pixels: Vec<u8>, width: u32, height: u32) -> MlResult<Self> {
let expected = (width as usize) * (height as usize) * 3;
if pixels.len() != expected {
return Err(MlError::invalid_input(format!(
"face image: expected {expected} bytes for {width}x{height} RGB, got {}",
pixels.len()
)));
}
Ok(Self {
pixels,
width,
height,
})
}
}
#[derive(Clone, Debug)]
pub struct FaceEmbedderConfig {
pub input_size: (u32, u32),
pub mean: [f32; 3],
pub std: [f32; 3],
pub embedding_dim: usize,
pub input_name: Option<String>,
pub output_name: Option<String>,
}
impl Default for FaceEmbedderConfig {
fn default() -> Self {
Self {
input_size: (112, 112),
mean: [0.485, 0.456, 0.406],
std: [0.229, 0.224, 0.225],
embedding_dim: 512,
input_name: None,
output_name: None,
}
}
}
pub struct FaceEmbedder {
model: Arc<OnnxModel>,
config: FaceEmbedderConfig,
preprocessor: ImagePreprocessor,
model_path: PathBuf,
}
impl FaceEmbedder {
pub fn load(path: impl AsRef<Path>, device: DeviceType) -> MlResult<Self> {
Self::load_with_config(path, device, FaceEmbedderConfig::default())
}
pub fn load_with_config(
path: impl AsRef<Path>,
device: DeviceType,
config: FaceEmbedderConfig,
) -> MlResult<Self> {
let model_path = path.as_ref().to_path_buf();
let model = Arc::new(OnnxModel::load(&model_path, device)?);
Ok(Self::build(model, config, model_path))
}
#[must_use]
pub fn from_shared(
model: Arc<OnnxModel>,
config: FaceEmbedderConfig,
model_path: PathBuf,
) -> Self {
Self::build(model, config, model_path)
}
fn build(model: Arc<OnnxModel>, config: FaceEmbedderConfig, model_path: PathBuf) -> Self {
let (w, h) = config.input_size;
let preprocessor = ImagePreprocessor::new(w, h)
.with_tensor_layout(TensorLayout::Nchw)
.with_mean(config.mean)
.with_std(config.std);
Self {
model,
config,
preprocessor,
model_path,
}
}
#[must_use]
pub fn model_path(&self) -> &Path {
&self.model_path
}
#[must_use]
pub fn shared_model(&self) -> Arc<OnnxModel> {
Arc::clone(&self.model)
}
#[must_use]
pub fn config(&self) -> &FaceEmbedderConfig {
&self.config
}
#[must_use]
pub fn expected_input_shape(&self) -> [usize; 4] {
let (w, h) = self.config.input_size;
[1, 3, h as usize, w as usize]
}
}
impl TypedPipeline for FaceEmbedder {
type Input = FaceImage;
type Output = FaceEmbedding;
fn run(&self, input: Self::Input) -> MlResult<Self::Output> {
#[cfg(feature = "onnx")]
{
use oxionnx::Tensor;
use std::collections::HashMap;
let buf = self
.preprocessor
.process_u8_rgb(&input.pixels, input.width, input.height)?;
let shape = self.preprocessor.batch_shape();
let tensor = Tensor { data: buf, shape };
let input_name = self
.config
.input_name
.clone()
.or_else(|| self.model.info().inputs.first().map(|s| s.name.clone()))
.ok_or_else(|| MlError::invalid_input("model has no declared inputs"))?;
let output_name = self
.config
.output_name
.clone()
.or_else(|| self.model.info().outputs.first().map(|s| s.name.clone()))
.ok_or_else(|| MlError::invalid_input("model has no declared outputs"))?;
let mut inputs: HashMap<&str, Tensor> = HashMap::with_capacity(1);
inputs.insert(input_name.as_str(), tensor);
let outputs = self.model.run(&inputs)?;
let out = outputs.get(&output_name).ok_or_else(|| {
MlError::postprocess(format!("output '{output_name}' missing from model run"))
})?;
if out.data.len() != self.config.embedding_dim {
return Err(MlError::postprocess(format!(
"face embedder expected output dim {}, got {}",
self.config.embedding_dim,
out.data.len()
)));
}
Ok(FaceEmbedding::from_raw(out.data.clone()))
}
#[cfg(not(feature = "onnx"))]
{
let _ = input;
Err(MlError::FeatureDisabled("onnx"))
}
}
fn info(&self) -> PipelineInfo {
PipelineInfo {
id: "face-embedder/arcface",
name: "Face Embedder",
task: PipelineTask::FaceEmbedding,
input_size: Some(self.config.input_size),
}
}
}
impl std::fmt::Debug for FaceEmbedder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FaceEmbedder")
.field("model_path", &self.model_path)
.field("input_size", &self.config.input_size)
.field("embedding_dim", &self.config.embedding_dim)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn face_image_rejects_wrong_buffer() {
let err = FaceImage::new(vec![0u8; 10], 2, 2).expect_err("must fail");
assert!(matches!(err, MlError::InvalidInput(_)));
}
#[test]
fn face_image_accepts_correct_buffer() {
let img = FaceImage::new(vec![0u8; 4 * 4 * 3], 4, 4).expect("ok");
assert_eq!(img.width, 4);
}
#[test]
fn default_config_is_arcface_112() {
let cfg = FaceEmbedderConfig::default();
assert_eq!(cfg.input_size, (112, 112));
assert_eq!(cfg.embedding_dim, 512);
}
}