Skip to main content

rlx_autodiff/
prepare_ad.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! MIR preparation for autodiff — canonical pre-passes shared by reverse-
17//! and forward-mode AD.
18//!
19//! The HIR → MIR → LIR pipeline lowers fusion-friendly blocks to MIR
20//! (`FusionPolicy::Direct`) or primitive chains (`FusionPolicy::Fusable` /
21//! [`FusionPolicy::for_autodiff`]). Autodiff always runs on **MIR**
22//! ([`Graph`]); this module rewrites fused / control-flow / scan shapes
23//! into primitives the VJP table covers.
24//!
25//! Typical training flow:
26//!
27//! ```text
28//! GraphModule (HIR) ──lower──▶ MirModule ──prepare_graph_for_ad──▶ MirModule
29//!                                      └──grad_with_loss──▶ backward Graph
30//! Compile forward (fused) and backward (AD + cleanup) via
31//! [`rlx_compile::CompilePipeline::compile_training`] when the `training`
32//! feature is enabled on `rlx-compile`; backward params alias the forward
33//! weight layout instead of duplicating arena storage.
34//! ```
35
36use rlx_ir::hir::LowerError;
37use rlx_ir::mir::MirModule;
38use rlx_ir::{Graph, GraphModule, GraphStage, NodeId};
39
40use rlx_fusion::pass::Pass;
41
42pub use crate::autodiff::{convert_scans_for_ad, inline_custom_fn_for_autodiff};
43pub use rlx_fusion::unfuse_fused_for_autodiff;
44
45/// Error from [`grad_with_loss_module`] / [`jvp_module`].
46#[derive(Debug, Clone, PartialEq, Eq)]
47pub enum AutodiffError {
48    /// Autodiff requires MIR (or HIR to lower first). LIR carries a buffer
49    /// plan that does not apply to a freshly built gradient graph.
50    WrongStage {
51        got: GraphStage,
52        hint: &'static str,
53    },
54    Lower(LowerError),
55}
56
57impl std::fmt::Display for AutodiffError {
58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59        match self {
60            Self::WrongStage { got, hint } => {
61                write!(f, "autodiff: cannot run on {got:?} stage — {hint}")
62            }
63            Self::Lower(e) => write!(f, "HIR lower failed: {e}"),
64        }
65    }
66}
67
68impl std::error::Error for AutodiffError {
69    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
70        match self {
71            Self::Lower(e) => Some(e),
72            _ => None,
73        }
74    }
75}
76
77/// Canonical MIR pre-passes before reverse- or forward-mode AD.
78///
79/// Order:
80/// 1. [`UnfuseElementwiseRegions`](rlx_fusion::UnfuseElementwiseRegions)
81/// 2. [`rlx_fusion::unfuse_fused_for_autodiff`] — tier-2 fused ops → primitives
82/// 3. [`LowerDotGeneral`](rlx_fusion::LowerDotGeneral)
83/// 4. [`control_flow::inline_if`]
84/// 5. [`control_flow::unroll_while`]
85/// 6. [`inline_custom_fn_for_autodiff`]
86/// 7. [`convert_scans_for_ad`]
87pub fn prepare_graph_for_ad(g: Graph) -> Graph {
88    use rlx_fusion::pass::Pass as _;
89    let g = rlx_fusion::UnfuseElementwiseRegions.run(g);
90    let g = rlx_fusion::unfuse_fused_for_autodiff(g);
91    let g = rlx_fusion::LowerDotGeneral.run(g);
92    let g = rlx_fusion::control_flow::inline_if(g);
93    let g = rlx_fusion::control_flow::unroll_while(g);
94    let g = inline_custom_fn_for_autodiff(g);
95    let g = convert_scans_for_ad(g);
96    let g = crate::legalize_reduce::legalize_multi_axis_reduce(g);
97    crate::fuse_splat::fuse_decomposed_gaussian_splat(g)
98}
99
100/// [`Pass`] wrapper for [`prepare_graph_for_ad`].
101#[derive(Debug, Clone, Copy, Default)]
102pub struct PrepareForAutodiff;
103
104impl Pass for PrepareForAutodiff {
105    fn name(&self) -> &str {
106        "prepare_for_autodiff"
107    }
108
109    fn run(&self, graph: Graph) -> Graph {
110        prepare_graph_for_ad(graph)
111    }
112}
113
114/// Return MIR suitable for inspection or a custom AD walk.
115pub fn prepare_mir_for_ad(mir: MirModule) -> MirModule {
116    MirModule::from_graph(prepare_graph_for_ad(mir.into_graph()))
117}
118
119/// Lower HIR if needed, then run [`prepare_graph_for_ad`].
120pub fn prepare_module_for_ad(module: GraphModule) -> Result<GraphModule, AutodiffError> {
121    let mir = module_into_mir(module)?;
122    Ok(MirModule::from_graph(prepare_graph_for_ad(mir.into_graph())).into())
123}
124
125/// Reverse-mode AD on a [`GraphModule`] at HIR or MIR stage.
126pub fn grad_with_loss_module(module: GraphModule, wrt: &[NodeId]) -> Result<Graph, AutodiffError> {
127    let mir = module_into_mir(module)?;
128    Ok(crate::autodiff::grad_with_loss(mir.as_graph(), wrt))
129}
130
131/// Forward-mode AD on a [`GraphModule`] at HIR or MIR stage.
132pub fn jvp_module(module: GraphModule, tangent_for: &[NodeId]) -> Result<Graph, AutodiffError> {
133    let mir = module_into_mir(module)?;
134    Ok(crate::autodiff_fwd::jvp(mir.as_graph(), tangent_for))
135}
136
137fn module_into_mir(module: GraphModule) -> Result<MirModule, AutodiffError> {
138    match module.stage() {
139        GraphStage::Lir => Err(AutodiffError::WrongStage {
140            got: GraphStage::Lir,
141            hint: "use the embedded `mir` from LIR or rebuild from HIR/MIR before AD",
142        }),
143        GraphStage::Hir => module.into_mir().map_err(AutodiffError::Lower),
144        GraphStage::Mir => module.into_mir().map_err(AutodiffError::Lower),
145    }
146}
147
148/// MIR extensions for the training pipeline.
149pub trait MirAutodiffExt {
150    /// Run [`prepare_graph_for_ad`] and return primitive MIR.
151    fn prepare_for_autodiff(self) -> MirModule;
152
153    /// [`crate::autodiff::grad_with_loss`] on this module's graph.
154    fn grad_with_loss(&self, wrt: &[NodeId]) -> Graph;
155}
156
157impl MirAutodiffExt for MirModule {
158    fn prepare_for_autodiff(self) -> MirModule {
159        prepare_mir_for_ad(self)
160    }
161
162    fn grad_with_loss(&self, wrt: &[NodeId]) -> Graph {
163        crate::autodiff::grad_with_loss(self.as_graph(), wrt)
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170    use rlx_ir::op::Op;
171    use rlx_ir::{DType, Shape};
172
173    fn f32_shape(d: &[usize]) -> Shape {
174        Shape::new(d, DType::F32)
175    }
176
177    #[test]
178    fn hir_direct_linear_grad_module() {
179        let module = GraphModule::define("layer", |m| {
180            let x = m.input("x", f32_shape(&[2, 8]));
181            let w = m.param("w", f32_shape(&[8, 8]));
182            let b = m.param("b", f32_shape(&[8]));
183            m.linear(x, w, Some(b), None, f32_shape(&[2, 8]))
184        });
185        let mir = module.into_mir().expect("lower");
186        assert!(
187            mir.as_graph()
188                .nodes()
189                .iter()
190                .any(|n| matches!(n.op, Op::FusedMatMulBiasAct { .. })),
191            "Direct HIR should lower to FusedMatMulBiasAct"
192        );
193
194        let w = mir
195            .as_graph()
196            .nodes()
197            .iter()
198            .find(|n| matches!(&n.op, Op::Param { name } if name == "w"))
199            .map(|n| n.id)
200            .expect("param w");
201        let bwd = grad_with_loss_module(GraphModule::from_mir(mir), &[w]).expect("grad");
202        assert!(
203            !bwd.nodes()
204                .iter()
205                .any(|n| matches!(n.op, Op::FusedMatMulBiasAct { .. })),
206            "backward graph should not retain fused ops"
207        );
208        assert!(bwd.outputs.len() >= 2);
209    }
210
211    #[test]
212    fn prepare_for_autodiff_pass_matches_fn() {
213        let mut g = Graph::new("t");
214        let x = g.input("x", f32_shape(&[4]));
215        g.set_outputs(vec![x]);
216        let via_pass = PrepareForAutodiff.run(g.clone());
217        let via_fn = prepare_graph_for_ad(g);
218        assert_eq!(via_pass.len(), via_fn.len());
219    }
220}