Skip to main content

mlx_native/
encoder.rs

1//! [`CommandEncoder`] — batched GPU command submission.
2//!
3//! Wraps a Metal command buffer.  Encode one or more compute kernel dispatches,
4//! then call [`commit_and_wait`](CommandEncoder::commit_and_wait) to submit the
5//! entire batch and block until the GPU finishes.
6//!
7//! # Persistent compute encoder
8//!
9//! A single Metal `ComputeCommandEncoder` is kept alive across multiple
10//! dispatches within the same command buffer.  This avoids the overhead of
11//! creating and ending a new compute encoder per dispatch — the same pattern
12//! candle uses (`compute_per_buffer`).  On a forward pass with ~800 dispatches
13//! this saves ~800 encoder create/end cycles.
14//!
15//! # Capture mode (Phase 4e.1)
16//!
17//! When `start_capture()` is called, subsequent dispatches are recorded into a
18//! `Vec<CapturedNode>` instead of being encoded into Metal.  `memory_barrier()`
19//! records a barrier sentinel.  Call `take_capture()` to extract the recorded
20//! graph for later replay via `ComputeGraph::encode_sequential()`.
21
22use std::sync::atomic::{AtomicI8, AtomicU64, Ordering};
23
24use metal::{
25    CommandBuffer, CommandQueue, ComputeCommandEncoderRef, ComputePipelineState,
26    ComputePipelineStateRef, CounterSampleBuffer, CounterSampleBufferDescriptor,
27    MTLCommandBufferStatus, MTLCounterSamplingPoint, MTLDispatchType, MTLSize, MTLStorageMode,
28    NSRange,
29};
30#[allow(unused_imports)]
31use objc::{msg_send, sel, sel_impl};
32
33use crate::buffer::MlxBuffer;
34use crate::error::{MlxError, Result};
35use crate::mem_ranges::MemRanges;
36use crate::residency::ResidencySet;
37
38/// A buffer or inline-bytes binding for a compute kernel argument slot.
39pub enum KernelArg<'a> {
40    /// Bind an existing Metal buffer at the given index.
41    Buffer(&'a MlxBuffer),
42    /// Bind an existing Metal buffer at the given index with a byte offset.
43    BufferWithOffset(&'a MlxBuffer, u64),
44    /// Bind inline bytes (small constant data) at the given index.
45    /// The data must be `Pod` and is copied into the command encoder.
46    Bytes(&'a [u8]),
47}
48
49/// Pre-baked dispatch record for hot decode paths.
50///
51/// ADR-029 iter-175 Step 1d — first piece of the multi-week
52/// "Option A" refactor that the gemma4 decode gap analysis localized
53/// to per-dispatch CPU orchestration (forward_mlx::forward_decode →
54/// encode_one_layer → dispatch_qmatmul → quantized_matmul_ggml →
55/// dispatch_mv → encoder.encode_threadgroups_with_args).
56///
57/// At gemma4 decode m=1, every dispatch within the inner loop has
58/// load-time-immutable shape: the kernel pipeline, threadgroup
59/// geometry, params struct bytes, and binding-slot layout are fully
60/// determined by the weight + ggml_type and never change across the
61/// thousands of decode tokens that follow.  `DispatchRecord` captures
62/// that state once at model-load (or on first-call lazy-init) so the
63/// hot path skips:
64///   - `KernelRegistry::get_pipeline*` HashMap lookups
65///   - match expressions over `ggml_type` for kernel-name + geometry
66///   - `MTLSize::new` construction (already-known values)
67///   - param-struct field stores + bytemuck::bytes_of conversion
68///
69/// Only the runtime-varying buffers (input, output) need to be passed
70/// to [`CommandEncoder::dispatch_record`].  Weight buffers are baked
71/// inline via the `bake_buffers` slot list.
72///
73/// # Bake-time invariants
74///
75/// - `buffer_slots.len() == bake_buffers.len() + runtime_buffer_count`
76///   for the call site contract; the call site documents
77///   `runtime_buffer_count` and the order of runtime buffers.
78/// - `params_bytes.len()` is whatever the kernel's `KernelArg::Bytes`
79///   expects (typically 8-byte aligned per Metal struct layout).
80/// - `threadgroup_mem` is `(slot, byte_length)` pairs; empty when the
81///   kernel doesn't request `[[threadgroup]]` memory.
82///
83/// # Coherence
84///
85/// `dispatch_record` produces a byte-identical Metal command stream
86/// to the equivalent `encode_threadgroups_with_args*` call.  Capture
87/// mode is supported (replays into `CapturedNode::Dispatch` exactly
88/// like the unbaked path).  See `dispatch_record` for the lockstep.
89#[derive(Clone)]
90pub struct DispatchRecord {
91    /// Pipeline reference, looked up once at bake time.
92    pub pipeline: ComputePipelineState,
93    /// Threadgroup count.
94    pub threadgroups: MTLSize,
95    /// Threads per threadgroup.
96    pub threads_per_tg: MTLSize,
97    /// Threadgroup shared-memory bindings: `(slot_index, byte_length)`.
98    /// Empty when the kernel doesn't allocate `[[threadgroup]]` memory.
99    pub threadgroup_mem: Vec<(u64, u64)>,
100    /// Pre-encoded params struct bytes (bound as `KernelArg::Bytes`).
101    /// Empty when the kernel has no inline-bytes parameter.
102    pub params_bytes: Vec<u8>,
103    /// Slot index for `params_bytes`.  Ignored when `params_bytes` is empty.
104    pub params_slot: u64,
105    /// Slot indices for runtime buffer arguments, in caller order.
106    /// `dispatch_record` zips `runtime_buffers` against this list.
107    pub buffer_slots: Vec<u64>,
108    /// `CapturedOpKind` used when the encoder is in capture mode.
109    pub op_kind: CapturedOpKind,
110    /// Diagnostic label (kernel name) for debug/timing.
111    pub kernel_name: String,
112}
113
114/// Convert a `Pod` value to a byte slice suitable for `KernelArg::Bytes`.
115///
116/// # Safety
117///
118/// The caller must ensure `T` has the same layout as the corresponding
119/// MSL struct in the shader (matching field order, sizes, and alignment).
120pub fn as_bytes<T: bytemuck::Pod>(val: &T) -> &[u8] {
121    bytemuck::bytes_of(val)
122}
123
124// ---------------------------------------------------------------------------
125// Capture-mode types (Phase 4e.1 — Graph IR)
126// ---------------------------------------------------------------------------
127
128/// A recorded kernel argument binding.
129///
130/// When the encoder is in capture mode, each `set_buffer` / `set_bytes` call
131/// is stored as a `RecordedBinding` instead of being applied to Metal.
132#[derive(Clone)]
133pub enum RecordedBinding {
134    /// A Metal buffer at the given offset.
135    Buffer {
136        metal_buffer: metal::Buffer,
137        offset: u64,
138    },
139    /// Inline bytes (small constant data, copied).
140    Bytes(Vec<u8>),
141}
142
143/// How to dispatch the recorded kernel.
144#[derive(Clone, Copy, Debug)]
145pub enum DispatchKind {
146    /// `dispatch_threads(grid_size, threadgroup_size)` — Metal picks threadgroup count.
147    Threads,
148    /// `dispatch_thread_groups(threadgroups, threadgroup_size)` — caller specifies threadgroup count.
149    ThreadGroups,
150}
151
152/// Operation kind tag for captured nodes, used by the fusion pass (4e.2).
153///
154/// When the encoder is in capture mode, each dispatch can be tagged with an
155/// `OpKind` so the fusion pass can identify fuseable sequences without
156/// inspecting pipeline names.
157#[derive(Clone, Copy, Debug, PartialEq, Eq)]
158pub enum CapturedOpKind {
159    /// RMS normalization (with learned scale).
160    RmsNorm,
161    /// Elementwise multiply.
162    ElemMul,
163    /// Elementwise add.
164    ElemAdd,
165    /// Scaled dot-product attention (NOT reorderable — breaks lookahead).
166    Sdpa,
167    /// Softmax (NOT reorderable — breaks lookahead).
168    Softmax,
169    /// Any other operation — treated as reorderable by the graph optimizer.
170    Other,
171}
172
173impl CapturedOpKind {
174    /// Whether this captured op kind is safe to reorder past in the graph
175    /// optimizer (Phase 4e.3).
176    ///
177    /// Mirrors the `h_safe` whitelist from llama.cpp's
178    /// `ggml_metal_graph_optimize_reorder`.  Non-safe ops break the 64-node
179    /// lookahead — the reorder pass cannot look past them.
180    pub fn is_reorderable(&self) -> bool {
181        match self {
182            Self::Sdpa | Self::Softmax => false,
183            Self::RmsNorm | Self::ElemMul | Self::ElemAdd | Self::Other => true,
184        }
185    }
186
187    /// Stable string label suitable for embedding in the per-dispatch
188    /// profile dump (ADR-015 iter63 §A.5).  Matches the variant name —
189    /// `Other` is preserved verbatim so an aggregate-by-op_kind sort
190    /// produces a clean "what isn't yet labeled" bucket.
191    pub fn name(&self) -> &'static str {
192        match self {
193            Self::RmsNorm => "RmsNorm",
194            Self::ElemMul => "ElemMul",
195            Self::ElemAdd => "ElemAdd",
196            Self::Sdpa => "Sdpa",
197            Self::Softmax => "Softmax",
198            Self::Other => "Other",
199        }
200    }
201}
202
203/// A memory range annotation: (start_address, end_address).
204///
205/// Represents a contiguous GPU buffer region for conflict detection in the
206/// reorder pass (Phase 4e.3).  Addresses are CPU-visible `contents_ptr()`
207/// values, which on Apple Silicon unified memory equal the GPU addresses.
208pub type MemRange = (usize, usize);
209
210/// A single captured compute dispatch or barrier sentinel.
211///
212/// Created when the encoder is in capture mode.  Replayed later by
213/// `ComputeGraph::encode_sequential()`.
214#[derive(Clone)]
215pub enum CapturedNode {
216    /// A compute dispatch to replay.
217    Dispatch {
218        /// Pipeline state object to bind.
219        pipeline: ComputePipelineState,
220        /// Kernel argument bindings: (slot_index, binding).
221        bindings: Vec<(u64, RecordedBinding)>,
222        /// Grid or threadgroup count (interpretation depends on `dispatch_kind`).
223        threads_per_grid: MTLSize,
224        /// Threads per threadgroup.
225        threads_per_threadgroup: MTLSize,
226        /// Optional threadgroup memory allocations: (index, byte_length).
227        threadgroup_memory: Vec<(u64, u64)>,
228        /// Whether this is a dispatch_threads or dispatch_thread_groups call.
229        dispatch_kind: DispatchKind,
230        /// Operation kind tag for the fusion pass (4e.2).
231        /// Defaults to `Other` if not explicitly set via `set_op_kind()`.
232        op_kind: CapturedOpKind,
233        /// Read buffer ranges for reorder conflict detection (4e.3).
234        /// Populated from `barrier_between` calls in capture mode.
235        reads: Vec<MemRange>,
236        /// Write buffer ranges for reorder conflict detection (4e.3).
237        /// Populated from `barrier_between` calls in capture mode.
238        writes: Vec<MemRange>,
239    },
240    /// A memory barrier sentinel — forces a barrier at replay time.
241    Barrier,
242}
243
244/// Convert a slice of buffer references into capture-mode
245/// [`MemRange`] tuples.  Used by the [`CommandEncoder::dispatch_tracked*`]
246/// family in capture mode — equivalent to the conversion
247/// `GraphSession::barrier_between` does at `graph.rs:1452-1465`.
248///
249/// `(start, end)` uses `contents_ptr() + byte_offset` as the start
250/// and `contents_ptr() + byte_offset + slice_extent` as the end.
251fn ranges_from_buffers(bufs: &[&MlxBuffer]) -> Vec<MemRange> {
252    bufs.iter()
253        .map(|b| {
254            let base = b.contents_ptr() as usize + b.byte_offset() as usize;
255            let extent = (b.byte_len()).saturating_sub(b.byte_offset() as usize);
256            (base, base + extent)
257        })
258        .collect()
259}
260
261/// Apply a slice of `KernelArg` bindings to a compute encoder.
262///
263/// `KernelArg::Buffer(buf)` propagates the `MlxBuffer::byte_offset()` so
264/// `slice_view`-derived sub-buffers are honored automatically — the
265/// kernel sees memory starting at the slice's offset. This matches the
266/// documented contract of `slice_view` and the offset-handling in the
267/// other binding paths in this file (`encode`, `encode_threadgroups`,
268/// `encode_threadgroups_with_shared`, replay). Without it, every
269/// `slice_view`-derived buffer bound via `KernelArg::Buffer` silently
270/// exposes the entire underlying allocation — surfaced by hf2q's
271/// nomic-bert iter-79 cosine parity bisection (cosine 0.098 → 0.999962
272/// after fix).
273///
274/// `KernelArg::BufferWithOffset(buf, offset)` continues to use the
275/// explicit `offset` argument verbatim (callers asking for an explicit
276/// offset get exactly that, even on sliced buffers). The two API
277/// surfaces are intentional: implicit (sliced views auto-propagate) vs.
278/// explicit (caller-controlled).
279#[inline]
280fn apply_bindings(encoder: &ComputeCommandEncoderRef, bindings: &[(u64, KernelArg<'_>)]) {
281    for &(index, ref arg) in bindings {
282        match arg {
283            KernelArg::Buffer(buf) => {
284                encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
285            }
286            KernelArg::BufferWithOffset(buf, offset) => {
287                encoder.set_buffer(index, Some(buf.metal_buffer()), *offset);
288            }
289            KernelArg::Bytes(bytes) => {
290                encoder.set_bytes(index, bytes.len() as u64, bytes.as_ptr() as *const _);
291            }
292        }
293    }
294}
295
296/// Number of times `commit_and_wait()` has been called (CPU sync points).
297static SYNC_COUNT: AtomicU64 = AtomicU64::new(0);
298
299/// Number of times an encode method has been called (GPU dispatches).
300static DISPATCH_COUNT: AtomicU64 = AtomicU64::new(0);
301
302/// Number of `MTLCommandBuffer` instances created via `CommandEncoder::new`.
303/// Increments once per `device.command_encoder()` call.  Used by hf2q's
304/// `HF2Q_DECODE_PROFILE` instrumentation to measure command-buffer
305/// overhead per decode token (ADR-012 §Optimize / Task #15 follow-up).
306static CMD_BUF_COUNT: AtomicU64 = AtomicU64::new(0);
307
308/// Number of `memory_barrier()` calls that reached the
309/// `objc::msg_send![encoder, memoryBarrierWithScope:]` site.  Capture-mode
310/// no-ops and pre-encoder no-ops are excluded so the count reflects
311/// actual MTL barriers issued.
312///
313/// Always tracked — the increment is one atomic op, ~5 ns.  ADR-015 H4
314/// (Wave 2b hard gate #2) requires per-barrier counter resolution to
315/// confirm-or-falsify the barrier-coalescing lever; xctrace TimeProfiler
316/// at 1 ms sampling cannot resolve `memory_barrier` even though it fires
317/// ~440×/token (`docs/ADR-015-mlx-native-single-cb-decode.md` §"P3a' live
318/// profile pass" hypothesis register row H4).
319static BARRIER_COUNT: AtomicU64 = AtomicU64::new(0);
320
321/// Total nanoseconds spent inside the `objc::msg_send!` barrier site,
322/// summed across all calls.  ONLY updated when the env var
323/// `MLX_PROFILE_BARRIERS=1` is set on the process (cached on first
324/// `memory_barrier` call).  When disabled the timing path is a single
325/// branch + the unconditional barrier dispatch — same hot-path cost as
326/// before this counter was added.
327///
328/// Why env-gated: timing adds 2 × `Instant::now()` (~50–100 ns each via
329/// `mach_absolute_time`) per barrier.  At ~440 barriers/token that is
330/// ~22–44 µs/token of measurement overhead — comparable to what we are
331/// trying to measure.  Production must keep this off; profiling runs
332/// opt-in.
333static BARRIER_NS: AtomicU64 = AtomicU64::new(0);
334
335/// Reset all counters to zero.
336pub fn reset_counters() {
337    SYNC_COUNT.store(0, Ordering::Relaxed);
338    DISPATCH_COUNT.store(0, Ordering::Relaxed);
339    CMD_BUF_COUNT.store(0, Ordering::Relaxed);
340    BARRIER_COUNT.store(0, Ordering::Relaxed);
341    BARRIER_NS.store(0, Ordering::Relaxed);
342    AUTO_BARRIER_COUNT.store(0, Ordering::Relaxed);
343    AUTO_BARRIER_CONCURRENT.store(0, Ordering::Relaxed);
344}
345
346/// Read the current value of `SYNC_COUNT`.
347///
348/// Each call to `commit_and_wait()` increments this counter.
349pub fn sync_count() -> u64 {
350    SYNC_COUNT.load(Ordering::Relaxed)
351}
352
353/// Read the current value of `DISPATCH_COUNT`.
354///
355/// Each call to `encode()`, `encode_threadgroups()`, or
356/// `encode_threadgroups_with_shared()` increments this counter.
357pub fn dispatch_count() -> u64 {
358    DISPATCH_COUNT.load(Ordering::Relaxed)
359}
360
361/// Per-pipeline dispatch bucket support (ADR-028 iter-284).
362///
363/// Env-gated via `MLX_DISP_BUCKET=1`.  When enabled, every
364/// `encode*` call records its pipeline's label in a global hash map.
365/// This gives a per-kernel breakdown comparable to llama.cpp's
366/// instrumented dispatch site for finding *which* kernels make up
367/// the per-token dispatch budget.
368fn pipeline_buckets()
369    -> &'static std::sync::Mutex<std::collections::HashMap<String, u64>> {
370    static BUCKETS: std::sync::OnceLock<
371        std::sync::Mutex<std::collections::HashMap<String, u64>>,
372    > = std::sync::OnceLock::new();
373    BUCKETS.get_or_init(|| std::sync::Mutex::new(std::collections::HashMap::new()))
374}
375
376/// Cached env-flag check — single load on the hot path.
377fn pipeline_bucket_enabled() -> bool {
378    static CACHED: AtomicI8 = AtomicI8::new(-1);
379    let v = CACHED.load(Ordering::Relaxed);
380    if v >= 0 {
381        return v == 1;
382    }
383    let on = std::env::var("MLX_DISP_BUCKET").as_deref() == Ok("1");
384    CACHED.store(if on { 1 } else { 0 }, Ordering::Relaxed);
385    on
386}
387
388/// Record a dispatch into the per-pipeline bucket if the env-flag is on.
389/// Called from every `encode*` site alongside the `DISPATCH_COUNT` bump.
390#[inline]
391pub(crate) fn bucket_dispatch(pipeline: &ComputePipelineStateRef) {
392    if !pipeline_bucket_enabled() {
393        return;
394    }
395    let label = pipeline.label();
396    if label.is_empty() {
397        return;
398    }
399    if let Ok(mut t) = pipeline_buckets().lock() {
400        *t.entry(label.to_string()).or_insert(0) += 1;
401    }
402}
403
404/// Public dump of `MLX_DISP_BUCKET` data: `Vec<(label, count)>` sorted
405/// descending by count.  Returns empty when env-flag is off / never
406/// recorded.
407pub fn pipeline_dispatch_buckets() -> Vec<(String, u64)> {
408    let mut v: Vec<(String, u64)> = if let Ok(t) = pipeline_buckets().lock() {
409        t.iter().map(|(k, v)| (k.clone(), *v)).collect()
410    } else {
411        Vec::new()
412    };
413    v.sort_by(|a, b| b.1.cmp(&a.1));
414    v
415}
416
417/// Reset the per-pipeline dispatch buckets (typically called at decode
418/// start to ignore prefill / warmup contributions).
419pub fn reset_pipeline_dispatch_buckets() {
420    if let Ok(mut t) = pipeline_buckets().lock() {
421        t.clear();
422    }
423}
424
425/// Read the current value of `CMD_BUF_COUNT`.
426///
427/// Each `CommandEncoder::new` (i.e. each `MlxDevice::command_encoder()`)
428/// increments this counter.  Useful for diagnosing per-dispatch Metal
429/// command-buffer overhead in inner loops.
430pub fn cmd_buf_count() -> u64 {
431    CMD_BUF_COUNT.load(Ordering::Relaxed)
432}
433
434/// Read the current value of `BARRIER_COUNT`.
435///
436/// Each `memory_barrier()` call that reaches the underlying
437/// `objc::msg_send![encoder, memoryBarrierWithScope:]` site increments this
438/// counter.  Capture-mode no-ops and pre-encoder no-ops are excluded.
439/// ADR-015 H4 hypothesis: ~440 barriers/token on the qwen35 decode hot
440/// path (verify against this counter).
441pub fn barrier_count() -> u64 {
442    BARRIER_COUNT.load(Ordering::Relaxed)
443}
444
445/// Read the total nanoseconds spent in the `memoryBarrierWithScope:`
446/// `objc::msg_send!` site.  Only non-zero when `MLX_PROFILE_BARRIERS=1`
447/// was in the environment at the time of the first `memory_barrier()`
448/// call (the env check is cached on first use).
449///
450/// Combined with [`barrier_count`] this gives µs/barrier =
451/// `barrier_total_ns() / 1000 / barrier_count()`.
452pub fn barrier_total_ns() -> u64 {
453    BARRIER_NS.load(Ordering::Relaxed)
454}
455
456/// Whether barrier timing is enabled (env-gated, cached on first check).
457///
458/// Reading the env var via `std::env::var` is itself non-trivial; using
459/// `OnceLock` caches the decision so the per-barrier branch is a single
460/// atomic-load + compare.
461fn barrier_profile_enabled() -> bool {
462    use std::sync::OnceLock;
463    static FLAG: OnceLock<bool> = OnceLock::new();
464    *FLAG.get_or_init(|| {
465        std::env::var("MLX_PROFILE_BARRIERS")
466            .map(|v| v == "1")
467            .unwrap_or(false)
468    })
469}
470
471/// Whether `MLX_UNRETAINED_REFS=1` is set in the process environment.
472///
473/// ADR-015 iter13 — when true, `CommandEncoder::new_with_residency` opens
474/// each `MTLCommandBuffer` via
475/// [`CommandQueueRef::new_command_buffer_with_unretained_references`]
476/// instead of the default `commandBuffer`.  llama.cpp's per-token decode
477/// CBs use this same call (`/opt/llama.cpp/ggml/src/ggml-metal/`
478/// `ggml-metal-context.m:512` `[queue commandBufferWithUnretainedReferences]`)
479/// and gain ~3-5% wall on M-series GPUs by skipping per-buffer-binding ARC
480/// retains on submit.
481///
482/// **Caller-side prerequisite.**  Every Metal buffer bound to a dispatch
483/// must outlive the CB — see the docstring on
484/// [`CommandEncoder::new_with_residency`] for the full caller contract.
485/// In hf2q, the per-decode-token `MlxBufferPool` (`buffer_pool.rs`)
486/// already keeps ARC clones alive in its `in_use` list across the entire
487/// decode token; routing transient scratches through that pool is the
488/// canonical way to satisfy the contract.
489///
490/// Cached on first read via `OnceLock` to keep the per-CB-construction
491/// branch single-atomic-load fast.  Default OFF so any production decode
492/// run that does NOT explicitly set the var preserves retained-refs
493/// behavior verbatim.
494fn unretained_refs_enabled() -> bool {
495    use std::sync::OnceLock;
496    static FLAG: OnceLock<bool> = OnceLock::new();
497    *FLAG.get_or_init(|| {
498        std::env::var("MLX_UNRETAINED_REFS")
499            .map(|v| v == "1")
500            .unwrap_or(false)
501    })
502}
503
504/// Whether `HF2Q_PIPELINE_TG_MULT_HINT=1` is set.  Cached on first read.
505///
506/// ADR-029 iter-175 Step 1q safety gate: when this flag is ON, every
507/// pipeline created via `KernelRegistry` has
508/// `threadGroupSizeIsMultipleOfThreadExecutionWidth(true)`.  Apple's
509/// Metal spec says this is UB unless every dispatched threadgroup is
510/// a multiple of 32.  `assert_tg_size_multiple_of_32_if_hinted()`
511/// asserts the constraint before each dispatch, converting UB to a
512/// safe panic with diagnostic info.
513fn pipeline_tg_mult_hint_enabled() -> bool {
514    use std::sync::OnceLock;
515    static FLAG: OnceLock<bool> = OnceLock::new();
516    *FLAG.get_or_init(|| {
517        std::env::var("HF2Q_PIPELINE_TG_MULT_HINT")
518            .map(|v| v == "1")
519            .unwrap_or(false)
520    })
521}
522
523/// ADR-029 iter-175 Step 1q — runtime safety check for
524/// `HF2Q_PIPELINE_TG_MULT_HINT=1`.
525///
526/// When the env flag is ON, the Metal pipeline descriptor sets
527/// `threadGroupSizeIsMultipleOfThreadExecutionWidth(true)`, which
528/// requires every dispatched threadgroup to have
529/// `tg.x * tg.y * tg.z % 32 == 0` on Apple silicon (where
530/// `threadExecutionWidth == 32`).  Violating this is **undefined behavior**
531/// per the Metal spec — the GPU may produce garbage, hang, or panic.
532///
533/// This function panics with a clear message before dispatching so an
534/// offending site is caught immediately instead of silently corrupting
535/// output.  Returns immediately (one atomic-load cost) when the env
536/// flag is OFF.  Includes the pipeline's label in the panic message
537/// (Step 1r) so the offending kernel can be identified without a
538/// debugger.
539#[inline]
540fn assert_tg_size_multiple_of_32_if_hinted(
541    tg: MTLSize,
542    pipeline: &ComputePipelineStateRef,
543) {
544    if !pipeline_tg_mult_hint_enabled() {
545        return;
546    }
547    let total = tg.width.saturating_mul(tg.height).saturating_mul(tg.depth);
548    if total % 32 != 0 {
549        let label = pipeline.label();
550        panic!(
551            "ADR-029 Step 1q safety: HF2Q_PIPELINE_TG_MULT_HINT=1 requires \
552             threadgroup_size.x * y * z to be a multiple of 32 (Apple's \
553             threadExecutionWidth).  Got tg=({}, {}, {}) → total={} → \
554             {} mod 32 = {}.  Pipeline label: \"{}\".  Either fix the \
555             dispatch site to use a multiple-of-32 threadgroup, or unset \
556             HF2Q_PIPELINE_TG_MULT_HINT.",
557            tg.width, tg.height, tg.depth, total, total, total % 32,
558            label
559        );
560    }
561}
562
563/// Whether `HF2Q_AUTO_BARRIER=1` is set in the process environment.
564///
565/// ADR-015 iter37 — when true, every [`CommandEncoder::dispatch_tracked`]
566/// call consults a [`MemRanges`](crate::mem_ranges::MemRanges) tracker
567/// and auto-emits a `memoryBarrierWithScope:` exactly when the new
568/// dispatch's read/write ranges conflict with previously-recorded
569/// ranges (mirrors llama.cpp's `ggml_metal_op_concurrency_check` at
570/// `/opt/llama.cpp/ggml/src/ggml-metal/ggml-metal-ops.cpp:147-225`).
571/// When false, `dispatch_tracked` collapses to the same code path as
572/// `encode*` — no tracking, no auto-barriers — preserving sourdough
573/// behavior for any caller that opts into the tracked API but runs
574/// without the env gate.
575///
576/// Cached on first read via `OnceLock`.  Default OFF — production
577/// decode/prefill keeps its hand-placed `enc.memory_barrier()` calls
578/// until the migration in iter38+.
579fn auto_barrier_enabled() -> bool {
580    use std::sync::OnceLock;
581    static FLAG: OnceLock<bool> = OnceLock::new();
582    *FLAG.get_or_init(|| {
583        std::env::var("HF2Q_AUTO_BARRIER")
584            .map(|v| v == "1")
585            .unwrap_or(false)
586    })
587}
588
589/// Number of `memory_barrier()` calls auto-emitted by
590/// [`CommandEncoder::dispatch_tracked`] under
591/// `HF2Q_AUTO_BARRIER=1`.  Disjoint from [`BARRIER_COUNT`] —
592/// auto-barriers also bump `BARRIER_COUNT` since they go through
593/// `memory_barrier()`, so this counter measures only the
594/// auto-emitted subset.
595static AUTO_BARRIER_COUNT: AtomicU64 = AtomicU64::new(0);
596
597/// Number of `dispatch_tracked` calls whose mem-ranges check returned
598/// "concurrent" (no barrier needed).  Together with
599/// [`AUTO_BARRIER_COUNT`] this measures the elision rate of the
600/// dataflow barrier: `concurrent / (concurrent + barriers)` is the
601/// fraction of dispatches that ran inside the previous concurrent
602/// group rather than starting a new one.
603static AUTO_BARRIER_CONCURRENT: AtomicU64 = AtomicU64::new(0);
604
605// ---------------------------------------------------------------------------
606// ADR-015 iter63 — per-dispatch GPU sampling support
607// ---------------------------------------------------------------------------
608
609/// Hard cap on per-CB sample-buffer sample count (Risk R4 in
610/// PROFILING-KIT-DESIGN §A.7).
611///
612/// Empirically verified on Apple Silicon (M-series, macOS 26): the
613/// underlying `MTLCounterSampleBufferDescriptor.sampleCount` is bounded
614/// by a per-buffer **byte-size** limit of 32768 B.  At 8 bytes per
615/// `MTLCounterResultTimestamp` sample that maps to a sample-count
616/// ceiling of `32_768 / 8 = 4096`.  We allocate two samples per
617/// dispatch (start + end), so this ceiling = 2048 dispatches per CB.
618/// Decode CBs (~120 dispatches) fit comfortably; long prefill CBs
619/// (~6K dispatches per design §A.7) will truncate after 2048 — see
620/// [`Self::sample_dispatch_pre`] for the truncation path.  Future
621/// iter can chunk-resolve every 2K dispatches.
622///
623/// The original design constant of 32_768 (PROFILING-KIT-DESIGN §A.7)
624/// was based on Apple's documented ~64K-per-buffer "practical" limit,
625/// but the measured constraint on this hardware is the 32 KB byte
626/// budget.  Setting the budget below that would underutilize the
627/// buffer; setting it above causes
628/// `newCounterSampleBufferWithDescriptor` to fail with `Invalid sample
629/// buffer length: <bytes> B. Expected range: 8 -> 32768`.
630const MAX_SAMPLES_PER_CB: u64 = 4096;
631
632/// Whether the per-CB warning about a missing `MTLCommonCounterSetTimestamp`
633/// has been emitted yet.  Risk R1: if `device.counter_sets()` does not
634/// return a set named `"timestamp"` (case-insensitive), we degrade the
635/// per-dispatch path to a no-op and log once via stderr.
636static TIMESTAMP_SET_WARN_LOGGED: AtomicU64 = AtomicU64::new(0);
637
638/// Pending per-dispatch metadata that pairs with sample indices `2i`
639/// (start) and `2i+1` (end) inside the CB's `MTLCounterSampleBuffer`.
640/// Resolved by `CommandEncoder::resolve_dispatch_samples` at CB
641/// commit-time and converted to [`crate::kernel_profile::DispatchEntry`]
642/// before being pushed to the global table.
643#[derive(Clone, Debug)]
644struct PendingDispatchMeta {
645    op_kind: &'static str,
646    dispatch_index: u32,
647}
648
649/// Read the cumulative number of auto-emitted barriers across all
650/// encoders since process start (or last [`reset_counters`]).
651pub fn auto_barrier_count() -> u64 {
652    AUTO_BARRIER_COUNT.load(Ordering::Relaxed)
653}
654
655/// Read the cumulative number of `dispatch_tracked` calls that did NOT
656/// emit a barrier (ran concurrent with the previous group).
657pub fn auto_barrier_concurrent_count() -> u64 {
658    AUTO_BARRIER_CONCURRENT.load(Ordering::Relaxed)
659}
660
661/// Issue the underlying Metal `memoryBarrierWithScope:` ObjC msg_send.
662///
663/// Held in its own `#[inline(never)]` function so xctrace / Instruments
664/// has a stable Rust frame to attribute barrier time against, separate
665/// from the surrounding encoder accounting.  Per ADR-015 §P3a' Codex
666/// review Q2: TimeProfiler at 1 ms sampling cannot see this site when
667/// inlined; an explicit non-inline frame plus the [`BARRIER_NS`] counter
668/// closes the H4 hard gate.
669#[inline(never)]
670fn issue_metal_buffer_barrier(encoder: &ComputeCommandEncoderRef) {
671    // MTLBarrierScopeBuffers = 1 << 0 = 1.
672    const MTL_BARRIER_SCOPE_BUFFERS: u64 = 1;
673    unsafe {
674        let _: () =
675            objc::msg_send![encoder, memoryBarrierWithScope: MTL_BARRIER_SCOPE_BUFFERS];
676    }
677}
678
679/// A batched compute command encoder.
680///
681/// Keeps a single Metal `ComputeCommandEncoder` alive across multiple
682/// dispatches.  The encoder is created on the first dispatch and ended
683/// only when the command buffer is committed.  This mirrors candle's
684/// `compute_per_buffer` pattern and avoids per-dispatch encoder overhead.
685///
686/// # Typical usage
687///
688/// ```ignore
689/// let mut enc = device.command_encoder()?;
690/// // Multiple dispatches share the same compute encoder:
691/// enc.encode_threadgroups(pipeline1, &buffers1, tg1, tg_size1);
692/// enc.encode_threadgroups(pipeline2, &buffers2, tg2, tg_size2);
693/// enc.commit_and_wait()?;
694/// ```
695pub struct CommandEncoder {
696    cmd_buf: CommandBuffer,
697    /// Owned clone of the originating command queue.
698    ///
699    /// ADR-019 Phase 0b iter89e2-A: stored at `new_with_residency` time so
700    /// downstream lifecycle code (e.g. `EncoderSession::reset_for_next_stage`
701    /// in Phase 0b-B) can open a fresh `CommandBuffer` from the same queue
702    /// after a non-blocking `commit_stage()`. metal-rs 0.33's
703    /// `CommandQueue` type is `Send + Sync` via `foreign_obj_type!`
704    /// (`/Users/robert/.cargo/registry/src/index.crates.io-1949cf8c6b5b557f/metal-0.33.0/src/lib.rs:179`),
705    /// so adding this field preserves the existing unsafe `Send` impl
706    /// on `CommandEncoder` (declared below).
707    ///
708    /// ADR-019 Phase 0b iter89e2-B (CONSUMED): read by
709    /// [`Self::reset_command_buffer`] to spawn a fresh `CommandBuffer`
710    /// after a non-blocking `commit*` so `EncoderSession::reset_for_next_stage`
711    /// can chain stage CBs without re-constructing the encoder. Holding a
712    /// clone here (rather than a `&CommandQueue` borrow) avoids a lifetime
713    /// parameter on `CommandEncoder` that would propagate through every
714    /// consumer in mlx-native and hf2q.
715    queue: CommandQueue,
716    // SAFETY marker: see unsafe Send impl below.
717    /// Raw pointer to the persistent compute encoder.
718    /// Non-null when a compute pass is active.
719    /// The encoder borrows from `cmd_buf` but we cannot express this
720    /// lifetime in safe Rust, so we use a raw pointer.
721    /// SAFETY: the pointer is valid as long as `cmd_buf` is alive and
722    /// `end_encoding()` has not been called on it.
723    active_encoder: *const ComputeCommandEncoderRef,
724    /// When `Some`, dispatches are recorded here instead of being encoded
725    /// into Metal.  Set via `start_capture()`, extracted via `take_capture()`.
726    capture: Option<Vec<CapturedNode>>,
727    /// Op kind tag for the NEXT captured dispatch.  Set via `set_op_kind()`,
728    /// consumed (reset to `Other`) when a dispatch is captured.
729    pending_op_kind: CapturedOpKind,
730    /// Pending read buffer ranges for the NEXT captured dispatch.
731    /// Set via `set_pending_buffer_ranges()`, consumed when the next dispatch
732    /// is captured.  Used by the reorder pass (Phase 4e.3).
733    pending_reads: Vec<MemRange>,
734    /// Pending write buffer ranges for the NEXT captured dispatch.
735    pending_writes: Vec<MemRange>,
736    /// ADR-015 iter8e (Phase 3b): residency set whose pending add/remove
737    /// staging is flushed at every `commit*` boundary.
738    ///
739    /// Cloned from the device at `device.command_encoder()` time. `None`
740    /// when residency sets are disabled (HF2Q_NO_RESIDENCY=1, macOS<15,
741    /// or test-only `CommandEncoder::new` from a residency-less queue).
742    residency_set: Option<ResidencySet>,
743    /// ADR-015 iter37: dataflow barrier inference state.
744    ///
745    /// Populated only when `HF2Q_AUTO_BARRIER=1` is set at process
746    /// start (cached via [`auto_barrier_enabled`]).  Each
747    /// [`Self::dispatch_tracked`] call consults this state to decide
748    /// whether a Metal memory barrier is required; on conflict the
749    /// barrier is emitted, the state is reset, and the new dispatch's
750    /// ranges seed the next concurrent group.  When the env gate is
751    /// off, `dispatch_tracked` collapses to its untracked equivalent
752    /// and this field is left empty for the encoder's lifetime.
753    ///
754    /// The field is always present (zero-sized when empty) so the
755    /// gate-off branch is a single bool-load + early return rather
756    /// than an allocation/Option indirection.
757    mem_ranges: MemRanges,
758    /// ADR-015 iter63 (per-dispatch profiling): the sample buffer for
759    /// `MTLCounterSampleBuffer.sampleCounters` calls that bracket every
760    /// `encode*` dispatch in this CB.  Lazily allocated on first
761    /// dispatch when `MLX_PROFILE_DISPATCH=1`; `None` otherwise.
762    /// Released (set to `None`) inside `resolve_dispatch_samples` after
763    /// the CB completes — re-allocated on the next `encode*` if the env
764    /// gate stays set.
765    sample_buffer: Option<CounterSampleBuffer>,
766    /// ADR-015 iter63: pending per-dispatch metadata that pairs with
767    /// sample indices `2*i` and `2*i+1` inside `sample_buffer`.  Each
768    /// `encode*` call appends one entry (when sampling is active);
769    /// `resolve_dispatch_samples` drains the vec at commit time.
770    pending_dispatch_meta: Vec<PendingDispatchMeta>,
771    /// ADR-015 iter63: 0-based dispatch ordinal within the current CB.
772    /// Incremented in every `encode*` site after taking the pending
773    /// op_kind; reset to 0 inside `resolve_dispatch_samples`.
774    dispatch_in_cb: u32,
775    /// ADR-015 iter63: most recent label set via `apply_labels`, used
776    /// as the per-dispatch `cb_label` field.  `String::new()` until
777    /// `commit_and_wait_labeled` / `commit_labeled` is called.
778    last_label: String,
779}
780
781/// SAFETY: CommandEncoder is safe to Send across threads provided that:
782/// 1. Only one thread accesses the encoder at a time (exclusive ownership).
783/// 2. The encoder is not used concurrently from multiple threads.
784///
785/// Metal command buffers and compute encoders are thread-safe for exclusive
786/// access (Apple documentation: "You can create command buffers, encode
787/// commands, and submit them from any thread"). The raw pointer
788/// `active_encoder` borrows from `cmd_buf` and is valid as long as
789/// `cmd_buf` is alive — this invariant holds across thread boundaries
790/// because both fields move together.
791///
792/// This matches llama.cpp's pattern of encoding command buffers on GCD
793/// worker threads via `dispatch_apply`, and is used for the dual-buffer
794/// pipeline where buf1 is encoded on a worker thread while buf0 executes.
795unsafe impl Send for CommandEncoder {}
796
797impl CommandEncoder {
798    /// Create a new command encoder from the given command queue.
799    ///
800    /// This immediately creates a Metal command buffer.
801    ///
802    /// # Why retained references
803    ///
804    /// We use the regular `commandBuffer` (Metal retains every bound
805    /// resource for the lifetime of the buffer) rather than
806    /// `commandBufferWithUnretainedReferences`.  llama.cpp uses unretained
807    /// refs for an additional perf bump (~3-5% on M-series GPUs), but the
808    /// hf2q dispatch pattern allocates many transient scratch buffers
809    /// inside helper functions (`apply_proj` → `weight_bf16_owned`,
810    /// `apply_pre_norm` → `params`, etc.) that go out of scope at the
811    /// helper's return.  With unretained refs the metal::Buffer's ARC
812    /// drops to zero, freeing the underlying GPU memory before the
813    /// dispatch executes.  Verified 2026-04-26: switching to unretained
814    /// hits "Command buffer error: GPU command buffer completed with
815    /// error status" on the first MoE FFN dispatch.
816    ///
817    /// To enable unretained refs in the future, every helper that
818    /// allocates and dispatches must thread its scratch buffers up to a
819    /// caller scope that outlives the eventual commit, OR all such
820    /// scratch must come from the per-decode-token pool (which already
821    /// ARC-retains in its in_use list).  Today the lm_head + router-
822    /// download paths are still unpooled.
823    #[allow(dead_code)]
824    pub(crate) fn new(queue: &CommandQueue) -> Result<Self> {
825        Self::new_with_residency(queue, None)
826    }
827
828    /// Create a new command encoder, optionally bound to a residency set so
829    /// `commit*` boundaries can flush deferred add/remove staging.
830    ///
831    /// ADR-015 iter8e (Phase 3b): the encoder's `commit_and_wait`,
832    /// `commit_and_wait_labeled`, `commit`, `commit_labeled`,
833    /// `commit_wait_with_gpu_time` all call
834    /// [`ResidencySet::flush_pending`](ResidencySet::flush_pending) before
835    /// submitting the Metal command buffer. This converts the
836    /// per-allocation `[set commit]` storm
837    /// (~880 commits/decode-token in iter8d/8e claude+codex variants) into
838    /// at most one commit per CB submission — mirrors llama.cpp's
839    /// `ggml-metal-device.m:1378-1382` pattern (batch addAllocation in
840    /// loop, commit ONCE).
841    ///
842    /// ADR-015 iter13: when the `MLX_UNRETAINED_REFS=1` env var is set at
843    /// process start, this constructor uses
844    /// [`CommandQueueRef::new_command_buffer_with_unretained_references`]
845    /// instead of `new_command_buffer`.  llama.cpp's per-token decode CBs
846    /// use `commandBufferWithUnretainedReferences` (see
847    /// `/opt/llama.cpp/ggml/src/ggml-metal/ggml-metal-context.m:512`) which
848    /// skips Metal's per-buffer-binding ARC-retain on submit and saves
849    /// ~3-5% on M-series GPUs (per the docstring above).
850    ///
851    /// **Caller contract under unretained refs.**  Every Metal buffer bound
852    /// to a dispatch in this CB MUST outlive the CB's GPU completion.  In
853    /// the hf2q decode path, that means every transient scratch must be
854    /// either (a) backed by the per-decode-token arena pool
855    /// (`MlxBufferPool` keeps an ARC clone in `in_use` until the next
856    /// `reset` — see `buffer_pool.rs:60`) or (b) hoisted to a caller scope
857    /// that lives across the terminal `commit_and_wait_labeled`.  Helpers
858    /// in `apply_proj` / `apply_pre_norm` / lm_head cast / router-download
859    /// that allocated transients via `device.alloc_buffer` and dropped
860    /// them at function return MUST be lifted to `pooled_alloc_buffer`
861    /// before `MLX_UNRETAINED_REFS=1` is enabled, or the first MoE FFN
862    /// dispatch will crash with "Command buffer error: GPU command buffer
863    /// completed with error status" (verified 2026-04-26).
864    ///
865    /// The default (`MLX_UNRETAINED_REFS` unset) preserves retained-refs
866    /// behavior verbatim — this is the sourdough-safe path.
867    pub(crate) fn new_with_residency(
868        queue: &CommandQueue,
869        residency_set: Option<ResidencySet>,
870    ) -> Result<Self> {
871        let cmd_buf = if unretained_refs_enabled() {
872            queue.new_command_buffer_with_unretained_references().to_owned()
873        } else {
874            queue.new_command_buffer().to_owned()
875        };
876        CMD_BUF_COUNT.fetch_add(1, Ordering::Relaxed);
877        Ok(Self {
878            cmd_buf,
879            queue: queue.to_owned(),
880            active_encoder: std::ptr::null(),
881            capture: None,
882            pending_op_kind: CapturedOpKind::Other,
883            pending_reads: Vec::new(),
884            pending_writes: Vec::new(),
885            residency_set,
886            mem_ranges: MemRanges::new(),
887            sample_buffer: None,
888            pending_dispatch_meta: Vec::new(),
889            dispatch_in_cb: 0,
890            last_label: String::new(),
891        })
892    }
893
894    /// Enable capture mode.
895    ///
896    /// All subsequent dispatch and barrier calls will be recorded into a
897    /// `Vec<CapturedNode>` instead of being encoded into Metal.
898    /// Call `take_capture()` to extract the recorded nodes.
899    pub fn start_capture(&mut self) {
900        self.capture = Some(Vec::with_capacity(128));
901    }
902
903    /// Whether the encoder is currently in capture mode.
904    pub fn is_capturing(&self) -> bool {
905        self.capture.is_some()
906    }
907
908    /// Extract the captured nodes, ending capture mode.
909    ///
910    /// Returns `None` if capture mode was not active.
911    pub fn take_capture(&mut self) -> Option<Vec<CapturedNode>> {
912        self.capture.take()
913    }
914
915    /// Tag the NEXT captured dispatch with the given operation kind.
916    ///
917    /// The tag is consumed (reset to `Other`) after the next dispatch is
918    /// captured.  Only meaningful in capture mode — has no effect on
919    /// direct-dispatch encoding.
920    ///
921    /// Used by op dispatch functions to annotate captures for the fusion
922    /// pass (Phase 4e.2).
923    pub fn set_op_kind(&mut self, kind: CapturedOpKind) {
924        self.pending_op_kind = kind;
925    }
926
927    /// Consume and return the pending op kind, resetting it to `Other`.
928    fn take_pending_op_kind(&mut self) -> CapturedOpKind {
929        let kind = self.pending_op_kind;
930        self.pending_op_kind = CapturedOpKind::Other;
931        kind
932    }
933
934    /// Stash buffer range annotations for the NEXT captured dispatch.
935    ///
936    /// Called by `GraphSession::barrier_between()` in capture mode to record
937    /// which buffers the next dispatch reads from and writes to.  The ranges
938    /// are consumed by the next `encode_*` call and attached to the captured
939    /// `CapturedNode::Dispatch`.
940    ///
941    /// Only meaningful in capture mode — has no effect on direct-dispatch.
942    pub fn set_pending_buffer_ranges(&mut self, reads: Vec<MemRange>, writes: Vec<MemRange>) {
943        self.pending_reads = reads;
944        self.pending_writes = writes;
945    }
946
947    /// Patch the last captured dispatch node's empty reads/writes with the
948    /// given ranges. No-op if not capturing, or if the last node isn't a
949    /// Dispatch, or if its ranges are already populated.
950    ///
951    /// Used by `GraphSession::track_dispatch` in recording mode to annotate
952    /// dispatches that were called without a preceding `barrier_between`.
953    pub fn annotate_last_dispatch_if_missing(&mut self, reads: Vec<MemRange>, writes: Vec<MemRange>) {
954        if let Some(ref mut nodes) = self.capture {
955            if let Some(CapturedNode::Dispatch { reads: r, writes: w, .. }) = nodes.last_mut() {
956                if r.is_empty() && !reads.is_empty() {
957                    *r = reads;
958                }
959                if w.is_empty() && !writes.is_empty() {
960                    *w = writes;
961                }
962            }
963        }
964    }
965
966    /// Consume and return the pending buffer range annotations.
967    fn take_pending_buffer_ranges(&mut self) -> (Vec<MemRange>, Vec<MemRange>) {
968        let reads = std::mem::take(&mut self.pending_reads);
969        let writes = std::mem::take(&mut self.pending_writes);
970        (reads, writes)
971    }
972
973    /// Record buffer bindings into `RecordedBinding` form.
974    fn record_buffer_bindings(buffers: &[(u64, &MlxBuffer)]) -> Vec<(u64, RecordedBinding)> {
975        buffers
976            .iter()
977            .map(|&(index, buf)| {
978                (
979                    index,
980                    RecordedBinding::Buffer {
981                        metal_buffer: buf.metal_buffer().clone(),
982                        offset: buf.byte_offset(),
983                    },
984                )
985            })
986            .collect()
987    }
988
989    /// Record `KernelArg` bindings into `RecordedBinding` form.
990    ///
991    /// `KernelArg::Buffer(buf)` records `buf.byte_offset()` so capture →
992    /// replay round-trips of `slice_view`-derived buffers preserve their
993    /// offsets, matching `record_buffer_bindings`'s behavior at line 382.
994    fn record_arg_bindings(bindings: &[(u64, KernelArg<'_>)]) -> Vec<(u64, RecordedBinding)> {
995        bindings
996            .iter()
997            .map(|(index, arg)| {
998                let recorded = match arg {
999                    KernelArg::Buffer(buf) => RecordedBinding::Buffer {
1000                        metal_buffer: buf.metal_buffer().clone(),
1001                        offset: buf.byte_offset(),
1002                    },
1003                    KernelArg::BufferWithOffset(buf, offset) => RecordedBinding::Buffer {
1004                        metal_buffer: buf.metal_buffer().clone(),
1005                        offset: *offset,
1006                    },
1007                    KernelArg::Bytes(bytes) => RecordedBinding::Bytes(bytes.to_vec()),
1008                };
1009                (*index, recorded)
1010            })
1011            .collect()
1012    }
1013
1014    /// Get or create the persistent compute encoder.
1015    ///
1016    /// On the first call, creates a new compute encoder from the command
1017    /// buffer.  On subsequent calls, returns the existing one.
1018    ///
1019    /// SAFETY: The returned reference borrows from `self.cmd_buf` which is
1020    /// alive for the lifetime of this `CommandEncoder`.  The raw pointer is
1021    /// valid until `end_active_encoder()` is called.
1022    #[inline]
1023    fn get_or_create_encoder(&mut self) -> &ComputeCommandEncoderRef {
1024        if self.active_encoder.is_null() {
1025            // Use MTLDispatchTypeConcurrent to allow independent dispatches
1026            // to overlap on the GPU.  Memory barriers are inserted between
1027            // dependent dispatches via `memory_barrier()`.
1028            //
1029            // ADR-015 iter61a-2 probe: HF2Q_FORCE_SERIAL_DISPATCH=1 falls back
1030            // to MTLDispatchType::Serial — every dispatch waits for the
1031            // previous to complete, eliminating concurrent-dispatch race
1032            // windows. Used to falsify Hypothesis (g): missing memory_barrier
1033            // calls between dependent dispatches cause cold-run logit
1034            // non-determinism via thread-race on a shared buffer.
1035            let dispatch_type = if std::env::var("HF2Q_FORCE_SERIAL_DISPATCH")
1036                .map(|v| v == "1")
1037                .unwrap_or(false)
1038            {
1039                MTLDispatchType::Serial
1040            } else {
1041                MTLDispatchType::Concurrent
1042            };
1043            let encoder = self
1044                .cmd_buf
1045                .compute_command_encoder_with_dispatch_type(dispatch_type);
1046            self.active_encoder = encoder as *const ComputeCommandEncoderRef;
1047        }
1048        // SAFETY: active_encoder is non-null and points to a valid encoder
1049        // owned by cmd_buf.
1050        unsafe { &*self.active_encoder }
1051    }
1052
1053    /// End the active compute encoder if one exists.
1054    #[inline]
1055    fn end_active_encoder(&mut self) {
1056        if !self.active_encoder.is_null() {
1057            // SAFETY: the pointer was obtained from cmd_buf.new_compute_command_encoder()
1058            // and has not been ended yet.
1059            unsafe { &*self.active_encoder }.end_encoding();
1060            self.active_encoder = std::ptr::null();
1061        }
1062    }
1063
1064    /// Insert a memory barrier with scope `MTLBarrierScopeBuffers`.
1065    ///
1066    /// When the encoder uses `MTLDispatchTypeConcurrent`, all dispatches can
1067    /// execute concurrently unless separated by a barrier.  Call this between
1068    /// dispatches where the later dispatch reads a buffer written by an
1069    /// earlier one.
1070    ///
1071    /// This is the same pattern llama.cpp uses:
1072    /// `[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers]`
1073    #[allow(unexpected_cfgs)]
1074    pub fn memory_barrier(&mut self) {
1075        if let Some(ref mut nodes) = self.capture {
1076            nodes.push(CapturedNode::Barrier);
1077            return;
1078        }
1079        if self.active_encoder.is_null() {
1080            return;
1081        }
1082        BARRIER_COUNT.fetch_add(1, Ordering::Relaxed);
1083        // ADR-029 iter-175 Step 1e: when HF2Q_AUTO_BARRIER=1, hand-placed
1084        // barriers must reset the MemRanges tracker so partial migration to
1085        // `dispatch_tracked_*` stays correct.  A hand-placed `memory_barrier()`
1086        // drains the GPU at this point; any tracked dispatch after it
1087        // should start with a fresh cumulative state instead of false-
1088        // conflicting against ranges recorded before this barrier.
1089        // No-op under default HF2Q_AUTO_BARRIER=0 (tracker is empty).
1090        if auto_barrier_enabled() {
1091            self.mem_ranges.reset();
1092        }
1093        // SAFETY: active_encoder is non-null and valid.
1094        let encoder = unsafe { &*self.active_encoder };
1095        if barrier_profile_enabled() {
1096            // mach_absolute_time path — only on when MLX_PROFILE_BARRIERS=1.
1097            let start = std::time::Instant::now();
1098            issue_metal_buffer_barrier(encoder);
1099            let elapsed_ns = start.elapsed().as_nanos() as u64;
1100            BARRIER_NS.fetch_add(elapsed_ns, Ordering::Relaxed);
1101        } else {
1102            issue_metal_buffer_barrier(encoder);
1103        }
1104    }
1105
1106    /// Set the compute pipeline state for subsequent dispatches.
1107    ///
1108    /// This begins a new compute pass if one is not already active.
1109    pub fn set_pipeline(&mut self, pipeline: &ComputePipelineStateRef) {
1110        let encoder = self.get_or_create_encoder();
1111        encoder.set_compute_pipeline_state(pipeline);
1112    }
1113
1114    /// Bind a buffer to a compute kernel argument slot.
1115    ///
1116    /// The `index` corresponds to the `[[buffer(N)]]` attribute in the MSL shader.
1117    pub fn set_buffer(&self, index: u64, buffer: &MlxBuffer) {
1118        let _ = (index, buffer);
1119    }
1120
1121    /// Dispatch threads on the GPU.
1122    pub fn dispatch_threads(&self, grid_size: MTLSize, threadgroup_size: MTLSize) {
1123        let _ = (grid_size, threadgroup_size);
1124    }
1125
1126    /// Encode a complete compute pass: set pipeline, bind buffers, dispatch.
1127    ///
1128    /// Reuses the persistent compute encoder — no per-dispatch encoder
1129    /// creation overhead.
1130    ///
1131    /// # Arguments
1132    ///
1133    /// * `pipeline`         — The compiled compute pipeline to execute.
1134    /// * `buffers`          — Slice of `(index, &MlxBuffer)` pairs for buffer bindings.
1135    /// * `grid_size`        — Total number of threads to launch.
1136    /// * `threadgroup_size` — Threads per threadgroup.
1137    pub fn encode(
1138        &mut self,
1139        pipeline: &ComputePipelineStateRef,
1140        buffers: &[(u64, &MlxBuffer)],
1141        grid_size: MTLSize,
1142        threadgroup_size: MTLSize,
1143    ) {
1144        DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
1145        bucket_dispatch(pipeline);
1146        let op_kind = self.take_pending_op_kind();
1147        let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
1148        if let Some(ref mut nodes) = self.capture {
1149            nodes.push(CapturedNode::Dispatch {
1150                pipeline: pipeline.to_owned(),
1151                bindings: Self::record_buffer_bindings(buffers),
1152                threads_per_grid: grid_size,
1153                threads_per_threadgroup: threadgroup_size,
1154                threadgroup_memory: Vec::new(),
1155                dispatch_kind: DispatchKind::Threads,
1156                op_kind,
1157                reads: pending_reads,
1158                writes: pending_writes,
1159            });
1160            return;
1161        }
1162        // ADR-015 iter63: per-dispatch sampling (no-op when env unset).
1163        self.ensure_sample_buffer();
1164        let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
1165        // SAFETY: encoder_ptr aliases &self via active_encoder which we
1166        // know is non-null after get_or_create_encoder; this pattern is
1167        // used throughout the file (see memory_barrier).
1168        let encoder = unsafe { &*encoder_ptr };
1169        encoder.set_compute_pipeline_state(pipeline);
1170        for &(index, buf) in buffers {
1171            encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
1172        }
1173        let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
1174        assert_tg_size_multiple_of_32_if_hinted(threadgroup_size, pipeline);
1175        encoder.dispatch_threads(grid_size, threadgroup_size);
1176        self.sample_dispatch_post(encoder, pre_idx);
1177    }
1178
1179    /// Encode a compute pass using threadgroups instead of raw thread counts.
1180    ///
1181    /// Reuses the persistent compute encoder — no per-dispatch encoder
1182    /// creation overhead.
1183    pub fn encode_threadgroups(
1184        &mut self,
1185        pipeline: &ComputePipelineStateRef,
1186        buffers: &[(u64, &MlxBuffer)],
1187        threadgroups: MTLSize,
1188        threadgroup_size: MTLSize,
1189    ) {
1190        DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
1191        bucket_dispatch(pipeline);
1192        let op_kind = self.take_pending_op_kind();
1193        let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
1194        if let Some(ref mut nodes) = self.capture {
1195            nodes.push(CapturedNode::Dispatch {
1196                pipeline: pipeline.to_owned(),
1197                bindings: Self::record_buffer_bindings(buffers),
1198                threads_per_grid: threadgroups,
1199                threads_per_threadgroup: threadgroup_size,
1200                threadgroup_memory: Vec::new(),
1201                dispatch_kind: DispatchKind::ThreadGroups,
1202                op_kind,
1203                reads: pending_reads,
1204                writes: pending_writes,
1205            });
1206            return;
1207        }
1208        // ADR-015 iter63: per-dispatch sampling (no-op when env unset).
1209        self.ensure_sample_buffer();
1210        let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
1211        // SAFETY: see encode() above.
1212        let encoder = unsafe { &*encoder_ptr };
1213        encoder.set_compute_pipeline_state(pipeline);
1214        for &(index, buf) in buffers {
1215            encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
1216        }
1217        let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
1218        assert_tg_size_multiple_of_32_if_hinted(threadgroup_size, pipeline);
1219        encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
1220        self.sample_dispatch_post(encoder, pre_idx);
1221    }
1222
1223    /// Encode a compute pass using threadgroups with shared threadgroup memory.
1224    ///
1225    /// Like [`encode_threadgroups`](Self::encode_threadgroups), but additionally
1226    /// allocates threadgroup memory at the specified indices.  This is required
1227    /// for kernels that use `threadgroup` memory (e.g. reductions in rms_norm
1228    /// and softmax).
1229    ///
1230    /// # Arguments
1231    ///
1232    /// * `pipeline`         — The compiled compute pipeline to execute.
1233    /// * `buffers`          — Slice of `(index, &MlxBuffer)` pairs for buffer bindings.
1234    /// * `threadgroup_mem`  — Slice of `(index, byte_length)` pairs for threadgroup memory.
1235    /// * `threadgroups`     — Number of threadgroups to dispatch.
1236    /// * `threadgroup_size` — Threads per threadgroup.
1237    pub fn encode_threadgroups_with_shared(
1238        &mut self,
1239        pipeline: &ComputePipelineStateRef,
1240        buffers: &[(u64, &MlxBuffer)],
1241        threadgroup_mem: &[(u64, u64)],
1242        threadgroups: MTLSize,
1243        threadgroup_size: MTLSize,
1244    ) {
1245        DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
1246        bucket_dispatch(pipeline);
1247        let op_kind = self.take_pending_op_kind();
1248        let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
1249        if let Some(ref mut nodes) = self.capture {
1250            nodes.push(CapturedNode::Dispatch {
1251                pipeline: pipeline.to_owned(),
1252                bindings: Self::record_buffer_bindings(buffers),
1253                threads_per_grid: threadgroups,
1254                threads_per_threadgroup: threadgroup_size,
1255                threadgroup_memory: threadgroup_mem.to_vec(),
1256                dispatch_kind: DispatchKind::ThreadGroups,
1257                op_kind,
1258                reads: pending_reads,
1259                writes: pending_writes,
1260            });
1261            return;
1262        }
1263        // ADR-015 iter63: per-dispatch sampling (no-op when env unset).
1264        self.ensure_sample_buffer();
1265        let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
1266        // SAFETY: see encode() above.
1267        let encoder = unsafe { &*encoder_ptr };
1268        encoder.set_compute_pipeline_state(pipeline);
1269        for &(index, buf) in buffers {
1270            encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
1271        }
1272        for &(index, byte_length) in threadgroup_mem {
1273            encoder.set_threadgroup_memory_length(index, byte_length);
1274        }
1275        let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
1276        assert_tg_size_multiple_of_32_if_hinted(threadgroup_size, pipeline);
1277        encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
1278        self.sample_dispatch_post(encoder, pre_idx);
1279    }
1280
1281    /// Encode a dispatch with mixed buffer/bytes bindings (dispatch_threads).
1282    ///
1283    /// Reuses the persistent compute encoder.
1284    pub fn encode_with_args(
1285        &mut self,
1286        pipeline: &ComputePipelineStateRef,
1287        bindings: &[(u64, KernelArg<'_>)],
1288        grid_size: MTLSize,
1289        threadgroup_size: MTLSize,
1290    ) {
1291        DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
1292        bucket_dispatch(pipeline);
1293        let op_kind = self.take_pending_op_kind();
1294        let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
1295        if let Some(ref mut nodes) = self.capture {
1296            nodes.push(CapturedNode::Dispatch {
1297                pipeline: pipeline.to_owned(),
1298                bindings: Self::record_arg_bindings(bindings),
1299                threads_per_grid: grid_size,
1300                threads_per_threadgroup: threadgroup_size,
1301                threadgroup_memory: Vec::new(),
1302                dispatch_kind: DispatchKind::Threads,
1303                op_kind,
1304                reads: pending_reads,
1305                writes: pending_writes,
1306            });
1307            return;
1308        }
1309        // ADR-015 iter63: per-dispatch sampling (no-op when env unset).
1310        self.ensure_sample_buffer();
1311        let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
1312        // SAFETY: see encode() above.
1313        let encoder = unsafe { &*encoder_ptr };
1314        encoder.set_compute_pipeline_state(pipeline);
1315        apply_bindings(encoder, bindings);
1316        let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
1317        assert_tg_size_multiple_of_32_if_hinted(threadgroup_size, pipeline);
1318        encoder.dispatch_threads(grid_size, threadgroup_size);
1319        self.sample_dispatch_post(encoder, pre_idx);
1320    }
1321
1322    /// Encode a dispatch with mixed buffer/bytes bindings (dispatch_thread_groups).
1323    ///
1324    /// Reuses the persistent compute encoder.
1325    pub fn encode_threadgroups_with_args(
1326        &mut self,
1327        pipeline: &ComputePipelineStateRef,
1328        bindings: &[(u64, KernelArg<'_>)],
1329        threadgroups: MTLSize,
1330        threadgroup_size: MTLSize,
1331    ) {
1332        DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
1333        bucket_dispatch(pipeline);
1334        let op_kind = self.take_pending_op_kind();
1335        let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
1336        if let Some(ref mut nodes) = self.capture {
1337            nodes.push(CapturedNode::Dispatch {
1338                pipeline: pipeline.to_owned(),
1339                bindings: Self::record_arg_bindings(bindings),
1340                threads_per_grid: threadgroups,
1341                threads_per_threadgroup: threadgroup_size,
1342                threadgroup_memory: Vec::new(),
1343                dispatch_kind: DispatchKind::ThreadGroups,
1344                op_kind,
1345                reads: pending_reads,
1346                writes: pending_writes,
1347            });
1348            return;
1349        }
1350        // ADR-015 iter63: per-dispatch sampling (no-op when env unset).
1351        self.ensure_sample_buffer();
1352        let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
1353        // SAFETY: see encode() above.
1354        let encoder = unsafe { &*encoder_ptr };
1355        encoder.set_compute_pipeline_state(pipeline);
1356        apply_bindings(encoder, bindings);
1357        let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
1358        assert_tg_size_multiple_of_32_if_hinted(threadgroup_size, pipeline);
1359        encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
1360        self.sample_dispatch_post(encoder, pre_idx);
1361    }
1362
1363    /// Encode a dispatch with mixed buffer/bytes bindings and shared memory.
1364    ///
1365    /// Reuses the persistent compute encoder.
1366    pub fn encode_threadgroups_with_args_and_shared(
1367        &mut self,
1368        pipeline: &ComputePipelineStateRef,
1369        bindings: &[(u64, KernelArg<'_>)],
1370        threadgroup_mem: &[(u64, u64)],
1371        threadgroups: MTLSize,
1372        threadgroup_size: MTLSize,
1373    ) {
1374        DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
1375        bucket_dispatch(pipeline);
1376        let op_kind = self.take_pending_op_kind();
1377        let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
1378        if let Some(ref mut nodes) = self.capture {
1379            nodes.push(CapturedNode::Dispatch {
1380                pipeline: pipeline.to_owned(),
1381                bindings: Self::record_arg_bindings(bindings),
1382                threads_per_grid: threadgroups,
1383                threads_per_threadgroup: threadgroup_size,
1384                threadgroup_memory: threadgroup_mem.to_vec(),
1385                dispatch_kind: DispatchKind::ThreadGroups,
1386                op_kind,
1387                reads: pending_reads,
1388                writes: pending_writes,
1389            });
1390            return;
1391        }
1392        // ADR-015 iter63: per-dispatch sampling (no-op when env unset).
1393        self.ensure_sample_buffer();
1394        let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
1395        // SAFETY: see encode() above.
1396        let encoder = unsafe { &*encoder_ptr };
1397        encoder.set_compute_pipeline_state(pipeline);
1398        apply_bindings(encoder, bindings);
1399        for &(index, byte_length) in threadgroup_mem {
1400            encoder.set_threadgroup_memory_length(index, byte_length);
1401        }
1402        let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
1403        assert_tg_size_multiple_of_32_if_hinted(threadgroup_size, pipeline);
1404        encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
1405        self.sample_dispatch_post(encoder, pre_idx);
1406    }
1407
1408    // -----------------------------------------------------------------
1409    // ADR-015 iter37 — dataflow-driven auto-barrier dispatch family.
1410    //
1411    // These mirrors of `encode_threadgroups*_with_args*` take explicit
1412    // `reads: &[&MlxBuffer]` and `writes: &[&MlxBuffer]` slices.  When
1413    // the process started with `HF2Q_AUTO_BARRIER=1`, the encoder's
1414    // [`MemRanges`] tracker checks the new ranges against the
1415    // cumulative state since the last barrier; on conflict it emits
1416    // `memory_barrier()` and resets the state before recording the
1417    // new ranges.  When the env gate is unset, the check is skipped
1418    // entirely and the dispatch is applied identically to the
1419    // matching `encode_*` method — sourdough-safe by construction.
1420    //
1421    // Capture mode: the `reads`/`writes` ranges are recorded onto the
1422    // captured node via the existing `pending_reads`/`pending_writes`
1423    // mechanism, so a `dispatch_tracked` call inside capture mode is
1424    // equivalent to `set_pending_buffer_ranges + encode_*`.
1425    //
1426    // No production callsite migrates in iter37 — this is the API
1427    // surface the qwen35 forward path will adopt incrementally in
1428    // iter38+.  Today, every call to `dispatch_tracked` from a
1429    // production code path lives behind an explicit caller decision
1430    // to opt in.
1431    // -----------------------------------------------------------------
1432
1433    /// Auto-barrier-aware dispatch with [`KernelArg`] bindings (uses
1434    /// `dispatch_thread_groups`).
1435    ///
1436    /// Behaves identically to
1437    /// [`encode_threadgroups_with_args`](Self::encode_threadgroups_with_args)
1438    /// when `HF2Q_AUTO_BARRIER` is unset.  When set, consults the
1439    /// per-encoder [`MemRanges`] tracker:
1440    ///
1441    /// * Conflict (RAW/WAR/WAW on a same-buffer range) → emit
1442    ///   `memory_barrier()`, increment [`AUTO_BARRIER_COUNT`], reset
1443    ///   the tracker, then dispatch and seed the new concurrent group
1444    ///   with this dispatch's ranges.
1445    /// * No conflict → increment [`AUTO_BARRIER_CONCURRENT`], record
1446    ///   the ranges into the cumulative state, dispatch.
1447    pub fn dispatch_tracked_threadgroups_with_args(
1448        &mut self,
1449        pipeline: &ComputePipelineStateRef,
1450        bindings: &[(u64, KernelArg<'_>)],
1451        reads: &[&MlxBuffer],
1452        writes: &[&MlxBuffer],
1453        threadgroups: MTLSize,
1454        threadgroup_size: MTLSize,
1455    ) {
1456        // Capture mode: stash ranges + delegate to the standard encode.
1457        // The ranges flow through `pending_reads`/`pending_writes` and
1458        // attach to the captured `Dispatch` node — identical to what
1459        // `GraphSession::barrier_between` already does in capture mode.
1460        if self.is_capturing() {
1461            let read_ranges = ranges_from_buffers(reads);
1462            let write_ranges = ranges_from_buffers(writes);
1463            self.set_pending_buffer_ranges(read_ranges, write_ranges);
1464            self.encode_threadgroups_with_args(pipeline, bindings, threadgroups, threadgroup_size);
1465            return;
1466        }
1467
1468        if auto_barrier_enabled() {
1469            self.maybe_auto_barrier(reads, writes);
1470        }
1471
1472        self.encode_threadgroups_with_args(pipeline, bindings, threadgroups, threadgroup_size);
1473    }
1474
1475    /// Auto-barrier-aware dispatch with [`KernelArg`] bindings + shared
1476    /// threadgroup memory.
1477    ///
1478    /// See [`dispatch_tracked_threadgroups_with_args`](Self::dispatch_tracked_threadgroups_with_args)
1479    /// for the behavioral contract; this variant additionally takes a
1480    /// `threadgroup_mem` slice that is forwarded to
1481    /// [`encode_threadgroups_with_args_and_shared`](Self::encode_threadgroups_with_args_and_shared).
1482    ///
1483    /// The 8-argument signature mirrors the existing
1484    /// `encode_threadgroups_with_args_and_shared` plus the two
1485    /// dataflow slices; `clippy::too_many_arguments` is allowed
1486    /// because each parameter is load-bearing for either the dispatch
1487    /// (pipeline/bindings/threadgroups/threadgroup_size/shared_mem)
1488    /// or the auto-barrier (reads/writes).
1489    #[allow(clippy::too_many_arguments)]
1490    pub fn dispatch_tracked_threadgroups_with_args_and_shared(
1491        &mut self,
1492        pipeline: &ComputePipelineStateRef,
1493        bindings: &[(u64, KernelArg<'_>)],
1494        threadgroup_mem: &[(u64, u64)],
1495        reads: &[&MlxBuffer],
1496        writes: &[&MlxBuffer],
1497        threadgroups: MTLSize,
1498        threadgroup_size: MTLSize,
1499    ) {
1500        if self.is_capturing() {
1501            let read_ranges = ranges_from_buffers(reads);
1502            let write_ranges = ranges_from_buffers(writes);
1503            self.set_pending_buffer_ranges(read_ranges, write_ranges);
1504            self.encode_threadgroups_with_args_and_shared(
1505                pipeline,
1506                bindings,
1507                threadgroup_mem,
1508                threadgroups,
1509                threadgroup_size,
1510            );
1511            return;
1512        }
1513
1514        if auto_barrier_enabled() {
1515            self.maybe_auto_barrier(reads, writes);
1516        }
1517
1518        self.encode_threadgroups_with_args_and_shared(
1519            pipeline,
1520            bindings,
1521            threadgroup_mem,
1522            threadgroups,
1523            threadgroup_size,
1524        );
1525    }
1526
1527    /// Auto-barrier-aware dispatch using `(slot, &MlxBuffer)` bindings
1528    /// (uses `dispatch_thread_groups`).
1529    ///
1530    /// Convenience wrapper for callers that don't need
1531    /// [`KernelArg::Bytes`] inline-byte arguments.  See
1532    /// [`dispatch_tracked_threadgroups_with_args`](Self::dispatch_tracked_threadgroups_with_args)
1533    /// for behavioral contract.
1534    pub fn dispatch_tracked_threadgroups(
1535        &mut self,
1536        pipeline: &ComputePipelineStateRef,
1537        buffers: &[(u64, &MlxBuffer)],
1538        reads: &[&MlxBuffer],
1539        writes: &[&MlxBuffer],
1540        threadgroups: MTLSize,
1541        threadgroup_size: MTLSize,
1542    ) {
1543        if self.is_capturing() {
1544            let read_ranges = ranges_from_buffers(reads);
1545            let write_ranges = ranges_from_buffers(writes);
1546            self.set_pending_buffer_ranges(read_ranges, write_ranges);
1547            self.encode_threadgroups(pipeline, buffers, threadgroups, threadgroup_size);
1548            return;
1549        }
1550
1551        if auto_barrier_enabled() {
1552            self.maybe_auto_barrier(reads, writes);
1553        }
1554
1555        self.encode_threadgroups(pipeline, buffers, threadgroups, threadgroup_size);
1556    }
1557
1558    /// Auto-barrier-aware dispatch using `(slot, &MlxBuffer)` bindings
1559    /// **plus shared threadgroup memory** (uses `dispatch_thread_groups`).
1560    ///
1561    /// Mirrors [`encode_threadgroups_with_shared`](Self::encode_threadgroups_with_shared)
1562    /// — convenience variant for kernels that allocate threadgroup
1563    /// memory (reductions in `rms_norm`, `softmax`, etc.) but don't
1564    /// need [`KernelArg::Bytes`] inline-byte arguments.  See
1565    /// [`dispatch_tracked_threadgroups_with_args`](Self::dispatch_tracked_threadgroups_with_args)
1566    /// for the behavioral contract; the only addition here is the
1567    /// `threadgroup_mem` slice forwarded to the underlying encode.
1568    ///
1569    /// Closes the iter38-audit coverage gap: the 5 `rms_norm.rs`
1570    /// callsites (`/opt/mlx-native/src/ops/rms_norm.rs:124,236,443,
1571    /// 516,589`) all use `encode_threadgroups_with_shared` and need
1572    /// dataflow tracking when migrated to auto-barrier in iter40+.
1573    ///
1574    /// 7-argument signature; `clippy::too_many_arguments` is allowed
1575    /// because each parameter is load-bearing for either the dispatch
1576    /// (pipeline/buffers/threadgroups/threadgroup_size/shared_mem) or
1577    /// the auto-barrier (reads/writes).
1578    #[allow(clippy::too_many_arguments)]
1579    pub fn dispatch_tracked_threadgroups_with_shared(
1580        &mut self,
1581        pipeline: &ComputePipelineStateRef,
1582        buffers: &[(u64, &MlxBuffer)],
1583        threadgroup_mem: &[(u64, u64)],
1584        reads: &[&MlxBuffer],
1585        writes: &[&MlxBuffer],
1586        threadgroups: MTLSize,
1587        threadgroup_size: MTLSize,
1588    ) {
1589        if self.is_capturing() {
1590            let read_ranges = ranges_from_buffers(reads);
1591            let write_ranges = ranges_from_buffers(writes);
1592            self.set_pending_buffer_ranges(read_ranges, write_ranges);
1593            self.encode_threadgroups_with_shared(
1594                pipeline,
1595                buffers,
1596                threadgroup_mem,
1597                threadgroups,
1598                threadgroup_size,
1599            );
1600            return;
1601        }
1602
1603        if auto_barrier_enabled() {
1604            self.maybe_auto_barrier(reads, writes);
1605        }
1606
1607        self.encode_threadgroups_with_shared(
1608            pipeline,
1609            buffers,
1610            threadgroup_mem,
1611            threadgroups,
1612            threadgroup_size,
1613        );
1614    }
1615
1616    /// Auto-barrier-aware `dispatch_threads` variant with
1617    /// [`KernelArg`] bindings.
1618    ///
1619    /// Mirrors [`encode_with_args`](Self::encode_with_args) — the
1620    /// `dispatch_threads` (per-thread grid) flavor, as opposed to the
1621    /// `dispatch_thread_groups` flavor of
1622    /// [`dispatch_tracked_threadgroups_with_args`](Self::dispatch_tracked_threadgroups_with_args).
1623    /// See that method for the behavioral contract.
1624    ///
1625    /// Closes the iter38-audit coverage gap: callers that use
1626    /// per-thread grids — `rope.rs:108` (IMROPE), `sigmoid_mul.rs:76`
1627    /// (sigmoid-mul), and `encode_helpers.rs:41` (kv_cache_copy) —
1628    /// need a `dispatch_threads` flavor of the tracked dispatch
1629    /// because their grid sizes are expressed in threads, not
1630    /// threadgroups.
1631    ///
1632    /// Note: the simpler `(slot, &MlxBuffer)` form (from
1633    /// [`encode`](Self::encode)) is a special case of this method —
1634    /// callers can wrap each binding as `KernelArg::Buffer(buf)` to
1635    /// reuse this single tracked variant rather than introducing a
1636    /// fifth one.
1637    pub fn dispatch_tracked_threads_with_args(
1638        &mut self,
1639        pipeline: &ComputePipelineStateRef,
1640        bindings: &[(u64, KernelArg<'_>)],
1641        reads: &[&MlxBuffer],
1642        writes: &[&MlxBuffer],
1643        grid_size: MTLSize,
1644        threadgroup_size: MTLSize,
1645    ) {
1646        if self.is_capturing() {
1647            let read_ranges = ranges_from_buffers(reads);
1648            let write_ranges = ranges_from_buffers(writes);
1649            self.set_pending_buffer_ranges(read_ranges, write_ranges);
1650            self.encode_with_args(pipeline, bindings, grid_size, threadgroup_size);
1651            return;
1652        }
1653
1654        if auto_barrier_enabled() {
1655            self.maybe_auto_barrier(reads, writes);
1656        }
1657
1658        self.encode_with_args(pipeline, bindings, grid_size, threadgroup_size);
1659    }
1660
1661    /// Dispatch a pre-baked record.
1662    ///
1663    /// ADR-029 iter-175 Step 1d — fast path for decode hot kernels
1664    /// whose pipeline + threadgroup geometry + params bytes are
1665    /// load-time-immutable.  `runtime_buffers` must be in the same
1666    /// order as `rec.buffer_slots`.
1667    ///
1668    /// Equivalent Metal command stream to:
1669    /// ```ignore
1670    /// encoder.encode_threadgroups_with_args_and_shared(
1671    ///     &rec.pipeline,
1672    ///     bindings,  // = runtime_buffers zipped with buffer_slots + (params_slot, Bytes(&rec.params_bytes))
1673    ///     &rec.threadgroup_mem,
1674    ///     rec.threadgroups,
1675    ///     rec.threads_per_tg,
1676    /// );
1677    /// ```
1678    /// — but skips the kernel-name lookup, ggml_type match arms,
1679    /// MTLSize::new, and param-struct field stores that the unbaked
1680    /// path performs on every call.
1681    ///
1682    /// Capture mode and auto-barrier are supported identically to
1683    /// `encode_threadgroups_with_args_and_shared`.  The caller is
1684    /// expected to have called `set_pending_buffer_ranges` (capture)
1685    /// or rely on auto-barrier for dataflow correctness before this
1686    /// call, matching the contract of the unbaked dispatch_tracked_*
1687    /// family.
1688    pub fn dispatch_record(
1689        &mut self,
1690        rec: &DispatchRecord,
1691        runtime_buffers: &[&MlxBuffer],
1692    ) {
1693        debug_assert_eq!(
1694            rec.buffer_slots.len(),
1695            runtime_buffers.len(),
1696            "dispatch_record: runtime_buffers count must match buffer_slots ({}); got {}",
1697            rec.buffer_slots.len(),
1698            runtime_buffers.len(),
1699        );
1700
1701        DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
1702        bucket_dispatch(&rec.pipeline);
1703        let op_kind_override = self.take_pending_op_kind();
1704        // If a caller set an op_kind override via set_op_kind(), honor it;
1705        // otherwise use the baked op_kind from the record.
1706        let op_kind = if matches!(op_kind_override, CapturedOpKind::Other) {
1707            rec.op_kind
1708        } else {
1709            op_kind_override
1710        };
1711        let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
1712
1713        if let Some(ref mut nodes) = self.capture {
1714            // Reconstruct bindings for replay — runtime buffers first,
1715            // params bytes (if any) last.  Order matches what the
1716            // baked-path runtime encoding produces below.
1717            let cap = runtime_buffers.len() + if rec.params_bytes.is_empty() { 0 } else { 1 };
1718            let mut bindings: Vec<(u64, RecordedBinding)> = Vec::with_capacity(cap);
1719            for (slot, buf) in rec.buffer_slots.iter().zip(runtime_buffers.iter()) {
1720                bindings.push((
1721                    *slot,
1722                    RecordedBinding::Buffer {
1723                        metal_buffer: buf.metal_buffer().to_owned(),
1724                        offset: buf.byte_offset(),
1725                    },
1726                ));
1727            }
1728            if !rec.params_bytes.is_empty() {
1729                bindings.push((
1730                    rec.params_slot,
1731                    RecordedBinding::Bytes(rec.params_bytes.clone()),
1732                ));
1733            }
1734            nodes.push(CapturedNode::Dispatch {
1735                pipeline: rec.pipeline.clone(),
1736                bindings,
1737                threads_per_grid: rec.threadgroups,
1738                threads_per_threadgroup: rec.threads_per_tg,
1739                threadgroup_memory: rec.threadgroup_mem.clone(),
1740                dispatch_kind: DispatchKind::ThreadGroups,
1741                op_kind,
1742                reads: pending_reads,
1743                writes: pending_writes,
1744            });
1745            return;
1746        }
1747
1748        // ADR-015 iter63: per-dispatch sampling (no-op when env unset).
1749        self.ensure_sample_buffer();
1750        let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
1751        // SAFETY: see encode() above — encoder reference outlives this scope
1752        // because `get_or_create_encoder` only mutates the `Option` wrapper.
1753        let encoder = unsafe { &*encoder_ptr };
1754        encoder.set_compute_pipeline_state(&rec.pipeline);
1755        for (slot, buf) in rec.buffer_slots.iter().zip(runtime_buffers.iter()) {
1756            encoder.set_buffer(*slot, Some(buf.metal_buffer()), buf.byte_offset());
1757        }
1758        if !rec.params_bytes.is_empty() {
1759            encoder.set_bytes(
1760                rec.params_slot,
1761                rec.params_bytes.len() as u64,
1762                rec.params_bytes.as_ptr() as *const _,
1763            );
1764        }
1765        for &(idx, len) in rec.threadgroup_mem.iter() {
1766            encoder.set_threadgroup_memory_length(idx, len);
1767        }
1768        let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
1769        // Skip assert_tg_size_multiple_of_32_if_hinted: bake-time
1770        // construction already validated the geometry.
1771        encoder.dispatch_thread_groups(rec.threadgroups, rec.threads_per_tg);
1772        self.sample_dispatch_post(encoder, pre_idx);
1773    }
1774
1775    /// Run the dataflow check, emit a barrier on conflict, and record
1776    /// the dispatch's ranges into the cumulative state.
1777    ///
1778    /// Always called *before* the underlying `encode_*` method
1779    /// applies the dispatch.  Mirrors lines 220-225 of
1780    /// `ggml-metal-ops.cpp` (`concurrency_check + concurrency_reset +
1781    /// concurrency_add` around each node).
1782    fn maybe_auto_barrier(
1783        &mut self,
1784        reads: &[&MlxBuffer],
1785        writes: &[&MlxBuffer],
1786    ) {
1787        if self.mem_ranges.check_dispatch(reads, writes) {
1788            // Concurrent — no barrier needed; just record the new ranges.
1789            self.mem_ranges.add_dispatch(reads, writes);
1790            AUTO_BARRIER_CONCURRENT.fetch_add(1, Ordering::Relaxed);
1791        } else {
1792            // Conflict — emit barrier, reset state, seed new group.
1793            //
1794            // `memory_barrier()` itself increments `BARRIER_COUNT` and,
1795            // when `MLX_PROFILE_BARRIERS=1`, accumulates `BARRIER_NS`.
1796            // We additionally bump `AUTO_BARRIER_COUNT` so the
1797            // "auto-emitted vs hand-placed" subset is queryable.
1798            self.memory_barrier();
1799            self.mem_ranges.reset();
1800            self.mem_ranges.add_dispatch(reads, writes);
1801            AUTO_BARRIER_COUNT.fetch_add(1, Ordering::Relaxed);
1802        }
1803    }
1804
1805    /// Force a barrier and reset the auto-barrier tracker.
1806    ///
1807    /// Use at boundaries where the caller knows a barrier is required
1808    /// regardless of dataflow — typically before reading data back to
1809    /// CPU, or at the end of an op group whose internal dependencies
1810    /// the tracker can't see (e.g. host-driven memcpy).
1811    ///
1812    /// Equivalent to `memory_barrier()` plus a `MemRanges::reset()`
1813    /// when `HF2Q_AUTO_BARRIER=1`; equivalent to plain
1814    /// `memory_barrier()` otherwise.
1815    pub fn force_barrier_and_reset_tracker(&mut self) {
1816        self.memory_barrier();
1817        if auto_barrier_enabled() {
1818            self.mem_ranges.reset();
1819        }
1820    }
1821
1822    /// Diagnostic accessor — number of ranges currently recorded in
1823    /// this encoder's [`MemRanges`] tracker.  Always zero unless
1824    /// `HF2Q_AUTO_BARRIER=1` and at least one `dispatch_tracked` call
1825    /// has fired since the last conflict.
1826    #[inline]
1827    pub fn mem_ranges_len(&self) -> usize {
1828        self.mem_ranges.len()
1829    }
1830
1831    /// Replay a single captured dispatch node into this encoder.
1832    ///
1833    /// This is the inverse of capture: it takes a previously recorded
1834    /// `CapturedNode::Dispatch` and encodes it into the live Metal encoder.
1835    /// Barrier nodes are handled by the caller (ComputeGraph::encode_sequential).
1836    ///
1837    /// Does NOT increment `DISPATCH_COUNT` — that was already counted at
1838    /// capture time.
1839    pub fn replay_dispatch(
1840        &mut self,
1841        pipeline: &ComputePipelineStateRef,
1842        bindings: &[(u64, RecordedBinding)],
1843        threadgroup_memory: &[(u64, u64)],
1844        threads_per_grid: MTLSize,
1845        threads_per_threadgroup: MTLSize,
1846        dispatch_kind: DispatchKind,
1847    ) {
1848        // ADR-015 iter63 (Phase A.3): mirror the per-dispatch sampling
1849        // scaffold here so capture-mode-recorded graphs (graph.rs
1850        // encode_sequential / encode_with_barriers / encode_chunk_with
1851        // _barriers) still produce per-dispatch entries.  The replay
1852        // path bypasses encode*; without this hook the per-dispatch
1853        // table would be silently empty for any model that uses
1854        // `GraphExecutor::begin_recorded`.
1855        //
1856        // Captured `op_kind` is forwarded via `pending_op_kind`: the
1857        // graph replay layer at graph.rs:197/236/727 sets it from the
1858        // CapturedNode.op_kind before calling replay_dispatch.
1859        self.ensure_sample_buffer();
1860        let op_kind = self.take_pending_op_kind();
1861        let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
1862        // SAFETY: see encode() above.
1863        let encoder = unsafe { &*encoder_ptr };
1864        encoder.set_compute_pipeline_state(pipeline);
1865        for (index, binding) in bindings {
1866            match binding {
1867                RecordedBinding::Buffer { metal_buffer, offset } => {
1868                    encoder.set_buffer(*index, Some(metal_buffer), *offset);
1869                }
1870                RecordedBinding::Bytes(bytes) => {
1871                    encoder.set_bytes(
1872                        *index,
1873                        bytes.len() as u64,
1874                        bytes.as_ptr() as *const _,
1875                    );
1876                }
1877            }
1878        }
1879        for &(index, byte_length) in threadgroup_memory {
1880            encoder.set_threadgroup_memory_length(index, byte_length);
1881        }
1882        let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
1883        match dispatch_kind {
1884            DispatchKind::Threads => {
1885                assert_tg_size_multiple_of_32_if_hinted(threads_per_threadgroup, pipeline);
1886                encoder.dispatch_threads(threads_per_grid, threads_per_threadgroup);
1887            }
1888            DispatchKind::ThreadGroups => {
1889                assert_tg_size_multiple_of_32_if_hinted(threads_per_threadgroup, pipeline);
1890                encoder.dispatch_thread_groups(threads_per_grid, threads_per_threadgroup);
1891            }
1892        }
1893        self.sample_dispatch_post(encoder, pre_idx);
1894    }
1895
1896    /// Flush any pending residency-set add/remove staging.
1897    ///
1898    /// Hooked at every commit boundary so per-allocation
1899    /// [`ResidencySet::add_allocation`](ResidencySet::add_allocation) and
1900    /// [`ResidencySet::remove_allocation`](ResidencySet::remove_allocation)
1901    /// calls (as fired by `MlxDevice::alloc_buffer` and
1902    /// `MlxBufferStorage::Drop`) collapse into at most ONE `[set commit]`
1903    /// per CB submission. Mirrors llama.cpp's
1904    /// `ggml-metal-device.m:1378-1382` (batch addAllocation in loop,
1905    /// commit ONCE).
1906    #[inline]
1907    fn flush_residency_pending(&self) {
1908        if let Some(set) = self.residency_set.as_ref() {
1909            set.flush_pending();
1910        }
1911    }
1912
1913    // ----------------------------------------------------------------
1914    // ADR-015 iter63 — per-dispatch sample buffer lifecycle
1915    // ----------------------------------------------------------------
1916
1917    /// Allocate the per-CB `MTLCounterSampleBuffer` if it has not been
1918    /// allocated yet for this CB.
1919    ///
1920    /// No-op when `MLX_PROFILE_DISPATCH` is unset, when the buffer is
1921    /// already present, or when the device does not expose a counter
1922    /// set named `"timestamp"` (Risk R1 — graceful degrade with a
1923    /// one-shot stderr warning).
1924    ///
1925    /// The sample buffer is sized to [`MAX_SAMPLES_PER_CB`] (32_768).
1926    /// This is the start-+-end pair budget — i.e. ≤ 16,384 dispatches
1927    /// per CB.  Above that ceiling, additional dispatches will skip
1928    /// sampling (see [`Self::sample_dispatch_pre`]).
1929    #[inline]
1930    fn ensure_sample_buffer(&mut self) {
1931        if !crate::kernel_profile::is_dispatch_enabled() {
1932            return;
1933        }
1934        if self.sample_buffer.is_some() {
1935            return;
1936        }
1937        // Discover the timestamp counter set.  metal-rs 0.33 does not
1938        // export the `MTLCommonCounterSetTimestamp` constant, so we
1939        // name-match `"timestamp"` case-insensitively.  Reach the
1940        // device via the cmd_buf's `device` selector (metal-rs 0.33
1941        // exposes `CommandQueue::device` but not `CommandBuffer::device`,
1942        // so we go through ObjC directly).
1943        let device: &metal::DeviceRef = unsafe {
1944            let cb = &*self.cmd_buf;
1945            msg_send![cb, device]
1946        };
1947        // ADR-015 iter63 — Apple Silicon hardware constraint (NEW Risk
1948        // discovered at impl time, supersedes design §A.7).  M-series
1949        // GPUs (verified: AGXG17XFamilyComputeContext = M5 Max series,
1950        // macOS 26) only support counter sampling AtStageBoundary —
1951        // i.e. between compute *passes*, not between dispatches inside
1952        // a persistent compute encoder.  Calling
1953        // `sampleCountersInBuffer:atSampleIndex:withBarrier:` on such
1954        // hardware aborts with `failed assertion ... not supported on
1955        // this device`.  The persistent-encoder design (mlx-native uses
1956        // ONE compute encoder per CB to amortize ~800 encoder
1957        // create/end cycles per forward pass — see `get_or_create_
1958        // encoder` docstring) is incompatible with stage-boundary-only
1959        // sampling, so on Apple Silicon we degrade per-dispatch
1960        // profiling to a no-op and log once.  Per-CB profiling is
1961        // unaffected (it uses MTLCommandBuffer.GPUStartTime/
1962        // GPUEndTime, which are always available).
1963        //
1964        // Future: if Apple ever ships AtDispatchBoundary support on
1965        // Apple Silicon, this branch becomes a true cap check.  For
1966        // now, the kit infrastructure is in place; only the sample-
1967        // point cooperates.
1968        if !device.supports_counter_sampling(MTLCounterSamplingPoint::AtDispatchBoundary) {
1969            if TIMESTAMP_SET_WARN_LOGGED
1970                .compare_exchange(0, 1, Ordering::Relaxed, Ordering::Relaxed)
1971                .is_ok()
1972            {
1973                eprintln!(
1974                    "[mlx-native] MLX_PROFILE_DISPATCH=1 ignored: \
1975                     device {:?} does NOT support \
1976                     MTLCounterSamplingPointAtDispatchBoundary \
1977                     (Apple Silicon limitation; only AtStageBoundary \
1978                     is supported, which is incompatible with the \
1979                     persistent compute-encoder pattern). \
1980                     MLX_PROFILE_CB=1 still produces per-CB GPU times.",
1981                    device.name()
1982                );
1983            }
1984            return;
1985        }
1986        let counter_sets = device.counter_sets();
1987        let timestamp_set = counter_sets
1988            .iter()
1989            .find(|c: &&metal::CounterSet| c.name().eq_ignore_ascii_case("timestamp"));
1990        let timestamp_set = match timestamp_set {
1991            Some(s) => s,
1992            None => {
1993                // Risk R1: device does not expose a timestamp set.
1994                // Log once and degrade to no-op (sample_buffer stays None).
1995                if TIMESTAMP_SET_WARN_LOGGED
1996                    .compare_exchange(0, 1, Ordering::Relaxed, Ordering::Relaxed)
1997                    .is_ok()
1998                {
1999                    eprintln!(
2000                        "[mlx-native] MLX_PROFILE_DISPATCH=1 ignored: \
2001                         device {:?} exposes no MTLCommonCounterSetTimestamp",
2002                        device.name()
2003                    );
2004                }
2005                return;
2006            }
2007        };
2008        // Build descriptor.  StorageMode::Shared is required by
2009        // resolveCounterRange (MTLCounters.h:185-188).
2010        let descriptor = CounterSampleBufferDescriptor::new();
2011        descriptor.set_counter_set(timestamp_set);
2012        descriptor.set_storage_mode(MTLStorageMode::Shared);
2013        descriptor.set_label("mlx_native.dispatch_samples");
2014        descriptor.set_sample_count(MAX_SAMPLES_PER_CB);
2015        match device.new_counter_sample_buffer_with_descriptor(&descriptor) {
2016            Ok(buf) => {
2017                self.sample_buffer = Some(buf);
2018            }
2019            Err(e) => {
2020                if TIMESTAMP_SET_WARN_LOGGED
2021                    .compare_exchange(0, 1, Ordering::Relaxed, Ordering::Relaxed)
2022                    .is_ok()
2023                {
2024                    eprintln!(
2025                        "[mlx-native] MLX_PROFILE_DISPATCH=1 ignored: \
2026                         newCounterSampleBufferWithDescriptor failed: {}",
2027                        e
2028                    );
2029                }
2030                self.sample_buffer = None;
2031            }
2032        }
2033    }
2034
2035    /// Insert the start-of-dispatch counter sample (sample index `2*i`)
2036    /// and queue the per-dispatch metadata.  Returns the dispatch
2037    /// ordinal `i` so the caller can emit the matching post-sample.
2038    ///
2039    /// No-op when sampling is inactive — returns 0 in that case (the
2040    /// returned value is only consumed when the sample buffer is
2041    /// active, so this is safe).
2042    ///
2043    /// `with_barrier:true` is mandatory: the encoder uses
2044    /// `MTLDispatchTypeConcurrent` and without the barrier the start
2045    /// timestamp would race against any in-flight dispatch (PROFILING-
2046    /// KIT-DESIGN §A.5).
2047    #[inline]
2048    fn sample_dispatch_pre(
2049        &mut self,
2050        encoder: &ComputeCommandEncoderRef,
2051        op_kind: CapturedOpKind,
2052    ) -> Option<u32> {
2053        let sb = self.sample_buffer.as_ref()?;
2054        let i = self.dispatch_in_cb;
2055        let pre_idx = (i as u64).checked_mul(2)?;
2056        if pre_idx >= MAX_SAMPLES_PER_CB {
2057            // Ceiling exceeded — skip sampling for the remainder of
2058            // this CB.  Risk R4 (PROFILING-KIT-DESIGN §A.7): future
2059            // iter can chunk-resolve every N dispatches; for now we
2060            // accept truncation with a one-shot warning (re-uses the
2061            // R1 warn flag).
2062            return None;
2063        }
2064        encoder.sample_counters_in_buffer(sb, pre_idx, true);
2065        self.pending_dispatch_meta.push(PendingDispatchMeta {
2066            op_kind: op_kind.name(),
2067            dispatch_index: i,
2068        });
2069        Some(i)
2070    }
2071
2072    /// Insert the end-of-dispatch counter sample (sample index `2*i+1`)
2073    /// matching the most recent [`Self::sample_dispatch_pre`].
2074    ///
2075    /// No-op when sampling is inactive or when `pre_idx` is `None`.
2076    #[inline]
2077    fn sample_dispatch_post(
2078        &mut self,
2079        encoder: &ComputeCommandEncoderRef,
2080        pre_idx: Option<u32>,
2081    ) {
2082        let i = match pre_idx {
2083            Some(v) => v,
2084            None => return,
2085        };
2086        let sb = match self.sample_buffer.as_ref() {
2087            Some(b) => b,
2088            None => return,
2089        };
2090        let post_idx = match (i as u64).checked_mul(2).and_then(|v| v.checked_add(1)) {
2091            Some(v) if v < MAX_SAMPLES_PER_CB => v,
2092            _ => return,
2093        };
2094        encoder.sample_counters_in_buffer(sb, post_idx, true);
2095        // Bump the per-CB ordinal only after both samples committed
2096        // successfully so a truncation skip leaves the meta queue
2097        // length matching the buffer's resolved range.
2098        self.dispatch_in_cb = i.saturating_add(1);
2099    }
2100
2101    /// Resolve the per-CB sample buffer, push entries into
2102    /// [`crate::kernel_profile`], and reset per-CB state.
2103    ///
2104    /// Called from [`Self::commit_and_wait_labeled`] after the CB
2105    /// completes; the caller is responsible for ensuring the GPU has
2106    /// finished (otherwise `resolveCounterRange` returns garbage).
2107    ///
2108    /// On the first resolve after a [`crate::kernel_profile::reset`],
2109    /// also captures a `(cpu_ns, gpu_ticks)` pair via
2110    /// `device.sampleTimestamps` so subsequent ticks→ns conversion
2111    /// uses a fresh scale factor.
2112    fn resolve_dispatch_samples(&mut self, cb_label: &str) -> Result<()> {
2113        let sb = match self.sample_buffer.take() {
2114            Some(b) => b,
2115            None => {
2116                self.pending_dispatch_meta.clear();
2117                self.dispatch_in_cb = 0;
2118                return Ok(());
2119            }
2120        };
2121        let n = self.pending_dispatch_meta.len();
2122        if n == 0 {
2123            self.dispatch_in_cb = 0;
2124            return Ok(());
2125        }
2126        // Refresh the (cpu, gpu) scale pair on every resolve; the
2127        // device call is cheap and keeps us robust against driver-side
2128        // timebase changes between CBs.
2129        let mut cpu_t: u64 = 0;
2130        let mut gpu_t: u64 = 0;
2131        let device: &metal::DeviceRef = unsafe {
2132            let cb = &*self.cmd_buf;
2133            msg_send![cb, device]
2134        };
2135        device.sample_timestamps(&mut cpu_t, &mut gpu_t);
2136        crate::kernel_profile::record_clock_pair(cpu_t, gpu_t);
2137        let length = (n as u64).saturating_mul(2);
2138        let data = sb.resolve_counter_range(NSRange {
2139            location: 0,
2140            length,
2141        });
2142        // `resolve_counter_range` returns one NSUInteger per sample.
2143        // Pair them up: data[2i] = start, data[2i+1] = end.
2144        for (i, meta) in self.pending_dispatch_meta.drain(..).enumerate() {
2145            let start_idx = 2 * i;
2146            let end_idx = 2 * i + 1;
2147            if end_idx >= data.len() {
2148                break;
2149            }
2150            let start_raw = data[start_idx] as u64;
2151            let end_raw = data[end_idx] as u64;
2152            let start_ns = crate::kernel_profile::convert_gpu_ticks_to_ns(start_raw);
2153            let end_ns = crate::kernel_profile::convert_gpu_ticks_to_ns(end_raw);
2154            let gpu_ns = end_ns.saturating_sub(start_ns);
2155            crate::kernel_profile::record_dispatch(
2156                crate::kernel_profile::DispatchEntry {
2157                    cb_label: cb_label.to_string(),
2158                    op_kind: meta.op_kind,
2159                    dispatch_index: meta.dispatch_index,
2160                    gpu_ns,
2161                    start_gpu_ns: start_ns,
2162                    end_gpu_ns: end_ns,
2163                },
2164            );
2165        }
2166        // Buffer dropped at end of scope releases the underlying
2167        // CounterSampleBuffer; per-CB lifetime correctly bounded.
2168        drop(sb);
2169        self.dispatch_in_cb = 0;
2170        Ok(())
2171    }
2172
2173    /// Commit the command buffer and block until the GPU finishes execution.
2174    ///
2175    /// # Errors
2176    ///
2177    /// Returns `MlxError::CommandBufferError` if the GPU reports an error.
2178    pub fn commit_and_wait(&mut self) -> Result<()> {
2179        SYNC_COUNT.fetch_add(1, Ordering::Relaxed);
2180
2181        // End the persistent compute encoder before committing.
2182        self.end_active_encoder();
2183
2184        // ADR-015 iter8e (Phase 3b): flush deferred residency-set
2185        // add/remove staging so the residency hint covers any buffers
2186        // referenced by this CB. Single commit per CB boundary; no-op
2187        // when no residency set or no staged changes.
2188        self.flush_residency_pending();
2189
2190        self.cmd_buf.commit();
2191        self.cmd_buf.wait_until_completed();
2192
2193        match self.cmd_buf.status() {
2194            MTLCommandBufferStatus::Completed => Ok(()),
2195            MTLCommandBufferStatus::Error => {
2196                Err(MlxError::CommandBufferError(
2197                    "GPU command buffer completed with error status".into(),
2198                ))
2199            }
2200            status => Err(MlxError::CommandBufferError(format!(
2201                "Unexpected command buffer status after wait: {:?}",
2202                status
2203            ))),
2204        }
2205    }
2206
2207    /// Commit + wait, accumulating GPU wall-clock time under `label` into
2208    /// the [`crate::kernel_profile`] global table when `MLX_PROFILE_CB=1`
2209    /// is set.  When the env var is unset, this is identical to
2210    /// [`commit_and_wait`](Self::commit_and_wait) — zero overhead.
2211    ///
2212    /// Used by hf2q's decode hot path to attribute per-cb GPU time to
2213    /// labeled phases (per-layer attn, per-layer ffn, output_head, etc.)
2214    /// without manually wiring `commit_wait_with_gpu_time` everywhere.
2215    ///
2216    /// # Errors
2217    ///
2218    /// Returns `MlxError::CommandBufferError` if the GPU reports an error.
2219    pub fn commit_and_wait_labeled(&mut self, label: &str) -> Result<()> {
2220        // ADR-015 iter16 — propagate `label` to MTLCommandBuffer.setLabel and
2221        // (if a compute encoder is active) MTLComputeCommandEncoder.setLabel
2222        // BEFORE end_encoding/commit so xctrace's
2223        // `metal-application-encoders-list` table populates `cmdbuffer-label`
2224        // and `encoder-label` columns with the semantic phase name (e.g.
2225        // `layer.attn_moe_ffn`, `output_head.fused_norm_lm_argmax`,
2226        // `layer.delta_net.ops1-9`).  Joined to per-CB GPU duration via
2227        // `metal-gpu-submission-to-command-buffer-id` (sub_id ↔ encoder_id) →
2228        // `metal-gpu-execution-points` (per-dispatch start/end), this enables
2229        // per-phase µs/token attribution comparing hf2q vs llama side-by-side
2230        // (iter15 §E "iter16 ATTRIBUTION PATH").  Cost is a single ObjC
2231        // msg_send per CB submission — sub-µs on M5 Max — and a no-op when
2232        // xctrace isn't recording, so this is unconditionally safe to call on
2233        // the production decode hot path.
2234        self.apply_labels(label);
2235        // ADR-015 iter63: record GPU time AND resolve per-dispatch samples
2236        // when either env gate is set.  Per-dispatch sampling force-enables
2237        // the per-CB path so cross-validation per Risk R3 always has a
2238        // ground-truth comparator.
2239        let need_gpu_time =
2240            crate::kernel_profile::is_enabled() || crate::kernel_profile::is_dispatch_enabled();
2241        if need_gpu_time {
2242            let (start_s, end_s) = self.commit_wait_with_gpu_time()?;
2243            let ns = ((end_s - start_s).max(0.0) * 1_000_000_000.0) as u64;
2244            if crate::kernel_profile::is_enabled() {
2245                crate::kernel_profile::record(label, ns);
2246            }
2247            if crate::kernel_profile::is_dispatch_enabled() {
2248                self.resolve_dispatch_samples(label)?;
2249            }
2250            Ok(())
2251        } else {
2252            self.commit_and_wait()
2253        }
2254    }
2255
2256    /// Async commit, but with profiling label.  When `MLX_PROFILE_CB=1`
2257    /// is set, redirects to a synchronous [`commit_and_wait_labeled`]
2258    /// call to capture per-cb GPU time (this defeats async pipelining
2259    /// while profiling, which is the whole point — profile-mode is slow
2260    /// but informative).  When unset, identical to [`commit`](Self::commit).
2261    pub fn commit_labeled(&mut self, label: &str) {
2262        // ADR-015 iter16 — see `commit_and_wait_labeled` for rationale.
2263        if crate::kernel_profile::is_enabled() {
2264            // Profile mode: force sync to capture GPU time.  apply_labels is
2265            // called inside commit_and_wait_labeled — do NOT call it twice
2266            // here (would double the ObjC msg_send under MLX_PROFILE_CB=1).
2267            // Errors are logged via stderr because the void return matches
2268            // commit().
2269            if let Err(e) = self.commit_and_wait_labeled(label) {
2270                eprintln!("[mlx-native] commit_labeled({}) failed: {}", label, e);
2271            }
2272        } else {
2273            // Async path: apply labels here so xctrace MST traces capture
2274            // per-CB phase attribution under default decode (no
2275            // `MLX_PROFILE_CB`).
2276            self.apply_labels(label);
2277            self.commit();
2278        }
2279    }
2280
2281    /// Apply `label` to the underlying `MTLCommandBuffer` and, if a compute
2282    /// encoder is currently active, to the `MTLComputeCommandEncoder`.
2283    ///
2284    /// Called from [`commit_labeled`] and [`commit_and_wait_labeled`] BEFORE
2285    /// the encoder is ended / the CB is committed so xctrace's
2286    /// `metal-application-encoders-list` table picks up the label on the
2287    /// row emitted at the encoder's `endEncoding` / CB submission boundary.
2288    /// Single ObjC `msg_send` per call (two if an encoder is active); sub-µs
2289    /// on M5 Max; no-op when xctrace isn't recording.
2290    ///
2291    /// Skipped (debug-only assert) if `label` is empty — empty labels would
2292    /// produce an indistinguishable trace row from the metal-rs default
2293    /// `Command Buffer 0` placeholder.
2294    #[inline]
2295    fn apply_labels(&mut self, label: &str) {
2296        debug_assert!(!label.is_empty(), "commit_*_labeled called with empty label");
2297        if label.is_empty() {
2298            return;
2299        }
2300        self.cmd_buf.set_label(label);
2301        if !self.active_encoder.is_null() {
2302            // SAFETY: active_encoder is non-null and points to a live encoder
2303            // owned by cmd_buf — same invariant as get_or_create_encoder /
2304            // memory_barrier.  set_label is a single property write on the
2305            // ObjC object; safe before endEncoding.
2306            unsafe { &*self.active_encoder }.set_label(label);
2307        }
2308        // ADR-015 iter63: capture the most recent label for per-dispatch
2309        // entries.  Cheap String allocation — only happens at CB commit
2310        // boundaries, not per dispatch.
2311        self.last_label.clear();
2312        self.last_label.push_str(label);
2313    }
2314
2315    /// Commit + wait, returning `(gpu_start_s, gpu_end_s)` CFTimeInterval
2316    /// timestamps from `MTLCommandBuffer`'s `GPUStartTime`/`GPUEndTime`
2317    /// properties.  Both are mach-absolute CFTimeInterval seconds (double).
2318    ///
2319    /// Intended for `HF2Q_PROFILE_GPU_TS=1` per-bucket GPU wall-clock
2320    /// attribution.  Adds exactly two ObjC property reads per call on top
2321    /// of the regular `commit_and_wait` — measured well under 1 μs on
2322    /// M5 Max.
2323    ///
2324    /// # Errors
2325    ///
2326    /// Returns `MlxError::CommandBufferError` if the GPU reports an error.
2327    pub fn commit_wait_with_gpu_time(&mut self) -> Result<(f64, f64)> {
2328        self.commit_and_wait()?;
2329        // SAFETY: cmd_buf is a valid MTLCommandBuffer that has been
2330        // committed and awaited.  GPUStartTime / GPUEndTime return
2331        // CFTimeInterval (double precision seconds).  See
2332        // https://developer.apple.com/documentation/metal/mtlcommandbuffer/1639925-gpustarttime
2333        let (gpu_start, gpu_end): (f64, f64) = unsafe {
2334            let cb = &*self.cmd_buf;
2335            let s: f64 = msg_send![cb, GPUStartTime];
2336            let e: f64 = msg_send![cb, GPUEndTime];
2337            (s, e)
2338        };
2339        Ok((gpu_start, gpu_end))
2340    }
2341
2342    /// Commit the command buffer WITHOUT blocking.
2343    ///
2344    /// The GPU begins executing the encoded commands immediately.  Call
2345    /// [`wait_until_completed`](Self::wait_until_completed) later to block
2346    /// the CPU and check for errors.  This allows the CPU to continue doing
2347    /// other work (e.g. preparing the next batch) while the GPU runs.
2348    pub fn commit(&mut self) {
2349        self.end_active_encoder();
2350        // ADR-015 iter8e (Phase 3b): same flush hook as commit_and_wait —
2351        // this is the async-pipeline path that production decode uses.
2352        self.flush_residency_pending();
2353        self.cmd_buf.commit();
2354    }
2355
2356    /// Block until a previously committed command buffer completes.
2357    ///
2358    /// Must be called after [`commit`](Self::commit).  Do not call after
2359    /// [`commit_and_wait`](Self::commit_and_wait) — that method already waits.
2360    ///
2361    /// # Errors
2362    ///
2363    /// Returns `MlxError::CommandBufferError` if the GPU reports an error.
2364    pub fn wait_until_completed(&self) -> Result<()> {
2365        self.cmd_buf.wait_until_completed();
2366        match self.cmd_buf.status() {
2367            MTLCommandBufferStatus::Completed => Ok(()),
2368            MTLCommandBufferStatus::Error => Err(MlxError::CommandBufferError(
2369                "GPU command buffer completed with error status".into(),
2370            )),
2371            status => Err(MlxError::CommandBufferError(format!(
2372                "Unexpected command buffer status after wait: {:?}",
2373                status
2374            ))),
2375        }
2376    }
2377
2378    /// Borrow the underlying Metal command buffer.
2379    #[inline]
2380    pub fn metal_command_buffer(&self) -> &CommandBuffer {
2381        &self.cmd_buf
2382    }
2383
2384    /// Borrow the residency set bound to this encoder, if one exists.
2385    ///
2386    /// ADR-019 Phase 0b iter89e2-B: exposed `pub(crate)` so
2387    /// [`crate::EncoderSession`] can route caller-driven add/remove
2388    /// requests through the same `Arc<ResidencySetInner>` the encoder
2389    /// itself flushes at every `commit*` boundary. The single-set
2390    /// invariant from `device.rs::MlxDevice` is preserved — both the
2391    /// encoder's `flush_residency_pending` and the session's delegated
2392    /// add/remove operate on the SAME residency set. Returns `None` when
2393    /// residency sets are disabled (HF2Q_NO_RESIDENCY=1, macOS<15, or
2394    /// `CommandEncoder::new` from a residency-less queue).
2395    #[inline]
2396    pub(crate) fn residency_set(&self) -> Option<&ResidencySet> {
2397        self.residency_set.as_ref()
2398    }
2399
2400    /// Reopen `cmd_buf` with a fresh `CommandBuffer` from the originating queue.
2401    ///
2402    /// ADR-019 Phase 0b iter89e2-B: enables multi-stage chaining. After a
2403    /// non-blocking `commit*` has handed the prior CB to Metal, this method
2404    /// rotates `cmd_buf` to a freshly-allocated CB on the same queue and
2405    /// resets every per-CB scratch field so the next dispatch is encoded
2406    /// onto the new CB.
2407    ///
2408    /// # Caller contract
2409    ///
2410    /// Only valid when `active_encoder.is_null()` (the persistent compute
2411    /// encoder must have been ended via `end_active_encoder()`, which both
2412    /// `commit_and_wait` and `commit` already do). Calling this method
2413    /// while a compute encoder is open would leak the encoder (the new
2414    /// `cmd_buf` does not own it) and trip Metal's "Command encoder
2415    /// released without endEncoding" assertion when the prior `cmd_buf`
2416    /// drops. Callers are [`crate::EncoderSession::reset_for_next_stage`]
2417    /// only — the session has already committed before invoking this.
2418    ///
2419    /// # F2 / F11 / F12 fence preservation
2420    ///
2421    /// - **F2 — residency-rescission**: this method does NOT re-flush
2422    ///   the residency set. The prior `commit*` already flushed; staged
2423    ///   add/remove since then will flush at the next `commit*` on the
2424    ///   new CB. The residency-set Arc clone is preserved.
2425    /// - **F11 — zero-init alloc_buffer**: untouched (no buffer allocs).
2426    /// - **F12 — `HF2Q_FORCE_SERIAL_DISPATCH`**: the new CB will lazily
2427    ///   open its compute encoder via `get_or_create_encoder`, which
2428    ///   re-reads the env var; the falsification probe still fires on
2429    ///   the new CB.
2430    ///
2431    /// # Counter semantics
2432    ///
2433    /// Bumps `CMD_BUF_COUNT` exactly once per call, matching the
2434    /// `new_with_residency` accounting. Does NOT bump `SYNC_COUNT` (no
2435    /// commit/wait happens here).
2436    pub(crate) fn reset_command_buffer(&mut self) {
2437        debug_assert!(
2438            self.active_encoder.is_null(),
2439            "reset_command_buffer called with an active compute encoder \
2440             — caller must commit (which calls end_active_encoder) first"
2441        );
2442        let cmd_buf = if unretained_refs_enabled() {
2443            self.queue
2444                .new_command_buffer_with_unretained_references()
2445                .to_owned()
2446        } else {
2447            self.queue.new_command_buffer().to_owned()
2448        };
2449        CMD_BUF_COUNT.fetch_add(1, Ordering::Relaxed);
2450        self.cmd_buf = cmd_buf;
2451        // Per-CB scratch state — every field that's documented as being
2452        // bounded by a CB lifetime resets here.
2453        self.active_encoder = std::ptr::null();
2454        self.dispatch_in_cb = 0;
2455        self.last_label.clear();
2456        self.pending_dispatch_meta.clear();
2457        // `mem_ranges` is a per-CB barrier inference state; clearing on
2458        // CB rotation matches the `commit_and_wait` post-commit invariant
2459        // (any new CB starts with no pending hazards). The field's own
2460        // `clear` is invoked via `MemRanges::default` here to avoid
2461        // exposing internals.
2462        self.mem_ranges = MemRanges::new();
2463        // `sample_buffer` is dropped explicitly inside
2464        // `resolve_dispatch_samples` after a CB completes; we leave it
2465        // in whatever state the prior commit left it (typically `None`
2466        // after `commit_and_wait` finishes). A stale `Some` here would
2467        // be visible only under `MLX_PROFILE_DISPATCH=1` which fires its
2468        // own one-shot warning; not worth a special case.
2469        // `capture` (if Some) persists across CB rotation — capture mode
2470        // accumulates across stages within a session by design.
2471        // `pending_op_kind` / `pending_reads` / `pending_writes` only
2472        // hold tags for the NEXT dispatch and are consumed when that
2473        // dispatch fires — leaving them as-is is correct.
2474    }
2475
2476    /// Encode an `MTLSharedEvent` wait at `value` on the current CB.
2477    ///
2478    /// ADR-019 Phase 0b iter89e2-B: pairs with [`Self::encode_signal_event`]
2479    /// to express the inter-CB ordering D3 stage boundaries need. The new
2480    /// CB's GPU work blocks until the prior CB's signal lands on the same
2481    /// event at >= `value`.
2482    ///
2483    /// # Caller contract
2484    ///
2485    /// Must be called BEFORE any compute encoder is opened on the new
2486    /// CB — the wait is a CB-level op that must precede every dispatch
2487    /// in the new CB to actually order them. [`crate::EncoderSession::reset_for_next_stage`]
2488    /// fires this immediately after `reset_command_buffer`, before any
2489    /// dispatch lazy-opens the encoder.
2490    #[inline]
2491    pub(crate) fn encode_wait_for_event(&self, event: &metal::EventRef, value: u64) {
2492        debug_assert!(
2493            self.active_encoder.is_null(),
2494            "encode_wait_for_event called with an open compute encoder \
2495             — wait must precede the first dispatch on the new CB"
2496        );
2497        self.cmd_buf.encode_wait_for_event(event, value);
2498    }
2499
2500    /// End the active compute encoder, encode a stage-fence signal, and
2501    /// commit the CB non-blocking — atomically from the caller's view.
2502    ///
2503    /// ADR-019 Phase 0b iter89e2-B: this is the helper
2504    /// [`crate::EncoderSession::fence_stage`] uses to thread the signal
2505    /// between the encoder-end and the CB-commit boundaries that
2506    /// `commit_labeled` would otherwise serialize. Sequence:
2507    ///
2508    /// 1. End the persistent compute encoder (so `encodeSignalEvent:` is
2509    ///    encoded at CB-level, not encoder-level — Metal validates that
2510    ///    `encodeSignalEvent:` outside any encoder pass is the only
2511    ///    legal placement).
2512    /// 2. Apply `label` (when `Some`) to the CB. Note: at this point
2513    ///    the encoder is already ended, so the encoder's own
2514    ///    `setLabel:` is a no-op site — only the CB label propagates.
2515    ///    `last_label` and per-dispatch profiling keep working as
2516    ///    documented.
2517    /// 3. Encode `encodeSignalEvent:event:value:new_value` at CB-level.
2518    /// 4. Flush the residency-set pending staging (matches the
2519    ///    `commit_labeled` / `commit` flush at encoder.rs:2004).
2520    /// 5. Commit the CB non-blocking (matches `commit()` at
2521    ///    encoder.rs:2026).
2522    ///
2523    /// # Counter semantics
2524    ///
2525    /// Bumps `SYNC_COUNT` zero times (non-blocking). Bumps
2526    /// `CMD_BUF_COUNT` zero times (no new CB allocated here —
2527    /// [`Self::reset_command_buffer`] does that on the next stage).
2528    ///
2529    /// # Errors
2530    ///
2531    /// Infallible (matches `commit()` semantics — errors surface only
2532    /// at `wait_until_completed`).
2533    pub(crate) fn fence_signal_and_commit(
2534        &mut self,
2535        event: &metal::EventRef,
2536        new_value: u64,
2537        label: Option<&str>,
2538    ) {
2539        // Step 1: end the active compute encoder. encode_signal_event's
2540        // debug_assert requires this be done first.
2541        self.end_active_encoder();
2542        // Step 2: apply the CB label so xctrace MST attribution still
2543        // works on the fenced CB. apply_labels' debug_assert against
2544        // empty labels matches commit_labeled's semantics.
2545        if let Some(l) = label {
2546            self.apply_labels(l);
2547        }
2548        // Step 3: encode the signal at CB-level.
2549        self.cmd_buf.encode_signal_event(event, new_value);
2550        // Step 4 + 5: same as commit() — flush residency staging, then
2551        // hand the CB to Metal.
2552        self.flush_residency_pending();
2553        self.cmd_buf.commit();
2554    }
2555}
2556
2557impl Drop for CommandEncoder {
2558    fn drop(&mut self) {
2559        // End the persistent compute encoder before the command buffer
2560        // is dropped, otherwise Metal will assert:
2561        // "Command encoder released without endEncoding"
2562        self.end_active_encoder();
2563    }
2564}