#![cfg(all(feature = "metal", target_os = "macos"))]
#![allow(dead_code)]
#![allow(deprecated)]
use crate::gpu::GpuError;
use std::sync::Arc;
#[cfg(all(feature = "metal", target_os = "macos"))]
use objc2_metal::{MTLBuffer, MTLCommandQueue, MTLDevice};
#[cfg(all(feature = "metal", target_os = "macos"))]
use objc2_metal_performance_shaders::{
MPSDataType as MPSDataTypeEnum, MPSImageConvolution, MPSImageGaussianBlur, MPSMatrix,
MPSMatrixDescriptor, MPSMatrixMultiplication,
};
#[cfg(all(feature = "metal", target_os = "macos"))]
use objc2::runtime::ProtocolObject;
#[cfg(all(feature = "metal", target_os = "macos"))]
use objc2::rc::Retained;
#[cfg(all(feature = "metal", target_os = "macos"))]
use objc2::{msg_send, msg_send_id, ClassType};
#[cfg(all(feature = "metal", target_os = "macos"))]
use objc2::runtime::AnyObject;
#[cfg(all(feature = "metal", target_os = "macos"))]
use objc2_metal_performance_shaders::MPSKernel;
#[cfg(not(all(feature = "metal", target_os = "macos")))]
type MTLDevice = ();
#[cfg(not(all(feature = "metal", target_os = "macos")))]
type MTLCommandQueue = ();
#[cfg(not(all(feature = "metal", target_os = "macos")))]
type MTLBuffer = ();
pub struct MPSContext {
#[cfg(all(feature = "metal", target_os = "macos"))]
device: Retained<ProtocolObject<dyn MTLDevice>>,
#[cfg(all(feature = "metal", target_os = "macos"))]
command_queue: Retained<ProtocolObject<dyn MTLCommandQueue>>,
#[cfg(not(all(feature = "metal", target_os = "macos")))]
device: MTLDevice,
#[cfg(not(all(feature = "metal", target_os = "macos")))]
command_queue: MTLCommandQueue,
}
unsafe impl Send for MPSContext {}
unsafe impl Sync for MPSContext {}
impl MPSContext {
#[cfg(all(feature = "metal", target_os = "macos"))]
pub fn new(
device: Retained<ProtocolObject<dyn MTLDevice>>,
command_queue: Retained<ProtocolObject<dyn MTLCommandQueue>>,
) -> Self {
Self {
device,
command_queue,
}
}
#[cfg(not(all(feature = "metal", target_os = "macos")))]
pub fn new(device: MTLDevice, command_queue: MTLCommandQueue) -> Self {
Self {
device,
command_queue,
}
}
#[cfg(all(feature = "metal", target_os = "macos"))]
pub fn create_matmul(
&self,
transpose_left: bool,
transpose_right: bool,
result_rows: usize,
result_columns: usize,
interior_columns: usize,
alpha: f64,
beta: f64,
) -> Result<Retained<MPSMatrixMultiplication>, GpuError> {
use objc2_metal_performance_shaders::MPSMatrixMultiplication;
let matmul = unsafe {
let cls = MPSMatrixMultiplication::class();
let alloc = msg_send_id![cls, alloc];
msg_send_id![
alloc,
initWithDevice: &*self.device,
transposeLeft: transpose_left,
transposeRight: transpose_right,
resultRows: result_rows,
resultColumns: result_columns,
interiorColumns: interior_columns,
alpha: alpha,
beta: beta
]
};
Ok(matmul)
}
#[cfg(not(all(feature = "metal", target_os = "macos")))]
pub fn create_matmul(
&self,
_transpose_left: bool,
_transpose_right: bool,
_result_rows: usize,
_result_columns: usize,
_interior_columns: usize,
_alpha: f64,
_beta: f64,
) -> Result<(), GpuError> {
Err(GpuError::Other(
"Metal not available on this platform".to_string(),
))
}
#[cfg(all(feature = "metal", target_os = "macos"))]
pub fn create_descriptor(
rows: usize,
columns: usize,
row_bytes: usize,
datatype: MPSDataType,
) -> Result<Retained<MPSMatrixDescriptor>, GpuError> {
use objc2_metal_performance_shaders::MPSMatrixDescriptor;
let mps_datatype = match datatype {
MPSDataType::Float32 => MPSDataTypeEnum::Float32,
MPSDataType::Float16 => MPSDataTypeEnum::Float16,
MPSDataType::Int32 => MPSDataTypeEnum::Int32,
_ => {
return Err(GpuError::Other(format!(
"Unsupported datatype: {:?}",
datatype
)))
}
};
let descriptor = unsafe {
let cls = MPSMatrixDescriptor::class();
msg_send_id![
cls,
matrixDescriptorWithRows: rows,
columns: columns,
rowBytes: row_bytes,
dataType: mps_datatype
]
};
Ok(descriptor)
}
#[cfg(not(all(feature = "metal", target_os = "macos")))]
pub fn create_descriptor(
_rows: usize,
_columns: usize,
_row_bytes: usize,
_datatype: MPSDataType,
) -> Result<(), GpuError> {
Err(GpuError::Other(
"Metal not available on this platform".to_string(),
))
}
#[cfg(all(feature = "metal", target_os = "macos"))]
pub fn create_matrix(
&self,
buffer: &Retained<ProtocolObject<dyn MTLBuffer>>,
descriptor: &Retained<MPSMatrixDescriptor>,
) -> Result<Retained<MPSMatrix>, GpuError> {
use objc2_metal_performance_shaders::MPSMatrix;
let matrix = unsafe {
let cls = MPSMatrix::class();
let alloc = msg_send_id![cls, alloc];
msg_send_id![
alloc,
initWithBuffer: &**buffer,
descriptor: &**descriptor
]
};
Ok(matrix)
}
#[cfg(not(all(feature = "metal", target_os = "macos")))]
pub fn create_matrix(&self, _buffer: &MTLBuffer, _descriptor: &()) -> Result<(), GpuError> {
Err(GpuError::Other(
"Metal not available on this platform".to_string(),
))
}
#[cfg(all(feature = "metal", target_os = "macos"))]
pub fn create_command_buffer(&self) -> Result<Retained<AnyObject>, GpuError> {
let command_buffer: Option<Retained<AnyObject>> =
unsafe { msg_send_id![&self.command_queue, commandBuffer] };
command_buffer.ok_or_else(|| GpuError::Other("Failed to create command buffer".to_string()))
}
#[cfg(all(feature = "metal", target_os = "macos"))]
pub fn commit_command_buffer(&self, command_buffer: &Retained<AnyObject>) {
unsafe {
let _: () = msg_send![&**command_buffer, commit];
}
}
#[cfg(all(feature = "metal", target_os = "macos"))]
pub fn wait_for_command_buffer(&self, command_buffer: &Retained<AnyObject>) {
unsafe {
let _: () = msg_send![&**command_buffer, waitUntilCompleted];
}
}
#[cfg(all(feature = "metal", target_os = "macos"))]
pub fn encode_matrix_multiply(
&self,
command_buffer: &Retained<AnyObject>,
left_matrix: &Retained<MPSMatrix>,
right_matrix: &Retained<MPSMatrix>,
result_matrix: &Retained<MPSMatrix>,
matmul: &Retained<MPSMatrixMultiplication>,
) -> Result<(), GpuError> {
unsafe {
let _: () = msg_send![
&**matmul,
encodeToCommandBuffer: &**command_buffer,
leftMatrix: &**left_matrix,
rightMatrix: &**right_matrix,
resultMatrix: &**result_matrix
];
}
Ok(())
}
#[cfg(all(feature = "metal", target_os = "macos"))]
pub fn matrix_multiply(
&self,
left_matrix: &Retained<MPSMatrix>,
right_matrix: &Retained<MPSMatrix>,
result_matrix: &Retained<MPSMatrix>,
matmul: &Retained<MPSMatrixMultiplication>,
) -> Result<(), GpuError> {
use objc2_metal::MTLCommandBuffer;
let command_buffer = self.create_command_buffer()?;
self.encode_matrix_multiply(
&command_buffer,
left_matrix,
right_matrix,
result_matrix,
matmul,
)?;
self.commit_command_buffer(&command_buffer);
self.wait_for_command_buffer(&command_buffer);
Ok(())
}
#[cfg(not(all(feature = "metal", target_os = "macos")))]
pub fn matrix_multiply(
&self,
_left: &(),
_right: &(),
_result: &(),
_matmul: &(),
) -> Result<(), GpuError> {
Err(GpuError::Other(
"Metal not available on this platform".to_string(),
))
}
pub fn create_softmax(&self, _axis: i32) -> Result<(), GpuError> {
Err(GpuError::Other(
"MPS softmax not yet implemented with new objc2 API".to_string(),
))
}
pub fn create_sum(&self) -> Result<(), GpuError> {
Err(GpuError::Other(
"MPS sum not yet implemented with new objc2 API".to_string(),
))
}
pub fn create_find_top_k(&self, _k: usize) -> Result<(), GpuError> {
Err(GpuError::Other(
"MPS top-k not yet implemented with new objc2 API".to_string(),
))
}
}
pub struct MPSConvolution {
pub(crate) context: Arc<MPSContext>,
}
impl MPSConvolution {
pub fn new(_context: Arc<MPSContext>) -> Result<Self, GpuError> {
Err(GpuError::Other(
"MPS convolution not yet implemented with new objc2 API".to_string(),
))
}
#[cfg(all(feature = "metal", target_os = "macos"))]
pub fn execute(
&self,
_input: &objc2::rc::Retained<dyn MTLBuffer>,
_weights: &objc2::rc::Retained<dyn MTLBuffer>,
_output: &mut objc2::rc::Retained<dyn MTLBuffer>,
) -> Result<(), GpuError> {
Err(GpuError::Other(
"MPS convolution execution not yet implemented with new objc2 API".to_string(),
))
}
#[cfg(not(all(feature = "metal", target_os = "macos")))]
pub fn execute(
&self,
_input: &MTLBuffer,
_weights: &MTLBuffer,
_output: &mut MTLBuffer,
) -> Result<(), GpuError> {
Err(GpuError::Other(
"Metal not available on this platform".to_string(),
))
}
}
pub struct MPSPooling {
pub(crate) context: Arc<MPSContext>,
pub(crate) pool_type: PoolType,
}
#[derive(Clone, Copy, Debug)]
pub enum PoolType {
Max,
Average,
}
impl MPSPooling {
pub fn new(_context: Arc<MPSContext>, _pool_type: PoolType) -> Result<Self, GpuError> {
Err(GpuError::Other(
"MPS pooling not yet implemented with new objc2 API".to_string(),
))
}
#[cfg(all(feature = "metal", target_os = "macos"))]
pub fn execute(
&self,
_input: &objc2::rc::Retained<dyn MTLBuffer>,
_output: &mut objc2::rc::Retained<dyn MTLBuffer>,
_kernel_size: (usize, usize),
_stride: (usize, usize),
) -> Result<(), GpuError> {
Err(GpuError::Other(
"MPS pooling execution not yet implemented with new objc2 API".to_string(),
))
}
#[cfg(not(all(feature = "metal", target_os = "macos")))]
pub fn execute(
&self,
_input: &MTLBuffer,
_output: &mut MTLBuffer,
_kernel_size: (usize, usize),
_stride: (usize, usize),
) -> Result<(), GpuError> {
Err(GpuError::Other(
"Metal not available on this platform".to_string(),
))
}
}
#[derive(Clone, Copy, Debug)]
pub enum MPSDataType {
Float32,
Float16,
Int32,
Int16,
Int8,
UInt8,
}
impl MPSDataType {
pub fn to_mps_datatype(self) -> u32 {
match self {
MPSDataType::Float32 => 0x10000 | 32, MPSDataType::Float16 => 0x10000 | 16, MPSDataType::Int32 => 0x20000 | 32, MPSDataType::Int16 => 0x20000 | 16, MPSDataType::Int8 => 0x20000 | 8, MPSDataType::UInt8 => 0x30000 | 8, }
}
}
pub struct MPSOperations {
context: Arc<MPSContext>,
}
unsafe impl Send for MPSOperations {}
unsafe impl Sync for MPSOperations {}
impl MPSOperations {
#[cfg(all(feature = "metal", target_os = "macos"))]
pub fn new(
device: Retained<ProtocolObject<dyn MTLDevice>>,
command_queue: Retained<ProtocolObject<dyn MTLCommandQueue>>,
) -> Self {
Self {
context: Arc::new(MPSContext::new(device, command_queue)),
}
}
#[cfg(not(all(feature = "metal", target_os = "macos")))]
pub fn new(device: MTLDevice, command_queue: MTLCommandQueue) -> Self {
Self {
context: Arc::new(MPSContext::new(device, command_queue)),
}
}
pub fn context(&self) -> &Arc<MPSContext> {
&self.context
}
#[cfg(all(feature = "metal", target_os = "macos"))]
pub fn encode_matmul_f32(
&self,
command_buffer: &Retained<AnyObject>,
a_buffer: &Retained<ProtocolObject<dyn MTLBuffer>>,
b_buffer: &Retained<ProtocolObject<dyn MTLBuffer>>,
c_buffer: &Retained<ProtocolObject<dyn MTLBuffer>>,
m: usize,
k: usize,
n: usize,
) -> Result<(), GpuError> {
let a_desc = MPSContext::create_descriptor(m, k, k * 4, MPSDataType::Float32)?;
let b_desc = MPSContext::create_descriptor(k, n, n * 4, MPSDataType::Float32)?;
let c_desc = MPSContext::create_descriptor(m, n, n * 4, MPSDataType::Float32)?;
let a_matrix = self.context.create_matrix(a_buffer, &a_desc)?;
let b_matrix = self.context.create_matrix(b_buffer, &b_desc)?;
let c_matrix = self.context.create_matrix(c_buffer, &c_desc)?;
let matmul = self.context.create_matmul(
false, false, m, n, k, 1.0, 0.0, )?;
self.context.encode_matrix_multiply(
command_buffer,
&a_matrix,
&b_matrix,
&c_matrix,
&matmul,
)?;
Ok(())
}
#[cfg(all(feature = "metal", target_os = "macos"))]
pub fn matmul_f32(
&self,
a_buffer: &Retained<ProtocolObject<dyn MTLBuffer>>,
b_buffer: &Retained<ProtocolObject<dyn MTLBuffer>>,
c_buffer: &Retained<ProtocolObject<dyn MTLBuffer>>,
m: usize,
k: usize,
n: usize,
) -> Result<(), GpuError> {
let a_desc = MPSContext::create_descriptor(m, k, k * 4, MPSDataType::Float32)?;
let b_desc = MPSContext::create_descriptor(k, n, n * 4, MPSDataType::Float32)?;
let c_desc = MPSContext::create_descriptor(m, n, n * 4, MPSDataType::Float32)?;
let a_matrix = self.context.create_matrix(a_buffer, &a_desc)?;
let b_matrix = self.context.create_matrix(b_buffer, &b_desc)?;
let c_matrix = self.context.create_matrix(c_buffer, &c_desc)?;
let matmul = self.context.create_matmul(
false, false, m, n, k, 1.0, 0.0, )?;
self.context
.matrix_multiply(&a_matrix, &b_matrix, &c_matrix, &matmul)?;
Ok(())
}
#[cfg(not(all(feature = "metal", target_os = "macos")))]
pub fn matmul_f32(
&self,
_a_buffer: &(),
_b_buffer: &(),
_c_buffer: &(),
_m: usize,
_k: usize,
_n: usize,
) -> Result<(), GpuError> {
Err(GpuError::Other(
"Metal not available on this platform".to_string(),
))
}
#[cfg(all(feature = "metal", target_os = "macos"))]
pub fn matmul_f32_scaled(
&self,
a_buffer: &Retained<ProtocolObject<dyn MTLBuffer>>,
b_buffer: &Retained<ProtocolObject<dyn MTLBuffer>>,
c_buffer: &Retained<ProtocolObject<dyn MTLBuffer>>,
m: usize,
k: usize,
n: usize,
alpha: f32,
) -> Result<(), GpuError> {
let a_desc = MPSContext::create_descriptor(m, k, k * 4, MPSDataType::Float32)?;
let b_desc = MPSContext::create_descriptor(k, n, n * 4, MPSDataType::Float32)?;
let c_desc = MPSContext::create_descriptor(m, n, n * 4, MPSDataType::Float32)?;
let a_matrix = self.context.create_matrix(a_buffer, &a_desc)?;
let b_matrix = self.context.create_matrix(b_buffer, &b_desc)?;
let c_matrix = self.context.create_matrix(c_buffer, &c_desc)?;
let matmul = self.context.create_matmul(
false, false, m, n, k, alpha as f64, 0.0, )?;
self.context
.matrix_multiply(&a_matrix, &b_matrix, &c_matrix, &matmul)?;
Ok(())
}
#[cfg(not(all(feature = "metal", target_os = "macos")))]
pub fn matmul_f32_scaled(
&self,
_a_buffer: &(),
_b_buffer: &(),
_c_buffer: &(),
_m: usize,
_k: usize,
_n: usize,
_alpha: f32,
) -> Result<(), GpuError> {
Err(GpuError::Other(
"Metal not available on this platform".to_string(),
))
}
}
pub struct MPSImageOps {
pub(crate) context: Arc<MPSContext>,
}
impl MPSImageOps {
pub fn new(_context: Arc<MPSContext>) -> Result<Self, GpuError> {
Err(GpuError::Other(
"MPS image operations not yet implemented with new objc2 API".to_string(),
))
}
#[cfg(all(feature = "metal", target_os = "macos"))]
pub fn gaussian_blur(
&self,
_input: &objc2::rc::Retained<dyn MTLBuffer>,
_output: &mut objc2::rc::Retained<dyn MTLBuffer>,
_sigma: f32,
) -> Result<(), GpuError> {
Err(GpuError::Other(
"MPS Gaussian blur not yet implemented with new objc2 API".to_string(),
))
}
#[cfg(not(all(feature = "metal", target_os = "macos")))]
pub fn gaussian_blur(
&self,
_input: &MTLBuffer,
_output: &mut MTLBuffer,
_sigma: f32,
) -> Result<(), GpuError> {
Err(GpuError::Other(
"Metal not available on this platform".to_string(),
))
}
#[cfg(all(feature = "metal", target_os = "macos"))]
pub fn edge_detection(
&self,
_input: &objc2::rc::Retained<dyn MTLBuffer>,
_output: &mut objc2::rc::Retained<dyn MTLBuffer>,
_threshold: f32,
) -> Result<(), GpuError> {
Err(GpuError::Other(
"MPS edge detection not yet implemented with new objc2 API".to_string(),
))
}
#[cfg(not(all(feature = "metal", target_os = "macos")))]
pub fn edge_detection(
&self,
_input: &MTLBuffer,
_output: &mut MTLBuffer,
_threshold: f32,
) -> Result<(), GpuError> {
Err(GpuError::Other(
"Metal not available on this platform".to_string(),
))
}
}