1use crate::ffi;
2use apple_metal::{CommandBuffer, MetalBuffer, MetalDevice};
3use core::ffi::c_void;
4use core::ptr;
5
6pub mod data_type {
8 pub const INVALID: u32 = 0;
9 pub const FLOAT32: u32 = 0x1000_0020;
10 pub const FLOAT16: u32 = 0x1000_0010;
11 pub const INT8: u32 = 0x2000_0008;
12 pub const INT16: u32 = 0x2000_0010;
13 pub const INT32: u32 = 0x2000_0020;
14 pub const UINT8: u32 = 0x0000_0008;
15 pub const UINT16: u32 = 0x0000_0010;
16 pub const UINT32: u32 = 0x0000_0020;
17 pub const UNORM8: u32 = 0x4000_0008;
18}
19
20#[must_use]
22pub const fn data_type_size(data_type: u32) -> Option<usize> {
23 match data_type {
24 data_type::FLOAT16 | data_type::INT16 | data_type::UINT16 => Some(2),
25 data_type::FLOAT32 | data_type::INT32 | data_type::UINT32 => Some(4),
26 data_type::INT8 | data_type::UINT8 | data_type::UNORM8 => Some(1),
27 _ => None,
28 }
29}
30
31#[derive(Debug, Clone, Copy)]
33pub struct MatrixDescriptor {
34 pub rows: usize,
35 pub columns: usize,
36 pub matrices: usize,
37 pub row_bytes: usize,
38 pub matrix_bytes: usize,
39 pub data_type: u32,
40}
41
42impl MatrixDescriptor {
43 #[must_use]
45 pub const fn with_strides(
46 rows: usize,
47 columns: usize,
48 matrices: usize,
49 row_bytes: usize,
50 matrix_bytes: usize,
51 data_type: u32,
52 ) -> Self {
53 Self {
54 rows,
55 columns,
56 matrices,
57 row_bytes,
58 matrix_bytes,
59 data_type,
60 }
61 }
62
63 #[must_use]
65 pub fn contiguous(rows: usize, columns: usize, data_type: u32) -> Option<Self> {
66 let element_size = data_type_size(data_type)?;
67 let row_bytes = columns.checked_mul(element_size)?;
68 let matrix_bytes = rows.checked_mul(row_bytes)?;
69 Some(Self::with_strides(
70 rows,
71 columns,
72 1,
73 row_bytes,
74 matrix_bytes,
75 data_type,
76 ))
77 }
78
79 #[must_use]
81 pub fn recommended_row_bytes(columns: usize, data_type: u32) -> usize {
82 unsafe { ffi::mps_matrix_descriptor_row_bytes_for_columns(columns, data_type) }
84 }
85}
86
87#[derive(Debug, Clone, Copy)]
89pub struct VectorDescriptor {
90 pub length: usize,
91 pub vectors: usize,
92 pub vector_bytes: usize,
93 pub data_type: u32,
94}
95
96impl VectorDescriptor {
97 #[must_use]
99 pub const fn with_stride(
100 length: usize,
101 vectors: usize,
102 vector_bytes: usize,
103 data_type: u32,
104 ) -> Self {
105 Self {
106 length,
107 vectors,
108 vector_bytes,
109 data_type,
110 }
111 }
112
113 #[must_use]
115 pub fn contiguous(length: usize, data_type: u32) -> Option<Self> {
116 let element_size = data_type_size(data_type)?;
117 let vector_bytes = length.checked_mul(element_size)?;
118 Some(Self::with_stride(length, 1, vector_bytes, data_type))
119 }
120
121 #[must_use]
123 pub fn recommended_vector_bytes(length: usize, data_type: u32) -> usize {
124 unsafe { ffi::mps_vector_descriptor_vector_bytes_for_length(length, data_type) }
126 }
127}
128
129macro_rules! opaque_handle {
130 ($name:ident) => {
131 pub struct $name {
132 ptr: *mut c_void,
133 }
134
135 unsafe impl Send for $name {}
136 unsafe impl Sync for $name {}
137
138 impl Drop for $name {
139 fn drop(&mut self) {
140 if !self.ptr.is_null() {
141 unsafe { ffi::mps_object_release(self.ptr) };
143 self.ptr = ptr::null_mut();
144 }
145 }
146 }
147
148 impl $name {
149 #[must_use]
150 pub const fn as_ptr(&self) -> *mut c_void {
151 self.ptr
152 }
153 }
154 };
155}
156
157opaque_handle!(Matrix);
158impl Matrix {
159 #[must_use]
161 pub fn new_with_buffer(buffer: &MetalBuffer, descriptor: MatrixDescriptor) -> Option<Self> {
162 let ptr = unsafe {
164 ffi::mps_matrix_new_with_buffer(
165 buffer.as_ptr(),
166 descriptor.rows,
167 descriptor.columns,
168 descriptor.matrices,
169 descriptor.row_bytes,
170 descriptor.matrix_bytes,
171 descriptor.data_type,
172 )
173 };
174 if ptr.is_null() {
175 None
176 } else {
177 Some(Self { ptr })
178 }
179 }
180
181 #[must_use]
182 pub fn rows(&self) -> usize {
183 unsafe { ffi::mps_matrix_rows(self.ptr) }
185 }
186
187 #[must_use]
188 pub fn columns(&self) -> usize {
189 unsafe { ffi::mps_matrix_columns(self.ptr) }
191 }
192
193 #[must_use]
194 pub fn matrices(&self) -> usize {
195 unsafe { ffi::mps_matrix_matrices(self.ptr) }
197 }
198
199 #[must_use]
200 pub fn row_bytes(&self) -> usize {
201 unsafe { ffi::mps_matrix_row_bytes(self.ptr) }
203 }
204
205 #[must_use]
206 pub fn matrix_bytes(&self) -> usize {
207 unsafe { ffi::mps_matrix_matrix_bytes(self.ptr) }
209 }
210
211 #[must_use]
212 pub fn data_type(&self) -> u32 {
213 unsafe { ffi::mps_matrix_data_type(self.ptr) }
215 }
216}
217
218opaque_handle!(Vector);
219impl Vector {
220 #[must_use]
222 pub fn new_with_buffer(buffer: &MetalBuffer, descriptor: VectorDescriptor) -> Option<Self> {
223 let ptr = unsafe {
225 ffi::mps_vector_new_with_buffer(
226 buffer.as_ptr(),
227 descriptor.length,
228 descriptor.vectors,
229 descriptor.vector_bytes,
230 descriptor.data_type,
231 )
232 };
233 if ptr.is_null() {
234 None
235 } else {
236 Some(Self { ptr })
237 }
238 }
239
240 #[must_use]
241 pub fn length(&self) -> usize {
242 unsafe { ffi::mps_vector_length(self.ptr) }
244 }
245
246 #[must_use]
247 pub fn vectors(&self) -> usize {
248 unsafe { ffi::mps_vector_vectors(self.ptr) }
250 }
251
252 #[must_use]
253 pub fn vector_bytes(&self) -> usize {
254 unsafe { ffi::mps_vector_vector_bytes(self.ptr) }
256 }
257
258 #[must_use]
259 pub fn data_type(&self) -> u32 {
260 unsafe { ffi::mps_vector_data_type(self.ptr) }
262 }
263}
264
265#[derive(Debug, Clone, Copy)]
267pub struct MatrixMultiplicationDescriptor {
268 pub transpose_left: bool,
269 pub transpose_right: bool,
270 pub result_rows: usize,
271 pub result_columns: usize,
272 pub interior_columns: usize,
273 pub alpha: f64,
274 pub beta: f64,
275}
276
277impl MatrixMultiplicationDescriptor {
278 #[must_use]
280 pub const fn new(result_rows: usize, result_columns: usize, interior_columns: usize) -> Self {
281 Self {
282 transpose_left: false,
283 transpose_right: false,
284 result_rows,
285 result_columns,
286 interior_columns,
287 alpha: 1.0,
288 beta: 0.0,
289 }
290 }
291
292 #[must_use]
294 pub const fn with_options(
295 transpose_left: bool,
296 transpose_right: bool,
297 result_rows: usize,
298 result_columns: usize,
299 interior_columns: usize,
300 alpha: f64,
301 beta: f64,
302 ) -> Self {
303 Self {
304 transpose_left,
305 transpose_right,
306 result_rows,
307 result_columns,
308 interior_columns,
309 alpha,
310 beta,
311 }
312 }
313}
314
315opaque_handle!(MatrixMultiplication);
316impl MatrixMultiplication {
317 #[must_use]
319 pub fn new(device: &MetalDevice, descriptor: MatrixMultiplicationDescriptor) -> Option<Self> {
320 let ptr = unsafe {
322 ffi::mps_matrix_multiplication_new(
323 device.as_ptr(),
324 descriptor.transpose_left,
325 descriptor.transpose_right,
326 descriptor.result_rows,
327 descriptor.result_columns,
328 descriptor.interior_columns,
329 descriptor.alpha,
330 descriptor.beta,
331 )
332 };
333 if ptr.is_null() {
334 None
335 } else {
336 Some(Self { ptr })
337 }
338 }
339
340 #[must_use]
342 pub fn new_simple(
343 device: &MetalDevice,
344 result_rows: usize,
345 result_columns: usize,
346 interior_columns: usize,
347 ) -> Option<Self> {
348 Self::new(
349 device,
350 MatrixMultiplicationDescriptor::new(result_rows, result_columns, interior_columns),
351 )
352 }
353
354 pub fn encode(
356 &self,
357 command_buffer: &CommandBuffer,
358 left: &Matrix,
359 right: &Matrix,
360 result: &Matrix,
361 ) {
362 unsafe {
364 ffi::mps_matrix_multiplication_encode(
365 self.ptr,
366 command_buffer.as_ptr(),
367 left.as_ptr(),
368 right.as_ptr(),
369 result.as_ptr(),
370 );
371 };
372 }
373}