Skip to main content

apple_mps/
matrix.rs

1use crate::ffi;
2use apple_metal::{CommandBuffer, MetalBuffer, MetalDevice};
3use core::ffi::c_void;
4use core::ptr;
5
6/// Selected `MPSDataType` constants used for matrices and vectors.
7pub mod data_type {
8    /// Wraps a `MPSDataType` raw value.
9    pub const INVALID: u32 = 0;
10    /// Wraps a `MPSDataType` raw value.
11    pub const FLOAT32: u32 = 0x1000_0020;
12    /// Wraps a `MPSDataType` raw value.
13    pub const FLOAT16: u32 = 0x1000_0010;
14    /// Wraps a `MPSDataType` raw value.
15    pub const INT8: u32 = 0x2000_0008;
16    /// Wraps a `MPSDataType` raw value.
17    pub const INT16: u32 = 0x2000_0010;
18    /// Wraps a `MPSDataType` raw value.
19    pub const INT32: u32 = 0x2000_0020;
20    /// Wraps a `MPSDataType` raw value.
21    pub const UINT8: u32 = 0x0000_0008;
22    /// Wraps a `MPSDataType` raw value.
23    pub const UINT16: u32 = 0x0000_0010;
24    /// Wraps a `MPSDataType` raw value.
25    pub const UINT32: u32 = 0x0000_0020;
26    /// Wraps a `MPSDataType` raw value.
27    pub const UNORM8: u32 = 0x4000_0008;
28}
29
30/// Return the byte width of a supported `MPSDataType`.
31#[must_use]
32pub const fn data_type_size(data_type: u32) -> Option<usize> {
33    match data_type {
34        data_type::FLOAT16 | data_type::INT16 | data_type::UINT16 => Some(2),
35        data_type::FLOAT32 | data_type::INT32 | data_type::UINT32 => Some(4),
36        data_type::INT8 | data_type::UINT8 | data_type::UNORM8 => Some(1),
37        _ => None,
38    }
39}
40
41/// Plain-Rust configuration for `MPSMatrixDescriptor`.
42#[derive(Debug, Clone, Copy)]
43pub struct MatrixDescriptor {
44    /// Corresponds to the `rows` field on `MPSMatrixDescriptor`.
45    pub rows: usize,
46    /// Corresponds to the `columns` field on `MPSMatrixDescriptor`.
47    pub columns: usize,
48    /// Corresponds to the `matrices` field on `MPSMatrixDescriptor`.
49    pub matrices: usize,
50    /// Corresponds to the `row_bytes` field on `MPSMatrixDescriptor`.
51    pub row_bytes: usize,
52    /// Corresponds to the `matrix_bytes` field on `MPSMatrixDescriptor`.
53    pub matrix_bytes: usize,
54    /// Corresponds to the `data_type` field on `MPSMatrixDescriptor`.
55    pub data_type: u32,
56}
57
58impl MatrixDescriptor {
59    /// Construct a matrix descriptor with explicit row and matrix strides.
60    #[must_use]
61    pub const fn with_strides(
62        rows: usize,
63        columns: usize,
64        matrices: usize,
65        row_bytes: usize,
66        matrix_bytes: usize,
67        data_type: u32,
68    ) -> Self {
69        Self {
70            rows,
71            columns,
72            matrices,
73            row_bytes,
74            matrix_bytes,
75            data_type,
76        }
77    }
78
79    /// Construct a single contiguous matrix descriptor for a supported data type.
80    #[must_use]
81    pub fn contiguous(rows: usize, columns: usize, data_type: u32) -> Option<Self> {
82        let element_size = data_type_size(data_type)?;
83        let row_bytes = columns.checked_mul(element_size)?;
84        let matrix_bytes = rows.checked_mul(row_bytes)?;
85        Some(Self::with_strides(
86            rows,
87            columns,
88            1,
89            row_bytes,
90            matrix_bytes,
91            data_type,
92        ))
93    }
94
95    /// Query MPS's recommended row stride for a matrix width.
96    #[must_use]
97    pub fn recommended_row_bytes(columns: usize, data_type: u32) -> usize {
98        // SAFETY: Pure function over scalar inputs.
99        unsafe { ffi::mps_matrix_descriptor_row_bytes_for_columns(columns, data_type) }
100    }
101}
102
103/// Plain-Rust configuration for `MPSVectorDescriptor`.
104#[derive(Debug, Clone, Copy)]
105pub struct VectorDescriptor {
106    /// Corresponds to the `length` field on `MPSVectorDescriptor`.
107    pub length: usize,
108    /// Corresponds to the `vectors` field on `MPSVectorDescriptor`.
109    pub vectors: usize,
110    /// Corresponds to the `vector_bytes` field on `MPSVectorDescriptor`.
111    pub vector_bytes: usize,
112    /// Corresponds to the `data_type` field on `MPSVectorDescriptor`.
113    pub data_type: u32,
114}
115
116impl VectorDescriptor {
117    /// Construct a vector descriptor with an explicit stride.
118    #[must_use]
119    pub const fn with_stride(
120        length: usize,
121        vectors: usize,
122        vector_bytes: usize,
123        data_type: u32,
124    ) -> Self {
125        Self {
126            length,
127            vectors,
128            vector_bytes,
129            data_type,
130        }
131    }
132
133    /// Construct a contiguous vector descriptor for a supported data type.
134    #[must_use]
135    pub fn contiguous(length: usize, data_type: u32) -> Option<Self> {
136        let element_size = data_type_size(data_type)?;
137        let vector_bytes = length.checked_mul(element_size)?;
138        Some(Self::with_stride(length, 1, vector_bytes, data_type))
139    }
140
141    /// Query MPS's recommended vector stride for a vector length.
142    #[must_use]
143    pub fn recommended_vector_bytes(length: usize, data_type: u32) -> usize {
144        // SAFETY: Pure function over scalar inputs.
145        unsafe { ffi::mps_vector_descriptor_vector_bytes_for_length(length, data_type) }
146    }
147}
148
149macro_rules! opaque_handle {
150    ($name:ident, $doc:expr) => {
151        #[doc = $doc]
152        pub struct $name {
153            ptr: *mut c_void,
154        }
155
156        // SAFETY: MPS handles are opaque pointers to thread-safe Swift/ObjC objects.
157        unsafe impl Send for $name {}
158        // SAFETY: MPS handles are opaque pointers to thread-safe Swift/ObjC objects.
159        unsafe impl Sync for $name {}
160
161        impl Drop for $name {
162            fn drop(&mut self) {
163                if !self.ptr.is_null() {
164                    // SAFETY: `ptr` is a +1 retained Swift/ObjC object pointer owned by this wrapper.
165                    unsafe { ffi::mps_object_release(self.ptr) };
166                    self.ptr = ptr::null_mut();
167                }
168            }
169        }
170
171        impl $name {
172            /// Returns the retained Objective-C pointer backing this wrapper.
173            #[must_use]
174            pub const fn as_ptr(&self) -> *mut c_void {
175                self.ptr
176            }
177        }
178    };
179}
180
181opaque_handle!(Matrix, "Wraps `MPSMatrix`.");
182impl Matrix {
183    /// Wrap an existing `MTLBuffer` as an `MPSMatrix`.
184    #[must_use]
185    pub fn new_with_buffer(buffer: &MetalBuffer, descriptor: MatrixDescriptor) -> Option<Self> {
186        // SAFETY: `buffer` is a valid `MTLBuffer` wrapper and scalar parameters are POD.
187        let ptr = unsafe {
188            ffi::mps_matrix_new_with_buffer(
189                buffer.as_ptr(),
190                descriptor.rows,
191                descriptor.columns,
192                descriptor.matrices,
193                descriptor.row_bytes,
194                descriptor.matrix_bytes,
195                descriptor.data_type,
196            )
197        };
198        if ptr.is_null() {
199            None
200        } else {
201            Some(Self { ptr })
202        }
203    }
204
205    /// Wraps the corresponding `MPSMatrix` method.
206    #[must_use]
207    pub fn rows(&self) -> usize {
208        // SAFETY: `self.ptr` is a valid `MPSMatrix` pointer while `self` is alive.
209        unsafe { ffi::mps_matrix_rows(self.ptr) }
210    }
211
212    /// Wraps the corresponding `MPSMatrix` method.
213    #[must_use]
214    pub fn columns(&self) -> usize {
215        // SAFETY: `self.ptr` is a valid `MPSMatrix` pointer while `self` is alive.
216        unsafe { ffi::mps_matrix_columns(self.ptr) }
217    }
218
219    /// Wraps the corresponding `MPSMatrix` method.
220    #[must_use]
221    pub fn matrices(&self) -> usize {
222        // SAFETY: `self.ptr` is a valid `MPSMatrix` pointer while `self` is alive.
223        unsafe { ffi::mps_matrix_matrices(self.ptr) }
224    }
225
226    /// Wraps the corresponding `MPSMatrix` method.
227    #[must_use]
228    pub fn row_bytes(&self) -> usize {
229        // SAFETY: `self.ptr` is a valid `MPSMatrix` pointer while `self` is alive.
230        unsafe { ffi::mps_matrix_row_bytes(self.ptr) }
231    }
232
233    /// Wraps the corresponding `MPSMatrix` method.
234    #[must_use]
235    pub fn matrix_bytes(&self) -> usize {
236        // SAFETY: `self.ptr` is a valid `MPSMatrix` pointer while `self` is alive.
237        unsafe { ffi::mps_matrix_matrix_bytes(self.ptr) }
238    }
239
240    /// Wraps the corresponding `MPSMatrix` method.
241    #[must_use]
242    pub fn data_type(&self) -> u32 {
243        // SAFETY: `self.ptr` is a valid `MPSMatrix` pointer while `self` is alive.
244        unsafe { ffi::mps_matrix_data_type(self.ptr) }
245    }
246}
247
248opaque_handle!(Vector, "Wraps `MPSVector`.");
249#[doc(hidden)]
250pub use crate::generated::matrix::*;
251
252impl Vector {
253    /// Wrap an existing `MTLBuffer` as an `MPSVector`.
254    #[must_use]
255    pub fn new_with_buffer(buffer: &MetalBuffer, descriptor: VectorDescriptor) -> Option<Self> {
256        // SAFETY: `buffer` is a valid `MTLBuffer` wrapper and scalar parameters are POD.
257        let ptr = unsafe {
258            ffi::mps_vector_new_with_buffer(
259                buffer.as_ptr(),
260                descriptor.length,
261                descriptor.vectors,
262                descriptor.vector_bytes,
263                descriptor.data_type,
264            )
265        };
266        if ptr.is_null() {
267            None
268        } else {
269            Some(Self { ptr })
270        }
271    }
272
273    /// Wraps the corresponding `MPSVector` method.
274    #[must_use]
275    pub fn length(&self) -> usize {
276        // SAFETY: `self.ptr` is a valid `MPSVector` pointer while `self` is alive.
277        unsafe { ffi::mps_vector_length(self.ptr) }
278    }
279
280    /// Wraps the corresponding `MPSVector` method.
281    #[must_use]
282    pub fn vectors(&self) -> usize {
283        // SAFETY: `self.ptr` is a valid `MPSVector` pointer while `self` is alive.
284        unsafe { ffi::mps_vector_vectors(self.ptr) }
285    }
286
287    /// Wraps the corresponding `MPSVector` method.
288    #[must_use]
289    pub fn vector_bytes(&self) -> usize {
290        // SAFETY: `self.ptr` is a valid `MPSVector` pointer while `self` is alive.
291        unsafe { ffi::mps_vector_vector_bytes(self.ptr) }
292    }
293
294    /// Wraps the corresponding `MPSVector` method.
295    #[must_use]
296    pub fn data_type(&self) -> u32 {
297        // SAFETY: `self.ptr` is a valid `MPSVector` pointer while `self` is alive.
298        unsafe { ffi::mps_vector_data_type(self.ptr) }
299    }
300}
301
302/// Plain-Rust configuration for `MPSMatrixMultiplication`.
303#[derive(Debug, Clone, Copy)]
304pub struct MatrixMultiplicationDescriptor {
305    /// Corresponds to the `transpose_left` field on `MPSMatrixMultiplication`.
306    pub transpose_left: bool,
307    /// Corresponds to the `transpose_right` field on `MPSMatrixMultiplication`.
308    pub transpose_right: bool,
309    /// Corresponds to the `result_rows` field on `MPSMatrixMultiplication`.
310    pub result_rows: usize,
311    /// Corresponds to the `result_columns` field on `MPSMatrixMultiplication`.
312    pub result_columns: usize,
313    /// Corresponds to the `interior_columns` field on `MPSMatrixMultiplication`.
314    pub interior_columns: usize,
315    /// Corresponds to the `alpha` field on `MPSMatrixMultiplication`.
316    pub alpha: f64,
317    /// Corresponds to the `beta` field on `MPSMatrixMultiplication`.
318    pub beta: f64,
319}
320
321impl MatrixMultiplicationDescriptor {
322    /// Construct the common `C = A * B` descriptor.
323    #[must_use]
324    pub const fn new(result_rows: usize, result_columns: usize, interior_columns: usize) -> Self {
325        Self {
326            transpose_left: false,
327            transpose_right: false,
328            result_rows,
329            result_columns,
330            interior_columns,
331            alpha: 1.0,
332            beta: 0.0,
333        }
334    }
335
336    /// Construct a fully configurable descriptor.
337    #[must_use]
338    pub const fn with_options(
339        transpose_left: bool,
340        transpose_right: bool,
341        result_rows: usize,
342        result_columns: usize,
343        interior_columns: usize,
344        alpha: f64,
345        beta: f64,
346    ) -> Self {
347        Self {
348            transpose_left,
349            transpose_right,
350            result_rows,
351            result_columns,
352            interior_columns,
353            alpha,
354            beta,
355        }
356    }
357}
358
359opaque_handle!(MatrixMultiplication, "Wraps `MPSMatrixMultiplication`.");
360impl MatrixMultiplication {
361    /// Build a configurable GEMM kernel with optional transposition and scaling.
362    #[must_use]
363    pub fn new(device: &MetalDevice, descriptor: MatrixMultiplicationDescriptor) -> Option<Self> {
364        // SAFETY: `device` exposes a valid `MTLDevice` pointer.
365        let ptr = unsafe {
366            ffi::mps_matrix_multiplication_new(
367                device.as_ptr(),
368                descriptor.transpose_left,
369                descriptor.transpose_right,
370                descriptor.result_rows,
371                descriptor.result_columns,
372                descriptor.interior_columns,
373                descriptor.alpha,
374                descriptor.beta,
375            )
376        };
377        if ptr.is_null() {
378            None
379        } else {
380            Some(Self { ptr })
381        }
382    }
383
384    /// Convenience constructor for the common `C = A * B` case.
385    #[must_use]
386    pub fn new_simple(
387        device: &MetalDevice,
388        result_rows: usize,
389        result_columns: usize,
390        interior_columns: usize,
391    ) -> Option<Self> {
392        Self::new(
393            device,
394            MatrixMultiplicationDescriptor::new(result_rows, result_columns, interior_columns),
395        )
396    }
397
398    /// Encode the matrix multiplication onto a command buffer.
399    pub fn encode(
400        &self,
401        command_buffer: &CommandBuffer,
402        left: &Matrix,
403        right: &Matrix,
404        result: &Matrix,
405    ) {
406        // SAFETY: All handles come from safe wrappers and remain alive for the call.
407        unsafe {
408            ffi::mps_matrix_multiplication_encode(
409                self.ptr,
410                command_buffer.as_ptr(),
411                left.as_ptr(),
412                right.as_ptr(),
413                result.as_ptr(),
414            );
415        };
416    }
417}