use std::sync::atomic::{AtomicU64, Ordering};
use metal::{
CommandBuffer, CommandQueue, ComputeCommandEncoderRef, ComputePipelineState,
ComputePipelineStateRef, CounterSampleBuffer, CounterSampleBufferDescriptor,
MTLCommandBufferStatus, MTLCounterSamplingPoint, MTLDispatchType, MTLSize, MTLStorageMode,
NSRange,
};
#[allow(unused_imports)]
use objc::{msg_send, sel, sel_impl};
use crate::buffer::MlxBuffer;
use crate::error::{MlxError, Result};
use crate::mem_ranges::MemRanges;
use crate::residency::ResidencySet;
pub enum KernelArg<'a> {
Buffer(&'a MlxBuffer),
BufferWithOffset(&'a MlxBuffer, u64),
Bytes(&'a [u8]),
}
pub fn as_bytes<T: bytemuck::Pod>(val: &T) -> &[u8] {
bytemuck::bytes_of(val)
}
#[derive(Clone)]
pub enum RecordedBinding {
Buffer {
metal_buffer: metal::Buffer,
offset: u64,
},
Bytes(Vec<u8>),
}
#[derive(Clone, Copy, Debug)]
pub enum DispatchKind {
Threads,
ThreadGroups,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum CapturedOpKind {
RmsNorm,
ElemMul,
ElemAdd,
Sdpa,
Softmax,
Other,
}
impl CapturedOpKind {
pub fn is_reorderable(&self) -> bool {
match self {
Self::Sdpa | Self::Softmax => false,
Self::RmsNorm | Self::ElemMul | Self::ElemAdd | Self::Other => true,
}
}
pub fn name(&self) -> &'static str {
match self {
Self::RmsNorm => "RmsNorm",
Self::ElemMul => "ElemMul",
Self::ElemAdd => "ElemAdd",
Self::Sdpa => "Sdpa",
Self::Softmax => "Softmax",
Self::Other => "Other",
}
}
}
pub type MemRange = (usize, usize);
#[derive(Clone)]
pub enum CapturedNode {
Dispatch {
pipeline: ComputePipelineState,
bindings: Vec<(u64, RecordedBinding)>,
threads_per_grid: MTLSize,
threads_per_threadgroup: MTLSize,
threadgroup_memory: Vec<(u64, u64)>,
dispatch_kind: DispatchKind,
op_kind: CapturedOpKind,
reads: Vec<MemRange>,
writes: Vec<MemRange>,
},
Barrier,
}
fn ranges_from_buffers(bufs: &[&MlxBuffer]) -> Vec<MemRange> {
bufs.iter()
.map(|b| {
let base = b.contents_ptr() as usize + b.byte_offset() as usize;
let extent = (b.byte_len()).saturating_sub(b.byte_offset() as usize);
(base, base + extent)
})
.collect()
}
#[inline]
fn apply_bindings(encoder: &ComputeCommandEncoderRef, bindings: &[(u64, KernelArg<'_>)]) {
for &(index, ref arg) in bindings {
match arg {
KernelArg::Buffer(buf) => {
encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
}
KernelArg::BufferWithOffset(buf, offset) => {
encoder.set_buffer(index, Some(buf.metal_buffer()), *offset);
}
KernelArg::Bytes(bytes) => {
encoder.set_bytes(index, bytes.len() as u64, bytes.as_ptr() as *const _);
}
}
}
}
static SYNC_COUNT: AtomicU64 = AtomicU64::new(0);
static DISPATCH_COUNT: AtomicU64 = AtomicU64::new(0);
static CMD_BUF_COUNT: AtomicU64 = AtomicU64::new(0);
static BARRIER_COUNT: AtomicU64 = AtomicU64::new(0);
static BARRIER_NS: AtomicU64 = AtomicU64::new(0);
pub fn reset_counters() {
SYNC_COUNT.store(0, Ordering::Relaxed);
DISPATCH_COUNT.store(0, Ordering::Relaxed);
CMD_BUF_COUNT.store(0, Ordering::Relaxed);
BARRIER_COUNT.store(0, Ordering::Relaxed);
BARRIER_NS.store(0, Ordering::Relaxed);
AUTO_BARRIER_COUNT.store(0, Ordering::Relaxed);
AUTO_BARRIER_CONCURRENT.store(0, Ordering::Relaxed);
}
pub fn sync_count() -> u64 {
SYNC_COUNT.load(Ordering::Relaxed)
}
pub fn dispatch_count() -> u64 {
DISPATCH_COUNT.load(Ordering::Relaxed)
}
pub fn cmd_buf_count() -> u64 {
CMD_BUF_COUNT.load(Ordering::Relaxed)
}
pub fn barrier_count() -> u64 {
BARRIER_COUNT.load(Ordering::Relaxed)
}
pub fn barrier_total_ns() -> u64 {
BARRIER_NS.load(Ordering::Relaxed)
}
fn barrier_profile_enabled() -> bool {
use std::sync::OnceLock;
static FLAG: OnceLock<bool> = OnceLock::new();
*FLAG.get_or_init(|| {
std::env::var("MLX_PROFILE_BARRIERS")
.map(|v| v == "1")
.unwrap_or(false)
})
}
fn unretained_refs_enabled() -> bool {
use std::sync::OnceLock;
static FLAG: OnceLock<bool> = OnceLock::new();
*FLAG.get_or_init(|| {
std::env::var("MLX_UNRETAINED_REFS")
.map(|v| v == "1")
.unwrap_or(false)
})
}
fn auto_barrier_enabled() -> bool {
use std::sync::OnceLock;
static FLAG: OnceLock<bool> = OnceLock::new();
*FLAG.get_or_init(|| {
std::env::var("HF2Q_AUTO_BARRIER")
.map(|v| v == "1")
.unwrap_or(false)
})
}
static AUTO_BARRIER_COUNT: AtomicU64 = AtomicU64::new(0);
static AUTO_BARRIER_CONCURRENT: AtomicU64 = AtomicU64::new(0);
const MAX_SAMPLES_PER_CB: u64 = 4096;
static TIMESTAMP_SET_WARN_LOGGED: AtomicU64 = AtomicU64::new(0);
#[derive(Clone, Debug)]
struct PendingDispatchMeta {
op_kind: &'static str,
dispatch_index: u32,
}
pub fn auto_barrier_count() -> u64 {
AUTO_BARRIER_COUNT.load(Ordering::Relaxed)
}
pub fn auto_barrier_concurrent_count() -> u64 {
AUTO_BARRIER_CONCURRENT.load(Ordering::Relaxed)
}
#[inline(never)]
fn issue_metal_buffer_barrier(encoder: &ComputeCommandEncoderRef) {
const MTL_BARRIER_SCOPE_BUFFERS: u64 = 1;
unsafe {
let _: () =
objc::msg_send![encoder, memoryBarrierWithScope: MTL_BARRIER_SCOPE_BUFFERS];
}
}
pub struct CommandEncoder {
cmd_buf: CommandBuffer,
active_encoder: *const ComputeCommandEncoderRef,
capture: Option<Vec<CapturedNode>>,
pending_op_kind: CapturedOpKind,
pending_reads: Vec<MemRange>,
pending_writes: Vec<MemRange>,
residency_set: Option<ResidencySet>,
mem_ranges: MemRanges,
sample_buffer: Option<CounterSampleBuffer>,
pending_dispatch_meta: Vec<PendingDispatchMeta>,
dispatch_in_cb: u32,
last_label: String,
}
unsafe impl Send for CommandEncoder {}
impl CommandEncoder {
#[allow(dead_code)]
pub(crate) fn new(queue: &CommandQueue) -> Result<Self> {
Self::new_with_residency(queue, None)
}
pub(crate) fn new_with_residency(
queue: &CommandQueue,
residency_set: Option<ResidencySet>,
) -> Result<Self> {
let cmd_buf = if unretained_refs_enabled() {
queue.new_command_buffer_with_unretained_references().to_owned()
} else {
queue.new_command_buffer().to_owned()
};
CMD_BUF_COUNT.fetch_add(1, Ordering::Relaxed);
Ok(Self {
cmd_buf,
active_encoder: std::ptr::null(),
capture: None,
pending_op_kind: CapturedOpKind::Other,
pending_reads: Vec::new(),
pending_writes: Vec::new(),
residency_set,
mem_ranges: MemRanges::new(),
sample_buffer: None,
pending_dispatch_meta: Vec::new(),
dispatch_in_cb: 0,
last_label: String::new(),
})
}
pub fn start_capture(&mut self) {
self.capture = Some(Vec::with_capacity(128));
}
pub fn is_capturing(&self) -> bool {
self.capture.is_some()
}
pub fn take_capture(&mut self) -> Option<Vec<CapturedNode>> {
self.capture.take()
}
pub fn set_op_kind(&mut self, kind: CapturedOpKind) {
self.pending_op_kind = kind;
}
fn take_pending_op_kind(&mut self) -> CapturedOpKind {
let kind = self.pending_op_kind;
self.pending_op_kind = CapturedOpKind::Other;
kind
}
pub fn set_pending_buffer_ranges(&mut self, reads: Vec<MemRange>, writes: Vec<MemRange>) {
self.pending_reads = reads;
self.pending_writes = writes;
}
pub fn annotate_last_dispatch_if_missing(&mut self, reads: Vec<MemRange>, writes: Vec<MemRange>) {
if let Some(ref mut nodes) = self.capture {
if let Some(CapturedNode::Dispatch { reads: r, writes: w, .. }) = nodes.last_mut() {
if r.is_empty() && !reads.is_empty() {
*r = reads;
}
if w.is_empty() && !writes.is_empty() {
*w = writes;
}
}
}
}
fn take_pending_buffer_ranges(&mut self) -> (Vec<MemRange>, Vec<MemRange>) {
let reads = std::mem::take(&mut self.pending_reads);
let writes = std::mem::take(&mut self.pending_writes);
(reads, writes)
}
fn record_buffer_bindings(buffers: &[(u64, &MlxBuffer)]) -> Vec<(u64, RecordedBinding)> {
buffers
.iter()
.map(|&(index, buf)| {
(
index,
RecordedBinding::Buffer {
metal_buffer: buf.metal_buffer().clone(),
offset: buf.byte_offset(),
},
)
})
.collect()
}
fn record_arg_bindings(bindings: &[(u64, KernelArg<'_>)]) -> Vec<(u64, RecordedBinding)> {
bindings
.iter()
.map(|(index, arg)| {
let recorded = match arg {
KernelArg::Buffer(buf) => RecordedBinding::Buffer {
metal_buffer: buf.metal_buffer().clone(),
offset: buf.byte_offset(),
},
KernelArg::BufferWithOffset(buf, offset) => RecordedBinding::Buffer {
metal_buffer: buf.metal_buffer().clone(),
offset: *offset,
},
KernelArg::Bytes(bytes) => RecordedBinding::Bytes(bytes.to_vec()),
};
(*index, recorded)
})
.collect()
}
#[inline]
fn get_or_create_encoder(&mut self) -> &ComputeCommandEncoderRef {
if self.active_encoder.is_null() {
let dispatch_type = if std::env::var("HF2Q_FORCE_SERIAL_DISPATCH")
.map(|v| v == "1")
.unwrap_or(false)
{
MTLDispatchType::Serial
} else {
MTLDispatchType::Concurrent
};
let encoder = self
.cmd_buf
.compute_command_encoder_with_dispatch_type(dispatch_type);
self.active_encoder = encoder as *const ComputeCommandEncoderRef;
}
unsafe { &*self.active_encoder }
}
#[inline]
fn end_active_encoder(&mut self) {
if !self.active_encoder.is_null() {
unsafe { &*self.active_encoder }.end_encoding();
self.active_encoder = std::ptr::null();
}
}
#[allow(unexpected_cfgs)]
pub fn memory_barrier(&mut self) {
if let Some(ref mut nodes) = self.capture {
nodes.push(CapturedNode::Barrier);
return;
}
if self.active_encoder.is_null() {
return;
}
BARRIER_COUNT.fetch_add(1, Ordering::Relaxed);
let encoder = unsafe { &*self.active_encoder };
if barrier_profile_enabled() {
let start = std::time::Instant::now();
issue_metal_buffer_barrier(encoder);
let elapsed_ns = start.elapsed().as_nanos() as u64;
BARRIER_NS.fetch_add(elapsed_ns, Ordering::Relaxed);
} else {
issue_metal_buffer_barrier(encoder);
}
}
pub fn set_pipeline(&mut self, pipeline: &ComputePipelineStateRef) {
let encoder = self.get_or_create_encoder();
encoder.set_compute_pipeline_state(pipeline);
}
pub fn set_buffer(&self, index: u64, buffer: &MlxBuffer) {
let _ = (index, buffer);
}
pub fn dispatch_threads(&self, grid_size: MTLSize, threadgroup_size: MTLSize) {
let _ = (grid_size, threadgroup_size);
}
pub fn encode(
&mut self,
pipeline: &ComputePipelineStateRef,
buffers: &[(u64, &MlxBuffer)],
grid_size: MTLSize,
threadgroup_size: MTLSize,
) {
DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
let op_kind = self.take_pending_op_kind();
let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
if let Some(ref mut nodes) = self.capture {
nodes.push(CapturedNode::Dispatch {
pipeline: pipeline.to_owned(),
bindings: Self::record_buffer_bindings(buffers),
threads_per_grid: grid_size,
threads_per_threadgroup: threadgroup_size,
threadgroup_memory: Vec::new(),
dispatch_kind: DispatchKind::Threads,
op_kind,
reads: pending_reads,
writes: pending_writes,
});
return;
}
self.ensure_sample_buffer();
let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
let encoder = unsafe { &*encoder_ptr };
encoder.set_compute_pipeline_state(pipeline);
for &(index, buf) in buffers {
encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
}
let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
encoder.dispatch_threads(grid_size, threadgroup_size);
self.sample_dispatch_post(encoder, pre_idx);
}
pub fn encode_threadgroups(
&mut self,
pipeline: &ComputePipelineStateRef,
buffers: &[(u64, &MlxBuffer)],
threadgroups: MTLSize,
threadgroup_size: MTLSize,
) {
DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
let op_kind = self.take_pending_op_kind();
let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
if let Some(ref mut nodes) = self.capture {
nodes.push(CapturedNode::Dispatch {
pipeline: pipeline.to_owned(),
bindings: Self::record_buffer_bindings(buffers),
threads_per_grid: threadgroups,
threads_per_threadgroup: threadgroup_size,
threadgroup_memory: Vec::new(),
dispatch_kind: DispatchKind::ThreadGroups,
op_kind,
reads: pending_reads,
writes: pending_writes,
});
return;
}
self.ensure_sample_buffer();
let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
let encoder = unsafe { &*encoder_ptr };
encoder.set_compute_pipeline_state(pipeline);
for &(index, buf) in buffers {
encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
}
let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
self.sample_dispatch_post(encoder, pre_idx);
}
pub fn encode_threadgroups_with_shared(
&mut self,
pipeline: &ComputePipelineStateRef,
buffers: &[(u64, &MlxBuffer)],
threadgroup_mem: &[(u64, u64)],
threadgroups: MTLSize,
threadgroup_size: MTLSize,
) {
DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
let op_kind = self.take_pending_op_kind();
let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
if let Some(ref mut nodes) = self.capture {
nodes.push(CapturedNode::Dispatch {
pipeline: pipeline.to_owned(),
bindings: Self::record_buffer_bindings(buffers),
threads_per_grid: threadgroups,
threads_per_threadgroup: threadgroup_size,
threadgroup_memory: threadgroup_mem.to_vec(),
dispatch_kind: DispatchKind::ThreadGroups,
op_kind,
reads: pending_reads,
writes: pending_writes,
});
return;
}
self.ensure_sample_buffer();
let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
let encoder = unsafe { &*encoder_ptr };
encoder.set_compute_pipeline_state(pipeline);
for &(index, buf) in buffers {
encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
}
for &(index, byte_length) in threadgroup_mem {
encoder.set_threadgroup_memory_length(index, byte_length);
}
let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
self.sample_dispatch_post(encoder, pre_idx);
}
pub fn encode_with_args(
&mut self,
pipeline: &ComputePipelineStateRef,
bindings: &[(u64, KernelArg<'_>)],
grid_size: MTLSize,
threadgroup_size: MTLSize,
) {
DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
let op_kind = self.take_pending_op_kind();
let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
if let Some(ref mut nodes) = self.capture {
nodes.push(CapturedNode::Dispatch {
pipeline: pipeline.to_owned(),
bindings: Self::record_arg_bindings(bindings),
threads_per_grid: grid_size,
threads_per_threadgroup: threadgroup_size,
threadgroup_memory: Vec::new(),
dispatch_kind: DispatchKind::Threads,
op_kind,
reads: pending_reads,
writes: pending_writes,
});
return;
}
self.ensure_sample_buffer();
let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
let encoder = unsafe { &*encoder_ptr };
encoder.set_compute_pipeline_state(pipeline);
apply_bindings(encoder, bindings);
let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
encoder.dispatch_threads(grid_size, threadgroup_size);
self.sample_dispatch_post(encoder, pre_idx);
}
pub fn encode_threadgroups_with_args(
&mut self,
pipeline: &ComputePipelineStateRef,
bindings: &[(u64, KernelArg<'_>)],
threadgroups: MTLSize,
threadgroup_size: MTLSize,
) {
DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
let op_kind = self.take_pending_op_kind();
let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
if let Some(ref mut nodes) = self.capture {
nodes.push(CapturedNode::Dispatch {
pipeline: pipeline.to_owned(),
bindings: Self::record_arg_bindings(bindings),
threads_per_grid: threadgroups,
threads_per_threadgroup: threadgroup_size,
threadgroup_memory: Vec::new(),
dispatch_kind: DispatchKind::ThreadGroups,
op_kind,
reads: pending_reads,
writes: pending_writes,
});
return;
}
self.ensure_sample_buffer();
let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
let encoder = unsafe { &*encoder_ptr };
encoder.set_compute_pipeline_state(pipeline);
apply_bindings(encoder, bindings);
let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
self.sample_dispatch_post(encoder, pre_idx);
}
pub fn encode_threadgroups_with_args_and_shared(
&mut self,
pipeline: &ComputePipelineStateRef,
bindings: &[(u64, KernelArg<'_>)],
threadgroup_mem: &[(u64, u64)],
threadgroups: MTLSize,
threadgroup_size: MTLSize,
) {
DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
let op_kind = self.take_pending_op_kind();
let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
if let Some(ref mut nodes) = self.capture {
nodes.push(CapturedNode::Dispatch {
pipeline: pipeline.to_owned(),
bindings: Self::record_arg_bindings(bindings),
threads_per_grid: threadgroups,
threads_per_threadgroup: threadgroup_size,
threadgroup_memory: threadgroup_mem.to_vec(),
dispatch_kind: DispatchKind::ThreadGroups,
op_kind,
reads: pending_reads,
writes: pending_writes,
});
return;
}
self.ensure_sample_buffer();
let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
let encoder = unsafe { &*encoder_ptr };
encoder.set_compute_pipeline_state(pipeline);
apply_bindings(encoder, bindings);
for &(index, byte_length) in threadgroup_mem {
encoder.set_threadgroup_memory_length(index, byte_length);
}
let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
self.sample_dispatch_post(encoder, pre_idx);
}
pub fn dispatch_tracked_threadgroups_with_args(
&mut self,
pipeline: &ComputePipelineStateRef,
bindings: &[(u64, KernelArg<'_>)],
reads: &[&MlxBuffer],
writes: &[&MlxBuffer],
threadgroups: MTLSize,
threadgroup_size: MTLSize,
) {
if self.is_capturing() {
let read_ranges = ranges_from_buffers(reads);
let write_ranges = ranges_from_buffers(writes);
self.set_pending_buffer_ranges(read_ranges, write_ranges);
self.encode_threadgroups_with_args(pipeline, bindings, threadgroups, threadgroup_size);
return;
}
if auto_barrier_enabled() {
self.maybe_auto_barrier(reads, writes);
}
self.encode_threadgroups_with_args(pipeline, bindings, threadgroups, threadgroup_size);
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_tracked_threadgroups_with_args_and_shared(
&mut self,
pipeline: &ComputePipelineStateRef,
bindings: &[(u64, KernelArg<'_>)],
threadgroup_mem: &[(u64, u64)],
reads: &[&MlxBuffer],
writes: &[&MlxBuffer],
threadgroups: MTLSize,
threadgroup_size: MTLSize,
) {
if self.is_capturing() {
let read_ranges = ranges_from_buffers(reads);
let write_ranges = ranges_from_buffers(writes);
self.set_pending_buffer_ranges(read_ranges, write_ranges);
self.encode_threadgroups_with_args_and_shared(
pipeline,
bindings,
threadgroup_mem,
threadgroups,
threadgroup_size,
);
return;
}
if auto_barrier_enabled() {
self.maybe_auto_barrier(reads, writes);
}
self.encode_threadgroups_with_args_and_shared(
pipeline,
bindings,
threadgroup_mem,
threadgroups,
threadgroup_size,
);
}
pub fn dispatch_tracked_threadgroups(
&mut self,
pipeline: &ComputePipelineStateRef,
buffers: &[(u64, &MlxBuffer)],
reads: &[&MlxBuffer],
writes: &[&MlxBuffer],
threadgroups: MTLSize,
threadgroup_size: MTLSize,
) {
if self.is_capturing() {
let read_ranges = ranges_from_buffers(reads);
let write_ranges = ranges_from_buffers(writes);
self.set_pending_buffer_ranges(read_ranges, write_ranges);
self.encode_threadgroups(pipeline, buffers, threadgroups, threadgroup_size);
return;
}
if auto_barrier_enabled() {
self.maybe_auto_barrier(reads, writes);
}
self.encode_threadgroups(pipeline, buffers, threadgroups, threadgroup_size);
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_tracked_threadgroups_with_shared(
&mut self,
pipeline: &ComputePipelineStateRef,
buffers: &[(u64, &MlxBuffer)],
threadgroup_mem: &[(u64, u64)],
reads: &[&MlxBuffer],
writes: &[&MlxBuffer],
threadgroups: MTLSize,
threadgroup_size: MTLSize,
) {
if self.is_capturing() {
let read_ranges = ranges_from_buffers(reads);
let write_ranges = ranges_from_buffers(writes);
self.set_pending_buffer_ranges(read_ranges, write_ranges);
self.encode_threadgroups_with_shared(
pipeline,
buffers,
threadgroup_mem,
threadgroups,
threadgroup_size,
);
return;
}
if auto_barrier_enabled() {
self.maybe_auto_barrier(reads, writes);
}
self.encode_threadgroups_with_shared(
pipeline,
buffers,
threadgroup_mem,
threadgroups,
threadgroup_size,
);
}
pub fn dispatch_tracked_threads_with_args(
&mut self,
pipeline: &ComputePipelineStateRef,
bindings: &[(u64, KernelArg<'_>)],
reads: &[&MlxBuffer],
writes: &[&MlxBuffer],
grid_size: MTLSize,
threadgroup_size: MTLSize,
) {
if self.is_capturing() {
let read_ranges = ranges_from_buffers(reads);
let write_ranges = ranges_from_buffers(writes);
self.set_pending_buffer_ranges(read_ranges, write_ranges);
self.encode_with_args(pipeline, bindings, grid_size, threadgroup_size);
return;
}
if auto_barrier_enabled() {
self.maybe_auto_barrier(reads, writes);
}
self.encode_with_args(pipeline, bindings, grid_size, threadgroup_size);
}
fn maybe_auto_barrier(
&mut self,
reads: &[&MlxBuffer],
writes: &[&MlxBuffer],
) {
if self.mem_ranges.check_dispatch(reads, writes) {
self.mem_ranges.add_dispatch(reads, writes);
AUTO_BARRIER_CONCURRENT.fetch_add(1, Ordering::Relaxed);
} else {
self.memory_barrier();
self.mem_ranges.reset();
self.mem_ranges.add_dispatch(reads, writes);
AUTO_BARRIER_COUNT.fetch_add(1, Ordering::Relaxed);
}
}
pub fn force_barrier_and_reset_tracker(&mut self) {
self.memory_barrier();
if auto_barrier_enabled() {
self.mem_ranges.reset();
}
}
#[inline]
pub fn mem_ranges_len(&self) -> usize {
self.mem_ranges.len()
}
pub fn replay_dispatch(
&mut self,
pipeline: &ComputePipelineStateRef,
bindings: &[(u64, RecordedBinding)],
threadgroup_memory: &[(u64, u64)],
threads_per_grid: MTLSize,
threads_per_threadgroup: MTLSize,
dispatch_kind: DispatchKind,
) {
self.ensure_sample_buffer();
let op_kind = self.take_pending_op_kind();
let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
let encoder = unsafe { &*encoder_ptr };
encoder.set_compute_pipeline_state(pipeline);
for (index, binding) in bindings {
match binding {
RecordedBinding::Buffer { metal_buffer, offset } => {
encoder.set_buffer(*index, Some(metal_buffer), *offset);
}
RecordedBinding::Bytes(bytes) => {
encoder.set_bytes(
*index,
bytes.len() as u64,
bytes.as_ptr() as *const _,
);
}
}
}
for &(index, byte_length) in threadgroup_memory {
encoder.set_threadgroup_memory_length(index, byte_length);
}
let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
match dispatch_kind {
DispatchKind::Threads => {
encoder.dispatch_threads(threads_per_grid, threads_per_threadgroup);
}
DispatchKind::ThreadGroups => {
encoder.dispatch_thread_groups(threads_per_grid, threads_per_threadgroup);
}
}
self.sample_dispatch_post(encoder, pre_idx);
}
#[inline]
fn flush_residency_pending(&self) {
if let Some(set) = self.residency_set.as_ref() {
set.flush_pending();
}
}
#[inline]
fn ensure_sample_buffer(&mut self) {
if !crate::kernel_profile::is_dispatch_enabled() {
return;
}
if self.sample_buffer.is_some() {
return;
}
let device: &metal::DeviceRef = unsafe {
let cb = &*self.cmd_buf;
msg_send![cb, device]
};
if !device.supports_counter_sampling(MTLCounterSamplingPoint::AtDispatchBoundary) {
if TIMESTAMP_SET_WARN_LOGGED
.compare_exchange(0, 1, Ordering::Relaxed, Ordering::Relaxed)
.is_ok()
{
eprintln!(
"[mlx-native] MLX_PROFILE_DISPATCH=1 ignored: \
device {:?} does NOT support \
MTLCounterSamplingPointAtDispatchBoundary \
(Apple Silicon limitation; only AtStageBoundary \
is supported, which is incompatible with the \
persistent compute-encoder pattern). \
MLX_PROFILE_CB=1 still produces per-CB GPU times.",
device.name()
);
}
return;
}
let counter_sets = device.counter_sets();
let timestamp_set = counter_sets
.iter()
.find(|c: &&metal::CounterSet| c.name().eq_ignore_ascii_case("timestamp"));
let timestamp_set = match timestamp_set {
Some(s) => s,
None => {
if TIMESTAMP_SET_WARN_LOGGED
.compare_exchange(0, 1, Ordering::Relaxed, Ordering::Relaxed)
.is_ok()
{
eprintln!(
"[mlx-native] MLX_PROFILE_DISPATCH=1 ignored: \
device {:?} exposes no MTLCommonCounterSetTimestamp",
device.name()
);
}
return;
}
};
let descriptor = CounterSampleBufferDescriptor::new();
descriptor.set_counter_set(timestamp_set);
descriptor.set_storage_mode(MTLStorageMode::Shared);
descriptor.set_label("mlx_native.dispatch_samples");
descriptor.set_sample_count(MAX_SAMPLES_PER_CB);
match device.new_counter_sample_buffer_with_descriptor(&descriptor) {
Ok(buf) => {
self.sample_buffer = Some(buf);
}
Err(e) => {
if TIMESTAMP_SET_WARN_LOGGED
.compare_exchange(0, 1, Ordering::Relaxed, Ordering::Relaxed)
.is_ok()
{
eprintln!(
"[mlx-native] MLX_PROFILE_DISPATCH=1 ignored: \
newCounterSampleBufferWithDescriptor failed: {}",
e
);
}
self.sample_buffer = None;
}
}
}
#[inline]
fn sample_dispatch_pre(
&mut self,
encoder: &ComputeCommandEncoderRef,
op_kind: CapturedOpKind,
) -> Option<u32> {
let sb = self.sample_buffer.as_ref()?;
let i = self.dispatch_in_cb;
let pre_idx = (i as u64).checked_mul(2)?;
if pre_idx >= MAX_SAMPLES_PER_CB {
return None;
}
encoder.sample_counters_in_buffer(sb, pre_idx, true);
self.pending_dispatch_meta.push(PendingDispatchMeta {
op_kind: op_kind.name(),
dispatch_index: i,
});
Some(i)
}
#[inline]
fn sample_dispatch_post(
&mut self,
encoder: &ComputeCommandEncoderRef,
pre_idx: Option<u32>,
) {
let i = match pre_idx {
Some(v) => v,
None => return,
};
let sb = match self.sample_buffer.as_ref() {
Some(b) => b,
None => return,
};
let post_idx = match (i as u64).checked_mul(2).and_then(|v| v.checked_add(1)) {
Some(v) if v < MAX_SAMPLES_PER_CB => v,
_ => return,
};
encoder.sample_counters_in_buffer(sb, post_idx, true);
self.dispatch_in_cb = i.saturating_add(1);
}
fn resolve_dispatch_samples(&mut self, cb_label: &str) -> Result<()> {
let sb = match self.sample_buffer.take() {
Some(b) => b,
None => {
self.pending_dispatch_meta.clear();
self.dispatch_in_cb = 0;
return Ok(());
}
};
let n = self.pending_dispatch_meta.len();
if n == 0 {
self.dispatch_in_cb = 0;
return Ok(());
}
let mut cpu_t: u64 = 0;
let mut gpu_t: u64 = 0;
let device: &metal::DeviceRef = unsafe {
let cb = &*self.cmd_buf;
msg_send![cb, device]
};
device.sample_timestamps(&mut cpu_t, &mut gpu_t);
crate::kernel_profile::record_clock_pair(cpu_t, gpu_t);
let length = (n as u64).saturating_mul(2);
let data = sb.resolve_counter_range(NSRange {
location: 0,
length,
});
for (i, meta) in self.pending_dispatch_meta.drain(..).enumerate() {
let start_idx = 2 * i;
let end_idx = 2 * i + 1;
if end_idx >= data.len() {
break;
}
let start_raw = data[start_idx] as u64;
let end_raw = data[end_idx] as u64;
let start_ns = crate::kernel_profile::convert_gpu_ticks_to_ns(start_raw);
let end_ns = crate::kernel_profile::convert_gpu_ticks_to_ns(end_raw);
let gpu_ns = end_ns.saturating_sub(start_ns);
crate::kernel_profile::record_dispatch(
crate::kernel_profile::DispatchEntry {
cb_label: cb_label.to_string(),
op_kind: meta.op_kind,
dispatch_index: meta.dispatch_index,
gpu_ns,
start_gpu_ns: start_ns,
end_gpu_ns: end_ns,
},
);
}
drop(sb);
self.dispatch_in_cb = 0;
Ok(())
}
pub fn commit_and_wait(&mut self) -> Result<()> {
SYNC_COUNT.fetch_add(1, Ordering::Relaxed);
self.end_active_encoder();
self.flush_residency_pending();
self.cmd_buf.commit();
self.cmd_buf.wait_until_completed();
match self.cmd_buf.status() {
MTLCommandBufferStatus::Completed => Ok(()),
MTLCommandBufferStatus::Error => {
Err(MlxError::CommandBufferError(
"GPU command buffer completed with error status".into(),
))
}
status => Err(MlxError::CommandBufferError(format!(
"Unexpected command buffer status after wait: {:?}",
status
))),
}
}
pub fn commit_and_wait_labeled(&mut self, label: &str) -> Result<()> {
self.apply_labels(label);
let need_gpu_time =
crate::kernel_profile::is_enabled() || crate::kernel_profile::is_dispatch_enabled();
if need_gpu_time {
let (start_s, end_s) = self.commit_wait_with_gpu_time()?;
let ns = ((end_s - start_s).max(0.0) * 1_000_000_000.0) as u64;
if crate::kernel_profile::is_enabled() {
crate::kernel_profile::record(label, ns);
}
if crate::kernel_profile::is_dispatch_enabled() {
self.resolve_dispatch_samples(label)?;
}
Ok(())
} else {
self.commit_and_wait()
}
}
pub fn commit_labeled(&mut self, label: &str) {
if crate::kernel_profile::is_enabled() {
if let Err(e) = self.commit_and_wait_labeled(label) {
eprintln!("[mlx-native] commit_labeled({}) failed: {}", label, e);
}
} else {
self.apply_labels(label);
self.commit();
}
}
#[inline]
fn apply_labels(&mut self, label: &str) {
debug_assert!(!label.is_empty(), "commit_*_labeled called with empty label");
if label.is_empty() {
return;
}
self.cmd_buf.set_label(label);
if !self.active_encoder.is_null() {
unsafe { &*self.active_encoder }.set_label(label);
}
self.last_label.clear();
self.last_label.push_str(label);
}
pub fn commit_wait_with_gpu_time(&mut self) -> Result<(f64, f64)> {
self.commit_and_wait()?;
let (gpu_start, gpu_end): (f64, f64) = unsafe {
let cb = &*self.cmd_buf;
let s: f64 = msg_send![cb, GPUStartTime];
let e: f64 = msg_send![cb, GPUEndTime];
(s, e)
};
Ok((gpu_start, gpu_end))
}
pub fn commit(&mut self) {
self.end_active_encoder();
self.flush_residency_pending();
self.cmd_buf.commit();
}
pub fn wait_until_completed(&self) -> Result<()> {
self.cmd_buf.wait_until_completed();
match self.cmd_buf.status() {
MTLCommandBufferStatus::Completed => Ok(()),
MTLCommandBufferStatus::Error => Err(MlxError::CommandBufferError(
"GPU command buffer completed with error status".into(),
)),
status => Err(MlxError::CommandBufferError(format!(
"Unexpected command buffer status after wait: {:?}",
status
))),
}
}
#[inline]
pub fn metal_command_buffer(&self) -> &CommandBuffer {
&self.cmd_buf
}
}
impl Drop for CommandEncoder {
fn drop(&mut self) {
self.end_active_encoder();
}
}