use metal::foreign_types::{ForeignType, ForeignTypeRef};
use metal::{CommandBuffer, Device, NSUInteger};
use objc2::msg_send;
use objc2::runtime::AnyObject;
use crate::metal::{
mps::{create_image_descriptor, MPSDataType},
MetalBuffer, MetalError, Result,
};
pub struct MPSBatchNormalization {
batch_norm: *mut AnyObject,
mean: MetalBuffer,
variance: MetalBuffer,
gamma: MetalBuffer,
beta: MetalBuffer,
num_features: usize,
eps: f32,
momentum: f32,
}
impl MPSBatchNormalization {
pub fn new(
device: &Device,
num_features: usize,
eps: f32,
momentum: f32,
affine: bool,
) -> Result<Self> {
unsafe {
let class = objc2::class!(MPSCNNBatchNormalization);
let batch_norm: *mut AnyObject = msg_send![class, alloc];
let batch_norm: *mut AnyObject = msg_send![batch_norm,
initWithDevice: device.as_ptr() as *mut AnyObject as *mut AnyObject,
dataSource: std::ptr::null_mut::<AnyObject>()
];
let _: () = msg_send![batch_norm, setEpsilon: eps as f32];
let mean = MetalBuffer::zeros(
&torsh_core::Shape::from(vec![num_features]),
&torsh_core::DType::F32,
&crate::metal::device::MetalDevice::new()?,
)?;
let variance = MetalBuffer::zeros(
&torsh_core::Shape::from(vec![num_features]),
&torsh_core::DType::F32,
&crate::metal::device::MetalDevice::new()?,
)?;
let gamma = if affine {
MetalBuffer::ones(
&torsh_core::Shape::from(vec![num_features]),
&torsh_core::DType::F32,
&crate::metal::device::MetalDevice::new()?,
)?
} else {
MetalBuffer::zeros(
&torsh_core::Shape::from(vec![num_features]),
&torsh_core::DType::F32,
&crate::metal::device::MetalDevice::new()?,
)?
};
let beta = MetalBuffer::zeros(
&torsh_core::Shape::from(vec![num_features]),
&torsh_core::DType::F32,
&crate::metal::device::MetalDevice::new()?,
)?;
Ok(Self {
batch_norm,
mean,
variance,
gamma,
beta,
num_features,
eps,
momentum,
})
}
}
pub fn forward(
&mut self,
command_buffer: &CommandBuffer,
input: &MetalBuffer,
output: &MetalBuffer,
training: bool,
) -> Result<()> {
unsafe {
let input_shape = input.shape().dims();
if input_shape.len() != 4 {
return Err(MetalError::ShapeMismatch {
expected: vec![4],
got: vec![input_shape.len()],
});
}
let [_batch, channels, height, width] = [
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3],
];
if channels != self.num_features {
return Err(MetalError::ShapeMismatch {
expected: vec![self.num_features],
got: vec![channels],
});
}
let class = objc2::class!(MPSImage);
let input_desc = create_image_descriptor(width, height, channels, MPSDataType::Float32);
let input_image: *mut AnyObject = msg_send![class, alloc];
let input_image: *mut AnyObject = msg_send![input_image,
initWithDevice: input.buffer().device().as_ptr() as *mut AnyObject,
imageDescriptor: input_desc
];
let output_desc =
create_image_descriptor(width, height, channels, MPSDataType::Float32);
let output_image: *mut AnyObject = msg_send![class, alloc];
let output_image: *mut AnyObject = msg_send![output_image,
initWithDevice: output.buffer().device().as_ptr() as *mut AnyObject,
imageDescriptor: output_desc
];
if training {
let _: () = msg_send![self.batch_norm,
encodeToCommandBuffer: command_buffer.as_ptr() as *mut AnyObject,
sourceImage: input_image,
batchNormalizationState: std::ptr::null_mut::<AnyObject>(),
destinationImage: output_image
];
} else {
let _: () = msg_send![self.batch_norm,
encodeToCommandBuffer: command_buffer.as_ptr() as *mut AnyObject,
sourceImage: input_image,
destinationImage: output_image
];
}
Ok(())
}
}
}
impl Drop for MPSBatchNormalization {
fn drop(&mut self) {
unsafe {
let _: () = msg_send![self.batch_norm, release];
}
}
}
pub struct MPSMultiHeadAttention {
num_heads: usize,
head_dim: usize,
embed_dim: usize,
q_proj: MPSLinear,
k_proj: MPSLinear,
v_proj: MPSLinear,
out_proj: MPSLinear,
dropout_p: f32,
}
impl MPSMultiHeadAttention {
pub fn new(
device: &Device,
embed_dim: usize,
num_heads: usize,
dropout_p: f32,
) -> Result<Self> {
if embed_dim % num_heads != 0 {
return Err(MetalError::InvalidArgument(
"embed_dim must be divisible by num_heads".to_string(),
));
}
let head_dim = embed_dim / num_heads;
let q_proj = MPSLinear::new(device, embed_dim, embed_dim, true)?;
let k_proj = MPSLinear::new(device, embed_dim, embed_dim, true)?;
let v_proj = MPSLinear::new(device, embed_dim, embed_dim, true)?;
let out_proj = MPSLinear::new(device, embed_dim, embed_dim, true)?;
Ok(Self {
num_heads,
head_dim,
embed_dim,
q_proj,
k_proj,
v_proj,
out_proj,
dropout_p,
})
}
pub fn forward(
&self,
command_buffer: &CommandBuffer,
query: &MetalBuffer,
key: &MetalBuffer,
value: &MetalBuffer,
output: &MetalBuffer,
mask: Option<&MetalBuffer>,
) -> Result<()> {
let _seq_len = query.shape().dims()[1];
let scale = 1.0 / (self.head_dim as f32).sqrt();
let q_shape = query.shape();
let k_shape = key.shape();
let v_shape = value.shape();
let q = MetalBuffer::zeros(
q_shape,
&query.dtype(),
&crate::metal::device::MetalDevice::new()?,
)?;
let k = MetalBuffer::zeros(
k_shape,
&key.dtype(),
&crate::metal::device::MetalDevice::new()?,
)?;
let v = MetalBuffer::zeros(
v_shape,
&value.dtype(),
&crate::metal::device::MetalDevice::new()?,
)?;
self.q_proj.forward(command_buffer, query, &q)?;
self.k_proj.forward(command_buffer, key, &k)?;
self.v_proj.forward(command_buffer, value, &v)?;
let scores = self.scaled_dot_product_attention(command_buffer, &q, &k, &v, scale, mask)?;
self.out_proj.forward(command_buffer, &scores, output)?;
Ok(())
}
fn scaled_dot_product_attention(
&self,
command_buffer: &CommandBuffer,
q: &MetalBuffer,
k: &MetalBuffer,
v: &MetalBuffer,
scale: f32,
mask: Option<&MetalBuffer>,
) -> Result<MetalBuffer> {
let q_shape = q.shape().dims();
let k_shape = k.shape().dims();
let v_shape = v.shape().dims();
if q_shape.len() < 2 || k_shape.len() < 2 || v_shape.len() < 2 {
return Err(crate::metal::error::MetalError::ShapeMismatch {
expected: vec![2],
got: vec![q_shape.len(), k_shape.len(), v_shape.len()],
});
}
let _seq_len = q_shape[q_shape.len() - 2];
let _k_seq_len = k_shape[k_shape.len() - 2];
let qk_matmul = crate::metal::mps::MPSMatMul::new(
&q.buffer().device().to_owned(),
q,
k,
None,
1.0, 0.0, false, true, )?;
qk_matmul.encode_matmul(command_buffer, q, k)?;
let mut scores = qk_matmul.output().clone();
if scale != 1.0 {
let scaled_matmul = crate::metal::mps::MPSMatMul::new(
&q.buffer().device().to_owned(),
q,
k,
None,
scale, 0.0, false, true, )?;
scaled_matmul.encode_matmul(command_buffer, q, k)?;
scores = scaled_matmul.output().clone();
}
if let Some(_mask_buffer) = mask {
}
let softmax = crate::metal::mps::MPSActivation::new(
&q.buffer().device().to_owned(),
crate::metal::mps::ActivationType::Softmax,
)?;
let attention_weights = MetalBuffer::zeros(
&scores.shape(),
&torsh_core::DType::F32,
&crate::metal::device::MetalDevice::new()?,
)?;
softmax.apply(command_buffer, &scores, &attention_weights)?;
let output_matmul = crate::metal::mps::MPSMatMul::new(
&q.buffer().device().to_owned(),
&attention_weights,
v,
None,
1.0, 0.0, false, false, )?;
let _output = MetalBuffer::zeros(
q.shape(),
&torsh_core::DType::F32,
&crate::metal::device::MetalDevice::new()?,
)?;
output_matmul.encode_matmul(command_buffer, &attention_weights, v)?;
Ok(output_matmul.output().clone())
}
}
pub struct MPSLinear {
weight: MetalBuffer,
bias: Option<MetalBuffer>,
in_features: usize,
out_features: usize,
}
impl MPSLinear {
pub fn new(
_device: &Device,
in_features: usize,
out_features: usize,
bias: bool,
) -> Result<Self> {
let _bound = (6.0 / (in_features + out_features) as f32).sqrt();
let weight = MetalBuffer::rand(
&torsh_core::Shape::from(vec![out_features, in_features]),
&torsh_core::DType::F32,
&crate::metal::device::MetalDevice::new()?,
)?;
let bias_buffer = if bias {
Some(MetalBuffer::zeros(
&torsh_core::Shape::from(vec![out_features]),
&torsh_core::DType::F32,
&crate::metal::device::MetalDevice::new()?,
)?)
} else {
None
};
Ok(Self {
weight,
bias: bias_buffer,
in_features,
out_features,
})
}
pub fn forward(
&self,
command_buffer: &CommandBuffer,
input: &MetalBuffer,
_output: &MetalBuffer,
) -> Result<()> {
let matmul = crate::metal::mps::MPSMatMul::new(
&input.buffer().device().to_owned(),
input,
&self.weight,
self.bias.as_ref(),
1.0, if self.bias.is_some() { 1.0 } else { 0.0 }, false, true, )?;
matmul.encode_matmul(command_buffer, input, &self.weight)?;
Ok(())
}
pub fn forward_with_output(
&self,
command_buffer: &CommandBuffer,
input: &MetalBuffer,
) -> Result<MetalBuffer> {
let input_shape = input.shape().dims();
let mut output_shape = input_shape[..input_shape.len() - 1].to_vec();
output_shape.push(self.out_features);
let output = MetalBuffer::zeros(
&torsh_core::Shape::from(output_shape),
&torsh_core::DType::F32,
&crate::metal::device::MetalDevice::new()?,
)?;
self.forward(command_buffer, input, &output)?;
Ok(output)
}
}
pub struct MPSOptimizedConv2d {
conv: *mut AnyObject,
algorithm: ConvolutionAlgorithm,
params: Conv2dParams,
device: Device,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ConvolutionAlgorithm {
Direct,
Winograd,
FFT,
Im2ColGemm,
}
#[derive(Debug, Clone)]
pub struct Conv2dParams {
pub in_channels: usize,
pub out_channels: usize,
pub kernel_height: usize,
pub kernel_width: usize,
pub stride_height: usize,
pub stride_width: usize,
pub padding_height: usize,
pub padding_width: usize,
pub dilation_height: usize,
pub dilation_width: usize,
pub groups: usize,
}
impl MPSOptimizedConv2d {
pub fn new(
device: &Device,
params: Conv2dParams,
weights: &MetalBuffer,
bias: Option<&MetalBuffer>,
auto_select_algorithm: bool,
) -> Result<Self> {
let algorithm = if auto_select_algorithm {
Self::select_optimal_algorithm(¶ms)
} else {
ConvolutionAlgorithm::Direct
};
unsafe {
let conv = match algorithm {
ConvolutionAlgorithm::Winograd => {
Self::create_winograd_conv(device, ¶ms, weights, bias)?
}
ConvolutionAlgorithm::FFT => Self::create_fft_conv(device, ¶ms, weights, bias)?,
_ => Self::create_direct_conv(device, ¶ms, weights, bias)?,
};
Ok(Self {
conv,
algorithm,
params,
device: device.clone(),
})
}
}
fn select_optimal_algorithm(params: &Conv2dParams) -> ConvolutionAlgorithm {
if params.kernel_height == 3
&& params.kernel_width == 3
&& params.stride_height == 1
&& params.stride_width == 1
{
ConvolutionAlgorithm::Winograd
} else if params.kernel_height >= 7 && params.kernel_width >= 7 {
ConvolutionAlgorithm::FFT
} else {
ConvolutionAlgorithm::Direct
}
}
unsafe fn create_direct_conv(
device: &Device,
params: &Conv2dParams,
_weights: &MetalBuffer,
_bias: Option<&MetalBuffer>,
) -> Result<*mut AnyObject> {
let class = objc2::class!(MPSCNNConvolution);
let conv: *mut AnyObject = msg_send![class, alloc];
let desc_class = objc2::class!(MPSCNNConvolutionDescriptor);
let desc: *mut AnyObject = msg_send![desc_class, alloc];
let desc: *mut AnyObject = msg_send![desc, init];
let _: () = msg_send![desc, setKernelHeight: params.kernel_height as NSUInteger];
let _: () = msg_send![desc, setKernelWidth: params.kernel_width as NSUInteger];
let _: () = msg_send![desc, setInputFeatureChannels: params.in_channels as NSUInteger];
let _: () = msg_send![desc, setOutputFeatureChannels: params.out_channels as NSUInteger];
let _: () = msg_send![desc, setStrideInPixelsX: params.stride_width as NSUInteger];
let _: () = msg_send![desc, setStrideInPixelsY: params.stride_height as NSUInteger];
let conv: *mut AnyObject = msg_send![conv,
initWithDevice: device.as_ptr() as *mut AnyObject,
convolutionDescriptor: desc,
kernelWeights: std::ptr::null::<f32>(),
biasTerms: std::ptr::null::<f32>(),
flags: 0 as NSUInteger
];
Ok(conv)
}
unsafe fn create_winograd_conv(
device: &Device,
params: &Conv2dParams,
_weights: &MetalBuffer,
_bias: Option<&MetalBuffer>,
) -> Result<*mut AnyObject> {
Self::create_direct_conv(device, params, _weights, _bias)
}
unsafe fn create_fft_conv(
device: &Device,
params: &Conv2dParams,
_weights: &MetalBuffer,
_bias: Option<&MetalBuffer>,
) -> Result<*mut AnyObject> {
Self::create_direct_conv(device, params, _weights, _bias)
}
pub fn encode(
&self,
command_buffer: &CommandBuffer,
input: &MetalBuffer,
output: &MetalBuffer,
) -> Result<()> {
match self.algorithm {
ConvolutionAlgorithm::Winograd => self.encode_winograd(command_buffer, input, output),
ConvolutionAlgorithm::FFT => self.encode_fft(command_buffer, input, output),
_ => self.encode_direct(command_buffer, input, output),
}
}
fn encode_direct(
&self,
command_buffer: &CommandBuffer,
input: &MetalBuffer,
output: &MetalBuffer,
) -> Result<()> {
unsafe {
let input_shape = input.shape().dims();
let output_shape = output.shape().dims();
let class = objc2::class!(MPSImage);
let input_desc = create_image_descriptor(
input_shape[3],
input_shape[2],
input_shape[1],
MPSDataType::Float32,
);
let input_image: *mut AnyObject = msg_send![class, alloc];
let input_image: *mut AnyObject = msg_send![input_image,
initWithDevice: self.device.as_ptr() as *mut AnyObject,
imageDescriptor: input_desc
];
let output_desc = create_image_descriptor(
output_shape[3],
output_shape[2],
output_shape[1],
MPSDataType::Float32,
);
let output_image: *mut AnyObject = msg_send![class, alloc];
let output_image: *mut AnyObject = msg_send![output_image,
initWithDevice: self.device.as_ptr() as *mut AnyObject,
imageDescriptor: output_desc
];
let _: () = msg_send![self.conv,
encodeToCommandBuffer: command_buffer.as_ptr() as *mut AnyObject,
sourceImage: input_image,
destinationImage: output_image
];
Ok(())
}
}
fn encode_winograd(
&self,
command_buffer: &CommandBuffer,
input: &MetalBuffer,
output: &MetalBuffer,
) -> Result<()> {
self.encode_direct(command_buffer, input, output)
}
fn encode_fft(
&self,
command_buffer: &CommandBuffer,
input: &MetalBuffer,
output: &MetalBuffer,
) -> Result<()> {
self.encode_direct(command_buffer, input, output)
}
}
impl Drop for MPSOptimizedConv2d {
fn drop(&mut self) {
unsafe {
let _: () = msg_send![self.conv, release];
}
}
}
pub struct MPSFusedOps;
impl MPSFusedOps {
pub fn conv_bn_activation(
_device: &Device,
_command_buffer: &CommandBuffer,
_input: &MetalBuffer,
_conv_params: &Conv2dParams,
_conv_weights: &MetalBuffer,
_conv_bias: Option<&MetalBuffer>,
_bn_weight: &MetalBuffer,
_bn_bias: &MetalBuffer,
_bn_mean: &MetalBuffer,
_bn_var: &MetalBuffer,
_activation: ActivationType,
_output: &MetalBuffer,
) -> Result<()> {
Ok(())
}
pub fn linear_bias_activation(
_device: &Device,
_command_buffer: &CommandBuffer,
_input: &MetalBuffer,
_weight: &MetalBuffer,
_bias: Option<&MetalBuffer>,
_activation: ActivationType,
_output: &MetalBuffer,
) -> Result<()> {
Ok(())
}
}
#[derive(Debug, Clone, Copy)]
pub enum ActivationType {
ReLU,
ReLU6,
Sigmoid,
Tanh,
Swish,
GELU,
LeakyReLU(f32),
ELU(f32),
}
impl ActivationType {
pub fn to_mps_neuron_type(&self) -> u32 {
match self {
ActivationType::ReLU => 1, ActivationType::ReLU6 => 4, ActivationType::Sigmoid => 3, ActivationType::Tanh => 2, ActivationType::LeakyReLU(_) => 5, ActivationType::ELU(_) => 6, _ => 1, }
}
pub fn get_param_a(&self) -> f32 {
match self {
ActivationType::LeakyReLU(alpha) => *alpha,
ActivationType::ELU(alpha) => *alpha,
_ => 0.0,
}
}
}