Skip to main content

apple_mpsgraph/
execution.rs

1use crate::data::TensorData;
2use crate::error::{Error, Result};
3use crate::ffi;
4use crate::graph::{Executable, FeedDescription, Graph, Tensor};
5use crate::types::{
6    collect_owned_tensors, collect_shaped_type_array_box, collect_tensor_data_array_box, ShapedType,
7};
8use apple_metal::{CommandQueue, MetalDevice};
9use core::ffi::c_void;
10use core::ptr;
11use std::ffi::CString;
12
13fn release_handle(ptr: &mut *mut c_void) {
14    if !ptr.is_null() {
15        // SAFETY: `ptr` is a +1 retained Swift/ObjC object pointer owned by this wrapper.
16        unsafe { ffi::mpsgraph_object_release(*ptr) };
17        *ptr = ptr::null_mut();
18    }
19}
20
21fn copy_string(
22    len: unsafe extern "C" fn(*mut c_void) -> usize,
23    copy: unsafe extern "C" fn(*mut c_void, *mut u8, usize) -> bool,
24    handle: *mut c_void,
25) -> Result<String> {
26    // SAFETY: the function pointers belong to Swift shims that treat `handle` as immutable for the duration of the call.
27    let len = unsafe { len(handle) };
28    let mut bytes = vec![0_u8; len];
29    // SAFETY: the buffer is valid for exactly `len` bytes.
30    let ok = unsafe { copy(handle, bytes.as_mut_ptr(), len) };
31    if ok {
32        String::from_utf8(bytes)
33            .map_err(|_| Error::OperationFailed("bridge returned invalid UTF-8"))
34    } else {
35        Err(Error::OperationFailed("failed to copy string from bridge"))
36    }
37}
38
39/// `MPSGraphOptions` constants.
40pub mod graph_options {
41/// Mirrors the `MPSGraph` framework constant `NONE`.
42    pub const NONE: u64 = 0;
43/// Mirrors the `MPSGraph` framework constant `SYNCHRONIZE_RESULTS`.
44    pub const SYNCHRONIZE_RESULTS: u64 = 1;
45/// Mirrors the `MPSGraph` framework constant `VERBOSE`.
46    pub const VERBOSE: u64 = 2;
47/// Mirrors the `MPSGraph` framework constant `DEFAULT`.
48    pub const DEFAULT: u64 = SYNCHRONIZE_RESULTS;
49}
50
51/// `MPSGraphOptimization` constants.
52pub mod optimization {
53/// Mirrors the `MPSGraph` framework constant `LEVEL0`.
54    pub const LEVEL0: u64 = 0;
55/// Mirrors the `MPSGraph` framework constant `LEVEL1`.
56    pub const LEVEL1: u64 = 1;
57}
58
59/// `MPSGraphOptimizationProfile` constants.
60pub mod optimization_profile {
61/// Mirrors the `MPSGraph` framework constant `PERFORMANCE`.
62    pub const PERFORMANCE: u64 = 0;
63/// Mirrors the `MPSGraph` framework constant `POWER_EFFICIENCY`.
64    pub const POWER_EFFICIENCY: u64 = 1;
65}
66
67/// `MPSGraphReducedPrecisionFastMath` bit flags.
68pub mod reduced_precision_fast_math {
69/// Mirrors the `MPSGraph` framework constant `NONE`.
70    pub const NONE: usize = 0;
71/// Mirrors the `MPSGraph` framework constant `ALLOW_FP16_CONV2D_WINOGRAD_TRANSFORM_INTERMEDIATE`.
72    pub const ALLOW_FP16_CONV2D_WINOGRAD_TRANSFORM_INTERMEDIATE: usize = 1 << 1;
73/// Mirrors the `MPSGraph` framework constant `ALLOW_FP16_INTERMEDIATES`.
74    pub const ALLOW_FP16_INTERMEDIATES: usize = ALLOW_FP16_CONV2D_WINOGRAD_TRANSFORM_INTERMEDIATE;
75/// Mirrors the `MPSGraph` framework constant `DEFAULT`.
76    pub const DEFAULT: usize = NONE;
77}
78
79/// `MPSGraphDeploymentPlatform` constants.
80pub mod deployment_platform {
81/// Mirrors the `MPSGraph` framework constant `MACOS`.
82    pub const MACOS: u64 = 0;
83/// Mirrors the `MPSGraph` framework constant `IOS`.
84    pub const IOS: u64 = 1;
85/// Mirrors the `MPSGraph` framework constant `TVOS`.
86    pub const TVOS: u64 = 2;
87/// Mirrors the `MPSGraph` framework constant `VISIONOS`.
88    pub const VISIONOS: u64 = 3;
89}
90
91/// Safe owner for `MPSGraphCompilationDescriptor`.
92pub struct CompilationDescriptor {
93    ptr: *mut c_void,
94}
95
96unsafe impl Send for CompilationDescriptor {}
97unsafe impl Sync for CompilationDescriptor {}
98
99impl Drop for CompilationDescriptor {
100    fn drop(&mut self) {
101        release_handle(&mut self.ptr);
102    }
103}
104
105impl CompilationDescriptor {
106/// Calls the `MPSGraph` framework counterpart for `new`.
107    #[must_use]
108    pub fn new() -> Option<Self> {
109        // SAFETY: pure constructor.
110        let ptr = unsafe { ffi::mpsgraph_compilation_descriptor_new() };
111        if ptr.is_null() {
112            None
113        } else {
114            Some(Self { ptr })
115        }
116    }
117
118    #[must_use]
119    pub(crate) const fn as_ptr(&self) -> *mut c_void {
120        self.ptr
121    }
122
123/// Calls the `MPSGraph` framework counterpart for `disable_type_inference`.
124    pub fn disable_type_inference(&self) -> Result<()> {
125        // SAFETY: `self.ptr` is a live descriptor handle.
126        let ok = unsafe { ffi::mpsgraph_compilation_descriptor_disable_type_inference(self.ptr) };
127        if ok {
128            Ok(())
129        } else {
130            Err(Error::OperationFailed("failed to disable type inference"))
131        }
132    }
133
134/// Calls the `MPSGraph` framework counterpart for `optimization_level`.
135    #[must_use]
136    pub fn optimization_level(&self) -> u64 {
137        // SAFETY: `self.ptr` is a live descriptor handle.
138        unsafe { ffi::mpsgraph_compilation_descriptor_optimization_level(self.ptr) }
139    }
140
141/// Calls the `MPSGraph` framework counterpart for `set_optimization_level`.
142    pub fn set_optimization_level(&self, value: u64) -> Result<()> {
143        // SAFETY: `self.ptr` is a live descriptor handle.
144        let ok =
145            unsafe { ffi::mpsgraph_compilation_descriptor_set_optimization_level(self.ptr, value) };
146        if ok {
147            Ok(())
148        } else {
149            Err(Error::OperationFailed("failed to set optimization level"))
150        }
151    }
152
153/// Calls the `MPSGraph` framework counterpart for `wait_for_compilation_completion`.
154    #[must_use]
155    pub fn wait_for_compilation_completion(&self) -> bool {
156        // SAFETY: `self.ptr` is a live descriptor handle.
157        unsafe { ffi::mpsgraph_compilation_descriptor_wait_for_completion(self.ptr) }
158    }
159
160/// Calls the `MPSGraph` framework counterpart for `set_wait_for_compilation_completion`.
161    pub fn set_wait_for_compilation_completion(&self, value: bool) -> Result<()> {
162        // SAFETY: `self.ptr` is a live descriptor handle.
163        let ok = unsafe {
164            ffi::mpsgraph_compilation_descriptor_set_wait_for_completion(self.ptr, value)
165        };
166        if ok {
167            Ok(())
168        } else {
169            Err(Error::OperationFailed(
170                "failed to set waitForCompilationCompletion",
171            ))
172        }
173    }
174
175/// Calls the `MPSGraph` framework counterpart for `optimization_profile`.
176    #[must_use]
177    pub fn optimization_profile(&self) -> u64 {
178        // SAFETY: `self.ptr` is a live descriptor handle.
179        unsafe { ffi::mpsgraph_compilation_descriptor_optimization_profile(self.ptr) }
180    }
181
182/// Calls the `MPSGraph` framework counterpart for `set_optimization_profile`.
183    pub fn set_optimization_profile(&self, value: u64) -> Result<()> {
184        // SAFETY: `self.ptr` is a live descriptor handle.
185        let ok = unsafe {
186            ffi::mpsgraph_compilation_descriptor_set_optimization_profile(self.ptr, value)
187        };
188        if ok {
189            Ok(())
190        } else {
191            Err(Error::OperationFailed("failed to set optimization profile"))
192        }
193    }
194
195/// Calls the `MPSGraph` framework counterpart for `reduced_precision_fast_math`.
196    #[must_use]
197    pub fn reduced_precision_fast_math(&self) -> usize {
198        // SAFETY: `self.ptr` is a live descriptor handle.
199        unsafe { ffi::mpsgraph_compilation_descriptor_reduced_precision_fast_math(self.ptr) }
200    }
201
202/// Calls the `MPSGraph` framework counterpart for `set_reduced_precision_fast_math`.
203    pub fn set_reduced_precision_fast_math(&self, value: usize) -> Result<()> {
204        // SAFETY: `self.ptr` is a live descriptor handle.
205        let ok = unsafe {
206            ffi::mpsgraph_compilation_descriptor_set_reduced_precision_fast_math(self.ptr, value)
207        };
208        if ok {
209            Ok(())
210        } else {
211            Err(Error::OperationFailed(
212                "failed to set reducedPrecisionFastMath",
213            ))
214        }
215    }
216
217/// Calls the `MPSGraph` framework counterpart for `set_callable`.
218    pub fn set_callable(&self, symbol_name: &str, executable: Option<&Executable>) -> Result<()> {
219        let symbol_name = CString::new(symbol_name)
220            .map_err(|_| Error::OperationFailed("call symbol name contained NUL"))?;
221        let executable_ptr = executable.map_or(ptr::null_mut(), Executable::as_ptr);
222        // SAFETY: all handles remain valid for the duration of the call.
223        let ok = unsafe {
224            ffi::mpsgraph_compilation_descriptor_set_callable(
225                self.ptr,
226                symbol_name.as_ptr(),
227                executable_ptr,
228            )
229        };
230        if ok {
231            Ok(())
232        } else {
233            Err(Error::OperationFailed(
234                "failed to set compilation descriptor callable",
235            ))
236        }
237    }
238}
239
240/// Safe owner for `MPSGraphExecutionDescriptor`.
241pub struct ExecutionDescriptor {
242    ptr: *mut c_void,
243}
244
245unsafe impl Send for ExecutionDescriptor {}
246unsafe impl Sync for ExecutionDescriptor {}
247
248impl Drop for ExecutionDescriptor {
249    fn drop(&mut self) {
250        release_handle(&mut self.ptr);
251    }
252}
253
254impl ExecutionDescriptor {
255/// Calls the `MPSGraph` framework counterpart for `new`.
256    #[must_use]
257    pub fn new() -> Option<Self> {
258        // SAFETY: pure constructor.
259        let ptr = unsafe { ffi::mpsgraph_execution_descriptor_new() };
260        if ptr.is_null() {
261            None
262        } else {
263            Some(Self { ptr })
264        }
265    }
266
267/// Mirrors the `MPSGraph` framework constant `fn`.
268    #[must_use]
269    pub const fn as_ptr(&self) -> *mut c_void {
270        self.ptr
271    }
272
273/// Calls the `MPSGraph` framework counterpart for `wait_until_completed`.
274    #[must_use]
275    pub fn wait_until_completed(&self) -> bool {
276        // SAFETY: `self.ptr` is a live descriptor handle.
277        unsafe { ffi::mpsgraph_execution_descriptor_wait_until_completed(self.ptr) }
278    }
279
280/// Calls the `MPSGraph` framework counterpart for `set_wait_until_completed`.
281    pub fn set_wait_until_completed(&self, value: bool) -> Result<()> {
282        // SAFETY: `self.ptr` is a live descriptor handle.
283        let ok =
284            unsafe { ffi::mpsgraph_execution_descriptor_set_wait_until_completed(self.ptr, value) };
285        if ok {
286            Ok(())
287        } else {
288            Err(Error::OperationFailed("failed to set waitUntilCompleted"))
289        }
290    }
291
292/// Calls the `MPSGraph` framework counterpart for `compilation_descriptor`.
293    #[must_use]
294    pub fn compilation_descriptor(&self) -> Option<CompilationDescriptor> {
295        // SAFETY: `self.ptr` is a live descriptor handle.
296        let ptr = unsafe { ffi::mpsgraph_execution_descriptor_compilation_descriptor(self.ptr) };
297        if ptr.is_null() {
298            None
299        } else {
300            Some(CompilationDescriptor { ptr })
301        }
302    }
303
304/// Calls the `MPSGraph` framework counterpart for `set_compilation_descriptor`.
305    pub fn set_compilation_descriptor(
306        &self,
307        descriptor: Option<&CompilationDescriptor>,
308    ) -> Result<()> {
309        let descriptor_ptr = descriptor.map_or(ptr::null_mut(), CompilationDescriptor::as_ptr);
310        // SAFETY: all handles remain valid for the duration of the call.
311        let ok = unsafe {
312            ffi::mpsgraph_execution_descriptor_set_compilation_descriptor(self.ptr, descriptor_ptr)
313        };
314        if ok {
315            Ok(())
316        } else {
317            Err(Error::OperationFailed(
318                "failed to set compilation descriptor",
319            ))
320        }
321    }
322}
323
324/// Safe owner for `MPSGraphExecutableExecutionDescriptor`.
325pub struct ExecutableExecutionDescriptor {
326    ptr: *mut c_void,
327}
328
329unsafe impl Send for ExecutableExecutionDescriptor {}
330unsafe impl Sync for ExecutableExecutionDescriptor {}
331
332impl Drop for ExecutableExecutionDescriptor {
333    fn drop(&mut self) {
334        release_handle(&mut self.ptr);
335    }
336}
337
338impl ExecutableExecutionDescriptor {
339/// Calls the `MPSGraph` framework counterpart for `new`.
340    #[must_use]
341    pub fn new() -> Option<Self> {
342        // SAFETY: pure constructor.
343        let ptr = unsafe { ffi::mpsgraph_executable_execution_descriptor_new() };
344        if ptr.is_null() {
345            None
346        } else {
347            Some(Self { ptr })
348        }
349    }
350
351    #[must_use]
352    pub(crate) const fn as_ptr(&self) -> *mut c_void {
353        self.ptr
354    }
355
356/// Calls the `MPSGraph` framework counterpart for `wait_until_completed`.
357    #[must_use]
358    pub fn wait_until_completed(&self) -> bool {
359        // SAFETY: `self.ptr` is a live descriptor handle.
360        unsafe { ffi::mpsgraph_executable_execution_descriptor_wait_until_completed(self.ptr) }
361    }
362
363/// Calls the `MPSGraph` framework counterpart for `set_wait_until_completed`.
364    pub fn set_wait_until_completed(&self, value: bool) -> Result<()> {
365        // SAFETY: `self.ptr` is a live descriptor handle.
366        let ok = unsafe {
367            ffi::mpsgraph_executable_execution_descriptor_set_wait_until_completed(self.ptr, value)
368        };
369        if ok {
370            Ok(())
371        } else {
372            Err(Error::OperationFailed(
373                "failed to set executable waitUntilCompleted",
374            ))
375        }
376    }
377}
378
379/// Safe owner for `MPSGraphExecutableSerializationDescriptor`.
380pub struct ExecutableSerializationDescriptor {
381    ptr: *mut c_void,
382}
383
384unsafe impl Send for ExecutableSerializationDescriptor {}
385unsafe impl Sync for ExecutableSerializationDescriptor {}
386
387impl Drop for ExecutableSerializationDescriptor {
388    fn drop(&mut self) {
389        release_handle(&mut self.ptr);
390    }
391}
392
393impl ExecutableSerializationDescriptor {
394/// Calls the `MPSGraph` framework counterpart for `new`.
395    #[must_use]
396    pub fn new() -> Option<Self> {
397        // SAFETY: pure constructor.
398        let ptr = unsafe { ffi::mpsgraph_executable_serialization_descriptor_new() };
399        if ptr.is_null() {
400            None
401        } else {
402            Some(Self { ptr })
403        }
404    }
405
406    #[must_use]
407    pub(crate) const fn as_ptr(&self) -> *mut c_void {
408        self.ptr
409    }
410
411/// Calls the `MPSGraph` framework counterpart for `append`.
412    #[must_use]
413    pub fn append(&self) -> bool {
414        // SAFETY: `self.ptr` is a live descriptor handle.
415        unsafe { ffi::mpsgraph_executable_serialization_descriptor_append(self.ptr) }
416    }
417
418/// Calls the `MPSGraph` framework counterpart for `set_append`.
419    pub fn set_append(&self, value: bool) -> Result<()> {
420        // SAFETY: `self.ptr` is a live descriptor handle.
421        let ok = unsafe {
422            ffi::mpsgraph_executable_serialization_descriptor_set_append(self.ptr, value)
423        };
424        if ok {
425            Ok(())
426        } else {
427            Err(Error::OperationFailed("failed to set append"))
428        }
429    }
430
431/// Calls the `MPSGraph` framework counterpart for `deployment_platform`.
432    #[must_use]
433    pub fn deployment_platform(&self) -> u64 {
434        // SAFETY: `self.ptr` is a live descriptor handle.
435        unsafe { ffi::mpsgraph_executable_serialization_descriptor_deployment_platform(self.ptr) }
436    }
437
438/// Calls the `MPSGraph` framework counterpart for `set_deployment_platform`.
439    pub fn set_deployment_platform(&self, value: u64) -> Result<()> {
440        // SAFETY: `self.ptr` is a live descriptor handle.
441        let ok = unsafe {
442            ffi::mpsgraph_executable_serialization_descriptor_set_deployment_platform(
443                self.ptr, value,
444            )
445        };
446        if ok {
447            Ok(())
448        } else {
449            Err(Error::OperationFailed("failed to set deployment platform"))
450        }
451    }
452
453/// Calls the `MPSGraph` framework counterpart for `minimum_deployment_target`.
454    pub fn minimum_deployment_target(&self) -> Result<String> {
455        copy_string(
456            ffi::mpsgraph_executable_serialization_descriptor_minimum_deployment_target_len,
457            ffi::mpsgraph_executable_serialization_descriptor_copy_minimum_deployment_target,
458            self.ptr,
459        )
460    }
461
462/// Calls the `MPSGraph` framework counterpart for `set_minimum_deployment_target`.
463    pub fn set_minimum_deployment_target(&self, value: &str) -> Result<()> {
464        let value = CString::new(value)
465            .map_err(|_| Error::OperationFailed("minimum deployment target contained NUL"))?;
466        // SAFETY: the CString stays alive for the duration of the call.
467        let ok = unsafe {
468            ffi::mpsgraph_executable_serialization_descriptor_set_minimum_deployment_target(
469                self.ptr,
470                value.as_ptr(),
471            )
472        };
473        if ok {
474            Ok(())
475        } else {
476            Err(Error::OperationFailed(
477                "failed to set minimum deployment target",
478            ))
479        }
480    }
481}
482
483impl Graph {
484    /// Return the graph's `MPSGraphOptions` bitmask.
485    #[must_use]
486    pub fn options(&self) -> u64 {
487        // SAFETY: `self` owns a live graph handle.
488        unsafe { ffi::mpsgraph_graph_options(self.as_ptr()) }
489    }
490
491    /// Replace the graph's options bitmask.
492    pub fn set_options(&self, options: u64) -> Result<()> {
493        // SAFETY: `self` owns a live graph handle.
494        let ok = unsafe { ffi::mpsgraph_graph_set_options(self.as_ptr(), options) };
495        if ok {
496            Ok(())
497        } else {
498            Err(Error::OperationFailed("failed to set graph options"))
499        }
500    }
501
502    /// Return the graph's placeholder tensors in insertion order.
503    #[must_use]
504    pub fn placeholder_tensors(&self) -> Vec<Tensor> {
505        // SAFETY: `self` owns a live graph handle.
506        let box_handle = unsafe { ffi::mpsgraph_graph_placeholder_tensors(self.as_ptr()) };
507        collect_owned_tensors(box_handle)
508    }
509
510    /// Compile the graph with an optional compilation descriptor.
511    #[must_use]
512    pub fn compile_with_descriptor(
513        &self,
514        device: Option<&MetalDevice>,
515        feeds: &[FeedDescription<'_>],
516        targets: &[&Tensor],
517        descriptor: Option<&CompilationDescriptor>,
518    ) -> Option<Executable> {
519        let feed_tensors = feeds
520            .iter()
521            .map(|feed| feed.tensor.as_ptr())
522            .collect::<Vec<_>>();
523        let shape_lengths = feeds
524            .iter()
525            .map(|feed| feed.shape.len())
526            .collect::<Vec<_>>();
527        let data_types = feeds.iter().map(|feed| feed.data_type).collect::<Vec<_>>();
528        let flat_shapes = feeds
529            .iter()
530            .flat_map(|feed| feed.shape.iter().copied())
531            .collect::<Vec<_>>();
532        let target_tensors = targets
533            .iter()
534            .map(|tensor| tensor.as_ptr())
535            .collect::<Vec<_>>();
536        let device_ptr = device.map_or(ptr::null_mut(), MetalDevice::as_ptr);
537        let descriptor_ptr = descriptor.map_or(ptr::null_mut(), CompilationDescriptor::as_ptr);
538
539        // SAFETY: all pointer arrays stay alive for the duration of the call.
540        let ptr = unsafe {
541            ffi::mpsgraph_graph_compile_with_descriptor(
542                self.as_ptr(),
543                device_ptr,
544                feed_tensors.as_ptr(),
545                feeds.len(),
546                flat_shapes.as_ptr(),
547                shape_lengths.as_ptr(),
548                data_types.as_ptr(),
549                target_tensors.as_ptr(),
550                targets.len(),
551                descriptor_ptr,
552            )
553        };
554        if ptr.is_null() {
555            None
556        } else {
557            Some(Executable::from_raw(ptr, targets.len()))
558        }
559    }
560}
561
562impl Executable {
563    /// Return the executable's `MPSGraphOptions` bitmask.
564    #[must_use]
565    pub fn options(&self) -> u64 {
566        // SAFETY: `self` owns a live executable handle.
567        unsafe { ffi::mpsgraph_executable_options(self.as_ptr()) }
568    }
569
570    /// Replace the executable's options bitmask.
571    pub fn set_options(&self, options: u64) -> Result<()> {
572        // SAFETY: `self` owns a live executable handle.
573        let ok = unsafe { ffi::mpsgraph_executable_set_options(self.as_ptr(), options) };
574        if ok {
575            Ok(())
576        } else {
577            Err(Error::OperationFailed("failed to set executable options"))
578        }
579    }
580
581    /// Return feed tensors if this executable was compiled from a graph.
582    #[must_use]
583    pub fn feed_tensors(&self) -> Vec<Tensor> {
584        // SAFETY: `self` owns a live executable handle.
585        let box_handle = unsafe { ffi::mpsgraph_executable_feed_tensors(self.as_ptr()) };
586        collect_owned_tensors(box_handle)
587    }
588
589    /// Return target tensors if this executable was compiled from a graph.
590    #[must_use]
591    pub fn target_tensors(&self) -> Vec<Tensor> {
592        // SAFETY: `self` owns a live executable handle.
593        let box_handle = unsafe { ffi::mpsgraph_executable_target_tensors(self.as_ptr()) };
594        collect_owned_tensors(box_handle)
595    }
596
597    /// Specialize the executable for the provided input types.
598    pub fn specialize(
599        &self,
600        device: Option<&MetalDevice>,
601        input_types: &[&ShapedType],
602        descriptor: Option<&CompilationDescriptor>,
603    ) -> Result<()> {
604        let input_type_handles = input_types
605            .iter()
606            .map(|value| value.as_ptr())
607            .collect::<Vec<_>>();
608        let device_ptr = device.map_or(ptr::null_mut(), MetalDevice::as_ptr);
609        let descriptor_ptr = descriptor.map_or(ptr::null_mut(), CompilationDescriptor::as_ptr);
610
611        // SAFETY: all pointer arrays stay alive for the duration of the call.
612        let ok = unsafe {
613            ffi::mpsgraph_executable_specialize(
614                self.as_ptr(),
615                device_ptr,
616                input_type_handles.as_ptr(),
617                input_types.len(),
618                descriptor_ptr,
619            )
620        };
621        if ok {
622            Ok(())
623        } else {
624            Err(Error::OperationFailed("failed to specialize executable"))
625        }
626    }
627
628    /// Query specialized output types for the provided input types.
629    pub fn output_types(
630        &self,
631        device: Option<&MetalDevice>,
632        input_types: &[&ShapedType],
633        descriptor: Option<&CompilationDescriptor>,
634    ) -> Result<Vec<ShapedType>> {
635        let input_type_handles = input_types
636            .iter()
637            .map(|value| value.as_ptr())
638            .collect::<Vec<_>>();
639        let device_ptr = device.map_or(ptr::null_mut(), MetalDevice::as_ptr);
640        let descriptor_ptr = descriptor.map_or(ptr::null_mut(), CompilationDescriptor::as_ptr);
641
642        // SAFETY: all pointer arrays stay alive for the duration of the call.
643        let box_handle = unsafe {
644            ffi::mpsgraph_executable_get_output_types(
645                self.as_ptr(),
646                device_ptr,
647                input_type_handles.as_ptr(),
648                input_types.len(),
649                descriptor_ptr,
650            )
651        };
652        if box_handle.is_null() {
653            Err(Error::OperationFailed(
654                "failed to get executable output types",
655            ))
656        } else {
657            Ok(collect_shaped_type_array_box(box_handle))
658        }
659    }
660
661    /// Run the executable with an optional execution descriptor and optional preallocated results.
662    pub fn run_with_descriptor(
663        &self,
664        command_queue: &CommandQueue,
665        inputs: &[&TensorData],
666        results: Option<&[&TensorData]>,
667        descriptor: Option<&ExecutableExecutionDescriptor>,
668    ) -> Result<Vec<TensorData>> {
669        let input_handles = inputs
670            .iter()
671            .map(|value| value.as_ptr())
672            .collect::<Vec<_>>();
673        let result_handles = results
674            .map(|values| {
675                values
676                    .iter()
677                    .map(|value| value.as_ptr())
678                    .collect::<Vec<_>>()
679            })
680            .unwrap_or_default();
681        let descriptor_ptr =
682            descriptor.map_or(ptr::null_mut(), ExecutableExecutionDescriptor::as_ptr);
683
684        // SAFETY: all pointer arrays stay alive for the duration of the call.
685        let box_handle = unsafe {
686            ffi::mpsgraph_executable_run_with_descriptor(
687                self.as_ptr(),
688                command_queue.as_ptr(),
689                input_handles.as_ptr(),
690                inputs.len(),
691                result_handles.as_ptr(),
692                result_handles.len(),
693                descriptor_ptr,
694            )
695        };
696        if box_handle.is_null() {
697            Err(Error::OperationFailed("failed to run executable"))
698        } else {
699            Ok(collect_tensor_data_array_box(box_handle))
700        }
701    }
702
703    /// Asynchronously run the executable with an optional execution descriptor.
704    pub fn run_async_with_descriptor(
705        &self,
706        command_queue: &CommandQueue,
707        inputs: &[&TensorData],
708        results: Option<&[&TensorData]>,
709        descriptor: Option<&ExecutableExecutionDescriptor>,
710    ) -> Result<Vec<TensorData>> {
711        let input_handles = inputs
712            .iter()
713            .map(|value| value.as_ptr())
714            .collect::<Vec<_>>();
715        let result_handles = results
716            .map(|values| {
717                values
718                    .iter()
719                    .map(|value| value.as_ptr())
720                    .collect::<Vec<_>>()
721            })
722            .unwrap_or_default();
723        let descriptor_ptr =
724            descriptor.map_or(ptr::null_mut(), ExecutableExecutionDescriptor::as_ptr);
725
726        // SAFETY: all pointer arrays stay alive for the duration of the call.
727        let box_handle = unsafe {
728            ffi::mpsgraph_executable_run_async_with_descriptor(
729                self.as_ptr(),
730                command_queue.as_ptr(),
731                input_handles.as_ptr(),
732                inputs.len(),
733                result_handles.as_ptr(),
734                result_handles.len(),
735                descriptor_ptr,
736            )
737        };
738        if box_handle.is_null() {
739            Err(Error::OperationFailed(
740                "failed to run executable asynchronously",
741            ))
742        } else {
743            Ok(collect_tensor_data_array_box(box_handle))
744        }
745    }
746
747    /// Serialize the executable to an `.mpsgraphpackage` path.
748    pub fn serialize_package(
749        &self,
750        path: &str,
751        descriptor: Option<&ExecutableSerializationDescriptor>,
752    ) -> Result<()> {
753        let path =
754            CString::new(path).map_err(|_| Error::OperationFailed("package path contained NUL"))?;
755        let descriptor_ptr =
756            descriptor.map_or(ptr::null_mut(), ExecutableSerializationDescriptor::as_ptr);
757        // SAFETY: the CString stays alive for the duration of the call.
758        let ok = unsafe {
759            ffi::mpsgraph_executable_serialize_package(self.as_ptr(), path.as_ptr(), descriptor_ptr)
760        };
761        if ok {
762            Ok(())
763        } else {
764            Err(Error::OperationFailed(
765                "failed to serialize executable package",
766            ))
767        }
768    }
769
770    /// Load an executable from an existing `.mpsgraphpackage`.
771    pub fn from_package(path: &str, descriptor: Option<&CompilationDescriptor>) -> Result<Self> {
772        let path =
773            CString::new(path).map_err(|_| Error::OperationFailed("package path contained NUL"))?;
774        let descriptor_ptr = descriptor.map_or(ptr::null_mut(), CompilationDescriptor::as_ptr);
775        // SAFETY: the CString stays alive for the duration of the call.
776        let ptr =
777            unsafe { ffi::mpsgraph_executable_new_with_package(path.as_ptr(), descriptor_ptr) };
778        if ptr.is_null() {
779            return Err(Error::OperationFailed("failed to load executable package"));
780        }
781        let output_count = {
782            // SAFETY: `ptr` is a live executable handle returned just above.
783            let box_handle = unsafe { ffi::mpsgraph_executable_target_tensors(ptr) };
784            collect_owned_tensors(box_handle).len()
785        };
786        Ok(Self::from_raw(ptr, output_count))
787    }
788}