apple-mps 0.2.1

Safe Rust bindings for Apple's MetalPerformanceShaders framework on macOS, backed by a Swift bridge
Documentation
use crate::ffi;
use apple_metal::{CommandBuffer as MetalCommandBuffer, MetalBuffer, MetalDevice};
use core::ffi::c_void;
use core::ptr;

macro_rules! opaque_handle {
    ($name:ident) => {
        pub struct $name {
            ptr: *mut c_void,
        }

        unsafe impl Send for $name {}
        unsafe impl Sync for $name {}

        impl Drop for $name {
            fn drop(&mut self) {
                if !self.ptr.is_null() {
                    unsafe { ffi::mps_object_release(self.ptr) };
                    self.ptr = ptr::null_mut();
                }
            }
        }

        impl $name {
            #[must_use]
            pub const fn as_ptr(&self) -> *mut c_void {
                self.ptr
            }
        }
    };
}

opaque_handle!(NDArrayDescriptor);
impl NDArrayDescriptor {
    #[must_use]
    pub fn with_dimension_sizes(data_type: u32, dimension_sizes: &[usize]) -> Option<Self> {
        let ptr = unsafe {
            ffi::mps_ndarray_descriptor_new_with_dimension_sizes(
                data_type,
                dimension_sizes.len(),
                dimension_sizes.as_ptr(),
            )
        };
        if ptr.is_null() {
            None
        } else {
            Some(Self { ptr })
        }
    }

    #[must_use]
    pub fn data_type(&self) -> u32 {
        unsafe { ffi::mps_ndarray_descriptor_data_type(self.ptr) }
    }

    pub fn set_data_type(&self, data_type: u32) {
        unsafe { ffi::mps_ndarray_descriptor_set_data_type(self.ptr, data_type) };
    }

    #[must_use]
    pub fn number_of_dimensions(&self) -> usize {
        unsafe { ffi::mps_ndarray_descriptor_number_of_dimensions(self.ptr) }
    }

    pub fn set_number_of_dimensions(&self, number_of_dimensions: usize) {
        unsafe {
            ffi::mps_ndarray_descriptor_set_number_of_dimensions(self.ptr, number_of_dimensions);
        };
    }

    #[must_use]
    pub fn length_of_dimension(&self, dimension_index: usize) -> usize {
        unsafe { ffi::mps_ndarray_descriptor_length_of_dimension(self.ptr, dimension_index) }
    }

    pub fn reshape_with_dimension_sizes(&self, dimension_sizes: &[usize]) {
        unsafe {
            ffi::mps_ndarray_descriptor_reshape_with_dimension_sizes(
                self.ptr,
                dimension_sizes.len(),
                dimension_sizes.as_ptr(),
            );
        };
    }

    pub fn transpose_dimension(&self, dimension_index: usize, other_dimension_index: usize) {
        unsafe {
            ffi::mps_ndarray_descriptor_transpose_dimension(
                self.ptr,
                dimension_index,
                other_dimension_index,
            );
        };
    }
}

opaque_handle!(NDArray);
impl NDArray {
    #[must_use]
    pub fn new(device: &MetalDevice, descriptor: &NDArrayDescriptor) -> Option<Self> {
        let ptr =
            unsafe { ffi::mps_ndarray_new_with_descriptor(device.as_ptr(), descriptor.as_ptr()) };
        if ptr.is_null() {
            None
        } else {
            Some(Self { ptr })
        }
    }

    #[must_use]
    pub fn scalar(device: &MetalDevice, value: f64) -> Option<Self> {
        let ptr = unsafe { ffi::mps_ndarray_new_scalar(device.as_ptr(), value) };
        if ptr.is_null() {
            None
        } else {
            Some(Self { ptr })
        }
    }

    #[must_use]
    pub fn new_with_buffer(
        buffer: &MetalBuffer,
        offset: usize,
        descriptor: &NDArrayDescriptor,
    ) -> Option<Self> {
        let ptr = unsafe {
            ffi::mps_ndarray_new_with_buffer(buffer.as_ptr(), offset, descriptor.as_ptr())
        };
        if ptr.is_null() {
            None
        } else {
            Some(Self { ptr })
        }
    }

    #[must_use]
    pub fn data_type(&self) -> u32 {
        unsafe { ffi::mps_ndarray_data_type(self.ptr) }
    }

    #[must_use]
    pub fn number_of_dimensions(&self) -> usize {
        unsafe { ffi::mps_ndarray_number_of_dimensions(self.ptr) }
    }

    #[must_use]
    pub fn length_of_dimension(&self, dimension_index: usize) -> usize {
        unsafe { ffi::mps_ndarray_length_of_dimension(self.ptr, dimension_index) }
    }

    #[must_use]
    pub fn descriptor(&self) -> Option<NDArrayDescriptor> {
        let ptr = unsafe { ffi::mps_ndarray_descriptor(self.ptr) };
        if ptr.is_null() {
            None
        } else {
            Some(NDArrayDescriptor { ptr })
        }
    }

