use std::path::Path;
use super::base::FaceDetectorTrait;
use super::location::FaceLocations;
use crate::base::path_as_cstring;
use crate::matrix::ImageMatrix;
#[derive(Clone)]
pub struct FaceDetectorCnn {
inner: FaceDetectorCnnInner,
data: std::marker::PhantomData<std::cell::UnsafeCell<()>>,
}
cpp_class!(unsafe struct FaceDetectorCnnInner as "face_detection_cnn");
impl FaceDetectorCnn {
#[cfg(feature = "embed-fd-nn")]
pub fn default() -> Result<Self, String> {
use crate::embed::{check_file_or_download, ModelFile};
let filename = ModelFile::FaceDetectorCnn;
let default_filepath = crate::embed::path_for_file(&filename);
check_file_or_download(&filename);
Self::open(default_filepath)
}
pub fn open<P: AsRef<Path>>(filename: P) -> Result<Self, String> {
let string = path_as_cstring(filename.as_ref())?;
let inner = FaceDetectorCnnInner::default();
let deserialized = unsafe {
let filename = string.as_ptr();
let network = &inner;
cpp!([filename as "char*", network as "face_detection_cnn*"] -> bool as "bool" {
try {
dlib::deserialize(filename) >> *network;
return true;
} catch (const dlib::error& exception) {
return false;
}
})
};
if !deserialized {
Err(format!(
"Failed to deserialize '{}'",
filename.as_ref().display()
))
} else {
Ok(Self {
inner,
data: std::marker::PhantomData::default(),
})
}
}
}
impl FaceDetectorTrait for FaceDetectorCnn {
fn face_locations(&self, image: &ImageMatrix) -> FaceLocations {
let detector = &self.inner;
unsafe {
cpp!([detector as "face_detection_cnn*", image as "dlib::matrix<dlib::rgb_pixel>*"] -> FaceLocations as "std::vector<dlib::rectangle>" {
std::vector<dlib::mmod_rect> detections = (*detector)(*image);
std::vector<dlib::rectangle> rects;
rects.reserve(detections.size());
for (auto &detection: detections) {
rects.push_back(detection.rect);
}
return rects;
})
}
}
}