Skip to main content

apple_mps/
matrix.rs

1use crate::ffi;
2use apple_metal::{CommandBuffer, MetalBuffer, MetalDevice};
3use core::ffi::c_void;
4use core::ptr;
5
6/// Selected `MPSDataType` constants used for matrices and vectors.
7pub mod data_type {
8    pub const INVALID: u32 = 0;
9    pub const FLOAT32: u32 = 0x1000_0020;
10    pub const FLOAT16: u32 = 0x1000_0010;
11    pub const INT8: u32 = 0x2000_0008;
12    pub const INT16: u32 = 0x2000_0010;
13    pub const INT32: u32 = 0x2000_0020;
14    pub const UINT8: u32 = 0x0000_0008;
15    pub const UINT16: u32 = 0x0000_0010;
16    pub const UINT32: u32 = 0x0000_0020;
17    pub const UNORM8: u32 = 0x4000_0008;
18}
19
20/// Return the byte width of a supported `MPSDataType`.
21#[must_use]
22pub const fn data_type_size(data_type: u32) -> Option<usize> {
23    match data_type {
24        data_type::FLOAT16 | data_type::INT16 | data_type::UINT16 => Some(2),
25        data_type::FLOAT32 | data_type::INT32 | data_type::UINT32 => Some(4),
26        data_type::INT8 | data_type::UINT8 | data_type::UNORM8 => Some(1),
27        _ => None,
28    }
29}
30
31/// Plain-Rust configuration for `MPSMatrixDescriptor`.
32#[derive(Debug, Clone, Copy)]
33pub struct MatrixDescriptor {
34    pub rows: usize,
35    pub columns: usize,
36    pub matrices: usize,
37    pub row_bytes: usize,
38    pub matrix_bytes: usize,
39    pub data_type: u32,
40}
41
42impl MatrixDescriptor {
43    /// Construct a matrix descriptor with explicit row and matrix strides.
44    #[must_use]
45    pub const fn with_strides(
46        rows: usize,
47        columns: usize,
48        matrices: usize,
49        row_bytes: usize,
50        matrix_bytes: usize,
51        data_type: u32,
52    ) -> Self {
53        Self {
54            rows,
55            columns,
56            matrices,
57            row_bytes,
58            matrix_bytes,
59            data_type,
60        }
61    }
62
63    /// Construct a single contiguous matrix descriptor for a supported data type.
64    #[must_use]
65    pub fn contiguous(rows: usize, columns: usize, data_type: u32) -> Option<Self> {
66        let element_size = data_type_size(data_type)?;
67        let row_bytes = columns.checked_mul(element_size)?;
68        let matrix_bytes = rows.checked_mul(row_bytes)?;
69        Some(Self::with_strides(
70            rows,
71            columns,
72            1,
73            row_bytes,
74            matrix_bytes,
75            data_type,
76        ))
77    }
78
79    /// Query MPS's recommended row stride for a matrix width.
80    #[must_use]
81    pub fn recommended_row_bytes(columns: usize, data_type: u32) -> usize {
82        // SAFETY: Pure function over scalar inputs.
83        unsafe { ffi::mps_matrix_descriptor_row_bytes_for_columns(columns, data_type) }
84    }
85}
86
87/// Plain-Rust configuration for `MPSVectorDescriptor`.
88#[derive(Debug, Clone, Copy)]
89pub struct VectorDescriptor {
90    pub length: usize,
91    pub vectors: usize,
92    pub vector_bytes: usize,
93    pub data_type: u32,
94}
95
96impl VectorDescriptor {
97    /// Construct a vector descriptor with an explicit stride.
98    #[must_use]
99    pub const fn with_stride(
100        length: usize,
101        vectors: usize,
102        vector_bytes: usize,
103        data_type: u32,
104    ) -> Self {
105        Self {
106            length,
107            vectors,
108            vector_bytes,
109            data_type,
110        }
111    }
112
113    /// Construct a contiguous vector descriptor for a supported data type.
114    #[must_use]
115    pub fn contiguous(length: usize, data_type: u32) -> Option<Self> {
116        let element_size = data_type_size(data_type)?;
117        let vector_bytes = length.checked_mul(element_size)?;
118        Some(Self::with_stride(length, 1, vector_bytes, data_type))
119    }
120
121    /// Query MPS's recommended vector stride for a vector length.
122    #[must_use]
123    pub fn recommended_vector_bytes(length: usize, data_type: u32) -> usize {
124        // SAFETY: Pure function over scalar inputs.
125        unsafe { ffi::mps_vector_descriptor_vector_bytes_for_length(length, data_type) }
126    }
127}
128
129macro_rules! opaque_handle {
130    ($name:ident) => {
131        pub struct $name {
132            ptr: *mut c_void,
133        }
134
135        // SAFETY: MPS handles are opaque pointers to thread-safe Swift/ObjC objects.
136        unsafe impl Send for $name {}
137        // SAFETY: MPS handles are opaque pointers to thread-safe Swift/ObjC objects.
138        unsafe impl Sync for $name {}
139
140        impl Drop for $name {
141            fn drop(&mut self) {
142                if !self.ptr.is_null() {
143                    // SAFETY: `ptr` is a +1 retained Swift/ObjC object pointer owned by this wrapper.
144                    unsafe { ffi::mps_object_release(self.ptr) };
145                    self.ptr = ptr::null_mut();
146                }
147            }
148        }
149
150        impl $name {
151            #[must_use]
152            pub const fn as_ptr(&self) -> *mut c_void {
153                self.ptr
154            }
155        }
156    };
157}
158
159opaque_handle!(Matrix);
160impl Matrix {
161    /// Wrap an existing `MTLBuffer` as an `MPSMatrix`.
162    #[must_use]
163    pub fn new_with_buffer(buffer: &MetalBuffer, descriptor: MatrixDescriptor) -> Option<Self> {
164        // SAFETY: `buffer` is a valid `MTLBuffer` wrapper and scalar parameters are POD.
165        let ptr = unsafe {
166            ffi::mps_matrix_new_with_buffer(
167                buffer.as_ptr(),
168                descriptor.rows,
169                descriptor.columns,
170                descriptor.matrices,
171                descriptor.row_bytes,
172                descriptor.matrix_bytes,
173                descriptor.data_type,
174            )
175        };
176        if ptr.is_null() {
177            None
178        } else {
179            Some(Self { ptr })
180        }
181    }
182
183    #[must_use]
184    pub fn rows(&self) -> usize {
185        // SAFETY: `self.ptr` is a valid `MPSMatrix` pointer while `self` is alive.
186        unsafe { ffi::mps_matrix_rows(self.ptr) }
187    }
188
189    #[must_use]
190    pub fn columns(&self) -> usize {
191        // SAFETY: `self.ptr` is a valid `MPSMatrix` pointer while `self` is alive.
192        unsafe { ffi::mps_matrix_columns(self.ptr) }
193    }
194
195    #[must_use]
196    pub fn matrices(&self) -> usize {
197        // SAFETY: `self.ptr` is a valid `MPSMatrix` pointer while `self` is alive.
198        unsafe { ffi::mps_matrix_matrices(self.ptr) }
199    }
200
201    #[must_use]
202    pub fn row_bytes(&self) -> usize {
203        // SAFETY: `self.ptr` is a valid `MPSMatrix` pointer while `self` is alive.
204        unsafe { ffi::mps_matrix_row_bytes(self.ptr) }
205    }
206
207    #[must_use]
208    pub fn matrix_bytes(&self) -> usize {
209        // SAFETY: `self.ptr` is a valid `MPSMatrix` pointer while `self` is alive.
210        unsafe { ffi::mps_matrix_matrix_bytes(self.ptr) }
211    }
212
213    #[must_use]
214    pub fn data_type(&self) -> u32 {
215        // SAFETY: `self.ptr` is a valid `MPSMatrix` pointer while `self` is alive.
216        unsafe { ffi::mps_matrix_data_type(self.ptr) }
217    }
218}
219
220opaque_handle!(Vector);
221pub use crate::generated::matrix::*;
222
223impl Vector {
224    /// Wrap an existing `MTLBuffer` as an `MPSVector`.
225    #[must_use]
226    pub fn new_with_buffer(buffer: &MetalBuffer, descriptor: VectorDescriptor) -> Option<Self> {
227        // SAFETY: `buffer` is a valid `MTLBuffer` wrapper and scalar parameters are POD.
228        let ptr = unsafe {
229            ffi::mps_vector_new_with_buffer(
230                buffer.as_ptr(),
231                descriptor.length,
232                descriptor.vectors,
233                descriptor.vector_bytes,
234                descriptor.data_type,
235            )
236        };
237        if ptr.is_null() {
238            None
239        } else {
240            Some(Self { ptr })
241        }
242    }
243
244    #[must_use]
245    pub fn length(&self) -> usize {
246        // SAFETY: `self.ptr` is a valid `MPSVector` pointer while `self` is alive.
247        unsafe { ffi::mps_vector_length(self.ptr) }
248    }
249
250    #[must_use]
251    pub fn vectors(&self) -> usize {
252        // SAFETY: `self.ptr` is a valid `MPSVector` pointer while `self` is alive.
253        unsafe { ffi::mps_vector_vectors(self.ptr) }
254    }
255
256    #[must_use]
257    pub fn vector_bytes(&self) -> usize {
258        // SAFETY: `self.ptr` is a valid `MPSVector` pointer while `self` is alive.
259        unsafe { ffi::mps_vector_vector_bytes(self.ptr) }
260    }
261
262    #[must_use]
263    pub fn data_type(&self) -> u32 {
264        // SAFETY: `self.ptr` is a valid `MPSVector` pointer while `self` is alive.
265        unsafe { ffi::mps_vector_data_type(self.ptr) }
266    }
267}
268
269/// Plain-Rust configuration for `MPSMatrixMultiplication`.
270#[derive(Debug, Clone, Copy)]
271pub struct MatrixMultiplicationDescriptor {
272    pub transpose_left: bool,
273    pub transpose_right: bool,
274    pub result_rows: usize,
275    pub result_columns: usize,
276    pub interior_columns: usize,
277    pub alpha: f64,
278    pub beta: f64,
279}
280
281impl MatrixMultiplicationDescriptor {
282    /// Construct the common `C = A * B` descriptor.
283    #[must_use]
284    pub const fn new(result_rows: usize, result_columns: usize, interior_columns: usize) -> Self {
285        Self {
286            transpose_left: false,
287            transpose_right: false,
288            result_rows,
289            result_columns,
290            interior_columns,
291            alpha: 1.0,
292            beta: 0.0,
293        }
294    }
295
296    /// Construct a fully configurable descriptor.
297    #[must_use]
298    pub const fn with_options(
299        transpose_left: bool,
300        transpose_right: bool,
301        result_rows: usize,
302        result_columns: usize,
303        interior_columns: usize,
304        alpha: f64,
305        beta: f64,
306    ) -> Self {
307        Self {
308            transpose_left,
309            transpose_right,
310            result_rows,
311            result_columns,
312            interior_columns,
313            alpha,
314            beta,
315        }
316    }
317}
318
319opaque_handle!(MatrixMultiplication);
320impl MatrixMultiplication {
321    /// Build a configurable GEMM kernel with optional transposition and scaling.
322    #[must_use]
323    pub fn new(device: &MetalDevice, descriptor: MatrixMultiplicationDescriptor) -> Option<Self> {
324        // SAFETY: `device` exposes a valid `MTLDevice` pointer.
325        let ptr = unsafe {
326            ffi::mps_matrix_multiplication_new(
327                device.as_ptr(),
328                descriptor.transpose_left,
329                descriptor.transpose_right,
330                descriptor.result_rows,
331                descriptor.result_columns,
332                descriptor.interior_columns,
333                descriptor.alpha,
334                descriptor.beta,
335            )
336        };
337        if ptr.is_null() {
338            None
339        } else {
340            Some(Self { ptr })
341        }
342    }
343
344    /// Convenience constructor for the common `C = A * B` case.
345    #[must_use]
346    pub fn new_simple(
347        device: &MetalDevice,
348        result_rows: usize,
349        result_columns: usize,
350        interior_columns: usize,
351    ) -> Option<Self> {
352        Self::new(
353            device,
354            MatrixMultiplicationDescriptor::new(result_rows, result_columns, interior_columns),
355        )
356    }
357
358    /// Encode the matrix multiplication onto a command buffer.
359    pub fn encode(
360        &self,
361        command_buffer: &CommandBuffer,
362        left: &Matrix,
363        right: &Matrix,
364        result: &Matrix,
365    ) {
366        // SAFETY: All handles come from safe wrappers and remain alive for the call.
367        unsafe {
368            ffi::mps_matrix_multiplication_encode(
369                self.ptr,
370                command_buffer.as_ptr(),
371                left.as_ptr(),
372                right.as_ptr(),
373                result.as_ptr(),
374            );
375        };
376    }
377}