1use core::{ffi::c_char, ptr};
5use std::{
6 ffi::{CStr, CString},
7 path::{Path, PathBuf},
8};
9
10use crate::classify::Classification;
11use crate::error::{from_swift, VisionError};
12use crate::ffi;
13use crate::request_base::ImageBasedRequest;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
17pub enum CoreMLImageCropAndScaleOption {
18 CenterCrop = 0,
19 ScaleFit = 1,
20 ScaleFill = 2,
21 ScaleFitRotate90CCW = 0x101,
22 ScaleFillRotate90CCW = 0x102,
23}
24
25#[derive(Debug, Clone, PartialEq, Eq)]
27pub struct CoreMLModel {
28 model_path: PathBuf,
29 input_image_feature_name: Option<String>,
30}
31
32impl CoreMLModel {
33 #[must_use]
34 pub fn new(model_path: impl AsRef<Path>) -> Self {
35 Self {
36 model_path: model_path.as_ref().to_path_buf(),
37 input_image_feature_name: None,
38 }
39 }
40
41 #[must_use]
42 pub fn with_input_image_feature_name(
43 mut self,
44 input_image_feature_name: impl Into<String>,
45 ) -> Self {
46 self.input_image_feature_name = Some(input_image_feature_name.into());
47 self
48 }
49
50 #[must_use]
51 pub fn model_path(&self) -> &Path {
52 &self.model_path
53 }
54
55 #[must_use]
56 pub fn input_image_feature_name(&self) -> Option<&str> {
57 self.input_image_feature_name.as_deref()
58 }
59}
60
61#[derive(Debug, Clone, PartialEq)]
63pub enum CoreMLFeatureValue {
64 Int64(i64),
65 Double(f64),
66 String(String),
67 MultiArray { shape: Vec<usize>, values: Vec<f64> },
68 Unknown { type_name: String },
69}
70
71#[derive(Debug, Clone, PartialEq)]
73pub struct CoreMLFeatureValueObservation {
74 pub feature_name: Option<String>,
75 pub value: CoreMLFeatureValue,
76}
77
78#[derive(Debug, Clone, PartialEq)]
80pub struct CoreMLRequest {
81 model: CoreMLModel,
82 image_based: ImageBasedRequest,
83 image_crop_and_scale_option: CoreMLImageCropAndScaleOption,
84}
85
86impl CoreMLRequest {
87 #[must_use]
88 pub fn new(model_path: impl AsRef<Path>) -> Self {
89 Self {
90 model: CoreMLModel::new(model_path),
91 image_based: ImageBasedRequest::new(),
92 image_crop_and_scale_option: CoreMLImageCropAndScaleOption::CenterCrop,
93 }
94 }
95
96 #[must_use]
97 pub fn with_model(mut self, model: CoreMLModel) -> Self {
98 self.model = model;
99 self
100 }
101
102 #[must_use]
103 pub const fn with_image_based_request(mut self, image_based: ImageBasedRequest) -> Self {
104 self.image_based = image_based;
105 self
106 }
107
108 #[must_use]
109 pub const fn with_image_crop_and_scale_option(
110 mut self,
111 image_crop_and_scale_option: CoreMLImageCropAndScaleOption,
112 ) -> Self {
113 self.image_crop_and_scale_option = image_crop_and_scale_option;
114 self
115 }
116
117 #[must_use]
118 pub const fn image_based_request(&self) -> &ImageBasedRequest {
119 &self.image_based
120 }
121
122 #[must_use]
123 pub const fn image_crop_and_scale_option(&self) -> CoreMLImageCropAndScaleOption {
124 self.image_crop_and_scale_option
125 }
126
127 #[must_use]
128 pub const fn model(&self) -> &CoreMLModel {
129 &self.model
130 }
131
132 pub fn classify(
140 &self,
141 image_path: impl AsRef<Path>,
142 ) -> Result<Vec<Classification>, VisionError> {
143 let image_c = path_to_cstring(image_path.as_ref(), "image path")?;
144 let model_c = path_to_cstring(self.model.model_path(), "model path")?;
145 let input_feature_c = self
146 .model
147 .input_image_feature_name()
148 .map(|name| {
149 CString::new(name).map_err(|err| {
150 VisionError::InvalidArgument(format!(
151 "input image feature name NUL byte: {err}"
152 ))
153 })
154 })
155 .transpose()?;
156 let roi = self.image_based.region_of_interest();
157 let mut out_array = ptr::null_mut();
158 let mut out_count = 0;
159 let mut err_msg: *mut c_char = ptr::null_mut();
160 let status = unsafe {
162 ffi::vn_coreml_request_classify_in_path(
163 image_c.as_ptr(),
164 model_c.as_ptr(),
165 input_feature_c
166 .as_ref()
167 .map_or(ptr::null(), |name| name.as_ptr()),
168 input_feature_c.is_some(),
169 self.image_crop_and_scale_option as i32,
170 roi.map_or(0.0, |rect| rect.x),
171 roi.map_or(0.0, |rect| rect.y),
172 roi.map_or(1.0, |rect| rect.width),
173 roi.map_or(1.0, |rect| rect.height),
174 roi.is_some(),
175 self.image_based.prefer_background_processing(),
176 self.image_based.uses_cpu_only(),
177 self.image_based.revision().unwrap_or_default(),
178 self.image_based.revision().is_some(),
179 &mut out_array,
180 &mut out_count,
181 &mut err_msg,
182 )
183 };
184 if status != ffi::status::OK {
185 return Err(unsafe { from_swift(status, err_msg) });
187 }
188 Ok(collect_classifications(out_array, out_count))
189 }
190
191 pub fn feature_value(
199 &self,
200 image_path: impl AsRef<Path>,
201 ) -> Result<Option<CoreMLFeatureValueObservation>, VisionError> {
202 let image_c = path_to_cstring(image_path.as_ref(), "image path")?;
203 let model_c = path_to_cstring(self.model.model_path(), "model path")?;
204 let input_feature_c = self
205 .model
206 .input_image_feature_name()
207 .map(|name| {
208 CString::new(name).map_err(|err| {
209 VisionError::InvalidArgument(format!(
210 "input image feature name NUL byte: {err}"
211 ))
212 })
213 })
214 .transpose()?;
215 let roi = self.image_based.region_of_interest();
216 let mut raw = ffi::CoreMLFeatureValueRaw {
217 feature_name: ptr::null_mut(),
218 type_name: ptr::null_mut(),
219 kind: 0,
220 int64_value: 0,
221 double_value: 0.0,
222 string_value: ptr::null_mut(),
223 multi_array_shape: ptr::null_mut(),
224 multi_array_shape_count: 0,
225 multi_array_values: ptr::null_mut(),
226 multi_array_value_count: 0,
227 };
228 let mut has_value = false;
229 let mut err_msg: *mut c_char = ptr::null_mut();
230 let status = unsafe {
232 ffi::vn_coreml_feature_value_in_path(
233 image_c.as_ptr(),
234 model_c.as_ptr(),
235 input_feature_c
236 .as_ref()
237 .map_or(ptr::null(), |name| name.as_ptr()),
238 input_feature_c.is_some(),
239 self.image_crop_and_scale_option as i32,
240 roi.map_or(0.0, |rect| rect.x),
241 roi.map_or(0.0, |rect| rect.y),
242 roi.map_or(1.0, |rect| rect.width),
243 roi.map_or(1.0, |rect| rect.height),
244 roi.is_some(),
245 self.image_based.prefer_background_processing(),
246 self.image_based.uses_cpu_only(),
247 self.image_based.revision().unwrap_or_default(),
248 self.image_based.revision().is_some(),
249 &mut raw,
250 &mut has_value,
251 &mut err_msg,
252 )
253 };
254 if status != ffi::status::OK {
255 return Err(unsafe { from_swift(status, err_msg) });
257 }
258 if !has_value {
259 return Ok(None);
260 }
261 let observation = CoreMLFeatureValueObservation {
262 feature_name: string_from_ptr(raw.feature_name),
263 value: match raw.kind {
264 1 => CoreMLFeatureValue::Int64(raw.int64_value),
265 2 => CoreMLFeatureValue::Double(raw.double_value),
266 3 => CoreMLFeatureValue::String(
267 string_from_ptr(raw.string_value).unwrap_or_default(),
268 ),
269 4 => {
270 let shape =
271 if raw.multi_array_shape.is_null() || raw.multi_array_shape_count == 0 {
272 Vec::new()
273 } else {
274 unsafe {
276 std::slice::from_raw_parts(
277 raw.multi_array_shape,
278 raw.multi_array_shape_count,
279 )
280 }
281 .to_vec()
282 };
283 let values =
284 if raw.multi_array_values.is_null() || raw.multi_array_value_count == 0 {
285 Vec::new()
286 } else {
287 unsafe {
289 std::slice::from_raw_parts(
290 raw.multi_array_values,
291 raw.multi_array_value_count,
292 )
293 }
294 .to_vec()
295 };
296 CoreMLFeatureValue::MultiArray { shape, values }
297 }
298 _ => CoreMLFeatureValue::Unknown {
299 type_name: string_from_ptr(raw.type_name)
300 .unwrap_or_else(|| "unknown".to_string()),
301 },
302 },
303 };
304 unsafe { ffi::vn_coreml_feature_value_free(&mut raw) };
306 Ok(Some(observation))
307 }
308}
309
310pub fn coreml_classify_in_path(
316 image_path: impl AsRef<Path>,
317 model_path: impl AsRef<Path>,
318) -> Result<Vec<Classification>, VisionError> {
319 CoreMLRequest::new(model_path).classify(image_path)
320}
321
322pub fn coreml_feature_value_in_path(
328 image_path: impl AsRef<Path>,
329 model_path: impl AsRef<Path>,
330) -> Result<Option<CoreMLFeatureValueObservation>, VisionError> {
331 CoreMLRequest::new(model_path).feature_value(image_path)
332}
333
334fn collect_classifications(
335 out_array: *mut core::ffi::c_void,
336 out_count: usize,
337) -> Vec<Classification> {
338 if out_array.is_null() || out_count == 0 {
339 return Vec::new();
340 }
341 let typed = out_array.cast::<ffi::ClassificationRaw>();
342 let mut values = Vec::with_capacity(out_count);
343 for index in 0..out_count {
344 let raw = unsafe { &*typed.add(index) };
346 values.push(Classification {
347 identifier: string_from_ptr(raw.identifier).unwrap_or_default(),
348 confidence: raw.confidence,
349 });
350 }
351 unsafe { ffi::vn_classifications_free(out_array, out_count) };
353 values
354}
355
356fn path_to_cstring(path: &Path, label: &str) -> Result<CString, VisionError> {
357 let path = path
358 .to_str()
359 .ok_or_else(|| VisionError::InvalidArgument(format!("non-UTF-8 {label}")))?;
360 CString::new(path)
361 .map_err(|err| VisionError::InvalidArgument(format!("{label} NUL byte: {err}")))
362}
363
364fn string_from_ptr(ptr: *mut c_char) -> Option<String> {
365 (!ptr.is_null()).then(|| {
366 unsafe { CStr::from_ptr(ptr) }
368 .to_string_lossy()
369 .into_owned()
370 })
371}