Skip to main content

rlx_ir/ops/
backward.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Backward / training op builders.
17//!
18//! These nodes are emitted by `rlx-opt::autodiff` when it walks a
19//! forward graph in reverse and needs a closed-form gradient kernel
20//! (rather than composing one from primitives). Output shapes follow
21//! directly from the forward shapes: `relu_backward` and
22//! `maxpool2d_backward` match the original input; conv backward shapes
23//! match the original input / weight; cross-entropy returns one loss
24//! per row of logits.
25//!
26//! Shape checks here are debug-only; the verifier in `verify.rs` does
27//! the rigorous version.
28
29use crate::op::{AttentionBwdWrt, MaskKind};
30use crate::{DType, Graph, NodeId, Op, Shape};
31
32impl Graph {
33    /// ReLU backward: `dx = dy where x > 0 else 0`. Output shape matches `x`.
34    pub fn relu_backward(&mut self, x: NodeId, dy: NodeId) -> NodeId {
35        let x_shape = self.shape(x).clone();
36        debug_assert_eq!(
37            self.shape(x),
38            self.shape(dy),
39            "relu_backward: x and dy must have identical shapes"
40        );
41        self.push(Op::ReluBackward, vec![x, dy], x_shape, None)
42    }
43
44    /// Element-wise activation backward — closed-form derivative of
45    /// any single-input activation other than ReLU. See
46    /// `Op::ActivationBackward` for the per-kind formulae.
47    pub fn activation_backward(
48        &mut self,
49        kind: crate::op::Activation,
50        x: NodeId,
51        dy: NodeId,
52    ) -> NodeId {
53        let x_shape = self.shape(x).clone();
54        debug_assert_eq!(
55            self.shape(x),
56            self.shape(dy),
57            "activation_backward: x and dy must have identical shapes"
58        );
59        self.push(Op::ActivationBackward { kind }, vec![x, dy], x_shape, None)
60    }
61
62    /// LayerNorm backward w.r.t. the input. Inputs `[x, gamma, dy]`.
63    /// Output shape matches `x`. Currently axis = -1 only.
64    pub fn layer_norm_backward_input(
65        &mut self,
66        x: NodeId,
67        gamma: NodeId,
68        dy: NodeId,
69        axis: i32,
70        eps: f32,
71    ) -> NodeId {
72        let x_shape = self.shape(x).clone();
73        debug_assert_eq!(
74            self.shape(x),
75            self.shape(dy),
76            "layer_norm_backward_input: x and dy must match"
77        );
78        self.push(
79            Op::LayerNormBackwardInput { axis, eps },
80            vec![x, gamma, dy],
81            x_shape,
82            None,
83        )
84    }
85
86    /// RMSNorm backward w.r.t. input. Inputs `[x, gamma, beta, dy]`.
87    pub fn rms_norm_backward_input(
88        &mut self,
89        x: NodeId,
90        gamma: NodeId,
91        beta: NodeId,
92        dy: NodeId,
93        axis: i32,
94        eps: f32,
95    ) -> NodeId {
96        let x_shape = self.shape(x).clone();
97        self.push(
98            Op::RmsNormBackwardInput { axis, eps },
99            vec![x, gamma, beta, dy],
100            x_shape,
101            None,
102        )
103    }
104
105    pub fn rms_norm_backward_gamma(
106        &mut self,
107        x: NodeId,
108        gamma: NodeId,
109        beta: NodeId,
110        dy: NodeId,
111        axis: i32,
112        eps: f32,
113    ) -> NodeId {
114        self.push(
115            Op::RmsNormBackwardGamma { axis, eps },
116            vec![x, gamma, beta, dy],
117            self.shape(gamma).clone(),
118            None,
119        )
120    }
121
122    pub fn rms_norm_backward_beta(
123        &mut self,
124        x: NodeId,
125        gamma: NodeId,
126        beta: NodeId,
127        dy: NodeId,
128        axis: i32,
129        eps: f32,
130    ) -> NodeId {
131        self.push(
132            Op::RmsNormBackwardBeta { axis, eps },
133            vec![x, gamma, beta, dy],
134            self.shape(beta).clone(),
135            None,
136        )
137    }
138
139    pub fn rope_backward(
140        &mut self,
141        dy: NodeId,
142        cos: NodeId,
143        sin: NodeId,
144        head_dim: usize,
145        n_rot: usize,
146    ) -> NodeId {
147        let out_shape = self.shape(dy).clone();
148        self.push(
149            Op::RopeBackward { head_dim, n_rot },
150            vec![dy, cos, sin],
151            out_shape,
152            None,
153        )
154    }
155
156    pub fn cumsum_backward(
157        &mut self,
158        dy: NodeId,
159        out_shape: Shape,
160        axis: i32,
161        exclusive: bool,
162    ) -> NodeId {
163        self.push(
164            Op::CumsumBackward { axis, exclusive },
165            vec![dy],
166            out_shape,
167            None,
168        )
169    }
170
171    pub fn gather_backward(
172        &mut self,
173        dy: NodeId,
174        indices: NodeId,
175        table_shape: Shape,
176        axis: i32,
177    ) -> NodeId {
178        self.push(
179            Op::GatherBackward { axis },
180            vec![dy, indices],
181            table_shape,
182            None,
183        )
184    }
185
186    /// GroupNorm (NCHW) backward w.r.t. input. Inputs `[x, gamma, beta, dy]`.
187    pub fn group_norm_backward_input(
188        &mut self,
189        x: NodeId,
190        gamma: NodeId,
191        beta: NodeId,
192        dy: NodeId,
193        num_groups: usize,
194        eps: f32,
195    ) -> NodeId {
196        let x_shape = self.shape(x).clone();
197        self.push(
198            Op::GroupNormBackwardInput { num_groups, eps },
199            vec![x, gamma, beta, dy],
200            x_shape,
201            None,
202        )
203    }
204
205    /// GroupNorm backward w.r.t. gamma. Inputs `[x, dy]`.
206    pub fn group_norm_backward_gamma(
207        &mut self,
208        x: NodeId,
209        dy: NodeId,
210        gamma_shape: Shape,
211        num_groups: usize,
212        eps: f32,
213    ) -> NodeId {
214        self.push(
215            Op::GroupNormBackwardGamma { num_groups, eps },
216            vec![x, dy],
217            gamma_shape,
218            None,
219        )
220    }
221
222    /// GroupNorm backward w.r.t. beta. Inputs `[x, dy]`.
223    pub fn group_norm_backward_beta(
224        &mut self,
225        x: NodeId,
226        dy: NodeId,
227        beta_shape: Shape,
228        num_groups: usize,
229        eps: f32,
230    ) -> NodeId {
231        self.push(
232            Op::GroupNormBackwardBeta { num_groups, eps },
233            vec![x, dy],
234            beta_shape,
235            None,
236        )
237    }
238
239    /// BatchNorm inference backward w.r.t. input.
240    pub fn batch_norm_inference_backward_input(
241        &mut self,
242        x: NodeId,
243        gamma: NodeId,
244        mean: NodeId,
245        var: NodeId,
246        dy: NodeId,
247        eps: f32,
248    ) -> NodeId {
249        let x_shape = self.shape(x).clone();
250        debug_assert_eq!(self.shape(x), self.shape(dy));
251        self.push(
252            Op::BatchNormInferenceBackwardInput { eps },
253            vec![x, gamma, mean, var, dy],
254            x_shape,
255            None,
256        )
257    }
258
259    /// BatchNorm inference backward w.r.t. gamma.
260    pub fn batch_norm_inference_backward_gamma(
261        &mut self,
262        x: NodeId,
263        mean: NodeId,
264        var: NodeId,
265        dy: NodeId,
266        gamma_shape: Shape,
267        eps: f32,
268    ) -> NodeId {
269        self.push(
270            Op::BatchNormInferenceBackwardGamma { eps },
271            vec![x, mean, var, dy],
272            gamma_shape,
273            None,
274        )
275    }
276
277    /// BatchNorm inference backward w.r.t. beta.
278    pub fn batch_norm_inference_backward_beta(&mut self, dy: NodeId, beta_shape: Shape) -> NodeId {
279        self.push(
280            Op::BatchNormInferenceBackwardBeta,
281            vec![dy],
282            beta_shape,
283            None,
284        )
285    }
286
287    /// LayerNorm backward w.r.t. gamma. Inputs `[x, dy]`. Output shape
288    /// is provided by the caller — typically the gamma's shape, e.g.
289    /// `[D]` for a per-feature 1-D gamma.
290    pub fn layer_norm_backward_gamma(
291        &mut self,
292        x: NodeId,
293        dy: NodeId,
294        gamma_shape: Shape,
295        axis: i32,
296        eps: f32,
297    ) -> NodeId {
298        debug_assert_eq!(
299            self.shape(x),
300            self.shape(dy),
301            "layer_norm_backward_gamma: x and dy must match"
302        );
303        self.push(
304            Op::LayerNormBackwardGamma { axis, eps },
305            vec![x, dy],
306            gamma_shape,
307            None,
308        )
309    }
310
311    /// 2D max-pool backward. `x` is the original NCHW input; `dy` is
312    /// the upstream gradient with shape matching the pool's output.
313    /// Output shape matches `x`.
314    pub fn maxpool2d_backward(
315        &mut self,
316        x: NodeId,
317        dy: NodeId,
318        kernel_size: Vec<usize>,
319        stride: Vec<usize>,
320        padding: Vec<usize>,
321    ) -> NodeId {
322        let x_shape = self.shape(x).clone();
323        debug_assert_eq!(kernel_size.len(), 2, "maxpool2d_backward: 2-D only");
324        debug_assert_eq!(stride.len(), 2);
325        debug_assert_eq!(padding.len(), 2);
326        self.push(
327            Op::MaxPool2dBackward {
328                kernel_size,
329                stride,
330                padding,
331            },
332            vec![x, dy],
333            x_shape,
334            None,
335        )
336    }
337
338    /// Conv2D backward w.r.t. input. `dy` has the conv output shape;
339    /// `w` is the forward weight `[C_out, C_in/groups, kH, kW]`. The
340    /// output shape (the original input shape) is supplied by the
341    /// caller because it can't be unambiguously derived from `dy.shape`
342    /// alone in the presence of strides + padding.
343    pub fn conv2d_backward_input(
344        &mut self,
345        dy: NodeId,
346        w: NodeId,
347        x_shape: Shape,
348        kernel_size: Vec<usize>,
349        stride: Vec<usize>,
350        padding: Vec<usize>,
351        dilation: Vec<usize>,
352        groups: usize,
353    ) -> NodeId {
354        debug_assert_eq!(kernel_size.len(), 2);
355        debug_assert_eq!(stride.len(), 2);
356        debug_assert_eq!(padding.len(), 2);
357        debug_assert_eq!(dilation.len(), 2);
358        self.push(
359            Op::Conv2dBackwardInput {
360                kernel_size,
361                stride,
362                padding,
363                dilation,
364                groups,
365            },
366            vec![dy, w],
367            x_shape,
368            None,
369        )
370    }
371
372    /// Conv2D backward w.r.t. weight. Output shape matches the forward
373    /// weight `[C_out, C_in/groups, kH, kW]`.
374    pub fn conv2d_backward_weight(
375        &mut self,
376        x: NodeId,
377        dy: NodeId,
378        w_shape: Shape,
379        kernel_size: Vec<usize>,
380        stride: Vec<usize>,
381        padding: Vec<usize>,
382        dilation: Vec<usize>,
383        groups: usize,
384    ) -> NodeId {
385        debug_assert_eq!(kernel_size.len(), 2);
386        debug_assert_eq!(stride.len(), 2);
387        debug_assert_eq!(padding.len(), 2);
388        debug_assert_eq!(dilation.len(), 2);
389        self.push(
390            Op::Conv2dBackwardWeight {
391                kernel_size,
392                stride,
393                padding,
394                dilation,
395                groups,
396            },
397            vec![x, dy],
398            w_shape,
399            None,
400        )
401    }
402
403    /// Fused softmax + cross-entropy with f32-encoded integer labels.
404    /// `logits [N, C]`, `labels [N]` → `[N]` per-row loss.
405    pub fn softmax_cross_entropy_with_logits(&mut self, logits: NodeId, labels: NodeId) -> NodeId {
406        let logits_shape = self.shape(logits);
407        debug_assert_eq!(
408            logits_shape.rank(),
409            2,
410            "sce_with_logits: logits must be 2-D [N, C]"
411        );
412        let n = logits_shape.dim(0);
413        let dtype = logits_shape.dtype();
414        let out_shape = Shape::from_dims(&[n], dtype);
415        self.push(
416            Op::SoftmaxCrossEntropyWithLogits,
417            vec![logits, labels],
418            out_shape,
419            None,
420        )
421    }
422
423    /// Backward of `softmax_cross_entropy_with_logits`.
424    /// `[logits, labels, d_loss]` → `dlogits` shaped like `logits`.
425    pub fn softmax_cross_entropy_backward(
426        &mut self,
427        logits: NodeId,
428        labels: NodeId,
429        d_loss: NodeId,
430    ) -> NodeId {
431        let logits_shape = self.shape(logits).clone();
432        debug_assert_eq!(
433            logits_shape.rank(),
434            2,
435            "sce_backward: logits must be 2-D [N, C]"
436        );
437        self.push(
438            Op::SoftmaxCrossEntropyBackward,
439            vec![logits, labels, d_loss],
440            logits_shape,
441            None,
442        )
443    }
444
445    /// Element-wise complex squared-magnitude: `|z|² = re² + im²`.
446    /// Input must be `DType::C64`; output is same logical shape but
447    /// `DType::F32`. The canonical real-valued loss surface for
448    /// Wirtinger reverse-mode AD on complex graphs.
449    pub fn complex_norm_sq(&mut self, z: NodeId) -> NodeId {
450        let z_shape = self.shape(z).clone();
451        debug_assert_eq!(
452            z_shape.dtype(),
453            DType::C64,
454            "complex_norm_sq: input must be C64, got {:?}",
455            z_shape.dtype()
456        );
457        let out_shape = Shape::from_dims(z_shape.dims(), DType::F32);
458        self.push(Op::ComplexNormSq, vec![z], out_shape, None)
459    }
460
461    /// Scaled dot-product attention backward w.r.t. `q`, `k`, or `v`.
462    /// See [`Op::AttentionBackward`]. When `mask_kind` is [`MaskKind::Custom`]
463    /// or [`MaskKind::Bias`], pass the same mask tensor used in forward.
464    pub fn attention_backward(
465        &mut self,
466        wrt: AttentionBwdWrt,
467        q: NodeId,
468        k: NodeId,
469        v: NodeId,
470        dy: NodeId,
471        num_heads: usize,
472        head_dim: usize,
473        mask_kind: MaskKind,
474        mask: Option<NodeId>,
475    ) -> NodeId {
476        let out_shape = match wrt {
477            AttentionBwdWrt::Query => self.shape(q).clone(),
478            AttentionBwdWrt::Key => self.shape(k).clone(),
479            AttentionBwdWrt::Value => self.shape(v).clone(),
480        };
481        let mut inputs = vec![q, k, v, dy];
482        if matches!(mask_kind, MaskKind::Custom | MaskKind::Bias) {
483            inputs.push(mask.expect("attention_backward: mask required for Custom/Bias"));
484        }
485        self.push(
486            Op::AttentionBackward {
487                num_heads,
488                head_dim,
489                mask_kind,
490                wrt,
491            },
492            inputs,
493            out_shape,
494            None,
495        )
496    }
497
498    /// Emit `dQ`, `dK`, and `dV` for one [`Op::Attention`] forward node.
499    pub fn attention_backward_all(
500        &mut self,
501        q: NodeId,
502        k: NodeId,
503        v: NodeId,
504        dy: NodeId,
505        num_heads: usize,
506        head_dim: usize,
507        mask_kind: MaskKind,
508        mask: Option<NodeId>,
509    ) -> (NodeId, NodeId, NodeId) {
510        let dq = self.attention_backward(
511            AttentionBwdWrt::Query,
512            q,
513            k,
514            v,
515            dy,
516            num_heads,
517            head_dim,
518            mask_kind,
519            mask,
520        );
521        let dk = self.attention_backward(
522            AttentionBwdWrt::Key,
523            q,
524            k,
525            v,
526            dy,
527            num_heads,
528            head_dim,
529            mask_kind,
530            mask,
531        );
532        let dv = self.attention_backward(
533            AttentionBwdWrt::Value,
534            q,
535            k,
536            v,
537            dy,
538            num_heads,
539            head_dim,
540            mask_kind,
541            mask,
542        );
543        (dq, dk, dv)
544    }
545
546    /// Wirtinger backward for [`complex_norm_sq`]: given upstream `g`
547    /// (real, same shape as the forward output) and the original
548    /// complex input `z`, returns `dz = g · z` as C64.
549    pub fn complex_norm_sq_backward(&mut self, z: NodeId, g: NodeId) -> NodeId {
550        let z_shape = self.shape(z).clone();
551        debug_assert_eq!(z_shape.dtype(), DType::C64);
552        debug_assert_eq!(self.shape(g).dtype(), DType::F32);
553        debug_assert_eq!(
554            z_shape.dims(),
555            self.shape(g).dims(),
556            "complex_norm_sq_backward: z and g must share logical shape"
557        );
558        self.push(Op::ComplexNormSqBackward, vec![z, g], z_shape, None)
559    }
560
561    /// Element-wise complex conjugate: `z̄ = re - i·im`. Input must be
562    /// `DType::C64`; output is the same shape and dtype. Used by
563    /// Wirtinger VJP rules on C64 binary ops.
564    pub fn conjugate(&mut self, z: NodeId) -> NodeId {
565        let z_shape = self.shape(z).clone();
566        debug_assert_eq!(
567            z_shape.dtype(),
568            DType::C64,
569            "conjugate: input must be C64, got {:?}",
570            z_shape.dtype()
571        );
572        self.push(Op::Conjugate, vec![z], z_shape, None)
573    }
574}