Skip to main content

mlx_native/
encoder.rs

1//! [`CommandEncoder`] — batched GPU command submission.
2//!
3//! Wraps a Metal command buffer.  Encode one or more compute kernel dispatches,
4//! then call [`commit_and_wait`](CommandEncoder::commit_and_wait) to submit the
5//! entire batch and block until the GPU finishes.
6//!
7//! # Persistent compute encoder
8//!
9//! A single Metal `ComputeCommandEncoder` is kept alive across multiple
10//! dispatches within the same command buffer.  This avoids the overhead of
11//! creating and ending a new compute encoder per dispatch — the same pattern
12//! candle uses (`compute_per_buffer`).  On a forward pass with ~800 dispatches
13//! this saves ~800 encoder create/end cycles.
14//!
15//! # Capture mode (Phase 4e.1)
16//!
17//! When `start_capture()` is called, subsequent dispatches are recorded into a
18//! `Vec<CapturedNode>` instead of being encoded into Metal.  `memory_barrier()`
19//! records a barrier sentinel.  Call `take_capture()` to extract the recorded
20//! graph for later replay via `ComputeGraph::encode_sequential()`.
21
22use std::sync::atomic::{AtomicU64, Ordering};
23
24use metal::{
25    CommandBuffer, CommandQueue, ComputeCommandEncoderRef, ComputePipelineState,
26    ComputePipelineStateRef, MTLCommandBufferStatus, MTLDispatchType, MTLSize,
27};
28#[allow(unused_imports)]
29use objc::{msg_send, sel, sel_impl};
30
31use crate::buffer::MlxBuffer;
32use crate::error::{MlxError, Result};
33
34/// A buffer or inline-bytes binding for a compute kernel argument slot.
35pub enum KernelArg<'a> {
36    /// Bind an existing Metal buffer at the given index.
37    Buffer(&'a MlxBuffer),
38    /// Bind an existing Metal buffer at the given index with a byte offset.
39    BufferWithOffset(&'a MlxBuffer, u64),
40    /// Bind inline bytes (small constant data) at the given index.
41    /// The data must be `Pod` and is copied into the command encoder.
42    Bytes(&'a [u8]),
43}
44
45/// Convert a `Pod` value to a byte slice suitable for `KernelArg::Bytes`.
46///
47/// # Safety
48///
49/// The caller must ensure `T` has the same layout as the corresponding
50/// MSL struct in the shader (matching field order, sizes, and alignment).
51pub fn as_bytes<T: bytemuck::Pod>(val: &T) -> &[u8] {
52    bytemuck::bytes_of(val)
53}
54
55// ---------------------------------------------------------------------------
56// Capture-mode types (Phase 4e.1 — Graph IR)
57// ---------------------------------------------------------------------------
58
59/// A recorded kernel argument binding.
60///
61/// When the encoder is in capture mode, each `set_buffer` / `set_bytes` call
62/// is stored as a `RecordedBinding` instead of being applied to Metal.
63#[derive(Clone)]
64pub enum RecordedBinding {
65    /// A Metal buffer at the given offset.
66    Buffer {
67        metal_buffer: metal::Buffer,
68        offset: u64,
69    },
70    /// Inline bytes (small constant data, copied).
71    Bytes(Vec<u8>),
72}
73
74/// How to dispatch the recorded kernel.
75#[derive(Clone, Copy, Debug)]
76pub enum DispatchKind {
77    /// `dispatch_threads(grid_size, threadgroup_size)` — Metal picks threadgroup count.
78    Threads,
79    /// `dispatch_thread_groups(threadgroups, threadgroup_size)` — caller specifies threadgroup count.
80    ThreadGroups,
81}
82
83/// Operation kind tag for captured nodes, used by the fusion pass (4e.2).
84///
85/// When the encoder is in capture mode, each dispatch can be tagged with an
86/// `OpKind` so the fusion pass can identify fuseable sequences without
87/// inspecting pipeline names.
88#[derive(Clone, Copy, Debug, PartialEq, Eq)]
89pub enum CapturedOpKind {
90    /// RMS normalization (with learned scale).
91    RmsNorm,
92    /// Elementwise multiply.
93    ElemMul,
94    /// Elementwise add.
95    ElemAdd,
96    /// Scaled dot-product attention (NOT reorderable — breaks lookahead).
97    Sdpa,
98    /// Softmax (NOT reorderable — breaks lookahead).
99    Softmax,
100    /// Any other operation — treated as reorderable by the graph optimizer.
101    Other,
102}
103
104impl CapturedOpKind {
105    /// Whether this captured op kind is safe to reorder past in the graph
106    /// optimizer (Phase 4e.3).
107    ///
108    /// Mirrors the `h_safe` whitelist from llama.cpp's
109    /// `ggml_metal_graph_optimize_reorder`.  Non-safe ops break the 64-node
110    /// lookahead — the reorder pass cannot look past them.
111    pub fn is_reorderable(&self) -> bool {
112        match self {
113            Self::Sdpa | Self::Softmax => false,
114            Self::RmsNorm | Self::ElemMul | Self::ElemAdd | Self::Other => true,
115        }
116    }
117}
118
119/// A memory range annotation: (start_address, end_address).
120///
121/// Represents a contiguous GPU buffer region for conflict detection in the
122/// reorder pass (Phase 4e.3).  Addresses are CPU-visible `contents_ptr()`
123/// values, which on Apple Silicon unified memory equal the GPU addresses.
124pub type MemRange = (usize, usize);
125
126/// A single captured compute dispatch or barrier sentinel.
127///
128/// Created when the encoder is in capture mode.  Replayed later by
129/// `ComputeGraph::encode_sequential()`.
130#[derive(Clone)]
131pub enum CapturedNode {
132    /// A compute dispatch to replay.
133    Dispatch {
134        /// Pipeline state object to bind.
135        pipeline: ComputePipelineState,
136        /// Kernel argument bindings: (slot_index, binding).
137        bindings: Vec<(u64, RecordedBinding)>,
138        /// Grid or threadgroup count (interpretation depends on `dispatch_kind`).
139        threads_per_grid: MTLSize,
140        /// Threads per threadgroup.
141        threads_per_threadgroup: MTLSize,
142        /// Optional threadgroup memory allocations: (index, byte_length).
143        threadgroup_memory: Vec<(u64, u64)>,
144        /// Whether this is a dispatch_threads or dispatch_thread_groups call.
145        dispatch_kind: DispatchKind,
146        /// Operation kind tag for the fusion pass (4e.2).
147        /// Defaults to `Other` if not explicitly set via `set_op_kind()`.
148        op_kind: CapturedOpKind,
149        /// Read buffer ranges for reorder conflict detection (4e.3).
150        /// Populated from `barrier_between` calls in capture mode.
151        reads: Vec<MemRange>,
152        /// Write buffer ranges for reorder conflict detection (4e.3).
153        /// Populated from `barrier_between` calls in capture mode.
154        writes: Vec<MemRange>,
155    },
156    /// A memory barrier sentinel — forces a barrier at replay time.
157    Barrier,
158}
159
160/// Apply a slice of `KernelArg` bindings to a compute encoder.
161#[inline]
162fn apply_bindings(encoder: &ComputeCommandEncoderRef, bindings: &[(u64, KernelArg<'_>)]) {
163    for &(index, ref arg) in bindings {
164        match arg {
165            KernelArg::Buffer(buf) => {
166                encoder.set_buffer(index, Some(buf.metal_buffer()), 0);
167            }
168            KernelArg::BufferWithOffset(buf, offset) => {
169                encoder.set_buffer(index, Some(buf.metal_buffer()), *offset);
170            }
171            KernelArg::Bytes(bytes) => {
172                encoder.set_bytes(index, bytes.len() as u64, bytes.as_ptr() as *const _);
173            }
174        }
175    }
176}
177
178/// Number of times `commit_and_wait()` has been called (CPU sync points).
179static SYNC_COUNT: AtomicU64 = AtomicU64::new(0);
180
181/// Number of times an encode method has been called (GPU dispatches).
182static DISPATCH_COUNT: AtomicU64 = AtomicU64::new(0);
183
184/// Reset both `SYNC_COUNT` and `DISPATCH_COUNT` to zero.
185pub fn reset_counters() {
186    SYNC_COUNT.store(0, Ordering::Relaxed);
187    DISPATCH_COUNT.store(0, Ordering::Relaxed);
188}
189
190/// Read the current value of `SYNC_COUNT`.
191///
192/// Each call to `commit_and_wait()` increments this counter.
193pub fn sync_count() -> u64 {
194    SYNC_COUNT.load(Ordering::Relaxed)
195}
196
197/// Read the current value of `DISPATCH_COUNT`.
198///
199/// Each call to `encode()`, `encode_threadgroups()`, or
200/// `encode_threadgroups_with_shared()` increments this counter.
201pub fn dispatch_count() -> u64 {
202    DISPATCH_COUNT.load(Ordering::Relaxed)
203}
204
205/// A batched compute command encoder.
206///
207/// Keeps a single Metal `ComputeCommandEncoder` alive across multiple
208/// dispatches.  The encoder is created on the first dispatch and ended
209/// only when the command buffer is committed.  This mirrors candle's
210/// `compute_per_buffer` pattern and avoids per-dispatch encoder overhead.
211///
212/// # Typical usage
213///
214/// ```ignore
215/// let mut enc = device.command_encoder()?;
216/// // Multiple dispatches share the same compute encoder:
217/// enc.encode_threadgroups(pipeline1, &buffers1, tg1, tg_size1);
218/// enc.encode_threadgroups(pipeline2, &buffers2, tg2, tg_size2);
219/// enc.commit_and_wait()?;
220/// ```
221pub struct CommandEncoder {
222    cmd_buf: CommandBuffer,
223    // SAFETY marker: see unsafe Send impl below.
224    /// Raw pointer to the persistent compute encoder.
225    /// Non-null when a compute pass is active.
226    /// The encoder borrows from `cmd_buf` but we cannot express this
227    /// lifetime in safe Rust, so we use a raw pointer.
228    /// SAFETY: the pointer is valid as long as `cmd_buf` is alive and
229    /// `end_encoding()` has not been called on it.
230    active_encoder: *const ComputeCommandEncoderRef,
231    /// When `Some`, dispatches are recorded here instead of being encoded
232    /// into Metal.  Set via `start_capture()`, extracted via `take_capture()`.
233    capture: Option<Vec<CapturedNode>>,
234    /// Op kind tag for the NEXT captured dispatch.  Set via `set_op_kind()`,
235    /// consumed (reset to `Other`) when a dispatch is captured.
236    pending_op_kind: CapturedOpKind,
237    /// Pending read buffer ranges for the NEXT captured dispatch.
238    /// Set via `set_pending_buffer_ranges()`, consumed when the next dispatch
239    /// is captured.  Used by the reorder pass (Phase 4e.3).
240    pending_reads: Vec<MemRange>,
241    /// Pending write buffer ranges for the NEXT captured dispatch.
242    pending_writes: Vec<MemRange>,
243}
244
245/// SAFETY: CommandEncoder is safe to Send across threads provided that:
246/// 1. Only one thread accesses the encoder at a time (exclusive ownership).
247/// 2. The encoder is not used concurrently from multiple threads.
248///
249/// Metal command buffers and compute encoders are thread-safe for exclusive
250/// access (Apple documentation: "You can create command buffers, encode
251/// commands, and submit them from any thread"). The raw pointer
252/// `active_encoder` borrows from `cmd_buf` and is valid as long as
253/// `cmd_buf` is alive — this invariant holds across thread boundaries
254/// because both fields move together.
255///
256/// This matches llama.cpp's pattern of encoding command buffers on GCD
257/// worker threads via `dispatch_apply`, and is used for the dual-buffer
258/// pipeline where buf1 is encoded on a worker thread while buf0 executes.
259unsafe impl Send for CommandEncoder {}
260
261impl CommandEncoder {
262    /// Create a new command encoder from the given command queue.
263    ///
264    /// This immediately creates a Metal command buffer.
265    pub(crate) fn new(queue: &CommandQueue) -> Result<Self> {
266        let cmd_buf = queue.new_command_buffer().to_owned();
267        Ok(Self {
268            cmd_buf,
269            active_encoder: std::ptr::null(),
270            capture: None,
271            pending_op_kind: CapturedOpKind::Other,
272            pending_reads: Vec::new(),
273            pending_writes: Vec::new(),
274        })
275    }
276
277    /// Enable capture mode.
278    ///
279    /// All subsequent dispatch and barrier calls will be recorded into a
280    /// `Vec<CapturedNode>` instead of being encoded into Metal.
281    /// Call `take_capture()` to extract the recorded nodes.
282    pub fn start_capture(&mut self) {
283        self.capture = Some(Vec::with_capacity(128));
284    }
285
286    /// Whether the encoder is currently in capture mode.
287    pub fn is_capturing(&self) -> bool {
288        self.capture.is_some()
289    }
290
291    /// Extract the captured nodes, ending capture mode.
292    ///
293    /// Returns `None` if capture mode was not active.
294    pub fn take_capture(&mut self) -> Option<Vec<CapturedNode>> {
295        self.capture.take()
296    }
297
298    /// Tag the NEXT captured dispatch with the given operation kind.
299    ///
300    /// The tag is consumed (reset to `Other`) after the next dispatch is
301    /// captured.  Only meaningful in capture mode — has no effect on
302    /// direct-dispatch encoding.
303    ///
304    /// Used by op dispatch functions to annotate captures for the fusion
305    /// pass (Phase 4e.2).
306    pub fn set_op_kind(&mut self, kind: CapturedOpKind) {
307        self.pending_op_kind = kind;
308    }
309
310    /// Consume and return the pending op kind, resetting it to `Other`.
311    fn take_pending_op_kind(&mut self) -> CapturedOpKind {
312        let kind = self.pending_op_kind;
313        self.pending_op_kind = CapturedOpKind::Other;
314        kind
315    }
316
317    /// Stash buffer range annotations for the NEXT captured dispatch.
318    ///
319    /// Called by `GraphSession::barrier_between()` in capture mode to record
320    /// which buffers the next dispatch reads from and writes to.  The ranges
321    /// are consumed by the next `encode_*` call and attached to the captured
322    /// `CapturedNode::Dispatch`.
323    ///
324    /// Only meaningful in capture mode — has no effect on direct-dispatch.
325    pub fn set_pending_buffer_ranges(&mut self, reads: Vec<MemRange>, writes: Vec<MemRange>) {
326        self.pending_reads = reads;
327        self.pending_writes = writes;
328    }
329
330    /// Consume and return the pending buffer range annotations.
331    fn take_pending_buffer_ranges(&mut self) -> (Vec<MemRange>, Vec<MemRange>) {
332        let reads = std::mem::take(&mut self.pending_reads);
333        let writes = std::mem::take(&mut self.pending_writes);
334        (reads, writes)
335    }
336
337    /// Record buffer bindings into `RecordedBinding` form.
338    fn record_buffer_bindings(buffers: &[(u64, &MlxBuffer)]) -> Vec<(u64, RecordedBinding)> {
339        buffers
340            .iter()
341            .map(|&(index, buf)| {
342                (
343                    index,
344                    RecordedBinding::Buffer {
345                        metal_buffer: buf.metal_buffer().clone(),
346                        offset: 0,
347                    },
348                )
349            })
350            .collect()
351    }
352
353    /// Record `KernelArg` bindings into `RecordedBinding` form.
354    fn record_arg_bindings(bindings: &[(u64, KernelArg<'_>)]) -> Vec<(u64, RecordedBinding)> {
355        bindings
356            .iter()
357            .map(|(index, arg)| {
358                let recorded = match arg {
359                    KernelArg::Buffer(buf) => RecordedBinding::Buffer {
360                        metal_buffer: buf.metal_buffer().clone(),
361                        offset: 0,
362                    },
363                    KernelArg::BufferWithOffset(buf, offset) => RecordedBinding::Buffer {
364                        metal_buffer: buf.metal_buffer().clone(),
365                        offset: *offset,
366                    },
367                    KernelArg::Bytes(bytes) => RecordedBinding::Bytes(bytes.to_vec()),
368                };
369                (*index, recorded)
370            })
371            .collect()
372    }
373
374    /// Get or create the persistent compute encoder.
375    ///
376    /// On the first call, creates a new compute encoder from the command
377    /// buffer.  On subsequent calls, returns the existing one.
378    ///
379    /// SAFETY: The returned reference borrows from `self.cmd_buf` which is
380    /// alive for the lifetime of this `CommandEncoder`.  The raw pointer is
381    /// valid until `end_active_encoder()` is called.
382    #[inline]
383    fn get_or_create_encoder(&mut self) -> &ComputeCommandEncoderRef {
384        if self.active_encoder.is_null() {
385            // Use MTLDispatchTypeConcurrent to allow independent dispatches
386            // to overlap on the GPU.  Memory barriers are inserted between
387            // dependent dispatches via `memory_barrier()`.
388            let encoder = self
389                .cmd_buf
390                .compute_command_encoder_with_dispatch_type(MTLDispatchType::Concurrent);
391            self.active_encoder = encoder as *const ComputeCommandEncoderRef;
392        }
393        // SAFETY: active_encoder is non-null and points to a valid encoder
394        // owned by cmd_buf.
395        unsafe { &*self.active_encoder }
396    }
397
398    /// End the active compute encoder if one exists.
399    #[inline]
400    fn end_active_encoder(&mut self) {
401        if !self.active_encoder.is_null() {
402            // SAFETY: the pointer was obtained from cmd_buf.new_compute_command_encoder()
403            // and has not been ended yet.
404            unsafe { &*self.active_encoder }.end_encoding();
405            self.active_encoder = std::ptr::null();
406        }
407    }
408
409    /// Insert a memory barrier with scope `MTLBarrierScopeBuffers`.
410    ///
411    /// When the encoder uses `MTLDispatchTypeConcurrent`, all dispatches can
412    /// execute concurrently unless separated by a barrier.  Call this between
413    /// dispatches where the later dispatch reads a buffer written by an
414    /// earlier one.
415    ///
416    /// This is the same pattern llama.cpp uses:
417    /// `[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers]`
418    #[allow(unexpected_cfgs)]
419    pub fn memory_barrier(&mut self) {
420        if let Some(ref mut nodes) = self.capture {
421            nodes.push(CapturedNode::Barrier);
422            return;
423        }
424        if self.active_encoder.is_null() {
425            return;
426        }
427        // SAFETY: active_encoder is non-null and valid.
428        let encoder = unsafe { &*self.active_encoder };
429        // MTLBarrierScopeBuffers = 1 << 0 = 1
430        const MTL_BARRIER_SCOPE_BUFFERS: u64 = 1;
431        unsafe {
432            let _: () = objc::msg_send![encoder, memoryBarrierWithScope: MTL_BARRIER_SCOPE_BUFFERS];
433        }
434    }
435
436    /// Set the compute pipeline state for subsequent dispatches.
437    ///
438    /// This begins a new compute pass if one is not already active.
439    pub fn set_pipeline(&mut self, pipeline: &ComputePipelineStateRef) {
440        let encoder = self.get_or_create_encoder();
441        encoder.set_compute_pipeline_state(pipeline);
442    }
443
444    /// Bind a buffer to a compute kernel argument slot.
445    ///
446    /// The `index` corresponds to the `[[buffer(N)]]` attribute in the MSL shader.
447    pub fn set_buffer(&self, index: u64, buffer: &MlxBuffer) {
448        let _ = (index, buffer);
449    }
450
451    /// Dispatch threads on the GPU.
452    pub fn dispatch_threads(&self, grid_size: MTLSize, threadgroup_size: MTLSize) {
453        let _ = (grid_size, threadgroup_size);
454    }
455
456    /// Encode a complete compute pass: set pipeline, bind buffers, dispatch.
457    ///
458    /// Reuses the persistent compute encoder — no per-dispatch encoder
459    /// creation overhead.
460    ///
461    /// # Arguments
462    ///
463    /// * `pipeline`         — The compiled compute pipeline to execute.
464    /// * `buffers`          — Slice of `(index, &MlxBuffer)` pairs for buffer bindings.
465    /// * `grid_size`        — Total number of threads to launch.
466    /// * `threadgroup_size` — Threads per threadgroup.
467    pub fn encode(
468        &mut self,
469        pipeline: &ComputePipelineStateRef,
470        buffers: &[(u64, &MlxBuffer)],
471        grid_size: MTLSize,
472        threadgroup_size: MTLSize,
473    ) {
474        DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
475        let op_kind = self.take_pending_op_kind();
476        let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
477        if let Some(ref mut nodes) = self.capture {
478            nodes.push(CapturedNode::Dispatch {
479                pipeline: pipeline.to_owned(),
480                bindings: Self::record_buffer_bindings(buffers),
481                threads_per_grid: grid_size,
482                threads_per_threadgroup: threadgroup_size,
483                threadgroup_memory: Vec::new(),
484                dispatch_kind: DispatchKind::Threads,
485                op_kind,
486                reads: pending_reads,
487                writes: pending_writes,
488            });
489            return;
490        }
491        let encoder = self.get_or_create_encoder();
492        encoder.set_compute_pipeline_state(pipeline);
493        for &(index, buf) in buffers {
494            encoder.set_buffer(index, Some(buf.metal_buffer()), 0);
495        }
496        encoder.dispatch_threads(grid_size, threadgroup_size);
497    }
498
499    /// Encode a compute pass using threadgroups instead of raw thread counts.
500    ///
501    /// Reuses the persistent compute encoder — no per-dispatch encoder
502    /// creation overhead.
503    pub fn encode_threadgroups(
504        &mut self,
505        pipeline: &ComputePipelineStateRef,
506        buffers: &[(u64, &MlxBuffer)],
507        threadgroups: MTLSize,
508        threadgroup_size: MTLSize,
509    ) {
510        DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
511        let op_kind = self.take_pending_op_kind();
512        let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
513        if let Some(ref mut nodes) = self.capture {
514            nodes.push(CapturedNode::Dispatch {
515                pipeline: pipeline.to_owned(),
516                bindings: Self::record_buffer_bindings(buffers),
517                threads_per_grid: threadgroups,
518                threads_per_threadgroup: threadgroup_size,
519                threadgroup_memory: Vec::new(),
520                dispatch_kind: DispatchKind::ThreadGroups,
521                op_kind,
522                reads: pending_reads,
523                writes: pending_writes,
524            });
525            return;
526        }
527        let encoder = self.get_or_create_encoder();
528        encoder.set_compute_pipeline_state(pipeline);
529        for &(index, buf) in buffers {
530            encoder.set_buffer(index, Some(buf.metal_buffer()), 0);
531        }
532        encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
533    }
534
535    /// Encode a compute pass using threadgroups with shared threadgroup memory.
536    ///
537    /// Like [`encode_threadgroups`](Self::encode_threadgroups), but additionally
538    /// allocates threadgroup memory at the specified indices.  This is required
539    /// for kernels that use `threadgroup` memory (e.g. reductions in rms_norm
540    /// and softmax).
541    ///
542    /// # Arguments
543    ///
544    /// * `pipeline`         — The compiled compute pipeline to execute.
545    /// * `buffers`          — Slice of `(index, &MlxBuffer)` pairs for buffer bindings.
546    /// * `threadgroup_mem`  — Slice of `(index, byte_length)` pairs for threadgroup memory.
547    /// * `threadgroups`     — Number of threadgroups to dispatch.
548    /// * `threadgroup_size` — Threads per threadgroup.
549    pub fn encode_threadgroups_with_shared(
550        &mut self,
551        pipeline: &ComputePipelineStateRef,
552        buffers: &[(u64, &MlxBuffer)],
553        threadgroup_mem: &[(u64, u64)],
554        threadgroups: MTLSize,
555        threadgroup_size: MTLSize,
556    ) {
557        DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
558        let op_kind = self.take_pending_op_kind();
559        let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
560        if let Some(ref mut nodes) = self.capture {
561            nodes.push(CapturedNode::Dispatch {
562                pipeline: pipeline.to_owned(),
563                bindings: Self::record_buffer_bindings(buffers),
564                threads_per_grid: threadgroups,
565                threads_per_threadgroup: threadgroup_size,
566                threadgroup_memory: threadgroup_mem.to_vec(),
567                dispatch_kind: DispatchKind::ThreadGroups,
568                op_kind,
569                reads: pending_reads,
570                writes: pending_writes,
571            });
572            return;
573        }
574        let encoder = self.get_or_create_encoder();
575        encoder.set_compute_pipeline_state(pipeline);
576        for &(index, buf) in buffers {
577            encoder.set_buffer(index, Some(buf.metal_buffer()), 0);
578        }
579        for &(index, byte_length) in threadgroup_mem {
580            encoder.set_threadgroup_memory_length(index, byte_length);
581        }
582        encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
583    }
584
585    /// Encode a dispatch with mixed buffer/bytes bindings (dispatch_threads).
586    ///
587    /// Reuses the persistent compute encoder.
588    pub fn encode_with_args(
589        &mut self,
590        pipeline: &ComputePipelineStateRef,
591        bindings: &[(u64, KernelArg<'_>)],
592        grid_size: MTLSize,
593        threadgroup_size: MTLSize,
594    ) {
595        DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
596        let op_kind = self.take_pending_op_kind();
597        let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
598        if let Some(ref mut nodes) = self.capture {
599            nodes.push(CapturedNode::Dispatch {
600                pipeline: pipeline.to_owned(),
601                bindings: Self::record_arg_bindings(bindings),
602                threads_per_grid: grid_size,
603                threads_per_threadgroup: threadgroup_size,
604                threadgroup_memory: Vec::new(),
605                dispatch_kind: DispatchKind::Threads,
606                op_kind,
607                reads: pending_reads,
608                writes: pending_writes,
609            });
610            return;
611        }
612        let encoder = self.get_or_create_encoder();
613        encoder.set_compute_pipeline_state(pipeline);
614        apply_bindings(encoder, bindings);
615        encoder.dispatch_threads(grid_size, threadgroup_size);
616    }
617
618    /// Encode a dispatch with mixed buffer/bytes bindings (dispatch_thread_groups).
619    ///
620    /// Reuses the persistent compute encoder.
621    pub fn encode_threadgroups_with_args(
622        &mut self,
623        pipeline: &ComputePipelineStateRef,
624        bindings: &[(u64, KernelArg<'_>)],
625        threadgroups: MTLSize,
626        threadgroup_size: MTLSize,
627    ) {
628        DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
629        let op_kind = self.take_pending_op_kind();
630        let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
631        if let Some(ref mut nodes) = self.capture {
632            nodes.push(CapturedNode::Dispatch {
633                pipeline: pipeline.to_owned(),
634                bindings: Self::record_arg_bindings(bindings),
635                threads_per_grid: threadgroups,
636                threads_per_threadgroup: threadgroup_size,
637                threadgroup_memory: Vec::new(),
638                dispatch_kind: DispatchKind::ThreadGroups,
639                op_kind,
640                reads: pending_reads,
641                writes: pending_writes,
642            });
643            return;
644        }
645        let encoder = self.get_or_create_encoder();
646        encoder.set_compute_pipeline_state(pipeline);
647        apply_bindings(encoder, bindings);
648        encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
649    }
650
651    /// Encode a dispatch with mixed buffer/bytes bindings and shared memory.
652    ///
653    /// Reuses the persistent compute encoder.
654    pub fn encode_threadgroups_with_args_and_shared(
655        &mut self,
656        pipeline: &ComputePipelineStateRef,
657        bindings: &[(u64, KernelArg<'_>)],
658        threadgroup_mem: &[(u64, u64)],
659        threadgroups: MTLSize,
660        threadgroup_size: MTLSize,
661    ) {
662        DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
663        let op_kind = self.take_pending_op_kind();
664        let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
665        if let Some(ref mut nodes) = self.capture {
666            nodes.push(CapturedNode::Dispatch {
667                pipeline: pipeline.to_owned(),
668                bindings: Self::record_arg_bindings(bindings),
669                threads_per_grid: threadgroups,
670                threads_per_threadgroup: threadgroup_size,
671                threadgroup_memory: threadgroup_mem.to_vec(),
672                dispatch_kind: DispatchKind::ThreadGroups,
673                op_kind,
674                reads: pending_reads,
675                writes: pending_writes,
676            });
677            return;
678        }
679        let encoder = self.get_or_create_encoder();
680        encoder.set_compute_pipeline_state(pipeline);
681        apply_bindings(encoder, bindings);
682        for &(index, byte_length) in threadgroup_mem {
683            encoder.set_threadgroup_memory_length(index, byte_length);
684        }
685        encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
686    }
687
688    /// Replay a single captured dispatch node into this encoder.
689    ///
690    /// This is the inverse of capture: it takes a previously recorded
691    /// `CapturedNode::Dispatch` and encodes it into the live Metal encoder.
692    /// Barrier nodes are handled by the caller (ComputeGraph::encode_sequential).
693    ///
694    /// Does NOT increment `DISPATCH_COUNT` — that was already counted at
695    /// capture time.
696    pub fn replay_dispatch(
697        &mut self,
698        pipeline: &ComputePipelineStateRef,
699        bindings: &[(u64, RecordedBinding)],
700        threadgroup_memory: &[(u64, u64)],
701        threads_per_grid: MTLSize,
702        threads_per_threadgroup: MTLSize,
703        dispatch_kind: DispatchKind,
704    ) {
705        let encoder = self.get_or_create_encoder();
706        encoder.set_compute_pipeline_state(pipeline);
707        for (index, binding) in bindings {
708            match binding {
709                RecordedBinding::Buffer { metal_buffer, offset } => {
710                    encoder.set_buffer(*index, Some(metal_buffer), *offset);
711                }
712                RecordedBinding::Bytes(bytes) => {
713                    encoder.set_bytes(
714                        *index,
715                        bytes.len() as u64,
716                        bytes.as_ptr() as *const _,
717                    );
718                }
719            }
720        }
721        for &(index, byte_length) in threadgroup_memory {
722            encoder.set_threadgroup_memory_length(index, byte_length);
723        }
724        match dispatch_kind {
725            DispatchKind::Threads => {
726                encoder.dispatch_threads(threads_per_grid, threads_per_threadgroup);
727            }
728            DispatchKind::ThreadGroups => {
729                encoder.dispatch_thread_groups(threads_per_grid, threads_per_threadgroup);
730            }
731        }
732    }
733
734    /// Commit the command buffer and block until the GPU finishes execution.
735    ///
736    /// # Errors
737    ///
738    /// Returns `MlxError::CommandBufferError` if the GPU reports an error.
739    pub fn commit_and_wait(&mut self) -> Result<()> {
740        SYNC_COUNT.fetch_add(1, Ordering::Relaxed);
741
742        // End the persistent compute encoder before committing.
743        self.end_active_encoder();
744
745        self.cmd_buf.commit();
746        self.cmd_buf.wait_until_completed();
747
748        match self.cmd_buf.status() {
749            MTLCommandBufferStatus::Completed => Ok(()),
750            MTLCommandBufferStatus::Error => {
751                Err(MlxError::CommandBufferError(
752                    "GPU command buffer completed with error status".into(),
753                ))
754            }
755            status => Err(MlxError::CommandBufferError(format!(
756                "Unexpected command buffer status after wait: {:?}",
757                status
758            ))),
759        }
760    }
761
762    /// Commit the command buffer WITHOUT blocking.
763    ///
764    /// The GPU begins executing the encoded commands immediately.  Call
765    /// [`wait_until_completed`](Self::wait_until_completed) later to block
766    /// the CPU and check for errors.  This allows the CPU to continue doing
767    /// other work (e.g. preparing the next batch) while the GPU runs.
768    pub fn commit(&mut self) {
769        self.end_active_encoder();
770        self.cmd_buf.commit();
771    }
772
773    /// Block until a previously committed command buffer completes.
774    ///
775    /// Must be called after [`commit`](Self::commit).  Do not call after
776    /// [`commit_and_wait`](Self::commit_and_wait) — that method already waits.
777    ///
778    /// # Errors
779    ///
780    /// Returns `MlxError::CommandBufferError` if the GPU reports an error.
781    pub fn wait_until_completed(&self) -> Result<()> {
782        self.cmd_buf.wait_until_completed();
783        match self.cmd_buf.status() {
784            MTLCommandBufferStatus::Completed => Ok(()),
785            MTLCommandBufferStatus::Error => Err(MlxError::CommandBufferError(
786                "GPU command buffer completed with error status".into(),
787            )),
788            status => Err(MlxError::CommandBufferError(format!(
789                "Unexpected command buffer status after wait: {:?}",
790                status
791            ))),
792        }
793    }
794
795    /// Borrow the underlying Metal command buffer.
796    #[inline]
797    pub fn metal_command_buffer(&self) -> &CommandBuffer {
798        &self.cmd_buf
799    }
800}
801
802impl Drop for CommandEncoder {
803    fn drop(&mut self) {
804        // End the persistent compute encoder before the command buffer
805        // is dropped, otherwise Metal will assert:
806        // "Command encoder released without endEncoding"
807        self.end_active_encoder();
808    }
809}