Skip to main content

Crate rlx_autodiff

Crate rlx_autodiff 

Source
Expand description

JAX-shaped program transforms on RLX MIR: autodiff, JVP/HVP, and vmap.

Run prepare_graph_for_ad (or PrepareForAutodiff) before the gradient walk when the graph contains fused ops from HIR Direct lowering or inference fusion passes.

Re-exports§

pub use autodiff::GradWithLossOptions;
pub use autodiff::grad;
pub use autodiff::grad_with_loss;
pub use autodiff::grad_with_loss_opts;
pub use autodiff::quantized_weight_bits;
pub use autodiff_fwd::hvp;
pub use autodiff_fwd::jvp;
pub use compose::broadcast_scalar;
pub use compose::cse;
pub use decompose_backward::decompose_backward_for_ad;
pub use decompose_backward::decompose_backward_ops;
pub use decompose_backward::decompose_backward_ops_except;
pub use decompose_backward::prepare_grad_graph_for_jvp;
pub use higher_order::HigherOrderOptions;
pub use higher_order::directional_nth_grad;
pub use higher_order::fuse_elementwise;
pub use higher_order::nth_order_grad;
pub use higher_order::nth_order_grad_with_options;
pub use mlip::ForceEnergyLossWeights;
pub use mlip::build_force_energy_loss;
pub use mlip::grad_subgraph;
pub use mlip::grad_subgraph_for_jvp;
pub use prepare_ad::AutodiffError;
pub use prepare_ad::MirAutodiffExt;
pub use prepare_ad::PrepareForAutodiff;
pub use prepare_ad::grad_with_loss_module;
pub use prepare_ad::hvp_module;
pub use prepare_ad::jvp_module;
pub use prepare_ad::nth_order_grad_module;
pub use prepare_ad::prepare_graph_for_ad;
pub use prepare_ad::prepare_mir_for_ad;
pub use prepare_ad::prepare_module_for_ad;
pub use vmap::vmap;

Modules§

activation_deriv
Closed-form activation derivatives as primitive MIR (f'(x)).
autodiff
Reverse-mode automatic differentiation (VJP transform).
autodiff_fwd
Forward-mode AD (JVP transform).
compose
Graph composition helpers for higher-order reverse-mode AD.
decompose_backward
Decompose first-order *Backward kernels into primitive MIR so a second reverse-mode pass can differentiate through them.
decompose_backward_kernels
fuse_splat
Fuse GaussianSplatPrepareGaussianSplatRasterize into Op::GaussianSplatRender for AD.
higher_order
Higher-order reverse-mode AD — stack grad_with_loss with backward decomposition for 2nd/3rd/4th derivatives.
legalize_reduce
mlip
ML interatomic potential (MLIP) helpers — force + energy supervision via embedded inner gradients.
prepare_ad
MIR preparation for autodiff — canonical pre-passes shared by reverse- and forward-mode AD.
vmap
Batched function transformation (vmap).