Skip to main content

apple_vision/segmentation/
mod.rs

1//! Segmentation mask generation —
2//! `VNGeneratePersonSegmentationRequest` and
3//! `VNGenerateForegroundInstanceMaskRequest`.
4//!
5//! Both produce a grayscale `CVPixelBuffer` mask; this module copies
6//! it into a Rust-owned `Vec<u8>` so callers don't need a `CoreVideo`
7//! dependency. Mask values are 8-bit (`0` = background, `255` =
8//! foreground). For instance masks, pixel values index into a list
9//! of detected instances (`1..=instance_count`).
10
11use core::ffi::c_char;
12use core::ptr;
13use std::ffi::CString;
14use std::path::Path;
15
16use crate::error::{from_swift, VisionError};
17use crate::ffi;
18use crate::request_base::PixelBufferObservation;
19
20/// Apple person-segmentation quality.
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum SegmentationQuality {
23    Fast = 0,
24    Balanced = 1,
25    Accurate = 2,
26}
27
28/// A single grayscale mask in row-major byte order.
29#[derive(Debug, Clone, PartialEq, Eq)]
30pub struct SegmentationMask {
31    pub width: usize,
32    pub height: usize,
33    pub bytes_per_row: usize,
34    pub bytes: Vec<u8>,
35}
36
37/// A foreground-instance mask plus the number of distinct instances
38/// the model identified. Pixel values are `0` for background and
39/// `1..=instance_count` for each detected instance.
40#[derive(Debug, Clone, PartialEq, Eq)]
41pub struct InstanceMask {
42    pub mask: SegmentationMask,
43    pub instance_count: usize,
44}
45
46/// A dedicated `VNInstanceMaskObservation` wrapper.
47#[derive(Debug, Clone, PartialEq, Eq)]
48pub struct InstanceMaskObservation {
49    pub pixel_buffer_observation: PixelBufferObservation,
50    pub instance_count: usize,
51}
52
53impl InstanceMaskObservation {
54    #[must_use]
55    pub fn into_instance_mask(self) -> InstanceMask {
56        self.into()
57    }
58}
59
60impl From<SegmentationMask> for PixelBufferObservation {
61    fn from(value: SegmentationMask) -> Self {
62        Self::new(value.width, value.height, value.bytes_per_row, value.bytes)
63    }
64}
65
66impl From<InstanceMask> for InstanceMaskObservation {
67    fn from(value: InstanceMask) -> Self {
68        Self {
69            pixel_buffer_observation: value.mask.into(),
70            instance_count: value.instance_count,
71        }
72    }
73}
74
75impl From<InstanceMaskObservation> for InstanceMask {
76    fn from(value: InstanceMaskObservation) -> Self {
77        Self {
78            mask: SegmentationMask {
79                width: value.pixel_buffer_observation.width,
80                height: value.pixel_buffer_observation.height,
81                bytes_per_row: value.pixel_buffer_observation.bytes_per_row,
82                bytes: value.pixel_buffer_observation.bytes,
83            },
84            instance_count: value.instance_count,
85        }
86    }
87}
88
89/// Generate a person/body silhouette mask for the image at `path`.
90///
91/// # Errors
92///
93/// Returns [`VisionError::ImageLoadFailed`] / [`VisionError::RequestFailed`].
94pub fn generate_person_segmentation_in_path(
95    path: impl AsRef<Path>,
96    quality: SegmentationQuality,
97) -> Result<Option<SegmentationMask>, VisionError> {
98    let path_str = path
99        .as_ref()
100        .to_str()
101        .ok_or_else(|| VisionError::InvalidArgument("non-UTF-8 path".into()))?;
102    let path_c = CString::new(path_str)
103        .map_err(|e| VisionError::InvalidArgument(format!("path NUL byte: {e}")))?;
104
105    let mut raw = ffi::SegmentationMaskRaw {
106        width: 0,
107        height: 0,
108        bytes_per_row: 0,
109        bytes: ptr::null_mut(),
110    };
111    let mut has_value = false;
112    let mut err_msg: *mut c_char = ptr::null_mut();
113    // SAFETY: all pointer arguments are valid stack locations or bridge-owned handles; strings are valid C strings for the duration of the call.
114    let status = unsafe {
115        ffi::vn_generate_person_segmentation_in_path(
116            path_c.as_ptr(),
117            quality as i32,
118            &mut raw,
119            &mut has_value,
120            &mut err_msg,
121        )
122    };
123    if status != ffi::status::OK {
124        // SAFETY: the error pointer is either null or a bridge-allocated C string; `from_swift` frees it.
125        return Err(unsafe { from_swift(status, err_msg) });
126    }
127    if !has_value || raw.bytes.is_null() {
128        return Ok(None);
129    }
130    let mask = take_raw(&mut raw);
131    Ok(Some(mask))
132}
133
134/// Generate an instance segmentation mask of all foreground objects
135/// in the image at `path` (macOS 14+).
136///
137/// # Errors
138///
139/// Returns [`VisionError::ImageLoadFailed`] / [`VisionError::RequestFailed`].
140pub fn generate_foreground_instance_mask_in_path(
141    path: impl AsRef<Path>,
142) -> Result<Option<InstanceMask>, VisionError> {
143    let path_str = path
144        .as_ref()
145        .to_str()
146        .ok_or_else(|| VisionError::InvalidArgument("non-UTF-8 path".into()))?;
147    let path_c = CString::new(path_str)
148        .map_err(|e| VisionError::InvalidArgument(format!("path NUL byte: {e}")))?;
149
150    let mut raw = ffi::SegmentationMaskRaw {
151        width: 0,
152        height: 0,
153        bytes_per_row: 0,
154        bytes: ptr::null_mut(),
155    };
156    let mut instance_count: usize = 0;
157    let mut has_value = false;
158    let mut err_msg: *mut c_char = ptr::null_mut();
159    // SAFETY: all pointer arguments are valid stack locations or bridge-owned handles; strings are valid C strings for the duration of the call.
160    let status = unsafe {
161        ffi::vn_generate_foreground_instance_mask_in_path(
162            path_c.as_ptr(),
163            &mut raw,
164            &mut instance_count,
165            &mut has_value,
166            &mut err_msg,
167        )
168    };
169    if status != ffi::status::OK {
170        // SAFETY: the error pointer is either null or a bridge-allocated C string; `from_swift` frees it.
171        return Err(unsafe { from_swift(status, err_msg) });
172    }
173    if !has_value || raw.bytes.is_null() {
174        return Ok(None);
175    }
176    let mask = take_raw(&mut raw);
177    Ok(Some(InstanceMask {
178        mask,
179        instance_count,
180    }))
181}
182
183/// Generate a dedicated `VNInstanceMaskObservation` wrapper for the image at
184/// `path`.
185///
186/// # Errors
187///
188/// Returns [`VisionError::ImageLoadFailed`] / [`VisionError::RequestFailed`].
189pub fn generate_foreground_instance_mask_observation_in_path(
190    path: impl AsRef<Path>,
191) -> Result<Option<InstanceMaskObservation>, VisionError> {
192    generate_foreground_instance_mask_in_path(path)
193        .map(|mask| mask.map(InstanceMaskObservation::from))
194}
195
196fn take_raw(raw: &mut ffi::SegmentationMaskRaw) -> SegmentationMask {
197    let len = raw.height.saturating_mul(raw.bytes_per_row);
198    // SAFETY: `raw.bytes` is valid for `len` bytes as guaranteed by the Swift bridge.
199    let slice = unsafe { core::slice::from_raw_parts(raw.bytes.cast::<u8>(), len) };
200    let bytes = slice.to_vec();
201    // SAFETY: the pointer/count pair was allocated by the bridge and is freed exactly once here.
202    unsafe { ffi::vn_segmentation_mask_free(raw) };
203    SegmentationMask {
204        width: raw.width,
205        height: raw.height,
206        bytes_per_row: raw.bytes_per_row,
207        bytes,
208    }
209}