Skip to main content

apple_mpsgraph/
types.rs

1use crate::data::TensorData;
2use crate::error::{Error, Result};
3use crate::ffi;
4use crate::graph::Tensor;
5use apple_metal::MetalDevice;
6use core::ffi::c_void;
7use core::ptr;
8
9fn release_handle(ptr: &mut *mut c_void) {
10    if !ptr.is_null() {
11        // SAFETY: `ptr` is a +1 retained Swift/ObjC object pointer owned by this wrapper.
12        unsafe { ffi::mpsgraph_object_release(*ptr) };
13        *ptr = ptr::null_mut();
14    }
15}
16
17fn copy_optional_signed_shape(
18    handle: *mut c_void,
19    has_shape: unsafe extern "C" fn(*mut c_void) -> bool,
20    shape_len: unsafe extern "C" fn(*mut c_void) -> usize,
21    copy_shape: unsafe extern "C" fn(*mut c_void, *mut isize),
22) -> Option<Vec<isize>> {
23    // SAFETY: the function pointers belong to Swift shims that treat `handle` as immutable for the duration of the call.
24    if unsafe { !has_shape(handle) } {
25        return None;
26    }
27    // SAFETY: see above.
28    let len = unsafe { shape_len(handle) };
29    let mut shape = vec![0_isize; len];
30    if len > 0 {
31        // SAFETY: `shape` has space for exactly `len` elements.
32        unsafe { copy_shape(handle, shape.as_mut_ptr()) };
33    }
34    Some(shape)
35}
36
37fn collect_tensor_array_box(handle: *mut c_void) -> Vec<Tensor> {
38    if handle.is_null() {
39        return Vec::new();
40    }
41
42    // SAFETY: `handle` is a retained tensor-array box created by the Swift bridge.
43    let len = unsafe { ffi::mpsgraph_tensor_array_box_len(handle) };
44    let mut tensors = Vec::with_capacity(len);
45    for index in 0..len {
46        // SAFETY: indices are bounded by the just-read length.
47        let tensor = unsafe { ffi::mpsgraph_tensor_array_box_get(handle, index) };
48        if !tensor.is_null() {
49            tensors.push(Tensor::from_raw(tensor));
50        }
51    }
52    let mut box_handle = handle;
53    release_handle(&mut box_handle);
54    tensors
55}
56
57pub(crate) fn collect_tensor_data_array_box(handle: *mut c_void) -> Vec<TensorData> {
58    if handle.is_null() {
59        return Vec::new();
60    }
61
62    // SAFETY: `handle` is a retained tensor-data-array box created by the Swift bridge.
63    let len = unsafe { ffi::mpsgraph_tensor_data_array_box_len(handle) };
64    let mut values = Vec::with_capacity(len);
65    for index in 0..len {
66        // SAFETY: indices are bounded by the just-read length.
67        let value = unsafe { ffi::mpsgraph_tensor_data_array_box_get(handle, index) };
68        if !value.is_null() {
69            values.push(TensorData::from_raw(value));
70        }
71    }
72    let mut box_handle = handle;
73    release_handle(&mut box_handle);
74    values
75}
76
77pub(crate) fn collect_shaped_type_array_box(handle: *mut c_void) -> Vec<ShapedType> {
78    if handle.is_null() {
79        return Vec::new();
80    }
81
82    // SAFETY: `handle` is a retained shaped-type-array box created by the Swift bridge.
83    let len = unsafe { ffi::mpsgraph_shaped_type_array_box_len(handle) };
84    let mut values = Vec::with_capacity(len);
85    for index in 0..len {
86        // SAFETY: indices are bounded by the just-read length.
87        let value = unsafe { ffi::mpsgraph_shaped_type_array_box_get(handle, index) };
88        if !value.is_null() {
89            values.push(ShapedType { ptr: value });
90        }
91    }
92    let mut box_handle = handle;
93    release_handle(&mut box_handle);
94    values
95}
96
97/// `MPSGraphDeviceType` constants.
98pub mod graph_device_type {
99/// Mirrors the `MPSGraph` framework constant `METAL`.
100    pub const METAL: u32 = 0;
101}
102
103/// Owned wrapper for `MPSGraphDevice`.
104pub struct GraphDevice {
105    ptr: *mut c_void,
106}
107
108unsafe impl Send for GraphDevice {}
109unsafe impl Sync for GraphDevice {}
110
111impl Drop for GraphDevice {
112    fn drop(&mut self) {
113        release_handle(&mut self.ptr);
114    }
115}
116
117impl GraphDevice {
118    /// Create a graph device from an existing Metal device.
119    #[must_use]
120    pub fn from_metal_device(device: &MetalDevice) -> Option<Self> {
121        // SAFETY: `device` remains valid for the duration of the bridge call.
122        let ptr = unsafe { ffi::mpsgraph_device_new_with_metal_device(device.as_ptr()) };
123        if ptr.is_null() {
124            None
125        } else {
126            Some(Self { ptr })
127        }
128    }
129
130/// Mirrors the `MPSGraph` framework constant `fn`.
131    #[must_use]
132    pub const fn as_ptr(&self) -> *mut c_void {
133        self.ptr
134    }
135
136    /// Return the underlying `MPSGraphDeviceType` raw value.
137    #[must_use]
138    pub fn device_type(&self) -> u32 {
139        // SAFETY: `self.ptr` is a live graph-device handle.
140        unsafe { ffi::mpsgraph_device_type(self.ptr) }
141    }
142}
143
144/// Owned wrapper for `MPSGraphShapedType`.
145pub struct ShapedType {
146    ptr: *mut c_void,
147}
148
149unsafe impl Send for ShapedType {}
150unsafe impl Sync for ShapedType {}
151
152impl Drop for ShapedType {
153    fn drop(&mut self) {
154        release_handle(&mut self.ptr);
155    }
156}
157
158impl ShapedType {
159    /// Create a shaped type from an optional shape and `MPSDataType` raw value.
160    #[must_use]
161    pub fn new(shape: Option<&[isize]>, data_type: u32) -> Option<Self> {
162        let (shape_ptr, shape_len) =
163            shape.map_or((ptr::null(), 0), |shape| (shape.as_ptr(), shape.len()));
164        // SAFETY: the optional slice lives for the duration of the call.
165        let ptr = unsafe { ffi::mpsgraph_shaped_type_new(shape_ptr, shape_len, data_type) };
166        if ptr.is_null() {
167            None
168        } else {
169            Some(Self { ptr })
170        }
171    }
172
173    #[must_use]
174    pub(crate) const fn as_ptr(&self) -> *mut c_void {
175        self.ptr
176    }
177
178    /// Return the optional tensor shape. `None` corresponds to an unranked shape.
179    #[must_use]
180    pub fn shape(&self) -> Option<Vec<isize>> {
181        copy_optional_signed_shape(
182            self.ptr,
183            ffi::mpsgraph_shaped_type_has_shape,
184            ffi::mpsgraph_shaped_type_shape_len,
185            ffi::mpsgraph_shaped_type_copy_shape,
186        )
187    }
188
189    /// Return the underlying `MPSDataType` raw value.
190    #[must_use]
191    pub fn data_type(&self) -> u32 {
192        // SAFETY: `self.ptr` is a live shaped-type handle.
193        unsafe { ffi::mpsgraph_shaped_type_data_type(self.ptr) }
194    }
195
196    /// Replace the shape metadata for this shaped type.
197    pub fn set_shape(&self, shape: Option<&[isize]>) -> Result<()> {
198        let (shape_ptr, shape_len) =
199            shape.map_or((ptr::null(), 0), |shape| (shape.as_ptr(), shape.len()));
200        // SAFETY: the optional slice lives for the duration of the call.
201        let ok = unsafe { ffi::mpsgraph_shaped_type_set_shape(self.ptr, shape_ptr, shape_len) };
202        if ok {
203            Ok(())
204        } else {
205            Err(Error::OperationFailed("failed to set shaped type shape"))
206        }
207    }
208
209    /// Replace the data-type metadata for this shaped type.
210    pub fn set_data_type(&self, data_type: u32) -> Result<()> {
211        // SAFETY: `self.ptr` is a live shaped-type handle.
212        let ok = unsafe { ffi::mpsgraph_shaped_type_set_data_type(self.ptr, data_type) };
213        if ok {
214            Ok(())
215        } else {
216            Err(Error::OperationFailed(
217                "failed to set shaped type data type",
218            ))
219        }
220    }
221
222    /// Compare two shaped types using `MPSGraphShapedType.isEqual(to:)`.
223    #[must_use]
224    pub fn is_equal(&self, other: Option<&Self>) -> bool {
225        let other_ptr = other.map_or(ptr::null_mut(), Self::as_ptr);
226        // SAFETY: all handles stay alive for the duration of the call.
227        unsafe { ffi::mpsgraph_shaped_type_is_equal(self.ptr, other_ptr) }
228    }
229}
230
231/// Owned wrapper for `MPSGraphOperation`.
232pub struct Operation {
233    ptr: *mut c_void,
234}
235
236unsafe impl Send for Operation {}
237unsafe impl Sync for Operation {}
238
239impl Drop for Operation {
240    fn drop(&mut self) {
241        release_handle(&mut self.ptr);
242    }
243}
244
245impl Operation {
246    #[must_use]
247    pub(crate) const fn from_raw(ptr: *mut c_void) -> Self {
248        Self { ptr }
249    }
250
251/// Mirrors the `MPSGraph` framework constant `fn`.
252    #[must_use]
253    pub const fn as_ptr(&self) -> *mut c_void {
254        self.ptr
255    }
256}
257
258impl Tensor {
259    /// Return the optional symbolic tensor shape.
260    #[must_use]
261    pub fn shape(&self) -> Option<Vec<isize>> {
262        copy_optional_signed_shape(
263            self.as_ptr(),
264            ffi::mpsgraph_tensor_has_shape,
265            ffi::mpsgraph_tensor_shape_len,
266            ffi::mpsgraph_tensor_copy_shape,
267        )
268    }
269
270    /// Return the tensor's `MPSDataType` raw value.
271    #[must_use]
272    pub fn data_type(&self) -> u32 {
273        // SAFETY: `self` owns a live tensor handle.
274        unsafe { ffi::mpsgraph_tensor_data_type(self.as_ptr()) }
275    }
276
277    /// Return the operation that produced this tensor.
278    #[must_use]
279    pub fn operation(&self) -> Option<Operation> {
280        // SAFETY: `self` owns a live tensor handle.
281        let ptr = unsafe { ffi::mpsgraph_tensor_operation(self.as_ptr()) };
282        if ptr.is_null() {
283            None
284        } else {
285            Some(Operation { ptr })
286        }
287    }
288}
289
290impl TensorData {
291    /// Return the graph-device type that backs this tensor data.
292    #[must_use]
293    pub fn graph_device_type(&self) -> Option<u32> {
294        // SAFETY: `self` owns a live tensor-data handle.
295        let ptr = unsafe { ffi::mpsgraph_tensor_data_device(self.as_ptr()) };
296        if ptr.is_null() {
297            return None;
298        }
299        let device = GraphDevice { ptr };
300        Some(device.device_type())
301    }
302}
303
304pub(crate) fn collect_owned_tensors(handle: *mut c_void) -> Vec<Tensor> {
305    collect_tensor_array_box(handle)
306}