Skip to main content

mlx_native/
graph.rs

1//! [`GraphExecutor`] — batched Metal dispatch for single-encoder forward passes.
2//!
3//! llama.cpp's speed advantage over candle is NOT the kernels (Phase 0 proved
4//! candle's are as fast or faster per-call).  It is the dispatch pattern:
5//! 1 encoder per command buffer instead of ~120.  This module implements that
6//! pattern.
7//!
8//! # Usage
9//!
10//! ```ignore
11//! let mut executor = GraphExecutor::new(device.clone());
12//! let mut session = executor.begin()?;
13//!
14//! // All ops encode into the same command buffer — no per-op encoder creation.
15//! session.rms_norm(&mut registry, device.metal_device(), input, weight, output, params, rows, dim)?;
16//! session.quantized_matmul(&mut registry, &device, input, weight, scales, biases, &qparams)?;
17//! session.elementwise_add(&mut registry, device.metal_device(), a, b, out, n, DType::F32)?;
18//!
19//! // Single GPU sync point for the entire forward pass.
20//! session.finish()?;
21//! ```
22//!
23//! # Design
24//!
25//! The `GraphSession` holds a single `CommandEncoder`.  Each op method delegates
26//! to the existing op dispatch functions in [`crate::ops`], passing the session's
27//! shared encoder.  No new Metal code is needed — the ops already work with a
28//! shared encoder.  The executor just prevents creating a new encoder per op.
29//!
30//! # Phase 4e.1 — Graph IR
31//!
32//! The `ComputeGraph` type captures dispatches into a `Vec<CapturedNode>` for
33//! later replay.  `GraphExecutor::begin_recorded()` starts a session in capture
34//! mode: all op calls are intercepted at the `CommandEncoder` level and recorded
35//! instead of being sent to Metal.  `GraphSession::finish()` detects capture
36//! mode, extracts the recorded graph, and replays it into a fresh encoder via
37//! `ComputeGraph::encode_sequential()`.
38//!
39//! The existing direct-dispatch path (`begin()`) is completely unchanged.
40
41use metal::foreign_types::ForeignType;
42
43use crate::device::MlxDevice;
44use crate::encoder::{CapturedNode, CapturedOpKind, CommandEncoder, MemRange, RecordedBinding};
45use crate::error::Result;
46use crate::kernel_registry::KernelRegistry;
47use crate::ops;
48
49// Re-export types used in the public API so callers don't need separate imports.
50pub use crate::buffer::MlxBuffer;
51pub use crate::dtypes::DType;
52
53// ---------------------------------------------------------------------------
54// OpKind — operation classification for the reorder safety whitelist (4e.3)
55// ---------------------------------------------------------------------------
56
57/// Classification of a compute operation for reorder safety analysis.
58///
59/// Operations marked as reorderable can be freely reordered by the graph
60/// optimizer (Phase 4e.3) as long as their data dependencies allow it.
61/// Non-reorderable operations have side effects or dependencies that
62/// require them to stay in their original sequential position.
63#[derive(Clone, Copy, Debug, PartialEq, Eq)]
64pub enum OpKind {
65    /// Matrix multiplication (reorderable).
66    MatMul,
67    /// Expert-routed matrix multiplication (reorderable).
68    MatMulId,
69    /// Normalization — RMS norm, layer norm (reorderable).
70    Norm,
71    /// Rotary position embedding (reorderable).
72    Rope,
73    /// Elementwise ops — add, mul, scale, gelu, softcap, etc. (reorderable).
74    Elementwise,
75    /// Memory copy — KV cache copy, embedding gather (reorderable).
76    Copy,
77    /// Gather/scatter (reorderable).
78    Gather,
79    /// Scaled dot-product attention (NOT reorderable).
80    Sdpa,
81    /// Softmax (NOT reorderable).
82    Softmax,
83    /// MoE gate with CPU readback dependency (NOT reorderable).
84    MoeGate,
85    /// Anything else (NOT reorderable).
86    Other,
87}
88
89impl OpKind {
90    /// Whether this op kind is safe to reorder in the graph optimizer.
91    pub fn is_reorderable(&self) -> bool {
92        matches!(
93            self,
94            Self::MatMul
95                | Self::MatMulId
96                | Self::Norm
97                | Self::Rope
98                | Self::Elementwise
99                | Self::Copy
100                | Self::Gather
101        )
102    }
103}
104
105// ---------------------------------------------------------------------------
106// ComputeGraph — the recorded graph IR
107// ---------------------------------------------------------------------------
108
109/// A recorded sequence of GPU compute dispatches and barriers.
110///
111/// Created by running a forward pass with the encoder in capture mode.
112/// Can be replayed into a real `CommandEncoder` via `encode_sequential()`,
113/// producing identical Metal dispatch behavior to the original direct path.
114///
115/// Future phases (4e.2, 4e.3) will add fusion and reorder passes that
116/// transform the graph before encoding.
117pub struct ComputeGraph {
118    nodes: Vec<CapturedNode>,
119}
120
121impl ComputeGraph {
122    /// Create an empty compute graph.
123    pub fn new() -> Self {
124        Self {
125            nodes: Vec::with_capacity(128),
126        }
127    }
128
129    /// Create a compute graph from a pre-built list of captured nodes.
130    pub fn from_nodes(nodes: Vec<CapturedNode>) -> Self {
131        Self { nodes }
132    }
133
134    /// Record a captured node into the graph.
135    pub fn record(&mut self, node: CapturedNode) {
136        self.nodes.push(node);
137    }
138
139    /// Number of nodes (dispatches + barriers) in the graph.
140    pub fn len(&self) -> usize {
141        self.nodes.len()
142    }
143
144    /// Whether the graph contains no nodes.
145    pub fn is_empty(&self) -> bool {
146        self.nodes.is_empty()
147    }
148
149    /// Number of dispatch nodes (excludes barriers).
150    pub fn dispatch_count(&self) -> usize {
151        self.nodes
152            .iter()
153            .filter(|n| matches!(n, CapturedNode::Dispatch { .. }))
154            .count()
155    }
156
157    /// Number of barrier nodes.
158    pub fn barrier_count(&self) -> usize {
159        self.nodes
160            .iter()
161            .filter(|n| matches!(n, CapturedNode::Barrier))
162            .count()
163    }
164
165    /// Borrow the node list.
166    pub fn nodes(&self) -> &[CapturedNode] {
167        &self.nodes
168    }
169
170    /// Count dispatch nodes that have empty read/write range annotations.
171    ///
172    /// Used for diagnostics: if >0, the reorder pass cannot guarantee
173    /// correctness because it relies on complete annotations.
174    pub fn unannotated_dispatch_count(&self) -> usize {
175        self.nodes
176            .iter()
177            .filter(|n| matches!(n, CapturedNode::Dispatch { reads, writes, .. }
178                if reads.is_empty() || writes.is_empty()))
179            .count()
180    }
181
182    /// Take ownership of the node list, consuming the graph.
183    pub fn into_nodes(self) -> Vec<CapturedNode> {
184        self.nodes
185    }
186
187    /// Encode all nodes sequentially into the given encoder.
188    ///
189    /// Barrier sentinel nodes emit a Metal memory barrier.  Dispatch nodes
190    /// are replayed through `CommandEncoder::replay_dispatch()`.
191    ///
192    /// This produces identical GPU behavior to the direct-dispatch path —
193    /// same pipeline bindings, same dispatch dimensions, same barrier
194    /// placement.
195    ///
196    /// Returns the number of barriers emitted.
197    pub fn encode_sequential(&self, encoder: &mut CommandEncoder) -> u32 {
198        let mut barrier_count = 0u32;
199        for node in &self.nodes {
200            match node {
201                CapturedNode::Barrier => {
202                    encoder.memory_barrier();
203                    barrier_count += 1;
204                }
205                CapturedNode::Dispatch {
206                    pipeline,
207                    bindings,
208                    threads_per_grid,
209                    threads_per_threadgroup,
210                    threadgroup_memory,
211                    dispatch_kind,
212                    op_kind,
213                    ..
214                } => {
215                    // ADR-015 iter63 (Phase A.3): forward captured op_kind
216                    // into replay_dispatch via pending_op_kind so the
217                    // per-dispatch profile entries from a recorded graph
218                    // (GraphExecutor::begin_recorded path) are tagged with
219                    // the same op label as direct-dispatch encoding.
220                    encoder.set_op_kind(*op_kind);
221                    encoder.replay_dispatch(
222                        pipeline,
223                        bindings,
224                        threadgroup_memory,
225                        *threads_per_grid,
226                        *threads_per_threadgroup,
227                        *dispatch_kind,
228                    );
229                }
230            }
231        }
232        barrier_count
233    }
234
235    /// Encode the graph into a Metal command buffer, computing barriers on the
236    /// fly from each node's read/write buffer ranges.
237    ///
238    /// This is the correct encoding method for reordered graphs where barrier
239    /// sentinels have been stripped.  Mirrors llama.cpp's encode-time barrier
240    /// insertion via `ggml_metal_op_concurrency_check`.
241    ///
242    /// Returns the number of barriers emitted.
243    pub fn encode_with_barriers(&self, encoder: &mut CommandEncoder) -> u32 {
244        let mut tracker = ReorderConflictTracker::new();
245        let mut barrier_count = 0u32;
246
247        for node in &self.nodes {
248            match node {
249                CapturedNode::Dispatch {
250                    pipeline,
251                    bindings,
252                    threads_per_grid,
253                    threads_per_threadgroup,
254                    threadgroup_memory,
255                    dispatch_kind,
256                    reads,
257                    writes,
258                    op_kind,
259                    ..
260                } => {
261                    let has_ranges = !reads.is_empty() || !writes.is_empty();
262                    if has_ranges && tracker.conflicts(reads, writes) {
263                        encoder.memory_barrier();
264                        tracker.reset();
265                        barrier_count += 1;
266                    }
267                    if has_ranges {
268                        tracker.add(reads, writes);
269                    }
270                    // ADR-015 iter63 (Phase A.3): see encode_sequential note.
271                    encoder.set_op_kind(*op_kind);
272                    encoder.replay_dispatch(
273                        pipeline,
274                        bindings,
275                        threadgroup_memory,
276                        *threads_per_grid,
277                        *threads_per_threadgroup,
278                        *dispatch_kind,
279                    );
280                }
281                CapturedNode::Barrier => {
282                    // Explicit barriers still force a barrier boundary
283                    encoder.memory_barrier();
284                    tracker.reset();
285                    barrier_count += 1;
286                }
287            }
288        }
289        barrier_count
290    }
291
292    /// Encode the graph using two command buffers for CPU/GPU overlap.
293    ///
294    /// The first `n0` dispatches are encoded into `encoder0` and committed
295    /// immediately (GPU starts executing).  The remaining dispatches are encoded
296    /// into `encoder1`.  The caller is responsible for committing `encoder1`.
297    ///
298    /// This matches llama.cpp's dual command buffer pattern from
299    /// `ggml_metal_graph_compute` (ggml-metal-context.m:441-644):
300    /// `n_nodes_0 = MAX(64, 0.1 * n_nodes)` for the first buffer.
301    ///
302    /// Command buffers submitted to the same `MTLCommandQueue` execute in
303    /// submission order, so `encoder0.commit()` followed by `encoder1.commit()`
304    /// guarantees enc0 finishes before enc1 starts.  The win: the GPU starts
305    /// executing enc0 while the CPU is still encoding enc1.
306    ///
307    /// Returns `(barriers_buf0, barriers_buf1)`.
308    pub fn encode_dual_buffer(
309        &self,
310        encoder0: &mut CommandEncoder,
311        encoder1: &mut CommandEncoder,
312    ) -> (u32, u32) {
313        let dispatch_total = self.dispatch_count();
314        let n0 = std::cmp::max(64, dispatch_total / 10);
315
316        // Find the split point: the index of the n0-th dispatch node.
317        let split_idx = find_dispatch_split_index(&self.nodes, n0);
318
319        // Encode first chunk with barrier recomputation, then commit immediately.
320        let barriers0 = encode_chunk_with_barriers(&self.nodes[..split_idx], encoder0);
321        encoder0.commit();
322
323        // Encode second chunk with barrier recomputation.
324        let barriers1 = encode_chunk_with_barriers(&self.nodes[split_idx..], encoder1);
325
326        (barriers0, barriers1)
327    }
328
329    /// Run the RMS norm + MUL fusion pass over the graph.
330    ///
331    /// Scans for the pattern:
332    ///   Dispatch(RmsNorm) → Barrier(s) → Dispatch(ElemMul)
333    /// where the MUL reads the norm's output buffer, and replaces the
334    /// sequence with a single fused `rms_norm_mul_*` dispatch.
335    ///
336    /// The fused dispatch:
337    /// - Reads the norm's input (buffer 0) and weight (buffer 1)
338    /// - Reads the MUL's second operand as the scale (buffer 2)
339    /// - Writes to the MUL's output (buffer 3)
340    /// - Carries the norm's params (buffer 4)
341    /// - Uses the norm's threadgroup config and shared memory
342    ///
343    /// Returns the number of fusions applied.
344    ///
345    /// # Arguments
346    ///
347    /// * `registry` - Kernel registry for compiling the fused pipeline.
348    /// * `device`   - Metal device for pipeline compilation.
349    pub fn fuse(
350        &mut self,
351        registry: &mut KernelRegistry,
352        device: &metal::DeviceRef,
353    ) -> Result<u32> {
354        let mut result: Vec<CapturedNode> = Vec::with_capacity(self.nodes.len());
355        let mut fusions = 0u32;
356        let mut i = 0;
357
358        while i < self.nodes.len() {
359            // Check if current node is an RMS norm dispatch.
360            let is_rms_norm = matches!(
361                &self.nodes[i],
362                CapturedNode::Dispatch { op_kind: CapturedOpKind::RmsNorm, .. }
363            );
364
365            if !is_rms_norm {
366                result.push(self.nodes[i].clone());
367                i += 1;
368                continue;
369            }
370
371            // Look ahead: skip barriers, then check for ElemMul.
372            let mut j = i + 1;
373            let mut barrier_count = 0usize;
374            while j < self.nodes.len() && matches!(&self.nodes[j], CapturedNode::Barrier) {
375                barrier_count += 1;
376                j += 1;
377            }
378
379            // Must have at least one barrier and the next node must be ElemMul.
380            if barrier_count == 0 || j >= self.nodes.len() {
381                result.push(self.nodes[i].clone());
382                i += 1;
383                continue;
384            }
385
386            let is_elem_mul = matches!(
387                &self.nodes[j],
388                CapturedNode::Dispatch { op_kind: CapturedOpKind::ElemMul, .. }
389            );
390
391            if !is_elem_mul {
392                result.push(self.nodes[i].clone());
393                i += 1;
394                continue;
395            }
396
397            // Extract norm and mul dispatch fields.
398            let (norm_pipeline, norm_bindings, norm_tpg, norm_tptg, norm_tgmem, norm_dk) =
399                match &self.nodes[i] {
400                    CapturedNode::Dispatch {
401                        pipeline,
402                        bindings,
403                        threads_per_grid,
404                        threads_per_threadgroup,
405                        threadgroup_memory,
406                        dispatch_kind,
407                        ..
408                    } => (pipeline, bindings, threads_per_grid, threads_per_threadgroup, threadgroup_memory, dispatch_kind),
409                    _ => unreachable!(),
410                };
411
412            let (mul_bindings, _mul_tpg, _mul_tptg) = match &self.nodes[j] {
413                CapturedNode::Dispatch {
414                    bindings,
415                    threads_per_grid,
416                    threads_per_threadgroup,
417                    ..
418                } => (bindings, threads_per_grid, threads_per_threadgroup),
419                _ => unreachable!(),
420            };
421
422            // Verify data dependency: the norm's output buffer (slot 2) must
423            // appear as one of the MUL's input buffers (slot 0 or 1).
424            //
425            // Norm binding layout: (0=input, 1=weight, 2=output, 3=params)
426            // MUL binding layout:  (0=a, 1=b, 2=output, 3=params_bytes)
427            let norm_output_ptr = Self::buffer_ptr_for_slot(norm_bindings, 2);
428            let mul_a_ptr = Self::buffer_ptr_for_slot(mul_bindings, 0);
429            let mul_b_ptr = Self::buffer_ptr_for_slot(mul_bindings, 1);
430
431            if norm_output_ptr.is_none() || (norm_output_ptr != mul_a_ptr && norm_output_ptr != mul_b_ptr) {
432                // Data dependency not confirmed — don't fuse.
433                result.push(self.nodes[i].clone());
434                i += 1;
435                continue;
436            }
437
438            // Determine which MUL input is the scale (the one that is NOT
439            // the norm's output).
440            let scale_slot = if norm_output_ptr == mul_a_ptr { 1 } else { 0 };
441
442            // Build fused bindings:
443            //   0 = norm input
444            //   1 = norm weight
445            //   2 = scale (from MUL)
446            //   3 = MUL output
447            //   4 = norm params
448            // Gather all required bindings; bail if any are missing.
449            let (norm_input, norm_weight, scale, mul_output, norm_params) = match (
450                Self::get_binding(norm_bindings, 0),
451                Self::get_binding(norm_bindings, 1),
452                Self::get_binding(mul_bindings, scale_slot),
453                Self::get_binding(mul_bindings, 2),
454                Self::get_binding(norm_bindings, 3),
455            ) {
456                (Some(a), Some(b), Some(c), Some(d), Some(e)) => (a, b, c, d, e),
457                _ => {
458                    // Missing bindings — don't fuse.
459                    result.push(self.nodes[i].clone());
460                    i += 1;
461                    continue;
462                }
463            };
464
465            // Select fused pipeline based on the original norm pipeline name.
466            // The norm pipeline name is "rms_norm_f32", "rms_norm_f16", or
467            // "rms_norm_bf16" — we need the corresponding fused pipeline.
468            let fused_name = match Self::fused_pipeline_name(norm_pipeline) {
469                Some(name) => name,
470                None => {
471                    result.push(self.nodes[i].clone());
472                    i += 1;
473                    continue;
474                }
475            };
476
477            let fused_pipeline = registry.get_pipeline(fused_name, device)?;
478
479            let fused_bindings = vec![
480                (0, norm_input),
481                (1, norm_weight),
482                (2, scale),
483                (3, mul_output),
484                (4, norm_params),
485            ];
486
487            // Merge read/write ranges from both the norm and mul nodes for the
488            // fused dispatch.  The fused op reads everything the norm reads
489            // plus the mul's scale input, and writes to the mul's output.
490            let (fused_reads, fused_writes) = match (&self.nodes[i], &self.nodes[j]) {
491                (
492                    CapturedNode::Dispatch { reads: nr, writes: _nw, .. },
493                    CapturedNode::Dispatch { reads: mr, writes: mw, .. },
494                ) => {
495                    let mut reads = nr.clone();
496                    reads.extend_from_slice(mr);
497                    (reads, mw.clone())
498                }
499                _ => (Vec::new(), Vec::new()),
500            };
501
502            result.push(CapturedNode::Dispatch {
503                pipeline: fused_pipeline.to_owned(),
504                bindings: fused_bindings,
505                threads_per_grid: *norm_tpg,
506                threads_per_threadgroup: *norm_tptg,
507                threadgroup_memory: norm_tgmem.clone(),
508                dispatch_kind: *norm_dk,
509                op_kind: CapturedOpKind::Other, // Fused ops are not further fuseable
510                reads: fused_reads,
511                writes: fused_writes,
512            });
513
514            fusions += 1;
515            // Skip past the norm, barrier(s), and mul nodes.
516            i = j + 1;
517        }
518
519        self.nodes = result;
520        Ok(fusions)
521    }
522
523    /// Run the reorder pass over the graph to improve GPU concurrency.
524    ///
525    /// Port of llama.cpp's `ggml_metal_graph_optimize_reorder` — a greedy
526    /// 64-node lookahead that pulls independent dispatches forward to fill
527    /// larger concurrent groups between barriers.
528    ///
529    /// **Prerequisites:** Call `fuse()` first if desired.  The reorder pass
530    /// operates on the post-fusion graph.  Barrier sentinel nodes are stripped
531    /// before reordering (they will be recomputed at encode time by the
532    /// `ConflictTracker` in `encode_sequential`).
533    ///
534    /// **Algorithm (matching llama.cpp exactly):**
535    /// 1. Strip all `CapturedNode::Barrier` nodes.
536    /// 2. For each unprocessed node `i0`:
537    ///    - If it conflicts with the current concurrent group (`mrs0`):
538    ///      * Initialize `mrs1` from `i0`'s ranges (skipped-over set)
539    ///      * Lookahead up to 64 nodes for candidates that:
540    ///        (a) Are reorderable (`CapturedOpKind::is_reorderable()`)
541    ///        (b) Don't conflict with `mrs0` (current group)
542    ///        (c) Don't conflict with `mrs1` (skipped-over nodes)
543    ///      * Pull qualifying candidates into the current group
544    ///      * Non-reorderable ops break the lookahead
545    ///    - Reset `mrs0` (new concurrent group)
546    ///    - Add `i0` to the new group
547    ///
548    /// Returns the number of nodes that were moved to earlier positions.
549    pub fn reorder(&mut self) -> u32 {
550        // Step 1: Strip barrier nodes.  After fusion + reorder, barriers will
551        // be recomputed by the ConflictTracker at encode time.
552        self.nodes.retain(|n| !matches!(n, CapturedNode::Barrier));
553
554        let n = self.nodes.len();
555        if n == 0 {
556            return 0;
557        }
558
559        let mut result: Vec<usize> = Vec::with_capacity(n);
560        let mut used = vec![false; n];
561
562        // mrs0: memory ranges for the current concurrent group
563        let mut mrs0 = ReorderConflictTracker::new();
564        // mrs1: memory ranges for skipped-over (unprocessed) nodes
565        let mut mrs1 = ReorderConflictTracker::new();
566
567        const N_FORWARD: usize = 64;
568
569        for i0 in 0..n {
570            if used[i0] {
571                continue;
572            }
573
574            let node0 = &self.nodes[i0];
575
576            // Extract reads/writes for conflict check.
577            let (reads0, writes0, op_kind0) = match node0 {
578                CapturedNode::Dispatch { reads, writes, op_kind, .. } => {
579                    (reads.as_slice(), writes.as_slice(), *op_kind)
580                }
581                CapturedNode::Barrier => continue, // stripped, but be safe
582            };
583
584            // Check if node0 conflicts with the current concurrent group.
585            // Empty nodes (no ranges) never conflict — like llama.cpp's is_empty.
586            let has_ranges = !reads0.is_empty() || !writes0.is_empty();
587            if has_ranges && mrs0.conflicts(reads0, writes0) {
588                // Before starting a new group, look forward for nodes that
589                // can be pulled into the CURRENT group.
590                mrs1.reset();
591                mrs1.add(reads0, writes0);
592
593                let end = (i0 + N_FORWARD).min(n);
594                for i1 in (i0 + 1)..end {
595                    if used[i1] {
596                        continue;
597                    }
598
599                    let node1 = &self.nodes[i1];
600                    let (reads1, writes1, op_kind1) = match node1 {
601                        CapturedNode::Dispatch { reads, writes, op_kind, .. } => {
602                            (reads.as_slice(), writes.as_slice(), *op_kind)
603                        }
604                        CapturedNode::Barrier => continue,
605                    };
606
607                    // Non-reorderable ops break the lookahead.
608                    if !op_kind1.is_reorderable() {
609                        break;
610                    }
611
612                    let is_empty1 = reads1.is_empty() && writes1.is_empty();
613
614                    // A node can be reordered into the current group if:
615                    // 1. It's empty (no ranges) OR doesn't conflict with mrs0
616                    // 2. It doesn't conflict with mrs1 (skipped-over nodes)
617                    if (is_empty1 || !mrs0.conflicts(reads1, writes1))
618                        && !mrs1.conflicts(reads1, writes1)
619                    {
620                        // Pull into current concurrent group.
621                        mrs0.add(reads1, writes1);
622                        result.push(i1);
623                        used[i1] = true;
624                    } else {
625                        // Not eligible — expand the skipped-over set.
626                        mrs1.add(reads1, writes1);
627                    }
628                }
629
630                // Finalize the current concurrent group.
631                mrs0.reset();
632            }
633
634            // Expand the concurrent group with node0.
635            // (Barriers were stripped, so this is always a Dispatch.)
636            let _ = op_kind0; // suppress unused warning
637            mrs0.add(reads0, writes0);
638            result.push(i0);
639        }
640
641        // Apply the permutation to produce the reordered node list.
642        let mut reordered_count = 0u32;
643        for (pos, &orig_idx) in result.iter().enumerate() {
644            if orig_idx != pos {
645                reordered_count += 1;
646            }
647        }
648
649        // Build the reordered nodes vec.
650        let old_nodes = std::mem::take(&mut self.nodes);
651        self.nodes = result.iter().map(|&idx| old_nodes[idx].clone()).collect();
652
653        // Debug dump if requested.
654        if std::env::var("HF2Q_REORDER_DUMP").is_ok() {
655            eprintln!(
656                "  [REORDER] nodes={} reordered={} ({:.1}%)",
657                n,
658                reordered_count,
659                100.0 * reordered_count as f64 / n as f64,
660            );
661        }
662
663        reordered_count
664    }
665
666    /// Get the Metal buffer pointer for a binding at the given slot index.
667    ///
668    /// Returns `Some(ptr)` if the slot has a `RecordedBinding::Buffer`,
669    /// `None` otherwise.
670    fn buffer_ptr_for_slot(bindings: &[(u64, RecordedBinding)], slot: u64) -> Option<*const std::ffi::c_void> {
671        for (idx, binding) in bindings {
672            if *idx == slot {
673                if let RecordedBinding::Buffer { metal_buffer, offset: _ } = binding {
674                    // Use the Metal buffer's GPU address as the identity key.
675                    // On Apple Silicon unified memory, this uniquely identifies
676                    // the allocation.
677                    let ptr: *const std::ffi::c_void = metal_buffer.as_ptr() as *const _;
678                    return Some(ptr);
679                }
680            }
681        }
682        None
683    }
684
685    /// Clone the binding at the given slot index.
686    fn get_binding(bindings: &[(u64, RecordedBinding)], slot: u64) -> Option<RecordedBinding> {
687        for (idx, binding) in bindings {
688            if *idx == slot {
689                return Some(binding.clone());
690            }
691        }
692        None
693    }
694
695    /// Map a norm pipeline to its fused norm+mul pipeline name.
696    ///
697    /// The pipeline's `label()` is set by Metal to the function name, so we
698    /// can match on it.  Returns `None` if the pipeline is not a known norm.
699    fn fused_pipeline_name(pipeline: &metal::ComputePipelineState) -> Option<&'static str> {
700        match pipeline.label() {
701            "rms_norm_f32" => Some("rms_norm_mul_f32"),
702            "rms_norm_f16" => Some("rms_norm_mul_f16"),
703            "rms_norm_bf16" => Some("rms_norm_mul_bf16"),
704            _ => None,
705        }
706    }
707}
708
709impl Default for ComputeGraph {
710    fn default() -> Self {
711        Self::new()
712    }
713}
714
715// ---------------------------------------------------------------------------
716// Dual-buffer encoding helpers (Phase 4e.4)
717// ---------------------------------------------------------------------------
718
719/// Find the node index where the n0-th dispatch starts.
720///
721/// Counts `CapturedNode::Dispatch` nodes until `n0` are reached, then returns
722/// the index of the n0-th dispatch (i.e., the first node of the second chunk).
723/// If `n0 >= dispatch_count`, returns `nodes.len()` (everything in chunk 0).
724fn find_dispatch_split_index(nodes: &[CapturedNode], n0: usize) -> usize {
725    let mut dispatches_seen = 0usize;
726    for (i, node) in nodes.iter().enumerate() {
727        if matches!(node, CapturedNode::Dispatch { .. }) {
728            dispatches_seen += 1;
729            if dispatches_seen == n0 {
730                return i + 1; // split AFTER the n0-th dispatch
731            }
732        }
733    }
734    nodes.len()
735}
736
737/// Encode a slice of captured nodes into a command encoder, recomputing
738/// barriers on the fly from each node's read/write buffer ranges.
739///
740/// This is the chunked counterpart of `ComputeGraph::encode_with_barriers()`.
741/// Factored out so both halves of a dual-buffer encode can use it.
742///
743/// Returns the number of barriers emitted.
744fn encode_chunk_with_barriers(nodes: &[CapturedNode], encoder: &mut CommandEncoder) -> u32 {
745    let mut tracker = ReorderConflictTracker::new();
746    let mut barrier_count = 0u32;
747
748    for node in nodes {
749        match node {
750            CapturedNode::Dispatch {
751                pipeline,
752                bindings,
753                threads_per_grid,
754                threads_per_threadgroup,
755                threadgroup_memory,
756                dispatch_kind,
757                reads,
758                writes,
759                op_kind,
760                ..
761            } => {
762                let has_ranges = !reads.is_empty() || !writes.is_empty();
763                if has_ranges && tracker.conflicts(reads, writes) {
764                    encoder.memory_barrier();
765                    tracker.reset();
766                    barrier_count += 1;
767                }
768                if has_ranges {
769                    tracker.add(reads, writes);
770                }
771                // ADR-015 iter63 (Phase A.3): forward captured op_kind so the
772                // per-dispatch profile dump groups dual-buffer chunked replay
773                // entries by the same op_kind label as direct dispatch.
774                encoder.set_op_kind(*op_kind);
775                encoder.replay_dispatch(
776                    pipeline,
777                    bindings,
778                    threadgroup_memory,
779                    *threads_per_grid,
780                    *threads_per_threadgroup,
781                    *dispatch_kind,
782                );
783            }
784            CapturedNode::Barrier => {
785                encoder.memory_barrier();
786                tracker.reset();
787                barrier_count += 1;
788            }
789        }
790    }
791    barrier_count
792}
793
794// ---------------------------------------------------------------------------
795// ReorderConflictTracker — range-based conflict detection for the reorder pass
796// ---------------------------------------------------------------------------
797
798/// Memory range conflict tracker for the reorder pass (Phase 4e.3).
799///
800/// Works with `MemRange` tuples `(start, end)` stored on `CapturedNode::Dispatch`,
801/// rather than requiring live `&MlxBuffer` references.  This is the reorder-time
802/// equivalent of the runtime `ConflictTracker`.
803///
804/// Conflict rules match llama.cpp's `ggml_mem_ranges_check`:
805/// - Two read ranges: OK (read-read is concurrent-safe)
806/// - A new read overlapping an existing write: CONFLICT (RAW)
807/// - A new write overlapping any existing range: CONFLICT (WAR/WAW)
808struct ReorderConflictTracker {
809    /// (start, end, is_write) for all ranges in the tracked set.
810    ranges: Vec<(usize, usize, bool)>,
811}
812
813impl ReorderConflictTracker {
814    fn new() -> Self {
815        Self {
816            ranges: Vec::with_capacity(64),
817        }
818    }
819
820    fn reset(&mut self) {
821        self.ranges.clear();
822    }
823
824    /// Check if a dispatch with the given read/write ranges conflicts with
825    /// any range in this tracker.
826    fn conflicts(&self, reads: &[MemRange], writes: &[MemRange]) -> bool {
827        // New reads vs existing writes (RAW)
828        for &(r_start, r_end) in reads {
829            for &(s, e, is_write) in &self.ranges {
830                if is_write && r_start < e && r_end > s {
831                    return true;
832                }
833            }
834        }
835        // New writes vs all existing ranges (WAR/WAW)
836        for &(w_start, w_end) in writes {
837            for &(s, e, _) in &self.ranges {
838                if w_start < e && w_end > s {
839                    return true;
840                }
841            }
842        }
843        false
844    }
845
846    /// Add read and write ranges to the tracked set.
847    fn add(&mut self, reads: &[MemRange], writes: &[MemRange]) {
848        for &(start, end) in reads {
849            self.ranges.push((start, end, false));
850        }
851        for &(start, end) in writes {
852            self.ranges.push((start, end, true));
853        }
854    }
855}
856
857/// Batched Metal dispatch — encodes multiple ops into a single `CommandEncoder`.
858///
859/// Create one per model (or per forward-pass loop).  Call [`begin`](Self::begin)
860/// at the start of each forward pass to get a [`GraphSession`] that holds the
861/// shared encoder.
862pub struct GraphExecutor {
863    device: MlxDevice,
864}
865
866impl GraphExecutor {
867    /// Create a new graph executor backed by the given device.
868    pub fn new(device: MlxDevice) -> Self {
869        Self { device }
870    }
871
872    /// Begin a new forward pass (direct-dispatch mode).
873    ///
874    /// Returns a [`GraphSession`] that holds a fresh `CommandEncoder`.  All ops
875    /// encoded through the session share this single encoder.  Call
876    /// [`GraphSession::finish`] to commit and wait.
877    pub fn begin(&self) -> Result<GraphSession<'_>> {
878        let encoder = self.device.command_encoder()?;
879        // ADR-015 iter63 (Phase B): start a programmatic capture if
880        // MLX_METAL_CAPTURE+METAL_CAPTURE_ENABLED are set.  No-op when
881        // unset; one-shot per process so subsequent forward passes do
882        // not pay the env-check cost more than once.
883        let metal_capture = {
884            let mut c = crate::metal_capture::MetalCapture::from_env(&self.device);
885            if let Some(ref mut cap) = c {
886                cap.begin();
887            }
888            c
889        };
890        Ok(GraphSession {
891            encoder,
892            device: &self.device,
893            barrier_count: 0,
894            tracker: ConflictTracker::new(),
895            dispatch_in_group: 0,
896            total_dispatches: 0,
897            group_sizes: [0; 8],
898            recording: false,
899            metal_capture,
900        })
901    }
902
903    /// Begin a new forward pass in capture (record) mode.
904    ///
905    /// All op calls are recorded into a `ComputeGraph` instead of being
906    /// dispatched to Metal.  When [`GraphSession::finish`] is called, the
907    /// recorded graph is replayed into a fresh encoder via
908    /// `ComputeGraph::encode_sequential()`.
909    ///
910    /// The API is identical to `begin()` — callers do not need to change
911    /// any op call code.  The only behavioral difference: GPU work happens
912    /// at `finish()` time rather than at each op call.
913    pub fn begin_recorded(&self) -> Result<GraphSession<'_>> {
914        let mut encoder = self.device.command_encoder()?;
915        encoder.start_capture();
916        // ADR-015 iter63 (Phase B): see GraphExecutor::begin().
917        let metal_capture = {
918            let mut c = crate::metal_capture::MetalCapture::from_env(&self.device);
919            if let Some(ref mut cap) = c {
920                cap.begin();
921            }
922            c
923        };
924        Ok(GraphSession {
925            encoder,
926            device: &self.device,
927            barrier_count: 0,
928            tracker: ConflictTracker::new(),
929            dispatch_in_group: 0,
930            total_dispatches: 0,
931            group_sizes: [0; 8],
932            recording: true,
933            metal_capture,
934        })
935    }
936
937    /// Borrow the underlying device.
938    pub fn device(&self) -> &MlxDevice {
939        &self.device
940    }
941}
942
943/// A single forward pass execution context.
944///
945/// All ops are encoded into one `CommandEncoder`.  Call [`finish`](Self::finish)
946/// to commit the command buffer and wait for GPU completion — this is the ONLY
947/// sync point per forward pass.
948///
949/// If an op returns an error, the session can be dropped without committing.
950/// The underlying command buffer is abandoned (never committed to the GPU).
951/// Tracks buffer address ranges for automatic barrier elision.
952///
953/// Mirrors llama.cpp's `ggml_mem_ranges` — accumulates the read and write
954/// ranges of all dispatches in the current concurrent group. When a new
955/// dispatch's reads overlap with an existing write (RAW), or its writes
956/// overlap with an existing read or write (WAR/WAW), a barrier is needed.
957/// Otherwise the dispatch can run concurrently and the barrier is elided.
958///
959/// Uses CPU-visible `contents_ptr()` addresses, which on Apple Silicon
960/// unified memory equal the GPU addresses.
961pub struct ConflictTracker {
962    /// (start, end, is_write) tuples for the current concurrent group.
963    ranges: Vec<(usize, usize, bool)>,
964}
965
966impl ConflictTracker {
967    fn new() -> Self {
968        Self {
969            ranges: Vec::with_capacity(32),
970        }
971    }
972
973    /// Reset the tracker — called after emitting a barrier.
974    fn reset(&mut self) {
975        self.ranges.clear();
976    }
977
978    /// Check if a new dispatch with the given reads and writes conflicts
979    /// with the current concurrent group.
980    ///
981    /// Conflict rules (same as llama.cpp `ggml_mem_ranges_check`):
982    /// - Two SRC (read) ranges in the same buffer: OK (read-read)
983    /// - A new SRC overlapping an existing DST: CONFLICT (RAW)
984    /// - A new DST overlapping an existing SRC or DST: CONFLICT (WAR/WAW)
985    /// Check for conflicts and return the reason if one is found.
986    /// Returns (conflict_type, new_buf_ptr, existing_buf_ptr) or None.
987    fn conflicts_reason(&self, reads: &[&MlxBuffer], writes: &[&MlxBuffer])
988        -> Option<(&'static str, usize, usize)>
989    {
990        // Check new reads against existing writes (RAW)
991        for r in reads {
992            let r_start = r.contents_ptr() as usize;
993            let r_end = r_start + r.byte_len();
994            for &(s, e, is_write) in &self.ranges {
995                if is_write && r_start < e && r_end > s {
996                    return Some(("RAW", r_start, s));
997                }
998            }
999        }
1000        // Check new writes against existing reads and writes (WAR/WAW)
1001        for w in writes {
1002            let w_start = w.contents_ptr() as usize;
1003            let w_end = w_start + w.byte_len();
1004            for &(s, e, is_write) in &self.ranges {
1005                if w_start < e && w_end > s {
1006                    let kind = if is_write { "WAW" } else { "WAR" };
1007                    return Some((kind, w_start, s));
1008                }
1009            }
1010        }
1011        None
1012    }
1013
1014    /// Add read and write ranges to the current concurrent group.
1015    fn add(&mut self, reads: &[&MlxBuffer], writes: &[&MlxBuffer]) {
1016        for r in reads {
1017            let start = r.contents_ptr() as usize;
1018            let end = start + r.byte_len();
1019            self.ranges.push((start, end, false));
1020        }
1021        for w in writes {
1022            let start = w.contents_ptr() as usize;
1023            let end = start + w.byte_len();
1024            self.ranges.push((start, end, true));
1025        }
1026    }
1027}
1028
1029pub struct GraphSession<'a> {
1030    encoder: CommandEncoder,
1031    device: &'a MlxDevice,
1032    barrier_count: u32,
1033    tracker: ConflictTracker,
1034    dispatch_in_group: u32,
1035    total_dispatches: u32,
1036    /// Histogram: group_sizes[i] = number of concurrent groups with (i+1) dispatches
1037    group_sizes: [u32; 8],
1038    /// Whether this session was created in capture/record mode.
1039    recording: bool,
1040    /// ADR-015 iter63 (Phase B): optional Metal frame capture wrapping
1041    /// this session's GPU work.  Populated by `MetalCapture::from_env`
1042    /// when `MLX_METAL_CAPTURE=<path>` + `METAL_CAPTURE_ENABLED=1` are
1043    /// both set in the process env AND the process-global one-shot
1044    /// latch has not yet flipped.  `None` in all other cases (default
1045    /// production path).  Capture scope is begun in
1046    /// [`GraphExecutor::begin`] / [`GraphExecutor::begin_recorded`]
1047    /// and ended in [`Self::finish`] / [`Self::commit`] BEFORE the
1048    /// CB is committed, so all enqueued CBs land inside the trace.
1049    metal_capture: Option<crate::metal_capture::MetalCapture>,
1050}
1051
1052impl<'a> GraphSession<'a> {
1053    /// Encode an RMS normalization into this session's encoder.
1054    ///
1055    /// Delegates to [`ops::rms_norm::dispatch_rms_norm`].
1056    pub fn rms_norm(
1057        &mut self,
1058        registry: &mut KernelRegistry,
1059        device: &metal::DeviceRef,
1060        input: &MlxBuffer,
1061        weight: &MlxBuffer,
1062        output: &MlxBuffer,
1063        params_buf: &MlxBuffer,
1064        rows: u32,
1065        dim: u32,
1066    ) -> Result<()> {
1067        ops::rms_norm::dispatch_rms_norm(
1068            &mut self.encoder,
1069            registry,
1070            device,
1071            input,
1072            weight,
1073            output,
1074            params_buf,
1075            rows,
1076            dim,
1077        )
1078    }
1079
1080    /// Encode a quantized matrix multiplication into this session's encoder.
1081    ///
1082    /// Delegates to [`ops::quantized_matmul::quantized_matmul`].
1083    /// Returns the freshly allocated output buffer.
1084    pub fn quantized_matmul(
1085        &mut self,
1086        registry: &mut KernelRegistry,
1087        device: &MlxDevice,
1088        input: &MlxBuffer,
1089        weight: &MlxBuffer,
1090        scales: &MlxBuffer,
1091        biases: &MlxBuffer,
1092        params: &ops::quantized_matmul::QuantizedMatmulParams,
1093    ) -> Result<MlxBuffer> {
1094        ops::quantized_matmul::quantized_matmul(
1095            &mut self.encoder,
1096            registry,
1097            device,
1098            input,
1099            weight,
1100            scales,
1101            biases,
1102            params,
1103        )
1104    }
1105
1106    /// Encode a SIMD-optimized quantized matmul into this session's encoder.
1107    ///
1108    /// Delegates to [`ops::quantized_matmul::quantized_matmul_simd`].
1109    /// Returns the freshly allocated output buffer.
1110    pub fn quantized_matmul_simd(
1111        &mut self,
1112        registry: &mut KernelRegistry,
1113        device: &MlxDevice,
1114        input: &MlxBuffer,
1115        weight: &MlxBuffer,
1116        scales: &MlxBuffer,
1117        biases: &MlxBuffer,
1118        params: &ops::quantized_matmul::QuantizedMatmulParams,
1119    ) -> Result<MlxBuffer> {
1120        ops::quantized_matmul::quantized_matmul_simd(
1121            &mut self.encoder,
1122            registry,
1123            device,
1124            input,
1125            weight,
1126            scales,
1127            biases,
1128            params,
1129        )
1130    }
1131
1132    /// Encode a GGML block-format quantized mat-vec into this session's encoder.
1133    ///
1134    /// Delegates to [`ops::quantized_matmul_ggml::quantized_matmul_ggml`].
1135    pub fn quantized_matmul_ggml(
1136        &mut self,
1137        registry: &mut KernelRegistry,
1138        device: &MlxDevice,
1139        input: &MlxBuffer,
1140        weight: &MlxBuffer,
1141        output: &MlxBuffer,
1142        params: &ops::quantized_matmul_ggml::GgmlQuantizedMatmulParams,
1143    ) -> Result<()> {
1144        ops::quantized_matmul_ggml::quantized_matmul_ggml(
1145            &mut self.encoder,
1146            registry,
1147            device,
1148            input,
1149            weight,
1150            output,
1151            params,
1152        )
1153    }
1154
1155    /// Encode an expert-routed GGML block-format quantized mat-vec into this session's encoder.
1156    ///
1157    /// Delegates to [`ops::quantized_matmul_id_ggml::quantized_matmul_id_ggml`].
1158    #[allow(clippy::too_many_arguments)]
1159    pub fn quantized_matmul_id_ggml(
1160        &mut self,
1161        registry: &mut KernelRegistry,
1162        device: &MlxDevice,
1163        input: &MlxBuffer,
1164        weight: &MlxBuffer,
1165        ids: &MlxBuffer,
1166        output: &MlxBuffer,
1167        params: &ops::quantized_matmul_id_ggml::GgmlQuantizedMatmulIdParams,
1168    ) -> Result<()> {
1169        ops::quantized_matmul_id_ggml::quantized_matmul_id_ggml(
1170            &mut self.encoder,
1171            registry,
1172            device,
1173            input,
1174            weight,
1175            ids,
1176            output,
1177            params,
1178        )
1179    }
1180
1181    /// Pooled-scratch variant of [`Self::quantized_matmul_id_ggml`] — the
1182    /// `IdMmScratch` is caller-owned so batched prefill amortises the
1183    /// per-call allocations that the auto entry point incurs (ADR-011
1184    /// Phase 3 Wave P3b).
1185    #[allow(clippy::too_many_arguments)]
1186    pub fn quantized_matmul_id_ggml_pooled(
1187        &mut self,
1188        registry: &mut KernelRegistry,
1189        device: &MlxDevice,
1190        input: &MlxBuffer,
1191        weight: &MlxBuffer,
1192        ids: &MlxBuffer,
1193        output: &MlxBuffer,
1194        scratch: &mut ops::quantized_matmul_id_ggml::IdMmScratch,
1195        params: &ops::quantized_matmul_id_ggml::GgmlQuantizedMatmulIdParams,
1196    ) -> Result<()> {
1197        ops::quantized_matmul_id_ggml::quantized_matmul_id_ggml_pooled(
1198            &mut self.encoder,
1199            registry,
1200            device,
1201            input,
1202            weight,
1203            ids,
1204            output,
1205            scratch,
1206            params,
1207        )
1208    }
1209
1210    /// Encode scaled dot-product attention into this session's encoder.
1211    ///
1212    /// Delegates to [`ops::sdpa::sdpa`].
1213    pub fn sdpa(
1214        &mut self,
1215        registry: &mut KernelRegistry,
1216        device: &MlxDevice,
1217        q: &MlxBuffer,
1218        k: &MlxBuffer,
1219        v: &MlxBuffer,
1220        output: &MlxBuffer,
1221        params: &ops::sdpa::SdpaParams,
1222        batch_size: u32,
1223    ) -> Result<()> {
1224        ops::sdpa::sdpa(
1225            &mut self.encoder,
1226            registry,
1227            device,
1228            q,
1229            k,
1230            v,
1231            output,
1232            params,
1233            batch_size,
1234        )
1235    }
1236
1237    /// Encode flash attention vector (SIMD-vectorized decode-path SDPA).
1238    ///
1239    /// Delegates to [`ops::flash_attn_vec::flash_attn_vec`].
1240    pub fn flash_attn_vec(
1241        &mut self,
1242        registry: &mut KernelRegistry,
1243        device: &MlxDevice,
1244        q: &MlxBuffer,
1245        k: &MlxBuffer,
1246        v: &MlxBuffer,
1247        output: &MlxBuffer,
1248        tmp: &MlxBuffer,
1249        params: &ops::flash_attn_vec::FlashAttnVecParams,
1250    ) -> Result<()> {
1251        ops::flash_attn_vec::flash_attn_vec(
1252            &mut self.encoder,
1253            registry,
1254            device,
1255            q,
1256            k,
1257            v,
1258            output,
1259            tmp,
1260            params,
1261        )
1262    }
1263
1264    /// Encode an elementwise add into this session's encoder.
1265    ///
1266    /// Delegates to [`ops::elementwise::elementwise_add`].
1267    pub fn elementwise_add(
1268        &mut self,
1269        registry: &mut KernelRegistry,
1270        device: &metal::DeviceRef,
1271        a: &MlxBuffer,
1272        b: &MlxBuffer,
1273        output: &MlxBuffer,
1274        n_elements: usize,
1275        dtype: DType,
1276    ) -> Result<()> {
1277        ops::elementwise::elementwise_add(
1278            &mut self.encoder,
1279            registry,
1280            device,
1281            a,
1282            b,
1283            output,
1284            n_elements,
1285            dtype,
1286        )
1287    }
1288
1289    /// Encode an elementwise multiply into this session's encoder.
1290    ///
1291    /// Delegates to [`ops::elementwise::elementwise_mul`].
1292    pub fn elementwise_mul(
1293        &mut self,
1294        registry: &mut KernelRegistry,
1295        device: &metal::DeviceRef,
1296        a: &MlxBuffer,
1297        b: &MlxBuffer,
1298        output: &MlxBuffer,
1299        n_elements: usize,
1300        dtype: DType,
1301    ) -> Result<()> {
1302        ops::elementwise::elementwise_mul(
1303            &mut self.encoder,
1304            registry,
1305            device,
1306            a,
1307            b,
1308            output,
1309            n_elements,
1310            dtype,
1311        )
1312    }
1313
1314    /// Encode a RoPE transform into this session's encoder.
1315    ///
1316    /// Delegates to [`ops::rope::dispatch_rope`].
1317    pub fn rope(
1318        &mut self,
1319        registry: &mut KernelRegistry,
1320        device: &metal::DeviceRef,
1321        input: &MlxBuffer,
1322        output: &MlxBuffer,
1323        params_buf: &MlxBuffer,
1324        positions_buf: &MlxBuffer,
1325        seq_len: u32,
1326        head_dim: u32,
1327    ) -> Result<()> {
1328        ops::rope::dispatch_rope(
1329            &mut self.encoder,
1330            registry,
1331            device,
1332            input,
1333            output,
1334            params_buf,
1335            positions_buf,
1336            seq_len,
1337            head_dim,
1338        )
1339    }
1340
1341    /// Encode a GELU activation into this session's encoder.
1342    ///
1343    /// Delegates to [`ops::gelu::dispatch_gelu`].
1344    pub fn gelu(
1345        &mut self,
1346        registry: &mut KernelRegistry,
1347        device: &metal::DeviceRef,
1348        input: &MlxBuffer,
1349        output: &MlxBuffer,
1350    ) -> Result<()> {
1351        ops::gelu::dispatch_gelu(
1352            &mut self.encoder,
1353            registry,
1354            device,
1355            input,
1356            output,
1357        )
1358    }
1359
1360    /// Encode a softmax into this session's encoder.
1361    ///
1362    /// Delegates to [`ops::softmax::dispatch_softmax`].
1363    pub fn softmax(
1364        &mut self,
1365        registry: &mut KernelRegistry,
1366        device: &metal::DeviceRef,
1367        input: &MlxBuffer,
1368        output: &MlxBuffer,
1369        params_buf: &MlxBuffer,
1370        rows: u32,
1371        cols: u32,
1372    ) -> Result<()> {
1373        ops::softmax::dispatch_softmax(
1374            &mut self.encoder,
1375            registry,
1376            device,
1377            input,
1378            output,
1379            params_buf,
1380            rows,
1381            cols,
1382        )
1383    }
1384
1385    /// Encode a softcap into this session's encoder.
1386    ///
1387    /// Delegates to [`ops::softcap::dispatch_softcap`].
1388    pub fn softcap(
1389        &mut self,
1390        registry: &mut KernelRegistry,
1391        device: &metal::DeviceRef,
1392        input: &MlxBuffer,
1393        output: &MlxBuffer,
1394        params_buf: &MlxBuffer,
1395        cap: f32,
1396    ) -> Result<()> {
1397        ops::softcap::dispatch_softcap(
1398            &mut self.encoder,
1399            registry,
1400            device,
1401            input,
1402            output,
1403            params_buf,
1404            cap,
1405        )
1406    }
1407
1408    /// Encode an RMS norm without learned scale (f32) into this session's encoder.
1409    ///
1410    /// Delegates to [`ops::rms_norm::dispatch_rms_norm_no_scale_f32`].
1411    pub fn rms_norm_no_scale_f32(
1412        &mut self,
1413        registry: &mut KernelRegistry,
1414        device: &metal::DeviceRef,
1415        input: &MlxBuffer,
1416        output: &MlxBuffer,
1417        params_buf: &MlxBuffer,
1418        rows: u32,
1419        dim: u32,
1420    ) -> Result<()> {
1421        ops::rms_norm::dispatch_rms_norm_no_scale_f32(
1422            &mut self.encoder,
1423            registry,
1424            device,
1425            input,
1426            output,
1427            params_buf,
1428            rows,
1429            dim,
1430        )
1431    }
1432
1433    /// Encode a NeoX RoPE (f32) with optional freq_factors into this session's encoder.
1434    ///
1435    /// Delegates to [`ops::rope::dispatch_rope_neox_f32`].
1436    #[allow(clippy::too_many_arguments)]
1437    pub fn rope_neox_f32(
1438        &mut self,
1439        registry: &mut KernelRegistry,
1440        device: &metal::DeviceRef,
1441        input: &MlxBuffer,
1442        output: &MlxBuffer,
1443        params_buf: &MlxBuffer,
1444        positions_buf: &MlxBuffer,
1445        freq_factors: Option<&MlxBuffer>,
1446        seq_len: u32,
1447        n_heads: u32,
1448        head_dim: u32,
1449        rope_dim: u32,
1450    ) -> Result<()> {
1451        ops::rope::dispatch_rope_neox_f32(
1452            &mut self.encoder,
1453            registry,
1454            device,
1455            input,
1456            output,
1457            params_buf,
1458            positions_buf,
1459            freq_factors,
1460            seq_len,
1461            n_heads,
1462            head_dim,
1463            rope_dim,
1464        )
1465    }
1466
1467    /// Insert a GPU memory barrier (MTLBarrierScopeBuffers).
1468    ///
1469    /// Unconditional barrier — always emits. Use `barrier_between` for
1470    /// automatic conflict detection that can elide unnecessary barriers.
1471    #[inline]
1472    pub fn barrier(&mut self) {
1473        // Record the outgoing group size
1474        if self.dispatch_in_group > 0 {
1475            let idx = (self.dispatch_in_group as usize).min(self.group_sizes.len()) - 1;
1476            self.group_sizes[idx] += 1;
1477        }
1478        self.encoder.memory_barrier();
1479        self.tracker.reset();
1480        self.barrier_count += 1;
1481        self.dispatch_in_group = 0;
1482    }
1483
1484    /// Smart barrier with conflict detection.
1485    ///
1486    /// Checks if the next dispatch (with the given read and write buffers)
1487    /// actually conflicts with any dispatch in the current concurrent group.
1488    /// If yes, emits a Metal barrier and resets the tracker. If no, the
1489    /// barrier is elided and the dispatch can run concurrently.
1490    ///
1491    /// This mirrors llama.cpp's `ggml_metal_op_concurrency_check` +
1492    /// `ggml_metal_op_concurrency_reset` pattern.
1493    #[inline]
1494    pub fn barrier_between(&mut self, reads: &[&MlxBuffer], writes: &[&MlxBuffer]) {
1495        // In capture mode, stash the read/write ranges so the next captured
1496        // dispatch node carries them for the reorder pass (Phase 4e.3).
1497        if self.recording {
1498            let read_ranges: Vec<MemRange> = reads
1499                .iter()
1500                .map(|b| {
1501                    let start = b.contents_ptr() as usize;
1502                    (start, start + b.byte_len())
1503                })
1504                .collect();
1505            let write_ranges: Vec<MemRange> = writes
1506                .iter()
1507                .map(|b| {
1508                    let start = b.contents_ptr() as usize;
1509                    (start, start + b.byte_len())
1510                })
1511                .collect();
1512            self.encoder.set_pending_buffer_ranges(read_ranges, write_ranges);
1513        }
1514
1515        let reason = self.tracker.conflicts_reason(reads, writes);
1516        if let Some((_kind, _new_ptr, _existing_ptr)) = reason {
1517            // Record the outgoing group size before resetting
1518            if self.dispatch_in_group > 0 {
1519                let idx = (self.dispatch_in_group as usize).min(self.group_sizes.len()) - 1;
1520                self.group_sizes[idx] += 1;
1521            }
1522            self.encoder.memory_barrier();
1523            self.tracker.reset();
1524            self.barrier_count += 1;
1525            self.dispatch_in_group = 0;
1526        }
1527        self.dispatch_in_group += 1;
1528        self.total_dispatches += 1;
1529        self.tracker.add(reads, writes);
1530    }
1531
1532    /// Print group size histogram to stderr (for HF2Q_MLX_TIMING debug).
1533    pub fn dump_group_stats(&self) {
1534        // Record the final (unterminated) group
1535        let mut gs = self.group_sizes;
1536        if self.dispatch_in_group > 0 {
1537            let idx = (self.dispatch_in_group as usize).min(gs.len()) - 1;
1538            gs[idx] += 1;
1539        }
1540        let total_groups: u32 = gs.iter().sum();
1541        eprintln!("  [GROUP_STATS] dispatches={} barriers={} groups={} ratio={:.2}",
1542            self.total_dispatches, self.barrier_count, total_groups,
1543            if total_groups > 0 { self.total_dispatches as f64 / total_groups as f64 } else { 0.0 });
1544        for (i, &count) in gs.iter().enumerate() {
1545            if count > 0 {
1546                eprintln!("    size {}: {} groups", i + 1, count);
1547            }
1548        }
1549    }
1550
1551    /// Register a dispatch's buffer ranges without checking for conflicts.
1552    ///
1553    /// Use after dispatching an op that doesn't need a barrier check (e.g.,
1554    /// the first dispatch in a session, or dispatches known to be concurrent).
1555    ///
1556    /// In recording mode, also retroactively annotates the most recently
1557    /// captured dispatch node with these ranges if it was missing them.
1558    /// That keeps the reorder pass able to reason about dispatches that
1559    /// were preceded by `track_dispatch` rather than `barrier_between`.
1560    #[inline]
1561    pub fn track_dispatch(&mut self, reads: &[&MlxBuffer], writes: &[&MlxBuffer]) {
1562        if self.recording {
1563            let read_ranges: Vec<MemRange> = reads
1564                .iter()
1565                .map(|b| {
1566                    let start = b.contents_ptr() as usize;
1567                    (start, start + b.byte_len())
1568                })
1569                .collect();
1570            let write_ranges: Vec<MemRange> = writes
1571                .iter()
1572                .map(|b| {
1573                    let start = b.contents_ptr() as usize;
1574                    (start, start + b.byte_len())
1575                })
1576                .collect();
1577            self.encoder
1578                .annotate_last_dispatch_if_missing(read_ranges, write_ranges);
1579        }
1580        self.tracker.add(reads, writes);
1581    }
1582
1583    /// Return the number of barriers inserted so far in this session.
1584    #[inline]
1585    pub fn barrier_count(&self) -> u32 {
1586        self.barrier_count
1587    }
1588
1589    /// Cumulative nanoseconds spent in ConflictTracker checks (diagnostic).
1590    /// Returns 0 when timing is not compiled in.
1591    pub fn tracker_overhead_ns(&self) -> u64 {
1592        0
1593    }
1594
1595    /// Borrow the underlying command encoder for direct op dispatch.
1596    ///
1597    /// Use this when you need to call an op function that is not wrapped by
1598    /// a `GraphSession` method.  The returned encoder is the same shared
1599    /// encoder — all dispatches still go into the same command buffer.
1600    pub fn encoder_mut(&mut self) -> &mut CommandEncoder {
1601        &mut self.encoder
1602    }
1603
1604    /// Borrow the device reference.
1605    pub fn device(&self) -> &MlxDevice {
1606        self.device
1607    }
1608
1609    /// Whether this session is in capture/record mode.
1610    pub fn is_recording(&self) -> bool {
1611        self.recording
1612    }
1613
1614    /// Commit the command buffer and wait for GPU completion.
1615    ///
1616    /// This is the ONLY sync point per forward pass.  After this call, all
1617    /// output buffers are readable by the CPU.
1618    ///
1619    /// In recording mode: extracts the captured graph, replays it into
1620    /// the encoder via `ComputeGraph::encode_sequential()`, then commits
1621    /// and waits.  The result is identical to the direct-dispatch path.
1622    ///
1623    /// Consumes the session — no further ops can be encoded.
1624    pub fn finish(mut self) -> Result<()> {
1625        if self.recording {
1626            if let Some(nodes) = self.encoder.take_capture() {
1627                let graph = ComputeGraph::from_nodes(nodes);
1628                graph.encode_sequential(&mut self.encoder);
1629            }
1630        }
1631        self.encoder.commit_and_wait()
1632    }
1633
1634    /// Commit the command buffer WITHOUT waiting.
1635    ///
1636    /// The GPU begins executing immediately.  Use this for fire-and-forget
1637    /// dispatch when you do not need results until later.
1638    ///
1639    /// In recording mode: replays the captured graph before committing.
1640    ///
1641    /// Consumes the session.
1642    pub fn commit(mut self) -> CommandEncoder {
1643        if self.recording {
1644            if let Some(nodes) = self.encoder.take_capture() {
1645                let graph = ComputeGraph::from_nodes(nodes);
1646                graph.encode_sequential(&mut self.encoder);
1647            }
1648        }
1649        self.encoder.commit();
1650        // ADR-015 iter63 (Phase B): close the capture window AFTER
1651        // commit (so the CB is recorded inside the trace) but BEFORE
1652        // returning the encoder (so the trace finalizes promptly).
1653        // `MTLCaptureManager.stopCapture` marks the recording
1654        // boundary at exactly this point.  CBs committed BEFORE this
1655        // line are in the trace; any work done by the caller through
1656        // the returned encoder is NOT.  This matches llama.cpp's
1657        // ggml-metal-context.m:608 pattern (`stopCapture` after the
1658        // last `commit + waitUntilCompleted`).
1659        //
1660        // Note: for the async commit() path, the GPU may still be
1661        // executing the CB when stopCapture fires.  Apple's
1662        // MTLCaptureManager is documented to flush in-flight work
1663        // into the trace before finalizing the file.
1664        if let Some(mut cap) = self.metal_capture.take() {
1665            cap.end();
1666        }
1667        self.encoder
1668    }
1669
1670    /// Commit the command buffer and wait, returning split timing.
1671    ///
1672    /// Returns `(encoding_ns, gpu_wait_ns)` where:
1673    /// - `encoding_ns` is the time from session begin to commit (CPU encoding)
1674    /// - `gpu_wait_ns` is the time from commit to GPU completion
1675    ///
1676    /// The `session_begin` instant should be captured right after `exec.begin()`.
1677    ///
1678    /// In recording mode: replays the captured graph before committing.
1679    ///
1680    /// Consumes the session.
1681    pub fn finish_with_timing(mut self, session_begin: std::time::Instant) -> Result<(u64, u64)> {
1682        if self.recording {
1683            if let Some(nodes) = self.encoder.take_capture() {
1684                let graph = ComputeGraph::from_nodes(nodes);
1685                graph.encode_sequential(&mut self.encoder);
1686            }
1687        }
1688        let commit_start = std::time::Instant::now();
1689        let encoding_ns = commit_start.duration_since(session_begin).as_nanos() as u64;
1690        self.encoder.commit();
1691        self.encoder.wait_until_completed()?;
1692        let gpu_wait_ns = commit_start.elapsed().as_nanos() as u64;
1693        Ok((encoding_ns, gpu_wait_ns))
1694    }
1695
1696    /// Finish this session and return the GPU wall-clock interval in ns.
1697    ///
1698    /// Returns `(gpu_interval_ns,)` where `gpu_interval_ns` is the CFTimeInterval
1699    /// difference between `MTLCommandBuffer.GPUEndTime` and
1700    /// `MTLCommandBuffer.GPUStartTime`, converted to ns.  Excludes CPU
1701    /// commit+wait overhead — that appears in the residual when bucket
1702    /// sums are compared to the outer wall-clock.
1703    ///
1704    /// Used by `HF2Q_PROFILE_GPU_TS=1` to accumulate pure GPU time per
1705    /// op bucket.  In recording mode: replays the captured graph before
1706    /// committing.
1707    ///
1708    /// Consumes the session.
1709    pub fn finish_with_gpu_time(mut self) -> Result<u64> {
1710        if self.recording {
1711            if let Some(nodes) = self.encoder.take_capture() {
1712                let graph = ComputeGraph::from_nodes(nodes);
1713                graph.encode_sequential(&mut self.encoder);
1714            }
1715        }
1716        let (gs, ge) = self.encoder.commit_wait_with_gpu_time()?;
1717        // GPUStartTime/GPUEndTime are CFTimeInterval (seconds, double).
1718        // Guard against negative deltas (can happen on the first CB of
1719        // a run if the kernel driver lazily initialises the timeline;
1720        // clamp to zero in that case).
1721        let delta = (ge - gs).max(0.0);
1722        Ok((delta * 1.0e9) as u64)
1723    }
1724
1725    /// Finish with fusion: run the RMS norm + MUL fusion pass before
1726    /// replaying the graph.
1727    ///
1728    /// Only meaningful in recording mode.  In direct-dispatch mode, this
1729    /// behaves identically to `finish()`.
1730    ///
1731    /// Returns `(fusions_applied,)` on success.
1732    pub fn finish_with_fusion(
1733        mut self,
1734        registry: &mut KernelRegistry,
1735        device: &metal::DeviceRef,
1736    ) -> Result<u32> {
1737        let mut fusions = 0;
1738        if self.recording {
1739            if let Some(nodes) = self.encoder.take_capture() {
1740                let mut graph = ComputeGraph::from_nodes(nodes);
1741                fusions = graph.fuse(registry, device)?;
1742                graph.encode_sequential(&mut self.encoder);
1743            }
1744        }
1745        self.encoder.commit_and_wait()?;
1746        Ok(fusions)
1747    }
1748
1749    /// Async-commit variant of `finish_with_fusion`.
1750    ///
1751    /// Runs the fusion pass on the captured graph, replays the fused
1752    /// graph into a fresh command buffer, then commits *without waiting*
1753    /// (fire-and-forget — GPU executes asynchronously while CPU returns).
1754    /// Mirrors `commit()` semantics, plus the fusion optimization pass.
1755    ///
1756    /// Used by the prefill per-layer pattern where the next layer's CPU
1757    /// encoding should overlap with this layer's GPU execution — peer's
1758    /// llama.cpp Metal backend uses the same async-commit-with-graph-opt
1759    /// pattern at `ggml-metal-context.m:617-621`.
1760    ///
1761    /// In direct-dispatch mode (recording=false), behaves identically to
1762    /// `commit()`: no fusion happens, just an async commit.
1763    ///
1764    /// Consumes the session.  Returns `(encoder, fusions_applied)`.
1765    ///
1766    /// ADR-029 iter-39 H40 — first step of graph_opt port to prefill.
1767    pub fn commit_with_fusion(
1768        mut self,
1769        registry: &mut KernelRegistry,
1770        device: &metal::DeviceRef,
1771    ) -> Result<(CommandEncoder, u32)> {
1772        let mut fusions = 0;
1773        if self.recording {
1774            if let Some(nodes) = self.encoder.take_capture() {
1775                let mut graph = ComputeGraph::from_nodes(nodes);
1776                fusions = graph.fuse(registry, device)?;
1777                graph.encode_sequential(&mut self.encoder);
1778            }
1779        }
1780        self.encoder.commit();
1781        // Close the metal_capture window after commit (same pattern as `commit`).
1782        if let Some(mut cap) = self.metal_capture.take() {
1783            cap.end();
1784        }
1785        Ok((self.encoder, fusions))
1786    }
1787
1788    /// Finish with fusion and split timing.
1789    ///
1790    /// Like `finish_with_timing` but runs the fusion pass first.
1791    /// Returns `(encoding_ns, gpu_wait_ns, fusions_applied)`.
1792    pub fn finish_with_fusion_and_timing(
1793        mut self,
1794        registry: &mut KernelRegistry,
1795        device: &metal::DeviceRef,
1796        session_begin: std::time::Instant,
1797    ) -> Result<(u64, u64, u32)> {
1798        let mut fusions = 0;
1799        if self.recording {
1800            if let Some(nodes) = self.encoder.take_capture() {
1801                let mut graph = ComputeGraph::from_nodes(nodes);
1802                fusions = graph.fuse(registry, device)?;
1803                graph.encode_sequential(&mut self.encoder);
1804            }
1805        }
1806        let commit_start = std::time::Instant::now();
1807        let encoding_ns = commit_start.duration_since(session_begin).as_nanos() as u64;
1808        self.encoder.commit();
1809        self.encoder.wait_until_completed()?;
1810        let gpu_wait_ns = commit_start.elapsed().as_nanos() as u64;
1811        Ok((encoding_ns, gpu_wait_ns, fusions))
1812    }
1813
1814    /// Finish with fusion AND reorder: run both graph optimization passes
1815    /// before replaying the graph.
1816    ///
1817    /// Only meaningful in recording mode.  In direct-dispatch mode, this
1818    /// behaves identically to `finish()`.
1819    ///
1820    /// Returns `(fusions_applied, nodes_reordered)` on success.
1821    pub fn finish_with_fusion_and_reorder(
1822        mut self,
1823        registry: &mut KernelRegistry,
1824        device: &metal::DeviceRef,
1825    ) -> Result<(u32, u32)> {
1826        let mut fusions = 0;
1827        let mut reordered = 0;
1828        if self.recording {
1829            if let Some(nodes) = self.encoder.take_capture() {
1830                let mut graph = ComputeGraph::from_nodes(nodes);
1831                fusions = graph.fuse(registry, device)?;
1832                reordered = graph.reorder();
1833                graph.encode_with_barriers(&mut self.encoder);
1834            }
1835        }
1836        self.encoder.commit_and_wait()?;
1837        Ok((fusions, reordered))
1838    }
1839
1840    /// Finish with fusion, reorder, and split timing.
1841    ///
1842    /// Like `finish_with_fusion_and_timing` but also runs the reorder pass.
1843    /// Returns `(encoding_ns, gpu_wait_ns, fusions_applied, nodes_reordered)`.
1844    pub fn finish_with_fusion_reorder_and_timing(
1845        mut self,
1846        registry: &mut KernelRegistry,
1847        device: &metal::DeviceRef,
1848        session_begin: std::time::Instant,
1849    ) -> Result<(u64, u64, u32, u32)> {
1850        let mut fusions = 0;
1851        let mut reordered = 0;
1852        if self.recording {
1853            if let Some(nodes) = self.encoder.take_capture() {
1854                let mut graph = ComputeGraph::from_nodes(nodes);
1855                fusions = graph.fuse(registry, device)?;
1856                reordered = graph.reorder();
1857                graph.encode_with_barriers(&mut self.encoder);
1858            }
1859        }
1860        let commit_start = std::time::Instant::now();
1861        let encoding_ns = commit_start.duration_since(session_begin).as_nanos() as u64;
1862        self.encoder.commit();
1863        self.encoder.wait_until_completed()?;
1864        let gpu_wait_ns = commit_start.elapsed().as_nanos() as u64;
1865        Ok((encoding_ns, gpu_wait_ns, fusions, reordered))
1866    }
1867
1868    /// Finish with the full optimization pipeline: fuse, reorder, dual-buffer
1869    /// encode.
1870    ///
1871    /// Runs the fusion pass, reorder pass, then encodes the graph into two
1872    /// Metal command buffers for CPU/GPU overlap.  The first ~10% of dispatches
1873    /// are committed immediately so the GPU can start executing while the CPU
1874    /// encodes the remaining ~90%.
1875    ///
1876    /// Only meaningful in recording mode.  In direct-dispatch mode, this
1877    /// behaves identically to `finish()`.
1878    ///
1879    /// Returns `(fusions_applied, nodes_reordered, barriers_buf0, barriers_buf1)`.
1880    pub fn finish_optimized(
1881        mut self,
1882        registry: &mut KernelRegistry,
1883        device: &metal::DeviceRef,
1884    ) -> Result<(u32, u32, u32, u32)> {
1885        let mut fusions = 0;
1886        let mut reordered = 0;
1887        let mut barriers0 = 0u32;
1888        let mut barriers1 = 0u32;
1889
1890        if self.recording {
1891            if let Some(nodes) = self.encoder.take_capture() {
1892                // Commit the capture encoder's empty command buffer so its
1893                // MTLCommandQueue pool slot is freed (same fix as timing variant).
1894                self.encoder.commit();
1895
1896                let mut graph = ComputeGraph::from_nodes(nodes);
1897                fusions = graph.fuse(registry, device)?;
1898                reordered = graph.reorder();
1899
1900                let mut enc0 = self.device.command_encoder()?;
1901                let mut enc1 = self.device.command_encoder()?;
1902
1903                let (b0, b1) = graph.encode_dual_buffer(&mut enc0, &mut enc1);
1904                barriers0 = b0;
1905                barriers1 = b1;
1906
1907                // enc0 was already committed inside encode_dual_buffer.
1908                // Commit enc1 and wait — Metal queue ordering guarantees enc0
1909                // finishes before enc1 starts executing.
1910                enc1.commit_and_wait()?;
1911
1912                // The original encoder was never committed (capture mode drained
1913                // it). We need to end it cleanly — dropping it will end the
1914                // active encoder if any, and the uncommitted command buffer is
1915                // abandoned.  That is safe: Metal silently drops uncommitted
1916                // command buffers.
1917                return Ok((fusions, reordered, barriers0, barriers1));
1918            }
1919        }
1920
1921        // Direct-dispatch fallback: just commit the original encoder.
1922        self.encoder.commit_and_wait()?;
1923        Ok((fusions, reordered, barriers0, barriers1))
1924    }
1925
1926    /// Finish with the full optimization pipeline and split timing.
1927    ///
1928    /// Like `finish_optimized` but returns timing information.
1929    /// Returns `(encoding_ns, gpu_wait_ns, fusions, reordered, barriers_buf0, barriers_buf1)`.
1930    ///
1931    /// Timing breakdown:
1932    /// - `encoding_ns`: CPU time from session begin to first buffer commit
1933    ///   (fusion + reorder + encode chunk 0)
1934    /// - `gpu_wait_ns`: wall time from second buffer commit to GPU completion
1935    ///   (includes GPU execution of both buffers, overlapped with chunk 1 encoding)
1936    pub fn finish_optimized_with_timing(
1937        mut self,
1938        registry: &mut KernelRegistry,
1939        device: &metal::DeviceRef,
1940        session_begin: std::time::Instant,
1941    ) -> Result<(u64, u64, u32, u32, u32, u32)> {
1942        let mut fusions = 0;
1943        let mut reordered = 0;
1944        let mut barriers0 = 0u32;
1945        let mut barriers1 = 0u32;
1946
1947        if self.recording {
1948            if let Some(nodes) = self.encoder.take_capture() {
1949                // Commit the capture encoder's empty command buffer so its
1950                // MTLCommandQueue pool slot is freed.  Without this, each
1951                // token leaks one uncommitted buffer and the queue exhausts
1952                // its ~64-slot pool after ~64 tokens, causing a deadlock.
1953                self.encoder.commit();
1954
1955                let opt_t0 = std::time::Instant::now();
1956                let mut graph = ComputeGraph::from_nodes(nodes);
1957                let fuse_t0 = std::time::Instant::now();
1958                fusions = graph.fuse(registry, device)?;
1959                let fuse_us = fuse_t0.elapsed().as_micros();
1960
1961                let reorder_t0 = std::time::Instant::now();
1962                let unannotated = graph.unannotated_dispatch_count();
1963                if unannotated == 0 {
1964                    reordered = graph.reorder();
1965                } else if std::env::var("HF2Q_MLX_TIMING").is_ok() {
1966                    eprintln!("  [GRAPH_OPT] WARN: skipping reorder — {} of {} dispatches lack range annotations",
1967                        unannotated, graph.dispatch_count());
1968                }
1969                let reorder_us = reorder_t0.elapsed().as_micros();
1970                let opt_us = opt_t0.elapsed().as_micros();
1971
1972                let diag = std::env::var("HF2Q_GRAPH_DIAG").is_ok();
1973                let t0 = std::time::Instant::now();
1974                let mut enc0 = self.device.command_encoder()?;
1975                let mut enc1 = self.device.command_encoder()?;
1976                let enc_create_us = t0.elapsed().as_micros();
1977
1978                let t1 = std::time::Instant::now();
1979                let (b0, b1) = graph.encode_dual_buffer(&mut enc0, &mut enc1);
1980                barriers0 = b0;
1981                barriers1 = b1;
1982                let encode_us = t1.elapsed().as_micros();
1983
1984                let encoding_ns = session_begin.elapsed().as_nanos() as u64;
1985
1986                let wait_start = std::time::Instant::now();
1987                enc1.commit_and_wait()?;
1988                let gpu_wait_ns = wait_start.elapsed().as_nanos() as u64;
1989
1990                if diag {
1991                    eprintln!("  [DIAG] fuse={:.1}ms reorder={:.1}ms opt_total={:.1}ms enc_create={:.1}ms encode={:.1}ms gpu_wait={:.1}ms barriers={}+{}",
1992                        fuse_us as f64 / 1e3, reorder_us as f64 / 1e3, opt_us as f64 / 1e3,
1993                        enc_create_us as f64 / 1e3, encode_us as f64 / 1e3,
1994                        gpu_wait_ns as f64 / 1e6, b0, b1);
1995                }
1996
1997                return Ok((encoding_ns, gpu_wait_ns, fusions, reordered, barriers0, barriers1));
1998            }
1999        }
2000
2001        // Direct-dispatch fallback.
2002        let commit_start = std::time::Instant::now();
2003        let encoding_ns = commit_start.duration_since(session_begin).as_nanos() as u64;
2004        self.encoder.commit();
2005        self.encoder.wait_until_completed()?;
2006        let gpu_wait_ns = commit_start.elapsed().as_nanos() as u64;
2007        Ok((encoding_ns, gpu_wait_ns, fusions, reordered, barriers0, barriers1))
2008    }
2009}