Expand description
JAX-shaped program transforms on RLX MIR: autodiff, JVP/HVP, and vmap.
Run prepare_graph_for_ad (or PrepareForAutodiff) before the
gradient walk when the graph contains fused ops from HIR Direct
lowering or inference fusion passes.
Re-exports§
pub use autodiff::GradWithLossOptions;pub use autodiff::grad;pub use autodiff::grad_with_loss;pub use autodiff::grad_with_loss_opts;pub use autodiff::quantized_weight_bits;pub use autodiff_fwd::hvp;pub use autodiff_fwd::jvp;pub use compose::broadcast_scalar;pub use compose::cse;pub use decompose_backward::decompose_backward_for_ad;pub use decompose_backward::decompose_backward_ops;pub use decompose_backward::decompose_backward_ops_except;pub use decompose_backward::prepare_grad_graph_for_jvp;pub use higher_order::HigherOrderOptions;pub use higher_order::directional_nth_grad;pub use higher_order::fuse_elementwise;pub use higher_order::nth_order_grad;pub use higher_order::nth_order_grad_with_options;pub use mlip::ForceEnergyLossWeights;pub use mlip::build_force_energy_loss;pub use mlip::grad_subgraph;pub use mlip::grad_subgraph_for_jvp;pub use prepare_ad::AutodiffError;pub use prepare_ad::MirAutodiffExt;pub use prepare_ad::PrepareForAutodiff;pub use prepare_ad::grad_with_loss_module;pub use prepare_ad::hvp_module;pub use prepare_ad::jvp_module;pub use prepare_ad::nth_order_grad_module;pub use prepare_ad::prepare_graph_for_ad;pub use prepare_ad::prepare_mir_for_ad;pub use prepare_ad::prepare_module_for_ad;pub use vmap::vmap;
Modules§
- activation_
deriv - Closed-form activation derivatives as primitive MIR (
f'(x)). - autodiff
- Reverse-mode automatic differentiation (VJP transform).
- autodiff_
fwd - Forward-mode AD (JVP transform).
- compose
- Graph composition helpers for higher-order reverse-mode AD.
- decompose_
backward - Decompose first-order
*Backwardkernels into primitive MIR so a second reverse-mode pass can differentiate through them. - decompose_
backward_ kernels - fuse_
splat - Fuse
GaussianSplatPrepare→GaussianSplatRasterizeintoOp::GaussianSplatRenderfor AD. - higher_
order - Higher-order reverse-mode AD — stack
grad_with_losswith backward decomposition for 2nd/3rd/4th derivatives. - legalize_
reduce - mlip
- ML interatomic potential (MLIP) helpers — force + energy supervision via embedded inner gradients.
- prepare_
ad - MIR preparation for autodiff — canonical pre-passes shared by reverse- and forward-mode AD.
- vmap
- Batched function transformation (vmap).