mlx_native/encoder.rs
1//! [`CommandEncoder`] — batched GPU command submission.
2//!
3//! Wraps a Metal command buffer. Encode one or more compute kernel dispatches,
4//! then call [`commit_and_wait`](CommandEncoder::commit_and_wait) to submit the
5//! entire batch and block until the GPU finishes.
6//!
7//! # Persistent compute encoder
8//!
9//! A single Metal `ComputeCommandEncoder` is kept alive across multiple
10//! dispatches within the same command buffer. This avoids the overhead of
11//! creating and ending a new compute encoder per dispatch — the same pattern
12//! candle uses (`compute_per_buffer`). On a forward pass with ~800 dispatches
13//! this saves ~800 encoder create/end cycles.
14//!
15//! # Capture mode (Phase 4e.1)
16//!
17//! When `start_capture()` is called, subsequent dispatches are recorded into a
18//! `Vec<CapturedNode>` instead of being encoded into Metal. `memory_barrier()`
19//! records a barrier sentinel. Call `take_capture()` to extract the recorded
20//! graph for later replay via `ComputeGraph::encode_sequential()`.
21
22use std::sync::atomic::{AtomicU64, Ordering};
23
24use metal::{
25 CommandBuffer, CommandQueue, ComputeCommandEncoderRef, ComputePipelineState,
26 ComputePipelineStateRef, MTLCommandBufferStatus, MTLDispatchType, MTLSize,
27};
28#[allow(unused_imports)]
29use objc::{msg_send, sel, sel_impl};
30
31use crate::buffer::MlxBuffer;
32use crate::error::{MlxError, Result};
33
34/// A buffer or inline-bytes binding for a compute kernel argument slot.
35pub enum KernelArg<'a> {
36 /// Bind an existing Metal buffer at the given index.
37 Buffer(&'a MlxBuffer),
38 /// Bind an existing Metal buffer at the given index with a byte offset.
39 BufferWithOffset(&'a MlxBuffer, u64),
40 /// Bind inline bytes (small constant data) at the given index.
41 /// The data must be `Pod` and is copied into the command encoder.
42 Bytes(&'a [u8]),
43}
44
45/// Convert a `Pod` value to a byte slice suitable for `KernelArg::Bytes`.
46///
47/// # Safety
48///
49/// The caller must ensure `T` has the same layout as the corresponding
50/// MSL struct in the shader (matching field order, sizes, and alignment).
51pub fn as_bytes<T: bytemuck::Pod>(val: &T) -> &[u8] {
52 bytemuck::bytes_of(val)
53}
54
55// ---------------------------------------------------------------------------
56// Capture-mode types (Phase 4e.1 — Graph IR)
57// ---------------------------------------------------------------------------
58
59/// A recorded kernel argument binding.
60///
61/// When the encoder is in capture mode, each `set_buffer` / `set_bytes` call
62/// is stored as a `RecordedBinding` instead of being applied to Metal.
63#[derive(Clone)]
64pub enum RecordedBinding {
65 /// A Metal buffer at the given offset.
66 Buffer {
67 metal_buffer: metal::Buffer,
68 offset: u64,
69 },
70 /// Inline bytes (small constant data, copied).
71 Bytes(Vec<u8>),
72}
73
74/// How to dispatch the recorded kernel.
75#[derive(Clone, Copy, Debug)]
76pub enum DispatchKind {
77 /// `dispatch_threads(grid_size, threadgroup_size)` — Metal picks threadgroup count.
78 Threads,
79 /// `dispatch_thread_groups(threadgroups, threadgroup_size)` — caller specifies threadgroup count.
80 ThreadGroups,
81}
82
83/// Operation kind tag for captured nodes, used by the fusion pass (4e.2).
84///
85/// When the encoder is in capture mode, each dispatch can be tagged with an
86/// `OpKind` so the fusion pass can identify fuseable sequences without
87/// inspecting pipeline names.
88#[derive(Clone, Copy, Debug, PartialEq, Eq)]
89pub enum CapturedOpKind {
90 /// RMS normalization (with learned scale).
91 RmsNorm,
92 /// Elementwise multiply.
93 ElemMul,
94 /// Elementwise add.
95 ElemAdd,
96 /// Scaled dot-product attention (NOT reorderable — breaks lookahead).
97 Sdpa,
98 /// Softmax (NOT reorderable — breaks lookahead).
99 Softmax,
100 /// Any other operation — treated as reorderable by the graph optimizer.
101 Other,
102}
103
104impl CapturedOpKind {
105 /// Whether this captured op kind is safe to reorder past in the graph
106 /// optimizer (Phase 4e.3).
107 ///
108 /// Mirrors the `h_safe` whitelist from llama.cpp's
109 /// `ggml_metal_graph_optimize_reorder`. Non-safe ops break the 64-node
110 /// lookahead — the reorder pass cannot look past them.
111 pub fn is_reorderable(&self) -> bool {
112 match self {
113 Self::Sdpa | Self::Softmax => false,
114 Self::RmsNorm | Self::ElemMul | Self::ElemAdd | Self::Other => true,
115 }
116 }
117}
118
119/// A memory range annotation: (start_address, end_address).
120///
121/// Represents a contiguous GPU buffer region for conflict detection in the
122/// reorder pass (Phase 4e.3). Addresses are CPU-visible `contents_ptr()`
123/// values, which on Apple Silicon unified memory equal the GPU addresses.
124pub type MemRange = (usize, usize);
125
126/// A single captured compute dispatch or barrier sentinel.
127///
128/// Created when the encoder is in capture mode. Replayed later by
129/// `ComputeGraph::encode_sequential()`.
130#[derive(Clone)]
131pub enum CapturedNode {
132 /// A compute dispatch to replay.
133 Dispatch {
134 /// Pipeline state object to bind.
135 pipeline: ComputePipelineState,
136 /// Kernel argument bindings: (slot_index, binding).
137 bindings: Vec<(u64, RecordedBinding)>,
138 /// Grid or threadgroup count (interpretation depends on `dispatch_kind`).
139 threads_per_grid: MTLSize,
140 /// Threads per threadgroup.
141 threads_per_threadgroup: MTLSize,
142 /// Optional threadgroup memory allocations: (index, byte_length).
143 threadgroup_memory: Vec<(u64, u64)>,
144 /// Whether this is a dispatch_threads or dispatch_thread_groups call.
145 dispatch_kind: DispatchKind,
146 /// Operation kind tag for the fusion pass (4e.2).
147 /// Defaults to `Other` if not explicitly set via `set_op_kind()`.
148 op_kind: CapturedOpKind,
149 /// Read buffer ranges for reorder conflict detection (4e.3).
150 /// Populated from `barrier_between` calls in capture mode.
151 reads: Vec<MemRange>,
152 /// Write buffer ranges for reorder conflict detection (4e.3).
153 /// Populated from `barrier_between` calls in capture mode.
154 writes: Vec<MemRange>,
155 },
156 /// A memory barrier sentinel — forces a barrier at replay time.
157 Barrier,
158}
159
160/// Apply a slice of `KernelArg` bindings to a compute encoder.
161///
162/// `KernelArg::Buffer(buf)` propagates the `MlxBuffer::byte_offset()` so
163/// `slice_view`-derived sub-buffers are honored automatically — the
164/// kernel sees memory starting at the slice's offset. This matches the
165/// documented contract of `slice_view` and the offset-handling in the
166/// other binding paths in this file (`encode`, `encode_threadgroups`,
167/// `encode_threadgroups_with_shared`, replay). Without it, every
168/// `slice_view`-derived buffer bound via `KernelArg::Buffer` silently
169/// exposes the entire underlying allocation — surfaced by hf2q's
170/// nomic-bert iter-79 cosine parity bisection (cosine 0.098 → 0.999962
171/// after fix).
172///
173/// `KernelArg::BufferWithOffset(buf, offset)` continues to use the
174/// explicit `offset` argument verbatim (callers asking for an explicit
175/// offset get exactly that, even on sliced buffers). The two API
176/// surfaces are intentional: implicit (sliced views auto-propagate) vs.
177/// explicit (caller-controlled).
178#[inline]
179fn apply_bindings(encoder: &ComputeCommandEncoderRef, bindings: &[(u64, KernelArg<'_>)]) {
180 for &(index, ref arg) in bindings {
181 match arg {
182 KernelArg::Buffer(buf) => {
183 encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
184 }
185 KernelArg::BufferWithOffset(buf, offset) => {
186 encoder.set_buffer(index, Some(buf.metal_buffer()), *offset);
187 }
188 KernelArg::Bytes(bytes) => {
189 encoder.set_bytes(index, bytes.len() as u64, bytes.as_ptr() as *const _);
190 }
191 }
192 }
193}
194
195/// Number of times `commit_and_wait()` has been called (CPU sync points).
196static SYNC_COUNT: AtomicU64 = AtomicU64::new(0);
197
198/// Number of times an encode method has been called (GPU dispatches).
199static DISPATCH_COUNT: AtomicU64 = AtomicU64::new(0);
200
201/// Reset both `SYNC_COUNT` and `DISPATCH_COUNT` to zero.
202pub fn reset_counters() {
203 SYNC_COUNT.store(0, Ordering::Relaxed);
204 DISPATCH_COUNT.store(0, Ordering::Relaxed);
205}
206
207/// Read the current value of `SYNC_COUNT`.
208///
209/// Each call to `commit_and_wait()` increments this counter.
210pub fn sync_count() -> u64 {
211 SYNC_COUNT.load(Ordering::Relaxed)
212}
213
214/// Read the current value of `DISPATCH_COUNT`.
215///
216/// Each call to `encode()`, `encode_threadgroups()`, or
217/// `encode_threadgroups_with_shared()` increments this counter.
218pub fn dispatch_count() -> u64 {
219 DISPATCH_COUNT.load(Ordering::Relaxed)
220}
221
222/// A batched compute command encoder.
223///
224/// Keeps a single Metal `ComputeCommandEncoder` alive across multiple
225/// dispatches. The encoder is created on the first dispatch and ended
226/// only when the command buffer is committed. This mirrors candle's
227/// `compute_per_buffer` pattern and avoids per-dispatch encoder overhead.
228///
229/// # Typical usage
230///
231/// ```ignore
232/// let mut enc = device.command_encoder()?;
233/// // Multiple dispatches share the same compute encoder:
234/// enc.encode_threadgroups(pipeline1, &buffers1, tg1, tg_size1);
235/// enc.encode_threadgroups(pipeline2, &buffers2, tg2, tg_size2);
236/// enc.commit_and_wait()?;
237/// ```
238pub struct CommandEncoder {
239 cmd_buf: CommandBuffer,
240 // SAFETY marker: see unsafe Send impl below.
241 /// Raw pointer to the persistent compute encoder.
242 /// Non-null when a compute pass is active.
243 /// The encoder borrows from `cmd_buf` but we cannot express this
244 /// lifetime in safe Rust, so we use a raw pointer.
245 /// SAFETY: the pointer is valid as long as `cmd_buf` is alive and
246 /// `end_encoding()` has not been called on it.
247 active_encoder: *const ComputeCommandEncoderRef,
248 /// When `Some`, dispatches are recorded here instead of being encoded
249 /// into Metal. Set via `start_capture()`, extracted via `take_capture()`.
250 capture: Option<Vec<CapturedNode>>,
251 /// Op kind tag for the NEXT captured dispatch. Set via `set_op_kind()`,
252 /// consumed (reset to `Other`) when a dispatch is captured.
253 pending_op_kind: CapturedOpKind,
254 /// Pending read buffer ranges for the NEXT captured dispatch.
255 /// Set via `set_pending_buffer_ranges()`, consumed when the next dispatch
256 /// is captured. Used by the reorder pass (Phase 4e.3).
257 pending_reads: Vec<MemRange>,
258 /// Pending write buffer ranges for the NEXT captured dispatch.
259 pending_writes: Vec<MemRange>,
260}
261
262/// SAFETY: CommandEncoder is safe to Send across threads provided that:
263/// 1. Only one thread accesses the encoder at a time (exclusive ownership).
264/// 2. The encoder is not used concurrently from multiple threads.
265///
266/// Metal command buffers and compute encoders are thread-safe for exclusive
267/// access (Apple documentation: "You can create command buffers, encode
268/// commands, and submit them from any thread"). The raw pointer
269/// `active_encoder` borrows from `cmd_buf` and is valid as long as
270/// `cmd_buf` is alive — this invariant holds across thread boundaries
271/// because both fields move together.
272///
273/// This matches llama.cpp's pattern of encoding command buffers on GCD
274/// worker threads via `dispatch_apply`, and is used for the dual-buffer
275/// pipeline where buf1 is encoded on a worker thread while buf0 executes.
276unsafe impl Send for CommandEncoder {}
277
278impl CommandEncoder {
279 /// Create a new command encoder from the given command queue.
280 ///
281 /// This immediately creates a Metal command buffer.
282 pub(crate) fn new(queue: &CommandQueue) -> Result<Self> {
283 let cmd_buf = queue.new_command_buffer().to_owned();
284 Ok(Self {
285 cmd_buf,
286 active_encoder: std::ptr::null(),
287 capture: None,
288 pending_op_kind: CapturedOpKind::Other,
289 pending_reads: Vec::new(),
290 pending_writes: Vec::new(),
291 })
292 }
293
294 /// Enable capture mode.
295 ///
296 /// All subsequent dispatch and barrier calls will be recorded into a
297 /// `Vec<CapturedNode>` instead of being encoded into Metal.
298 /// Call `take_capture()` to extract the recorded nodes.
299 pub fn start_capture(&mut self) {
300 self.capture = Some(Vec::with_capacity(128));
301 }
302
303 /// Whether the encoder is currently in capture mode.
304 pub fn is_capturing(&self) -> bool {
305 self.capture.is_some()
306 }
307
308 /// Extract the captured nodes, ending capture mode.
309 ///
310 /// Returns `None` if capture mode was not active.
311 pub fn take_capture(&mut self) -> Option<Vec<CapturedNode>> {
312 self.capture.take()
313 }
314
315 /// Tag the NEXT captured dispatch with the given operation kind.
316 ///
317 /// The tag is consumed (reset to `Other`) after the next dispatch is
318 /// captured. Only meaningful in capture mode — has no effect on
319 /// direct-dispatch encoding.
320 ///
321 /// Used by op dispatch functions to annotate captures for the fusion
322 /// pass (Phase 4e.2).
323 pub fn set_op_kind(&mut self, kind: CapturedOpKind) {
324 self.pending_op_kind = kind;
325 }
326
327 /// Consume and return the pending op kind, resetting it to `Other`.
328 fn take_pending_op_kind(&mut self) -> CapturedOpKind {
329 let kind = self.pending_op_kind;
330 self.pending_op_kind = CapturedOpKind::Other;
331 kind
332 }
333
334 /// Stash buffer range annotations for the NEXT captured dispatch.
335 ///
336 /// Called by `GraphSession::barrier_between()` in capture mode to record
337 /// which buffers the next dispatch reads from and writes to. The ranges
338 /// are consumed by the next `encode_*` call and attached to the captured
339 /// `CapturedNode::Dispatch`.
340 ///
341 /// Only meaningful in capture mode — has no effect on direct-dispatch.
342 pub fn set_pending_buffer_ranges(&mut self, reads: Vec<MemRange>, writes: Vec<MemRange>) {
343 self.pending_reads = reads;
344 self.pending_writes = writes;
345 }
346
347 /// Patch the last captured dispatch node's empty reads/writes with the
348 /// given ranges. No-op if not capturing, or if the last node isn't a
349 /// Dispatch, or if its ranges are already populated.
350 ///
351 /// Used by `GraphSession::track_dispatch` in recording mode to annotate
352 /// dispatches that were called without a preceding `barrier_between`.
353 pub fn annotate_last_dispatch_if_missing(&mut self, reads: Vec<MemRange>, writes: Vec<MemRange>) {
354 if let Some(ref mut nodes) = self.capture {
355 if let Some(CapturedNode::Dispatch { reads: r, writes: w, .. }) = nodes.last_mut() {
356 if r.is_empty() && !reads.is_empty() {
357 *r = reads;
358 }
359 if w.is_empty() && !writes.is_empty() {
360 *w = writes;
361 }
362 }
363 }
364 }
365
366 /// Consume and return the pending buffer range annotations.
367 fn take_pending_buffer_ranges(&mut self) -> (Vec<MemRange>, Vec<MemRange>) {
368 let reads = std::mem::take(&mut self.pending_reads);
369 let writes = std::mem::take(&mut self.pending_writes);
370 (reads, writes)
371 }
372
373 /// Record buffer bindings into `RecordedBinding` form.
374 fn record_buffer_bindings(buffers: &[(u64, &MlxBuffer)]) -> Vec<(u64, RecordedBinding)> {
375 buffers
376 .iter()
377 .map(|&(index, buf)| {
378 (
379 index,
380 RecordedBinding::Buffer {
381 metal_buffer: buf.metal_buffer().clone(),
382 offset: buf.byte_offset(),
383 },
384 )
385 })
386 .collect()
387 }
388
389 /// Record `KernelArg` bindings into `RecordedBinding` form.
390 ///
391 /// `KernelArg::Buffer(buf)` records `buf.byte_offset()` so capture →
392 /// replay round-trips of `slice_view`-derived buffers preserve their
393 /// offsets, matching `record_buffer_bindings`'s behavior at line 382.
394 fn record_arg_bindings(bindings: &[(u64, KernelArg<'_>)]) -> Vec<(u64, RecordedBinding)> {
395 bindings
396 .iter()
397 .map(|(index, arg)| {
398 let recorded = match arg {
399 KernelArg::Buffer(buf) => RecordedBinding::Buffer {
400 metal_buffer: buf.metal_buffer().clone(),
401 offset: buf.byte_offset(),
402 },
403 KernelArg::BufferWithOffset(buf, offset) => RecordedBinding::Buffer {
404 metal_buffer: buf.metal_buffer().clone(),
405 offset: *offset,
406 },
407 KernelArg::Bytes(bytes) => RecordedBinding::Bytes(bytes.to_vec()),
408 };
409 (*index, recorded)
410 })
411 .collect()
412 }
413
414 /// Get or create the persistent compute encoder.
415 ///
416 /// On the first call, creates a new compute encoder from the command
417 /// buffer. On subsequent calls, returns the existing one.
418 ///
419 /// SAFETY: The returned reference borrows from `self.cmd_buf` which is
420 /// alive for the lifetime of this `CommandEncoder`. The raw pointer is
421 /// valid until `end_active_encoder()` is called.
422 #[inline]
423 fn get_or_create_encoder(&mut self) -> &ComputeCommandEncoderRef {
424 if self.active_encoder.is_null() {
425 // Use MTLDispatchTypeConcurrent to allow independent dispatches
426 // to overlap on the GPU. Memory barriers are inserted between
427 // dependent dispatches via `memory_barrier()`.
428 let encoder = self
429 .cmd_buf
430 .compute_command_encoder_with_dispatch_type(MTLDispatchType::Concurrent);
431 self.active_encoder = encoder as *const ComputeCommandEncoderRef;
432 }
433 // SAFETY: active_encoder is non-null and points to a valid encoder
434 // owned by cmd_buf.
435 unsafe { &*self.active_encoder }
436 }
437
438 /// End the active compute encoder if one exists.
439 #[inline]
440 fn end_active_encoder(&mut self) {
441 if !self.active_encoder.is_null() {
442 // SAFETY: the pointer was obtained from cmd_buf.new_compute_command_encoder()
443 // and has not been ended yet.
444 unsafe { &*self.active_encoder }.end_encoding();
445 self.active_encoder = std::ptr::null();
446 }
447 }
448
449 /// Insert a memory barrier with scope `MTLBarrierScopeBuffers`.
450 ///
451 /// When the encoder uses `MTLDispatchTypeConcurrent`, all dispatches can
452 /// execute concurrently unless separated by a barrier. Call this between
453 /// dispatches where the later dispatch reads a buffer written by an
454 /// earlier one.
455 ///
456 /// This is the same pattern llama.cpp uses:
457 /// `[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers]`
458 #[allow(unexpected_cfgs)]
459 pub fn memory_barrier(&mut self) {
460 if let Some(ref mut nodes) = self.capture {
461 nodes.push(CapturedNode::Barrier);
462 return;
463 }
464 if self.active_encoder.is_null() {
465 return;
466 }
467 // SAFETY: active_encoder is non-null and valid.
468 let encoder = unsafe { &*self.active_encoder };
469 // MTLBarrierScopeBuffers = 1 << 0 = 1
470 const MTL_BARRIER_SCOPE_BUFFERS: u64 = 1;
471 unsafe {
472 let _: () = objc::msg_send![encoder, memoryBarrierWithScope: MTL_BARRIER_SCOPE_BUFFERS];
473 }
474 }
475
476 /// Set the compute pipeline state for subsequent dispatches.
477 ///
478 /// This begins a new compute pass if one is not already active.
479 pub fn set_pipeline(&mut self, pipeline: &ComputePipelineStateRef) {
480 let encoder = self.get_or_create_encoder();
481 encoder.set_compute_pipeline_state(pipeline);
482 }
483
484 /// Bind a buffer to a compute kernel argument slot.
485 ///
486 /// The `index` corresponds to the `[[buffer(N)]]` attribute in the MSL shader.
487 pub fn set_buffer(&self, index: u64, buffer: &MlxBuffer) {
488 let _ = (index, buffer);
489 }
490
491 /// Dispatch threads on the GPU.
492 pub fn dispatch_threads(&self, grid_size: MTLSize, threadgroup_size: MTLSize) {
493 let _ = (grid_size, threadgroup_size);
494 }
495
496 /// Encode a complete compute pass: set pipeline, bind buffers, dispatch.
497 ///
498 /// Reuses the persistent compute encoder — no per-dispatch encoder
499 /// creation overhead.
500 ///
501 /// # Arguments
502 ///
503 /// * `pipeline` — The compiled compute pipeline to execute.
504 /// * `buffers` — Slice of `(index, &MlxBuffer)` pairs for buffer bindings.
505 /// * `grid_size` — Total number of threads to launch.
506 /// * `threadgroup_size` — Threads per threadgroup.
507 pub fn encode(
508 &mut self,
509 pipeline: &ComputePipelineStateRef,
510 buffers: &[(u64, &MlxBuffer)],
511 grid_size: MTLSize,
512 threadgroup_size: MTLSize,
513 ) {
514 DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
515 let op_kind = self.take_pending_op_kind();
516 let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
517 if let Some(ref mut nodes) = self.capture {
518 nodes.push(CapturedNode::Dispatch {
519 pipeline: pipeline.to_owned(),
520 bindings: Self::record_buffer_bindings(buffers),
521 threads_per_grid: grid_size,
522 threads_per_threadgroup: threadgroup_size,
523 threadgroup_memory: Vec::new(),
524 dispatch_kind: DispatchKind::Threads,
525 op_kind,
526 reads: pending_reads,
527 writes: pending_writes,
528 });
529 return;
530 }
531 let encoder = self.get_or_create_encoder();
532 encoder.set_compute_pipeline_state(pipeline);
533 for &(index, buf) in buffers {
534 encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
535 }
536 encoder.dispatch_threads(grid_size, threadgroup_size);
537 }
538
539 /// Encode a compute pass using threadgroups instead of raw thread counts.
540 ///
541 /// Reuses the persistent compute encoder — no per-dispatch encoder
542 /// creation overhead.
543 pub fn encode_threadgroups(
544 &mut self,
545 pipeline: &ComputePipelineStateRef,
546 buffers: &[(u64, &MlxBuffer)],
547 threadgroups: MTLSize,
548 threadgroup_size: MTLSize,
549 ) {
550 DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
551 let op_kind = self.take_pending_op_kind();
552 let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
553 if let Some(ref mut nodes) = self.capture {
554 nodes.push(CapturedNode::Dispatch {
555 pipeline: pipeline.to_owned(),
556 bindings: Self::record_buffer_bindings(buffers),
557 threads_per_grid: threadgroups,
558 threads_per_threadgroup: threadgroup_size,
559 threadgroup_memory: Vec::new(),
560 dispatch_kind: DispatchKind::ThreadGroups,
561 op_kind,
562 reads: pending_reads,
563 writes: pending_writes,
564 });
565 return;
566 }
567 let encoder = self.get_or_create_encoder();
568 encoder.set_compute_pipeline_state(pipeline);
569 for &(index, buf) in buffers {
570 encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
571 }
572 encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
573 }
574
575 /// Encode a compute pass using threadgroups with shared threadgroup memory.
576 ///
577 /// Like [`encode_threadgroups`](Self::encode_threadgroups), but additionally
578 /// allocates threadgroup memory at the specified indices. This is required
579 /// for kernels that use `threadgroup` memory (e.g. reductions in rms_norm
580 /// and softmax).
581 ///
582 /// # Arguments
583 ///
584 /// * `pipeline` — The compiled compute pipeline to execute.
585 /// * `buffers` — Slice of `(index, &MlxBuffer)` pairs for buffer bindings.
586 /// * `threadgroup_mem` — Slice of `(index, byte_length)` pairs for threadgroup memory.
587 /// * `threadgroups` — Number of threadgroups to dispatch.
588 /// * `threadgroup_size` — Threads per threadgroup.
589 pub fn encode_threadgroups_with_shared(
590 &mut self,
591 pipeline: &ComputePipelineStateRef,
592 buffers: &[(u64, &MlxBuffer)],
593 threadgroup_mem: &[(u64, u64)],
594 threadgroups: MTLSize,
595 threadgroup_size: MTLSize,
596 ) {
597 DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
598 let op_kind = self.take_pending_op_kind();
599 let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
600 if let Some(ref mut nodes) = self.capture {
601 nodes.push(CapturedNode::Dispatch {
602 pipeline: pipeline.to_owned(),
603 bindings: Self::record_buffer_bindings(buffers),
604 threads_per_grid: threadgroups,
605 threads_per_threadgroup: threadgroup_size,
606 threadgroup_memory: threadgroup_mem.to_vec(),
607 dispatch_kind: DispatchKind::ThreadGroups,
608 op_kind,
609 reads: pending_reads,
610 writes: pending_writes,
611 });
612 return;
613 }
614 let encoder = self.get_or_create_encoder();
615 encoder.set_compute_pipeline_state(pipeline);
616 for &(index, buf) in buffers {
617 encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
618 }
619 for &(index, byte_length) in threadgroup_mem {
620 encoder.set_threadgroup_memory_length(index, byte_length);
621 }
622 encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
623 }
624
625 /// Encode a dispatch with mixed buffer/bytes bindings (dispatch_threads).
626 ///
627 /// Reuses the persistent compute encoder.
628 pub fn encode_with_args(
629 &mut self,
630 pipeline: &ComputePipelineStateRef,
631 bindings: &[(u64, KernelArg<'_>)],
632 grid_size: MTLSize,
633 threadgroup_size: MTLSize,
634 ) {
635 DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
636 let op_kind = self.take_pending_op_kind();
637 let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
638 if let Some(ref mut nodes) = self.capture {
639 nodes.push(CapturedNode::Dispatch {
640 pipeline: pipeline.to_owned(),
641 bindings: Self::record_arg_bindings(bindings),
642 threads_per_grid: grid_size,
643 threads_per_threadgroup: threadgroup_size,
644 threadgroup_memory: Vec::new(),
645 dispatch_kind: DispatchKind::Threads,
646 op_kind,
647 reads: pending_reads,
648 writes: pending_writes,
649 });
650 return;
651 }
652 let encoder = self.get_or_create_encoder();
653 encoder.set_compute_pipeline_state(pipeline);
654 apply_bindings(encoder, bindings);
655 encoder.dispatch_threads(grid_size, threadgroup_size);
656 }
657
658 /// Encode a dispatch with mixed buffer/bytes bindings (dispatch_thread_groups).
659 ///
660 /// Reuses the persistent compute encoder.
661 pub fn encode_threadgroups_with_args(
662 &mut self,
663 pipeline: &ComputePipelineStateRef,
664 bindings: &[(u64, KernelArg<'_>)],
665 threadgroups: MTLSize,
666 threadgroup_size: MTLSize,
667 ) {
668 DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
669 let op_kind = self.take_pending_op_kind();
670 let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
671 if let Some(ref mut nodes) = self.capture {
672 nodes.push(CapturedNode::Dispatch {
673 pipeline: pipeline.to_owned(),
674 bindings: Self::record_arg_bindings(bindings),
675 threads_per_grid: threadgroups,
676 threads_per_threadgroup: threadgroup_size,
677 threadgroup_memory: Vec::new(),
678 dispatch_kind: DispatchKind::ThreadGroups,
679 op_kind,
680 reads: pending_reads,
681 writes: pending_writes,
682 });
683 return;
684 }
685 let encoder = self.get_or_create_encoder();
686 encoder.set_compute_pipeline_state(pipeline);
687 apply_bindings(encoder, bindings);
688 encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
689 }
690
691 /// Encode a dispatch with mixed buffer/bytes bindings and shared memory.
692 ///
693 /// Reuses the persistent compute encoder.
694 pub fn encode_threadgroups_with_args_and_shared(
695 &mut self,
696 pipeline: &ComputePipelineStateRef,
697 bindings: &[(u64, KernelArg<'_>)],
698 threadgroup_mem: &[(u64, u64)],
699 threadgroups: MTLSize,
700 threadgroup_size: MTLSize,
701 ) {
702 DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
703 let op_kind = self.take_pending_op_kind();
704 let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
705 if let Some(ref mut nodes) = self.capture {
706 nodes.push(CapturedNode::Dispatch {
707 pipeline: pipeline.to_owned(),
708 bindings: Self::record_arg_bindings(bindings),
709 threads_per_grid: threadgroups,
710 threads_per_threadgroup: threadgroup_size,
711 threadgroup_memory: threadgroup_mem.to_vec(),
712 dispatch_kind: DispatchKind::ThreadGroups,
713 op_kind,
714 reads: pending_reads,
715 writes: pending_writes,
716 });
717 return;
718 }
719 let encoder = self.get_or_create_encoder();
720 encoder.set_compute_pipeline_state(pipeline);
721 apply_bindings(encoder, bindings);
722 for &(index, byte_length) in threadgroup_mem {
723 encoder.set_threadgroup_memory_length(index, byte_length);
724 }
725 encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
726 }
727
728 /// Replay a single captured dispatch node into this encoder.
729 ///
730 /// This is the inverse of capture: it takes a previously recorded
731 /// `CapturedNode::Dispatch` and encodes it into the live Metal encoder.
732 /// Barrier nodes are handled by the caller (ComputeGraph::encode_sequential).
733 ///
734 /// Does NOT increment `DISPATCH_COUNT` — that was already counted at
735 /// capture time.
736 pub fn replay_dispatch(
737 &mut self,
738 pipeline: &ComputePipelineStateRef,
739 bindings: &[(u64, RecordedBinding)],
740 threadgroup_memory: &[(u64, u64)],
741 threads_per_grid: MTLSize,
742 threads_per_threadgroup: MTLSize,
743 dispatch_kind: DispatchKind,
744 ) {
745 let encoder = self.get_or_create_encoder();
746 encoder.set_compute_pipeline_state(pipeline);
747 for (index, binding) in bindings {
748 match binding {
749 RecordedBinding::Buffer { metal_buffer, offset } => {
750 encoder.set_buffer(*index, Some(metal_buffer), *offset);
751 }
752 RecordedBinding::Bytes(bytes) => {
753 encoder.set_bytes(
754 *index,
755 bytes.len() as u64,
756 bytes.as_ptr() as *const _,
757 );
758 }
759 }
760 }
761 for &(index, byte_length) in threadgroup_memory {
762 encoder.set_threadgroup_memory_length(index, byte_length);
763 }
764 match dispatch_kind {
765 DispatchKind::Threads => {
766 encoder.dispatch_threads(threads_per_grid, threads_per_threadgroup);
767 }
768 DispatchKind::ThreadGroups => {
769 encoder.dispatch_thread_groups(threads_per_grid, threads_per_threadgroup);
770 }
771 }
772 }
773
774 /// Commit the command buffer and block until the GPU finishes execution.
775 ///
776 /// # Errors
777 ///
778 /// Returns `MlxError::CommandBufferError` if the GPU reports an error.
779 pub fn commit_and_wait(&mut self) -> Result<()> {
780 SYNC_COUNT.fetch_add(1, Ordering::Relaxed);
781
782 // End the persistent compute encoder before committing.
783 self.end_active_encoder();
784
785 self.cmd_buf.commit();
786 self.cmd_buf.wait_until_completed();
787
788 match self.cmd_buf.status() {
789 MTLCommandBufferStatus::Completed => Ok(()),
790 MTLCommandBufferStatus::Error => {
791 Err(MlxError::CommandBufferError(
792 "GPU command buffer completed with error status".into(),
793 ))
794 }
795 status => Err(MlxError::CommandBufferError(format!(
796 "Unexpected command buffer status after wait: {:?}",
797 status
798 ))),
799 }
800 }
801
802 /// Commit + wait, returning `(gpu_start_s, gpu_end_s)` CFTimeInterval
803 /// timestamps from `MTLCommandBuffer`'s `GPUStartTime`/`GPUEndTime`
804 /// properties. Both are mach-absolute CFTimeInterval seconds (double).
805 ///
806 /// Intended for `HF2Q_PROFILE_GPU_TS=1` per-bucket GPU wall-clock
807 /// attribution. Adds exactly two ObjC property reads per call on top
808 /// of the regular `commit_and_wait` — measured well under 1 μs on
809 /// M5 Max.
810 ///
811 /// # Errors
812 ///
813 /// Returns `MlxError::CommandBufferError` if the GPU reports an error.
814 pub fn commit_wait_with_gpu_time(&mut self) -> Result<(f64, f64)> {
815 self.commit_and_wait()?;
816 // SAFETY: cmd_buf is a valid MTLCommandBuffer that has been
817 // committed and awaited. GPUStartTime / GPUEndTime return
818 // CFTimeInterval (double precision seconds). See
819 // https://developer.apple.com/documentation/metal/mtlcommandbuffer/1639925-gpustarttime
820 let (gpu_start, gpu_end): (f64, f64) = unsafe {
821 let cb = &*self.cmd_buf;
822 let s: f64 = msg_send![cb, GPUStartTime];
823 let e: f64 = msg_send![cb, GPUEndTime];
824 (s, e)
825 };
826 Ok((gpu_start, gpu_end))
827 }
828
829 /// Commit the command buffer WITHOUT blocking.
830 ///
831 /// The GPU begins executing the encoded commands immediately. Call
832 /// [`wait_until_completed`](Self::wait_until_completed) later to block
833 /// the CPU and check for errors. This allows the CPU to continue doing
834 /// other work (e.g. preparing the next batch) while the GPU runs.
835 pub fn commit(&mut self) {
836 self.end_active_encoder();
837 self.cmd_buf.commit();
838 }
839
840 /// Block until a previously committed command buffer completes.
841 ///
842 /// Must be called after [`commit`](Self::commit). Do not call after
843 /// [`commit_and_wait`](Self::commit_and_wait) — that method already waits.
844 ///
845 /// # Errors
846 ///
847 /// Returns `MlxError::CommandBufferError` if the GPU reports an error.
848 pub fn wait_until_completed(&self) -> Result<()> {
849 self.cmd_buf.wait_until_completed();
850 match self.cmd_buf.status() {
851 MTLCommandBufferStatus::Completed => Ok(()),
852 MTLCommandBufferStatus::Error => Err(MlxError::CommandBufferError(
853 "GPU command buffer completed with error status".into(),
854 )),
855 status => Err(MlxError::CommandBufferError(format!(
856 "Unexpected command buffer status after wait: {:?}",
857 status
858 ))),
859 }
860 }
861
862 /// Borrow the underlying Metal command buffer.
863 #[inline]
864 pub fn metal_command_buffer(&self) -> &CommandBuffer {
865 &self.cmd_buf
866 }
867}
868
869impl Drop for CommandEncoder {
870 fn drop(&mut self) {
871 // End the persistent compute encoder before the command buffer
872 // is dropped, otherwise Metal will assert:
873 // "Command encoder released without endEncoding"
874 self.end_active_encoder();
875 }
876}