Skip to main content

apple_vision/tracking/
mod.rs

1#![allow(clippy::cast_possible_wrap, clippy::cast_sign_loss)]
2#![allow(clippy::too_long_first_doc_paragraph)]
3//! Stateful Vision tracking requests backed by retained Swift sessions.
4//!
5//! These wrappers keep the underlying Vision request alive across frames,
6//! which is required for object / rectangle tracking and the sequence-based
7//! optical-flow and image-registration trackers.
8
9use core::ffi::{c_char, c_void};
10use core::ptr;
11use std::ffi::{CStr, CString};
12use std::path::Path;
13
14use crate::error::VisionError;
15use crate::face_landmarks::LandmarkPoint;
16use crate::ffi;
17use crate::optical_flow::OpticalFlowAccuracy;
18use crate::recognize_text::BoundingBox;
19use crate::rectangles::RectangleObservation;
20use crate::registration::{HomographicAlignment, TranslationalAlignment};
21
22/// Public alias for `VNTrackOpticalFlowRequest.ComputationAccuracy`.
23pub type TrackOpticalFlowRequestComputationAccuracy = OpticalFlowAccuracy;
24
25/// Raw optical-flow pixel buffer copied out of Vision.
26#[derive(Debug, Clone, PartialEq, Eq)]
27pub struct OpticalFlowFrame {
28    pub width: usize,
29    pub height: usize,
30    pub bytes_per_row: usize,
31    pub bytes: Vec<u8>,
32}
33
34impl OpticalFlowFrame {
35    /// Borrow the copied raw pixel-buffer bytes.
36    #[must_use]
37    pub fn as_bytes(&self) -> &[u8] {
38        &self.bytes
39    }
40}
41
42/// Tracks a detected object's bounding box across a sequence of frames.
43pub struct ObjectTracker {
44    handle: *mut c_void,
45}
46
47/// Tracks a known rectangle observation across a sequence of frames.
48pub struct RectangleTracker {
49    handle: *mut c_void,
50}
51
52/// Tracks dense optical flow across a frame sequence.
53pub struct OpticalFlowTracker {
54    handle: *mut c_void,
55}
56
57/// Tracks translational image registration across frames.
58pub struct TranslationalImageTracker {
59    handle: *mut c_void,
60}
61
62/// Tracks homographic image registration across frames.
63pub struct HomographicImageTracker {
64    handle: *mut c_void,
65}
66
67impl ObjectTracker {
68    /// Create a new object tracker seeded from `image_path` and `bbox`.
69    ///
70    /// # Errors
71    ///
72    /// Returns [`VisionError`] if the image path is invalid, the image
73    /// fails to load, or Vision rejects the tracking request.
74    pub fn new(image_path: impl AsRef<Path>, bbox: BoundingBox) -> Result<Self, VisionError> {
75        let image_c = path_to_cstring(image_path.as_ref(), "image path")?;
76        let mut raw_bbox = ffi::SimpleRectRaw {
77            x: bbox.x,
78            y: bbox.y,
79            w: bbox.width,
80            h: bbox.height,
81            confidence: 1.0,
82            _pad: 0.0,
83        };
84        let mut handle: *mut c_void = ptr::null_mut();
85        let mut err: *mut c_char = ptr::null_mut();
86        // SAFETY: all pointer arguments are valid stack locations or bridge-owned handles; strings are valid C strings for the duration of the call.
87        let status = unsafe {
88            ffi::vn_object_tracker_create(
89                image_c.as_ptr(),
90                ptr::addr_of_mut!(raw_bbox).cast(),
91                &mut handle,
92                &mut err,
93            )
94        };
95        if status != ffi::status::OK {
96            return Err(error_from_status(status, err));
97        }
98        ensure_handle(handle, "object tracker")?;
99        Ok(Self { handle })
100    }
101
102    /// Track the object into `image_path` and return the updated bounding
103    /// box.
104    ///
105    /// # Errors
106    ///
107    /// Returns [`VisionError`] if the image path is invalid, the image
108    /// fails to load, or Vision rejects the tracking request.
109    pub fn track(&mut self, image_path: impl AsRef<Path>) -> Result<BoundingBox, VisionError> {
110        let image_c = path_to_cstring(image_path.as_ref(), "image path")?;
111        let mut raw_bbox = ffi::SimpleRectRaw {
112            x: 0.0,
113            y: 0.0,
114            w: 0.0,
115            h: 0.0,
116            confidence: 0.0,
117            _pad: 0.0,
118        };
119        let mut err: *mut c_char = ptr::null_mut();
120        // SAFETY: all pointer arguments are valid stack locations or bridge-owned handles; strings are valid C strings for the duration of the call.
121        let status = unsafe {
122            ffi::vn_object_tracker_track(
123                self.handle,
124                image_c.as_ptr(),
125                ptr::addr_of_mut!(raw_bbox).cast(),
126                &mut err,
127            )
128        };
129        if status != ffi::status::OK {
130            return Err(error_from_status(status, err));
131        }
132        Ok(BoundingBox {
133            x: raw_bbox.x,
134            y: raw_bbox.y,
135            width: raw_bbox.w,
136            height: raw_bbox.h,
137        })
138    }
139}
140
141impl RectangleTracker {
142    /// Create a new rectangle tracker seeded from `image_path` and the
143    /// known rectangle observation for that frame.
144    ///
145    /// # Errors
146    ///
147    /// Returns [`VisionError`] if the image path is invalid, the image
148    /// fails to load, or Vision rejects the tracking request.
149    pub fn new(
150        image_path: impl AsRef<Path>,
151        rect_observation: &RectangleObservation,
152    ) -> Result<Self, VisionError> {
153        let image_c = path_to_cstring(image_path.as_ref(), "image path")?;
154        let mut raw = rectangle_to_raw(rect_observation);
155        let mut handle: *mut c_void = ptr::null_mut();
156        let mut err: *mut c_char = ptr::null_mut();
157        // SAFETY: all pointer arguments are valid stack locations or bridge-owned handles; strings are valid C strings for the duration of the call.
158        let status = unsafe {
159            ffi::vn_rectangle_tracker_create(
160                image_c.as_ptr(),
161                ptr::addr_of_mut!(raw).cast(),
162                &mut handle,
163                &mut err,
164            )
165        };
166        if status != ffi::status::OK {
167            return Err(error_from_status(status, err));
168        }
169        ensure_handle(handle, "rectangle tracker")?;
170        Ok(Self { handle })
171    }
172
173    /// Track the rectangle into `image_path` and return the updated
174    /// rectangle observation.
175    ///
176    /// # Errors
177    ///
178    /// Returns [`VisionError`] if the image path is invalid, the image
179    /// fails to load, or Vision rejects the tracking request.
180    pub fn track(
181        &mut self,
182        image_path: impl AsRef<Path>,
183    ) -> Result<RectangleObservation, VisionError> {
184        let image_c = path_to_cstring(image_path.as_ref(), "image path")?;
185        let mut raw = ffi::RectangleObservationRaw {
186            bbox_x: 0.0,
187            bbox_y: 0.0,
188            bbox_w: 0.0,
189            bbox_h: 0.0,
190            confidence: 0.0,
191            tl_x: 0.0,
192            tl_y: 0.0,
193            tr_x: 0.0,
194            tr_y: 0.0,
195            bl_x: 0.0,
196            bl_y: 0.0,
197            br_x: 0.0,
198            br_y: 0.0,
199        };
200        let mut err: *mut c_char = ptr::null_mut();
201        // SAFETY: all pointer arguments are valid stack locations or bridge-owned handles; strings are valid C strings for the duration of the call.
202        let status = unsafe {
203            ffi::vn_rectangle_tracker_track(
204                self.handle,
205                image_c.as_ptr(),
206                ptr::addr_of_mut!(raw).cast(),
207                &mut err,
208            )
209        };
210        if status != ffi::status::OK {
211            return Err(error_from_status(status, err));
212        }
213        Ok(rectangle_from_raw(&raw))
214    }
215}
216
217impl OpticalFlowTracker {
218    /// The computation-accuracy values exposed by `VNTrackOpticalFlowRequest`.
219    #[must_use]
220    pub const fn supported_computation_accuracies(
221    ) -> &'static [TrackOpticalFlowRequestComputationAccuracy] {
222        &[
223            TrackOpticalFlowRequestComputationAccuracy::Low,
224            TrackOpticalFlowRequestComputationAccuracy::Medium,
225            TrackOpticalFlowRequestComputationAccuracy::High,
226            TrackOpticalFlowRequestComputationAccuracy::VeryHigh,
227        ]
228    }
229
230    /// Create a new optical-flow tracker seeded with the reference image.
231    ///
232    /// # Errors
233    ///
234    /// Returns [`VisionError`] if the image path is invalid, the image
235    /// fails to load, or Vision rejects the tracking request.
236    pub fn new(reference_path: impl AsRef<Path>) -> Result<Self, VisionError> {
237        let image_c = path_to_cstring(reference_path.as_ref(), "reference path")?;
238        let mut handle: *mut c_void = ptr::null_mut();
239        let mut err: *mut c_char = ptr::null_mut();
240        let status =
241            // SAFETY: all pointer arguments are valid stack locations or null-initialised out-params; strings are valid C strings for the duration of the call.
242            unsafe { ffi::vn_optical_flow_tracker_create(image_c.as_ptr(), &mut handle, &mut err) };
243        if status != ffi::status::OK {
244            return Err(error_from_status(status, err));
245        }
246        ensure_handle(handle, "optical-flow tracker")?;
247        Ok(Self { handle })
248    }
249
250    /// Track optical flow into `image_path` and return the copied raw
251    /// pixel-buffer bytes.
252    ///
253    /// # Errors
254    ///
255    /// Returns [`VisionError`] if the image path is invalid, the image
256    /// fails to load, or Vision rejects the tracking request.
257    pub fn track(&mut self, image_path: impl AsRef<Path>) -> Result<OpticalFlowFrame, VisionError> {
258        let image_c = path_to_cstring(image_path.as_ref(), "image path")?;
259        let mut raw = ffi::SegmentationMaskRaw {
260            width: 0,
261            height: 0,
262            bytes_per_row: 0,
263            bytes: ptr::null_mut(),
264        };
265        let mut err: *mut c_char = ptr::null_mut();
266        // SAFETY: all pointer arguments are valid stack locations or bridge-owned handles; strings are valid C strings for the duration of the call.
267        let status = unsafe {
268            ffi::vn_optical_flow_tracker_track(
269                self.handle,
270                image_c.as_ptr(),
271                ptr::addr_of_mut!(raw).cast(),
272                &mut err,
273            )
274        };
275        if status != ffi::status::OK {
276            return Err(error_from_status(status, err));
277        }
278        let frame = copy_mask(&raw);
279        // SAFETY: `raw` was populated by the bridge and has not been freed yet; unique free site.
280        unsafe { ffi::vn_segmentation_mask_free(ptr::addr_of_mut!(raw).cast()) };
281        Ok(frame)
282    }
283}
284
285impl TranslationalImageTracker {
286    /// Create a new translational-registration tracker seeded with the
287    /// reference image.
288    ///
289    /// # Errors
290    ///
291    /// Returns [`VisionError`] if the image path is invalid, the image
292    /// fails to load, or Vision rejects the tracking request.
293    pub fn new(reference_path: impl AsRef<Path>) -> Result<Self, VisionError> {
294        let image_c = path_to_cstring(reference_path.as_ref(), "reference path")?;
295        let mut handle: *mut c_void = ptr::null_mut();
296        let mut err: *mut c_char = ptr::null_mut();
297        // SAFETY: all pointer arguments are valid stack locations or null-initialised out-params; strings are valid C strings for the duration of the call.
298        let status = unsafe {
299            ffi::vn_translational_image_tracker_create(image_c.as_ptr(), &mut handle, &mut err)
300        };
301        if status != ffi::status::OK {
302            return Err(error_from_status(status, err));
303        }
304        ensure_handle(handle, "translational tracker")?;
305        Ok(Self { handle })
306    }
307
308    /// Track the translational alignment into `image_path`.
309    ///
310    /// # Errors
311    ///
312    /// Returns [`VisionError`] if the image path is invalid, the image
313    /// fails to load, or Vision rejects the tracking request.
314    pub fn track(
315        &mut self,
316        image_path: impl AsRef<Path>,
317    ) -> Result<TranslationalAlignment, VisionError> {
318        let image_c = path_to_cstring(image_path.as_ref(), "image path")?;
319        let mut raw = ffi::TranslationalAlignmentRaw { tx: 0.0, ty: 0.0 };
320        let mut err: *mut c_char = ptr::null_mut();
321        // SAFETY: all pointer arguments are valid stack locations or bridge-owned handles; strings are valid C strings for the duration of the call.
322        let status = unsafe {
323            ffi::vn_translational_image_tracker_track(
324                self.handle,
325                image_c.as_ptr(),
326                ptr::addr_of_mut!(raw).cast(),
327                &mut err,
328            )
329        };
330        if status != ffi::status::OK {
331            return Err(error_from_status(status, err));
332        }
333        Ok(TranslationalAlignment {
334            tx: raw.tx,
335            ty: raw.ty,
336        })
337    }
338}
339
340impl HomographicImageTracker {
341    /// Create a new homographic-registration tracker seeded with the
342    /// reference image.
343    ///
344    /// # Errors
345    ///
346    /// Returns [`VisionError`] if the image path is invalid, the image
347    /// fails to load, or Vision rejects the tracking request.
348    pub fn new(reference_path: impl AsRef<Path>) -> Result<Self, VisionError> {
349        let image_c = path_to_cstring(reference_path.as_ref(), "reference path")?;
350        let mut handle: *mut c_void = ptr::null_mut();
351        let mut err: *mut c_char = ptr::null_mut();
352        // SAFETY: all pointer arguments are valid stack locations or null-initialised out-params; strings are valid C strings for the duration of the call.
353        let status = unsafe {
354            ffi::vn_homographic_image_tracker_create(image_c.as_ptr(), &mut handle, &mut err)
355        };
356        if status != ffi::status::OK {
357            return Err(error_from_status(status, err));
358        }
359        ensure_handle(handle, "homographic tracker")?;
360        Ok(Self { handle })
361    }
362
363    /// Track the homographic alignment into `image_path`.
364    ///
365    /// # Errors
366    ///
367    /// Returns [`VisionError`] if the image path is invalid, the image
368    /// fails to load, or Vision rejects the tracking request.
369    pub fn track(
370        &mut self,
371        image_path: impl AsRef<Path>,
372    ) -> Result<HomographicAlignment, VisionError> {
373        let image_c = path_to_cstring(image_path.as_ref(), "image path")?;
374        let mut raw = ffi::HomographicAlignmentRaw {
375            m00: 0.0,
376            m01: 0.0,
377            m02: 0.0,
378            m10: 0.0,
379            m11: 0.0,
380            m12: 0.0,
381            m20: 0.0,
382            m21: 0.0,
383            m22: 0.0,
384            _pad: 0.0,
385        };
386        let mut err: *mut c_char = ptr::null_mut();
387        // SAFETY: all pointer arguments are valid stack locations or bridge-owned handles; strings are valid C strings for the duration of the call.
388        let status = unsafe {
389            ffi::vn_homographic_image_tracker_track(
390                self.handle,
391                image_c.as_ptr(),
392                ptr::addr_of_mut!(raw).cast(),
393                &mut err,
394            )
395        };
396        if status != ffi::status::OK {
397            return Err(error_from_status(status, err));
398        }
399        Ok(HomographicAlignment {
400            matrix: [
401                [raw.m00, raw.m01, raw.m02],
402                [raw.m10, raw.m11, raw.m12],
403                [raw.m20, raw.m21, raw.m22],
404            ],
405        })
406    }
407}
408
409macro_rules! impl_tracker_drop {
410    ($tracker:ident, $release:path) => {
411        impl Drop for $tracker {
412            fn drop(&mut self) {
413                if !self.handle.is_null() {
414                    // SAFETY: `self.handle` is a live bridge-owned handle and this drop path releases it exactly once.
415                    unsafe { $release(self.handle) };
416                    self.handle = ptr::null_mut();
417                }
418            }
419        }
420    };
421}
422
423impl_tracker_drop!(ObjectTracker, ffi::vn_object_tracker_release);
424impl_tracker_drop!(RectangleTracker, ffi::vn_rectangle_tracker_release);
425impl_tracker_drop!(OpticalFlowTracker, ffi::vn_optical_flow_tracker_release);
426impl_tracker_drop!(
427    TranslationalImageTracker,
428    ffi::vn_translational_image_tracker_release
429);
430impl_tracker_drop!(
431    HomographicImageTracker,
432    ffi::vn_homographic_image_tracker_release
433);
434
435fn path_to_cstring(path: &Path, label: &str) -> Result<CString, VisionError> {
436    let path_str = path
437        .to_str()
438        .ok_or_else(|| VisionError::InvalidArgument(format!("non-UTF-8 {label}")))?;
439    CString::new(path_str)
440        .map_err(|e| VisionError::InvalidArgument(format!("{label} NUL byte: {e}")))
441}
442
443const fn rectangle_to_raw(rect: &RectangleObservation) -> ffi::RectangleObservationRaw {
444    ffi::RectangleObservationRaw {
445        bbox_x: rect.bounding_box.x,
446        bbox_y: rect.bounding_box.y,
447        bbox_w: rect.bounding_box.width,
448        bbox_h: rect.bounding_box.height,
449        confidence: rect.confidence,
450        tl_x: rect.top_left.x,
451        tl_y: rect.top_left.y,
452        tr_x: rect.top_right.x,
453        tr_y: rect.top_right.y,
454        bl_x: rect.bottom_left.x,
455        bl_y: rect.bottom_left.y,
456        br_x: rect.bottom_right.x,
457        br_y: rect.bottom_right.y,
458    }
459}
460
461const fn rectangle_from_raw(raw: &ffi::RectangleObservationRaw) -> RectangleObservation {
462    RectangleObservation {
463        bounding_box: BoundingBox {
464            x: raw.bbox_x,
465            y: raw.bbox_y,
466            width: raw.bbox_w,
467            height: raw.bbox_h,
468        },
469        confidence: raw.confidence,
470        top_left: LandmarkPoint {
471            x: raw.tl_x,
472            y: raw.tl_y,
473        },
474        top_right: LandmarkPoint {
475            x: raw.tr_x,
476            y: raw.tr_y,
477        },
478        bottom_left: LandmarkPoint {
479            x: raw.bl_x,
480            y: raw.bl_y,
481        },
482        bottom_right: LandmarkPoint {
483            x: raw.br_x,
484            y: raw.br_y,
485        },
486    }
487}
488
489fn copy_mask(raw: &ffi::SegmentationMaskRaw) -> OpticalFlowFrame {
490    if raw.bytes.is_null() {
491        return OpticalFlowFrame {
492            width: raw.width,
493            height: raw.height,
494            bytes_per_row: raw.bytes_per_row,
495            bytes: Vec::new(),
496        };
497    }
498    let len = raw.height.saturating_mul(raw.bytes_per_row);
499    // SAFETY: `raw.bytes` is valid for `len` bytes as guaranteed by the Swift bridge.
500    let bytes = unsafe { core::slice::from_raw_parts(raw.bytes.cast::<u8>(), len) }.to_vec();
501    OpticalFlowFrame {
502        width: raw.width,
503        height: raw.height,
504        bytes_per_row: raw.bytes_per_row,
505        bytes,
506    }
507}
508
509fn ensure_handle(handle: *mut c_void, tracker_name: &str) -> Result<(), VisionError> {
510    if handle.is_null() {
511        return Err(VisionError::Unknown {
512            code: ffi::status::UNKNOWN,
513            message: format!("{tracker_name} returned a null handle"),
514        });
515    }
516    Ok(())
517}
518
519fn error_from_status(status: i32, err: *mut c_char) -> VisionError {
520    // SAFETY: the error pointer is either null or a bridge-allocated C string; `take_err` copies and frees it.
521    let message = unsafe { take_err(err) };
522    match status {
523        ffi::status::IMAGE_LOAD_FAILED => VisionError::ImageLoadFailed(message),
524        ffi::status::REQUEST_FAILED => VisionError::RequestFailed(message),
525        ffi::status::INVALID_ARGUMENT => VisionError::InvalidArgument(message),
526        code => VisionError::Unknown { code, message },
527    }
528}
529
530/// Extract an error string from a bridge-allocated C string and free it.
531///
532/// # Safety
533///
534/// `p` must be either null or a valid null-terminated C string heap-allocated
535/// (via `malloc`) by the Swift bridge. After this call `p` is invalid.
536unsafe fn take_err(p: *mut c_char) -> String {
537    if p.is_null() {
538        return String::new();
539    }
540    // SAFETY: the C string pointer is non-null (checked above) and valid for the duration of this borrow.
541    let s = unsafe { CStr::from_ptr(p) }.to_string_lossy().into_owned();
542    // SAFETY: `p` was malloc-allocated by the bridge and has not been freed yet.
543    unsafe { libc::free(p.cast()) };
544    s
545}