Skip to main content

mlx_native/
encoder.rs

1//! [`CommandEncoder`] — batched GPU command submission.
2//!
3//! Wraps a Metal command buffer.  Encode one or more compute kernel dispatches,
4//! then call [`commit_and_wait`](CommandEncoder::commit_and_wait) to submit the
5//! entire batch and block until the GPU finishes.
6//!
7//! # Persistent compute encoder
8//!
9//! A single Metal `ComputeCommandEncoder` is kept alive across multiple
10//! dispatches within the same command buffer.  This avoids the overhead of
11//! creating and ending a new compute encoder per dispatch — the same pattern
12//! candle uses (`compute_per_buffer`).  On a forward pass with ~800 dispatches
13//! this saves ~800 encoder create/end cycles.
14//!
15//! # Capture mode (Phase 4e.1)
16//!
17//! When `start_capture()` is called, subsequent dispatches are recorded into a
18//! `Vec<CapturedNode>` instead of being encoded into Metal.  `memory_barrier()`
19//! records a barrier sentinel.  Call `take_capture()` to extract the recorded
20//! graph for later replay via `ComputeGraph::encode_sequential()`.
21
22use std::sync::atomic::{AtomicU64, Ordering};
23
24use metal::{
25    CommandBuffer, CommandQueue, ComputeCommandEncoderRef, ComputePipelineState,
26    ComputePipelineStateRef, MTLCommandBufferStatus, MTLDispatchType, MTLSize,
27};
28#[allow(unused_imports)]
29use objc::{msg_send, sel, sel_impl};
30
31use crate::buffer::MlxBuffer;
32use crate::error::{MlxError, Result};
33
34/// A buffer or inline-bytes binding for a compute kernel argument slot.
35pub enum KernelArg<'a> {
36    /// Bind an existing Metal buffer at the given index.
37    Buffer(&'a MlxBuffer),
38    /// Bind an existing Metal buffer at the given index with a byte offset.
39    BufferWithOffset(&'a MlxBuffer, u64),
40    /// Bind inline bytes (small constant data) at the given index.
41    /// The data must be `Pod` and is copied into the command encoder.
42    Bytes(&'a [u8]),
43}
44
45/// Convert a `Pod` value to a byte slice suitable for `KernelArg::Bytes`.
46///
47/// # Safety
48///
49/// The caller must ensure `T` has the same layout as the corresponding
50/// MSL struct in the shader (matching field order, sizes, and alignment).
51pub fn as_bytes<T: bytemuck::Pod>(val: &T) -> &[u8] {
52    bytemuck::bytes_of(val)
53}
54
55// ---------------------------------------------------------------------------
56// Capture-mode types (Phase 4e.1 — Graph IR)
57// ---------------------------------------------------------------------------
58
59/// A recorded kernel argument binding.
60///
61/// When the encoder is in capture mode, each `set_buffer` / `set_bytes` call
62/// is stored as a `RecordedBinding` instead of being applied to Metal.
63#[derive(Clone)]
64pub enum RecordedBinding {
65    /// A Metal buffer at the given offset.
66    Buffer {
67        metal_buffer: metal::Buffer,
68        offset: u64,
69    },
70    /// Inline bytes (small constant data, copied).
71    Bytes(Vec<u8>),
72}
73
74/// How to dispatch the recorded kernel.
75#[derive(Clone, Copy, Debug)]
76pub enum DispatchKind {
77    /// `dispatch_threads(grid_size, threadgroup_size)` — Metal picks threadgroup count.
78    Threads,
79    /// `dispatch_thread_groups(threadgroups, threadgroup_size)` — caller specifies threadgroup count.
80    ThreadGroups,
81}
82
83/// Operation kind tag for captured nodes, used by the fusion pass (4e.2).
84///
85/// When the encoder is in capture mode, each dispatch can be tagged with an
86/// `OpKind` so the fusion pass can identify fuseable sequences without
87/// inspecting pipeline names.
88#[derive(Clone, Copy, Debug, PartialEq, Eq)]
89pub enum CapturedOpKind {
90    /// RMS normalization (with learned scale).
91    RmsNorm,
92    /// Elementwise multiply.
93    ElemMul,
94    /// Elementwise add.
95    ElemAdd,
96    /// Scaled dot-product attention (NOT reorderable — breaks lookahead).
97    Sdpa,
98    /// Softmax (NOT reorderable — breaks lookahead).
99    Softmax,
100    /// Any other operation — treated as reorderable by the graph optimizer.
101    Other,
102}
103
104impl CapturedOpKind {
105    /// Whether this captured op kind is safe to reorder past in the graph
106    /// optimizer (Phase 4e.3).
107    ///
108    /// Mirrors the `h_safe` whitelist from llama.cpp's
109    /// `ggml_metal_graph_optimize_reorder`.  Non-safe ops break the 64-node
110    /// lookahead — the reorder pass cannot look past them.
111    pub fn is_reorderable(&self) -> bool {
112        match self {
113            Self::Sdpa | Self::Softmax => false,
114            Self::RmsNorm | Self::ElemMul | Self::ElemAdd | Self::Other => true,
115        }
116    }
117}
118
119/// A memory range annotation: (start_address, end_address).
120///
121/// Represents a contiguous GPU buffer region for conflict detection in the
122/// reorder pass (Phase 4e.3).  Addresses are CPU-visible `contents_ptr()`
123/// values, which on Apple Silicon unified memory equal the GPU addresses.
124pub type MemRange = (usize, usize);
125
126/// A single captured compute dispatch or barrier sentinel.
127///
128/// Created when the encoder is in capture mode.  Replayed later by
129/// `ComputeGraph::encode_sequential()`.
130#[derive(Clone)]
131pub enum CapturedNode {
132    /// A compute dispatch to replay.
133    Dispatch {
134        /// Pipeline state object to bind.
135        pipeline: ComputePipelineState,
136        /// Kernel argument bindings: (slot_index, binding).
137        bindings: Vec<(u64, RecordedBinding)>,
138        /// Grid or threadgroup count (interpretation depends on `dispatch_kind`).
139        threads_per_grid: MTLSize,
140        /// Threads per threadgroup.
141        threads_per_threadgroup: MTLSize,
142        /// Optional threadgroup memory allocations: (index, byte_length).
143        threadgroup_memory: Vec<(u64, u64)>,
144        /// Whether this is a dispatch_threads or dispatch_thread_groups call.
145        dispatch_kind: DispatchKind,
146        /// Operation kind tag for the fusion pass (4e.2).
147        /// Defaults to `Other` if not explicitly set via `set_op_kind()`.
148        op_kind: CapturedOpKind,
149        /// Read buffer ranges for reorder conflict detection (4e.3).
150        /// Populated from `barrier_between` calls in capture mode.
151        reads: Vec<MemRange>,
152        /// Write buffer ranges for reorder conflict detection (4e.3).
153        /// Populated from `barrier_between` calls in capture mode.
154        writes: Vec<MemRange>,
155    },
156    /// A memory barrier sentinel — forces a barrier at replay time.
157    Barrier,
158}
159
160/// Apply a slice of `KernelArg` bindings to a compute encoder.
161///
162/// `KernelArg::Buffer(buf)` propagates the `MlxBuffer::byte_offset()` so
163/// `slice_view`-derived sub-buffers are honored automatically — the
164/// kernel sees memory starting at the slice's offset. This matches the
165/// documented contract of `slice_view` and the offset-handling in the
166/// other binding paths in this file (`encode`, `encode_threadgroups`,
167/// `encode_threadgroups_with_shared`, replay). Without it, every
168/// `slice_view`-derived buffer bound via `KernelArg::Buffer` silently
169/// exposes the entire underlying allocation — surfaced by hf2q's
170/// nomic-bert iter-79 cosine parity bisection (cosine 0.098 → 0.999962
171/// after fix).
172///
173/// `KernelArg::BufferWithOffset(buf, offset)` continues to use the
174/// explicit `offset` argument verbatim (callers asking for an explicit
175/// offset get exactly that, even on sliced buffers). The two API
176/// surfaces are intentional: implicit (sliced views auto-propagate) vs.
177/// explicit (caller-controlled).
178#[inline]
179fn apply_bindings(encoder: &ComputeCommandEncoderRef, bindings: &[(u64, KernelArg<'_>)]) {
180    for &(index, ref arg) in bindings {
181        match arg {
182            KernelArg::Buffer(buf) => {
183                encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
184            }
185            KernelArg::BufferWithOffset(buf, offset) => {
186                encoder.set_buffer(index, Some(buf.metal_buffer()), *offset);
187            }
188            KernelArg::Bytes(bytes) => {
189                encoder.set_bytes(index, bytes.len() as u64, bytes.as_ptr() as *const _);
190            }
191        }
192    }
193}
194
195/// Number of times `commit_and_wait()` has been called (CPU sync points).
196static SYNC_COUNT: AtomicU64 = AtomicU64::new(0);
197
198/// Number of times an encode method has been called (GPU dispatches).
199static DISPATCH_COUNT: AtomicU64 = AtomicU64::new(0);
200
201/// Reset both `SYNC_COUNT` and `DISPATCH_COUNT` to zero.
202pub fn reset_counters() {
203    SYNC_COUNT.store(0, Ordering::Relaxed);
204    DISPATCH_COUNT.store(0, Ordering::Relaxed);
205}
206
207/// Read the current value of `SYNC_COUNT`.
208///
209/// Each call to `commit_and_wait()` increments this counter.
210pub fn sync_count() -> u64 {
211    SYNC_COUNT.load(Ordering::Relaxed)
212}
213
214/// Read the current value of `DISPATCH_COUNT`.
215///
216/// Each call to `encode()`, `encode_threadgroups()`, or
217/// `encode_threadgroups_with_shared()` increments this counter.
218pub fn dispatch_count() -> u64 {
219    DISPATCH_COUNT.load(Ordering::Relaxed)
220}
221
222/// A batched compute command encoder.
223///
224/// Keeps a single Metal `ComputeCommandEncoder` alive across multiple
225/// dispatches.  The encoder is created on the first dispatch and ended
226/// only when the command buffer is committed.  This mirrors candle's
227/// `compute_per_buffer` pattern and avoids per-dispatch encoder overhead.
228///
229/// # Typical usage
230///
231/// ```ignore
232/// let mut enc = device.command_encoder()?;
233/// // Multiple dispatches share the same compute encoder:
234/// enc.encode_threadgroups(pipeline1, &buffers1, tg1, tg_size1);
235/// enc.encode_threadgroups(pipeline2, &buffers2, tg2, tg_size2);
236/// enc.commit_and_wait()?;
237/// ```
238pub struct CommandEncoder {
239    cmd_buf: CommandBuffer,
240    // SAFETY marker: see unsafe Send impl below.
241    /// Raw pointer to the persistent compute encoder.
242    /// Non-null when a compute pass is active.
243    /// The encoder borrows from `cmd_buf` but we cannot express this
244    /// lifetime in safe Rust, so we use a raw pointer.
245    /// SAFETY: the pointer is valid as long as `cmd_buf` is alive and
246    /// `end_encoding()` has not been called on it.
247    active_encoder: *const ComputeCommandEncoderRef,
248    /// When `Some`, dispatches are recorded here instead of being encoded
249    /// into Metal.  Set via `start_capture()`, extracted via `take_capture()`.
250    capture: Option<Vec<CapturedNode>>,
251    /// Op kind tag for the NEXT captured dispatch.  Set via `set_op_kind()`,
252    /// consumed (reset to `Other`) when a dispatch is captured.
253    pending_op_kind: CapturedOpKind,
254    /// Pending read buffer ranges for the NEXT captured dispatch.
255    /// Set via `set_pending_buffer_ranges()`, consumed when the next dispatch
256    /// is captured.  Used by the reorder pass (Phase 4e.3).
257    pending_reads: Vec<MemRange>,
258    /// Pending write buffer ranges for the NEXT captured dispatch.
259    pending_writes: Vec<MemRange>,
260}
261
262/// SAFETY: CommandEncoder is safe to Send across threads provided that:
263/// 1. Only one thread accesses the encoder at a time (exclusive ownership).
264/// 2. The encoder is not used concurrently from multiple threads.
265///
266/// Metal command buffers and compute encoders are thread-safe for exclusive
267/// access (Apple documentation: "You can create command buffers, encode
268/// commands, and submit them from any thread"). The raw pointer
269/// `active_encoder` borrows from `cmd_buf` and is valid as long as
270/// `cmd_buf` is alive — this invariant holds across thread boundaries
271/// because both fields move together.
272///
273/// This matches llama.cpp's pattern of encoding command buffers on GCD
274/// worker threads via `dispatch_apply`, and is used for the dual-buffer
275/// pipeline where buf1 is encoded on a worker thread while buf0 executes.
276unsafe impl Send for CommandEncoder {}
277
278impl CommandEncoder {
279    /// Create a new command encoder from the given command queue.
280    ///
281    /// This immediately creates a Metal command buffer.
282    pub(crate) fn new(queue: &CommandQueue) -> Result<Self> {
283        let cmd_buf = queue.new_command_buffer().to_owned();
284        Ok(Self {
285            cmd_buf,
286            active_encoder: std::ptr::null(),
287            capture: None,
288            pending_op_kind: CapturedOpKind::Other,
289            pending_reads: Vec::new(),
290            pending_writes: Vec::new(),
291        })
292    }
293
294    /// Enable capture mode.
295    ///
296    /// All subsequent dispatch and barrier calls will be recorded into a
297    /// `Vec<CapturedNode>` instead of being encoded into Metal.
298    /// Call `take_capture()` to extract the recorded nodes.
299    pub fn start_capture(&mut self) {
300        self.capture = Some(Vec::with_capacity(128));
301    }
302
303    /// Whether the encoder is currently in capture mode.
304    pub fn is_capturing(&self) -> bool {
305        self.capture.is_some()
306    }
307
308    /// Extract the captured nodes, ending capture mode.
309    ///
310    /// Returns `None` if capture mode was not active.
311    pub fn take_capture(&mut self) -> Option<Vec<CapturedNode>> {
312        self.capture.take()
313    }
314
315    /// Tag the NEXT captured dispatch with the given operation kind.
316    ///
317    /// The tag is consumed (reset to `Other`) after the next dispatch is
318    /// captured.  Only meaningful in capture mode — has no effect on
319    /// direct-dispatch encoding.
320    ///
321    /// Used by op dispatch functions to annotate captures for the fusion
322    /// pass (Phase 4e.2).
323    pub fn set_op_kind(&mut self, kind: CapturedOpKind) {
324        self.pending_op_kind = kind;
325    }
326
327    /// Consume and return the pending op kind, resetting it to `Other`.
328    fn take_pending_op_kind(&mut self) -> CapturedOpKind {
329        let kind = self.pending_op_kind;
330        self.pending_op_kind = CapturedOpKind::Other;
331        kind
332    }
333
334    /// Stash buffer range annotations for the NEXT captured dispatch.
335    ///
336    /// Called by `GraphSession::barrier_between()` in capture mode to record
337    /// which buffers the next dispatch reads from and writes to.  The ranges
338    /// are consumed by the next `encode_*` call and attached to the captured
339    /// `CapturedNode::Dispatch`.
340    ///
341    /// Only meaningful in capture mode — has no effect on direct-dispatch.
342    pub fn set_pending_buffer_ranges(&mut self, reads: Vec<MemRange>, writes: Vec<MemRange>) {
343        self.pending_reads = reads;
344        self.pending_writes = writes;
345    }
346
347    /// Patch the last captured dispatch node's empty reads/writes with the
348    /// given ranges. No-op if not capturing, or if the last node isn't a
349    /// Dispatch, or if its ranges are already populated.
350    ///
351    /// Used by `GraphSession::track_dispatch` in recording mode to annotate
352    /// dispatches that were called without a preceding `barrier_between`.
353    pub fn annotate_last_dispatch_if_missing(&mut self, reads: Vec<MemRange>, writes: Vec<MemRange>) {
354        if let Some(ref mut nodes) = self.capture {
355            if let Some(CapturedNode::Dispatch { reads: r, writes: w, .. }) = nodes.last_mut() {
356                if r.is_empty() && !reads.is_empty() {
357                    *r = reads;
358                }
359                if w.is_empty() && !writes.is_empty() {
360                    *w = writes;
361                }
362            }
363        }
364    }
365
366    /// Consume and return the pending buffer range annotations.
367    fn take_pending_buffer_ranges(&mut self) -> (Vec<MemRange>, Vec<MemRange>) {
368        let reads = std::mem::take(&mut self.pending_reads);
369        let writes = std::mem::take(&mut self.pending_writes);
370        (reads, writes)
371    }
372
373    /// Record buffer bindings into `RecordedBinding` form.
374    fn record_buffer_bindings(buffers: &[(u64, &MlxBuffer)]) -> Vec<(u64, RecordedBinding)> {
375        buffers
376            .iter()
377            .map(|&(index, buf)| {
378                (
379                    index,
380                    RecordedBinding::Buffer {
381                        metal_buffer: buf.metal_buffer().clone(),
382                        offset: buf.byte_offset(),
383                    },
384                )
385            })
386            .collect()
387    }
388
389    /// Record `KernelArg` bindings into `RecordedBinding` form.
390    ///
391    /// `KernelArg::Buffer(buf)` records `buf.byte_offset()` so capture →
392    /// replay round-trips of `slice_view`-derived buffers preserve their
393    /// offsets, matching `record_buffer_bindings`'s behavior at line 382.
394    fn record_arg_bindings(bindings: &[(u64, KernelArg<'_>)]) -> Vec<(u64, RecordedBinding)> {
395        bindings
396            .iter()
397            .map(|(index, arg)| {
398                let recorded = match arg {
399                    KernelArg::Buffer(buf) => RecordedBinding::Buffer {
400                        metal_buffer: buf.metal_buffer().clone(),
401                        offset: buf.byte_offset(),
402                    },
403                    KernelArg::BufferWithOffset(buf, offset) => RecordedBinding::Buffer {
404                        metal_buffer: buf.metal_buffer().clone(),
405                        offset: *offset,
406                    },
407                    KernelArg::Bytes(bytes) => RecordedBinding::Bytes(bytes.to_vec()),
408                };
409                (*index, recorded)
410            })
411            .collect()
412    }
413
414    /// Get or create the persistent compute encoder.
415    ///
416    /// On the first call, creates a new compute encoder from the command
417    /// buffer.  On subsequent calls, returns the existing one.
418    ///
419    /// SAFETY: The returned reference borrows from `self.cmd_buf` which is
420    /// alive for the lifetime of this `CommandEncoder`.  The raw pointer is
421    /// valid until `end_active_encoder()` is called.
422    #[inline]
423    fn get_or_create_encoder(&mut self) -> &ComputeCommandEncoderRef {
424        if self.active_encoder.is_null() {
425            // Use MTLDispatchTypeConcurrent to allow independent dispatches
426            // to overlap on the GPU.  Memory barriers are inserted between
427            // dependent dispatches via `memory_barrier()`.
428            let encoder = self
429                .cmd_buf
430                .compute_command_encoder_with_dispatch_type(MTLDispatchType::Concurrent);
431            self.active_encoder = encoder as *const ComputeCommandEncoderRef;
432        }
433        // SAFETY: active_encoder is non-null and points to a valid encoder
434        // owned by cmd_buf.
435        unsafe { &*self.active_encoder }
436    }
437
438    /// End the active compute encoder if one exists.
439    #[inline]
440    fn end_active_encoder(&mut self) {
441        if !self.active_encoder.is_null() {
442            // SAFETY: the pointer was obtained from cmd_buf.new_compute_command_encoder()
443            // and has not been ended yet.
444            unsafe { &*self.active_encoder }.end_encoding();
445            self.active_encoder = std::ptr::null();
446        }
447    }
448
449    /// Insert a memory barrier with scope `MTLBarrierScopeBuffers`.
450    ///
451    /// When the encoder uses `MTLDispatchTypeConcurrent`, all dispatches can
452    /// execute concurrently unless separated by a barrier.  Call this between
453    /// dispatches where the later dispatch reads a buffer written by an
454    /// earlier one.
455    ///
456    /// This is the same pattern llama.cpp uses:
457    /// `[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers]`
458    #[allow(unexpected_cfgs)]
459    pub fn memory_barrier(&mut self) {
460        if let Some(ref mut nodes) = self.capture {
461            nodes.push(CapturedNode::Barrier);
462            return;
463        }
464        if self.active_encoder.is_null() {
465            return;
466        }
467        // SAFETY: active_encoder is non-null and valid.
468        let encoder = unsafe { &*self.active_encoder };
469        // MTLBarrierScopeBuffers = 1 << 0 = 1
470        const MTL_BARRIER_SCOPE_BUFFERS: u64 = 1;
471        unsafe {
472            let _: () = objc::msg_send![encoder, memoryBarrierWithScope: MTL_BARRIER_SCOPE_BUFFERS];
473        }
474    }
475
476    /// Set the compute pipeline state for subsequent dispatches.
477    ///
478    /// This begins a new compute pass if one is not already active.
479    pub fn set_pipeline(&mut self, pipeline: &ComputePipelineStateRef) {
480        let encoder = self.get_or_create_encoder();
481        encoder.set_compute_pipeline_state(pipeline);
482    }
483
484    /// Bind a buffer to a compute kernel argument slot.
485    ///
486    /// The `index` corresponds to the `[[buffer(N)]]` attribute in the MSL shader.
487    pub fn set_buffer(&self, index: u64, buffer: &MlxBuffer) {
488        let _ = (index, buffer);
489    }
490
491    /// Dispatch threads on the GPU.
492    pub fn dispatch_threads(&self, grid_size: MTLSize, threadgroup_size: MTLSize) {
493        let _ = (grid_size, threadgroup_size);
494    }
495
496    /// Encode a complete compute pass: set pipeline, bind buffers, dispatch.
497    ///
498    /// Reuses the persistent compute encoder — no per-dispatch encoder
499    /// creation overhead.
500    ///
501    /// # Arguments
502    ///
503    /// * `pipeline`         — The compiled compute pipeline to execute.
504    /// * `buffers`          — Slice of `(index, &MlxBuffer)` pairs for buffer bindings.
505    /// * `grid_size`        — Total number of threads to launch.
506    /// * `threadgroup_size` — Threads per threadgroup.
507    pub fn encode(
508        &mut self,
509        pipeline: &ComputePipelineStateRef,
510        buffers: &[(u64, &MlxBuffer)],
511        grid_size: MTLSize,
512        threadgroup_size: MTLSize,
513    ) {
514        DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
515        let op_kind = self.take_pending_op_kind();
516        let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
517        if let Some(ref mut nodes) = self.capture {
518            nodes.push(CapturedNode::Dispatch {
519                pipeline: pipeline.to_owned(),
520                bindings: Self::record_buffer_bindings(buffers),
521                threads_per_grid: grid_size,
522                threads_per_threadgroup: threadgroup_size,
523                threadgroup_memory: Vec::new(),
524                dispatch_kind: DispatchKind::Threads,
525                op_kind,
526                reads: pending_reads,
527                writes: pending_writes,
528            });
529            return;
530        }
531        let encoder = self.get_or_create_encoder();
532        encoder.set_compute_pipeline_state(pipeline);
533        for &(index, buf) in buffers {
534            encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
535        }
536        encoder.dispatch_threads(grid_size, threadgroup_size);
537    }
538
539    /// Encode a compute pass using threadgroups instead of raw thread counts.
540    ///
541    /// Reuses the persistent compute encoder — no per-dispatch encoder
542    /// creation overhead.
543    pub fn encode_threadgroups(
544        &mut self,
545        pipeline: &ComputePipelineStateRef,
546        buffers: &[(u64, &MlxBuffer)],
547        threadgroups: MTLSize,
548        threadgroup_size: MTLSize,
549    ) {
550        DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
551        let op_kind = self.take_pending_op_kind();
552        let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
553        if let Some(ref mut nodes) = self.capture {
554            nodes.push(CapturedNode::Dispatch {
555                pipeline: pipeline.to_owned(),
556                bindings: Self::record_buffer_bindings(buffers),
557                threads_per_grid: threadgroups,
558                threads_per_threadgroup: threadgroup_size,
559                threadgroup_memory: Vec::new(),
560                dispatch_kind: DispatchKind::ThreadGroups,
561                op_kind,
562                reads: pending_reads,
563                writes: pending_writes,
564            });
565            return;
566        }
567        let encoder = self.get_or_create_encoder();
568        encoder.set_compute_pipeline_state(pipeline);
569        for &(index, buf) in buffers {
570            encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
571        }
572        encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
573    }
574
575    /// Encode a compute pass using threadgroups with shared threadgroup memory.
576    ///
577    /// Like [`encode_threadgroups`](Self::encode_threadgroups), but additionally
578    /// allocates threadgroup memory at the specified indices.  This is required
579    /// for kernels that use `threadgroup` memory (e.g. reductions in rms_norm
580    /// and softmax).
581    ///
582    /// # Arguments
583    ///
584    /// * `pipeline`         — The compiled compute pipeline to execute.
585    /// * `buffers`          — Slice of `(index, &MlxBuffer)` pairs for buffer bindings.
586    /// * `threadgroup_mem`  — Slice of `(index, byte_length)` pairs for threadgroup memory.
587    /// * `threadgroups`     — Number of threadgroups to dispatch.
588    /// * `threadgroup_size` — Threads per threadgroup.
589    pub fn encode_threadgroups_with_shared(
590        &mut self,
591        pipeline: &ComputePipelineStateRef,
592        buffers: &[(u64, &MlxBuffer)],
593        threadgroup_mem: &[(u64, u64)],
594        threadgroups: MTLSize,
595        threadgroup_size: MTLSize,
596    ) {
597        DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
598        let op_kind = self.take_pending_op_kind();
599        let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
600        if let Some(ref mut nodes) = self.capture {
601            nodes.push(CapturedNode::Dispatch {
602                pipeline: pipeline.to_owned(),
603                bindings: Self::record_buffer_bindings(buffers),
604                threads_per_grid: threadgroups,
605                threads_per_threadgroup: threadgroup_size,
606                threadgroup_memory: threadgroup_mem.to_vec(),
607                dispatch_kind: DispatchKind::ThreadGroups,
608                op_kind,
609                reads: pending_reads,
610                writes: pending_writes,
611            });
612            return;
613        }
614        let encoder = self.get_or_create_encoder();
615        encoder.set_compute_pipeline_state(pipeline);
616        for &(index, buf) in buffers {
617            encoder.set_buffer(index, Some(buf.metal_buffer()), buf.byte_offset());
618        }
619        for &(index, byte_length) in threadgroup_mem {
620            encoder.set_threadgroup_memory_length(index, byte_length);
621        }
622        encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
623    }
624
625    /// Encode a dispatch with mixed buffer/bytes bindings (dispatch_threads).
626    ///
627    /// Reuses the persistent compute encoder.
628    pub fn encode_with_args(
629        &mut self,
630        pipeline: &ComputePipelineStateRef,
631        bindings: &[(u64, KernelArg<'_>)],
632        grid_size: MTLSize,
633        threadgroup_size: MTLSize,
634    ) {
635        DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
636        let op_kind = self.take_pending_op_kind();
637        let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
638        if let Some(ref mut nodes) = self.capture {
639            nodes.push(CapturedNode::Dispatch {
640                pipeline: pipeline.to_owned(),
641                bindings: Self::record_arg_bindings(bindings),
642                threads_per_grid: grid_size,
643                threads_per_threadgroup: threadgroup_size,
644                threadgroup_memory: Vec::new(),
645                dispatch_kind: DispatchKind::Threads,
646                op_kind,
647                reads: pending_reads,
648                writes: pending_writes,
649            });
650            return;
651        }
652        let encoder = self.get_or_create_encoder();
653        encoder.set_compute_pipeline_state(pipeline);
654        apply_bindings(encoder, bindings);
655        encoder.dispatch_threads(grid_size, threadgroup_size);
656    }
657
658    /// Encode a dispatch with mixed buffer/bytes bindings (dispatch_thread_groups).
659    ///
660    /// Reuses the persistent compute encoder.
661    pub fn encode_threadgroups_with_args(
662        &mut self,
663        pipeline: &ComputePipelineStateRef,
664        bindings: &[(u64, KernelArg<'_>)],
665        threadgroups: MTLSize,
666        threadgroup_size: MTLSize,
667    ) {
668        DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
669        let op_kind = self.take_pending_op_kind();
670        let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
671        if let Some(ref mut nodes) = self.capture {
672            nodes.push(CapturedNode::Dispatch {
673                pipeline: pipeline.to_owned(),
674                bindings: Self::record_arg_bindings(bindings),
675                threads_per_grid: threadgroups,
676                threads_per_threadgroup: threadgroup_size,
677                threadgroup_memory: Vec::new(),
678                dispatch_kind: DispatchKind::ThreadGroups,
679                op_kind,
680                reads: pending_reads,
681                writes: pending_writes,
682            });
683            return;
684        }
685        let encoder = self.get_or_create_encoder();
686        encoder.set_compute_pipeline_state(pipeline);
687        apply_bindings(encoder, bindings);
688        encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
689    }
690
691    /// Encode a dispatch with mixed buffer/bytes bindings and shared memory.
692    ///
693    /// Reuses the persistent compute encoder.
694    pub fn encode_threadgroups_with_args_and_shared(
695        &mut self,
696        pipeline: &ComputePipelineStateRef,
697        bindings: &[(u64, KernelArg<'_>)],
698        threadgroup_mem: &[(u64, u64)],
699        threadgroups: MTLSize,
700        threadgroup_size: MTLSize,
701    ) {
702        DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed);
703        let op_kind = self.take_pending_op_kind();
704        let (pending_reads, pending_writes) = self.take_pending_buffer_ranges();
705        if let Some(ref mut nodes) = self.capture {
706            nodes.push(CapturedNode::Dispatch {
707                pipeline: pipeline.to_owned(),
708                bindings: Self::record_arg_bindings(bindings),
709                threads_per_grid: threadgroups,
710                threads_per_threadgroup: threadgroup_size,
711                threadgroup_memory: threadgroup_mem.to_vec(),
712                dispatch_kind: DispatchKind::ThreadGroups,
713                op_kind,
714                reads: pending_reads,
715                writes: pending_writes,
716            });
717            return;
718        }
719        let encoder = self.get_or_create_encoder();
720        encoder.set_compute_pipeline_state(pipeline);
721        apply_bindings(encoder, bindings);
722        for &(index, byte_length) in threadgroup_mem {
723            encoder.set_threadgroup_memory_length(index, byte_length);
724        }
725        encoder.dispatch_thread_groups(threadgroups, threadgroup_size);
726    }
727
728    /// Replay a single captured dispatch node into this encoder.
729    ///
730    /// This is the inverse of capture: it takes a previously recorded
731    /// `CapturedNode::Dispatch` and encodes it into the live Metal encoder.
732    /// Barrier nodes are handled by the caller (ComputeGraph::encode_sequential).
733    ///
734    /// Does NOT increment `DISPATCH_COUNT` — that was already counted at
735    /// capture time.
736    pub fn replay_dispatch(
737        &mut self,
738        pipeline: &ComputePipelineStateRef,
739        bindings: &[(u64, RecordedBinding)],
740        threadgroup_memory: &[(u64, u64)],
741        threads_per_grid: MTLSize,
742        threads_per_threadgroup: MTLSize,
743        dispatch_kind: DispatchKind,
744    ) {
745        let encoder = self.get_or_create_encoder();
746        encoder.set_compute_pipeline_state(pipeline);
747        for (index, binding) in bindings {
748            match binding {
749                RecordedBinding::Buffer { metal_buffer, offset } => {
750                    encoder.set_buffer(*index, Some(metal_buffer), *offset);
751                }
752                RecordedBinding::Bytes(bytes) => {
753                    encoder.set_bytes(
754                        *index,
755                        bytes.len() as u64,
756                        bytes.as_ptr() as *const _,
757                    );
758                }
759            }
760        }
761        for &(index, byte_length) in threadgroup_memory {
762            encoder.set_threadgroup_memory_length(index, byte_length);
763        }
764        match dispatch_kind {
765            DispatchKind::Threads => {
766                encoder.dispatch_threads(threads_per_grid, threads_per_threadgroup);
767            }
768            DispatchKind::ThreadGroups => {
769                encoder.dispatch_thread_groups(threads_per_grid, threads_per_threadgroup);
770            }
771        }
772    }
773
774    /// Commit the command buffer and block until the GPU finishes execution.
775    ///
776    /// # Errors
777    ///
778    /// Returns `MlxError::CommandBufferError` if the GPU reports an error.
779    pub fn commit_and_wait(&mut self) -> Result<()> {
780        SYNC_COUNT.fetch_add(1, Ordering::Relaxed);
781
782        // End the persistent compute encoder before committing.
783        self.end_active_encoder();
784
785        self.cmd_buf.commit();
786        self.cmd_buf.wait_until_completed();
787
788        match self.cmd_buf.status() {
789            MTLCommandBufferStatus::Completed => Ok(()),
790            MTLCommandBufferStatus::Error => {
791                Err(MlxError::CommandBufferError(
792                    "GPU command buffer completed with error status".into(),
793                ))
794            }
795            status => Err(MlxError::CommandBufferError(format!(
796                "Unexpected command buffer status after wait: {:?}",
797                status
798            ))),
799        }
800    }
801
802    /// Commit + wait, returning `(gpu_start_s, gpu_end_s)` CFTimeInterval
803    /// timestamps from `MTLCommandBuffer`'s `GPUStartTime`/`GPUEndTime`
804    /// properties.  Both are mach-absolute CFTimeInterval seconds (double).
805    ///
806    /// Intended for `HF2Q_PROFILE_GPU_TS=1` per-bucket GPU wall-clock
807    /// attribution.  Adds exactly two ObjC property reads per call on top
808    /// of the regular `commit_and_wait` — measured well under 1 μs on
809    /// M5 Max.
810    ///
811    /// # Errors
812    ///
813    /// Returns `MlxError::CommandBufferError` if the GPU reports an error.
814    pub fn commit_wait_with_gpu_time(&mut self) -> Result<(f64, f64)> {
815        self.commit_and_wait()?;
816        // SAFETY: cmd_buf is a valid MTLCommandBuffer that has been
817        // committed and awaited.  GPUStartTime / GPUEndTime return
818        // CFTimeInterval (double precision seconds).  See
819        // https://developer.apple.com/documentation/metal/mtlcommandbuffer/1639925-gpustarttime
820        let (gpu_start, gpu_end): (f64, f64) = unsafe {
821            let cb = &*self.cmd_buf;
822            let s: f64 = msg_send![cb, GPUStartTime];
823            let e: f64 = msg_send![cb, GPUEndTime];
824            (s, e)
825        };
826        Ok((gpu_start, gpu_end))
827    }
828
829    /// Commit the command buffer WITHOUT blocking.
830    ///
831    /// The GPU begins executing the encoded commands immediately.  Call
832    /// [`wait_until_completed`](Self::wait_until_completed) later to block
833    /// the CPU and check for errors.  This allows the CPU to continue doing
834    /// other work (e.g. preparing the next batch) while the GPU runs.
835    pub fn commit(&mut self) {
836        self.end_active_encoder();
837        self.cmd_buf.commit();
838    }
839
840    /// Block until a previously committed command buffer completes.
841    ///
842    /// Must be called after [`commit`](Self::commit).  Do not call after
843    /// [`commit_and_wait`](Self::commit_and_wait) — that method already waits.
844    ///
845    /// # Errors
846    ///
847    /// Returns `MlxError::CommandBufferError` if the GPU reports an error.
848    pub fn wait_until_completed(&self) -> Result<()> {
849        self.cmd_buf.wait_until_completed();
850        match self.cmd_buf.status() {
851            MTLCommandBufferStatus::Completed => Ok(()),
852            MTLCommandBufferStatus::Error => Err(MlxError::CommandBufferError(
853                "GPU command buffer completed with error status".into(),
854            )),
855            status => Err(MlxError::CommandBufferError(format!(
856                "Unexpected command buffer status after wait: {:?}",
857                status
858            ))),
859        }
860    }
861
862    /// Borrow the underlying Metal command buffer.
863    #[inline]
864    pub fn metal_command_buffer(&self) -> &CommandBuffer {
865        &self.cmd_buf
866    }
867}
868
869impl Drop for CommandEncoder {
870    fn drop(&mut self) {
871        // End the persistent compute encoder before the command buffer
872        // is dropped, otherwise Metal will assert:
873        // "Command encoder released without endEncoding"
874        self.end_active_encoder();
875    }
876}