Skip to main content

Graph

Struct Graph 

Source
pub struct Graph {
    pub name: String,
    pub outputs: Vec<NodeId>,
    /* private fields */
}
Expand description

A computation graph — the core IR data structure.

§Example

use rlx_ir::*;

let mut g = Graph::new("bert_layer");

// Inputs
let x = g.input("hidden", Shape::new(&[4, 15, 384], DType::F32));
let w = g.param("qkv_weight", Shape::new(&[384, 1152], DType::F32));
let b = g.param("qkv_bias", Shape::new(&[1152], DType::F32));

// QKV projection: matmul + bias
let mm = g.matmul(x, w, Shape::new(&[4, 15, 1152], DType::F32));
let qkv = g.binary(op::BinaryOp::Add, mm, b, Shape::new(&[4, 15, 1152], DType::F32));

assert_eq!(g.len(), 5);
println!("{g}");

Fields§

§name: String§outputs: Vec<NodeId>

Output node IDs (the graph’s results).

Implementations§

Source§

impl Graph

Source

pub fn new(name: impl Into<String>) -> Graph

Source

pub fn len(&self) -> usize

Number of nodes in the graph.

Source

pub fn is_empty(&self) -> bool

Source

pub fn node(&self, id: NodeId) -> &Node

Get a node by ID.

Source

pub fn nodes(&self) -> &[Node]

Iterate all nodes in topological order (insertion order = topo order).

Source

pub fn shape(&self, id: NodeId) -> &Shape

Get the shape of a node’s output.

Source

pub fn set_outputs(&mut self, outputs: Vec<NodeId>)

Set the graph outputs.

Source

pub fn set_inputs(&mut self, id: NodeId, inputs: Vec<NodeId>)

Replace the input list of a node in place. Used by post- construction passes (quant_propagate, dce, etc.) that rewire consumers without inserting new nodes. Caller is responsible for shape consistency — this does no re-inference.

Source

pub fn node_mut(&mut self, id: NodeId) -> &mut Node

Source

pub fn nodes_mut(&mut self) -> &mut [Node]

Source

pub fn append_node( &mut self, op: Op, inputs: Vec<NodeId>, shape: Shape, name: Option<String>, ) -> NodeId

