mlx_native/encoder.rs
1//! [`CommandEncoder`] — batched GPU command submission.
2//!
3//! Wraps a Metal command buffer. Encode one or more compute kernel dispatches,
4//! then call [`commit_and_wait`](CommandEncoder::commit_and_wait) to submit the
5//! entire batch and block until the GPU finishes.
6//!
7//! # Persistent compute encoder
8//!
9//! A single Metal `ComputeCommandEncoder` is kept alive across multiple
10//! dispatches within the same command buffer. This avoids the overhead of
11//! creating and ending a new compute encoder per dispatch — the same pattern
12//! candle uses (`compute_per_buffer`). On a forward pass with ~800 dispatches
13//! this saves ~800 encoder create/end cycles.
14//!
15//! # Capture mode (Phase 4e.1)
16//!
17//! When `start_capture()` is called, subsequent dispatches are recorded into a
18//! `Vec<CapturedNode>` instead of being encoded into Metal. `memory_barrier()`
19//! records a barrier sentinel. Call `take_capture()` to extract the recorded
20//! graph for later replay via `ComputeGraph::encode_sequential()`.
21
22use std::sync::atomic::{AtomicI8, AtomicU64, Ordering};
23
24use metal::{
25 CommandBuffer, CommandQueue, ComputeCommandEncoderRef, ComputePipelineState,
26 ComputePipelineStateRef, CounterSampleBuffer, CounterSampleBufferDescriptor,
27 MTLCommandBufferStatus, MTLCounterSamplingPoint, MTLDispatchType, MTLSize, MTLStorageMode,
28 NSRange,
29};
30#[allow(unused_imports)]
31use objc::{msg_send, sel, sel_impl};
32
33use crate::buffer::MlxBuffer;
34use crate::error::{MlxError, Result};
35use crate::mem_ranges::MemRanges;
36use crate::residency::ResidencySet;
37
38/// A buffer or inline-bytes binding for a compute kernel argument slot.
39pub enum KernelArg<'a> {
40 /// Bind an existing Metal buffer at the given index.
41 Buffer(&'a MlxBuffer),
42 /// Bind an existing Metal buffer at the given index with a byte offset.
43 BufferWithOffset(&'a MlxBuffer, u64),
44 /// Bind inline bytes (small constant data) at the given index.
45 /// The data must be `Pod` and is copied into the command encoder.
46 Bytes(&'a [u8]),
47}
48
49/// Pre-baked dispatch record for hot decode paths.
50///
51/// ADR-029 iter-175 Step 1d — first piece of the multi-week
52/// "Option A" refactor that the gemma4 decode gap analysis localized
53/// to per-dispatch CPU orchestration (forward_mlx::forward_decode →
54/// encode_one_layer → dispatch_qmatmul → quantized_matmul_ggml →
55/// dispatch_mv → encoder.encode_threadgroups_with_args).
56///
57/// At gemma4 decode m=1, every dispatch within the inner loop has
58/// load-time-immutable shape: the kernel pipeline, threadgroup
59/// geometry, params struct bytes, and binding-slot layout are fully
60/// determined by the weight + ggml_type and never change across the
61/// thousands of decode tokens that follow. `DispatchRecord` captures
62/// that state once at model-load (or on first-call lazy-init) so the
63/// hot path skips:
64/// - `KernelRegistry::get_pipeline*` HashMap lookups
65/// - match expressions over `ggml_type` for kernel-name + geometry
66/// - `MTLSize::new` construction (already-known values)
67/// - param-struct field stores + bytemuck::bytes_of conversion
68///
69/// Only the runtime-varying buffers (input, output) need to be passed
70/// to [`CommandEncoder::dispatch_record`]. Weight buffers are baked
71/// inline via the `bake_buffers` slot list.
72///
73/// # Bake-time invariants
74///
75/// - `buffer_slots.len() == bake_buffers.len() + runtime_buffer_count`
76/// for the call site contract; the call site documents
77/// `runtime_buffer_count` and the order of runtime buffers.
78/// - `params_bytes.len()` is whatever the kernel's `KernelArg::Bytes`
79/// expects (typically 8-byte aligned per Metal struct layout).
80/// - `threadgroup_mem` is `(slot, byte_length)` pairs; empty when the
81/// kernel doesn't request `[[threadgroup]]` memory.
82///
83/// # Coherence
84///
85/// `dispatch_record` produces a byte-identical Metal command stream
86/// to the equivalent `encode_threadgroups_with_args*` call. Capture
87/// mode is supported (replays into `CapturedNode::Dispatch` exactly
88/// like the unbaked path). See `dispatch_record` for the lockstep.
89#[derive(Clone)]
90pub struct DispatchRecord {
91 /// Pipeline reference, looked up once at bake time.
92 pub pipeline: ComputePipelineState,
93 /// Threadgroup count.
94 pub threadgroups: MTLSize,
95 /// Threads per threadgroup.
96 pub threads_per_tg: MTLSize,
97 /// Threadgroup shared-memory bindings: `(slot_index, byte_length)`.
98 /// Empty when the kernel doesn't allocate `[[threadgroup]]` memory.
99 pub threadgroup_mem: Vec<(u64, u64)>,
100 /// Pre-encoded params struct bytes (bound as `KernelArg::Bytes`).
101 /// Empty when the kernel has no inline-bytes parameter.
102 pub params_bytes: Vec<u8>,
103 /// Slot index for `params_bytes`. Ignored when `params_bytes` is empty.
104 pub params_slot: u64,
105 /// Slot indices for runtime buffer arguments, in caller order.
106 /// `dispatch_record` zips `runtime_buffers` against this list.
107 pub buffer_slots: Vec<u64>,
108 /// `CapturedOpKind` used when the encoder is in capture mode.
109 pub op_kind: CapturedOpKind,
110 /// Diagnostic label (kernel name) for debug/timing.
111 pub kernel_name: String,
112}
113
114/// Convert a `Pod` value to a byte slice suitable for `KernelArg::Bytes`.
115///
116/// # Safety
117///
118/// The caller must ensure `T` has the same layout as the corresponding
119/// MSL struct in the shader (matching field order, sizes, and alignment).
120pub fn as_bytes<T: bytemuck::Pod>(val: &T) -> &[u8] {
121 bytemuck::bytes_of(val)
122}
123
124// ---------------------------------------------------------------------------
125// Capture-mode types (Phase 4e.1 — Graph IR)
126// ---------------------------------------------------------------------------
127
128/// A recorded kernel argument binding.
129///
130/// When the encoder is in capture mode, each `set_buffer` / `set_bytes` call
131/// is stored as a `RecordedBinding` instead of being applied to Metal.
132#[derive(Clone)]
133pub enum RecordedBinding {
134 /// A Metal buffer at the given offset.
135 Buffer {
136 metal_buffer: metal::Buffer,
137 offset: u64,
138 },
139 /// Inline bytes (small constant data, copied).
140 Bytes(Vec<u8>),
141}
142
143/// How to dispatch the recorded kernel.
144#[derive(Clone, Copy, Debug)]
145pub enum DispatchKind {
146 /// `dispatch_threads(grid_size, threadgroup_size)` — Metal picks threadgroup count.
147 Threads,
148 /// `dispatch_thread_groups(threadgroups, threadgroup_size)` — caller specifies threadgroup count.
149 ThreadGroups,
150}
151
152/// Operation kind tag for captured nodes, used by the fusion pass (4e.2).
153///
154/// When the encoder is in capture mode, each dispatch can be tagged with an
155/// `OpKind` so the fusion pass can identify fuseable sequences without
156/// inspecting pipeline names.
157#[derive(Clone, Copy, Debug, PartialEq, Eq)]
158pub enum CapturedOpKind {
159 /// RMS normalization (with learned scale).
160 RmsNorm,
161 /// Elementwise multiply.
162 ElemMul,
163 /// Elementwise add.
164 ElemAdd,
165 /// Scaled dot-product attention (NOT reorderable — breaks lookahead).
166 Sdpa,
167 /// Softmax (NOT reorderable — breaks lookahead).
168 Softmax,
169 /// Any other operation — treated as reorderable by the graph optimizer.
170 Other,
171}
172
173impl CapturedOpKind {
174 /// Whether this captured op kind is safe to reorder past in the graph
175 /// optimizer (Phase 4e.3).
176 ///
177 /// Mirrors the `h_safe` whitelist from llama.cpp's
178 /// `ggml_metal_graph_optimize_reorder`. Non-safe ops break the 64-node
179 /// lookahead — the reorder pass cannot look past them.
180 pub fn is_reorderable(&self) -> bool {
181 match self {
182 Self::Sdpa | Self::Softmax => false,
183 Self::RmsNorm | Self::ElemMul | Self::ElemAdd | Self::Other => true,
184 }
185 }
186
187 /// Stable string label suitable for embedding in the per-dispatch
188 /// profile dump (ADR-015 iter63 §A.5). Matches the variant name —
189 /// `Other` is preserved verbatim so an aggregate-by-op_kind sort
190 /// produces a clean "what isn't yet labeled" bucket.
191 pub fn name(&self) -> &'static str {
192 match self {
193 Self::RmsNorm => "RmsNorm",
194 Self::ElemMul => "ElemMul",
195 Self::ElemAdd => "ElemAdd",
196 Self::Sdpa => "Sdpa",
197 Self::Softmax => "Softmax",
198 Self::Other => "Other",
199 }
200 }
201}
202
203/// A memory range annotation: (start_address, end_address).
204///
205/// Represents a contiguous GPU buffer region for conflict detection in the
206/// reorder pass (Phase 4e.3). Addresses are CPU-visible `contents_ptr()`
207/// values, which on Apple Silicon unified memory equal the GPU addresses.
208pub type MemRange = (usize, usize);
209
210/// A single captured compute dispatch or barrier sentinel.
211///
212/// Created when the encoder is in capture mode. Replayed later by
213/// `ComputeGraph::encode_sequential()`.
214#[derive(Clone)]
215pub enum CapturedNode {
216 /// A compute dispatch to replay.
217 Dispatch {
218 /// Pipeline state object to bind.
219 pipeline: ComputePipelineState,
220 /// Kernel argument bindings: (slot_index, binding).
221 bindings: Vec<(u64, RecordedBinding)>,
222 /// Grid or threadgroup count (interpretation depends on `dispatch_kind`).
223 threads_per_grid: MTLSize,
224 /// Threads per threadgroup.
225 threads_per_threadgroup: MTLSize,
226 /// Optional threadgroup memory allocations: (index, byte_length).
227 threadgroup_memory: Vec<(u64, u64)>,
228 /// Whether this is a dispatch_threads or dispatch_thread_groups call.
229 dispatch_kind: DispatchKind,
230 /// Operation kind tag for the fusion pass (4e.2).
231 /// Defaults to `Other` if not explicitly set via `set_op_kind()`.
232 op_kind: CapturedOpKind,
233 /// Read buffer ranges for reorder conflict detection (4e.3).
234 /// Populated from `barrier_between` calls in capture mode.
235 reads: Vec<MemRange>,
236 /// Write buffer ranges for reorder conflict detection (4e.3).
237 /// Populated from `barrier_between` calls in capture mode.
238 writes: Vec<MemRange>,
239 },
240 /// A memory barrier sentinel — forces a barrier at replay time.
241 Barrier,
242}
243
244/// Convert a slice of buffer references into capture-mode
245/// [`MemRange`] tuples. Used by the [`CommandEncoder::dispatch_tracked*`]
246/// family in capture mode — equivalent to the conversion
247/// `GraphSession::barrier_between` does at `graph.rs:1452-1465`.
248///
249/// `(start, end)` uses `contents_ptr() + byte_offset` as the start
250/// and `contents_ptr() + byte_offset + slice_extent` as the end.
251fn ranges_from_buffers(bufs: &[&MlxBuffer]) -> Vec<MemRange> {
252 bufs.iter()
253 .map(|b| {
254 let base = b.contents_ptr() as usize + b.byte_offset() as usize;
255 let extent = (b.byte_len()).saturating_sub(b.byte_offset() as usize);
256 (base, base + extent)
257 })
258 .collect()
259}
260
261/// Apply a slice of `KernelArg` bindings to a compute encoder.
262///
263/// `KernelArg::Buffer(buf)` propagates the `MlxBuffer::byte_offset()` so
264/// `slice_view`-derived sub-buffers are honored automatically — the
265/// kernel sees memory starting at the slice's offset. This matches the
266/// documented contract of `slice_view` and the offset-handling in the
267/// other binding paths in this file (`encode`, `encode_threadgroups`,
268/// `encode_threadgroups_with_shared`, replay). Without it, every
269/// `slice_view`-derived buffer bound via `KernelArg::Buffer` silently
270/// exposes the entire underlying allocation — surfaced by hf2q's
271/// nomic-bert iter-79 cosine parity bisection (cosine 0.098 → 0.999962
272/// after fix).
273///
274/// `KernelArg::BufferWithOffset(buf, offset)` continues to use the
275/// explicit `offset` argument verbatim (callers asking for an explicit
276/// offset get exactly that, even on sliced buffers). The two API
277/// surfaces are intentional: implicit (sliced views auto-propagate) vs.
278/// explicit (caller-controlled).
279#[inline]
280fn apply_bindings(encoder: &ComputeCommandEncoderRef, bindings: &[(u64, KernelArg<'_>)]) {
281 for &(index, ref arg) in bindings {
282 match arg {
283 KernelArg::Buffer(buf) => {
284 encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
285 }
286 KernelArg::BufferWithOffset(buf, offset) => {
287 encoder.set_buffer(index, Some(buf.metal_buffer()), *offset);
288 }
289 KernelArg::Bytes(bytes) => {
290 encoder.set_bytes(index, bytes.len() as u64, bytes.as_ptr() as *const _);
291 }
292 }
293 }
294}
295
296/// Number of times `commit_and_wait()` has been called (CPU sync points).
297static SYNC_COUNT: AtomicU64 = AtomicU64::new(0);
298
299/// Number of times an encode method has been called (GPU dispatches).
300static DISPATCH_COUNT: AtomicU64 = AtomicU64::new(0);
301
302/// Number of `MTLCommandBuffer` instances created via `CommandEncoder::new`.
303/// Increments once per `device.command_encoder()` call. Used by hf2q's
304/// `HF2Q_DECODE_PROFILE` instrumentation to measure command-buffer
305/// overhead per decode token (ADR-012 §Optimize / Task #15 follow-up).
306static CMD_BUF_COUNT: AtomicU64 = AtomicU64::new(0);
307
308/// Number of `memory_barrier()` calls that reached the
309/// `objc::msg_send![encoder, memoryBarrierWithScope:]` site. Capture-mode
310/// no-ops and pre-encoder no-ops are excluded so the count reflects
311/// actual MTL barriers issued.
312///
313/// Always tracked — the increment is one atomic op, ~5 ns. ADR-015 H4
314/// (Wave 2b hard gate #2) requires per-barrier counter resolution to
315/// confirm-or-falsify the barrier-coalescing lever; xctrace TimeProfiler
316/// at 1 ms sampling cannot resolve `memory_barrier` even though it fires
317/// ~440×/token (`docs/ADR-015-mlx-native-single-cb-decode.md` §"P3a' live
318/// profile pass" hypothesis register row H4).
319static BARRIER_COUNT: AtomicU64 = AtomicU64::new(0);
320
321/// Total nanoseconds spent inside the `objc::msg_send!` barrier site,
322/// summed across all calls. ONLY updated when the env var
323/// `MLX_PROFILE_BARRIERS=1` is set on the process (cached on first
324/// `memory_barrier` call). When disabled the timing path is a single
325/// branch + the unconditional barrier dispatch — same hot-path cost as
326/// before this counter was added.
327///
328/// Why env-gated: timing adds 2 × `Instant::now()` (~50–100 ns each via
329/// `mach_absolute_time`) per barrier. At ~440 barriers/token that is
330/// ~22–44 µs/token of measurement overhead — comparable to what we are
331/// trying to measure. Production must keep this off; profiling runs
332/// opt-in.
333static BARRIER_NS: AtomicU64 = AtomicU64::new(0);
334
335/// Reset all counters to zero.
336pub fn reset_counters() {
337 SYNC_COUNT.store(0, Ordering::Relaxed);
338 DISPATCH_COUNT.store(0, Ordering::Relaxed);
339 CMD_BUF_COUNT.store(0, Ordering::Relaxed);
340 BARRIER_COUNT.store(0, Ordering::Relaxed);
341 BARRIER_NS.store(0, Ordering::Relaxed);
342 AUTO_BARRIER_COUNT.store(0, Ordering::Relaxed);
343 AUTO_BARRIER_CONCURRENT.store(0, Ordering::Relaxed);
344}
345
346/// Read the current value of `SYNC_COUNT`.
347///
348/// Each call to `commit_and_wait()` increments this counter.
349pub fn sync_count() -> u64 {
350 SYNC_COUNT.load(Ordering::Relaxed)
351}
352
353/// Read the current value of `DISPATCH_COUNT`.
354///
355/// Each call to `encode()`, `encode_threadgroups()`, or
356/// `encode_threadgroups_with_shared()` increments this counter.
357pub fn dispatch_count() -> u64 {
358 DISPATCH_COUNT.load(Ordering::Relaxed)
359}
360
361/// Per-pipeline dispatch bucket support (ADR-028 iter-284).
362///
363/// Env-gated via `MLX_DISP_BUCKET=1`. When enabled, every
364/// `encode*` call records its pipeline's label in a global hash map.
365/// This gives a per-kernel breakdown comparable to llama.cpp's
366/// instrumented dispatch site for finding *which* kernels make up
367/// the per-token dispatch budget.
368fn pipeline_buckets()
369 -> &'static std::sync::Mutex<std::collections::HashMap<String, u64>> {
370 static BUCKETS: std::sync::OnceLock<
371 std::sync::Mutex<std::collections::HashMap<String, u64>>,
372 > = std::sync::OnceLock::new();
373 BUCKETS.get_or_init(|| std::sync::Mutex::new(std::collections::HashMap::new()))
374}
375
376/// Cached env-flag check — single load on the hot path.
377fn pipeline_bucket_enabled() -> bool {
378 static CACHED: AtomicI8 = AtomicI8::new(-1);
379 let v = CACHED.load(Ordering::Relaxed);
380 if v >= 0 {
381 return v == 1;
382 }
383 let on = std::env::var("MLX_DISP_BUCKET").as_deref() == Ok("1");
384 CACHED.store(if on { 1 } else { 0 }, Ordering::Relaxed);
385 on
386}
387
388/// Record a dispatch into the per-pipeline bucket if the env-flag is on.
389/// Called from every `encode*` site alongside the `DISPATCH_COUNT` bump.
390#[inline]
391pub(crate) fn bucket_dispatch(pipeline: &ComputePipelineStateRef) {
392 if !pipeline_bucket_enabled() {
393 return;
394 }
395 let label = pipeline.label();
396 if label.is_empty() {
397 return;
398 }
399 if let Ok(mut t) = pipeline_buckets().lock() {
400 *t.entry(label.to_string()).or_insert(0) += 1;
401 }
402}
403
404/// Public dump of `MLX_DISP_BUCKET` data: `Vec<(label, count)>` sorted
405/// descending by count. Returns empty when env-flag is off / never
406/// recorded.
407pub fn pipeline_dispatch_buckets() -> Vec<(String, u64)> {
408 let mut v: Vec<(String, u64)> = if let Ok(t) = pipeline_buckets().lock() {
409 t.iter().map(|(k, v)| (k.clone(), *v)).collect()
410 } else {
411 Vec::new()
412 };
413 v.sort_by(|a, b| b.1.cmp(&a.1));
414 v
415}
416
417/// Reset the per-pipeline dispatch buckets (typically called at decode
418/// start to ignore prefill / warmup contributions).
419pub fn reset_pipeline_dispatch_buckets() {
420 if let Ok(mut t) = pipeline_buckets().lock() {
421 t.clear();
422 }
423}
424
425/// Read the current value of `CMD_BUF_COUNT`.
426///
427/// Each `CommandEncoder::new` (i.e. each `MlxDevice::command_encoder()`)
428/// increments this counter. Useful for diagnosing per-dispatch Metal
429/// command-buffer overhead in inner loops.
430pub fn cmd_buf_count() -> u64 {
431 CMD_BUF_COUNT.load(Ordering::Relaxed)
432}
433
434/// Read the current value of `BARRIER_COUNT`.
435///
436/// Each `memory_barrier()` call that reaches the underlying
437/// `objc::msg_send![encoder, memoryBarrierWithScope:]` site increments this
438/// counter. Capture-mode no-ops and pre-encoder no-ops are excluded.
439/// ADR-015 H4 hypothesis: ~440 barriers/token on the qwen35 decode hot
440/// path (verify against this counter).
441pub fn barrier_count() -> u64 {
442 BARRIER_COUNT.load(Ordering::Relaxed)
443}
444
445/// Read the total nanoseconds spent in the `memoryBarrierWithScope:`
446/// `objc::msg_send!` site. Only non-zero when `MLX_PROFILE_BARRIERS=1`
447/// was in the environment at the time of the first `memory_barrier()`
448/// call (the env check is cached on first use).
449///
450/// Combined with [`barrier_count`] this gives µs/barrier =
451/// `barrier_total_ns() / 1000 / barrier_count()`.
452pub fn barrier_total_ns() -> u64 {
453 BARRIER_NS.load(Ordering::Relaxed)
454}
455
456/// Whether barrier timing is enabled (env-gated, cached on first check).
457///
458/// Reading the env var via `std::env::var` is itself non-trivial; using
459/// `OnceLock` caches the decision so the per-barrier branch is a single
460/// atomic-load + compare.
461fn barrier_profile_enabled() -> bool {
462 use std::sync::OnceLock;
463 static FLAG: OnceLock<bool> = OnceLock::new();
464 *FLAG.get_or_init(|| {
465 std::env::var("MLX_PROFILE_BARRIERS")
466 .map(|v| v == "1")
467 .unwrap_or(false)
468 })
469}
470
471/// Whether `MLX_UNRETAINED_REFS=1` is set in the process environment.
472///
473/// ADR-015 iter13 — when true, `CommandEncoder::new_with_residency` opens
474/// each `MTLCommandBuffer` via
475/// [`CommandQueueRef::new_command_buffer_with_unretained_references`]
476/// instead of the default `commandBuffer`. llama.cpp's per-token decode
477/// CBs use this same call (`/opt/llama.cpp/ggml/src/ggml-metal/`
478/// `ggml-metal-context.m:512` `[queue commandBufferWithUnretainedReferences]`)
479/// and gain ~3-5% wall on M-series GPUs by skipping per-buffer-binding ARC
480/// retains on submit.
481///
482/// **Caller-side prerequisite.** Every Metal buffer bound to a dispatch
483/// must outlive the CB — see the docstring on
484/// [`CommandEncoder::new_with_residency`] for the full caller contract.
485/// In hf2q, the per-decode-token `MlxBufferPool` (`buffer_pool.rs`)
486/// already keeps ARC clones alive in its `in_use` list across the entire
487/// decode token; routing transient scratches through that pool is the
488/// canonical way to satisfy the contract.
489///
490/// Cached on first read via `OnceLock` to keep the per-CB-construction
491/// branch single-atomic-load fast. Default OFF so any production decode
492/// run that does NOT explicitly set the var preserves retained-refs
493/// behavior verbatim.
494fn unretained_refs_enabled() -> bool {
495 use std::sync::OnceLock;
496 static FLAG: OnceLock<bool> = OnceLock::new();
497 *FLAG.get_or_init(|| {
498 std::env::var("MLX_UNRETAINED_REFS")
499 .map(|v| v == "1")
500 .unwrap_or(false)
501 })
502}
503
504/// Whether `HF2Q_PIPELINE_TG_MULT_HINT=1` is set. Cached on first read.
505///
506/// ADR-029 iter-175 Step 1q safety gate: when this flag is ON, every
507/// pipeline created via `KernelRegistry` has
508/// `threadGroupSizeIsMultipleOfThreadExecutionWidth(true)`. Apple's
509/// Metal spec says this is UB unless every dispatched threadgroup is
510/// a multiple of 32. `assert_tg_size_multiple_of_32_if_hinted()`
511/// asserts the constraint before each dispatch, converting UB to a
512/// safe panic with diagnostic info.
513fn pipeline_tg_mult_hint_enabled() -> bool {
514 use std::sync::OnceLock;
515 static FLAG: OnceLock<bool> = OnceLock::new();
516 *FLAG.get_or_init(|| {
517 std::env::var("HF2Q_PIPELINE_TG_MULT_HINT")
518 .map(|v| v == "1")
519 .unwrap_or(false)
520 })
521}
522
523/// ADR-029 iter-175 Step 1q — runtime safety check for
524/// `HF2Q_PIPELINE_TG_MULT_HINT=1`.
525///
526/// When the env flag is ON, the Metal pipeline descriptor sets
527/// `threadGroupSizeIsMultipleOfThreadExecutionWidth(true)`, which
528/// requires every dispatched threadgroup to have
529/// `tg.x * tg.y * tg.z % 32 == 0` on Apple silicon (where
530/// `threadExecutionWidth == 32`). Violating this is **undefined behavior**
531/// per the Metal spec — the GPU may produce garbage, hang, or panic.
532///
533/// This function panics with a clear message before dispatching so an
534/// offending site is caught immediately instead of silently corrupting
535/// output. Returns immediately (one atomic-load cost) when the env
536/// flag is OFF. Includes the pipeline's label in the panic message
537/// (Step 1r) so the offending kernel can be identified without a
538/// debugger.
539#[inline]
540fn assert_tg_size_multiple_of_32_if_hinted(
541 tg: MTLSize,
542 pipeline: &ComputePipelineStateRef,
543) {
544 if !pipeline_tg_mult_hint_enabled() {
545 return;
546 }
547 let total = tg.width.saturating_mul(tg.height).saturating_mul(tg.depth);
548 if total % 32 != 0 {
549 let label = pipeline.label();
550 panic!(
551 "ADR-029 Step 1q safety: HF2Q_PIPELINE_TG_MULT_HINT=1 requires \
552 threadgroup_size.x * y * z to be a multiple of 32 (Apple's \
553 threadExecutionWidth). Got tg=({}, {}, {}) → total={} → \
554 {} mod 32 = {}. Pipeline label: \"{}\". Either fix the \
555 dispatch site to use a multiple-of-32 threadgroup, or unset \
556 HF2Q_PIPELINE_TG_MULT_HINT.",
557 tg.width, tg.height, tg.depth, total, total, total % 32,
558 label
559 );
560 }
561}
562
563/// Whether `HF2Q_AUTO_BARRIER=1` is set in the process environment.
564///
565/// ADR-015 iter37 — when true, every [`CommandEncoder::dispatch_tracked`]
566/// call consults a [`MemRanges`](crate::mem_ranges::MemRanges) tracker
567/// and auto-emits a `memoryBarrierWithScope:` exactly when the new
568/// dispatch's read/write ranges conflict with previously-recorded
569/// ranges (mirrors llama.cpp's `ggml_metal_op_concurrency_check` at
570/// `/opt/llama.cpp/ggml/src/ggml-metal/ggml-metal-ops.cpp:147-225`).
571/// When false, `dispatch_tracked` collapses to the same code path as
572/// `encode*` — no tracking, no auto-barriers — preserving sourdough
573/// behavior for any caller that opts into the tracked API but runs
574/// without the env gate.
575///
576/// Cached on first read via `OnceLock`. Default OFF — production
577/// decode/prefill keeps its hand-placed `enc.memory_barrier()` calls
578/// until the migration in iter38+.
579fn auto_barrier_enabled() -> bool {
580 use std::sync::OnceLock;
581 static FLAG: OnceLock<bool> = OnceLock::new();
582 *FLAG.get_or_init(|| {
583 std::env::var("HF2Q_AUTO_BARRIER")
584 .map(|v| v == "1")
585 .unwrap_or(false)
586 })
587}
588
589/// Number of `memory_barrier()` calls auto-emitted by
590/// [`CommandEncoder::dispatch_tracked`] under
591/// `HF2Q_AUTO_BARRIER=1`. Disjoint from [`BARRIER_COUNT`] —
592/// auto-barriers also bump `BARRIER_COUNT` since they go through
593/// `memory_barrier()`, so this counter measures only the
594/// auto-emitted subset.
595static AUTO_BARRIER_COUNT: AtomicU64 = AtomicU64::new(0);
596
597/// Number of `dispatch_tracked` calls whose mem-ranges check returned
598/// "concurrent" (no barrier needed). Together with
599/// [`AUTO_BARRIER_COUNT`] this measures the elision rate of the
600/// dataflow barrier: `concurrent / (concurrent + barriers)` is the
601/// fraction of dispatches that ran inside the previous concurrent
602/// group rather than starting a new one.
603static AUTO_BARRIER_CONCURRENT: AtomicU64 = AtomicU64::new(0);
604
605// ---------------------------------------------------------------------------
606// ADR-015 iter63 — per-dispatch GPU sampling support
607// ---------------------------------------------------------------------------
608
609/// Hard cap on per-CB sample-buffer sample count (Risk R4 in
610/// PROFILING-KIT-DESIGN §A.7).
611///
612/// Empirically verified on Apple Silicon (M-series, macOS 26): the
613/// underlying `MTLCounterSampleBufferDescriptor.sampleCount` is bounded
614/// by a per-buffer **byte-size** limit of 32768 B. At 8 bytes per
615/// `MTLCounterResultTimestamp` sample that maps to a sample-count
616/// ceiling of `32_768 / 8 = 4096`. We allocate two samples per
617/// dispatch (start + end), so this ceiling = 2048 dispatches per CB.
618/// Decode CBs (~120 dispatches) fit comfortably; long prefill CBs
619/// (~6K dispatches per design §A.7) will truncate after 2048 — see
620/// [`Self::sample_dispatch_pre`] for the truncation path. Future
621/// iter can chunk-resolve every 2K dispatches.
622///
623/// The original design constant of 32_768 (PROFILING-KIT-DESIGN §A.7)
624/// was based on Apple's documented ~64K-per-buffer "practical" limit,
625/// but the measured constraint on this hardware is the 32 KB byte
626/// budget. Setting the budget below that would underutilize the
627/// buffer; setting it above causes
628/// `newCounterSampleBufferWithDescriptor` to fail with `Invalid sample
629/// buffer length: <bytes> B. Expected range: 8 -> 32768`.
630const MAX_SAMPLES_PER_CB: u64 = 4096;
631
632/// Whether the per-CB warning about a missing `MTLCommonCounterSetTimestamp`
633/// has been emitted yet. Risk R1: if `device.counter_sets()` does not
634/// return a set named `"timestamp"` (case-insensitive), we degrade the
635/// per-dispatch path to a no-op and log once via stderr.
636static TIMESTAMP_SET_WARN_LOGGED: AtomicU64 = AtomicU64::new(0);
637
638/// Pending per-dispatch metadata that pairs with sample indices `2i`
639/// (start) and `2i+1` (end) inside the CB's `MTLCounterSampleBuffer`.
640/// Resolved by `CommandEncoder::resolve_dispatch_samples` at CB
641/// commit-time and converted to [`crate::kernel_profile::DispatchEntry`]
642/// before being pushed to the global table.
643#[derive(Clone, Debug)]
644struct PendingDispatchMeta {
645 op_kind: &'static str,
646 dispatch_index: u32,
647}
648
649/// Read the cumulative number of auto-emitted barriers across all
650/// encoders since process start (or last [`reset_counters`]).
651pub fn auto_barrier_count() -> u64 {
652 AUTO_BARRIER_COUNT.load(Ordering::Relaxed)
653}
654
655/// Read the cumulative number of `dispatch_tracked` calls that did NOT
656/// emit a barrier (ran concurrent with the previous group).
657pub fn auto_barrier_concurrent_count() -> u64 {
658 AUTO_BARRIER_CONCURRENT.load(Ordering::Relaxed)
659}
660
661/// Issue the underlying Metal `memoryBarrierWithScope:` ObjC msg_send.
662///
663/// Held in its own `#[inline(never)]` function so xctrace / Instruments
664/// has a stable Rust frame to attribute barrier time against, separate
665/// from the surrounding encoder accounting. Per ADR-015 §P3a' Codex
666/// review Q2: TimeProfiler at 1 ms sampling cannot see this site when
667/// inlined; an explicit non-inline frame plus the [`BARRIER_NS`] counter
668/// closes the H4 hard gate.
669#[inline(never)]
670fn issue_metal_buffer_barrier(encoder: &ComputeCommandEncoderRef) {
671 // MTLBarrierScopeBuffers = 1 << 0 = 1.
672 const MTL_BARRIER_SCOPE_BUFFERS: u64 = 1;
673 unsafe {
674 let _: () =
675 objc::msg_send![encoder, memoryBarrierWithScope: MTL_BARRIER_SCOPE_BUFFERS];
676 }
677}
678
679/// A batched compute command encoder.
680///
681/// Keeps a single Metal `ComputeCommandEncoder` alive across multiple
682/// dispatches. The encoder is created on the first dispatch and ended
683/// only when the command buffer is committed. This mirrors candle's
684/// `compute_per_buffer` pattern and avoids per-dispatch encoder overhead.
685///
686/// # Typical usage
687///
688/// ```ignore
689/// let mut enc = device.command_encoder()?;
690/// // Multiple dispatches share the same compute encoder:
691/// enc.encode_threadgroups(pipeline1, &buffers1, tg1, tg_size1);
692/// enc.encode_threadgroups(pipeline2, &buffers2, tg2, tg_size2);
693/// enc.commit_and_wait()?;
694/// ```
695pub struct CommandEncoder {
696 cmd_buf: CommandBuffer,
697 /// Owned clone of the originating command queue.
698 ///
699 /// ADR-019 Phase 0b iter89e2-A: stored at `new_with_residency` time so
700 /// downstream lifecycle code (e.g. `EncoderSession::reset_for_next_stage`
701 /// in Phase 0b-B) can open a fresh `CommandBuffer` from the same queue
702 /// after a non-blocking `commit_stage()`. metal-rs 0.33's
703 /// `CommandQueue` type is `Send + Sync` via `foreign_obj_type!`
704 /// (`/Users/robert/.cargo/registry/src/index.crates.io-1949cf8c6b5b557f/metal-0.33.0/src/lib.rs:179`),
705 /// so adding this field preserves the existing unsafe `Send` impl
706 /// on `CommandEncoder` (declared below).
707 ///
708 /// ADR-019 Phase 0b iter89e2-B (CONSUMED): read by
709 /// [`Self::reset_command_buffer`] to spawn a fresh `CommandBuffer`
710 /// after a non-blocking `commit*` so `EncoderSession::reset_for_next_stage`
711 /// can chain stage CBs without re-constructing the encoder. Holding a
712 /// clone here (rather than a `&CommandQueue` borrow) avoids a lifetime
713 /// parameter on `CommandEncoder` that would propagate through every
714 /// consumer in mlx-native and hf2q.
715 queue: CommandQueue,
716 // SAFETY marker: see unsafe Send impl below.
717 /// Raw pointer to the persistent compute encoder.
718 /// Non-null when a compute pass is active.
719 /// The encoder borrows from `cmd_buf` but we cannot express this
720 /// lifetime in safe Rust, so we use a raw pointer.
721 /// SAFETY: the pointer is valid as long as `cmd_buf` is alive and
722 /// `end_encoding()` has not been called on it.
723 active_encoder: *const ComputeCommandEncoderRef,
724 /// When `Some`, dispatches are recorded here instead of being encoded
725 /// into Metal. Set via `start_capture()`, extracted via `take_capture()`.
726 capture: Option<Vec<CapturedNode>>,
727 /// Op kind tag for the NEXT captured dispatch. Set via `set_op_kind()`,
728 /// consumed (reset to `Other`) when a dispatch is captured.
729 pending_op_kind: CapturedOpKind,
730 /// Pending read buffer ranges for the NEXT captured dispatch.
731 /// Set via `set_pending_buffer_ranges()`, consumed when the next dispatch
732 /// is captured. Used by the reorder pass (Phase 4e.3).
733 pending_reads: Vec<MemRange>,
734 /// Pending write buffer ranges for the NEXT captured dispatch.
735 pending_writes: Vec<MemRange>,
736 /// ADR-015 iter8e (Phase 3b): residency set whose pending add/remove
737 /// staging is flushed at every `commit*` boundary.
738 ///
739 /// Cloned from the device at `device.command_encoder()` time. `None`
740 /// when residency sets are disabled (HF2Q_NO_RESIDENCY=1, macOS<15,
741 /// or test-only `CommandEncoder::new` from a residency-less queue).
742 residency_set: Option<ResidencySet>,
743 /// ADR-015 iter37: dataflow barrier inference state.
744 ///
745 /// Populated only when `HF2Q_AUTO_BARRIER=1` is set at process
746 /// start (cached via [`auto_barrier_enabled`]). Each
747 /// [`Self::dispatch_tracked`] call consults this state to decide
748 /// whether a Metal memory barrier is required; on conflict the
749 /// barrier is emitted, the state is reset, and the new dispatch's
750 /// ranges seed the next concurrent group. When the env gate is
751 /// off, `dispatch_tracked` collapses to its untracked equivalent
752 /// and this field is left empty for the encoder's lifetime.
753 ///
754 /// The field is always present (zero-sized when empty) so the
755 /// gate-off branch is a single bool-load + early return rather
756 /// than an allocation/Option indirection.
757 mem_ranges: MemRanges,
758 /// ADR-015 iter63 (per-dispatch profiling): the sample buffer for
759 /// `MTLCounterSampleBuffer.sampleCounters` calls that bracket every
760 /// `encode*` dispatch in this CB. Lazily allocated on first
761 /// dispatch when `MLX_PROFILE_DISPATCH=1`; `None` otherwise.
762 /// Released (set to `None`) inside `resolve_dispatch_samples` after
763 /// the CB completes — re-allocated on the next `encode*` if the env
764 /// gate stays set.
765 sample_buffer: Option<CounterSampleBuffer>,
766 /// ADR-015 iter63: pending per-dispatch metadata that pairs with
767 /// sample indices `2*i` and `2*i+1` inside `sample_buffer`. Each
768 /// `encode*` call appends one entry (when sampling is active);
769 /// `resolve_dispatch_samples` drains the vec at commit time.
770 pending_dispatch_meta: Vec<PendingDispatchMeta>,
771 /// ADR-015 iter63: 0-based dispatch ordinal within the current CB.
772 /// Incremented in every `encode*` site after taking the pending
773 /// op_kind; reset to 0 inside `resolve_dispatch_samples`.
774 dispatch_in_cb: u32,
775 /// ADR-015 iter63: most recent label set via `apply_labels`, used
776 /// as the per-dispatch `cb_label` field. `String::new()` until
777 /// `commit_and_wait_labeled` / `commit_labeled` is called.
778 last_label: String,
779}
780
781/// SAFETY: CommandEncoder is safe to Send across threads provided that:
782/// 1. Only one thread accesses the encoder at a time (exclusive ownership).
783/// 2. The encoder is not used concurrently from multiple threads.
784///
785/// Metal command buffers and compute encoders are thread-safe for exclusive
786/// access (Apple documentation: "You can create command buffers, encode
787/// commands, and submit them from any thread"). The raw pointer
788/// `active_encoder` borrows from `cmd_buf` and is valid as long as
789/// `cmd_buf` is alive — this invariant holds across thread boundaries
790/// because both fields move together.
791///
792/// This matches llama.cpp's pattern of encoding command buffers on GCD
793/// worker threads via `dispatch_apply`, and is used for the dual-buffer
794/// pipeline where buf1 is encoded on a worker thread while buf0 executes.
795unsafe impl Send for CommandEncoder {}
796
797impl CommandEncoder {
798 /// Create a new command encoder from the given command queue.
799 ///
800 /// This immediately creates a Metal command buffer.
801 ///
802 /// # Why retained references
803 ///
804 /// We use the regular `commandBuffer` (Metal retains every bound
805 /// resource for the lifetime of the buffer) rather than
806 /// `commandBufferWithUnretainedReferences`. llama.cpp uses unretained
807 /// refs for an additional perf bump (~3-5% on M-series GPUs), but the
808 /// hf2q dispatch pattern allocates many transient scratch buffers
809 /// inside helper functions (`apply_proj` → `weight_bf16_owned`,
810 /// `apply_pre_norm` → `params`, etc.) that go out of scope at the
811 /// helper's return. With unretained refs the metal::Buffer's ARC
812 /// drops to zero, freeing the underlying GPU memory before the
813 /// dispatch executes. Verified 2026-04-26: switching to unretained
814 /// hits "Command buffer error: GPU command buffer completed with
815 /// error status" on the first MoE FFN dispatch.
816 ///
817 /// To enable unretained refs in the future, every helper that
818 /// allocates and dispatches must thread its scratch buffers up to a
819 /// caller scope that outlives the eventual commit, OR all such
820 /// scratch must come from the per-decode-token pool (which already
821 /// ARC-retains in its in_use list). Today the lm_head + router-
822 /// download paths are still unpooled.
823 #[allow(dead_code)]
824 pub(crate) fn new(queue: &CommandQueue) -> Result<Self> {
825 Self::new_with_residency(queue, None)
826 }
827
828 /// Create a new command encoder, optionally bound to a residency set so
829 /// `commit*` boundaries can flush deferred add/remove staging.
830 ///
831 /// ADR-015 iter8e (Phase 3b): the encoder's `commit_and_wait`,
832 /// `commit_and_wait_labeled`, `commit`, `commit_labeled`,
833 /// `commit_wait_with_gpu_time` all call
834 /// [`ResidencySet::flush_pending`](ResidencySet::flush_pending) before
835 /// submitting the Metal command buffer. This converts the
836 /// per-allocation `[set commit]` storm
837 /// (~880 commits/decode-token in iter8d/8e claude+codex variants) into
838 /// at most one commit per CB submission — mirrors llama.cpp's
839 /// `ggml-metal-device.m:1378-1382` pattern (batch addAllocation in
840 /// loop, commit ONCE).
841 ///
842 /// ADR-015 iter13: when the `MLX_UNRETAINED_REFS=1` env var is set at
843 /// process start, this constructor uses
844 /// [`CommandQueueRef::new_command_buffer_with_unretained_references`]
845 /// instead of `new_command_buffer`. llama.cpp's per-token decode CBs
846 /// use `commandBufferWithUnretainedReferences` (see
847 /// `/opt/llama.cpp/ggml/src/ggml-metal/ggml-metal-context.m:512`) which
848 /// skips Metal's per-buffer-binding ARC-retain on submit and saves
849 /// ~3-5% on M-series GPUs (per the docstring above).
850 ///
851 /// **Caller contract under unretained refs.** Every Metal buffer bound
852 /// to a dispatch in this CB MUST outlive the CB's GPU completion. In
853 /// the hf2q decode path, that means every transient scratch must be
854 /// either (a) backed by the per-decode-token arena pool
855 /// (`MlxBufferPool` keeps an ARC clone in `in_use` until the next
856 /// `reset` — see `buffer_pool.rs:60`) or (b) hoisted to a caller scope
857 /// that lives across the terminal `commit_and_wait_labeled`. Helpers
858 /// in `apply_proj` / `apply_pre_norm` / lm_head cast / router-download
859 /// that allocated transients via `device.alloc_buffer` and dropped
860 /// them at function return MUST be lifted to `pooled_alloc_buffer`
861 /// before `MLX_UNRETAINED_REFS=1` is enabled, or the first MoE FFN
862 /// dispatch will crash with "Command buffer error: GPU command buffer
863 /// completed with error status" (verified 2026-04-26).
864 ///
865 /// The default (`MLX_UNRETAINED_REFS` unset) preserves retained-refs
866 /// behavior verbatim — this is the sourdough-safe path.
867 pub(crate) fn new_with_residency(
868 queue: &CommandQueue,
869 residency_set: Option<ResidencySet>,
870 ) -> Result<Self> {
871 let cmd_buf = if unretained_refs_enabled() {
872 queue.new_command_buffer_with_unretained_references().to_owned()
873 } else {
874 queue.new_command_buffer().to_owned()
875 };
876 CMD_BUF_COUNT.fetch_add(1, Ordering::Relaxed);
877 Ok(Self {
878 cmd_buf,
879 queue: queue.to_owned(),
880 active_encoder: std::ptr::null(),
881 capture: None,
882 pending_op_kind: CapturedOpKind::Other,
883 pending_reads: Vec::new(),
884 pending_writes: Vec::new(),
885 residency_set,
886 mem_ranges: MemRanges::new(),
887 sample_buffer: None,
888 pending_dispatch_meta: Vec::new(),
889 dispatch_in_cb: 0,
890 last_label: String::new(),
891 })
892 }
893
894 /// Enable capture mode.
895 ///
896 /// All subsequent dispatch and barrier calls will be recorded into a
897 /// `Vec<CapturedNode>` instead of being encoded into Metal.
898 /// Call `take_capture()` to extract the recorded nodes.
899 pub fn start_capture(&mut self) {
900 self.capture = Some(Vec::with_capacity(128));
901 }
902
903 /// Whether the encoder is currently in capture mode.
904 pub fn is_capturing(&self) -> bool {
905 self.capture.is_some()
906 }
907
908 /// Extract the captured nodes, ending capture mode.
909 ///
910 /// Returns `None` if capture mode was not active.
911 pub fn take_capture(&mut self) -> Option<Vec<CapturedNode>> {
912 self.capture.take()
913 }
914
915 /// Tag the NEXT captured dispatch with the given operation kind.
916 ///
917 /// The tag is consumed (reset to `Other`) after the next dispatch is
918 /// captured. Only meaningful in capture mode — has no effect on
919 /// direct-dispatch encoding.
920 ///
921 /// Used by op dispatch functions to annotate captures for the fusion
922 /// pass (Phase 4e.2).
923 pub fn set_op_kind(&mut self, kind: CapturedOpKind) {
924 self.pending_op_kind = kind;
925 }
926
927 /// Consume and return the pending op kind, resetting it to `Other`.
928 fn take_pending_op_kind(&mut self) -> CapturedOpKind {
929 let kind = self.pending_op_kind;
930 self.pending_op_kind = CapturedOpKind::Other;
931 kind
932 }
933
934 /// Stash buffer range annotations for the NEXT captured dispatch.
935 ///
936 /// Called by `GraphSession::barrier_between()` in capture mode to record
937 /// which buffers the next dispatch reads from and writes to. The ranges
938 /// are consumed by the next `encode_*` call and attached to the captured
939 /// `CapturedNode::Dispatch`.
940 ///
941 /// Only meaningful in capture mode — has no effect on direct-dispatch.
942 pub fn set_pending_buffer_ranges(&mut self, reads: Vec<MemRange>, writes: Vec<MemRange>) {
943 self.pending_reads = reads;
944 self.pending_writes = writes;
945 }
946
947 /// Patch the last captured dispatch node's empty reads/writes with the
948 /// given ranges. No-op if not capturing, or if the last node isn't a
949 /// Dispatch, or if its ranges are already populated.
950 ///
951 /// Used by `GraphSession::track_dispatch` in recording mode to annotate
952 /// dispatches that were called without a preceding `barrier_between`.
953 pub fn annotate_last_dispatch_if_missing(&mut self, reads: Vec<MemRange>, writes: Vec<MemRange>) {
954 if let Some(ref mut nodes) = self.capture {
955 if let Some(CapturedNode::Dispatch { reads: r, writes: w, .. }) = nodes.last_mut() {
956 if r.is_empty() && !reads.is_empty() {
957 *r = reads;
958 }
959 if w.is_empty() && !writes.is_empty() {
960 *w = writes;
961 }
962 }
963 }
964 }
965
966 /// Consume and return the pending buffer range annotations.
967 fn take_pending_buffer_ranges(&mut self) -> (Vec<MemRange>, Vec<MemRange>) {
968 let reads = std::mem::take(&mut self.pending_reads);
969 let writes = std::mem::take(&mut self.pending_writes);
970 (reads, writes)
971 }
972
973 /// Record buffer bindings into `RecordedBinding` form.
974 fn record_buffer_bindings(buffers: &[(u64, &MlxBuffer)]) -> Vec<(u64, RecordedBinding)> {
975 buffers
976 .iter()
977 .map(|&(index, buf)| {
978 (
979 index,
980 RecordedBinding::Buffer {
981 metal_buffer: buf.metal_buffer().clone(),
982 offset: buf.byte_offset(),
983 },
984 )
985 })
986 .collect()
987 }
988
989 /// Record `KernelArg` bindings into `RecordedBinding` form.
990 ///
991 /// `KernelArg::Buffer(buf)` records `buf.byte_offset()` so capture →
992 /// replay round-trips of `slice_view`-derived buffers preserve their
993 /// offsets, matching `record_buffer_bindings`'s behavior at line 382.
994 fn record_arg_bindings(bindings: &[(u64, KernelArg<'_>)]) -> Vec<(u64, RecordedBinding)> {
995 bindings
996 .iter()
997 .map(|(index, arg)| {
998 let recorded = match arg {
999 KernelArg::Buffer(buf) => RecordedBinding::Buffer {
1000 metal_buffer: buf.metal_buffer().clone(),
1001 offset: buf.byte_offset(),
1002 },
1003 KernelArg::BufferWithOffset(buf, offset) => RecordedBinding::Buffer {
1004 metal_buffer: buf.metal_buffer().clone(),
1005 offset: *offset,
1006 },
1007 KernelArg::Bytes(bytes) => RecordedBinding::Bytes(bytes.to_vec()),
1008 };
1009 (*index, recorded)
1010 })
1011 .collect()
1012 }
1013
1014 /// Get or create the persistent compute encoder.
1015 ///
1016 /// On the first call, creates a new compute encoder from the command
1017 /// buffer. On subsequent calls, returns the existing one.
1018 ///
1019 /// SAFETY: The returned reference borrows from `self.cmd_buf` which is
1020 /// alive for the lifetime of this `CommandEncoder`. The raw pointer is
1021 /// valid until `end_active_encoder()` is called.
1022 #[inline]
1023 fn get_or_create_encoder(&mut self) -> &ComputeCommandEncoderRef {
1024 if self.active_encoder.is_null() {
1025 // Use MTLDispatchTypeConcurrent to allow independent dispatches
1026 // to overlap on the GPU. Memory barriers are inserted between
1027 // dependent dispatches via `memory_barrier()`.
1028 //
1029 // ADR-015 iter61a-2 probe: HF2Q_FORCE_SERIAL_DISPATCH=1 falls back
1030 // to MTLDispatchType::Serial — every dispatch waits for the
1031 // previous to complete, eliminating concurrent-dispatch race
1032 // windows. Used to falsify Hypothesis (g): missing memory_barrier
1033 // calls between dependent dispatches cause cold-run logit
1034 // non-determinism via thread-race on a shared buffer.
1035 let dispatch_type = if std::env::var("HF2Q_FORCE_SERIAL_DISPATCH")
1036 .map(|v| v == "1")
1037 .unwrap_or(false)
1038 {
1039 MTLDispatchType::Serial
1040 } else {
1041 MTLDispatchType::Concurrent
1042 };
1043 let encoder = self
1044 .cmd_buf
1045 .compute_command_encoder_with_dispatch_type(dispatch_type);
1046 self.active_encoder = encoder as *const ComputeCommandEncoderRef;
1047 }
1048 // SAFETY: active_encoder is non-null and points to a valid encoder
1049 // owned by cmd_buf.
1050 unsafe { &*self.active_encoder }
1051 }
1052
1053 /// End the active compute encoder if one exists.
1054 #[inline]
1055 fn end_active_encoder(&mut self) {
1056 if !self.active_encoder.is_null() {
1057 // SAFETY: the pointer was obtained from cmd_buf.new_compute_command_encoder()
1058 // and has not been ended yet.
1059 unsafe { &*self.active_encoder }.end_encoding();
1060 self.active_encoder = std::ptr::null();
1061 }
1062 }
1063
1064 /// Insert a memory barrier with scope `MTLBarrierScopeBuffers`.
1065 ///
1066 /// When the encoder uses `MTLDispatchTypeConcurrent`, all dispatches can
1067 /// execute concurrently unless separated by a barrier. Call this between
1068 /// dispatches where the later dispatch reads a buffer written by an
1069 /// earlier one.
1070 ///
1071 /// This is the same pattern llama.cpp uses:
1072 /// `[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers]`
1073 #[allow(unexpected_cfgs)]
1074 pub fn memory_barrier(&mut self) {
1075 if let Some(ref mut nodes) = self.capture {
1076 nodes.push(CapturedNode::Barrier);
1077 return;
1078 }
1079 if self.active_encoder.is_null() {
1080 return;
1081 }
1082 BARRIER_COUNT.fetch_add(1, Ordering::Relaxed);
1083 // ADR-029 iter-175 Step 1e: when HF2Q_AUTO_BARRIER=1, hand-placed
1084 // barriers must reset the MemRanges tracker so partial migration to
1085 // `dispatch_tracked_*` stays correct. A hand-placed `memory_barrier()`
1086 // drains the GPU at this point; any tracked dispatch after it
1087 // should start with a fresh cumulative state instead of false-
1088 // conflicting against ranges recorded before this barrier.
1089 // No-op under default HF2Q_AUTO_BARRIER=0 (tracker is empty).
1090 if auto_barrier_enabled() {
1091 self.mem_ranges.reset();
1092 }
1093 // SAFETY: active_encoder is non-null and valid.
1094 let encoder = unsafe { &*self.active_encoder };
1095 if barrier_profile_enabled() {
1096 // mach_absolute_time path — only on when MLX_PROFILE_BARRIERS=1.
1097 let start = std::time::Instant::now();
1098 issue_metal_buffer_barrier(encoder);
1099 let elapsed_ns = start.elapsed().as_nanos() as u64;
1100 BARRIER_NS.fetch_add(elapsed_ns, Ordering::Relaxed);
1101 } else {
1102 issue_metal_buffer_barrier(encoder);
1103 }
1104 }
1105
1106 /// Set the compute pipeline state for subsequent dispatches.
1107 ///
1108 /// This begins a new compute pass if one is not already active.
1109 pub fn set_pipeline(&mut self, pipeline: &ComputePipelineStateRef) {
1110 let encoder = self.get_or_create_encoder();
1111 encoder.set_compute_pipeline_state(pipeline);
1112 }
1113
1114 /// Bind a buffer to a compute kernel argument slot.
1115 ///
1116 /// The `index` corresponds to the `[[buffer(N)]]` attribute in the MSL shader.
1117 pub fn set_buffer(&self, index: u64, buffer: &MlxBuffer) {
1118 let _ = (index, buffer);
1119 }
1120
1121 /// Dispatch threads on the GPU.
1122 pub fn dispatch_threads(&self, grid_size: MTLSize, threadgroup_size: MTLSize) {
1123 let _ = (grid_size, threadgroup_size);
1124 }
1125
1126 /// Encode a complete compute pass: set pipeline, bind buffers, dispatch.
1127 ///
1128 /// Reuses the persistent compute encoder — no per-dispatch encoder
1129 /// creation overhead.
1130 ///
1131 /// # Arguments
1132 ///
1133 /// * `pipeline` — The compiled compute pipeline to execute.
1134 /// * `buffers` — Slice of `(index, &MlxBuffer)` pairs for buffer bindings.
1135 /// * `grid_size` — Total number of threads to launch.
1136 /// * `threadgroup_size` — Threads per threadgroup.
1137 pub fn encode(
1138 &mut self,
1139 pipeline: &ComputePipelineStateRef,
1140 buffers: &[(u64, &MlxBuffer)],
1141 grid_size: MTLSize,
1142 threadgroup_size: MTLSize,
1143 ) {
1144 DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
1145 bucket_dispatch(pipeline);
1146 let op_kind = self.take_pending_op_kind();
1147 let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
1148 if let Some(ref mut nodes) = self.capture {
1149 nodes.push(CapturedNode::Dispatch {
1150 pipeline: pipeline.to_owned(),
1151 bindings: Self::record_buffer_bindings(buffers),
1152 threads_per_grid: grid_size,
1153 threads_per_threadgroup: threadgroup_size,
1154 threadgroup_memory: Vec::new(),
1155 dispatch_kind: DispatchKind::Threads,
1156 op_kind,
1157 reads: pending_reads,
1158 writes: pending_writes,
1159 });
1160 return;
1161 }
1162 // ADR-015 iter63: per-dispatch sampling (no-op when env unset).
1163 self.ensure_sample_buffer();
1164 let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
1165 // SAFETY: encoder_ptr aliases &self via active_encoder which we
1166 // know is non-null after get_or_create_encoder; this pattern is
1167 // used throughout the file (see memory_barrier).
1168 let encoder = unsafe { &*encoder_ptr };
1169 encoder.set_compute_pipeline_state(pipeline);
1170 for &(index, buf) in buffers {
1171 encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
1172 }
1173 let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
1174 assert_tg_size_multiple_of_32_if_hinted(threadgroup_size, pipeline);
1175 encoder.dispatch_threads(grid_size, threadgroup_size);
1176 self.sample_dispatch_post(encoder, pre_idx);
1177 }
1178
1179 /// Encode a compute pass using threadgroups instead of raw thread counts.
1180 ///
1181 /// Reuses the persistent compute encoder — no per-dispatch encoder
1182 /// creation overhead.
1183 pub fn encode_threadgroups(
1184 &mut self,
1185 pipeline: &ComputePipelineStateRef,
1186 buffers: &[(u64, &MlxBuffer)],
1187 threadgroups: MTLSize,
1188 threadgroup_size: MTLSize,
1189 ) {
1190 DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
1191 bucket_dispatch(pipeline);
1192 let op_kind = self.take_pending_op_kind();
1193 let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
1194 if let Some(ref mut nodes) = self.capture {
1195 nodes.push(CapturedNode::Dispatch {
1196 pipeline: pipeline.to_owned(),
1197 bindings: Self::record_buffer_bindings(buffers),
1198 threads_per_grid: threadgroups,
1199 threads_per_threadgroup: threadgroup_size,
1200 threadgroup_memory: Vec::new(),
1201 dispatch_kind: DispatchKind::ThreadGroups,
1202 op_kind,
1203 reads: pending_reads,
1204 writes: pending_writes,
1205 });
1206 return;
1207 }
1208 // ADR-015 iter63: per-dispatch sampling (no-op when env unset).
1209 self.ensure_sample_buffer();
1210 let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
1211 // SAFETY: see encode() above.
1212 let encoder = unsafe { &*encoder_ptr };
1213 encoder.set_compute_pipeline_state(pipeline);
1214 for &(index, buf) in buffers {
1215 encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
1216 }
1217 let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
1218 assert_tg_size_multiple_of_32_if_hinted(threadgroup_size, pipeline);
1219 encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
1220 self.sample_dispatch_post(encoder, pre_idx);
1221 }
1222
1223 /// Encode a compute pass using threadgroups with shared threadgroup memory.
1224 ///
1225 /// Like [`encode_threadgroups`](Self::encode_threadgroups), but additionally
1226 /// allocates threadgroup memory at the specified indices. This is required
1227 /// for kernels that use `threadgroup` memory (e.g. reductions in rms_norm
1228 /// and softmax).
1229 ///
1230 /// # Arguments
1231 ///
1232 /// * `pipeline` — The compiled compute pipeline to execute.
1233 /// * `buffers` — Slice of `(index, &MlxBuffer)` pairs for buffer bindings.
1234 /// * `threadgroup_mem` — Slice of `(index, byte_length)` pairs for threadgroup memory.
1235 /// * `threadgroups` — Number of threadgroups to dispatch.
1236 /// * `threadgroup_size` — Threads per threadgroup.
1237 pub fn encode_threadgroups_with_shared(
1238 &mut self,
1239 pipeline: &ComputePipelineStateRef,
1240 buffers: &[(u64, &MlxBuffer)],
1241 threadgroup_mem: &[(u64, u64)],
1242 threadgroups: MTLSize,
1243 threadgroup_size: MTLSize,
1244 ) {
1245 DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
1246 bucket_dispatch(pipeline);
1247 let op_kind = self.take_pending_op_kind();
1248 let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
1249 if let Some(ref mut nodes) = self.capture {
1250 nodes.push(CapturedNode::Dispatch {
1251 pipeline: pipeline.to_owned(),
1252 bindings: Self::record_buffer_bindings(buffers),
1253 threads_per_grid: threadgroups,
1254 threads_per_threadgroup: threadgroup_size,
1255 threadgroup_memory: threadgroup_mem.to_vec(),
1256 dispatch_kind: DispatchKind::ThreadGroups,
1257 op_kind,
1258 reads: pending_reads,
1259 writes: pending_writes,
1260 });
1261 return;
1262 }
1263 // ADR-015 iter63: per-dispatch sampling (no-op when env unset).
1264 self.ensure_sample_buffer();
1265 let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
1266 // SAFETY: see encode() above.
1267 let encoder = unsafe { &*encoder_ptr };
1268 encoder.set_compute_pipeline_state(pipeline);
1269 for &(index, buf) in buffers {
1270 encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
1271 }
1272 for &(index, byte_length) in threadgroup_mem {
1273 encoder.set_threadgroup_memory_length(index, byte_length);
1274 }
1275 let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
1276 assert_tg_size_multiple_of_32_if_hinted(threadgroup_size, pipeline);
1277 encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
1278 self.sample_dispatch_post(encoder, pre_idx);
1279 }
1280
1281 /// Encode a dispatch with mixed buffer/bytes bindings (dispatch_threads).
1282 ///
1283 /// Reuses the persistent compute encoder.
1284 pub fn encode_with_args(
1285 &mut self,
1286 pipeline: &ComputePipelineStateRef,
1287 bindings: &[(u64, KernelArg<'_>)],
1288 grid_size: MTLSize,
1289 threadgroup_size: MTLSize,
1290 ) {
1291 DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
1292 bucket_dispatch(pipeline);
1293 let op_kind = self.take_pending_op_kind();
1294 let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
1295 if let Some(ref mut nodes) = self.capture {
1296 nodes.push(CapturedNode::Dispatch {
1297 pipeline: pipeline.to_owned(),
1298 bindings: Self::record_arg_bindings(bindings),
1299 threads_per_grid: grid_size,
1300 threads_per_threadgroup: threadgroup_size,
1301 threadgroup_memory: Vec::new(),
1302 dispatch_kind: DispatchKind::Threads,
1303 op_kind,
1304 reads: pending_reads,
1305 writes: pending_writes,
1306 });
1307 return;
1308 }
1309 // ADR-015 iter63: per-dispatch sampling (no-op when env unset).
1310 self.ensure_sample_buffer();
1311 let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
1312 // SAFETY: see encode() above.
1313 let encoder = unsafe { &*encoder_ptr };
1314 encoder.set_compute_pipeline_state(pipeline);
1315 apply_bindings(encoder, bindings);
1316 let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
1317 assert_tg_size_multiple_of_32_if_hinted(threadgroup_size, pipeline);
1318 encoder.dispatch_threads(grid_size, threadgroup_size);
1319 self.sample_dispatch_post(encoder, pre_idx);
1320 }
1321
1322 /// Encode a dispatch with mixed buffer/bytes bindings (dispatch_thread_groups).
1323 ///
1324 /// Reuses the persistent compute encoder.
1325 pub fn encode_threadgroups_with_args(
1326 &mut self,
1327 pipeline: &ComputePipelineStateRef,
1328 bindings: &[(u64, KernelArg<'_>)],
1329 threadgroups: MTLSize,
1330 threadgroup_size: MTLSize,
1331 ) {
1332 DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
1333 bucket_dispatch(pipeline);
1334 let op_kind = self.take_pending_op_kind();
1335 let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
1336 if let Some(ref mut nodes) = self.capture {
1337 nodes.push(CapturedNode::Dispatch {
1338 pipeline: pipeline.to_owned(),
1339 bindings: Self::record_arg_bindings(bindings),
1340 threads_per_grid: threadgroups,
1341 threads_per_threadgroup: threadgroup_size,
1342 threadgroup_memory: Vec::new(),
1343 dispatch_kind: DispatchKind::ThreadGroups,
1344 op_kind,
1345 reads: pending_reads,
1346 writes: pending_writes,
1347 });
1348 return;
1349 }
1350 // ADR-015 iter63: per-dispatch sampling (no-op when env unset).
1351 self.ensure_sample_buffer();
1352 let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
1353 // SAFETY: see encode() above.
1354 let encoder = unsafe { &*encoder_ptr };
1355 encoder.set_compute_pipeline_state(pipeline);
1356 apply_bindings(encoder, bindings);
1357 let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
1358 assert_tg_size_multiple_of_32_if_hinted(threadgroup_size, pipeline);
1359 encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
1360 self.sample_dispatch_post(encoder, pre_idx);
1361 }
1362
1363 /// Encode a dispatch with mixed buffer/bytes bindings and shared memory.
1364 ///
1365 /// Reuses the persistent compute encoder.
1366 pub fn encode_threadgroups_with_args_and_shared(
1367 &mut self,
1368 pipeline: &ComputePipelineStateRef,
1369 bindings: &[(u64, KernelArg<'_>)],
1370 threadgroup_mem: &[(u64, u64)],
1371 threadgroups: MTLSize,
1372 threadgroup_size: MTLSize,
1373 ) {
1374 DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
1375 bucket_dispatch(pipeline);
1376 let op_kind = self.take_pending_op_kind();
1377 let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
1378 if let Some(ref mut nodes) = self.capture {
1379 nodes.push(CapturedNode::Dispatch {
1380 pipeline: pipeline.to_owned(),
1381 bindings: Self::record_arg_bindings(bindings),
1382 threads_per_grid: threadgroups,
1383 threads_per_threadgroup: threadgroup_size,
1384 threadgroup_memory: threadgroup_mem.to_vec(),
1385 dispatch_kind: DispatchKind::ThreadGroups,
1386 op_kind,
1387 reads: pending_reads,
1388 writes: pending_writes,
1389 });
1390 return;
1391 }
1392 // ADR-015 iter63: per-dispatch sampling (no-op when env unset).
1393 self.ensure_sample_buffer();
1394 let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
1395 // SAFETY: see encode() above.
1396 let encoder = unsafe { &*encoder_ptr };
1397 encoder.set_compute_pipeline_state(pipeline);
1398 apply_bindings(encoder, bindings);
1399 for &(index, byte_length) in threadgroup_mem {
1400 encoder.set_threadgroup_memory_length(index, byte_length);
1401 }
1402 let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
1403 assert_tg_size_multiple_of_32_if_hinted(threadgroup_size, pipeline);
1404 encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
1405 self.sample_dispatch_post(encoder, pre_idx);
1406 }
1407
1408 // -----------------------------------------------------------------
1409 // ADR-015 iter37 — dataflow-driven auto-barrier dispatch family.
1410 //
1411 // These mirrors of `encode_threadgroups*_with_args*` take explicit
1412 // `reads: &[&MlxBuffer]` and `writes: &[&MlxBuffer]` slices. When
1413 // the process started with `HF2Q_AUTO_BARRIER=1`, the encoder's
1414 // [`MemRanges`] tracker checks the new ranges against the
1415 // cumulative state since the last barrier; on conflict it emits
1416 // `memory_barrier()` and resets the state before recording the
1417 // new ranges. When the env gate is unset, the check is skipped
1418 // entirely and the dispatch is applied identically to the
1419 // matching `encode_*` method — sourdough-safe by construction.
1420 //
1421 // Capture mode: the `reads`/`writes` ranges are recorded onto the
1422 // captured node via the existing `pending_reads`/`pending_writes`
1423 // mechanism, so a `dispatch_tracked` call inside capture mode is
1424 // equivalent to `set_pending_buffer_ranges + encode_*`.
1425 //
1426 // No production callsite migrates in iter37 — this is the API
1427 // surface the qwen35 forward path will adopt incrementally in
1428 // iter38+. Today, every call to `dispatch_tracked` from a
1429 // production code path lives behind an explicit caller decision
1430 // to opt in.
1431 // -----------------------------------------------------------------
1432
1433 /// Auto-barrier-aware dispatch with [`KernelArg`] bindings (uses
1434 /// `dispatch_thread_groups`).
1435 ///
1436 /// Behaves identically to
1437 /// [`encode_threadgroups_with_args`](Self::encode_threadgroups_with_args)
1438 /// when `HF2Q_AUTO_BARRIER` is unset. When set, consults the
1439 /// per-encoder [`MemRanges`] tracker:
1440 ///
1441 /// * Conflict (RAW/WAR/WAW on a same-buffer range) → emit
1442 /// `memory_barrier()`, increment [`AUTO_BARRIER_COUNT`], reset
1443 /// the tracker, then dispatch and seed the new concurrent group
1444 /// with this dispatch's ranges.
1445 /// * No conflict → increment [`AUTO_BARRIER_CONCURRENT`], record
1446 /// the ranges into the cumulative state, dispatch.
1447 pub fn dispatch_tracked_threadgroups_with_args(
1448 &mut self,
1449 pipeline: &ComputePipelineStateRef,
1450 bindings: &[(u64, KernelArg<'_>)],
1451 reads: &[&MlxBuffer],
1452 writes: &[&MlxBuffer],
1453 threadgroups: MTLSize,
1454 threadgroup_size: MTLSize,
1455 ) {
1456 // Capture mode: stash ranges + delegate to the standard encode.
1457 // The ranges flow through `pending_reads`/`pending_writes` and
1458 // attach to the captured `Dispatch` node — identical to what
1459 // `GraphSession::barrier_between` already does in capture mode.
1460 if self.is_capturing() {
1461 let read_ranges = ranges_from_buffers(reads);
1462 let write_ranges = ranges_from_buffers(writes);
1463 self.set_pending_buffer_ranges(read_ranges, write_ranges);
1464 self.encode_threadgroups_with_args(pipeline, bindings, threadgroups, threadgroup_size);
1465 return;
1466 }
1467
1468 if auto_barrier_enabled() {
1469 self.maybe_auto_barrier(reads, writes);
1470 }
1471
1472 self.encode_threadgroups_with_args(pipeline, bindings, threadgroups, threadgroup_size);
1473 }
1474
1475 /// Auto-barrier-aware dispatch with [`KernelArg`] bindings + shared
1476 /// threadgroup memory.
1477 ///
1478 /// See [`dispatch_tracked_threadgroups_with_args`](Self::dispatch_tracked_threadgroups_with_args)
1479 /// for the behavioral contract; this variant additionally takes a
1480 /// `threadgroup_mem` slice that is forwarded to
1481 /// [`encode_threadgroups_with_args_and_shared`](Self::encode_threadgroups_with_args_and_shared).
1482 ///
1483 /// The 8-argument signature mirrors the existing
1484 /// `encode_threadgroups_with_args_and_shared` plus the two
1485 /// dataflow slices; `clippy::too_many_arguments` is allowed
1486 /// because each parameter is load-bearing for either the dispatch
1487 /// (pipeline/bindings/threadgroups/threadgroup_size/shared_mem)
1488 /// or the auto-barrier (reads/writes).
1489 #[allow(clippy::too_many_arguments)]
1490 pub fn dispatch_tracked_threadgroups_with_args_and_shared(
1491 &mut self,
1492 pipeline: &ComputePipelineStateRef,
1493 bindings: &[(u64, KernelArg<'_>)],
1494 threadgroup_mem: &[(u64, u64)],
1495 reads: &[&MlxBuffer],
1496 writes: &[&MlxBuffer],
1497 threadgroups: MTLSize,
1498 threadgroup_size: MTLSize,
1499 ) {
1500 if self.is_capturing() {
1501 let read_ranges = ranges_from_buffers(reads);
1502 let write_ranges = ranges_from_buffers(writes);
1503 self.set_pending_buffer_ranges(read_ranges, write_ranges);
1504 self.encode_threadgroups_with_args_and_shared(
1505 pipeline,
1506 bindings,
1507 threadgroup_mem,
1508 threadgroups,
1509 threadgroup_size,
1510 );
1511 return;
1512 }
1513
1514 if auto_barrier_enabled() {
1515 self.maybe_auto_barrier(reads, writes);
1516 }
1517
1518 self.encode_threadgroups_with_args_and_shared(
1519 pipeline,
1520 bindings,
1521 threadgroup_mem,
1522 threadgroups,
1523 threadgroup_size,
1524 );
1525 }
1526
1527 /// Auto-barrier-aware dispatch using `(slot, &MlxBuffer)` bindings
1528 /// (uses `dispatch_thread_groups`).
1529 ///
1530 /// Convenience wrapper for callers that don't need
1531 /// [`KernelArg::Bytes`] inline-byte arguments. See
1532 /// [`dispatch_tracked_threadgroups_with_args`](Self::dispatch_tracked_threadgroups_with_args)
1533 /// for behavioral contract.
1534 pub fn dispatch_tracked_threadgroups(
1535 &mut self,
1536 pipeline: &ComputePipelineStateRef,
1537 buffers: &[(u64, &MlxBuffer)],
1538 reads: &[&MlxBuffer],
1539 writes: &[&MlxBuffer],
1540 threadgroups: MTLSize,
1541 threadgroup_size: MTLSize,
1542 ) {
1543 if self.is_capturing() {
1544 let read_ranges = ranges_from_buffers(reads);
1545 let write_ranges = ranges_from_buffers(writes);
1546 self.set_pending_buffer_ranges(read_ranges, write_ranges);
1547 self.encode_threadgroups(pipeline, buffers, threadgroups, threadgroup_size);
1548 return;
1549 }
1550
1551 if auto_barrier_enabled() {
1552 self.maybe_auto_barrier(reads, writes);
1553 }
1554
1555 self.encode_threadgroups(pipeline, buffers, threadgroups, threadgroup_size);
1556 }
1557
1558 /// Auto-barrier-aware dispatch using `(slot, &MlxBuffer)` bindings
1559 /// **plus shared threadgroup memory** (uses `dispatch_thread_groups`).
1560 ///
1561 /// Mirrors [`encode_threadgroups_with_shared`](Self::encode_threadgroups_with_shared)
1562 /// — convenience variant for kernels that allocate threadgroup
1563 /// memory (reductions in `rms_norm`, `softmax`, etc.) but don't
1564 /// need [`KernelArg::Bytes`] inline-byte arguments. See
1565 /// [`dispatch_tracked_threadgroups_with_args`](Self::dispatch_tracked_threadgroups_with_args)
1566 /// for the behavioral contract; the only addition here is the
1567 /// `threadgroup_mem` slice forwarded to the underlying encode.
1568 ///
1569 /// Closes the iter38-audit coverage gap: the 5 `rms_norm.rs`
1570 /// callsites (`/opt/mlx-native/src/ops/rms_norm.rs:124,236,443,
1571 /// 516,589`) all use `encode_threadgroups_with_shared` and need
1572 /// dataflow tracking when migrated to auto-barrier in iter40+.
1573 ///
1574 /// 7-argument signature; `clippy::too_many_arguments` is allowed
1575 /// because each parameter is load-bearing for either the dispatch
1576 /// (pipeline/buffers/threadgroups/threadgroup_size/shared_mem) or
1577 /// the auto-barrier (reads/writes).
1578 #[allow(clippy::too_many_arguments)]
1579 pub fn dispatch_tracked_threadgroups_with_shared(
1580 &mut self,
1581 pipeline: &ComputePipelineStateRef,
1582 buffers: &[(u64, &MlxBuffer)],
1583 threadgroup_mem: &[(u64, u64)],
1584 reads: &[&MlxBuffer],
1585 writes: &[&MlxBuffer],
1586 threadgroups: MTLSize,
1587 threadgroup_size: MTLSize,
1588 ) {
1589 if self.is_capturing() {
1590 let read_ranges = ranges_from_buffers(reads);
1591 let write_ranges = ranges_from_buffers(writes);
1592 self.set_pending_buffer_ranges(read_ranges, write_ranges);
1593 self.encode_threadgroups_with_shared(
1594 pipeline,
1595 buffers,
1596 threadgroup_mem,
1597 threadgroups,
1598 threadgroup_size,
1599 );
1600 return;
1601 }
1602
1603 if auto_barrier_enabled() {
1604 self.maybe_auto_barrier(reads, writes);
1605 }
1606
1607 self.encode_threadgroups_with_shared(
1608 pipeline,
1609 buffers,
1610 threadgroup_mem,
1611 threadgroups,
1612 threadgroup_size,
1613 );
1614 }
1615
1616 /// Auto-barrier-aware `dispatch_threads` variant with
1617 /// [`KernelArg`] bindings.
1618 ///
1619 /// Mirrors [`encode_with_args`](Self::encode_with_args) — the
1620 /// `dispatch_threads` (per-thread grid) flavor, as opposed to the
1621 /// `dispatch_thread_groups` flavor of
1622 /// [`dispatch_tracked_threadgroups_with_args`](Self::dispatch_tracked_threadgroups_with_args).
1623 /// See that method for the behavioral contract.
1624 ///
1625 /// Closes the iter38-audit coverage gap: callers that use
1626 /// per-thread grids — `rope.rs:108` (IMROPE), `sigmoid_mul.rs:76`
1627 /// (sigmoid-mul), and `encode_helpers.rs:41` (kv_cache_copy) —
1628 /// need a `dispatch_threads` flavor of the tracked dispatch
1629 /// because their grid sizes are expressed in threads, not
1630 /// threadgroups.
1631 ///
1632 /// Note: the simpler `(slot, &MlxBuffer)` form (from
1633 /// [`encode`](Self::encode)) is a special case of this method —
1634 /// callers can wrap each binding as `KernelArg::Buffer(buf)` to
1635 /// reuse this single tracked variant rather than introducing a
1636 /// fifth one.
1637 pub fn dispatch_tracked_threads_with_args(
1638 &mut self,
1639 pipeline: &ComputePipelineStateRef,
1640 bindings: &[(u64, KernelArg<'_>)],
1641 reads: &[&MlxBuffer],
1642 writes: &[&MlxBuffer],
1643 grid_size: MTLSize,
1644 threadgroup_size: MTLSize,
1645 ) {
1646 if self.is_capturing() {
1647 let read_ranges = ranges_from_buffers(reads);
1648 let write_ranges = ranges_from_buffers(writes);
1649 self.set_pending_buffer_ranges(read_ranges, write_ranges);
1650 self.encode_with_args(pipeline, bindings, grid_size, threadgroup_size);
1651 return;
1652 }
1653
1654 if auto_barrier_enabled() {
1655 self.maybe_auto_barrier(reads, writes);
1656 }
1657
1658 self.encode_with_args(pipeline, bindings, grid_size, threadgroup_size);
1659 }
1660
1661 /// Dispatch a pre-baked record.
1662 ///
1663 /// ADR-029 iter-175 Step 1d — fast path for decode hot kernels
1664 /// whose pipeline + threadgroup geometry + params bytes are
1665 /// load-time-immutable. `runtime_buffers` must be in the same
1666 /// order as `rec.buffer_slots`.
1667 ///
1668 /// Equivalent Metal command stream to:
1669 /// ```ignore
1670 /// encoder.encode_threadgroups_with_args_and_shared(
1671 /// &rec.pipeline,
1672 /// bindings, // = runtime_buffers zipped with buffer_slots + (params_slot, Bytes(&rec.params_bytes))
1673 /// &rec.threadgroup_mem,
1674 /// rec.threadgroups,
1675 /// rec.threads_per_tg,
1676 /// );
1677 /// ```
1678 /// — but skips the kernel-name lookup, ggml_type match arms,
1679 /// MTLSize::new, and param-struct field stores that the unbaked
1680 /// path performs on every call.
1681 ///
1682 /// Capture mode and auto-barrier are supported identically to
1683 /// `encode_threadgroups_with_args_and_shared`. The caller is
1684 /// expected to have called `set_pending_buffer_ranges` (capture)
1685 /// or rely on auto-barrier for dataflow correctness before this
1686 /// call, matching the contract of the unbaked dispatch_tracked_*
1687 /// family.
1688 pub fn dispatch_record(
1689 &mut self,
1690 rec: &DispatchRecord,
1691 runtime_buffers: &[&MlxBuffer],
1692 ) {
1693 debug_assert_eq!(
1694 rec.buffer_slots.len(),
1695 runtime_buffers.len(),
1696 "dispatch_record: runtime_buffers count must match buffer_slots ({}); got {}",
1697 rec.buffer_slots.len(),
1698 runtime_buffers.len(),
1699 );
1700
1701 DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
1702 bucket_dispatch(&rec.pipeline);
1703 let op_kind_override = self.take_pending_op_kind();
1704 // If a caller set an op_kind override via set_op_kind(), honor it;
1705 // otherwise use the baked op_kind from the record.
1706 let op_kind = if matches!(op_kind_override, CapturedOpKind::Other) {
1707 rec.op_kind
1708 } else {
1709 op_kind_override
1710 };
1711 let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
1712
1713 if let Some(ref mut nodes) = self.capture {
1714 // Reconstruct bindings for replay — runtime buffers first,
1715 // params bytes (if any) last. Order matches what the
1716 // baked-path runtime encoding produces below.
1717 let cap = runtime_buffers.len() + if rec.params_bytes.is_empty() { 0 } else { 1 };
1718 let mut bindings: Vec<(u64, RecordedBinding)> = Vec::with_capacity(cap);
1719 for (slot, buf) in rec.buffer_slots.iter().zip(runtime_buffers.iter()) {
1720 bindings.push((
1721 *slot,
1722 RecordedBinding::Buffer {
1723 metal_buffer: buf.metal_buffer().to_owned(),
1724 offset: buf.byte_offset(),
1725 },
1726 ));
1727 }
1728 if !rec.params_bytes.is_empty() {
1729 bindings.push((
1730 rec.params_slot,
1731 RecordedBinding::Bytes(rec.params_bytes.clone()),
1732 ));
1733 }
1734 nodes.push(CapturedNode::Dispatch {
1735 pipeline: rec.pipeline.clone(),
1736 bindings,
1737 threads_per_grid: rec.threadgroups,
1738 threads_per_threadgroup: rec.threads_per_tg,
1739 threadgroup_memory: rec.threadgroup_mem.clone(),
1740 dispatch_kind: DispatchKind::ThreadGroups,
1741 op_kind,
1742 reads: pending_reads,
1743 writes: pending_writes,
1744 });
1745 return;
1746 }
1747
1748 // ADR-015 iter63: per-dispatch sampling (no-op when env unset).
1749 self.ensure_sample_buffer();
1750 let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
1751 // SAFETY: see encode() above — encoder reference outlives this scope
1752 // because `get_or_create_encoder` only mutates the `Option` wrapper.
1753 let encoder = unsafe { &*encoder_ptr };
1754 encoder.set_compute_pipeline_state(&rec.pipeline);
1755 for (slot, buf) in rec.buffer_slots.iter().zip(runtime_buffers.iter()) {
1756 encoder.set_buffer(*slot, Some(buf.metal_buffer()), buf.byte_offset());
1757 }
1758 if !rec.params_bytes.is_empty() {
1759 encoder.set_bytes(
1760 rec.params_slot,
1761 rec.params_bytes.len() as u64,
1762 rec.params_bytes.as_ptr() as *const _,
1763 );
1764 }
1765 for &(idx, len) in rec.threadgroup_mem.iter() {
1766 encoder.set_threadgroup_memory_length(idx, len);
1767 }
1768 let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
1769 // Skip assert_tg_size_multiple_of_32_if_hinted: bake-time
1770 // construction already validated the geometry.
1771 encoder.dispatch_thread_groups(rec.threadgroups, rec.threads_per_tg);
1772 self.sample_dispatch_post(encoder, pre_idx);
1773 }
1774
1775 /// Run the dataflow check, emit a barrier on conflict, and record
1776 /// the dispatch's ranges into the cumulative state.
1777 ///
1778 /// Always called *before* the underlying `encode_*` method
1779 /// applies the dispatch. Mirrors lines 220-225 of
1780 /// `ggml-metal-ops.cpp` (`concurrency_check + concurrency_reset +
1781 /// concurrency_add` around each node).
1782 fn maybe_auto_barrier(
1783 &mut self,
1784 reads: &[&MlxBuffer],
1785 writes: &[&MlxBuffer],
1786 ) {
1787 if self.mem_ranges.check_dispatch(reads, writes) {
1788 // Concurrent — no barrier needed; just record the new ranges.
1789 self.mem_ranges.add_dispatch(reads, writes);
1790 AUTO_BARRIER_CONCURRENT.fetch_add(1, Ordering::Relaxed);
1791 } else {
1792 // Conflict — emit barrier, reset state, seed new group.
1793 //
1794 // `memory_barrier()` itself increments `BARRIER_COUNT` and,
1795 // when `MLX_PROFILE_BARRIERS=1`, accumulates `BARRIER_NS`.
1796 // We additionally bump `AUTO_BARRIER_COUNT` so the
1797 // "auto-emitted vs hand-placed" subset is queryable.
1798 self.memory_barrier();
1799 self.mem_ranges.reset();
1800 self.mem_ranges.add_dispatch(reads, writes);
1801 AUTO_BARRIER_COUNT.fetch_add(1, Ordering::Relaxed);
1802 }
1803 }
1804
1805 /// Force a barrier and reset the auto-barrier tracker.
1806 ///
1807 /// Use at boundaries where the caller knows a barrier is required
1808 /// regardless of dataflow — typically before reading data back to
1809 /// CPU, or at the end of an op group whose internal dependencies
1810 /// the tracker can't see (e.g. host-driven memcpy).
1811 ///
1812 /// Equivalent to `memory_barrier()` plus a `MemRanges::reset()`
1813 /// when `HF2Q_AUTO_BARRIER=1`; equivalent to plain
1814 /// `memory_barrier()` otherwise.
1815 pub fn force_barrier_and_reset_tracker(&mut self) {
1816 self.memory_barrier();
1817 if auto_barrier_enabled() {
1818 self.mem_ranges.reset();
1819 }
1820 }
1821
1822 /// Diagnostic accessor — number of ranges currently recorded in
1823 /// this encoder's [`MemRanges`] tracker. Always zero unless
1824 /// `HF2Q_AUTO_BARRIER=1` and at least one `dispatch_tracked` call
1825 /// has fired since the last conflict.
1826 #[inline]
1827 pub fn mem_ranges_len(&self) -> usize {
1828 self.mem_ranges.len()
1829 }
1830
1831 /// Replay a single captured dispatch node into this encoder.
1832 ///
1833 /// This is the inverse of capture: it takes a previously recorded
1834 /// `CapturedNode::Dispatch` and encodes it into the live Metal encoder.
1835 /// Barrier nodes are handled by the caller (ComputeGraph::encode_sequential).
1836 ///
1837 /// Does NOT increment `DISPATCH_COUNT` — that was already counted at
1838 /// capture time.
1839 pub fn replay_dispatch(
1840 &mut self,
1841 pipeline: &ComputePipelineStateRef,
1842 bindings: &[(u64, RecordedBinding)],
1843 threadgroup_memory: &[(u64, u64)],
1844 threads_per_grid: MTLSize,
1845 threads_per_threadgroup: MTLSize,
1846 dispatch_kind: DispatchKind,
1847 ) {
1848 // ADR-015 iter63 (Phase A.3): mirror the per-dispatch sampling
1849 // scaffold here so capture-mode-recorded graphs (graph.rs
1850 // encode_sequential / encode_with_barriers / encode_chunk_with
1851 // _barriers) still produce per-dispatch entries. The replay
1852 // path bypasses encode*; without this hook the per-dispatch
1853 // table would be silently empty for any model that uses
1854 // `GraphExecutor::begin_recorded`.
1855 //
1856 // Captured `op_kind` is forwarded via `pending_op_kind`: the
1857 // graph replay layer at graph.rs:197/236/727 sets it from the
1858 // CapturedNode.op_kind before calling replay_dispatch.
1859 self.ensure_sample_buffer();
1860 let op_kind = self.take_pending_op_kind();
1861 let encoder_ptr = self.get_or_create_encoder() as *const ComputeCommandEncoderRef;
1862 // SAFETY: see encode() above.
1863 let encoder = unsafe { &*encoder_ptr };
1864 encoder.set_compute_pipeline_state(pipeline);
1865 for (index, binding) in bindings {
1866 match binding {
1867 RecordedBinding::Buffer { metal_buffer, offset } => {
1868 encoder.set_buffer(*index, Some(metal_buffer), *offset);
1869 }
1870 RecordedBinding::Bytes(bytes) => {
1871 encoder.set_bytes(
1872 *index,
1873 bytes.len() as u64,
1874 bytes.as_ptr() as *const _,
1875 );
1876 }
1877 }
1878 }
1879 for &(index, byte_length) in threadgroup_memory {
1880 encoder.set_threadgroup_memory_length(index, byte_length);
1881 }
1882 let pre_idx = self.sample_dispatch_pre(encoder, op_kind);
1883 match dispatch_kind {
1884 DispatchKind::Threads => {
1885 assert_tg_size_multiple_of_32_if_hinted(threads_per_threadgroup, pipeline);
1886 encoder.dispatch_threads(threads_per_grid, threads_per_threadgroup);
1887 }
1888 DispatchKind::ThreadGroups => {
1889 assert_tg_size_multiple_of_32_if_hinted(threads_per_threadgroup, pipeline);
1890 encoder.dispatch_thread_groups(threads_per_grid, threads_per_threadgroup);
1891 }
1892 }
1893 self.sample_dispatch_post(encoder, pre_idx);
1894 }
1895
1896 /// Flush any pending residency-set add/remove staging.
1897 ///
1898 /// Hooked at every commit boundary so per-allocation
1899 /// [`ResidencySet::add_allocation`](ResidencySet::add_allocation) and
1900 /// [`ResidencySet::remove_allocation`](ResidencySet::remove_allocation)
1901 /// calls (as fired by `MlxDevice::alloc_buffer` and
1902 /// `MlxBufferStorage::Drop`) collapse into at most ONE `[set commit]`
1903 /// per CB submission. Mirrors llama.cpp's
1904 /// `ggml-metal-device.m:1378-1382` (batch addAllocation in loop,
1905 /// commit ONCE).
1906 #[inline]
1907 fn flush_residency_pending(&self) {
1908 if let Some(set) = self.residency_set.as_ref() {
1909 set.flush_pending();
1910 }
1911 }
1912
1913 // ----------------------------------------------------------------
1914 // ADR-015 iter63 — per-dispatch sample buffer lifecycle
1915 // ----------------------------------------------------------------
1916
1917 /// Allocate the per-CB `MTLCounterSampleBuffer` if it has not been
1918 /// allocated yet for this CB.
1919 ///
1920 /// No-op when `MLX_PROFILE_DISPATCH` is unset, when the buffer is
1921 /// already present, or when the device does not expose a counter
1922 /// set named `"timestamp"` (Risk R1 — graceful degrade with a
1923 /// one-shot stderr warning).
1924 ///
1925 /// The sample buffer is sized to [`MAX_SAMPLES_PER_CB`] (32_768).
1926 /// This is the start-+-end pair budget — i.e. ≤ 16,384 dispatches
1927 /// per CB. Above that ceiling, additional dispatches will skip
1928 /// sampling (see [`Self::sample_dispatch_pre`]).
1929 #[inline]
1930 fn ensure_sample_buffer(&mut self) {
1931 if !crate::kernel_profile::is_dispatch_enabled() {
1932 return;
1933 }
1934 if self.sample_buffer.is_some() {
1935 return;
1936 }
1937 // Discover the timestamp counter set. metal-rs 0.33 does not
1938 // export the `MTLCommonCounterSetTimestamp` constant, so we
1939 // name-match `"timestamp"` case-insensitively. Reach the
1940 // device via the cmd_buf's `device` selector (metal-rs 0.33
1941 // exposes `CommandQueue::device` but not `CommandBuffer::device`,
1942 // so we go through ObjC directly).
1943 let device: &metal::DeviceRef = unsafe {
1944 let cb = &*self.cmd_buf;
1945 msg_send![cb, device]
1946 };
1947 // ADR-015 iter63 — Apple Silicon hardware constraint (NEW Risk
1948 // discovered at impl time, supersedes design §A.7). M-series
1949 // GPUs (verified: AGXG17XFamilyComputeContext = M5 Max series,
1950 // macOS 26) only support counter sampling AtStageBoundary —
1951 // i.e. between compute *passes*, not between dispatches inside
1952 // a persistent compute encoder. Calling
1953 // `sampleCountersInBuffer:atSampleIndex:withBarrier:` on such
1954 // hardware aborts with `failed assertion ... not supported on
1955 // this device`. The persistent-encoder design (mlx-native uses
1956 // ONE compute encoder per CB to amortize ~800 encoder
1957 // create/end cycles per forward pass — see `get_or_create_
1958 // encoder` docstring) is incompatible with stage-boundary-only
1959 // sampling, so on Apple Silicon we degrade per-dispatch
1960 // profiling to a no-op and log once. Per-CB profiling is
1961 // unaffected (it uses MTLCommandBuffer.GPUStartTime/
1962 // GPUEndTime, which are always available).
1963 //
1964 // Future: if Apple ever ships AtDispatchBoundary support on
1965 // Apple Silicon, this branch becomes a true cap check. For
1966 // now, the kit infrastructure is in place; only the sample-
1967 // point cooperates.
1968 if !device.supports_counter_sampling(MTLCounterSamplingPoint::AtDispatchBoundary) {
1969 if TIMESTAMP_SET_WARN_LOGGED
1970 .compare_exchange(0, 1, Ordering::Relaxed, Ordering::Relaxed)
1971 .is_ok()
1972 {
1973 eprintln!(
1974 "[mlx-native] MLX_PROFILE_DISPATCH=1 ignored: \
1975 device {:?} does NOT support \
1976 MTLCounterSamplingPointAtDispatchBoundary \
1977 (Apple Silicon limitation; only AtStageBoundary \
1978 is supported, which is incompatible with the \
1979 persistent compute-encoder pattern). \
1980 MLX_PROFILE_CB=1 still produces per-CB GPU times.",
1981 device.name()
1982 );
1983 }
1984 return;
1985 }
1986 let counter_sets = device.counter_sets();
1987 let timestamp_set = counter_sets
1988 .iter()
1989 .find(|c: &&metal::CounterSet| c.name().eq_ignore_ascii_case("timestamp"));
1990 let timestamp_set = match timestamp_set {
1991 Some(s) => s,
1992 None => {
1993 // Risk R1: device does not expose a timestamp set.
1994 // Log once and degrade to no-op (sample_buffer stays None).
1995 if TIMESTAMP_SET_WARN_LOGGED
1996 .compare_exchange(0, 1, Ordering::Relaxed, Ordering::Relaxed)
1997 .is_ok()
1998 {
1999 eprintln!(
2000 "[mlx-native] MLX_PROFILE_DISPATCH=1 ignored: \
2001 device {:?} exposes no MTLCommonCounterSetTimestamp",
2002 device.name()
2003 );
2004 }
2005 return;
2006 }
2007 };
2008 // Build descriptor. StorageMode::Shared is required by
2009 // resolveCounterRange (MTLCounters.h:185-188).
2010 let descriptor = CounterSampleBufferDescriptor::new();
2011 descriptor.set_counter_set(timestamp_set);
2012 descriptor.set_storage_mode(MTLStorageMode::Shared);
2013 descriptor.set_label("mlx_native.dispatch_samples");
2014 descriptor.set_sample_count(MAX_SAMPLES_PER_CB);
2015 match device.new_counter_sample_buffer_with_descriptor(&descriptor) {
2016 Ok(buf) => {
2017 self.sample_buffer = Some(buf);
2018 }
2019 Err(e) => {
2020 if TIMESTAMP_SET_WARN_LOGGED
2021 .compare_exchange(0, 1, Ordering::Relaxed, Ordering::Relaxed)
2022 .is_ok()
2023 {
2024 eprintln!(
2025 "[mlx-native] MLX_PROFILE_DISPATCH=1 ignored: \
2026 newCounterSampleBufferWithDescriptor failed: {}",
2027 e
2028 );
2029 }
2030 self.sample_buffer = None;
2031 }
2032 }
2033 }
2034
2035 /// Insert the start-of-dispatch counter sample (sample index `2*i`)
2036 /// and queue the per-dispatch metadata. Returns the dispatch
2037 /// ordinal `i` so the caller can emit the matching post-sample.
2038 ///
2039 /// No-op when sampling is inactive — returns 0 in that case (the
2040 /// returned value is only consumed when the sample buffer is
2041 /// active, so this is safe).
2042 ///
2043 /// `with_barrier:true` is mandatory: the encoder uses
2044 /// `MTLDispatchTypeConcurrent` and without the barrier the start
2045 /// timestamp would race against any in-flight dispatch (PROFILING-
2046 /// KIT-DESIGN §A.5).
2047 #[inline]
2048 fn sample_dispatch_pre(
2049 &mut self,
2050 encoder: &ComputeCommandEncoderRef,
2051 op_kind: CapturedOpKind,
2052 ) -> Option<u32> {
2053 let sb = self.sample_buffer.as_ref()?;
2054 let i = self.dispatch_in_cb;
2055 let pre_idx = (i as u64).checked_mul(2)?;
2056 if pre_idx >= MAX_SAMPLES_PER_CB {
2057 // Ceiling exceeded — skip sampling for the remainder of
2058 // this CB. Risk R4 (PROFILING-KIT-DESIGN §A.7): future
2059 // iter can chunk-resolve every N dispatches; for now we
2060 // accept truncation with a one-shot warning (re-uses the
2061 // R1 warn flag).
2062 return None;
2063 }
2064 encoder.sample_counters_in_buffer(sb, pre_idx, true);
2065 self.pending_dispatch_meta.push(PendingDispatchMeta {
2066 op_kind: op_kind.name(),
2067 dispatch_index: i,
2068 });
2069 Some(i)
2070 }
2071
2072 /// Insert the end-of-dispatch counter sample (sample index `2*i+1`)
2073 /// matching the most recent [`Self::sample_dispatch_pre`].
2074 ///
2075 /// No-op when sampling is inactive or when `pre_idx` is `None`.
2076 #[inline]
2077 fn sample_dispatch_post(
2078 &mut self,
2079 encoder: &ComputeCommandEncoderRef,
2080 pre_idx: Option<u32>,
2081 ) {
2082 let i = match pre_idx {
2083 Some(v) => v,
2084 None => return,
2085 };
2086 let sb = match self.sample_buffer.as_ref() {
2087 Some(b) => b,
2088 None => return,
2089 };
2090 let post_idx = match (i as u64).checked_mul(2).and_then(|v| v.checked_add(1)) {
2091 Some(v) if v < MAX_SAMPLES_PER_CB => v,
2092 _ => return,
2093 };
2094 encoder.sample_counters_in_buffer(sb, post_idx, true);
2095 // Bump the per-CB ordinal only after both samples committed
2096 // successfully so a truncation skip leaves the meta queue
2097 // length matching the buffer's resolved range.
2098 self.dispatch_in_cb = i.saturating_add(1);
2099 }
2100
2101 /// Resolve the per-CB sample buffer, push entries into
2102 /// [`crate::kernel_profile`], and reset per-CB state.
2103 ///
2104 /// Called from [`Self::commit_and_wait_labeled`] after the CB
2105 /// completes; the caller is responsible for ensuring the GPU has
2106 /// finished (otherwise `resolveCounterRange` returns garbage).
2107 ///
2108 /// On the first resolve after a [`crate::kernel_profile::reset`],
2109 /// also captures a `(cpu_ns, gpu_ticks)` pair via
2110 /// `device.sampleTimestamps` so subsequent ticks→ns conversion
2111 /// uses a fresh scale factor.
2112 fn resolve_dispatch_samples(&mut self, cb_label: &str) -> Result<()> {
2113 let sb = match self.sample_buffer.take() {
2114 Some(b) => b,
2115 None => {
2116 self.pending_dispatch_meta.clear();
2117 self.dispatch_in_cb = 0;
2118 return Ok(());
2119 }
2120 };
2121 let n = self.pending_dispatch_meta.len();
2122 if n == 0 {
2123 self.dispatch_in_cb = 0;
2124 return Ok(());
2125 }
2126 // Refresh the (cpu, gpu) scale pair on every resolve; the
2127 // device call is cheap and keeps us robust against driver-side
2128 // timebase changes between CBs.
2129 let mut cpu_t: u64 = 0;
2130 let mut gpu_t: u64 = 0;
2131 let device: &metal::DeviceRef = unsafe {
2132 let cb = &*self.cmd_buf;
2133 msg_send![cb, device]
2134 };
2135 device.sample_timestamps(&mut cpu_t, &mut gpu_t);
2136 crate::kernel_profile::record_clock_pair(cpu_t, gpu_t);
2137 let length = (n as u64).saturating_mul(2);
2138 let data = sb.resolve_counter_range(NSRange {
2139 location: 0,
2140 length,
2141 });
2142 // `resolve_counter_range` returns one NSUInteger per sample.
2143 // Pair them up: data[2i] = start, data[2i+1] = end.
2144 for (i, meta) in self.pending_dispatch_meta.drain(..).enumerate() {
2145 let start_idx = 2 * i;
2146 let end_idx = 2 * i + 1;
2147 if end_idx >= data.len() {
2148 break;
2149 }
2150 let start_raw = data[start_idx] as u64;
2151 let end_raw = data[end_idx] as u64;
2152 let start_ns = crate::kernel_profile::convert_gpu_ticks_to_ns(start_raw);
2153 let end_ns = crate::kernel_profile::convert_gpu_ticks_to_ns(end_raw);
2154 let gpu_ns = end_ns.saturating_sub(start_ns);
2155 crate::kernel_profile::record_dispatch(
2156 crate::kernel_profile::DispatchEntry {
2157 cb_label: cb_label.to_string(),
2158 op_kind: meta.op_kind,
2159 dispatch_index: meta.dispatch_index,
2160 gpu_ns,
2161 start_gpu_ns: start_ns,
2162 end_gpu_ns: end_ns,
2163 },
2164 );
2165 }
2166 // Buffer dropped at end of scope releases the underlying
2167 // CounterSampleBuffer; per-CB lifetime correctly bounded.
2168 drop(sb);
2169 self.dispatch_in_cb = 0;
2170 Ok(())
2171 }
2172
2173 /// Commit the command buffer and block until the GPU finishes execution.
2174 ///
2175 /// # Errors
2176 ///
2177 /// Returns `MlxError::CommandBufferError` if the GPU reports an error.
2178 pub fn commit_and_wait(&mut self) -> Result<()> {
2179 SYNC_COUNT.fetch_add(1, Ordering::Relaxed);
2180
2181 // End the persistent compute encoder before committing.
2182 self.end_active_encoder();
2183
2184 // ADR-015 iter8e (Phase 3b): flush deferred residency-set
2185 // add/remove staging so the residency hint covers any buffers
2186 // referenced by this CB. Single commit per CB boundary; no-op
2187 // when no residency set or no staged changes.
2188 self.flush_residency_pending();
2189
2190 self.cmd_buf.commit();
2191 self.cmd_buf.wait_until_completed();
2192
2193 match self.cmd_buf.status() {
2194 MTLCommandBufferStatus::Completed => Ok(()),
2195 MTLCommandBufferStatus::Error => {
2196 Err(MlxError::CommandBufferError(
2197 "GPU command buffer completed with error status".into(),
2198 ))
2199 }
2200 status => Err(MlxError::CommandBufferError(format!(
2201 "Unexpected command buffer status after wait: {:?}",
2202 status
2203 ))),
2204 }
2205 }
2206
2207 /// Commit + wait, accumulating GPU wall-clock time under `label` into
2208 /// the [`crate::kernel_profile`] global table when `MLX_PROFILE_CB=1`
2209 /// is set. When the env var is unset, this is identical to
2210 /// [`commit_and_wait`](Self::commit_and_wait) — zero overhead.
2211 ///
2212 /// Used by hf2q's decode hot path to attribute per-cb GPU time to
2213 /// labeled phases (per-layer attn, per-layer ffn, output_head, etc.)
2214 /// without manually wiring `commit_wait_with_gpu_time` everywhere.
2215 ///
2216 /// # Errors
2217 ///
2218 /// Returns `MlxError::CommandBufferError` if the GPU reports an error.
2219 pub fn commit_and_wait_labeled(&mut self, label: &str) -> Result<()> {
2220 // ADR-015 iter16 — propagate `label` to MTLCommandBuffer.setLabel and
2221 // (if a compute encoder is active) MTLComputeCommandEncoder.setLabel
2222 // BEFORE end_encoding/commit so xctrace's
2223 // `metal-application-encoders-list` table populates `cmdbuffer-label`
2224 // and `encoder-label` columns with the semantic phase name (e.g.
2225 // `layer.attn_moe_ffn`, `output_head.fused_norm_lm_argmax`,
2226 // `layer.delta_net.ops1-9`). Joined to per-CB GPU duration via
2227 // `metal-gpu-submission-to-command-buffer-id` (sub_id ↔ encoder_id) →
2228 // `metal-gpu-execution-points` (per-dispatch start/end), this enables
2229 // per-phase µs/token attribution comparing hf2q vs llama side-by-side
2230 // (iter15 §E "iter16 ATTRIBUTION PATH"). Cost is a single ObjC
2231 // msg_send per CB submission — sub-µs on M5 Max — and a no-op when
2232 // xctrace isn't recording, so this is unconditionally safe to call on
2233 // the production decode hot path.
2234 self.apply_labels(label);
2235 // ADR-015 iter63: record GPU time AND resolve per-dispatch samples
2236 // when either env gate is set. Per-dispatch sampling force-enables
2237 // the per-CB path so cross-validation per Risk R3 always has a
2238 // ground-truth comparator.
2239 let need_gpu_time =
2240 crate::kernel_profile::is_enabled() || crate::kernel_profile::is_dispatch_enabled();
2241 if need_gpu_time {
2242 let (start_s, end_s) = self.commit_wait_with_gpu_time()?;
2243 let ns = ((end_s - start_s).max(0.0) * 1_000_000_000.0) as u64;
2244 if crate::kernel_profile::is_enabled() {
2245 crate::kernel_profile::record(label, ns);
2246 }
2247 if crate::kernel_profile::is_dispatch_enabled() {
2248 self.resolve_dispatch_samples(label)?;
2249 }
2250 Ok(())
2251 } else {
2252 self.commit_and_wait()
2253 }
2254 }
2255
2256 /// Async commit, but with profiling label. When `MLX_PROFILE_CB=1`
2257 /// is set, redirects to a synchronous [`commit_and_wait_labeled`]
2258 /// call to capture per-cb GPU time (this defeats async pipelining
2259 /// while profiling, which is the whole point — profile-mode is slow
2260 /// but informative). When unset, identical to [`commit`](Self::commit).
2261 pub fn commit_labeled(&mut self, label: &str) {
2262 // ADR-015 iter16 — see `commit_and_wait_labeled` for rationale.
2263 if crate::kernel_profile::is_enabled() {
2264 // Profile mode: force sync to capture GPU time. apply_labels is
2265 // called inside commit_and_wait_labeled — do NOT call it twice
2266 // here (would double the ObjC msg_send under MLX_PROFILE_CB=1).
2267 // Errors are logged via stderr because the void return matches
2268 // commit().
2269 if let Err(e) = self.commit_and_wait_labeled(label) {
2270 eprintln!("[mlx-native] commit_labeled({}) failed: {}", label, e);
2271 }
2272 } else {
2273 // Async path: apply labels here so xctrace MST traces capture
2274 // per-CB phase attribution under default decode (no
2275 // `MLX_PROFILE_CB`).
2276 self.apply_labels(label);
2277 self.commit();
2278 }
2279 }
2280
2281 /// Apply `label` to the underlying `MTLCommandBuffer` and, if a compute
2282 /// encoder is currently active, to the `MTLComputeCommandEncoder`.
2283 ///
2284 /// Called from [`commit_labeled`] and [`commit_and_wait_labeled`] BEFORE
2285 /// the encoder is ended / the CB is committed so xctrace's
2286 /// `metal-application-encoders-list` table picks up the label on the
2287 /// row emitted at the encoder's `endEncoding` / CB submission boundary.
2288 /// Single ObjC `msg_send` per call (two if an encoder is active); sub-µs
2289 /// on M5 Max; no-op when xctrace isn't recording.
2290 ///
2291 /// Skipped (debug-only assert) if `label` is empty — empty labels would
2292 /// produce an indistinguishable trace row from the metal-rs default
2293 /// `Command Buffer 0` placeholder.
2294 #[inline]
2295 fn apply_labels(&mut self, label: &str) {
2296 debug_assert!(!label.is_empty(), "commit_*_labeled called with empty label");
2297 if label.is_empty() {
2298 return;
2299 }
2300 self.cmd_buf.set_label(label);
2301 if !self.active_encoder.is_null() {
2302 // SAFETY: active_encoder is non-null and points to a live encoder
2303 // owned by cmd_buf — same invariant as get_or_create_encoder /
2304 // memory_barrier. set_label is a single property write on the
2305 // ObjC object; safe before endEncoding.
2306 unsafe { &*self.active_encoder }.set_label(label);
2307 }
2308 // ADR-015 iter63: capture the most recent label for per-dispatch
2309 // entries. Cheap String allocation — only happens at CB commit
2310 // boundaries, not per dispatch.
2311 self.last_label.clear();
2312 self.last_label.push_str(label);
2313 }
2314
2315 /// Commit + wait, returning `(gpu_start_s, gpu_end_s)` CFTimeInterval
2316 /// timestamps from `MTLCommandBuffer`'s `GPUStartTime`/`GPUEndTime`
2317 /// properties. Both are mach-absolute CFTimeInterval seconds (double).
2318 ///
2319 /// Intended for `HF2Q_PROFILE_GPU_TS=1` per-bucket GPU wall-clock
2320 /// attribution. Adds exactly two ObjC property reads per call on top
2321 /// of the regular `commit_and_wait` — measured well under 1 μs on
2322 /// M5 Max.
2323 ///
2324 /// # Errors
2325 ///
2326 /// Returns `MlxError::CommandBufferError` if the GPU reports an error.
2327 pub fn commit_wait_with_gpu_time(&mut self) -> Result<(f64, f64)> {
2328 self.commit_and_wait()?;
2329 // SAFETY: cmd_buf is a valid MTLCommandBuffer that has been
2330 // committed and awaited. GPUStartTime / GPUEndTime return
2331 // CFTimeInterval (double precision seconds). See
2332 // https://developer.apple.com/documentation/metal/mtlcommandbuffer/1639925-gpustarttime
2333 let (gpu_start, gpu_end): (f64, f64) = unsafe {
2334 let cb = &*self.cmd_buf;
2335 let s: f64 = msg_send![cb, GPUStartTime];
2336 let e: f64 = msg_send![cb, GPUEndTime];
2337 (s, e)
2338 };
2339 Ok((gpu_start, gpu_end))
2340 }
2341
2342 /// Commit the command buffer WITHOUT blocking.
2343 ///
2344 /// The GPU begins executing the encoded commands immediately. Call
2345 /// [`wait_until_completed`](Self::wait_until_completed) later to block
2346 /// the CPU and check for errors. This allows the CPU to continue doing
2347 /// other work (e.g. preparing the next batch) while the GPU runs.
2348 pub fn commit(&mut self) {
2349 self.end_active_encoder();
2350 // ADR-015 iter8e (Phase 3b): same flush hook as commit_and_wait —
2351 // this is the async-pipeline path that production decode uses.
2352 self.flush_residency_pending();
2353 self.cmd_buf.commit();
2354 }
2355
2356 /// Block until a previously committed command buffer completes.
2357 ///
2358 /// Must be called after [`commit`](Self::commit). Do not call after
2359 /// [`commit_and_wait`](Self::commit_and_wait) — that method already waits.
2360 ///
2361 /// # Errors
2362 ///
2363 /// Returns `MlxError::CommandBufferError` if the GPU reports an error.
2364 pub fn wait_until_completed(&self) -> Result<()> {
2365 self.cmd_buf.wait_until_completed();
2366 match self.cmd_buf.status() {
2367 MTLCommandBufferStatus::Completed => Ok(()),
2368 MTLCommandBufferStatus::Error => Err(MlxError::CommandBufferError(
2369 "GPU command buffer completed with error status".into(),
2370 )),
2371 status => Err(MlxError::CommandBufferError(format!(
2372 "Unexpected command buffer status after wait: {:?}",
2373 status
2374 ))),
2375 }
2376 }
2377
2378 /// Borrow the underlying Metal command buffer.
2379 #[inline]
2380 pub fn metal_command_buffer(&self) -> &CommandBuffer {
2381 &self.cmd_buf
2382 }
2383
2384 /// Borrow the residency set bound to this encoder, if one exists.
2385 ///
2386 /// ADR-019 Phase 0b iter89e2-B: exposed `pub(crate)` so
2387 /// [`crate::EncoderSession`] can route caller-driven add/remove
2388 /// requests through the same `Arc<ResidencySetInner>` the encoder
2389 /// itself flushes at every `commit*` boundary. The single-set
2390 /// invariant from `device.rs::MlxDevice` is preserved — both the
2391 /// encoder's `flush_residency_pending` and the session's delegated
2392 /// add/remove operate on the SAME residency set. Returns `None` when
2393 /// residency sets are disabled (HF2Q_NO_RESIDENCY=1, macOS<15, or
2394 /// `CommandEncoder::new` from a residency-less queue).
2395 #[inline]
2396 pub(crate) fn residency_set(&self) -> Option<&ResidencySet> {
2397 self.residency_set.as_ref()
2398 }
2399
2400 /// Reopen `cmd_buf` with a fresh `CommandBuffer` from the originating queue.
2401 ///
2402 /// ADR-019 Phase 0b iter89e2-B: enables multi-stage chaining. After a
2403 /// non-blocking `commit*` has handed the prior CB to Metal, this method
2404 /// rotates `cmd_buf` to a freshly-allocated CB on the same queue and
2405 /// resets every per-CB scratch field so the next dispatch is encoded
2406 /// onto the new CB.
2407 ///
2408 /// # Caller contract
2409 ///
2410 /// Only valid when `active_encoder.is_null()` (the persistent compute
2411 /// encoder must have been ended via `end_active_encoder()`, which both
2412 /// `commit_and_wait` and `commit` already do). Calling this method
2413 /// while a compute encoder is open would leak the encoder (the new
2414 /// `cmd_buf` does not own it) and trip Metal's "Command encoder
2415 /// released without endEncoding" assertion when the prior `cmd_buf`
2416 /// drops. Callers are [`crate::EncoderSession::reset_for_next_stage`]
2417 /// only — the session has already committed before invoking this.
2418 ///
2419 /// # F2 / F11 / F12 fence preservation
2420 ///
2421 /// - **F2 — residency-rescission**: this method does NOT re-flush
2422 /// the residency set. The prior `commit*` already flushed; staged
2423 /// add/remove since then will flush at the next `commit*` on the
2424 /// new CB. The residency-set Arc clone is preserved.
2425 /// - **F11 — zero-init alloc_buffer**: untouched (no buffer allocs).
2426 /// - **F12 — `HF2Q_FORCE_SERIAL_DISPATCH`**: the new CB will lazily
2427 /// open its compute encoder via `get_or_create_encoder`, which
2428 /// re-reads the env var; the falsification probe still fires on
2429 /// the new CB.
2430 ///
2431 /// # Counter semantics
2432 ///
2433 /// Bumps `CMD_BUF_COUNT` exactly once per call, matching the
2434 /// `new_with_residency` accounting. Does NOT bump `SYNC_COUNT` (no
2435 /// commit/wait happens here).
2436 pub(crate) fn reset_command_buffer(&mut self) {
2437 debug_assert!(
2438 self.active_encoder.is_null(),
2439 "reset_command_buffer called with an active compute encoder \
2440 — caller must commit (which calls end_active_encoder) first"
2441 );
2442 let cmd_buf = if unretained_refs_enabled() {
2443 self.queue
2444 .new_command_buffer_with_unretained_references()
2445 .to_owned()
2446 } else {
2447 self.queue.new_command_buffer().to_owned()
2448 };
2449 CMD_BUF_COUNT.fetch_add(1, Ordering::Relaxed);
2450 self.cmd_buf = cmd_buf;
2451 // Per-CB scratch state — every field that's documented as being
2452 // bounded by a CB lifetime resets here.
2453 self.active_encoder = std::ptr::null();
2454 self.dispatch_in_cb = 0;
2455 self.last_label.clear();
2456 self.pending_dispatch_meta.clear();
2457 // `mem_ranges` is a per-CB barrier inference state; clearing on
2458 // CB rotation matches the `commit_and_wait` post-commit invariant
2459 // (any new CB starts with no pending hazards). The field's own
2460 // `clear` is invoked via `MemRanges::default` here to avoid
2461 // exposing internals.
2462 self.mem_ranges = MemRanges::new();
2463 // `sample_buffer` is dropped explicitly inside
2464 // `resolve_dispatch_samples` after a CB completes; we leave it
2465 // in whatever state the prior commit left it (typically `None`
2466 // after `commit_and_wait` finishes). A stale `Some` here would
2467 // be visible only under `MLX_PROFILE_DISPATCH=1` which fires its
2468 // own one-shot warning; not worth a special case.
2469 // `capture` (if Some) persists across CB rotation — capture mode
2470 // accumulates across stages within a session by design.
2471 // `pending_op_kind` / `pending_reads` / `pending_writes` only
2472 // hold tags for the NEXT dispatch and are consumed when that
2473 // dispatch fires — leaving them as-is is correct.
2474 }
2475
2476 /// Encode an `MTLSharedEvent` wait at `value` on the current CB.
2477 ///
2478 /// ADR-019 Phase 0b iter89e2-B: pairs with [`Self::encode_signal_event`]
2479 /// to express the inter-CB ordering D3 stage boundaries need. The new
2480 /// CB's GPU work blocks until the prior CB's signal lands on the same
2481 /// event at >= `value`.
2482 ///
2483 /// # Caller contract
2484 ///
2485 /// Must be called BEFORE any compute encoder is opened on the new
2486 /// CB — the wait is a CB-level op that must precede every dispatch
2487 /// in the new CB to actually order them. [`crate::EncoderSession::reset_for_next_stage`]
2488 /// fires this immediately after `reset_command_buffer`, before any
2489 /// dispatch lazy-opens the encoder.
2490 #[inline]
2491 pub(crate) fn encode_wait_for_event(&self, event: &metal::EventRef, value: u64) {
2492 debug_assert!(
2493 self.active_encoder.is_null(),
2494 "encode_wait_for_event called with an open compute encoder \
2495 — wait must precede the first dispatch on the new CB"
2496 );
2497 self.cmd_buf.encode_wait_for_event(event, value);
2498 }
2499
2500 /// End the active compute encoder, encode a stage-fence signal, and
2501 /// commit the CB non-blocking — atomically from the caller's view.
2502 ///
2503 /// ADR-019 Phase 0b iter89e2-B: this is the helper
2504 /// [`crate::EncoderSession::fence_stage`] uses to thread the signal
2505 /// between the encoder-end and the CB-commit boundaries that
2506 /// `commit_labeled` would otherwise serialize. Sequence:
2507 ///
2508 /// 1. End the persistent compute encoder (so `encodeSignalEvent:` is
2509 /// encoded at CB-level, not encoder-level — Metal validates that
2510 /// `encodeSignalEvent:` outside any encoder pass is the only
2511 /// legal placement).
2512 /// 2. Apply `label` (when `Some`) to the CB. Note: at this point
2513 /// the encoder is already ended, so the encoder's own
2514 /// `setLabel:` is a no-op site — only the CB label propagates.
2515 /// `last_label` and per-dispatch profiling keep working as
2516 /// documented.
2517 /// 3. Encode `encodeSignalEvent:event:value:new_value` at CB-level.
2518 /// 4. Flush the residency-set pending staging (matches the
2519 /// `commit_labeled` / `commit` flush at encoder.rs:2004).
2520 /// 5. Commit the CB non-blocking (matches `commit()` at
2521 /// encoder.rs:2026).
2522 ///
2523 /// # Counter semantics
2524 ///
2525 /// Bumps `SYNC_COUNT` zero times (non-blocking). Bumps
2526 /// `CMD_BUF_COUNT` zero times (no new CB allocated here —
2527 /// [`Self::reset_command_buffer`] does that on the next stage).
2528 ///
2529 /// # Errors
2530 ///
2531 /// Infallible (matches `commit()` semantics — errors surface only
2532 /// at `wait_until_completed`).
2533 pub(crate) fn fence_signal_and_commit(
2534 &mut self,
2535 event: &metal::EventRef,
2536 new_value: u64,
2537 label: Option<&str>,
2538 ) {
2539 // Step 1: end the active compute encoder. encode_signal_event's
2540 // debug_assert requires this be done first.
2541 self.end_active_encoder();
2542 // Step 2: apply the CB label so xctrace MST attribution still
2543 // works on the fenced CB. apply_labels' debug_assert against
2544 // empty labels matches commit_labeled's semantics.
2545 if let Some(l) = label {
2546 self.apply_labels(l);
2547 }
2548 // Step 3: encode the signal at CB-level.
2549 self.cmd_buf.encode_signal_event(event, new_value);
2550 // Step 4 + 5: same as commit() — flush residency staging, then
2551 // hand the CB to Metal.
2552 self.flush_residency_pending();
2553 self.cmd_buf.commit();
2554 }
2555}
2556
2557impl Drop for CommandEncoder {
2558 fn drop(&mut self) {
2559 // End the persistent compute encoder before the command buffer
2560 // is dropped, otherwise Metal will assert:
2561 // "Command encoder released without endEncoding"
2562 self.end_active_encoder();
2563 }
2564}