Skip to main content

rlx_autodiff/
higher_order.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7
8//! Higher-order reverse-mode AD — stack `grad_with_loss` with backward
9//! decomposition for 2nd/3rd/4th derivatives.
10
11use rlx_ir::{Graph, Shape};
12
13use crate::autodiff::grad_with_loss;
14use crate::compose::{
15    cse, find_input_by_name, internalize_d_output, output_depends_on_differentiable,
16    peel_scalar_expands, zero_derivative_graph,
17};
18use crate::decompose_backward::{
19    contract_grad_with_direction, decompose_backward_for_ad, decompose_backward_ops,
20};
21
22/// Opt-in elementwise fusion after higher-order stacking.
23pub fn fuse_elementwise(g: Graph) -> Graph {
24    use rlx_fusion::Pass;
25    rlx_fusion::MarkElementwiseRegions.run(g)
26}
27
28/// Options for [`nth_order_grad_with_options`].
29#[derive(Debug, Clone, Copy, Default)]
30pub struct HigherOrderOptions {
31    /// Run elementwise fusion after each differentiation layer (default: on).
32    pub fuse_elementwise: bool,
33}
34
35impl HigherOrderOptions {
36    /// Default options: elementwise fusion enabled unless `RLX_HIGHER_ORDER_NO_FUSE=1`.
37    pub fn new() -> Self {
38        Self {
39            fuse_elementwise: !rlx_ir::env::flag("RLX_HIGHER_ORDER_NO_FUSE"),
40        }
41    }
42}
43
44/// Like [`nth_order_grad`] with optional post-layer fusion.
45pub fn nth_order_grad_with_options(
46    forward: &Graph,
47    wrt_name: &str,
48    order: usize,
49    opts: HigherOrderOptions,
50) -> Graph {
51    nth_order_grad_inner(forward, wrt_name, order, opts.fuse_elementwise)
52}
53
54/// Scalar `wrt`, scalar output: differentiate `order` times.
55pub fn nth_order_grad(forward: &Graph, wrt_name: &str, order: usize) -> Graph {
56    nth_order_grad_inner(
57        forward,
58        wrt_name,
59        order,
60        !rlx_ir::env::flag("RLX_HIGHER_ORDER_NO_FUSE"),
61    )
62}
63
64fn nth_order_grad_inner(forward: &Graph, wrt_name: &str, order: usize, do_fuse: bool) -> Graph {
65    assert_eq!(
66        forward.outputs.len(),
67        1,
68        "nth_order_grad: forward must have exactly one output"
69    );
70    let wrt = find_input_by_name(forward, wrt_name)
71        .unwrap_or_else(|| panic!("nth_order_grad: no Input/Param named '{wrt_name}'"));
72    let dtype = forward.node(wrt).shape.dtype();
73    let loss = forward.outputs[0];
74    if order == 0 {
75        let mut g = forward.clone();
76        g.set_outputs(vec![loss]);
77        return g;
78    }
79    if !output_depends_on_differentiable(forward, loss, wrt) {
80        return zero_derivative_graph(&format!("{}_d{order}_zero", forward.name), wrt_name, dtype);
81    }
82
83    let mut g = forward.clone();
84    for layer in 0..order {
85        let wrt_id = find_input_by_name(&g, wrt_name).expect("wrt input preserved");
86        if layer > 0 && !output_depends_on_differentiable(&g, g.outputs[0], wrt_id) {
87            return zero_derivative_graph(
88                &format!("{}_d{order}_zero", forward.name),
89                wrt_name,
90                dtype,
91            );
92        }
93        let grad_g = grad_with_loss(&g, &[wrt_id]);
94        g = decompose_backward_for_ad(grad_g, 0);
95        g = cse(g);
96        g = peel_scalar_expands(g);
97        if do_fuse {
98            g = fuse_elementwise(g);
99        }
100        g.name = format!("{}_d{}", forward.name, layer + 1);
101    }
102    g
103}
104
105/// ND wrt via per-level direction contraction.
106///
107/// `directions.len()` is the derivative order. Each direction is exposed as
108/// `"dir_<level>"`. Order-2 with the same direction twice yields `<v, H v>`.
109pub fn directional_nth_grad(forward: &Graph, wrt_name: &str, directions: &[&str]) -> Graph {
110    assert_eq!(
111        forward.outputs.len(),
112        1,
113        "directional_nth_grad: forward must have exactly one output"
114    );
115    let order = directions.len();
116    let wrt = find_input_by_name(forward, wrt_name)
117        .unwrap_or_else(|| panic!("directional_nth_grad: no Input/Param named '{wrt_name}'"));
118    let dtype = forward.node(wrt).shape.dtype();
119    let wrt_shape = forward.node(wrt).shape.clone();
120    let loss = forward.outputs[0];
121    if order == 0 {
122        let mut g = forward.clone();
123        g.set_outputs(vec![loss]);
124        return g;
125    }
126    if !output_depends_on_differentiable(forward, loss, wrt) {
127        return zero_derivative_graph(
128            &format!("{}_dir_d{order}_zero", forward.name),
129            wrt_name,
130            dtype,
131        );
132    }
133
134    let mut g = forward.clone();
135    let fuse = !rlx_ir::env::flag("RLX_HIGHER_ORDER_NO_FUSE");
136    for (level, _dir_name) in directions.iter().enumerate() {
137        let wrt_id = find_input_by_name(&g, wrt_name).expect("wrt input preserved");
138        if level > 0 && !output_depends_on_differentiable(&g, g.outputs[0], wrt_id) {
139            return zero_derivative_graph(
140                &format!("{}_dir_d{order}_zero", forward.name),
141                wrt_name,
142                dtype,
143            );
144        }
145        let grad_g = grad_with_loss(&g, &[wrt_id]);
146        let grad_out = grad_g.outputs[1];
147
148        let mut contracted = decompose_backward_ops(grad_g);
149        internalize_d_output(&mut contracted);
150        let dir_input = contracted.input(
151            format!("dir_{level}"),
152            if wrt_shape.rank() == 0 {
153                Shape::scalar(dtype)
154            } else {
155                wrt_shape.clone()
156            },
157        );
158        let scalar = contract_grad_with_direction(&mut contracted, grad_out, dir_input);
159        contracted.set_outputs(vec![scalar]);
160        g = cse(contracted);
161        g = peel_scalar_expands(g);
162        if fuse {
163            g = fuse_elementwise(g);
164        }
165        g.name = format!("{}_dir_d{}", forward.name, level + 1);
166    }
167    g
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173    use rlx_ir::op::BinaryOp;
174    use rlx_ir::{DType, Graph, Op, Shape};
175
176    #[test]
177    fn nth_order_x_cubed_graph_shape() {
178        let mut g = Graph::new("x3");
179        let x = g.input("x", Shape::scalar(DType::F64));
180        let x2 = g.binary(BinaryOp::Mul, x, x, Shape::scalar(DType::F64));
181        let x3 = g.binary(BinaryOp::Mul, x2, x, Shape::scalar(DType::F64));
182        g.set_outputs(vec![x3]);
183
184        let g3 = nth_order_grad(&g, "x", 3);
185        assert_eq!(g3.outputs.len(), 1);
186        assert!(find_input_by_name(&g3, "d_output").is_none());
187    }
188
189    #[test]
190    fn nth_order_f16_bf16_graph_builds() {
191        for dt in [DType::F16, DType::BF16] {
192            let mut g = Graph::new("x3_lp");
193            let x = g.input("x", Shape::scalar(dt));
194            let x2 = g.binary(BinaryOp::Mul, x, x, Shape::scalar(dt));
195            let x3 = g.binary(BinaryOp::Mul, x2, x, Shape::scalar(dt));
196            g.set_outputs(vec![x3]);
197            let g3 = nth_order_grad(&g, "x", 3);
198            assert_eq!(g3.node(g3.outputs[0]).shape.dtype(), dt);
199        }
200    }
201
202    #[test]
203    fn relu_higher_order_builds() {
204        use rlx_ir::op::Activation;
205
206        let mut g = Graph::new("relu");
207        let x = g.input("x", Shape::scalar(DType::F64));
208        let y = g.activation(Activation::Relu, x, Shape::scalar(DType::F64));
209        g.set_outputs(vec![y]);
210
211        let g2 = nth_order_grad(&g, "x", 2);
212        let g3 = nth_order_grad(&g, "x", 3);
213        assert_eq!(g2.outputs.len(), 1);
214        assert_eq!(g3.outputs.len(), 1);
215        assert!(find_input_by_name(&g3, "d_output").is_none());
216    }
217
218    #[test]
219    fn directional_scalar_x_cubed_third() {
220        use rlx_ir::op::BinaryOp;
221
222        let mut g = Graph::new("x3_dir");
223        let x = g.input("x", Shape::scalar(DType::F64));
224        let x2 = g.binary(BinaryOp::Mul, x, x, Shape::scalar(DType::F64));
225        let x3 = g.binary(BinaryOp::Mul, x2, x, Shape::scalar(DType::F64));
226        g.set_outputs(vec![x3]);
227
228        let hg = directional_nth_grad(&g, "x", &["a", "b", "c"]);
229        assert_eq!(hg.outputs.len(), 1);
230        assert!(find_input_by_name(&hg, "dir_0").is_some());
231    }
232
233    #[test]
234    fn unreachable_compare_path_short_circuits() {
235        use rlx_ir::infer::GraphExt;
236        use rlx_ir::op::CmpOp;
237
238        let mut g = Graph::new("cmp");
239        let x = g.input("x", Shape::scalar(DType::F64));
240        let zero = g.add_node(
241            crate::compose::constant_zero(&Shape::scalar(DType::F64)),
242            vec![],
243            Shape::scalar(DType::F64),
244        );
245        let cmp = g.add_node(
246            Op::Compare(CmpOp::Gt),
247            vec![x, zero],
248            Shape::scalar(DType::F32),
249        );
250        let out = g.cast(cmp, DType::F64);
251        g.set_outputs(vec![out]);
252
253        let g3 = nth_order_grad(&g, "x", 3);
254        assert!(matches!(&g3.node(g3.outputs[0]).op, Op::Constant { .. }));
255    }
256}