Skip to main content

apple_mps/
ndarray.rs

1use crate::ffi;
2use apple_metal::{CommandBuffer as MetalCommandBuffer, MetalBuffer, MetalDevice};
3use core::ffi::c_void;
4use core::ptr;
5
6macro_rules! opaque_handle {
7    ($name:ident) => {
8        pub struct $name {
9            ptr: *mut c_void,
10        }
11
12        // SAFETY: MPS handles are opaque pointers to thread-safe Swift/ObjC objects.
13        unsafe impl Send for $name {}
14        // SAFETY: MPS handles are opaque pointers to thread-safe Swift/ObjC objects.
15        unsafe impl Sync for $name {}
16
17        impl Drop for $name {
18            fn drop(&mut self) {
19                if !self.ptr.is_null() {
20                    // SAFETY: `ptr` is a +1 retained MPS object owned by this wrapper.
21                    unsafe { ffi::mps_object_release(self.ptr) };
22                    self.ptr = ptr::null_mut();
23                }
24            }
25        }
26
27        impl $name {
28            #[must_use]
29            pub const fn as_ptr(&self) -> *mut c_void {
30                self.ptr
31            }
32        }
33    };
34}
35
36pub use crate::generated::ndarray::*;
37
38opaque_handle!(NDArrayDescriptor);
39impl NDArrayDescriptor {
40    #[must_use]
41    pub fn with_dimension_sizes(data_type: u32, dimension_sizes: &[usize]) -> Option<Self> {
42        // SAFETY: dimension_sizes.as_ptr() is valid for dimension_sizes.len() elements.
43        let ptr = unsafe {
44            ffi::mps_ndarray_descriptor_new_with_dimension_sizes(
45                data_type,
46                dimension_sizes.len(),
47                dimension_sizes.as_ptr(),
48            )
49        };
50        if ptr.is_null() {
51            None
52        } else {
53            Some(Self { ptr })
54        }
55    }
56
57    #[must_use]
58    pub fn data_type(&self) -> u32 {
59        // SAFETY: self.ptr is a valid NDArrayDescriptor.
60        unsafe { ffi::mps_ndarray_descriptor_data_type(self.ptr) }
61    }
62
63    pub fn set_data_type(&self, data_type: u32) {
64        // SAFETY: self.ptr is a valid NDArrayDescriptor.
65        unsafe { ffi::mps_ndarray_descriptor_set_data_type(self.ptr, data_type) };
66    }
67
68    #[must_use]
69    pub fn number_of_dimensions(&self) -> usize {
70        // SAFETY: self.ptr is a valid NDArrayDescriptor.
71        unsafe { ffi::mps_ndarray_descriptor_number_of_dimensions(self.ptr) }
72    }
73
74    pub fn set_number_of_dimensions(&self, number_of_dimensions: usize) {
75        // SAFETY: self.ptr is a valid NDArrayDescriptor.
76        unsafe {
77            ffi::mps_ndarray_descriptor_set_number_of_dimensions(self.ptr, number_of_dimensions);
78        };
79    }
80
81    #[must_use]
82    pub fn length_of_dimension(&self, dimension_index: usize) -> usize {
83        // SAFETY: self.ptr is a valid NDArrayDescriptor and dimension_index is in bounds.
84        unsafe { ffi::mps_ndarray_descriptor_length_of_dimension(self.ptr, dimension_index) }
85    }
86
87    pub fn reshape_with_dimension_sizes(&self, dimension_sizes: &[usize]) {
88        // SAFETY: dimension_sizes.as_ptr() is valid for dimension_sizes.len() elements.
89        unsafe {
90            ffi::mps_ndarray_descriptor_reshape_with_dimension_sizes(
91                self.ptr,
92                dimension_sizes.len(),
93                dimension_sizes.as_ptr(),
94            );
95        };
96    }
97
98    pub fn transpose_dimension(&self, dimension_index: usize, other_dimension_index: usize) {
99        // SAFETY: Both dimension indices are validated by MPS.
100        unsafe {
101            ffi::mps_ndarray_descriptor_transpose_dimension(
102                self.ptr,
103                dimension_index,
104                other_dimension_index,
105            );
106        };
107    }
108}
109
110opaque_handle!(NDArray);
111impl NDArray {
112    #[must_use]
113    pub fn new(device: &MetalDevice, descriptor: &NDArrayDescriptor) -> Option<Self> {
114        // SAFETY: Both pointers come from safe wrappers and are valid for the call.
115        let ptr =
116            unsafe { ffi::mps_ndarray_new_with_descriptor(device.as_ptr(), descriptor.as_ptr()) };
117        if ptr.is_null() {
118            None
119        } else {
120            Some(Self { ptr })
121        }
122    }
123
124    #[must_use]
125    pub fn scalar(device: &MetalDevice, value: f64) -> Option<Self> {
126        // SAFETY: device pointer is valid and we return null or a +1 retained NDArray.
127        let ptr = unsafe { ffi::mps_ndarray_new_scalar(device.as_ptr(), value) };
128        if ptr.is_null() {
129            None
130        } else {
131            Some(Self { ptr })
132        }
133    }
134
135    #[must_use]
136    pub fn new_with_buffer(
137        buffer: &MetalBuffer,
138        offset: usize,
139        descriptor: &NDArrayDescriptor,
140    ) -> Option<Self> {
141        let ptr = unsafe {
142            ffi::mps_ndarray_new_with_buffer(buffer.as_ptr(), offset, descriptor.as_ptr())
143        };
144        if ptr.is_null() {
145            None
146        } else {
147            Some(Self { ptr })
148        }
149    }
150
151    #[must_use]
152    pub fn data_type(&self) -> u32 {
153        unsafe { ffi::mps_ndarray_data_type(self.ptr) }
154    }
155
156    #[must_use]
157    pub fn number_of_dimensions(&self) -> usize {
158        unsafe { ffi::mps_ndarray_number_of_dimensions(self.ptr) }
159    }
160
161    #[must_use]
162    pub fn length_of_dimension(&self, dimension_index: usize) -> usize {
163        unsafe { ffi::mps_ndarray_length_of_dimension(self.ptr, dimension_index) }
164    }
165
166    #[must_use]
167    pub fn descriptor(&self) -> Option<NDArrayDescriptor> {
168        let ptr = unsafe { ffi::mps_ndarray_descriptor(self.ptr) };
169        if ptr.is_null() {
170            None
171        } else {
172            Some(NDArrayDescriptor { ptr })
173        }
174    }
175
176    #[must_use]
177    pub fn resource_size(&self) -> usize {
178        unsafe { ffi::mps_ndarray_resource_size(self.ptr) }
179    }
180}
181
182opaque_handle!(NDArrayIdentity);
183impl NDArrayIdentity {
184    #[must_use]
185    pub fn new(device: &MetalDevice) -> Option<Self> {
186        let ptr = unsafe { ffi::mps_ndarray_identity_new(device.as_ptr()) };
187        if ptr.is_null() {
188            None
189        } else {
190            Some(Self { ptr })
191        }
192    }
193
194    #[must_use]
195    pub fn reshape(&self, source: &NDArray, dimension_sizes: &[usize]) -> Option<NDArray> {
196        let ptr = unsafe {
197            ffi::mps_ndarray_identity_reshape(
198                self.ptr,
199                ptr::null_mut(),
200                source.as_ptr(),
201                dimension_sizes.len(),
202                dimension_sizes.as_ptr(),
203                ptr::null_mut(),
204            )
205        };
206        if ptr.is_null() {
207            None
208        } else {
209            Some(NDArray { ptr })
210        }
211    }
212
213    #[must_use]
214    pub fn reshape_with_command_buffer(
215        &self,
216        command_buffer: &MetalCommandBuffer,
217        source: &NDArray,
218        dimension_sizes: &[usize],
219    ) -> Option<NDArray> {
220        let ptr = unsafe {
221            ffi::mps_ndarray_identity_reshape(
222                self.ptr,
223                command_buffer.as_ptr(),
224                source.as_ptr(),
225                dimension_sizes.len(),
226                dimension_sizes.as_ptr(),
227                ptr::null_mut(),
228            )
229        };
230        if ptr.is_null() {
231            None
232        } else {
233            Some(NDArray { ptr })
234        }
235    }
236
237    pub fn reshape_into(
238        &self,
239        command_buffer: Option<&MetalCommandBuffer>,
240        source: &NDArray,
241        dimension_sizes: &[usize],
242        destination: &NDArray,
243    ) -> bool {
244        let command_buffer_ptr = command_buffer.map_or(ptr::null_mut(), MetalCommandBuffer::as_ptr);
245        let ptr = unsafe {
246            ffi::mps_ndarray_identity_reshape(
247                self.ptr,
248                command_buffer_ptr,
249                source.as_ptr(),
250                dimension_sizes.len(),
251                dimension_sizes.as_ptr(),
252                destination.as_ptr(),
253            )
254        };
255        !ptr.is_null()
256    }
257}
258
259opaque_handle!(NDArrayMatrixMultiplication);
260impl NDArrayMatrixMultiplication {
261    #[must_use]
262    pub fn new(device: &MetalDevice, source_count: usize) -> Option<Self> {
263        let ptr =
264            unsafe { ffi::mps_ndarray_matrix_multiplication_new(device.as_ptr(), source_count) };
265        if ptr.is_null() {
266            None
267        } else {
268            Some(Self { ptr })
269        }
270    }
271
272    #[must_use]
273    pub fn alpha(&self) -> f64 {
274        unsafe { ffi::mps_ndarray_matrix_multiplication_alpha(self.ptr) }
275    }
276
277    pub fn set_alpha(&self, alpha: f64) {
278        unsafe { ffi::mps_ndarray_matrix_multiplication_set_alpha(self.ptr, alpha) };
279    }
280
281    #[must_use]
282    pub fn beta(&self) -> f64 {
283        unsafe { ffi::mps_ndarray_matrix_multiplication_beta(self.ptr) }
284    }
285
286    pub fn set_beta(&self, beta: f64) {
287        unsafe { ffi::mps_ndarray_matrix_multiplication_set_beta(self.ptr, beta) };
288    }
289
290    #[must_use]
291    pub fn encode(
292        &self,
293        command_buffer: &MetalCommandBuffer,
294        source_arrays: &[&NDArray],
295    ) -> Option<NDArray> {
296        let handles: Vec<_> = source_arrays.iter().map(|array| array.as_ptr()).collect();
297        let handles_ptr = if handles.is_empty() {
298            ptr::null()
299        } else {
300            handles.as_ptr()
301        };
302        let ptr = unsafe {
303            ffi::mps_ndarray_matrix_multiplication_encode(
304                self.ptr,
305                command_buffer.as_ptr(),
306                source_arrays.len(),
307                handles_ptr,
308            )
309        };
310        if ptr.is_null() {
311            None
312        } else {
313            Some(NDArray { ptr })
314        }
315    }
316
317    pub fn encode_to_destination(
318        &self,
319        command_buffer: &MetalCommandBuffer,
320        source_arrays: &[&NDArray],
321        destination: &NDArray,
322    ) {
323        let handles: Vec<_> = source_arrays.iter().map(|array| array.as_ptr()).collect();
324        let handles_ptr = if handles.is_empty() {
325            ptr::null()
326        } else {
327            handles.as_ptr()
328        };
329        unsafe {
330            ffi::mps_ndarray_matrix_multiplication_encode_to_destination(
331                self.ptr,
332                command_buffer.as_ptr(),
333                source_arrays.len(),
334                handles_ptr,
335                destination.as_ptr(),
336            );
337        };
338    }
339}