Skip to main content

Module graph

Module graph 

Source
Expand description

GraphExecutor — batched Metal dispatch for single-encoder forward passes.

llama.cpp’s speed advantage over candle is NOT the kernels (Phase 0 proved candle’s are as fast or faster per-call). It is the dispatch pattern: 1 encoder per command buffer instead of ~120. This module implements that pattern.

§Usage

let mut executor = GraphExecutor::new(device.clone());
let mut session = executor.begin()?;

// All ops encode into the same command buffer — no per-op encoder creation.
session.rms_norm(&mut registry, device.metal_device(), input, weight, output, params, rows, dim)?;
session.quantized_matmul(&mut registry, &device, input, weight, scales, biases, &qparams)?;
session.elementwise_add(&mut registry, device.metal_device(), a, b, out, n, DType::F32)?;

// Single GPU sync point for the entire forward pass.
session.finish()?;

§Design

The GraphSession holds a single CommandEncoder. Each op method delegates to the existing op dispatch functions in crate::ops, passing the session’s shared encoder. No new Metal code is needed — the ops already work with a shared encoder. The executor just prevents creating a new encoder per op.

§Phase 4e.1 — Graph IR

The ComputeGraph type captures dispatches into a Vec<CapturedNode> for later replay. GraphExecutor::begin_recorded() starts a session in capture mode: all op calls are intercepted at the CommandEncoder level and recorded instead of being sent to Metal. GraphSession::finish() detects capture mode, extracts the recorded graph, and replays it into a fresh encoder via ComputeGraph::encode_sequential().

The existing direct-dispatch path (begin()) is completely unchanged.

Structs§

ComputeGraph
A recorded sequence of GPU compute dispatches and barriers.
ConflictTracker
A single forward pass execution context.
GraphExecutor
Batched Metal dispatch — encodes multiple ops into a single CommandEncoder.
GraphSession
MlxBuffer
A Metal GPU buffer annotated with element dtype and tensor shape.

Enums§

DType
Element data type carried by an MlxBuffer.
OpKind
Classification of a compute operation for reorder safety analysis.