mlx_native/encoder.rs
1//! [`CommandEncoder`] — batched GPU command submission.
2//!
3//! Wraps a Metal command buffer. Encode one or more compute kernel dispatches,
4//! then call [`commit_and_wait`](CommandEncoder::commit_and_wait) to submit the
5//! entire batch and block until the GPU finishes.
6//!
7//! # Persistent compute encoder
8//!
9//! A single Metal `ComputeCommandEncoder` is kept alive across multiple
10//! dispatches within the same command buffer. This avoids the overhead of
11//! creating and ending a new compute encoder per dispatch — the same pattern
12//! candle uses (`compute_per_buffer`). On a forward pass with ~800 dispatches
13//! this saves ~800 encoder create/end cycles.
14//!
15//! # Capture mode (Phase 4e.1)
16//!
17//! When `start_capture()` is called, subsequent dispatches are recorded into a
18//! `Vec<CapturedNode>` instead of being encoded into Metal. `memory_barrier()`
19//! records a barrier sentinel. Call `take_capture()` to extract the recorded
20//! graph for later replay via `ComputeGraph::encode_sequential()`.
21
22use std::sync::atomic::{AtomicU64, Ordering};
23
24use metal::{
25 CommandBuffer, CommandQueue, ComputeCommandEncoderRef, ComputePipelineState,
26 ComputePipelineStateRef, CounterSampleBuffer, CounterSampleBufferDescriptor,
27 MTLCommandBufferStatus, MTLCounterSamplingPoint, MTLDispatchType, MTLSize, MTLStorageMode,
28 NSRange,
29};
30#[allow(unused_imports)]
31use objc::{msg_send, sel, sel_impl};
32
33use crate::buffer::MlxBuffer;
34use crate::error::{MlxError, Result};
35use crate::mem_ranges::MemRanges;
36use crate::residency::ResidencySet;
37
38/// A buffer or inline-bytes binding for a compute kernel argument slot.
39pub enum KernelArg<'a> {
40 /// Bind an existing Metal buffer at the given index.
41 Buffer(&'a MlxBuffer),
42 /// Bind an existing Metal buffer at the given index with a byte offset.
43 BufferWithOffset(&'a MlxBuffer, u64),
44 /// Bind inline bytes (small constant data) at the given index.
45 /// The data must be `Pod` and is copied into the command encoder.
46 Bytes(&'a [u8]),
47}
48
49/// Convert a `Pod` value to a byte slice suitable for `KernelArg::Bytes`.
50///
51/// # Safety
52///
53/// The caller must ensure `T` has the same layout as the corresponding
54/// MSL struct in the shader (matching field order, sizes, and alignment).
55pub fn as_bytes<T: bytemuck::Pod>(val: &T) -> &[u8] {
56 bytemuck::bytes_of(val)
57}
58
59// ---------------------------------------------------------------------------
60// Capture-mode types (Phase 4e.1 — Graph IR)
61// ---------------------------------------------------------------------------
62
63/// A recorded kernel argument binding.
64///
65/// When the encoder is in capture mode, each `set_buffer` / `set_bytes` call
66/// is stored as a `RecordedBinding` instead of being applied to Metal.
67#[derive(Clone)]
68pub enum RecordedBinding {
69 /// A Metal buffer at the given offset.
70 Buffer {
71 metal_buffer: metal::Buffer,
72 offset: u64,
73 },
74 /// Inline bytes (small constant data, copied).
75 Bytes(Vec<u8>),
76}
77
78/// How to dispatch the recorded kernel.
79#[derive(Clone, Copy, Debug)]
80pub enum DispatchKind {
81 /// `dispatch_threads(grid_size, threadgroup_size)` — Metal picks threadgroup count.
82 Threads,
83 /// `dispatch_thread_groups(threadgroups, threadgroup_size)` — caller specifies threadgroup count.
84 ThreadGroups,
85}
86
87/// Operation kind tag for captured nodes, used by the fusion pass (4e.2).
88///
89/// When the encoder is in capture mode, each dispatch can be tagged with an
90/// `OpKind` so the fusion pass can identify fuseable sequences without
91/// inspecting pipeline names.
92#[derive(Clone, Copy, Debug, PartialEq, Eq)]
93pub enum CapturedOpKind {
94 /// RMS normalization (with learned scale).
95 RmsNorm,
96 /// Elementwise multiply.
97 ElemMul,
98 /// Elementwise add.
99 ElemAdd,
100 /// Scaled dot-product attention (NOT reorderable — breaks lookahead).
101 Sdpa,
102 /// Softmax (NOT reorderable — breaks lookahead).
103 Softmax,
104 /// Any other operation — treated as reorderable by the graph optimizer.
105 Other,
106}
107
108impl CapturedOpKind {
109 /// Whether this captured op kind is safe to reorder past in the graph
110 /// optimizer (Phase 4e.3).
111 ///
112 /// Mirrors the `h_safe` whitelist from llama.cpp's
113 /// `ggml_metal_graph_optimize_reorder`. Non-safe ops break the 64-node
114 /// lookahead — the reorder pass cannot look past them.
115 pub fn is_reorderable(&self) -> bool {
116 match self {
117 Self::Sdpa | Self::Softmax => false,
118 Self::RmsNorm | Self::ElemMul | Self::ElemAdd | Self::Other => true,
119 }
120 }
121
122 /// Stable string label suitable for embedding in the per-dispatch
123 /// profile dump (ADR-015 iter63 §A.5). Matches the variant name —
124 /// `Other` is preserved verbatim so an aggregate-by-op_kind sort
125 /// produces a clean "what isn't yet labeled" bucket.
126 pub fn name(&self) -> &'static str {
127 match self {
128 Self::RmsNorm => "RmsNorm",
129 Self::ElemMul => "ElemMul",
130 Self::ElemAdd => "ElemAdd",
131 Self::Sdpa => "Sdpa",
132 Self::Softmax => "Softmax",
133 Self::Other => "Other",
134 }
135 }
136}
137
138/// A memory range annotation: (start_address, end_address).
139///
140/// Represents a contiguous GPU buffer region for conflict detection in the
141/// reorder pass (Phase 4e.3). Addresses are CPU-visible `contents_ptr()`
142/// values, which on Apple Silicon unified memory equal the GPU addresses.
143pub type MemRange = (usize, usize);
144
145/// A single captured compute dispatch or barrier sentinel.
146///
147/// Created when the encoder is in capture mode. Replayed later by
148/// `ComputeGraph::encode_sequential()`.
149#[derive(Clone)]
150pub enum CapturedNode {
151 /// A compute dispatch to replay.
152 Dispatch {
153 /// Pipeline state object to bind.
154 pipeline: ComputePipelineState,
155 /// Kernel argument bindings: (slot_index, binding).
156 bindings: Vec<(u64, RecordedBinding)>,
157 /// Grid or threadgroup count (interpretation depends on `dispatch_kind`).
158 threads_per_grid: MTLSize,
159 /// Threads per threadgroup.
160 threads_per_threadgroup: MTLSize,
161 /// Optional threadgroup memory allocations: (index, byte_length).
162 threadgroup_memory: Vec<(u64, u64)>,
163 /// Whether this is a dispatch_threads or dispatch_thread_groups call.
164 dispatch_kind: DispatchKind,
165 /// Operation kind tag for the fusion pass (4e.2).
166 /// Defaults to `Other` if not explicitly set via `set_op_kind()`.
167 op_kind: CapturedOpKind,
168 /// Read buffer ranges for reorder conflict detection (4e.3).
169 /// Populated from `barrier_between` calls in capture mode.
170 reads: Vec<MemRange>,
171 /// Write buffer ranges for reorder conflict detection (4e.3).
172 /// Populated from `barrier_between` calls in capture mode.
173 writes: Vec<MemRange>,
174 },
175 /// A memory barrier sentinel — forces a barrier at replay time.
176 Barrier,
177}
178
179/// Convert a slice of buffer references into capture-mode
180/// [`MemRange`] tuples. Used by the [`CommandEncoder::dispatch_tracked*`]
181/// family in capture mode — equivalent to the conversion
182/// `GraphSession::barrier_between` does at `graph.rs:1452-1465`.
183///
184/// `(start, end)` uses `contents_ptr() + byte_offset` as the start
185/// and `contents_ptr() + byte_offset + slice_extent` as the end.
186fn ranges_from_buffers(bufs: &[&MlxBuffer]) -> Vec<MemRange> {
187 bufs.iter()
188 .map(|b| {
189 let base = b.contents_ptr() as usize + b.byte_offset() as usize;
190 let extent = (b.byte_len()).saturating_sub(b.byte_offset() as usize);
191 (base, base + extent)
192 })
193 .collect()
194}
195
196/// Apply a slice of `KernelArg` bindings to a compute encoder.
197///
198/// `KernelArg::Buffer(buf)` propagates the `MlxBuffer::byte_offset()` so
199/// `slice_view`-derived sub-buffers are honored automatically — the
200/// kernel sees memory starting at the slice's offset. This matches the
201/// documented contract of `slice_view` and the offset-handling in the
202/// other binding paths in this file (`encode`, `encode_threadgroups`,
203/// `encode_threadgroups_with_shared`, replay). Without it, every
204/// `slice_view`-derived buffer bound via `KernelArg::Buffer` silently
205/// exposes the entire underlying allocation — surfaced by hf2q's
206/// nomic-bert iter-79 cosine parity bisection (cosine 0.098 → 0.999962
207/// after fix).
208///
209/// `KernelArg::BufferWithOffset(buf, offset)` continues to use the
210/// explicit `offset` argument verbatim (callers asking for an explicit
211/// offset get exactly that, even on sliced buffers). The two API
212/// surfaces are intentional: implicit (sliced views auto-propagate) vs.
213/// explicit (caller-controlled).
214#[inline]
215fn apply_bindings(encoder: &ComputeCommandEncoderRef, bindings: &[(u64, KernelArg<'_>)]) {
216 for &(index, ref arg) in bindings {
217 match arg {
218 KernelArg::Buffer(buf) => {
219 encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
220 }
221 KernelArg::BufferWithOffset(buf, offset) => {
222 encoder.set_buffer(index, Some(buf.metal_buffer()), *offset);
223 }
224 KernelArg::Bytes(bytes) => {
225 encoder.set_bytes(index, bytes.len() as u64, bytes.as_ptr() as *const _);
226 }
227 }
228 }
229}
230
231/// Number of times `commit_and_wait()` has been called (CPU sync points).
232static SYNC_COUNT: AtomicU64 = AtomicU64::new(0);
233
234/// Number of times an encode method has been called (GPU dispatches).
235static DISPATCH_COUNT: AtomicU64 = AtomicU64::new(0);
236
237/// Number of `MTLCommandBuffer` instances created via `CommandEncoder::new`.
238/// Increments once per `device.command_encoder()` call. Used by hf2q's
239/// `HF2Q_DECODE_PROFILE` instrumentation to measure command-buffer
240/// overhead per decode token (ADR-012 §Optimize / Task #15 follow-up).
241static CMD_BUF_COUNT: AtomicU64 = AtomicU64::new(0);
242
243/// Number of `memory_barrier()` calls that reached the
244/// `objc::msg_send![encoder, memoryBarrierWithScope:]` site. Capture-mode
245/// no-ops and pre-encoder no-ops are excluded so the count reflects
246/// actual MTL barriers issued.
247///
248/// Always tracked — the increment is one atomic op, ~5 ns. ADR-015 H4
249/// (Wave 2b hard gate #2) requires per-barrier counter resolution to
250/// confirm-or-falsify the barrier-coalescing lever; xctrace TimeProfiler
251/// at 1 ms sampling cannot resolve `memory_barrier` even though it fires
252/// ~440×/token (`docs/ADR-015-mlx-native-single-cb-decode.md` §"P3a' live
253/// profile pass" hypothesis register row H4).
254static BARRIER_COUNT: AtomicU64 = AtomicU64::new(0);
255
256/// Total nanoseconds spent inside the `objc::msg_send!` barrier site,
257/// summed across all calls. ONLY updated when the env var
258/// `MLX_PROFILE_BARRIERS=1` is set on the process (cached on first
259/// `memory_barrier` call). When disabled the timing path is a single
260/// branch + the unconditional barrier dispatch — same hot-path cost as
261/// before this counter was added.
262///
263/// Why env-gated: timing adds 2 × `Instant::now()` (~50–100 ns each via
264/// `mach_absolute_time`) per barrier. At ~440 barriers/token that is
265/// ~22–44 µs/token of measurement overhead — comparable to what we are
266/// trying to measure. Production must keep this off; profiling runs
267/// opt-in.
268static BARRIER_NS: AtomicU64 = AtomicU64::new(0);
269
270/// Reset all counters to zero.
271pub fn reset_counters() {
272 SYNC_COUNT.store(0, Ordering::Relaxed);
273 DISPATCH_COUNT.store(0, Ordering::Relaxed);
274 CMD_BUF_COUNT.store(0, Ordering::Relaxed);
275 BARRIER_COUNT.store(0, Ordering::Relaxed);
276 BARRIER_NS.store(0, Ordering::Relaxed);
277 AUTO_BARRIER_COUNT.store(0, Ordering::Relaxed);
278 AUTO_BARRIER_CONCURRENT.store(0, Ordering::Relaxed);
279}
280
281/// Read the current value of `SYNC_COUNT`.
282///
283/// Each call to `commit_and_wait()` increments this counter.
284pub fn sync_count() -> u64 {
285 SYNC_COUNT.load(Ordering::Relaxed)
286}
287
288/// Read the current value of `DISPATCH_COUNT`.
289///
290/// Each call to `encode()`, `encode_threadgroups()`, or
291/// `encode_threadgroups_with_shared()` increments this counter.
292pub fn dispatch_count() -> u64 {
293 DISPATCH_COUNT.load(Ordering::Relaxed)
294}
295
296/// Read the current value of `CMD_BUF_COUNT`.
297///
298/// Each `CommandEncoder::new` (i.e. each `MlxDevice::command_encoder()`)
299/// increments this counter. Useful for diagnosing per-dispatch Metal
300/// command-buffer overhead in inner loops.
301pub fn cmd_buf_count() -> u64 {
302 CMD_BUF_COUNT.load(Ordering::Relaxed)
303}
304
305/// Read the current value of `BARRIER_COUNT`.
306///
307/// Each `memory_barrier()` call that reaches the underlying
308/// `objc::msg_send![encoder, memoryBarrierWithScope:]` site increments this
309/// counter. Capture-mode no-ops and pre-encoder no-ops are excluded.
310/// ADR-015 H4 hypothesis: ~440 barriers/token on the qwen35 decode hot
311/// path (verify against this counter).
312pub fn barrier_count() -> u64 {
313 BARRIER_COUNT.load(Ordering::Relaxed)
314}
315
316/// Read the total nanoseconds spent in the `memoryBarrierWithScope:`
317/// `objc::msg_send!` site. Only non-zero when `MLX_PROFILE_BARRIERS=1`
318/// was in the environment at the time of the first `memory_barrier()`
319/// call (the env check is cached on first use).
320///
321/// Combined with [`barrier_count`] this gives µs/barrier =
322/// `barrier_total_ns() / 1000 / barrier_count()`.
323pub fn barrier_total_ns() -> u64 {
324 BARRIER_NS.load(Ordering::Relaxed)
325}
326
327/// Whether barrier timing is enabled (env-gated, cached on first check).
328///
329/// Reading the env var via `std::env::var` is itself non-trivial; using
330/// `OnceLock` caches the decision so the per-barrier branch is a single
331/// atomic-load + compare.
332fn barrier_profile_enabled() -> bool {
333 use std::sync::OnceLock;
334 static FLAG: OnceLock<bool> = OnceLock::new();
335 *FLAG.get_or_init(|| {
336 std::env::var("MLX_PROFILE_BARRIERS")
337 .map(|v| v == "1")
338 .unwrap_or(false)
339 })
340}
341
342/// Whether `MLX_UNRETAINED_REFS=1` is set in the process environment.
343///
344/// ADR-015 iter13 — when true, `CommandEncoder::new_with_residency` opens
345/// each `MTLCommandBuffer` via
346/// [`CommandQueueRef::new_command_buffer_with_unretained_references`]
347/// instead of the default `commandBuffer`. llama.cpp's per-token decode
348/// CBs use this same call (`/opt/llama.cpp/ggml/src/ggml-metal/`
349/// `ggml-metal-context.m:512` `[queue commandBufferWithUnretainedReferences]`)
350/// and gain ~3-5% wall on M-series GPUs by skipping per-buffer-binding ARC
351/// retains on submit.
352///
353/// **Caller-side prerequisite.** Every Metal buffer bound to a dispatch
354/// must outlive the CB — see the docstring on
355/// [`CommandEncoder::new_with_residency`] for the full caller contract.
356/// In hf2q, the per-decode-token `MlxBufferPool` (`buffer_pool.rs`)
357/// already keeps ARC clones alive in its `in_use` list across the entire
358/// decode token; routing transient scratches through that pool is the
359/// canonical way to satisfy the contract.
360///
361/// Cached on first read via `OnceLock` to keep the per-CB-construction
362/// branch single-atomic-load fast. Default OFF so any production decode
363/// run that does NOT explicitly set the var preserves retained-refs
364/// behavior verbatim.
365fn unretained_refs_enabled() -> bool {
366 use std::sync::OnceLock;
367 static FLAG: OnceLock<bool> = OnceLock::new();
368 *FLAG.get_or_init(|| {
369 std::env::var("MLX_UNRETAINED_REFS")
370 .map(|v| v == "1")
371 .unwrap_or(false)
372 })
373}
374
375/// Whether `HF2Q_AUTO_BARRIER=1` is set in the process environment.
376///
377/// ADR-015 iter37 — when true, every [`CommandEncoder::dispatch_tracked`]
378/// call consults a [`MemRanges`](crate::mem_ranges::MemRanges) tracker
379/// and auto-emits a `memoryBarrierWithScope:` exactly when the new
380/// dispatch's read/write ranges conflict with previously-recorded
381/// ranges (mirrors llama.cpp's `ggml_metal_op_concurrency_check` at
382/// `/opt/llama.cpp/ggml/src/ggml-metal/ggml-metal-ops.cpp:147-225`).
383/// When false, `dispatch_tracked` collapses to the same code path as
384/// `encode*` — no tracking, no auto-barriers — preserving sourdough
385/// behavior for any caller that opts into the tracked API but runs
386/// without the env gate.
387///
388/// Cached on first read via `OnceLock`. Default OFF — production
389/// decode/prefill keeps its hand-placed `enc.memory_barrier()` calls
390/// until the migration in iter38+.
391fn auto_barrier_enabled() -> bool {
392 use std::sync::OnceLock;
393 static FLAG: OnceLock<bool> = OnceLock::new();
394 *FLAG.get_or_init(|| {
395 std::env::var("HF2Q_AUTO_BARRIER")
396 .map(|v| v == "1")
397 .unwrap_or(false)
398 })
399}
400
401/// Number of `memory_barrier()` calls auto-emitted by
402/// [`CommandEncoder::dispatch_tracked`] under
403/// `HF2Q_AUTO_BARRIER=1`. Disjoint from [`BARRIER_COUNT`] —
404/// auto-barriers also bump `BARRIER_COUNT` since they go through
405/// `memory_barrier()`, so this counter measures only the
406/// auto-emitted subset.
407static AUTO_BARRIER_COUNT: AtomicU64 = AtomicU64::new(0);
408
409/// Number of `dispatch_tracked` calls whose mem-ranges check returned
410/// "concurrent" (no barrier needed). Together with
411/// [`AUTO_BARRIER_COUNT`] this measures the elision rate of the
412/// dataflow barrier: `concurrent / (concurrent + barriers)` is the
413/// fraction of dispatches that ran inside the previous concurrent
414/// group rather than starting a new one.
415static AUTO_BARRIER_CONCURRENT: AtomicU64 = AtomicU64::new(0);
416
417// ---------------------------------------------------------------------------
418// ADR-015 iter63 — per-dispatch GPU sampling support
419// ---------------------------------------------------------------------------
420
421/// Hard cap on per-CB sample-buffer sample count (Risk R4 in
422/// PROFILING-KIT-DESIGN §A.7).
423///
424/// Empirically verified on Apple Silicon (M-series, macOS 26): the
425/// underlying `MTLCounterSampleBufferDescriptor.sampleCount` is bounded
426/// by a per-buffer **byte-size** limit of 32768 B. At 8 bytes per
427/// `MTLCounterResultTimestamp` sample that maps to a sample-count
428/// ceiling of `32_768 / 8 = 4096`. We allocate two samples per
429/// dispatch (start + end), so this ceiling = 2048 dispatches per CB.
430/// Decode CBs (~120 dispatches) fit comfortably; long prefill CBs
431/// (~6K dispatches per design §A.7) will truncate after 2048 — see
432/// [`Self::sample_dispatch_pre`] for the truncation path. Future
433/// iter can chunk-resolve every 2K dispatches.
434///
435/// The original design constant of 32_768 (PROFILING-KIT-DESIGN §A.7)
436/// was based on Apple's documented ~64K-per-buffer "practical" limit,
437/// but the measured constraint on this hardware is the 32 KB byte
438/// budget. Setting the budget below that would underutilize the
439/// buffer; setting it above causes
440/// `newCounterSampleBufferWithDescriptor` to fail with `Invalid sample
441/// buffer length: <bytes> B. Expected range: 8 -> 32768`.
442const MAX_SAMPLES_PER_CB: u64 = 4096;
443
444/// Whether the per-CB warning about a missing `MTLCommonCounterSetTimestamp`
445/// has been emitted yet. Risk R1: if `device.counter_sets()` does not
446/// return a set named `"timestamp"` (case-insensitive), we degrade the
447/// per-dispatch path to a no-op and log once via stderr.
448static TIMESTAMP_SET_WARN_LOGGED: AtomicU64 = AtomicU64::new(0);
449
450/// Pending per-dispatch metadata that pairs with sample indices `2i`
451/// (start) and `2i+1` (end) inside the CB's `MTLCounterSampleBuffer`.
452/// Resolved by `CommandEncoder::resolve_dispatch_samples` at CB
453/// commit-time and converted to [`crate::kernel_profile::DispatchEntry`]
454/// before being pushed to the global table.
455#[derive(Clone, Debug)]
456struct PendingDispatchMeta {
457 op_kind: &'static str,
458 dispatch_index: u32,
459}
460
461/// Read the cumulative number of auto-emitted barriers across all
462/// encoders since process start (or last [`reset_counters`]).
463pub fn auto_barrier_count() -> u64 {
464 AUTO_BARRIER_COUNT.load(Ordering::Relaxed)
465}
466
467/// Read the cumulative number of `dispatch_tracked` calls that did NOT
468/// emit a barrier (ran concurrent with the previous group).
469pub fn auto_barrier_concurrent_count() -> u64 {
470 AUTO_BARRIER_CONCURRENT.load(Ordering::Relaxed)
471}
472
473/// Issue the underlying Metal `memoryBarrierWithScope:` ObjC msg_send.
474///
475/// Held in its own `#[inline(never)]` function so xctrace / Instruments
476/// has a stable Rust frame to attribute barrier time against, separate
477/// from the surrounding encoder accounting. Per ADR-015 §P3a' Codex
478/// review Q2: TimeProfiler at 1 ms sampling cannot see this site when
479/// inlined; an explicit non-inline frame plus the [`BARRIER_NS`] counter
480/// closes the H4 hard gate.
481#[inline(never)]
482fn issue_metal_buffer_barrier(encoder: &ComputeCommandEncoderRef) {
483 // MTLBarrierScopeBuffers = 1 << 0 = 1.
484 const MTL_BARRIER_SCOPE_BUFFERS: u64 = 1;
485 unsafe {
486 let _: () =
487 objc::msg_send![encoder, memoryBarrierWithScope: MTL_BARRIER_SCOPE_BUFFERS];
488 }
489}
490
491/// A batched compute command encoder.
492///
493/// Keeps a single Metal `ComputeCommandEncoder` alive across multiple
494/// dispatches. The encoder is created on the first dispatch and ended
495/// only when the command buffer is committed. This mirrors candle's
496/// `compute_per_buffer` pattern and avoids per-dispatch encoder overhead.
497///
498/// # Typical usage
499///
500/// ```ignore
501/// let mut enc = device.command_encoder()?;
502/// // Multiple dispatches share the same compute encoder:
503/// enc.encode_threadgroups(pipeline1, &buffers1, tg1, tg_size1);
504/// enc.encode_threadgroups(pipeline2, &buffers2, tg2, tg_size2);
505/// enc.commit_and_wait()?;
506/// ```
507pub struct CommandEncoder {
508 cmd_buf: CommandBuffer,
509 /// Owned clone of the originating command queue.
510 ///
511 /// ADR-019 Phase 0b iter89e2-A: stored at `new_with_residency` time so
512 /// downstream lifecycle code (e.g. `EncoderSession::reset_for_next_stage`
513 /// in Phase 0b-B) can open a fresh `CommandBuffer` from the same queue
514 /// after a non-blocking `commit_stage()`. metal-rs 0.33's
515 /// `CommandQueue` type is `Send + Sync` via `foreign_obj_type!`
516 /// (`/Users/robert/.cargo/registry/src/index.crates.io-1949cf8c6b5b557f/metal-0.33.0/src/lib.rs:179`),
517 /// so adding this field preserves the existing unsafe `Send` impl
518 /// on `CommandEncoder` (declared below).
519 ///
520 /// ADR-019 Phase 0b iter89e2-B (CONSUMED): read by
521 /// [`Self::reset_command_buffer`] to spawn a fresh `CommandBuffer`
522 /// after a non-blocking `commit*` so `EncoderSession::reset_for_next_stage`
523 /// can chain stage CBs without re-constructing the encoder. Holding a
524 /// clone here (rather than a `&CommandQueue` borrow) avoids a lifetime
525 /// parameter on `CommandEncoder` that would propagate through every
526 /// consumer in mlx-native and hf2q.
527 queue: CommandQueue,
528 // SAFETY marker: see unsafe Send impl below.
529 /// Raw pointer to the persistent compute encoder.
530 /// Non-null when a compute pass is active.
531 /// The encoder borrows from `cmd_buf` but we cannot express this
532 /// lifetime in safe Rust, so we use a raw pointer.
533 /// SAFETY: the pointer is valid as long as `cmd_buf` is alive and
534 /// `end_encoding()` has not been called on it.
535 active_encoder: *const ComputeCommandEncoderRef,
536 /// When `Some`, dispatches are recorded here instead of being encoded
537 /// into Metal. Set via `start_capture()`, extracted via `take_capture()`.
538 capture: Option<Vec<CapturedNode>>,
539 /// Op kind tag for the NEXT captured dispatch. Set via `set_op_kind()`,
540 /// consumed (reset to `Other`) when a dispatch is captured.
541 pending_op_kind: CapturedOpKind,
542 /// Pending read buffer ranges for the NEXT captured dispatch.
543 /// Set via `set_pending_buffer_ranges()`, consumed when the next dispatch
544 /// is captured. Used by the reorder pass (Phase 4e.3).
545 pending_reads: Vec<MemRange>,
546 /// Pending write buffer ranges for the NEXT captured dispatch.
547 pending_writes: Vec<MemRange>,
548 /// ADR-015 iter8e (Phase 3b): residency set whose pending add/remove
549 /// staging is flushed at every `commit*` boundary.
550 ///
551 /// Cloned from the device at `device.command_encoder()` time. `None`
552 /// when residency sets are disabled (HF2Q_NO_RESIDENCY=1, macOS<15,
553 /// or test-only `CommandEncoder::new` from a residency-less queue).
554 residency_set: Option<ResidencySet>,
555 /// ADR-015 iter37: dataflow barrier inference state.
556 ///
557 /// Populated only when `HF2Q_AUTO_BARRIER=1` is set at process
558 /// start (cached via [`auto_barrier_enabled`]). Each
559 /// [`Self::dispatch_tracked`] call consults this state to decide
560 /// whether a Metal memory barrier is required; on conflict the
561 /// barrier is emitted, the state is reset, and the new dispatch's
562 /// ranges seed the next concurrent group. When the env gate is
563 /// off, `dispatch_tracked` collapses to its untracked equivalent
564 /// and this field is left empty for the encoder's lifetime.
565 ///
566 /// The field is always present (zero-sized when empty) so the
567 /// gate-off branch is a single bool-load + early return rather
568 /// than an allocation/Option indirection.
569 mem_ranges: MemRanges,
570 /// ADR-015 iter63 (per-dispatch profiling): the sample buffer for
571 /// `MTLCounterSampleBuffer.sampleCounters` calls that bracket every
572 /// `encode*` dispatch in this CB. Lazily allocated on first
573 /// dispatch when `MLX_PROFILE_DISPATCH=1`; `None` otherwise.
574 /// Released (set to `None`) inside `resolve_dispatch_samples` after
575 /// the CB completes — re-allocated on the next `encode*` if the env
576 /// gate stays set.
577 sample_buffer: Option<CounterSampleBuffer>,
578 /// ADR-015 iter63: pending per-dispatch metadata that pairs with
579 /// sample indices `2*i` and `2*i+1` inside `sample_buffer`. Each
580 /// `encode*` call appends one entry (when sampling is active);
581 /// `resolve_dispatch_samples` drains the vec at commit time.
582 pending_dispatch_meta: Vec<PendingDispatchMeta>,
583 /// ADR-015 iter63: 0-based dispatch ordinal within the current CB.
584 /// Incremented in every `encode*` site after taking the pending
585 /// op_kind; reset to 0 inside `resolve_dispatch_samples`.
586 dispatch_in_cb: u32,
587 /// ADR-015 iter63: most recent label set via `apply_labels`, used
588 /// as the per-dispatch `cb_label` field. `String::new()` until
589 /// `commit_and_wait_labeled` / `commit_labeled` is called.
590 last_label: String,
591}
592
593/// SAFETY: CommandEncoder is safe to Send across threads provided that:
594/// 1. Only one thread accesses the encoder at a time (exclusive ownership).
595/// 2. The encoder is not used concurrently from multiple threads.
596///
597/// Metal command buffers and compute encoders are thread-safe for exclusive
598/// access (Apple documentation: "You can create command buffers, encode
599/// commands, and submit them from any thread"). The raw pointer
600/// `active_encoder` borrows from `cmd_buf` and is valid as long as
601/// `cmd_buf` is alive — this invariant holds across thread boundaries
602/// because both fields move together.
603///
604/// This matches llama.cpp's pattern of encoding command buffers on GCD
605/// worker threads via `dispatch_apply`, and is used for the dual-buffer
606/// pipeline where buf1 is encoded on a worker thread while buf0 executes.
607unsafe impl Send for CommandEncoder {}
608
609impl CommandEncoder {
610 /// Create a new command encoder from the given command queue.
611 ///
612 /// This immediately creates a Metal command buffer.
613 ///
614 /// # Why retained references
615 ///
616 /// We use the regular `commandBuffer` (Metal retains every bound
617 /// resource for the lifetime of the buffer) rather than
618 /// `commandBufferWithUnretainedReferences`. llama.cpp uses unretained
619 /// refs for an additional perf bump (~3-5% on M-series GPUs), but the
620 /// hf2q dispatch pattern allocates many transient scratch buffers
621 /// inside helper functions (`apply_proj` → `weight_bf16_owned`,
622 /// `apply_pre_norm` → `params`, etc.) that go out of scope at the
623 /// helper's return. With unretained refs the metal::Buffer's ARC
624 /// drops to zero, freeing the underlying GPU memory before the
625 /// dispatch executes. Verified 2026-04-26: switching to unretained
626 /// hits "Command buffer error: GPU command buffer completed with
627 /// error status" on the first MoE FFN dispatch.
628 ///
629 /// To enable unretained refs in the future, every helper that
630 /// allocates and dispatches must thread its scratch buffers up to a
631 /// caller scope that outlives the eventual commit, OR all such
632 /// scratch must come from the per-decode-token pool (which already
633 /// ARC-retains in its in_use list). Today the lm_head + router-
634 /// download paths are still unpooled.
635 #[allow(dead_code)]
636 pub(crate) fn new(queue: &CommandQueue) -> Result<Self> {
637 Self::new_with_residency(queue, None)
638 }
639
640 /// Create a new command encoder, optionally bound to a residency set so
641 /// `commit*` boundaries can flush deferred add/remove staging.
642 ///
643 /// ADR-015 iter8e (Phase 3b): the encoder's `commit_and_wait`,
644 /// `commit_and_wait_labeled`, `commit`, `commit_labeled`,
645 /// `commit_wait_with_gpu_time` all call
646 /// [`ResidencySet::flush_pending`](ResidencySet::flush_pending) before
647 /// submitting the Metal command buffer. This converts the
648 /// per-allocation `[set commit]` storm
649 /// (~880 commits/decode-token in iter8d/8e claude+codex variants) into
650 /// at most one commit per CB submission — mirrors llama.cpp's
651 /// `ggml-metal-device.m:1378-1382` pattern (batch addAllocation in
652 /// loop, commit ONCE).
653 ///
654 /// ADR-015 iter13: when the `MLX_UNRETAINED_REFS=1` env var is set at
655 /// process start, this constructor uses
656 /// [`CommandQueueRef::new_command_buffer_with_unretained_references`]
657 /// instead of `new_command_buffer`. llama.cpp's per-token decode CBs
658 /// use `commandBufferWithUnretainedReferences` (see
659 /// `/opt/llama.cpp/ggml/src/ggml-metal/ggml-metal-context.m:512`) which
660 /// skips Metal's per-buffer-binding ARC-retain on submit and saves
661 /// ~3-5% on M-series GPUs (per the docstring above).
662 ///
663 /// **Caller contract under unretained refs.** Every Metal buffer bound
664 /// to a dispatch in this CB MUST outlive the CB's GPU completion. In
665 /// the hf2q decode path, that means every transient scratch must be
666 /// either (a) backed by the per-decode-token arena pool
667 /// (`MlxBufferPool` keeps an ARC clone in `in_use` until the next
668 /// `reset` — see `buffer_pool.rs:60`) or (b) hoisted to a caller scope
669 /// that lives across the terminal `commit_and_wait_labeled`. Helpers
670 /// in `apply_proj` / `apply_pre_norm` / lm_head cast / router-download
671 /// that allocated transients via `device.alloc_buffer` and dropped
672 /// them at function return MUST be lifted to `pooled_alloc_buffer`
673 /// before `MLX_UNRETAINED_REFS=1` is enabled, or the first MoE FFN
674 /// dispatch will crash with "Command buffer error: GPU command buffer
675 /// completed with error status" (verified 2026-04-26).
676 ///
677 /// The default (`MLX_UNRETAINED_REFS` unset) preserves retained-refs
678 /// behavior verbatim — this is the sourdough-safe path.
679 pub(crate) fn new_with_residency(
680 queue: &CommandQueue,
681 residency_set: Option<ResidencySet>,
682 ) -> Result<Self> {
683 let cmd_buf = if unretained_refs_enabled() {
684 queue.new_command_buffer_with_unretained_references().to_owned()
685 } else {
686 queue.new_command_buffer().to_owned()
687 };
688 CMD_BUF_COUNT.fetch_add(1, Ordering::Relaxed);
689 Ok(Self {
690 cmd_buf,
691 queue: queue.to_owned(),
692 active_encoder: std::ptr::null(),
693 capture: None,
694 pending_op_kind: CapturedOpKind::Other,
695 pending_reads: Vec::new(),
696 pending_writes: Vec::new(),
697 residency_set,
698 mem_ranges: MemRanges::new(),
699 sample_buffer: None,
700 pending_dispatch_meta: Vec::new(),
701 dispatch_in_cb: 0,
702 last_label: String::new(),
703 })
704 }
705
706 /// Enable capture mode.
707 ///
708 /// All subsequent dispatch and barrier calls will be recorded into a
709 /// `Vec<CapturedNode>` instead of being encoded into Metal.
710 /// Call `take_capture()` to extract the recorded nodes.
711 pub fn start_capture(&mut self) {
712 self.capture = Some(Vec::with_capacity(128));
713 }
714
715 /// Whether the encoder is currently in capture mode.
716 pub fn is_capturing(&self) -> bool {
717 self.capture.is_some()
718 }
719
720 /// Extract the captured nodes, ending capture mode.
721 ///
722 /// Returns `None` if capture mode was not active.
723 pub fn take_capture(&mut self) -> Option<Vec<CapturedNode>> {
724 self.capture.take()
725 }
726
727 /// Tag the NEXT captured dispatch with the given operation kind.
728 ///
729 /// The tag is consumed (reset to `Other`) after the next dispatch is
730 /// captured. Only meaningful in capture mode — has no effect on
731 /// direct-dispatch encoding.
732 ///
733 /// Used by op dispatch functions to annotate captures for the fusion
734 /// pass (Phase 4e.2).
735 pub fn set_op_kind(&mut self, kind: CapturedOpKind) {
736 self.pending_op_kind = kind;
737 }
738
739 /// Consume and return the pending op kind, resetting it to `Other`.
740 fn take_pending_op_kind(&mut self) -> CapturedOpKind {
741 let kind = self.pending_op_kind;
742 self.pending_op_kind = CapturedOpKind::Other;
743 kind
744 }
745
746 /// Stash buffer range annotations for the NEXT captured dispatch.
747 ///
748 /// Called by `GraphSession::barrier_between()` in capture mode to record
749 /// which buffers the next dispatch reads from and writes to. The ranges
750 /// are consumed by the next `encode_*` call and attached to the captured
751 /// `CapturedNode::Dispatch`.
752 ///
753 /// Only meaningful in capture mode — has no effect on direct-dispatch.
754 pub fn set_pending_buffer_ranges(&mut self, reads: Vec<MemRange>, writes: Vec<MemRange>) {
755 self.pending_reads = reads;
756 self.pending_writes = writes;
757 }
758
759 /// Patch the last captured dispatch node's empty reads/writes with the
760 /// given ranges. No-op if not capturing, or if the last node isn't a
761 /// Dispatch, or if its ranges are already populated.
762 ///
763 /// Used by `GraphSession::track_dispatch` in recording mode to annotate
764 /// dispatches that were called without a preceding `barrier_between`.
765 pub fn annotate_last_dispatch_if_missing(&mut self, reads: Vec<MemRange>, writes: Vec<MemRange>) {
766 if let Some(ref mut nodes) = self.capture {
767 if let Some(CapturedNode::Dispatch { reads: r, writes: w, .. }) = nodes.last_mut() {
768 if r.is_empty() && !reads.is_empty() {
769 *r = reads;
770 }
771 if w.is_empty() && !writes.is_empty() {
772 *w = writes;
773 }
774 }
775 }
776 }
777
778 /// Consume and return the pending buffer range annotations.
779 fn take_pending_buffer_ranges(&mut self) -> (Vec<MemRange>, Vec<MemRange>) {
780 let reads = std::mem::take(&mut self.pending_reads);
781 let writes = std::mem::take(&mut self.pending_writes);
782 (reads, writes)
783 }
784
785 /// Record buffer bindings into `RecordedBinding` form.
786 fn record_buffer_bindings(buffers: &[(u64, &MlxBuffer)]) -> Vec<(u64, RecordedBinding)> {
787 buffers
788 .iter()
789 .map(|&(index, buf)| {
790 (
791 index,
792 RecordedBinding::Buffer {
793 metal_buffer: buf.metal_buffer().clone(),
794 offset: buf.byte_offset(),
795 },
796 )
797 })
798 .collect()
799 }
800
801 /// Record `KernelArg` bindings into `RecordedBinding` form.
802 ///
803 /// `KernelArg::Buffer(buf)` records `buf.byte_offset()` so capture →
804 /// replay round-trips of `slice_view`-derived buffers preserve their
805 /// offsets, matching `record_buffer_bindings`'s behavior at line 382.
806 fn record_arg_bindings(bindings: &[(u64, KernelArg<'_>)]) -> Vec<(u64, RecordedBinding)> {
807 bindings
808 .iter()
809 .map(|(index, arg)| {
810 let recorded = match arg {
811 KernelArg::Buffer(buf) => RecordedBinding::Buffer {
812 metal_buffer: buf.metal_buffer().clone(),
813 offset: buf.byte_offset(),
814 },
815 KernelArg::BufferWithOffset(buf, offset) => RecordedBinding::Buffer {
816 metal_buffer: buf.metal_buffer().clone(),
817 offset: *offset,
818 },
819 KernelArg::Bytes(bytes) => RecordedBinding::Bytes(bytes.to_vec()),
820 };
821 (*index, recorded)
822 })
823 .collect()
824 }
825
826 /// Get or create the persistent compute encoder.
827 ///
828 /// On the first call, creates a new compute encoder from the command
829 /// buffer. On subsequent calls, returns the existing one.
830 ///
831 /// SAFETY: The returned reference borrows from `self.cmd_buf` which is
832 /// alive for the lifetime of this `CommandEncoder`. The raw pointer is
833 /// valid until `end_active_encoder()` is called.
834 #[inline]
835 fn get_or_create_encoder(&mut self) -> &ComputeCommandEncoderRef {
836 if self.active_encoder.is_null() {
837 // Use MTLDispatchTypeConcurrent to allow independent dispatches
838 // to overlap on the GPU. Memory barriers are inserted between
839 // dependent dispatches via `memory_barrier()`.
840 //
841 // ADR-015 iter61a-2 probe: HF2Q_FORCE_SERIAL_DISPATCH=1 falls back
842 // to MTLDispatchType::Serial — every dispatch waits for the
843 // previous to complete, eliminating concurrent-dispatch race
844 // windows. Used to falsify Hypothesis (g): missing memory_barrier
845 // calls between dependent dispatches cause cold-run logit
846 // non-determinism via thread-race on a shared buffer.
847 let dispatch_type = if std::env::var("HF2Q_FORCE_SERIAL_DISPATCH")
848 .map(|v| v == "1")
849 .unwrap_or(false)
850 {
851 MTLDispatchType::Serial
852 } else {
853 MTLDispatchType::Concurrent
854 };
855 let encoder = self
856 .cmd_buf
857 .compute_command_encoder_with_dispatch_type(dispatch_type);
858 self.active_encoder = encoder as *const ComputeCommandEncoderRef;
859 }
860 // SAFETY: active_encoder is non-null and points to a valid encoder
861 // owned by cmd_buf.
862 unsafe { &*self.active_encoder }
863 }
864
865 /// End the active compute encoder if one exists.
866 #[inline]
867 fn end_active_encoder(&mut self) {
868 if !self.active_encoder.is_null() {
869 // SAFETY: the pointer was obtained from cmd_buf.new_compute_command_encoder()
870 // and has not been ended yet.
871 unsafe { &*self.active_encoder }.end_encoding();
872 self.active_encoder = std::ptr::null();
873 }
874 }
875
876 /// Insert a memory barrier with scope `MTLBarrierScopeBuffers`.
877 ///
878 /// When the encoder uses `MTLDispatchTypeConcurrent`, all dispatches can
879 /// execute concurrently unless separated by a barrier. Call this between
880 /// dispatches where the later dispatch reads a buffer written by an
881 /// earlier one.
882 ///
883 /// This is the same pattern llama.cpp uses:
884 /// `[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers]`
885 #[allow(unexpected_cfgs)]
886 pub fn memory_barrier(&mut self) {
887 if let Some(ref mut nodes) = self.capture {
888 nodes.push(CapturedNode::Barrier);
889 return;
890 }
891 if self.active_encoder.is_null() {
892 return;
893 }
894 BARRIER_COUNT.fetch_add(1, Ordering::Relaxed);
895 // SAFETY: active_encoder is non-null and valid.
896 let encoder = unsafe { &*self.active_encoder };
897 if barrier_profile_enabled() {
898 // mach_absolute_time path — only on when MLX_PROFILE_BARRIERS=1.
899 let start = std::time::Instant::now();
900 issue_metal_buffer_barrier(encoder);
901 let elapsed_ns = start.elapsed().as_nanos() as u64;
902 BARRIER_NS.fetch_add(elapsed_ns, Ordering::Relaxed);
903 } else {
904 issue_metal_buffer_barrier(encoder);
905 }
906 }
907
908 /// Set the compute pipeline state for subsequent dispatches.
909 ///
910 /// This begins a new compute pass if one is not already active.
911 pub fn set_pipeline(&mut self, pipeline: &ComputePipelineStateRef) {
912 let encoder = self.get_or_create_encoder();
913 encoder.set_compute_pipeline_state(pipeline);
914 }
915
916 /// Bind a buffer to a compute kernel argument slot.
917 ///
918 /// The `index` corresponds to the `[[buffer(N)]]` attribute in the MSL shader.
919 pub fn set_buffer(&self, index: u64, buffer: &MlxBuffer) {
920 let _ = (index, buffer);
921 }
922
923 /// Dispatch threads on the GPU.
924 pub fn dispatch_threads(&self, grid_size: MTLSize, threadgroup_size: MTLSize) {
925 let _ = (grid_size, threadgroup_size);
926 }
927
928 /// Encode a complete compute pass: set pipeline, bind buffers, dispatch.
929 ///
930 /// Reuses the persistent compute encoder — no per-dispatch encoder
931 /// creation overhead.
932 ///
933 /// # Arguments
934 ///
935 /// * `pipeline` — The compiled compute pipeline to execute.
936 /// * `buffers` — Slice of `(index, &MlxBuffer)` pairs for buffer bindings.
937 /// * `grid_size` — Total number of threads to launch.
938 /// * `threadgroup_size` — Threads per threadgroup.
939 pub fn encode(
940 &mut self,
941 pipeline: &ComputePipelineStateRef,
942 buffers: &[(u64, &MlxBuffer)],
943 grid_size: MTLSize,
944 threadgroup_size: MTLSize,
945 ) {
946 DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
947 let op_kind = self.take_pending_op_kind();
948 let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
949 if let Some(ref mut nodes) = self.capture {
950 nodes.push(CapturedNode::Dispatch {
951 pipeline: pipeline.to_owned(),
952 bindings: Self::record_buffer_bindings(buffers),
953 threads_per_grid: grid_size,
954 threads_per_threadgroup: threadgroup_size,
955 threadgroup_memory: Vec::new(),
956 dispatch_kind: DispatchKind::Threads,
957 op_kind,
958 reads: pending_reads,
959 writes: pending_writes,
960 });
961 return;
962 }
963 // ADR-015 iter63: per-dispatch sampling (no-op when env unset).
964 self.ensure_sample_buffer();
965 let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
966 // SAFETY: encoder_ptr aliases &self via active_encoder which we
967 // know is non-null after get_or_create_encoder; this pattern is
968 // used throughout the file (see memory_barrier).
969 let encoder = unsafe { &*encoder_ptr };
970 encoder.set_compute_pipeline_state(pipeline);
971 for &(index, buf) in buffers {
972 encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
973 }
974 let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
975 encoder.dispatch_threads(grid_size, threadgroup_size);
976 self.sample_dispatch_post(encoder, pre_idx);
977 }
978
979 /// Encode a compute pass using threadgroups instead of raw thread counts.
980 ///
981 /// Reuses the persistent compute encoder — no per-dispatch encoder
982 /// creation overhead.
983 pub fn encode_threadgroups(
984 &mut self,
985 pipeline: &ComputePipelineStateRef,
986 buffers: &[(u64, &MlxBuffer)],
987 threadgroups: MTLSize,
988 threadgroup_size: MTLSize,
989 ) {
990 DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
991 let op_kind = self.take_pending_op_kind();
992 let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
993 if let Some(ref mut nodes) = self.capture {
994 nodes.push(CapturedNode::Dispatch {
995 pipeline: pipeline.to_owned(),
996 bindings: Self::record_buffer_bindings(buffers),
997 threads_per_grid: threadgroups,
998 threads_per_threadgroup: threadgroup_size,
999 threadgroup_memory: Vec::new(),
1000 dispatch_kind: DispatchKind::ThreadGroups,
1001 op_kind,
1002 reads: pending_reads,
1003 writes: pending_writes,
1004 });
1005 return;
1006 }
1007 // ADR-015 iter63: per-dispatch sampling (no-op when env unset).
1008 self.ensure_sample_buffer();
1009 let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
1010 // SAFETY: see encode() above.
1011 let encoder = unsafe { &*encoder_ptr };
1012 encoder.set_compute_pipeline_state(pipeline);
1013 for &(index, buf) in buffers {
1014 encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
1015 }
1016 let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
1017 encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
1018 self.sample_dispatch_post(encoder, pre_idx);
1019 }
1020
1021 /// Encode a compute pass using threadgroups with shared threadgroup memory.
1022 ///
1023 /// Like [`encode_threadgroups`](Self::encode_threadgroups), but additionally
1024 /// allocates threadgroup memory at the specified indices. This is required
1025 /// for kernels that use `threadgroup` memory (e.g. reductions in rms_norm
1026 /// and softmax).
1027 ///
1028 /// # Arguments
1029 ///
1030 /// * `pipeline` — The compiled compute pipeline to execute.
1031 /// * `buffers` — Slice of `(index, &MlxBuffer)` pairs for buffer bindings.
1032 /// * `threadgroup_mem` — Slice of `(index, byte_length)` pairs for threadgroup memory.
1033 /// * `threadgroups` — Number of threadgroups to dispatch.
1034 /// * `threadgroup_size` — Threads per threadgroup.
1035 pub fn encode_threadgroups_with_shared(
1036 &mut self,
1037 pipeline: &ComputePipelineStateRef,
1038 buffers: &[(u64, &MlxBuffer)],
1039 threadgroup_mem: &[(u64, u64)],
1040 threadgroups: MTLSize,
1041 threadgroup_size: MTLSize,
1042 ) {
1043 DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
1044 let op_kind = self.take_pending_op_kind();
1045 let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
1046 if let Some(ref mut nodes) = self.capture {
1047 nodes.push(CapturedNode::Dispatch {
1048 pipeline: pipeline.to_owned(),
1049 bindings: Self::record_buffer_bindings(buffers),
1050 threads_per_grid: threadgroups,
1051 threads_per_threadgroup: threadgroup_size,
1052 threadgroup_memory: threadgroup_mem.to_vec(),
1053 dispatch_kind: DispatchKind::ThreadGroups,
1054 op_kind,
1055 reads: pending_reads,
1056 writes: pending_writes,
1057 });
1058 return;
1059 }
1060 // ADR-015 iter63: per-dispatch sampling (no-op when env unset).
1061 self.ensure_sample_buffer();
1062 let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
1063 // SAFETY: see encode() above.
1064 let encoder = unsafe { &*encoder_ptr };
1065 encoder.set_compute_pipeline_state(pipeline);
1066 for &(index, buf) in buffers {
1067 encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
1068 }
1069 for &(index, byte_length) in threadgroup_mem {
1070 encoder.set_threadgroup_memory_length(index, byte_length);
1071 }
1072 let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
1073 encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
1074 self.sample_dispatch_post(encoder, pre_idx);
1075 }
1076
1077 /// Encode a dispatch with mixed buffer/bytes bindings (dispatch_threads).
1078 ///
1079 /// Reuses the persistent compute encoder.
1080 pub fn encode_with_args(
1081 &mut self,
1082 pipeline: &ComputePipelineStateRef,
1083 bindings: &[(u64, KernelArg<'_>)],
1084 grid_size: MTLSize,
1085 threadgroup_size: MTLSize,
1086 ) {
1087 DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
1088 let op_kind = self.take_pending_op_kind();
1089 let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
1090 if let Some(ref mut nodes) = self.capture {
1091 nodes.push(CapturedNode::Dispatch {
1092 pipeline: pipeline.to_owned(),
1093 bindings: Self::record_arg_bindings(bindings),
1094 threads_per_grid: grid_size,
1095 threads_per_threadgroup: threadgroup_size,
1096 threadgroup_memory: Vec::new(),
1097 dispatch_kind: DispatchKind::Threads,
1098 op_kind,
1099 reads: pending_reads,
1100 writes: pending_writes,
1101 });
1102 return;
1103 }
1104 // ADR-015 iter63: per-dispatch sampling (no-op when env unset).
1105 self.ensure_sample_buffer();
1106 let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
1107 // SAFETY: see encode() above.
1108 let encoder = unsafe { &*encoder_ptr };
1109 encoder.set_compute_pipeline_state(pipeline);
1110 apply_bindings(encoder, bindings);
1111 let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
1112 encoder.dispatch_threads(grid_size, threadgroup_size);
1113 self.sample_dispatch_post(encoder, pre_idx);
1114 }
1115
1116 /// Encode a dispatch with mixed buffer/bytes bindings (dispatch_thread_groups).
1117 ///
1118 /// Reuses the persistent compute encoder.
1119 pub fn encode_threadgroups_with_args(
1120 &mut self,
1121 pipeline: &ComputePipelineStateRef,
1122 bindings: &[(u64, KernelArg<'_>)],
1123 threadgroups: MTLSize,
1124 threadgroup_size: MTLSize,
1125 ) {
1126 DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
1127 let op_kind = self.take_pending_op_kind();
1128 let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
1129 if let Some(ref mut nodes) = self.capture {
1130 nodes.push(CapturedNode::Dispatch {
1131 pipeline: pipeline.to_owned(),
1132 bindings: Self::record_arg_bindings(bindings),
1133 threads_per_grid: threadgroups,
1134 threads_per_threadgroup: threadgroup_size,
1135 threadgroup_memory: Vec::new(),
1136 dispatch_kind: DispatchKind::ThreadGroups,
1137 op_kind,
1138 reads: pending_reads,
1139 writes: pending_writes,
1140 });
1141 return;
1142 }
1143 // ADR-015 iter63: per-dispatch sampling (no-op when env unset).
1144 self.ensure_sample_buffer();
1145 let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
1146 // SAFETY: see encode() above.
1147 let encoder = unsafe { &*encoder_ptr };
1148 encoder.set_compute_pipeline_state(pipeline);
1149 apply_bindings(encoder, bindings);
1150 let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
1151 encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
1152 self.sample_dispatch_post(encoder, pre_idx);
1153 }
1154
1155 /// Encode a dispatch with mixed buffer/bytes bindings and shared memory.
1156 ///
1157 /// Reuses the persistent compute encoder.
1158 pub fn encode_threadgroups_with_args_and_shared(
1159 &mut self,
1160 pipeline: &ComputePipelineStateRef,
1161 bindings: &[(u64, KernelArg<'_>)],
1162 threadgroup_mem: &[(u64, u64)],
1163 threadgroups: MTLSize,
1164 threadgroup_size: MTLSize,
1165 ) {
1166 DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
1167 let op_kind = self.take_pending_op_kind();
1168 let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
1169 if let Some(ref mut nodes) = self.capture {
1170 nodes.push(CapturedNode::Dispatch {
1171 pipeline: pipeline.to_owned(),
1172 bindings: Self::record_arg_bindings(bindings),
1173 threads_per_grid: threadgroups,
1174 threads_per_threadgroup: threadgroup_size,
1175 threadgroup_memory: threadgroup_mem.to_vec(),
1176 dispatch_kind: DispatchKind::ThreadGroups,
1177 op_kind,
1178 reads: pending_reads,
1179 writes: pending_writes,
1180 });
1181 return;
1182 }
1183 // ADR-015 iter63: per-dispatch sampling (no-op when env unset).
1184 self.ensure_sample_buffer();
1185 let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
1186 // SAFETY: see encode() above.
1187 let encoder = unsafe { &*encoder_ptr };
1188 encoder.set_compute_pipeline_state(pipeline);
1189 apply_bindings(encoder, bindings);
1190 for &(index, byte_length) in threadgroup_mem {
1191 encoder.set_threadgroup_memory_length(index, byte_length);
1192 }
1193 let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
1194 encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
1195 self.sample_dispatch_post(encoder, pre_idx);
1196 }
1197
1198 // -----------------------------------------------------------------
1199 // ADR-015 iter37 — dataflow-driven auto-barrier dispatch family.
1200 //
1201 // These mirrors of `encode_threadgroups*_with_args*` take explicit
1202 // `reads: &[&MlxBuffer]` and `writes: &[&MlxBuffer]` slices. When
1203 // the process started with `HF2Q_AUTO_BARRIER=1`, the encoder's
1204 // [`MemRanges`] tracker checks the new ranges against the
1205 // cumulative state since the last barrier; on conflict it emits
1206 // `memory_barrier()` and resets the state before recording the
1207 // new ranges. When the env gate is unset, the check is skipped
1208 // entirely and the dispatch is applied identically to the
1209 // matching `encode_*` method — sourdough-safe by construction.
1210 //
1211 // Capture mode: the `reads`/`writes` ranges are recorded onto the
1212 // captured node via the existing `pending_reads`/`pending_writes`
1213 // mechanism, so a `dispatch_tracked` call inside capture mode is
1214 // equivalent to `set_pending_buffer_ranges + encode_*`.
1215 //
1216 // No production callsite migrates in iter37 — this is the API
1217 // surface the qwen35 forward path will adopt incrementally in
1218 // iter38+. Today, every call to `dispatch_tracked` from a
1219 // production code path lives behind an explicit caller decision
1220 // to opt in.
1221 // -----------------------------------------------------------------
1222
1223 /// Auto-barrier-aware dispatch with [`KernelArg`] bindings (uses
1224 /// `dispatch_thread_groups`).
1225 ///
1226 /// Behaves identically to
1227 /// [`encode_threadgroups_with_args`](Self::encode_threadgroups_with_args)
1228 /// when `HF2Q_AUTO_BARRIER` is unset. When set, consults the
1229 /// per-encoder [`MemRanges`] tracker:
1230 ///
1231 /// * Conflict (RAW/WAR/WAW on a same-buffer range) → emit
1232 /// `memory_barrier()`, increment [`AUTO_BARRIER_COUNT`], reset
1233 /// the tracker, then dispatch and seed the new concurrent group
1234 /// with this dispatch's ranges.
1235 /// * No conflict → increment [`AUTO_BARRIER_CONCURRENT`], record
1236 /// the ranges into the cumulative state, dispatch.
1237 pub fn dispatch_tracked_threadgroups_with_args(
1238 &mut self,
1239 pipeline: &ComputePipelineStateRef,
1240 bindings: &[(u64, KernelArg<'_>)],
1241 reads: &[&MlxBuffer],
1242 writes: &[&MlxBuffer],
1243 threadgroups: MTLSize,
1244 threadgroup_size: MTLSize,
1245 ) {
1246 // Capture mode: stash ranges + delegate to the standard encode.
1247 // The ranges flow through `pending_reads`/`pending_writes` and
1248 // attach to the captured `Dispatch` node — identical to what
1249 // `GraphSession::barrier_between` already does in capture mode.
1250 if self.is_capturing() {
1251 let read_ranges = ranges_from_buffers(reads);
1252 let write_ranges = ranges_from_buffers(writes);
1253 self.set_pending_buffer_ranges(read_ranges, write_ranges);
1254 self.encode_threadgroups_with_args(pipeline, bindings, threadgroups, threadgroup_size);
1255 return;
1256 }
1257
1258 if auto_barrier_enabled() {
1259 self.maybe_auto_barrier(reads, writes);
1260 }
1261
1262 self.encode_threadgroups_with_args(pipeline, bindings, threadgroups, threadgroup_size);
1263 }
1264
1265 /// Auto-barrier-aware dispatch with [`KernelArg`] bindings + shared
1266 /// threadgroup memory.
1267 ///
1268 /// See [`dispatch_tracked_threadgroups_with_args`](Self::dispatch_tracked_threadgroups_with_args)
1269 /// for the behavioral contract; this variant additionally takes a
1270 /// `threadgroup_mem` slice that is forwarded to
1271 /// [`encode_threadgroups_with_args_and_shared`](Self::encode_threadgroups_with_args_and_shared).
1272 ///
1273 /// The 8-argument signature mirrors the existing
1274 /// `encode_threadgroups_with_args_and_shared` plus the two
1275 /// dataflow slices; `clippy::too_many_arguments` is allowed
1276 /// because each parameter is load-bearing for either the dispatch
1277 /// (pipeline/bindings/threadgroups/threadgroup_size/shared_mem)
1278 /// or the auto-barrier (reads/writes).
1279 #[allow(clippy::too_many_arguments)]
1280 pub fn dispatch_tracked_threadgroups_with_args_and_shared(
1281 &mut self,
1282 pipeline: &ComputePipelineStateRef,
1283 bindings: &[(u64, KernelArg<'_>)],
1284 threadgroup_mem: &[(u64, u64)],
1285 reads: &[&MlxBuffer],
1286 writes: &[&MlxBuffer],
1287 threadgroups: MTLSize,
1288 threadgroup_size: MTLSize,
1289 ) {
1290 if self.is_capturing() {
1291 let read_ranges = ranges_from_buffers(reads);
1292 let write_ranges = ranges_from_buffers(writes);
1293 self.set_pending_buffer_ranges(read_ranges, write_ranges);
1294 self.encode_threadgroups_with_args_and_shared(
1295 pipeline,
1296 bindings,
1297 threadgroup_mem,
1298 threadgroups,
1299 threadgroup_size,
1300 );
1301 return;
1302 }
1303
1304 if auto_barrier_enabled() {
1305 self.maybe_auto_barrier(reads, writes);
1306 }
1307
1308 self.encode_threadgroups_with_args_and_shared(
1309 pipeline,
1310 bindings,
1311 threadgroup_mem,
1312 threadgroups,
1313 threadgroup_size,
1314 );
1315 }
1316
1317 /// Auto-barrier-aware dispatch using `(slot, &MlxBuffer)` bindings
1318 /// (uses `dispatch_thread_groups`).
1319 ///
1320 /// Convenience wrapper for callers that don't need
1321 /// [`KernelArg::Bytes`] inline-byte arguments. See
1322 /// [`dispatch_tracked_threadgroups_with_args`](Self::dispatch_tracked_threadgroups_with_args)
1323 /// for behavioral contract.
1324 pub fn dispatch_tracked_threadgroups(
1325 &mut self,
1326 pipeline: &ComputePipelineStateRef,
1327 buffers: &[(u64, &MlxBuffer)],
1328 reads: &[&MlxBuffer],
1329 writes: &[&MlxBuffer],
1330 threadgroups: MTLSize,
1331 threadgroup_size: MTLSize,
1332 ) {
1333 if self.is_capturing() {
1334 let read_ranges = ranges_from_buffers(reads);
1335 let write_ranges = ranges_from_buffers(writes);
1336 self.set_pending_buffer_ranges(read_ranges, write_ranges);
1337 self.encode_threadgroups(pipeline, buffers, threadgroups, threadgroup_size);
1338 return;
1339 }
1340
1341 if auto_barrier_enabled() {
1342 self.maybe_auto_barrier(reads, writes);
1343 }
1344
1345 self.encode_threadgroups(pipeline, buffers, threadgroups, threadgroup_size);
1346 }
1347
1348 /// Auto-barrier-aware dispatch using `(slot, &MlxBuffer)` bindings
1349 /// **plus shared threadgroup memory** (uses `dispatch_thread_groups`).
1350 ///
1351 /// Mirrors [`encode_threadgroups_with_shared`](Self::encode_threadgroups_with_shared)
1352 /// — convenience variant for kernels that allocate threadgroup
1353 /// memory (reductions in `rms_norm`, `softmax`, etc.) but don't
1354 /// need [`KernelArg::Bytes`] inline-byte arguments. See
1355 /// [`dispatch_tracked_threadgroups_with_args`](Self::dispatch_tracked_threadgroups_with_args)
1356 /// for the behavioral contract; the only addition here is the
1357 /// `threadgroup_mem` slice forwarded to the underlying encode.
1358 ///
1359 /// Closes the iter38-audit coverage gap: the 5 `rms_norm.rs`
1360 /// callsites (`/opt/mlx-native/src/ops/rms_norm.rs:124,236,443,
1361 /// 516,589`) all use `encode_threadgroups_with_shared` and need
1362 /// dataflow tracking when migrated to auto-barrier in iter40+.
1363 ///
1364 /// 7-argument signature; `clippy::too_many_arguments` is allowed
1365 /// because each parameter is load-bearing for either the dispatch
1366 /// (pipeline/buffers/threadgroups/threadgroup_size/shared_mem) or
1367 /// the auto-barrier (reads/writes).
1368 #[allow(clippy::too_many_arguments)]
1369 pub fn dispatch_tracked_threadgroups_with_shared(
1370 &mut self,
1371 pipeline: &ComputePipelineStateRef,
1372 buffers: &[(u64, &MlxBuffer)],
1373 threadgroup_mem: &[(u64, u64)],
1374 reads: &[&MlxBuffer],
1375 writes: &[&MlxBuffer],
1376 threadgroups: MTLSize,
1377 threadgroup_size: MTLSize,
1378 ) {
1379 if self.is_capturing() {
1380 let read_ranges = ranges_from_buffers(reads);
1381 let write_ranges = ranges_from_buffers(writes);
1382 self.set_pending_buffer_ranges(read_ranges, write_ranges);
1383 self.encode_threadgroups_with_shared(
1384 pipeline,
1385 buffers,
1386 threadgroup_mem,
1387 threadgroups,
1388 threadgroup_size,
1389 );
1390 return;
1391 }
1392
1393 if auto_barrier_enabled() {
1394 self.maybe_auto_barrier(reads, writes);
1395 }
1396
1397 self.encode_threadgroups_with_shared(
1398 pipeline,
1399 buffers,
1400 threadgroup_mem,
1401 threadgroups,
1402 threadgroup_size,
1403 );
1404 }
1405
1406 /// Auto-barrier-aware `dispatch_threads` variant with
1407 /// [`KernelArg`] bindings.
1408 ///
1409 /// Mirrors [`encode_with_args`](Self::encode_with_args) — the
1410 /// `dispatch_threads` (per-thread grid) flavor, as opposed to the
1411 /// `dispatch_thread_groups` flavor of
1412 /// [`dispatch_tracked_threadgroups_with_args`](Self::dispatch_tracked_threadgroups_with_args).
1413 /// See that method for the behavioral contract.
1414 ///
1415 /// Closes the iter38-audit coverage gap: callers that use
1416 /// per-thread grids — `rope.rs:108` (IMROPE), `sigmoid_mul.rs:76`
1417 /// (sigmoid-mul), and `encode_helpers.rs:41` (kv_cache_copy) —
1418 /// need a `dispatch_threads` flavor of the tracked dispatch
1419 /// because their grid sizes are expressed in threads, not
1420 /// threadgroups.
1421 ///
1422 /// Note: the simpler `(slot, &MlxBuffer)` form (from
1423 /// [`encode`](Self::encode)) is a special case of this method —
1424 /// callers can wrap each binding as `KernelArg::Buffer(buf)` to
1425 /// reuse this single tracked variant rather than introducing a
1426 /// fifth one.
1427 pub fn dispatch_tracked_threads_with_args(
1428 &mut self,
1429 pipeline: &ComputePipelineStateRef,
1430 bindings: &[(u64, KernelArg<'_>)],
1431 reads: &[&MlxBuffer],
1432 writes: &[&MlxBuffer],
1433 grid_size: MTLSize,
1434 threadgroup_size: MTLSize,
1435 ) {
1436 if self.is_capturing() {
1437 let read_ranges = ranges_from_buffers(reads);
1438 let write_ranges = ranges_from_buffers(writes);
1439 self.set_pending_buffer_ranges(read_ranges, write_ranges);
1440 self.encode_with_args(pipeline, bindings, grid_size, threadgroup_size);
1441 return;
1442 }
1443
1444 if auto_barrier_enabled() {
1445 self.maybe_auto_barrier(reads, writes);
1446 }
1447
1448 self.encode_with_args(pipeline, bindings, grid_size, threadgroup_size);
1449 }
1450
1451 /// Run the dataflow check, emit a barrier on conflict, and record
1452 /// the dispatch's ranges into the cumulative state.
1453 ///
1454 /// Always called *before* the underlying `encode_*` method
1455 /// applies the dispatch. Mirrors lines 220-225 of
1456 /// `ggml-metal-ops.cpp` (`concurrency_check + concurrency_reset +
1457 /// concurrency_add` around each node).
1458 fn maybe_auto_barrier(
1459 &mut self,
1460 reads: &[&MlxBuffer],
1461 writes: &[&MlxBuffer],
1462 ) {
1463 if self.mem_ranges.check_dispatch(reads, writes) {
1464 // Concurrent — no barrier needed; just record the new ranges.
1465 self.mem_ranges.add_dispatch(reads, writes);
1466 AUTO_BARRIER_CONCURRENT.fetch_add(1, Ordering::Relaxed);
1467 } else {
1468 // Conflict — emit barrier, reset state, seed new group.
1469 //
1470 // `memory_barrier()` itself increments `BARRIER_COUNT` and,
1471 // when `MLX_PROFILE_BARRIERS=1`, accumulates `BARRIER_NS`.
1472 // We additionally bump `AUTO_BARRIER_COUNT` so the
1473 // "auto-emitted vs hand-placed" subset is queryable.
1474 self.memory_barrier();
1475 self.mem_ranges.reset();
1476 self.mem_ranges.add_dispatch(reads, writes);
1477 AUTO_BARRIER_COUNT.fetch_add(1, Ordering::Relaxed);
1478 }
1479 }
1480
1481 /// Force a barrier and reset the auto-barrier tracker.
1482 ///
1483 /// Use at boundaries where the caller knows a barrier is required
1484 /// regardless of dataflow — typically before reading data back to
1485 /// CPU, or at the end of an op group whose internal dependencies
1486 /// the tracker can't see (e.g. host-driven memcpy).
1487 ///
1488 /// Equivalent to `memory_barrier()` plus a `MemRanges::reset()`
1489 /// when `HF2Q_AUTO_BARRIER=1`; equivalent to plain
1490 /// `memory_barrier()` otherwise.
1491 pub fn force_barrier_and_reset_tracker(&mut self) {
1492 self.memory_barrier();
1493 if auto_barrier_enabled() {
1494 self.mem_ranges.reset();
1495 }
1496 }
1497
1498 /// Diagnostic accessor — number of ranges currently recorded in
1499 /// this encoder's [`MemRanges`] tracker. Always zero unless
1500 /// `HF2Q_AUTO_BARRIER=1` and at least one `dispatch_tracked` call
1501 /// has fired since the last conflict.
1502 #[inline]
1503 pub fn mem_ranges_len(&self) -> usize {
1504 self.mem_ranges.len()
1505 }
1506
1507 /// Replay a single captured dispatch node into this encoder.
1508 ///
1509 /// This is the inverse of capture: it takes a previously recorded
1510 /// `CapturedNode::Dispatch` and encodes it into the live Metal encoder.
1511 /// Barrier nodes are handled by the caller (ComputeGraph::encode_sequential).
1512 ///
1513 /// Does NOT increment `DISPATCH_COUNT` — that was already counted at
1514 /// capture time.
1515 pub fn replay_dispatch(
1516 &mut self,
1517 pipeline: &ComputePipelineStateRef,
1518 bindings: &[(u64, RecordedBinding)],
1519 threadgroup_memory: &[(u64, u64)],
1520 threads_per_grid: MTLSize,
1521 threads_per_threadgroup: MTLSize,
1522 dispatch_kind: DispatchKind,
1523 ) {
1524 // ADR-015 iter63 (Phase A.3): mirror the per-dispatch sampling
1525 // scaffold here so capture-mode-recorded graphs (graph.rs
1526 // encode_sequential / encode_with_barriers / encode_chunk_with
1527 // _barriers) still produce per-dispatch entries. The replay
1528 // path bypasses encode*; without this hook the per-dispatch
1529 // table would be silently empty for any model that uses
1530 // `GraphExecutor::begin_recorded`.
1531 //
1532 // Captured `op_kind` is forwarded via `pending_op_kind`: the
1533 // graph replay layer at graph.rs:197/236/727 sets it from the
1534 // CapturedNode.op_kind before calling replay_dispatch.
1535 self.ensure_sample_buffer();
1536 let op_kind = self.take_pending_op_kind();
1537 let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
1538 // SAFETY: see encode() above.
1539 let encoder = unsafe { &*encoder_ptr };
1540 encoder.set_compute_pipeline_state(pipeline);
1541 for (index, binding) in bindings {
1542 match binding {
1543 RecordedBinding::Buffer { metal_buffer, offset } => {
1544 encoder.set_buffer(*index, Some(metal_buffer), *offset);
1545 }
1546 RecordedBinding::Bytes(bytes) => {
1547 encoder.set_bytes(
1548 *index,
1549 bytes.len() as u64,
1550 bytes.as_ptr() as *const _,
1551 );
1552 }
1553 }
1554 }
1555 for &(index, byte_length) in threadgroup_memory {
1556 encoder.set_threadgroup_memory_length(index, byte_length);
1557 }
1558 let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
1559 match dispatch_kind {
1560 DispatchKind::Threads => {
1561 encoder.dispatch_threads(threads_per_grid, threads_per_threadgroup);
1562 }
1563 DispatchKind::ThreadGroups => {
1564 encoder.dispatch_thread_groups(threads_per_grid, threads_per_threadgroup);
1565 }
1566 }
1567 self.sample_dispatch_post(encoder, pre_idx);
1568 }
1569
1570 /// Flush any pending residency-set add/remove staging.
1571 ///
1572 /// Hooked at every commit boundary so per-allocation
1573 /// [`ResidencySet::add_allocation`](ResidencySet::add_allocation) and
1574 /// [`ResidencySet::remove_allocation`](ResidencySet::remove_allocation)
1575 /// calls (as fired by `MlxDevice::alloc_buffer` and
1576 /// `MlxBufferStorage::Drop`) collapse into at most ONE `[set commit]`
1577 /// per CB submission. Mirrors llama.cpp's
1578 /// `ggml-metal-device.m:1378-1382` (batch addAllocation in loop,
1579 /// commit ONCE).
1580 #[inline]
1581 fn flush_residency_pending(&self) {
1582 if let Some(set) = self.residency_set.as_ref() {
1583 set.flush_pending();
1584 }
1585 }
1586
1587 // ----------------------------------------------------------------
1588 // ADR-015 iter63 — per-dispatch sample buffer lifecycle
1589 // ----------------------------------------------------------------
1590
1591 /// Allocate the per-CB `MTLCounterSampleBuffer` if it has not been
1592 /// allocated yet for this CB.
1593 ///
1594 /// No-op when `MLX_PROFILE_DISPATCH` is unset, when the buffer is
1595 /// already present, or when the device does not expose a counter
1596 /// set named `"timestamp"` (Risk R1 — graceful degrade with a
1597 /// one-shot stderr warning).
1598 ///
1599 /// The sample buffer is sized to [`MAX_SAMPLES_PER_CB`] (32_768).
1600 /// This is the start-+-end pair budget — i.e. ≤ 16,384 dispatches
1601 /// per CB. Above that ceiling, additional dispatches will skip
1602 /// sampling (see [`Self::sample_dispatch_pre`]).
1603 #[inline]
1604 fn ensure_sample_buffer(&mut self) {
1605 if !crate::kernel_profile::is_dispatch_enabled() {
1606 return;
1607 }
1608 if self.sample_buffer.is_some() {
1609 return;
1610 }
1611 // Discover the timestamp counter set. metal-rs 0.33 does not
1612 // export the `MTLCommonCounterSetTimestamp` constant, so we
1613 // name-match `"timestamp"` case-insensitively. Reach the
1614 // device via the cmd_buf's `device` selector (metal-rs 0.33
1615 // exposes `CommandQueue::device` but not `CommandBuffer::device`,
1616 // so we go through ObjC directly).
1617 let device: &metal::DeviceRef = unsafe {
1618 let cb = &*self.cmd_buf;
1619 msg_send![cb, device]
1620 };
1621 // ADR-015 iter63 — Apple Silicon hardware constraint (NEW Risk
1622 // discovered at impl time, supersedes design §A.7). M-series
1623 // GPUs (verified: AGXG17XFamilyComputeContext = M5 Max series,
1624 // macOS 26) only support counter sampling AtStageBoundary —
1625 // i.e. between compute *passes*, not between dispatches inside
1626 // a persistent compute encoder. Calling
1627 // `sampleCountersInBuffer:atSampleIndex:withBarrier:` on such
1628 // hardware aborts with `failed assertion ... not supported on
1629 // this device`. The persistent-encoder design (mlx-native uses
1630 // ONE compute encoder per CB to amortize ~800 encoder
1631 // create/end cycles per forward pass — see `get_or_create_
1632 // encoder` docstring) is incompatible with stage-boundary-only
1633 // sampling, so on Apple Silicon we degrade per-dispatch
1634 // profiling to a no-op and log once. Per-CB profiling is
1635 // unaffected (it uses MTLCommandBuffer.GPUStartTime/
1636 // GPUEndTime, which are always available).
1637 //
1638 // Future: if Apple ever ships AtDispatchBoundary support on
1639 // Apple Silicon, this branch becomes a true cap check. For
1640 // now, the kit infrastructure is in place; only the sample-
1641 // point cooperates.
1642 if !device.supports_counter_sampling(MTLCounterSamplingPoint::AtDispatchBoundary) {
1643 if TIMESTAMP_SET_WARN_LOGGED
1644 .compare_exchange(0, 1, Ordering::Relaxed, Ordering::Relaxed)
1645 .is_ok()
1646 {
1647 eprintln!(
1648 "[mlx-native] MLX_PROFILE_DISPATCH=1 ignored: \
1649 device {:?} does NOT support \
1650 MTLCounterSamplingPointAtDispatchBoundary \
1651 (Apple Silicon limitation; only AtStageBoundary \
1652 is supported, which is incompatible with the \
1653 persistent compute-encoder pattern). \
1654 MLX_PROFILE_CB=1 still produces per-CB GPU times.",
1655 device.name()
1656 );
1657 }
1658 return;
1659 }
1660 let counter_sets = device.counter_sets();
1661 let timestamp_set = counter_sets
1662 .iter()
1663 .find(|c: &&metal::CounterSet| c.name().eq_ignore_ascii_case("timestamp"));
1664 let timestamp_set = match timestamp_set {
1665 Some(s) => s,
1666 None => {
1667 // Risk R1: device does not expose a timestamp set.
1668 // Log once and degrade to no-op (sample_buffer stays None).
1669 if TIMESTAMP_SET_WARN_LOGGED
1670 .compare_exchange(0, 1, Ordering::Relaxed, Ordering::Relaxed)
1671 .is_ok()
1672 {
1673 eprintln!(
1674 "[mlx-native] MLX_PROFILE_DISPATCH=1 ignored: \
1675 device {:?} exposes no MTLCommonCounterSetTimestamp",
1676 device.name()
1677 );
1678 }
1679 return;
1680 }
1681 };
1682 // Build descriptor. StorageMode::Shared is required by
1683 // resolveCounterRange (MTLCounters.h:185-188).
1684 let descriptor = CounterSampleBufferDescriptor::new();
1685 descriptor.set_counter_set(timestamp_set);
1686 descriptor.set_storage_mode(MTLStorageMode::Shared);
1687 descriptor.set_label("mlx_native.dispatch_samples");
1688 descriptor.set_sample_count(MAX_SAMPLES_PER_CB);
1689 match device.new_counter_sample_buffer_with_descriptor(&descriptor) {
1690 Ok(buf) => {
1691 self.sample_buffer = Some(buf);
1692 }
1693 Err(e) => {
1694 if TIMESTAMP_SET_WARN_LOGGED
1695 .compare_exchange(0, 1, Ordering::Relaxed, Ordering::Relaxed)
1696 .is_ok()
1697 {
1698 eprintln!(
1699 "[mlx-native] MLX_PROFILE_DISPATCH=1 ignored: \
1700 newCounterSampleBufferWithDescriptor failed: {}",
1701 e
1702 );
1703 }
1704 self.sample_buffer = None;
1705 }
1706 }
1707 }
1708
1709 /// Insert the start-of-dispatch counter sample (sample index `2*i`)
1710 /// and queue the per-dispatch metadata. Returns the dispatch
1711 /// ordinal `i` so the caller can emit the matching post-sample.
1712 ///
1713 /// No-op when sampling is inactive — returns 0 in that case (the
1714 /// returned value is only consumed when the sample buffer is
1715 /// active, so this is safe).
1716 ///
1717 /// `with_barrier:true` is mandatory: the encoder uses
1718 /// `MTLDispatchTypeConcurrent` and without the barrier the start
1719 /// timestamp would race against any in-flight dispatch (PROFILING-
1720 /// KIT-DESIGN §A.5).
1721 #[inline]
1722 fn sample_dispatch_pre(
1723 &mut self,
1724 encoder: &ComputeCommandEncoderRef,
1725 op_kind: CapturedOpKind,
1726 ) -> Option<u32> {
1727 let sb = self.sample_buffer.as_ref()?;
1728 let i = self.dispatch_in_cb;
1729 let pre_idx = (i as u64).checked_mul(2)?;
1730 if pre_idx >= MAX_SAMPLES_PER_CB {
1731 // Ceiling exceeded — skip sampling for the remainder of
1732 // this CB. Risk R4 (PROFILING-KIT-DESIGN §A.7): future
1733 // iter can chunk-resolve every N dispatches; for now we
1734 // accept truncation with a one-shot warning (re-uses the
1735 // R1 warn flag).
1736 return None;
1737 }
1738 encoder.sample_counters_in_buffer(sb, pre_idx, true);
1739 self.pending_dispatch_meta.push(PendingDispatchMeta {
1740 op_kind: op_kind.name(),
1741 dispatch_index: i,
1742 });
1743 Some(i)
1744 }
1745
1746 /// Insert the end-of-dispatch counter sample (sample index `2*i+1`)
1747 /// matching the most recent [`Self::sample_dispatch_pre`].
1748 ///
1749 /// No-op when sampling is inactive or when `pre_idx` is `None`.
1750 #[inline]
1751 fn sample_dispatch_post(
1752 &mut self,
1753 encoder: &ComputeCommandEncoderRef,
1754 pre_idx: Option<u32>,
1755 ) {
1756 let i = match pre_idx {
1757 Some(v) => v,
1758 None => return,
1759 };
1760 let sb = match self.sample_buffer.as_ref() {
1761 Some(b) => b,
1762 None => return,
1763 };
1764 let post_idx = match (i as u64).checked_mul(2).and_then(|v| v.checked_add(1)) {
1765 Some(v) if v < MAX_SAMPLES_PER_CB => v,
1766 _ => return,
1767 };
1768 encoder.sample_counters_in_buffer(sb, post_idx, true);
1769 // Bump the per-CB ordinal only after both samples committed
1770 // successfully so a truncation skip leaves the meta queue
1771 // length matching the buffer's resolved range.
1772 self.dispatch_in_cb = i.saturating_add(1);
1773 }
1774
1775 /// Resolve the per-CB sample buffer, push entries into
1776 /// [`crate::kernel_profile`], and reset per-CB state.
1777 ///
1778 /// Called from [`Self::commit_and_wait_labeled`] after the CB
1779 /// completes; the caller is responsible for ensuring the GPU has
1780 /// finished (otherwise `resolveCounterRange` returns garbage).
1781 ///
1782 /// On the first resolve after a [`crate::kernel_profile::reset`],
1783 /// also captures a `(cpu_ns, gpu_ticks)` pair via
1784 /// `device.sampleTimestamps` so subsequent ticks→ns conversion
1785 /// uses a fresh scale factor.
1786 fn resolve_dispatch_samples(&mut self, cb_label: &str) -> Result<()> {
1787 let sb = match self.sample_buffer.take() {
1788 Some(b) => b,
1789 None => {
1790 self.pending_dispatch_meta.clear();
1791 self.dispatch_in_cb = 0;
1792 return Ok(());
1793 }
1794 };
1795 let n = self.pending_dispatch_meta.len();
1796 if n == 0 {
1797 self.dispatch_in_cb = 0;
1798 return Ok(());
1799 }
1800 // Refresh the (cpu, gpu) scale pair on every resolve; the
1801 // device call is cheap and keeps us robust against driver-side
1802 // timebase changes between CBs.
1803 let mut cpu_t: u64 = 0;
1804 let mut gpu_t: u64 = 0;
1805 let device: &metal::DeviceRef = unsafe {
1806 let cb = &*self.cmd_buf;
1807 msg_send![cb, device]
1808 };
1809 device.sample_timestamps(&mut cpu_t, &mut gpu_t);
1810 crate::kernel_profile::record_clock_pair(cpu_t, gpu_t);
1811 let length = (n as u64).saturating_mul(2);
1812 let data = sb.resolve_counter_range(NSRange {
1813 location: 0,
1814 length,
1815 });
1816 // `resolve_counter_range` returns one NSUInteger per sample.
1817 // Pair them up: data[2i] = start, data[2i+1] = end.
1818 for (i, meta) in self.pending_dispatch_meta.drain(..).enumerate() {
1819 let start_idx = 2 * i;
1820 let end_idx = 2 * i + 1;
1821 if end_idx >= data.len() {
1822 break;
1823 }
1824 let start_raw = data[start_idx] as u64;
1825 let end_raw = data[end_idx] as u64;
1826 let start_ns = crate::kernel_profile::convert_gpu_ticks_to_ns(start_raw);
1827 let end_ns = crate::kernel_profile::convert_gpu_ticks_to_ns(end_raw);
1828 let gpu_ns = end_ns.saturating_sub(start_ns);
1829 crate::kernel_profile::record_dispatch(
1830 crate::kernel_profile::DispatchEntry {
1831 cb_label: cb_label.to_string(),
1832 op_kind: meta.op_kind,
1833 dispatch_index: meta.dispatch_index,
1834 gpu_ns,
1835 start_gpu_ns: start_ns,
1836 end_gpu_ns: end_ns,
1837 },
1838 );
1839 }
1840 // Buffer dropped at end of scope releases the underlying
1841 // CounterSampleBuffer; per-CB lifetime correctly bounded.
1842 drop(sb);
1843 self.dispatch_in_cb = 0;
1844 Ok(())
1845 }
1846
1847 /// Commit the command buffer and block until the GPU finishes execution.
1848 ///
1849 /// # Errors
1850 ///
1851 /// Returns `MlxError::CommandBufferError` if the GPU reports an error.
1852 pub fn commit_and_wait(&mut self) -> Result<()> {
1853 SYNC_COUNT.fetch_add(1, Ordering::Relaxed);
1854
1855 // End the persistent compute encoder before committing.
1856 self.end_active_encoder();
1857
1858 // ADR-015 iter8e (Phase 3b): flush deferred residency-set
1859 // add/remove staging so the residency hint covers any buffers
1860 // referenced by this CB. Single commit per CB boundary; no-op
1861 // when no residency set or no staged changes.
1862 self.flush_residency_pending();
1863
1864 self.cmd_buf.commit();
1865 self.cmd_buf.wait_until_completed();
1866
1867 match self.cmd_buf.status() {
1868 MTLCommandBufferStatus::Completed => Ok(()),
1869 MTLCommandBufferStatus::Error => {
1870 Err(MlxError::CommandBufferError(
1871 "GPU command buffer completed with error status".into(),
1872 ))
1873 }
1874 status => Err(MlxError::CommandBufferError(format!(
1875 "Unexpected command buffer status after wait: {:?}",
1876 status
1877 ))),
1878 }
1879 }
1880
1881 /// Commit + wait, accumulating GPU wall-clock time under `label` into
1882 /// the [`crate::kernel_profile`] global table when `MLX_PROFILE_CB=1`
1883 /// is set. When the env var is unset, this is identical to
1884 /// [`commit_and_wait`](Self::commit_and_wait) — zero overhead.
1885 ///
1886 /// Used by hf2q's decode hot path to attribute per-cb GPU time to
1887 /// labeled phases (per-layer attn, per-layer ffn, output_head, etc.)
1888 /// without manually wiring `commit_wait_with_gpu_time` everywhere.
1889 ///
1890 /// # Errors
1891 ///
1892 /// Returns `MlxError::CommandBufferError` if the GPU reports an error.
1893 pub fn commit_and_wait_labeled(&mut self, label: &str) -> Result<()> {
1894 // ADR-015 iter16 — propagate `label` to MTLCommandBuffer.setLabel and
1895 // (if a compute encoder is active) MTLComputeCommandEncoder.setLabel
1896 // BEFORE end_encoding/commit so xctrace's
1897 // `metal-application-encoders-list` table populates `cmdbuffer-label`
1898 // and `encoder-label` columns with the semantic phase name (e.g.
1899 // `layer.attn_moe_ffn`, `output_head.fused_norm_lm_argmax`,
1900 // `layer.delta_net.ops1-9`). Joined to per-CB GPU duration via
1901 // `metal-gpu-submission-to-command-buffer-id` (sub_id ↔ encoder_id) →
1902 // `metal-gpu-execution-points` (per-dispatch start/end), this enables
1903 // per-phase µs/token attribution comparing hf2q vs llama side-by-side
1904 // (iter15 §E "iter16 ATTRIBUTION PATH"). Cost is a single ObjC
1905 // msg_send per CB submission — sub-µs on M5 Max — and a no-op when
1906 // xctrace isn't recording, so this is unconditionally safe to call on
1907 // the production decode hot path.
1908 self.apply_labels(label);
1909 // ADR-015 iter63: record GPU time AND resolve per-dispatch samples
1910 // when either env gate is set. Per-dispatch sampling force-enables
1911 // the per-CB path so cross-validation per Risk R3 always has a
1912 // ground-truth comparator.
1913 let need_gpu_time =
1914 crate::kernel_profile::is_enabled() || crate::kernel_profile::is_dispatch_enabled();
1915 if need_gpu_time {
1916 let (start_s, end_s) = self.commit_wait_with_gpu_time()?;
1917 let ns = ((end_s - start_s).max(0.0) * 1_000_000_000.0) as u64;
1918 if crate::kernel_profile::is_enabled() {
1919 crate::kernel_profile::record(label, ns);
1920 }
1921 if crate::kernel_profile::is_dispatch_enabled() {
1922 self.resolve_dispatch_samples(label)?;
1923 }
1924 Ok(())
1925 } else {
1926 self.commit_and_wait()
1927 }
1928 }
1929
1930 /// Async commit, but with profiling label. When `MLX_PROFILE_CB=1`
1931 /// is set, redirects to a synchronous [`commit_and_wait_labeled`]
1932 /// call to capture per-cb GPU time (this defeats async pipelining
1933 /// while profiling, which is the whole point — profile-mode is slow
1934 /// but informative). When unset, identical to [`commit`](Self::commit).
1935 pub fn commit_labeled(&mut self, label: &str) {
1936 // ADR-015 iter16 — see `commit_and_wait_labeled` for rationale.
1937 if crate::kernel_profile::is_enabled() {
1938 // Profile mode: force sync to capture GPU time. apply_labels is
1939 // called inside commit_and_wait_labeled — do NOT call it twice
1940 // here (would double the ObjC msg_send under MLX_PROFILE_CB=1).
1941 // Errors are logged via stderr because the void return matches
1942 // commit().
1943 if let Err(e) = self.commit_and_wait_labeled(label) {
1944 eprintln!("[mlx-native] commit_labeled({}) failed: {}", label, e);
1945 }
1946 } else {
1947 // Async path: apply labels here so xctrace MST traces capture
1948 // per-CB phase attribution under default decode (no
1949 // `MLX_PROFILE_CB`).
1950 self.apply_labels(label);
1951 self.commit();
1952 }
1953 }
1954
1955 /// Apply `label` to the underlying `MTLCommandBuffer` and, if a compute
1956 /// encoder is currently active, to the `MTLComputeCommandEncoder`.
1957 ///
1958 /// Called from [`commit_labeled`] and [`commit_and_wait_labeled`] BEFORE
1959 /// the encoder is ended / the CB is committed so xctrace's
1960 /// `metal-application-encoders-list` table picks up the label on the
1961 /// row emitted at the encoder's `endEncoding` / CB submission boundary.
1962 /// Single ObjC `msg_send` per call (two if an encoder is active); sub-µs
1963 /// on M5 Max; no-op when xctrace isn't recording.
1964 ///
1965 /// Skipped (debug-only assert) if `label` is empty — empty labels would
1966 /// produce an indistinguishable trace row from the metal-rs default
1967 /// `Command Buffer 0` placeholder.
1968 #[inline]
1969 fn apply_labels(&mut self, label: &str) {
1970 debug_assert!(!label.is_empty(), "commit_*_labeled called with empty label");
1971 if label.is_empty() {
1972 return;
1973 }
1974 self.cmd_buf.set_label(label);
1975 if !self.active_encoder.is_null() {
1976 // SAFETY: active_encoder is non-null and points to a live encoder
1977 // owned by cmd_buf — same invariant as get_or_create_encoder /
1978 // memory_barrier. set_label is a single property write on the
1979 // ObjC object; safe before endEncoding.
1980 unsafe { &*self.active_encoder }.set_label(label);
1981 }
1982 // ADR-015 iter63: capture the most recent label for per-dispatch
1983 // entries. Cheap String allocation — only happens at CB commit
1984 // boundaries, not per dispatch.
1985 self.last_label.clear();
1986 self.last_label.push_str(label);
1987 }
1988
1989 /// Commit + wait, returning `(gpu_start_s, gpu_end_s)` CFTimeInterval
1990 /// timestamps from `MTLCommandBuffer`'s `GPUStartTime`/`GPUEndTime`
1991 /// properties. Both are mach-absolute CFTimeInterval seconds (double).
1992 ///
1993 /// Intended for `HF2Q_PROFILE_GPU_TS=1` per-bucket GPU wall-clock
1994 /// attribution. Adds exactly two ObjC property reads per call on top
1995 /// of the regular `commit_and_wait` — measured well under 1 μs on
1996 /// M5 Max.
1997 ///
1998 /// # Errors
1999 ///
2000 /// Returns `MlxError::CommandBufferError` if the GPU reports an error.
2001 pub fn commit_wait_with_gpu_time(&mut self) -> Result<(f64, f64)> {
2002 self.commit_and_wait()?;
2003 // SAFETY: cmd_buf is a valid MTLCommandBuffer that has been
2004 // committed and awaited. GPUStartTime / GPUEndTime return
2005 // CFTimeInterval (double precision seconds). See
2006 // https://developer.apple.com/documentation/metal/mtlcommandbuffer/1639925-gpustarttime
2007 let (gpu_start, gpu_end): (f64, f64) = unsafe {
2008 let cb = &*self.cmd_buf;
2009 let s: f64 = msg_send![cb, GPUStartTime];
2010 let e: f64 = msg_send![cb, GPUEndTime];
2011 (s, e)
2012 };
2013 Ok((gpu_start, gpu_end))
2014 }
2015
2016 /// Commit the command buffer WITHOUT blocking.
2017 ///
2018 /// The GPU begins executing the encoded commands immediately. Call
2019 /// [`wait_until_completed`](Self::wait_until_completed) later to block
2020 /// the CPU and check for errors. This allows the CPU to continue doing
2021 /// other work (e.g. preparing the next batch) while the GPU runs.
2022 pub fn commit(&mut self) {
2023 self.end_active_encoder();
2024 // ADR-015 iter8e (Phase 3b): same flush hook as commit_and_wait —
2025 // this is the async-pipeline path that production decode uses.
2026 self.flush_residency_pending();
2027 self.cmd_buf.commit();
2028 }
2029
2030 /// Block until a previously committed command buffer completes.
2031 ///
2032 /// Must be called after [`commit`](Self::commit). Do not call after
2033 /// [`commit_and_wait`](Self::commit_and_wait) — that method already waits.
2034 ///
2035 /// # Errors
2036 ///
2037 /// Returns `MlxError::CommandBufferError` if the GPU reports an error.
2038 pub fn wait_until_completed(&self) -> Result<()> {
2039 self.cmd_buf.wait_until_completed();
2040 match self.cmd_buf.status() {
2041 MTLCommandBufferStatus::Completed => Ok(()),
2042 MTLCommandBufferStatus::Error => Err(MlxError::CommandBufferError(
2043 "GPU command buffer completed with error status".into(),
2044 )),
2045 status => Err(MlxError::CommandBufferError(format!(
2046 "Unexpected command buffer status after wait: {:?}",
2047 status
2048 ))),
2049 }
2050 }
2051
2052 /// Borrow the underlying Metal command buffer.
2053 #[inline]
2054 pub fn metal_command_buffer(&self) -> &CommandBuffer {
2055 &self.cmd_buf
2056 }
2057
2058 /// Borrow the residency set bound to this encoder, if one exists.
2059 ///
2060 /// ADR-019 Phase 0b iter89e2-B: exposed `pub(crate)` so
2061 /// [`crate::EncoderSession`] can route caller-driven add/remove
2062 /// requests through the same `Arc<ResidencySetInner>` the encoder
2063 /// itself flushes at every `commit*` boundary. The single-set
2064 /// invariant from `device.rs::MlxDevice` is preserved — both the
2065 /// encoder's `flush_residency_pending` and the session's delegated
2066 /// add/remove operate on the SAME residency set. Returns `None` when
2067 /// residency sets are disabled (HF2Q_NO_RESIDENCY=1, macOS<15, or
2068 /// `CommandEncoder::new` from a residency-less queue).
2069 #[inline]
2070 pub(crate) fn residency_set(&self) -> Option<&ResidencySet> {
2071 self.residency_set.as_ref()
2072 }
2073
2074 /// Reopen `cmd_buf` with a fresh `CommandBuffer` from the originating queue.
2075 ///
2076 /// ADR-019 Phase 0b iter89e2-B: enables multi-stage chaining. After a
2077 /// non-blocking `commit*` has handed the prior CB to Metal, this method
2078 /// rotates `cmd_buf` to a freshly-allocated CB on the same queue and
2079 /// resets every per-CB scratch field so the next dispatch is encoded
2080 /// onto the new CB.
2081 ///
2082 /// # Caller contract
2083 ///
2084 /// Only valid when `active_encoder.is_null()` (the persistent compute
2085 /// encoder must have been ended via `end_active_encoder()`, which both
2086 /// `commit_and_wait` and `commit` already do). Calling this method
2087 /// while a compute encoder is open would leak the encoder (the new
2088 /// `cmd_buf` does not own it) and trip Metal's "Command encoder
2089 /// released without endEncoding" assertion when the prior `cmd_buf`
2090 /// drops. Callers are [`crate::EncoderSession::reset_for_next_stage`]
2091 /// only — the session has already committed before invoking this.
2092 ///
2093 /// # F2 / F11 / F12 fence preservation
2094 ///
2095 /// - **F2 — residency-rescission**: this method does NOT re-flush
2096 /// the residency set. The prior `commit*` already flushed; staged
2097 /// add/remove since then will flush at the next `commit*` on the
2098 /// new CB. The residency-set Arc clone is preserved.
2099 /// - **F11 — zero-init alloc_buffer**: untouched (no buffer allocs).
2100 /// - **F12 — `HF2Q_FORCE_SERIAL_DISPATCH`**: the new CB will lazily
2101 /// open its compute encoder via `get_or_create_encoder`, which
2102 /// re-reads the env var; the falsification probe still fires on
2103 /// the new CB.
2104 ///
2105 /// # Counter semantics
2106 ///
2107 /// Bumps `CMD_BUF_COUNT` exactly once per call, matching the
2108 /// `new_with_residency` accounting. Does NOT bump `SYNC_COUNT` (no
2109 /// commit/wait happens here).
2110 pub(crate) fn reset_command_buffer(&mut self) {
2111 debug_assert!(
2112 self.active_encoder.is_null(),
2113 "reset_command_buffer called with an active compute encoder \
2114 — caller must commit (which calls end_active_encoder) first"
2115 );
2116 let cmd_buf = if unretained_refs_enabled() {
2117 self.queue
2118 .new_command_buffer_with_unretained_references()
2119 .to_owned()
2120 } else {
2121 self.queue.new_command_buffer().to_owned()
2122 };
2123 CMD_BUF_COUNT.fetch_add(1, Ordering::Relaxed);
2124 self.cmd_buf = cmd_buf;
2125 // Per-CB scratch state — every field that's documented as being
2126 // bounded by a CB lifetime resets here.
2127 self.active_encoder = std::ptr::null();
2128 self.dispatch_in_cb = 0;
2129 self.last_label.clear();
2130 self.pending_dispatch_meta.clear();
2131 // `mem_ranges` is a per-CB barrier inference state; clearing on
2132 // CB rotation matches the `commit_and_wait` post-commit invariant
2133 // (any new CB starts with no pending hazards). The field's own
2134 // `clear` is invoked via `MemRanges::default` here to avoid
2135 // exposing internals.
2136 self.mem_ranges = MemRanges::new();
2137 // `sample_buffer` is dropped explicitly inside
2138 // `resolve_dispatch_samples` after a CB completes; we leave it
2139 // in whatever state the prior commit left it (typically `None`
2140 // after `commit_and_wait` finishes). A stale `Some` here would
2141 // be visible only under `MLX_PROFILE_DISPATCH=1` which fires its
2142 // own one-shot warning; not worth a special case.
2143 // `capture` (if Some) persists across CB rotation — capture mode
2144 // accumulates across stages within a session by design.
2145 // `pending_op_kind` / `pending_reads` / `pending_writes` only
2146 // hold tags for the NEXT dispatch and are consumed when that
2147 // dispatch fires — leaving them as-is is correct.
2148 }
2149
2150 /// Encode an `MTLSharedEvent` wait at `value` on the current CB.
2151 ///
2152 /// ADR-019 Phase 0b iter89e2-B: pairs with [`Self::encode_signal_event`]
2153 /// to express the inter-CB ordering D3 stage boundaries need. The new
2154 /// CB's GPU work blocks until the prior CB's signal lands on the same
2155 /// event at >= `value`.
2156 ///
2157 /// # Caller contract
2158 ///
2159 /// Must be called BEFORE any compute encoder is opened on the new
2160 /// CB — the wait is a CB-level op that must precede every dispatch
2161 /// in the new CB to actually order them. [`crate::EncoderSession::reset_for_next_stage`]
2162 /// fires this immediately after `reset_command_buffer`, before any
2163 /// dispatch lazy-opens the encoder.
2164 #[inline]
2165 pub(crate) fn encode_wait_for_event(&self, event: &metal::EventRef, value: u64) {
2166 debug_assert!(
2167 self.active_encoder.is_null(),
2168 "encode_wait_for_event called with an open compute encoder \
2169 — wait must precede the first dispatch on the new CB"
2170 );
2171 self.cmd_buf.encode_wait_for_event(event, value);
2172 }
2173
2174 /// End the active compute encoder, encode a stage-fence signal, and
2175 /// commit the CB non-blocking — atomically from the caller's view.
2176 ///
2177 /// ADR-019 Phase 0b iter89e2-B: this is the helper
2178 /// [`crate::EncoderSession::fence_stage`] uses to thread the signal
2179 /// between the encoder-end and the CB-commit boundaries that
2180 /// `commit_labeled` would otherwise serialize. Sequence:
2181 ///
2182 /// 1. End the persistent compute encoder (so `encodeSignalEvent:` is
2183 /// encoded at CB-level, not encoder-level — Metal validates that
2184 /// `encodeSignalEvent:` outside any encoder pass is the only
2185 /// legal placement).
2186 /// 2. Apply `label` (when `Some`) to the CB. Note: at this point
2187 /// the encoder is already ended, so the encoder's own
2188 /// `setLabel:` is a no-op site — only the CB label propagates.
2189 /// `last_label` and per-dispatch profiling keep working as
2190 /// documented.
2191 /// 3. Encode `encodeSignalEvent:event:value:new_value` at CB-level.
2192 /// 4. Flush the residency-set pending staging (matches the
2193 /// `commit_labeled` / `commit` flush at encoder.rs:2004).
2194 /// 5. Commit the CB non-blocking (matches `commit()` at
2195 /// encoder.rs:2026).
2196 ///
2197 /// # Counter semantics
2198 ///
2199 /// Bumps `SYNC_COUNT` zero times (non-blocking). Bumps
2200 /// `CMD_BUF_COUNT` zero times (no new CB allocated here —
2201 /// [`Self::reset_command_buffer`] does that on the next stage).
2202 ///
2203 /// # Errors
2204 ///
2205 /// Infallible (matches `commit()` semantics — errors surface only
2206 /// at `wait_until_completed`).
2207 pub(crate) fn fence_signal_and_commit(
2208 &mut self,
2209 event: &metal::EventRef,
2210 new_value: u64,
2211 label: Option<&str>,
2212 ) {
2213 // Step 1: end the active compute encoder. encode_signal_event's
2214 // debug_assert requires this be done first.
2215 self.end_active_encoder();
2216 // Step 2: apply the CB label so xctrace MST attribution still
2217 // works on the fenced CB. apply_labels' debug_assert against
2218 // empty labels matches commit_labeled's semantics.
2219 if let Some(l) = label {
2220 self.apply_labels(l);
2221 }
2222 // Step 3: encode the signal at CB-level.
2223 self.cmd_buf.encode_signal_event(event, new_value);
2224 // Step 4 + 5: same as commit() — flush residency staging, then
2225 // hand the CB to Metal.
2226 self.flush_residency_pending();
2227 self.cmd_buf.commit();
2228 }
2229}
2230
2231impl Drop for CommandEncoder {
2232 fn drop(&mut self) {
2233 // End the persistent compute encoder before the command buffer
2234 // is dropped, otherwise Metal will assert:
2235 // "Command encoder released without endEncoding"
2236 self.end_active_encoder();
2237 }
2238}