Skip to main content

CommandEncoder

Struct CommandEncoder 

Source
pub struct CommandEncoder { /* private fields */ }
Expand description

A batched compute command encoder.

Keeps a single Metal ComputeCommandEncoder alive across multiple dispatches. The encoder is created on the first dispatch and ended only when the command buffer is committed. This mirrors candle’s compute_per_buffer pattern and avoids per-dispatch encoder overhead.

§Typical usage

let mut enc = device.command_encoder()?;
// Multiple dispatches share the same compute encoder:
enc.encode_threadgroups(pipeline1, &buffers1, tg1, tg_size1);
enc.encode_threadgroups(pipeline2, &buffers2, tg2, tg_size2);
enc.commit_and_wait()?;

Implementations§

Source§

impl CommandEncoder

Source

pub fn start_capture(&mut self)

Enable capture mode.

All subsequent dispatch and barrier calls will be recorded into a Vec<CapturedNode> instead of being encoded into Metal. Call take_capture() to extract the recorded nodes.

Source

pub fn is_capturing(&self) -> bool

Whether the encoder is currently in capture mode.

Source

pub fn take_capture(&mut self) -> Option<Vec<CapturedNode>>

Extract the captured nodes, ending capture mode.

Returns None if capture mode was not active.

Source

pub fn set_op_kind(&mut self, kind: CapturedOpKind)

Tag the NEXT captured dispatch with the given operation kind.

The tag is consumed (reset to Other) after the next dispatch is captured. Only meaningful in capture mode — has no effect on direct-dispatch encoding.

Used by op dispatch functions to annotate captures for the fusion pass (Phase 4e.2).

Source

pub fn set_pending_buffer_ranges( &mut self, reads: Vec<(usize, usize)>, writes: Vec<(usize, usize)>, )

Stash buffer range annotations for the NEXT captured dispatch.

Called by GraphSession::barrier_between() in capture mode to record which buffers the next dispatch reads from and writes to. The ranges are consumed by the next encode_* call and attached to the captured CapturedNode::Dispatch.

Only meaningful in capture mode — has no effect on direct-dispatch.

Source

pub fn annotate_last_dispatch_if_missing( &mut self, reads: Vec<(usize, usize)>, writes: Vec<(usize, usize)>, )

Patch the last captured dispatch node’s empty reads/writes with the given ranges. No-op if not capturing, or if the last node isn’t a Dispatch, or if its ranges are already populated.

Used by GraphSession::track_dispatch in recording mode to annotate dispatches that were called without a preceding barrier_between.

Source

pub fn memory_barrier(&mut self)

Insert a memory barrier with scope MTLBarrierScopeBuffers.

When the encoder uses MTLDispatchTypeConcurrent, all dispatches can execute concurrently unless separated by a barrier. Call this between dispatches where the later dispatch reads a buffer written by an earlier one.

This is the same pattern llama.cpp uses: [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers]

Source

pub fn set_pipeline(&mut self, pipeline: &ComputePipelineStateRef)

Set the compute pipeline state for subsequent dispatches.

This begins a new compute pass if one is not already active.

Source

pub fn set_buffer(&self, index: u64, buffer: &MlxBuffer)

Bind a buffer to a compute kernel argument slot.

The index corresponds to the [[buffer(N)]] attribute in the MSL shader.

Source

pub fn dispatch_threads(&self, grid_size: MTLSize, threadgroup_size: MTLSize)

Dispatch threads on the GPU.

Source

pub fn encode( &mut self, pipeline: &ComputePipelineStateRef, buffers: &[(u64, &MlxBuffer)], grid_size: MTLSize, threadgroup_size: MTLSize, )

Encode a complete compute pass: set pipeline, bind buffers, dispatch.

Reuses the persistent compute encoder — no per-dispatch encoder creation overhead.

§Arguments
  • pipeline — The compiled compute pipeline to execute.
  • buffers — Slice of (index, &MlxBuffer) pairs for buffer bindings.
  • grid_size — Total number of threads to launch.
  • threadgroup_size — Threads per threadgroup.
Source

pub fn encode_threadgroups( &mut self, pipeline: &ComputePipelineStateRef, buffers: &[(u64, &MlxBuffer)], threadgroups: MTLSize, threadgroup_size: MTLSize, )

Encode a compute pass using threadgroups instead of raw thread counts.

