Skip to main content

rlx_autodiff/
lib.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7
8//! JAX-shaped program transforms on RLX MIR: autodiff, JVP/HVP, and vmap.
9//!
10//! Run [`prepare_graph_for_ad`] (or [`PrepareForAutodiff`]) before the
11//! gradient walk when the graph contains fused ops from HIR `Direct`
12//! lowering or inference fusion passes.
13
14pub mod activation_deriv;
15pub mod autodiff;
16pub mod autodiff_fwd;
17pub mod compose;
18pub mod decompose_backward;
19pub mod decompose_backward_kernels;
20pub mod fuse_splat;
21pub mod higher_order;
22pub mod legalize_reduce;
23pub mod mlip;
24pub mod prepare_ad;
25pub mod vmap;
26
27pub use autodiff::{
28    GradWithLossOptions, grad, grad_with_loss, grad_with_loss_opts, quantized_weight_bits,
29};
30pub use autodiff_fwd::{hvp, jvp};
31pub use compose::{broadcast_scalar, cse};
32pub use decompose_backward::{
33    decompose_backward_for_ad, decompose_backward_ops, decompose_backward_ops_except,
34    prepare_grad_graph_for_jvp,
35};
36pub use higher_order::{
37    HigherOrderOptions, directional_nth_grad, fuse_elementwise, nth_order_grad,
38    nth_order_grad_with_options,
39};
40pub use mlip::{
41    ForceEnergyLossWeights, build_force_energy_loss, grad_subgraph, grad_subgraph_for_jvp,
42};
43pub use prepare_ad::{
44    AutodiffError, MirAutodiffExt, PrepareForAutodiff, grad_with_loss_module, hvp_module,
45    jvp_module, nth_order_grad_module, prepare_graph_for_ad, prepare_mir_for_ad,
46    prepare_module_for_ad,
47};
48pub use vmap::vmap;