use crate::ffi;
use apple_metal::{CommandBuffer, MetalBuffer, MetalDevice};
use core::ffi::c_void;
use core::ptr;
pub mod data_type {
pub const INVALID: u32 = 0;
pub const FLOAT32: u32 = 0x1000_0020;
pub const FLOAT16: u32 = 0x1000_0010;
pub const INT8: u32 = 0x2000_0008;
pub const INT16: u32 = 0x2000_0010;
pub const INT32: u32 = 0x2000_0020;
pub const UINT8: u32 = 0x0000_0008;
pub const UINT16: u32 = 0x0000_0010;
pub const UINT32: u32 = 0x0000_0020;
pub const UNORM8: u32 = 0x4000_0008;
}
#[must_use]
pub const fn data_type_size(data_type: u32) -> Option<usize> {
match data_type {
data_type::FLOAT16 | data_type::INT16 | data_type::UINT16 => Some(2),
data_type::FLOAT32 | data_type::INT32 | data_type::UINT32 => Some(4),
data_type::INT8 | data_type::UINT8 | data_type::UNORM8 => Some(1),
_ => None,
}
}
#[derive(Debug, Clone, Copy)]
pub struct MatrixDescriptor {
pub rows: usize,
pub columns: usize,
pub matrices: usize,
pub row_bytes: usize,
pub matrix_bytes: usize,
pub data_type: u32,
}
impl MatrixDescriptor {
#[must_use]
pub const fn with_strides(
rows: usize,
columns: usize,
matrices: usize,
row_bytes: usize,
matrix_bytes: usize,
data_type: u32,
) -> Self {
Self {
rows,
columns,
matrices,
row_bytes,
matrix_bytes,
data_type,
}
}
#[must_use]
pub fn contiguous(rows: usize, columns: usize, data_type: u32) -> Option<Self> {
let element_size = data_type_size(data_type)?;
let row_bytes = columns.checked_mul(element_size)?;
let matrix_bytes = rows.checked_mul(row_bytes)?;
Some(Self::with_strides(
rows,
columns,
1,
row_bytes,
matrix_bytes,
data_type,
))
}
#[must_use]
pub fn recommended_row_bytes(columns: usize, data_type: u32) -> usize {
unsafe { ffi::mps_matrix_descriptor_row_bytes_for_columns(columns, data_type) }
}
}
#[derive(Debug, Clone, Copy)]
pub struct VectorDescriptor {
pub length: usize,
pub vectors: usize,
pub vector_bytes: usize,
pub data_type: u32,
}
impl VectorDescriptor {
#[must_use]
pub const fn with_stride(
length: usize,
vectors: usize,
vector_bytes: usize,
data_type: u32,
) -> Self {
Self {
length,
vectors,
vector_bytes,
data_type,
}
}
#[must_use]
pub fn contiguous(length: usize, data_type: u32) -> Option<Self> {
let element_size = data_type_size(data_type)?;
let vector_bytes = length.checked_mul(element_size)?;
Some(Self::with_stride(length, 1, vector_bytes, data_type))
}
#[must_use]
pub fn recommended_vector_bytes(length: usize, data_type: u32) -> usize {
unsafe { ffi::mps_vector_descriptor_vector_bytes_for_length(length, data_type) }
}
}
macro_rules! opaque_handle {
($name:ident) => {
pub struct $name {
ptr: *mut c_void,
}
unsafe impl Send for $name {}
unsafe impl Sync for $name {}
impl Drop for $name {
fn drop(&mut self) {
if !self.ptr.is_null() {
unsafe { ffi::mps_object_release(self.ptr) };
self.ptr = ptr::null_mut();
}
}
}
impl $name {
#[must_use]
pub const fn as_ptr(&self) -> *mut c_void {
self.ptr
}
}
};
}
opaque_handle!(Matrix);
impl Matrix {
#[must_use]
pub fn new_with_buffer(buffer: &MetalBuffer, descriptor: MatrixDescriptor) -> Option<Self> {
let ptr = unsafe {
ffi::mps_matrix_new_with_buffer(
buffer.as_ptr(),
descriptor.rows,
descriptor.columns,
descriptor.matrices,
descriptor.row_bytes,
descriptor.matrix_bytes,
descriptor.data_type,
)
};
if ptr.is_null() {
None
} else {
Some(Self { ptr })
}
}
#[must_use]
pub fn rows(&self) -> usize {
unsafe { ffi::mps_matrix_rows(self.ptr) }
}
#[must_use]
pub fn columns(&self) -> usize {
unsafe { ffi::mps_matrix_columns(self.ptr) }
}
#[must_use]
pub fn matrices(&self) -> usize {
unsafe { ffi::mps_matrix_matrices(self.ptr) }
}
#[must_use]
pub fn row_bytes(&self) -> usize {
unsafe { ffi::mps_matrix_row_bytes(self.ptr) }
}
#[must_use]
pub fn matrix_bytes(&self) -> usize {
unsafe { ffi::mps_matrix_matrix_bytes(self.ptr) }
}
#[must_use]
pub fn data_type(&self) -> u32 {
unsafe { ffi::mps_matrix_data_type(self.ptr) }
}
}
opaque_handle!(Vector);
impl Vector {
#[must_use]
pub fn new_with_buffer(buffer: &MetalBuffer, descriptor: VectorDescriptor) -> Option<Self> {
let ptr = unsafe {
ffi::mps_vector_new_with_buffer(
buffer.as_ptr(),
descriptor.length,
descriptor.vectors,
descriptor.vector_bytes,
descriptor.data_type,
)
};
if ptr.is_null() {
None
} else {
Some(Self { ptr })
}
}
#[must_use]
pub fn length(&self) -> usize {
unsafe { ffi::mps_vector_length(self.ptr) }
}
#[must_use]
pub fn vectors(&self) -> usize {
unsafe { ffi::mps_vector_vectors(self.ptr) }
}
#[must_use]
pub fn vector_bytes(&self) -> usize {
unsafe { ffi::mps_vector_vector_bytes(self.ptr) }
}
#[must_use]
pub fn data_type(&self) -> u32 {
unsafe { ffi::mps_vector_data_type(self.ptr) }
}
}
#[derive(Debug, Clone, Copy)]
pub struct MatrixMultiplicationDescriptor {
pub transpose_left: bool,
pub transpose_right: bool,
pub result_rows: usize,
pub result_columns: usize,
pub interior_columns: usize,
pub alpha: f64,
pub beta: f64,
}
impl MatrixMultiplicationDescriptor {
#[must_use]
pub const fn new(result_rows: usize, result_columns: usize, interior_columns: usize) -> Self {
Self {
transpose_left: false,
transpose_right: false,
result_rows,
result_columns,
interior_columns,
alpha: 1.0,
beta: 0.0,
}
}
#[must_use]
pub const fn with_options(
transpose_left: bool,
transpose_right: bool,
result_rows: usize,
result_columns: usize,
interior_columns: usize,
alpha: f64,
beta: f64,
) -> Self {
Self {
transpose_left,
transpose_right,
result_rows,
result_columns,
interior_columns,
alpha,
beta,
}
}
}
opaque_handle!(MatrixMultiplication);
impl MatrixMultiplication {
#[must_use]
pub fn new(device: &MetalDevice, descriptor: MatrixMultiplicationDescriptor) -> Option<Self> {
let ptr = unsafe {
ffi::mps_matrix_multiplication_new(
device.as_ptr(),
descriptor.transpose_left,
descriptor.transpose_right,
descriptor.result_rows,
descriptor.result_columns,
descriptor.interior_columns,
descriptor.alpha,
descriptor.beta,
)
};
if ptr.is_null() {
None
} else {
Some(Self { ptr })
}
}
#[must_use]
pub fn new_simple(
device: &MetalDevice,
result_rows: usize,
result_columns: usize,
interior_columns: usize,
) -> Option<Self> {
Self::new(
device,
MatrixMultiplicationDescriptor::new(result_rows, result_columns, interior_columns),
)
}
pub fn encode(
&self,
command_buffer: &CommandBuffer,
left: &Matrix,
right: &Matrix,
result: &Matrix,
) {
unsafe {
ffi::mps_matrix_multiplication_encode(
self.ptr,
command_buffer.as_ptr(),
left.as_ptr(),
right.as_ptr(),
result.as_ptr(),
);
};
}
}