1use crate::ffi;
2use apple_metal::{CommandBuffer, MetalBuffer, MetalDevice};
3use core::ffi::c_void;
4use core::ptr;
5
6pub mod data_type {
8 pub const INVALID: u32 = 0;
10 pub const FLOAT32: u32 = 0x1000_0020;
12 pub const FLOAT16: u32 = 0x1000_0010;
14 pub const INT8: u32 = 0x2000_0008;
16 pub const INT16: u32 = 0x2000_0010;
18 pub const INT32: u32 = 0x2000_0020;
20 pub const UINT8: u32 = 0x0000_0008;
22 pub const UINT16: u32 = 0x0000_0010;
24 pub const UINT32: u32 = 0x0000_0020;
26 pub const UNORM8: u32 = 0x4000_0008;
28}
29
30#[must_use]
32pub const fn data_type_size(data_type: u32) -> Option<usize> {
33 match data_type {
34 data_type::FLOAT16 | data_type::INT16 | data_type::UINT16 => Some(2),
35 data_type::FLOAT32 | data_type::INT32 | data_type::UINT32 => Some(4),
36 data_type::INT8 | data_type::UINT8 | data_type::UNORM8 => Some(1),
37 _ => None,
38 }
39}
40
41#[derive(Debug, Clone, Copy)]
43pub struct MatrixDescriptor {
44 pub rows: usize,
46 pub columns: usize,
48 pub matrices: usize,
50 pub row_bytes: usize,
52 pub matrix_bytes: usize,
54 pub data_type: u32,
56}
57
58impl MatrixDescriptor {
59 #[must_use]
61 pub const fn with_strides(
62 rows: usize,
63 columns: usize,
64 matrices: usize,
65 row_bytes: usize,
66 matrix_bytes: usize,
67 data_type: u32,
68 ) -> Self {
69 Self {
70 rows,
71 columns,
72 matrices,
73 row_bytes,
74 matrix_bytes,
75 data_type,
76 }
77 }
78
79 #[must_use]
81 pub fn contiguous(rows: usize, columns: usize, data_type: u32) -> Option<Self> {
82 let element_size = data_type_size(data_type)?;
83 let row_bytes = columns.checked_mul(element_size)?;
84 let matrix_bytes = rows.checked_mul(row_bytes)?;
85 Some(Self::with_strides(
86 rows,
87 columns,
88 1,
89 row_bytes,
90 matrix_bytes,
91 data_type,
92 ))
93 }
94
95 #[must_use]
97 pub fn recommended_row_bytes(columns: usize, data_type: u32) -> usize {
98 unsafe { ffi::mps_matrix_descriptor_row_bytes_for_columns(columns, data_type) }
100 }
101}
102
103#[derive(Debug, Clone, Copy)]
105pub struct VectorDescriptor {
106 pub length: usize,
108 pub vectors: usize,
110 pub vector_bytes: usize,
112 pub data_type: u32,
114}
115
116impl VectorDescriptor {
117 #[must_use]
119 pub const fn with_stride(
120 length: usize,
121 vectors: usize,
122 vector_bytes: usize,
123 data_type: u32,
124 ) -> Self {
125 Self {
126 length,
127 vectors,
128 vector_bytes,
129 data_type,
130 }
131 }
132
133 #[must_use]
135 pub fn contiguous(length: usize, data_type: u32) -> Option<Self> {
136 let element_size = data_type_size(data_type)?;
137 let vector_bytes = length.checked_mul(element_size)?;
138 Some(Self::with_stride(length, 1, vector_bytes, data_type))
139 }
140
141 #[must_use]
143 pub fn recommended_vector_bytes(length: usize, data_type: u32) -> usize {
144 unsafe { ffi::mps_vector_descriptor_vector_bytes_for_length(length, data_type) }
146 }
147}
148
149macro_rules! opaque_handle {
150 ($name:ident, $doc:expr) => {
151 #[doc = $doc]
152 pub struct $name {
153 ptr: *mut c_void,
154 }
155
156 unsafe impl Send for $name {}
158 unsafe impl Sync for $name {}
160
161 impl Drop for $name {
162 fn drop(&mut self) {
163 if !self.ptr.is_null() {
164 unsafe { ffi::mps_object_release(self.ptr) };
166 self.ptr = ptr::null_mut();
167 }
168 }
169 }
170
171 impl $name {
172 #[must_use]
174 pub const fn as_ptr(&self) -> *mut c_void {
175 self.ptr
176 }
177 }
178 };
179}
180
181opaque_handle!(Matrix, "Wraps `MPSMatrix`.");
182impl Matrix {
183 #[must_use]
185 pub fn new_with_buffer(buffer: &MetalBuffer, descriptor: MatrixDescriptor) -> Option<Self> {
186 let ptr = unsafe {
188 ffi::mps_matrix_new_with_buffer(
189 buffer.as_ptr(),
190 descriptor.rows,
191 descriptor.columns,
192 descriptor.matrices,
193 descriptor.row_bytes,
194 descriptor.matrix_bytes,
195 descriptor.data_type,
196 )
197 };
198 if ptr.is_null() {
199 None
200 } else {
201 Some(Self { ptr })
202 }
203 }
204
205 #[must_use]
207 pub fn rows(&self) -> usize {
208 unsafe { ffi::mps_matrix_rows(self.ptr) }
210 }
211
212 #[must_use]
214 pub fn columns(&self) -> usize {
215 unsafe { ffi::mps_matrix_columns(self.ptr) }
217 }
218
219 #[must_use]
221 pub fn matrices(&self) -> usize {
222 unsafe { ffi::mps_matrix_matrices(self.ptr) }
224 }
225
226 #[must_use]
228 pub fn row_bytes(&self) -> usize {
229 unsafe { ffi::mps_matrix_row_bytes(self.ptr) }
231 }
232
233 #[must_use]
235 pub fn matrix_bytes(&self) -> usize {
236 unsafe { ffi::mps_matrix_matrix_bytes(self.ptr) }
238 }
239
240 #[must_use]
242 pub fn data_type(&self) -> u32 {
243 unsafe { ffi::mps_matrix_data_type(self.ptr) }
245 }
246}
247
248opaque_handle!(Vector, "Wraps `MPSVector`.");
249#[doc(hidden)]
250pub use crate::generated::matrix::*;
251
252impl Vector {
253 #[must_use]
255 pub fn new_with_buffer(buffer: &MetalBuffer, descriptor: VectorDescriptor) -> Option<Self> {
256 let ptr = unsafe {
258 ffi::mps_vector_new_with_buffer(
259 buffer.as_ptr(),
260 descriptor.length,
261 descriptor.vectors,
262 descriptor.vector_bytes,
263 descriptor.data_type,
264 )
265 };
266 if ptr.is_null() {
267 None
268 } else {
269 Some(Self { ptr })
270 }
271 }
272
273 #[must_use]
275 pub fn length(&self) -> usize {
276 unsafe { ffi::mps_vector_length(self.ptr) }
278 }
279
280 #[must_use]
282 pub fn vectors(&self) -> usize {
283 unsafe { ffi::mps_vector_vectors(self.ptr) }
285 }
286
287 #[must_use]
289 pub fn vector_bytes(&self) -> usize {
290 unsafe { ffi::mps_vector_vector_bytes(self.ptr) }
292 }
293
294 #[must_use]
296 pub fn data_type(&self) -> u32 {
297 unsafe { ffi::mps_vector_data_type(self.ptr) }
299 }
300}
301
302#[derive(Debug, Clone, Copy)]
304pub struct MatrixMultiplicationDescriptor {
305 pub transpose_left: bool,
307 pub transpose_right: bool,
309 pub result_rows: usize,
311 pub result_columns: usize,
313 pub interior_columns: usize,
315 pub alpha: f64,
317 pub beta: f64,
319}
320
321impl MatrixMultiplicationDescriptor {
322 #[must_use]
324 pub const fn new(result_rows: usize, result_columns: usize, interior_columns: usize) -> Self {
325 Self {
326 transpose_left: false,
327 transpose_right: false,
328 result_rows,
329 result_columns,
330 interior_columns,
331 alpha: 1.0,
332 beta: 0.0,
333 }
334 }
335
336 #[must_use]
338 pub const fn with_options(
339 transpose_left: bool,
340 transpose_right: bool,
341 result_rows: usize,
342 result_columns: usize,
343 interior_columns: usize,
344 alpha: f64,
345 beta: f64,
346 ) -> Self {
347 Self {
348 transpose_left,
349 transpose_right,
350 result_rows,
351 result_columns,
352 interior_columns,
353 alpha,
354 beta,
355 }
356 }
357}
358
359opaque_handle!(MatrixMultiplication, "Wraps `MPSMatrixMultiplication`.");
360impl MatrixMultiplication {
361 #[must_use]
363 pub fn new(device: &MetalDevice, descriptor: MatrixMultiplicationDescriptor) -> Option<Self> {
364 let ptr = unsafe {
366 ffi::mps_matrix_multiplication_new(
367 device.as_ptr(),
368 descriptor.transpose_left,
369 descriptor.transpose_right,
370 descriptor.result_rows,
371 descriptor.result_columns,
372 descriptor.interior_columns,
373 descriptor.alpha,
374 descriptor.beta,
375 )
376 };
377 if ptr.is_null() {
378 None
379 } else {
380 Some(Self { ptr })
381 }
382 }
383
384 #[must_use]
386 pub fn new_simple(
387 device: &MetalDevice,
388 result_rows: usize,
389 result_columns: usize,
390 interior_columns: usize,
391 ) -> Option<Self> {
392 Self::new(
393 device,
394 MatrixMultiplicationDescriptor::new(result_rows, result_columns, interior_columns),
395 )
396 }
397
398 pub fn encode(
400 &self,
401 command_buffer: &CommandBuffer,
402 left: &Matrix,
403 right: &Matrix,
404 result: &Matrix,
405 ) {
406 unsafe {
408 ffi::mps_matrix_multiplication_encode(
409 self.ptr,
410 command_buffer.as_ptr(),
411 left.as_ptr(),
412 right.as_ptr(),
413 result.as_ptr(),
414 );
415 };
416 }
417}