pub struct Graph {
pub name: String,
pub outputs: Vec<NodeId>,
/* private fields */
}Expand description
A computation graph — the core IR data structure.
§Example
use rlx_ir::*;
let mut g = Graph::new("bert_layer");
// Inputs
let x = g.input("hidden", Shape::new(&[4, 15, 384], DType::F32));
let w = g.param("qkv_weight", Shape::new(&[384, 1152], DType::F32));
let b = g.param("qkv_bias", Shape::new(&[1152], DType::F32));
// QKV projection: matmul + bias
let mm = g.matmul(x, w, Shape::new(&[4, 15, 1152], DType::F32));
let qkv = g.binary(op::BinaryOp::Add, mm, b, Shape::new(&[4, 15, 1152], DType::F32));
assert_eq!(g.len(), 5);
println!("{g}");Fields§
§name: String§outputs: Vec<NodeId>Output node IDs (the graph’s results).
Implementations§
Source§impl Graph
impl Graph
pub fn new(name: impl Into<String>) -> Self
pub fn is_empty(&self) -> bool
Sourcepub fn nodes(&self) -> &[Node]
pub fn nodes(&self) -> &[Node]
Iterate all nodes in topological order (insertion order = topo order).
Sourcepub fn set_outputs(&mut self, outputs: Vec<NodeId>)
pub fn set_outputs(&mut self, outputs: Vec<NodeId>)
Set the graph outputs.
Sourcepub fn set_inputs(&mut self, id: NodeId, inputs: Vec<NodeId>)
pub fn set_inputs(&mut self, id: NodeId, inputs: Vec<NodeId>)
Replace the input list of a node in place. Used by post-
construction passes (quant_propagate, dce, etc.) that
rewire consumers without inserting new nodes.
Caller is responsible for shape consistency — this does no
re-inference.
pub fn node_mut(&mut self, id: NodeId) -> &mut Node
pub fn nodes_mut(&mut self) -> &mut [Node]
Sourcepub fn append_node(
&mut self,
op: Op,
inputs: Vec<NodeId>,
shape: Shape,
name: Option<String>,
) -> NodeId
pub fn append_node( &mut self, op: Op, inputs: Vec<NodeId>, shape: Shape, name: Option<String>, ) -> NodeId
Append a node to the graph. pub(crate) so per-op builder
files in rlx_ir::ops::* can call it (plan #53).
Append a node for backend graph slicing (e.g. TPU HLO segments).
Sourcepub fn topo_order(&self) -> impl Iterator<Item = NodeId> + '_
pub fn topo_order(&self) -> impl Iterator<Item = NodeId> + '_
Topological order (already guaranteed by construction — just node indices).
Sourcepub fn reverse_topo(&self) -> impl Iterator<Item = NodeId> + '_
pub fn reverse_topo(&self) -> impl Iterator<Item = NodeId> + '_
Reverse topological order (outputs first).
Sourcepub fn define(
name: impl Into<String>,
build: impl FnOnce(&mut HirModule) -> HirNodeId,
) -> GraphModule
pub fn define( name: impl Into<String>, build: impl FnOnce(&mut HirModule) -> HirNodeId, ) -> GraphModule
Fusion-first model definition at HIR level.
Returns a [GraphModule] at HIR stage; call [GraphModule::lower]
or pass to [rlx_opt::CompilePipeline::compile_module].
Sourcepub fn hir(name: impl Into<String>) -> GraphModule
pub fn hir(name: impl Into<String>) -> GraphModule
Start an empty HIR-stage [GraphModule].
Sourcepub fn module(self) -> GraphModule
pub fn module(self) -> GraphModule
Wrap this MIR graph in a [GraphModule] for pipeline operations.
Sourcepub fn from_hir(hir: HirModule) -> Result<Self, LowerError>
pub fn from_hir(hir: HirModule) -> Result<Self, LowerError>
Lower a HIR module to a MIR graph.
Sourcepub fn has_dynamic_dims(&self) -> bool
pub fn has_dynamic_dims(&self) -> bool
True if any node shape uses a [Dim::Dynamic] symbol.
Sourcepub fn dynamic_symbols(&self) -> Vec<u32>
pub fn dynamic_symbols(&self) -> Vec<u32>
All dynamic symbols referenced in this graph.
Sourcepub fn bind(&self, bindings: &DimBinding) -> Self
pub fn bind(&self, bindings: &DimBinding) -> Self
Specialize symbolic dims to concrete sizes.
Sourcepub fn inspect_module(module: &GraphModule) -> String
pub fn inspect_module(module: &GraphModule) -> String
Stage-aware dump when wrapped in [GraphModule].
Source§impl Graph
impl Graph
Sourcepub fn attention(
&mut self,
q: NodeId,
k: NodeId,
v: NodeId,
mask: NodeId,
num_heads: usize,
head_dim: usize,
shape: Shape,
) -> NodeId
pub fn attention( &mut self, q: NodeId, k: NodeId, v: NodeId, mask: NodeId, num_heads: usize, head_dim: usize, shape: Shape, ) -> NodeId
Scaled dot-product attention with a custom (caller-supplied) mask.
Equivalent to attention_kind(.., MaskKind::Custom, ..).
Sourcepub fn attention_opts(
&mut self,
q: NodeId,
k: NodeId,
v: NodeId,
mask: NodeId,
num_heads: usize,
head_dim: usize,
shape: Shape,
score_scale: Option<f32>,
attn_logit_softcap: Option<f32>,
) -> NodeId
pub fn attention_opts( &mut self, q: NodeId, k: NodeId, v: NodeId, mask: NodeId, num_heads: usize, head_dim: usize, shape: Shape, score_scale: Option<f32>, attn_logit_softcap: Option<f32>, ) -> NodeId
Like Self::attention with optional score scale and logit softcap.
Sourcepub fn attention_kind(
&mut self,
q: NodeId,
k: NodeId,
v: NodeId,
num_heads: usize,
head_dim: usize,
mask_kind: MaskKind,
shape: Shape,
) -> NodeId
pub fn attention_kind( &mut self, q: NodeId, k: NodeId, v: NodeId, num_heads: usize, head_dim: usize, mask_kind: MaskKind, shape: Shape, ) -> NodeId
Scaled dot-product attention with a kernel-synthesized mask
(None / Causal / SlidingWindow). Inputs are Q, K, V only —
no mask tensor is allocated or read in the inner loop. Use
MaskKind::None for a single un-padded sequence.
Sourcepub fn attention_kind_opts(
&mut self,
q: NodeId,
k: NodeId,
v: NodeId,
num_heads: usize,
head_dim: usize,
mask_kind: MaskKind,
shape: Shape,
score_scale: Option<f32>,
attn_logit_softcap: Option<f32>,
) -> NodeId
pub fn attention_kind_opts( &mut self, q: NodeId, k: NodeId, v: NodeId, num_heads: usize, head_dim: usize, mask_kind: MaskKind, shape: Shape, score_scale: Option<f32>, attn_logit_softcap: Option<f32>, ) -> NodeId
Like Self::attention_kind with optional score scale and logit softcap.
Sourcepub fn attention_bias(
&mut self,
q: NodeId,
k: NodeId,
v: NodeId,
bias: NodeId,
num_heads: usize,
head_dim: usize,
shape: Shape,
) -> NodeId
pub fn attention_bias( &mut self, q: NodeId, k: NodeId, v: NodeId, bias: NodeId, num_heads: usize, head_dim: usize, shape: Shape, ) -> NodeId
Scaled dot-product attention with an additive bias tensor of shape
[batch, num_heads, query_len, key_len] added to the
QK^T · scale scores before softmax. Lets boxRPB / per-query
position biases reuse the fast Op::Attention kernel path.
Source§impl Graph
impl Graph
Sourcepub fn relu_backward(&mut self, x: NodeId, dy: NodeId) -> NodeId
pub fn relu_backward(&mut self, x: NodeId, dy: NodeId) -> NodeId
ReLU backward: dx = dy where x > 0 else 0. Output shape matches x.
Sourcepub fn activation_backward(
&mut self,
kind: Activation,
x: NodeId,
dy: NodeId,
) -> NodeId
pub fn activation_backward( &mut self, kind: Activation, x: NodeId, dy: NodeId, ) -> NodeId
Element-wise activation backward — closed-form derivative of
any single-input activation other than ReLU. See
Op::ActivationBackward for the per-kind formulae.
Sourcepub fn layer_norm_backward_input(
&mut self,
x: NodeId,
gamma: NodeId,
dy: NodeId,
axis: i32,
eps: f32,
) -> NodeId
pub fn layer_norm_backward_input( &mut self, x: NodeId, gamma: NodeId, dy: NodeId, axis: i32, eps: f32, ) -> NodeId
LayerNorm backward w.r.t. the input. Inputs [x, gamma, dy].
Output shape matches x. Currently axis = -1 only.
Sourcepub fn rms_norm_backward_input(
&mut self,
x: NodeId,
gamma: NodeId,
beta: NodeId,
dy: NodeId,
axis: i32,
eps: f32,
) -> NodeId
pub fn rms_norm_backward_input( &mut self, x: NodeId, gamma: NodeId, beta: NodeId, dy: NodeId, axis: i32, eps: f32, ) -> NodeId
RMSNorm backward w.r.t. input. Inputs [x, gamma, beta, dy].
pub fn rms_norm_backward_gamma( &mut self, x: NodeId, gamma: NodeId, beta: NodeId, dy: NodeId, axis: i32, eps: f32, ) -> NodeId
pub fn rms_norm_backward_beta( &mut self, x: NodeId, gamma: NodeId, beta: NodeId, dy: NodeId, axis: i32, eps: f32, ) -> NodeId
pub fn rope_backward( &mut self, dy: NodeId, cos: NodeId, sin: NodeId, head_dim: usize, n_rot: usize, ) -> NodeId
pub fn cumsum_backward( &mut self, dy: NodeId, out_shape: Shape, axis: i32, exclusive: bool, ) -> NodeId
pub fn gather_backward( &mut self, dy: NodeId, indices: NodeId, table_shape: Shape, axis: i32, ) -> NodeId
Sourcepub fn group_norm_backward_input(
&mut self,
x: NodeId,
gamma: NodeId,
beta: NodeId,
dy: NodeId,
num_groups: usize,
eps: f32,
) -> NodeId
pub fn group_norm_backward_input( &mut self, x: NodeId, gamma: NodeId, beta: NodeId, dy: NodeId, num_groups: usize, eps: f32, ) -> NodeId
GroupNorm (NCHW) backward w.r.t. input. Inputs [x, gamma, beta, dy].
Sourcepub fn group_norm_backward_gamma(
&mut self,
x: NodeId,
dy: NodeId,
gamma_shape: Shape,
num_groups: usize,
eps: f32,
) -> NodeId
pub fn group_norm_backward_gamma( &mut self, x: NodeId, dy: NodeId, gamma_shape: Shape, num_groups: usize, eps: f32, ) -> NodeId
GroupNorm backward w.r.t. gamma. Inputs [x, dy].
Sourcepub fn group_norm_backward_beta(
&mut self,
x: NodeId,
dy: NodeId,
beta_shape: Shape,
num_groups: usize,
eps: f32,
) -> NodeId
pub fn group_norm_backward_beta( &mut self, x: NodeId, dy: NodeId, beta_shape: Shape, num_groups: usize, eps: f32, ) -> NodeId
GroupNorm backward w.r.t. beta. Inputs [x, dy].
Sourcepub fn layer_norm_backward_gamma(
&mut self,
x: NodeId,
dy: NodeId,
gamma_shape: Shape,
axis: i32,
eps: f32,
) -> NodeId
pub fn layer_norm_backward_gamma( &mut self, x: NodeId, dy: NodeId, gamma_shape: Shape, axis: i32, eps: f32, ) -> NodeId
LayerNorm backward w.r.t. gamma. Inputs [x, dy]. Output shape
is provided by the caller — typically the gamma’s shape, e.g.
[D] for a per-feature 1-D gamma.
Sourcepub fn maxpool2d_backward(
&mut self,
x: NodeId,
dy: NodeId,
kernel_size: Vec<usize>,
stride: Vec<usize>,
padding: Vec<usize>,
) -> NodeId
pub fn maxpool2d_backward( &mut self, x: NodeId, dy: NodeId, kernel_size: Vec<usize>, stride: Vec<usize>, padding: Vec<usize>, ) -> NodeId
2D max-pool backward. x is the original NCHW input; dy is
the upstream gradient with shape matching the pool’s output.
Output shape matches x.
Sourcepub fn conv2d_backward_input(
&mut self,
dy: NodeId,
w: NodeId,
x_shape: Shape,
kernel_size: Vec<usize>,
stride: Vec<usize>,
padding: Vec<usize>,
dilation: Vec<usize>,
groups: usize,
) -> NodeId
pub fn conv2d_backward_input( &mut self, dy: NodeId, w: NodeId, x_shape: Shape, kernel_size: Vec<usize>, stride: Vec<usize>, padding: Vec<usize>, dilation: Vec<usize>, groups: usize, ) -> NodeId
Conv2D backward w.r.t. input. dy has the conv output shape;
w is the forward weight [C_out, C_in/groups, kH, kW]. The
output shape (the original input shape) is supplied by the
caller because it can’t be unambiguously derived from dy.shape
alone in the presence of strides + padding.
Sourcepub fn conv2d_backward_weight(
&mut self,
x: NodeId,
dy: NodeId,
w_shape: Shape,
kernel_size: Vec<usize>,
stride: Vec<usize>,
padding: Vec<usize>,
dilation: Vec<usize>,
groups: usize,
) -> NodeId
pub fn conv2d_backward_weight( &mut self, x: NodeId, dy: NodeId, w_shape: Shape, kernel_size: Vec<usize>, stride: Vec<usize>, padding: Vec<usize>, dilation: Vec<usize>, groups: usize, ) -> NodeId
Conv2D backward w.r.t. weight. Output shape matches the forward
weight [C_out, C_in/groups, kH, kW].
Sourcepub fn softmax_cross_entropy_with_logits(
&mut self,
logits: NodeId,
labels: NodeId,
) -> NodeId
pub fn softmax_cross_entropy_with_logits( &mut self, logits: NodeId, labels: NodeId, ) -> NodeId
Fused softmax + cross-entropy with f32-encoded integer labels.
logits [N, C], labels [N] → [N] per-row loss.
Sourcepub fn softmax_cross_entropy_backward(
&mut self,
logits: NodeId,
labels: NodeId,
d_loss: NodeId,
) -> NodeId
pub fn softmax_cross_entropy_backward( &mut self, logits: NodeId, labels: NodeId, d_loss: NodeId, ) -> NodeId
Backward of softmax_cross_entropy_with_logits.
[logits, labels, d_loss] → dlogits shaped like logits.
Sourcepub fn complex_norm_sq(&mut self, z: NodeId) -> NodeId
pub fn complex_norm_sq(&mut self, z: NodeId) -> NodeId
Element-wise complex squared-magnitude: |z|² = re² + im².
Input must be DType::C64; output is same logical shape but
DType::F32. The canonical real-valued loss surface for
Wirtinger reverse-mode AD on complex graphs.
Sourcepub fn attention_backward(
&mut self,
wrt: AttentionBwdWrt,
q: NodeId,
k: NodeId,
v: NodeId,
dy: NodeId,
num_heads: usize,
head_dim: usize,
mask_kind: MaskKind,
mask: Option<NodeId>,
) -> NodeId
pub fn attention_backward( &mut self, wrt: AttentionBwdWrt, q: NodeId, k: NodeId, v: NodeId, dy: NodeId, num_heads: usize, head_dim: usize, mask_kind: MaskKind, mask: Option<NodeId>, ) -> NodeId
Scaled dot-product attention backward w.r.t. q, k, or v.
See Op::AttentionBackward. When mask_kind is MaskKind::Custom
or MaskKind::Bias, pass the same mask tensor used in forward.
Sourcepub fn attention_backward_all(
&mut self,
q: NodeId,
k: NodeId,
v: NodeId,
dy: NodeId,
num_heads: usize,
head_dim: usize,
mask_kind: MaskKind,
mask: Option<NodeId>,
) -> (NodeId, NodeId, NodeId)
pub fn attention_backward_all( &mut self, q: NodeId, k: NodeId, v: NodeId, dy: NodeId, num_heads: usize, head_dim: usize, mask_kind: MaskKind, mask: Option<NodeId>, ) -> (NodeId, NodeId, NodeId)
Emit dQ, dK, and dV for one Op::Attention forward node.
Sourcepub fn complex_norm_sq_backward(&mut self, z: NodeId, g: NodeId) -> NodeId
pub fn complex_norm_sq_backward(&mut self, z: NodeId, g: NodeId) -> NodeId
Wirtinger backward for [complex_norm_sq]: given upstream g
(real, same shape as the forward output) and the original
complex input z, returns dz = g · z as C64.
Source§impl Graph
impl Graph
Sourcepub fn linear_bias(
&mut self,
input: NodeId,
weight: NodeId,
bias: Option<NodeId>,
) -> NodeId
pub fn linear_bias( &mut self, input: NodeId, weight: NodeId, bias: Option<NodeId>, ) -> NodeId
Dense linear layer: matmul(input, weight) with optional rank-1 bias.
Sourcepub fn linear_bias_act(
&mut self,
input: NodeId,
weight: NodeId,
bias: Option<NodeId>,
activation: Option<Activation>,
) -> NodeId
pub fn linear_bias_act( &mut self, input: NodeId, weight: NodeId, bias: Option<NodeId>, activation: Option<Activation>, ) -> NodeId
Dense linear with optional bias and epilogue activation.
Sourcepub fn linear_fused(
&mut self,
input: NodeId,
weight: NodeId,
bias: NodeId,
activation: Option<Activation>,
out_shape: Shape,
) -> NodeId
pub fn linear_fused( &mut self, input: NodeId, weight: NodeId, bias: NodeId, activation: Option<Activation>, out_shape: Shape, ) -> NodeId
Emit Op::FusedMatMulBiasAct directly — deterministic fusion
without relying on the FuseMatMulBiasAct pass.
Two matmuls sharing the same input — canonical gate+up / QKV
pattern for FuseSharedInputMatMul.
Returns (first, second) in declaration order. For SwiGLU,
pass up weight first and gate weight second so the
post-concat narrow layout matches FuseSwiGLU (up @ 0, gate @ N).
Sourcepub fn swiglu_ffn(
&mut self,
input: NodeId,
up_w: NodeId,
gate_w: NodeId,
down_w: NodeId,
) -> NodeId
pub fn swiglu_ffn( &mut self, input: NodeId, up_w: NodeId, gate_w: NodeId, down_w: NodeId, ) -> NodeId
SwiGLU FFN block: shared-input gate+up → silu(gate) * up → down proj.
Weight order matches FuseSwiGLU’s canonical narrow layout
(up projection first, gate projection second).
Sourcepub fn fused_swiglu_ffn(
&mut self,
input: NodeId,
up_w: NodeId,
gate_w: NodeId,
down_w: NodeId,
out_shape: Shape,
) -> NodeId
pub fn fused_swiglu_ffn( &mut self, input: NodeId, up_w: NodeId, gate_w: NodeId, down_w: NodeId, out_shape: Shape, ) -> NodeId
Fully fused SwiGLU FFN: concat weights → single matmul →
Op::FusedSwiGLU → down projection. Matches the rewrite
performed by FuseSwiGLUDualMatmul
without relying on the pass.
Source§impl Graph
impl Graph
Source§impl Graph
impl Graph
Sourcepub fn binary(
&mut self,
op: BinaryOp,
lhs: NodeId,
rhs: NodeId,
out_shape: Shape,
) -> NodeId
pub fn binary( &mut self, op: BinaryOp, lhs: NodeId, rhs: NodeId, out_shape: Shape, ) -> NodeId
Binary element-wise operation.
Sourcepub fn activation(
&mut self,
act: Activation,
input: NodeId,
shape: Shape,
) -> NodeId
pub fn activation( &mut self, act: Activation, input: NodeId, shape: Shape, ) -> NodeId
Unary activation.
Sourcepub fn quantize(&mut self, x: NodeId, scale: f32, zero_point: i32) -> NodeId
pub fn quantize(&mut self, x: NodeId, scale: f32, zero_point: i32) -> NodeId
Per-tensor INT8 quantization. Output dtype = I8, same shape
otherwise. scale and zero_point apply uniformly to every
element. Use quantize_per_channel when weights deserve
per-channel scales (the standard PTQ improvement).
Sourcepub fn quantize_per_channel(
&mut self,
x: NodeId,
axis: usize,
scales: Vec<f32>,
zero_points: Vec<i32>,
) -> NodeId
pub fn quantize_per_channel( &mut self, x: NodeId, axis: usize, scales: Vec<f32>, zero_points: Vec<i32>, ) -> NodeId
Per-channel INT8 quantization. scales and zero_points must
each have length input.dim(axis); the kernel picks the i-th
pair when quantizing the i-th slice along axis. The most
common usage is axis = 0 for a [C_out, C_in, kH, kW]
conv weight (one scale per output channel).
Source§impl Graph
impl Graph
Sourcepub fn pad_last_axis_to_pow2(&mut self, x: NodeId) -> NodeId
pub fn pad_last_axis_to_pow2(&mut self, x: NodeId) -> NodeId
Zero-pad the last axis to the next power of two (no-op when already pow2).
Sourcepub fn split_spectrum(&mut self, spectrum: NodeId) -> (NodeId, NodeId)
pub fn split_spectrum(&mut self, spectrum: NodeId) -> (NodeId, NodeId)
Split a 2N real-block spectrum into separate real / imag tensors.
Sourcepub fn fft_real(&mut self, x: NodeId, norm: FftNorm) -> (NodeId, NodeId)
pub fn fft_real(&mut self, x: NodeId, norm: FftNorm) -> (NodeId, NodeId)
Real-input FFT (gpu-fft fft): auto zero-pads to pow2, returns (re, im).
Sourcepub fn fft_batch_real(&mut self, x: NodeId, norm: FftNorm) -> (NodeId, NodeId)
pub fn fft_batch_real(&mut self, x: NodeId, norm: FftNorm) -> (NodeId, NodeId)
Batched real-input FFT — same as fft_real when the last axis is signal
length; leading axes are independent batch dimensions.
Sourcepub fn rfft(&mut self, x: NodeId, norm: FftNorm) -> (NodeId, NodeId)
pub fn rfft(&mut self, x: NodeId, norm: FftNorm) -> (NodeId, NodeId)
Real-input FFT with half-spectrum output (n_pad/2 + 1 complex bins).
The input is zero-padded to the next power of two along the last axis
before the transform, matching NumPy rfft padding semantics.
Sourcepub fn irfft(
&mut self,
re_half: NodeId,
im_half: NodeId,
n: usize,
norm: FftNorm,
) -> NodeId
pub fn irfft( &mut self, re_half: NodeId, im_half: NodeId, n: usize, norm: FftNorm, ) -> NodeId
Inverse real FFT from half-spectrum (re, im) with Hermitian symmetry.
Mirrors the conjugate half of the spectrum (excluding DC and Nyquist) before
calling Self::ifft_spectrum, then truncates to length n.
Sourcepub fn stft(
&mut self,
x: NodeId,
frame_len: usize,
hop: usize,
norm: FftNorm,
) -> NodeId
pub fn stft( &mut self, x: NodeId, frame_len: usize, hop: usize, norm: FftNorm, ) -> NodeId
Short-time Fourier transform: [..., T] → [frames, ..., 2·half] (re/im block per frame).
Each frame is rfft’d with length frame_len and hop hop along the last axis.
Sourcepub fn fft_conv1d(
&mut self,
a: NodeId,
b: NodeId,
n_fft: usize,
norm: FftNorm,
) -> NodeId
pub fn fft_conv1d( &mut self, a: NodeId, b: NodeId, n_fft: usize, norm: FftNorm, ) -> NodeId
1D convolution via the convolution theorem (rfft → complex multiply → irfft).
Both inputs are zero-padded to at least n_fft (or the next power of two covering
len(a) + len(b) - 1 when n_fft is small).
Sourcepub fn fftfreq_tensor(&mut self, n: usize) -> NodeId
pub fn fftfreq_tensor(&mut self, n: usize) -> NodeId
Constant tensor of FFT sample frequencies (length n, f64).
Sourcepub fn rfftfreq_tensor(&mut self, n: usize) -> NodeId
pub fn rfftfreq_tensor(&mut self, n: usize) -> NodeId
Constant tensor of rFFT sample frequencies (length n/2 + 1, f64).
Sourcepub fn psd_real(&mut self, x: NodeId, norm: FftNorm) -> NodeId
pub fn psd_real(&mut self, x: NodeId, norm: FftNorm) -> NodeId
Power spectral density from real input: rfft → psd.
Source§impl Graph
impl Graph
Sourcepub fn input(&mut self, name: impl Into<String>, shape: Shape) -> NodeId
pub fn input(&mut self, name: impl Into<String>, shape: Shape) -> NodeId
Graph input (runtime-provided tensor).
Sourcepub fn param(&mut self, name: impl Into<String>, shape: Shape) -> NodeId
pub fn param(&mut self, name: impl Into<String>, shape: Shape) -> NodeId
Model parameter (weight loaded at init).
Sourcepub fn add_node(&mut self, op: Op, inputs: Vec<NodeId>, shape: Shape) -> NodeId
pub fn add_node(&mut self, op: Op, inputs: Vec<NodeId>, shape: Shape) -> NodeId
Generic node constructor for custom ops.
Sourcepub fn custom_op(
&mut self,
name: impl Into<String>,
attrs: Vec<u8>,
inputs: Vec<NodeId>,
) -> NodeId
pub fn custom_op( &mut self, name: impl Into<String>, attrs: Vec<u8>, inputs: Vec<NodeId>, ) -> NodeId
Build an Op::Custom node, dispatching shape inference through
the global op registry. The named op must already be registered
via crate::register_op; attrs is forwarded verbatim to
the impl’s infer_shape (and later, at execution time, to its
per-backend kernel).
Panics if name is not registered or if inputs.len() does
not match the registered num_inputs() — both are programmer
errors that should fail loudly at graph-build time, not silently
at execution.
Sourcepub fn custom_op_packed(
&mut self,
name: impl Into<String>,
attrs: Vec<u8>,
inputs: Vec<NodeId>,
out_shape: Shape,
) -> NodeId
pub fn custom_op_packed( &mut self, name: impl Into<String>, attrs: Vec<u8>, inputs: Vec<NodeId>, out_shape: Shape, ) -> NodeId
Build an Op::Custom node with a caller-supplied output shape,
bypassing the registry’s infer_shape. Use this for ops
whose output shape can’t be determined by static input shapes
alone — most importantly, ops with multiple logical outputs
packed into one buffer.
The canonical multi-output pattern:
// Sparse-LU returns L_values + U_values packed end-to-end.
// Caller knows nnz_L and nnz_U from the symbolic factor.
let lu = g.custom_op_packed(
"sparse_lu",
attrs,
vec![A, b],
Shape::new(&[nnz_L + nnz_U], DType::F64),
);
let l_vals = g.narrow_(lu, 0, 0, nnz_L);
let u_vals = g.narrow_(lu, 0, nnz_L, nnz_U);The op must still be registered (so num_inputs validation
and autodiff routing still work); only the shape is overridden.
Sourcepub fn fft(&mut self, x: NodeId, inverse: bool) -> NodeId
pub fn fft(&mut self, x: NodeId, inverse: bool) -> NodeId
1D FFT along the last axis.
- F32 / F64 — 2N real-block layout: last axis is
[re…, im…]. - C64 — interleaved
[re, im]pairs per complex element.
Output shape matches input. Radix-2 when N is a power of two,
Bluestein otherwise. Default normalization is unnormalized
(FftNorm::Backward; ifft(fft(x)) = N·x).
Sourcepub fn fft_norm(&mut self, x: NodeId, inverse: bool, norm: FftNorm) -> NodeId
pub fn fft_norm(&mut self, x: NodeId, inverse: bool, norm: FftNorm) -> NodeId
1D FFT with explicit normalization mode.
Sourcepub fn fft_axis(&mut self, x: NodeId, axis: usize, inverse: bool) -> NodeId
pub fn fft_axis(&mut self, x: NodeId, axis: usize, inverse: bool) -> NodeId
1D FFT along an arbitrary axis. Lowers to
Transpose(axis ↔ last) → Fft(last) → Transpose(last ↔ axis).
AD is free: both Op::Transpose and Op::Fft have VJP/JVP rules.
Sourcepub fn fftn(&mut self, x: NodeId, axes: &[usize], inverse: bool) -> NodeId
pub fn fftn(&mut self, x: NodeId, axes: &[usize], inverse: bool) -> NodeId
N-dimensional FFT along axes (NumPy fftn semantics).
Applies a 1D FFT along each listed axis in ascending order.
Empty axes is a no-op. For multi-axis transforms on tensors
with more than one spatial dimension, use DType::C64; the
F32/F64 2N-block layout only describes a single complex axis.
Source§impl Graph
impl Graph
Sourcepub fn matmul(&mut self, lhs: NodeId, rhs: NodeId, out_shape: Shape) -> NodeId
pub fn matmul(&mut self, lhs: NodeId, rhs: NodeId, out_shape: Shape) -> NodeId
Matrix multiply.
Sourcepub fn dense_solve(&mut self, a: NodeId, b: NodeId, out_shape: Shape) -> NodeId
pub fn dense_solve(&mut self, a: NodeId, b: NodeId, out_shape: Shape) -> NodeId
Dense linear solve x = A⁻¹·b. A must be [N, N]; b is
[N] for a single right-hand side or [N, K] for multiple.
out_shape matches b’s shape.
Sourcepub fn batched_dense_solve(
&mut self,
a: NodeId,
b: NodeId,
out_shape: Shape,
) -> NodeId
pub fn batched_dense_solve( &mut self, a: NodeId, b: NodeId, out_shape: Shape, ) -> NodeId
Batched dense linear solve. A is [B, N, N]; b is
[B, N] (single-RHS) or [B, N, K] (multi-RHS). Per-batch
independent — each slice solved as a separate dense_solve.
Typically constructed by vmap of dense_solve.
Sourcepub fn lora_matmul(
&mut self,
x: NodeId,
w: NodeId,
a: NodeId,
b: NodeId,
scale: f32,
shape: Shape,
) -> NodeId
pub fn lora_matmul( &mut self, x: NodeId, w: NodeId, a: NodeId, b: NodeId, scale: f32, shape: Shape, ) -> NodeId
Fused LoRA matmul: out = x·W + scale * (x·A)·B. Inputs: x [m, k], w [k, n], a [k, r], b [r, n]. r is the LoRA rank; scale is the alpha/rank coefficient.
Sourcepub fn dequant_matmul(
&mut self,
x: NodeId,
w_q: NodeId,
scale: NodeId,
zp: NodeId,
scheme: QuantScheme,
shape: Shape,
) -> NodeId
pub fn dequant_matmul( &mut self, x: NodeId, w_q: NodeId, scale: NodeId, zp: NodeId, scheme: QuantScheme, shape: Shape, ) -> NodeId
Fused dequant + matmul. See Op::DequantMatMul for per-scheme
input layout (4 inputs for legacy/NVFP4, 2 for GGUF).
Sourcepub fn dequant_matmul_packed(
&mut self,
x: NodeId,
packed_w: NodeId,
scheme: QuantScheme,
shape: Shape,
) -> NodeId
pub fn dequant_matmul_packed( &mut self, x: NodeId, packed_w: NodeId, scheme: QuantScheme, shape: Shape, ) -> NodeId
GGUF / K-quant packed weights — [x, packed_w_bytes] only.
Sourcepub fn dequant_matmul_nvfp4(
&mut self,
x: NodeId,
w_q: NodeId,
block_scales: NodeId,
global_scale: NodeId,
shape: Shape,
) -> NodeId
pub fn dequant_matmul_nvfp4( &mut self, x: NodeId, w_q: NodeId, block_scales: NodeId, global_scale: NodeId, shape: Shape, ) -> NodeId
NVFP4 (E2M1) block matmul — group size 16, FP8 block scales, optional f32 global scale (defaults to 1.0 when unset at runtime).
Sourcepub fn fused_matmul_bias_act(
&mut self,
input: NodeId,
weight: NodeId,
bias: NodeId,
activation: Option<Activation>,
shape: Shape,
) -> NodeId
pub fn fused_matmul_bias_act( &mut self, input: NodeId, weight: NodeId, bias: NodeId, activation: Option<Activation>, shape: Shape, ) -> NodeId
Fused matmul + bias + activation (created by optimization passes).
Sourcepub fn q_matmul(
&mut self,
x: NodeId,
w: NodeId,
bias: NodeId,
x_zp: i32,
w_zp: i32,
out_zp: i32,
mult: f32,
out_shape: Shape,
) -> NodeId
pub fn q_matmul( &mut self, x: NodeId, w: NodeId, bias: NodeId, x_zp: i32, w_zp: i32, out_zp: i32, mult: f32, out_shape: Shape, ) -> NodeId
Real INT8-arithmetic matmul: i8 inputs, i32 bias, i8 output.
mult = x_scale · w_scale / out_scale. Caller’s responsible
for asserting the input dtypes — the builder just plumbs the
shape with dtype = I8 since that’s what the kernel writes.
Sourcepub fn q_conv2d(
&mut self,
x: NodeId,
w: NodeId,
bias: NodeId,
kernel_size: Vec<usize>,
stride: Vec<usize>,
padding: Vec<usize>,
dilation: Vec<usize>,
groups: usize,
x_zp: i32,
w_zp: i32,
out_zp: i32,
mult: f32,
out_shape: Shape,
) -> NodeId
pub fn q_conv2d( &mut self, x: NodeId, w: NodeId, bias: NodeId, kernel_size: Vec<usize>, stride: Vec<usize>, padding: Vec<usize>, dilation: Vec<usize>, groups: usize, x_zp: i32, w_zp: i32, out_zp: i32, mult: f32, out_shape: Shape, ) -> NodeId
Real INT8-arithmetic 2-D convolution. NCHW layout matching
Op::Conv. mult = x_scale · w_scale / out_scale.
Source§impl Graph
impl Graph
Sourcepub fn layer_norm2d(
&mut self,
input: NodeId,
gamma: NodeId,
beta: NodeId,
eps: f32,
) -> NodeId
pub fn layer_norm2d( &mut self, input: NodeId, gamma: NodeId, beta: NodeId, eps: f32, ) -> NodeId
LayerNorm2d on NCHW (normalize across channels at each spatial position).
Sourcepub fn group_norm(
&mut self,
input: NodeId,
gamma: NodeId,
beta: NodeId,
num_groups: usize,
eps: f32,
) -> NodeId
pub fn group_norm( &mut self, input: NodeId, gamma: NodeId, beta: NodeId, num_groups: usize, eps: f32, ) -> NodeId
Group normalization on NCHW.
Sourcepub fn layer_norm(
&mut self,
input: NodeId,
gamma: NodeId,
beta: NodeId,
axis: i32,
eps: f32,
shape: Shape,
) -> NodeId
pub fn layer_norm( &mut self, input: NodeId, gamma: NodeId, beta: NodeId, axis: i32, eps: f32, shape: Shape, ) -> NodeId
Layer normalization.
Source§impl Graph
impl Graph
Source§impl Graph
impl Graph
Source§impl Graph
impl Graph
Sourcepub fn selective_scan(
&mut self,
x: NodeId,
delta: NodeId,
a: NodeId,
b: NodeId,
c: NodeId,
state_size: usize,
shape: Shape,
) -> NodeId
pub fn selective_scan( &mut self, x: NodeId, delta: NodeId, a: NodeId, b: NodeId, c: NodeId, state_size: usize, shape: Shape, ) -> NodeId
Mamba-style selective scan: y = SSM(x, Δ, A, B, C). Inputs: x [b,s,h], delta [b,s,h], a [h,n], b [b,s,n], c [b,s,n]. Output [b,s,h]. n is the state size.
Sourcepub fn gated_delta_net(
&mut self,
q: NodeId,
k: NodeId,
v: NodeId,
g: NodeId,
beta: NodeId,
state_size: usize,
shape: Shape,
) -> NodeId
pub fn gated_delta_net( &mut self, q: NodeId, k: NodeId, v: NodeId, g: NodeId, beta: NodeId, state_size: usize, shape: Shape, ) -> NodeId
Gated DeltaNet linear-attention scan (Qwen3.5/3.6 trunk,
Qwen3-Next, Kimi-Linear). See Op::GatedDeltaNet for the
recurrence math. All five inputs are f32. Shapes:
q,k,v: [b, s, h_v, n]; g,beta: [b, s, h_v]. Output:
[b, s, h_v, n]. State is implicit (reset per batch) unless
carry_state is set — then pass state as a sixth input.
Sourcepub fn gated_delta_net_carry(
&mut self,
q: NodeId,
k: NodeId,
v: NodeId,
g: NodeId,
beta: NodeId,
state: NodeId,
state_size: usize,
shape: Shape,
) -> NodeId
pub fn gated_delta_net_carry( &mut self, q: NodeId, k: NodeId, v: NodeId, g: NodeId, beta: NodeId, state: NodeId, state_size: usize, shape: Shape, ) -> NodeId
Same as Self::gated_delta_net but threads state
[b, h_v, n, n] in/out for decode-mode recurrence.
Sourcepub fn scan(&mut self, init: NodeId, body: Graph, length: u32) -> NodeId
pub fn scan(&mut self, init: NodeId, body: Graph, length: u32) -> NodeId
Bounded scan returning the final carry. Body must have exactly
one Op::Input (the carry) and one output, both same shape as
init. Output shape matches init.
Sourcepub fn scan_checkpointed(
&mut self,
init: NodeId,
body: Graph,
length: u32,
num_checkpoints: u32,
) -> NodeId
pub fn scan_checkpointed( &mut self, init: NodeId, body: Graph, length: u32, num_checkpoints: u32, ) -> NodeId
Bounded scan with recursive checkpointing for memory-bounded
backward AD. Equivalent to Self::scan for the forward
computation, but during backward only num_checkpoints carry
values are cached; intermediate carries are recomputed via the
body. Memory: O(num_checkpoints · carry_size). Time: forward
unchanged; backward O(length) (segment-cached).
The AD pre-pass propagates num_checkpoints into the rewritten
trajectory-saving Scan and into the emitted ScanBackward, so a
single call to crate::Graph::scan_checkpointed is enough
to enable the memory bound across the whole forward+backward
pipeline.
Sourcepub fn scan_with_bcasts_and_xs(
&mut self,
init: NodeId,
bcasts: &[NodeId],
xs: &[NodeId],
body: Graph,
length: u32,
) -> NodeId
pub fn scan_with_bcasts_and_xs( &mut self, init: NodeId, bcasts: &[NodeId], xs: &[NodeId], body: Graph, length: u32, ) -> NodeId
Bounded scan with broadcast and per-step inputs.
Body Op::Inputs in NodeId order: [carry, bcast_0..bcast_{B-1}, x_t_0..x_t_{X-1}]. Bcast inputs keep their natural shape (the
CPU executor fills them once before the scan loop). xs[i] has
shape [length, *per_step] and the body sees xs[i][t] per
iteration. Output shape matches init.
Sourcepub fn scan_with_xs(
&mut self,
init: NodeId,
xs: &[NodeId],
body: Graph,
length: u32,
) -> NodeId
pub fn scan_with_xs( &mut self, init: NodeId, xs: &[NodeId], body: Graph, length: u32, ) -> NodeId
Bounded scan with per-step xs inputs returning the final carry.
Body has 1 + xs.len() Op::Inputs in NodeId construction order
(first declared is the carry; the remaining match xs in order).
Each xs[i] has shape [length, *per_step_shape_i]; the body
sees a per_step_shape_i slice on iteration t.
Sourcepub fn scan_backward(
&mut self,
init: NodeId,
trajectory: NodeId,
upstream: NodeId,
xs: &[NodeId],
body_vjp: Graph,
length: u32,
save_trajectory: bool,
out_shape: Shape,
) -> NodeId
pub fn scan_backward( &mut self, init: NodeId, trajectory: NodeId, upstream: NodeId, xs: &[NodeId], body_vjp: Graph, length: u32, save_trajectory: bool, out_shape: Shape, ) -> NodeId
Reverse-mode AD companion to Self::scan /
Self::scan_trajectory. Typically constructed by the
autodiff pass, not by hand.
xs is the list of per-step input tensors (must match the
forward Op::Scan’s xs in count, order, and per-step shape).
Body_vjp’s 1 + xs.len() + 1 Op::Inputs match the forward
body’s inputs plus a fresh "d_output" Input.
Sourcepub fn scan_backward_with_checkpoints(
&mut self,
init: NodeId,
trajectory: NodeId,
upstream: NodeId,
xs: &[NodeId],
body_vjp: Graph,
length: u32,
save_trajectory: bool,
num_checkpoints: u32,
forward_body: Option<Graph>,
out_shape: Shape,
) -> NodeId
pub fn scan_backward_with_checkpoints( &mut self, init: NodeId, trajectory: NodeId, upstream: NodeId, xs: &[NodeId], body_vjp: Graph, length: u32, save_trajectory: bool, num_checkpoints: u32, forward_body: Option<Graph>, out_shape: Shape, ) -> NodeId
Lower-level scan_backward with explicit checkpointing config.
num_checkpoints == 0 (default) means no checkpointing — the
trajectory cache holds every step’s carry. 0 < K < length
enables segment-cached recompute via forward_body (must be
Some).
Sourcepub fn scan_backward_xs(
&mut self,
init: NodeId,
trajectory: NodeId,
upstream: NodeId,
xs: &[NodeId],
body_vjp: Graph,
length: u32,
save_trajectory: bool,
xs_idx: u32,
out_shape: Shape,
) -> NodeId
pub fn scan_backward_xs( &mut self, init: NodeId, trajectory: NodeId, upstream: NodeId, xs: &[NodeId], body_vjp: Graph, length: u32, save_trajectory: bool, xs_idx: u32, out_shape: Shape, ) -> NodeId
Per-step xs gradient companion to Self::scan_backward.
Same inputs and same body_vjp graph, plus an xs_idx
selecting which body_vjp output to stack into the result.
Output shape is [length, *per_step_xs_shape].
pub fn scan_backward_xs_with_checkpoints( &mut self, init: NodeId, trajectory: NodeId, upstream: NodeId, xs: &[NodeId], body_vjp: Graph, length: u32, save_trajectory: bool, xs_idx: u32, num_checkpoints: u32, forward_body: Option<Graph>, out_shape: Shape, ) -> NodeId
Sourcepub fn custom_fn(
&mut self,
inputs: Vec<NodeId>,
fwd_body: Graph,
vjp_body: Option<Graph>,
jvp_body: Option<Graph>,
) -> NodeId
pub fn custom_fn( &mut self, inputs: Vec<NodeId>, fwd_body: Graph, vjp_body: Option<Graph>, jvp_body: Option<Graph>, ) -> NodeId
User-defined sub-graph with optional override AD rules.
JAX-shaped custom_vjp / custom_jvp — see Op::CustomFn.
inputs.len() must equal the number of Op::Input nodes in
fwd_body. Output shape is inferred from fwd_body’s declared
output. When supplied, vjp_body and jvp_body must follow the
conventions documented on Op::CustomFn (special-named
"primal_output" / "d_output" / "tangent_*" Inputs).
Sourcepub fn custom_fn_multi(
&mut self,
inputs: Vec<NodeId>,
fwd_body: Graph,
) -> MultiOutputHandle
pub fn custom_fn_multi( &mut self, inputs: Vec<NodeId>, fwd_body: Graph, ) -> MultiOutputHandle
Multi-output custom_fn via the concat-with-Narrow design:
rewrites fwd_body to flatten + concat its K declared outputs
into a single 1-D F32 output, wraps that as Op::CustomFn,
and returns a MultiOutputHandle the caller uses to extract
each sub-output via Op::Narrow + Op::Reshape.
Per PLAN line 484, this avoids rewriting rlx’s “1 Op = 1 output”
IR contract: the wrapped Op::CustomFn still has one output (the
flat concat), and MultiOutputHandle::output(g, i) materializes
component i lazily on the outer graph.
Constraints (MVP):
- All sub-outputs must be
DType::F32. Tuples-of-mixed-dtype need either a per-dtype split or a future tuple-type extension. - All sub-output shapes must be statically known (no
Dim::Dynamic). vjp_body/jvp_bodyaren’t yet rewritten through the concat — caller must provide bodies that already expect the flat-concat output convention if they need custom AD.
Sourcepub fn scan_trajectory(
&mut self,
init: NodeId,
body: Graph,
length: u32,
) -> NodeId
pub fn scan_trajectory( &mut self, init: NodeId, body: Graph, length: u32, ) -> NodeId
Bounded scan returning the stacked trajectory.
Output shape is [length, *init.shape] — row t is the carry
after step t+1, so row length-1 equals the result of plain
Self::scan.
Source§impl Graph
impl Graph
Sourcepub fn gaussian_splat_render(
&mut self,
inputs: GaussianSplatInputs,
params: GaussianSplatRenderParams,
) -> NodeId
pub fn gaussian_splat_render( &mut self, inputs: GaussianSplatInputs, params: GaussianSplatRenderParams, ) -> NodeId
First-class CPU reference Gaussian splat forward render.
See Op::GaussianSplatRender for the seven-input contract and
GaussianSplatRenderParams for framebuffer settings.
Sourcepub fn gaussian_splat_render_meta(
&mut self,
camera_position: [f32; 3],
camera_target: [f32; 3],
camera_up: [f32; 3],
fov_y_degrees: f32,
near: f32,
far: f32,
background: [f32; 3],
params: GaussianSplatRenderParams,
) -> NodeId
pub fn gaussian_splat_render_meta( &mut self, camera_position: [f32; 3], camera_target: [f32; 3], camera_up: [f32; 3], fov_y_degrees: f32, near: f32, far: f32, background: [f32; 3], params: GaussianSplatRenderParams, ) -> NodeId
Build the 23-float meta vector expected by Op::GaussianSplatRender.
Sourcepub fn gaussian_splat_prepare(
&mut self,
inputs: GaussianSplatInputs,
params: GaussianSplatRenderParams,
) -> NodeId
pub fn gaussian_splat_prepare( &mut self, inputs: GaussianSplatInputs, params: GaussianSplatRenderParams, ) -> NodeId
Strict IR stage 1: project + bin + sort + rays → packed prepare buffer.
Sourcepub fn gaussian_splat_rasterize(
&mut self,
prep: NodeId,
meta: NodeId,
params: GaussianSplatRenderParams,
) -> NodeId
pub fn gaussian_splat_rasterize( &mut self, prep: NodeId, meta: NodeId, params: GaussianSplatRenderParams, ) -> NodeId
Strict IR stage 2: rasterize from prepare buffer + meta.
Sourcepub fn gaussian_splat_render_decomposed(
&mut self,
inputs: GaussianSplatInputs,
params: GaussianSplatRenderParams,
) -> NodeId
pub fn gaussian_splat_render_decomposed( &mut self, inputs: GaussianSplatInputs, params: GaussianSplatRenderParams, ) -> NodeId
Decomposed strict-IR forward: prepare → rasterize.
Sourcepub fn gaussian_splat_render_backward(
&mut self,
inputs: GaussianSplatInputs,
d_loss_rgba: NodeId,
params: GaussianSplatBackwardParams,
) -> NodeId
pub fn gaussian_splat_render_backward( &mut self, inputs: GaussianSplatInputs, d_loss_rgba: NodeId, params: GaussianSplatBackwardParams, ) -> NodeId
Backward pass for Op::GaussianSplatRender (packed scene gradients).
Trait Implementations§
Source§impl From<Graph> for GraphModule
impl From<Graph> for GraphModule
Source§impl GraphExt for Graph
impl GraphExt for Graph
fn mm(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId
fn add(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId
fn sub(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId
fn mul(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId
fn div(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId
fn gelu(&mut self, x: NodeId) -> NodeId
Source§fn gelu_approx(&mut self, x: NodeId) -> NodeId
fn gelu_approx(&mut self, x: NodeId) -> NodeId
gelu formula,
also candle’s Tensor::gelu). Use this when porting models
whose reference implementations use the tanh form for
numerical parity (e.g. DINOv2, many ViTs).fn silu(&mut self, x: NodeId) -> NodeId
fn relu(&mut self, x: NodeId) -> NodeId
fn exp(&mut self, x: NodeId) -> NodeId
fn sqrt(&mut self, x: NodeId) -> NodeId
fn neg(&mut self, x: NodeId) -> NodeId
fn tanh(&mut self, x: NodeId) -> NodeId
fn ln(&mut self, x: NodeId, gamma: NodeId, beta: NodeId, eps: f32) -> NodeId
fn layer_norm2d( &mut self, x: NodeId, gamma: NodeId, beta: NodeId, eps: f32, ) -> NodeId
fn group_norm( &mut self, x: NodeId, gamma: NodeId, beta: NodeId, num_groups: usize, eps: f32, ) -> NodeId
fn conv2d( &mut self, input: NodeId, weight: NodeId, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], groups: usize, ) -> NodeId
fn conv_transpose2d( &mut self, input: NodeId, weight: NodeId, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], output_padding: [usize; 2], groups: usize, ) -> NodeId
fn rms_norm( &mut self, x: NodeId, gamma: NodeId, beta: NodeId, eps: f32, ) -> NodeId
fn sum(&mut self, x: NodeId, axes: Vec<usize>, keep_dim: bool) -> NodeId
fn mean(&mut self, x: NodeId, axes: Vec<usize>, keep_dim: bool) -> NodeId
fn sm(&mut self, x: NodeId, axis: i32) -> NodeId
fn reshape_(&mut self, x: NodeId, new_shape: Vec<i64>) -> NodeId
fn transpose_(&mut self, x: NodeId, perm: Vec<usize>) -> NodeId
fn narrow_( &mut self, x: NodeId, axis: usize, start: usize, len: usize, ) -> NodeId
fn concat_(&mut self, inputs: Vec<NodeId>, axis: usize) -> NodeId
fn gather_(&mut self, table: NodeId, indices: NodeId, axis: usize) -> NodeId
fn eq(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId
fn lt(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId
fn attention_( &mut self, q: NodeId, k: NodeId, v: NodeId, mask: NodeId, num_heads: usize, head_dim: usize, ) -> NodeId
fn rope( &mut self, x: NodeId, cos: NodeId, sin: NodeId, head_dim: usize, ) -> NodeId
Source§fn rope_n(
&mut self,
x: NodeId,
cos: NodeId,
sin: NodeId,
head_dim: usize,
n_rot: usize,
) -> NodeId
fn rope_n( &mut self, x: NodeId, cos: NodeId, sin: NodeId, head_dim: usize, n_rot: usize, ) -> NodeId
n_rot dims (NeoX offset n_rot/2).