pub enum Op {
Show 92 variants
Input {
name: String,
},
Param {
name: String,
},
Constant {
data: Vec<u8>,
},
Activation(Activation),
Cast {
to: DType,
},
Quantize {
axis: Option<usize>,
scales: Vec<f32>,
zero_points: Vec<i32>,
},
Dequantize {
axis: Option<usize>,
scales: Vec<f32>,
zero_points: Vec<i32>,
},
FakeQuantize {
bits: u8,
axis: Option<usize>,
ste: SteKind,
scale_mode: ScaleMode,
},
FakeQuantizeLSQ {
bits: u8,
axis: Option<usize>,
},
FakeQuantizeLSQBackwardX {
bits: u8,
axis: Option<usize>,
},
FakeQuantizeLSQBackwardScale {
bits: u8,
axis: Option<usize>,
},
Binary(BinaryOp),
Compare(CmpOp),
Where,
ElementwiseRegion {
chain: Vec<ChainStep>,
num_inputs: u32,
scalar_input_mask: u32,
input_modulus: [u32; 16],
},
MatMul,
DotGeneral {
lhs_contracting: Vec<usize>,
rhs_contracting: Vec<usize>,
lhs_batch: Vec<usize>,
rhs_batch: Vec<usize>,
},
BatchedDenseSolve,
DenseSolve,
LayerNorm {
axis: i32,
eps: f32,
},
GroupNorm {
num_groups: usize,
eps: f32,
},
LayerNorm2d {
eps: f32,
},
ResizeNearest2x,
RmsNorm {
axis: i32,
eps: f32,
},
Attention {
num_heads: usize,
head_dim: usize,
mask_kind: MaskKind,
score_scale: Option<f32>,
attn_logit_softcap: Option<f32>,
},
Rope {
head_dim: usize,
n_rot: usize,
},
AxialRope2d {
end_x: usize,
end_y: usize,
head_dim: usize,
num_heads: usize,
theta: f32,
repeat_factor: usize,
},
Reshape {
new_shape: Vec<i64>,
},
Transpose {
perm: Vec<usize>,
},
Narrow {
axis: usize,
start: usize,
len: usize,
},
Concat {
axis: usize,
},
Expand {
target_shape: Vec<i64>,
},
Gather {
axis: usize,
},
Reduce {
op: ReduceOp,
axes: Vec<usize>,
keep_dim: bool,
},
SelectiveScan {
state_size: usize,
},
GatedDeltaNet {
state_size: usize,
carry_state: bool,
},
DequantMatMul {
scheme: QuantScheme,
},
QMatMul {
x_zp: i32,
w_zp: i32,
out_zp: i32,
mult: f32,
},
QConv2d {
kernel_size: Vec<usize>,
stride: Vec<usize>,
padding: Vec<usize>,
dilation: Vec<usize>,
groups: usize,
x_zp: i32,
w_zp: i32,
out_zp: i32,
mult: f32,
},
LoraMatMul {
scale: f32,
},
Sample {
top_k: usize,
top_p: f32,
temperature: f32,
seed: u64,
},
Cumsum {
axis: i32,
exclusive: bool,
},
Softmax {
axis: i32,
},
TopK {
k: usize,
},
GroupedMatMul,
DequantGroupedMatMul {
scheme: QuantScheme,
},
DequantMoEWeights {
scheme: QuantScheme,
},
ScatterAdd,
Conv {
kernel_size: Vec<usize>,
stride: Vec<usize>,
padding: Vec<usize>,
dilation: Vec<usize>,
groups: usize,
},
ConvTranspose2d {
kernel_size: Vec<usize>,
stride: Vec<usize>,
padding: Vec<usize>,
dilation: Vec<usize>,
output_padding: Vec<usize>,
groups: usize,
},
Pool {
kind: ReduceOp,
kernel_size: Vec<usize>,
stride: Vec<usize>,
padding: Vec<usize>,
},
ReluBackward,
ComplexNormSq,
Conjugate,
ComplexNormSqBackward,
LayerNormBackwardInput {
axis: i32,
eps: f32,
},
LayerNormBackwardGamma {
axis: i32,
eps: f32,
},
RmsNormBackwardInput {
axis: i32,
eps: f32,
},
RmsNormBackwardGamma {
axis: i32,
eps: f32,
},
RmsNormBackwardBeta {
axis: i32,
eps: f32,
},
RopeBackward {
head_dim: usize,
n_rot: usize,
},
GroupNormBackwardInput {
num_groups: usize,
eps: f32,
},
GroupNormBackwardGamma {
num_groups: usize,
eps: f32,
},
GroupNormBackwardBeta {
num_groups: usize,
eps: f32,
},
CumsumBackward {
axis: i32,
exclusive: bool,
},
GatherBackward {
axis: i32,
},
ActivationBackward {
kind: Activation,
},
FakeQuantizeBackward {
bits: u8,
axis: Option<usize>,
ste: SteKind,
},
MaxPool2dBackward {
kernel_size: Vec<usize>,
stride: Vec<usize>,
padding: Vec<usize>,
},
Conv2dBackwardInput {
kernel_size: Vec<usize>,
stride: Vec<usize>,
padding: Vec<usize>,
dilation: Vec<usize>,
groups: usize,
},
Conv2dBackwardWeight {
kernel_size: Vec<usize>,
stride: Vec<usize>,
padding: Vec<usize>,
dilation: Vec<usize>,
groups: usize,
},
SoftmaxCrossEntropyWithLogits,
SoftmaxCrossEntropyBackward,
AttentionBackward {
num_heads: usize,
head_dim: usize,
mask_kind: MaskKind,
wrt: AttentionBwdWrt,
},
FusedMatMulBiasAct {
activation: Option<Activation>,
},
FusedResidualLN {
has_bias: bool,
eps: f32,
},
FusedResidualRmsNorm {
has_bias: bool,
eps: f32,
},
FusedSwiGLU {
cast_to: Option<DType>,
gate_first: bool,
},
FusedTransformerLayer {
num_heads: usize,
head_dim: usize,
intermediate_size: usize,
eps1: f32,
eps2: f32,
activation: Activation,
has_bias: bool,
},
FusedAttentionBlock {
num_heads: usize,
head_dim: usize,
has_bias: bool,
has_rope: bool,
},
If {
then_branch: Box<Graph>,
else_branch: Box<Graph>,
},
While {
cond: Box<Graph>,
body: Box<Graph>,
max_iterations: Option<usize>,
},
Scan {
body: Box<Graph>,
length: u32,
save_trajectory: bool,
num_bcast: u32,
num_xs: u32,
num_checkpoints: u32,
},
ScanBackward {
body_vjp: Box<Graph>,
length: u32,
save_trajectory: bool,
num_xs: u32,
num_checkpoints: u32,
forward_body: Option<Box<Graph>>,
},
ScanBackwardXs {
body_vjp: Box<Graph>,
length: u32,
save_trajectory: bool,
num_xs: u32,
xs_idx: u32,
num_checkpoints: u32,
forward_body: Option<Box<Graph>>,
},
GaussianSplatRender {
width: u32,
height: u32,
tile_size: u32,
radius_scale: f32,
alpha_cutoff: f32,
max_splat_steps: u32,
transmittance_threshold: f32,
max_list_entries: u32,
},
GaussianSplatRenderBackward {
width: u32,
height: u32,
tile_size: u32,
radius_scale: f32,
alpha_cutoff: f32,
max_splat_steps: u32,
transmittance_threshold: f32,
max_list_entries: u32,
loss_grad_clip: f32,
sh_band: u32,
max_anisotropy: f32,
},
GaussianSplatPrepare {
width: u32,
height: u32,
tile_size: u32,
radius_scale: f32,
alpha_cutoff: f32,
max_splat_steps: u32,
transmittance_threshold: f32,
max_list_entries: u32,
},
GaussianSplatRasterize {
width: u32,
height: u32,
tile_size: u32,
alpha_cutoff: f32,
max_splat_steps: u32,
transmittance_threshold: f32,
max_list_entries: u32,
},
Custom {
name: String,
num_inputs: u32,
attrs: Vec<u8>,
},
Fft {
inverse: bool,
},
CustomFn {
fwd_body: Box<Graph>,
vjp_body: Option<Box<Graph>>,
jvp_body: Option<Box<Graph>>,
num_inputs: u32,
},
}Expand description
An operation in the RLX IR graph.
Operations are categorized for fusion analysis:
- Element-wise ops fuse with anything reading their output
- Matmul/Conv are BLAS-dispatched and form fusion boundaries
- Reductions are fusion roots (drive the loop iteration)
Variants§
Input
Model input with a name (shape on the Node).
Param
Model parameter (weight/bias) with a name.
Constant
Constant tensor embedded in the graph.
Activation(Activation)
Unary activation: one input, same shape output.
Cast
Cast to a different dtype.
Quantize
INT8 quantization. Input f32; output i8 same shape.
q[i] = saturate_i8(round(x[i] / scale[c]) + zero_point[c])
where c selects the per-channel scale/zp when axis = Some(d)
(c = idx[d]), or always uses index 0 when axis = None
(per-tensor). The scales / zero_points payload length must
match 1 for per-tensor and input.dim(d) for per-channel.
Static — typically produced at calibration time and baked
into the loaded model. Use Op::Dequantize for the inverse.
Dequantize
INT8 dequantization (inverse of Op::Quantize). Input i8;
output f32 same shape.
x[i] = (q[i] - zero_point[c]) · scale[c]
where c is selected by axis exactly as in Op::Quantize.
FakeQuantize
“Fake-quantize” op for quantization-aware training (QAT).
Input f32; output f32 same shape. Forward computes a per-axis
(or per-tensor when axis = None) max-abs scale on the fly:
s[c] = max(|x[..., c, ...]|) / q_max(bits)
then quantizes-then-dequantizes:
out[i] = clamp(round(x[i] / s[c]), -q_max, q_max) * s[c]
where q_max is 127 for bits=8, 7 for bits=4, 1 for
bits=2 (ternary). Symmetric only — zero-point is always 0.
The point of this op is to make the SGD optimizer “see” the
deployment-time rounding during training. Backward is the
straight-through estimator (STE): the gradient passes
through (variant chosen by ste), ignoring the discontinuity
at the round. Without STE the rounding would have zero
gradient almost everywhere and learning would stop.
Inserted by the trainer on conv / FC weight tensors when
--qat is on; the existing Op::Quantize / packing path at
the end of training still handles the deployment-side
conversion to i8/i4/i2 codes.
FakeQuantizeLSQ
Learned Step Size Quantization (LSQ; Esser et al. 2020,
arXiv:1902.08153). Like FakeQuantize but the per-channel
scale is a learned parameter, passed as the second input.
Forward is identical to FakeQuantize with a fixed scale:
out[i] = clamp(round(x[i]/s[c]), -q_max, q_max) * s[c]
Backward computes both dx (STE) and dscale[c] via the
closed-form gradient:
dscale[c] = sum_i ψ(x[i]/s[c]) · upstream[i]
where ψ(z) = -z + round(z) if |z| ≤ q_max else
sign(z) · q_max. Routinely beats per-batch and EMA at
tight bit widths (i2 / i3).
Inputs: [x, scale]. scale is [chan_dim] f32 (matches
axis); for axis = None it’s [1].
FakeQuantizeLSQBackwardX
Backward pass for Op::FakeQuantizeLSQ. Computes BOTH the
gradient w.r.t. x (STE) and the gradient w.r.t. scale
(closed-form). Output shape matches x; the scale gradient
is reduced separately by LsqScaleGradient.
Inputs: [x, scale, dy]. Output: dx, same shape as x.
FakeQuantizeLSQBackwardScale
Companion to FakeQuantizeLSQBackwardX: computes the
[chan_dim] per-channel scale gradient. Inputs [x, scale, dy].
Output shape matches scale.
Binary(BinaryOp)
Binary op with broadcasting: two inputs, output shape is broadcast result.
Compare(CmpOp)
Element-wise comparison: two inputs, Bool output.
Where
Select elements: cond (Bool), on_true, on_false → output.
ElementwiseRegion
Fused element-wise region (PLAN L2). Holds an N-step chain of
element-wise operations. Inputs are referenced by index 0..num_inputs;
each step’s result can be referenced by later steps via
ChainOperand::Step(idx). The output is the last step’s result.
Emitted by MarkElementwiseRegions in rlx-opt from chains of
Activation/Cast/Binary/Compare/Where ops with single-consumer
intermediates and broadcast-compatible shapes. Backends that
don’t have a region kernel can decompose back to the original
chain via unfuse::unfuse_elementwise_regions.
scalar_input_mask is a per-input bitfield (bit i set ⇒
input i is a scalar broadcast — has shape [1]). Kept as a
fast-path indicator that lets kernels skip the modulo entirely
when they detect a scalar.
input_modulus[i] is the per-input element count, used by
kernels to compute arena[input_offs[i] + (gid % input_modulus[i])]
— the trailing-shape broadcast pattern. 0 means “no broadcast”
(input matches the output element count; kernel reads gid
directly). 1 means scalar; any other value means the input
has fewer elements than the output and they tile by modulo.
The encoder only allows broadcasts where out_elems % in_elems == 0 so the modulo divides cleanly. Lets chains include bias /
scale / eps / mask factors that previously broke the chain at
a Binary op with mismatched shapes.
MatMul
Matrix multiply. Inputs: [.., M, K] × [.., K, N] → [.., M, N]. Batch dimensions are broadcast.
DotGeneral
Matrix multiply with explicit dimension specification. Like XLA’s DotGeneral — handles arbitrary batch/contracting dims.
Fields
BatchedDenseSolve
Batched dense linear solve. Inputs: A [B, N, N],
b [B, N] or b [B, N, K]. Output: same shape as b.
Per-batch independent solve — each A[i] and b[i] are
solved as a separate Op::DenseSolve. Emitted by vmap of
Op::DenseSolve. The CPU lowering loops over the batch
dimension calling dgesv per slice (LAPACK doesn’t expose a
batched solve on Accelerate; cuSOLVER does on NVIDIA).
DenseSolve
Dense linear solve x = A⁻¹ · b via LU factorization.
Inputs: A [N, N], b [N] (or b [N, K] for multi-RHS).
Output: same shape as b.
VJP via the implicit-function theorem:
dx = solve(Aᵀ, upstream)
dA = -outer(dx, x) (x is the forward output)
db = dx
The rule is dtype-agnostic; lowering is per-backend (Accelerate
dgesv / sgesv, cuSOLVER, etc.).
LayerNorm
Layer normalization: input, gamma, beta → normalized output.
axis is the feature dimension (usually -1).
GroupNorm
Group normalization on NCHW tensors: input, gamma, beta → same shape.
Normalizes over (C/num_groups) × H × W per group.
LayerNorm2d
LayerNorm2d on NCHW: normalize across the channel axis at each spatial
position (candle / SAM LayerNorm2d semantics — not PyTorch’s H×W norm).
ResizeNearest2x
Nearest-neighbor 2× upsample on NCHW (doubles spatial dims 2 and 3).
RmsNorm
RMS normalization: input, gamma → normalized output.
Attention
Scaled dot-product attention: Q, K, V, [mask] → output.
The compiler can lower this to fused SDPA or flash attention.
mask_kind controls how masking is applied — Custom reads from
the 4th input tensor; None / Causal / SlidingWindow skip the
mask load and apply the mask directly in the inner loop. See
MaskKind for the rationale.
score_scale: when Some(s), dot-product scores are multiplied by
s instead of the default 1/sqrt(head_dim) (Gemma uses head_dim^-0.5
explicitly in config). attn_logit_softcap: when Some(c), applies
c * tanh(s/c) to scores before softmax (Gemma 2).
Fields
Rope
Rotary position embedding applied to one tensor: x, cos, sin → x_rotated.
Apply separately to Q and K. head_dim is the per-head width; n_rot
is how many leading dims get NeoX RoPE (pair offset n_rot/2). When
n_rot < head_dim, trailing dims are copied unchanged (Qwen3.5 MRoPE).
AxialRope2d
SAM2 axial 2-D RoPE on [batch, seq, num_heads * head_dim].
Reshape
Transpose
Narrow
Select a contiguous slice along an axis.
Concat
Concatenate along an axis.
Expand
Expand (broadcast) to a target shape.
Gather
Gather elements by index along an axis (embedding lookup).
Reduce
Reduce along specified axes.
SelectiveScan
Selective scan (plan #15) — Mamba-style state-space model
step. The recurrence:
h[t] = exp(Δ[t] * A) * h[t-1] + Δ[t] * B[t] * x[t]
y[t] = C[t] * h[t]
where state h has dimension state_size and the input has
(batch, seq, hidden).
Inputs (in order):
x [b, s, h] f32 input
delta [b, s, h] f32 step size (per-position, per-channel)
a [h, n] f32 transition matrix (one per channel)
b [b, s, n] f32 input projection
c [b, s, n] f32 output projection
Output: [b, s, h] f32. State h is implicit; the kernel
scans through the seq dimension carrying it.
state_size = n is exposed for the cost model.
GatedDeltaNet
Gated DeltaNet linear-attention recurrence — the per-layer
kernel used by Qwen3.5/3.6 trunk “linear attention” blocks
(and Qwen3-Next, Kimi-Linear). Mirrors
llama.cpp / src/models/delta-net-base.cpp autoregressive
path; chunked + fused variants ride the same op identity.
Math (per token t, head h, state size n):
state matrix S[h, i, j] is implicit (reset per batch).
S[h] *= exp(g[t,h]) # scalar gate
sk[h,j] = Σ_i S[h,i,j] * k[t,h,i]
d[h,j] = (v[t,h,j] - sk[h,j]) * b[t,h] # b = beta
S[h,i,j] += k[t,h,i] * d[h,j] # outer-prod
o[t,h,j] = Σ_i S[h,i,j] * (q[t,h,i] / √n)Inputs:
q [b, s, h_v, n] f32 queries (L2-normed by caller)
k [b, s, h_v, n] f32 keys (L2-normed by caller;
GQA-repeated to match h_v)
v [b, s, h_v, n] f32 values
g [b, s, h_v] f32 log-gate (exp’d inside kernel)
beta [b, s, h_v] f32 delta-rule mixing factor
Output: [b, s, h_v, n] f32.
When carry_state is true, a sixth input state [b, h_v, n, n]
provides the initial SSM matrix per head; the kernel updates it
in place across the sequence and leaves the final state in the
same buffer (same layout as the internal scan state:
state[h, i, j] row-major over (n, n) per head).
DequantMatMul
Fused dequant + matmul (plan #5). The biggest LLM-bandwidth win on Apple Silicon: dequantizes weights inside the matmul inner loop, never materializing f32 weights.
BREAKING CHANGE in 0.2.0: num_inputs() is now
scheme-dependent — 4 for legacy Int8 schemes, 2 for
the new GGUF K-quant schemes (their scales/mins live inside
the packed bytes, so no side-channel scale / zp tensors
are fed in). Callers that assumed a fixed 4-input contract
must inspect scheme.is_gguf() before reading inputs.
Inputs (Int8 schemes — scheme.is_gguf() == false):
x [m, k] f32 activations
w_q [k, n] packed quantized weight bytes (i8 per
element for Int8 schemes; 4-bit
packed two-per-byte for Int4)
scale [k/block, n] per-block f32 dequant scale
zp [k/block, n] per-block f32 zero-point
(zero-tensor if symmetric)
Inputs (Nvfp4Block — fixed group size 16 along K):
x [m, k] f32 activations
w_q [k,n/2] packed FP4 E2M1 codes (unsigned nibble 0..15)
scale [k/16, n] u8 FP8 E4M3 block scales (one byte / group)
global_scale [1] f32 per-tensor scale (pass [1.0] if unused)
Inputs (GGUF schemes — scheme.is_gguf() == true):
x [m, k] f32 activations
packed_w [bytes] raw GGUF super-block bytes; the
dequantizer reads the per-sub-block
scales / mins / quants directly out
of the buffer per the K-quant block
layout (no side tensors).
Output: [m, n] f32.
block_size (on the Int8 schemes only) is the number of
consecutive elements that share one (scale, zero_point) pair.
The Op carries enough metadata that the kernel doesn’t need
a separate QuantMap lookup at run time.
Fields
scheme: QuantSchemeQMatMul
Real INT8-arithmetic matrix multiply with i32 accumulation.
Inputs (in order):
x [M, K] i8 activations (zero-point = x_zp)
w [K, N] i8 weights (zero-point = w_zp)
bias [N] i32 (in accumulator scale = x_scale·w_scale),
pass a zeros tensor for “no bias”
Output: [M, N] i8 (zero-point = out_zp)
Per-element compute:
out[m,n] = requantize(bias[n] + Σₖ (x[m,k]-x_zp)·(w[k,n]-w_zp), mult, out_zp)
where mult = x_scale · w_scale / out_scale.
This is the same kernel shape rlx-cortexm/src/dense.rs
uses for on-device int8 inference, lifted into the IR so the
rlx-cpu backend can run a quantized graph directly (instead
of round-tripping through fake-quant Dequantize → MatMul →
Quantize). 2-D only — generalizing to batched comes when a
real workload demands it.
QConv2d
Real INT8-arithmetic 2-D convolution with i32 accumulation.
Inputs:
x [N, C_in, H, W] i8 (zero-point = x_zp)
w [C_out, C_in/groups, kH, kW] i8 (zero-point = w_zp)
bias [C_out] i32 in accumulator scale
Output: [N, C_out, H_out, W_out] i8 (zero-point = out_zp).
Same NCHW geometry contract as Op::Conv; same requantize
math as Op::QMatMul (per-element acc·mult rounded to i8).
Fields
LoraMatMul
Fused LoRA matmul: out = x·W + scale * x·A·B.
Inputs (in order): x [m, k], w [k, n], a [k, r], b [r, n].
r is the LoRA rank (typically 4-64). scale is the
per-adapter alpha / rank knob.
Plan #9: lifts LoRA from “three matmuls + an add” into one
kernel that keeps the rank-r intermediate in registers.
Sample
Fused sampling kernel: logits → optional top-k filter →
optional top-p truncation → softmax → multinomial sample.
One f32-encoded sampled token id per batch row (output
shape [batch]).
temperature == 1.0 matches a plain argmax-of-softmax;
lower → sharper, higher → flatter. top_k == 0 disables.
top_p == 1.0 disables. seed is the Philox seed; pass 0
for “use process-global counter” (still deterministic
given the call order).
Borrowed from MAX’s nn/sampling.mojo (#42 in PLAN.md).
Latency-critical: never materializes the full softmax
distribution on the host.
Cumsum
Inclusive cumulative sum along an axis. Same shape in/out.
Underpins ragged-tensor offsets, sampling (top-p prefix sum),
and sequence-position math (#44 in PLAN.md).
exclusive=true shifts the result so output[0] = 0 (useful
for offset arrays where the first segment starts at 0).
Softmax
Softmax along an axis (reduction + element-wise).
TopK
Top-K indices along the last axis. Output shape [..., k],
f32-encoded indices (rlx is f32-only at the I/O boundary).
To recover the values, follow with a Gather against the
original tensor — works because Gather already supports any axis.
Ties broken by smaller index (matches NumPy / PyTorch
torch.topk(..., largest=True, sorted=True)).
Used by MoE gating; also useful for beam search.
GroupedMatMul
Indexed batched matmul. The MoE GEMM primitive.
Inputs: [input, weight, expert_idx]
input : [M, K] — per-token activations
weight : [num_experts, K, N] — stacked expert weights
expert_idx : [M] — f32-encoded expert id per token
Output : [M, N] — output[i] = input[i] @ weight[expert_idx[i]]
Naive impl on both backends; future work can replace with a
segmented/grouped GEMM when there’s a real workload.
DequantGroupedMatMul
Fused GGUF K-quant dequant + Op::GroupedMatMul. Same three
inputs as GroupedMatMul, but weight is a U8 tensor holding
num_experts contiguous packed expert slabs (GGML layout, expert
dimension outermost). Scales live inside the packed bytes.
Fields
scheme: QuantSchemeDequantMoEWeights
Dequant a packed MoE expert stack to F32 [num_experts, K, N] in
GroupedMatMul layout. Input: U8 packed bytes; output shape is
declared on the node ([E, K, N]).
Fields
scheme: QuantSchemeScatterAdd
Scatter-add into a destination tensor. The “unpermute” half of
MoE routing (also useful for embedding gradient updates).
Inputs: [updates, indices]
updates : [num_updates, trailing] — values to add
indices : [num_updates] — f32-encoded destination row
Output : [out_dim, trailing] — output[indices[i]] += updates[i]
out_dim is taken from the node’s declared output shape.
Initial output is zero; multiple updates to the same row
accumulate (sequentially on CPU; with atomic-add on Metal).
Conv
2D convolution on NCHW tensors. Also exposed as OpKind::Conv / conv2d.
Weight layout: [C_out, C_in / groups, kH, kW].
Fields
ConvTranspose2d
2D transposed convolution on NCHW. Weight layout (PyTorch):
[C_in, C_out / groups, kH, kW].
Fields
Pool
ReluBackward
ReLU backward: dx = dy where x > 0 else 0.
Inputs: [x, dy] — both same shape; output matches.
ComplexNormSq
Element-wise complex squared-magnitude: |z|² = z.re² + z.im².
Input: 1 tensor with DType::C64. Output: same shape but
DType::F32. The natural real-valued loss surface for
Wirtinger reverse-mode AD on complex graphs — pair with
Op::ComplexNormSqBackward.
Conjugate
Element-wise complex conjugate: z̄ = z.re - i·z.im per element.
Input: 1 tensor with DType::C64. Output: same shape, same dtype.
Used by Wirtinger VJP rules on Op::Binary over C64 (the rule
for y = a·b is dL/dā = upstream · conj(b), etc.).
ComplexNormSqBackward
Backward for Op::ComplexNormSq under Wirtinger calculus.
f(z) = |z|² = z·z̄, so ∂f/∂z̄ = z. Given upstream real
cotangent g (same shape as the forward output), the C64
gradient with respect to z is g · z element-wise, returned
in C64 storage [re_g·re_z, re_g·im_z] per element.
Inputs: [z (C64), g (F32)] — both same logical shape; output
matches z (C64).
LayerNormBackwardInput
LayerNorm backward w.r.t. the input. Computes
d_x[..., d] = inv_std · (dy·γ - mean(dy·γ) - x̂·mean(dy·γ·x̂))
over the feature axis, where x̂ = (x - mean)/std is recomputed
inline from x. Inputs: [x, gamma, dy]; output shape = x.shape.
Currently lowers axis=-1 only (matches the forward thunk).
LayerNormBackwardGamma
LayerNorm backward w.r.t. gamma. Computes
d_gamma[d] = Σ_{batch} dy[..., d] · x̂[..., d]
— sums the per-element product of upstream and the (recomputed)
normalized input over the leading axes. Inputs: [x, dy];
output shape = gamma.shape (= 1-D feature axis).
RmsNormBackwardInput
RMSNorm backward w.r.t. input. Inputs [x, gamma, beta, dy]; output = x.shape.
RmsNormBackwardGamma
RMSNorm backward w.r.t. gamma. Inputs [x, gamma, beta, dy]; output = gamma.shape.
RmsNormBackwardBeta
RMSNorm backward w.r.t. beta. Inputs [x, gamma, beta, dy]; output = beta.shape.
RopeBackward
RoPE backward w.r.t. x. Inputs [dy, cos, sin]; output = dy.shape.
GroupNormBackwardInput
GroupNorm (NCHW) backward w.r.t. input. Inputs [x, gamma, beta, dy].
GroupNormBackwardGamma
GroupNorm backward w.r.t. gamma. Inputs [x, dy]; output = gamma.shape.
GroupNormBackwardBeta
GroupNorm backward w.r.t. beta. Inputs [x, dy]; output = beta.shape.
CumsumBackward
Cumsum backward along axis. Inputs [dy]; output matches forward input shape.
GatherBackward
Gather backward (scatter-add into table). Inputs [dy, indices]; output = table shape.
axis matches forward Op::Gather.
ActivationBackward
Generic element-wise activation backward. kind selects the
closed-form derivative d/dx act(x). Inputs: [x, dy]; output
shape matches x. The kernel computes d/dx · dy per element.
Closed forms (all element-wise):
Gelu— exact derivative of erf-based GELU.GeluApprox— derivative of the tanh approximation0.5 x (1 + tanh(√(2/π)(x + 0.044715 x³))).Silu—σ(x)·(1 + x·(1 - σ(x))).Sigmoid—σ(x)·(1 - σ(x)).Tanh—1 - tanh(x)².Exp—exp(x).Log—1 / x.Sqrt—0.5 / sqrt(x).Rsqrt—-0.5 · x^(-3/2).Neg—-1.Abs—sign(x)(zero at x=0).Sin—cos(x).Cos—-sin(x).Tan—1 + tan²(x).Atan—1 / (1 + x²).Relu— kept here for completeness; the dedicatedReluBackwardop is preferred for relu and is what the autodiff pass actually emits.
Fields
kind: ActivationFakeQuantizeBackward
Backward for Op::FakeQuantize under a non-default STE.
Inputs [x, dy]: the forward input and the upstream
gradient. Output dx same shape. The bits/axis/ste
fields must match the forward op so the kernel computes the
same per-channel scale and applies the right STE attenuation.
For SteKind::Identity this op is unnecessary — autodiff
just routes upstream through unchanged.
MaxPool2dBackward
2D max-pool backward. Routes each element of dy back into the
position in x’s window where the forward max was taken.
Inputs: [x, dy] with x [N, C, H, W] and
dy [N, C, H_out, W_out]. Output: same shape as x.
Carries the forward pool’s geometry so the kernel can recompute
the argmax position per window without a saved-indices tensor.
Conv2dBackwardInput
2D conv backward w.r.t. input. Computes dx = conv_transpose(dy, w).
Inputs: [dy, w] with dy [N, C_out, H_out, W_out] and
w [C_out, C_in/groups, kH, kW]. Output: [N, C_in, H, W]
(declared on the node — caller knows the original input shape).
Geometry is the forward conv’s parameters, not the transposed
conv’s.
Fields
Conv2dBackwardWeight
2D conv backward w.r.t. weight. Computes
dw[c_out, c_in, kh, kw] = sum_{n,h_out,w_out} x[n,c_in,...] * dy[n,c_out,h_out,w_out].
Inputs: [x, dy]. Output: [C_out, C_in/groups, kH, kW].
Fields
SoftmaxCrossEntropyWithLogits
Fused softmax + cross-entropy loss with integer (f32-encoded)
targets — the standard classification loss. Per-row output:
loss[n] = -log(softmax(logits[n])[labels[n]]).
Inputs: [logits, labels] with logits [N, C] and
labels [N] (f32-encoded class indices). Output: [N].
Caller does the Reduce::Mean if they want a scalar.
SoftmaxCrossEntropyBackward
Backward of the fused loss above. Emits
dlogits[n,c] = (softmax(logits[n])[c] - one_hot(labels)[n,c]) * d_loss[n].
Inputs: [logits, labels, d_loss]. Output: [N, C] (same shape
as logits). Recomputes the softmax inline rather than threading
it through from the forward node.
AttentionBackward
Backward of Op::Attention. Recomputes scaled QK^T, applies
the same mask_kind as the forward op, softmaxes scores, then
emits one of dQ, dK, or dV selected by AttentionBwdWrt.
Autodiff emits three nodes (one per wrt) so each output shape
stays a normal single-output MIR node.
Inputs: [q, k, v, dy] plus optional mask when mask_kind is
MaskKind::Custom or MaskKind::Bias (same convention as
forward). Output shape matches q, k, or v respectively.
FusedMatMulBiasAct
Fused matmul + bias + activation. Created from MatMul → Add → Activation.
Fields
activation: Option<Activation>FusedResidualLN
Fused residual + optional bias + layer norm. Created from Add(x, residual) → [Add(bias)] → LayerNorm.
FusedResidualRmsNorm
Fused residual + optional bias + RMS norm. Created from Add(x, residual) → [Add(bias)] → RmsNorm.
FusedSwiGLU
Fused SwiGLU: split input into up/gate halves, silu(gate) * up. Created from Split → Silu → Mul when fed by a concatenated matmul.
cast_to: optional output dtype — when Some(dt) the kernel casts
its result from the input dtype to dt in-register, saving a
separate cast pass. Reserved for future fp8/fp4 quantization paths;
for f32→f16 mixed precision the AutoMixedPrecision pass already
inserts a Cast node so this stays None in current pipelines.
Fields
FusedTransformerLayer
Fused full transformer layer: attention block + residual+LN + FFN + residual+LN. All intermediates resident in registers/threadgroup memory; one kernel per layer instead of ~30 (the CPU’s batch=1 win, lifted to IR so any backend can implement it as a monolithic kernel).
Inputs: hidden, qkv_w, qkv_b, out_w, out_b, ln1_g, ln1_b, fc1_w, fc1_b, fc2_w, fc2_b, ln2_g, ln2_b, mask Output: same shape as hidden.
Backend status: same as FusedAttentionBlock. CPU implements the L1-cache-resident merge at the thunk level. Metal deferred — requires a single MSL kernel for the whole layer to actually beat the unfused path. Multi-day work; revisit when there’s a model whose Metal inference is bottlenecked here rather than on the wait latency floor.
Fields
activation: ActivationFusedAttentionBlock
Fused attention block: QKV projection → split → [RoPE] → SDPA → output projection. Created by FuseAttentionBlock pass when batch*seq is small. All intermediates stay in L1 cache — no arena writes between ops.
Inputs (in order): hidden, qkv_w, out_w, mask, [qkv_b, out_b] if has_bias, [rope_cos, rope_sin] if has_rope
Backend status (Phase C finalize):
CPU — implemented at the thunk level: the CPU schedule
recognizes the multi-thunk pattern and merges into
a single FusedAttnBlock that keeps Q/K/V in stack
buffers across stages (the L1-cache win).
Metal — deferred. A dispatch-wrapper version (chaining
existing kernels) buys nothing the unfused Metal path
doesn’t already get, since per-run cost is dominated
by wait_until_completed (~150 µs), not encode. The
real win is a single MSL kernel keeping Q/K/V in
threadgroup memory across stages — multi-day work.
Until then, Metal runs the unfused chain (one matmul,
three narrows, two ropes, attention, one matmul) — all
covered in op_coverage and parity_harness.
If
Conditional: pick between two subgraphs based on a boolean predicate.
Inputs: [predicate, …captures (used inside both branches)].
then_branch and else_branch are sub-graphs that share the
captured inputs and must produce identically-shaped outputs.
Used for: shape-dependent execution, batched inference of
dynamic-length sequences with padding masks.
While
Loop: iterate body while cond evaluates true.
Inputs: […initial loop-carried values].
cond’s single output is a Bool scalar.
body’s outputs become the next iteration’s loop-carried inputs.
Outputs of While are the values after the final iteration.
Used for: KV-cache-driven autoregressive generation, beam search.
Scan
Bounded-length loop with a fixed-shape carry, optional per-step
inputs, and optional stacked output. Mirrors JAX’s lax.scan.
Body signature: (carry, x_t_0, ..., x_t_{num_xs-1}) → carry_next
— 1 + num_xs Op::Inputs in NodeId construction order (first
declared is the carry; the remaining num_xs are per-step
slices). Single output (the next carry).
Outer Op::Scan inputs (in order):
[init_carry, xs_0, xs_1, ..., xs_{num_xs-1}]
Each xs_i has shape [length, *per_step_shape_i]; the body
sees xs_i[t] (a per_step_shape_i slice) on iteration t.
Outer Op::Scan output:
save_trajectory == false— final carry, shape*carry.save_trajectory == true— stacked trajectory of carries, shape[length, *carry]. Rowtis the carry after stept+1, so rowlength-1matches the no-trajectory case.
Mirrors JAX’s lax.scan. Common uses include time-stepping
integrators with time-varying drives, Mamba-style SSM scans
reading per-step inputs, and RNN-style sequence processing.
Fields
num_bcast: u32Number of “broadcast” inputs — values that are constant
across iterations. Outer scan inputs in order:
[init, bcast_0..bcast_{B-1}, xs_0..xs_{X-1}]
Body Op::Inputs in NodeId order:
[carry, bcast_0..bcast_{B-1}, x_t_0..x_t_{X-1}]
CPU executor fills bcast slots ONCE before the iteration
loop (xs slots are filled per-step). The reverse-mode AD
pre-pass materialises each bcast into an xs of shape
[length, *bcast] via broadcast Mul so the rest of the
VJP / executor pipeline can stay unchanged. 0 (default)
keeps the original carry+xs scan shape.
num_checkpoints: u32Number of trajectory checkpoints when save_trajectory == true. 0 means “save all length rows” (default). A
positive value K means save only K evenly-spaced rows
at indices floor(t * length / K) for t in 0..K. Used
by recursive checkpointed AD: store O(√T) carries during
forward, recompute the rest in the backward pass.
When 0 (or K == length), the saved trajectory has
shape [length, *carry] — same as the original behavior.
When 0 < K < length, the saved trajectory has shape
[K, *carry].
ScanBackward
Reverse-mode AD companion to Op::Scan — extracts the carry
gradient dinit. Walks t = length-1 .. 0, applying body_vjp
to thread dcarry back through the time loop.
Inputs (in order):
[init, trajectory, upstream, xs_0, ..., xs_{num_xs-1}]
Output: dinit, shape = carry shape.
body_vjp is the result of
autodiff::grad(body, [carry_id, xs_0_id, ..., xs_{num_xs-1}_id])
— a graph with 1 + num_xs + 1 Inputs (carry + x_t_i for each
xs + "d_output") and 1 + num_xs outputs
(dcarry + dx_t_i for each xs). This op reads outputs[0] =
dcarry; the sibling Self::ScanBackwardXs reads the
outputs[1 + xs_idx] slot for each xs gradient.
Fields
num_checkpoints: u32When 0 or equal to length, the trajectory input has
shape [length, *carry] — every step’s carry is cached
(CheckpointStrategy::All). When 0 < K < length, the
trajectory input has shape [K, *carry] and the executor
recomputes intermediate carries via forward_body between
checkpoints. forward_body must be Some whenever this
is < length.
ScanBackwardXs
Companion to Self::ScanBackward that extracts one stacked
per-step dxs_i (shape [length, *per_step_xs_i]). Same inputs
and same body_vjp graph as ScanBackward — xs_idx selects
which body_vjp output to stack into the result.
Note: each ScanBackwardXs runs its own backward loop. A future
optimization can fuse them into a single multi-output backward
kernel; for now it’s 1 + num_xs independent sweeps.
Fields
GaussianSplatRender
CPU reference 3D Gaussian splat forward render.
Seven flat F32 inputs (scene buffers + camera/render meta):
0. positions [N*3]
- scales
[N*3](log-space) - rotations
[N*4](xyzw) - opacities
[N](logit) - colors
[N*3](linear RGB) - sh_coeffs
[N * sh_coeff_count * 3] - meta
[23]— camera position/target/up/fov/near/far, background RGB, then width/height/tile_size/radius_scale/alpha_cutoff/max_splat_steps/ transmittance_threshold/max_list_entries as f32 bit-patterns.
Output: [height * width * 4] linear RGBA (display gamma baked in).
Build via crate::Graph::gaussian_splat_render.
Differentiable backward is not implemented in v1; autodiff treats this
op as non-differentiable (same as Op::Sample).
Fields
GaussianSplatRenderBackward
Backward pass for Self::GaussianSplatRender.
Eight inputs: the same seven as forward plus d_loss_rgba [W*H*4]
(only RGB channels are used). Re-runs the training forward internally.
Output: packed gradients
[positions(3N) | scales(3N) | rotations(4N) | opacities(N) | colors(3N) | sh(N*sh*3)].
Unpack via crate::ops::splat::unpack_gaussian_splat_packed_grads.
Fields
GaussianSplatPrepare
Strict IR stage 1: project, bin, sort, build per-pixel rays.
Seven inputs (same scene + meta as Self::GaussianSplatRender). Output: packed
prepare buffer (see rlx_splat::prep_layout::prep_packed_len).
Fields
GaussianSplatRasterize
Strict IR stage 2: tile raster from Self::GaussianSplatPrepare output.
Inputs: prep packed buffer, meta [23]. Output: [width * height * 4] RGBA.
Fields
Custom
User-registered custom op. name keys into the
crate::op_registry for shape inference, autodiff, and
per-backend execution. attrs is an opaque blob passed
through to those callbacks (FFT direction, SparseLU
reordering strategy, etc.). num_inputs is captured at
construction time so Op::num_inputs stays infallible
without a registry lookup. Build via crate::Graph::custom_op.
Fft
1D Fast Fourier Transform along the last axis.
Convention: complex tensors are represented as 2N real-block
— the input shape is [..., 2N] along the last axis, with
the first N elements the real part and the second N the
imaginary part. Output shape matches the input. Last axis
length must be even (and a power of 2 for the v1 radix-2
kernel; other sizes will eventually go through mixed-radix).
Both forward and inverse are unnormalized (no 1/N scale):
fft(x)[k] = Σ x[n]·exp(-2πi·nk/N)
ifft(y)[n] = Σ y[k]·exp(+2πi·nk/N)
so ifft(fft(x)) = N·x. Users dividing by N for round-trip
identity matches numpy’s fft.fft / fft.ifft·N convention.
The unnormalized choice keeps both AD rules free of scaling:
- reverse-mode VJP:
VJP(fft) = ifft,VJP(ifft) = fft(transpose of the DFT matrix over the 2N-real-block view equals the unnormalized inverse). - forward-mode JVP: same op, same direction — FFT is linear, so the JVP is the linear map itself, not its transpose.
CPU paths exist for both DType::F32 and DType::F64 on the
2N-real-block layout. Native DType::C64 and non-power-of-two
sizes (Bluestein / mixed-radix) are not implemented; ND FFT
and non-CPU backend lowerings are deferred.
CustomFn
User-defined sub-graph with optional override AD rules.
Mirrors JAX’s custom_vjp / custom_jvp decorators: the
caller wraps a forward computation and supplies its own
reverse- and/or forward-mode AD bodies. Useful when:
- The forward is iterative (Newton, fixed-point) and differentiating through the loop is wasteful — the vjp_body computes the implicit-function gradient at the converged point in one shot.
- The math has a closed-form gradient that’s much cheaper than autodiff.
- The forward op is non-differentiable by tracing (sampling, argmax) and the user wants a smooth surrogate.
fwd_body: num_inputs Op::Inputs in NodeId construction
order, one Op::Output (the primal y). Forward execution
inlines this body once.
vjp_body (optional): Op::Inputs are num_inputs primal
inputs in NodeId order, plus two special-named Inputs —
"primal_output" (the y from forward) and "d_output" (the
upstream gradient). Outputs: num_inputs tensors in
set_outputs order, matching the gradients of each primal
input. When None, reverse-mode AD recurses into fwd_body
— same as if the op were inlined.
jvp_body (optional): Op::Inputs are num_inputs primal
inputs in NodeId order, num_inputs special-named Inputs
"tangent_0"..="tangent_{num_inputs-1}" carrying each input’s
tangent, and an optional special-named "primal_output" Input
(the y from forward, useful when the JVP must be evaluated at
a converged / nonlinear point — e.g. IFT-style forward-mode AD
of an iterative solver). Output: 1 tensor (the tangent of y).
When None, forward-mode AD recurses into fwd_body.
num_inputs is captured so Op::num_inputs stays
infallible. Build via crate::Graph::custom_fn.
Implementations§
Source§impl Op
impl Op
Sourcepub fn kind(&self) -> OpKind
pub fn kind(&self) -> OpKind
PLAN L4: discriminant for backend-supported-set checks.
Stable, parameter-free identity per variant — Op::Activation(_)
and Op::Activation(Relu) share the same OpKind::Activation.
Sourcepub fn is_elementwise(&self) -> bool
pub fn is_elementwise(&self) -> bool
True if this op is element-wise (same shape in, same shape out). Element-wise ops are prime fusion candidates.
Sourcepub fn is_blas(&self) -> bool
pub fn is_blas(&self) -> bool
True if this op is a BLAS/compute-intensive op that forms a fusion boundary.
Sourcepub fn is_fusion_boundary(&self) -> bool
pub fn is_fusion_boundary(&self) -> bool
True if element-wise fusion must not span across this op.
Sourcepub fn is_reduction(&self) -> bool
pub fn is_reduction(&self) -> bool
True if this op is a reduction (drives loop iteration in fused kernels).
Sourcepub fn num_inputs(&self) -> usize
pub fn num_inputs(&self) -> usize
Number of tensor inputs this op expects.
Trait Implementations§
Source§impl<'de> Deserialize<'de> for Op
impl<'de> Deserialize<'de> for Op
Source§fn deserialize<__D>(
__deserializer: __D,
) -> Result<Op, <__D as Deserializer<'de>>::Error>where
__D: Deserializer<'de>,
fn deserialize<__D>(
__deserializer: __D,
) -> Result<Op, <__D as Deserializer<'de>>::Error>where
__D: Deserializer<'de>,
Source§impl Serialize for Op
impl Serialize for Op
Source§fn serialize<__S>(
&self,
__serializer: __S,
) -> Result<<__S as Serializer>::Ok, <__S as Serializer>::Error>where
__S: Serializer,
fn serialize<__S>(
&self,
__serializer: __S,
) -> Result<<__S as Serializer>::Ok, <__S as Serializer>::Error>where
__S: Serializer,
impl StructuralPartialEq for Op
Auto Trait Implementations§
impl Freeze for Op
impl RefUnwindSafe for Op
impl Send for Op
impl Sync for Op
impl Unpin for Op
impl UnsafeUnpin for Op
impl UnwindSafe for Op
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Source§impl<T> CloneToUninit for Twhere
T: Clone,
impl<T> CloneToUninit for Twhere
T: Clone,
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more