mlx_native/graph.rs
1//! [`GraphExecutor`] — batched Metal dispatch for single-encoder forward passes.
2//!
3//! llama.cpp's speed advantage over candle is NOT the kernels (Phase 0 proved
4//! candle's are as fast or faster per-call). It is the dispatch pattern:
5//! 1 encoder per command buffer instead of ~120. This module implements that
6//! pattern.
7//!
8//! # Usage
9//!
10//! ```ignore
11//! let mut executor = GraphExecutor::new(device.clone());
12//! let mut session = executor.begin()?;
13//!
14//! // All ops encode into the same command buffer — no per-op encoder creation.
15//! session.rms_norm(&mut registry, device.metal_device(), input, weight, output, params, rows, dim)?;
16//! session.quantized_matmul(&mut registry, &device, input, weight, scales, biases, &qparams)?;
17//! session.elementwise_add(&mut registry, device.metal_device(), a, b, out, n, DType::F32)?;
18//!
19//! // Single GPU sync point for the entire forward pass.
20//! session.finish()?;
21//! ```
22//!
23//! # Design
24//!
25//! The `GraphSession` holds a single `CommandEncoder`. Each op method delegates
26//! to the existing op dispatch functions in [`crate::ops`], passing the session's
27//! shared encoder. No new Metal code is needed — the ops already work with a
28//! shared encoder. The executor just prevents creating a new encoder per op.
29//!
30//! # Phase 4e.1 — Graph IR
31//!
32//! The `ComputeGraph` type captures dispatches into a `Vec<CapturedNode>` for
33//! later replay. `GraphExecutor::begin_recorded()` starts a session in capture
34//! mode: all op calls are intercepted at the `CommandEncoder` level and recorded
35//! instead of being sent to Metal. `GraphSession::finish()` detects capture
36//! mode, extracts the recorded graph, and replays it into a fresh encoder via
37//! `ComputeGraph::encode_sequential()`.
38//!
39//! The existing direct-dispatch path (`begin()`) is completely unchanged.
40
41use metal::foreign_types::ForeignType;
42
43use crate::device::MlxDevice;
44use crate::encoder::{CapturedNode, CapturedOpKind, CommandEncoder, MemRange, RecordedBinding};
45use crate::error::Result;
46use crate::kernel_registry::KernelRegistry;
47use crate::ops;
48
49// Re-export types used in the public API so callers don't need separate imports.
50pub use crate::buffer::MlxBuffer;
51pub use crate::dtypes::DType;
52
53// ---------------------------------------------------------------------------
54// OpKind — operation classification for the reorder safety whitelist (4e.3)
55// ---------------------------------------------------------------------------
56
57/// Classification of a compute operation for reorder safety analysis.
58///
59/// Operations marked as reorderable can be freely reordered by the graph
60/// optimizer (Phase 4e.3) as long as their data dependencies allow it.
61/// Non-reorderable operations have side effects or dependencies that
62/// require them to stay in their original sequential position.
63#[derive(Clone, Copy, Debug, PartialEq, Eq)]
64pub enum OpKind {
65 /// Matrix multiplication (reorderable).
66 MatMul,
67 /// Expert-routed matrix multiplication (reorderable).
68 MatMulId,
69 /// Normalization — RMS norm, layer norm (reorderable).
70 Norm,
71 /// Rotary position embedding (reorderable).
72 Rope,
73 /// Elementwise ops — add, mul, scale, gelu, softcap, etc. (reorderable).
74 Elementwise,
75 /// Memory copy — KV cache copy, embedding gather (reorderable).
76 Copy,
77 /// Gather/scatter (reorderable).
78 Gather,
79 /// Scaled dot-product attention (NOT reorderable).
80 Sdpa,
81 /// Softmax (NOT reorderable).
82 Softmax,
83 /// MoE gate with CPU readback dependency (NOT reorderable).
84 MoeGate,
85 /// Anything else (NOT reorderable).
86 Other,
87}
88
89impl OpKind {
90 /// Whether this op kind is safe to reorder in the graph optimizer.
91 pub fn is_reorderable(&self) -> bool {
92 matches!(
93 self,
94 Self::MatMul
95 | Self::MatMulId
96 | Self::Norm
97 | Self::Rope
98 | Self::Elementwise
99 | Self::Copy
100 | Self::Gather
101 )
102 }
103}
104
105// ---------------------------------------------------------------------------
106// ComputeGraph — the recorded graph IR
107// ---------------------------------------------------------------------------
108
109/// A recorded sequence of GPU compute dispatches and barriers.
110///
111/// Created by running a forward pass with the encoder in capture mode.
112/// Can be replayed into a real `CommandEncoder` via `encode_sequential()`,
113/// producing identical Metal dispatch behavior to the original direct path.
114///
115/// Future phases (4e.2, 4e.3) will add fusion and reorder passes that
116/// transform the graph before encoding.
117pub struct ComputeGraph {
118 nodes: Vec<CapturedNode>,
119}
120
121impl ComputeGraph {
122 /// Create an empty compute graph.
123 pub fn new() -> Self {
124 Self {
125 nodes: Vec::with_capacity(128),
126 }
127 }
128
129 /// Create a compute graph from a pre-built list of captured nodes.
130 pub fn from_nodes(nodes: Vec<CapturedNode>) -> Self {
131 Self { nodes }
132 }
133
134 /// Record a captured node into the graph.
135 pub fn record(&mut self, node: CapturedNode) {
136 self.nodes.push(node);
137 }
138
139 /// Number of nodes (dispatches + barriers) in the graph.
140 pub fn len(&self) -> usize {
141 self.nodes.len()
142 }
143
144 /// Whether the graph contains no nodes.
145 pub fn is_empty(&self) -> bool {
146 self.nodes.is_empty()
147 }
148
149 /// Number of dispatch nodes (excludes barriers).
150 pub fn dispatch_count(&self) -> usize {
151 self.nodes
152 .iter()
153 .filter(|n| matches!(n, CapturedNode::Dispatch { .. }))
154 .count()
155 }
156
157 /// Number of barrier nodes.
158 pub fn barrier_count(&self) -> usize {
159 self.nodes
160 .iter()
161 .filter(|n| matches!(n, CapturedNode::Barrier))
162 .count()
163 }
164
165 /// Borrow the node list.
166 pub fn nodes(&self) -> &[CapturedNode] {
167 &self.nodes
168 }
169
170 /// Count dispatch nodes that have empty read/write range annotations.
171 ///
172 /// Used for diagnostics: if >0, the reorder pass cannot guarantee
173 /// correctness because it relies on complete annotations.
174 pub fn unannotated_dispatch_count(&self) -> usize {
175 self.nodes
176 .iter()
177 .filter(|n| matches!(n, CapturedNode::Dispatch { reads, writes, .. }
178 if reads.is_empty() || writes.is_empty()))
179 .count()
180 }
181
182 /// Take ownership of the node list, consuming the graph.
183 pub fn into_nodes(self) -> Vec<CapturedNode> {
184 self.nodes
185 }
186
187 /// Encode all nodes sequentially into the given encoder.
188 ///
189 /// Barrier sentinel nodes emit a Metal memory barrier. Dispatch nodes
190 /// are replayed through `CommandEncoder::replay_dispatch()`.
191 ///
192 /// This produces identical GPU behavior to the direct-dispatch path —
193 /// same pipeline bindings, same dispatch dimensions, same barrier
194 /// placement.
195 ///
196 /// Returns the number of barriers emitted.
197 pub fn encode_sequential(&self, encoder: &mut CommandEncoder) -> u32 {
198 let mut barrier_count = 0u32;
199 for node in &self.nodes {
200 match node {
201 CapturedNode::Barrier => {
202 encoder.memory_barrier();
203 barrier_count += 1;
204 }
205 CapturedNode::Dispatch {
206 pipeline,
207 bindings,
208 threads_per_grid,
209 threads_per_threadgroup,
210 threadgroup_memory,
211 dispatch_kind,
212 ..
213 } => {
214 encoder.replay_dispatch(
215 pipeline,
216 bindings,
217 threadgroup_memory,
218 *threads_per_grid,
219 *threads_per_threadgroup,
220 *dispatch_kind,
221 );
222 }
223 }
224 }
225 barrier_count
226 }
227
228 /// Encode the graph into a Metal command buffer, computing barriers on the
229 /// fly from each node's read/write buffer ranges.
230 ///
231 /// This is the correct encoding method for reordered graphs where barrier
232 /// sentinels have been stripped. Mirrors llama.cpp's encode-time barrier
233 /// insertion via `ggml_metal_op_concurrency_check`.
234 ///
235 /// Returns the number of barriers emitted.
236 pub fn encode_with_barriers(&self, encoder: &mut CommandEncoder) -> u32 {
237 let mut tracker = ReorderConflictTracker::new();
238 let mut barrier_count = 0u32;
239
240 for node in &self.nodes {
241 match node {
242 CapturedNode::Dispatch {
243 pipeline,
244 bindings,
245 threads_per_grid,
246 threads_per_threadgroup,
247 threadgroup_memory,
248 dispatch_kind,
249 reads,
250 writes,
251 ..
252 } => {
253 let has_ranges = !reads.is_empty() || !writes.is_empty();
254 if has_ranges && tracker.conflicts(reads, writes) {
255 encoder.memory_barrier();
256 tracker.reset();
257 barrier_count += 1;
258 }
259 if has_ranges {
260 tracker.add(reads, writes);
261 }
262 encoder.replay_dispatch(
263 pipeline,
264 bindings,
265 threadgroup_memory,
266 *threads_per_grid,
267 *threads_per_threadgroup,
268 *dispatch_kind,
269 );
270 }
271 CapturedNode::Barrier => {
272 // Explicit barriers still force a barrier boundary
273 encoder.memory_barrier();
274 tracker.reset();
275 barrier_count += 1;
276 }
277 }
278 }
279 barrier_count
280 }
281
282 /// Encode the graph using two command buffers for CPU/GPU overlap.
283 ///
284 /// The first `n0` dispatches are encoded into `encoder0` and committed
285 /// immediately (GPU starts executing). The remaining dispatches are encoded
286 /// into `encoder1`. The caller is responsible for committing `encoder1`.
287 ///
288 /// This matches llama.cpp's dual command buffer pattern from
289 /// `ggml_metal_graph_compute` (ggml-metal-context.m:441-644):
290 /// `n_nodes_0 = MAX(64, 0.1 * n_nodes)` for the first buffer.
291 ///
292 /// Command buffers submitted to the same `MTLCommandQueue` execute in
293 /// submission order, so `encoder0.commit()` followed by `encoder1.commit()`
294 /// guarantees enc0 finishes before enc1 starts. The win: the GPU starts
295 /// executing enc0 while the CPU is still encoding enc1.
296 ///
297 /// Returns `(barriers_buf0, barriers_buf1)`.
298 pub fn encode_dual_buffer(
299 &self,
300 encoder0: &mut CommandEncoder,
301 encoder1: &mut CommandEncoder,
302 ) -> (u32, u32) {
303 let dispatch_total = self.dispatch_count();
304 let n0 = std::cmp::max(64, dispatch_total / 10);
305
306 // Find the split point: the index of the n0-th dispatch node.
307 let split_idx = find_dispatch_split_index(&self.nodes, n0);
308
309 // Encode first chunk with barrier recomputation, then commit immediately.
310 let barriers0 = encode_chunk_with_barriers(&self.nodes[..split_idx], encoder0);
311 encoder0.commit();
312
313 // Encode second chunk with barrier recomputation.
314 let barriers1 = encode_chunk_with_barriers(&self.nodes[split_idx..], encoder1);
315
316 (barriers0, barriers1)
317 }
318
319 /// Run the RMS norm + MUL fusion pass over the graph.
320 ///
321 /// Scans for the pattern:
322 /// Dispatch(RmsNorm) → Barrier(s) → Dispatch(ElemMul)
323 /// where the MUL reads the norm's output buffer, and replaces the
324 /// sequence with a single fused `rms_norm_mul_*` dispatch.
325 ///
326 /// The fused dispatch:
327 /// - Reads the norm's input (buffer 0) and weight (buffer 1)
328 /// - Reads the MUL's second operand as the scale (buffer 2)
329 /// - Writes to the MUL's output (buffer 3)
330 /// - Carries the norm's params (buffer 4)
331 /// - Uses the norm's threadgroup config and shared memory
332 ///
333 /// Returns the number of fusions applied.
334 ///
335 /// # Arguments
336 ///
337 /// * `registry` - Kernel registry for compiling the fused pipeline.
338 /// * `device` - Metal device for pipeline compilation.
339 pub fn fuse(
340 &mut self,
341 registry: &mut KernelRegistry,
342 device: &metal::DeviceRef,
343 ) -> Result<u32> {
344 let mut result: Vec<CapturedNode> = Vec::with_capacity(self.nodes.len());
345 let mut fusions = 0u32;
346 let mut i = 0;
347
348 while i < self.nodes.len() {
349 // Check if current node is an RMS norm dispatch.
350 let is_rms_norm = matches!(
351 &self.nodes[i],
352 CapturedNode::Dispatch { op_kind: CapturedOpKind::RmsNorm, .. }
353 );
354
355 if !is_rms_norm {
356 result.push(self.nodes[i].clone());
357 i += 1;
358 continue;
359 }
360
361 // Look ahead: skip barriers, then check for ElemMul.
362 let mut j = i + 1;
363 let mut barrier_count = 0usize;
364 while j < self.nodes.len() && matches!(&self.nodes[j], CapturedNode::Barrier) {
365 barrier_count += 1;
366 j += 1;
367 }
368
369 // Must have at least one barrier and the next node must be ElemMul.
370 if barrier_count == 0 || j >= self.nodes.len() {
371 result.push(self.nodes[i].clone());
372 i += 1;
373 continue;
374 }
375
376 let is_elem_mul = matches!(
377 &self.nodes[j],
378 CapturedNode::Dispatch { op_kind: CapturedOpKind::ElemMul, .. }
379 );
380
381 if !is_elem_mul {
382 result.push(self.nodes[i].clone());
383 i += 1;
384 continue;
385 }
386
387 // Extract norm and mul dispatch fields.
388 let (norm_pipeline, norm_bindings, norm_tpg, norm_tptg, norm_tgmem, norm_dk) =
389 match &self.nodes[i] {
390 CapturedNode::Dispatch {
391 pipeline,
392 bindings,
393 threads_per_grid,
394 threads_per_threadgroup,
395 threadgroup_memory,
396 dispatch_kind,
397 ..
398 } => (pipeline, bindings, threads_per_grid, threads_per_threadgroup, threadgroup_memory, dispatch_kind),
399 _ => unreachable!(),
400 };
401
402 let (mul_bindings, _mul_tpg, _mul_tptg) = match &self.nodes[j] {
403 CapturedNode::Dispatch {
404 bindings,
405 threads_per_grid,
406 threads_per_threadgroup,
407 ..
408 } => (bindings, threads_per_grid, threads_per_threadgroup),
409 _ => unreachable!(),
410 };
411
412 // Verify data dependency: the norm's output buffer (slot 2) must
413 // appear as one of the MUL's input buffers (slot 0 or 1).
414 //
415 // Norm binding layout: (0=input, 1=weight, 2=output, 3=params)
416 // MUL binding layout: (0=a, 1=b, 2=output, 3=params_bytes)
417 let norm_output_ptr = Self::buffer_ptr_for_slot(norm_bindings, 2);
418 let mul_a_ptr = Self::buffer_ptr_for_slot(mul_bindings, 0);
419 let mul_b_ptr = Self::buffer_ptr_for_slot(mul_bindings, 1);
420
421 if norm_output_ptr.is_none() || (norm_output_ptr != mul_a_ptr && norm_output_ptr != mul_b_ptr) {
422 // Data dependency not confirmed — don't fuse.
423 result.push(self.nodes[i].clone());
424 i += 1;
425 continue;
426 }
427
428 // Determine which MUL input is the scale (the one that is NOT
429 // the norm's output).
430 let scale_slot = if norm_output_ptr == mul_a_ptr { 1 } else { 0 };
431
432 // Build fused bindings:
433 // 0 = norm input
434 // 1 = norm weight
435 // 2 = scale (from MUL)
436 // 3 = MUL output
437 // 4 = norm params
438 // Gather all required bindings; bail if any are missing.
439 let (norm_input, norm_weight, scale, mul_output, norm_params) = match (
440 Self::get_binding(norm_bindings, 0),
441 Self::get_binding(norm_bindings, 1),
442 Self::get_binding(mul_bindings, scale_slot),
443 Self::get_binding(mul_bindings, 2),
444 Self::get_binding(norm_bindings, 3),
445 ) {
446 (Some(a), Some(b), Some(c), Some(d), Some(e)) => (a, b, c, d, e),
447 _ => {
448 // Missing bindings — don't fuse.
449 result.push(self.nodes[i].clone());
450 i += 1;
451 continue;
452 }
453 };
454
455 // Select fused pipeline based on the original norm pipeline name.
456 // The norm pipeline name is "rms_norm_f32", "rms_norm_f16", or
457 // "rms_norm_bf16" — we need the corresponding fused pipeline.
458 let fused_name = match Self::fused_pipeline_name(norm_pipeline) {
459 Some(name) => name,
460 None => {
461 result.push(self.nodes[i].clone());
462 i += 1;
463 continue;
464 }
465 };
466
467 let fused_pipeline = registry.get_pipeline(fused_name, device)?;
468
469 let fused_bindings = vec![
470 (0, norm_input),
471 (1, norm_weight),
472 (2, scale),
473 (3, mul_output),
474 (4, norm_params),
475 ];
476
477 // Merge read/write ranges from both the norm and mul nodes for the
478 // fused dispatch. The fused op reads everything the norm reads
479 // plus the mul's scale input, and writes to the mul's output.
480 let (fused_reads, fused_writes) = match (&self.nodes[i], &self.nodes[j]) {
481 (
482 CapturedNode::Dispatch { reads: nr, writes: _nw, .. },
483 CapturedNode::Dispatch { reads: mr, writes: mw, .. },
484 ) => {
485 let mut reads = nr.clone();
486 reads.extend_from_slice(mr);
487 (reads, mw.clone())
488 }
489 _ => (Vec::new(), Vec::new()),
490 };
491
492 result.push(CapturedNode::Dispatch {
493 pipeline: fused_pipeline.to_owned(),
494 bindings: fused_bindings,
495 threads_per_grid: *norm_tpg,
496 threads_per_threadgroup: *norm_tptg,
497 threadgroup_memory: norm_tgmem.clone(),
498 dispatch_kind: *norm_dk,
499 op_kind: CapturedOpKind::Other, // Fused ops are not further fuseable
500 reads: fused_reads,
501 writes: fused_writes,
502 });
503
504 fusions += 1;
505 // Skip past the norm, barrier(s), and mul nodes.
506 i = j + 1;
507 }
508
509 self.nodes = result;
510 Ok(fusions)
511 }
512
513 /// Run the reorder pass over the graph to improve GPU concurrency.
514 ///
515 /// Port of llama.cpp's `ggml_metal_graph_optimize_reorder` — a greedy
516 /// 64-node lookahead that pulls independent dispatches forward to fill
517 /// larger concurrent groups between barriers.
518 ///
519 /// **Prerequisites:** Call `fuse()` first if desired. The reorder pass
520 /// operates on the post-fusion graph. Barrier sentinel nodes are stripped
521 /// before reordering (they will be recomputed at encode time by the
522 /// `ConflictTracker` in `encode_sequential`).
523 ///
524 /// **Algorithm (matching llama.cpp exactly):**
525 /// 1. Strip all `CapturedNode::Barrier` nodes.
526 /// 2. For each unprocessed node `i0`:
527 /// - If it conflicts with the current concurrent group (`mrs0`):
528 /// * Initialize `mrs1` from `i0`'s ranges (skipped-over set)
529 /// * Lookahead up to 64 nodes for candidates that:
530 /// (a) Are reorderable (`CapturedOpKind::is_reorderable()`)
531 /// (b) Don't conflict with `mrs0` (current group)
532 /// (c) Don't conflict with `mrs1` (skipped-over nodes)
533 /// * Pull qualifying candidates into the current group
534 /// * Non-reorderable ops break the lookahead
535 /// - Reset `mrs0` (new concurrent group)
536 /// - Add `i0` to the new group
537 ///
538 /// Returns the number of nodes that were moved to earlier positions.
539 pub fn reorder(&mut self) -> u32 {
540 // Step 1: Strip barrier nodes. After fusion + reorder, barriers will
541 // be recomputed by the ConflictTracker at encode time.
542 self.nodes.retain(|n| !matches!(n, CapturedNode::Barrier));
543
544 let n = self.nodes.len();
545 if n == 0 {
546 return 0;
547 }
548
549 let mut result: Vec<usize> = Vec::with_capacity(n);
550 let mut used = vec![false; n];
551
552 // mrs0: memory ranges for the current concurrent group
553 let mut mrs0 = ReorderConflictTracker::new();
554 // mrs1: memory ranges for skipped-over (unprocessed) nodes
555 let mut mrs1 = ReorderConflictTracker::new();
556
557 const N_FORWARD: usize = 64;
558
559 for i0 in 0..n {
560 if used[i0] {
561 continue;
562 }
563
564 let node0 = &self.nodes[i0];
565
566 // Extract reads/writes for conflict check.
567 let (reads0, writes0, op_kind0) = match node0 {
568 CapturedNode::Dispatch { reads, writes, op_kind, .. } => {
569 (reads.as_slice(), writes.as_slice(), *op_kind)
570 }
571 CapturedNode::Barrier => continue, // stripped, but be safe
572 };
573
574 // Check if node0 conflicts with the current concurrent group.
575 // Empty nodes (no ranges) never conflict — like llama.cpp's is_empty.
576 let has_ranges = !reads0.is_empty() || !writes0.is_empty();
577 if has_ranges && mrs0.conflicts(reads0, writes0) {
578 // Before starting a new group, look forward for nodes that
579 // can be pulled into the CURRENT group.
580 mrs1.reset();
581 mrs1.add(reads0, writes0);
582
583 let end = (i0 + N_FORWARD).min(n);
584 for i1 in (i0 + 1)..end {
585 if used[i1] {
586 continue;
587 }
588
589 let node1 = &self.nodes[i1];
590 let (reads1, writes1, op_kind1) = match node1 {
591 CapturedNode::Dispatch { reads, writes, op_kind, .. } => {
592 (reads.as_slice(), writes.as_slice(), *op_kind)
593 }
594 CapturedNode::Barrier => continue,
595 };
596
597 // Non-reorderable ops break the lookahead.
598 if !op_kind1.is_reorderable() {
599 break;
600 }
601
602 let is_empty1 = reads1.is_empty() && writes1.is_empty();
603
604 // A node can be reordered into the current group if:
605 // 1. It's empty (no ranges) OR doesn't conflict with mrs0
606 // 2. It doesn't conflict with mrs1 (skipped-over nodes)
607 if (is_empty1 || !mrs0.conflicts(reads1, writes1))
608 && !mrs1.conflicts(reads1, writes1)
609 {
610 // Pull into current concurrent group.
611 mrs0.add(reads1, writes1);
612 result.push(i1);
613 used[i1] = true;
614 } else {
615 // Not eligible — expand the skipped-over set.
616 mrs1.add(reads1, writes1);
617 }
618 }
619
620 // Finalize the current concurrent group.
621 mrs0.reset();
622 }
623
624 // Expand the concurrent group with node0.
625 // (Barriers were stripped, so this is always a Dispatch.)
626 let _ = op_kind0; // suppress unused warning
627 mrs0.add(reads0, writes0);
628 result.push(i0);
629 }
630
631 // Apply the permutation to produce the reordered node list.
632 let mut reordered_count = 0u32;
633 for (pos, &orig_idx) in result.iter().enumerate() {
634 if orig_idx != pos {
635 reordered_count += 1;
636 }
637 }
638
639 // Build the reordered nodes vec.
640 let old_nodes = std::mem::take(&mut self.nodes);
641 self.nodes = result.iter().map(|&idx| old_nodes[idx].clone()).collect();
642
643 // Debug dump if requested.
644 if std::env::var("HF2Q_REORDER_DUMP").is_ok() {
645 eprintln!(
646 " [REORDER] nodes={} reordered={} ({:.1}%)",
647 n,
648 reordered_count,
649 100.0 * reordered_count as f64 / n as f64,
650 );
651 }
652
653 reordered_count
654 }
655
656 /// Get the Metal buffer pointer for a binding at the given slot index.
657 ///
658 /// Returns `Some(ptr)` if the slot has a `RecordedBinding::Buffer`,
659 /// `None` otherwise.
660 fn buffer_ptr_for_slot(bindings: &[(u64, RecordedBinding)], slot: u64) -> Option<*const std::ffi::c_void> {
661 for (idx, binding) in bindings {
662 if *idx == slot {
663 if let RecordedBinding::Buffer { metal_buffer, offset: _ } = binding {
664 // Use the Metal buffer's GPU address as the identity key.
665 // On Apple Silicon unified memory, this uniquely identifies
666 // the allocation.
667 let ptr: *const std::ffi::c_void = metal_buffer.as_ptr() as *const _;
668 return Some(ptr);
669 }
670 }
671 }
672 None
673 }
674
675 /// Clone the binding at the given slot index.
676 fn get_binding(bindings: &[(u64, RecordedBinding)], slot: u64) -> Option<RecordedBinding> {
677 for (idx, binding) in bindings {
678 if *idx == slot {
679 return Some(binding.clone());
680 }
681 }
682 None
683 }
684
685 /// Map a norm pipeline to its fused norm+mul pipeline name.
686 ///
687 /// The pipeline's `label()` is set by Metal to the function name, so we
688 /// can match on it. Returns `None` if the pipeline is not a known norm.
689 fn fused_pipeline_name(pipeline: &metal::ComputePipelineState) -> Option<&'static str> {
690 match pipeline.label() {
691 "rms_norm_f32" => Some("rms_norm_mul_f32"),
692 "rms_norm_f16" => Some("rms_norm_mul_f16"),
693 "rms_norm_bf16" => Some("rms_norm_mul_bf16"),
694 _ => None,
695 }
696 }
697}
698
699impl Default for ComputeGraph {
700 fn default() -> Self {
701 Self::new()
702 }
703}
704
705// ---------------------------------------------------------------------------
706// Dual-buffer encoding helpers (Phase 4e.4)
707// ---------------------------------------------------------------------------
708
709/// Find the node index where the n0-th dispatch starts.
710///
711/// Counts `CapturedNode::Dispatch` nodes until `n0` are reached, then returns
712/// the index of the n0-th dispatch (i.e., the first node of the second chunk).
713/// If `n0 >= dispatch_count`, returns `nodes.len()` (everything in chunk 0).
714fn find_dispatch_split_index(nodes: &[CapturedNode], n0: usize) -> usize {
715 let mut dispatches_seen = 0usize;
716 for (i, node) in nodes.iter().enumerate() {
717 if matches!(node, CapturedNode::Dispatch { .. }) {
718 dispatches_seen += 1;
719 if dispatches_seen == n0 {
720 return i + 1; // split AFTER the n0-th dispatch
721 }
722 }
723 }
724 nodes.len()
725}
726
727/// Encode a slice of captured nodes into a command encoder, recomputing
728/// barriers on the fly from each node's read/write buffer ranges.
729///
730/// This is the chunked counterpart of `ComputeGraph::encode_with_barriers()`.
731/// Factored out so both halves of a dual-buffer encode can use it.
732///
733/// Returns the number of barriers emitted.
734fn encode_chunk_with_barriers(nodes: &[CapturedNode], encoder: &mut CommandEncoder) -> u32 {
735 let mut tracker = ReorderConflictTracker::new();
736 let mut barrier_count = 0u32;
737
738 for node in nodes {
739 match node {
740 CapturedNode::Dispatch {
741 pipeline,
742 bindings,
743 threads_per_grid,
744 threads_per_threadgroup,
745 threadgroup_memory,
746 dispatch_kind,
747 reads,
748 writes,
749 ..
750 } => {
751 let has_ranges = !reads.is_empty() || !writes.is_empty();
752 if has_ranges && tracker.conflicts(reads, writes) {
753 encoder.memory_barrier();
754 tracker.reset();
755 barrier_count += 1;
756 }
757 if has_ranges {
758 tracker.add(reads, writes);
759 }
760 encoder.replay_dispatch(
761 pipeline,
762 bindings,
763 threadgroup_memory,
764 *threads_per_grid,
765 *threads_per_threadgroup,
766 *dispatch_kind,
767 );
768 }
769 CapturedNode::Barrier => {
770 encoder.memory_barrier();
771 tracker.reset();
772 barrier_count += 1;
773 }
774 }
775 }
776 barrier_count
777}
778
779// ---------------------------------------------------------------------------
780// ReorderConflictTracker — range-based conflict detection for the reorder pass
781// ---------------------------------------------------------------------------
782
783/// Memory range conflict tracker for the reorder pass (Phase 4e.3).
784///
785/// Works with `MemRange` tuples `(start, end)` stored on `CapturedNode::Dispatch`,
786/// rather than requiring live `&MlxBuffer` references. This is the reorder-time
787/// equivalent of the runtime `ConflictTracker`.
788///
789/// Conflict rules match llama.cpp's `ggml_mem_ranges_check`:
790/// - Two read ranges: OK (read-read is concurrent-safe)
791/// - A new read overlapping an existing write: CONFLICT (RAW)
792/// - A new write overlapping any existing range: CONFLICT (WAR/WAW)
793struct ReorderConflictTracker {
794 /// (start, end, is_write) for all ranges in the tracked set.
795 ranges: Vec<(usize, usize, bool)>,
796}
797
798impl ReorderConflictTracker {
799 fn new() -> Self {
800 Self {
801 ranges: Vec::with_capacity(64),
802 }
803 }
804
805 fn reset(&mut self) {
806 self.ranges.clear();
807 }
808
809 /// Check if a dispatch with the given read/write ranges conflicts with
810 /// any range in this tracker.
811 fn conflicts(&self, reads: &[MemRange], writes: &[MemRange]) -> bool {
812 // New reads vs existing writes (RAW)
813 for &(r_start, r_end) in reads {
814 for &(s, e, is_write) in &self.ranges {
815 if is_write && r_start < e && r_end > s {
816 return true;
817 }
818 }
819 }
820 // New writes vs all existing ranges (WAR/WAW)
821 for &(w_start, w_end) in writes {
822 for &(s, e, _) in &self.ranges {
823 if w_start < e && w_end > s {
824 return true;
825 }
826 }
827 }
828 false
829 }
830
831 /// Add read and write ranges to the tracked set.
832 fn add(&mut self, reads: &[MemRange], writes: &[MemRange]) {
833 for &(start, end) in reads {
834 self.ranges.push((start, end, false));
835 }
836 for &(start, end) in writes {
837 self.ranges.push((start, end, true));
838 }
839 }
840}
841
842/// Batched Metal dispatch — encodes multiple ops into a single `CommandEncoder`.
843///
844/// Create one per model (or per forward-pass loop). Call [`begin`](Self::begin)
845/// at the start of each forward pass to get a [`GraphSession`] that holds the
846/// shared encoder.
847pub struct GraphExecutor {
848 device: MlxDevice,
849}
850
851impl GraphExecutor {
852 /// Create a new graph executor backed by the given device.
853 pub fn new(device: MlxDevice) -> Self {
854 Self { device }
855 }
856
857 /// Begin a new forward pass (direct-dispatch mode).
858 ///
859 /// Returns a [`GraphSession`] that holds a fresh `CommandEncoder`. All ops
860 /// encoded through the session share this single encoder. Call
861 /// [`GraphSession::finish`] to commit and wait.
862 pub fn begin(&self) -> Result<GraphSession<'_>> {
863 let encoder = self.device.command_encoder()?;
864 Ok(GraphSession {
865 encoder,
866 device: &self.device,
867 barrier_count: 0,
868 tracker: ConflictTracker::new(),
869 dispatch_in_group: 0,
870 total_dispatches: 0,
871 group_sizes: [0; 8],
872 recording: false,
873 })
874 }
875
876 /// Begin a new forward pass in capture (record) mode.
877 ///
878 /// All op calls are recorded into a `ComputeGraph` instead of being
879 /// dispatched to Metal. When [`GraphSession::finish`] is called, the
880 /// recorded graph is replayed into a fresh encoder via
881 /// `ComputeGraph::encode_sequential()`.
882 ///
883 /// The API is identical to `begin()` — callers do not need to change
884 /// any op call code. The only behavioral difference: GPU work happens
885 /// at `finish()` time rather than at each op call.
886 pub fn begin_recorded(&self) -> Result<GraphSession<'_>> {
887 let mut encoder = self.device.command_encoder()?;
888 encoder.start_capture();
889 Ok(GraphSession {
890 encoder,
891 device: &self.device,
892 barrier_count: 0,
893 tracker: ConflictTracker::new(),
894 dispatch_in_group: 0,
895 total_dispatches: 0,
896 group_sizes: [0; 8],
897 recording: true,
898 })
899 }
900
901 /// Borrow the underlying device.
902 pub fn device(&self) -> &MlxDevice {
903 &self.device
904 }
905}
906
907/// A single forward pass execution context.
908///
909/// All ops are encoded into one `CommandEncoder`. Call [`finish`](Self::finish)
910/// to commit the command buffer and wait for GPU completion — this is the ONLY
911/// sync point per forward pass.
912///
913/// If an op returns an error, the session can be dropped without committing.
914/// The underlying command buffer is abandoned (never committed to the GPU).
915/// Tracks buffer address ranges for automatic barrier elision.
916///
917/// Mirrors llama.cpp's `ggml_mem_ranges` — accumulates the read and write
918/// ranges of all dispatches in the current concurrent group. When a new
919/// dispatch's reads overlap with an existing write (RAW), or its writes
920/// overlap with an existing read or write (WAR/WAW), a barrier is needed.
921/// Otherwise the dispatch can run concurrently and the barrier is elided.
922///
923/// Uses CPU-visible `contents_ptr()` addresses, which on Apple Silicon
924/// unified memory equal the GPU addresses.
925pub struct ConflictTracker {
926 /// (start, end, is_write) tuples for the current concurrent group.
927 ranges: Vec<(usize, usize, bool)>,
928}
929
930impl ConflictTracker {
931 fn new() -> Self {
932 Self {
933 ranges: Vec::with_capacity(32),
934 }
935 }
936
937 /// Reset the tracker — called after emitting a barrier.
938 fn reset(&mut self) {
939 self.ranges.clear();
940 }
941
942 /// Check if a new dispatch with the given reads and writes conflicts
943 /// with the current concurrent group.
944 ///
945 /// Conflict rules (same as llama.cpp `ggml_mem_ranges_check`):
946 /// - Two SRC (read) ranges in the same buffer: OK (read-read)
947 /// - A new SRC overlapping an existing DST: CONFLICT (RAW)
948 /// - A new DST overlapping an existing SRC or DST: CONFLICT (WAR/WAW)
949 /// Check for conflicts and return the reason if one is found.
950 /// Returns (conflict_type, new_buf_ptr, existing_buf_ptr) or None.
951 fn conflicts_reason(&self, reads: &[&MlxBuffer], writes: &[&MlxBuffer])
952 -> Option<(&'static str, usize, usize)>
953 {
954 // Check new reads against existing writes (RAW)
955 for r in reads {
956 let r_start = r.contents_ptr() as usize;
957 let r_end = r_start + r.byte_len();
958 for &(s, e, is_write) in &self.ranges {
959 if is_write && r_start < e && r_end > s {
960 return Some(("RAW", r_start, s));
961 }
962 }
963 }
964 // Check new writes against existing reads and writes (WAR/WAW)
965 for w in writes {
966 let w_start = w.contents_ptr() as usize;
967 let w_end = w_start + w.byte_len();
968 for &(s, e, is_write) in &self.ranges {
969 if w_start < e && w_end > s {
970 let kind = if is_write { "WAW" } else { "WAR" };
971 return Some((kind, w_start, s));
972 }
973 }
974 }
975 None
976 }
977
978 /// Add read and write ranges to the current concurrent group.
979 fn add(&mut self, reads: &[&MlxBuffer], writes: &[&MlxBuffer]) {
980 for r in reads {
981 let start = r.contents_ptr() as usize;
982 let end = start + r.byte_len();
983 self.ranges.push((start, end, false));
984 }
985 for w in writes {
986 let start = w.contents_ptr() as usize;
987 let end = start + w.byte_len();
988 self.ranges.push((start, end, true));
989 }
990 }
991}
992
993pub struct GraphSession<'a> {
994 encoder: CommandEncoder,
995 device: &'a MlxDevice,
996 barrier_count: u32,
997 tracker: ConflictTracker,
998 dispatch_in_group: u32,
999 total_dispatches: u32,
1000 /// Histogram: group_sizes[i] = number of concurrent groups with (i+1) dispatches
1001 group_sizes: [u32; 8],
1002 /// Whether this session was created in capture/record mode.
1003 recording: bool,
1004}
1005
1006impl<'a> GraphSession<'a> {
1007 /// Encode an RMS normalization into this session's encoder.
1008 ///
1009 /// Delegates to [`ops::rms_norm::dispatch_rms_norm`].
1010 pub fn rms_norm(
1011 &mut self,
1012 registry: &mut KernelRegistry,
1013 device: &metal::DeviceRef,
1014 input: &MlxBuffer,
1015 weight: &MlxBuffer,
1016 output: &MlxBuffer,
1017 params_buf: &MlxBuffer,
1018 rows: u32,
1019 dim: u32,
1020 ) -> Result<()> {
1021 ops::rms_norm::dispatch_rms_norm(
1022 &mut self.encoder,
1023 registry,
1024 device,
1025 input,
1026 weight,
1027 output,
1028 params_buf,
1029 rows,
1030 dim,
1031 )
1032 }
1033
1034 /// Encode a quantized matrix multiplication into this session's encoder.
1035 ///
1036 /// Delegates to [`ops::quantized_matmul::quantized_matmul`].
1037 /// Returns the freshly allocated output buffer.
1038 pub fn quantized_matmul(
1039 &mut self,
1040 registry: &mut KernelRegistry,
1041 device: &MlxDevice,
1042 input: &MlxBuffer,
1043 weight: &MlxBuffer,
1044 scales: &MlxBuffer,
1045 biases: &MlxBuffer,
1046 params: &ops::quantized_matmul::QuantizedMatmulParams,
1047 ) -> Result<MlxBuffer> {
1048 ops::quantized_matmul::quantized_matmul(
1049 &mut self.encoder,
1050 registry,
1051 device,
1052 input,
1053 weight,
1054 scales,
1055 biases,
1056 params,
1057 )
1058 }
1059
1060 /// Encode a SIMD-optimized quantized matmul into this session's encoder.
1061 ///
1062 /// Delegates to [`ops::quantized_matmul::quantized_matmul_simd`].
1063 /// Returns the freshly allocated output buffer.
1064 pub fn quantized_matmul_simd(
1065 &mut self,
1066 registry: &mut KernelRegistry,
1067 device: &MlxDevice,
1068 input: &MlxBuffer,
1069 weight: &MlxBuffer,
1070 scales: &MlxBuffer,
1071 biases: &MlxBuffer,
1072 params: &ops::quantized_matmul::QuantizedMatmulParams,
1073 ) -> Result<MlxBuffer> {
1074 ops::quantized_matmul::quantized_matmul_simd(
1075 &mut self.encoder,
1076 registry,
1077 device,
1078 input,
1079 weight,
1080 scales,
1081 biases,
1082 params,
1083 )
1084 }
1085
1086 /// Encode a GGML block-format quantized mat-vec into this session's encoder.
1087 ///
1088 /// Delegates to [`ops::quantized_matmul_ggml::quantized_matmul_ggml`].
1089 pub fn quantized_matmul_ggml(
1090 &mut self,
1091 registry: &mut KernelRegistry,
1092 device: &MlxDevice,
1093 input: &MlxBuffer,
1094 weight: &MlxBuffer,
1095 output: &mut MlxBuffer,
1096 params: &ops::quantized_matmul_ggml::GgmlQuantizedMatmulParams,
1097 ) -> Result<()> {
1098 ops::quantized_matmul_ggml::quantized_matmul_ggml(
1099 &mut self.encoder,
1100 registry,
1101 device,
1102 input,
1103 weight,
1104 output,
1105 params,
1106 )
1107 }
1108
1109 /// Encode an expert-routed GGML block-format quantized mat-vec into this session's encoder.
1110 ///
1111 /// Delegates to [`ops::quantized_matmul_id_ggml::quantized_matmul_id_ggml`].
1112 #[allow(clippy::too_many_arguments)]
1113 pub fn quantized_matmul_id_ggml(
1114 &mut self,
1115 registry: &mut KernelRegistry,
1116 device: &MlxDevice,
1117 input: &MlxBuffer,
1118 weight: &MlxBuffer,
1119 ids: &MlxBuffer,
1120 output: &mut MlxBuffer,
1121 params: &ops::quantized_matmul_id_ggml::GgmlQuantizedMatmulIdParams,
1122 ) -> Result<()> {
1123 ops::quantized_matmul_id_ggml::quantized_matmul_id_ggml(
1124 &mut self.encoder,
1125 registry,
1126 device,
1127 input,
1128 weight,
1129 ids,
1130 output,
1131 params,
1132 )
1133 }
1134
1135 /// Pooled-scratch variant of [`Self::quantized_matmul_id_ggml`] — the
1136 /// `IdMmScratch` is caller-owned so batched prefill amortises the
1137 /// per-call allocations that the auto entry point incurs (ADR-011
1138 /// Phase 3 Wave P3b).
1139 #[allow(clippy::too_many_arguments)]
1140 pub fn quantized_matmul_id_ggml_pooled(
1141 &mut self,
1142 registry: &mut KernelRegistry,
1143 device: &MlxDevice,
1144 input: &MlxBuffer,
1145 weight: &MlxBuffer,
1146 ids: &MlxBuffer,
1147 output: &mut MlxBuffer,
1148 scratch: &mut ops::quantized_matmul_id_ggml::IdMmScratch,
1149 params: &ops::quantized_matmul_id_ggml::GgmlQuantizedMatmulIdParams,
1150 ) -> Result<()> {
1151 ops::quantized_matmul_id_ggml::quantized_matmul_id_ggml_pooled(
1152 &mut self.encoder,
1153 registry,
1154 device,
1155 input,
1156 weight,
1157 ids,
1158 output,
1159 scratch,
1160 params,
1161 )
1162 }
1163
1164 /// Encode scaled dot-product attention into this session's encoder.
1165 ///
1166 /// Delegates to [`ops::sdpa::sdpa`].
1167 pub fn sdpa(
1168 &mut self,
1169 registry: &mut KernelRegistry,
1170 device: &MlxDevice,
1171 q: &MlxBuffer,
1172 k: &MlxBuffer,
1173 v: &MlxBuffer,
1174 output: &MlxBuffer,
1175 params: &ops::sdpa::SdpaParams,
1176 batch_size: u32,
1177 ) -> Result<()> {
1178 ops::sdpa::sdpa(
1179 &mut self.encoder,
1180 registry,
1181 device,
1182 q,
1183 k,
1184 v,
1185 output,
1186 params,
1187 batch_size,
1188 )
1189 }
1190
1191 /// Encode flash attention vector (SIMD-vectorized decode-path SDPA).
1192 ///
1193 /// Delegates to [`ops::flash_attn_vec::flash_attn_vec`].
1194 pub fn flash_attn_vec(
1195 &mut self,
1196 registry: &mut KernelRegistry,
1197 device: &MlxDevice,
1198 q: &MlxBuffer,
1199 k: &MlxBuffer,
1200 v: &MlxBuffer,
1201 output: &MlxBuffer,
1202 tmp: &MlxBuffer,
1203 params: &ops::flash_attn_vec::FlashAttnVecParams,
1204 ) -> Result<()> {
1205 ops::flash_attn_vec::flash_attn_vec(
1206 &mut self.encoder,
1207 registry,
1208 device,
1209 q,
1210 k,
1211 v,
1212 output,
1213 tmp,
1214 params,
1215 )
1216 }
1217
1218 /// Encode an elementwise add into this session's encoder.
1219 ///
1220 /// Delegates to [`ops::elementwise::elementwise_add`].
1221 pub fn elementwise_add(
1222 &mut self,
1223 registry: &mut KernelRegistry,
1224 device: &metal::DeviceRef,
1225 a: &MlxBuffer,
1226 b: &MlxBuffer,
1227 output: &MlxBuffer,
1228 n_elements: usize,
1229 dtype: DType,
1230 ) -> Result<()> {
1231 ops::elementwise::elementwise_add(
1232 &mut self.encoder,
1233 registry,
1234 device,
1235 a,
1236 b,
1237 output,
1238 n_elements,
1239 dtype,
1240 )
1241 }
1242
1243 /// Encode an elementwise multiply into this session's encoder.
1244 ///
1245 /// Delegates to [`ops::elementwise::elementwise_mul`].
1246 pub fn elementwise_mul(
1247 &mut self,
1248 registry: &mut KernelRegistry,
1249 device: &metal::DeviceRef,
1250 a: &MlxBuffer,
1251 b: &MlxBuffer,
1252 output: &MlxBuffer,
1253 n_elements: usize,
1254 dtype: DType,
1255 ) -> Result<()> {
1256 ops::elementwise::elementwise_mul(
1257 &mut self.encoder,
1258 registry,
1259 device,
1260 a,
1261 b,
1262 output,
1263 n_elements,
1264 dtype,
1265 )
1266 }
1267
1268 /// Encode a RoPE transform into this session's encoder.
1269 ///
1270 /// Delegates to [`ops::rope::dispatch_rope`].
1271 pub fn rope(
1272 &mut self,
1273 registry: &mut KernelRegistry,
1274 device: &metal::DeviceRef,
1275 input: &MlxBuffer,
1276 output: &MlxBuffer,
1277 params_buf: &MlxBuffer,
1278 positions_buf: &MlxBuffer,
1279 seq_len: u32,
1280 head_dim: u32,
1281 ) -> Result<()> {
1282 ops::rope::dispatch_rope(
1283 &mut self.encoder,
1284 registry,
1285 device,
1286 input,
1287 output,
1288 params_buf,
1289 positions_buf,
1290 seq_len,
1291 head_dim,
1292 )
1293 }
1294
1295 /// Encode a GELU activation into this session's encoder.
1296 ///
1297 /// Delegates to [`ops::gelu::dispatch_gelu`].
1298 pub fn gelu(
1299 &mut self,
1300 registry: &mut KernelRegistry,
1301 device: &metal::DeviceRef,
1302 input: &MlxBuffer,
1303 output: &MlxBuffer,
1304 ) -> Result<()> {
1305 ops::gelu::dispatch_gelu(
1306 &mut self.encoder,
1307 registry,
1308 device,
1309 input,
1310 output,
1311 )
1312 }
1313
1314 /// Encode a softmax into this session's encoder.
1315 ///
1316 /// Delegates to [`ops::softmax::dispatch_softmax`].
1317 pub fn softmax(
1318 &mut self,
1319 registry: &mut KernelRegistry,
1320 device: &metal::DeviceRef,
1321 input: &MlxBuffer,
1322 output: &MlxBuffer,
1323 params_buf: &MlxBuffer,
1324 rows: u32,
1325 cols: u32,
1326 ) -> Result<()> {
1327 ops::softmax::dispatch_softmax(
1328 &mut self.encoder,
1329 registry,
1330 device,
1331 input,
1332 output,
1333 params_buf,
1334 rows,
1335 cols,
1336 )
1337 }
1338
1339 /// Encode a softcap into this session's encoder.
1340 ///
1341 /// Delegates to [`ops::softcap::dispatch_softcap`].
1342 pub fn softcap(
1343 &mut self,
1344 registry: &mut KernelRegistry,
1345 device: &metal::DeviceRef,
1346 input: &MlxBuffer,
1347 output: &MlxBuffer,
1348 params_buf: &MlxBuffer,
1349 cap: f32,
1350 ) -> Result<()> {
1351 ops::softcap::dispatch_softcap(
1352 &mut self.encoder,
1353 registry,
1354 device,
1355 input,
1356 output,
1357 params_buf,
1358 cap,
1359 )
1360 }
1361
1362 /// Encode an RMS norm without learned scale (f32) into this session's encoder.
1363 ///
1364 /// Delegates to [`ops::rms_norm::dispatch_rms_norm_no_scale_f32`].
1365 pub fn rms_norm_no_scale_f32(
1366 &mut self,
1367 registry: &mut KernelRegistry,
1368 device: &metal::DeviceRef,
1369 input: &MlxBuffer,
1370 output: &MlxBuffer,
1371 params_buf: &MlxBuffer,
1372 rows: u32,
1373 dim: u32,
1374 ) -> Result<()> {
1375 ops::rms_norm::dispatch_rms_norm_no_scale_f32(
1376 &mut self.encoder,
1377 registry,
1378 device,
1379 input,
1380 output,
1381 params_buf,
1382 rows,
1383 dim,
1384 )
1385 }
1386
1387 /// Encode a NeoX RoPE (f32) with optional freq_factors into this session's encoder.
1388 ///
1389 /// Delegates to [`ops::rope::dispatch_rope_neox_f32`].
1390 #[allow(clippy::too_many_arguments)]
1391 pub fn rope_neox_f32(
1392 &mut self,
1393 registry: &mut KernelRegistry,
1394 device: &metal::DeviceRef,
1395 input: &MlxBuffer,
1396 output: &MlxBuffer,
1397 params_buf: &MlxBuffer,
1398 positions_buf: &MlxBuffer,
1399 freq_factors: Option<&MlxBuffer>,
1400 seq_len: u32,
1401 n_heads: u32,
1402 head_dim: u32,
1403 rope_dim: u32,
1404 ) -> Result<()> {
1405 ops::rope::dispatch_rope_neox_f32(
1406 &mut self.encoder,
1407 registry,
1408 device,
1409 input,
1410 output,
1411 params_buf,
1412 positions_buf,
1413 freq_factors,
1414 seq_len,
1415 n_heads,
1416 head_dim,
1417 rope_dim,
1418 )
1419 }
1420
1421 /// Insert a GPU memory barrier (MTLBarrierScopeBuffers).
1422 ///
1423 /// Unconditional barrier — always emits. Use `barrier_between` for
1424 /// automatic conflict detection that can elide unnecessary barriers.
1425 #[inline]
1426 pub fn barrier(&mut self) {
1427 // Record the outgoing group size
1428 if self.dispatch_in_group > 0 {
1429 let idx = (self.dispatch_in_group as usize).min(self.group_sizes.len()) - 1;
1430 self.group_sizes[idx] += 1;
1431 }
1432 self.encoder.memory_barrier();
1433 self.tracker.reset();
1434 self.barrier_count += 1;
1435 self.dispatch_in_group = 0;
1436 }
1437
1438 /// Smart barrier with conflict detection.
1439 ///
1440 /// Checks if the next dispatch (with the given read and write buffers)
1441 /// actually conflicts with any dispatch in the current concurrent group.
1442 /// If yes, emits a Metal barrier and resets the tracker. If no, the
1443 /// barrier is elided and the dispatch can run concurrently.
1444 ///
1445 /// This mirrors llama.cpp's `ggml_metal_op_concurrency_check` +
1446 /// `ggml_metal_op_concurrency_reset` pattern.
1447 #[inline]
1448 pub fn barrier_between(&mut self, reads: &[&MlxBuffer], writes: &[&MlxBuffer]) {
1449 // In capture mode, stash the read/write ranges so the next captured
1450 // dispatch node carries them for the reorder pass (Phase 4e.3).
1451 if self.recording {
1452 let read_ranges: Vec<MemRange> = reads
1453 .iter()
1454 .map(|b| {
1455 let start = b.contents_ptr() as usize;
1456 (start, start + b.byte_len())
1457 })
1458 .collect();
1459 let write_ranges: Vec<MemRange> = writes
1460 .iter()
1461 .map(|b| {
1462 let start = b.contents_ptr() as usize;
1463 (start, start + b.byte_len())
1464 })
1465 .collect();
1466 self.encoder.set_pending_buffer_ranges(read_ranges, write_ranges);
1467 }
1468
1469 let reason = self.tracker.conflicts_reason(reads, writes);
1470 if let Some((_kind, _new_ptr, _existing_ptr)) = reason {
1471 // Record the outgoing group size before resetting
1472 if self.dispatch_in_group > 0 {
1473 let idx = (self.dispatch_in_group as usize).min(self.group_sizes.len()) - 1;
1474 self.group_sizes[idx] += 1;
1475 }
1476 self.encoder.memory_barrier();
1477 self.tracker.reset();
1478 self.barrier_count += 1;
1479 self.dispatch_in_group = 0;
1480 }
1481 self.dispatch_in_group += 1;
1482 self.total_dispatches += 1;
1483 self.tracker.add(reads, writes);
1484 }
1485
1486 /// Print group size histogram to stderr (for HF2Q_MLX_TIMING debug).
1487 pub fn dump_group_stats(&self) {
1488 // Record the final (unterminated) group
1489 let mut gs = self.group_sizes;
1490 if self.dispatch_in_group > 0 {
1491 let idx = (self.dispatch_in_group as usize).min(gs.len()) - 1;
1492 gs[idx] += 1;
1493 }
1494 let total_groups: u32 = gs.iter().sum();
1495 eprintln!(" [GROUP_STATS] dispatches={} barriers={} groups={} ratio={:.2}",
1496 self.total_dispatches, self.barrier_count, total_groups,
1497 if total_groups > 0 { self.total_dispatches as f64 / total_groups as f64 } else { 0.0 });
1498 for (i, &count) in gs.iter().enumerate() {
1499 if count > 0 {
1500 eprintln!(" size {}: {} groups", i + 1, count);
1501 }
1502 }
1503 }
1504
1505 /// Register a dispatch's buffer ranges without checking for conflicts.
1506 ///
1507 /// Use after dispatching an op that doesn't need a barrier check (e.g.,
1508 /// the first dispatch in a session, or dispatches known to be concurrent).
1509 ///
1510 /// In recording mode, also retroactively annotates the most recently
1511 /// captured dispatch node with these ranges if it was missing them.
1512 /// That keeps the reorder pass able to reason about dispatches that
1513 /// were preceded by `track_dispatch` rather than `barrier_between`.
1514 #[inline]
1515 pub fn track_dispatch(&mut self, reads: &[&MlxBuffer], writes: &[&MlxBuffer]) {
1516 if self.recording {
1517 let read_ranges: Vec<MemRange> = reads
1518 .iter()
1519 .map(|b| {
1520 let start = b.contents_ptr() as usize;
1521 (start, start + b.byte_len())
1522 })
1523 .collect();
1524 let write_ranges: Vec<MemRange> = writes
1525 .iter()
1526 .map(|b| {
1527 let start = b.contents_ptr() as usize;
1528 (start, start + b.byte_len())
1529 })
1530 .collect();
1531 self.encoder
1532 .annotate_last_dispatch_if_missing(read_ranges, write_ranges);
1533 }
1534 self.tracker.add(reads, writes);
1535 }
1536
1537 /// Return the number of barriers inserted so far in this session.
1538 #[inline]
1539 pub fn barrier_count(&self) -> u32 {
1540 self.barrier_count
1541 }
1542
1543 /// Cumulative nanoseconds spent in ConflictTracker checks (diagnostic).
1544 /// Returns 0 when timing is not compiled in.
1545 pub fn tracker_overhead_ns(&self) -> u64 {
1546 0
1547 }
1548
1549 /// Borrow the underlying command encoder for direct op dispatch.
1550 ///
1551 /// Use this when you need to call an op function that is not wrapped by
1552 /// a `GraphSession` method. The returned encoder is the same shared
1553 /// encoder — all dispatches still go into the same command buffer.
1554 pub fn encoder_mut(&mut self) -> &mut CommandEncoder {
1555 &mut self.encoder
1556 }
1557
1558 /// Borrow the device reference.
1559 pub fn device(&self) -> &MlxDevice {
1560 self.device
1561 }
1562
1563 /// Whether this session is in capture/record mode.
1564 pub fn is_recording(&self) -> bool {
1565 self.recording
1566 }
1567
1568 /// Commit the command buffer and wait for GPU completion.
1569 ///
1570 /// This is the ONLY sync point per forward pass. After this call, all
1571 /// output buffers are readable by the CPU.
1572 ///
1573 /// In recording mode: extracts the captured graph, replays it into
1574 /// the encoder via `ComputeGraph::encode_sequential()`, then commits
1575 /// and waits. The result is identical to the direct-dispatch path.
1576 ///
1577 /// Consumes the session — no further ops can be encoded.
1578 pub fn finish(mut self) -> Result<()> {
1579 if self.recording {
1580 if let Some(nodes) = self.encoder.take_capture() {
1581 let graph = ComputeGraph::from_nodes(nodes);
1582 graph.encode_sequential(&mut self.encoder);
1583 }
1584 }
1585 self.encoder.commit_and_wait()
1586 }
1587
1588 /// Commit the command buffer WITHOUT waiting.
1589 ///
1590 /// The GPU begins executing immediately. Use this for fire-and-forget
1591 /// dispatch when you do not need results until later.
1592 ///
1593 /// In recording mode: replays the captured graph before committing.
1594 ///
1595 /// Consumes the session.
1596 pub fn commit(mut self) -> CommandEncoder {
1597 if self.recording {
1598 if let Some(nodes) = self.encoder.take_capture() {
1599 let graph = ComputeGraph::from_nodes(nodes);
1600 graph.encode_sequential(&mut self.encoder);
1601 }
1602 }
1603 self.encoder.commit();
1604 self.encoder
1605 }
1606
1607 /// Commit the command buffer and wait, returning split timing.
1608 ///
1609 /// Returns `(encoding_ns, gpu_wait_ns)` where:
1610 /// - `encoding_ns` is the time from session begin to commit (CPU encoding)
1611 /// - `gpu_wait_ns` is the time from commit to GPU completion
1612 ///
1613 /// The `session_begin` instant should be captured right after `exec.begin()`.
1614 ///
1615 /// In recording mode: replays the captured graph before committing.
1616 ///
1617 /// Consumes the session.
1618 pub fn finish_with_timing(mut self, session_begin: std::time::Instant) -> Result<(u64, u64)> {
1619 if self.recording {
1620 if let Some(nodes) = self.encoder.take_capture() {
1621 let graph = ComputeGraph::from_nodes(nodes);
1622 graph.encode_sequential(&mut self.encoder);
1623 }
1624 }
1625 let commit_start = std::time::Instant::now();
1626 let encoding_ns = commit_start.duration_since(session_begin).as_nanos() as u64;
1627 self.encoder.commit();
1628 self.encoder.wait_until_completed()?;
1629 let gpu_wait_ns = commit_start.elapsed().as_nanos() as u64;
1630 Ok((encoding_ns, gpu_wait_ns))
1631 }
1632
1633 /// Finish this session and return the GPU wall-clock interval in ns.
1634 ///
1635 /// Returns `(gpu_interval_ns,)` where `gpu_interval_ns` is the CFTimeInterval
1636 /// difference between `MTLCommandBuffer.GPUEndTime` and
1637 /// `MTLCommandBuffer.GPUStartTime`, converted to ns. Excludes CPU
1638 /// commit+wait overhead — that appears in the residual when bucket
1639 /// sums are compared to the outer wall-clock.
1640 ///
1641 /// Used by `HF2Q_PROFILE_GPU_TS=1` to accumulate pure GPU time per
1642 /// op bucket. In recording mode: replays the captured graph before
1643 /// committing.
1644 ///
1645 /// Consumes the session.
1646 pub fn finish_with_gpu_time(mut self) -> Result<u64> {
1647 if self.recording {
1648 if let Some(nodes) = self.encoder.take_capture() {
1649 let graph = ComputeGraph::from_nodes(nodes);
1650 graph.encode_sequential(&mut self.encoder);
1651 }
1652 }
1653 let (gs, ge) = self.encoder.commit_wait_with_gpu_time()?;
1654 // GPUStartTime/GPUEndTime are CFTimeInterval (seconds, double).
1655 // Guard against negative deltas (can happen on the first CB of
1656 // a run if the kernel driver lazily initialises the timeline;
1657 // clamp to zero in that case).
1658 let delta = (ge - gs).max(0.0);
1659 Ok((delta * 1.0e9) as u64)
1660 }
1661
1662 /// Finish with fusion: run the RMS norm + MUL fusion pass before
1663 /// replaying the graph.
1664 ///
1665 /// Only meaningful in recording mode. In direct-dispatch mode, this
1666 /// behaves identically to `finish()`.
1667 ///
1668 /// Returns `(fusions_applied,)` on success.
1669 pub fn finish_with_fusion(
1670 mut self,
1671 registry: &mut KernelRegistry,
1672 device: &metal::DeviceRef,
1673 ) -> Result<u32> {
1674 let mut fusions = 0;
1675 if self.recording {
1676 if let Some(nodes) = self.encoder.take_capture() {
1677 let mut graph = ComputeGraph::from_nodes(nodes);
1678 fusions = graph.fuse(registry, device)?;
1679 graph.encode_sequential(&mut self.encoder);
1680 }
1681 }
1682 self.encoder.commit_and_wait()?;
1683 Ok(fusions)
1684 }
1685
1686 /// Finish with fusion and split timing.
1687 ///
1688 /// Like `finish_with_timing` but runs the fusion pass first.
1689 /// Returns `(encoding_ns, gpu_wait_ns, fusions_applied)`.
1690 pub fn finish_with_fusion_and_timing(
1691 mut self,
1692 registry: &mut KernelRegistry,
1693 device: &metal::DeviceRef,
1694 session_begin: std::time::Instant,
1695 ) -> Result<(u64, u64, u32)> {
1696 let mut fusions = 0;
1697 if self.recording {
1698 if let Some(nodes) = self.encoder.take_capture() {
1699 let mut graph = ComputeGraph::from_nodes(nodes);
1700 fusions = graph.fuse(registry, device)?;
1701 graph.encode_sequential(&mut self.encoder);
1702 }
1703 }
1704 let commit_start = std::time::Instant::now();
1705 let encoding_ns = commit_start.duration_since(session_begin).as_nanos() as u64;
1706 self.encoder.commit();
1707 self.encoder.wait_until_completed()?;
1708 let gpu_wait_ns = commit_start.elapsed().as_nanos() as u64;
1709 Ok((encoding_ns, gpu_wait_ns, fusions))
1710 }
1711
1712 /// Finish with fusion AND reorder: run both graph optimization passes
1713 /// before replaying the graph.
1714 ///
1715 /// Only meaningful in recording mode. In direct-dispatch mode, this
1716 /// behaves identically to `finish()`.
1717 ///
1718 /// Returns `(fusions_applied, nodes_reordered)` on success.
1719 pub fn finish_with_fusion_and_reorder(
1720 mut self,
1721 registry: &mut KernelRegistry,
1722 device: &metal::DeviceRef,
1723 ) -> Result<(u32, u32)> {
1724 let mut fusions = 0;
1725 let mut reordered = 0;
1726 if self.recording {
1727 if let Some(nodes) = self.encoder.take_capture() {
1728 let mut graph = ComputeGraph::from_nodes(nodes);
1729 fusions = graph.fuse(registry, device)?;
1730 reordered = graph.reorder();
1731 graph.encode_with_barriers(&mut self.encoder);
1732 }
1733 }
1734 self.encoder.commit_and_wait()?;
1735 Ok((fusions, reordered))
1736 }
1737
1738 /// Finish with fusion, reorder, and split timing.
1739 ///
1740 /// Like `finish_with_fusion_and_timing` but also runs the reorder pass.
1741 /// Returns `(encoding_ns, gpu_wait_ns, fusions_applied, nodes_reordered)`.
1742 pub fn finish_with_fusion_reorder_and_timing(
1743 mut self,
1744 registry: &mut KernelRegistry,
1745 device: &metal::DeviceRef,
1746 session_begin: std::time::Instant,
1747 ) -> Result<(u64, u64, u32, u32)> {
1748 let mut fusions = 0;
1749 let mut reordered = 0;
1750 if self.recording {
1751 if let Some(nodes) = self.encoder.take_capture() {
1752 let mut graph = ComputeGraph::from_nodes(nodes);
1753 fusions = graph.fuse(registry, device)?;
1754 reordered = graph.reorder();
1755 graph.encode_with_barriers(&mut self.encoder);
1756 }
1757 }
1758 let commit_start = std::time::Instant::now();
1759 let encoding_ns = commit_start.duration_since(session_begin).as_nanos() as u64;
1760 self.encoder.commit();
1761 self.encoder.wait_until_completed()?;
1762 let gpu_wait_ns = commit_start.elapsed().as_nanos() as u64;
1763 Ok((encoding_ns, gpu_wait_ns, fusions, reordered))
1764 }
1765
1766 /// Finish with the full optimization pipeline: fuse, reorder, dual-buffer
1767 /// encode.
1768 ///
1769 /// Runs the fusion pass, reorder pass, then encodes the graph into two
1770 /// Metal command buffers for CPU/GPU overlap. The first ~10% of dispatches
1771 /// are committed immediately so the GPU can start executing while the CPU
1772 /// encodes the remaining ~90%.
1773 ///
1774 /// Only meaningful in recording mode. In direct-dispatch mode, this
1775 /// behaves identically to `finish()`.
1776 ///
1777 /// Returns `(fusions_applied, nodes_reordered, barriers_buf0, barriers_buf1)`.
1778 pub fn finish_optimized(
1779 mut self,
1780 registry: &mut KernelRegistry,
1781 device: &metal::DeviceRef,
1782 ) -> Result<(u32, u32, u32, u32)> {
1783 let mut fusions = 0;
1784 let mut reordered = 0;
1785 let mut barriers0 = 0u32;
1786 let mut barriers1 = 0u32;
1787
1788 if self.recording {
1789 if let Some(nodes) = self.encoder.take_capture() {
1790 // Commit the capture encoder's empty command buffer so its
1791 // MTLCommandQueue pool slot is freed (same fix as timing variant).
1792 self.encoder.commit();
1793
1794 let mut graph = ComputeGraph::from_nodes(nodes);
1795 fusions = graph.fuse(registry, device)?;
1796 reordered = graph.reorder();
1797
1798 let mut enc0 = self.device.command_encoder()?;
1799 let mut enc1 = self.device.command_encoder()?;
1800
1801 let (b0, b1) = graph.encode_dual_buffer(&mut enc0, &mut enc1);
1802 barriers0 = b0;
1803 barriers1 = b1;
1804
1805 // enc0 was already committed inside encode_dual_buffer.
1806 // Commit enc1 and wait — Metal queue ordering guarantees enc0
1807 // finishes before enc1 starts executing.
1808 enc1.commit_and_wait()?;
1809
1810 // The original encoder was never committed (capture mode drained
1811 // it). We need to end it cleanly — dropping it will end the
1812 // active encoder if any, and the uncommitted command buffer is
1813 // abandoned. That is safe: Metal silently drops uncommitted
1814 // command buffers.
1815 return Ok((fusions, reordered, barriers0, barriers1));
1816 }
1817 }
1818
1819 // Direct-dispatch fallback: just commit the original encoder.
1820 self.encoder.commit_and_wait()?;
1821 Ok((fusions, reordered, barriers0, barriers1))
1822 }
1823
1824 /// Finish with the full optimization pipeline and split timing.
1825 ///
1826 /// Like `finish_optimized` but returns timing information.
1827 /// Returns `(encoding_ns, gpu_wait_ns, fusions, reordered, barriers_buf0, barriers_buf1)`.
1828 ///
1829 /// Timing breakdown:
1830 /// - `encoding_ns`: CPU time from session begin to first buffer commit
1831 /// (fusion + reorder + encode chunk 0)
1832 /// - `gpu_wait_ns`: wall time from second buffer commit to GPU completion
1833 /// (includes GPU execution of both buffers, overlapped with chunk 1 encoding)
1834 pub fn finish_optimized_with_timing(
1835 mut self,
1836 registry: &mut KernelRegistry,
1837 device: &metal::DeviceRef,
1838 session_begin: std::time::Instant,
1839 ) -> Result<(u64, u64, u32, u32, u32, u32)> {
1840 let mut fusions = 0;
1841 let mut reordered = 0;
1842 let mut barriers0 = 0u32;
1843 let mut barriers1 = 0u32;
1844
1845 if self.recording {
1846 if let Some(nodes) = self.encoder.take_capture() {
1847 // Commit the capture encoder's empty command buffer so its
1848 // MTLCommandQueue pool slot is freed. Without this, each
1849 // token leaks one uncommitted buffer and the queue exhausts
1850 // its ~64-slot pool after ~64 tokens, causing a deadlock.
1851 self.encoder.commit();
1852
1853 let opt_t0 = std::time::Instant::now();
1854 let mut graph = ComputeGraph::from_nodes(nodes);
1855 let fuse_t0 = std::time::Instant::now();
1856 fusions = graph.fuse(registry, device)?;
1857 let fuse_us = fuse_t0.elapsed().as_micros();
1858
1859 let reorder_t0 = std::time::Instant::now();
1860 let unannotated = graph.unannotated_dispatch_count();
1861 if unannotated == 0 {
1862 reordered = graph.reorder();
1863 } else if std::env::var("HF2Q_MLX_TIMING").is_ok() {
1864 eprintln!(" [GRAPH_OPT] WARN: skipping reorder — {} of {} dispatches lack range annotations",
1865 unannotated, graph.dispatch_count());
1866 }
1867 let reorder_us = reorder_t0.elapsed().as_micros();
1868 let opt_us = opt_t0.elapsed().as_micros();
1869
1870 let diag = std::env::var("HF2Q_GRAPH_DIAG").is_ok();
1871 let t0 = std::time::Instant::now();
1872 let mut enc0 = self.device.command_encoder()?;
1873 let mut enc1 = self.device.command_encoder()?;
1874 let enc_create_us = t0.elapsed().as_micros();
1875
1876 let t1 = std::time::Instant::now();
1877 let (b0, b1) = graph.encode_dual_buffer(&mut enc0, &mut enc1);
1878 barriers0 = b0;
1879 barriers1 = b1;
1880 let encode_us = t1.elapsed().as_micros();
1881
1882 let encoding_ns = session_begin.elapsed().as_nanos() as u64;
1883
1884 let wait_start = std::time::Instant::now();
1885 enc1.commit_and_wait()?;
1886 let gpu_wait_ns = wait_start.elapsed().as_nanos() as u64;
1887
1888 if diag {
1889 eprintln!(" [DIAG] fuse={:.1}ms reorder={:.1}ms opt_total={:.1}ms enc_create={:.1}ms encode={:.1}ms gpu_wait={:.1}ms barriers={}+{}",
1890 fuse_us as f64 / 1e3, reorder_us as f64 / 1e3, opt_us as f64 / 1e3,
1891 enc_create_us as f64 / 1e3, encode_us as f64 / 1e3,
1892 gpu_wait_ns as f64 / 1e6, b0, b1);
1893 }
1894
1895 return Ok((encoding_ns, gpu_wait_ns, fusions, reordered, barriers0, barriers1));
1896 }
1897 }
1898
1899 // Direct-dispatch fallback.
1900 let commit_start = std::time::Instant::now();
1901 let encoding_ns = commit_start.duration_since(session_begin).as_nanos() as u64;
1902 self.encoder.commit();
1903 self.encoder.wait_until_completed()?;
1904 let gpu_wait_ns = commit_start.elapsed().as_nanos() as u64;
1905 Ok((encoding_ns, gpu_wait_ns, fusions, reordered, barriers0, barriers1))
1906 }
1907}