Expand description
RLX optimizer facade — re-exports rlx_fusion, rlx_autodiff, and
rlx_compile for backward-compatible rlx_opt:: paths.
§Crates
| Crate | Role |
|---|---|
rlx_fusion | Fusion passes + rlx_fusion::unfuse_fused_for_autodiff |
rlx_autodiff | grad_with_loss, jvp, vmap, prepare_graph_for_ad (feature training) |
rlx_compile | CompilePipeline, memory plan, legalization (feature compile) |
§Features
compile(default) — HIR → MIR → LIR pipelinetraining(default) — autodiff transformsfull— both
Re-exports§
pub use rlx_fusion;pub use rlx_autodiff;pub use rlx_compile;
Modules§
- autodiff
- autodiff_
fwd - compiler
- const_
fold - control_
flow - Control-flow lowering passes:
Op::If→Where+ inlined branches;Op::While→ bounded unroll of body replicas. - dce
- fusion
- Fusion passes — pattern-match and replace subgraphs with fused ops.
- fusion_
pipeline - fusion_
report - Fusion diagnostics — what fused, what missed, and why.
- inline
- inspect
- legalize
- legalize_
broadcast - lower_
dot_ general - Lower
Op::DotGeneralto primitive ops (MatMul + Transpose + Reshape). - memory
- pass
- Pass infrastructure — trait + pipeline runner.
- precision
- prepare_
ad - promote_
params - quant_
insert - quant_
propagate - svg
- unfuse
- Decompose tier-2 fused MIR ops into primitives for autodiff and backends.
- vmap
Structs§
- Algebraic
Simplify - Auto
Mixed Precision - Pass that rewrites a graph according to a
PrecisionPolicy. - Calibration
Entry - One calibrated quant entry per tap.
axis = Noneis per-tensor;axis = Some(d)is per-channel along axisd, in which casescalesandzero_pointsmust each have lengthtap.shape.dim(d). - Cast
Config - Cast configuration carried by ops that emit a typed output.
- Compile
Pipeline - End-to-end compiler pipeline configuration.
- Compile
Result - End-to-end compiler output: optimized LIR + fusion diagnostics.
- Constant
Folding - Dead
Code Elimination - Force
Energy Loss Weights - Weights for force + energy MSE terms.
- Fuse
Attention Block - Fuses
matmul(QKV) → narrow(Q,K,V) → [rope] → attention → matmul(out)into a single FusedAttentionBlock when batch*seq is small. - Fuse
MatMul Bias Act - Fuses
matmul → add(bias) → activationinto a single FusedMatMulBiasAct. - Fuse
ResidualLN - Fuses
add(x, residual) → layer_norminto FusedResidualLN. - Fuse
Residual RmsNorm - Fuses
add(x, residual) → rms_normintoOp::FusedResidualRmsNorm. - Fuse
RmsNorm Reshape - Fuses
rms_norm([…, H]) → reshape([∏leading, H])into a singleRmsNormwith the flattened output shape, eliminating a memcpy. - Fuse
Shared Input MatMul - Detects two MatMul nodes with the same input and concatenates their weight matrices into a single larger MatMul.
- Fuse
SwiGLU - Detects the post-
FuseSharedInputMatMulSwiGLU pattern and replaces it with a singleOp::FusedSwiGLUnode consuming the concatenated matmul. - Fuse
SwiGLU Dual Matmul - Fuses the common LLM FFN pattern in one rewrite: gate = matmul(x, wg); up = matmul(x, wu); out = mul(silu(gate), up)
- Fuse
Transformer Layer - Fuses an entire BERT-style transformer layer (attention block + residual+LN +
FFN + residual+LN) into one
Op::FusedTransformerLayernode. - Fusion
Limits - Hardware / encoder limits for fusion passes.
- Fusion
Options - Per-target fusion toggles (env-driven on Metal today).
- Fusion
Report - Before/after fusion statistics and missed-pattern tally.
- Grad
With Loss Options - Compute the reverse-mode gradient graph and the loss value.
- Higher
Order Options - Options for
nth_order_grad_with_options. - Kernel
Dispatch Config - Per-compile overrides on top of
KernelDispatchPolicy. - Kernel
Dispatch Report - Full report after rewrite + legalization probe (same path as
crate::rewrite::legalize_or_rewrite_for_backend_with_config). - Kind
Dispatch Summary - Per-
OpKindsummary for one graph + backend claim set. - Legalize
Broadcast - Pass that materializes non-trailing broadcasts via
Op::Expand. - Lower
Control Flow - Pass form: rewrites
Op::IfandOp::Whileinto primitive ops. No-op when neither op is present. - Lower
DotGeneral - Mark
Elementwise Regions - Memory
Plan Options - Assign buffers using a greedy best-fit algorithm.
- Missed
Fusion - A single fusion opportunity that remains in the graph.
- Pipeline
Inspect - Text dump of each compiler pipeline stage.
- Prepare
ForAutodiff Passwrapper forprepare_graph_for_ad.- Shared
Weight Layout - Persistent parameter slots extracted from a forward
MemoryPlan. - Specialize
Params - Pass wrapper for the fusion pipeline / runtime preprocess hook.
- Training
Compile Result - Forward + backward LIR with a single shared weight region.
- Unfuse
Elementwise Regions - Weight
Slot - One named parameter and its byte range in the shared weight region.
Enums§
- Autodiff
Error - Error from
grad_with_loss_module/jvp_module. - Dispatch
Path - How a logical / fused op reaches the backend executable.
- Fusion
Target - Compile target that selects a fusion pipeline.
- Kernel
Dispatch Policy - When to use native backend kernels vs the shared IR common body.
- Miss
Reason - Why a recognizable fusion pattern was not collapsed.
- OpKind
- High-level op categorization for precision policies.
- Precision
- Which numeric precision to use for an op. (Subset of DType — only the ones we currently dispatch on.)
- Precision
Policy - Declarative precision policy for graph compilation.
- Training
Compile Error - Error from
CompilePipeline::compile_training.
Traits§
- MirAutodiff
Ext - MIR extensions for the training pipeline.
- Pass
- A graph-to-graph transformation pass.
Functions§
- analyze_
dispatch - Analyze the graph before rewrite (static — does not run unfuse passes).
- backward_
cleanup_ passes - Passes safe on backward MIR after AD (no forward-only fusion).
- build_
force_ energy_ loss - Build
w_f·MSE(−∇E, F_ref) + w_e·MSE(E, E_ref). - convert_
scans_ for_ ad - cse
- Structural-equality CSE for pure (non-leaf) ops.
- decompose_
backward_ for_ ad - Prepare a
crate::autodiff::grad_with_lossgraph for another reverse pass. - directional_
nth_ grad - ND wrt via per-level direction contraction.
- format_
dispatch_ report - Human-readable report for logs / CI / REPL.
- format_
legalize_ error - Helper: format the legalize error as a single human-readable
diagnostic. Used by backend
compilepaths to panic with a clear message when legalization fails. - fuse_
elementwise - Opt-in elementwise fusion after higher-order stacking.
- fusion_
limits_ for_ target - Elementwise-region caps for
target(matches GPU kernel encoders). - fusion_
passes - Return the ordered fusion passes for
target. - fusion_
passes_ for_ supported - Return the ordered fusion passes allowed for
supported. - grad
- Backwards-compatible single-output alias (parameter gradients only,
no loss). Kept for the existing tests; prefer
grad_with_lossfor training. - grad_
subgraph - Decomposed, AD-ready gradient graph:
[energy, dE/d(positions…)]. - grad_
subgraph_ for_ jvp - Prepared for an outer
crate::autodiff_fwd::jvp. - grad_
with_ loss - Build a backward graph with scalar loss + gradients w.r.t.
wrt. - grad_
with_ loss_ module - Reverse-mode AD on a
GraphModuleat HIR or MIR stage. - grad_
with_ loss_ opts - Like
grad_with_losswith configurable unused-parameter handling. - hvp
- Hessian-vector product via forward-over-reverse.
- hvp_
module - Hessian-vector product module wrapper (
wrtparams get tangent inputs). - inline_
custom_ fn_ for_ autodiff - Pre-AD pass: inline
Op::CustomFnnodes that have neither avjp_bodynor ajvp_bodyby expanding theirfwd_bodyinto the parent graph. When either override body is present, keep theCustomFnwrapper so reverse- / forward-mode AD can dispatch to it. - inline_
if - Inline
Op::Ifsub-graphs into the parent and replace the If node withWhere(predicate, then_output, else_output). Both branches are present in the rewritten graph and always evaluate. - inline_
into - Inline
sourceintotarget. Returns thetargetNodeIds that correspond tosource.outputs. - inline_
subgraph_ into - Helper: copy
sub’s nodes intoout, mapping each Op::Input by position to the corresponding capture. Returns the new NodeId inoutof the sub-graph’s first declared output. - insert_
q_ dq - Insert
Quantize → Dequantizepairs at every tap inrecord. Returns a graph where each tagged node is followed by aQuantize → Dequantizepair, and every consumer of the original tap reads from the dequantized output instead. - inspect_
compiled - Inspect a completed
CompileResultplus the original HIR text. - inspect_
fusion - Fusion report only (post-optimize diagnostics).
- inspect_
pipeline - Inspect every lowering stage for
hirthroughpipeline. - is_
pure_ view - Public predicate for backends — true iff this op should compile to a Nop because its output aliases a parent buffer (the memory planner has already aliased its slot).
- jvp
- Compute the JVP graph for
forward, perturbing eachInput/Paramnamed intangent_for. Returns a new graph whose outputs are[primals..., tangents...], in the order forward listed them. - jvp_
module - Forward-mode AD on a
GraphModuleat HIR or MIR stage. - legalize_
for_ backend - Check
graphagainst the backend’ssupportedop set. - legalize_
or_ rewrite_ for_ backend - Legalize, rewriting unsupported ops first when possible.
- legalize_
or_ rewrite_ for_ backend_ with_ config - Legalize with full
KernelDispatchConfig. - legalize_
or_ rewrite_ for_ backend_ with_ dispatch - Legalize with explicit logical-kernel dispatch policy.
- maybe_
dump_ pipeline - Write a full pipeline dump when
RLX_IR_DUMPis set (path prefix or directory). - maybe_
log_ dispatch_ report - Print when
RLX_VERBOSE=1orRLX_DISPATCH_REPORT=1. - nth_
order_ grad - Scalar
wrt, scalar output: differentiateordertimes. - nth_
order_ grad_ module - Higher-order reverse-mode AD on a
GraphModule. - nth_
order_ grad_ with_ options - Like
nth_order_gradwith optional post-layer fusion. - plan_
memory_ backward - Plan backward activations, then alias params onto
weights. - plan_
memory_ f32_ uniform - Liveness-aware planning with every slot sized as
num_elements * 4bytes (wgpu / uniform-f32 arenas). Reuses dead tensor slots so large[n, n]pairwise graphs stay under WebGPU’s 128 MiB binding cap. - plan_
memory_ with_ options - Plan memory with custom alignment and boundary allocation policy.
- prepare_
grad_ graph_ for_ jvp - Prepare a
crate::autodiff::grad_with_lossgraph for an outer forward-mode (jvp) pass: decompose*Backwardops and bake ind_output = 1. - prepare_
graph_ for_ ad - Canonical MIR pre-passes before reverse- or forward-mode AD.
- prepare_
graph_ for_ backend_ with_ report - Rewrite toward
supported, then report native / common / rewritten / missing. - prepare_
mir_ for_ ad - Return MIR suitable for inspection or a custom AD walk.
- prepare_
module_ for_ ad - Lower HIR if needed, then run
prepare_graph_for_ad. - promote_
params_ to_ inputs - Convert every
Op::Param { name }whosenameis in the list toOp::Input { name }of the same shape. All other ops are copied through unchanged. Returns a freshGraph. - quantized_
weight_ bits - Project a gradient back to a smaller shape it was broadcasted from.
target_shapeis the broadcast source shape (e.g.[C]for a bias added to[N, C, H, W]). Sums over leading prepended axes and over any axis where target was 1 but the gradient is larger. Then reshapes to drop the size-1 axes if the rank shrunk. ReturnsSome(bits)ifnode_idis the output of anOp::FakeQuantize { bits, .. }(orFakeQuantizeLSQ) in the forward graph. Used by the autodiff Conv backward to detect the QAT pattern and emit a specialized weight-grad kernel that can skip dead bins (weights that round to the same code share the gradient). Today only the detection is exposed — the specialization is a follow-up commit. - rewrite_
for_ backend - Rewrite
graphtowardsupportedop kinds. Idempotent when already legal. - rewrite_
for_ backend_ with_ config - Full dispatch control (policy + per-
OpKindoverrides). - rewrite_
for_ backend_ with_ dispatch - Like
rewrite_for_backendbut applies logical-kernel common lowers first. - run_
passes - Run a sequence of passes, printing the graph after each if
verbose. - specialize_
params - Substitute listed params with constants. Unlisted params are unchanged.
- supported_
for_ target - Per-target op claims used when a backend doesn’t supply an explicit
supported_opsslice. Must stay aligned with each backend’s*_SUPPORTED_OPSinrlx-runtime/src/backend.rs. - supports_
op - True when
supportedis empty (no claim) or containskind. - unfuse_
fused_ for_ autodiff - Expand fused blocks so per-op VJP rules apply.
- unroll_
while - Bounded-unroll
Op::Whileup tomax_iterations. Each iteration inlinescondandbodywith all loop-carried captures, then appliesWhere(active, body_out, carried)per carry (MLX semantics). - vmap
- Vectorize
forwardover a leading batch axis.
Type Aliases§
- Calibration
Record - Map of tap NodeId → calibrated quant params.
- Legalize
Result - Result of
legalize_for_backend— list of(node, kind)pairs whose op is outside the backend’s claimed set.Ok(())when the graph is fully legalized.