use std::sync::Arc;
use crate::autograd::AutogradError;
use crate::tensor::{GradId, Layout, StorageHandle, Tensor, WeakStorageHandle};
#[derive(Debug, Clone)]
pub struct VersionSnapshot {
pub grad_id: GradId,
pub weak_storage: WeakStorageHandle,
pub recorded_version: usize,
}
impl VersionSnapshot {
pub fn new(grad_id: GradId, storage: &StorageHandle) -> Self {
Self {
grad_id,
recorded_version: storage.version(),
weak_storage: storage.downgrade(),
}
}
pub fn check(&self) -> Result<(), AutogradError> {
match self.weak_storage.upgrade() {
Some(strong) => {
let current = strong.version();
if current != self.recorded_version {
Err(AutogradError::VersionMismatch {
grad_id: self.grad_id,
expected: self.recorded_version,
found: current,
})
} else {
Ok(())
}
}
None => Ok(()),
}
}
}
#[derive(Debug)]
pub struct AddBackward {
pub lhs_version: VersionSnapshot,
pub rhs_version: VersionSnapshot,
}
#[derive(Debug)]
pub struct SubBackward {
pub lhs_version: VersionSnapshot,
pub rhs_version: VersionSnapshot,
}
#[derive(Debug)]
pub struct MulBackward {
pub lhs_storage: StorageHandle,
pub lhs_layout: Layout,
pub lhs_version: VersionSnapshot,
pub rhs_storage: StorageHandle,
pub rhs_layout: Layout,
pub rhs_version: VersionSnapshot,
}
#[derive(Debug)]
pub struct MatmulBackward {
pub lhs_storage: StorageHandle,
pub lhs_layout: Layout,
pub lhs_version: VersionSnapshot,
pub rhs_storage: StorageHandle,
pub rhs_layout: Layout,
pub rhs_version: VersionSnapshot,
pub m: usize,
pub k: usize,
pub n: usize,
}
#[derive(Debug)]
pub struct ReluBackward {
pub input_storage: StorageHandle,
pub input_layout: Layout,
pub input_version: VersionSnapshot,
}
#[derive(Debug)]
pub struct MseLossBackward {
pub pred_storage: StorageHandle,
pub pred_layout: Layout,
pub pred_version: VersionSnapshot,
pub target_storage: StorageHandle,
pub target_layout: Layout,
pub target_version: VersionSnapshot,
pub numel: usize,
}
#[derive(Debug)]
pub struct AddBiasBackward {
pub input_version: VersionSnapshot,
pub bias_version: VersionSnapshot,
pub m: usize,
pub n: usize,
}
#[derive(Debug)]
pub struct SliceBatchBackward {
pub input_version: VersionSnapshot,
pub original_shape: Vec<usize>,
pub index: usize,
}
#[derive(Debug)]
pub struct Im2ColBackward {
pub input_version: VersionSnapshot,
pub c_in: usize,
pub h: usize,
pub w: usize,
pub kernel_size: usize,
pub stride: usize,
pub padding: usize,
pub out_h: usize,
pub out_w: usize,
}
#[derive(Debug)]
pub struct StackBackward {
pub count: usize,
pub each_shape: Vec<usize>,
pub versions: Vec<VersionSnapshot>,
}
#[derive(Debug)]
pub struct AddChannelBiasBackward {
pub input_version: VersionSnapshot,
pub bias_version: VersionSnapshot,
pub channels: usize,
pub spatial: usize,
}
#[derive(Debug)]
pub struct MaxPool2dBackward {
pub input_version: VersionSnapshot,
pub indices_storage: StorageHandle,
pub indices_layout: Layout,
pub channels: usize,
pub h: usize,
pub w: usize,
pub out_h: usize,
pub out_w: usize,
}
#[derive(Debug)]
pub struct ReshapeBackward {
pub input_version: VersionSnapshot,
pub original_shape: Vec<usize>,
}
#[derive(Debug)]
pub struct FlattenBackward {
pub input_version: VersionSnapshot,
pub original_shape: Vec<usize>,
}
#[derive(Debug)]
pub struct CrossEntropyBackward {
pub input_version: VersionSnapshot,
pub grad_storage: StorageHandle,
pub grad_layout: Layout,
}
#[derive(Debug)]
pub struct DropoutBackward {
pub input_version: VersionSnapshot,
pub mask_storage: StorageHandle,
pub mask_layout: Layout,
}
#[derive(Debug)]
pub struct TransposeBackward {
pub input_version: VersionSnapshot,
pub dim0: usize,
pub dim1: usize,
}
#[derive(Debug)]
pub struct BmmBackward {
pub lhs_storage: StorageHandle,
pub lhs_layout: Layout,
pub lhs_version: VersionSnapshot,
pub rhs_storage: StorageHandle,
pub rhs_layout: Layout,
pub rhs_version: VersionSnapshot,
pub batch: usize,
pub m: usize,
pub k: usize,
pub n: usize,
}
#[derive(Debug)]
pub struct SoftmaxBackward {
pub output_storage: StorageHandle,
pub output_layout: Layout,
pub input_version: VersionSnapshot,
pub num_rows: usize,
pub row_size: usize,
}
#[derive(Debug)]
pub struct LayerNormBackward {
pub input_storage: StorageHandle,
pub input_layout: Layout,
pub input_version: VersionSnapshot,
pub weight_storage: StorageHandle,
pub weight_layout: Layout,
pub weight_version: VersionSnapshot,
pub save_storage: StorageHandle, pub save_layout: Layout,
pub num_instances: usize,
pub norm_size: usize,
}
#[derive(Debug)]
pub struct EmbeddingBackward {
pub input_version: VersionSnapshot,
pub indices_storage: StorageHandle,
pub indices_layout: Layout,
pub vocab_size: usize,
pub embed_dim: usize,
pub total_lookups: usize,
}
#[derive(Debug)]
pub struct SigmoidBackward {
pub output_storage: StorageHandle,
pub output_layout: Layout,
pub input_version: VersionSnapshot,
}
#[derive(Debug)]
pub struct TanhBackward {
pub output_storage: StorageHandle,
pub output_layout: Layout,
pub input_version: VersionSnapshot,
}
#[derive(Debug)]
pub struct GeluBackward {
pub input_storage: StorageHandle,
pub input_layout: Layout,
pub input_version: VersionSnapshot,
}
#[derive(Debug)]
pub struct LeakyReluBackward {
pub input_storage: StorageHandle,
pub input_layout: Layout,
pub input_version: VersionSnapshot,
pub alpha: f32,
}
#[derive(Debug)]
pub struct BroadcastAddBackward {
pub lhs_version: VersionSnapshot,
pub rhs_version: VersionSnapshot,
pub lhs_broadcast: Option<crate::tensor::broadcast::BroadcastInfo>,
pub rhs_broadcast: Option<crate::tensor::broadcast::BroadcastInfo>,
pub output_shape: Vec<usize>,
}
#[derive(Debug)]
pub struct BroadcastSubBackward {
pub lhs_version: VersionSnapshot,
pub rhs_version: VersionSnapshot,
pub lhs_broadcast: Option<crate::tensor::broadcast::BroadcastInfo>,
pub rhs_broadcast: Option<crate::tensor::broadcast::BroadcastInfo>,
pub output_shape: Vec<usize>,
}
#[derive(Debug)]
pub struct BroadcastMulBackward {
pub lhs_storage: StorageHandle,
pub lhs_layout: Layout,
pub lhs_version: VersionSnapshot,
pub rhs_storage: StorageHandle,
pub rhs_layout: Layout,
pub rhs_version: VersionSnapshot,
pub lhs_broadcast: Option<crate::tensor::broadcast::BroadcastInfo>,
pub rhs_broadcast: Option<crate::tensor::broadcast::BroadcastInfo>,
pub output_shape: Vec<usize>,
}
#[derive(Debug)]
pub struct BatchNorm2dBackward {
pub input_storage: StorageHandle,
pub input_layout: Layout,
pub input_version: VersionSnapshot,
pub weight_storage: StorageHandle,
pub weight_layout: Layout,
pub weight_version: VersionSnapshot,
pub save_storage: StorageHandle, pub save_layout: Layout,
pub batch: usize,
pub channels: usize,
pub height: usize,
pub width: usize,
}
#[derive(Debug)]
pub struct AdaptiveAvgPool2dBackward {
pub input_version: VersionSnapshot,
pub batch: usize,
pub channels: usize,
pub h_in: usize,
pub w_in: usize,
pub h_out: usize,
pub w_out: usize,
}
#[derive(Debug)]
pub struct CastBackward {
pub input_version: VersionSnapshot,
pub source_dtype: crate::tensor::DType,
}
#[derive(Debug)]
pub enum BackwardOp {
Add(AddBackward),
Sub(SubBackward),
Mul(MulBackward),
Matmul(MatmulBackward),
Relu(ReluBackward),
MseLoss(MseLossBackward),
AddBias(AddBiasBackward),
Im2Col(Im2ColBackward),
Stack(StackBackward),
AddChannelBias(AddChannelBiasBackward),
SliceBatch(SliceBatchBackward),
MaxPool2d(MaxPool2dBackward),
Flatten(FlattenBackward),
Reshape(ReshapeBackward),
Dropout(DropoutBackward),
CrossEntropy(CrossEntropyBackward),
Sigmoid(SigmoidBackward),
Tanh(TanhBackward),
Gelu(GeluBackward),
LeakyRelu(LeakyReluBackward),
Transpose(TransposeBackward),
Bmm(BmmBackward),
Softmax(SoftmaxBackward),
LayerNorm(LayerNormBackward),
Embedding(EmbeddingBackward),
BroadcastAdd(BroadcastAddBackward),
BroadcastSub(BroadcastSubBackward),
BroadcastMul(BroadcastMulBackward),
BatchNorm2d(BatchNorm2dBackward),
AdaptiveAvgPool2d(AdaptiveAvgPool2dBackward),
Cast(CastBackward),
SliceRange(SliceRangeBackward),
Cat(CatBackward),
#[cfg(feature = "multi_gpu")]
FsdpLinear(FsdpLinearBackward),
Custom(CustomBackwardOp),
}
pub trait CustomBackward: Send + Sync + std::fmt::Debug {
fn backward(&self, out_grad: &Tensor, saved: &[Tensor]) -> Vec<Tensor>;
}
pub struct CustomBackwardOp {
pub handler: Arc<dyn CustomBackward>,
pub input_versions: Vec<VersionSnapshot>,
pub saved_storages: Vec<StorageHandle>,
pub saved_layouts: Vec<Layout>,
pub saved_shapes: Vec<Vec<usize>>,
}
impl std::fmt::Debug for CustomBackwardOp {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CustomBackwardOp")
.field("handler", &self.handler)
.field("num_saved", &self.saved_storages.len())
.finish()
}
}
#[cfg(feature = "multi_gpu")]
pub struct FsdpSync {
pub world_size: usize,
pub state: std::sync::Mutex<FsdpSyncState>,
pub cvar: std::sync::Condvar,
}
#[cfg(feature = "multi_gpu")]
pub struct FsdpSyncState {
pub weight_grads: Vec<Vec<f32>>,
pub bias_grads: Vec<Vec<f32>>,
pub weight_result: Option<Vec<f32>>,
pub bias_result: Option<Vec<f32>>,
pub read_count: usize,
}
#[cfg(feature = "multi_gpu")]
impl FsdpSync {
pub fn new(world_size: usize) -> Self {
Self {
world_size,
state: std::sync::Mutex::new(FsdpSyncState {
weight_grads: Vec::new(),
bias_grads: Vec::new(),
weight_result: None,
bias_result: None,
read_count: 0,
}),
cvar: std::sync::Condvar::new(),
}
}
}
#[cfg(feature = "multi_gpu")]
impl std::fmt::Debug for FsdpSync {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FsdpSync")
.field("world_size", &self.world_size)
.finish()
}
}
#[cfg(feature = "multi_gpu")]
unsafe impl Send for FsdpSync {}
#[cfg(feature = "multi_gpu")]
unsafe impl Sync for FsdpSync {}
#[cfg(feature = "multi_gpu")]
#[derive(Debug)]
pub struct FsdpLinearBackward {
pub input_version: VersionSnapshot,
pub input_storage: StorageHandle,
pub input_layout: Layout,
pub weight_shard_storages: Vec<StorageHandle>,
pub weight_shard_layouts: Vec<Layout>,
pub full_weight_shape: Vec<usize>,
pub shard_size: usize,
pub weight_shard_offset: usize,
pub rank: usize,
pub world_size: usize,
pub device_index: usize,
pub has_bias: bool,
pub bias_shard_storages: Vec<StorageHandle>,
pub full_bias_shape: Vec<usize>,
pub bias_shard_offset: usize,
pub bias_shard_size: usize,
pub sync: std::sync::Arc<FsdpSync>,
}
#[derive(Debug)]
pub struct SliceRangeBackward {
pub input_version: VersionSnapshot,
pub original_shape: Vec<usize>,
pub dim: usize,
pub start: usize,
pub end: usize,
}
#[derive(Debug)]
pub struct CatBackward {
pub splits: Vec<usize>, pub dim: usize,
pub versions: Vec<VersionSnapshot>,
}
const _: () = {
fn _assert_send<T: Send>() {}
fn _assert_sync<T: Sync>() {}
fn _assertions() {
_assert_send::<BackwardOp>();
_assert_sync::<BackwardOp>();
}
};