rlx_autodiff/
prepare_ad.rs1use rlx_ir::hir::LowerError;
37use rlx_ir::mir::MirModule;
38use rlx_ir::{Graph, GraphModule, GraphStage, NodeId};
39
40use rlx_fusion::pass::Pass;
41
42pub use crate::autodiff::{convert_scans_for_ad, inline_custom_fn_for_autodiff};
43pub use rlx_fusion::unfuse_fused_for_autodiff;
44
45#[derive(Debug, Clone, PartialEq, Eq)]
47pub enum AutodiffError {
48 WrongStage {
51 got: GraphStage,
52 hint: &'static str,
53 },
54 Lower(LowerError),
55}
56
57impl std::fmt::Display for AutodiffError {
58 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 match self {
60 Self::WrongStage { got, hint } => {
61 write!(f, "autodiff: cannot run on {got:?} stage — {hint}")
62 }
63 Self::Lower(e) => write!(f, "HIR lower failed: {e}"),
64 }
65 }
66}
67
68impl std::error::Error for AutodiffError {
69 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
70 match self {
71 Self::Lower(e) => Some(e),
72 _ => None,
73 }
74 }
75}
76
77pub fn prepare_graph_for_ad(g: Graph) -> Graph {
99 use rlx_fusion::pass::Pass as _;
100 let g = rlx_fusion::DecomposeFusionRegions.run(g);
101 let g = rlx_fusion::UnfuseElementwiseRegions::FOR_CPU.run(g);
102 let g = crate::legalize_reduce::legalize_multi_axis_reduce(g);
103 let g = rlx_fusion::unfuse_fused_for_autodiff(g);
104 let g = rlx_fusion::LowerDotGeneral.run(g);
105 let g = rlx_fusion::control_flow::inline_if(g);
106 let g = rlx_fusion::control_flow::unroll_while(g);
107 let g = inline_custom_fn_for_autodiff(g);
108 let g = convert_scans_for_ad(g);
109 crate::fuse_splat::fuse_decomposed_gaussian_splat(g)
110}
111
112#[derive(Debug, Clone, Copy, Default)]
114pub struct PrepareForAutodiff;
115
116impl Pass for PrepareForAutodiff {
117 fn name(&self) -> &str {
118 "prepare_for_autodiff"
119 }
120
121 fn run(&self, graph: Graph) -> Graph {
122 prepare_graph_for_ad(graph)
123 }
124}
125
126pub fn prepare_mir_for_ad(mir: MirModule) -> MirModule {
128 MirModule::from_graph(prepare_graph_for_ad(mir.into_graph()))
129}
130
131pub fn prepare_module_for_ad(module: GraphModule) -> Result<GraphModule, AutodiffError> {
133 let mir = module_into_mir(module)?;
134 Ok(MirModule::from_graph(prepare_graph_for_ad(mir.into_graph())).into())
135}
136
137pub fn grad_with_loss_module(module: GraphModule, wrt: &[NodeId]) -> Result<Graph, AutodiffError> {
139 let mir = module_into_mir(module)?;
140 Ok(crate::autodiff::grad_with_loss(mir.as_graph(), wrt))
141}
142
143pub fn jvp_module(module: GraphModule, tangent_for: &[NodeId]) -> Result<Graph, AutodiffError> {
145 let mir = module_into_mir(module)?;
146 Ok(crate::autodiff_fwd::jvp(mir.as_graph(), tangent_for))
147}
148
149pub fn hvp_module(module: GraphModule, wrt: &[NodeId]) -> Result<Graph, AutodiffError> {
151 let mir = module_into_mir(module)?;
152 Ok(crate::autodiff_fwd::hvp(mir.as_graph(), wrt))
153}
154
155pub fn nth_order_grad_module(
157 module: GraphModule,
158 wrt_name: &str,
159 order: usize,
160) -> Result<Graph, AutodiffError> {
161 let mir = module_into_mir(module)?;
162 Ok(crate::higher_order::nth_order_grad(
163 mir.as_graph(),
164 wrt_name,
165 order,
166 ))
167}
168
169fn module_into_mir(module: GraphModule) -> Result<MirModule, AutodiffError> {
170 match module.stage() {
171 GraphStage::Lir => Err(AutodiffError::WrongStage {
172 got: GraphStage::Lir,
173 hint: "use the embedded `mir` from LIR or rebuild from HIR/MIR before AD",
174 }),
175 GraphStage::Hir => module.into_mir().map_err(AutodiffError::Lower),
176 GraphStage::Mir => module.into_mir().map_err(AutodiffError::Lower),
177 }
178}
179
180pub trait MirAutodiffExt {
182 fn prepare_for_autodiff(self) -> MirModule;
184
185 fn grad_with_loss(&self, wrt: &[NodeId]) -> Graph;
187}
188
189impl MirAutodiffExt for MirModule {
190 fn prepare_for_autodiff(self) -> MirModule {
191 prepare_mir_for_ad(self)
192 }
193
194 fn grad_with_loss(&self, wrt: &[NodeId]) -> Graph {
195 crate::autodiff::grad_with_loss(self.as_graph(), wrt)
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202 use rlx_ir::op::Op;
203 use rlx_ir::{DType, Shape};
204
205 fn f32_shape(d: &[usize]) -> Shape {
206 Shape::new(d, DType::F32)
207 }
208
209 #[test]
210 fn hir_direct_linear_grad_module() {
211 let module = GraphModule::define("layer", |m| {
212 let x = m.input("x", f32_shape(&[2, 8]));
213 let w = m.param("w", f32_shape(&[8, 8]));
214 let b = m.param("b", f32_shape(&[8]));
215 m.linear(x, w, Some(b), None, f32_shape(&[2, 8]))
216 });
217 let mir = module.into_mir().expect("lower");
218 assert!(
219 mir.as_graph()
220 .nodes()
221 .iter()
222 .any(|n| matches!(n.op, Op::FusedMatMulBiasAct { .. })),
223 "Direct HIR should lower to FusedMatMulBiasAct"
224 );
225
226 let w = mir
227 .as_graph()
228 .nodes()
229 .iter()
230 .find(|n| matches!(&n.op, Op::Param { name } if name == "w"))
231 .map(|n| n.id)
232 .expect("param w");
233 let bwd = grad_with_loss_module(GraphModule::from_mir(mir), &[w]).expect("grad");
234 assert!(
235 !bwd.nodes()
236 .iter()
237 .any(|n| matches!(n.op, Op::FusedMatMulBiasAct { .. })),
238 "backward graph should not retain fused ops"
239 );
240 assert!(bwd.outputs.len() >= 2);
241 }
242
243 #[test]
244 fn prepare_for_autodiff_decomposes_batch_elementwise_region() {
245 use rlx_fusion::fk_fusion::{FuseBatchPreprocess, MarkBatchSliceRegions};
246 use rlx_fusion::fk_graphs::batch_narrow_relu_primitive_graph;
247 let g = batch_narrow_relu_primitive_graph("batch_ad", 2, 3, 8, 8);
248 let fused = FuseBatchPreprocess.run(MarkBatchSliceRegions.run(g));
249 assert!(
250 fused
251 .nodes()
252 .iter()
253 .any(|n| matches!(n.op, Op::BatchElementwiseRegion { .. }))
254 );
255 let prep = prepare_graph_for_ad(fused);
256 assert!(
257 !prep
258 .nodes()
259 .iter()
260 .any(|n| matches!(n.op, Op::BatchElementwiseRegion { .. })),
261 "AD prep should decompose BatchElementwiseRegion"
262 );
263 }
264
265 #[test]
266 fn prepare_for_autodiff_pass_matches_fn() {
267 let mut g = Graph::new("t");
268 let x = g.input("x", f32_shape(&[4]));
269 g.set_outputs(vec![x]);
270 let via_pass = PrepareForAutodiff.run(g.clone());
271 let via_fn = prepare_graph_for_ad(g);
272 assert_eq!(via_pass.len(), via_fn.len());
273 }
274}