Skip to main content

rlx_ir/ops/
backward.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Backward / training op builders.
17//!
18//! These nodes are emitted by `rlx-opt::autodiff` when it walks a
19//! forward graph in reverse and needs a closed-form gradient kernel
20//! (rather than composing one from primitives). Output shapes follow
21//! directly from the forward shapes: `relu_backward` and
22//! `maxpool2d_backward` match the original input; conv backward shapes
23//! match the original input / weight; cross-entropy returns one loss
24//! per row of logits.
25//!
26//! Shape checks here are debug-only; the verifier in `verify.rs` does
27//! the rigorous version.
28
29use crate::op::{AttentionBwdWrt, MaskKind};
30use crate::{DType, Graph, NodeId, Op, Shape};
31
32impl Graph {
33    /// ReLU backward: `dx = dy where x > 0 else 0`. Output shape matches `x`.
34    pub fn relu_backward(&mut self, x: NodeId, dy: NodeId) -> NodeId {
35        let x_shape = self.shape(x).clone();
36        debug_assert_eq!(
37            self.shape(x),
38            self.shape(dy),
39            "relu_backward: x and dy must have identical shapes"
40        );
41        self.push(Op::ReluBackward, vec![x, dy], x_shape, None)
42    }
43
44    /// Element-wise activation backward — closed-form derivative of
45    /// any single-input activation other than ReLU. See
46    /// `Op::ActivationBackward` for the per-kind formulae.
47    pub fn activation_backward(
48        &mut self,
49        kind: crate::op::Activation,
50        x: NodeId,
51        dy: NodeId,
52    ) -> NodeId {
53        let x_shape = self.shape(x).clone();
54        debug_assert_eq!(
55            self.shape(x),
56            self.shape(dy),
57            "activation_backward: x and dy must have identical shapes"
58        );
59        self.push(Op::ActivationBackward { kind }, vec![x, dy], x_shape, None)
60    }
61
62    /// LayerNorm backward w.r.t. the input. Inputs `[x, gamma, dy]`.
63    /// Output shape matches `x`. Currently axis = -1 only.
64    pub fn layer_norm_backward_input(
65        &mut self,
66        x: NodeId,
67        gamma: NodeId,
68        dy: NodeId,
69        axis: i32,
70        eps: f32,
71    ) -> NodeId {
72        let x_shape = self.shape(x).clone();
73        debug_assert_eq!(
74            self.shape(x),
75            self.shape(dy),
76            "layer_norm_backward_input: x and dy must match"
77        );
78        self.push(
79            Op::LayerNormBackwardInput { axis, eps },
80            vec![x, gamma, dy],
81            x_shape,
82            None,
83        )
84    }
85
86    /// RMSNorm backward w.r.t. input. Inputs `[x, gamma, beta, dy]`.
87    pub fn rms_norm_backward_input(
88        &mut self,
89        x: NodeId,
90        gamma: NodeId,
91        beta: NodeId,
92        dy: NodeId,
93        axis: i32,
94        eps: f32,
95    ) -> NodeId {
96        let x_shape = self.shape(x).clone();
97        self.push(
98            Op::RmsNormBackwardInput { axis, eps },
99            vec![x, gamma, beta, dy],
100            x_shape,
101            None,
102        )
103    }
104
105    pub fn rms_norm_backward_gamma(
106        &mut self,
107        x: NodeId,
108        gamma: NodeId,
109        beta: NodeId,
110        dy: NodeId,
111        axis: i32,
112        eps: f32,
113    ) -> NodeId {
114        self.push(
115            Op::RmsNormBackwardGamma { axis, eps },
116            vec![x, gamma, beta, dy],
117            self.shape(gamma).clone(),
118            None,
119        )
120    }
121
122    pub fn rms_norm_backward_beta(
123        &mut self,
124        x: NodeId,
125        gamma: NodeId,
126        beta: NodeId,
127        dy: NodeId,
128        axis: i32,
129        eps: f32,
130    ) -> NodeId {
131        self.push(
132            Op::RmsNormBackwardBeta { axis, eps },
133            vec![x, gamma, beta, dy],
134            self.shape(beta).clone(),
135            None,
136        )
137    }
138
139    pub fn rope_backward(
140        &mut self,
141        dy: NodeId,
142        cos: NodeId,
143        sin: NodeId,
144        head_dim: usize,
145        n_rot: usize,
146    ) -> NodeId {
147        let out_shape = self.shape(dy).clone();
148        self.push(
149            Op::RopeBackward { head_dim, n_rot },
150            vec![dy, cos, sin],
151            out_shape,
152            None,
153        )
154    }
155
156    pub fn cumsum_backward(
157        &mut self,
158        dy: NodeId,
159        out_shape: Shape,
160        axis: i32,
161        exclusive: bool,
162    ) -> NodeId {
163        self.push(
164            Op::CumsumBackward { axis, exclusive },
165            vec![dy],
166            out_shape,
167            None,
168        )
169    }
170
171    pub fn gather_backward(
172        &mut self,
173        dy: NodeId,
174        indices: NodeId,
175        table_shape: Shape,
176        axis: i32,
177    ) -> NodeId {
178        self.push(
179            Op::GatherBackward { axis },
180            vec![dy, indices],
181            table_shape,
182            None,
183        )
184    }
185
186    /// GroupNorm (NCHW) backward w.r.t. input. Inputs `[x, gamma, beta, dy]`.
187    pub fn group_norm_backward_input(
188        &mut self,
189        x: NodeId,
190        gamma: NodeId,
191        beta: NodeId,
192        dy: NodeId,
193        num_groups: usize,
194        eps: f32,
195    ) -> NodeId {
196        let x_shape = self.shape(x).clone();
197        self.push(
198            Op::GroupNormBackwardInput { num_groups, eps },
199            vec![x, gamma, beta, dy],
200            x_shape,
201            None,
202        )
203    }
204
205    /// GroupNorm backward w.r.t. gamma. Inputs `[x, dy]`.
206    pub fn group_norm_backward_gamma(
207        &mut self,
208        x: NodeId,
209        dy: NodeId,
210        gamma_shape: Shape,
211        num_groups: usize,
212        eps: f32,
213    ) -> NodeId {
214        self.push(
215            Op::GroupNormBackwardGamma { num_groups, eps },
216            vec![x, dy],
217            gamma_shape,
218            None,
219        )
220    }
221
222    /// GroupNorm backward w.r.t. beta. Inputs `[x, dy]`.
223    pub fn group_norm_backward_beta(
224        &mut self,
225        x: NodeId,
226        dy: NodeId,
227        beta_shape: Shape,
228        num_groups: usize,
229        eps: f32,
230    ) -> NodeId {
231        self.push(
232            Op::GroupNormBackwardBeta { num_groups, eps },
233            vec![x, dy],
234            beta_shape,
235            None,
236        )
237    }
238
239    /// LayerNorm backward w.r.t. gamma. Inputs `[x, dy]`. Output shape
240    /// is provided by the caller — typically the gamma's shape, e.g.
241    /// `[D]` for a per-feature 1-D gamma.
242    pub fn layer_norm_backward_gamma(
243        &mut self,
244        x: NodeId,
245        dy: NodeId,
246        gamma_shape: Shape,
247        axis: i32,
248        eps: f32,
249    ) -> NodeId {
250        debug_assert_eq!(
251            self.shape(x),
252            self.shape(dy),
253            "layer_norm_backward_gamma: x and dy must match"
254        );
255        self.push(
256            Op::LayerNormBackwardGamma { axis, eps },
257            vec![x, dy],
258            gamma_shape,
259            None,
260        )
261    }
262
263    /// 2D max-pool backward. `x` is the original NCHW input; `dy` is
264    /// the upstream gradient with shape matching the pool's output.
265    /// Output shape matches `x`.
266    pub fn maxpool2d_backward(
267        &mut self,
268        x: NodeId,
269        dy: NodeId,
270        kernel_size: Vec<usize>,
271        stride: Vec<usize>,
272        padding: Vec<usize>,
273    ) -> NodeId {
274        let x_shape = self.shape(x).clone();
275        debug_assert_eq!(kernel_size.len(), 2, "maxpool2d_backward: 2-D only");
276        debug_assert_eq!(stride.len(), 2);
277        debug_assert_eq!(padding.len(), 2);
278        self.push(
279            Op::MaxPool2dBackward {
280                kernel_size,
281                stride,
282                padding,
283            },
284            vec![x, dy],
285            x_shape,
286            None,
287        )
288    }
289
290    /// Conv2D backward w.r.t. input. `dy` has the conv output shape;
291    /// `w` is the forward weight `[C_out, C_in/groups, kH, kW]`. The
292    /// output shape (the original input shape) is supplied by the
293    /// caller because it can't be unambiguously derived from `dy.shape`
294    /// alone in the presence of strides + padding.
295    pub fn conv2d_backward_input(
296        &mut self,
297        dy: NodeId,
298        w: NodeId,
299        x_shape: Shape,
300        kernel_size: Vec<usize>,
301        stride: Vec<usize>,
302        padding: Vec<usize>,
303        dilation: Vec<usize>,
304        groups: usize,
305    ) -> NodeId {
306        debug_assert_eq!(kernel_size.len(), 2);
307        debug_assert_eq!(stride.len(), 2);
308        debug_assert_eq!(padding.len(), 2);
309        debug_assert_eq!(dilation.len(), 2);
310        self.push(
311            Op::Conv2dBackwardInput {
312                kernel_size,
313                stride,
314                padding,
315                dilation,
316                groups,
317            },
318            vec![dy, w],
319            x_shape,
320            None,
321        )
322    }
323
324    /// Conv2D backward w.r.t. weight. Output shape matches the forward
325    /// weight `[C_out, C_in/groups, kH, kW]`.
326    pub fn conv2d_backward_weight(
327        &mut self,
328        x: NodeId,
329        dy: NodeId,
330        w_shape: Shape,
331        kernel_size: Vec<usize>,
332        stride: Vec<usize>,
333        padding: Vec<usize>,
334        dilation: Vec<usize>,
335        groups: usize,
336    ) -> NodeId {
337        debug_assert_eq!(kernel_size.len(), 2);
338        debug_assert_eq!(stride.len(), 2);
339        debug_assert_eq!(padding.len(), 2);
340        debug_assert_eq!(dilation.len(), 2);
341        self.push(
342            Op::Conv2dBackwardWeight {
343                kernel_size,
344                stride,
345                padding,
346                dilation,
347                groups,
348            },
349            vec![x, dy],
350            w_shape,
351            None,
352        )
353    }
354
355    /// Fused softmax + cross-entropy with f32-encoded integer labels.
356    /// `logits [N, C]`, `labels [N]` → `[N]` per-row loss.
357    pub fn softmax_cross_entropy_with_logits(&mut self, logits: NodeId, labels: NodeId) -> NodeId {
358        let logits_shape = self.shape(logits);
359        debug_assert_eq!(
360            logits_shape.rank(),
361            2,
362            "sce_with_logits: logits must be 2-D [N, C]"
363        );
364        let n = logits_shape.dim(0);
365        let dtype = logits_shape.dtype();
366        let out_shape = Shape::from_dims(&[n], dtype);
367        self.push(
368            Op::SoftmaxCrossEntropyWithLogits,
369            vec![logits, labels],
370            out_shape,
371            None,
372        )
373    }
374
375    /// Backward of `softmax_cross_entropy_with_logits`.
376    /// `[logits, labels, d_loss]` → `dlogits` shaped like `logits`.
377    pub fn softmax_cross_entropy_backward(
378        &mut self,
379        logits: NodeId,
380        labels: NodeId,
381        d_loss: NodeId,
382    ) -> NodeId {
383        let logits_shape = self.shape(logits).clone();
384        debug_assert_eq!(
385            logits_shape.rank(),
386            2,
387            "sce_backward: logits must be 2-D [N, C]"
388        );
389        self.push(
390            Op::SoftmaxCrossEntropyBackward,
391            vec![logits, labels, d_loss],
392            logits_shape,
393            None,
394        )
395    }
396
397    /// Element-wise complex squared-magnitude: `|z|² = re² + im²`.
398    /// Input must be `DType::C64`; output is same logical shape but
399    /// `DType::F32`. The canonical real-valued loss surface for
400    /// Wirtinger reverse-mode AD on complex graphs.
401    pub fn complex_norm_sq(&mut self, z: NodeId) -> NodeId {
402        let z_shape = self.shape(z).clone();
403        debug_assert_eq!(
404            z_shape.dtype(),
405            DType::C64,
406            "complex_norm_sq: input must be C64, got {:?}",
407            z_shape.dtype()
408        );
409        let out_shape = Shape::from_dims(z_shape.dims(), DType::F32);
410        self.push(Op::ComplexNormSq, vec![z], out_shape, None)
411    }
412
413    /// Scaled dot-product attention backward w.r.t. `q`, `k`, or `v`.
414    /// See [`Op::AttentionBackward`]. When `mask_kind` is [`MaskKind::Custom`]
415    /// or [`MaskKind::Bias`], pass the same mask tensor used in forward.
416    pub fn attention_backward(
417        &mut self,
418        wrt: AttentionBwdWrt,
419        q: NodeId,
420        k: NodeId,
421        v: NodeId,
422        dy: NodeId,
423        num_heads: usize,
424        head_dim: usize,
425        mask_kind: MaskKind,
426        mask: Option<NodeId>,
427    ) -> NodeId {
428        let out_shape = match wrt {
429            AttentionBwdWrt::Query => self.shape(q).clone(),
430            AttentionBwdWrt::Key => self.shape(k).clone(),
431            AttentionBwdWrt::Value => self.shape(v).clone(),
432        };
433        let mut inputs = vec![q, k, v, dy];
434        if matches!(mask_kind, MaskKind::Custom | MaskKind::Bias) {
435            inputs.push(mask.expect("attention_backward: mask required for Custom/Bias"));
436        }
437        self.push(
438            Op::AttentionBackward {
439                num_heads,
440                head_dim,
441                mask_kind,
442                wrt,
443            },
444            inputs,
445            out_shape,
446            None,
447        )
448    }
449
450    /// Emit `dQ`, `dK`, and `dV` for one [`Op::Attention`] forward node.
451    pub fn attention_backward_all(
452        &mut self,
453        q: NodeId,
454        k: NodeId,
455        v: NodeId,
456        dy: NodeId,
457        num_heads: usize,
458        head_dim: usize,
459        mask_kind: MaskKind,
460        mask: Option<NodeId>,
461    ) -> (NodeId, NodeId, NodeId) {
462        let dq = self.attention_backward(
463            AttentionBwdWrt::Query,
464            q,
465            k,
466            v,
467            dy,
468            num_heads,
469            head_dim,
470            mask_kind,
471            mask,
472        );
473        let dk = self.attention_backward(
474            AttentionBwdWrt::Key,
475            q,
476            k,
477            v,
478            dy,
479            num_heads,
480            head_dim,
481            mask_kind,
482            mask,
483        );
484        let dv = self.attention_backward(
485            AttentionBwdWrt::Value,
486            q,
487            k,
488            v,
489            dy,
490            num_heads,
491            head_dim,
492            mask_kind,
493            mask,
494        );
495        (dq, dk, dv)
496    }
497
498    /// Wirtinger backward for [`complex_norm_sq`]: given upstream `g`
499    /// (real, same shape as the forward output) and the original
500    /// complex input `z`, returns `dz = g · z` as C64.
501    pub fn complex_norm_sq_backward(&mut self, z: NodeId, g: NodeId) -> NodeId {
502        let z_shape = self.shape(z).clone();
503        debug_assert_eq!(z_shape.dtype(), DType::C64);
504        debug_assert_eq!(self.shape(g).dtype(), DType::F32);
505        debug_assert_eq!(
506            z_shape.dims(),
507            self.shape(g).dims(),
508            "complex_norm_sq_backward: z and g must share logical shape"
509        );
510        self.push(Op::ComplexNormSqBackward, vec![z, g], z_shape, None)
511    }
512
513    /// Element-wise complex conjugate: `z̄ = re - i·im`. Input must be
514    /// `DType::C64`; output is the same shape and dtype. Used by
515    /// Wirtinger VJP rules on C64 binary ops.
516    pub fn conjugate(&mut self, z: NodeId) -> NodeId {
517        let z_shape = self.shape(z).clone();
518        debug_assert_eq!(
519            z_shape.dtype(),
520            DType::C64,
521            "conjugate: input must be C64, got {:?}",
522            z_shape.dtype()
523        );
524        self.push(Op::Conjugate, vec![z], z_shape, None)
525    }
526}