Skip to main content

apple_mpsgraph/
graph.rs

1use crate::data::TensorData;
2use crate::error::{Error, Result};
3use crate::ffi;
4use apple_metal::{CommandQueue, MetalDevice};
5use core::ffi::{c_char, c_void};
6use core::ptr;
7use std::ffi::CString;
8
9/// Selected `MPSDataType` constants useful for graph inputs and outputs.
10pub mod data_type {
11    pub const INVALID: u32 = 0;
12    pub const FLOAT32: u32 = 0x1000_0020;
13    pub const FLOAT16: u32 = 0x1000_0010;
14    pub const INT8: u32 = 0x2000_0008;
15    pub const INT16: u32 = 0x2000_0010;
16    pub const INT32: u32 = 0x2000_0020;
17    pub const INT64: u32 = 0x2000_0040;
18    pub const UINT8: u32 = 0x0000_0008;
19    pub const UINT16: u32 = 0x0000_0010;
20    pub const UINT32: u32 = 0x0000_0020;
21    pub const UINT64: u32 = 0x0000_0040;
22    pub const BOOL: u32 = 0x8000_0008;
23    pub const UNORM8: u32 = 0x4000_0008;
24}
25
26/// Return the byte width of a supported `MPSDataType`.
27#[must_use]
28pub const fn data_type_size(data_type: u32) -> Option<usize> {
29    match data_type {
30        data_type::FLOAT16 | data_type::INT16 | data_type::UINT16 => Some(2),
31        data_type::FLOAT32 | data_type::INT32 | data_type::UINT32 => Some(4),
32        data_type::INT64 | data_type::UINT64 => Some(8),
33        data_type::INT8 | data_type::UINT8 | data_type::BOOL | data_type::UNORM8 => Some(1),
34        _ => None,
35    }
36}
37
38/// `MPSGraphTensorNamedDataLayout` constants.
39pub mod tensor_named_data_layout {
40    pub const NCHW: usize = 0;
41    pub const NHWC: usize = 1;
42    pub const OIHW: usize = 2;
43    pub const HWIO: usize = 3;
44    pub const CHW: usize = 4;
45    pub const HWC: usize = 5;
46    pub const HW: usize = 6;
47    pub const NCDHW: usize = 7;
48    pub const NDHWC: usize = 8;
49    pub const OIDHW: usize = 9;
50    pub const DHWIO: usize = 10;
51}
52
53/// `MPSGraphPaddingStyle` constants.
54pub mod padding_style {
55    pub const EXPLICIT: usize = 0;
56    pub const TF_VALID: usize = 1;
57    pub const TF_SAME: usize = 2;
58    pub const EXPLICIT_OFFSET: usize = 3;
59    pub const ONNX_SAME_LOWER: usize = 4;
60}
61
62/// `MPSGraphPaddingMode` constants.
63pub mod padding_mode {
64    pub const CONSTANT: isize = 0;
65    pub const REFLECT: isize = 1;
66    pub const SYMMETRIC: isize = 2;
67    pub const CLAMP_TO_EDGE: isize = 3;
68    pub const ZERO: isize = 4;
69    pub const PERIODIC: isize = 5;
70    pub const ANTI_PERIODIC: isize = 6;
71}
72
73macro_rules! opaque_handle {
74    ($name:ident) => {
75        pub struct $name {
76            ptr: *mut c_void,
77        }
78
79        unsafe impl Send for $name {}
80        unsafe impl Sync for $name {}
81
82        impl Drop for $name {
83            fn drop(&mut self) {
84                if !self.ptr.is_null() {
85                    // SAFETY: `ptr` is a +1 retained Swift/ObjC object pointer owned by this wrapper.
86                    unsafe { ffi::mpsgraph_object_release(self.ptr) };
87                    self.ptr = ptr::null_mut();
88                }
89            }
90        }
91
92        impl $name {
93            #[must_use]
94            pub const fn as_ptr(&self) -> *mut c_void {
95                self.ptr
96            }
97        }
98    };
99}
100
101fn checked_byte_len(shape: &[usize], data_type: u32) -> Option<usize> {
102    let element_size = data_type_size(data_type)?;
103    shape
104        .iter()
105        .try_fold(element_size, |acc, dimension| acc.checked_mul(*dimension))
106}
107
108fn optional_cstring(name: Option<&str>) -> Option<CString> {
109    name.and_then(|value| CString::new(value).ok())
110}
111
112#[allow(clippy::ref_option)]
113fn cstring_ptr(value: &Option<CString>) -> *const c_char {
114    value.as_ref().map_or(ptr::null(), |value| value.as_ptr())
115}
116
117fn wrap_tensor(ptr: *mut c_void) -> Option<Tensor> {
118    if ptr.is_null() {
119        None
120    } else {
121        Some(Tensor { ptr })
122    }
123}
124
125fn wrap_tensor_data_results(
126    handles: Vec<*mut c_void>,
127    message: &'static str,
128) -> Result<Vec<TensorData>> {
129    let mut results = Vec::with_capacity(handles.len());
130    for handle in handles {
131        if handle.is_null() {
132            return Err(Error::OperationFailed(message));
133        }
134        results.push(TensorData::from_raw(handle));
135    }
136    Ok(results)
137}
138
139macro_rules! impl_binary_tensor_op {
140    ($fn_name:ident, $ffi_name:ident) => {
141        #[must_use]
142        pub fn $fn_name(
143            &self,
144            primary: &Tensor,
145            secondary: &Tensor,
146            name: Option<&str>,
147        ) -> Option<Tensor> {
148            let name = optional_cstring(name);
149            // SAFETY: All pointers originate from safe wrappers and remain alive for the duration of the call.
150            let ptr = unsafe {
151                ffi::$ffi_name(
152                    self.ptr,
153                    primary.as_ptr(),
154                    secondary.as_ptr(),
155                    cstring_ptr(&name),
156                )
157            };
158            wrap_tensor(ptr)
159        }
160    };
161}
162
163macro_rules! impl_unary_tensor_op {
164    ($fn_name:ident, $ffi_name:ident) => {
165        #[must_use]
166        pub fn $fn_name(&self, tensor: &Tensor, name: Option<&str>) -> Option<Tensor> {
167            let name = optional_cstring(name);
168            // SAFETY: All pointers originate from safe wrappers and remain alive for the duration of the call.
169            let ptr = unsafe { ffi::$ffi_name(self.ptr, tensor.as_ptr(), cstring_ptr(&name)) };
170            wrap_tensor(ptr)
171        }
172    };
173}
174
175macro_rules! impl_axes_tensor_op {
176    ($fn_name:ident, $ffi_name:ident) => {
177        #[must_use]
178        pub fn $fn_name(
179            &self,
180            tensor: &Tensor,
181            axes: &[usize],
182            name: Option<&str>,
183        ) -> Option<Tensor> {
184            let name = optional_cstring(name);
185            // SAFETY: All pointers originate from safe wrappers and remain alive for the duration of the call.
186            let ptr = unsafe {
187                ffi::$ffi_name(
188                    self.ptr,
189                    tensor.as_ptr(),
190                    axes.as_ptr(),
191                    axes.len(),
192                    cstring_ptr(&name),
193                )
194            };
195            wrap_tensor(ptr)
196        }
197    };
198}
199
200/// Ordered placeholder feed pairing used for graph execution.
201#[derive(Clone, Copy)]
202pub struct Feed<'a> {
203    pub tensor: &'a Tensor,
204    pub data: &'a TensorData,
205}
206
207impl<'a> Feed<'a> {
208    #[must_use]
209    pub const fn new(tensor: &'a Tensor, data: &'a TensorData) -> Self {
210        Self { tensor, data }
211    }
212}
213
214/// Feed metadata used to compile a graph into an executable.
215#[derive(Clone, Copy)]
216pub struct FeedDescription<'a> {
217    pub tensor: &'a Tensor,
218    pub shape: &'a [usize],
219    pub data_type: u32,
220}
221
222impl<'a> FeedDescription<'a> {
223    #[must_use]
224    pub const fn new(tensor: &'a Tensor, shape: &'a [usize], data_type: u32) -> Self {
225        Self {
226            tensor,
227            shape,
228            data_type,
229        }
230    }
231}
232
233/// Plain-Rust configuration for `MPSGraphConvolution2DOpDescriptor`.
234#[derive(Debug, Clone, Copy)]
235pub struct Convolution2DDescriptorInfo {
236    pub stride_in_x: usize,
237    pub stride_in_y: usize,
238    pub dilation_rate_in_x: usize,
239    pub dilation_rate_in_y: usize,
240    pub groups: usize,
241    pub padding_left: usize,
242    pub padding_right: usize,
243    pub padding_top: usize,
244    pub padding_bottom: usize,
245    pub padding_style: usize,
246    pub data_layout: usize,
247    pub weights_layout: usize,
248}
249
250impl Default for Convolution2DDescriptorInfo {
251    fn default() -> Self {
252        Self {
253            stride_in_x: 1,
254            stride_in_y: 1,
255            dilation_rate_in_x: 1,
256            dilation_rate_in_y: 1,
257            groups: 1,
258            padding_left: 0,
259            padding_right: 0,
260            padding_top: 0,
261            padding_bottom: 0,
262            padding_style: padding_style::EXPLICIT,
263            data_layout: tensor_named_data_layout::NHWC,
264            weights_layout: tensor_named_data_layout::HWIO,
265        }
266    }
267}
268
269opaque_handle!(Convolution2DDescriptor);
270impl Convolution2DDescriptor {
271    #[must_use]
272    pub fn new(info: Convolution2DDescriptorInfo) -> Option<Self> {
273        // SAFETY: All scalar configuration values are POD.
274        let ptr = unsafe {
275            ffi::mpsgraph_convolution2d_descriptor_new(
276                info.stride_in_x,
277                info.stride_in_y,
278                info.dilation_rate_in_x,
279                info.dilation_rate_in_y,
280                info.groups,
281                info.padding_left,
282                info.padding_right,
283                info.padding_top,
284                info.padding_bottom,
285                info.padding_style,
286                info.data_layout,
287                info.weights_layout,
288            )
289        };
290        if ptr.is_null() {
291            None
292        } else {
293            Some(Self { ptr })
294        }
295    }
296}
297
298/// Plain-Rust configuration for `MPSGraphPooling2DOpDescriptor`.
299#[derive(Debug, Clone, Copy)]
300pub struct Pooling2DDescriptorInfo {
301    pub kernel_width: usize,
302    pub kernel_height: usize,
303    pub stride_in_x: usize,
304    pub stride_in_y: usize,
305    pub dilation_rate_in_x: usize,
306    pub dilation_rate_in_y: usize,
307    pub padding_left: usize,
308    pub padding_right: usize,
309    pub padding_top: usize,
310    pub padding_bottom: usize,
311    pub padding_style: usize,
312    pub data_layout: usize,
313}
314
315impl Pooling2DDescriptorInfo {
316    #[must_use]
317    pub const fn new(kernel_width: usize, kernel_height: usize) -> Self {
318        Self {
319            kernel_width,
320            kernel_height,
321            stride_in_x: 1,
322            stride_in_y: 1,
323            dilation_rate_in_x: 1,
324            dilation_rate_in_y: 1,
325            padding_left: 0,
326            padding_right: 0,
327            padding_top: 0,
328            padding_bottom: 0,
329            padding_style: padding_style::EXPLICIT,
330            data_layout: tensor_named_data_layout::NHWC,
331        }
332    }
333}
334
335opaque_handle!(Pooling2DDescriptor);
336impl Pooling2DDescriptor {
337    #[must_use]
338    pub fn new(info: Pooling2DDescriptorInfo) -> Option<Self> {
339        // SAFETY: All scalar configuration values are POD.
340        let ptr = unsafe {
341            ffi::mpsgraph_pooling2d_descriptor_new(
342                info.kernel_width,
343                info.kernel_height,
344                info.stride_in_x,
345                info.stride_in_y,
346                info.dilation_rate_in_x,
347                info.dilation_rate_in_y,
348                info.padding_left,
349                info.padding_right,
350                info.padding_top,
351                info.padding_bottom,
352                info.padding_style,
353                info.data_layout,
354            )
355        };
356        if ptr.is_null() {
357            None
358        } else {
359            Some(Self { ptr })
360        }
361    }
362}
363
364opaque_handle!(Graph);
365opaque_handle!(Tensor);
366
367impl Tensor {
368    pub(crate) const fn from_raw(ptr: *mut c_void) -> Self {
369        Self { ptr }
370    }
371}
372
373impl Graph {
374    #[must_use]
375    pub fn new() -> Option<Self> {
376        // SAFETY: Pure constructor with no inputs.
377        let ptr = unsafe { ffi::mpsgraph_graph_new() };
378        if ptr.is_null() {
379            None
380        } else {
381            Some(Self { ptr })
382        }
383    }
384
385    #[must_use]
386    pub fn placeholder(
387        &self,
388        shape: Option<&[usize]>,
389        data_type: u32,
390        name: Option<&str>,
391    ) -> Option<Tensor> {
392        let name = optional_cstring(name);
393        let (shape_ptr, shape_len) =
394            shape.map_or((ptr::null(), 0), |shape| (shape.as_ptr(), shape.len()));
395
396        // SAFETY: All pointers originate from safe wrappers and remain alive for the duration of the call.
397        let ptr = unsafe {
398            ffi::mpsgraph_graph_placeholder(
399                self.ptr,
400                shape_ptr,
401                shape_len,
402                data_type,
403                cstring_ptr(&name),
404            )
405        };
406        wrap_tensor(ptr)
407    }
408
409    #[must_use]
410    pub fn constant_bytes(&self, data: &[u8], shape: &[usize], data_type: u32) -> Option<Tensor> {
411        let expected = checked_byte_len(shape, data_type)?;
412        if data.len() != expected {
413            return None;
414        }
415
416        // SAFETY: The byte slice remains valid for the duration of the FFI call.
417        let ptr = unsafe {
418            ffi::mpsgraph_graph_constant_data(
419                self.ptr,
420                data.as_ptr().cast(),
421                data.len(),
422                shape.as_ptr(),
423                shape.len(),
424                data_type,
425            )
426        };
427        wrap_tensor(ptr)
428    }
429
430    #[must_use]
431    pub fn constant_f32_slice(&self, values: &[f32], shape: &[usize]) -> Option<Tensor> {
432        // SAFETY: `values` is a contiguous slice of `f32` that may be viewed as bytes.
433        let bytes = unsafe {
434            core::slice::from_raw_parts(
435                values.as_ptr().cast::<u8>(),
436                core::mem::size_of_val(values),
437            )
438        };
439        self.constant_bytes(bytes, shape, data_type::FLOAT32)
440    }
441
442    #[must_use]
443    pub fn constant_scalar(&self, scalar: f64, data_type: u32) -> Option<Tensor> {
444        // SAFETY: Pure constructor over scalar inputs.
445        let ptr = unsafe { ffi::mpsgraph_graph_constant_scalar(self.ptr, scalar, data_type) };
446        wrap_tensor(ptr)
447    }
448
449    #[must_use]
450    pub fn constant_scalar_shaped(
451        &self,
452        scalar: f64,
453        shape: &[usize],
454        data_type: u32,
455    ) -> Option<Tensor> {
456        // SAFETY: Shape slice stays valid for the duration of the FFI call.
457        let ptr = unsafe {
458            ffi::mpsgraph_graph_constant_scalar_shaped(
459                self.ptr,
460                scalar,
461                shape.as_ptr(),
462                shape.len(),
463                data_type,
464            )
465        };
466        wrap_tensor(ptr)
467    }
468
469    impl_binary_tensor_op!(addition, mpsgraph_graph_addition);
470    impl_binary_tensor_op!(subtraction, mpsgraph_graph_subtraction);
471    impl_binary_tensor_op!(multiplication, mpsgraph_graph_multiplication);
472    impl_binary_tensor_op!(division, mpsgraph_graph_division);
473    impl_binary_tensor_op!(matrix_multiplication, mpsgraph_graph_matrix_multiplication);
474    impl_unary_tensor_op!(relu, mpsgraph_graph_relu);
475    impl_unary_tensor_op!(sigmoid, mpsgraph_graph_sigmoid);
476    impl_axes_tensor_op!(reduction_sum, mpsgraph_graph_reduction_sum);
477    impl_axes_tensor_op!(reduction_maximum, mpsgraph_graph_reduction_maximum);
478    impl_axes_tensor_op!(reduction_minimum, mpsgraph_graph_reduction_minimum);
479    impl_axes_tensor_op!(mean, mpsgraph_graph_mean);
480
481    #[must_use]
482    pub fn softmax(&self, tensor: &Tensor, axis: isize, name: Option<&str>) -> Option<Tensor> {
483        let name = optional_cstring(name);
484        // SAFETY: All pointers originate from safe wrappers and remain alive for the duration of the call.
485        let ptr = unsafe {
486            ffi::mpsgraph_graph_softmax(self.ptr, tensor.as_ptr(), axis, cstring_ptr(&name))
487        };
488        wrap_tensor(ptr)
489    }
490
491    #[must_use]
492    pub fn reshape(&self, tensor: &Tensor, shape: &[usize], name: Option<&str>) -> Option<Tensor> {
493        let name = optional_cstring(name);
494        // SAFETY: All pointers originate from safe wrappers and remain alive for the duration of the call.
495        let ptr = unsafe {
496            ffi::mpsgraph_graph_reshape(
497                self.ptr,
498                tensor.as_ptr(),
499                shape.as_ptr(),
500                shape.len(),
501                cstring_ptr(&name),
502            )
503        };
504        wrap_tensor(ptr)
505    }
506
507    #[must_use]
508    pub fn transpose(
509        &self,
510        tensor: &Tensor,
511        permutation: &[usize],
512        name: Option<&str>,
513    ) -> Option<Tensor> {
514        let name = optional_cstring(name);
515        // SAFETY: All pointers originate from safe wrappers and remain alive for the duration of the call.
516        let ptr = unsafe {
517            ffi::mpsgraph_graph_transpose(
518                self.ptr,
519                tensor.as_ptr(),
520                permutation.as_ptr(),
521                permutation.len(),
522                cstring_ptr(&name),
523            )
524        };
525        wrap_tensor(ptr)
526    }
527
528    #[must_use]
529    pub fn slice(
530        &self,
531        tensor: &Tensor,
532        dimension: usize,
533        start: isize,
534        length: isize,
535        name: Option<&str>,
536    ) -> Option<Tensor> {
537        let name = optional_cstring(name);
538        // SAFETY: All pointers originate from safe wrappers and remain alive for the duration of the call.
539        let ptr = unsafe {
540            ffi::mpsgraph_graph_slice(
541                self.ptr,
542                tensor.as_ptr(),
543                dimension,
544                start,
545                length,
546                cstring_ptr(&name),
547            )
548        };
549        wrap_tensor(ptr)
550    }
551
552    #[must_use]
553    pub fn broadcast(
554        &self,
555        tensor: &Tensor,
556        shape: &[usize],
557        name: Option<&str>,
558    ) -> Option<Tensor> {
559        let name = optional_cstring(name);
560        // SAFETY: All pointers originate from safe wrappers and remain alive for the duration of the call.
561        let ptr = unsafe {
562            ffi::mpsgraph_graph_broadcast(
563                self.ptr,
564                tensor.as_ptr(),
565                shape.as_ptr(),
566                shape.len(),
567                cstring_ptr(&name),
568            )
569        };
570        wrap_tensor(ptr)
571    }
572
573    #[must_use]
574    pub fn convolution2d(
575        &self,
576        source: &Tensor,
577        weights: &Tensor,
578        descriptor: &Convolution2DDescriptor,
579        name: Option<&str>,
580    ) -> Option<Tensor> {
581        let name = optional_cstring(name);
582        // SAFETY: All pointers originate from safe wrappers and remain alive for the duration of the call.
583        let ptr = unsafe {
584            ffi::mpsgraph_graph_convolution2d(
585                self.ptr,
586                source.as_ptr(),
587                weights.as_ptr(),
588                descriptor.as_ptr(),
589                cstring_ptr(&name),
590            )
591        };
592        wrap_tensor(ptr)
593    }
594
595    #[must_use]
596    pub fn max_pooling2d(
597        &self,
598        source: &Tensor,
599        descriptor: &Pooling2DDescriptor,
600        name: Option<&str>,
601    ) -> Option<Tensor> {
602        let name = optional_cstring(name);
603        // SAFETY: All pointers originate from safe wrappers and remain alive for the duration of the call.
604        let ptr = unsafe {
605            ffi::mpsgraph_graph_max_pooling2d(
606                self.ptr,
607                source.as_ptr(),
608                descriptor.as_ptr(),
609                cstring_ptr(&name),
610            )
611        };
612        wrap_tensor(ptr)
613    }
614
615    #[allow(clippy::too_many_arguments)]
616    #[must_use]
617    pub fn normalize(
618        &self,
619        tensor: &Tensor,
620        mean: &Tensor,
621        variance: &Tensor,
622        gamma: Option<&Tensor>,
623        beta: Option<&Tensor>,
624        epsilon: f32,
625        name: Option<&str>,
626    ) -> Option<Tensor> {
627        let name = optional_cstring(name);
628        let gamma_ptr = gamma.map_or(ptr::null_mut(), Tensor::as_ptr);
629        let beta_ptr = beta.map_or(ptr::null_mut(), Tensor::as_ptr);
630        // SAFETY: All pointers originate from safe wrappers and remain alive for the duration of the call.
631        let ptr = unsafe {
632            ffi::mpsgraph_graph_normalize(
633                self.ptr,
634                tensor.as_ptr(),
635                mean.as_ptr(),
636                variance.as_ptr(),
637                gamma_ptr,
638                beta_ptr,
639                epsilon,
640                cstring_ptr(&name),
641            )
642        };
643        wrap_tensor(ptr)
644    }
645
646    pub fn run(&self, feeds: &[Feed<'_>], targets: &[&Tensor]) -> Result<Vec<TensorData>> {
647        let feed_tensors = feeds
648            .iter()
649            .map(|feed| feed.tensor.as_ptr())
650            .collect::<Vec<_>>();
651        let feed_data = feeds
652            .iter()
653            .map(|feed| feed.data.as_ptr())
654            .collect::<Vec<_>>();
655        let target_tensors = targets
656            .iter()
657            .map(|tensor| tensor.as_ptr())
658            .collect::<Vec<_>>();
659        let mut results = vec![ptr::null_mut(); targets.len()];
660
661        // SAFETY: The pointer arrays are valid for the duration of the FFI call.
662        let ok = unsafe {
663            ffi::mpsgraph_graph_run(
664                self.ptr,
665                feed_tensors.as_ptr(),
666                feed_data.as_ptr(),
667                feeds.len(),
668                target_tensors.as_ptr(),
669                targets.len(),
670                results.as_mut_ptr(),
671            )
672        };
673        if ok {
674            wrap_tensor_data_results(results, "failed to run graph")
675        } else {
676            Err(Error::OperationFailed("failed to run graph"))
677        }
678    }
679
680    pub fn run_with_command_queue(
681        &self,
682        command_queue: &CommandQueue,
683        feeds: &[Feed<'_>],
684        targets: &[&Tensor],
685    ) -> Result<Vec<TensorData>> {
686        let feed_tensors = feeds
687            .iter()
688            .map(|feed| feed.tensor.as_ptr())
689            .collect::<Vec<_>>();
690        let feed_data = feeds
691            .iter()
692            .map(|feed| feed.data.as_ptr())
693            .collect::<Vec<_>>();
694        let target_tensors = targets
695            .iter()
696            .map(|tensor| tensor.as_ptr())
697            .collect::<Vec<_>>();
698        let mut results = vec![ptr::null_mut(); targets.len()];
699
700        // SAFETY: The pointer arrays are valid for the duration of the FFI call.
701        let ok = unsafe {
702            ffi::mpsgraph_graph_run_with_command_queue(
703                self.ptr,
704                command_queue.as_ptr(),
705                feed_tensors.as_ptr(),
706                feed_data.as_ptr(),
707                feeds.len(),
708                target_tensors.as_ptr(),
709                targets.len(),
710                results.as_mut_ptr(),
711            )
712        };
713        if ok {
714            wrap_tensor_data_results(results, "failed to run graph with command queue")
715        } else {
716            Err(Error::OperationFailed(
717                "failed to run graph with command queue",
718            ))
719        }
720    }
721
722    #[must_use]
723    pub fn compile(
724        &self,
725        device: &MetalDevice,
726        feeds: &[FeedDescription<'_>],
727        targets: &[&Tensor],
728    ) -> Option<Executable> {
729        let feed_tensors = feeds
730            .iter()
731            .map(|feed| feed.tensor.as_ptr())
732            .collect::<Vec<_>>();
733        let shape_lengths = feeds
734            .iter()
735            .map(|feed| feed.shape.len())
736            .collect::<Vec<_>>();
737        let data_types = feeds.iter().map(|feed| feed.data_type).collect::<Vec<_>>();
738        let flat_shapes = feeds
739            .iter()
740            .flat_map(|feed| feed.shape.iter().copied())
741            .collect::<Vec<_>>();
742        let target_tensors = targets
743            .iter()
744            .map(|tensor| tensor.as_ptr())
745            .collect::<Vec<_>>();
746
747        // SAFETY: The pointer arrays are valid for the duration of the FFI call.
748        let ptr = unsafe {
749            ffi::mpsgraph_graph_compile(
750                self.ptr,
751                device.as_ptr(),
752                feed_tensors.as_ptr(),
753                feeds.len(),
754                flat_shapes.as_ptr(),
755                shape_lengths.as_ptr(),
756                data_types.as_ptr(),
757                target_tensors.as_ptr(),
758                targets.len(),
759            )
760        };
761        if ptr.is_null() {
762            None
763        } else {
764            Some(Executable::from_raw(ptr, targets.len()))
765        }
766    }
767}
768
769/// Safe owner for a compiled `MPSGraphExecutable`.
770pub struct Executable {
771    ptr: *mut c_void,
772    output_count: usize,
773}
774
775unsafe impl Send for Executable {}
776unsafe impl Sync for Executable {}
777
778impl Drop for Executable {
779    fn drop(&mut self) {
780        if !self.ptr.is_null() {
781            // SAFETY: `ptr` is a +1 retained Swift/ObjC object pointer owned by this wrapper.
782            unsafe { ffi::mpsgraph_object_release(self.ptr) };
783            self.ptr = ptr::null_mut();
784        }
785    }
786}
787
788impl Executable {
789    pub(crate) const fn from_raw(ptr: *mut c_void, output_count: usize) -> Self {
790        Self { ptr, output_count }
791    }
792
793    #[must_use]
794    pub const fn as_ptr(&self) -> *mut c_void {
795        self.ptr
796    }
797
798    #[must_use]
799    pub const fn output_count(&self) -> usize {
800        self.output_count
801    }
802
803    pub fn run(
804        &self,
805        command_queue: &CommandQueue,
806        inputs: &[&TensorData],
807    ) -> Result<Vec<TensorData>> {
808        let input_data = inputs
809            .iter()
810            .map(|tensor_data| tensor_data.as_ptr())
811            .collect::<Vec<_>>();
812        let mut results = vec![ptr::null_mut(); self.output_count];
813
814        // SAFETY: The pointer arrays are valid for the duration of the FFI call.
815        let ok = unsafe {
816            ffi::mpsgraph_executable_run(
817                self.ptr,
818                command_queue.as_ptr(),
819                input_data.as_ptr(),
820                inputs.len(),
821                self.output_count,
822                results.as_mut_ptr(),
823            )
824        };
825        if ok {
826            wrap_tensor_data_results(results, "failed to run executable")
827        } else {
828            Err(Error::OperationFailed("failed to run executable"))
829        }
830    }
831}