1use std::{
32 ffi::{c_void, CString},
33 future::Future,
34 panic::AssertUnwindSafe,
35 path::Path,
36 pin::Pin,
37 task::{Context, Poll},
38};
39
40use doom_fish_utils::completion::{error_from_cstr, AsyncCompletion, AsyncCompletionFuture};
41use doom_fish_utils::panic_safe::log_callback_panic;
42
43use crate::{error::VisionError, ffi};
44
45#[cfg(feature = "detect_barcodes")]
46use crate::detect_barcodes::DetectedBarcode;
47#[cfg(feature = "detect_faces")]
48use crate::detect_faces::DetectedFace;
49#[cfg(feature = "recognize_text")]
50use crate::recognize_text::{RecognitionLevel, RecognizedText};
51#[cfg(feature = "segmentation")]
52use crate::segmentation::{SegmentationMask, SegmentationQuality};
53
54enum FutureState<T> {
55 Ready(Option<Result<T, VisionError>>),
56 Pending(AsyncCompletionFuture<T>),
57}
58
59impl<T> FutureState<T> {
60 const fn ready_err(error: VisionError) -> Self {
61 Self::Ready(Some(Err(error)))
62 }
63
64 const fn pending(future: AsyncCompletionFuture<T>) -> Self {
65 Self::Pending(future)
66 }
67}
68
69impl<T: Unpin> Future for FutureState<T> {
70 type Output = Result<T, VisionError>;
71
72 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
73 match self.as_mut().get_mut() {
74 Self::Ready(result) => Poll::Ready(
75 result
76 .take()
77 .expect("async Vision future polled after completion"),
78 ),
79 Self::Pending(future) => Pin::new(future)
80 .poll(cx)
81 .map(|result| result.map_err(VisionError::RequestFailed)),
82 }
83 }
84}
85
86fn path_to_cstring(path: impl AsRef<Path>) -> Result<CString, VisionError> {
87 let path_str = path
88 .as_ref()
89 .to_str()
90 .ok_or_else(|| VisionError::InvalidArgument("non-UTF-8 path".into()))?;
91 CString::new(path_str)
92 .map_err(|error| VisionError::InvalidArgument(format!("path NUL byte: {error}")))
93}
94
95#[cfg(feature = "recognize_text")]
111unsafe fn parse_text_result(
112 result: *const c_void,
113 error: *const i8,
114) -> Result<Vec<RecognizedText>, String> {
115 if !error.is_null() {
116 return Err(unsafe { error_from_cstr(error) });
118 }
119 if result.is_null() {
120 return Err("text recognition returned null".into());
121 }
122
123 let raw = unsafe { &*(result.cast::<ffi::AsyncArrayResultRaw>()) };
125 let texts = if raw.array.is_null() || raw.count == 0 {
126 Vec::new()
127 } else {
128 let typed = raw.array.cast::<ffi::RecognizedTextRaw>();
129 let mut out = Vec::with_capacity(raw.count);
130 for index in 0..raw.count {
131 let entry = unsafe { &*typed.add(index) };
133 let text = if entry.text.is_null() {
134 String::new()
135 } else {
136 unsafe { std::ffi::CStr::from_ptr(entry.text) }
138 .to_string_lossy()
139 .into_owned()
140 };
141 out.push(RecognizedText {
142 text,
143 confidence: entry.confidence,
144 bounding_box: crate::recognize_text::BoundingBox {
145 x: entry.bbox_x,
146 y: entry.bbox_y,
147 width: entry.bbox_w,
148 height: entry.bbox_h,
149 },
150 });
151 }
152 unsafe { ffi::vn_recognized_text_free(raw.array, raw.count) };
155 out
156 };
157
158 unsafe { ffi::vn_async_array_result_free(result.cast_mut()) };
161 Ok(texts)
162}
163
164#[cfg(feature = "recognize_text")]
174extern "C" fn text_result_cb(result: *const c_void, error: *const i8, ctx: *mut c_void) {
175 let outcome =
179 std::panic::catch_unwind(AssertUnwindSafe(|| unsafe { parse_text_result(result, error) }));
180 match outcome {
181 Ok(Ok(texts)) => {
182 unsafe { AsyncCompletion::complete_ok(ctx, texts) };
185 }
186 Ok(Err(msg)) => {
187 unsafe { AsyncCompletion::<Vec<RecognizedText>>::complete_err(ctx, msg) };
189 }
190 Err(payload) => {
191 log_callback_panic("text_result_cb", payload.as_ref());
192 unsafe {
194 AsyncCompletion::<Vec<RecognizedText>>::complete_err(
195 ctx,
196 "panic in Vision text_result_cb".into(),
197 );
198 };
199 }
200 }
201}
202
203#[cfg(feature = "recognize_text")]
205pub struct RecognizeTextFuture {
206 inner: FutureState<Vec<RecognizedText>>,
207}
208
209#[cfg(feature = "recognize_text")]
210impl std::fmt::Debug for RecognizeTextFuture {
211 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
212 f.debug_struct("RecognizeTextFuture")
213 .finish_non_exhaustive()
214 }
215}
216
217#[cfg(feature = "recognize_text")]
218impl Future for RecognizeTextFuture {
219 type Output = Result<Vec<RecognizedText>, VisionError>;
220
221 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
222 Pin::new(&mut self.inner).poll(cx)
223 }
224}
225
226#[cfg(feature = "recognize_text")]
231#[derive(Debug, Clone)]
232pub struct AsyncRecognizeText {
233 recognition_level: RecognitionLevel,
234 uses_language_correction: bool,
235}
236
237#[cfg(feature = "recognize_text")]
238impl Default for AsyncRecognizeText {
239 fn default() -> Self {
240 Self::new(RecognitionLevel::Accurate, true)
241 }
242}
243
244#[cfg(feature = "recognize_text")]
245impl AsyncRecognizeText {
246 #[must_use]
247 pub const fn new(recognition_level: RecognitionLevel, uses_language_correction: bool) -> Self {
248 Self {
249 recognition_level,
250 uses_language_correction,
251 }
252 }
253
254 pub fn recognize_in_path(&self, path: impl AsRef<Path>) -> RecognizeTextFuture {
261 match path_to_cstring(path) {
262 Err(error) => RecognizeTextFuture {
263 inner: FutureState::ready_err(error),
264 },
265 Ok(path_c) => {
266 let (future, ctx) = AsyncCompletion::create();
267 unsafe {
272 ffi::vn_recognize_text_in_path_async(
273 path_c.as_ptr(),
274 self.recognition_level.as_raw(),
275 self.uses_language_correction,
276 text_result_cb,
277 ctx,
278 );
279 };
280 RecognizeTextFuture {
281 inner: FutureState::pending(future),
282 }
283 }
284 }
285 }
286}
287
288#[cfg(feature = "detect_faces")]
300unsafe fn parse_face_result(
301 result: *const c_void,
302 error: *const i8,
303) -> Result<Vec<DetectedFace>, String> {
304 if !error.is_null() {
305 return Err(unsafe { error_from_cstr(error) });
307 }
308 if result.is_null() {
309 return Err("face detection returned null".into());
310 }
311
312 let raw = unsafe { &*(result.cast::<ffi::AsyncArrayResultRaw>()) };
314 let faces = if raw.array.is_null() || raw.count == 0 {
315 Vec::new()
316 } else {
317 let typed = raw.array.cast::<ffi::DetectedFaceRaw>();
318 let mut out = Vec::with_capacity(raw.count);
319 let nan_to_none = |value: f32| if value.is_nan() { None } else { Some(value) };
320 for index in 0..raw.count {
321 let entry = unsafe { &*typed.add(index) };
323 out.push(DetectedFace {
324 bounding_box: crate::recognize_text::BoundingBox {
325 x: entry.bbox_x,
326 y: entry.bbox_y,
327 width: entry.bbox_w,
328 height: entry.bbox_h,
329 },
330 confidence: entry.confidence,
331 roll: nan_to_none(entry.roll),
332 yaw: nan_to_none(entry.yaw),
333 pitch: nan_to_none(entry.pitch),
334 });
335 }
336 unsafe { ffi::vn_detected_faces_free(raw.array, raw.count) };
339 out
340 };
341
342 unsafe { ffi::vn_async_array_result_free(result.cast_mut()) };
344 Ok(faces)
345}
346
347#[cfg(feature = "detect_faces")]
351extern "C" fn face_result_cb(result: *const c_void, error: *const i8, ctx: *mut c_void) {
352 let outcome =
355 std::panic::catch_unwind(AssertUnwindSafe(|| unsafe { parse_face_result(result, error) }));
356 match outcome {
357 Ok(Ok(faces)) => {
358 unsafe { AsyncCompletion::complete_ok(ctx, faces) };
360 }
361 Ok(Err(msg)) => {
362 unsafe { AsyncCompletion::<Vec<DetectedFace>>::complete_err(ctx, msg) };
364 }
365 Err(payload) => {
366 log_callback_panic("face_result_cb", payload.as_ref());
367 unsafe {
369 AsyncCompletion::<Vec<DetectedFace>>::complete_err(
370 ctx,
371 "panic in Vision face_result_cb".into(),
372 );
373 };
374 }
375 }
376}
377
378#[cfg(feature = "detect_faces")]
380pub struct DetectFacesFuture {
381 inner: FutureState<Vec<DetectedFace>>,
382}
383
384#[cfg(feature = "detect_faces")]
385impl std::fmt::Debug for DetectFacesFuture {
386 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
387 f.debug_struct("DetectFacesFuture").finish_non_exhaustive()
388 }
389}
390
391#[cfg(feature = "detect_faces")]
392impl Future for DetectFacesFuture {
393 type Output = Result<Vec<DetectedFace>, VisionError>;
394
395 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
396 Pin::new(&mut self.inner).poll(cx)
397 }
398}
399
400#[cfg(feature = "detect_faces")]
402#[derive(Debug, Clone, Copy, Default)]
403pub struct AsyncDetectFaces;
404
405#[cfg(feature = "detect_faces")]
406impl AsyncDetectFaces {
407 #[must_use]
408 pub const fn new() -> Self {
409 Self
410 }
411
412 pub fn detect_in_path(&self, path: impl AsRef<Path>) -> DetectFacesFuture {
418 match path_to_cstring(path) {
419 Err(error) => DetectFacesFuture {
420 inner: FutureState::ready_err(error),
421 },
422 Ok(path_c) => {
423 let (future, ctx) = AsyncCompletion::create();
424 unsafe {
427 ffi::vn_detect_faces_in_path_async(path_c.as_ptr(), face_result_cb, ctx);
428 };
429 DetectFacesFuture {
430 inner: FutureState::pending(future),
431 }
432 }
433 }
434 }
435}
436
437#[cfg(feature = "detect_barcodes")]
449unsafe fn parse_barcode_result(
450 result: *const c_void,
451 error: *const i8,
452) -> Result<Vec<DetectedBarcode>, String> {
453 if !error.is_null() {
454 return Err(unsafe { error_from_cstr(error) });
456 }
457 if result.is_null() {
458 return Err("barcode detection returned null".into());
459 }
460
461 let raw = unsafe { &*(result.cast::<ffi::AsyncArrayResultRaw>()) };
463 let barcodes = if raw.array.is_null() || raw.count == 0 {
464 Vec::new()
465 } else {
466 let typed = raw.array.cast::<ffi::DetectedBarcodeRaw>();
467 let mut out = Vec::with_capacity(raw.count);
468 for index in 0..raw.count {
469 let entry = unsafe { &*typed.add(index) };
471 let payload = if entry.payload.is_null() {
472 String::new()
473 } else {
474 unsafe { std::ffi::CStr::from_ptr(entry.payload) }
476 .to_string_lossy()
477 .into_owned()
478 };
479 let symbology = if entry.symbology.is_null() {
480 String::new()
481 } else {
482 unsafe { std::ffi::CStr::from_ptr(entry.symbology) }
484 .to_string_lossy()
485 .into_owned()
486 };
487 out.push(DetectedBarcode {
488 payload,
489 symbology,
490 confidence: entry.confidence,
491 bounding_box: crate::recognize_text::BoundingBox {
492 x: entry.bbox_x,
493 y: entry.bbox_y,
494 width: entry.bbox_w,
495 height: entry.bbox_h,
496 },
497 });
498 }
499 unsafe { ffi::vn_detected_barcodes_free(raw.array, raw.count) };
501 out
502 };
503
504 unsafe { ffi::vn_async_array_result_free(result.cast_mut()) };
506 Ok(barcodes)
507}
508
509#[cfg(feature = "detect_barcodes")]
513extern "C" fn barcode_result_cb(result: *const c_void, error: *const i8, ctx: *mut c_void) {
514 let outcome = std::panic::catch_unwind(AssertUnwindSafe(|| unsafe {
517 parse_barcode_result(result, error)
518 }));
519 match outcome {
520 Ok(Ok(barcodes)) => {
521 unsafe { AsyncCompletion::complete_ok(ctx, barcodes) };
523 }
524 Ok(Err(msg)) => {
525 unsafe { AsyncCompletion::<Vec<DetectedBarcode>>::complete_err(ctx, msg) };
527 }
528 Err(payload) => {
529 log_callback_panic("barcode_result_cb", payload.as_ref());
530 unsafe {
532 AsyncCompletion::<Vec<DetectedBarcode>>::complete_err(
533 ctx,
534 "panic in Vision barcode_result_cb".into(),
535 );
536 };
537 }
538 }
539}
540
541#[cfg(feature = "detect_barcodes")]
543pub struct DetectBarcodesFuture {
544 inner: FutureState<Vec<DetectedBarcode>>,
545}
546
547#[cfg(feature = "detect_barcodes")]
548impl std::fmt::Debug for DetectBarcodesFuture {
549 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
550 f.debug_struct("DetectBarcodesFuture")
551 .finish_non_exhaustive()
552 }
553}
554
555#[cfg(feature = "detect_barcodes")]
556impl Future for DetectBarcodesFuture {
557 type Output = Result<Vec<DetectedBarcode>, VisionError>;
558
559 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
560 Pin::new(&mut self.inner).poll(cx)
561 }
562}
563
564#[cfg(feature = "detect_barcodes")]
566#[derive(Debug, Clone, Copy, Default)]
567pub struct AsyncDetectBarcodes;
568
569#[cfg(feature = "detect_barcodes")]
570impl AsyncDetectBarcodes {
571 #[must_use]
572 pub const fn new() -> Self {
573 Self
574 }
575
576 pub fn detect_in_path(&self, path: impl AsRef<Path>) -> DetectBarcodesFuture {
582 match path_to_cstring(path) {
583 Err(error) => DetectBarcodesFuture {
584 inner: FutureState::ready_err(error),
585 },
586 Ok(path_c) => {
587 let (future, ctx) = AsyncCompletion::create();
588 unsafe {
591 ffi::vn_detect_barcodes_in_path_async(path_c.as_ptr(), barcode_result_cb, ctx);
592 };
593 DetectBarcodesFuture {
594 inner: FutureState::pending(future),
595 }
596 }
597 }
598 }
599}
600
601#[cfg(feature = "segmentation")]
613unsafe fn parse_seg_result(
614 result: *const c_void,
615 error: *const i8,
616) -> Result<SegmentationMask, String> {
617 if !error.is_null() {
618 return Err(unsafe { error_from_cstr(error) });
620 }
621 if result.is_null() {
622 return Err("segmentation returned null".into());
623 }
624
625 let raw = unsafe { &*(result.cast::<ffi::AsyncSegResultRaw>()) };
627 if raw.bytes.is_null() {
628 unsafe { ffi::vn_async_seg_result_free(result.cast_mut()) };
630 return Err("segmentation bytes were null".into());
631 }
632
633 let len = raw.height.saturating_mul(raw.bytes_per_row);
634 let bytes = unsafe { core::slice::from_raw_parts(raw.bytes, len) }.to_vec();
636 let mask = SegmentationMask {
637 width: raw.width,
638 height: raw.height,
639 bytes_per_row: raw.bytes_per_row,
640 bytes,
641 };
642
643 unsafe { ffi::vn_async_seg_result_free(result.cast_mut()) };
646 Ok(mask)
647}
648
649#[cfg(feature = "segmentation")]
653extern "C" fn seg_result_cb(result: *const c_void, error: *const i8, ctx: *mut c_void) {
654 let outcome =
657 std::panic::catch_unwind(AssertUnwindSafe(|| unsafe { parse_seg_result(result, error) }));
658 match outcome {
659 Ok(Ok(mask)) => {
660 unsafe { AsyncCompletion::complete_ok(ctx, mask) };
662 }
663 Ok(Err(msg)) => {
664 unsafe { AsyncCompletion::<SegmentationMask>::complete_err(ctx, msg) };
666 }
667 Err(payload) => {
668 log_callback_panic("seg_result_cb", payload.as_ref());
669 unsafe {
671 AsyncCompletion::<SegmentationMask>::complete_err(
672 ctx,
673 "panic in Vision seg_result_cb".into(),
674 );
675 };
676 }
677 }
678}
679
680#[cfg(feature = "segmentation")]
682pub struct PersonSegmentationFuture {
683 inner: FutureState<SegmentationMask>,
684}
685
686#[cfg(feature = "segmentation")]
687impl std::fmt::Debug for PersonSegmentationFuture {
688 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
689 f.debug_struct("PersonSegmentationFuture")
690 .finish_non_exhaustive()
691 }
692}
693
694#[cfg(feature = "segmentation")]
695impl Future for PersonSegmentationFuture {
696 type Output = Result<SegmentationMask, VisionError>;
697
698 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
699 Pin::new(&mut self.inner).poll(cx)
700 }
701}
702
703#[cfg(feature = "segmentation")]
705#[derive(Debug, Clone, Copy)]
706pub struct AsyncPersonSegmentation {
707 quality: SegmentationQuality,
708}
709
710#[cfg(feature = "segmentation")]
711impl Default for AsyncPersonSegmentation {
712 fn default() -> Self {
713 Self::new(SegmentationQuality::Balanced)
714 }
715}
716
717#[cfg(feature = "segmentation")]
718impl AsyncPersonSegmentation {
719 #[must_use]
720 pub const fn new(quality: SegmentationQuality) -> Self {
721 Self { quality }
722 }
723
724 pub fn generate_in_path(&self, path: impl AsRef<Path>) -> PersonSegmentationFuture {
730 match path_to_cstring(path) {
731 Err(error) => PersonSegmentationFuture {
732 inner: FutureState::ready_err(error),
733 },
734 Ok(path_c) => {
735 let (future, ctx) = AsyncCompletion::create();
736 unsafe {
740 ffi::vn_generate_person_segmentation_async(
741 path_c.as_ptr(),
742 self.quality as i32,
743 seg_result_cb,
744 ctx,
745 );
746 };
747 PersonSegmentationFuture {
748 inner: FutureState::pending(future),
749 }
750 }
751 }
752 }
753}