Append a node to the graph. pub(crate) so per-op builder files in rlx_ir::ops::* can call it (plan #53). Append a node for backend graph slicing (e.g. TPU HLO segments).

Source

pub fn users(&self, id: NodeId) -> Vec<NodeId>

Find all nodes that use a given node’s output.

Source

pub fn use_count(&self, id: NodeId) -> usize

Count how many nodes use a given node’s output.

Source

pub fn topo_order(&self) -> impl Iterator<Item = NodeId>

Topological order (already guaranteed by construction — just node indices).

Source

pub fn reverse_topo(&self) -> impl Iterator<Item = NodeId>

Reverse topological order (outputs first).

Source

pub fn define( name: impl Into<String>, build: impl FnOnce(&mut HirModule) -> HirNodeId, ) -> GraphModule

Fusion-first model definition at HIR level.

Returns a [GraphModule] at HIR stage; call [GraphModule::lower] or pass to [rlx_opt::CompilePipeline::compile_module].

Source

pub fn hir(name: impl Into<String>) -> GraphModule

Start an empty HIR-stage [GraphModule].

Source

pub fn module(self) -> GraphModule

Wrap this MIR graph in a [GraphModule] for pipeline operations.

Source

pub fn from_hir(hir: HirModule) -> Result<Graph, LowerError>

Lower a HIR module to a MIR graph.

Source

pub fn to_mir(self) -> MirModule

View as [MirModule].

Source

pub fn from_lir(lir: LirModule) -> Graph

Extract the MIR graph from optimized LIR.

Source

pub fn inspect(&self) -> String

Annotated text dump ([inspect_graph]).

Source

pub fn has_dynamic_dims(&self) -> bool

True if any node shape uses a [Dim::Dynamic] symbol.

Source

pub fn dynamic_symbols(&self) -> Vec<u32>

All dynamic symbols referenced in this graph.

Source

pub fn bind(&self, bindings: &DimBinding) -> Graph

Specialize symbolic dims to concrete sizes.

Source

pub fn inspect_module(module: &GraphModule) -> String

Stage-aware dump when wrapped in [GraphModule].

Source§

impl Graph

Source

pub fn attention( &mut self, q: NodeId, k: NodeId, v: NodeId, mask: NodeId, num_heads: usize, head_dim: usize, shape: Shape, ) -> NodeId

Scaled dot-product attention with a custom (caller-supplied) mask. Equivalent to attention_kind(.., MaskKind::Custom, ..).

Source

pub fn attention_opts( &mut self, q: NodeId, k: NodeId, v: NodeId, mask: NodeId, num_heads: usize, head_dim: usize, shape: Shape, score_scale: Option<f32>, attn_logit_softcap: Option<f32>, ) -> NodeId

Like Self::attention with optional score scale and logit softcap.

Source

pub fn attention_kind( &mut self, q: NodeId, k: NodeId, v: NodeId, num_heads: usize, head_dim: usize, mask_kind: MaskKind, shape: Shape, ) -> NodeId

Scaled dot-product attention with a kernel-synthesized mask (None / Causal / SlidingWindow). Inputs are Q, K, V only — no mask tensor is allocated or read in the inner loop. Use MaskKind::None for a single un-padded sequence.

Source

pub fn attention_kind_opts( &mut self, q: NodeId, k: NodeId, v: NodeId, num_heads: usize, head_dim: usize, mask_kind: MaskKind, shape: Shape, score_scale: Option<f32>, attn_logit_softcap: Option<f32>, ) -> NodeId

Like Self::attention_kind with optional score scale and logit softcap.

Source

pub fn attention_bias( &mut self, q: NodeId, k: NodeId, v: NodeId, bias: NodeId, num_heads: usize, head_dim: usize, shape: Shape, ) -> NodeId

Scaled dot-product attention with an additive bias tensor of shape [batch, num_heads, query_len, key_len] added to the QK^T · scale scores before softmax. Lets boxRPB / per-query position biases reuse the fast Op::Attention kernel path.

Source§

impl Graph

Source

pub fn axial_rope2d( &mut self, x: NodeId, end_x: usize, end_y: usize, head_dim: usize, num_heads: usize, theta: f32, repeat_factor: usize, ) -> NodeId

x: [1, seq, num_heads * head_dim] → same shape.

Source§

impl Graph

Source

pub fn relu_backward(&mut self, x: NodeId, dy: NodeId) -> NodeId

ReLU backward: dx = dy where x > 0 else 0. Output shape matches x.

Source

pub fn activation_backward( &mut self, kind: Activation, x: NodeId, dy: NodeId, ) -> NodeId

Element-wise activation backward — closed-form derivative of any single-input activation other than ReLU. See Op::ActivationBackward for the per-kind formulae.

Source

pub fn layer_norm_backward_input( &mut self, x: NodeId, gamma: NodeId, dy: NodeId, axis: i32, eps: f32, ) -> NodeId

LayerNorm backward w.r.t. the input. Inputs [x, gamma, dy]. Output shape matches x. Currently axis = -1 only.

Source

pub fn rms_norm_backward_input( &mut self, x: NodeId, gamma: NodeId, beta: NodeId, dy: NodeId, axis: i32, eps: f32, ) -> NodeId

RMSNorm backward w.r.t. input. Inputs [x, gamma, beta, dy].

Source

pub fn rms_norm_backward_gamma( &mut self, x: NodeId, gamma: NodeId, beta: NodeId, dy: NodeId, axis: i32, eps: f32, ) -> NodeId

Source

pub fn rms_norm_backward_beta( &mut self, x: NodeId, gamma: NodeId, beta: NodeId, dy: NodeId, axis: i32, eps: f32, ) -> NodeId

Source

pub fn rope_backward( &mut self, dy: NodeId, cos: NodeId, sin: NodeId, head_dim: usize, n_rot: usize, ) -> NodeId

Source

pub fn cumsum_backward( &mut self, dy: NodeId, out_shape: Shape, axis: i32, exclusive: bool, ) -> NodeId

Source

pub fn gather_backward( &mut self, dy: NodeId, indices: NodeId, table_shape: Shape, axis: i32, ) -> NodeId

Source

pub fn group_norm_backward_input( &mut self, x: NodeId, gamma: NodeId, beta: NodeId, dy: NodeId, num_groups: usize, eps: f32, ) -> NodeId

GroupNorm (NCHW) backward w.r.t. input. Inputs [x, gamma, beta, dy].

Source

pub fn group_norm_backward_gamma( &mut self, x: NodeId, dy: NodeId, gamma_shape: Shape, num_groups: usize, eps: f32, ) -> NodeId

GroupNorm backward w.r.t. gamma. Inputs [x, dy].

Source

pub fn group_norm_backward_beta( &mut self, x: NodeId, dy: NodeId, beta_shape: Shape, num_groups: usize, eps: f32, ) -> NodeId

GroupNorm backward w.r.t. beta. Inputs [x, dy].

Source

pub fn layer_norm_backward_gamma( &mut self, x: NodeId, dy: NodeId, gamma_shape: Shape, axis: i32, eps: f32, ) -> NodeId

LayerNorm backward w.r.t. gamma. Inputs [x, dy]. Output shape is provided by the caller — typically the gamma’s shape, e.g. [D] for a per-feature 1-D gamma.

Source

pub fn maxpool2d_backward( &mut self, x: NodeId, dy: NodeId, kernel_size: Vec<usize>, stride: Vec<usize>, padding: Vec<usize>, ) -> NodeId

2D max-pool backward. x is the original NCHW input; dy is the upstream gradient with shape matching the pool’s output. Output shape matches x.

Source

pub fn conv2d_backward_input( &mut self, dy: NodeId, w: NodeId, x_shape: Shape, kernel_size: Vec<usize>, stride: Vec<usize>, padding: Vec<usize>, dilation: Vec<usize>, groups: usize, ) -> NodeId

Conv2D backward w.r.t. input. dy has the conv output shape; w is the forward weight [C_out, C_in/groups, kH, kW]. The output shape (the original input shape) is supplied by the caller because it can’t be unambiguously derived from dy.shape alone in the presence of strides + padding.

Source

pub fn conv2d_backward_weight( &mut self, x: NodeId, dy: NodeId, w_shape: Shape, kernel_size: Vec<usize>, stride: Vec<usize>, padding: Vec<usize>, dilation: Vec<usize>, groups: usize, ) -> NodeId

Conv2D backward w.r.t. weight. Output shape matches the forward weight [C_out, C_in/groups, kH, kW].

Source

pub fn softmax_cross_entropy_with_logits( &mut self, logits: NodeId, labels: NodeId, ) -> NodeId

Fused softmax + cross-entropy with f32-encoded integer labels. logits [N, C], labels [N][N] per-row loss.

Source

pub fn softmax_cross_entropy_backward( &mut self, logits: NodeId, labels: NodeId, d_loss: NodeId, ) -> NodeId

Backward of softmax_cross_entropy_with_logits. [logits, labels, d_loss]dlogits shaped like logits.

Source

pub fn complex_norm_sq(&mut self, z: NodeId) -> NodeId

Element-wise complex squared-magnitude: |z|² = re² + im². Input must be DType::C64; output is same logical shape but DType::F32. The canonical real-valued loss surface for Wirtinger reverse-mode AD on complex graphs.

Source

pub fn attention_backward( &mut self, wrt: AttentionBwdWrt, q: NodeId, k: NodeId, v: NodeId, dy: NodeId, num_heads: usize, head_dim: usize, mask_kind: MaskKind, mask: Option<NodeId>, ) -> NodeId

Scaled dot-product attention backward w.r.t. q, k, or v. See Op::AttentionBackward. When mask_kind is MaskKind::Custom or MaskKind::Bias, pass the same mask tensor used in forward.

Source

pub fn attention_backward_all( &mut self, q: NodeId, k: NodeId, v: NodeId, dy: NodeId, num_heads: usize, head_dim: usize, mask_kind: MaskKind, mask: Option<NodeId>, ) -> (NodeId, NodeId, NodeId)

Emit dQ, dK, and dV for one Op::Attention forward node.

Source

pub fn complex_norm_sq_backward(&mut self, z: NodeId, g: NodeId) -> NodeId

Wirtinger backward for [complex_norm_sq]: given upstream g (real, same shape as the forward output) and the original complex input z, returns dz = g · z as C64.

Source

pub fn conjugate(&mut self, z: NodeId) -> NodeId

Element-wise complex conjugate: z̄ = re - i·im. Input must be DType::C64; output is the same shape and dtype. Used by Wirtinger VJP rules on C64 binary ops.

Source§

impl Graph

Source

pub fn linear_bias( &mut self, input: NodeId, weight: NodeId, bias: Option<NodeId>, ) -> NodeId

Dense linear layer: matmul(input, weight) with optional rank-1 bias.

Source

pub fn linear_bias_act( &mut self, input: NodeId, weight: NodeId, bias: Option<NodeId>, activation: Option<Activation>, ) -> NodeId

Dense linear with optional bias and epilogue activation.

Source

pub fn linear_fused( &mut self, input: NodeId, weight: NodeId, bias: NodeId, activation: Option<Activation>, out_shape: Shape, ) -> NodeId

Emit Op::FusedMatMulBiasAct directly — deterministic fusion without relying on the FuseMatMulBiasAct pass.

Source

pub fn shared_matmul_pair( &mut self, input: NodeId, w_first: NodeId, w_second: NodeId, ) -> (NodeId, NodeId)

Two matmuls sharing the same input — canonical gate+up / QKV pattern for FuseSharedInputMatMul.

Returns (first, second) in declaration order. For SwiGLU, pass up weight first and gate weight second so the post-concat narrow layout matches FuseSwiGLU (up @ 0, gate @ N).

Source

pub fn swiglu_ffn( &mut self, input: NodeId, up_w: NodeId, gate_w: NodeId, down_w: NodeId, ) -> NodeId

SwiGLU FFN block: shared-input gate+up → silu(gate) * up → down proj.

Weight order matches FuseSwiGLU’s canonical narrow layout (up projection first, gate projection second).

Source

pub fn fused_swiglu_ffn( &mut self, input: NodeId, up_w: NodeId, gate_w: NodeId, down_w: NodeId, out_shape: Shape, ) -> NodeId

Fully fused SwiGLU FFN: concat weights → single matmul → Op::FusedSwiGLU → down projection. Matches the rewrite performed by FuseSwiGLUDualMatmul without relying on the pass.

Source§

impl Graph

Source

pub fn conv2d( &mut self, input: NodeId, weight: NodeId, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], groups: usize, ) -> NodeId

