Skip to main content

CommandEncoder

Struct CommandEncoder 

Source
pub struct CommandEncoder { /* private fields */ }
Expand description

A batched compute command encoder.

Keeps a single Metal ComputeCommandEncoder alive across multiple dispatches. The encoder is created on the first dispatch and ended only when the command buffer is committed. This mirrors candle’s compute_per_buffer pattern and avoids per-dispatch encoder overhead.

§Typical usage

let mut enc = device.command_encoder()?;
// Multiple dispatches share the same compute encoder:
enc.encode_threadgroups(pipeline1, &buffers1, tg1, tg_size1);
enc.encode_threadgroups(pipeline2, &buffers2, tg2, tg_size2);
enc.commit_and_wait()?;

Implementations§

Source§

impl CommandEncoder

Source

pub fn start_capture(&mut self)

Enable capture mode.

All subsequent dispatch and barrier calls will be recorded into a Vec<CapturedNode> instead of being encoded into Metal. Call take_capture() to extract the recorded nodes.

Source

pub fn is_capturing(&self) -> bool

Whether the encoder is currently in capture mode.

Source

pub fn take_capture(&mut self) -> Option<Vec<CapturedNode>>

Extract the captured nodes, ending capture mode.

Returns None if capture mode was not active.

Source

pub fn set_op_kind(&mut self, kind: CapturedOpKind)

Tag the NEXT captured dispatch with the given operation kind.

The tag is consumed (reset to Other) after the next dispatch is captured. Only meaningful in capture mode — has no effect on direct-dispatch encoding.

Used by op dispatch functions to annotate captures for the fusion pass (Phase 4e.2).

Source

pub fn set_pending_buffer_ranges( &mut self, reads: Vec<(usize, usize)>, writes: Vec<(usize, usize)>, )

Stash buffer range annotations for the NEXT captured dispatch.

Called by GraphSession::barrier_between() in capture mode to record which buffers the next dispatch reads from and writes to. The ranges are consumed by the next encode_* call and attached to the captured CapturedNode::Dispatch.

Only meaningful in capture mode — has no effect on direct-dispatch.

Source

pub fn annotate_last_dispatch_if_missing( &mut self, reads: Vec<(usize, usize)>, writes: Vec<(usize, usize)>, )

Patch the last captured dispatch node’s empty reads/writes with the given ranges. No-op if not capturing, or if the last node isn’t a Dispatch, or if its ranges are already populated.

Used by GraphSession::track_dispatch in recording mode to annotate dispatches that were called without a preceding barrier_between.

Source

pub fn memory_barrier(&mut self)

Insert a memory barrier with scope MTLBarrierScopeBuffers.

When the encoder uses MTLDispatchTypeConcurrent, all dispatches can execute concurrently unless separated by a barrier. Call this between dispatches where the later dispatch reads a buffer written by an earlier one.

This is the same pattern llama.cpp uses: [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers]

Source

pub fn set_pipeline(&mut self, pipeline: &ComputePipelineStateRef)

Set the compute pipeline state for subsequent dispatches.

This begins a new compute pass if one is not already active.

Source

pub fn set_buffer(&self, index: u64, buffer: &MlxBuffer)

Bind a buffer to a compute kernel argument slot.

The index corresponds to the [[buffer(N)]] attribute in the MSL shader.

Source

pub fn dispatch_threads(&self, grid_size: MTLSize, threadgroup_size: MTLSize)

Dispatch threads on the GPU.

Source

pub fn encode( &mut self, pipeline: &ComputePipelineStateRef, buffers: &[(u64, &MlxBuffer)], grid_size: MTLSize, threadgroup_size: MTLSize, )

Encode a complete compute pass: set pipeline, bind buffers, dispatch.

Reuses the persistent compute encoder — no per-dispatch encoder creation overhead.

§Arguments
  • pipeline — The compiled compute pipeline to execute.
  • buffers — Slice of (index, &MlxBuffer) pairs for buffer bindings.
  • grid_size — Total number of threads to launch.
  • threadgroup_size — Threads per threadgroup.
Source

pub fn encode_threadgroups( &mut self, pipeline: &ComputePipelineStateRef, buffers: &[(u64, &MlxBuffer)], threadgroups: MTLSize, threadgroup_size: MTLSize, )

Encode a compute pass using threadgroups instead of raw thread counts.

