pub struct CommandEncoder { /* private fields */ }Expand description
A batched compute command encoder.
Keeps a single Metal ComputeCommandEncoder alive across multiple
dispatches. The encoder is created on the first dispatch and ended
only when the command buffer is committed. This mirrors candle’s
compute_per_buffer pattern and avoids per-dispatch encoder overhead.
§Typical usage
let mut enc = device.command_encoder()?;
// Multiple dispatches share the same compute encoder:
enc.encode_threadgroups(pipeline1, &buffers1, tg1, tg_size1);
enc.encode_threadgroups(pipeline2, &buffers2, tg2, tg_size2);
enc.commit_and_wait()?;Implementations§
Source§impl CommandEncoder
impl CommandEncoder
Sourcepub fn start_capture(&mut self)
pub fn start_capture(&mut self)
Enable capture mode.
All subsequent dispatch and barrier calls will be recorded into a
Vec<CapturedNode> instead of being encoded into Metal.
Call take_capture() to extract the recorded nodes.
Sourcepub fn is_capturing(&self) -> bool
pub fn is_capturing(&self) -> bool
Whether the encoder is currently in capture mode.
Sourcepub fn take_capture(&mut self) -> Option<Vec<CapturedNode>>
pub fn take_capture(&mut self) -> Option<Vec<CapturedNode>>
Extract the captured nodes, ending capture mode.
Returns None if capture mode was not active.
Sourcepub fn set_op_kind(&mut self, kind: CapturedOpKind)
pub fn set_op_kind(&mut self, kind: CapturedOpKind)
Tag the NEXT captured dispatch with the given operation kind.
The tag is consumed (reset to Other) after the next dispatch is
captured. Only meaningful in capture mode — has no effect on
direct-dispatch encoding.
Used by op dispatch functions to annotate captures for the fusion pass (Phase 4e.2).
Sourcepub fn set_pending_buffer_ranges(
&mut self,
reads: Vec<(usize, usize)>,
writes: Vec<(usize, usize)>,
)
pub fn set_pending_buffer_ranges( &mut self, reads: Vec<(usize, usize)>, writes: Vec<(usize, usize)>, )
Stash buffer range annotations for the NEXT captured dispatch.
Called by GraphSession::barrier_between() in capture mode to record
which buffers the next dispatch reads from and writes to. The ranges
are consumed by the next encode_* call and attached to the captured
CapturedNode::Dispatch.
Only meaningful in capture mode — has no effect on direct-dispatch.
Sourcepub fn annotate_last_dispatch_if_missing(
&mut self,
reads: Vec<(usize, usize)>,
writes: Vec<(usize, usize)>,
)
pub fn annotate_last_dispatch_if_missing( &mut self, reads: Vec<(usize, usize)>, writes: Vec<(usize, usize)>, )
Patch the last captured dispatch node’s empty reads/writes with the given ranges. No-op if not capturing, or if the last node isn’t a Dispatch, or if its ranges are already populated.
Used by GraphSession::track_dispatch in recording mode to annotate
dispatches that were called without a preceding barrier_between.
Sourcepub fn memory_barrier(&mut self)
pub fn memory_barrier(&mut self)
Insert a memory barrier with scope MTLBarrierScopeBuffers.
When the encoder uses MTLDispatchTypeConcurrent, all dispatches can
execute concurrently unless separated by a barrier. Call this between
dispatches where the later dispatch reads a buffer written by an
earlier one.
This is the same pattern llama.cpp uses:
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers]
Sourcepub fn set_pipeline(&mut self, pipeline: &ComputePipelineStateRef)
pub fn set_pipeline(&mut self, pipeline: &ComputePipelineStateRef)
Set the compute pipeline state for subsequent dispatches.
This begins a new compute pass if one is not already active.
Sourcepub fn set_buffer(&self, index: u64, buffer: &MlxBuffer)
pub fn set_buffer(&self, index: u64, buffer: &MlxBuffer)
Bind a buffer to a compute kernel argument slot.
The index corresponds to the [[buffer(N)]] attribute in the MSL shader.
Sourcepub fn dispatch_threads(&self, grid_size: MTLSize, threadgroup_size: MTLSize)
pub fn dispatch_threads(&self, grid_size: MTLSize, threadgroup_size: MTLSize)
Dispatch threads on the GPU.
Sourcepub fn encode(
&mut self,
pipeline: &ComputePipelineStateRef,
buffers: &[(u64, &MlxBuffer)],
grid_size: MTLSize,
threadgroup_size: MTLSize,
)
pub fn encode( &mut self, pipeline: &ComputePipelineStateRef, buffers: &[(u64, &MlxBuffer)], grid_size: MTLSize, threadgroup_size: MTLSize, )
Encode a complete compute pass: set pipeline, bind buffers, dispatch.
Reuses the persistent compute encoder — no per-dispatch encoder creation overhead.
§Arguments
pipeline— The compiled compute pipeline to execute.buffers— Slice of(index, &MlxBuffer)pairs for buffer bindings.grid_size— Total number of threads to launch.threadgroup_size— Threads per threadgroup.
Sourcepub fn encode_threadgroups(
&mut self,
pipeline: &ComputePipelineStateRef,
buffers: &[(u64, &MlxBuffer)],
threadgroups: MTLSize,
threadgroup_size: MTLSize,
)
pub fn encode_threadgroups( &mut self, pipeline: &ComputePipelineStateRef, buffers: &[(u64, &MlxBuffer)], threadgroups: MTLSize, threadgroup_size: MTLSize, )
Encode a compute pass using threadgroups instead of raw thread counts.
Reuses the persistent compute encoder — no per-dispatch encoder creation overhead.
Encode a compute pass using threadgroups with shared threadgroup memory.
Like encode_threadgroups, but additionally
allocates threadgroup memory at the specified indices. This is required
for kernels that use threadgroup memory (e.g. reductions in rms_norm
and softmax).
§Arguments
pipeline— The compiled compute pipeline to execute.buffers— Slice of(index, &MlxBuffer)pairs for buffer bindings.threadgroup_mem— Slice of(index, byte_length)pairs for threadgroup memory.threadgroups— Number of threadgroups to dispatch.threadgroup_size— Threads per threadgroup.
Sourcepub fn encode_with_args(
&mut self,
pipeline: &ComputePipelineStateRef,
bindings: &[(u64, KernelArg<'_>)],
grid_size: MTLSize,
threadgroup_size: MTLSize,
)
pub fn encode_with_args( &mut self, pipeline: &ComputePipelineStateRef, bindings: &[(u64, KernelArg<'_>)], grid_size: MTLSize, threadgroup_size: MTLSize, )
Encode a dispatch with mixed buffer/bytes bindings (dispatch_threads).
Reuses the persistent compute encoder.
Sourcepub fn encode_threadgroups_with_args(
&mut self,
pipeline: &ComputePipelineStateRef,
bindings: &[(u64, KernelArg<'_>)],
threadgroups: MTLSize,
threadgroup_size: MTLSize,
)
pub fn encode_threadgroups_with_args( &mut self, pipeline: &ComputePipelineStateRef, bindings: &[(u64, KernelArg<'_>)], threadgroups: MTLSize, threadgroup_size: MTLSize, )
Encode a dispatch with mixed buffer/bytes bindings (dispatch_thread_groups).
Reuses the persistent compute encoder.
Encode a dispatch with mixed buffer/bytes bindings and shared memory.
Reuses the persistent compute encoder.
Sourcepub fn replay_dispatch(
&mut self,
pipeline: &ComputePipelineStateRef,
bindings: &[(u64, RecordedBinding)],
threadgroup_memory: &[(u64, u64)],
threads_per_grid: MTLSize,
threads_per_threadgroup: MTLSize,
dispatch_kind: DispatchKind,
)
pub fn replay_dispatch( &mut self, pipeline: &ComputePipelineStateRef, bindings: &[(u64, RecordedBinding)], threadgroup_memory: &[(u64, u64)], threads_per_grid: MTLSize, threads_per_threadgroup: MTLSize, dispatch_kind: DispatchKind, )
Replay a single captured dispatch node into this encoder.
This is the inverse of capture: it takes a previously recorded
CapturedNode::Dispatch and encodes it into the live Metal encoder.
Barrier nodes are handled by the caller (ComputeGraph::encode_sequential).
Does NOT increment DISPATCH_COUNT — that was already counted at
capture time.
Sourcepub fn commit_and_wait(&mut self) -> Result<()>
pub fn commit_and_wait(&mut self) -> Result<()>
Commit the command buffer and block until the GPU finishes execution.
§Errors
Returns MlxError::CommandBufferError if the GPU reports an error.
Sourcepub fn commit(&mut self)
pub fn commit(&mut self)
Commit the command buffer WITHOUT blocking.
The GPU begins executing the encoded commands immediately. Call
wait_until_completed later to block
the CPU and check for errors. This allows the CPU to continue doing
other work (e.g. preparing the next batch) while the GPU runs.
Sourcepub fn wait_until_completed(&self) -> Result<()>
pub fn wait_until_completed(&self) -> Result<()>
Block until a previously committed command buffer completes.
Must be called after commit. Do not call after
commit_and_wait — that method already waits.
§Errors
Returns MlxError::CommandBufferError if the GPU reports an error.
Sourcepub fn metal_command_buffer(&self) -> &CommandBuffer
pub fn metal_command_buffer(&self) -> &CommandBuffer
Borrow the underlying Metal command buffer.
Trait Implementations§
Source§impl Drop for CommandEncoder
impl Drop for CommandEncoder
impl Send for CommandEncoder
SAFETY: CommandEncoder is safe to Send across threads provided that:
- Only one thread accesses the encoder at a time (exclusive ownership).
- The encoder is not used concurrently from multiple threads.
Metal command buffers and compute encoders are thread-safe for exclusive
access (Apple documentation: “You can create command buffers, encode
commands, and submit them from any thread”). The raw pointer
active_encoder borrows from cmd_buf and is valid as long as
cmd_buf is alive — this invariant holds across thread boundaries
because both fields move together.
This matches llama.cpp’s pattern of encoding command buffers on GCD
worker threads via dispatch_apply, and is used for the dual-buffer
pipeline where buf1 is encoded on a worker thread while buf0 executes.