2D convolution on NCHW tensors (Op::Conv). Weight [C_out, C_in/g, kH, kW].

Source

pub fn conv_transpose2d( &mut self, input: NodeId, weight: NodeId, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], output_padding: [usize; 2], groups: usize, ) -> NodeId

2D transposed convolution on NCHW. Weight [C_in, C_out/g, kH, kW].

Source§

impl Graph

Source

pub fn binary( &mut self, op: BinaryOp, lhs: NodeId, rhs: NodeId, out_shape: Shape, ) -> NodeId

Binary element-wise operation.

Source

pub fn activation( &mut self, act: Activation, input: NodeId, shape: Shape, ) -> NodeId

Unary activation.

Source

pub fn quantize(&mut self, x: NodeId, scale: f32, zero_point: i32) -> NodeId

Per-tensor INT8 quantization. Output dtype = I8, same shape otherwise. scale and zero_point apply uniformly to every element. Use quantize_per_channel when weights deserve per-channel scales (the standard PTQ improvement).

Source

pub fn quantize_per_channel( &mut self, x: NodeId, axis: usize, scales: Vec<f32>, zero_points: Vec<i32>, ) -> NodeId

Per-channel INT8 quantization. scales and zero_points must each have length input.dim(axis); the kernel picks the i-th pair when quantizing the i-th slice along axis. The most common usage is axis = 0 for a [C_out, C_in, kH, kW] conv weight (one scale per output channel).

Source

pub fn dequantize(&mut self, x: NodeId, scale: f32, zero_point: i32) -> NodeId

Per-tensor INT8 dequantization (inverse of quantize). Output dtype is f32.

Source

