#![cfg(all(feature = "metal", target_os = "macos"))]
#![allow(dead_code)]
#![allow(deprecated)]
use crate::gpu::GpuError;
use std::sync::Arc;
#[cfg(all(feature = "metal", target_os = "macos"))]
use objc2_metal::{MTLBuffer, MTLCommandQueue, MTLDevice};
#[cfg(all(feature = "metal", target_os = "macos"))]
use objc2_metal_performance_shaders::{
MPSDataType as MPSDataTypeEnum, MPSMatrix, MPSMatrixDescriptor, MPSMatrixMultiplication,
};
#[cfg(all(feature = "metal", target_os = "macos"))]
use objc2::runtime::ProtocolObject;
#[cfg(all(feature = "metal", target_os = "macos"))]
use objc2::rc::Retained;
#[cfg(all(feature = "metal", target_os = "macos"))]
use objc2::{msg_send, msg_send_id, ClassType};
#[cfg(all(feature = "metal", target_os = "macos"))]
use objc2::runtime::AnyObject;
#[cfg(not(all(feature = "metal", target_os = "macos")))]
type MTLDevice = ();
#[cfg(not(all(feature = "metal", target_os = "macos")))]
type MTLCommandQueue = ();
#[cfg(not(all(feature = "metal", target_os = "macos")))]
type MTLBuffer = ();
pub struct MPSContext {
#[cfg(all(feature = "metal", target_os = "macos"))]
device: Retained<ProtocolObject<dyn MTLDevice>>,
#[cfg(all(feature = "metal", target_os = "macos"))]
command_queue: Retained<ProtocolObject<dyn MTLCommandQueue>>,
#[cfg(not(all(feature = "metal", target_os = "macos")))]
device: MTLDevice,
#[cfg(not(all(feature = "metal", target_os = "macos")))]
command_queue: MTLCommandQueue,
}
unsafe impl Send for MPSContext {}
unsafe impl Sync for MPSContext {}
impl MPSContext {
#[cfg(all(feature = "metal", target_os = "macos"))]
pub fn new(
device: Retained<ProtocolObject<dyn MTLDevice>>,
command_queue: Retained<ProtocolObject<dyn MTLCommandQueue>>,
) -> Self {
Self {
device,
command_queue,
}
}
#[cfg(not(all(feature = "metal", target_os = "macos")))]
pub fn new(device: MTLDevice, command_queue: MTLCommandQueue) -> Self {
Self {
device,
command_queue,
}
}
#[cfg(all(feature = "metal", target_os = "macos"))]
pub fn create_matmul(
&self,
transpose_left: bool,
transpose_right: bool,
result_rows: usize,
result_columns: usize,
interior_columns: usize,
alpha: f64,
beta: f64,
) -> Result<Retained<MPSMatrixMultiplication>, GpuError> {
use objc2_metal_performance_shaders::MPSMatrixMultiplication;
let matmul = unsafe {
let cls = MPSMatrixMultiplication::class();
let alloc = msg_send_id![cls, alloc];
msg_send_id![
alloc,
initWithDevice: &*self.device,
transposeLeft: transpose_left,
transposeRight: transpose_right,
resultRows: result_rows,
resultColumns: result_columns,
interiorColumns: interior_columns,
alpha: alpha,
beta: beta
]
};
Ok(matmul)
}
#[cfg(not(all(feature = "metal", target_os = "macos")))]
pub fn create_matmul(
&self,
_transpose_left: bool,
_transpose_right: bool,
_result_rows: usize,
_result_columns: usize,
_interior_columns: usize,
_alpha: f64,
_beta: f64,
) -> Result<(), GpuError> {
Err(GpuError::Other(
"Metal not available on this platform".to_string(),
))
}
#[cfg(all(feature = "metal", target_os = "macos"))]
pub fn create_descriptor(
rows: usize,
columns: usize,
row_bytes: usize,
datatype: MPSDataType,
) -> Result<Retained<MPSMatrixDescriptor>, GpuError> {
use objc2_metal_performance_shaders::MPSMatrixDescriptor;
let mps_datatype = match datatype {
MPSDataType::Float32 => MPSDataTypeEnum::Float32,
MPSDataType::Float16 => MPSDataTypeEnum::Float16,
MPSDataType::Int32 => MPSDataTypeEnum::Int32,
_ => {
return Err(GpuError::Other(format!(
"Unsupported datatype: {:?}",
datatype
)))
}
};
let descriptor = unsafe {
let cls = MPSMatrixDescriptor::class();
msg_send_id![
cls,
matrixDescriptorWithRows: rows,
columns: columns,
rowBytes: row_bytes,
dataType: mps_datatype
]
};
Ok(descriptor)
}
#[cfg(not(all(feature = "metal", target_os = "macos")))]
pub fn create_descriptor(
_rows: usize,
_columns: usize,
_row_bytes: usize,
_datatype: MPSDataType,
) -> Result<(), GpuError> {
Err(GpuError::Other(
"Metal not available on this platform".to_string(),
))
}
#[cfg(all(feature = "metal", target_os = "macos"))]
pub fn create_matrix(
&self,
buffer: &Retained<ProtocolObject<dyn MTLBuffer>>,
descriptor: &Retained<MPSMatrixDescriptor>,
) -> Result<Retained<MPSMatrix>, GpuError> {
use objc2_metal_performance_shaders::MPSMatrix;
let matrix = unsafe {
let cls = MPSMatrix::class();
let alloc = msg_send_id![cls, alloc];
msg_send_id![
alloc,
initWithBuffer: &**buffer,
descriptor: &**descriptor
]
};
Ok(matrix)
}
#[cfg(not(all(feature = "metal", target_os = "macos")))]
pub fn create_matrix(&self, _buffer: &MTLBuffer, _descriptor: &()) -> Result<(), GpuError> {
Err(GpuError::Other(
"Metal not available on this platform".to_string(),
))
}
#[cfg(all(feature = "metal", target_os = "macos"))]
pub fn create_command_buffer(&self) -> Result<Retained<AnyObject>, GpuError> {
let command_buffer: Option<Retained<AnyObject>> =
unsafe { msg_send_id![&self.command_queue, commandBuffer] };
command_buffer.ok_or_else(|| GpuError::Other("Failed to create command buffer".to_string()))
}
#[cfg(all(feature = "metal", target_os = "macos"))]
pub fn commit_command_buffer(&self, command_buffer: &Retained<AnyObject>) {
unsafe {
let _: () = msg_send![&**command_buffer, commit];
}
}
#[cfg(all(feature = "metal", target_os = "macos"))]
pub fn wait_for_command_buffer(&self, command_buffer: &Retained<AnyObject>) {
unsafe {
let _: () = msg_send![&**command_buffer, waitUntilCompleted];
}
}
#[cfg(all(feature = "metal", target_os = "macos"))]
pub fn encode_matrix_multiply(
&self,
command_buffer: &Retained<AnyObject>,
left_matrix: &Retained<MPSMatrix>,
right_matrix: &Retained<MPSMatrix>,
result_matrix: &Retained<MPSMatrix>,
matmul: &Retained<MPSMatrixMultiplication>,
) -> Result<(), GpuError> {
unsafe {
let _: () = msg_send![
&**matmul,
encodeToCommandBuffer: &**command_buffer,
leftMatrix: &**left_matrix,
rightMatrix: &**right_matrix,
resultMatrix: &**result_matrix
];
}
Ok(())
}
#[cfg(all(feature = "metal", target_os = "macos"))]
pub fn matrix_multiply(
&self,
left_matrix: &Retained<MPSMatrix>,
right_matrix: &Retained<MPSMatrix>,
result_matrix: &Retained<MPSMatrix>,
matmul: &Retained<MPSMatrixMultiplication>,
) -> Result<(), GpuError> {
use objc2_metal::MTLCommandBuffer;
let command_buffer = self.create_command_buffer()?;
self.encode_matrix_multiply(
&command_buffer,
left_matrix,
right_matrix,
result_matrix,
matmul,
)?;
self.commit_command_buffer(&command_buffer);
self.wait_for_command_buffer(&command_buffer);
Ok(())
}
#[cfg(not(all(feature = "metal", target_os = "macos")))]
pub fn matrix_multiply(
&self,
_left: &(),
_right: &(),
_result: &(),
_matmul: &(),
) -> Result<(), GpuError> {
Err(GpuError::Other(
"Metal not available on this platform".to_string(),
))
}
#[cfg(all(feature = "metal", target_os = "macos"))]
pub fn create_softmax(&self, _axis: i32) -> Result<(), GpuError> {
use objc2_metal_performance_shaders::MPSMatrixSoftMax;
let _kernel = unsafe {
let cls = MPSMatrixSoftMax::class();
let alloc: objc2::rc::Allocated<MPSMatrixSoftMax> = msg_send_id![cls, alloc];
let kernel: Retained<MPSMatrixSoftMax> =
msg_send_id![alloc, initWithDevice: &*self.device];
kernel
};
Ok(())
}
#[cfg(not(all(feature = "metal", target_os = "macos")))]
pub fn create_softmax(&self, _axis: i32) -> Result<(), GpuError> {
Err(GpuError::Other(
"Metal not available on this platform".to_string(),
))
}
#[cfg(all(feature = "metal", target_os = "macos"))]
pub fn create_sum(&self) -> Result<(), GpuError> {
use objc2_metal_performance_shaders::MPSMatrixSum;
let _kernel = unsafe {
let cls = MPSMatrixSum::class();
let alloc: objc2::rc::Allocated<MPSMatrixSum> = msg_send_id![cls, alloc];
let kernel: Retained<MPSMatrixSum> = msg_send_id![
alloc,
initWithDevice: &*self.device,
count: 1usize,
rows: 1usize,
columns: 1usize,
transpose: false
];
kernel
};
Ok(())
}
#[cfg(not(all(feature = "metal", target_os = "macos")))]
pub fn create_sum(&self) -> Result<(), GpuError> {
Err(GpuError::Other(
"Metal not available on this platform".to_string(),
))
}
#[cfg(all(feature = "metal", target_os = "macos"))]
pub fn create_find_top_k(&self, k: usize) -> Result<(), GpuError> {
use objc2_metal_performance_shaders::MPSMatrixFindTopK;
if k == 0 || k > 16 {
return Err(GpuError::Other(format!(
"MPSMatrixFindTopK: k must be in 1..=16, got {k}"
)));
}
let _kernel = unsafe {
let cls = MPSMatrixFindTopK::class();
let alloc: objc2::rc::Allocated<MPSMatrixFindTopK> = msg_send_id![cls, alloc];
let kernel: Retained<MPSMatrixFindTopK> = msg_send_id![
alloc,
initWithDevice: &*self.device,
numberOfTopKValues: k
];
kernel
};
Ok(())
}
#[cfg(not(all(feature = "metal", target_os = "macos")))]
pub fn create_find_top_k(&self, _k: usize) -> Result<(), GpuError> {
Err(GpuError::Other(
"Metal not available on this platform".to_string(),
))
}
}
pub struct MPSConvolution {
pub(crate) context: Arc<MPSContext>,
}
impl MPSConvolution {
pub fn new(context: Arc<MPSContext>) -> Result<Self, GpuError> {
Ok(Self { context })
}
#[cfg(all(feature = "metal", target_os = "macos"))]
pub fn execute(
&self,
_input: &objc2::rc::Retained<dyn MTLBuffer>,
_weights: &objc2::rc::Retained<dyn MTLBuffer>,
_output: &mut objc2::rc::Retained<dyn MTLBuffer>,
) -> Result<(), GpuError> {
Err(GpuError::Other(
"MPSConvolution::execute requires MTLTexture/MPSImage input; \
buffer-backed convolution needs a blit pass — use MPSImage-based API instead"
.to_string(),
))
}
#[cfg(not(all(feature = "metal", target_os = "macos")))]
pub fn execute(
&self,
_input: &MTLBuffer,
_weights: &MTLBuffer,
_output: &mut MTLBuffer,
) -> Result<(), GpuError> {
Err(GpuError::Other(
"Metal not available on this platform".to_string(),
))
}
}
pub struct MPSPooling {
pub(crate) context: Arc<MPSContext>,
pub(crate) pool_type: PoolType,
}
#[derive(Clone, Copy, Debug)]
pub enum PoolType {
Max,
Average,
}
impl MPSPooling {
pub fn new(context: Arc<MPSContext>, pool_type: PoolType) -> Result<Self, GpuError> {
Ok(Self { context, pool_type })
}
#[cfg(all(feature = "metal", target_os = "macos"))]
pub fn execute(
&self,
_input: &objc2::rc::Retained<dyn MTLBuffer>,
_output: &mut objc2::rc::Retained<dyn MTLBuffer>,
_kernel_size: (usize, usize),
_stride: (usize, usize),
) -> Result<(), GpuError> {
Err(GpuError::Other(format!(
"MPSPooling({:?})::execute requires MTLTexture/MPSImage input; \
buffer-backed pooling needs a blit pass — use MPSImage-based API instead",
self.pool_type
)))
}
#[cfg(not(all(feature = "metal", target_os = "macos")))]
pub fn execute(
&self,
_input: &MTLBuffer,
_output: &mut MTLBuffer,
_kernel_size: (usize, usize),
_stride: (usize, usize),
) -> Result<(), GpuError> {
Err(GpuError::Other(
"Metal not available on this platform".to_string(),
))
}
}
#[derive(Clone, Copy, Debug)]
pub enum MPSDataType {
Float32,
Float16,
Int32,
Int16,
Int8,
UInt8,
}
impl MPSDataType {
#[cfg(all(feature = "metal", target_os = "macos"))]
pub fn to_mps_datatype(self) -> u32 {
match self {
MPSDataType::Float32 => MPSDataTypeEnum::Float32.0,
MPSDataType::Float16 => MPSDataTypeEnum::Float16.0,
MPSDataType::Int32 => MPSDataTypeEnum::Int32.0,
MPSDataType::Int16 => MPSDataTypeEnum::Int16.0,
MPSDataType::Int8 => MPSDataTypeEnum::Int8.0,
MPSDataType::UInt8 => MPSDataTypeEnum::UInt8.0,
}
}
#[cfg(not(all(feature = "metal", target_os = "macos")))]
pub fn to_mps_datatype(self) -> u32 {
match self {
MPSDataType::Float32 => 0x10000000 | 32,
MPSDataType::Float16 => 0x10000000 | 16,
MPSDataType::Int32 => 0x20000000 | 32,
MPSDataType::Int16 => 0x20000000 | 16,
MPSDataType::Int8 => 0x20000000 | 8,
MPSDataType::UInt8 => 8,
}
}
}
pub struct MPSOperations {
context: Arc<MPSContext>,
}
unsafe impl Send for MPSOperations {}
unsafe impl Sync for MPSOperations {}
impl MPSOperations {
#[cfg(all(feature = "metal", target_os = "macos"))]
pub fn new(
device: Retained<ProtocolObject<dyn MTLDevice>>,
command_queue: Retained<ProtocolObject<dyn MTLCommandQueue>>,
) -> Self {
Self {
context: Arc::new(MPSContext::new(device, command_queue)),
}
}
#[cfg(not(all(feature = "metal", target_os = "macos")))]
pub fn new(device: MTLDevice, command_queue: MTLCommandQueue) -> Self {
Self {
context: Arc::new(MPSContext::new(device, command_queue)),
}
}
pub fn context(&self) -> &Arc<MPSContext> {
&self.context
}
#[cfg(all(feature = "metal", target_os = "macos"))]
pub fn encode_matmul_f32(
&self,
command_buffer: &Retained<AnyObject>,
a_buffer: &Retained<ProtocolObject<dyn MTLBuffer>>,
b_buffer: &Retained<ProtocolObject<dyn MTLBuffer>>,
c_buffer: &Retained<ProtocolObject<dyn MTLBuffer>>,
m: usize,
k: usize,
n: usize,
) -> Result<(), GpuError> {
let a_desc = MPSContext::create_descriptor(m, k, k * 4, MPSDataType::Float32)?;
let b_desc = MPSContext::create_descriptor(k, n, n * 4, MPSDataType::Float32)?;
let c_desc = MPSContext::create_descriptor(m, n, n * 4, MPSDataType::Float32)?;
let a_matrix = self.context.create_matrix(a_buffer, &a_desc)?;
let b_matrix = self.context.create_matrix(b_buffer, &b_desc)?;
let c_matrix = self.context.create_matrix(c_buffer, &c_desc)?;
let matmul = self.context.create_matmul(
false, false, m, n, k, 1.0, 0.0, )?;
self.context.encode_matrix_multiply(
command_buffer,
&a_matrix,
&b_matrix,
&c_matrix,
&matmul,
)?;
Ok(())
}
#[cfg(all(feature = "metal", target_os = "macos"))]
pub fn matmul_f32(
&self,
a_buffer: &Retained<ProtocolObject<dyn MTLBuffer>>,
b_buffer: &Retained<ProtocolObject<dyn MTLBuffer>>,
c_buffer: &Retained<ProtocolObject<dyn MTLBuffer>>,
m: usize,
k: usize,
n: usize,
) -> Result<(), GpuError> {
let a_desc = MPSContext::create_descriptor(m, k, k * 4, MPSDataType::Float32)?;
let b_desc = MPSContext::create_descriptor(k, n, n * 4, MPSDataType::Float32)?;
let c_desc = MPSContext::create_descriptor(m, n, n * 4, MPSDataType::Float32)?;
let a_matrix = self.context.create_matrix(a_buffer, &a_desc)?;
let b_matrix = self.context.create_matrix(b_buffer, &b_desc)?;
let c_matrix = self.context.create_matrix(c_buffer, &c_desc)?;
let matmul = self.context.create_matmul(
false, false, m, n, k, 1.0, 0.0, )?;
self.context
.matrix_multiply(&a_matrix, &b_matrix, &c_matrix, &matmul)?;
Ok(())
}
#[cfg(not(all(feature = "metal", target_os = "macos")))]
pub fn matmul_f32(
&self,
_a_buffer: &(),
_b_buffer: &(),
_c_buffer: &(),
_m: usize,
_k: usize,
_n: usize,
) -> Result<(), GpuError> {
Err(GpuError::Other(
"Metal not available on this platform".to_string(),
))
}
#[cfg(all(feature = "metal", target_os = "macos"))]
pub fn matmul_f32_scaled(
&self,
a_buffer: &Retained<ProtocolObject<dyn MTLBuffer>>,
b_buffer: &Retained<ProtocolObject<dyn MTLBuffer>>,
c_buffer: &Retained<ProtocolObject<dyn MTLBuffer>>,
m: usize,
k: usize,
n: usize,
alpha: f32,
) -> Result<(), GpuError> {
let a_desc = MPSContext::create_descriptor(m, k, k * 4, MPSDataType::Float32)?;
let b_desc = MPSContext::create_descriptor(k, n, n * 4, MPSDataType::Float32)?;
let c_desc = MPSContext::create_descriptor(m, n, n * 4, MPSDataType::Float32)?;
let a_matrix = self.context.create_matrix(a_buffer, &a_desc)?;
let b_matrix = self.context.create_matrix(b_buffer, &b_desc)?;
let c_matrix = self.context.create_matrix(c_buffer, &c_desc)?;
let matmul = self.context.create_matmul(
false, false, m, n, k, alpha as f64, 0.0, )?;
self.context
.matrix_multiply(&a_matrix, &b_matrix, &c_matrix, &matmul)?;
Ok(())
}
#[cfg(not(all(feature = "metal", target_os = "macos")))]
pub fn matmul_f32_scaled(
&self,
_a_buffer: &(),
_b_buffer: &(),
_c_buffer: &(),
_m: usize,
_k: usize,
_n: usize,
_alpha: f32,
) -> Result<(), GpuError> {
Err(GpuError::Other(
"Metal not available on this platform".to_string(),
))
}
}
pub struct MPSImageOps {
pub(crate) context: Arc<MPSContext>,
}
impl MPSImageOps {
pub fn new(context: Arc<MPSContext>) -> Result<Self, GpuError> {
Ok(Self { context })
}
#[cfg(all(feature = "metal", target_os = "macos"))]
pub fn gaussian_blur(
&self,
_input: &objc2::rc::Retained<dyn MTLBuffer>,
_output: &mut objc2::rc::Retained<dyn MTLBuffer>,
_sigma: f32,
) -> Result<(), GpuError> {
Err(GpuError::Other(
"MPSImageOps::gaussian_blur requires MTLTexture/MPSImage input; \
buffer-backed image ops need a blit pass — use MPSImage-based API instead"
.to_string(),
))
}
#[cfg(not(all(feature = "metal", target_os = "macos")))]
pub fn gaussian_blur(
&self,
_input: &MTLBuffer,
_output: &mut MTLBuffer,
_sigma: f32,
) -> Result<(), GpuError> {
Err(GpuError::Other(
"Metal not available on this platform".to_string(),
))
}
#[cfg(all(feature = "metal", target_os = "macos"))]
pub fn edge_detection(
&self,
_input: &objc2::rc::Retained<dyn MTLBuffer>,
_output: &mut objc2::rc::Retained<dyn MTLBuffer>,
_threshold: f32,
) -> Result<(), GpuError> {
Err(GpuError::Other(
"MPSImageOps::edge_detection requires MTLTexture/MPSImage input; \
buffer-backed image ops need a blit pass — use MPSImage-based API instead"
.to_string(),
))
}
#[cfg(not(all(feature = "metal", target_os = "macos")))]
pub fn edge_detection(
&self,
_input: &MTLBuffer,
_output: &mut MTLBuffer,
_threshold: f32,
) -> Result<(), GpuError> {
Err(GpuError::Other(
"Metal not available on this platform".to_string(),
))
}
}
#[cfg(test)]
mod tests {
use super::{MPSContext, MPSConvolution, MPSDataType, MPSImageOps, MPSPooling, PoolType};
use std::sync::Arc;
fn try_make_context() -> Option<MPSContext> {
use objc2_metal::{MTLCreateSystemDefaultDevice, MTLDevice};
let device = MTLCreateSystemDefaultDevice()?;
let queue = device.newCommandQueue()?;
Some(MPSContext::new(device, queue))
}
#[test]
fn test_mps_convolution_new_succeeds() {
if let Some(ctx) = try_make_context() {
let result = MPSConvolution::new(Arc::new(ctx));
assert!(result.is_ok(), "MPSConvolution::new should succeed");
}
}
#[test]
fn test_mps_pooling_new_succeeds() {
if let Some(ctx) = try_make_context() {
let result = MPSPooling::new(Arc::new(ctx), PoolType::Max);
assert!(result.is_ok(), "MPSPooling::new should succeed");
}
}
#[test]
fn test_mps_image_ops_new_succeeds() {
if let Some(ctx) = try_make_context() {
let result = MPSImageOps::new(Arc::new(ctx));
assert!(result.is_ok(), "MPSImageOps::new should succeed");
}
}
#[test]
fn test_mps_create_softmax_succeeds() {
if let Some(ctx) = try_make_context() {
let result = ctx.create_softmax(1);
assert!(result.is_ok(), "create_softmax should succeed: {result:?}");
}
}
#[test]
fn test_mps_create_sum_succeeds() {
if let Some(ctx) = try_make_context() {
let result = ctx.create_sum();
assert!(result.is_ok(), "create_sum should succeed: {result:?}");
}
}
#[test]
fn test_mps_create_find_top_k_valid() {
if let Some(ctx) = try_make_context() {
let result = ctx.create_find_top_k(4);
assert!(
result.is_ok(),
"create_find_top_k(4) should succeed: {result:?}"
);
}
}
#[test]
fn test_mps_create_find_top_k_out_of_range() {
if let Some(ctx) = try_make_context() {
assert!(ctx.create_find_top_k(0).is_err(), "k=0 must be rejected");
assert!(ctx.create_find_top_k(17).is_err(), "k=17 must be rejected");
}
}
#[test]
fn test_mps_datatype_roundtrip() {
assert_eq!(MPSDataType::Float32.to_mps_datatype(), 0x10000000 | 32);
assert_eq!(MPSDataType::Float16.to_mps_datatype(), 0x10000000 | 16);
assert_eq!(MPSDataType::Int32.to_mps_datatype(), 0x20000000 | 32);
assert_eq!(MPSDataType::Int16.to_mps_datatype(), 0x20000000 | 16);
assert_eq!(MPSDataType::Int8.to_mps_datatype(), 0x20000000 | 8);
assert_eq!(MPSDataType::UInt8.to_mps_datatype(), 8);
}
}