Skip to main content

apple_mps/
ndarray.rs

1use crate::ffi;
2use apple_metal::{CommandBuffer as MetalCommandBuffer, MetalBuffer, MetalDevice};
3use core::ffi::c_void;
4use core::ptr;
5
6macro_rules! opaque_handle {
7    ($name:ident) => {
8        pub struct $name {
9            ptr: *mut c_void,
10        }
11
12        unsafe impl Send for $name {}
13        unsafe impl Sync for $name {}
14
15        impl Drop for $name {
16            fn drop(&mut self) {
17                if !self.ptr.is_null() {
18                    unsafe { ffi::mps_object_release(self.ptr) };
19                    self.ptr = ptr::null_mut();
20                }
21            }
22        }
23
24        impl $name {
25            #[must_use]
26            pub const fn as_ptr(&self) -> *mut c_void {
27                self.ptr
28            }
29        }
30    };
31}
32
33opaque_handle!(NDArrayDescriptor);
34impl NDArrayDescriptor {
35    #[must_use]
36    pub fn with_dimension_sizes(data_type: u32, dimension_sizes: &[usize]) -> Option<Self> {
37        let ptr = unsafe {
38            ffi::mps_ndarray_descriptor_new_with_dimension_sizes(
39                data_type,
40                dimension_sizes.len(),
41                dimension_sizes.as_ptr(),
42            )
43        };
44        if ptr.is_null() {
45            None
46        } else {
47            Some(Self { ptr })
48        }
49    }
50
51    #[must_use]
52    pub fn data_type(&self) -> u32 {
53        unsafe { ffi::mps_ndarray_descriptor_data_type(self.ptr) }
54    }
55
56    pub fn set_data_type(&self, data_type: u32) {
57        unsafe { ffi::mps_ndarray_descriptor_set_data_type(self.ptr, data_type) };
58    }
59
60    #[must_use]
61    pub fn number_of_dimensions(&self) -> usize {
62        unsafe { ffi::mps_ndarray_descriptor_number_of_dimensions(self.ptr) }
63    }
64
65    pub fn set_number_of_dimensions(&self, number_of_dimensions: usize) {
66        unsafe {
67            ffi::mps_ndarray_descriptor_set_number_of_dimensions(self.ptr, number_of_dimensions);
68        };
69    }
70
71    #[must_use]
72    pub fn length_of_dimension(&self, dimension_index: usize) -> usize {
73        unsafe { ffi::mps_ndarray_descriptor_length_of_dimension(self.ptr, dimension_index) }
74    }
75
76    pub fn reshape_with_dimension_sizes(&self, dimension_sizes: &[usize]) {
77        unsafe {
78            ffi::mps_ndarray_descriptor_reshape_with_dimension_sizes(
79                self.ptr,
80                dimension_sizes.len(),
81                dimension_sizes.as_ptr(),
82            );
83        };
84    }
85
86    pub fn transpose_dimension(&self, dimension_index: usize, other_dimension_index: usize) {
87        unsafe {
88            ffi::mps_ndarray_descriptor_transpose_dimension(
89                self.ptr,
90                dimension_index,
91                other_dimension_index,
92            );
93        };
94    }
95}
96
97opaque_handle!(NDArray);
98impl NDArray {
99    #[must_use]
100    pub fn new(device: &MetalDevice, descriptor: &NDArrayDescriptor) -> Option<Self> {
101        let ptr =
102            unsafe { ffi::mps_ndarray_new_with_descriptor(device.as_ptr(), descriptor.as_ptr()) };
103        if ptr.is_null() {
104            None
105        } else {
106            Some(Self { ptr })
107        }
108    }
109
110    #[must_use]
111    pub fn scalar(device: &MetalDevice, value: f64) -> Option<Self> {
112        let ptr = unsafe { ffi::mps_ndarray_new_scalar(device.as_ptr(), value) };
113        if ptr.is_null() {
114            None
115        } else {
116            Some(Self { ptr })
117        }
118    }
119
120    #[must_use]
121    pub fn new_with_buffer(
122        buffer: &MetalBuffer,
123        offset: usize,
124        descriptor: &NDArrayDescriptor,
125    ) -> Option<Self> {
126        let ptr = unsafe {
127            ffi::mps_ndarray_new_with_buffer(buffer.as_ptr(), offset, descriptor.as_ptr())
128        };
129        if ptr.is_null() {
130            None
131        } else {
132            Some(Self { ptr })
133        }
134    }
135
136    #[must_use]
137    pub fn data_type(&self) -> u32 {
138        unsafe { ffi::mps_ndarray_data_type(self.ptr) }
139    }
140
141    #[must_use]
142    pub fn number_of_dimensions(&self) -> usize {
143        unsafe { ffi::mps_ndarray_number_of_dimensions(self.ptr) }
144    }
145
146    #[must_use]
147    pub fn length_of_dimension(&self, dimension_index: usize) -> usize {
148        unsafe { ffi::mps_ndarray_length_of_dimension(self.ptr, dimension_index) }
149    }
150
151    #[must_use]
152    pub fn descriptor(&self) -> Option<NDArrayDescriptor> {
153        let ptr = unsafe { ffi::mps_ndarray_descriptor(self.ptr) };
154        if ptr.is_null() {
155            None
156        } else {
157            Some(NDArrayDescriptor { ptr })
158        }
159    }
160
161    #[must_use]
162    pub fn resource_size(&self) -> usize {
163        unsafe { ffi::mps_ndarray_resource_size(self.ptr) }
164    }
165}
166
167opaque_handle!(NDArrayIdentity);
168impl NDArrayIdentity {
169    #[must_use]
170    pub fn new(device: &MetalDevice) -> Option<Self> {
171        let ptr = unsafe { ffi::mps_ndarray_identity_new(device.as_ptr()) };
172        if ptr.is_null() {
173            None
174        } else {
175            Some(Self { ptr })
176        }
177    }
178
179    #[must_use]
180    pub fn reshape(&self, source: &NDArray, dimension_sizes: &[usize]) -> Option<NDArray> {
181        let ptr = unsafe {
182            ffi::mps_ndarray_identity_reshape(
183                self.ptr,
184                ptr::null_mut(),
185                source.as_ptr(),
186                dimension_sizes.len(),
187                dimension_sizes.as_ptr(),
188                ptr::null_mut(),
189            )
190        };
191        if ptr.is_null() {
192            None
193        } else {
194            Some(NDArray { ptr })
195        }
196    }
197
198    #[must_use]
199    pub fn reshape_with_command_buffer(
200        &self,
201        command_buffer: &MetalCommandBuffer,
202        source: &NDArray,
203        dimension_sizes: &[usize],
204    ) -> Option<NDArray> {
205        let ptr = unsafe {
206            ffi::mps_ndarray_identity_reshape(
207                self.ptr,
208                command_buffer.as_ptr(),
209                source.as_ptr(),
210                dimension_sizes.len(),
211                dimension_sizes.as_ptr(),
212                ptr::null_mut(),
213            )
214        };
215        if ptr.is_null() {
216            None
217        } else {
218            Some(NDArray { ptr })
219        }
220    }
221
222    pub fn reshape_into(
223        &self,
224        command_buffer: Option<&MetalCommandBuffer>,
225        source: &NDArray,
226        dimension_sizes: &[usize],
227        destination: &NDArray,
228    ) -> bool {
229        let command_buffer_ptr = command_buffer.map_or(ptr::null_mut(), MetalCommandBuffer::as_ptr);
230        let ptr = unsafe {
231            ffi::mps_ndarray_identity_reshape(
232                self.ptr,
233                command_buffer_ptr,
234                source.as_ptr(),
235                dimension_sizes.len(),
236                dimension_sizes.as_ptr(),
237                destination.as_ptr(),
238            )
239        };
240        !ptr.is_null()
241    }
242}
243
244opaque_handle!(NDArrayMatrixMultiplication);
245impl NDArrayMatrixMultiplication {
246    #[must_use]
247    pub fn new(device: &MetalDevice, source_count: usize) -> Option<Self> {
248        let ptr = unsafe { ffi::mps_ndarray_matrix_multiplication_new(device.as_ptr(), source_count) };
249        if ptr.is_null() {
250            None
251        } else {
252            Some(Self { ptr })
253        }
254    }
255
256    #[must_use]
257    pub fn alpha(&self) -> f64 {
258        unsafe { ffi::mps_ndarray_matrix_multiplication_alpha(self.ptr) }
259    }
260
261    pub fn set_alpha(&self, alpha: f64) {
262        unsafe { ffi::mps_ndarray_matrix_multiplication_set_alpha(self.ptr, alpha) };
263    }
264
265    #[must_use]
266    pub fn beta(&self) -> f64 {
267        unsafe { ffi::mps_ndarray_matrix_multiplication_beta(self.ptr) }
268    }
269
270    pub fn set_beta(&self, beta: f64) {
271        unsafe { ffi::mps_ndarray_matrix_multiplication_set_beta(self.ptr, beta) };
272    }
273
274    #[must_use]
275    pub fn encode(
276        &self,
277        command_buffer: &MetalCommandBuffer,
278        source_arrays: &[&NDArray],
279    ) -> Option<NDArray> {
280        let handles: Vec<_> = source_arrays.iter().map(|array| array.as_ptr()).collect();
281        let handles_ptr = if handles.is_empty() {
282            ptr::null()
283        } else {
284            handles.as_ptr()
285        };
286        let ptr = unsafe {
287            ffi::mps_ndarray_matrix_multiplication_encode(
288                self.ptr,
289                command_buffer.as_ptr(),
290                source_arrays.len(),
291                handles_ptr,
292            )
293        };
294        if ptr.is_null() {
295            None
296        } else {
297            Some(NDArray { ptr })
298        }
299    }
300
301    pub fn encode_to_destination(
302        &self,
303        command_buffer: &MetalCommandBuffer,
304        source_arrays: &[&NDArray],
305        destination: &NDArray,
306    ) {
307        let handles: Vec<_> = source_arrays.iter().map(|array| array.as_ptr()).collect();
308        let handles_ptr = if handles.is_empty() {
309            ptr::null()
310        } else {
311            handles.as_ptr()
312        };
313        unsafe {
314            ffi::mps_ndarray_matrix_multiplication_encode_to_destination(
315                self.ptr,
316                command_buffer.as_ptr(),
317                source_arrays.len(),
318                handles_ptr,
319                destination.as_ptr(),
320            );
321        };
322    }
323}