use crate::error::{Error, ErrorKind, Result};
use crate::tensor::AsMultiArray;
#[cfg(target_vendor = "apple")]
pub struct BatchProvider {
pub(crate) inner: objc2::rc::Retained<objc2_core_ml::MLArrayBatchProvider>,
}
#[cfg(target_vendor = "apple")]
impl BatchProvider {
pub fn new(inputs: &[&[(&str, &dyn AsMultiArray)]]) -> Result<Self> {
use objc2::AnyThread;
use objc2::runtime::ProtocolObject;
use objc2_core_ml::{
MLArrayBatchProvider, MLDictionaryFeatureProvider, MLFeatureProvider, MLFeatureValue,
};
use objc2_foundation::{NSDictionary, NSString};
let mut providers: Vec<
objc2::rc::Retained<ProtocolObject<dyn MLFeatureProvider>>,
> = Vec::with_capacity(inputs.len());
for input_set in inputs {
let mut keys: Vec<objc2::rc::Retained<NSString>> =
Vec::with_capacity(input_set.len());
let mut vals: Vec<objc2::rc::Retained<MLFeatureValue>> =
Vec::with_capacity(input_set.len());
for &(name, tensor) in *input_set {
keys.push(crate::ffi::str_to_nsstring(name));
vals.push(unsafe {
MLFeatureValue::featureValueWithMultiArray(tensor.as_ml_multi_array())
});
}
let key_refs: Vec<&NSString> = keys.iter().map(|k| &**k).collect();
let val_refs: Vec<&MLFeatureValue> = vals.iter().map(|v| &**v).collect();
let dict: objc2::rc::Retained<NSDictionary<NSString, MLFeatureValue>> =
NSDictionary::from_slices(&key_refs, &val_refs);
let dict_any: &NSDictionary<NSString, objc2::runtime::AnyObject> = unsafe {
&*((&*dict) as *const NSDictionary<NSString, MLFeatureValue>
as *const NSDictionary<NSString, objc2::runtime::AnyObject>)
};
let provider = unsafe {
MLDictionaryFeatureProvider::initWithDictionary_error(
MLDictionaryFeatureProvider::alloc(),
dict_any,
)
}
.map_err(|e| Error::from_nserror(ErrorKind::Prediction, &e))?;
let proto = ProtocolObject::from_retained(provider);
providers.push(proto);
}
let provider_refs: Vec<&ProtocolObject<dyn MLFeatureProvider>> =
providers.iter().map(|p| &**p).collect();
let array = objc2_foundation::NSArray::from_slice(&provider_refs);
let batch = unsafe {
MLArrayBatchProvider::initWithFeatureProviderArray(
MLArrayBatchProvider::alloc(),
&array,
)
};
Ok(Self { inner: batch })
}
pub fn count(&self) -> usize {
use objc2_core_ml::MLBatchProvider;
(unsafe { self.inner.count() }) as usize
}
}
#[cfg(target_vendor = "apple")]
impl std::fmt::Debug for BatchProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BatchProvider")
.field("count", &self.count())
.finish()
}
}
#[cfg(target_vendor = "apple")]
unsafe impl Send for BatchProvider {}
#[cfg(not(target_vendor = "apple"))]
#[derive(Debug)]
pub struct BatchProvider {
_private: (),
}
#[cfg(not(target_vendor = "apple"))]
impl BatchProvider {
pub fn new(_inputs: &[&[(&str, &dyn AsMultiArray)]]) -> Result<Self> {
Err(Error::new(
ErrorKind::UnsupportedPlatform,
"CoreML requires Apple platform",
))
}
pub fn count(&self) -> usize {
0
}
}
#[cfg(target_vendor = "apple")]
pub struct BatchPrediction {
pub(crate) inner:
objc2::rc::Retained<objc2::runtime::ProtocolObject<dyn objc2_core_ml::MLBatchProvider>>,
}
#[cfg(target_vendor = "apple")]
impl BatchPrediction {
pub fn count(&self) -> usize {
use objc2_core_ml::MLBatchProvider;
(unsafe { self.inner.count() }) as usize
}
#[allow(deprecated)]
#[allow(clippy::needless_range_loop)]
pub fn get_f32(&self, index: usize, output_name: &str) -> Result<(Vec<f32>, Vec<usize>)> {
use objc2_core_ml::{MLBatchProvider, MLFeatureProvider};
if index >= self.count() {
return Err(Error::new(
ErrorKind::Prediction,
format!(
"batch index {index} out of range (count: {})",
self.count()
),
));
}
objc2::rc::autoreleasepool(|_pool| {
let provider = unsafe { self.inner.featuresAtIndex(index as isize) };
let ns_name = crate::ffi::str_to_nsstring(output_name);
let feature_val =
unsafe { provider.featureValueForName(&ns_name) }.ok_or_else(|| {
Error::new(
ErrorKind::Prediction,
format!("output '{output_name}' not found at batch index {index}"),
)
})?;
let array = unsafe { feature_val.multiArrayValue() }.ok_or_else(|| {
Error::new(
ErrorKind::Prediction,
format!(
"output '{output_name}' is not a multi-array at batch index {index}"
),
)
})?;
let shape = crate::ffi::nsarray_to_shape(unsafe { &array.shape() });
let count = crate::tensor::element_count(&shape);
let dt_raw = unsafe { array.dataType() };
let data_type = crate::ffi::ml_to_datatype(dt_raw.0);
let mut buf = vec![0.0f32; count];
unsafe {
let ptr = array.dataPointer();
match data_type {
Some(crate::tensor::DataType::Float32) => {
let src = ptr.as_ptr() as *const f32;
std::ptr::copy_nonoverlapping(src, buf.as_mut_ptr(), count);
}
Some(crate::tensor::DataType::Float16) => {
let src = ptr.as_ptr() as *const u16;
for i in 0..count {
buf[i] = crate::f16_to_f32(*src.add(i));
}
}
Some(crate::tensor::DataType::Float64) => {
let src = ptr.as_ptr() as *const f64;
for i in 0..count {
buf[i] = *src.add(i) as f32;
}
}
Some(crate::tensor::DataType::Int32) => {
let src = ptr.as_ptr() as *const i32;
for i in 0..count {
buf[i] = *src.add(i) as f32;
}
}
Some(crate::tensor::DataType::Int16) => {
let src = ptr.as_ptr() as *const i16;
for i in 0..count {
buf[i] = *src.add(i) as f32;
}
}
Some(crate::tensor::DataType::Int8) => {
let src = ptr.as_ptr() as *const i8;
for i in 0..count {
buf[i] = *src.add(i) as f32;
}
}
Some(crate::tensor::DataType::UInt32) => {
let src = ptr.as_ptr() as *const u32;
for i in 0..count {
buf[i] = *src.add(i) as f32;
}
}
Some(crate::tensor::DataType::UInt16) => {
let src = ptr.as_ptr() as *const u16;
for i in 0..count {
buf[i] = *src.add(i) as f32;
}
}
Some(crate::tensor::DataType::UInt8) => {
let src = ptr.as_ptr() as *const u8;
for i in 0..count {
buf[i] = *src.add(i) as f32;
}
}
None => {
return Err(Error::new(
ErrorKind::Prediction,
"unsupported output data type",
));
}
}
}
Ok((buf, shape))
})
}
pub fn feature_provider(
&self,
index: usize,
) -> Result<
objc2::rc::Retained<
objc2::runtime::ProtocolObject<dyn objc2_core_ml::MLFeatureProvider>,
>,
> {
use objc2_core_ml::MLBatchProvider;
if index >= self.count() {
return Err(Error::new(
ErrorKind::Prediction,
format!(
"batch index {index} out of range (count: {})",
self.count()
),
));
}
Ok(unsafe { self.inner.featuresAtIndex(index as isize) })
}
}
#[cfg(target_vendor = "apple")]
impl std::fmt::Debug for BatchPrediction {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BatchPrediction")
.field("count", &self.count())
.finish()
}
}
#[cfg(target_vendor = "apple")]
unsafe impl Send for BatchPrediction {}
#[cfg(not(target_vendor = "apple"))]
#[derive(Debug)]
pub struct BatchPrediction {
_private: (),
}
#[cfg(not(target_vendor = "apple"))]
impl BatchPrediction {
pub fn count(&self) -> usize {
0
}
pub fn get_f32(
&self,
_index: usize,
_output_name: &str,
) -> Result<(Vec<f32>, Vec<usize>)> {
Err(Error::new(
ErrorKind::UnsupportedPlatform,
"CoreML requires Apple platform",
))
}
}
#[cfg(test)]
mod tests {
#[cfg(not(target_vendor = "apple"))]
use super::*;
#[cfg(not(target_vendor = "apple"))]
#[test]
fn batch_provider_fails_on_non_apple() {
let inputs: &[&[(&str, &dyn AsMultiArray)]] = &[];
let err = BatchProvider::new(inputs).unwrap_err();
assert_eq!(err.kind(), &ErrorKind::UnsupportedPlatform);
}
#[cfg(not(target_vendor = "apple"))]
#[test]
fn batch_prediction_fails_on_non_apple() {
let pred = BatchPrediction { _private: () };
assert_eq!(pred.count(), 0);
let err = pred.get_f32(0, "output").unwrap_err();
assert_eq!(err.kind(), &ErrorKind::UnsupportedPlatform);
}
}