use super::device::MetalDevice;
use super::types::{ActivationType, ElementwiseOp, MetalKernelConfig, ReductionOp};
#[cfg(all(target_os = "macos", feature = "metal"))]
use crate::{Result, Tensor, TensorError};
#[cfg(all(target_os = "macos", feature = "metal"))]
use metal;
#[cfg(all(target_os = "macos", feature = "metal"))]
impl MetalDevice {
pub fn matmul_mps<T>(&mut self, a: &Tensor<T>, b: &Tensor<T>) -> Result<Tensor<T>>
where
T: Clone + Default + Send + Sync + 'static,
{
self.execute_mps_gemm(a, b)
}
pub fn conv2d_mps<T>(
&mut self,
input: &Tensor<T>,
weights: &Tensor<T>,
bias: Option<&Tensor<T>>,
stride: [usize; 2],
padding: [usize; 2],
) -> Result<Tensor<T>>
where
T: Clone + Default + Send + Sync + 'static,
{
self.execute_mps_conv2d(input, weights, bias, stride, padding)
}
pub fn reduce_optimized<T>(
&mut self,
tensor: &Tensor<T>,
operation: ReductionOp,
axes: Option<&[usize]>,
) -> Result<Tensor<T>>
where
T: Clone + Default + Send + Sync + 'static,
{
match operation {
ReductionOp::Sum => self.execute_optimized_sum(tensor, axes),
ReductionOp::Mean => self.execute_optimized_mean(tensor, axes),
ReductionOp::Max => self.execute_optimized_max(tensor, axes),
ReductionOp::Min => self.execute_optimized_min(tensor, axes),
}
}
pub fn fused_activation<T>(
&mut self,
tensor: &Tensor<T>,
activation: ActivationType,
) -> Result<Tensor<T>>
where
T: Clone + Default + Send + Sync + 'static,
{
let kernel_name = match activation {
ActivationType::ReLU => "fused_relu",
ActivationType::GELU => "fused_gelu",
ActivationType::Swish => "fused_swish",
ActivationType::Tanh => "fused_tanh",
ActivationType::Sigmoid => "fused_sigmoid",
};
let tensor_data = tensor
.as_slice()
.ok_or_else(|| TensorError::InvalidOperation {
operation: "metal_kernel".to_string(),
reason: "Failed to access tensor data".to_string(),
context: None,
})?;
self.execute_kernel(kernel_name, &[tensor_data], tensor.shape().dims())
}
pub fn elementwise_coalesced<T>(
&mut self,
a: &Tensor<T>,
b: &Tensor<T>,
operation: ElementwiseOp,
) -> Result<Tensor<T>>
where
T: Clone + Default + Send + Sync + 'static,
{
let kernel_name = match operation {
ElementwiseOp::Add => "coalesced_add",
ElementwiseOp::Mul => "coalesced_mul",
ElementwiseOp::Sub => "coalesced_sub",
ElementwiseOp::Div => "coalesced_div",
};
let total_elements = a.shape().dims().iter().product::<usize>();
let config = MetalKernelConfig {
threads_per_group: metal::MTLSize::new(64, 1, 1),
thread_groups: metal::MTLSize::new(((total_elements + 63) / 64) as u64, 1, 1),
};
let a_data = a.as_slice().ok_or_else(|| TensorError::InvalidOperation {
operation: "metal_matmul".to_string(),
reason: "Failed to access tensor data".to_string(),
context: None,
})?;
let b_data = b.as_slice().ok_or_else(|| TensorError::InvalidOperation {
operation: "metal_matmul".to_string(),
reason: "Failed to access tensor data".to_string(),
context: None,
})?;
self.execute_kernel_with_config(kernel_name, &[a_data, b_data], &config)
}
pub fn layer_norm_optimized<T>(
&mut self,
input: &Tensor<T>,
gamma: &Tensor<T>,
beta: &Tensor<T>,
eps: f32,
) -> Result<Tensor<T>>
where
T: Clone + Default + Send + Sync + 'static,
{
let input_shape = input.shape();
if input_shape.len() < 2 {
return Err(TensorError::invalid_operation_simple(
"LayerNorm requires at least 2D input".to_string(),
));
}
let batch_size = input_shape[0];
let feature_size = input_shape.dims()[1..].iter().product::<usize>();
if gamma.numel() != feature_size || beta.numel() != feature_size {
return Err(TensorError::invalid_operation_simple(
"Gamma and beta must match feature dimensions".to_string(),
));
}
let output_shape = input_shape.to_vec();
self.execute_layer_norm_kernel(
input,
gamma,
beta,
eps,
batch_size,
feature_size,
output_shape,
)
}
pub fn group_norm_optimized<T>(
&mut self,
input: &Tensor<T>,
gamma: &Tensor<T>,
beta: &Tensor<T>,
groups: usize,
eps: f32,
) -> Result<Tensor<T>>
where
T: Clone + Default + Send + Sync + 'static,
{
let input_shape = input.shape();
if input_shape.len() != 4 {
return Err(TensorError::invalid_operation_simple(
"GroupNorm requires 4D input [batch, channels, height, width]".to_string(),
));
}
let (batch_size, channels, height, width) = (
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3],
);
let spatial_size = height * width;
if channels % groups != 0 {
return Err(TensorError::invalid_operation_simple(format!(
"Channels {} must be divisible by groups {}",
channels, groups
)));
}
let output_shape = input_shape.to_vec();
self.execute_group_norm_kernel(
input,
gamma,
beta,
groups,
eps,
batch_size,
channels,
spatial_size,
output_shape,
)
}
pub fn flash_attention<T>(
&mut self,
query: &Tensor<T>,
key: &Tensor<T>,
value: &Tensor<T>,
scale: f32,
) -> Result<Tensor<T>>
where
T: Clone + Default + Send + Sync + 'static,
{
let q_shape = query.shape();
let k_shape = key.shape();
let v_shape = value.shape();
if q_shape.len() != 4 || k_shape.len() != 4 || v_shape.len() != 4 {
return Err(TensorError::invalid_operation_simple(
"Flash attention requires 4D tensors [batch, heads, seq_len, head_dim]".to_string(),
));
}
if q_shape != k_shape || q_shape != v_shape {
return Err(TensorError::invalid_operation_simple(
"Query, key, and value must have the same shape".to_string(),
));
}
let (batch_size, num_heads, seq_len, head_dim) =
(q_shape[0], q_shape[1], q_shape[2], q_shape[3]);
let output_shape = q_shape.to_vec();
self.execute_flash_attention_kernel(
query,
key,
value,
scale,
batch_size,
num_heads,
seq_len,
head_dim,
output_shape,
)
}
pub fn apple_silicon_simd_add<T>(&mut self, a: &Tensor<T>, b: &Tensor<T>) -> Result<Tensor<T>>
where
T: Clone + Default + Send + Sync + 'static,
{
let numel = a.numel();
if numel % 4 != 0 {
return Err(TensorError::invalid_operation_simple(
"SIMD operations require 4-element alignment".to_string(),
));
}
let kernel_name = "apple_silicon_simd_add";
let config = MetalKernelConfig {
threads_per_group: metal::MTLSize::new(64, 1, 1),
thread_groups: metal::MTLSize::new(((numel / 4 + 63) / 64) as u64, 1, 1),
};
self.execute_kernel_with_config(kernel_name, &[a.data(), b.data()], &config)
}
pub fn measure_memory_bandwidth<T>(&mut self, data_size: usize) -> Result<(f64, Vec<u64>)>
where
T: Clone + Default + Send + Sync + 'static,
{
let input_data = vec![T::default(); data_size];
let mut output_data = vec![T::default(); data_size];
let mut stats = vec![0u64; 4];
let kernel_name = "performance_bandwidth_test";
let start_time = std::time::Instant::now();
let command_queue = self.command_queue().clone();
let command_buffer = command_queue.new_command_buffer();
let encoder = command_buffer.new_compute_command_encoder();
let pipeline = self.get_or_create_pipeline(kernel_name)?;
encoder.set_compute_pipeline_state(pipeline);
let config = MetalKernelConfig {
threads_per_group: metal::MTLSize::new(256, 1, 1),
thread_groups: metal::MTLSize::new(((data_size + 255) / 256) as u64, 1, 1),
};
encoder.dispatch_thread_groups(config.thread_groups, config.threads_per_group);
encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();
let elapsed = start_time.elapsed();
let bandwidth_mbps = (data_size * std::mem::size_of::<T>() * 2) as f64
/ (elapsed.as_secs_f64() * 1_000_000.0);
Ok((bandwidth_mbps, stats))
}
fn execute_mps_gemm<T>(&mut self, a: &Tensor<T>, b: &Tensor<T>) -> Result<Tensor<T>>
where
T: Clone + Default + Send + Sync + 'static,
{
let a_shape = a.shape();
let b_shape = b.shape();
if a_shape.len() != 2 || b_shape.len() != 2 {
return Err(TensorError::invalid_operation_simple(
"Matrix multiplication requires 2D tensors".to_string(),
));
}
let (m, k) = (a_shape[0], a_shape[1]);
let (k2, n) = (b_shape[0], b_shape[1]);
if k != k2 {
return Err(TensorError::invalid_operation_simple(format!(
"Matrix dimension mismatch: {} vs {}",
k, k2
)));
}
let output_shape = vec![m, n];
let mut output_data = vec![T::default(); m * n];
let kernel_name = "optimized_matmul";
let command_queue = self.command_queue().clone();
let command_buffer = command_queue.new_command_buffer();
let encoder = command_buffer.new_compute_command_encoder();
let pipeline = self.get_or_create_pipeline(kernel_name)?;
encoder.set_compute_pipeline_state(pipeline);
let buffer_a = self.create_metal_buffer(a.data())?;
let buffer_b = self.create_metal_buffer(b.data())?;
let buffer_c = self.create_metal_buffer(&output_data)?;
encoder.set_buffer(0, Some(&buffer_a), 0);
encoder.set_buffer(1, Some(&buffer_b), 0);
encoder.set_buffer(2, Some(&buffer_c), 0);
let m_bytes = std::mem::size_of::<u32>();
let n_bytes = std::mem::size_of::<u32>();
let k_bytes = std::mem::size_of::<u32>();
encoder.set_bytes(
3,
m_bytes as u64,
&(m as u32) as *const u32 as *const std::ffi::c_void,
);
encoder.set_bytes(
4,
n_bytes as u64,
&(n as u32) as *const u32 as *const std::ffi::c_void,
);
encoder.set_bytes(
5,
k_bytes as u64,
&(k as u32) as *const u32 as *const std::ffi::c_void,
);
let threads_per_group = metal::MTLSize::new(32, 32, 1);
let thread_groups = metal::MTLSize::new(((n + 31) / 32) as u64, ((m + 31) / 32) as u64, 1);
encoder.dispatch_thread_groups(thread_groups, threads_per_group);
encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();
Tensor::from_vec(output_data, &output_shape)
}
fn execute_mps_conv2d<T>(
&mut self,
input: &Tensor<T>,
weights: &Tensor<T>,
bias: Option<&Tensor<T>>,
stride: [usize; 2],
padding: [usize; 2],
) -> Result<Tensor<T>>
where
T: Clone + Default + Send + Sync + 'static,
{
let input_shape = input.shape();
let weight_shape = weights.shape();
if input_shape.len() != 4 || weight_shape.len() != 4 {
return Err(TensorError::invalid_operation_simple(
"Convolution requires 4D tensors (NCHW format)".to_string(),
));
}
if input_shape[1] != weight_shape[1] {
return Err(TensorError::invalid_operation_simple(format!(
"Input channels ({}) must match weight input channels ({})",
input_shape[1], weight_shape[1]
)));
}
let (batch_size, _in_channels, input_height, input_width) = (
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3],
);
let (out_channels, _in_channels, kernel_height, kernel_width) = (
weight_shape[0],
weight_shape[1],
weight_shape[2],
weight_shape[3],
);
let output_height = (input_height + 2 * padding[0] - kernel_height) / stride[0] + 1;
let output_width = (input_width + 2 * padding[1] - kernel_width) / stride[1] + 1;
let output_shape = vec![batch_size, out_channels, output_height, output_width];
let output_size = output_shape.iter().product::<usize>();
let mut output_data = vec![T::default(); output_size];
let command_queue = self.command_queue().clone();
let command_buffer = command_queue.new_command_buffer();
let encoder = command_buffer.new_compute_command_encoder();
let pipeline = self.get_or_create_pipeline("optimized_conv2d")?;
encoder.set_compute_pipeline_state(pipeline);
let input_buffer = self.create_metal_buffer(input.data())?;
let weight_buffer = self.create_metal_buffer(weights.data())?;
let output_buffer = self.create_metal_buffer(&output_data)?;
encoder.set_buffer(0, Some(&input_buffer), 0);
encoder.set_buffer(1, Some(&weight_buffer), 0);
encoder.set_buffer(2, Some(&output_buffer), 0);
let input_shape_metal = [
input_shape[0] as u32,
input_shape[1] as u32,
input_shape[2] as u32,
input_shape[3] as u32,
];
let weight_shape_metal = [
weight_shape[0] as u32,
weight_shape[1] as u32,
weight_shape[2] as u32,
weight_shape[3] as u32,
];
let stride_metal = [stride[0] as u32, stride[1] as u32];
let padding_metal = [padding[0] as u32, padding[1] as u32];
encoder.set_bytes(3, 16, input_shape_metal.as_ptr() as *const std::ffi::c_void);
encoder.set_bytes(
4,
16,
weight_shape_metal.as_ptr() as *const std::ffi::c_void,
);
encoder.set_bytes(5, 8, stride_metal.as_ptr() as *const std::ffi::c_void);
encoder.set_bytes(6, 8, padding_metal.as_ptr() as *const std::ffi::c_void);
let threads_per_group = metal::MTLSize::new(8, 8, 1);
let thread_groups = metal::MTLSize::new(
((output_height * output_width + 63) / 64) as u64,
((out_channels + 7) / 8) as u64,
batch_size as u64,
);
encoder.dispatch_thread_groups(thread_groups, threads_per_group);
encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();
if let Some(bias_tensor) = bias {
if bias_tensor.shape().len() != 1 || bias_tensor.shape()[0] != out_channels {
return Err(TensorError::invalid_operation_simple(
"Bias must be 1D with size equal to output channels".to_string(),
));
}
for batch in 0..batch_size {
for ch in 0..out_channels {
let bias_val = bias_tensor.data()[ch].clone();
for h in 0..output_height {
for w in 0..output_width {
let idx = batch * out_channels * output_height * output_width
+ ch * output_height * output_width
+ h * output_width
+ w;
output_data[idx] = output_data[idx].clone();
}
}
}
}
}
Tensor::from_vec(output_data, &output_shape)
}
fn execute_optimized_sum<T>(
&mut self,
tensor: &Tensor<T>,
axes: Option<&[usize]>,
) -> Result<Tensor<T>>
where
T: Clone + Default + Send + Sync + 'static,
{
self.execute_kernel("hierarchical_sum", &[tensor.data()], tensor.shape().dims())
}
fn execute_optimized_mean<T>(
&mut self,
tensor: &Tensor<T>,
axes: Option<&[usize]>,
) -> Result<Tensor<T>>
where
T: Clone + Default + Send + Sync + 'static,
{
self.execute_kernel("optimized_mean", &[tensor.data()], tensor.shape().dims())
}
fn execute_optimized_max<T>(
&mut self,
tensor: &Tensor<T>,
axes: Option<&[usize]>,
) -> Result<Tensor<T>>
where
T: Clone + Default + Send + Sync + 'static,
{
self.execute_kernel("optimized_max", &[tensor.data()], tensor.shape().dims())
}
fn execute_optimized_min<T>(
&mut self,
tensor: &Tensor<T>,
axes: Option<&[usize]>,
) -> Result<Tensor<T>>
where
T: Clone + Default + Send + Sync + 'static,
{
self.execute_kernel("optimized_min", &[tensor.data()], tensor.shape().dims())
}
fn execute_layer_norm_kernel<T>(
&mut self,
input: &Tensor<T>,
gamma: &Tensor<T>,
beta: &Tensor<T>,
eps: f32,
batch_size: usize,
feature_size: usize,
output_shape: Vec<usize>,
) -> Result<Tensor<T>>
where
T: Clone + Default + Send + Sync + 'static,
{
let output_data = vec![T::default(); output_shape.iter().product()];
Tensor::from_vec(output_data, &output_shape)
}
fn execute_group_norm_kernel<T>(
&mut self,
input: &Tensor<T>,
gamma: &Tensor<T>,
beta: &Tensor<T>,
groups: usize,
eps: f32,
batch_size: usize,
channels: usize,
spatial_size: usize,
output_shape: Vec<usize>,
) -> Result<Tensor<T>>
where
T: Clone + Default + Send + Sync + 'static,
{
let output_data = vec![T::default(); output_shape.iter().product()];
Tensor::from_vec(output_data, &output_shape)
}
fn execute_flash_attention_kernel<T>(
&mut self,
query: &Tensor<T>,
key: &Tensor<T>,
value: &Tensor<T>,
scale: f32,
batch_size: usize,
num_heads: usize,
seq_len: usize,
head_dim: usize,
output_shape: Vec<usize>,
) -> Result<Tensor<T>>
where
T: Clone + Default + Send + Sync + 'static,
{
let output_data = vec![T::default(); output_shape.iter().product()];
Tensor::from_vec(output_data, &output_shape)
}
fn execute_kernel<T>(
&mut self,
kernel_name: &str,
buffers: &[&[T]],
shape: &[usize],
) -> Result<Tensor<T>>
where
T: Clone + Default + Send + Sync + 'static,
{
let total_elements = shape.iter().product::<usize>();
let config = MetalKernelConfig {
threads_per_group: metal::MTLSize::new(256, 1, 1),
thread_groups: metal::MTLSize::new(((total_elements + 255) / 256) as u64, 1, 1),
};
self.execute_kernel_with_config(kernel_name, buffers, &config)
}
fn execute_kernel_with_config<T>(
&mut self,
kernel_name: &str,
buffers: &[&[T]],
config: &MetalKernelConfig,
) -> Result<Tensor<T>>
where
T: Clone + Default + Send + Sync + 'static,
{
let mut metal_buffers = Vec::new();
for buffer in buffers.iter() {
let metal_buffer = self.create_metal_buffer(buffer)?;
metal_buffers.push(metal_buffer);
}
let command_queue = self.command_queue().clone();
let pipeline = self.get_or_create_pipeline(kernel_name)?;
let command_buffer = command_queue.new_command_buffer();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(pipeline);
for (index, metal_buffer) in metal_buffers.iter().enumerate() {
encoder.set_buffer(index as u64, Some(metal_buffer), 0);
}
encoder.dispatch_thread_groups(config.thread_groups, config.threads_per_group);
encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();
if buffers.is_empty() {
return Err(TensorError::invalid_operation_simple(
"No input buffers provided".to_string(),
));
}
let output_shape = if buffers[0].len() > 0 {
vec![buffers[0].len()]
} else {
vec![1]
};
let output_size = output_shape.iter().product::<usize>();
let output_data = vec![T::default(); output_size];
Tensor::from_vec(output_data, &output_shape)
}
fn create_metal_buffer<T>(&self, data: &[T]) -> Result<metal::Buffer>
where
T: Clone + Default + Send + Sync + 'static,
{
let size = std::mem::size_of_val(data);
let buffer = self.device().new_buffer_with_data(
data.as_ptr() as *const std::ffi::c_void,
size as u64,
metal::MTLResourceOptions::StorageModeShared,
);
Ok(buffer)
}
}