1use crate::ffi;
2use apple_metal::{CommandBuffer, MetalBuffer, MetalDevice};
3use core::ffi::c_void;
4use core::ptr;
5
6pub mod data_type {
8 pub const INVALID: u32 = 0;
9 pub const FLOAT32: u32 = 0x1000_0020;
10 pub const FLOAT16: u32 = 0x1000_0010;
11 pub const INT8: u32 = 0x2000_0008;
12 pub const INT16: u32 = 0x2000_0010;
13 pub const INT32: u32 = 0x2000_0020;
14 pub const UINT8: u32 = 0x0000_0008;
15 pub const UINT16: u32 = 0x0000_0010;
16 pub const UINT32: u32 = 0x0000_0020;
17 pub const UNORM8: u32 = 0x4000_0008;
18}
19
20#[must_use]
22pub const fn data_type_size(data_type: u32) -> Option<usize> {
23 match data_type {
24 data_type::FLOAT16 | data_type::INT16 | data_type::UINT16 => Some(2),
25 data_type::FLOAT32 | data_type::INT32 | data_type::UINT32 => Some(4),
26 data_type::INT8 | data_type::UINT8 | data_type::UNORM8 => Some(1),
27 _ => None,
28 }
29}
30
31#[derive(Debug, Clone, Copy)]
33pub struct MatrixDescriptor {
34 pub rows: usize,
35 pub columns: usize,
36 pub matrices: usize,
37 pub row_bytes: usize,
38 pub matrix_bytes: usize,
39 pub data_type: u32,
40}
41
42impl MatrixDescriptor {
43 #[must_use]
45 pub const fn with_strides(
46 rows: usize,
47 columns: usize,
48 matrices: usize,
49 row_bytes: usize,
50 matrix_bytes: usize,
51 data_type: u32,
52 ) -> Self {
53 Self {
54 rows,
55 columns,
56 matrices,
57 row_bytes,
58 matrix_bytes,
59 data_type,
60 }
61 }
62
63 #[must_use]
65 pub fn contiguous(rows: usize, columns: usize, data_type: u32) -> Option<Self> {
66 let element_size = data_type_size(data_type)?;
67 let row_bytes = columns.checked_mul(element_size)?;
68 let matrix_bytes = rows.checked_mul(row_bytes)?;
69 Some(Self::with_strides(
70 rows,
71 columns,
72 1,
73 row_bytes,
74 matrix_bytes,
75 data_type,
76 ))
77 }
78
79 #[must_use]
81 pub fn recommended_row_bytes(columns: usize, data_type: u32) -> usize {
82 unsafe { ffi::mps_matrix_descriptor_row_bytes_for_columns(columns, data_type) }
84 }
85}
86
87#[derive(Debug, Clone, Copy)]
89pub struct VectorDescriptor {
90 pub length: usize,
91 pub vectors: usize,
92 pub vector_bytes: usize,
93 pub data_type: u32,
94}
95
96impl VectorDescriptor {
97 #[must_use]
99 pub const fn with_stride(
100 length: usize,
101 vectors: usize,
102 vector_bytes: usize,
103 data_type: u32,
104 ) -> Self {
105 Self {
106 length,
107 vectors,
108 vector_bytes,
109 data_type,
110 }
111 }
112
113 #[must_use]
115 pub fn contiguous(length: usize, data_type: u32) -> Option<Self> {
116 let element_size = data_type_size(data_type)?;
117 let vector_bytes = length.checked_mul(element_size)?;
118 Some(Self::with_stride(length, 1, vector_bytes, data_type))
119 }
120
121 #[must_use]
123 pub fn recommended_vector_bytes(length: usize, data_type: u32) -> usize {
124 unsafe { ffi::mps_vector_descriptor_vector_bytes_for_length(length, data_type) }
126 }
127}
128
129macro_rules! opaque_handle {
130 ($name:ident) => {
131 pub struct $name {
132 ptr: *mut c_void,
133 }
134
135 unsafe impl Send for $name {}
137 unsafe impl Sync for $name {}
139
140 impl Drop for $name {
141 fn drop(&mut self) {
142 if !self.ptr.is_null() {
143 unsafe { ffi::mps_object_release(self.ptr) };
145 self.ptr = ptr::null_mut();
146 }
147 }
148 }
149
150 impl $name {
151 #[must_use]
152 pub const fn as_ptr(&self) -> *mut c_void {
153 self.ptr
154 }
155 }
156 };
157}
158
159opaque_handle!(Matrix);
160impl Matrix {
161 #[must_use]
163 pub fn new_with_buffer(buffer: &MetalBuffer, descriptor: MatrixDescriptor) -> Option<Self> {
164 let ptr = unsafe {
166 ffi::mps_matrix_new_with_buffer(
167 buffer.as_ptr(),
168 descriptor.rows,
169 descriptor.columns,
170 descriptor.matrices,
171 descriptor.row_bytes,
172 descriptor.matrix_bytes,
173 descriptor.data_type,
174 )
175 };
176 if ptr.is_null() {
177 None
178 } else {
179 Some(Self { ptr })
180 }
181 }
182
183 #[must_use]
184 pub fn rows(&self) -> usize {
185 unsafe { ffi::mps_matrix_rows(self.ptr) }
187 }
188
189 #[must_use]
190 pub fn columns(&self) -> usize {
191 unsafe { ffi::mps_matrix_columns(self.ptr) }
193 }
194
195 #[must_use]
196 pub fn matrices(&self) -> usize {
197 unsafe { ffi::mps_matrix_matrices(self.ptr) }
199 }
200
201 #[must_use]
202 pub fn row_bytes(&self) -> usize {
203 unsafe { ffi::mps_matrix_row_bytes(self.ptr) }
205 }
206
207 #[must_use]
208 pub fn matrix_bytes(&self) -> usize {
209 unsafe { ffi::mps_matrix_matrix_bytes(self.ptr) }
211 }
212
213 #[must_use]
214 pub fn data_type(&self) -> u32 {
215 unsafe { ffi::mps_matrix_data_type(self.ptr) }
217 }
218}
219
220opaque_handle!(Vector);
221pub use crate::generated::matrix::*;
222
223impl Vector {
224 #[must_use]
226 pub fn new_with_buffer(buffer: &MetalBuffer, descriptor: VectorDescriptor) -> Option<Self> {
227 let ptr = unsafe {
229 ffi::mps_vector_new_with_buffer(
230 buffer.as_ptr(),
231 descriptor.length,
232 descriptor.vectors,
233 descriptor.vector_bytes,
234 descriptor.data_type,
235 )
236 };
237 if ptr.is_null() {
238 None
239 } else {
240 Some(Self { ptr })
241 }
242 }
243
244 #[must_use]
245 pub fn length(&self) -> usize {
246 unsafe { ffi::mps_vector_length(self.ptr) }
248 }
249
250 #[must_use]
251 pub fn vectors(&self) -> usize {
252 unsafe { ffi::mps_vector_vectors(self.ptr) }
254 }
255
256 #[must_use]
257 pub fn vector_bytes(&self) -> usize {
258 unsafe { ffi::mps_vector_vector_bytes(self.ptr) }
260 }
261
262 #[must_use]
263 pub fn data_type(&self) -> u32 {
264 unsafe { ffi::mps_vector_data_type(self.ptr) }
266 }
267}
268
269#[derive(Debug, Clone, Copy)]
271pub struct MatrixMultiplicationDescriptor {
272 pub transpose_left: bool,
273 pub transpose_right: bool,
274 pub result_rows: usize,
275 pub result_columns: usize,
276 pub interior_columns: usize,
277 pub alpha: f64,
278 pub beta: f64,
279}
280
281impl MatrixMultiplicationDescriptor {
282 #[must_use]
284 pub const fn new(result_rows: usize, result_columns: usize, interior_columns: usize) -> Self {
285 Self {
286 transpose_left: false,
287 transpose_right: false,
288 result_rows,
289 result_columns,
290 interior_columns,
291 alpha: 1.0,
292 beta: 0.0,
293 }
294 }
295
296 #[must_use]
298 pub const fn with_options(
299 transpose_left: bool,
300 transpose_right: bool,
301 result_rows: usize,
302 result_columns: usize,
303 interior_columns: usize,
304 alpha: f64,
305 beta: f64,
306 ) -> Self {
307 Self {
308 transpose_left,
309 transpose_right,
310 result_rows,
311 result_columns,
312 interior_columns,
313 alpha,
314 beta,
315 }
316 }
317}
318
319opaque_handle!(MatrixMultiplication);
320impl MatrixMultiplication {
321 #[must_use]
323 pub fn new(device: &MetalDevice, descriptor: MatrixMultiplicationDescriptor) -> Option<Self> {
324 let ptr = unsafe {
326 ffi::mps_matrix_multiplication_new(
327 device.as_ptr(),
328 descriptor.transpose_left,
329 descriptor.transpose_right,
330 descriptor.result_rows,
331 descriptor.result_columns,
332 descriptor.interior_columns,
333 descriptor.alpha,
334 descriptor.beta,
335 )
336 };
337 if ptr.is_null() {
338 None
339 } else {
340 Some(Self { ptr })
341 }
342 }
343
344 #[must_use]
346 pub fn new_simple(
347 device: &MetalDevice,
348 result_rows: usize,
349 result_columns: usize,
350 interior_columns: usize,
351 ) -> Option<Self> {
352 Self::new(
353 device,
354 MatrixMultiplicationDescriptor::new(result_rows, result_columns, interior_columns),
355 )
356 }
357
358 pub fn encode(
360 &self,
361 command_buffer: &CommandBuffer,
362 left: &Matrix,
363 right: &Matrix,
364 result: &Matrix,
365 ) {
366 unsafe {
368 ffi::mps_matrix_multiplication_encode(
369 self.ptr,
370 command_buffer.as_ptr(),
371 left.as_ptr(),
372 right.as_ptr(),
373 result.as_ptr(),
374 );
375 };
376 }
377}