1use crate::ffi;
2use apple_metal::{CommandBuffer as MetalCommandBuffer, MetalBuffer, MetalDevice};
3use core::ffi::c_void;
4use core::ptr;
5
6macro_rules! opaque_handle {
7 ($name:ident) => {
8 pub struct $name {
9 ptr: *mut c_void,
10 }
11
12 unsafe impl Send for $name {}
13 unsafe impl Sync for $name {}
14
15 impl Drop for $name {
16 fn drop(&mut self) {
17 if !self.ptr.is_null() {
18 unsafe { ffi::mps_object_release(self.ptr) };
19 self.ptr = ptr::null_mut();
20 }
21 }
22 }
23
24 impl $name {
25 #[must_use]
26 pub const fn as_ptr(&self) -> *mut c_void {
27 self.ptr
28 }
29 }
30 };
31}
32
33opaque_handle!(NDArrayDescriptor);
34impl NDArrayDescriptor {
35 #[must_use]
36 pub fn with_dimension_sizes(data_type: u32, dimension_sizes: &[usize]) -> Option<Self> {
37 let ptr = unsafe {
38 ffi::mps_ndarray_descriptor_new_with_dimension_sizes(
39 data_type,
40 dimension_sizes.len(),
41 dimension_sizes.as_ptr(),
42 )
43 };
44 if ptr.is_null() {
45 None
46 } else {
47 Some(Self { ptr })
48 }
49 }
50
51 #[must_use]
52 pub fn data_type(&self) -> u32 {
53 unsafe { ffi::mps_ndarray_descriptor_data_type(self.ptr) }
54 }
55
56 pub fn set_data_type(&self, data_type: u32) {
57 unsafe { ffi::mps_ndarray_descriptor_set_data_type(self.ptr, data_type) };
58 }
59
60 #[must_use]
61 pub fn number_of_dimensions(&self) -> usize {
62 unsafe { ffi::mps_ndarray_descriptor_number_of_dimensions(self.ptr) }
63 }
64
65 pub fn set_number_of_dimensions(&self, number_of_dimensions: usize) {
66 unsafe {
67 ffi::mps_ndarray_descriptor_set_number_of_dimensions(self.ptr, number_of_dimensions);
68 };
69 }
70
71 #[must_use]
72 pub fn length_of_dimension(&self, dimension_index: usize) -> usize {
73 unsafe { ffi::mps_ndarray_descriptor_length_of_dimension(self.ptr, dimension_index) }
74 }
75
76 pub fn reshape_with_dimension_sizes(&self, dimension_sizes: &[usize]) {
77 unsafe {
78 ffi::mps_ndarray_descriptor_reshape_with_dimension_sizes(
79 self.ptr,
80 dimension_sizes.len(),
81 dimension_sizes.as_ptr(),
82 );
83 };
84 }
85
86 pub fn transpose_dimension(&self, dimension_index: usize, other_dimension_index: usize) {
87 unsafe {
88 ffi::mps_ndarray_descriptor_transpose_dimension(
89 self.ptr,
90 dimension_index,
91 other_dimension_index,
92 );
93 };
94 }
95}
96
97opaque_handle!(NDArray);
98impl NDArray {
99 #[must_use]
100 pub fn new(device: &MetalDevice, descriptor: &NDArrayDescriptor) -> Option<Self> {
101 let ptr =
102 unsafe { ffi::mps_ndarray_new_with_descriptor(device.as_ptr(), descriptor.as_ptr()) };
103 if ptr.is_null() {
104 None
105 } else {
106 Some(Self { ptr })
107 }
108 }
109
110 #[must_use]
111 pub fn scalar(device: &MetalDevice, value: f64) -> Option<Self> {
112 let ptr = unsafe { ffi::mps_ndarray_new_scalar(device.as_ptr(), value) };
113 if ptr.is_null() {
114 None
115 } else {
116 Some(Self { ptr })
117 }
118 }
119
120 #[must_use]
121 pub fn new_with_buffer(
122 buffer: &MetalBuffer,
123 offset: usize,
124 descriptor: &NDArrayDescriptor,
125 ) -> Option<Self> {
126 let ptr = unsafe {
127 ffi::mps_ndarray_new_with_buffer(buffer.as_ptr(), offset, descriptor.as_ptr())
128 };
129 if ptr.is_null() {
130 None
131 } else {
132 Some(Self { ptr })
133 }
134 }
135
136 #[must_use]
137 pub fn data_type(&self) -> u32 {
138 unsafe { ffi::mps_ndarray_data_type(self.ptr) }
139 }
140
141 #[must_use]
142 pub fn number_of_dimensions(&self) -> usize {
143 unsafe { ffi::mps_ndarray_number_of_dimensions(self.ptr) }
144 }
145
146 #[must_use]
147 pub fn length_of_dimension(&self, dimension_index: usize) -> usize {
148 unsafe { ffi::mps_ndarray_length_of_dimension(self.ptr, dimension_index) }
149 }
150
151 #[must_use]
152 pub fn descriptor(&self) -> Option<NDArrayDescriptor> {
153 let ptr = unsafe { ffi::mps_ndarray_descriptor(self.ptr) };
154 if ptr.is_null() {
155 None
156 } else {
157 Some(NDArrayDescriptor { ptr })
158 }
159 }
160
161 #[must_use]
162 pub fn resource_size(&self) -> usize {
163 unsafe { ffi::mps_ndarray_resource_size(self.ptr) }
164 }
165}
166
167opaque_handle!(NDArrayIdentity);
168impl NDArrayIdentity {
169 #[must_use]
170 pub fn new(device: &MetalDevice) -> Option<Self> {
171 let ptr = unsafe { ffi::mps_ndarray_identity_new(device.as_ptr()) };
172 if ptr.is_null() {
173 None
174 } else {
175 Some(Self { ptr })
176 }
177 }
178
179 #[must_use]
180 pub fn reshape(&self, source: &NDArray, dimension_sizes: &[usize]) -> Option<NDArray> {
181 let ptr = unsafe {
182 ffi::mps_ndarray_identity_reshape(
183 self.ptr,
184 ptr::null_mut(),
185 source.as_ptr(),
186 dimension_sizes.len(),
187 dimension_sizes.as_ptr(),
188 ptr::null_mut(),
189 )
190 };
191 if ptr.is_null() {
192 None
193 } else {
194 Some(NDArray { ptr })
195 }
196 }
197
198 #[must_use]
199 pub fn reshape_with_command_buffer(
200 &self,
201 command_buffer: &MetalCommandBuffer,
202 source: &NDArray,
203 dimension_sizes: &[usize],
204 ) -> Option<NDArray> {
205 let ptr = unsafe {
206 ffi::mps_ndarray_identity_reshape(
207 self.ptr,
208 command_buffer.as_ptr(),
209 source.as_ptr(),
210 dimension_sizes.len(),
211 dimension_sizes.as_ptr(),
212 ptr::null_mut(),
213 )
214 };
215 if ptr.is_null() {
216 None
217 } else {
218 Some(NDArray { ptr })
219 }
220 }
221
222 pub fn reshape_into(
223 &self,
224 command_buffer: Option<&MetalCommandBuffer>,
225 source: &NDArray,
226 dimension_sizes: &[usize],
227 destination: &NDArray,
228 ) -> bool {
229 let command_buffer_ptr = command_buffer.map_or(ptr::null_mut(), MetalCommandBuffer::as_ptr);
230 let ptr = unsafe {
231 ffi::mps_ndarray_identity_reshape(
232 self.ptr,
233 command_buffer_ptr,
234 source.as_ptr(),
235 dimension_sizes.len(),
236 dimension_sizes.as_ptr(),
237 destination.as_ptr(),
238 )
239 };
240 !ptr.is_null()
241 }
242}
243
244opaque_handle!(NDArrayMatrixMultiplication);
245impl NDArrayMatrixMultiplication {
246 #[must_use]
247 pub fn new(device: &MetalDevice, source_count: usize) -> Option<Self> {
248 let ptr = unsafe { ffi::mps_ndarray_matrix_multiplication_new(device.as_ptr(), source_count) };
249 if ptr.is_null() {
250 None
251 } else {
252 Some(Self { ptr })
253 }
254 }
255
256 #[must_use]
257 pub fn alpha(&self) -> f64 {
258 unsafe { ffi::mps_ndarray_matrix_multiplication_alpha(self.ptr) }
259 }
260
261 pub fn set_alpha(&self, alpha: f64) {
262 unsafe { ffi::mps_ndarray_matrix_multiplication_set_alpha(self.ptr, alpha) };
263 }
264
265 #[must_use]
266 pub fn beta(&self) -> f64 {
267 unsafe { ffi::mps_ndarray_matrix_multiplication_beta(self.ptr) }
268 }
269
270 pub fn set_beta(&self, beta: f64) {
271 unsafe { ffi::mps_ndarray_matrix_multiplication_set_beta(self.ptr, beta) };
272 }
273
274 #[must_use]
275 pub fn encode(
276 &self,
277 command_buffer: &MetalCommandBuffer,
278 source_arrays: &[&NDArray],
279 ) -> Option<NDArray> {
280 let handles: Vec<_> = source_arrays.iter().map(|array| array.as_ptr()).collect();
281 let handles_ptr = if handles.is_empty() {
282 ptr::null()
283 } else {
284 handles.as_ptr()
285 };
286 let ptr = unsafe {
287 ffi::mps_ndarray_matrix_multiplication_encode(
288 self.ptr,
289 command_buffer.as_ptr(),
290 source_arrays.len(),
291 handles_ptr,
292 )
293 };
294 if ptr.is_null() {
295 None
296 } else {
297 Some(NDArray { ptr })
298 }
299 }
300
301 pub fn encode_to_destination(
302 &self,
303 command_buffer: &MetalCommandBuffer,
304 source_arrays: &[&NDArray],
305 destination: &NDArray,
306 ) {
307 let handles: Vec<_> = source_arrays.iter().map(|array| array.as_ptr()).collect();
308 let handles_ptr = if handles.is_empty() {
309 ptr::null()
310 } else {
311 handles.as_ptr()
312 };
313 unsafe {
314 ffi::mps_ndarray_matrix_multiplication_encode_to_destination(
315 self.ptr,
316 command_buffer.as_ptr(),
317 source_arrays.len(),
318 handles_ptr,
319 destination.as_ptr(),
320 );
321 };
322 }
323}