Skip to main content

mlx_native/
graph.rs

1//! [`GraphExecutor`] — batched Metal dispatch for single-encoder forward passes.
2//!
3//! llama.cpp's speed advantage over candle is NOT the kernels (Phase 0 proved
4//! candle's are as fast or faster per-call).  It is the dispatch pattern:
5//! 1 encoder per command buffer instead of ~120.  This module implements that
6//! pattern.
7//!
8//! # Usage
9//!
10//! ```ignore
11//! let mut executor = GraphExecutor::new(device.clone());
12//! let mut session = executor.begin()?;
13//!
14//! // All ops encode into the same command buffer — no per-op encoder creation.
15//! session.rms_norm(&mut registry, device.metal_device(), input, weight, output, params, rows, dim)?;
16//! session.quantized_matmul(&mut registry, &device, input, weight, scales, biases, &qparams)?;
17//! session.elementwise_add(&mut registry, device.metal_device(), a, b, out, n, DType::F32)?;
18//!
19//! // Single GPU sync point for the entire forward pass.
20//! session.finish()?;
21//! ```
22//!
23//! # Design
24//!
25//! The `GraphSession` holds a single `CommandEncoder`.  Each op method delegates
26//! to the existing op dispatch functions in [`crate::ops`], passing the session's
27//! shared encoder.  No new Metal code is needed — the ops already work with a
28//! shared encoder.  The executor just prevents creating a new encoder per op.
29//!
30//! # Phase 4e.1 — Graph IR
31//!
32//! The `ComputeGraph` type captures dispatches into a `Vec<CapturedNode>` for
33//! later replay.  `GraphExecutor::begin_recorded()` starts a session in capture
34//! mode: all op calls are intercepted at the `CommandEncoder` level and recorded
35//! instead of being sent to Metal.  `GraphSession::finish()` detects capture
36//! mode, extracts the recorded graph, and replays it into a fresh encoder via
37//! `ComputeGraph::encode_sequential()`.
38//!
39//! The existing direct-dispatch path (`begin()`) is completely unchanged.
40
41use metal::foreign_types::ForeignType;
42
43use crate::device::MlxDevice;
44use crate::encoder::{CapturedNode, CapturedOpKind, CommandEncoder, MemRange, RecordedBinding};
45use crate::error::Result;
46use crate::kernel_registry::KernelRegistry;
47use crate::ops;
48
49// Re-export types used in the public API so callers don't need separate imports.
50pub use crate::buffer::MlxBuffer;
51pub use crate::dtypes::DType;
52
53// ---------------------------------------------------------------------------
54// OpKind — operation classification for the reorder safety whitelist (4e.3)
55// ---------------------------------------------------------------------------
56
57/// Classification of a compute operation for reorder safety analysis.
58///
59/// Operations marked as reorderable can be freely reordered by the graph
60/// optimizer (Phase 4e.3) as long as their data dependencies allow it.
61/// Non-reorderable operations have side effects or dependencies that
62/// require them to stay in their original sequential position.
63#[derive(Clone, Copy, Debug, PartialEq, Eq)]
64pub enum OpKind {
65    /// Matrix multiplication (reorderable).
66    MatMul,
67    /// Expert-routed matrix multiplication (reorderable).
68    MatMulId,
69    /// Normalization — RMS norm, layer norm (reorderable).
70    Norm,
71    /// Rotary position embedding (reorderable).
72    Rope,
73    /// Elementwise ops — add, mul, scale, gelu, softcap, etc. (reorderable).
74    Elementwise,
75    /// Memory copy — KV cache copy, embedding gather (reorderable).
76    Copy,
77    /// Gather/scatter (reorderable).
78    Gather,
79    /// Scaled dot-product attention (NOT reorderable).
80    Sdpa,
81    /// Softmax (NOT reorderable).
82    Softmax,
83    /// MoE gate with CPU readback dependency (NOT reorderable).
84    MoeGate,
85    /// Anything else (NOT reorderable).
86    Other,
87}
88
89impl OpKind {
90    /// Whether this op kind is safe to reorder in the graph optimizer.
91    pub fn is_reorderable(&self) -> bool {
92        matches!(
93            self,
94            Self::MatMul
95                | Self::MatMulId
96                | Self::Norm
97                | Self::Rope
98                | Self::Elementwise
99                | Self::Copy
100                | Self::Gather
101        )
102    }
103}
104
105// ---------------------------------------------------------------------------
106// ComputeGraph — the recorded graph IR
107// ---------------------------------------------------------------------------
108
109/// A recorded sequence of GPU compute dispatches and barriers.
110///
111/// Created by running a forward pass with the encoder in capture mode.
112/// Can be replayed into a real `CommandEncoder` via `encode_sequential()`,
113/// producing identical Metal dispatch behavior to the original direct path.
114///
115/// Future phases (4e.2, 4e.3) will add fusion and reorder passes that
116/// transform the graph before encoding.
117pub struct ComputeGraph {
118    nodes: Vec<CapturedNode>,
119}
120
121impl ComputeGraph {
122    /// Create an empty compute graph.
123    pub fn new() -> Self {
124        Self {
125            nodes: Vec::with_capacity(128),
126        }
127    }
128
129    /// Create a compute graph from a pre-built list of captured nodes.
130    pub fn from_nodes(nodes: Vec<CapturedNode>) -> Self {
131        Self { nodes }
132    }
133
134    /// Record a captured node into the graph.
135    pub fn record(&mut self, node: CapturedNode) {
136        self.nodes.push(node);
137    }
138
139    /// Number of nodes (dispatches + barriers) in the graph.
140    pub fn len(&self) -> usize {
141        self.nodes.len()
142    }
143
144    /// Whether the graph contains no nodes.
145    pub fn is_empty(&self) -> bool {
146        self.nodes.is_empty()
147    }
148
149    /// Number of dispatch nodes (excludes barriers).
150    pub fn dispatch_count(&self) -> usize {
151        self.nodes
152            .iter()
153            .filter(|n| matches!(n, CapturedNode::Dispatch { .. }))
154            .count()
155    }
156
157    /// Number of barrier nodes.
158    pub fn barrier_count(&self) -> usize {
159        self.nodes
160            .iter()
161            .filter(|n| matches!(n, CapturedNode::Barrier))
162            .count()
163    }
164
165    /// Borrow the node list.
166    pub fn nodes(&self) -> &[CapturedNode] {
167        &self.nodes
168    }
169
170    /// Count dispatch nodes that have empty read/write range annotations.
171    ///
172    /// Used for diagnostics: if >0, the reorder pass cannot guarantee
173    /// correctness because it relies on complete annotations.
174    pub fn unannotated_dispatch_count(&self) -> usize {
175        self.nodes
176            .iter()
177            .filter(|n| matches!(n, CapturedNode::Dispatch { reads, writes, .. }
178                if reads.is_empty() || writes.is_empty()))
179            .count()
180    }
181
182    /// Take ownership of the node list, consuming the graph.
183    pub fn into_nodes(self) -> Vec<CapturedNode> {
184        self.nodes
185    }
186
187    /// Encode all nodes sequentially into the given encoder.
188    ///
189    /// Barrier sentinel nodes emit a Metal memory barrier.  Dispatch nodes
190    /// are replayed through `CommandEncoder::replay_dispatch()`.
191    ///
192    /// This produces identical GPU behavior to the direct-dispatch path —
193    /// same pipeline bindings, same dispatch dimensions, same barrier
194    /// placement.
195    ///
196    /// Returns the number of barriers emitted.
197    pub fn encode_sequential(&self, encoder: &mut CommandEncoder) -> u32 {
198        let mut barrier_count = 0u32;
199        for node in &self.nodes {
200            match node {
201                CapturedNode::Barrier => {
202                    encoder.memory_barrier();
203                    barrier_count += 1;
204                }
205                CapturedNode::Dispatch {
206                    pipeline,
207                    bindings,
208                    threads_per_grid,
209                    threads_per_threadgroup,
210                    threadgroup_memory,
211                    dispatch_kind,
212                    ..
213                } => {
214                    encoder.replay_dispatch(
215                        pipeline,
216                        bindings,
217                        threadgroup_memory,
218                        *threads_per_grid,
219                        *threads_per_threadgroup,
220                        *dispatch_kind,
221                    );
222                }
223            }
224        }
225        barrier_count
226    }
227
228    /// Encode the graph into a Metal command buffer, computing barriers on the
229    /// fly from each node's read/write buffer ranges.
230    ///
231    /// This is the correct encoding method for reordered graphs where barrier
232    /// sentinels have been stripped.  Mirrors llama.cpp's encode-time barrier
233    /// insertion via `ggml_metal_op_concurrency_check`.
234    ///
235    /// Returns the number of barriers emitted.
236    pub fn encode_with_barriers(&self, encoder: &mut CommandEncoder) -> u32 {
237        let mut tracker = ReorderConflictTracker::new();
238        let mut barrier_count = 0u32;
239
240        for node in &self.nodes {
241            match node {
242                CapturedNode::Dispatch {
243                    pipeline,
244                    bindings,
245                    threads_per_grid,
246                    threads_per_threadgroup,
247                    threadgroup_memory,
248                    dispatch_kind,
249                    reads,
250                    writes,
251                    ..
252                } => {
253                    let has_ranges = !reads.is_empty() || !writes.is_empty();
254                    if has_ranges && tracker.conflicts(reads, writes) {
255                        encoder.memory_barrier();
256                        tracker.reset();
257                        barrier_count += 1;
258                    }
259                    if has_ranges {
260                        tracker.add(reads, writes);
261                    }
262                    encoder.replay_dispatch(
263                        pipeline,
264                        bindings,
265                        threadgroup_memory,
266                        *threads_per_grid,
267                        *threads_per_threadgroup,
268                        *dispatch_kind,
269                    );
270                }
271                CapturedNode::Barrier => {
272                    // Explicit barriers still force a barrier boundary
273                    encoder.memory_barrier();
274                    tracker.reset();
275                    barrier_count += 1;
276                }
277            }
278        }
279        barrier_count
280    }
281
282    /// Encode the graph using two command buffers for CPU/GPU overlap.
283    ///
284    /// The first `n0` dispatches are encoded into `encoder0` and committed
285    /// immediately (GPU starts executing).  The remaining dispatches are encoded
286    /// into `encoder1`.  The caller is responsible for committing `encoder1`.
287    ///
288    /// This matches llama.cpp's dual command buffer pattern from
289    /// `ggml_metal_graph_compute` (ggml-metal-context.m:441-644):
290    /// `n_nodes_0 = MAX(64, 0.1 * n_nodes)` for the first buffer.
291    ///
292    /// Command buffers submitted to the same `MTLCommandQueue` execute in
293    /// submission order, so `encoder0.commit()` followed by `encoder1.commit()`
294    /// guarantees enc0 finishes before enc1 starts.  The win: the GPU starts
295    /// executing enc0 while the CPU is still encoding enc1.
296    ///
297    /// Returns `(barriers_buf0, barriers_buf1)`.
298    pub fn encode_dual_buffer(
299        &self,
300        encoder0: &mut CommandEncoder,
301        encoder1: &mut CommandEncoder,
302    ) -> (u32, u32) {
303        let dispatch_total = self.dispatch_count();
304        let n0 = std::cmp::max(64, dispatch_total / 10);
305
306        // Find the split point: the index of the n0-th dispatch node.
307        let split_idx = find_dispatch_split_index(&self.nodes, n0);
308
309        // Encode first chunk with barrier recomputation, then commit immediately.
310        let barriers0 = encode_chunk_with_barriers(&self.nodes[..split_idx], encoder0);
311        encoder0.commit();
312
313        // Encode second chunk with barrier recomputation.
314        let barriers1 = encode_chunk_with_barriers(&self.nodes[split_idx..], encoder1);
315
316        (barriers0, barriers1)
317    }
318
319    /// Run the RMS norm + MUL fusion pass over the graph.
320    ///
321    /// Scans for the pattern:
322    ///   Dispatch(RmsNorm) → Barrier(s) → Dispatch(ElemMul)
323    /// where the MUL reads the norm's output buffer, and replaces the
324    /// sequence with a single fused `rms_norm_mul_*` dispatch.
325    ///
326    /// The fused dispatch:
327    /// - Reads the norm's input (buffer 0) and weight (buffer 1)
328    /// - Reads the MUL's second operand as the scale (buffer 2)
329    /// - Writes to the MUL's output (buffer 3)
330    /// - Carries the norm's params (buffer 4)
331    /// - Uses the norm's threadgroup config and shared memory
332    ///
333    /// Returns the number of fusions applied.
334    ///
335    /// # Arguments
336    ///
337    /// * `registry` - Kernel registry for compiling the fused pipeline.
338    /// * `device`   - Metal device for pipeline compilation.
339    pub fn fuse(
340        &mut self,
341        registry: &mut KernelRegistry,
342        device: &metal::DeviceRef,
343    ) -> Result<u32> {
344        let mut result: Vec<CapturedNode> = Vec::with_capacity(self.nodes.len());
345        let mut fusions = 0u32;
346        let mut i = 0;
347
348        while i < self.nodes.len() {
349            // Check if current node is an RMS norm dispatch.
350            let is_rms_norm = matches!(
351                &self.nodes[i],
352                CapturedNode::Dispatch { op_kind: CapturedOpKind::RmsNorm, .. }
353            );
354
355            if !is_rms_norm {
356                result.push(self.nodes[i].clone());
357                i += 1;
358                continue;
359            }
360
361            // Look ahead: skip barriers, then check for ElemMul.
362            let mut j = i + 1;
363            let mut barrier_count = 0usize;
364            while j < self.nodes.len() && matches!(&self.nodes[j], CapturedNode::Barrier) {
365                barrier_count += 1;
366                j += 1;
367            }
368
369            // Must have at least one barrier and the next node must be ElemMul.
370            if barrier_count == 0 || j >= self.nodes.len() {
371                result.push(self.nodes[i].clone());
372                i += 1;
373                continue;
374            }
375
376            let is_elem_mul = matches!(
377                &self.nodes[j],
378                CapturedNode::Dispatch { op_kind: CapturedOpKind::ElemMul, .. }
379            );
380
381            if !is_elem_mul {
382                result.push(self.nodes[i].clone());
383                i += 1;
384                continue;
385            }
386
387            // Extract norm and mul dispatch fields.
388            let (norm_pipeline, norm_bindings, norm_tpg, norm_tptg, norm_tgmem, norm_dk) =
389                match &self.nodes[i] {
390                    CapturedNode::Dispatch {
391                        pipeline,
392                        bindings,
393                        threads_per_grid,
394                        threads_per_threadgroup,
395                        threadgroup_memory,
396                        dispatch_kind,
397                        ..
398                    } => (pipeline, bindings, threads_per_grid, threads_per_threadgroup, threadgroup_memory, dispatch_kind),
399                    _ => unreachable!(),
400                };
401
402            let (mul_bindings, _mul_tpg, _mul_tptg) = match &self.nodes[j] {
403                CapturedNode::Dispatch {
404                    bindings,
405                    threads_per_grid,
406                    threads_per_threadgroup,
407                    ..
408                } => (bindings, threads_per_grid, threads_per_threadgroup),
409                _ => unreachable!(),
410            };
411
412            // Verify data dependency: the norm's output buffer (slot 2) must
413            // appear as one of the MUL's input buffers (slot 0 or 1).
414            //
415            // Norm binding layout: (0=input, 1=weight, 2=output, 3=params)
416            // MUL binding layout:  (0=a, 1=b, 2=output, 3=params_bytes)
417            let norm_output_ptr = Self::buffer_ptr_for_slot(norm_bindings, 2);
418            let mul_a_ptr = Self::buffer_ptr_for_slot(mul_bindings, 0);
419            let mul_b_ptr = Self::buffer_ptr_for_slot(mul_bindings, 1);
420
421            if norm_output_ptr.is_none() || (norm_output_ptr != mul_a_ptr && norm_output_ptr != mul_b_ptr) {
422                // Data dependency not confirmed — don't fuse.
423                result.push(self.nodes[i].clone());
424                i += 1;
425                continue;
426            }
427
428            // Determine which MUL input is the scale (the one that is NOT
429            // the norm's output).
430            let scale_slot = if norm_output_ptr == mul_a_ptr { 1 } else { 0 };
431
432            // Build fused bindings:
433            //   0 = norm input
434            //   1 = norm weight
435            //   2 = scale (from MUL)
436            //   3 = MUL output
437            //   4 = norm params
438            // Gather all required bindings; bail if any are missing.
439            let (norm_input, norm_weight, scale, mul_output, norm_params) = match (
440                Self::get_binding(norm_bindings, 0),
441                Self::get_binding(norm_bindings, 1),
442                Self::get_binding(mul_bindings, scale_slot),
443                Self::get_binding(mul_bindings, 2),
444                Self::get_binding(norm_bindings, 3),
445            ) {
446                (Some(a), Some(b), Some(c), Some(d), Some(e)) => (a, b, c, d, e),
447                _ => {
448                    // Missing bindings — don't fuse.
449                    result.push(self.nodes[i].clone());
450                    i += 1;
451                    continue;
452                }
453            };
454
455            // Select fused pipeline based on the original norm pipeline name.
456            // The norm pipeline name is "rms_norm_f32", "rms_norm_f16", or
457            // "rms_norm_bf16" — we need the corresponding fused pipeline.
458            let fused_name = match Self::fused_pipeline_name(norm_pipeline) {
459                Some(name) => name,
460                None => {
461                    result.push(self.nodes[i].clone());
462                    i += 1;
463                    continue;
464                }
465            };
466
467            let fused_pipeline = registry.get_pipeline(fused_name, device)?;
468
469            let fused_bindings = vec![
470                (0, norm_input),
471                (1, norm_weight),
472                (2, scale),
473                (3, mul_output),
474                (4, norm_params),
475            ];
476
477            // Merge read/write ranges from both the norm and mul nodes for the
478            // fused dispatch.  The fused op reads everything the norm reads
479            // plus the mul's scale input, and writes to the mul's output.
480            let (fused_reads, fused_writes) = match (&self.nodes[i], &self.nodes[j]) {
481                (
482                    CapturedNode::Dispatch { reads: nr, writes: _nw, .. },
483                    CapturedNode::Dispatch { reads: mr, writes: mw, .. },
484                ) => {
485                    let mut reads = nr.clone();
486                    reads.extend_from_slice(mr);
487                    (reads, mw.clone())
488                }
489                _ => (Vec::new(), Vec::new()),
490            };
491
492            result.push(CapturedNode::Dispatch {
493                pipeline: fused_pipeline.to_owned(),
494                bindings: fused_bindings,
495                threads_per_grid: *norm_tpg,
496                threads_per_threadgroup: *norm_tptg,
497                threadgroup_memory: norm_tgmem.clone(),
498                dispatch_kind: *norm_dk,
499                op_kind: CapturedOpKind::Other, // Fused ops are not further fuseable
500                reads: fused_reads,
501                writes: fused_writes,
502            });
503
504            fusions += 1;
505            // Skip past the norm, barrier(s), and mul nodes.
506            i = j + 1;
507        }
508
509        self.nodes = result;
510        Ok(fusions)
511    }
512
513    /// Run the reorder pass over the graph to improve GPU concurrency.
514    ///
515    /// Port of llama.cpp's `ggml_metal_graph_optimize_reorder` — a greedy
516    /// 64-node lookahead that pulls independent dispatches forward to fill
517    /// larger concurrent groups between barriers.
518    ///
519    /// **Prerequisites:** Call `fuse()` first if desired.  The reorder pass
520    /// operates on the post-fusion graph.  Barrier sentinel nodes are stripped
521    /// before reordering (they will be recomputed at encode time by the
522    /// `ConflictTracker` in `encode_sequential`).
523    ///
524    /// **Algorithm (matching llama.cpp exactly):**
525    /// 1. Strip all `CapturedNode::Barrier` nodes.
526    /// 2. For each unprocessed node `i0`:
527    ///    - If it conflicts with the current concurrent group (`mrs0`):
528    ///      * Initialize `mrs1` from `i0`'s ranges (skipped-over set)
529    ///      * Lookahead up to 64 nodes for candidates that:
530    ///        (a) Are reorderable (`CapturedOpKind::is_reorderable()`)
531    ///        (b) Don't conflict with `mrs0` (current group)
532    ///        (c) Don't conflict with `mrs1` (skipped-over nodes)
533    ///      * Pull qualifying candidates into the current group
534    ///      * Non-reorderable ops break the lookahead
535    ///    - Reset `mrs0` (new concurrent group)
536    ///    - Add `i0` to the new group
537    ///
538    /// Returns the number of nodes that were moved to earlier positions.
539    pub fn reorder(&mut self) -> u32 {
540        // Step 1: Strip barrier nodes.  After fusion + reorder, barriers will
541        // be recomputed by the ConflictTracker at encode time.
542        self.nodes.retain(|n| !matches!(n, CapturedNode::Barrier));
543
544        let n = self.nodes.len();
545        if n == 0 {
546            return 0;
547        }
548
549        let mut result: Vec<usize> = Vec::with_capacity(n);
550        let mut used = vec![false; n];
551
552        // mrs0: memory ranges for the current concurrent group
553        let mut mrs0 = ReorderConflictTracker::new();
554        // mrs1: memory ranges for skipped-over (unprocessed) nodes
555        let mut mrs1 = ReorderConflictTracker::new();
556
557        const N_FORWARD: usize = 64;
558
559        for i0 in 0..n {
560            if used[i0] {
561                continue;
562            }
563
564            let node0 = &self.nodes[i0];
565
566            // Extract reads/writes for conflict check.
567            let (reads0, writes0, op_kind0) = match node0 {
568                CapturedNode::Dispatch { reads, writes, op_kind, .. } => {
569                    (reads.as_slice(), writes.as_slice(), *op_kind)
570                }
571                CapturedNode::Barrier => continue, // stripped, but be safe
572            };
573
574            // Check if node0 conflicts with the current concurrent group.
575            // Empty nodes (no ranges) never conflict — like llama.cpp's is_empty.
576            let has_ranges = !reads0.is_empty() || !writes0.is_empty();
577            if has_ranges && mrs0.conflicts(reads0, writes0) {
578                // Before starting a new group, look forward for nodes that
579                // can be pulled into the CURRENT group.
580                mrs1.reset();
581                mrs1.add(reads0, writes0);
582
583                let end = (i0 + N_FORWARD).min(n);
584                for i1 in (i0 + 1)..end {
585                    if used[i1] {
586                        continue;
587                    }
588
589                    let node1 = &self.nodes[i1];
590                    let (reads1, writes1, op_kind1) = match node1 {
591                        CapturedNode::Dispatch { reads, writes, op_kind, .. } => {
592                            (reads.as_slice(), writes.as_slice(), *op_kind)
593                        }
594                        CapturedNode::Barrier => continue,
595                    };
596
597                    // Non-reorderable ops break the lookahead.
598                    if !op_kind1.is_reorderable() {
599                        break;
600                    }
601
602                    let is_empty1 = reads1.is_empty() && writes1.is_empty();
603
604                    // A node can be reordered into the current group if:
605                    // 1. It's empty (no ranges) OR doesn't conflict with mrs0
606                    // 2. It doesn't conflict with mrs1 (skipped-over nodes)
607                    if (is_empty1 || !mrs0.conflicts(reads1, writes1))
608                        && !mrs1.conflicts(reads1, writes1)
609                    {
610                        // Pull into current concurrent group.
611                        mrs0.add(reads1, writes1);
612                        result.push(i1);
613                        used[i1] = true;
614                    } else {
615                        // Not eligible — expand the skipped-over set.
616                        mrs1.add(reads1, writes1);
617                    }
618                }
619
620                // Finalize the current concurrent group.
621                mrs0.reset();
622            }
623
624            // Expand the concurrent group with node0.
625            // (Barriers were stripped, so this is always a Dispatch.)
626            let _ = op_kind0; // suppress unused warning
627            mrs0.add(reads0, writes0);
628            result.push(i0);
629        }
630
631        // Apply the permutation to produce the reordered node list.
632        let mut reordered_count = 0u32;
633        for (pos, &orig_idx) in result.iter().enumerate() {
634            if orig_idx != pos {
635                reordered_count += 1;
636            }
637        }
638
639        // Build the reordered nodes vec.
640        let old_nodes = std::mem::take(&mut self.nodes);
641        self.nodes = result.iter().map(|&idx| old_nodes[idx].clone()).collect();
642
643        // Debug dump if requested.
644        if std::env::var("HF2Q_REORDER_DUMP").is_ok() {
645            eprintln!(
646                "  [REORDER] nodes={} reordered={} ({:.1}%)",
647                n,
648                reordered_count,
649                100.0 * reordered_count as f64 / n as f64,
650            );
651        }
652
653        reordered_count
654    }
655
656    /// Get the Metal buffer pointer for a binding at the given slot index.
657    ///
658    /// Returns `Some(ptr)` if the slot has a `RecordedBinding::Buffer`,
659    /// `None` otherwise.
660    fn buffer_ptr_for_slot(bindings: &[(u64, RecordedBinding)], slot: u64) -> Option<*const std::ffi::c_void> {
661        for (idx, binding) in bindings {
662            if *idx == slot {
663                if let RecordedBinding::Buffer { metal_buffer, offset: _ } = binding {
664                    // Use the Metal buffer's GPU address as the identity key.
665                    // On Apple Silicon unified memory, this uniquely identifies
666                    // the allocation.
667                    let ptr: *const std::ffi::c_void = metal_buffer.as_ptr() as *const _;
668                    return Some(ptr);
669                }
670            }
671        }
672        None
673    }
674
675    /// Clone the binding at the given slot index.
676    fn get_binding(bindings: &[(u64, RecordedBinding)], slot: u64) -> Option<RecordedBinding> {
677        for (idx, binding) in bindings {
678            if *idx == slot {
679                return Some(binding.clone());
680            }
681        }
682        None
683    }
684
685    /// Map a norm pipeline to its fused norm+mul pipeline name.
686    ///
687    /// The pipeline's `label()` is set by Metal to the function name, so we
688    /// can match on it.  Returns `None` if the pipeline is not a known norm.
689    fn fused_pipeline_name(pipeline: &metal::ComputePipelineState) -> Option<&'static str> {
690        match pipeline.label() {
691            "rms_norm_f32" => Some("rms_norm_mul_f32"),
692            "rms_norm_f16" => Some("rms_norm_mul_f16"),
693            "rms_norm_bf16" => Some("rms_norm_mul_bf16"),
694            _ => None,
695        }
696    }
697}
698
699impl Default for ComputeGraph {
700    fn default() -> Self {
701        Self::new()
702    }
703}
704
705// ---------------------------------------------------------------------------
706// Dual-buffer encoding helpers (Phase 4e.4)
707// ---------------------------------------------------------------------------
708
709/// Find the node index where the n0-th dispatch starts.
710///
711/// Counts `CapturedNode::Dispatch` nodes until `n0` are reached, then returns
712/// the index of the n0-th dispatch (i.e., the first node of the second chunk).
713/// If `n0 >= dispatch_count`, returns `nodes.len()` (everything in chunk 0).
714fn find_dispatch_split_index(nodes: &[CapturedNode], n0: usize) -> usize {
715    let mut dispatches_seen = 0usize;
716    for (i, node) in nodes.iter().enumerate() {
717        if matches!(node, CapturedNode::Dispatch { .. }) {
718            dispatches_seen += 1;
719            if dispatches_seen == n0 {
720                return i + 1; // split AFTER the n0-th dispatch
721            }
722        }
723    }
724    nodes.len()
725}
726
727/// Encode a slice of captured nodes into a command encoder, recomputing
728/// barriers on the fly from each node's read/write buffer ranges.
729///
730/// This is the chunked counterpart of `ComputeGraph::encode_with_barriers()`.
731/// Factored out so both halves of a dual-buffer encode can use it.
732///
733/// Returns the number of barriers emitted.
734fn encode_chunk_with_barriers(nodes: &[CapturedNode], encoder: &mut CommandEncoder) -> u32 {
735    let mut tracker = ReorderConflictTracker::new();
736    let mut barrier_count = 0u32;
737
738    for node in nodes {
739        match node {
740            CapturedNode::Dispatch {
741                pipeline,
742                bindings,
743                threads_per_grid,
744                threads_per_threadgroup,
745                threadgroup_memory,
746                dispatch_kind,
747                reads,
748                writes,
749                ..
750            } => {
751                let has_ranges = !reads.is_empty() || !writes.is_empty();
752                if has_ranges && tracker.conflicts(reads, writes) {
753                    encoder.memory_barrier();
754                    tracker.reset();
755                    barrier_count += 1;
756                }
757                if has_ranges {
758                    tracker.add(reads, writes);
759                }
760                encoder.replay_dispatch(
761                    pipeline,
762                    bindings,
763                    threadgroup_memory,
764                    *threads_per_grid,
765                    *threads_per_threadgroup,
766                    *dispatch_kind,
767                );
768            }
769            CapturedNode::Barrier => {
770                encoder.memory_barrier();
771                tracker.reset();
772                barrier_count += 1;
773            }
774        }
775    }
776    barrier_count
777}
778
779// ---------------------------------------------------------------------------
780// ReorderConflictTracker — range-based conflict detection for the reorder pass
781// ---------------------------------------------------------------------------
782
783/// Memory range conflict tracker for the reorder pass (Phase 4e.3).
784///
785/// Works with `MemRange` tuples `(start, end)` stored on `CapturedNode::Dispatch`,
786/// rather than requiring live `&MlxBuffer` references.  This is the reorder-time
787/// equivalent of the runtime `ConflictTracker`.
788///
789/// Conflict rules match llama.cpp's `ggml_mem_ranges_check`:
790/// - Two read ranges: OK (read-read is concurrent-safe)
791/// - A new read overlapping an existing write: CONFLICT (RAW)
792/// - A new write overlapping any existing range: CONFLICT (WAR/WAW)
793struct ReorderConflictTracker {
794    /// (start, end, is_write) for all ranges in the tracked set.
795    ranges: Vec<(usize, usize, bool)>,
796}
797
798impl ReorderConflictTracker {
799    fn new() -> Self {
800        Self {
801            ranges: Vec::with_capacity(64),
802        }
803    }
804
805    fn reset(&mut self) {
806        self.ranges.clear();
807    }
808
809    /// Check if a dispatch with the given read/write ranges conflicts with
810    /// any range in this tracker.
811    fn conflicts(&self, reads: &[MemRange], writes: &[MemRange]) -> bool {
812        // New reads vs existing writes (RAW)
813        for &(r_start, r_end) in reads {
814            for &(s, e, is_write) in &self.ranges {
815                if is_write && r_start < e && r_end > s {
816                    return true;
817                }
818            }
819        }
820        // New writes vs all existing ranges (WAR/WAW)
821        for &(w_start, w_end) in writes {
822            for &(s, e, _) in &self.ranges {
823                if w_start < e && w_end > s {
824                    return true;
825                }
826            }
827        }
828        false
829    }
830
831    /// Add read and write ranges to the tracked set.
832    fn add(&mut self, reads: &[MemRange], writes: &[MemRange]) {
833        for &(start, end) in reads {
834            self.ranges.push((start, end, false));
835        }
836        for &(start, end) in writes {
837            self.ranges.push((start, end, true));
838        }
839    }
840}
841
842/// Batched Metal dispatch — encodes multiple ops into a single `CommandEncoder`.
843///
844/// Create one per model (or per forward-pass loop).  Call [`begin`](Self::begin)
845/// at the start of each forward pass to get a [`GraphSession`] that holds the
846/// shared encoder.
847pub struct GraphExecutor {
848    device: MlxDevice,
849}
850
851impl GraphExecutor {
852    /// Create a new graph executor backed by the given device.
853    pub fn new(device: MlxDevice) -> Self {
854        Self { device }
855    }
856
857    /// Begin a new forward pass (direct-dispatch mode).
858    ///
859    /// Returns a [`GraphSession`] that holds a fresh `CommandEncoder`.  All ops
860    /// encoded through the session share this single encoder.  Call
861    /// [`GraphSession::finish`] to commit and wait.
862    pub fn begin(&self) -> Result<GraphSession<'_>> {
863        let encoder = self.device.command_encoder()?;
864        Ok(GraphSession {
865            encoder,
866            device: &self.device,
867            barrier_count: 0,
868            tracker: ConflictTracker::new(),
869            dispatch_in_group: 0,
870            total_dispatches: 0,
871            group_sizes: [0; 8],
872            recording: false,
873        })
874    }
875
876    /// Begin a new forward pass in capture (record) mode.
877    ///
878    /// All op calls are recorded into a `ComputeGraph` instead of being
879    /// dispatched to Metal.  When [`GraphSession::finish`] is called, the
880    /// recorded graph is replayed into a fresh encoder via
881    /// `ComputeGraph::encode_sequential()`.
882    ///
883    /// The API is identical to `begin()` — callers do not need to change
884    /// any op call code.  The only behavioral difference: GPU work happens
885    /// at `finish()` time rather than at each op call.
886    pub fn begin_recorded(&self) -> Result<GraphSession<'_>> {
887        let mut encoder = self.device.command_encoder()?;
888        encoder.start_capture();
889        Ok(GraphSession {
890            encoder,
891            device: &self.device,
892            barrier_count: 0,
893            tracker: ConflictTracker::new(),
894            dispatch_in_group: 0,
895            total_dispatches: 0,
896            group_sizes: [0; 8],
897            recording: true,
898        })
899    }
900
901    /// Borrow the underlying device.
902    pub fn device(&self) -> &MlxDevice {
903        &self.device
904    }
905}
906
907/// A single forward pass execution context.
908///
909/// All ops are encoded into one `CommandEncoder`.  Call [`finish`](Self::finish)
910/// to commit the command buffer and wait for GPU completion — this is the ONLY
911/// sync point per forward pass.
912///
913/// If an op returns an error, the session can be dropped without committing.
914/// The underlying command buffer is abandoned (never committed to the GPU).
915/// Tracks buffer address ranges for automatic barrier elision.
916///
917/// Mirrors llama.cpp's `ggml_mem_ranges` — accumulates the read and write
918/// ranges of all dispatches in the current concurrent group. When a new
919/// dispatch's reads overlap with an existing write (RAW), or its writes
920/// overlap with an existing read or write (WAR/WAW), a barrier is needed.
921/// Otherwise the dispatch can run concurrently and the barrier is elided.
922///
923/// Uses CPU-visible `contents_ptr()` addresses, which on Apple Silicon
924/// unified memory equal the GPU addresses.
925pub struct ConflictTracker {
926    /// (start, end, is_write) tuples for the current concurrent group.
927    ranges: Vec<(usize, usize, bool)>,
928}
929
930impl ConflictTracker {
931    fn new() -> Self {
932        Self {
933            ranges: Vec::with_capacity(32),
934        }
935    }
936
937    /// Reset the tracker — called after emitting a barrier.
938    fn reset(&mut self) {
939        self.ranges.clear();
940    }
941
942    /// Check if a new dispatch with the given reads and writes conflicts
943    /// with the current concurrent group.
944    ///
945    /// Conflict rules (same as llama.cpp `ggml_mem_ranges_check`):
946    /// - Two SRC (read) ranges in the same buffer: OK (read-read)
947    /// - A new SRC overlapping an existing DST: CONFLICT (RAW)
948    /// - A new DST overlapping an existing SRC or DST: CONFLICT (WAR/WAW)
949    /// Check for conflicts and return the reason if one is found.
950    /// Returns (conflict_type, new_buf_ptr, existing_buf_ptr) or None.
951    fn conflicts_reason(&self, reads: &[&MlxBuffer], writes: &[&MlxBuffer])
952        -> Option<(&'static str, usize, usize)>
953    {
954        // Check new reads against existing writes (RAW)
955        for r in reads {
956            let r_start = r.contents_ptr() as usize;
957            let r_end = r_start + r.byte_len();
958            for &(s, e, is_write) in &self.ranges {
959                if is_write && r_start < e && r_end > s {
960                    return Some(("RAW", r_start, s));
961                }
962            }
963        }
964        // Check new writes against existing reads and writes (WAR/WAW)
965        for w in writes {
966            let w_start = w.contents_ptr() as usize;
967            let w_end = w_start + w.byte_len();
968            for &(s, e, is_write) in &self.ranges {
969                if w_start < e && w_end > s {
970                    let kind = if is_write { "WAW" } else { "WAR" };
971                    return Some((kind, w_start, s));
972                }
973            }
974        }
975        None
976    }
977
978    /// Add read and write ranges to the current concurrent group.
979    fn add(&mut self, reads: &[&MlxBuffer], writes: &[&MlxBuffer]) {
980        for r in reads {
981            let start = r.contents_ptr() as usize;
982            let end = start + r.byte_len();
983            self.ranges.push((start, end, false));
984        }
985        for w in writes {
986            let start = w.contents_ptr() as usize;
987            let end = start + w.byte_len();
988            self.ranges.push((start, end, true));
989        }
990    }
991}
992
993pub struct GraphSession<'a> {
994    encoder: CommandEncoder,
995    device: &'a MlxDevice,
996    barrier_count: u32,
997    tracker: ConflictTracker,
998    dispatch_in_group: u32,
999    total_dispatches: u32,
1000    /// Histogram: group_sizes[i] = number of concurrent groups with (i+1) dispatches
1001    group_sizes: [u32; 8],
1002    /// Whether this session was created in capture/record mode.
1003    recording: bool,
1004}
1005
1006impl<'a> GraphSession<'a> {
1007    /// Encode an RMS normalization into this session's encoder.
1008    ///
1009    /// Delegates to [`ops::rms_norm::dispatch_rms_norm`].
1010    pub fn rms_norm(
1011        &mut self,
1012        registry: &mut KernelRegistry,
1013        device: &metal::DeviceRef,
1014        input: &MlxBuffer,
1015        weight: &MlxBuffer,
1016        output: &MlxBuffer,
1017        params_buf: &MlxBuffer,
1018        rows: u32,
1019        dim: u32,
1020    ) -> Result<()> {
1021        ops::rms_norm::dispatch_rms_norm(
1022            &mut self.encoder,
1023            registry,
1024            device,
1025            input,
1026            weight,
1027            output,
1028            params_buf,
1029            rows,
1030            dim,
1031        )
1032    }
1033
1034    /// Encode a quantized matrix multiplication into this session's encoder.
1035    ///
1036    /// Delegates to [`ops::quantized_matmul::quantized_matmul`].
1037    /// Returns the freshly allocated output buffer.
1038    pub fn quantized_matmul(
1039        &mut self,
1040        registry: &mut KernelRegistry,
1041        device: &MlxDevice,
1042        input: &MlxBuffer,
1043        weight: &MlxBuffer,
1044        scales: &MlxBuffer,
1045        biases: &MlxBuffer,
1046        params: &ops::quantized_matmul::QuantizedMatmulParams,
1047    ) -> Result<MlxBuffer> {
1048        ops::quantized_matmul::quantized_matmul(
1049            &mut self.encoder,
1050            registry,
1051            device,
1052            input,
1053            weight,
1054            scales,
1055            biases,
1056            params,
1057        )
1058    }
1059
1060    /// Encode a SIMD-optimized quantized matmul into this session's encoder.
1061    ///
1062    /// Delegates to [`ops::quantized_matmul::quantized_matmul_simd`].
1063    /// Returns the freshly allocated output buffer.
1064    pub fn quantized_matmul_simd(
1065        &mut self,
1066        registry: &mut KernelRegistry,
1067        device: &MlxDevice,
1068        input: &MlxBuffer,
1069        weight: &MlxBuffer,
1070        scales: &MlxBuffer,
1071        biases: &MlxBuffer,
1072        params: &ops::quantized_matmul::QuantizedMatmulParams,
1073    ) -> Result<MlxBuffer> {
1074        ops::quantized_matmul::quantized_matmul_simd(
1075            &mut self.encoder,
1076            registry,
1077            device,
1078            input,
1079            weight,
1080            scales,
1081            biases,
1082            params,
1083        )
1084    }
1085
1086    /// Encode a GGML block-format quantized mat-vec into this session's encoder.
1087    ///
1088    /// Delegates to [`ops::quantized_matmul_ggml::quantized_matmul_ggml`].
1089    pub fn quantized_matmul_ggml(
1090        &mut self,
1091        registry: &mut KernelRegistry,
1092        device: &MlxDevice,
1093        input: &MlxBuffer,
1094        weight: &MlxBuffer,
1095        output: &mut MlxBuffer,
1096        params: &ops::quantized_matmul_ggml::GgmlQuantizedMatmulParams,
1097    ) -> Result<()> {
1098        ops::quantized_matmul_ggml::quantized_matmul_ggml(
1099            &mut self.encoder,
1100            registry,
1101            device,
1102            input,
1103            weight,
1104            output,
1105            params,
1106        )
1107    }
1108
1109    /// Encode an expert-routed GGML block-format quantized mat-vec into this session's encoder.
1110    ///
1111    /// Delegates to [`ops::quantized_matmul_id_ggml::quantized_matmul_id_ggml`].
1112    #[allow(clippy::too_many_arguments)]
1113    pub fn quantized_matmul_id_ggml(
1114        &mut self,
1115        registry: &mut KernelRegistry,
1116        device: &MlxDevice,
1117        input: &MlxBuffer,
1118        weight: &MlxBuffer,
1119        ids: &MlxBuffer,
1120        output: &mut MlxBuffer,
1121        params: &ops::quantized_matmul_id_ggml::GgmlQuantizedMatmulIdParams,
1122    ) -> Result<()> {
1123        ops::quantized_matmul_id_ggml::quantized_matmul_id_ggml(
1124            &mut self.encoder,
1125            registry,
1126            device,
1127            input,
1128            weight,
1129            ids,
1130            output,
1131            params,
1132        )
1133    }
1134
1135    /// Encode scaled dot-product attention into this session's encoder.
1136    ///
1137    /// Delegates to [`ops::sdpa::sdpa`].
1138    pub fn sdpa(
1139        &mut self,
1140        registry: &mut KernelRegistry,
1141        device: &MlxDevice,
1142        q: &MlxBuffer,
1143        k: &MlxBuffer,
1144        v: &MlxBuffer,
1145        output: &MlxBuffer,
1146        params: &ops::sdpa::SdpaParams,
1147        batch_size: u32,
1148    ) -> Result<()> {
1149        ops::sdpa::sdpa(
1150            &mut self.encoder,
1151            registry,
1152            device,
1153            q,
1154            k,
1155            v,
1156            output,
1157            params,
1158            batch_size,
1159        )
1160    }
1161
1162    /// Encode flash attention vector (SIMD-vectorized decode-path SDPA).
1163    ///
1164    /// Delegates to [`ops::flash_attn_vec::flash_attn_vec`].
1165    pub fn flash_attn_vec(
1166        &mut self,
1167        registry: &mut KernelRegistry,
1168        device: &MlxDevice,
1169        q: &MlxBuffer,
1170        k: &MlxBuffer,
1171        v: &MlxBuffer,
1172        output: &MlxBuffer,
1173        tmp: &MlxBuffer,
1174        params: &ops::flash_attn_vec::FlashAttnVecParams,
1175    ) -> Result<()> {
1176        ops::flash_attn_vec::flash_attn_vec(
1177            &mut self.encoder,
1178            registry,
1179            device,
1180            q,
1181            k,
1182            v,
1183            output,
1184            tmp,
1185            params,
1186        )
1187    }
1188
1189    /// Encode an elementwise add into this session's encoder.
1190    ///
1191    /// Delegates to [`ops::elementwise::elementwise_add`].
1192    pub fn elementwise_add(
1193        &mut self,
1194        registry: &mut KernelRegistry,
1195        device: &metal::DeviceRef,
1196        a: &MlxBuffer,
1197        b: &MlxBuffer,
1198        output: &MlxBuffer,
1199        n_elements: usize,
1200        dtype: DType,
1201    ) -> Result<()> {
1202        ops::elementwise::elementwise_add(
1203            &mut self.encoder,
1204            registry,
1205            device,
1206            a,
1207            b,
1208            output,
1209            n_elements,
1210            dtype,
1211        )
1212    }
1213
1214    /// Encode an elementwise multiply into this session's encoder.
1215    ///
1216    /// Delegates to [`ops::elementwise::elementwise_mul`].
1217    pub fn elementwise_mul(
1218        &mut self,
1219        registry: &mut KernelRegistry,
1220        device: &metal::DeviceRef,
1221        a: &MlxBuffer,
1222        b: &MlxBuffer,
1223        output: &MlxBuffer,
1224        n_elements: usize,
1225        dtype: DType,
1226    ) -> Result<()> {
1227        ops::elementwise::elementwise_mul(
1228            &mut self.encoder,
1229            registry,
1230            device,
1231            a,
1232            b,
1233            output,
1234            n_elements,
1235            dtype,
1236        )
1237    }
1238
1239    /// Encode a RoPE transform into this session's encoder.
1240    ///
1241    /// Delegates to [`ops::rope::dispatch_rope`].
1242    pub fn rope(
1243        &mut self,
1244        registry: &mut KernelRegistry,
1245        device: &metal::DeviceRef,
1246        input: &MlxBuffer,
1247        output: &MlxBuffer,
1248        params_buf: &MlxBuffer,
1249        positions_buf: &MlxBuffer,
1250        seq_len: u32,
1251        head_dim: u32,
1252    ) -> Result<()> {
1253        ops::rope::dispatch_rope(
1254            &mut self.encoder,
1255            registry,
1256            device,
1257            input,
1258            output,
1259            params_buf,
1260            positions_buf,
1261            seq_len,
1262            head_dim,
1263        )
1264    }
1265
1266    /// Encode a GELU activation into this session's encoder.
1267    ///
1268    /// Delegates to [`ops::gelu::dispatch_gelu`].
1269    pub fn gelu(
1270        &mut self,
1271        registry: &mut KernelRegistry,
1272        device: &metal::DeviceRef,
1273        input: &MlxBuffer,
1274        output: &MlxBuffer,
1275    ) -> Result<()> {
1276        ops::gelu::dispatch_gelu(
1277            &mut self.encoder,
1278            registry,
1279            device,
1280            input,
1281            output,
1282        )
1283    }
1284
1285    /// Encode a softmax into this session's encoder.
1286    ///
1287    /// Delegates to [`ops::softmax::dispatch_softmax`].
1288    pub fn softmax(
1289        &mut self,
1290        registry: &mut KernelRegistry,
1291        device: &metal::DeviceRef,
1292        input: &MlxBuffer,
1293        output: &MlxBuffer,
1294        params_buf: &MlxBuffer,
1295        rows: u32,
1296        cols: u32,
1297    ) -> Result<()> {
1298        ops::softmax::dispatch_softmax(
1299            &mut self.encoder,
1300            registry,
1301            device,
1302            input,
1303            output,
1304            params_buf,
1305            rows,
1306            cols,
1307        )
1308    }
1309
1310    /// Encode a softcap into this session's encoder.
1311    ///
1312    /// Delegates to [`ops::softcap::dispatch_softcap`].
1313    pub fn softcap(
1314        &mut self,
1315        registry: &mut KernelRegistry,
1316        device: &metal::DeviceRef,
1317        input: &MlxBuffer,
1318        output: &MlxBuffer,
1319        params_buf: &MlxBuffer,
1320        cap: f32,
1321    ) -> Result<()> {
1322        ops::softcap::dispatch_softcap(
1323            &mut self.encoder,
1324            registry,
1325            device,
1326            input,
1327            output,
1328            params_buf,
1329            cap,
1330        )
1331    }
1332
1333    /// Encode an RMS norm without learned scale (f32) into this session's encoder.
1334    ///
1335    /// Delegates to [`ops::rms_norm::dispatch_rms_norm_no_scale_f32`].
1336    pub fn rms_norm_no_scale_f32(
1337        &mut self,
1338        registry: &mut KernelRegistry,
1339        device: &metal::DeviceRef,
1340        input: &MlxBuffer,
1341        output: &MlxBuffer,
1342        params_buf: &MlxBuffer,
1343        rows: u32,
1344        dim: u32,
1345    ) -> Result<()> {
1346        ops::rms_norm::dispatch_rms_norm_no_scale_f32(
1347            &mut self.encoder,
1348            registry,
1349            device,
1350            input,
1351            output,
1352            params_buf,
1353            rows,
1354            dim,
1355        )
1356    }
1357
1358    /// Encode a NeoX RoPE (f32) with optional freq_factors into this session's encoder.
1359    ///
1360    /// Delegates to [`ops::rope::dispatch_rope_neox_f32`].
1361    #[allow(clippy::too_many_arguments)]
1362    pub fn rope_neox_f32(
1363        &mut self,
1364        registry: &mut KernelRegistry,
1365        device: &metal::DeviceRef,
1366        input: &MlxBuffer,
1367        output: &MlxBuffer,
1368        params_buf: &MlxBuffer,
1369        positions_buf: &MlxBuffer,
1370        freq_factors: Option<&MlxBuffer>,
1371        seq_len: u32,
1372        n_heads: u32,
1373        head_dim: u32,
1374        rope_dim: u32,
1375    ) -> Result<()> {
1376        ops::rope::dispatch_rope_neox_f32(
1377            &mut self.encoder,
1378            registry,
1379            device,
1380            input,
1381            output,
1382            params_buf,
1383            positions_buf,
1384            freq_factors,
1385            seq_len,
1386            n_heads,
1387            head_dim,
1388            rope_dim,
1389        )
1390    }
1391
1392    /// Insert a GPU memory barrier (MTLBarrierScopeBuffers).
1393    ///
1394    /// Unconditional barrier — always emits. Use `barrier_between` for
1395    /// automatic conflict detection that can elide unnecessary barriers.
1396    #[inline]
1397    pub fn barrier(&mut self) {
1398        // Record the outgoing group size
1399        if self.dispatch_in_group > 0 {
1400            let idx = (self.dispatch_in_group as usize).min(self.group_sizes.len()) - 1;
1401            self.group_sizes[idx] += 1;
1402        }
1403        self.encoder.memory_barrier();
1404        self.tracker.reset();
1405        self.barrier_count += 1;
1406        self.dispatch_in_group = 0;
1407    }
1408
1409    /// Smart barrier with conflict detection.
1410    ///
1411    /// Checks if the next dispatch (with the given read and write buffers)
1412    /// actually conflicts with any dispatch in the current concurrent group.
1413    /// If yes, emits a Metal barrier and resets the tracker. If no, the
1414    /// barrier is elided and the dispatch can run concurrently.
1415    ///
1416    /// This mirrors llama.cpp's `ggml_metal_op_concurrency_check` +
1417    /// `ggml_metal_op_concurrency_reset` pattern.
1418    #[inline]
1419    pub fn barrier_between(&mut self, reads: &[&MlxBuffer], writes: &[&MlxBuffer]) {
1420        // In capture mode, stash the read/write ranges so the next captured
1421        // dispatch node carries them for the reorder pass (Phase 4e.3).
1422        if self.recording {
1423            let read_ranges: Vec<MemRange> = reads
1424                .iter()
1425                .map(|b| {
1426                    let start = b.contents_ptr() as usize;
1427                    (start, start + b.byte_len())
1428                })
1429                .collect();
1430            let write_ranges: Vec<MemRange> = writes
1431                .iter()
1432                .map(|b| {
1433                    let start = b.contents_ptr() as usize;
1434                    (start, start + b.byte_len())
1435                })
1436                .collect();
1437            self.encoder.set_pending_buffer_ranges(read_ranges, write_ranges);
1438        }
1439
1440        let reason = self.tracker.conflicts_reason(reads, writes);
1441        if let Some((_kind, _new_ptr, _existing_ptr)) = reason {
1442            // Record the outgoing group size before resetting
1443            if self.dispatch_in_group > 0 {
1444                let idx = (self.dispatch_in_group as usize).min(self.group_sizes.len()) - 1;
1445                self.group_sizes[idx] += 1;
1446            }
1447            self.encoder.memory_barrier();
1448            self.tracker.reset();
1449            self.barrier_count += 1;
1450            self.dispatch_in_group = 0;
1451        }
1452        self.dispatch_in_group += 1;
1453        self.total_dispatches += 1;
1454        self.tracker.add(reads, writes);
1455    }
1456
1457    /// Print group size histogram to stderr (for HF2Q_MLX_TIMING debug).
1458    pub fn dump_group_stats(&self) {
1459        // Record the final (unterminated) group
1460        let mut gs = self.group_sizes;
1461        if self.dispatch_in_group > 0 {
1462            let idx = (self.dispatch_in_group as usize).min(gs.len()) - 1;
1463            gs[idx] += 1;
1464        }
1465        let total_groups: u32 = gs.iter().sum();
1466        eprintln!("  [GROUP_STATS] dispatches={} barriers={} groups={} ratio={:.2}",
1467            self.total_dispatches, self.barrier_count, total_groups,
1468            if total_groups > 0 { self.total_dispatches as f64 / total_groups as f64 } else { 0.0 });
1469        for (i, &count) in gs.iter().enumerate() {
1470            if count > 0 {
1471                eprintln!("    size {}: {} groups", i + 1, count);
1472            }
1473        }
1474    }
1475
1476    /// Register a dispatch's buffer ranges without checking for conflicts.
1477    ///
1478    /// Use after dispatching an op that doesn't need a barrier check (e.g.,
1479    /// the first dispatch in a session, or dispatches known to be concurrent).
1480    ///
1481    /// In recording mode, also retroactively annotates the most recently
1482    /// captured dispatch node with these ranges if it was missing them.
1483    /// That keeps the reorder pass able to reason about dispatches that
1484    /// were preceded by `track_dispatch` rather than `barrier_between`.
1485    #[inline]
1486    pub fn track_dispatch(&mut self, reads: &[&MlxBuffer], writes: &[&MlxBuffer]) {
1487        if self.recording {
1488            let read_ranges: Vec<MemRange> = reads
1489                .iter()
1490                .map(|b| {
1491                    let start = b.contents_ptr() as usize;
1492                    (start, start + b.byte_len())
1493                })
1494                .collect();
1495            let write_ranges: Vec<MemRange> = writes
1496                .iter()
1497                .map(|b| {
1498                    let start = b.contents_ptr() as usize;
1499                    (start, start + b.byte_len())
1500                })
1501                .collect();
1502            self.encoder
1503                .annotate_last_dispatch_if_missing(read_ranges, write_ranges);
1504        }
1505        self.tracker.add(reads, writes);
1506    }
1507
1508    /// Return the number of barriers inserted so far in this session.
1509    #[inline]
1510    pub fn barrier_count(&self) -> u32 {
1511        self.barrier_count
1512    }
1513
1514    /// Cumulative nanoseconds spent in ConflictTracker checks (diagnostic).
1515    /// Returns 0 when timing is not compiled in.
1516    pub fn tracker_overhead_ns(&self) -> u64 {
1517        0
1518    }
1519
1520    /// Borrow the underlying command encoder for direct op dispatch.
1521    ///
1522    /// Use this when you need to call an op function that is not wrapped by
1523    /// a `GraphSession` method.  The returned encoder is the same shared
1524    /// encoder — all dispatches still go into the same command buffer.
1525    pub fn encoder_mut(&mut self) -> &mut CommandEncoder {
1526        &mut self.encoder
1527    }
1528
1529    /// Borrow the device reference.
1530    pub fn device(&self) -> &MlxDevice {
1531        self.device
1532    }
1533
1534    /// Whether this session is in capture/record mode.
1535    pub fn is_recording(&self) -> bool {
1536        self.recording
1537    }
1538
1539    /// Commit the command buffer and wait for GPU completion.
1540    ///
1541    /// This is the ONLY sync point per forward pass.  After this call, all
1542    /// output buffers are readable by the CPU.
1543    ///
1544    /// In recording mode: extracts the captured graph, replays it into
1545    /// the encoder via `ComputeGraph::encode_sequential()`, then commits
1546    /// and waits.  The result is identical to the direct-dispatch path.
1547    ///
1548    /// Consumes the session — no further ops can be encoded.
1549    pub fn finish(mut self) -> Result<()> {
1550        if self.recording {
1551            if let Some(nodes) = self.encoder.take_capture() {
1552                let graph = ComputeGraph::from_nodes(nodes);
1553                graph.encode_sequential(&mut self.encoder);
1554            }
1555        }
1556        self.encoder.commit_and_wait()
1557    }
1558
1559    /// Commit the command buffer WITHOUT waiting.
1560    ///
1561    /// The GPU begins executing immediately.  Use this for fire-and-forget
1562    /// dispatch when you do not need results until later.
1563    ///
1564    /// In recording mode: replays the captured graph before committing.
1565    ///
1566    /// Consumes the session.
1567    pub fn commit(mut self) -> CommandEncoder {
1568        if self.recording {
1569            if let Some(nodes) = self.encoder.take_capture() {
1570                let graph = ComputeGraph::from_nodes(nodes);
1571                graph.encode_sequential(&mut self.encoder);
1572            }
1573        }
1574        self.encoder.commit();
1575        self.encoder
1576    }
1577
1578    /// Commit the command buffer and wait, returning split timing.
1579    ///
1580    /// Returns `(encoding_ns, gpu_wait_ns)` where:
1581    /// - `encoding_ns` is the time from session begin to commit (CPU encoding)
1582    /// - `gpu_wait_ns` is the time from commit to GPU completion
1583    ///
1584    /// The `session_begin` instant should be captured right after `exec.begin()`.
1585    ///
1586    /// In recording mode: replays the captured graph before committing.
1587    ///
1588    /// Consumes the session.
1589    pub fn finish_with_timing(mut self, session_begin: std::time::Instant) -> Result<(u64, u64)> {
1590        if self.recording {
1591            if let Some(nodes) = self.encoder.take_capture() {
1592                let graph = ComputeGraph::from_nodes(nodes);
1593                graph.encode_sequential(&mut self.encoder);
1594            }
1595        }
1596        let commit_start = std::time::Instant::now();
1597        let encoding_ns = commit_start.duration_since(session_begin).as_nanos() as u64;
1598        self.encoder.commit();
1599        self.encoder.wait_until_completed()?;
1600        let gpu_wait_ns = commit_start.elapsed().as_nanos() as u64;
1601        Ok((encoding_ns, gpu_wait_ns))
1602    }
1603
1604    /// Finish with fusion: run the RMS norm + MUL fusion pass before
1605    /// replaying the graph.
1606    ///
1607    /// Only meaningful in recording mode.  In direct-dispatch mode, this
1608    /// behaves identically to `finish()`.
1609    ///
1610    /// Returns `(fusions_applied,)` on success.
1611    pub fn finish_with_fusion(
1612        mut self,
1613        registry: &mut KernelRegistry,
1614        device: &metal::DeviceRef,
1615    ) -> Result<u32> {
1616        let mut fusions = 0;
1617        if self.recording {
1618            if let Some(nodes) = self.encoder.take_capture() {
1619                let mut graph = ComputeGraph::from_nodes(nodes);
1620                fusions = graph.fuse(registry, device)?;
1621                graph.encode_sequential(&mut self.encoder);
1622            }
1623        }
1624        self.encoder.commit_and_wait()?;
1625        Ok(fusions)
1626    }
1627
1628    /// Finish with fusion and split timing.
1629    ///
1630    /// Like `finish_with_timing` but runs the fusion pass first.
1631    /// Returns `(encoding_ns, gpu_wait_ns, fusions_applied)`.
1632    pub fn finish_with_fusion_and_timing(
1633        mut self,
1634        registry: &mut KernelRegistry,
1635        device: &metal::DeviceRef,
1636        session_begin: std::time::Instant,
1637    ) -> Result<(u64, u64, u32)> {
1638        let mut fusions = 0;
1639        if self.recording {
1640            if let Some(nodes) = self.encoder.take_capture() {
1641                let mut graph = ComputeGraph::from_nodes(nodes);
1642                fusions = graph.fuse(registry, device)?;
1643                graph.encode_sequential(&mut self.encoder);
1644            }
1645        }
1646        let commit_start = std::time::Instant::now();
1647        let encoding_ns = commit_start.duration_since(session_begin).as_nanos() as u64;
1648        self.encoder.commit();
1649        self.encoder.wait_until_completed()?;
1650        let gpu_wait_ns = commit_start.elapsed().as_nanos() as u64;
1651        Ok((encoding_ns, gpu_wait_ns, fusions))
1652    }
1653
1654    /// Finish with fusion AND reorder: run both graph optimization passes
1655    /// before replaying the graph.
1656    ///
1657    /// Only meaningful in recording mode.  In direct-dispatch mode, this
1658    /// behaves identically to `finish()`.
1659    ///
1660    /// Returns `(fusions_applied, nodes_reordered)` on success.
1661    pub fn finish_with_fusion_and_reorder(
1662        mut self,
1663        registry: &mut KernelRegistry,
1664        device: &metal::DeviceRef,
1665    ) -> Result<(u32, u32)> {
1666        let mut fusions = 0;
1667        let mut reordered = 0;
1668        if self.recording {
1669            if let Some(nodes) = self.encoder.take_capture() {
1670                let mut graph = ComputeGraph::from_nodes(nodes);
1671                fusions = graph.fuse(registry, device)?;
1672                reordered = graph.reorder();
1673                graph.encode_with_barriers(&mut self.encoder);
1674            }
1675        }
1676        self.encoder.commit_and_wait()?;
1677        Ok((fusions, reordered))
1678    }
1679
1680    /// Finish with fusion, reorder, and split timing.
1681    ///
1682    /// Like `finish_with_fusion_and_timing` but also runs the reorder pass.
1683    /// Returns `(encoding_ns, gpu_wait_ns, fusions_applied, nodes_reordered)`.
1684    pub fn finish_with_fusion_reorder_and_timing(
1685        mut self,
1686        registry: &mut KernelRegistry,
1687        device: &metal::DeviceRef,
1688        session_begin: std::time::Instant,
1689    ) -> Result<(u64, u64, u32, u32)> {
1690        let mut fusions = 0;
1691        let mut reordered = 0;
1692        if self.recording {
1693            if let Some(nodes) = self.encoder.take_capture() {
1694                let mut graph = ComputeGraph::from_nodes(nodes);
1695                fusions = graph.fuse(registry, device)?;
1696                reordered = graph.reorder();
1697                graph.encode_with_barriers(&mut self.encoder);
1698            }
1699        }
1700        let commit_start = std::time::Instant::now();
1701        let encoding_ns = commit_start.duration_since(session_begin).as_nanos() as u64;
1702        self.encoder.commit();
1703        self.encoder.wait_until_completed()?;
1704        let gpu_wait_ns = commit_start.elapsed().as_nanos() as u64;
1705        Ok((encoding_ns, gpu_wait_ns, fusions, reordered))
1706    }
1707
1708    /// Finish with the full optimization pipeline: fuse, reorder, dual-buffer
1709    /// encode.
1710    ///
1711    /// Runs the fusion pass, reorder pass, then encodes the graph into two
1712    /// Metal command buffers for CPU/GPU overlap.  The first ~10% of dispatches
1713    /// are committed immediately so the GPU can start executing while the CPU
1714    /// encodes the remaining ~90%.
1715    ///
1716    /// Only meaningful in recording mode.  In direct-dispatch mode, this
1717    /// behaves identically to `finish()`.
1718    ///
1719    /// Returns `(fusions_applied, nodes_reordered, barriers_buf0, barriers_buf1)`.
1720    pub fn finish_optimized(
1721        mut self,
1722        registry: &mut KernelRegistry,
1723        device: &metal::DeviceRef,
1724    ) -> Result<(u32, u32, u32, u32)> {
1725        let mut fusions = 0;
1726        let mut reordered = 0;
1727        let mut barriers0 = 0u32;
1728        let mut barriers1 = 0u32;
1729
1730        if self.recording {
1731            if let Some(nodes) = self.encoder.take_capture() {
1732                // Commit the capture encoder's empty command buffer so its
1733                // MTLCommandQueue pool slot is freed (same fix as timing variant).
1734                self.encoder.commit();
1735
1736                let mut graph = ComputeGraph::from_nodes(nodes);
1737                fusions = graph.fuse(registry, device)?;
1738                reordered = graph.reorder();
1739
1740                let mut enc0 = self.device.command_encoder()?;
1741                let mut enc1 = self.device.command_encoder()?;
1742
1743                let (b0, b1) = graph.encode_dual_buffer(&mut enc0, &mut enc1);
1744                barriers0 = b0;
1745                barriers1 = b1;
1746
1747                // enc0 was already committed inside encode_dual_buffer.
1748                // Commit enc1 and wait — Metal queue ordering guarantees enc0
1749                // finishes before enc1 starts executing.
1750                enc1.commit_and_wait()?;
1751
1752                // The original encoder was never committed (capture mode drained
1753                // it). We need to end it cleanly — dropping it will end the
1754                // active encoder if any, and the uncommitted command buffer is
1755                // abandoned.  That is safe: Metal silently drops uncommitted
1756                // command buffers.
1757                return Ok((fusions, reordered, barriers0, barriers1));
1758            }
1759        }
1760
1761        // Direct-dispatch fallback: just commit the original encoder.
1762        self.encoder.commit_and_wait()?;
1763        Ok((fusions, reordered, barriers0, barriers1))
1764    }
1765
1766    /// Finish with the full optimization pipeline and split timing.
1767    ///
1768    /// Like `finish_optimized` but returns timing information.
1769    /// Returns `(encoding_ns, gpu_wait_ns, fusions, reordered, barriers_buf0, barriers_buf1)`.
1770    ///
1771    /// Timing breakdown:
1772    /// - `encoding_ns`: CPU time from session begin to first buffer commit
1773    ///   (fusion + reorder + encode chunk 0)
1774    /// - `gpu_wait_ns`: wall time from second buffer commit to GPU completion
1775    ///   (includes GPU execution of both buffers, overlapped with chunk 1 encoding)
1776    pub fn finish_optimized_with_timing(
1777        mut self,
1778        registry: &mut KernelRegistry,
1779        device: &metal::DeviceRef,
1780        session_begin: std::time::Instant,
1781    ) -> Result<(u64, u64, u32, u32, u32, u32)> {
1782        let mut fusions = 0;
1783        let mut reordered = 0;
1784        let mut barriers0 = 0u32;
1785        let mut barriers1 = 0u32;
1786
1787        if self.recording {
1788            if let Some(nodes) = self.encoder.take_capture() {
1789                // Commit the capture encoder's empty command buffer so its
1790                // MTLCommandQueue pool slot is freed.  Without this, each
1791                // token leaks one uncommitted buffer and the queue exhausts
1792                // its ~64-slot pool after ~64 tokens, causing a deadlock.
1793                self.encoder.commit();
1794
1795                let opt_t0 = std::time::Instant::now();
1796                let mut graph = ComputeGraph::from_nodes(nodes);
1797                let fuse_t0 = std::time::Instant::now();
1798                fusions = graph.fuse(registry, device)?;
1799                let fuse_us = fuse_t0.elapsed().as_micros();
1800
1801                let reorder_t0 = std::time::Instant::now();
1802                let unannotated = graph.unannotated_dispatch_count();
1803                if unannotated == 0 {
1804                    reordered = graph.reorder();
1805                } else if std::env::var("HF2Q_MLX_TIMING").is_ok() {
1806                    eprintln!("  [GRAPH_OPT] WARN: skipping reorder — {} of {} dispatches lack range annotations",
1807                        unannotated, graph.dispatch_count());
1808                }
1809                let reorder_us = reorder_t0.elapsed().as_micros();
1810                let opt_us = opt_t0.elapsed().as_micros();
1811
1812                let diag = std::env::var("HF2Q_GRAPH_DIAG").is_ok();
1813                let t0 = std::time::Instant::now();
1814                let mut enc0 = self.device.command_encoder()?;
1815                let mut enc1 = self.device.command_encoder()?;
1816                let enc_create_us = t0.elapsed().as_micros();
1817
1818                let t1 = std::time::Instant::now();
1819                let (b0, b1) = graph.encode_dual_buffer(&mut enc0, &mut enc1);
1820                barriers0 = b0;
1821                barriers1 = b1;
1822                let encode_us = t1.elapsed().as_micros();
1823
1824                let encoding_ns = session_begin.elapsed().as_nanos() as u64;
1825
1826                let wait_start = std::time::Instant::now();
1827                enc1.commit_and_wait()?;
1828                let gpu_wait_ns = wait_start.elapsed().as_nanos() as u64;
1829
1830                if diag {
1831                    eprintln!("  [DIAG] fuse={:.1}ms reorder={:.1}ms opt_total={:.1}ms enc_create={:.1}ms encode={:.1}ms gpu_wait={:.1}ms barriers={}+{}",
1832                        fuse_us as f64 / 1e3, reorder_us as f64 / 1e3, opt_us as f64 / 1e3,
1833                        enc_create_us as f64 / 1e3, encode_us as f64 / 1e3,
1834                        gpu_wait_ns as f64 / 1e6, b0, b1);
1835                }
1836
1837                return Ok((encoding_ns, gpu_wait_ns, fusions, reordered, barriers0, barriers1));
1838            }
1839        }
1840
1841        // Direct-dispatch fallback.
1842        let commit_start = std::time::Instant::now();
1843        let encoding_ns = commit_start.duration_since(session_begin).as_nanos() as u64;
1844        self.encoder.commit();
1845        self.encoder.wait_until_completed()?;
1846        let gpu_wait_ns = commit_start.elapsed().as_nanos() as u64;
1847        Ok((encoding_ns, gpu_wait_ns, fusions, reordered, barriers0, barriers1))
1848    }
1849}