Skip to main content

mlx_native/
encoder.rs

1//! [`CommandEncoder`] — batched GPU command submission.
2//!
3//! Wraps a Metal command buffer.  Encode one or more compute kernel dispatches,
4//! then call [`commit_and_wait`](CommandEncoder::commit_and_wait) to submit the
5//! entire batch and block until the GPU finishes.
6//!
7//! # Persistent compute encoder
8//!
9//! A single Metal `ComputeCommandEncoder` is kept alive across multiple
10//! dispatches within the same command buffer.  This avoids the overhead of
11//! creating and ending a new compute encoder per dispatch — the same pattern
12//! candle uses (`compute_per_buffer`).  On a forward pass with ~800 dispatches
13//! this saves ~800 encoder create/end cycles.
14//!
15//! # Capture mode (Phase 4e.1)
16//!
17//! When `start_capture()` is called, subsequent dispatches are recorded into a
18//! `Vec<CapturedNode>` instead of being encoded into Metal.  `memory_barrier()`
19//! records a barrier sentinel.  Call `take_capture()` to extract the recorded
20//! graph for later replay via `ComputeGraph::encode_sequential()`.
21
22use std::sync::atomic::{AtomicU64, Ordering};
23
24use metal::{
25    CommandBuffer, CommandQueue, ComputeCommandEncoderRef, ComputePipelineState,
26    ComputePipelineStateRef, MTLCommandBufferStatus, MTLDispatchType, MTLSize,
27};
28#[allow(unused_imports)]
29use objc::{msg_send, sel, sel_impl};
30
31use crate::buffer::MlxBuffer;
32use crate::error::{MlxError, Result};
33
34/// A buffer or inline-bytes binding for a compute kernel argument slot.
35pub enum KernelArg<'a> {
36    /// Bind an existing Metal buffer at the given index.
37    Buffer(&'a MlxBuffer),
38    /// Bind an existing Metal buffer at the given index with a byte offset.
39    BufferWithOffset(&'a MlxBuffer, u64),
40    /// Bind inline bytes (small constant data) at the given index.
41    /// The data must be `Pod` and is copied into the command encoder.
42    Bytes(&'a [u8]),
43}
44
45/// Convert a `Pod` value to a byte slice suitable for `KernelArg::Bytes`.
46///
47/// # Safety
48///
49/// The caller must ensure `T` has the same layout as the corresponding
50/// MSL struct in the shader (matching field order, sizes, and alignment).
51pub fn as_bytes<T: bytemuck::Pod>(val: &T) -> &[u8] {
52    bytemuck::bytes_of(val)
53}
54
55// ---------------------------------------------------------------------------
56// Capture-mode types (Phase 4e.1 — Graph IR)
57// ---------------------------------------------------------------------------
58
59/// A recorded kernel argument binding.
60///
61/// When the encoder is in capture mode, each `set_buffer` / `set_bytes` call
62/// is stored as a `RecordedBinding` instead of being applied to Metal.
63#[derive(Clone)]
64pub enum RecordedBinding {
65    /// A Metal buffer at the given offset.
66    Buffer {
67        metal_buffer: metal::Buffer,
68        offset: u64,
69    },
70    /// Inline bytes (small constant data, copied).
71    Bytes(Vec<u8>),
72}
73
74/// How to dispatch the recorded kernel.
75#[derive(Clone, Copy, Debug)]
76pub enum DispatchKind {
77    /// `dispatch_threads(grid_size, threadgroup_size)` — Metal picks threadgroup count.
78    Threads,
79    /// `dispatch_thread_groups(threadgroups, threadgroup_size)` — caller specifies threadgroup count.
80    ThreadGroups,
81}
82
83/// Operation kind tag for captured nodes, used by the fusion pass (4e.2).
84///
85/// When the encoder is in capture mode, each dispatch can be tagged with an
86/// `OpKind` so the fusion pass can identify fuseable sequences without
87/// inspecting pipeline names.
88#[derive(Clone, Copy, Debug, PartialEq, Eq)]
89pub enum CapturedOpKind {
90    /// RMS normalization (with learned scale).
91    RmsNorm,
92    /// Elementwise multiply.
93    ElemMul,
94    /// Elementwise add.
95    ElemAdd,
96    /// Scaled dot-product attention (NOT reorderable — breaks lookahead).
97    Sdpa,
98    /// Softmax (NOT reorderable — breaks lookahead).
99    Softmax,
100    /// Any other operation — treated as reorderable by the graph optimizer.
101    Other,
102}
103
104impl CapturedOpKind {
105    /// Whether this captured op kind is safe to reorder past in the graph
106    /// optimizer (Phase 4e.3).
107    ///
108    /// Mirrors the `h_safe` whitelist from llama.cpp's
109    /// `ggml_metal_graph_optimize_reorder`.  Non-safe ops break the 64-node
110    /// lookahead — the reorder pass cannot look past them.
111    pub fn is_reorderable(&self) -> bool {
112        match self {
113            Self::Sdpa | Self::Softmax => false,
114            Self::RmsNorm | Self::ElemMul | Self::ElemAdd | Self::Other => true,
115        }
116    }
117}
118
119/// A memory range annotation: (start_address, end_address).
120///
121/// Represents a contiguous GPU buffer region for conflict detection in the
122/// reorder pass (Phase 4e.3).  Addresses are CPU-visible `contents_ptr()`
123/// values, which on Apple Silicon unified memory equal the GPU addresses.
124pub type MemRange = (usize, usize);
125
126/// A single captured compute dispatch or barrier sentinel.
127///
128/// Created when the encoder is in capture mode.  Replayed later by
129/// `ComputeGraph::encode_sequential()`.
130#[derive(Clone)]
131pub enum CapturedNode {
132    /// A compute dispatch to replay.
133    Dispatch {
134        /// Pipeline state object to bind.
135        pipeline: ComputePipelineState,
136        /// Kernel argument bindings: (slot_index, binding).
137        bindings: Vec<(u64, RecordedBinding)>,
138        /// Grid or threadgroup count (interpretation depends on `dispatch_kind`).
139        threads_per_grid: MTLSize,
140        /// Threads per threadgroup.
141        threads_per_threadgroup: MTLSize,
142        /// Optional threadgroup memory allocations: (index, byte_length).
143        threadgroup_memory: Vec<(u64, u64)>,
144        /// Whether this is a dispatch_threads or dispatch_thread_groups call.
145        dispatch_kind: DispatchKind,
146        /// Operation kind tag for the fusion pass (4e.2).
147        /// Defaults to `Other` if not explicitly set via `set_op_kind()`.
148        op_kind: CapturedOpKind,
149        /// Read buffer ranges for reorder conflict detection (4e.3).
150        /// Populated from `barrier_between` calls in capture mode.
151        reads: Vec<MemRange>,
152        /// Write buffer ranges for reorder conflict detection (4e.3).
153        /// Populated from `barrier_between` calls in capture mode.
154        writes: Vec<MemRange>,
155    },
156    /// A memory barrier sentinel — forces a barrier at replay time.
157    Barrier,
158}
159
160/// Apply a slice of `KernelArg` bindings to a compute encoder.
161#[inline]
162fn apply_bindings(encoder: &ComputeCommandEncoderRef, bindings: &[(u64, KernelArg<'_>)]) {
163    for &(index, ref arg) in bindings {
164        match arg {
165            KernelArg::Buffer(buf) => {
166                encoder.set_buffer(index, Some(buf.metal_buffer()), 0);
167            }
168            KernelArg::BufferWithOffset(buf, offset) => {
169                encoder.set_buffer(index, Some(buf.metal_buffer()), *offset);
170            }
171            KernelArg::Bytes(bytes) => {
172                encoder.set_bytes(index, bytes.len() as u64, bytes.as_ptr() as *const _);
173            }
174        }
175    }
176}
177
178/// Number of times `commit_and_wait()` has been called (CPU sync points).
179static SYNC_COUNT: AtomicU64 = AtomicU64::new(0);
180
181/// Number of times an encode method has been called (GPU dispatches).
182static DISPATCH_COUNT: AtomicU64 = AtomicU64::new(0);
183
184/// Reset both `SYNC_COUNT` and `DISPATCH_COUNT` to zero.
185pub fn reset_counters() {
186    SYNC_COUNT.store(0, Ordering::Relaxed);
187    DISPATCH_COUNT.store(0, Ordering::Relaxed);
188}
189
190/// Read the current value of `SYNC_COUNT`.
191///
192/// Each call to `commit_and_wait()` increments this counter.
193pub fn sync_count() -> u64 {
194    SYNC_COUNT.load(Ordering::Relaxed)
195}
196
197/// Read the current value of `DISPATCH_COUNT`.
198///
199/// Each call to `encode()`, `encode_threadgroups()`, or
200/// `encode_threadgroups_with_shared()` increments this counter.
201pub fn dispatch_count() -> u64 {
202    DISPATCH_COUNT.load(Ordering::Relaxed)
203}
204
205/// A batched compute command encoder.
206///
207/// Keeps a single Metal `ComputeCommandEncoder` alive across multiple
208/// dispatches.  The encoder is created on the first dispatch and ended
209/// only when the command buffer is committed.  This mirrors candle's
210/// `compute_per_buffer` pattern and avoids per-dispatch encoder overhead.
211///
212/// # Typical usage
213///
214/// ```ignore
215/// let mut enc = device.command_encoder()?;
216/// // Multiple dispatches share the same compute encoder:
217/// enc.encode_threadgroups(pipeline1, &buffers1, tg1, tg_size1);
218/// enc.encode_threadgroups(pipeline2, &buffers2, tg2, tg_size2);
219/// enc.commit_and_wait()?;
220/// ```
221pub struct CommandEncoder {
222    cmd_buf: CommandBuffer,
223    // SAFETY marker: see unsafe Send impl below.
224    /// Raw pointer to the persistent compute encoder.
225    /// Non-null when a compute pass is active.
226    /// The encoder borrows from `cmd_buf` but we cannot express this
227    /// lifetime in safe Rust, so we use a raw pointer.
228    /// SAFETY: the pointer is valid as long as `cmd_buf` is alive and
229    /// `end_encoding()` has not been called on it.
230    active_encoder: *const ComputeCommandEncoderRef,
231    /// When `Some`, dispatches are recorded here instead of being encoded
232    /// into Metal.  Set via `start_capture()`, extracted via `take_capture()`.
233    capture: Option<Vec<CapturedNode>>,
234    /// Op kind tag for the NEXT captured dispatch.  Set via `set_op_kind()`,
235    /// consumed (reset to `Other`) when a dispatch is captured.
236    pending_op_kind: CapturedOpKind,
237    /// Pending read buffer ranges for the NEXT captured dispatch.
238    /// Set via `set_pending_buffer_ranges()`, consumed when the next dispatch
239    /// is captured.  Used by the reorder pass (Phase 4e.3).
240    pending_reads: Vec<MemRange>,
241    /// Pending write buffer ranges for the NEXT captured dispatch.
242    pending_writes: Vec<MemRange>,
243}
244
245/// SAFETY: CommandEncoder is safe to Send across threads provided that:
246/// 1. Only one thread accesses the encoder at a time (exclusive ownership).
247/// 2. The encoder is not used concurrently from multiple threads.
248///
249/// Metal command buffers and compute encoders are thread-safe for exclusive
250/// access (Apple documentation: "You can create command buffers, encode
251/// commands, and submit them from any thread"). The raw pointer
252/// `active_encoder` borrows from `cmd_buf` and is valid as long as
253/// `cmd_buf` is alive — this invariant holds across thread boundaries
254/// because both fields move together.
255///
256/// This matches llama.cpp's pattern of encoding command buffers on GCD
257/// worker threads via `dispatch_apply`, and is used for the dual-buffer
258/// pipeline where buf1 is encoded on a worker thread while buf0 executes.
259unsafe impl Send for CommandEncoder {}
260
261impl CommandEncoder {
262    /// Create a new command encoder from the given command queue.
263    ///
264    /// This immediately creates a Metal command buffer.
265    pub(crate) fn new(queue: &CommandQueue) -> Result<Self> {
266        let cmd_buf = queue.new_command_buffer().to_owned();
267        Ok(Self {
268            cmd_buf,
269            active_encoder: std::ptr::null(),
270            capture: None,
271            pending_op_kind: CapturedOpKind::Other,
272            pending_reads: Vec::new(),
273            pending_writes: Vec::new(),
274        })
275    }
276
277    /// Enable capture mode.
278    ///
279    /// All subsequent dispatch and barrier calls will be recorded into a
280    /// `Vec<CapturedNode>` instead of being encoded into Metal.
281    /// Call `take_capture()` to extract the recorded nodes.
282    pub fn start_capture(&mut self) {
283        self.capture = Some(Vec::with_capacity(128));
284    }
285
286    /// Whether the encoder is currently in capture mode.
287    pub fn is_capturing(&self) -> bool {
288        self.capture.is_some()
289    }
290
291    /// Extract the captured nodes, ending capture mode.
292    ///
293    /// Returns `None` if capture mode was not active.
294    pub fn take_capture(&mut self) -> Option<Vec<CapturedNode>> {
295        self.capture.take()
296    }
297
298    /// Tag the NEXT captured dispatch with the given operation kind.
299    ///
300    /// The tag is consumed (reset to `Other`) after the next dispatch is
301    /// captured.  Only meaningful in capture mode — has no effect on
302    /// direct-dispatch encoding.
303    ///
304    /// Used by op dispatch functions to annotate captures for the fusion
305    /// pass (Phase 4e.2).
306    pub fn set_op_kind(&mut self, kind: CapturedOpKind) {
307        self.pending_op_kind = kind;
308    }
309
310    /// Consume and return the pending op kind, resetting it to `Other`.
311    fn take_pending_op_kind(&mut self) -> CapturedOpKind {
312        let kind = self.pending_op_kind;
313        self.pending_op_kind = CapturedOpKind::Other;
314        kind
315    }
316
317    /// Stash buffer range annotations for the NEXT captured dispatch.
318    ///
319    /// Called by `GraphSession::barrier_between()` in capture mode to record
320    /// which buffers the next dispatch reads from and writes to.  The ranges
321    /// are consumed by the next `encode_*` call and attached to the captured
322    /// `CapturedNode::Dispatch`.
323    ///
324    /// Only meaningful in capture mode — has no effect on direct-dispatch.
325    pub fn set_pending_buffer_ranges(&mut self, reads: Vec<MemRange>, writes: Vec<MemRange>) {
326        self.pending_reads = reads;
327        self.pending_writes = writes;
328    }
329
330    /// Patch the last captured dispatch node's empty reads/writes with the
331    /// given ranges. No-op if not capturing, or if the last node isn't a
332    /// Dispatch, or if its ranges are already populated.
333    ///
334    /// Used by `GraphSession::track_dispatch` in recording mode to annotate
335    /// dispatches that were called without a preceding `barrier_between`.
336    pub fn annotate_last_dispatch_if_missing(&mut self, reads: Vec<MemRange>, writes: Vec<MemRange>) {
337        if let Some(ref mut nodes) = self.capture {
338            if let Some(CapturedNode::Dispatch { reads: r, writes: w, .. }) = nodes.last_mut() {
339                if r.is_empty() && !reads.is_empty() {
340                    *r = reads;
341                }
342                if w.is_empty() && !writes.is_empty() {
343                    *w = writes;
344                }
345            }
346        }
347    }
348
349    /// Consume and return the pending buffer range annotations.
350    fn take_pending_buffer_ranges(&mut self) -> (Vec<MemRange>, Vec<MemRange>) {
351        let reads = std::mem::take(&mut self.pending_reads);
352        let writes = std::mem::take(&mut self.pending_writes);
353        (reads, writes)
354    }
355
356    /// Record buffer bindings into `RecordedBinding` form.
357    fn record_buffer_bindings(buffers: &[(u64, &MlxBuffer)]) -> Vec<(u64, RecordedBinding)> {
358        buffers
359            .iter()
360            .map(|&(index, buf)| {
361                (
362                    index,
363                    RecordedBinding::Buffer {
364                        metal_buffer: buf.metal_buffer().clone(),
365                        offset: 0,
366                    },
367                )
368            })
369            .collect()
370    }
371
372    /// Record `KernelArg` bindings into `RecordedBinding` form.
373    fn record_arg_bindings(bindings: &[(u64, KernelArg<'_>)]) -> Vec<(u64, RecordedBinding)> {
374        bindings
375            .iter()
376            .map(|(index, arg)| {
377                let recorded = match arg {
378                    KernelArg::Buffer(buf) => RecordedBinding::Buffer {
379                        metal_buffer: buf.metal_buffer().clone(),
380                        offset: 0,
381                    },
382                    KernelArg::BufferWithOffset(buf, offset) => RecordedBinding::Buffer {
383                        metal_buffer: buf.metal_buffer().clone(),
384                        offset: *offset,
385                    },
386                    KernelArg::Bytes(bytes) => RecordedBinding::Bytes(bytes.to_vec()),
387                };
388                (*index, recorded)
389            })
390            .collect()
391    }
392
393    /// Get or create the persistent compute encoder.
394    ///
395    /// On the first call, creates a new compute encoder from the command
396    /// buffer.  On subsequent calls, returns the existing one.
397    ///
398    /// SAFETY: The returned reference borrows from `self.cmd_buf` which is
399    /// alive for the lifetime of this `CommandEncoder`.  The raw pointer is
400    /// valid until `end_active_encoder()` is called.
401    #[inline]
402    fn get_or_create_encoder(&mut self) -> &ComputeCommandEncoderRef {
403        if self.active_encoder.is_null() {
404            // Use MTLDispatchTypeConcurrent to allow independent dispatches
405            // to overlap on the GPU.  Memory barriers are inserted between
406            // dependent dispatches via `memory_barrier()`.
407            let encoder = self
408                .cmd_buf
409                .compute_command_encoder_with_dispatch_type(MTLDispatchType::Concurrent);
410            self.active_encoder = encoder as *const ComputeCommandEncoderRef;
411        }
412        // SAFETY: active_encoder is non-null and points to a valid encoder
413        // owned by cmd_buf.
414        unsafe { &*self.active_encoder }
415    }
416
417    /// End the active compute encoder if one exists.
418    #[inline]
419    fn end_active_encoder(&mut self) {
420        if !self.active_encoder.is_null() {
421            // SAFETY: the pointer was obtained from cmd_buf.new_compute_command_encoder()
422            // and has not been ended yet.
423            unsafe { &*self.active_encoder }.end_encoding();
424            self.active_encoder = std::ptr::null();
425        }
426    }
427
428    /// Insert a memory barrier with scope `MTLBarrierScopeBuffers`.
429    ///
430    /// When the encoder uses `MTLDispatchTypeConcurrent`, all dispatches can
431    /// execute concurrently unless separated by a barrier.  Call this between
432    /// dispatches where the later dispatch reads a buffer written by an
433    /// earlier one.
434    ///
435    /// This is the same pattern llama.cpp uses:
436    /// `[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers]`
437    #[allow(unexpected_cfgs)]
438    pub fn memory_barrier(&mut self) {
439        if let Some(ref mut nodes) = self.capture {
440            nodes.push(CapturedNode::Barrier);
441            return;
442        }
443        if self.active_encoder.is_null() {
444            return;
445        }
446        // SAFETY: active_encoder is non-null and valid.
447        let encoder = unsafe { &*self.active_encoder };
448        // MTLBarrierScopeBuffers = 1 << 0 = 1
449        const MTL_BARRIER_SCOPE_BUFFERS: u64 = 1;
450        unsafe {
451            let _: () = objc::msg_send![encoder, memoryBarrierWithScope: MTL_BARRIER_SCOPE_BUFFERS];
452        }
453    }
454
455    /// Set the compute pipeline state for subsequent dispatches.
456    ///
457    /// This begins a new compute pass if one is not already active.
458    pub fn set_pipeline(&mut self, pipeline: &ComputePipelineStateRef) {
459        let encoder = self.get_or_create_encoder();
460        encoder.set_compute_pipeline_state(pipeline);
461    }
462
463    /// Bind a buffer to a compute kernel argument slot.
464    ///
465    /// The `index` corresponds to the `[[buffer(N)]]` attribute in the MSL shader.
466    pub fn set_buffer(&self, index: u64, buffer: &MlxBuffer) {
467        let _ = (index, buffer);
468    }
469
470    /// Dispatch threads on the GPU.
471    pub fn dispatch_threads(&self, grid_size: MTLSize, threadgroup_size: MTLSize) {
472        let _ = (grid_size, threadgroup_size);
473    }
474
475    /// Encode a complete compute pass: set pipeline, bind buffers, dispatch.
476    ///
477    /// Reuses the persistent compute encoder — no per-dispatch encoder
478    /// creation overhead.
479    ///
480    /// # Arguments
481    ///
482    /// * `pipeline`         — The compiled compute pipeline to execute.
483    /// * `buffers`          — Slice of `(index, &MlxBuffer)` pairs for buffer bindings.
484    /// * `grid_size`        — Total number of threads to launch.
485    /// * `threadgroup_size` — Threads per threadgroup.
486    pub fn encode(
487        &mut self,
488        pipeline: &ComputePipelineStateRef,
489        buffers: &[(u64, &MlxBuffer)],
490        grid_size: MTLSize,
491        threadgroup_size: MTLSize,
492    ) {
493        DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
494        let op_kind = self.take_pending_op_kind();
495        let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
496        if let Some(ref mut nodes) = self.capture {
497            nodes.push(CapturedNode::Dispatch {
498                pipeline: pipeline.to_owned(),
499                bindings: Self::record_buffer_bindings(buffers),
500                threads_per_grid: grid_size,
501                threads_per_threadgroup: threadgroup_size,
502                threadgroup_memory: Vec::new(),
503                dispatch_kind: DispatchKind::Threads,
504                op_kind,
505                reads: pending_reads,
506                writes: pending_writes,
507            });
508            return;
509        }
510        let encoder = self.get_or_create_encoder();
511        encoder.set_compute_pipeline_state(pipeline);
512        for &(index, buf) in buffers {
513            encoder.set_buffer(index, Some(buf.metal_buffer()), 0);
514        }
515        encoder.dispatch_threads(grid_size, threadgroup_size);
516    }
517
518    /// Encode a compute pass using threadgroups instead of raw thread counts.
519    ///
520    /// Reuses the persistent compute encoder — no per-dispatch encoder
521    /// creation overhead.
522    pub fn encode_threadgroups(
523        &mut self,
524        pipeline: &ComputePipelineStateRef,
525        buffers: &[(u64, &MlxBuffer)],
526        threadgroups: MTLSize,
527        threadgroup_size: MTLSize,
528    ) {
529        DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
530        let op_kind = self.take_pending_op_kind();
531        let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
532        if let Some(ref mut nodes) = self.capture {
533            nodes.push(CapturedNode::Dispatch {
534                pipeline: pipeline.to_owned(),
535                bindings: Self::record_buffer_bindings(buffers),
536                threads_per_grid: threadgroups,
537                threads_per_threadgroup: threadgroup_size,
538                threadgroup_memory: Vec::new(),
539                dispatch_kind: DispatchKind::ThreadGroups,
540                op_kind,
541                reads: pending_reads,
542                writes: pending_writes,
543            });
544            return;
545        }
546        let encoder = self.get_or_create_encoder();
547        encoder.set_compute_pipeline_state(pipeline);
548        for &(index, buf) in buffers {
549            encoder.set_buffer(index, Some(buf.metal_buffer()), 0);
550        }
551        encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
552    }
553
554    /// Encode a compute pass using threadgroups with shared threadgroup memory.
555    ///
556    /// Like [`encode_threadgroups`](Self::encode_threadgroups), but additionally
557    /// allocates threadgroup memory at the specified indices.  This is required
558    /// for kernels that use `threadgroup` memory (e.g. reductions in rms_norm
559    /// and softmax).
560    ///
561    /// # Arguments
562    ///
563    /// * `pipeline`         — The compiled compute pipeline to execute.
564    /// * `buffers`          — Slice of `(index, &MlxBuffer)` pairs for buffer bindings.
565    /// * `threadgroup_mem`  — Slice of `(index, byte_length)` pairs for threadgroup memory.
566    /// * `threadgroups`     — Number of threadgroups to dispatch.
567    /// * `threadgroup_size` — Threads per threadgroup.
568    pub fn encode_threadgroups_with_shared(
569        &mut self,
570        pipeline: &ComputePipelineStateRef,
571        buffers: &[(u64, &MlxBuffer)],
572        threadgroup_mem: &[(u64, u64)],
573        threadgroups: MTLSize,
574        threadgroup_size: MTLSize,
575    ) {
576        DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
577        let op_kind = self.take_pending_op_kind();
578        let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
579        if let Some(ref mut nodes) = self.capture {
580            nodes.push(CapturedNode::Dispatch {
581                pipeline: pipeline.to_owned(),
582                bindings: Self::record_buffer_bindings(buffers),
583                threads_per_grid: threadgroups,
584                threads_per_threadgroup: threadgroup_size,
585                threadgroup_memory: threadgroup_mem.to_vec(),
586                dispatch_kind: DispatchKind::ThreadGroups,
587                op_kind,
588                reads: pending_reads,
589                writes: pending_writes,
590            });
591            return;
592        }
593        let encoder = self.get_or_create_encoder();
594        encoder.set_compute_pipeline_state(pipeline);
595        for &(index, buf) in buffers {
596            encoder.set_buffer(index, Some(buf.metal_buffer()), 0);
597        }
598        for &(index, byte_length) in threadgroup_mem {
599            encoder.set_threadgroup_memory_length(index, byte_length);
600        }
601        encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
602    }
603
604    /// Encode a dispatch with mixed buffer/bytes bindings (dispatch_threads).
605    ///
606    /// Reuses the persistent compute encoder.
607    pub fn encode_with_args(
608        &mut self,
609        pipeline: &ComputePipelineStateRef,
610        bindings: &[(u64, KernelArg<'_>)],
611        grid_size: MTLSize,
612        threadgroup_size: MTLSize,
613    ) {
614        DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
615        let op_kind = self.take_pending_op_kind();
616        let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
617        if let Some(ref mut nodes) = self.capture {
618            nodes.push(CapturedNode::Dispatch {
619                pipeline: pipeline.to_owned(),
620                bindings: Self::record_arg_bindings(bindings),
621                threads_per_grid: grid_size,
622                threads_per_threadgroup: threadgroup_size,
623                threadgroup_memory: Vec::new(),
624                dispatch_kind: DispatchKind::Threads,
625                op_kind,
626                reads: pending_reads,
627                writes: pending_writes,
628            });
629            return;
630        }
631        let encoder = self.get_or_create_encoder();
632        encoder.set_compute_pipeline_state(pipeline);
633        apply_bindings(encoder, bindings);
634        encoder.dispatch_threads(grid_size, threadgroup_size);
635    }
636
637    /// Encode a dispatch with mixed buffer/bytes bindings (dispatch_thread_groups).
638    ///
639    /// Reuses the persistent compute encoder.
640    pub fn encode_threadgroups_with_args(
641        &mut self,
642        pipeline: &ComputePipelineStateRef,
643        bindings: &[(u64, KernelArg<'_>)],
644        threadgroups: MTLSize,
645        threadgroup_size: MTLSize,
646    ) {
647        DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
648        let op_kind = self.take_pending_op_kind();
649        let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
650        if let Some(ref mut nodes) = self.capture {
651            nodes.push(CapturedNode::Dispatch {
652                pipeline: pipeline.to_owned(),
653                bindings: Self::record_arg_bindings(bindings),
654                threads_per_grid: threadgroups,
655                threads_per_threadgroup: threadgroup_size,
656                threadgroup_memory: Vec::new(),
657                dispatch_kind: DispatchKind::ThreadGroups,
658                op_kind,
659                reads: pending_reads,
660                writes: pending_writes,
661            });
662            return;
663        }
664        let encoder = self.get_or_create_encoder();
665        encoder.set_compute_pipeline_state(pipeline);
666        apply_bindings(encoder, bindings);
667        encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
668    }
669
670    /// Encode a dispatch with mixed buffer/bytes bindings and shared memory.
671    ///
672    /// Reuses the persistent compute encoder.
673    pub fn encode_threadgroups_with_args_and_shared(
674        &mut self,
675        pipeline: &ComputePipelineStateRef,
676        bindings: &[(u64, KernelArg<'_>)],
677        threadgroup_mem: &[(u64, u64)],
678        threadgroups: MTLSize,
679        threadgroup_size: MTLSize,
680    ) {
681        DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
682        let op_kind = self.take_pending_op_kind();
683        let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
684        if let Some(ref mut nodes) = self.capture {
685            nodes.push(CapturedNode::Dispatch {
686                pipeline: pipeline.to_owned(),
687                bindings: Self::record_arg_bindings(bindings),
688                threads_per_grid: threadgroups,
689                threads_per_threadgroup: threadgroup_size,
690                threadgroup_memory: threadgroup_mem.to_vec(),
691                dispatch_kind: DispatchKind::ThreadGroups,
692                op_kind,
693                reads: pending_reads,
694                writes: pending_writes,
695            });
696            return;
697        }
698        let encoder = self.get_or_create_encoder();
699        encoder.set_compute_pipeline_state(pipeline);
700        apply_bindings(encoder, bindings);
701        for &(index, byte_length) in threadgroup_mem {
702            encoder.set_threadgroup_memory_length(index, byte_length);
703        }
704        encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
705    }
706
707    /// Replay a single captured dispatch node into this encoder.
708    ///
709    /// This is the inverse of capture: it takes a previously recorded
710    /// `CapturedNode::Dispatch` and encodes it into the live Metal encoder.
711    /// Barrier nodes are handled by the caller (ComputeGraph::encode_sequential).
712    ///
713    /// Does NOT increment `DISPATCH_COUNT` — that was already counted at
714    /// capture time.
715    pub fn replay_dispatch(
716        &mut self,
717        pipeline: &ComputePipelineStateRef,
718        bindings: &[(u64, RecordedBinding)],
719        threadgroup_memory: &[(u64, u64)],
720        threads_per_grid: MTLSize,
721        threads_per_threadgroup: MTLSize,
722        dispatch_kind: DispatchKind,
723    ) {
724        let encoder = self.get_or_create_encoder();
725        encoder.set_compute_pipeline_state(pipeline);
726        for (index, binding) in bindings {
727            match binding {
728                RecordedBinding::Buffer { metal_buffer, offset } => {
729                    encoder.set_buffer(*index, Some(metal_buffer), *offset);
730                }
731                RecordedBinding::Bytes(bytes) => {
732                    encoder.set_bytes(
733                        *index,
734                        bytes.len() as u64,
735                        bytes.as_ptr() as *const _,
736                    );
737                }
738            }
739        }
740        for &(index, byte_length) in threadgroup_memory {
741            encoder.set_threadgroup_memory_length(index, byte_length);
742        }
743        match dispatch_kind {
744            DispatchKind::Threads => {
745                encoder.dispatch_threads(threads_per_grid, threads_per_threadgroup);
746            }
747            DispatchKind::ThreadGroups => {
748                encoder.dispatch_thread_groups(threads_per_grid, threads_per_threadgroup);
749            }
750        }
751    }
752
753    /// Commit the command buffer and block until the GPU finishes execution.
754    ///
755    /// # Errors
756    ///
757    /// Returns `MlxError::CommandBufferError` if the GPU reports an error.
758    pub fn commit_and_wait(&mut self) -> Result<()> {
759        SYNC_COUNT.fetch_add(1, Ordering::Relaxed);
760
761        // End the persistent compute encoder before committing.
762        self.end_active_encoder();
763
764        self.cmd_buf.commit();
765        self.cmd_buf.wait_until_completed();
766
767        match self.cmd_buf.status() {
768            MTLCommandBufferStatus::Completed => Ok(()),
769            MTLCommandBufferStatus::Error => {
770                Err(MlxError::CommandBufferError(
771                    "GPU command buffer completed with error status".into(),
772                ))
773            }
774            status => Err(MlxError::CommandBufferError(format!(
775                "Unexpected command buffer status after wait: {:?}",
776                status
777            ))),
778        }
779    }
780
781    /// Commit the command buffer WITHOUT blocking.
782    ///
783    /// The GPU begins executing the encoded commands immediately.  Call
784    /// [`wait_until_completed`](Self::wait_until_completed) later to block
785    /// the CPU and check for errors.  This allows the CPU to continue doing
786    /// other work (e.g. preparing the next batch) while the GPU runs.
787    pub fn commit(&mut self) {
788        self.end_active_encoder();
789        self.cmd_buf.commit();
790    }
791
792    /// Block until a previously committed command buffer completes.
793    ///
794    /// Must be called after [`commit`](Self::commit).  Do not call after
795    /// [`commit_and_wait`](Self::commit_and_wait) — that method already waits.
796    ///
797    /// # Errors
798    ///
799    /// Returns `MlxError::CommandBufferError` if the GPU reports an error.
800    pub fn wait_until_completed(&self) -> Result<()> {
801        self.cmd_buf.wait_until_completed();
802        match self.cmd_buf.status() {
803            MTLCommandBufferStatus::Completed => Ok(()),
804            MTLCommandBufferStatus::Error => Err(MlxError::CommandBufferError(
805                "GPU command buffer completed with error status".into(),
806            )),
807            status => Err(MlxError::CommandBufferError(format!(
808                "Unexpected command buffer status after wait: {:?}",
809                status
810            ))),
811        }
812    }
813
814    /// Borrow the underlying Metal command buffer.
815    #[inline]
816    pub fn metal_command_buffer(&self) -> &CommandBuffer {
817        &self.cmd_buf
818    }
819}
820
821impl Drop for CommandEncoder {
822    fn drop(&mut self) {
823        // End the persistent compute encoder before the command buffer
824        // is dropped, otherwise Metal will assert:
825        // "Command encoder released without endEncoding"
826        self.end_active_encoder();
827    }
828}