use crate::error::{RusTorchError, RusTorchResult};
use std::collections::HashMap;
use std::ffi::c_void;
use std::marker::PhantomData;
#[cfg(feature = "metal")]
use metal::foreign_types::ForeignType;
#[cfg(feature = "metal")]
use metal::{
CommandQueue, CompileOptions, ComputePipelineState, Device, Library, MTLResourceOptions,
MTLSize,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum MetalKernelType {
ElementWise,
MatMul,
Reduction,
Convolution,
BatchNorm,
}
#[derive(Debug, Clone)]
pub struct MetalKernelParams {
pub threads_per_threadgroup: (u32, u32, u32),
pub threadgroups_per_grid: (u32, u32, u32),
}
impl Default for MetalKernelParams {
fn default() -> Self {
Self {
threads_per_threadgroup: (1, 1, 1),
threadgroups_per_grid: (1, 1, 1),
}
}
}
pub struct MetalBuffer<T> {
_buffer: *mut c_void,
size: usize,
_phantom: PhantomData<T>,
}
impl<T> MetalBuffer<T> {
#[cfg(feature = "metal")]
pub fn new(size: usize, device: &Device) -> RusTorchResult<Self> {
let buffer_size = size * std::mem::size_of::<T>();
let buffer = device.new_buffer(buffer_size as u64, MTLResourceOptions::StorageModeShared);
Ok(Self {
_buffer: buffer.as_ptr() as *mut c_void,
size,
_phantom: PhantomData,
})
}
#[cfg(not(feature = "metal"))]
pub fn new(_size: usize, _device: &()) -> RusTorchResult<Self> {
Err(RusTorchError::UnsupportedDevice(
"Metal not available".to_string(),
))
}
pub fn copy_from_host(&mut self, host_data: &[T]) -> RusTorchResult<()> {
if host_data.len() != self.size {
return Err(RusTorchError::InvalidOperation(
"Size mismatch in host-to-device copy".to_string(),
));
}
#[cfg(feature = "metal")]
{
unsafe {
let buffer_ptr = self._buffer as *mut T;
std::ptr::copy_nonoverlapping(host_data.as_ptr(), buffer_ptr, self.size);
}
Ok(())
}
#[cfg(not(feature = "metal"))]
{
Err(RusTorchError::UnsupportedDevice(
"Metal not available".to_string(),
))
}
}
pub fn copy_to_host(&self, host_data: &mut [T]) -> RusTorchResult<()> {
if host_data.len() != self.size {
return Err(RusTorchError::InvalidOperation(
"Size mismatch in device-to-host copy".to_string(),
));
}
#[cfg(feature = "metal")]
{
unsafe {
let buffer_ptr = self._buffer as *const T;
std::ptr::copy_nonoverlapping(buffer_ptr, host_data.as_mut_ptr(), self.size);
}
Ok(())
}
#[cfg(not(feature = "metal"))]
{
Err(RusTorchError::UnsupportedDevice(
"Metal not available".to_string(),
))
}
}
}
#[cfg(feature = "metal")]
pub struct MetalKernelExecutor {
device: Device,
command_queue: CommandQueue,
library: Library,
pipeline_states: HashMap<MetalKernelType, ComputePipelineState>,
}
#[cfg(feature = "metal")]
impl MetalKernelExecutor {
pub fn new() -> RusTorchResult<Self> {
let device = Device::system_default()
.ok_or_else(|| RusTorchError::tensor_op("No Metal device available"))?;
let command_queue = device.new_command_queue();
let shader_source = r#"
#include <metal_stdlib>
using namespace metal;
// Optimized element-wise addition with vectorization
kernel void elementwise_add_f32(
device const float* a [[buffer(0)]],
device const float* b [[buffer(1)]],
device float* result [[buffer(2)]],
uint index [[thread_position_in_grid]]
) {
result[index] = a[index] + b[index];
}
// Optimized element-wise multiplication with vectorization
kernel void elementwise_mul_f32(
device const float* a [[buffer(0)]],
device const float* b [[buffer(1)]],
device float* result [[buffer(2)]],
uint index [[thread_position_in_grid]]
) {
result[index] = a[index] * b[index];
}
// High-performance matrix multiplication optimized for Vega 56 architecture
kernel void matrix_multiply_f32(
device const float* a [[buffer(0)]],
device const float* b [[buffer(1)]],
device float* c [[buffer(2)]],
constant uint& M [[buffer(3)]],
constant uint& N [[buffer(4)]],
constant uint& K [[buffer(5)]],
uint2 gid [[thread_position_in_grid]]
) {
uint row = gid.y;
uint col = gid.x;
if (row >= M || col >= N) return;
float sum = 0.0;
// Unroll loop for better performance on Vega architecture
for (uint k = 0; k < K; k++) {
sum += a[row * K + k] * b[k * N + col];
}
c[row * N + col] = sum;
}
// Optimized tiled matrix multiplication for large matrices
kernel void tiled_matrix_multiply_f32(
device const float* a [[buffer(0)]],
device const float* b [[buffer(1)]],
device float* c [[buffer(2)]],
constant uint& M [[buffer(3)]],
constant uint& N [[buffer(4)]],
constant uint& K [[buffer(5)]],
threadgroup float* tile_a [[threadgroup(0)]],
threadgroup float* tile_b [[threadgroup(1)]],
uint2 gid [[thread_position_in_grid]],
uint2 lid [[thread_position_in_threadgroup]]
) {
const uint TILE_SIZE = 16; // Optimized for Vega 56 memory hierarchy
uint row = gid.y;
uint col = gid.x;
uint local_row = lid.y;
uint local_col = lid.x;
float sum = 0.0;
// Tiled computation for better memory bandwidth utilization
for (uint tile = 0; tile < (K + TILE_SIZE - 1) / TILE_SIZE; tile++) {
// Load tiles into shared memory
uint a_idx = row * K + tile * TILE_SIZE + local_col;
uint b_idx = (tile * TILE_SIZE + local_row) * N + col;
tile_a[local_row * TILE_SIZE + local_col] =
(row < M && tile * TILE_SIZE + local_col < K) ? a[a_idx] : 0.0;
tile_b[local_row * TILE_SIZE + local_col] =
(tile * TILE_SIZE + local_row < K && col < N) ? b[b_idx] : 0.0;
threadgroup_barrier(mem_flags::mem_threadgroup);
// Compute partial result
for (uint k = 0; k < TILE_SIZE; k++) {
sum += tile_a[local_row * TILE_SIZE + k] *
tile_b[k * TILE_SIZE + local_col];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
if (row < M && col < N) {
c[row * N + col] = sum;
}
}
kernel void reduce_sum_f32(
device const float* input [[buffer(0)]],
device float* output [[buffer(1)]],
threadgroup float* shared_data [[threadgroup(0)]],
uint tid [[thread_index_in_threadgroup]],
uint bid [[threadgroup_position_in_grid]],
uint block_size [[threads_per_threadgroup]],
uint grid_size [[threadgroups_per_grid]]
) {
uint index = bid * block_size + tid;
// Load data into shared memory
shared_data[tid] = (index < grid_size * block_size) ? input[index] : 0.0;
threadgroup_barrier(mem_flags::mem_threadgroup);
// Reduction in shared memory
for (uint s = block_size / 2; s > 0; s >>= 1) {
if (tid < s) {
shared_data[tid] += shared_data[tid + s];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// Write result for this block
if (tid == 0) {
output[bid] = shared_data[0];
}
}
// Activation functions optimized for Vega 56 architecture
// ReLU activation: max(0, x)
kernel void relu_f32(
device const float* input [[buffer(0)]],
device float* output [[buffer(1)]],
constant uint& n [[buffer(2)]],
uint index [[thread_position_in_grid]]
) {
if (index >= n) return;
output[index] = fmax(0.0f, input[index]);
}
// Sigmoid activation: 1 / (1 + exp(-x))
kernel void sigmoid_f32(
device const float* input [[buffer(0)]],
device float* output [[buffer(1)]],
constant uint& n [[buffer(2)]],
uint index [[thread_position_in_grid]]
) {
if (index >= n) return;
output[index] = 1.0f / (1.0f + exp(-input[index]));
}
// Tanh activation: (exp(x) - exp(-x)) / (exp(x) + exp(-x))
kernel void tanh_f32(
device const float* input [[buffer(0)]],
device float* output [[buffer(1)]],
constant uint& n [[buffer(2)]],
uint index [[thread_position_in_grid]]
) {
if (index >= n) return;
output[index] = tanh(input[index]);
}
// GELU activation: x * 0.5 * (1 + erf(x / sqrt(2))) - using Abramowitz and Stegun approximation
kernel void gelu_f32(
device const float* input [[buffer(0)]],
device float* output [[buffer(1)]],
constant uint& n [[buffer(2)]],
uint index [[thread_position_in_grid]]
) {
if (index >= n) return;
float x = input[index];
float sqrt_2_inv = 0.70710678118f; // 1 / sqrt(2)
float t = x * sqrt_2_inv;
// Abramowitz and Stegun approximation for erf function
float a1 = 0.254829592f;
float a2 = -0.284496736f;
float a3 = 1.421413741f;
float a4 = -1.453152027f;
float a5 = 1.061405429f;
float p = 0.3275911f;
float sign = (t >= 0.0f) ? 1.0f : -1.0f;
t = fabs(t);
float t_p = 1.0f / (1.0f + p * t);
float erf_approx = 1.0f - (((((a5 * t_p + a4) * t_p) + a3) * t_p + a2) * t_p + a1) * t_p * exp(-t * t);
erf_approx *= sign;
output[index] = x * 0.5f * (1.0f + erf_approx);
}
// Softmax activation (per element, requires separate reduction for proper implementation)
kernel void softmax_f32(
device const float* input [[buffer(0)]],
device float* output [[buffer(1)]],
constant uint& n [[buffer(2)]],
constant float& max_val [[buffer(3)]],
constant float& sum_exp [[buffer(4)]],
uint index [[thread_position_in_grid]]
) {
if (index >= n) return;
float exp_val = exp(input[index] - max_val);
output[index] = exp_val / sum_exp;
}
// Leaky ReLU activation: max(alpha * x, x)
kernel void leaky_relu_f32(
device const float* input [[buffer(0)]],
device float* output [[buffer(1)]],
constant uint& n [[buffer(2)]],
constant float& alpha [[buffer(3)]],
uint index [[thread_position_in_grid]]
) {
if (index >= n) return;
float x = input[index];
output[index] = (x > 0.0f) ? x : alpha * x;
}
// ELU activation: x if x > 0, alpha * (exp(x) - 1) if x <= 0
kernel void elu_f32(
device const float* input [[buffer(0)]],
device float* output [[buffer(1)]],
constant uint& n [[buffer(2)]],
constant float& alpha [[buffer(3)]],
uint index [[thread_position_in_grid]]
) {
if (index >= n) return;
float x = input[index];
output[index] = (x > 0.0f) ? x : alpha * (exp(x) - 1.0f);
}
// Swish activation: x * sigmoid(x)
kernel void swish_f32(
device const float* input [[buffer(0)]],
device float* output [[buffer(1)]],
constant uint& n [[buffer(2)]],
uint index [[thread_position_in_grid]]
) {
if (index >= n) return;
float x = input[index];
float sigmoid_x = 1.0f / (1.0f + exp(-x));
output[index] = x * sigmoid_x;
}
"#;
let library = device
.new_library_with_source(shader_source, &CompileOptions::new())
.map_err(|e| {
RusTorchError::KernelError(format!("Failed to compile Metal shaders: {:?}", e))
})?;
let mut pipeline_states = HashMap::new();
let add_function = library
.get_function("elementwise_add_f32", None)
.map_err(|e| {
RusTorchError::KernelError(format!("Failed to get add function: {:?}", e))
})?;
let add_pipeline = device
.new_compute_pipeline_state_with_function(&add_function)
.map_err(|e| {
RusTorchError::KernelError(format!("Failed to create add pipeline: {:?}", e))
})?;
pipeline_states.insert(MetalKernelType::ElementWise, add_pipeline);
let matmul_function = library
.get_function("matrix_multiply_f32", None)
.map_err(|e| {
RusTorchError::KernelError(format!("Failed to get matmul function: {:?}", e))
})?;
let matmul_pipeline = device
.new_compute_pipeline_state_with_function(&matmul_function)
.map_err(|e| {
RusTorchError::KernelError(format!("Failed to create matmul pipeline: {:?}", e))
})?;
pipeline_states.insert(MetalKernelType::MatMul, matmul_pipeline);
let tiled_matmul_function = library
.get_function("tiled_matrix_multiply_f32", None)
.map_err(|e| {
RusTorchError::KernelError(format!("Failed to get tiled matmul function: {:?}", e))
})?;
let tiled_matmul_pipeline = device
.new_compute_pipeline_state_with_function(&tiled_matmul_function)
.map_err(|e| {
RusTorchError::KernelError(format!(
"Failed to create tiled matmul pipeline: {:?}",
e
))
})?;
pipeline_states.insert(MetalKernelType::Convolution, tiled_matmul_pipeline);
let reduce_function = library.get_function("reduce_sum_f32", None).map_err(|e| {
RusTorchError::KernelError(format!("Failed to get reduce function: {:?}", e))
})?;
let reduce_pipeline = device
.new_compute_pipeline_state_with_function(&reduce_function)
.map_err(|e| {
RusTorchError::KernelError(format!("Failed to create reduce pipeline: {:?}", e))
})?;
pipeline_states.insert(MetalKernelType::Reduction, reduce_pipeline);
Ok(Self {
device,
command_queue,
library,
pipeline_states,
})
}
pub fn elementwise_add_f32(&self, a: &[f32], b: &[f32], c: &mut [f32]) -> RusTorchResult<()> {
let size = a.len();
if b.len() != size || c.len() != size {
return Err(RusTorchError::InvalidOperation(
"Array size mismatch in element-wise addition".to_string(),
));
}
let a_buffer = self.device.new_buffer_with_data(
a.as_ptr() as *const c_void,
(size * std::mem::size_of::<f32>()) as u64,
MTLResourceOptions::StorageModeShared,
);
let b_buffer = self.device.new_buffer_with_data(
b.as_ptr() as *const c_void,
(size * std::mem::size_of::<f32>()) as u64,
MTLResourceOptions::StorageModeShared,
);
let c_buffer = self.device.new_buffer(
(size * std::mem::size_of::<f32>()) as u64,
MTLResourceOptions::StorageModeShared,
);
let pipeline_state = self
.pipeline_states
.get(&MetalKernelType::ElementWise)
.ok_or_else(|| {
RusTorchError::KernelError("ElementWise pipeline not found".to_string())
})?;
let command_buffer = self.command_queue.new_command_buffer();
let compute_encoder = command_buffer.new_compute_command_encoder();
compute_encoder.set_compute_pipeline_state(pipeline_state);
compute_encoder.set_buffer(0, Some(&a_buffer), 0);
compute_encoder.set_buffer(1, Some(&b_buffer), 0);
compute_encoder.set_buffer(2, Some(&c_buffer), 0);
let threads_per_threadgroup = MTLSize::new(256, 1, 1);
let threadgroups_per_grid = MTLSize::new(((size + 255) / 256) as u64, 1, 1);
compute_encoder.dispatch_thread_groups(threadgroups_per_grid, threads_per_threadgroup);
compute_encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();
unsafe {
let result_ptr = c_buffer.contents() as *const f32;
std::ptr::copy_nonoverlapping(result_ptr, c.as_mut_ptr(), size);
}
Ok(())
}
pub fn tiled_matmul_f32(
&self,
a: &[f32],
b: &[f32],
c: &mut [f32],
m: usize,
n: usize,
k: usize,
) -> RusTorchResult<()> {
const TILE_SIZE: usize = 16;
let a_buffer = self.device.new_buffer_with_data(
a.as_ptr() as *const c_void,
(m * k * std::mem::size_of::<f32>()) as u64,
MTLResourceOptions::StorageModeShared,
);
let b_buffer = self.device.new_buffer_with_data(
b.as_ptr() as *const c_void,
(k * n * std::mem::size_of::<f32>()) as u64,
MTLResourceOptions::StorageModeShared,
);
let c_buffer = self.device.new_buffer(
(m * n * std::mem::size_of::<f32>()) as u64,
MTLResourceOptions::StorageModeShared,
);
let m_u32 = m as u32;
let n_u32 = n as u32;
let k_u32 = k as u32;
let m_buffer = self.device.new_buffer_with_data(
&m_u32 as *const u32 as *const c_void,
std::mem::size_of::<u32>() as u64,
MTLResourceOptions::StorageModeShared,
);
let n_buffer = self.device.new_buffer_with_data(
&n_u32 as *const u32 as *const c_void,
std::mem::size_of::<u32>() as u64,
MTLResourceOptions::StorageModeShared,
);
let k_buffer = self.device.new_buffer_with_data(
&k_u32 as *const u32 as *const c_void,
std::mem::size_of::<u32>() as u64,
MTLResourceOptions::StorageModeShared,
);
let pipeline_state = self
.pipeline_states
.get(&MetalKernelType::Convolution)
.ok_or_else(|| {
RusTorchError::KernelError("Tiled MatMul pipeline not found".to_string())
})?;
let command_buffer = self.command_queue.new_command_buffer();
let compute_encoder = command_buffer.new_compute_command_encoder();
compute_encoder.set_compute_pipeline_state(pipeline_state);
compute_encoder.set_buffer(0, Some(&a_buffer), 0);
compute_encoder.set_buffer(1, Some(&b_buffer), 0);
compute_encoder.set_buffer(2, Some(&c_buffer), 0);
compute_encoder.set_buffer(3, Some(&m_buffer), 0);
compute_encoder.set_buffer(4, Some(&n_buffer), 0);
compute_encoder.set_buffer(5, Some(&k_buffer), 0);
let threadgroup_memory_size = 2 * TILE_SIZE * TILE_SIZE * std::mem::size_of::<f32>();
compute_encoder.set_threadgroup_memory_length(0, threadgroup_memory_size as u64);
compute_encoder.set_threadgroup_memory_length(1, threadgroup_memory_size as u64);
let threads_per_threadgroup = MTLSize::new(TILE_SIZE as u64, TILE_SIZE as u64, 1);
let threadgroups_per_grid = MTLSize::new(
n.div_ceil(TILE_SIZE) as u64,
m.div_ceil(TILE_SIZE) as u64,
1,
);
compute_encoder.dispatch_thread_groups(threadgroups_per_grid, threads_per_threadgroup);
compute_encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();
unsafe {
let result_ptr = c_buffer.contents() as *const f32;
std::ptr::copy_nonoverlapping(result_ptr, c.as_mut_ptr(), m * n);
}
Ok(())
}
pub fn matmul_f32(
&self,
a: &[f32],
b: &[f32],
c: &mut [f32],
m: usize,
n: usize,
k: usize,
) -> RusTorchResult<()> {
if m >= 256 || n >= 256 || k >= 256 {
return self.tiled_matmul_f32(a, b, c, m, n, k);
}
let a_buffer = self.device.new_buffer_with_data(
a.as_ptr() as *const c_void,
(m * k * std::mem::size_of::<f32>()) as u64,
MTLResourceOptions::StorageModeShared,
);
let b_buffer = self.device.new_buffer_with_data(
b.as_ptr() as *const c_void,
(k * n * std::mem::size_of::<f32>()) as u64,
MTLResourceOptions::StorageModeShared,
);
let c_buffer = self.device.new_buffer(
(m * n * std::mem::size_of::<f32>()) as u64,
MTLResourceOptions::StorageModeShared,
);
let m_u32 = m as u32;
let n_u32 = n as u32;
let k_u32 = k as u32;
let m_buffer = self.device.new_buffer_with_data(
&m_u32 as *const u32 as *const c_void,
std::mem::size_of::<u32>() as u64,
MTLResourceOptions::StorageModeShared,
);
let n_buffer = self.device.new_buffer_with_data(
&n_u32 as *const u32 as *const c_void,
std::mem::size_of::<u32>() as u64,
MTLResourceOptions::StorageModeShared,
);
let k_buffer = self.device.new_buffer_with_data(
&k_u32 as *const u32 as *const c_void,
std::mem::size_of::<u32>() as u64,
MTLResourceOptions::StorageModeShared,
);
let pipeline_state = self
.pipeline_states
.get(&MetalKernelType::MatMul)
.ok_or_else(|| RusTorchError::KernelError("MatMul pipeline not found".to_string()))?;
let command_buffer = self.command_queue.new_command_buffer();
let compute_encoder = command_buffer.new_compute_command_encoder();
compute_encoder.set_compute_pipeline_state(pipeline_state);
compute_encoder.set_buffer(0, Some(&a_buffer), 0);
compute_encoder.set_buffer(1, Some(&b_buffer), 0);
compute_encoder.set_buffer(2, Some(&c_buffer), 0);
compute_encoder.set_buffer(3, Some(&m_buffer), 0);
compute_encoder.set_buffer(4, Some(&n_buffer), 0);
compute_encoder.set_buffer(5, Some(&k_buffer), 0);
let threads_per_threadgroup = MTLSize::new(16, 16, 1);
let threadgroups_per_grid = MTLSize::new(((n + 15) / 16) as u64, ((m + 15) / 16) as u64, 1);
compute_encoder.dispatch_thread_groups(threadgroups_per_grid, threads_per_threadgroup);
compute_encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();
unsafe {
let result_ptr = c_buffer.contents() as *const f32;
std::ptr::copy_nonoverlapping(result_ptr, c.as_mut_ptr(), m * n);
}
Ok(())
}
pub fn matrix_multiply_f32(
&self,
a: &[f32],
b: &[f32],
m: usize,
n: usize,
k: usize,
) -> RusTorchResult<Vec<f32>> {
if a.len() != m * k || b.len() != k * n {
return Err(RusTorchError::InvalidOperation(
"Matrix dimension mismatch".to_string(),
));
}
let mut result = vec![0.0f32; m * n];
self.matmul_f32(a, b, &mut result, m, n, k)?;
Ok(result)
}
pub fn conv2d_f32(
&self,
input: &[f32],
kernel: &[f32],
output: &mut [f32],
input_height: usize,
input_width: usize,
input_channels: usize,
output_channels: usize,
kernel_height: usize,
kernel_width: usize,
stride_h: usize,
stride_w: usize,
pad_h: usize,
pad_w: usize,
) -> RusTorchResult<()> {
let output_height = (input_height + 2 * pad_h - kernel_height) / stride_h + 1;
let output_width = (input_width + 2 * pad_w - kernel_width) / stride_w + 1;
let patch_size = kernel_height * kernel_width * input_channels;
let num_patches = output_height * output_width;
let mut im2col_matrix = vec![0.0f32; patch_size * num_patches];
for out_h in 0..output_height {
for out_w in 0..output_width {
let patch_idx = out_h * output_width + out_w;
let start_h = out_h * stride_h;
let start_w = out_w * stride_w;
for kh in 0..kernel_height {
for kw in 0..kernel_width {
for ch in 0..input_channels {
let in_h = start_h + kh;
let in_w = start_w + kw;
let im2col_row =
ch * kernel_height * kernel_width + kh * kernel_width + kw;
let im2col_idx = im2col_row * num_patches + patch_idx;
if in_h >= pad_h
&& in_h < input_height + pad_h
&& in_w >= pad_w
&& in_w < input_width + pad_w
{
let actual_in_h = in_h - pad_h;
let actual_in_w = in_w - pad_w;
let input_idx = ch * input_height * input_width
+ actual_in_h * input_width
+ actual_in_w;
im2col_matrix[im2col_idx] = input[input_idx];
} else {
im2col_matrix[im2col_idx] = 0.0; }
}
}
}
}
}
let mut result_matrix = vec![0.0f32; output_channels * num_patches];
self.matmul_f32(
kernel,
&im2col_matrix,
&mut result_matrix,
output_channels,
num_patches,
patch_size,
)?;
for out_ch in 0..output_channels {
for out_h in 0..output_height {
for out_w in 0..output_width {
let result_idx = out_ch * num_patches + out_h * output_width + out_w;
let output_idx =
out_ch * output_height * output_width + out_h * output_width + out_w;
output[output_idx] = result_matrix[result_idx];
}
}
}
Ok(())
}
pub fn reduce_sum_f32(&self, input: &[f32]) -> RusTorchResult<f32> {
let size = input.len();
let block_size = 256;
let grid_size = size.div_ceil(block_size);
let input_buffer = self.device.new_buffer_with_data(
input.as_ptr() as *const c_void,
(size * std::mem::size_of::<f32>()) as u64,
MTLResourceOptions::StorageModeShared,
);
let output_buffer = self.device.new_buffer(
(grid_size * std::mem::size_of::<f32>()) as u64,
MTLResourceOptions::StorageModeShared,
);
let pipeline_state = self
.pipeline_states
.get(&MetalKernelType::Reduction)
.ok_or_else(|| {
RusTorchError::KernelError("Reduction pipeline not found".to_string())
})?;
let command_buffer = self.command_queue.new_command_buffer();
let compute_encoder = command_buffer.new_compute_command_encoder();
compute_encoder.set_compute_pipeline_state(pipeline_state);
compute_encoder.set_buffer(0, Some(&input_buffer), 0);
compute_encoder.set_buffer(1, Some(&output_buffer), 0);
compute_encoder
.set_threadgroup_memory_length(0, (block_size * std::mem::size_of::<f32>()) as u64);
let threads_per_threadgroup = MTLSize::new(block_size as u64, 1, 1);
let threadgroups_per_grid = MTLSize::new(grid_size as u64, 1, 1);
compute_encoder.dispatch_thread_groups(threadgroups_per_grid, threads_per_threadgroup);
compute_encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();
let mut partial_results = vec![0.0f32; grid_size];
unsafe {
let result_ptr = output_buffer.contents() as *const f32;
std::ptr::copy_nonoverlapping(result_ptr, partial_results.as_mut_ptr(), grid_size);
}
Ok(partial_results.iter().sum())
}
pub fn relu_f32(&self, input: &[f32], output: &mut [f32]) -> RusTorchResult<()> {
if input.len() != output.len() {
return Err(RusTorchError::InvalidOperation(
"Input and output lengths must match".to_string(),
));
}
let n = input.len();
let input_buffer = self.device.new_buffer_with_data(
input.as_ptr() as *const c_void,
(n * std::mem::size_of::<f32>()) as u64,
metal::MTLResourceOptions::StorageModeShared,
);
let output_buffer = self.device.new_buffer(
(n * std::mem::size_of::<f32>()) as u64,
metal::MTLResourceOptions::StorageModeShared,
);
let function = self
.library
.get_function("relu_f32", None)
.map_err(|e| RusTorchError::gpu(&format!("Failed to get ReLU function: {}", e)))?;
let pipeline_state = self
.device
.new_compute_pipeline_state_with_function(&function)
.map_err(|e| RusTorchError::gpu(&format!("Failed to create ReLU pipeline: {}", e)))?;
let command_buffer = self.command_queue.new_command_buffer();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline_state);
encoder.set_buffer(0, Some(&input_buffer), 0);
encoder.set_buffer(1, Some(&output_buffer), 0);
encoder.set_bytes(
2,
std::mem::size_of::<u32>() as u64,
&(n as u32) as *const u32 as *const c_void,
);
let threads_per_threadgroup = metal::MTLSize::new(256, 1, 1);
let threadgroups = metal::MTLSize::new(((n + 255) / 256).try_into().unwrap(), 1, 1);
encoder.dispatch_thread_groups(threadgroups, threads_per_threadgroup);
encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();
unsafe {
let result_ptr = output_buffer.contents() as *const f32;
std::ptr::copy_nonoverlapping(result_ptr, output.as_mut_ptr(), n);
}
Ok(())
}
pub fn sigmoid_f32(&self, input: &[f32], output: &mut [f32]) -> RusTorchResult<()> {
if input.len() != output.len() {
return Err(RusTorchError::InvalidOperation(
"Input and output lengths must match".to_string(),
));
}
let n = input.len();
let input_buffer = self.device.new_buffer_with_data(
input.as_ptr() as *const c_void,
(n * std::mem::size_of::<f32>()) as u64,
metal::MTLResourceOptions::StorageModeShared,
);
let output_buffer = self.device.new_buffer(
(n * std::mem::size_of::<f32>()) as u64,
metal::MTLResourceOptions::StorageModeShared,
);
let function = self
.library
.get_function("sigmoid_f32", None)
.map_err(|e| RusTorchError::gpu(&format!("Failed to get Sigmoid function: {}", e)))?;
let pipeline_state = self
.device
.new_compute_pipeline_state_with_function(&function)
.map_err(|e| {
RusTorchError::gpu(&format!("Failed to create Sigmoid pipeline: {}", e))
})?;
let command_buffer = self.command_queue.new_command_buffer();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline_state);
encoder.set_buffer(0, Some(&input_buffer), 0);
encoder.set_buffer(1, Some(&output_buffer), 0);
encoder.set_bytes(
2,
std::mem::size_of::<u32>() as u64,
&(n as u32) as *const u32 as *const c_void,
);
let threads_per_threadgroup = metal::MTLSize::new(256, 1, 1);
let threadgroups = metal::MTLSize::new(((n + 255) / 256).try_into().unwrap(), 1, 1);
encoder.dispatch_thread_groups(threadgroups, threads_per_threadgroup);
encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();
unsafe {
let result_ptr = output_buffer.contents() as *const f32;
std::ptr::copy_nonoverlapping(result_ptr, output.as_mut_ptr(), n);
}
Ok(())
}
pub fn tanh_f32(&self, input: &[f32], output: &mut [f32]) -> RusTorchResult<()> {
if input.len() != output.len() {
return Err(RusTorchError::InvalidOperation(
"Input and output lengths must match".to_string(),
));
}
let n = input.len();
let input_buffer = self.device.new_buffer_with_data(
input.as_ptr() as *const c_void,
(n * std::mem::size_of::<f32>()) as u64,
metal::MTLResourceOptions::StorageModeShared,
);
let output_buffer = self.device.new_buffer(
(n * std::mem::size_of::<f32>()) as u64,
metal::MTLResourceOptions::StorageModeShared,
);
let function = self
.library
.get_function("tanh_f32", None)
.map_err(|e| RusTorchError::gpu(&format!("Failed to get Tanh function: {}", e)))?;
let pipeline_state = self
.device
.new_compute_pipeline_state_with_function(&function)
.map_err(|e| RusTorchError::gpu(&format!("Failed to create Tanh pipeline: {}", e)))?;
let command_buffer = self.command_queue.new_command_buffer();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline_state);
encoder.set_buffer(0, Some(&input_buffer), 0);
encoder.set_buffer(1, Some(&output_buffer), 0);
encoder.set_bytes(
2,
std::mem::size_of::<u32>() as u64,
&(n as u32) as *const u32 as *const c_void,
);
let threads_per_threadgroup = metal::MTLSize::new(256, 1, 1);
let threadgroups = metal::MTLSize::new(((n + 255) / 256).try_into().unwrap(), 1, 1);
encoder.dispatch_thread_groups(threadgroups, threads_per_threadgroup);
encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();
unsafe {
let result_ptr = output_buffer.contents() as *const f32;
std::ptr::copy_nonoverlapping(result_ptr, output.as_mut_ptr(), n);
}
Ok(())
}
pub fn gelu_f32(&self, input: &[f32], output: &mut [f32]) -> RusTorchResult<()> {
if input.len() != output.len() {
return Err(RusTorchError::InvalidOperation(
"Input and output lengths must match".to_string(),
));
}
let n = input.len();
let input_buffer = self.device.new_buffer_with_data(
input.as_ptr() as *const c_void,
(n * std::mem::size_of::<f32>()) as u64,
metal::MTLResourceOptions::StorageModeShared,
);
let output_buffer = self.device.new_buffer(
(n * std::mem::size_of::<f32>()) as u64,
metal::MTLResourceOptions::StorageModeShared,
);
let function = self
.library
.get_function("gelu_f32", None)
.map_err(|e| RusTorchError::gpu(&format!("Failed to get GELU function: {}", e)))?;
let pipeline_state = self
.device
.new_compute_pipeline_state_with_function(&function)
.map_err(|e| RusTorchError::gpu(&format!("Failed to create GELU pipeline: {}", e)))?;
let command_buffer = self.command_queue.new_command_buffer();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline_state);
encoder.set_buffer(0, Some(&input_buffer), 0);
encoder.set_buffer(1, Some(&output_buffer), 0);
encoder.set_bytes(
2,
std::mem::size_of::<u32>() as u64,
&(n as u32) as *const u32 as *const c_void,
);
let threads_per_threadgroup = metal::MTLSize::new(256, 1, 1);
let threadgroups = metal::MTLSize::new(((n + 255) / 256).try_into().unwrap(), 1, 1);
encoder.dispatch_thread_groups(threadgroups, threads_per_threadgroup);
encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();
unsafe {
let result_ptr = output_buffer.contents() as *const f32;
std::ptr::copy_nonoverlapping(result_ptr, output.as_mut_ptr(), n);
}
Ok(())
}
pub fn leaky_relu_f32(
&self,
input: &[f32],
output: &mut [f32],
alpha: f32,
) -> RusTorchResult<()> {
if input.len() != output.len() {
return Err(RusTorchError::InvalidOperation(
"Input and output lengths must match".to_string(),
));
}
let n = input.len();
let input_buffer = self.device.new_buffer_with_data(
input.as_ptr() as *const c_void,
(n * std::mem::size_of::<f32>()) as u64,
metal::MTLResourceOptions::StorageModeShared,
);
let output_buffer = self.device.new_buffer(
(n * std::mem::size_of::<f32>()) as u64,
metal::MTLResourceOptions::StorageModeShared,
);
let function = self
.library
.get_function("leaky_relu_f32", None)
.map_err(|e| {
RusTorchError::gpu(&format!("Failed to get Leaky ReLU function: {}", e))
})?;
let pipeline_state = self
.device
.new_compute_pipeline_state_with_function(&function)
.map_err(|e| {
RusTorchError::gpu(&format!("Failed to create Leaky ReLU pipeline: {}", e))
})?;
let command_buffer = self.command_queue.new_command_buffer();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline_state);
encoder.set_buffer(0, Some(&input_buffer), 0);
encoder.set_buffer(1, Some(&output_buffer), 0);
encoder.set_bytes(
2,
std::mem::size_of::<u32>() as u64,
&(n as u32) as *const u32 as *const c_void,
);
encoder.set_bytes(
3,
std::mem::size_of::<f32>() as u64,
&alpha as *const f32 as *const c_void,
);
let threads_per_threadgroup = metal::MTLSize::new(256, 1, 1);
let threadgroups = metal::MTLSize::new(((n + 255) / 256).try_into().unwrap(), 1, 1);
encoder.dispatch_thread_groups(threadgroups, threads_per_threadgroup);
encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();
unsafe {
let result_ptr = output_buffer.contents() as *const f32;
std::ptr::copy_nonoverlapping(result_ptr, output.as_mut_ptr(), n);
}
Ok(())
}
pub fn elu_f32(&self, input: &[f32], output: &mut [f32], alpha: f32) -> RusTorchResult<()> {
if input.len() != output.len() {
return Err(RusTorchError::InvalidOperation(
"Input and output lengths must match".to_string(),
));
}
let n = input.len();
let input_buffer = self.device.new_buffer_with_data(
input.as_ptr() as *const c_void,
(n * std::mem::size_of::<f32>()) as u64,
metal::MTLResourceOptions::StorageModeShared,
);
let output_buffer = self.device.new_buffer(
(n * std::mem::size_of::<f32>()) as u64,
metal::MTLResourceOptions::StorageModeShared,
);
let function = self
.library
.get_function("elu_f32", None)
.map_err(|e| RusTorchError::gpu(&format!("Failed to get ELU function: {}", e)))?;
let pipeline_state = self
.device
.new_compute_pipeline_state_with_function(&function)
.map_err(|e| RusTorchError::gpu(&format!("Failed to create ELU pipeline: {}", e)))?;
let command_buffer = self.command_queue.new_command_buffer();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline_state);
encoder.set_buffer(0, Some(&input_buffer), 0);
encoder.set_buffer(1, Some(&output_buffer), 0);
encoder.set_bytes(
2,
std::mem::size_of::<u32>() as u64,
&(n as u32) as *const u32 as *const c_void,
);
encoder.set_bytes(
3,
std::mem::size_of::<f32>() as u64,
&alpha as *const f32 as *const c_void,
);
let threads_per_threadgroup = metal::MTLSize::new(256, 1, 1);
let threadgroups = metal::MTLSize::new(((n + 255) / 256).try_into().unwrap(), 1, 1);
encoder.dispatch_thread_groups(threadgroups, threads_per_threadgroup);
encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();
unsafe {
let result_ptr = output_buffer.contents() as *const f32;
std::ptr::copy_nonoverlapping(result_ptr, output.as_mut_ptr(), n);
}
Ok(())
}
pub fn swish_f32(&self, input: &[f32], output: &mut [f32]) -> RusTorchResult<()> {
if input.len() != output.len() {
return Err(RusTorchError::InvalidOperation(
"Input and output lengths must match".to_string(),
));
}
let n = input.len();
let input_buffer = self.device.new_buffer_with_data(
input.as_ptr() as *const c_void,
(n * std::mem::size_of::<f32>()) as u64,
metal::MTLResourceOptions::StorageModeShared,
);
let output_buffer = self.device.new_buffer(
(n * std::mem::size_of::<f32>()) as u64,
metal::MTLResourceOptions::StorageModeShared,
);
let function = self
.library
.get_function("swish_f32", None)
.map_err(|e| RusTorchError::gpu(&format!("Failed to get Swish function: {}", e)))?;
let pipeline_state = self
.device
.new_compute_pipeline_state_with_function(&function)
.map_err(|e| RusTorchError::gpu(&format!("Failed to create Swish pipeline: {}", e)))?;
let command_buffer = self.command_queue.new_command_buffer();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline_state);
encoder.set_buffer(0, Some(&input_buffer), 0);
encoder.set_buffer(1, Some(&output_buffer), 0);
encoder.set_bytes(
2,
std::mem::size_of::<u32>() as u64,
&(n as u32) as *const u32 as *const c_void,
);
let threads_per_threadgroup = metal::MTLSize::new(256, 1, 1);
let threadgroups = metal::MTLSize::new(((n + 255) / 256).try_into().unwrap(), 1, 1);
encoder.dispatch_thread_groups(threadgroups, threads_per_threadgroup);
encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();
unsafe {
let result_ptr = output_buffer.contents() as *const f32;
std::ptr::copy_nonoverlapping(result_ptr, output.as_mut_ptr(), n);
}
Ok(())
}
}
#[cfg(not(feature = "metal"))]
pub struct MetalKernelExecutor;
#[cfg(not(feature = "metal"))]
impl MetalKernelExecutor {
pub fn new() -> RusTorchResult<Self> {
Err(RusTorchError::UnsupportedDevice(
"Metal not available".to_string(),
))
}
pub fn elementwise_add_f32(
&self,
_a: &[f32],
_b: &[f32],
_c: &mut [f32],
) -> RusTorchResult<()> {
Err(RusTorchError::UnsupportedDevice(
"Metal not available".to_string(),
))
}
pub fn matmul_f32(
&self,
_a: &[f32],
_b: &[f32],
_c: &mut [f32],
_m: usize,
_n: usize,
_k: usize,
) -> RusTorchResult<()> {
Err(RusTorchError::UnsupportedDevice(
"Metal not available".to_string(),
))
}
pub fn reduce_sum_f32(&self, _input: &[f32]) -> RusTorchResult<f32> {
Err(RusTorchError::UnsupportedDevice(
"Metal not available".to_string(),
))
}
}
pub fn metal_matmul_f32(
_a: &[f32],
_b: &[f32],
_c: &mut [f32],
_m: usize,
_n: usize,
_k: usize,
) -> RusTorchResult<()> {
#[cfg(feature = "metal")]
{
let executor = MetalKernelExecutor::new()?;
executor.matmul_f32(_a, _b, _c, _m, _n, _k)
}
#[cfg(not(feature = "metal"))]
{
Err(RusTorchError::UnsupportedDevice(
"Metal not available".to_string(),
))
}
}
pub fn metal_elementwise_add_f32(_a: &[f32], _b: &[f32], _c: &mut [f32]) -> RusTorchResult<()> {
#[cfg(feature = "metal")]
{
let executor = MetalKernelExecutor::new()?;
executor.elementwise_add_f32(_a, _b, _c)
}
#[cfg(not(feature = "metal"))]
{
Err(RusTorchError::UnsupportedDevice(
"Metal not available".to_string(),
))
}
}
pub fn metal_reduce_sum_f32(_input: &[f32]) -> RusTorchResult<f32> {
#[cfg(feature = "metal")]
{
let executor = MetalKernelExecutor::new()?;
executor.reduce_sum_f32(_input)
}
#[cfg(not(feature = "metal"))]
{
Err(RusTorchError::UnsupportedDevice(
"Metal not available".to_string(),
))
}
}
pub fn metal_conv2d_f32(
input: &[f32],
kernel: &[f32],
output: &mut [f32],
input_height: usize,
input_width: usize,
input_channels: usize,
output_channels: usize,
kernel_height: usize,
kernel_width: usize,
stride_h: usize,
stride_w: usize,
pad_h: usize,
pad_w: usize,
) -> RusTorchResult<()> {
#[cfg(feature = "metal")]
{
let executor = MetalKernelExecutor::new()?;
executor.conv2d_f32(
input,
kernel,
output,
input_height,
input_width,
input_channels,
output_channels,
kernel_height,
kernel_width,
stride_h,
stride_w,
pad_h,
pad_w,
)
}
#[cfg(not(feature = "metal"))]
{
Err(RusTorchError::UnsupportedDevice(
"Metal not available".to_string(),
))
}
}
#[cfg(feature = "metal")]
pub fn metal_relu_f32(input: &[f32], output: &mut [f32]) -> RusTorchResult<()> {
let executor = MetalKernelExecutor::new()?;
executor.relu_f32(input, output)
}
#[cfg(not(feature = "metal"))]
pub fn metal_relu_f32(_input: &[f32], _output: &mut [f32]) -> RusTorchResult<()> {
Err(RusTorchError::UnsupportedDevice(
"Metal not available".to_string(),
))
}
#[cfg(feature = "metal")]
pub fn metal_sigmoid_f32(input: &[f32], output: &mut [f32]) -> RusTorchResult<()> {
let executor = MetalKernelExecutor::new()?;
executor.sigmoid_f32(input, output)
}
#[cfg(not(feature = "metal"))]
pub fn metal_sigmoid_f32(_input: &[f32], _output: &mut [f32]) -> RusTorchResult<()> {
Err(RusTorchError::UnsupportedDevice(
"Metal not available".to_string(),
))
}
#[cfg(feature = "metal")]
pub fn metal_tanh_f32(input: &[f32], output: &mut [f32]) -> RusTorchResult<()> {
let executor = MetalKernelExecutor::new()?;
executor.tanh_f32(input, output)
}
#[cfg(not(feature = "metal"))]
pub fn metal_tanh_f32(_input: &[f32], _output: &mut [f32]) -> RusTorchResult<()> {
Err(RusTorchError::UnsupportedDevice(
"Metal not available".to_string(),
))
}
#[cfg(feature = "metal")]
pub fn metal_gelu_f32(input: &[f32], output: &mut [f32]) -> RusTorchResult<()> {
let executor = MetalKernelExecutor::new()?;
executor.gelu_f32(input, output)
}
#[cfg(not(feature = "metal"))]
pub fn metal_gelu_f32(_input: &[f32], _output: &mut [f32]) -> RusTorchResult<()> {
Err(RusTorchError::UnsupportedDevice(
"Metal not available".to_string(),
))
}
#[cfg(feature = "metal")]
pub fn metal_leaky_relu_f32(input: &[f32], output: &mut [f32], alpha: f32) -> RusTorchResult<()> {
let executor = MetalKernelExecutor::new()?;
executor.leaky_relu_f32(input, output, alpha)
}
#[cfg(not(feature = "metal"))]
pub fn metal_leaky_relu_f32(
_input: &[f32],
_output: &mut [f32],
_alpha: f32,
) -> RusTorchResult<()> {
Err(RusTorchError::UnsupportedDevice(
"Metal not available".to_string(),
))
}
#[cfg(feature = "metal")]
pub fn metal_elu_f32(input: &[f32], output: &mut [f32], alpha: f32) -> RusTorchResult<()> {
let executor = MetalKernelExecutor::new()?;
executor.elu_f32(input, output, alpha)
}
#[cfg(not(feature = "metal"))]
pub fn metal_elu_f32(_input: &[f32], _output: &mut [f32], _alpha: f32) -> RusTorchResult<()> {
Err(RusTorchError::UnsupportedDevice(
"Metal not available".to_string(),
))
}
#[cfg(feature = "metal")]
pub fn metal_swish_f32(input: &[f32], output: &mut [f32]) -> RusTorchResult<()> {
let executor = MetalKernelExecutor::new()?;
executor.swish_f32(input, output)
}
#[cfg(not(feature = "metal"))]
pub fn metal_swish_f32(_input: &[f32], _output: &mut [f32]) -> RusTorchResult<()> {
Err(RusTorchError::UnsupportedDevice(
"Metal not available".to_string(),
))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_metal_kernel_params() {
let params = MetalKernelParams::default();
assert_eq!(params.threads_per_threadgroup, (1, 1, 1));
assert_eq!(params.threadgroups_per_grid, (1, 1, 1));
}
#[test]
fn test_metal_executor_creation() {
let result = MetalKernelExecutor::new();
#[cfg(not(feature = "metal"))]
assert!(result.is_err());
}
#[test]
fn test_metal_kernel_types() {
assert_eq!(MetalKernelType::ElementWise, MetalKernelType::ElementWise);
assert_ne!(MetalKernelType::ElementWise, MetalKernelType::MatMul);
}
}