use core::ffi::c_char;
use core::ptr;
use std::ffi::CString;
use std::path::Path;
use crate::error::{from_swift, VisionError};
use crate::ffi;
use crate::segmentation::SegmentationMask;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OpticalFlowAccuracy {
Low = 0,
Medium = 1,
High = 2,
VeryHigh = 3,
}
pub fn generate_optical_flow_in_paths(
path_a: impl AsRef<Path>,
path_b: impl AsRef<Path>,
accuracy: OpticalFlowAccuracy,
) -> Result<Option<SegmentationMask>, VisionError> {
let a_str = path_a
.as_ref()
.to_str()
.ok_or_else(|| VisionError::InvalidArgument("non-UTF-8 path A".into()))?;
let b_str = path_b
.as_ref()
.to_str()
.ok_or_else(|| VisionError::InvalidArgument("non-UTF-8 path B".into()))?;
let a_c = CString::new(a_str)
.map_err(|e| VisionError::InvalidArgument(format!("path A NUL byte: {e}")))?;
let b_c = CString::new(b_str)
.map_err(|e| VisionError::InvalidArgument(format!("path B NUL byte: {e}")))?;
let mut raw = ffi::SegmentationMaskRaw {
width: 0,
height: 0,
bytes_per_row: 0,
bytes: ptr::null_mut(),
};
let mut has_value = false;
let mut err_msg: *mut c_char = ptr::null_mut();
let status = unsafe {
ffi::vn_generate_optical_flow_in_paths(
a_c.as_ptr(),
b_c.as_ptr(),
accuracy as i32,
&mut raw,
&mut has_value,
&mut err_msg,
)
};
if status != ffi::status::OK {
return Err(unsafe { from_swift(status, err_msg) });
}
if !has_value || raw.bytes.is_null() {
return Ok(None);
}
let len = raw.height.saturating_mul(raw.bytes_per_row);
let slice = unsafe { core::slice::from_raw_parts(raw.bytes.cast::<u8>(), len) };
let bytes = slice.to_vec();
unsafe { ffi::vn_segmentation_mask_free(&mut raw) };
Ok(Some(SegmentationMask {
width: raw.width,
height: raw.height,
bytes_per_row: raw.bytes_per_row,
bytes,
}))
}