Reuses the persistent compute encoder — no per-dispatch encoder creation overhead.

Source

pub fn encode_threadgroups_with_shared( &mut self, pipeline: &ComputePipelineStateRef, buffers: &[(u64, &MlxBuffer)], threadgroup_mem: &[(u64, u64)], threadgroups: MTLSize, threadgroup_size: MTLSize, )

Encode a compute pass using threadgroups with shared threadgroup memory.

Like encode_threadgroups, but additionally allocates threadgroup memory at the specified indices. This is required for kernels that use threadgroup memory (e.g. reductions in rms_norm and softmax).

§Arguments
  • pipeline — The compiled compute pipeline to execute.
  • buffers — Slice of (index, &MlxBuffer) pairs for buffer bindings.
  • threadgroup_mem — Slice of (index, byte_length) pairs for threadgroup memory.
  • threadgroups — Number of threadgroups to dispatch.
  • threadgroup_size — Threads per threadgroup.
Source

pub fn encode_with_args( &mut self, pipeline: &ComputePipelineStateRef, bindings: &[(u64, KernelArg<'_>)], grid_size: MTLSize, threadgroup_size: MTLSize, )

Encode a dispatch with mixed buffer/bytes bindings (dispatch_threads).

Reuses the persistent compute encoder.

Source

pub fn encode_threadgroups_with_args( &mut self, pipeline: &ComputePipelineStateRef, bindings: &[(u64, KernelArg<'_>)], threadgroups: MTLSize, threadgroup_size: MTLSize, )

Encode a dispatch with mixed buffer/bytes bindings (dispatch_thread_groups).

Reuses the persistent compute encoder.

Source

pub fn encode_threadgroups_with_args_and_shared( &mut self, pipeline: &ComputePipelineStateRef, bindings: &[(u64, KernelArg<'_>)], threadgroup_mem: &[(u64, u64)], threadgroups: MTLSize, threadgroup_size: MTLSize, )

Encode a dispatch with mixed buffer/bytes bindings and shared memory.

Reuses the persistent compute encoder.

Source

pub fn dispatch_tracked_threadgroups_with_args( &mut self, pipeline: &ComputePipelineStateRef, bindings: &[(u64, KernelArg<'_>)], reads: &[&MlxBuffer], writes: &[&MlxBuffer], threadgroups: MTLSize, threadgroup_size: MTLSize, )

Auto-barrier-aware dispatch with KernelArg bindings (uses dispatch_thread_groups).

Behaves identically to encode_threadgroups_with_args when HF2Q_AUTO_BARRIER is unset. When set, consults the per-encoder MemRanges tracker:

  • Conflict (RAW/WAR/WAW on a same-buffer range) → emit memory_barrier(), increment [AUTO_BARRIER_COUNT], reset the tracker, then dispatch and seed the new concurrent group with this dispatch’s ranges.
  • No conflict → increment [AUTO_BARRIER_CONCURRENT], record the ranges into the cumulative state, dispatch.
Source

pub fn dispatch_tracked_threadgroups_with_args_and_shared( &mut self, pipeline: &ComputePipelineStateRef, bindings: &[(u64, KernelArg<'_>)], threadgroup_mem: &[(u64, u64)], reads: &[&MlxBuffer], writes: &[&MlxBuffer], threadgroups: MTLSize, threadgroup_size: MTLSize, )

Auto-barrier-aware dispatch with KernelArg bindings + shared threadgroup memory.

See dispatch_tracked_threadgroups_with_args for the behavioral contract; this variant additionally takes a threadgroup_mem slice that is forwarded to encode_threadgroups_with_args_and_shared.

The 8-argument signature mirrors the existing encode_threadgroups_with_args_and_shared plus the two dataflow slices; clippy::too_many_arguments is allowed because each parameter is load-bearing for either the dispatch (pipeline/bindings/threadgroups/threadgroup_size/shared_mem) or the auto-barrier (reads/writes).

Source

pub fn dispatch_tracked_threadgroups( &mut self, pipeline: &ComputePipelineStateRef, buffers: &[(u64, &MlxBuffer)], reads: &[&MlxBuffer], writes: &[&MlxBuffer], threadgroups: MTLSize, threadgroup_size: MTLSize, )

Auto-barrier-aware dispatch using (slot, &MlxBuffer) bindings (uses dispatch_thread_groups).

Convenience wrapper for callers that don’t need KernelArg::Bytes inline-byte arguments. See dispatch_tracked_threadgroups_with_args for behavioral contract.

Source

pub fn dispatch_tracked_threadgroups_with_shared( &mut self, pipeline: &ComputePipelineStateRef, buffers: &[(u64, &MlxBuffer)], threadgroup_mem: &[(u64, u64)], reads: &[&MlxBuffer], writes: &[&MlxBuffer], threadgroups: MTLSize, threadgroup_size: MTLSize, )

Auto-barrier-aware dispatch using (slot, &MlxBuffer) bindings plus shared threadgroup memory (uses dispatch_thread_groups).

Mirrors encode_threadgroups_with_shared — convenience variant for kernels that allocate threadgroup memory (reductions in rms_norm, softmax, etc.) but don’t need KernelArg::Bytes inline-byte arguments. See dispatch_tracked_threadgroups_with_args for the behavioral contract; the only addition here is the threadgroup_mem slice forwarded to the underlying encode.

Closes the iter38-audit coverage gap: the 5 rms_norm.rs callsites (/opt/mlx-native/src/ops/rms_norm.rs:124,236,443, 516,589) all use encode_threadgroups_with_shared and need dataflow tracking when migrated to auto-barrier in iter40+.

7-argument signature; clippy::too_many_arguments is allowed because each parameter is load-bearing for either the dispatch (pipeline/buffers/threadgroups/threadgroup_size/shared_mem) or the auto-barrier (reads/writes).

Source

pub fn dispatch_tracked_threads_with_args( &mut self, pipeline: &ComputePipelineStateRef, bindings: &[(u64, KernelArg<'_>)], reads: &[&MlxBuffer], writes: &[&MlxBuffer], grid_size: MTLSize, threadgroup_size: MTLSize, )

Auto-barrier-aware dispatch_threads variant with KernelArg bindings.

Mirrors encode_with_args — the dispatch_threads (per-thread grid) flavor, as opposed to the dispatch_thread_groups flavor of dispatch_tracked_threadgroups_with_args. See that method for the behavioral contract.

Closes the iter38-audit coverage gap: callers that use per-thread grids — rope.rs:108 (IMROPE), sigmoid_mul.rs:76 (sigmoid-mul), and encode_helpers.rs:41 (kv_cache_copy) — need a dispatch_threads flavor of the tracked dispatch because their grid sizes are expressed in threads, not threadgroups.

Note: the simpler (slot, &MlxBuffer) form (from encode) is a special case of this method — callers can wrap each binding as KernelArg::Buffer(buf) to reuse this single tracked variant rather than introducing a fifth one.

Source

pub fn dispatch_record( &mut self, rec: &DispatchRecord, runtime_buffers: &[&MlxBuffer], )

Dispatch a pre-baked record.

ADR-029 iter-175 Step 1d — fast path for decode hot kernels whose pipeline + threadgroup geometry + params bytes are load-time-immutable. runtime_buffers must be in the same order as rec.buffer_slots.

Equivalent Metal command stream to:

encoder.encode_threadgroups_with_args_and_shared(
    &rec.pipeline,
    bindings,  // = runtime_buffers zipped with buffer_slots + (params_slot, Bytes(&rec.params_bytes))
    &rec.threadgroup_mem,
    rec.threadgroups,
    rec.threads_per_tg,
);

— but skips the kernel-name lookup, ggml_type match arms, MTLSize::new, and param-struct field stores that the unbaked path performs on every call.

Capture mode and auto-barrier are supported identically to encode_threadgroups_with_args_and_shared. The caller is expected to have called set_pending_buffer_ranges (capture) or rely on auto-barrier for dataflow correctness before this call, matching the contract of the unbaked dispatch_tracked_* family.

Source

pub fn force_barrier_and_reset_tracker(&mut self)

Force a barrier and reset the auto-barrier tracker.

Use at boundaries where the caller knows a barrier is required regardless of dataflow — typically before reading data back to CPU, or at the end of an op group whose internal dependencies the tracker can’t see (e.g. host-driven memcpy).

Equivalent to memory_barrier() plus a MemRanges::reset() when HF2Q_AUTO_BARRIER=1; equivalent to plain memory_barrier() otherwise.

Source

pub fn mem_ranges_len(&self) -> usize

Diagnostic accessor — number of ranges currently recorded in this encoder’s MemRanges tracker. Always zero unless HF2Q_AUTO_BARRIER=1 and at least one dispatch_tracked call has fired since the last conflict.

Source

pub fn replay_dispatch( &mut self, pipeline: &ComputePipelineStateRef, bindings: &[(u64, RecordedBinding)], threadgroup_memory: &[(u64, u64)], threads_per_grid: MTLSize, threads_per_threadgroup: MTLSize, dispatch_kind: DispatchKind, )

Replay a single captured dispatch node into this encoder.

This is the inverse of capture: it takes a previously recorded CapturedNode::Dispatch and encodes it into the live Metal encoder. Barrier nodes are handled by the caller (ComputeGraph::encode_sequential).

Does NOT increment DISPATCH_COUNT — that was already counted at capture time.

Source

pub fn commit_and_wait(&mut self) -> Result<()>

Commit the command buffer and block until the GPU finishes execution.

§Errors

Returns MlxError::CommandBufferError if the GPU reports an error.

Source

pub fn commit_and_wait_labeled(&mut self, label: &str) -> Result<()>

Commit + wait, accumulating GPU wall-clock time under label into the crate::kernel_profile global table when MLX_PROFILE_CB=1 is set. When the env var is unset, this is identical to commit_and_wait — zero overhead.

Used by hf2q’s decode hot path to attribute per-cb GPU time to labeled phases (per-layer attn, per-layer ffn, output_head, etc.) without manually wiring commit_wait_with_gpu_time everywhere.

§Errors

Returns MlxError::CommandBufferError if the GPU reports an error.

Source

pub fn commit_labeled(&mut self, label: &str)

Async commit, but with profiling label. When MLX_PROFILE_CB=1 is set, redirects to a synchronous [commit_and_wait_labeled] call to capture per-cb GPU time (this defeats async pipelining while profiling, which is the whole point — profile-mode is slow but informative). When unset, identical to commit.

Source

pub fn commit_wait_with_gpu_time(&mut self) -> Result<(f64, f64)>

Commit + wait, returning (gpu_start_s, gpu_end_s) CFTimeInterval timestamps from MTLCommandBuffer’s GPUStartTime/GPUEndTime properties. Both are mach-absolute CFTimeInterval seconds (double).

Intended for HF2Q_PROFILE_GPU_TS=1 per-bucket GPU wall-clock attribution. Adds exactly two ObjC property reads per call on top of the regular commit_and_wait — measured well under 1 μs on M5 Max.

§Errors

Returns MlxError::CommandBufferError if the GPU reports an error.

Source

pub fn commit(&mut self)

Commit the command buffer WITHOUT blocking.

The GPU begins executing the encoded commands immediately. Call wait_until_completed later to block the CPU and check for errors. This allows the CPU to continue doing other work (e.g. preparing the next batch) while the GPU runs.

Source

pub fn wait_until_completed(&self) -> Result<()>

Block until a previously committed command buffer completes.

Must be called after commit. Do not call after commit_and_wait — that method already waits.

§Errors

Returns MlxError::CommandBufferError if the GPU reports an error.

Source

pub fn metal_command_buffer(&self) -> &CommandBuffer

Borrow the underlying Metal command buffer.

Trait Implementations§

Source§

impl Drop for CommandEncoder

Source§

fn drop(&mut self)

Executes the destructor for this type. Read more
Source§

fn pin_drop(self: Pin<&mut Self>)

🔬This is a nightly-only experimental API. (pin_ergonomics)
Execute the destructor for this type, but different to Drop::drop, it requires self to be pinned. Read more
Source§

impl Send for CommandEncoder

SAFETY: CommandEncoder is safe to Send across threads provided that:

  1. Only one thread accesses the encoder at a time (exclusive ownership).
  2. The encoder is not used concurrently from multiple threads.

Metal command buffers and compute encoders are thread-safe for exclusive access (Apple documentation: “You can create command buffers, encode commands, and submit them from any thread”). The raw pointer active_encoder borrows from cmd_buf and is valid as long as cmd_buf is alive — this invariant holds across thread boundaries because both fields move together.

This matches llama.cpp’s pattern of encoding command buffers on GCD worker threads via dispatch_apply, and is used for the dual-buffer pipeline where buf1 is encoded on a worker thread while buf0 executes.

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.