mlx_native/graph.rs
1//! [`GraphExecutor`] — batched Metal dispatch for single-encoder forward passes.
2//!
3//! llama.cpp's speed advantage over candle is NOT the kernels (Phase 0 proved
4//! candle's are as fast or faster per-call). It is the dispatch pattern:
5//! 1 encoder per command buffer instead of ~120. This module implements that
6//! pattern.
7//!
8//! # Usage
9//!
10//! ```ignore
11//! let mut executor = GraphExecutor::new(device.clone());
12//! let mut session = executor.begin()?;
13//!
14//! // All ops encode into the same command buffer — no per-op encoder creation.
15//! session.rms_norm(&mut registry, device.metal_device(), input, weight, output, params, rows, dim)?;
16//! session.quantized_matmul(&mut registry, &device, input, weight, scales, biases, &qparams)?;
17//! session.elementwise_add(&mut registry, device.metal_device(), a, b, out, n, DType::F32)?;
18//!
19//! // Single GPU sync point for the entire forward pass.
20//! session.finish()?;
21//! ```
22//!
23//! # Design
24//!
25//! The `GraphSession` holds a single `CommandEncoder`. Each op method delegates
26//! to the existing op dispatch functions in [`crate::ops`], passing the session's
27//! shared encoder. No new Metal code is needed — the ops already work with a
28//! shared encoder. The executor just prevents creating a new encoder per op.
29//!
30//! # Phase 4e.1 — Graph IR
31//!
32//! The `ComputeGraph` type captures dispatches into a `Vec<CapturedNode>` for
33//! later replay. `GraphExecutor::begin_recorded()` starts a session in capture
34//! mode: all op calls are intercepted at the `CommandEncoder` level and recorded
35//! instead of being sent to Metal. `GraphSession::finish()` detects capture
36//! mode, extracts the recorded graph, and replays it into a fresh encoder via
37//! `ComputeGraph::encode_sequential()`.
38//!
39//! The existing direct-dispatch path (`begin()`) is completely unchanged.
40
41use metal::foreign_types::ForeignType;
42
43use crate::device::MlxDevice;
44use crate::encoder::{CapturedNode, CapturedOpKind, CommandEncoder, MemRange, RecordedBinding};
45use crate::error::Result;
46use crate::kernel_registry::KernelRegistry;
47use crate::ops;
48
49// Re-export types used in the public API so callers don't need separate imports.
50pub use crate::buffer::MlxBuffer;
51pub use crate::dtypes::DType;
52
53// ---------------------------------------------------------------------------
54// OpKind — operation classification for the reorder safety whitelist (4e.3)
55// ---------------------------------------------------------------------------
56
57/// Classification of a compute operation for reorder safety analysis.
58///
59/// Operations marked as reorderable can be freely reordered by the graph
60/// optimizer (Phase 4e.3) as long as their data dependencies allow it.
61/// Non-reorderable operations have side effects or dependencies that
62/// require them to stay in their original sequential position.
63#[derive(Clone, Copy, Debug, PartialEq, Eq)]
64pub enum OpKind {
65 /// Matrix multiplication (reorderable).
66 MatMul,
67 /// Expert-routed matrix multiplication (reorderable).
68 MatMulId,
69 /// Normalization — RMS norm, layer norm (reorderable).
70 Norm,
71 /// Rotary position embedding (reorderable).
72 Rope,
73 /// Elementwise ops — add, mul, scale, gelu, softcap, etc. (reorderable).
74 Elementwise,
75 /// Memory copy — KV cache copy, embedding gather (reorderable).
76 Copy,
77 /// Gather/scatter (reorderable).
78 Gather,
79 /// Scaled dot-product attention (NOT reorderable).
80 Sdpa,
81 /// Softmax (NOT reorderable).
82 Softmax,
83 /// MoE gate with CPU readback dependency (NOT reorderable).
84 MoeGate,
85 /// Anything else (NOT reorderable).
86 Other,
87}
88
89impl OpKind {
90 /// Whether this op kind is safe to reorder in the graph optimizer.
91 pub fn is_reorderable(&self) -> bool {
92 matches!(
93 self,
94 Self::MatMul
95 | Self::MatMulId
96 | Self::Norm
97 | Self::Rope
98 | Self::Elementwise
99 | Self::Copy
100 | Self::Gather
101 )
102 }
103}
104
105// ---------------------------------------------------------------------------
106// ComputeGraph — the recorded graph IR
107// ---------------------------------------------------------------------------
108
109/// A recorded sequence of GPU compute dispatches and barriers.
110///
111/// Created by running a forward pass with the encoder in capture mode.
112/// Can be replayed into a real `CommandEncoder` via `encode_sequential()`,
113/// producing identical Metal dispatch behavior to the original direct path.
114///
115/// Future phases (4e.2, 4e.3) will add fusion and reorder passes that
116/// transform the graph before encoding.
117pub struct ComputeGraph {
118 nodes: Vec<CapturedNode>,
119}
120
121impl ComputeGraph {
122 /// Create an empty compute graph.
123 pub fn new() -> Self {
124 Self {
125 nodes: Vec::with_capacity(128),
126 }
127 }
128
129 /// Create a compute graph from a pre-built list of captured nodes.
130 pub fn from_nodes(nodes: Vec<CapturedNode>) -> Self {
131 Self { nodes }
132 }
133
134 /// Record a captured node into the graph.
135 pub fn record(&mut self, node: CapturedNode) {
136 self.nodes.push(node);
137 }
138
139 /// Number of nodes (dispatches + barriers) in the graph.
140 pub fn len(&self) -> usize {
141 self.nodes.len()
142 }
143
144 /// Whether the graph contains no nodes.
145 pub fn is_empty(&self) -> bool {
146 self.nodes.is_empty()
147 }
148
149 /// Number of dispatch nodes (excludes barriers).
150 pub fn dispatch_count(&self) -> usize {
151 self.nodes
152 .iter()
153 .filter(|n| matches!(n, CapturedNode::Dispatch { .. }))
154 .count()
155 }
156
157 /// Number of barrier nodes.
158 pub fn barrier_count(&self) -> usize {
159 self.nodes
160 .iter()
161 .filter(|n| matches!(n, CapturedNode::Barrier))
162 .count()
163 }
164
165 /// Borrow the node list.
166 pub fn nodes(&self) -> &[CapturedNode] {
167 &self.nodes
168 }
169
170 /// Count dispatch nodes that have empty read/write range annotations.
171 ///
172 /// Used for diagnostics: if >0, the reorder pass cannot guarantee
173 /// correctness because it relies on complete annotations.
174 pub fn unannotated_dispatch_count(&self) -> usize {
175 self.nodes
176 .iter()
177 .filter(|n| matches!(n, CapturedNode::Dispatch { reads, writes, .. }
178 if reads.is_empty() || writes.is_empty()))
179 .count()
180 }
181
182 /// Take ownership of the node list, consuming the graph.
183 pub fn into_nodes(self) -> Vec<CapturedNode> {
184 self.nodes
185 }
186
187 /// Encode all nodes sequentially into the given encoder.
188 ///
189 /// Barrier sentinel nodes emit a Metal memory barrier. Dispatch nodes
190 /// are replayed through `CommandEncoder::replay_dispatch()`.
191 ///
192 /// This produces identical GPU behavior to the direct-dispatch path —
193 /// same pipeline bindings, same dispatch dimensions, same barrier
194 /// placement.
195 ///
196 /// Returns the number of barriers emitted.
197 pub fn encode_sequential(&self, encoder: &mut CommandEncoder) -> u32 {
198 let mut barrier_count = 0u32;
199 for node in &self.nodes {
200 match node {
201 CapturedNode::Barrier => {
202 encoder.memory_barrier();
203 barrier_count += 1;
204 }
205 CapturedNode::Dispatch {
206 pipeline,
207 bindings,
208 threads_per_grid,
209 threads_per_threadgroup,
210 threadgroup_memory,
211 dispatch_kind,
212 ..
213 } => {
214 encoder.replay_dispatch(
215 pipeline,
216 bindings,
217 threadgroup_memory,
218 *threads_per_grid,
219 *threads_per_threadgroup,
220 *dispatch_kind,
221 );
222 }
223 }
224 }
225 barrier_count
226 }
227
228 /// Encode the graph into a Metal command buffer, computing barriers on the
229 /// fly from each node's read/write buffer ranges.
230 ///
231 /// This is the correct encoding method for reordered graphs where barrier
232 /// sentinels have been stripped. Mirrors llama.cpp's encode-time barrier
233 /// insertion via `ggml_metal_op_concurrency_check`.
234 ///
235 /// Returns the number of barriers emitted.
236 pub fn encode_with_barriers(&self, encoder: &mut CommandEncoder) -> u32 {
237 let mut tracker = ReorderConflictTracker::new();
238 let mut barrier_count = 0u32;
239
240 for node in &self.nodes {
241 match node {
242 CapturedNode::Dispatch {
243 pipeline,
244 bindings,
245 threads_per_grid,
246 threads_per_threadgroup,
247 threadgroup_memory,
248 dispatch_kind,
249 reads,
250 writes,
251 ..
252 } => {
253 let has_ranges = !reads.is_empty() || !writes.is_empty();
254 if has_ranges && tracker.conflicts(reads, writes) {
255 encoder.memory_barrier();
256 tracker.reset();
257 barrier_count += 1;
258 }
259 if has_ranges {
260 tracker.add(reads, writes);
261 }
262 encoder.replay_dispatch(
263 pipeline,
264 bindings,
265 threadgroup_memory,
266 *threads_per_grid,
267 *threads_per_threadgroup,
268 *dispatch_kind,
269 );
270 }
271 CapturedNode::Barrier => {
272 // Explicit barriers still force a barrier boundary
273 encoder.memory_barrier();
274 tracker.reset();
275 barrier_count += 1;
276 }
277 }
278 }
279 barrier_count
280 }
281
282 /// Encode the graph using two command buffers for CPU/GPU overlap.
283 ///
284 /// The first `n0` dispatches are encoded into `encoder0` and committed
285 /// immediately (GPU starts executing). The remaining dispatches are encoded
286 /// into `encoder1`. The caller is responsible for committing `encoder1`.
287 ///
288 /// This matches llama.cpp's dual command buffer pattern from
289 /// `ggml_metal_graph_compute` (ggml-metal-context.m:441-644):
290 /// `n_nodes_0 = MAX(64, 0.1 * n_nodes)` for the first buffer.
291 ///
292 /// Command buffers submitted to the same `MTLCommandQueue` execute in
293 /// submission order, so `encoder0.commit()` followed by `encoder1.commit()`
294 /// guarantees enc0 finishes before enc1 starts. The win: the GPU starts
295 /// executing enc0 while the CPU is still encoding enc1.
296 ///
297 /// Returns `(barriers_buf0, barriers_buf1)`.
298 pub fn encode_dual_buffer(
299 &self,
300 encoder0: &mut CommandEncoder,
301 encoder1: &mut CommandEncoder,
302 ) -> (u32, u32) {
303 let dispatch_total = self.dispatch_count();
304 let n0 = std::cmp::max(64, dispatch_total / 10);
305
306 // Find the split point: the index of the n0-th dispatch node.
307 let split_idx = find_dispatch_split_index(&self.nodes, n0);
308
309 // Encode first chunk with barrier recomputation, then commit immediately.
310 let barriers0 = encode_chunk_with_barriers(&self.nodes[..split_idx], encoder0);
311 encoder0.commit();
312
313 // Encode second chunk with barrier recomputation.
314 let barriers1 = encode_chunk_with_barriers(&self.nodes[split_idx..], encoder1);
315
316 (barriers0, barriers1)
317 }
318
319 /// Run the RMS norm + MUL fusion pass over the graph.
320 ///
321 /// Scans for the pattern:
322 /// Dispatch(RmsNorm) → Barrier(s) → Dispatch(ElemMul)
323 /// where the MUL reads the norm's output buffer, and replaces the
324 /// sequence with a single fused `rms_norm_mul_*` dispatch.
325 ///
326 /// The fused dispatch:
327 /// - Reads the norm's input (buffer 0) and weight (buffer 1)
328 /// - Reads the MUL's second operand as the scale (buffer 2)
329 /// - Writes to the MUL's output (buffer 3)
330 /// - Carries the norm's params (buffer 4)
331 /// - Uses the norm's threadgroup config and shared memory
332 ///
333 /// Returns the number of fusions applied.
334 ///
335 /// # Arguments
336 ///
337 /// * `registry` - Kernel registry for compiling the fused pipeline.
338 /// * `device` - Metal device for pipeline compilation.
339 pub fn fuse(
340 &mut self,
341 registry: &mut KernelRegistry,
342 device: &metal::DeviceRef,
343 ) -> Result<u32> {
344 let mut result: Vec<CapturedNode> = Vec::with_capacity(self.nodes.len());
345 let mut fusions = 0u32;
346 let mut i = 0;
347
348 while i < self.nodes.len() {
349 // Check if current node is an RMS norm dispatch.
350 let is_rms_norm = matches!(
351 &self.nodes[i],
352 CapturedNode::Dispatch { op_kind: CapturedOpKind::RmsNorm, .. }
353 );
354
355 if !is_rms_norm {
356 result.push(self.nodes[i].clone());
357 i += 1;
358 continue;
359 }
360
361 // Look ahead: skip barriers, then check for ElemMul.
362 let mut j = i + 1;
363 let mut barrier_count = 0usize;
364 while j < self.nodes.len() && matches!(&self.nodes[j], CapturedNode::Barrier) {
365 barrier_count += 1;
366 j += 1;
367 }
368
369 // Must have at least one barrier and the next node must be ElemMul.
370 if barrier_count == 0 || j >= self.nodes.len() {
371 result.push(self.nodes[i].clone());
372 i += 1;
373 continue;
374 }
375
376 let is_elem_mul = matches!(
377 &self.nodes[j],
378 CapturedNode::Dispatch { op_kind: CapturedOpKind::ElemMul, .. }
379 );
380
381 if !is_elem_mul {
382 result.push(self.nodes[i].clone());
383 i += 1;
384 continue;
385 }
386
387 // Extract norm and mul dispatch fields.
388 let (norm_pipeline, norm_bindings, norm_tpg, norm_tptg, norm_tgmem, norm_dk) =
389 match &self.nodes[i] {
390 CapturedNode::Dispatch {
391 pipeline,
392 bindings,
393 threads_per_grid,
394 threads_per_threadgroup,
395 threadgroup_memory,
396 dispatch_kind,
397 ..
398 } => (pipeline, bindings, threads_per_grid, threads_per_threadgroup, threadgroup_memory, dispatch_kind),
399 _ => unreachable!(),
400 };
401
402 let (mul_bindings, _mul_tpg, _mul_tptg) = match &self.nodes[j] {
403 CapturedNode::Dispatch {
404 bindings,
405 threads_per_grid,
406 threads_per_threadgroup,
407 ..
408 } => (bindings, threads_per_grid, threads_per_threadgroup),
409 _ => unreachable!(),
410 };
411
412 // Verify data dependency: the norm's output buffer (slot 2) must
413 // appear as one of the MUL's input buffers (slot 0 or 1).
414 //
415 // Norm binding layout: (0=input, 1=weight, 2=output, 3=params)
416 // MUL binding layout: (0=a, 1=b, 2=output, 3=params_bytes)
417 let norm_output_ptr = Self::buffer_ptr_for_slot(norm_bindings, 2);
418 let mul_a_ptr = Self::buffer_ptr_for_slot(mul_bindings, 0);
419 let mul_b_ptr = Self::buffer_ptr_for_slot(mul_bindings, 1);
420
421 if norm_output_ptr.is_none() || (norm_output_ptr != mul_a_ptr && norm_output_ptr != mul_b_ptr) {
422 // Data dependency not confirmed — don't fuse.
423 result.push(self.nodes[i].clone());
424 i += 1;
425 continue;
426 }
427
428 // Determine which MUL input is the scale (the one that is NOT
429 // the norm's output).
430 let scale_slot = if norm_output_ptr == mul_a_ptr { 1 } else { 0 };
431
432 // Build fused bindings:
433 // 0 = norm input
434 // 1 = norm weight
435 // 2 = scale (from MUL)
436 // 3 = MUL output
437 // 4 = norm params
438 // Gather all required bindings; bail if any are missing.
439 let (norm_input, norm_weight, scale, mul_output, norm_params) = match (
440 Self::get_binding(norm_bindings, 0),
441 Self::get_binding(norm_bindings, 1),
442 Self::get_binding(mul_bindings, scale_slot),
443 Self::get_binding(mul_bindings, 2),
444 Self::get_binding(norm_bindings, 3),
445 ) {
446 (Some(a), Some(b), Some(c), Some(d), Some(e)) => (a, b, c, d, e),
447 _ => {
448 // Missing bindings — don't fuse.
449 result.push(self.nodes[i].clone());
450 i += 1;
451 continue;
452 }
453 };
454
455 // Select fused pipeline based on the original norm pipeline name.
456 // The norm pipeline name is "rms_norm_f32", "rms_norm_f16", or
457 // "rms_norm_bf16" — we need the corresponding fused pipeline.
458 let fused_name = match Self::fused_pipeline_name(norm_pipeline) {
459 Some(name) => name,
460 None => {
461 result.push(self.nodes[i].clone());
462 i += 1;
463 continue;
464 }
465 };
466
467 let fused_pipeline = registry.get_pipeline(fused_name, device)?;
468
469 let fused_bindings = vec![
470 (0, norm_input),
471 (1, norm_weight),
472 (2, scale),
473 (3, mul_output),
474 (4, norm_params),
475 ];
476
477 // Merge read/write ranges from both the norm and mul nodes for the
478 // fused dispatch. The fused op reads everything the norm reads
479 // plus the mul's scale input, and writes to the mul's output.
480 let (fused_reads, fused_writes) = match (&self.nodes[i], &self.nodes[j]) {
481 (
482 CapturedNode::Dispatch { reads: nr, writes: _nw, .. },
483 CapturedNode::Dispatch { reads: mr, writes: mw, .. },
484 ) => {
485 let mut reads = nr.clone();
486 reads.extend_from_slice(mr);
487 (reads, mw.clone())
488 }
489 _ => (Vec::new(), Vec::new()),
490 };
491
492 result.push(CapturedNode::Dispatch {
493 pipeline: fused_pipeline.to_owned(),
494 bindings: fused_bindings,
495 threads_per_grid: *norm_tpg,
496 threads_per_threadgroup: *norm_tptg,
497 threadgroup_memory: norm_tgmem.clone(),
498 dispatch_kind: *norm_dk,
499 op_kind: CapturedOpKind::Other, // Fused ops are not further fuseable
500 reads: fused_reads,
501 writes: fused_writes,
502 });
503
504 fusions += 1;
505 // Skip past the norm, barrier(s), and mul nodes.
506 i = j + 1;
507 }
508
509 self.nodes = result;
510 Ok(fusions)
511 }
512
513 /// Run the reorder pass over the graph to improve GPU concurrency.
514 ///
515 /// Port of llama.cpp's `ggml_metal_graph_optimize_reorder` — a greedy
516 /// 64-node lookahead that pulls independent dispatches forward to fill
517 /// larger concurrent groups between barriers.
518 ///
519 /// **Prerequisites:** Call `fuse()` first if desired. The reorder pass
520 /// operates on the post-fusion graph. Barrier sentinel nodes are stripped
521 /// before reordering (they will be recomputed at encode time by the
522 /// `ConflictTracker` in `encode_sequential`).
523 ///
524 /// **Algorithm (matching llama.cpp exactly):**
525 /// 1. Strip all `CapturedNode::Barrier` nodes.
526 /// 2. For each unprocessed node `i0`:
527 /// - If it conflicts with the current concurrent group (`mrs0`):
528 /// * Initialize `mrs1` from `i0`'s ranges (skipped-over set)
529 /// * Lookahead up to 64 nodes for candidates that:
530 /// (a) Are reorderable (`CapturedOpKind::is_reorderable()`)
531 /// (b) Don't conflict with `mrs0` (current group)
532 /// (c) Don't conflict with `mrs1` (skipped-over nodes)
533 /// * Pull qualifying candidates into the current group
534 /// * Non-reorderable ops break the lookahead
535 /// - Reset `mrs0` (new concurrent group)
536 /// - Add `i0` to the new group
537 ///
538 /// Returns the number of nodes that were moved to earlier positions.
539 pub fn reorder(&mut self) -> u32 {
540 // Step 1: Strip barrier nodes. After fusion + reorder, barriers will
541 // be recomputed by the ConflictTracker at encode time.
542 self.nodes.retain(|n| !matches!(n, CapturedNode::Barrier));
543
544 let n = self.nodes.len();
545 if n == 0 {
546 return 0;
547 }
548
549 let mut result: Vec<usize> = Vec::with_capacity(n);
550 let mut used = vec![false; n];
551
552 // mrs0: memory ranges for the current concurrent group
553 let mut mrs0 = ReorderConflictTracker::new();
554 // mrs1: memory ranges for skipped-over (unprocessed) nodes
555 let mut mrs1 = ReorderConflictTracker::new();
556
557 const N_FORWARD: usize = 64;
558
559 for i0 in 0..n {
560 if used[i0] {
561 continue;
562 }
563
564 let node0 = &self.nodes[i0];
565
566 // Extract reads/writes for conflict check.
567 let (reads0, writes0, op_kind0) = match node0 {
568 CapturedNode::Dispatch { reads, writes, op_kind, .. } => {
569 (reads.as_slice(), writes.as_slice(), *op_kind)
570 }
571 CapturedNode::Barrier => continue, // stripped, but be safe
572 };
573
574 // Check if node0 conflicts with the current concurrent group.
575 // Empty nodes (no ranges) never conflict — like llama.cpp's is_empty.
576 let has_ranges = !reads0.is_empty() || !writes0.is_empty();
577 if has_ranges && mrs0.conflicts(reads0, writes0) {
578 // Before starting a new group, look forward for nodes that
579 // can be pulled into the CURRENT group.
580 mrs1.reset();
581 mrs1.add(reads0, writes0);
582
583 let end = (i0 + N_FORWARD).min(n);
584 for i1 in (i0 + 1)..end {
585 if used[i1] {
586 continue;
587 }
588
589 let node1 = &self.nodes[i1];
590 let (reads1, writes1, op_kind1) = match node1 {
591 CapturedNode::Dispatch { reads, writes, op_kind, .. } => {
592 (reads.as_slice(), writes.as_slice(), *op_kind)
593 }
594 CapturedNode::Barrier => continue,
595 };
596
597 // Non-reorderable ops break the lookahead.
598 if !op_kind1.is_reorderable() {
599 break;
600 }
601
602 let is_empty1 = reads1.is_empty() && writes1.is_empty();
603
604 // A node can be reordered into the current group if:
605 // 1. It's empty (no ranges) OR doesn't conflict with mrs0
606 // 2. It doesn't conflict with mrs1 (skipped-over nodes)
607 if (is_empty1 || !mrs0.conflicts(reads1, writes1))
608 && !mrs1.conflicts(reads1, writes1)
609 {
610 // Pull into current concurrent group.
611 mrs0.add(reads1, writes1);
612 result.push(i1);
613 used[i1] = true;
614 } else {
615 // Not eligible — expand the skipped-over set.
616 mrs1.add(reads1, writes1);
617 }
618 }
619
620 // Finalize the current concurrent group.
621 mrs0.reset();
622 }
623
624 // Expand the concurrent group with node0.
625 // (Barriers were stripped, so this is always a Dispatch.)
626 let _ = op_kind0; // suppress unused warning
627 mrs0.add(reads0, writes0);
628 result.push(i0);
629 }
630
631 // Apply the permutation to produce the reordered node list.
632 let mut reordered_count = 0u32;
633 for (pos, &orig_idx) in result.iter().enumerate() {
634 if orig_idx != pos {
635 reordered_count += 1;
636 }
637 }
638
639 // Build the reordered nodes vec.
640 let old_nodes = std::mem::take(&mut self.nodes);
641 self.nodes = result.iter().map(|&idx| old_nodes[idx].clone()).collect();
642
643 // Debug dump if requested.
644 if std::env::var("HF2Q_REORDER_DUMP").is_ok() {
645 eprintln!(
646 " [REORDER] nodes={} reordered={} ({:.1}%)",
647 n,
648 reordered_count,
649 100.0 * reordered_count as f64 / n as f64,
650 );
651 }
652
653 reordered_count
654 }
655
656 /// Get the Metal buffer pointer for a binding at the given slot index.
657 ///
658 /// Returns `Some(ptr)` if the slot has a `RecordedBinding::Buffer`,
659 /// `None` otherwise.
660 fn buffer_ptr_for_slot(bindings: &[(u64, RecordedBinding)], slot: u64) -> Option<*const std::ffi::c_void> {
661 for (idx, binding) in bindings {
662 if *idx == slot {
663 if let RecordedBinding::Buffer { metal_buffer, offset: _ } = binding {
664 // Use the Metal buffer's GPU address as the identity key.
665 // On Apple Silicon unified memory, this uniquely identifies
666 // the allocation.
667 let ptr: *const std::ffi::c_void = metal_buffer.as_ptr() as *const _;
668 return Some(ptr);
669 }
670 }
671 }
672 None
673 }
674
675 /// Clone the binding at the given slot index.
676 fn get_binding(bindings: &[(u64, RecordedBinding)], slot: u64) -> Option<RecordedBinding> {
677 for (idx, binding) in bindings {
678 if *idx == slot {
679 return Some(binding.clone());
680 }
681 }
682 None
683 }
684
685 /// Map a norm pipeline to its fused norm+mul pipeline name.
686 ///
687 /// The pipeline's `label()` is set by Metal to the function name, so we
688 /// can match on it. Returns `None` if the pipeline is not a known norm.
689 fn fused_pipeline_name(pipeline: &metal::ComputePipelineState) -> Option<&'static str> {
690 match pipeline.label() {
691 "rms_norm_f32" => Some("rms_norm_mul_f32"),
692 "rms_norm_f16" => Some("rms_norm_mul_f16"),
693 "rms_norm_bf16" => Some("rms_norm_mul_bf16"),
694 _ => None,
695 }
696 }
697}
698
699impl Default for ComputeGraph {
700 fn default() -> Self {
701 Self::new()
702 }
703}
704
705// ---------------------------------------------------------------------------
706// Dual-buffer encoding helpers (Phase 4e.4)
707// ---------------------------------------------------------------------------
708
709/// Find the node index where the n0-th dispatch starts.
710///
711/// Counts `CapturedNode::Dispatch` nodes until `n0` are reached, then returns
712/// the index of the n0-th dispatch (i.e., the first node of the second chunk).
713/// If `n0 >= dispatch_count`, returns `nodes.len()` (everything in chunk 0).
714fn find_dispatch_split_index(nodes: &[CapturedNode], n0: usize) -> usize {
715 let mut dispatches_seen = 0usize;
716 for (i, node) in nodes.iter().enumerate() {
717 if matches!(node, CapturedNode::Dispatch { .. }) {
718 dispatches_seen += 1;
719 if dispatches_seen == n0 {
720 return i + 1; // split AFTER the n0-th dispatch
721 }
722 }
723 }
724 nodes.len()
725}
726
727/// Encode a slice of captured nodes into a command encoder, recomputing
728/// barriers on the fly from each node's read/write buffer ranges.
729///
730/// This is the chunked counterpart of `ComputeGraph::encode_with_barriers()`.
731/// Factored out so both halves of a dual-buffer encode can use it.
732///
733/// Returns the number of barriers emitted.
734fn encode_chunk_with_barriers(nodes: &[CapturedNode], encoder: &mut CommandEncoder) -> u32 {
735 let mut tracker = ReorderConflictTracker::new();
736 let mut barrier_count = 0u32;
737
738 for node in nodes {
739 match node {
740 CapturedNode::Dispatch {
741 pipeline,
742 bindings,
743 threads_per_grid,
744 threads_per_threadgroup,
745 threadgroup_memory,
746 dispatch_kind,
747 reads,
748 writes,
749 ..
750 } => {
751 let has_ranges = !reads.is_empty() || !writes.is_empty();
752 if has_ranges && tracker.conflicts(reads, writes) {
753 encoder.memory_barrier();
754 tracker.reset();
755 barrier_count += 1;
756 }
757 if has_ranges {
758 tracker.add(reads, writes);
759 }
760 encoder.replay_dispatch(
761 pipeline,
762 bindings,
763 threadgroup_memory,
764 *threads_per_grid,
765 *threads_per_threadgroup,
766 *dispatch_kind,
767 );
768 }
769 CapturedNode::Barrier => {
770 encoder.memory_barrier();
771 tracker.reset();
772 barrier_count += 1;
773 }
774 }
775 }
776 barrier_count
777}
778
779// ---------------------------------------------------------------------------
780// ReorderConflictTracker — range-based conflict detection for the reorder pass
781// ---------------------------------------------------------------------------
782
783/// Memory range conflict tracker for the reorder pass (Phase 4e.3).
784///
785/// Works with `MemRange` tuples `(start, end)` stored on `CapturedNode::Dispatch`,
786/// rather than requiring live `&MlxBuffer` references. This is the reorder-time
787/// equivalent of the runtime `ConflictTracker`.
788///
789/// Conflict rules match llama.cpp's `ggml_mem_ranges_check`:
790/// - Two read ranges: OK (read-read is concurrent-safe)
791/// - A new read overlapping an existing write: CONFLICT (RAW)
792/// - A new write overlapping any existing range: CONFLICT (WAR/WAW)
793struct ReorderConflictTracker {
794 /// (start, end, is_write) for all ranges in the tracked set.
795 ranges: Vec<(usize, usize, bool)>,
796}
797
798impl ReorderConflictTracker {
799 fn new() -> Self {
800 Self {
801 ranges: Vec::with_capacity(64),
802 }
803 }
804
805 fn reset(&mut self) {
806 self.ranges.clear();
807 }
808
809 /// Check if a dispatch with the given read/write ranges conflicts with
810 /// any range in this tracker.
811 fn conflicts(&self, reads: &[MemRange], writes: &[MemRange]) -> bool {
812 // New reads vs existing writes (RAW)
813 for &(r_start, r_end) in reads {
814 for &(s, e, is_write) in &self.ranges {
815 if is_write && r_start < e && r_end > s {
816 return true;
817 }
818 }
819 }
820 // New writes vs all existing ranges (WAR/WAW)
821 for &(w_start, w_end) in writes {
822 for &(s, e, _) in &self.ranges {
823 if w_start < e && w_end > s {
824 return true;
825 }
826 }
827 }
828 false
829 }
830
831 /// Add read and write ranges to the tracked set.
832 fn add(&mut self, reads: &[MemRange], writes: &[MemRange]) {
833 for &(start, end) in reads {
834 self.ranges.push((start, end, false));
835 }
836 for &(start, end) in writes {
837 self.ranges.push((start, end, true));
838 }
839 }
840}
841
842/// Batched Metal dispatch — encodes multiple ops into a single `CommandEncoder`.
843///
844/// Create one per model (or per forward-pass loop). Call [`begin`](Self::begin)
845/// at the start of each forward pass to get a [`GraphSession`] that holds the
846/// shared encoder.
847pub struct GraphExecutor {
848 device: MlxDevice,
849}
850
851impl GraphExecutor {
852 /// Create a new graph executor backed by the given device.
853 pub fn new(device: MlxDevice) -> Self {
854 Self { device }
855 }
856
857 /// Begin a new forward pass (direct-dispatch mode).
858 ///
859 /// Returns a [`GraphSession`] that holds a fresh `CommandEncoder`. All ops
860 /// encoded through the session share this single encoder. Call
861 /// [`GraphSession::finish`] to commit and wait.
862 pub fn begin(&self) -> Result<GraphSession<'_>> {
863 let encoder = self.device.command_encoder()?;
864 Ok(GraphSession {
865 encoder,
866 device: &self.device,
867 barrier_count: 0,
868 tracker: ConflictTracker::new(),
869 dispatch_in_group: 0,
870 total_dispatches: 0,
871 group_sizes: [0; 8],
872 recording: false,
873 })
874 }
875
876 /// Begin a new forward pass in capture (record) mode.
877 ///
878 /// All op calls are recorded into a `ComputeGraph` instead of being
879 /// dispatched to Metal. When [`GraphSession::finish`] is called, the
880 /// recorded graph is replayed into a fresh encoder via
881 /// `ComputeGraph::encode_sequential()`.
882 ///
883 /// The API is identical to `begin()` — callers do not need to change
884 /// any op call code. The only behavioral difference: GPU work happens
885 /// at `finish()` time rather than at each op call.
886 pub fn begin_recorded(&self) -> Result<GraphSession<'_>> {
887 let mut encoder = self.device.command_encoder()?;
888 encoder.start_capture();
889 Ok(GraphSession {
890 encoder,
891 device: &self.device,
892 barrier_count: 0,
893 tracker: ConflictTracker::new(),
894 dispatch_in_group: 0,
895 total_dispatches: 0,
896 group_sizes: [0; 8],
897 recording: true,
898 })
899 }
900
901 /// Borrow the underlying device.
902 pub fn device(&self) -> &MlxDevice {
903 &self.device
904 }
905}
906
907/// A single forward pass execution context.
908///
909/// All ops are encoded into one `CommandEncoder`. Call [`finish`](Self::finish)
910/// to commit the command buffer and wait for GPU completion — this is the ONLY
911/// sync point per forward pass.
912///
913/// If an op returns an error, the session can be dropped without committing.
914/// The underlying command buffer is abandoned (never committed to the GPU).
915/// Tracks buffer address ranges for automatic barrier elision.
916///
917/// Mirrors llama.cpp's `ggml_mem_ranges` — accumulates the read and write
918/// ranges of all dispatches in the current concurrent group. When a new
919/// dispatch's reads overlap with an existing write (RAW), or its writes
920/// overlap with an existing read or write (WAR/WAW), a barrier is needed.
921/// Otherwise the dispatch can run concurrently and the barrier is elided.
922///
923/// Uses CPU-visible `contents_ptr()` addresses, which on Apple Silicon
924/// unified memory equal the GPU addresses.
925pub struct ConflictTracker {
926 /// (start, end, is_write) tuples for the current concurrent group.
927 ranges: Vec<(usize, usize, bool)>,
928}
929
930impl ConflictTracker {
931 fn new() -> Self {
932 Self {
933 ranges: Vec::with_capacity(32),
934 }
935 }
936
937 /// Reset the tracker — called after emitting a barrier.
938 fn reset(&mut self) {
939 self.ranges.clear();
940 }
941
942 /// Check if a new dispatch with the given reads and writes conflicts
943 /// with the current concurrent group.
944 ///
945 /// Conflict rules (same as llama.cpp `ggml_mem_ranges_check`):
946 /// - Two SRC (read) ranges in the same buffer: OK (read-read)
947 /// - A new SRC overlapping an existing DST: CONFLICT (RAW)
948 /// - A new DST overlapping an existing SRC or DST: CONFLICT (WAR/WAW)
949 /// Check for conflicts and return the reason if one is found.
950 /// Returns (conflict_type, new_buf_ptr, existing_buf_ptr) or None.
951 fn conflicts_reason(&self, reads: &[&MlxBuffer], writes: &[&MlxBuffer])
952 -> Option<(&'static str, usize, usize)>
953 {
954 // Check new reads against existing writes (RAW)
955 for r in reads {
956 let r_start = r.contents_ptr() as usize;
957 let r_end = r_start + r.byte_len();
958 for &(s, e, is_write) in &self.ranges {
959 if is_write && r_start < e && r_end > s {
960 return Some(("RAW", r_start, s));
961 }
962 }
963 }
964 // Check new writes against existing reads and writes (WAR/WAW)
965 for w in writes {
966 let w_start = w.contents_ptr() as usize;
967 let w_end = w_start + w.byte_len();
968 for &(s, e, is_write) in &self.ranges {
969 if w_start < e && w_end > s {
970 let kind = if is_write { "WAW" } else { "WAR" };
971 return Some((kind, w_start, s));
972 }
973 }
974 }
975 None
976 }
977
978 /// Add read and write ranges to the current concurrent group.
979 fn add(&mut self, reads: &[&MlxBuffer], writes: &[&MlxBuffer]) {
980 for r in reads {
981 let start = r.contents_ptr() as usize;
982 let end = start + r.byte_len();
983 self.ranges.push((start, end, false));
984 }
985 for w in writes {
986 let start = w.contents_ptr() as usize;
987 let end = start + w.byte_len();
988 self.ranges.push((start, end, true));
989 }
990 }
991}
992
993pub struct GraphSession<'a> {
994 encoder: CommandEncoder,
995 device: &'a MlxDevice,
996 barrier_count: u32,
997 tracker: ConflictTracker,
998 dispatch_in_group: u32,
999 total_dispatches: u32,
1000 /// Histogram: group_sizes[i] = number of concurrent groups with (i+1) dispatches
1001 group_sizes: [u32; 8],
1002 /// Whether this session was created in capture/record mode.
1003 recording: bool,
1004}
1005
1006impl<'a> GraphSession<'a> {
1007 /// Encode an RMS normalization into this session's encoder.
1008 ///
1009 /// Delegates to [`ops::rms_norm::dispatch_rms_norm`].
1010 pub fn rms_norm(
1011 &mut self,
1012 registry: &mut KernelRegistry,
1013 device: &metal::DeviceRef,
1014 input: &MlxBuffer,
1015 weight: &MlxBuffer,
1016 output: &MlxBuffer,
1017 params_buf: &MlxBuffer,
1018 rows: u32,
1019 dim: u32,
1020 ) -> Result<()> {
1021 ops::rms_norm::dispatch_rms_norm(
1022 &mut self.encoder,
1023 registry,
1024 device,
1025 input,
1026 weight,
1027 output,
1028 params_buf,
1029 rows,
1030 dim,
1031 )
1032 }
1033
1034 /// Encode a quantized matrix multiplication into this session's encoder.
1035 ///
1036 /// Delegates to [`ops::quantized_matmul::quantized_matmul`].
1037 /// Returns the freshly allocated output buffer.
1038 pub fn quantized_matmul(
1039 &mut self,
1040 registry: &mut KernelRegistry,
1041 device: &MlxDevice,
1042 input: &MlxBuffer,
1043 weight: &MlxBuffer,
1044 scales: &MlxBuffer,
1045 biases: &MlxBuffer,
1046 params: &ops::quantized_matmul::QuantizedMatmulParams,
1047 ) -> Result<MlxBuffer> {
1048 ops::quantized_matmul::quantized_matmul(
1049 &mut self.encoder,
1050 registry,
1051 device,
1052 input,
1053 weight,
1054 scales,
1055 biases,
1056 params,
1057 )
1058 }
1059
1060 /// Encode a SIMD-optimized quantized matmul into this session's encoder.
1061 ///
1062 /// Delegates to [`ops::quantized_matmul::quantized_matmul_simd`].
1063 /// Returns the freshly allocated output buffer.
1064 pub fn quantized_matmul_simd(
1065 &mut self,
1066 registry: &mut KernelRegistry,
1067 device: &MlxDevice,
1068 input: &MlxBuffer,
1069 weight: &MlxBuffer,
1070 scales: &MlxBuffer,
1071 biases: &MlxBuffer,
1072 params: &ops::quantized_matmul::QuantizedMatmulParams,
1073 ) -> Result<MlxBuffer> {
1074 ops::quantized_matmul::quantized_matmul_simd(
1075 &mut self.encoder,
1076 registry,
1077 device,
1078 input,
1079 weight,
1080 scales,
1081 biases,
1082 params,
1083 )
1084 }
1085
1086 /// Encode a GGML block-format quantized mat-vec into this session's encoder.
1087 ///
1088 /// Delegates to [`ops::quantized_matmul_ggml::quantized_matmul_ggml`].
1089 pub fn quantized_matmul_ggml(
1090 &mut self,
1091 registry: &mut KernelRegistry,
1092 device: &MlxDevice,
1093 input: &MlxBuffer,
1094 weight: &MlxBuffer,
1095 output: &mut MlxBuffer,
1096 params: &ops::quantized_matmul_ggml::GgmlQuantizedMatmulParams,
1097 ) -> Result<()> {
1098 ops::quantized_matmul_ggml::quantized_matmul_ggml(
1099 &mut self.encoder,
1100 registry,
1101 device,
1102 input,
1103 weight,
1104 output,
1105 params,
1106 )
1107 }
1108
1109 /// Encode an expert-routed GGML block-format quantized mat-vec into this session's encoder.
1110 ///
1111 /// Delegates to [`ops::quantized_matmul_id_ggml::quantized_matmul_id_ggml`].
1112 #[allow(clippy::too_many_arguments)]
1113 pub fn quantized_matmul_id_ggml(
1114 &mut self,
1115 registry: &mut KernelRegistry,
1116 device: &MlxDevice,
1117 input: &MlxBuffer,
1118 weight: &MlxBuffer,
1119 ids: &MlxBuffer,
1120 output: &mut MlxBuffer,
1121 params: &ops::quantized_matmul_id_ggml::GgmlQuantizedMatmulIdParams,
1122 ) -> Result<()> {
1123 ops::quantized_matmul_id_ggml::quantized_matmul_id_ggml(
1124 &mut self.encoder,
1125 registry,
1126 device,
1127 input,
1128 weight,
1129 ids,
1130 output,
1131 params,
1132 )
1133 }
1134
1135 /// Encode scaled dot-product attention into this session's encoder.
1136 ///
1137 /// Delegates to [`ops::sdpa::sdpa`].
1138 pub fn sdpa(
1139 &mut self,
1140 registry: &mut KernelRegistry,
1141 device: &MlxDevice,
1142 q: &MlxBuffer,
1143 k: &MlxBuffer,
1144 v: &MlxBuffer,
1145 output: &MlxBuffer,
1146 params: &ops::sdpa::SdpaParams,
1147 batch_size: u32,
1148 ) -> Result<()> {
1149 ops::sdpa::sdpa(
1150 &mut self.encoder,
1151 registry,
1152 device,
1153 q,
1154 k,
1155 v,
1156 output,
1157 params,
1158 batch_size,
1159 )
1160 }
1161
1162 /// Encode flash attention vector (SIMD-vectorized decode-path SDPA).
1163 ///
1164 /// Delegates to [`ops::flash_attn_vec::flash_attn_vec`].
1165 pub fn flash_attn_vec(
1166 &mut self,
1167 registry: &mut KernelRegistry,
1168 device: &MlxDevice,
1169 q: &MlxBuffer,
1170 k: &MlxBuffer,
1171 v: &MlxBuffer,
1172 output: &MlxBuffer,
1173 tmp: &MlxBuffer,
1174 params: &ops::flash_attn_vec::FlashAttnVecParams,
1175 ) -> Result<()> {
1176 ops::flash_attn_vec::flash_attn_vec(
1177 &mut self.encoder,
1178 registry,
1179 device,
1180 q,
1181 k,
1182 v,
1183 output,
1184 tmp,
1185 params,
1186 )
1187 }
1188
1189 /// Encode an elementwise add into this session's encoder.
1190 ///
1191 /// Delegates to [`ops::elementwise::elementwise_add`].
1192 pub fn elementwise_add(
1193 &mut self,
1194 registry: &mut KernelRegistry,
1195 device: &metal::DeviceRef,
1196 a: &MlxBuffer,
1197 b: &MlxBuffer,
1198 output: &MlxBuffer,
1199 n_elements: usize,
1200 dtype: DType,
1201 ) -> Result<()> {
1202 ops::elementwise::elementwise_add(
1203 &mut self.encoder,
1204 registry,
1205 device,
1206 a,
1207 b,
1208 output,
1209 n_elements,
1210 dtype,
1211 )
1212 }
1213
1214 /// Encode an elementwise multiply into this session's encoder.
1215 ///
1216 /// Delegates to [`ops::elementwise::elementwise_mul`].
1217 pub fn elementwise_mul(
1218 &mut self,
1219 registry: &mut KernelRegistry,
1220 device: &metal::DeviceRef,
1221 a: &MlxBuffer,
1222 b: &MlxBuffer,
1223 output: &MlxBuffer,
1224 n_elements: usize,
1225 dtype: DType,
1226 ) -> Result<()> {
1227 ops::elementwise::elementwise_mul(
1228 &mut self.encoder,
1229 registry,
1230 device,
1231 a,
1232 b,
1233 output,
1234 n_elements,
1235 dtype,
1236 )
1237 }
1238
1239 /// Encode a RoPE transform into this session's encoder.
1240 ///
1241 /// Delegates to [`ops::rope::dispatch_rope`].
1242 pub fn rope(
1243 &mut self,
1244 registry: &mut KernelRegistry,
1245 device: &metal::DeviceRef,
1246 input: &MlxBuffer,
1247 output: &MlxBuffer,
1248 params_buf: &MlxBuffer,
1249 positions_buf: &MlxBuffer,
1250 seq_len: u32,
1251 head_dim: u32,
1252 ) -> Result<()> {
1253 ops::rope::dispatch_rope(
1254 &mut self.encoder,
1255 registry,
1256 device,
1257 input,
1258 output,
1259 params_buf,
1260 positions_buf,
1261 seq_len,
1262 head_dim,
1263 )
1264 }
1265
1266 /// Encode a GELU activation into this session's encoder.
1267 ///
1268 /// Delegates to [`ops::gelu::dispatch_gelu`].
1269 pub fn gelu(
1270 &mut self,
1271 registry: &mut KernelRegistry,
1272 device: &metal::DeviceRef,
1273 input: &MlxBuffer,
1274 output: &MlxBuffer,
1275 ) -> Result<()> {
1276 ops::gelu::dispatch_gelu(
1277 &mut self.encoder,
1278 registry,
1279 device,
1280 input,
1281 output,
1282 )
1283 }
1284
1285 /// Encode a softmax into this session's encoder.
1286 ///
1287 /// Delegates to [`ops::softmax::dispatch_softmax`].
1288 pub fn softmax(
1289 &mut self,
1290 registry: &mut KernelRegistry,
1291 device: &metal::DeviceRef,
1292 input: &MlxBuffer,
1293 output: &MlxBuffer,
1294 params_buf: &MlxBuffer,
1295 rows: u32,
1296 cols: u32,
1297 ) -> Result<()> {
1298 ops::softmax::dispatch_softmax(
1299 &mut self.encoder,
1300 registry,
1301 device,
1302 input,
1303 output,
1304 params_buf,
1305 rows,
1306 cols,
1307 )
1308 }
1309
1310 /// Encode a softcap into this session's encoder.
1311 ///
1312 /// Delegates to [`ops::softcap::dispatch_softcap`].
1313 pub fn softcap(
1314 &mut self,
1315 registry: &mut KernelRegistry,
1316 device: &metal::DeviceRef,
1317 input: &MlxBuffer,
1318 output: &MlxBuffer,
1319 params_buf: &MlxBuffer,
1320 cap: f32,
1321 ) -> Result<()> {
1322 ops::softcap::dispatch_softcap(
1323 &mut self.encoder,
1324 registry,
1325 device,
1326 input,
1327 output,
1328 params_buf,
1329 cap,
1330 )
1331 }
1332
1333 /// Encode an RMS norm without learned scale (f32) into this session's encoder.
1334 ///
1335 /// Delegates to [`ops::rms_norm::dispatch_rms_norm_no_scale_f32`].
1336 pub fn rms_norm_no_scale_f32(
1337 &mut self,
1338 registry: &mut KernelRegistry,
1339 device: &metal::DeviceRef,
1340 input: &MlxBuffer,
1341 output: &MlxBuffer,
1342 params_buf: &MlxBuffer,
1343 rows: u32,
1344 dim: u32,
1345 ) -> Result<()> {
1346 ops::rms_norm::dispatch_rms_norm_no_scale_f32(
1347 &mut self.encoder,
1348 registry,
1349 device,
1350 input,
1351 output,
1352 params_buf,
1353 rows,
1354 dim,
1355 )
1356 }
1357
1358 /// Encode a NeoX RoPE (f32) with optional freq_factors into this session's encoder.
1359 ///
1360 /// Delegates to [`ops::rope::dispatch_rope_neox_f32`].
1361 #[allow(clippy::too_many_arguments)]
1362 pub fn rope_neox_f32(
1363 &mut self,
1364 registry: &mut KernelRegistry,
1365 device: &metal::DeviceRef,
1366 input: &MlxBuffer,
1367 output: &MlxBuffer,
1368 params_buf: &MlxBuffer,
1369 positions_buf: &MlxBuffer,
1370 freq_factors: Option<&MlxBuffer>,
1371 seq_len: u32,
1372 n_heads: u32,
1373 head_dim: u32,
1374 rope_dim: u32,
1375 ) -> Result<()> {
1376 ops::rope::dispatch_rope_neox_f32(
1377 &mut self.encoder,
1378 registry,
1379 device,
1380 input,
1381 output,
1382 params_buf,
1383 positions_buf,
1384 freq_factors,
1385 seq_len,
1386 n_heads,
1387 head_dim,
1388 rope_dim,
1389 )
1390 }
1391
1392 /// Insert a GPU memory barrier (MTLBarrierScopeBuffers).
1393 ///
1394 /// Unconditional barrier — always emits. Use `barrier_between` for
1395 /// automatic conflict detection that can elide unnecessary barriers.
1396 #[inline]
1397 pub fn barrier(&mut self) {
1398 // Record the outgoing group size
1399 if self.dispatch_in_group > 0 {
1400 let idx = (self.dispatch_in_group as usize).min(self.group_sizes.len()) - 1;
1401 self.group_sizes[idx] += 1;
1402 }
1403 self.encoder.memory_barrier();
1404 self.tracker.reset();
1405 self.barrier_count += 1;
1406 self.dispatch_in_group = 0;
1407 }
1408
1409 /// Smart barrier with conflict detection.
1410 ///
1411 /// Checks if the next dispatch (with the given read and write buffers)
1412 /// actually conflicts with any dispatch in the current concurrent group.
1413 /// If yes, emits a Metal barrier and resets the tracker. If no, the
1414 /// barrier is elided and the dispatch can run concurrently.
1415 ///
1416 /// This mirrors llama.cpp's `ggml_metal_op_concurrency_check` +
1417 /// `ggml_metal_op_concurrency_reset` pattern.
1418 #[inline]
1419 pub fn barrier_between(&mut self, reads: &[&MlxBuffer], writes: &[&MlxBuffer]) {
1420 // In capture mode, stash the read/write ranges so the next captured
1421 // dispatch node carries them for the reorder pass (Phase 4e.3).
1422 if self.recording {
1423 let read_ranges: Vec<MemRange> = reads
1424 .iter()
1425 .map(|b| {
1426 let start = b.contents_ptr() as usize;
1427 (start, start + b.byte_len())
1428 })
1429 .collect();
1430 let write_ranges: Vec<MemRange> = writes
1431 .iter()
1432 .map(|b| {
1433 let start = b.contents_ptr() as usize;
1434 (start, start + b.byte_len())
1435 })
1436 .collect();
1437 self.encoder.set_pending_buffer_ranges(read_ranges, write_ranges);
1438 }
1439
1440 let reason = self.tracker.conflicts_reason(reads, writes);
1441 if let Some((_kind, _new_ptr, _existing_ptr)) = reason {
1442 // Record the outgoing group size before resetting
1443 if self.dispatch_in_group > 0 {
1444 let idx = (self.dispatch_in_group as usize).min(self.group_sizes.len()) - 1;
1445 self.group_sizes[idx] += 1;
1446 }
1447 self.encoder.memory_barrier();
1448 self.tracker.reset();
1449 self.barrier_count += 1;
1450 self.dispatch_in_group = 0;
1451 }
1452 self.dispatch_in_group += 1;
1453 self.total_dispatches += 1;
1454 self.tracker.add(reads, writes);
1455 }
1456
1457 /// Print group size histogram to stderr (for HF2Q_MLX_TIMING debug).
1458 pub fn dump_group_stats(&self) {
1459 // Record the final (unterminated) group
1460 let mut gs = self.group_sizes;
1461 if self.dispatch_in_group > 0 {
1462 let idx = (self.dispatch_in_group as usize).min(gs.len()) - 1;
1463 gs[idx] += 1;
1464 }
1465 let total_groups: u32 = gs.iter().sum();
1466 eprintln!(" [GROUP_STATS] dispatches={} barriers={} groups={} ratio={:.2}",
1467 self.total_dispatches, self.barrier_count, total_groups,
1468 if total_groups > 0 { self.total_dispatches as f64 / total_groups as f64 } else { 0.0 });
1469 for (i, &count) in gs.iter().enumerate() {
1470 if count > 0 {
1471 eprintln!(" size {}: {} groups", i + 1, count);
1472 }
1473 }
1474 }
1475
1476 /// Register a dispatch's buffer ranges without checking for conflicts.
1477 ///
1478 /// Use after dispatching an op that doesn't need a barrier check (e.g.,
1479 /// the first dispatch in a session, or dispatches known to be concurrent).
1480 ///
1481 /// In recording mode, also retroactively annotates the most recently
1482 /// captured dispatch node with these ranges if it was missing them.
1483 /// That keeps the reorder pass able to reason about dispatches that
1484 /// were preceded by `track_dispatch` rather than `barrier_between`.
1485 #[inline]
1486 pub fn track_dispatch(&mut self, reads: &[&MlxBuffer], writes: &[&MlxBuffer]) {
1487 if self.recording {
1488 let read_ranges: Vec<MemRange> = reads
1489 .iter()
1490 .map(|b| {
1491 let start = b.contents_ptr() as usize;
1492 (start, start + b.byte_len())
1493 })
1494 .collect();
1495 let write_ranges: Vec<MemRange> = writes
1496 .iter()
1497 .map(|b| {
1498 let start = b.contents_ptr() as usize;
1499 (start, start + b.byte_len())
1500 })
1501 .collect();
1502 self.encoder
1503 .annotate_last_dispatch_if_missing(read_ranges, write_ranges);
1504 }
1505 self.tracker.add(reads, writes);
1506 }
1507
1508 /// Return the number of barriers inserted so far in this session.
1509 #[inline]
1510 pub fn barrier_count(&self) -> u32 {
1511 self.barrier_count
1512 }
1513
1514 /// Cumulative nanoseconds spent in ConflictTracker checks (diagnostic).
1515 /// Returns 0 when timing is not compiled in.
1516 pub fn tracker_overhead_ns(&self) -> u64 {
1517 0
1518 }
1519
1520 /// Borrow the underlying command encoder for direct op dispatch.
1521 ///
1522 /// Use this when you need to call an op function that is not wrapped by
1523 /// a `GraphSession` method. The returned encoder is the same shared
1524 /// encoder — all dispatches still go into the same command buffer.
1525 pub fn encoder_mut(&mut self) -> &mut CommandEncoder {
1526 &mut self.encoder
1527 }
1528
1529 /// Borrow the device reference.
1530 pub fn device(&self) -> &MlxDevice {
1531 self.device
1532 }
1533
1534 /// Whether this session is in capture/record mode.
1535 pub fn is_recording(&self) -> bool {
1536 self.recording
1537 }
1538
1539 /// Commit the command buffer and wait for GPU completion.
1540 ///
1541 /// This is the ONLY sync point per forward pass. After this call, all
1542 /// output buffers are readable by the CPU.
1543 ///
1544 /// In recording mode: extracts the captured graph, replays it into
1545 /// the encoder via `ComputeGraph::encode_sequential()`, then commits
1546 /// and waits. The result is identical to the direct-dispatch path.
1547 ///
1548 /// Consumes the session — no further ops can be encoded.
1549 pub fn finish(mut self) -> Result<()> {
1550 if self.recording {
1551 if let Some(nodes) = self.encoder.take_capture() {
1552 let graph = ComputeGraph::from_nodes(nodes);
1553 graph.encode_sequential(&mut self.encoder);
1554 }
1555 }
1556 self.encoder.commit_and_wait()
1557 }
1558
1559 /// Commit the command buffer WITHOUT waiting.
1560 ///
1561 /// The GPU begins executing immediately. Use this for fire-and-forget
1562 /// dispatch when you do not need results until later.
1563 ///
1564 /// In recording mode: replays the captured graph before committing.
1565 ///
1566 /// Consumes the session.
1567 pub fn commit(mut self) -> CommandEncoder {
1568 if self.recording {
1569 if let Some(nodes) = self.encoder.take_capture() {
1570 let graph = ComputeGraph::from_nodes(nodes);
1571 graph.encode_sequential(&mut self.encoder);
1572 }
1573 }
1574 self.encoder.commit();
1575 self.encoder
1576 }
1577
1578 /// Commit the command buffer and wait, returning split timing.
1579 ///
1580 /// Returns `(encoding_ns, gpu_wait_ns)` where:
1581 /// - `encoding_ns` is the time from session begin to commit (CPU encoding)
1582 /// - `gpu_wait_ns` is the time from commit to GPU completion
1583 ///
1584 /// The `session_begin` instant should be captured right after `exec.begin()`.
1585 ///
1586 /// In recording mode: replays the captured graph before committing.
1587 ///
1588 /// Consumes the session.
1589 pub fn finish_with_timing(mut self, session_begin: std::time::Instant) -> Result<(u64, u64)> {
1590 if self.recording {
1591 if let Some(nodes) = self.encoder.take_capture() {
1592 let graph = ComputeGraph::from_nodes(nodes);
1593 graph.encode_sequential(&mut self.encoder);
1594 }
1595 }
1596 let commit_start = std::time::Instant::now();
1597 let encoding_ns = commit_start.duration_since(session_begin).as_nanos() as u64;
1598 self.encoder.commit();
1599 self.encoder.wait_until_completed()?;
1600 let gpu_wait_ns = commit_start.elapsed().as_nanos() as u64;
1601 Ok((encoding_ns, gpu_wait_ns))
1602 }
1603
1604 /// Finish with fusion: run the RMS norm + MUL fusion pass before
1605 /// replaying the graph.
1606 ///
1607 /// Only meaningful in recording mode. In direct-dispatch mode, this
1608 /// behaves identically to `finish()`.
1609 ///
1610 /// Returns `(fusions_applied,)` on success.
1611 pub fn finish_with_fusion(
1612 mut self,
1613 registry: &mut KernelRegistry,
1614 device: &metal::DeviceRef,
1615 ) -> Result<u32> {
1616 let mut fusions = 0;
1617 if self.recording {
1618 if let Some(nodes) = self.encoder.take_capture() {
1619 let mut graph = ComputeGraph::from_nodes(nodes);
1620 fusions = graph.fuse(registry, device)?;
1621 graph.encode_sequential(&mut self.encoder);
1622 }
1623 }
1624 self.encoder.commit_and_wait()?;
1625 Ok(fusions)
1626 }
1627
1628 /// Finish with fusion and split timing.
1629 ///
1630 /// Like `finish_with_timing` but runs the fusion pass first.
1631 /// Returns `(encoding_ns, gpu_wait_ns, fusions_applied)`.
1632 pub fn finish_with_fusion_and_timing(
1633 mut self,
1634 registry: &mut KernelRegistry,
1635 device: &metal::DeviceRef,
1636 session_begin: std::time::Instant,
1637 ) -> Result<(u64, u64, u32)> {
1638 let mut fusions = 0;
1639 if self.recording {
1640 if let Some(nodes) = self.encoder.take_capture() {
1641 let mut graph = ComputeGraph::from_nodes(nodes);
1642 fusions = graph.fuse(registry, device)?;
1643 graph.encode_sequential(&mut self.encoder);
1644 }
1645 }
1646 let commit_start = std::time::Instant::now();
1647 let encoding_ns = commit_start.duration_since(session_begin).as_nanos() as u64;
1648 self.encoder.commit();
1649 self.encoder.wait_until_completed()?;
1650 let gpu_wait_ns = commit_start.elapsed().as_nanos() as u64;
1651 Ok((encoding_ns, gpu_wait_ns, fusions))
1652 }
1653
1654 /// Finish with fusion AND reorder: run both graph optimization passes
1655 /// before replaying the graph.
1656 ///
1657 /// Only meaningful in recording mode. In direct-dispatch mode, this
1658 /// behaves identically to `finish()`.
1659 ///
1660 /// Returns `(fusions_applied, nodes_reordered)` on success.
1661 pub fn finish_with_fusion_and_reorder(
1662 mut self,
1663 registry: &mut KernelRegistry,
1664 device: &metal::DeviceRef,
1665 ) -> Result<(u32, u32)> {
1666 let mut fusions = 0;
1667 let mut reordered = 0;
1668 if self.recording {
1669 if let Some(nodes) = self.encoder.take_capture() {
1670 let mut graph = ComputeGraph::from_nodes(nodes);
1671 fusions = graph.fuse(registry, device)?;
1672 reordered = graph.reorder();
1673 graph.encode_with_barriers(&mut self.encoder);
1674 }
1675 }
1676 self.encoder.commit_and_wait()?;
1677 Ok((fusions, reordered))
1678 }
1679
1680 /// Finish with fusion, reorder, and split timing.
1681 ///
1682 /// Like `finish_with_fusion_and_timing` but also runs the reorder pass.
1683 /// Returns `(encoding_ns, gpu_wait_ns, fusions_applied, nodes_reordered)`.
1684 pub fn finish_with_fusion_reorder_and_timing(
1685 mut self,
1686 registry: &mut KernelRegistry,
1687 device: &metal::DeviceRef,
1688 session_begin: std::time::Instant,
1689 ) -> Result<(u64, u64, u32, u32)> {
1690 let mut fusions = 0;
1691 let mut reordered = 0;
1692 if self.recording {
1693 if let Some(nodes) = self.encoder.take_capture() {
1694 let mut graph = ComputeGraph::from_nodes(nodes);
1695 fusions = graph.fuse(registry, device)?;
1696 reordered = graph.reorder();
1697 graph.encode_with_barriers(&mut self.encoder);
1698 }
1699 }
1700 let commit_start = std::time::Instant::now();
1701 let encoding_ns = commit_start.duration_since(session_begin).as_nanos() as u64;
1702 self.encoder.commit();
1703 self.encoder.wait_until_completed()?;
1704 let gpu_wait_ns = commit_start.elapsed().as_nanos() as u64;
1705 Ok((encoding_ns, gpu_wait_ns, fusions, reordered))
1706 }
1707
1708 /// Finish with the full optimization pipeline: fuse, reorder, dual-buffer
1709 /// encode.
1710 ///
1711 /// Runs the fusion pass, reorder pass, then encodes the graph into two
1712 /// Metal command buffers for CPU/GPU overlap. The first ~10% of dispatches
1713 /// are committed immediately so the GPU can start executing while the CPU
1714 /// encodes the remaining ~90%.
1715 ///
1716 /// Only meaningful in recording mode. In direct-dispatch mode, this
1717 /// behaves identically to `finish()`.
1718 ///
1719 /// Returns `(fusions_applied, nodes_reordered, barriers_buf0, barriers_buf1)`.
1720 pub fn finish_optimized(
1721 mut self,
1722 registry: &mut KernelRegistry,
1723 device: &metal::DeviceRef,
1724 ) -> Result<(u32, u32, u32, u32)> {
1725 let mut fusions = 0;
1726 let mut reordered = 0;
1727 let mut barriers0 = 0u32;
1728 let mut barriers1 = 0u32;
1729
1730 if self.recording {
1731 if let Some(nodes) = self.encoder.take_capture() {
1732 // Commit the capture encoder's empty command buffer so its
1733 // MTLCommandQueue pool slot is freed (same fix as timing variant).
1734 self.encoder.commit();
1735
1736 let mut graph = ComputeGraph::from_nodes(nodes);
1737 fusions = graph.fuse(registry, device)?;
1738 reordered = graph.reorder();
1739
1740 let mut enc0 = self.device.command_encoder()?;
1741 let mut enc1 = self.device.command_encoder()?;
1742
1743 let (b0, b1) = graph.encode_dual_buffer(&mut enc0, &mut enc1);
1744 barriers0 = b0;
1745 barriers1 = b1;
1746
1747 // enc0 was already committed inside encode_dual_buffer.
1748 // Commit enc1 and wait — Metal queue ordering guarantees enc0
1749 // finishes before enc1 starts executing.
1750 enc1.commit_and_wait()?;
1751
1752 // The original encoder was never committed (capture mode drained
1753 // it). We need to end it cleanly — dropping it will end the
1754 // active encoder if any, and the uncommitted command buffer is
1755 // abandoned. That is safe: Metal silently drops uncommitted
1756 // command buffers.
1757 return Ok((fusions, reordered, barriers0, barriers1));
1758 }
1759 }
1760
1761 // Direct-dispatch fallback: just commit the original encoder.
1762 self.encoder.commit_and_wait()?;
1763 Ok((fusions, reordered, barriers0, barriers1))
1764 }
1765
1766 /// Finish with the full optimization pipeline and split timing.
1767 ///
1768 /// Like `finish_optimized` but returns timing information.
1769 /// Returns `(encoding_ns, gpu_wait_ns, fusions, reordered, barriers_buf0, barriers_buf1)`.
1770 ///
1771 /// Timing breakdown:
1772 /// - `encoding_ns`: CPU time from session begin to first buffer commit
1773 /// (fusion + reorder + encode chunk 0)
1774 /// - `gpu_wait_ns`: wall time from second buffer commit to GPU completion
1775 /// (includes GPU execution of both buffers, overlapped with chunk 1 encoding)
1776 pub fn finish_optimized_with_timing(
1777 mut self,
1778 registry: &mut KernelRegistry,
1779 device: &metal::DeviceRef,
1780 session_begin: std::time::Instant,
1781 ) -> Result<(u64, u64, u32, u32, u32, u32)> {
1782 let mut fusions = 0;
1783 let mut reordered = 0;
1784 let mut barriers0 = 0u32;
1785 let mut barriers1 = 0u32;
1786
1787 if self.recording {
1788 if let Some(nodes) = self.encoder.take_capture() {
1789 // Commit the capture encoder's empty command buffer so its
1790 // MTLCommandQueue pool slot is freed. Without this, each
1791 // token leaks one uncommitted buffer and the queue exhausts
1792 // its ~64-slot pool after ~64 tokens, causing a deadlock.
1793 self.encoder.commit();
1794
1795 let opt_t0 = std::time::Instant::now();
1796 let mut graph = ComputeGraph::from_nodes(nodes);
1797 let fuse_t0 = std::time::Instant::now();
1798 fusions = graph.fuse(registry, device)?;
1799 let fuse_us = fuse_t0.elapsed().as_micros();
1800
1801 let reorder_t0 = std::time::Instant::now();
1802 let unannotated = graph.unannotated_dispatch_count();
1803 if unannotated == 0 {
1804 reordered = graph.reorder();
1805 } else if std::env::var("HF2Q_MLX_TIMING").is_ok() {
1806 eprintln!(" [GRAPH_OPT] WARN: skipping reorder — {} of {} dispatches lack range annotations",
1807 unannotated, graph.dispatch_count());
1808 }
1809 let reorder_us = reorder_t0.elapsed().as_micros();
1810 let opt_us = opt_t0.elapsed().as_micros();
1811
1812 let diag = std::env::var("HF2Q_GRAPH_DIAG").is_ok();
1813 let t0 = std::time::Instant::now();
1814 let mut enc0 = self.device.command_encoder()?;
1815 let mut enc1 = self.device.command_encoder()?;
1816 let enc_create_us = t0.elapsed().as_micros();
1817
1818 let t1 = std::time::Instant::now();
1819 let (b0, b1) = graph.encode_dual_buffer(&mut enc0, &mut enc1);
1820 barriers0 = b0;
1821 barriers1 = b1;
1822 let encode_us = t1.elapsed().as_micros();
1823
1824 let encoding_ns = session_begin.elapsed().as_nanos() as u64;
1825
1826 let wait_start = std::time::Instant::now();
1827 enc1.commit_and_wait()?;
1828 let gpu_wait_ns = wait_start.elapsed().as_nanos() as u64;
1829
1830 if diag {
1831 eprintln!(" [DIAG] fuse={:.1}ms reorder={:.1}ms opt_total={:.1}ms enc_create={:.1}ms encode={:.1}ms gpu_wait={:.1}ms barriers={}+{}",
1832 fuse_us as f64 / 1e3, reorder_us as f64 / 1e3, opt_us as f64 / 1e3,
1833 enc_create_us as f64 / 1e3, encode_us as f64 / 1e3,
1834 gpu_wait_ns as f64 / 1e6, b0, b1);
1835 }
1836
1837 return Ok((encoding_ns, gpu_wait_ns, fusions, reordered, barriers0, barriers1));
1838 }
1839 }
1840
1841 // Direct-dispatch fallback.
1842 let commit_start = std::time::Instant::now();
1843 let encoding_ns = commit_start.duration_since(session_begin).as_nanos() as u64;
1844 self.encoder.commit();
1845 self.encoder.wait_until_completed()?;
1846 let gpu_wait_ns = commit_start.elapsed().as_nanos() as u64;
1847 Ok((encoding_ns, gpu_wait_ns, fusions, reordered, barriers0, barriers1))
1848 }
1849}