Skip to main content

ComputeGraph

Struct ComputeGraph 

Source
pub struct ComputeGraph { /* private fields */ }
Expand description

A recorded sequence of GPU compute dispatches and barriers.

Created by running a forward pass with the encoder in capture mode. Can be replayed into a real CommandEncoder via encode_sequential(), producing identical Metal dispatch behavior to the original direct path.

Future phases (4e.2, 4e.3) will add fusion and reorder passes that transform the graph before encoding.

Implementations§

Source§

impl ComputeGraph

Source

pub fn new() -> Self

Create an empty compute graph.

Source

pub fn from_nodes(nodes: Vec<CapturedNode>) -> Self

Create a compute graph from a pre-built list of captured nodes.

Source

pub fn record(&mut self, node: CapturedNode)

Record a captured node into the graph.

Source

pub fn len(&self) -> usize

Number of nodes (dispatches + barriers) in the graph.

Source

pub fn is_empty(&self) -> bool

Whether the graph contains no nodes.

Source

pub fn dispatch_count(&self) -> usize

Number of dispatch nodes (excludes barriers).

Source

pub fn barrier_count(&self) -> usize

Number of barrier nodes.

Source

pub fn nodes(&self) -> &[CapturedNode]

Borrow the node list.

Source

pub fn unannotated_dispatch_count(&self) -> usize

Count dispatch nodes that have empty read/write range annotations.

Used for diagnostics: if >0, the reorder pass cannot guarantee correctness because it relies on complete annotations.

Source

pub fn into_nodes(self) -> Vec<CapturedNode>

Take ownership of the node list, consuming the graph.

Source

pub fn encode_sequential(&self, encoder: &mut CommandEncoder) -> u32

Encode all nodes sequentially into the given encoder.

Barrier sentinel nodes emit a Metal memory barrier. Dispatch nodes are replayed through CommandEncoder::replay_dispatch().

This produces identical GPU behavior to the direct-dispatch path — same pipeline bindings, same dispatch dimensions, same barrier placement.

Returns the number of barriers emitted.

Source

pub fn encode_with_barriers(&self, encoder: &mut CommandEncoder) -> u32

Encode the graph into a Metal command buffer, computing barriers on the fly from each node’s read/write buffer ranges.

This is the correct encoding method for reordered graphs where barrier sentinels have been stripped. Mirrors llama.cpp’s encode-time barrier insertion via ggml_metal_op_concurrency_check.

Returns the number of barriers emitted.

Source

pub fn encode_dual_buffer( &self, encoder0: &mut CommandEncoder, encoder1: &mut CommandEncoder, ) -> (u32, u32)

Encode the graph using two command buffers for CPU/GPU overlap.

The first n0 dispatches are encoded into encoder0 and committed immediately (GPU starts executing). The remaining dispatches are encoded into encoder1. The caller is responsible for committing encoder1.

This matches llama.cpp’s dual command buffer pattern from ggml_metal_graph_compute (ggml-metal-context.m:441-644): n_nodes_0 = MAX(64, 0.1 * n_nodes) for the first buffer.

Command buffers submitted to the same MTLCommandQueue execute in submission order, so encoder0.commit() followed by encoder1.commit() guarantees enc0 finishes before enc1 starts. The win: the GPU starts executing enc0 while the CPU is still encoding enc1.

Returns (barriers_buf0, barriers_buf1).

Source

pub fn fuse( &mut self, registry: &mut KernelRegistry, device: &DeviceRef, ) -> Result<u32>

Run the RMS norm + MUL fusion pass over the graph.

Scans for the pattern: Dispatch(RmsNorm) → Barrier(s) → Dispatch(ElemMul) where the MUL reads the norm’s output buffer, and replaces the sequence with a single fused rms_norm_mul_* dispatch.

The fused dispatch:

  • Reads the norm’s input (buffer 0) and weight (buffer 1)
  • Reads the MUL’s second operand as the scale (buffer 2)
  • Writes to the MUL’s output (buffer 3)
  • Carries the norm’s params (buffer 4)
  • Uses the norm’s threadgroup config and shared memory

Returns the number of fusions applied.

§Arguments
  • registry - Kernel registry for compiling the fused pipeline.
  • device - Metal device for pipeline compilation.
Source

pub fn reorder(&mut self) -> u32

Run the reorder pass over the graph to improve GPU concurrency.

Port of llama.cpp’s ggml_metal_graph_optimize_reorder — a greedy 64-node lookahead that pulls independent dispatches forward to fill larger concurrent groups between barriers.

Prerequisites: Call fuse() first if desired. The reorder pass operates on the post-fusion graph. Barrier sentinel nodes are stripped before reordering (they will be recomputed at encode time by the ConflictTracker in encode_sequential).

Algorithm (matching llama.cpp exactly):

  1. Strip all CapturedNode::Barrier nodes.
  2. For each unprocessed node i0:
    • If it conflicts with the current concurrent group (mrs0):
      • Initialize mrs1 from i0’s ranges (skipped-over set)
      • Lookahead up to 64 nodes for candidates that: (a) Are reorderable (CapturedOpKind::is_reorderable()) (b) Don’t conflict with mrs0 (current group) (c) Don’t conflict with mrs1 (skipped-over nodes)
      • Pull qualifying candidates into the current group
      • Non-reorderable ops break the lookahead
    • Reset mrs0 (new concurrent group)
    • Add i0 to the new group

Returns the number of nodes that were moved to earlier positions.

Trait Implementations§

Source§

impl Default for ComputeGraph

Source§

fn default() -> Self

Returns the “default value” for a type. Read more

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.