Skip to main content

rumus/autograd/
backward_ops.rs

1// SPDX-License-Identifier: Apache-2.0 OR MIT
2//! Backward operation structs and the version-checking snapshot.
3//!
4//! Each struct captures the minimal data needed to compute gradients for
5//! its corresponding forward op.  No opaque closures — every backward op
6//! is a concrete, inspectable type that is `Send + Sync` by construction.
7
8use std::sync::Arc;
9
10use crate::autograd::AutogradError;
11use crate::tensor::{GradId, Layout, StorageHandle, Tensor, WeakStorageHandle};
12
13// ---------------------------------------------------------------------------
14// VersionSnapshot — weak-reference version checker
15// ---------------------------------------------------------------------------
16
17/// Snapshot of a [`StorageHandle`]'s version counter at tape-record time.
18///
19/// Holds a [`WeakStorageHandle`] so recording does **not** keep intermediate
20/// tensor memory alive.
21///
22/// - **Upgrade succeeds:** compare live version vs recorded.  Mismatch →
23///   [`AutogradError::VersionMismatch`].
24/// - **Upgrade fails:** dead tensor → provably unmutated → `Ok(())`.
25#[derive(Debug, Clone)]
26pub struct VersionSnapshot {
27    pub grad_id: GradId,
28    pub weak_storage: WeakStorageHandle,
29    pub recorded_version: usize,
30}
31
32impl VersionSnapshot {
33    pub fn new(grad_id: GradId, storage: &StorageHandle) -> Self {
34        Self {
35            grad_id,
36            recorded_version: storage.version(),
37            weak_storage: storage.downgrade(),
38        }
39    }
40
41    pub fn check(&self) -> Result<(), AutogradError> {
42        match self.weak_storage.upgrade() {
43            Some(strong) => {
44                let current = strong.version();
45                if current != self.recorded_version {
46                    Err(AutogradError::VersionMismatch {
47                        grad_id: self.grad_id,
48                        expected: self.recorded_version,
49                        found: current,
50                    })
51                } else {
52                    Ok(())
53                }
54            }
55            None => Ok(()),
56        }
57    }
58}
59
60// ---------------------------------------------------------------------------
61// Per-op backward structs
62// ---------------------------------------------------------------------------
63
64/// Backward for `c = a + b`.
65///
66/// `∂L/∂a = ∂L/∂c`,  `∂L/∂b = ∂L/∂c`  (identity).
67#[derive(Debug)]
68pub struct AddBackward {
69    pub lhs_version: VersionSnapshot,
70    pub rhs_version: VersionSnapshot,
71}
72
73/// Backward for `c = a - b`.
74///
75/// `∂L/∂a = ∂L/∂c`,  `∂L/∂b = -∂L/∂c`.
76#[derive(Debug)]
77pub struct SubBackward {
78    pub lhs_version: VersionSnapshot,
79    pub rhs_version: VersionSnapshot,
80}
81
82/// Backward for `c = a * b` (element-wise).
83///
84/// `∂L/∂a = ∂L/∂c ⊙ b`,  `∂L/∂b = ∂L/∂c ⊙ a`.
85#[derive(Debug)]
86pub struct MulBackward {
87    pub lhs_storage: StorageHandle,
88    pub lhs_layout: Layout,
89    pub lhs_version: VersionSnapshot,
90    pub rhs_storage: StorageHandle,
91    pub rhs_layout: Layout,
92    pub rhs_version: VersionSnapshot,
93}
94
95/// Backward for `C = A @ B`.
96///
97/// `∂L/∂A = ∂L/∂C @ Bᵀ`,  `∂L/∂B = Aᵀ @ ∂L/∂C`.
98#[derive(Debug)]
99pub struct MatmulBackward {
100    pub lhs_storage: StorageHandle,
101    pub lhs_layout: Layout,
102    pub lhs_version: VersionSnapshot,
103    pub rhs_storage: StorageHandle,
104    pub rhs_layout: Layout,
105    pub rhs_version: VersionSnapshot,
106    pub m: usize,
107    pub k: usize,
108    pub n: usize,
109}
110
111/// Backward for `y = relu(x)`.
112///
113/// `∂L/∂x[i] = ∂L/∂y[i]  if x[i] > 0,  else 0`.
114#[derive(Debug)]
115pub struct ReluBackward {
116    pub input_storage: StorageHandle,
117    pub input_layout: Layout,
118    pub input_version: VersionSnapshot,
119}
120
121/// Backward for `loss = mse_loss(pred, target)` (fused).
122///
123/// `∂L/∂pred[i] = out_grad_scalar * 2 * (pred[i] - target[i]) / N`.
124///
125/// Only `pred` receives a gradient; `target` is treated as a constant.
126#[derive(Debug)]
127pub struct MseLossBackward {
128    pub pred_storage: StorageHandle,
129    pub pred_layout: Layout,
130    pub pred_version: VersionSnapshot,
131    pub target_storage: StorageHandle,
132    pub target_layout: Layout,
133    pub target_version: VersionSnapshot,
134    pub numel: usize,
135}
136
137/// Backward for `y = add_bias(matrix, bias)`.
138///
139/// `∂L/∂matrix = ∂L/∂y`  (identity, same shape `[m,n]`).
140/// `∂L/∂bias = sum_rows(∂L/∂y)`  (reduce `[m,n]` → `[n]`).
141#[derive(Debug)]
142pub struct AddBiasBackward {
143    pub input_version: VersionSnapshot,
144    pub bias_version: VersionSnapshot,
145    pub m: usize,
146    pub n: usize,
147}
148
149/// Backward for `slice_batch(input, index)`.
150///
151/// `∂L/∂input` is a zero tensor matching the original batched input shape,
152/// with `∂L/∂output` placed at the `index`-th batch slot.
153#[derive(Debug)]
154pub struct SliceBatchBackward {
155    pub input_version: VersionSnapshot,
156    /// Shape of the original batched input (e.g. `[batch, C, H, W]`).
157    pub original_shape: Vec<usize>,
158    /// Which batch element was sliced.
159    pub index: usize,
160}
161
162/// Backward for `im2col(input)`.
163///
164/// `∂L/∂input = col2im(∂L/∂output)`.
165#[derive(Debug)]
166pub struct Im2ColBackward {
167    pub input_version: VersionSnapshot,
168    pub c_in: usize,
169    pub h: usize,
170    pub w: usize,
171    pub kernel_size: usize,
172    pub stride: usize,
173    pub padding: usize,
174    pub out_h: usize,
175    pub out_w: usize,
176}
177
178/// Backward for `stack([t0, t1, ...], axis=0)`.
179///
180/// `∂L/∂t_i = slice(∂L/∂output, i)` along axis 0.
181#[derive(Debug)]
182pub struct StackBackward {
183    /// Number of tensors that were stacked.
184    pub count: usize,
185    /// Shape of each individual tensor (all must match).
186    pub each_shape: Vec<usize>,
187    /// Version snapshots for each input.
188    pub versions: Vec<VersionSnapshot>,
189}
190
191/// Backward for `add_channel_bias(src, bias)`.
192///
193/// `∂L/∂src = ∂L/∂out`  (identity, same shape `[batch*C, spatial]`)
194/// `∂L/∂bias = sum over spatial of ∂L/∂out` per channel.
195#[derive(Debug)]
196pub struct AddChannelBiasBackward {
197    pub input_version: VersionSnapshot,
198    pub bias_version: VersionSnapshot,
199    pub channels: usize,
200    pub spatial: usize,
201}
202
203/// Backward for `max_pool2d(input)`.
204///
205/// Scatters `∂L/∂output` to the argmax positions saved during forward.
206#[derive(Debug)]
207pub struct MaxPool2dBackward {
208    pub input_version: VersionSnapshot,
209    /// Saved argmax indices (flat spatial offsets stored as f32).
210    pub indices_storage: StorageHandle,
211    pub indices_layout: Layout,
212    pub channels: usize,
213    pub h: usize,
214    pub w: usize,
215    pub out_h: usize,
216    pub out_w: usize,
217}
218
219/// Backward for `reshape_tracked(input, new_shape)`.
220///
221/// `∂L/∂input = reshape(∂L/∂output, original_shape)` — zero-copy.
222#[derive(Debug)]
223pub struct ReshapeBackward {
224    pub input_version: VersionSnapshot,
225    pub original_shape: Vec<usize>,
226}
227
228/// Backward for `flatten(input)`.
229///
230/// `∂L/∂input = reshape(∂L/∂output, original_shape)` — zero-copy.
231#[derive(Debug)]
232pub struct FlattenBackward {
233    pub input_version: VersionSnapshot,
234    pub original_shape: Vec<usize>,
235}
236
237/// Backward for `cross_entropy_loss(logits, targets)`.
238///
239/// The gradient was pre-computed during the forward pass (softmax - one_hot,
240/// scaled by 1/B).  Backward simply scales by the incoming `out_grad` scalar.
241#[derive(Debug)]
242pub struct CrossEntropyBackward {
243    pub input_version: VersionSnapshot,
244    /// Pre-computed gradient [B, C], saved during forward.
245    pub grad_storage: StorageHandle,
246    pub grad_layout: Layout,
247}
248
249/// Backward for `dropout(input, p)`.
250///
251/// `∂L/∂input = ∂L/∂output * saved_mask`.
252/// Reuses the existing `mul` dispatch (auto CPU/GPU).
253#[derive(Debug)]
254pub struct DropoutBackward {
255    pub input_version: VersionSnapshot,
256    pub mask_storage: StorageHandle,
257    pub mask_layout: Layout,
258}
259
260/// Backward for tracked `transpose(dim0, dim1)`.
261/// `grad_input = transpose(grad_output, dim0, dim1)` — reverse the swap.
262#[derive(Debug)]
263pub struct TransposeBackward {
264    pub input_version: VersionSnapshot,
265    pub dim0: usize,
266    pub dim1: usize,
267}
268
269/// Backward for `bmm(A, B)`.
270/// `grad_A = bmm(grad_C, B^T)`, `grad_B = bmm(A^T, grad_C)`.
271#[derive(Debug)]
272pub struct BmmBackward {
273    pub lhs_storage: StorageHandle,
274    pub lhs_layout: Layout,
275    pub lhs_version: VersionSnapshot,
276    pub rhs_storage: StorageHandle,
277    pub rhs_layout: Layout,
278    pub rhs_version: VersionSnapshot,
279    pub batch: usize,
280    pub m: usize,
281    pub k: usize,
282    pub n: usize,
283}
284
285/// Backward for `softmax(input)`.  Saves **output**.
286/// `grad_input = saved * (grad_out - dot)` where `dot = Σ grad_out * saved`.
287#[derive(Debug)]
288pub struct SoftmaxBackward {
289    pub output_storage: StorageHandle,
290    pub output_layout: Layout,
291    pub input_version: VersionSnapshot,
292    pub num_rows: usize,
293    pub row_size: usize,
294}
295
296/// Backward for `layer_norm`.
297///
298/// Kernel 1: per-instance grad_input via c1/c2 reductions.
299/// Kernel 2: grad_weight = reduce(grad_out * x_hat), grad_bias = reduce(grad_out).
300#[derive(Debug)]
301pub struct LayerNormBackward {
302    pub input_storage: StorageHandle,
303    pub input_layout: Layout,
304    pub input_version: VersionSnapshot,
305    pub weight_storage: StorageHandle,
306    pub weight_layout: Layout,
307    pub weight_version: VersionSnapshot,
308    pub save_storage: StorageHandle,  // [num_instances, 2]: mean + invstd
309    pub save_layout: Layout,
310    pub num_instances: usize,
311    pub norm_size: usize,
312}
313
314/// Backward for `embedding(indices)`.
315///
316/// Sparse scatter: grad_weight[token_id] += grad_output[lookup].
317/// CPU-only backward (no f32 atomics in WGSL).
318#[derive(Debug)]
319pub struct EmbeddingBackward {
320    pub input_version: VersionSnapshot,
321    pub indices_storage: StorageHandle,
322    pub indices_layout: Layout,
323    pub vocab_size: usize,
324    pub embed_dim: usize,
325    pub total_lookups: usize,
326}
327
328/// Backward for `sigmoid(input)`.  Saves **output**.
329/// `grad = out_grad * saved_out * (1 - saved_out)`
330#[derive(Debug)]
331pub struct SigmoidBackward {
332    pub output_storage: StorageHandle,
333    pub output_layout: Layout,
334    pub input_version: VersionSnapshot,
335}
336
337/// Backward for `tanh(input)`.  Saves **output**.
338/// `grad = out_grad * (1 - saved_out^2)`
339#[derive(Debug)]
340pub struct TanhBackward {
341    pub output_storage: StorageHandle,
342    pub output_layout: Layout,
343    pub input_version: VersionSnapshot,
344}
345
346/// Backward for `gelu(input)` (tanh approx).  Saves **input**.
347#[derive(Debug)]
348pub struct GeluBackward {
349    pub input_storage: StorageHandle,
350    pub input_layout: Layout,
351    pub input_version: VersionSnapshot,
352}
353
354/// Backward for `leaky_relu(input, alpha)`.  Saves **input**.
355#[derive(Debug)]
356pub struct LeakyReluBackward {
357    pub input_storage: StorageHandle,
358    pub input_layout: Layout,
359    pub input_version: VersionSnapshot,
360    pub alpha: f32,
361}
362
363/// Backward for a broadcasted binary op.
364///
365/// If an operand was broadcast, its gradient must be summed (reduced)
366/// along the broadcast dimensions.
367#[derive(Debug)]
368pub struct BroadcastAddBackward {
369    pub lhs_version: VersionSnapshot,
370    pub rhs_version: VersionSnapshot,
371    pub lhs_broadcast: Option<crate::tensor::broadcast::BroadcastInfo>,
372    pub rhs_broadcast: Option<crate::tensor::broadcast::BroadcastInfo>,
373    pub output_shape: Vec<usize>,
374}
375
376#[derive(Debug)]
377pub struct BroadcastSubBackward {
378    pub lhs_version: VersionSnapshot,
379    pub rhs_version: VersionSnapshot,
380    pub lhs_broadcast: Option<crate::tensor::broadcast::BroadcastInfo>,
381    pub rhs_broadcast: Option<crate::tensor::broadcast::BroadcastInfo>,
382    pub output_shape: Vec<usize>,
383}
384
385#[derive(Debug)]
386pub struct BroadcastMulBackward {
387    pub lhs_storage: StorageHandle,
388    pub lhs_layout: Layout,
389    pub lhs_version: VersionSnapshot,
390    pub rhs_storage: StorageHandle,
391    pub rhs_layout: Layout,
392    pub rhs_version: VersionSnapshot,
393    pub lhs_broadcast: Option<crate::tensor::broadcast::BroadcastInfo>,
394    pub rhs_broadcast: Option<crate::tensor::broadcast::BroadcastInfo>,
395    pub output_shape: Vec<usize>,
396}
397
398/// Backward for `batch_norm_2d(input, weight, bias)`.
399///
400/// Saves input, weight, and mean+invstd for backward.
401/// Tape records 3 inputs: [input, weight, bias].
402#[derive(Debug)]
403pub struct BatchNorm2dBackward {
404    pub input_storage: StorageHandle,
405    pub input_layout: Layout,
406    pub input_version: VersionSnapshot,
407    pub weight_storage: StorageHandle,
408    pub weight_layout: Layout,
409    pub weight_version: VersionSnapshot,
410    pub save_storage: StorageHandle,  // [channels, 2]: mean + invstd per channel
411    pub save_layout: Layout,
412    pub batch: usize,
413    pub channels: usize,
414    pub height: usize,
415    pub width: usize,
416}
417
418/// Backward for `adaptive_avg_pool2d(input)`.
419///
420/// Each input pixel distributes its gradient to the output bins that cover it,
421/// weighted by `1/count`.
422#[derive(Debug)]
423pub struct AdaptiveAvgPool2dBackward {
424    pub input_version: VersionSnapshot,
425    pub batch: usize,
426    pub channels: usize,
427    pub h_in: usize,
428    pub w_in: usize,
429    pub h_out: usize,
430    pub w_out: usize,
431}
432
433/// Backward for `to_dtype(target_dtype)`.
434///
435/// The gradient of a cast is simply a cast in the reverse direction.
436/// No data needs to be saved — only the source dtype for the reverse cast.
437#[derive(Debug)]
438pub struct CastBackward {
439    pub input_version: VersionSnapshot,
440    pub source_dtype: crate::tensor::DType,
441}
442
443// ---------------------------------------------------------------------------
444// BackwardOp enum
445// ---------------------------------------------------------------------------
446
447/// Discriminated union of all backward operation types.
448///
449/// No closures, no trait objects — `Send + Sync` and inspectable.
450#[derive(Debug)]
451pub enum BackwardOp {
452    Add(AddBackward),
453    Sub(SubBackward),
454    Mul(MulBackward),
455    Matmul(MatmulBackward),
456    Relu(ReluBackward),
457    MseLoss(MseLossBackward),
458    AddBias(AddBiasBackward),
459    Im2Col(Im2ColBackward),
460    Stack(StackBackward),
461    AddChannelBias(AddChannelBiasBackward),
462    SliceBatch(SliceBatchBackward),
463    MaxPool2d(MaxPool2dBackward),
464    Flatten(FlattenBackward),
465    Reshape(ReshapeBackward),
466    Dropout(DropoutBackward),
467    CrossEntropy(CrossEntropyBackward),
468    Sigmoid(SigmoidBackward),
469    Tanh(TanhBackward),
470    Gelu(GeluBackward),
471    LeakyRelu(LeakyReluBackward),
472    Transpose(TransposeBackward),
473    Bmm(BmmBackward),
474    Softmax(SoftmaxBackward),
475    LayerNorm(LayerNormBackward),
476    Embedding(EmbeddingBackward),
477    BroadcastAdd(BroadcastAddBackward),
478    BroadcastSub(BroadcastSubBackward),
479    BroadcastMul(BroadcastMulBackward),
480    BatchNorm2d(BatchNorm2dBackward),
481    AdaptiveAvgPool2d(AdaptiveAvgPool2dBackward),
482    Cast(CastBackward),
483    /// Backward for `slice_range(dim, start, end)`.
484    SliceRange(SliceRangeBackward),
485    /// Backward for `cat(tensors, dim)`.
486    Cat(CatBackward),
487    /// Backward for FSDP sharded linear: re-gathers weights during backward.
488    #[cfg(feature = "multi_gpu")]
489    FsdpLinear(FsdpLinearBackward),
490    /// User-defined custom backward op via `Arc<dyn CustomBackward>`.
491    Custom(CustomBackwardOp),
492}
493
494// ---------------------------------------------------------------------------
495// Custom backward op (plugin system)
496// ---------------------------------------------------------------------------
497
498/// Trait for user-defined backward computations.
499///
500/// Implement this to define custom gradient math for operations injected
501/// via `ext::custom_forward`.
502pub trait CustomBackward: Send + Sync + std::fmt::Debug {
503    /// Compute input gradients given the output gradient and saved tensors.
504    ///
505    /// Returns one gradient per input (in forward input order).
506    fn backward(&self, out_grad: &Tensor, saved: &[Tensor]) -> Vec<Tensor>;
507}
508
509/// Backward state for a custom op: the user's handler + saved tensors.
510pub struct CustomBackwardOp {
511    /// The user's backward implementation.
512    pub handler: Arc<dyn CustomBackward>,
513    /// Version snapshots for each input (for mutation checking).
514    pub input_versions: Vec<VersionSnapshot>,
515    /// Tensors saved during forward for use in backward.
516    pub saved_storages: Vec<StorageHandle>,
517    pub saved_layouts: Vec<Layout>,
518    pub saved_shapes: Vec<Vec<usize>>,
519}
520
521impl std::fmt::Debug for CustomBackwardOp {
522    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
523        f.debug_struct("CustomBackwardOp")
524            .field("handler", &self.handler)
525            .field("num_saved", &self.saved_storages.len())
526            .finish()
527    }
528}
529
530/// Backward for FSDP-sharded linear layer.
531///
532/// During backward, re-gathers the full weight from all shard storages,
533/// Cross-rank synchronization barrier for FSDP gradient reduce-scatter.
534///
535/// Shared across all ranks for a single layer.  Each rank pushes its
536/// local gradient into `grads`, then waits on the `Condvar` until all
537/// ranks have arrived.  The last arrival sums the gradients and wakes
538/// all waiters.
539#[cfg(feature = "multi_gpu")]
540pub struct FsdpSync {
541    pub world_size: usize,
542    pub state: std::sync::Mutex<FsdpSyncState>,
543    pub cvar: std::sync::Condvar,
544}
545
546#[cfg(feature = "multi_gpu")]
547pub struct FsdpSyncState {
548    pub weight_grads: Vec<Vec<f32>>,
549    pub bias_grads: Vec<Vec<f32>>,
550    /// The reduced (summed + averaged) result.  Set by the last arrival.
551    pub weight_result: Option<Vec<f32>>,
552    pub bias_result: Option<Vec<f32>>,
553    /// Counts how many ranks have read the result and exited.
554    pub read_count: usize,
555}
556
557#[cfg(feature = "multi_gpu")]
558impl FsdpSync {
559    pub fn new(world_size: usize) -> Self {
560        Self {
561            world_size,
562            state: std::sync::Mutex::new(FsdpSyncState {
563                weight_grads: Vec::new(),
564                bias_grads: Vec::new(),
565                weight_result: None,
566                bias_result: None,
567                read_count: 0,
568            }),
569            cvar: std::sync::Condvar::new(),
570        }
571    }
572}
573
574#[cfg(feature = "multi_gpu")]
575impl std::fmt::Debug for FsdpSync {
576    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
577        f.debug_struct("FsdpSync")
578            .field("world_size", &self.world_size)
579            .finish()
580    }
581}
582
583// Safety: FsdpSync uses std::sync primitives — Send + Sync.
584#[cfg(feature = "multi_gpu")]
585unsafe impl Send for FsdpSync {}
586#[cfg(feature = "multi_gpu")]
587unsafe impl Sync for FsdpSync {}
588
589/// computes grad_X and grad_W, then reduce-scatters grad_W back to shards.
590/// The gathered weight is dropped immediately after use.
591#[cfg(feature = "multi_gpu")]
592#[derive(Debug)]
593pub struct FsdpLinearBackward {
594    pub input_version: VersionSnapshot,
595    /// Saved input for grad_W = X^T @ grad_Y.
596    pub input_storage: StorageHandle,
597    pub input_layout: Layout,
598    /// Shard storages from ALL ranks (used to re-gather W during backward).
599    pub weight_shard_storages: Vec<StorageHandle>,
600    pub weight_shard_layouts: Vec<Layout>,
601    /// Full weight shape [D_out, D_in] for re-assembly.
602    pub full_weight_shape: Vec<usize>,
603    /// Per-shard size along dim 0 for this rank's weight shard.
604    pub shard_size: usize,
605    /// Exact row offset in the full weight for this rank's shard.
606    pub weight_shard_offset: usize,
607    /// Which rank this backward op runs on.
608    pub rank: usize,
609    pub world_size: usize,
610    /// Device index for this rank.
611    pub device_index: usize,
612    /// Whether bias exists.
613    pub has_bias: bool,
614    /// Bias shard storages (one per rank), if bias exists.
615    pub bias_shard_storages: Vec<StorageHandle>,
616    /// Full bias shape [D_out].
617    pub full_bias_shape: Vec<usize>,
618    /// Exact offset in the full bias for this rank's shard.
619    pub bias_shard_offset: usize,
620    /// Bias shard size for this rank.
621    pub bias_shard_size: usize,
622    /// Shared cross-rank synchronization barrier for reduce-scatter.
623    pub sync: std::sync::Arc<FsdpSync>,
624}
625
626/// Backward for `slice_range`: scatter grad into a zero tensor at the slice position.
627#[derive(Debug)]
628pub struct SliceRangeBackward {
629    pub input_version: VersionSnapshot,
630    pub original_shape: Vec<usize>,
631    pub dim: usize,
632    pub start: usize,
633    pub end: usize,
634}
635
636/// Backward for `cat`: split the grad along the cat dimension.
637#[derive(Debug)]
638pub struct CatBackward {
639    pub splits: Vec<usize>,  // size of each input along the cat dim
640    pub dim: usize,
641    pub versions: Vec<VersionSnapshot>,
642}
643
644const _: () = {
645    fn _assert_send<T: Send>() {}
646    fn _assert_sync<T: Sync>() {}
647    fn _assertions() {
648        _assert_send::<BackwardOp>();
649        _assert_sync::<BackwardOp>();
650    }
651};