Skip to main content

apple_mpsgraph/
execution.rs

1use crate::data::TensorData;
2use crate::error::{Error, Result};
3use crate::ffi;
4use crate::graph::{Executable, FeedDescription, Graph, Tensor};
5use crate::types::{
6    collect_owned_tensors, collect_shaped_type_array_box, collect_tensor_data_array_box, ShapedType,
7};
8use apple_metal::{CommandQueue, MetalDevice};
9use core::ffi::c_void;
10use core::ptr;
11use std::ffi::CString;
12
13fn release_handle(ptr: &mut *mut c_void) {
14    if !ptr.is_null() {
15        // SAFETY: `ptr` is a +1 retained Swift/ObjC object pointer owned by this wrapper.
16        unsafe { ffi::mpsgraph_object_release(*ptr) };
17        *ptr = ptr::null_mut();
18    }
19}
20
21fn copy_string(
22    len: unsafe extern "C" fn(*mut c_void) -> usize,
23    copy: unsafe extern "C" fn(*mut c_void, *mut u8, usize) -> bool,
24    handle: *mut c_void,
25) -> Result<String> {
26    // SAFETY: the function pointers belong to Swift shims that treat `handle` as immutable for the duration of the call.
27    let len = unsafe { len(handle) };
28    let mut bytes = vec![0_u8; len];
29    // SAFETY: the buffer is valid for exactly `len` bytes.
30    let ok = unsafe { copy(handle, bytes.as_mut_ptr(), len) };
31    if ok {
32        String::from_utf8(bytes)
33            .map_err(|_| Error::OperationFailed("bridge returned invalid UTF-8"))
34    } else {
35        Err(Error::OperationFailed("failed to copy string from bridge"))
36    }
37}
38
39/// `MPSGraphOptions` constants.
40pub mod graph_options {
41    pub const NONE: u64 = 0;
42    pub const SYNCHRONIZE_RESULTS: u64 = 1;
43    pub const VERBOSE: u64 = 2;
44    pub const DEFAULT: u64 = SYNCHRONIZE_RESULTS;
45}
46
47/// `MPSGraphOptimization` constants.
48pub mod optimization {
49    pub const LEVEL0: u64 = 0;
50    pub const LEVEL1: u64 = 1;
51}
52
53/// `MPSGraphOptimizationProfile` constants.
54pub mod optimization_profile {
55    pub const PERFORMANCE: u64 = 0;
56    pub const POWER_EFFICIENCY: u64 = 1;
57}
58
59/// `MPSGraphReducedPrecisionFastMath` bit flags.
60pub mod reduced_precision_fast_math {
61    pub const NONE: usize = 0;
62    pub const ALLOW_FP16_CONV2D_WINOGRAD_TRANSFORM_INTERMEDIATE: usize = 1 << 1;
63    pub const ALLOW_FP16_INTERMEDIATES: usize = ALLOW_FP16_CONV2D_WINOGRAD_TRANSFORM_INTERMEDIATE;
64    pub const DEFAULT: usize = NONE;
65}
66
67/// `MPSGraphDeploymentPlatform` constants.
68pub mod deployment_platform {
69    pub const MACOS: u64 = 0;
70    pub const IOS: u64 = 1;
71    pub const TVOS: u64 = 2;
72    pub const VISIONOS: u64 = 3;
73}
74
75/// Safe owner for `MPSGraphCompilationDescriptor`.
76pub struct CompilationDescriptor {
77    ptr: *mut c_void,
78}
79
80unsafe impl Send for CompilationDescriptor {}
81unsafe impl Sync for CompilationDescriptor {}
82
83impl Drop for CompilationDescriptor {
84    fn drop(&mut self) {
85        release_handle(&mut self.ptr);
86    }
87}
88
89impl CompilationDescriptor {
90    #[must_use]
91    pub fn new() -> Option<Self> {
92        // SAFETY: pure constructor.
93        let ptr = unsafe { ffi::mpsgraph_compilation_descriptor_new() };
94        if ptr.is_null() {
95            None
96        } else {
97            Some(Self { ptr })
98        }
99    }
100
101    #[must_use]
102    pub(crate) const fn as_ptr(&self) -> *mut c_void {
103        self.ptr
104    }
105
106    pub fn disable_type_inference(&self) -> Result<()> {
107        // SAFETY: `self.ptr` is a live descriptor handle.
108        let ok = unsafe { ffi::mpsgraph_compilation_descriptor_disable_type_inference(self.ptr) };
109        if ok {
110            Ok(())
111        } else {
112            Err(Error::OperationFailed("failed to disable type inference"))
113        }
114    }
115
116    #[must_use]
117    pub fn optimization_level(&self) -> u64 {
118        // SAFETY: `self.ptr` is a live descriptor handle.
119        unsafe { ffi::mpsgraph_compilation_descriptor_optimization_level(self.ptr) }
120    }
121
122    pub fn set_optimization_level(&self, value: u64) -> Result<()> {
123        // SAFETY: `self.ptr` is a live descriptor handle.
124        let ok =
125            unsafe { ffi::mpsgraph_compilation_descriptor_set_optimization_level(self.ptr, value) };
126        if ok {
127            Ok(())
128        } else {
129            Err(Error::OperationFailed("failed to set optimization level"))
130        }
131    }
132
133    #[must_use]
134    pub fn wait_for_compilation_completion(&self) -> bool {
135        // SAFETY: `self.ptr` is a live descriptor handle.
136        unsafe { ffi::mpsgraph_compilation_descriptor_wait_for_completion(self.ptr) }
137    }
138
139    pub fn set_wait_for_compilation_completion(&self, value: bool) -> Result<()> {
140        // SAFETY: `self.ptr` is a live descriptor handle.
141        let ok = unsafe {
142            ffi::mpsgraph_compilation_descriptor_set_wait_for_completion(self.ptr, value)
143        };
144        if ok {
145            Ok(())
146        } else {
147            Err(Error::OperationFailed(
148                "failed to set waitForCompilationCompletion",
149            ))
150        }
151    }
152
153    #[must_use]
154    pub fn optimization_profile(&self) -> u64 {
155        // SAFETY: `self.ptr` is a live descriptor handle.
156        unsafe { ffi::mpsgraph_compilation_descriptor_optimization_profile(self.ptr) }
157    }
158
159    pub fn set_optimization_profile(&self, value: u64) -> Result<()> {
160        // SAFETY: `self.ptr` is a live descriptor handle.
161        let ok = unsafe {
162            ffi::mpsgraph_compilation_descriptor_set_optimization_profile(self.ptr, value)
163        };
164        if ok {
165            Ok(())
166        } else {
167            Err(Error::OperationFailed("failed to set optimization profile"))
168        }
169    }
170
171    #[must_use]
172    pub fn reduced_precision_fast_math(&self) -> usize {
173        // SAFETY: `self.ptr` is a live descriptor handle.
174        unsafe { ffi::mpsgraph_compilation_descriptor_reduced_precision_fast_math(self.ptr) }
175    }
176
177    pub fn set_reduced_precision_fast_math(&self, value: usize) -> Result<()> {
178        // SAFETY: `self.ptr` is a live descriptor handle.
179        let ok = unsafe {
180            ffi::mpsgraph_compilation_descriptor_set_reduced_precision_fast_math(self.ptr, value)
181        };
182        if ok {
183            Ok(())
184        } else {
185            Err(Error::OperationFailed(
186                "failed to set reducedPrecisionFastMath",
187            ))
188        }
189    }
190
191    pub fn set_callable(&self, symbol_name: &str, executable: Option<&Executable>) -> Result<()> {
192        let symbol_name = CString::new(symbol_name)
193            .map_err(|_| Error::OperationFailed("call symbol name contained NUL"))?;
194        let executable_ptr = executable.map_or(ptr::null_mut(), Executable::as_ptr);
195        // SAFETY: all handles remain valid for the duration of the call.
196        let ok = unsafe {
197            ffi::mpsgraph_compilation_descriptor_set_callable(
198                self.ptr,
199                symbol_name.as_ptr(),
200                executable_ptr,
201            )
202        };
203        if ok {
204            Ok(())
205        } else {
206            Err(Error::OperationFailed(
207                "failed to set compilation descriptor callable",
208            ))
209        }
210    }
211}
212
213/// Safe owner for `MPSGraphExecutionDescriptor`.
214pub struct ExecutionDescriptor {
215    ptr: *mut c_void,
216}
217
218unsafe impl Send for ExecutionDescriptor {}
219unsafe impl Sync for ExecutionDescriptor {}
220
221impl Drop for ExecutionDescriptor {
222    fn drop(&mut self) {
223        release_handle(&mut self.ptr);
224    }
225}
226
227impl ExecutionDescriptor {
228    #[must_use]
229    pub fn new() -> Option<Self> {
230        // SAFETY: pure constructor.
231        let ptr = unsafe { ffi::mpsgraph_execution_descriptor_new() };
232        if ptr.is_null() {
233            None
234        } else {
235            Some(Self { ptr })
236        }
237    }
238
239    #[must_use]
240    pub const fn as_ptr(&self) -> *mut c_void {
241        self.ptr
242    }
243
244    #[must_use]
245    pub fn wait_until_completed(&self) -> bool {
246        // SAFETY: `self.ptr` is a live descriptor handle.
247        unsafe { ffi::mpsgraph_execution_descriptor_wait_until_completed(self.ptr) }
248    }
249
250    pub fn set_wait_until_completed(&self, value: bool) -> Result<()> {
251        // SAFETY: `self.ptr` is a live descriptor handle.
252        let ok =
253            unsafe { ffi::mpsgraph_execution_descriptor_set_wait_until_completed(self.ptr, value) };
254        if ok {
255            Ok(())
256        } else {
257            Err(Error::OperationFailed("failed to set waitUntilCompleted"))
258        }
259    }
260
261    #[must_use]
262    pub fn compilation_descriptor(&self) -> Option<CompilationDescriptor> {
263        // SAFETY: `self.ptr` is a live descriptor handle.
264        let ptr = unsafe { ffi::mpsgraph_execution_descriptor_compilation_descriptor(self.ptr) };
265        if ptr.is_null() {
266            None
267        } else {
268            Some(CompilationDescriptor { ptr })
269        }
270    }
271
272    pub fn set_compilation_descriptor(
273        &self,
274        descriptor: Option<&CompilationDescriptor>,
275    ) -> Result<()> {
276        let descriptor_ptr = descriptor.map_or(ptr::null_mut(), CompilationDescriptor::as_ptr);
277        // SAFETY: all handles remain valid for the duration of the call.
278        let ok = unsafe {
279            ffi::mpsgraph_execution_descriptor_set_compilation_descriptor(self.ptr, descriptor_ptr)
280        };
281        if ok {
282            Ok(())
283        } else {
284            Err(Error::OperationFailed(
285                "failed to set compilation descriptor",
286            ))
287        }
288    }
289}
290
291/// Safe owner for `MPSGraphExecutableExecutionDescriptor`.
292pub struct ExecutableExecutionDescriptor {
293    ptr: *mut c_void,
294}
295
296unsafe impl Send for ExecutableExecutionDescriptor {}
297unsafe impl Sync for ExecutableExecutionDescriptor {}
298
299impl Drop for ExecutableExecutionDescriptor {
300    fn drop(&mut self) {
301        release_handle(&mut self.ptr);
302    }
303}
304
305impl ExecutableExecutionDescriptor {
306    #[must_use]
307    pub fn new() -> Option<Self> {
308        // SAFETY: pure constructor.
309        let ptr = unsafe { ffi::mpsgraph_executable_execution_descriptor_new() };
310        if ptr.is_null() {
311            None
312        } else {
313            Some(Self { ptr })
314        }
315    }
316
317    #[must_use]
318    pub(crate) const fn as_ptr(&self) -> *mut c_void {
319        self.ptr
320    }
321
322    #[must_use]
323    pub fn wait_until_completed(&self) -> bool {
324        // SAFETY: `self.ptr` is a live descriptor handle.
325        unsafe { ffi::mpsgraph_executable_execution_descriptor_wait_until_completed(self.ptr) }
326    }
327
328    pub fn set_wait_until_completed(&self, value: bool) -> Result<()> {
329        // SAFETY: `self.ptr` is a live descriptor handle.
330        let ok = unsafe {
331            ffi::mpsgraph_executable_execution_descriptor_set_wait_until_completed(self.ptr, value)
332        };
333        if ok {
334            Ok(())
335        } else {
336            Err(Error::OperationFailed(
337                "failed to set executable waitUntilCompleted",
338            ))
339        }
340    }
341}
342
343/// Safe owner for `MPSGraphExecutableSerializationDescriptor`.
344pub struct ExecutableSerializationDescriptor {
345    ptr: *mut c_void,
346}
347
348unsafe impl Send for ExecutableSerializationDescriptor {}
349unsafe impl Sync for ExecutableSerializationDescriptor {}
350
351impl Drop for ExecutableSerializationDescriptor {
352    fn drop(&mut self) {
353        release_handle(&mut self.ptr);
354    }
355}
356
357impl ExecutableSerializationDescriptor {
358    #[must_use]
359    pub fn new() -> Option<Self> {
360        // SAFETY: pure constructor.
361        let ptr = unsafe { ffi::mpsgraph_executable_serialization_descriptor_new() };
362        if ptr.is_null() {
363            None
364        } else {
365            Some(Self { ptr })
366        }
367    }
368
369    #[must_use]
370    pub(crate) const fn as_ptr(&self) -> *mut c_void {
371        self.ptr
372    }
373
374    #[must_use]
375    pub fn append(&self) -> bool {
376        // SAFETY: `self.ptr` is a live descriptor handle.
377        unsafe { ffi::mpsgraph_executable_serialization_descriptor_append(self.ptr) }
378    }
379
380    pub fn set_append(&self, value: bool) -> Result<()> {
381        // SAFETY: `self.ptr` is a live descriptor handle.
382        let ok = unsafe {
383            ffi::mpsgraph_executable_serialization_descriptor_set_append(self.ptr, value)
384        };
385        if ok {
386            Ok(())
387        } else {
388            Err(Error::OperationFailed("failed to set append"))
389        }
390    }
391
392    #[must_use]
393    pub fn deployment_platform(&self) -> u64 {
394        // SAFETY: `self.ptr` is a live descriptor handle.
395        unsafe { ffi::mpsgraph_executable_serialization_descriptor_deployment_platform(self.ptr) }
396    }
397
398    pub fn set_deployment_platform(&self, value: u64) -> Result<()> {
399        // SAFETY: `self.ptr` is a live descriptor handle.
400        let ok = unsafe {
401            ffi::mpsgraph_executable_serialization_descriptor_set_deployment_platform(
402                self.ptr, value,
403            )
404        };
405        if ok {
406            Ok(())
407        } else {
408            Err(Error::OperationFailed("failed to set deployment platform"))
409        }
410    }
411
412    pub fn minimum_deployment_target(&self) -> Result<String> {
413        copy_string(
414            ffi::mpsgraph_executable_serialization_descriptor_minimum_deployment_target_len,
415            ffi::mpsgraph_executable_serialization_descriptor_copy_minimum_deployment_target,
416            self.ptr,
417        )
418    }
419
420    pub fn set_minimum_deployment_target(&self, value: &str) -> Result<()> {
421        let value = CString::new(value)
422            .map_err(|_| Error::OperationFailed("minimum deployment target contained NUL"))?;
423        // SAFETY: the CString stays alive for the duration of the call.
424        let ok = unsafe {
425            ffi::mpsgraph_executable_serialization_descriptor_set_minimum_deployment_target(
426                self.ptr,
427                value.as_ptr(),
428            )
429        };
430        if ok {
431            Ok(())
432        } else {
433            Err(Error::OperationFailed(
434                "failed to set minimum deployment target",
435            ))
436        }
437    }
438}
439
440impl Graph {
441    /// Return the graph's `MPSGraphOptions` bitmask.
442    #[must_use]
443    pub fn options(&self) -> u64 {
444        // SAFETY: `self` owns a live graph handle.
445        unsafe { ffi::mpsgraph_graph_options(self.as_ptr()) }
446    }
447
448    /// Replace the graph's options bitmask.
449    pub fn set_options(&self, options: u64) -> Result<()> {
450        // SAFETY: `self` owns a live graph handle.
451        let ok = unsafe { ffi::mpsgraph_graph_set_options(self.as_ptr(), options) };
452        if ok {
453            Ok(())
454        } else {
455            Err(Error::OperationFailed("failed to set graph options"))
456        }
457    }
458
459    /// Return the graph's placeholder tensors in insertion order.
460    #[must_use]
461    pub fn placeholder_tensors(&self) -> Vec<Tensor> {
462        // SAFETY: `self` owns a live graph handle.
463        let box_handle = unsafe { ffi::mpsgraph_graph_placeholder_tensors(self.as_ptr()) };
464        collect_owned_tensors(box_handle)
465    }
466
467    /// Compile the graph with an optional compilation descriptor.
468    #[must_use]
469    pub fn compile_with_descriptor(
470        &self,
471        device: Option<&MetalDevice>,
472        feeds: &[FeedDescription<'_>],
473        targets: &[&Tensor],
474        descriptor: Option<&CompilationDescriptor>,
475    ) -> Option<Executable> {
476        let feed_tensors = feeds
477            .iter()
478            .map(|feed| feed.tensor.as_ptr())
479            .collect::<Vec<_>>();
480        let shape_lengths = feeds
481            .iter()
482            .map(|feed| feed.shape.len())
483            .collect::<Vec<_>>();
484        let data_types = feeds.iter().map(|feed| feed.data_type).collect::<Vec<_>>();
485        let flat_shapes = feeds
486            .iter()
487            .flat_map(|feed| feed.shape.iter().copied())
488            .collect::<Vec<_>>();
489        let target_tensors = targets
490            .iter()
491            .map(|tensor| tensor.as_ptr())
492            .collect::<Vec<_>>();
493        let device_ptr = device.map_or(ptr::null_mut(), MetalDevice::as_ptr);
494        let descriptor_ptr = descriptor.map_or(ptr::null_mut(), CompilationDescriptor::as_ptr);
495
496        // SAFETY: all pointer arrays stay alive for the duration of the call.
497        let ptr = unsafe {
498            ffi::mpsgraph_graph_compile_with_descriptor(
499                self.as_ptr(),
500                device_ptr,
501                feed_tensors.as_ptr(),
502                feeds.len(),
503                flat_shapes.as_ptr(),
504                shape_lengths.as_ptr(),
505                data_types.as_ptr(),
506                target_tensors.as_ptr(),
507                targets.len(),
508                descriptor_ptr,
509            )
510        };
511        if ptr.is_null() {
512            None
513        } else {
514            Some(Executable::from_raw(ptr, targets.len()))
515        }
516    }
517}
518
519impl Executable {
520    /// Return the executable's `MPSGraphOptions` bitmask.
521    #[must_use]
522    pub fn options(&self) -> u64 {
523        // SAFETY: `self` owns a live executable handle.
524        unsafe { ffi::mpsgraph_executable_options(self.as_ptr()) }
525    }
526
527    /// Replace the executable's options bitmask.
528    pub fn set_options(&self, options: u64) -> Result<()> {
529        // SAFETY: `self` owns a live executable handle.
530        let ok = unsafe { ffi::mpsgraph_executable_set_options(self.as_ptr(), options) };
531        if ok {
532            Ok(())
533        } else {
534            Err(Error::OperationFailed("failed to set executable options"))
535        }
536    }
537
538    /// Return feed tensors if this executable was compiled from a graph.
539    #[must_use]
540    pub fn feed_tensors(&self) -> Vec<Tensor> {
541        // SAFETY: `self` owns a live executable handle.
542        let box_handle = unsafe { ffi::mpsgraph_executable_feed_tensors(self.as_ptr()) };
543        collect_owned_tensors(box_handle)
544    }
545
546    /// Return target tensors if this executable was compiled from a graph.
547    #[must_use]
548    pub fn target_tensors(&self) -> Vec<Tensor> {
549        // SAFETY: `self` owns a live executable handle.
550        let box_handle = unsafe { ffi::mpsgraph_executable_target_tensors(self.as_ptr()) };
551        collect_owned_tensors(box_handle)
552    }
553
554    /// Specialize the executable for the provided input types.
555    pub fn specialize(
556        &self,
557        device: Option<&MetalDevice>,
558        input_types: &[&ShapedType],
559        descriptor: Option<&CompilationDescriptor>,
560    ) -> Result<()> {
561        let input_type_handles = input_types
562            .iter()
563            .map(|value| value.as_ptr())
564            .collect::<Vec<_>>();
565        let device_ptr = device.map_or(ptr::null_mut(), MetalDevice::as_ptr);
566        let descriptor_ptr = descriptor.map_or(ptr::null_mut(), CompilationDescriptor::as_ptr);
567
568        // SAFETY: all pointer arrays stay alive for the duration of the call.
569        let ok = unsafe {
570            ffi::mpsgraph_executable_specialize(
571                self.as_ptr(),
572                device_ptr,
573                input_type_handles.as_ptr(),
574                input_types.len(),
575                descriptor_ptr,
576            )
577        };
578        if ok {
579            Ok(())
580        } else {
581            Err(Error::OperationFailed("failed to specialize executable"))
582        }
583    }
584
585    /// Query specialized output types for the provided input types.
586    pub fn output_types(
587        &self,
588        device: Option<&MetalDevice>,
589        input_types: &[&ShapedType],
590        descriptor: Option<&CompilationDescriptor>,
591    ) -> Result<Vec<ShapedType>> {
592        let input_type_handles = input_types
593            .iter()
594            .map(|value| value.as_ptr())
595            .collect::<Vec<_>>();
596        let device_ptr = device.map_or(ptr::null_mut(), MetalDevice::as_ptr);
597        let descriptor_ptr = descriptor.map_or(ptr::null_mut(), CompilationDescriptor::as_ptr);
598
599        // SAFETY: all pointer arrays stay alive for the duration of the call.
600        let box_handle = unsafe {
601            ffi::mpsgraph_executable_get_output_types(
602                self.as_ptr(),
603                device_ptr,
604                input_type_handles.as_ptr(),
605                input_types.len(),
606                descriptor_ptr,
607            )
608        };
609        if box_handle.is_null() {
610            Err(Error::OperationFailed(
611                "failed to get executable output types",
612            ))
613        } else {
614            Ok(collect_shaped_type_array_box(box_handle))
615        }
616    }
617
618    /// Run the executable with an optional execution descriptor and optional preallocated results.
619    pub fn run_with_descriptor(
620        &self,
621        command_queue: &CommandQueue,
622        inputs: &[&TensorData],
623        results: Option<&[&TensorData]>,
624        descriptor: Option<&ExecutableExecutionDescriptor>,
625    ) -> Result<Vec<TensorData>> {
626        let input_handles = inputs
627            .iter()
628            .map(|value| value.as_ptr())
629            .collect::<Vec<_>>();
630        let result_handles = results
631            .map(|values| {
632                values
633                    .iter()
634                    .map(|value| value.as_ptr())
635                    .collect::<Vec<_>>()
636            })
637            .unwrap_or_default();
638        let descriptor_ptr =
639            descriptor.map_or(ptr::null_mut(), ExecutableExecutionDescriptor::as_ptr);
640
641        // SAFETY: all pointer arrays stay alive for the duration of the call.
642        let box_handle = unsafe {
643            ffi::mpsgraph_executable_run_with_descriptor(
644                self.as_ptr(),
645                command_queue.as_ptr(),
646                input_handles.as_ptr(),
647                inputs.len(),
648                result_handles.as_ptr(),
649                result_handles.len(),
650                descriptor_ptr,
651            )
652        };
653        if box_handle.is_null() {
654            Err(Error::OperationFailed("failed to run executable"))
655        } else {
656            Ok(collect_tensor_data_array_box(box_handle))
657        }
658    }
659
660    /// Asynchronously run the executable with an optional execution descriptor.
661    pub fn run_async_with_descriptor(
662        &self,
663        command_queue: &CommandQueue,
664        inputs: &[&TensorData],
665        results: Option<&[&TensorData]>,
666        descriptor: Option<&ExecutableExecutionDescriptor>,
667    ) -> Result<Vec<TensorData>> {
668        let input_handles = inputs
669            .iter()
670            .map(|value| value.as_ptr())
671            .collect::<Vec<_>>();
672        let result_handles = results
673            .map(|values| {
674                values
675                    .iter()
676                    .map(|value| value.as_ptr())
677                    .collect::<Vec<_>>()
678            })
679            .unwrap_or_default();
680        let descriptor_ptr =
681            descriptor.map_or(ptr::null_mut(), ExecutableExecutionDescriptor::as_ptr);
682
683        // SAFETY: all pointer arrays stay alive for the duration of the call.
684        let box_handle = unsafe {
685            ffi::mpsgraph_executable_run_async_with_descriptor(
686                self.as_ptr(),
687                command_queue.as_ptr(),
688                input_handles.as_ptr(),
689                inputs.len(),
690                result_handles.as_ptr(),
691                result_handles.len(),
692                descriptor_ptr,
693            )
694        };
695        if box_handle.is_null() {
696            Err(Error::OperationFailed(
697                "failed to run executable asynchronously",
698            ))
699        } else {
700            Ok(collect_tensor_data_array_box(box_handle))
701        }
702    }
703
704    /// Serialize the executable to an `.mpsgraphpackage` path.
705    pub fn serialize_package(
706        &self,
707        path: &str,
708        descriptor: Option<&ExecutableSerializationDescriptor>,
709    ) -> Result<()> {
710        let path =
711            CString::new(path).map_err(|_| Error::OperationFailed("package path contained NUL"))?;
712        let descriptor_ptr =
713            descriptor.map_or(ptr::null_mut(), ExecutableSerializationDescriptor::as_ptr);
714        // SAFETY: the CString stays alive for the duration of the call.
715        let ok = unsafe {
716            ffi::mpsgraph_executable_serialize_package(self.as_ptr(), path.as_ptr(), descriptor_ptr)
717        };
718        if ok {
719            Ok(())
720        } else {
721            Err(Error::OperationFailed(
722                "failed to serialize executable package",
723            ))
724        }
725    }
726
727    /// Load an executable from an existing `.mpsgraphpackage`.
728    pub fn from_package(path: &str, descriptor: Option<&CompilationDescriptor>) -> Result<Self> {
729        let path =
730            CString::new(path).map_err(|_| Error::OperationFailed("package path contained NUL"))?;
731        let descriptor_ptr = descriptor.map_or(ptr::null_mut(), CompilationDescriptor::as_ptr);
732        // SAFETY: the CString stays alive for the duration of the call.
733        let ptr =
734            unsafe { ffi::mpsgraph_executable_new_with_package(path.as_ptr(), descriptor_ptr) };
735        if ptr.is_null() {
736            return Err(Error::OperationFailed("failed to load executable package"));
737        }
738        let output_count = {
739            // SAFETY: `ptr` is a live executable handle returned just above.
740            let box_handle = unsafe { ffi::mpsgraph_executable_target_tensors(ptr) };
741            collect_owned_tensors(box_handle).len()
742        };
743        Ok(Self::from_raw(ptr, output_count))
744    }
745}