Skip to main content

coreml_native/
lib.rs

1//! Safe, ergonomic Rust bindings for Apple CoreML inference with ANE acceleration.
2//!
3//! # Platform Support
4//!
5//! Requires macOS or iOS. On non-Apple targets, types exist as stubs
6//! returning `Error::UnsupportedPlatform`.
7
8pub mod async_bridge;
9pub mod compile;
10pub mod description;
11pub mod error;
12pub(crate) mod ffi;
13mod model_async;
14pub mod state;
15pub mod tensor;
16pub mod batch;
17pub mod compute;
18pub mod model_lifecycle;
19#[cfg(feature = "ndarray")]
20pub mod ndarray_support;
21
22pub use async_bridge::CompletionFuture;
23pub use batch::{BatchPrediction, BatchProvider};
24pub use compile::{compile_model, compile_model_async};
25pub use compute::{available_devices, ComputeDevice};
26pub use description::{FeatureDescription, FeatureType, ModelMetadata, ShapeConstraint};
27pub use error::{Error, ErrorKind, Result};
28pub use model_lifecycle::ModelHandle;
29pub use state::State;
30pub use tensor::{AsMultiArray, BorrowedTensor, DataType, OwnedTensor};
31#[cfg(feature = "ndarray")]
32pub use ndarray_support::PredictionNdarray;
33
34/// Compute unit selection for CoreML model loading.
35///
36/// Default is `All` — uses CPU, GPU (Metal), and Apple Neural Engine
37/// for maximum throughput. This is the whole point of native CoreML.
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
39pub enum ComputeUnits {
40    /// CPU only — no GPU or ANE.
41    CpuOnly,
42    /// CPU + GPU (Metal) — no ANE.
43    CpuAndGpu,
44    /// CPU + Apple Neural Engine — no GPU.
45    CpuAndNeuralEngine,
46    /// All available: CPU + GPU + ANE. **Use this.**
47    #[default]
48    All,
49}
50
51impl std::fmt::Display for ComputeUnits {
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        match self {
54            Self::CpuOnly => write!(f, "CPU only"),
55            Self::CpuAndGpu => write!(f, "CPU + GPU"),
56            Self::CpuAndNeuralEngine => write!(f, "CPU + Neural Engine"),
57            Self::All => write!(f, "All (CPU + GPU + ANE)"),
58        }
59    }
60}
61
62// ─── Model ──────────────────────────────────────────────────────────────────
63
64#[cfg(target_vendor = "apple")]
65pub struct Model {
66    inner: objc2::rc::Retained<objc2_core_ml::MLModel>,
67    path: std::path::PathBuf,
68}
69
70#[cfg(target_vendor = "apple")]
71impl std::fmt::Debug for Model {
72    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73        f.debug_struct("Model").field("path", &self.path).finish()
74    }
75}
76
77#[cfg(not(target_vendor = "apple"))]
78#[derive(Debug)]
79pub struct Model {
80    _private: (),
81}
82
83// Apple documents MLModel.predictionFromFeatures as thread-safe for
84// concurrent read-only predictions on the same model instance.
85#[cfg(target_vendor = "apple")]
86unsafe impl Send for Model {}
87#[cfg(target_vendor = "apple")]
88unsafe impl Sync for Model {}
89
90impl Model {
91    #[cfg(target_vendor = "apple")]
92    pub fn load(path: impl AsRef<std::path::Path>, compute_units: ComputeUnits) -> Result<Self> {
93        use objc2_core_ml::{MLComputeUnits, MLModel, MLModelConfiguration};
94
95        let path = path.as_ref();
96        let path_str = path.to_str().ok_or_else(|| {
97            Error::new(ErrorKind::ModelLoad, "path contains non-UTF8 characters")
98        })?;
99
100        let url = objc2_foundation::NSURL::fileURLWithPath(&ffi::str_to_nsstring(path_str));
101        let config = unsafe { MLModelConfiguration::new() };
102        let ml_units = match compute_units {
103            ComputeUnits::CpuOnly => MLComputeUnits(1),
104            ComputeUnits::CpuAndGpu => MLComputeUnits::CPUAndGPU,
105            ComputeUnits::CpuAndNeuralEngine => MLComputeUnits(2),
106            ComputeUnits::All => MLComputeUnits::All,
107        };
108        unsafe { config.setComputeUnits(ml_units) };
109
110        let inner = unsafe { MLModel::modelWithContentsOfURL_configuration_error(&url, &config) }
111            .map_err(|e| Error::from_nserror(ErrorKind::ModelLoad, &e))?;
112
113        Ok(Self { inner, path: path.to_path_buf() })
114    }
115
116    #[cfg(not(target_vendor = "apple"))]
117    pub fn load(_path: impl AsRef<std::path::Path>, _compute_units: ComputeUnits) -> Result<Self> {
118        Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
119    }
120
121    /// The filesystem path this model was loaded from.
122    pub fn path(&self) -> &std::path::Path {
123        #[cfg(target_vendor = "apple")]
124        { &self.path }
125        #[cfg(not(target_vendor = "apple"))]
126        { std::path::Path::new("") }
127    }
128
129    /// Run a synchronous prediction with named input tensors.
130    ///
131    /// Accepts any type implementing `AsMultiArray` (both `BorrowedTensor` and `OwnedTensor`).
132    #[cfg(target_vendor = "apple")]
133    pub fn predict(&self, inputs: &[(&str, &dyn AsMultiArray)]) -> Result<Prediction> {
134        use objc2::AnyThread;
135        use objc2_core_ml::{MLDictionaryFeatureProvider, MLFeatureProvider, MLFeatureValue};
136        use objc2_foundation::{NSDictionary, NSString};
137
138        objc2::rc::autoreleasepool(|_pool| {
139            let mut keys: Vec<objc2::rc::Retained<NSString>> = Vec::with_capacity(inputs.len());
140            let mut vals: Vec<objc2::rc::Retained<MLFeatureValue>> = Vec::with_capacity(inputs.len());
141
142            for &(name, tensor) in inputs {
143                keys.push(ffi::str_to_nsstring(name));
144                vals.push(unsafe { MLFeatureValue::featureValueWithMultiArray(tensor.as_ml_multi_array()) });
145            }
146
147            let key_refs: Vec<&NSString> = keys.iter().map(|k| &**k).collect();
148            let val_refs: Vec<&MLFeatureValue> = vals.iter().map(|v| &**v).collect();
149
150            let dict: objc2::rc::Retained<NSDictionary<NSString, MLFeatureValue>> =
151                NSDictionary::from_slices(&key_refs, &val_refs);
152
153            let dict_any: &NSDictionary<NSString, objc2::runtime::AnyObject> =
154                unsafe { &*((&*dict) as *const NSDictionary<NSString, MLFeatureValue>
155                    as *const NSDictionary<NSString, objc2::runtime::AnyObject>) };
156
157            let provider = unsafe {
158                MLDictionaryFeatureProvider::initWithDictionary_error(
159                    MLDictionaryFeatureProvider::alloc(),
160                    dict_any,
161                )
162            }
163            .map_err(|e| Error::from_nserror(ErrorKind::Prediction, &e))?;
164
165            let provider_ref: &objc2::runtime::ProtocolObject<dyn MLFeatureProvider> =
166                objc2::runtime::ProtocolObject::from_ref(&*provider);
167
168            let result = unsafe { self.inner.predictionFromFeatures_error(provider_ref) }
169                .map_err(|e| Error::from_nserror(ErrorKind::Prediction, &e))?;
170
171            Ok(Prediction { inner: result })
172        })
173    }
174
175    #[cfg(not(target_vendor = "apple"))]
176    pub fn predict(&self, _inputs: &[(&str, &dyn AsMultiArray)]) -> Result<Prediction> {
177        Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
178    }
179
180    /// Get descriptions of all model inputs.
181    #[cfg(target_vendor = "apple")]
182    pub fn inputs(&self) -> Vec<FeatureDescription> {
183        let desc = unsafe { self.inner.modelDescription() };
184        let input_map = unsafe { desc.inputDescriptionsByName() };
185        description::extract_features(&input_map)
186    }
187
188    /// Get descriptions of all model outputs.
189    #[cfg(target_vendor = "apple")]
190    pub fn outputs(&self) -> Vec<FeatureDescription> {
191        let desc = unsafe { self.inner.modelDescription() };
192        let output_map = unsafe { desc.outputDescriptionsByName() };
193        description::extract_features(&output_map)
194    }
195
196    /// Get model metadata (author, description, version, license).
197    #[cfg(target_vendor = "apple")]
198    pub fn metadata(&self) -> ModelMetadata {
199        let desc = unsafe { self.inner.modelDescription() };
200        description::extract_metadata(&desc)
201    }
202
203    /// Create a new state for stateful prediction (macOS 15+ / iOS 18+).
204    #[cfg(target_vendor = "apple")]
205    pub fn new_state(&self) -> Result<State> {
206        let inner = unsafe { self.inner.newState() };
207        Ok(State { inner })
208    }
209
210    #[cfg(not(target_vendor = "apple"))]
211    pub fn new_state(&self) -> Result<State> {
212        Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
213    }
214
215    /// Run prediction with persistent state (macOS 15+ / iOS 18+).
216    #[cfg(target_vendor = "apple")]
217    pub fn predict_stateful(
218        &self,
219        inputs: &[(&str, &dyn AsMultiArray)],
220        state: &State,
221    ) -> Result<Prediction> {
222        use objc2::AnyThread;
223        use objc2_core_ml::{MLDictionaryFeatureProvider, MLFeatureProvider, MLFeatureValue};
224        use objc2_foundation::{NSDictionary, NSString};
225
226        objc2::rc::autoreleasepool(|_pool| {
227            let mut keys: Vec<objc2::rc::Retained<NSString>> = Vec::with_capacity(inputs.len());
228            let mut vals: Vec<objc2::rc::Retained<MLFeatureValue>> = Vec::with_capacity(inputs.len());
229
230            for &(name, tensor) in inputs {
231                keys.push(ffi::str_to_nsstring(name));
232                vals.push(unsafe { MLFeatureValue::featureValueWithMultiArray(tensor.as_ml_multi_array()) });
233            }
234
235            let key_refs: Vec<&NSString> = keys.iter().map(|k| &**k).collect();
236            let val_refs: Vec<&MLFeatureValue> = vals.iter().map(|v| &**v).collect();
237            let dict: objc2::rc::Retained<NSDictionary<NSString, MLFeatureValue>> =
238                NSDictionary::from_slices(&key_refs, &val_refs);
239            let dict_any: &NSDictionary<NSString, objc2::runtime::AnyObject> =
240                unsafe { &*((&*dict) as *const NSDictionary<NSString, MLFeatureValue>
241                    as *const NSDictionary<NSString, objc2::runtime::AnyObject>) };
242
243            let provider = unsafe {
244                MLDictionaryFeatureProvider::initWithDictionary_error(
245                    MLDictionaryFeatureProvider::alloc(), dict_any,
246                )
247            }
248            .map_err(|e| Error::from_nserror(ErrorKind::Prediction, &e))?;
249
250            let provider_ref: &objc2::runtime::ProtocolObject<dyn MLFeatureProvider> =
251                objc2::runtime::ProtocolObject::from_ref(&*provider);
252
253            let result = unsafe {
254                self.inner.predictionFromFeatures_usingState_error(provider_ref, &state.inner)
255            }
256            .map_err(|e| Error::from_nserror(ErrorKind::Prediction, &e))?;
257
258            Ok(Prediction { inner: result })
259        })
260    }
261
262    #[cfg(not(target_vendor = "apple"))]
263    pub fn predict_stateful(
264        &self,
265        _inputs: &[(&str, &dyn AsMultiArray)],
266        _state: &State,
267    ) -> Result<Prediction> {
268        Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
269    }
270
271    /// Run batch prediction for multiple input sets at once.
272    ///
273    /// More efficient than calling `predict()` in a loop.
274    #[cfg(target_vendor = "apple")]
275    pub fn predict_batch(&self, batch: &batch::BatchProvider) -> Result<batch::BatchPrediction> {
276        use objc2_core_ml::MLBatchProvider;
277
278        let batch_ref: &objc2::runtime::ProtocolObject<dyn MLBatchProvider> =
279            objc2::runtime::ProtocolObject::from_ref(&*batch.inner);
280
281        let result = unsafe { self.inner.predictionsFromBatch_error(batch_ref) }
282            .map_err(|e| Error::from_nserror(ErrorKind::Prediction, &e))?;
283
284        Ok(batch::BatchPrediction { inner: result })
285    }
286
287    #[cfg(not(target_vendor = "apple"))]
288    pub fn predict_batch(&self, _batch: &batch::BatchProvider) -> Result<batch::BatchPrediction> {
289        Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
290    }
291
292    #[cfg(not(target_vendor = "apple"))]
293    pub fn inputs(&self) -> Vec<FeatureDescription> { vec![] }
294
295    #[cfg(not(target_vendor = "apple"))]
296    pub fn outputs(&self) -> Vec<FeatureDescription> { vec![] }
297
298    #[cfg(not(target_vendor = "apple"))]
299    pub fn metadata(&self) -> ModelMetadata { ModelMetadata::default() }
300}
301
302// ─── Prediction result ──────────────────────────────────────────────────────
303
304#[cfg(target_vendor = "apple")]
305pub struct Prediction {
306    inner: objc2::rc::Retained<objc2::runtime::ProtocolObject<dyn objc2_core_ml::MLFeatureProvider>>,
307}
308
309#[cfg(not(target_vendor = "apple"))]
310pub struct Prediction {
311    _private: (),
312}
313
314// The Prediction holds a Retained MLFeatureProvider which is reference-counted.
315// Safe to move to another thread for output extraction.
316#[cfg(target_vendor = "apple")]
317unsafe impl Send for Prediction {}
318#[cfg(target_vendor = "apple")]
319unsafe impl Sync for Prediction {}
320
321impl Prediction {
322    /// Get an output as (Vec<f32>, shape), converting from the model's native data type.
323    /// Allocates a new Vec for the output data.
324    #[cfg(target_vendor = "apple")]
325    #[allow(deprecated)]
326    pub fn get_f32(&self, name: &str) -> Result<(Vec<f32>, Vec<usize>)> {
327        objc2::rc::autoreleasepool(|_pool| {
328            let (count, shape, data_type, array) = self.get_output_array(name)?;
329            let mut buf = vec![0.0f32; count];
330            Self::copy_array_to_f32(&array, data_type, count, &mut buf)?;
331            Ok((buf, shape))
332        })
333    }
334
335    /// Copy an output into a caller-provided f32 buffer (zero-alloc hot path).
336    ///
337    /// Returns the shape. The buffer must be large enough to hold all elements.
338    #[cfg(target_vendor = "apple")]
339    #[allow(deprecated)]
340    pub fn get_f32_into(&self, name: &str, buf: &mut [f32]) -> Result<Vec<usize>> {
341        objc2::rc::autoreleasepool(|_pool| {
342            let (count, shape, data_type, array) = self.get_output_array(name)?;
343            if buf.len() < count {
344                return Err(Error::new(
345                    ErrorKind::InvalidShape,
346                    format!("buffer length {} < output element count {count}", buf.len()),
347                ));
348            }
349            Self::copy_array_to_f32(&array, data_type, count, buf)?;
350            Ok(shape)
351        })
352    }
353
354    #[cfg(target_vendor = "apple")]
355    #[allow(deprecated)]
356    #[allow(clippy::type_complexity)]
357    fn get_output_array(
358        &self,
359        name: &str,
360    ) -> Result<(
361        usize,
362        Vec<usize>,
363        Option<DataType>,
364        objc2::rc::Retained<objc2_core_ml::MLMultiArray>,
365    )> {
366        use objc2_core_ml::MLFeatureProvider;
367
368        let ns_name = ffi::str_to_nsstring(name);
369        let feature_val = unsafe { self.inner.featureValueForName(&ns_name) }.ok_or_else(|| {
370            Error::new(ErrorKind::Prediction, format!("output '{name}' not found"))
371        })?;
372
373        let array = unsafe { feature_val.multiArrayValue() }.ok_or_else(|| {
374            Error::new(ErrorKind::Prediction, format!("output '{name}' is not a multi-array"))
375        })?;
376
377        let shape = ffi::nsarray_to_shape(unsafe { &array.shape() });
378        let count = tensor::element_count(&shape);
379        let dt_raw = unsafe { array.dataType() };
380        let data_type = ffi::ml_to_datatype(dt_raw.0);
381
382        Ok((count, shape, data_type, array))
383    }
384
385    /// Copy MLMultiArray data into a flat f32 buffer in row-major (C-contiguous) order.
386    ///
387    /// CoreML MLMultiArray outputs may have non-row-major strides (especially when
388    /// inference runs on GPU or ANE). This function reads the array's actual strides
389    /// and iterates in logical (row-major) index order, computing the physical offset
390    /// for each element using the strides.
391    #[cfg(target_vendor = "apple")]
392    #[allow(deprecated)]
393    #[allow(clippy::needless_range_loop)]
394    fn copy_array_to_f32(
395        array: &objc2_core_ml::MLMultiArray,
396        data_type: Option<DataType>,
397        count: usize,
398        buf: &mut [f32],
399    ) -> Result<()> {
400        unsafe {
401            let ptr = array.dataPointer();
402            let shape = ffi::nsarray_to_shape(&array.shape());
403            let strides = ffi::nsarray_to_shape(&array.strides());
404            let row_major_strides = tensor::compute_strides(&shape);
405            let is_contiguous = strides == row_major_strides;
406
407            if is_contiguous {
408                // Fast path: data is already row-major contiguous
409                match data_type {
410                    Some(DataType::Float32) => {
411                        let src = ptr.as_ptr() as *const f32;
412                        std::ptr::copy_nonoverlapping(src, buf.as_mut_ptr(), count);
413                    }
414                    Some(DataType::Float16) => {
415                        let src = ptr.as_ptr() as *const u16;
416                        for i in 0..count {
417                            buf[i] = f16_to_f32(*src.add(i));
418                        }
419                    }
420                    Some(DataType::Float64) => {
421                        let src = ptr.as_ptr() as *const f64;
422                        for i in 0..count {
423                            buf[i] = *src.add(i) as f32;
424                        }
425                    }
426                    Some(DataType::Int32) => {
427                        let src = ptr.as_ptr() as *const i32;
428                        for i in 0..count {
429                            buf[i] = *src.add(i) as f32;
430                        }
431                    }
432                    Some(DataType::Int16) => {
433                        let src = ptr.as_ptr() as *const i16;
434                        for i in 0..count {
435                            buf[i] = *src.add(i) as f32;
436                        }
437                    }
438                    Some(DataType::Int8) => {
439                        let src = ptr.as_ptr() as *const i8;
440                        for i in 0..count {
441                            buf[i] = *src.add(i) as f32;
442                        }
443                    }
444                    Some(DataType::UInt32) => {
445                        let src = ptr.as_ptr() as *const u32;
446                        for i in 0..count {
447                            buf[i] = *src.add(i) as f32;
448                        }
449                    }
450                    Some(DataType::UInt16) => {
451                        let src = ptr.as_ptr() as *const u16;
452                        for i in 0..count {
453                            buf[i] = *src.add(i) as f32;
454                        }
455                    }
456                    Some(DataType::UInt8) => {
457                        let src = ptr.as_ptr() as *const u8;
458                        for i in 0..count {
459                            buf[i] = *src.add(i) as f32;
460                        }
461                    }
462                    None => {
463                        return Err(Error::new(
464                            ErrorKind::Prediction,
465                            "unsupported output data type",
466                        ));
467                    }
468                }
469            } else {
470                // Slow path: non-contiguous strides — iterate in logical row-major order,
471                // compute physical offset for each element using the actual strides.
472                let ndims = shape.len();
473                let mut indices = vec![0usize; ndims];
474
475                macro_rules! strided_copy {
476                    ($src_type:ty, $convert:expr) => {{
477                        let src = ptr.as_ptr() as *const $src_type;
478                        for logical_idx in 0..count {
479                            let physical: usize = indices.iter()
480                                .zip(strides.iter())
481                                .map(|(&i, &s)| i * s)
482                                .sum();
483                            buf[logical_idx] = $convert(*src.add(physical));
484                            // Increment indices in row-major order (last dim fastest)
485                            for d in (0..ndims).rev() {
486                                indices[d] += 1;
487                                if indices[d] < shape[d] {
488                                    break;
489                                }
490                                indices[d] = 0;
491                            }
492                        }
493                    }};
494                }
495
496                match data_type {
497                    Some(DataType::Float32) => strided_copy!(f32, |v: f32| v),
498                    Some(DataType::Float16) => strided_copy!(u16, |v: u16| f16_to_f32(v)),
499                    Some(DataType::Float64) => strided_copy!(f64, |v: f64| v as f32),
500                    Some(DataType::Int32)   => strided_copy!(i32, |v: i32| v as f32),
501                    Some(DataType::Int16)   => strided_copy!(i16, |v: i16| v as f32),
502                    Some(DataType::Int8)    => strided_copy!(i8,  |v: i8|  v as f32),
503                    Some(DataType::UInt32)  => strided_copy!(u32, |v: u32| v as f32),
504                    Some(DataType::UInt16)  => strided_copy!(u16, |v: u16| v as f32),
505                    Some(DataType::UInt8)   => strided_copy!(u8,  |v: u8|  v as f32),
506                    None => {
507                        return Err(Error::new(
508                            ErrorKind::Prediction,
509                            "unsupported output data type",
510                        ));
511                    }
512                }
513            }
514        }
515        Ok(())
516    }
517
518    #[cfg(not(target_vendor = "apple"))]
519    pub fn get_f32(&self, _name: &str) -> Result<(Vec<f32>, Vec<usize>)> {
520        Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
521    }
522
523    #[cfg(not(target_vendor = "apple"))]
524    pub fn get_f32_into(&self, _name: &str, _buf: &mut [f32]) -> Result<Vec<usize>> {
525        Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
526    }
527
528    /// Get an output as (Vec<i32>, shape). Only works if the output is Int32.
529    #[cfg(target_vendor = "apple")]
530    #[allow(deprecated)]
531    pub fn get_i32(&self, name: &str) -> Result<(Vec<i32>, Vec<usize>)> {
532        objc2::rc::autoreleasepool(|_pool| {
533            let (count, shape, data_type, array) = self.get_output_array(name)?;
534            match data_type {
535                Some(DataType::Int32) => {
536                    let mut buf = vec![0i32; count];
537                    unsafe {
538                        let ptr = array.dataPointer();
539                        let src = ptr.as_ptr() as *const i32;
540                        std::ptr::copy_nonoverlapping(src, buf.as_mut_ptr(), count);
541                    }
542                    Ok((buf, shape))
543                }
544                Some(dt) => Err(Error::new(
545                    ErrorKind::Prediction,
546                    format!("output '{name}' is {dt}, not Int32"),
547                )),
548                None => Err(Error::new(ErrorKind::Prediction, "unsupported output data type")),
549            }
550        })
551    }
552
553    /// Get an output as (Vec<f64>, shape). Only works if the output is Float64.
554    #[cfg(target_vendor = "apple")]
555    #[allow(deprecated)]
556    pub fn get_f64(&self, name: &str) -> Result<(Vec<f64>, Vec<usize>)> {
557        objc2::rc::autoreleasepool(|_pool| {
558            let (count, shape, data_type, array) = self.get_output_array(name)?;
559            match data_type {
560                Some(DataType::Float64) => {
561                    let mut buf = vec![0.0f64; count];
562                    unsafe {
563                        let ptr = array.dataPointer();
564                        let src = ptr.as_ptr() as *const f64;
565                        std::ptr::copy_nonoverlapping(src, buf.as_mut_ptr(), count);
566                    }
567                    Ok((buf, shape))
568                }
569                Some(dt) => Err(Error::new(
570                    ErrorKind::Prediction,
571                    format!("output '{name}' is {dt}, not Float64"),
572                )),
573                None => Err(Error::new(ErrorKind::Prediction, "unsupported output data type")),
574            }
575        })
576    }
577
578    /// Get an output as raw bytes and its shape + data type.
579    #[cfg(target_vendor = "apple")]
580    #[allow(deprecated)]
581    pub fn get_raw(&self, name: &str) -> Result<(Vec<u8>, Vec<usize>, Option<DataType>)> {
582        objc2::rc::autoreleasepool(|_pool| {
583            let (count, shape, data_type, array) = self.get_output_array(name)?;
584            let byte_size = data_type.map(|dt| dt.byte_size()).unwrap_or(4);
585            let total_bytes = count * byte_size;
586            let mut buf = vec![0u8; total_bytes];
587            unsafe {
588                let ptr = array.dataPointer();
589                std::ptr::copy_nonoverlapping(
590                    ptr.as_ptr() as *const u8,
591                    buf.as_mut_ptr(),
592                    total_bytes,
593                );
594            }
595            Ok((buf, shape, data_type))
596        })
597    }
598
599    #[cfg(not(target_vendor = "apple"))]
600    pub fn get_i32(&self, _name: &str) -> Result<(Vec<i32>, Vec<usize>)> {
601        Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
602    }
603
604    #[cfg(not(target_vendor = "apple"))]
605    pub fn get_f64(&self, _name: &str) -> Result<(Vec<f64>, Vec<usize>)> {
606        Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
607    }
608
609    #[cfg(not(target_vendor = "apple"))]
610    pub fn get_raw(&self, _name: &str) -> Result<(Vec<u8>, Vec<usize>, Option<DataType>)> {
611        Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
612    }
613}
614
615/// Convert a half-precision float (u16 bits) to f32.
616#[cfg(target_vendor = "apple")]
617fn f16_to_f32(bits: u16) -> f32 {
618    let sign = ((bits >> 15) & 1) as u32;
619    let exp = ((bits >> 10) & 0x1f) as u32;
620    let frac = (bits & 0x3ff) as u32;
621
622    if exp == 0 {
623        if frac == 0 {
624            f32::from_bits(sign << 31)
625        } else {
626            let mut e = 0i32;
627            let mut f = frac;
628            while (f & 0x400) == 0 {
629                f <<= 1;
630                e -= 1;
631            }
632            f &= 0x3ff;
633            let exp32 = (127 - 15 + 1 + e) as u32;
634            f32::from_bits((sign << 31) | (exp32 << 23) | (f << 13))
635        }
636    } else if exp == 31 {
637        if frac == 0 {
638            f32::from_bits((sign << 31) | (0xff << 23))
639        } else {
640            f32::from_bits((sign << 31) | (0xff << 23) | (frac << 13))
641        }
642    } else {
643        let exp32 = exp + (127 - 15);
644        f32::from_bits((sign << 31) | (exp32 << 23) | (frac << 13))
645    }
646}
647
648#[cfg(test)]
649mod tests {
650    use super::*;
651
652    #[test]
653    fn compute_units_default_is_all() {
654        assert_eq!(ComputeUnits::default(), ComputeUnits::All);
655    }
656
657    #[test]
658    fn compute_units_display() {
659        assert_eq!(format!("{}", ComputeUnits::CpuAndGpu), "CPU + GPU");
660        assert_eq!(format!("{}", ComputeUnits::All), "All (CPU + GPU + ANE)");
661    }
662
663    #[test]
664    fn compute_units_display_cpu_only() {
665        assert_eq!(format!("{}", ComputeUnits::CpuOnly), "CPU only");
666    }
667
668    #[cfg(not(target_vendor = "apple"))]
669    #[test]
670    fn model_load_fails_on_non_apple() {
671        let err = Model::load("/tmp/fake.mlmodelc", ComputeUnits::All).unwrap_err();
672        assert_eq!(err.kind(), &ErrorKind::UnsupportedPlatform);
673    }
674}