Skip to main content

mlx_native/
encoder.rs

1//! [`CommandEncoder`] — batched GPU command submission.
2//!
3//! Wraps a Metal command buffer.  Encode one or more compute kernel dispatches,
4//! then call [`commit_and_wait`](CommandEncoder::commit_and_wait) to submit the
5//! entire batch and block until the GPU finishes.
6//!
7//! # Persistent compute encoder
8//!
9//! A single Metal `ComputeCommandEncoder` is kept alive across multiple
10//! dispatches within the same command buffer.  This avoids the overhead of
11//! creating and ending a new compute encoder per dispatch — the same pattern
12//! candle uses (`compute_per_buffer`).  On a forward pass with ~800 dispatches
13//! this saves ~800 encoder create/end cycles.
14//!
15//! # Capture mode (Phase 4e.1)
16//!
17//! When `start_capture()` is called, subsequent dispatches are recorded into a
18//! `Vec<CapturedNode>` instead of being encoded into Metal.  `memory_barrier()`
19//! records a barrier sentinel.  Call `take_capture()` to extract the recorded
20//! graph for later replay via `ComputeGraph::encode_sequential()`.
21
22use std::sync::atomic::{AtomicU64, Ordering};
23
24use metal::{
25    CommandBuffer, CommandQueue, ComputeCommandEncoderRef, ComputePipelineState,
26    ComputePipelineStateRef, MTLCommandBufferStatus, MTLDispatchType, MTLSize,
27};
28#[allow(unused_imports)]
29use objc::{msg_send, sel, sel_impl};
30
31use crate::buffer::MlxBuffer;
32use crate::error::{MlxError, Result};
33use crate::residency::ResidencySet;
34
35/// A buffer or inline-bytes binding for a compute kernel argument slot.
36pub enum KernelArg<'a> {
37    /// Bind an existing Metal buffer at the given index.
38    Buffer(&'a MlxBuffer),
39    /// Bind an existing Metal buffer at the given index with a byte offset.
40    BufferWithOffset(&'a MlxBuffer, u64),
41    /// Bind inline bytes (small constant data) at the given index.
42    /// The data must be `Pod` and is copied into the command encoder.
43    Bytes(&'a [u8]),
44}
45
46/// Convert a `Pod` value to a byte slice suitable for `KernelArg::Bytes`.
47///
48/// # Safety
49///
50/// The caller must ensure `T` has the same layout as the corresponding
51/// MSL struct in the shader (matching field order, sizes, and alignment).
52pub fn as_bytes<T: bytemuck::Pod>(val: &T) -> &[u8] {
53    bytemuck::bytes_of(val)
54}
55
56// ---------------------------------------------------------------------------
57// Capture-mode types (Phase 4e.1 — Graph IR)
58// ---------------------------------------------------------------------------
59
60/// A recorded kernel argument binding.
61///
62/// When the encoder is in capture mode, each `set_buffer` / `set_bytes` call
63/// is stored as a `RecordedBinding` instead of being applied to Metal.
64#[derive(Clone)]
65pub enum RecordedBinding {
66    /// A Metal buffer at the given offset.
67    Buffer {
68        metal_buffer: metal::Buffer,
69        offset: u64,
70    },
71    /// Inline bytes (small constant data, copied).
72    Bytes(Vec<u8>),
73}
74
75/// How to dispatch the recorded kernel.
76#[derive(Clone, Copy, Debug)]
77pub enum DispatchKind {
78    /// `dispatch_threads(grid_size, threadgroup_size)` — Metal picks threadgroup count.
79    Threads,
80    /// `dispatch_thread_groups(threadgroups, threadgroup_size)` — caller specifies threadgroup count.
81    ThreadGroups,
82}
83
84/// Operation kind tag for captured nodes, used by the fusion pass (4e.2).
85///
86/// When the encoder is in capture mode, each dispatch can be tagged with an
87/// `OpKind` so the fusion pass can identify fuseable sequences without
88/// inspecting pipeline names.
89#[derive(Clone, Copy, Debug, PartialEq, Eq)]
90pub enum CapturedOpKind {
91    /// RMS normalization (with learned scale).
92    RmsNorm,
93    /// Elementwise multiply.
94    ElemMul,
95    /// Elementwise add.
96    ElemAdd,
97    /// Scaled dot-product attention (NOT reorderable — breaks lookahead).
98    Sdpa,
99    /// Softmax (NOT reorderable — breaks lookahead).
100    Softmax,
101    /// Any other operation — treated as reorderable by the graph optimizer.
102    Other,
103}
104
105impl CapturedOpKind {
106    /// Whether this captured op kind is safe to reorder past in the graph
107    /// optimizer (Phase 4e.3).
108    ///
109    /// Mirrors the `h_safe` whitelist from llama.cpp's
110    /// `ggml_metal_graph_optimize_reorder`.  Non-safe ops break the 64-node
111    /// lookahead — the reorder pass cannot look past them.
112    pub fn is_reorderable(&self) -> bool {
113        match self {
114            Self::Sdpa | Self::Softmax => false,
115            Self::RmsNorm | Self::ElemMul | Self::ElemAdd | Self::Other => true,
116        }
117    }
118}
119
120/// A memory range annotation: (start_address, end_address).
121///
122/// Represents a contiguous GPU buffer region for conflict detection in the
123/// reorder pass (Phase 4e.3).  Addresses are CPU-visible `contents_ptr()`
124/// values, which on Apple Silicon unified memory equal the GPU addresses.
125pub type MemRange = (usize, usize);
126
127/// A single captured compute dispatch or barrier sentinel.
128///
129/// Created when the encoder is in capture mode.  Replayed later by
130/// `ComputeGraph::encode_sequential()`.
131#[derive(Clone)]
132pub enum CapturedNode {
133    /// A compute dispatch to replay.
134    Dispatch {
135        /// Pipeline state object to bind.
136        pipeline: ComputePipelineState,
137        /// Kernel argument bindings: (slot_index, binding).
138        bindings: Vec<(u64, RecordedBinding)>,
139        /// Grid or threadgroup count (interpretation depends on `dispatch_kind`).
140        threads_per_grid: MTLSize,
141        /// Threads per threadgroup.
142        threads_per_threadgroup: MTLSize,
143        /// Optional threadgroup memory allocations: (index, byte_length).
144        threadgroup_memory: Vec<(u64, u64)>,
145        /// Whether this is a dispatch_threads or dispatch_thread_groups call.
146        dispatch_kind: DispatchKind,
147        /// Operation kind tag for the fusion pass (4e.2).
148        /// Defaults to `Other` if not explicitly set via `set_op_kind()`.
149        op_kind: CapturedOpKind,
150        /// Read buffer ranges for reorder conflict detection (4e.3).
151        /// Populated from `barrier_between` calls in capture mode.
152        reads: Vec<MemRange>,
153        /// Write buffer ranges for reorder conflict detection (4e.3).
154        /// Populated from `barrier_between` calls in capture mode.
155        writes: Vec<MemRange>,
156    },
157    /// A memory barrier sentinel — forces a barrier at replay time.
158    Barrier,
159}
160
161/// Apply a slice of `KernelArg` bindings to a compute encoder.
162///
163/// `KernelArg::Buffer(buf)` propagates the `MlxBuffer::byte_offset()` so
164/// `slice_view`-derived sub-buffers are honored automatically — the
165/// kernel sees memory starting at the slice's offset. This matches the
166/// documented contract of `slice_view` and the offset-handling in the
167/// other binding paths in this file (`encode`, `encode_threadgroups`,
168/// `encode_threadgroups_with_shared`, replay). Without it, every
169/// `slice_view`-derived buffer bound via `KernelArg::Buffer` silently
170/// exposes the entire underlying allocation — surfaced by hf2q's
171/// nomic-bert iter-79 cosine parity bisection (cosine 0.098 → 0.999962
172/// after fix).
173///
174/// `KernelArg::BufferWithOffset(buf, offset)` continues to use the
175/// explicit `offset` argument verbatim (callers asking for an explicit
176/// offset get exactly that, even on sliced buffers). The two API
177/// surfaces are intentional: implicit (sliced views auto-propagate) vs.
178/// explicit (caller-controlled).
179#[inline]
180fn apply_bindings(encoder: &ComputeCommandEncoderRef, bindings: &[(u64, KernelArg<'_>)]) {
181    for &(index, ref arg) in bindings {
182        match arg {
183            KernelArg::Buffer(buf) => {
184                encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
185            }
186            KernelArg::BufferWithOffset(buf, offset) => {
187                encoder.set_buffer(index, Some(buf.metal_buffer()), *offset);
188            }
189            KernelArg::Bytes(bytes) => {
190                encoder.set_bytes(index, bytes.len() as u64, bytes.as_ptr() as *const _);
191            }
192        }
193    }
194}
195
196/// Number of times `commit_and_wait()` has been called (CPU sync points).
197static SYNC_COUNT: AtomicU64 = AtomicU64::new(0);
198
199/// Number of times an encode method has been called (GPU dispatches).
200static DISPATCH_COUNT: AtomicU64 = AtomicU64::new(0);
201
202/// Number of `MTLCommandBuffer` instances created via `CommandEncoder::new`.
203/// Increments once per `device.command_encoder()` call.  Used by hf2q's
204/// `HF2Q_DECODE_PROFILE` instrumentation to measure command-buffer
205/// overhead per decode token (ADR-012 §Optimize / Task #15 follow-up).
206static CMD_BUF_COUNT: AtomicU64 = AtomicU64::new(0);
207
208/// Number of `memory_barrier()` calls that reached the
209/// `objc::msg_send![encoder, memoryBarrierWithScope:]` site.  Capture-mode
210/// no-ops and pre-encoder no-ops are excluded so the count reflects
211/// actual MTL barriers issued.
212///
213/// Always tracked — the increment is one atomic op, ~5 ns.  ADR-015 H4
214/// (Wave 2b hard gate #2) requires per-barrier counter resolution to
215/// confirm-or-falsify the barrier-coalescing lever; xctrace TimeProfiler
216/// at 1 ms sampling cannot resolve `memory_barrier` even though it fires
217/// ~440×/token (`docs/ADR-015-mlx-native-single-cb-decode.md` §"P3a' live
218/// profile pass" hypothesis register row H4).
219static BARRIER_COUNT: AtomicU64 = AtomicU64::new(0);
220
221/// Total nanoseconds spent inside the `objc::msg_send!` barrier site,
222/// summed across all calls.  ONLY updated when the env var
223/// `MLX_PROFILE_BARRIERS=1` is set on the process (cached on first
224/// `memory_barrier` call).  When disabled the timing path is a single
225/// branch + the unconditional barrier dispatch — same hot-path cost as
226/// before this counter was added.
227///
228/// Why env-gated: timing adds 2 × `Instant::now()` (~50–100 ns each via
229/// `mach_absolute_time`) per barrier.  At ~440 barriers/token that is
230/// ~22–44 µs/token of measurement overhead — comparable to what we are
231/// trying to measure.  Production must keep this off; profiling runs
232/// opt-in.
233static BARRIER_NS: AtomicU64 = AtomicU64::new(0);
234
235/// Reset all counters to zero.
236pub fn reset_counters() {
237    SYNC_COUNT.store(0, Ordering::Relaxed);
238    DISPATCH_COUNT.store(0, Ordering::Relaxed);
239    CMD_BUF_COUNT.store(0, Ordering::Relaxed);
240    BARRIER_COUNT.store(0, Ordering::Relaxed);
241    BARRIER_NS.store(0, Ordering::Relaxed);
242}
243
244/// Read the current value of `SYNC_COUNT`.
245///
246/// Each call to `commit_and_wait()` increments this counter.
247pub fn sync_count() -> u64 {
248    SYNC_COUNT.load(Ordering::Relaxed)
249}
250
251/// Read the current value of `DISPATCH_COUNT`.
252///
253/// Each call to `encode()`, `encode_threadgroups()`, or
254/// `encode_threadgroups_with_shared()` increments this counter.
255pub fn dispatch_count() -> u64 {
256    DISPATCH_COUNT.load(Ordering::Relaxed)
257}
258
259/// Read the current value of `CMD_BUF_COUNT`.
260///
261/// Each `CommandEncoder::new` (i.e. each `MlxDevice::command_encoder()`)
262/// increments this counter.  Useful for diagnosing per-dispatch Metal
263/// command-buffer overhead in inner loops.
264pub fn cmd_buf_count() -> u64 {
265    CMD_BUF_COUNT.load(Ordering::Relaxed)
266}
267
268/// Read the current value of `BARRIER_COUNT`.
269///
270/// Each `memory_barrier()` call that reaches the underlying
271/// `objc::msg_send![encoder, memoryBarrierWithScope:]` site increments this
272/// counter.  Capture-mode no-ops and pre-encoder no-ops are excluded.
273/// ADR-015 H4 hypothesis: ~440 barriers/token on the qwen35 decode hot
274/// path (verify against this counter).
275pub fn barrier_count() -> u64 {
276    BARRIER_COUNT.load(Ordering::Relaxed)
277}
278
279/// Read the total nanoseconds spent in the `memoryBarrierWithScope:`
280/// `objc::msg_send!` site.  Only non-zero when `MLX_PROFILE_BARRIERS=1`
281/// was in the environment at the time of the first `memory_barrier()`
282/// call (the env check is cached on first use).
283///
284/// Combined with [`barrier_count`] this gives µs/barrier =
285/// `barrier_total_ns() / 1000 / barrier_count()`.
286pub fn barrier_total_ns() -> u64 {
287    BARRIER_NS.load(Ordering::Relaxed)
288}
289
290/// Whether barrier timing is enabled (env-gated, cached on first check).
291///
292/// Reading the env var via `std::env::var` is itself non-trivial; using
293/// `OnceLock` caches the decision so the per-barrier branch is a single
294/// atomic-load + compare.
295fn barrier_profile_enabled() -> bool {
296    use std::sync::OnceLock;
297    static FLAG: OnceLock<bool> = OnceLock::new();
298    *FLAG.get_or_init(|| {
299        std::env::var("MLX_PROFILE_BARRIERS")
300            .map(|v| v == "1")
301            .unwrap_or(false)
302    })
303}
304
305/// Whether `MLX_UNRETAINED_REFS=1` is set in the process environment.
306///
307/// ADR-015 iter13 — when true, `CommandEncoder::new_with_residency` opens
308/// each `MTLCommandBuffer` via
309/// [`CommandQueueRef::new_command_buffer_with_unretained_references`]
310/// instead of the default `commandBuffer`.  llama.cpp's per-token decode
311/// CBs use this same call (`/opt/llama.cpp/ggml/src/ggml-metal/`
312/// `ggml-metal-context.m:512` `[queue commandBufferWithUnretainedReferences]`)
313/// and gain ~3-5% wall on M-series GPUs by skipping per-buffer-binding ARC
314/// retains on submit.
315///
316/// **Caller-side prerequisite.**  Every Metal buffer bound to a dispatch
317/// must outlive the CB — see the docstring on
318/// [`CommandEncoder::new_with_residency`] for the full caller contract.
319/// In hf2q, the per-decode-token `MlxBufferPool` (`buffer_pool.rs`)
320/// already keeps ARC clones alive in its `in_use` list across the entire
321/// decode token; routing transient scratches through that pool is the
322/// canonical way to satisfy the contract.
323///
324/// Cached on first read via `OnceLock` to keep the per-CB-construction
325/// branch single-atomic-load fast.  Default OFF so any production decode
326/// run that does NOT explicitly set the var preserves retained-refs
327/// behavior verbatim.
328fn unretained_refs_enabled() -> bool {
329    use std::sync::OnceLock;
330    static FLAG: OnceLock<bool> = OnceLock::new();
331    *FLAG.get_or_init(|| {
332        std::env::var("MLX_UNRETAINED_REFS")
333            .map(|v| v == "1")
334            .unwrap_or(false)
335    })
336}
337
338/// Issue the underlying Metal `memoryBarrierWithScope:` ObjC msg_send.
339///
340/// Held in its own `#[inline(never)]` function so xctrace / Instruments
341/// has a stable Rust frame to attribute barrier time against, separate
342/// from the surrounding encoder accounting.  Per ADR-015 §P3a' Codex
343/// review Q2: TimeProfiler at 1 ms sampling cannot see this site when
344/// inlined; an explicit non-inline frame plus the [`BARRIER_NS`] counter
345/// closes the H4 hard gate.
346#[inline(never)]
347fn issue_metal_buffer_barrier(encoder: &ComputeCommandEncoderRef) {
348    // MTLBarrierScopeBuffers = 1 << 0 = 1.
349    const MTL_BARRIER_SCOPE_BUFFERS: u64 = 1;
350    unsafe {
351        let _: () =
352            objc::msg_send![encoder, memoryBarrierWithScope: MTL_BARRIER_SCOPE_BUFFERS];
353    }
354}
355
356/// A batched compute command encoder.
357///
358/// Keeps a single Metal `ComputeCommandEncoder` alive across multiple
359/// dispatches.  The encoder is created on the first dispatch and ended
360/// only when the command buffer is committed.  This mirrors candle's
361/// `compute_per_buffer` pattern and avoids per-dispatch encoder overhead.
362///
363/// # Typical usage
364///
365/// ```ignore
366/// let mut enc = device.command_encoder()?;
367/// // Multiple dispatches share the same compute encoder:
368/// enc.encode_threadgroups(pipeline1, &buffers1, tg1, tg_size1);
369/// enc.encode_threadgroups(pipeline2, &buffers2, tg2, tg_size2);
370/// enc.commit_and_wait()?;
371/// ```
372pub struct CommandEncoder {
373    cmd_buf: CommandBuffer,
374    // SAFETY marker: see unsafe Send impl below.
375    /// Raw pointer to the persistent compute encoder.
376    /// Non-null when a compute pass is active.
377    /// The encoder borrows from `cmd_buf` but we cannot express this
378    /// lifetime in safe Rust, so we use a raw pointer.
379    /// SAFETY: the pointer is valid as long as `cmd_buf` is alive and
380    /// `end_encoding()` has not been called on it.
381    active_encoder: *const ComputeCommandEncoderRef,
382    /// When `Some`, dispatches are recorded here instead of being encoded
383    /// into Metal.  Set via `start_capture()`, extracted via `take_capture()`.
384    capture: Option<Vec<CapturedNode>>,
385    /// Op kind tag for the NEXT captured dispatch.  Set via `set_op_kind()`,
386    /// consumed (reset to `Other`) when a dispatch is captured.
387    pending_op_kind: CapturedOpKind,
388    /// Pending read buffer ranges for the NEXT captured dispatch.
389    /// Set via `set_pending_buffer_ranges()`, consumed when the next dispatch
390    /// is captured.  Used by the reorder pass (Phase 4e.3).
391    pending_reads: Vec<MemRange>,
392    /// Pending write buffer ranges for the NEXT captured dispatch.
393    pending_writes: Vec<MemRange>,
394    /// ADR-015 iter8e (Phase 3b): residency set whose pending add/remove
395    /// staging is flushed at every `commit*` boundary.
396    ///
397    /// Cloned from the device at `device.command_encoder()` time. `None`
398    /// when residency sets are disabled (HF2Q_NO_RESIDENCY=1, macOS<15,
399    /// or test-only `CommandEncoder::new` from a residency-less queue).
400    residency_set: Option<ResidencySet>,
401}
402
403/// SAFETY: CommandEncoder is safe to Send across threads provided that:
404/// 1. Only one thread accesses the encoder at a time (exclusive ownership).
405/// 2. The encoder is not used concurrently from multiple threads.
406///
407/// Metal command buffers and compute encoders are thread-safe for exclusive
408/// access (Apple documentation: "You can create command buffers, encode
409/// commands, and submit them from any thread"). The raw pointer
410/// `active_encoder` borrows from `cmd_buf` and is valid as long as
411/// `cmd_buf` is alive — this invariant holds across thread boundaries
412/// because both fields move together.
413///
414/// This matches llama.cpp's pattern of encoding command buffers on GCD
415/// worker threads via `dispatch_apply`, and is used for the dual-buffer
416/// pipeline where buf1 is encoded on a worker thread while buf0 executes.
417unsafe impl Send for CommandEncoder {}
418
419impl CommandEncoder {
420    /// Create a new command encoder from the given command queue.
421    ///
422    /// This immediately creates a Metal command buffer.
423    ///
424    /// # Why retained references
425    ///
426    /// We use the regular `commandBuffer` (Metal retains every bound
427    /// resource for the lifetime of the buffer) rather than
428    /// `commandBufferWithUnretainedReferences`.  llama.cpp uses unretained
429    /// refs for an additional perf bump (~3-5% on M-series GPUs), but the
430    /// hf2q dispatch pattern allocates many transient scratch buffers
431    /// inside helper functions (`apply_proj` → `weight_bf16_owned`,
432    /// `apply_pre_norm` → `params`, etc.) that go out of scope at the
433    /// helper's return.  With unretained refs the metal::Buffer's ARC
434    /// drops to zero, freeing the underlying GPU memory before the
435    /// dispatch executes.  Verified 2026-04-26: switching to unretained
436    /// hits "Command buffer error: GPU command buffer completed with
437    /// error status" on the first MoE FFN dispatch.
438    ///
439    /// To enable unretained refs in the future, every helper that
440    /// allocates and dispatches must thread its scratch buffers up to a
441    /// caller scope that outlives the eventual commit, OR all such
442    /// scratch must come from the per-decode-token pool (which already
443    /// ARC-retains in its in_use list).  Today the lm_head + router-
444    /// download paths are still unpooled.
445    #[allow(dead_code)]
446    pub(crate) fn new(queue: &CommandQueue) -> Result<Self> {
447        Self::new_with_residency(queue, None)
448    }
449
450    /// Create a new command encoder, optionally bound to a residency set so
451    /// `commit*` boundaries can flush deferred add/remove staging.
452    ///
453    /// ADR-015 iter8e (Phase 3b): the encoder's `commit_and_wait`,
454    /// `commit_and_wait_labeled`, `commit`, `commit_labeled`,
455    /// `commit_wait_with_gpu_time` all call
456    /// [`ResidencySet::flush_pending`](ResidencySet::flush_pending) before
457    /// submitting the Metal command buffer. This converts the
458    /// per-allocation `[set commit]` storm
459    /// (~880 commits/decode-token in iter8d/8e claude+codex variants) into
460    /// at most one commit per CB submission — mirrors llama.cpp's
461    /// `ggml-metal-device.m:1378-1382` pattern (batch addAllocation in
462    /// loop, commit ONCE).
463    ///
464    /// ADR-015 iter13: when the `MLX_UNRETAINED_REFS=1` env var is set at
465    /// process start, this constructor uses
466    /// [`CommandQueueRef::new_command_buffer_with_unretained_references`]
467    /// instead of `new_command_buffer`.  llama.cpp's per-token decode CBs
468    /// use `commandBufferWithUnretainedReferences` (see
469    /// `/opt/llama.cpp/ggml/src/ggml-metal/ggml-metal-context.m:512`) which
470    /// skips Metal's per-buffer-binding ARC-retain on submit and saves
471    /// ~3-5% on M-series GPUs (per the docstring above).
472    ///
473    /// **Caller contract under unretained refs.**  Every Metal buffer bound
474    /// to a dispatch in this CB MUST outlive the CB's GPU completion.  In
475    /// the hf2q decode path, that means every transient scratch must be
476    /// either (a) backed by the per-decode-token arena pool
477    /// (`MlxBufferPool` keeps an ARC clone in `in_use` until the next
478    /// `reset` — see `buffer_pool.rs:60`) or (b) hoisted to a caller scope
479    /// that lives across the terminal `commit_and_wait_labeled`.  Helpers
480    /// in `apply_proj` / `apply_pre_norm` / lm_head cast / router-download
481    /// that allocated transients via `device.alloc_buffer` and dropped
482    /// them at function return MUST be lifted to `pooled_alloc_buffer`
483    /// before `MLX_UNRETAINED_REFS=1` is enabled, or the first MoE FFN
484    /// dispatch will crash with "Command buffer error: GPU command buffer
485    /// completed with error status" (verified 2026-04-26).
486    ///
487    /// The default (`MLX_UNRETAINED_REFS` unset) preserves retained-refs
488    /// behavior verbatim — this is the sourdough-safe path.
489    pub(crate) fn new_with_residency(
490        queue: &CommandQueue,
491        residency_set: Option<ResidencySet>,
492    ) -> Result<Self> {
493        let cmd_buf = if unretained_refs_enabled() {
494            queue.new_command_buffer_with_unretained_references().to_owned()
495        } else {
496            queue.new_command_buffer().to_owned()
497        };
498        CMD_BUF_COUNT.fetch_add(1, Ordering::Relaxed);
499        Ok(Self {
500            cmd_buf,
501            active_encoder: std::ptr::null(),
502            capture: None,
503            pending_op_kind: CapturedOpKind::Other,
504            pending_reads: Vec::new(),
505            pending_writes: Vec::new(),
506            residency_set,
507        })
508    }
509
510    /// Enable capture mode.
511    ///
512    /// All subsequent dispatch and barrier calls will be recorded into a
513    /// `Vec<CapturedNode>` instead of being encoded into Metal.
514    /// Call `take_capture()` to extract the recorded nodes.
515    pub fn start_capture(&mut self) {
516        self.capture = Some(Vec::with_capacity(128));
517    }
518
519    /// Whether the encoder is currently in capture mode.
520    pub fn is_capturing(&self) -> bool {
521        self.capture.is_some()
522    }
523
524    /// Extract the captured nodes, ending capture mode.
525    ///
526    /// Returns `None` if capture mode was not active.
527    pub fn take_capture(&mut self) -> Option<Vec<CapturedNode>> {
528        self.capture.take()
529    }
530
531    /// Tag the NEXT captured dispatch with the given operation kind.
532    ///
533    /// The tag is consumed (reset to `Other`) after the next dispatch is
534    /// captured.  Only meaningful in capture mode — has no effect on
535    /// direct-dispatch encoding.
536    ///
537    /// Used by op dispatch functions to annotate captures for the fusion
538    /// pass (Phase 4e.2).
539    pub fn set_op_kind(&mut self, kind: CapturedOpKind) {
540        self.pending_op_kind = kind;
541    }
542
543    /// Consume and return the pending op kind, resetting it to `Other`.
544    fn take_pending_op_kind(&mut self) -> CapturedOpKind {
545        let kind = self.pending_op_kind;
546        self.pending_op_kind = CapturedOpKind::Other;
547        kind
548    }
549
550    /// Stash buffer range annotations for the NEXT captured dispatch.
551    ///
552    /// Called by `GraphSession::barrier_between()` in capture mode to record
553    /// which buffers the next dispatch reads from and writes to.  The ranges
554    /// are consumed by the next `encode_*` call and attached to the captured
555    /// `CapturedNode::Dispatch`.
556    ///
557    /// Only meaningful in capture mode — has no effect on direct-dispatch.
558    pub fn set_pending_buffer_ranges(&mut self, reads: Vec<MemRange>, writes: Vec<MemRange>) {
559        self.pending_reads = reads;
560        self.pending_writes = writes;
561    }
562
563    /// Patch the last captured dispatch node's empty reads/writes with the
564    /// given ranges. No-op if not capturing, or if the last node isn't a
565    /// Dispatch, or if its ranges are already populated.
566    ///
567    /// Used by `GraphSession::track_dispatch` in recording mode to annotate
568    /// dispatches that were called without a preceding `barrier_between`.
569    pub fn annotate_last_dispatch_if_missing(&mut self, reads: Vec<MemRange>, writes: Vec<MemRange>) {
570        if let Some(ref mut nodes) = self.capture {
571            if let Some(CapturedNode::Dispatch { reads: r, writes: w, .. }) = nodes.last_mut() {
572                if r.is_empty() && !reads.is_empty() {
573                    *r = reads;
574                }
575                if w.is_empty() && !writes.is_empty() {
576                    *w = writes;
577                }
578            }
579        }
580    }
581
582    /// Consume and return the pending buffer range annotations.
583    fn take_pending_buffer_ranges(&mut self) -> (Vec<MemRange>, Vec<MemRange>) {
584        let reads = std::mem::take(&mut self.pending_reads);
585        let writes = std::mem::take(&mut self.pending_writes);
586        (reads, writes)
587    }
588
589    /// Record buffer bindings into `RecordedBinding` form.
590    fn record_buffer_bindings(buffers: &[(u64, &MlxBuffer)]) -> Vec<(u64, RecordedBinding)> {
591        buffers
592            .iter()
593            .map(|&(index, buf)| {
594                (
595                    index,
596                    RecordedBinding::Buffer {
597                        metal_buffer: buf.metal_buffer().clone(),
598                        offset: buf.byte_offset(),
599                    },
600                )
601            })
602            .collect()
603    }
604
605    /// Record `KernelArg` bindings into `RecordedBinding` form.
606    ///
607    /// `KernelArg::Buffer(buf)` records `buf.byte_offset()` so capture →
608    /// replay round-trips of `slice_view`-derived buffers preserve their
609    /// offsets, matching `record_buffer_bindings`'s behavior at line 382.
610    fn record_arg_bindings(bindings: &[(u64, KernelArg<'_>)]) -> Vec<(u64, RecordedBinding)> {
611        bindings
612            .iter()
613            .map(|(index, arg)| {
614                let recorded = match arg {
615                    KernelArg::Buffer(buf) => RecordedBinding::Buffer {
616                        metal_buffer: buf.metal_buffer().clone(),
617                        offset: buf.byte_offset(),
618                    },
619                    KernelArg::BufferWithOffset(buf, offset) => RecordedBinding::Buffer {
620                        metal_buffer: buf.metal_buffer().clone(),
621                        offset: *offset,
622                    },
623                    KernelArg::Bytes(bytes) => RecordedBinding::Bytes(bytes.to_vec()),
624                };
625                (*index, recorded)
626            })
627            .collect()
628    }
629
630    /// Get or create the persistent compute encoder.
631    ///
632    /// On the first call, creates a new compute encoder from the command
633    /// buffer.  On subsequent calls, returns the existing one.
634    ///
635    /// SAFETY: The returned reference borrows from `self.cmd_buf` which is
636    /// alive for the lifetime of this `CommandEncoder`.  The raw pointer is
637    /// valid until `end_active_encoder()` is called.
638    #[inline]
639    fn get_or_create_encoder(&mut self) -> &ComputeCommandEncoderRef {
640        if self.active_encoder.is_null() {
641            // Use MTLDispatchTypeConcurrent to allow independent dispatches
642            // to overlap on the GPU.  Memory barriers are inserted between
643            // dependent dispatches via `memory_barrier()`.
644            let encoder = self
645                .cmd_buf
646                .compute_command_encoder_with_dispatch_type(MTLDispatchType::Concurrent);
647            self.active_encoder = encoder as *const ComputeCommandEncoderRef;
648        }
649        // SAFETY: active_encoder is non-null and points to a valid encoder
650        // owned by cmd_buf.
651        unsafe { &*self.active_encoder }
652    }
653
654    /// End the active compute encoder if one exists.
655    #[inline]
656    fn end_active_encoder(&mut self) {
657        if !self.active_encoder.is_null() {
658            // SAFETY: the pointer was obtained from cmd_buf.new_compute_command_encoder()
659            // and has not been ended yet.
660            unsafe { &*self.active_encoder }.end_encoding();
661            self.active_encoder = std::ptr::null();
662        }
663    }
664
665    /// Insert a memory barrier with scope `MTLBarrierScopeBuffers`.
666    ///
667    /// When the encoder uses `MTLDispatchTypeConcurrent`, all dispatches can
668    /// execute concurrently unless separated by a barrier.  Call this between
669    /// dispatches where the later dispatch reads a buffer written by an
670    /// earlier one.
671    ///
672    /// This is the same pattern llama.cpp uses:
673    /// `[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers]`
674    #[allow(unexpected_cfgs)]
675    pub fn memory_barrier(&mut self) {
676        if let Some(ref mut nodes) = self.capture {
677            nodes.push(CapturedNode::Barrier);
678            return;
679        }
680        if self.active_encoder.is_null() {
681            return;
682        }
683        BARRIER_COUNT.fetch_add(1, Ordering::Relaxed);
684        // SAFETY: active_encoder is non-null and valid.
685        let encoder = unsafe { &*self.active_encoder };
686        if barrier_profile_enabled() {
687            // mach_absolute_time path — only on when MLX_PROFILE_BARRIERS=1.
688            let start = std::time::Instant::now();
689            issue_metal_buffer_barrier(encoder);
690            let elapsed_ns = start.elapsed().as_nanos() as u64;
691            BARRIER_NS.fetch_add(elapsed_ns, Ordering::Relaxed);
692        } else {
693            issue_metal_buffer_barrier(encoder);
694        }
695    }
696
697    /// Set the compute pipeline state for subsequent dispatches.
698    ///
699    /// This begins a new compute pass if one is not already active.
700    pub fn set_pipeline(&mut self, pipeline: &ComputePipelineStateRef) {
701        let encoder = self.get_or_create_encoder();
702        encoder.set_compute_pipeline_state(pipeline);
703    }
704
705    /// Bind a buffer to a compute kernel argument slot.
706    ///
707    /// The `index` corresponds to the `[[buffer(N)]]` attribute in the MSL shader.
708    pub fn set_buffer(&self, index: u64, buffer: &MlxBuffer) {
709        let _ = (index, buffer);
710    }
711
712    /// Dispatch threads on the GPU.
713    pub fn dispatch_threads(&self, grid_size: MTLSize, threadgroup_size: MTLSize) {
714        let _ = (grid_size, threadgroup_size);
715    }
716
717    /// Encode a complete compute pass: set pipeline, bind buffers, dispatch.
718    ///
719    /// Reuses the persistent compute encoder — no per-dispatch encoder
720    /// creation overhead.
721    ///
722    /// # Arguments
723    ///
724    /// * `pipeline`         — The compiled compute pipeline to execute.
725    /// * `buffers`          — Slice of `(index, &MlxBuffer)` pairs for buffer bindings.
726    /// * `grid_size`        — Total number of threads to launch.
727    /// * `threadgroup_size` — Threads per threadgroup.
728    pub fn encode(
729        &mut self,
730        pipeline: &ComputePipelineStateRef,
731        buffers: &[(u64, &MlxBuffer)],
732        grid_size: MTLSize,
733        threadgroup_size: MTLSize,
734    ) {
735        DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
736        let op_kind = self.take_pending_op_kind();
737        let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
738        if let Some(ref mut nodes) = self.capture {
739            nodes.push(CapturedNode::Dispatch {
740                pipeline: pipeline.to_owned(),
741                bindings: Self::record_buffer_bindings(buffers),
742                threads_per_grid: grid_size,
743                threads_per_threadgroup: threadgroup_size,
744                threadgroup_memory: Vec::new(),
745                dispatch_kind: DispatchKind::Threads,
746                op_kind,
747                reads: pending_reads,
748                writes: pending_writes,
749            });
750            return;
751        }
752        let encoder = self.get_or_create_encoder();
753        encoder.set_compute_pipeline_state(pipeline);
754        for &(index, buf) in buffers {
755            encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
756        }
757        encoder.dispatch_threads(grid_size, threadgroup_size);
758    }
759
760    /// Encode a compute pass using threadgroups instead of raw thread counts.
761    ///
762    /// Reuses the persistent compute encoder — no per-dispatch encoder
763    /// creation overhead.
764    pub fn encode_threadgroups(
765        &mut self,
766        pipeline: &ComputePipelineStateRef,
767        buffers: &[(u64, &MlxBuffer)],
768        threadgroups: MTLSize,
769        threadgroup_size: MTLSize,
770    ) {
771        DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
772        let op_kind = self.take_pending_op_kind();
773        let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
774        if let Some(ref mut nodes) = self.capture {
775            nodes.push(CapturedNode::Dispatch {
776                pipeline: pipeline.to_owned(),
777                bindings: Self::record_buffer_bindings(buffers),
778                threads_per_grid: threadgroups,
779                threads_per_threadgroup: threadgroup_size,
780                threadgroup_memory: Vec::new(),
781                dispatch_kind: DispatchKind::ThreadGroups,
782                op_kind,
783                reads: pending_reads,
784                writes: pending_writes,
785            });
786            return;
787        }
788        let encoder = self.get_or_create_encoder();
789        encoder.set_compute_pipeline_state(pipeline);
790        for &(index, buf) in buffers {
791            encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
792        }
793        encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
794    }
795
796    /// Encode a compute pass using threadgroups with shared threadgroup memory.
797    ///
798    /// Like [`encode_threadgroups`](Self::encode_threadgroups), but additionally
799    /// allocates threadgroup memory at the specified indices.  This is required
800    /// for kernels that use `threadgroup` memory (e.g. reductions in rms_norm
801    /// and softmax).
802    ///
803    /// # Arguments
804    ///
805    /// * `pipeline`         — The compiled compute pipeline to execute.
806    /// * `buffers`          — Slice of `(index, &MlxBuffer)` pairs for buffer bindings.
807    /// * `threadgroup_mem`  — Slice of `(index, byte_length)` pairs for threadgroup memory.
808    /// * `threadgroups`     — Number of threadgroups to dispatch.
809    /// * `threadgroup_size` — Threads per threadgroup.
810    pub fn encode_threadgroups_with_shared(
811        &mut self,
812        pipeline: &ComputePipelineStateRef,
813        buffers: &[(u64, &MlxBuffer)],
814        threadgroup_mem: &[(u64, u64)],
815        threadgroups: MTLSize,
816        threadgroup_size: MTLSize,
817    ) {
818        DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
819        let op_kind = self.take_pending_op_kind();
820        let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
821        if let Some(ref mut nodes) = self.capture {
822            nodes.push(CapturedNode::Dispatch {
823                pipeline: pipeline.to_owned(),
824                bindings: Self::record_buffer_bindings(buffers),
825                threads_per_grid: threadgroups,
826                threads_per_threadgroup: threadgroup_size,
827                threadgroup_memory: threadgroup_mem.to_vec(),
828                dispatch_kind: DispatchKind::ThreadGroups,
829                op_kind,
830                reads: pending_reads,
831                writes: pending_writes,
832            });
833            return;
834        }
835        let encoder = self.get_or_create_encoder();
836        encoder.set_compute_pipeline_state(pipeline);
837        for &(index, buf) in buffers {
838            encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
839        }
840        for &(index, byte_length) in threadgroup_mem {
841            encoder.set_threadgroup_memory_length(index, byte_length);
842        }
843        encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
844    }
845
846    /// Encode a dispatch with mixed buffer/bytes bindings (dispatch_threads).
847    ///
848    /// Reuses the persistent compute encoder.
849    pub fn encode_with_args(
850        &mut self,
851        pipeline: &ComputePipelineStateRef,
852        bindings: &[(u64, KernelArg<'_>)],
853        grid_size: MTLSize,
854        threadgroup_size: MTLSize,
855    ) {
856        DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
857        let op_kind = self.take_pending_op_kind();
858        let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
859        if let Some(ref mut nodes) = self.capture {
860            nodes.push(CapturedNode::Dispatch {
861                pipeline: pipeline.to_owned(),
862                bindings: Self::record_arg_bindings(bindings),
863                threads_per_grid: grid_size,
864                threads_per_threadgroup: threadgroup_size,
865                threadgroup_memory: Vec::new(),
866                dispatch_kind: DispatchKind::Threads,
867                op_kind,
868                reads: pending_reads,
869                writes: pending_writes,
870            });
871            return;
872        }
873        let encoder = self.get_or_create_encoder();
874        encoder.set_compute_pipeline_state(pipeline);
875        apply_bindings(encoder, bindings);
876        encoder.dispatch_threads(grid_size, threadgroup_size);
877    }
878
879    /// Encode a dispatch with mixed buffer/bytes bindings (dispatch_thread_groups).
880    ///
881    /// Reuses the persistent compute encoder.
882    pub fn encode_threadgroups_with_args(
883        &mut self,
884        pipeline: &ComputePipelineStateRef,
885        bindings: &[(u64, KernelArg<'_>)],
886        threadgroups: MTLSize,
887        threadgroup_size: MTLSize,
888    ) {
889        DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
890        let op_kind = self.take_pending_op_kind();
891        let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
892        if let Some(ref mut nodes) = self.capture {
893            nodes.push(CapturedNode::Dispatch {
894                pipeline: pipeline.to_owned(),
895                bindings: Self::record_arg_bindings(bindings),
896                threads_per_grid: threadgroups,
897                threads_per_threadgroup: threadgroup_size,
898                threadgroup_memory: Vec::new(),
899                dispatch_kind: DispatchKind::ThreadGroups,
900                op_kind,
901                reads: pending_reads,
902                writes: pending_writes,
903            });
904            return;
905        }
906        let encoder = self.get_or_create_encoder();
907        encoder.set_compute_pipeline_state(pipeline);
908        apply_bindings(encoder, bindings);
909        encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
910    }
911
912    /// Encode a dispatch with mixed buffer/bytes bindings and shared memory.
913    ///
914    /// Reuses the persistent compute encoder.
915    pub fn encode_threadgroups_with_args_and_shared(
916        &mut self,
917        pipeline: &ComputePipelineStateRef,
918        bindings: &[(u64, KernelArg<'_>)],
919        threadgroup_mem: &[(u64, u64)],
920        threadgroups: MTLSize,
921        threadgroup_size: MTLSize,
922    ) {
923        DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
924        let op_kind = self.take_pending_op_kind();
925        let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
926        if let Some(ref mut nodes) = self.capture {
927            nodes.push(CapturedNode::Dispatch {
928                pipeline: pipeline.to_owned(),
929                bindings: Self::record_arg_bindings(bindings),
930                threads_per_grid: threadgroups,
931                threads_per_threadgroup: threadgroup_size,
932                threadgroup_memory: threadgroup_mem.to_vec(),
933                dispatch_kind: DispatchKind::ThreadGroups,
934                op_kind,
935                reads: pending_reads,
936                writes: pending_writes,
937            });
938            return;
939        }
940        let encoder = self.get_or_create_encoder();
941        encoder.set_compute_pipeline_state(pipeline);
942        apply_bindings(encoder, bindings);
943        for &(index, byte_length) in threadgroup_mem {
944            encoder.set_threadgroup_memory_length(index, byte_length);
945        }
946        encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
947    }
948
949    /// Replay a single captured dispatch node into this encoder.
950    ///
951    /// This is the inverse of capture: it takes a previously recorded
952    /// `CapturedNode::Dispatch` and encodes it into the live Metal encoder.
953    /// Barrier nodes are handled by the caller (ComputeGraph::encode_sequential).
954    ///
955    /// Does NOT increment `DISPATCH_COUNT` — that was already counted at
956    /// capture time.
957    pub fn replay_dispatch(
958        &mut self,
959        pipeline: &ComputePipelineStateRef,
960        bindings: &[(u64, RecordedBinding)],
961        threadgroup_memory: &[(u64, u64)],
962        threads_per_grid: MTLSize,
963        threads_per_threadgroup: MTLSize,
964        dispatch_kind: DispatchKind,
965    ) {
966        let encoder = self.get_or_create_encoder();
967        encoder.set_compute_pipeline_state(pipeline);
968        for (index, binding) in bindings {
969            match binding {
970                RecordedBinding::Buffer { metal_buffer, offset } => {
971                    encoder.set_buffer(*index, Some(metal_buffer), *offset);
972                }
973                RecordedBinding::Bytes(bytes) => {
974                    encoder.set_bytes(
975                        *index,
976                        bytes.len() as u64,
977                        bytes.as_ptr() as *const _,
978                    );
979                }
980            }
981        }
982        for &(index, byte_length) in threadgroup_memory {
983            encoder.set_threadgroup_memory_length(index, byte_length);
984        }
985        match dispatch_kind {
986            DispatchKind::Threads => {
987                encoder.dispatch_threads(threads_per_grid, threads_per_threadgroup);
988            }
989            DispatchKind::ThreadGroups => {
990                encoder.dispatch_thread_groups(threads_per_grid, threads_per_threadgroup);
991            }
992        }
993    }
994
995    /// Flush any pending residency-set add/remove staging.
996    ///
997    /// Hooked at every commit boundary so per-allocation
998    /// [`ResidencySet::add_allocation`](ResidencySet::add_allocation) and
999    /// [`ResidencySet::remove_allocation`](ResidencySet::remove_allocation)
1000    /// calls (as fired by `MlxDevice::alloc_buffer` and
1001    /// `MlxBufferStorage::Drop`) collapse into at most ONE `[set commit]`
1002    /// per CB submission. Mirrors llama.cpp's
1003    /// `ggml-metal-device.m:1378-1382` (batch addAllocation in loop,
1004    /// commit ONCE).
1005    #[inline]
1006    fn flush_residency_pending(&self) {
1007        if let Some(set) = self.residency_set.as_ref() {
1008            set.flush_pending();
1009        }
1010    }
1011
1012    /// Commit the command buffer and block until the GPU finishes execution.
1013    ///
1014    /// # Errors
1015    ///
1016    /// Returns `MlxError::CommandBufferError` if the GPU reports an error.
1017    pub fn commit_and_wait(&mut self) -> Result<()> {
1018        SYNC_COUNT.fetch_add(1, Ordering::Relaxed);
1019
1020        // End the persistent compute encoder before committing.
1021        self.end_active_encoder();
1022
1023        // ADR-015 iter8e (Phase 3b): flush deferred residency-set
1024        // add/remove staging so the residency hint covers any buffers
1025        // referenced by this CB. Single commit per CB boundary; no-op
1026        // when no residency set or no staged changes.
1027        self.flush_residency_pending();
1028
1029        self.cmd_buf.commit();
1030        self.cmd_buf.wait_until_completed();
1031
1032        match self.cmd_buf.status() {
1033            MTLCommandBufferStatus::Completed => Ok(()),
1034            MTLCommandBufferStatus::Error => {
1035                Err(MlxError::CommandBufferError(
1036                    "GPU command buffer completed with error status".into(),
1037                ))
1038            }
1039            status => Err(MlxError::CommandBufferError(format!(
1040                "Unexpected command buffer status after wait: {:?}",
1041                status
1042            ))),
1043        }
1044    }
1045
1046    /// Commit + wait, accumulating GPU wall-clock time under `label` into
1047    /// the [`crate::kernel_profile`] global table when `MLX_PROFILE_CB=1`
1048    /// is set.  When the env var is unset, this is identical to
1049    /// [`commit_and_wait`](Self::commit_and_wait) — zero overhead.
1050    ///
1051    /// Used by hf2q's decode hot path to attribute per-cb GPU time to
1052    /// labeled phases (per-layer attn, per-layer ffn, output_head, etc.)
1053    /// without manually wiring `commit_wait_with_gpu_time` everywhere.
1054    ///
1055    /// # Errors
1056    ///
1057    /// Returns `MlxError::CommandBufferError` if the GPU reports an error.
1058    pub fn commit_and_wait_labeled(&mut self, label: &str) -> Result<()> {
1059        // ADR-015 iter16 — propagate `label` to MTLCommandBuffer.setLabel and
1060        // (if a compute encoder is active) MTLComputeCommandEncoder.setLabel
1061        // BEFORE end_encoding/commit so xctrace's
1062        // `metal-application-encoders-list` table populates `cmdbuffer-label`
1063        // and `encoder-label` columns with the semantic phase name (e.g.
1064        // `layer.attn_moe_ffn`, `output_head.fused_norm_lm_argmax`,
1065        // `layer.delta_net.ops1-9`).  Joined to per-CB GPU duration via
1066        // `metal-gpu-submission-to-command-buffer-id` (sub_id ↔ encoder_id) →
1067        // `metal-gpu-execution-points` (per-dispatch start/end), this enables
1068        // per-phase µs/token attribution comparing hf2q vs llama side-by-side
1069        // (iter15 §E "iter16 ATTRIBUTION PATH").  Cost is a single ObjC
1070        // msg_send per CB submission — sub-µs on M5 Max — and a no-op when
1071        // xctrace isn't recording, so this is unconditionally safe to call on
1072        // the production decode hot path.
1073        self.apply_labels(label);
1074        if crate::kernel_profile::is_enabled() {
1075            let (start_s, end_s) = self.commit_wait_with_gpu_time()?;
1076            let ns = ((end_s - start_s).max(0.0) * 1_000_000_000.0) as u64;
1077            crate::kernel_profile::record(label, ns);
1078            Ok(())
1079        } else {
1080            self.commit_and_wait()
1081        }
1082    }
1083
1084    /// Async commit, but with profiling label.  When `MLX_PROFILE_CB=1`
1085    /// is set, redirects to a synchronous [`commit_and_wait_labeled`]
1086    /// call to capture per-cb GPU time (this defeats async pipelining
1087    /// while profiling, which is the whole point — profile-mode is slow
1088    /// but informative).  When unset, identical to [`commit`](Self::commit).
1089    pub fn commit_labeled(&mut self, label: &str) {
1090        // ADR-015 iter16 — see `commit_and_wait_labeled` for rationale.
1091        if crate::kernel_profile::is_enabled() {
1092            // Profile mode: force sync to capture GPU time.  apply_labels is
1093            // called inside commit_and_wait_labeled — do NOT call it twice
1094            // here (would double the ObjC msg_send under MLX_PROFILE_CB=1).
1095            // Errors are logged via stderr because the void return matches
1096            // commit().
1097            if let Err(e) = self.commit_and_wait_labeled(label) {
1098                eprintln!("[mlx-native] commit_labeled({}) failed: {}", label, e);
1099            }
1100        } else {
1101            // Async path: apply labels here so xctrace MST traces capture
1102            // per-CB phase attribution under default decode (no
1103            // `MLX_PROFILE_CB`).
1104            self.apply_labels(label);
1105            self.commit();
1106        }
1107    }
1108
1109    /// Apply `label` to the underlying `MTLCommandBuffer` and, if a compute
1110    /// encoder is currently active, to the `MTLComputeCommandEncoder`.
1111    ///
1112    /// Called from [`commit_labeled`] and [`commit_and_wait_labeled`] BEFORE
1113    /// the encoder is ended / the CB is committed so xctrace's
1114    /// `metal-application-encoders-list` table picks up the label on the
1115    /// row emitted at the encoder's `endEncoding` / CB submission boundary.
1116    /// Single ObjC `msg_send` per call (two if an encoder is active); sub-µs
1117    /// on M5 Max; no-op when xctrace isn't recording.
1118    ///
1119    /// Skipped (debug-only assert) if `label` is empty — empty labels would
1120    /// produce an indistinguishable trace row from the metal-rs default
1121    /// `Command Buffer 0` placeholder.
1122    #[inline]
1123    fn apply_labels(&self, label: &str) {
1124        debug_assert!(!label.is_empty(), "commit_*_labeled called with empty label");
1125        if label.is_empty() {
1126            return;
1127        }
1128        self.cmd_buf.set_label(label);
1129        if !self.active_encoder.is_null() {
1130            // SAFETY: active_encoder is non-null and points to a live encoder
1131            // owned by cmd_buf — same invariant as get_or_create_encoder /
1132            // memory_barrier.  set_label is a single property write on the
1133            // ObjC object; safe before endEncoding.
1134            unsafe { &*self.active_encoder }.set_label(label);
1135        }
1136    }
1137
1138    /// Commit + wait, returning `(gpu_start_s, gpu_end_s)` CFTimeInterval
1139    /// timestamps from `MTLCommandBuffer`'s `GPUStartTime`/`GPUEndTime`
1140    /// properties.  Both are mach-absolute CFTimeInterval seconds (double).
1141    ///
1142    /// Intended for `HF2Q_PROFILE_GPU_TS=1` per-bucket GPU wall-clock
1143    /// attribution.  Adds exactly two ObjC property reads per call on top
1144    /// of the regular `commit_and_wait` — measured well under 1 μs on
1145    /// M5 Max.
1146    ///
1147    /// # Errors
1148    ///
1149    /// Returns `MlxError::CommandBufferError` if the GPU reports an error.
1150    pub fn commit_wait_with_gpu_time(&mut self) -> Result<(f64, f64)> {
1151        self.commit_and_wait()?;
1152        // SAFETY: cmd_buf is a valid MTLCommandBuffer that has been
1153        // committed and awaited.  GPUStartTime / GPUEndTime return
1154        // CFTimeInterval (double precision seconds).  See
1155        // https://developer.apple.com/documentation/metal/mtlcommandbuffer/1639925-gpustarttime
1156        let (gpu_start, gpu_end): (f64, f64) = unsafe {
1157            let cb = &*self.cmd_buf;
1158            let s: f64 = msg_send![cb, GPUStartTime];
1159            let e: f64 = msg_send![cb, GPUEndTime];
1160            (s, e)
1161        };
1162        Ok((gpu_start, gpu_end))
1163    }
1164
1165    /// Commit the command buffer WITHOUT blocking.
1166    ///
1167    /// The GPU begins executing the encoded commands immediately.  Call
1168    /// [`wait_until_completed`](Self::wait_until_completed) later to block
1169    /// the CPU and check for errors.  This allows the CPU to continue doing
1170    /// other work (e.g. preparing the next batch) while the GPU runs.
1171    pub fn commit(&mut self) {
1172        self.end_active_encoder();
1173        // ADR-015 iter8e (Phase 3b): same flush hook as commit_and_wait —
1174        // this is the async-pipeline path that production decode uses.
1175        self.flush_residency_pending();
1176        self.cmd_buf.commit();
1177    }
1178
1179    /// Block until a previously committed command buffer completes.
1180    ///
1181    /// Must be called after [`commit`](Self::commit).  Do not call after
1182    /// [`commit_and_wait`](Self::commit_and_wait) — that method already waits.
1183    ///
1184    /// # Errors
1185    ///
1186    /// Returns `MlxError::CommandBufferError` if the GPU reports an error.
1187    pub fn wait_until_completed(&self) -> Result<()> {
1188        self.cmd_buf.wait_until_completed();
1189        match self.cmd_buf.status() {
1190            MTLCommandBufferStatus::Completed => Ok(()),
1191            MTLCommandBufferStatus::Error => Err(MlxError::CommandBufferError(
1192                "GPU command buffer completed with error status".into(),
1193            )),
1194            status => Err(MlxError::CommandBufferError(format!(
1195                "Unexpected command buffer status after wait: {:?}",
1196                status
1197            ))),
1198        }
1199    }
1200
1201    /// Borrow the underlying Metal command buffer.
1202    #[inline]
1203    pub fn metal_command_buffer(&self) -> &CommandBuffer {
1204        &self.cmd_buf
1205    }
1206}
1207
1208impl Drop for CommandEncoder {
1209    fn drop(&mut self) {
1210        // End the persistent compute encoder before the command buffer
1211        // is dropped, otherwise Metal will assert:
1212        // "Command encoder released without endEncoding"
1213        self.end_active_encoder();
1214    }
1215}