Skip to main content

apple_mpsgraph/
types.rs

1use crate::data::TensorData;
2use crate::error::{Error, Result};
3use crate::ffi;
4use crate::graph::Tensor;
5use apple_metal::MetalDevice;
6use core::ffi::c_void;
7use core::ptr;
8
9fn release_handle(ptr: &mut *mut c_void) {
10    if !ptr.is_null() {
11        // SAFETY: `ptr` is a +1 retained Swift/ObjC object pointer owned by this wrapper.
12        unsafe { ffi::mpsgraph_object_release(*ptr) };
13        *ptr = ptr::null_mut();
14    }
15}
16
17fn copy_optional_signed_shape(
18    handle: *mut c_void,
19    has_shape: unsafe extern "C" fn(*mut c_void) -> bool,
20    shape_len: unsafe extern "C" fn(*mut c_void) -> usize,
21    copy_shape: unsafe extern "C" fn(*mut c_void, *mut isize),
22) -> Option<Vec<isize>> {
23    // SAFETY: the function pointers belong to Swift shims that treat `handle` as immutable for the duration of the call.
24    if unsafe { !has_shape(handle) } {
25        return None;
26    }
27    // SAFETY: see above.
28    let len = unsafe { shape_len(handle) };
29    let mut shape = vec![0_isize; len];
30    if len > 0 {
31        // SAFETY: `shape` has space for exactly `len` elements.
32        unsafe { copy_shape(handle, shape.as_mut_ptr()) };
33    }
34    Some(shape)
35}
36
37fn collect_tensor_array_box(handle: *mut c_void) -> Vec<Tensor> {
38    if handle.is_null() {
39        return Vec::new();
40    }
41
42    // SAFETY: `handle` is a retained tensor-array box created by the Swift bridge.
43    let len = unsafe { ffi::mpsgraph_tensor_array_box_len(handle) };
44    let mut tensors = Vec::with_capacity(len);
45    for index in 0..len {
46        // SAFETY: indices are bounded by the just-read length.
47        let tensor = unsafe { ffi::mpsgraph_tensor_array_box_get(handle, index) };
48        if !tensor.is_null() {
49            tensors.push(Tensor::from_raw(tensor));
50        }
51    }
52    let mut box_handle = handle;
53    release_handle(&mut box_handle);
54    tensors
55}
56
57pub(crate) fn collect_tensor_data_array_box(handle: *mut c_void) -> Vec<TensorData> {
58    if handle.is_null() {
59        return Vec::new();
60    }
61
62    // SAFETY: `handle` is a retained tensor-data-array box created by the Swift bridge.
63    let len = unsafe { ffi::mpsgraph_tensor_data_array_box_len(handle) };
64    let mut values = Vec::with_capacity(len);
65    for index in 0..len {
66        // SAFETY: indices are bounded by the just-read length.
67        let value = unsafe { ffi::mpsgraph_tensor_data_array_box_get(handle, index) };
68        if !value.is_null() {
69            values.push(TensorData::from_raw(value));
70        }
71    }
72    let mut box_handle = handle;
73    release_handle(&mut box_handle);
74    values
75}
76
77pub(crate) fn collect_shaped_type_array_box(handle: *mut c_void) -> Vec<ShapedType> {
78    if handle.is_null() {
79        return Vec::new();
80    }
81
82    // SAFETY: `handle` is a retained shaped-type-array box created by the Swift bridge.
83    let len = unsafe { ffi::mpsgraph_shaped_type_array_box_len(handle) };
84    let mut values = Vec::with_capacity(len);
85    for index in 0..len {
86        // SAFETY: indices are bounded by the just-read length.
87        let value = unsafe { ffi::mpsgraph_shaped_type_array_box_get(handle, index) };
88        if !value.is_null() {
89            values.push(ShapedType { ptr: value });
90        }
91    }
92    let mut box_handle = handle;
93    release_handle(&mut box_handle);
94    values
95}
96
97/// `MPSGraphDeviceType` constants.
98pub mod graph_device_type {
99    pub const METAL: u32 = 0;
100}
101
102/// Owned wrapper for `MPSGraphDevice`.
103pub struct GraphDevice {
104    ptr: *mut c_void,
105}
106
107unsafe impl Send for GraphDevice {}
108unsafe impl Sync for GraphDevice {}
109
110impl Drop for GraphDevice {
111    fn drop(&mut self) {
112        release_handle(&mut self.ptr);
113    }
114}
115
116impl GraphDevice {
117    /// Create a graph device from an existing Metal device.
118    #[must_use]
119    pub fn from_metal_device(device: &MetalDevice) -> Option<Self> {
120        // SAFETY: `device` remains valid for the duration of the bridge call.
121        let ptr = unsafe { ffi::mpsgraph_device_new_with_metal_device(device.as_ptr()) };
122        if ptr.is_null() {
123            None
124        } else {
125            Some(Self { ptr })
126        }
127    }
128
129    #[must_use]
130    pub const fn as_ptr(&self) -> *mut c_void {
131        self.ptr
132    }
133
134    /// Return the underlying `MPSGraphDeviceType` raw value.
135    #[must_use]
136    pub fn device_type(&self) -> u32 {
137        // SAFETY: `self.ptr` is a live graph-device handle.
138        unsafe { ffi::mpsgraph_device_type(self.ptr) }
139    }
140}
141
142/// Owned wrapper for `MPSGraphShapedType`.
143pub struct ShapedType {
144    ptr: *mut c_void,
145}
146
147unsafe impl Send for ShapedType {}
148unsafe impl Sync for ShapedType {}
149
150impl Drop for ShapedType {
151    fn drop(&mut self) {
152        release_handle(&mut self.ptr);
153    }
154}
155
156impl ShapedType {
157    /// Create a shaped type from an optional shape and `MPSDataType` raw value.
158    #[must_use]
159    pub fn new(shape: Option<&[isize]>, data_type: u32) -> Option<Self> {
160        let (shape_ptr, shape_len) =
161            shape.map_or((ptr::null(), 0), |shape| (shape.as_ptr(), shape.len()));
162        // SAFETY: the optional slice lives for the duration of the call.
163        let ptr = unsafe { ffi::mpsgraph_shaped_type_new(shape_ptr, shape_len, data_type) };
164        if ptr.is_null() {
165            None
166        } else {
167            Some(Self { ptr })
168        }
169    }
170
171    #[must_use]
172    pub(crate) const fn as_ptr(&self) -> *mut c_void {
173        self.ptr
174    }
175
176    /// Return the optional tensor shape. `None` corresponds to an unranked shape.
177    #[must_use]
178    pub fn shape(&self) -> Option<Vec<isize>> {
179        copy_optional_signed_shape(
180            self.ptr,
181            ffi::mpsgraph_shaped_type_has_shape,
182            ffi::mpsgraph_shaped_type_shape_len,
183            ffi::mpsgraph_shaped_type_copy_shape,
184        )
185    }
186
187    /// Return the underlying `MPSDataType` raw value.
188    #[must_use]
189    pub fn data_type(&self) -> u32 {
190        // SAFETY: `self.ptr` is a live shaped-type handle.
191        unsafe { ffi::mpsgraph_shaped_type_data_type(self.ptr) }
192    }
193
194    /// Replace the shape metadata for this shaped type.
195    pub fn set_shape(&self, shape: Option<&[isize]>) -> Result<()> {
196        let (shape_ptr, shape_len) =
197            shape.map_or((ptr::null(), 0), |shape| (shape.as_ptr(), shape.len()));
198        // SAFETY: the optional slice lives for the duration of the call.
199        let ok = unsafe { ffi::mpsgraph_shaped_type_set_shape(self.ptr, shape_ptr, shape_len) };
200        if ok {
201            Ok(())
202        } else {
203            Err(Error::OperationFailed("failed to set shaped type shape"))
204        }
205    }
206
207    /// Replace the data-type metadata for this shaped type.
208    pub fn set_data_type(&self, data_type: u32) -> Result<()> {
209        // SAFETY: `self.ptr` is a live shaped-type handle.
210        let ok = unsafe { ffi::mpsgraph_shaped_type_set_data_type(self.ptr, data_type) };
211        if ok {
212            Ok(())
213        } else {
214            Err(Error::OperationFailed(
215                "failed to set shaped type data type",
216            ))
217        }
218    }
219
220    /// Compare two shaped types using `MPSGraphShapedType.isEqual(to:)`.
221    #[must_use]
222    pub fn is_equal(&self, other: Option<&Self>) -> bool {
223        let other_ptr = other.map_or(ptr::null_mut(), Self::as_ptr);
224        // SAFETY: all handles stay alive for the duration of the call.
225        unsafe { ffi::mpsgraph_shaped_type_is_equal(self.ptr, other_ptr) }
226    }
227}
228
229/// Owned wrapper for `MPSGraphOperation`.
230pub struct Operation {
231    ptr: *mut c_void,
232}
233
234unsafe impl Send for Operation {}
235unsafe impl Sync for Operation {}
236
237impl Drop for Operation {
238    fn drop(&mut self) {
239        release_handle(&mut self.ptr);
240    }
241}
242
243impl Operation {
244    #[must_use]
245    pub(crate) const fn from_raw(ptr: *mut c_void) -> Self {
246        Self { ptr }
247    }
248
249    #[must_use]
250    pub const fn as_ptr(&self) -> *mut c_void {
251        self.ptr
252    }
253}
254
255impl Tensor {
256    /// Return the optional symbolic tensor shape.
257    #[must_use]
258    pub fn shape(&self) -> Option<Vec<isize>> {
259        copy_optional_signed_shape(
260            self.as_ptr(),
261            ffi::mpsgraph_tensor_has_shape,
262            ffi::mpsgraph_tensor_shape_len,
263            ffi::mpsgraph_tensor_copy_shape,
264        )
265    }
266
267    /// Return the tensor's `MPSDataType` raw value.
268    #[must_use]
269    pub fn data_type(&self) -> u32 {
270        // SAFETY: `self` owns a live tensor handle.
271        unsafe { ffi::mpsgraph_tensor_data_type(self.as_ptr()) }
272    }
273
274    /// Return the operation that produced this tensor.
275    #[must_use]
276    pub fn operation(&self) -> Option<Operation> {
277        // SAFETY: `self` owns a live tensor handle.
278        let ptr = unsafe { ffi::mpsgraph_tensor_operation(self.as_ptr()) };
279        if ptr.is_null() {
280            None
281        } else {
282            Some(Operation { ptr })
283        }
284    }
285}
286
287impl TensorData {
288    /// Return the graph-device type that backs this tensor data.
289    #[must_use]
290    pub fn graph_device_type(&self) -> Option<u32> {
291        // SAFETY: `self` owns a live tensor-data handle.
292        let ptr = unsafe { ffi::mpsgraph_tensor_data_device(self.as_ptr()) };
293        if ptr.is_null() {
294            return None;
295        }
296        let device = GraphDevice { ptr };
297        Some(device.device_type())
298    }
299}
300
301pub(crate) fn collect_owned_tensors(handle: *mut c_void) -> Vec<Tensor> {
302    collect_tensor_array_box(handle)
303}