use std::ffi::c_void;
use std::ptr::NonNull;
use block2::RcBlock;
use objc2::AnyThread;
use objc2::rc::Retained;
use objc2_core_ml::{MLMultiArray, MLMultiArrayDataType};
use objc2_foundation::{NSArray, NSNumber};
use super::{CachedInputShape, CoreMlError};
pub(super) fn contiguous_strides(shape: &[usize]) -> Vec<usize> {
let mut strides = vec![1usize; shape.len()];
for i in (0..shape.len().saturating_sub(1)).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
strides
}
pub(super) fn ns_number_array(values: &[usize]) -> Retained<NSArray<NSNumber>> {
let numbers: Vec<Retained<NSNumber>> = values
.iter()
.copied()
.map(|value| NSNumber::new_isize(value as isize))
.collect();
NSArray::from_retained_slice(&numbers)
}
pub(super) fn create_multi_array_with_deallocator(
data: &[f32],
shape: &[usize],
deallocator: &RcBlock<dyn Fn(NonNull<c_void>)>,
) -> Result<Retained<MLMultiArray>, CoreMlError> {
let ns_shape = ns_number_array(shape);
let ns_strides = ns_number_array(&contiguous_strides(shape));
let ptr = NonNull::new(data.as_ptr() as *mut c_void)
.ok_or_else(|| CoreMlError::ArrayCreationFailed("null data pointer".into()))?;
#[allow(deprecated)]
unsafe {
MLMultiArray::initWithDataPointer_shape_dataType_strides_deallocator_error(
MLMultiArray::alloc(),
ptr,
&ns_shape,
MLMultiArrayDataType::Float32,
&ns_strides,
Some(deallocator),
)
}
.map_err(|e| CoreMlError::ArrayCreationFailed(format!("{e}")))
}
pub(super) fn create_multi_array_cached_with_deallocator(
data: &[f32],
cached: &CachedInputShape,
deallocator: &RcBlock<dyn Fn(NonNull<c_void>)>,
) -> Result<Retained<MLMultiArray>, CoreMlError> {
let ptr = NonNull::new(data.as_ptr() as *mut c_void)
.ok_or_else(|| CoreMlError::ArrayCreationFailed("null data pointer".into()))?;
#[allow(deprecated)]
unsafe {
MLMultiArray::initWithDataPointer_shape_dataType_strides_deallocator_error(
MLMultiArray::alloc(),
ptr,
&cached.ns_shape,
MLMultiArrayDataType::Float32,
&cached.ns_strides,
Some(deallocator),
)
}
.map_err(|e| CoreMlError::ArrayCreationFailed(format!("{e}")))
}
#[allow(deprecated)]
pub(super) fn extract_output(array: &MLMultiArray) -> Result<(Vec<f32>, Vec<usize>), CoreMlError> {
let (count, ptr, dtype, ns_shape) = unsafe {
(
array.count() as usize,
array.dataPointer(),
array.dataType(),
array.shape(),
)
};
let shape: Vec<usize> = (0..ns_shape.len())
.map(|i| ns_shape.objectAtIndex(i).as_isize() as usize)
.collect();
let data = if dtype == MLMultiArrayDataType::Float16 {
let fp16_data = unsafe { std::slice::from_raw_parts(ptr.as_ptr() as *const u16, count) };
fp16_data.iter().copied().map(f16_to_f32).collect()
} else {
let fp32_data = unsafe { std::slice::from_raw_parts(ptr.as_ptr() as *const f32, count) };
fp32_data.to_vec()
};
Ok((data, shape))
}
fn f16_to_f32(bits: u16) -> f32 {
let sign = ((bits >> 15) & 1) as u32;
let exp = ((bits >> 10) & 0x1f) as u32;
let mant = (bits & 0x3ff) as u32;
if exp == 0 {
if mant == 0 {
return f32::from_bits(sign << 31);
}
let mut e: i32 = exp as i32;
let mut m = mant;
while m & 0x400 == 0 {
m <<= 1;
e -= 1;
}
m &= 0x3ff;
let f32_exp = ((127 - 15) + e + 1) as u32;
return f32::from_bits((sign << 31) | (f32_exp << 23) | (m << 13));
}
if exp == 0x1f {
return f32::from_bits((sign << 31) | (0xff_u32 << 23) | (mant << 13));
}
let f32_exp = exp - 15 + 127;
f32::from_bits((sign << 31) | (f32_exp << 23) | (mant << 13))
}