pub fn dequantize_per_channel( &mut self, x: NodeId, axis: usize, scales: Vec<f32>, zero_points: Vec<i32>, ) -> NodeId

Per-channel INT8 dequantization (inverse of quantize_per_channel).

Source§

impl Graph

Source

pub fn pad_last_axis_to_pow2(&mut self, x: NodeId) -> NodeId

Zero-pad the last axis to the next power of two (no-op when already pow2).

Source

pub fn split_spectrum(&mut self, spectrum: NodeId) -> (NodeId, NodeId)

Split a 2N real-block spectrum into separate real / imag tensors.

Source

pub fn fft_real(&mut self, x: NodeId, norm: FftNorm) -> (NodeId, NodeId)

Real-input FFT (gpu-fft fft): auto zero-pads to pow2, returns (re, im).

Source

pub fn fft_batch_real(&mut self, x: NodeId, norm: FftNorm) -> (NodeId, NodeId)

Batched real-input FFT — same as fft_real when the last axis is signal length; leading axes are independent batch dimensions.

Source

pub fn rfft(&mut self, x: NodeId, norm: FftNorm) -> (NodeId, NodeId)

Real-input FFT with half-spectrum output (n_pad/2 + 1 complex bins).

The input is zero-padded to the next power of two along the last axis before the transform, matching NumPy rfft padding semantics.

Source

pub fn irfft( &mut self, re_half: NodeId, im_half: NodeId, n: usize, norm: FftNorm, ) -> NodeId

Inverse real FFT from half-spectrum (re, im) with Hermitian symmetry.

Mirrors the conjugate half of the spectrum (excluding DC and Nyquist) before calling Self::ifft_spectrum, then truncates to length n.

Source

pub fn stft( &mut self, x: NodeId, frame_len: usize, hop: usize, norm: FftNorm, ) -> NodeId

Short-time Fourier transform: [..., T][frames, ..., 2·half] (re/im block per frame).

Each frame is rfft’d with length frame_len and hop hop along the last axis.

Source

pub fn fft_conv1d( &mut self, a: NodeId, b: NodeId, n_fft: usize, norm: FftNorm, ) -> NodeId

1D convolution via the convolution theorem (rfft → complex multiply → irfft).

Both inputs are zero-padded to at least n_fft (or the next power of two covering len(a) + len(b) - 1 when n_fft is small).

Source

pub fn fftfreq_tensor(&mut self, n: usize) -> NodeId

Constant tensor of FFT sample frequencies (length n, f64).

Source

pub fn rfftfreq_tensor(&mut self, n: usize) -> NodeId

Constant tensor of rFFT sample frequencies (length n/2 + 1, f64).

Source

pub fn psd_real(&mut self, x: NodeId, norm: FftNorm) -> NodeId

Power spectral density from real input: rfftpsd.

Source

pub fn ifft_spectrum(&mut self, re: NodeId, im: NodeId, norm: FftNorm) -> NodeId

Inverse FFT from separate real / imag spectra (gpu-fft ifft real part).

Source

pub fn psd(&mut self, re: NodeId, im: NodeId) -> NodeId

Power spectral density: (re² + im²) / N (gpu-fft psd::psd).

Source§

impl Graph

Source

pub fn input(&mut self, name: impl Into<String>, shape: Shape) -> NodeId

Graph input (runtime-provided tensor).

Source

pub fn param(&mut self, name: impl Into<String>, shape: Shape) -> NodeId

Model parameter (weight loaded at init).

Source

pub fn add_node(&mut self, op: Op, inputs: Vec<NodeId>, shape: Shape) -> NodeId

Generic node constructor for custom ops.

Source

pub fn custom_op( &mut self, name: impl Into<String>, attrs: Vec<u8>, inputs: Vec<NodeId>, ) -> NodeId

Build an Op::Custom node, dispatching shape inference through the global op registry. The named op must already be registered via crate::register_op; attrs is forwarded verbatim to the impl’s infer_shape (and later, at execution time, to its per-backend kernel).

Panics if name is not registered or if inputs.len() does not match the registered num_inputs() — both are programmer errors that should fail loudly at graph-build time, not silently at execution.

Source

pub fn custom_op_packed( &mut self, name: impl Into<String>, attrs: Vec<u8>, inputs: Vec<NodeId>, out_shape: Shape, ) -> NodeId

Build an Op::Custom node with a caller-supplied output shape, bypassing the registry’s infer_shape. Use this for ops whose output shape can’t be determined by static input shapes alone — most importantly, ops with multiple logical outputs packed into one buffer.

The canonical multi-output pattern:

// Sparse-LU returns L_values + U_values packed end-to-end.
// Caller knows nnz_L and nnz_U from the symbolic factor.
let lu = g.custom_op_packed(
    "sparse_lu",
    attrs,
    vec![A, b],
    Shape::new(&[nnz_L + nnz_U], DType::F64),
);
let l_vals = g.narrow_(lu, 0, 0, nnz_L);
let u_vals = g.narrow_(lu, 0, nnz_L, nnz_U);

The op must still be registered (so num_inputs validation and autodiff routing still work); only the shape is overridden.

Source

pub fn fft(&mut self, x: NodeId, inverse: bool) -> NodeId

1D FFT along the last axis.

  • F32 / F64 — 2N real-block layout: last axis is [re…, im…].
  • C64 — interleaved [re, im] pairs per complex element.

Output shape matches input. Radix-2 when N is a power of two, Bluestein otherwise. Default normalization is unnormalized (FftNorm::Backward; ifft(fft(x)) = N·x).

