mlx_native/graph.rs
1//! [`GraphExecutor`] — batched Metal dispatch for single-encoder forward passes.
2//!
3//! llama.cpp's speed advantage over candle is NOT the kernels (Phase 0 proved
4//! candle's are as fast or faster per-call). It is the dispatch pattern:
5//! 1 encoder per command buffer instead of ~120. This module implements that
6//! pattern.
7//!
8//! # Usage
9//!
10//! ```ignore
11//! let mut executor = GraphExecutor::new(device.clone());
12//! let mut session = executor.begin()?;
13//!
14//! // All ops encode into the same command buffer — no per-op encoder creation.
15//! session.rms_norm(&mut registry, device.metal_device(), input, weight, output, params, rows, dim)?;
16//! session.quantized_matmul(&mut registry, &device, input, weight, scales, biases, &qparams)?;
17//! session.elementwise_add(&mut registry, device.metal_device(), a, b, out, n, DType::F32)?;
18//!
19//! // Single GPU sync point for the entire forward pass.
20//! session.finish()?;
21//! ```
22//!
23//! # Design
24//!
25//! The `GraphSession` holds a single `CommandEncoder`. Each op method delegates
26//! to the existing op dispatch functions in [`crate::ops`], passing the session's
27//! shared encoder. No new Metal code is needed — the ops already work with a
28//! shared encoder. The executor just prevents creating a new encoder per op.
29//!
30//! # Phase 4e.1 — Graph IR
31//!
32//! The `ComputeGraph` type captures dispatches into a `Vec<CapturedNode>` for
33//! later replay. `GraphExecutor::begin_recorded()` starts a session in capture
34//! mode: all op calls are intercepted at the `CommandEncoder` level and recorded
35//! instead of being sent to Metal. `GraphSession::finish()` detects capture
36//! mode, extracts the recorded graph, and replays it into a fresh encoder via
37//! `ComputeGraph::encode_sequential()`.
38//!
39//! The existing direct-dispatch path (`begin()`) is completely unchanged.
40
41use metal::foreign_types::ForeignType;
42
43use crate::device::MlxDevice;
44use crate::encoder::{CapturedNode, CapturedOpKind, CommandEncoder, MemRange, RecordedBinding};
45use crate::error::Result;
46use crate::kernel_registry::KernelRegistry;
47use crate::ops;
48
49// Re-export types used in the public API so callers don't need separate imports.
50pub use crate::buffer::MlxBuffer;
51pub use crate::dtypes::DType;
52
53// ---------------------------------------------------------------------------
54// OpKind — operation classification for the reorder safety whitelist (4e.3)
55// ---------------------------------------------------------------------------
56
57/// Classification of a compute operation for reorder safety analysis.
58///
59/// Operations marked as reorderable can be freely reordered by the graph
60/// optimizer (Phase 4e.3) as long as their data dependencies allow it.
61/// Non-reorderable operations have side effects or dependencies that
62/// require them to stay in their original sequential position.
63#[derive(Clone, Copy, Debug, PartialEq, Eq)]
64pub enum OpKind {
65 /// Matrix multiplication (reorderable).
66 MatMul,
67 /// Expert-routed matrix multiplication (reorderable).
68 MatMulId,
69 /// Normalization — RMS norm, layer norm (reorderable).
70 Norm,
71 /// Rotary position embedding (reorderable).
72 Rope,
73 /// Elementwise ops — add, mul, scale, gelu, softcap, etc. (reorderable).
74 Elementwise,
75 /// Memory copy — KV cache copy, embedding gather (reorderable).
76 Copy,
77 /// Gather/scatter (reorderable).
78 Gather,
79 /// Scaled dot-product attention (NOT reorderable).
80 Sdpa,
81 /// Softmax (NOT reorderable).
82 Softmax,
83 /// MoE gate with CPU readback dependency (NOT reorderable).
84 MoeGate,
85 /// Anything else (NOT reorderable).
86 Other,
87}
88
89impl OpKind {
90 /// Whether this op kind is safe to reorder in the graph optimizer.
91 pub fn is_reorderable(&self) -> bool {
92 matches!(
93 self,
94 Self::MatMul
95 | Self::MatMulId
96 | Self::Norm
97 | Self::Rope
98 | Self::Elementwise
99 | Self::Copy
100 | Self::Gather
101 )
102 }
103}
104
105// ---------------------------------------------------------------------------
106// ComputeGraph — the recorded graph IR
107// ---------------------------------------------------------------------------
108
109/// A recorded sequence of GPU compute dispatches and barriers.
110///
111/// Created by running a forward pass with the encoder in capture mode.
112/// Can be replayed into a real `CommandEncoder` via `encode_sequential()`,
113/// producing identical Metal dispatch behavior to the original direct path.
114///
115/// Future phases (4e.2, 4e.3) will add fusion and reorder passes that
116/// transform the graph before encoding.
117pub struct ComputeGraph {
118 nodes: Vec<CapturedNode>,
119}
120
121impl ComputeGraph {
122 /// Create an empty compute graph.
123 pub fn new() -> Self {
124 Self {
125 nodes: Vec::with_capacity(128),
126 }
127 }
128
129 /// Create a compute graph from a pre-built list of captured nodes.
130 pub fn from_nodes(nodes: Vec<CapturedNode>) -> Self {
131 Self { nodes }
132 }
133
134 /// Record a captured node into the graph.
135 pub fn record(&mut self, node: CapturedNode) {
136 self.nodes.push(node);
137 }
138
139 /// Number of nodes (dispatches + barriers) in the graph.
140 pub fn len(&self) -> usize {
141 self.nodes.len()
142 }
143
144 /// Whether the graph contains no nodes.
145 pub fn is_empty(&self) -> bool {
146 self.nodes.is_empty()
147 }
148
149 /// Number of dispatch nodes (excludes barriers).
150 pub fn dispatch_count(&self) -> usize {
151 self.nodes
152 .iter()
153 .filter(|n| matches!(n, CapturedNode::Dispatch { .. }))
154 .count()
155 }
156
157 /// Number of barrier nodes.
158 pub fn barrier_count(&self) -> usize {
159 self.nodes
160 .iter()
161 .filter(|n| matches!(n, CapturedNode::Barrier))
162 .count()
163 }
164
165 /// Borrow the node list.
166 pub fn nodes(&self) -> &[CapturedNode] {
167 &self.nodes
168 }
169
170 /// Count dispatch nodes that have empty read/write range annotations.
171 ///
172 /// Used for diagnostics: if >0, the reorder pass cannot guarantee
173 /// correctness because it relies on complete annotations.
174 pub fn unannotated_dispatch_count(&self) -> usize {
175 self.nodes
176 .iter()
177 .filter(|n| matches!(n, CapturedNode::Dispatch { reads, writes, .. }
178 if reads.is_empty() || writes.is_empty()))
179 .count()
180 }
181
182 /// Take ownership of the node list, consuming the graph.
183 pub fn into_nodes(self) -> Vec<CapturedNode> {
184 self.nodes
185 }
186
187 /// Encode all nodes sequentially into the given encoder.
188 ///
189 /// Barrier sentinel nodes emit a Metal memory barrier. Dispatch nodes
190 /// are replayed through `CommandEncoder::replay_dispatch()`.
191 ///
192 /// This produces identical GPU behavior to the direct-dispatch path —
193 /// same pipeline bindings, same dispatch dimensions, same barrier
194 /// placement.
195 ///
196 /// Returns the number of barriers emitted.
197 pub fn encode_sequential(&self, encoder: &mut CommandEncoder) -> u32 {
198 let mut barrier_count = 0u32;
199 for node in &self.nodes {
200 match node {
201 CapturedNode::Barrier => {
202 encoder.memory_barrier();
203 barrier_count += 1;
204 }
205 CapturedNode::Dispatch {
206 pipeline,
207 bindings,
208 threads_per_grid,
209 threads_per_threadgroup,
210 threadgroup_memory,
211 dispatch_kind,
212 op_kind,
213 ..
214 } => {
215 // ADR-015 iter63 (Phase A.3): forward captured op_kind
216 // into replay_dispatch via pending_op_kind so the
217 // per-dispatch profile entries from a recorded graph
218 // (GraphExecutor::begin_recorded path) are tagged with
219 // the same op label as direct-dispatch encoding.
220 encoder.set_op_kind(*op_kind);
221 encoder.replay_dispatch(
222 pipeline,
223 bindings,
224 threadgroup_memory,
225 *threads_per_grid,
226 *threads_per_threadgroup,
227 *dispatch_kind,
228 );
229 }
230 }
231 }
232 barrier_count
233 }
234
235 /// Encode the graph into a Metal command buffer, computing barriers on the
236 /// fly from each node's read/write buffer ranges.
237 ///
238 /// This is the correct encoding method for reordered graphs where barrier
239 /// sentinels have been stripped. Mirrors llama.cpp's encode-time barrier
240 /// insertion via `ggml_metal_op_concurrency_check`.
241 ///
242 /// Returns the number of barriers emitted.
243 pub fn encode_with_barriers(&self, encoder: &mut CommandEncoder) -> u32 {
244 let mut tracker = ReorderConflictTracker::new();
245 let mut barrier_count = 0u32;
246
247 for node in &self.nodes {
248 match node {
249 CapturedNode::Dispatch {
250 pipeline,
251 bindings,
252 threads_per_grid,
253 threads_per_threadgroup,
254 threadgroup_memory,
255 dispatch_kind,
256 reads,
257 writes,
258 op_kind,
259 ..
260 } => {
261 let has_ranges = !reads.is_empty() || !writes.is_empty();
262 if has_ranges && tracker.conflicts(reads, writes) {
263 encoder.memory_barrier();
264 tracker.reset();
265 barrier_count += 1;
266 }
267 if has_ranges {
268 tracker.add(reads, writes);
269 }
270 // ADR-015 iter63 (Phase A.3): see encode_sequential note.
271 encoder.set_op_kind(*op_kind);
272 encoder.replay_dispatch(
273 pipeline,
274 bindings,
275 threadgroup_memory,
276 *threads_per_grid,
277 *threads_per_threadgroup,
278 *dispatch_kind,
279 );
280 }
281 CapturedNode::Barrier => {
282 // Explicit barriers still force a barrier boundary
283 encoder.memory_barrier();
284 tracker.reset();
285 barrier_count += 1;
286 }
287 }
288 }
289 barrier_count
290 }
291
292 /// Encode the graph using two command buffers for CPU/GPU overlap.
293 ///
294 /// The first `n0` dispatches are encoded into `encoder0` and committed
295 /// immediately (GPU starts executing). The remaining dispatches are encoded
296 /// into `encoder1`. The caller is responsible for committing `encoder1`.
297 ///
298 /// This matches llama.cpp's dual command buffer pattern from
299 /// `ggml_metal_graph_compute` (ggml-metal-context.m:441-644):
300 /// `n_nodes_0 = MAX(64, 0.1 * n_nodes)` for the first buffer.
301 ///
302 /// Command buffers submitted to the same `MTLCommandQueue` execute in
303 /// submission order, so `encoder0.commit()` followed by `encoder1.commit()`
304 /// guarantees enc0 finishes before enc1 starts. The win: the GPU starts
305 /// executing enc0 while the CPU is still encoding enc1.
306 ///
307 /// Returns `(barriers_buf0, barriers_buf1)`.
308 pub fn encode_dual_buffer(
309 &self,
310 encoder0: &mut CommandEncoder,
311 encoder1: &mut CommandEncoder,
312 ) -> (u32, u32) {
313 let dispatch_total = self.dispatch_count();
314 let n0 = std::cmp::max(64, dispatch_total / 10);
315
316 // Find the split point: the index of the n0-th dispatch node.
317 let split_idx = find_dispatch_split_index(&self.nodes, n0);
318
319 // Encode first chunk with barrier recomputation, then commit immediately.
320 let barriers0 = encode_chunk_with_barriers(&self.nodes[..split_idx], encoder0);
321 encoder0.commit();
322
323 // Encode second chunk with barrier recomputation.
324 let barriers1 = encode_chunk_with_barriers(&self.nodes[split_idx..], encoder1);
325
326 (barriers0, barriers1)
327 }
328
329 /// Run the RMS norm + MUL fusion pass over the graph.
330 ///
331 /// Scans for the pattern:
332 /// Dispatch(RmsNorm) → Barrier(s) → Dispatch(ElemMul)
333 /// where the MUL reads the norm's output buffer, and replaces the
334 /// sequence with a single fused `rms_norm_mul_*` dispatch.
335 ///
336 /// The fused dispatch:
337 /// - Reads the norm's input (buffer 0) and weight (buffer 1)
338 /// - Reads the MUL's second operand as the scale (buffer 2)
339 /// - Writes to the MUL's output (buffer 3)
340 /// - Carries the norm's params (buffer 4)
341 /// - Uses the norm's threadgroup config and shared memory
342 ///
343 /// Returns the number of fusions applied.
344 ///
345 /// # Arguments
346 ///
347 /// * `registry` - Kernel registry for compiling the fused pipeline.
348 /// * `device` - Metal device for pipeline compilation.
349 pub fn fuse(
350 &mut self,
351 registry: &mut KernelRegistry,
352 device: &metal::DeviceRef,
353 ) -> Result<u32> {
354 let mut result: Vec<CapturedNode> = Vec::with_capacity(self.nodes.len());
355 let mut fusions = 0u32;
356 let mut i = 0;
357
358 while i < self.nodes.len() {
359 // Check if current node is an RMS norm dispatch.
360 let is_rms_norm = matches!(
361 &self.nodes[i],
362 CapturedNode::Dispatch { op_kind: CapturedOpKind::RmsNorm, .. }
363 );
364
365 if !is_rms_norm {
366 result.push(self.nodes[i].clone());
367 i += 1;
368 continue;
369 }
370
371 // Look ahead: skip barriers, then check for ElemMul.
372 let mut j = i + 1;
373 let mut barrier_count = 0usize;
374 while j < self.nodes.len() && matches!(&self.nodes[j], CapturedNode::Barrier) {
375 barrier_count += 1;
376 j += 1;
377 }
378
379 // Must have at least one barrier and the next node must be ElemMul.
380 if barrier_count == 0 || j >= self.nodes.len() {
381 result.push(self.nodes[i].clone());
382 i += 1;
383 continue;
384 }
385
386 let is_elem_mul = matches!(
387 &self.nodes[j],
388 CapturedNode::Dispatch { op_kind: CapturedOpKind::ElemMul, .. }
389 );
390
391 if !is_elem_mul {
392 result.push(self.nodes[i].clone());
393 i += 1;
394 continue;
395 }
396
397 // Extract norm and mul dispatch fields.
398 let (norm_pipeline, norm_bindings, norm_tpg, norm_tptg, norm_tgmem, norm_dk) =
399 match &self.nodes[i] {
400 CapturedNode::Dispatch {
401 pipeline,
402 bindings,
403 threads_per_grid,
404 threads_per_threadgroup,
405 threadgroup_memory,
406 dispatch_kind,
407 ..
408 } => (pipeline, bindings, threads_per_grid, threads_per_threadgroup, threadgroup_memory, dispatch_kind),
409 _ => unreachable!(),
410 };
411
412 let (mul_bindings, _mul_tpg, _mul_tptg) = match &self.nodes[j] {
413 CapturedNode::Dispatch {
414 bindings,
415 threads_per_grid,
416 threads_per_threadgroup,
417 ..
418 } => (bindings, threads_per_grid, threads_per_threadgroup),
419 _ => unreachable!(),
420 };
421
422 // Verify data dependency: the norm's output buffer (slot 2) must
423 // appear as one of the MUL's input buffers (slot 0 or 1).
424 //
425 // Norm binding layout: (0=input, 1=weight, 2=output, 3=params)
426 // MUL binding layout: (0=a, 1=b, 2=output, 3=params_bytes)
427 let norm_output_ptr = Self::buffer_ptr_for_slot(norm_bindings, 2);
428 let mul_a_ptr = Self::buffer_ptr_for_slot(mul_bindings, 0);
429 let mul_b_ptr = Self::buffer_ptr_for_slot(mul_bindings, 1);
430
431 if norm_output_ptr.is_none() || (norm_output_ptr != mul_a_ptr && norm_output_ptr != mul_b_ptr) {
432 // Data dependency not confirmed — don't fuse.
433 result.push(self.nodes[i].clone());
434 i += 1;
435 continue;
436 }
437
438 // Determine which MUL input is the scale (the one that is NOT
439 // the norm's output).
440 let scale_slot = if norm_output_ptr == mul_a_ptr { 1 } else { 0 };
441
442 // Build fused bindings:
443 // 0 = norm input
444 // 1 = norm weight
445 // 2 = scale (from MUL)
446 // 3 = MUL output
447 // 4 = norm params
448 // Gather all required bindings; bail if any are missing.
449 let (norm_input, norm_weight, scale, mul_output, norm_params) = match (
450 Self::get_binding(norm_bindings, 0),
451 Self::get_binding(norm_bindings, 1),
452 Self::get_binding(mul_bindings, scale_slot),
453 Self::get_binding(mul_bindings, 2),
454 Self::get_binding(norm_bindings, 3),
455 ) {
456 (Some(a), Some(b), Some(c), Some(d), Some(e)) => (a, b, c, d, e),
457 _ => {
458 // Missing bindings — don't fuse.
459 result.push(self.nodes[i].clone());
460 i += 1;
461 continue;
462 }
463 };
464
465 // Select fused pipeline based on the original norm pipeline name.
466 // The norm pipeline name is "rms_norm_f32", "rms_norm_f16", or
467 // "rms_norm_bf16" — we need the corresponding fused pipeline.
468 let fused_name = match Self::fused_pipeline_name(norm_pipeline) {
469 Some(name) => name,
470 None => {
471 result.push(self.nodes[i].clone());
472 i += 1;
473 continue;
474 }
475 };
476
477 let fused_pipeline = registry.get_pipeline(fused_name, device)?;
478
479 let fused_bindings = vec![
480 (0, norm_input),
481 (1, norm_weight),
482 (2, scale),
483 (3, mul_output),
484 (4, norm_params),
485 ];
486
487 // Merge read/write ranges from both the norm and mul nodes for the
488 // fused dispatch. The fused op reads everything the norm reads
489 // plus the mul's scale input, and writes to the mul's output.
490 let (fused_reads, fused_writes) = match (&self.nodes[i], &self.nodes[j]) {
491 (
492 CapturedNode::Dispatch { reads: nr, writes: _nw, .. },
493 CapturedNode::Dispatch { reads: mr, writes: mw, .. },
494 ) => {
495 let mut reads = nr.clone();
496 reads.extend_from_slice(mr);
497 (reads, mw.clone())
498 }
499 _ => (Vec::new(), Vec::new()),
500 };
501
502 result.push(CapturedNode::Dispatch {
503 pipeline: fused_pipeline.to_owned(),
504 bindings: fused_bindings,
505 threads_per_grid: *norm_tpg,
506 threads_per_threadgroup: *norm_tptg,
507 threadgroup_memory: norm_tgmem.clone(),
508 dispatch_kind: *norm_dk,
509 op_kind: CapturedOpKind::Other, // Fused ops are not further fuseable
510 reads: fused_reads,
511 writes: fused_writes,
512 });
513
514 fusions += 1;
515 // Skip past the norm, barrier(s), and mul nodes.
516 i = j + 1;
517 }
518
519 self.nodes = result;
520 Ok(fusions)
521 }
522
523 /// Run the reorder pass over the graph to improve GPU concurrency.
524 ///
525 /// Port of llama.cpp's `ggml_metal_graph_optimize_reorder` — a greedy
526 /// 64-node lookahead that pulls independent dispatches forward to fill
527 /// larger concurrent groups between barriers.
528 ///
529 /// **Prerequisites:** Call `fuse()` first if desired. The reorder pass
530 /// operates on the post-fusion graph. Barrier sentinel nodes are stripped
531 /// before reordering (they will be recomputed at encode time by the
532 /// `ConflictTracker` in `encode_sequential`).
533 ///
534 /// **Algorithm (matching llama.cpp exactly):**
535 /// 1. Strip all `CapturedNode::Barrier` nodes.
536 /// 2. For each unprocessed node `i0`:
537 /// - If it conflicts with the current concurrent group (`mrs0`):
538 /// * Initialize `mrs1` from `i0`'s ranges (skipped-over set)
539 /// * Lookahead up to 64 nodes for candidates that:
540 /// (a) Are reorderable (`CapturedOpKind::is_reorderable()`)
541 /// (b) Don't conflict with `mrs0` (current group)
542 /// (c) Don't conflict with `mrs1` (skipped-over nodes)
543 /// * Pull qualifying candidates into the current group
544 /// * Non-reorderable ops break the lookahead
545 /// - Reset `mrs0` (new concurrent group)
546 /// - Add `i0` to the new group
547 ///
548 /// Returns the number of nodes that were moved to earlier positions.
549 pub fn reorder(&mut self) -> u32 {
550 // Step 1: Strip barrier nodes. After fusion + reorder, barriers will
551 // be recomputed by the ConflictTracker at encode time.
552 self.nodes.retain(|n| !matches!(n, CapturedNode::Barrier));
553
554 let n = self.nodes.len();
555 if n == 0 {
556 return 0;
557 }
558
559 let mut result: Vec<usize> = Vec::with_capacity(n);
560 let mut used = vec![false; n];
561
562 // mrs0: memory ranges for the current concurrent group
563 let mut mrs0 = ReorderConflictTracker::new();
564 // mrs1: memory ranges for skipped-over (unprocessed) nodes
565 let mut mrs1 = ReorderConflictTracker::new();
566
567 const N_FORWARD: usize = 64;
568
569 for i0 in 0..n {
570 if used[i0] {
571 continue;
572 }
573
574 let node0 = &self.nodes[i0];
575
576 // Extract reads/writes for conflict check.
577 let (reads0, writes0, op_kind0) = match node0 {
578 CapturedNode::Dispatch { reads, writes, op_kind, .. } => {
579 (reads.as_slice(), writes.as_slice(), *op_kind)
580 }
581 CapturedNode::Barrier => continue, // stripped, but be safe
582 };
583
584 // Check if node0 conflicts with the current concurrent group.
585 // Empty nodes (no ranges) never conflict — like llama.cpp's is_empty.
586 let has_ranges = !reads0.is_empty() || !writes0.is_empty();
587 if has_ranges && mrs0.conflicts(reads0, writes0) {
588 // Before starting a new group, look forward for nodes that
589 // can be pulled into the CURRENT group.
590 mrs1.reset();
591 mrs1.add(reads0, writes0);
592
593 let end = (i0 + N_FORWARD).min(n);
594 for i1 in (i0 + 1)..end {
595 if used[i1] {
596 continue;
597 }
598
599 let node1 = &self.nodes[i1];
600 let (reads1, writes1, op_kind1) = match node1 {
601 CapturedNode::Dispatch { reads, writes, op_kind, .. } => {
602 (reads.as_slice(), writes.as_slice(), *op_kind)
603 }
604 CapturedNode::Barrier => continue,
605 };
606
607 // Non-reorderable ops break the lookahead.
608 if !op_kind1.is_reorderable() {
609 break;
610 }
611
612 let is_empty1 = reads1.is_empty() && writes1.is_empty();
613
614 // A node can be reordered into the current group if:
615 // 1. It's empty (no ranges) OR doesn't conflict with mrs0
616 // 2. It doesn't conflict with mrs1 (skipped-over nodes)
617 if (is_empty1 || !mrs0.conflicts(reads1, writes1))
618 && !mrs1.conflicts(reads1, writes1)
619 {
620 // Pull into current concurrent group.
621 mrs0.add(reads1, writes1);
622 result.push(i1);
623 used[i1] = true;
624 } else {
625 // Not eligible — expand the skipped-over set.
626 mrs1.add(reads1, writes1);
627 }
628 }
629
630 // Finalize the current concurrent group.
631 mrs0.reset();
632 }
633
634 // Expand the concurrent group with node0.
635 // (Barriers were stripped, so this is always a Dispatch.)
636 let _ = op_kind0; // suppress unused warning
637 mrs0.add(reads0, writes0);
638 result.push(i0);
639 }
640
641 // Apply the permutation to produce the reordered node list.
642 let mut reordered_count = 0u32;
643 for (pos, &orig_idx) in result.iter().enumerate() {
644 if orig_idx != pos {
645 reordered_count += 1;
646 }
647 }
648
649 // Build the reordered nodes vec.
650 let old_nodes = std::mem::take(&mut self.nodes);
651 self.nodes = result.iter().map(|&idx| old_nodes[idx].clone()).collect();
652
653 // Debug dump if requested.
654 if std::env::var("HF2Q_REORDER_DUMP").is_ok() {
655 eprintln!(
656 " [REORDER] nodes={} reordered={} ({:.1}%)",
657 n,
658 reordered_count,
659 100.0 * reordered_count as f64 / n as f64,
660 );
661 }
662
663 reordered_count
664 }
665
666 /// Get the Metal buffer pointer for a binding at the given slot index.
667 ///
668 /// Returns `Some(ptr)` if the slot has a `RecordedBinding::Buffer`,
669 /// `None` otherwise.
670 fn buffer_ptr_for_slot(bindings: &[(u64, RecordedBinding)], slot: u64) -> Option<*const std::ffi::c_void> {
671 for (idx, binding) in bindings {
672 if *idx == slot {
673 if let RecordedBinding::Buffer { metal_buffer, offset: _ } = binding {
674 // Use the Metal buffer's GPU address as the identity key.
675 // On Apple Silicon unified memory, this uniquely identifies
676 // the allocation.
677 let ptr: *const std::ffi::c_void = metal_buffer.as_ptr() as *const _;
678 return Some(ptr);
679 }
680 }
681 }
682 None
683 }
684
685 /// Clone the binding at the given slot index.
686 fn get_binding(bindings: &[(u64, RecordedBinding)], slot: u64) -> Option<RecordedBinding> {
687 for (idx, binding) in bindings {
688 if *idx == slot {
689 return Some(binding.clone());
690 }
691 }
692 None
693 }
694
695 /// Map a norm pipeline to its fused norm+mul pipeline name.
696 ///
697 /// The pipeline's `label()` is set by Metal to the function name, so we
698 /// can match on it. Returns `None` if the pipeline is not a known norm.
699 fn fused_pipeline_name(pipeline: &metal::ComputePipelineState) -> Option<&'static str> {
700 match pipeline.label() {
701 "rms_norm_f32" => Some("rms_norm_mul_f32"),
702 "rms_norm_f16" => Some("rms_norm_mul_f16"),
703 "rms_norm_bf16" => Some("rms_norm_mul_bf16"),
704 _ => None,
705 }
706 }
707}
708
709impl Default for ComputeGraph {
710 fn default() -> Self {
711 Self::new()
712 }
713}
714
715// ---------------------------------------------------------------------------
716// Dual-buffer encoding helpers (Phase 4e.4)
717// ---------------------------------------------------------------------------
718
719/// Find the node index where the n0-th dispatch starts.
720///
721/// Counts `CapturedNode::Dispatch` nodes until `n0` are reached, then returns
722/// the index of the n0-th dispatch (i.e., the first node of the second chunk).
723/// If `n0 >= dispatch_count`, returns `nodes.len()` (everything in chunk 0).
724fn find_dispatch_split_index(nodes: &[CapturedNode], n0: usize) -> usize {
725 let mut dispatches_seen = 0usize;
726 for (i, node) in nodes.iter().enumerate() {
727 if matches!(node, CapturedNode::Dispatch { .. }) {
728 dispatches_seen += 1;
729 if dispatches_seen == n0 {
730 return i + 1; // split AFTER the n0-th dispatch
731 }
732 }
733 }
734 nodes.len()
735}
736
737/// Encode a slice of captured nodes into a command encoder, recomputing
738/// barriers on the fly from each node's read/write buffer ranges.
739///
740/// This is the chunked counterpart of `ComputeGraph::encode_with_barriers()`.
741/// Factored out so both halves of a dual-buffer encode can use it.
742///
743/// Returns the number of barriers emitted.
744fn encode_chunk_with_barriers(nodes: &[CapturedNode], encoder: &mut CommandEncoder) -> u32 {
745 let mut tracker = ReorderConflictTracker::new();
746 let mut barrier_count = 0u32;
747
748 for node in nodes {
749 match node {
750 CapturedNode::Dispatch {
751 pipeline,
752 bindings,
753 threads_per_grid,
754 threads_per_threadgroup,
755 threadgroup_memory,
756 dispatch_kind,
757 reads,
758 writes,
759 op_kind,
760 ..
761 } => {
762 let has_ranges = !reads.is_empty() || !writes.is_empty();
763 if has_ranges && tracker.conflicts(reads, writes) {
764 encoder.memory_barrier();
765 tracker.reset();
766 barrier_count += 1;
767 }
768 if has_ranges {
769 tracker.add(reads, writes);
770 }
771 // ADR-015 iter63 (Phase A.3): forward captured op_kind so the
772 // per-dispatch profile dump groups dual-buffer chunked replay
773 // entries by the same op_kind label as direct dispatch.
774 encoder.set_op_kind(*op_kind);
775 encoder.replay_dispatch(
776 pipeline,
777 bindings,
778 threadgroup_memory,
779 *threads_per_grid,
780 *threads_per_threadgroup,
781 *dispatch_kind,
782 );
783 }
784 CapturedNode::Barrier => {
785 encoder.memory_barrier();
786 tracker.reset();
787 barrier_count += 1;
788 }
789 }
790 }
791 barrier_count
792}
793
794// ---------------------------------------------------------------------------
795// ReorderConflictTracker — range-based conflict detection for the reorder pass
796// ---------------------------------------------------------------------------
797
798/// Memory range conflict tracker for the reorder pass (Phase 4e.3).
799///
800/// Works with `MemRange` tuples `(start, end)` stored on `CapturedNode::Dispatch`,
801/// rather than requiring live `&MlxBuffer` references. This is the reorder-time
802/// equivalent of the runtime `ConflictTracker`.
803///
804/// Conflict rules match llama.cpp's `ggml_mem_ranges_check`:
805/// - Two read ranges: OK (read-read is concurrent-safe)
806/// - A new read overlapping an existing write: CONFLICT (RAW)
807/// - A new write overlapping any existing range: CONFLICT (WAR/WAW)
808struct ReorderConflictTracker {
809 /// (start, end, is_write) for all ranges in the tracked set.
810 ranges: Vec<(usize, usize, bool)>,
811}
812
813impl ReorderConflictTracker {
814 fn new() -> Self {
815 Self {
816 ranges: Vec::with_capacity(64),
817 }
818 }
819
820 fn reset(&mut self) {
821 self.ranges.clear();
822 }
823
824 /// Check if a dispatch with the given read/write ranges conflicts with
825 /// any range in this tracker.
826 fn conflicts(&self, reads: &[MemRange], writes: &[MemRange]) -> bool {
827 // New reads vs existing writes (RAW)
828 for &(r_start, r_end) in reads {
829 for &(s, e, is_write) in &self.ranges {
830 if is_write && r_start < e && r_end > s {
831 return true;
832 }
833 }
834 }
835 // New writes vs all existing ranges (WAR/WAW)
836 for &(w_start, w_end) in writes {
837 for &(s, e, _) in &self.ranges {
838 if w_start < e && w_end > s {
839 return true;
840 }
841 }
842 }
843 false
844 }
845
846 /// Add read and write ranges to the tracked set.
847 fn add(&mut self, reads: &[MemRange], writes: &[MemRange]) {
848 for &(start, end) in reads {
849 self.ranges.push((start, end, false));
850 }
851 for &(start, end) in writes {
852 self.ranges.push((start, end, true));
853 }
854 }
855}
856
857/// Batched Metal dispatch — encodes multiple ops into a single `CommandEncoder`.
858///
859/// Create one per model (or per forward-pass loop). Call [`begin`](Self::begin)
860/// at the start of each forward pass to get a [`GraphSession`] that holds the
861/// shared encoder.
862pub struct GraphExecutor {
863 device: MlxDevice,
864}
865
866impl GraphExecutor {
867 /// Create a new graph executor backed by the given device.
868 pub fn new(device: MlxDevice) -> Self {
869 Self { device }
870 }
871
872 /// Begin a new forward pass (direct-dispatch mode).
873 ///
874 /// Returns a [`GraphSession`] that holds a fresh `CommandEncoder`. All ops
875 /// encoded through the session share this single encoder. Call
876 /// [`GraphSession::finish`] to commit and wait.
877 pub fn begin(&self) -> Result<GraphSession<'_>> {
878 let encoder = self.device.command_encoder()?;
879 // ADR-015 iter63 (Phase B): start a programmatic capture if
880 // MLX_METAL_CAPTURE+METAL_CAPTURE_ENABLED are set. No-op when
881 // unset; one-shot per process so subsequent forward passes do
882 // not pay the env-check cost more than once.
883 let metal_capture = {
884 let mut c = crate::metal_capture::MetalCapture::from_env(&self.device);
885 if let Some(ref mut cap) = c {
886 cap.begin();
887 }
888 c
889 };
890 Ok(GraphSession {
891 encoder,
892 device: &self.device,
893 barrier_count: 0,
894 tracker: ConflictTracker::new(),
895 dispatch_in_group: 0,
896 total_dispatches: 0,
897 group_sizes: [0; 8],
898 recording: false,
899 metal_capture,
900 })
901 }
902
903 /// Begin a new forward pass in capture (record) mode.
904 ///
905 /// All op calls are recorded into a `ComputeGraph` instead of being
906 /// dispatched to Metal. When [`GraphSession::finish`] is called, the
907 /// recorded graph is replayed into a fresh encoder via
908 /// `ComputeGraph::encode_sequential()`.
909 ///
910 /// The API is identical to `begin()` — callers do not need to change
911 /// any op call code. The only behavioral difference: GPU work happens
912 /// at `finish()` time rather than at each op call.
913 pub fn begin_recorded(&self) -> Result<GraphSession<'_>> {
914 let mut encoder = self.device.command_encoder()?;
915 encoder.start_capture();
916 // ADR-015 iter63 (Phase B): see GraphExecutor::begin().
917 let metal_capture = {
918 let mut c = crate::metal_capture::MetalCapture::from_env(&self.device);
919 if let Some(ref mut cap) = c {
920 cap.begin();
921 }
922 c
923 };
924 Ok(GraphSession {
925 encoder,
926 device: &self.device,
927 barrier_count: 0,
928 tracker: ConflictTracker::new(),
929 dispatch_in_group: 0,
930 total_dispatches: 0,
931 group_sizes: [0; 8],
932 recording: true,
933 metal_capture,
934 })
935 }
936
937 /// Borrow the underlying device.
938 pub fn device(&self) -> &MlxDevice {
939 &self.device
940 }
941}
942
943/// A single forward pass execution context.
944///
945/// All ops are encoded into one `CommandEncoder`. Call [`finish`](Self::finish)
946/// to commit the command buffer and wait for GPU completion — this is the ONLY
947/// sync point per forward pass.
948///
949/// If an op returns an error, the session can be dropped without committing.
950/// The underlying command buffer is abandoned (never committed to the GPU).
951/// Tracks buffer address ranges for automatic barrier elision.
952///
953/// Mirrors llama.cpp's `ggml_mem_ranges` — accumulates the read and write
954/// ranges of all dispatches in the current concurrent group. When a new
955/// dispatch's reads overlap with an existing write (RAW), or its writes
956/// overlap with an existing read or write (WAR/WAW), a barrier is needed.
957/// Otherwise the dispatch can run concurrently and the barrier is elided.
958///
959/// Uses CPU-visible `contents_ptr()` addresses, which on Apple Silicon
960/// unified memory equal the GPU addresses.
961pub struct ConflictTracker {
962 /// (start, end, is_write) tuples for the current concurrent group.
963 ranges: Vec<(usize, usize, bool)>,
964}
965
966impl ConflictTracker {
967 fn new() -> Self {
968 Self {
969 ranges: Vec::with_capacity(32),
970 }
971 }
972
973 /// Reset the tracker — called after emitting a barrier.
974 fn reset(&mut self) {
975 self.ranges.clear();
976 }
977
978 /// Check if a new dispatch with the given reads and writes conflicts
979 /// with the current concurrent group.
980 ///
981 /// Conflict rules (same as llama.cpp `ggml_mem_ranges_check`):
982 /// - Two SRC (read) ranges in the same buffer: OK (read-read)
983 /// - A new SRC overlapping an existing DST: CONFLICT (RAW)
984 /// - A new DST overlapping an existing SRC or DST: CONFLICT (WAR/WAW)
985 /// Check for conflicts and return the reason if one is found.
986 /// Returns (conflict_type, new_buf_ptr, existing_buf_ptr) or None.
987 fn conflicts_reason(&self, reads: &[&MlxBuffer], writes: &[&MlxBuffer])
988 -> Option<(&'static str, usize, usize)>
989 {
990 // Check new reads against existing writes (RAW)
991 for r in reads {
992 let r_start = r.contents_ptr() as usize;
993 let r_end = r_start + r.byte_len();
994 for &(s, e, is_write) in &self.ranges {
995 if is_write && r_start < e && r_end > s {
996 return Some(("RAW", r_start, s));
997 }
998 }
999 }
1000 // Check new writes against existing reads and writes (WAR/WAW)
1001 for w in writes {
1002 let w_start = w.contents_ptr() as usize;
1003 let w_end = w_start + w.byte_len();
1004 for &(s, e, is_write) in &self.ranges {
1005 if w_start < e && w_end > s {
1006 let kind = if is_write { "WAW" } else { "WAR" };
1007 return Some((kind, w_start, s));
1008 }
1009 }
1010 }
1011 None
1012 }
1013
1014 /// Add read and write ranges to the current concurrent group.
1015 fn add(&mut self, reads: &[&MlxBuffer], writes: &[&MlxBuffer]) {
1016 for r in reads {
1017 let start = r.contents_ptr() as usize;
1018 let end = start + r.byte_len();
1019 self.ranges.push((start, end, false));
1020 }
1021 for w in writes {
1022 let start = w.contents_ptr() as usize;
1023 let end = start + w.byte_len();
1024 self.ranges.push((start, end, true));
1025 }
1026 }
1027}
1028
1029pub struct GraphSession<'a> {
1030 encoder: CommandEncoder,
1031 device: &'a MlxDevice,
1032 barrier_count: u32,
1033 tracker: ConflictTracker,
1034 dispatch_in_group: u32,
1035 total_dispatches: u32,
1036 /// Histogram: group_sizes[i] = number of concurrent groups with (i+1) dispatches
1037 group_sizes: [u32; 8],
1038 /// Whether this session was created in capture/record mode.
1039 recording: bool,
1040 /// ADR-015 iter63 (Phase B): optional Metal frame capture wrapping
1041 /// this session's GPU work. Populated by `MetalCapture::from_env`
1042 /// when `MLX_METAL_CAPTURE=<path>` + `METAL_CAPTURE_ENABLED=1` are
1043 /// both set in the process env AND the process-global one-shot
1044 /// latch has not yet flipped. `None` in all other cases (default
1045 /// production path). Capture scope is begun in
1046 /// [`GraphExecutor::begin`] / [`GraphExecutor::begin_recorded`]
1047 /// and ended in [`Self::finish`] / [`Self::commit`] BEFORE the
1048 /// CB is committed, so all enqueued CBs land inside the trace.
1049 metal_capture: Option<crate::metal_capture::MetalCapture>,
1050}
1051
1052impl<'a> GraphSession<'a> {
1053 /// Encode an RMS normalization into this session's encoder.
1054 ///
1055 /// Delegates to [`ops::rms_norm::dispatch_rms_norm`].
1056 pub fn rms_norm(
1057 &mut self,
1058 registry: &mut KernelRegistry,
1059 device: &metal::DeviceRef,
1060 input: &MlxBuffer,
1061 weight: &MlxBuffer,
1062 output: &MlxBuffer,
1063 params_buf: &MlxBuffer,
1064 rows: u32,
1065 dim: u32,
1066 ) -> Result<()> {
1067 ops::rms_norm::dispatch_rms_norm(
1068 &mut self.encoder,
1069 registry,
1070 device,
1071 input,
1072 weight,
1073 output,
1074 params_buf,
1075 rows,
1076 dim,
1077 )
1078 }
1079
1080 /// Encode a quantized matrix multiplication into this session's encoder.
1081 ///
1082 /// Delegates to [`ops::quantized_matmul::quantized_matmul`].
1083 /// Returns the freshly allocated output buffer.
1084 pub fn quantized_matmul(
1085 &mut self,
1086 registry: &mut KernelRegistry,
1087 device: &MlxDevice,
1088 input: &MlxBuffer,
1089 weight: &MlxBuffer,
1090 scales: &MlxBuffer,
1091 biases: &MlxBuffer,
1092 params: &ops::quantized_matmul::QuantizedMatmulParams,
1093 ) -> Result<MlxBuffer> {
1094 ops::quantized_matmul::quantized_matmul(
1095 &mut self.encoder,
1096 registry,
1097 device,
1098 input,
1099 weight,
1100 scales,
1101 biases,
1102 params,
1103 )
1104 }
1105
1106 /// Encode a SIMD-optimized quantized matmul into this session's encoder.
1107 ///
1108 /// Delegates to [`ops::quantized_matmul::quantized_matmul_simd`].
1109 /// Returns the freshly allocated output buffer.
1110 pub fn quantized_matmul_simd(
1111 &mut self,
1112 registry: &mut KernelRegistry,
1113 device: &MlxDevice,
1114 input: &MlxBuffer,
1115 weight: &MlxBuffer,
1116 scales: &MlxBuffer,
1117 biases: &MlxBuffer,
1118 params: &ops::quantized_matmul::QuantizedMatmulParams,
1119 ) -> Result<MlxBuffer> {
1120 ops::quantized_matmul::quantized_matmul_simd(
1121 &mut self.encoder,
1122 registry,
1123 device,
1124 input,
1125 weight,
1126 scales,
1127 biases,
1128 params,
1129 )
1130 }
1131
1132 /// Encode a GGML block-format quantized mat-vec into this session's encoder.
1133 ///
1134 /// Delegates to [`ops::quantized_matmul_ggml::quantized_matmul_ggml`].
1135 pub fn quantized_matmul_ggml(
1136 &mut self,
1137 registry: &mut KernelRegistry,
1138 device: &MlxDevice,
1139 input: &MlxBuffer,
1140 weight: &MlxBuffer,
1141 output: &mut MlxBuffer,
1142 params: &ops::quantized_matmul_ggml::GgmlQuantizedMatmulParams,
1143 ) -> Result<()> {
1144 ops::quantized_matmul_ggml::quantized_matmul_ggml(
1145 &mut self.encoder,
1146 registry,
1147 device,
1148 input,
1149 weight,
1150 output,
1151 params,
1152 )
1153 }
1154
1155 /// Encode an expert-routed GGML block-format quantized mat-vec into this session's encoder.
1156 ///
1157 /// Delegates to [`ops::quantized_matmul_id_ggml::quantized_matmul_id_ggml`].
1158 #[allow(clippy::too_many_arguments)]
1159 pub fn quantized_matmul_id_ggml(
1160 &mut self,
1161 registry: &mut KernelRegistry,
1162 device: &MlxDevice,
1163 input: &MlxBuffer,
1164 weight: &MlxBuffer,
1165 ids: &MlxBuffer,
1166 output: &mut MlxBuffer,
1167 params: &ops::quantized_matmul_id_ggml::GgmlQuantizedMatmulIdParams,
1168 ) -> Result<()> {
1169 ops::quantized_matmul_id_ggml::quantized_matmul_id_ggml(
1170 &mut self.encoder,
1171 registry,
1172 device,
1173 input,
1174 weight,
1175 ids,
1176 output,
1177 params,
1178 )
1179 }
1180
1181 /// Pooled-scratch variant of [`Self::quantized_matmul_id_ggml`] — the
1182 /// `IdMmScratch` is caller-owned so batched prefill amortises the
1183 /// per-call allocations that the auto entry point incurs (ADR-011
1184 /// Phase 3 Wave P3b).
1185 #[allow(clippy::too_many_arguments)]
1186 pub fn quantized_matmul_id_ggml_pooled(
1187 &mut self,
1188 registry: &mut KernelRegistry,
1189 device: &MlxDevice,
1190 input: &MlxBuffer,
1191 weight: &MlxBuffer,
1192 ids: &MlxBuffer,
1193 output: &mut MlxBuffer,
1194 scratch: &mut ops::quantized_matmul_id_ggml::IdMmScratch,
1195 params: &ops::quantized_matmul_id_ggml::GgmlQuantizedMatmulIdParams,
1196 ) -> Result<()> {
1197 ops::quantized_matmul_id_ggml::quantized_matmul_id_ggml_pooled(
1198 &mut self.encoder,
1199 registry,
1200 device,
1201 input,
1202 weight,
1203 ids,
1204 output,
1205 scratch,
1206 params,
1207 )
1208 }
1209
1210 /// Encode scaled dot-product attention into this session's encoder.
1211 ///
1212 /// Delegates to [`ops::sdpa::sdpa`].
1213 pub fn sdpa(
1214 &mut self,
1215 registry: &mut KernelRegistry,
1216 device: &MlxDevice,
1217 q: &MlxBuffer,
1218 k: &MlxBuffer,
1219 v: &MlxBuffer,
1220 output: &MlxBuffer,
1221 params: &ops::sdpa::SdpaParams,
1222 batch_size: u32,
1223 ) -> Result<()> {
1224 ops::sdpa::sdpa(
1225 &mut self.encoder,
1226 registry,
1227 device,
1228 q,
1229 k,
1230 v,
1231 output,
1232 params,
1233 batch_size,
1234 )
1235 }
1236
1237 /// Encode flash attention vector (SIMD-vectorized decode-path SDPA).
1238 ///
1239 /// Delegates to [`ops::flash_attn_vec::flash_attn_vec`].
1240 pub fn flash_attn_vec(
1241 &mut self,
1242 registry: &mut KernelRegistry,
1243 device: &MlxDevice,
1244 q: &MlxBuffer,
1245 k: &MlxBuffer,
1246 v: &MlxBuffer,
1247 output: &MlxBuffer,
1248 tmp: &MlxBuffer,
1249 params: &ops::flash_attn_vec::FlashAttnVecParams,
1250 ) -> Result<()> {
1251 ops::flash_attn_vec::flash_attn_vec(
1252 &mut self.encoder,
1253 registry,
1254 device,
1255 q,
1256 k,
1257 v,
1258 output,
1259 tmp,
1260 params,
1261 )
1262 }
1263
1264 /// Encode an elementwise add into this session's encoder.
1265 ///
1266 /// Delegates to [`ops::elementwise::elementwise_add`].
1267 pub fn elementwise_add(
1268 &mut self,
1269 registry: &mut KernelRegistry,
1270 device: &metal::DeviceRef,
1271 a: &MlxBuffer,
1272 b: &MlxBuffer,
1273 output: &MlxBuffer,
1274 n_elements: usize,
1275 dtype: DType,
1276 ) -> Result<()> {
1277 ops::elementwise::elementwise_add(
1278 &mut self.encoder,
1279 registry,
1280 device,
1281 a,
1282 b,
1283 output,
1284 n_elements,
1285 dtype,
1286 )
1287 }
1288
1289 /// Encode an elementwise multiply into this session's encoder.
1290 ///
1291 /// Delegates to [`ops::elementwise::elementwise_mul`].
1292 pub fn elementwise_mul(
1293 &mut self,
1294 registry: &mut KernelRegistry,
1295 device: &metal::DeviceRef,
1296 a: &MlxBuffer,
1297 b: &MlxBuffer,
1298 output: &MlxBuffer,
1299 n_elements: usize,
1300 dtype: DType,
1301 ) -> Result<()> {
1302 ops::elementwise::elementwise_mul(
1303 &mut self.encoder,
1304 registry,
1305 device,
1306 a,
1307 b,
1308 output,
1309 n_elements,
1310 dtype,
1311 )
1312 }
1313
1314 /// Encode a RoPE transform into this session's encoder.
1315 ///
1316 /// Delegates to [`ops::rope::dispatch_rope`].
1317 pub fn rope(
1318 &mut self,
1319 registry: &mut KernelRegistry,
1320 device: &metal::DeviceRef,
1321 input: &MlxBuffer,
1322 output: &MlxBuffer,
1323 params_buf: &MlxBuffer,
1324 positions_buf: &MlxBuffer,
1325 seq_len: u32,
1326 head_dim: u32,
1327 ) -> Result<()> {
1328 ops::rope::dispatch_rope(
1329 &mut self.encoder,
1330 registry,
1331 device,
1332 input,
1333 output,
1334 params_buf,
1335 positions_buf,
1336 seq_len,
1337 head_dim,
1338 )
1339 }
1340
1341 /// Encode a GELU activation into this session's encoder.
1342 ///
1343 /// Delegates to [`ops::gelu::dispatch_gelu`].
1344 pub fn gelu(
1345 &mut self,
1346 registry: &mut KernelRegistry,
1347 device: &metal::DeviceRef,
1348 input: &MlxBuffer,
1349 output: &MlxBuffer,
1350 ) -> Result<()> {
1351 ops::gelu::dispatch_gelu(
1352 &mut self.encoder,
1353 registry,
1354 device,
1355 input,
1356 output,
1357 )
1358 }
1359
1360 /// Encode a softmax into this session's encoder.
1361 ///
1362 /// Delegates to [`ops::softmax::dispatch_softmax`].
1363 pub fn softmax(
1364 &mut self,
1365 registry: &mut KernelRegistry,
1366 device: &metal::DeviceRef,
1367 input: &MlxBuffer,
1368 output: &MlxBuffer,
1369 params_buf: &MlxBuffer,
1370 rows: u32,
1371 cols: u32,
1372 ) -> Result<()> {
1373 ops::softmax::dispatch_softmax(
1374 &mut self.encoder,
1375 registry,
1376 device,
1377 input,
1378 output,
1379 params_buf,
1380 rows,
1381 cols,
1382 )
1383 }
1384
1385 /// Encode a softcap into this session's encoder.
1386 ///
1387 /// Delegates to [`ops::softcap::dispatch_softcap`].
1388 pub fn softcap(
1389 &mut self,
1390 registry: &mut KernelRegistry,
1391 device: &metal::DeviceRef,
1392 input: &MlxBuffer,
1393 output: &MlxBuffer,
1394 params_buf: &MlxBuffer,
1395 cap: f32,
1396 ) -> Result<()> {
1397 ops::softcap::dispatch_softcap(
1398 &mut self.encoder,
1399 registry,
1400 device,
1401 input,
1402 output,
1403 params_buf,
1404 cap,
1405 )
1406 }
1407
1408 /// Encode an RMS norm without learned scale (f32) into this session's encoder.
1409 ///
1410 /// Delegates to [`ops::rms_norm::dispatch_rms_norm_no_scale_f32`].
1411 pub fn rms_norm_no_scale_f32(
1412 &mut self,
1413 registry: &mut KernelRegistry,
1414 device: &metal::DeviceRef,
1415 input: &MlxBuffer,
1416 output: &MlxBuffer,
1417 params_buf: &MlxBuffer,
1418 rows: u32,
1419 dim: u32,
1420 ) -> Result<()> {
1421 ops::rms_norm::dispatch_rms_norm_no_scale_f32(
1422 &mut self.encoder,
1423 registry,
1424 device,
1425 input,
1426 output,
1427 params_buf,
1428 rows,
1429 dim,
1430 )
1431 }
1432
1433 /// Encode a NeoX RoPE (f32) with optional freq_factors into this session's encoder.
1434 ///
1435 /// Delegates to [`ops::rope::dispatch_rope_neox_f32`].
1436 #[allow(clippy::too_many_arguments)]
1437 pub fn rope_neox_f32(
1438 &mut self,
1439 registry: &mut KernelRegistry,
1440 device: &metal::DeviceRef,
1441 input: &MlxBuffer,
1442 output: &MlxBuffer,
1443 params_buf: &MlxBuffer,
1444 positions_buf: &MlxBuffer,
1445 freq_factors: Option<&MlxBuffer>,
1446 seq_len: u32,
1447 n_heads: u32,
1448 head_dim: u32,
1449 rope_dim: u32,
1450 ) -> Result<()> {
1451 ops::rope::dispatch_rope_neox_f32(
1452 &mut self.encoder,
1453 registry,
1454 device,
1455 input,
1456 output,
1457 params_buf,
1458 positions_buf,
1459 freq_factors,
1460 seq_len,
1461 n_heads,
1462 head_dim,
1463 rope_dim,
1464 )
1465 }
1466
1467 /// Insert a GPU memory barrier (MTLBarrierScopeBuffers).
1468 ///
1469 /// Unconditional barrier — always emits. Use `barrier_between` for
1470 /// automatic conflict detection that can elide unnecessary barriers.
1471 #[inline]
1472 pub fn barrier(&mut self) {
1473 // Record the outgoing group size
1474 if self.dispatch_in_group > 0 {
1475 let idx = (self.dispatch_in_group as usize).min(self.group_sizes.len()) - 1;
1476 self.group_sizes[idx] += 1;
1477 }
1478 self.encoder.memory_barrier();
1479 self.tracker.reset();
1480 self.barrier_count += 1;
1481 self.dispatch_in_group = 0;
1482 }
1483
1484 /// Smart barrier with conflict detection.
1485 ///
1486 /// Checks if the next dispatch (with the given read and write buffers)
1487 /// actually conflicts with any dispatch in the current concurrent group.
1488 /// If yes, emits a Metal barrier and resets the tracker. If no, the
1489 /// barrier is elided and the dispatch can run concurrently.
1490 ///
1491 /// This mirrors llama.cpp's `ggml_metal_op_concurrency_check` +
1492 /// `ggml_metal_op_concurrency_reset` pattern.
1493 #[inline]
1494 pub fn barrier_between(&mut self, reads: &[&MlxBuffer], writes: &[&MlxBuffer]) {
1495 // In capture mode, stash the read/write ranges so the next captured
1496 // dispatch node carries them for the reorder pass (Phase 4e.3).
1497 if self.recording {
1498 let read_ranges: Vec<MemRange> = reads
1499 .iter()
1500 .map(|b| {
1501 let start = b.contents_ptr() as usize;
1502 (start, start + b.byte_len())
1503 })
1504 .collect();
1505 let write_ranges: Vec<MemRange> = writes
1506 .iter()
1507 .map(|b| {
1508 let start = b.contents_ptr() as usize;
1509 (start, start + b.byte_len())
1510 })
1511 .collect();
1512 self.encoder.set_pending_buffer_ranges(read_ranges, write_ranges);
1513 }
1514
1515 let reason = self.tracker.conflicts_reason(reads, writes);
1516 if let Some((_kind, _new_ptr, _existing_ptr)) = reason {
1517 // Record the outgoing group size before resetting
1518 if self.dispatch_in_group > 0 {
1519 let idx = (self.dispatch_in_group as usize).min(self.group_sizes.len()) - 1;
1520 self.group_sizes[idx] += 1;
1521 }
1522 self.encoder.memory_barrier();
1523 self.tracker.reset();
1524 self.barrier_count += 1;
1525 self.dispatch_in_group = 0;
1526 }
1527 self.dispatch_in_group += 1;
1528 self.total_dispatches += 1;
1529 self.tracker.add(reads, writes);
1530 }
1531
1532 /// Print group size histogram to stderr (for HF2Q_MLX_TIMING debug).
1533 pub fn dump_group_stats(&self) {
1534 // Record the final (unterminated) group
1535 let mut gs = self.group_sizes;
1536 if self.dispatch_in_group > 0 {
1537 let idx = (self.dispatch_in_group as usize).min(gs.len()) - 1;
1538 gs[idx] += 1;
1539 }
1540 let total_groups: u32 = gs.iter().sum();
1541 eprintln!(" [GROUP_STATS] dispatches={} barriers={} groups={} ratio={:.2}",
1542 self.total_dispatches, self.barrier_count, total_groups,
1543 if total_groups > 0 { self.total_dispatches as f64 / total_groups as f64 } else { 0.0 });
1544 for (i, &count) in gs.iter().enumerate() {
1545 if count > 0 {
1546 eprintln!(" size {}: {} groups", i + 1, count);
1547 }
1548 }
1549 }
1550
1551 /// Register a dispatch's buffer ranges without checking for conflicts.
1552 ///
1553 /// Use after dispatching an op that doesn't need a barrier check (e.g.,
1554 /// the first dispatch in a session, or dispatches known to be concurrent).
1555 ///
1556 /// In recording mode, also retroactively annotates the most recently
1557 /// captured dispatch node with these ranges if it was missing them.
1558 /// That keeps the reorder pass able to reason about dispatches that
1559 /// were preceded by `track_dispatch` rather than `barrier_between`.
1560 #[inline]
1561 pub fn track_dispatch(&mut self, reads: &[&MlxBuffer], writes: &[&MlxBuffer]) {
1562 if self.recording {
1563 let read_ranges: Vec<MemRange> = reads
1564 .iter()
1565 .map(|b| {
1566 let start = b.contents_ptr() as usize;
1567 (start, start + b.byte_len())
1568 })
1569 .collect();
1570 let write_ranges: Vec<MemRange> = writes
1571 .iter()
1572 .map(|b| {
1573 let start = b.contents_ptr() as usize;
1574 (start, start + b.byte_len())
1575 })
1576 .collect();
1577 self.encoder
1578 .annotate_last_dispatch_if_missing(read_ranges, write_ranges);
1579 }
1580 self.tracker.add(reads, writes);
1581 }
1582
1583 /// Return the number of barriers inserted so far in this session.
1584 #[inline]
1585 pub fn barrier_count(&self) -> u32 {
1586 self.barrier_count
1587 }
1588
1589 /// Cumulative nanoseconds spent in ConflictTracker checks (diagnostic).
1590 /// Returns 0 when timing is not compiled in.
1591 pub fn tracker_overhead_ns(&self) -> u64 {
1592 0
1593 }
1594
1595 /// Borrow the underlying command encoder for direct op dispatch.
1596 ///
1597 /// Use this when you need to call an op function that is not wrapped by
1598 /// a `GraphSession` method. The returned encoder is the same shared
1599 /// encoder — all dispatches still go into the same command buffer.
1600 pub fn encoder_mut(&mut self) -> &mut CommandEncoder {
1601 &mut self.encoder
1602 }
1603
1604 /// Borrow the device reference.
1605 pub fn device(&self) -> &MlxDevice {
1606 self.device
1607 }
1608
1609 /// Whether this session is in capture/record mode.
1610 pub fn is_recording(&self) -> bool {
1611 self.recording
1612 }
1613
1614 /// Commit the command buffer and wait for GPU completion.
1615 ///
1616 /// This is the ONLY sync point per forward pass. After this call, all
1617 /// output buffers are readable by the CPU.
1618 ///
1619 /// In recording mode: extracts the captured graph, replays it into
1620 /// the encoder via `ComputeGraph::encode_sequential()`, then commits
1621 /// and waits. The result is identical to the direct-dispatch path.
1622 ///
1623 /// Consumes the session — no further ops can be encoded.
1624 pub fn finish(mut self) -> Result<()> {
1625 if self.recording {
1626 if let Some(nodes) = self.encoder.take_capture() {
1627 let graph = ComputeGraph::from_nodes(nodes);
1628 graph.encode_sequential(&mut self.encoder);
1629 }
1630 }
1631 self.encoder.commit_and_wait()
1632 }
1633
1634 /// Commit the command buffer WITHOUT waiting.
1635 ///
1636 /// The GPU begins executing immediately. Use this for fire-and-forget
1637 /// dispatch when you do not need results until later.
1638 ///
1639 /// In recording mode: replays the captured graph before committing.
1640 ///
1641 /// Consumes the session.
1642 pub fn commit(mut self) -> CommandEncoder {
1643 if self.recording {
1644 if let Some(nodes) = self.encoder.take_capture() {
1645 let graph = ComputeGraph::from_nodes(nodes);
1646 graph.encode_sequential(&mut self.encoder);
1647 }
1648 }
1649 self.encoder.commit();
1650 // ADR-015 iter63 (Phase B): close the capture window AFTER
1651 // commit (so the CB is recorded inside the trace) but BEFORE
1652 // returning the encoder (so the trace finalizes promptly).
1653 // `MTLCaptureManager.stopCapture` marks the recording
1654 // boundary at exactly this point. CBs committed BEFORE this
1655 // line are in the trace; any work done by the caller through
1656 // the returned encoder is NOT. This matches llama.cpp's
1657 // ggml-metal-context.m:608 pattern (`stopCapture` after the
1658 // last `commit + waitUntilCompleted`).
1659 //
1660 // Note: for the async commit() path, the GPU may still be
1661 // executing the CB when stopCapture fires. Apple's
1662 // MTLCaptureManager is documented to flush in-flight work
1663 // into the trace before finalizing the file.
1664 if let Some(mut cap) = self.metal_capture.take() {
1665 cap.end();
1666 }
1667 self.encoder
1668 }
1669
1670 /// Commit the command buffer and wait, returning split timing.
1671 ///
1672 /// Returns `(encoding_ns, gpu_wait_ns)` where:
1673 /// - `encoding_ns` is the time from session begin to commit (CPU encoding)
1674 /// - `gpu_wait_ns` is the time from commit to GPU completion
1675 ///
1676 /// The `session_begin` instant should be captured right after `exec.begin()`.
1677 ///
1678 /// In recording mode: replays the captured graph before committing.
1679 ///
1680 /// Consumes the session.
1681 pub fn finish_with_timing(mut self, session_begin: std::time::Instant) -> Result<(u64, u64)> {
1682 if self.recording {
1683 if let Some(nodes) = self.encoder.take_capture() {
1684 let graph = ComputeGraph::from_nodes(nodes);
1685 graph.encode_sequential(&mut self.encoder);
1686 }
1687 }
1688 let commit_start = std::time::Instant::now();
1689 let encoding_ns = commit_start.duration_since(session_begin).as_nanos() as u64;
1690 self.encoder.commit();
1691 self.encoder.wait_until_completed()?;
1692 let gpu_wait_ns = commit_start.elapsed().as_nanos() as u64;
1693 Ok((encoding_ns, gpu_wait_ns))
1694 }
1695
1696 /// Finish this session and return the GPU wall-clock interval in ns.
1697 ///
1698 /// Returns `(gpu_interval_ns,)` where `gpu_interval_ns` is the CFTimeInterval
1699 /// difference between `MTLCommandBuffer.GPUEndTime` and
1700 /// `MTLCommandBuffer.GPUStartTime`, converted to ns. Excludes CPU
1701 /// commit+wait overhead — that appears in the residual when bucket
1702 /// sums are compared to the outer wall-clock.
1703 ///
1704 /// Used by `HF2Q_PROFILE_GPU_TS=1` to accumulate pure GPU time per
1705 /// op bucket. In recording mode: replays the captured graph before
1706 /// committing.
1707 ///
1708 /// Consumes the session.
1709 pub fn finish_with_gpu_time(mut self) -> Result<u64> {
1710 if self.recording {
1711 if let Some(nodes) = self.encoder.take_capture() {
1712 let graph = ComputeGraph::from_nodes(nodes);
1713 graph.encode_sequential(&mut self.encoder);
1714 }
1715 }
1716 let (gs, ge) = self.encoder.commit_wait_with_gpu_time()?;
1717 // GPUStartTime/GPUEndTime are CFTimeInterval (seconds, double).
1718 // Guard against negative deltas (can happen on the first CB of
1719 // a run if the kernel driver lazily initialises the timeline;
1720 // clamp to zero in that case).
1721 let delta = (ge - gs).max(0.0);
1722 Ok((delta * 1.0e9) as u64)
1723 }
1724
1725 /// Finish with fusion: run the RMS norm + MUL fusion pass before
1726 /// replaying the graph.
1727 ///
1728 /// Only meaningful in recording mode. In direct-dispatch mode, this
1729 /// behaves identically to `finish()`.
1730 ///
1731 /// Returns `(fusions_applied,)` on success.
1732 pub fn finish_with_fusion(
1733 mut self,
1734 registry: &mut KernelRegistry,
1735 device: &metal::DeviceRef,
1736 ) -> Result<u32> {
1737 let mut fusions = 0;
1738 if self.recording {
1739 if let Some(nodes) = self.encoder.take_capture() {
1740 let mut graph = ComputeGraph::from_nodes(nodes);
1741 fusions = graph.fuse(registry, device)?;
1742 graph.encode_sequential(&mut self.encoder);
1743 }
1744 }
1745 self.encoder.commit_and_wait()?;
1746 Ok(fusions)
1747 }
1748
1749 /// Finish with fusion and split timing.
1750 ///
1751 /// Like `finish_with_timing` but runs the fusion pass first.
1752 /// Returns `(encoding_ns, gpu_wait_ns, fusions_applied)`.
1753 pub fn finish_with_fusion_and_timing(
1754 mut self,
1755 registry: &mut KernelRegistry,
1756 device: &metal::DeviceRef,
1757 session_begin: std::time::Instant,
1758 ) -> Result<(u64, u64, u32)> {
1759 let mut fusions = 0;
1760 if self.recording {
1761 if let Some(nodes) = self.encoder.take_capture() {
1762 let mut graph = ComputeGraph::from_nodes(nodes);
1763 fusions = graph.fuse(registry, device)?;
1764 graph.encode_sequential(&mut self.encoder);
1765 }
1766 }
1767 let commit_start = std::time::Instant::now();
1768 let encoding_ns = commit_start.duration_since(session_begin).as_nanos() as u64;
1769 self.encoder.commit();
1770 self.encoder.wait_until_completed()?;
1771 let gpu_wait_ns = commit_start.elapsed().as_nanos() as u64;
1772 Ok((encoding_ns, gpu_wait_ns, fusions))
1773 }
1774
1775 /// Finish with fusion AND reorder: run both graph optimization passes
1776 /// before replaying the graph.
1777 ///
1778 /// Only meaningful in recording mode. In direct-dispatch mode, this
1779 /// behaves identically to `finish()`.
1780 ///
1781 /// Returns `(fusions_applied, nodes_reordered)` on success.
1782 pub fn finish_with_fusion_and_reorder(
1783 mut self,
1784 registry: &mut KernelRegistry,
1785 device: &metal::DeviceRef,
1786 ) -> Result<(u32, u32)> {
1787 let mut fusions = 0;
1788 let mut reordered = 0;
1789 if self.recording {
1790 if let Some(nodes) = self.encoder.take_capture() {
1791 let mut graph = ComputeGraph::from_nodes(nodes);
1792 fusions = graph.fuse(registry, device)?;
1793 reordered = graph.reorder();
1794 graph.encode_with_barriers(&mut self.encoder);
1795 }
1796 }
1797 self.encoder.commit_and_wait()?;
1798 Ok((fusions, reordered))
1799 }
1800
1801 /// Finish with fusion, reorder, and split timing.
1802 ///
1803 /// Like `finish_with_fusion_and_timing` but also runs the reorder pass.
1804 /// Returns `(encoding_ns, gpu_wait_ns, fusions_applied, nodes_reordered)`.
1805 pub fn finish_with_fusion_reorder_and_timing(
1806 mut self,
1807 registry: &mut KernelRegistry,
1808 device: &metal::DeviceRef,
1809 session_begin: std::time::Instant,
1810 ) -> Result<(u64, u64, u32, u32)> {
1811 let mut fusions = 0;
1812 let mut reordered = 0;
1813 if self.recording {
1814 if let Some(nodes) = self.encoder.take_capture() {
1815 let mut graph = ComputeGraph::from_nodes(nodes);
1816 fusions = graph.fuse(registry, device)?;
1817 reordered = graph.reorder();
1818 graph.encode_with_barriers(&mut self.encoder);
1819 }
1820 }
1821 let commit_start = std::time::Instant::now();
1822 let encoding_ns = commit_start.duration_since(session_begin).as_nanos() as u64;
1823 self.encoder.commit();
1824 self.encoder.wait_until_completed()?;
1825 let gpu_wait_ns = commit_start.elapsed().as_nanos() as u64;
1826 Ok((encoding_ns, gpu_wait_ns, fusions, reordered))
1827 }
1828
1829 /// Finish with the full optimization pipeline: fuse, reorder, dual-buffer
1830 /// encode.
1831 ///
1832 /// Runs the fusion pass, reorder pass, then encodes the graph into two
1833 /// Metal command buffers for CPU/GPU overlap. The first ~10% of dispatches
1834 /// are committed immediately so the GPU can start executing while the CPU
1835 /// encodes the remaining ~90%.
1836 ///
1837 /// Only meaningful in recording mode. In direct-dispatch mode, this
1838 /// behaves identically to `finish()`.
1839 ///
1840 /// Returns `(fusions_applied, nodes_reordered, barriers_buf0, barriers_buf1)`.
1841 pub fn finish_optimized(
1842 mut self,
1843 registry: &mut KernelRegistry,
1844 device: &metal::DeviceRef,
1845 ) -> Result<(u32, u32, u32, u32)> {
1846 let mut fusions = 0;
1847 let mut reordered = 0;
1848 let mut barriers0 = 0u32;
1849 let mut barriers1 = 0u32;
1850
1851 if self.recording {
1852 if let Some(nodes) = self.encoder.take_capture() {
1853 // Commit the capture encoder's empty command buffer so its
1854 // MTLCommandQueue pool slot is freed (same fix as timing variant).
1855 self.encoder.commit();
1856
1857 let mut graph = ComputeGraph::from_nodes(nodes);
1858 fusions = graph.fuse(registry, device)?;
1859 reordered = graph.reorder();
1860
1861 let mut enc0 = self.device.command_encoder()?;
1862 let mut enc1 = self.device.command_encoder()?;
1863
1864 let (b0, b1) = graph.encode_dual_buffer(&mut enc0, &mut enc1);
1865 barriers0 = b0;
1866 barriers1 = b1;
1867
1868 // enc0 was already committed inside encode_dual_buffer.
1869 // Commit enc1 and wait — Metal queue ordering guarantees enc0
1870 // finishes before enc1 starts executing.
1871 enc1.commit_and_wait()?;
1872
1873 // The original encoder was never committed (capture mode drained
1874 // it). We need to end it cleanly — dropping it will end the
1875 // active encoder if any, and the uncommitted command buffer is
1876 // abandoned. That is safe: Metal silently drops uncommitted
1877 // command buffers.
1878 return Ok((fusions, reordered, barriers0, barriers1));
1879 }
1880 }
1881
1882 // Direct-dispatch fallback: just commit the original encoder.
1883 self.encoder.commit_and_wait()?;
1884 Ok((fusions, reordered, barriers0, barriers1))
1885 }
1886
1887 /// Finish with the full optimization pipeline and split timing.
1888 ///
1889 /// Like `finish_optimized` but returns timing information.
1890 /// Returns `(encoding_ns, gpu_wait_ns, fusions, reordered, barriers_buf0, barriers_buf1)`.
1891 ///
1892 /// Timing breakdown:
1893 /// - `encoding_ns`: CPU time from session begin to first buffer commit
1894 /// (fusion + reorder + encode chunk 0)
1895 /// - `gpu_wait_ns`: wall time from second buffer commit to GPU completion
1896 /// (includes GPU execution of both buffers, overlapped with chunk 1 encoding)
1897 pub fn finish_optimized_with_timing(
1898 mut self,
1899 registry: &mut KernelRegistry,
1900 device: &metal::DeviceRef,
1901 session_begin: std::time::Instant,
1902 ) -> Result<(u64, u64, u32, u32, u32, u32)> {
1903 let mut fusions = 0;
1904 let mut reordered = 0;
1905 let mut barriers0 = 0u32;
1906 let mut barriers1 = 0u32;
1907
1908 if self.recording {
1909 if let Some(nodes) = self.encoder.take_capture() {
1910 // Commit the capture encoder's empty command buffer so its
1911 // MTLCommandQueue pool slot is freed. Without this, each
1912 // token leaks one uncommitted buffer and the queue exhausts
1913 // its ~64-slot pool after ~64 tokens, causing a deadlock.
1914 self.encoder.commit();
1915
1916 let opt_t0 = std::time::Instant::now();
1917 let mut graph = ComputeGraph::from_nodes(nodes);
1918 let fuse_t0 = std::time::Instant::now();
1919 fusions = graph.fuse(registry, device)?;
1920 let fuse_us = fuse_t0.elapsed().as_micros();
1921
1922 let reorder_t0 = std::time::Instant::now();
1923 let unannotated = graph.unannotated_dispatch_count();
1924 if unannotated == 0 {
1925 reordered = graph.reorder();
1926 } else if std::env::var("HF2Q_MLX_TIMING").is_ok() {
1927 eprintln!(" [GRAPH_OPT] WARN: skipping reorder — {} of {} dispatches lack range annotations",
1928 unannotated, graph.dispatch_count());
1929 }
1930 let reorder_us = reorder_t0.elapsed().as_micros();
1931 let opt_us = opt_t0.elapsed().as_micros();
1932
1933 let diag = std::env::var("HF2Q_GRAPH_DIAG").is_ok();
1934 let t0 = std::time::Instant::now();
1935 let mut enc0 = self.device.command_encoder()?;
1936 let mut enc1 = self.device.command_encoder()?;
1937 let enc_create_us = t0.elapsed().as_micros();
1938
1939 let t1 = std::time::Instant::now();
1940 let (b0, b1) = graph.encode_dual_buffer(&mut enc0, &mut enc1);
1941 barriers0 = b0;
1942 barriers1 = b1;
1943 let encode_us = t1.elapsed().as_micros();
1944
1945 let encoding_ns = session_begin.elapsed().as_nanos() as u64;
1946
1947 let wait_start = std::time::Instant::now();
1948 enc1.commit_and_wait()?;
1949 let gpu_wait_ns = wait_start.elapsed().as_nanos() as u64;
1950
1951 if diag {
1952 eprintln!(" [DIAG] fuse={:.1}ms reorder={:.1}ms opt_total={:.1}ms enc_create={:.1}ms encode={:.1}ms gpu_wait={:.1}ms barriers={}+{}",
1953 fuse_us as f64 / 1e3, reorder_us as f64 / 1e3, opt_us as f64 / 1e3,
1954 enc_create_us as f64 / 1e3, encode_us as f64 / 1e3,
1955 gpu_wait_ns as f64 / 1e6, b0, b1);
1956 }
1957
1958 return Ok((encoding_ns, gpu_wait_ns, fusions, reordered, barriers0, barriers1));
1959 }
1960 }
1961
1962 // Direct-dispatch fallback.
1963 let commit_start = std::time::Instant::now();
1964 let encoding_ns = commit_start.duration_since(session_begin).as_nanos() as u64;
1965 self.encoder.commit();
1966 self.encoder.wait_until_completed()?;
1967 let gpu_wait_ns = commit_start.elapsed().as_nanos() as u64;
1968 Ok((encoding_ns, gpu_wait_ns, fusions, reordered, barriers0, barriers1))
1969 }
1970}