Skip to main content

apple_mps/
matrix.rs

1use crate::ffi;
2use apple_metal::{CommandBuffer, MetalBuffer, MetalDevice};
3use core::ffi::c_void;
4use core::ptr;
5
6/// Selected `MPSDataType` constants used for matrices and vectors.
7pub mod data_type {
8    pub const INVALID: u32 = 0;
9    pub const FLOAT32: u32 = 0x1000_0020;
10    pub const FLOAT16: u32 = 0x1000_0010;
11    pub const INT8: u32 = 0x2000_0008;
12    pub const INT16: u32 = 0x2000_0010;
13    pub const INT32: u32 = 0x2000_0020;
14    pub const UINT8: u32 = 0x0000_0008;
15    pub const UINT16: u32 = 0x0000_0010;
16    pub const UINT32: u32 = 0x0000_0020;
17    pub const UNORM8: u32 = 0x4000_0008;
18}
19
20/// Return the byte width of a supported `MPSDataType`.
21#[must_use]
22pub const fn data_type_size(data_type: u32) -> Option<usize> {
23    match data_type {
24        data_type::FLOAT16 | data_type::INT16 | data_type::UINT16 => Some(2),
25        data_type::FLOAT32 | data_type::INT32 | data_type::UINT32 => Some(4),
26        data_type::INT8 | data_type::UINT8 | data_type::UNORM8 => Some(1),
27        _ => None,
28    }
29}
30
31/// Plain-Rust configuration for `MPSMatrixDescriptor`.
32#[derive(Debug, Clone, Copy)]
33pub struct MatrixDescriptor {
34    pub rows: usize,
35    pub columns: usize,
36    pub matrices: usize,
37    pub row_bytes: usize,
38    pub matrix_bytes: usize,
39    pub data_type: u32,
40}
41
42impl MatrixDescriptor {
43    /// Construct a matrix descriptor with explicit row and matrix strides.
44    #[must_use]
45    pub const fn with_strides(
46        rows: usize,
47        columns: usize,
48        matrices: usize,
49        row_bytes: usize,
50        matrix_bytes: usize,
51        data_type: u32,
52    ) -> Self {
53        Self {
54            rows,
55            columns,
56            matrices,
57            row_bytes,
58            matrix_bytes,
59            data_type,
60        }
61    }
62
63    /// Construct a single contiguous matrix descriptor for a supported data type.
64    #[must_use]
65    pub fn contiguous(rows: usize, columns: usize, data_type: u32) -> Option<Self> {
66        let element_size = data_type_size(data_type)?;
67        let row_bytes = columns.checked_mul(element_size)?;
68        let matrix_bytes = rows.checked_mul(row_bytes)?;
69        Some(Self::with_strides(
70            rows,
71            columns,
72            1,
73            row_bytes,
74            matrix_bytes,
75            data_type,
76        ))
77    }
78
79    /// Query MPS's recommended row stride for a matrix width.
80    #[must_use]
81    pub fn recommended_row_bytes(columns: usize, data_type: u32) -> usize {
82        // SAFETY: Pure function over scalar inputs.
83        unsafe { ffi::mps_matrix_descriptor_row_bytes_for_columns(columns, data_type) }
84    }
85}
86
87/// Plain-Rust configuration for `MPSVectorDescriptor`.
88#[derive(Debug, Clone, Copy)]
89pub struct VectorDescriptor {
90    pub length: usize,
91    pub vectors: usize,
92    pub vector_bytes: usize,
93    pub data_type: u32,
94}
95
96impl VectorDescriptor {
97    /// Construct a vector descriptor with an explicit stride.
98    #[must_use]
99    pub const fn with_stride(
100        length: usize,
101        vectors: usize,
102        vector_bytes: usize,
103        data_type: u32,
104    ) -> Self {
105        Self {
106            length,
107            vectors,
108            vector_bytes,
109            data_type,
110        }
111    }
112
113    /// Construct a contiguous vector descriptor for a supported data type.
114    #[must_use]
115    pub fn contiguous(length: usize, data_type: u32) -> Option<Self> {
116        let element_size = data_type_size(data_type)?;
117        let vector_bytes = length.checked_mul(element_size)?;
118        Some(Self::with_stride(length, 1, vector_bytes, data_type))
119    }
120
121    /// Query MPS's recommended vector stride for a vector length.
122    #[must_use]
123    pub fn recommended_vector_bytes(length: usize, data_type: u32) -> usize {
124        // SAFETY: Pure function over scalar inputs.
125        unsafe { ffi::mps_vector_descriptor_vector_bytes_for_length(length, data_type) }
126    }
127}
128
129macro_rules! opaque_handle {
130    ($name:ident) => {
131        pub struct $name {
132            ptr: *mut c_void,
133        }
134
135        unsafe impl Send for $name {}
136        unsafe impl Sync for $name {}
137
138        impl Drop for $name {
139            fn drop(&mut self) {
140                if !self.ptr.is_null() {
141                    // SAFETY: `ptr` is a +1 retained Swift/ObjC object pointer owned by this wrapper.
142                    unsafe { ffi::mps_object_release(self.ptr) };
143                    self.ptr = ptr::null_mut();
144                }
145            }
146        }
147
148        impl $name {
149            #[must_use]
150            pub const fn as_ptr(&self) -> *mut c_void {
151                self.ptr
152            }
153        }
154    };
155}
156
157opaque_handle!(Matrix);
158impl Matrix {
159    /// Wrap an existing `MTLBuffer` as an `MPSMatrix`.
160    #[must_use]
161    pub fn new_with_buffer(buffer: &MetalBuffer, descriptor: MatrixDescriptor) -> Option<Self> {
162        // SAFETY: `buffer` is a valid `MTLBuffer` wrapper and scalar parameters are POD.
163        let ptr = unsafe {
164            ffi::mps_matrix_new_with_buffer(
165                buffer.as_ptr(),
166                descriptor.rows,
167                descriptor.columns,
168                descriptor.matrices,
169                descriptor.row_bytes,
170                descriptor.matrix_bytes,
171                descriptor.data_type,
172            )
173        };
174        if ptr.is_null() {
175            None
176        } else {
177            Some(Self { ptr })
178        }
179    }
180
181    #[must_use]
182    pub fn rows(&self) -> usize {
183        // SAFETY: `self.ptr` is a valid `MPSMatrix` pointer while `self` is alive.
184        unsafe { ffi::mps_matrix_rows(self.ptr) }
185    }
186
187    #[must_use]
188    pub fn columns(&self) -> usize {
189        // SAFETY: `self.ptr` is a valid `MPSMatrix` pointer while `self` is alive.
190        unsafe { ffi::mps_matrix_columns(self.ptr) }
191    }
192
193    #[must_use]
194    pub fn matrices(&self) -> usize {
195        // SAFETY: `self.ptr` is a valid `MPSMatrix` pointer while `self` is alive.
196        unsafe { ffi::mps_matrix_matrices(self.ptr) }
197    }
198
199    #[must_use]
200    pub fn row_bytes(&self) -> usize {
201        // SAFETY: `self.ptr` is a valid `MPSMatrix` pointer while `self` is alive.
202        unsafe { ffi::mps_matrix_row_bytes(self.ptr) }
203    }
204
205    #[must_use]
206    pub fn matrix_bytes(&self) -> usize {
207        // SAFETY: `self.ptr` is a valid `MPSMatrix` pointer while `self` is alive.
208        unsafe { ffi::mps_matrix_matrix_bytes(self.ptr) }
209    }
210
211    #[must_use]
212    pub fn data_type(&self) -> u32 {
213        // SAFETY: `self.ptr` is a valid `MPSMatrix` pointer while `self` is alive.
214        unsafe { ffi::mps_matrix_data_type(self.ptr) }
215    }
216}
217
218opaque_handle!(Vector);
219impl Vector {
220    /// Wrap an existing `MTLBuffer` as an `MPSVector`.
221    #[must_use]
222    pub fn new_with_buffer(buffer: &MetalBuffer, descriptor: VectorDescriptor) -> Option<Self> {
223        // SAFETY: `buffer` is a valid `MTLBuffer` wrapper and scalar parameters are POD.
224        let ptr = unsafe {
225            ffi::mps_vector_new_with_buffer(
226                buffer.as_ptr(),
227                descriptor.length,
228                descriptor.vectors,
229                descriptor.vector_bytes,
230                descriptor.data_type,
231            )
232        };
233        if ptr.is_null() {
234            None
235        } else {
236            Some(Self { ptr })
237        }
238    }
239
240    #[must_use]
241    pub fn length(&self) -> usize {
242        // SAFETY: `self.ptr` is a valid `MPSVector` pointer while `self` is alive.
243        unsafe { ffi::mps_vector_length(self.ptr) }
244    }
245
246    #[must_use]
247    pub fn vectors(&self) -> usize {
248        // SAFETY: `self.ptr` is a valid `MPSVector` pointer while `self` is alive.
249        unsafe { ffi::mps_vector_vectors(self.ptr) }
250    }
251
252    #[must_use]
253    pub fn vector_bytes(&self) -> usize {
254        // SAFETY: `self.ptr` is a valid `MPSVector` pointer while `self` is alive.
255        unsafe { ffi::mps_vector_vector_bytes(self.ptr) }
256    }
257
258    #[must_use]
259    pub fn data_type(&self) -> u32 {
260        // SAFETY: `self.ptr` is a valid `MPSVector` pointer while `self` is alive.
261        unsafe { ffi::mps_vector_data_type(self.ptr) }
262    }
263}
264
265/// Plain-Rust configuration for `MPSMatrixMultiplication`.
266#[derive(Debug, Clone, Copy)]
267pub struct MatrixMultiplicationDescriptor {
268    pub transpose_left: bool,
269    pub transpose_right: bool,
270    pub result_rows: usize,
271    pub result_columns: usize,
272    pub interior_columns: usize,
273    pub alpha: f64,
274    pub beta: f64,
275}
276
277impl MatrixMultiplicationDescriptor {
278    /// Construct the common `C = A * B` descriptor.
279    #[must_use]
280    pub const fn new(result_rows: usize, result_columns: usize, interior_columns: usize) -> Self {
281        Self {
282            transpose_left: false,
283            transpose_right: false,
284            result_rows,
285            result_columns,
286            interior_columns,
287            alpha: 1.0,
288            beta: 0.0,
289        }
290    }
291
292    /// Construct a fully configurable descriptor.
293    #[must_use]
294    pub const fn with_options(
295        transpose_left: bool,
296        transpose_right: bool,
297        result_rows: usize,
298        result_columns: usize,
299        interior_columns: usize,
300        alpha: f64,
301        beta: f64,
302    ) -> Self {
303        Self {
304            transpose_left,
305            transpose_right,
306            result_rows,
307            result_columns,
308            interior_columns,
309            alpha,
310            beta,
311        }
312    }
313}
314
315opaque_handle!(MatrixMultiplication);
316impl MatrixMultiplication {
317    /// Build a configurable GEMM kernel with optional transposition and scaling.
318    #[must_use]
319    pub fn new(device: &MetalDevice, descriptor: MatrixMultiplicationDescriptor) -> Option<Self> {
320        // SAFETY: `device` exposes a valid `MTLDevice` pointer.
321        let ptr = unsafe {
322            ffi::mps_matrix_multiplication_new(
323                device.as_ptr(),
324                descriptor.transpose_left,
325                descriptor.transpose_right,
326                descriptor.result_rows,
327                descriptor.result_columns,
328                descriptor.interior_columns,
329                descriptor.alpha,
330                descriptor.beta,
331            )
332        };
333        if ptr.is_null() {
334            None
335        } else {
336            Some(Self { ptr })
337        }
338    }
339
340    /// Convenience constructor for the common `C = A * B` case.
341    #[must_use]
342    pub fn new_simple(
343        device: &MetalDevice,
344        result_rows: usize,
345        result_columns: usize,
346        interior_columns: usize,
347    ) -> Option<Self> {
348        Self::new(
349            device,
350            MatrixMultiplicationDescriptor::new(result_rows, result_columns, interior_columns),
351        )
352    }
353
354    /// Encode the matrix multiplication onto a command buffer.
355    pub fn encode(
356        &self,
357        command_buffer: &CommandBuffer,
358        left: &Matrix,
359        right: &Matrix,
360        result: &Matrix,
361    ) {
362        // SAFETY: All handles come from safe wrappers and remain alive for the call.
363        unsafe {
364            ffi::mps_matrix_multiplication_encode(
365                self.ptr,
366                command_buffer.as_ptr(),
367                left.as_ptr(),
368                right.as_ptr(),
369                result.as_ptr(),
370            );
371        };
372    }
373}