use std::collections::HashMap;
use metal::{
CommandQueue, CompileOptions, ComputePipelineState, Device, Library,
MTLResourceOptions, NSUInteger,
};
#[derive(Debug, thiserror::Error)]
pub enum MetalError {
#[error("no Metal device available (system has no GPU?)")]
NoDevice,
#[error("compiling shaders.metal: {0}")]
LibraryCompile(String),
#[error("kernel '{name}' not found in compiled library")]
FunctionNotFound { name: String },
#[error("pipeline-state creation failed for '{name}': {err}")]
PipelineCreate { name: String, err: String },
}
const SHADER_SOURCE: &str =
include_str!("../../shaders/shaders.metal");
pub const ALL_KERNELS: &[&str] = &[
"attn_scores_batched",
"attn_softmax_batched",
"attn_values_batched",
"compute_decay_beta",
"conv1d_step",
"dequant_matvec_2bit",
"dequant_matvec_4bit",
"dequant_matvec_4bit_batched",
"dequant_matvec_4bit_fast",
"dequant_matvec_4bit_v3",
"dequant_matvec_4bit_v4",
"dequant_matvec_4bit_v5",
"dequant_matvec_8bit_v3",
"fused_gate_up_swiglu",
"gated_delta_net_step",
"gated_rms_norm",
"moe_combine_residual",
"residual_add",
"rms_norm_apply",
"rms_norm_apply_bf16",
"rms_norm_qk",
"rms_norm_sum_sq",
"sigmoid_gate",
"swiglu_fused",
"swiglu_fused_batched",
"swiglu_fused_vec4",
"weighted_sum",
];
pub struct MetalBackend {
device: Device,
queue: CommandQueue,
library: Library,
pipelines: HashMap<&'static str, ComputePipelineState>,
}
impl MetalBackend {
pub fn new() -> Result<Self, MetalError> {
let device = Device::system_default().ok_or(MetalError::NoDevice)?;
let queue = device.new_command_queue();
let options = CompileOptions::new();
let library = device
.new_library_with_source(SHADER_SOURCE, &options)
.map_err(MetalError::LibraryCompile)?;
Ok(Self {
device,
queue,
library,
pipelines: HashMap::new(),
})
}
pub fn device(&self) -> &Device {
&self.device
}
pub fn queue(&self) -> &CommandQueue {
&self.queue
}
pub fn pipeline(
&mut self,
name: &'static str,
) -> Result<&ComputePipelineState, MetalError> {
if !self.pipelines.contains_key(name) {
let function = self.library.get_function(name, None).map_err(
|_| MetalError::FunctionNotFound {
name: name.to_string(),
},
)?;
let state = self
.device
.new_compute_pipeline_state_with_function(&function)
.map_err(|err| MetalError::PipelineCreate {
name: name.to_string(),
err,
})?;
self.pipelines.insert(name, state);
}
Ok(&self.pipelines[name])
}
pub fn warm_all(&mut self) -> Result<(), MetalError> {
for &name in ALL_KERNELS {
self.pipeline(name)?;
}
Ok(())
}
pub fn pipeline_count(&self) -> usize {
self.pipelines.len()
}
}
impl std::fmt::Debug for MetalBackend {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MetalBackend")
.field("device", &self.device.name())
.field("pipelines_cached", &self.pipelines.len())
.finish()
}
}
pub struct MtlBuffer<T> {
inner: metal::Buffer,
len: usize,
_phantom: std::marker::PhantomData<T>,
}
impl<T: Copy> MtlBuffer<T> {
pub fn with_len(device: &Device, len: usize) -> Self {
let bytes = (len * std::mem::size_of::<T>()) as NSUInteger;
let inner = device.new_buffer(bytes, MTLResourceOptions::StorageModeShared);
Self {
inner,
len,
_phantom: std::marker::PhantomData,
}
}
pub fn with_data(device: &Device, data: &[T]) -> Self {
let bytes = (std::mem::size_of_val(data)) as NSUInteger;
let inner = device.new_buffer_with_data(
data.as_ptr().cast(),
bytes,
MTLResourceOptions::StorageModeShared,
);
Self {
inner,
len: data.len(),
_phantom: std::marker::PhantomData,
}
}
pub fn raw(&self) -> &metal::BufferRef {
&self.inner
}
pub fn len(&self) -> usize {
self.len
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn to_vec(&self) -> Vec<T> {
let ptr = self.inner.contents() as *const T;
unsafe { std::slice::from_raw_parts(ptr, self.len).to_vec() }
}
pub fn as_mut_slice(&mut self) -> &mut [T] {
let ptr = self.inner.contents() as *mut T;
unsafe { std::slice::from_raw_parts_mut(ptr, self.len) }
}
}
impl<T> std::fmt::Debug for MtlBuffer<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MtlBuffer")
.field("len", &self.len)
.field("element_size", &std::mem::size_of::<T>())
.field("byte_size", &(self.len * std::mem::size_of::<T>()))
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[ignore = "needs Metal device + access to shaders.metal source"]
fn metal_backend_compiles_all_kernels() {
let mut backend =
MetalBackend::new().expect("MetalBackend::new failed");
eprintln!("[metal] device: {}", backend.device().name());
eprintln!("[metal] kernels to compile: {}", ALL_KERNELS.len());
backend.warm_all().expect("warm_all failed");
assert_eq!(backend.pipeline_count(), ALL_KERNELS.len());
eprintln!(
"[metal] all {} kernels compiled successfully",
backend.pipeline_count()
);
}
#[test]
#[ignore = "needs Metal device"]
fn buffer_round_trip() {
let backend = MetalBackend::new().expect("MetalBackend::new");
let data: Vec<f32> = (0..1024).map(|i| i as f32 * 0.5).collect();
let buf = MtlBuffer::with_data(backend.device(), &data);
assert_eq!(buf.len(), 1024);
let read = buf.to_vec();
assert_eq!(read, data);
}
}