Skip to main content

apple_mps/
ndarray.rs

1use crate::ffi;
2use apple_metal::{CommandBuffer as MetalCommandBuffer, MetalBuffer, MetalDevice};
3use core::ffi::c_void;
4use core::ptr;
5
6macro_rules! opaque_handle {
7    ($name:ident, $doc:expr) => {
8        #[doc = $doc]
9        pub struct $name {
10            ptr: *mut c_void,
11        }
12
13        // SAFETY: MPS handles are opaque pointers to thread-safe Swift/ObjC objects.
14        unsafe impl Send for $name {}
15        // SAFETY: MPS handles are opaque pointers to thread-safe Swift/ObjC objects.
16        unsafe impl Sync for $name {}
17
18        impl Drop for $name {
19            fn drop(&mut self) {
20                if !self.ptr.is_null() {
21                    // SAFETY: `ptr` is a +1 retained MPS object owned by this wrapper.
22                    unsafe { ffi::mps_object_release(self.ptr) };
23                    self.ptr = ptr::null_mut();
24                }
25            }
26        }
27
28        impl $name {
29            /// Returns the retained Objective-C pointer backing this wrapper.
30            #[must_use]
31            pub const fn as_ptr(&self) -> *mut c_void {
32                self.ptr
33            }
34        }
35    };
36}
37
38#[doc(hidden)]
39pub use crate::generated::ndarray::*;
40
41opaque_handle!(NDArrayDescriptor, "Wraps `MPSNDArrayDescriptor`.");
42impl NDArrayDescriptor {
43    /// Wraps the corresponding `MPSNDArrayDescriptor` method.
44    #[must_use]
45    pub fn with_dimension_sizes(data_type: u32, dimension_sizes: &[usize]) -> Option<Self> {
46        // SAFETY: dimension_sizes.as_ptr() is valid for dimension_sizes.len() elements.
47        let ptr = unsafe {
48            ffi::mps_ndarray_descriptor_new_with_dimension_sizes(
49                data_type,
50                dimension_sizes.len(),
51                dimension_sizes.as_ptr(),
52            )
53        };
54        if ptr.is_null() {
55            None
56        } else {
57            Some(Self { ptr })
58        }
59    }
60
61    /// Wraps the corresponding `MPSNDArrayDescriptor` method.
62    #[must_use]
63    pub fn data_type(&self) -> u32 {
64        // SAFETY: self.ptr is a valid NDArrayDescriptor.
65        unsafe { ffi::mps_ndarray_descriptor_data_type(self.ptr) }
66    }
67
68    /// Wraps the corresponding `MPSNDArrayDescriptor` setter.
69    pub fn set_data_type(&self, data_type: u32) {
70        // SAFETY: self.ptr is a valid NDArrayDescriptor.
71        unsafe { ffi::mps_ndarray_descriptor_set_data_type(self.ptr, data_type) };
72    }
73
74    /// Wraps the corresponding `MPSNDArrayDescriptor` method.
75    #[must_use]
76    pub fn number_of_dimensions(&self) -> usize {
77        // SAFETY: self.ptr is a valid NDArrayDescriptor.
78        unsafe { ffi::mps_ndarray_descriptor_number_of_dimensions(self.ptr) }
79    }
80
81    /// Wraps the corresponding `MPSNDArrayDescriptor` setter.
82    pub fn set_number_of_dimensions(&self, number_of_dimensions: usize) {
83        // SAFETY: self.ptr is a valid NDArrayDescriptor.
84        unsafe {
85            ffi::mps_ndarray_descriptor_set_number_of_dimensions(self.ptr, number_of_dimensions);
86        };
87    }
88
89    /// Wraps the corresponding `MPSNDArrayDescriptor` method.
90    #[must_use]
91    pub fn length_of_dimension(&self, dimension_index: usize) -> usize {
92        // SAFETY: self.ptr is a valid NDArrayDescriptor and dimension_index is in bounds.
93        unsafe { ffi::mps_ndarray_descriptor_length_of_dimension(self.ptr, dimension_index) }
94    }
95
96    /// Wraps the corresponding `MPSNDArrayDescriptor` method.
97    pub fn reshape_with_dimension_sizes(&self, dimension_sizes: &[usize]) {
98        // SAFETY: dimension_sizes.as_ptr() is valid for dimension_sizes.len() elements.
99        unsafe {
100            ffi::mps_ndarray_descriptor_reshape_with_dimension_sizes(
101                self.ptr,
102                dimension_sizes.len(),
103                dimension_sizes.as_ptr(),
104            );
105        };
106    }
107
108    /// Wraps the corresponding `MPSNDArrayDescriptor` method.
109    pub fn transpose_dimension(&self, dimension_index: usize, other_dimension_index: usize) {
110        // SAFETY: Both dimension indices are validated by MPS.
111        unsafe {
112            ffi::mps_ndarray_descriptor_transpose_dimension(
113                self.ptr,
114                dimension_index,
115                other_dimension_index,
116            );
117        };
118    }
119}
120
121opaque_handle!(NDArray, "Wraps `MPSNDArray`.");
122impl NDArray {
123    /// Wraps a constructor on `MPSNDArray`.
124    #[must_use]
125    pub fn new(device: &MetalDevice, descriptor: &NDArrayDescriptor) -> Option<Self> {
126        // SAFETY: Both pointers come from safe wrappers and are valid for the call.
127        let ptr =
128            unsafe { ffi::mps_ndarray_new_with_descriptor(device.as_ptr(), descriptor.as_ptr()) };
129        if ptr.is_null() {
130            None
131        } else {
132            Some(Self { ptr })
133        }
134    }
135
136    /// Wraps a constructor on `MPSNDArray`.
137    #[must_use]
138    pub fn scalar(device: &MetalDevice, value: f64) -> Option<Self> {
139        // SAFETY: device pointer is valid and we return null or a +1 retained NDArray.
140        let ptr = unsafe { ffi::mps_ndarray_new_scalar(device.as_ptr(), value) };
141        if ptr.is_null() {
142            None
143        } else {
144            Some(Self { ptr })
145        }
146    }
147
148    /// Wraps a constructor on `MPSNDArray`.
149    #[must_use]
150    pub fn new_with_buffer(
151        buffer: &MetalBuffer,
152        offset: usize,
153        descriptor: &NDArrayDescriptor,
154    ) -> Option<Self> {
155        let ptr = unsafe {
156            ffi::mps_ndarray_new_with_buffer(buffer.as_ptr(), offset, descriptor.as_ptr())
157        };
158        if ptr.is_null() {
159            None
160        } else {
161            Some(Self { ptr })
162        }
163    }
164
165    /// Wraps the corresponding `MPSNDArray` method.
166    #[must_use]
167    pub fn data_type(&self) -> u32 {
168        unsafe { ffi::mps_ndarray_data_type(self.ptr) }
169    }
170
171    /// Wraps the corresponding `MPSNDArray` method.
172    #[must_use]
173    pub fn number_of_dimensions(&self) -> usize {
174        unsafe { ffi::mps_ndarray_number_of_dimensions(self.ptr) }
175    }
176
177    /// Wraps the corresponding `MPSNDArray` method.
178    #[must_use]
179    pub fn length_of_dimension(&self, dimension_index: usize) -> usize {
180        unsafe { ffi::mps_ndarray_length_of_dimension(self.ptr, dimension_index) }
181    }
182
183    /// Wraps the corresponding `MPSNDArray` method.
184    #[must_use]
185    pub fn descriptor(&self) -> Option<NDArrayDescriptor> {
186        let ptr = unsafe { ffi::mps_ndarray_descriptor(self.ptr) };
187        if ptr.is_null() {
188            None
189        } else {
190            Some(NDArrayDescriptor { ptr })
191        }
192    }
193
194    /// Wraps the corresponding `MPSNDArray` method.
195    #[must_use]
196    pub fn resource_size(&self) -> usize {
197        unsafe { ffi::mps_ndarray_resource_size(self.ptr) }
198    }
199}
200
201opaque_handle!(NDArrayIdentity, "Wraps `MPSNDArrayIdentity`.");
202impl NDArrayIdentity {
203    /// Wraps a constructor on `MPSNDArrayIdentity`.
204    #[must_use]
205    pub fn new(device: &MetalDevice) -> Option<Self> {
206        let ptr = unsafe { ffi::mps_ndarray_identity_new(device.as_ptr()) };
207        if ptr.is_null() {
208            None
209        } else {
210            Some(Self { ptr })
211        }
212    }
213
214    /// Wraps the corresponding `MPSNDArrayIdentity` method.
215    #[must_use]
216    pub fn reshape(&self, source: &NDArray, dimension_sizes: &[usize]) -> Option<NDArray> {
217        let ptr = unsafe {
218            ffi::mps_ndarray_identity_reshape(
219                self.ptr,
220                ptr::null_mut(),
221                source.as_ptr(),
222                dimension_sizes.len(),
223                dimension_sizes.as_ptr(),
224                ptr::null_mut(),
225            )
226        };
227        if ptr.is_null() {
228            None
229        } else {
230            Some(NDArray { ptr })
231        }
232    }
233
234    /// Wraps the corresponding `MPSNDArrayIdentity` method.
235    #[must_use]
236    pub fn reshape_with_command_buffer(
237        &self,
238        command_buffer: &MetalCommandBuffer,
239        source: &NDArray,
240        dimension_sizes: &[usize],
241    ) -> Option<NDArray> {
242        let ptr = unsafe {
243            ffi::mps_ndarray_identity_reshape(
244                self.ptr,
245                command_buffer.as_ptr(),
246                source.as_ptr(),
247                dimension_sizes.len(),
248                dimension_sizes.as_ptr(),
249                ptr::null_mut(),
250            )
251        };
252        if ptr.is_null() {
253            None
254        } else {
255            Some(NDArray { ptr })
256        }
257    }
258
259    /// Wraps the corresponding `MPSNDArrayIdentity` method.
260    pub fn reshape_into(
261        &self,
262        command_buffer: Option<&MetalCommandBuffer>,
263        source: &NDArray,
264        dimension_sizes: &[usize],
265        destination: &NDArray,
266    ) -> bool {
267        let command_buffer_ptr = command_buffer.map_or(ptr::null_mut(), MetalCommandBuffer::as_ptr);
268        let ptr = unsafe {
269            ffi::mps_ndarray_identity_reshape(
270                self.ptr,
271                command_buffer_ptr,
272                source.as_ptr(),
273                dimension_sizes.len(),
274                dimension_sizes.as_ptr(),
275                destination.as_ptr(),
276            )
277        };
278        !ptr.is_null()
279    }
280}
281
282opaque_handle!(NDArrayMatrixMultiplication, "Wraps `MPSNDArrayMatrixMultiplication`.");
283impl NDArrayMatrixMultiplication {
284    /// Wraps a constructor on `MPSNDArrayMatrixMultiplication`.
285    #[must_use]
286    pub fn new(device: &MetalDevice, source_count: usize) -> Option<Self> {
287        let ptr =
288            unsafe { ffi::mps_ndarray_matrix_multiplication_new(device.as_ptr(), source_count) };
289        if ptr.is_null() {
290            None
291        } else {
292            Some(Self { ptr })
293        }
294    }
295
296    /// Wraps the corresponding `MPSNDArrayMatrixMultiplication` method.
297    #[must_use]
298    pub fn alpha(&self) -> f64 {
299        unsafe { ffi::mps_ndarray_matrix_multiplication_alpha(self.ptr) }
300    }
301
302    /// Wraps the corresponding `MPSNDArrayMatrixMultiplication` setter.
303    pub fn set_alpha(&self, alpha: f64) {
304        unsafe { ffi::mps_ndarray_matrix_multiplication_set_alpha(self.ptr, alpha) };
305    }
306
307    /// Wraps the corresponding `MPSNDArrayMatrixMultiplication` method.
308    #[must_use]
309    pub fn beta(&self) -> f64 {
310        unsafe { ffi::mps_ndarray_matrix_multiplication_beta(self.ptr) }
311    }
312
313    /// Wraps the corresponding `MPSNDArrayMatrixMultiplication` setter.
314    pub fn set_beta(&self, beta: f64) {
315        unsafe { ffi::mps_ndarray_matrix_multiplication_set_beta(self.ptr, beta) };
316    }
317
318    /// Wraps the corresponding `MPSNDArrayMatrixMultiplication` encode entry point.
319    #[must_use]
320    pub fn encode(
321        &self,
322        command_buffer: &MetalCommandBuffer,
323        source_arrays: &[&NDArray],
324    ) -> Option<NDArray> {
325        let handles: Vec<_> = source_arrays.iter().map(|array| array.as_ptr()).collect();
326        let handles_ptr = if handles.is_empty() {
327            ptr::null()
328        } else {
329            handles.as_ptr()
330        };
331        let ptr = unsafe {
332            ffi::mps_ndarray_matrix_multiplication_encode(
333                self.ptr,
334                command_buffer.as_ptr(),
335                source_arrays.len(),
336                handles_ptr,
337            )
338        };
339        if ptr.is_null() {
340            None
341        } else {
342            Some(NDArray { ptr })
343        }
344    }
345
346    /// Wraps the corresponding `MPSNDArrayMatrixMultiplication` encode entry point.
347    pub fn encode_to_destination(
348        &self,
349        command_buffer: &MetalCommandBuffer,
350        source_arrays: &[&NDArray],
351        destination: &NDArray,
352    ) {
353        let handles: Vec<_> = source_arrays.iter().map(|array| array.as_ptr()).collect();
354        let handles_ptr = if handles.is_empty() {
355            ptr::null()
356        } else {
357            handles.as_ptr()
358        };
359        unsafe {
360            ffi::mps_ndarray_matrix_multiplication_encode_to_destination(
361                self.ptr,
362                command_buffer.as_ptr(),
363                source_arrays.len(),
364                handles_ptr,
365                destination.as_ptr(),
366            );
367        };
368    }
369}