use metal::CommandBuffer;
use metal::NSUInteger;
use objc2::msg_send;
use objc2::runtime::AnyObject;
use crate::metal::error::Result;
mod activation;
mod conv;
mod matmul;
mod mixed_precision;
mod networks;
mod neural_ops;
mod normalization;
mod pooling;
pub use activation::{ActivationType, MPSActivation};
pub use conv::{Conv2dParams, MPSConv2d};
pub use matmul::MPSMatMul;
pub use mixed_precision::{
AMPConfig, MPSAutocast, MPSGradScaler, MPSMixedPrecision, MixedPrecisionStats, OptLevel,
};
pub use networks::{
MPSConvBlock, MPSConvBlockBuilder, MPSFeedForward, MPSLayerNorm, MPSOptimizations,
MPSResidualBlock, MPSTransformerEncoderLayer, MemoryLayout,
};
pub use neural_ops::{
Conv2dParams as OptimizedConv2dParams, ConvolutionAlgorithm, MPSBatchNormalization,
MPSFusedOps, MPSLinear, MPSMultiHeadAttention, MPSOptimizedConv2d,
};
pub use normalization::MPSBatchNorm;
pub use pooling::{MPSAvgPool2d, MPSMaxPool2d};
pub trait MPSOperation {
fn encode(&self, command_buffer: &CommandBuffer) -> Result<()>;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MPSDataType {
Float16,
Float32,
Int8,
UInt8,
}
impl MPSDataType {
pub fn to_mps_constant(&self) -> u32 {
match self {
MPSDataType::Float16 => 0x10DE, MPSDataType::Float32 => 0x10E0, MPSDataType::Int8 => 0x1020, MPSDataType::UInt8 => 0x1008, }
}
}
#[allow(dead_code)]
pub struct MPSTensorDescriptor {
shape: Vec<usize>,
dtype: MPSDataType,
layout: MPSLayout,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MPSLayout {
NCHW,
NHWC,
}
pub(crate) unsafe fn create_matrix_descriptor(
rows: usize,
columns: usize,
dtype: MPSDataType,
) -> *mut AnyObject {
let class = objc2::class!(MPSMatrixDescriptor);
let descriptor: *mut AnyObject = msg_send![class, alloc];
let descriptor: *mut AnyObject = msg_send![descriptor, init];
let _: () = msg_send![descriptor, setRows: rows as NSUInteger];
let _: () = msg_send![descriptor, setColumns: columns as NSUInteger];
let _: () = msg_send![descriptor, setDataType: dtype.to_mps_constant()];
let element_size = match dtype {
MPSDataType::Float16 => 2,
MPSDataType::Float32 => 4,
MPSDataType::Int8 | MPSDataType::UInt8 => 1,
};
let row_bytes = columns * element_size;
let _: () = msg_send![descriptor, setRowBytes: row_bytes as NSUInteger];
descriptor
}
pub(crate) unsafe fn create_image_descriptor(
width: usize,
height: usize,
channels: usize,
_dtype: MPSDataType,
) -> *mut AnyObject {
let class = objc2::class!(MPSImageDescriptor);
let descriptor: *mut AnyObject = msg_send![class, alloc];
let descriptor: *mut AnyObject = if channels == 1 {
msg_send![descriptor,
initWithChannelFormat: 0x10DE, width: width as NSUInteger,
height: height as NSUInteger,
featureChannels: channels as NSUInteger
]
} else {
msg_send![descriptor,
initWithChannelFormat: 0x7310, width: width as NSUInteger,
height: height as NSUInteger,
featureChannels: channels as NSUInteger
]
};
descriptor
}
pub(crate) unsafe fn create_conv_descriptor(
kernel_height: usize,
kernel_width: usize,
input_channels: usize,
output_channels: usize,
) -> *mut AnyObject {
let class = objc2::class!(MPSCNNConvolutionDescriptor);
let descriptor: *mut AnyObject = msg_send![class, alloc];
let descriptor: *mut AnyObject = msg_send![descriptor, init];
let _: () = msg_send![descriptor, setKernelHeight: kernel_height as NSUInteger];
let _: () = msg_send![descriptor, setKernelWidth: kernel_width as NSUInteger];
let _: () = msg_send![descriptor, setInputFeatureChannels: input_channels as NSUInteger];
let _: () = msg_send![descriptor, setOutputFeatureChannels: output_channels as NSUInteger];
descriptor
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mps_data_type() {
assert_eq!(MPSDataType::Float32.to_mps_constant(), 0x10E0);
assert_eq!(MPSDataType::Float16.to_mps_constant(), 0x10DE);
}
}