#[cfg(target_vendor = "apple")]
use objc2::rc::Retained;
#[cfg(target_vendor = "apple")]
use objc2_foundation::{NSArray, NSNumber, NSString};
use crate::tensor::DataType;
#[cfg(target_vendor = "apple")]
pub(crate) fn shape_to_nsarray(shape: &[usize]) -> Retained<NSArray<NSNumber>> {
let numbers: Vec<Retained<NSNumber>> = shape
.iter()
.map(|&d| NSNumber::new_isize(d as isize))
.collect();
let refs: Vec<&NSNumber> = numbers.iter().map(|n| &**n).collect();
NSArray::from_slice(&refs)
}
#[cfg(target_vendor = "apple")]
pub(crate) fn nsarray_to_shape(array: &NSArray<NSNumber>) -> Vec<usize> {
let count = array.len();
let mut result = Vec::with_capacity(count);
for i in 0..count {
let num = array.objectAtIndex(i);
result.push(num.as_isize() as usize);
}
result
}
#[cfg(target_vendor = "apple")]
pub(crate) fn str_to_nsstring(s: &str) -> Retained<NSString> {
NSString::from_str(s)
}
#[cfg(target_vendor = "apple")]
pub(crate) fn nsstring_to_string(s: &NSString) -> String {
s.to_string()
}
pub(crate) fn datatype_to_ml(dt: DataType) -> isize {
match dt {
DataType::Float16 => 0x10000 | 16, DataType::Float32 => 0x10000 | 32, DataType::Float64 => 0x10000 | 64, DataType::Int32 => 0x20000 | 32, DataType::Int8 => 0x20000 | 8, DataType::Int16 => 0x20000 | 16, DataType::UInt32 => 0x30000 | 32, DataType::UInt16 => 0x30000 | 16, DataType::UInt8 => 0x30000 | 8, }
}
pub(crate) fn ml_to_datatype(raw: isize) -> Option<DataType> {
match raw {
65552 => Some(DataType::Float16),
65568 => Some(DataType::Float32),
65600 => Some(DataType::Float64),
131104 => Some(DataType::Int32),
131080 => Some(DataType::Int8), 131088 => Some(DataType::Int16), _ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn datatype_roundtrip() {
for dt in [
DataType::Float16,
DataType::Float32,
DataType::Float64,
DataType::Int32,
DataType::Int8, DataType::Int16, ] {
let raw = datatype_to_ml(dt);
let back = ml_to_datatype(raw).unwrap();
assert_eq!(dt, back);
}
}
#[test]
fn unsigned_types_no_coreml_mapping() {
for dt in [DataType::UInt32, DataType::UInt16, DataType::UInt8] {
let raw = datatype_to_ml(dt);
assert_eq!(ml_to_datatype(raw), None,
"unsigned type {dt} sentinel value {raw} should not reverse-map");
}
}
#[test]
fn ml_to_datatype_unknown() {
assert_eq!(ml_to_datatype(999), None);
}
#[cfg(target_vendor = "apple")]
mod apple_tests {
use super::super::*;
#[test]
fn shape_roundtrip() {
let shape = vec![1, 128, 500];
let ns = shape_to_nsarray(&shape);
let back = nsarray_to_shape(&ns);
assert_eq!(shape, back);
}
#[test]
fn shape_empty() {
let shape: Vec<usize> = vec![];
let ns = shape_to_nsarray(&shape);
let back = nsarray_to_shape(&ns);
assert_eq!(shape, back);
}
#[test]
fn string_roundtrip() {
let s = "audio_signal";
let ns = str_to_nsstring(s);
let back = nsstring_to_string(&ns);
assert_eq!(s, back);
}
#[test]
fn string_unicode() {
let s = "input_\u{2581}test";
let ns = str_to_nsstring(s);
let back = nsstring_to_string(&ns);
assert_eq!(s, back);
}
}
}