mlx_native/encoder.rs
1//! [`CommandEncoder`] — batched GPU command submission.
2//!
3//! Wraps a Metal command buffer. Encode one or more compute kernel dispatches,
4//! then call [`commit_and_wait`](CommandEncoder::commit_and_wait) to submit the
5//! entire batch and block until the GPU finishes.
6//!
7//! # Persistent compute encoder
8//!
9//! A single Metal `ComputeCommandEncoder` is kept alive across multiple
10//! dispatches within the same command buffer. This avoids the overhead of
11//! creating and ending a new compute encoder per dispatch — the same pattern
12//! candle uses (`compute_per_buffer`). On a forward pass with ~800 dispatches
13//! this saves ~800 encoder create/end cycles.
14//!
15//! # Capture mode (Phase 4e.1)
16//!
17//! When `start_capture()` is called, subsequent dispatches are recorded into a
18//! `Vec<CapturedNode>` instead of being encoded into Metal. `memory_barrier()`
19//! records a barrier sentinel. Call `take_capture()` to extract the recorded
20//! graph for later replay via `ComputeGraph::encode_sequential()`.
21
22use std::sync::atomic::{AtomicU64, Ordering};
23
24use metal::{
25 CommandBuffer, CommandQueue, ComputeCommandEncoderRef, ComputePipelineState,
26 ComputePipelineStateRef, MTLCommandBufferStatus, MTLDispatchType, MTLSize,
27};
28#[allow(unused_imports)]
29use objc::{msg_send, sel, sel_impl};
30
31use crate::buffer::MlxBuffer;
32use crate::error::{MlxError, Result};
33use crate::mem_ranges::MemRanges;
34use crate::residency::ResidencySet;
35
36/// A buffer or inline-bytes binding for a compute kernel argument slot.
37pub enum KernelArg<'a> {
38 /// Bind an existing Metal buffer at the given index.
39 Buffer(&'a MlxBuffer),
40 /// Bind an existing Metal buffer at the given index with a byte offset.
41 BufferWithOffset(&'a MlxBuffer, u64),
42 /// Bind inline bytes (small constant data) at the given index.
43 /// The data must be `Pod` and is copied into the command encoder.
44 Bytes(&'a [u8]),
45}
46
47/// Convert a `Pod` value to a byte slice suitable for `KernelArg::Bytes`.
48///
49/// # Safety
50///
51/// The caller must ensure `T` has the same layout as the corresponding
52/// MSL struct in the shader (matching field order, sizes, and alignment).
53pub fn as_bytes<T: bytemuck::Pod>(val: &T) -> &[u8] {
54 bytemuck::bytes_of(val)
55}
56
57// ---------------------------------------------------------------------------
58// Capture-mode types (Phase 4e.1 — Graph IR)
59// ---------------------------------------------------------------------------
60
61/// A recorded kernel argument binding.
62///
63/// When the encoder is in capture mode, each `set_buffer` / `set_bytes` call
64/// is stored as a `RecordedBinding` instead of being applied to Metal.
65#[derive(Clone)]
66pub enum RecordedBinding {
67 /// A Metal buffer at the given offset.
68 Buffer {
69 metal_buffer: metal::Buffer,
70 offset: u64,
71 },
72 /// Inline bytes (small constant data, copied).
73 Bytes(Vec<u8>),
74}
75
76/// How to dispatch the recorded kernel.
77#[derive(Clone, Copy, Debug)]
78pub enum DispatchKind {
79 /// `dispatch_threads(grid_size, threadgroup_size)` — Metal picks threadgroup count.
80 Threads,
81 /// `dispatch_thread_groups(threadgroups, threadgroup_size)` — caller specifies threadgroup count.
82 ThreadGroups,
83}
84
85/// Operation kind tag for captured nodes, used by the fusion pass (4e.2).
86///
87/// When the encoder is in capture mode, each dispatch can be tagged with an
88/// `OpKind` so the fusion pass can identify fuseable sequences without
89/// inspecting pipeline names.
90#[derive(Clone, Copy, Debug, PartialEq, Eq)]
91pub enum CapturedOpKind {
92 /// RMS normalization (with learned scale).
93 RmsNorm,
94 /// Elementwise multiply.
95 ElemMul,
96 /// Elementwise add.
97 ElemAdd,
98 /// Scaled dot-product attention (NOT reorderable — breaks lookahead).
99 Sdpa,
100 /// Softmax (NOT reorderable — breaks lookahead).
101 Softmax,
102 /// Any other operation — treated as reorderable by the graph optimizer.
103 Other,
104}
105
106impl CapturedOpKind {
107 /// Whether this captured op kind is safe to reorder past in the graph
108 /// optimizer (Phase 4e.3).
109 ///
110 /// Mirrors the `h_safe` whitelist from llama.cpp's
111 /// `ggml_metal_graph_optimize_reorder`. Non-safe ops break the 64-node
112 /// lookahead — the reorder pass cannot look past them.
113 pub fn is_reorderable(&self) -> bool {
114 match self {
115 Self::Sdpa | Self::Softmax => false,
116 Self::RmsNorm | Self::ElemMul | Self::ElemAdd | Self::Other => true,
117 }
118 }
119}
120
121/// A memory range annotation: (start_address, end_address).
122///
123/// Represents a contiguous GPU buffer region for conflict detection in the
124/// reorder pass (Phase 4e.3). Addresses are CPU-visible `contents_ptr()`
125/// values, which on Apple Silicon unified memory equal the GPU addresses.
126pub type MemRange = (usize, usize);
127
128/// A single captured compute dispatch or barrier sentinel.
129///
130/// Created when the encoder is in capture mode. Replayed later by
131/// `ComputeGraph::encode_sequential()`.
132#[derive(Clone)]
133pub enum CapturedNode {
134 /// A compute dispatch to replay.
135 Dispatch {
136 /// Pipeline state object to bind.
137 pipeline: ComputePipelineState,
138 /// Kernel argument bindings: (slot_index, binding).
139 bindings: Vec<(u64, RecordedBinding)>,
140 /// Grid or threadgroup count (interpretation depends on `dispatch_kind`).
141 threads_per_grid: MTLSize,
142 /// Threads per threadgroup.
143 threads_per_threadgroup: MTLSize,
144 /// Optional threadgroup memory allocations: (index, byte_length).
145 threadgroup_memory: Vec<(u64, u64)>,
146 /// Whether this is a dispatch_threads or dispatch_thread_groups call.
147 dispatch_kind: DispatchKind,
148 /// Operation kind tag for the fusion pass (4e.2).
149 /// Defaults to `Other` if not explicitly set via `set_op_kind()`.
150 op_kind: CapturedOpKind,
151 /// Read buffer ranges for reorder conflict detection (4e.3).
152 /// Populated from `barrier_between` calls in capture mode.
153 reads: Vec<MemRange>,
154 /// Write buffer ranges for reorder conflict detection (4e.3).
155 /// Populated from `barrier_between` calls in capture mode.
156 writes: Vec<MemRange>,
157 },
158 /// A memory barrier sentinel — forces a barrier at replay time.
159 Barrier,
160}
161
162/// Convert a slice of buffer references into capture-mode
163/// [`MemRange`] tuples. Used by the [`CommandEncoder::dispatch_tracked*`]
164/// family in capture mode — equivalent to the conversion
165/// `GraphSession::barrier_between` does at `graph.rs:1452-1465`.
166///
167/// `(start, end)` uses `contents_ptr() + byte_offset` as the start
168/// and `contents_ptr() + byte_offset + slice_extent` as the end.
169fn ranges_from_buffers(bufs: &[&MlxBuffer]) -> Vec<MemRange> {
170 bufs.iter()
171 .map(|b| {
172 let base = b.contents_ptr() as usize + b.byte_offset() as usize;
173 let extent = (b.byte_len()).saturating_sub(b.byte_offset() as usize);
174 (base, base + extent)
175 })
176 .collect()
177}
178
179/// Apply a slice of `KernelArg` bindings to a compute encoder.
180///
181/// `KernelArg::Buffer(buf)` propagates the `MlxBuffer::byte_offset()` so
182/// `slice_view`-derived sub-buffers are honored automatically — the
183/// kernel sees memory starting at the slice's offset. This matches the
184/// documented contract of `slice_view` and the offset-handling in the
185/// other binding paths in this file (`encode`, `encode_threadgroups`,
186/// `encode_threadgroups_with_shared`, replay). Without it, every
187/// `slice_view`-derived buffer bound via `KernelArg::Buffer` silently
188/// exposes the entire underlying allocation — surfaced by hf2q's
189/// nomic-bert iter-79 cosine parity bisection (cosine 0.098 → 0.999962
190/// after fix).
191///
192/// `KernelArg::BufferWithOffset(buf, offset)` continues to use the
193/// explicit `offset` argument verbatim (callers asking for an explicit
194/// offset get exactly that, even on sliced buffers). The two API
195/// surfaces are intentional: implicit (sliced views auto-propagate) vs.
196/// explicit (caller-controlled).
197#[inline]
198fn apply_bindings(encoder: &ComputeCommandEncoderRef, bindings: &[(u64, KernelArg<'_>)]) {
199 for &(index, ref arg) in bindings {
200 match arg {
201 KernelArg::Buffer(buf) => {
202 encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
203 }
204 KernelArg::BufferWithOffset(buf, offset) => {
205 encoder.set_buffer(index, Some(buf.metal_buffer()), *offset);
206 }
207 KernelArg::Bytes(bytes) => {
208 encoder.set_bytes(index, bytes.len() as u64, bytes.as_ptr() as *const _);
209 }
210 }
211 }
212}
213
214/// Number of times `commit_and_wait()` has been called (CPU sync points).
215static SYNC_COUNT: AtomicU64 = AtomicU64::new(0);
216
217/// Number of times an encode method has been called (GPU dispatches).
218static DISPATCH_COUNT: AtomicU64 = AtomicU64::new(0);
219
220/// Number of `MTLCommandBuffer` instances created via `CommandEncoder::new`.
221/// Increments once per `device.command_encoder()` call. Used by hf2q's
222/// `HF2Q_DECODE_PROFILE` instrumentation to measure command-buffer
223/// overhead per decode token (ADR-012 §Optimize / Task #15 follow-up).
224static CMD_BUF_COUNT: AtomicU64 = AtomicU64::new(0);
225
226/// Number of `memory_barrier()` calls that reached the
227/// `objc::msg_send![encoder, memoryBarrierWithScope:]` site. Capture-mode
228/// no-ops and pre-encoder no-ops are excluded so the count reflects
229/// actual MTL barriers issued.
230///
231/// Always tracked — the increment is one atomic op, ~5 ns. ADR-015 H4
232/// (Wave 2b hard gate #2) requires per-barrier counter resolution to
233/// confirm-or-falsify the barrier-coalescing lever; xctrace TimeProfiler
234/// at 1 ms sampling cannot resolve `memory_barrier` even though it fires
235/// ~440×/token (`docs/ADR-015-mlx-native-single-cb-decode.md` §"P3a' live
236/// profile pass" hypothesis register row H4).
237static BARRIER_COUNT: AtomicU64 = AtomicU64::new(0);
238
239/// Total nanoseconds spent inside the `objc::msg_send!` barrier site,
240/// summed across all calls. ONLY updated when the env var
241/// `MLX_PROFILE_BARRIERS=1` is set on the process (cached on first
242/// `memory_barrier` call). When disabled the timing path is a single
243/// branch + the unconditional barrier dispatch — same hot-path cost as
244/// before this counter was added.
245///
246/// Why env-gated: timing adds 2 × `Instant::now()` (~50–100 ns each via
247/// `mach_absolute_time`) per barrier. At ~440 barriers/token that is
248/// ~22–44 µs/token of measurement overhead — comparable to what we are
249/// trying to measure. Production must keep this off; profiling runs
250/// opt-in.
251static BARRIER_NS: AtomicU64 = AtomicU64::new(0);
252
253/// Reset all counters to zero.
254pub fn reset_counters() {
255 SYNC_COUNT.store(0, Ordering::Relaxed);
256 DISPATCH_COUNT.store(0, Ordering::Relaxed);
257 CMD_BUF_COUNT.store(0, Ordering::Relaxed);
258 BARRIER_COUNT.store(0, Ordering::Relaxed);
259 BARRIER_NS.store(0, Ordering::Relaxed);
260 AUTO_BARRIER_COUNT.store(0, Ordering::Relaxed);
261 AUTO_BARRIER_CONCURRENT.store(0, Ordering::Relaxed);
262}
263
264/// Read the current value of `SYNC_COUNT`.
265///
266/// Each call to `commit_and_wait()` increments this counter.
267pub fn sync_count() -> u64 {
268 SYNC_COUNT.load(Ordering::Relaxed)
269}
270
271/// Read the current value of `DISPATCH_COUNT`.
272///
273/// Each call to `encode()`, `encode_threadgroups()`, or
274/// `encode_threadgroups_with_shared()` increments this counter.
275pub fn dispatch_count() -> u64 {
276 DISPATCH_COUNT.load(Ordering::Relaxed)
277}
278
279/// Read the current value of `CMD_BUF_COUNT`.
280///
281/// Each `CommandEncoder::new` (i.e. each `MlxDevice::command_encoder()`)
282/// increments this counter. Useful for diagnosing per-dispatch Metal
283/// command-buffer overhead in inner loops.
284pub fn cmd_buf_count() -> u64 {
285 CMD_BUF_COUNT.load(Ordering::Relaxed)
286}
287
288/// Read the current value of `BARRIER_COUNT`.
289///
290/// Each `memory_barrier()` call that reaches the underlying
291/// `objc::msg_send![encoder, memoryBarrierWithScope:]` site increments this
292/// counter. Capture-mode no-ops and pre-encoder no-ops are excluded.
293/// ADR-015 H4 hypothesis: ~440 barriers/token on the qwen35 decode hot
294/// path (verify against this counter).
295pub fn barrier_count() -> u64 {
296 BARRIER_COUNT.load(Ordering::Relaxed)
297}
298
299/// Read the total nanoseconds spent in the `memoryBarrierWithScope:`
300/// `objc::msg_send!` site. Only non-zero when `MLX_PROFILE_BARRIERS=1`
301/// was in the environment at the time of the first `memory_barrier()`
302/// call (the env check is cached on first use).
303///
304/// Combined with [`barrier_count`] this gives µs/barrier =
305/// `barrier_total_ns() / 1000 / barrier_count()`.
306pub fn barrier_total_ns() -> u64 {
307 BARRIER_NS.load(Ordering::Relaxed)
308}
309
310/// Whether barrier timing is enabled (env-gated, cached on first check).
311///
312/// Reading the env var via `std::env::var` is itself non-trivial; using
313/// `OnceLock` caches the decision so the per-barrier branch is a single
314/// atomic-load + compare.
315fn barrier_profile_enabled() -> bool {
316 use std::sync::OnceLock;
317 static FLAG: OnceLock<bool> = OnceLock::new();
318 *FLAG.get_or_init(|| {
319 std::env::var("MLX_PROFILE_BARRIERS")
320 .map(|v| v == "1")
321 .unwrap_or(false)
322 })
323}
324
325/// Whether `MLX_UNRETAINED_REFS=1` is set in the process environment.
326///
327/// ADR-015 iter13 — when true, `CommandEncoder::new_with_residency` opens
328/// each `MTLCommandBuffer` via
329/// [`CommandQueueRef::new_command_buffer_with_unretained_references`]
330/// instead of the default `commandBuffer`. llama.cpp's per-token decode
331/// CBs use this same call (`/opt/llama.cpp/ggml/src/ggml-metal/`
332/// `ggml-metal-context.m:512` `[queue commandBufferWithUnretainedReferences]`)
333/// and gain ~3-5% wall on M-series GPUs by skipping per-buffer-binding ARC
334/// retains on submit.
335///
336/// **Caller-side prerequisite.** Every Metal buffer bound to a dispatch
337/// must outlive the CB — see the docstring on
338/// [`CommandEncoder::new_with_residency`] for the full caller contract.
339/// In hf2q, the per-decode-token `MlxBufferPool` (`buffer_pool.rs`)
340/// already keeps ARC clones alive in its `in_use` list across the entire
341/// decode token; routing transient scratches through that pool is the
342/// canonical way to satisfy the contract.
343///
344/// Cached on first read via `OnceLock` to keep the per-CB-construction
345/// branch single-atomic-load fast. Default OFF so any production decode
346/// run that does NOT explicitly set the var preserves retained-refs
347/// behavior verbatim.
348fn unretained_refs_enabled() -> bool {
349 use std::sync::OnceLock;
350 static FLAG: OnceLock<bool> = OnceLock::new();
351 *FLAG.get_or_init(|| {
352 std::env::var("MLX_UNRETAINED_REFS")
353 .map(|v| v == "1")
354 .unwrap_or(false)
355 })
356}
357
358/// Whether `HF2Q_AUTO_BARRIER=1` is set in the process environment.
359///
360/// ADR-015 iter37 — when true, every [`CommandEncoder::dispatch_tracked`]
361/// call consults a [`MemRanges`](crate::mem_ranges::MemRanges) tracker
362/// and auto-emits a `memoryBarrierWithScope:` exactly when the new
363/// dispatch's read/write ranges conflict with previously-recorded
364/// ranges (mirrors llama.cpp's `ggml_metal_op_concurrency_check` at
365/// `/opt/llama.cpp/ggml/src/ggml-metal/ggml-metal-ops.cpp:147-225`).
366/// When false, `dispatch_tracked` collapses to the same code path as
367/// `encode*` — no tracking, no auto-barriers — preserving sourdough
368/// behavior for any caller that opts into the tracked API but runs
369/// without the env gate.
370///
371/// Cached on first read via `OnceLock`. Default OFF — production
372/// decode/prefill keeps its hand-placed `enc.memory_barrier()` calls
373/// until the migration in iter38+.
374fn auto_barrier_enabled() -> bool {
375 use std::sync::OnceLock;
376 static FLAG: OnceLock<bool> = OnceLock::new();
377 *FLAG.get_or_init(|| {
378 std::env::var("HF2Q_AUTO_BARRIER")
379 .map(|v| v == "1")
380 .unwrap_or(false)
381 })
382}
383
384/// Number of `memory_barrier()` calls auto-emitted by
385/// [`CommandEncoder::dispatch_tracked`] under
386/// `HF2Q_AUTO_BARRIER=1`. Disjoint from [`BARRIER_COUNT`] —
387/// auto-barriers also bump `BARRIER_COUNT` since they go through
388/// `memory_barrier()`, so this counter measures only the
389/// auto-emitted subset.
390static AUTO_BARRIER_COUNT: AtomicU64 = AtomicU64::new(0);
391
392/// Number of `dispatch_tracked` calls whose mem-ranges check returned
393/// "concurrent" (no barrier needed). Together with
394/// [`AUTO_BARRIER_COUNT`] this measures the elision rate of the
395/// dataflow barrier: `concurrent / (concurrent + barriers)` is the
396/// fraction of dispatches that ran inside the previous concurrent
397/// group rather than starting a new one.
398static AUTO_BARRIER_CONCURRENT: AtomicU64 = AtomicU64::new(0);
399
400/// Read the cumulative number of auto-emitted barriers across all
401/// encoders since process start (or last [`reset_counters`]).
402pub fn auto_barrier_count() -> u64 {
403 AUTO_BARRIER_COUNT.load(Ordering::Relaxed)
404}
405
406/// Read the cumulative number of `dispatch_tracked` calls that did NOT
407/// emit a barrier (ran concurrent with the previous group).
408pub fn auto_barrier_concurrent_count() -> u64 {
409 AUTO_BARRIER_CONCURRENT.load(Ordering::Relaxed)
410}
411
412/// Issue the underlying Metal `memoryBarrierWithScope:` ObjC msg_send.
413///
414/// Held in its own `#[inline(never)]` function so xctrace / Instruments
415/// has a stable Rust frame to attribute barrier time against, separate
416/// from the surrounding encoder accounting. Per ADR-015 §P3a' Codex
417/// review Q2: TimeProfiler at 1 ms sampling cannot see this site when
418/// inlined; an explicit non-inline frame plus the [`BARRIER_NS`] counter
419/// closes the H4 hard gate.
420#[inline(never)]
421fn issue_metal_buffer_barrier(encoder: &ComputeCommandEncoderRef) {
422 // MTLBarrierScopeBuffers = 1 << 0 = 1.
423 const MTL_BARRIER_SCOPE_BUFFERS: u64 = 1;
424 unsafe {
425 let _: () =
426 objc::msg_send![encoder, memoryBarrierWithScope: MTL_BARRIER_SCOPE_BUFFERS];
427 }
428}
429
430/// A batched compute command encoder.
431///
432/// Keeps a single Metal `ComputeCommandEncoder` alive across multiple
433/// dispatches. The encoder is created on the first dispatch and ended
434/// only when the command buffer is committed. This mirrors candle's
435/// `compute_per_buffer` pattern and avoids per-dispatch encoder overhead.
436///
437/// # Typical usage
438///
439/// ```ignore
440/// let mut enc = device.command_encoder()?;
441/// // Multiple dispatches share the same compute encoder:
442/// enc.encode_threadgroups(pipeline1, &buffers1, tg1, tg_size1);
443/// enc.encode_threadgroups(pipeline2, &buffers2, tg2, tg_size2);
444/// enc.commit_and_wait()?;
445/// ```
446pub struct CommandEncoder {
447 cmd_buf: CommandBuffer,
448 // SAFETY marker: see unsafe Send impl below.
449 /// Raw pointer to the persistent compute encoder.
450 /// Non-null when a compute pass is active.
451 /// The encoder borrows from `cmd_buf` but we cannot express this
452 /// lifetime in safe Rust, so we use a raw pointer.
453 /// SAFETY: the pointer is valid as long as `cmd_buf` is alive and
454 /// `end_encoding()` has not been called on it.
455 active_encoder: *const ComputeCommandEncoderRef,
456 /// When `Some`, dispatches are recorded here instead of being encoded
457 /// into Metal. Set via `start_capture()`, extracted via `take_capture()`.
458 capture: Option<Vec<CapturedNode>>,
459 /// Op kind tag for the NEXT captured dispatch. Set via `set_op_kind()`,
460 /// consumed (reset to `Other`) when a dispatch is captured.
461 pending_op_kind: CapturedOpKind,
462 /// Pending read buffer ranges for the NEXT captured dispatch.
463 /// Set via `set_pending_buffer_ranges()`, consumed when the next dispatch
464 /// is captured. Used by the reorder pass (Phase 4e.3).
465 pending_reads: Vec<MemRange>,
466 /// Pending write buffer ranges for the NEXT captured dispatch.
467 pending_writes: Vec<MemRange>,
468 /// ADR-015 iter8e (Phase 3b): residency set whose pending add/remove
469 /// staging is flushed at every `commit*` boundary.
470 ///
471 /// Cloned from the device at `device.command_encoder()` time. `None`
472 /// when residency sets are disabled (HF2Q_NO_RESIDENCY=1, macOS<15,
473 /// or test-only `CommandEncoder::new` from a residency-less queue).
474 residency_set: Option<ResidencySet>,
475 /// ADR-015 iter37: dataflow barrier inference state.
476 ///
477 /// Populated only when `HF2Q_AUTO_BARRIER=1` is set at process
478 /// start (cached via [`auto_barrier_enabled`]). Each
479 /// [`Self::dispatch_tracked`] call consults this state to decide
480 /// whether a Metal memory barrier is required; on conflict the
481 /// barrier is emitted, the state is reset, and the new dispatch's
482 /// ranges seed the next concurrent group. When the env gate is
483 /// off, `dispatch_tracked` collapses to its untracked equivalent
484 /// and this field is left empty for the encoder's lifetime.
485 ///
486 /// The field is always present (zero-sized when empty) so the
487 /// gate-off branch is a single bool-load + early return rather
488 /// than an allocation/Option indirection.
489 mem_ranges: MemRanges,
490}
491
492/// SAFETY: CommandEncoder is safe to Send across threads provided that:
493/// 1. Only one thread accesses the encoder at a time (exclusive ownership).
494/// 2. The encoder is not used concurrently from multiple threads.
495///
496/// Metal command buffers and compute encoders are thread-safe for exclusive
497/// access (Apple documentation: "You can create command buffers, encode
498/// commands, and submit them from any thread"). The raw pointer
499/// `active_encoder` borrows from `cmd_buf` and is valid as long as
500/// `cmd_buf` is alive — this invariant holds across thread boundaries
501/// because both fields move together.
502///
503/// This matches llama.cpp's pattern of encoding command buffers on GCD
504/// worker threads via `dispatch_apply`, and is used for the dual-buffer
505/// pipeline where buf1 is encoded on a worker thread while buf0 executes.
506unsafe impl Send for CommandEncoder {}
507
508impl CommandEncoder {
509 /// Create a new command encoder from the given command queue.
510 ///
511 /// This immediately creates a Metal command buffer.
512 ///
513 /// # Why retained references
514 ///
515 /// We use the regular `commandBuffer` (Metal retains every bound
516 /// resource for the lifetime of the buffer) rather than
517 /// `commandBufferWithUnretainedReferences`. llama.cpp uses unretained
518 /// refs for an additional perf bump (~3-5% on M-series GPUs), but the
519 /// hf2q dispatch pattern allocates many transient scratch buffers
520 /// inside helper functions (`apply_proj` → `weight_bf16_owned`,
521 /// `apply_pre_norm` → `params`, etc.) that go out of scope at the
522 /// helper's return. With unretained refs the metal::Buffer's ARC
523 /// drops to zero, freeing the underlying GPU memory before the
524 /// dispatch executes. Verified 2026-04-26: switching to unretained
525 /// hits "Command buffer error: GPU command buffer completed with
526 /// error status" on the first MoE FFN dispatch.
527 ///
528 /// To enable unretained refs in the future, every helper that
529 /// allocates and dispatches must thread its scratch buffers up to a
530 /// caller scope that outlives the eventual commit, OR all such
531 /// scratch must come from the per-decode-token pool (which already
532 /// ARC-retains in its in_use list). Today the lm_head + router-
533 /// download paths are still unpooled.
534 #[allow(dead_code)]
535 pub(crate) fn new(queue: &CommandQueue) -> Result<Self> {
536 Self::new_with_residency(queue, None)
537 }
538
539 /// Create a new command encoder, optionally bound to a residency set so
540 /// `commit*` boundaries can flush deferred add/remove staging.
541 ///
542 /// ADR-015 iter8e (Phase 3b): the encoder's `commit_and_wait`,
543 /// `commit_and_wait_labeled`, `commit`, `commit_labeled`,
544 /// `commit_wait_with_gpu_time` all call
545 /// [`ResidencySet::flush_pending`](ResidencySet::flush_pending) before
546 /// submitting the Metal command buffer. This converts the
547 /// per-allocation `[set commit]` storm
548 /// (~880 commits/decode-token in iter8d/8e claude+codex variants) into
549 /// at most one commit per CB submission — mirrors llama.cpp's
550 /// `ggml-metal-device.m:1378-1382` pattern (batch addAllocation in
551 /// loop, commit ONCE).
552 ///
553 /// ADR-015 iter13: when the `MLX_UNRETAINED_REFS=1` env var is set at
554 /// process start, this constructor uses
555 /// [`CommandQueueRef::new_command_buffer_with_unretained_references`]
556 /// instead of `new_command_buffer`. llama.cpp's per-token decode CBs
557 /// use `commandBufferWithUnretainedReferences` (see
558 /// `/opt/llama.cpp/ggml/src/ggml-metal/ggml-metal-context.m:512`) which
559 /// skips Metal's per-buffer-binding ARC-retain on submit and saves
560 /// ~3-5% on M-series GPUs (per the docstring above).
561 ///
562 /// **Caller contract under unretained refs.** Every Metal buffer bound
563 /// to a dispatch in this CB MUST outlive the CB's GPU completion. In
564 /// the hf2q decode path, that means every transient scratch must be
565 /// either (a) backed by the per-decode-token arena pool
566 /// (`MlxBufferPool` keeps an ARC clone in `in_use` until the next
567 /// `reset` — see `buffer_pool.rs:60`) or (b) hoisted to a caller scope
568 /// that lives across the terminal `commit_and_wait_labeled`. Helpers
569 /// in `apply_proj` / `apply_pre_norm` / lm_head cast / router-download
570 /// that allocated transients via `device.alloc_buffer` and dropped
571 /// them at function return MUST be lifted to `pooled_alloc_buffer`
572 /// before `MLX_UNRETAINED_REFS=1` is enabled, or the first MoE FFN
573 /// dispatch will crash with "Command buffer error: GPU command buffer
574 /// completed with error status" (verified 2026-04-26).
575 ///
576 /// The default (`MLX_UNRETAINED_REFS` unset) preserves retained-refs
577 /// behavior verbatim — this is the sourdough-safe path.
578 pub(crate) fn new_with_residency(
579 queue: &CommandQueue,
580 residency_set: Option<ResidencySet>,
581 ) -> Result<Self> {
582 let cmd_buf = if unretained_refs_enabled() {
583 queue.new_command_buffer_with_unretained_references().to_owned()
584 } else {
585 queue.new_command_buffer().to_owned()
586 };
587 CMD_BUF_COUNT.fetch_add(1, Ordering::Relaxed);
588 Ok(Self {
589 cmd_buf,
590 active_encoder: std::ptr::null(),
591 capture: None,
592 pending_op_kind: CapturedOpKind::Other,
593 pending_reads: Vec::new(),
594 pending_writes: Vec::new(),
595 residency_set,
596 mem_ranges: MemRanges::new(),
597 })
598 }
599
600 /// Enable capture mode.
601 ///
602 /// All subsequent dispatch and barrier calls will be recorded into a
603 /// `Vec<CapturedNode>` instead of being encoded into Metal.
604 /// Call `take_capture()` to extract the recorded nodes.
605 pub fn start_capture(&mut self) {
606 self.capture = Some(Vec::with_capacity(128));
607 }
608
609 /// Whether the encoder is currently in capture mode.
610 pub fn is_capturing(&self) -> bool {
611 self.capture.is_some()
612 }
613
614 /// Extract the captured nodes, ending capture mode.
615 ///
616 /// Returns `None` if capture mode was not active.
617 pub fn take_capture(&mut self) -> Option<Vec<CapturedNode>> {
618 self.capture.take()
619 }
620
621 /// Tag the NEXT captured dispatch with the given operation kind.
622 ///
623 /// The tag is consumed (reset to `Other`) after the next dispatch is
624 /// captured. Only meaningful in capture mode — has no effect on
625 /// direct-dispatch encoding.
626 ///
627 /// Used by op dispatch functions to annotate captures for the fusion
628 /// pass (Phase 4e.2).
629 pub fn set_op_kind(&mut self, kind: CapturedOpKind) {
630 self.pending_op_kind = kind;
631 }
632
633 /// Consume and return the pending op kind, resetting it to `Other`.
634 fn take_pending_op_kind(&mut self) -> CapturedOpKind {
635 let kind = self.pending_op_kind;
636 self.pending_op_kind = CapturedOpKind::Other;
637 kind
638 }
639
640 /// Stash buffer range annotations for the NEXT captured dispatch.
641 ///
642 /// Called by `GraphSession::barrier_between()` in capture mode to record
643 /// which buffers the next dispatch reads from and writes to. The ranges
644 /// are consumed by the next `encode_*` call and attached to the captured
645 /// `CapturedNode::Dispatch`.
646 ///
647 /// Only meaningful in capture mode — has no effect on direct-dispatch.
648 pub fn set_pending_buffer_ranges(&mut self, reads: Vec<MemRange>, writes: Vec<MemRange>) {
649 self.pending_reads = reads;
650 self.pending_writes = writes;
651 }
652
653 /// Patch the last captured dispatch node's empty reads/writes with the
654 /// given ranges. No-op if not capturing, or if the last node isn't a
655 /// Dispatch, or if its ranges are already populated.
656 ///
657 /// Used by `GraphSession::track_dispatch` in recording mode to annotate
658 /// dispatches that were called without a preceding `barrier_between`.
659 pub fn annotate_last_dispatch_if_missing(&mut self, reads: Vec<MemRange>, writes: Vec<MemRange>) {
660 if let Some(ref mut nodes) = self.capture {
661 if let Some(CapturedNode::Dispatch { reads: r, writes: w, .. }) = nodes.last_mut() {
662 if r.is_empty() && !reads.is_empty() {
663 *r = reads;
664 }
665 if w.is_empty() && !writes.is_empty() {
666 *w = writes;
667 }
668 }
669 }
670 }
671
672 /// Consume and return the pending buffer range annotations.
673 fn take_pending_buffer_ranges(&mut self) -> (Vec<MemRange>, Vec<MemRange>) {
674 let reads = std::mem::take(&mut self.pending_reads);
675 let writes = std::mem::take(&mut self.pending_writes);
676 (reads, writes)
677 }
678
679 /// Record buffer bindings into `RecordedBinding` form.
680 fn record_buffer_bindings(buffers: &[(u64, &MlxBuffer)]) -> Vec<(u64, RecordedBinding)> {
681 buffers
682 .iter()
683 .map(|&(index, buf)| {
684 (
685 index,
686 RecordedBinding::Buffer {
687 metal_buffer: buf.metal_buffer().clone(),
688 offset: buf.byte_offset(),
689 },
690 )
691 })
692 .collect()
693 }
694
695 /// Record `KernelArg` bindings into `RecordedBinding` form.
696 ///
697 /// `KernelArg::Buffer(buf)` records `buf.byte_offset()` so capture →
698 /// replay round-trips of `slice_view`-derived buffers preserve their
699 /// offsets, matching `record_buffer_bindings`'s behavior at line 382.
700 fn record_arg_bindings(bindings: &[(u64, KernelArg<'_>)]) -> Vec<(u64, RecordedBinding)> {
701 bindings
702 .iter()
703 .map(|(index, arg)| {
704 let recorded = match arg {
705 KernelArg::Buffer(buf) => RecordedBinding::Buffer {
706 metal_buffer: buf.metal_buffer().clone(),
707 offset: buf.byte_offset(),
708 },
709 KernelArg::BufferWithOffset(buf, offset) => RecordedBinding::Buffer {
710 metal_buffer: buf.metal_buffer().clone(),
711 offset: *offset,
712 },
713 KernelArg::Bytes(bytes) => RecordedBinding::Bytes(bytes.to_vec()),
714 };
715 (*index, recorded)
716 })
717 .collect()
718 }
719
720 /// Get or create the persistent compute encoder.
721 ///
722 /// On the first call, creates a new compute encoder from the command
723 /// buffer. On subsequent calls, returns the existing one.
724 ///
725 /// SAFETY: The returned reference borrows from `self.cmd_buf` which is
726 /// alive for the lifetime of this `CommandEncoder`. The raw pointer is
727 /// valid until `end_active_encoder()` is called.
728 #[inline]
729 fn get_or_create_encoder(&mut self) -> &ComputeCommandEncoderRef {
730 if self.active_encoder.is_null() {
731 // Use MTLDispatchTypeConcurrent to allow independent dispatches
732 // to overlap on the GPU. Memory barriers are inserted between
733 // dependent dispatches via `memory_barrier()`.
734 //
735 // ADR-015 iter61a-2 probe: HF2Q_FORCE_SERIAL_DISPATCH=1 falls back
736 // to MTLDispatchType::Serial — every dispatch waits for the
737 // previous to complete, eliminating concurrent-dispatch race
738 // windows. Used to falsify Hypothesis (g): missing memory_barrier
739 // calls between dependent dispatches cause cold-run logit
740 // non-determinism via thread-race on a shared buffer.
741 let dispatch_type = if std::env::var("HF2Q_FORCE_SERIAL_DISPATCH")
742 .map(|v| v == "1")
743 .unwrap_or(false)
744 {
745 MTLDispatchType::Serial
746 } else {
747 MTLDispatchType::Concurrent
748 };
749 let encoder = self
750 .cmd_buf
751 .compute_command_encoder_with_dispatch_type(dispatch_type);
752 self.active_encoder = encoder as *const ComputeCommandEncoderRef;
753 }
754 // SAFETY: active_encoder is non-null and points to a valid encoder
755 // owned by cmd_buf.
756 unsafe { &*self.active_encoder }
757 }
758
759 /// End the active compute encoder if one exists.
760 #[inline]
761 fn end_active_encoder(&mut self) {
762 if !self.active_encoder.is_null() {
763 // SAFETY: the pointer was obtained from cmd_buf.new_compute_command_encoder()
764 // and has not been ended yet.
765 unsafe { &*self.active_encoder }.end_encoding();
766 self.active_encoder = std::ptr::null();
767 }
768 }
769
770 /// Insert a memory barrier with scope `MTLBarrierScopeBuffers`.
771 ///
772 /// When the encoder uses `MTLDispatchTypeConcurrent`, all dispatches can
773 /// execute concurrently unless separated by a barrier. Call this between
774 /// dispatches where the later dispatch reads a buffer written by an
775 /// earlier one.
776 ///
777 /// This is the same pattern llama.cpp uses:
778 /// `[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers]`
779 #[allow(unexpected_cfgs)]
780 pub fn memory_barrier(&mut self) {
781 if let Some(ref mut nodes) = self.capture {
782 nodes.push(CapturedNode::Barrier);
783 return;
784 }
785 if self.active_encoder.is_null() {
786 return;
787 }
788 BARRIER_COUNT.fetch_add(1, Ordering::Relaxed);
789 // SAFETY: active_encoder is non-null and valid.
790 let encoder = unsafe { &*self.active_encoder };
791 if barrier_profile_enabled() {
792 // mach_absolute_time path — only on when MLX_PROFILE_BARRIERS=1.
793 let start = std::time::Instant::now();
794 issue_metal_buffer_barrier(encoder);
795 let elapsed_ns = start.elapsed().as_nanos() as u64;
796 BARRIER_NS.fetch_add(elapsed_ns, Ordering::Relaxed);
797 } else {
798 issue_metal_buffer_barrier(encoder);
799 }
800 }
801
802 /// Set the compute pipeline state for subsequent dispatches.
803 ///
804 /// This begins a new compute pass if one is not already active.
805 pub fn set_pipeline(&mut self, pipeline: &ComputePipelineStateRef) {
806 let encoder = self.get_or_create_encoder();
807 encoder.set_compute_pipeline_state(pipeline);
808 }
809
810 /// Bind a buffer to a compute kernel argument slot.
811 ///
812 /// The `index` corresponds to the `[[buffer(N)]]` attribute in the MSL shader.
813 pub fn set_buffer(&self, index: u64, buffer: &MlxBuffer) {
814 let _ = (index, buffer);
815 }
816
817 /// Dispatch threads on the GPU.
818 pub fn dispatch_threads(&self, grid_size: MTLSize, threadgroup_size: MTLSize) {
819 let _ = (grid_size, threadgroup_size);
820 }
821
822 /// Encode a complete compute pass: set pipeline, bind buffers, dispatch.
823 ///
824 /// Reuses the persistent compute encoder — no per-dispatch encoder
825 /// creation overhead.
826 ///
827 /// # Arguments
828 ///
829 /// * `pipeline` — The compiled compute pipeline to execute.
830 /// * `buffers` — Slice of `(index, &MlxBuffer)` pairs for buffer bindings.
831 /// * `grid_size` — Total number of threads to launch.
832 /// * `threadgroup_size` — Threads per threadgroup.
833 pub fn encode(
834 &mut self,
835 pipeline: &ComputePipelineStateRef,
836 buffers: &[(u64, &MlxBuffer)],
837 grid_size: MTLSize,
838 threadgroup_size: MTLSize,
839 ) {
840 DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
841 let op_kind = self.take_pending_op_kind();
842 let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
843 if let Some(ref mut nodes) = self.capture {
844 nodes.push(CapturedNode::Dispatch {
845 pipeline: pipeline.to_owned(),
846 bindings: Self::record_buffer_bindings(buffers),
847 threads_per_grid: grid_size,
848 threads_per_threadgroup: threadgroup_size,
849 threadgroup_memory: Vec::new(),
850 dispatch_kind: DispatchKind::Threads,
851 op_kind,
852 reads: pending_reads,
853 writes: pending_writes,
854 });
855 return;
856 }
857 let encoder = self.get_or_create_encoder();
858 encoder.set_compute_pipeline_state(pipeline);
859 for &(index, buf) in buffers {
860 encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
861 }
862 encoder.dispatch_threads(grid_size, threadgroup_size);
863 }
864
865 /// Encode a compute pass using threadgroups instead of raw thread counts.
866 ///
867 /// Reuses the persistent compute encoder — no per-dispatch encoder
868 /// creation overhead.
869 pub fn encode_threadgroups(
870 &mut self,
871 pipeline: &ComputePipelineStateRef,
872 buffers: &[(u64, &MlxBuffer)],
873 threadgroups: MTLSize,
874 threadgroup_size: MTLSize,
875 ) {
876 DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
877 let op_kind = self.take_pending_op_kind();
878 let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
879 if let Some(ref mut nodes) = self.capture {
880 nodes.push(CapturedNode::Dispatch {
881 pipeline: pipeline.to_owned(),
882 bindings: Self::record_buffer_bindings(buffers),
883 threads_per_grid: threadgroups,
884 threads_per_threadgroup: threadgroup_size,
885 threadgroup_memory: Vec::new(),
886 dispatch_kind: DispatchKind::ThreadGroups,
887 op_kind,
888 reads: pending_reads,
889 writes: pending_writes,
890 });
891 return;
892 }
893 let encoder = self.get_or_create_encoder();
894 encoder.set_compute_pipeline_state(pipeline);
895 for &(index, buf) in buffers {
896 encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
897 }
898 encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
899 }
900
901 /// Encode a compute pass using threadgroups with shared threadgroup memory.
902 ///
903 /// Like [`encode_threadgroups`](Self::encode_threadgroups), but additionally
904 /// allocates threadgroup memory at the specified indices. This is required
905 /// for kernels that use `threadgroup` memory (e.g. reductions in rms_norm
906 /// and softmax).
907 ///
908 /// # Arguments
909 ///
910 /// * `pipeline` — The compiled compute pipeline to execute.
911 /// * `buffers` — Slice of `(index, &MlxBuffer)` pairs for buffer bindings.
912 /// * `threadgroup_mem` — Slice of `(index, byte_length)` pairs for threadgroup memory.
913 /// * `threadgroups` — Number of threadgroups to dispatch.
914 /// * `threadgroup_size` — Threads per threadgroup.
915 pub fn encode_threadgroups_with_shared(
916 &mut self,
917 pipeline: &ComputePipelineStateRef,
918 buffers: &[(u64, &MlxBuffer)],
919 threadgroup_mem: &[(u64, u64)],
920 threadgroups: MTLSize,
921 threadgroup_size: MTLSize,
922 ) {
923 DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
924 let op_kind = self.take_pending_op_kind();
925 let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
926 if let Some(ref mut nodes) = self.capture {
927 nodes.push(CapturedNode::Dispatch {
928 pipeline: pipeline.to_owned(),
929 bindings: Self::record_buffer_bindings(buffers),
930 threads_per_grid: threadgroups,
931 threads_per_threadgroup: threadgroup_size,
932 threadgroup_memory: threadgroup_mem.to_vec(),
933 dispatch_kind: DispatchKind::ThreadGroups,
934 op_kind,
935 reads: pending_reads,
936 writes: pending_writes,
937 });
938 return;
939 }
940 let encoder = self.get_or_create_encoder();
941 encoder.set_compute_pipeline_state(pipeline);
942 for &(index, buf) in buffers {
943 encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
944 }
945 for &(index, byte_length) in threadgroup_mem {
946 encoder.set_threadgroup_memory_length(index, byte_length);
947 }
948 encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
949 }
950
951 /// Encode a dispatch with mixed buffer/bytes bindings (dispatch_threads).
952 ///
953 /// Reuses the persistent compute encoder.
954 pub fn encode_with_args(
955 &mut self,
956 pipeline: &ComputePipelineStateRef,
957 bindings: &[(u64, KernelArg<'_>)],
958 grid_size: MTLSize,
959 threadgroup_size: MTLSize,
960 ) {
961 DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
962 let op_kind = self.take_pending_op_kind();
963 let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
964 if let Some(ref mut nodes) = self.capture {
965 nodes.push(CapturedNode::Dispatch {
966 pipeline: pipeline.to_owned(),
967 bindings: Self::record_arg_bindings(bindings),
968 threads_per_grid: grid_size,
969 threads_per_threadgroup: threadgroup_size,
970 threadgroup_memory: Vec::new(),
971 dispatch_kind: DispatchKind::Threads,
972 op_kind,
973 reads: pending_reads,
974 writes: pending_writes,
975 });
976 return;
977 }
978 let encoder = self.get_or_create_encoder();
979 encoder.set_compute_pipeline_state(pipeline);
980 apply_bindings(encoder, bindings);
981 encoder.dispatch_threads(grid_size, threadgroup_size);
982 }
983
984 /// Encode a dispatch with mixed buffer/bytes bindings (dispatch_thread_groups).
985 ///
986 /// Reuses the persistent compute encoder.
987 pub fn encode_threadgroups_with_args(
988 &mut self,
989 pipeline: &ComputePipelineStateRef,
990 bindings: &[(u64, KernelArg<'_>)],
991 threadgroups: MTLSize,
992 threadgroup_size: MTLSize,
993 ) {
994 DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
995 let op_kind = self.take_pending_op_kind();
996 let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
997 if let Some(ref mut nodes) = self.capture {
998 nodes.push(CapturedNode::Dispatch {
999 pipeline: pipeline.to_owned(),
1000 bindings: Self::record_arg_bindings(bindings),
1001 threads_per_grid: threadgroups,
1002 threads_per_threadgroup: threadgroup_size,
1003 threadgroup_memory: Vec::new(),
1004 dispatch_kind: DispatchKind::ThreadGroups,
1005 op_kind,
1006 reads: pending_reads,
1007 writes: pending_writes,
1008 });
1009 return;
1010 }
1011 let encoder = self.get_or_create_encoder();
1012 encoder.set_compute_pipeline_state(pipeline);
1013 apply_bindings(encoder, bindings);
1014 encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
1015 }
1016
1017 /// Encode a dispatch with mixed buffer/bytes bindings and shared memory.
1018 ///
1019 /// Reuses the persistent compute encoder.
1020 pub fn encode_threadgroups_with_args_and_shared(
1021 &mut self,
1022 pipeline: &ComputePipelineStateRef,
1023 bindings: &[(u64, KernelArg<'_>)],
1024 threadgroup_mem: &[(u64, u64)],
1025 threadgroups: MTLSize,
1026 threadgroup_size: MTLSize,
1027 ) {
1028 DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
1029 let op_kind = self.take_pending_op_kind();
1030 let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
1031 if let Some(ref mut nodes) = self.capture {
1032 nodes.push(CapturedNode::Dispatch {
1033 pipeline: pipeline.to_owned(),
1034 bindings: Self::record_arg_bindings(bindings),
1035 threads_per_grid: threadgroups,
1036 threads_per_threadgroup: threadgroup_size,
1037 threadgroup_memory: threadgroup_mem.to_vec(),
1038 dispatch_kind: DispatchKind::ThreadGroups,
1039 op_kind,
1040 reads: pending_reads,
1041 writes: pending_writes,
1042 });
1043 return;
1044 }
1045 let encoder = self.get_or_create_encoder();
1046 encoder.set_compute_pipeline_state(pipeline);
1047 apply_bindings(encoder, bindings);
1048 for &(index, byte_length) in threadgroup_mem {
1049 encoder.set_threadgroup_memory_length(index, byte_length);
1050 }
1051 encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
1052 }
1053
1054 // -----------------------------------------------------------------
1055 // ADR-015 iter37 — dataflow-driven auto-barrier dispatch family.
1056 //
1057 // These mirrors of `encode_threadgroups*_with_args*` take explicit
1058 // `reads: &[&MlxBuffer]` and `writes: &[&MlxBuffer]` slices. When
1059 // the process started with `HF2Q_AUTO_BARRIER=1`, the encoder's
1060 // [`MemRanges`] tracker checks the new ranges against the
1061 // cumulative state since the last barrier; on conflict it emits
1062 // `memory_barrier()` and resets the state before recording the
1063 // new ranges. When the env gate is unset, the check is skipped
1064 // entirely and the dispatch is applied identically to the
1065 // matching `encode_*` method — sourdough-safe by construction.
1066 //
1067 // Capture mode: the `reads`/`writes` ranges are recorded onto the
1068 // captured node via the existing `pending_reads`/`pending_writes`
1069 // mechanism, so a `dispatch_tracked` call inside capture mode is
1070 // equivalent to `set_pending_buffer_ranges + encode_*`.
1071 //
1072 // No production callsite migrates in iter37 — this is the API
1073 // surface the qwen35 forward path will adopt incrementally in
1074 // iter38+. Today, every call to `dispatch_tracked` from a
1075 // production code path lives behind an explicit caller decision
1076 // to opt in.
1077 // -----------------------------------------------------------------
1078
1079 /// Auto-barrier-aware dispatch with [`KernelArg`] bindings (uses
1080 /// `dispatch_thread_groups`).
1081 ///
1082 /// Behaves identically to
1083 /// [`encode_threadgroups_with_args`](Self::encode_threadgroups_with_args)
1084 /// when `HF2Q_AUTO_BARRIER` is unset. When set, consults the
1085 /// per-encoder [`MemRanges`] tracker:
1086 ///
1087 /// * Conflict (RAW/WAR/WAW on a same-buffer range) → emit
1088 /// `memory_barrier()`, increment [`AUTO_BARRIER_COUNT`], reset
1089 /// the tracker, then dispatch and seed the new concurrent group
1090 /// with this dispatch's ranges.
1091 /// * No conflict → increment [`AUTO_BARRIER_CONCURRENT`], record
1092 /// the ranges into the cumulative state, dispatch.
1093 pub fn dispatch_tracked_threadgroups_with_args(
1094 &mut self,
1095 pipeline: &ComputePipelineStateRef,
1096 bindings: &[(u64, KernelArg<'_>)],
1097 reads: &[&MlxBuffer],
1098 writes: &[&MlxBuffer],
1099 threadgroups: MTLSize,
1100 threadgroup_size: MTLSize,
1101 ) {
1102 // Capture mode: stash ranges + delegate to the standard encode.
1103 // The ranges flow through `pending_reads`/`pending_writes` and
1104 // attach to the captured `Dispatch` node — identical to what
1105 // `GraphSession::barrier_between` already does in capture mode.
1106 if self.is_capturing() {
1107 let read_ranges = ranges_from_buffers(reads);
1108 let write_ranges = ranges_from_buffers(writes);
1109 self.set_pending_buffer_ranges(read_ranges, write_ranges);
1110 self.encode_threadgroups_with_args(pipeline, bindings, threadgroups, threadgroup_size);
1111 return;
1112 }
1113
1114 if auto_barrier_enabled() {
1115 self.maybe_auto_barrier(reads, writes);
1116 }
1117
1118 self.encode_threadgroups_with_args(pipeline, bindings, threadgroups, threadgroup_size);
1119 }
1120
1121 /// Auto-barrier-aware dispatch with [`KernelArg`] bindings + shared
1122 /// threadgroup memory.
1123 ///
1124 /// See [`dispatch_tracked_threadgroups_with_args`](Self::dispatch_tracked_threadgroups_with_args)
1125 /// for the behavioral contract; this variant additionally takes a
1126 /// `threadgroup_mem` slice that is forwarded to
1127 /// [`encode_threadgroups_with_args_and_shared`](Self::encode_threadgroups_with_args_and_shared).
1128 ///
1129 /// The 8-argument signature mirrors the existing
1130 /// `encode_threadgroups_with_args_and_shared` plus the two
1131 /// dataflow slices; `clippy::too_many_arguments` is allowed
1132 /// because each parameter is load-bearing for either the dispatch
1133 /// (pipeline/bindings/threadgroups/threadgroup_size/shared_mem)
1134 /// or the auto-barrier (reads/writes).
1135 #[allow(clippy::too_many_arguments)]
1136 pub fn dispatch_tracked_threadgroups_with_args_and_shared(
1137 &mut self,
1138 pipeline: &ComputePipelineStateRef,
1139 bindings: &[(u64, KernelArg<'_>)],
1140 threadgroup_mem: &[(u64, u64)],
1141 reads: &[&MlxBuffer],
1142 writes: &[&MlxBuffer],
1143 threadgroups: MTLSize,
1144 threadgroup_size: MTLSize,
1145 ) {
1146 if self.is_capturing() {
1147 let read_ranges = ranges_from_buffers(reads);
1148 let write_ranges = ranges_from_buffers(writes);
1149 self.set_pending_buffer_ranges(read_ranges, write_ranges);
1150 self.encode_threadgroups_with_args_and_shared(
1151 pipeline,
1152 bindings,
1153 threadgroup_mem,
1154 threadgroups,
1155 threadgroup_size,
1156 );
1157 return;
1158 }
1159
1160 if auto_barrier_enabled() {
1161 self.maybe_auto_barrier(reads, writes);
1162 }
1163
1164 self.encode_threadgroups_with_args_and_shared(
1165 pipeline,
1166 bindings,
1167 threadgroup_mem,
1168 threadgroups,
1169 threadgroup_size,
1170 );
1171 }
1172
1173 /// Auto-barrier-aware dispatch using `(slot, &MlxBuffer)` bindings
1174 /// (uses `dispatch_thread_groups`).
1175 ///
1176 /// Convenience wrapper for callers that don't need
1177 /// [`KernelArg::Bytes`] inline-byte arguments. See
1178 /// [`dispatch_tracked_threadgroups_with_args`](Self::dispatch_tracked_threadgroups_with_args)
1179 /// for behavioral contract.
1180 pub fn dispatch_tracked_threadgroups(
1181 &mut self,
1182 pipeline: &ComputePipelineStateRef,
1183 buffers: &[(u64, &MlxBuffer)],
1184 reads: &[&MlxBuffer],
1185 writes: &[&MlxBuffer],
1186 threadgroups: MTLSize,
1187 threadgroup_size: MTLSize,
1188 ) {
1189 if self.is_capturing() {
1190 let read_ranges = ranges_from_buffers(reads);
1191 let write_ranges = ranges_from_buffers(writes);
1192 self.set_pending_buffer_ranges(read_ranges, write_ranges);
1193 self.encode_threadgroups(pipeline, buffers, threadgroups, threadgroup_size);
1194 return;
1195 }
1196
1197 if auto_barrier_enabled() {
1198 self.maybe_auto_barrier(reads, writes);
1199 }
1200
1201 self.encode_threadgroups(pipeline, buffers, threadgroups, threadgroup_size);
1202 }
1203
1204 /// Auto-barrier-aware dispatch using `(slot, &MlxBuffer)` bindings
1205 /// **plus shared threadgroup memory** (uses `dispatch_thread_groups`).
1206 ///
1207 /// Mirrors [`encode_threadgroups_with_shared`](Self::encode_threadgroups_with_shared)
1208 /// — convenience variant for kernels that allocate threadgroup
1209 /// memory (reductions in `rms_norm`, `softmax`, etc.) but don't
1210 /// need [`KernelArg::Bytes`] inline-byte arguments. See
1211 /// [`dispatch_tracked_threadgroups_with_args`](Self::dispatch_tracked_threadgroups_with_args)
1212 /// for the behavioral contract; the only addition here is the
1213 /// `threadgroup_mem` slice forwarded to the underlying encode.
1214 ///
1215 /// Closes the iter38-audit coverage gap: the 5 `rms_norm.rs`
1216 /// callsites (`/opt/mlx-native/src/ops/rms_norm.rs:124,236,443,
1217 /// 516,589`) all use `encode_threadgroups_with_shared` and need
1218 /// dataflow tracking when migrated to auto-barrier in iter40+.
1219 ///
1220 /// 7-argument signature; `clippy::too_many_arguments` is allowed
1221 /// because each parameter is load-bearing for either the dispatch
1222 /// (pipeline/buffers/threadgroups/threadgroup_size/shared_mem) or
1223 /// the auto-barrier (reads/writes).
1224 #[allow(clippy::too_many_arguments)]
1225 pub fn dispatch_tracked_threadgroups_with_shared(
1226 &mut self,
1227 pipeline: &ComputePipelineStateRef,
1228 buffers: &[(u64, &MlxBuffer)],
1229 threadgroup_mem: &[(u64, u64)],
1230 reads: &[&MlxBuffer],
1231 writes: &[&MlxBuffer],
1232 threadgroups: MTLSize,
1233 threadgroup_size: MTLSize,
1234 ) {
1235 if self.is_capturing() {
1236 let read_ranges = ranges_from_buffers(reads);
1237 let write_ranges = ranges_from_buffers(writes);
1238 self.set_pending_buffer_ranges(read_ranges, write_ranges);
1239 self.encode_threadgroups_with_shared(
1240 pipeline,
1241 buffers,
1242 threadgroup_mem,
1243 threadgroups,
1244 threadgroup_size,
1245 );
1246 return;
1247 }
1248
1249 if auto_barrier_enabled() {
1250 self.maybe_auto_barrier(reads, writes);
1251 }
1252
1253 self.encode_threadgroups_with_shared(
1254 pipeline,
1255 buffers,
1256 threadgroup_mem,
1257 threadgroups,
1258 threadgroup_size,
1259 );
1260 }
1261
1262 /// Auto-barrier-aware `dispatch_threads` variant with
1263 /// [`KernelArg`] bindings.
1264 ///
1265 /// Mirrors [`encode_with_args`](Self::encode_with_args) — the
1266 /// `dispatch_threads` (per-thread grid) flavor, as opposed to the
1267 /// `dispatch_thread_groups` flavor of
1268 /// [`dispatch_tracked_threadgroups_with_args`](Self::dispatch_tracked_threadgroups_with_args).
1269 /// See that method for the behavioral contract.
1270 ///
1271 /// Closes the iter38-audit coverage gap: callers that use
1272 /// per-thread grids — `rope.rs:108` (IMROPE), `sigmoid_mul.rs:76`
1273 /// (sigmoid-mul), and `encode_helpers.rs:41` (kv_cache_copy) —
1274 /// need a `dispatch_threads` flavor of the tracked dispatch
1275 /// because their grid sizes are expressed in threads, not
1276 /// threadgroups.
1277 ///
1278 /// Note: the simpler `(slot, &MlxBuffer)` form (from
1279 /// [`encode`](Self::encode)) is a special case of this method —
1280 /// callers can wrap each binding as `KernelArg::Buffer(buf)` to
1281 /// reuse this single tracked variant rather than introducing a
1282 /// fifth one.
1283 pub fn dispatch_tracked_threads_with_args(
1284 &mut self,
1285 pipeline: &ComputePipelineStateRef,
1286 bindings: &[(u64, KernelArg<'_>)],
1287 reads: &[&MlxBuffer],
1288 writes: &[&MlxBuffer],
1289 grid_size: MTLSize,
1290 threadgroup_size: MTLSize,
1291 ) {
1292 if self.is_capturing() {
1293 let read_ranges = ranges_from_buffers(reads);
1294 let write_ranges = ranges_from_buffers(writes);
1295 self.set_pending_buffer_ranges(read_ranges, write_ranges);
1296 self.encode_with_args(pipeline, bindings, grid_size, threadgroup_size);
1297 return;
1298 }
1299
1300 if auto_barrier_enabled() {
1301 self.maybe_auto_barrier(reads, writes);
1302 }
1303
1304 self.encode_with_args(pipeline, bindings, grid_size, threadgroup_size);
1305 }
1306
1307 /// Run the dataflow check, emit a barrier on conflict, and record
1308 /// the dispatch's ranges into the cumulative state.
1309 ///
1310 /// Always called *before* the underlying `encode_*` method
1311 /// applies the dispatch. Mirrors lines 220-225 of
1312 /// `ggml-metal-ops.cpp` (`concurrency_check + concurrency_reset +
1313 /// concurrency_add` around each node).
1314 fn maybe_auto_barrier(
1315 &mut self,
1316 reads: &[&MlxBuffer],
1317 writes: &[&MlxBuffer],
1318 ) {
1319 if self.mem_ranges.check_dispatch(reads, writes) {
1320 // Concurrent — no barrier needed; just record the new ranges.
1321 self.mem_ranges.add_dispatch(reads, writes);
1322 AUTO_BARRIER_CONCURRENT.fetch_add(1, Ordering::Relaxed);
1323 } else {
1324 // Conflict — emit barrier, reset state, seed new group.
1325 //
1326 // `memory_barrier()` itself increments `BARRIER_COUNT` and,
1327 // when `MLX_PROFILE_BARRIERS=1`, accumulates `BARRIER_NS`.
1328 // We additionally bump `AUTO_BARRIER_COUNT` so the
1329 // "auto-emitted vs hand-placed" subset is queryable.
1330 self.memory_barrier();
1331 self.mem_ranges.reset();
1332 self.mem_ranges.add_dispatch(reads, writes);
1333 AUTO_BARRIER_COUNT.fetch_add(1, Ordering::Relaxed);
1334 }
1335 }
1336
1337 /// Force a barrier and reset the auto-barrier tracker.
1338 ///
1339 /// Use at boundaries where the caller knows a barrier is required
1340 /// regardless of dataflow — typically before reading data back to
1341 /// CPU, or at the end of an op group whose internal dependencies
1342 /// the tracker can't see (e.g. host-driven memcpy).
1343 ///
1344 /// Equivalent to `memory_barrier()` plus a `MemRanges::reset()`
1345 /// when `HF2Q_AUTO_BARRIER=1`; equivalent to plain
1346 /// `memory_barrier()` otherwise.
1347 pub fn force_barrier_and_reset_tracker(&mut self) {
1348 self.memory_barrier();
1349 if auto_barrier_enabled() {
1350 self.mem_ranges.reset();
1351 }
1352 }
1353
1354 /// Diagnostic accessor — number of ranges currently recorded in
1355 /// this encoder's [`MemRanges`] tracker. Always zero unless
1356 /// `HF2Q_AUTO_BARRIER=1` and at least one `dispatch_tracked` call
1357 /// has fired since the last conflict.
1358 #[inline]
1359 pub fn mem_ranges_len(&self) -> usize {
1360 self.mem_ranges.len()
1361 }
1362
1363 /// Replay a single captured dispatch node into this encoder.
1364 ///
1365 /// This is the inverse of capture: it takes a previously recorded
1366 /// `CapturedNode::Dispatch` and encodes it into the live Metal encoder.
1367 /// Barrier nodes are handled by the caller (ComputeGraph::encode_sequential).
1368 ///
1369 /// Does NOT increment `DISPATCH_COUNT` — that was already counted at
1370 /// capture time.
1371 pub fn replay_dispatch(
1372 &mut self,
1373 pipeline: &ComputePipelineStateRef,
1374 bindings: &[(u64, RecordedBinding)],
1375 threadgroup_memory: &[(u64, u64)],
1376 threads_per_grid: MTLSize,
1377 threads_per_threadgroup: MTLSize,
1378 dispatch_kind: DispatchKind,
1379 ) {
1380 let encoder = self.get_or_create_encoder();
1381 encoder.set_compute_pipeline_state(pipeline);
1382 for (index, binding) in bindings {
1383 match binding {
1384 RecordedBinding::Buffer { metal_buffer, offset } => {
1385 encoder.set_buffer(*index, Some(metal_buffer), *offset);
1386 }
1387 RecordedBinding::Bytes(bytes) => {
1388 encoder.set_bytes(
1389 *index,
1390 bytes.len() as u64,
1391 bytes.as_ptr() as *const _,
1392 );
1393 }
1394 }
1395 }
1396 for &(index, byte_length) in threadgroup_memory {
1397 encoder.set_threadgroup_memory_length(index, byte_length);
1398 }
1399 match dispatch_kind {
1400 DispatchKind::Threads => {
1401 encoder.dispatch_threads(threads_per_grid, threads_per_threadgroup);
1402 }
1403 DispatchKind::ThreadGroups => {
1404 encoder.dispatch_thread_groups(threads_per_grid, threads_per_threadgroup);
1405 }
1406 }
1407 }
1408
1409 /// Flush any pending residency-set add/remove staging.
1410 ///
1411 /// Hooked at every commit boundary so per-allocation
1412 /// [`ResidencySet::add_allocation`](ResidencySet::add_allocation) and
1413 /// [`ResidencySet::remove_allocation`](ResidencySet::remove_allocation)
1414 /// calls (as fired by `MlxDevice::alloc_buffer` and
1415 /// `MlxBufferStorage::Drop`) collapse into at most ONE `[set commit]`
1416 /// per CB submission. Mirrors llama.cpp's
1417 /// `ggml-metal-device.m:1378-1382` (batch addAllocation in loop,
1418 /// commit ONCE).
1419 #[inline]
1420 fn flush_residency_pending(&self) {
1421 if let Some(set) = self.residency_set.as_ref() {
1422 set.flush_pending();
1423 }
1424 }
1425
1426 /// Commit the command buffer and block until the GPU finishes execution.
1427 ///
1428 /// # Errors
1429 ///
1430 /// Returns `MlxError::CommandBufferError` if the GPU reports an error.
1431 pub fn commit_and_wait(&mut self) -> Result<()> {
1432 SYNC_COUNT.fetch_add(1, Ordering::Relaxed);
1433
1434 // End the persistent compute encoder before committing.
1435 self.end_active_encoder();
1436
1437 // ADR-015 iter8e (Phase 3b): flush deferred residency-set
1438 // add/remove staging so the residency hint covers any buffers
1439 // referenced by this CB. Single commit per CB boundary; no-op
1440 // when no residency set or no staged changes.
1441 self.flush_residency_pending();
1442
1443 self.cmd_buf.commit();
1444 self.cmd_buf.wait_until_completed();
1445
1446 match self.cmd_buf.status() {
1447 MTLCommandBufferStatus::Completed => Ok(()),
1448 MTLCommandBufferStatus::Error => {
1449 Err(MlxError::CommandBufferError(
1450 "GPU command buffer completed with error status".into(),
1451 ))
1452 }
1453 status => Err(MlxError::CommandBufferError(format!(
1454 "Unexpected command buffer status after wait: {:?}",
1455 status
1456 ))),
1457 }
1458 }
1459
1460 /// Commit + wait, accumulating GPU wall-clock time under `label` into
1461 /// the [`crate::kernel_profile`] global table when `MLX_PROFILE_CB=1`
1462 /// is set. When the env var is unset, this is identical to
1463 /// [`commit_and_wait`](Self::commit_and_wait) — zero overhead.
1464 ///
1465 /// Used by hf2q's decode hot path to attribute per-cb GPU time to
1466 /// labeled phases (per-layer attn, per-layer ffn, output_head, etc.)
1467 /// without manually wiring `commit_wait_with_gpu_time` everywhere.
1468 ///
1469 /// # Errors
1470 ///
1471 /// Returns `MlxError::CommandBufferError` if the GPU reports an error.
1472 pub fn commit_and_wait_labeled(&mut self, label: &str) -> Result<()> {
1473 // ADR-015 iter16 — propagate `label` to MTLCommandBuffer.setLabel and
1474 // (if a compute encoder is active) MTLComputeCommandEncoder.setLabel
1475 // BEFORE end_encoding/commit so xctrace's
1476 // `metal-application-encoders-list` table populates `cmdbuffer-label`
1477 // and `encoder-label` columns with the semantic phase name (e.g.
1478 // `layer.attn_moe_ffn`, `output_head.fused_norm_lm_argmax`,
1479 // `layer.delta_net.ops1-9`). Joined to per-CB GPU duration via
1480 // `metal-gpu-submission-to-command-buffer-id` (sub_id ↔ encoder_id) →
1481 // `metal-gpu-execution-points` (per-dispatch start/end), this enables
1482 // per-phase µs/token attribution comparing hf2q vs llama side-by-side
1483 // (iter15 §E "iter16 ATTRIBUTION PATH"). Cost is a single ObjC
1484 // msg_send per CB submission — sub-µs on M5 Max — and a no-op when
1485 // xctrace isn't recording, so this is unconditionally safe to call on
1486 // the production decode hot path.
1487 self.apply_labels(label);
1488 if crate::kernel_profile::is_enabled() {
1489 let (start_s, end_s) = self.commit_wait_with_gpu_time()?;
1490 let ns = ((end_s - start_s).max(0.0) * 1_000_000_000.0) as u64;
1491 crate::kernel_profile::record(label, ns);
1492 Ok(())
1493 } else {
1494 self.commit_and_wait()
1495 }
1496 }
1497
1498 /// Async commit, but with profiling label. When `MLX_PROFILE_CB=1`
1499 /// is set, redirects to a synchronous [`commit_and_wait_labeled`]
1500 /// call to capture per-cb GPU time (this defeats async pipelining
1501 /// while profiling, which is the whole point — profile-mode is slow
1502 /// but informative). When unset, identical to [`commit`](Self::commit).
1503 pub fn commit_labeled(&mut self, label: &str) {
1504 // ADR-015 iter16 — see `commit_and_wait_labeled` for rationale.
1505 if crate::kernel_profile::is_enabled() {
1506 // Profile mode: force sync to capture GPU time. apply_labels is
1507 // called inside commit_and_wait_labeled — do NOT call it twice
1508 // here (would double the ObjC msg_send under MLX_PROFILE_CB=1).
1509 // Errors are logged via stderr because the void return matches
1510 // commit().
1511 if let Err(e) = self.commit_and_wait_labeled(label) {
1512 eprintln!("[mlx-native] commit_labeled({}) failed: {}", label, e);
1513 }
1514 } else {
1515 // Async path: apply labels here so xctrace MST traces capture
1516 // per-CB phase attribution under default decode (no
1517 // `MLX_PROFILE_CB`).
1518 self.apply_labels(label);
1519 self.commit();
1520 }
1521 }
1522
1523 /// Apply `label` to the underlying `MTLCommandBuffer` and, if a compute
1524 /// encoder is currently active, to the `MTLComputeCommandEncoder`.
1525 ///
1526 /// Called from [`commit_labeled`] and [`commit_and_wait_labeled`] BEFORE
1527 /// the encoder is ended / the CB is committed so xctrace's
1528 /// `metal-application-encoders-list` table picks up the label on the
1529 /// row emitted at the encoder's `endEncoding` / CB submission boundary.
1530 /// Single ObjC `msg_send` per call (two if an encoder is active); sub-µs
1531 /// on M5 Max; no-op when xctrace isn't recording.
1532 ///
1533 /// Skipped (debug-only assert) if `label` is empty — empty labels would
1534 /// produce an indistinguishable trace row from the metal-rs default
1535 /// `Command Buffer 0` placeholder.
1536 #[inline]
1537 fn apply_labels(&self, label: &str) {
1538 debug_assert!(!label.is_empty(), "commit_*_labeled called with empty label");
1539 if label.is_empty() {
1540 return;
1541 }
1542 self.cmd_buf.set_label(label);
1543 if !self.active_encoder.is_null() {
1544 // SAFETY: active_encoder is non-null and points to a live encoder
1545 // owned by cmd_buf — same invariant as get_or_create_encoder /
1546 // memory_barrier. set_label is a single property write on the
1547 // ObjC object; safe before endEncoding.
1548 unsafe { &*self.active_encoder }.set_label(label);
1549 }
1550 }
1551
1552 /// Commit + wait, returning `(gpu_start_s, gpu_end_s)` CFTimeInterval
1553 /// timestamps from `MTLCommandBuffer`'s `GPUStartTime`/`GPUEndTime`
1554 /// properties. Both are mach-absolute CFTimeInterval seconds (double).
1555 ///
1556 /// Intended for `HF2Q_PROFILE_GPU_TS=1` per-bucket GPU wall-clock
1557 /// attribution. Adds exactly two ObjC property reads per call on top
1558 /// of the regular `commit_and_wait` — measured well under 1 μs on
1559 /// M5 Max.
1560 ///
1561 /// # Errors
1562 ///
1563 /// Returns `MlxError::CommandBufferError` if the GPU reports an error.
1564 pub fn commit_wait_with_gpu_time(&mut self) -> Result<(f64, f64)> {
1565 self.commit_and_wait()?;
1566 // SAFETY: cmd_buf is a valid MTLCommandBuffer that has been
1567 // committed and awaited. GPUStartTime / GPUEndTime return
1568 // CFTimeInterval (double precision seconds). See
1569 // https://developer.apple.com/documentation/metal/mtlcommandbuffer/1639925-gpustarttime
1570 let (gpu_start, gpu_end): (f64, f64) = unsafe {
1571 let cb = &*self.cmd_buf;
1572 let s: f64 = msg_send![cb, GPUStartTime];
1573 let e: f64 = msg_send![cb, GPUEndTime];
1574 (s, e)
1575 };
1576 Ok((gpu_start, gpu_end))
1577 }
1578
1579 /// Commit the command buffer WITHOUT blocking.
1580 ///
1581 /// The GPU begins executing the encoded commands immediately. Call
1582 /// [`wait_until_completed`](Self::wait_until_completed) later to block
1583 /// the CPU and check for errors. This allows the CPU to continue doing
1584 /// other work (e.g. preparing the next batch) while the GPU runs.
1585 pub fn commit(&mut self) {
1586 self.end_active_encoder();
1587 // ADR-015 iter8e (Phase 3b): same flush hook as commit_and_wait —
1588 // this is the async-pipeline path that production decode uses.
1589 self.flush_residency_pending();
1590 self.cmd_buf.commit();
1591 }
1592
1593 /// Block until a previously committed command buffer completes.
1594 ///
1595 /// Must be called after [`commit`](Self::commit). Do not call after
1596 /// [`commit_and_wait`](Self::commit_and_wait) — that method already waits.
1597 ///
1598 /// # Errors
1599 ///
1600 /// Returns `MlxError::CommandBufferError` if the GPU reports an error.
1601 pub fn wait_until_completed(&self) -> Result<()> {
1602 self.cmd_buf.wait_until_completed();
1603 match self.cmd_buf.status() {
1604 MTLCommandBufferStatus::Completed => Ok(()),
1605 MTLCommandBufferStatus::Error => Err(MlxError::CommandBufferError(
1606 "GPU command buffer completed with error status".into(),
1607 )),
1608 status => Err(MlxError::CommandBufferError(format!(
1609 "Unexpected command buffer status after wait: {:?}",
1610 status
1611 ))),
1612 }
1613 }
1614
1615 /// Borrow the underlying Metal command buffer.
1616 #[inline]
1617 pub fn metal_command_buffer(&self) -> &CommandBuffer {
1618 &self.cmd_buf
1619 }
1620}
1621
1622impl Drop for CommandEncoder {
1623 fn drop(&mut self) {
1624 // End the persistent compute encoder before the command buffer
1625 // is dropped, otherwise Metal will assert:
1626 // "Command encoder released without endEncoding"
1627 self.end_active_encoder();
1628 }
1629}