rlx_autodiff/
prepare_ad.rs1use rlx_ir::hir::LowerError;
37use rlx_ir::mir::MirModule;
38use rlx_ir::{Graph, GraphModule, GraphStage, NodeId};
39
40use rlx_fusion::pass::Pass;
41
42pub use crate::autodiff::{convert_scans_for_ad, inline_custom_fn_for_autodiff};
43pub use rlx_fusion::unfuse_fused_for_autodiff;
44
45#[derive(Debug, Clone, PartialEq, Eq)]
47pub enum AutodiffError {
48 WrongStage {
51 got: GraphStage,
52 hint: &'static str,
53 },
54 Lower(LowerError),
55}
56
57impl std::fmt::Display for AutodiffError {
58 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 match self {
60 Self::WrongStage { got, hint } => {
61 write!(f, "autodiff: cannot run on {got:?} stage — {hint}")
62 }
63 Self::Lower(e) => write!(f, "HIR lower failed: {e}"),
64 }
65 }
66}
67
68impl std::error::Error for AutodiffError {
69 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
70 match self {
71 Self::Lower(e) => Some(e),
72 _ => None,
73 }
74 }
75}
76
77pub fn prepare_graph_for_ad(g: Graph) -> Graph {
88 use rlx_fusion::pass::Pass as _;
89 let g = rlx_fusion::UnfuseElementwiseRegions.run(g);
90 let g = rlx_fusion::unfuse_fused_for_autodiff(g);
91 let g = rlx_fusion::LowerDotGeneral.run(g);
92 let g = rlx_fusion::control_flow::inline_if(g);
93 let g = rlx_fusion::control_flow::unroll_while(g);
94 let g = inline_custom_fn_for_autodiff(g);
95 let g = convert_scans_for_ad(g);
96 let g = crate::legalize_reduce::legalize_multi_axis_reduce(g);
97 crate::fuse_splat::fuse_decomposed_gaussian_splat(g)
98}
99
100#[derive(Debug, Clone, Copy, Default)]
102pub struct PrepareForAutodiff;
103
104impl Pass for PrepareForAutodiff {
105 fn name(&self) -> &str {
106 "prepare_for_autodiff"
107 }
108
109 fn run(&self, graph: Graph) -> Graph {
110 prepare_graph_for_ad(graph)
111 }
112}
113
114pub fn prepare_mir_for_ad(mir: MirModule) -> MirModule {
116 MirModule::from_graph(prepare_graph_for_ad(mir.into_graph()))
117}
118
119pub fn prepare_module_for_ad(module: GraphModule) -> Result<GraphModule, AutodiffError> {
121 let mir = module_into_mir(module)?;
122 Ok(MirModule::from_graph(prepare_graph_for_ad(mir.into_graph())).into())
123}
124
125pub fn grad_with_loss_module(module: GraphModule, wrt: &[NodeId]) -> Result<Graph, AutodiffError> {
127 let mir = module_into_mir(module)?;
128 Ok(crate::autodiff::grad_with_loss(mir.as_graph(), wrt))
129}
130
131pub fn jvp_module(module: GraphModule, tangent_for: &[NodeId]) -> Result<Graph, AutodiffError> {
133 let mir = module_into_mir(module)?;
134 Ok(crate::autodiff_fwd::jvp(mir.as_graph(), tangent_for))
135}
136
137fn module_into_mir(module: GraphModule) -> Result<MirModule, AutodiffError> {
138 match module.stage() {
139 GraphStage::Lir => Err(AutodiffError::WrongStage {
140 got: GraphStage::Lir,
141 hint: "use the embedded `mir` from LIR or rebuild from HIR/MIR before AD",
142 }),
143 GraphStage::Hir => module.into_mir().map_err(AutodiffError::Lower),
144 GraphStage::Mir => module.into_mir().map_err(AutodiffError::Lower),
145 }
146}
147
148pub trait MirAutodiffExt {
150 fn prepare_for_autodiff(self) -> MirModule;
152
153 fn grad_with_loss(&self, wrt: &[NodeId]) -> Graph;
155}
156
157impl MirAutodiffExt for MirModule {
158 fn prepare_for_autodiff(self) -> MirModule {
159 prepare_mir_for_ad(self)
160 }
161
162 fn grad_with_loss(&self, wrt: &[NodeId]) -> Graph {
163 crate::autodiff::grad_with_loss(self.as_graph(), wrt)
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170 use rlx_ir::op::Op;
171 use rlx_ir::{DType, Shape};
172
173 fn f32_shape(d: &[usize]) -> Shape {
174 Shape::new(d, DType::F32)
175 }
176
177 #[test]
178 fn hir_direct_linear_grad_module() {
179 let module = GraphModule::define("layer", |m| {
180 let x = m.input("x", f32_shape(&[2, 8]));
181 let w = m.param("w", f32_shape(&[8, 8]));
182 let b = m.param("b", f32_shape(&[8]));
183 m.linear(x, w, Some(b), None, f32_shape(&[2, 8]))
184 });
185 let mir = module.into_mir().expect("lower");
186 assert!(
187 mir.as_graph()
188 .nodes()
189 .iter()
190 .any(|n| matches!(n.op, Op::FusedMatMulBiasAct { .. })),
191 "Direct HIR should lower to FusedMatMulBiasAct"
192 );
193
194 let w = mir
195 .as_graph()
196 .nodes()
197 .iter()
198 .find(|n| matches!(&n.op, Op::Param { name } if name == "w"))
199 .map(|n| n.id)
200 .expect("param w");
201 let bwd = grad_with_loss_module(GraphModule::from_mir(mir), &[w]).expect("grad");
202 assert!(
203 !bwd.nodes()
204 .iter()
205 .any(|n| matches!(n.op, Op::FusedMatMulBiasAct { .. })),
206 "backward graph should not retain fused ops"
207 );
208 assert!(bwd.outputs.len() >= 2);
209 }
210
211 #[test]
212 fn prepare_for_autodiff_pass_matches_fn() {
213 let mut g = Graph::new("t");
214 let x = g.input("x", f32_shape(&[4]));
215 g.set_outputs(vec![x]);
216 let via_pass = PrepareForAutodiff.run(g.clone());
217 let via_fn = prepare_graph_for_ad(g);
218 assert_eq!(via_pass.len(), via_fn.len());
219 }
220}