1use crate::ffi;
2use apple_metal::{CommandBuffer as MetalCommandBuffer, MetalBuffer, MetalDevice};
3use core::ffi::c_void;
4use core::ptr;
5
6macro_rules! opaque_handle {
7 ($name:ident, $doc:expr) => {
8 #[doc = $doc]
9 pub struct $name {
10 ptr: *mut c_void,
11 }
12
13 unsafe impl Send for $name {}
15 unsafe impl Sync for $name {}
17
18 impl Drop for $name {
19 fn drop(&mut self) {
20 if !self.ptr.is_null() {
21 unsafe { ffi::mps_object_release(self.ptr) };
23 self.ptr = ptr::null_mut();
24 }
25 }
26 }
27
28 impl $name {
29 #[must_use]
31 pub const fn as_ptr(&self) -> *mut c_void {
32 self.ptr
33 }
34 }
35 };
36}
37
38#[doc(hidden)]
39pub use crate::generated::ndarray::*;
40
41opaque_handle!(NDArrayDescriptor, "Wraps `MPSNDArrayDescriptor`.");
42impl NDArrayDescriptor {
43 #[must_use]
45 pub fn with_dimension_sizes(data_type: u32, dimension_sizes: &[usize]) -> Option<Self> {
46 let ptr = unsafe {
48 ffi::mps_ndarray_descriptor_new_with_dimension_sizes(
49 data_type,
50 dimension_sizes.len(),
51 dimension_sizes.as_ptr(),
52 )
53 };
54 if ptr.is_null() {
55 None
56 } else {
57 Some(Self { ptr })
58 }
59 }
60
61 #[must_use]
63 pub fn data_type(&self) -> u32 {
64 unsafe { ffi::mps_ndarray_descriptor_data_type(self.ptr) }
66 }
67
68 pub fn set_data_type(&self, data_type: u32) {
70 unsafe { ffi::mps_ndarray_descriptor_set_data_type(self.ptr, data_type) };
72 }
73
74 #[must_use]
76 pub fn number_of_dimensions(&self) -> usize {
77 unsafe { ffi::mps_ndarray_descriptor_number_of_dimensions(self.ptr) }
79 }
80
81 pub fn set_number_of_dimensions(&self, number_of_dimensions: usize) {
83 unsafe {
85 ffi::mps_ndarray_descriptor_set_number_of_dimensions(self.ptr, number_of_dimensions);
86 };
87 }
88
89 #[must_use]
91 pub fn length_of_dimension(&self, dimension_index: usize) -> usize {
92 unsafe { ffi::mps_ndarray_descriptor_length_of_dimension(self.ptr, dimension_index) }
94 }
95
96 pub fn reshape_with_dimension_sizes(&self, dimension_sizes: &[usize]) {
98 unsafe {
100 ffi::mps_ndarray_descriptor_reshape_with_dimension_sizes(
101 self.ptr,
102 dimension_sizes.len(),
103 dimension_sizes.as_ptr(),
104 );
105 };
106 }
107
108 pub fn transpose_dimension(&self, dimension_index: usize, other_dimension_index: usize) {
110 unsafe {
112 ffi::mps_ndarray_descriptor_transpose_dimension(
113 self.ptr,
114 dimension_index,
115 other_dimension_index,
116 );
117 };
118 }
119}
120
121opaque_handle!(NDArray, "Wraps `MPSNDArray`.");
122impl NDArray {
123 #[must_use]
125 pub fn new(device: &MetalDevice, descriptor: &NDArrayDescriptor) -> Option<Self> {
126 let ptr =
128 unsafe { ffi::mps_ndarray_new_with_descriptor(device.as_ptr(), descriptor.as_ptr()) };
129 if ptr.is_null() {
130 None
131 } else {
132 Some(Self { ptr })
133 }
134 }
135
136 #[must_use]
138 pub fn scalar(device: &MetalDevice, value: f64) -> Option<Self> {
139 let ptr = unsafe { ffi::mps_ndarray_new_scalar(device.as_ptr(), value) };
141 if ptr.is_null() {
142 None
143 } else {
144 Some(Self { ptr })
145 }
146 }
147
148 #[must_use]
150 pub fn new_with_buffer(
151 buffer: &MetalBuffer,
152 offset: usize,
153 descriptor: &NDArrayDescriptor,
154 ) -> Option<Self> {
155 let ptr = unsafe {
156 ffi::mps_ndarray_new_with_buffer(buffer.as_ptr(), offset, descriptor.as_ptr())
157 };
158 if ptr.is_null() {
159 None
160 } else {
161 Some(Self { ptr })
162 }
163 }
164
165 #[must_use]
167 pub fn data_type(&self) -> u32 {
168 unsafe { ffi::mps_ndarray_data_type(self.ptr) }
169 }
170
171 #[must_use]
173 pub fn number_of_dimensions(&self) -> usize {
174 unsafe { ffi::mps_ndarray_number_of_dimensions(self.ptr) }
175 }
176
177 #[must_use]
179 pub fn length_of_dimension(&self, dimension_index: usize) -> usize {
180 unsafe { ffi::mps_ndarray_length_of_dimension(self.ptr, dimension_index) }
181 }
182
183 #[must_use]
185 pub fn descriptor(&self) -> Option<NDArrayDescriptor> {
186 let ptr = unsafe { ffi::mps_ndarray_descriptor(self.ptr) };
187 if ptr.is_null() {
188 None
189 } else {
190 Some(NDArrayDescriptor { ptr })
191 }
192 }
193
194 #[must_use]
196 pub fn resource_size(&self) -> usize {
197 unsafe { ffi::mps_ndarray_resource_size(self.ptr) }
198 }
199}
200
201opaque_handle!(NDArrayIdentity, "Wraps `MPSNDArrayIdentity`.");
202impl NDArrayIdentity {
203 #[must_use]
205 pub fn new(device: &MetalDevice) -> Option<Self> {
206 let ptr = unsafe { ffi::mps_ndarray_identity_new(device.as_ptr()) };
207 if ptr.is_null() {
208 None
209 } else {
210 Some(Self { ptr })
211 }
212 }
213
214 #[must_use]
216 pub fn reshape(&self, source: &NDArray, dimension_sizes: &[usize]) -> Option<NDArray> {
217 let ptr = unsafe {
218 ffi::mps_ndarray_identity_reshape(
219 self.ptr,
220 ptr::null_mut(),
221 source.as_ptr(),
222 dimension_sizes.len(),
223 dimension_sizes.as_ptr(),
224 ptr::null_mut(),
225 )
226 };
227 if ptr.is_null() {
228 None
229 } else {
230 Some(NDArray { ptr })
231 }
232 }
233
234 #[must_use]
236 pub fn reshape_with_command_buffer(
237 &self,
238 command_buffer: &MetalCommandBuffer,
239 source: &NDArray,
240 dimension_sizes: &[usize],
241 ) -> Option<NDArray> {
242 let ptr = unsafe {
243 ffi::mps_ndarray_identity_reshape(
244 self.ptr,
245 command_buffer.as_ptr(),
246 source.as_ptr(),
247 dimension_sizes.len(),
248 dimension_sizes.as_ptr(),
249 ptr::null_mut(),
250 )
251 };
252 if ptr.is_null() {
253 None
254 } else {
255 Some(NDArray { ptr })
256 }
257 }
258
259 pub fn reshape_into(
261 &self,
262 command_buffer: Option<&MetalCommandBuffer>,
263 source: &NDArray,
264 dimension_sizes: &[usize],
265 destination: &NDArray,
266 ) -> bool {
267 let command_buffer_ptr = command_buffer.map_or(ptr::null_mut(), MetalCommandBuffer::as_ptr);
268 let ptr = unsafe {
269 ffi::mps_ndarray_identity_reshape(
270 self.ptr,
271 command_buffer_ptr,
272 source.as_ptr(),
273 dimension_sizes.len(),
274 dimension_sizes.as_ptr(),
275 destination.as_ptr(),
276 )
277 };
278 !ptr.is_null()
279 }
280}
281
282opaque_handle!(NDArrayMatrixMultiplication, "Wraps `MPSNDArrayMatrixMultiplication`.");
283impl NDArrayMatrixMultiplication {
284 #[must_use]
286 pub fn new(device: &MetalDevice, source_count: usize) -> Option<Self> {
287 let ptr =
288 unsafe { ffi::mps_ndarray_matrix_multiplication_new(device.as_ptr(), source_count) };
289 if ptr.is_null() {
290 None
291 } else {
292 Some(Self { ptr })
293 }
294 }
295
296 #[must_use]
298 pub fn alpha(&self) -> f64 {
299 unsafe { ffi::mps_ndarray_matrix_multiplication_alpha(self.ptr) }
300 }
301
302 pub fn set_alpha(&self, alpha: f64) {
304 unsafe { ffi::mps_ndarray_matrix_multiplication_set_alpha(self.ptr, alpha) };
305 }
306
307 #[must_use]
309 pub fn beta(&self) -> f64 {
310 unsafe { ffi::mps_ndarray_matrix_multiplication_beta(self.ptr) }
311 }
312
313 pub fn set_beta(&self, beta: f64) {
315 unsafe { ffi::mps_ndarray_matrix_multiplication_set_beta(self.ptr, beta) };
316 }
317
318 #[must_use]
320 pub fn encode(
321 &self,
322 command_buffer: &MetalCommandBuffer,
323 source_arrays: &[&NDArray],
324 ) -> Option<NDArray> {
325 let handles: Vec<_> = source_arrays.iter().map(|array| array.as_ptr()).collect();
326 let handles_ptr = if handles.is_empty() {
327 ptr::null()
328 } else {
329 handles.as_ptr()
330 };
331 let ptr = unsafe {
332 ffi::mps_ndarray_matrix_multiplication_encode(
333 self.ptr,
334 command_buffer.as_ptr(),
335 source_arrays.len(),
336 handles_ptr,
337 )
338 };
339 if ptr.is_null() {
340 None
341 } else {
342 Some(NDArray { ptr })
343 }
344 }
345
346 pub fn encode_to_destination(
348 &self,
349 command_buffer: &MetalCommandBuffer,
350 source_arrays: &[&NDArray],
351 destination: &NDArray,
352 ) {
353 let handles: Vec<_> = source_arrays.iter().map(|array| array.as_ptr()).collect();
354 let handles_ptr = if handles.is_empty() {
355 ptr::null()
356 } else {
357 handles.as_ptr()
358 };
359 unsafe {
360 ffi::mps_ndarray_matrix_multiplication_encode_to_destination(
361 self.ptr,
362 command_buffer.as_ptr(),
363 source_arrays.len(),
364 handles_ptr,
365 destination.as_ptr(),
366 );
367 };
368 }
369}