Skip to main content

apple_vision/coreml/
mod.rs

1//! `CoreML` inference via Vision (`VNCoreMLModel`, `VNCoreMLRequest`, and
2//! `VNCoreMLFeatureValueObservation`).
3
4use core::{ffi::c_char, ptr};
5use std::{
6    ffi::{CStr, CString},
7    path::{Path, PathBuf},
8};
9
10use crate::classify::Classification;
11use crate::error::{from_swift, VisionError};
12use crate::ffi;
13use crate::request_base::ImageBasedRequest;
14
15/// Mirrors `VNImageCropAndScaleOption`.
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
17pub enum CoreMLImageCropAndScaleOption {
18    CenterCrop = 0,
19    ScaleFit = 1,
20    ScaleFill = 2,
21    ScaleFitRotate90CCW = 0x101,
22    ScaleFillRotate90CCW = 0x102,
23}
24
25/// A safe wrapper for `VNCoreMLModel`.
26#[derive(Debug, Clone, PartialEq, Eq)]
27pub struct CoreMLModel {
28    model_path: PathBuf,
29    input_image_feature_name: Option<String>,
30}
31
32impl CoreMLModel {
33    #[must_use]
34    pub fn new(model_path: impl AsRef<Path>) -> Self {
35        Self {
36            model_path: model_path.as_ref().to_path_buf(),
37            input_image_feature_name: None,
38        }
39    }
40
41    #[must_use]
42    pub fn with_input_image_feature_name(
43        mut self,
44        input_image_feature_name: impl Into<String>,
45    ) -> Self {
46        self.input_image_feature_name = Some(input_image_feature_name.into());
47        self
48    }
49
50    #[must_use]
51    pub fn model_path(&self) -> &Path {
52        &self.model_path
53    }
54
55    #[must_use]
56    pub fn input_image_feature_name(&self) -> Option<&str> {
57        self.input_image_feature_name.as_deref()
58    }
59}
60
61/// A safe `MLFeatureValue` wrapper for `VNCoreMLFeatureValueObservation`.
62#[derive(Debug, Clone, PartialEq)]
63pub enum CoreMLFeatureValue {
64    Int64(i64),
65    Double(f64),
66    String(String),
67    MultiArray { shape: Vec<usize>, values: Vec<f64> },
68    Unknown { type_name: String },
69}
70
71/// A dedicated `VNCoreMLFeatureValueObservation` wrapper.
72#[derive(Debug, Clone, PartialEq)]
73pub struct CoreMLFeatureValueObservation {
74    pub feature_name: Option<String>,
75    pub value: CoreMLFeatureValue,
76}
77
78/// A dedicated `VNCoreMLRequest` wrapper.
79#[derive(Debug, Clone, PartialEq)]
80pub struct CoreMLRequest {
81    model: CoreMLModel,
82    image_based: ImageBasedRequest,
83    image_crop_and_scale_option: CoreMLImageCropAndScaleOption,
84}
85
86impl CoreMLRequest {
87    #[must_use]
88    pub fn new(model_path: impl AsRef<Path>) -> Self {
89        Self {
90            model: CoreMLModel::new(model_path),
91            image_based: ImageBasedRequest::new(),
92            image_crop_and_scale_option: CoreMLImageCropAndScaleOption::CenterCrop,
93        }
94    }
95
96    #[must_use]
97    pub fn with_model(mut self, model: CoreMLModel) -> Self {
98        self.model = model;
99        self
100    }
101
102    #[must_use]
103    pub const fn with_image_based_request(mut self, image_based: ImageBasedRequest) -> Self {
104        self.image_based = image_based;
105        self
106    }
107
108    #[must_use]
109    pub const fn with_image_crop_and_scale_option(
110        mut self,
111        image_crop_and_scale_option: CoreMLImageCropAndScaleOption,
112    ) -> Self {
113        self.image_crop_and_scale_option = image_crop_and_scale_option;
114        self
115    }
116
117    #[must_use]
118    pub const fn image_based_request(&self) -> &ImageBasedRequest {
119        &self.image_based
120    }
121
122    #[must_use]
123    pub const fn image_crop_and_scale_option(&self) -> CoreMLImageCropAndScaleOption {
124        self.image_crop_and_scale_option
125    }
126
127    #[must_use]
128    pub const fn model(&self) -> &CoreMLModel {
129        &self.model
130    }
131
132    /// Run the request as a classifier and return `VNClassificationObservation`
133    /// values.
134    ///
135    /// # Errors
136    ///
137    /// Returns [`VisionError`] if the image/model cannot be loaded or Vision
138    /// rejects the request.
139    pub fn classify(
140        &self,
141        image_path: impl AsRef<Path>,
142    ) -> Result<Vec<Classification>, VisionError> {
143        let image_c = path_to_cstring(image_path.as_ref(), "image path")?;
144        let model_c = path_to_cstring(self.model.model_path(), "model path")?;
145        let input_feature_c = self
146            .model
147            .input_image_feature_name()
148            .map(|name| {
149                CString::new(name).map_err(|err| {
150                    VisionError::InvalidArgument(format!(
151                        "input image feature name NUL byte: {err}"
152                    ))
153                })
154            })
155            .transpose()?;
156        let roi = self.image_based.region_of_interest();
157        let mut out_array = ptr::null_mut();
158        let mut out_count = 0;
159        let mut err_msg: *mut c_char = ptr::null_mut();
160        // SAFETY: all pointer arguments are valid stack locations or bridge-owned handles; strings are valid C strings for the duration of the call.
161        let status = unsafe {
162            ffi::vn_coreml_request_classify_in_path(
163                image_c.as_ptr(),
164                model_c.as_ptr(),
165                input_feature_c
166                    .as_ref()
167                    .map_or(ptr::null(), |name| name.as_ptr()),
168                input_feature_c.is_some(),
169                self.image_crop_and_scale_option as i32,
170                roi.map_or(0.0, |rect| rect.x),
171                roi.map_or(0.0, |rect| rect.y),
172                roi.map_or(1.0, |rect| rect.width),
173                roi.map_or(1.0, |rect| rect.height),
174                roi.is_some(),
175                self.image_based.prefer_background_processing(),
176                self.image_based.uses_cpu_only(),
177                self.image_based.revision().unwrap_or_default(),
178                self.image_based.revision().is_some(),
179                &mut out_array,
180                &mut out_count,
181                &mut err_msg,
182            )
183        };
184        if status != ffi::status::OK {
185            // SAFETY: the error pointer is either null or a bridge-allocated C string; `from_swift` frees it.
186            return Err(unsafe { from_swift(status, err_msg) });
187        }
188        Ok(collect_classifications(out_array, out_count))
189    }
190
191    /// Run the request and return a dedicated
192    /// `VNCoreMLFeatureValueObservation`.
193    ///
194    /// # Errors
195    ///
196    /// Returns [`VisionError`] if the image/model cannot be loaded or Vision
197    /// rejects the request.
198    pub fn feature_value(
199        &self,
200        image_path: impl AsRef<Path>,
201    ) -> Result<Option<CoreMLFeatureValueObservation>, VisionError> {
202        let image_c = path_to_cstring(image_path.as_ref(), "image path")?;
203        let model_c = path_to_cstring(self.model.model_path(), "model path")?;
204        let input_feature_c = self
205            .model
206            .input_image_feature_name()
207            .map(|name| {
208                CString::new(name).map_err(|err| {
209                    VisionError::InvalidArgument(format!(
210                        "input image feature name NUL byte: {err}"
211                    ))
212                })
213            })
214            .transpose()?;
215        let roi = self.image_based.region_of_interest();
216        let mut raw = ffi::CoreMLFeatureValueRaw {
217            feature_name: ptr::null_mut(),
218            type_name: ptr::null_mut(),
219            kind: 0,
220            int64_value: 0,
221            double_value: 0.0,
222            string_value: ptr::null_mut(),
223            multi_array_shape: ptr::null_mut(),
224            multi_array_shape_count: 0,
225            multi_array_values: ptr::null_mut(),
226            multi_array_value_count: 0,
227        };
228        let mut has_value = false;
229        let mut err_msg: *mut c_char = ptr::null_mut();
230        // SAFETY: all pointer arguments are valid stack locations or bridge-owned handles; strings are valid C strings for the duration of the call.
231        let status = unsafe {
232            ffi::vn_coreml_feature_value_in_path(
233                image_c.as_ptr(),
234                model_c.as_ptr(),
235                input_feature_c
236                    .as_ref()
237                    .map_or(ptr::null(), |name| name.as_ptr()),
238                input_feature_c.is_some(),
239                self.image_crop_and_scale_option as i32,
240                roi.map_or(0.0, |rect| rect.x),
241                roi.map_or(0.0, |rect| rect.y),
242                roi.map_or(1.0, |rect| rect.width),
243                roi.map_or(1.0, |rect| rect.height),
244                roi.is_some(),
245                self.image_based.prefer_background_processing(),
246                self.image_based.uses_cpu_only(),
247                self.image_based.revision().unwrap_or_default(),
248                self.image_based.revision().is_some(),
249                &mut raw,
250                &mut has_value,
251                &mut err_msg,
252            )
253        };
254        if status != ffi::status::OK {
255            // SAFETY: the error pointer is either null or a bridge-allocated C string; `from_swift` frees it.
256            return Err(unsafe { from_swift(status, err_msg) });
257        }
258        if !has_value {
259            return Ok(None);
260        }
261        let observation = CoreMLFeatureValueObservation {
262            feature_name: string_from_ptr(raw.feature_name),
263            value: match raw.kind {
264                1 => CoreMLFeatureValue::Int64(raw.int64_value),
265                2 => CoreMLFeatureValue::Double(raw.double_value),
266                3 => CoreMLFeatureValue::String(
267                    string_from_ptr(raw.string_value).unwrap_or_default(),
268                ),
269                4 => {
270                    let shape =
271                        if raw.multi_array_shape.is_null() || raw.multi_array_shape_count == 0 {
272                            Vec::new()
273                        } else {
274                            // SAFETY: the pointer is valid for the reported element count as guaranteed by the bridge.
275                            unsafe {
276                                std::slice::from_raw_parts(
277                                    raw.multi_array_shape,
278                                    raw.multi_array_shape_count,
279                                )
280                            }
281                            .to_vec()
282                        };
283                    let values =
284                        if raw.multi_array_values.is_null() || raw.multi_array_value_count == 0 {
285                            Vec::new()
286                        } else {
287                            // SAFETY: the pointer is valid for the reported element count as guaranteed by the bridge.
288                            unsafe {
289                                std::slice::from_raw_parts(
290                                    raw.multi_array_values,
291                                    raw.multi_array_value_count,
292                                )
293                            }
294                            .to_vec()
295                        };
296                    CoreMLFeatureValue::MultiArray { shape, values }
297                }
298                _ => CoreMLFeatureValue::Unknown {
299                    type_name: string_from_ptr(raw.type_name)
300                        .unwrap_or_else(|| "unknown".to_string()),
301                },
302            },
303        };
304        // SAFETY: `raw` was populated by the bridge and has not been freed yet; unique free site.
305        unsafe { ffi::vn_coreml_feature_value_free(&mut raw) };
306        Ok(Some(observation))
307    }
308}
309
310/// Run a Core ML classifier model on the image at `path`.
311///
312/// # Errors
313///
314/// Returns [`VisionError::ImageLoadFailed`] / [`VisionError::RequestFailed`].
315pub fn coreml_classify_in_path(
316    image_path: impl AsRef<Path>,
317    model_path: impl AsRef<Path>,
318) -> Result<Vec<Classification>, VisionError> {
319    CoreMLRequest::new(model_path).classify(image_path)
320}
321
322/// Run a Core ML model that returns a feature value on the image at `path`.
323///
324/// # Errors
325///
326/// Returns [`VisionError::ImageLoadFailed`] / [`VisionError::RequestFailed`].
327pub fn coreml_feature_value_in_path(
328    image_path: impl AsRef<Path>,
329    model_path: impl AsRef<Path>,
330) -> Result<Option<CoreMLFeatureValueObservation>, VisionError> {
331    CoreMLRequest::new(model_path).feature_value(image_path)
332}
333
334fn collect_classifications(
335    out_array: *mut core::ffi::c_void,
336    out_count: usize,
337) -> Vec<Classification> {
338    if out_array.is_null() || out_count == 0 {
339        return Vec::new();
340    }
341    let typed = out_array.cast::<ffi::ClassificationRaw>();
342    let mut values = Vec::with_capacity(out_count);
343    for index in 0..out_count {
344        // SAFETY: the pointer is valid for the reported element count; the index is in bounds.
345        let raw = unsafe { &*typed.add(index) };
346        values.push(Classification {
347            identifier: string_from_ptr(raw.identifier).unwrap_or_default(),
348            confidence: raw.confidence,
349        });
350    }
351    // SAFETY: the pointer/count pair was allocated by the bridge and is freed exactly once here.
352    unsafe { ffi::vn_classifications_free(out_array, out_count) };
353    values
354}
355
356fn path_to_cstring(path: &Path, label: &str) -> Result<CString, VisionError> {
357    let path = path
358        .to_str()
359        .ok_or_else(|| VisionError::InvalidArgument(format!("non-UTF-8 {label}")))?;
360    CString::new(path)
361        .map_err(|err| VisionError::InvalidArgument(format!("{label} NUL byte: {err}")))
362}
363
364fn string_from_ptr(ptr: *mut c_char) -> Option<String> {
365    (!ptr.is_null()).then(|| {
366        // SAFETY: the C string pointer is non-null (checked above) and valid for the duration of this borrow.
367        unsafe { CStr::from_ptr(ptr) }
368            .to_string_lossy()
369            .into_owned()
370    })
371}