Skip to main content

mlx_native/
encoder.rs

1//! [`CommandEncoder`] — batched GPU command submission.
2//!
3//! Wraps a Metal command buffer.  Encode one or more compute kernel dispatches,
4//! then call [`commit_and_wait`](CommandEncoder::commit_and_wait) to submit the
5//! entire batch and block until the GPU finishes.
6//!
7//! # Persistent compute encoder
8//!
9//! A single Metal `ComputeCommandEncoder` is kept alive across multiple
10//! dispatches within the same command buffer.  This avoids the overhead of
11//! creating and ending a new compute encoder per dispatch — the same pattern
12//! candle uses (`compute_per_buffer`).  On a forward pass with ~800 dispatches
13//! this saves ~800 encoder create/end cycles.
14//!
15//! # Capture mode (Phase 4e.1)
16//!
17//! When `start_capture()` is called, subsequent dispatches are recorded into a
18//! `Vec<CapturedNode>` instead of being encoded into Metal.  `memory_barrier()`
19//! records a barrier sentinel.  Call `take_capture()` to extract the recorded
20//! graph for later replay via `ComputeGraph::encode_sequential()`.
21
22use std::sync::atomic::{AtomicI8, AtomicU64, Ordering};
23
24use metal::{
25    CommandBuffer, CommandQueue, ComputeCommandEncoderRef, ComputePipelineState,
26    ComputePipelineStateRef, CounterSampleBuffer, CounterSampleBufferDescriptor,
27    MTLCommandBufferStatus, MTLCounterSamplingPoint, MTLDispatchType, MTLSize, MTLStorageMode,
28    NSRange,
29};
30#[allow(unused_imports)]
31use objc::{msg_send, sel, sel_impl};
32
33use crate::buffer::MlxBuffer;
34use crate::error::{MlxError, Result};
35use crate::mem_ranges::MemRanges;
36use crate::residency::ResidencySet;
37
38/// A buffer or inline-bytes binding for a compute kernel argument slot.
39pub enum KernelArg<'a> {
40    /// Bind an existing Metal buffer at the given index.
41    Buffer(&'a MlxBuffer),
42    /// Bind an existing Metal buffer at the given index with a byte offset.
43    BufferWithOffset(&'a MlxBuffer, u64),
44    /// Bind inline bytes (small constant data) at the given index.
45    /// The data must be `Pod` and is copied into the command encoder.
46    Bytes(&'a [u8]),
47}
48
49/// Convert a `Pod` value to a byte slice suitable for `KernelArg::Bytes`.
50///
51/// # Safety
52///
53/// The caller must ensure `T` has the same layout as the corresponding
54/// MSL struct in the shader (matching field order, sizes, and alignment).
55pub fn as_bytes<T: bytemuck::Pod>(val: &T) -> &[u8] {
56    bytemuck::bytes_of(val)
57}
58
59// ---------------------------------------------------------------------------
60// Capture-mode types (Phase 4e.1 — Graph IR)
61// ---------------------------------------------------------------------------
62
63/// A recorded kernel argument binding.
64///
65/// When the encoder is in capture mode, each `set_buffer` / `set_bytes` call
66/// is stored as a `RecordedBinding` instead of being applied to Metal.
67#[derive(Clone)]
68pub enum RecordedBinding {
69    /// A Metal buffer at the given offset.
70    Buffer {
71        metal_buffer: metal::Buffer,
72        offset: u64,
73    },
74    /// Inline bytes (small constant data, copied).
75    Bytes(Vec<u8>),
76}
77
78/// How to dispatch the recorded kernel.
79#[derive(Clone, Copy, Debug)]
80pub enum DispatchKind {
81    /// `dispatch_threads(grid_size, threadgroup_size)` — Metal picks threadgroup count.
82    Threads,
83    /// `dispatch_thread_groups(threadgroups, threadgroup_size)` — caller specifies threadgroup count.
84    ThreadGroups,
85}
86
87/// Operation kind tag for captured nodes, used by the fusion pass (4e.2).
88///
89/// When the encoder is in capture mode, each dispatch can be tagged with an
90/// `OpKind` so the fusion pass can identify fuseable sequences without
91/// inspecting pipeline names.
92#[derive(Clone, Copy, Debug, PartialEq, Eq)]
93pub enum CapturedOpKind {
94    /// RMS normalization (with learned scale).
95    RmsNorm,
96    /// Elementwise multiply.
97    ElemMul,
98    /// Elementwise add.
99    ElemAdd,
100    /// Scaled dot-product attention (NOT reorderable — breaks lookahead).
101    Sdpa,
102    /// Softmax (NOT reorderable — breaks lookahead).
103    Softmax,
104    /// Any other operation — treated as reorderable by the graph optimizer.
105    Other,
106}
107
108impl CapturedOpKind {
109    /// Whether this captured op kind is safe to reorder past in the graph
110    /// optimizer (Phase 4e.3).
111    ///
112    /// Mirrors the `h_safe` whitelist from llama.cpp's
113    /// `ggml_metal_graph_optimize_reorder`.  Non-safe ops break the 64-node
114    /// lookahead — the reorder pass cannot look past them.
115    pub fn is_reorderable(&self) -> bool {
116        match self {
117            Self::Sdpa | Self::Softmax => false,
118            Self::RmsNorm | Self::ElemMul | Self::ElemAdd | Self::Other => true,
119        }
120    }
121
122    /// Stable string label suitable for embedding in the per-dispatch
123    /// profile dump (ADR-015 iter63 §A.5).  Matches the variant name —
124    /// `Other` is preserved verbatim so an aggregate-by-op_kind sort
125    /// produces a clean "what isn't yet labeled" bucket.
126    pub fn name(&self) -> &'static str {
127        match self {
128            Self::RmsNorm => "RmsNorm",
129            Self::ElemMul => "ElemMul",
130            Self::ElemAdd => "ElemAdd",
131            Self::Sdpa => "Sdpa",
132            Self::Softmax => "Softmax",
133            Self::Other => "Other",
134        }
135    }
136}
137
138/// A memory range annotation: (start_address, end_address).
139///
140/// Represents a contiguous GPU buffer region for conflict detection in the
141/// reorder pass (Phase 4e.3).  Addresses are CPU-visible `contents_ptr()`
142/// values, which on Apple Silicon unified memory equal the GPU addresses.
143pub type MemRange = (usize, usize);
144
145/// A single captured compute dispatch or barrier sentinel.
146///
147/// Created when the encoder is in capture mode.  Replayed later by
148/// `ComputeGraph::encode_sequential()`.
149#[derive(Clone)]
150pub enum CapturedNode {
151    /// A compute dispatch to replay.
152    Dispatch {
153        /// Pipeline state object to bind.
154        pipeline: ComputePipelineState,
155        /// Kernel argument bindings: (slot_index, binding).
156        bindings: Vec<(u64, RecordedBinding)>,
157        /// Grid or threadgroup count (interpretation depends on `dispatch_kind`).
158        threads_per_grid: MTLSize,
159        /// Threads per threadgroup.
160        threads_per_threadgroup: MTLSize,
161        /// Optional threadgroup memory allocations: (index, byte_length).
162        threadgroup_memory: Vec<(u64, u64)>,
163        /// Whether this is a dispatch_threads or dispatch_thread_groups call.
164        dispatch_kind: DispatchKind,
165        /// Operation kind tag for the fusion pass (4e.2).
166        /// Defaults to `Other` if not explicitly set via `set_op_kind()`.
167        op_kind: CapturedOpKind,
168        /// Read buffer ranges for reorder conflict detection (4e.3).
169        /// Populated from `barrier_between` calls in capture mode.
170        reads: Vec<MemRange>,
171        /// Write buffer ranges for reorder conflict detection (4e.3).
172        /// Populated from `barrier_between` calls in capture mode.
173        writes: Vec<MemRange>,
174    },
175    /// A memory barrier sentinel — forces a barrier at replay time.
176    Barrier,
177}
178
179/// Convert a slice of buffer references into capture-mode
180/// [`MemRange`] tuples.  Used by the [`CommandEncoder::dispatch_tracked*`]
181/// family in capture mode — equivalent to the conversion
182/// `GraphSession::barrier_between` does at `graph.rs:1452-1465`.
183///
184/// `(start, end)` uses `contents_ptr() + byte_offset` as the start
185/// and `contents_ptr() + byte_offset + slice_extent` as the end.
186fn ranges_from_buffers(bufs: &[&MlxBuffer]) -> Vec<MemRange> {
187    bufs.iter()
188        .map(|b| {
189            let base = b.contents_ptr() as usize + b.byte_offset() as usize;
190            let extent = (b.byte_len()).saturating_sub(b.byte_offset() as usize);
191            (base, base + extent)
192        })
193        .collect()
194}
195
196/// Apply a slice of `KernelArg` bindings to a compute encoder.
197///
198/// `KernelArg::Buffer(buf)` propagates the `MlxBuffer::byte_offset()` so
199/// `slice_view`-derived sub-buffers are honored automatically — the
200/// kernel sees memory starting at the slice's offset. This matches the
201/// documented contract of `slice_view` and the offset-handling in the
202/// other binding paths in this file (`encode`, `encode_threadgroups`,
203/// `encode_threadgroups_with_shared`, replay). Without it, every
204/// `slice_view`-derived buffer bound via `KernelArg::Buffer` silently
205/// exposes the entire underlying allocation — surfaced by hf2q's
206/// nomic-bert iter-79 cosine parity bisection (cosine 0.098 → 0.999962
207/// after fix).
208///
209/// `KernelArg::BufferWithOffset(buf, offset)` continues to use the
210/// explicit `offset` argument verbatim (callers asking for an explicit
211/// offset get exactly that, even on sliced buffers). The two API
212/// surfaces are intentional: implicit (sliced views auto-propagate) vs.
213/// explicit (caller-controlled).
214#[inline]
215fn apply_bindings(encoder: &ComputeCommandEncoderRef, bindings: &[(u64, KernelArg<'_>)]) {
216    for &(index, ref arg) in bindings {
217        match arg {
218            KernelArg::Buffer(buf) => {
219                encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
220            }
221            KernelArg::BufferWithOffset(buf, offset) => {
222                encoder.set_buffer(index, Some(buf.metal_buffer()), *offset);
223            }
224            KernelArg::Bytes(bytes) => {
225                encoder.set_bytes(index, bytes.len() as u64, bytes.as_ptr() as *const _);
226            }
227        }
228    }
229}
230
231/// Number of times `commit_and_wait()` has been called (CPU sync points).
232static SYNC_COUNT: AtomicU64 = AtomicU64::new(0);
233
234/// Number of times an encode method has been called (GPU dispatches).
235static DISPATCH_COUNT: AtomicU64 = AtomicU64::new(0);
236
237/// Number of `MTLCommandBuffer` instances created via `CommandEncoder::new`.
238/// Increments once per `device.command_encoder()` call.  Used by hf2q's
239/// `HF2Q_DECODE_PROFILE` instrumentation to measure command-buffer
240/// overhead per decode token (ADR-012 §Optimize / Task #15 follow-up).
241static CMD_BUF_COUNT: AtomicU64 = AtomicU64::new(0);
242
243/// Number of `memory_barrier()` calls that reached the
244/// `objc::msg_send![encoder, memoryBarrierWithScope:]` site.  Capture-mode
245/// no-ops and pre-encoder no-ops are excluded so the count reflects
246/// actual MTL barriers issued.
247///
248/// Always tracked — the increment is one atomic op, ~5 ns.  ADR-015 H4
249/// (Wave 2b hard gate #2) requires per-barrier counter resolution to
250/// confirm-or-falsify the barrier-coalescing lever; xctrace TimeProfiler
251/// at 1 ms sampling cannot resolve `memory_barrier` even though it fires
252/// ~440×/token (`docs/ADR-015-mlx-native-single-cb-decode.md` §"P3a' live
253/// profile pass" hypothesis register row H4).
254static BARRIER_COUNT: AtomicU64 = AtomicU64::new(0);
255
256/// Total nanoseconds spent inside the `objc::msg_send!` barrier site,
257/// summed across all calls.  ONLY updated when the env var
258/// `MLX_PROFILE_BARRIERS=1` is set on the process (cached on first
259/// `memory_barrier` call).  When disabled the timing path is a single
260/// branch + the unconditional barrier dispatch — same hot-path cost as
261/// before this counter was added.
262///
263/// Why env-gated: timing adds 2 × `Instant::now()` (~50–100 ns each via
264/// `mach_absolute_time`) per barrier.  At ~440 barriers/token that is
265/// ~22–44 µs/token of measurement overhead — comparable to what we are
266/// trying to measure.  Production must keep this off; profiling runs
267/// opt-in.
268static BARRIER_NS: AtomicU64 = AtomicU64::new(0);
269
270/// Reset all counters to zero.
271pub fn reset_counters() {
272    SYNC_COUNT.store(0, Ordering::Relaxed);
273    DISPATCH_COUNT.store(0, Ordering::Relaxed);
274    CMD_BUF_COUNT.store(0, Ordering::Relaxed);
275    BARRIER_COUNT.store(0, Ordering::Relaxed);
276    BARRIER_NS.store(0, Ordering::Relaxed);
277    AUTO_BARRIER_COUNT.store(0, Ordering::Relaxed);
278    AUTO_BARRIER_CONCURRENT.store(0, Ordering::Relaxed);
279}
280
281/// Read the current value of `SYNC_COUNT`.
282///
283/// Each call to `commit_and_wait()` increments this counter.
284pub fn sync_count() -> u64 {
285    SYNC_COUNT.load(Ordering::Relaxed)
286}
287
288/// Read the current value of `DISPATCH_COUNT`.
289///
290/// Each call to `encode()`, `encode_threadgroups()`, or
291/// `encode_threadgroups_with_shared()` increments this counter.
292pub fn dispatch_count() -> u64 {
293    DISPATCH_COUNT.load(Ordering::Relaxed)
294}
295
296/// Per-pipeline dispatch bucket support (ADR-028 iter-284).
297///
298/// Env-gated via `MLX_DISP_BUCKET=1`.  When enabled, every
299/// `encode*` call records its pipeline's label in a global hash map.
300/// This gives a per-kernel breakdown comparable to llama.cpp's
301/// instrumented dispatch site for finding *which* kernels make up
302/// the per-token dispatch budget.
303fn pipeline_buckets()
304    -> &'static std::sync::Mutex<std::collections::HashMap<String, u64>> {
305    static BUCKETS: std::sync::OnceLock<
306        std::sync::Mutex<std::collections::HashMap<String, u64>>,
307    > = std::sync::OnceLock::new();
308    BUCKETS.get_or_init(|| std::sync::Mutex::new(std::collections::HashMap::new()))
309}
310
311/// Cached env-flag check — single load on the hot path.
312fn pipeline_bucket_enabled() -> bool {
313    static CACHED: AtomicI8 = AtomicI8::new(-1);
314    let v = CACHED.load(Ordering::Relaxed);
315    if v >= 0 {
316        return v == 1;
317    }
318    let on = std::env::var("MLX_DISP_BUCKET").as_deref() == Ok("1");
319    CACHED.store(if on { 1 } else { 0 }, Ordering::Relaxed);
320    on
321}
322
323/// Record a dispatch into the per-pipeline bucket if the env-flag is on.
324/// Called from every `encode*` site alongside the `DISPATCH_COUNT` bump.
325#[inline]
326pub(crate) fn bucket_dispatch(pipeline: &ComputePipelineStateRef) {
327    if !pipeline_bucket_enabled() {
328        return;
329    }
330    let label = pipeline.label();
331    if label.is_empty() {
332        return;
333    }
334    if let Ok(mut t) = pipeline_buckets().lock() {
335        *t.entry(label.to_string()).or_insert(0) += 1;
336    }
337}
338
339/// Public dump of `MLX_DISP_BUCKET` data: `Vec<(label, count)>` sorted
340/// descending by count.  Returns empty when env-flag is off / never
341/// recorded.
342pub fn pipeline_dispatch_buckets() -> Vec<(String, u64)> {
343    let mut v: Vec<(String, u64)> = if let Ok(t) = pipeline_buckets().lock() {
344        t.iter().map(|(k, v)| (k.clone(), *v)).collect()
345    } else {
346        Vec::new()
347    };
348    v.sort_by(|a, b| b.1.cmp(&a.1));
349    v
350}
351
352/// Reset the per-pipeline dispatch buckets (typically called at decode
353/// start to ignore prefill / warmup contributions).
354pub fn reset_pipeline_dispatch_buckets() {
355    if let Ok(mut t) = pipeline_buckets().lock() {
356        t.clear();
357    }
358}
359
360/// Read the current value of `CMD_BUF_COUNT`.
361///
362/// Each `CommandEncoder::new` (i.e. each `MlxDevice::command_encoder()`)
363/// increments this counter.  Useful for diagnosing per-dispatch Metal
364/// command-buffer overhead in inner loops.
365pub fn cmd_buf_count() -> u64 {
366    CMD_BUF_COUNT.load(Ordering::Relaxed)
367}
368
369/// Read the current value of `BARRIER_COUNT`.
370///
371/// Each `memory_barrier()` call that reaches the underlying
372/// `objc::msg_send![encoder, memoryBarrierWithScope:]` site increments this
373/// counter.  Capture-mode no-ops and pre-encoder no-ops are excluded.
374/// ADR-015 H4 hypothesis: ~440 barriers/token on the qwen35 decode hot
375/// path (verify against this counter).
376pub fn barrier_count() -> u64 {
377    BARRIER_COUNT.load(Ordering::Relaxed)
378}
379
380/// Read the total nanoseconds spent in the `memoryBarrierWithScope:`
381/// `objc::msg_send!` site.  Only non-zero when `MLX_PROFILE_BARRIERS=1`
382/// was in the environment at the time of the first `memory_barrier()`
383/// call (the env check is cached on first use).
384///
385/// Combined with [`barrier_count`] this gives µs/barrier =
386/// `barrier_total_ns() / 1000 / barrier_count()`.
387pub fn barrier_total_ns() -> u64 {
388    BARRIER_NS.load(Ordering::Relaxed)
389}
390
391/// Whether barrier timing is enabled (env-gated, cached on first check).
392///
393/// Reading the env var via `std::env::var` is itself non-trivial; using
394/// `OnceLock` caches the decision so the per-barrier branch is a single
395/// atomic-load + compare.
396fn barrier_profile_enabled() -> bool {
397    use std::sync::OnceLock;
398    static FLAG: OnceLock<bool> = OnceLock::new();
399    *FLAG.get_or_init(|| {
400        std::env::var("MLX_PROFILE_BARRIERS")
401            .map(|v| v == "1")
402            .unwrap_or(false)
403    })
404}
405
406/// Whether `MLX_UNRETAINED_REFS=1` is set in the process environment.
407///
408/// ADR-015 iter13 — when true, `CommandEncoder::new_with_residency` opens
409/// each `MTLCommandBuffer` via
410/// [`CommandQueueRef::new_command_buffer_with_unretained_references`]
411/// instead of the default `commandBuffer`.  llama.cpp's per-token decode
412/// CBs use this same call (`/opt/llama.cpp/ggml/src/ggml-metal/`
413/// `ggml-metal-context.m:512` `[queue commandBufferWithUnretainedReferences]`)
414/// and gain ~3-5% wall on M-series GPUs by skipping per-buffer-binding ARC
415/// retains on submit.
416///
417/// **Caller-side prerequisite.**  Every Metal buffer bound to a dispatch
418/// must outlive the CB — see the docstring on
419/// [`CommandEncoder::new_with_residency`] for the full caller contract.
420/// In hf2q, the per-decode-token `MlxBufferPool` (`buffer_pool.rs`)
421/// already keeps ARC clones alive in its `in_use` list across the entire
422/// decode token; routing transient scratches through that pool is the
423/// canonical way to satisfy the contract.
424///
425/// Cached on first read via `OnceLock` to keep the per-CB-construction
426/// branch single-atomic-load fast.  Default OFF so any production decode
427/// run that does NOT explicitly set the var preserves retained-refs
428/// behavior verbatim.
429fn unretained_refs_enabled() -> bool {
430    use std::sync::OnceLock;
431    static FLAG: OnceLock<bool> = OnceLock::new();
432    *FLAG.get_or_init(|| {
433        std::env::var("MLX_UNRETAINED_REFS")
434            .map(|v| v == "1")
435            .unwrap_or(false)
436    })
437}
438
439/// Whether `HF2Q_AUTO_BARRIER=1` is set in the process environment.
440///
441/// ADR-015 iter37 — when true, every [`CommandEncoder::dispatch_tracked`]
442/// call consults a [`MemRanges`](crate::mem_ranges::MemRanges) tracker
443/// and auto-emits a `memoryBarrierWithScope:` exactly when the new
444/// dispatch's read/write ranges conflict with previously-recorded
445/// ranges (mirrors llama.cpp's `ggml_metal_op_concurrency_check` at
446/// `/opt/llama.cpp/ggml/src/ggml-metal/ggml-metal-ops.cpp:147-225`).
447/// When false, `dispatch_tracked` collapses to the same code path as
448/// `encode*` — no tracking, no auto-barriers — preserving sourdough
449/// behavior for any caller that opts into the tracked API but runs
450/// without the env gate.
451///
452/// Cached on first read via `OnceLock`.  Default OFF — production
453/// decode/prefill keeps its hand-placed `enc.memory_barrier()` calls
454/// until the migration in iter38+.
455fn auto_barrier_enabled() -> bool {
456    use std::sync::OnceLock;
457    static FLAG: OnceLock<bool> = OnceLock::new();
458    *FLAG.get_or_init(|| {
459        std::env::var("HF2Q_AUTO_BARRIER")
460            .map(|v| v == "1")
461            .unwrap_or(false)
462    })
463}
464
465/// Number of `memory_barrier()` calls auto-emitted by
466/// [`CommandEncoder::dispatch_tracked`] under
467/// `HF2Q_AUTO_BARRIER=1`.  Disjoint from [`BARRIER_COUNT`] —
468/// auto-barriers also bump `BARRIER_COUNT` since they go through
469/// `memory_barrier()`, so this counter measures only the
470/// auto-emitted subset.
471static AUTO_BARRIER_COUNT: AtomicU64 = AtomicU64::new(0);
472
473/// Number of `dispatch_tracked` calls whose mem-ranges check returned
474/// "concurrent" (no barrier needed).  Together with
475/// [`AUTO_BARRIER_COUNT`] this measures the elision rate of the
476/// dataflow barrier: `concurrent / (concurrent + barriers)` is the
477/// fraction of dispatches that ran inside the previous concurrent
478/// group rather than starting a new one.
479static AUTO_BARRIER_CONCURRENT: AtomicU64 = AtomicU64::new(0);
480
481// ---------------------------------------------------------------------------
482// ADR-015 iter63 — per-dispatch GPU sampling support
483// ---------------------------------------------------------------------------
484
485/// Hard cap on per-CB sample-buffer sample count (Risk R4 in
486/// PROFILING-KIT-DESIGN §A.7).
487///
488/// Empirically verified on Apple Silicon (M-series, macOS 26): the
489/// underlying `MTLCounterSampleBufferDescriptor.sampleCount` is bounded
490/// by a per-buffer **byte-size** limit of 32768 B.  At 8 bytes per
491/// `MTLCounterResultTimestamp` sample that maps to a sample-count
492/// ceiling of `32_768 / 8 = 4096`.  We allocate two samples per
493/// dispatch (start + end), so this ceiling = 2048 dispatches per CB.
494/// Decode CBs (~120 dispatches) fit comfortably; long prefill CBs
495/// (~6K dispatches per design §A.7) will truncate after 2048 — see
496/// [`Self::sample_dispatch_pre`] for the truncation path.  Future
497/// iter can chunk-resolve every 2K dispatches.
498///
499/// The original design constant of 32_768 (PROFILING-KIT-DESIGN §A.7)
500/// was based on Apple's documented ~64K-per-buffer "practical" limit,
501/// but the measured constraint on this hardware is the 32 KB byte
502/// budget.  Setting the budget below that would underutilize the
503/// buffer; setting it above causes
504/// `newCounterSampleBufferWithDescriptor` to fail with `Invalid sample
505/// buffer length: <bytes> B. Expected range: 8 -> 32768`.
506const MAX_SAMPLES_PER_CB: u64 = 4096;
507
508/// Whether the per-CB warning about a missing `MTLCommonCounterSetTimestamp`
509/// has been emitted yet.  Risk R1: if `device.counter_sets()` does not
510/// return a set named `"timestamp"` (case-insensitive), we degrade the
511/// per-dispatch path to a no-op and log once via stderr.
512static TIMESTAMP_SET_WARN_LOGGED: AtomicU64 = AtomicU64::new(0);
513
514/// Pending per-dispatch metadata that pairs with sample indices `2i`
515/// (start) and `2i+1` (end) inside the CB's `MTLCounterSampleBuffer`.
516/// Resolved by `CommandEncoder::resolve_dispatch_samples` at CB
517/// commit-time and converted to [`crate::kernel_profile::DispatchEntry`]
518/// before being pushed to the global table.
519#[derive(Clone, Debug)]
520struct PendingDispatchMeta {
521    op_kind: &'static str,
522    dispatch_index: u32,
523}
524
525/// Read the cumulative number of auto-emitted barriers across all
526/// encoders since process start (or last [`reset_counters`]).
527pub fn auto_barrier_count() -> u64 {
528    AUTO_BARRIER_COUNT.load(Ordering::Relaxed)
529}
530
531/// Read the cumulative number of `dispatch_tracked` calls that did NOT
532/// emit a barrier (ran concurrent with the previous group).
533pub fn auto_barrier_concurrent_count() -> u64 {
534    AUTO_BARRIER_CONCURRENT.load(Ordering::Relaxed)
535}
536
537/// Issue the underlying Metal `memoryBarrierWithScope:` ObjC msg_send.
538///
539/// Held in its own `#[inline(never)]` function so xctrace / Instruments
540/// has a stable Rust frame to attribute barrier time against, separate
541/// from the surrounding encoder accounting.  Per ADR-015 §P3a' Codex
542/// review Q2: TimeProfiler at 1 ms sampling cannot see this site when
543/// inlined; an explicit non-inline frame plus the [`BARRIER_NS`] counter
544/// closes the H4 hard gate.
545#[inline(never)]
546fn issue_metal_buffer_barrier(encoder: &ComputeCommandEncoderRef) {
547    // MTLBarrierScopeBuffers = 1 << 0 = 1.
548    const MTL_BARRIER_SCOPE_BUFFERS: u64 = 1;
549    unsafe {
550        let _: () =
551            objc::msg_send![encoder, memoryBarrierWithScope: MTL_BARRIER_SCOPE_BUFFERS];
552    }
553}
554
555/// A batched compute command encoder.
556///
557/// Keeps a single Metal `ComputeCommandEncoder` alive across multiple
558/// dispatches.  The encoder is created on the first dispatch and ended
559/// only when the command buffer is committed.  This mirrors candle's
560/// `compute_per_buffer` pattern and avoids per-dispatch encoder overhead.
561///
562/// # Typical usage
563///
564/// ```ignore
565/// let mut enc = device.command_encoder()?;
566/// // Multiple dispatches share the same compute encoder:
567/// enc.encode_threadgroups(pipeline1, &buffers1, tg1, tg_size1);
568/// enc.encode_threadgroups(pipeline2, &buffers2, tg2, tg_size2);
569/// enc.commit_and_wait()?;
570/// ```
571pub struct CommandEncoder {
572    cmd_buf: CommandBuffer,
573    /// Owned clone of the originating command queue.
574    ///
575    /// ADR-019 Phase 0b iter89e2-A: stored at `new_with_residency` time so
576    /// downstream lifecycle code (e.g. `EncoderSession::reset_for_next_stage`
577    /// in Phase 0b-B) can open a fresh `CommandBuffer` from the same queue
578    /// after a non-blocking `commit_stage()`. metal-rs 0.33's
579    /// `CommandQueue` type is `Send + Sync` via `foreign_obj_type!`
580    /// (`/Users/robert/.cargo/registry/src/index.crates.io-1949cf8c6b5b557f/metal-0.33.0/src/lib.rs:179`),
581    /// so adding this field preserves the existing unsafe `Send` impl
582    /// on `CommandEncoder` (declared below).
583    ///
584    /// ADR-019 Phase 0b iter89e2-B (CONSUMED): read by
585    /// [`Self::reset_command_buffer`] to spawn a fresh `CommandBuffer`
586    /// after a non-blocking `commit*` so `EncoderSession::reset_for_next_stage`
587    /// can chain stage CBs without re-constructing the encoder. Holding a
588    /// clone here (rather than a `&CommandQueue` borrow) avoids a lifetime
589    /// parameter on `CommandEncoder` that would propagate through every
590    /// consumer in mlx-native and hf2q.
591    queue: CommandQueue,
592    // SAFETY marker: see unsafe Send impl below.
593    /// Raw pointer to the persistent compute encoder.
594    /// Non-null when a compute pass is active.
595    /// The encoder borrows from `cmd_buf` but we cannot express this
596    /// lifetime in safe Rust, so we use a raw pointer.
597    /// SAFETY: the pointer is valid as long as `cmd_buf` is alive and
598    /// `end_encoding()` has not been called on it.
599    active_encoder: *const ComputeCommandEncoderRef,
600    /// When `Some`, dispatches are recorded here instead of being encoded
601    /// into Metal.  Set via `start_capture()`, extracted via `take_capture()`.
602    capture: Option<Vec<CapturedNode>>,
603    /// Op kind tag for the NEXT captured dispatch.  Set via `set_op_kind()`,
604    /// consumed (reset to `Other`) when a dispatch is captured.
605    pending_op_kind: CapturedOpKind,
606    /// Pending read buffer ranges for the NEXT captured dispatch.
607    /// Set via `set_pending_buffer_ranges()`, consumed when the next dispatch
608    /// is captured.  Used by the reorder pass (Phase 4e.3).
609    pending_reads: Vec<MemRange>,
610    /// Pending write buffer ranges for the NEXT captured dispatch.
611    pending_writes: Vec<MemRange>,
612    /// ADR-015 iter8e (Phase 3b): residency set whose pending add/remove
613    /// staging is flushed at every `commit*` boundary.
614    ///
615    /// Cloned from the device at `device.command_encoder()` time. `None`
616    /// when residency sets are disabled (HF2Q_NO_RESIDENCY=1, macOS<15,
617    /// or test-only `CommandEncoder::new` from a residency-less queue).
618    residency_set: Option<ResidencySet>,
619    /// ADR-015 iter37: dataflow barrier inference state.
620    ///
621    /// Populated only when `HF2Q_AUTO_BARRIER=1` is set at process
622    /// start (cached via [`auto_barrier_enabled`]).  Each
623    /// [`Self::dispatch_tracked`] call consults this state to decide
624    /// whether a Metal memory barrier is required; on conflict the
625    /// barrier is emitted, the state is reset, and the new dispatch's
626    /// ranges seed the next concurrent group.  When the env gate is
627    /// off, `dispatch_tracked` collapses to its untracked equivalent
628    /// and this field is left empty for the encoder's lifetime.
629    ///
630    /// The field is always present (zero-sized when empty) so the
631    /// gate-off branch is a single bool-load + early return rather
632    /// than an allocation/Option indirection.
633    mem_ranges: MemRanges,
634    /// ADR-015 iter63 (per-dispatch profiling): the sample buffer for
635    /// `MTLCounterSampleBuffer.sampleCounters` calls that bracket every
636    /// `encode*` dispatch in this CB.  Lazily allocated on first
637    /// dispatch when `MLX_PROFILE_DISPATCH=1`; `None` otherwise.
638    /// Released (set to `None`) inside `resolve_dispatch_samples` after
639    /// the CB completes — re-allocated on the next `encode*` if the env
640    /// gate stays set.
641    sample_buffer: Option<CounterSampleBuffer>,
642    /// ADR-015 iter63: pending per-dispatch metadata that pairs with
643    /// sample indices `2*i` and `2*i+1` inside `sample_buffer`.  Each
644    /// `encode*` call appends one entry (when sampling is active);
645    /// `resolve_dispatch_samples` drains the vec at commit time.
646    pending_dispatch_meta: Vec<PendingDispatchMeta>,
647    /// ADR-015 iter63: 0-based dispatch ordinal within the current CB.
648    /// Incremented in every `encode*` site after taking the pending
649    /// op_kind; reset to 0 inside `resolve_dispatch_samples`.
650    dispatch_in_cb: u32,
651    /// ADR-015 iter63: most recent label set via `apply_labels`, used
652    /// as the per-dispatch `cb_label` field.  `String::new()` until
653    /// `commit_and_wait_labeled` / `commit_labeled` is called.
654    last_label: String,
655}
656
657/// SAFETY: CommandEncoder is safe to Send across threads provided that:
658/// 1. Only one thread accesses the encoder at a time (exclusive ownership).
659/// 2. The encoder is not used concurrently from multiple threads.
660///
661/// Metal command buffers and compute encoders are thread-safe for exclusive
662/// access (Apple documentation: "You can create command buffers, encode
663/// commands, and submit them from any thread"). The raw pointer
664/// `active_encoder` borrows from `cmd_buf` and is valid as long as
665/// `cmd_buf` is alive — this invariant holds across thread boundaries
666/// because both fields move together.
667///
668/// This matches llama.cpp's pattern of encoding command buffers on GCD
669/// worker threads via `dispatch_apply`, and is used for the dual-buffer
670/// pipeline where buf1 is encoded on a worker thread while buf0 executes.
671unsafe impl Send for CommandEncoder {}
672
673impl CommandEncoder {
674    /// Create a new command encoder from the given command queue.
675    ///
676    /// This immediately creates a Metal command buffer.
677    ///
678    /// # Why retained references
679    ///
680    /// We use the regular `commandBuffer` (Metal retains every bound
681    /// resource for the lifetime of the buffer) rather than
682    /// `commandBufferWithUnretainedReferences`.  llama.cpp uses unretained
683    /// refs for an additional perf bump (~3-5% on M-series GPUs), but the
684    /// hf2q dispatch pattern allocates many transient scratch buffers
685    /// inside helper functions (`apply_proj` → `weight_bf16_owned`,
686    /// `apply_pre_norm` → `params`, etc.) that go out of scope at the
687    /// helper's return.  With unretained refs the metal::Buffer's ARC
688    /// drops to zero, freeing the underlying GPU memory before the
689    /// dispatch executes.  Verified 2026-04-26: switching to unretained
690    /// hits "Command buffer error: GPU command buffer completed with
691    /// error status" on the first MoE FFN dispatch.
692    ///
693    /// To enable unretained refs in the future, every helper that
694    /// allocates and dispatches must thread its scratch buffers up to a
695    /// caller scope that outlives the eventual commit, OR all such
696    /// scratch must come from the per-decode-token pool (which already
697    /// ARC-retains in its in_use list).  Today the lm_head + router-
698    /// download paths are still unpooled.
699    #[allow(dead_code)]
700    pub(crate) fn new(queue: &CommandQueue) -> Result<Self> {
701        Self::new_with_residency(queue, None)
702    }
703
704    /// Create a new command encoder, optionally bound to a residency set so
705    /// `commit*` boundaries can flush deferred add/remove staging.
706    ///
707    /// ADR-015 iter8e (Phase 3b): the encoder's `commit_and_wait`,
708    /// `commit_and_wait_labeled`, `commit`, `commit_labeled`,
709    /// `commit_wait_with_gpu_time` all call
710    /// [`ResidencySet::flush_pending`](ResidencySet::flush_pending) before
711    /// submitting the Metal command buffer. This converts the
712    /// per-allocation `[set commit]` storm
713    /// (~880 commits/decode-token in iter8d/8e claude+codex variants) into
714    /// at most one commit per CB submission — mirrors llama.cpp's
715    /// `ggml-metal-device.m:1378-1382` pattern (batch addAllocation in
716    /// loop, commit ONCE).
717    ///
718    /// ADR-015 iter13: when the `MLX_UNRETAINED_REFS=1` env var is set at
719    /// process start, this constructor uses
720    /// [`CommandQueueRef::new_command_buffer_with_unretained_references`]
721    /// instead of `new_command_buffer`.  llama.cpp's per-token decode CBs
722    /// use `commandBufferWithUnretainedReferences` (see
723    /// `/opt/llama.cpp/ggml/src/ggml-metal/ggml-metal-context.m:512`) which
724    /// skips Metal's per-buffer-binding ARC-retain on submit and saves
725    /// ~3-5% on M-series GPUs (per the docstring above).
726    ///
727    /// **Caller contract under unretained refs.**  Every Metal buffer bound
728    /// to a dispatch in this CB MUST outlive the CB's GPU completion.  In
729    /// the hf2q decode path, that means every transient scratch must be
730    /// either (a) backed by the per-decode-token arena pool
731    /// (`MlxBufferPool` keeps an ARC clone in `in_use` until the next
732    /// `reset` — see `buffer_pool.rs:60`) or (b) hoisted to a caller scope
733    /// that lives across the terminal `commit_and_wait_labeled`.  Helpers
734    /// in `apply_proj` / `apply_pre_norm` / lm_head cast / router-download
735    /// that allocated transients via `device.alloc_buffer` and dropped
736    /// them at function return MUST be lifted to `pooled_alloc_buffer`
737    /// before `MLX_UNRETAINED_REFS=1` is enabled, or the first MoE FFN
738    /// dispatch will crash with "Command buffer error: GPU command buffer
739    /// completed with error status" (verified 2026-04-26).
740    ///
741    /// The default (`MLX_UNRETAINED_REFS` unset) preserves retained-refs
742    /// behavior verbatim — this is the sourdough-safe path.
743    pub(crate) fn new_with_residency(
744        queue: &CommandQueue,
745        residency_set: Option<ResidencySet>,
746    ) -> Result<Self> {
747        let cmd_buf = if unretained_refs_enabled() {
748            queue.new_command_buffer_with_unretained_references().to_owned()
749        } else {
750            queue.new_command_buffer().to_owned()
751        };
752        CMD_BUF_COUNT.fetch_add(1, Ordering::Relaxed);
753        Ok(Self {
754            cmd_buf,
755            queue: queue.to_owned(),
756            active_encoder: std::ptr::null(),
757            capture: None,
758            pending_op_kind: CapturedOpKind::Other,
759            pending_reads: Vec::new(),
760            pending_writes: Vec::new(),
761            residency_set,
762            mem_ranges: MemRanges::new(),
763            sample_buffer: None,
764            pending_dispatch_meta: Vec::new(),
765            dispatch_in_cb: 0,
766            last_label: String::new(),
767        })
768    }
769
770    /// Enable capture mode.
771    ///
772    /// All subsequent dispatch and barrier calls will be recorded into a
773    /// `Vec<CapturedNode>` instead of being encoded into Metal.
774    /// Call `take_capture()` to extract the recorded nodes.
775    pub fn start_capture(&mut self) {
776        self.capture = Some(Vec::with_capacity(128));
777    }
778
779    /// Whether the encoder is currently in capture mode.
780    pub fn is_capturing(&self) -> bool {
781        self.capture.is_some()
782    }
783
784    /// Extract the captured nodes, ending capture mode.
785    ///
786    /// Returns `None` if capture mode was not active.
787    pub fn take_capture(&mut self) -> Option<Vec<CapturedNode>> {
788        self.capture.take()
789    }
790
791    /// Tag the NEXT captured dispatch with the given operation kind.
792    ///
793    /// The tag is consumed (reset to `Other`) after the next dispatch is
794    /// captured.  Only meaningful in capture mode — has no effect on
795    /// direct-dispatch encoding.
796    ///
797    /// Used by op dispatch functions to annotate captures for the fusion
798    /// pass (Phase 4e.2).
799    pub fn set_op_kind(&mut self, kind: CapturedOpKind) {
800        self.pending_op_kind = kind;
801    }
802
803    /// Consume and return the pending op kind, resetting it to `Other`.
804    fn take_pending_op_kind(&mut self) -> CapturedOpKind {
805        let kind = self.pending_op_kind;
806        self.pending_op_kind = CapturedOpKind::Other;
807        kind
808    }
809
810    /// Stash buffer range annotations for the NEXT captured dispatch.
811    ///
812    /// Called by `GraphSession::barrier_between()` in capture mode to record
813    /// which buffers the next dispatch reads from and writes to.  The ranges
814    /// are consumed by the next `encode_*` call and attached to the captured
815    /// `CapturedNode::Dispatch`.
816    ///
817    /// Only meaningful in capture mode — has no effect on direct-dispatch.
818    pub fn set_pending_buffer_ranges(&mut self, reads: Vec<MemRange>, writes: Vec<MemRange>) {
819        self.pending_reads = reads;
820        self.pending_writes = writes;
821    }
822
823    /// Patch the last captured dispatch node's empty reads/writes with the
824    /// given ranges. No-op if not capturing, or if the last node isn't a
825    /// Dispatch, or if its ranges are already populated.
826    ///
827    /// Used by `GraphSession::track_dispatch` in recording mode to annotate
828    /// dispatches that were called without a preceding `barrier_between`.
829    pub fn annotate_last_dispatch_if_missing(&mut self, reads: Vec<MemRange>, writes: Vec<MemRange>) {
830        if let Some(ref mut nodes) = self.capture {
831            if let Some(CapturedNode::Dispatch { reads: r, writes: w, .. }) = nodes.last_mut() {
832                if r.is_empty() && !reads.is_empty() {
833                    *r = reads;
834                }
835                if w.is_empty() && !writes.is_empty() {
836                    *w = writes;
837                }
838            }
839        }
840    }
841
842    /// Consume and return the pending buffer range annotations.
843    fn take_pending_buffer_ranges(&mut self) -> (Vec<MemRange>, Vec<MemRange>) {
844        let reads = std::mem::take(&mut self.pending_reads);
845        let writes = std::mem::take(&mut self.pending_writes);
846        (reads, writes)
847    }
848
849    /// Record buffer bindings into `RecordedBinding` form.
850    fn record_buffer_bindings(buffers: &[(u64, &MlxBuffer)]) -> Vec<(u64, RecordedBinding)> {
851        buffers
852            .iter()
853            .map(|&(index, buf)| {
854                (
855                    index,
856                    RecordedBinding::Buffer {
857                        metal_buffer: buf.metal_buffer().clone(),
858                        offset: buf.byte_offset(),
859                    },
860                )
861            })
862            .collect()
863    }
864
865    /// Record `KernelArg` bindings into `RecordedBinding` form.
866    ///
867    /// `KernelArg::Buffer(buf)` records `buf.byte_offset()` so capture →
868    /// replay round-trips of `slice_view`-derived buffers preserve their
869    /// offsets, matching `record_buffer_bindings`'s behavior at line 382.
870    fn record_arg_bindings(bindings: &[(u64, KernelArg<'_>)]) -> Vec<(u64, RecordedBinding)> {
871        bindings
872            .iter()
873            .map(|(index, arg)| {
874                let recorded = match arg {
875                    KernelArg::Buffer(buf) => RecordedBinding::Buffer {
876                        metal_buffer: buf.metal_buffer().clone(),
877                        offset: buf.byte_offset(),
878                    },
879                    KernelArg::BufferWithOffset(buf, offset) => RecordedBinding::Buffer {
880                        metal_buffer: buf.metal_buffer().clone(),
881                        offset: *offset,
882                    },
883                    KernelArg::Bytes(bytes) => RecordedBinding::Bytes(bytes.to_vec()),
884                };
885                (*index, recorded)
886            })
887            .collect()
888    }
889
890    /// Get or create the persistent compute encoder.
891    ///
892    /// On the first call, creates a new compute encoder from the command
893    /// buffer.  On subsequent calls, returns the existing one.
894    ///
895    /// SAFETY: The returned reference borrows from `self.cmd_buf` which is
896    /// alive for the lifetime of this `CommandEncoder`.  The raw pointer is
897    /// valid until `end_active_encoder()` is called.
898    #[inline]
899    fn get_or_create_encoder(&mut self) -> &ComputeCommandEncoderRef {
900        if self.active_encoder.is_null() {
901            // Use MTLDispatchTypeConcurrent to allow independent dispatches
902            // to overlap on the GPU.  Memory barriers are inserted between
903            // dependent dispatches via `memory_barrier()`.
904            //
905            // ADR-015 iter61a-2 probe: HF2Q_FORCE_SERIAL_DISPATCH=1 falls back
906            // to MTLDispatchType::Serial — every dispatch waits for the
907            // previous to complete, eliminating concurrent-dispatch race
908            // windows. Used to falsify Hypothesis (g): missing memory_barrier
909            // calls between dependent dispatches cause cold-run logit
910            // non-determinism via thread-race on a shared buffer.
911            let dispatch_type = if std::env::var("HF2Q_FORCE_SERIAL_DISPATCH")
912                .map(|v| v == "1")
913                .unwrap_or(false)
914            {
915                MTLDispatchType::Serial
916            } else {
917                MTLDispatchType::Concurrent
918            };
919            let encoder = self
920                .cmd_buf
921                .compute_command_encoder_with_dispatch_type(dispatch_type);
922            self.active_encoder = encoder as *const ComputeCommandEncoderRef;
923        }
924        // SAFETY: active_encoder is non-null and points to a valid encoder
925        // owned by cmd_buf.
926        unsafe { &*self.active_encoder }
927    }
928
929    /// End the active compute encoder if one exists.
930    #[inline]
931    fn end_active_encoder(&mut self) {
932        if !self.active_encoder.is_null() {
933            // SAFETY: the pointer was obtained from cmd_buf.new_compute_command_encoder()
934            // and has not been ended yet.
935            unsafe { &*self.active_encoder }.end_encoding();
936            self.active_encoder = std::ptr::null();
937        }
938    }
939
940    /// Insert a memory barrier with scope `MTLBarrierScopeBuffers`.
941    ///
942    /// When the encoder uses `MTLDispatchTypeConcurrent`, all dispatches can
943    /// execute concurrently unless separated by a barrier.  Call this between
944    /// dispatches where the later dispatch reads a buffer written by an
945    /// earlier one.
946    ///
947    /// This is the same pattern llama.cpp uses:
948    /// `[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers]`
949    #[allow(unexpected_cfgs)]
950    pub fn memory_barrier(&mut self) {
951        if let Some(ref mut nodes) = self.capture {
952            nodes.push(CapturedNode::Barrier);
953            return;
954        }
955        if self.active_encoder.is_null() {
956            return;
957        }
958        BARRIER_COUNT.fetch_add(1, Ordering::Relaxed);
959        // SAFETY: active_encoder is non-null and valid.
960        let encoder = unsafe { &*self.active_encoder };
961        if barrier_profile_enabled() {
962            // mach_absolute_time path — only on when MLX_PROFILE_BARRIERS=1.
963            let start = std::time::Instant::now();
964            issue_metal_buffer_barrier(encoder);
965            let elapsed_ns = start.elapsed().as_nanos() as u64;
966            BARRIER_NS.fetch_add(elapsed_ns, Ordering::Relaxed);
967        } else {
968            issue_metal_buffer_barrier(encoder);
969        }
970    }
971
972    /// Set the compute pipeline state for subsequent dispatches.
973    ///
974    /// This begins a new compute pass if one is not already active.
975    pub fn set_pipeline(&mut self, pipeline: &ComputePipelineStateRef) {
976        let encoder = self.get_or_create_encoder();
977        encoder.set_compute_pipeline_state(pipeline);
978    }
979
980    /// Bind a buffer to a compute kernel argument slot.
981    ///
982    /// The `index` corresponds to the `[[buffer(N)]]` attribute in the MSL shader.
983    pub fn set_buffer(&self, index: u64, buffer: &MlxBuffer) {
984        let _ = (index, buffer);
985    }
986
987    /// Dispatch threads on the GPU.
988    pub fn dispatch_threads(&self, grid_size: MTLSize, threadgroup_size: MTLSize) {
989        let _ = (grid_size, threadgroup_size);
990    }
991
992    /// Encode a complete compute pass: set pipeline, bind buffers, dispatch.
993    ///
994    /// Reuses the persistent compute encoder — no per-dispatch encoder
995    /// creation overhead.
996    ///
997    /// # Arguments
998    ///
999    /// * `pipeline`         — The compiled compute pipeline to execute.
1000    /// * `buffers`          — Slice of `(index, &MlxBuffer)` pairs for buffer bindings.
1001    /// * `grid_size`        — Total number of threads to launch.
1002    /// * `threadgroup_size` — Threads per threadgroup.
1003    pub fn encode(
1004        &mut self,
1005        pipeline: &ComputePipelineStateRef,
1006        buffers: &[(u64, &MlxBuffer)],
1007        grid_size: MTLSize,
1008        threadgroup_size: MTLSize,
1009    ) {
1010        DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
1011        bucket_dispatch(pipeline);
1012        let op_kind = self.take_pending_op_kind();
1013        let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
1014        if let Some(ref mut nodes) = self.capture {
1015            nodes.push(CapturedNode::Dispatch {
1016                pipeline: pipeline.to_owned(),
1017                bindings: Self::record_buffer_bindings(buffers),
1018                threads_per_grid: grid_size,
1019                threads_per_threadgroup: threadgroup_size,
1020                threadgroup_memory: Vec::new(),
1021                dispatch_kind: DispatchKind::Threads,
1022                op_kind,
1023                reads: pending_reads,
1024                writes: pending_writes,
1025            });
1026            return;
1027        }
1028        // ADR-015 iter63: per-dispatch sampling (no-op when env unset).
1029        self.ensure_sample_buffer();
1030        let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
1031        // SAFETY: encoder_ptr aliases &self via active_encoder which we
1032        // know is non-null after get_or_create_encoder; this pattern is
1033        // used throughout the file (see memory_barrier).
1034        let encoder = unsafe { &*encoder_ptr };
1035        encoder.set_compute_pipeline_state(pipeline);
1036        for &(index, buf) in buffers {
1037            encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
1038        }
1039        let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
1040        encoder.dispatch_threads(grid_size, threadgroup_size);
1041        self.sample_dispatch_post(encoder, pre_idx);
1042    }
1043
1044    /// Encode a compute pass using threadgroups instead of raw thread counts.
1045    ///
1046    /// Reuses the persistent compute encoder — no per-dispatch encoder
1047    /// creation overhead.
1048    pub fn encode_threadgroups(
1049        &mut self,
1050        pipeline: &ComputePipelineStateRef,
1051        buffers: &[(u64, &MlxBuffer)],
1052        threadgroups: MTLSize,
1053        threadgroup_size: MTLSize,
1054    ) {
1055        DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
1056        bucket_dispatch(pipeline);
1057        let op_kind = self.take_pending_op_kind();
1058        let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
1059        if let Some(ref mut nodes) = self.capture {
1060            nodes.push(CapturedNode::Dispatch {
1061                pipeline: pipeline.to_owned(),
1062                bindings: Self::record_buffer_bindings(buffers),
1063                threads_per_grid: threadgroups,
1064                threads_per_threadgroup: threadgroup_size,
1065                threadgroup_memory: Vec::new(),
1066                dispatch_kind: DispatchKind::ThreadGroups,
1067                op_kind,
1068                reads: pending_reads,
1069                writes: pending_writes,
1070            });
1071            return;
1072        }
1073        // ADR-015 iter63: per-dispatch sampling (no-op when env unset).
1074        self.ensure_sample_buffer();
1075        let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
1076        // SAFETY: see encode() above.
1077        let encoder = unsafe { &*encoder_ptr };
1078        encoder.set_compute_pipeline_state(pipeline);
1079        for &(index, buf) in buffers {
1080            encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
1081        }
1082        let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
1083        encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
1084        self.sample_dispatch_post(encoder, pre_idx);
1085    }
1086
1087    /// Encode a compute pass using threadgroups with shared threadgroup memory.
1088    ///
1089    /// Like [`encode_threadgroups`](Self::encode_threadgroups), but additionally
1090    /// allocates threadgroup memory at the specified indices.  This is required
1091    /// for kernels that use `threadgroup` memory (e.g. reductions in rms_norm
1092    /// and softmax).
1093    ///
1094    /// # Arguments
1095    ///
1096    /// * `pipeline`         — The compiled compute pipeline to execute.
1097    /// * `buffers`          — Slice of `(index, &MlxBuffer)` pairs for buffer bindings.
1098    /// * `threadgroup_mem`  — Slice of `(index, byte_length)` pairs for threadgroup memory.
1099    /// * `threadgroups`     — Number of threadgroups to dispatch.
1100    /// * `threadgroup_size` — Threads per threadgroup.
1101    pub fn encode_threadgroups_with_shared(
1102        &mut self,
1103        pipeline: &ComputePipelineStateRef,
1104        buffers: &[(u64, &MlxBuffer)],
1105        threadgroup_mem: &[(u64, u64)],
1106        threadgroups: MTLSize,
1107        threadgroup_size: MTLSize,
1108    ) {
1109        DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
1110        bucket_dispatch(pipeline);
1111        let op_kind = self.take_pending_op_kind();
1112        let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
1113        if let Some(ref mut nodes) = self.capture {
1114            nodes.push(CapturedNode::Dispatch {
1115                pipeline: pipeline.to_owned(),
1116                bindings: Self::record_buffer_bindings(buffers),
1117                threads_per_grid: threadgroups,
1118                threads_per_threadgroup: threadgroup_size,
1119                threadgroup_memory: threadgroup_mem.to_vec(),
1120                dispatch_kind: DispatchKind::ThreadGroups,
1121                op_kind,
1122                reads: pending_reads,
1123                writes: pending_writes,
1124            });
1125            return;
1126        }
1127        // ADR-015 iter63: per-dispatch sampling (no-op when env unset).
1128        self.ensure_sample_buffer();
1129        let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
1130        // SAFETY: see encode() above.
1131        let encoder = unsafe { &*encoder_ptr };
1132        encoder.set_compute_pipeline_state(pipeline);
1133        for &(index, buf) in buffers {
1134            encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
1135        }
1136        for &(index, byte_length) in threadgroup_mem {
1137            encoder.set_threadgroup_memory_length(index, byte_length);
1138        }
1139        let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
1140        encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
1141        self.sample_dispatch_post(encoder, pre_idx);
1142    }
1143
1144    /// Encode a dispatch with mixed buffer/bytes bindings (dispatch_threads).
1145    ///
1146    /// Reuses the persistent compute encoder.
1147    pub fn encode_with_args(
1148        &mut self,
1149        pipeline: &ComputePipelineStateRef,
1150        bindings: &[(u64, KernelArg<'_>)],
1151        grid_size: MTLSize,
1152        threadgroup_size: MTLSize,
1153    ) {
1154        DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
1155        bucket_dispatch(pipeline);
1156        let op_kind = self.take_pending_op_kind();
1157        let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
1158        if let Some(ref mut nodes) = self.capture {
1159            nodes.push(CapturedNode::Dispatch {
1160                pipeline: pipeline.to_owned(),
1161                bindings: Self::record_arg_bindings(bindings),
1162                threads_per_grid: grid_size,
1163                threads_per_threadgroup: threadgroup_size,
1164                threadgroup_memory: Vec::new(),
1165                dispatch_kind: DispatchKind::Threads,
1166                op_kind,
1167                reads: pending_reads,
1168                writes: pending_writes,
1169            });
1170            return;
1171        }
1172        // ADR-015 iter63: per-dispatch sampling (no-op when env unset).
1173        self.ensure_sample_buffer();
1174        let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
1175        // SAFETY: see encode() above.
1176        let encoder = unsafe { &*encoder_ptr };
1177        encoder.set_compute_pipeline_state(pipeline);
1178        apply_bindings(encoder, bindings);
1179        let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
1180        encoder.dispatch_threads(grid_size, threadgroup_size);
1181        self.sample_dispatch_post(encoder, pre_idx);
1182    }
1183
1184    /// Encode a dispatch with mixed buffer/bytes bindings (dispatch_thread_groups).
1185    ///
1186    /// Reuses the persistent compute encoder.
1187    pub fn encode_threadgroups_with_args(
1188        &mut self,
1189        pipeline: &ComputePipelineStateRef,
1190        bindings: &[(u64, KernelArg<'_>)],
1191        threadgroups: MTLSize,
1192        threadgroup_size: MTLSize,
1193    ) {
1194        DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
1195        bucket_dispatch(pipeline);
1196        let op_kind = self.take_pending_op_kind();
1197        let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
1198        if let Some(ref mut nodes) = self.capture {
1199            nodes.push(CapturedNode::Dispatch {
1200                pipeline: pipeline.to_owned(),
1201                bindings: Self::record_arg_bindings(bindings),
1202                threads_per_grid: threadgroups,
1203                threads_per_threadgroup: threadgroup_size,
1204                threadgroup_memory: Vec::new(),
1205                dispatch_kind: DispatchKind::ThreadGroups,
1206                op_kind,
1207                reads: pending_reads,
1208                writes: pending_writes,
1209            });
1210            return;
1211        }
1212        // ADR-015 iter63: per-dispatch sampling (no-op when env unset).
1213        self.ensure_sample_buffer();
1214        let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
1215        // SAFETY: see encode() above.
1216        let encoder = unsafe { &*encoder_ptr };
1217        encoder.set_compute_pipeline_state(pipeline);
1218        apply_bindings(encoder, bindings);
1219        let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
1220        encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
1221        self.sample_dispatch_post(encoder, pre_idx);
1222    }
1223
1224    /// Encode a dispatch with mixed buffer/bytes bindings and shared memory.
1225    ///
1226    /// Reuses the persistent compute encoder.
1227    pub fn encode_threadgroups_with_args_and_shared(
1228        &mut self,
1229        pipeline: &ComputePipelineStateRef,
1230        bindings: &[(u64, KernelArg<'_>)],
1231        threadgroup_mem: &[(u64, u64)],
1232        threadgroups: MTLSize,
1233        threadgroup_size: MTLSize,
1234    ) {
1235        DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
1236        bucket_dispatch(pipeline);
1237        let op_kind = self.take_pending_op_kind();
1238        let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
1239        if let Some(ref mut nodes) = self.capture {
1240            nodes.push(CapturedNode::Dispatch {
1241                pipeline: pipeline.to_owned(),
1242                bindings: Self::record_arg_bindings(bindings),
1243                threads_per_grid: threadgroups,
1244                threads_per_threadgroup: threadgroup_size,
1245                threadgroup_memory: threadgroup_mem.to_vec(),
1246                dispatch_kind: DispatchKind::ThreadGroups,
1247                op_kind,
1248                reads: pending_reads,
1249                writes: pending_writes,
1250            });
1251            return;
1252        }
1253        // ADR-015 iter63: per-dispatch sampling (no-op when env unset).
1254        self.ensure_sample_buffer();
1255        let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
1256        // SAFETY: see encode() above.
1257        let encoder = unsafe { &*encoder_ptr };
1258        encoder.set_compute_pipeline_state(pipeline);
1259        apply_bindings(encoder, bindings);
1260        for &(index, byte_length) in threadgroup_mem {
1261            encoder.set_threadgroup_memory_length(index, byte_length);
1262        }
1263        let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
1264        encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
1265        self.sample_dispatch_post(encoder, pre_idx);
1266    }
1267
1268    // -----------------------------------------------------------------
1269    // ADR-015 iter37 — dataflow-driven auto-barrier dispatch family.
1270    //
1271    // These mirrors of `encode_threadgroups*_with_args*` take explicit
1272    // `reads: &[&MlxBuffer]` and `writes: &[&MlxBuffer]` slices.  When
1273    // the process started with `HF2Q_AUTO_BARRIER=1`, the encoder's
1274    // [`MemRanges`] tracker checks the new ranges against the
1275    // cumulative state since the last barrier; on conflict it emits
1276    // `memory_barrier()` and resets the state before recording the
1277    // new ranges.  When the env gate is unset, the check is skipped
1278    // entirely and the dispatch is applied identically to the
1279    // matching `encode_*` method — sourdough-safe by construction.
1280    //
1281    // Capture mode: the `reads`/`writes` ranges are recorded onto the
1282    // captured node via the existing `pending_reads`/`pending_writes`
1283    // mechanism, so a `dispatch_tracked` call inside capture mode is
1284    // equivalent to `set_pending_buffer_ranges + encode_*`.
1285    //
1286    // No production callsite migrates in iter37 — this is the API
1287    // surface the qwen35 forward path will adopt incrementally in
1288    // iter38+.  Today, every call to `dispatch_tracked` from a
1289    // production code path lives behind an explicit caller decision
1290    // to opt in.
1291    // -----------------------------------------------------------------
1292
1293    /// Auto-barrier-aware dispatch with [`KernelArg`] bindings (uses
1294    /// `dispatch_thread_groups`).
1295    ///
1296    /// Behaves identically to
1297    /// [`encode_threadgroups_with_args`](Self::encode_threadgroups_with_args)
1298    /// when `HF2Q_AUTO_BARRIER` is unset.  When set, consults the
1299    /// per-encoder [`MemRanges`] tracker:
1300    ///
1301    /// * Conflict (RAW/WAR/WAW on a same-buffer range) → emit
1302    ///   `memory_barrier()`, increment [`AUTO_BARRIER_COUNT`], reset
1303    ///   the tracker, then dispatch and seed the new concurrent group
1304    ///   with this dispatch's ranges.
1305    /// * No conflict → increment [`AUTO_BARRIER_CONCURRENT`], record
1306    ///   the ranges into the cumulative state, dispatch.
1307    pub fn dispatch_tracked_threadgroups_with_args(
1308        &mut self,
1309        pipeline: &ComputePipelineStateRef,
1310        bindings: &[(u64, KernelArg<'_>)],
1311        reads: &[&MlxBuffer],
1312        writes: &[&MlxBuffer],
1313        threadgroups: MTLSize,
1314        threadgroup_size: MTLSize,
1315    ) {
1316        // Capture mode: stash ranges + delegate to the standard encode.
1317        // The ranges flow through `pending_reads`/`pending_writes` and
1318        // attach to the captured `Dispatch` node — identical to what
1319        // `GraphSession::barrier_between` already does in capture mode.
1320        if self.is_capturing() {
1321            let read_ranges = ranges_from_buffers(reads);
1322            let write_ranges = ranges_from_buffers(writes);
1323            self.set_pending_buffer_ranges(read_ranges, write_ranges);
1324            self.encode_threadgroups_with_args(pipeline, bindings, threadgroups, threadgroup_size);
1325            return;
1326        }
1327
1328        if auto_barrier_enabled() {
1329            self.maybe_auto_barrier(reads, writes);
1330        }
1331
1332        self.encode_threadgroups_with_args(pipeline, bindings, threadgroups, threadgroup_size);
1333    }
1334
1335    /// Auto-barrier-aware dispatch with [`KernelArg`] bindings + shared
1336    /// threadgroup memory.
1337    ///
1338    /// See [`dispatch_tracked_threadgroups_with_args`](Self::dispatch_tracked_threadgroups_with_args)
1339    /// for the behavioral contract; this variant additionally takes a
1340    /// `threadgroup_mem` slice that is forwarded to
1341    /// [`encode_threadgroups_with_args_and_shared`](Self::encode_threadgroups_with_args_and_shared).
1342    ///
1343    /// The 8-argument signature mirrors the existing
1344    /// `encode_threadgroups_with_args_and_shared` plus the two
1345    /// dataflow slices; `clippy::too_many_arguments` is allowed
1346    /// because each parameter is load-bearing for either the dispatch
1347    /// (pipeline/bindings/threadgroups/threadgroup_size/shared_mem)
1348    /// or the auto-barrier (reads/writes).
1349    #[allow(clippy::too_many_arguments)]
1350    pub fn dispatch_tracked_threadgroups_with_args_and_shared(
1351        &mut self,
1352        pipeline: &ComputePipelineStateRef,
1353        bindings: &[(u64, KernelArg<'_>)],
1354        threadgroup_mem: &[(u64, u64)],
1355        reads: &[&MlxBuffer],
1356        writes: &[&MlxBuffer],
1357        threadgroups: MTLSize,
1358        threadgroup_size: MTLSize,
1359    ) {
1360        if self.is_capturing() {
1361            let read_ranges = ranges_from_buffers(reads);
1362            let write_ranges = ranges_from_buffers(writes);
1363            self.set_pending_buffer_ranges(read_ranges, write_ranges);
1364            self.encode_threadgroups_with_args_and_shared(
1365                pipeline,
1366                bindings,
1367                threadgroup_mem,
1368                threadgroups,
1369                threadgroup_size,
1370            );
1371            return;
1372        }
1373
1374        if auto_barrier_enabled() {
1375            self.maybe_auto_barrier(reads, writes);
1376        }
1377
1378        self.encode_threadgroups_with_args_and_shared(
1379            pipeline,
1380            bindings,
1381            threadgroup_mem,
1382            threadgroups,
1383            threadgroup_size,
1384        );
1385    }
1386
1387    /// Auto-barrier-aware dispatch using `(slot, &MlxBuffer)` bindings
1388    /// (uses `dispatch_thread_groups`).
1389    ///
1390    /// Convenience wrapper for callers that don't need
1391    /// [`KernelArg::Bytes`] inline-byte arguments.  See
1392    /// [`dispatch_tracked_threadgroups_with_args`](Self::dispatch_tracked_threadgroups_with_args)
1393    /// for behavioral contract.
1394    pub fn dispatch_tracked_threadgroups(
1395        &mut self,
1396        pipeline: &ComputePipelineStateRef,
1397        buffers: &[(u64, &MlxBuffer)],
1398        reads: &[&MlxBuffer],
1399        writes: &[&MlxBuffer],
1400        threadgroups: MTLSize,
1401        threadgroup_size: MTLSize,
1402    ) {
1403        if self.is_capturing() {
1404            let read_ranges = ranges_from_buffers(reads);
1405            let write_ranges = ranges_from_buffers(writes);
1406            self.set_pending_buffer_ranges(read_ranges, write_ranges);
1407            self.encode_threadgroups(pipeline, buffers, threadgroups, threadgroup_size);
1408            return;
1409        }
1410
1411        if auto_barrier_enabled() {
1412            self.maybe_auto_barrier(reads, writes);
1413        }
1414
1415        self.encode_threadgroups(pipeline, buffers, threadgroups, threadgroup_size);
1416    }
1417
1418    /// Auto-barrier-aware dispatch using `(slot, &MlxBuffer)` bindings
1419    /// **plus shared threadgroup memory** (uses `dispatch_thread_groups`).
1420    ///
1421    /// Mirrors [`encode_threadgroups_with_shared`](Self::encode_threadgroups_with_shared)
1422    /// — convenience variant for kernels that allocate threadgroup
1423    /// memory (reductions in `rms_norm`, `softmax`, etc.) but don't
1424    /// need [`KernelArg::Bytes`] inline-byte arguments.  See
1425    /// [`dispatch_tracked_threadgroups_with_args`](Self::dispatch_tracked_threadgroups_with_args)
1426    /// for the behavioral contract; the only addition here is the
1427    /// `threadgroup_mem` slice forwarded to the underlying encode.
1428    ///
1429    /// Closes the iter38-audit coverage gap: the 5 `rms_norm.rs`
1430    /// callsites (`/opt/mlx-native/src/ops/rms_norm.rs:124,236,443,
1431    /// 516,589`) all use `encode_threadgroups_with_shared` and need
1432    /// dataflow tracking when migrated to auto-barrier in iter40+.
1433    ///
1434    /// 7-argument signature; `clippy::too_many_arguments` is allowed
1435    /// because each parameter is load-bearing for either the dispatch
1436    /// (pipeline/buffers/threadgroups/threadgroup_size/shared_mem) or
1437    /// the auto-barrier (reads/writes).
1438    #[allow(clippy::too_many_arguments)]
1439    pub fn dispatch_tracked_threadgroups_with_shared(
1440        &mut self,
1441        pipeline: &ComputePipelineStateRef,
1442        buffers: &[(u64, &MlxBuffer)],
1443        threadgroup_mem: &[(u64, u64)],
1444        reads: &[&MlxBuffer],
1445        writes: &[&MlxBuffer],
1446        threadgroups: MTLSize,
1447        threadgroup_size: MTLSize,
1448    ) {
1449        if self.is_capturing() {
1450            let read_ranges = ranges_from_buffers(reads);
1451            let write_ranges = ranges_from_buffers(writes);
1452            self.set_pending_buffer_ranges(read_ranges, write_ranges);
1453            self.encode_threadgroups_with_shared(
1454                pipeline,
1455                buffers,
1456                threadgroup_mem,
1457                threadgroups,
1458                threadgroup_size,
1459            );
1460            return;
1461        }
1462
1463        if auto_barrier_enabled() {
1464            self.maybe_auto_barrier(reads, writes);
1465        }
1466
1467        self.encode_threadgroups_with_shared(
1468            pipeline,
1469            buffers,
1470            threadgroup_mem,
1471            threadgroups,
1472            threadgroup_size,
1473        );
1474    }
1475
1476    /// Auto-barrier-aware `dispatch_threads` variant with
1477    /// [`KernelArg`] bindings.
1478    ///
1479    /// Mirrors [`encode_with_args`](Self::encode_with_args) — the
1480    /// `dispatch_threads` (per-thread grid) flavor, as opposed to the
1481    /// `dispatch_thread_groups` flavor of
1482    /// [`dispatch_tracked_threadgroups_with_args`](Self::dispatch_tracked_threadgroups_with_args).
1483    /// See that method for the behavioral contract.
1484    ///
1485    /// Closes the iter38-audit coverage gap: callers that use
1486    /// per-thread grids — `rope.rs:108` (IMROPE), `sigmoid_mul.rs:76`
1487    /// (sigmoid-mul), and `encode_helpers.rs:41` (kv_cache_copy) —
1488    /// need a `dispatch_threads` flavor of the tracked dispatch
1489    /// because their grid sizes are expressed in threads, not
1490    /// threadgroups.
1491    ///
1492    /// Note: the simpler `(slot, &MlxBuffer)` form (from
1493    /// [`encode`](Self::encode)) is a special case of this method —
1494    /// callers can wrap each binding as `KernelArg::Buffer(buf)` to
1495    /// reuse this single tracked variant rather than introducing a
1496    /// fifth one.
1497    pub fn dispatch_tracked_threads_with_args(
1498        &mut self,
1499        pipeline: &ComputePipelineStateRef,
1500        bindings: &[(u64, KernelArg<'_>)],
1501        reads: &[&MlxBuffer],
1502        writes: &[&MlxBuffer],
1503        grid_size: MTLSize,
1504        threadgroup_size: MTLSize,
1505    ) {
1506        if self.is_capturing() {
1507            let read_ranges = ranges_from_buffers(reads);
1508            let write_ranges = ranges_from_buffers(writes);
1509            self.set_pending_buffer_ranges(read_ranges, write_ranges);
1510            self.encode_with_args(pipeline, bindings, grid_size, threadgroup_size);
1511            return;
1512        }
1513
1514        if auto_barrier_enabled() {
1515            self.maybe_auto_barrier(reads, writes);
1516        }
1517
1518        self.encode_with_args(pipeline, bindings, grid_size, threadgroup_size);
1519    }
1520
1521    /// Run the dataflow check, emit a barrier on conflict, and record
1522    /// the dispatch's ranges into the cumulative state.
1523    ///
1524    /// Always called *before* the underlying `encode_*` method
1525    /// applies the dispatch.  Mirrors lines 220-225 of
1526    /// `ggml-metal-ops.cpp` (`concurrency_check + concurrency_reset +
1527    /// concurrency_add` around each node).
1528    fn maybe_auto_barrier(
1529        &mut self,
1530        reads: &[&MlxBuffer],
1531        writes: &[&MlxBuffer],
1532    ) {
1533        if self.mem_ranges.check_dispatch(reads, writes) {
1534            // Concurrent — no barrier needed; just record the new ranges.
1535            self.mem_ranges.add_dispatch(reads, writes);
1536            AUTO_BARRIER_CONCURRENT.fetch_add(1, Ordering::Relaxed);
1537        } else {
1538            // Conflict — emit barrier, reset state, seed new group.
1539            //
1540            // `memory_barrier()` itself increments `BARRIER_COUNT` and,
1541            // when `MLX_PROFILE_BARRIERS=1`, accumulates `BARRIER_NS`.
1542            // We additionally bump `AUTO_BARRIER_COUNT` so the
1543            // "auto-emitted vs hand-placed" subset is queryable.
1544            self.memory_barrier();
1545            self.mem_ranges.reset();
1546            self.mem_ranges.add_dispatch(reads, writes);
1547            AUTO_BARRIER_COUNT.fetch_add(1, Ordering::Relaxed);
1548        }
1549    }
1550
1551    /// Force a barrier and reset the auto-barrier tracker.
1552    ///
1553    /// Use at boundaries where the caller knows a barrier is required
1554    /// regardless of dataflow — typically before reading data back to
1555    /// CPU, or at the end of an op group whose internal dependencies
1556    /// the tracker can't see (e.g. host-driven memcpy).
1557    ///
1558    /// Equivalent to `memory_barrier()` plus a `MemRanges::reset()`
1559    /// when `HF2Q_AUTO_BARRIER=1`; equivalent to plain
1560    /// `memory_barrier()` otherwise.
1561    pub fn force_barrier_and_reset_tracker(&mut self) {
1562        self.memory_barrier();
1563        if auto_barrier_enabled() {
1564            self.mem_ranges.reset();
1565        }
1566    }
1567
1568    /// Diagnostic accessor — number of ranges currently recorded in
1569    /// this encoder's [`MemRanges`] tracker.  Always zero unless
1570    /// `HF2Q_AUTO_BARRIER=1` and at least one `dispatch_tracked` call
1571    /// has fired since the last conflict.
1572    #[inline]
1573    pub fn mem_ranges_len(&self) -> usize {
1574        self.mem_ranges.len()
1575    }
1576
1577    /// Replay a single captured dispatch node into this encoder.
1578    ///
1579    /// This is the inverse of capture: it takes a previously recorded
1580    /// `CapturedNode::Dispatch` and encodes it into the live Metal encoder.
1581    /// Barrier nodes are handled by the caller (ComputeGraph::encode_sequential).
1582    ///
1583    /// Does NOT increment `DISPATCH_COUNT` — that was already counted at
1584    /// capture time.
1585    pub fn replay_dispatch(
1586        &mut self,
1587        pipeline: &ComputePipelineStateRef,
1588        bindings: &[(u64, RecordedBinding)],
1589        threadgroup_memory: &[(u64, u64)],
1590        threads_per_grid: MTLSize,
1591        threads_per_threadgroup: MTLSize,
1592        dispatch_kind: DispatchKind,
1593    ) {
1594        // ADR-015 iter63 (Phase A.3): mirror the per-dispatch sampling
1595        // scaffold here so capture-mode-recorded graphs (graph.rs
1596        // encode_sequential / encode_with_barriers / encode_chunk_with
1597        // _barriers) still produce per-dispatch entries.  The replay
1598        // path bypasses encode*; without this hook the per-dispatch
1599        // table would be silently empty for any model that uses
1600        // `GraphExecutor::begin_recorded`.
1601        //
1602        // Captured `op_kind` is forwarded via `pending_op_kind`: the
1603        // graph replay layer at graph.rs:197/236/727 sets it from the
1604        // CapturedNode.op_kind before calling replay_dispatch.
1605        self.ensure_sample_buffer();
1606        let op_kind = self.take_pending_op_kind();
1607        let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
1608        // SAFETY: see encode() above.
1609        let encoder = unsafe { &*encoder_ptr };
1610        encoder.set_compute_pipeline_state(pipeline);
1611        for (index, binding) in bindings {
1612            match binding {
1613                RecordedBinding::Buffer { metal_buffer, offset } => {
1614                    encoder.set_buffer(*index, Some(metal_buffer), *offset);
1615                }
1616                RecordedBinding::Bytes(bytes) => {
1617                    encoder.set_bytes(
1618                        *index,
1619                        bytes.len() as u64,
1620                        bytes.as_ptr() as *const _,
1621                    );
1622                }
1623            }
1624        }
1625        for &(index, byte_length) in threadgroup_memory {
1626            encoder.set_threadgroup_memory_length(index, byte_length);
1627        }
1628        let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
1629        match dispatch_kind {
1630            DispatchKind::Threads => {
1631                encoder.dispatch_threads(threads_per_grid, threads_per_threadgroup);
1632            }
1633            DispatchKind::ThreadGroups => {
1634                encoder.dispatch_thread_groups(threads_per_grid, threads_per_threadgroup);
1635            }
1636        }
1637        self.sample_dispatch_post(encoder, pre_idx);
1638    }
1639
1640    /// Flush any pending residency-set add/remove staging.
1641    ///
1642    /// Hooked at every commit boundary so per-allocation
1643    /// [`ResidencySet::add_allocation`](ResidencySet::add_allocation) and
1644    /// [`ResidencySet::remove_allocation`](ResidencySet::remove_allocation)
1645    /// calls (as fired by `MlxDevice::alloc_buffer` and
1646    /// `MlxBufferStorage::Drop`) collapse into at most ONE `[set commit]`
1647    /// per CB submission. Mirrors llama.cpp's
1648    /// `ggml-metal-device.m:1378-1382` (batch addAllocation in loop,
1649    /// commit ONCE).
1650    #[inline]
1651    fn flush_residency_pending(&self) {
1652        if let Some(set) = self.residency_set.as_ref() {
1653            set.flush_pending();
1654        }
1655    }
1656
1657    // ----------------------------------------------------------------
1658    // ADR-015 iter63 — per-dispatch sample buffer lifecycle
1659    // ----------------------------------------------------------------
1660
1661    /// Allocate the per-CB `MTLCounterSampleBuffer` if it has not been
1662    /// allocated yet for this CB.
1663    ///
1664    /// No-op when `MLX_PROFILE_DISPATCH` is unset, when the buffer is
1665    /// already present, or when the device does not expose a counter
1666    /// set named `"timestamp"` (Risk R1 — graceful degrade with a
1667    /// one-shot stderr warning).
1668    ///
1669    /// The sample buffer is sized to [`MAX_SAMPLES_PER_CB`] (32_768).
1670    /// This is the start-+-end pair budget — i.e. ≤ 16,384 dispatches
1671    /// per CB.  Above that ceiling, additional dispatches will skip
1672    /// sampling (see [`Self::sample_dispatch_pre`]).
1673    #[inline]
1674    fn ensure_sample_buffer(&mut self) {
1675        if !crate::kernel_profile::is_dispatch_enabled() {
1676            return;
1677        }
1678        if self.sample_buffer.is_some() {
1679            return;
1680        }
1681        // Discover the timestamp counter set.  metal-rs 0.33 does not
1682        // export the `MTLCommonCounterSetTimestamp` constant, so we
1683        // name-match `"timestamp"` case-insensitively.  Reach the
1684        // device via the cmd_buf's `device` selector (metal-rs 0.33
1685        // exposes `CommandQueue::device` but not `CommandBuffer::device`,
1686        // so we go through ObjC directly).
1687        let device: &metal::DeviceRef = unsafe {
1688            let cb = &*self.cmd_buf;
1689            msg_send![cb, device]
1690        };
1691        // ADR-015 iter63 — Apple Silicon hardware constraint (NEW Risk
1692        // discovered at impl time, supersedes design §A.7).  M-series
1693        // GPUs (verified: AGXG17XFamilyComputeContext = M5 Max series,
1694        // macOS 26) only support counter sampling AtStageBoundary —
1695        // i.e. between compute *passes*, not between dispatches inside
1696        // a persistent compute encoder.  Calling
1697        // `sampleCountersInBuffer:atSampleIndex:withBarrier:` on such
1698        // hardware aborts with `failed assertion ... not supported on
1699        // this device`.  The persistent-encoder design (mlx-native uses
1700        // ONE compute encoder per CB to amortize ~800 encoder
1701        // create/end cycles per forward pass — see `get_or_create_
1702        // encoder` docstring) is incompatible with stage-boundary-only
1703        // sampling, so on Apple Silicon we degrade per-dispatch
1704        // profiling to a no-op and log once.  Per-CB profiling is
1705        // unaffected (it uses MTLCommandBuffer.GPUStartTime/
1706        // GPUEndTime, which are always available).
1707        //
1708        // Future: if Apple ever ships AtDispatchBoundary support on
1709        // Apple Silicon, this branch becomes a true cap check.  For
1710        // now, the kit infrastructure is in place; only the sample-
1711        // point cooperates.
1712        if !device.supports_counter_sampling(MTLCounterSamplingPoint::AtDispatchBoundary) {
1713            if TIMESTAMP_SET_WARN_LOGGED
1714                .compare_exchange(0, 1, Ordering::Relaxed, Ordering::Relaxed)
1715                .is_ok()
1716            {
1717                eprintln!(
1718                    "[mlx-native] MLX_PROFILE_DISPATCH=1 ignored: \
1719                     device {:?} does NOT support \
1720                     MTLCounterSamplingPointAtDispatchBoundary \
1721                     (Apple Silicon limitation; only AtStageBoundary \
1722                     is supported, which is incompatible with the \
1723                     persistent compute-encoder pattern). \
1724                     MLX_PROFILE_CB=1 still produces per-CB GPU times.",
1725                    device.name()
1726                );
1727            }
1728            return;
1729        }
1730        let counter_sets = device.counter_sets();
1731        let timestamp_set = counter_sets
1732            .iter()
1733            .find(|c: &&metal::CounterSet| c.name().eq_ignore_ascii_case("timestamp"));
1734        let timestamp_set = match timestamp_set {
1735            Some(s) => s,
1736            None => {
1737                // Risk R1: device does not expose a timestamp set.
1738                // Log once and degrade to no-op (sample_buffer stays None).
1739                if TIMESTAMP_SET_WARN_LOGGED
1740                    .compare_exchange(0, 1, Ordering::Relaxed, Ordering::Relaxed)
1741                    .is_ok()
1742                {
1743                    eprintln!(
1744                        "[mlx-native] MLX_PROFILE_DISPATCH=1 ignored: \
1745                         device {:?} exposes no MTLCommonCounterSetTimestamp",
1746                        device.name()
1747                    );
1748                }
1749                return;
1750            }
1751        };
1752        // Build descriptor.  StorageMode::Shared is required by
1753        // resolveCounterRange (MTLCounters.h:185-188).
1754        let descriptor = CounterSampleBufferDescriptor::new();
1755        descriptor.set_counter_set(timestamp_set);
1756        descriptor.set_storage_mode(MTLStorageMode::Shared);
1757        descriptor.set_label("mlx_native.dispatch_samples");
1758        descriptor.set_sample_count(MAX_SAMPLES_PER_CB);
1759        match device.new_counter_sample_buffer_with_descriptor(&descriptor) {
1760            Ok(buf) => {
1761                self.sample_buffer = Some(buf);
1762            }
1763            Err(e) => {
1764                if TIMESTAMP_SET_WARN_LOGGED
1765                    .compare_exchange(0, 1, Ordering::Relaxed, Ordering::Relaxed)
1766                    .is_ok()
1767                {
1768                    eprintln!(
1769                        "[mlx-native] MLX_PROFILE_DISPATCH=1 ignored: \
1770                         newCounterSampleBufferWithDescriptor failed: {}",
1771                        e
1772                    );
1773                }
1774                self.sample_buffer = None;
1775            }
1776        }
1777    }
1778
1779    /// Insert the start-of-dispatch counter sample (sample index `2*i`)
1780    /// and queue the per-dispatch metadata.  Returns the dispatch
1781    /// ordinal `i` so the caller can emit the matching post-sample.
1782    ///
1783    /// No-op when sampling is inactive — returns 0 in that case (the
1784    /// returned value is only consumed when the sample buffer is
1785    /// active, so this is safe).
1786    ///
1787    /// `with_barrier:true` is mandatory: the encoder uses
1788    /// `MTLDispatchTypeConcurrent` and without the barrier the start
1789    /// timestamp would race against any in-flight dispatch (PROFILING-
1790    /// KIT-DESIGN §A.5).
1791    #[inline]
1792    fn sample_dispatch_pre(
1793        &mut self,
1794        encoder: &ComputeCommandEncoderRef,
1795        op_kind: CapturedOpKind,
1796    ) -> Option<u32> {
1797        let sb = self.sample_buffer.as_ref()?;
1798        let i = self.dispatch_in_cb;
1799        let pre_idx = (i as u64).checked_mul(2)?;
1800        if pre_idx >= MAX_SAMPLES_PER_CB {
1801            // Ceiling exceeded — skip sampling for the remainder of
1802            // this CB.  Risk R4 (PROFILING-KIT-DESIGN §A.7): future
1803            // iter can chunk-resolve every N dispatches; for now we
1804            // accept truncation with a one-shot warning (re-uses the
1805            // R1 warn flag).
1806            return None;
1807        }
1808        encoder.sample_counters_in_buffer(sb, pre_idx, true);
1809        self.pending_dispatch_meta.push(PendingDispatchMeta {
1810            op_kind: op_kind.name(),
1811            dispatch_index: i,
1812        });
1813        Some(i)
1814    }
1815
1816    /// Insert the end-of-dispatch counter sample (sample index `2*i+1`)
1817    /// matching the most recent [`Self::sample_dispatch_pre`].
1818    ///
1819    /// No-op when sampling is inactive or when `pre_idx` is `None`.
1820    #[inline]
1821    fn sample_dispatch_post(
1822        &mut self,
1823        encoder: &ComputeCommandEncoderRef,
1824        pre_idx: Option<u32>,
1825    ) {
1826        let i = match pre_idx {
1827            Some(v) => v,
1828            None => return,
1829        };
1830        let sb = match self.sample_buffer.as_ref() {
1831            Some(b) => b,
1832            None => return,
1833        };
1834        let post_idx = match (i as u64).checked_mul(2).and_then(|v| v.checked_add(1)) {
1835            Some(v) if v < MAX_SAMPLES_PER_CB => v,
1836            _ => return,
1837        };
1838        encoder.sample_counters_in_buffer(sb, post_idx, true);
1839        // Bump the per-CB ordinal only after both samples committed
1840        // successfully so a truncation skip leaves the meta queue
1841        // length matching the buffer's resolved range.
1842        self.dispatch_in_cb = i.saturating_add(1);
1843    }
1844
1845    /// Resolve the per-CB sample buffer, push entries into
1846    /// [`crate::kernel_profile`], and reset per-CB state.
1847    ///
1848    /// Called from [`Self::commit_and_wait_labeled`] after the CB
1849    /// completes; the caller is responsible for ensuring the GPU has
1850    /// finished (otherwise `resolveCounterRange` returns garbage).
1851    ///
1852    /// On the first resolve after a [`crate::kernel_profile::reset`],
1853    /// also captures a `(cpu_ns, gpu_ticks)` pair via
1854    /// `device.sampleTimestamps` so subsequent ticks→ns conversion
1855    /// uses a fresh scale factor.
1856    fn resolve_dispatch_samples(&mut self, cb_label: &str) -> Result<()> {
1857        let sb = match self.sample_buffer.take() {
1858            Some(b) => b,
1859            None => {
1860                self.pending_dispatch_meta.clear();
1861                self.dispatch_in_cb = 0;
1862                return Ok(());
1863            }
1864        };
1865        let n = self.pending_dispatch_meta.len();
1866        if n == 0 {
1867            self.dispatch_in_cb = 0;
1868            return Ok(());
1869        }
1870        // Refresh the (cpu, gpu) scale pair on every resolve; the
1871        // device call is cheap and keeps us robust against driver-side
1872        // timebase changes between CBs.
1873        let mut cpu_t: u64 = 0;
1874        let mut gpu_t: u64 = 0;
1875        let device: &metal::DeviceRef = unsafe {
1876            let cb = &*self.cmd_buf;
1877            msg_send![cb, device]
1878        };
1879        device.sample_timestamps(&mut cpu_t, &mut gpu_t);
1880        crate::kernel_profile::record_clock_pair(cpu_t, gpu_t);
1881        let length = (n as u64).saturating_mul(2);
1882        let data = sb.resolve_counter_range(NSRange {
1883            location: 0,
1884            length,
1885        });
1886        // `resolve_counter_range` returns one NSUInteger per sample.
1887        // Pair them up: data[2i] = start, data[2i+1] = end.
1888        for (i, meta) in self.pending_dispatch_meta.drain(..).enumerate() {
1889            let start_idx = 2 * i;
1890            let end_idx = 2 * i + 1;
1891            if end_idx >= data.len() {
1892                break;
1893            }
1894            let start_raw = data[start_idx] as u64;
1895            let end_raw = data[end_idx] as u64;
1896            let start_ns = crate::kernel_profile::convert_gpu_ticks_to_ns(start_raw);
1897            let end_ns = crate::kernel_profile::convert_gpu_ticks_to_ns(end_raw);
1898            let gpu_ns = end_ns.saturating_sub(start_ns);
1899            crate::kernel_profile::record_dispatch(
1900                crate::kernel_profile::DispatchEntry {
1901                    cb_label: cb_label.to_string(),
1902                    op_kind: meta.op_kind,
1903                    dispatch_index: meta.dispatch_index,
1904                    gpu_ns,
1905                    start_gpu_ns: start_ns,
1906                    end_gpu_ns: end_ns,
1907                },
1908            );
1909        }
1910        // Buffer dropped at end of scope releases the underlying
1911        // CounterSampleBuffer; per-CB lifetime correctly bounded.
1912        drop(sb);
1913        self.dispatch_in_cb = 0;
1914        Ok(())
1915    }
1916
1917    /// Commit the command buffer and block until the GPU finishes execution.
1918    ///
1919    /// # Errors
1920    ///
1921    /// Returns `MlxError::CommandBufferError` if the GPU reports an error.
1922    pub fn commit_and_wait(&mut self) -> Result<()> {
1923        SYNC_COUNT.fetch_add(1, Ordering::Relaxed);
1924
1925        // End the persistent compute encoder before committing.
1926        self.end_active_encoder();
1927
1928        // ADR-015 iter8e (Phase 3b): flush deferred residency-set
1929        // add/remove staging so the residency hint covers any buffers
1930        // referenced by this CB. Single commit per CB boundary; no-op
1931        // when no residency set or no staged changes.
1932        self.flush_residency_pending();
1933
1934        self.cmd_buf.commit();
1935        self.cmd_buf.wait_until_completed();
1936
1937        match self.cmd_buf.status() {
1938            MTLCommandBufferStatus::Completed => Ok(()),
1939            MTLCommandBufferStatus::Error => {
1940                Err(MlxError::CommandBufferError(
1941                    "GPU command buffer completed with error status".into(),
1942                ))
1943            }
1944            status => Err(MlxError::CommandBufferError(format!(
1945                "Unexpected command buffer status after wait: {:?}",
1946                status
1947            ))),
1948        }
1949    }
1950
1951    /// Commit + wait, accumulating GPU wall-clock time under `label` into
1952    /// the [`crate::kernel_profile`] global table when `MLX_PROFILE_CB=1`
1953    /// is set.  When the env var is unset, this is identical to
1954    /// [`commit_and_wait`](Self::commit_and_wait) — zero overhead.
1955    ///
1956    /// Used by hf2q's decode hot path to attribute per-cb GPU time to
1957    /// labeled phases (per-layer attn, per-layer ffn, output_head, etc.)
1958    /// without manually wiring `commit_wait_with_gpu_time` everywhere.
1959    ///
1960    /// # Errors
1961    ///
1962    /// Returns `MlxError::CommandBufferError` if the GPU reports an error.
1963    pub fn commit_and_wait_labeled(&mut self, label: &str) -> Result<()> {
1964        // ADR-015 iter16 — propagate `label` to MTLCommandBuffer.setLabel and
1965        // (if a compute encoder is active) MTLComputeCommandEncoder.setLabel
1966        // BEFORE end_encoding/commit so xctrace's
1967        // `metal-application-encoders-list` table populates `cmdbuffer-label`
1968        // and `encoder-label` columns with the semantic phase name (e.g.
1969        // `layer.attn_moe_ffn`, `output_head.fused_norm_lm_argmax`,
1970        // `layer.delta_net.ops1-9`).  Joined to per-CB GPU duration via
1971        // `metal-gpu-submission-to-command-buffer-id` (sub_id ↔ encoder_id) →
1972        // `metal-gpu-execution-points` (per-dispatch start/end), this enables
1973        // per-phase µs/token attribution comparing hf2q vs llama side-by-side
1974        // (iter15 §E "iter16 ATTRIBUTION PATH").  Cost is a single ObjC
1975        // msg_send per CB submission — sub-µs on M5 Max — and a no-op when
1976        // xctrace isn't recording, so this is unconditionally safe to call on
1977        // the production decode hot path.
1978        self.apply_labels(label);
1979        // ADR-015 iter63: record GPU time AND resolve per-dispatch samples
1980        // when either env gate is set.  Per-dispatch sampling force-enables
1981        // the per-CB path so cross-validation per Risk R3 always has a
1982        // ground-truth comparator.
1983        let need_gpu_time =
1984            crate::kernel_profile::is_enabled() || crate::kernel_profile::is_dispatch_enabled();
1985        if need_gpu_time {
1986            let (start_s, end_s) = self.commit_wait_with_gpu_time()?;
1987            let ns = ((end_s - start_s).max(0.0) * 1_000_000_000.0) as u64;
1988            if crate::kernel_profile::is_enabled() {
1989                crate::kernel_profile::record(label, ns);
1990            }
1991            if crate::kernel_profile::is_dispatch_enabled() {
1992                self.resolve_dispatch_samples(label)?;
1993            }
1994            Ok(())
1995        } else {
1996            self.commit_and_wait()
1997        }
1998    }
1999
2000    /// Async commit, but with profiling label.  When `MLX_PROFILE_CB=1`
2001    /// is set, redirects to a synchronous [`commit_and_wait_labeled`]
2002    /// call to capture per-cb GPU time (this defeats async pipelining
2003    /// while profiling, which is the whole point — profile-mode is slow
2004    /// but informative).  When unset, identical to [`commit`](Self::commit).
2005    pub fn commit_labeled(&mut self, label: &str) {
2006        // ADR-015 iter16 — see `commit_and_wait_labeled` for rationale.
2007        if crate::kernel_profile::is_enabled() {
2008            // Profile mode: force sync to capture GPU time.  apply_labels is
2009            // called inside commit_and_wait_labeled — do NOT call it twice
2010            // here (would double the ObjC msg_send under MLX_PROFILE_CB=1).
2011            // Errors are logged via stderr because the void return matches
2012            // commit().
2013            if let Err(e) = self.commit_and_wait_labeled(label) {
2014                eprintln!("[mlx-native] commit_labeled({}) failed: {}", label, e);
2015            }
2016        } else {
2017            // Async path: apply labels here so xctrace MST traces capture
2018            // per-CB phase attribution under default decode (no
2019            // `MLX_PROFILE_CB`).
2020            self.apply_labels(label);
2021            self.commit();
2022        }
2023    }
2024
2025    /// Apply `label` to the underlying `MTLCommandBuffer` and, if a compute
2026    /// encoder is currently active, to the `MTLComputeCommandEncoder`.
2027    ///
2028    /// Called from [`commit_labeled`] and [`commit_and_wait_labeled`] BEFORE
2029    /// the encoder is ended / the CB is committed so xctrace's
2030    /// `metal-application-encoders-list` table picks up the label on the
2031    /// row emitted at the encoder's `endEncoding` / CB submission boundary.
2032    /// Single ObjC `msg_send` per call (two if an encoder is active); sub-µs
2033    /// on M5 Max; no-op when xctrace isn't recording.
2034    ///
2035    /// Skipped (debug-only assert) if `label` is empty — empty labels would
2036    /// produce an indistinguishable trace row from the metal-rs default
2037    /// `Command Buffer 0` placeholder.
2038    #[inline]
2039    fn apply_labels(&mut self, label: &str) {
2040        debug_assert!(!label.is_empty(), "commit_*_labeled called with empty label");
2041        if label.is_empty() {
2042            return;
2043        }
2044        self.cmd_buf.set_label(label);
2045        if !self.active_encoder.is_null() {
2046            // SAFETY: active_encoder is non-null and points to a live encoder
2047            // owned by cmd_buf — same invariant as get_or_create_encoder /
2048            // memory_barrier.  set_label is a single property write on the
2049            // ObjC object; safe before endEncoding.
2050            unsafe { &*self.active_encoder }.set_label(label);
2051        }
2052        // ADR-015 iter63: capture the most recent label for per-dispatch
2053        // entries.  Cheap String allocation — only happens at CB commit
2054        // boundaries, not per dispatch.
2055        self.last_label.clear();
2056        self.last_label.push_str(label);
2057    }
2058
2059    /// Commit + wait, returning `(gpu_start_s, gpu_end_s)` CFTimeInterval
2060    /// timestamps from `MTLCommandBuffer`'s `GPUStartTime`/`GPUEndTime`
2061    /// properties.  Both are mach-absolute CFTimeInterval seconds (double).
2062    ///
2063    /// Intended for `HF2Q_PROFILE_GPU_TS=1` per-bucket GPU wall-clock
2064    /// attribution.  Adds exactly two ObjC property reads per call on top
2065    /// of the regular `commit_and_wait` — measured well under 1 μs on
2066    /// M5 Max.
2067    ///
2068    /// # Errors
2069    ///
2070    /// Returns `MlxError::CommandBufferError` if the GPU reports an error.
2071    pub fn commit_wait_with_gpu_time(&mut self) -> Result<(f64, f64)> {
2072        self.commit_and_wait()?;
2073        // SAFETY: cmd_buf is a valid MTLCommandBuffer that has been
2074        // committed and awaited.  GPUStartTime / GPUEndTime return
2075        // CFTimeInterval (double precision seconds).  See
2076        // https://developer.apple.com/documentation/metal/mtlcommandbuffer/1639925-gpustarttime
2077        let (gpu_start, gpu_end): (f64, f64) = unsafe {
2078            let cb = &*self.cmd_buf;
2079            let s: f64 = msg_send![cb, GPUStartTime];
2080            let e: f64 = msg_send![cb, GPUEndTime];
2081            (s, e)
2082        };
2083        Ok((gpu_start, gpu_end))
2084    }
2085
2086    /// Commit the command buffer WITHOUT blocking.
2087    ///
2088    /// The GPU begins executing the encoded commands immediately.  Call
2089    /// [`wait_until_completed`](Self::wait_until_completed) later to block
2090    /// the CPU and check for errors.  This allows the CPU to continue doing
2091    /// other work (e.g. preparing the next batch) while the GPU runs.
2092    pub fn commit(&mut self) {
2093        self.end_active_encoder();
2094        // ADR-015 iter8e (Phase 3b): same flush hook as commit_and_wait —
2095        // this is the async-pipeline path that production decode uses.
2096        self.flush_residency_pending();
2097        self.cmd_buf.commit();
2098    }
2099
2100    /// Block until a previously committed command buffer completes.
2101    ///
2102    /// Must be called after [`commit`](Self::commit).  Do not call after
2103    /// [`commit_and_wait`](Self::commit_and_wait) — that method already waits.
2104    ///
2105    /// # Errors
2106    ///
2107    /// Returns `MlxError::CommandBufferError` if the GPU reports an error.
2108    pub fn wait_until_completed(&self) -> Result<()> {
2109        self.cmd_buf.wait_until_completed();
2110        match self.cmd_buf.status() {
2111            MTLCommandBufferStatus::Completed => Ok(()),
2112            MTLCommandBufferStatus::Error => Err(MlxError::CommandBufferError(
2113                "GPU command buffer completed with error status".into(),
2114            )),
2115            status => Err(MlxError::CommandBufferError(format!(
2116                "Unexpected command buffer status after wait: {:?}",
2117                status
2118            ))),
2119        }
2120    }
2121
2122    /// Borrow the underlying Metal command buffer.
2123    #[inline]
2124    pub fn metal_command_buffer(&self) -> &CommandBuffer {
2125        &self.cmd_buf
2126    }
2127
2128    /// Borrow the residency set bound to this encoder, if one exists.
2129    ///
2130    /// ADR-019 Phase 0b iter89e2-B: exposed `pub(crate)` so
2131    /// [`crate::EncoderSession`] can route caller-driven add/remove
2132    /// requests through the same `Arc<ResidencySetInner>` the encoder
2133    /// itself flushes at every `commit*` boundary. The single-set
2134    /// invariant from `device.rs::MlxDevice` is preserved — both the
2135    /// encoder's `flush_residency_pending` and the session's delegated
2136    /// add/remove operate on the SAME residency set. Returns `None` when
2137    /// residency sets are disabled (HF2Q_NO_RESIDENCY=1, macOS<15, or
2138    /// `CommandEncoder::new` from a residency-less queue).
2139    #[inline]
2140    pub(crate) fn residency_set(&self) -> Option<&ResidencySet> {
2141        self.residency_set.as_ref()
2142    }
2143
2144    /// Reopen `cmd_buf` with a fresh `CommandBuffer` from the originating queue.
2145    ///
2146    /// ADR-019 Phase 0b iter89e2-B: enables multi-stage chaining. After a
2147    /// non-blocking `commit*` has handed the prior CB to Metal, this method
2148    /// rotates `cmd_buf` to a freshly-allocated CB on the same queue and
2149    /// resets every per-CB scratch field so the next dispatch is encoded
2150    /// onto the new CB.
2151    ///
2152    /// # Caller contract
2153    ///
2154    /// Only valid when `active_encoder.is_null()` (the persistent compute
2155    /// encoder must have been ended via `end_active_encoder()`, which both
2156    /// `commit_and_wait` and `commit` already do). Calling this method
2157    /// while a compute encoder is open would leak the encoder (the new
2158    /// `cmd_buf` does not own it) and trip Metal's "Command encoder
2159    /// released without endEncoding" assertion when the prior `cmd_buf`
2160    /// drops. Callers are [`crate::EncoderSession::reset_for_next_stage`]
2161    /// only — the session has already committed before invoking this.
2162    ///
2163    /// # F2 / F11 / F12 fence preservation
2164    ///
2165    /// - **F2 — residency-rescission**: this method does NOT re-flush
2166    ///   the residency set. The prior `commit*` already flushed; staged
2167    ///   add/remove since then will flush at the next `commit*` on the
2168    ///   new CB. The residency-set Arc clone is preserved.
2169    /// - **F11 — zero-init alloc_buffer**: untouched (no buffer allocs).
2170    /// - **F12 — `HF2Q_FORCE_SERIAL_DISPATCH`**: the new CB will lazily
2171    ///   open its compute encoder via `get_or_create_encoder`, which
2172    ///   re-reads the env var; the falsification probe still fires on
2173    ///   the new CB.
2174    ///
2175    /// # Counter semantics
2176    ///
2177    /// Bumps `CMD_BUF_COUNT` exactly once per call, matching the
2178    /// `new_with_residency` accounting. Does NOT bump `SYNC_COUNT` (no
2179    /// commit/wait happens here).
2180    pub(crate) fn reset_command_buffer(&mut self) {
2181        debug_assert!(
2182            self.active_encoder.is_null(),
2183            "reset_command_buffer called with an active compute encoder \
2184             — caller must commit (which calls end_active_encoder) first"
2185        );
2186        let cmd_buf = if unretained_refs_enabled() {
2187            self.queue
2188                .new_command_buffer_with_unretained_references()
2189                .to_owned()
2190        } else {
2191            self.queue.new_command_buffer().to_owned()
2192        };
2193        CMD_BUF_COUNT.fetch_add(1, Ordering::Relaxed);
2194        self.cmd_buf = cmd_buf;
2195        // Per-CB scratch state — every field that's documented as being
2196        // bounded by a CB lifetime resets here.
2197        self.active_encoder = std::ptr::null();
2198        self.dispatch_in_cb = 0;
2199        self.last_label.clear();
2200        self.pending_dispatch_meta.clear();
2201        // `mem_ranges` is a per-CB barrier inference state; clearing on
2202        // CB rotation matches the `commit_and_wait` post-commit invariant
2203        // (any new CB starts with no pending hazards). The field's own
2204        // `clear` is invoked via `MemRanges::default` here to avoid
2205        // exposing internals.
2206        self.mem_ranges = MemRanges::new();
2207        // `sample_buffer` is dropped explicitly inside
2208        // `resolve_dispatch_samples` after a CB completes; we leave it
2209        // in whatever state the prior commit left it (typically `None`
2210        // after `commit_and_wait` finishes). A stale `Some` here would
2211        // be visible only under `MLX_PROFILE_DISPATCH=1` which fires its
2212        // own one-shot warning; not worth a special case.
2213        // `capture` (if Some) persists across CB rotation — capture mode
2214        // accumulates across stages within a session by design.
2215        // `pending_op_kind` / `pending_reads` / `pending_writes` only
2216        // hold tags for the NEXT dispatch and are consumed when that
2217        // dispatch fires — leaving them as-is is correct.
2218    }
2219
2220    /// Encode an `MTLSharedEvent` wait at `value` on the current CB.
2221    ///
2222    /// ADR-019 Phase 0b iter89e2-B: pairs with [`Self::encode_signal_event`]
2223    /// to express the inter-CB ordering D3 stage boundaries need. The new
2224    /// CB's GPU work blocks until the prior CB's signal lands on the same
2225    /// event at >= `value`.
2226    ///
2227    /// # Caller contract
2228    ///
2229    /// Must be called BEFORE any compute encoder is opened on the new
2230    /// CB — the wait is a CB-level op that must precede every dispatch
2231    /// in the new CB to actually order them. [`crate::EncoderSession::reset_for_next_stage`]
2232    /// fires this immediately after `reset_command_buffer`, before any
2233    /// dispatch lazy-opens the encoder.
2234    #[inline]
2235    pub(crate) fn encode_wait_for_event(&self, event: &metal::EventRef, value: u64) {
2236        debug_assert!(
2237            self.active_encoder.is_null(),
2238            "encode_wait_for_event called with an open compute encoder \
2239             — wait must precede the first dispatch on the new CB"
2240        );
2241        self.cmd_buf.encode_wait_for_event(event, value);
2242    }
2243
2244    /// End the active compute encoder, encode a stage-fence signal, and
2245    /// commit the CB non-blocking — atomically from the caller's view.
2246    ///
2247    /// ADR-019 Phase 0b iter89e2-B: this is the helper
2248    /// [`crate::EncoderSession::fence_stage`] uses to thread the signal
2249    /// between the encoder-end and the CB-commit boundaries that
2250    /// `commit_labeled` would otherwise serialize. Sequence:
2251    ///
2252    /// 1. End the persistent compute encoder (so `encodeSignalEvent:` is
2253    ///    encoded at CB-level, not encoder-level — Metal validates that
2254    ///    `encodeSignalEvent:` outside any encoder pass is the only
2255    ///    legal placement).
2256    /// 2. Apply `label` (when `Some`) to the CB. Note: at this point
2257    ///    the encoder is already ended, so the encoder's own
2258    ///    `setLabel:` is a no-op site — only the CB label propagates.
2259    ///    `last_label` and per-dispatch profiling keep working as
2260    ///    documented.
2261    /// 3. Encode `encodeSignalEvent:event:value:new_value` at CB-level.
2262    /// 4. Flush the residency-set pending staging (matches the
2263    ///    `commit_labeled` / `commit` flush at encoder.rs:2004).
2264    /// 5. Commit the CB non-blocking (matches `commit()` at
2265    ///    encoder.rs:2026).
2266    ///
2267    /// # Counter semantics
2268    ///
2269    /// Bumps `SYNC_COUNT` zero times (non-blocking). Bumps
2270    /// `CMD_BUF_COUNT` zero times (no new CB allocated here —
2271    /// [`Self::reset_command_buffer`] does that on the next stage).
2272    ///
2273    /// # Errors
2274    ///
2275    /// Infallible (matches `commit()` semantics — errors surface only
2276    /// at `wait_until_completed`).
2277    pub(crate) fn fence_signal_and_commit(
2278        &mut self,
2279        event: &metal::EventRef,
2280        new_value: u64,
2281        label: Option<&str>,
2282    ) {
2283        // Step 1: end the active compute encoder. encode_signal_event's
2284        // debug_assert requires this be done first.
2285        self.end_active_encoder();
2286        // Step 2: apply the CB label so xctrace MST attribution still
2287        // works on the fenced CB. apply_labels' debug_assert against
2288        // empty labels matches commit_labeled's semantics.
2289        if let Some(l) = label {
2290            self.apply_labels(l);
2291        }
2292        // Step 3: encode the signal at CB-level.
2293        self.cmd_buf.encode_signal_event(event, new_value);
2294        // Step 4 + 5: same as commit() — flush residency staging, then
2295        // hand the CB to Metal.
2296        self.flush_residency_pending();
2297        self.cmd_buf.commit();
2298    }
2299}
2300
2301impl Drop for CommandEncoder {
2302    fn drop(&mut self) {
2303        // End the persistent compute encoder before the command buffer
2304        // is dropped, otherwise Metal will assert:
2305        // "Command encoder released without endEncoding"
2306        self.end_active_encoder();
2307    }
2308}