Skip to main content

apple_mpsgraph/
graph.rs

1use crate::data::TensorData;
2use crate::error::{Error, Result};
3use crate::ffi;
4use apple_metal::{CommandQueue, MetalDevice};
5use core::ffi::{c_char, c_void};
6use core::ptr;
7use std::ffi::CString;
8
9/// Selected `MPSDataType` constants useful for graph inputs and outputs.
10pub mod data_type {
11/// Mirrors the `MPSGraph` framework constant `INVALID`.
12    pub const INVALID: u32 = 0;
13/// Mirrors the `MPSGraph` framework constant `FLOAT32`.
14    pub const FLOAT32: u32 = 0x1000_0020;
15/// Mirrors the `MPSGraph` framework constant `FLOAT16`.
16    pub const FLOAT16: u32 = 0x1000_0010;
17/// Mirrors the `MPSGraph` framework constant `INT8`.
18    pub const INT8: u32 = 0x2000_0008;
19/// Mirrors the `MPSGraph` framework constant `INT16`.
20    pub const INT16: u32 = 0x2000_0010;
21/// Mirrors the `MPSGraph` framework constant `INT32`.
22    pub const INT32: u32 = 0x2000_0020;
23/// Mirrors the `MPSGraph` framework constant `INT64`.
24    pub const INT64: u32 = 0x2000_0040;
25/// Mirrors the `MPSGraph` framework constant `UINT8`.
26    pub const UINT8: u32 = 0x0000_0008;
27/// Mirrors the `MPSGraph` framework constant `UINT16`.
28    pub const UINT16: u32 = 0x0000_0010;
29/// Mirrors the `MPSGraph` framework constant `UINT32`.
30    pub const UINT32: u32 = 0x0000_0020;
31/// Mirrors the `MPSGraph` framework constant `UINT64`.
32    pub const UINT64: u32 = 0x0000_0040;
33/// Mirrors the `MPSGraph` framework constant `BOOL`.
34    pub const BOOL: u32 = 0x8000_0008;
35/// Mirrors the `MPSGraph` framework constant `UNORM8`.
36    pub const UNORM8: u32 = 0x4000_0008;
37}
38
39/// Return the byte width of a supported `MPSDataType`.
40#[must_use]
41pub const fn data_type_size(data_type: u32) -> Option<usize> {
42    match data_type {
43        data_type::FLOAT16 | data_type::INT16 | data_type::UINT16 => Some(2),
44        data_type::FLOAT32 | data_type::INT32 | data_type::UINT32 => Some(4),
45        data_type::INT64 | data_type::UINT64 => Some(8),
46        data_type::INT8 | data_type::UINT8 | data_type::BOOL | data_type::UNORM8 => Some(1),
47        _ => None,
48    }
49}
50
51/// `MPSGraphTensorNamedDataLayout` constants.
52pub mod tensor_named_data_layout {
53/// Mirrors the `MPSGraph` framework constant `NCHW`.
54    pub const NCHW: usize = 0;
55/// Mirrors the `MPSGraph` framework constant `NHWC`.
56    pub const NHWC: usize = 1;
57/// Mirrors the `MPSGraph` framework constant `OIHW`.
58    pub const OIHW: usize = 2;
59/// Mirrors the `MPSGraph` framework constant `HWIO`.
60    pub const HWIO: usize = 3;
61/// Mirrors the `MPSGraph` framework constant `CHW`.
62    pub const CHW: usize = 4;
63/// Mirrors the `MPSGraph` framework constant `HWC`.
64    pub const HWC: usize = 5;
65/// Mirrors the `MPSGraph` framework constant `HW`.
66    pub const HW: usize = 6;
67/// Mirrors the `MPSGraph` framework constant `NCDHW`.
68    pub const NCDHW: usize = 7;
69/// Mirrors the `MPSGraph` framework constant `NDHWC`.
70    pub const NDHWC: usize = 8;
71/// Mirrors the `MPSGraph` framework constant `OIDHW`.
72    pub const OIDHW: usize = 9;
73/// Mirrors the `MPSGraph` framework constant `DHWIO`.
74    pub const DHWIO: usize = 10;
75}
76
77/// `MPSGraphPaddingStyle` constants.
78pub mod padding_style {
79/// Mirrors the `MPSGraph` framework constant `EXPLICIT`.
80    pub const EXPLICIT: usize = 0;
81/// Mirrors the `MPSGraph` framework constant `TF_VALID`.
82    pub const TF_VALID: usize = 1;
83/// Mirrors the `MPSGraph` framework constant `TF_SAME`.
84    pub const TF_SAME: usize = 2;
85/// Mirrors the `MPSGraph` framework constant `EXPLICIT_OFFSET`.
86    pub const EXPLICIT_OFFSET: usize = 3;
87/// Mirrors the `MPSGraph` framework constant `ONNX_SAME_LOWER`.
88    pub const ONNX_SAME_LOWER: usize = 4;
89}
90
91/// `MPSGraphPaddingMode` constants.
92pub mod padding_mode {
93/// Mirrors the `MPSGraph` framework constant `CONSTANT`.
94    pub const CONSTANT: isize = 0;
95/// Mirrors the `MPSGraph` framework constant `REFLECT`.
96    pub const REFLECT: isize = 1;
97/// Mirrors the `MPSGraph` framework constant `SYMMETRIC`.
98    pub const SYMMETRIC: isize = 2;
99/// Mirrors the `MPSGraph` framework constant `CLAMP_TO_EDGE`.
100    pub const CLAMP_TO_EDGE: isize = 3;
101/// Mirrors the `MPSGraph` framework constant `ZERO`.
102    pub const ZERO: isize = 4;
103/// Mirrors the `MPSGraph` framework constant `PERIODIC`.
104    pub const PERIODIC: isize = 5;
105/// Mirrors the `MPSGraph` framework constant `ANTI_PERIODIC`.
106    pub const ANTI_PERIODIC: isize = 6;
107}
108
109macro_rules! opaque_handle {
110    ($name:ident) => {
111/// Mirrors the `MPSGraph` framework counterpart for this type.
112        pub struct $name {
113            ptr: *mut c_void,
114        }
115
116        unsafe impl Send for $name {}
117        unsafe impl Sync for $name {}
118
119        impl Drop for $name {
120            fn drop(&mut self) {
121                if !self.ptr.is_null() {
122                    // SAFETY: `ptr` is a +1 retained Swift/ObjC object pointer owned by this wrapper.
123                    unsafe { ffi::mpsgraph_object_release(self.ptr) };
124                    self.ptr = ptr::null_mut();
125                }
126            }
127        }
128
129        impl $name {
130/// Mirrors the `MPSGraph` framework constant `fn`.
131            #[must_use]
132            pub const fn as_ptr(&self) -> *mut c_void {
133                self.ptr
134            }
135        }
136    };
137}
138
139fn checked_byte_len(shape: &[usize], data_type: u32) -> Option<usize> {
140    let element_size = data_type_size(data_type)?;
141    shape
142        .iter()
143        .try_fold(element_size, |acc, dimension| acc.checked_mul(*dimension))
144}
145
146fn optional_cstring(name: Option<&str>) -> Option<CString> {
147    name.and_then(|value| CString::new(value).ok())
148}
149
150#[allow(clippy::ref_option)]
151fn cstring_ptr(value: &Option<CString>) -> *const c_char {
152    value.as_ref().map_or(ptr::null(), |value| value.as_ptr())
153}
154
155fn wrap_tensor(ptr: *mut c_void) -> Option<Tensor> {
156    if ptr.is_null() {
157        None
158    } else {
159        Some(Tensor { ptr })
160    }
161}
162
163fn wrap_tensor_data_results(
164    handles: Vec<*mut c_void>,
165    message: &'static str,
166) -> Result<Vec<TensorData>> {
167    let mut results = Vec::with_capacity(handles.len());
168    for handle in handles {
169        if handle.is_null() {
170            return Err(Error::OperationFailed(message));
171        }
172        results.push(TensorData::from_raw(handle));
173    }
174    Ok(results)
175}
176
177macro_rules! impl_binary_tensor_op {
178    ($fn_name:ident, $ffi_name:ident) => {
179/// Calls the `MPSGraph` framework counterpart for this method.
180        #[must_use]
181        pub fn $fn_name(
182            &self,
183            primary: &Tensor,
184            secondary: &Tensor,
185            name: Option<&str>,
186        ) -> Option<Tensor> {
187            let name = optional_cstring(name);
188            // SAFETY: All pointers originate from safe wrappers and remain alive for the duration of the call.
189            let ptr = unsafe {
190                ffi::$ffi_name(
191                    self.ptr,
192                    primary.as_ptr(),
193                    secondary.as_ptr(),
194                    cstring_ptr(&name),
195                )
196            };
197            wrap_tensor(ptr)
198        }
199    };
200}
201
202macro_rules! impl_unary_tensor_op {
203    ($fn_name:ident, $ffi_name:ident) => {
204/// Calls the `MPSGraph` framework counterpart for this method.
205        #[must_use]
206        pub fn $fn_name(&self, tensor: &Tensor, name: Option<&str>) -> Option<Tensor> {
207            let name = optional_cstring(name);
208            // SAFETY: All pointers originate from safe wrappers and remain alive for the duration of the call.
209            let ptr = unsafe { ffi::$ffi_name(self.ptr, tensor.as_ptr(), cstring_ptr(&name)) };
210            wrap_tensor(ptr)
211        }
212    };
213}
214
215macro_rules! impl_axes_tensor_op {
216    ($fn_name:ident, $ffi_name:ident) => {
217/// Calls the `MPSGraph` framework counterpart for this method.
218        #[must_use]
219        pub fn $fn_name(
220            &self,
221            tensor: &Tensor,
222            axes: &[usize],
223            name: Option<&str>,
224        ) -> Option<Tensor> {
225            let name = optional_cstring(name);
226            // SAFETY: All pointers originate from safe wrappers and remain alive for the duration of the call.
227            let ptr = unsafe {
228                ffi::$ffi_name(
229                    self.ptr,
230                    tensor.as_ptr(),
231                    axes.as_ptr(),
232                    axes.len(),
233                    cstring_ptr(&name),
234                )
235            };
236            wrap_tensor(ptr)
237        }
238    };
239}
240
241/// Ordered placeholder feed pairing used for graph execution.
242#[derive(Clone, Copy)]
243pub struct Feed<'a> {
244/// Mirrors the `MPSGraph` framework property for `tensor`.
245    pub tensor: &'a Tensor,
246/// Mirrors the `MPSGraph` framework property for `data`.
247    pub data: &'a TensorData,
248}
249
250impl<'a> Feed<'a> {
251/// Mirrors the `MPSGraph` framework constant `fn`.
252    #[must_use]
253    pub const fn new(tensor: &'a Tensor, data: &'a TensorData) -> Self {
254        Self { tensor, data }
255    }
256}
257
258/// Feed metadata used to compile a graph into an executable.
259#[derive(Clone, Copy)]
260pub struct FeedDescription<'a> {
261/// Mirrors the `MPSGraph` framework property for `tensor`.
262    pub tensor: &'a Tensor,
263/// Mirrors the `MPSGraph` framework property for `shape`.
264    pub shape: &'a [usize],
265/// Mirrors the `MPSGraph` framework property for `data_type`.
266    pub data_type: u32,
267}
268
269impl<'a> FeedDescription<'a> {
270/// Mirrors the `MPSGraph` framework constant `fn`.
271    #[must_use]
272    pub const fn new(tensor: &'a Tensor, shape: &'a [usize], data_type: u32) -> Self {
273        Self {
274            tensor,
275            shape,
276            data_type,
277        }
278    }
279}
280
281/// Plain-Rust configuration for `MPSGraphConvolution2DOpDescriptor`.
282#[derive(Debug, Clone, Copy)]
283pub struct Convolution2DDescriptorInfo {
284/// Mirrors the `MPSGraph` framework property for `stride_in_x`.
285    pub stride_in_x: usize,
286/// Mirrors the `MPSGraph` framework property for `stride_in_y`.
287    pub stride_in_y: usize,
288/// Mirrors the `MPSGraph` framework property for `dilation_rate_in_x`.
289    pub dilation_rate_in_x: usize,
290/// Mirrors the `MPSGraph` framework property for `dilation_rate_in_y`.
291    pub dilation_rate_in_y: usize,
292/// Mirrors the `MPSGraph` framework property for `groups`.
293    pub groups: usize,
294/// Mirrors the `MPSGraph` framework property for `padding_left`.
295    pub padding_left: usize,
296/// Mirrors the `MPSGraph` framework property for `padding_right`.
297    pub padding_right: usize,
298/// Mirrors the `MPSGraph` framework property for `padding_top`.
299    pub padding_top: usize,
300/// Mirrors the `MPSGraph` framework property for `padding_bottom`.
301    pub padding_bottom: usize,
302/// Mirrors the `MPSGraph` framework property for `padding_style`.
303    pub padding_style: usize,
304/// Mirrors the `MPSGraph` framework property for `data_layout`.
305    pub data_layout: usize,
306/// Mirrors the `MPSGraph` framework property for `weights_layout`.
307    pub weights_layout: usize,
308}
309
310impl Default for Convolution2DDescriptorInfo {
311    fn default() -> Self {
312        Self {
313            stride_in_x: 1,
314            stride_in_y: 1,
315            dilation_rate_in_x: 1,
316            dilation_rate_in_y: 1,
317            groups: 1,
318            padding_left: 0,
319            padding_right: 0,
320            padding_top: 0,
321            padding_bottom: 0,
322            padding_style: padding_style::EXPLICIT,
323            data_layout: tensor_named_data_layout::NHWC,
324            weights_layout: tensor_named_data_layout::HWIO,
325        }
326    }
327}
328
329opaque_handle!(Convolution2DDescriptor);
330impl Convolution2DDescriptor {
331/// Calls the `MPSGraph` framework counterpart for `new`.
332    #[must_use]
333    pub fn new(info: Convolution2DDescriptorInfo) -> Option<Self> {
334        // SAFETY: All scalar configuration values are POD.
335        let ptr = unsafe {
336            ffi::mpsgraph_convolution2d_descriptor_new(
337                info.stride_in_x,
338                info.stride_in_y,
339                info.dilation_rate_in_x,
340                info.dilation_rate_in_y,
341                info.groups,
342                info.padding_left,
343                info.padding_right,
344                info.padding_top,
345                info.padding_bottom,
346                info.padding_style,
347                info.data_layout,
348                info.weights_layout,
349            )
350        };
351        if ptr.is_null() {
352            None
353        } else {
354            Some(Self { ptr })
355        }
356    }
357}
358
359/// Plain-Rust configuration for `MPSGraphPooling2DOpDescriptor`.
360#[derive(Debug, Clone, Copy)]
361pub struct Pooling2DDescriptorInfo {
362/// Mirrors the `MPSGraph` framework property for `kernel_width`.
363    pub kernel_width: usize,
364/// Mirrors the `MPSGraph` framework property for `kernel_height`.
365    pub kernel_height: usize,
366/// Mirrors the `MPSGraph` framework property for `stride_in_x`.
367    pub stride_in_x: usize,
368/// Mirrors the `MPSGraph` framework property for `stride_in_y`.
369    pub stride_in_y: usize,
370/// Mirrors the `MPSGraph` framework property for `dilation_rate_in_x`.
371    pub dilation_rate_in_x: usize,
372/// Mirrors the `MPSGraph` framework property for `dilation_rate_in_y`.
373    pub dilation_rate_in_y: usize,
374/// Mirrors the `MPSGraph` framework property for `padding_left`.
375    pub padding_left: usize,
376/// Mirrors the `MPSGraph` framework property for `padding_right`.
377    pub padding_right: usize,
378/// Mirrors the `MPSGraph` framework property for `padding_top`.
379    pub padding_top: usize,
380/// Mirrors the `MPSGraph` framework property for `padding_bottom`.
381    pub padding_bottom: usize,
382/// Mirrors the `MPSGraph` framework property for `padding_style`.
383    pub padding_style: usize,
384/// Mirrors the `MPSGraph` framework property for `data_layout`.
385    pub data_layout: usize,
386}
387
388impl Pooling2DDescriptorInfo {
389/// Mirrors the `MPSGraph` framework constant `fn`.
390    #[must_use]
391    pub const fn new(kernel_width: usize, kernel_height: usize) -> Self {
392        Self {
393            kernel_width,
394            kernel_height,
395            stride_in_x: 1,
396            stride_in_y: 1,
397            dilation_rate_in_x: 1,
398            dilation_rate_in_y: 1,
399            padding_left: 0,
400            padding_right: 0,
401            padding_top: 0,
402            padding_bottom: 0,
403            padding_style: padding_style::EXPLICIT,
404            data_layout: tensor_named_data_layout::NHWC,
405        }
406    }
407}
408
409opaque_handle!(Pooling2DDescriptor);
410impl Pooling2DDescriptor {
411/// Calls the `MPSGraph` framework counterpart for `new`.
412    #[must_use]
413    pub fn new(info: Pooling2DDescriptorInfo) -> Option<Self> {
414        // SAFETY: All scalar configuration values are POD.
415        let ptr = unsafe {
416            ffi::mpsgraph_pooling2d_descriptor_new(
417                info.kernel_width,
418                info.kernel_height,
419                info.stride_in_x,
420                info.stride_in_y,
421                info.dilation_rate_in_x,
422                info.dilation_rate_in_y,
423                info.padding_left,
424                info.padding_right,
425                info.padding_top,
426                info.padding_bottom,
427                info.padding_style,
428                info.data_layout,
429            )
430        };
431        if ptr.is_null() {
432            None
433        } else {
434            Some(Self { ptr })
435        }
436    }
437}
438
439opaque_handle!(Graph);
440opaque_handle!(Tensor);
441
442impl Tensor {
443    pub(crate) const fn from_raw(ptr: *mut c_void) -> Self {
444        Self { ptr }
445    }
446}
447
448impl Graph {
449/// Calls the `MPSGraph` framework counterpart for `new`.
450    #[must_use]
451    pub fn new() -> Option<Self> {
452        // SAFETY: Pure constructor with no inputs.
453        let ptr = unsafe { ffi::mpsgraph_graph_new() };
454        if ptr.is_null() {
455            None
456        } else {
457            Some(Self { ptr })
458        }
459    }
460
461/// Calls the `MPSGraph` framework counterpart for `placeholder`.
462    #[must_use]
463    pub fn placeholder(
464        &self,
465        shape: Option<&[usize]>,
466        data_type: u32,
467        name: Option<&str>,
468    ) -> Option<Tensor> {
469        let name = optional_cstring(name);
470        let (shape_ptr, shape_len) =
471            shape.map_or((ptr::null(), 0), |shape| (shape.as_ptr(), shape.len()));
472
473        // SAFETY: All pointers originate from safe wrappers and remain alive for the duration of the call.
474        let ptr = unsafe {
475            ffi::mpsgraph_graph_placeholder(
476                self.ptr,
477                shape_ptr,
478                shape_len,
479                data_type,
480                cstring_ptr(&name),
481            )
482        };
483        wrap_tensor(ptr)
484    }
485
486/// Calls the `MPSGraph` framework counterpart for `constant_bytes`.
487    #[must_use]
488    pub fn constant_bytes(&self, data: &[u8], shape: &[usize], data_type: u32) -> Option<Tensor> {
489        let expected = checked_byte_len(shape, data_type)?;
490        if data.len() != expected {
491            return None;
492        }
493
494        // SAFETY: The byte slice remains valid for the duration of the FFI call.
495        let ptr = unsafe {
496            ffi::mpsgraph_graph_constant_data(
497                self.ptr,
498                data.as_ptr().cast(),
499                data.len(),
500                shape.as_ptr(),
501                shape.len(),
502                data_type,
503            )
504        };
505        wrap_tensor(ptr)
506    }
507
508/// Calls the `MPSGraph` framework counterpart for `constant_f32_slice`.
509    #[must_use]
510    pub fn constant_f32_slice(&self, values: &[f32], shape: &[usize]) -> Option<Tensor> {
511        // SAFETY: `values` is a contiguous slice of `f32` that may be viewed as bytes.
512        let bytes = unsafe {
513            core::slice::from_raw_parts(
514                values.as_ptr().cast::<u8>(),
515                core::mem::size_of_val(values),
516            )
517        };
518        self.constant_bytes(bytes, shape, data_type::FLOAT32)
519    }
520
521/// Calls the `MPSGraph` framework counterpart for `constant_scalar`.
522    #[must_use]
523    pub fn constant_scalar(&self, scalar: f64, data_type: u32) -> Option<Tensor> {
524        // SAFETY: Pure constructor over scalar inputs.
525        let ptr = unsafe { ffi::mpsgraph_graph_constant_scalar(self.ptr, scalar, data_type) };
526        wrap_tensor(ptr)
527    }
528
529/// Calls the `MPSGraph` framework counterpart for `constant_scalar_shaped`.
530    #[must_use]
531    pub fn constant_scalar_shaped(
532        &self,
533        scalar: f64,
534        shape: &[usize],
535        data_type: u32,
536    ) -> Option<Tensor> {
537        // SAFETY: Shape slice stays valid for the duration of the FFI call.
538        let ptr = unsafe {
539            ffi::mpsgraph_graph_constant_scalar_shaped(
540                self.ptr,
541                scalar,
542                shape.as_ptr(),
543                shape.len(),
544                data_type,
545            )
546        };
547        wrap_tensor(ptr)
548    }
549
550    impl_binary_tensor_op!(addition, mpsgraph_graph_addition);
551    impl_binary_tensor_op!(subtraction, mpsgraph_graph_subtraction);
552    impl_binary_tensor_op!(multiplication, mpsgraph_graph_multiplication);
553    impl_binary_tensor_op!(division, mpsgraph_graph_division);
554    impl_binary_tensor_op!(matrix_multiplication, mpsgraph_graph_matrix_multiplication);
555    impl_unary_tensor_op!(relu, mpsgraph_graph_relu);
556    impl_unary_tensor_op!(sigmoid, mpsgraph_graph_sigmoid);
557    impl_axes_tensor_op!(reduction_sum, mpsgraph_graph_reduction_sum);
558    impl_axes_tensor_op!(reduction_maximum, mpsgraph_graph_reduction_maximum);
559    impl_axes_tensor_op!(reduction_minimum, mpsgraph_graph_reduction_minimum);
560    impl_axes_tensor_op!(mean, mpsgraph_graph_mean);
561
562/// Calls the `MPSGraph` framework counterpart for `softmax`.
563    #[must_use]
564    pub fn softmax(&self, tensor: &Tensor, axis: isize, name: Option<&str>) -> Option<Tensor> {
565        let name = optional_cstring(name);
566        // SAFETY: All pointers originate from safe wrappers and remain alive for the duration of the call.
567        let ptr = unsafe {
568            ffi::mpsgraph_graph_softmax(self.ptr, tensor.as_ptr(), axis, cstring_ptr(&name))
569        };
570        wrap_tensor(ptr)
571    }
572
573/// Calls the `MPSGraph` framework counterpart for `reshape`.
574    #[must_use]
575    pub fn reshape(&self, tensor: &Tensor, shape: &[usize], name: Option<&str>) -> Option<Tensor> {
576        let name = optional_cstring(name);
577        // SAFETY: All pointers originate from safe wrappers and remain alive for the duration of the call.
578        let ptr = unsafe {
579            ffi::mpsgraph_graph_reshape(
580                self.ptr,
581                tensor.as_ptr(),
582                shape.as_ptr(),
583                shape.len(),
584                cstring_ptr(&name),
585            )
586        };
587        wrap_tensor(ptr)
588    }
589
590/// Calls the `MPSGraph` framework counterpart for `transpose`.
591    #[must_use]
592    pub fn transpose(
593        &self,
594        tensor: &Tensor,
595        permutation: &[usize],
596        name: Option<&str>,
597    ) -> Option<Tensor> {
598        let name = optional_cstring(name);
599        // SAFETY: All pointers originate from safe wrappers and remain alive for the duration of the call.
600        let ptr = unsafe {
601            ffi::mpsgraph_graph_transpose(
602                self.ptr,
603                tensor.as_ptr(),
604                permutation.as_ptr(),
605                permutation.len(),
606                cstring_ptr(&name),
607            )
608        };
609        wrap_tensor(ptr)
610    }
611
612/// Calls the `MPSGraph` framework counterpart for `slice`.
613    #[must_use]
614    pub fn slice(
615        &self,
616        tensor: &Tensor,
617        dimension: usize,
618        start: isize,
619        length: isize,
620        name: Option<&str>,
621    ) -> Option<Tensor> {
622        let name = optional_cstring(name);
623        // SAFETY: All pointers originate from safe wrappers and remain alive for the duration of the call.
624        let ptr = unsafe {
625            ffi::mpsgraph_graph_slice(
626                self.ptr,
627                tensor.as_ptr(),
628                dimension,
629                start,
630                length,
631                cstring_ptr(&name),
632            )
633        };
634        wrap_tensor(ptr)
635    }
636
637/// Calls the `MPSGraph` framework counterpart for `broadcast`.
638    #[must_use]
639    pub fn broadcast(
640        &self,
641        tensor: &Tensor,
642        shape: &[usize],
643        name: Option<&str>,
644    ) -> Option<Tensor> {
645        let name = optional_cstring(name);
646        // SAFETY: All pointers originate from safe wrappers and remain alive for the duration of the call.
647        let ptr = unsafe {
648            ffi::mpsgraph_graph_broadcast(
649                self.ptr,
650                tensor.as_ptr(),
651                shape.as_ptr(),
652                shape.len(),
653                cstring_ptr(&name),
654            )
655        };
656        wrap_tensor(ptr)
657    }
658
659/// Calls the `MPSGraph` framework counterpart for `convolution2d`.
660    #[must_use]
661    pub fn convolution2d(
662        &self,
663        source: &Tensor,
664        weights: &Tensor,
665        descriptor: &Convolution2DDescriptor,
666        name: Option<&str>,
667    ) -> Option<Tensor> {
668        let name = optional_cstring(name);
669        // SAFETY: All pointers originate from safe wrappers and remain alive for the duration of the call.
670        let ptr = unsafe {
671            ffi::mpsgraph_graph_convolution2d(
672                self.ptr,
673                source.as_ptr(),
674                weights.as_ptr(),
675                descriptor.as_ptr(),
676                cstring_ptr(&name),
677            )
678        };
679        wrap_tensor(ptr)
680    }
681
682/// Calls the `MPSGraph` framework counterpart for `max_pooling2d`.
683    #[must_use]
684    pub fn max_pooling2d(
685        &self,
686        source: &Tensor,
687        descriptor: &Pooling2DDescriptor,
688        name: Option<&str>,
689    ) -> Option<Tensor> {
690        let name = optional_cstring(name);
691        // SAFETY: All pointers originate from safe wrappers and remain alive for the duration of the call.
692        let ptr = unsafe {
693            ffi::mpsgraph_graph_max_pooling2d(
694                self.ptr,
695                source.as_ptr(),
696                descriptor.as_ptr(),
697                cstring_ptr(&name),
698            )
699        };
700        wrap_tensor(ptr)
701    }
702
703/// Calls the `MPSGraph` framework counterpart for `normalize`.
704    #[allow(clippy::too_many_arguments)]
705    #[must_use]
706    pub fn normalize(
707        &self,
708        tensor: &Tensor,
709        mean: &Tensor,
710        variance: &Tensor,
711        gamma: Option<&Tensor>,
712        beta: Option<&Tensor>,
713        epsilon: f32,
714        name: Option<&str>,
715    ) -> Option<Tensor> {
716        let name = optional_cstring(name);
717        let gamma_ptr = gamma.map_or(ptr::null_mut(), Tensor::as_ptr);
718        let beta_ptr = beta.map_or(ptr::null_mut(), Tensor::as_ptr);
719        // SAFETY: All pointers originate from safe wrappers and remain alive for the duration of the call.
720        let ptr = unsafe {
721            ffi::mpsgraph_graph_normalize(
722                self.ptr,
723                tensor.as_ptr(),
724                mean.as_ptr(),
725                variance.as_ptr(),
726                gamma_ptr,
727                beta_ptr,
728                epsilon,
729                cstring_ptr(&name),
730            )
731        };
732        wrap_tensor(ptr)
733    }
734
735/// Calls the `MPSGraph` framework counterpart for `run`.
736    pub fn run(&self, feeds: &[Feed<'_>], targets: &[&Tensor]) -> Result<Vec<TensorData>> {
737        let feed_tensors = feeds
738            .iter()
739            .map(|feed| feed.tensor.as_ptr())
740            .collect::<Vec<_>>();
741        let feed_data = feeds
742            .iter()
743            .map(|feed| feed.data.as_ptr())
744            .collect::<Vec<_>>();
745        let target_tensors = targets
746            .iter()
747            .map(|tensor| tensor.as_ptr())
748            .collect::<Vec<_>>();
749        let mut results = vec![ptr::null_mut(); targets.len()];
750
751        // SAFETY: The pointer arrays are valid for the duration of the FFI call.
752        let ok = unsafe {
753            ffi::mpsgraph_graph_run(
754                self.ptr,
755                feed_tensors.as_ptr(),
756                feed_data.as_ptr(),
757                feeds.len(),
758                target_tensors.as_ptr(),
759                targets.len(),
760                results.as_mut_ptr(),
761            )
762        };
763        if ok {
764            wrap_tensor_data_results(results, "failed to run graph")
765        } else {
766            Err(Error::OperationFailed("failed to run graph"))
767        }
768    }
769
770/// Calls the `MPSGraph` framework counterpart for `run_with_command_queue`.
771    pub fn run_with_command_queue(
772        &self,
773        command_queue: &CommandQueue,
774        feeds: &[Feed<'_>],
775        targets: &[&Tensor],
776    ) -> Result<Vec<TensorData>> {
777        let feed_tensors = feeds
778            .iter()
779            .map(|feed| feed.tensor.as_ptr())
780            .collect::<Vec<_>>();
781        let feed_data = feeds
782            .iter()
783            .map(|feed| feed.data.as_ptr())
784            .collect::<Vec<_>>();
785        let target_tensors = targets
786            .iter()
787            .map(|tensor| tensor.as_ptr())
788            .collect::<Vec<_>>();
789        let mut results = vec![ptr::null_mut(); targets.len()];
790
791        // SAFETY: The pointer arrays are valid for the duration of the FFI call.
792        let ok = unsafe {
793            ffi::mpsgraph_graph_run_with_command_queue(
794                self.ptr,
795                command_queue.as_ptr(),
796                feed_tensors.as_ptr(),
797                feed_data.as_ptr(),
798                feeds.len(),
799                target_tensors.as_ptr(),
800                targets.len(),
801                results.as_mut_ptr(),
802            )
803        };
804        if ok {
805            wrap_tensor_data_results(results, "failed to run graph with command queue")
806        } else {
807            Err(Error::OperationFailed(
808                "failed to run graph with command queue",
809            ))
810        }
811    }
812
813/// Calls the `MPSGraph` framework counterpart for `compile`.
814    #[must_use]
815    pub fn compile(
816        &self,
817        device: &MetalDevice,
818        feeds: &[FeedDescription<'_>],
819        targets: &[&Tensor],
820    ) -> Option<Executable> {
821        let feed_tensors = feeds
822            .iter()
823            .map(|feed| feed.tensor.as_ptr())
824            .collect::<Vec<_>>();
825        let shape_lengths = feeds
826            .iter()
827            .map(|feed| feed.shape.len())
828            .collect::<Vec<_>>();
829        let data_types = feeds.iter().map(|feed| feed.data_type).collect::<Vec<_>>();
830        let flat_shapes = feeds
831            .iter()
832            .flat_map(|feed| feed.shape.iter().copied())
833            .collect::<Vec<_>>();
834        let target_tensors = targets
835            .iter()
836            .map(|tensor| tensor.as_ptr())
837            .collect::<Vec<_>>();
838
839        // SAFETY: The pointer arrays are valid for the duration of the FFI call.
840        let ptr = unsafe {
841            ffi::mpsgraph_graph_compile(
842                self.ptr,
843                device.as_ptr(),
844                feed_tensors.as_ptr(),
845                feeds.len(),
846                flat_shapes.as_ptr(),
847                shape_lengths.as_ptr(),
848                data_types.as_ptr(),
849                target_tensors.as_ptr(),
850                targets.len(),
851            )
852        };
853        if ptr.is_null() {
854            None
855        } else {
856            Some(Executable::from_raw(ptr, targets.len()))
857        }
858    }
859}
860
861/// Safe owner for a compiled `MPSGraphExecutable`.
862pub struct Executable {
863    ptr: *mut c_void,
864    output_count: usize,
865}
866
867unsafe impl Send for Executable {}
868unsafe impl Sync for Executable {}
869
870impl Drop for Executable {
871    fn drop(&mut self) {
872        if !self.ptr.is_null() {
873            // SAFETY: `ptr` is a +1 retained Swift/ObjC object pointer owned by this wrapper.
874            unsafe { ffi::mpsgraph_object_release(self.ptr) };
875            self.ptr = ptr::null_mut();
876        }
877    }
878}
879
880impl Executable {
881    pub(crate) const fn from_raw(ptr: *mut c_void, output_count: usize) -> Self {
882        Self { ptr, output_count }
883    }
884
885/// Mirrors the `MPSGraph` framework constant `fn`.
886    #[must_use]
887    pub const fn as_ptr(&self) -> *mut c_void {
888        self.ptr
889    }
890
891/// Mirrors the `MPSGraph` framework constant `fn`.
892    #[must_use]
893    pub const fn output_count(&self) -> usize {
894        self.output_count
895    }
896
897/// Calls the `MPSGraph` framework counterpart for `run`.
898    pub fn run(
899        &self,
900        command_queue: &CommandQueue,
901        inputs: &[&TensorData],
902    ) -> Result<Vec<TensorData>> {
903        let input_data = inputs
904            .iter()
905            .map(|tensor_data| tensor_data.as_ptr())
906            .collect::<Vec<_>>();
907        let mut results = vec![ptr::null_mut(); self.output_count];
908
909        // SAFETY: The pointer arrays are valid for the duration of the FFI call.
910        let ok = unsafe {
911            ffi::mpsgraph_executable_run(
912                self.ptr,
913                command_queue.as_ptr(),
914                input_data.as_ptr(),
915                inputs.len(),
916                self.output_count,
917                results.as_mut_ptr(),
918            )
919        };
920        if ok {
921            wrap_tensor_data_results(results, "failed to run executable")
922        } else {
923            Err(Error::OperationFailed("failed to run executable"))
924        }
925    }
926}