use std::sync::atomic::{AtomicI8, AtomicU64, Ordering};
use metal::{
CommandBuffer, CommandQueue, ComputeCommandEncoderRef, ComputePipelineState,
ComputePipelineStateRef, CounterSampleBuffer, CounterSampleBufferDescriptor,
MTLCommandBufferStatus, MTLCounterSamplingPoint, MTLDispatchType, MTLSize, MTLStorageMode,
NSRange,
};
#[allow(unused_imports)]
use objc::{msg_send, sel, sel_impl};
use crate::buffer::MlxBuffer;
use crate::error::{MlxError, Result};
use crate::mem_ranges::MemRanges;
use crate::residency::ResidencySet;
pub enum KernelArg<'a> {
Buffer(&'a MlxBuffer),
BufferWithOffset(&'a MlxBuffer, u64),
Bytes(&'a [u8]),
}
#[derive(Clone)]
pub struct DispatchRecord {
pub pipeline: ComputePipelineState,
pub threadgroups: MTLSize,
pub threads_per_tg: MTLSize,
pub threadgroup_mem: Vec<(u64, u64)>,
pub params_bytes: Vec<u8>,
pub params_slot: u64,
pub buffer_slots: Vec<u64>,
pub op_kind: CapturedOpKind,
pub kernel_name: String,
}
pub fn as_bytes<T: bytemuck::Pod>(val: &T) -> &[u8] {
bytemuck::bytes_of(val)
}
#[derive(Clone)]
pub enum RecordedBinding {
Buffer {
metal_buffer: metal::Buffer,
offset: u64,
},
Bytes(Vec<u8>),
}
#[derive(Clone, Copy, Debug)]
pub enum DispatchKind {
Threads,
ThreadGroups,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum CapturedOpKind {
RmsNorm,
ElemMul,
ElemAdd,
Sdpa,
Softmax,
Other,
}
impl CapturedOpKind {
pub fn is_reorderable(&self) -> bool {
match self {
Self::Sdpa | Self::Softmax => false,
Self::RmsNorm | Self::ElemMul | Self::ElemAdd | Self::Other => true,
}
}
pub fn name(&self) -> &'static str {
match self {
Self::RmsNorm => "RmsNorm",
Self::ElemMul => "ElemMul",
Self::ElemAdd => "ElemAdd",
Self::Sdpa => "Sdpa",
Self::Softmax => "Softmax",
Self::Other => "Other",
}
}
}
pub type MemRange = (usize, usize);
#[derive(Clone)]
pub enum CapturedNode {
Dispatch {
pipeline: ComputePipelineState,
bindings: Vec<(u64, RecordedBinding)>,
threads_per_grid: MTLSize,
threads_per_threadgroup: MTLSize,
threadgroup_memory: Vec<(u64, u64)>,
dispatch_kind: DispatchKind,
op_kind: CapturedOpKind,
reads: Vec<MemRange>,
writes: Vec<MemRange>,
},
Barrier,
}
fn ranges_from_buffers(bufs: &[&MlxBuffer]) -> Vec<MemRange> {
bufs.iter()
.map(|b| {
let base = b.contents_ptr() as usize + b.byte_offset() as usize;
let extent = (b.byte_len()).saturating_sub(b.byte_offset() as usize);
(base, base + extent)
})
.collect()
}
#[inline]
fn apply_bindings(encoder: &ComputeCommandEncoderRef, bindings: &[(u64, KernelArg<'_>)]) {
for &(index, ref arg) in bindings {
match arg {
KernelArg::Buffer(buf) => {
encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
}
KernelArg::BufferWithOffset(buf, offset) => {
encoder.set_buffer(index, Some(buf.metal_buffer()), *offset);
}
KernelArg::Bytes(bytes) => {
encoder.set_bytes(index, bytes.len() as u64, bytes.as_ptr() as *const _);
}
}
}
}
static SYNC_COUNT: AtomicU64 = AtomicU64::new(0);
static DISPATCH_COUNT: AtomicU64 = AtomicU64::new(0);
static CMD_BUF_COUNT: AtomicU64 = AtomicU64::new(0);
static BARRIER_COUNT: AtomicU64 = AtomicU64::new(0);
static BARRIER_NS: AtomicU64 = AtomicU64::new(0);
pub fn reset_counters() {
SYNC_COUNT.store(0, Ordering::Relaxed);
DISPATCH_COUNT.store(0, Ordering::Relaxed);
CMD_BUF_COUNT.store(0, Ordering::Relaxed);
BARRIER_COUNT.store(0, Ordering::Relaxed);
BARRIER_NS.store(0, Ordering::Relaxed);
AUTO_BARRIER_COUNT.store(0, Ordering::Relaxed);
AUTO_BARRIER_CONCURRENT.store(0, Ordering::Relaxed);
}
pub fn sync_count() -> u64 {
SYNC_COUNT.load(Ordering::Relaxed)
}
pub fn dispatch_count() -> u64 {
DISPATCH_COUNT.load(Ordering::Relaxed)
}
fn pipeline_buckets()
-> &'static std::sync::Mutex<std::collections::HashMap<String, u64>> {
static BUCKETS: std::sync::OnceLock<
std::sync::Mutex<std::collections::HashMap<String, u64>>,
> = std::sync::OnceLock::new();
BUCKETS.get_or_init(|| std::sync::Mutex::new(std::collections::HashMap::new()))
}
fn pipeline_bucket_enabled() -> bool {
static CACHED: AtomicI8 = AtomicI8::new(-1);
let v = CACHED.load(Ordering::Relaxed);
if v >= 0 {
return v == 1;
}
let on = std::env::var("MLX_DISP_BUCKET").as_deref() == Ok("1");
CACHED.store(if on { 1 } else { 0 }, Ordering::Relaxed);
on
}
#[inline]
pub(crate) fn bucket_dispatch(pipeline: &ComputePipelineStateRef) {
if !pipeline_bucket_enabled() {
return;
}
let label = pipeline.label();
if label.is_empty() {
return;
}
if let Ok(mut t) = pipeline_buckets().lock() {
*t.entry(label.to_string()).or_insert(0) += 1;
}
}
pub fn pipeline_dispatch_buckets() -> Vec<(String, u64)> {
let mut v: Vec<(String, u64)> = if let Ok(t) = pipeline_buckets().lock() {
t.iter().map(|(k, v)| (k.clone(), *v)).collect()
} else {
Vec::new()
};
v.sort_by(|a, b| b.1.cmp(&a.1));
v
}
pub fn reset_pipeline_dispatch_buckets() {
if let Ok(mut t) = pipeline_buckets().lock() {
t.clear();
}
}
pub fn cmd_buf_count() -> u64 {
CMD_BUF_COUNT.load(Ordering::Relaxed)
}
pub fn barrier_count() -> u64 {
BARRIER_COUNT.load(Ordering::Relaxed)
}
pub fn barrier_total_ns() -> u64 {
BARRIER_NS.load(Ordering::Relaxed)
}
fn barrier_profile_enabled() -> bool {
use std::sync::OnceLock;
static FLAG: OnceLock<bool> = OnceLock::new();
*FLAG.get_or_init(|| {
std::env::var("MLX_PROFILE_BARRIERS")
.map(|v| v == "1")
.unwrap_or(false)
})
}
fn unretained_refs_enabled() -> bool {
use std::sync::OnceLock;
static FLAG: OnceLock<bool> = OnceLock::new();
*FLAG.get_or_init(|| {
std::env::var("MLX_UNRETAINED_REFS")
.map(|v| v == "1")
.unwrap_or(false)
})
}
fn pipeline_tg_mult_hint_enabled() -> bool {
use std::sync::OnceLock;
static FLAG: OnceLock<bool> = OnceLock::new();
*FLAG.get_or_init(|| {
std::env::var("HF2Q_PIPELINE_TG_MULT_HINT")
.map(|v| v == "1")
.unwrap_or(false)
})
}
#[inline]
fn assert_tg_size_multiple_of_32_if_hinted(
tg: MTLSize,
pipeline: &ComputePipelineStateRef,
) {
if !pipeline_tg_mult_hint_enabled() {
return;
}
let total = tg.width.saturating_mul(tg.height).saturating_mul(tg.depth);
if total % 32 != 0 {
let label = pipeline.label();
panic!(
"ADR-029 Step 1q safety: HF2Q_PIPELINE_TG_MULT_HINT=1 requires \
threadgroup_size.x * y * z to be a multiple of 32 (Apple's \
threadExecutionWidth). Got tg=({}, {}, {}) → total={} → \
{} mod 32 = {}. Pipeline label: \"{}\". Either fix the \
dispatch site to use a multiple-of-32 threadgroup, or unset \
HF2Q_PIPELINE_TG_MULT_HINT.",
tg.width, tg.height, tg.depth, total, total, total % 32,
label
);
}
}
fn auto_barrier_enabled() -> bool {
use std::sync::OnceLock;
static FLAG: OnceLock<bool> = OnceLock::new();
*FLAG.get_or_init(|| {
std::env::var("HF2Q_AUTO_BARRIER")
.map(|v| v == "1")
.unwrap_or(false)
})
}
static AUTO_BARRIER_COUNT: AtomicU64 = AtomicU64::new(0);
static AUTO_BARRIER_CONCURRENT: AtomicU64 = AtomicU64::new(0);
const MAX_SAMPLES_PER_CB: u64 = 4096;
static TIMESTAMP_SET_WARN_LOGGED: AtomicU64 = AtomicU64::new(0);
#[derive(Clone, Debug)]
struct PendingDispatchMeta {
op_kind: &'static str,
dispatch_index: u32,
}
pub fn auto_barrier_count() -> u64 {
AUTO_BARRIER_COUNT.load(Ordering::Relaxed)
}
pub fn auto_barrier_concurrent_count() -> u64 {
AUTO_BARRIER_CONCURRENT.load(Ordering::Relaxed)
}
#[inline(never)]
fn issue_metal_buffer_barrier(encoder: &ComputeCommandEncoderRef) {
const MTL_BARRIER_SCOPE_BUFFERS: u64 = 1;
unsafe {
let _: () =
objc::msg_send![encoder, memoryBarrierWithScope: MTL_BARRIER_SCOPE_BUFFERS];
}
}
pub struct CommandEncoder {
cmd_buf: CommandBuffer,
queue: CommandQueue,
active_encoder: *const ComputeCommandEncoderRef,
capture: Option<Vec<CapturedNode>>,
pending_op_kind: CapturedOpKind,
pending_reads: Vec<MemRange>,
pending_writes: Vec<MemRange>,
residency_set: Option<ResidencySet>,
mem_ranges: MemRanges,
sample_buffer: Option<CounterSampleBuffer>,
pending_dispatch_meta: Vec<PendingDispatchMeta>,
dispatch_in_cb: u32,
last_label: String,
}
unsafe impl Send for CommandEncoder {}
impl CommandEncoder {
#[allow(dead_code)]
pub(crate) fn new(queue: &CommandQueue) -> Result<Self> {
Self::new_with_residency(queue, None)
}
pub(crate) fn new_with_residency(
queue: &CommandQueue,
residency_set: Option<ResidencySet>,
) -> Result<Self> {
let cmd_buf = if unretained_refs_enabled() {
queue.new_command_buffer_with_unretained_references().to_owned()
} else {
queue.new_command_buffer().to_owned()
};
CMD_BUF_COUNT.fetch_add(1, Ordering::Relaxed);
Ok(Self {
cmd_buf,
queue: queue.to_owned(),
active_encoder: std::ptr::null(),
capture: None,
pending_op_kind: CapturedOpKind::Other,
pending_reads: Vec::new(),
pending_writes: Vec::new(),
residency_set,
mem_ranges: MemRanges::new(),
sample_buffer: None,
pending_dispatch_meta: Vec::new(),
dispatch_in_cb: 0,
last_label: String::new(),
})
}
pub fn start_capture(&mut self) {
self.capture = Some(Vec::with_capacity(128));
}
pub fn is_capturing(&self) -> bool {
self.capture.is_some()
}
pub fn take_capture(&mut self) -> Option<Vec<CapturedNode>> {
self.capture.take()
}
pub fn set_op_kind(&mut self, kind: CapturedOpKind) {
self.pending_op_kind = kind;
}
fn take_pending_op_kind(&mut self) -> CapturedOpKind {
let kind = self.pending_op_kind;
self.pending_op_kind = CapturedOpKind::Other;
kind
}
pub fn set_pending_buffer_ranges(&mut self, reads: Vec<MemRange>, writes: Vec<MemRange>) {
self.pending_reads = reads;
self.pending_writes = writes;
}
pub fn annotate_last_dispatch_if_missing(&mut self, reads: Vec<MemRange>, writes: Vec<MemRange>) {
if let Some(ref mut nodes) = self.capture {
if let Some(CapturedNode::Dispatch { reads: r, writes: w, .. }) = nodes.last_mut() {
if r.is_empty() && !reads.is_empty() {
*r = reads;
}
if w.is_empty() && !writes.is_empty() {
*w = writes;
}
}
}
}
fn take_pending_buffer_ranges(&mut self) -> (Vec<MemRange>, Vec<MemRange>) {
let reads = std::mem::take(&mut self.pending_reads);
let writes = std::mem::take(&mut self.pending_writes);
(reads, writes)
}
fn record_buffer_bindings(buffers: &[(u64, &MlxBuffer)]) -> Vec<(u64, RecordedBinding)> {
buffers
.iter()
.map(|&(index, buf)| {
(
index,
RecordedBinding::Buffer {
metal_buffer: buf.metal_buffer().clone(),
offset: buf.byte_offset(),
},
)
})
.collect()
}
fn record_arg_bindings(bindings: &[(u64, KernelArg<'_>)]) -> Vec<(u64, RecordedBinding)> {
bindings
.iter()
.map(|(index, arg)| {
let recorded = match arg {
KernelArg::Buffer(buf) => RecordedBinding::Buffer {
metal_buffer: buf.metal_buffer().clone(),
offset: buf.byte_offset(),
},
KernelArg::BufferWithOffset(buf, offset) => RecordedBinding::Buffer {
metal_buffer: buf.metal_buffer().clone(),
offset: *offset,
},
KernelArg::Bytes(bytes) => RecordedBinding::Bytes(bytes.to_vec()),
};
(*index, recorded)
})
.collect()
}
#[inline]
fn get_or_create_encoder(&mut self) -> &ComputeCommandEncoderRef {
if self.active_encoder.is_null() {
let dispatch_type = if std::env::var("HF2Q_FORCE_SERIAL_DISPATCH")
.map(|v| v == "1")
.unwrap_or(false)
{
MTLDispatchType::Serial
} else {
MTLDispatchType::Concurrent
};
let encoder = self
.cmd_buf
.compute_command_encoder_with_dispatch_type(dispatch_type);
self.active_encoder = encoder as *const ComputeCommandEncoderRef;
}
unsafe { &*self.active_encoder }
}
#[inline]
fn end_active_encoder(&mut self) {
if !self.active_encoder.is_null() {
unsafe { &*self.active_encoder }.end_encoding();
self.active_encoder = std::ptr::null();
}
}
#[allow(unexpected_cfgs)]
pub fn memory_barrier(&mut self) {
if let Some(ref mut nodes) = self.capture {
nodes.push(CapturedNode::Barrier);
return;
}
if self.active_encoder.is_null() {
return;
}
BARRIER_COUNT.fetch_add(1, Ordering::Relaxed);
if auto_barrier_enabled() {
self.mem_ranges.reset();
}
let encoder = unsafe { &*self.active_encoder };
if barrier_profile_enabled() {
let start = std::time::Instant::now();
issue_metal_buffer_barrier(encoder);
let elapsed_ns = start.elapsed().as_nanos() as u64;
BARRIER_NS.fetch_add(elapsed_ns, Ordering::Relaxed);
} else {
issue_metal_buffer_barrier(encoder);
}
}
pub fn set_pipeline(&mut self, pipeline: &ComputePipelineStateRef) {
let encoder = self.get_or_create_encoder();
encoder.set_compute_pipeline_state(pipeline);
}
pub fn set_buffer(&self, index: u64, buffer: &MlxBuffer) {
let _ = (index, buffer);
}
pub fn dispatch_threads(&self, grid_size: MTLSize, threadgroup_size: MTLSize) {
let _ = (grid_size, threadgroup_size);
}
pub fn encode(
&mut self,
pipeline: &ComputePipelineStateRef,
buffers: &[(u64, &MlxBuffer)],
grid_size: MTLSize,
threadgroup_size: MTLSize,
) {
DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
bucket_dispatch(pipeline);
let op_kind = self.take_pending_op_kind();
let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
if let Some(ref mut nodes) = self.capture {
nodes.push(CapturedNode::Dispatch {
pipeline: pipeline.to_owned(),
bindings: Self::record_buffer_bindings(buffers),
threads_per_grid: grid_size,
threads_per_threadgroup: threadgroup_size,
threadgroup_memory: Vec::new(),
dispatch_kind: DispatchKind::Threads,
op_kind,
reads: pending_reads,
writes: pending_writes,
});
return;
}
self.ensure_sample_buffer();
let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
let encoder = unsafe { &*encoder_ptr };
encoder.set_compute_pipeline_state(pipeline);
for &(index, buf) in buffers {
encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
}
let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
assert_tg_size_multiple_of_32_if_hinted(threadgroup_size, pipeline);
encoder.dispatch_threads(grid_size, threadgroup_size);
self.sample_dispatch_post(encoder, pre_idx);
}
pub fn encode_threadgroups(
&mut self,
pipeline: &ComputePipelineStateRef,
buffers: &[(u64, &MlxBuffer)],
threadgroups: MTLSize,
threadgroup_size: MTLSize,
) {
DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
bucket_dispatch(pipeline);
let op_kind = self.take_pending_op_kind();
let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
if let Some(ref mut nodes) = self.capture {
nodes.push(CapturedNode::Dispatch {
pipeline: pipeline.to_owned(),
bindings: Self::record_buffer_bindings(buffers),
threads_per_grid: threadgroups,
threads_per_threadgroup: threadgroup_size,
threadgroup_memory: Vec::new(),
dispatch_kind: DispatchKind::ThreadGroups,
op_kind,
reads: pending_reads,
writes: pending_writes,
});
return;
}
self.ensure_sample_buffer();
let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
let encoder = unsafe { &*encoder_ptr };
encoder.set_compute_pipeline_state(pipeline);
for &(index, buf) in buffers {
encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
}
let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
assert_tg_size_multiple_of_32_if_hinted(threadgroup_size, pipeline);
encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
self.sample_dispatch_post(encoder, pre_idx);
}
pub fn encode_threadgroups_with_shared(
&mut self,
pipeline: &ComputePipelineStateRef,
buffers: &[(u64, &MlxBuffer)],
threadgroup_mem: &[(u64, u64)],
threadgroups: MTLSize,
threadgroup_size: MTLSize,
) {
DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
bucket_dispatch(pipeline);
let op_kind = self.take_pending_op_kind();
let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
if let Some(ref mut nodes) = self.capture {
nodes.push(CapturedNode::Dispatch {
pipeline: pipeline.to_owned(),
bindings: Self::record_buffer_bindings(buffers),
threads_per_grid: threadgroups,
threads_per_threadgroup: threadgroup_size,
threadgroup_memory: threadgroup_mem.to_vec(),
dispatch_kind: DispatchKind::ThreadGroups,
op_kind,
reads: pending_reads,
writes: pending_writes,
});
return;
}
self.ensure_sample_buffer();
let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
let encoder = unsafe { &*encoder_ptr };
encoder.set_compute_pipeline_state(pipeline);
for &(index, buf) in buffers {
encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
}
for &(index, byte_length) in threadgroup_mem {
encoder.set_threadgroup_memory_length(index, byte_length);
}
let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
assert_tg_size_multiple_of_32_if_hinted(threadgroup_size, pipeline);
encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
self.sample_dispatch_post(encoder, pre_idx);
}
pub fn encode_with_args(
&mut self,
pipeline: &ComputePipelineStateRef,
bindings: &[(u64, KernelArg<'_>)],
grid_size: MTLSize,
threadgroup_size: MTLSize,
) {
DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
bucket_dispatch(pipeline);
let op_kind = self.take_pending_op_kind();
let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
if let Some(ref mut nodes) = self.capture {
nodes.push(CapturedNode::Dispatch {
pipeline: pipeline.to_owned(),
bindings: Self::record_arg_bindings(bindings),
threads_per_grid: grid_size,
threads_per_threadgroup: threadgroup_size,
threadgroup_memory: Vec::new(),
dispatch_kind: DispatchKind::Threads,
op_kind,
reads: pending_reads,
writes: pending_writes,
});
return;
}
self.ensure_sample_buffer();
let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
let encoder = unsafe { &*encoder_ptr };
encoder.set_compute_pipeline_state(pipeline);
apply_bindings(encoder, bindings);
let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
assert_tg_size_multiple_of_32_if_hinted(threadgroup_size, pipeline);
encoder.dispatch_threads(grid_size, threadgroup_size);
self.sample_dispatch_post(encoder, pre_idx);
}
pub fn encode_threadgroups_with_args(
&mut self,
pipeline: &ComputePipelineStateRef,
bindings: &[(u64, KernelArg<'_>)],
threadgroups: MTLSize,
threadgroup_size: MTLSize,
) {
DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
bucket_dispatch(pipeline);
let op_kind = self.take_pending_op_kind();
let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
if let Some(ref mut nodes) = self.capture {
nodes.push(CapturedNode::Dispatch {
pipeline: pipeline.to_owned(),
bindings: Self::record_arg_bindings(bindings),
threads_per_grid: threadgroups,
threads_per_threadgroup: threadgroup_size,
threadgroup_memory: Vec::new(),
dispatch_kind: DispatchKind::ThreadGroups,
op_kind,
reads: pending_reads,
writes: pending_writes,
});
return;
}
self.ensure_sample_buffer();
let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
let encoder = unsafe { &*encoder_ptr };
encoder.set_compute_pipeline_state(pipeline);
apply_bindings(encoder, bindings);
let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
assert_tg_size_multiple_of_32_if_hinted(threadgroup_size, pipeline);
encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
self.sample_dispatch_post(encoder, pre_idx);
}
pub fn encode_threadgroups_with_args_and_shared(
&mut self,
pipeline: &ComputePipelineStateRef,
bindings: &[(u64, KernelArg<'_>)],
threadgroup_mem: &[(u64, u64)],
threadgroups: MTLSize,
threadgroup_size: MTLSize,
) {
DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
bucket_dispatch(pipeline);
let op_kind = self.take_pending_op_kind();
let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
if let Some(ref mut nodes) = self.capture {
nodes.push(CapturedNode::Dispatch {
pipeline: pipeline.to_owned(),
bindings: Self::record_arg_bindings(bindings),
threads_per_grid: threadgroups,
threads_per_threadgroup: threadgroup_size,
threadgroup_memory: threadgroup_mem.to_vec(),
dispatch_kind: DispatchKind::ThreadGroups,
op_kind,
reads: pending_reads,
writes: pending_writes,
});
return;
}
self.ensure_sample_buffer();
let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
let encoder = unsafe { &*encoder_ptr };
encoder.set_compute_pipeline_state(pipeline);
apply_bindings(encoder, bindings);
for &(index, byte_length) in threadgroup_mem {
encoder.set_threadgroup_memory_length(index, byte_length);
}
let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
assert_tg_size_multiple_of_32_if_hinted(threadgroup_size, pipeline);
encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
self.sample_dispatch_post(encoder, pre_idx);
}
pub fn dispatch_tracked_threadgroups_with_args(
&mut self,
pipeline: &ComputePipelineStateRef,
bindings: &[(u64, KernelArg<'_>)],
reads: &[&MlxBuffer],
writes: &[&MlxBuffer],
threadgroups: MTLSize,
threadgroup_size: MTLSize,
) {
if self.is_capturing() {
let read_ranges = ranges_from_buffers(reads);
let write_ranges = ranges_from_buffers(writes);
self.set_pending_buffer_ranges(read_ranges, write_ranges);
self.encode_threadgroups_with_args(pipeline, bindings, threadgroups, threadgroup_size);
return;
}
if auto_barrier_enabled() {
self.maybe_auto_barrier(reads, writes);
}
self.encode_threadgroups_with_args(pipeline, bindings, threadgroups, threadgroup_size);
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_tracked_threadgroups_with_args_and_shared(
&mut self,
pipeline: &ComputePipelineStateRef,
bindings: &[(u64, KernelArg<'_>)],
threadgroup_mem: &[(u64, u64)],
reads: &[&MlxBuffer],
writes: &[&MlxBuffer],
threadgroups: MTLSize,
threadgroup_size: MTLSize,
) {
if self.is_capturing() {
let read_ranges = ranges_from_buffers(reads);
let write_ranges = ranges_from_buffers(writes);
self.set_pending_buffer_ranges(read_ranges, write_ranges);
self.encode_threadgroups_with_args_and_shared(
pipeline,
bindings,
threadgroup_mem,
threadgroups,
threadgroup_size,
);
return;
}
if auto_barrier_enabled() {
self.maybe_auto_barrier(reads, writes);
}
self.encode_threadgroups_with_args_and_shared(
pipeline,
bindings,
threadgroup_mem,
threadgroups,
threadgroup_size,
);
}
pub fn dispatch_tracked_threadgroups(
&mut self,
pipeline: &ComputePipelineStateRef,
buffers: &[(u64, &MlxBuffer)],
reads: &[&MlxBuffer],
writes: &[&MlxBuffer],
threadgroups: MTLSize,
threadgroup_size: MTLSize,
) {
if self.is_capturing() {
let read_ranges = ranges_from_buffers(reads);
let write_ranges = ranges_from_buffers(writes);
self.set_pending_buffer_ranges(read_ranges, write_ranges);
self.encode_threadgroups(pipeline, buffers, threadgroups, threadgroup_size);
return;
}
if auto_barrier_enabled() {
self.maybe_auto_barrier(reads, writes);
}
self.encode_threadgroups(pipeline, buffers, threadgroups, threadgroup_size);
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_tracked_threadgroups_with_shared(
&mut self,
pipeline: &ComputePipelineStateRef,
buffers: &[(u64, &MlxBuffer)],
threadgroup_mem: &[(u64, u64)],
reads: &[&MlxBuffer],
writes: &[&MlxBuffer],
threadgroups: MTLSize,
threadgroup_size: MTLSize,
) {
if self.is_capturing() {
let read_ranges = ranges_from_buffers(reads);
let write_ranges = ranges_from_buffers(writes);
self.set_pending_buffer_ranges(read_ranges, write_ranges);
self.encode_threadgroups_with_shared(
pipeline,
buffers,
threadgroup_mem,
threadgroups,
threadgroup_size,
);
return;
}
if auto_barrier_enabled() {
self.maybe_auto_barrier(reads, writes);
}
self.encode_threadgroups_with_shared(
pipeline,
buffers,
threadgroup_mem,
threadgroups,
threadgroup_size,
);
}
pub fn dispatch_tracked_threads_with_args(
&mut self,
pipeline: &ComputePipelineStateRef,
bindings: &[(u64, KernelArg<'_>)],
reads: &[&MlxBuffer],
writes: &[&MlxBuffer],
grid_size: MTLSize,
threadgroup_size: MTLSize,
) {
if self.is_capturing() {
let read_ranges = ranges_from_buffers(reads);
let write_ranges = ranges_from_buffers(writes);
self.set_pending_buffer_ranges(read_ranges, write_ranges);
self.encode_with_args(pipeline, bindings, grid_size, threadgroup_size);
return;
}
if auto_barrier_enabled() {
self.maybe_auto_barrier(reads, writes);
}
self.encode_with_args(pipeline, bindings, grid_size, threadgroup_size);
}
pub fn dispatch_record(
&mut self,
rec: &DispatchRecord,
runtime_buffers: &[&MlxBuffer],
) {
debug_assert_eq!(
rec.buffer_slots.len(),
runtime_buffers.len(),
"dispatch_record: runtime_buffers count must match buffer_slots ({}); got {}",
rec.buffer_slots.len(),
runtime_buffers.len(),
);
DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
bucket_dispatch(&rec.pipeline);
let op_kind_override = self.take_pending_op_kind();
let op_kind = if matches!(op_kind_override, CapturedOpKind::Other) {
rec.op_kind
} else {
op_kind_override
};
let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
if let Some(ref mut nodes) = self.capture {
let cap = runtime_buffers.len() + if rec.params_bytes.is_empty() { 0 } else { 1 };
let mut bindings: Vec<(u64, RecordedBinding)> = Vec::with_capacity(cap);
for (slot, buf) in rec.buffer_slots.iter().zip(runtime_buffers.iter()) {
bindings.push((
*slot,
RecordedBinding::Buffer {
metal_buffer: buf.metal_buffer().to_owned(),
offset: buf.byte_offset(),
},
));
}
if !rec.params_bytes.is_empty() {
bindings.push((
rec.params_slot,
RecordedBinding::Bytes(rec.params_bytes.clone()),
));
}
nodes.push(CapturedNode::Dispatch {
pipeline: rec.pipeline.clone(),
bindings,
threads_per_grid: rec.threadgroups,
threads_per_threadgroup: rec.threads_per_tg,
threadgroup_memory: rec.threadgroup_mem.clone(),
dispatch_kind: DispatchKind::ThreadGroups,
op_kind,
reads: pending_reads,
writes: pending_writes,
});
return;
}
self.ensure_sample_buffer();
let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
let encoder = unsafe { &*encoder_ptr };
encoder.set_compute_pipeline_state(&rec.pipeline);
for (slot, buf) in rec.buffer_slots.iter().zip(runtime_buffers.iter()) {
encoder.set_buffer(*slot, Some(buf.metal_buffer()), buf.byte_offset());
}
if !rec.params_bytes.is_empty() {
encoder.set_bytes(
rec.params_slot,
rec.params_bytes.len() as u64,
rec.params_bytes.as_ptr() as *const _,
);
}
for &(idx, len) in rec.threadgroup_mem.iter() {
encoder.set_threadgroup_memory_length(idx, len);
}
let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
encoder.dispatch_thread_groups(rec.threadgroups, rec.threads_per_tg);
self.sample_dispatch_post(encoder, pre_idx);
}
fn maybe_auto_barrier(
&mut self,
reads: &[&MlxBuffer],
writes: &[&MlxBuffer],
) {
if self.mem_ranges.check_dispatch(reads, writes) {
self.mem_ranges.add_dispatch(reads, writes);
AUTO_BARRIER_CONCURRENT.fetch_add(1, Ordering::Relaxed);
} else {
self.memory_barrier();
self.mem_ranges.reset();
self.mem_ranges.add_dispatch(reads, writes);
AUTO_BARRIER_COUNT.fetch_add(1, Ordering::Relaxed);
}
}
pub fn force_barrier_and_reset_tracker(&mut self) {
self.memory_barrier();
if auto_barrier_enabled() {
self.mem_ranges.reset();
}
}
#[inline]
pub fn mem_ranges_len(&self) -> usize {
self.mem_ranges.len()
}
pub fn replay_dispatch(
&mut self,
pipeline: &ComputePipelineStateRef,
bindings: &[(u64, RecordedBinding)],
threadgroup_memory: &[(u64, u64)],
threads_per_grid: MTLSize,
threads_per_threadgroup: MTLSize,
dispatch_kind: DispatchKind,
) {
self.ensure_sample_buffer();
let op_kind = self.take_pending_op_kind();
let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
let encoder = unsafe { &*encoder_ptr };
encoder.set_compute_pipeline_state(pipeline);
for (index, binding) in bindings {
match binding {
RecordedBinding::Buffer { metal_buffer, offset } => {
encoder.set_buffer(*index, Some(metal_buffer), *offset);
}
RecordedBinding::Bytes(bytes) => {
encoder.set_bytes(
*index,
bytes.len() as u64,
bytes.as_ptr() as *const _,
);
}
}
}
for &(index, byte_length) in threadgroup_memory {
encoder.set_threadgroup_memory_length(index, byte_length);
}
let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
match dispatch_kind {
DispatchKind::Threads => {
assert_tg_size_multiple_of_32_if_hinted(threads_per_threadgroup, pipeline);
encoder.dispatch_threads(threads_per_grid, threads_per_threadgroup);
}
DispatchKind::ThreadGroups => {
assert_tg_size_multiple_of_32_if_hinted(threads_per_threadgroup, pipeline);
encoder.dispatch_thread_groups(threads_per_grid, threads_per_threadgroup);
}
}
self.sample_dispatch_post(encoder, pre_idx);
}
#[inline]
fn flush_residency_pending(&self) {
if let Some(set) = self.residency_set.as_ref() {
set.flush_pending();
}
}
#[inline]
fn ensure_sample_buffer(&mut self) {
if !crate::kernel_profile::is_dispatch_enabled() {
return;
}
if self.sample_buffer.is_some() {
return;
}
let device: &metal::DeviceRef = unsafe {
let cb = &*self.cmd_buf;
msg_send![cb, device]
};
if !device.supports_counter_sampling(MTLCounterSamplingPoint::AtDispatchBoundary) {
if TIMESTAMP_SET_WARN_LOGGED
.compare_exchange(0, 1, Ordering::Relaxed, Ordering::Relaxed)
.is_ok()
{
eprintln!(
"[mlx-native] MLX_PROFILE_DISPATCH=1 ignored: \
device {:?} does NOT support \
MTLCounterSamplingPointAtDispatchBoundary \
(Apple Silicon limitation; only AtStageBoundary \
is supported, which is incompatible with the \
persistent compute-encoder pattern). \
MLX_PROFILE_CB=1 still produces per-CB GPU times.",
device.name()
);
}
return;
}
let counter_sets = device.counter_sets();
let timestamp_set = counter_sets
.iter()
.find(|c: &&metal::CounterSet| c.name().eq_ignore_ascii_case("timestamp"));
let timestamp_set = match timestamp_set {
Some(s) => s,
None => {
if TIMESTAMP_SET_WARN_LOGGED
.compare_exchange(0, 1, Ordering::Relaxed, Ordering::Relaxed)
.is_ok()
{
eprintln!(
"[mlx-native] MLX_PROFILE_DISPATCH=1 ignored: \
device {:?} exposes no MTLCommonCounterSetTimestamp",
device.name()
);
}
return;
}
};
let descriptor = CounterSampleBufferDescriptor::new();
descriptor.set_counter_set(timestamp_set);
descriptor.set_storage_mode(MTLStorageMode::Shared);
descriptor.set_label("mlx_native.dispatch_samples");
descriptor.set_sample_count(MAX_SAMPLES_PER_CB);
match device.new_counter_sample_buffer_with_descriptor(&descriptor) {
Ok(buf) => {
self.sample_buffer = Some(buf);
}
Err(e) => {
if TIMESTAMP_SET_WARN_LOGGED
.compare_exchange(0, 1, Ordering::Relaxed, Ordering::Relaxed)
.is_ok()
{
eprintln!(
"[mlx-native] MLX_PROFILE_DISPATCH=1 ignored: \
newCounterSampleBufferWithDescriptor failed: {}",
e
);
}
self.sample_buffer = None;
}
}
}
#[inline]
fn sample_dispatch_pre(
&mut self,
encoder: &ComputeCommandEncoderRef,
op_kind: CapturedOpKind,
) -> Option<u32> {
let sb = self.sample_buffer.as_ref()?;
let i = self.dispatch_in_cb;
let pre_idx = (i as u64).checked_mul(2)?;
if pre_idx >= MAX_SAMPLES_PER_CB {
return None;
}
encoder.sample_counters_in_buffer(sb, pre_idx, true);
self.pending_dispatch_meta.push(PendingDispatchMeta {
op_kind: op_kind.name(),
dispatch_index: i,
});
Some(i)
}
#[inline]
fn sample_dispatch_post(
&mut self,
encoder: &ComputeCommandEncoderRef,
pre_idx: Option<u32>,
) {
let i = match pre_idx {
Some(v) => v,
None => return,
};
let sb = match self.sample_buffer.as_ref() {
Some(b) => b,
None => return,
};
let post_idx = match (i as u64).checked_mul(2).and_then(|v| v.checked_add(1)) {
Some(v) if v < MAX_SAMPLES_PER_CB => v,
_ => return,
};
encoder.sample_counters_in_buffer(sb, post_idx, true);
self.dispatch_in_cb = i.saturating_add(1);
}
fn resolve_dispatch_samples(&mut self, cb_label: &str) -> Result<()> {
let sb = match self.sample_buffer.take() {
Some(b) => b,
None => {
self.pending_dispatch_meta.clear();
self.dispatch_in_cb = 0;
return Ok(());
}
};
let n = self.pending_dispatch_meta.len();
if n == 0 {
self.dispatch_in_cb = 0;
return Ok(());
}
let mut cpu_t: u64 = 0;
let mut gpu_t: u64 = 0;
let device: &metal::DeviceRef = unsafe {
let cb = &*self.cmd_buf;
msg_send![cb, device]
};
device.sample_timestamps(&mut cpu_t, &mut gpu_t);
crate::kernel_profile::record_clock_pair(cpu_t, gpu_t);
let length = (n as u64).saturating_mul(2);
let data = sb.resolve_counter_range(NSRange {
location: 0,
length,
});
for (i, meta) in self.pending_dispatch_meta.drain(..).enumerate() {
let start_idx = 2 * i;
let end_idx = 2 * i + 1;
if end_idx >= data.len() {
break;
}
let start_raw = data[start_idx] as u64;
let end_raw = data[end_idx] as u64;
let start_ns = crate::kernel_profile::convert_gpu_ticks_to_ns(start_raw);
let end_ns = crate::kernel_profile::convert_gpu_ticks_to_ns(end_raw);
let gpu_ns = end_ns.saturating_sub(start_ns);
crate::kernel_profile::record_dispatch(
crate::kernel_profile::DispatchEntry {
cb_label: cb_label.to_string(),
op_kind: meta.op_kind,
dispatch_index: meta.dispatch_index,
gpu_ns,
start_gpu_ns: start_ns,
end_gpu_ns: end_ns,
},
);
}
drop(sb);
self.dispatch_in_cb = 0;
Ok(())
}
pub fn commit_and_wait(&mut self) -> Result<()> {
SYNC_COUNT.fetch_add(1, Ordering::Relaxed);
self.end_active_encoder();
self.flush_residency_pending();
self.cmd_buf.commit();
self.cmd_buf.wait_until_completed();
match self.cmd_buf.status() {
MTLCommandBufferStatus::Completed => Ok(()),
MTLCommandBufferStatus::Error => {
Err(MlxError::CommandBufferError(
"GPU command buffer completed with error status".into(),
))
}
status => Err(MlxError::CommandBufferError(format!(
"Unexpected command buffer status after wait: {:?}",
status
))),
}
}
pub fn commit_and_wait_labeled(&mut self, label: &str) -> Result<()> {
self.apply_labels(label);
let need_gpu_time =
crate::kernel_profile::is_enabled() || crate::kernel_profile::is_dispatch_enabled();
if need_gpu_time {
let (start_s, end_s) = self.commit_wait_with_gpu_time()?;
let ns = ((end_s - start_s).max(0.0) * 1_000_000_000.0) as u64;
if crate::kernel_profile::is_enabled() {
crate::kernel_profile::record(label, ns);
}
if crate::kernel_profile::is_dispatch_enabled() {
self.resolve_dispatch_samples(label)?;
}
Ok(())
} else {
self.commit_and_wait()
}
}
pub fn commit_labeled(&mut self, label: &str) {
if crate::kernel_profile::is_enabled() {
if let Err(e) = self.commit_and_wait_labeled(label) {
eprintln!("[mlx-native] commit_labeled({}) failed: {}", label, e);
}
} else {
self.apply_labels(label);
self.commit();
}
}
#[inline]
fn apply_labels(&mut self, label: &str) {
debug_assert!(!label.is_empty(), "commit_*_labeled called with empty label");
if label.is_empty() {
return;
}
self.cmd_buf.set_label(label);
if !self.active_encoder.is_null() {
unsafe { &*self.active_encoder }.set_label(label);
}
self.last_label.clear();
self.last_label.push_str(label);
}
pub fn commit_wait_with_gpu_time(&mut self) -> Result<(f64, f64)> {
self.commit_and_wait()?;
let (gpu_start, gpu_end): (f64, f64) = unsafe {
let cb = &*self.cmd_buf;
let s: f64 = msg_send![cb, GPUStartTime];
let e: f64 = msg_send![cb, GPUEndTime];
(s, e)
};
Ok((gpu_start, gpu_end))
}
pub fn commit(&mut self) {
self.end_active_encoder();
self.flush_residency_pending();
self.cmd_buf.commit();
}
pub fn wait_until_completed(&self) -> Result<()> {
self.cmd_buf.wait_until_completed();
match self.cmd_buf.status() {
MTLCommandBufferStatus::Completed => Ok(()),
MTLCommandBufferStatus::Error => Err(MlxError::CommandBufferError(
"GPU command buffer completed with error status".into(),
)),
status => Err(MlxError::CommandBufferError(format!(
"Unexpected command buffer status after wait: {:?}",
status
))),
}
}
#[inline]
pub fn metal_command_buffer(&self) -> &CommandBuffer {
&self.cmd_buf
}
#[inline]
pub(crate) fn residency_set(&self) -> Option<&ResidencySet> {
self.residency_set.as_ref()
}
pub(crate) fn reset_command_buffer(&mut self) {
debug_assert!(
self.active_encoder.is_null(),
"reset_command_buffer called with an active compute encoder \
— caller must commit (which calls end_active_encoder) first"
);
let cmd_buf = if unretained_refs_enabled() {
self.queue
.new_command_buffer_with_unretained_references()
.to_owned()
} else {
self.queue.new_command_buffer().to_owned()
};
CMD_BUF_COUNT.fetch_add(1, Ordering::Relaxed);
self.cmd_buf = cmd_buf;
self.active_encoder = std::ptr::null();
self.dispatch_in_cb = 0;
self.last_label.clear();
self.pending_dispatch_meta.clear();
self.mem_ranges = MemRanges::new();
}
#[inline]
pub(crate) fn encode_wait_for_event(&self, event: &metal::EventRef, value: u64) {
debug_assert!(
self.active_encoder.is_null(),
"encode_wait_for_event called with an open compute encoder \
— wait must precede the first dispatch on the new CB"
);
self.cmd_buf.encode_wait_for_event(event, value);
}
pub(crate) fn fence_signal_and_commit(
&mut self,
event: &metal::EventRef,
new_value: u64,
label: Option<&str>,
) {
self.end_active_encoder();
if let Some(l) = label {
self.apply_labels(l);
}
self.cmd_buf.encode_signal_event(event, new_value);
self.flush_residency_pending();
self.cmd_buf.commit();
}
}
impl Drop for CommandEncoder {
fn drop(&mut self) {
self.end_active_encoder();
}
}