Source

pub fn fft_norm(&mut self, x: NodeId, inverse: bool, norm: FftNorm) -> NodeId

1D FFT with explicit normalization mode.

Source

pub fn fft_axis(&mut self, x: NodeId, axis: usize, inverse: bool) -> NodeId

1D FFT along an arbitrary axis. Lowers to Transpose(axis ↔ last) → Fft(last) → Transpose(last ↔ axis).

AD is free: both Op::Transpose and Op::Fft have VJP/JVP rules.

Source

pub fn fftn(&mut self, x: NodeId, axes: &[usize], inverse: bool) -> NodeId

N-dimensional FFT along axes (NumPy fftn semantics).

Applies a 1D FFT along each listed axis in ascending order. Empty axes is a no-op. For multi-axis transforms on tensors with more than one spatial dimension, use DType::C64; the F32/F64 2N-block layout only describes a single complex axis.

Source

pub fn ifftn(&mut self, x: NodeId, axes: &[usize]) -> NodeId

Inverse N-dimensional FFT — alias for fftn(..., inverse: true).

Source§

impl Graph

Source

pub fn matmul(&mut self, lhs: NodeId, rhs: NodeId, out_shape: Shape) -> NodeId

Matrix multiply.

Source

pub fn dense_solve(&mut self, a: NodeId, b: NodeId, out_shape: Shape) -> NodeId

Dense linear solve x = A⁻¹·b. A must be [N, N]; b is [N] for a single right-hand side or [N, K] for multiple. out_shape matches b’s shape.

Source

pub fn batched_dense_solve( &mut self, a: NodeId, b: NodeId, out_shape: Shape, ) -> NodeId

Batched dense linear solve. A is [B, N, N]; b is [B, N] (single-RHS) or [B, N, K] (multi-RHS). Per-batch independent — each slice solved as a separate dense_solve. Typically constructed by vmap of dense_solve.

Source

pub fn lora_matmul( &mut self, x: NodeId, w: NodeId, a: NodeId, b: NodeId, scale: f32, shape: Shape, ) -> NodeId

Fused LoRA matmul: out = x·W + scale * (x·A)·B. Inputs: x [m, k], w [k, n], a [k, r], b [r, n]. r is the LoRA rank; scale is the alpha/rank coefficient.

Source

pub fn dequant_matmul( &mut self, x: NodeId, w_q: NodeId, scale: NodeId, zp: NodeId, scheme: QuantScheme, shape: Shape, ) -> NodeId

Fused dequant + matmul. See Op::DequantMatMul for per-scheme input layout (4 inputs for legacy/NVFP4, 2 for GGUF).

Source

pub fn dequant_matmul_packed( &mut self, x: NodeId, packed_w: NodeId, scheme: QuantScheme, shape: Shape, ) -> NodeId

GGUF / K-quant packed weights — [x, packed_w_bytes] only.

Source

pub fn dequant_matmul_nvfp4( &mut self, x: NodeId, w_q: NodeId, block_scales: NodeId, global_scale: NodeId, shape: Shape, ) -> NodeId

NVFP4 (E2M1) block matmul — group size 16, FP8 block scales, optional f32 global scale (defaults to 1.0 when unset at runtime).

Source

pub fn fused_matmul_bias_act( &mut self, input: NodeId, weight: NodeId, bias: NodeId, activation: Option<Activation>, shape: Shape, ) -> NodeId

Fused matmul + bias + activation (created by optimization passes).

Source

pub fn q_matmul( &mut self, x: NodeId, w: NodeId, bias: NodeId, x_zp: i32, w_zp: i32, out_zp: i32, mult: f32, out_shape: Shape, ) -> NodeId

Real INT8-arithmetic matmul: i8 inputs, i32 bias, i8 output. mult = x_scale · w_scale / out_scale. Caller’s responsible for asserting the input dtypes — the builder just plumbs the shape with dtype = I8 since that’s what the kernel writes.

Source

pub fn q_conv2d( &mut self, x: NodeId, w: NodeId, bias: NodeId, kernel_size: Vec<usize>, stride: Vec<usize>, padding: Vec<usize>, dilation: Vec<usize>, groups: usize, x_zp: i32, w_zp: i32, out_zp: i32, mult: f32, out_shape: Shape, ) -> NodeId

Real INT8-arithmetic 2-D convolution. NCHW layout matching Op::Conv. mult = x_scale · w_scale / out_scale.

Source§

impl Graph

Source

pub fn layer_norm2d( &mut self, input: NodeId, gamma: NodeId, beta: NodeId, eps: f32, ) -> NodeId

LayerNorm2d on NCHW (normalize across channels at each spatial position).

Source

pub fn group_norm( &mut self, input: NodeId, gamma: NodeId, beta: NodeId, num_groups: usize, eps: f32, ) -> NodeId

Group normalization on NCHW.

Source

pub fn layer_norm( &mut self, input: NodeId, gamma: NodeId, beta: NodeId, axis: i32, eps: f32, shape: Shape, ) -> NodeId

Layer normalization.

Source

pub fn fused_residual_ln( &mut self, x: NodeId, residual: NodeId, bias: Option<NodeId>, gamma: NodeId, beta: NodeId, eps: f32, shape: Shape, ) -> NodeId

Fused residual + bias + layer norm (created by optimization passes).

Source

pub fn fused_residual_rms_norm( &mut self, x: NodeId, residual: NodeId, bias: Option<NodeId>, gamma: NodeId, beta: NodeId, eps: f32, shape: Shape, ) -> NodeId

