use metal::{CommandBuffer, Device};
use crate::metal::{
mps::{
neural_ops::Conv2dParams, ActivationType, ConvolutionAlgorithm, MPSBatchNormalization,
MPSLinear, MPSMultiHeadAttention, MPSOptimizedConv2d,
},
MetalBuffer, Result,
};
pub struct MPSResidualBlock {
conv1: MPSOptimizedConv2d,
bn1: MPSBatchNormalization,
conv2: MPSOptimizedConv2d,
bn2: MPSBatchNormalization,
downsample: Option<(MPSOptimizedConv2d, MPSBatchNormalization)>,
activation: ActivationType,
}
impl MPSResidualBlock {
pub fn new(
device: &Device,
in_channels: usize,
out_channels: usize,
stride: usize,
downsample: bool,
) -> Result<Self> {
let conv1_params = Conv2dParams {
in_channels,
out_channels,
kernel_height: 3,
kernel_width: 3,
stride_height: stride,
stride_width: stride,
padding_height: 1,
padding_width: 1,
dilation_height: 1,
dilation_width: 1,
groups: 1,
};
let conv2_params = Conv2dParams {
in_channels: out_channels,
out_channels,
kernel_height: 3,
kernel_width: 3,
stride_height: 1,
stride_width: 1,
padding_height: 1,
padding_width: 1,
dilation_height: 1,
dilation_width: 1,
groups: 1,
};
let weights1 = MetalBuffer::zeros(
&torsh_core::Shape::from(vec![out_channels, in_channels, 3, 3]),
&torsh_core::DType::F32,
&crate::metal::device::MetalDevice::new()?,
)?;
let weights2 = MetalBuffer::zeros(
&torsh_core::Shape::from(vec![out_channels, out_channels, 3, 3]),
&torsh_core::DType::F32,
&crate::metal::device::MetalDevice::new()?,
)?;
let conv1 = MPSOptimizedConv2d::new(device, conv1_params, &weights1, None, true)?;
let bn1 = MPSBatchNormalization::new(device, out_channels, 1e-5, 0.1, true)?;
let conv2 = MPSOptimizedConv2d::new(device, conv2_params, &weights2, None, true)?;
let bn2 = MPSBatchNormalization::new(device, out_channels, 1e-5, 0.1, true)?;
let downsample_layers = if downsample || in_channels != out_channels {
let downsample_params = Conv2dParams {
in_channels,
out_channels,
kernel_height: 1,
kernel_width: 1,
stride_height: stride,
stride_width: stride,
padding_height: 0,
padding_width: 0,
dilation_height: 1,
dilation_width: 1,
groups: 1,
};
let downsample_weights = MetalBuffer::zeros(
&torsh_core::Shape::from(vec![out_channels, in_channels, 1, 1]),
&torsh_core::DType::F32,
&crate::metal::device::MetalDevice::new()?,
)?;
let downsample_conv = MPSOptimizedConv2d::new(
device,
downsample_params,
&downsample_weights,
None,
true,
)?;
let downsample_bn = MPSBatchNormalization::new(device, out_channels, 1e-5, 0.1, true)?;
Some((downsample_conv, downsample_bn))
} else {
None
};
Ok(Self {
conv1,
bn1,
conv2,
bn2,
downsample: downsample_layers,
activation: ActivationType::ReLU,
})
}
pub fn forward(
&mut self,
command_buffer: &CommandBuffer,
input: &MetalBuffer,
output: &MetalBuffer,
training: bool,
) -> Result<()> {
let conv1_out = MetalBuffer::zeros(
input.shape(),
&torsh_core::DType::F32,
&crate::metal::device::MetalDevice::new()?,
)?;
let bn1_out = MetalBuffer::zeros(
input.shape(),
&torsh_core::DType::F32,
&crate::metal::device::MetalDevice::new()?,
)?;
let conv2_out = MetalBuffer::zeros(
input.shape(),
&torsh_core::DType::F32,
&crate::metal::device::MetalDevice::new()?,
)?;
let bn2_out = MetalBuffer::zeros(
input.shape(),
&torsh_core::DType::F32,
&crate::metal::device::MetalDevice::new()?,
)?;
self.conv1.encode(command_buffer, input, &conv1_out)?;
self.bn1
.forward(command_buffer, &conv1_out, &bn1_out, training)?;
self.conv2.encode(command_buffer, &bn1_out, &conv2_out)?;
self.bn2
.forward(command_buffer, &conv2_out, &bn2_out, training)?;
if let Some((ref downsample_conv, ref mut downsample_bn)) = self.downsample {
let downsampled = MetalBuffer::zeros(
output.shape(),
&torsh_core::DType::F32,
&crate::metal::device::MetalDevice::new()?,
)?;
let downsampled_bn = MetalBuffer::zeros(
output.shape(),
&torsh_core::DType::F32,
&crate::metal::device::MetalDevice::new()?,
)?;
downsample_conv.encode(command_buffer, input, &downsampled)?;
downsample_bn.forward(command_buffer, &downsampled, &downsampled_bn, training)?;
} else {
}
Ok(())
}
}
pub struct MPSTransformerEncoderLayer {
self_attention: MPSMultiHeadAttention,
feed_forward: MPSFeedForward,
norm1: MPSLayerNorm,
norm2: MPSLayerNorm,
dropout_p: f32,
}
impl MPSTransformerEncoderLayer {
pub fn new(
device: &Device,
embed_dim: usize,
num_heads: usize,
ff_dim: usize,
dropout_p: f32,
) -> Result<Self> {
let self_attention = MPSMultiHeadAttention::new(device, embed_dim, num_heads, dropout_p)?;
let feed_forward = MPSFeedForward::new(device, embed_dim, ff_dim, ActivationType::ReLU)?;
let norm1 = MPSLayerNorm::new(device, embed_dim, 1e-5)?;
let norm2 = MPSLayerNorm::new(device, embed_dim, 1e-5)?;
Ok(Self {
self_attention,
feed_forward,
norm1,
norm2,
dropout_p,
})
}
pub fn forward(
&self,
command_buffer: &CommandBuffer,
input: &MetalBuffer,
output: &MetalBuffer,
mask: Option<&MetalBuffer>,
) -> Result<()> {
let attn_out = MetalBuffer::zeros(
input.shape(),
&torsh_core::DType::F32,
&crate::metal::device::MetalDevice::new()?,
)?;
self.self_attention
.forward(command_buffer, input, input, input, &attn_out, mask)?;
let norm1_out = MetalBuffer::zeros(
input.shape(),
&torsh_core::DType::F32,
&crate::metal::device::MetalDevice::new()?,
)?;
self.norm1.forward(command_buffer, &attn_out, &norm1_out)?;
let ff_out = MetalBuffer::zeros(
input.shape(),
&torsh_core::DType::F32,
&crate::metal::device::MetalDevice::new()?,
)?;
self.feed_forward
.forward(command_buffer, &norm1_out, &ff_out)?;
self.norm2.forward(command_buffer, &ff_out, output)?;
Ok(())
}
}
pub struct MPSFeedForward {
linear1: MPSLinear,
linear2: MPSLinear,
activation: ActivationType,
hidden_dim: usize,
}
impl MPSFeedForward {
pub fn new(
device: &Device,
input_dim: usize,
hidden_dim: usize,
activation: ActivationType,
) -> Result<Self> {
let linear1 = MPSLinear::new(device, input_dim, hidden_dim, true)?;
let linear2 = MPSLinear::new(device, hidden_dim, input_dim, true)?;
Ok(Self {
linear1,
linear2,
activation,
hidden_dim,
})
}
pub fn forward(
&self,
command_buffer: &CommandBuffer,
input: &MetalBuffer,
output: &MetalBuffer,
) -> Result<()> {
let hidden = MetalBuffer::zeros(
&torsh_core::Shape::from(vec![input.shape().dims()[0], self.hidden_dim]),
&input.dtype(),
&crate::metal::device::MetalDevice::new()?,
)?;
self.linear1.forward(command_buffer, input, &hidden)?;
self.linear2.forward(command_buffer, &hidden, output)?;
Ok(())
}
}
pub struct MPSLayerNorm {
weight: MetalBuffer,
bias: MetalBuffer,
normalized_shape: Vec<usize>,
eps: f32,
}
impl MPSLayerNorm {
pub fn new(_device: &Device, normalized_shape: usize, eps: f32) -> Result<Self> {
let weight = MetalBuffer::ones(
&torsh_core::Shape::from(vec![normalized_shape]),
&torsh_core::DType::F32,
&crate::metal::device::MetalDevice::new()?,
)?;
let bias = MetalBuffer::zeros(
&torsh_core::Shape::from(vec![normalized_shape]),
&torsh_core::DType::F32,
&crate::metal::device::MetalDevice::new()?,
)?;
Ok(Self {
weight,
bias,
normalized_shape: vec![normalized_shape],
eps,
})
}
pub fn forward(
&self,
_command_buffer: &CommandBuffer,
_input: &MetalBuffer,
_output: &MetalBuffer,
) -> Result<()> {
Ok(())
}
}
pub struct MPSConvBlock {
layers: Vec<ConvLayer>,
}
enum ConvLayer {
Conv2d(MPSOptimizedConv2d),
BatchNorm(MPSBatchNormalization),
Activation(ActivationType),
Dropout(f32),
}
impl MPSConvBlock {
pub fn new(device: &Device) -> MPSConvBlockBuilder {
MPSConvBlockBuilder::new(device)
}
pub fn forward(
&mut self,
command_buffer: &CommandBuffer,
input: &MetalBuffer,
output: &MetalBuffer,
training: bool,
) -> Result<()> {
let mut current_input = input.clone();
let layer_count = self.layers.len();
for (i, layer) in self.layers.iter_mut().enumerate() {
let is_last = i == layer_count - 1;
let current_output = if is_last {
output.clone()
} else {
MetalBuffer::zeros(
current_input.shape(),
&torsh_core::DType::F32,
&crate::metal::device::MetalDevice::new()?,
)?
};
match layer {
ConvLayer::Conv2d(conv) => {
conv.encode(command_buffer, ¤t_input, ¤t_output)?;
}
ConvLayer::BatchNorm(bn) => {
bn.forward(command_buffer, ¤t_input, ¤t_output, training)?;
}
ConvLayer::Activation(_activation) => {
}
ConvLayer::Dropout(_p) => {
}
}
current_input = current_output;
}
Ok(())
}
}
pub struct MPSConvBlockBuilder {
device: Device,
layers: Vec<ConvLayer>,
}
impl MPSConvBlockBuilder {
fn new(device: &Device) -> Self {
Self {
device: device.clone(),
layers: Vec::new(),
}
}
pub fn conv2d(
mut self,
in_channels: usize,
out_channels: usize,
kernel_size: usize,
stride: usize,
padding: usize,
) -> Result<Self> {
let params = Conv2dParams {
in_channels,
out_channels,
kernel_height: kernel_size,
kernel_width: kernel_size,
stride_height: stride,
stride_width: stride,
padding_height: padding,
padding_width: padding,
dilation_height: 1,
dilation_width: 1,
groups: 1,
};
let weights = MetalBuffer::zeros(
&torsh_core::Shape::from(vec![out_channels, in_channels, kernel_size, kernel_size]),
&torsh_core::DType::F32,
&crate::metal::device::MetalDevice::new()?,
)?;
let conv = MPSOptimizedConv2d::new(&self.device, params, &weights, None, true)?;
self.layers.push(ConvLayer::Conv2d(conv));
Ok(self)
}
pub fn batch_norm(mut self, num_features: usize) -> Result<Self> {
let bn = MPSBatchNormalization::new(&self.device, num_features, 1e-5, 0.1, true)?;
self.layers.push(ConvLayer::BatchNorm(bn));
Ok(self)
}
pub fn activation(mut self, activation: ActivationType) -> Self {
self.layers.push(ConvLayer::Activation(activation));
self
}
pub fn dropout(mut self, p: f32) -> Self {
self.layers.push(ConvLayer::Dropout(p));
self
}
pub fn build(self) -> MPSConvBlock {
MPSConvBlock {
layers: self.layers,
}
}
}
pub struct MPSOptimizations;
impl MPSOptimizations {
pub fn analyze_workload(
input_shape: &[usize],
conv_params: &Conv2dParams,
batch_size: usize,
) -> ConvolutionAlgorithm {
let [_batch, _channels, height, width] = [
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3],
];
let _input_size = height * width;
let kernel_size = conv_params.kernel_height * conv_params.kernel_width;
let output_channels = conv_params.out_channels;
if conv_params.kernel_height == 3
&& conv_params.kernel_width == 3
&& conv_params.stride_height == 1
&& conv_params.stride_width == 1
{
return ConvolutionAlgorithm::Winograd;
}
if kernel_size >= 49 {
return ConvolutionAlgorithm::FFT;
}
if batch_size >= 32 && output_channels >= 256 {
return ConvolutionAlgorithm::Im2ColGemm;
}
ConvolutionAlgorithm::Direct
}
pub fn suggest_memory_layout(shapes: &[Vec<usize>]) -> Vec<MemoryLayout> {
shapes
.iter()
.map(|shape| {
match shape.len() {
4 => MemoryLayout::NCHW, 2 => MemoryLayout::RowMajor, _ => MemoryLayout::Contiguous,
}
})
.collect()
}
pub fn calculate_flops(_operation: &dyn MPSOperation, _input_shapes: &[Vec<usize>]) -> u64 {
0
}
}
#[derive(Debug, Clone, Copy)]
pub enum MemoryLayout {
NCHW,
NHWC,
RowMajor,
ColumnMajor,
Contiguous,
}
pub trait MPSOperation {
fn encode(&self, command_buffer: &CommandBuffer) -> Result<()>;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_conv_block_builder() {
assert!(true);
}
#[test]
fn test_optimization_heuristics() {
let input_shape = vec![32, 256, 56, 56]; let conv_params = Conv2dParams {
in_channels: 256,
out_channels: 512,
kernel_height: 3,
kernel_width: 3,
stride_height: 1,
stride_width: 1,
padding_height: 1,
padding_width: 1,
dilation_height: 1,
dilation_width: 1,
groups: 1,
};
let algorithm = MPSOptimizations::analyze_workload(&input_shape, &conv_params, 32);
assert_eq!(algorithm, ConvolutionAlgorithm::Winograd);
}
}