rlx-autodiff 0.2.0

JAX-shaped transforms for RLX MIR — autodiff, JVP/HVP, vmap
Documentation

rlx-autodiff

JAX-shaped transforms on RLX MIR: reverse-mode grad_with_loss, forward-mode jvp / hvp, and vmap (leading-axis batching).

Depends on rlx-ir and [rlx-fusion] (unfuse fused ops before AD when needed).

What's here

  • autodiff — reverse-mode AD; fused-op VJPs, control flow (If / While / Scan), custom-fn inlining for AD.
  • autodiff_fwd — forward-mode AD.
  • prepare_adprepare_graph_for_ad, MIR/module preparation.
  • vmap — batched function transform.
  • legalize_reduce — reduce legalization helpers for training graphs.

Feature

Enable via rlx-opt with feature training (default), or depend on this crate directly for a minimal AD-only dep tree.

Build / test

cargo test -p rlx-autodiff

License

GPL-3.0-only.