Skip to main content

rlx_autodiff/
prepare_ad.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! MIR preparation for autodiff — canonical pre-passes shared by reverse-
17//! and forward-mode AD.
18//!
19//! The HIR → MIR → LIR pipeline lowers fusion-friendly blocks to MIR
20//! (`FusionPolicy::Direct`) or primitive chains (`FusionPolicy::Fusable` /
21//! [`FusionPolicy::for_autodiff`]). Autodiff always runs on **MIR**
22//! ([`Graph`]); this module rewrites fused / control-flow / scan shapes
23//! into primitives the VJP table covers.
24//!
25//! Typical training flow:
26//!
27//! ```text
28//! GraphModule (HIR) ──lower──▶ MirModule ──prepare_graph_for_ad──▶ MirModule
29//!                                      └──grad_with_loss──▶ backward Graph
30//! Compile forward (fused) and backward (AD + cleanup) via
31//! [`rlx_compile::CompilePipeline::compile_training`] when the `training`
32//! feature is enabled on `rlx-compile`; backward params alias the forward
33//! weight layout instead of duplicating arena storage.
34//! ```
35
36use rlx_ir::hir::LowerError;
37use rlx_ir::mir::MirModule;
38use rlx_ir::{Graph, GraphModule, GraphStage, NodeId};
39
40use rlx_fusion::pass::Pass;
41
42pub use crate::autodiff::{convert_scans_for_ad, inline_custom_fn_for_autodiff};
43pub use rlx_fusion::unfuse_fused_for_autodiff;
44
45/// Error from [`grad_with_loss_module`] / [`jvp_module`].
46#[derive(Debug, Clone, PartialEq, Eq)]
47pub enum AutodiffError {
48    /// Autodiff requires MIR (or HIR to lower first). LIR carries a buffer
49    /// plan that does not apply to a freshly built gradient graph.
50    WrongStage {
51        got: GraphStage,
52        hint: &'static str,
53    },
54    Lower(LowerError),
55}
56
57impl std::fmt::Display for AutodiffError {
58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59        match self {
60            Self::WrongStage { got, hint } => {
61                write!(f, "autodiff: cannot run on {got:?} stage — {hint}")
62            }
63            Self::Lower(e) => write!(f, "HIR lower failed: {e}"),
64        }
65    }
66}
67
68impl std::error::Error for AutodiffError {
69    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
70        match self {
71            Self::Lower(e) => Some(e),
72            _ => None,
73        }
74    }
75}
76
77/// Canonical MIR pre-passes before reverse- or forward-mode AD.
78///
79/// Order:
80/// 0. [`DecomposeFusionRegions`](rlx_fusion::DecomposeFusionRegions) — FKL batch/transform → primitives
81/// 1. [`UnfuseElementwiseRegions`](rlx_fusion::UnfuseElementwiseRegions)
82/// 2. [`legalize_multi_axis_reduce`](crate::legalize_reduce::legalize_multi_axis_reduce)
83/// 3. [`rlx_fusion::unfuse_fused_for_autodiff`] — tier-2 fused ops → primitives
84/// 4. [`LowerDotGeneral`](rlx_fusion::LowerDotGeneral)
85/// 5. [`control_flow::inline_if`]
86/// 6. [`control_flow::unroll_while`]
87/// 7. [`inline_custom_fn_for_autodiff`]
88/// 8. [`convert_scans_for_ad`]
89///
90/// **Why `legalize_multi_axis_reduce` runs early** (step 2, before
91/// `unfuse_fused_for_autodiff`). The unfuse walker assumes every input
92/// is already mapped by the time its consumer is visited; a user-built
93/// `Op::Reduce { axes: [0, 1], .. }` in a custom loss head can leave
94/// a dangling `id_map[i]` lookup inside the unfuser. Decomposing
95/// multi-axis Reduce into a single-axis chain first keeps the walk in
96/// strict topo order and is a no-op for graphs that never used
97/// multi-axis reduce.
98pub fn prepare_graph_for_ad(g: Graph) -> Graph {
99    use rlx_fusion::pass::Pass as _;
100    let g = rlx_fusion::DecomposeFusionRegions.run(g);
101    let g = rlx_fusion::UnfuseElementwiseRegions::FOR_CPU.run(g);
102    let g = crate::legalize_reduce::legalize_multi_axis_reduce(g);
103    let g = rlx_fusion::unfuse_fused_for_autodiff(g);
104    let g = rlx_fusion::LowerDotGeneral.run(g);
105    let g = rlx_fusion::control_flow::inline_if(g);
106    let g = rlx_fusion::control_flow::unroll_while(g);
107    let g = inline_custom_fn_for_autodiff(g);
108    let g = convert_scans_for_ad(g);
109    crate::fuse_splat::fuse_decomposed_gaussian_splat(g)
110}
111
112/// [`Pass`] wrapper for [`prepare_graph_for_ad`].
113#[derive(Debug, Clone, Copy, Default)]
114pub struct PrepareForAutodiff;
115
116impl Pass for PrepareForAutodiff {
117    fn name(&self) -> &str {
118        "prepare_for_autodiff"
119    }
120
121    fn run(&self, graph: Graph) -> Graph {
122        prepare_graph_for_ad(graph)
123    }
124}
125
126/// Return MIR suitable for inspection or a custom AD walk.
127pub fn prepare_mir_for_ad(mir: MirModule) -> MirModule {
128    MirModule::from_graph(prepare_graph_for_ad(mir.into_graph()))
129}
130
131/// Lower HIR if needed, then run [`prepare_graph_for_ad`].
132pub fn prepare_module_for_ad(module: GraphModule) -> Result<GraphModule, AutodiffError> {
133    let mir = module_into_mir(module)?;
134    Ok(MirModule::from_graph(prepare_graph_for_ad(mir.into_graph())).into())
135}
136
137/// Reverse-mode AD on a [`GraphModule`] at HIR or MIR stage.
138pub fn grad_with_loss_module(module: GraphModule, wrt: &[NodeId]) -> Result<Graph, AutodiffError> {
139    let mir = module_into_mir(module)?;
140    Ok(crate::autodiff::grad_with_loss(mir.as_graph(), wrt))
141}
142
143/// Forward-mode AD on a [`GraphModule`] at HIR or MIR stage.
144pub fn jvp_module(module: GraphModule, tangent_for: &[NodeId]) -> Result<Graph, AutodiffError> {
145    let mir = module_into_mir(module)?;
146    Ok(crate::autodiff_fwd::jvp(mir.as_graph(), tangent_for))
147}
148
149/// Hessian-vector product module wrapper (`wrt` params get tangent inputs).
150pub fn hvp_module(module: GraphModule, wrt: &[NodeId]) -> Result<Graph, AutodiffError> {
151    let mir = module_into_mir(module)?;
152    Ok(crate::autodiff_fwd::hvp(mir.as_graph(), wrt))
153}
154
155/// Higher-order reverse-mode AD on a [`GraphModule`].
156pub fn nth_order_grad_module(
157    module: GraphModule,
158    wrt_name: &str,
159    order: usize,
160) -> Result<Graph, AutodiffError> {
161    let mir = module_into_mir(module)?;
162    Ok(crate::higher_order::nth_order_grad(
163        mir.as_graph(),
164        wrt_name,
165        order,
166    ))
167}
168
169fn module_into_mir(module: GraphModule) -> Result<MirModule, AutodiffError> {
170    match module.stage() {
171        GraphStage::Lir => Err(AutodiffError::WrongStage {
172            got: GraphStage::Lir,
173            hint: "use the embedded `mir` from LIR or rebuild from HIR/MIR before AD",
174        }),
175        GraphStage::Hir => module.into_mir().map_err(AutodiffError::Lower),
176        GraphStage::Mir => module.into_mir().map_err(AutodiffError::Lower),
177    }
178}
179
180/// MIR extensions for the training pipeline.
181pub trait MirAutodiffExt {
182    /// Run [`prepare_graph_for_ad`] and return primitive MIR.
183    fn prepare_for_autodiff(self) -> MirModule;
184
185    /// [`crate::autodiff::grad_with_loss`] on this module's graph.
186    fn grad_with_loss(&self, wrt: &[NodeId]) -> Graph;
187}
188
189impl MirAutodiffExt for MirModule {
190    fn prepare_for_autodiff(self) -> MirModule {
191        prepare_mir_for_ad(self)
192    }
193
194    fn grad_with_loss(&self, wrt: &[NodeId]) -> Graph {
195        crate::autodiff::grad_with_loss(self.as_graph(), wrt)
196    }
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202    use rlx_ir::op::Op;
203    use rlx_ir::{DType, Shape};
204
205    fn f32_shape(d: &[usize]) -> Shape {
206        Shape::new(d, DType::F32)
207    }
208
209    #[test]
210    fn hir_direct_linear_grad_module() {
211        let module = GraphModule::define("layer", |m| {
212            let x = m.input("x", f32_shape(&[2, 8]));
213            let w = m.param("w", f32_shape(&[8, 8]));
214            let b = m.param("b", f32_shape(&[8]));
215            m.linear(x, w, Some(b), None, f32_shape(&[2, 8]))
216        });
217        let mir = module.into_mir().expect("lower");
218        assert!(
219            mir.as_graph()
220                .nodes()
221                .iter()
222                .any(|n| matches!(n.op, Op::FusedMatMulBiasAct { .. })),
223            "Direct HIR should lower to FusedMatMulBiasAct"
224        );
225
226        let w = mir
227            .as_graph()
228            .nodes()
229            .iter()
230            .find(|n| matches!(&n.op, Op::Param { name } if name == "w"))
231            .map(|n| n.id)
232            .expect("param w");
233        let bwd = grad_with_loss_module(GraphModule::from_mir(mir), &[w]).expect("grad");
234        assert!(
235            !bwd.nodes()
236                .iter()
237                .any(|n| matches!(n.op, Op::FusedMatMulBiasAct { .. })),
238            "backward graph should not retain fused ops"
239        );
240        assert!(bwd.outputs.len() >= 2);
241    }
242
243    #[test]
244    fn prepare_for_autodiff_decomposes_batch_elementwise_region() {
245        use rlx_fusion::fk_fusion::{FuseBatchPreprocess, MarkBatchSliceRegions};
246        use rlx_fusion::fk_graphs::batch_narrow_relu_primitive_graph;
247        let g = batch_narrow_relu_primitive_graph("batch_ad", 2, 3, 8, 8);
248        let fused = FuseBatchPreprocess.run(MarkBatchSliceRegions.run(g));
249        assert!(
250            fused
251                .nodes()
252                .iter()
253                .any(|n| matches!(n.op, Op::BatchElementwiseRegion { .. }))
254        );
255        let prep = prepare_graph_for_ad(fused);
256        assert!(
257            !prep
258                .nodes()
259                .iter()
260                .any(|n| matches!(n.op, Op::BatchElementwiseRegion { .. })),
261            "AD prep should decompose BatchElementwiseRegion"
262        );
263    }
264
265    #[test]
266    fn prepare_for_autodiff_pass_matches_fn() {
267        let mut g = Graph::new("t");
268        let x = g.input("x", f32_shape(&[4]));
269        g.set_outputs(vec![x]);
270        let via_pass = PrepareForAutodiff.run(g.clone());
271        let via_fn = prepare_graph_for_ad(g);
272        assert_eq!(via_pass.len(), via_fn.len());
273    }
274}