1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
use std::sync::Arc;

use ort::ExecutionProvider;

use crate::{
    detection::{DetectionParams, FaceDetector, RustFacesResult},
    model_repository::{GitHubRepository, ModelRepository},
    BlazeFace, Nms,
};

#[derive(Clone, Copy, Debug)]
pub enum FaceDetection {
    BlazeFace640 = 0,
    BlazeFace320 = 1,
}

#[derive(Clone, Debug)]
enum OpenMode {
    File(String),
    Download,
}

/// Runtime inference provider. May not be available depending of your Onnx runtime installation.
#[derive(Clone, Copy, Debug)]
pub enum Provider {
    /// Uses the, default, CPU inference
    OrtCpu,
    /// Uses the Cuda inference.
    OrtCuda(i32),
    /// Uses Intel's OpenVINO inference.
    OrtVino(i32),
    /// Apple's Core ML inference.
    OrtCoreMl,
}

/// Inference parameters.
pub struct InferParams {
    /// Chooses the ONNX runtime provider.
    pub provider: Provider,
    /// Sets the number of intra-op threads.
    pub intra_threads: Option<usize>,
    /// Sets the number of inter-op threads.
    pub inter_threads: Option<usize>,
}

impl Default for InferParams {
    fn default() -> Self {
        Self {
            provider: Provider::OrtCpu,
            intra_threads: None,
            inter_threads: None,
        }
    }
}

/// Builder for loading or downloading and creating face detectors.
pub struct FaceDetectorBuilder {
    detector: FaceDetection,
    open_mode: OpenMode,
    params: DetectionParams,
    infer_params: InferParams,
}

impl FaceDetectorBuilder {
    /// Create a new builder for the given face detector.
    ///
    /// # Arguments
    ///
    /// * `detector` - The face detector to build.
    pub fn new(detector: FaceDetection) -> Self {
        Self {
            detector,
            open_mode: OpenMode::Download,
            params: DetectionParams::default(),
            infer_params: InferParams::default(),
        }
    }

    /// Load the model from the given file path.
    pub fn from_file(mut self, path: String) -> Self {
        self.open_mode = OpenMode::File(path);
        self
    }

    /// Download the model from the model repository.
    pub fn download(mut self) -> Self {
        self.open_mode = OpenMode::Download;
        self
    }

    /// Set the detection parameters.
    pub fn detect_params(mut self, params: DetectionParams) -> Self {
        self.params = params;
        self
    }

    /// Set the non-maximum suppression.
    pub fn nms(mut self, nms: Nms) -> Self {
        self.params.nms = nms;
        self
    }

    /// Sets the inference parameters.
    pub fn infer_params(mut self, params: InferParams) -> Self {
        self.infer_params = params;
        self
    }

    /// Builds a new detector.
    pub fn build(&self) -> RustFacesResult<Box<dyn FaceDetector>> {
        let mut ort_builder = ort::Environment::builder().with_name("RustFaces");

        ort_builder = match self.infer_params.provider {
            Provider::OrtCuda(device_id) => ort_builder
                .with_execution_providers([ExecutionProvider::cuda().with_device_id(device_id)]),
            Provider::OrtVino(_device_id) => {
                return Err(crate::RustFacesError::Other(
                    "OpenVINO is not supported yet.".to_string(),
                ));
            }
            Provider::OrtCoreMl => {
                ort_builder.with_execution_providers([ExecutionProvider::coreml()])
            }
            _ => ort_builder,
        };

        let env = Arc::new(ort_builder.build()?);
        let repository = GitHubRepository::new();

        let model_path = match &self.open_mode {
            OpenMode::Download => repository
                .get_model(self.detector)?
                .to_str()
                .unwrap()
                .to_string(),
            OpenMode::File(path) => path.clone(),
        };

        Ok(Box::new(match self.detector {
            FaceDetection::BlazeFace640 => BlazeFace::from_file(env, &model_path, self.params),
            FaceDetection::BlazeFace320 => BlazeFace::from_file(env, &model_path, self.params),
        }))
    }
}

#[cfg(test)]
mod tests {}