1use std::sync::Arc;
9
10use crate::autograd::AutogradError;
11use crate::tensor::{GradId, Layout, StorageHandle, Tensor, WeakStorageHandle};
12
13#[derive(Debug, Clone)]
26pub struct VersionSnapshot {
27 pub grad_id: GradId,
28 pub weak_storage: WeakStorageHandle,
29 pub recorded_version: usize,
30}
31
32impl VersionSnapshot {
33 pub fn new(grad_id: GradId, storage: &StorageHandle) -> Self {
34 Self {
35 grad_id,
36 recorded_version: storage.version(),
37 weak_storage: storage.downgrade(),
38 }
39 }
40
41 pub fn check(&self) -> Result<(), AutogradError> {
42 match self.weak_storage.upgrade() {
43 Some(strong) => {
44 let current = strong.version();
45 if current != self.recorded_version {
46 Err(AutogradError::VersionMismatch {
47 grad_id: self.grad_id,
48 expected: self.recorded_version,
49 found: current,
50 })
51 } else {
52 Ok(())
53 }
54 }
55 None => Ok(()),
56 }
57 }
58}
59
60#[derive(Debug)]
68pub struct AddBackward {
69 pub lhs_version: VersionSnapshot,
70 pub rhs_version: VersionSnapshot,
71}
72
73#[derive(Debug)]
77pub struct SubBackward {
78 pub lhs_version: VersionSnapshot,
79 pub rhs_version: VersionSnapshot,
80}
81
82#[derive(Debug)]
86pub struct MulBackward {
87 pub lhs_storage: StorageHandle,
88 pub lhs_layout: Layout,
89 pub lhs_version: VersionSnapshot,
90 pub rhs_storage: StorageHandle,
91 pub rhs_layout: Layout,
92 pub rhs_version: VersionSnapshot,
93}
94
95#[derive(Debug)]
99pub struct MatmulBackward {
100 pub lhs_storage: StorageHandle,
101 pub lhs_layout: Layout,
102 pub lhs_version: VersionSnapshot,
103 pub rhs_storage: StorageHandle,
104 pub rhs_layout: Layout,
105 pub rhs_version: VersionSnapshot,
106 pub m: usize,
107 pub k: usize,
108 pub n: usize,
109}
110
111#[derive(Debug)]
115pub struct ReluBackward {
116 pub input_storage: StorageHandle,
117 pub input_layout: Layout,
118 pub input_version: VersionSnapshot,
119}
120
121#[derive(Debug)]
127pub struct MseLossBackward {
128 pub pred_storage: StorageHandle,
129 pub pred_layout: Layout,
130 pub pred_version: VersionSnapshot,
131 pub target_storage: StorageHandle,
132 pub target_layout: Layout,
133 pub target_version: VersionSnapshot,
134 pub numel: usize,
135}
136
137#[derive(Debug)]
142pub struct AddBiasBackward {
143 pub input_version: VersionSnapshot,
144 pub bias_version: VersionSnapshot,
145 pub m: usize,
146 pub n: usize,
147}
148
149#[derive(Debug)]
154pub struct SliceBatchBackward {
155 pub input_version: VersionSnapshot,
156 pub original_shape: Vec<usize>,
158 pub index: usize,
160}
161
162#[derive(Debug)]
166pub struct Im2ColBackward {
167 pub input_version: VersionSnapshot,
168 pub c_in: usize,
169 pub h: usize,
170 pub w: usize,
171 pub kernel_size: usize,
172 pub stride: usize,
173 pub padding: usize,
174 pub out_h: usize,
175 pub out_w: usize,
176}
177
178#[derive(Debug)]
182pub struct StackBackward {
183 pub count: usize,
185 pub each_shape: Vec<usize>,
187 pub versions: Vec<VersionSnapshot>,
189}
190
191#[derive(Debug)]
196pub struct AddChannelBiasBackward {
197 pub input_version: VersionSnapshot,
198 pub bias_version: VersionSnapshot,
199 pub channels: usize,
200 pub spatial: usize,
201}
202
203#[derive(Debug)]
207pub struct MaxPool2dBackward {
208 pub input_version: VersionSnapshot,
209 pub indices_storage: StorageHandle,
211 pub indices_layout: Layout,
212 pub channels: usize,
213 pub h: usize,
214 pub w: usize,
215 pub out_h: usize,
216 pub out_w: usize,
217}
218
219#[derive(Debug)]
223pub struct ReshapeBackward {
224 pub input_version: VersionSnapshot,
225 pub original_shape: Vec<usize>,
226}
227
228#[derive(Debug)]
232pub struct FlattenBackward {
233 pub input_version: VersionSnapshot,
234 pub original_shape: Vec<usize>,
235}
236
237#[derive(Debug)]
242pub struct CrossEntropyBackward {
243 pub input_version: VersionSnapshot,
244 pub grad_storage: StorageHandle,
246 pub grad_layout: Layout,
247}
248
249#[derive(Debug)]
254pub struct DropoutBackward {
255 pub input_version: VersionSnapshot,
256 pub mask_storage: StorageHandle,
257 pub mask_layout: Layout,
258}
259
260#[derive(Debug)]
263pub struct TransposeBackward {
264 pub input_version: VersionSnapshot,
265 pub dim0: usize,
266 pub dim1: usize,
267}
268
269#[derive(Debug)]
272pub struct BmmBackward {
273 pub lhs_storage: StorageHandle,
274 pub lhs_layout: Layout,
275 pub lhs_version: VersionSnapshot,
276 pub rhs_storage: StorageHandle,
277 pub rhs_layout: Layout,
278 pub rhs_version: VersionSnapshot,
279 pub batch: usize,
280 pub m: usize,
281 pub k: usize,
282 pub n: usize,
283}
284
285#[derive(Debug)]
288pub struct SoftmaxBackward {
289 pub output_storage: StorageHandle,
290 pub output_layout: Layout,
291 pub input_version: VersionSnapshot,
292 pub num_rows: usize,
293 pub row_size: usize,
294}
295
296#[derive(Debug)]
301pub struct LayerNormBackward {
302 pub input_storage: StorageHandle,
303 pub input_layout: Layout,
304 pub input_version: VersionSnapshot,
305 pub weight_storage: StorageHandle,
306 pub weight_layout: Layout,
307 pub weight_version: VersionSnapshot,
308 pub save_storage: StorageHandle, pub save_layout: Layout,
310 pub num_instances: usize,
311 pub norm_size: usize,
312}
313
314#[derive(Debug)]
319pub struct EmbeddingBackward {
320 pub input_version: VersionSnapshot,
321 pub indices_storage: StorageHandle,
322 pub indices_layout: Layout,
323 pub vocab_size: usize,
324 pub embed_dim: usize,
325 pub total_lookups: usize,
326}
327
328#[derive(Debug)]
331pub struct SigmoidBackward {
332 pub output_storage: StorageHandle,
333 pub output_layout: Layout,
334 pub input_version: VersionSnapshot,
335}
336
337#[derive(Debug)]
340pub struct TanhBackward {
341 pub output_storage: StorageHandle,
342 pub output_layout: Layout,
343 pub input_version: VersionSnapshot,
344}
345
346#[derive(Debug)]
348pub struct GeluBackward {
349 pub input_storage: StorageHandle,
350 pub input_layout: Layout,
351 pub input_version: VersionSnapshot,
352}
353
354#[derive(Debug)]
356pub struct LeakyReluBackward {
357 pub input_storage: StorageHandle,
358 pub input_layout: Layout,
359 pub input_version: VersionSnapshot,
360 pub alpha: f32,
361}
362
363#[derive(Debug)]
368pub struct BroadcastAddBackward {
369 pub lhs_version: VersionSnapshot,
370 pub rhs_version: VersionSnapshot,
371 pub lhs_broadcast: Option<crate::tensor::broadcast::BroadcastInfo>,
372 pub rhs_broadcast: Option<crate::tensor::broadcast::BroadcastInfo>,
373 pub output_shape: Vec<usize>,
374}
375
376#[derive(Debug)]
377pub struct BroadcastSubBackward {
378 pub lhs_version: VersionSnapshot,
379 pub rhs_version: VersionSnapshot,
380 pub lhs_broadcast: Option<crate::tensor::broadcast::BroadcastInfo>,
381 pub rhs_broadcast: Option<crate::tensor::broadcast::BroadcastInfo>,
382 pub output_shape: Vec<usize>,
383}
384
385#[derive(Debug)]
386pub struct BroadcastMulBackward {
387 pub lhs_storage: StorageHandle,
388 pub lhs_layout: Layout,
389 pub lhs_version: VersionSnapshot,
390 pub rhs_storage: StorageHandle,
391 pub rhs_layout: Layout,
392 pub rhs_version: VersionSnapshot,
393 pub lhs_broadcast: Option<crate::tensor::broadcast::BroadcastInfo>,
394 pub rhs_broadcast: Option<crate::tensor::broadcast::BroadcastInfo>,
395 pub output_shape: Vec<usize>,
396}
397
398#[derive(Debug)]
403pub struct BatchNorm2dBackward {
404 pub input_storage: StorageHandle,
405 pub input_layout: Layout,
406 pub input_version: VersionSnapshot,
407 pub weight_storage: StorageHandle,
408 pub weight_layout: Layout,
409 pub weight_version: VersionSnapshot,
410 pub save_storage: StorageHandle, pub save_layout: Layout,
412 pub batch: usize,
413 pub channels: usize,
414 pub height: usize,
415 pub width: usize,
416}
417
418#[derive(Debug)]
423pub struct AdaptiveAvgPool2dBackward {
424 pub input_version: VersionSnapshot,
425 pub batch: usize,
426 pub channels: usize,
427 pub h_in: usize,
428 pub w_in: usize,
429 pub h_out: usize,
430 pub w_out: usize,
431}
432
433#[derive(Debug)]
438pub struct CastBackward {
439 pub input_version: VersionSnapshot,
440 pub source_dtype: crate::tensor::DType,
441}
442
443#[derive(Debug)]
451pub enum BackwardOp {
452 Add(AddBackward),
453 Sub(SubBackward),
454 Mul(MulBackward),
455 Matmul(MatmulBackward),
456 Relu(ReluBackward),
457 MseLoss(MseLossBackward),
458 AddBias(AddBiasBackward),
459 Im2Col(Im2ColBackward),
460 Stack(StackBackward),
461 AddChannelBias(AddChannelBiasBackward),
462 SliceBatch(SliceBatchBackward),
463 MaxPool2d(MaxPool2dBackward),
464 Flatten(FlattenBackward),
465 Reshape(ReshapeBackward),
466 Dropout(DropoutBackward),
467 CrossEntropy(CrossEntropyBackward),
468 Sigmoid(SigmoidBackward),
469 Tanh(TanhBackward),
470 Gelu(GeluBackward),
471 LeakyRelu(LeakyReluBackward),
472 Transpose(TransposeBackward),
473 Bmm(BmmBackward),
474 Softmax(SoftmaxBackward),
475 LayerNorm(LayerNormBackward),
476 Embedding(EmbeddingBackward),
477 BroadcastAdd(BroadcastAddBackward),
478 BroadcastSub(BroadcastSubBackward),
479 BroadcastMul(BroadcastMulBackward),
480 BatchNorm2d(BatchNorm2dBackward),
481 AdaptiveAvgPool2d(AdaptiveAvgPool2dBackward),
482 Cast(CastBackward),
483 SliceRange(SliceRangeBackward),
485 Cat(CatBackward),
487 #[cfg(feature = "multi_gpu")]
489 FsdpLinear(FsdpLinearBackward),
490 Custom(CustomBackwardOp),
492}
493
494pub trait CustomBackward: Send + Sync + std::fmt::Debug {
503 fn backward(&self, out_grad: &Tensor, saved: &[Tensor]) -> Vec<Tensor>;
507}
508
509pub struct CustomBackwardOp {
511 pub handler: Arc<dyn CustomBackward>,
513 pub input_versions: Vec<VersionSnapshot>,
515 pub saved_storages: Vec<StorageHandle>,
517 pub saved_layouts: Vec<Layout>,
518 pub saved_shapes: Vec<Vec<usize>>,
519}
520
521impl std::fmt::Debug for CustomBackwardOp {
522 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
523 f.debug_struct("CustomBackwardOp")
524 .field("handler", &self.handler)
525 .field("num_saved", &self.saved_storages.len())
526 .finish()
527 }
528}
529
530#[cfg(feature = "multi_gpu")]
540pub struct FsdpSync {
541 pub world_size: usize,
542 pub state: std::sync::Mutex<FsdpSyncState>,
543 pub cvar: std::sync::Condvar,
544}
545
546#[cfg(feature = "multi_gpu")]
547pub struct FsdpSyncState {
548 pub weight_grads: Vec<Vec<f32>>,
549 pub bias_grads: Vec<Vec<f32>>,
550 pub weight_result: Option<Vec<f32>>,
552 pub bias_result: Option<Vec<f32>>,
553 pub read_count: usize,
555}
556
557#[cfg(feature = "multi_gpu")]
558impl FsdpSync {
559 pub fn new(world_size: usize) -> Self {
560 Self {
561 world_size,
562 state: std::sync::Mutex::new(FsdpSyncState {
563 weight_grads: Vec::new(),
564 bias_grads: Vec::new(),
565 weight_result: None,
566 bias_result: None,
567 read_count: 0,
568 }),
569 cvar: std::sync::Condvar::new(),
570 }
571 }
572}
573
574#[cfg(feature = "multi_gpu")]
575impl std::fmt::Debug for FsdpSync {
576 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
577 f.debug_struct("FsdpSync")
578 .field("world_size", &self.world_size)
579 .finish()
580 }
581}
582
583#[cfg(feature = "multi_gpu")]
585unsafe impl Send for FsdpSync {}
586#[cfg(feature = "multi_gpu")]
587unsafe impl Sync for FsdpSync {}
588
589#[cfg(feature = "multi_gpu")]
592#[derive(Debug)]
593pub struct FsdpLinearBackward {
594 pub input_version: VersionSnapshot,
595 pub input_storage: StorageHandle,
597 pub input_layout: Layout,
598 pub weight_shard_storages: Vec<StorageHandle>,
600 pub weight_shard_layouts: Vec<Layout>,
601 pub full_weight_shape: Vec<usize>,
603 pub shard_size: usize,
605 pub weight_shard_offset: usize,
607 pub rank: usize,
609 pub world_size: usize,
610 pub device_index: usize,
612 pub has_bias: bool,
614 pub bias_shard_storages: Vec<StorageHandle>,
616 pub full_bias_shape: Vec<usize>,
618 pub bias_shard_offset: usize,
620 pub bias_shard_size: usize,
622 pub sync: std::sync::Arc<FsdpSync>,
624}
625
626#[derive(Debug)]
628pub struct SliceRangeBackward {
629 pub input_version: VersionSnapshot,
630 pub original_shape: Vec<usize>,
631 pub dim: usize,
632 pub start: usize,
633 pub end: usize,
634}
635
636#[derive(Debug)]
638pub struct CatBackward {
639 pub splits: Vec<usize>, pub dim: usize,
641 pub versions: Vec<VersionSnapshot>,
642}
643
644const _: () = {
645 fn _assert_send<T: Send>() {}
646 fn _assert_sync<T: Sync>() {}
647 fn _assertions() {
648 _assert_send::<BackwardOp>();
649 _assert_sync::<BackwardOp>();
650 }
651};