rumus 0.3.1

A native-Rust deep learning framework with explicit memory safety and hardware acceleration
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
// SPDX-License-Identifier: Apache-2.0 OR MIT
//! Backward operation structs and the version-checking snapshot.
//!
//! Each struct captures the minimal data needed to compute gradients for
//! its corresponding forward op.  No opaque closures — every backward op
//! is a concrete, inspectable type that is `Send + Sync` by construction.

use std::sync::Arc;

use crate::autograd::AutogradError;
use crate::tensor::{GradId, Layout, StorageHandle, Tensor, WeakStorageHandle};

// ---------------------------------------------------------------------------
// VersionSnapshot — weak-reference version checker
// ---------------------------------------------------------------------------

/// Snapshot of a [`StorageHandle`]'s version counter at tape-record time.
///
/// Holds a [`WeakStorageHandle`] so recording does **not** keep intermediate
/// tensor memory alive.
///
/// - **Upgrade succeeds:** compare live version vs recorded.  Mismatch →
///   [`AutogradError::VersionMismatch`].
/// - **Upgrade fails:** dead tensor → provably unmutated → `Ok(())`.
#[derive(Debug, Clone)]
pub struct VersionSnapshot {
    pub grad_id: GradId,
    pub weak_storage: WeakStorageHandle,
    pub recorded_version: usize,
}

impl VersionSnapshot {
    pub fn new(grad_id: GradId, storage: &StorageHandle) -> Self {
        Self {
            grad_id,
            recorded_version: storage.version(),
            weak_storage: storage.downgrade(),
        }
    }

    pub fn check(&self) -> Result<(), AutogradError> {
        match self.weak_storage.upgrade() {
            Some(strong) => {
                let current = strong.version();
                if current != self.recorded_version {
                    Err(AutogradError::VersionMismatch {
                        grad_id: self.grad_id,
                        expected: self.recorded_version,
                        found: current,
                    })
                } else {
                    Ok(())
                }
            }
            None => Ok(()),
        }
    }
}

// ---------------------------------------------------------------------------
// Per-op backward structs
// ---------------------------------------------------------------------------

/// Backward for `c = a + b`.
///
/// `∂L/∂a = ∂L/∂c`,  `∂L/∂b = ∂L/∂c`  (identity).
#[derive(Debug)]
pub struct AddBackward {
    pub lhs_version: VersionSnapshot,
    pub rhs_version: VersionSnapshot,
}

/// Backward for `c = a - b`.
///
/// `∂L/∂a = ∂L/∂c`,  `∂L/∂b = -∂L/∂c`.
#[derive(Debug)]
pub struct SubBackward {
    pub lhs_version: VersionSnapshot,
    pub rhs_version: VersionSnapshot,
}

/// Backward for `c = a * b` (element-wise).
///
/// `∂L/∂a = ∂L/∂c ⊙ b`,  `∂L/∂b = ∂L/∂c ⊙ a`.
#[derive(Debug)]
pub struct MulBackward {
    pub lhs_storage: StorageHandle,
    pub lhs_layout: Layout,
    pub lhs_version: VersionSnapshot,
    pub rhs_storage: StorageHandle,
    pub rhs_layout: Layout,
    pub rhs_version: VersionSnapshot,
}

/// Backward for `C = A @ B`.
///
/// `∂L/∂A = ∂L/∂C @ Bᵀ`,  `∂L/∂B = Aᵀ @ ∂L/∂C`.
#[derive(Debug)]
pub struct MatmulBackward {
    pub lhs_storage: StorageHandle,
    pub lhs_layout: Layout,
    pub lhs_version: VersionSnapshot,
    pub rhs_storage: StorageHandle,
    pub rhs_layout: Layout,
    pub rhs_version: VersionSnapshot,
    pub m: usize,
    pub k: usize,
    pub n: usize,
}

