1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
use std::sync::Arc;
use ort::{
execution_providers::{CUDAExecutionProviderOptions, CoreMLExecutionProviderOptions},
ExecutionProvider,
};
use crate::{
blazeface::BlazeFaceParams,
detection::{FaceDetector, RustFacesResult},
model_repository::{GitHubRepository, ModelRepository},
BlazeFace, MtCnn, MtCnnParams,
};
pub enum FaceDetection {
BlazeFace640(BlazeFaceParams),
BlazeFace320(BlazeFaceParams),
MtCnn(MtCnnParams),
}
#[derive(Clone, Debug)]
enum OpenMode {
File(String),
Download,
}
/// Runtime inference provider. Some may not be available depending of your Onnx runtime installation.
#[derive(Clone, Copy, Debug)]
pub enum Provider {
/// Uses the, default, CPU inference
OrtCpu,
/// Uses the Cuda inference.
OrtCuda(i32),
/// Uses Intel's OpenVINO inference.
OrtVino(i32),
/// Apple's Core ML inference.
OrtCoreMl,
}
/// Inference parameters.
pub struct InferParams {
/// Chooses the ONNX runtime provider.
pub provider: Provider,
/// Sets the number of intra-op threads.
pub intra_threads: Option<usize>,
/// Sets the number of inter-op threads.
pub inter_threads: Option<usize>,
}
impl Default for InferParams {
/// Default provider is `OrtCpu` (Onnx CPU).
fn default() -> Self {
Self {
provider: Provider::OrtCpu,
intra_threads: None,
inter_threads: None,
}
}
}
/// Builder for loading or downloading, configuring, and creating face detectors.
pub struct FaceDetectorBuilder {
detector: FaceDetection,
open_mode: OpenMode,
infer_params: InferParams,
}
impl FaceDetectorBuilder {
/// Create a new builder for the given face detector.
///
/// # Arguments
///
/// * `detector` - The face detector to build.
pub fn new(detector: FaceDetection) -> Self {
Self {
detector,
open_mode: OpenMode::Download,
infer_params: InferParams::default(),
}
}
/// Load the model from the given file path.
///
/// # Arguments
///
/// * `path` - Path to the model file.
pub fn from_file(mut self, path: String) -> Self {
self.open_mode = OpenMode::File(path);
self
}
/// Sets the model to be downloaded from the model repository.
pub fn download(mut self) -> Self {
self.open_mode = OpenMode::Download;
self
}
/// Sets the inference parameters.
pub fn infer_params(mut self, params: InferParams) -> Self {
self.infer_params = params;
self
}
/// Instantiates a new detector.
///
/// # Errors
///
/// Returns an error if the model can't be loaded.
///
/// # Returns
///
/// A new face detector.
pub fn build(&self) -> RustFacesResult<Box<dyn FaceDetector>> {
let mut ort_builder = ort::Environment::builder().with_name("RustFaces");
ort_builder = match self.infer_params.provider {
Provider::OrtCuda(device_id) => {
let provider = ExecutionProvider::CUDA(CUDAExecutionProviderOptions {
device_id: device_id as u32,
..Default::default()
});
if !provider.is_available() {
eprintln!("Warning: CUDA is not available. It'll likely use CPU inference.");
}
ort_builder.with_execution_providers([provider])
}
Provider::OrtVino(_device_id) => {
return Err(crate::RustFacesError::Other(
"OpenVINO is not supported yet.".to_string(),
));
}
Provider::OrtCoreMl => {
ort_builder.with_execution_providers([ExecutionProvider::CoreML(
CoreMLExecutionProviderOptions::default(),
)])
}
_ => ort_builder,
};
let env = Arc::new(ort_builder.build()?);
let repository = GitHubRepository::new();
let model_paths = match &self.open_mode {
OpenMode::Download => repository
.get_model(&self.detector)?
.iter()
.map(|path| path.to_str().unwrap().to_string())
.collect(),
OpenMode::File(path) => vec![path.clone()],
};
match &self.detector {
FaceDetection::BlazeFace640(params) => Ok(Box::new(BlazeFace::from_file(
env,
&model_paths[0],
params.clone(),
))),
FaceDetection::BlazeFace320(params) => Ok(Box::new(BlazeFace::from_file(
env,
&model_paths[0],
params.clone(),
))),
FaceDetection::MtCnn(params) => Ok(Box::new(
MtCnn::from_file(
env,
&model_paths[0],
&model_paths[1],
&model_paths[2],
params.clone(),
)
.unwrap(),
)),
}
}
}
#[cfg(test)]
mod tests {}