1use crate::ffi;
2use apple_metal::{CommandBuffer as MetalCommandBuffer, MetalBuffer, MetalDevice};
3use core::ffi::c_void;
4use core::ptr;
5
6macro_rules! opaque_handle {
7 ($name:ident) => {
8 pub struct $name {
9 ptr: *mut c_void,
10 }
11
12 unsafe impl Send for $name {}
14 unsafe impl Sync for $name {}
16
17 impl Drop for $name {
18 fn drop(&mut self) {
19 if !self.ptr.is_null() {
20 unsafe { ffi::mps_object_release(self.ptr) };
22 self.ptr = ptr::null_mut();
23 }
24 }
25 }
26
27 impl $name {
28 #[must_use]
29 pub const fn as_ptr(&self) -> *mut c_void {
30 self.ptr
31 }
32 }
33 };
34}
35
36pub use crate::generated::ndarray::*;
37
38opaque_handle!(NDArrayDescriptor);
39impl NDArrayDescriptor {
40 #[must_use]
41 pub fn with_dimension_sizes(data_type: u32, dimension_sizes: &[usize]) -> Option<Self> {
42 let ptr = unsafe {
44 ffi::mps_ndarray_descriptor_new_with_dimension_sizes(
45 data_type,
46 dimension_sizes.len(),
47 dimension_sizes.as_ptr(),
48 )
49 };
50 if ptr.is_null() {
51 None
52 } else {
53 Some(Self { ptr })
54 }
55 }
56
57 #[must_use]
58 pub fn data_type(&self) -> u32 {
59 unsafe { ffi::mps_ndarray_descriptor_data_type(self.ptr) }
61 }
62
63 pub fn set_data_type(&self, data_type: u32) {
64 unsafe { ffi::mps_ndarray_descriptor_set_data_type(self.ptr, data_type) };
66 }
67
68 #[must_use]
69 pub fn number_of_dimensions(&self) -> usize {
70 unsafe { ffi::mps_ndarray_descriptor_number_of_dimensions(self.ptr) }
72 }
73
74 pub fn set_number_of_dimensions(&self, number_of_dimensions: usize) {
75 unsafe {
77 ffi::mps_ndarray_descriptor_set_number_of_dimensions(self.ptr, number_of_dimensions);
78 };
79 }
80
81 #[must_use]
82 pub fn length_of_dimension(&self, dimension_index: usize) -> usize {
83 unsafe { ffi::mps_ndarray_descriptor_length_of_dimension(self.ptr, dimension_index) }
85 }
86
87 pub fn reshape_with_dimension_sizes(&self, dimension_sizes: &[usize]) {
88 unsafe {
90 ffi::mps_ndarray_descriptor_reshape_with_dimension_sizes(
91 self.ptr,
92 dimension_sizes.len(),
93 dimension_sizes.as_ptr(),
94 );
95 };
96 }
97
98 pub fn transpose_dimension(&self, dimension_index: usize, other_dimension_index: usize) {
99 unsafe {
101 ffi::mps_ndarray_descriptor_transpose_dimension(
102 self.ptr,
103 dimension_index,
104 other_dimension_index,
105 );
106 };
107 }
108}
109
110opaque_handle!(NDArray);
111impl NDArray {
112 #[must_use]
113 pub fn new(device: &MetalDevice, descriptor: &NDArrayDescriptor) -> Option<Self> {
114 let ptr =
116 unsafe { ffi::mps_ndarray_new_with_descriptor(device.as_ptr(), descriptor.as_ptr()) };
117 if ptr.is_null() {
118 None
119 } else {
120 Some(Self { ptr })
121 }
122 }
123
124 #[must_use]
125 pub fn scalar(device: &MetalDevice, value: f64) -> Option<Self> {
126 let ptr = unsafe { ffi::mps_ndarray_new_scalar(device.as_ptr(), value) };
128 if ptr.is_null() {
129 None
130 } else {
131 Some(Self { ptr })
132 }
133 }
134
135 #[must_use]
136 pub fn new_with_buffer(
137 buffer: &MetalBuffer,
138 offset: usize,
139 descriptor: &NDArrayDescriptor,
140 ) -> Option<Self> {
141 let ptr = unsafe {
142 ffi::mps_ndarray_new_with_buffer(buffer.as_ptr(), offset, descriptor.as_ptr())
143 };
144 if ptr.is_null() {
145 None
146 } else {
147 Some(Self { ptr })
148 }
149 }
150
151 #[must_use]
152 pub fn data_type(&self) -> u32 {
153 unsafe { ffi::mps_ndarray_data_type(self.ptr) }
154 }
155
156 #[must_use]
157 pub fn number_of_dimensions(&self) -> usize {
158 unsafe { ffi::mps_ndarray_number_of_dimensions(self.ptr) }
159 }
160
161 #[must_use]
162 pub fn length_of_dimension(&self, dimension_index: usize) -> usize {
163 unsafe { ffi::mps_ndarray_length_of_dimension(self.ptr, dimension_index) }
164 }
165
166 #[must_use]
167 pub fn descriptor(&self) -> Option<NDArrayDescriptor> {
168 let ptr = unsafe { ffi::mps_ndarray_descriptor(self.ptr) };
169 if ptr.is_null() {
170 None
171 } else {
172 Some(NDArrayDescriptor { ptr })
173 }
174 }
175
176 #[must_use]
177 pub fn resource_size(&self) -> usize {
178 unsafe { ffi::mps_ndarray_resource_size(self.ptr) }
179 }
180}
181
182opaque_handle!(NDArrayIdentity);
183impl NDArrayIdentity {
184 #[must_use]
185 pub fn new(device: &MetalDevice) -> Option<Self> {
186 let ptr = unsafe { ffi::mps_ndarray_identity_new(device.as_ptr()) };
187 if ptr.is_null() {
188 None
189 } else {
190 Some(Self { ptr })
191 }
192 }
193
194 #[must_use]
195 pub fn reshape(&self, source: &NDArray, dimension_sizes: &[usize]) -> Option<NDArray> {
196 let ptr = unsafe {
197 ffi::mps_ndarray_identity_reshape(
198 self.ptr,
199 ptr::null_mut(),
200 source.as_ptr(),
201 dimension_sizes.len(),
202 dimension_sizes.as_ptr(),
203 ptr::null_mut(),
204 )
205 };
206 if ptr.is_null() {
207 None
208 } else {
209 Some(NDArray { ptr })
210 }
211 }
212
213 #[must_use]
214 pub fn reshape_with_command_buffer(
215 &self,
216 command_buffer: &MetalCommandBuffer,
217 source: &NDArray,
218 dimension_sizes: &[usize],
219 ) -> Option<NDArray> {
220 let ptr = unsafe {
221 ffi::mps_ndarray_identity_reshape(
222 self.ptr,
223 command_buffer.as_ptr(),
224 source.as_ptr(),
225 dimension_sizes.len(),
226 dimension_sizes.as_ptr(),
227 ptr::null_mut(),
228 )
229 };
230 if ptr.is_null() {
231 None
232 } else {
233 Some(NDArray { ptr })
234 }
235 }
236
237 pub fn reshape_into(
238 &self,
239 command_buffer: Option<&MetalCommandBuffer>,
240 source: &NDArray,
241 dimension_sizes: &[usize],
242 destination: &NDArray,
243 ) -> bool {
244 let command_buffer_ptr = command_buffer.map_or(ptr::null_mut(), MetalCommandBuffer::as_ptr);
245 let ptr = unsafe {
246 ffi::mps_ndarray_identity_reshape(
247 self.ptr,
248 command_buffer_ptr,
249 source.as_ptr(),
250 dimension_sizes.len(),
251 dimension_sizes.as_ptr(),
252 destination.as_ptr(),
253 )
254 };
255 !ptr.is_null()
256 }
257}
258
259opaque_handle!(NDArrayMatrixMultiplication);
260impl NDArrayMatrixMultiplication {
261 #[must_use]
262 pub fn new(device: &MetalDevice, source_count: usize) -> Option<Self> {
263 let ptr =
264 unsafe { ffi::mps_ndarray_matrix_multiplication_new(device.as_ptr(), source_count) };
265 if ptr.is_null() {
266 None
267 } else {
268 Some(Self { ptr })
269 }
270 }
271
272 #[must_use]
273 pub fn alpha(&self) -> f64 {
274 unsafe { ffi::mps_ndarray_matrix_multiplication_alpha(self.ptr) }
275 }
276
277 pub fn set_alpha(&self, alpha: f64) {
278 unsafe { ffi::mps_ndarray_matrix_multiplication_set_alpha(self.ptr, alpha) };
279 }
280
281 #[must_use]
282 pub fn beta(&self) -> f64 {
283 unsafe { ffi::mps_ndarray_matrix_multiplication_beta(self.ptr) }
284 }
285
286 pub fn set_beta(&self, beta: f64) {
287 unsafe { ffi::mps_ndarray_matrix_multiplication_set_beta(self.ptr, beta) };
288 }
289
290 #[must_use]
291 pub fn encode(
292 &self,
293 command_buffer: &MetalCommandBuffer,
294 source_arrays: &[&NDArray],
295 ) -> Option<NDArray> {
296 let handles: Vec<_> = source_arrays.iter().map(|array| array.as_ptr()).collect();
297 let handles_ptr = if handles.is_empty() {
298 ptr::null()
299 } else {
300 handles.as_ptr()
301 };
302 let ptr = unsafe {
303 ffi::mps_ndarray_matrix_multiplication_encode(
304 self.ptr,
305 command_buffer.as_ptr(),
306 source_arrays.len(),
307 handles_ptr,
308 )
309 };
310 if ptr.is_null() {
311 None
312 } else {
313 Some(NDArray { ptr })
314 }
315 }
316
317 pub fn encode_to_destination(
318 &self,
319 command_buffer: &MetalCommandBuffer,
320 source_arrays: &[&NDArray],
321 destination: &NDArray,
322 ) {
323 let handles: Vec<_> = source_arrays.iter().map(|array| array.as_ptr()).collect();
324 let handles_ptr = if handles.is_empty() {
325 ptr::null()
326 } else {
327 handles.as_ptr()
328 };
329 unsafe {
330 ffi::mps_ndarray_matrix_multiplication_encode_to_destination(
331 self.ptr,
332 command_buffer.as_ptr(),
333 source_arrays.len(),
334 handles_ptr,
335 destination.as_ptr(),
336 );
337 };
338 }
339}