mlx_native/encoder.rs
1//! [`CommandEncoder`] — batched GPU command submission.
2//!
3//! Wraps a Metal command buffer. Encode one or more compute kernel dispatches,
4//! then call [`commit_and_wait`](CommandEncoder::commit_and_wait) to submit the
5//! entire batch and block until the GPU finishes.
6//!
7//! # Persistent compute encoder
8//!
9//! A single Metal `ComputeCommandEncoder` is kept alive across multiple
10//! dispatches within the same command buffer. This avoids the overhead of
11//! creating and ending a new compute encoder per dispatch — the same pattern
12//! candle uses (`compute_per_buffer`). On a forward pass with ~800 dispatches
13//! this saves ~800 encoder create/end cycles.
14//!
15//! # Capture mode (Phase 4e.1)
16//!
17//! When `start_capture()` is called, subsequent dispatches are recorded into a
18//! `Vec<CapturedNode>` instead of being encoded into Metal. `memory_barrier()`
19//! records a barrier sentinel. Call `take_capture()` to extract the recorded
20//! graph for later replay via `ComputeGraph::encode_sequential()`.
21
22use std::sync::atomic::{AtomicU64, Ordering};
23
24use metal::{
25 CommandBuffer, CommandQueue, ComputeCommandEncoderRef, ComputePipelineState,
26 ComputePipelineStateRef, MTLCommandBufferStatus, MTLDispatchType, MTLSize,
27};
28#[allow(unused_imports)]
29use objc::{msg_send, sel, sel_impl};
30
31use crate::buffer::MlxBuffer;
32use crate::error::{MlxError, Result};
33
34/// A buffer or inline-bytes binding for a compute kernel argument slot.
35pub enum KernelArg<'a> {
36 /// Bind an existing Metal buffer at the given index.
37 Buffer(&'a MlxBuffer),
38 /// Bind an existing Metal buffer at the given index with a byte offset.
39 BufferWithOffset(&'a MlxBuffer, u64),
40 /// Bind inline bytes (small constant data) at the given index.
41 /// The data must be `Pod` and is copied into the command encoder.
42 Bytes(&'a [u8]),
43}
44
45/// Convert a `Pod` value to a byte slice suitable for `KernelArg::Bytes`.
46///
47/// # Safety
48///
49/// The caller must ensure `T` has the same layout as the corresponding
50/// MSL struct in the shader (matching field order, sizes, and alignment).
51pub fn as_bytes<T: bytemuck::Pod>(val: &T) -> &[u8] {
52 bytemuck::bytes_of(val)
53}
54
55// ---------------------------------------------------------------------------
56// Capture-mode types (Phase 4e.1 — Graph IR)
57// ---------------------------------------------------------------------------
58
59/// A recorded kernel argument binding.
60///
61/// When the encoder is in capture mode, each `set_buffer` / `set_bytes` call
62/// is stored as a `RecordedBinding` instead of being applied to Metal.
63#[derive(Clone)]
64pub enum RecordedBinding {
65 /// A Metal buffer at the given offset.
66 Buffer {
67 metal_buffer: metal::Buffer,
68 offset: u64,
69 },
70 /// Inline bytes (small constant data, copied).
71 Bytes(Vec<u8>),
72}
73
74/// How to dispatch the recorded kernel.
75#[derive(Clone, Copy, Debug)]
76pub enum DispatchKind {
77 /// `dispatch_threads(grid_size, threadgroup_size)` — Metal picks threadgroup count.
78 Threads,
79 /// `dispatch_thread_groups(threadgroups, threadgroup_size)` — caller specifies threadgroup count.
80 ThreadGroups,
81}
82
83/// Operation kind tag for captured nodes, used by the fusion pass (4e.2).
84///
85/// When the encoder is in capture mode, each dispatch can be tagged with an
86/// `OpKind` so the fusion pass can identify fuseable sequences without
87/// inspecting pipeline names.
88#[derive(Clone, Copy, Debug, PartialEq, Eq)]
89pub enum CapturedOpKind {
90 /// RMS normalization (with learned scale).
91 RmsNorm,
92 /// Elementwise multiply.
93 ElemMul,
94 /// Elementwise add.
95 ElemAdd,
96 /// Scaled dot-product attention (NOT reorderable — breaks lookahead).
97 Sdpa,
98 /// Softmax (NOT reorderable — breaks lookahead).
99 Softmax,
100 /// Any other operation — treated as reorderable by the graph optimizer.
101 Other,
102}
103
104impl CapturedOpKind {
105 /// Whether this captured op kind is safe to reorder past in the graph
106 /// optimizer (Phase 4e.3).
107 ///
108 /// Mirrors the `h_safe` whitelist from llama.cpp's
109 /// `ggml_metal_graph_optimize_reorder`. Non-safe ops break the 64-node
110 /// lookahead — the reorder pass cannot look past them.
111 pub fn is_reorderable(&self) -> bool {
112 match self {
113 Self::Sdpa | Self::Softmax => false,
114 Self::RmsNorm | Self::ElemMul | Self::ElemAdd | Self::Other => true,
115 }
116 }
117}
118
119/// A memory range annotation: (start_address, end_address).
120///
121/// Represents a contiguous GPU buffer region for conflict detection in the
122/// reorder pass (Phase 4e.3). Addresses are CPU-visible `contents_ptr()`
123/// values, which on Apple Silicon unified memory equal the GPU addresses.
124pub type MemRange = (usize, usize);
125
126/// A single captured compute dispatch or barrier sentinel.
127///
128/// Created when the encoder is in capture mode. Replayed later by
129/// `ComputeGraph::encode_sequential()`.
130#[derive(Clone)]
131pub enum CapturedNode {
132 /// A compute dispatch to replay.
133 Dispatch {
134 /// Pipeline state object to bind.
135 pipeline: ComputePipelineState,
136 /// Kernel argument bindings: (slot_index, binding).
137 bindings: Vec<(u64, RecordedBinding)>,
138 /// Grid or threadgroup count (interpretation depends on `dispatch_kind`).
139 threads_per_grid: MTLSize,
140 /// Threads per threadgroup.
141 threads_per_threadgroup: MTLSize,
142 /// Optional threadgroup memory allocations: (index, byte_length).
143 threadgroup_memory: Vec<(u64, u64)>,
144 /// Whether this is a dispatch_threads or dispatch_thread_groups call.
145 dispatch_kind: DispatchKind,
146 /// Operation kind tag for the fusion pass (4e.2).
147 /// Defaults to `Other` if not explicitly set via `set_op_kind()`.
148 op_kind: CapturedOpKind,
149 /// Read buffer ranges for reorder conflict detection (4e.3).
150 /// Populated from `barrier_between` calls in capture mode.
151 reads: Vec<MemRange>,
152 /// Write buffer ranges for reorder conflict detection (4e.3).
153 /// Populated from `barrier_between` calls in capture mode.
154 writes: Vec<MemRange>,
155 },
156 /// A memory barrier sentinel — forces a barrier at replay time.
157 Barrier,
158}
159
160/// Apply a slice of `KernelArg` bindings to a compute encoder.
161#[inline]
162fn apply_bindings(encoder: &ComputeCommandEncoderRef, bindings: &[(u64, KernelArg<'_>)]) {
163 for &(index, ref arg) in bindings {
164 match arg {
165 KernelArg::Buffer(buf) => {
166 encoder.set_buffer(index, Some(buf.metal_buffer()), 0);
167 }
168 KernelArg::BufferWithOffset(buf, offset) => {
169 encoder.set_buffer(index, Some(buf.metal_buffer()), *offset);
170 }
171 KernelArg::Bytes(bytes) => {
172 encoder.set_bytes(index, bytes.len() as u64, bytes.as_ptr() as *const _);
173 }
174 }
175 }
176}
177
178/// Number of times `commit_and_wait()` has been called (CPU sync points).
179static SYNC_COUNT: AtomicU64 = AtomicU64::new(0);
180
181/// Number of times an encode method has been called (GPU dispatches).
182static DISPATCH_COUNT: AtomicU64 = AtomicU64::new(0);
183
184/// Reset both `SYNC_COUNT` and `DISPATCH_COUNT` to zero.
185pub fn reset_counters() {
186 SYNC_COUNT.store(0, Ordering::Relaxed);
187 DISPATCH_COUNT.store(0, Ordering::Relaxed);
188}
189
190/// Read the current value of `SYNC_COUNT`.
191///
192/// Each call to `commit_and_wait()` increments this counter.
193pub fn sync_count() -> u64 {
194 SYNC_COUNT.load(Ordering::Relaxed)
195}
196
197/// Read the current value of `DISPATCH_COUNT`.
198///
199/// Each call to `encode()`, `encode_threadgroups()`, or
200/// `encode_threadgroups_with_shared()` increments this counter.
201pub fn dispatch_count() -> u64 {
202 DISPATCH_COUNT.load(Ordering::Relaxed)
203}
204
205/// A batched compute command encoder.
206///
207/// Keeps a single Metal `ComputeCommandEncoder` alive across multiple
208/// dispatches. The encoder is created on the first dispatch and ended
209/// only when the command buffer is committed. This mirrors candle's
210/// `compute_per_buffer` pattern and avoids per-dispatch encoder overhead.
211///
212/// # Typical usage
213///
214/// ```ignore
215/// let mut enc = device.command_encoder()?;
216/// // Multiple dispatches share the same compute encoder:
217/// enc.encode_threadgroups(pipeline1, &buffers1, tg1, tg_size1);
218/// enc.encode_threadgroups(pipeline2, &buffers2, tg2, tg_size2);
219/// enc.commit_and_wait()?;
220/// ```
221pub struct CommandEncoder {
222 cmd_buf: CommandBuffer,
223 // SAFETY marker: see unsafe Send impl below.
224 /// Raw pointer to the persistent compute encoder.
225 /// Non-null when a compute pass is active.
226 /// The encoder borrows from `cmd_buf` but we cannot express this
227 /// lifetime in safe Rust, so we use a raw pointer.
228 /// SAFETY: the pointer is valid as long as `cmd_buf` is alive and
229 /// `end_encoding()` has not been called on it.
230 active_encoder: *const ComputeCommandEncoderRef,
231 /// When `Some`, dispatches are recorded here instead of being encoded
232 /// into Metal. Set via `start_capture()`, extracted via `take_capture()`.
233 capture: Option<Vec<CapturedNode>>,
234 /// Op kind tag for the NEXT captured dispatch. Set via `set_op_kind()`,
235 /// consumed (reset to `Other`) when a dispatch is captured.
236 pending_op_kind: CapturedOpKind,
237 /// Pending read buffer ranges for the NEXT captured dispatch.
238 /// Set via `set_pending_buffer_ranges()`, consumed when the next dispatch
239 /// is captured. Used by the reorder pass (Phase 4e.3).
240 pending_reads: Vec<MemRange>,
241 /// Pending write buffer ranges for the NEXT captured dispatch.
242 pending_writes: Vec<MemRange>,
243}
244
245/// SAFETY: CommandEncoder is safe to Send across threads provided that:
246/// 1. Only one thread accesses the encoder at a time (exclusive ownership).
247/// 2. The encoder is not used concurrently from multiple threads.
248///
249/// Metal command buffers and compute encoders are thread-safe for exclusive
250/// access (Apple documentation: "You can create command buffers, encode
251/// commands, and submit them from any thread"). The raw pointer
252/// `active_encoder` borrows from `cmd_buf` and is valid as long as
253/// `cmd_buf` is alive — this invariant holds across thread boundaries
254/// because both fields move together.
255///
256/// This matches llama.cpp's pattern of encoding command buffers on GCD
257/// worker threads via `dispatch_apply`, and is used for the dual-buffer
258/// pipeline where buf1 is encoded on a worker thread while buf0 executes.
259unsafe impl Send for CommandEncoder {}
260
261impl CommandEncoder {
262 /// Create a new command encoder from the given command queue.
263 ///
264 /// This immediately creates a Metal command buffer.
265 pub(crate) fn new(queue: &CommandQueue) -> Result<Self> {
266 let cmd_buf = queue.new_command_buffer().to_owned();
267 Ok(Self {
268 cmd_buf,
269 active_encoder: std::ptr::null(),
270 capture: None,
271 pending_op_kind: CapturedOpKind::Other,
272 pending_reads: Vec::new(),
273 pending_writes: Vec::new(),
274 })
275 }
276
277 /// Enable capture mode.
278 ///
279 /// All subsequent dispatch and barrier calls will be recorded into a
280 /// `Vec<CapturedNode>` instead of being encoded into Metal.
281 /// Call `take_capture()` to extract the recorded nodes.
282 pub fn start_capture(&mut self) {
283 self.capture = Some(Vec::with_capacity(128));
284 }
285
286 /// Whether the encoder is currently in capture mode.
287 pub fn is_capturing(&self) -> bool {
288 self.capture.is_some()
289 }
290
291 /// Extract the captured nodes, ending capture mode.
292 ///
293 /// Returns `None` if capture mode was not active.
294 pub fn take_capture(&mut self) -> Option<Vec<CapturedNode>> {
295 self.capture.take()
296 }
297
298 /// Tag the NEXT captured dispatch with the given operation kind.
299 ///
300 /// The tag is consumed (reset to `Other`) after the next dispatch is
301 /// captured. Only meaningful in capture mode — has no effect on
302 /// direct-dispatch encoding.
303 ///
304 /// Used by op dispatch functions to annotate captures for the fusion
305 /// pass (Phase 4e.2).
306 pub fn set_op_kind(&mut self, kind: CapturedOpKind) {
307 self.pending_op_kind = kind;
308 }
309
310 /// Consume and return the pending op kind, resetting it to `Other`.
311 fn take_pending_op_kind(&mut self) -> CapturedOpKind {
312 let kind = self.pending_op_kind;
313 self.pending_op_kind = CapturedOpKind::Other;
314 kind
315 }
316
317 /// Stash buffer range annotations for the NEXT captured dispatch.
318 ///
319 /// Called by `GraphSession::barrier_between()` in capture mode to record
320 /// which buffers the next dispatch reads from and writes to. The ranges
321 /// are consumed by the next `encode_*` call and attached to the captured
322 /// `CapturedNode::Dispatch`.
323 ///
324 /// Only meaningful in capture mode — has no effect on direct-dispatch.
325 pub fn set_pending_buffer_ranges(&mut self, reads: Vec<MemRange>, writes: Vec<MemRange>) {
326 self.pending_reads = reads;
327 self.pending_writes = writes;
328 }
329
330 /// Consume and return the pending buffer range annotations.
331 fn take_pending_buffer_ranges(&mut self) -> (Vec<MemRange>, Vec<MemRange>) {
332 let reads = std::mem::take(&mut self.pending_reads);
333 let writes = std::mem::take(&mut self.pending_writes);
334 (reads, writes)
335 }
336
337 /// Record buffer bindings into `RecordedBinding` form.
338 fn record_buffer_bindings(buffers: &[(u64, &MlxBuffer)]) -> Vec<(u64, RecordedBinding)> {
339 buffers
340 .iter()
341 .map(|&(index, buf)| {
342 (
343 index,
344 RecordedBinding::Buffer {
345 metal_buffer: buf.metal_buffer().clone(),
346 offset: 0,
347 },
348 )
349 })
350 .collect()
351 }
352
353 /// Record `KernelArg` bindings into `RecordedBinding` form.
354 fn record_arg_bindings(bindings: &[(u64, KernelArg<'_>)]) -> Vec<(u64, RecordedBinding)> {
355 bindings
356 .iter()
357 .map(|(index, arg)| {
358 let recorded = match arg {
359 KernelArg::Buffer(buf) => RecordedBinding::Buffer {
360 metal_buffer: buf.metal_buffer().clone(),
361 offset: 0,
362 },
363 KernelArg::BufferWithOffset(buf, offset) => RecordedBinding::Buffer {
364 metal_buffer: buf.metal_buffer().clone(),
365 offset: *offset,
366 },
367 KernelArg::Bytes(bytes) => RecordedBinding::Bytes(bytes.to_vec()),
368 };
369 (*index, recorded)
370 })
371 .collect()
372 }
373
374 /// Get or create the persistent compute encoder.
375 ///
376 /// On the first call, creates a new compute encoder from the command
377 /// buffer. On subsequent calls, returns the existing one.
378 ///
379 /// SAFETY: The returned reference borrows from `self.cmd_buf` which is
380 /// alive for the lifetime of this `CommandEncoder`. The raw pointer is
381 /// valid until `end_active_encoder()` is called.
382 #[inline]
383 fn get_or_create_encoder(&mut self) -> &ComputeCommandEncoderRef {
384 if self.active_encoder.is_null() {
385 // Use MTLDispatchTypeConcurrent to allow independent dispatches
386 // to overlap on the GPU. Memory barriers are inserted between
387 // dependent dispatches via `memory_barrier()`.
388 let encoder = self
389 .cmd_buf
390 .compute_command_encoder_with_dispatch_type(MTLDispatchType::Concurrent);
391 self.active_encoder = encoder as *const ComputeCommandEncoderRef;
392 }
393 // SAFETY: active_encoder is non-null and points to a valid encoder
394 // owned by cmd_buf.
395 unsafe { &*self.active_encoder }
396 }
397
398 /// End the active compute encoder if one exists.
399 #[inline]
400 fn end_active_encoder(&mut self) {
401 if !self.active_encoder.is_null() {
402 // SAFETY: the pointer was obtained from cmd_buf.new_compute_command_encoder()
403 // and has not been ended yet.
404 unsafe { &*self.active_encoder }.end_encoding();
405 self.active_encoder = std::ptr::null();
406 }
407 }
408
409 /// Insert a memory barrier with scope `MTLBarrierScopeBuffers`.
410 ///
411 /// When the encoder uses `MTLDispatchTypeConcurrent`, all dispatches can
412 /// execute concurrently unless separated by a barrier. Call this between
413 /// dispatches where the later dispatch reads a buffer written by an
414 /// earlier one.
415 ///
416 /// This is the same pattern llama.cpp uses:
417 /// `[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers]`
418 #[allow(unexpected_cfgs)]
419 pub fn memory_barrier(&mut self) {
420 if let Some(ref mut nodes) = self.capture {
421 nodes.push(CapturedNode::Barrier);
422 return;
423 }
424 if self.active_encoder.is_null() {
425 return;
426 }
427 // SAFETY: active_encoder is non-null and valid.
428 let encoder = unsafe { &*self.active_encoder };
429 // MTLBarrierScopeBuffers = 1 << 0 = 1
430 const MTL_BARRIER_SCOPE_BUFFERS: u64 = 1;
431 unsafe {
432 let _: () = objc::msg_send![encoder, memoryBarrierWithScope: MTL_BARRIER_SCOPE_BUFFERS];
433 }
434 }
435
436 /// Set the compute pipeline state for subsequent dispatches.
437 ///
438 /// This begins a new compute pass if one is not already active.
439 pub fn set_pipeline(&mut self, pipeline: &ComputePipelineStateRef) {
440 let encoder = self.get_or_create_encoder();
441 encoder.set_compute_pipeline_state(pipeline);
442 }
443
444 /// Bind a buffer to a compute kernel argument slot.
445 ///
446 /// The `index` corresponds to the `[[buffer(N)]]` attribute in the MSL shader.
447 pub fn set_buffer(&self, index: u64, buffer: &MlxBuffer) {
448 let _ = (index, buffer);
449 }
450
451 /// Dispatch threads on the GPU.
452 pub fn dispatch_threads(&self, grid_size: MTLSize, threadgroup_size: MTLSize) {
453 let _ = (grid_size, threadgroup_size);
454 }
455
456 /// Encode a complete compute pass: set pipeline, bind buffers, dispatch.
457 ///
458 /// Reuses the persistent compute encoder — no per-dispatch encoder
459 /// creation overhead.
460 ///
461 /// # Arguments
462 ///
463 /// * `pipeline` — The compiled compute pipeline to execute.
464 /// * `buffers` — Slice of `(index, &MlxBuffer)` pairs for buffer bindings.
465 /// * `grid_size` — Total number of threads to launch.
466 /// * `threadgroup_size` — Threads per threadgroup.
467 pub fn encode(
468 &mut self,
469 pipeline: &ComputePipelineStateRef,
470 buffers: &[(u64, &MlxBuffer)],
471 grid_size: MTLSize,
472 threadgroup_size: MTLSize,
473 ) {
474 DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
475 let op_kind = self.take_pending_op_kind();
476 let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
477 if let Some(ref mut nodes) = self.capture {
478 nodes.push(CapturedNode::Dispatch {
479 pipeline: pipeline.to_owned(),
480 bindings: Self::record_buffer_bindings(buffers),
481 threads_per_grid: grid_size,
482 threads_per_threadgroup: threadgroup_size,
483 threadgroup_memory: Vec::new(),
484 dispatch_kind: DispatchKind::Threads,
485 op_kind,
486 reads: pending_reads,
487 writes: pending_writes,
488 });
489 return;
490 }
491 let encoder = self.get_or_create_encoder();
492 encoder.set_compute_pipeline_state(pipeline);
493 for &(index, buf) in buffers {
494 encoder.set_buffer(index, Some(buf.metal_buffer()), 0);
495 }
496 encoder.dispatch_threads(grid_size, threadgroup_size);
497 }
498
499 /// Encode a compute pass using threadgroups instead of raw thread counts.
500 ///
501 /// Reuses the persistent compute encoder — no per-dispatch encoder
502 /// creation overhead.
503 pub fn encode_threadgroups(
504 &mut self,
505 pipeline: &ComputePipelineStateRef,
506 buffers: &[(u64, &MlxBuffer)],
507 threadgroups: MTLSize,
508 threadgroup_size: MTLSize,
509 ) {
510 DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
511 let op_kind = self.take_pending_op_kind();
512 let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
513 if let Some(ref mut nodes) = self.capture {
514 nodes.push(CapturedNode::Dispatch {
515 pipeline: pipeline.to_owned(),
516 bindings: Self::record_buffer_bindings(buffers),
517 threads_per_grid: threadgroups,
518 threads_per_threadgroup: threadgroup_size,
519 threadgroup_memory: Vec::new(),
520 dispatch_kind: DispatchKind::ThreadGroups,
521 op_kind,
522 reads: pending_reads,
523 writes: pending_writes,
524 });
525 return;
526 }
527 let encoder = self.get_or_create_encoder();
528 encoder.set_compute_pipeline_state(pipeline);
529 for &(index, buf) in buffers {
530 encoder.set_buffer(index, Some(buf.metal_buffer()), 0);
531 }
532 encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
533 }
534
535 /// Encode a compute pass using threadgroups with shared threadgroup memory.
536 ///
537 /// Like [`encode_threadgroups`](Self::encode_threadgroups), but additionally
538 /// allocates threadgroup memory at the specified indices. This is required
539 /// for kernels that use `threadgroup` memory (e.g. reductions in rms_norm
540 /// and softmax).
541 ///
542 /// # Arguments
543 ///
544 /// * `pipeline` — The compiled compute pipeline to execute.
545 /// * `buffers` — Slice of `(index, &MlxBuffer)` pairs for buffer bindings.
546 /// * `threadgroup_mem` — Slice of `(index, byte_length)` pairs for threadgroup memory.
547 /// * `threadgroups` — Number of threadgroups to dispatch.
548 /// * `threadgroup_size` — Threads per threadgroup.
549 pub fn encode_threadgroups_with_shared(
550 &mut self,
551 pipeline: &ComputePipelineStateRef,
552 buffers: &[(u64, &MlxBuffer)],
553 threadgroup_mem: &[(u64, u64)],
554 threadgroups: MTLSize,
555 threadgroup_size: MTLSize,
556 ) {
557 DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
558 let op_kind = self.take_pending_op_kind();
559 let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
560 if let Some(ref mut nodes) = self.capture {
561 nodes.push(CapturedNode::Dispatch {
562 pipeline: pipeline.to_owned(),
563 bindings: Self::record_buffer_bindings(buffers),
564 threads_per_grid: threadgroups,
565 threads_per_threadgroup: threadgroup_size,
566 threadgroup_memory: threadgroup_mem.to_vec(),
567 dispatch_kind: DispatchKind::ThreadGroups,
568 op_kind,
569 reads: pending_reads,
570 writes: pending_writes,
571 });
572 return;
573 }
574 let encoder = self.get_or_create_encoder();
575 encoder.set_compute_pipeline_state(pipeline);
576 for &(index, buf) in buffers {
577 encoder.set_buffer(index, Some(buf.metal_buffer()), 0);
578 }
579 for &(index, byte_length) in threadgroup_mem {
580 encoder.set_threadgroup_memory_length(index, byte_length);
581 }
582 encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
583 }
584
585 /// Encode a dispatch with mixed buffer/bytes bindings (dispatch_threads).
586 ///
587 /// Reuses the persistent compute encoder.
588 pub fn encode_with_args(
589 &mut self,
590 pipeline: &ComputePipelineStateRef,
591 bindings: &[(u64, KernelArg<'_>)],
592 grid_size: MTLSize,
593 threadgroup_size: MTLSize,
594 ) {
595 DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
596 let op_kind = self.take_pending_op_kind();
597 let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
598 if let Some(ref mut nodes) = self.capture {
599 nodes.push(CapturedNode::Dispatch {
600 pipeline: pipeline.to_owned(),
601 bindings: Self::record_arg_bindings(bindings),
602 threads_per_grid: grid_size,
603 threads_per_threadgroup: threadgroup_size,
604 threadgroup_memory: Vec::new(),
605 dispatch_kind: DispatchKind::Threads,
606 op_kind,
607 reads: pending_reads,
608 writes: pending_writes,
609 });
610 return;
611 }
612 let encoder = self.get_or_create_encoder();
613 encoder.set_compute_pipeline_state(pipeline);
614 apply_bindings(encoder, bindings);
615 encoder.dispatch_threads(grid_size, threadgroup_size);
616 }
617
618 /// Encode a dispatch with mixed buffer/bytes bindings (dispatch_thread_groups).
619 ///
620 /// Reuses the persistent compute encoder.
621 pub fn encode_threadgroups_with_args(
622 &mut self,
623 pipeline: &ComputePipelineStateRef,
624 bindings: &[(u64, KernelArg<'_>)],
625 threadgroups: MTLSize,
626 threadgroup_size: MTLSize,
627 ) {
628 DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
629 let op_kind = self.take_pending_op_kind();
630 let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
631 if let Some(ref mut nodes) = self.capture {
632 nodes.push(CapturedNode::Dispatch {
633 pipeline: pipeline.to_owned(),
634 bindings: Self::record_arg_bindings(bindings),
635 threads_per_grid: threadgroups,
636 threads_per_threadgroup: threadgroup_size,
637 threadgroup_memory: Vec::new(),
638 dispatch_kind: DispatchKind::ThreadGroups,
639 op_kind,
640 reads: pending_reads,
641 writes: pending_writes,
642 });
643 return;
644 }
645 let encoder = self.get_or_create_encoder();
646 encoder.set_compute_pipeline_state(pipeline);
647 apply_bindings(encoder, bindings);
648 encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
649 }
650
651 /// Encode a dispatch with mixed buffer/bytes bindings and shared memory.
652 ///
653 /// Reuses the persistent compute encoder.
654 pub fn encode_threadgroups_with_args_and_shared(
655 &mut self,
656 pipeline: &ComputePipelineStateRef,
657 bindings: &[(u64, KernelArg<'_>)],
658 threadgroup_mem: &[(u64, u64)],
659 threadgroups: MTLSize,
660 threadgroup_size: MTLSize,
661 ) {
662 DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
663 let op_kind = self.take_pending_op_kind();
664 let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
665 if let Some(ref mut nodes) = self.capture {
666 nodes.push(CapturedNode::Dispatch {
667 pipeline: pipeline.to_owned(),
668 bindings: Self::record_arg_bindings(bindings),
669 threads_per_grid: threadgroups,
670 threads_per_threadgroup: threadgroup_size,
671 threadgroup_memory: threadgroup_mem.to_vec(),
672 dispatch_kind: DispatchKind::ThreadGroups,
673 op_kind,
674 reads: pending_reads,
675 writes: pending_writes,
676 });
677 return;
678 }
679 let encoder = self.get_or_create_encoder();
680 encoder.set_compute_pipeline_state(pipeline);
681 apply_bindings(encoder, bindings);
682 for &(index, byte_length) in threadgroup_mem {
683 encoder.set_threadgroup_memory_length(index, byte_length);
684 }
685 encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
686 }
687
688 /// Replay a single captured dispatch node into this encoder.
689 ///
690 /// This is the inverse of capture: it takes a previously recorded
691 /// `CapturedNode::Dispatch` and encodes it into the live Metal encoder.
692 /// Barrier nodes are handled by the caller (ComputeGraph::encode_sequential).
693 ///
694 /// Does NOT increment `DISPATCH_COUNT` — that was already counted at
695 /// capture time.
696 pub fn replay_dispatch(
697 &mut self,
698 pipeline: &ComputePipelineStateRef,
699 bindings: &[(u64, RecordedBinding)],
700 threadgroup_memory: &[(u64, u64)],
701 threads_per_grid: MTLSize,
702 threads_per_threadgroup: MTLSize,
703 dispatch_kind: DispatchKind,
704 ) {
705 let encoder = self.get_or_create_encoder();
706 encoder.set_compute_pipeline_state(pipeline);
707 for (index, binding) in bindings {
708 match binding {
709 RecordedBinding::Buffer { metal_buffer, offset } => {
710 encoder.set_buffer(*index, Some(metal_buffer), *offset);
711 }
712 RecordedBinding::Bytes(bytes) => {
713 encoder.set_bytes(
714 *index,
715 bytes.len() as u64,
716 bytes.as_ptr() as *const _,
717 );
718 }
719 }
720 }
721 for &(index, byte_length) in threadgroup_memory {
722 encoder.set_threadgroup_memory_length(index, byte_length);
723 }
724 match dispatch_kind {
725 DispatchKind::Threads => {
726 encoder.dispatch_threads(threads_per_grid, threads_per_threadgroup);
727 }
728 DispatchKind::ThreadGroups => {
729 encoder.dispatch_thread_groups(threads_per_grid, threads_per_threadgroup);
730 }
731 }
732 }
733
734 /// Commit the command buffer and block until the GPU finishes execution.
735 ///
736 /// # Errors
737 ///
738 /// Returns `MlxError::CommandBufferError` if the GPU reports an error.
739 pub fn commit_and_wait(&mut self) -> Result<()> {
740 SYNC_COUNT.fetch_add(1, Ordering::Relaxed);
741
742 // End the persistent compute encoder before committing.
743 self.end_active_encoder();
744
745 self.cmd_buf.commit();
746 self.cmd_buf.wait_until_completed();
747
748 match self.cmd_buf.status() {
749 MTLCommandBufferStatus::Completed => Ok(()),
750 MTLCommandBufferStatus::Error => {
751 Err(MlxError::CommandBufferError(
752 "GPU command buffer completed with error status".into(),
753 ))
754 }
755 status => Err(MlxError::CommandBufferError(format!(
756 "Unexpected command buffer status after wait: {:?}",
757 status
758 ))),
759 }
760 }
761
762 /// Commit the command buffer WITHOUT blocking.
763 ///
764 /// The GPU begins executing the encoded commands immediately. Call
765 /// [`wait_until_completed`](Self::wait_until_completed) later to block
766 /// the CPU and check for errors. This allows the CPU to continue doing
767 /// other work (e.g. preparing the next batch) while the GPU runs.
768 pub fn commit(&mut self) {
769 self.end_active_encoder();
770 self.cmd_buf.commit();
771 }
772
773 /// Block until a previously committed command buffer completes.
774 ///
775 /// Must be called after [`commit`](Self::commit). Do not call after
776 /// [`commit_and_wait`](Self::commit_and_wait) — that method already waits.
777 ///
778 /// # Errors
779 ///
780 /// Returns `MlxError::CommandBufferError` if the GPU reports an error.
781 pub fn wait_until_completed(&self) -> Result<()> {
782 self.cmd_buf.wait_until_completed();
783 match self.cmd_buf.status() {
784 MTLCommandBufferStatus::Completed => Ok(()),
785 MTLCommandBufferStatus::Error => Err(MlxError::CommandBufferError(
786 "GPU command buffer completed with error status".into(),
787 )),
788 status => Err(MlxError::CommandBufferError(format!(
789 "Unexpected command buffer status after wait: {:?}",
790 status
791 ))),
792 }
793 }
794
795 /// Borrow the underlying Metal command buffer.
796 #[inline]
797 pub fn metal_command_buffer(&self) -> &CommandBuffer {
798 &self.cmd_buf
799 }
800}
801
802impl Drop for CommandEncoder {
803 fn drop(&mut self) {
804 // End the persistent compute encoder before the command buffer
805 // is dropped, otherwise Metal will assert:
806 // "Command encoder released without endEncoding"
807 self.end_active_encoder();
808 }
809}