pub struct CommandEncoder { /* private fields */ }Expand description
A batched compute command encoder.
Keeps a single Metal ComputeCommandEncoder alive across multiple
dispatches. The encoder is created on the first dispatch and ended
only when the command buffer is committed. This mirrors candle’s
compute_per_buffer pattern and avoids per-dispatch encoder overhead.
§Typical usage
let mut enc = device.command_encoder()?;
// Multiple dispatches share the same compute encoder:
enc.encode_threadgroups(pipeline1, &buffers1, tg1, tg_size1);
enc.encode_threadgroups(pipeline2, &buffers2, tg2, tg_size2);
enc.commit_and_wait()?;Implementations§
Source§impl CommandEncoder
impl CommandEncoder
Sourcepub fn start_capture(&mut self)
pub fn start_capture(&mut self)
Enable capture mode.
All subsequent dispatch and barrier calls will be recorded into a
Vec<CapturedNode> instead of being encoded into Metal.
Call take_capture() to extract the recorded nodes.
Sourcepub fn is_capturing(&self) -> bool
pub fn is_capturing(&self) -> bool
Whether the encoder is currently in capture mode.
Sourcepub fn take_capture(&mut self) -> Option<Vec<CapturedNode>>
pub fn take_capture(&mut self) -> Option<Vec<CapturedNode>>
Extract the captured nodes, ending capture mode.
Returns None if capture mode was not active.
Sourcepub fn set_op_kind(&mut self, kind: CapturedOpKind)
pub fn set_op_kind(&mut self, kind: CapturedOpKind)
Tag the NEXT captured dispatch with the given operation kind.
The tag is consumed (reset to Other) after the next dispatch is
captured. Only meaningful in capture mode — has no effect on
direct-dispatch encoding.
Used by op dispatch functions to annotate captures for the fusion pass (Phase 4e.2).
Sourcepub fn set_pending_buffer_ranges(
&mut self,
reads: Vec<(usize, usize)>,
writes: Vec<(usize, usize)>,
)
pub fn set_pending_buffer_ranges( &mut self, reads: Vec<(usize, usize)>, writes: Vec<(usize, usize)>, )
Stash buffer range annotations for the NEXT captured dispatch.
Called by GraphSession::barrier_between() in capture mode to record
which buffers the next dispatch reads from and writes to. The ranges
are consumed by the next encode_* call and attached to the captured
CapturedNode::Dispatch.
Only meaningful in capture mode — has no effect on direct-dispatch.
Sourcepub fn annotate_last_dispatch_if_missing(
&mut self,
reads: Vec<(usize, usize)>,
writes: Vec<(usize, usize)>,
)
pub fn annotate_last_dispatch_if_missing( &mut self, reads: Vec<(usize, usize)>, writes: Vec<(usize, usize)>, )
Patch the last captured dispatch node’s empty reads/writes with the given ranges. No-op if not capturing, or if the last node isn’t a Dispatch, or if its ranges are already populated.
Used by GraphSession::track_dispatch in recording mode to annotate
dispatches that were called without a preceding barrier_between.
Sourcepub fn memory_barrier(&mut self)
pub fn memory_barrier(&mut self)
Insert a memory barrier with scope MTLBarrierScopeBuffers.
When the encoder uses MTLDispatchTypeConcurrent, all dispatches can
execute concurrently unless separated by a barrier. Call this between
dispatches where the later dispatch reads a buffer written by an
earlier one.
This is the same pattern llama.cpp uses:
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers]
Sourcepub fn set_pipeline(&mut self, pipeline: &ComputePipelineStateRef)
pub fn set_pipeline(&mut self, pipeline: &ComputePipelineStateRef)
Set the compute pipeline state for subsequent dispatches.
This begins a new compute pass if one is not already active.
Sourcepub fn set_buffer(&self, index: u64, buffer: &MlxBuffer)
pub fn set_buffer(&self, index: u64, buffer: &MlxBuffer)
Bind a buffer to a compute kernel argument slot.
The index corresponds to the [[buffer(N)]] attribute in the MSL shader.
Sourcepub fn dispatch_threads(&self, grid_size: MTLSize, threadgroup_size: MTLSize)
pub fn dispatch_threads(&self, grid_size: MTLSize, threadgroup_size: MTLSize)
Dispatch threads on the GPU.
Sourcepub fn encode(
&mut self,
pipeline: &ComputePipelineStateRef,
buffers: &[(u64, &MlxBuffer)],
grid_size: MTLSize,
threadgroup_size: MTLSize,
)
pub fn encode( &mut self, pipeline: &ComputePipelineStateRef, buffers: &[(u64, &MlxBuffer)], grid_size: MTLSize, threadgroup_size: MTLSize, )
Encode a complete compute pass: set pipeline, bind buffers, dispatch.
Reuses the persistent compute encoder — no per-dispatch encoder creation overhead.
§Arguments
pipeline— The compiled compute pipeline to execute.buffers— Slice of(index, &MlxBuffer)pairs for buffer bindings.grid_size— Total number of threads to launch.threadgroup_size— Threads per threadgroup.
Sourcepub fn encode_threadgroups(
&mut self,
pipeline: &ComputePipelineStateRef,
buffers: &[(u64, &MlxBuffer)],
threadgroups: MTLSize,
threadgroup_size: MTLSize,
)
pub fn encode_threadgroups( &mut self, pipeline: &ComputePipelineStateRef, buffers: &[(u64, &MlxBuffer)], threadgroups: MTLSize, threadgroup_size: MTLSize, )
Encode a compute pass using threadgroups instead of raw thread counts.
Reuses the persistent compute encoder — no per-dispatch encoder creation overhead.
Encode a compute pass using threadgroups with shared threadgroup memory.
Like encode_threadgroups, but additionally
allocates threadgroup memory at the specified indices. This is required
for kernels that use threadgroup memory (e.g. reductions in rms_norm
and softmax).
§Arguments
pipeline— The compiled compute pipeline to execute.buffers— Slice of(index, &MlxBuffer)pairs for buffer bindings.threadgroup_mem— Slice of(index, byte_length)pairs for threadgroup memory.threadgroups— Number of threadgroups to dispatch.threadgroup_size— Threads per threadgroup.
Sourcepub fn encode_with_args(
&mut self,
pipeline: &ComputePipelineStateRef,
bindings: &[(u64, KernelArg<'_>)],
grid_size: MTLSize,
threadgroup_size: MTLSize,
)
pub fn encode_with_args( &mut self, pipeline: &ComputePipelineStateRef, bindings: &[(u64, KernelArg<'_>)], grid_size: MTLSize, threadgroup_size: MTLSize, )
Encode a dispatch with mixed buffer/bytes bindings (dispatch_threads).
Reuses the persistent compute encoder.
Sourcepub fn encode_threadgroups_with_args(
&mut self,
pipeline: &ComputePipelineStateRef,
bindings: &[(u64, KernelArg<'_>)],
threadgroups: MTLSize,
threadgroup_size: MTLSize,
)
pub fn encode_threadgroups_with_args( &mut self, pipeline: &ComputePipelineStateRef, bindings: &[(u64, KernelArg<'_>)], threadgroups: MTLSize, threadgroup_size: MTLSize, )
Encode a dispatch with mixed buffer/bytes bindings (dispatch_thread_groups).
Reuses the persistent compute encoder.
Encode a dispatch with mixed buffer/bytes bindings and shared memory.
Reuses the persistent compute encoder.
Sourcepub fn dispatch_tracked_threadgroups_with_args(
&mut self,
pipeline: &ComputePipelineStateRef,
bindings: &[(u64, KernelArg<'_>)],
reads: &[&MlxBuffer],
writes: &[&MlxBuffer],
threadgroups: MTLSize,
threadgroup_size: MTLSize,
)
pub fn dispatch_tracked_threadgroups_with_args( &mut self, pipeline: &ComputePipelineStateRef, bindings: &[(u64, KernelArg<'_>)], reads: &[&MlxBuffer], writes: &[&MlxBuffer], threadgroups: MTLSize, threadgroup_size: MTLSize, )
Auto-barrier-aware dispatch with KernelArg bindings (uses
dispatch_thread_groups).
Behaves identically to
encode_threadgroups_with_args
when HF2Q_AUTO_BARRIER is unset. When set, consults the
per-encoder MemRanges tracker:
- Conflict (RAW/WAR/WAW on a same-buffer range) → emit
memory_barrier(), increment [AUTO_BARRIER_COUNT], reset the tracker, then dispatch and seed the new concurrent group with this dispatch’s ranges. - No conflict → increment [
AUTO_BARRIER_CONCURRENT], record the ranges into the cumulative state, dispatch.
Auto-barrier-aware dispatch with KernelArg bindings + shared
threadgroup memory.
See dispatch_tracked_threadgroups_with_args
for the behavioral contract; this variant additionally takes a
threadgroup_mem slice that is forwarded to
encode_threadgroups_with_args_and_shared.
The 8-argument signature mirrors the existing
encode_threadgroups_with_args_and_shared plus the two
dataflow slices; clippy::too_many_arguments is allowed
because each parameter is load-bearing for either the dispatch
(pipeline/bindings/threadgroups/threadgroup_size/shared_mem)
or the auto-barrier (reads/writes).
Sourcepub fn dispatch_tracked_threadgroups(
&mut self,
pipeline: &ComputePipelineStateRef,
buffers: &[(u64, &MlxBuffer)],
reads: &[&MlxBuffer],
writes: &[&MlxBuffer],
threadgroups: MTLSize,
threadgroup_size: MTLSize,
)
pub fn dispatch_tracked_threadgroups( &mut self, pipeline: &ComputePipelineStateRef, buffers: &[(u64, &MlxBuffer)], reads: &[&MlxBuffer], writes: &[&MlxBuffer], threadgroups: MTLSize, threadgroup_size: MTLSize, )
Auto-barrier-aware dispatch using (slot, &MlxBuffer) bindings
(uses dispatch_thread_groups).
Convenience wrapper for callers that don’t need
KernelArg::Bytes inline-byte arguments. See
dispatch_tracked_threadgroups_with_args
for behavioral contract.
Auto-barrier-aware dispatch using (slot, &MlxBuffer) bindings
plus shared threadgroup memory (uses dispatch_thread_groups).
Mirrors encode_threadgroups_with_shared
— convenience variant for kernels that allocate threadgroup
memory (reductions in rms_norm, softmax, etc.) but don’t
need KernelArg::Bytes inline-byte arguments. See
dispatch_tracked_threadgroups_with_args
for the behavioral contract; the only addition here is the
threadgroup_mem slice forwarded to the underlying encode.
Closes the iter38-audit coverage gap: the 5 rms_norm.rs
callsites (/opt/mlx-native/src/ops/rms_norm.rs:124,236,443, 516,589) all use encode_threadgroups_with_shared and need
dataflow tracking when migrated to auto-barrier in iter40+.
7-argument signature; clippy::too_many_arguments is allowed
because each parameter is load-bearing for either the dispatch
(pipeline/buffers/threadgroups/threadgroup_size/shared_mem) or
the auto-barrier (reads/writes).
Sourcepub fn dispatch_tracked_threads_with_args(
&mut self,
pipeline: &ComputePipelineStateRef,
bindings: &[(u64, KernelArg<'_>)],
reads: &[&MlxBuffer],
writes: &[&MlxBuffer],
grid_size: MTLSize,
threadgroup_size: MTLSize,
)
pub fn dispatch_tracked_threads_with_args( &mut self, pipeline: &ComputePipelineStateRef, bindings: &[(u64, KernelArg<'_>)], reads: &[&MlxBuffer], writes: &[&MlxBuffer], grid_size: MTLSize, threadgroup_size: MTLSize, )
Auto-barrier-aware dispatch_threads variant with
KernelArg bindings.
Mirrors encode_with_args — the
dispatch_threads (per-thread grid) flavor, as opposed to the
dispatch_thread_groups flavor of
dispatch_tracked_threadgroups_with_args.
See that method for the behavioral contract.
Closes the iter38-audit coverage gap: callers that use
per-thread grids — rope.rs:108 (IMROPE), sigmoid_mul.rs:76
(sigmoid-mul), and encode_helpers.rs:41 (kv_cache_copy) —
need a dispatch_threads flavor of the tracked dispatch
because their grid sizes are expressed in threads, not
threadgroups.
Note: the simpler (slot, &MlxBuffer) form (from
encode) is a special case of this method —
callers can wrap each binding as KernelArg::Buffer(buf) to
reuse this single tracked variant rather than introducing a
fifth one.
Sourcepub fn dispatch_record(
&mut self,
rec: &DispatchRecord,
runtime_buffers: &[&MlxBuffer],
)
pub fn dispatch_record( &mut self, rec: &DispatchRecord, runtime_buffers: &[&MlxBuffer], )
Dispatch a pre-baked record.
ADR-029 iter-175 Step 1d — fast path for decode hot kernels
whose pipeline + threadgroup geometry + params bytes are
load-time-immutable. runtime_buffers must be in the same
order as rec.buffer_slots.
Equivalent Metal command stream to:
encoder.encode_threadgroups_with_args_and_shared(
&rec.pipeline,
bindings, // = runtime_buffers zipped with buffer_slots + (params_slot, Bytes(&rec.params_bytes))
&rec.threadgroup_mem,
rec.threadgroups,
rec.threads_per_tg,
);— but skips the kernel-name lookup, ggml_type match arms, MTLSize::new, and param-struct field stores that the unbaked path performs on every call.
Capture mode and auto-barrier are supported identically to
encode_threadgroups_with_args_and_shared. The caller is
expected to have called set_pending_buffer_ranges (capture)
or rely on auto-barrier for dataflow correctness before this
call, matching the contract of the unbaked dispatch_tracked_*
family.
Sourcepub fn force_barrier_and_reset_tracker(&mut self)
pub fn force_barrier_and_reset_tracker(&mut self)
Force a barrier and reset the auto-barrier tracker.
Use at boundaries where the caller knows a barrier is required regardless of dataflow — typically before reading data back to CPU, or at the end of an op group whose internal dependencies the tracker can’t see (e.g. host-driven memcpy).
Equivalent to memory_barrier() plus a MemRanges::reset()
when HF2Q_AUTO_BARRIER=1; equivalent to plain
memory_barrier() otherwise.
Sourcepub fn mem_ranges_len(&self) -> usize
pub fn mem_ranges_len(&self) -> usize
Diagnostic accessor — number of ranges currently recorded in
this encoder’s MemRanges tracker. Always zero unless
HF2Q_AUTO_BARRIER=1 and at least one dispatch_tracked call
has fired since the last conflict.
Sourcepub fn replay_dispatch(
&mut self,
pipeline: &ComputePipelineStateRef,
bindings: &[(u64, RecordedBinding)],
threadgroup_memory: &[(u64, u64)],
threads_per_grid: MTLSize,
threads_per_threadgroup: MTLSize,
dispatch_kind: DispatchKind,
)
pub fn replay_dispatch( &mut self, pipeline: &ComputePipelineStateRef, bindings: &[(u64, RecordedBinding)], threadgroup_memory: &[(u64, u64)], threads_per_grid: MTLSize, threads_per_threadgroup: MTLSize, dispatch_kind: DispatchKind, )
Replay a single captured dispatch node into this encoder.
This is the inverse of capture: it takes a previously recorded
CapturedNode::Dispatch and encodes it into the live Metal encoder.
Barrier nodes are handled by the caller (ComputeGraph::encode_sequential).
Does NOT increment DISPATCH_COUNT — that was already counted at
capture time.
Sourcepub fn commit_and_wait(&mut self) -> Result<()>
pub fn commit_and_wait(&mut self) -> Result<()>
Commit the command buffer and block until the GPU finishes execution.
§Errors
Returns MlxError::CommandBufferError if the GPU reports an error.
Sourcepub fn commit_and_wait_labeled(&mut self, label: &str) -> Result<()>
pub fn commit_and_wait_labeled(&mut self, label: &str) -> Result<()>
Commit + wait, accumulating GPU wall-clock time under label into
the crate::kernel_profile global table when MLX_PROFILE_CB=1
is set. When the env var is unset, this is identical to
commit_and_wait — zero overhead.
Used by hf2q’s decode hot path to attribute per-cb GPU time to
labeled phases (per-layer attn, per-layer ffn, output_head, etc.)
without manually wiring commit_wait_with_gpu_time everywhere.
§Errors
Returns MlxError::CommandBufferError if the GPU reports an error.
Sourcepub fn commit_labeled(&mut self, label: &str)
pub fn commit_labeled(&mut self, label: &str)
Async commit, but with profiling label. When MLX_PROFILE_CB=1
is set, redirects to a synchronous [commit_and_wait_labeled]
call to capture per-cb GPU time (this defeats async pipelining
while profiling, which is the whole point — profile-mode is slow
but informative). When unset, identical to commit.
Sourcepub fn commit_wait_with_gpu_time(&mut self) -> Result<(f64, f64)>
pub fn commit_wait_with_gpu_time(&mut self) -> Result<(f64, f64)>
Commit + wait, returning (gpu_start_s, gpu_end_s) CFTimeInterval
timestamps from MTLCommandBuffer’s GPUStartTime/GPUEndTime
properties. Both are mach-absolute CFTimeInterval seconds (double).
Intended for HF2Q_PROFILE_GPU_TS=1 per-bucket GPU wall-clock
attribution. Adds exactly two ObjC property reads per call on top
of the regular commit_and_wait — measured well under 1 μs on
M5 Max.
§Errors
Returns MlxError::CommandBufferError if the GPU reports an error.
Sourcepub fn commit(&mut self)
pub fn commit(&mut self)
Commit the command buffer WITHOUT blocking.
The GPU begins executing the encoded commands immediately. Call
wait_until_completed later to block
the CPU and check for errors. This allows the CPU to continue doing
other work (e.g. preparing the next batch) while the GPU runs.
Sourcepub fn wait_until_completed(&self) -> Result<()>
pub fn wait_until_completed(&self) -> Result<()>
Block until a previously committed command buffer completes.
Must be called after commit. Do not call after
commit_and_wait — that method already waits.
§Errors
Returns MlxError::CommandBufferError if the GPU reports an error.
Sourcepub fn metal_command_buffer(&self) -> &CommandBuffer
pub fn metal_command_buffer(&self) -> &CommandBuffer
Borrow the underlying Metal command buffer.
Trait Implementations§
Source§impl Drop for CommandEncoder
impl Drop for CommandEncoder
impl Send for CommandEncoder
SAFETY: CommandEncoder is safe to Send across threads provided that:
- Only one thread accesses the encoder at a time (exclusive ownership).
- The encoder is not used concurrently from multiple threads.
Metal command buffers and compute encoders are thread-safe for exclusive
access (Apple documentation: “You can create command buffers, encode
commands, and submit them from any thread”). The raw pointer
active_encoder borrows from cmd_buf and is valid as long as
cmd_buf is alive — this invariant holds across thread boundaries
because both fields move together.
This matches llama.cpp’s pattern of encoding command buffers on GCD
worker threads via dispatch_apply, and is used for the dual-buffer
pipeline where buf1 is encoded on a worker thread while buf0 executes.