mlx_native/encoder.rs
1//! [`CommandEncoder`] — batched GPU command submission.
2//!
3//! Wraps a Metal command buffer. Encode one or more compute kernel dispatches,
4//! then call [`commit_and_wait`](CommandEncoder::commit_and_wait) to submit the
5//! entire batch and block until the GPU finishes.
6//!
7//! # Persistent compute encoder
8//!
9//! A single Metal `ComputeCommandEncoder` is kept alive across multiple
10//! dispatches within the same command buffer. This avoids the overhead of
11//! creating and ending a new compute encoder per dispatch — the same pattern
12//! candle uses (`compute_per_buffer`). On a forward pass with ~800 dispatches
13//! this saves ~800 encoder create/end cycles.
14//!
15//! # Capture mode (Phase 4e.1)
16//!
17//! When `start_capture()` is called, subsequent dispatches are recorded into a
18//! `Vec<CapturedNode>` instead of being encoded into Metal. `memory_barrier()`
19//! records a barrier sentinel. Call `take_capture()` to extract the recorded
20//! graph for later replay via `ComputeGraph::encode_sequential()`.
21
22use std::sync::atomic::{AtomicU64, Ordering};
23
24use metal::{
25 CommandBuffer, CommandQueue, ComputeCommandEncoderRef, ComputePipelineState,
26 ComputePipelineStateRef, MTLCommandBufferStatus, MTLDispatchType, MTLSize,
27};
28#[allow(unused_imports)]
29use objc::{msg_send, sel, sel_impl};
30
31use crate::buffer::MlxBuffer;
32use crate::error::{MlxError, Result};
33use crate::residency::ResidencySet;
34
35/// A buffer or inline-bytes binding for a compute kernel argument slot.
36pub enum KernelArg<'a> {
37 /// Bind an existing Metal buffer at the given index.
38 Buffer(&'a MlxBuffer),
39 /// Bind an existing Metal buffer at the given index with a byte offset.
40 BufferWithOffset(&'a MlxBuffer, u64),
41 /// Bind inline bytes (small constant data) at the given index.
42 /// The data must be `Pod` and is copied into the command encoder.
43 Bytes(&'a [u8]),
44}
45
46/// Convert a `Pod` value to a byte slice suitable for `KernelArg::Bytes`.
47///
48/// # Safety
49///
50/// The caller must ensure `T` has the same layout as the corresponding
51/// MSL struct in the shader (matching field order, sizes, and alignment).
52pub fn as_bytes<T: bytemuck::Pod>(val: &T) -> &[u8] {
53 bytemuck::bytes_of(val)
54}
55
56// ---------------------------------------------------------------------------
57// Capture-mode types (Phase 4e.1 — Graph IR)
58// ---------------------------------------------------------------------------
59
60/// A recorded kernel argument binding.
61///
62/// When the encoder is in capture mode, each `set_buffer` / `set_bytes` call
63/// is stored as a `RecordedBinding` instead of being applied to Metal.
64#[derive(Clone)]
65pub enum RecordedBinding {
66 /// A Metal buffer at the given offset.
67 Buffer {
68 metal_buffer: metal::Buffer,
69 offset: u64,
70 },
71 /// Inline bytes (small constant data, copied).
72 Bytes(Vec<u8>),
73}
74
75/// How to dispatch the recorded kernel.
76#[derive(Clone, Copy, Debug)]
77pub enum DispatchKind {
78 /// `dispatch_threads(grid_size, threadgroup_size)` — Metal picks threadgroup count.
79 Threads,
80 /// `dispatch_thread_groups(threadgroups, threadgroup_size)` — caller specifies threadgroup count.
81 ThreadGroups,
82}
83
84/// Operation kind tag for captured nodes, used by the fusion pass (4e.2).
85///
86/// When the encoder is in capture mode, each dispatch can be tagged with an
87/// `OpKind` so the fusion pass can identify fuseable sequences without
88/// inspecting pipeline names.
89#[derive(Clone, Copy, Debug, PartialEq, Eq)]
90pub enum CapturedOpKind {
91 /// RMS normalization (with learned scale).
92 RmsNorm,
93 /// Elementwise multiply.
94 ElemMul,
95 /// Elementwise add.
96 ElemAdd,
97 /// Scaled dot-product attention (NOT reorderable — breaks lookahead).
98 Sdpa,
99 /// Softmax (NOT reorderable — breaks lookahead).
100 Softmax,
101 /// Any other operation — treated as reorderable by the graph optimizer.
102 Other,
103}
104
105impl CapturedOpKind {
106 /// Whether this captured op kind is safe to reorder past in the graph
107 /// optimizer (Phase 4e.3).
108 ///
109 /// Mirrors the `h_safe` whitelist from llama.cpp's
110 /// `ggml_metal_graph_optimize_reorder`. Non-safe ops break the 64-node
111 /// lookahead — the reorder pass cannot look past them.
112 pub fn is_reorderable(&self) -> bool {
113 match self {
114 Self::Sdpa | Self::Softmax => false,
115 Self::RmsNorm | Self::ElemMul | Self::ElemAdd | Self::Other => true,
116 }
117 }
118}
119
120/// A memory range annotation: (start_address, end_address).
121///
122/// Represents a contiguous GPU buffer region for conflict detection in the
123/// reorder pass (Phase 4e.3). Addresses are CPU-visible `contents_ptr()`
124/// values, which on Apple Silicon unified memory equal the GPU addresses.
125pub type MemRange = (usize, usize);
126
127/// A single captured compute dispatch or barrier sentinel.
128///
129/// Created when the encoder is in capture mode. Replayed later by
130/// `ComputeGraph::encode_sequential()`.
131#[derive(Clone)]
132pub enum CapturedNode {
133 /// A compute dispatch to replay.
134 Dispatch {
135 /// Pipeline state object to bind.
136 pipeline: ComputePipelineState,
137 /// Kernel argument bindings: (slot_index, binding).
138 bindings: Vec<(u64, RecordedBinding)>,
139 /// Grid or threadgroup count (interpretation depends on `dispatch_kind`).
140 threads_per_grid: MTLSize,
141 /// Threads per threadgroup.
142 threads_per_threadgroup: MTLSize,
143 /// Optional threadgroup memory allocations: (index, byte_length).
144 threadgroup_memory: Vec<(u64, u64)>,
145 /// Whether this is a dispatch_threads or dispatch_thread_groups call.
146 dispatch_kind: DispatchKind,
147 /// Operation kind tag for the fusion pass (4e.2).
148 /// Defaults to `Other` if not explicitly set via `set_op_kind()`.
149 op_kind: CapturedOpKind,
150 /// Read buffer ranges for reorder conflict detection (4e.3).
151 /// Populated from `barrier_between` calls in capture mode.
152 reads: Vec<MemRange>,
153 /// Write buffer ranges for reorder conflict detection (4e.3).
154 /// Populated from `barrier_between` calls in capture mode.
155 writes: Vec<MemRange>,
156 },
157 /// A memory barrier sentinel — forces a barrier at replay time.
158 Barrier,
159}
160
161/// Apply a slice of `KernelArg` bindings to a compute encoder.
162///
163/// `KernelArg::Buffer(buf)` propagates the `MlxBuffer::byte_offset()` so
164/// `slice_view`-derived sub-buffers are honored automatically — the
165/// kernel sees memory starting at the slice's offset. This matches the
166/// documented contract of `slice_view` and the offset-handling in the
167/// other binding paths in this file (`encode`, `encode_threadgroups`,
168/// `encode_threadgroups_with_shared`, replay). Without it, every
169/// `slice_view`-derived buffer bound via `KernelArg::Buffer` silently
170/// exposes the entire underlying allocation — surfaced by hf2q's
171/// nomic-bert iter-79 cosine parity bisection (cosine 0.098 → 0.999962
172/// after fix).
173///
174/// `KernelArg::BufferWithOffset(buf, offset)` continues to use the
175/// explicit `offset` argument verbatim (callers asking for an explicit
176/// offset get exactly that, even on sliced buffers). The two API
177/// surfaces are intentional: implicit (sliced views auto-propagate) vs.
178/// explicit (caller-controlled).
179#[inline]
180fn apply_bindings(encoder: &ComputeCommandEncoderRef, bindings: &[(u64, KernelArg<'_>)]) {
181 for &(index, ref arg) in bindings {
182 match arg {
183 KernelArg::Buffer(buf) => {
184 encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
185 }
186 KernelArg::BufferWithOffset(buf, offset) => {
187 encoder.set_buffer(index, Some(buf.metal_buffer()), *offset);
188 }
189 KernelArg::Bytes(bytes) => {
190 encoder.set_bytes(index, bytes.len() as u64, bytes.as_ptr() as *const _);
191 }
192 }
193 }
194}
195
196/// Number of times `commit_and_wait()` has been called (CPU sync points).
197static SYNC_COUNT: AtomicU64 = AtomicU64::new(0);
198
199/// Number of times an encode method has been called (GPU dispatches).
200static DISPATCH_COUNT: AtomicU64 = AtomicU64::new(0);
201
202/// Number of `MTLCommandBuffer` instances created via `CommandEncoder::new`.
203/// Increments once per `device.command_encoder()` call. Used by hf2q's
204/// `HF2Q_DECODE_PROFILE` instrumentation to measure command-buffer
205/// overhead per decode token (ADR-012 §Optimize / Task #15 follow-up).
206static CMD_BUF_COUNT: AtomicU64 = AtomicU64::new(0);
207
208/// Number of `memory_barrier()` calls that reached the
209/// `objc::msg_send![encoder, memoryBarrierWithScope:]` site. Capture-mode
210/// no-ops and pre-encoder no-ops are excluded so the count reflects
211/// actual MTL barriers issued.
212///
213/// Always tracked — the increment is one atomic op, ~5 ns. ADR-015 H4
214/// (Wave 2b hard gate #2) requires per-barrier counter resolution to
215/// confirm-or-falsify the barrier-coalescing lever; xctrace TimeProfiler
216/// at 1 ms sampling cannot resolve `memory_barrier` even though it fires
217/// ~440×/token (`docs/ADR-015-mlx-native-single-cb-decode.md` §"P3a' live
218/// profile pass" hypothesis register row H4).
219static BARRIER_COUNT: AtomicU64 = AtomicU64::new(0);
220
221/// Total nanoseconds spent inside the `objc::msg_send!` barrier site,
222/// summed across all calls. ONLY updated when the env var
223/// `MLX_PROFILE_BARRIERS=1` is set on the process (cached on first
224/// `memory_barrier` call). When disabled the timing path is a single
225/// branch + the unconditional barrier dispatch — same hot-path cost as
226/// before this counter was added.
227///
228/// Why env-gated: timing adds 2 × `Instant::now()` (~50–100 ns each via
229/// `mach_absolute_time`) per barrier. At ~440 barriers/token that is
230/// ~22–44 µs/token of measurement overhead — comparable to what we are
231/// trying to measure. Production must keep this off; profiling runs
232/// opt-in.
233static BARRIER_NS: AtomicU64 = AtomicU64::new(0);
234
235/// Reset all counters to zero.
236pub fn reset_counters() {
237 SYNC_COUNT.store(0, Ordering::Relaxed);
238 DISPATCH_COUNT.store(0, Ordering::Relaxed);
239 CMD_BUF_COUNT.store(0, Ordering::Relaxed);
240 BARRIER_COUNT.store(0, Ordering::Relaxed);
241 BARRIER_NS.store(0, Ordering::Relaxed);
242}
243
244/// Read the current value of `SYNC_COUNT`.
245///
246/// Each call to `commit_and_wait()` increments this counter.
247pub fn sync_count() -> u64 {
248 SYNC_COUNT.load(Ordering::Relaxed)
249}
250
251/// Read the current value of `DISPATCH_COUNT`.
252///
253/// Each call to `encode()`, `encode_threadgroups()`, or
254/// `encode_threadgroups_with_shared()` increments this counter.
255pub fn dispatch_count() -> u64 {
256 DISPATCH_COUNT.load(Ordering::Relaxed)
257}
258
259/// Read the current value of `CMD_BUF_COUNT`.
260///
261/// Each `CommandEncoder::new` (i.e. each `MlxDevice::command_encoder()`)
262/// increments this counter. Useful for diagnosing per-dispatch Metal
263/// command-buffer overhead in inner loops.
264pub fn cmd_buf_count() -> u64 {
265 CMD_BUF_COUNT.load(Ordering::Relaxed)
266}
267
268/// Read the current value of `BARRIER_COUNT`.
269///
270/// Each `memory_barrier()` call that reaches the underlying
271/// `objc::msg_send![encoder, memoryBarrierWithScope:]` site increments this
272/// counter. Capture-mode no-ops and pre-encoder no-ops are excluded.
273/// ADR-015 H4 hypothesis: ~440 barriers/token on the qwen35 decode hot
274/// path (verify against this counter).
275pub fn barrier_count() -> u64 {
276 BARRIER_COUNT.load(Ordering::Relaxed)
277}
278
279/// Read the total nanoseconds spent in the `memoryBarrierWithScope:`
280/// `objc::msg_send!` site. Only non-zero when `MLX_PROFILE_BARRIERS=1`
281/// was in the environment at the time of the first `memory_barrier()`
282/// call (the env check is cached on first use).
283///
284/// Combined with [`barrier_count`] this gives µs/barrier =
285/// `barrier_total_ns() / 1000 / barrier_count()`.
286pub fn barrier_total_ns() -> u64 {
287 BARRIER_NS.load(Ordering::Relaxed)
288}
289
290/// Whether barrier timing is enabled (env-gated, cached on first check).
291///
292/// Reading the env var via `std::env::var` is itself non-trivial; using
293/// `OnceLock` caches the decision so the per-barrier branch is a single
294/// atomic-load + compare.
295fn barrier_profile_enabled() -> bool {
296 use std::sync::OnceLock;
297 static FLAG: OnceLock<bool> = OnceLock::new();
298 *FLAG.get_or_init(|| {
299 std::env::var("MLX_PROFILE_BARRIERS")
300 .map(|v| v == "1")
301 .unwrap_or(false)
302 })
303}
304
305/// Whether `MLX_UNRETAINED_REFS=1` is set in the process environment.
306///
307/// ADR-015 iter13 — when true, `CommandEncoder::new_with_residency` opens
308/// each `MTLCommandBuffer` via
309/// [`CommandQueueRef::new_command_buffer_with_unretained_references`]
310/// instead of the default `commandBuffer`. llama.cpp's per-token decode
311/// CBs use this same call (`/opt/llama.cpp/ggml/src/ggml-metal/`
312/// `ggml-metal-context.m:512` `[queue commandBufferWithUnretainedReferences]`)
313/// and gain ~3-5% wall on M-series GPUs by skipping per-buffer-binding ARC
314/// retains on submit.
315///
316/// **Caller-side prerequisite.** Every Metal buffer bound to a dispatch
317/// must outlive the CB — see the docstring on
318/// [`CommandEncoder::new_with_residency`] for the full caller contract.
319/// In hf2q, the per-decode-token `MlxBufferPool` (`buffer_pool.rs`)
320/// already keeps ARC clones alive in its `in_use` list across the entire
321/// decode token; routing transient scratches through that pool is the
322/// canonical way to satisfy the contract.
323///
324/// Cached on first read via `OnceLock` to keep the per-CB-construction
325/// branch single-atomic-load fast. Default OFF so any production decode
326/// run that does NOT explicitly set the var preserves retained-refs
327/// behavior verbatim.
328fn unretained_refs_enabled() -> bool {
329 use std::sync::OnceLock;
330 static FLAG: OnceLock<bool> = OnceLock::new();
331 *FLAG.get_or_init(|| {
332 std::env::var("MLX_UNRETAINED_REFS")
333 .map(|v| v == "1")
334 .unwrap_or(false)
335 })
336}
337
338/// Issue the underlying Metal `memoryBarrierWithScope:` ObjC msg_send.
339///
340/// Held in its own `#[inline(never)]` function so xctrace / Instruments
341/// has a stable Rust frame to attribute barrier time against, separate
342/// from the surrounding encoder accounting. Per ADR-015 §P3a' Codex
343/// review Q2: TimeProfiler at 1 ms sampling cannot see this site when
344/// inlined; an explicit non-inline frame plus the [`BARRIER_NS`] counter
345/// closes the H4 hard gate.
346#[inline(never)]
347fn issue_metal_buffer_barrier(encoder: &ComputeCommandEncoderRef) {
348 // MTLBarrierScopeBuffers = 1 << 0 = 1.
349 const MTL_BARRIER_SCOPE_BUFFERS: u64 = 1;
350 unsafe {
351 let _: () =
352 objc::msg_send![encoder, memoryBarrierWithScope: MTL_BARRIER_SCOPE_BUFFERS];
353 }
354}
355
356/// A batched compute command encoder.
357///
358/// Keeps a single Metal `ComputeCommandEncoder` alive across multiple
359/// dispatches. The encoder is created on the first dispatch and ended
360/// only when the command buffer is committed. This mirrors candle's
361/// `compute_per_buffer` pattern and avoids per-dispatch encoder overhead.
362///
363/// # Typical usage
364///
365/// ```ignore
366/// let mut enc = device.command_encoder()?;
367/// // Multiple dispatches share the same compute encoder:
368/// enc.encode_threadgroups(pipeline1, &buffers1, tg1, tg_size1);
369/// enc.encode_threadgroups(pipeline2, &buffers2, tg2, tg_size2);
370/// enc.commit_and_wait()?;
371/// ```
372pub struct CommandEncoder {
373 cmd_buf: CommandBuffer,
374 // SAFETY marker: see unsafe Send impl below.
375 /// Raw pointer to the persistent compute encoder.
376 /// Non-null when a compute pass is active.
377 /// The encoder borrows from `cmd_buf` but we cannot express this
378 /// lifetime in safe Rust, so we use a raw pointer.
379 /// SAFETY: the pointer is valid as long as `cmd_buf` is alive and
380 /// `end_encoding()` has not been called on it.
381 active_encoder: *const ComputeCommandEncoderRef,
382 /// When `Some`, dispatches are recorded here instead of being encoded
383 /// into Metal. Set via `start_capture()`, extracted via `take_capture()`.
384 capture: Option<Vec<CapturedNode>>,
385 /// Op kind tag for the NEXT captured dispatch. Set via `set_op_kind()`,
386 /// consumed (reset to `Other`) when a dispatch is captured.
387 pending_op_kind: CapturedOpKind,
388 /// Pending read buffer ranges for the NEXT captured dispatch.
389 /// Set via `set_pending_buffer_ranges()`, consumed when the next dispatch
390 /// is captured. Used by the reorder pass (Phase 4e.3).
391 pending_reads: Vec<MemRange>,
392 /// Pending write buffer ranges for the NEXT captured dispatch.
393 pending_writes: Vec<MemRange>,
394 /// ADR-015 iter8e (Phase 3b): residency set whose pending add/remove
395 /// staging is flushed at every `commit*` boundary.
396 ///
397 /// Cloned from the device at `device.command_encoder()` time. `None`
398 /// when residency sets are disabled (HF2Q_NO_RESIDENCY=1, macOS<15,
399 /// or test-only `CommandEncoder::new` from a residency-less queue).
400 residency_set: Option<ResidencySet>,
401}
402
403/// SAFETY: CommandEncoder is safe to Send across threads provided that:
404/// 1. Only one thread accesses the encoder at a time (exclusive ownership).
405/// 2. The encoder is not used concurrently from multiple threads.
406///
407/// Metal command buffers and compute encoders are thread-safe for exclusive
408/// access (Apple documentation: "You can create command buffers, encode
409/// commands, and submit them from any thread"). The raw pointer
410/// `active_encoder` borrows from `cmd_buf` and is valid as long as
411/// `cmd_buf` is alive — this invariant holds across thread boundaries
412/// because both fields move together.
413///
414/// This matches llama.cpp's pattern of encoding command buffers on GCD
415/// worker threads via `dispatch_apply`, and is used for the dual-buffer
416/// pipeline where buf1 is encoded on a worker thread while buf0 executes.
417unsafe impl Send for CommandEncoder {}
418
419impl CommandEncoder {
420 /// Create a new command encoder from the given command queue.
421 ///
422 /// This immediately creates a Metal command buffer.
423 ///
424 /// # Why retained references
425 ///
426 /// We use the regular `commandBuffer` (Metal retains every bound
427 /// resource for the lifetime of the buffer) rather than
428 /// `commandBufferWithUnretainedReferences`. llama.cpp uses unretained
429 /// refs for an additional perf bump (~3-5% on M-series GPUs), but the
430 /// hf2q dispatch pattern allocates many transient scratch buffers
431 /// inside helper functions (`apply_proj` → `weight_bf16_owned`,
432 /// `apply_pre_norm` → `params`, etc.) that go out of scope at the
433 /// helper's return. With unretained refs the metal::Buffer's ARC
434 /// drops to zero, freeing the underlying GPU memory before the
435 /// dispatch executes. Verified 2026-04-26: switching to unretained
436 /// hits "Command buffer error: GPU command buffer completed with
437 /// error status" on the first MoE FFN dispatch.
438 ///
439 /// To enable unretained refs in the future, every helper that
440 /// allocates and dispatches must thread its scratch buffers up to a
441 /// caller scope that outlives the eventual commit, OR all such
442 /// scratch must come from the per-decode-token pool (which already
443 /// ARC-retains in its in_use list). Today the lm_head + router-
444 /// download paths are still unpooled.
445 #[allow(dead_code)]
446 pub(crate) fn new(queue: &CommandQueue) -> Result<Self> {
447 Self::new_with_residency(queue, None)
448 }
449
450 /// Create a new command encoder, optionally bound to a residency set so
451 /// `commit*` boundaries can flush deferred add/remove staging.
452 ///
453 /// ADR-015 iter8e (Phase 3b): the encoder's `commit_and_wait`,
454 /// `commit_and_wait_labeled`, `commit`, `commit_labeled`,
455 /// `commit_wait_with_gpu_time` all call
456 /// [`ResidencySet::flush_pending`](ResidencySet::flush_pending) before
457 /// submitting the Metal command buffer. This converts the
458 /// per-allocation `[set commit]` storm
459 /// (~880 commits/decode-token in iter8d/8e claude+codex variants) into
460 /// at most one commit per CB submission — mirrors llama.cpp's
461 /// `ggml-metal-device.m:1378-1382` pattern (batch addAllocation in
462 /// loop, commit ONCE).
463 ///
464 /// ADR-015 iter13: when the `MLX_UNRETAINED_REFS=1` env var is set at
465 /// process start, this constructor uses
466 /// [`CommandQueueRef::new_command_buffer_with_unretained_references`]
467 /// instead of `new_command_buffer`. llama.cpp's per-token decode CBs
468 /// use `commandBufferWithUnretainedReferences` (see
469 /// `/opt/llama.cpp/ggml/src/ggml-metal/ggml-metal-context.m:512`) which
470 /// skips Metal's per-buffer-binding ARC-retain on submit and saves
471 /// ~3-5% on M-series GPUs (per the docstring above).
472 ///
473 /// **Caller contract under unretained refs.** Every Metal buffer bound
474 /// to a dispatch in this CB MUST outlive the CB's GPU completion. In
475 /// the hf2q decode path, that means every transient scratch must be
476 /// either (a) backed by the per-decode-token arena pool
477 /// (`MlxBufferPool` keeps an ARC clone in `in_use` until the next
478 /// `reset` — see `buffer_pool.rs:60`) or (b) hoisted to a caller scope
479 /// that lives across the terminal `commit_and_wait_labeled`. Helpers
480 /// in `apply_proj` / `apply_pre_norm` / lm_head cast / router-download
481 /// that allocated transients via `device.alloc_buffer` and dropped
482 /// them at function return MUST be lifted to `pooled_alloc_buffer`
483 /// before `MLX_UNRETAINED_REFS=1` is enabled, or the first MoE FFN
484 /// dispatch will crash with "Command buffer error: GPU command buffer
485 /// completed with error status" (verified 2026-04-26).
486 ///
487 /// The default (`MLX_UNRETAINED_REFS` unset) preserves retained-refs
488 /// behavior verbatim — this is the sourdough-safe path.
489 pub(crate) fn new_with_residency(
490 queue: &CommandQueue,
491 residency_set: Option<ResidencySet>,
492 ) -> Result<Self> {
493 let cmd_buf = if unretained_refs_enabled() {
494 queue.new_command_buffer_with_unretained_references().to_owned()
495 } else {
496 queue.new_command_buffer().to_owned()
497 };
498 CMD_BUF_COUNT.fetch_add(1, Ordering::Relaxed);
499 Ok(Self {
500 cmd_buf,
501 active_encoder: std::ptr::null(),
502 capture: None,
503 pending_op_kind: CapturedOpKind::Other,
504 pending_reads: Vec::new(),
505 pending_writes: Vec::new(),
506 residency_set,
507 })
508 }
509
510 /// Enable capture mode.
511 ///
512 /// All subsequent dispatch and barrier calls will be recorded into a
513 /// `Vec<CapturedNode>` instead of being encoded into Metal.
514 /// Call `take_capture()` to extract the recorded nodes.
515 pub fn start_capture(&mut self) {
516 self.capture = Some(Vec::with_capacity(128));
517 }
518
519 /// Whether the encoder is currently in capture mode.
520 pub fn is_capturing(&self) -> bool {
521 self.capture.is_some()
522 }
523
524 /// Extract the captured nodes, ending capture mode.
525 ///
526 /// Returns `None` if capture mode was not active.
527 pub fn take_capture(&mut self) -> Option<Vec<CapturedNode>> {
528 self.capture.take()
529 }
530
531 /// Tag the NEXT captured dispatch with the given operation kind.
532 ///
533 /// The tag is consumed (reset to `Other`) after the next dispatch is
534 /// captured. Only meaningful in capture mode — has no effect on
535 /// direct-dispatch encoding.
536 ///
537 /// Used by op dispatch functions to annotate captures for the fusion
538 /// pass (Phase 4e.2).
539 pub fn set_op_kind(&mut self, kind: CapturedOpKind) {
540 self.pending_op_kind = kind;
541 }
542
543 /// Consume and return the pending op kind, resetting it to `Other`.
544 fn take_pending_op_kind(&mut self) -> CapturedOpKind {
545 let kind = self.pending_op_kind;
546 self.pending_op_kind = CapturedOpKind::Other;
547 kind
548 }
549
550 /// Stash buffer range annotations for the NEXT captured dispatch.
551 ///
552 /// Called by `GraphSession::barrier_between()` in capture mode to record
553 /// which buffers the next dispatch reads from and writes to. The ranges
554 /// are consumed by the next `encode_*` call and attached to the captured
555 /// `CapturedNode::Dispatch`.
556 ///
557 /// Only meaningful in capture mode — has no effect on direct-dispatch.
558 pub fn set_pending_buffer_ranges(&mut self, reads: Vec<MemRange>, writes: Vec<MemRange>) {
559 self.pending_reads = reads;
560 self.pending_writes = writes;
561 }
562
563 /// Patch the last captured dispatch node's empty reads/writes with the
564 /// given ranges. No-op if not capturing, or if the last node isn't a
565 /// Dispatch, or if its ranges are already populated.
566 ///
567 /// Used by `GraphSession::track_dispatch` in recording mode to annotate
568 /// dispatches that were called without a preceding `barrier_between`.
569 pub fn annotate_last_dispatch_if_missing(&mut self, reads: Vec<MemRange>, writes: Vec<MemRange>) {
570 if let Some(ref mut nodes) = self.capture {
571 if let Some(CapturedNode::Dispatch { reads: r, writes: w, .. }) = nodes.last_mut() {
572 if r.is_empty() && !reads.is_empty() {
573 *r = reads;
574 }
575 if w.is_empty() && !writes.is_empty() {
576 *w = writes;
577 }
578 }
579 }
580 }
581
582 /// Consume and return the pending buffer range annotations.
583 fn take_pending_buffer_ranges(&mut self) -> (Vec<MemRange>, Vec<MemRange>) {
584 let reads = std::mem::take(&mut self.pending_reads);
585 let writes = std::mem::take(&mut self.pending_writes);
586 (reads, writes)
587 }
588
589 /// Record buffer bindings into `RecordedBinding` form.
590 fn record_buffer_bindings(buffers: &[(u64, &MlxBuffer)]) -> Vec<(u64, RecordedBinding)> {
591 buffers
592 .iter()
593 .map(|&(index, buf)| {
594 (
595 index,
596 RecordedBinding::Buffer {
597 metal_buffer: buf.metal_buffer().clone(),
598 offset: buf.byte_offset(),
599 },
600 )
601 })
602 .collect()
603 }
604
605 /// Record `KernelArg` bindings into `RecordedBinding` form.
606 ///
607 /// `KernelArg::Buffer(buf)` records `buf.byte_offset()` so capture →
608 /// replay round-trips of `slice_view`-derived buffers preserve their
609 /// offsets, matching `record_buffer_bindings`'s behavior at line 382.
610 fn record_arg_bindings(bindings: &[(u64, KernelArg<'_>)]) -> Vec<(u64, RecordedBinding)> {
611 bindings
612 .iter()
613 .map(|(index, arg)| {
614 let recorded = match arg {
615 KernelArg::Buffer(buf) => RecordedBinding::Buffer {
616 metal_buffer: buf.metal_buffer().clone(),
617 offset: buf.byte_offset(),
618 },
619 KernelArg::BufferWithOffset(buf, offset) => RecordedBinding::Buffer {
620 metal_buffer: buf.metal_buffer().clone(),
621 offset: *offset,
622 },
623 KernelArg::Bytes(bytes) => RecordedBinding::Bytes(bytes.to_vec()),
624 };
625 (*index, recorded)
626 })
627 .collect()
628 }
629
630 /// Get or create the persistent compute encoder.
631 ///
632 /// On the first call, creates a new compute encoder from the command
633 /// buffer. On subsequent calls, returns the existing one.
634 ///
635 /// SAFETY: The returned reference borrows from `self.cmd_buf` which is
636 /// alive for the lifetime of this `CommandEncoder`. The raw pointer is
637 /// valid until `end_active_encoder()` is called.
638 #[inline]
639 fn get_or_create_encoder(&mut self) -> &ComputeCommandEncoderRef {
640 if self.active_encoder.is_null() {
641 // Use MTLDispatchTypeConcurrent to allow independent dispatches
642 // to overlap on the GPU. Memory barriers are inserted between
643 // dependent dispatches via `memory_barrier()`.
644 let encoder = self
645 .cmd_buf
646 .compute_command_encoder_with_dispatch_type(MTLDispatchType::Concurrent);
647 self.active_encoder = encoder as *const ComputeCommandEncoderRef;
648 }
649 // SAFETY: active_encoder is non-null and points to a valid encoder
650 // owned by cmd_buf.
651 unsafe { &*self.active_encoder }
652 }
653
654 /// End the active compute encoder if one exists.
655 #[inline]
656 fn end_active_encoder(&mut self) {
657 if !self.active_encoder.is_null() {
658 // SAFETY: the pointer was obtained from cmd_buf.new_compute_command_encoder()
659 // and has not been ended yet.
660 unsafe { &*self.active_encoder }.end_encoding();
661 self.active_encoder = std::ptr::null();
662 }
663 }
664
665 /// Insert a memory barrier with scope `MTLBarrierScopeBuffers`.
666 ///
667 /// When the encoder uses `MTLDispatchTypeConcurrent`, all dispatches can
668 /// execute concurrently unless separated by a barrier. Call this between
669 /// dispatches where the later dispatch reads a buffer written by an
670 /// earlier one.
671 ///
672 /// This is the same pattern llama.cpp uses:
673 /// `[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers]`
674 #[allow(unexpected_cfgs)]
675 pub fn memory_barrier(&mut self) {
676 if let Some(ref mut nodes) = self.capture {
677 nodes.push(CapturedNode::Barrier);
678 return;
679 }
680 if self.active_encoder.is_null() {
681 return;
682 }
683 BARRIER_COUNT.fetch_add(1, Ordering::Relaxed);
684 // SAFETY: active_encoder is non-null and valid.
685 let encoder = unsafe { &*self.active_encoder };
686 if barrier_profile_enabled() {
687 // mach_absolute_time path — only on when MLX_PROFILE_BARRIERS=1.
688 let start = std::time::Instant::now();
689 issue_metal_buffer_barrier(encoder);
690 let elapsed_ns = start.elapsed().as_nanos() as u64;
691 BARRIER_NS.fetch_add(elapsed_ns, Ordering::Relaxed);
692 } else {
693 issue_metal_buffer_barrier(encoder);
694 }
695 }
696
697 /// Set the compute pipeline state for subsequent dispatches.
698 ///
699 /// This begins a new compute pass if one is not already active.
700 pub fn set_pipeline(&mut self, pipeline: &ComputePipelineStateRef) {
701 let encoder = self.get_or_create_encoder();
702 encoder.set_compute_pipeline_state(pipeline);
703 }
704
705 /// Bind a buffer to a compute kernel argument slot.
706 ///
707 /// The `index` corresponds to the `[[buffer(N)]]` attribute in the MSL shader.
708 pub fn set_buffer(&self, index: u64, buffer: &MlxBuffer) {
709 let _ = (index, buffer);
710 }
711
712 /// Dispatch threads on the GPU.
713 pub fn dispatch_threads(&self, grid_size: MTLSize, threadgroup_size: MTLSize) {
714 let _ = (grid_size, threadgroup_size);
715 }
716
717 /// Encode a complete compute pass: set pipeline, bind buffers, dispatch.
718 ///
719 /// Reuses the persistent compute encoder — no per-dispatch encoder
720 /// creation overhead.
721 ///
722 /// # Arguments
723 ///
724 /// * `pipeline` — The compiled compute pipeline to execute.
725 /// * `buffers` — Slice of `(index, &MlxBuffer)` pairs for buffer bindings.
726 /// * `grid_size` — Total number of threads to launch.
727 /// * `threadgroup_size` — Threads per threadgroup.
728 pub fn encode(
729 &mut self,
730 pipeline: &ComputePipelineStateRef,
731 buffers: &[(u64, &MlxBuffer)],
732 grid_size: MTLSize,
733 threadgroup_size: MTLSize,
734 ) {
735 DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
736 let op_kind = self.take_pending_op_kind();
737 let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
738 if let Some(ref mut nodes) = self.capture {
739 nodes.push(CapturedNode::Dispatch {
740 pipeline: pipeline.to_owned(),
741 bindings: Self::record_buffer_bindings(buffers),
742 threads_per_grid: grid_size,
743 threads_per_threadgroup: threadgroup_size,
744 threadgroup_memory: Vec::new(),
745 dispatch_kind: DispatchKind::Threads,
746 op_kind,
747 reads: pending_reads,
748 writes: pending_writes,
749 });
750 return;
751 }
752 let encoder = self.get_or_create_encoder();
753 encoder.set_compute_pipeline_state(pipeline);
754 for &(index, buf) in buffers {
755 encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
756 }
757 encoder.dispatch_threads(grid_size, threadgroup_size);
758 }
759
760 /// Encode a compute pass using threadgroups instead of raw thread counts.
761 ///
762 /// Reuses the persistent compute encoder — no per-dispatch encoder
763 /// creation overhead.
764 pub fn encode_threadgroups(
765 &mut self,
766 pipeline: &ComputePipelineStateRef,
767 buffers: &[(u64, &MlxBuffer)],
768 threadgroups: MTLSize,
769 threadgroup_size: MTLSize,
770 ) {
771 DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
772 let op_kind = self.take_pending_op_kind();
773 let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
774 if let Some(ref mut nodes) = self.capture {
775 nodes.push(CapturedNode::Dispatch {
776 pipeline: pipeline.to_owned(),
777 bindings: Self::record_buffer_bindings(buffers),
778 threads_per_grid: threadgroups,
779 threads_per_threadgroup: threadgroup_size,
780 threadgroup_memory: Vec::new(),
781 dispatch_kind: DispatchKind::ThreadGroups,
782 op_kind,
783 reads: pending_reads,
784 writes: pending_writes,
785 });
786 return;
787 }
788 let encoder = self.get_or_create_encoder();
789 encoder.set_compute_pipeline_state(pipeline);
790 for &(index, buf) in buffers {
791 encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
792 }
793 encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
794 }
795
796 /// Encode a compute pass using threadgroups with shared threadgroup memory.
797 ///
798 /// Like [`encode_threadgroups`](Self::encode_threadgroups), but additionally
799 /// allocates threadgroup memory at the specified indices. This is required
800 /// for kernels that use `threadgroup` memory (e.g. reductions in rms_norm
801 /// and softmax).
802 ///
803 /// # Arguments
804 ///
805 /// * `pipeline` — The compiled compute pipeline to execute.
806 /// * `buffers` — Slice of `(index, &MlxBuffer)` pairs for buffer bindings.
807 /// * `threadgroup_mem` — Slice of `(index, byte_length)` pairs for threadgroup memory.
808 /// * `threadgroups` — Number of threadgroups to dispatch.
809 /// * `threadgroup_size` — Threads per threadgroup.
810 pub fn encode_threadgroups_with_shared(
811 &mut self,
812 pipeline: &ComputePipelineStateRef,
813 buffers: &[(u64, &MlxBuffer)],
814 threadgroup_mem: &[(u64, u64)],
815 threadgroups: MTLSize,
816 threadgroup_size: MTLSize,
817 ) {
818 DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
819 let op_kind = self.take_pending_op_kind();
820 let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
821 if let Some(ref mut nodes) = self.capture {
822 nodes.push(CapturedNode::Dispatch {
823 pipeline: pipeline.to_owned(),
824 bindings: Self::record_buffer_bindings(buffers),
825 threads_per_grid: threadgroups,
826 threads_per_threadgroup: threadgroup_size,
827 threadgroup_memory: threadgroup_mem.to_vec(),
828 dispatch_kind: DispatchKind::ThreadGroups,
829 op_kind,
830 reads: pending_reads,
831 writes: pending_writes,
832 });
833 return;
834 }
835 let encoder = self.get_or_create_encoder();
836 encoder.set_compute_pipeline_state(pipeline);
837 for &(index, buf) in buffers {
838 encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
839 }
840 for &(index, byte_length) in threadgroup_mem {
841 encoder.set_threadgroup_memory_length(index, byte_length);
842 }
843 encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
844 }
845
846 /// Encode a dispatch with mixed buffer/bytes bindings (dispatch_threads).
847 ///
848 /// Reuses the persistent compute encoder.
849 pub fn encode_with_args(
850 &mut self,
851 pipeline: &ComputePipelineStateRef,
852 bindings: &[(u64, KernelArg<'_>)],
853 grid_size: MTLSize,
854 threadgroup_size: MTLSize,
855 ) {
856 DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
857 let op_kind = self.take_pending_op_kind();
858 let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
859 if let Some(ref mut nodes) = self.capture {
860 nodes.push(CapturedNode::Dispatch {
861 pipeline: pipeline.to_owned(),
862 bindings: Self::record_arg_bindings(bindings),
863 threads_per_grid: grid_size,
864 threads_per_threadgroup: threadgroup_size,
865 threadgroup_memory: Vec::new(),
866 dispatch_kind: DispatchKind::Threads,
867 op_kind,
868 reads: pending_reads,
869 writes: pending_writes,
870 });
871 return;
872 }
873 let encoder = self.get_or_create_encoder();
874 encoder.set_compute_pipeline_state(pipeline);
875 apply_bindings(encoder, bindings);
876 encoder.dispatch_threads(grid_size, threadgroup_size);
877 }
878
879 /// Encode a dispatch with mixed buffer/bytes bindings (dispatch_thread_groups).
880 ///
881 /// Reuses the persistent compute encoder.
882 pub fn encode_threadgroups_with_args(
883 &mut self,
884 pipeline: &ComputePipelineStateRef,
885 bindings: &[(u64, KernelArg<'_>)],
886 threadgroups: MTLSize,
887 threadgroup_size: MTLSize,
888 ) {
889 DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
890 let op_kind = self.take_pending_op_kind();
891 let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
892 if let Some(ref mut nodes) = self.capture {
893 nodes.push(CapturedNode::Dispatch {
894 pipeline: pipeline.to_owned(),
895 bindings: Self::record_arg_bindings(bindings),
896 threads_per_grid: threadgroups,
897 threads_per_threadgroup: threadgroup_size,
898 threadgroup_memory: Vec::new(),
899 dispatch_kind: DispatchKind::ThreadGroups,
900 op_kind,
901 reads: pending_reads,
902 writes: pending_writes,
903 });
904 return;
905 }
906 let encoder = self.get_or_create_encoder();
907 encoder.set_compute_pipeline_state(pipeline);
908 apply_bindings(encoder, bindings);
909 encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
910 }
911
912 /// Encode a dispatch with mixed buffer/bytes bindings and shared memory.
913 ///
914 /// Reuses the persistent compute encoder.
915 pub fn encode_threadgroups_with_args_and_shared(
916 &mut self,
917 pipeline: &ComputePipelineStateRef,
918 bindings: &[(u64, KernelArg<'_>)],
919 threadgroup_mem: &[(u64, u64)],
920 threadgroups: MTLSize,
921 threadgroup_size: MTLSize,
922 ) {
923 DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
924 let op_kind = self.take_pending_op_kind();
925 let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
926 if let Some(ref mut nodes) = self.capture {
927 nodes.push(CapturedNode::Dispatch {
928 pipeline: pipeline.to_owned(),
929 bindings: Self::record_arg_bindings(bindings),
930 threads_per_grid: threadgroups,
931 threads_per_threadgroup: threadgroup_size,
932 threadgroup_memory: threadgroup_mem.to_vec(),
933 dispatch_kind: DispatchKind::ThreadGroups,
934 op_kind,
935 reads: pending_reads,
936 writes: pending_writes,
937 });
938 return;
939 }
940 let encoder = self.get_or_create_encoder();
941 encoder.set_compute_pipeline_state(pipeline);
942 apply_bindings(encoder, bindings);
943 for &(index, byte_length) in threadgroup_mem {
944 encoder.set_threadgroup_memory_length(index, byte_length);
945 }
946 encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
947 }
948
949 /// Replay a single captured dispatch node into this encoder.
950 ///
951 /// This is the inverse of capture: it takes a previously recorded
952 /// `CapturedNode::Dispatch` and encodes it into the live Metal encoder.
953 /// Barrier nodes are handled by the caller (ComputeGraph::encode_sequential).
954 ///
955 /// Does NOT increment `DISPATCH_COUNT` — that was already counted at
956 /// capture time.
957 pub fn replay_dispatch(
958 &mut self,
959 pipeline: &ComputePipelineStateRef,
960 bindings: &[(u64, RecordedBinding)],
961 threadgroup_memory: &[(u64, u64)],
962 threads_per_grid: MTLSize,
963 threads_per_threadgroup: MTLSize,
964 dispatch_kind: DispatchKind,
965 ) {
966 let encoder = self.get_or_create_encoder();
967 encoder.set_compute_pipeline_state(pipeline);
968 for (index, binding) in bindings {
969 match binding {
970 RecordedBinding::Buffer { metal_buffer, offset } => {
971 encoder.set_buffer(*index, Some(metal_buffer), *offset);
972 }
973 RecordedBinding::Bytes(bytes) => {
974 encoder.set_bytes(
975 *index,
976 bytes.len() as u64,
977 bytes.as_ptr() as *const _,
978 );
979 }
980 }
981 }
982 for &(index, byte_length) in threadgroup_memory {
983 encoder.set_threadgroup_memory_length(index, byte_length);
984 }
985 match dispatch_kind {
986 DispatchKind::Threads => {
987 encoder.dispatch_threads(threads_per_grid, threads_per_threadgroup);
988 }
989 DispatchKind::ThreadGroups => {
990 encoder.dispatch_thread_groups(threads_per_grid, threads_per_threadgroup);
991 }
992 }
993 }
994
995 /// Flush any pending residency-set add/remove staging.
996 ///
997 /// Hooked at every commit boundary so per-allocation
998 /// [`ResidencySet::add_allocation`](ResidencySet::add_allocation) and
999 /// [`ResidencySet::remove_allocation`](ResidencySet::remove_allocation)
1000 /// calls (as fired by `MlxDevice::alloc_buffer` and
1001 /// `MlxBufferStorage::Drop`) collapse into at most ONE `[set commit]`
1002 /// per CB submission. Mirrors llama.cpp's
1003 /// `ggml-metal-device.m:1378-1382` (batch addAllocation in loop,
1004 /// commit ONCE).
1005 #[inline]
1006 fn flush_residency_pending(&self) {
1007 if let Some(set) = self.residency_set.as_ref() {
1008 set.flush_pending();
1009 }
1010 }
1011
1012 /// Commit the command buffer and block until the GPU finishes execution.
1013 ///
1014 /// # Errors
1015 ///
1016 /// Returns `MlxError::CommandBufferError` if the GPU reports an error.
1017 pub fn commit_and_wait(&mut self) -> Result<()> {
1018 SYNC_COUNT.fetch_add(1, Ordering::Relaxed);
1019
1020 // End the persistent compute encoder before committing.
1021 self.end_active_encoder();
1022
1023 // ADR-015 iter8e (Phase 3b): flush deferred residency-set
1024 // add/remove staging so the residency hint covers any buffers
1025 // referenced by this CB. Single commit per CB boundary; no-op
1026 // when no residency set or no staged changes.
1027 self.flush_residency_pending();
1028
1029 self.cmd_buf.commit();
1030 self.cmd_buf.wait_until_completed();
1031
1032 match self.cmd_buf.status() {
1033 MTLCommandBufferStatus::Completed => Ok(()),
1034 MTLCommandBufferStatus::Error => {
1035 Err(MlxError::CommandBufferError(
1036 "GPU command buffer completed with error status".into(),
1037 ))
1038 }
1039 status => Err(MlxError::CommandBufferError(format!(
1040 "Unexpected command buffer status after wait: {:?}",
1041 status
1042 ))),
1043 }
1044 }
1045
1046 /// Commit + wait, accumulating GPU wall-clock time under `label` into
1047 /// the [`crate::kernel_profile`] global table when `MLX_PROFILE_CB=1`
1048 /// is set. When the env var is unset, this is identical to
1049 /// [`commit_and_wait`](Self::commit_and_wait) — zero overhead.
1050 ///
1051 /// Used by hf2q's decode hot path to attribute per-cb GPU time to
1052 /// labeled phases (per-layer attn, per-layer ffn, output_head, etc.)
1053 /// without manually wiring `commit_wait_with_gpu_time` everywhere.
1054 ///
1055 /// # Errors
1056 ///
1057 /// Returns `MlxError::CommandBufferError` if the GPU reports an error.
1058 pub fn commit_and_wait_labeled(&mut self, label: &str) -> Result<()> {
1059 // ADR-015 iter16 — propagate `label` to MTLCommandBuffer.setLabel and
1060 // (if a compute encoder is active) MTLComputeCommandEncoder.setLabel
1061 // BEFORE end_encoding/commit so xctrace's
1062 // `metal-application-encoders-list` table populates `cmdbuffer-label`
1063 // and `encoder-label` columns with the semantic phase name (e.g.
1064 // `layer.attn_moe_ffn`, `output_head.fused_norm_lm_argmax`,
1065 // `layer.delta_net.ops1-9`). Joined to per-CB GPU duration via
1066 // `metal-gpu-submission-to-command-buffer-id` (sub_id ↔ encoder_id) →
1067 // `metal-gpu-execution-points` (per-dispatch start/end), this enables
1068 // per-phase µs/token attribution comparing hf2q vs llama side-by-side
1069 // (iter15 §E "iter16 ATTRIBUTION PATH"). Cost is a single ObjC
1070 // msg_send per CB submission — sub-µs on M5 Max — and a no-op when
1071 // xctrace isn't recording, so this is unconditionally safe to call on
1072 // the production decode hot path.
1073 self.apply_labels(label);
1074 if crate::kernel_profile::is_enabled() {
1075 let (start_s, end_s) = self.commit_wait_with_gpu_time()?;
1076 let ns = ((end_s - start_s).max(0.0) * 1_000_000_000.0) as u64;
1077 crate::kernel_profile::record(label, ns);
1078 Ok(())
1079 } else {
1080 self.commit_and_wait()
1081 }
1082 }
1083
1084 /// Async commit, but with profiling label. When `MLX_PROFILE_CB=1`
1085 /// is set, redirects to a synchronous [`commit_and_wait_labeled`]
1086 /// call to capture per-cb GPU time (this defeats async pipelining
1087 /// while profiling, which is the whole point — profile-mode is slow
1088 /// but informative). When unset, identical to [`commit`](Self::commit).
1089 pub fn commit_labeled(&mut self, label: &str) {
1090 // ADR-015 iter16 — see `commit_and_wait_labeled` for rationale.
1091 if crate::kernel_profile::is_enabled() {
1092 // Profile mode: force sync to capture GPU time. apply_labels is
1093 // called inside commit_and_wait_labeled — do NOT call it twice
1094 // here (would double the ObjC msg_send under MLX_PROFILE_CB=1).
1095 // Errors are logged via stderr because the void return matches
1096 // commit().
1097 if let Err(e) = self.commit_and_wait_labeled(label) {
1098 eprintln!("[mlx-native] commit_labeled({}) failed: {}", label, e);
1099 }
1100 } else {
1101 // Async path: apply labels here so xctrace MST traces capture
1102 // per-CB phase attribution under default decode (no
1103 // `MLX_PROFILE_CB`).
1104 self.apply_labels(label);
1105 self.commit();
1106 }
1107 }
1108
1109 /// Apply `label` to the underlying `MTLCommandBuffer` and, if a compute
1110 /// encoder is currently active, to the `MTLComputeCommandEncoder`.
1111 ///
1112 /// Called from [`commit_labeled`] and [`commit_and_wait_labeled`] BEFORE
1113 /// the encoder is ended / the CB is committed so xctrace's
1114 /// `metal-application-encoders-list` table picks up the label on the
1115 /// row emitted at the encoder's `endEncoding` / CB submission boundary.
1116 /// Single ObjC `msg_send` per call (two if an encoder is active); sub-µs
1117 /// on M5 Max; no-op when xctrace isn't recording.
1118 ///
1119 /// Skipped (debug-only assert) if `label` is empty — empty labels would
1120 /// produce an indistinguishable trace row from the metal-rs default
1121 /// `Command Buffer 0` placeholder.
1122 #[inline]
1123 fn apply_labels(&self, label: &str) {
1124 debug_assert!(!label.is_empty(), "commit_*_labeled called with empty label");
1125 if label.is_empty() {
1126 return;
1127 }
1128 self.cmd_buf.set_label(label);
1129 if !self.active_encoder.is_null() {
1130 // SAFETY: active_encoder is non-null and points to a live encoder
1131 // owned by cmd_buf — same invariant as get_or_create_encoder /
1132 // memory_barrier. set_label is a single property write on the
1133 // ObjC object; safe before endEncoding.
1134 unsafe { &*self.active_encoder }.set_label(label);
1135 }
1136 }
1137
1138 /// Commit + wait, returning `(gpu_start_s, gpu_end_s)` CFTimeInterval
1139 /// timestamps from `MTLCommandBuffer`'s `GPUStartTime`/`GPUEndTime`
1140 /// properties. Both are mach-absolute CFTimeInterval seconds (double).
1141 ///
1142 /// Intended for `HF2Q_PROFILE_GPU_TS=1` per-bucket GPU wall-clock
1143 /// attribution. Adds exactly two ObjC property reads per call on top
1144 /// of the regular `commit_and_wait` — measured well under 1 μs on
1145 /// M5 Max.
1146 ///
1147 /// # Errors
1148 ///
1149 /// Returns `MlxError::CommandBufferError` if the GPU reports an error.
1150 pub fn commit_wait_with_gpu_time(&mut self) -> Result<(f64, f64)> {
1151 self.commit_and_wait()?;
1152 // SAFETY: cmd_buf is a valid MTLCommandBuffer that has been
1153 // committed and awaited. GPUStartTime / GPUEndTime return
1154 // CFTimeInterval (double precision seconds). See
1155 // https://developer.apple.com/documentation/metal/mtlcommandbuffer/1639925-gpustarttime
1156 let (gpu_start, gpu_end): (f64, f64) = unsafe {
1157 let cb = &*self.cmd_buf;
1158 let s: f64 = msg_send![cb, GPUStartTime];
1159 let e: f64 = msg_send![cb, GPUEndTime];
1160 (s, e)
1161 };
1162 Ok((gpu_start, gpu_end))
1163 }
1164
1165 /// Commit the command buffer WITHOUT blocking.
1166 ///
1167 /// The GPU begins executing the encoded commands immediately. Call
1168 /// [`wait_until_completed`](Self::wait_until_completed) later to block
1169 /// the CPU and check for errors. This allows the CPU to continue doing
1170 /// other work (e.g. preparing the next batch) while the GPU runs.
1171 pub fn commit(&mut self) {
1172 self.end_active_encoder();
1173 // ADR-015 iter8e (Phase 3b): same flush hook as commit_and_wait —
1174 // this is the async-pipeline path that production decode uses.
1175 self.flush_residency_pending();
1176 self.cmd_buf.commit();
1177 }
1178
1179 /// Block until a previously committed command buffer completes.
1180 ///
1181 /// Must be called after [`commit`](Self::commit). Do not call after
1182 /// [`commit_and_wait`](Self::commit_and_wait) — that method already waits.
1183 ///
1184 /// # Errors
1185 ///
1186 /// Returns `MlxError::CommandBufferError` if the GPU reports an error.
1187 pub fn wait_until_completed(&self) -> Result<()> {
1188 self.cmd_buf.wait_until_completed();
1189 match self.cmd_buf.status() {
1190 MTLCommandBufferStatus::Completed => Ok(()),
1191 MTLCommandBufferStatus::Error => Err(MlxError::CommandBufferError(
1192 "GPU command buffer completed with error status".into(),
1193 )),
1194 status => Err(MlxError::CommandBufferError(format!(
1195 "Unexpected command buffer status after wait: {:?}",
1196 status
1197 ))),
1198 }
1199 }
1200
1201 /// Borrow the underlying Metal command buffer.
1202 #[inline]
1203 pub fn metal_command_buffer(&self) -> &CommandBuffer {
1204 &self.cmd_buf
1205 }
1206}
1207
1208impl Drop for CommandEncoder {
1209 fn drop(&mut self) {
1210 // End the persistent compute encoder before the command buffer
1211 // is dropped, otherwise Metal will assert:
1212 // "Command encoder released without endEncoding"
1213 self.end_active_encoder();
1214 }
1215}