Skip to main content

apple_vision/body_pose/
mod.rs

1//! Human body pose detection (`VNDetectHumanBodyPoseRequest`).
2
3use core::ffi::c_char;
4use core::ptr;
5use std::collections::HashMap;
6use std::ffi::CString;
7use std::path::Path;
8
9use crate::error::{from_swift, VisionError};
10use crate::ffi;
11use crate::recognize_text::BoundingBox;
12use crate::recognized_points::{
13    RecognizedPoint, RecognizedPointsObservation, VisionRecognizedPoint,
14};
15use crate::request_base::NormalizedRect;
16
17macro_rules! string_enum {
18    (
19        $(#[$meta:meta])*
20        pub enum $name:ident {
21            $( $variant:ident => $value:literal ),+ $(,)?
22        }
23    ) => {
24        $(#[$meta])*
25        #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
26        pub enum $name {
27            $( $variant ),+
28        }
29
30        impl $name {
31            pub const ALL: &'static [Self] = &[
32                $( Self::$variant ),+
33            ];
34
35            #[must_use]
36            pub const fn as_str(self) -> &'static str {
37                match self {
38                    $( Self::$variant => $value ),+
39                }
40            }
41
42            #[allow(clippy::should_implement_trait)]
43            #[must_use]
44            pub fn from_str(value: &str) -> Option<Self> {
45                match value {
46                    $( $value => Some(Self::$variant), )+
47                    _ => None,
48                }
49            }
50        }
51    };
52}
53
54string_enum! {
55    /// Mirrors `VNHumanBodyPoseObservationJointName`.
56    pub enum HumanBodyPoseJointName {
57        Nose => "head_joint",
58        LeftEye => "left_eye_joint",
59        RightEye => "right_eye_joint",
60        LeftEar => "left_ear_joint",
61        RightEar => "right_ear_joint",
62        LeftShoulder => "left_shoulder_1_joint",
63        RightShoulder => "right_shoulder_1_joint",
64        Neck => "neck_1_joint",
65        LeftElbow => "left_forearm_joint",
66        RightElbow => "right_forearm_joint",
67        LeftWrist => "left_hand_joint",
68        RightWrist => "right_hand_joint",
69        LeftHip => "left_upLeg_joint",
70        RightHip => "right_upLeg_joint",
71        Root => "root",
72        LeftKnee => "left_leg_joint",
73        RightKnee => "right_leg_joint",
74        LeftAnkle => "left_foot_joint",
75        RightAnkle => "right_foot_joint",
76    }
77}
78
79string_enum! {
80    /// Mirrors `VNHumanBodyPoseObservationJointsGroupName`.
81    pub enum HumanBodyPoseJointGroupName {
82        Face => "VNBLKFACE",
83        Torso => "VNBLKTORSO",
84        LeftArm => "VNBLKLARM",
85        RightArm => "VNBLKRARM",
86        LeftLeg => "VNBLKLLEG",
87        RightLeg => "VNBLKRLEG",
88        All => "VNIPOAll",
89    }
90}
91
92/// A single detected joint in normalised image coordinates
93/// (`0.0..=1.0`, bottom-left origin).
94#[derive(Debug, Clone, Copy, PartialEq)]
95pub struct JointPoint {
96    pub x: f64,
97    pub y: f64,
98    pub confidence: f32,
99}
100
101/// One detected human body with its recognised joints.
102///
103/// `joints` keys use Apple's `VNHumanBodyPoseObservationJointName` raw values.
104#[derive(Debug, Clone, PartialEq)]
105pub struct DetectedBodyPose {
106    pub bounding_box: BoundingBox,
107    pub confidence: f32,
108    pub joints: HashMap<String, JointPoint>,
109}
110
111/// A dedicated `VNHumanBodyPoseObservation` wrapper built on top of the generic
112/// `VNRecognizedPointsObservation` surface.
113#[derive(Debug, Clone, PartialEq)]
114pub struct HumanBodyPoseObservation {
115    pub recognized_points: RecognizedPointsObservation,
116    pub available_joint_names: Vec<String>,
117    pub available_joint_group_names: Vec<String>,
118}
119
120const SUPPORTED_JOINT_NAMES: &[&str] = &[
121    HumanBodyPoseJointName::Nose.as_str(),
122    HumanBodyPoseJointName::LeftEye.as_str(),
123    HumanBodyPoseJointName::RightEye.as_str(),
124    HumanBodyPoseJointName::LeftEar.as_str(),
125    HumanBodyPoseJointName::RightEar.as_str(),
126    HumanBodyPoseJointName::LeftShoulder.as_str(),
127    HumanBodyPoseJointName::RightShoulder.as_str(),
128    HumanBodyPoseJointName::Neck.as_str(),
129    HumanBodyPoseJointName::LeftElbow.as_str(),
130    HumanBodyPoseJointName::RightElbow.as_str(),
131    HumanBodyPoseJointName::LeftWrist.as_str(),
132    HumanBodyPoseJointName::RightWrist.as_str(),
133    HumanBodyPoseJointName::LeftHip.as_str(),
134    HumanBodyPoseJointName::RightHip.as_str(),
135    HumanBodyPoseJointName::Root.as_str(),
136    HumanBodyPoseJointName::LeftKnee.as_str(),
137    HumanBodyPoseJointName::RightKnee.as_str(),
138    HumanBodyPoseJointName::LeftAnkle.as_str(),
139    HumanBodyPoseJointName::RightAnkle.as_str(),
140];
141
142const SUPPORTED_JOINT_GROUP_NAMES: &[&str] = &[
143    HumanBodyPoseJointGroupName::Face.as_str(),
144    HumanBodyPoseJointGroupName::Torso.as_str(),
145    HumanBodyPoseJointGroupName::LeftArm.as_str(),
146    HumanBodyPoseJointGroupName::RightArm.as_str(),
147    HumanBodyPoseJointGroupName::LeftLeg.as_str(),
148    HumanBodyPoseJointGroupName::RightLeg.as_str(),
149    HumanBodyPoseJointGroupName::All.as_str(),
150];
151
152impl HumanBodyPoseObservation {
153    #[must_use]
154    pub const fn supported_joint_name_keys() -> &'static [HumanBodyPoseJointName] {
155        HumanBodyPoseJointName::ALL
156    }
157
158    #[must_use]
159    pub const fn supported_joint_names() -> &'static [&'static str] {
160        SUPPORTED_JOINT_NAMES
161    }
162
163    #[must_use]
164    pub const fn supported_joint_group_name_keys() -> &'static [HumanBodyPoseJointGroupName] {
165        HumanBodyPoseJointGroupName::ALL
166    }
167
168    #[must_use]
169    pub const fn supported_joint_group_names() -> &'static [&'static str] {
170        SUPPORTED_JOINT_GROUP_NAMES
171    }
172
173    #[must_use]
174    pub fn recognized_point(
175        &self,
176        joint_name: HumanBodyPoseJointName,
177    ) -> Option<VisionRecognizedPoint> {
178        self.recognized_points.recognized_point(joint_name.as_str())
179    }
180
181    #[must_use]
182    pub fn into_detected_body_pose(self) -> DetectedBodyPose {
183        self.into()
184    }
185}
186
187impl From<DetectedBodyPose> for HumanBodyPoseObservation {
188    fn from(value: DetectedBodyPose) -> Self {
189        let mut available_joint_names = value.joints.keys().cloned().collect::<Vec<_>>();
190        available_joint_names.sort();
191        Self {
192            recognized_points: RecognizedPointsObservation {
193                bounding_box: NormalizedRect::new(
194                    value.bounding_box.x,
195                    value.bounding_box.y,
196                    value.bounding_box.width,
197                    value.bounding_box.height,
198                ),
199                confidence: value.confidence,
200                available_keys: available_joint_names.clone(),
201                available_group_keys: Self::supported_joint_group_names()
202                    .iter()
203                    .map(|name| (*name).to_string())
204                    .collect(),
205                points: value
206                    .joints
207                    .iter()
208                    .map(|(name, point)| {
209                        (
210                            name.clone(),
211                            RecognizedPoint {
212                                x: point.x,
213                                y: point.y,
214                                confidence: point.confidence,
215                            },
216                        )
217                    })
218                    .collect(),
219            },
220            available_joint_names,
221            available_joint_group_names: Self::supported_joint_group_names()
222                .iter()
223                .map(|name| (*name).to_string())
224                .collect(),
225        }
226    }
227}
228
229impl From<HumanBodyPoseObservation> for DetectedBodyPose {
230    fn from(value: HumanBodyPoseObservation) -> Self {
231        Self {
232            bounding_box: BoundingBox {
233                x: value.recognized_points.bounding_box.x,
234                y: value.recognized_points.bounding_box.y,
235                width: value.recognized_points.bounding_box.width,
236                height: value.recognized_points.bounding_box.height,
237            },
238            confidence: value.recognized_points.confidence,
239            joints: value
240                .recognized_points
241                .points
242                .into_iter()
243                .map(|(name, point)| {
244                    (
245                        name,
246                        JointPoint {
247                            x: point.x,
248                            y: point.y,
249                            confidence: point.confidence,
250                        },
251                    )
252                })
253                .collect(),
254        }
255    }
256}
257
258/// Detect human body-pose observations in the image at `path`.
259///
260/// # Errors
261///
262/// Returns [`VisionError::ImageLoadFailed`] / [`VisionError::RequestFailed`].
263pub fn detect_human_body_pose_observations_in_path(
264    path: impl AsRef<Path>,
265) -> Result<Vec<HumanBodyPoseObservation>, VisionError> {
266    detect_human_body_pose_in_path(path).map(|poses| poses.into_iter().map(Into::into).collect())
267}
268
269/// Detect human body poses in the image at `path`.
270///
271/// # Errors
272///
273/// Returns [`VisionError::ImageLoadFailed`] / [`VisionError::RequestFailed`].
274pub fn detect_human_body_pose_in_path(
275    path: impl AsRef<Path>,
276) -> Result<Vec<DetectedBodyPose>, VisionError> {
277    // SAFETY: `run` validates the path and the supplied bridge function matches its expected contract.
278    unsafe { run(path, ffi::vn_detect_human_body_pose_in_path) }
279}
280
281/// Call the supplied pose-detection bridge function and collect owned results.
282///
283/// # Safety
284///
285/// `f` must be a valid bridge function matching this signature and must
286/// populate the out-parameters using the allocation contract expected by
287/// [`collect`].
288pub(crate) unsafe fn run(
289    path: impl AsRef<Path>,
290    f: unsafe extern "C" fn(
291        *const c_char,
292        *mut *mut core::ffi::c_void,
293        *mut usize,
294        *mut *mut c_char,
295    ) -> i32,
296) -> Result<Vec<DetectedBodyPose>, VisionError> {
297    let path_str = path
298        .as_ref()
299        .to_str()
300        .ok_or_else(|| VisionError::InvalidArgument("non-UTF-8 path".into()))?;
301    let path_c = CString::new(path_str)
302        .map_err(|e| VisionError::InvalidArgument(format!("path NUL byte: {e}")))?;
303
304    let mut out_array: *mut core::ffi::c_void = ptr::null_mut();
305    let mut out_count: usize = 0;
306    let mut err_msg: *mut c_char = ptr::null_mut();
307    let status = f(
308        path_c.as_ptr(),
309        &mut out_array,
310        &mut out_count,
311        &mut err_msg,
312    );
313    if status != ffi::status::OK {
314        return Err(from_swift(status, err_msg));
315    }
316    Ok(collect(out_array, out_count))
317}
318
319/// Convert a bridge-allocated pose array into Rust-owned observations.
320///
321/// # Safety
322///
323/// `out_array` must be either null or point to `out_count` consecutive
324/// `PoseObservationRaw` elements allocated by the Swift bridge. This
325/// function consumes that allocation and frees it exactly once.
326pub(crate) unsafe fn collect(
327    out_array: *mut core::ffi::c_void,
328    out_count: usize,
329) -> Vec<DetectedBodyPose> {
330    if out_array.is_null() || out_count == 0 {
331        return Vec::new();
332    }
333    let typed = out_array.cast::<ffi::PoseObservationRaw>();
334    let mut v = Vec::with_capacity(out_count);
335    for i in 0..out_count {
336        let raw = &*typed.add(i);
337        let mut joints = HashMap::with_capacity(raw.joint_count);
338        for j in 0..raw.joint_count {
339            let name_ptr = *raw.joint_names.add(j);
340            if name_ptr.is_null() {
341                continue;
342            }
343            let name = core::ffi::CStr::from_ptr(name_ptr)
344                .to_string_lossy()
345                .into_owned();
346            joints.insert(
347                name,
348                JointPoint {
349                    x: *raw.joint_xs.add(j),
350                    y: *raw.joint_ys.add(j),
351                    confidence: *raw.joint_confidences.add(j),
352                },
353            );
354        }
355        v.push(DetectedBodyPose {
356            bounding_box: BoundingBox {
357                x: raw.bbox_x,
358                y: raw.bbox_y,
359                width: raw.bbox_w,
360                height: raw.bbox_h,
361            },
362            confidence: raw.confidence,
363            joints,
364        });
365    }
366    ffi::vn_pose_observations_free(out_array, out_count);
367    v
368}