Skip to main content

apple_vision/tracking/
mod.rs

1#![allow(clippy::cast_possible_wrap, clippy::cast_sign_loss)]
2#![allow(clippy::too_long_first_doc_paragraph)]
3//! Stateful Vision tracking requests backed by retained Swift sessions.
4//!
5//! These wrappers keep the underlying Vision request alive across frames,
6//! which is required for object / rectangle tracking and the sequence-based
7//! optical-flow and image-registration trackers.
8
9use core::ffi::{c_char, c_void};
10use core::ptr;
11use std::ffi::{CStr, CString};
12use std::path::Path;
13
14use crate::error::VisionError;
15use crate::face_landmarks::LandmarkPoint;
16use crate::ffi;
17use crate::recognize_text::BoundingBox;
18use crate::rectangles::RectangleObservation;
19use crate::registration::{HomographicAlignment, TranslationalAlignment};
20
21/// Raw optical-flow pixel buffer copied out of Vision.
22#[derive(Debug, Clone, PartialEq, Eq)]
23pub struct OpticalFlowFrame {
24    pub width: usize,
25    pub height: usize,
26    pub bytes_per_row: usize,
27    pub bytes: Vec<u8>,
28}
29
30impl OpticalFlowFrame {
31    /// Borrow the copied raw pixel-buffer bytes.
32    #[must_use]
33    pub fn as_bytes(&self) -> &[u8] {
34        &self.bytes
35    }
36}
37
38/// Tracks a detected object's bounding box across a sequence of frames.
39pub struct ObjectTracker {
40    handle: *mut c_void,
41}
42
43/// Tracks a known rectangle observation across a sequence of frames.
44pub struct RectangleTracker {
45    handle: *mut c_void,
46}
47
48/// Tracks dense optical flow across a frame sequence.
49pub struct OpticalFlowTracker {
50    handle: *mut c_void,
51}
52
53/// Tracks translational image registration across frames.
54pub struct TranslationalImageTracker {
55    handle: *mut c_void,
56}
57
58/// Tracks homographic image registration across frames.
59pub struct HomographicImageTracker {
60    handle: *mut c_void,
61}
62
63impl ObjectTracker {
64    /// Create a new object tracker seeded from `image_path` and `bbox`.
65    ///
66    /// # Errors
67    ///
68    /// Returns [`VisionError`] if the image path is invalid, the image
69    /// fails to load, or Vision rejects the tracking request.
70    pub fn new(image_path: impl AsRef<Path>, bbox: BoundingBox) -> Result<Self, VisionError> {
71        let image_c = path_to_cstring(image_path.as_ref(), "image path")?;
72        let mut raw_bbox = ffi::SimpleRectRaw {
73            x: bbox.x,
74            y: bbox.y,
75            w: bbox.width,
76            h: bbox.height,
77            confidence: 1.0,
78            _pad: 0.0,
79        };
80        let mut handle: *mut c_void = ptr::null_mut();
81        let mut err: *mut c_char = ptr::null_mut();
82        let status = unsafe {
83            ffi::vn_object_tracker_create(
84                image_c.as_ptr(),
85                ptr::addr_of_mut!(raw_bbox).cast(),
86                &mut handle,
87                &mut err,
88            )
89        };
90        if status != ffi::status::OK {
91            return Err(error_from_status(status, err));
92        }
93        ensure_handle(handle, "object tracker")?;
94        Ok(Self { handle })
95    }
96
97    /// Track the object into `image_path` and return the updated bounding
98    /// box.
99    ///
100    /// # Errors
101    ///
102    /// Returns [`VisionError`] if the image path is invalid, the image
103    /// fails to load, or Vision rejects the tracking request.
104    pub fn track(&mut self, image_path: impl AsRef<Path>) -> Result<BoundingBox, VisionError> {
105        let image_c = path_to_cstring(image_path.as_ref(), "image path")?;
106        let mut raw_bbox = ffi::SimpleRectRaw {
107            x: 0.0,
108            y: 0.0,
109            w: 0.0,
110            h: 0.0,
111            confidence: 0.0,
112            _pad: 0.0,
113        };
114        let mut err: *mut c_char = ptr::null_mut();
115        let status = unsafe {
116            ffi::vn_object_tracker_track(
117                self.handle,
118                image_c.as_ptr(),
119                ptr::addr_of_mut!(raw_bbox).cast(),
120                &mut err,
121            )
122        };
123        if status != ffi::status::OK {
124            return Err(error_from_status(status, err));
125        }
126        Ok(BoundingBox {
127            x: raw_bbox.x,
128            y: raw_bbox.y,
129            width: raw_bbox.w,
130            height: raw_bbox.h,
131        })
132    }
133}
134
135impl RectangleTracker {
136    /// Create a new rectangle tracker seeded from `image_path` and the
137    /// known rectangle observation for that frame.
138    ///
139    /// # Errors
140    ///
141    /// Returns [`VisionError`] if the image path is invalid, the image
142    /// fails to load, or Vision rejects the tracking request.
143    pub fn new(
144        image_path: impl AsRef<Path>,
145        rect_observation: &RectangleObservation,
146    ) -> Result<Self, VisionError> {
147        let image_c = path_to_cstring(image_path.as_ref(), "image path")?;
148        let mut raw = rectangle_to_raw(rect_observation);
149        let mut handle: *mut c_void = ptr::null_mut();
150        let mut err: *mut c_char = ptr::null_mut();
151        let status = unsafe {
152            ffi::vn_rectangle_tracker_create(
153                image_c.as_ptr(),
154                ptr::addr_of_mut!(raw).cast(),
155                &mut handle,
156                &mut err,
157            )
158        };
159        if status != ffi::status::OK {
160            return Err(error_from_status(status, err));
161        }
162        ensure_handle(handle, "rectangle tracker")?;
163        Ok(Self { handle })
164    }
165
166    /// Track the rectangle into `image_path` and return the updated
167    /// rectangle observation.
168    ///
169    /// # Errors
170    ///
171    /// Returns [`VisionError`] if the image path is invalid, the image
172    /// fails to load, or Vision rejects the tracking request.
173    pub fn track(
174        &mut self,
175        image_path: impl AsRef<Path>,
176    ) -> Result<RectangleObservation, VisionError> {
177        let image_c = path_to_cstring(image_path.as_ref(), "image path")?;
178        let mut raw = ffi::RectangleObservationRaw {
179            bbox_x: 0.0,
180            bbox_y: 0.0,
181            bbox_w: 0.0,
182            bbox_h: 0.0,
183            confidence: 0.0,
184            tl_x: 0.0,
185            tl_y: 0.0,
186            tr_x: 0.0,
187            tr_y: 0.0,
188            bl_x: 0.0,
189            bl_y: 0.0,
190            br_x: 0.0,
191            br_y: 0.0,
192        };
193        let mut err: *mut c_char = ptr::null_mut();
194        let status = unsafe {
195            ffi::vn_rectangle_tracker_track(
196                self.handle,
197                image_c.as_ptr(),
198                ptr::addr_of_mut!(raw).cast(),
199                &mut err,
200            )
201        };
202        if status != ffi::status::OK {
203            return Err(error_from_status(status, err));
204        }
205        Ok(rectangle_from_raw(&raw))
206    }
207}
208
209impl OpticalFlowTracker {
210    /// Create a new optical-flow tracker seeded with the reference image.
211    ///
212    /// # Errors
213    ///
214    /// Returns [`VisionError`] if the image path is invalid, the image
215    /// fails to load, or Vision rejects the tracking request.
216    pub fn new(reference_path: impl AsRef<Path>) -> Result<Self, VisionError> {
217        let image_c = path_to_cstring(reference_path.as_ref(), "reference path")?;
218        let mut handle: *mut c_void = ptr::null_mut();
219        let mut err: *mut c_char = ptr::null_mut();
220        let status = unsafe {
221            ffi::vn_optical_flow_tracker_create(image_c.as_ptr(), &mut handle, &mut err)
222        };
223        if status != ffi::status::OK {
224            return Err(error_from_status(status, err));
225        }
226        ensure_handle(handle, "optical-flow tracker")?;
227        Ok(Self { handle })
228    }
229
230    /// Track optical flow into `image_path` and return the copied raw
231    /// pixel-buffer bytes.
232    ///
233    /// # Errors
234    ///
235    /// Returns [`VisionError`] if the image path is invalid, the image
236    /// fails to load, or Vision rejects the tracking request.
237    pub fn track(&mut self, image_path: impl AsRef<Path>) -> Result<OpticalFlowFrame, VisionError> {
238        let image_c = path_to_cstring(image_path.as_ref(), "image path")?;
239        let mut raw = ffi::SegmentationMaskRaw {
240            width: 0,
241            height: 0,
242            bytes_per_row: 0,
243            bytes: ptr::null_mut(),
244        };
245        let mut err: *mut c_char = ptr::null_mut();
246        let status = unsafe {
247            ffi::vn_optical_flow_tracker_track(
248                self.handle,
249                image_c.as_ptr(),
250                ptr::addr_of_mut!(raw).cast(),
251                &mut err,
252            )
253        };
254        if status != ffi::status::OK {
255            return Err(error_from_status(status, err));
256        }
257        let frame = copy_mask(&raw);
258        unsafe { ffi::vn_segmentation_mask_free(ptr::addr_of_mut!(raw).cast()) };
259        Ok(frame)
260    }
261}
262
263impl TranslationalImageTracker {
264    /// Create a new translational-registration tracker seeded with the
265    /// reference image.
266    ///
267    /// # Errors
268    ///
269    /// Returns [`VisionError`] if the image path is invalid, the image
270    /// fails to load, or Vision rejects the tracking request.
271    pub fn new(reference_path: impl AsRef<Path>) -> Result<Self, VisionError> {
272        let image_c = path_to_cstring(reference_path.as_ref(), "reference path")?;
273        let mut handle: *mut c_void = ptr::null_mut();
274        let mut err: *mut c_char = ptr::null_mut();
275        let status = unsafe {
276            ffi::vn_translational_image_tracker_create(image_c.as_ptr(), &mut handle, &mut err)
277        };
278        if status != ffi::status::OK {
279            return Err(error_from_status(status, err));
280        }
281        ensure_handle(handle, "translational tracker")?;
282        Ok(Self { handle })
283    }
284
285    /// Track the translational alignment into `image_path`.
286    ///
287    /// # Errors
288    ///
289    /// Returns [`VisionError`] if the image path is invalid, the image
290    /// fails to load, or Vision rejects the tracking request.
291    pub fn track(
292        &mut self,
293        image_path: impl AsRef<Path>,
294    ) -> Result<TranslationalAlignment, VisionError> {
295        let image_c = path_to_cstring(image_path.as_ref(), "image path")?;
296        let mut raw = ffi::TranslationalAlignmentRaw { tx: 0.0, ty: 0.0 };
297        let mut err: *mut c_char = ptr::null_mut();
298        let status = unsafe {
299            ffi::vn_translational_image_tracker_track(
300                self.handle,
301                image_c.as_ptr(),
302                ptr::addr_of_mut!(raw).cast(),
303                &mut err,
304            )
305        };
306        if status != ffi::status::OK {
307            return Err(error_from_status(status, err));
308        }
309        Ok(TranslationalAlignment {
310            tx: raw.tx,
311            ty: raw.ty,
312        })
313    }
314}
315
316impl HomographicImageTracker {
317    /// Create a new homographic-registration tracker seeded with the
318    /// reference image.
319    ///
320    /// # Errors
321    ///
322    /// Returns [`VisionError`] if the image path is invalid, the image
323    /// fails to load, or Vision rejects the tracking request.
324    pub fn new(reference_path: impl AsRef<Path>) -> Result<Self, VisionError> {
325        let image_c = path_to_cstring(reference_path.as_ref(), "reference path")?;
326        let mut handle: *mut c_void = ptr::null_mut();
327        let mut err: *mut c_char = ptr::null_mut();
328        let status = unsafe {
329            ffi::vn_homographic_image_tracker_create(image_c.as_ptr(), &mut handle, &mut err)
330        };
331        if status != ffi::status::OK {
332            return Err(error_from_status(status, err));
333        }
334        ensure_handle(handle, "homographic tracker")?;
335        Ok(Self { handle })
336    }
337
338    /// Track the homographic alignment into `image_path`.
339    ///
340    /// # Errors
341    ///
342    /// Returns [`VisionError`] if the image path is invalid, the image
343    /// fails to load, or Vision rejects the tracking request.
344    pub fn track(
345        &mut self,
346        image_path: impl AsRef<Path>,
347    ) -> Result<HomographicAlignment, VisionError> {
348        let image_c = path_to_cstring(image_path.as_ref(), "image path")?;
349        let mut raw = ffi::HomographicAlignmentRaw {
350            m00: 0.0,
351            m01: 0.0,
352            m02: 0.0,
353            m10: 0.0,
354            m11: 0.0,
355            m12: 0.0,
356            m20: 0.0,
357            m21: 0.0,
358            m22: 0.0,
359            _pad: 0.0,
360        };
361        let mut err: *mut c_char = ptr::null_mut();
362        let status = unsafe {
363            ffi::vn_homographic_image_tracker_track(
364                self.handle,
365                image_c.as_ptr(),
366                ptr::addr_of_mut!(raw).cast(),
367                &mut err,
368            )
369        };
370        if status != ffi::status::OK {
371            return Err(error_from_status(status, err));
372        }
373        Ok(HomographicAlignment {
374            matrix: [
375                [raw.m00, raw.m01, raw.m02],
376                [raw.m10, raw.m11, raw.m12],
377                [raw.m20, raw.m21, raw.m22],
378            ],
379        })
380    }
381}
382
383macro_rules! impl_tracker_drop {
384    ($tracker:ident, $release:path) => {
385        impl Drop for $tracker {
386            fn drop(&mut self) {
387                if !self.handle.is_null() {
388                    unsafe { $release(self.handle) };
389                    self.handle = ptr::null_mut();
390                }
391            }
392        }
393    };
394}
395
396impl_tracker_drop!(ObjectTracker, ffi::vn_object_tracker_release);
397impl_tracker_drop!(RectangleTracker, ffi::vn_rectangle_tracker_release);
398impl_tracker_drop!(OpticalFlowTracker, ffi::vn_optical_flow_tracker_release);
399impl_tracker_drop!(TranslationalImageTracker, ffi::vn_translational_image_tracker_release);
400impl_tracker_drop!(HomographicImageTracker, ffi::vn_homographic_image_tracker_release);
401
402fn path_to_cstring(path: &Path, label: &str) -> Result<CString, VisionError> {
403    let path_str = path
404        .to_str()
405        .ok_or_else(|| VisionError::InvalidArgument(format!("non-UTF-8 {label}")))?;
406    CString::new(path_str).map_err(|e| VisionError::InvalidArgument(format!("{label} NUL byte: {e}")))
407}
408
409const fn rectangle_to_raw(rect: &RectangleObservation) -> ffi::RectangleObservationRaw {
410    ffi::RectangleObservationRaw {
411        bbox_x: rect.bounding_box.x,
412        bbox_y: rect.bounding_box.y,
413        bbox_w: rect.bounding_box.width,
414        bbox_h: rect.bounding_box.height,
415        confidence: rect.confidence,
416        tl_x: rect.top_left.x,
417        tl_y: rect.top_left.y,
418        tr_x: rect.top_right.x,
419        tr_y: rect.top_right.y,
420        bl_x: rect.bottom_left.x,
421        bl_y: rect.bottom_left.y,
422        br_x: rect.bottom_right.x,
423        br_y: rect.bottom_right.y,
424    }
425}
426
427const fn rectangle_from_raw(raw: &ffi::RectangleObservationRaw) -> RectangleObservation {
428    RectangleObservation {
429        bounding_box: BoundingBox {
430            x: raw.bbox_x,
431            y: raw.bbox_y,
432            width: raw.bbox_w,
433            height: raw.bbox_h,
434        },
435        confidence: raw.confidence,
436        top_left: LandmarkPoint {
437            x: raw.tl_x,
438            y: raw.tl_y,
439        },
440        top_right: LandmarkPoint {
441            x: raw.tr_x,
442            y: raw.tr_y,
443        },
444        bottom_left: LandmarkPoint {
445            x: raw.bl_x,
446            y: raw.bl_y,
447        },
448        bottom_right: LandmarkPoint {
449            x: raw.br_x,
450            y: raw.br_y,
451        },
452    }
453}
454
455fn copy_mask(raw: &ffi::SegmentationMaskRaw) -> OpticalFlowFrame {
456    if raw.bytes.is_null() {
457        return OpticalFlowFrame {
458            width: raw.width,
459            height: raw.height,
460            bytes_per_row: raw.bytes_per_row,
461            bytes: Vec::new(),
462        };
463    }
464    let len = raw.height.saturating_mul(raw.bytes_per_row);
465    let bytes = unsafe { core::slice::from_raw_parts(raw.bytes.cast::<u8>(), len) }.to_vec();
466    OpticalFlowFrame {
467        width: raw.width,
468        height: raw.height,
469        bytes_per_row: raw.bytes_per_row,
470        bytes,
471    }
472}
473
474fn ensure_handle(handle: *mut c_void, tracker_name: &str) -> Result<(), VisionError> {
475    if handle.is_null() {
476        return Err(VisionError::Unknown {
477            code: ffi::status::UNKNOWN,
478            message: format!("{tracker_name} returned a null handle"),
479        });
480    }
481    Ok(())
482}
483
484fn error_from_status(status: i32, err: *mut c_char) -> VisionError {
485    let message = unsafe { take_err(err) };
486    match status {
487        ffi::status::IMAGE_LOAD_FAILED => VisionError::ImageLoadFailed(message),
488        ffi::status::REQUEST_FAILED => VisionError::RequestFailed(message),
489        ffi::status::INVALID_ARGUMENT => VisionError::InvalidArgument(message),
490        code => VisionError::Unknown { code, message },
491    }
492}
493
494unsafe fn take_err(p: *mut c_char) -> String {
495    if p.is_null() {
496        return String::new();
497    }
498    let s = unsafe { CStr::from_ptr(p) }.to_string_lossy().into_owned();
499    unsafe { libc::free(p.cast()) };
500    s
501}