use rlx_ir::hir::LowerError;
use rlx_ir::mir::MirModule;
use rlx_ir::{Graph, GraphModule, GraphStage, NodeId};
use rlx_fusion::pass::Pass;
pub use crate::autodiff::{convert_scans_for_ad, inline_custom_fn_for_autodiff};
pub use rlx_fusion::unfuse_fused_for_autodiff;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AutodiffError {
WrongStage {
got: GraphStage,
hint: &'static str,
},
Lower(LowerError),
}
impl std::fmt::Display for AutodiffError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::WrongStage { got, hint } => {
write!(f, "autodiff: cannot run on {got:?} stage — {hint}")
}
Self::Lower(e) => write!(f, "HIR lower failed: {e}"),
}
}
}
impl std::error::Error for AutodiffError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Lower(e) => Some(e),
_ => None,
}
}
}
pub fn prepare_graph_for_ad(g: Graph) -> Graph {
use rlx_fusion::pass::Pass as _;
let g = rlx_fusion::UnfuseElementwiseRegions.run(g);
let g = rlx_fusion::unfuse_fused_for_autodiff(g);
let g = rlx_fusion::LowerDotGeneral.run(g);
let g = rlx_fusion::control_flow::inline_if(g);
let g = rlx_fusion::control_flow::unroll_while(g);
let g = inline_custom_fn_for_autodiff(g);
let g = convert_scans_for_ad(g);
let g = crate::legalize_reduce::legalize_multi_axis_reduce(g);
crate::fuse_splat::fuse_decomposed_gaussian_splat(g)
}
#[derive(Debug, Clone, Copy, Default)]
pub struct PrepareForAutodiff;
impl Pass for PrepareForAutodiff {
fn name(&self) -> &str {
"prepare_for_autodiff"
}
fn run(&self, graph: Graph) -> Graph {
prepare_graph_for_ad(graph)
}
}
pub fn prepare_mir_for_ad(mir: MirModule) -> MirModule {
MirModule::from_graph(prepare_graph_for_ad(mir.into_graph()))
}
pub fn prepare_module_for_ad(module: GraphModule) -> Result<GraphModule, AutodiffError> {
let mir = module_into_mir(module)?;
Ok(MirModule::from_graph(prepare_graph_for_ad(mir.into_graph())).into())
}
pub fn grad_with_loss_module(module: GraphModule, wrt: &[NodeId]) -> Result<Graph, AutodiffError> {
let mir = module_into_mir(module)?;
Ok(crate::autodiff::grad_with_loss(mir.as_graph(), wrt))
}
pub fn jvp_module(module: GraphModule, tangent_for: &[NodeId]) -> Result<Graph, AutodiffError> {
let mir = module_into_mir(module)?;
Ok(crate::autodiff_fwd::jvp(mir.as_graph(), tangent_for))
}
fn module_into_mir(module: GraphModule) -> Result<MirModule, AutodiffError> {
match module.stage() {
GraphStage::Lir => Err(AutodiffError::WrongStage {
got: GraphStage::Lir,
hint: "use the embedded `mir` from LIR or rebuild from HIR/MIR before AD",
}),
GraphStage::Hir => module.into_mir().map_err(AutodiffError::Lower),
GraphStage::Mir => module.into_mir().map_err(AutodiffError::Lower),
}
}
pub trait MirAutodiffExt {
fn prepare_for_autodiff(self) -> MirModule;
fn grad_with_loss(&self, wrt: &[NodeId]) -> Graph;
}
impl MirAutodiffExt for MirModule {
fn prepare_for_autodiff(self) -> MirModule {
prepare_mir_for_ad(self)
}
fn grad_with_loss(&self, wrt: &[NodeId]) -> Graph {
crate::autodiff::grad_with_loss(self.as_graph(), wrt)
}
}
#[cfg(test)]
mod tests {
use super::*;
use rlx_ir::op::Op;
use rlx_ir::{DType, Shape};
fn f32_shape(d: &[usize]) -> Shape {
Shape::new(d, DType::F32)
}
#[test]
fn hir_direct_linear_grad_module() {
let module = GraphModule::define("layer", |m| {
let x = m.input("x", f32_shape(&[2, 8]));
let w = m.param("w", f32_shape(&[8, 8]));
let b = m.param("b", f32_shape(&[8]));
m.linear(x, w, Some(b), None, f32_shape(&[2, 8]))
});
let mir = module.into_mir().expect("lower");
assert!(
mir.as_graph()
.nodes()
.iter()
.any(|n| matches!(n.op, Op::FusedMatMulBiasAct { .. })),
"Direct HIR should lower to FusedMatMulBiasAct"
);
let w = mir
.as_graph()
.nodes()
.iter()
.find(|n| matches!(&n.op, Op::Param { name } if name == "w"))
.map(|n| n.id)
.expect("param w");
let bwd = grad_with_loss_module(GraphModule::from_mir(mir), &[w]).expect("grad");
assert!(
!bwd.nodes()
.iter()
.any(|n| matches!(n.op, Op::FusedMatMulBiasAct { .. })),
"backward graph should not retain fused ops"
);
assert!(bwd.outputs.len() >= 2);
}
#[test]
fn prepare_for_autodiff_pass_matches_fn() {
let mut g = Graph::new("t");
let x = g.input("x", f32_shape(&[4]));
g.set_outputs(vec![x]);
let via_pass = PrepareForAutodiff.run(g.clone());
let via_fn = prepare_graph_for_ad(g);
assert_eq!(via_pass.len(), via_fn.len());
}
}