/// Backward for `y = relu(x)`.
///
/// `∂L/∂x[i] = ∂L/∂y[i]  if x[i] > 0,  else 0`.
#[derive(Debug)]
pub struct ReluBackward {
    pub input_storage: StorageHandle,
    pub input_layout: Layout,
    pub input_version: VersionSnapshot,
}

/// Backward for `loss = mse_loss(pred, target)` (fused).
///
/// `∂L/∂pred[i] = out_grad_scalar * 2 * (pred[i] - target[i]) / N`.
///
/// Only `pred` receives a gradient; `target` is treated as a constant.
#[derive(Debug)]
pub struct MseLossBackward {
    pub pred_storage: StorageHandle,
    pub pred_layout: Layout,
    pub pred_version: VersionSnapshot,
    pub target_storage: StorageHandle,
    pub target_layout: Layout,
    pub target_version: VersionSnapshot,
    pub numel: usize,
}

/// Backward for `y = add_bias(matrix, bias)`.
///
/// `∂L/∂matrix = ∂L/∂y`  (identity, same shape `[m,n]`).
/// `∂L/∂bias = sum_rows(∂L/∂y)`  (reduce `[m,n]` → `[n]`).
#[derive(Debug)]
pub struct AddBiasBackward {
    pub input_version: VersionSnapshot,
    pub bias_version: VersionSnapshot,
    pub m: usize,
    pub n: usize,
}

/// Backward for `slice_batch(input, index)`.
///
/// `∂L/∂input` is a zero tensor matching the original batched input shape,
/// with `∂L/∂output` placed at the `index`-th batch slot.
#[derive(Debug)]
pub struct SliceBatchBackward {
    pub input_version: VersionSnapshot,
    /// Shape of the original batched input (e.g. `[batch, C, H, W]`).
    pub original_shape: Vec<usize>,
    /// Which batch element was sliced.
    pub index: usize,
}

/// Backward for `im2col(input)`.
///
/// `∂L/∂input = col2im(∂L/∂output)`.
#[derive(Debug)]
pub struct Im2ColBackward {
    pub input_version: VersionSnapshot,
    pub c_in: usize,
    pub h: usize,
    pub w: usize,
    pub kernel_size: usize,
    pub stride: usize,
    pub padding: usize,
    pub out_h: usize,
    pub out_w: usize,
}

/// Backward for `stack([t0, t1, ...], axis=0)`.
///
/// `∂L/∂t_i = slice(∂L/∂output, i)` along axis 0.
#[derive(Debug)]
pub struct StackBackward {
    /// Number of tensors that were stacked.
    pub count: usize,
    /// Shape of each individual tensor (all must match).
    pub each_shape: Vec<usize>,
    /// Version snapshots for each input.
    pub versions: Vec<VersionSnapshot>,
}

/// Backward for `add_channel_bias(src, bias)`.
///
/// `∂L/∂src = ∂L/∂out`  (identity, same shape `[batch*C, spatial]`)
/// `∂L/∂bias = sum over spatial of ∂L/∂out` per channel.
#[derive(Debug)]
pub struct AddChannelBiasBackward {
    pub input_version: VersionSnapshot,
    pub bias_version: VersionSnapshot,
    pub channels: usize,
    pub spatial: usize,
}

/// Backward for `max_pool2d(input)`.
///
/// Scatters `∂L/∂output` to the argmax positions saved during forward.
#[derive(Debug)]
pub struct MaxPool2dBackward {
    pub input_version: VersionSnapshot,
    /// Saved argmax indices (flat spatial offsets stored as f32).
    pub indices_storage: StorageHandle,
    pub indices_layout: Layout,
    pub channels: usize,
    pub h: usize,
    pub w: usize,
    pub out_h: usize,
    pub out_w: usize,
}

/// Backward for `reshape_tracked(input, new_shape)`.
///
/// `∂L/∂input = reshape(∂L/∂output, original_shape)` — zero-copy.
#[derive(Debug)]
pub struct ReshapeBackward {
    pub input_version: VersionSnapshot,
    pub original_shape: Vec<usize>,
}

