use std::sync::atomic::{AtomicU64, Ordering};
use metal::{
CommandBuffer, CommandQueue, ComputeCommandEncoderRef, ComputePipelineState,
ComputePipelineStateRef, MTLCommandBufferStatus, MTLDispatchType, MTLSize,
};
#[allow(unused_imports)]
use objc::{msg_send, sel, sel_impl};
use crate::buffer::MlxBuffer;
use crate::error::{MlxError, Result};
pub enum KernelArg<'a> {
Buffer(&'a MlxBuffer),
BufferWithOffset(&'a MlxBuffer, u64),
Bytes(&'a [u8]),
}
pub fn as_bytes<T: bytemuck::Pod>(val: &T) -> &[u8] {
bytemuck::bytes_of(val)
}
#[derive(Clone)]
pub enum RecordedBinding {
Buffer {
metal_buffer: metal::Buffer,
offset: u64,
},
Bytes(Vec<u8>),
}
#[derive(Clone, Copy, Debug)]
pub enum DispatchKind {
Threads,
ThreadGroups,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum CapturedOpKind {
RmsNorm,
ElemMul,
ElemAdd,
Sdpa,
Softmax,
Other,
}
impl CapturedOpKind {
pub fn is_reorderable(&self) -> bool {
match self {
Self::Sdpa | Self::Softmax => false,
Self::RmsNorm | Self::ElemMul | Self::ElemAdd | Self::Other => true,
}
}
}
pub type MemRange = (usize, usize);
#[derive(Clone)]
pub enum CapturedNode {
Dispatch {
pipeline: ComputePipelineState,
bindings: Vec<(u64, RecordedBinding)>,
threads_per_grid: MTLSize,
threads_per_threadgroup: MTLSize,
threadgroup_memory: Vec<(u64, u64)>,
dispatch_kind: DispatchKind,
op_kind: CapturedOpKind,
reads: Vec<MemRange>,
writes: Vec<MemRange>,
},
Barrier,
}
#[inline]
fn apply_bindings(encoder: &ComputeCommandEncoderRef, bindings: &[(u64, KernelArg<'_>)]) {
for &(index, ref arg) in bindings {
match arg {
KernelArg::Buffer(buf) => {
encoder.set_buffer(index, Some(buf.metal_buffer()), 0);
}
KernelArg::BufferWithOffset(buf, offset) => {
encoder.set_buffer(index, Some(buf.metal_buffer()), *offset);
}
KernelArg::Bytes(bytes) => {
encoder.set_bytes(index, bytes.len() as u64, bytes.as_ptr() as *const _);
}
}
}
}
static SYNC_COUNT: AtomicU64 = AtomicU64::new(0);
static DISPATCH_COUNT: AtomicU64 = AtomicU64::new(0);
pub fn reset_counters() {
SYNC_COUNT.store(0, Ordering::Relaxed);
DISPATCH_COUNT.store(0, Ordering::Relaxed);
}
pub fn sync_count() -> u64 {
SYNC_COUNT.load(Ordering::Relaxed)
}
pub fn dispatch_count() -> u64 {
DISPATCH_COUNT.load(Ordering::Relaxed)
}
pub struct CommandEncoder {
cmd_buf: CommandBuffer,
active_encoder: *const ComputeCommandEncoderRef,
capture: Option<Vec<CapturedNode>>,
pending_op_kind: CapturedOpKind,
pending_reads: Vec<MemRange>,
pending_writes: Vec<MemRange>,
}
unsafe impl Send for CommandEncoder {}
impl CommandEncoder {
pub(crate) fn new(queue: &CommandQueue) -> Result<Self> {
let cmd_buf = queue.new_command_buffer().to_owned();
Ok(Self {
cmd_buf,
active_encoder: std::ptr::null(),
capture: None,
pending_op_kind: CapturedOpKind::Other,
pending_reads: Vec::new(),
pending_writes: Vec::new(),
})
}
pub fn start_capture(&mut self) {
self.capture = Some(Vec::with_capacity(128));
}
pub fn is_capturing(&self) -> bool {
self.capture.is_some()
}
pub fn take_capture(&mut self) -> Option<Vec<CapturedNode>> {
self.capture.take()
}
pub fn set_op_kind(&mut self, kind: CapturedOpKind) {
self.pending_op_kind = kind;
}
fn take_pending_op_kind(&mut self) -> CapturedOpKind {
let kind = self.pending_op_kind;
self.pending_op_kind = CapturedOpKind::Other;
kind
}
pub fn set_pending_buffer_ranges(&mut self, reads: Vec<MemRange>, writes: Vec<MemRange>) {
self.pending_reads = reads;
self.pending_writes = writes;
}
fn take_pending_buffer_ranges(&mut self) -> (Vec<MemRange>, Vec<MemRange>) {
let reads = std::mem::take(&mut self.pending_reads);
let writes = std::mem::take(&mut self.pending_writes);
(reads, writes)
}
fn record_buffer_bindings(buffers: &[(u64, &MlxBuffer)]) -> Vec<(u64, RecordedBinding)> {
buffers
.iter()
.map(|&(index, buf)| {
(
index,
RecordedBinding::Buffer {
metal_buffer: buf.metal_buffer().clone(),
offset: 0,
},
)
})
.collect()
}
fn record_arg_bindings(bindings: &[(u64, KernelArg<'_>)]) -> Vec<(u64, RecordedBinding)> {
bindings
.iter()
.map(|(index, arg)| {
let recorded = match arg {
KernelArg::Buffer(buf) => RecordedBinding::Buffer {
metal_buffer: buf.metal_buffer().clone(),
offset: 0,
},
KernelArg::BufferWithOffset(buf, offset) => RecordedBinding::Buffer {
metal_buffer: buf.metal_buffer().clone(),
offset: *offset,
},
KernelArg::Bytes(bytes) => RecordedBinding::Bytes(bytes.to_vec()),
};
(*index, recorded)
})
.collect()
}
#[inline]
fn get_or_create_encoder(&mut self) -> &ComputeCommandEncoderRef {
if self.active_encoder.is_null() {
let encoder = self
.cmd_buf
.compute_command_encoder_with_dispatch_type(MTLDispatchType::Concurrent);
self.active_encoder = encoder as *const ComputeCommandEncoderRef;
}
unsafe { &*self.active_encoder }
}
#[inline]
fn end_active_encoder(&mut self) {
if !self.active_encoder.is_null() {
unsafe { &*self.active_encoder }.end_encoding();
self.active_encoder = std::ptr::null();
}
}
#[allow(unexpected_cfgs)]
pub fn memory_barrier(&mut self) {
if let Some(ref mut nodes) = self.capture {
nodes.push(CapturedNode::Barrier);
return;
}
if self.active_encoder.is_null() {
return;
}
let encoder = unsafe { &*self.active_encoder };
const MTL_BARRIER_SCOPE_BUFFERS: u64 = 1;
unsafe {
let _: () = objc::msg_send![encoder, memoryBarrierWithScope: MTL_BARRIER_SCOPE_BUFFERS];
}
}
pub fn set_pipeline(&mut self, pipeline: &ComputePipelineStateRef) {
let encoder = self.get_or_create_encoder();
encoder.set_compute_pipeline_state(pipeline);
}
pub fn set_buffer(&self, index: u64, buffer: &MlxBuffer) {
let _ = (index, buffer);
}
pub fn dispatch_threads(&self, grid_size: MTLSize, threadgroup_size: MTLSize) {
let _ = (grid_size, threadgroup_size);
}
pub fn encode(
&mut self,
pipeline: &ComputePipelineStateRef,
buffers: &[(u64, &MlxBuffer)],
grid_size: MTLSize,
threadgroup_size: MTLSize,
) {
DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
let op_kind = self.take_pending_op_kind();
let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
if let Some(ref mut nodes) = self.capture {
nodes.push(CapturedNode::Dispatch {
pipeline: pipeline.to_owned(),
bindings: Self::record_buffer_bindings(buffers),
threads_per_grid: grid_size,
threads_per_threadgroup: threadgroup_size,
threadgroup_memory: Vec::new(),
dispatch_kind: DispatchKind::Threads,
op_kind,
reads: pending_reads,
writes: pending_writes,
});
return;
}
let encoder = self.get_or_create_encoder();
encoder.set_compute_pipeline_state(pipeline);
for &(index, buf) in buffers {
encoder.set_buffer(index, Some(buf.metal_buffer()), 0);
}
encoder.dispatch_threads(grid_size, threadgroup_size);
}
pub fn encode_threadgroups(
&mut self,
pipeline: &ComputePipelineStateRef,
buffers: &[(u64, &MlxBuffer)],
threadgroups: MTLSize,
threadgroup_size: MTLSize,
) {
DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
let op_kind = self.take_pending_op_kind();
let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
if let Some(ref mut nodes) = self.capture {
nodes.push(CapturedNode::Dispatch {
pipeline: pipeline.to_owned(),
bindings: Self::record_buffer_bindings(buffers),
threads_per_grid: threadgroups,
threads_per_threadgroup: threadgroup_size,
threadgroup_memory: Vec::new(),
dispatch_kind: DispatchKind::ThreadGroups,
op_kind,
reads: pending_reads,
writes: pending_writes,
});
return;
}
let encoder = self.get_or_create_encoder();
encoder.set_compute_pipeline_state(pipeline);
for &(index, buf) in buffers {
encoder.set_buffer(index, Some(buf.metal_buffer()), 0);
}
encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
}
pub fn encode_threadgroups_with_shared(
&mut self,
pipeline: &ComputePipelineStateRef,
buffers: &[(u64, &MlxBuffer)],
threadgroup_mem: &[(u64, u64)],
threadgroups: MTLSize,
threadgroup_size: MTLSize,
) {
DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
let op_kind = self.take_pending_op_kind();
let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
if let Some(ref mut nodes) = self.capture {
nodes.push(CapturedNode::Dispatch {
pipeline: pipeline.to_owned(),
bindings: Self::record_buffer_bindings(buffers),
threads_per_grid: threadgroups,
threads_per_threadgroup: threadgroup_size,
threadgroup_memory: threadgroup_mem.to_vec(),
dispatch_kind: DispatchKind::ThreadGroups,
op_kind,
reads: pending_reads,
writes: pending_writes,
});
return;
}
let encoder = self.get_or_create_encoder();
encoder.set_compute_pipeline_state(pipeline);
for &(index, buf) in buffers {
encoder.set_buffer(index, Some(buf.metal_buffer()), 0);
}
for &(index, byte_length) in threadgroup_mem {
encoder.set_threadgroup_memory_length(index, byte_length);
}
encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
}
pub fn encode_with_args(
&mut self,
pipeline: &ComputePipelineStateRef,
bindings: &[(u64, KernelArg<'_>)],
grid_size: MTLSize,
threadgroup_size: MTLSize,
) {
DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
let op_kind = self.take_pending_op_kind();
let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
if let Some(ref mut nodes) = self.capture {
nodes.push(CapturedNode::Dispatch {
pipeline: pipeline.to_owned(),
bindings: Self::record_arg_bindings(bindings),
threads_per_grid: grid_size,
threads_per_threadgroup: threadgroup_size,
threadgroup_memory: Vec::new(),
dispatch_kind: DispatchKind::Threads,
op_kind,
reads: pending_reads,
writes: pending_writes,
});
return;
}
let encoder = self.get_or_create_encoder();
encoder.set_compute_pipeline_state(pipeline);
apply_bindings(encoder, bindings);
encoder.dispatch_threads(grid_size, threadgroup_size);
}
pub fn encode_threadgroups_with_args(
&mut self,
pipeline: &ComputePipelineStateRef,
bindings: &[(u64, KernelArg<'_>)],
threadgroups: MTLSize,
threadgroup_size: MTLSize,
) {
DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
let op_kind = self.take_pending_op_kind();
let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
if let Some(ref mut nodes) = self.capture {
nodes.push(CapturedNode::Dispatch {
pipeline: pipeline.to_owned(),
bindings: Self::record_arg_bindings(bindings),
threads_per_grid: threadgroups,
threads_per_threadgroup: threadgroup_size,
threadgroup_memory: Vec::new(),
dispatch_kind: DispatchKind::ThreadGroups,
op_kind,
reads: pending_reads,
writes: pending_writes,
});
return;
}
let encoder = self.get_or_create_encoder();
encoder.set_compute_pipeline_state(pipeline);
apply_bindings(encoder, bindings);
encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
}
pub fn encode_threadgroups_with_args_and_shared(
&mut self,
pipeline: &ComputePipelineStateRef,
bindings: &[(u64, KernelArg<'_>)],
threadgroup_mem: &[(u64, u64)],
threadgroups: MTLSize,
threadgroup_size: MTLSize,
) {
DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
let op_kind = self.take_pending_op_kind();
let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
if let Some(ref mut nodes) = self.capture {
nodes.push(CapturedNode::Dispatch {
pipeline: pipeline.to_owned(),
bindings: Self::record_arg_bindings(bindings),
threads_per_grid: threadgroups,
threads_per_threadgroup: threadgroup_size,
threadgroup_memory: threadgroup_mem.to_vec(),
dispatch_kind: DispatchKind::ThreadGroups,
op_kind,
reads: pending_reads,
writes: pending_writes,
});
return;
}
let encoder = self.get_or_create_encoder();
encoder.set_compute_pipeline_state(pipeline);
apply_bindings(encoder, bindings);
for &(index, byte_length) in threadgroup_mem {
encoder.set_threadgroup_memory_length(index, byte_length);
}
encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
}
pub fn replay_dispatch(
&mut self,
pipeline: &ComputePipelineStateRef,
bindings: &[(u64, RecordedBinding)],
threadgroup_memory: &[(u64, u64)],
threads_per_grid: MTLSize,
threads_per_threadgroup: MTLSize,
dispatch_kind: DispatchKind,
) {
let encoder = self.get_or_create_encoder();
encoder.set_compute_pipeline_state(pipeline);
for (index, binding) in bindings {
match binding {
RecordedBinding::Buffer { metal_buffer, offset } => {
encoder.set_buffer(*index, Some(metal_buffer), *offset);
}
RecordedBinding::Bytes(bytes) => {
encoder.set_bytes(
*index,
bytes.len() as u64,
bytes.as_ptr() as *const _,
);
}
}
}
for &(index, byte_length) in threadgroup_memory {
encoder.set_threadgroup_memory_length(index, byte_length);
}
match dispatch_kind {
DispatchKind::Threads => {
encoder.dispatch_threads(threads_per_grid, threads_per_threadgroup);
}
DispatchKind::ThreadGroups => {
encoder.dispatch_thread_groups(threads_per_grid, threads_per_threadgroup);
}
}
}
pub fn commit_and_wait(&mut self) -> Result<()> {
SYNC_COUNT.fetch_add(1, Ordering::Relaxed);
self.end_active_encoder();
self.cmd_buf.commit();
self.cmd_buf.wait_until_completed();
match self.cmd_buf.status() {
MTLCommandBufferStatus::Completed => Ok(()),
MTLCommandBufferStatus::Error => {
Err(MlxError::CommandBufferError(
"GPU command buffer completed with error status".into(),
))
}
status => Err(MlxError::CommandBufferError(format!(
"Unexpected command buffer status after wait: {:?}",
status
))),
}
}
pub fn commit(&mut self) {
self.end_active_encoder();
self.cmd_buf.commit();
}
pub fn wait_until_completed(&self) -> Result<()> {
self.cmd_buf.wait_until_completed();
match self.cmd_buf.status() {
MTLCommandBufferStatus::Completed => Ok(()),
MTLCommandBufferStatus::Error => Err(MlxError::CommandBufferError(
"GPU command buffer completed with error status".into(),
)),
status => Err(MlxError::CommandBufferError(format!(
"Unexpected command buffer status after wait: {:?}",
status
))),
}
}
#[inline]
pub fn metal_command_buffer(&self) -> &CommandBuffer {
&self.cmd_buf
}
}
impl Drop for CommandEncoder {
fn drop(&mut self) {
self.end_active_encoder();
}
}