Skip to main content

apple_vision/optical_flow/
mod.rs

1//! Optical flow generation (`VNGenerateOpticalFlowRequest`).
2//!
3//! Apple's optical flow request runs over **two** frames (A and B)
4//! and returns a per-pixel displacement field describing how pixels
5//! in A moved to land in B. Raw `CVPixelBuffer` bytes are copied into
6//! a [`SegmentationMask`] for easy transport.
7//!
8//! Trajectory detection (`VNDetectTrajectoriesRequest`) is intentionally
9//! deferred — it's a `VNStatefulRequest` that requires feeding many
10//! frames into the same request instance over time, which doesn't fit
11//! the synchronous one-shot request pattern this crate uses.
12
13use core::ffi::c_char;
14use core::ptr;
15use std::ffi::CString;
16use std::path::Path;
17
18use crate::error::{from_swift, VisionError};
19use crate::ffi;
20use crate::segmentation::SegmentationMask;
21
22/// `VNGenerateOpticalFlowRequest.ComputationAccuracy`.
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum OpticalFlowAccuracy {
25    Low = 0,
26    Medium = 1,
27    High = 2,
28    VeryHigh = 3,
29}
30
31/// Compute the optical flow between `path_a` (start) and `path_b`
32/// (end). Returns raw flow bytes wrapped in a [`SegmentationMask`].
33///
34/// # Errors
35///
36/// Returns [`VisionError::ImageLoadFailed`] / [`VisionError::RequestFailed`].
37pub fn generate_optical_flow_in_paths(
38    path_a: impl AsRef<Path>,
39    path_b: impl AsRef<Path>,
40    accuracy: OpticalFlowAccuracy,
41) -> Result<Option<SegmentationMask>, VisionError> {
42    let a_str = path_a
43        .as_ref()
44        .to_str()
45        .ok_or_else(|| VisionError::InvalidArgument("non-UTF-8 path A".into()))?;
46    let b_str = path_b
47        .as_ref()
48        .to_str()
49        .ok_or_else(|| VisionError::InvalidArgument("non-UTF-8 path B".into()))?;
50    let a_c = CString::new(a_str)
51        .map_err(|e| VisionError::InvalidArgument(format!("path A NUL byte: {e}")))?;
52    let b_c = CString::new(b_str)
53        .map_err(|e| VisionError::InvalidArgument(format!("path B NUL byte: {e}")))?;
54
55    let mut raw = ffi::SegmentationMaskRaw {
56        width: 0,
57        height: 0,
58        bytes_per_row: 0,
59        bytes: ptr::null_mut(),
60    };
61    let mut has_value = false;
62    let mut err_msg: *mut c_char = ptr::null_mut();
63    let status = unsafe {
64        ffi::vn_generate_optical_flow_in_paths(
65            a_c.as_ptr(),
66            b_c.as_ptr(),
67            accuracy as i32,
68            &mut raw,
69            &mut has_value,
70            &mut err_msg,
71        )
72    };
73    if status != ffi::status::OK {
74        return Err(unsafe { from_swift(status, err_msg) });
75    }
76    if !has_value || raw.bytes.is_null() {
77        return Ok(None);
78    }
79    let len = raw.height.saturating_mul(raw.bytes_per_row);
80    let slice = unsafe { core::slice::from_raw_parts(raw.bytes.cast::<u8>(), len) };
81    let bytes = slice.to_vec();
82    unsafe { ffi::vn_segmentation_mask_free(&mut raw) };
83    Ok(Some(SegmentationMask {
84        width: raw.width,
85        height: raw.height,
86        bytes_per_row: raw.bytes_per_row,
87        bytes,
88    }))
89}