/// Backward for `flatten(input)`.
///
/// `∂L/∂input = reshape(∂L/∂output, original_shape)` — zero-copy.
#[derive(Debug)]
pub struct FlattenBackward {
    pub input_version: VersionSnapshot,
    pub original_shape: Vec<usize>,
}

/// Backward for `cross_entropy_loss(logits, targets)`.
///
/// The gradient was pre-computed during the forward pass (softmax - one_hot,
/// scaled by 1/B).  Backward simply scales by the incoming `out_grad` scalar.
#[derive(Debug)]
pub struct CrossEntropyBackward {
    pub input_version: VersionSnapshot,
    /// Pre-computed gradient [B, C], saved during forward.
    pub grad_storage: StorageHandle,
    pub grad_layout: Layout,
}

/// Backward for `dropout(input, p)`.
///
/// `∂L/∂input = ∂L/∂output * saved_mask`.
/// Reuses the existing `mul` dispatch (auto CPU/GPU).
#[derive(Debug)]
pub struct DropoutBackward {
    pub input_version: VersionSnapshot,
    pub mask_storage: StorageHandle,
    pub mask_layout: Layout,
}

/// Backward for tracked `transpose(dim0, dim1)`.
/// `grad_input = transpose(grad_output, dim0, dim1)` — reverse the swap.
#[derive(Debug)]
pub struct TransposeBackward {
    pub input_version: VersionSnapshot,
    pub dim0: usize,
    pub dim1: usize,
}

/// Backward for `bmm(A, B)`.
/// `grad_A = bmm(grad_C, B^T)`, `grad_B = bmm(A^T, grad_C)`.
#[derive(Debug)]
pub struct BmmBackward {
    pub lhs_storage: StorageHandle,
    pub lhs_layout: Layout,
    pub lhs_version: VersionSnapshot,
    pub rhs_storage: StorageHandle,
    pub rhs_layout: Layout,
    pub rhs_version: VersionSnapshot,
    pub batch: usize,
    pub m: usize,
    pub k: usize,
    pub n: usize,
}

/// Backward for `softmax(input)`.  Saves **output**.
/// `grad_input = saved * (grad_out - dot)` where `dot = Σ grad_out * saved`.
#[derive(Debug)]
pub struct SoftmaxBackward {
    pub output_storage: StorageHandle,
    pub output_layout: Layout,
    pub input_version: VersionSnapshot,
    pub num_rows: usize,
    pub row_size: usize,
}

/// Backward for `layer_norm`.
///
/// Kernel 1: per-instance grad_input via c1/c2 reductions.
/// Kernel 2: grad_weight = reduce(grad_out * x_hat), grad_bias = reduce(grad_out).
#[derive(Debug)]
pub struct LayerNormBackward {
    pub input_storage: StorageHandle,
    pub input_layout: Layout,
    pub input_version: VersionSnapshot,
    pub weight_storage: StorageHandle,
    pub weight_layout: Layout,
    pub weight_version: VersionSnapshot,
    pub save_storage: StorageHandle,  // [num_instances, 2]: mean + invstd
    pub save_layout: Layout,
    pub num_instances: usize,
    pub norm_size: usize,
}

/// Backward for `embedding(indices)`.
///
/// Sparse scatter: grad_weight[token_id] += grad_output[lookup].
/// CPU-only backward (no f32 atomics in WGSL).
#[derive(Debug)]
pub struct EmbeddingBackward {
    pub input_version: VersionSnapshot,
    pub indices_storage: StorageHandle,
    pub indices_layout: Layout,
    pub vocab_size: usize,
    pub embed_dim: usize,
    pub total_lookups: usize,
}

/// Backward for `sigmoid(input)`.  Saves **output**.
/// `grad = out_grad * saved_out * (1 - saved_out)`
#[derive(Debug)]
pub struct SigmoidBackward {
    pub output_storage: StorageHandle,
    pub output_layout: Layout,
    pub input_version: VersionSnapshot,
}

