1pub mod async_bridge;
9pub mod compile;
10pub mod description;
11pub mod error;
12pub(crate) mod ffi;
13mod model_async;
14pub mod state;
15pub mod tensor;
16pub mod batch;
17pub mod compute;
18pub mod model_lifecycle;
19#[cfg(feature = "ndarray")]
20pub mod ndarray_support;
21
22pub use async_bridge::CompletionFuture;
23pub use batch::{BatchPrediction, BatchProvider};
24pub use compile::{compile_model, compile_model_async};
25pub use compute::{available_devices, ComputeDevice};
26pub use description::{FeatureDescription, FeatureType, ModelMetadata, ShapeConstraint};
27pub use error::{Error, ErrorKind, Result};
28pub use model_lifecycle::ModelHandle;
29pub use state::State;
30pub use tensor::{AsMultiArray, BorrowedTensor, DataType, OwnedTensor};
31#[cfg(feature = "ndarray")]
32pub use ndarray_support::PredictionNdarray;
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
39pub enum ComputeUnits {
40 CpuOnly,
42 CpuAndGpu,
44 CpuAndNeuralEngine,
46 #[default]
48 All,
49}
50
51impl std::fmt::Display for ComputeUnits {
52 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53 match self {
54 Self::CpuOnly => write!(f, "CPU only"),
55 Self::CpuAndGpu => write!(f, "CPU + GPU"),
56 Self::CpuAndNeuralEngine => write!(f, "CPU + Neural Engine"),
57 Self::All => write!(f, "All (CPU + GPU + ANE)"),
58 }
59 }
60}
61
62#[cfg(target_vendor = "apple")]
65pub struct Model {
66 inner: objc2::rc::Retained<objc2_core_ml::MLModel>,
67 path: std::path::PathBuf,
68}
69
70#[cfg(target_vendor = "apple")]
71impl std::fmt::Debug for Model {
72 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73 f.debug_struct("Model").field("path", &self.path).finish()
74 }
75}
76
77#[cfg(not(target_vendor = "apple"))]
78#[derive(Debug)]
79pub struct Model {
80 _private: (),
81}
82
83#[cfg(target_vendor = "apple")]
86unsafe impl Send for Model {}
87#[cfg(target_vendor = "apple")]
88unsafe impl Sync for Model {}
89
90impl Model {
91 #[cfg(target_vendor = "apple")]
92 pub fn load(path: impl AsRef<std::path::Path>, compute_units: ComputeUnits) -> Result<Self> {
93 use objc2_core_ml::{MLComputeUnits, MLModel, MLModelConfiguration};
94
95 let path = path.as_ref();
96 let path_str = path.to_str().ok_or_else(|| {
97 Error::new(ErrorKind::ModelLoad, "path contains non-UTF8 characters")
98 })?;
99
100 let url = objc2_foundation::NSURL::fileURLWithPath(&ffi::str_to_nsstring(path_str));
101 let config = unsafe { MLModelConfiguration::new() };
102 let ml_units = match compute_units {
103 ComputeUnits::CpuOnly => MLComputeUnits(1),
104 ComputeUnits::CpuAndGpu => MLComputeUnits::CPUAndGPU,
105 ComputeUnits::CpuAndNeuralEngine => MLComputeUnits(2),
106 ComputeUnits::All => MLComputeUnits::All,
107 };
108 unsafe { config.setComputeUnits(ml_units) };
109
110 let inner = unsafe { MLModel::modelWithContentsOfURL_configuration_error(&url, &config) }
111 .map_err(|e| Error::from_nserror(ErrorKind::ModelLoad, &e))?;
112
113 Ok(Self { inner, path: path.to_path_buf() })
114 }
115
116 #[cfg(not(target_vendor = "apple"))]
117 pub fn load(_path: impl AsRef<std::path::Path>, _compute_units: ComputeUnits) -> Result<Self> {
118 Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
119 }
120
121 pub fn path(&self) -> &std::path::Path {
123 #[cfg(target_vendor = "apple")]
124 { &self.path }
125 #[cfg(not(target_vendor = "apple"))]
126 { std::path::Path::new("") }
127 }
128
129 #[cfg(target_vendor = "apple")]
133 pub fn predict(&self, inputs: &[(&str, &dyn AsMultiArray)]) -> Result<Prediction> {
134 use objc2::AnyThread;
135 use objc2_core_ml::{MLDictionaryFeatureProvider, MLFeatureProvider, MLFeatureValue};
136 use objc2_foundation::{NSDictionary, NSString};
137
138 objc2::rc::autoreleasepool(|_pool| {
139 let mut keys: Vec<objc2::rc::Retained<NSString>> = Vec::with_capacity(inputs.len());
140 let mut vals: Vec<objc2::rc::Retained<MLFeatureValue>> = Vec::with_capacity(inputs.len());
141
142 for &(name, tensor) in inputs {
143 keys.push(ffi::str_to_nsstring(name));
144 vals.push(unsafe { MLFeatureValue::featureValueWithMultiArray(tensor.as_ml_multi_array()) });
145 }
146
147 let key_refs: Vec<&NSString> = keys.iter().map(|k| &**k).collect();
148 let val_refs: Vec<&MLFeatureValue> = vals.iter().map(|v| &**v).collect();
149
150 let dict: objc2::rc::Retained<NSDictionary<NSString, MLFeatureValue>> =
151 NSDictionary::from_slices(&key_refs, &val_refs);
152
153 let dict_any: &NSDictionary<NSString, objc2::runtime::AnyObject> =
154 unsafe { &*((&*dict) as *const NSDictionary<NSString, MLFeatureValue>
155 as *const NSDictionary<NSString, objc2::runtime::AnyObject>) };
156
157 let provider = unsafe {
158 MLDictionaryFeatureProvider::initWithDictionary_error(
159 MLDictionaryFeatureProvider::alloc(),
160 dict_any,
161 )
162 }
163 .map_err(|e| Error::from_nserror(ErrorKind::Prediction, &e))?;
164
165 let provider_ref: &objc2::runtime::ProtocolObject<dyn MLFeatureProvider> =
166 objc2::runtime::ProtocolObject::from_ref(&*provider);
167
168 let result = unsafe { self.inner.predictionFromFeatures_error(provider_ref) }
169 .map_err(|e| Error::from_nserror(ErrorKind::Prediction, &e))?;
170
171 Ok(Prediction { inner: result })
172 })
173 }
174
175 #[cfg(not(target_vendor = "apple"))]
176 pub fn predict(&self, _inputs: &[(&str, &dyn AsMultiArray)]) -> Result<Prediction> {
177 Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
178 }
179
180 #[cfg(target_vendor = "apple")]
182 pub fn inputs(&self) -> Vec<FeatureDescription> {
183 let desc = unsafe { self.inner.modelDescription() };
184 let input_map = unsafe { desc.inputDescriptionsByName() };
185 description::extract_features(&input_map)
186 }
187
188 #[cfg(target_vendor = "apple")]
190 pub fn outputs(&self) -> Vec<FeatureDescription> {
191 let desc = unsafe { self.inner.modelDescription() };
192 let output_map = unsafe { desc.outputDescriptionsByName() };
193 description::extract_features(&output_map)
194 }
195
196 #[cfg(target_vendor = "apple")]
198 pub fn metadata(&self) -> ModelMetadata {
199 let desc = unsafe { self.inner.modelDescription() };
200 description::extract_metadata(&desc)
201 }
202
203 #[cfg(target_vendor = "apple")]
205 pub fn new_state(&self) -> Result<State> {
206 let inner = unsafe { self.inner.newState() };
207 Ok(State { inner })
208 }
209
210 #[cfg(not(target_vendor = "apple"))]
211 pub fn new_state(&self) -> Result<State> {
212 Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
213 }
214
215 #[cfg(target_vendor = "apple")]
217 pub fn predict_stateful(
218 &self,
219 inputs: &[(&str, &dyn AsMultiArray)],
220 state: &State,
221 ) -> Result<Prediction> {
222 use objc2::AnyThread;
223 use objc2_core_ml::{MLDictionaryFeatureProvider, MLFeatureProvider, MLFeatureValue};
224 use objc2_foundation::{NSDictionary, NSString};
225
226 objc2::rc::autoreleasepool(|_pool| {
227 let mut keys: Vec<objc2::rc::Retained<NSString>> = Vec::with_capacity(inputs.len());
228 let mut vals: Vec<objc2::rc::Retained<MLFeatureValue>> = Vec::with_capacity(inputs.len());
229
230 for &(name, tensor) in inputs {
231 keys.push(ffi::str_to_nsstring(name));
232 vals.push(unsafe { MLFeatureValue::featureValueWithMultiArray(tensor.as_ml_multi_array()) });
233 }
234
235 let key_refs: Vec<&NSString> = keys.iter().map(|k| &**k).collect();
236 let val_refs: Vec<&MLFeatureValue> = vals.iter().map(|v| &**v).collect();
237 let dict: objc2::rc::Retained<NSDictionary<NSString, MLFeatureValue>> =
238 NSDictionary::from_slices(&key_refs, &val_refs);
239 let dict_any: &NSDictionary<NSString, objc2::runtime::AnyObject> =
240 unsafe { &*((&*dict) as *const NSDictionary<NSString, MLFeatureValue>
241 as *const NSDictionary<NSString, objc2::runtime::AnyObject>) };
242
243 let provider = unsafe {
244 MLDictionaryFeatureProvider::initWithDictionary_error(
245 MLDictionaryFeatureProvider::alloc(), dict_any,
246 )
247 }
248 .map_err(|e| Error::from_nserror(ErrorKind::Prediction, &e))?;
249
250 let provider_ref: &objc2::runtime::ProtocolObject<dyn MLFeatureProvider> =
251 objc2::runtime::ProtocolObject::from_ref(&*provider);
252
253 let result = unsafe {
254 self.inner.predictionFromFeatures_usingState_error(provider_ref, &state.inner)
255 }
256 .map_err(|e| Error::from_nserror(ErrorKind::Prediction, &e))?;
257
258 Ok(Prediction { inner: result })
259 })
260 }
261
262 #[cfg(not(target_vendor = "apple"))]
263 pub fn predict_stateful(
264 &self,
265 _inputs: &[(&str, &dyn AsMultiArray)],
266 _state: &State,
267 ) -> Result<Prediction> {
268 Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
269 }
270
271 #[cfg(target_vendor = "apple")]
275 pub fn predict_batch(&self, batch: &batch::BatchProvider) -> Result<batch::BatchPrediction> {
276 use objc2_core_ml::MLBatchProvider;
277
278 let batch_ref: &objc2::runtime::ProtocolObject<dyn MLBatchProvider> =
279 objc2::runtime::ProtocolObject::from_ref(&*batch.inner);
280
281 let result = unsafe { self.inner.predictionsFromBatch_error(batch_ref) }
282 .map_err(|e| Error::from_nserror(ErrorKind::Prediction, &e))?;
283
284 Ok(batch::BatchPrediction { inner: result })
285 }
286
287 #[cfg(not(target_vendor = "apple"))]
288 pub fn predict_batch(&self, _batch: &batch::BatchProvider) -> Result<batch::BatchPrediction> {
289 Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
290 }
291
292 #[cfg(not(target_vendor = "apple"))]
293 pub fn inputs(&self) -> Vec<FeatureDescription> { vec![] }
294
295 #[cfg(not(target_vendor = "apple"))]
296 pub fn outputs(&self) -> Vec<FeatureDescription> { vec![] }
297
298 #[cfg(not(target_vendor = "apple"))]
299 pub fn metadata(&self) -> ModelMetadata { ModelMetadata::default() }
300}
301
302#[cfg(target_vendor = "apple")]
305pub struct Prediction {
306 inner: objc2::rc::Retained<objc2::runtime::ProtocolObject<dyn objc2_core_ml::MLFeatureProvider>>,
307}
308
309#[cfg(not(target_vendor = "apple"))]
310pub struct Prediction {
311 _private: (),
312}
313
314#[cfg(target_vendor = "apple")]
317unsafe impl Send for Prediction {}
318#[cfg(target_vendor = "apple")]
319unsafe impl Sync for Prediction {}
320
321impl Prediction {
322 #[cfg(target_vendor = "apple")]
325 #[allow(deprecated)]
326 pub fn get_f32(&self, name: &str) -> Result<(Vec<f32>, Vec<usize>)> {
327 objc2::rc::autoreleasepool(|_pool| {
328 let (count, shape, data_type, array) = self.get_output_array(name)?;
329 let mut buf = vec![0.0f32; count];
330 Self::copy_array_to_f32(&array, data_type, count, &mut buf)?;
331 Ok((buf, shape))
332 })
333 }
334
335 #[cfg(target_vendor = "apple")]
339 #[allow(deprecated)]
340 pub fn get_f32_into(&self, name: &str, buf: &mut [f32]) -> Result<Vec<usize>> {
341 objc2::rc::autoreleasepool(|_pool| {
342 let (count, shape, data_type, array) = self.get_output_array(name)?;
343 if buf.len() < count {
344 return Err(Error::new(
345 ErrorKind::InvalidShape,
346 format!("buffer length {} < output element count {count}", buf.len()),
347 ));
348 }
349 Self::copy_array_to_f32(&array, data_type, count, buf)?;
350 Ok(shape)
351 })
352 }
353
354 #[cfg(target_vendor = "apple")]
355 #[allow(deprecated)]
356 #[allow(clippy::type_complexity)]
357 fn get_output_array(
358 &self,
359 name: &str,
360 ) -> Result<(
361 usize,
362 Vec<usize>,
363 Option<DataType>,
364 objc2::rc::Retained<objc2_core_ml::MLMultiArray>,
365 )> {
366 use objc2_core_ml::MLFeatureProvider;
367
368 let ns_name = ffi::str_to_nsstring(name);
369 let feature_val = unsafe { self.inner.featureValueForName(&ns_name) }.ok_or_else(|| {
370 Error::new(ErrorKind::Prediction, format!("output '{name}' not found"))
371 })?;
372
373 let array = unsafe { feature_val.multiArrayValue() }.ok_or_else(|| {
374 Error::new(ErrorKind::Prediction, format!("output '{name}' is not a multi-array"))
375 })?;
376
377 let shape = ffi::nsarray_to_shape(unsafe { &array.shape() });
378 let count = tensor::element_count(&shape);
379 let dt_raw = unsafe { array.dataType() };
380 let data_type = ffi::ml_to_datatype(dt_raw.0);
381
382 Ok((count, shape, data_type, array))
383 }
384
385 #[cfg(target_vendor = "apple")]
392 #[allow(deprecated)]
393 #[allow(clippy::needless_range_loop)]
394 fn copy_array_to_f32(
395 array: &objc2_core_ml::MLMultiArray,
396 data_type: Option<DataType>,
397 count: usize,
398 buf: &mut [f32],
399 ) -> Result<()> {
400 unsafe {
401 let ptr = array.dataPointer();
402 let shape = ffi::nsarray_to_shape(&array.shape());
403 let strides = ffi::nsarray_to_shape(&array.strides());
404 let row_major_strides = tensor::compute_strides(&shape);
405 let is_contiguous = strides == row_major_strides;
406
407 if is_contiguous {
408 match data_type {
410 Some(DataType::Float32) => {
411 let src = ptr.as_ptr() as *const f32;
412 std::ptr::copy_nonoverlapping(src, buf.as_mut_ptr(), count);
413 }
414 Some(DataType::Float16) => {
415 let src = ptr.as_ptr() as *const u16;
416 for i in 0..count {
417 buf[i] = f16_to_f32(*src.add(i));
418 }
419 }
420 Some(DataType::Float64) => {
421 let src = ptr.as_ptr() as *const f64;
422 for i in 0..count {
423 buf[i] = *src.add(i) as f32;
424 }
425 }
426 Some(DataType::Int32) => {
427 let src = ptr.as_ptr() as *const i32;
428 for i in 0..count {
429 buf[i] = *src.add(i) as f32;
430 }
431 }
432 Some(DataType::Int16) => {
433 let src = ptr.as_ptr() as *const i16;
434 for i in 0..count {
435 buf[i] = *src.add(i) as f32;
436 }
437 }
438 Some(DataType::Int8) => {
439 let src = ptr.as_ptr() as *const i8;
440 for i in 0..count {
441 buf[i] = *src.add(i) as f32;
442 }
443 }
444 Some(DataType::UInt32) => {
445 let src = ptr.as_ptr() as *const u32;
446 for i in 0..count {
447 buf[i] = *src.add(i) as f32;
448 }
449 }
450 Some(DataType::UInt16) => {
451 let src = ptr.as_ptr() as *const u16;
452 for i in 0..count {
453 buf[i] = *src.add(i) as f32;
454 }
455 }
456 Some(DataType::UInt8) => {
457 let src = ptr.as_ptr() as *const u8;
458 for i in 0..count {
459 buf[i] = *src.add(i) as f32;
460 }
461 }
462 None => {
463 return Err(Error::new(
464 ErrorKind::Prediction,
465 "unsupported output data type",
466 ));
467 }
468 }
469 } else {
470 let ndims = shape.len();
473 let mut indices = vec![0usize; ndims];
474
475 macro_rules! strided_copy {
476 ($src_type:ty, $convert:expr) => {{
477 let src = ptr.as_ptr() as *const $src_type;
478 for logical_idx in 0..count {
479 let physical: usize = indices.iter()
480 .zip(strides.iter())
481 .map(|(&i, &s)| i * s)
482 .sum();
483 buf[logical_idx] = $convert(*src.add(physical));
484 for d in (0..ndims).rev() {
486 indices[d] += 1;
487 if indices[d] < shape[d] {
488 break;
489 }
490 indices[d] = 0;
491 }
492 }
493 }};
494 }
495
496 match data_type {
497 Some(DataType::Float32) => strided_copy!(f32, |v: f32| v),
498 Some(DataType::Float16) => strided_copy!(u16, |v: u16| f16_to_f32(v)),
499 Some(DataType::Float64) => strided_copy!(f64, |v: f64| v as f32),
500 Some(DataType::Int32) => strided_copy!(i32, |v: i32| v as f32),
501 Some(DataType::Int16) => strided_copy!(i16, |v: i16| v as f32),
502 Some(DataType::Int8) => strided_copy!(i8, |v: i8| v as f32),
503 Some(DataType::UInt32) => strided_copy!(u32, |v: u32| v as f32),
504 Some(DataType::UInt16) => strided_copy!(u16, |v: u16| v as f32),
505 Some(DataType::UInt8) => strided_copy!(u8, |v: u8| v as f32),
506 None => {
507 return Err(Error::new(
508 ErrorKind::Prediction,
509 "unsupported output data type",
510 ));
511 }
512 }
513 }
514 }
515 Ok(())
516 }
517
518 #[cfg(not(target_vendor = "apple"))]
519 pub fn get_f32(&self, _name: &str) -> Result<(Vec<f32>, Vec<usize>)> {
520 Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
521 }
522
523 #[cfg(not(target_vendor = "apple"))]
524 pub fn get_f32_into(&self, _name: &str, _buf: &mut [f32]) -> Result<Vec<usize>> {
525 Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
526 }
527
528 #[cfg(target_vendor = "apple")]
530 #[allow(deprecated)]
531 pub fn get_i32(&self, name: &str) -> Result<(Vec<i32>, Vec<usize>)> {
532 objc2::rc::autoreleasepool(|_pool| {
533 let (count, shape, data_type, array) = self.get_output_array(name)?;
534 match data_type {
535 Some(DataType::Int32) => {
536 let mut buf = vec![0i32; count];
537 unsafe {
538 let ptr = array.dataPointer();
539 let src = ptr.as_ptr() as *const i32;
540 std::ptr::copy_nonoverlapping(src, buf.as_mut_ptr(), count);
541 }
542 Ok((buf, shape))
543 }
544 Some(dt) => Err(Error::new(
545 ErrorKind::Prediction,
546 format!("output '{name}' is {dt}, not Int32"),
547 )),
548 None => Err(Error::new(ErrorKind::Prediction, "unsupported output data type")),
549 }
550 })
551 }
552
553 #[cfg(target_vendor = "apple")]
555 #[allow(deprecated)]
556 pub fn get_f64(&self, name: &str) -> Result<(Vec<f64>, Vec<usize>)> {
557 objc2::rc::autoreleasepool(|_pool| {
558 let (count, shape, data_type, array) = self.get_output_array(name)?;
559 match data_type {
560 Some(DataType::Float64) => {
561 let mut buf = vec![0.0f64; count];
562 unsafe {
563 let ptr = array.dataPointer();
564 let src = ptr.as_ptr() as *const f64;
565 std::ptr::copy_nonoverlapping(src, buf.as_mut_ptr(), count);
566 }
567 Ok((buf, shape))
568 }
569 Some(dt) => Err(Error::new(
570 ErrorKind::Prediction,
571 format!("output '{name}' is {dt}, not Float64"),
572 )),
573 None => Err(Error::new(ErrorKind::Prediction, "unsupported output data type")),
574 }
575 })
576 }
577
578 #[cfg(target_vendor = "apple")]
580 #[allow(deprecated)]
581 pub fn get_raw(&self, name: &str) -> Result<(Vec<u8>, Vec<usize>, Option<DataType>)> {
582 objc2::rc::autoreleasepool(|_pool| {
583 let (count, shape, data_type, array) = self.get_output_array(name)?;
584 let byte_size = data_type.map(|dt| dt.byte_size()).unwrap_or(4);
585 let total_bytes = count * byte_size;
586 let mut buf = vec![0u8; total_bytes];
587 unsafe {
588 let ptr = array.dataPointer();
589 std::ptr::copy_nonoverlapping(
590 ptr.as_ptr() as *const u8,
591 buf.as_mut_ptr(),
592 total_bytes,
593 );
594 }
595 Ok((buf, shape, data_type))
596 })
597 }
598
599 #[cfg(not(target_vendor = "apple"))]
600 pub fn get_i32(&self, _name: &str) -> Result<(Vec<i32>, Vec<usize>)> {
601 Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
602 }
603
604 #[cfg(not(target_vendor = "apple"))]
605 pub fn get_f64(&self, _name: &str) -> Result<(Vec<f64>, Vec<usize>)> {
606 Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
607 }
608
609 #[cfg(not(target_vendor = "apple"))]
610 pub fn get_raw(&self, _name: &str) -> Result<(Vec<u8>, Vec<usize>, Option<DataType>)> {
611 Err(Error::new(ErrorKind::UnsupportedPlatform, "CoreML requires Apple platform"))
612 }
613}
614
615#[cfg(target_vendor = "apple")]
617fn f16_to_f32(bits: u16) -> f32 {
618 let sign = ((bits >> 15) & 1) as u32;
619 let exp = ((bits >> 10) & 0x1f) as u32;
620 let frac = (bits & 0x3ff) as u32;
621
622 if exp == 0 {
623 if frac == 0 {
624 f32::from_bits(sign << 31)
625 } else {
626 let mut e = 0i32;
627 let mut f = frac;
628 while (f & 0x400) == 0 {
629 f <<= 1;
630 e -= 1;
631 }
632 f &= 0x3ff;
633 let exp32 = (127 - 15 + 1 + e) as u32;
634 f32::from_bits((sign << 31) | (exp32 << 23) | (f << 13))
635 }
636 } else if exp == 31 {
637 if frac == 0 {
638 f32::from_bits((sign << 31) | (0xff << 23))
639 } else {
640 f32::from_bits((sign << 31) | (0xff << 23) | (frac << 13))
641 }
642 } else {
643 let exp32 = exp + (127 - 15);
644 f32::from_bits((sign << 31) | (exp32 << 23) | (frac << 13))
645 }
646}
647
648#[cfg(test)]
649mod tests {
650 use super::*;
651
652 #[test]
653 fn compute_units_default_is_all() {
654 assert_eq!(ComputeUnits::default(), ComputeUnits::All);
655 }
656
657 #[test]
658 fn compute_units_display() {
659 assert_eq!(format!("{}", ComputeUnits::CpuAndGpu), "CPU + GPU");
660 assert_eq!(format!("{}", ComputeUnits::All), "All (CPU + GPU + ANE)");
661 }
662
663 #[test]
664 fn compute_units_display_cpu_only() {
665 assert_eq!(format!("{}", ComputeUnits::CpuOnly), "CPU only");
666 }
667
668 #[cfg(not(target_vendor = "apple"))]
669 #[test]
670 fn model_load_fails_on_non_apple() {
671 let err = Model::load("/tmp/fake.mlmodelc", ComputeUnits::All).unwrap_err();
672 assert_eq!(err.kind(), &ErrorKind::UnsupportedPlatform);
673 }
674}