use std::sync::atomic::{AtomicU64, Ordering};
use metal::{
CommandBuffer, CommandQueue, ComputeCommandEncoderRef, ComputePipelineState,
ComputePipelineStateRef, MTLCommandBufferStatus, MTLDispatchType, MTLSize,
};
#[allow(unused_imports)]
use objc::{msg_send, sel, sel_impl};
use crate::buffer::MlxBuffer;
use crate::error::{MlxError, Result};
use crate::mem_ranges::MemRanges;
use crate::residency::ResidencySet;
pub enum KernelArg<'a> {
Buffer(&'a MlxBuffer),
BufferWithOffset(&'a MlxBuffer, u64),
Bytes(&'a [u8]),
}
pub fn as_bytes<T: bytemuck::Pod>(val: &T) -> &[u8] {
bytemuck::bytes_of(val)
}
#[derive(Clone)]
pub enum RecordedBinding {
Buffer {
metal_buffer: metal::Buffer,
offset: u64,
},
Bytes(Vec<u8>),
}
#[derive(Clone, Copy, Debug)]
pub enum DispatchKind {
Threads,
ThreadGroups,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum CapturedOpKind {
RmsNorm,
ElemMul,
ElemAdd,
Sdpa,
Softmax,
Other,
}
impl CapturedOpKind {
pub fn is_reorderable(&self) -> bool {
match self {
Self::Sdpa | Self::Softmax => false,
Self::RmsNorm | Self::ElemMul | Self::ElemAdd | Self::Other => true,
}
}
}
pub type MemRange = (usize, usize);
#[derive(Clone)]
pub enum CapturedNode {
Dispatch {
pipeline: ComputePipelineState,
bindings: Vec<(u64, RecordedBinding)>,
threads_per_grid: MTLSize,
threads_per_threadgroup: MTLSize,
threadgroup_memory: Vec<(u64, u64)>,
dispatch_kind: DispatchKind,
op_kind: CapturedOpKind,
reads: Vec<MemRange>,
writes: Vec<MemRange>,
},
Barrier,
}
fn ranges_from_buffers(bufs: &[&MlxBuffer]) -> Vec<MemRange> {
bufs.iter()
.map(|b| {
let base = b.contents_ptr() as usize + b.byte_offset() as usize;
let extent = (b.byte_len()).saturating_sub(b.byte_offset() as usize);
(base, base + extent)
})
.collect()
}
#[inline]
fn apply_bindings(encoder: &ComputeCommandEncoderRef, bindings: &[(u64, KernelArg<'_>)]) {
for &(index, ref arg) in bindings {
match arg {
KernelArg::Buffer(buf) => {
encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
}
KernelArg::BufferWithOffset(buf, offset) => {
encoder.set_buffer(index, Some(buf.metal_buffer()), *offset);
}
KernelArg::Bytes(bytes) => {
encoder.set_bytes(index, bytes.len() as u64, bytes.as_ptr() as *const _);
}
}
}
}
static SYNC_COUNT: AtomicU64 = AtomicU64::new(0);
static DISPATCH_COUNT: AtomicU64 = AtomicU64::new(0);
static CMD_BUF_COUNT: AtomicU64 = AtomicU64::new(0);
static BARRIER_COUNT: AtomicU64 = AtomicU64::new(0);
static BARRIER_NS: AtomicU64 = AtomicU64::new(0);
pub fn reset_counters() {
SYNC_COUNT.store(0, Ordering::Relaxed);
DISPATCH_COUNT.store(0, Ordering::Relaxed);
CMD_BUF_COUNT.store(0, Ordering::Relaxed);
BARRIER_COUNT.store(0, Ordering::Relaxed);
BARRIER_NS.store(0, Ordering::Relaxed);
AUTO_BARRIER_COUNT.store(0, Ordering::Relaxed);
AUTO_BARRIER_CONCURRENT.store(0, Ordering::Relaxed);
}
pub fn sync_count() -> u64 {
SYNC_COUNT.load(Ordering::Relaxed)
}
pub fn dispatch_count() -> u64 {
DISPATCH_COUNT.load(Ordering::Relaxed)
}
pub fn cmd_buf_count() -> u64 {
CMD_BUF_COUNT.load(Ordering::Relaxed)
}
pub fn barrier_count() -> u64 {
BARRIER_COUNT.load(Ordering::Relaxed)
}
pub fn barrier_total_ns() -> u64 {
BARRIER_NS.load(Ordering::Relaxed)
}
fn barrier_profile_enabled() -> bool {
use std::sync::OnceLock;
static FLAG: OnceLock<bool> = OnceLock::new();
*FLAG.get_or_init(|| {
std::env::var("MLX_PROFILE_BARRIERS")
.map(|v| v == "1")
.unwrap_or(false)
})
}
fn unretained_refs_enabled() -> bool {
use std::sync::OnceLock;
static FLAG: OnceLock<bool> = OnceLock::new();
*FLAG.get_or_init(|| {
std::env::var("MLX_UNRETAINED_REFS")
.map(|v| v == "1")
.unwrap_or(false)
})
}
fn auto_barrier_enabled() -> bool {
use std::sync::OnceLock;
static FLAG: OnceLock<bool> = OnceLock::new();
*FLAG.get_or_init(|| {
std::env::var("HF2Q_AUTO_BARRIER")
.map(|v| v == "1")
.unwrap_or(false)
})
}
static AUTO_BARRIER_COUNT: AtomicU64 = AtomicU64::new(0);
static AUTO_BARRIER_CONCURRENT: AtomicU64 = AtomicU64::new(0);
pub fn auto_barrier_count() -> u64 {
AUTO_BARRIER_COUNT.load(Ordering::Relaxed)
}
pub fn auto_barrier_concurrent_count() -> u64 {
AUTO_BARRIER_CONCURRENT.load(Ordering::Relaxed)
}
#[inline(never)]
fn issue_metal_buffer_barrier(encoder: &ComputeCommandEncoderRef) {
const MTL_BARRIER_SCOPE_BUFFERS: u64 = 1;
unsafe {
let _: () =
objc::msg_send![encoder, memoryBarrierWithScope: MTL_BARRIER_SCOPE_BUFFERS];
}
}
pub struct CommandEncoder {
cmd_buf: CommandBuffer,
active_encoder: *const ComputeCommandEncoderRef,
capture: Option<Vec<CapturedNode>>,
pending_op_kind: CapturedOpKind,
pending_reads: Vec<MemRange>,
pending_writes: Vec<MemRange>,
residency_set: Option<ResidencySet>,
mem_ranges: MemRanges,
}
unsafe impl Send for CommandEncoder {}
impl CommandEncoder {
#[allow(dead_code)]
pub(crate) fn new(queue: &CommandQueue) -> Result<Self> {
Self::new_with_residency(queue, None)
}
pub(crate) fn new_with_residency(
queue: &CommandQueue,
residency_set: Option<ResidencySet>,
) -> Result<Self> {
let cmd_buf = if unretained_refs_enabled() {
queue.new_command_buffer_with_unretained_references().to_owned()
} else {
queue.new_command_buffer().to_owned()
};
CMD_BUF_COUNT.fetch_add(1, Ordering::Relaxed);
Ok(Self {
cmd_buf,
active_encoder: std::ptr::null(),
capture: None,
pending_op_kind: CapturedOpKind::Other,
pending_reads: Vec::new(),
pending_writes: Vec::new(),
residency_set,
mem_ranges: MemRanges::new(),
})
}
pub fn start_capture(&mut self) {
self.capture = Some(Vec::with_capacity(128));
}
pub fn is_capturing(&self) -> bool {
self.capture.is_some()
}
pub fn take_capture(&mut self) -> Option<Vec<CapturedNode>> {
self.capture.take()
}
pub fn set_op_kind(&mut self, kind: CapturedOpKind) {
self.pending_op_kind = kind;
}
fn take_pending_op_kind(&mut self) -> CapturedOpKind {
let kind = self.pending_op_kind;
self.pending_op_kind = CapturedOpKind::Other;
kind
}
pub fn set_pending_buffer_ranges(&mut self, reads: Vec<MemRange>, writes: Vec<MemRange>) {
self.pending_reads = reads;
self.pending_writes = writes;
}
pub fn annotate_last_dispatch_if_missing(&mut self, reads: Vec<MemRange>, writes: Vec<MemRange>) {
if let Some(ref mut nodes) = self.capture {
if let Some(CapturedNode::Dispatch { reads: r, writes: w, .. }) = nodes.last_mut() {
if r.is_empty() && !reads.is_empty() {
*r = reads;
}
if w.is_empty() && !writes.is_empty() {
*w = writes;
}
}
}
}
fn take_pending_buffer_ranges(&mut self) -> (Vec<MemRange>, Vec<MemRange>) {
let reads = std::mem::take(&mut self.pending_reads);
let writes = std::mem::take(&mut self.pending_writes);
(reads, writes)
}
fn record_buffer_bindings(buffers: &[(u64, &MlxBuffer)]) -> Vec<(u64, RecordedBinding)> {
buffers
.iter()
.map(|&(index, buf)| {
(
index,
RecordedBinding::Buffer {
metal_buffer: buf.metal_buffer().clone(),
offset: buf.byte_offset(),
},
)
})
.collect()
}
fn record_arg_bindings(bindings: &[(u64, KernelArg<'_>)]) -> Vec<(u64, RecordedBinding)> {
bindings
.iter()
.map(|(index, arg)| {
let recorded = match arg {
KernelArg::Buffer(buf) => RecordedBinding::Buffer {
metal_buffer: buf.metal_buffer().clone(),
offset: buf.byte_offset(),
},
KernelArg::BufferWithOffset(buf, offset) => RecordedBinding::Buffer {
metal_buffer: buf.metal_buffer().clone(),
offset: *offset,
},
KernelArg::Bytes(bytes) => RecordedBinding::Bytes(bytes.to_vec()),
};
(*index, recorded)
})
.collect()
}
#[inline]
fn get_or_create_encoder(&mut self) -> &ComputeCommandEncoderRef {
if self.active_encoder.is_null() {
let encoder = self
.cmd_buf
.compute_command_encoder_with_dispatch_type(MTLDispatchType::Concurrent);
self.active_encoder = encoder as *const ComputeCommandEncoderRef;
}
unsafe { &*self.active_encoder }
}
#[inline]
fn end_active_encoder(&mut self) {
if !self.active_encoder.is_null() {
unsafe { &*self.active_encoder }.end_encoding();
self.active_encoder = std::ptr::null();
}
}
#[allow(unexpected_cfgs)]
pub fn memory_barrier(&mut self) {
if let Some(ref mut nodes) = self.capture {
nodes.push(CapturedNode::Barrier);
return;
}
if self.active_encoder.is_null() {
return;
}
BARRIER_COUNT.fetch_add(1, Ordering::Relaxed);
let encoder = unsafe { &*self.active_encoder };
if barrier_profile_enabled() {
let start = std::time::Instant::now();
issue_metal_buffer_barrier(encoder);
let elapsed_ns = start.elapsed().as_nanos() as u64;
BARRIER_NS.fetch_add(elapsed_ns, Ordering::Relaxed);
} else {
issue_metal_buffer_barrier(encoder);
}
}
pub fn set_pipeline(&mut self, pipeline: &ComputePipelineStateRef) {
let encoder = self.get_or_create_encoder();
encoder.set_compute_pipeline_state(pipeline);
}
pub fn set_buffer(&self, index: u64, buffer: &MlxBuffer) {
let _ = (index, buffer);
}
pub fn dispatch_threads(&self, grid_size: MTLSize, threadgroup_size: MTLSize) {
let _ = (grid_size, threadgroup_size);
}
pub fn encode(
&mut self,
pipeline: &ComputePipelineStateRef,
buffers: &[(u64, &MlxBuffer)],
grid_size: MTLSize,
threadgroup_size: MTLSize,
) {
DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
let op_kind = self.take_pending_op_kind();
let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
if let Some(ref mut nodes) = self.capture {
nodes.push(CapturedNode::Dispatch {
pipeline: pipeline.to_owned(),
bindings: Self::record_buffer_bindings(buffers),
threads_per_grid: grid_size,
threads_per_threadgroup: threadgroup_size,
threadgroup_memory: Vec::new(),
dispatch_kind: DispatchKind::Threads,
op_kind,
reads: pending_reads,
writes: pending_writes,
});
return;
}
let encoder = self.get_or_create_encoder();
encoder.set_compute_pipeline_state(pipeline);
for &(index, buf) in buffers {
encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
}
encoder.dispatch_threads(grid_size, threadgroup_size);
}
pub fn encode_threadgroups(
&mut self,
pipeline: &ComputePipelineStateRef,
buffers: &[(u64, &MlxBuffer)],
threadgroups: MTLSize,
threadgroup_size: MTLSize,
) {
DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
let op_kind = self.take_pending_op_kind();
let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
if let Some(ref mut nodes) = self.capture {
nodes.push(CapturedNode::Dispatch {
pipeline: pipeline.to_owned(),
bindings: Self::record_buffer_bindings(buffers),
threads_per_grid: threadgroups,
threads_per_threadgroup: threadgroup_size,
threadgroup_memory: Vec::new(),
dispatch_kind: DispatchKind::ThreadGroups,
op_kind,
reads: pending_reads,
writes: pending_writes,
});
return;
}
let encoder = self.get_or_create_encoder();
encoder.set_compute_pipeline_state(pipeline);
for &(index, buf) in buffers {
encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
}
encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
}
pub fn encode_threadgroups_with_shared(
&mut self,
pipeline: &ComputePipelineStateRef,
buffers: &[(u64, &MlxBuffer)],
threadgroup_mem: &[(u64, u64)],
threadgroups: MTLSize,
threadgroup_size: MTLSize,
) {
DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
let op_kind = self.take_pending_op_kind();
let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
if let Some(ref mut nodes) = self.capture {
nodes.push(CapturedNode::Dispatch {
pipeline: pipeline.to_owned(),
bindings: Self::record_buffer_bindings(buffers),
threads_per_grid: threadgroups,
threads_per_threadgroup: threadgroup_size,
threadgroup_memory: threadgroup_mem.to_vec(),
dispatch_kind: DispatchKind::ThreadGroups,
op_kind,
reads: pending_reads,
writes: pending_writes,
});
return;
}
let encoder = self.get_or_create_encoder();
encoder.set_compute_pipeline_state(pipeline);
for &(index, buf) in buffers {
encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
}
for &(index, byte_length) in threadgroup_mem {
encoder.set_threadgroup_memory_length(index, byte_length);
}
encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
}
pub fn encode_with_args(
&mut self,
pipeline: &ComputePipelineStateRef,
bindings: &[(u64, KernelArg<'_>)],
grid_size: MTLSize,
threadgroup_size: MTLSize,
) {
DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
let op_kind = self.take_pending_op_kind();
let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
if let Some(ref mut nodes) = self.capture {
nodes.push(CapturedNode::Dispatch {
pipeline: pipeline.to_owned(),
bindings: Self::record_arg_bindings(bindings),
threads_per_grid: grid_size,
threads_per_threadgroup: threadgroup_size,
threadgroup_memory: Vec::new(),
dispatch_kind: DispatchKind::Threads,
op_kind,
reads: pending_reads,
writes: pending_writes,
});
return;
}
let encoder = self.get_or_create_encoder();
encoder.set_compute_pipeline_state(pipeline);
apply_bindings(encoder, bindings);
encoder.dispatch_threads(grid_size, threadgroup_size);
}
pub fn encode_threadgroups_with_args(
&mut self,
pipeline: &ComputePipelineStateRef,
bindings: &[(u64, KernelArg<'_>)],
threadgroups: MTLSize,
threadgroup_size: MTLSize,
) {
DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
let op_kind = self.take_pending_op_kind();
let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
if let Some(ref mut nodes) = self.capture {
nodes.push(CapturedNode::Dispatch {
pipeline: pipeline.to_owned(),
bindings: Self::record_arg_bindings(bindings),
threads_per_grid: threadgroups,
threads_per_threadgroup: threadgroup_size,
threadgroup_memory: Vec::new(),
dispatch_kind: DispatchKind::ThreadGroups,
op_kind,
reads: pending_reads,
writes: pending_writes,
});
return;
}
let encoder = self.get_or_create_encoder();
encoder.set_compute_pipeline_state(pipeline);
apply_bindings(encoder, bindings);
encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
}
pub fn encode_threadgroups_with_args_and_shared(
&mut self,
pipeline: &ComputePipelineStateRef,
bindings: &[(u64, KernelArg<'_>)],
threadgroup_mem: &[(u64, u64)],
threadgroups: MTLSize,
threadgroup_size: MTLSize,
) {
DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
let op_kind = self.take_pending_op_kind();
let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
if let Some(ref mut nodes) = self.capture {
nodes.push(CapturedNode::Dispatch {
pipeline: pipeline.to_owned(),
bindings: Self::record_arg_bindings(bindings),
threads_per_grid: threadgroups,
threads_per_threadgroup: threadgroup_size,
threadgroup_memory: threadgroup_mem.to_vec(),
dispatch_kind: DispatchKind::ThreadGroups,
op_kind,
reads: pending_reads,
writes: pending_writes,
});
return;
}
let encoder = self.get_or_create_encoder();
encoder.set_compute_pipeline_state(pipeline);
apply_bindings(encoder, bindings);
for &(index, byte_length) in threadgroup_mem {
encoder.set_threadgroup_memory_length(index, byte_length);
}
encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
}
pub fn dispatch_tracked_threadgroups_with_args(
&mut self,
pipeline: &ComputePipelineStateRef,
bindings: &[(u64, KernelArg<'_>)],
reads: &[&MlxBuffer],
writes: &[&MlxBuffer],
threadgroups: MTLSize,
threadgroup_size: MTLSize,
) {
if self.is_capturing() {
let read_ranges = ranges_from_buffers(reads);
let write_ranges = ranges_from_buffers(writes);
self.set_pending_buffer_ranges(read_ranges, write_ranges);
self.encode_threadgroups_with_args(pipeline, bindings, threadgroups, threadgroup_size);
return;
}
if auto_barrier_enabled() {
self.maybe_auto_barrier(reads, writes);
}
self.encode_threadgroups_with_args(pipeline, bindings, threadgroups, threadgroup_size);
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_tracked_threadgroups_with_args_and_shared(
&mut self,
pipeline: &ComputePipelineStateRef,
bindings: &[(u64, KernelArg<'_>)],
threadgroup_mem: &[(u64, u64)],
reads: &[&MlxBuffer],
writes: &[&MlxBuffer],
threadgroups: MTLSize,
threadgroup_size: MTLSize,
) {
if self.is_capturing() {
let read_ranges = ranges_from_buffers(reads);
let write_ranges = ranges_from_buffers(writes);
self.set_pending_buffer_ranges(read_ranges, write_ranges);
self.encode_threadgroups_with_args_and_shared(
pipeline,
bindings,
threadgroup_mem,
threadgroups,
threadgroup_size,
);
return;
}
if auto_barrier_enabled() {
self.maybe_auto_barrier(reads, writes);
}
self.encode_threadgroups_with_args_and_shared(
pipeline,
bindings,
threadgroup_mem,
threadgroups,
threadgroup_size,
);
}
pub fn dispatch_tracked_threadgroups(
&mut self,
pipeline: &ComputePipelineStateRef,
buffers: &[(u64, &MlxBuffer)],
reads: &[&MlxBuffer],
writes: &[&MlxBuffer],
threadgroups: MTLSize,
threadgroup_size: MTLSize,
) {
if self.is_capturing() {
let read_ranges = ranges_from_buffers(reads);
let write_ranges = ranges_from_buffers(writes);
self.set_pending_buffer_ranges(read_ranges, write_ranges);
self.encode_threadgroups(pipeline, buffers, threadgroups, threadgroup_size);
return;
}
if auto_barrier_enabled() {
self.maybe_auto_barrier(reads, writes);
}
self.encode_threadgroups(pipeline, buffers, threadgroups, threadgroup_size);
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_tracked_threadgroups_with_shared(
&mut self,
pipeline: &ComputePipelineStateRef,
buffers: &[(u64, &MlxBuffer)],
threadgroup_mem: &[(u64, u64)],
reads: &[&MlxBuffer],
writes: &[&MlxBuffer],
threadgroups: MTLSize,
threadgroup_size: MTLSize,
) {
if self.is_capturing() {
let read_ranges = ranges_from_buffers(reads);
let write_ranges = ranges_from_buffers(writes);
self.set_pending_buffer_ranges(read_ranges, write_ranges);
self.encode_threadgroups_with_shared(
pipeline,
buffers,
threadgroup_mem,
threadgroups,
threadgroup_size,
);
return;
}
if auto_barrier_enabled() {
self.maybe_auto_barrier(reads, writes);
}
self.encode_threadgroups_with_shared(
pipeline,
buffers,
threadgroup_mem,
threadgroups,
threadgroup_size,
);
}
pub fn dispatch_tracked_threads_with_args(
&mut self,
pipeline: &ComputePipelineStateRef,
bindings: &[(u64, KernelArg<'_>)],
reads: &[&MlxBuffer],
writes: &[&MlxBuffer],
grid_size: MTLSize,
threadgroup_size: MTLSize,
) {
if self.is_capturing() {
let read_ranges = ranges_from_buffers(reads);
let write_ranges = ranges_from_buffers(writes);
self.set_pending_buffer_ranges(read_ranges, write_ranges);
self.encode_with_args(pipeline, bindings, grid_size, threadgroup_size);
return;
}
if auto_barrier_enabled() {
self.maybe_auto_barrier(reads, writes);
}
self.encode_with_args(pipeline, bindings, grid_size, threadgroup_size);
}
fn maybe_auto_barrier(
&mut self,
reads: &[&MlxBuffer],
writes: &[&MlxBuffer],
) {
if self.mem_ranges.check_dispatch(reads, writes) {
self.mem_ranges.add_dispatch(reads, writes);
AUTO_BARRIER_CONCURRENT.fetch_add(1, Ordering::Relaxed);
} else {
self.memory_barrier();
self.mem_ranges.reset();
self.mem_ranges.add_dispatch(reads, writes);
AUTO_BARRIER_COUNT.fetch_add(1, Ordering::Relaxed);
}
}
pub fn force_barrier_and_reset_tracker(&mut self) {
self.memory_barrier();
if auto_barrier_enabled() {
self.mem_ranges.reset();
}
}
#[inline]
pub fn mem_ranges_len(&self) -> usize {
self.mem_ranges.len()
}
pub fn replay_dispatch(
&mut self,
pipeline: &ComputePipelineStateRef,
bindings: &[(u64, RecordedBinding)],
threadgroup_memory: &[(u64, u64)],
threads_per_grid: MTLSize,
threads_per_threadgroup: MTLSize,
dispatch_kind: DispatchKind,
) {
let encoder = self.get_or_create_encoder();
encoder.set_compute_pipeline_state(pipeline);
for (index, binding) in bindings {
match binding {
RecordedBinding::Buffer { metal_buffer, offset } => {
encoder.set_buffer(*index, Some(metal_buffer), *offset);
}
RecordedBinding::Bytes(bytes) => {
encoder.set_bytes(
*index,
bytes.len() as u64,
bytes.as_ptr() as *const _,
);
}
}
}
for &(index, byte_length) in threadgroup_memory {
encoder.set_threadgroup_memory_length(index, byte_length);
}
match dispatch_kind {
DispatchKind::Threads => {
encoder.dispatch_threads(threads_per_grid, threads_per_threadgroup);
}
DispatchKind::ThreadGroups => {
encoder.dispatch_thread_groups(threads_per_grid, threads_per_threadgroup);
}
}
}
#[inline]
fn flush_residency_pending(&self) {
if let Some(set) = self.residency_set.as_ref() {
set.flush_pending();
}
}
pub fn commit_and_wait(&mut self) -> Result<()> {
SYNC_COUNT.fetch_add(1, Ordering::Relaxed);
self.end_active_encoder();
self.flush_residency_pending();
self.cmd_buf.commit();
self.cmd_buf.wait_until_completed();
match self.cmd_buf.status() {
MTLCommandBufferStatus::Completed => Ok(()),
MTLCommandBufferStatus::Error => {
Err(MlxError::CommandBufferError(
"GPU command buffer completed with error status".into(),
))
}
status => Err(MlxError::CommandBufferError(format!(
"Unexpected command buffer status after wait: {:?}",
status
))),
}
}
pub fn commit_and_wait_labeled(&mut self, label: &str) -> Result<()> {
self.apply_labels(label);
if crate::kernel_profile::is_enabled() {
let (start_s, end_s) = self.commit_wait_with_gpu_time()?;
let ns = ((end_s - start_s).max(0.0) * 1_000_000_000.0) as u64;
crate::kernel_profile::record(label, ns);
Ok(())
} else {
self.commit_and_wait()
}
}
pub fn commit_labeled(&mut self, label: &str) {
if crate::kernel_profile::is_enabled() {
if let Err(e) = self.commit_and_wait_labeled(label) {
eprintln!("[mlx-native] commit_labeled({}) failed: {}", label, e);
}
} else {
self.apply_labels(label);
self.commit();
}
}
#[inline]
fn apply_labels(&self, label: &str) {
debug_assert!(!label.is_empty(), "commit_*_labeled called with empty label");
if label.is_empty() {
return;
}
self.cmd_buf.set_label(label);
if !self.active_encoder.is_null() {
unsafe { &*self.active_encoder }.set_label(label);
}
}
pub fn commit_wait_with_gpu_time(&mut self) -> Result<(f64, f64)> {
self.commit_and_wait()?;
let (gpu_start, gpu_end): (f64, f64) = unsafe {
let cb = &*self.cmd_buf;
let s: f64 = msg_send![cb, GPUStartTime];
let e: f64 = msg_send![cb, GPUEndTime];
(s, e)
};
Ok((gpu_start, gpu_end))
}
pub fn commit(&mut self) {
self.end_active_encoder();
self.flush_residency_pending();
self.cmd_buf.commit();
}
pub fn wait_until_completed(&self) -> Result<()> {
self.cmd_buf.wait_until_completed();
match self.cmd_buf.status() {
MTLCommandBufferStatus::Completed => Ok(()),
MTLCommandBufferStatus::Error => Err(MlxError::CommandBufferError(
"GPU command buffer completed with error status".into(),
)),
status => Err(MlxError::CommandBufferError(format!(
"Unexpected command buffer status after wait: {:?}",
status
))),
}
}
#[inline]
pub fn metal_command_buffer(&self) -> &CommandBuffer {
&self.cmd_buf
}
}
impl Drop for CommandEncoder {
fn drop(&mut self) {
self.end_active_encoder();
}
}