/// Backward for `tanh(input)`.  Saves **output**.
/// `grad = out_grad * (1 - saved_out^2)`
#[derive(Debug)]
pub struct TanhBackward {
    pub output_storage: StorageHandle,
    pub output_layout: Layout,
    pub input_version: VersionSnapshot,
}

/// Backward for `gelu(input)` (tanh approx).  Saves **input**.
#[derive(Debug)]
pub struct GeluBackward {
    pub input_storage: StorageHandle,
    pub input_layout: Layout,
    pub input_version: VersionSnapshot,
}

/// Backward for `leaky_relu(input, alpha)`.  Saves **input**.
#[derive(Debug)]
pub struct LeakyReluBackward {
    pub input_storage: StorageHandle,
    pub input_layout: Layout,
    pub input_version: VersionSnapshot,
    pub alpha: f32,
}

/// Backward for a broadcasted binary op.
///
/// If an operand was broadcast, its gradient must be summed (reduced)
/// along the broadcast dimensions.
#[derive(Debug)]
pub struct BroadcastAddBackward {
    pub lhs_version: VersionSnapshot,
    pub rhs_version: VersionSnapshot,
    pub lhs_broadcast: Option<crate::tensor::broadcast::BroadcastInfo>,
    pub rhs_broadcast: Option<crate::tensor::broadcast::BroadcastInfo>,
    pub output_shape: Vec<usize>,
}

#[derive(Debug)]
pub struct BroadcastSubBackward {
    pub lhs_version: VersionSnapshot,
    pub rhs_version: VersionSnapshot,
    pub lhs_broadcast: Option<crate::tensor::broadcast::BroadcastInfo>,
    pub rhs_broadcast: Option<crate::tensor::broadcast::BroadcastInfo>,
    pub output_shape: Vec<usize>,
}

#[derive(Debug)]
pub struct BroadcastMulBackward {
    pub lhs_storage: StorageHandle,
    pub lhs_layout: Layout,
    pub lhs_version: VersionSnapshot,
    pub rhs_storage: StorageHandle,
    pub rhs_layout: Layout,
    pub rhs_version: VersionSnapshot,
    pub lhs_broadcast: Option<crate::tensor::broadcast::BroadcastInfo>,
    pub rhs_broadcast: Option<crate::tensor::broadcast::BroadcastInfo>,
    pub output_shape: Vec<usize>,
}

/// Backward for `batch_norm_2d(input, weight, bias)`.
///
/// Saves input, weight, and mean+invstd for backward.
/// Tape records 3 inputs: [input, weight, bias].
#[derive(Debug)]
pub struct BatchNorm2dBackward {
    pub input_storage: StorageHandle,
    pub input_layout: Layout,
    pub input_version: VersionSnapshot,
    pub weight_storage: StorageHandle,
    pub weight_layout: Layout,
    pub weight_version: VersionSnapshot,
    pub save_storage: StorageHandle,  // [channels, 2]: mean + invstd per channel
    pub save_layout: Layout,
    pub batch: usize,
    pub channels: usize,
    pub height: usize,
    pub width: usize,
}

/// Backward for `adaptive_avg_pool2d(input)`.
///
/// Each input pixel distributes its gradient to the output bins that cover it,
/// weighted by `1/count`.
#[derive(Debug)]
pub struct AdaptiveAvgPool2dBackward {
    pub input_version: VersionSnapshot,
    pub batch: usize,
    pub channels: usize,
    pub h_in: usize,
    pub w_in: usize,
    pub h_out: usize,
    pub w_out: usize,
}

/// Backward for `to_dtype(target_dtype)`.
///
/// The gradient of a cast is simply a cast in the reverse direction.
/// No data needs to be saved — only the source dtype for the reverse cast.
#[derive(Debug)]
pub struct CastBackward {
    pub input_version: VersionSnapshot,
    pub source_dtype: crate::tensor::DType,
}