Fused residual + bias + RMS norm (created by optimization passes).

Source§

impl Graph

Source

pub fn reduce( &mut self, input: NodeId, op: ReduceOp, axes: Vec<usize>, keep_dim: bool, shape: Shape, ) -> NodeId

Reduce.

Source

pub fn softmax(&mut self, input: NodeId, axis: i32, shape: Shape) -> NodeId

Softmax.

Source

pub fn cumsum( &mut self, input: NodeId, axis: i32, exclusive: bool, shape: Shape, ) -> NodeId

Cumulative sum along an axis (output shape == input shape).

Source

pub fn sample( &mut self, logits: NodeId, top_k: usize, top_p: f32, temperature: f32, seed: u64, output_shape: Shape, ) -> NodeId

Fused sample: logits → token id (one f32-encoded id per row). output_shape should be [batch] (one id per logit row).

Source§

impl Graph

Source

pub fn reshape( &mut self, input: NodeId, new_shape: Vec<i64>, out_shape: Shape, ) -> NodeId

Reshape.

Source

pub fn gather( &mut self, table: NodeId, indices: NodeId, axis: usize, shape: Shape, ) -> NodeId

Gather (embedding lookup).

Source

pub fn concat( &mut self, inputs: Vec<NodeId>, axis: usize, shape: Shape, ) -> NodeId

Concatenate tensors along an axis.

Source§

impl Graph

Source

pub fn selective_scan( &mut self, x: NodeId, delta: NodeId, a: NodeId, b: NodeId, c: NodeId, state_size: usize, shape: Shape, ) -> NodeId

Mamba-style selective scan: y = SSM(x, Δ, A, B, C). Inputs: x [b,s,h], delta [b,s,h], a [h,n], b [b,s,n], c [b,s,n]. Output [b,s,h]. n is the state size.

Source

pub fn gated_delta_net( &mut self, q: NodeId, k: NodeId, v: NodeId, g: NodeId, beta: NodeId, state_size: usize, shape: Shape, ) -> NodeId

Gated DeltaNet linear-attention scan (Qwen3.5/3.6 trunk, Qwen3-Next, Kimi-Linear). See Op::GatedDeltaNet for the recurrence math. All five inputs are f32. Shapes: q,k,v: [b, s, h_v, n]; g,beta: [b, s, h_v]. Output: [b, s, h_v, n]. State is implicit (reset per batch) unless carry_state is set — then pass state as a sixth input.

Source

pub fn gated_delta_net_carry( &mut self, q: NodeId, k: NodeId, v: NodeId, g: NodeId, beta: NodeId, state: NodeId, state_size: usize, shape: Shape, ) -> NodeId

Same as Self::gated_delta_net but threads state [b, h_v, n, n] in/out for decode-mode recurrence.

Source

pub fn scan(&mut self, init: NodeId, body: Graph, length: u32) -> NodeId

Bounded scan returning the final carry. Body must have exactly one Op::Input (the carry) and one output, both same shape as init. Output shape matches init.

Source

pub fn scan_checkpointed( &mut self, init: NodeId, body: Graph, length: u32, num_checkpoints: u32, ) -> NodeId

Bounded scan with recursive checkpointing for memory-bounded backward AD. Equivalent to Self::scan for the forward computation, but during backward only num_checkpoints carry values are cached; intermediate carries are recomputed via the body. Memory: O(num_checkpoints · carry_size). Time: forward unchanged; backward O(length) (segment-cached).

The AD pre-pass propagates num_checkpoints into the rewritten trajectory-saving Scan and into the emitted ScanBackward, so a single call to crate::Graph::scan_checkpointed is enough to enable the memory bound across the whole forward+backward pipeline.

Source

pub fn scan_with_bcasts_and_xs( &mut self, init: NodeId, bcasts: &[NodeId], xs: &[NodeId], body: Graph, length: u32, ) -> NodeId

Bounded scan with broadcast and per-step inputs.

Body Op::Inputs in NodeId order: [carry, bcast_0..bcast_{B-1}, x_t_0..x_t_{X-1}]. Bcast inputs keep their natural shape (the CPU executor fills them once before the scan loop). xs[i] has shape [length, *per_step] and the body sees xs[i][t] per iteration. Output shape matches init.

Source

pub fn scan_with_xs( &mut self, init: NodeId, xs: &[NodeId], body: Graph, length: u32, ) -> NodeId

Bounded scan with per-step xs inputs returning the final carry. Body has 1 + xs.len() Op::Inputs in NodeId construction order (first declared is the carry; the remaining match xs in order). Each xs[i] has shape [length, *per_step_shape_i]; the body sees a per_step_shape_i slice on iteration t.

Source

pub fn scan_backward( &mut self, init: NodeId, trajectory: NodeId, upstream: NodeId, xs: &[NodeId], body_vjp: Graph, length: u32, save_trajectory: bool, out_shape: Shape, ) -> NodeId

Reverse-mode AD companion to Self::scan / Self::scan_trajectory. Typically constructed by the autodiff pass, not by hand.

xs is the list of per-step input tensors (must match the forward Op::Scan’s xs in count, order, and per-step shape). Body_vjp’s 1 + xs.len() + 1 Op::Inputs match the forward body’s inputs plus a fresh "d_output" Input.

Source

pub fn scan_backward_with_checkpoints( &mut self, init: NodeId, trajectory: NodeId, upstream: NodeId, xs: &[NodeId], body_vjp: Graph, length: u32, save_trajectory: bool, num_checkpoints: u32, forward_body: Option<Graph>, out_shape: Shape, ) -> NodeId

