Skip to main content

mlx_native/
encoder.rs

1//! [`CommandEncoder`] — batched GPU command submission.
2//!
3//! Wraps a Metal command buffer.  Encode one or more compute kernel dispatches,
4//! then call [`commit_and_wait`](CommandEncoder::commit_and_wait) to submit the
5//! entire batch and block until the GPU finishes.
6//!
7//! # Persistent compute encoder
8//!
9//! A single Metal `ComputeCommandEncoder` is kept alive across multiple
10//! dispatches within the same command buffer.  This avoids the overhead of
11//! creating and ending a new compute encoder per dispatch — the same pattern
12//! candle uses (`compute_per_buffer`).  On a forward pass with ~800 dispatches
13//! this saves ~800 encoder create/end cycles.
14//!
15//! # Capture mode (Phase 4e.1)
16//!
17//! When `start_capture()` is called, subsequent dispatches are recorded into a
18//! `Vec<CapturedNode>` instead of being encoded into Metal.  `memory_barrier()`
19//! records a barrier sentinel.  Call `take_capture()` to extract the recorded
20//! graph for later replay via `ComputeGraph::encode_sequential()`.
21
22use std::sync::atomic::{AtomicU64, Ordering};
23
24use metal::{
25    CommandBuffer, CommandQueue, ComputeCommandEncoderRef, ComputePipelineState,
26    ComputePipelineStateRef, MTLCommandBufferStatus, MTLDispatchType, MTLSize,
27};
28#[allow(unused_imports)]
29use objc::{msg_send, sel, sel_impl};
30
31use crate::buffer::MlxBuffer;
32use crate::error::{MlxError, Result};
33use crate::mem_ranges::MemRanges;
34use crate::residency::ResidencySet;
35
36/// A buffer or inline-bytes binding for a compute kernel argument slot.
37pub enum KernelArg<'a> {
38    /// Bind an existing Metal buffer at the given index.
39    Buffer(&'a MlxBuffer),
40    /// Bind an existing Metal buffer at the given index with a byte offset.
41    BufferWithOffset(&'a MlxBuffer, u64),
42    /// Bind inline bytes (small constant data) at the given index.
43    /// The data must be `Pod` and is copied into the command encoder.
44    Bytes(&'a [u8]),
45}
46
47/// Convert a `Pod` value to a byte slice suitable for `KernelArg::Bytes`.
48///
49/// # Safety
50///
51/// The caller must ensure `T` has the same layout as the corresponding
52/// MSL struct in the shader (matching field order, sizes, and alignment).
53pub fn as_bytes<T: bytemuck::Pod>(val: &T) -> &[u8] {
54    bytemuck::bytes_of(val)
55}
56
57// ---------------------------------------------------------------------------
58// Capture-mode types (Phase 4e.1 — Graph IR)
59// ---------------------------------------------------------------------------
60
61/// A recorded kernel argument binding.
62///
63/// When the encoder is in capture mode, each `set_buffer` / `set_bytes` call
64/// is stored as a `RecordedBinding` instead of being applied to Metal.
65#[derive(Clone)]
66pub enum RecordedBinding {
67    /// A Metal buffer at the given offset.
68    Buffer {
69        metal_buffer: metal::Buffer,
70        offset: u64,
71    },
72    /// Inline bytes (small constant data, copied).
73    Bytes(Vec<u8>),
74}
75
76/// How to dispatch the recorded kernel.
77#[derive(Clone, Copy, Debug)]
78pub enum DispatchKind {
79    /// `dispatch_threads(grid_size, threadgroup_size)` — Metal picks threadgroup count.
80    Threads,
81    /// `dispatch_thread_groups(threadgroups, threadgroup_size)` — caller specifies threadgroup count.
82    ThreadGroups,
83}
84
85/// Operation kind tag for captured nodes, used by the fusion pass (4e.2).
86///
87/// When the encoder is in capture mode, each dispatch can be tagged with an
88/// `OpKind` so the fusion pass can identify fuseable sequences without
89/// inspecting pipeline names.
90#[derive(Clone, Copy, Debug, PartialEq, Eq)]
91pub enum CapturedOpKind {
92    /// RMS normalization (with learned scale).
93    RmsNorm,
94    /// Elementwise multiply.
95    ElemMul,
96    /// Elementwise add.
97    ElemAdd,
98    /// Scaled dot-product attention (NOT reorderable — breaks lookahead).
99    Sdpa,
100    /// Softmax (NOT reorderable — breaks lookahead).
101    Softmax,
102    /// Any other operation — treated as reorderable by the graph optimizer.
103    Other,
104}
105
106impl CapturedOpKind {
107    /// Whether this captured op kind is safe to reorder past in the graph
108    /// optimizer (Phase 4e.3).
109    ///
110    /// Mirrors the `h_safe` whitelist from llama.cpp's
111    /// `ggml_metal_graph_optimize_reorder`.  Non-safe ops break the 64-node
112    /// lookahead — the reorder pass cannot look past them.
113    pub fn is_reorderable(&self) -> bool {
114        match self {
115            Self::Sdpa | Self::Softmax => false,
116            Self::RmsNorm | Self::ElemMul | Self::ElemAdd | Self::Other => true,
117        }
118    }
119}
120
121/// A memory range annotation: (start_address, end_address).
122///
123/// Represents a contiguous GPU buffer region for conflict detection in the
124/// reorder pass (Phase 4e.3).  Addresses are CPU-visible `contents_ptr()`
125/// values, which on Apple Silicon unified memory equal the GPU addresses.
126pub type MemRange = (usize, usize);
127
128/// A single captured compute dispatch or barrier sentinel.
129///
130/// Created when the encoder is in capture mode.  Replayed later by
131/// `ComputeGraph::encode_sequential()`.
132#[derive(Clone)]
133pub enum CapturedNode {
134    /// A compute dispatch to replay.
135    Dispatch {
136        /// Pipeline state object to bind.
137        pipeline: ComputePipelineState,
138        /// Kernel argument bindings: (slot_index, binding).
139        bindings: Vec<(u64, RecordedBinding)>,
140        /// Grid or threadgroup count (interpretation depends on `dispatch_kind`).
141        threads_per_grid: MTLSize,
142        /// Threads per threadgroup.
143        threads_per_threadgroup: MTLSize,
144        /// Optional threadgroup memory allocations: (index, byte_length).
145        threadgroup_memory: Vec<(u64, u64)>,
146        /// Whether this is a dispatch_threads or dispatch_thread_groups call.
147        dispatch_kind: DispatchKind,
148        /// Operation kind tag for the fusion pass (4e.2).
149        /// Defaults to `Other` if not explicitly set via `set_op_kind()`.
150        op_kind: CapturedOpKind,
151        /// Read buffer ranges for reorder conflict detection (4e.3).
152        /// Populated from `barrier_between` calls in capture mode.
153        reads: Vec<MemRange>,
154        /// Write buffer ranges for reorder conflict detection (4e.3).
155        /// Populated from `barrier_between` calls in capture mode.
156        writes: Vec<MemRange>,
157    },
158    /// A memory barrier sentinel — forces a barrier at replay time.
159    Barrier,
160}
161
162/// Convert a slice of buffer references into capture-mode
163/// [`MemRange`] tuples.  Used by the [`CommandEncoder::dispatch_tracked*`]
164/// family in capture mode — equivalent to the conversion
165/// `GraphSession::barrier_between` does at `graph.rs:1452-1465`.
166///
167/// `(start, end)` uses `contents_ptr() + byte_offset` as the start
168/// and `contents_ptr() + byte_offset + slice_extent` as the end.
169fn ranges_from_buffers(bufs: &[&MlxBuffer]) -> Vec<MemRange> {
170    bufs.iter()
171        .map(|b| {
172            let base = b.contents_ptr() as usize + b.byte_offset() as usize;
173            let extent = (b.byte_len()).saturating_sub(b.byte_offset() as usize);
174            (base, base + extent)
175        })
176        .collect()
177}
178
179/// Apply a slice of `KernelArg` bindings to a compute encoder.
180///
181/// `KernelArg::Buffer(buf)` propagates the `MlxBuffer::byte_offset()` so
182/// `slice_view`-derived sub-buffers are honored automatically — the
183/// kernel sees memory starting at the slice's offset. This matches the
184/// documented contract of `slice_view` and the offset-handling in the
185/// other binding paths in this file (`encode`, `encode_threadgroups`,
186/// `encode_threadgroups_with_shared`, replay). Without it, every
187/// `slice_view`-derived buffer bound via `KernelArg::Buffer` silently
188/// exposes the entire underlying allocation — surfaced by hf2q's
189/// nomic-bert iter-79 cosine parity bisection (cosine 0.098 → 0.999962
190/// after fix).
191///
192/// `KernelArg::BufferWithOffset(buf, offset)` continues to use the
193/// explicit `offset` argument verbatim (callers asking for an explicit
194/// offset get exactly that, even on sliced buffers). The two API
195/// surfaces are intentional: implicit (sliced views auto-propagate) vs.
196/// explicit (caller-controlled).
197#[inline]
198fn apply_bindings(encoder: &ComputeCommandEncoderRef, bindings: &[(u64, KernelArg<'_>)]) {
199    for &(index, ref arg) in bindings {
200        match arg {
201            KernelArg::Buffer(buf) => {
202                encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
203            }
204            KernelArg::BufferWithOffset(buf, offset) => {
205                encoder.set_buffer(index, Some(buf.metal_buffer()), *offset);
206            }
207            KernelArg::Bytes(bytes) => {
208                encoder.set_bytes(index, bytes.len() as u64, bytes.as_ptr() as *const _);
209            }
210        }
211    }
212}
213
214/// Number of times `commit_and_wait()` has been called (CPU sync points).
215static SYNC_COUNT: AtomicU64 = AtomicU64::new(0);
216
217/// Number of times an encode method has been called (GPU dispatches).
218static DISPATCH_COUNT: AtomicU64 = AtomicU64::new(0);
219
220/// Number of `MTLCommandBuffer` instances created via `CommandEncoder::new`.
221/// Increments once per `device.command_encoder()` call.  Used by hf2q's
222/// `HF2Q_DECODE_PROFILE` instrumentation to measure command-buffer
223/// overhead per decode token (ADR-012 §Optimize / Task #15 follow-up).
224static CMD_BUF_COUNT: AtomicU64 = AtomicU64::new(0);
225
226/// Number of `memory_barrier()` calls that reached the
227/// `objc::msg_send![encoder, memoryBarrierWithScope:]` site.  Capture-mode
228/// no-ops and pre-encoder no-ops are excluded so the count reflects
229/// actual MTL barriers issued.
230///
231/// Always tracked — the increment is one atomic op, ~5 ns.  ADR-015 H4
232/// (Wave 2b hard gate #2) requires per-barrier counter resolution to
233/// confirm-or-falsify the barrier-coalescing lever; xctrace TimeProfiler
234/// at 1 ms sampling cannot resolve `memory_barrier` even though it fires
235/// ~440×/token (`docs/ADR-015-mlx-native-single-cb-decode.md` §"P3a' live
236/// profile pass" hypothesis register row H4).
237static BARRIER_COUNT: AtomicU64 = AtomicU64::new(0);
238
239/// Total nanoseconds spent inside the `objc::msg_send!` barrier site,
240/// summed across all calls.  ONLY updated when the env var
241/// `MLX_PROFILE_BARRIERS=1` is set on the process (cached on first
242/// `memory_barrier` call).  When disabled the timing path is a single
243/// branch + the unconditional barrier dispatch — same hot-path cost as
244/// before this counter was added.
245///
246/// Why env-gated: timing adds 2 × `Instant::now()` (~50–100 ns each via
247/// `mach_absolute_time`) per barrier.  At ~440 barriers/token that is
248/// ~22–44 µs/token of measurement overhead — comparable to what we are
249/// trying to measure.  Production must keep this off; profiling runs
250/// opt-in.
251static BARRIER_NS: AtomicU64 = AtomicU64::new(0);
252
253/// Reset all counters to zero.
254pub fn reset_counters() {
255    SYNC_COUNT.store(0, Ordering::Relaxed);
256    DISPATCH_COUNT.store(0, Ordering::Relaxed);
257    CMD_BUF_COUNT.store(0, Ordering::Relaxed);
258    BARRIER_COUNT.store(0, Ordering::Relaxed);
259    BARRIER_NS.store(0, Ordering::Relaxed);
260    AUTO_BARRIER_COUNT.store(0, Ordering::Relaxed);
261    AUTO_BARRIER_CONCURRENT.store(0, Ordering::Relaxed);
262}
263
264/// Read the current value of `SYNC_COUNT`.
265///
266/// Each call to `commit_and_wait()` increments this counter.
267pub fn sync_count() -> u64 {
268    SYNC_COUNT.load(Ordering::Relaxed)
269}
270
271/// Read the current value of `DISPATCH_COUNT`.
272///
273/// Each call to `encode()`, `encode_threadgroups()`, or
274/// `encode_threadgroups_with_shared()` increments this counter.
275pub fn dispatch_count() -> u64 {
276    DISPATCH_COUNT.load(Ordering::Relaxed)
277}
278
279/// Read the current value of `CMD_BUF_COUNT`.
280///
281/// Each `CommandEncoder::new` (i.e. each `MlxDevice::command_encoder()`)
282/// increments this counter.  Useful for diagnosing per-dispatch Metal
283/// command-buffer overhead in inner loops.
284pub fn cmd_buf_count() -> u64 {
285    CMD_BUF_COUNT.load(Ordering::Relaxed)
286}
287
288/// Read the current value of `BARRIER_COUNT`.
289///
290/// Each `memory_barrier()` call that reaches the underlying
291/// `objc::msg_send![encoder, memoryBarrierWithScope:]` site increments this
292/// counter.  Capture-mode no-ops and pre-encoder no-ops are excluded.
293/// ADR-015 H4 hypothesis: ~440 barriers/token on the qwen35 decode hot
294/// path (verify against this counter).
295pub fn barrier_count() -> u64 {
296    BARRIER_COUNT.load(Ordering::Relaxed)
297}
298
299/// Read the total nanoseconds spent in the `memoryBarrierWithScope:`
300/// `objc::msg_send!` site.  Only non-zero when `MLX_PROFILE_BARRIERS=1`
301/// was in the environment at the time of the first `memory_barrier()`
302/// call (the env check is cached on first use).
303///
304/// Combined with [`barrier_count`] this gives µs/barrier =
305/// `barrier_total_ns() / 1000 / barrier_count()`.
306pub fn barrier_total_ns() -> u64 {
307    BARRIER_NS.load(Ordering::Relaxed)
308}
309
310/// Whether barrier timing is enabled (env-gated, cached on first check).
311///
312/// Reading the env var via `std::env::var` is itself non-trivial; using
313/// `OnceLock` caches the decision so the per-barrier branch is a single
314/// atomic-load + compare.
315fn barrier_profile_enabled() -> bool {
316    use std::sync::OnceLock;
317    static FLAG: OnceLock<bool> = OnceLock::new();
318    *FLAG.get_or_init(|| {
319        std::env::var("MLX_PROFILE_BARRIERS")
320            .map(|v| v == "1")
321            .unwrap_or(false)
322    })
323}
324
325/// Whether `MLX_UNRETAINED_REFS=1` is set in the process environment.
326///
327/// ADR-015 iter13 — when true, `CommandEncoder::new_with_residency` opens
328/// each `MTLCommandBuffer` via
329/// [`CommandQueueRef::new_command_buffer_with_unretained_references`]
330/// instead of the default `commandBuffer`.  llama.cpp's per-token decode
331/// CBs use this same call (`/opt/llama.cpp/ggml/src/ggml-metal/`
332/// `ggml-metal-context.m:512` `[queue commandBufferWithUnretainedReferences]`)
333/// and gain ~3-5% wall on M-series GPUs by skipping per-buffer-binding ARC
334/// retains on submit.
335///
336/// **Caller-side prerequisite.**  Every Metal buffer bound to a dispatch
337/// must outlive the CB — see the docstring on
338/// [`CommandEncoder::new_with_residency`] for the full caller contract.
339/// In hf2q, the per-decode-token `MlxBufferPool` (`buffer_pool.rs`)
340/// already keeps ARC clones alive in its `in_use` list across the entire
341/// decode token; routing transient scratches through that pool is the
342/// canonical way to satisfy the contract.
343///
344/// Cached on first read via `OnceLock` to keep the per-CB-construction
345/// branch single-atomic-load fast.  Default OFF so any production decode
346/// run that does NOT explicitly set the var preserves retained-refs
347/// behavior verbatim.
348fn unretained_refs_enabled() -> bool {
349    use std::sync::OnceLock;
350    static FLAG: OnceLock<bool> = OnceLock::new();
351    *FLAG.get_or_init(|| {
352        std::env::var("MLX_UNRETAINED_REFS")
353            .map(|v| v == "1")
354            .unwrap_or(false)
355    })
356}
357
358/// Whether `HF2Q_AUTO_BARRIER=1` is set in the process environment.
359///
360/// ADR-015 iter37 — when true, every [`CommandEncoder::dispatch_tracked`]
361/// call consults a [`MemRanges`](crate::mem_ranges::MemRanges) tracker
362/// and auto-emits a `memoryBarrierWithScope:` exactly when the new
363/// dispatch's read/write ranges conflict with previously-recorded
364/// ranges (mirrors llama.cpp's `ggml_metal_op_concurrency_check` at
365/// `/opt/llama.cpp/ggml/src/ggml-metal/ggml-metal-ops.cpp:147-225`).
366/// When false, `dispatch_tracked` collapses to the same code path as
367/// `encode*` — no tracking, no auto-barriers — preserving sourdough
368/// behavior for any caller that opts into the tracked API but runs
369/// without the env gate.
370///
371/// Cached on first read via `OnceLock`.  Default OFF — production
372/// decode/prefill keeps its hand-placed `enc.memory_barrier()` calls
373/// until the migration in iter38+.
374fn auto_barrier_enabled() -> bool {
375    use std::sync::OnceLock;
376    static FLAG: OnceLock<bool> = OnceLock::new();
377    *FLAG.get_or_init(|| {
378        std::env::var("HF2Q_AUTO_BARRIER")
379            .map(|v| v == "1")
380            .unwrap_or(false)
381    })
382}
383
384/// Number of `memory_barrier()` calls auto-emitted by
385/// [`CommandEncoder::dispatch_tracked`] under
386/// `HF2Q_AUTO_BARRIER=1`.  Disjoint from [`BARRIER_COUNT`] —
387/// auto-barriers also bump `BARRIER_COUNT` since they go through
388/// `memory_barrier()`, so this counter measures only the
389/// auto-emitted subset.
390static AUTO_BARRIER_COUNT: AtomicU64 = AtomicU64::new(0);
391
392/// Number of `dispatch_tracked` calls whose mem-ranges check returned
393/// "concurrent" (no barrier needed).  Together with
394/// [`AUTO_BARRIER_COUNT`] this measures the elision rate of the
395/// dataflow barrier: `concurrent / (concurrent + barriers)` is the
396/// fraction of dispatches that ran inside the previous concurrent
397/// group rather than starting a new one.
398static AUTO_BARRIER_CONCURRENT: AtomicU64 = AtomicU64::new(0);
399
400/// Read the cumulative number of auto-emitted barriers across all
401/// encoders since process start (or last [`reset_counters`]).
402pub fn auto_barrier_count() -> u64 {
403    AUTO_BARRIER_COUNT.load(Ordering::Relaxed)
404}
405
406/// Read the cumulative number of `dispatch_tracked` calls that did NOT
407/// emit a barrier (ran concurrent with the previous group).
408pub fn auto_barrier_concurrent_count() -> u64 {
409    AUTO_BARRIER_CONCURRENT.load(Ordering::Relaxed)
410}
411
412/// Issue the underlying Metal `memoryBarrierWithScope:` ObjC msg_send.
413///
414/// Held in its own `#[inline(never)]` function so xctrace / Instruments
415/// has a stable Rust frame to attribute barrier time against, separate
416/// from the surrounding encoder accounting.  Per ADR-015 §P3a' Codex
417/// review Q2: TimeProfiler at 1 ms sampling cannot see this site when
418/// inlined; an explicit non-inline frame plus the [`BARRIER_NS`] counter
419/// closes the H4 hard gate.
420#[inline(never)]
421fn issue_metal_buffer_barrier(encoder: &ComputeCommandEncoderRef) {
422    // MTLBarrierScopeBuffers = 1 << 0 = 1.
423    const MTL_BARRIER_SCOPE_BUFFERS: u64 = 1;
424    unsafe {
425        let _: () =
426            objc::msg_send![encoder, memoryBarrierWithScope: MTL_BARRIER_SCOPE_BUFFERS];
427    }
428}
429
430/// A batched compute command encoder.
431///
432/// Keeps a single Metal `ComputeCommandEncoder` alive across multiple
433/// dispatches.  The encoder is created on the first dispatch and ended
434/// only when the command buffer is committed.  This mirrors candle's
435/// `compute_per_buffer` pattern and avoids per-dispatch encoder overhead.
436///
437/// # Typical usage
438///
439/// ```ignore
440/// let mut enc = device.command_encoder()?;
441/// // Multiple dispatches share the same compute encoder:
442/// enc.encode_threadgroups(pipeline1, &buffers1, tg1, tg_size1);
443/// enc.encode_threadgroups(pipeline2, &buffers2, tg2, tg_size2);
444/// enc.commit_and_wait()?;
445/// ```
446pub struct CommandEncoder {
447    cmd_buf: CommandBuffer,
448    // SAFETY marker: see unsafe Send impl below.
449    /// Raw pointer to the persistent compute encoder.
450    /// Non-null when a compute pass is active.
451    /// The encoder borrows from `cmd_buf` but we cannot express this
452    /// lifetime in safe Rust, so we use a raw pointer.
453    /// SAFETY: the pointer is valid as long as `cmd_buf` is alive and
454    /// `end_encoding()` has not been called on it.
455    active_encoder: *const ComputeCommandEncoderRef,
456    /// When `Some`, dispatches are recorded here instead of being encoded
457    /// into Metal.  Set via `start_capture()`, extracted via `take_capture()`.
458    capture: Option<Vec<CapturedNode>>,
459    /// Op kind tag for the NEXT captured dispatch.  Set via `set_op_kind()`,
460    /// consumed (reset to `Other`) when a dispatch is captured.
461    pending_op_kind: CapturedOpKind,
462    /// Pending read buffer ranges for the NEXT captured dispatch.
463    /// Set via `set_pending_buffer_ranges()`, consumed when the next dispatch
464    /// is captured.  Used by the reorder pass (Phase 4e.3).
465    pending_reads: Vec<MemRange>,
466    /// Pending write buffer ranges for the NEXT captured dispatch.
467    pending_writes: Vec<MemRange>,
468    /// ADR-015 iter8e (Phase 3b): residency set whose pending add/remove
469    /// staging is flushed at every `commit*` boundary.
470    ///
471    /// Cloned from the device at `device.command_encoder()` time. `None`
472    /// when residency sets are disabled (HF2Q_NO_RESIDENCY=1, macOS<15,
473    /// or test-only `CommandEncoder::new` from a residency-less queue).
474    residency_set: Option<ResidencySet>,
475    /// ADR-015 iter37: dataflow barrier inference state.
476    ///
477    /// Populated only when `HF2Q_AUTO_BARRIER=1` is set at process
478    /// start (cached via [`auto_barrier_enabled`]).  Each
479    /// [`Self::dispatch_tracked`] call consults this state to decide
480    /// whether a Metal memory barrier is required; on conflict the
481    /// barrier is emitted, the state is reset, and the new dispatch's
482    /// ranges seed the next concurrent group.  When the env gate is
483    /// off, `dispatch_tracked` collapses to its untracked equivalent
484    /// and this field is left empty for the encoder's lifetime.
485    ///
486    /// The field is always present (zero-sized when empty) so the
487    /// gate-off branch is a single bool-load + early return rather
488    /// than an allocation/Option indirection.
489    mem_ranges: MemRanges,
490}
491
492/// SAFETY: CommandEncoder is safe to Send across threads provided that:
493/// 1. Only one thread accesses the encoder at a time (exclusive ownership).
494/// 2. The encoder is not used concurrently from multiple threads.
495///
496/// Metal command buffers and compute encoders are thread-safe for exclusive
497/// access (Apple documentation: "You can create command buffers, encode
498/// commands, and submit them from any thread"). The raw pointer
499/// `active_encoder` borrows from `cmd_buf` and is valid as long as
500/// `cmd_buf` is alive — this invariant holds across thread boundaries
501/// because both fields move together.
502///
503/// This matches llama.cpp's pattern of encoding command buffers on GCD
504/// worker threads via `dispatch_apply`, and is used for the dual-buffer
505/// pipeline where buf1 is encoded on a worker thread while buf0 executes.
506unsafe impl Send for CommandEncoder {}
507
508impl CommandEncoder {
509    /// Create a new command encoder from the given command queue.
510    ///
511    /// This immediately creates a Metal command buffer.
512    ///
513    /// # Why retained references
514    ///
515    /// We use the regular `commandBuffer` (Metal retains every bound
516    /// resource for the lifetime of the buffer) rather than
517    /// `commandBufferWithUnretainedReferences`.  llama.cpp uses unretained
518    /// refs for an additional perf bump (~3-5% on M-series GPUs), but the
519    /// hf2q dispatch pattern allocates many transient scratch buffers
520    /// inside helper functions (`apply_proj` → `weight_bf16_owned`,
521    /// `apply_pre_norm` → `params`, etc.) that go out of scope at the
522    /// helper's return.  With unretained refs the metal::Buffer's ARC
523    /// drops to zero, freeing the underlying GPU memory before the
524    /// dispatch executes.  Verified 2026-04-26: switching to unretained
525    /// hits "Command buffer error: GPU command buffer completed with
526    /// error status" on the first MoE FFN dispatch.
527    ///
528    /// To enable unretained refs in the future, every helper that
529    /// allocates and dispatches must thread its scratch buffers up to a
530    /// caller scope that outlives the eventual commit, OR all such
531    /// scratch must come from the per-decode-token pool (which already
532    /// ARC-retains in its in_use list).  Today the lm_head + router-
533    /// download paths are still unpooled.
534    #[allow(dead_code)]
535    pub(crate) fn new(queue: &CommandQueue) -> Result<Self> {
536        Self::new_with_residency(queue, None)
537    }
538
539    /// Create a new command encoder, optionally bound to a residency set so
540    /// `commit*` boundaries can flush deferred add/remove staging.
541    ///
542    /// ADR-015 iter8e (Phase 3b): the encoder's `commit_and_wait`,
543    /// `commit_and_wait_labeled`, `commit`, `commit_labeled`,
544    /// `commit_wait_with_gpu_time` all call
545    /// [`ResidencySet::flush_pending`](ResidencySet::flush_pending) before
546    /// submitting the Metal command buffer. This converts the
547    /// per-allocation `[set commit]` storm
548    /// (~880 commits/decode-token in iter8d/8e claude+codex variants) into
549    /// at most one commit per CB submission — mirrors llama.cpp's
550    /// `ggml-metal-device.m:1378-1382` pattern (batch addAllocation in
551    /// loop, commit ONCE).
552    ///
553    /// ADR-015 iter13: when the `MLX_UNRETAINED_REFS=1` env var is set at
554    /// process start, this constructor uses
555    /// [`CommandQueueRef::new_command_buffer_with_unretained_references`]
556    /// instead of `new_command_buffer`.  llama.cpp's per-token decode CBs
557    /// use `commandBufferWithUnretainedReferences` (see
558    /// `/opt/llama.cpp/ggml/src/ggml-metal/ggml-metal-context.m:512`) which
559    /// skips Metal's per-buffer-binding ARC-retain on submit and saves
560    /// ~3-5% on M-series GPUs (per the docstring above).
561    ///
562    /// **Caller contract under unretained refs.**  Every Metal buffer bound
563    /// to a dispatch in this CB MUST outlive the CB's GPU completion.  In
564    /// the hf2q decode path, that means every transient scratch must be
565    /// either (a) backed by the per-decode-token arena pool
566    /// (`MlxBufferPool` keeps an ARC clone in `in_use` until the next
567    /// `reset` — see `buffer_pool.rs:60`) or (b) hoisted to a caller scope
568    /// that lives across the terminal `commit_and_wait_labeled`.  Helpers
569    /// in `apply_proj` / `apply_pre_norm` / lm_head cast / router-download
570    /// that allocated transients via `device.alloc_buffer` and dropped
571    /// them at function return MUST be lifted to `pooled_alloc_buffer`
572    /// before `MLX_UNRETAINED_REFS=1` is enabled, or the first MoE FFN
573    /// dispatch will crash with "Command buffer error: GPU command buffer
574    /// completed with error status" (verified 2026-04-26).
575    ///
576    /// The default (`MLX_UNRETAINED_REFS` unset) preserves retained-refs
577    /// behavior verbatim — this is the sourdough-safe path.
578    pub(crate) fn new_with_residency(
579        queue: &CommandQueue,
580        residency_set: Option<ResidencySet>,
581    ) -> Result<Self> {
582        let cmd_buf = if unretained_refs_enabled() {
583            queue.new_command_buffer_with_unretained_references().to_owned()
584        } else {
585            queue.new_command_buffer().to_owned()
586        };
587        CMD_BUF_COUNT.fetch_add(1, Ordering::Relaxed);
588        Ok(Self {
589            cmd_buf,
590            active_encoder: std::ptr::null(),
591            capture: None,
592            pending_op_kind: CapturedOpKind::Other,
593            pending_reads: Vec::new(),
594            pending_writes: Vec::new(),
595            residency_set,
596            mem_ranges: MemRanges::new(),
597        })
598    }
599
600    /// Enable capture mode.
601    ///
602    /// All subsequent dispatch and barrier calls will be recorded into a
603    /// `Vec<CapturedNode>` instead of being encoded into Metal.
604    /// Call `take_capture()` to extract the recorded nodes.
605    pub fn start_capture(&mut self) {
606        self.capture = Some(Vec::with_capacity(128));
607    }
608
609    /// Whether the encoder is currently in capture mode.
610    pub fn is_capturing(&self) -> bool {
611        self.capture.is_some()
612    }
613
614    /// Extract the captured nodes, ending capture mode.
615    ///
616    /// Returns `None` if capture mode was not active.
617    pub fn take_capture(&mut self) -> Option<Vec<CapturedNode>> {
618        self.capture.take()
619    }
620
621    /// Tag the NEXT captured dispatch with the given operation kind.
622    ///
623    /// The tag is consumed (reset to `Other`) after the next dispatch is
624    /// captured.  Only meaningful in capture mode — has no effect on
625    /// direct-dispatch encoding.
626    ///
627    /// Used by op dispatch functions to annotate captures for the fusion
628    /// pass (Phase 4e.2).
629    pub fn set_op_kind(&mut self, kind: CapturedOpKind) {
630        self.pending_op_kind = kind;
631    }
632
633    /// Consume and return the pending op kind, resetting it to `Other`.
634    fn take_pending_op_kind(&mut self) -> CapturedOpKind {
635        let kind = self.pending_op_kind;
636        self.pending_op_kind = CapturedOpKind::Other;
637        kind
638    }
639
640    /// Stash buffer range annotations for the NEXT captured dispatch.
641    ///
642    /// Called by `GraphSession::barrier_between()` in capture mode to record
643    /// which buffers the next dispatch reads from and writes to.  The ranges
644    /// are consumed by the next `encode_*` call and attached to the captured
645    /// `CapturedNode::Dispatch`.
646    ///
647    /// Only meaningful in capture mode — has no effect on direct-dispatch.
648    pub fn set_pending_buffer_ranges(&mut self, reads: Vec<MemRange>, writes: Vec<MemRange>) {
649        self.pending_reads = reads;
650        self.pending_writes = writes;
651    }
652
653    /// Patch the last captured dispatch node's empty reads/writes with the
654    /// given ranges. No-op if not capturing, or if the last node isn't a
655    /// Dispatch, or if its ranges are already populated.
656    ///
657    /// Used by `GraphSession::track_dispatch` in recording mode to annotate
658    /// dispatches that were called without a preceding `barrier_between`.
659    pub fn annotate_last_dispatch_if_missing(&mut self, reads: Vec<MemRange>, writes: Vec<MemRange>) {
660        if let Some(ref mut nodes) = self.capture {
661            if let Some(CapturedNode::Dispatch { reads: r, writes: w, .. }) = nodes.last_mut() {
662                if r.is_empty() && !reads.is_empty() {
663                    *r = reads;
664                }
665                if w.is_empty() && !writes.is_empty() {
666                    *w = writes;
667                }
668            }
669        }
670    }
671
672    /// Consume and return the pending buffer range annotations.
673    fn take_pending_buffer_ranges(&mut self) -> (Vec<MemRange>, Vec<MemRange>) {
674        let reads = std::mem::take(&mut self.pending_reads);
675        let writes = std::mem::take(&mut self.pending_writes);
676        (reads, writes)
677    }
678
679    /// Record buffer bindings into `RecordedBinding` form.
680    fn record_buffer_bindings(buffers: &[(u64, &MlxBuffer)]) -> Vec<(u64, RecordedBinding)> {
681        buffers
682            .iter()
683            .map(|&(index, buf)| {
684                (
685                    index,
686                    RecordedBinding::Buffer {
687                        metal_buffer: buf.metal_buffer().clone(),
688                        offset: buf.byte_offset(),
689                    },
690                )
691            })
692            .collect()
693    }
694
695    /// Record `KernelArg` bindings into `RecordedBinding` form.
696    ///
697    /// `KernelArg::Buffer(buf)` records `buf.byte_offset()` so capture →
698    /// replay round-trips of `slice_view`-derived buffers preserve their
699    /// offsets, matching `record_buffer_bindings`'s behavior at line 382.
700    fn record_arg_bindings(bindings: &[(u64, KernelArg<'_>)]) -> Vec<(u64, RecordedBinding)> {
701        bindings
702            .iter()
703            .map(|(index, arg)| {
704                let recorded = match arg {
705                    KernelArg::Buffer(buf) => RecordedBinding::Buffer {
706                        metal_buffer: buf.metal_buffer().clone(),
707                        offset: buf.byte_offset(),
708                    },
709                    KernelArg::BufferWithOffset(buf, offset) => RecordedBinding::Buffer {
710                        metal_buffer: buf.metal_buffer().clone(),
711                        offset: *offset,
712                    },
713                    KernelArg::Bytes(bytes) => RecordedBinding::Bytes(bytes.to_vec()),
714                };
715                (*index, recorded)
716            })
717            .collect()
718    }
719
720    /// Get or create the persistent compute encoder.
721    ///
722    /// On the first call, creates a new compute encoder from the command
723    /// buffer.  On subsequent calls, returns the existing one.
724    ///
725    /// SAFETY: The returned reference borrows from `self.cmd_buf` which is
726    /// alive for the lifetime of this `CommandEncoder`.  The raw pointer is
727    /// valid until `end_active_encoder()` is called.
728    #[inline]
729    fn get_or_create_encoder(&mut self) -> &ComputeCommandEncoderRef {
730        if self.active_encoder.is_null() {
731            // Use MTLDispatchTypeConcurrent to allow independent dispatches
732            // to overlap on the GPU.  Memory barriers are inserted between
733            // dependent dispatches via `memory_barrier()`.
734            let encoder = self
735                .cmd_buf
736                .compute_command_encoder_with_dispatch_type(MTLDispatchType::Concurrent);
737            self.active_encoder = encoder as *const ComputeCommandEncoderRef;
738        }
739        // SAFETY: active_encoder is non-null and points to a valid encoder
740        // owned by cmd_buf.
741        unsafe { &*self.active_encoder }
742    }
743
744    /// End the active compute encoder if one exists.
745    #[inline]
746    fn end_active_encoder(&mut self) {
747        if !self.active_encoder.is_null() {
748            // SAFETY: the pointer was obtained from cmd_buf.new_compute_command_encoder()
749            // and has not been ended yet.
750            unsafe { &*self.active_encoder }.end_encoding();
751            self.active_encoder = std::ptr::null();
752        }
753    }
754
755    /// Insert a memory barrier with scope `MTLBarrierScopeBuffers`.
756    ///
757    /// When the encoder uses `MTLDispatchTypeConcurrent`, all dispatches can
758    /// execute concurrently unless separated by a barrier.  Call this between
759    /// dispatches where the later dispatch reads a buffer written by an
760    /// earlier one.
761    ///
762    /// This is the same pattern llama.cpp uses:
763    /// `[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers]`
764    #[allow(unexpected_cfgs)]
765    pub fn memory_barrier(&mut self) {
766        if let Some(ref mut nodes) = self.capture {
767            nodes.push(CapturedNode::Barrier);
768            return;
769        }
770        if self.active_encoder.is_null() {
771            return;
772        }
773        BARRIER_COUNT.fetch_add(1, Ordering::Relaxed);
774        // SAFETY: active_encoder is non-null and valid.
775        let encoder = unsafe { &*self.active_encoder };
776        if barrier_profile_enabled() {
777            // mach_absolute_time path — only on when MLX_PROFILE_BARRIERS=1.
778            let start = std::time::Instant::now();
779            issue_metal_buffer_barrier(encoder);
780            let elapsed_ns = start.elapsed().as_nanos() as u64;
781            BARRIER_NS.fetch_add(elapsed_ns, Ordering::Relaxed);
782        } else {
783            issue_metal_buffer_barrier(encoder);
784        }
785    }
786
787    /// Set the compute pipeline state for subsequent dispatches.
788    ///
789    /// This begins a new compute pass if one is not already active.
790    pub fn set_pipeline(&mut self, pipeline: &ComputePipelineStateRef) {
791        let encoder = self.get_or_create_encoder();
792        encoder.set_compute_pipeline_state(pipeline);
793    }
794
795    /// Bind a buffer to a compute kernel argument slot.
796    ///
797    /// The `index` corresponds to the `[[buffer(N)]]` attribute in the MSL shader.
798    pub fn set_buffer(&self, index: u64, buffer: &MlxBuffer) {
799        let _ = (index, buffer);
800    }
801
802    /// Dispatch threads on the GPU.
803    pub fn dispatch_threads(&self, grid_size: MTLSize, threadgroup_size: MTLSize) {
804        let _ = (grid_size, threadgroup_size);
805    }
806
807    /// Encode a complete compute pass: set pipeline, bind buffers, dispatch.
808    ///
809    /// Reuses the persistent compute encoder — no per-dispatch encoder
810    /// creation overhead.
811    ///
812    /// # Arguments
813    ///
814    /// * `pipeline`         — The compiled compute pipeline to execute.
815    /// * `buffers`          — Slice of `(index, &MlxBuffer)` pairs for buffer bindings.
816    /// * `grid_size`        — Total number of threads to launch.
817    /// * `threadgroup_size` — Threads per threadgroup.
818    pub fn encode(
819        &mut self,
820        pipeline: &ComputePipelineStateRef,
821        buffers: &[(u64, &MlxBuffer)],
822        grid_size: MTLSize,
823        threadgroup_size: MTLSize,
824    ) {
825        DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
826        let op_kind = self.take_pending_op_kind();
827        let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
828        if let Some(ref mut nodes) = self.capture {
829            nodes.push(CapturedNode::Dispatch {
830                pipeline: pipeline.to_owned(),
831                bindings: Self::record_buffer_bindings(buffers),
832                threads_per_grid: grid_size,
833                threads_per_threadgroup: threadgroup_size,
834                threadgroup_memory: Vec::new(),
835                dispatch_kind: DispatchKind::Threads,
836                op_kind,
837                reads: pending_reads,
838                writes: pending_writes,
839            });
840            return;
841        }
842        let encoder = self.get_or_create_encoder();
843        encoder.set_compute_pipeline_state(pipeline);
844        for &(index, buf) in buffers {
845            encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
846        }
847        encoder.dispatch_threads(grid_size, threadgroup_size);
848    }
849
850    /// Encode a compute pass using threadgroups instead of raw thread counts.
851    ///
852    /// Reuses the persistent compute encoder — no per-dispatch encoder
853    /// creation overhead.
854    pub fn encode_threadgroups(
855        &mut self,
856        pipeline: &ComputePipelineStateRef,
857        buffers: &[(u64, &MlxBuffer)],
858        threadgroups: MTLSize,
859        threadgroup_size: MTLSize,
860    ) {
861        DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
862        let op_kind = self.take_pending_op_kind();
863        let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
864        if let Some(ref mut nodes) = self.capture {
865            nodes.push(CapturedNode::Dispatch {
866                pipeline: pipeline.to_owned(),
867                bindings: Self::record_buffer_bindings(buffers),
868                threads_per_grid: threadgroups,
869                threads_per_threadgroup: threadgroup_size,
870                threadgroup_memory: Vec::new(),
871                dispatch_kind: DispatchKind::ThreadGroups,
872                op_kind,
873                reads: pending_reads,
874                writes: pending_writes,
875            });
876            return;
877        }
878        let encoder = self.get_or_create_encoder();
879        encoder.set_compute_pipeline_state(pipeline);
880        for &(index, buf) in buffers {
881            encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
882        }
883        encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
884    }
885
886    /// Encode a compute pass using threadgroups with shared threadgroup memory.
887    ///
888    /// Like [`encode_threadgroups`](Self::encode_threadgroups), but additionally
889    /// allocates threadgroup memory at the specified indices.  This is required
890    /// for kernels that use `threadgroup` memory (e.g. reductions in rms_norm
891    /// and softmax).
892    ///
893    /// # Arguments
894    ///
895    /// * `pipeline`         — The compiled compute pipeline to execute.
896    /// * `buffers`          — Slice of `(index, &MlxBuffer)` pairs for buffer bindings.
897    /// * `threadgroup_mem`  — Slice of `(index, byte_length)` pairs for threadgroup memory.
898    /// * `threadgroups`     — Number of threadgroups to dispatch.
899    /// * `threadgroup_size` — Threads per threadgroup.
900    pub fn encode_threadgroups_with_shared(
901        &mut self,
902        pipeline: &ComputePipelineStateRef,
903        buffers: &[(u64, &MlxBuffer)],
904        threadgroup_mem: &[(u64, u64)],
905        threadgroups: MTLSize,
906        threadgroup_size: MTLSize,
907    ) {
908        DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
909        let op_kind = self.take_pending_op_kind();
910        let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
911        if let Some(ref mut nodes) = self.capture {
912            nodes.push(CapturedNode::Dispatch {
913                pipeline: pipeline.to_owned(),
914                bindings: Self::record_buffer_bindings(buffers),
915                threads_per_grid: threadgroups,
916                threads_per_threadgroup: threadgroup_size,
917                threadgroup_memory: threadgroup_mem.to_vec(),
918                dispatch_kind: DispatchKind::ThreadGroups,
919                op_kind,
920                reads: pending_reads,
921                writes: pending_writes,
922            });
923            return;
924        }
925        let encoder = self.get_or_create_encoder();
926        encoder.set_compute_pipeline_state(pipeline);
927        for &(index, buf) in buffers {
928            encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
929        }
930        for &(index, byte_length) in threadgroup_mem {
931            encoder.set_threadgroup_memory_length(index, byte_length);
932        }
933        encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
934    }
935
936    /// Encode a dispatch with mixed buffer/bytes bindings (dispatch_threads).
937    ///
938    /// Reuses the persistent compute encoder.
939    pub fn encode_with_args(
940        &mut self,
941        pipeline: &ComputePipelineStateRef,
942        bindings: &[(u64, KernelArg<'_>)],
943        grid_size: MTLSize,
944        threadgroup_size: MTLSize,
945    ) {
946        DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
947        let op_kind = self.take_pending_op_kind();
948        let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
949        if let Some(ref mut nodes) = self.capture {
950            nodes.push(CapturedNode::Dispatch {
951                pipeline: pipeline.to_owned(),
952                bindings: Self::record_arg_bindings(bindings),
953                threads_per_grid: grid_size,
954                threads_per_threadgroup: threadgroup_size,
955                threadgroup_memory: Vec::new(),
956                dispatch_kind: DispatchKind::Threads,
957                op_kind,
958                reads: pending_reads,
959                writes: pending_writes,
960            });
961            return;
962        }
963        let encoder = self.get_or_create_encoder();
964        encoder.set_compute_pipeline_state(pipeline);
965        apply_bindings(encoder, bindings);
966        encoder.dispatch_threads(grid_size, threadgroup_size);
967    }
968
969    /// Encode a dispatch with mixed buffer/bytes bindings (dispatch_thread_groups).
970    ///
971    /// Reuses the persistent compute encoder.
972    pub fn encode_threadgroups_with_args(
973        &mut self,
974        pipeline: &ComputePipelineStateRef,
975        bindings: &[(u64, KernelArg<'_>)],
976        threadgroups: MTLSize,
977        threadgroup_size: MTLSize,
978    ) {
979        DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
980        let op_kind = self.take_pending_op_kind();
981        let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
982        if let Some(ref mut nodes) = self.capture {
983            nodes.push(CapturedNode::Dispatch {
984                pipeline: pipeline.to_owned(),
985                bindings: Self::record_arg_bindings(bindings),
986                threads_per_grid: threadgroups,
987                threads_per_threadgroup: threadgroup_size,
988                threadgroup_memory: Vec::new(),
989                dispatch_kind: DispatchKind::ThreadGroups,
990                op_kind,
991                reads: pending_reads,
992                writes: pending_writes,
993            });
994            return;
995        }
996        let encoder = self.get_or_create_encoder();
997        encoder.set_compute_pipeline_state(pipeline);
998        apply_bindings(encoder, bindings);
999        encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
1000    }
1001
1002    /// Encode a dispatch with mixed buffer/bytes bindings and shared memory.
1003    ///
1004    /// Reuses the persistent compute encoder.
1005    pub fn encode_threadgroups_with_args_and_shared(
1006        &mut self,
1007        pipeline: &ComputePipelineStateRef,
1008        bindings: &[(u64, KernelArg<'_>)],
1009        threadgroup_mem: &[(u64, u64)],
1010        threadgroups: MTLSize,
1011        threadgroup_size: MTLSize,
1012    ) {
1013        DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
1014        let op_kind = self.take_pending_op_kind();
1015        let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
1016        if let Some(ref mut nodes) = self.capture {
1017            nodes.push(CapturedNode::Dispatch {
1018                pipeline: pipeline.to_owned(),
1019                bindings: Self::record_arg_bindings(bindings),
1020                threads_per_grid: threadgroups,
1021                threads_per_threadgroup: threadgroup_size,
1022                threadgroup_memory: threadgroup_mem.to_vec(),
1023                dispatch_kind: DispatchKind::ThreadGroups,
1024                op_kind,
1025                reads: pending_reads,
1026                writes: pending_writes,
1027            });
1028            return;
1029        }
1030        let encoder = self.get_or_create_encoder();
1031        encoder.set_compute_pipeline_state(pipeline);
1032        apply_bindings(encoder, bindings);
1033        for &(index, byte_length) in threadgroup_mem {
1034            encoder.set_threadgroup_memory_length(index, byte_length);
1035        }
1036        encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
1037    }
1038
1039    // -----------------------------------------------------------------
1040    // ADR-015 iter37 — dataflow-driven auto-barrier dispatch family.
1041    //
1042    // These mirrors of `encode_threadgroups*_with_args*` take explicit
1043    // `reads: &[&MlxBuffer]` and `writes: &[&MlxBuffer]` slices.  When
1044    // the process started with `HF2Q_AUTO_BARRIER=1`, the encoder's
1045    // [`MemRanges`] tracker checks the new ranges against the
1046    // cumulative state since the last barrier; on conflict it emits
1047    // `memory_barrier()` and resets the state before recording the
1048    // new ranges.  When the env gate is unset, the check is skipped
1049    // entirely and the dispatch is applied identically to the
1050    // matching `encode_*` method — sourdough-safe by construction.
1051    //
1052    // Capture mode: the `reads`/`writes` ranges are recorded onto the
1053    // captured node via the existing `pending_reads`/`pending_writes`
1054    // mechanism, so a `dispatch_tracked` call inside capture mode is
1055    // equivalent to `set_pending_buffer_ranges + encode_*`.
1056    //
1057    // No production callsite migrates in iter37 — this is the API
1058    // surface the qwen35 forward path will adopt incrementally in
1059    // iter38+.  Today, every call to `dispatch_tracked` from a
1060    // production code path lives behind an explicit caller decision
1061    // to opt in.
1062    // -----------------------------------------------------------------
1063
1064    /// Auto-barrier-aware dispatch with [`KernelArg`] bindings (uses
1065    /// `dispatch_thread_groups`).
1066    ///
1067    /// Behaves identically to
1068    /// [`encode_threadgroups_with_args`](Self::encode_threadgroups_with_args)
1069    /// when `HF2Q_AUTO_BARRIER` is unset.  When set, consults the
1070    /// per-encoder [`MemRanges`] tracker:
1071    ///
1072    /// * Conflict (RAW/WAR/WAW on a same-buffer range) → emit
1073    ///   `memory_barrier()`, increment [`AUTO_BARRIER_COUNT`], reset
1074    ///   the tracker, then dispatch and seed the new concurrent group
1075    ///   with this dispatch's ranges.
1076    /// * No conflict → increment [`AUTO_BARRIER_CONCURRENT`], record
1077    ///   the ranges into the cumulative state, dispatch.
1078    pub fn dispatch_tracked_threadgroups_with_args(
1079        &mut self,
1080        pipeline: &ComputePipelineStateRef,
1081        bindings: &[(u64, KernelArg<'_>)],
1082        reads: &[&MlxBuffer],
1083        writes: &[&MlxBuffer],
1084        threadgroups: MTLSize,
1085        threadgroup_size: MTLSize,
1086    ) {
1087        // Capture mode: stash ranges + delegate to the standard encode.
1088        // The ranges flow through `pending_reads`/`pending_writes` and
1089        // attach to the captured `Dispatch` node — identical to what
1090        // `GraphSession::barrier_between` already does in capture mode.
1091        if self.is_capturing() {
1092            let read_ranges = ranges_from_buffers(reads);
1093            let write_ranges = ranges_from_buffers(writes);
1094            self.set_pending_buffer_ranges(read_ranges, write_ranges);
1095            self.encode_threadgroups_with_args(pipeline, bindings, threadgroups, threadgroup_size);
1096            return;
1097        }
1098
1099        if auto_barrier_enabled() {
1100            self.maybe_auto_barrier(reads, writes);
1101        }
1102
1103        self.encode_threadgroups_with_args(pipeline, bindings, threadgroups, threadgroup_size);
1104    }
1105
1106    /// Auto-barrier-aware dispatch with [`KernelArg`] bindings + shared
1107    /// threadgroup memory.
1108    ///
1109    /// See [`dispatch_tracked_threadgroups_with_args`](Self::dispatch_tracked_threadgroups_with_args)
1110    /// for the behavioral contract; this variant additionally takes a
1111    /// `threadgroup_mem` slice that is forwarded to
1112    /// [`encode_threadgroups_with_args_and_shared`](Self::encode_threadgroups_with_args_and_shared).
1113    ///
1114    /// The 8-argument signature mirrors the existing
1115    /// `encode_threadgroups_with_args_and_shared` plus the two
1116    /// dataflow slices; `clippy::too_many_arguments` is allowed
1117    /// because each parameter is load-bearing for either the dispatch
1118    /// (pipeline/bindings/threadgroups/threadgroup_size/shared_mem)
1119    /// or the auto-barrier (reads/writes).
1120    #[allow(clippy::too_many_arguments)]
1121    pub fn dispatch_tracked_threadgroups_with_args_and_shared(
1122        &mut self,
1123        pipeline: &ComputePipelineStateRef,
1124        bindings: &[(u64, KernelArg<'_>)],
1125        threadgroup_mem: &[(u64, u64)],
1126        reads: &[&MlxBuffer],
1127        writes: &[&MlxBuffer],
1128        threadgroups: MTLSize,
1129        threadgroup_size: MTLSize,
1130    ) {
1131        if self.is_capturing() {
1132            let read_ranges = ranges_from_buffers(reads);
1133            let write_ranges = ranges_from_buffers(writes);
1134            self.set_pending_buffer_ranges(read_ranges, write_ranges);
1135            self.encode_threadgroups_with_args_and_shared(
1136                pipeline,
1137                bindings,
1138                threadgroup_mem,
1139                threadgroups,
1140                threadgroup_size,
1141            );
1142            return;
1143        }
1144
1145        if auto_barrier_enabled() {
1146            self.maybe_auto_barrier(reads, writes);
1147        }
1148
1149        self.encode_threadgroups_with_args_and_shared(
1150            pipeline,
1151            bindings,
1152            threadgroup_mem,
1153            threadgroups,
1154            threadgroup_size,
1155        );
1156    }
1157
1158    /// Auto-barrier-aware dispatch using `(slot, &MlxBuffer)` bindings
1159    /// (uses `dispatch_thread_groups`).
1160    ///
1161    /// Convenience wrapper for callers that don't need
1162    /// [`KernelArg::Bytes`] inline-byte arguments.  See
1163    /// [`dispatch_tracked_threadgroups_with_args`](Self::dispatch_tracked_threadgroups_with_args)
1164    /// for behavioral contract.
1165    pub fn dispatch_tracked_threadgroups(
1166        &mut self,
1167        pipeline: &ComputePipelineStateRef,
1168        buffers: &[(u64, &MlxBuffer)],
1169        reads: &[&MlxBuffer],
1170        writes: &[&MlxBuffer],
1171        threadgroups: MTLSize,
1172        threadgroup_size: MTLSize,
1173    ) {
1174        if self.is_capturing() {
1175            let read_ranges = ranges_from_buffers(reads);
1176            let write_ranges = ranges_from_buffers(writes);
1177            self.set_pending_buffer_ranges(read_ranges, write_ranges);
1178            self.encode_threadgroups(pipeline, buffers, threadgroups, threadgroup_size);
1179            return;
1180        }
1181
1182        if auto_barrier_enabled() {
1183            self.maybe_auto_barrier(reads, writes);
1184        }
1185
1186        self.encode_threadgroups(pipeline, buffers, threadgroups, threadgroup_size);
1187    }
1188
1189    /// Auto-barrier-aware dispatch using `(slot, &MlxBuffer)` bindings
1190    /// **plus shared threadgroup memory** (uses `dispatch_thread_groups`).
1191    ///
1192    /// Mirrors [`encode_threadgroups_with_shared`](Self::encode_threadgroups_with_shared)
1193    /// — convenience variant for kernels that allocate threadgroup
1194    /// memory (reductions in `rms_norm`, `softmax`, etc.) but don't
1195    /// need [`KernelArg::Bytes`] inline-byte arguments.  See
1196    /// [`dispatch_tracked_threadgroups_with_args`](Self::dispatch_tracked_threadgroups_with_args)
1197    /// for the behavioral contract; the only addition here is the
1198    /// `threadgroup_mem` slice forwarded to the underlying encode.
1199    ///
1200    /// Closes the iter38-audit coverage gap: the 5 `rms_norm.rs`
1201    /// callsites (`/opt/mlx-native/src/ops/rms_norm.rs:124,236,443,
1202    /// 516,589`) all use `encode_threadgroups_with_shared` and need
1203    /// dataflow tracking when migrated to auto-barrier in iter40+.
1204    ///
1205    /// 7-argument signature; `clippy::too_many_arguments` is allowed
1206    /// because each parameter is load-bearing for either the dispatch
1207    /// (pipeline/buffers/threadgroups/threadgroup_size/shared_mem) or
1208    /// the auto-barrier (reads/writes).
1209    #[allow(clippy::too_many_arguments)]
1210    pub fn dispatch_tracked_threadgroups_with_shared(
1211        &mut self,
1212        pipeline: &ComputePipelineStateRef,
1213        buffers: &[(u64, &MlxBuffer)],
1214        threadgroup_mem: &[(u64, u64)],
1215        reads: &[&MlxBuffer],
1216        writes: &[&MlxBuffer],
1217        threadgroups: MTLSize,
1218        threadgroup_size: MTLSize,
1219    ) {
1220        if self.is_capturing() {
1221            let read_ranges = ranges_from_buffers(reads);
1222            let write_ranges = ranges_from_buffers(writes);
1223            self.set_pending_buffer_ranges(read_ranges, write_ranges);
1224            self.encode_threadgroups_with_shared(
1225                pipeline,
1226                buffers,
1227                threadgroup_mem,
1228                threadgroups,
1229                threadgroup_size,
1230            );
1231            return;
1232        }
1233
1234        if auto_barrier_enabled() {
1235            self.maybe_auto_barrier(reads, writes);
1236        }
1237
1238        self.encode_threadgroups_with_shared(
1239            pipeline,
1240            buffers,
1241            threadgroup_mem,
1242            threadgroups,
1243            threadgroup_size,
1244        );
1245    }
1246
1247    /// Auto-barrier-aware `dispatch_threads` variant with
1248    /// [`KernelArg`] bindings.
1249    ///
1250    /// Mirrors [`encode_with_args`](Self::encode_with_args) — the
1251    /// `dispatch_threads` (per-thread grid) flavor, as opposed to the
1252    /// `dispatch_thread_groups` flavor of
1253    /// [`dispatch_tracked_threadgroups_with_args`](Self::dispatch_tracked_threadgroups_with_args).
1254    /// See that method for the behavioral contract.
1255    ///
1256    /// Closes the iter38-audit coverage gap: callers that use
1257    /// per-thread grids — `rope.rs:108` (IMROPE), `sigmoid_mul.rs:76`
1258    /// (sigmoid-mul), and `encode_helpers.rs:41` (kv_cache_copy) —
1259    /// need a `dispatch_threads` flavor of the tracked dispatch
1260    /// because their grid sizes are expressed in threads, not
1261    /// threadgroups.
1262    ///
1263    /// Note: the simpler `(slot, &MlxBuffer)` form (from
1264    /// [`encode`](Self::encode)) is a special case of this method —
1265    /// callers can wrap each binding as `KernelArg::Buffer(buf)` to
1266    /// reuse this single tracked variant rather than introducing a
1267    /// fifth one.
1268    pub fn dispatch_tracked_threads_with_args(
1269        &mut self,
1270        pipeline: &ComputePipelineStateRef,
1271        bindings: &[(u64, KernelArg<'_>)],
1272        reads: &[&MlxBuffer],
1273        writes: &[&MlxBuffer],
1274        grid_size: MTLSize,
1275        threadgroup_size: MTLSize,
1276    ) {
1277        if self.is_capturing() {
1278            let read_ranges = ranges_from_buffers(reads);
1279            let write_ranges = ranges_from_buffers(writes);
1280            self.set_pending_buffer_ranges(read_ranges, write_ranges);
1281            self.encode_with_args(pipeline, bindings, grid_size, threadgroup_size);
1282            return;
1283        }
1284
1285        if auto_barrier_enabled() {
1286            self.maybe_auto_barrier(reads, writes);
1287        }
1288
1289        self.encode_with_args(pipeline, bindings, grid_size, threadgroup_size);
1290    }
1291
1292    /// Run the dataflow check, emit a barrier on conflict, and record
1293    /// the dispatch's ranges into the cumulative state.
1294    ///
1295    /// Always called *before* the underlying `encode_*` method
1296    /// applies the dispatch.  Mirrors lines 220-225 of
1297    /// `ggml-metal-ops.cpp` (`concurrency_check + concurrency_reset +
1298    /// concurrency_add` around each node).
1299    fn maybe_auto_barrier(
1300        &mut self,
1301        reads: &[&MlxBuffer],
1302        writes: &[&MlxBuffer],
1303    ) {
1304        if self.mem_ranges.check_dispatch(reads, writes) {
1305            // Concurrent — no barrier needed; just record the new ranges.
1306            self.mem_ranges.add_dispatch(reads, writes);
1307            AUTO_BARRIER_CONCURRENT.fetch_add(1, Ordering::Relaxed);
1308        } else {
1309            // Conflict — emit barrier, reset state, seed new group.
1310            //
1311            // `memory_barrier()` itself increments `BARRIER_COUNT` and,
1312            // when `MLX_PROFILE_BARRIERS=1`, accumulates `BARRIER_NS`.
1313            // We additionally bump `AUTO_BARRIER_COUNT` so the
1314            // "auto-emitted vs hand-placed" subset is queryable.
1315            self.memory_barrier();
1316            self.mem_ranges.reset();
1317            self.mem_ranges.add_dispatch(reads, writes);
1318            AUTO_BARRIER_COUNT.fetch_add(1, Ordering::Relaxed);
1319        }
1320    }
1321
1322    /// Force a barrier and reset the auto-barrier tracker.
1323    ///
1324    /// Use at boundaries where the caller knows a barrier is required
1325    /// regardless of dataflow — typically before reading data back to
1326    /// CPU, or at the end of an op group whose internal dependencies
1327    /// the tracker can't see (e.g. host-driven memcpy).
1328    ///
1329    /// Equivalent to `memory_barrier()` plus a `MemRanges::reset()`
1330    /// when `HF2Q_AUTO_BARRIER=1`; equivalent to plain
1331    /// `memory_barrier()` otherwise.
1332    pub fn force_barrier_and_reset_tracker(&mut self) {
1333        self.memory_barrier();
1334        if auto_barrier_enabled() {
1335            self.mem_ranges.reset();
1336        }
1337    }
1338
1339    /// Diagnostic accessor — number of ranges currently recorded in
1340    /// this encoder's [`MemRanges`] tracker.  Always zero unless
1341    /// `HF2Q_AUTO_BARRIER=1` and at least one `dispatch_tracked` call
1342    /// has fired since the last conflict.
1343    #[inline]
1344    pub fn mem_ranges_len(&self) -> usize {
1345        self.mem_ranges.len()
1346    }
1347
1348    /// Replay a single captured dispatch node into this encoder.
1349    ///
1350    /// This is the inverse of capture: it takes a previously recorded
1351    /// `CapturedNode::Dispatch` and encodes it into the live Metal encoder.
1352    /// Barrier nodes are handled by the caller (ComputeGraph::encode_sequential).
1353    ///
1354    /// Does NOT increment `DISPATCH_COUNT` — that was already counted at
1355    /// capture time.
1356    pub fn replay_dispatch(
1357        &mut self,
1358        pipeline: &ComputePipelineStateRef,
1359        bindings: &[(u64, RecordedBinding)],
1360        threadgroup_memory: &[(u64, u64)],
1361        threads_per_grid: MTLSize,
1362        threads_per_threadgroup: MTLSize,
1363        dispatch_kind: DispatchKind,
1364    ) {
1365        let encoder = self.get_or_create_encoder();
1366        encoder.set_compute_pipeline_state(pipeline);
1367        for (index, binding) in bindings {
1368            match binding {
1369                RecordedBinding::Buffer { metal_buffer, offset } => {
1370                    encoder.set_buffer(*index, Some(metal_buffer), *offset);
1371                }
1372                RecordedBinding::Bytes(bytes) => {
1373                    encoder.set_bytes(
1374                        *index,
1375                        bytes.len() as u64,
1376                        bytes.as_ptr() as *const _,
1377                    );
1378                }
1379            }
1380        }
1381        for &(index, byte_length) in threadgroup_memory {
1382            encoder.set_threadgroup_memory_length(index, byte_length);
1383        }
1384        match dispatch_kind {
1385            DispatchKind::Threads => {
1386                encoder.dispatch_threads(threads_per_grid, threads_per_threadgroup);
1387            }
1388            DispatchKind::ThreadGroups => {
1389                encoder.dispatch_thread_groups(threads_per_grid, threads_per_threadgroup);
1390            }
1391        }
1392    }
1393
1394    /// Flush any pending residency-set add/remove staging.
1395    ///
1396    /// Hooked at every commit boundary so per-allocation
1397    /// [`ResidencySet::add_allocation`](ResidencySet::add_allocation) and
1398    /// [`ResidencySet::remove_allocation`](ResidencySet::remove_allocation)
1399    /// calls (as fired by `MlxDevice::alloc_buffer` and
1400    /// `MlxBufferStorage::Drop`) collapse into at most ONE `[set commit]`
1401    /// per CB submission. Mirrors llama.cpp's
1402    /// `ggml-metal-device.m:1378-1382` (batch addAllocation in loop,
1403    /// commit ONCE).
1404    #[inline]
1405    fn flush_residency_pending(&self) {
1406        if let Some(set) = self.residency_set.as_ref() {
1407            set.flush_pending();
1408        }
1409    }
1410
1411    /// Commit the command buffer and block until the GPU finishes execution.
1412    ///
1413    /// # Errors
1414    ///
1415    /// Returns `MlxError::CommandBufferError` if the GPU reports an error.
1416    pub fn commit_and_wait(&mut self) -> Result<()> {
1417        SYNC_COUNT.fetch_add(1, Ordering::Relaxed);
1418
1419        // End the persistent compute encoder before committing.
1420        self.end_active_encoder();
1421
1422        // ADR-015 iter8e (Phase 3b): flush deferred residency-set
1423        // add/remove staging so the residency hint covers any buffers
1424        // referenced by this CB. Single commit per CB boundary; no-op
1425        // when no residency set or no staged changes.
1426        self.flush_residency_pending();
1427
1428        self.cmd_buf.commit();
1429        self.cmd_buf.wait_until_completed();
1430
1431        match self.cmd_buf.status() {
1432            MTLCommandBufferStatus::Completed => Ok(()),
1433            MTLCommandBufferStatus::Error => {
1434                Err(MlxError::CommandBufferError(
1435                    "GPU command buffer completed with error status".into(),
1436                ))
1437            }
1438            status => Err(MlxError::CommandBufferError(format!(
1439                "Unexpected command buffer status after wait: {:?}",
1440                status
1441            ))),
1442        }
1443    }
1444
1445    /// Commit + wait, accumulating GPU wall-clock time under `label` into
1446    /// the [`crate::kernel_profile`] global table when `MLX_PROFILE_CB=1`
1447    /// is set.  When the env var is unset, this is identical to
1448    /// [`commit_and_wait`](Self::commit_and_wait) — zero overhead.
1449    ///
1450    /// Used by hf2q's decode hot path to attribute per-cb GPU time to
1451    /// labeled phases (per-layer attn, per-layer ffn, output_head, etc.)
1452    /// without manually wiring `commit_wait_with_gpu_time` everywhere.
1453    ///
1454    /// # Errors
1455    ///
1456    /// Returns `MlxError::CommandBufferError` if the GPU reports an error.
1457    pub fn commit_and_wait_labeled(&mut self, label: &str) -> Result<()> {
1458        // ADR-015 iter16 — propagate `label` to MTLCommandBuffer.setLabel and
1459        // (if a compute encoder is active) MTLComputeCommandEncoder.setLabel
1460        // BEFORE end_encoding/commit so xctrace's
1461        // `metal-application-encoders-list` table populates `cmdbuffer-label`
1462        // and `encoder-label` columns with the semantic phase name (e.g.
1463        // `layer.attn_moe_ffn`, `output_head.fused_norm_lm_argmax`,
1464        // `layer.delta_net.ops1-9`).  Joined to per-CB GPU duration via
1465        // `metal-gpu-submission-to-command-buffer-id` (sub_id ↔ encoder_id) →
1466        // `metal-gpu-execution-points` (per-dispatch start/end), this enables
1467        // per-phase µs/token attribution comparing hf2q vs llama side-by-side
1468        // (iter15 §E "iter16 ATTRIBUTION PATH").  Cost is a single ObjC
1469        // msg_send per CB submission — sub-µs on M5 Max — and a no-op when
1470        // xctrace isn't recording, so this is unconditionally safe to call on
1471        // the production decode hot path.
1472        self.apply_labels(label);
1473        if crate::kernel_profile::is_enabled() {
1474            let (start_s, end_s) = self.commit_wait_with_gpu_time()?;
1475            let ns = ((end_s - start_s).max(0.0) * 1_000_000_000.0) as u64;
1476            crate::kernel_profile::record(label, ns);
1477            Ok(())
1478        } else {
1479            self.commit_and_wait()
1480        }
1481    }
1482
1483    /// Async commit, but with profiling label.  When `MLX_PROFILE_CB=1`
1484    /// is set, redirects to a synchronous [`commit_and_wait_labeled`]
1485    /// call to capture per-cb GPU time (this defeats async pipelining
1486    /// while profiling, which is the whole point — profile-mode is slow
1487    /// but informative).  When unset, identical to [`commit`](Self::commit).
1488    pub fn commit_labeled(&mut self, label: &str) {
1489        // ADR-015 iter16 — see `commit_and_wait_labeled` for rationale.
1490        if crate::kernel_profile::is_enabled() {
1491            // Profile mode: force sync to capture GPU time.  apply_labels is
1492            // called inside commit_and_wait_labeled — do NOT call it twice
1493            // here (would double the ObjC msg_send under MLX_PROFILE_CB=1).
1494            // Errors are logged via stderr because the void return matches
1495            // commit().
1496            if let Err(e) = self.commit_and_wait_labeled(label) {
1497                eprintln!("[mlx-native] commit_labeled({}) failed: {}", label, e);
1498            }
1499        } else {
1500            // Async path: apply labels here so xctrace MST traces capture
1501            // per-CB phase attribution under default decode (no
1502            // `MLX_PROFILE_CB`).
1503            self.apply_labels(label);
1504            self.commit();
1505        }
1506    }
1507
1508    /// Apply `label` to the underlying `MTLCommandBuffer` and, if a compute
1509    /// encoder is currently active, to the `MTLComputeCommandEncoder`.
1510    ///
1511    /// Called from [`commit_labeled`] and [`commit_and_wait_labeled`] BEFORE
1512    /// the encoder is ended / the CB is committed so xctrace's
1513    /// `metal-application-encoders-list` table picks up the label on the
1514    /// row emitted at the encoder's `endEncoding` / CB submission boundary.
1515    /// Single ObjC `msg_send` per call (two if an encoder is active); sub-µs
1516    /// on M5 Max; no-op when xctrace isn't recording.
1517    ///
1518    /// Skipped (debug-only assert) if `label` is empty — empty labels would
1519    /// produce an indistinguishable trace row from the metal-rs default
1520    /// `Command Buffer 0` placeholder.
1521    #[inline]
1522    fn apply_labels(&self, label: &str) {
1523        debug_assert!(!label.is_empty(), "commit_*_labeled called with empty label");
1524        if label.is_empty() {
1525            return;
1526        }
1527        self.cmd_buf.set_label(label);
1528        if !self.active_encoder.is_null() {
1529            // SAFETY: active_encoder is non-null and points to a live encoder
1530            // owned by cmd_buf — same invariant as get_or_create_encoder /
1531            // memory_barrier.  set_label is a single property write on the
1532            // ObjC object; safe before endEncoding.
1533            unsafe { &*self.active_encoder }.set_label(label);
1534        }
1535    }
1536
1537    /// Commit + wait, returning `(gpu_start_s, gpu_end_s)` CFTimeInterval
1538    /// timestamps from `MTLCommandBuffer`'s `GPUStartTime`/`GPUEndTime`
1539    /// properties.  Both are mach-absolute CFTimeInterval seconds (double).
1540    ///
1541    /// Intended for `HF2Q_PROFILE_GPU_TS=1` per-bucket GPU wall-clock
1542    /// attribution.  Adds exactly two ObjC property reads per call on top
1543    /// of the regular `commit_and_wait` — measured well under 1 μs on
1544    /// M5 Max.
1545    ///
1546    /// # Errors
1547    ///
1548    /// Returns `MlxError::CommandBufferError` if the GPU reports an error.
1549    pub fn commit_wait_with_gpu_time(&mut self) -> Result<(f64, f64)> {
1550        self.commit_and_wait()?;
1551        // SAFETY: cmd_buf is a valid MTLCommandBuffer that has been
1552        // committed and awaited.  GPUStartTime / GPUEndTime return
1553        // CFTimeInterval (double precision seconds).  See
1554        // https://developer.apple.com/documentation/metal/mtlcommandbuffer/1639925-gpustarttime
1555        let (gpu_start, gpu_end): (f64, f64) = unsafe {
1556            let cb = &*self.cmd_buf;
1557            let s: f64 = msg_send![cb, GPUStartTime];
1558            let e: f64 = msg_send![cb, GPUEndTime];
1559            (s, e)
1560        };
1561        Ok((gpu_start, gpu_end))
1562    }
1563
1564    /// Commit the command buffer WITHOUT blocking.
1565    ///
1566    /// The GPU begins executing the encoded commands immediately.  Call
1567    /// [`wait_until_completed`](Self::wait_until_completed) later to block
1568    /// the CPU and check for errors.  This allows the CPU to continue doing
1569    /// other work (e.g. preparing the next batch) while the GPU runs.
1570    pub fn commit(&mut self) {
1571        self.end_active_encoder();
1572        // ADR-015 iter8e (Phase 3b): same flush hook as commit_and_wait —
1573        // this is the async-pipeline path that production decode uses.
1574        self.flush_residency_pending();
1575        self.cmd_buf.commit();
1576    }
1577
1578    /// Block until a previously committed command buffer completes.
1579    ///
1580    /// Must be called after [`commit`](Self::commit).  Do not call after
1581    /// [`commit_and_wait`](Self::commit_and_wait) — that method already waits.
1582    ///
1583    /// # Errors
1584    ///
1585    /// Returns `MlxError::CommandBufferError` if the GPU reports an error.
1586    pub fn wait_until_completed(&self) -> Result<()> {
1587        self.cmd_buf.wait_until_completed();
1588        match self.cmd_buf.status() {
1589            MTLCommandBufferStatus::Completed => Ok(()),
1590            MTLCommandBufferStatus::Error => Err(MlxError::CommandBufferError(
1591                "GPU command buffer completed with error status".into(),
1592            )),
1593            status => Err(MlxError::CommandBufferError(format!(
1594                "Unexpected command buffer status after wait: {:?}",
1595                status
1596            ))),
1597        }
1598    }
1599
1600    /// Borrow the underlying Metal command buffer.
1601    #[inline]
1602    pub fn metal_command_buffer(&self) -> &CommandBuffer {
1603        &self.cmd_buf
1604    }
1605}
1606
1607impl Drop for CommandEncoder {
1608    fn drop(&mut self) {
1609        // End the persistent compute encoder before the command buffer
1610        // is dropped, otherwise Metal will assert:
1611        // "Command encoder released without endEncoding"
1612        self.end_active_encoder();
1613    }
1614}