// ---------------------------------------------------------------------------
// BackwardOp enum
// ---------------------------------------------------------------------------

/// Discriminated union of all backward operation types.
///
/// No closures, no trait objects — `Send + Sync` and inspectable.
#[derive(Debug)]
pub enum BackwardOp {
    Add(AddBackward),
    Sub(SubBackward),
    Mul(MulBackward),
    Matmul(MatmulBackward),
    Relu(ReluBackward),
    MseLoss(MseLossBackward),
    AddBias(AddBiasBackward),
    Im2Col(Im2ColBackward),
    Stack(StackBackward),
    AddChannelBias(AddChannelBiasBackward),
    SliceBatch(SliceBatchBackward),
    MaxPool2d(MaxPool2dBackward),
    Flatten(FlattenBackward),
    Reshape(ReshapeBackward),
    Dropout(DropoutBackward),
    CrossEntropy(CrossEntropyBackward),
    Sigmoid(SigmoidBackward),
    Tanh(TanhBackward),
    Gelu(GeluBackward),
    LeakyRelu(LeakyReluBackward),
    Transpose(TransposeBackward),
    Bmm(BmmBackward),
    Softmax(SoftmaxBackward),
    LayerNorm(LayerNormBackward),
    Embedding(EmbeddingBackward),
    BroadcastAdd(BroadcastAddBackward),
    BroadcastSub(BroadcastSubBackward),
    BroadcastMul(BroadcastMulBackward),
    BatchNorm2d(BatchNorm2dBackward),
    AdaptiveAvgPool2d(AdaptiveAvgPool2dBackward),
    Cast(CastBackward),
    /// Backward for `slice_range(dim, start, end)`.
    SliceRange(SliceRangeBackward),
    /// Backward for `cat(tensors, dim)`.
    Cat(CatBackward),
    /// Backward for FSDP sharded linear: re-gathers weights during backward.
    #[cfg(feature = "multi_gpu")]
    FsdpLinear(FsdpLinearBackward),
    /// User-defined custom backward op via `Arc<dyn CustomBackward>`.
    Custom(CustomBackwardOp),
}

// ---------------------------------------------------------------------------
// Custom backward op (plugin system)
// ---------------------------------------------------------------------------

/// Trait for user-defined backward computations.
///
/// Implement this to define custom gradient math for operations injected
/// via `ext::custom_forward`.
pub trait CustomBackward: Send + Sync + std::fmt::Debug {
    /// Compute input gradients given the output gradient and saved tensors.
    ///
    /// Returns one gradient per input (in forward input order).
    fn backward(&self, out_grad: &Tensor, saved: &[Tensor]) -> Vec<Tensor>;
}

/// Backward state for a custom op: the user's handler + saved tensors.
pub struct CustomBackwardOp {
    /// The user's backward implementation.
    pub handler: Arc<dyn CustomBackward>,
    /// Version snapshots for each input (for mutation checking).
    pub input_versions: Vec<VersionSnapshot>,
    /// Tensors saved during forward for use in backward.
    pub saved_storages: Vec<StorageHandle>,
    pub saved_layouts: Vec<Layout>,
    pub saved_shapes: Vec<Vec<usize>>,
}

impl std::fmt::Debug for CustomBackwardOp {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("CustomBackwardOp")
            .field("handler", &self.handler)
            .field("num_saved", &self.saved_storages.len())
            .finish()
    }
}

/// Backward for FSDP-sharded linear layer.
///
/// During backward, re-gathers the full weight from all shard storages,
/// Cross-rank synchronization barrier for FSDP gradient reduce-scatter.
///
/// Shared across all ranks for a single layer.  Each rank pushes its
/// local gradient into `grads`, then waits on the `Condvar` until all
/// ranks have arrived.  The last arrival sums the gradients and wakes
/// all waiters.
#[cfg(feature = "multi_gpu")]
pub struct FsdpSync {
    pub world_size: usize,
    pub state: std::sync::Mutex<FsdpSyncState>,
    pub cvar: std::sync::Condvar,
}

