Skip to main content

Crate rlx_opt

Crate rlx_opt 

Source
Expand description

RLX optimizer facade — re-exports rlx_fusion, rlx_autodiff, and rlx_compile for backward-compatible rlx_opt:: paths.

§Crates

CrateRole
rlx_fusionFusion passes + rlx_fusion::unfuse_fused_for_autodiff
rlx_autodiffgrad_with_loss, jvp, vmap, prepare_graph_for_ad (feature training)
rlx_compileCompilePipeline, memory plan, legalization (feature compile)

§Features

  • compile (default) — HIR → MIR → LIR pipeline
  • training (default) — autodiff transforms
  • full — both

Re-exports§

pub use rlx_fusion;
pub use rlx_autodiff;
pub use rlx_compile;

Modules§

autodiff
autodiff_fwd
compiler
const_fold
control_flow
Control-flow lowering passes: Op::IfWhere + inlined branches; Op::While → bounded unroll of body replicas.
dce
fusion
Fusion passes — pattern-match and replace subgraphs with fused ops.
fusion_pipeline
fusion_report
Fusion diagnostics — what fused, what missed, and why.
inline
inspect
legalize
legalize_broadcast
lower_dot_general
Lower Op::DotGeneral to primitive ops (MatMul + Transpose + Reshape).
memory
pass
Pass infrastructure — trait + pipeline runner.
precision
prepare_ad
promote_params
quant_insert
quant_propagate
svg
unfuse
Decompose tier-2 fused MIR ops into primitives for autodiff and backends.
vmap

Structs§

AlgebraicSimplify
AutoMixedPrecision
Pass that rewrites a graph according to a PrecisionPolicy.
CalibrationEntry
One calibrated quant entry per tap. axis = None is per-tensor; axis = Some(d) is per-channel along axis d, in which case scales and zero_points must each have length tap.shape.dim(d).
CastConfig
Cast configuration carried by ops that emit a typed output.
CompilePipeline
End-to-end compiler pipeline configuration.
CompileResult
End-to-end compiler output: optimized LIR + fusion diagnostics.
ConstantFolding
DeadCodeElimination
ForceEnergyLossWeights
Weights for force + energy MSE terms.
FuseAttentionBlock
Fuses matmul(QKV) → narrow(Q,K,V) → [rope] → attention → matmul(out) into a single FusedAttentionBlock when batch*seq is small.
FuseMatMulBiasAct
Fuses matmul → add(bias) → activation into a single FusedMatMulBiasAct.
FuseResidualLN
Fuses add(x, residual) → layer_norm into FusedResidualLN.
FuseResidualRmsNorm
Fuses add(x, residual) → rms_norm into Op::FusedResidualRmsNorm.
FuseRmsNormReshape
Fuses rms_norm([…, H]) → reshape([∏leading, H]) into a single RmsNorm with the flattened output shape, eliminating a memcpy.
FuseSharedInputMatMul
Detects two MatMul nodes with the same input and concatenates their weight matrices into a single larger MatMul.
FuseSwiGLU
Detects the post-FuseSharedInputMatMul SwiGLU pattern and replaces it with a single Op::FusedSwiGLU node consuming the concatenated matmul.
FuseSwiGLUDualMatmul
Fuses the common LLM FFN pattern in one rewrite: gate = matmul(x, wg); up = matmul(x, wu); out = mul(silu(gate), up)
FuseTransformerLayer
Fuses an entire BERT-style transformer layer (attention block + residual+LN + FFN + residual+LN) into one Op::FusedTransformerLayer node.
FusionLimits
Hardware / encoder limits for fusion passes.
FusionOptions
Per-target fusion toggles (env-driven on Metal today).
FusionReport
Before/after fusion statistics and missed-pattern tally.
GradWithLossOptions
Compute the reverse-mode gradient graph and the loss value.
HigherOrderOptions
Options for nth_order_grad_with_options.
KernelDispatchConfig
Per-compile overrides on top of KernelDispatchPolicy.
KernelDispatchReport
Full report after rewrite + legalization probe (same path as crate::rewrite::legalize_or_rewrite_for_backend_with_config).
KindDispatchSummary
Per-OpKind summary for one graph + backend claim set.
LegalizeBroadcast
Pass that materializes non-trailing broadcasts via Op::Expand.
LowerControlFlow
Pass form: rewrites Op::If and Op::While into primitive ops. No-op when neither op is present.
LowerDotGeneral
MarkElementwiseRegions
MemoryPlanOptions
Assign buffers using a greedy best-fit algorithm.
MissedFusion
A single fusion opportunity that remains in the graph.
PipelineInspect
Text dump of each compiler pipeline stage.
PrepareForAutodiff
Pass wrapper for prepare_graph_for_ad.
SharedWeightLayout
Persistent parameter slots extracted from a forward MemoryPlan.
SpecializeParams
Pass wrapper for the fusion pipeline / runtime preprocess hook.
TrainingCompileResult
Forward + backward LIR with a single shared weight region.
UnfuseElementwiseRegions
WeightSlot
One named parameter and its byte range in the shared weight region.

