use std::sync::Arc;
use ort::ExecutionProvider;
use crate::{
detection::{DetectionParams, FaceDetector, RustFacesResult},
model_repository::{GitHubRepository, ModelRepository},
BlazeFace, Nms,
};
#[derive(Clone, Copy, Debug)]
pub enum FaceDetection {
BlazeFace640 = 0,
BlazeFace320 = 1,
}
#[derive(Clone, Debug)]
enum OpenMode {
File(String),
Download,
}
#[derive(Clone, Copy, Debug)]
pub enum Provider {
OrtCpu,
OrtCuda(i32),
OrtVino(i32),
OrtCoreMl,
}
pub struct InferParams {
pub provider: Provider,
pub intra_threads: Option<usize>,
pub inter_threads: Option<usize>,
}
impl Default for InferParams {
fn default() -> Self {
Self {
provider: Provider::OrtCpu,
intra_threads: None,
inter_threads: None,
}
}
}
pub struct FaceDetectorBuilder {
detector: FaceDetection,
open_mode: OpenMode,
params: DetectionParams,
infer_params: InferParams,
}
impl FaceDetectorBuilder {
pub fn new(detector: FaceDetection) -> Self {
Self {
detector,
open_mode: OpenMode::Download,
params: DetectionParams::default(),
infer_params: InferParams::default(),
}
}
pub fn from_file(mut self, path: String) -> Self {
self.open_mode = OpenMode::File(path);
self
}
pub fn download(mut self) -> Self {
self.open_mode = OpenMode::Download;
self
}
pub fn detect_params(mut self, params: DetectionParams) -> Self {
self.params = params;
self
}
pub fn nms(mut self, nms: Nms) -> Self {
self.params.nms = nms;
self
}
pub fn infer_params(mut self, params: InferParams) -> Self {
self.infer_params = params;
self
}
pub fn build(&self) -> RustFacesResult<Box<dyn FaceDetector>> {
let mut ort_builder = ort::Environment::builder().with_name("RustFaces");
ort_builder = match self.infer_params.provider {
Provider::OrtCuda(device_id) => ort_builder
.with_execution_providers([ExecutionProvider::cuda().with_device_id(device_id)]),
Provider::OrtVino(_device_id) => {
return Err(crate::RustFacesError::Other(
"OpenVINO is not supported yet.".to_string(),
));
}
Provider::OrtCoreMl => {
ort_builder.with_execution_providers([ExecutionProvider::coreml()])
}
_ => ort_builder,
};
let env = Arc::new(ort_builder.build()?);
let repository = GitHubRepository::new();
let model_path = match &self.open_mode {
OpenMode::Download => repository
.get_model(self.detector)?
.to_str()
.unwrap()
.to_string(),
OpenMode::File(path) => path.clone(),
};
Ok(Box::new(match self.detector {
FaceDetection::BlazeFace640 => BlazeFace::from_file(env, &model_path, self.params),
FaceDetection::BlazeFace320 => BlazeFace::from_file(env, &model_path, self.params),
}))
}
}
#[cfg(test)]
mod tests {}