use std::{
ffi::{c_void, CString},
future::Future,
panic::AssertUnwindSafe,
path::{Path, PathBuf},
pin::Pin,
task::{Context, Poll},
};
use doom_fish_utils::completion::{error_from_cstr, AsyncCompletion, AsyncCompletionFuture};
use doom_fish_utils::panic_safe::log_callback_panic;
use crate::{error::VisionError, ffi};
#[cfg(feature = "coreml")]
use crate::classify::Classification;
#[cfg(feature = "coreml")]
use crate::coreml::{CoreMLFeatureValueObservation, CoreMLRequest};
#[cfg(feature = "detect_barcodes")]
use crate::detect_barcodes::DetectedBarcode;
#[cfg(feature = "detect_faces")]
use crate::detect_faces::DetectedFace;
use crate::human_body_pose_3d::HumanBodyPose3DObservation;
#[cfg(feature = "recognize_text")]
use crate::recognize_text::{RecognitionLevel, RecognizedText};
#[cfg(feature = "segmentation")]
use crate::segmentation::{SegmentationMask, SegmentationQuality};
use crate::trajectories::Trajectory;
enum FutureState<T> {
Ready(Option<Result<T, VisionError>>),
Pending(AsyncCompletionFuture<T>),
}
impl<T> FutureState<T> {
const fn ready_err(error: VisionError) -> Self {
Self::Ready(Some(Err(error)))
}
const fn pending(future: AsyncCompletionFuture<T>) -> Self {
Self::Pending(future)
}
}
impl<T: Unpin> Future for FutureState<T> {
type Output = Result<T, VisionError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.as_mut().get_mut() {
Self::Ready(result) => Poll::Ready(
result
.take()
.expect("async Vision future polled after completion"),
),
Self::Pending(future) => Pin::new(future)
.poll(cx)
.map(|result| result.map_err(VisionError::RequestFailed)),
}
}
}
fn path_to_cstring(path: impl AsRef<Path>) -> Result<CString, VisionError> {
let path_str = path
.as_ref()
.to_str()
.ok_or_else(|| VisionError::InvalidArgument("non-UTF-8 path".into()))?;
CString::new(path_str)
.map_err(|error| VisionError::InvalidArgument(format!("path NUL byte: {error}")))
}
struct WorkerFuture<T> {
inner: AsyncCompletionFuture<Result<T, VisionError>>,
}
impl<T> std::fmt::Debug for WorkerFuture<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WorkerFuture").finish_non_exhaustive()
}
}
impl<T> Future for WorkerFuture<T> {
type Output = Result<T, VisionError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.inner).poll(cx).map(|result| {
result.unwrap_or_else(|message| {
Err(VisionError::Unknown {
code: ffi::status::UNKNOWN,
message,
})
})
})
}
}
fn run_sync_on_worker<T, F>(work: F) -> WorkerFuture<T>
where
T: Send + 'static,
F: FnOnce() -> Result<T, VisionError> + Send + 'static,
{
let (future, ctx) = AsyncCompletion::<Result<T, VisionError>>::create();
let ctx = ctx as usize;
std::thread::spawn(move || unsafe {
AsyncCompletion::complete_ok(ctx as *mut c_void, work());
});
WorkerFuture { inner: future }
}
#[cfg(feature = "coreml")]
pub struct CoreMLClassifyFuture {
inner: WorkerFuture<Vec<Classification>>,
}
#[cfg(feature = "coreml")]
impl std::fmt::Debug for CoreMLClassifyFuture {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CoreMLClassifyFuture")
.finish_non_exhaustive()
}
}
#[cfg(feature = "coreml")]
impl Future for CoreMLClassifyFuture {
type Output = Result<Vec<Classification>, VisionError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.inner).poll(cx)
}
}
#[cfg(feature = "coreml")]
pub struct CoreMLFeatureValueFuture {
inner: WorkerFuture<Option<CoreMLFeatureValueObservation>>,
}
#[cfg(feature = "coreml")]
impl std::fmt::Debug for CoreMLFeatureValueFuture {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CoreMLFeatureValueFuture")
.finish_non_exhaustive()
}
}
#[cfg(feature = "coreml")]
impl Future for CoreMLFeatureValueFuture {
type Output = Result<Option<CoreMLFeatureValueObservation>, VisionError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.inner).poll(cx)
}
}
#[cfg(feature = "coreml")]
#[derive(Debug, Clone)]
pub struct AsyncCoreMLRequest {
request: CoreMLRequest,
}
#[cfg(feature = "coreml")]
impl AsyncCoreMLRequest {
#[must_use]
pub const fn new(request: CoreMLRequest) -> Self {
Self { request }
}
#[must_use]
pub fn classify_in_path(&self, path: impl AsRef<Path>) -> CoreMLClassifyFuture {
let request = self.request.clone();
let path = path.as_ref().to_path_buf();
CoreMLClassifyFuture {
inner: run_sync_on_worker(move || request.classify(path.as_path())),
}
}
#[must_use]
pub fn feature_value_in_path(&self, path: impl AsRef<Path>) -> CoreMLFeatureValueFuture {
let request = self.request.clone();
let path = path.as_ref().to_path_buf();
CoreMLFeatureValueFuture {
inner: run_sync_on_worker(move || request.feature_value(path.as_path())),
}
}
}
pub struct DetectHumanBodyPose3DFuture {
inner: WorkerFuture<Vec<HumanBodyPose3DObservation>>,
}
impl std::fmt::Debug for DetectHumanBodyPose3DFuture {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DetectHumanBodyPose3DFuture")
.finish_non_exhaustive()
}
}
impl Future for DetectHumanBodyPose3DFuture {
type Output = Result<Vec<HumanBodyPose3DObservation>, VisionError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.inner).poll(cx)
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct AsyncDetectHumanBodyPose3D;
impl AsyncDetectHumanBodyPose3D {
#[must_use]
pub const fn new() -> Self {
Self
}
#[must_use]
pub fn detect_in_path(&self, path: impl AsRef<Path>) -> DetectHumanBodyPose3DFuture {
let path = path.as_ref().to_path_buf();
DetectHumanBodyPose3DFuture {
inner: run_sync_on_worker(move || {
crate::human_body_pose_3d::detect_human_body_pose_3d_observations(path.as_path())
}),
}
}
}
pub struct DetectTrajectoriesFuture {
inner: WorkerFuture<Vec<Trajectory>>,
}
impl std::fmt::Debug for DetectTrajectoriesFuture {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DetectTrajectoriesFuture")
.finish_non_exhaustive()
}
}
impl Future for DetectTrajectoriesFuture {
type Output = Result<Vec<Trajectory>, VisionError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.inner).poll(cx)
}
}
#[derive(Debug, Clone)]
pub struct AsyncDetectTrajectories {
trajectory_length: usize,
}
impl AsyncDetectTrajectories {
#[must_use]
pub const fn new(trajectory_length: usize) -> Self {
Self { trajectory_length }
}
#[must_use]
pub fn detect_in_path(&self, path: impl AsRef<Path>) -> DetectTrajectoriesFuture {
let path: PathBuf = path.as_ref().to_path_buf();
let trajectory_length = self.trajectory_length;
DetectTrajectoriesFuture {
inner: run_sync_on_worker(move || {
crate::trajectories::detect_trajectories(path.as_path(), trajectory_length)
}),
}
}
}
#[cfg(feature = "recognize_text")]
unsafe fn parse_text_result(
result: *const c_void,
error: *const i8,
) -> Result<Vec<RecognizedText>, String> {
if !error.is_null() {
return Err(unsafe { error_from_cstr(error) });
}
if result.is_null() {
return Err("text recognition returned null".into());
}
let raw = unsafe { &*(result.cast::<ffi::AsyncArrayResultRaw>()) };
let texts = if raw.array.is_null() || raw.count == 0 {
Vec::new()
} else {
let typed = raw.array.cast::<ffi::RecognizedTextRaw>();
let mut out = Vec::with_capacity(raw.count);
for index in 0..raw.count {
let entry = unsafe { &*typed.add(index) };
let text = if entry.text.is_null() {
String::new()
} else {
unsafe { std::ffi::CStr::from_ptr(entry.text) }
.to_string_lossy()
.into_owned()
};
out.push(RecognizedText {
text,
confidence: entry.confidence,
bounding_box: crate::recognize_text::BoundingBox {
x: entry.bbox_x,
y: entry.bbox_y,
width: entry.bbox_w,
height: entry.bbox_h,
},
});
}
unsafe { ffi::vn_recognized_text_free(raw.array, raw.count) };
out
};
unsafe { ffi::vn_async_array_result_free(result.cast_mut()) };
Ok(texts)
}
#[cfg(feature = "recognize_text")]
extern "C" fn text_result_cb(result: *const c_void, error: *const i8, ctx: *mut c_void) {
let outcome =
std::panic::catch_unwind(AssertUnwindSafe(|| unsafe { parse_text_result(result, error) }));
match outcome {
Ok(Ok(texts)) => {
unsafe { AsyncCompletion::complete_ok(ctx, texts) };
}
Ok(Err(msg)) => {
unsafe { AsyncCompletion::<Vec<RecognizedText>>::complete_err(ctx, msg) };
}
Err(payload) => {
log_callback_panic("text_result_cb", payload.as_ref());
unsafe {
AsyncCompletion::<Vec<RecognizedText>>::complete_err(
ctx,
"panic in Vision text_result_cb".into(),
);
};
}
}
}
#[cfg(feature = "recognize_text")]
pub struct RecognizeTextFuture {
inner: FutureState<Vec<RecognizedText>>,
}
#[cfg(feature = "recognize_text")]
impl std::fmt::Debug for RecognizeTextFuture {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RecognizeTextFuture")
.finish_non_exhaustive()
}
}
#[cfg(feature = "recognize_text")]
impl Future for RecognizeTextFuture {
type Output = Result<Vec<RecognizedText>, VisionError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.inner).poll(cx)
}
}
#[cfg(feature = "recognize_text")]
#[derive(Debug, Clone)]
pub struct AsyncRecognizeText {
recognition_level: RecognitionLevel,
uses_language_correction: bool,
}
#[cfg(feature = "recognize_text")]
impl Default for AsyncRecognizeText {
fn default() -> Self {
Self::new(RecognitionLevel::Accurate, true)
}
}
#[cfg(feature = "recognize_text")]
impl AsyncRecognizeText {
#[must_use]
pub const fn new(recognition_level: RecognitionLevel, uses_language_correction: bool) -> Self {
Self {
recognition_level,
uses_language_correction,
}
}
pub fn recognize_in_path(&self, path: impl AsRef<Path>) -> RecognizeTextFuture {
match path_to_cstring(path) {
Err(error) => RecognizeTextFuture {
inner: FutureState::ready_err(error),
},
Ok(path_c) => {
let (future, ctx) = AsyncCompletion::create();
unsafe {
ffi::vn_recognize_text_in_path_async(
path_c.as_ptr(),
self.recognition_level.as_raw(),
self.uses_language_correction,
text_result_cb,
ctx,
);
};
RecognizeTextFuture {
inner: FutureState::pending(future),
}
}
}
}
}
#[cfg(feature = "detect_faces")]
unsafe fn parse_face_result(
result: *const c_void,
error: *const i8,
) -> Result<Vec<DetectedFace>, String> {
if !error.is_null() {
return Err(unsafe { error_from_cstr(error) });
}
if result.is_null() {
return Err("face detection returned null".into());
}
let raw = unsafe { &*(result.cast::<ffi::AsyncArrayResultRaw>()) };
let faces = if raw.array.is_null() || raw.count == 0 {
Vec::new()
} else {
let typed = raw.array.cast::<ffi::DetectedFaceRaw>();
let mut out = Vec::with_capacity(raw.count);
let nan_to_none = |value: f32| if value.is_nan() { None } else { Some(value) };
for index in 0..raw.count {
let entry = unsafe { &*typed.add(index) };
out.push(DetectedFace {
bounding_box: crate::recognize_text::BoundingBox {
x: entry.bbox_x,
y: entry.bbox_y,
width: entry.bbox_w,
height: entry.bbox_h,
},
confidence: entry.confidence,
roll: nan_to_none(entry.roll),
yaw: nan_to_none(entry.yaw),
pitch: nan_to_none(entry.pitch),
});
}
unsafe { ffi::vn_detected_faces_free(raw.array, raw.count) };
out
};
unsafe { ffi::vn_async_array_result_free(result.cast_mut()) };
Ok(faces)
}
#[cfg(feature = "detect_faces")]
extern "C" fn face_result_cb(result: *const c_void, error: *const i8, ctx: *mut c_void) {
let outcome =
std::panic::catch_unwind(AssertUnwindSafe(|| unsafe { parse_face_result(result, error) }));
match outcome {
Ok(Ok(faces)) => {
unsafe { AsyncCompletion::complete_ok(ctx, faces) };
}
Ok(Err(msg)) => {
unsafe { AsyncCompletion::<Vec<DetectedFace>>::complete_err(ctx, msg) };
}
Err(payload) => {
log_callback_panic("face_result_cb", payload.as_ref());
unsafe {
AsyncCompletion::<Vec<DetectedFace>>::complete_err(
ctx,
"panic in Vision face_result_cb".into(),
);
};
}
}
}
#[cfg(feature = "detect_faces")]
pub struct DetectFacesFuture {
inner: FutureState<Vec<DetectedFace>>,
}
#[cfg(feature = "detect_faces")]
impl std::fmt::Debug for DetectFacesFuture {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DetectFacesFuture").finish_non_exhaustive()
}
}
#[cfg(feature = "detect_faces")]
impl Future for DetectFacesFuture {
type Output = Result<Vec<DetectedFace>, VisionError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.inner).poll(cx)
}
}
#[cfg(feature = "detect_faces")]
#[derive(Debug, Clone, Copy, Default)]
pub struct AsyncDetectFaces;
#[cfg(feature = "detect_faces")]
impl AsyncDetectFaces {
#[must_use]
pub const fn new() -> Self {
Self
}
pub fn detect_in_path(&self, path: impl AsRef<Path>) -> DetectFacesFuture {
match path_to_cstring(path) {
Err(error) => DetectFacesFuture {
inner: FutureState::ready_err(error),
},
Ok(path_c) => {
let (future, ctx) = AsyncCompletion::create();
unsafe {
ffi::vn_detect_faces_in_path_async(path_c.as_ptr(), face_result_cb, ctx);
};
DetectFacesFuture {
inner: FutureState::pending(future),
}
}
}
}
}
#[cfg(feature = "detect_barcodes")]
unsafe fn parse_barcode_result(
result: *const c_void,
error: *const i8,
) -> Result<Vec<DetectedBarcode>, String> {
if !error.is_null() {
return Err(unsafe { error_from_cstr(error) });
}
if result.is_null() {
return Err("barcode detection returned null".into());
}
let raw = unsafe { &*(result.cast::<ffi::AsyncArrayResultRaw>()) };
let barcodes = if raw.array.is_null() || raw.count == 0 {
Vec::new()
} else {
let typed = raw.array.cast::<ffi::DetectedBarcodeRaw>();
let mut out = Vec::with_capacity(raw.count);
for index in 0..raw.count {
let entry = unsafe { &*typed.add(index) };
let payload = if entry.payload.is_null() {
String::new()
} else {
unsafe { std::ffi::CStr::from_ptr(entry.payload) }
.to_string_lossy()
.into_owned()
};
let symbology = if entry.symbology.is_null() {
String::new()
} else {
unsafe { std::ffi::CStr::from_ptr(entry.symbology) }
.to_string_lossy()
.into_owned()
};
out.push(DetectedBarcode {
payload,
symbology,
confidence: entry.confidence,
bounding_box: crate::recognize_text::BoundingBox {
x: entry.bbox_x,
y: entry.bbox_y,
width: entry.bbox_w,
height: entry.bbox_h,
},
});
}
unsafe { ffi::vn_detected_barcodes_free(raw.array, raw.count) };
out
};
unsafe { ffi::vn_async_array_result_free(result.cast_mut()) };
Ok(barcodes)
}
#[cfg(feature = "detect_barcodes")]
extern "C" fn barcode_result_cb(result: *const c_void, error: *const i8, ctx: *mut c_void) {
let outcome = std::panic::catch_unwind(AssertUnwindSafe(|| unsafe {
parse_barcode_result(result, error)
}));
match outcome {
Ok(Ok(barcodes)) => {
unsafe { AsyncCompletion::complete_ok(ctx, barcodes) };
}
Ok(Err(msg)) => {
unsafe { AsyncCompletion::<Vec<DetectedBarcode>>::complete_err(ctx, msg) };
}
Err(payload) => {
log_callback_panic("barcode_result_cb", payload.as_ref());
unsafe {
AsyncCompletion::<Vec<DetectedBarcode>>::complete_err(
ctx,
"panic in Vision barcode_result_cb".into(),
);
};
}
}
}
#[cfg(feature = "detect_barcodes")]
pub struct DetectBarcodesFuture {
inner: FutureState<Vec<DetectedBarcode>>,
}
#[cfg(feature = "detect_barcodes")]
impl std::fmt::Debug for DetectBarcodesFuture {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DetectBarcodesFuture")
.finish_non_exhaustive()
}
}
#[cfg(feature = "detect_barcodes")]
impl Future for DetectBarcodesFuture {
type Output = Result<Vec<DetectedBarcode>, VisionError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.inner).poll(cx)
}
}
#[cfg(feature = "detect_barcodes")]
#[derive(Debug, Clone, Copy, Default)]
pub struct AsyncDetectBarcodes;
#[cfg(feature = "detect_barcodes")]
impl AsyncDetectBarcodes {
#[must_use]
pub const fn new() -> Self {
Self
}
pub fn detect_in_path(&self, path: impl AsRef<Path>) -> DetectBarcodesFuture {
match path_to_cstring(path) {
Err(error) => DetectBarcodesFuture {
inner: FutureState::ready_err(error),
},
Ok(path_c) => {
let (future, ctx) = AsyncCompletion::create();
unsafe {
ffi::vn_detect_barcodes_in_path_async(path_c.as_ptr(), barcode_result_cb, ctx);
};
DetectBarcodesFuture {
inner: FutureState::pending(future),
}
}
}
}
}
#[cfg(feature = "segmentation")]
unsafe fn parse_seg_result(
result: *const c_void,
error: *const i8,
) -> Result<SegmentationMask, String> {
if !error.is_null() {
return Err(unsafe { error_from_cstr(error) });
}
if result.is_null() {
return Err("segmentation returned null".into());
}
let raw = unsafe { &*(result.cast::<ffi::AsyncSegResultRaw>()) };
if raw.bytes.is_null() {
unsafe { ffi::vn_async_seg_result_free(result.cast_mut()) };
return Err("segmentation bytes were null".into());
}
let len = raw.height.saturating_mul(raw.bytes_per_row);
let bytes = unsafe { core::slice::from_raw_parts(raw.bytes, len) }.to_vec();
let mask = SegmentationMask {
width: raw.width,
height: raw.height,
bytes_per_row: raw.bytes_per_row,
bytes,
};
unsafe { ffi::vn_async_seg_result_free(result.cast_mut()) };
Ok(mask)
}
#[cfg(feature = "segmentation")]
extern "C" fn seg_result_cb(result: *const c_void, error: *const i8, ctx: *mut c_void) {
let outcome =
std::panic::catch_unwind(AssertUnwindSafe(|| unsafe { parse_seg_result(result, error) }));
match outcome {
Ok(Ok(mask)) => {
unsafe { AsyncCompletion::complete_ok(ctx, mask) };
}
Ok(Err(msg)) => {
unsafe { AsyncCompletion::<SegmentationMask>::complete_err(ctx, msg) };
}
Err(payload) => {
log_callback_panic("seg_result_cb", payload.as_ref());
unsafe {
AsyncCompletion::<SegmentationMask>::complete_err(
ctx,
"panic in Vision seg_result_cb".into(),
);
};
}
}
}
#[cfg(feature = "segmentation")]
pub struct PersonSegmentationFuture {
inner: FutureState<SegmentationMask>,
}
#[cfg(feature = "segmentation")]
impl std::fmt::Debug for PersonSegmentationFuture {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PersonSegmentationFuture")
.finish_non_exhaustive()
}
}
#[cfg(feature = "segmentation")]
impl Future for PersonSegmentationFuture {
type Output = Result<SegmentationMask, VisionError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.inner).poll(cx)
}
}
#[cfg(feature = "segmentation")]
#[derive(Debug, Clone, Copy)]
pub struct AsyncPersonSegmentation {
quality: SegmentationQuality,
}
#[cfg(feature = "segmentation")]
impl Default for AsyncPersonSegmentation {
fn default() -> Self {
Self::new(SegmentationQuality::Balanced)
}
}
#[cfg(feature = "segmentation")]
impl AsyncPersonSegmentation {
#[must_use]
pub const fn new(quality: SegmentationQuality) -> Self {
Self { quality }
}
pub fn generate_in_path(&self, path: impl AsRef<Path>) -> PersonSegmentationFuture {
match path_to_cstring(path) {
Err(error) => PersonSegmentationFuture {
inner: FutureState::ready_err(error),
},
Ok(path_c) => {
let (future, ctx) = AsyncCompletion::create();
unsafe {
ffi::vn_generate_person_segmentation_async(
path_c.as_ptr(),
self.quality as i32,
seg_result_cb,
ctx,
);
};
PersonSegmentationFuture {
inner: FutureState::pending(future),
}
}
}
}
}