Skip to main content

apple_vision/feature_print/
mod.rs

1//! Image feature print (`VNGenerateImageFeaturePrintRequest`) —
2//! semantic image embedding for content-based similarity.
3
4use core::ffi::c_char;
5use core::ptr;
6use std::ffi::CString;
7use std::path::Path;
8
9use crate::error::{from_swift, VisionError};
10use crate::ffi;
11
12/// A semantic image embedding produced by Apple's vision pipeline.
13///
14/// Distances between two prints (e.g. cosine or L2) measure
15/// content similarity — useful for clustering, deduplication, and
16/// content-based image search.
17#[derive(Debug, Clone, PartialEq)]
18#[allow(clippy::derive_partial_eq_without_eq)]
19pub struct FeaturePrint {
20    /// Underlying element type — `1 = Float32`, `2 = Float64`.
21    pub element_type: i32,
22    /// Vector dimensionality.
23    pub element_count: usize,
24    /// Raw element bytes (length = `element_count * 4` or `* 8`).
25    pub data: Vec<u8>,
26}
27
28impl FeaturePrint {
29    /// Decode the vector as `f32` (only valid when
30    /// `element_type == 1`).
31    #[must_use]
32    pub fn as_f32(&self) -> Option<Vec<f32>> {
33        if self.element_type != 1 {
34            return None;
35        }
36        let mut out = Vec::with_capacity(self.element_count);
37        for chunk in self.data.chunks_exact(4) {
38            let arr: [u8; 4] = chunk.try_into().ok()?;
39            out.push(f32::from_le_bytes(arr));
40        }
41        Some(out)
42    }
43
44    /// Decode the vector as `f64` (only valid when
45    /// `element_type == 2`).
46    #[must_use]
47    pub fn as_f64(&self) -> Option<Vec<f64>> {
48        if self.element_type != 2 {
49            return None;
50        }
51        let mut out = Vec::with_capacity(self.element_count);
52        for chunk in self.data.chunks_exact(8) {
53            let arr: [u8; 8] = chunk.try_into().ok()?;
54            out.push(f64::from_le_bytes(arr));
55        }
56        Some(out)
57    }
58
59    /// Compute Euclidean (L2) distance to another print. Smaller =
60    /// more similar.
61    ///
62    /// # Errors
63    ///
64    /// Returns [`VisionError::InvalidArgument`] if the two prints
65    /// have different element types or counts.
66    pub fn l2_distance(&self, other: &Self) -> Result<f64, VisionError> {
67        if self.element_type != other.element_type
68            || self.element_count != other.element_count
69        {
70            return Err(VisionError::InvalidArgument(
71                "feature print element type / count mismatch".into(),
72            ));
73        }
74        let sumsq: f64 = match self.element_type {
75            1 => self
76                .as_f32()
77                .unwrap_or_default()
78                .iter()
79                .zip(other.as_f32().unwrap_or_default().iter())
80                .map(|(a, b)| f64::from(a - b).powi(2))
81                .sum(),
82            2 => self
83                .as_f64()
84                .unwrap_or_default()
85                .iter()
86                .zip(other.as_f64().unwrap_or_default().iter())
87                .map(|(a, b)| (a - b).powi(2))
88                .sum(),
89            _ => 0.0,
90        };
91        Ok(sumsq.sqrt())
92    }
93}
94
95/// Generate a feature print for the image at `path`.
96///
97/// # Errors
98///
99/// Returns [`VisionError::ImageLoadFailed`] / [`VisionError::RequestFailed`].
100pub fn generate_image_feature_print_in_path(
101    path: impl AsRef<Path>,
102) -> Result<Option<FeaturePrint>, VisionError> {
103    let path_str = path
104        .as_ref()
105        .to_str()
106        .ok_or_else(|| VisionError::InvalidArgument("non-UTF-8 path".into()))?;
107    let path_c = CString::new(path_str)
108        .map_err(|e| VisionError::InvalidArgument(format!("path NUL byte: {e}")))?;
109
110    let mut raw = ffi::FeaturePrintRaw {
111        element_type: 0,
112        element_count: 0,
113        bytes: ptr::null_mut(),
114    };
115    let mut err_msg: *mut c_char = ptr::null_mut();
116    let status = unsafe {
117        ffi::vn_generate_image_feature_print_in_path(path_c.as_ptr(), &mut raw, &mut err_msg)
118    };
119    if status != ffi::status::OK {
120        return Err(unsafe { from_swift(status, err_msg) });
121    }
122    if raw.bytes.is_null() || raw.element_count == 0 {
123        return Ok(None);
124    }
125    let bytes_per_elem = match raw.element_type {
126        1 => 4_usize,
127        2 => 8_usize,
128        _ => 0_usize,
129    };
130    let len = raw.element_count.saturating_mul(bytes_per_elem);
131    let slice = unsafe { core::slice::from_raw_parts(raw.bytes.cast::<u8>(), len) };
132    let data = slice.to_vec();
133    unsafe { ffi::vn_feature_print_free(&mut raw) };
134
135    Ok(Some(FeaturePrint {
136        element_type: raw.element_type,
137        element_count: raw.element_count,
138        data,
139    }))
140}