Skip to main content

apple_mpsgraph/
graph.rs

1use crate::data::TensorData;
2use crate::error::{Error, Result};
3use crate::ffi;
4use apple_metal::{CommandQueue, MetalDevice};
5use core::ffi::{c_char, c_void};
6use core::ptr;
7use std::ffi::CString;
8
9/// Selected `MPSDataType` constants useful for graph inputs and outputs.
10pub mod data_type {
11    pub const INVALID: u32 = 0;
12    pub const FLOAT32: u32 = 0x1000_0020;
13    pub const FLOAT16: u32 = 0x1000_0010;
14    pub const INT8: u32 = 0x2000_0008;
15    pub const INT16: u32 = 0x2000_0010;
16    pub const INT32: u32 = 0x2000_0020;
17    pub const INT64: u32 = 0x2000_0040;
18    pub const UINT8: u32 = 0x0000_0008;
19    pub const UINT16: u32 = 0x0000_0010;
20    pub const UINT32: u32 = 0x0000_0020;
21    pub const UINT64: u32 = 0x0000_0040;
22    pub const BOOL: u32 = 0x8000_0008;
23    pub const UNORM8: u32 = 0x4000_0008;
24}
25
26/// Return the byte width of a supported `MPSDataType`.
27#[must_use]
28pub const fn data_type_size(data_type: u32) -> Option<usize> {
29    match data_type {
30        data_type::FLOAT16 | data_type::INT16 | data_type::UINT16 => Some(2),
31        data_type::FLOAT32 | data_type::INT32 | data_type::UINT32 => Some(4),
32        data_type::INT64 | data_type::UINT64 => Some(8),
33        data_type::INT8 | data_type::UINT8 | data_type::BOOL | data_type::UNORM8 => Some(1),
34        _ => None,
35    }
36}
37
38/// `MPSGraphTensorNamedDataLayout` constants.
39pub mod tensor_named_data_layout {
40    pub const NCHW: usize = 0;
41    pub const NHWC: usize = 1;
42    pub const OIHW: usize = 2;
43    pub const HWIO: usize = 3;
44    pub const CHW: usize = 4;
45    pub const HWC: usize = 5;
46    pub const HW: usize = 6;
47}
48
49/// `MPSGraphPaddingStyle` constants.
50pub mod padding_style {
51    pub const EXPLICIT: usize = 0;
52    pub const TF_VALID: usize = 1;
53    pub const TF_SAME: usize = 2;
54    pub const EXPLICIT_OFFSET: usize = 3;
55    pub const ONNX_SAME_LOWER: usize = 4;
56}
57
58macro_rules! opaque_handle {
59    ($name:ident) => {
60        pub struct $name {
61            ptr: *mut c_void,
62        }
63
64        unsafe impl Send for $name {}
65        unsafe impl Sync for $name {}
66
67        impl Drop for $name {
68            fn drop(&mut self) {
69                if !self.ptr.is_null() {
70                    // SAFETY: `ptr` is a +1 retained Swift/ObjC object pointer owned by this wrapper.
71                    unsafe { ffi::mpsgraph_object_release(self.ptr) };
72                    self.ptr = ptr::null_mut();
73                }
74            }
75        }
76
77        impl $name {
78            #[must_use]
79            pub const fn as_ptr(&self) -> *mut c_void {
80                self.ptr
81            }
82        }
83    };
84}
85
86fn checked_byte_len(shape: &[usize], data_type: u32) -> Option<usize> {
87    let element_size = data_type_size(data_type)?;
88    shape
89        .iter()
90        .try_fold(element_size, |acc, dimension| acc.checked_mul(*dimension))
91}
92
93fn optional_cstring(name: Option<&str>) -> Option<CString> {
94    name.and_then(|value| CString::new(value).ok())
95}
96
97#[allow(clippy::ref_option)]
98fn cstring_ptr(value: &Option<CString>) -> *const c_char {
99    value.as_ref().map_or(ptr::null(), |value| value.as_ptr())
100}
101
102fn wrap_tensor(ptr: *mut c_void) -> Option<Tensor> {
103    if ptr.is_null() {
104        None
105    } else {
106        Some(Tensor { ptr })
107    }
108}
109
110fn wrap_tensor_data_results(
111    handles: Vec<*mut c_void>,
112    message: &'static str,
113) -> Result<Vec<TensorData>> {
114    let mut results = Vec::with_capacity(handles.len());
115    for handle in handles {
116        if handle.is_null() {
117            return Err(Error::OperationFailed(message));
118        }
119        results.push(TensorData::from_raw(handle));
120    }
121    Ok(results)
122}
123
124macro_rules! impl_binary_tensor_op {
125    ($fn_name:ident, $ffi_name:ident) => {
126        #[must_use]
127        pub fn $fn_name(
128            &self,
129            primary: &Tensor,
130            secondary: &Tensor,
131            name: Option<&str>,
132        ) -> Option<Tensor> {
133            let name = optional_cstring(name);
134            // SAFETY: All pointers originate from safe wrappers and remain alive for the duration of the call.
135            let ptr = unsafe {
136                ffi::$ffi_name(
137                    self.ptr,
138                    primary.as_ptr(),
139                    secondary.as_ptr(),
140                    cstring_ptr(&name),
141                )
142            };
143            wrap_tensor(ptr)
144        }
145    };
146}
147
148macro_rules! impl_unary_tensor_op {
149    ($fn_name:ident, $ffi_name:ident) => {
150        #[must_use]
151        pub fn $fn_name(&self, tensor: &Tensor, name: Option<&str>) -> Option<Tensor> {
152            let name = optional_cstring(name);
153            // SAFETY: All pointers originate from safe wrappers and remain alive for the duration of the call.
154            let ptr = unsafe { ffi::$ffi_name(self.ptr, tensor.as_ptr(), cstring_ptr(&name)) };
155            wrap_tensor(ptr)
156        }
157    };
158}
159
160macro_rules! impl_axes_tensor_op {
161    ($fn_name:ident, $ffi_name:ident) => {
162        #[must_use]
163        pub fn $fn_name(
164            &self,
165            tensor: &Tensor,
166            axes: &[usize],
167            name: Option<&str>,
168        ) -> Option<Tensor> {
169            let name = optional_cstring(name);
170            // SAFETY: All pointers originate from safe wrappers and remain alive for the duration of the call.
171            let ptr = unsafe {
172                ffi::$ffi_name(
173                    self.ptr,
174                    tensor.as_ptr(),
175                    axes.as_ptr(),
176                    axes.len(),
177                    cstring_ptr(&name),
178                )
179            };
180            wrap_tensor(ptr)
181        }
182    };
183}
184
185/// Ordered placeholder feed pairing used for graph execution.
186#[derive(Clone, Copy)]
187pub struct Feed<'a> {
188    pub tensor: &'a Tensor,
189    pub data: &'a TensorData,
190}
191
192impl<'a> Feed<'a> {
193    #[must_use]
194    pub const fn new(tensor: &'a Tensor, data: &'a TensorData) -> Self {
195        Self { tensor, data }
196    }
197}
198
199/// Feed metadata used to compile a graph into an executable.
200#[derive(Clone, Copy)]
201pub struct FeedDescription<'a> {
202    pub tensor: &'a Tensor,
203    pub shape: &'a [usize],
204    pub data_type: u32,
205}
206
207impl<'a> FeedDescription<'a> {
208    #[must_use]
209    pub const fn new(tensor: &'a Tensor, shape: &'a [usize], data_type: u32) -> Self {
210        Self {
211            tensor,
212            shape,
213            data_type,
214        }
215    }
216}
217
218/// Plain-Rust configuration for `MPSGraphConvolution2DOpDescriptor`.
219#[derive(Debug, Clone, Copy)]
220pub struct Convolution2DDescriptorInfo {
221    pub stride_in_x: usize,
222    pub stride_in_y: usize,
223    pub dilation_rate_in_x: usize,
224    pub dilation_rate_in_y: usize,
225    pub groups: usize,
226    pub padding_left: usize,
227    pub padding_right: usize,
228    pub padding_top: usize,
229    pub padding_bottom: usize,
230    pub padding_style: usize,
231    pub data_layout: usize,
232    pub weights_layout: usize,
233}
234
235impl Default for Convolution2DDescriptorInfo {
236    fn default() -> Self {
237        Self {
238            stride_in_x: 1,
239            stride_in_y: 1,
240            dilation_rate_in_x: 1,
241            dilation_rate_in_y: 1,
242            groups: 1,
243            padding_left: 0,
244            padding_right: 0,
245            padding_top: 0,
246            padding_bottom: 0,
247            padding_style: padding_style::EXPLICIT,
248            data_layout: tensor_named_data_layout::NHWC,
249            weights_layout: tensor_named_data_layout::HWIO,
250        }
251    }
252}
253
254opaque_handle!(Convolution2DDescriptor);
255impl Convolution2DDescriptor {
256    #[must_use]
257    pub fn new(info: Convolution2DDescriptorInfo) -> Option<Self> {
258        // SAFETY: All scalar configuration values are POD.
259        let ptr = unsafe {
260            ffi::mpsgraph_convolution2d_descriptor_new(
261                info.stride_in_x,
262                info.stride_in_y,
263                info.dilation_rate_in_x,
264                info.dilation_rate_in_y,
265                info.groups,
266                info.padding_left,
267                info.padding_right,
268                info.padding_top,
269                info.padding_bottom,
270                info.padding_style,
271                info.data_layout,
272                info.weights_layout,
273            )
274        };
275        if ptr.is_null() {
276            None
277        } else {
278            Some(Self { ptr })
279        }
280    }
281}
282
283/// Plain-Rust configuration for `MPSGraphPooling2DOpDescriptor`.
284#[derive(Debug, Clone, Copy)]
285pub struct Pooling2DDescriptorInfo {
286    pub kernel_width: usize,
287    pub kernel_height: usize,
288    pub stride_in_x: usize,
289    pub stride_in_y: usize,
290    pub dilation_rate_in_x: usize,
291    pub dilation_rate_in_y: usize,
292    pub padding_left: usize,
293    pub padding_right: usize,
294    pub padding_top: usize,
295    pub padding_bottom: usize,
296    pub padding_style: usize,
297    pub data_layout: usize,
298}
299
300impl Pooling2DDescriptorInfo {
301    #[must_use]
302    pub const fn new(kernel_width: usize, kernel_height: usize) -> Self {
303        Self {
304            kernel_width,
305            kernel_height,
306            stride_in_x: 1,
307            stride_in_y: 1,
308            dilation_rate_in_x: 1,
309            dilation_rate_in_y: 1,
310            padding_left: 0,
311            padding_right: 0,
312            padding_top: 0,
313            padding_bottom: 0,
314            padding_style: padding_style::EXPLICIT,
315            data_layout: tensor_named_data_layout::NHWC,
316        }
317    }
318}
319
320opaque_handle!(Pooling2DDescriptor);
321impl Pooling2DDescriptor {
322    #[must_use]
323    pub fn new(info: Pooling2DDescriptorInfo) -> Option<Self> {
324        // SAFETY: All scalar configuration values are POD.
325        let ptr = unsafe {
326            ffi::mpsgraph_pooling2d_descriptor_new(
327                info.kernel_width,
328                info.kernel_height,
329                info.stride_in_x,
330                info.stride_in_y,
331                info.dilation_rate_in_x,
332                info.dilation_rate_in_y,
333                info.padding_left,
334                info.padding_right,
335                info.padding_top,
336                info.padding_bottom,
337                info.padding_style,
338                info.data_layout,
339            )
340        };
341        if ptr.is_null() {
342            None
343        } else {
344            Some(Self { ptr })
345        }
346    }
347}
348
349opaque_handle!(Graph);
350opaque_handle!(Tensor);
351
352impl Graph {
353    #[must_use]
354    pub fn new() -> Option<Self> {
355        // SAFETY: Pure constructor with no inputs.
356        let ptr = unsafe { ffi::mpsgraph_graph_new() };
357        if ptr.is_null() {
358            None
359        } else {
360            Some(Self { ptr })
361        }
362    }
363
364    #[must_use]
365    pub fn placeholder(
366        &self,
367        shape: Option<&[usize]>,
368        data_type: u32,
369        name: Option<&str>,
370    ) -> Option<Tensor> {
371        let name = optional_cstring(name);
372        let (shape_ptr, shape_len) =
373            shape.map_or((ptr::null(), 0), |shape| (shape.as_ptr(), shape.len()));
374
375        // SAFETY: All pointers originate from safe wrappers and remain alive for the duration of the call.
376        let ptr = unsafe {
377            ffi::mpsgraph_graph_placeholder(
378                self.ptr,
379                shape_ptr,
380                shape_len,
381                data_type,
382                cstring_ptr(&name),
383            )
384        };
385        wrap_tensor(ptr)
386    }
387
388    #[must_use]
389    pub fn constant_bytes(&self, data: &[u8], shape: &[usize], data_type: u32) -> Option<Tensor> {
390        let expected = checked_byte_len(shape, data_type)?;
391        if data.len() != expected {
392            return None;
393        }
394
395        // SAFETY: The byte slice remains valid for the duration of the FFI call.
396        let ptr = unsafe {
397            ffi::mpsgraph_graph_constant_data(
398                self.ptr,
399                data.as_ptr().cast(),
400                data.len(),
401                shape.as_ptr(),
402                shape.len(),
403                data_type,
404            )
405        };
406        wrap_tensor(ptr)
407    }
408
409    #[must_use]
410    pub fn constant_f32_slice(&self, values: &[f32], shape: &[usize]) -> Option<Tensor> {
411        // SAFETY: `values` is a contiguous slice of `f32` that may be viewed as bytes.
412        let bytes = unsafe {
413            core::slice::from_raw_parts(
414                values.as_ptr().cast::<u8>(),
415                core::mem::size_of_val(values),
416            )
417        };
418        self.constant_bytes(bytes, shape, data_type::FLOAT32)
419    }
420
421    #[must_use]
422    pub fn constant_scalar(&self, scalar: f64, data_type: u32) -> Option<Tensor> {
423        // SAFETY: Pure constructor over scalar inputs.
424        let ptr = unsafe { ffi::mpsgraph_graph_constant_scalar(self.ptr, scalar, data_type) };
425        wrap_tensor(ptr)
426    }
427
428    #[must_use]
429    pub fn constant_scalar_shaped(
430        &self,
431        scalar: f64,
432        shape: &[usize],
433        data_type: u32,
434    ) -> Option<Tensor> {
435        // SAFETY: Shape slice stays valid for the duration of the FFI call.
436        let ptr = unsafe {
437            ffi::mpsgraph_graph_constant_scalar_shaped(
438                self.ptr,
439                scalar,
440                shape.as_ptr(),
441                shape.len(),
442                data_type,
443            )
444        };
445        wrap_tensor(ptr)
446    }
447
448    impl_binary_tensor_op!(addition, mpsgraph_graph_addition);
449    impl_binary_tensor_op!(subtraction, mpsgraph_graph_subtraction);
450    impl_binary_tensor_op!(multiplication, mpsgraph_graph_multiplication);
451    impl_binary_tensor_op!(division, mpsgraph_graph_division);
452    impl_binary_tensor_op!(matrix_multiplication, mpsgraph_graph_matrix_multiplication);
453    impl_unary_tensor_op!(relu, mpsgraph_graph_relu);
454    impl_unary_tensor_op!(sigmoid, mpsgraph_graph_sigmoid);
455    impl_axes_tensor_op!(reduction_sum, mpsgraph_graph_reduction_sum);
456    impl_axes_tensor_op!(reduction_maximum, mpsgraph_graph_reduction_maximum);
457    impl_axes_tensor_op!(reduction_minimum, mpsgraph_graph_reduction_minimum);
458    impl_axes_tensor_op!(mean, mpsgraph_graph_mean);
459
460    #[must_use]
461    pub fn softmax(&self, tensor: &Tensor, axis: isize, name: Option<&str>) -> Option<Tensor> {
462        let name = optional_cstring(name);
463        // SAFETY: All pointers originate from safe wrappers and remain alive for the duration of the call.
464        let ptr = unsafe {
465            ffi::mpsgraph_graph_softmax(self.ptr, tensor.as_ptr(), axis, cstring_ptr(&name))
466        };
467        wrap_tensor(ptr)
468    }
469
470    #[must_use]
471    pub fn reshape(&self, tensor: &Tensor, shape: &[usize], name: Option<&str>) -> Option<Tensor> {
472        let name = optional_cstring(name);
473        // SAFETY: All pointers originate from safe wrappers and remain alive for the duration of the call.
474        let ptr = unsafe {
475            ffi::mpsgraph_graph_reshape(
476                self.ptr,
477                tensor.as_ptr(),
478                shape.as_ptr(),
479                shape.len(),
480                cstring_ptr(&name),
481            )
482        };
483        wrap_tensor(ptr)
484    }
485
486    #[must_use]
487    pub fn transpose(
488        &self,
489        tensor: &Tensor,
490        permutation: &[usize],
491        name: Option<&str>,
492    ) -> Option<Tensor> {
493        let name = optional_cstring(name);
494        // SAFETY: All pointers originate from safe wrappers and remain alive for the duration of the call.
495        let ptr = unsafe {
496            ffi::mpsgraph_graph_transpose(
497                self.ptr,
498                tensor.as_ptr(),
499                permutation.as_ptr(),
500                permutation.len(),
501                cstring_ptr(&name),
502            )
503        };
504        wrap_tensor(ptr)
505    }
506
507    #[must_use]
508    pub fn slice(
509        &self,
510        tensor: &Tensor,
511        dimension: usize,
512        start: isize,
513        length: isize,
514        name: Option<&str>,
515    ) -> Option<Tensor> {
516        let name = optional_cstring(name);
517        // SAFETY: All pointers originate from safe wrappers and remain alive for the duration of the call.
518        let ptr = unsafe {
519            ffi::mpsgraph_graph_slice(
520                self.ptr,
521                tensor.as_ptr(),
522                dimension,
523                start,
524                length,
525                cstring_ptr(&name),
526            )
527        };
528        wrap_tensor(ptr)
529    }
530
531    #[must_use]
532    pub fn broadcast(
533        &self,
534        tensor: &Tensor,
535        shape: &[usize],
536        name: Option<&str>,
537    ) -> Option<Tensor> {
538        let name = optional_cstring(name);
539        // SAFETY: All pointers originate from safe wrappers and remain alive for the duration of the call.
540        let ptr = unsafe {
541            ffi::mpsgraph_graph_broadcast(
542                self.ptr,
543                tensor.as_ptr(),
544                shape.as_ptr(),
545                shape.len(),
546                cstring_ptr(&name),
547            )
548        };
549        wrap_tensor(ptr)
550    }
551
552    #[must_use]
553    pub fn convolution2d(
554        &self,
555        source: &Tensor,
556        weights: &Tensor,
557        descriptor: &Convolution2DDescriptor,
558        name: Option<&str>,
559    ) -> Option<Tensor> {
560        let name = optional_cstring(name);
561        // SAFETY: All pointers originate from safe wrappers and remain alive for the duration of the call.
562        let ptr = unsafe {
563            ffi::mpsgraph_graph_convolution2d(
564                self.ptr,
565                source.as_ptr(),
566                weights.as_ptr(),
567                descriptor.as_ptr(),
568                cstring_ptr(&name),
569            )
570        };
571        wrap_tensor(ptr)
572    }
573
574    #[must_use]
575    pub fn max_pooling2d(
576        &self,
577        source: &Tensor,
578        descriptor: &Pooling2DDescriptor,
579        name: Option<&str>,
580    ) -> Option<Tensor> {
581        let name = optional_cstring(name);
582        // SAFETY: All pointers originate from safe wrappers and remain alive for the duration of the call.
583        let ptr = unsafe {
584            ffi::mpsgraph_graph_max_pooling2d(
585                self.ptr,
586                source.as_ptr(),
587                descriptor.as_ptr(),
588                cstring_ptr(&name),
589            )
590        };
591        wrap_tensor(ptr)
592    }
593
594    #[allow(clippy::too_many_arguments)]
595    #[must_use]
596    pub fn normalize(
597        &self,
598        tensor: &Tensor,
599        mean: &Tensor,
600        variance: &Tensor,
601        gamma: Option<&Tensor>,
602        beta: Option<&Tensor>,
603        epsilon: f32,
604        name: Option<&str>,
605    ) -> Option<Tensor> {
606        let name = optional_cstring(name);
607        let gamma_ptr = gamma.map_or(ptr::null_mut(), Tensor::as_ptr);
608        let beta_ptr = beta.map_or(ptr::null_mut(), Tensor::as_ptr);
609        // SAFETY: All pointers originate from safe wrappers and remain alive for the duration of the call.
610        let ptr = unsafe {
611            ffi::mpsgraph_graph_normalize(
612                self.ptr,
613                tensor.as_ptr(),
614                mean.as_ptr(),
615                variance.as_ptr(),
616                gamma_ptr,
617                beta_ptr,
618                epsilon,
619                cstring_ptr(&name),
620            )
621        };
622        wrap_tensor(ptr)
623    }
624
625    pub fn run(&self, feeds: &[Feed<'_>], targets: &[&Tensor]) -> Result<Vec<TensorData>> {
626        let feed_tensors = feeds
627            .iter()
628            .map(|feed| feed.tensor.as_ptr())
629            .collect::<Vec<_>>();
630        let feed_data = feeds
631            .iter()
632            .map(|feed| feed.data.as_ptr())
633            .collect::<Vec<_>>();
634        let target_tensors = targets
635            .iter()
636            .map(|tensor| tensor.as_ptr())
637            .collect::<Vec<_>>();
638        let mut results = vec![ptr::null_mut(); targets.len()];
639
640        // SAFETY: The pointer arrays are valid for the duration of the FFI call.
641        let ok = unsafe {
642            ffi::mpsgraph_graph_run(
643                self.ptr,
644                feed_tensors.as_ptr(),
645                feed_data.as_ptr(),
646                feeds.len(),
647                target_tensors.as_ptr(),
648                targets.len(),
649                results.as_mut_ptr(),
650            )
651        };
652        if ok {
653            wrap_tensor_data_results(results, "failed to run graph")
654        } else {
655            Err(Error::OperationFailed("failed to run graph"))
656        }
657    }
658
659    pub fn run_with_command_queue(
660        &self,
661        command_queue: &CommandQueue,
662        feeds: &[Feed<'_>],
663        targets: &[&Tensor],
664    ) -> Result<Vec<TensorData>> {
665        let feed_tensors = feeds
666            .iter()
667            .map(|feed| feed.tensor.as_ptr())
668            .collect::<Vec<_>>();
669        let feed_data = feeds
670            .iter()
671            .map(|feed| feed.data.as_ptr())
672            .collect::<Vec<_>>();
673        let target_tensors = targets
674            .iter()
675            .map(|tensor| tensor.as_ptr())
676            .collect::<Vec<_>>();
677        let mut results = vec![ptr::null_mut(); targets.len()];
678
679        // SAFETY: The pointer arrays are valid for the duration of the FFI call.
680        let ok = unsafe {
681            ffi::mpsgraph_graph_run_with_command_queue(
682                self.ptr,
683                command_queue.as_ptr(),
684                feed_tensors.as_ptr(),
685                feed_data.as_ptr(),
686                feeds.len(),
687                target_tensors.as_ptr(),
688                targets.len(),
689                results.as_mut_ptr(),
690            )
691        };
692        if ok {
693            wrap_tensor_data_results(results, "failed to run graph with command queue")
694        } else {
695            Err(Error::OperationFailed(
696                "failed to run graph with command queue",
697            ))
698        }
699    }
700
701    #[must_use]
702    pub fn compile(
703        &self,
704        device: &MetalDevice,
705        feeds: &[FeedDescription<'_>],
706        targets: &[&Tensor],
707    ) -> Option<Executable> {
708        let feed_tensors = feeds
709            .iter()
710            .map(|feed| feed.tensor.as_ptr())
711            .collect::<Vec<_>>();
712        let shape_lengths = feeds
713            .iter()
714            .map(|feed| feed.shape.len())
715            .collect::<Vec<_>>();
716        let data_types = feeds.iter().map(|feed| feed.data_type).collect::<Vec<_>>();
717        let flat_shapes = feeds
718            .iter()
719            .flat_map(|feed| feed.shape.iter().copied())
720            .collect::<Vec<_>>();
721        let target_tensors = targets
722            .iter()
723            .map(|tensor| tensor.as_ptr())
724            .collect::<Vec<_>>();
725
726        // SAFETY: The pointer arrays are valid for the duration of the FFI call.
727        let ptr = unsafe {
728            ffi::mpsgraph_graph_compile(
729                self.ptr,
730                device.as_ptr(),
731                feed_tensors.as_ptr(),
732                feeds.len(),
733                flat_shapes.as_ptr(),
734                shape_lengths.as_ptr(),
735                data_types.as_ptr(),
736                target_tensors.as_ptr(),
737                targets.len(),
738            )
739        };
740        if ptr.is_null() {
741            None
742        } else {
743            Some(Executable {
744                ptr,
745                output_count: targets.len(),
746            })
747        }
748    }
749}
750
751/// Safe owner for a compiled `MPSGraphExecutable`.
752pub struct Executable {
753    ptr: *mut c_void,
754    output_count: usize,
755}
756
757unsafe impl Send for Executable {}
758unsafe impl Sync for Executable {}
759
760impl Drop for Executable {
761    fn drop(&mut self) {
762        if !self.ptr.is_null() {
763            // SAFETY: `ptr` is a +1 retained Swift/ObjC object pointer owned by this wrapper.
764            unsafe { ffi::mpsgraph_object_release(self.ptr) };
765            self.ptr = ptr::null_mut();
766        }
767    }
768}
769
770impl Executable {
771    #[must_use]
772    pub const fn as_ptr(&self) -> *mut c_void {
773        self.ptr
774    }
775
776    #[must_use]
777    pub const fn output_count(&self) -> usize {
778        self.output_count
779    }
780
781    pub fn run(
782        &self,
783        command_queue: &CommandQueue,
784        inputs: &[&TensorData],
785    ) -> Result<Vec<TensorData>> {
786        let input_data = inputs
787            .iter()
788            .map(|tensor_data| tensor_data.as_ptr())
789            .collect::<Vec<_>>();
790        let mut results = vec![ptr::null_mut(); self.output_count];
791
792        // SAFETY: The pointer arrays are valid for the duration of the FFI call.
793        let ok = unsafe {
794            ffi::mpsgraph_executable_run(
795                self.ptr,
796                command_queue.as_ptr(),
797                input_data.as_ptr(),
798                inputs.len(),
799                self.output_count,
800                results.as_mut_ptr(),
801            )
802        };
803        if ok {
804            wrap_tensor_data_results(results, "failed to run executable")
805        } else {
806            Err(Error::OperationFailed("failed to run executable"))
807        }
808    }
809}