Skip to main content

mlx_native/
graph.rs

1//! [`GraphExecutor`] — batched Metal dispatch for single-encoder forward passes.
2//!
3//! llama.cpp's speed advantage over candle is NOT the kernels (Phase 0 proved
4//! candle's are as fast or faster per-call).  It is the dispatch pattern:
5//! 1 encoder per command buffer instead of ~120.  This module implements that
6//! pattern.
7//!
8//! # Usage
9//!
10//! ```ignore
11//! let mut executor = GraphExecutor::new(device.clone());
12//! let mut session = executor.begin()?;
13//!
14//! // All ops encode into the same command buffer — no per-op encoder creation.
15//! session.rms_norm(&mut registry, device.metal_device(), input, weight, output, params, rows, dim)?;
16//! session.quantized_matmul(&mut registry, &device, input, weight, scales, biases, &qparams)?;
17//! session.elementwise_add(&mut registry, device.metal_device(), a, b, out, n, DType::F32)?;
18//!
19//! // Single GPU sync point for the entire forward pass.
20//! session.finish()?;
21//! ```
22//!
23//! # Design
24//!
25//! The `GraphSession` holds a single `CommandEncoder`.  Each op method delegates
26//! to the existing op dispatch functions in [`crate::ops`], passing the session's
27//! shared encoder.  No new Metal code is needed — the ops already work with a
28//! shared encoder.  The executor just prevents creating a new encoder per op.
29//!
30//! # Phase 4e.1 — Graph IR
31//!
32//! The `ComputeGraph` type captures dispatches into a `Vec<CapturedNode>` for
33//! later replay.  `GraphExecutor::begin_recorded()` starts a session in capture
34//! mode: all op calls are intercepted at the `CommandEncoder` level and recorded
35//! instead of being sent to Metal.  `GraphSession::finish()` detects capture
36//! mode, extracts the recorded graph, and replays it into a fresh encoder via
37//! `ComputeGraph::encode_sequential()`.
38//!
39//! The existing direct-dispatch path (`begin()`) is completely unchanged.
40
41use metal::foreign_types::ForeignType;
42
43use crate::device::MlxDevice;
44use crate::encoder::{CapturedNode, CapturedOpKind, CommandEncoder, MemRange, RecordedBinding};
45use crate::error::Result;
46use crate::kernel_registry::KernelRegistry;
47use crate::ops;
48
49// Re-export types used in the public API so callers don't need separate imports.
50pub use crate::buffer::MlxBuffer;
51pub use crate::dtypes::DType;
52
53// ---------------------------------------------------------------------------
54// OpKind — operation classification for the reorder safety whitelist (4e.3)
55// ---------------------------------------------------------------------------
56
57/// Classification of a compute operation for reorder safety analysis.
58///
59/// Operations marked as reorderable can be freely reordered by the graph
60/// optimizer (Phase 4e.3) as long as their data dependencies allow it.
61/// Non-reorderable operations have side effects or dependencies that
62/// require them to stay in their original sequential position.
63#[derive(Clone, Copy, Debug, PartialEq, Eq)]
64pub enum OpKind {
65    /// Matrix multiplication (reorderable).
66    MatMul,
67    /// Expert-routed matrix multiplication (reorderable).
68    MatMulId,
69    /// Normalization — RMS norm, layer norm (reorderable).
70    Norm,
71    /// Rotary position embedding (reorderable).
72    Rope,
73    /// Elementwise ops — add, mul, scale, gelu, softcap, etc. (reorderable).
74    Elementwise,
75    /// Memory copy — KV cache copy, embedding gather (reorderable).
76    Copy,
77    /// Gather/scatter (reorderable).
78    Gather,
79    /// Scaled dot-product attention (NOT reorderable).
80    Sdpa,
81    /// Softmax (NOT reorderable).
82    Softmax,
83    /// MoE gate with CPU readback dependency (NOT reorderable).
84    MoeGate,
85    /// Anything else (NOT reorderable).
86    Other,
87}
88
89impl OpKind {
90    /// Whether this op kind is safe to reorder in the graph optimizer.
91    pub fn is_reorderable(&self) -> bool {
92        matches!(
93            self,
94            Self::MatMul
95                | Self::MatMulId
96                | Self::Norm
97                | Self::Rope
98                | Self::Elementwise
99                | Self::Copy
100                | Self::Gather
101        )
102    }
103}
104
105// ---------------------------------------------------------------------------
106// ComputeGraph — the recorded graph IR
107// ---------------------------------------------------------------------------
108
109/// A recorded sequence of GPU compute dispatches and barriers.
110///
111/// Created by running a forward pass with the encoder in capture mode.
112/// Can be replayed into a real `CommandEncoder` via `encode_sequential()`,
113/// producing identical Metal dispatch behavior to the original direct path.
114///
115/// Future phases (4e.2, 4e.3) will add fusion and reorder passes that
116/// transform the graph before encoding.
117pub struct ComputeGraph {
118    nodes: Vec<CapturedNode>,
119}
120
121impl ComputeGraph {
122    /// Create an empty compute graph.
123    pub fn new() -> Self {
124        Self {
125            nodes: Vec::with_capacity(128),
126        }
127    }
128
129    /// Create a compute graph from a pre-built list of captured nodes.
130    pub fn from_nodes(nodes: Vec<CapturedNode>) -> Self {
131        Self { nodes }
132    }
133
134    /// Record a captured node into the graph.
135    pub fn record(&mut self, node: CapturedNode) {
136        self.nodes.push(node);
137    }
138
139    /// Number of nodes (dispatches + barriers) in the graph.
140    pub fn len(&self) -> usize {
141        self.nodes.len()
142    }
143
144    /// Whether the graph contains no nodes.
145    pub fn is_empty(&self) -> bool {
146        self.nodes.is_empty()
147    }
148
149    /// Number of dispatch nodes (excludes barriers).
150    pub fn dispatch_count(&self) -> usize {
151        self.nodes
152            .iter()
153            .filter(|n| matches!(n, CapturedNode::Dispatch { .. }))
154            .count()
155    }
156
157    /// Number of barrier nodes.
158    pub fn barrier_count(&self) -> usize {
159        self.nodes
160            .iter()
161            .filter(|n| matches!(n, CapturedNode::Barrier))
162            .count()
163    }
164
165    /// Borrow the node list.
166    pub fn nodes(&self) -> &[CapturedNode] {
167        &self.nodes
168    }
169
170    /// Count dispatch nodes that have empty read/write range annotations.
171    ///
172    /// Used for diagnostics: if >0, the reorder pass cannot guarantee
173    /// correctness because it relies on complete annotations.
174    pub fn unannotated_dispatch_count(&self) -> usize {
175        self.nodes
176            .iter()
177            .filter(|n| matches!(n, CapturedNode::Dispatch { reads, writes, .. }
178                if reads.is_empty() || writes.is_empty()))
179            .count()
180    }
181
182    /// Take ownership of the node list, consuming the graph.
183    pub fn into_nodes(self) -> Vec<CapturedNode> {
184        self.nodes
185    }
186
187    /// Encode all nodes sequentially into the given encoder.
188    ///
189    /// Barrier sentinel nodes emit a Metal memory barrier.  Dispatch nodes
190    /// are replayed through `CommandEncoder::replay_dispatch()`.
191    ///
192    /// This produces identical GPU behavior to the direct-dispatch path —
193    /// same pipeline bindings, same dispatch dimensions, same barrier
194    /// placement.
195    ///
196    /// Returns the number of barriers emitted.
197    pub fn encode_sequential(&self, encoder: &mut CommandEncoder) -> u32 {
198        let mut barrier_count = 0u32;
199        for node in &self.nodes {
200            match node {
201                CapturedNode::Barrier => {
202                    encoder.memory_barrier();
203                    barrier_count += 1;
204                }
205                CapturedNode::Dispatch {
206                    pipeline,
207                    bindings,
208                    threads_per_grid,
209                    threads_per_threadgroup,
210                    threadgroup_memory,
211                    dispatch_kind,
212                    ..
213                } => {
214                    encoder.replay_dispatch(
215                        pipeline,
216                        bindings,
217                        threadgroup_memory,
218                        *threads_per_grid,
219                        *threads_per_threadgroup,
220                        *dispatch_kind,
221                    );
222                }
223            }
224        }
225        barrier_count
226    }
227
228    /// Encode the graph into a Metal command buffer, computing barriers on the
229    /// fly from each node's read/write buffer ranges.
230    ///
231    /// This is the correct encoding method for reordered graphs where barrier
232    /// sentinels have been stripped.  Mirrors llama.cpp's encode-time barrier
233    /// insertion via `ggml_metal_op_concurrency_check`.
234    ///
235    /// Returns the number of barriers emitted.
236    pub fn encode_with_barriers(&self, encoder: &mut CommandEncoder) -> u32 {
237        let mut tracker = ReorderConflictTracker::new();
238        let mut barrier_count = 0u32;
239
240        for node in &self.nodes {
241            match node {
242                CapturedNode::Dispatch {
243                    pipeline,
244                    bindings,
245                    threads_per_grid,
246                    threads_per_threadgroup,
247                    threadgroup_memory,
248                    dispatch_kind,
249                    reads,
250                    writes,
251                    ..
252                } => {
253                    let has_ranges = !reads.is_empty() || !writes.is_empty();
254                    if has_ranges && tracker.conflicts(reads, writes) {
255                        encoder.memory_barrier();
256                        tracker.reset();
257                        barrier_count += 1;
258                    }
259                    if has_ranges {
260                        tracker.add(reads, writes);
261                    }
262                    encoder.replay_dispatch(
263                        pipeline,
264                        bindings,
265                        threadgroup_memory,
266                        *threads_per_grid,
267                        *threads_per_threadgroup,
268                        *dispatch_kind,
269                    );
270                }
271                CapturedNode::Barrier => {
272                    // Explicit barriers still force a barrier boundary
273                    encoder.memory_barrier();
274                    tracker.reset();
275                    barrier_count += 1;
276                }
277            }
278        }
279        barrier_count
280    }
281
282    /// Encode the graph using two command buffers for CPU/GPU overlap.
283    ///
284    /// The first `n0` dispatches are encoded into `encoder0` and committed
285    /// immediately (GPU starts executing).  The remaining dispatches are encoded
286    /// into `encoder1`.  The caller is responsible for committing `encoder1`.
287    ///
288    /// This matches llama.cpp's dual command buffer pattern from
289    /// `ggml_metal_graph_compute` (ggml-metal-context.m:441-644):
290    /// `n_nodes_0 = MAX(64, 0.1 * n_nodes)` for the first buffer.
291    ///
292    /// Command buffers submitted to the same `MTLCommandQueue` execute in
293    /// submission order, so `encoder0.commit()` followed by `encoder1.commit()`
294    /// guarantees enc0 finishes before enc1 starts.  The win: the GPU starts
295    /// executing enc0 while the CPU is still encoding enc1.
296    ///
297    /// Returns `(barriers_buf0, barriers_buf1)`.
298    pub fn encode_dual_buffer(
299        &self,
300        encoder0: &mut CommandEncoder,
301        encoder1: &mut CommandEncoder,
302    ) -> (u32, u32) {
303        let dispatch_total = self.dispatch_count();
304        let n0 = std::cmp::max(64, dispatch_total / 10);
305
306        // Find the split point: the index of the n0-th dispatch node.
307        let split_idx = find_dispatch_split_index(&self.nodes, n0);
308
309        // Encode first chunk with barrier recomputation, then commit immediately.
310        let barriers0 = encode_chunk_with_barriers(&self.nodes[..split_idx], encoder0);
311        encoder0.commit();
312
313        // Encode second chunk with barrier recomputation.
314        let barriers1 = encode_chunk_with_barriers(&self.nodes[split_idx..], encoder1);
315
316        (barriers0, barriers1)
317    }
318
319    /// Run the RMS norm + MUL fusion pass over the graph.
320    ///
321    /// Scans for the pattern:
322    ///   Dispatch(RmsNorm) → Barrier(s) → Dispatch(ElemMul)
323    /// where the MUL reads the norm's output buffer, and replaces the
324    /// sequence with a single fused `rms_norm_mul_*` dispatch.
325    ///
326    /// The fused dispatch:
327    /// - Reads the norm's input (buffer 0) and weight (buffer 1)
328    /// - Reads the MUL's second operand as the scale (buffer 2)
329    /// - Writes to the MUL's output (buffer 3)
330    /// - Carries the norm's params (buffer 4)
331    /// - Uses the norm's threadgroup config and shared memory
332    ///
333    /// Returns the number of fusions applied.
334    ///
335    /// # Arguments
336    ///
337    /// * `registry` - Kernel registry for compiling the fused pipeline.
338    /// * `device`   - Metal device for pipeline compilation.
339    pub fn fuse(
340        &mut self,
341        registry: &mut KernelRegistry,
342        device: &metal::DeviceRef,
343    ) -> Result<u32> {
344        let mut result: Vec<CapturedNode> = Vec::with_capacity(self.nodes.len());
345        let mut fusions = 0u32;
346        let mut i = 0;
347
348        while i < self.nodes.len() {
349            // Check if current node is an RMS norm dispatch.
350            let is_rms_norm = matches!(
351                &self.nodes[i],
352                CapturedNode::Dispatch { op_kind: CapturedOpKind::RmsNorm, .. }
353            );
354
355            if !is_rms_norm {
356                result.push(self.nodes[i].clone());
357                i += 1;
358                continue;
359            }
360
361            // Look ahead: skip barriers, then check for ElemMul.
362            let mut j = i + 1;
363            let mut barrier_count = 0usize;
364            while j < self.nodes.len() && matches!(&self.nodes[j], CapturedNode::Barrier) {
365                barrier_count += 1;
366                j += 1;
367            }
368
369            // Must have at least one barrier and the next node must be ElemMul.
370            if barrier_count == 0 || j >= self.nodes.len() {
371                result.push(self.nodes[i].clone());
372                i += 1;
373                continue;
374            }
375
376            let is_elem_mul = matches!(
377                &self.nodes[j],
378                CapturedNode::Dispatch { op_kind: CapturedOpKind::ElemMul, .. }
379            );
380
381            if !is_elem_mul {
382                result.push(self.nodes[i].clone());
383                i += 1;
384                continue;
385            }
386
387            // Extract norm and mul dispatch fields.
388            let (norm_pipeline, norm_bindings, norm_tpg, norm_tptg, norm_tgmem, norm_dk) =
389                match &self.nodes[i] {
390                    CapturedNode::Dispatch {
391                        pipeline,
392                        bindings,
393                        threads_per_grid,
394                        threads_per_threadgroup,
395                        threadgroup_memory,
396                        dispatch_kind,
397                        ..
398                    } => (pipeline, bindings, threads_per_grid, threads_per_threadgroup, threadgroup_memory, dispatch_kind),
399                    _ => unreachable!(),
400                };
401
402            let (mul_bindings, _mul_tpg, _mul_tptg) = match &self.nodes[j] {
403                CapturedNode::Dispatch {
404                    bindings,
405                    threads_per_grid,
406                    threads_per_threadgroup,
407                    ..
408                } => (bindings, threads_per_grid, threads_per_threadgroup),
409                _ => unreachable!(),
410            };
411
412            // Verify data dependency: the norm's output buffer (slot 2) must
413            // appear as one of the MUL's input buffers (slot 0 or 1).
414            //
415            // Norm binding layout: (0=input, 1=weight, 2=output, 3=params)
416            // MUL binding layout:  (0=a, 1=b, 2=output, 3=params_bytes)
417            let norm_output_ptr = Self::buffer_ptr_for_slot(norm_bindings, 2);
418            let mul_a_ptr = Self::buffer_ptr_for_slot(mul_bindings, 0);
419            let mul_b_ptr = Self::buffer_ptr_for_slot(mul_bindings, 1);
420
421            if norm_output_ptr.is_none() || (norm_output_ptr != mul_a_ptr && norm_output_ptr != mul_b_ptr) {
422                // Data dependency not confirmed — don't fuse.
423                result.push(self.nodes[i].clone());
424                i += 1;
425                continue;
426            }
427
428            // Determine which MUL input is the scale (the one that is NOT
429            // the norm's output).
430            let scale_slot = if norm_output_ptr == mul_a_ptr { 1 } else { 0 };
431
432            // Build fused bindings:
433            //   0 = norm input
434            //   1 = norm weight
435            //   2 = scale (from MUL)
436            //   3 = MUL output
437            //   4 = norm params
438            // Gather all required bindings; bail if any are missing.
439            let (norm_input, norm_weight, scale, mul_output, norm_params) = match (
440                Self::get_binding(norm_bindings, 0),
441                Self::get_binding(norm_bindings, 1),
442                Self::get_binding(mul_bindings, scale_slot),
443                Self::get_binding(mul_bindings, 2),
444                Self::get_binding(norm_bindings, 3),
445            ) {
446                (Some(a), Some(b), Some(c), Some(d), Some(e)) => (a, b, c, d, e),
447                _ => {
448                    // Missing bindings — don't fuse.
449                    result.push(self.nodes[i].clone());
450                    i += 1;
451                    continue;
452                }
453            };
454
455            // Select fused pipeline based on the original norm pipeline name.
456            // The norm pipeline name is "rms_norm_f32", "rms_norm_f16", or
457            // "rms_norm_bf16" — we need the corresponding fused pipeline.
458            let fused_name = match Self::fused_pipeline_name(norm_pipeline) {
459                Some(name) => name,
460                None => {
461                    result.push(self.nodes[i].clone());
462                    i += 1;
463                    continue;
464                }
465            };
466
467            let fused_pipeline = registry.get_pipeline(fused_name, device)?;
468
469            let fused_bindings = vec![
470                (0, norm_input),
471                (1, norm_weight),
472                (2, scale),
473                (3, mul_output),
474                (4, norm_params),
475            ];
476
477            // Merge read/write ranges from both the norm and mul nodes for the
478            // fused dispatch.  The fused op reads everything the norm reads
479            // plus the mul's scale input, and writes to the mul's output.
480            let (fused_reads, fused_writes) = match (&self.nodes[i], &self.nodes[j]) {
481                (
482                    CapturedNode::Dispatch { reads: nr, writes: _nw, .. },
483                    CapturedNode::Dispatch { reads: mr, writes: mw, .. },
484                ) => {
485                    let mut reads = nr.clone();
486                    reads.extend_from_slice(mr);
487                    (reads, mw.clone())
488                }
489                _ => (Vec::new(), Vec::new()),
490            };
491
492            result.push(CapturedNode::Dispatch {
493                pipeline: fused_pipeline.to_owned(),
494                bindings: fused_bindings,
495                threads_per_grid: *norm_tpg,
496                threads_per_threadgroup: *norm_tptg,
497                threadgroup_memory: norm_tgmem.clone(),
498                dispatch_kind: *norm_dk,
499                op_kind: CapturedOpKind::Other, // Fused ops are not further fuseable
500                reads: fused_reads,
501                writes: fused_writes,
502            });
503
504            fusions += 1;
505            // Skip past the norm, barrier(s), and mul nodes.
506            i = j + 1;
507        }
508
509        self.nodes = result;
510        Ok(fusions)
511    }
512
513    /// Run the reorder pass over the graph to improve GPU concurrency.
514    ///
515    /// Port of llama.cpp's `ggml_metal_graph_optimize_reorder` — a greedy
516    /// 64-node lookahead that pulls independent dispatches forward to fill
517    /// larger concurrent groups between barriers.
518    ///
519    /// **Prerequisites:** Call `fuse()` first if desired.  The reorder pass
520    /// operates on the post-fusion graph.  Barrier sentinel nodes are stripped
521    /// before reordering (they will be recomputed at encode time by the
522    /// `ConflictTracker` in `encode_sequential`).
523    ///
524    /// **Algorithm (matching llama.cpp exactly):**
525    /// 1. Strip all `CapturedNode::Barrier` nodes.
526    /// 2. For each unprocessed node `i0`:
527    ///    - If it conflicts with the current concurrent group (`mrs0`):
528    ///      * Initialize `mrs1` from `i0`'s ranges (skipped-over set)
529    ///      * Lookahead up to 64 nodes for candidates that:
530    ///        (a) Are reorderable (`CapturedOpKind::is_reorderable()`)
531    ///        (b) Don't conflict with `mrs0` (current group)
532    ///        (c) Don't conflict with `mrs1` (skipped-over nodes)
533    ///      * Pull qualifying candidates into the current group
534    ///      * Non-reorderable ops break the lookahead
535    ///    - Reset `mrs0` (new concurrent group)
536    ///    - Add `i0` to the new group
537    ///
538    /// Returns the number of nodes that were moved to earlier positions.
539    pub fn reorder(&mut self) -> u32 {
540        // Step 1: Strip barrier nodes.  After fusion + reorder, barriers will
541        // be recomputed by the ConflictTracker at encode time.
542        self.nodes.retain(|n| !matches!(n, CapturedNode::Barrier));
543
544        let n = self.nodes.len();
545        if n == 0 {
546            return 0;
547        }
548
549        let mut result: Vec<usize> = Vec::with_capacity(n);
550        let mut used = vec![false; n];
551
552        // mrs0: memory ranges for the current concurrent group
553        let mut mrs0 = ReorderConflictTracker::new();
554        // mrs1: memory ranges for skipped-over (unprocessed) nodes
555        let mut mrs1 = ReorderConflictTracker::new();
556
557        const N_FORWARD: usize = 64;
558
559        for i0 in 0..n {
560            if used[i0] {
561                continue;
562            }
563
564            let node0 = &self.nodes[i0];
565
566            // Extract reads/writes for conflict check.
567            let (reads0, writes0, op_kind0) = match node0 {
568                CapturedNode::Dispatch { reads, writes, op_kind, .. } => {
569                    (reads.as_slice(), writes.as_slice(), *op_kind)
570                }
571                CapturedNode::Barrier => continue, // stripped, but be safe
572            };
573
574            // Check if node0 conflicts with the current concurrent group.
575            // Empty nodes (no ranges) never conflict — like llama.cpp's is_empty.
576            let has_ranges = !reads0.is_empty() || !writes0.is_empty();
577            if has_ranges && mrs0.conflicts(reads0, writes0) {
578                // Before starting a new group, look forward for nodes that
579                // can be pulled into the CURRENT group.
580                mrs1.reset();
581                mrs1.add(reads0, writes0);
582
583                let end = (i0 + N_FORWARD).min(n);
584                for i1 in (i0 + 1)..end {
585                    if used[i1] {
586                        continue;
587                    }
588
589                    let node1 = &self.nodes[i1];
590                    let (reads1, writes1, op_kind1) = match node1 {
591                        CapturedNode::Dispatch { reads, writes, op_kind, .. } => {
592                            (reads.as_slice(), writes.as_slice(), *op_kind)
593                        }
594                        CapturedNode::Barrier => continue,
595                    };
596
597                    // Non-reorderable ops break the lookahead.
598                    if !op_kind1.is_reorderable() {
599                        break;
600                    }
601
602                    let is_empty1 = reads1.is_empty() && writes1.is_empty();
603
604                    // A node can be reordered into the current group if:
605                    // 1. It's empty (no ranges) OR doesn't conflict with mrs0
606                    // 2. It doesn't conflict with mrs1 (skipped-over nodes)
607                    if (is_empty1 || !mrs0.conflicts(reads1, writes1))
608                        && !mrs1.conflicts(reads1, writes1)
609                    {
610                        // Pull into current concurrent group.
611                        mrs0.add(reads1, writes1);
612                        result.push(i1);
613                        used[i1] = true;
614                    } else {
615                        // Not eligible — expand the skipped-over set.
616                        mrs1.add(reads1, writes1);
617                    }
618                }
619
620                // Finalize the current concurrent group.
621                mrs0.reset();
622            }
623
624            // Expand the concurrent group with node0.
625            // (Barriers were stripped, so this is always a Dispatch.)
626            let _ = op_kind0; // suppress unused warning
627            mrs0.add(reads0, writes0);
628            result.push(i0);
629        }
630
631        // Apply the permutation to produce the reordered node list.
632        let mut reordered_count = 0u32;
633        for (pos, &orig_idx) in result.iter().enumerate() {
634            if orig_idx != pos {
635                reordered_count += 1;
636            }
637        }
638
639        // Build the reordered nodes vec.
640        let old_nodes = std::mem::take(&mut self.nodes);
641        self.nodes = result.iter().map(|&idx| old_nodes[idx].clone()).collect();
642
643        // Debug dump if requested.
644        if std::env::var("HF2Q_REORDER_DUMP").is_ok() {
645            eprintln!(
646                "  [REORDER] nodes={} reordered={} ({:.1}%)",
647                n,
648                reordered_count,
649                100.0 * reordered_count as f64 / n as f64,
650            );
651        }
652
653        reordered_count
654    }
655
656    /// Get the Metal buffer pointer for a binding at the given slot index.
657    ///
658    /// Returns `Some(ptr)` if the slot has a `RecordedBinding::Buffer`,
659    /// `None` otherwise.
660    fn buffer_ptr_for_slot(bindings: &[(u64, RecordedBinding)], slot: u64) -> Option<*const std::ffi::c_void> {
661        for (idx, binding) in bindings {
662            if *idx == slot {
663                if let RecordedBinding::Buffer { metal_buffer, offset: _ } = binding {
664                    // Use the Metal buffer's GPU address as the identity key.
665                    // On Apple Silicon unified memory, this uniquely identifies
666                    // the allocation.
667                    let ptr: *const std::ffi::c_void = metal_buffer.as_ptr() as *const _;
668                    return Some(ptr);
669                }
670            }
671        }
672        None
673    }
674
675    /// Clone the binding at the given slot index.
676    fn get_binding(bindings: &[(u64, RecordedBinding)], slot: u64) -> Option<RecordedBinding> {
677        for (idx, binding) in bindings {
678            if *idx == slot {
679                return Some(binding.clone());
680            }
681        }
682        None
683    }
684
685    /// Map a norm pipeline to its fused norm+mul pipeline name.
686    ///
687    /// The pipeline's `label()` is set by Metal to the function name, so we
688    /// can match on it.  Returns `None` if the pipeline is not a known norm.
689    fn fused_pipeline_name(pipeline: &metal::ComputePipelineState) -> Option<&'static str> {
690        match pipeline.label() {
691            "rms_norm_f32" => Some("rms_norm_mul_f32"),
692            "rms_norm_f16" => Some("rms_norm_mul_f16"),
693            "rms_norm_bf16" => Some("rms_norm_mul_bf16"),
694            _ => None,
695        }
696    }
697}
698
699impl Default for ComputeGraph {
700    fn default() -> Self {
701        Self::new()
702    }
703}
704
705// ---------------------------------------------------------------------------
706// Dual-buffer encoding helpers (Phase 4e.4)
707// ---------------------------------------------------------------------------
708
709/// Find the node index where the n0-th dispatch starts.
710///
711/// Counts `CapturedNode::Dispatch` nodes until `n0` are reached, then returns
712/// the index of the n0-th dispatch (i.e., the first node of the second chunk).
713/// If `n0 >= dispatch_count`, returns `nodes.len()` (everything in chunk 0).
714fn find_dispatch_split_index(nodes: &[CapturedNode], n0: usize) -> usize {
715    let mut dispatches_seen = 0usize;
716    for (i, node) in nodes.iter().enumerate() {
717        if matches!(node, CapturedNode::Dispatch { .. }) {
718            dispatches_seen += 1;
719            if dispatches_seen == n0 {
720                return i + 1; // split AFTER the n0-th dispatch
721            }
722        }
723    }
724    nodes.len()
725}
726
727/// Encode a slice of captured nodes into a command encoder, recomputing
728/// barriers on the fly from each node's read/write buffer ranges.
729///
730/// This is the chunked counterpart of `ComputeGraph::encode_with_barriers()`.
731/// Factored out so both halves of a dual-buffer encode can use it.
732///
733/// Returns the number of barriers emitted.
734fn encode_chunk_with_barriers(nodes: &[CapturedNode], encoder: &mut CommandEncoder) -> u32 {
735    let mut tracker = ReorderConflictTracker::new();
736    let mut barrier_count = 0u32;
737
738    for node in nodes {
739        match node {
740            CapturedNode::Dispatch {
741                pipeline,
742                bindings,
743                threads_per_grid,
744                threads_per_threadgroup,
745                threadgroup_memory,
746                dispatch_kind,
747                reads,
748                writes,
749                ..
750            } => {
751                let has_ranges = !reads.is_empty() || !writes.is_empty();
752                if has_ranges && tracker.conflicts(reads, writes) {
753                    encoder.memory_barrier();
754                    tracker.reset();
755                    barrier_count += 1;
756                }
757                if has_ranges {
758                    tracker.add(reads, writes);
759                }
760                encoder.replay_dispatch(
761                    pipeline,
762                    bindings,
763                    threadgroup_memory,
764                    *threads_per_grid,
765                    *threads_per_threadgroup,
766                    *dispatch_kind,
767                );
768            }
769            CapturedNode::Barrier => {
770                encoder.memory_barrier();
771                tracker.reset();
772                barrier_count += 1;
773            }
774        }
775    }
776    barrier_count
777}
778
779// ---------------------------------------------------------------------------
780// ReorderConflictTracker — range-based conflict detection for the reorder pass
781// ---------------------------------------------------------------------------
782
783/// Memory range conflict tracker for the reorder pass (Phase 4e.3).
784///
785/// Works with `MemRange` tuples `(start, end)` stored on `CapturedNode::Dispatch`,
786/// rather than requiring live `&MlxBuffer` references.  This is the reorder-time
787/// equivalent of the runtime `ConflictTracker`.
788///
789/// Conflict rules match llama.cpp's `ggml_mem_ranges_check`:
790/// - Two read ranges: OK (read-read is concurrent-safe)
791/// - A new read overlapping an existing write: CONFLICT (RAW)
792/// - A new write overlapping any existing range: CONFLICT (WAR/WAW)
793struct ReorderConflictTracker {
794    /// (start, end, is_write) for all ranges in the tracked set.
795    ranges: Vec<(usize, usize, bool)>,
796}
797
798impl ReorderConflictTracker {
799    fn new() -> Self {
800        Self {
801            ranges: Vec::with_capacity(64),
802        }
803    }
804
805    fn reset(&mut self) {
806        self.ranges.clear();
807    }
808
809    /// Check if a dispatch with the given read/write ranges conflicts with
810    /// any range in this tracker.
811    fn conflicts(&self, reads: &[MemRange], writes: &[MemRange]) -> bool {
812        // New reads vs existing writes (RAW)
813        for &(r_start, r_end) in reads {
814            for &(s, e, is_write) in &self.ranges {
815                if is_write && r_start < e && r_end > s {
816                    return true;
817                }
818            }
819        }
820        // New writes vs all existing ranges (WAR/WAW)
821        for &(w_start, w_end) in writes {
822            for &(s, e, _) in &self.ranges {
823                if w_start < e && w_end > s {
824                    return true;
825                }
826            }
827        }
828        false
829    }
830
831    /// Add read and write ranges to the tracked set.
832    fn add(&mut self, reads: &[MemRange], writes: &[MemRange]) {
833        for &(start, end) in reads {
834            self.ranges.push((start, end, false));
835        }
836        for &(start, end) in writes {
837            self.ranges.push((start, end, true));
838        }
839    }
840}
841
842/// Batched Metal dispatch — encodes multiple ops into a single `CommandEncoder`.
843///
844/// Create one per model (or per forward-pass loop).  Call [`begin`](Self::begin)
845/// at the start of each forward pass to get a [`GraphSession`] that holds the
846/// shared encoder.
847pub struct GraphExecutor {
848    device: MlxDevice,
849}
850
851impl GraphExecutor {
852    /// Create a new graph executor backed by the given device.
853    pub fn new(device: MlxDevice) -> Self {
854        Self { device }
855    }
856
857    /// Begin a new forward pass (direct-dispatch mode).
858    ///
859    /// Returns a [`GraphSession`] that holds a fresh `CommandEncoder`.  All ops
860    /// encoded through the session share this single encoder.  Call
861    /// [`GraphSession::finish`] to commit and wait.
862    pub fn begin(&self) -> Result<GraphSession<'_>> {
863        let encoder = self.device.command_encoder()?;
864        Ok(GraphSession {
865            encoder,
866            device: &self.device,
867            barrier_count: 0,
868            tracker: ConflictTracker::new(),
869            dispatch_in_group: 0,
870            total_dispatches: 0,
871            group_sizes: [0; 8],
872            recording: false,
873        })
874    }
875
876    /// Begin a new forward pass in capture (record) mode.
877    ///
878    /// All op calls are recorded into a `ComputeGraph` instead of being
879    /// dispatched to Metal.  When [`GraphSession::finish`] is called, the
880    /// recorded graph is replayed into a fresh encoder via
881    /// `ComputeGraph::encode_sequential()`.
882    ///
883    /// The API is identical to `begin()` — callers do not need to change
884    /// any op call code.  The only behavioral difference: GPU work happens
885    /// at `finish()` time rather than at each op call.
886    pub fn begin_recorded(&self) -> Result<GraphSession<'_>> {
887        let mut encoder = self.device.command_encoder()?;
888        encoder.start_capture();
889        Ok(GraphSession {
890            encoder,
891            device: &self.device,
892            barrier_count: 0,
893            tracker: ConflictTracker::new(),
894            dispatch_in_group: 0,
895            total_dispatches: 0,
896            group_sizes: [0; 8],
897            recording: true,
898        })
899    }
900
901    /// Borrow the underlying device.
902    pub fn device(&self) -> &MlxDevice {
903        &self.device
904    }
905}
906
907/// A single forward pass execution context.
908///
909/// All ops are encoded into one `CommandEncoder`.  Call [`finish`](Self::finish)
910/// to commit the command buffer and wait for GPU completion — this is the ONLY
911/// sync point per forward pass.
912///
913/// If an op returns an error, the session can be dropped without committing.
914/// The underlying command buffer is abandoned (never committed to the GPU).
915/// Tracks buffer address ranges for automatic barrier elision.
916///
917/// Mirrors llama.cpp's `ggml_mem_ranges` — accumulates the read and write
918/// ranges of all dispatches in the current concurrent group. When a new
919/// dispatch's reads overlap with an existing write (RAW), or its writes
920/// overlap with an existing read or write (WAR/WAW), a barrier is needed.
921/// Otherwise the dispatch can run concurrently and the barrier is elided.
922///
923/// Uses CPU-visible `contents_ptr()` addresses, which on Apple Silicon
924/// unified memory equal the GPU addresses.
925pub struct ConflictTracker {
926    /// (start, end, is_write) tuples for the current concurrent group.
927    ranges: Vec<(usize, usize, bool)>,
928}
929
930impl ConflictTracker {
931    fn new() -> Self {
932        Self {
933            ranges: Vec::with_capacity(32),
934        }
935    }
936
937    /// Reset the tracker — called after emitting a barrier.
938    fn reset(&mut self) {
939        self.ranges.clear();
940    }
941
942    /// Check if a new dispatch with the given reads and writes conflicts
943    /// with the current concurrent group.
944    ///
945    /// Conflict rules (same as llama.cpp `ggml_mem_ranges_check`):
946    /// - Two SRC (read) ranges in the same buffer: OK (read-read)
947    /// - A new SRC overlapping an existing DST: CONFLICT (RAW)
948    /// - A new DST overlapping an existing SRC or DST: CONFLICT (WAR/WAW)
949    /// Check for conflicts and return the reason if one is found.
950    /// Returns (conflict_type, new_buf_ptr, existing_buf_ptr) or None.
951    fn conflicts_reason(&self, reads: &[&MlxBuffer], writes: &[&MlxBuffer])
952        -> Option<(&'static str, usize, usize)>
953    {
954        // Check new reads against existing writes (RAW)
955        for r in reads {
956            let r_start = r.contents_ptr() as usize;
957            let r_end = r_start + r.byte_len();
958            for &(s, e, is_write) in &self.ranges {
959                if is_write && r_start < e && r_end > s {
960                    return Some(("RAW", r_start, s));
961                }
962            }
963        }
964        // Check new writes against existing reads and writes (WAR/WAW)
965        for w in writes {
966            let w_start = w.contents_ptr() as usize;
967            let w_end = w_start + w.byte_len();
968            for &(s, e, is_write) in &self.ranges {
969                if w_start < e && w_end > s {
970                    let kind = if is_write { "WAW" } else { "WAR" };
971                    return Some((kind, w_start, s));
972                }
973            }
974        }
975        None
976    }
977
978    /// Add read and write ranges to the current concurrent group.
979    fn add(&mut self, reads: &[&MlxBuffer], writes: &[&MlxBuffer]) {
980        for r in reads {
981            let start = r.contents_ptr() as usize;
982            let end = start + r.byte_len();
983            self.ranges.push((start, end, false));
984        }
985        for w in writes {
986            let start = w.contents_ptr() as usize;
987            let end = start + w.byte_len();
988            self.ranges.push((start, end, true));
989        }
990    }
991}
992
993pub struct GraphSession<'a> {
994    encoder: CommandEncoder,
995    device: &'a MlxDevice,
996    barrier_count: u32,
997    tracker: ConflictTracker,
998    dispatch_in_group: u32,
999    total_dispatches: u32,
1000    /// Histogram: group_sizes[i] = number of concurrent groups with (i+1) dispatches
1001    group_sizes: [u32; 8],
1002    /// Whether this session was created in capture/record mode.
1003    recording: bool,
1004}
1005
1006impl<'a> GraphSession<'a> {
1007    /// Encode an RMS normalization into this session's encoder.
1008    ///
1009    /// Delegates to [`ops::rms_norm::dispatch_rms_norm`].
1010    pub fn rms_norm(
1011        &mut self,
1012        registry: &mut KernelRegistry,
1013        device: &metal::DeviceRef,
1014        input: &MlxBuffer,
1015        weight: &MlxBuffer,
1016        output: &MlxBuffer,
1017        params_buf: &MlxBuffer,
1018        rows: u32,
1019        dim: u32,
1020    ) -> Result<()> {
1021        ops::rms_norm::dispatch_rms_norm(
1022            &mut self.encoder,
1023            registry,
1024            device,
1025            input,
1026            weight,
1027            output,
1028            params_buf,
1029            rows,
1030            dim,
1031        )
1032    }
1033
1034    /// Encode a quantized matrix multiplication into this session's encoder.
1035    ///
1036    /// Delegates to [`ops::quantized_matmul::quantized_matmul`].
1037    /// Returns the freshly allocated output buffer.
1038    pub fn quantized_matmul(
1039        &mut self,
1040        registry: &mut KernelRegistry,
1041        device: &MlxDevice,
1042        input: &MlxBuffer,
1043        weight: &MlxBuffer,
1044        scales: &MlxBuffer,
1045        biases: &MlxBuffer,
1046        params: &ops::quantized_matmul::QuantizedMatmulParams,
1047    ) -> Result<MlxBuffer> {
1048        ops::quantized_matmul::quantized_matmul(
1049            &mut self.encoder,
1050            registry,
1051            device,
1052            input,
1053            weight,
1054            scales,
1055            biases,
1056            params,
1057        )
1058    }
1059
1060    /// Encode a SIMD-optimized quantized matmul into this session's encoder.
1061    ///
1062    /// Delegates to [`ops::quantized_matmul::quantized_matmul_simd`].
1063    /// Returns the freshly allocated output buffer.
1064    pub fn quantized_matmul_simd(
1065        &mut self,
1066        registry: &mut KernelRegistry,
1067        device: &MlxDevice,
1068        input: &MlxBuffer,
1069        weight: &MlxBuffer,
1070        scales: &MlxBuffer,
1071        biases: &MlxBuffer,
1072        params: &ops::quantized_matmul::QuantizedMatmulParams,
1073    ) -> Result<MlxBuffer> {
1074        ops::quantized_matmul::quantized_matmul_simd(
1075            &mut self.encoder,
1076            registry,
1077            device,
1078            input,
1079            weight,
1080            scales,
1081            biases,
1082            params,
1083        )
1084    }
1085
1086    /// Encode a GGML block-format quantized mat-vec into this session's encoder.
1087    ///
1088    /// Delegates to [`ops::quantized_matmul_ggml::quantized_matmul_ggml`].
1089    pub fn quantized_matmul_ggml(
1090        &mut self,
1091        registry: &mut KernelRegistry,
1092        device: &MlxDevice,
1093        input: &MlxBuffer,
1094        weight: &MlxBuffer,
1095        output: &mut MlxBuffer,
1096        params: &ops::quantized_matmul_ggml::GgmlQuantizedMatmulParams,
1097    ) -> Result<()> {
1098        ops::quantized_matmul_ggml::quantized_matmul_ggml(
1099            &mut self.encoder,
1100            registry,
1101            device,
1102            input,
1103            weight,
1104            output,
1105            params,
1106        )
1107    }
1108
1109    /// Encode an expert-routed GGML block-format quantized mat-vec into this session's encoder.
1110    ///
1111    /// Delegates to [`ops::quantized_matmul_id_ggml::quantized_matmul_id_ggml`].
1112    #[allow(clippy::too_many_arguments)]
1113    pub fn quantized_matmul_id_ggml(
1114        &mut self,
1115        registry: &mut KernelRegistry,
1116        device: &MlxDevice,
1117        input: &MlxBuffer,
1118        weight: &MlxBuffer,
1119        ids: &MlxBuffer,
1120        output: &mut MlxBuffer,
1121        params: &ops::quantized_matmul_id_ggml::GgmlQuantizedMatmulIdParams,
1122    ) -> Result<()> {
1123        ops::quantized_matmul_id_ggml::quantized_matmul_id_ggml(
1124            &mut self.encoder,
1125            registry,
1126            device,
1127            input,
1128            weight,
1129            ids,
1130            output,
1131            params,
1132        )
1133    }
1134
1135    /// Pooled-scratch variant of [`Self::quantized_matmul_id_ggml`] — the
1136    /// `IdMmScratch` is caller-owned so batched prefill amortises the
1137    /// per-call allocations that the auto entry point incurs (ADR-011
1138    /// Phase 3 Wave P3b).
1139    #[allow(clippy::too_many_arguments)]
1140    pub fn quantized_matmul_id_ggml_pooled(
1141        &mut self,
1142        registry: &mut KernelRegistry,
1143        device: &MlxDevice,
1144        input: &MlxBuffer,
1145        weight: &MlxBuffer,
1146        ids: &MlxBuffer,
1147        output: &mut MlxBuffer,
1148        scratch: &mut ops::quantized_matmul_id_ggml::IdMmScratch,
1149        params: &ops::quantized_matmul_id_ggml::GgmlQuantizedMatmulIdParams,
1150    ) -> Result<()> {
1151        ops::quantized_matmul_id_ggml::quantized_matmul_id_ggml_pooled(
1152            &mut self.encoder,
1153            registry,
1154            device,
1155            input,
1156            weight,
1157            ids,
1158            output,
1159            scratch,
1160            params,
1161        )
1162    }
1163
1164    /// Encode scaled dot-product attention into this session's encoder.
1165    ///
1166    /// Delegates to [`ops::sdpa::sdpa`].
1167    pub fn sdpa(
1168        &mut self,
1169        registry: &mut KernelRegistry,
1170        device: &MlxDevice,
1171        q: &MlxBuffer,
1172        k: &MlxBuffer,
1173        v: &MlxBuffer,
1174        output: &MlxBuffer,
1175        params: &ops::sdpa::SdpaParams,
1176        batch_size: u32,
1177    ) -> Result<()> {
1178        ops::sdpa::sdpa(
1179            &mut self.encoder,
1180            registry,
1181            device,
1182            q,
1183            k,
1184            v,
1185            output,
1186            params,
1187            batch_size,
1188        )
1189    }
1190
1191    /// Encode flash attention vector (SIMD-vectorized decode-path SDPA).
1192    ///
1193    /// Delegates to [`ops::flash_attn_vec::flash_attn_vec`].
1194    pub fn flash_attn_vec(
1195        &mut self,
1196        registry: &mut KernelRegistry,
1197        device: &MlxDevice,
1198        q: &MlxBuffer,
1199        k: &MlxBuffer,
1200        v: &MlxBuffer,
1201        output: &MlxBuffer,
1202        tmp: &MlxBuffer,
1203        params: &ops::flash_attn_vec::FlashAttnVecParams,
1204    ) -> Result<()> {
1205        ops::flash_attn_vec::flash_attn_vec(
1206            &mut self.encoder,
1207            registry,
1208            device,
1209            q,
1210            k,
1211            v,
1212            output,
1213            tmp,
1214            params,
1215        )
1216    }
1217
1218    /// Encode an elementwise add into this session's encoder.
1219    ///
1220    /// Delegates to [`ops::elementwise::elementwise_add`].
1221    pub fn elementwise_add(
1222        &mut self,
1223        registry: &mut KernelRegistry,
1224        device: &metal::DeviceRef,
1225        a: &MlxBuffer,
1226        b: &MlxBuffer,
1227        output: &MlxBuffer,
1228        n_elements: usize,
1229        dtype: DType,
1230    ) -> Result<()> {
1231        ops::elementwise::elementwise_add(
1232            &mut self.encoder,
1233            registry,
1234            device,
1235            a,
1236            b,
1237            output,
1238            n_elements,
1239            dtype,
1240        )
1241    }
1242
1243    /// Encode an elementwise multiply into this session's encoder.
1244    ///
1245    /// Delegates to [`ops::elementwise::elementwise_mul`].
1246    pub fn elementwise_mul(
1247        &mut self,
1248        registry: &mut KernelRegistry,
1249        device: &metal::DeviceRef,
1250        a: &MlxBuffer,
1251        b: &MlxBuffer,
1252        output: &MlxBuffer,
1253        n_elements: usize,
1254        dtype: DType,
1255    ) -> Result<()> {
1256        ops::elementwise::elementwise_mul(
1257            &mut self.encoder,
1258            registry,
1259            device,
1260            a,
1261            b,
1262            output,
1263            n_elements,
1264            dtype,
1265        )
1266    }
1267
1268    /// Encode a RoPE transform into this session's encoder.
1269    ///
1270    /// Delegates to [`ops::rope::dispatch_rope`].
1271    pub fn rope(
1272        &mut self,
1273        registry: &mut KernelRegistry,
1274        device: &metal::DeviceRef,
1275        input: &MlxBuffer,
1276        output: &MlxBuffer,
1277        params_buf: &MlxBuffer,
1278        positions_buf: &MlxBuffer,
1279        seq_len: u32,
1280        head_dim: u32,
1281    ) -> Result<()> {
1282        ops::rope::dispatch_rope(
1283            &mut self.encoder,
1284            registry,
1285            device,
1286            input,
1287            output,
1288            params_buf,
1289            positions_buf,
1290            seq_len,
1291            head_dim,
1292        )
1293    }
1294
1295    /// Encode a GELU activation into this session's encoder.
1296    ///
1297    /// Delegates to [`ops::gelu::dispatch_gelu`].
1298    pub fn gelu(
1299        &mut self,
1300        registry: &mut KernelRegistry,
1301        device: &metal::DeviceRef,
1302        input: &MlxBuffer,
1303        output: &MlxBuffer,
1304    ) -> Result<()> {
1305        ops::gelu::dispatch_gelu(
1306            &mut self.encoder,
1307            registry,
1308            device,
1309            input,
1310            output,
1311        )
1312    }
1313
1314    /// Encode a softmax into this session's encoder.
1315    ///
1316    /// Delegates to [`ops::softmax::dispatch_softmax`].
1317    pub fn softmax(
1318        &mut self,
1319        registry: &mut KernelRegistry,
1320        device: &metal::DeviceRef,
1321        input: &MlxBuffer,
1322        output: &MlxBuffer,
1323        params_buf: &MlxBuffer,
1324        rows: u32,
1325        cols: u32,
1326    ) -> Result<()> {
1327        ops::softmax::dispatch_softmax(
1328            &mut self.encoder,
1329            registry,
1330            device,
1331            input,
1332            output,
1333            params_buf,
1334            rows,
1335            cols,
1336        )
1337    }
1338
1339    /// Encode a softcap into this session's encoder.
1340    ///
1341    /// Delegates to [`ops::softcap::dispatch_softcap`].
1342    pub fn softcap(
1343        &mut self,
1344        registry: &mut KernelRegistry,
1345        device: &metal::DeviceRef,
1346        input: &MlxBuffer,
1347        output: &MlxBuffer,
1348        params_buf: &MlxBuffer,
1349        cap: f32,
1350    ) -> Result<()> {
1351        ops::softcap::dispatch_softcap(
1352            &mut self.encoder,
1353            registry,
1354            device,
1355            input,
1356            output,
1357            params_buf,
1358            cap,
1359        )
1360    }
1361
1362    /// Encode an RMS norm without learned scale (f32) into this session's encoder.
1363    ///
1364    /// Delegates to [`ops::rms_norm::dispatch_rms_norm_no_scale_f32`].
1365    pub fn rms_norm_no_scale_f32(
1366        &mut self,
1367        registry: &mut KernelRegistry,
1368        device: &metal::DeviceRef,
1369        input: &MlxBuffer,
1370        output: &MlxBuffer,
1371        params_buf: &MlxBuffer,
1372        rows: u32,
1373        dim: u32,
1374    ) -> Result<()> {
1375        ops::rms_norm::dispatch_rms_norm_no_scale_f32(
1376            &mut self.encoder,
1377            registry,
1378            device,
1379            input,
1380            output,
1381            params_buf,
1382            rows,
1383            dim,
1384        )
1385    }
1386
1387    /// Encode a NeoX RoPE (f32) with optional freq_factors into this session's encoder.
1388    ///
1389    /// Delegates to [`ops::rope::dispatch_rope_neox_f32`].
1390    #[allow(clippy::too_many_arguments)]
1391    pub fn rope_neox_f32(
1392        &mut self,
1393        registry: &mut KernelRegistry,
1394        device: &metal::DeviceRef,
1395        input: &MlxBuffer,
1396        output: &MlxBuffer,
1397        params_buf: &MlxBuffer,
1398        positions_buf: &MlxBuffer,
1399        freq_factors: Option<&MlxBuffer>,
1400        seq_len: u32,
1401        n_heads: u32,
1402        head_dim: u32,
1403        rope_dim: u32,
1404    ) -> Result<()> {
1405        ops::rope::dispatch_rope_neox_f32(
1406            &mut self.encoder,
1407            registry,
1408            device,
1409            input,
1410            output,
1411            params_buf,
1412            positions_buf,
1413            freq_factors,
1414            seq_len,
1415            n_heads,
1416            head_dim,
1417            rope_dim,
1418        )
1419    }
1420
1421    /// Insert a GPU memory barrier (MTLBarrierScopeBuffers).
1422    ///
1423    /// Unconditional barrier — always emits. Use `barrier_between` for
1424    /// automatic conflict detection that can elide unnecessary barriers.
1425    #[inline]
1426    pub fn barrier(&mut self) {
1427        // Record the outgoing group size
1428        if self.dispatch_in_group > 0 {
1429            let idx = (self.dispatch_in_group as usize).min(self.group_sizes.len()) - 1;
1430            self.group_sizes[idx] += 1;
1431        }
1432        self.encoder.memory_barrier();
1433        self.tracker.reset();
1434        self.barrier_count += 1;
1435        self.dispatch_in_group = 0;
1436    }
1437
1438    /// Smart barrier with conflict detection.
1439    ///
1440    /// Checks if the next dispatch (with the given read and write buffers)
1441    /// actually conflicts with any dispatch in the current concurrent group.
1442    /// If yes, emits a Metal barrier and resets the tracker. If no, the
1443    /// barrier is elided and the dispatch can run concurrently.
1444    ///
1445    /// This mirrors llama.cpp's `ggml_metal_op_concurrency_check` +
1446    /// `ggml_metal_op_concurrency_reset` pattern.
1447    #[inline]
1448    pub fn barrier_between(&mut self, reads: &[&MlxBuffer], writes: &[&MlxBuffer]) {
1449        // In capture mode, stash the read/write ranges so the next captured
1450        // dispatch node carries them for the reorder pass (Phase 4e.3).
1451        if self.recording {
1452            let read_ranges: Vec<MemRange> = reads
1453                .iter()
1454                .map(|b| {
1455                    let start = b.contents_ptr() as usize;
1456                    (start, start + b.byte_len())
1457                })
1458                .collect();
1459            let write_ranges: Vec<MemRange> = writes
1460                .iter()
1461                .map(|b| {
1462                    let start = b.contents_ptr() as usize;
1463                    (start, start + b.byte_len())
1464                })
1465                .collect();
1466            self.encoder.set_pending_buffer_ranges(read_ranges, write_ranges);
1467        }
1468
1469        let reason = self.tracker.conflicts_reason(reads, writes);
1470        if let Some((_kind, _new_ptr, _existing_ptr)) = reason {
1471            // Record the outgoing group size before resetting
1472            if self.dispatch_in_group > 0 {
1473                let idx = (self.dispatch_in_group as usize).min(self.group_sizes.len()) - 1;
1474                self.group_sizes[idx] += 1;
1475            }
1476            self.encoder.memory_barrier();
1477            self.tracker.reset();
1478            self.barrier_count += 1;
1479            self.dispatch_in_group = 0;
1480        }
1481        self.dispatch_in_group += 1;
1482        self.total_dispatches += 1;
1483        self.tracker.add(reads, writes);
1484    }
1485
1486    /// Print group size histogram to stderr (for HF2Q_MLX_TIMING debug).
1487    pub fn dump_group_stats(&self) {
1488        // Record the final (unterminated) group
1489        let mut gs = self.group_sizes;
1490        if self.dispatch_in_group > 0 {
1491            let idx = (self.dispatch_in_group as usize).min(gs.len()) - 1;
1492            gs[idx] += 1;
1493        }
1494        let total_groups: u32 = gs.iter().sum();
1495        eprintln!("  [GROUP_STATS] dispatches={} barriers={} groups={} ratio={:.2}",
1496            self.total_dispatches, self.barrier_count, total_groups,
1497            if total_groups > 0 { self.total_dispatches as f64 / total_groups as f64 } else { 0.0 });
1498        for (i, &count) in gs.iter().enumerate() {
1499            if count > 0 {
1500                eprintln!("    size {}: {} groups", i + 1, count);
1501            }
1502        }
1503    }
1504
1505    /// Register a dispatch's buffer ranges without checking for conflicts.
1506    ///
1507    /// Use after dispatching an op that doesn't need a barrier check (e.g.,
1508    /// the first dispatch in a session, or dispatches known to be concurrent).
1509    ///
1510    /// In recording mode, also retroactively annotates the most recently
1511    /// captured dispatch node with these ranges if it was missing them.
1512    /// That keeps the reorder pass able to reason about dispatches that
1513    /// were preceded by `track_dispatch` rather than `barrier_between`.
1514    #[inline]
1515    pub fn track_dispatch(&mut self, reads: &[&MlxBuffer], writes: &[&MlxBuffer]) {
1516        if self.recording {
1517            let read_ranges: Vec<MemRange> = reads
1518                .iter()
1519                .map(|b| {
1520                    let start = b.contents_ptr() as usize;
1521                    (start, start + b.byte_len())
1522                })
1523                .collect();
1524            let write_ranges: Vec<MemRange> = writes
1525                .iter()
1526                .map(|b| {
1527                    let start = b.contents_ptr() as usize;
1528                    (start, start + b.byte_len())
1529                })
1530                .collect();
1531            self.encoder
1532                .annotate_last_dispatch_if_missing(read_ranges, write_ranges);
1533        }
1534        self.tracker.add(reads, writes);
1535    }
1536
1537    /// Return the number of barriers inserted so far in this session.
1538    #[inline]
1539    pub fn barrier_count(&self) -> u32 {
1540        self.barrier_count
1541    }
1542
1543    /// Cumulative nanoseconds spent in ConflictTracker checks (diagnostic).
1544    /// Returns 0 when timing is not compiled in.
1545    pub fn tracker_overhead_ns(&self) -> u64 {
1546        0
1547    }
1548
1549    /// Borrow the underlying command encoder for direct op dispatch.
1550    ///
1551    /// Use this when you need to call an op function that is not wrapped by
1552    /// a `GraphSession` method.  The returned encoder is the same shared
1553    /// encoder — all dispatches still go into the same command buffer.
1554    pub fn encoder_mut(&mut self) -> &mut CommandEncoder {
1555        &mut self.encoder
1556    }
1557
1558    /// Borrow the device reference.
1559    pub fn device(&self) -> &MlxDevice {
1560        self.device
1561    }
1562
1563    /// Whether this session is in capture/record mode.
1564    pub fn is_recording(&self) -> bool {
1565        self.recording
1566    }
1567
1568    /// Commit the command buffer and wait for GPU completion.
1569    ///
1570    /// This is the ONLY sync point per forward pass.  After this call, all
1571    /// output buffers are readable by the CPU.
1572    ///
1573    /// In recording mode: extracts the captured graph, replays it into
1574    /// the encoder via `ComputeGraph::encode_sequential()`, then commits
1575    /// and waits.  The result is identical to the direct-dispatch path.
1576    ///
1577    /// Consumes the session — no further ops can be encoded.
1578    pub fn finish(mut self) -> Result<()> {
1579        if self.recording {
1580            if let Some(nodes) = self.encoder.take_capture() {
1581                let graph = ComputeGraph::from_nodes(nodes);
1582                graph.encode_sequential(&mut self.encoder);
1583            }
1584        }
1585        self.encoder.commit_and_wait()
1586    }
1587
1588    /// Commit the command buffer WITHOUT waiting.
1589    ///
1590    /// The GPU begins executing immediately.  Use this for fire-and-forget
1591    /// dispatch when you do not need results until later.
1592    ///
1593    /// In recording mode: replays the captured graph before committing.
1594    ///
1595    /// Consumes the session.
1596    pub fn commit(mut self) -> CommandEncoder {
1597        if self.recording {
1598            if let Some(nodes) = self.encoder.take_capture() {
1599                let graph = ComputeGraph::from_nodes(nodes);
1600                graph.encode_sequential(&mut self.encoder);
1601            }
1602        }
1603        self.encoder.commit();
1604        self.encoder
1605    }
1606
1607    /// Commit the command buffer and wait, returning split timing.
1608    ///
1609    /// Returns `(encoding_ns, gpu_wait_ns)` where:
1610    /// - `encoding_ns` is the time from session begin to commit (CPU encoding)
1611    /// - `gpu_wait_ns` is the time from commit to GPU completion
1612    ///
1613    /// The `session_begin` instant should be captured right after `exec.begin()`.
1614    ///
1615    /// In recording mode: replays the captured graph before committing.
1616    ///
1617    /// Consumes the session.
1618    pub fn finish_with_timing(mut self, session_begin: std::time::Instant) -> Result<(u64, u64)> {
1619        if self.recording {
1620            if let Some(nodes) = self.encoder.take_capture() {
1621                let graph = ComputeGraph::from_nodes(nodes);
1622                graph.encode_sequential(&mut self.encoder);
1623            }
1624        }
1625        let commit_start = std::time::Instant::now();
1626        let encoding_ns = commit_start.duration_since(session_begin).as_nanos() as u64;
1627        self.encoder.commit();
1628        self.encoder.wait_until_completed()?;
1629        let gpu_wait_ns = commit_start.elapsed().as_nanos() as u64;
1630        Ok((encoding_ns, gpu_wait_ns))
1631    }
1632
1633    /// Finish this session and return the GPU wall-clock interval in ns.
1634    ///
1635    /// Returns `(gpu_interval_ns,)` where `gpu_interval_ns` is the CFTimeInterval
1636    /// difference between `MTLCommandBuffer.GPUEndTime` and
1637    /// `MTLCommandBuffer.GPUStartTime`, converted to ns.  Excludes CPU
1638    /// commit+wait overhead — that appears in the residual when bucket
1639    /// sums are compared to the outer wall-clock.
1640    ///
1641    /// Used by `HF2Q_PROFILE_GPU_TS=1` to accumulate pure GPU time per
1642    /// op bucket.  In recording mode: replays the captured graph before
1643    /// committing.
1644    ///
1645    /// Consumes the session.
1646    pub fn finish_with_gpu_time(mut self) -> Result<u64> {
1647        if self.recording {
1648            if let Some(nodes) = self.encoder.take_capture() {
1649                let graph = ComputeGraph::from_nodes(nodes);
1650                graph.encode_sequential(&mut self.encoder);
1651            }
1652        }
1653        let (gs, ge) = self.encoder.commit_wait_with_gpu_time()?;
1654        // GPUStartTime/GPUEndTime are CFTimeInterval (seconds, double).
1655        // Guard against negative deltas (can happen on the first CB of
1656        // a run if the kernel driver lazily initialises the timeline;
1657        // clamp to zero in that case).
1658        let delta = (ge - gs).max(0.0);
1659        Ok((delta * 1.0e9) as u64)
1660    }
1661
1662    /// Finish with fusion: run the RMS norm + MUL fusion pass before
1663    /// replaying the graph.
1664    ///
1665    /// Only meaningful in recording mode.  In direct-dispatch mode, this
1666    /// behaves identically to `finish()`.
1667    ///
1668    /// Returns `(fusions_applied,)` on success.
1669    pub fn finish_with_fusion(
1670        mut self,
1671        registry: &mut KernelRegistry,
1672        device: &metal::DeviceRef,
1673    ) -> Result<u32> {
1674        let mut fusions = 0;
1675        if self.recording {
1676            if let Some(nodes) = self.encoder.take_capture() {
1677                let mut graph = ComputeGraph::from_nodes(nodes);
1678                fusions = graph.fuse(registry, device)?;
1679                graph.encode_sequential(&mut self.encoder);
1680            }
1681        }
1682        self.encoder.commit_and_wait()?;
1683        Ok(fusions)
1684    }
1685
1686    /// Finish with fusion and split timing.
1687    ///
1688    /// Like `finish_with_timing` but runs the fusion pass first.
1689    /// Returns `(encoding_ns, gpu_wait_ns, fusions_applied)`.
1690    pub fn finish_with_fusion_and_timing(
1691        mut self,
1692        registry: &mut KernelRegistry,
1693        device: &metal::DeviceRef,
1694        session_begin: std::time::Instant,
1695    ) -> Result<(u64, u64, u32)> {
1696        let mut fusions = 0;
1697        if self.recording {
1698            if let Some(nodes) = self.encoder.take_capture() {
1699                let mut graph = ComputeGraph::from_nodes(nodes);
1700                fusions = graph.fuse(registry, device)?;
1701                graph.encode_sequential(&mut self.encoder);
1702            }
1703        }
1704        let commit_start = std::time::Instant::now();
1705        let encoding_ns = commit_start.duration_since(session_begin).as_nanos() as u64;
1706        self.encoder.commit();
1707        self.encoder.wait_until_completed()?;
1708        let gpu_wait_ns = commit_start.elapsed().as_nanos() as u64;
1709        Ok((encoding_ns, gpu_wait_ns, fusions))
1710    }
1711
1712    /// Finish with fusion AND reorder: run both graph optimization passes
1713    /// before replaying the graph.
1714    ///
1715    /// Only meaningful in recording mode.  In direct-dispatch mode, this
1716    /// behaves identically to `finish()`.
1717    ///
1718    /// Returns `(fusions_applied, nodes_reordered)` on success.
1719    pub fn finish_with_fusion_and_reorder(
1720        mut self,
1721        registry: &mut KernelRegistry,
1722        device: &metal::DeviceRef,
1723    ) -> Result<(u32, u32)> {
1724        let mut fusions = 0;
1725        let mut reordered = 0;
1726        if self.recording {
1727            if let Some(nodes) = self.encoder.take_capture() {
1728                let mut graph = ComputeGraph::from_nodes(nodes);
1729                fusions = graph.fuse(registry, device)?;
1730                reordered = graph.reorder();
1731                graph.encode_with_barriers(&mut self.encoder);
1732            }
1733        }
1734        self.encoder.commit_and_wait()?;
1735        Ok((fusions, reordered))
1736    }
1737
1738    /// Finish with fusion, reorder, and split timing.
1739    ///
1740    /// Like `finish_with_fusion_and_timing` but also runs the reorder pass.
1741    /// Returns `(encoding_ns, gpu_wait_ns, fusions_applied, nodes_reordered)`.
1742    pub fn finish_with_fusion_reorder_and_timing(
1743        mut self,
1744        registry: &mut KernelRegistry,
1745        device: &metal::DeviceRef,
1746        session_begin: std::time::Instant,
1747    ) -> Result<(u64, u64, u32, u32)> {
1748        let mut fusions = 0;
1749        let mut reordered = 0;
1750        if self.recording {
1751            if let Some(nodes) = self.encoder.take_capture() {
1752                let mut graph = ComputeGraph::from_nodes(nodes);
1753                fusions = graph.fuse(registry, device)?;
1754                reordered = graph.reorder();
1755                graph.encode_with_barriers(&mut self.encoder);
1756            }
1757        }
1758        let commit_start = std::time::Instant::now();
1759        let encoding_ns = commit_start.duration_since(session_begin).as_nanos() as u64;
1760        self.encoder.commit();
1761        self.encoder.wait_until_completed()?;
1762        let gpu_wait_ns = commit_start.elapsed().as_nanos() as u64;
1763        Ok((encoding_ns, gpu_wait_ns, fusions, reordered))
1764    }
1765
1766    /// Finish with the full optimization pipeline: fuse, reorder, dual-buffer
1767    /// encode.
1768    ///
1769    /// Runs the fusion pass, reorder pass, then encodes the graph into two
1770    /// Metal command buffers for CPU/GPU overlap.  The first ~10% of dispatches
1771    /// are committed immediately so the GPU can start executing while the CPU
1772    /// encodes the remaining ~90%.
1773    ///
1774    /// Only meaningful in recording mode.  In direct-dispatch mode, this
1775    /// behaves identically to `finish()`.
1776    ///
1777    /// Returns `(fusions_applied, nodes_reordered, barriers_buf0, barriers_buf1)`.
1778    pub fn finish_optimized(
1779        mut self,
1780        registry: &mut KernelRegistry,
1781        device: &metal::DeviceRef,
1782    ) -> Result<(u32, u32, u32, u32)> {
1783        let mut fusions = 0;
1784        let mut reordered = 0;
1785        let mut barriers0 = 0u32;
1786        let mut barriers1 = 0u32;
1787
1788        if self.recording {
1789            if let Some(nodes) = self.encoder.take_capture() {
1790                // Commit the capture encoder's empty command buffer so its
1791                // MTLCommandQueue pool slot is freed (same fix as timing variant).
1792                self.encoder.commit();
1793
1794                let mut graph = ComputeGraph::from_nodes(nodes);
1795                fusions = graph.fuse(registry, device)?;
1796                reordered = graph.reorder();
1797
1798                let mut enc0 = self.device.command_encoder()?;
1799                let mut enc1 = self.device.command_encoder()?;
1800
1801                let (b0, b1) = graph.encode_dual_buffer(&mut enc0, &mut enc1);
1802                barriers0 = b0;
1803                barriers1 = b1;
1804
1805                // enc0 was already committed inside encode_dual_buffer.
1806                // Commit enc1 and wait — Metal queue ordering guarantees enc0
1807                // finishes before enc1 starts executing.
1808                enc1.commit_and_wait()?;
1809
1810                // The original encoder was never committed (capture mode drained
1811                // it). We need to end it cleanly — dropping it will end the
1812                // active encoder if any, and the uncommitted command buffer is
1813                // abandoned.  That is safe: Metal silently drops uncommitted
1814                // command buffers.
1815                return Ok((fusions, reordered, barriers0, barriers1));
1816            }
1817        }
1818
1819        // Direct-dispatch fallback: just commit the original encoder.
1820        self.encoder.commit_and_wait()?;
1821        Ok((fusions, reordered, barriers0, barriers1))
1822    }
1823
1824    /// Finish with the full optimization pipeline and split timing.
1825    ///
1826    /// Like `finish_optimized` but returns timing information.
1827    /// Returns `(encoding_ns, gpu_wait_ns, fusions, reordered, barriers_buf0, barriers_buf1)`.
1828    ///
1829    /// Timing breakdown:
1830    /// - `encoding_ns`: CPU time from session begin to first buffer commit
1831    ///   (fusion + reorder + encode chunk 0)
1832    /// - `gpu_wait_ns`: wall time from second buffer commit to GPU completion
1833    ///   (includes GPU execution of both buffers, overlapped with chunk 1 encoding)
1834    pub fn finish_optimized_with_timing(
1835        mut self,
1836        registry: &mut KernelRegistry,
1837        device: &metal::DeviceRef,
1838        session_begin: std::time::Instant,
1839    ) -> Result<(u64, u64, u32, u32, u32, u32)> {
1840        let mut fusions = 0;
1841        let mut reordered = 0;
1842        let mut barriers0 = 0u32;
1843        let mut barriers1 = 0u32;
1844
1845        if self.recording {
1846            if let Some(nodes) = self.encoder.take_capture() {
1847                // Commit the capture encoder's empty command buffer so its
1848                // MTLCommandQueue pool slot is freed.  Without this, each
1849                // token leaks one uncommitted buffer and the queue exhausts
1850                // its ~64-slot pool after ~64 tokens, causing a deadlock.
1851                self.encoder.commit();
1852
1853                let opt_t0 = std::time::Instant::now();
1854                let mut graph = ComputeGraph::from_nodes(nodes);
1855                let fuse_t0 = std::time::Instant::now();
1856                fusions = graph.fuse(registry, device)?;
1857                let fuse_us = fuse_t0.elapsed().as_micros();
1858
1859                let reorder_t0 = std::time::Instant::now();
1860                let unannotated = graph.unannotated_dispatch_count();
1861                if unannotated == 0 {
1862                    reordered = graph.reorder();
1863                } else if std::env::var("HF2Q_MLX_TIMING").is_ok() {
1864                    eprintln!("  [GRAPH_OPT] WARN: skipping reorder — {} of {} dispatches lack range annotations",
1865                        unannotated, graph.dispatch_count());
1866                }
1867                let reorder_us = reorder_t0.elapsed().as_micros();
1868                let opt_us = opt_t0.elapsed().as_micros();
1869
1870                let diag = std::env::var("HF2Q_GRAPH_DIAG").is_ok();
1871                let t0 = std::time::Instant::now();
1872                let mut enc0 = self.device.command_encoder()?;
1873                let mut enc1 = self.device.command_encoder()?;
1874                let enc_create_us = t0.elapsed().as_micros();
1875
1876                let t1 = std::time::Instant::now();
1877                let (b0, b1) = graph.encode_dual_buffer(&mut enc0, &mut enc1);
1878                barriers0 = b0;
1879                barriers1 = b1;
1880                let encode_us = t1.elapsed().as_micros();
1881
1882                let encoding_ns = session_begin.elapsed().as_nanos() as u64;
1883
1884                let wait_start = std::time::Instant::now();
1885                enc1.commit_and_wait()?;
1886                let gpu_wait_ns = wait_start.elapsed().as_nanos() as u64;
1887
1888                if diag {
1889                    eprintln!("  [DIAG] fuse={:.1}ms reorder={:.1}ms opt_total={:.1}ms enc_create={:.1}ms encode={:.1}ms gpu_wait={:.1}ms barriers={}+{}",
1890                        fuse_us as f64 / 1e3, reorder_us as f64 / 1e3, opt_us as f64 / 1e3,
1891                        enc_create_us as f64 / 1e3, encode_us as f64 / 1e3,
1892                        gpu_wait_ns as f64 / 1e6, b0, b1);
1893                }
1894
1895                return Ok((encoding_ns, gpu_wait_ns, fusions, reordered, barriers0, barriers1));
1896            }
1897        }
1898
1899        // Direct-dispatch fallback.
1900        let commit_start = std::time::Instant::now();
1901        let encoding_ns = commit_start.duration_since(session_begin).as_nanos() as u64;
1902        self.encoder.commit();
1903        self.encoder.wait_until_completed()?;
1904        let gpu_wait_ns = commit_start.elapsed().as_nanos() as u64;
1905        Ok((encoding_ns, gpu_wait_ns, fusions, reordered, barriers0, barriers1))
1906    }
1907}