Skip to main content

mlx_native/
encoder.rs

1//! [`CommandEncoder`] — batched GPU command submission.
2//!
3//! Wraps a Metal command buffer.  Encode one or more compute kernel dispatches,
4//! then call [`commit_and_wait`](CommandEncoder::commit_and_wait) to submit the
5//! entire batch and block until the GPU finishes.
6//!
7//! # Persistent compute encoder
8//!
9//! A single Metal `ComputeCommandEncoder` is kept alive across multiple
10//! dispatches within the same command buffer.  This avoids the overhead of
11//! creating and ending a new compute encoder per dispatch — the same pattern
12//! candle uses (`compute_per_buffer`).  On a forward pass with ~800 dispatches
13//! this saves ~800 encoder create/end cycles.
14//!
15//! # Capture mode (Phase 4e.1)
16//!
17//! When `start_capture()` is called, subsequent dispatches are recorded into a
18//! `Vec<CapturedNode>` instead of being encoded into Metal.  `memory_barrier()`
19//! records a barrier sentinel.  Call `take_capture()` to extract the recorded
20//! graph for later replay via `ComputeGraph::encode_sequential()`.
21
22use std::sync::atomic::{AtomicU64, Ordering};
23
24use metal::{
25    CommandBuffer, CommandQueue, ComputeCommandEncoderRef, ComputePipelineState,
26    ComputePipelineStateRef, CounterSampleBuffer, CounterSampleBufferDescriptor,
27    MTLCommandBufferStatus, MTLCounterSamplingPoint, MTLDispatchType, MTLSize, MTLStorageMode,
28    NSRange,
29};
30#[allow(unused_imports)]
31use objc::{msg_send, sel, sel_impl};
32
33use crate::buffer::MlxBuffer;
34use crate::error::{MlxError, Result};
35use crate::mem_ranges::MemRanges;
36use crate::residency::ResidencySet;
37
38/// A buffer or inline-bytes binding for a compute kernel argument slot.
39pub enum KernelArg<'a> {
40    /// Bind an existing Metal buffer at the given index.
41    Buffer(&'a MlxBuffer),
42    /// Bind an existing Metal buffer at the given index with a byte offset.
43    BufferWithOffset(&'a MlxBuffer, u64),
44    /// Bind inline bytes (small constant data) at the given index.
45    /// The data must be `Pod` and is copied into the command encoder.
46    Bytes(&'a [u8]),
47}
48
49/// Convert a `Pod` value to a byte slice suitable for `KernelArg::Bytes`.
50///
51/// # Safety
52///
53/// The caller must ensure `T` has the same layout as the corresponding
54/// MSL struct in the shader (matching field order, sizes, and alignment).
55pub fn as_bytes<T: bytemuck::Pod>(val: &T) -> &[u8] {
56    bytemuck::bytes_of(val)
57}
58
59// ---------------------------------------------------------------------------
60// Capture-mode types (Phase 4e.1 — Graph IR)
61// ---------------------------------------------------------------------------
62
63/// A recorded kernel argument binding.
64///
65/// When the encoder is in capture mode, each `set_buffer` / `set_bytes` call
66/// is stored as a `RecordedBinding` instead of being applied to Metal.
67#[derive(Clone)]
68pub enum RecordedBinding {
69    /// A Metal buffer at the given offset.
70    Buffer {
71        metal_buffer: metal::Buffer,
72        offset: u64,
73    },
74    /// Inline bytes (small constant data, copied).
75    Bytes(Vec<u8>),
76}
77
78/// How to dispatch the recorded kernel.
79#[derive(Clone, Copy, Debug)]
80pub enum DispatchKind {
81    /// `dispatch_threads(grid_size, threadgroup_size)` — Metal picks threadgroup count.
82    Threads,
83    /// `dispatch_thread_groups(threadgroups, threadgroup_size)` — caller specifies threadgroup count.
84    ThreadGroups,
85}
86
87/// Operation kind tag for captured nodes, used by the fusion pass (4e.2).
88///
89/// When the encoder is in capture mode, each dispatch can be tagged with an
90/// `OpKind` so the fusion pass can identify fuseable sequences without
91/// inspecting pipeline names.
92#[derive(Clone, Copy, Debug, PartialEq, Eq)]
93pub enum CapturedOpKind {
94    /// RMS normalization (with learned scale).
95    RmsNorm,
96    /// Elementwise multiply.
97    ElemMul,
98    /// Elementwise add.
99    ElemAdd,
100    /// Scaled dot-product attention (NOT reorderable — breaks lookahead).
101    Sdpa,
102    /// Softmax (NOT reorderable — breaks lookahead).
103    Softmax,
104    /// Any other operation — treated as reorderable by the graph optimizer.
105    Other,
106}
107
108impl CapturedOpKind {
109    /// Whether this captured op kind is safe to reorder past in the graph
110    /// optimizer (Phase 4e.3).
111    ///
112    /// Mirrors the `h_safe` whitelist from llama.cpp's
113    /// `ggml_metal_graph_optimize_reorder`.  Non-safe ops break the 64-node
114    /// lookahead — the reorder pass cannot look past them.
115    pub fn is_reorderable(&self) -> bool {
116        match self {
117            Self::Sdpa | Self::Softmax => false,
118            Self::RmsNorm | Self::ElemMul | Self::ElemAdd | Self::Other => true,
119        }
120    }
121
122    /// Stable string label suitable for embedding in the per-dispatch
123    /// profile dump (ADR-015 iter63 §A.5).  Matches the variant name —
124    /// `Other` is preserved verbatim so an aggregate-by-op_kind sort
125    /// produces a clean "what isn't yet labeled" bucket.
126    pub fn name(&self) -> &'static str {
127        match self {
128            Self::RmsNorm => "RmsNorm",
129            Self::ElemMul => "ElemMul",
130            Self::ElemAdd => "ElemAdd",
131            Self::Sdpa => "Sdpa",
132            Self::Softmax => "Softmax",
133            Self::Other => "Other",
134        }
135    }
136}
137
138/// A memory range annotation: (start_address, end_address).
139///
140/// Represents a contiguous GPU buffer region for conflict detection in the
141/// reorder pass (Phase 4e.3).  Addresses are CPU-visible `contents_ptr()`
142/// values, which on Apple Silicon unified memory equal the GPU addresses.
143pub type MemRange = (usize, usize);
144
145/// A single captured compute dispatch or barrier sentinel.
146///
147/// Created when the encoder is in capture mode.  Replayed later by
148/// `ComputeGraph::encode_sequential()`.
149#[derive(Clone)]
150pub enum CapturedNode {
151    /// A compute dispatch to replay.
152    Dispatch {
153        /// Pipeline state object to bind.
154        pipeline: ComputePipelineState,
155        /// Kernel argument bindings: (slot_index, binding).
156        bindings: Vec<(u64, RecordedBinding)>,
157        /// Grid or threadgroup count (interpretation depends on `dispatch_kind`).
158        threads_per_grid: MTLSize,
159        /// Threads per threadgroup.
160        threads_per_threadgroup: MTLSize,
161        /// Optional threadgroup memory allocations: (index, byte_length).
162        threadgroup_memory: Vec<(u64, u64)>,
163        /// Whether this is a dispatch_threads or dispatch_thread_groups call.
164        dispatch_kind: DispatchKind,
165        /// Operation kind tag for the fusion pass (4e.2).
166        /// Defaults to `Other` if not explicitly set via `set_op_kind()`.
167        op_kind: CapturedOpKind,
168        /// Read buffer ranges for reorder conflict detection (4e.3).
169        /// Populated from `barrier_between` calls in capture mode.
170        reads: Vec<MemRange>,
171        /// Write buffer ranges for reorder conflict detection (4e.3).
172        /// Populated from `barrier_between` calls in capture mode.
173        writes: Vec<MemRange>,
174    },
175    /// A memory barrier sentinel — forces a barrier at replay time.
176    Barrier,
177}
178
179/// Convert a slice of buffer references into capture-mode
180/// [`MemRange`] tuples.  Used by the [`CommandEncoder::dispatch_tracked*`]
181/// family in capture mode — equivalent to the conversion
182/// `GraphSession::barrier_between` does at `graph.rs:1452-1465`.
183///
184/// `(start, end)` uses `contents_ptr() + byte_offset` as the start
185/// and `contents_ptr() + byte_offset + slice_extent` as the end.
186fn ranges_from_buffers(bufs: &[&MlxBuffer]) -> Vec<MemRange> {
187    bufs.iter()
188        .map(|b| {
189            let base = b.contents_ptr() as usize + b.byte_offset() as usize;
190            let extent = (b.byte_len()).saturating_sub(b.byte_offset() as usize);
191            (base, base + extent)
192        })
193        .collect()
194}
195
196/// Apply a slice of `KernelArg` bindings to a compute encoder.
197///
198/// `KernelArg::Buffer(buf)` propagates the `MlxBuffer::byte_offset()` so
199/// `slice_view`-derived sub-buffers are honored automatically — the
200/// kernel sees memory starting at the slice's offset. This matches the
201/// documented contract of `slice_view` and the offset-handling in the
202/// other binding paths in this file (`encode`, `encode_threadgroups`,
203/// `encode_threadgroups_with_shared`, replay). Without it, every
204/// `slice_view`-derived buffer bound via `KernelArg::Buffer` silently
205/// exposes the entire underlying allocation — surfaced by hf2q's
206/// nomic-bert iter-79 cosine parity bisection (cosine 0.098 → 0.999962
207/// after fix).
208///
209/// `KernelArg::BufferWithOffset(buf, offset)` continues to use the
210/// explicit `offset` argument verbatim (callers asking for an explicit
211/// offset get exactly that, even on sliced buffers). The two API
212/// surfaces are intentional: implicit (sliced views auto-propagate) vs.
213/// explicit (caller-controlled).
214#[inline]
215fn apply_bindings(encoder: &ComputeCommandEncoderRef, bindings: &[(u64, KernelArg<'_>)]) {
216    for &(index, ref arg) in bindings {
217        match arg {
218            KernelArg::Buffer(buf) => {
219                encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
220            }
221            KernelArg::BufferWithOffset(buf, offset) => {
222                encoder.set_buffer(index, Some(buf.metal_buffer()), *offset);
223            }
224            KernelArg::Bytes(bytes) => {
225                encoder.set_bytes(index, bytes.len() as u64, bytes.as_ptr() as *const _);
226            }
227        }
228    }
229}
230
231/// Number of times `commit_and_wait()` has been called (CPU sync points).
232static SYNC_COUNT: AtomicU64 = AtomicU64::new(0);
233
234/// Number of times an encode method has been called (GPU dispatches).
235static DISPATCH_COUNT: AtomicU64 = AtomicU64::new(0);
236
237/// Number of `MTLCommandBuffer` instances created via `CommandEncoder::new`.
238/// Increments once per `device.command_encoder()` call.  Used by hf2q's
239/// `HF2Q_DECODE_PROFILE` instrumentation to measure command-buffer
240/// overhead per decode token (ADR-012 §Optimize / Task #15 follow-up).
241static CMD_BUF_COUNT: AtomicU64 = AtomicU64::new(0);
242
243/// Number of `memory_barrier()` calls that reached the
244/// `objc::msg_send![encoder, memoryBarrierWithScope:]` site.  Capture-mode
245/// no-ops and pre-encoder no-ops are excluded so the count reflects
246/// actual MTL barriers issued.
247///
248/// Always tracked — the increment is one atomic op, ~5 ns.  ADR-015 H4
249/// (Wave 2b hard gate #2) requires per-barrier counter resolution to
250/// confirm-or-falsify the barrier-coalescing lever; xctrace TimeProfiler
251/// at 1 ms sampling cannot resolve `memory_barrier` even though it fires
252/// ~440×/token (`docs/ADR-015-mlx-native-single-cb-decode.md` §"P3a' live
253/// profile pass" hypothesis register row H4).
254static BARRIER_COUNT: AtomicU64 = AtomicU64::new(0);
255
256/// Total nanoseconds spent inside the `objc::msg_send!` barrier site,
257/// summed across all calls.  ONLY updated when the env var
258/// `MLX_PROFILE_BARRIERS=1` is set on the process (cached on first
259/// `memory_barrier` call).  When disabled the timing path is a single
260/// branch + the unconditional barrier dispatch — same hot-path cost as
261/// before this counter was added.
262///
263/// Why env-gated: timing adds 2 × `Instant::now()` (~50–100 ns each via
264/// `mach_absolute_time`) per barrier.  At ~440 barriers/token that is
265/// ~22–44 µs/token of measurement overhead — comparable to what we are
266/// trying to measure.  Production must keep this off; profiling runs
267/// opt-in.
268static BARRIER_NS: AtomicU64 = AtomicU64::new(0);
269
270/// Reset all counters to zero.
271pub fn reset_counters() {
272    SYNC_COUNT.store(0, Ordering::Relaxed);
273    DISPATCH_COUNT.store(0, Ordering::Relaxed);
274    CMD_BUF_COUNT.store(0, Ordering::Relaxed);
275    BARRIER_COUNT.store(0, Ordering::Relaxed);
276    BARRIER_NS.store(0, Ordering::Relaxed);
277    AUTO_BARRIER_COUNT.store(0, Ordering::Relaxed);
278    AUTO_BARRIER_CONCURRENT.store(0, Ordering::Relaxed);
279}
280
281/// Read the current value of `SYNC_COUNT`.
282///
283/// Each call to `commit_and_wait()` increments this counter.
284pub fn sync_count() -> u64 {
285    SYNC_COUNT.load(Ordering::Relaxed)
286}
287
288/// Read the current value of `DISPATCH_COUNT`.
289///
290/// Each call to `encode()`, `encode_threadgroups()`, or
291/// `encode_threadgroups_with_shared()` increments this counter.
292pub fn dispatch_count() -> u64 {
293    DISPATCH_COUNT.load(Ordering::Relaxed)
294}
295
296/// Read the current value of `CMD_BUF_COUNT`.
297///
298/// Each `CommandEncoder::new` (i.e. each `MlxDevice::command_encoder()`)
299/// increments this counter.  Useful for diagnosing per-dispatch Metal
300/// command-buffer overhead in inner loops.
301pub fn cmd_buf_count() -> u64 {
302    CMD_BUF_COUNT.load(Ordering::Relaxed)
303}
304
305/// Read the current value of `BARRIER_COUNT`.
306///
307/// Each `memory_barrier()` call that reaches the underlying
308/// `objc::msg_send![encoder, memoryBarrierWithScope:]` site increments this
309/// counter.  Capture-mode no-ops and pre-encoder no-ops are excluded.
310/// ADR-015 H4 hypothesis: ~440 barriers/token on the qwen35 decode hot
311/// path (verify against this counter).
312pub fn barrier_count() -> u64 {
313    BARRIER_COUNT.load(Ordering::Relaxed)
314}
315
316/// Read the total nanoseconds spent in the `memoryBarrierWithScope:`
317/// `objc::msg_send!` site.  Only non-zero when `MLX_PROFILE_BARRIERS=1`
318/// was in the environment at the time of the first `memory_barrier()`
319/// call (the env check is cached on first use).
320///
321/// Combined with [`barrier_count`] this gives µs/barrier =
322/// `barrier_total_ns() / 1000 / barrier_count()`.
323pub fn barrier_total_ns() -> u64 {
324    BARRIER_NS.load(Ordering::Relaxed)
325}
326
327/// Whether barrier timing is enabled (env-gated, cached on first check).
328///
329/// Reading the env var via `std::env::var` is itself non-trivial; using
330/// `OnceLock` caches the decision so the per-barrier branch is a single
331/// atomic-load + compare.
332fn barrier_profile_enabled() -> bool {
333    use std::sync::OnceLock;
334    static FLAG: OnceLock<bool> = OnceLock::new();
335    *FLAG.get_or_init(|| {
336        std::env::var("MLX_PROFILE_BARRIERS")
337            .map(|v| v == "1")
338            .unwrap_or(false)
339    })
340}
341
342/// Whether `MLX_UNRETAINED_REFS=1` is set in the process environment.
343///
344/// ADR-015 iter13 — when true, `CommandEncoder::new_with_residency` opens
345/// each `MTLCommandBuffer` via
346/// [`CommandQueueRef::new_command_buffer_with_unretained_references`]
347/// instead of the default `commandBuffer`.  llama.cpp's per-token decode
348/// CBs use this same call (`/opt/llama.cpp/ggml/src/ggml-metal/`
349/// `ggml-metal-context.m:512` `[queue commandBufferWithUnretainedReferences]`)
350/// and gain ~3-5% wall on M-series GPUs by skipping per-buffer-binding ARC
351/// retains on submit.
352///
353/// **Caller-side prerequisite.**  Every Metal buffer bound to a dispatch
354/// must outlive the CB — see the docstring on
355/// [`CommandEncoder::new_with_residency`] for the full caller contract.
356/// In hf2q, the per-decode-token `MlxBufferPool` (`buffer_pool.rs`)
357/// already keeps ARC clones alive in its `in_use` list across the entire
358/// decode token; routing transient scratches through that pool is the
359/// canonical way to satisfy the contract.
360///
361/// Cached on first read via `OnceLock` to keep the per-CB-construction
362/// branch single-atomic-load fast.  Default OFF so any production decode
363/// run that does NOT explicitly set the var preserves retained-refs
364/// behavior verbatim.
365fn unretained_refs_enabled() -> bool {
366    use std::sync::OnceLock;
367    static FLAG: OnceLock<bool> = OnceLock::new();
368    *FLAG.get_or_init(|| {
369        std::env::var("MLX_UNRETAINED_REFS")
370            .map(|v| v == "1")
371            .unwrap_or(false)
372    })
373}
374
375/// Whether `HF2Q_AUTO_BARRIER=1` is set in the process environment.
376///
377/// ADR-015 iter37 — when true, every [`CommandEncoder::dispatch_tracked`]
378/// call consults a [`MemRanges`](crate::mem_ranges::MemRanges) tracker
379/// and auto-emits a `memoryBarrierWithScope:` exactly when the new
380/// dispatch's read/write ranges conflict with previously-recorded
381/// ranges (mirrors llama.cpp's `ggml_metal_op_concurrency_check` at
382/// `/opt/llama.cpp/ggml/src/ggml-metal/ggml-metal-ops.cpp:147-225`).
383/// When false, `dispatch_tracked` collapses to the same code path as
384/// `encode*` — no tracking, no auto-barriers — preserving sourdough
385/// behavior for any caller that opts into the tracked API but runs
386/// without the env gate.
387///
388/// Cached on first read via `OnceLock`.  Default OFF — production
389/// decode/prefill keeps its hand-placed `enc.memory_barrier()` calls
390/// until the migration in iter38+.
391fn auto_barrier_enabled() -> bool {
392    use std::sync::OnceLock;
393    static FLAG: OnceLock<bool> = OnceLock::new();
394    *FLAG.get_or_init(|| {
395        std::env::var("HF2Q_AUTO_BARRIER")
396            .map(|v| v == "1")
397            .unwrap_or(false)
398    })
399}
400
401/// Number of `memory_barrier()` calls auto-emitted by
402/// [`CommandEncoder::dispatch_tracked`] under
403/// `HF2Q_AUTO_BARRIER=1`.  Disjoint from [`BARRIER_COUNT`] —
404/// auto-barriers also bump `BARRIER_COUNT` since they go through
405/// `memory_barrier()`, so this counter measures only the
406/// auto-emitted subset.
407static AUTO_BARRIER_COUNT: AtomicU64 = AtomicU64::new(0);
408
409/// Number of `dispatch_tracked` calls whose mem-ranges check returned
410/// "concurrent" (no barrier needed).  Together with
411/// [`AUTO_BARRIER_COUNT`] this measures the elision rate of the
412/// dataflow barrier: `concurrent / (concurrent + barriers)` is the
413/// fraction of dispatches that ran inside the previous concurrent
414/// group rather than starting a new one.
415static AUTO_BARRIER_CONCURRENT: AtomicU64 = AtomicU64::new(0);
416
417// ---------------------------------------------------------------------------
418// ADR-015 iter63 — per-dispatch GPU sampling support
419// ---------------------------------------------------------------------------
420
421/// Hard cap on per-CB sample-buffer sample count (Risk R4 in
422/// PROFILING-KIT-DESIGN §A.7).
423///
424/// Empirically verified on Apple Silicon (M-series, macOS 26): the
425/// underlying `MTLCounterSampleBufferDescriptor.sampleCount` is bounded
426/// by a per-buffer **byte-size** limit of 32768 B.  At 8 bytes per
427/// `MTLCounterResultTimestamp` sample that maps to a sample-count
428/// ceiling of `32_768 / 8 = 4096`.  We allocate two samples per
429/// dispatch (start + end), so this ceiling = 2048 dispatches per CB.
430/// Decode CBs (~120 dispatches) fit comfortably; long prefill CBs
431/// (~6K dispatches per design §A.7) will truncate after 2048 — see
432/// [`Self::sample_dispatch_pre`] for the truncation path.  Future
433/// iter can chunk-resolve every 2K dispatches.
434///
435/// The original design constant of 32_768 (PROFILING-KIT-DESIGN §A.7)
436/// was based on Apple's documented ~64K-per-buffer "practical" limit,
437/// but the measured constraint on this hardware is the 32 KB byte
438/// budget.  Setting the budget below that would underutilize the
439/// buffer; setting it above causes
440/// `newCounterSampleBufferWithDescriptor` to fail with `Invalid sample
441/// buffer length: <bytes> B. Expected range: 8 -> 32768`.
442const MAX_SAMPLES_PER_CB: u64 = 4096;
443
444/// Whether the per-CB warning about a missing `MTLCommonCounterSetTimestamp`
445/// has been emitted yet.  Risk R1: if `device.counter_sets()` does not
446/// return a set named `"timestamp"` (case-insensitive), we degrade the
447/// per-dispatch path to a no-op and log once via stderr.
448static TIMESTAMP_SET_WARN_LOGGED: AtomicU64 = AtomicU64::new(0);
449
450/// Pending per-dispatch metadata that pairs with sample indices `2i`
451/// (start) and `2i+1` (end) inside the CB's `MTLCounterSampleBuffer`.
452/// Resolved by `CommandEncoder::resolve_dispatch_samples` at CB
453/// commit-time and converted to [`crate::kernel_profile::DispatchEntry`]
454/// before being pushed to the global table.
455#[derive(Clone, Debug)]
456struct PendingDispatchMeta {
457    op_kind: &'static str,
458    dispatch_index: u32,
459}
460
461/// Read the cumulative number of auto-emitted barriers across all
462/// encoders since process start (or last [`reset_counters`]).
463pub fn auto_barrier_count() -> u64 {
464    AUTO_BARRIER_COUNT.load(Ordering::Relaxed)
465}
466
467/// Read the cumulative number of `dispatch_tracked` calls that did NOT
468/// emit a barrier (ran concurrent with the previous group).
469pub fn auto_barrier_concurrent_count() -> u64 {
470    AUTO_BARRIER_CONCURRENT.load(Ordering::Relaxed)
471}
472
473/// Issue the underlying Metal `memoryBarrierWithScope:` ObjC msg_send.
474///
475/// Held in its own `#[inline(never)]` function so xctrace / Instruments
476/// has a stable Rust frame to attribute barrier time against, separate
477/// from the surrounding encoder accounting.  Per ADR-015 §P3a' Codex
478/// review Q2: TimeProfiler at 1 ms sampling cannot see this site when
479/// inlined; an explicit non-inline frame plus the [`BARRIER_NS`] counter
480/// closes the H4 hard gate.
481#[inline(never)]
482fn issue_metal_buffer_barrier(encoder: &ComputeCommandEncoderRef) {
483    // MTLBarrierScopeBuffers = 1 << 0 = 1.
484    const MTL_BARRIER_SCOPE_BUFFERS: u64 = 1;
485    unsafe {
486        let _: () =
487            objc::msg_send![encoder, memoryBarrierWithScope: MTL_BARRIER_SCOPE_BUFFERS];
488    }
489}
490
491/// A batched compute command encoder.
492///
493/// Keeps a single Metal `ComputeCommandEncoder` alive across multiple
494/// dispatches.  The encoder is created on the first dispatch and ended
495/// only when the command buffer is committed.  This mirrors candle's
496/// `compute_per_buffer` pattern and avoids per-dispatch encoder overhead.
497///
498/// # Typical usage
499///
500/// ```ignore
501/// let mut enc = device.command_encoder()?;
502/// // Multiple dispatches share the same compute encoder:
503/// enc.encode_threadgroups(pipeline1, &buffers1, tg1, tg_size1);
504/// enc.encode_threadgroups(pipeline2, &buffers2, tg2, tg_size2);
505/// enc.commit_and_wait()?;
506/// ```
507pub struct CommandEncoder {
508    cmd_buf: CommandBuffer,
509    // SAFETY marker: see unsafe Send impl below.
510    /// Raw pointer to the persistent compute encoder.
511    /// Non-null when a compute pass is active.
512    /// The encoder borrows from `cmd_buf` but we cannot express this
513    /// lifetime in safe Rust, so we use a raw pointer.
514    /// SAFETY: the pointer is valid as long as `cmd_buf` is alive and
515    /// `end_encoding()` has not been called on it.
516    active_encoder: *const ComputeCommandEncoderRef,
517    /// When `Some`, dispatches are recorded here instead of being encoded
518    /// into Metal.  Set via `start_capture()`, extracted via `take_capture()`.
519    capture: Option<Vec<CapturedNode>>,
520    /// Op kind tag for the NEXT captured dispatch.  Set via `set_op_kind()`,
521    /// consumed (reset to `Other`) when a dispatch is captured.
522    pending_op_kind: CapturedOpKind,
523    /// Pending read buffer ranges for the NEXT captured dispatch.
524    /// Set via `set_pending_buffer_ranges()`, consumed when the next dispatch
525    /// is captured.  Used by the reorder pass (Phase 4e.3).
526    pending_reads: Vec<MemRange>,
527    /// Pending write buffer ranges for the NEXT captured dispatch.
528    pending_writes: Vec<MemRange>,
529    /// ADR-015 iter8e (Phase 3b): residency set whose pending add/remove
530    /// staging is flushed at every `commit*` boundary.
531    ///
532    /// Cloned from the device at `device.command_encoder()` time. `None`
533    /// when residency sets are disabled (HF2Q_NO_RESIDENCY=1, macOS<15,
534    /// or test-only `CommandEncoder::new` from a residency-less queue).
535    residency_set: Option<ResidencySet>,
536    /// ADR-015 iter37: dataflow barrier inference state.
537    ///
538    /// Populated only when `HF2Q_AUTO_BARRIER=1` is set at process
539    /// start (cached via [`auto_barrier_enabled`]).  Each
540    /// [`Self::dispatch_tracked`] call consults this state to decide
541    /// whether a Metal memory barrier is required; on conflict the
542    /// barrier is emitted, the state is reset, and the new dispatch's
543    /// ranges seed the next concurrent group.  When the env gate is
544    /// off, `dispatch_tracked` collapses to its untracked equivalent
545    /// and this field is left empty for the encoder's lifetime.
546    ///
547    /// The field is always present (zero-sized when empty) so the
548    /// gate-off branch is a single bool-load + early return rather
549    /// than an allocation/Option indirection.
550    mem_ranges: MemRanges,
551    /// ADR-015 iter63 (per-dispatch profiling): the sample buffer for
552    /// `MTLCounterSampleBuffer.sampleCounters` calls that bracket every
553    /// `encode*` dispatch in this CB.  Lazily allocated on first
554    /// dispatch when `MLX_PROFILE_DISPATCH=1`; `None` otherwise.
555    /// Released (set to `None`) inside `resolve_dispatch_samples` after
556    /// the CB completes — re-allocated on the next `encode*` if the env
557    /// gate stays set.
558    sample_buffer: Option<CounterSampleBuffer>,
559    /// ADR-015 iter63: pending per-dispatch metadata that pairs with
560    /// sample indices `2*i` and `2*i+1` inside `sample_buffer`.  Each
561    /// `encode*` call appends one entry (when sampling is active);
562    /// `resolve_dispatch_samples` drains the vec at commit time.
563    pending_dispatch_meta: Vec<PendingDispatchMeta>,
564    /// ADR-015 iter63: 0-based dispatch ordinal within the current CB.
565    /// Incremented in every `encode*` site after taking the pending
566    /// op_kind; reset to 0 inside `resolve_dispatch_samples`.
567    dispatch_in_cb: u32,
568    /// ADR-015 iter63: most recent label set via `apply_labels`, used
569    /// as the per-dispatch `cb_label` field.  `String::new()` until
570    /// `commit_and_wait_labeled` / `commit_labeled` is called.
571    last_label: String,
572}
573
574/// SAFETY: CommandEncoder is safe to Send across threads provided that:
575/// 1. Only one thread accesses the encoder at a time (exclusive ownership).
576/// 2. The encoder is not used concurrently from multiple threads.
577///
578/// Metal command buffers and compute encoders are thread-safe for exclusive
579/// access (Apple documentation: "You can create command buffers, encode
580/// commands, and submit them from any thread"). The raw pointer
581/// `active_encoder` borrows from `cmd_buf` and is valid as long as
582/// `cmd_buf` is alive — this invariant holds across thread boundaries
583/// because both fields move together.
584///
585/// This matches llama.cpp's pattern of encoding command buffers on GCD
586/// worker threads via `dispatch_apply`, and is used for the dual-buffer
587/// pipeline where buf1 is encoded on a worker thread while buf0 executes.
588unsafe impl Send for CommandEncoder {}
589
590impl CommandEncoder {
591    /// Create a new command encoder from the given command queue.
592    ///
593    /// This immediately creates a Metal command buffer.
594    ///
595    /// # Why retained references
596    ///
597    /// We use the regular `commandBuffer` (Metal retains every bound
598    /// resource for the lifetime of the buffer) rather than
599    /// `commandBufferWithUnretainedReferences`.  llama.cpp uses unretained
600    /// refs for an additional perf bump (~3-5% on M-series GPUs), but the
601    /// hf2q dispatch pattern allocates many transient scratch buffers
602    /// inside helper functions (`apply_proj` → `weight_bf16_owned`,
603    /// `apply_pre_norm` → `params`, etc.) that go out of scope at the
604    /// helper's return.  With unretained refs the metal::Buffer's ARC
605    /// drops to zero, freeing the underlying GPU memory before the
606    /// dispatch executes.  Verified 2026-04-26: switching to unretained
607    /// hits "Command buffer error: GPU command buffer completed with
608    /// error status" on the first MoE FFN dispatch.
609    ///
610    /// To enable unretained refs in the future, every helper that
611    /// allocates and dispatches must thread its scratch buffers up to a
612    /// caller scope that outlives the eventual commit, OR all such
613    /// scratch must come from the per-decode-token pool (which already
614    /// ARC-retains in its in_use list).  Today the lm_head + router-
615    /// download paths are still unpooled.
616    #[allow(dead_code)]
617    pub(crate) fn new(queue: &CommandQueue) -> Result<Self> {
618        Self::new_with_residency(queue, None)
619    }
620
621    /// Create a new command encoder, optionally bound to a residency set so
622    /// `commit*` boundaries can flush deferred add/remove staging.
623    ///
624    /// ADR-015 iter8e (Phase 3b): the encoder's `commit_and_wait`,
625    /// `commit_and_wait_labeled`, `commit`, `commit_labeled`,
626    /// `commit_wait_with_gpu_time` all call
627    /// [`ResidencySet::flush_pending`](ResidencySet::flush_pending) before
628    /// submitting the Metal command buffer. This converts the
629    /// per-allocation `[set commit]` storm
630    /// (~880 commits/decode-token in iter8d/8e claude+codex variants) into
631    /// at most one commit per CB submission — mirrors llama.cpp's
632    /// `ggml-metal-device.m:1378-1382` pattern (batch addAllocation in
633    /// loop, commit ONCE).
634    ///
635    /// ADR-015 iter13: when the `MLX_UNRETAINED_REFS=1` env var is set at
636    /// process start, this constructor uses
637    /// [`CommandQueueRef::new_command_buffer_with_unretained_references`]
638    /// instead of `new_command_buffer`.  llama.cpp's per-token decode CBs
639    /// use `commandBufferWithUnretainedReferences` (see
640    /// `/opt/llama.cpp/ggml/src/ggml-metal/ggml-metal-context.m:512`) which
641    /// skips Metal's per-buffer-binding ARC-retain on submit and saves
642    /// ~3-5% on M-series GPUs (per the docstring above).
643    ///
644    /// **Caller contract under unretained refs.**  Every Metal buffer bound
645    /// to a dispatch in this CB MUST outlive the CB's GPU completion.  In
646    /// the hf2q decode path, that means every transient scratch must be
647    /// either (a) backed by the per-decode-token arena pool
648    /// (`MlxBufferPool` keeps an ARC clone in `in_use` until the next
649    /// `reset` — see `buffer_pool.rs:60`) or (b) hoisted to a caller scope
650    /// that lives across the terminal `commit_and_wait_labeled`.  Helpers
651    /// in `apply_proj` / `apply_pre_norm` / lm_head cast / router-download
652    /// that allocated transients via `device.alloc_buffer` and dropped
653    /// them at function return MUST be lifted to `pooled_alloc_buffer`
654    /// before `MLX_UNRETAINED_REFS=1` is enabled, or the first MoE FFN
655    /// dispatch will crash with "Command buffer error: GPU command buffer
656    /// completed with error status" (verified 2026-04-26).
657    ///
658    /// The default (`MLX_UNRETAINED_REFS` unset) preserves retained-refs
659    /// behavior verbatim — this is the sourdough-safe path.
660    pub(crate) fn new_with_residency(
661        queue: &CommandQueue,
662        residency_set: Option<ResidencySet>,
663    ) -> Result<Self> {
664        let cmd_buf = if unretained_refs_enabled() {
665            queue.new_command_buffer_with_unretained_references().to_owned()
666        } else {
667            queue.new_command_buffer().to_owned()
668        };
669        CMD_BUF_COUNT.fetch_add(1, Ordering::Relaxed);
670        Ok(Self {
671            cmd_buf,
672            active_encoder: std::ptr::null(),
673            capture: None,
674            pending_op_kind: CapturedOpKind::Other,
675            pending_reads: Vec::new(),
676            pending_writes: Vec::new(),
677            residency_set,
678            mem_ranges: MemRanges::new(),
679            sample_buffer: None,
680            pending_dispatch_meta: Vec::new(),
681            dispatch_in_cb: 0,
682            last_label: String::new(),
683        })
684    }
685
686    /// Enable capture mode.
687    ///
688    /// All subsequent dispatch and barrier calls will be recorded into a
689    /// `Vec<CapturedNode>` instead of being encoded into Metal.
690    /// Call `take_capture()` to extract the recorded nodes.
691    pub fn start_capture(&mut self) {
692        self.capture = Some(Vec::with_capacity(128));
693    }
694
695    /// Whether the encoder is currently in capture mode.
696    pub fn is_capturing(&self) -> bool {
697        self.capture.is_some()
698    }
699
700    /// Extract the captured nodes, ending capture mode.
701    ///
702    /// Returns `None` if capture mode was not active.
703    pub fn take_capture(&mut self) -> Option<Vec<CapturedNode>> {
704        self.capture.take()
705    }
706
707    /// Tag the NEXT captured dispatch with the given operation kind.
708    ///
709    /// The tag is consumed (reset to `Other`) after the next dispatch is
710    /// captured.  Only meaningful in capture mode — has no effect on
711    /// direct-dispatch encoding.
712    ///
713    /// Used by op dispatch functions to annotate captures for the fusion
714    /// pass (Phase 4e.2).
715    pub fn set_op_kind(&mut self, kind: CapturedOpKind) {
716        self.pending_op_kind = kind;
717    }
718
719    /// Consume and return the pending op kind, resetting it to `Other`.
720    fn take_pending_op_kind(&mut self) -> CapturedOpKind {
721        let kind = self.pending_op_kind;
722        self.pending_op_kind = CapturedOpKind::Other;
723        kind
724    }
725
726    /// Stash buffer range annotations for the NEXT captured dispatch.
727    ///
728    /// Called by `GraphSession::barrier_between()` in capture mode to record
729    /// which buffers the next dispatch reads from and writes to.  The ranges
730    /// are consumed by the next `encode_*` call and attached to the captured
731    /// `CapturedNode::Dispatch`.
732    ///
733    /// Only meaningful in capture mode — has no effect on direct-dispatch.
734    pub fn set_pending_buffer_ranges(&mut self, reads: Vec<MemRange>, writes: Vec<MemRange>) {
735        self.pending_reads = reads;
736        self.pending_writes = writes;
737    }
738
739    /// Patch the last captured dispatch node's empty reads/writes with the
740    /// given ranges. No-op if not capturing, or if the last node isn't a
741    /// Dispatch, or if its ranges are already populated.
742    ///
743    /// Used by `GraphSession::track_dispatch` in recording mode to annotate
744    /// dispatches that were called without a preceding `barrier_between`.
745    pub fn annotate_last_dispatch_if_missing(&mut self, reads: Vec<MemRange>, writes: Vec<MemRange>) {
746        if let Some(ref mut nodes) = self.capture {
747            if let Some(CapturedNode::Dispatch { reads: r, writes: w, .. }) = nodes.last_mut() {
748                if r.is_empty() && !reads.is_empty() {
749                    *r = reads;
750                }
751                if w.is_empty() && !writes.is_empty() {
752                    *w = writes;
753                }
754            }
755        }
756    }
757
758    /// Consume and return the pending buffer range annotations.
759    fn take_pending_buffer_ranges(&mut self) -> (Vec<MemRange>, Vec<MemRange>) {
760        let reads = std::mem::take(&mut self.pending_reads);
761        let writes = std::mem::take(&mut self.pending_writes);
762        (reads, writes)
763    }
764
765    /// Record buffer bindings into `RecordedBinding` form.
766    fn record_buffer_bindings(buffers: &[(u64, &MlxBuffer)]) -> Vec<(u64, RecordedBinding)> {
767        buffers
768            .iter()
769            .map(|&(index, buf)| {
770                (
771                    index,
772                    RecordedBinding::Buffer {
773                        metal_buffer: buf.metal_buffer().clone(),
774                        offset: buf.byte_offset(),
775                    },
776                )
777            })
778            .collect()
779    }
780
781    /// Record `KernelArg` bindings into `RecordedBinding` form.
782    ///
783    /// `KernelArg::Buffer(buf)` records `buf.byte_offset()` so capture →
784    /// replay round-trips of `slice_view`-derived buffers preserve their
785    /// offsets, matching `record_buffer_bindings`'s behavior at line 382.
786    fn record_arg_bindings(bindings: &[(u64, KernelArg<'_>)]) -> Vec<(u64, RecordedBinding)> {
787        bindings
788            .iter()
789            .map(|(index, arg)| {
790                let recorded = match arg {
791                    KernelArg::Buffer(buf) => RecordedBinding::Buffer {
792                        metal_buffer: buf.metal_buffer().clone(),
793                        offset: buf.byte_offset(),
794                    },
795                    KernelArg::BufferWithOffset(buf, offset) => RecordedBinding::Buffer {
796                        metal_buffer: buf.metal_buffer().clone(),
797                        offset: *offset,
798                    },
799                    KernelArg::Bytes(bytes) => RecordedBinding::Bytes(bytes.to_vec()),
800                };
801                (*index, recorded)
802            })
803            .collect()
804    }
805
806    /// Get or create the persistent compute encoder.
807    ///
808    /// On the first call, creates a new compute encoder from the command
809    /// buffer.  On subsequent calls, returns the existing one.
810    ///
811    /// SAFETY: The returned reference borrows from `self.cmd_buf` which is
812    /// alive for the lifetime of this `CommandEncoder`.  The raw pointer is
813    /// valid until `end_active_encoder()` is called.
814    #[inline]
815    fn get_or_create_encoder(&mut self) -> &ComputeCommandEncoderRef {
816        if self.active_encoder.is_null() {
817            // Use MTLDispatchTypeConcurrent to allow independent dispatches
818            // to overlap on the GPU.  Memory barriers are inserted between
819            // dependent dispatches via `memory_barrier()`.
820            //
821            // ADR-015 iter61a-2 probe: HF2Q_FORCE_SERIAL_DISPATCH=1 falls back
822            // to MTLDispatchType::Serial — every dispatch waits for the
823            // previous to complete, eliminating concurrent-dispatch race
824            // windows. Used to falsify Hypothesis (g): missing memory_barrier
825            // calls between dependent dispatches cause cold-run logit
826            // non-determinism via thread-race on a shared buffer.
827            let dispatch_type = if std::env::var("HF2Q_FORCE_SERIAL_DISPATCH")
828                .map(|v| v == "1")
829                .unwrap_or(false)
830            {
831                MTLDispatchType::Serial
832            } else {
833                MTLDispatchType::Concurrent
834            };
835            let encoder = self
836                .cmd_buf
837                .compute_command_encoder_with_dispatch_type(dispatch_type);
838            self.active_encoder = encoder as *const ComputeCommandEncoderRef;
839        }
840        // SAFETY: active_encoder is non-null and points to a valid encoder
841        // owned by cmd_buf.
842        unsafe { &*self.active_encoder }
843    }
844
845    /// End the active compute encoder if one exists.
846    #[inline]
847    fn end_active_encoder(&mut self) {
848        if !self.active_encoder.is_null() {
849            // SAFETY: the pointer was obtained from cmd_buf.new_compute_command_encoder()
850            // and has not been ended yet.
851            unsafe { &*self.active_encoder }.end_encoding();
852            self.active_encoder = std::ptr::null();
853        }
854    }
855
856    /// Insert a memory barrier with scope `MTLBarrierScopeBuffers`.
857    ///
858    /// When the encoder uses `MTLDispatchTypeConcurrent`, all dispatches can
859    /// execute concurrently unless separated by a barrier.  Call this between
860    /// dispatches where the later dispatch reads a buffer written by an
861    /// earlier one.
862    ///
863    /// This is the same pattern llama.cpp uses:
864    /// `[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers]`
865    #[allow(unexpected_cfgs)]
866    pub fn memory_barrier(&mut self) {
867        if let Some(ref mut nodes) = self.capture {
868            nodes.push(CapturedNode::Barrier);
869            return;
870        }
871        if self.active_encoder.is_null() {
872            return;
873        }
874        BARRIER_COUNT.fetch_add(1, Ordering::Relaxed);
875        // SAFETY: active_encoder is non-null and valid.
876        let encoder = unsafe { &*self.active_encoder };
877        if barrier_profile_enabled() {
878            // mach_absolute_time path — only on when MLX_PROFILE_BARRIERS=1.
879            let start = std::time::Instant::now();
880            issue_metal_buffer_barrier(encoder);
881            let elapsed_ns = start.elapsed().as_nanos() as u64;
882            BARRIER_NS.fetch_add(elapsed_ns, Ordering::Relaxed);
883        } else {
884            issue_metal_buffer_barrier(encoder);
885        }
886    }
887
888    /// Set the compute pipeline state for subsequent dispatches.
889    ///
890    /// This begins a new compute pass if one is not already active.
891    pub fn set_pipeline(&mut self, pipeline: &ComputePipelineStateRef) {
892        let encoder = self.get_or_create_encoder();
893        encoder.set_compute_pipeline_state(pipeline);
894    }
895
896    /// Bind a buffer to a compute kernel argument slot.
897    ///
898    /// The `index` corresponds to the `[[buffer(N)]]` attribute in the MSL shader.
899    pub fn set_buffer(&self, index: u64, buffer: &MlxBuffer) {
900        let _ = (index, buffer);
901    }
902
903    /// Dispatch threads on the GPU.
904    pub fn dispatch_threads(&self, grid_size: MTLSize, threadgroup_size: MTLSize) {
905        let _ = (grid_size, threadgroup_size);
906    }
907
908    /// Encode a complete compute pass: set pipeline, bind buffers, dispatch.
909    ///
910    /// Reuses the persistent compute encoder — no per-dispatch encoder
911    /// creation overhead.
912    ///
913    /// # Arguments
914    ///
915    /// * `pipeline`         — The compiled compute pipeline to execute.
916    /// * `buffers`          — Slice of `(index, &MlxBuffer)` pairs for buffer bindings.
917    /// * `grid_size`        — Total number of threads to launch.
918    /// * `threadgroup_size` — Threads per threadgroup.
919    pub fn encode(
920        &mut self,
921        pipeline: &ComputePipelineStateRef,
922        buffers: &[(u64, &MlxBuffer)],
923        grid_size: MTLSize,
924        threadgroup_size: MTLSize,
925    ) {
926        DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
927        let op_kind = self.take_pending_op_kind();
928        let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
929        if let Some(ref mut nodes) = self.capture {
930            nodes.push(CapturedNode::Dispatch {
931                pipeline: pipeline.to_owned(),
932                bindings: Self::record_buffer_bindings(buffers),
933                threads_per_grid: grid_size,
934                threads_per_threadgroup: threadgroup_size,
935                threadgroup_memory: Vec::new(),
936                dispatch_kind: DispatchKind::Threads,
937                op_kind,
938                reads: pending_reads,
939                writes: pending_writes,
940            });
941            return;
942        }
943        // ADR-015 iter63: per-dispatch sampling (no-op when env unset).
944        self.ensure_sample_buffer();
945        let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
946        // SAFETY: encoder_ptr aliases &self via active_encoder which we
947        // know is non-null after get_or_create_encoder; this pattern is
948        // used throughout the file (see memory_barrier).
949        let encoder = unsafe { &*encoder_ptr };
950        encoder.set_compute_pipeline_state(pipeline);
951        for &(index, buf) in buffers {
952            encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
953        }
954        let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
955        encoder.dispatch_threads(grid_size, threadgroup_size);
956        self.sample_dispatch_post(encoder, pre_idx);
957    }
958
959    /// Encode a compute pass using threadgroups instead of raw thread counts.
960    ///
961    /// Reuses the persistent compute encoder — no per-dispatch encoder
962    /// creation overhead.
963    pub fn encode_threadgroups(
964        &mut self,
965        pipeline: &ComputePipelineStateRef,
966        buffers: &[(u64, &MlxBuffer)],
967        threadgroups: MTLSize,
968        threadgroup_size: MTLSize,
969    ) {
970        DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
971        let op_kind = self.take_pending_op_kind();
972        let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
973        if let Some(ref mut nodes) = self.capture {
974            nodes.push(CapturedNode::Dispatch {
975                pipeline: pipeline.to_owned(),
976                bindings: Self::record_buffer_bindings(buffers),
977                threads_per_grid: threadgroups,
978                threads_per_threadgroup: threadgroup_size,
979                threadgroup_memory: Vec::new(),
980                dispatch_kind: DispatchKind::ThreadGroups,
981                op_kind,
982                reads: pending_reads,
983                writes: pending_writes,
984            });
985            return;
986        }
987        // ADR-015 iter63: per-dispatch sampling (no-op when env unset).
988        self.ensure_sample_buffer();
989        let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
990        // SAFETY: see encode() above.
991        let encoder = unsafe { &*encoder_ptr };
992        encoder.set_compute_pipeline_state(pipeline);
993        for &(index, buf) in buffers {
994            encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
995        }
996        let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
997        encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
998        self.sample_dispatch_post(encoder, pre_idx);
999    }
1000
1001    /// Encode a compute pass using threadgroups with shared threadgroup memory.
1002    ///
1003    /// Like [`encode_threadgroups`](Self::encode_threadgroups), but additionally
1004    /// allocates threadgroup memory at the specified indices.  This is required
1005    /// for kernels that use `threadgroup` memory (e.g. reductions in rms_norm
1006    /// and softmax).
1007    ///
1008    /// # Arguments
1009    ///
1010    /// * `pipeline`         — The compiled compute pipeline to execute.
1011    /// * `buffers`          — Slice of `(index, &MlxBuffer)` pairs for buffer bindings.
1012    /// * `threadgroup_mem`  — Slice of `(index, byte_length)` pairs for threadgroup memory.
1013    /// * `threadgroups`     — Number of threadgroups to dispatch.
1014    /// * `threadgroup_size` — Threads per threadgroup.
1015    pub fn encode_threadgroups_with_shared(
1016        &mut self,
1017        pipeline: &ComputePipelineStateRef,
1018        buffers: &[(u64, &MlxBuffer)],
1019        threadgroup_mem: &[(u64, u64)],
1020        threadgroups: MTLSize,
1021        threadgroup_size: MTLSize,
1022    ) {
1023        DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
1024        let op_kind = self.take_pending_op_kind();
1025        let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
1026        if let Some(ref mut nodes) = self.capture {
1027            nodes.push(CapturedNode::Dispatch {
1028                pipeline: pipeline.to_owned(),
1029                bindings: Self::record_buffer_bindings(buffers),
1030                threads_per_grid: threadgroups,
1031                threads_per_threadgroup: threadgroup_size,
1032                threadgroup_memory: threadgroup_mem.to_vec(),
1033                dispatch_kind: DispatchKind::ThreadGroups,
1034                op_kind,
1035                reads: pending_reads,
1036                writes: pending_writes,
1037            });
1038            return;
1039        }
1040        // ADR-015 iter63: per-dispatch sampling (no-op when env unset).
1041        self.ensure_sample_buffer();
1042        let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
1043        // SAFETY: see encode() above.
1044        let encoder = unsafe { &*encoder_ptr };
1045        encoder.set_compute_pipeline_state(pipeline);
1046        for &(index, buf) in buffers {
1047            encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
1048        }
1049        for &(index, byte_length) in threadgroup_mem {
1050            encoder.set_threadgroup_memory_length(index, byte_length);
1051        }
1052        let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
1053        encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
1054        self.sample_dispatch_post(encoder, pre_idx);
1055    }
1056
1057    /// Encode a dispatch with mixed buffer/bytes bindings (dispatch_threads).
1058    ///
1059    /// Reuses the persistent compute encoder.
1060    pub fn encode_with_args(
1061        &mut self,
1062        pipeline: &ComputePipelineStateRef,
1063        bindings: &[(u64, KernelArg<'_>)],
1064        grid_size: MTLSize,
1065        threadgroup_size: MTLSize,
1066    ) {
1067        DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
1068        let op_kind = self.take_pending_op_kind();
1069        let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
1070        if let Some(ref mut nodes) = self.capture {
1071            nodes.push(CapturedNode::Dispatch {
1072                pipeline: pipeline.to_owned(),
1073                bindings: Self::record_arg_bindings(bindings),
1074                threads_per_grid: grid_size,
1075                threads_per_threadgroup: threadgroup_size,
1076                threadgroup_memory: Vec::new(),
1077                dispatch_kind: DispatchKind::Threads,
1078                op_kind,
1079                reads: pending_reads,
1080                writes: pending_writes,
1081            });
1082            return;
1083        }
1084        // ADR-015 iter63: per-dispatch sampling (no-op when env unset).
1085        self.ensure_sample_buffer();
1086        let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
1087        // SAFETY: see encode() above.
1088        let encoder = unsafe { &*encoder_ptr };
1089        encoder.set_compute_pipeline_state(pipeline);
1090        apply_bindings(encoder, bindings);
1091        let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
1092        encoder.dispatch_threads(grid_size, threadgroup_size);
1093        self.sample_dispatch_post(encoder, pre_idx);
1094    }
1095
1096    /// Encode a dispatch with mixed buffer/bytes bindings (dispatch_thread_groups).
1097    ///
1098    /// Reuses the persistent compute encoder.
1099    pub fn encode_threadgroups_with_args(
1100        &mut self,
1101        pipeline: &ComputePipelineStateRef,
1102        bindings: &[(u64, KernelArg<'_>)],
1103        threadgroups: MTLSize,
1104        threadgroup_size: MTLSize,
1105    ) {
1106        DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
1107        let op_kind = self.take_pending_op_kind();
1108        let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
1109        if let Some(ref mut nodes) = self.capture {
1110            nodes.push(CapturedNode::Dispatch {
1111                pipeline: pipeline.to_owned(),
1112                bindings: Self::record_arg_bindings(bindings),
1113                threads_per_grid: threadgroups,
1114                threads_per_threadgroup: threadgroup_size,
1115                threadgroup_memory: Vec::new(),
1116                dispatch_kind: DispatchKind::ThreadGroups,
1117                op_kind,
1118                reads: pending_reads,
1119                writes: pending_writes,
1120            });
1121            return;
1122        }
1123        // ADR-015 iter63: per-dispatch sampling (no-op when env unset).
1124        self.ensure_sample_buffer();
1125        let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
1126        // SAFETY: see encode() above.
1127        let encoder = unsafe { &*encoder_ptr };
1128        encoder.set_compute_pipeline_state(pipeline);
1129        apply_bindings(encoder, bindings);
1130        let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
1131        encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
1132        self.sample_dispatch_post(encoder, pre_idx);
1133    }
1134
1135    /// Encode a dispatch with mixed buffer/bytes bindings and shared memory.
1136    ///
1137    /// Reuses the persistent compute encoder.
1138    pub fn encode_threadgroups_with_args_and_shared(
1139        &mut self,
1140        pipeline: &ComputePipelineStateRef,
1141        bindings: &[(u64, KernelArg<'_>)],
1142        threadgroup_mem: &[(u64, u64)],
1143        threadgroups: MTLSize,
1144        threadgroup_size: MTLSize,
1145    ) {
1146        DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
1147        let op_kind = self.take_pending_op_kind();
1148        let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
1149        if let Some(ref mut nodes) = self.capture {
1150            nodes.push(CapturedNode::Dispatch {
1151                pipeline: pipeline.to_owned(),
1152                bindings: Self::record_arg_bindings(bindings),
1153                threads_per_grid: threadgroups,
1154                threads_per_threadgroup: threadgroup_size,
1155                threadgroup_memory: threadgroup_mem.to_vec(),
1156                dispatch_kind: DispatchKind::ThreadGroups,
1157                op_kind,
1158                reads: pending_reads,
1159                writes: pending_writes,
1160            });
1161            return;
1162        }
1163        // ADR-015 iter63: per-dispatch sampling (no-op when env unset).
1164        self.ensure_sample_buffer();
1165        let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
1166        // SAFETY: see encode() above.
1167        let encoder = unsafe { &*encoder_ptr };
1168        encoder.set_compute_pipeline_state(pipeline);
1169        apply_bindings(encoder, bindings);
1170        for &(index, byte_length) in threadgroup_mem {
1171            encoder.set_threadgroup_memory_length(index, byte_length);
1172        }
1173        let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
1174        encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
1175        self.sample_dispatch_post(encoder, pre_idx);
1176    }
1177
1178    // -----------------------------------------------------------------
1179    // ADR-015 iter37 — dataflow-driven auto-barrier dispatch family.
1180    //
1181    // These mirrors of `encode_threadgroups*_with_args*` take explicit
1182    // `reads: &[&MlxBuffer]` and `writes: &[&MlxBuffer]` slices.  When
1183    // the process started with `HF2Q_AUTO_BARRIER=1`, the encoder's
1184    // [`MemRanges`] tracker checks the new ranges against the
1185    // cumulative state since the last barrier; on conflict it emits
1186    // `memory_barrier()` and resets the state before recording the
1187    // new ranges.  When the env gate is unset, the check is skipped
1188    // entirely and the dispatch is applied identically to the
1189    // matching `encode_*` method — sourdough-safe by construction.
1190    //
1191    // Capture mode: the `reads`/`writes` ranges are recorded onto the
1192    // captured node via the existing `pending_reads`/`pending_writes`
1193    // mechanism, so a `dispatch_tracked` call inside capture mode is
1194    // equivalent to `set_pending_buffer_ranges + encode_*`.
1195    //
1196    // No production callsite migrates in iter37 — this is the API
1197    // surface the qwen35 forward path will adopt incrementally in
1198    // iter38+.  Today, every call to `dispatch_tracked` from a
1199    // production code path lives behind an explicit caller decision
1200    // to opt in.
1201    // -----------------------------------------------------------------
1202
1203    /// Auto-barrier-aware dispatch with [`KernelArg`] bindings (uses
1204    /// `dispatch_thread_groups`).
1205    ///
1206    /// Behaves identically to
1207    /// [`encode_threadgroups_with_args`](Self::encode_threadgroups_with_args)
1208    /// when `HF2Q_AUTO_BARRIER` is unset.  When set, consults the
1209    /// per-encoder [`MemRanges`] tracker:
1210    ///
1211    /// * Conflict (RAW/WAR/WAW on a same-buffer range) → emit
1212    ///   `memory_barrier()`, increment [`AUTO_BARRIER_COUNT`], reset
1213    ///   the tracker, then dispatch and seed the new concurrent group
1214    ///   with this dispatch's ranges.
1215    /// * No conflict → increment [`AUTO_BARRIER_CONCURRENT`], record
1216    ///   the ranges into the cumulative state, dispatch.
1217    pub fn dispatch_tracked_threadgroups_with_args(
1218        &mut self,
1219        pipeline: &ComputePipelineStateRef,
1220        bindings: &[(u64, KernelArg<'_>)],
1221        reads: &[&MlxBuffer],
1222        writes: &[&MlxBuffer],
1223        threadgroups: MTLSize,
1224        threadgroup_size: MTLSize,
1225    ) {
1226        // Capture mode: stash ranges + delegate to the standard encode.
1227        // The ranges flow through `pending_reads`/`pending_writes` and
1228        // attach to the captured `Dispatch` node — identical to what
1229        // `GraphSession::barrier_between` already does in capture mode.
1230        if self.is_capturing() {
1231            let read_ranges = ranges_from_buffers(reads);
1232            let write_ranges = ranges_from_buffers(writes);
1233            self.set_pending_buffer_ranges(read_ranges, write_ranges);
1234            self.encode_threadgroups_with_args(pipeline, bindings, threadgroups, threadgroup_size);
1235            return;
1236        }
1237
1238        if auto_barrier_enabled() {
1239            self.maybe_auto_barrier(reads, writes);
1240        }
1241
1242        self.encode_threadgroups_with_args(pipeline, bindings, threadgroups, threadgroup_size);
1243    }
1244
1245    /// Auto-barrier-aware dispatch with [`KernelArg`] bindings + shared
1246    /// threadgroup memory.
1247    ///
1248    /// See [`dispatch_tracked_threadgroups_with_args`](Self::dispatch_tracked_threadgroups_with_args)
1249    /// for the behavioral contract; this variant additionally takes a
1250    /// `threadgroup_mem` slice that is forwarded to
1251    /// [`encode_threadgroups_with_args_and_shared`](Self::encode_threadgroups_with_args_and_shared).
1252    ///
1253    /// The 8-argument signature mirrors the existing
1254    /// `encode_threadgroups_with_args_and_shared` plus the two
1255    /// dataflow slices; `clippy::too_many_arguments` is allowed
1256    /// because each parameter is load-bearing for either the dispatch
1257    /// (pipeline/bindings/threadgroups/threadgroup_size/shared_mem)
1258    /// or the auto-barrier (reads/writes).
1259    #[allow(clippy::too_many_arguments)]
1260    pub fn dispatch_tracked_threadgroups_with_args_and_shared(
1261        &mut self,
1262        pipeline: &ComputePipelineStateRef,
1263        bindings: &[(u64, KernelArg<'_>)],
1264        threadgroup_mem: &[(u64, u64)],
1265        reads: &[&MlxBuffer],
1266        writes: &[&MlxBuffer],
1267        threadgroups: MTLSize,
1268        threadgroup_size: MTLSize,
1269    ) {
1270        if self.is_capturing() {
1271            let read_ranges = ranges_from_buffers(reads);
1272            let write_ranges = ranges_from_buffers(writes);
1273            self.set_pending_buffer_ranges(read_ranges, write_ranges);
1274            self.encode_threadgroups_with_args_and_shared(
1275                pipeline,
1276                bindings,
1277                threadgroup_mem,
1278                threadgroups,
1279                threadgroup_size,
1280            );
1281            return;
1282        }
1283
1284        if auto_barrier_enabled() {
1285            self.maybe_auto_barrier(reads, writes);
1286        }
1287
1288        self.encode_threadgroups_with_args_and_shared(
1289            pipeline,
1290            bindings,
1291            threadgroup_mem,
1292            threadgroups,
1293            threadgroup_size,
1294        );
1295    }
1296
1297    /// Auto-barrier-aware dispatch using `(slot, &MlxBuffer)` bindings
1298    /// (uses `dispatch_thread_groups`).
1299    ///
1300    /// Convenience wrapper for callers that don't need
1301    /// [`KernelArg::Bytes`] inline-byte arguments.  See
1302    /// [`dispatch_tracked_threadgroups_with_args`](Self::dispatch_tracked_threadgroups_with_args)
1303    /// for behavioral contract.
1304    pub fn dispatch_tracked_threadgroups(
1305        &mut self,
1306        pipeline: &ComputePipelineStateRef,
1307        buffers: &[(u64, &MlxBuffer)],
1308        reads: &[&MlxBuffer],
1309        writes: &[&MlxBuffer],
1310        threadgroups: MTLSize,
1311        threadgroup_size: MTLSize,
1312    ) {
1313        if self.is_capturing() {
1314            let read_ranges = ranges_from_buffers(reads);
1315            let write_ranges = ranges_from_buffers(writes);
1316            self.set_pending_buffer_ranges(read_ranges, write_ranges);
1317            self.encode_threadgroups(pipeline, buffers, threadgroups, threadgroup_size);
1318            return;
1319        }
1320
1321        if auto_barrier_enabled() {
1322            self.maybe_auto_barrier(reads, writes);
1323        }
1324
1325        self.encode_threadgroups(pipeline, buffers, threadgroups, threadgroup_size);
1326    }
1327
1328    /// Auto-barrier-aware dispatch using `(slot, &MlxBuffer)` bindings
1329    /// **plus shared threadgroup memory** (uses `dispatch_thread_groups`).
1330    ///
1331    /// Mirrors [`encode_threadgroups_with_shared`](Self::encode_threadgroups_with_shared)
1332    /// — convenience variant for kernels that allocate threadgroup
1333    /// memory (reductions in `rms_norm`, `softmax`, etc.) but don't
1334    /// need [`KernelArg::Bytes`] inline-byte arguments.  See
1335    /// [`dispatch_tracked_threadgroups_with_args`](Self::dispatch_tracked_threadgroups_with_args)
1336    /// for the behavioral contract; the only addition here is the
1337    /// `threadgroup_mem` slice forwarded to the underlying encode.
1338    ///
1339    /// Closes the iter38-audit coverage gap: the 5 `rms_norm.rs`
1340    /// callsites (`/opt/mlx-native/src/ops/rms_norm.rs:124,236,443,
1341    /// 516,589`) all use `encode_threadgroups_with_shared` and need
1342    /// dataflow tracking when migrated to auto-barrier in iter40+.
1343    ///
1344    /// 7-argument signature; `clippy::too_many_arguments` is allowed
1345    /// because each parameter is load-bearing for either the dispatch
1346    /// (pipeline/buffers/threadgroups/threadgroup_size/shared_mem) or
1347    /// the auto-barrier (reads/writes).
1348    #[allow(clippy::too_many_arguments)]
1349    pub fn dispatch_tracked_threadgroups_with_shared(
1350        &mut self,
1351        pipeline: &ComputePipelineStateRef,
1352        buffers: &[(u64, &MlxBuffer)],
1353        threadgroup_mem: &[(u64, u64)],
1354        reads: &[&MlxBuffer],
1355        writes: &[&MlxBuffer],
1356        threadgroups: MTLSize,
1357        threadgroup_size: MTLSize,
1358    ) {
1359        if self.is_capturing() {
1360            let read_ranges = ranges_from_buffers(reads);
1361            let write_ranges = ranges_from_buffers(writes);
1362            self.set_pending_buffer_ranges(read_ranges, write_ranges);
1363            self.encode_threadgroups_with_shared(
1364                pipeline,
1365                buffers,
1366                threadgroup_mem,
1367                threadgroups,
1368                threadgroup_size,
1369            );
1370            return;
1371        }
1372
1373        if auto_barrier_enabled() {
1374            self.maybe_auto_barrier(reads, writes);
1375        }
1376
1377        self.encode_threadgroups_with_shared(
1378            pipeline,
1379            buffers,
1380            threadgroup_mem,
1381            threadgroups,
1382            threadgroup_size,
1383        );
1384    }
1385
1386    /// Auto-barrier-aware `dispatch_threads` variant with
1387    /// [`KernelArg`] bindings.
1388    ///
1389    /// Mirrors [`encode_with_args`](Self::encode_with_args) — the
1390    /// `dispatch_threads` (per-thread grid) flavor, as opposed to the
1391    /// `dispatch_thread_groups` flavor of
1392    /// [`dispatch_tracked_threadgroups_with_args`](Self::dispatch_tracked_threadgroups_with_args).
1393    /// See that method for the behavioral contract.
1394    ///
1395    /// Closes the iter38-audit coverage gap: callers that use
1396    /// per-thread grids — `rope.rs:108` (IMROPE), `sigmoid_mul.rs:76`
1397    /// (sigmoid-mul), and `encode_helpers.rs:41` (kv_cache_copy) —
1398    /// need a `dispatch_threads` flavor of the tracked dispatch
1399    /// because their grid sizes are expressed in threads, not
1400    /// threadgroups.
1401    ///
1402    /// Note: the simpler `(slot, &MlxBuffer)` form (from
1403    /// [`encode`](Self::encode)) is a special case of this method —
1404    /// callers can wrap each binding as `KernelArg::Buffer(buf)` to
1405    /// reuse this single tracked variant rather than introducing a
1406    /// fifth one.
1407    pub fn dispatch_tracked_threads_with_args(
1408        &mut self,
1409        pipeline: &ComputePipelineStateRef,
1410        bindings: &[(u64, KernelArg<'_>)],
1411        reads: &[&MlxBuffer],
1412        writes: &[&MlxBuffer],
1413        grid_size: MTLSize,
1414        threadgroup_size: MTLSize,
1415    ) {
1416        if self.is_capturing() {
1417            let read_ranges = ranges_from_buffers(reads);
1418            let write_ranges = ranges_from_buffers(writes);
1419            self.set_pending_buffer_ranges(read_ranges, write_ranges);
1420            self.encode_with_args(pipeline, bindings, grid_size, threadgroup_size);
1421            return;
1422        }
1423
1424        if auto_barrier_enabled() {
1425            self.maybe_auto_barrier(reads, writes);
1426        }
1427
1428        self.encode_with_args(pipeline, bindings, grid_size, threadgroup_size);
1429    }
1430
1431    /// Run the dataflow check, emit a barrier on conflict, and record
1432    /// the dispatch's ranges into the cumulative state.
1433    ///
1434    /// Always called *before* the underlying `encode_*` method
1435    /// applies the dispatch.  Mirrors lines 220-225 of
1436    /// `ggml-metal-ops.cpp` (`concurrency_check + concurrency_reset +
1437    /// concurrency_add` around each node).
1438    fn maybe_auto_barrier(
1439        &mut self,
1440        reads: &[&MlxBuffer],
1441        writes: &[&MlxBuffer],
1442    ) {
1443        if self.mem_ranges.check_dispatch(reads, writes) {
1444            // Concurrent — no barrier needed; just record the new ranges.
1445            self.mem_ranges.add_dispatch(reads, writes);
1446            AUTO_BARRIER_CONCURRENT.fetch_add(1, Ordering::Relaxed);
1447        } else {
1448            // Conflict — emit barrier, reset state, seed new group.
1449            //
1450            // `memory_barrier()` itself increments `BARRIER_COUNT` and,
1451            // when `MLX_PROFILE_BARRIERS=1`, accumulates `BARRIER_NS`.
1452            // We additionally bump `AUTO_BARRIER_COUNT` so the
1453            // "auto-emitted vs hand-placed" subset is queryable.
1454            self.memory_barrier();
1455            self.mem_ranges.reset();
1456            self.mem_ranges.add_dispatch(reads, writes);
1457            AUTO_BARRIER_COUNT.fetch_add(1, Ordering::Relaxed);
1458        }
1459    }
1460
1461    /// Force a barrier and reset the auto-barrier tracker.
1462    ///
1463    /// Use at boundaries where the caller knows a barrier is required
1464    /// regardless of dataflow — typically before reading data back to
1465    /// CPU, or at the end of an op group whose internal dependencies
1466    /// the tracker can't see (e.g. host-driven memcpy).
1467    ///
1468    /// Equivalent to `memory_barrier()` plus a `MemRanges::reset()`
1469    /// when `HF2Q_AUTO_BARRIER=1`; equivalent to plain
1470    /// `memory_barrier()` otherwise.
1471    pub fn force_barrier_and_reset_tracker(&mut self) {
1472        self.memory_barrier();
1473        if auto_barrier_enabled() {
1474            self.mem_ranges.reset();
1475        }
1476    }
1477
1478    /// Diagnostic accessor — number of ranges currently recorded in
1479    /// this encoder's [`MemRanges`] tracker.  Always zero unless
1480    /// `HF2Q_AUTO_BARRIER=1` and at least one `dispatch_tracked` call
1481    /// has fired since the last conflict.
1482    #[inline]
1483    pub fn mem_ranges_len(&self) -> usize {
1484        self.mem_ranges.len()
1485    }
1486
1487    /// Replay a single captured dispatch node into this encoder.
1488    ///
1489    /// This is the inverse of capture: it takes a previously recorded
1490    /// `CapturedNode::Dispatch` and encodes it into the live Metal encoder.
1491    /// Barrier nodes are handled by the caller (ComputeGraph::encode_sequential).
1492    ///
1493    /// Does NOT increment `DISPATCH_COUNT` — that was already counted at
1494    /// capture time.
1495    pub fn replay_dispatch(
1496        &mut self,
1497        pipeline: &ComputePipelineStateRef,
1498        bindings: &[(u64, RecordedBinding)],
1499        threadgroup_memory: &[(u64, u64)],
1500        threads_per_grid: MTLSize,
1501        threads_per_threadgroup: MTLSize,
1502        dispatch_kind: DispatchKind,
1503    ) {
1504        // ADR-015 iter63 (Phase A.3): mirror the per-dispatch sampling
1505        // scaffold here so capture-mode-recorded graphs (graph.rs
1506        // encode_sequential / encode_with_barriers / encode_chunk_with
1507        // _barriers) still produce per-dispatch entries.  The replay
1508        // path bypasses encode*; without this hook the per-dispatch
1509        // table would be silently empty for any model that uses
1510        // `GraphExecutor::begin_recorded`.
1511        //
1512        // Captured `op_kind` is forwarded via `pending_op_kind`: the
1513        // graph replay layer at graph.rs:197/236/727 sets it from the
1514        // CapturedNode.op_kind before calling replay_dispatch.
1515        self.ensure_sample_buffer();
1516        let op_kind = self.take_pending_op_kind();
1517        let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
1518        // SAFETY: see encode() above.
1519        let encoder = unsafe { &*encoder_ptr };
1520        encoder.set_compute_pipeline_state(pipeline);
1521        for (index, binding) in bindings {
1522            match binding {
1523                RecordedBinding::Buffer { metal_buffer, offset } => {
1524                    encoder.set_buffer(*index, Some(metal_buffer), *offset);
1525                }
1526                RecordedBinding::Bytes(bytes) => {
1527                    encoder.set_bytes(
1528                        *index,
1529                        bytes.len() as u64,
1530                        bytes.as_ptr() as *const _,
1531                    );
1532                }
1533            }
1534        }
1535        for &(index, byte_length) in threadgroup_memory {
1536            encoder.set_threadgroup_memory_length(index, byte_length);
1537        }
1538        let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
1539        match dispatch_kind {
1540            DispatchKind::Threads => {
1541                encoder.dispatch_threads(threads_per_grid, threads_per_threadgroup);
1542            }
1543            DispatchKind::ThreadGroups => {
1544                encoder.dispatch_thread_groups(threads_per_grid, threads_per_threadgroup);
1545            }
1546        }
1547        self.sample_dispatch_post(encoder, pre_idx);
1548    }
1549
1550    /// Flush any pending residency-set add/remove staging.
1551    ///
1552    /// Hooked at every commit boundary so per-allocation
1553    /// [`ResidencySet::add_allocation`](ResidencySet::add_allocation) and
1554    /// [`ResidencySet::remove_allocation`](ResidencySet::remove_allocation)
1555    /// calls (as fired by `MlxDevice::alloc_buffer` and
1556    /// `MlxBufferStorage::Drop`) collapse into at most ONE `[set commit]`
1557    /// per CB submission. Mirrors llama.cpp's
1558    /// `ggml-metal-device.m:1378-1382` (batch addAllocation in loop,
1559    /// commit ONCE).
1560    #[inline]
1561    fn flush_residency_pending(&self) {
1562        if let Some(set) = self.residency_set.as_ref() {
1563            set.flush_pending();
1564        }
1565    }
1566
1567    // ----------------------------------------------------------------
1568    // ADR-015 iter63 — per-dispatch sample buffer lifecycle
1569    // ----------------------------------------------------------------
1570
1571    /// Allocate the per-CB `MTLCounterSampleBuffer` if it has not been
1572    /// allocated yet for this CB.
1573    ///
1574    /// No-op when `MLX_PROFILE_DISPATCH` is unset, when the buffer is
1575    /// already present, or when the device does not expose a counter
1576    /// set named `"timestamp"` (Risk R1 — graceful degrade with a
1577    /// one-shot stderr warning).
1578    ///
1579    /// The sample buffer is sized to [`MAX_SAMPLES_PER_CB`] (32_768).
1580    /// This is the start-+-end pair budget — i.e. ≤ 16,384 dispatches
1581    /// per CB.  Above that ceiling, additional dispatches will skip
1582    /// sampling (see [`Self::sample_dispatch_pre`]).
1583    #[inline]
1584    fn ensure_sample_buffer(&mut self) {
1585        if !crate::kernel_profile::is_dispatch_enabled() {
1586            return;
1587        }
1588        if self.sample_buffer.is_some() {
1589            return;
1590        }
1591        // Discover the timestamp counter set.  metal-rs 0.33 does not
1592        // export the `MTLCommonCounterSetTimestamp` constant, so we
1593        // name-match `"timestamp"` case-insensitively.  Reach the
1594        // device via the cmd_buf's `device` selector (metal-rs 0.33
1595        // exposes `CommandQueue::device` but not `CommandBuffer::device`,
1596        // so we go through ObjC directly).
1597        let device: &metal::DeviceRef = unsafe {
1598            let cb = &*self.cmd_buf;
1599            msg_send![cb, device]
1600        };
1601        // ADR-015 iter63 — Apple Silicon hardware constraint (NEW Risk
1602        // discovered at impl time, supersedes design §A.7).  M-series
1603        // GPUs (verified: AGXG17XFamilyComputeContext = M5 Max series,
1604        // macOS 26) only support counter sampling AtStageBoundary —
1605        // i.e. between compute *passes*, not between dispatches inside
1606        // a persistent compute encoder.  Calling
1607        // `sampleCountersInBuffer:atSampleIndex:withBarrier:` on such
1608        // hardware aborts with `failed assertion ... not supported on
1609        // this device`.  The persistent-encoder design (mlx-native uses
1610        // ONE compute encoder per CB to amortize ~800 encoder
1611        // create/end cycles per forward pass — see `get_or_create_
1612        // encoder` docstring) is incompatible with stage-boundary-only
1613        // sampling, so on Apple Silicon we degrade per-dispatch
1614        // profiling to a no-op and log once.  Per-CB profiling is
1615        // unaffected (it uses MTLCommandBuffer.GPUStartTime/
1616        // GPUEndTime, which are always available).
1617        //
1618        // Future: if Apple ever ships AtDispatchBoundary support on
1619        // Apple Silicon, this branch becomes a true cap check.  For
1620        // now, the kit infrastructure is in place; only the sample-
1621        // point cooperates.
1622        if !device.supports_counter_sampling(MTLCounterSamplingPoint::AtDispatchBoundary) {
1623            if TIMESTAMP_SET_WARN_LOGGED
1624                .compare_exchange(0, 1, Ordering::Relaxed, Ordering::Relaxed)
1625                .is_ok()
1626            {
1627                eprintln!(
1628                    "[mlx-native] MLX_PROFILE_DISPATCH=1 ignored: \
1629                     device {:?} does NOT support \
1630                     MTLCounterSamplingPointAtDispatchBoundary \
1631                     (Apple Silicon limitation; only AtStageBoundary \
1632                     is supported, which is incompatible with the \
1633                     persistent compute-encoder pattern). \
1634                     MLX_PROFILE_CB=1 still produces per-CB GPU times.",
1635                    device.name()
1636                );
1637            }
1638            return;
1639        }
1640        let counter_sets = device.counter_sets();
1641        let timestamp_set = counter_sets
1642            .iter()
1643            .find(|c: &&metal::CounterSet| c.name().eq_ignore_ascii_case("timestamp"));
1644        let timestamp_set = match timestamp_set {
1645            Some(s) => s,
1646            None => {
1647                // Risk R1: device does not expose a timestamp set.
1648                // Log once and degrade to no-op (sample_buffer stays None).
1649                if TIMESTAMP_SET_WARN_LOGGED
1650                    .compare_exchange(0, 1, Ordering::Relaxed, Ordering::Relaxed)
1651                    .is_ok()
1652                {
1653                    eprintln!(
1654                        "[mlx-native] MLX_PROFILE_DISPATCH=1 ignored: \
1655                         device {:?} exposes no MTLCommonCounterSetTimestamp",
1656                        device.name()
1657                    );
1658                }
1659                return;
1660            }
1661        };
1662        // Build descriptor.  StorageMode::Shared is required by
1663        // resolveCounterRange (MTLCounters.h:185-188).
1664        let descriptor = CounterSampleBufferDescriptor::new();
1665        descriptor.set_counter_set(timestamp_set);
1666        descriptor.set_storage_mode(MTLStorageMode::Shared);
1667        descriptor.set_label("mlx_native.dispatch_samples");
1668        descriptor.set_sample_count(MAX_SAMPLES_PER_CB);
1669        match device.new_counter_sample_buffer_with_descriptor(&descriptor) {
1670            Ok(buf) => {
1671                self.sample_buffer = Some(buf);
1672            }
1673            Err(e) => {
1674                if TIMESTAMP_SET_WARN_LOGGED
1675                    .compare_exchange(0, 1, Ordering::Relaxed, Ordering::Relaxed)
1676                    .is_ok()
1677                {
1678                    eprintln!(
1679                        "[mlx-native] MLX_PROFILE_DISPATCH=1 ignored: \
1680                         newCounterSampleBufferWithDescriptor failed: {}",
1681                        e
1682                    );
1683                }
1684                self.sample_buffer = None;
1685            }
1686        }
1687    }
1688
1689    /// Insert the start-of-dispatch counter sample (sample index `2*i`)
1690    /// and queue the per-dispatch metadata.  Returns the dispatch
1691    /// ordinal `i` so the caller can emit the matching post-sample.
1692    ///
1693    /// No-op when sampling is inactive — returns 0 in that case (the
1694    /// returned value is only consumed when the sample buffer is
1695    /// active, so this is safe).
1696    ///
1697    /// `with_barrier:true` is mandatory: the encoder uses
1698    /// `MTLDispatchTypeConcurrent` and without the barrier the start
1699    /// timestamp would race against any in-flight dispatch (PROFILING-
1700    /// KIT-DESIGN §A.5).
1701    #[inline]
1702    fn sample_dispatch_pre(
1703        &mut self,
1704        encoder: &ComputeCommandEncoderRef,
1705        op_kind: CapturedOpKind,
1706    ) -> Option<u32> {
1707        let sb = self.sample_buffer.as_ref()?;
1708        let i = self.dispatch_in_cb;
1709        let pre_idx = (i as u64).checked_mul(2)?;
1710        if pre_idx >= MAX_SAMPLES_PER_CB {
1711            // Ceiling exceeded — skip sampling for the remainder of
1712            // this CB.  Risk R4 (PROFILING-KIT-DESIGN §A.7): future
1713            // iter can chunk-resolve every N dispatches; for now we
1714            // accept truncation with a one-shot warning (re-uses the
1715            // R1 warn flag).
1716            return None;
1717        }
1718        encoder.sample_counters_in_buffer(sb, pre_idx, true);
1719        self.pending_dispatch_meta.push(PendingDispatchMeta {
1720            op_kind: op_kind.name(),
1721            dispatch_index: i,
1722        });
1723        Some(i)
1724    }
1725
1726    /// Insert the end-of-dispatch counter sample (sample index `2*i+1`)
1727    /// matching the most recent [`Self::sample_dispatch_pre`].
1728    ///
1729    /// No-op when sampling is inactive or when `pre_idx` is `None`.
1730    #[inline]
1731    fn sample_dispatch_post(
1732        &mut self,
1733        encoder: &ComputeCommandEncoderRef,
1734        pre_idx: Option<u32>,
1735    ) {
1736        let i = match pre_idx {
1737            Some(v) => v,
1738            None => return,
1739        };
1740        let sb = match self.sample_buffer.as_ref() {
1741            Some(b) => b,
1742            None => return,
1743        };
1744        let post_idx = match (i as u64).checked_mul(2).and_then(|v| v.checked_add(1)) {
1745            Some(v) if v < MAX_SAMPLES_PER_CB => v,
1746            _ => return,
1747        };
1748        encoder.sample_counters_in_buffer(sb, post_idx, true);
1749        // Bump the per-CB ordinal only after both samples committed
1750        // successfully so a truncation skip leaves the meta queue
1751        // length matching the buffer's resolved range.
1752        self.dispatch_in_cb = i.saturating_add(1);
1753    }
1754
1755    /// Resolve the per-CB sample buffer, push entries into
1756    /// [`crate::kernel_profile`], and reset per-CB state.
1757    ///
1758    /// Called from [`Self::commit_and_wait_labeled`] after the CB
1759    /// completes; the caller is responsible for ensuring the GPU has
1760    /// finished (otherwise `resolveCounterRange` returns garbage).
1761    ///
1762    /// On the first resolve after a [`crate::kernel_profile::reset`],
1763    /// also captures a `(cpu_ns, gpu_ticks)` pair via
1764    /// `device.sampleTimestamps` so subsequent ticks→ns conversion
1765    /// uses a fresh scale factor.
1766    fn resolve_dispatch_samples(&mut self, cb_label: &str) -> Result<()> {
1767        let sb = match self.sample_buffer.take() {
1768            Some(b) => b,
1769            None => {
1770                self.pending_dispatch_meta.clear();
1771                self.dispatch_in_cb = 0;
1772                return Ok(());
1773            }
1774        };
1775        let n = self.pending_dispatch_meta.len();
1776        if n == 0 {
1777            self.dispatch_in_cb = 0;
1778            return Ok(());
1779        }
1780        // Refresh the (cpu, gpu) scale pair on every resolve; the
1781        // device call is cheap and keeps us robust against driver-side
1782        // timebase changes between CBs.
1783        let mut cpu_t: u64 = 0;
1784        let mut gpu_t: u64 = 0;
1785        let device: &metal::DeviceRef = unsafe {
1786            let cb = &*self.cmd_buf;
1787            msg_send![cb, device]
1788        };
1789        device.sample_timestamps(&mut cpu_t, &mut gpu_t);
1790        crate::kernel_profile::record_clock_pair(cpu_t, gpu_t);
1791        let length = (n as u64).saturating_mul(2);
1792        let data = sb.resolve_counter_range(NSRange {
1793            location: 0,
1794            length,
1795        });
1796        // `resolve_counter_range` returns one NSUInteger per sample.
1797        // Pair them up: data[2i] = start, data[2i+1] = end.
1798        for (i, meta) in self.pending_dispatch_meta.drain(..).enumerate() {
1799            let start_idx = 2 * i;
1800            let end_idx = 2 * i + 1;
1801            if end_idx >= data.len() {
1802                break;
1803            }
1804            let start_raw = data[start_idx] as u64;
1805            let end_raw = data[end_idx] as u64;
1806            let start_ns = crate::kernel_profile::convert_gpu_ticks_to_ns(start_raw);
1807            let end_ns = crate::kernel_profile::convert_gpu_ticks_to_ns(end_raw);
1808            let gpu_ns = end_ns.saturating_sub(start_ns);
1809            crate::kernel_profile::record_dispatch(
1810                crate::kernel_profile::DispatchEntry {
1811                    cb_label: cb_label.to_string(),
1812                    op_kind: meta.op_kind,
1813                    dispatch_index: meta.dispatch_index,
1814                    gpu_ns,
1815                    start_gpu_ns: start_ns,
1816                    end_gpu_ns: end_ns,
1817                },
1818            );
1819        }
1820        // Buffer dropped at end of scope releases the underlying
1821        // CounterSampleBuffer; per-CB lifetime correctly bounded.
1822        drop(sb);
1823        self.dispatch_in_cb = 0;
1824        Ok(())
1825    }
1826
1827    /// Commit the command buffer and block until the GPU finishes execution.
1828    ///
1829    /// # Errors
1830    ///
1831    /// Returns `MlxError::CommandBufferError` if the GPU reports an error.
1832    pub fn commit_and_wait(&mut self) -> Result<()> {
1833        SYNC_COUNT.fetch_add(1, Ordering::Relaxed);
1834
1835        // End the persistent compute encoder before committing.
1836        self.end_active_encoder();
1837
1838        // ADR-015 iter8e (Phase 3b): flush deferred residency-set
1839        // add/remove staging so the residency hint covers any buffers
1840        // referenced by this CB. Single commit per CB boundary; no-op
1841        // when no residency set or no staged changes.
1842        self.flush_residency_pending();
1843
1844        self.cmd_buf.commit();
1845        self.cmd_buf.wait_until_completed();
1846
1847        match self.cmd_buf.status() {
1848            MTLCommandBufferStatus::Completed => Ok(()),
1849            MTLCommandBufferStatus::Error => {
1850                Err(MlxError::CommandBufferError(
1851                    "GPU command buffer completed with error status".into(),
1852                ))
1853            }
1854            status => Err(MlxError::CommandBufferError(format!(
1855                "Unexpected command buffer status after wait: {:?}",
1856                status
1857            ))),
1858        }
1859    }
1860
1861    /// Commit + wait, accumulating GPU wall-clock time under `label` into
1862    /// the [`crate::kernel_profile`] global table when `MLX_PROFILE_CB=1`
1863    /// is set.  When the env var is unset, this is identical to
1864    /// [`commit_and_wait`](Self::commit_and_wait) — zero overhead.
1865    ///
1866    /// Used by hf2q's decode hot path to attribute per-cb GPU time to
1867    /// labeled phases (per-layer attn, per-layer ffn, output_head, etc.)
1868    /// without manually wiring `commit_wait_with_gpu_time` everywhere.
1869    ///
1870    /// # Errors
1871    ///
1872    /// Returns `MlxError::CommandBufferError` if the GPU reports an error.
1873    pub fn commit_and_wait_labeled(&mut self, label: &str) -> Result<()> {
1874        // ADR-015 iter16 — propagate `label` to MTLCommandBuffer.setLabel and
1875        // (if a compute encoder is active) MTLComputeCommandEncoder.setLabel
1876        // BEFORE end_encoding/commit so xctrace's
1877        // `metal-application-encoders-list` table populates `cmdbuffer-label`
1878        // and `encoder-label` columns with the semantic phase name (e.g.
1879        // `layer.attn_moe_ffn`, `output_head.fused_norm_lm_argmax`,
1880        // `layer.delta_net.ops1-9`).  Joined to per-CB GPU duration via
1881        // `metal-gpu-submission-to-command-buffer-id` (sub_id ↔ encoder_id) →
1882        // `metal-gpu-execution-points` (per-dispatch start/end), this enables
1883        // per-phase µs/token attribution comparing hf2q vs llama side-by-side
1884        // (iter15 §E "iter16 ATTRIBUTION PATH").  Cost is a single ObjC
1885        // msg_send per CB submission — sub-µs on M5 Max — and a no-op when
1886        // xctrace isn't recording, so this is unconditionally safe to call on
1887        // the production decode hot path.
1888        self.apply_labels(label);
1889        // ADR-015 iter63: record GPU time AND resolve per-dispatch samples
1890        // when either env gate is set.  Per-dispatch sampling force-enables
1891        // the per-CB path so cross-validation per Risk R3 always has a
1892        // ground-truth comparator.
1893        let need_gpu_time =
1894            crate::kernel_profile::is_enabled() || crate::kernel_profile::is_dispatch_enabled();
1895        if need_gpu_time {
1896            let (start_s, end_s) = self.commit_wait_with_gpu_time()?;
1897            let ns = ((end_s - start_s).max(0.0) * 1_000_000_000.0) as u64;
1898            if crate::kernel_profile::is_enabled() {
1899                crate::kernel_profile::record(label, ns);
1900            }
1901            if crate::kernel_profile::is_dispatch_enabled() {
1902                self.resolve_dispatch_samples(label)?;
1903            }
1904            Ok(())
1905        } else {
1906            self.commit_and_wait()
1907        }
1908    }
1909
1910    /// Async commit, but with profiling label.  When `MLX_PROFILE_CB=1`
1911    /// is set, redirects to a synchronous [`commit_and_wait_labeled`]
1912    /// call to capture per-cb GPU time (this defeats async pipelining
1913    /// while profiling, which is the whole point — profile-mode is slow
1914    /// but informative).  When unset, identical to [`commit`](Self::commit).
1915    pub fn commit_labeled(&mut self, label: &str) {
1916        // ADR-015 iter16 — see `commit_and_wait_labeled` for rationale.
1917        if crate::kernel_profile::is_enabled() {
1918            // Profile mode: force sync to capture GPU time.  apply_labels is
1919            // called inside commit_and_wait_labeled — do NOT call it twice
1920            // here (would double the ObjC msg_send under MLX_PROFILE_CB=1).
1921            // Errors are logged via stderr because the void return matches
1922            // commit().
1923            if let Err(e) = self.commit_and_wait_labeled(label) {
1924                eprintln!("[mlx-native] commit_labeled({}) failed: {}", label, e);
1925            }
1926        } else {
1927            // Async path: apply labels here so xctrace MST traces capture
1928            // per-CB phase attribution under default decode (no
1929            // `MLX_PROFILE_CB`).
1930            self.apply_labels(label);
1931            self.commit();
1932        }
1933    }
1934
1935    /// Apply `label` to the underlying `MTLCommandBuffer` and, if a compute
1936    /// encoder is currently active, to the `MTLComputeCommandEncoder`.
1937    ///
1938    /// Called from [`commit_labeled`] and [`commit_and_wait_labeled`] BEFORE
1939    /// the encoder is ended / the CB is committed so xctrace's
1940    /// `metal-application-encoders-list` table picks up the label on the
1941    /// row emitted at the encoder's `endEncoding` / CB submission boundary.
1942    /// Single ObjC `msg_send` per call (two if an encoder is active); sub-µs
1943    /// on M5 Max; no-op when xctrace isn't recording.
1944    ///
1945    /// Skipped (debug-only assert) if `label` is empty — empty labels would
1946    /// produce an indistinguishable trace row from the metal-rs default
1947    /// `Command Buffer 0` placeholder.
1948    #[inline]
1949    fn apply_labels(&mut self, label: &str) {
1950        debug_assert!(!label.is_empty(), "commit_*_labeled called with empty label");
1951        if label.is_empty() {
1952            return;
1953        }
1954        self.cmd_buf.set_label(label);
1955        if !self.active_encoder.is_null() {
1956            // SAFETY: active_encoder is non-null and points to a live encoder
1957            // owned by cmd_buf — same invariant as get_or_create_encoder /
1958            // memory_barrier.  set_label is a single property write on the
1959            // ObjC object; safe before endEncoding.
1960            unsafe { &*self.active_encoder }.set_label(label);
1961        }
1962        // ADR-015 iter63: capture the most recent label for per-dispatch
1963        // entries.  Cheap String allocation — only happens at CB commit
1964        // boundaries, not per dispatch.
1965        self.last_label.clear();
1966        self.last_label.push_str(label);
1967    }
1968
1969    /// Commit + wait, returning `(gpu_start_s, gpu_end_s)` CFTimeInterval
1970    /// timestamps from `MTLCommandBuffer`'s `GPUStartTime`/`GPUEndTime`
1971    /// properties.  Both are mach-absolute CFTimeInterval seconds (double).
1972    ///
1973    /// Intended for `HF2Q_PROFILE_GPU_TS=1` per-bucket GPU wall-clock
1974    /// attribution.  Adds exactly two ObjC property reads per call on top
1975    /// of the regular `commit_and_wait` — measured well under 1 μs on
1976    /// M5 Max.
1977    ///
1978    /// # Errors
1979    ///
1980    /// Returns `MlxError::CommandBufferError` if the GPU reports an error.
1981    pub fn commit_wait_with_gpu_time(&mut self) -> Result<(f64, f64)> {
1982        self.commit_and_wait()?;
1983        // SAFETY: cmd_buf is a valid MTLCommandBuffer that has been
1984        // committed and awaited.  GPUStartTime / GPUEndTime return
1985        // CFTimeInterval (double precision seconds).  See
1986        // https://developer.apple.com/documentation/metal/mtlcommandbuffer/1639925-gpustarttime
1987        let (gpu_start, gpu_end): (f64, f64) = unsafe {
1988            let cb = &*self.cmd_buf;
1989            let s: f64 = msg_send![cb, GPUStartTime];
1990            let e: f64 = msg_send![cb, GPUEndTime];
1991            (s, e)
1992        };
1993        Ok((gpu_start, gpu_end))
1994    }
1995
1996    /// Commit the command buffer WITHOUT blocking.
1997    ///
1998    /// The GPU begins executing the encoded commands immediately.  Call
1999    /// [`wait_until_completed`](Self::wait_until_completed) later to block
2000    /// the CPU and check for errors.  This allows the CPU to continue doing
2001    /// other work (e.g. preparing the next batch) while the GPU runs.
2002    pub fn commit(&mut self) {
2003        self.end_active_encoder();
2004        // ADR-015 iter8e (Phase 3b): same flush hook as commit_and_wait —
2005        // this is the async-pipeline path that production decode uses.
2006        self.flush_residency_pending();
2007        self.cmd_buf.commit();
2008    }
2009
2010    /// Block until a previously committed command buffer completes.
2011    ///
2012    /// Must be called after [`commit`](Self::commit).  Do not call after
2013    /// [`commit_and_wait`](Self::commit_and_wait) — that method already waits.
2014    ///
2015    /// # Errors
2016    ///
2017    /// Returns `MlxError::CommandBufferError` if the GPU reports an error.
2018    pub fn wait_until_completed(&self) -> Result<()> {
2019        self.cmd_buf.wait_until_completed();
2020        match self.cmd_buf.status() {
2021            MTLCommandBufferStatus::Completed => Ok(()),
2022            MTLCommandBufferStatus::Error => Err(MlxError::CommandBufferError(
2023                "GPU command buffer completed with error status".into(),
2024            )),
2025            status => Err(MlxError::CommandBufferError(format!(
2026                "Unexpected command buffer status after wait: {:?}",
2027                status
2028            ))),
2029        }
2030    }
2031
2032    /// Borrow the underlying Metal command buffer.
2033    #[inline]
2034    pub fn metal_command_buffer(&self) -> &CommandBuffer {
2035        &self.cmd_buf
2036    }
2037}
2038
2039impl Drop for CommandEncoder {
2040    fn drop(&mut self) {
2041        // End the persistent compute encoder before the command buffer
2042        // is dropped, otherwise Metal will assert:
2043        // "Command encoder released without endEncoding"
2044        self.end_active_encoder();
2045    }
2046}