1use rlx_ir::{Graph, Shape};
12
13use crate::autodiff::grad_with_loss;
14use crate::compose::{
15 cse, find_input_by_name, internalize_d_output, output_depends_on_differentiable,
16 peel_scalar_expands, zero_derivative_graph,
17};
18use crate::decompose_backward::{
19 contract_grad_with_direction, decompose_backward_for_ad, decompose_backward_ops,
20};
21
22pub fn fuse_elementwise(g: Graph) -> Graph {
24 use rlx_fusion::Pass;
25 rlx_fusion::MarkElementwiseRegions.run(g)
26}
27
28#[derive(Debug, Clone, Copy, Default)]
30pub struct HigherOrderOptions {
31 pub fuse_elementwise: bool,
33}
34
35impl HigherOrderOptions {
36 pub fn new() -> Self {
38 Self {
39 fuse_elementwise: !rlx_ir::env::flag("RLX_HIGHER_ORDER_NO_FUSE"),
40 }
41 }
42}
43
44pub fn nth_order_grad_with_options(
46 forward: &Graph,
47 wrt_name: &str,
48 order: usize,
49 opts: HigherOrderOptions,
50) -> Graph {
51 nth_order_grad_inner(forward, wrt_name, order, opts.fuse_elementwise)
52}
53
54pub fn nth_order_grad(forward: &Graph, wrt_name: &str, order: usize) -> Graph {
56 nth_order_grad_inner(
57 forward,
58 wrt_name,
59 order,
60 !rlx_ir::env::flag("RLX_HIGHER_ORDER_NO_FUSE"),
61 )
62}
63
64fn nth_order_grad_inner(forward: &Graph, wrt_name: &str, order: usize, do_fuse: bool) -> Graph {
65 assert_eq!(
66 forward.outputs.len(),
67 1,
68 "nth_order_grad: forward must have exactly one output"
69 );
70 let wrt = find_input_by_name(forward, wrt_name)
71 .unwrap_or_else(|| panic!("nth_order_grad: no Input/Param named '{wrt_name}'"));
72 let dtype = forward.node(wrt).shape.dtype();
73 let loss = forward.outputs[0];
74 if order == 0 {
75 let mut g = forward.clone();
76 g.set_outputs(vec![loss]);
77 return g;
78 }
79 if !output_depends_on_differentiable(forward, loss, wrt) {
80 return zero_derivative_graph(&format!("{}_d{order}_zero", forward.name), wrt_name, dtype);
81 }
82
83 let mut g = forward.clone();
84 for layer in 0..order {
85 let wrt_id = find_input_by_name(&g, wrt_name).expect("wrt input preserved");
86 if layer > 0 && !output_depends_on_differentiable(&g, g.outputs[0], wrt_id) {
87 return zero_derivative_graph(
88 &format!("{}_d{order}_zero", forward.name),
89 wrt_name,
90 dtype,
91 );
92 }
93 let grad_g = grad_with_loss(&g, &[wrt_id]);
94 g = decompose_backward_for_ad(grad_g, 0);
95 g = cse(g);
96 g = peel_scalar_expands(g);
97 if do_fuse {
98 g = fuse_elementwise(g);
99 }
100 g.name = format!("{}_d{}", forward.name, layer + 1);
101 }
102 g
103}
104
105pub fn directional_nth_grad(forward: &Graph, wrt_name: &str, directions: &[&str]) -> Graph {
110 assert_eq!(
111 forward.outputs.len(),
112 1,
113 "directional_nth_grad: forward must have exactly one output"
114 );
115 let order = directions.len();
116 let wrt = find_input_by_name(forward, wrt_name)
117 .unwrap_or_else(|| panic!("directional_nth_grad: no Input/Param named '{wrt_name}'"));
118 let dtype = forward.node(wrt).shape.dtype();
119 let wrt_shape = forward.node(wrt).shape.clone();
120 let loss = forward.outputs[0];
121 if order == 0 {
122 let mut g = forward.clone();
123 g.set_outputs(vec![loss]);
124 return g;
125 }
126 if !output_depends_on_differentiable(forward, loss, wrt) {
127 return zero_derivative_graph(
128 &format!("{}_dir_d{order}_zero", forward.name),
129 wrt_name,
130 dtype,
131 );
132 }
133
134 let mut g = forward.clone();
135 let fuse = !rlx_ir::env::flag("RLX_HIGHER_ORDER_NO_FUSE");
136 for (level, _dir_name) in directions.iter().enumerate() {
137 let wrt_id = find_input_by_name(&g, wrt_name).expect("wrt input preserved");
138 if level > 0 && !output_depends_on_differentiable(&g, g.outputs[0], wrt_id) {
139 return zero_derivative_graph(
140 &format!("{}_dir_d{order}_zero", forward.name),
141 wrt_name,
142 dtype,
143 );
144 }
145 let grad_g = grad_with_loss(&g, &[wrt_id]);
146 let grad_out = grad_g.outputs[1];
147
148 let mut contracted = decompose_backward_ops(grad_g);
149 internalize_d_output(&mut contracted);
150 let dir_input = contracted.input(
151 format!("dir_{level}"),
152 if wrt_shape.rank() == 0 {
153 Shape::scalar(dtype)
154 } else {
155 wrt_shape.clone()
156 },
157 );
158 let scalar = contract_grad_with_direction(&mut contracted, grad_out, dir_input);
159 contracted.set_outputs(vec![scalar]);
160 g = cse(contracted);
161 g = peel_scalar_expands(g);
162 if fuse {
163 g = fuse_elementwise(g);
164 }
165 g.name = format!("{}_dir_d{}", forward.name, level + 1);
166 }
167 g
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173 use rlx_ir::op::BinaryOp;
174 use rlx_ir::{DType, Graph, Op, Shape};
175
176 #[test]
177 fn nth_order_x_cubed_graph_shape() {
178 let mut g = Graph::new("x3");
179 let x = g.input("x", Shape::scalar(DType::F64));
180 let x2 = g.binary(BinaryOp::Mul, x, x, Shape::scalar(DType::F64));
181 let x3 = g.binary(BinaryOp::Mul, x2, x, Shape::scalar(DType::F64));
182 g.set_outputs(vec![x3]);
183
184 let g3 = nth_order_grad(&g, "x", 3);
185 assert_eq!(g3.outputs.len(), 1);
186 assert!(find_input_by_name(&g3, "d_output").is_none());
187 }
188
189 #[test]
190 fn nth_order_f16_bf16_graph_builds() {
191 for dt in [DType::F16, DType::BF16] {
192 let mut g = Graph::new("x3_lp");
193 let x = g.input("x", Shape::scalar(dt));
194 let x2 = g.binary(BinaryOp::Mul, x, x, Shape::scalar(dt));
195 let x3 = g.binary(BinaryOp::Mul, x2, x, Shape::scalar(dt));
196 g.set_outputs(vec![x3]);
197 let g3 = nth_order_grad(&g, "x", 3);
198 assert_eq!(g3.node(g3.outputs[0]).shape.dtype(), dt);
199 }
200 }
201
202 #[test]
203 fn relu_higher_order_builds() {
204 use rlx_ir::op::Activation;
205
206 let mut g = Graph::new("relu");
207 let x = g.input("x", Shape::scalar(DType::F64));
208 let y = g.activation(Activation::Relu, x, Shape::scalar(DType::F64));
209 g.set_outputs(vec![y]);
210
211 let g2 = nth_order_grad(&g, "x", 2);
212 let g3 = nth_order_grad(&g, "x", 3);
213 assert_eq!(g2.outputs.len(), 1);
214 assert_eq!(g3.outputs.len(), 1);
215 assert!(find_input_by_name(&g3, "d_output").is_none());
216 }
217
218 #[test]
219 fn directional_scalar_x_cubed_third() {
220 use rlx_ir::op::BinaryOp;
221
222 let mut g = Graph::new("x3_dir");
223 let x = g.input("x", Shape::scalar(DType::F64));
224 let x2 = g.binary(BinaryOp::Mul, x, x, Shape::scalar(DType::F64));
225 let x3 = g.binary(BinaryOp::Mul, x2, x, Shape::scalar(DType::F64));
226 g.set_outputs(vec![x3]);
227
228 let hg = directional_nth_grad(&g, "x", &["a", "b", "c"]);
229 assert_eq!(hg.outputs.len(), 1);
230 assert!(find_input_by_name(&hg, "dir_0").is_some());
231 }
232
233 #[test]
234 fn unreachable_compare_path_short_circuits() {
235 use rlx_ir::infer::GraphExt;
236 use rlx_ir::op::CmpOp;
237
238 let mut g = Graph::new("cmp");
239 let x = g.input("x", Shape::scalar(DType::F64));
240 let zero = g.add_node(
241 crate::compose::constant_zero(&Shape::scalar(DType::F64)),
242 vec![],
243 Shape::scalar(DType::F64),
244 );
245 let cmp = g.add_node(
246 Op::Compare(CmpOp::Gt),
247 vec![x, zero],
248 Shape::scalar(DType::F32),
249 );
250 let out = g.cast(cmp, DType::F64);
251 g.set_outputs(vec![out]);
252
253 let g3 = nth_order_grad(&g, "x", 3);
254 assert!(matches!(&g3.node(g3.outputs[0]).op, Op::Constant { .. }));
255 }
256}