mlx_native/encoder.rs
1//! [`CommandEncoder`] — batched GPU command submission.
2//!
3//! Wraps a Metal command buffer. Encode one or more compute kernel dispatches,
4//! then call [`commit_and_wait`](CommandEncoder::commit_and_wait) to submit the
5//! entire batch and block until the GPU finishes.
6//!
7//! # Persistent compute encoder
8//!
9//! A single Metal `ComputeCommandEncoder` is kept alive across multiple
10//! dispatches within the same command buffer. This avoids the overhead of
11//! creating and ending a new compute encoder per dispatch — the same pattern
12//! candle uses (`compute_per_buffer`). On a forward pass with ~800 dispatches
13//! this saves ~800 encoder create/end cycles.
14//!
15//! # Capture mode (Phase 4e.1)
16//!
17//! When `start_capture()` is called, subsequent dispatches are recorded into a
18//! `Vec<CapturedNode>` instead of being encoded into Metal. `memory_barrier()`
19//! records a barrier sentinel. Call `take_capture()` to extract the recorded
20//! graph for later replay via `ComputeGraph::encode_sequential()`.
21
22use std::sync::atomic::{AtomicU64, Ordering};
23
24use metal::{
25 CommandBuffer, CommandQueue, ComputeCommandEncoderRef, ComputePipelineState,
26 ComputePipelineStateRef, MTLCommandBufferStatus, MTLDispatchType, MTLSize,
27};
28#[allow(unused_imports)]
29use objc::{msg_send, sel, sel_impl};
30
31use crate::buffer::MlxBuffer;
32use crate::error::{MlxError, Result};
33
34/// A buffer or inline-bytes binding for a compute kernel argument slot.
35pub enum KernelArg<'a> {
36 /// Bind an existing Metal buffer at the given index.
37 Buffer(&'a MlxBuffer),
38 /// Bind an existing Metal buffer at the given index with a byte offset.
39 BufferWithOffset(&'a MlxBuffer, u64),
40 /// Bind inline bytes (small constant data) at the given index.
41 /// The data must be `Pod` and is copied into the command encoder.
42 Bytes(&'a [u8]),
43}
44
45/// Convert a `Pod` value to a byte slice suitable for `KernelArg::Bytes`.
46///
47/// # Safety
48///
49/// The caller must ensure `T` has the same layout as the corresponding
50/// MSL struct in the shader (matching field order, sizes, and alignment).
51pub fn as_bytes<T: bytemuck::Pod>(val: &T) -> &[u8] {
52 bytemuck::bytes_of(val)
53}
54
55// ---------------------------------------------------------------------------
56// Capture-mode types (Phase 4e.1 — Graph IR)
57// ---------------------------------------------------------------------------
58
59/// A recorded kernel argument binding.
60///
61/// When the encoder is in capture mode, each `set_buffer` / `set_bytes` call
62/// is stored as a `RecordedBinding` instead of being applied to Metal.
63#[derive(Clone)]
64pub enum RecordedBinding {
65 /// A Metal buffer at the given offset.
66 Buffer {
67 metal_buffer: metal::Buffer,
68 offset: u64,
69 },
70 /// Inline bytes (small constant data, copied).
71 Bytes(Vec<u8>),
72}
73
74/// How to dispatch the recorded kernel.
75#[derive(Clone, Copy, Debug)]
76pub enum DispatchKind {
77 /// `dispatch_threads(grid_size, threadgroup_size)` — Metal picks threadgroup count.
78 Threads,
79 /// `dispatch_thread_groups(threadgroups, threadgroup_size)` — caller specifies threadgroup count.
80 ThreadGroups,
81}
82
83/// Operation kind tag for captured nodes, used by the fusion pass (4e.2).
84///
85/// When the encoder is in capture mode, each dispatch can be tagged with an
86/// `OpKind` so the fusion pass can identify fuseable sequences without
87/// inspecting pipeline names.
88#[derive(Clone, Copy, Debug, PartialEq, Eq)]
89pub enum CapturedOpKind {
90 /// RMS normalization (with learned scale).
91 RmsNorm,
92 /// Elementwise multiply.
93 ElemMul,
94 /// Elementwise add.
95 ElemAdd,
96 /// Scaled dot-product attention (NOT reorderable — breaks lookahead).
97 Sdpa,
98 /// Softmax (NOT reorderable — breaks lookahead).
99 Softmax,
100 /// Any other operation — treated as reorderable by the graph optimizer.
101 Other,
102}
103
104impl CapturedOpKind {
105 /// Whether this captured op kind is safe to reorder past in the graph
106 /// optimizer (Phase 4e.3).
107 ///
108 /// Mirrors the `h_safe` whitelist from llama.cpp's
109 /// `ggml_metal_graph_optimize_reorder`. Non-safe ops break the 64-node
110 /// lookahead — the reorder pass cannot look past them.
111 pub fn is_reorderable(&self) -> bool {
112 match self {
113 Self::Sdpa | Self::Softmax => false,
114 Self::RmsNorm | Self::ElemMul | Self::ElemAdd | Self::Other => true,
115 }
116 }
117}
118
119/// A memory range annotation: (start_address, end_address).
120///
121/// Represents a contiguous GPU buffer region for conflict detection in the
122/// reorder pass (Phase 4e.3). Addresses are CPU-visible `contents_ptr()`
123/// values, which on Apple Silicon unified memory equal the GPU addresses.
124pub type MemRange = (usize, usize);
125
126/// A single captured compute dispatch or barrier sentinel.
127///
128/// Created when the encoder is in capture mode. Replayed later by
129/// `ComputeGraph::encode_sequential()`.
130#[derive(Clone)]
131pub enum CapturedNode {
132 /// A compute dispatch to replay.
133 Dispatch {
134 /// Pipeline state object to bind.
135 pipeline: ComputePipelineState,
136 /// Kernel argument bindings: (slot_index, binding).
137 bindings: Vec<(u64, RecordedBinding)>,
138 /// Grid or threadgroup count (interpretation depends on `dispatch_kind`).
139 threads_per_grid: MTLSize,
140 /// Threads per threadgroup.
141 threads_per_threadgroup: MTLSize,
142 /// Optional threadgroup memory allocations: (index, byte_length).
143 threadgroup_memory: Vec<(u64, u64)>,
144 /// Whether this is a dispatch_threads or dispatch_thread_groups call.
145 dispatch_kind: DispatchKind,
146 /// Operation kind tag for the fusion pass (4e.2).
147 /// Defaults to `Other` if not explicitly set via `set_op_kind()`.
148 op_kind: CapturedOpKind,
149 /// Read buffer ranges for reorder conflict detection (4e.3).
150 /// Populated from `barrier_between` calls in capture mode.
151 reads: Vec<MemRange>,
152 /// Write buffer ranges for reorder conflict detection (4e.3).
153 /// Populated from `barrier_between` calls in capture mode.
154 writes: Vec<MemRange>,
155 },
156 /// A memory barrier sentinel — forces a barrier at replay time.
157 Barrier,
158}
159
160/// Apply a slice of `KernelArg` bindings to a compute encoder.
161#[inline]
162fn apply_bindings(encoder: &ComputeCommandEncoderRef, bindings: &[(u64, KernelArg<'_>)]) {
163 for &(index, ref arg) in bindings {
164 match arg {
165 KernelArg::Buffer(buf) => {
166 encoder.set_buffer(index, Some(buf.metal_buffer()), 0);
167 }
168 KernelArg::BufferWithOffset(buf, offset) => {
169 encoder.set_buffer(index, Some(buf.metal_buffer()), *offset);
170 }
171 KernelArg::Bytes(bytes) => {
172 encoder.set_bytes(index, bytes.len() as u64, bytes.as_ptr() as *const _);
173 }
174 }
175 }
176}
177
178/// Number of times `commit_and_wait()` has been called (CPU sync points).
179static SYNC_COUNT: AtomicU64 = AtomicU64::new(0);
180
181/// Number of times an encode method has been called (GPU dispatches).
182static DISPATCH_COUNT: AtomicU64 = AtomicU64::new(0);
183
184/// Reset both `SYNC_COUNT` and `DISPATCH_COUNT` to zero.
185pub fn reset_counters() {
186 SYNC_COUNT.store(0, Ordering::Relaxed);
187 DISPATCH_COUNT.store(0, Ordering::Relaxed);
188}
189
190/// Read the current value of `SYNC_COUNT`.
191///
192/// Each call to `commit_and_wait()` increments this counter.
193pub fn sync_count() -> u64 {
194 SYNC_COUNT.load(Ordering::Relaxed)
195}
196
197/// Read the current value of `DISPATCH_COUNT`.
198///
199/// Each call to `encode()`, `encode_threadgroups()`, or
200/// `encode_threadgroups_with_shared()` increments this counter.
201pub fn dispatch_count() -> u64 {
202 DISPATCH_COUNT.load(Ordering::Relaxed)
203}
204
205/// A batched compute command encoder.
206///
207/// Keeps a single Metal `ComputeCommandEncoder` alive across multiple
208/// dispatches. The encoder is created on the first dispatch and ended
209/// only when the command buffer is committed. This mirrors candle's
210/// `compute_per_buffer` pattern and avoids per-dispatch encoder overhead.
211///
212/// # Typical usage
213///
214/// ```ignore
215/// let mut enc = device.command_encoder()?;
216/// // Multiple dispatches share the same compute encoder:
217/// enc.encode_threadgroups(pipeline1, &buffers1, tg1, tg_size1);
218/// enc.encode_threadgroups(pipeline2, &buffers2, tg2, tg_size2);
219/// enc.commit_and_wait()?;
220/// ```
221pub struct CommandEncoder {
222 cmd_buf: CommandBuffer,
223 // SAFETY marker: see unsafe Send impl below.
224 /// Raw pointer to the persistent compute encoder.
225 /// Non-null when a compute pass is active.
226 /// The encoder borrows from `cmd_buf` but we cannot express this
227 /// lifetime in safe Rust, so we use a raw pointer.
228 /// SAFETY: the pointer is valid as long as `cmd_buf` is alive and
229 /// `end_encoding()` has not been called on it.
230 active_encoder: *const ComputeCommandEncoderRef,
231 /// When `Some`, dispatches are recorded here instead of being encoded
232 /// into Metal. Set via `start_capture()`, extracted via `take_capture()`.
233 capture: Option<Vec<CapturedNode>>,
234 /// Op kind tag for the NEXT captured dispatch. Set via `set_op_kind()`,
235 /// consumed (reset to `Other`) when a dispatch is captured.
236 pending_op_kind: CapturedOpKind,
237 /// Pending read buffer ranges for the NEXT captured dispatch.
238 /// Set via `set_pending_buffer_ranges()`, consumed when the next dispatch
239 /// is captured. Used by the reorder pass (Phase 4e.3).
240 pending_reads: Vec<MemRange>,
241 /// Pending write buffer ranges for the NEXT captured dispatch.
242 pending_writes: Vec<MemRange>,
243}
244
245/// SAFETY: CommandEncoder is safe to Send across threads provided that:
246/// 1. Only one thread accesses the encoder at a time (exclusive ownership).
247/// 2. The encoder is not used concurrently from multiple threads.
248///
249/// Metal command buffers and compute encoders are thread-safe for exclusive
250/// access (Apple documentation: "You can create command buffers, encode
251/// commands, and submit them from any thread"). The raw pointer
252/// `active_encoder` borrows from `cmd_buf` and is valid as long as
253/// `cmd_buf` is alive — this invariant holds across thread boundaries
254/// because both fields move together.
255///
256/// This matches llama.cpp's pattern of encoding command buffers on GCD
257/// worker threads via `dispatch_apply`, and is used for the dual-buffer
258/// pipeline where buf1 is encoded on a worker thread while buf0 executes.
259unsafe impl Send for CommandEncoder {}
260
261impl CommandEncoder {
262 /// Create a new command encoder from the given command queue.
263 ///
264 /// This immediately creates a Metal command buffer.
265 pub(crate) fn new(queue: &CommandQueue) -> Result<Self> {
266 let cmd_buf = queue.new_command_buffer().to_owned();
267 Ok(Self {
268 cmd_buf,
269 active_encoder: std::ptr::null(),
270 capture: None,
271 pending_op_kind: CapturedOpKind::Other,
272 pending_reads: Vec::new(),
273 pending_writes: Vec::new(),
274 })
275 }
276
277 /// Enable capture mode.
278 ///
279 /// All subsequent dispatch and barrier calls will be recorded into a
280 /// `Vec<CapturedNode>` instead of being encoded into Metal.
281 /// Call `take_capture()` to extract the recorded nodes.
282 pub fn start_capture(&mut self) {
283 self.capture = Some(Vec::with_capacity(128));
284 }
285
286 /// Whether the encoder is currently in capture mode.
287 pub fn is_capturing(&self) -> bool {
288 self.capture.is_some()
289 }
290
291 /// Extract the captured nodes, ending capture mode.
292 ///
293 /// Returns `None` if capture mode was not active.
294 pub fn take_capture(&mut self) -> Option<Vec<CapturedNode>> {
295 self.capture.take()
296 }
297
298 /// Tag the NEXT captured dispatch with the given operation kind.
299 ///
300 /// The tag is consumed (reset to `Other`) after the next dispatch is
301 /// captured. Only meaningful in capture mode — has no effect on
302 /// direct-dispatch encoding.
303 ///
304 /// Used by op dispatch functions to annotate captures for the fusion
305 /// pass (Phase 4e.2).
306 pub fn set_op_kind(&mut self, kind: CapturedOpKind) {
307 self.pending_op_kind = kind;
308 }
309
310 /// Consume and return the pending op kind, resetting it to `Other`.
311 fn take_pending_op_kind(&mut self) -> CapturedOpKind {
312 let kind = self.pending_op_kind;
313 self.pending_op_kind = CapturedOpKind::Other;
314 kind
315 }
316
317 /// Stash buffer range annotations for the NEXT captured dispatch.
318 ///
319 /// Called by `GraphSession::barrier_between()` in capture mode to record
320 /// which buffers the next dispatch reads from and writes to. The ranges
321 /// are consumed by the next `encode_*` call and attached to the captured
322 /// `CapturedNode::Dispatch`.
323 ///
324 /// Only meaningful in capture mode — has no effect on direct-dispatch.
325 pub fn set_pending_buffer_ranges(&mut self, reads: Vec<MemRange>, writes: Vec<MemRange>) {
326 self.pending_reads = reads;
327 self.pending_writes = writes;
328 }
329
330 /// Patch the last captured dispatch node's empty reads/writes with the
331 /// given ranges. No-op if not capturing, or if the last node isn't a
332 /// Dispatch, or if its ranges are already populated.
333 ///
334 /// Used by `GraphSession::track_dispatch` in recording mode to annotate
335 /// dispatches that were called without a preceding `barrier_between`.
336 pub fn annotate_last_dispatch_if_missing(&mut self, reads: Vec<MemRange>, writes: Vec<MemRange>) {
337 if let Some(ref mut nodes) = self.capture {
338 if let Some(CapturedNode::Dispatch { reads: r, writes: w, .. }) = nodes.last_mut() {
339 if r.is_empty() && !reads.is_empty() {
340 *r = reads;
341 }
342 if w.is_empty() && !writes.is_empty() {
343 *w = writes;
344 }
345 }
346 }
347 }
348
349 /// Consume and return the pending buffer range annotations.
350 fn take_pending_buffer_ranges(&mut self) -> (Vec<MemRange>, Vec<MemRange>) {
351 let reads = std::mem::take(&mut self.pending_reads);
352 let writes = std::mem::take(&mut self.pending_writes);
353 (reads, writes)
354 }
355
356 /// Record buffer bindings into `RecordedBinding` form.
357 fn record_buffer_bindings(buffers: &[(u64, &MlxBuffer)]) -> Vec<(u64, RecordedBinding)> {
358 buffers
359 .iter()
360 .map(|&(index, buf)| {
361 (
362 index,
363 RecordedBinding::Buffer {
364 metal_buffer: buf.metal_buffer().clone(),
365 offset: 0,
366 },
367 )
368 })
369 .collect()
370 }
371
372 /// Record `KernelArg` bindings into `RecordedBinding` form.
373 fn record_arg_bindings(bindings: &[(u64, KernelArg<'_>)]) -> Vec<(u64, RecordedBinding)> {
374 bindings
375 .iter()
376 .map(|(index, arg)| {
377 let recorded = match arg {
378 KernelArg::Buffer(buf) => RecordedBinding::Buffer {
379 metal_buffer: buf.metal_buffer().clone(),
380 offset: 0,
381 },
382 KernelArg::BufferWithOffset(buf, offset) => RecordedBinding::Buffer {
383 metal_buffer: buf.metal_buffer().clone(),
384 offset: *offset,
385 },
386 KernelArg::Bytes(bytes) => RecordedBinding::Bytes(bytes.to_vec()),
387 };
388 (*index, recorded)
389 })
390 .collect()
391 }
392
393 /// Get or create the persistent compute encoder.
394 ///
395 /// On the first call, creates a new compute encoder from the command
396 /// buffer. On subsequent calls, returns the existing one.
397 ///
398 /// SAFETY: The returned reference borrows from `self.cmd_buf` which is
399 /// alive for the lifetime of this `CommandEncoder`. The raw pointer is
400 /// valid until `end_active_encoder()` is called.
401 #[inline]
402 fn get_or_create_encoder(&mut self) -> &ComputeCommandEncoderRef {
403 if self.active_encoder.is_null() {
404 // Use MTLDispatchTypeConcurrent to allow independent dispatches
405 // to overlap on the GPU. Memory barriers are inserted between
406 // dependent dispatches via `memory_barrier()`.
407 let encoder = self
408 .cmd_buf
409 .compute_command_encoder_with_dispatch_type(MTLDispatchType::Concurrent);
410 self.active_encoder = encoder as *const ComputeCommandEncoderRef;
411 }
412 // SAFETY: active_encoder is non-null and points to a valid encoder
413 // owned by cmd_buf.
414 unsafe { &*self.active_encoder }
415 }
416
417 /// End the active compute encoder if one exists.
418 #[inline]
419 fn end_active_encoder(&mut self) {
420 if !self.active_encoder.is_null() {
421 // SAFETY: the pointer was obtained from cmd_buf.new_compute_command_encoder()
422 // and has not been ended yet.
423 unsafe { &*self.active_encoder }.end_encoding();
424 self.active_encoder = std::ptr::null();
425 }
426 }
427
428 /// Insert a memory barrier with scope `MTLBarrierScopeBuffers`.
429 ///
430 /// When the encoder uses `MTLDispatchTypeConcurrent`, all dispatches can
431 /// execute concurrently unless separated by a barrier. Call this between
432 /// dispatches where the later dispatch reads a buffer written by an
433 /// earlier one.
434 ///
435 /// This is the same pattern llama.cpp uses:
436 /// `[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers]`
437 #[allow(unexpected_cfgs)]
438 pub fn memory_barrier(&mut self) {
439 if let Some(ref mut nodes) = self.capture {
440 nodes.push(CapturedNode::Barrier);
441 return;
442 }
443 if self.active_encoder.is_null() {
444 return;
445 }
446 // SAFETY: active_encoder is non-null and valid.
447 let encoder = unsafe { &*self.active_encoder };
448 // MTLBarrierScopeBuffers = 1 << 0 = 1
449 const MTL_BARRIER_SCOPE_BUFFERS: u64 = 1;
450 unsafe {
451 let _: () = objc::msg_send![encoder, memoryBarrierWithScope: MTL_BARRIER_SCOPE_BUFFERS];
452 }
453 }
454
455 /// Set the compute pipeline state for subsequent dispatches.
456 ///
457 /// This begins a new compute pass if one is not already active.
458 pub fn set_pipeline(&mut self, pipeline: &ComputePipelineStateRef) {
459 let encoder = self.get_or_create_encoder();
460 encoder.set_compute_pipeline_state(pipeline);
461 }
462
463 /// Bind a buffer to a compute kernel argument slot.
464 ///
465 /// The `index` corresponds to the `[[buffer(N)]]` attribute in the MSL shader.
466 pub fn set_buffer(&self, index: u64, buffer: &MlxBuffer) {
467 let _ = (index, buffer);
468 }
469
470 /// Dispatch threads on the GPU.
471 pub fn dispatch_threads(&self, grid_size: MTLSize, threadgroup_size: MTLSize) {
472 let _ = (grid_size, threadgroup_size);
473 }
474
475 /// Encode a complete compute pass: set pipeline, bind buffers, dispatch.
476 ///
477 /// Reuses the persistent compute encoder — no per-dispatch encoder
478 /// creation overhead.
479 ///
480 /// # Arguments
481 ///
482 /// * `pipeline` — The compiled compute pipeline to execute.
483 /// * `buffers` — Slice of `(index, &MlxBuffer)` pairs for buffer bindings.
484 /// * `grid_size` — Total number of threads to launch.
485 /// * `threadgroup_size` — Threads per threadgroup.
486 pub fn encode(
487 &mut self,
488 pipeline: &ComputePipelineStateRef,
489 buffers: &[(u64, &MlxBuffer)],
490 grid_size: MTLSize,
491 threadgroup_size: MTLSize,
492 ) {
493 DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
494 let op_kind = self.take_pending_op_kind();
495 let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
496 if let Some(ref mut nodes) = self.capture {
497 nodes.push(CapturedNode::Dispatch {
498 pipeline: pipeline.to_owned(),
499 bindings: Self::record_buffer_bindings(buffers),
500 threads_per_grid: grid_size,
501 threads_per_threadgroup: threadgroup_size,
502 threadgroup_memory: Vec::new(),
503 dispatch_kind: DispatchKind::Threads,
504 op_kind,
505 reads: pending_reads,
506 writes: pending_writes,
507 });
508 return;
509 }
510 let encoder = self.get_or_create_encoder();
511 encoder.set_compute_pipeline_state(pipeline);
512 for &(index, buf) in buffers {
513 encoder.set_buffer(index, Some(buf.metal_buffer()), 0);
514 }
515 encoder.dispatch_threads(grid_size, threadgroup_size);
516 }
517
518 /// Encode a compute pass using threadgroups instead of raw thread counts.
519 ///
520 /// Reuses the persistent compute encoder — no per-dispatch encoder
521 /// creation overhead.
522 pub fn encode_threadgroups(
523 &mut self,
524 pipeline: &ComputePipelineStateRef,
525 buffers: &[(u64, &MlxBuffer)],
526 threadgroups: MTLSize,
527 threadgroup_size: MTLSize,
528 ) {
529 DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
530 let op_kind = self.take_pending_op_kind();
531 let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
532 if let Some(ref mut nodes) = self.capture {
533 nodes.push(CapturedNode::Dispatch {
534 pipeline: pipeline.to_owned(),
535 bindings: Self::record_buffer_bindings(buffers),
536 threads_per_grid: threadgroups,
537 threads_per_threadgroup: threadgroup_size,
538 threadgroup_memory: Vec::new(),
539 dispatch_kind: DispatchKind::ThreadGroups,
540 op_kind,
541 reads: pending_reads,
542 writes: pending_writes,
543 });
544 return;
545 }
546 let encoder = self.get_or_create_encoder();
547 encoder.set_compute_pipeline_state(pipeline);
548 for &(index, buf) in buffers {
549 encoder.set_buffer(index, Some(buf.metal_buffer()), 0);
550 }
551 encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
552 }
553
554 /// Encode a compute pass using threadgroups with shared threadgroup memory.
555 ///
556 /// Like [`encode_threadgroups`](Self::encode_threadgroups), but additionally
557 /// allocates threadgroup memory at the specified indices. This is required
558 /// for kernels that use `threadgroup` memory (e.g. reductions in rms_norm
559 /// and softmax).
560 ///
561 /// # Arguments
562 ///
563 /// * `pipeline` — The compiled compute pipeline to execute.
564 /// * `buffers` — Slice of `(index, &MlxBuffer)` pairs for buffer bindings.
565 /// * `threadgroup_mem` — Slice of `(index, byte_length)` pairs for threadgroup memory.
566 /// * `threadgroups` — Number of threadgroups to dispatch.
567 /// * `threadgroup_size` — Threads per threadgroup.
568 pub fn encode_threadgroups_with_shared(
569 &mut self,
570 pipeline: &ComputePipelineStateRef,
571 buffers: &[(u64, &MlxBuffer)],
572 threadgroup_mem: &[(u64, u64)],
573 threadgroups: MTLSize,
574 threadgroup_size: MTLSize,
575 ) {
576 DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
577 let op_kind = self.take_pending_op_kind();
578 let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
579 if let Some(ref mut nodes) = self.capture {
580 nodes.push(CapturedNode::Dispatch {
581 pipeline: pipeline.to_owned(),
582 bindings: Self::record_buffer_bindings(buffers),
583 threads_per_grid: threadgroups,
584 threads_per_threadgroup: threadgroup_size,
585 threadgroup_memory: threadgroup_mem.to_vec(),
586 dispatch_kind: DispatchKind::ThreadGroups,
587 op_kind,
588 reads: pending_reads,
589 writes: pending_writes,
590 });
591 return;
592 }
593 let encoder = self.get_or_create_encoder();
594 encoder.set_compute_pipeline_state(pipeline);
595 for &(index, buf) in buffers {
596 encoder.set_buffer(index, Some(buf.metal_buffer()), 0);
597 }
598 for &(index, byte_length) in threadgroup_mem {
599 encoder.set_threadgroup_memory_length(index, byte_length);
600 }
601 encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
602 }
603
604 /// Encode a dispatch with mixed buffer/bytes bindings (dispatch_threads).
605 ///
606 /// Reuses the persistent compute encoder.
607 pub fn encode_with_args(
608 &mut self,
609 pipeline: &ComputePipelineStateRef,
610 bindings: &[(u64, KernelArg<'_>)],
611 grid_size: MTLSize,
612 threadgroup_size: MTLSize,
613 ) {
614 DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
615 let op_kind = self.take_pending_op_kind();
616 let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
617 if let Some(ref mut nodes) = self.capture {
618 nodes.push(CapturedNode::Dispatch {
619 pipeline: pipeline.to_owned(),
620 bindings: Self::record_arg_bindings(bindings),
621 threads_per_grid: grid_size,
622 threads_per_threadgroup: threadgroup_size,
623 threadgroup_memory: Vec::new(),
624 dispatch_kind: DispatchKind::Threads,
625 op_kind,
626 reads: pending_reads,
627 writes: pending_writes,
628 });
629 return;
630 }
631 let encoder = self.get_or_create_encoder();
632 encoder.set_compute_pipeline_state(pipeline);
633 apply_bindings(encoder, bindings);
634 encoder.dispatch_threads(grid_size, threadgroup_size);
635 }
636
637 /// Encode a dispatch with mixed buffer/bytes bindings (dispatch_thread_groups).
638 ///
639 /// Reuses the persistent compute encoder.
640 pub fn encode_threadgroups_with_args(
641 &mut self,
642 pipeline: &ComputePipelineStateRef,
643 bindings: &[(u64, KernelArg<'_>)],
644 threadgroups: MTLSize,
645 threadgroup_size: MTLSize,
646 ) {
647 DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
648 let op_kind = self.take_pending_op_kind();
649 let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
650 if let Some(ref mut nodes) = self.capture {
651 nodes.push(CapturedNode::Dispatch {
652 pipeline: pipeline.to_owned(),
653 bindings: Self::record_arg_bindings(bindings),
654 threads_per_grid: threadgroups,
655 threads_per_threadgroup: threadgroup_size,
656 threadgroup_memory: Vec::new(),
657 dispatch_kind: DispatchKind::ThreadGroups,
658 op_kind,
659 reads: pending_reads,
660 writes: pending_writes,
661 });
662 return;
663 }
664 let encoder = self.get_or_create_encoder();
665 encoder.set_compute_pipeline_state(pipeline);
666 apply_bindings(encoder, bindings);
667 encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
668 }
669
670 /// Encode a dispatch with mixed buffer/bytes bindings and shared memory.
671 ///
672 /// Reuses the persistent compute encoder.
673 pub fn encode_threadgroups_with_args_and_shared(
674 &mut self,
675 pipeline: &ComputePipelineStateRef,
676 bindings: &[(u64, KernelArg<'_>)],
677 threadgroup_mem: &[(u64, u64)],
678 threadgroups: MTLSize,
679 threadgroup_size: MTLSize,
680 ) {
681 DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
682 let op_kind = self.take_pending_op_kind();
683 let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
684 if let Some(ref mut nodes) = self.capture {
685 nodes.push(CapturedNode::Dispatch {
686 pipeline: pipeline.to_owned(),
687 bindings: Self::record_arg_bindings(bindings),
688 threads_per_grid: threadgroups,
689 threads_per_threadgroup: threadgroup_size,
690 threadgroup_memory: threadgroup_mem.to_vec(),
691 dispatch_kind: DispatchKind::ThreadGroups,
692 op_kind,
693 reads: pending_reads,
694 writes: pending_writes,
695 });
696 return;
697 }
698 let encoder = self.get_or_create_encoder();
699 encoder.set_compute_pipeline_state(pipeline);
700 apply_bindings(encoder, bindings);
701 for &(index, byte_length) in threadgroup_mem {
702 encoder.set_threadgroup_memory_length(index, byte_length);
703 }
704 encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
705 }
706
707 /// Replay a single captured dispatch node into this encoder.
708 ///
709 /// This is the inverse of capture: it takes a previously recorded
710 /// `CapturedNode::Dispatch` and encodes it into the live Metal encoder.
711 /// Barrier nodes are handled by the caller (ComputeGraph::encode_sequential).
712 ///
713 /// Does NOT increment `DISPATCH_COUNT` — that was already counted at
714 /// capture time.
715 pub fn replay_dispatch(
716 &mut self,
717 pipeline: &ComputePipelineStateRef,
718 bindings: &[(u64, RecordedBinding)],
719 threadgroup_memory: &[(u64, u64)],
720 threads_per_grid: MTLSize,
721 threads_per_threadgroup: MTLSize,
722 dispatch_kind: DispatchKind,
723 ) {
724 let encoder = self.get_or_create_encoder();
725 encoder.set_compute_pipeline_state(pipeline);
726 for (index, binding) in bindings {
727 match binding {
728 RecordedBinding::Buffer { metal_buffer, offset } => {
729 encoder.set_buffer(*index, Some(metal_buffer), *offset);
730 }
731 RecordedBinding::Bytes(bytes) => {
732 encoder.set_bytes(
733 *index,
734 bytes.len() as u64,
735 bytes.as_ptr() as *const _,
736 );
737 }
738 }
739 }
740 for &(index, byte_length) in threadgroup_memory {
741 encoder.set_threadgroup_memory_length(index, byte_length);
742 }
743 match dispatch_kind {
744 DispatchKind::Threads => {
745 encoder.dispatch_threads(threads_per_grid, threads_per_threadgroup);
746 }
747 DispatchKind::ThreadGroups => {
748 encoder.dispatch_thread_groups(threads_per_grid, threads_per_threadgroup);
749 }
750 }
751 }
752
753 /// Commit the command buffer and block until the GPU finishes execution.
754 ///
755 /// # Errors
756 ///
757 /// Returns `MlxError::CommandBufferError` if the GPU reports an error.
758 pub fn commit_and_wait(&mut self) -> Result<()> {
759 SYNC_COUNT.fetch_add(1, Ordering::Relaxed);
760
761 // End the persistent compute encoder before committing.
762 self.end_active_encoder();
763
764 self.cmd_buf.commit();
765 self.cmd_buf.wait_until_completed();
766
767 match self.cmd_buf.status() {
768 MTLCommandBufferStatus::Completed => Ok(()),
769 MTLCommandBufferStatus::Error => {
770 Err(MlxError::CommandBufferError(
771 "GPU command buffer completed with error status".into(),
772 ))
773 }
774 status => Err(MlxError::CommandBufferError(format!(
775 "Unexpected command buffer status after wait: {:?}",
776 status
777 ))),
778 }
779 }
780
781 /// Commit the command buffer WITHOUT blocking.
782 ///
783 /// The GPU begins executing the encoded commands immediately. Call
784 /// [`wait_until_completed`](Self::wait_until_completed) later to block
785 /// the CPU and check for errors. This allows the CPU to continue doing
786 /// other work (e.g. preparing the next batch) while the GPU runs.
787 pub fn commit(&mut self) {
788 self.end_active_encoder();
789 self.cmd_buf.commit();
790 }
791
792 /// Block until a previously committed command buffer completes.
793 ///
794 /// Must be called after [`commit`](Self::commit). Do not call after
795 /// [`commit_and_wait`](Self::commit_and_wait) — that method already waits.
796 ///
797 /// # Errors
798 ///
799 /// Returns `MlxError::CommandBufferError` if the GPU reports an error.
800 pub fn wait_until_completed(&self) -> Result<()> {
801 self.cmd_buf.wait_until_completed();
802 match self.cmd_buf.status() {
803 MTLCommandBufferStatus::Completed => Ok(()),
804 MTLCommandBufferStatus::Error => Err(MlxError::CommandBufferError(
805 "GPU command buffer completed with error status".into(),
806 )),
807 status => Err(MlxError::CommandBufferError(format!(
808 "Unexpected command buffer status after wait: {:?}",
809 status
810 ))),
811 }
812 }
813
814 /// Borrow the underlying Metal command buffer.
815 #[inline]
816 pub fn metal_command_buffer(&self) -> &CommandBuffer {
817 &self.cmd_buf
818 }
819}
820
821impl Drop for CommandEncoder {
822 fn drop(&mut self) {
823 // End the persistent compute encoder before the command buffer
824 // is dropped, otherwise Metal will assert:
825 // "Command encoder released without endEncoding"
826 self.end_active_encoder();
827 }
828}