use core::ffi::c_char;
use core::ptr;
use std::ffi::CString;
use std::path::Path;
use crate::error::{from_swift, VisionError};
use crate::ffi;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SegmentationQuality {
Fast = 0,
Balanced = 1,
Accurate = 2,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SegmentationMask {
pub width: usize,
pub height: usize,
pub bytes_per_row: usize,
pub bytes: Vec<u8>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct InstanceMask {
pub mask: SegmentationMask,
pub instance_count: usize,
}
pub fn generate_person_segmentation_in_path(
path: impl AsRef<Path>,
quality: SegmentationQuality,
) -> Result<Option<SegmentationMask>, VisionError> {
let path_str = path
.as_ref()
.to_str()
.ok_or_else(|| VisionError::InvalidArgument("non-UTF-8 path".into()))?;
let path_c = CString::new(path_str)
.map_err(|e| VisionError::InvalidArgument(format!("path NUL byte: {e}")))?;
let mut raw = ffi::SegmentationMaskRaw {
width: 0,
height: 0,
bytes_per_row: 0,
bytes: ptr::null_mut(),
};
let mut has_value = false;
let mut err_msg: *mut c_char = ptr::null_mut();
let status = unsafe {
ffi::vn_generate_person_segmentation_in_path(
path_c.as_ptr(),
quality as i32,
&mut raw,
&mut has_value,
&mut err_msg,
)
};
if status != ffi::status::OK {
return Err(unsafe { from_swift(status, err_msg) });
}
if !has_value || raw.bytes.is_null() {
return Ok(None);
}
let mask = take_raw(&mut raw);
Ok(Some(mask))
}
pub fn generate_foreground_instance_mask_in_path(
path: impl AsRef<Path>,
) -> Result<Option<InstanceMask>, VisionError> {
let path_str = path
.as_ref()
.to_str()
.ok_or_else(|| VisionError::InvalidArgument("non-UTF-8 path".into()))?;
let path_c = CString::new(path_str)
.map_err(|e| VisionError::InvalidArgument(format!("path NUL byte: {e}")))?;
let mut raw = ffi::SegmentationMaskRaw {
width: 0,
height: 0,
bytes_per_row: 0,
bytes: ptr::null_mut(),
};
let mut instance_count: usize = 0;
let mut has_value = false;
let mut err_msg: *mut c_char = ptr::null_mut();
let status = unsafe {
ffi::vn_generate_foreground_instance_mask_in_path(
path_c.as_ptr(),
&mut raw,
&mut instance_count,
&mut has_value,
&mut err_msg,
)
};
if status != ffi::status::OK {
return Err(unsafe { from_swift(status, err_msg) });
}
if !has_value || raw.bytes.is_null() {
return Ok(None);
}
let mask = take_raw(&mut raw);
Ok(Some(InstanceMask {
mask,
instance_count,
}))
}
fn take_raw(raw: &mut ffi::SegmentationMaskRaw) -> SegmentationMask {
let len = raw.height.saturating_mul(raw.bytes_per_row);
let slice = unsafe { core::slice::from_raw_parts(raw.bytes.cast::<u8>(), len) };
let bytes = slice.to_vec();
unsafe { ffi::vn_segmentation_mask_free(raw) };
SegmentationMask {
width: raw.width,
height: raw.height,
bytes_per_row: raw.bytes_per_row,
bytes,
}
}