Skip to main content

apple_mpsgraph/
types.rs

1use crate::data::TensorData;
2use crate::error::{Error, Result};
3use crate::ffi;
4use crate::graph::Tensor;
5use apple_metal::MetalDevice;
6use core::ffi::c_void;
7use core::ptr;
8
9fn release_handle(ptr: &mut *mut c_void) {
10    if !ptr.is_null() {
11        // SAFETY: `ptr` is a +1 retained Swift/ObjC object pointer owned by this wrapper.
12        unsafe { ffi::mpsgraph_object_release(*ptr) };
13        *ptr = ptr::null_mut();
14    }
15}
16
17fn copy_optional_signed_shape(
18    handle: *mut c_void,
19    has_shape: unsafe extern "C" fn(*mut c_void) -> bool,
20    shape_len: unsafe extern "C" fn(*mut c_void) -> usize,
21    copy_shape: unsafe extern "C" fn(*mut c_void, *mut isize),
22) -> Option<Vec<isize>> {
23    // SAFETY: the function pointers belong to Swift shims that treat `handle` as immutable for the duration of the call.
24    if unsafe { !has_shape(handle) } {
25        return None;
26    }
27    // SAFETY: see above.
28    let len = unsafe { shape_len(handle) };
29    let mut shape = vec![0_isize; len];
30    if len > 0 {
31        // SAFETY: `shape` has space for exactly `len` elements.
32        unsafe { copy_shape(handle, shape.as_mut_ptr()) };
33    }
34    Some(shape)
35}
36
37fn collect_tensor_array_box(handle: *mut c_void) -> Vec<Tensor> {
38    if handle.is_null() {
39        return Vec::new();
40    }
41
42    // SAFETY: `handle` is a retained tensor-array box created by the Swift bridge.
43    let len = unsafe { ffi::mpsgraph_tensor_array_box_len(handle) };
44    let mut tensors = Vec::with_capacity(len);
45    for index in 0..len {
46        // SAFETY: indices are bounded by the just-read length.
47        let tensor = unsafe { ffi::mpsgraph_tensor_array_box_get(handle, index) };
48        if !tensor.is_null() {
49            tensors.push(Tensor::from_raw(tensor));
50        }
51    }
52    let mut box_handle = handle;
53    release_handle(&mut box_handle);
54    tensors
55}
56
57pub(crate) fn collect_tensor_data_array_box(handle: *mut c_void) -> Vec<TensorData> {
58    if handle.is_null() {
59        return Vec::new();
60    }
61
62    // SAFETY: `handle` is a retained tensor-data-array box created by the Swift bridge.
63    let len = unsafe { ffi::mpsgraph_tensor_data_array_box_len(handle) };
64    let mut values = Vec::with_capacity(len);
65    for index in 0..len {
66        // SAFETY: indices are bounded by the just-read length.
67        let value = unsafe { ffi::mpsgraph_tensor_data_array_box_get(handle, index) };
68        if !value.is_null() {
69            values.push(TensorData::from_raw(value));
70        }
71    }
72    let mut box_handle = handle;
73    release_handle(&mut box_handle);
74    values
75}
76
77pub(crate) fn collect_shaped_type_array_box(handle: *mut c_void) -> Vec<ShapedType> {
78    if handle.is_null() {
79        return Vec::new();
80    }
81
82    // SAFETY: `handle` is a retained shaped-type-array box created by the Swift bridge.
83    let len = unsafe { ffi::mpsgraph_shaped_type_array_box_len(handle) };
84    let mut values = Vec::with_capacity(len);
85    for index in 0..len {
86        // SAFETY: indices are bounded by the just-read length.
87        let value = unsafe { ffi::mpsgraph_shaped_type_array_box_get(handle, index) };
88        if !value.is_null() {
89            values.push(ShapedType { ptr: value });
90        }
91    }
92    let mut box_handle = handle;
93    release_handle(&mut box_handle);
94    values
95}
96
97/// `MPSGraphDeviceType` constants.
98pub mod graph_device_type {
99    pub const METAL: u32 = 0;
100}
101
102/// Owned wrapper for `MPSGraphDevice`.
103pub struct GraphDevice {
104    ptr: *mut c_void,
105}
106
107unsafe impl Send for GraphDevice {}
108unsafe impl Sync for GraphDevice {}
109
110impl Drop for GraphDevice {
111    fn drop(&mut self) {
112        release_handle(&mut self.ptr);
113    }
114}
115
116impl GraphDevice {
117    /// Create a graph device from an existing Metal device.
118    #[must_use]
119    pub fn from_metal_device(device: &MetalDevice) -> Option<Self> {
120        // SAFETY: `device` remains valid for the duration of the bridge call.
121        let ptr = unsafe { ffi::mpsgraph_device_new_with_metal_device(device.as_ptr()) };
122        if ptr.is_null() {
123            None
124        } else {
125            Some(Self { ptr })
126        }
127    }
128
129    #[must_use]
130    pub const fn as_ptr(&self) -> *mut c_void {
131        self.ptr
132    }
133
134    /// Return the underlying `MPSGraphDeviceType` raw value.
135    #[must_use]
136    pub fn device_type(&self) -> u32 {
137        // SAFETY: `self.ptr` is a live graph-device handle.
138        unsafe { ffi::mpsgraph_device_type(self.ptr) }
139    }
140}
141
142/// Owned wrapper for `MPSGraphShapedType`.
143pub struct ShapedType {
144    ptr: *mut c_void,
145}
146
147unsafe impl Send for ShapedType {}
148unsafe impl Sync for ShapedType {}
149
150impl Drop for ShapedType {
151    fn drop(&mut self) {
152        release_handle(&mut self.ptr);
153    }
154}
155
156impl ShapedType {
157    /// Create a shaped type from an optional shape and `MPSDataType` raw value.
158    #[must_use]
159    pub fn new(shape: Option<&[isize]>, data_type: u32) -> Option<Self> {
160        let (shape_ptr, shape_len) = shape.map_or((ptr::null(), 0), |shape| (shape.as_ptr(), shape.len()));
161        // SAFETY: the optional slice lives for the duration of the call.
162        let ptr = unsafe { ffi::mpsgraph_shaped_type_new(shape_ptr, shape_len, data_type) };
163        if ptr.is_null() {
164            None
165        } else {
166            Some(Self { ptr })
167        }
168    }
169
170    #[must_use]
171    pub(crate) const fn as_ptr(&self) -> *mut c_void {
172        self.ptr
173    }
174
175    /// Return the optional tensor shape. `None` corresponds to an unranked shape.
176    #[must_use]
177    pub fn shape(&self) -> Option<Vec<isize>> {
178        copy_optional_signed_shape(
179            self.ptr,
180            ffi::mpsgraph_shaped_type_has_shape,
181            ffi::mpsgraph_shaped_type_shape_len,
182            ffi::mpsgraph_shaped_type_copy_shape,
183        )
184    }
185
186    /// Return the underlying `MPSDataType` raw value.
187    #[must_use]
188    pub fn data_type(&self) -> u32 {
189        // SAFETY: `self.ptr` is a live shaped-type handle.
190        unsafe { ffi::mpsgraph_shaped_type_data_type(self.ptr) }
191    }
192
193    /// Replace the shape metadata for this shaped type.
194    pub fn set_shape(&self, shape: Option<&[isize]>) -> Result<()> {
195        let (shape_ptr, shape_len) = shape.map_or((ptr::null(), 0), |shape| (shape.as_ptr(), shape.len()));
196        // SAFETY: the optional slice lives for the duration of the call.
197        let ok = unsafe { ffi::mpsgraph_shaped_type_set_shape(self.ptr, shape_ptr, shape_len) };
198        if ok {
199            Ok(())
200        } else {
201            Err(Error::OperationFailed("failed to set shaped type shape"))
202        }
203    }
204
205    /// Replace the data-type metadata for this shaped type.
206    pub fn set_data_type(&self, data_type: u32) -> Result<()> {
207        // SAFETY: `self.ptr` is a live shaped-type handle.
208        let ok = unsafe { ffi::mpsgraph_shaped_type_set_data_type(self.ptr, data_type) };
209        if ok {
210            Ok(())
211        } else {
212            Err(Error::OperationFailed("failed to set shaped type data type"))
213        }
214    }
215
216    /// Compare two shaped types using `MPSGraphShapedType.isEqual(to:)`.
217    #[must_use]
218    pub fn is_equal(&self, other: Option<&Self>) -> bool {
219        let other_ptr = other.map_or(ptr::null_mut(), Self::as_ptr);
220        // SAFETY: all handles stay alive for the duration of the call.
221        unsafe { ffi::mpsgraph_shaped_type_is_equal(self.ptr, other_ptr) }
222    }
223}
224
225/// Owned wrapper for `MPSGraphOperation`.
226pub struct Operation {
227    ptr: *mut c_void,
228}
229
230unsafe impl Send for Operation {}
231unsafe impl Sync for Operation {}
232
233impl Drop for Operation {
234    fn drop(&mut self) {
235        release_handle(&mut self.ptr);
236    }
237}
238
239impl Operation {
240    #[must_use]
241    pub const fn as_ptr(&self) -> *mut c_void {
242        self.ptr
243    }
244}
245
246impl Tensor {
247    /// Return the optional symbolic tensor shape.
248    #[must_use]
249    pub fn shape(&self) -> Option<Vec<isize>> {
250        copy_optional_signed_shape(
251            self.as_ptr(),
252            ffi::mpsgraph_tensor_has_shape,
253            ffi::mpsgraph_tensor_shape_len,
254            ffi::mpsgraph_tensor_copy_shape,
255        )
256    }
257
258    /// Return the tensor's `MPSDataType` raw value.
259    #[must_use]
260    pub fn data_type(&self) -> u32 {
261        // SAFETY: `self` owns a live tensor handle.
262        unsafe { ffi::mpsgraph_tensor_data_type(self.as_ptr()) }
263    }
264
265    /// Return the operation that produced this tensor.
266    #[must_use]
267    pub fn operation(&self) -> Option<Operation> {
268        // SAFETY: `self` owns a live tensor handle.
269        let ptr = unsafe { ffi::mpsgraph_tensor_operation(self.as_ptr()) };
270        if ptr.is_null() {
271            None
272        } else {
273            Some(Operation { ptr })
274        }
275    }
276}
277
278impl TensorData {
279    /// Return the graph-device type that backs this tensor data.
280    #[must_use]
281    pub fn graph_device_type(&self) -> Option<u32> {
282        // SAFETY: `self` owns a live tensor-data handle.
283        let ptr = unsafe { ffi::mpsgraph_tensor_data_device(self.as_ptr()) };
284        if ptr.is_null() {
285            return None;
286        }
287        let device = GraphDevice { ptr };
288        Some(device.device_type())
289    }
290}
291
292pub(crate) fn collect_owned_tensors(handle: *mut c_void) -> Vec<Tensor> {
293    collect_tensor_array_box(handle)
294}