Skip to main content

apple_vision/
geometry.rs

1#![allow(
2    clippy::cast_possible_truncation,
3    clippy::cast_precision_loss,
4    clippy::imprecise_flops,
5    clippy::missing_const_for_fn,
6    clippy::should_implement_trait,
7    clippy::suboptimal_flops
8)]
9//! Geometry wrappers and utility helpers mirroring Vision's `VNGeometry*` and
10//! `VNUtils` surfaces.
11
12use crate::{request_base::NormalizedRect, sdk::ElementType};
13
14/// A two-dimensional Vision point (`VNPoint`).
15#[derive(Debug, Clone, Copy, PartialEq)]
16#[repr(C)]
17pub struct VisionPoint {
18    pub x: f64,
19    pub y: f64,
20}
21
22impl VisionPoint {
23    #[must_use]
24    pub const fn new(x: f64, y: f64) -> Self {
25        Self { x, y }
26    }
27
28    #[must_use]
29    pub const fn zero() -> Self {
30        Self { x: 0.0, y: 0.0 }
31    }
32
33    #[must_use]
34    pub fn apply_vector(self, vector: VisionVector) -> Self {
35        Self {
36            x: self.x + vector.x,
37            y: self.y + vector.y,
38        }
39    }
40
41    #[must_use]
42    pub fn distance_to(self, point: Self) -> f64 {
43        ((point.x - self.x).powi(2) + (point.y - self.y).powi(2)).sqrt()
44    }
45}
46
47/// A two-dimensional Vision vector (`VNVector`).
48#[derive(Debug, Clone, Copy, PartialEq)]
49pub struct VisionVector {
50    pub x: f64,
51    pub y: f64,
52}
53
54impl VisionVector {
55    #[must_use]
56    pub const fn new(x: f64, y: f64) -> Self {
57        Self { x, y }
58    }
59
60    #[must_use]
61    pub const fn zero() -> Self {
62        Self { x: 0.0, y: 0.0 }
63    }
64
65    #[must_use]
66    pub fn from_points(head: VisionPoint, tail: VisionPoint) -> Self {
67        Self {
68            x: head.x - tail.x,
69            y: head.y - tail.y,
70        }
71    }
72
73    #[must_use]
74    pub fn unit(self) -> Self {
75        let length = self.length();
76        if length <= f64::EPSILON {
77            Self::zero()
78        } else {
79            Self {
80                x: self.x / length,
81                y: self.y / length,
82            }
83        }
84    }
85
86    #[must_use]
87    pub fn multiply(self, scalar: f64) -> Self {
88        Self {
89            x: self.x * scalar,
90            y: self.y * scalar,
91        }
92    }
93
94    #[must_use]
95    pub fn add(self, other: Self) -> Self {
96        Self {
97            x: self.x + other.x,
98            y: self.y + other.y,
99        }
100    }
101
102    #[must_use]
103    pub fn subtract(self, other: Self) -> Self {
104        Self {
105            x: self.x - other.x,
106            y: self.y - other.y,
107        }
108    }
109
110    #[must_use]
111    pub fn dot(self, other: Self) -> f64 {
112        self.x * other.x + self.y * other.y
113    }
114
115    #[must_use]
116    pub fn r(self) -> f64 {
117        self.length()
118    }
119
120    #[must_use]
121    pub fn theta(self) -> f64 {
122        self.y.atan2(self.x)
123    }
124
125    #[must_use]
126    pub fn length(self) -> f64 {
127        self.squared_length().sqrt()
128    }
129
130    #[must_use]
131    pub fn squared_length(self) -> f64 {
132        self.x.powi(2) + self.y.powi(2)
133    }
134}
135
136/// A two-dimensional Vision circle (`VNCircle`).
137#[derive(Debug, Clone, Copy, PartialEq)]
138pub struct VisionCircle {
139    pub center: VisionPoint,
140    pub radius: f64,
141}
142
143impl VisionCircle {
144    #[must_use]
145    pub const fn new(center: VisionPoint, radius: f64) -> Self {
146        Self { center, radius }
147    }
148
149    #[must_use]
150    pub const fn zero() -> Self {
151        Self {
152            center: VisionPoint::zero(),
153            radius: 0.0,
154        }
155    }
156
157    #[must_use]
158    pub const fn from_diameter(center: VisionPoint, diameter: f64) -> Self {
159        Self {
160            center,
161            radius: diameter / 2.0,
162        }
163    }
164
165    #[must_use]
166    pub const fn diameter(self) -> f64 {
167        self.radius * 2.0
168    }
169
170    #[must_use]
171    pub fn contains_point(self, point: VisionPoint) -> bool {
172        self.center.distance_to(point) <= self.radius + 1e-9
173    }
174
175    #[must_use]
176    pub fn contains_point_in_circumferential_ring(
177        self,
178        point: VisionPoint,
179        ring_width: f64,
180    ) -> bool {
181        let distance = self.center.distance_to(point);
182        let delta = ring_width.abs() / 2.0;
183        distance >= self.radius - delta - 1e-9 && distance <= self.radius + delta + 1e-9
184    }
185}
186
187/// Column-major 4×4 transform used by Vision's 3-D point wrappers.
188#[derive(Debug, Clone, Copy, PartialEq)]
189pub struct Transform3D {
190    pub columns: [[f32; 4]; 4],
191}
192
193impl Transform3D {
194    #[must_use]
195    pub const fn identity() -> Self {
196        Self {
197            columns: [
198                [1.0, 0.0, 0.0, 0.0],
199                [0.0, 1.0, 0.0, 0.0],
200                [0.0, 0.0, 1.0, 0.0],
201                [0.0, 0.0, 0.0, 1.0],
202            ],
203        }
204    }
205
206    #[must_use]
207    pub const fn from_translation(x: f32, y: f32, z: f32) -> Self {
208        Self {
209            columns: [
210                [1.0, 0.0, 0.0, 0.0],
211                [0.0, 1.0, 0.0, 0.0],
212                [0.0, 0.0, 1.0, 0.0],
213                [x, y, z, 1.0],
214            ],
215        }
216    }
217
218    #[must_use]
219    pub const fn translation(self) -> (f32, f32, f32) {
220        (self.columns[3][0], self.columns[3][1], self.columns[3][2])
221    }
222}
223
224/// A three-dimensional Vision point (`VNPoint3D`).
225#[derive(Debug, Clone, Copy, PartialEq)]
226pub struct VisionPoint3D {
227    pub position: Transform3D,
228}
229
230impl VisionPoint3D {
231    #[must_use]
232    pub const fn new(position: Transform3D) -> Self {
233        Self { position }
234    }
235
236    #[must_use]
237    pub const fn from_xyz(x: f32, y: f32, z: f32) -> Self {
238        Self {
239            position: Transform3D::from_translation(x, y, z),
240        }
241    }
242
243    #[must_use]
244    pub const fn x(self) -> f32 {
245        self.position.columns[3][0]
246    }
247
248    #[must_use]
249    pub const fn y(self) -> f32 {
250        self.position.columns[3][1]
251    }
252
253    #[must_use]
254    pub const fn z(self) -> f32 {
255        self.position.columns[3][2]
256    }
257}
258
259#[derive(Clone, Copy)]
260#[repr(C)]
261struct CGPointRaw {
262    x: f64,
263    y: f64,
264}
265
266#[derive(Clone, Copy)]
267#[repr(C)]
268struct CGSizeRaw {
269    width: f64,
270    height: f64,
271}
272
273#[derive(Clone, Copy)]
274#[repr(C)]
275struct CGRectRaw {
276    origin: CGPointRaw,
277    size: CGSizeRaw,
278}
279
280#[repr(C)]
281struct VectorFloat2Raw {
282    x: f32,
283    y: f32,
284}
285
286extern "C" {
287    static VNNormalizedIdentityRect: CGRectRaw;
288
289    fn VNNormalizedRectIsIdentityRect(normalized_rect: CGRectRaw) -> bool;
290    fn VNImagePointForNormalizedPoint(
291        normalized_point: CGPointRaw,
292        image_width: usize,
293        image_height: usize,
294    ) -> CGPointRaw;
295    fn VNImagePointForNormalizedPointUsingRegionOfInterest(
296        normalized_point: CGPointRaw,
297        image_width: usize,
298        image_height: usize,
299        roi: CGRectRaw,
300    ) -> CGPointRaw;
301    fn VNNormalizedPointForImagePoint(
302        image_point: CGPointRaw,
303        image_width: usize,
304        image_height: usize,
305    ) -> CGPointRaw;
306    fn VNNormalizedPointForImagePointUsingRegionOfInterest(
307        image_point: CGPointRaw,
308        image_width: usize,
309        image_height: usize,
310        roi: CGRectRaw,
311    ) -> CGPointRaw;
312    fn VNImageRectForNormalizedRect(
313        normalized_rect: CGRectRaw,
314        image_width: usize,
315        image_height: usize,
316    ) -> CGRectRaw;
317    fn VNImageRectForNormalizedRectUsingRegionOfInterest(
318        normalized_rect: CGRectRaw,
319        image_width: usize,
320        image_height: usize,
321        roi: CGRectRaw,
322    ) -> CGRectRaw;
323    fn VNNormalizedRectForImageRect(
324        image_rect: CGRectRaw,
325        image_width: usize,
326        image_height: usize,
327    ) -> CGRectRaw;
328    fn VNNormalizedRectForImageRectUsingRegionOfInterest(
329        image_rect: CGRectRaw,
330        image_width: usize,
331        image_height: usize,
332        roi: CGRectRaw,
333    ) -> CGRectRaw;
334    fn VNNormalizedFaceBoundingBoxPointForLandmarkPoint(
335        face_landmark_point: VectorFloat2Raw,
336        face_bounding_box: CGRectRaw,
337        image_width: usize,
338        image_height: usize,
339    ) -> CGPointRaw;
340    fn VNImagePointForFaceLandmarkPoint(
341        face_landmark_point: VectorFloat2Raw,
342        face_bounding_box: CGRectRaw,
343        image_width: usize,
344        image_height: usize,
345    ) -> CGPointRaw;
346    fn VNElementTypeSize(element_type: usize) -> usize;
347}
348
349fn point_to_raw(point: VisionPoint) -> CGPointRaw {
350    CGPointRaw {
351        x: point.x,
352        y: point.y,
353    }
354}
355
356fn point_from_raw(raw: CGPointRaw) -> VisionPoint {
357    VisionPoint { x: raw.x, y: raw.y }
358}
359
360fn rect_to_raw(rect: NormalizedRect) -> CGRectRaw {
361    CGRectRaw {
362        origin: CGPointRaw {
363            x: rect.x,
364            y: rect.y,
365        },
366        size: CGSizeRaw {
367            width: rect.width,
368            height: rect.height,
369        },
370    }
371}
372
373fn rect_from_raw(raw: CGRectRaw) -> NormalizedRect {
374    NormalizedRect::new(raw.origin.x, raw.origin.y, raw.size.width, raw.size.height)
375}
376
377/// Mirrors `VNNormalizedIdentityRect`.
378#[must_use]
379pub fn normalized_identity_rect() -> NormalizedRect {
380    // SAFETY: `VNNormalizedIdentityRect` is a valid extern static provided by the Vision framework.
381    rect_from_raw(unsafe { VNNormalizedIdentityRect })
382}
383
384/// Mirrors `VNNormalizedRectIsIdentityRect`.
385#[must_use]
386pub fn normalized_rect_is_identity_rect(normalized_rect: NormalizedRect) -> bool {
387    // SAFETY: the arguments are plain value types passed by copy; the function is a pure math helper from the Vision framework.
388    unsafe { VNNormalizedRectIsIdentityRect(rect_to_raw(normalized_rect)) }
389}
390
391/// Mirrors `VNImagePointForNormalizedPoint`.
392#[must_use]
393pub fn image_point_for_normalized_point(
394    normalized_point: VisionPoint,
395    image_width: usize,
396    image_height: usize,
397) -> VisionPoint {
398    // SAFETY: the arguments are plain value types passed by copy; the function is a pure math helper from the Vision framework.
399    point_from_raw(unsafe {
400        VNImagePointForNormalizedPoint(point_to_raw(normalized_point), image_width, image_height)
401    })
402}
403
404/// Mirrors `VNImagePointForNormalizedPointUsingRegionOfInterest`.
405#[must_use]
406pub fn image_point_for_normalized_point_using_region_of_interest(
407    normalized_point: VisionPoint,
408    image_width: usize,
409    image_height: usize,
410    region_of_interest: NormalizedRect,
411) -> VisionPoint {
412    // SAFETY: the arguments are plain value types passed by copy; the function is a pure math helper from the Vision framework.
413    point_from_raw(unsafe {
414        VNImagePointForNormalizedPointUsingRegionOfInterest(
415            point_to_raw(normalized_point),
416            image_width,
417            image_height,
418            rect_to_raw(region_of_interest),
419        )
420    })
421}
422
423/// Mirrors `VNNormalizedPointForImagePoint`.
424#[must_use]
425pub fn normalized_point_for_image_point(
426    image_point: VisionPoint,
427    image_width: usize,
428    image_height: usize,
429) -> VisionPoint {
430    // SAFETY: the arguments are plain value types passed by copy; the function is a pure math helper from the Vision framework.
431    point_from_raw(unsafe {
432        VNNormalizedPointForImagePoint(point_to_raw(image_point), image_width, image_height)
433    })
434}
435
436/// Mirrors `VNNormalizedPointForImagePointUsingRegionOfInterest`.
437#[must_use]
438pub fn normalized_point_for_image_point_using_region_of_interest(
439    image_point: VisionPoint,
440    image_width: usize,
441    image_height: usize,
442    region_of_interest: NormalizedRect,
443) -> VisionPoint {
444    // SAFETY: the arguments are plain value types passed by copy; the function is a pure math helper from the Vision framework.
445    point_from_raw(unsafe {
446        VNNormalizedPointForImagePointUsingRegionOfInterest(
447            point_to_raw(image_point),
448            image_width,
449            image_height,
450            rect_to_raw(region_of_interest),
451        )
452    })
453}
454
455/// Mirrors `VNImageRectForNormalizedRect`.
456#[must_use]
457pub fn image_rect_for_normalized_rect(
458    normalized_rect: NormalizedRect,
459    image_width: usize,
460    image_height: usize,
461) -> NormalizedRect {
462    // SAFETY: the arguments are plain value types passed by copy; the function is a pure math helper from the Vision framework.
463    rect_from_raw(unsafe {
464        VNImageRectForNormalizedRect(rect_to_raw(normalized_rect), image_width, image_height)
465    })
466}
467
468/// Mirrors `VNImageRectForNormalizedRectUsingRegionOfInterest`.
469#[must_use]
470pub fn image_rect_for_normalized_rect_using_region_of_interest(
471    normalized_rect: NormalizedRect,
472    image_width: usize,
473    image_height: usize,
474    region_of_interest: NormalizedRect,
475) -> NormalizedRect {
476    // SAFETY: the arguments are plain value types passed by copy; the function is a pure math helper from the Vision framework.
477    rect_from_raw(unsafe {
478        VNImageRectForNormalizedRectUsingRegionOfInterest(
479            rect_to_raw(normalized_rect),
480            image_width,
481            image_height,
482            rect_to_raw(region_of_interest),
483        )
484    })
485}
486
487/// Mirrors `VNNormalizedRectForImageRect`.
488#[must_use]
489pub fn normalized_rect_for_image_rect(
490    image_rect: NormalizedRect,
491    image_width: usize,
492    image_height: usize,
493) -> NormalizedRect {
494    // SAFETY: the arguments are plain value types passed by copy; the function is a pure math helper from the Vision framework.
495    rect_from_raw(unsafe {
496        VNNormalizedRectForImageRect(rect_to_raw(image_rect), image_width, image_height)
497    })
498}
499
500/// Mirrors `VNNormalizedRectForImageRectUsingRegionOfInterest`.
501#[must_use]
502pub fn normalized_rect_for_image_rect_using_region_of_interest(
503    image_rect: NormalizedRect,
504    image_width: usize,
505    image_height: usize,
506    region_of_interest: NormalizedRect,
507) -> NormalizedRect {
508    // SAFETY: the arguments are plain value types passed by copy; the function is a pure math helper from the Vision framework.
509    rect_from_raw(unsafe {
510        VNNormalizedRectForImageRectUsingRegionOfInterest(
511            rect_to_raw(image_rect),
512            image_width,
513            image_height,
514            rect_to_raw(region_of_interest),
515        )
516    })
517}
518
519/// Mirrors `VNNormalizedFaceBoundingBoxPointForLandmarkPoint`.
520#[must_use]
521pub fn normalized_face_bounding_box_point_for_landmark_point(
522    face_landmark_point: VisionPoint,
523    face_bounding_box: NormalizedRect,
524    image_width: usize,
525    image_height: usize,
526) -> VisionPoint {
527    // SAFETY: the arguments are plain value types passed by copy; the function is a pure math helper from the Vision framework.
528    point_from_raw(unsafe {
529        VNNormalizedFaceBoundingBoxPointForLandmarkPoint(
530            VectorFloat2Raw {
531                x: face_landmark_point.x as f32,
532                y: face_landmark_point.y as f32,
533            },
534            rect_to_raw(face_bounding_box),
535            image_width,
536            image_height,
537        )
538    })
539}
540
541/// Mirrors `VNImagePointForFaceLandmarkPoint`.
542#[must_use]
543pub fn image_point_for_face_landmark_point(
544    face_landmark_point: VisionPoint,
545    face_bounding_box: NormalizedRect,
546    image_width: usize,
547    image_height: usize,
548) -> VisionPoint {
549    // SAFETY: the arguments are plain value types passed by copy; the function is a pure math helper from the Vision framework.
550    point_from_raw(unsafe {
551        VNImagePointForFaceLandmarkPoint(
552            VectorFloat2Raw {
553                x: face_landmark_point.x as f32,
554                y: face_landmark_point.y as f32,
555            },
556            rect_to_raw(face_bounding_box),
557            image_width,
558            image_height,
559        )
560    })
561}
562
563/// Mirrors `VNElementTypeSize`.
564#[must_use]
565pub fn element_type_size(element_type: ElementType) -> usize {
566    // SAFETY: the arguments are plain value types passed by copy; the function is a pure math helper from the Vision framework.
567    unsafe { VNElementTypeSize(element_type.as_raw()) }
568}
569
570/// Pure-Rust helpers mirroring `VNGeometryUtils`.
571pub struct VisionGeometryUtils;
572
573impl VisionGeometryUtils {
574    /// Compute a bounding circle covering every point.
575    #[must_use]
576    pub fn bounding_circle_for_points(points: &[VisionPoint]) -> Option<VisionCircle> {
577        minimal_enclosing_circle(points)
578    }
579
580    /// Compute a polygon area using Green's theorem.
581    #[must_use]
582    pub fn calculate_area(points: &[VisionPoint], oriented: bool) -> Option<f64> {
583        if points.len() < 3 {
584            return None;
585        }
586        let mut area = 0.0;
587        for index in 0..points.len() {
588            let next = points[(index + 1) % points.len()];
589            area += points[index].x * next.y - next.x * points[index].y;
590        }
591        area /= 2.0;
592        Some(if oriented { area } else { area.abs() })
593    }
594
595    /// Compute the closed polygon perimeter.
596    #[must_use]
597    pub fn calculate_perimeter(points: &[VisionPoint]) -> Option<f64> {
598        if points.len() < 2 {
599            return None;
600        }
601        let mut perimeter = 0.0;
602        for index in 0..points.len() {
603            perimeter += points[index].distance_to(points[(index + 1) % points.len()]);
604        }
605        Some(perimeter)
606    }
607}
608
609fn minimal_enclosing_circle(points: &[VisionPoint]) -> Option<VisionCircle> {
610    match points.len() {
611        0 => None,
612        1 => Some(VisionCircle::new(points[0], 0.0)),
613        _ => {
614            let mut best: Option<VisionCircle> = None;
615
616            for &point in points {
617                let candidate = VisionCircle::new(point, 0.0);
618                if contains_all(candidate, points) {
619                    best = Some(select_smaller(best, candidate));
620                }
621            }
622
623            for first in 0..points.len() {
624                for second in (first + 1)..points.len() {
625                    let candidate = circle_from_two(points[first], points[second]);
626                    if contains_all(candidate, points) {
627                        best = Some(select_smaller(best, candidate));
628                    }
629                }
630            }
631
632            for first in 0..points.len() {
633                for second in (first + 1)..points.len() {
634                    for third in (second + 1)..points.len() {
635                        if let Some(candidate) =
636                            circle_from_three(points[first], points[second], points[third])
637                        {
638                            if contains_all(candidate, points) {
639                                best = Some(select_smaller(best, candidate));
640                            }
641                        }
642                    }
643                }
644            }
645
646            best.or_else(|| {
647                let min_x = points
648                    .iter()
649                    .map(|point| point.x)
650                    .fold(f64::INFINITY, f64::min);
651                let max_x = points
652                    .iter()
653                    .map(|point| point.x)
654                    .fold(f64::NEG_INFINITY, f64::max);
655                let min_y = points
656                    .iter()
657                    .map(|point| point.y)
658                    .fold(f64::INFINITY, f64::min);
659                let max_y = points
660                    .iter()
661                    .map(|point| point.y)
662                    .fold(f64::NEG_INFINITY, f64::max);
663                let center = VisionPoint::new((min_x + max_x) / 2.0, (min_y + max_y) / 2.0);
664                Some(VisionCircle::new(
665                    center,
666                    center.distance_to(VisionPoint::new(max_x, max_y)),
667                ))
668            })
669        }
670    }
671}
672
673fn select_smaller(current: Option<VisionCircle>, candidate: VisionCircle) -> VisionCircle {
674    current.map_or(candidate, |existing| {
675        if candidate.radius < existing.radius {
676            candidate
677        } else {
678            existing
679        }
680    })
681}
682
683fn contains_all(circle: VisionCircle, points: &[VisionPoint]) -> bool {
684    points
685        .iter()
686        .copied()
687        .all(|point| circle.contains_point(point))
688}
689
690fn circle_from_two(first: VisionPoint, second: VisionPoint) -> VisionCircle {
691    let center = VisionPoint::new((first.x + second.x) / 2.0, (first.y + second.y) / 2.0);
692    VisionCircle::new(center, center.distance_to(first))
693}
694
695fn circle_from_three(
696    first: VisionPoint,
697    second: VisionPoint,
698    third: VisionPoint,
699) -> Option<VisionCircle> {
700    let d = 2.0
701        * (first.x * (second.y - third.y)
702            + second.x * (third.y - first.y)
703            + third.x * (first.y - second.y));
704    if d.abs() <= f64::EPSILON {
705        return None;
706    }
707
708    let first_sq = first.x.powi(2) + first.y.powi(2);
709    let second_sq = second.x.powi(2) + second.y.powi(2);
710    let third_sq = third.x.powi(2) + third.y.powi(2);
711
712    let ux = (first_sq * (second.y - third.y)
713        + second_sq * (third.y - first.y)
714        + third_sq * (first.y - second.y))
715        / d;
716    let uy = (first_sq * (third.x - second.x)
717        + second_sq * (first.x - third.x)
718        + third_sq * (second.x - first.x))
719        / d;
720    let center = VisionPoint::new(ux, uy);
721    Some(VisionCircle::new(center, center.distance_to(first)))
722}