use std::collections::HashMap;
use std::sync::Mutex;
pub use metal::{
CommandBufferRef, CommandQueue, CompileOptions, ComputePipelineState, Device, Library,
MTLResourceOptions, NSUInteger,
};
use objc::{msg_send, sel, sel_impl};
use moeflux_metal::Kernels;
unsafe fn nsstring_as_str<'a>(nsstr: &'a objc::runtime::Object) -> &'a str {
let bytes: *const i8 = unsafe { msg_send![nsstr, UTF8String] };
let len: NSUInteger = unsafe { msg_send![nsstr, length] };
if bytes.is_null() || len == 0 {
return "";
}
let slice = unsafe { std::slice::from_raw_parts(bytes.cast::<u8>(), len as usize) };
std::str::from_utf8(slice).unwrap_or("<invalid utf-8>")
}
fn cmdbuf_error_detail(cmdbuf: &CommandBufferRef) -> String {
unsafe {
let err: *mut objc::runtime::Object = msg_send![cmdbuf, error];
if err.is_null() {
return "<no NSError>".to_string();
}
let code: isize = msg_send![err, code];
let desc_obj: *mut objc::runtime::Object = msg_send![err, localizedDescription];
let desc = if desc_obj.is_null() {
"<no description>"
} else {
nsstring_as_str(&*desc_obj)
};
let kind = match code {
0 => "None",
1 => "Internal",
2 => "Timeout",
3 => "PageFault",
4 => "Blacklisted",
7 => "NotPermitted",
8 => "OutOfMemory",
9 => "InvalidResource",
10 => "Memoryless",
11 => "DeviceRemoved",
_ => "Unknown",
};
format!("{kind}({code}): {desc}")
}
}
#[derive(Debug, thiserror::Error)]
pub enum MetalError {
#[error("no Metal device available (system has no GPU?)")]
NoDevice,
#[error("compiling shaders.metal: {0}")]
LibraryCompile(String),
#[error("kernel '{name}' not found in compiled library")]
FunctionNotFound { name: String },
#[error("pipeline-state creation failed for '{name}': {err}")]
PipelineCreate { name: String, err: String },
#[error("building the moeflux-metal kernel library: {0}")]
MlxKernels(String),
}
const SHADER_SOURCE: &str = include_str!("../../../../shaders/shaders.metal");
pub const ALL_KERNELS: &[&str] = &[
"attn_scores_batched",
"attn_softmax_batched",
"attn_values_batched",
"bf16_matvec",
"compute_decay_beta",
"conv1d_state_update",
"conv1d_step",
"dequant_matvec_2bit",
"dequant_matvec_4bit",
"dequant_matvec_4bit_batched",
"dequant_matvec_4bit_fast",
"dequant_matvec_4bit_v3",
"dequant_matvec_4bit_v3_experts",
"dequant_matvec_4bit_v4",
"dequant_matvec_4bit_v5",
"dequant_matvec_8bit_v3",
"dequant_matvec_8bit_v3_n_tokens",
"fused_gate_up_swiglu",
"gated_delta_net_chunkwise",
"gated_delta_net_sequential",
"gated_delta_net_step",
"gated_rms_norm",
"kv_cache_append_n_tokens",
"mla_sdpa_tile_accumulate",
"mla_sdpa_tile_finalize",
"moe_combine_residual",
"moe_combine_residual_flat",
"moe_combine_residual_n_tokens",
"moe_normalize_weights",
"moe_softmax_topk",
"residual_add",
"residual_add_n_tokens",
"rms_norm_apply",
"rms_norm_apply_bf16",
"rms_norm_bf16_fused_n_tokens",
"rms_norm_per_head_n_tokens",
"rms_norm_qk",
"rms_norm_sum_sq",
"rope_n_tokens",
"sigmoid_gate",
"split_q_gate",
"swiglu_fused",
"swiglu_fused_batched",
"swiglu_fused_vec4",
"weighted_sum",
];
#[derive(Debug, Default, Clone, Copy)]
pub struct CmdbufStat {
pub count: u64,
pub cpu_wait_ns: u64,
pub gpu_ns: u64,
}
pub struct MetalContext {
device: Device,
queue: CommandQueue,
library: Library,
pipelines: HashMap<&'static str, ComputePipelineState>,
cmdbuf_stats: Mutex<HashMap<&'static str, CmdbufStat>>,
kernels: Kernels,
}
static AGX_TIMEOUT_INIT: std::sync::Once = std::sync::Once::new();
impl MetalContext {
pub fn new() -> Result<Self, MetalError> {
AGX_TIMEOUT_INIT.call_once(|| unsafe {
std::env::set_var("AGX_RELAX_CDM_CTXSTORE_TIMEOUT", "1");
});
let device = Device::system_default().ok_or(MetalError::NoDevice)?;
let queue = device.new_command_queue();
let options = CompileOptions::new();
let library = device
.new_library_with_source(SHADER_SOURCE, &options)
.map_err(MetalError::LibraryCompile)?;
let kernels = Kernels::new(&device)
.map_err(|e| MetalError::MlxKernels(e.to_string()))?;
Ok(Self {
device,
queue,
library,
pipelines: HashMap::new(),
cmdbuf_stats: Mutex::new(HashMap::new()),
kernels,
})
}
pub fn device(&self) -> &Device {
&self.device
}
pub fn queue(&self) -> &CommandQueue {
&self.queue
}
pub fn kernels(&self) -> &Kernels {
&self.kernels
}
pub fn queue_clone(&self) -> CommandQueue {
self.queue.clone()
}
pub fn pipeline(&mut self, name: &'static str) -> Result<&ComputePipelineState, MetalError> {
if !self.pipelines.contains_key(name) {
let function = self.library.get_function(name, None).map_err(|_| {
MetalError::FunctionNotFound {
name: name.to_string(),
}
})?;
let state = self
.device
.new_compute_pipeline_state_with_function(&function)
.map_err(|err| MetalError::PipelineCreate {
name: name.to_string(),
err,
})?;
self.pipelines.insert(name, state);
}
Ok(&self.pipelines[name])
}
pub fn warm_all(&mut self) -> Result<(), MetalError> {
for &name in ALL_KERNELS {
self.pipeline(name)?;
}
Ok(())
}
pub fn pipeline_count(&self) -> usize {
self.pipelines.len()
}
pub fn commit_and_wait_labeled(&self, cmdbuf: &CommandBufferRef, label: &'static str) {
let t0 = std::time::Instant::now();
cmdbuf.commit();
cmdbuf.wait_until_completed();
let cpu_wait_ns = t0.elapsed().as_nanos() as u64;
if cmdbuf.status() == metal::MTLCommandBufferStatus::Error {
let detail = cmdbuf_error_detail(cmdbuf);
panic!(
"Metal command buffer '{label}' completed with error \
status: {detail}. Rerun with MTL_DEBUG_LAYER=1 \
MTL_SHADER_VALIDATION=1 for the fault detail."
);
}
let mut stats = self
.cmdbuf_stats
.lock()
.expect("cmdbuf_stats mutex poisoned");
let entry = stats.entry(label).or_default();
entry.count += 1;
entry.cpu_wait_ns += cpu_wait_ns;
}
pub fn cmdbuf_stats(&self) -> Vec<(&'static str, CmdbufStat)> {
let stats = self
.cmdbuf_stats
.lock()
.expect("cmdbuf_stats mutex poisoned");
let mut out: Vec<_> = stats.iter().map(|(k, v)| (*k, *v)).collect();
out.sort_by_key(|(k, _)| *k);
out
}
pub fn reset_cmdbuf_stats(&self) {
self.cmdbuf_stats
.lock()
.expect("cmdbuf_stats mutex poisoned")
.clear();
}
pub fn drain_queue(&self) {
let cmdbuf = self.queue.new_command_buffer();
cmdbuf.commit();
cmdbuf.wait_until_completed();
if cmdbuf.status() == metal::MTLCommandBufferStatus::Error {
let detail = cmdbuf_error_detail(cmdbuf);
panic!(
"drain_queue: barrier cmdbuf completed with error \
status: {detail}. Rerun with MTL_DEBUG_LAYER=1 \
MTL_SHADER_VALIDATION=1 for the fault detail."
);
}
}
}
impl std::fmt::Debug for MetalContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MetalContext")
.field("device", &self.device.name())
.field("pipelines_cached", &self.pipelines.len())
.finish()
}
}
struct AlignedBacking {
ptr: std::ptr::NonNull<u8>,
layout: std::alloc::Layout,
}
unsafe impl Send for AlignedBacking {}
unsafe impl Sync for AlignedBacking {}
impl Drop for AlignedBacking {
fn drop(&mut self) {
unsafe { std::alloc::dealloc(self.ptr.as_ptr(), self.layout) }
}
}
pub struct MtlBuffer<T> {
inner: metal::Buffer,
len: usize,
_backing: Option<AlignedBacking>,
_phantom: std::marker::PhantomData<T>,
}
impl<T: Copy> MtlBuffer<T> {
pub fn with_len(device: &Device, len: usize) -> Self {
let bytes = (len * std::mem::size_of::<T>()) as NSUInteger;
let inner = device.new_buffer(bytes, MTLResourceOptions::StorageModeShared);
Self {
inner,
len,
_backing: None,
_phantom: std::marker::PhantomData,
}
}
pub fn with_data(device: &Device, data: &[T]) -> Self {
let bytes = (std::mem::size_of_val(data)) as NSUInteger;
let inner = device.new_buffer_with_data(
data.as_ptr().cast(),
bytes,
MTLResourceOptions::StorageModeShared,
);
Self {
inner,
len: data.len(),
_backing: None,
_phantom: std::marker::PhantomData,
}
}
pub fn raw(&self) -> &metal::BufferRef {
&self.inner
}
pub fn buffer(&self) -> &metal::Buffer {
&self.inner
}
pub fn len(&self) -> usize {
self.len
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn to_vec(&self) -> Vec<T> {
unsafe { buffer_as_slice::<T>(&self.inner, self.len).to_vec() }
}
pub fn as_slice(&self) -> &[T] {
unsafe { buffer_as_slice::<T>(&self.inner, self.len) }
}
pub fn as_mut_slice(&mut self) -> &mut [T] {
unsafe { buffer_as_mut_slice::<T>(&self.inner, self.len) }
}
}
pub unsafe fn buffer_as_slice<T>(buf: &metal::BufferRef, n: usize) -> &[T] {
unsafe { std::slice::from_raw_parts(buf.contents() as *const T, n) }
}
pub unsafe fn buffer_as_mut_slice<T>(
buf: &metal::BufferRef,
n: usize,
) -> &mut [T] {
unsafe {
std::slice::from_raw_parts_mut(buf.contents() as *mut T, n)
}
}
impl MtlBuffer<u8> {
pub fn with_aligned_len_u8(device: &Device, len: usize, align: usize) -> Self {
assert!(align.is_power_of_two(), "align must be power of two");
assert!(len > 0, "with_aligned_len_u8 len must be > 0");
let layout =
std::alloc::Layout::from_size_align(len, align).expect("invalid alignment for len");
let raw = unsafe { std::alloc::alloc(layout) };
let ptr =
std::ptr::NonNull::new(raw).unwrap_or_else(|| std::alloc::handle_alloc_error(layout));
let inner = device.new_buffer_with_bytes_no_copy(
ptr.as_ptr() as *const std::ffi::c_void,
len as NSUInteger,
MTLResourceOptions::StorageModeShared,
None,
);
Self {
inner,
len,
_backing: Some(AlignedBacking { ptr, layout }),
_phantom: std::marker::PhantomData,
}
}
}
impl<T> std::fmt::Debug for MtlBuffer<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MtlBuffer")
.field("len", &self.len)
.field("element_size", &std::mem::size_of::<T>())
.field("byte_size", &(self.len * std::mem::size_of::<T>()))
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[ignore = "needs Metal device + access to shaders.metal source"]
fn metal_backend_compiles_all_kernels() {
let mut backend = MetalContext::new().expect("MetalContext::new failed");
eprintln!("[metal] device: {}", backend.device().name());
eprintln!("[metal] kernels to compile: {}", ALL_KERNELS.len());
backend.warm_all().expect("warm_all failed");
assert_eq!(backend.pipeline_count(), ALL_KERNELS.len());
eprintln!(
"[metal] all {} kernels compiled successfully",
backend.pipeline_count()
);
}
#[test]
#[ignore = "needs Metal device"]
fn buffer_round_trip() {
let backend = MetalContext::new().expect("MetalContext::new");
let data: Vec<f32> = (0..1024).map(|i| i as f32 * 0.5).collect();
let buf = MtlBuffer::with_data(backend.device(), &data);
assert_eq!(buf.len(), 1024);
let read = buf.to_vec();
assert_eq!(read, data);
}
#[test]
#[ignore = "needs Metal device"]
fn drain_queue_smoke() {
let backend = MetalContext::new().expect("MetalContext::new");
backend.drain_queue();
backend.drain_queue();
}
}