Skip to main content

apple_mpsgraph/
graph.rs

1use crate::data::TensorData;
2use crate::error::{Error, Result};
3use crate::ffi;
4use apple_metal::{CommandQueue, MetalDevice};
5use core::ffi::{c_char, c_void};
6use core::ptr;
7use std::ffi::CString;
8
9/// Selected `MPSDataType` constants useful for graph inputs and outputs.
10pub mod data_type {
11    pub const INVALID: u32 = 0;
12    pub const FLOAT32: u32 = 0x1000_0020;
13    pub const FLOAT16: u32 = 0x1000_0010;
14    pub const INT8: u32 = 0x2000_0008;
15    pub const INT16: u32 = 0x2000_0010;
16    pub const INT32: u32 = 0x2000_0020;
17    pub const INT64: u32 = 0x2000_0040;
18    pub const UINT8: u32 = 0x0000_0008;
19    pub const UINT16: u32 = 0x0000_0010;
20    pub const UINT32: u32 = 0x0000_0020;
21    pub const UINT64: u32 = 0x0000_0040;
22    pub const BOOL: u32 = 0x8000_0008;
23    pub const UNORM8: u32 = 0x4000_0008;
24}
25
26/// Return the byte width of a supported `MPSDataType`.
27#[must_use]
28pub const fn data_type_size(data_type: u32) -> Option<usize> {
29    match data_type {
30        data_type::FLOAT16 | data_type::INT16 | data_type::UINT16 => Some(2),
31        data_type::FLOAT32 | data_type::INT32 | data_type::UINT32 => Some(4),
32        data_type::INT64 | data_type::UINT64 => Some(8),
33        data_type::INT8 | data_type::UINT8 | data_type::BOOL | data_type::UNORM8 => Some(1),
34        _ => None,
35    }
36}
37
38/// `MPSGraphTensorNamedDataLayout` constants.
39pub mod tensor_named_data_layout {
40    pub const NCHW: usize = 0;
41    pub const NHWC: usize = 1;
42    pub const OIHW: usize = 2;
43    pub const HWIO: usize = 3;
44    pub const CHW: usize = 4;
45    pub const HWC: usize = 5;
46    pub const HW: usize = 6;
47}
48
49/// `MPSGraphPaddingStyle` constants.
50pub mod padding_style {
51    pub const EXPLICIT: usize = 0;
52    pub const TF_VALID: usize = 1;
53    pub const TF_SAME: usize = 2;
54    pub const EXPLICIT_OFFSET: usize = 3;
55    pub const ONNX_SAME_LOWER: usize = 4;
56}
57
58macro_rules! opaque_handle {
59    ($name:ident) => {
60        pub struct $name {
61            ptr: *mut c_void,
62        }
63
64        unsafe impl Send for $name {}
65        unsafe impl Sync for $name {}
66
67        impl Drop for $name {
68            fn drop(&mut self) {
69                if !self.ptr.is_null() {
70                    // SAFETY: `ptr` is a +1 retained Swift/ObjC object pointer owned by this wrapper.
71                    unsafe { ffi::mpsgraph_object_release(self.ptr) };
72                    self.ptr = ptr::null_mut();
73                }
74            }
75        }
76
77        impl $name {
78            #[must_use]
79            pub const fn as_ptr(&self) -> *mut c_void {
80                self.ptr
81            }
82        }
83    };
84}
85
86fn checked_byte_len(shape: &[usize], data_type: u32) -> Option<usize> {
87    let element_size = data_type_size(data_type)?;
88    shape
89        .iter()
90        .try_fold(element_size, |acc, dimension| acc.checked_mul(*dimension))
91}
92
93fn optional_cstring(name: Option<&str>) -> Option<CString> {
94    name.and_then(|value| CString::new(value).ok())
95}
96
97#[allow(clippy::ref_option)]
98fn cstring_ptr(value: &Option<CString>) -> *const c_char {
99    value.as_ref().map_or(ptr::null(), |value| value.as_ptr())
100}
101
102fn wrap_tensor(ptr: *mut c_void) -> Option<Tensor> {
103    if ptr.is_null() {
104        None
105    } else {
106        Some(Tensor { ptr })
107    }
108}
109
110fn wrap_tensor_data_results(
111    handles: Vec<*mut c_void>,
112    message: &'static str,
113) -> Result<Vec<TensorData>> {
114    let mut results = Vec::with_capacity(handles.len());
115    for handle in handles {
116        if handle.is_null() {
117            return Err(Error::OperationFailed(message));
118        }
119        results.push(TensorData::from_raw(handle));
120    }
121    Ok(results)
122}
123
124macro_rules! impl_binary_tensor_op {
125    ($fn_name:ident, $ffi_name:ident) => {
126        #[must_use]
127        pub fn $fn_name(
128            &self,
129            primary: &Tensor,
130            secondary: &Tensor,
131            name: Option<&str>,
132        ) -> Option<Tensor> {
133            let name = optional_cstring(name);
134            // SAFETY: All pointers originate from safe wrappers and remain alive for the duration of the call.
135            let ptr = unsafe {
136                ffi::$ffi_name(
137                    self.ptr,
138                    primary.as_ptr(),
139                    secondary.as_ptr(),
140                    cstring_ptr(&name),
141                )
142            };
143            wrap_tensor(ptr)
144        }
145    };
146}
147
148macro_rules! impl_unary_tensor_op {
149    ($fn_name:ident, $ffi_name:ident) => {
150        #[must_use]
151        pub fn $fn_name(&self, tensor: &Tensor, name: Option<&str>) -> Option<Tensor> {
152            let name = optional_cstring(name);
153            // SAFETY: All pointers originate from safe wrappers and remain alive for the duration of the call.
154            let ptr = unsafe { ffi::$ffi_name(self.ptr, tensor.as_ptr(), cstring_ptr(&name)) };
155            wrap_tensor(ptr)
156        }
157    };
158}
159
160macro_rules! impl_axes_tensor_op {
161    ($fn_name:ident, $ffi_name:ident) => {
162        #[must_use]
163        pub fn $fn_name(
164            &self,
165            tensor: &Tensor,
166            axes: &[usize],
167            name: Option<&str>,
168        ) -> Option<Tensor> {
169            let name = optional_cstring(name);
170            // SAFETY: All pointers originate from safe wrappers and remain alive for the duration of the call.
171            let ptr = unsafe {
172                ffi::$ffi_name(
173                    self.ptr,
174                    tensor.as_ptr(),
175                    axes.as_ptr(),
176                    axes.len(),
177                    cstring_ptr(&name),
178                )
179            };
180            wrap_tensor(ptr)
181        }
182    };
183}
184
185/// Ordered placeholder feed pairing used for graph execution.
186#[derive(Clone, Copy)]
187pub struct Feed<'a> {
188    pub tensor: &'a Tensor,
189    pub data: &'a TensorData,
190}
191
192impl<'a> Feed<'a> {
193    #[must_use]
194    pub const fn new(tensor: &'a Tensor, data: &'a TensorData) -> Self {
195        Self { tensor, data }
196    }
197}
198
199/// Feed metadata used to compile a graph into an executable.
200#[derive(Clone, Copy)]
201pub struct FeedDescription<'a> {
202    pub tensor: &'a Tensor,
203    pub shape: &'a [usize],
204    pub data_type: u32,
205}
206
207impl<'a> FeedDescription<'a> {
208    #[must_use]
209    pub const fn new(tensor: &'a Tensor, shape: &'a [usize], data_type: u32) -> Self {
210        Self {
211            tensor,
212            shape,
213            data_type,
214        }
215    }
216}
217
218/// Plain-Rust configuration for `MPSGraphConvolution2DOpDescriptor`.
219#[derive(Debug, Clone, Copy)]
220pub struct Convolution2DDescriptorInfo {
221    pub stride_in_x: usize,
222    pub stride_in_y: usize,
223    pub dilation_rate_in_x: usize,
224    pub dilation_rate_in_y: usize,
225    pub groups: usize,
226    pub padding_left: usize,
227    pub padding_right: usize,
228    pub padding_top: usize,
229    pub padding_bottom: usize,
230    pub padding_style: usize,
231    pub data_layout: usize,
232    pub weights_layout: usize,
233}
234
235impl Default for Convolution2DDescriptorInfo {
236    fn default() -> Self {
237        Self {
238            stride_in_x: 1,
239            stride_in_y: 1,
240            dilation_rate_in_x: 1,
241            dilation_rate_in_y: 1,
242            groups: 1,
243            padding_left: 0,
244            padding_right: 0,
245            padding_top: 0,
246            padding_bottom: 0,
247            padding_style: padding_style::EXPLICIT,
248            data_layout: tensor_named_data_layout::NHWC,
249            weights_layout: tensor_named_data_layout::HWIO,
250        }
251    }
252}
253
254opaque_handle!(Convolution2DDescriptor);
255impl Convolution2DDescriptor {
256    #[must_use]
257    pub fn new(info: Convolution2DDescriptorInfo) -> Option<Self> {
258        // SAFETY: All scalar configuration values are POD.
259        let ptr = unsafe {
260            ffi::mpsgraph_convolution2d_descriptor_new(
261                info.stride_in_x,
262                info.stride_in_y,
263                info.dilation_rate_in_x,
264                info.dilation_rate_in_y,
265                info.groups,
266                info.padding_left,
267                info.padding_right,
268                info.padding_top,
269                info.padding_bottom,
270                info.padding_style,
271                info.data_layout,
272                info.weights_layout,
273            )
274        };
275        if ptr.is_null() {
276            None
277        } else {
278            Some(Self { ptr })
279        }
280    }
281}
282
283/// Plain-Rust configuration for `MPSGraphPooling2DOpDescriptor`.
284#[derive(Debug, Clone, Copy)]
285pub struct Pooling2DDescriptorInfo {
286    pub kernel_width: usize,
287    pub kernel_height: usize,
288    pub stride_in_x: usize,
289    pub stride_in_y: usize,
290    pub dilation_rate_in_x: usize,
291    pub dilation_rate_in_y: usize,
292    pub padding_left: usize,
293    pub padding_right: usize,
294    pub padding_top: usize,
295    pub padding_bottom: usize,
296    pub padding_style: usize,
297    pub data_layout: usize,
298}
299
300impl Pooling2DDescriptorInfo {
301    #[must_use]
302    pub const fn new(kernel_width: usize, kernel_height: usize) -> Self {
303        Self {
304            kernel_width,
305            kernel_height,
306            stride_in_x: 1,
307            stride_in_y: 1,
308            dilation_rate_in_x: 1,
309            dilation_rate_in_y: 1,
310            padding_left: 0,
311            padding_right: 0,
312            padding_top: 0,
313            padding_bottom: 0,
314            padding_style: padding_style::EXPLICIT,
315            data_layout: tensor_named_data_layout::NHWC,
316        }
317    }
318}
319
320opaque_handle!(Pooling2DDescriptor);
321impl Pooling2DDescriptor {
322    #[must_use]
323    pub fn new(info: Pooling2DDescriptorInfo) -> Option<Self> {
324        // SAFETY: All scalar configuration values are POD.
325        let ptr = unsafe {
326            ffi::mpsgraph_pooling2d_descriptor_new(
327                info.kernel_width,
328                info.kernel_height,
329                info.stride_in_x,
330                info.stride_in_y,
331                info.dilation_rate_in_x,
332                info.dilation_rate_in_y,
333                info.padding_left,
334                info.padding_right,
335                info.padding_top,
336                info.padding_bottom,
337                info.padding_style,
338                info.data_layout,
339            )
340        };
341        if ptr.is_null() {
342            None
343        } else {
344            Some(Self { ptr })
345        }
346    }
347}
348
349opaque_handle!(Graph);
350opaque_handle!(Tensor);
351
352impl Tensor {
353    pub(crate) const fn from_raw(ptr: *mut c_void) -> Self {
354        Self { ptr }
355    }
356}
357
358impl Graph {
359    #[must_use]
360    pub fn new() -> Option<Self> {
361        // SAFETY: Pure constructor with no inputs.
362        let ptr = unsafe { ffi::mpsgraph_graph_new() };
363        if ptr.is_null() {
364            None
365        } else {
366            Some(Self { ptr })
367        }
368    }
369
370    #[must_use]
371    pub fn placeholder(
372        &self,
373        shape: Option<&[usize]>,
374        data_type: u32,
375        name: Option<&str>,
376    ) -> Option<Tensor> {
377        let name = optional_cstring(name);
378        let (shape_ptr, shape_len) =
379            shape.map_or((ptr::null(), 0), |shape| (shape.as_ptr(), shape.len()));
380
381        // SAFETY: All pointers originate from safe wrappers and remain alive for the duration of the call.
382        let ptr = unsafe {
383            ffi::mpsgraph_graph_placeholder(
384                self.ptr,
385                shape_ptr,
386                shape_len,
387                data_type,
388                cstring_ptr(&name),
389            )
390        };
391        wrap_tensor(ptr)
392    }
393
394    #[must_use]
395    pub fn constant_bytes(&self, data: &[u8], shape: &[usize], data_type: u32) -> Option<Tensor> {
396        let expected = checked_byte_len(shape, data_type)?;
397        if data.len() != expected {
398            return None;
399        }
400
401        // SAFETY: The byte slice remains valid for the duration of the FFI call.
402        let ptr = unsafe {
403            ffi::mpsgraph_graph_constant_data(
404                self.ptr,
405                data.as_ptr().cast(),
406                data.len(),
407                shape.as_ptr(),
408                shape.len(),
409                data_type,
410            )
411        };
412        wrap_tensor(ptr)
413    }
414
415    #[must_use]
416    pub fn constant_f32_slice(&self, values: &[f32], shape: &[usize]) -> Option<Tensor> {
417        // SAFETY: `values` is a contiguous slice of `f32` that may be viewed as bytes.
418        let bytes = unsafe {
419            core::slice::from_raw_parts(
420                values.as_ptr().cast::<u8>(),
421                core::mem::size_of_val(values),
422            )
423        };
424        self.constant_bytes(bytes, shape, data_type::FLOAT32)
425    }
426
427    #[must_use]
428    pub fn constant_scalar(&self, scalar: f64, data_type: u32) -> Option<Tensor> {
429        // SAFETY: Pure constructor over scalar inputs.
430        let ptr = unsafe { ffi::mpsgraph_graph_constant_scalar(self.ptr, scalar, data_type) };
431        wrap_tensor(ptr)
432    }
433
434    #[must_use]
435    pub fn constant_scalar_shaped(
436        &self,
437        scalar: f64,
438        shape: &[usize],
439        data_type: u32,
440    ) -> Option<Tensor> {
441        // SAFETY: Shape slice stays valid for the duration of the FFI call.
442        let ptr = unsafe {
443            ffi::mpsgraph_graph_constant_scalar_shaped(
444                self.ptr,
445                scalar,
446                shape.as_ptr(),
447                shape.len(),
448                data_type,
449            )
450        };
451        wrap_tensor(ptr)
452    }
453
454    impl_binary_tensor_op!(addition, mpsgraph_graph_addition);
455    impl_binary_tensor_op!(subtraction, mpsgraph_graph_subtraction);
456    impl_binary_tensor_op!(multiplication, mpsgraph_graph_multiplication);
457    impl_binary_tensor_op!(division, mpsgraph_graph_division);
458    impl_binary_tensor_op!(matrix_multiplication, mpsgraph_graph_matrix_multiplication);
459    impl_unary_tensor_op!(relu, mpsgraph_graph_relu);
460    impl_unary_tensor_op!(sigmoid, mpsgraph_graph_sigmoid);
461    impl_axes_tensor_op!(reduction_sum, mpsgraph_graph_reduction_sum);
462    impl_axes_tensor_op!(reduction_maximum, mpsgraph_graph_reduction_maximum);
463    impl_axes_tensor_op!(reduction_minimum, mpsgraph_graph_reduction_minimum);
464    impl_axes_tensor_op!(mean, mpsgraph_graph_mean);
465
466    #[must_use]
467    pub fn softmax(&self, tensor: &Tensor, axis: isize, name: Option<&str>) -> Option<Tensor> {
468        let name = optional_cstring(name);
469        // SAFETY: All pointers originate from safe wrappers and remain alive for the duration of the call.
470        let ptr = unsafe {
471            ffi::mpsgraph_graph_softmax(self.ptr, tensor.as_ptr(), axis, cstring_ptr(&name))
472        };
473        wrap_tensor(ptr)
474    }
475
476    #[must_use]
477    pub fn reshape(&self, tensor: &Tensor, shape: &[usize], name: Option<&str>) -> Option<Tensor> {
478        let name = optional_cstring(name);
479        // SAFETY: All pointers originate from safe wrappers and remain alive for the duration of the call.
480        let ptr = unsafe {
481            ffi::mpsgraph_graph_reshape(
482                self.ptr,
483                tensor.as_ptr(),
484                shape.as_ptr(),
485                shape.len(),
486                cstring_ptr(&name),
487            )
488        };
489        wrap_tensor(ptr)
490    }
491
492    #[must_use]
493    pub fn transpose(
494        &self,
495        tensor: &Tensor,
496        permutation: &[usize],
497        name: Option<&str>,
498    ) -> Option<Tensor> {
499        let name = optional_cstring(name);
500        // SAFETY: All pointers originate from safe wrappers and remain alive for the duration of the call.
501        let ptr = unsafe {
502            ffi::mpsgraph_graph_transpose(
503                self.ptr,
504                tensor.as_ptr(),
505                permutation.as_ptr(),
506                permutation.len(),
507                cstring_ptr(&name),
508            )
509        };
510        wrap_tensor(ptr)
511    }
512
513    #[must_use]
514    pub fn slice(
515        &self,
516        tensor: &Tensor,
517        dimension: usize,
518        start: isize,
519        length: isize,
520        name: Option<&str>,
521    ) -> Option<Tensor> {
522        let name = optional_cstring(name);
523        // SAFETY: All pointers originate from safe wrappers and remain alive for the duration of the call.
524        let ptr = unsafe {
525            ffi::mpsgraph_graph_slice(
526                self.ptr,
527                tensor.as_ptr(),
528                dimension,
529                start,
530                length,
531                cstring_ptr(&name),
532            )
533        };
534        wrap_tensor(ptr)
535    }
536
537    #[must_use]
538    pub fn broadcast(
539        &self,
540        tensor: &Tensor,
541        shape: &[usize],
542        name: Option<&str>,
543    ) -> Option<Tensor> {
544        let name = optional_cstring(name);
545        // SAFETY: All pointers originate from safe wrappers and remain alive for the duration of the call.
546        let ptr = unsafe {
547            ffi::mpsgraph_graph_broadcast(
548                self.ptr,
549                tensor.as_ptr(),
550                shape.as_ptr(),
551                shape.len(),
552                cstring_ptr(&name),
553            )
554        };
555        wrap_tensor(ptr)
556    }
557
558    #[must_use]
559    pub fn convolution2d(
560        &self,
561        source: &Tensor,
562        weights: &Tensor,
563        descriptor: &Convolution2DDescriptor,
564        name: Option<&str>,
565    ) -> Option<Tensor> {
566        let name = optional_cstring(name);
567        // SAFETY: All pointers originate from safe wrappers and remain alive for the duration of the call.
568        let ptr = unsafe {
569            ffi::mpsgraph_graph_convolution2d(
570                self.ptr,
571                source.as_ptr(),
572                weights.as_ptr(),
573                descriptor.as_ptr(),
574                cstring_ptr(&name),
575            )
576        };
577        wrap_tensor(ptr)
578    }
579
580    #[must_use]
581    pub fn max_pooling2d(
582        &self,
583        source: &Tensor,
584        descriptor: &Pooling2DDescriptor,
585        name: Option<&str>,
586    ) -> Option<Tensor> {
587        let name = optional_cstring(name);
588        // SAFETY: All pointers originate from safe wrappers and remain alive for the duration of the call.
589        let ptr = unsafe {
590            ffi::mpsgraph_graph_max_pooling2d(
591                self.ptr,
592                source.as_ptr(),
593                descriptor.as_ptr(),
594                cstring_ptr(&name),
595            )
596        };
597        wrap_tensor(ptr)
598    }
599
600    #[allow(clippy::too_many_arguments)]
601    #[must_use]
602    pub fn normalize(
603        &self,
604        tensor: &Tensor,
605        mean: &Tensor,
606        variance: &Tensor,
607        gamma: Option<&Tensor>,
608        beta: Option<&Tensor>,
609        epsilon: f32,
610        name: Option<&str>,
611    ) -> Option<Tensor> {
612        let name = optional_cstring(name);
613        let gamma_ptr = gamma.map_or(ptr::null_mut(), Tensor::as_ptr);
614        let beta_ptr = beta.map_or(ptr::null_mut(), Tensor::as_ptr);
615        // SAFETY: All pointers originate from safe wrappers and remain alive for the duration of the call.
616        let ptr = unsafe {
617            ffi::mpsgraph_graph_normalize(
618                self.ptr,
619                tensor.as_ptr(),
620                mean.as_ptr(),
621                variance.as_ptr(),
622                gamma_ptr,
623                beta_ptr,
624                epsilon,
625                cstring_ptr(&name),
626            )
627        };
628        wrap_tensor(ptr)
629    }
630
631    pub fn run(&self, feeds: &[Feed<'_>], targets: &[&Tensor]) -> Result<Vec<TensorData>> {
632        let feed_tensors = feeds
633            .iter()
634            .map(|feed| feed.tensor.as_ptr())
635            .collect::<Vec<_>>();
636        let feed_data = feeds
637            .iter()
638            .map(|feed| feed.data.as_ptr())
639            .collect::<Vec<_>>();
640        let target_tensors = targets
641            .iter()
642            .map(|tensor| tensor.as_ptr())
643            .collect::<Vec<_>>();
644        let mut results = vec![ptr::null_mut(); targets.len()];
645
646        // SAFETY: The pointer arrays are valid for the duration of the FFI call.
647        let ok = unsafe {
648            ffi::mpsgraph_graph_run(
649                self.ptr,
650                feed_tensors.as_ptr(),
651                feed_data.as_ptr(),
652                feeds.len(),
653                target_tensors.as_ptr(),
654                targets.len(),
655                results.as_mut_ptr(),
656            )
657        };
658        if ok {
659            wrap_tensor_data_results(results, "failed to run graph")
660        } else {
661            Err(Error::OperationFailed("failed to run graph"))
662        }
663    }
664
665    pub fn run_with_command_queue(
666        &self,
667        command_queue: &CommandQueue,
668        feeds: &[Feed<'_>],
669        targets: &[&Tensor],
670    ) -> Result<Vec<TensorData>> {
671        let feed_tensors = feeds
672            .iter()
673            .map(|feed| feed.tensor.as_ptr())
674            .collect::<Vec<_>>();
675        let feed_data = feeds
676            .iter()
677            .map(|feed| feed.data.as_ptr())
678            .collect::<Vec<_>>();
679        let target_tensors = targets
680            .iter()
681            .map(|tensor| tensor.as_ptr())
682            .collect::<Vec<_>>();
683        let mut results = vec![ptr::null_mut(); targets.len()];
684
685        // SAFETY: The pointer arrays are valid for the duration of the FFI call.
686        let ok = unsafe {
687            ffi::mpsgraph_graph_run_with_command_queue(
688                self.ptr,
689                command_queue.as_ptr(),
690                feed_tensors.as_ptr(),
691                feed_data.as_ptr(),
692                feeds.len(),
693                target_tensors.as_ptr(),
694                targets.len(),
695                results.as_mut_ptr(),
696            )
697        };
698        if ok {
699            wrap_tensor_data_results(results, "failed to run graph with command queue")
700        } else {
701            Err(Error::OperationFailed(
702                "failed to run graph with command queue",
703            ))
704        }
705    }
706
707    #[must_use]
708    pub fn compile(
709        &self,
710        device: &MetalDevice,
711        feeds: &[FeedDescription<'_>],
712        targets: &[&Tensor],
713    ) -> Option<Executable> {
714        let feed_tensors = feeds
715            .iter()
716            .map(|feed| feed.tensor.as_ptr())
717            .collect::<Vec<_>>();
718        let shape_lengths = feeds
719            .iter()
720            .map(|feed| feed.shape.len())
721            .collect::<Vec<_>>();
722        let data_types = feeds.iter().map(|feed| feed.data_type).collect::<Vec<_>>();
723        let flat_shapes = feeds
724            .iter()
725            .flat_map(|feed| feed.shape.iter().copied())
726            .collect::<Vec<_>>();
727        let target_tensors = targets
728            .iter()
729            .map(|tensor| tensor.as_ptr())
730            .collect::<Vec<_>>();
731
732        // SAFETY: The pointer arrays are valid for the duration of the FFI call.
733        let ptr = unsafe {
734            ffi::mpsgraph_graph_compile(
735                self.ptr,
736                device.as_ptr(),
737                feed_tensors.as_ptr(),
738                feeds.len(),
739                flat_shapes.as_ptr(),
740                shape_lengths.as_ptr(),
741                data_types.as_ptr(),
742                target_tensors.as_ptr(),
743                targets.len(),
744            )
745        };
746        if ptr.is_null() {
747            None
748        } else {
749            Some(Executable::from_raw(ptr, targets.len()))
750        }
751    }
752}
753
754/// Safe owner for a compiled `MPSGraphExecutable`.
755pub struct Executable {
756    ptr: *mut c_void,
757    output_count: usize,
758}
759
760unsafe impl Send for Executable {}
761unsafe impl Sync for Executable {}
762
763impl Drop for Executable {
764    fn drop(&mut self) {
765        if !self.ptr.is_null() {
766            // SAFETY: `ptr` is a +1 retained Swift/ObjC object pointer owned by this wrapper.
767            unsafe { ffi::mpsgraph_object_release(self.ptr) };
768            self.ptr = ptr::null_mut();
769        }
770    }
771}
772
773impl Executable {
774    pub(crate) const fn from_raw(ptr: *mut c_void, output_count: usize) -> Self {
775        Self { ptr, output_count }
776    }
777
778    #[must_use]
779    pub const fn as_ptr(&self) -> *mut c_void {
780        self.ptr
781    }
782
783    #[must_use]
784    pub const fn output_count(&self) -> usize {
785        self.output_count
786    }
787
788    pub fn run(
789        &self,
790        command_queue: &CommandQueue,
791        inputs: &[&TensorData],
792    ) -> Result<Vec<TensorData>> {
793        let input_data = inputs
794            .iter()
795            .map(|tensor_data| tensor_data.as_ptr())
796            .collect::<Vec<_>>();
797        let mut results = vec![ptr::null_mut(); self.output_count];
798
799        // SAFETY: The pointer arrays are valid for the duration of the FFI call.
800        let ok = unsafe {
801            ffi::mpsgraph_executable_run(
802                self.ptr,
803                command_queue.as_ptr(),
804                input_data.as_ptr(),
805                inputs.len(),
806                self.output_count,
807                results.as_mut_ptr(),
808            )
809        };
810        if ok {
811            wrap_tensor_data_results(results, "failed to run executable")
812        } else {
813            Err(Error::OperationFailed("failed to run executable"))
814        }
815    }
816}