use core::{ffi::c_char, ptr};
use std::{
ffi::{CStr, CString},
path::{Path, PathBuf},
};
use crate::classify::Classification;
use crate::error::{from_swift, VisionError};
use crate::ffi;
use crate::request_base::ImageBasedRequest;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum CoreMLImageCropAndScaleOption {
CenterCrop = 0,
ScaleFit = 1,
ScaleFill = 2,
ScaleFitRotate90CCW = 0x101,
ScaleFillRotate90CCW = 0x102,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CoreMLModel {
model_path: PathBuf,
input_image_feature_name: Option<String>,
}
impl CoreMLModel {
#[must_use]
pub fn new(model_path: impl AsRef<Path>) -> Self {
Self {
model_path: model_path.as_ref().to_path_buf(),
input_image_feature_name: None,
}
}
#[must_use]
pub fn with_input_image_feature_name(
mut self,
input_image_feature_name: impl Into<String>,
) -> Self {
self.input_image_feature_name = Some(input_image_feature_name.into());
self
}
#[must_use]
pub fn model_path(&self) -> &Path {
&self.model_path
}
#[must_use]
pub fn input_image_feature_name(&self) -> Option<&str> {
self.input_image_feature_name.as_deref()
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum CoreMLFeatureValue {
Int64(i64),
Double(f64),
String(String),
MultiArray { shape: Vec<usize>, values: Vec<f64> },
Unknown { type_name: String },
}
#[derive(Debug, Clone, PartialEq)]
pub struct CoreMLFeatureValueObservation {
pub feature_name: Option<String>,
pub value: CoreMLFeatureValue,
}
#[derive(Debug, Clone, PartialEq)]
pub struct CoreMLRequest {
model: CoreMLModel,
image_based: ImageBasedRequest,
image_crop_and_scale_option: CoreMLImageCropAndScaleOption,
}
impl CoreMLRequest {
#[must_use]
pub fn new(model_path: impl AsRef<Path>) -> Self {
Self {
model: CoreMLModel::new(model_path),
image_based: ImageBasedRequest::new(),
image_crop_and_scale_option: CoreMLImageCropAndScaleOption::CenterCrop,
}
}
#[must_use]
pub fn with_model(mut self, model: CoreMLModel) -> Self {
self.model = model;
self
}
#[must_use]
pub const fn with_image_based_request(mut self, image_based: ImageBasedRequest) -> Self {
self.image_based = image_based;
self
}
#[must_use]
pub const fn with_image_crop_and_scale_option(
mut self,
image_crop_and_scale_option: CoreMLImageCropAndScaleOption,
) -> Self {
self.image_crop_and_scale_option = image_crop_and_scale_option;
self
}
#[must_use]
pub const fn image_based_request(&self) -> &ImageBasedRequest {
&self.image_based
}
#[must_use]
pub const fn image_crop_and_scale_option(&self) -> CoreMLImageCropAndScaleOption {
self.image_crop_and_scale_option
}
#[must_use]
pub const fn model(&self) -> &CoreMLModel {
&self.model
}
pub fn classify(
&self,
image_path: impl AsRef<Path>,
) -> Result<Vec<Classification>, VisionError> {
let image_c = path_to_cstring(image_path.as_ref(), "image path")?;
let model_c = path_to_cstring(self.model.model_path(), "model path")?;
let input_feature_c = self
.model
.input_image_feature_name()
.map(|name| {
CString::new(name).map_err(|err| {
VisionError::InvalidArgument(format!(
"input image feature name NUL byte: {err}"
))
})
})
.transpose()?;
let roi = self.image_based.region_of_interest();
let mut out_array = ptr::null_mut();
let mut out_count = 0;
let mut err_msg: *mut c_char = ptr::null_mut();
let status = unsafe {
ffi::vn_coreml_request_classify_in_path(
image_c.as_ptr(),
model_c.as_ptr(),
input_feature_c
.as_ref()
.map_or(ptr::null(), |name| name.as_ptr()),
input_feature_c.is_some(),
self.image_crop_and_scale_option as i32,
roi.map_or(0.0, |rect| rect.x),
roi.map_or(0.0, |rect| rect.y),
roi.map_or(1.0, |rect| rect.width),
roi.map_or(1.0, |rect| rect.height),
roi.is_some(),
self.image_based.prefer_background_processing(),
self.image_based.uses_cpu_only(),
self.image_based.revision().unwrap_or_default(),
self.image_based.revision().is_some(),
&mut out_array,
&mut out_count,
&mut err_msg,
)
};
if status != ffi::status::OK {
return Err(unsafe { from_swift(status, err_msg) });
}
Ok(collect_classifications(out_array, out_count))
}
pub fn feature_value(
&self,
image_path: impl AsRef<Path>,
) -> Result<Option<CoreMLFeatureValueObservation>, VisionError> {
let image_c = path_to_cstring(image_path.as_ref(), "image path")?;
let model_c = path_to_cstring(self.model.model_path(), "model path")?;
let input_feature_c = self
.model
.input_image_feature_name()
.map(|name| {
CString::new(name).map_err(|err| {
VisionError::InvalidArgument(format!(
"input image feature name NUL byte: {err}"
))
})
})
.transpose()?;
let roi = self.image_based.region_of_interest();
let mut raw = ffi::CoreMLFeatureValueRaw {
feature_name: ptr::null_mut(),
type_name: ptr::null_mut(),
kind: 0,
int64_value: 0,
double_value: 0.0,
string_value: ptr::null_mut(),
multi_array_shape: ptr::null_mut(),
multi_array_shape_count: 0,
multi_array_values: ptr::null_mut(),
multi_array_value_count: 0,
};
let mut has_value = false;
let mut err_msg: *mut c_char = ptr::null_mut();
let status = unsafe {
ffi::vn_coreml_feature_value_in_path(
image_c.as_ptr(),
model_c.as_ptr(),
input_feature_c
.as_ref()
.map_or(ptr::null(), |name| name.as_ptr()),
input_feature_c.is_some(),
self.image_crop_and_scale_option as i32,
roi.map_or(0.0, |rect| rect.x),
roi.map_or(0.0, |rect| rect.y),
roi.map_or(1.0, |rect| rect.width),
roi.map_or(1.0, |rect| rect.height),
roi.is_some(),
self.image_based.prefer_background_processing(),
self.image_based.uses_cpu_only(),
self.image_based.revision().unwrap_or_default(),
self.image_based.revision().is_some(),
&mut raw,
&mut has_value,
&mut err_msg,
)
};
if status != ffi::status::OK {
return Err(unsafe { from_swift(status, err_msg) });
}
if !has_value {
return Ok(None);
}
let observation = CoreMLFeatureValueObservation {
feature_name: string_from_ptr(raw.feature_name),
value: match raw.kind {
1 => CoreMLFeatureValue::Int64(raw.int64_value),
2 => CoreMLFeatureValue::Double(raw.double_value),
3 => CoreMLFeatureValue::String(
string_from_ptr(raw.string_value).unwrap_or_default(),
),
4 => {
let shape =
if raw.multi_array_shape.is_null() || raw.multi_array_shape_count == 0 {
Vec::new()
} else {
unsafe {
std::slice::from_raw_parts(
raw.multi_array_shape,
raw.multi_array_shape_count,
)
}
.to_vec()
};
let values =
if raw.multi_array_values.is_null() || raw.multi_array_value_count == 0 {
Vec::new()
} else {
unsafe {
std::slice::from_raw_parts(
raw.multi_array_values,
raw.multi_array_value_count,
)
}
.to_vec()
};
CoreMLFeatureValue::MultiArray { shape, values }
}
_ => CoreMLFeatureValue::Unknown {
type_name: string_from_ptr(raw.type_name)
.unwrap_or_else(|| "unknown".to_string()),
},
},
};
unsafe { ffi::vn_coreml_feature_value_free(&mut raw) };
Ok(Some(observation))
}
}
pub fn coreml_classify_in_path(
image_path: impl AsRef<Path>,
model_path: impl AsRef<Path>,
) -> Result<Vec<Classification>, VisionError> {
CoreMLRequest::new(model_path).classify(image_path)
}
pub fn coreml_feature_value_in_path(
image_path: impl AsRef<Path>,
model_path: impl AsRef<Path>,
) -> Result<Option<CoreMLFeatureValueObservation>, VisionError> {
CoreMLRequest::new(model_path).feature_value(image_path)
}
fn collect_classifications(
out_array: *mut core::ffi::c_void,
out_count: usize,
) -> Vec<Classification> {
if out_array.is_null() || out_count == 0 {
return Vec::new();
}
let typed = out_array.cast::<ffi::ClassificationRaw>();
let mut values = Vec::with_capacity(out_count);
for index in 0..out_count {
let raw = unsafe { &*typed.add(index) };
values.push(Classification {
identifier: string_from_ptr(raw.identifier).unwrap_or_default(),
confidence: raw.confidence,
});
}
unsafe { ffi::vn_classifications_free(out_array, out_count) };
values
}
fn path_to_cstring(path: &Path, label: &str) -> Result<CString, VisionError> {
let path = path
.to_str()
.ok_or_else(|| VisionError::InvalidArgument(format!("non-UTF-8 {label}")))?;
CString::new(path)
.map_err(|err| VisionError::InvalidArgument(format!("{label} NUL byte: {err}")))
}
fn string_from_ptr(ptr: *mut c_char) -> Option<String> {
(!ptr.is_null()).then(|| {
unsafe { CStr::from_ptr(ptr) }
.to_string_lossy()
.into_owned()
})
}