Lower-level scan_backward with explicit checkpointing config. num_checkpoints == 0 (default) means no checkpointing — the trajectory cache holds every step’s carry. 0 < K < length enables segment-cached recompute via forward_body (must be Some).

Source

pub fn scan_backward_xs( &mut self, init: NodeId, trajectory: NodeId, upstream: NodeId, xs: &[NodeId], body_vjp: Graph, length: u32, save_trajectory: bool, xs_idx: u32, out_shape: Shape, ) -> NodeId

Per-step xs gradient companion to Self::scan_backward. Same inputs and same body_vjp graph, plus an xs_idx selecting which body_vjp output to stack into the result. Output shape is [length, *per_step_xs_shape].

Source

pub fn scan_backward_xs_with_checkpoints( &mut self, init: NodeId, trajectory: NodeId, upstream: NodeId, xs: &[NodeId], body_vjp: Graph, length: u32, save_trajectory: bool, xs_idx: u32, num_checkpoints: u32, forward_body: Option<Graph>, out_shape: Shape, ) -> NodeId

Source

pub fn custom_fn( &mut self, inputs: Vec<NodeId>, fwd_body: Graph, vjp_body: Option<Graph>, jvp_body: Option<Graph>, ) -> NodeId

User-defined sub-graph with optional override AD rules. JAX-shaped custom_vjp / custom_jvp — see Op::CustomFn.

inputs.len() must equal the number of Op::Input nodes in fwd_body. Output shape is inferred from fwd_body’s declared output. When supplied, vjp_body and jvp_body must follow the conventions documented on Op::CustomFn (special-named "primal_output" / "d_output" / "tangent_*" Inputs).

Source

pub fn custom_fn_multi( &mut self, inputs: Vec<NodeId>, fwd_body: Graph, ) -> MultiOutputHandle

Multi-output custom_fn via the concat-with-Narrow design: rewrites fwd_body to flatten + concat its K declared outputs into a single 1-D F32 output, wraps that as Op::CustomFn, and returns a MultiOutputHandle the caller uses to extract each sub-output via Op::Narrow + Op::Reshape.

Per PLAN line 484, this avoids rewriting rlx’s “1 Op = 1 output” IR contract: the wrapped Op::CustomFn still has one output (the flat concat), and MultiOutputHandle::output(g, i) materializes component i lazily on the outer graph.

Constraints (MVP):

  • All sub-outputs must be DType::F32. Tuples-of-mixed-dtype need either a per-dtype split or a future tuple-type extension.
  • All sub-output shapes must be statically known (no Dim::Dynamic).
  • vjp_body / jvp_body aren’t yet rewritten through the concat — caller must provide bodies that already expect the flat-concat output convention if they need custom AD.
Source

pub fn scan_trajectory( &mut self, init: NodeId, body: Graph, length: u32, ) -> NodeId

Bounded scan returning the stacked trajectory. Output shape is [length, *init.shape] — row t is the carry after step t+1, so row length-1 equals the result of plain Self::scan.

Source§

impl Graph

Source

pub fn gaussian_splat_render( &mut self, inputs: GaussianSplatInputs, params: GaussianSplatRenderParams, ) -> NodeId

First-class CPU reference Gaussian splat forward render.

See Op::GaussianSplatRender for the seven-input contract and GaussianSplatRenderParams for framebuffer settings.

Source

pub fn gaussian_splat_render_meta( &mut self, camera_position: [f32; 3], camera_target: [f32; 3], camera_up: [f32; 3], fov_y_degrees: f32, near: f32, far: f32, background: [f32; 3], params: GaussianSplatRenderParams, ) -> NodeId

Build the 23-float meta vector expected by Op::GaussianSplatRender.

Source

pub fn gaussian_splat_prepare( &mut self, inputs: GaussianSplatInputs, params: GaussianSplatRenderParams, ) -> NodeId

Strict IR stage 1: project + bin + sort + rays → packed prepare buffer.

Source

pub fn gaussian_splat_rasterize( &mut self, prep: NodeId, meta: NodeId, params: GaussianSplatRenderParams, ) -> NodeId

Strict IR stage 2: rasterize from prepare buffer + meta.

Source

pub fn gaussian_splat_render_decomposed( &mut self, inputs: GaussianSplatInputs, params: GaussianSplatRenderParams, ) -> NodeId

Decomposed strict-IR forward: prepare → rasterize.

Source

pub fn gaussian_splat_render_backward( &mut self, inputs: GaussianSplatInputs, d_loss_rgba: NodeId, params: GaussianSplatBackwardParams, ) -> NodeId

Backward pass for Op::GaussianSplatRender (packed scene gradients).

Trait Implementations§

Source§

impl Clone for Graph

Source§

fn clone(&self) -> Graph

Returns a duplicate of the value. Read more
1.0.0 (const: unstable) · Source§

fn clone_from(&mut self, source: &Self)

Performs copy-assignment from source. Read more
Source§

impl Debug for Graph

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error>

Formats the value using the given formatter. Read more
Source§

impl<'de> Deserialize<'de> for Graph

Source§

fn deserialize<__D>( __deserializer: __D, ) -> Result<Graph, <__D as Deserializer<'de>>::Error>
where __D: Deserializer<'de>,

Deserialize this value from the given Serde deserializer. Read more
Source§

impl Display for Graph

Pretty-print the graph in a readable IR format.

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error>

Formats the value using the given formatter. Read more
Source§

impl From<MirModule> for Graph

Source§

fn from(mir: MirModule) -> Graph