Reuses the persistent compute encoder — no per-dispatch encoder creation overhead.

Source

pub fn encode_threadgroups_with_shared( &mut self, pipeline: &ComputePipelineStateRef, buffers: &[(u64, &MlxBuffer)], threadgroup_mem: &[(u64, u64)], threadgroups: MTLSize, threadgroup_size: MTLSize, )

Encode a compute pass using threadgroups with shared threadgroup memory.

Like encode_threadgroups, but additionally allocates threadgroup memory at the specified indices. This is required for kernels that use threadgroup memory (e.g. reductions in rms_norm and softmax).

§Arguments
  • pipeline — The compiled compute pipeline to execute.
  • buffers — Slice of (index, &MlxBuffer) pairs for buffer bindings.
  • threadgroup_mem — Slice of (index, byte_length) pairs for threadgroup memory.
  • threadgroups — Number of threadgroups to dispatch.
  • threadgroup_size — Threads per threadgroup.
Source

pub fn encode_with_args( &mut self, pipeline: &ComputePipelineStateRef, bindings: &[(u64, KernelArg<'_>)], grid_size: MTLSize, threadgroup_size: MTLSize, )

Encode a dispatch with mixed buffer/bytes bindings (dispatch_threads).

Reuses the persistent compute encoder.

Source

pub fn encode_threadgroups_with_args( &mut self, pipeline: &ComputePipelineStateRef, bindings: &[(u64, KernelArg<'_>)], threadgroups: MTLSize, threadgroup_size: MTLSize, )

Encode a dispatch with mixed buffer/bytes bindings (dispatch_thread_groups).

Reuses the persistent compute encoder.

Source

pub fn encode_threadgroups_with_args_and_shared( &mut self, pipeline: &ComputePipelineStateRef, bindings: &[(u64, KernelArg<'_>)], threadgroup_mem: &[(u64, u64)], threadgroups: MTLSize, threadgroup_size: MTLSize, )

Encode a dispatch with mixed buffer/bytes bindings and shared memory.

Reuses the persistent compute encoder.

Source

pub fn replay_dispatch( &mut self, pipeline: &ComputePipelineStateRef, bindings: &[(u64, RecordedBinding)], threadgroup_memory: &[(u64, u64)], threads_per_grid: MTLSize, threads_per_threadgroup: MTLSize, dispatch_kind: DispatchKind, )

Replay a single captured dispatch node into this encoder.

This is the inverse of capture: it takes a previously recorded CapturedNode::Dispatch and encodes it into the live Metal encoder. Barrier nodes are handled by the caller (ComputeGraph::encode_sequential).

Does NOT increment DISPATCH_COUNT — that was already counted at capture time.

Source

pub fn commit_and_wait(&mut self) -> Result<()>

Commit the command buffer and block until the GPU finishes execution.

§Errors

Returns MlxError::CommandBufferError if the GPU reports an error.

Source

pub fn commit(&mut self)

Commit the command buffer WITHOUT blocking.

The GPU begins executing the encoded commands immediately. Call wait_until_completed later to block the CPU and check for errors. This allows the CPU to continue doing other work (e.g. preparing the next batch) while the GPU runs.

Source

pub fn wait_until_completed(&self) -> Result<()>

Block until a previously committed command buffer completes.

Must be called after commit. Do not call after commit_and_wait — that method already waits.

§Errors

Returns MlxError::CommandBufferError if the GPU reports an error.

Source

pub fn metal_command_buffer(&self) -> &CommandBuffer

Borrow the underlying Metal command buffer.

Trait Implementations§

Source§

impl Drop for CommandEncoder

Source§

fn drop(&mut self)

Executes the destructor for this type. Read more
Source§

impl Send for CommandEncoder

SAFETY: CommandEncoder is safe to Send across threads provided that:

  1. Only one thread accesses the encoder at a time (exclusive ownership).
  2. The encoder is not used concurrently from multiple threads.

Metal command buffers and compute encoders are thread-safe for exclusive access (Apple documentation: “You can create command buffers, encode commands, and submit them from any thread”). The raw pointer active_encoder borrows from cmd_buf and is valid as long as cmd_buf is alive — this invariant holds across thread boundaries because both fields move together.

This matches llama.cpp’s pattern of encoding command buffers on GCD worker threads via dispatch_apply, and is used for the dual-buffer pipeline where buf1 is encoded on a worker thread while buf0 executes.

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.