Skip to main content

apple_vision/optical_flow/
mod.rs

1//! Optical flow generation (`VNGenerateOpticalFlowRequest`).
2//!
3//! Apple's optical flow request runs over **two** frames (A and B)
4//! and returns a per-pixel displacement field describing how pixels
5//! in A moved to land in B. Raw `CVPixelBuffer` bytes are copied into
6//! a [`SegmentationMask`] for easy transport.
7//!
8//! Trajectory detection (`VNDetectTrajectoriesRequest`) is intentionally
9//! deferred — it's a `VNStatefulRequest` that requires feeding many
10//! frames into the same request instance over time, which doesn't fit
11//! the synchronous one-shot request pattern this crate uses.
12
13use core::ffi::c_char;
14use core::ptr;
15use std::ffi::CString;
16use std::path::Path;
17
18use crate::error::{from_swift, VisionError};
19use crate::ffi;
20use crate::request_base::PixelBufferObservation;
21use crate::segmentation::SegmentationMask;
22
23/// `VNGenerateOpticalFlowRequest.ComputationAccuracy`.
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum OpticalFlowAccuracy {
26    Low = 0,
27    Medium = 1,
28    High = 2,
29    VeryHigh = 3,
30}
31
32/// Compute the optical flow between `path_a` (start) and `path_b`
33/// (end). Returns raw flow bytes wrapped in a [`SegmentationMask`].
34///
35/// # Errors
36///
37/// Returns [`VisionError::ImageLoadFailed`] / [`VisionError::RequestFailed`].
38pub fn generate_optical_flow_in_paths(
39    path_a: impl AsRef<Path>,
40    path_b: impl AsRef<Path>,
41    accuracy: OpticalFlowAccuracy,
42) -> Result<Option<SegmentationMask>, VisionError> {
43    let a_str = path_a
44        .as_ref()
45        .to_str()
46        .ok_or_else(|| VisionError::InvalidArgument("non-UTF-8 path A".into()))?;
47    let b_str = path_b
48        .as_ref()
49        .to_str()
50        .ok_or_else(|| VisionError::InvalidArgument("non-UTF-8 path B".into()))?;
51    let a_c = CString::new(a_str)
52        .map_err(|e| VisionError::InvalidArgument(format!("path A NUL byte: {e}")))?;
53    let b_c = CString::new(b_str)
54        .map_err(|e| VisionError::InvalidArgument(format!("path B NUL byte: {e}")))?;
55
56    let mut raw = ffi::SegmentationMaskRaw {
57        width: 0,
58        height: 0,
59        bytes_per_row: 0,
60        bytes: ptr::null_mut(),
61    };
62    let mut has_value = false;
63    let mut err_msg: *mut c_char = ptr::null_mut();
64    // SAFETY: all pointer arguments are valid stack locations or bridge-owned handles; strings are valid C strings for the duration of the call.
65    let status = unsafe {
66        ffi::vn_generate_optical_flow_in_paths(
67            a_c.as_ptr(),
68            b_c.as_ptr(),
69            accuracy as i32,
70            &mut raw,
71            &mut has_value,
72            &mut err_msg,
73        )
74    };
75    if status != ffi::status::OK {
76        // SAFETY: the error pointer is either null or a bridge-allocated C string; `from_swift` frees it.
77        return Err(unsafe { from_swift(status, err_msg) });
78    }
79    if !has_value || raw.bytes.is_null() {
80        return Ok(None);
81    }
82    let len = raw.height.saturating_mul(raw.bytes_per_row);
83    // SAFETY: `raw.bytes` is valid for `len` bytes as guaranteed by the Swift bridge.
84    let slice = unsafe { core::slice::from_raw_parts(raw.bytes.cast::<u8>(), len) };
85    let bytes = slice.to_vec();
86    // SAFETY: `raw` was populated by the bridge and has not been freed yet; unique free site.
87    unsafe { ffi::vn_segmentation_mask_free(&mut raw) };
88    Ok(Some(SegmentationMask {
89        width: raw.width,
90        height: raw.height,
91        bytes_per_row: raw.bytes_per_row,
92        bytes,
93    }))
94}
95
96/// Compute the optical flow and wrap the result as a generic
97/// `VNPixelBufferObservation`.
98///
99/// # Errors
100///
101/// Returns [`VisionError::ImageLoadFailed`] / [`VisionError::RequestFailed`].
102pub fn generate_optical_flow_observation_in_paths(
103    path_a: impl AsRef<Path>,
104    path_b: impl AsRef<Path>,
105    accuracy: OpticalFlowAccuracy,
106) -> Result<Option<PixelBufferObservation>, VisionError> {
107    generate_optical_flow_in_paths(path_a, path_b, accuracy)
108        .map(|mask| mask.map(PixelBufferObservation::from))
109}