Skip to main content

apple_vision/segmentation/
mod.rs

1//! Segmentation mask generation —
2//! `VNGeneratePersonSegmentationRequest` and
3//! `VNGenerateForegroundInstanceMaskRequest`.
4//!
5//! Both produce a grayscale `CVPixelBuffer` mask; this module copies
6//! it into a Rust-owned `Vec<u8>` so callers don't need a `CoreVideo`
7//! dependency. Mask values are 8-bit (`0` = background, `255` =
8//! foreground). For instance masks, pixel values index into a list
9//! of detected instances (`1..=instance_count`).
10
11use core::ffi::c_char;
12use core::ptr;
13use std::ffi::CString;
14use std::path::Path;
15
16use crate::error::{from_swift, VisionError};
17use crate::ffi;
18
19/// Apple person-segmentation quality.
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum SegmentationQuality {
22    Fast = 0,
23    Balanced = 1,
24    Accurate = 2,
25}
26
27/// A single grayscale mask in row-major byte order.
28#[derive(Debug, Clone, PartialEq, Eq)]
29pub struct SegmentationMask {
30    pub width: usize,
31    pub height: usize,
32    pub bytes_per_row: usize,
33    pub bytes: Vec<u8>,
34}
35
36/// A foreground-instance mask plus the number of distinct instances
37/// the model identified. Pixel values are `0` for background and
38/// `1..=instance_count` for each detected instance.
39#[derive(Debug, Clone, PartialEq, Eq)]
40pub struct InstanceMask {
41    pub mask: SegmentationMask,
42    pub instance_count: usize,
43}
44
45/// Generate a person/body silhouette mask for the image at `path`.
46///
47/// # Errors
48///
49/// Returns [`VisionError::ImageLoadFailed`] / [`VisionError::RequestFailed`].
50pub fn generate_person_segmentation_in_path(
51    path: impl AsRef<Path>,
52    quality: SegmentationQuality,
53) -> Result<Option<SegmentationMask>, VisionError> {
54    let path_str = path
55        .as_ref()
56        .to_str()
57        .ok_or_else(|| VisionError::InvalidArgument("non-UTF-8 path".into()))?;
58    let path_c = CString::new(path_str)
59        .map_err(|e| VisionError::InvalidArgument(format!("path NUL byte: {e}")))?;
60
61    let mut raw = ffi::SegmentationMaskRaw {
62        width: 0,
63        height: 0,
64        bytes_per_row: 0,
65        bytes: ptr::null_mut(),
66    };
67    let mut has_value = false;
68    let mut err_msg: *mut c_char = ptr::null_mut();
69    let status = unsafe {
70        ffi::vn_generate_person_segmentation_in_path(
71            path_c.as_ptr(),
72            quality as i32,
73            &mut raw,
74            &mut has_value,
75            &mut err_msg,
76        )
77    };
78    if status != ffi::status::OK {
79        return Err(unsafe { from_swift(status, err_msg) });
80    }
81    if !has_value || raw.bytes.is_null() {
82        return Ok(None);
83    }
84    let mask = take_raw(&mut raw);
85    Ok(Some(mask))
86}
87
88/// Generate an instance segmentation mask of all foreground objects
89/// in the image at `path` (macOS 14+).
90///
91/// # Errors
92///
93/// Returns [`VisionError::ImageLoadFailed`] / [`VisionError::RequestFailed`].
94pub fn generate_foreground_instance_mask_in_path(
95    path: impl AsRef<Path>,
96) -> Result<Option<InstanceMask>, VisionError> {
97    let path_str = path
98        .as_ref()
99        .to_str()
100        .ok_or_else(|| VisionError::InvalidArgument("non-UTF-8 path".into()))?;
101    let path_c = CString::new(path_str)
102        .map_err(|e| VisionError::InvalidArgument(format!("path NUL byte: {e}")))?;
103
104    let mut raw = ffi::SegmentationMaskRaw {
105        width: 0,
106        height: 0,
107        bytes_per_row: 0,
108        bytes: ptr::null_mut(),
109    };
110    let mut instance_count: usize = 0;
111    let mut has_value = false;
112    let mut err_msg: *mut c_char = ptr::null_mut();
113    let status = unsafe {
114        ffi::vn_generate_foreground_instance_mask_in_path(
115            path_c.as_ptr(),
116            &mut raw,
117            &mut instance_count,
118            &mut has_value,
119            &mut err_msg,
120        )
121    };
122    if status != ffi::status::OK {
123        return Err(unsafe { from_swift(status, err_msg) });
124    }
125    if !has_value || raw.bytes.is_null() {
126        return Ok(None);
127    }
128    let mask = take_raw(&mut raw);
129    Ok(Some(InstanceMask {
130        mask,
131        instance_count,
132    }))
133}
134
135fn take_raw(raw: &mut ffi::SegmentationMaskRaw) -> SegmentationMask {
136    let len = raw.height.saturating_mul(raw.bytes_per_row);
137    let slice = unsafe { core::slice::from_raw_parts(raw.bytes.cast::<u8>(), len) };
138    let bytes = slice.to_vec();
139    unsafe { ffi::vn_segmentation_mask_free(raw) };
140    SegmentationMask {
141        width: raw.width,
142        height: raw.height,
143        bytes_per_row: raw.bytes_per_row,
144        bytes,
145    }
146}