Enums§

AutodiffError
Error from grad_with_loss_module / jvp_module.
DispatchPath
How a logical / fused op reaches the backend executable.
FusionTarget
Compile target that selects a fusion pipeline.
KernelDispatchPolicy
When to use native backend kernels vs the shared IR common body.
MissReason
Why a recognizable fusion pattern was not collapsed.
OpKind
High-level op categorization for precision policies.
Precision
Which numeric precision to use for an op. (Subset of DType — only the ones we currently dispatch on.)
PrecisionPolicy
Declarative precision policy for graph compilation.
TrainingCompileError
Error from CompilePipeline::compile_training.

Traits§

MirAutodiffExt
MIR extensions for the training pipeline.
Pass
A graph-to-graph transformation pass.

Functions§

analyze_dispatch
Analyze the graph before rewrite (static — does not run unfuse passes).
backward_cleanup_passes
Passes safe on backward MIR after AD (no forward-only fusion).
build_force_energy_loss
Build w_f·MSE(−∇E, F_ref) + w_e·MSE(E, E_ref).
convert_scans_for_ad
cse
Structural-equality CSE for pure (non-leaf) ops.
decompose_backward_for_ad
Prepare a crate::autodiff::grad_with_loss graph for another reverse pass.
directional_nth_grad
ND wrt via per-level direction contraction.
format_dispatch_report
Human-readable report for logs / CI / REPL.
format_legalize_error
Helper: format the legalize error as a single human-readable diagnostic. Used by backend compile paths to panic with a clear message when legalization fails.
fuse_elementwise
Opt-in elementwise fusion after higher-order stacking.
fusion_limits_for_target
Elementwise-region caps for target (matches GPU kernel encoders).
fusion_passes
Return the ordered fusion passes for target.
fusion_passes_for_supported
Return the ordered fusion passes allowed for supported.
grad
Backwards-compatible single-output alias (parameter gradients only, no loss). Kept for the existing tests; prefer grad_with_loss for training.
grad_subgraph
Decomposed, AD-ready gradient graph: [energy, dE/d(positions…)].
grad_subgraph_for_jvp
Prepared for an outer crate::autodiff_fwd::jvp.
grad_with_loss
Build a backward graph with scalar loss + gradients w.r.t. wrt.
grad_with_loss_module
Reverse-mode AD on a GraphModule at HIR or MIR stage.
grad_with_loss_opts
Like grad_with_loss with configurable unused-parameter handling.
hvp
Hessian-vector product via forward-over-reverse.
hvp_module
Hessian-vector product module wrapper (wrt params get tangent inputs).
inline_custom_fn_for_autodiff
Pre-AD pass: inline Op::CustomFn nodes that have neither a vjp_body nor a jvp_body by expanding their fwd_body into the parent graph. When either override body is present, keep the CustomFn wrapper so reverse- / forward-mode AD can dispatch to it.
inline_if
Inline Op::If sub-graphs into the parent and replace the If node with Where(predicate, then_output, else_output). Both branches are present in the rewritten graph and always evaluate.
inline_into
Inline source into target. Returns the target NodeIds that correspond to source.outputs.
inline_subgraph_into
Helper: copy sub’s nodes into out, mapping each Op::Input by position to the corresponding capture. Returns the new NodeId in out of the sub-graph’s first declared output.
insert_q_dq
Insert Quantize → Dequantize pairs at every tap in record. Returns a graph where each tagged node is followed by a Quantize → Dequantize pair, and every consumer of the original tap reads from the dequantized output instead.
inspect_compiled
Inspect a completed CompileResult plus the original HIR text.
inspect_fusion
Fusion report only (post-optimize diagnostics).
inspect_pipeline
Inspect every lowering stage for hir through pipeline.
is_pure_view
Public predicate for backends — true iff this op should compile to a Nop because its output aliases a parent buffer (the memory planner has already aliased its slot).
jvp
Compute the JVP graph for forward, perturbing each Input / Param named in tangent_for. Returns a new graph whose outputs are [primals..., tangents...], in the order forward listed them.
jvp_module
Forward-mode AD on a GraphModule at HIR or MIR stage.
legalize_for_backend
Check graph against the backend’s supported op set.
legalize_or_rewrite_for_backend
Legalize, rewriting unsupported ops first when possible.
legalize_or_rewrite_for_backend_with_config
Legalize with full KernelDispatchConfig.
legalize_or_rewrite_for_backend_with_dispatch
Legalize with explicit logical-kernel dispatch policy.
maybe_dump_pipeline
Write a full pipeline dump when RLX_IR_DUMP is set (path prefix or directory).
maybe_log_dispatch_report
Print when RLX_VERBOSE=1 or RLX_DISPATCH_REPORT=1.
nth_order_grad
Scalar wrt, scalar output: differentiate order times.
nth_order_grad_module
Higher-order reverse-mode AD on a GraphModule.
nth_order_grad_with_options
Like nth_order_grad with optional post-layer fusion.
plan_memory_backward
Plan backward activations, then alias params onto weights.
plan_memory_f32_uniform
Liveness-aware planning with every slot sized as num_elements * 4 bytes (wgpu / uniform-f32 arenas). Reuses dead tensor slots so large [n, n] pairwise graphs stay under WebGPU’s 128 MiB binding cap.
plan_memory_with_options
Plan memory with custom alignment and boundary allocation policy.
prepare_grad_graph_for_jvp
Prepare a crate::autodiff::grad_with_loss graph for an outer forward-mode (jvp) pass: decompose *Backward ops and bake in d_output = 1.
prepare_graph_for_ad
Canonical MIR pre-passes before reverse- or forward-mode AD.
prepare_graph_for_backend_with_report
Rewrite toward supported, then report native / common / rewritten / missing.
prepare_mir_for_ad
Return MIR suitable for inspection or a custom AD walk.
prepare_module_for_ad
Lower HIR if needed, then run prepare_graph_for_ad.
promote_params_to_inputs
Convert every Op::Param { name } whose name is in the list to Op::Input { name } of the same shape. All other ops are copied through unchanged. Returns a fresh Graph.
quantized_weight_bits
Project a gradient back to a smaller shape it was broadcasted from. target_shape is the broadcast source shape (e.g. [C] for a bias added to [N, C, H, W]). Sums over leading prepended axes and over any axis where target was 1 but the gradient is larger. Then reshapes to drop the size-1 axes if the rank shrunk. Returns Some(bits) if node_id is the output of an Op::FakeQuantize { bits, .. } (or FakeQuantizeLSQ) in the forward graph. Used by the autodiff Conv backward to detect the QAT pattern and emit a specialized weight-grad kernel that can skip dead bins (weights that round to the same code share the gradient). Today only the detection is exposed — the specialization is a follow-up commit.
rewrite_for_backend
Rewrite graph toward supported op kinds. Idempotent when already legal.
rewrite_for_backend_with_config
Full dispatch control (policy + per-OpKind overrides).
rewrite_for_backend_with_dispatch
Like rewrite_for_backend but applies logical-kernel common lowers first.
run_passes
Run a sequence of passes, printing the graph after each if verbose.
specialize_params
Substitute listed params with constants. Unlisted params are unchanged.
supported_for_target
Per-target op claims used when a backend doesn’t supply an explicit supported_ops slice. Must stay aligned with each backend’s *_SUPPORTED_OPS in rlx-runtime/src/backend.rs.
supports_op
True when supported is empty (no claim) or contains kind.
unfuse_fused_for_autodiff
Expand fused blocks so per-op VJP rules apply.
unroll_while
Bounded-unroll Op::While up to max_iterations. Each iteration inlines cond and body with all loop-carried captures, then applies Where(active, body_out, carried) per carry (MLX semantics).
vmap
Vectorize forward over a leading batch axis.

Type Aliases§

CalibrationRecord
Map of tap NodeId → calibrated quant params.
LegalizeResult
Result of legalize_for_backend — list of (node, kind) pairs whose op is outside the backend’s claimed set. Ok(()) when the graph is fully legalized.