#[cfg(feature = "multi_gpu")]
pub struct FsdpSyncState {
    pub weight_grads: Vec<Vec<f32>>,
    pub bias_grads: Vec<Vec<f32>>,
    /// The reduced (summed + averaged) result.  Set by the last arrival.
    pub weight_result: Option<Vec<f32>>,
    pub bias_result: Option<Vec<f32>>,
    /// Counts how many ranks have read the result and exited.
    pub read_count: usize,
}

#[cfg(feature = "multi_gpu")]
impl FsdpSync {
    pub fn new(world_size: usize) -> Self {
        Self {
            world_size,
            state: std::sync::Mutex::new(FsdpSyncState {
                weight_grads: Vec::new(),
                bias_grads: Vec::new(),
                weight_result: None,
                bias_result: None,
                read_count: 0,
            }),
            cvar: std::sync::Condvar::new(),
        }
    }
}

#[cfg(feature = "multi_gpu")]
impl std::fmt::Debug for FsdpSync {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("FsdpSync")
            .field("world_size", &self.world_size)
            .finish()
    }
}

// Safety: FsdpSync uses std::sync primitives — Send + Sync.
#[cfg(feature = "multi_gpu")]
unsafe impl Send for FsdpSync {}
#[cfg(feature = "multi_gpu")]
unsafe impl Sync for FsdpSync {}

/// computes grad_X and grad_W, then reduce-scatters grad_W back to shards.
/// The gathered weight is dropped immediately after use.
#[cfg(feature = "multi_gpu")]
#[derive(Debug)]
pub struct FsdpLinearBackward {
    pub input_version: VersionSnapshot,
    /// Saved input for grad_W = X^T @ grad_Y.
    pub input_storage: StorageHandle,
    pub input_layout: Layout,
    /// Shard storages from ALL ranks (used to re-gather W during backward).
    pub weight_shard_storages: Vec<StorageHandle>,
    pub weight_shard_layouts: Vec<Layout>,
    /// Full weight shape [D_out, D_in] for re-assembly.
    pub full_weight_shape: Vec<usize>,
    /// Per-shard size along dim 0 for this rank's weight shard.
    pub shard_size: usize,
    /// Exact row offset in the full weight for this rank's shard.
    pub weight_shard_offset: usize,
    /// Which rank this backward op runs on.
    pub rank: usize,
    pub world_size: usize,
    /// Device index for this rank.
    pub device_index: usize,
    /// Whether bias exists.
    pub has_bias: bool,
    /// Bias shard storages (one per rank), if bias exists.
    pub bias_shard_storages: Vec<StorageHandle>,
    /// Full bias shape [D_out].
    pub full_bias_shape: Vec<usize>,
    /// Exact offset in the full bias for this rank's shard.
    pub bias_shard_offset: usize,
    /// Bias shard size for this rank.
    pub bias_shard_size: usize,
    /// Shared cross-rank synchronization barrier for reduce-scatter.
    pub sync: std::sync::Arc<FsdpSync>,
}

/// Backward for `slice_range`: scatter grad into a zero tensor at the slice position.
#[derive(Debug)]
pub struct SliceRangeBackward {
    pub input_version: VersionSnapshot,
    pub original_shape: Vec<usize>,
    pub dim: usize,
    pub start: usize,
    pub end: usize,
}

/// Backward for `cat`: split the grad along the cat dimension.
#[derive(Debug)]
pub struct CatBackward {
    pub splits: Vec<usize>,  // size of each input along the cat dim
    pub dim: usize,
    pub versions: Vec<VersionSnapshot>,
}

const _: () = {
    fn _assert_send<T: Send>() {}
    fn _assert_sync<T: Sync>() {}
    fn _assertions() {
        _assert_send::<BackwardOp>();
        _assert_sync::<BackwardOp>();
    }
};