    #[must_use]
    pub fn resource_size(&self) -> usize {
        unsafe { ffi::mps_ndarray_resource_size(self.ptr) }
    }
}

opaque_handle!(NDArrayIdentity);
impl NDArrayIdentity {
    #[must_use]
    pub fn new(device: &MetalDevice) -> Option<Self> {
        let ptr = unsafe { ffi::mps_ndarray_identity_new(device.as_ptr()) };
        if ptr.is_null() {
            None
        } else {
            Some(Self { ptr })
        }
    }

    #[must_use]
    pub fn reshape(&self, source: &NDArray, dimension_sizes: &[usize]) -> Option<NDArray> {
        let ptr = unsafe {
            ffi::mps_ndarray_identity_reshape(
                self.ptr,
                ptr::null_mut(),
                source.as_ptr(),
                dimension_sizes.len(),
                dimension_sizes.as_ptr(),
                ptr::null_mut(),
            )
        };
        if ptr.is_null() {
            None
        } else {
            Some(NDArray { ptr })
        }
    }

    #[must_use]
    pub fn reshape_with_command_buffer(
        &self,
        command_buffer: &MetalCommandBuffer,
        source: &NDArray,
        dimension_sizes: &[usize],
    ) -> Option<NDArray> {
        let ptr = unsafe {
            ffi::mps_ndarray_identity_reshape(
                self.ptr,
                command_buffer.as_ptr(),
                source.as_ptr(),
                dimension_sizes.len(),
                dimension_sizes.as_ptr(),
                ptr::null_mut(),
            )
        };
        if ptr.is_null() {
            None
        } else {
            Some(NDArray { ptr })
        }
    }

    pub fn reshape_into(
        &self,
        command_buffer: Option<&MetalCommandBuffer>,
        source: &NDArray,
        dimension_sizes: &[usize],
        destination: &NDArray,
    ) -> bool {
        let command_buffer_ptr = command_buffer.map_or(ptr::null_mut(), MetalCommandBuffer::as_ptr);
        let ptr = unsafe {
            ffi::mps_ndarray_identity_reshape(
                self.ptr,
                command_buffer_ptr,
                source.as_ptr(),
                dimension_sizes.len(),
                dimension_sizes.as_ptr(),
                destination.as_ptr(),
            )
        };
        !ptr.is_null()
    }
}

opaque_handle!(NDArrayMatrixMultiplication);
impl NDArrayMatrixMultiplication {
    #[must_use]
    pub fn new(device: &MetalDevice, source_count: usize) -> Option<Self> {
        let ptr = unsafe { ffi::mps_ndarray_matrix_multiplication_new(device.as_ptr(), source_count) };
        if ptr.is_null() {
            None
        } else {
            Some(Self { ptr })
        }
    }

    #[must_use]
    pub fn alpha(&self) -> f64 {
        unsafe { ffi::mps_ndarray_matrix_multiplication_alpha(self.ptr) }
    }

    pub fn set_alpha(&self, alpha: f64) {
        unsafe { ffi::mps_ndarray_matrix_multiplication_set_alpha(self.ptr, alpha) };
    }

    #[must_use]
    pub fn beta(&self) -> f64 {
        unsafe { ffi::mps_ndarray_matrix_multiplication_beta(self.ptr) }
    }

    pub fn set_beta(&self, beta: f64) {
        unsafe { ffi::mps_ndarray_matrix_multiplication_set_beta(self.ptr, beta) };
    }

    #[must_use]
    pub fn encode(
        &self,
        command_buffer: &MetalCommandBuffer,
        source_arrays: &[&NDArray],
    ) -> Option<NDArray> {
        let handles: Vec<_> = source_arrays.iter().map(|array| array.as_ptr()).collect();
        let handles_ptr = if handles.is_empty() {
            ptr::null()
        } else {
            handles.as_ptr()
        };
        let ptr = unsafe {
            ffi::mps_ndarray_matrix_multiplication_encode(
                self.ptr,
                command_buffer.as_ptr(),
                source_arrays.len(),
                handles_ptr,
            )
        };
        if ptr.is_null() {
            None
        } else {
            Some(NDArray { ptr })
        }
    }

    pub fn encode_to_destination(
        &self,
        command_buffer: &MetalCommandBuffer,
        source_arrays: &[&NDArray],
        destination: &NDArray,
    ) {
        let handles: Vec<_> = source_arrays.iter().map(|array| array.as_ptr()).collect();
        let handles_ptr = if handles.is_empty() {
            ptr::null()
        } else {
            handles.as_ptr()
        };
        unsafe {
            ffi::mps_ndarray_matrix_multiplication_encode_to_destination(
                self.ptr,
                command_buffer.as_ptr(),
                source_arrays.len(),
                handles_ptr,
                destination.as_ptr(),
            );
        };
    }
}