Converts to this type from the input type.
Source§

impl GraphExt for Graph

Source§

fn mm(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId

Source§

fn add(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId

Source§

fn sub(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId

Source§

fn mul(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId

Source§

fn div(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId

Source§

fn gelu(&mut self, x: NodeId) -> NodeId

Source§

fn gelu_approx(&mut self, x: NodeId) -> NodeId

Tanh-approximation GELU (PyTorch’s default gelu formula, also candle’s Tensor::gelu). Use this when porting models whose reference implementations use the tanh form for numerical parity (e.g. DINOv2, many ViTs).
Source§

fn silu(&mut self, x: NodeId) -> NodeId

Source§

fn relu(&mut self, x: NodeId) -> NodeId

Source§

fn exp(&mut self, x: NodeId) -> NodeId

Source§

fn sqrt(&mut self, x: NodeId) -> NodeId

Source§

fn neg(&mut self, x: NodeId) -> NodeId

Source§

fn tanh(&mut self, x: NodeId) -> NodeId

Source§

fn ln(&mut self, x: NodeId, gamma: NodeId, beta: NodeId, eps: f32) -> NodeId

Source§

fn layer_norm2d( &mut self, x: NodeId, gamma: NodeId, beta: NodeId, eps: f32, ) -> NodeId

Source§

fn group_norm( &mut self, x: NodeId, gamma: NodeId, beta: NodeId, num_groups: usize, eps: f32, ) -> NodeId

Source§

fn conv2d( &mut self, input: NodeId, weight: NodeId, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], groups: usize, ) -> NodeId

Source§

fn conv_transpose2d( &mut self, input: NodeId, weight: NodeId, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], output_padding: [usize; 2], groups: usize, ) -> NodeId

Source§

fn rms_norm( &mut self, x: NodeId, gamma: NodeId, beta: NodeId, eps: f32, ) -> NodeId

Source§

fn sum(&mut self, x: NodeId, axes: Vec<usize>, keep_dim: bool) -> NodeId

Source§

fn mean(&mut self, x: NodeId, axes: Vec<usize>, keep_dim: bool) -> NodeId

Source§

fn sm(&mut self, x: NodeId, axis: i32) -> NodeId

Source§

fn reshape_(&mut self, x: NodeId, new_shape: Vec<i64>) -> NodeId

Source§

fn transpose_(&mut self, x: NodeId, perm: Vec<usize>) -> NodeId

Source§

fn narrow_( &mut self, x: NodeId, axis: usize, start: usize, len: usize, ) -> NodeId

Source§

fn concat_(&mut self, inputs: Vec<NodeId>, axis: usize) -> NodeId

Source§

fn gather_(&mut self, table: NodeId, indices: NodeId, axis: usize) -> NodeId

Source§

fn eq(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId

Source§

fn lt(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId

Source§

fn attention_( &mut self, q: NodeId, k: NodeId, v: NodeId, mask: NodeId, num_heads: usize, head_dim: usize, ) -> NodeId

Source§

fn rope( &mut self, x: NodeId, cos: NodeId, sin: NodeId, head_dim: usize, ) -> NodeId

Source§

fn rope_n( &mut self, x: NodeId, cos: NodeId, sin: NodeId, head_dim: usize, n_rot: usize, ) -> NodeId

Partial RoPE: rotate the first n_rot dims (NeoX offset n_rot/2).
Source§

fn cast(&mut self, x: NodeId, to: DType) -> NodeId

Source§

impl PartialEq for Graph

Source§

fn eq(&self, other: &Graph) -> bool

Tests for self and other values to be equal, and is used by ==.
1.0.0 (const: unstable) · Source§

fn ne(&self, other: &Rhs) -> bool

Tests for !=. The default implementation is almost always sufficient, and should not be overridden without very good reason.
Source§

impl Serialize for Graph

Source§

fn serialize<__S>( &self, __serializer: __S, ) -> Result<<__S as Serializer>::Ok, <__S as Serializer>::Error>
where __S: Serializer,

Serialize this value into the given Serde serializer. Read more
Source§

impl TryFrom<GraphModule> for Graph

Source§

type Error = LowerError

The type returned in the event of a conversion error.
Source§

fn try_from(module: GraphModule) -> Result<Graph, LowerError>

Performs the conversion.

Auto Trait Implementations§

§

impl Freeze for Graph

§

impl RefUnwindSafe for Graph

§

impl Send for Graph

§

impl Sync for Graph

§

impl Unpin for Graph

§

impl UnsafeUnpin for Graph

§

impl UnwindSafe for Graph

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> CloneToUninit for T
where T: Clone,

Source§

unsafe fn clone_to_uninit(&self, dest: *mut u8)

🔬This is a nightly-only experimental API. (clone_to_uninit)
Performs copy-assignment from self to dest. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T> ToOwned for T
where T: Clone,

Source§

type Owned = T

The resulting type after obtaining ownership.
Source§

fn to_owned(&self) -> T

Creates owned data from borrowed data, usually by cloning. Read more
Source§

fn clone_into(&self, target: &mut T)

Uses borrowed data to replace owned data, usually by cloning. Read more
Source§

impl<T> ToString for T
where T: Display + ?Sized,

Source§

fn to_string(&self) -> String

Converts the given value to a String. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<T> DeserializeOwned for T
where T: for<'de> Deserialize<'de>,

Source§

impl<T> WasmNotSend for T
where T: Send,

Source§

impl<T> WasmNotSendSync for T

Source§

impl<T> WasmNotSync for T
where T: Sync,