use crate::autograd::AutogradError;
use crate::tensor::{GradId, Layout, StorageHandle, WeakStorageHandle};
#[derive(Debug, Clone)]
pub struct VersionSnapshot {
pub grad_id: GradId,
pub weak_storage: WeakStorageHandle,
pub recorded_version: usize,
}
impl VersionSnapshot {
pub fn new(grad_id: GradId, storage: &StorageHandle) -> Self {
Self {
grad_id,
recorded_version: storage.version(),
weak_storage: storage.downgrade(),
}
}
pub fn check(&self) -> Result<(), AutogradError> {
match self.weak_storage.upgrade() {
Some(strong) => {
let current = strong.version();
if current != self.recorded_version {
Err(AutogradError::VersionMismatch {
grad_id: self.grad_id,
expected: self.recorded_version,
found: current,
})
} else {
Ok(())
}
}
None => Ok(()),
}
}
}
#[derive(Debug)]
pub struct AddBackward {
pub lhs_version: VersionSnapshot,
pub rhs_version: VersionSnapshot,
}
#[derive(Debug)]
pub struct SubBackward {
pub lhs_version: VersionSnapshot,
pub rhs_version: VersionSnapshot,
}
#[derive(Debug)]
pub struct MulBackward {
pub lhs_storage: StorageHandle,
pub lhs_layout: Layout,
pub lhs_version: VersionSnapshot,
pub rhs_storage: StorageHandle,
pub rhs_layout: Layout,
pub rhs_version: VersionSnapshot,
}
#[derive(Debug)]
pub struct MatmulBackward {
pub lhs_storage: StorageHandle,
pub lhs_layout: Layout,
pub lhs_version: VersionSnapshot,
pub rhs_storage: StorageHandle,
pub rhs_layout: Layout,
pub rhs_version: VersionSnapshot,
pub m: usize,
pub k: usize,
pub n: usize,
}
#[derive(Debug)]
pub struct ReluBackward {
pub input_storage: StorageHandle,
pub input_layout: Layout,
pub input_version: VersionSnapshot,
}
#[derive(Debug)]
pub struct MseLossBackward {
pub pred_storage: StorageHandle,
pub pred_layout: Layout,
pub pred_version: VersionSnapshot,
pub target_storage: StorageHandle,
pub target_layout: Layout,
pub target_version: VersionSnapshot,
pub numel: usize,
}
#[derive(Debug)]
pub struct AddBiasBackward {
pub input_version: VersionSnapshot,
pub bias_version: VersionSnapshot,
pub m: usize,
pub n: usize,
}
#[derive(Debug)]
pub struct SliceBatchBackward {
pub input_version: VersionSnapshot,
pub original_shape: Vec<usize>,
pub index: usize,
}
#[derive(Debug)]
pub struct Im2ColBackward {
pub input_version: VersionSnapshot,
pub c_in: usize,
pub h: usize,
pub w: usize,
pub kernel_size: usize,
pub stride: usize,
pub padding: usize,
pub out_h: usize,
pub out_w: usize,
}
#[derive(Debug)]
pub struct StackBackward {
pub count: usize,
pub each_shape: Vec<usize>,
pub versions: Vec<VersionSnapshot>,
}
#[derive(Debug)]
pub struct AddChannelBiasBackward {
pub input_version: VersionSnapshot,
pub bias_version: VersionSnapshot,
pub channels: usize,
pub spatial: usize,
}
#[derive(Debug)]
pub struct MaxPool2dBackward {
pub input_version: VersionSnapshot,
pub indices_storage: StorageHandle,
pub indices_layout: Layout,
pub channels: usize,
pub h: usize,
pub w: usize,
pub out_h: usize,
pub out_w: usize,
}
#[derive(Debug)]
pub struct ReshapeBackward {
pub input_version: VersionSnapshot,
pub original_shape: Vec<usize>,
}
#[derive(Debug)]
pub struct FlattenBackward {
pub input_version: VersionSnapshot,
pub original_shape: Vec<usize>,
}
#[derive(Debug)]
pub struct CrossEntropyBackward {
pub input_version: VersionSnapshot,
pub grad_storage: StorageHandle,
pub grad_layout: Layout,
}
#[derive(Debug)]
pub struct DropoutBackward {
pub input_version: VersionSnapshot,
pub mask_storage: StorageHandle,
pub mask_layout: Layout,
}
#[derive(Debug)]
pub struct TransposeBackward {
pub input_version: VersionSnapshot,
pub dim0: usize,
pub dim1: usize,
}
#[derive(Debug)]
pub struct BmmBackward {
pub lhs_storage: StorageHandle,
pub lhs_layout: Layout,
pub lhs_version: VersionSnapshot,
pub rhs_storage: StorageHandle,
pub rhs_layout: Layout,
pub rhs_version: VersionSnapshot,
pub batch: usize,
pub m: usize,
pub k: usize,
pub n: usize,
}
#[derive(Debug)]
pub struct SoftmaxBackward {
pub output_storage: StorageHandle,
pub output_layout: Layout,
pub input_version: VersionSnapshot,
pub num_rows: usize,
pub row_size: usize,
}
#[derive(Debug)]
pub struct LayerNormBackward {
pub input_storage: StorageHandle,
pub input_layout: Layout,
pub input_version: VersionSnapshot,
pub weight_storage: StorageHandle,
pub weight_layout: Layout,
pub weight_version: VersionSnapshot,
pub save_storage: StorageHandle, pub save_layout: Layout,
pub num_instances: usize,
pub norm_size: usize,
}
#[derive(Debug)]
pub struct EmbeddingBackward {
pub input_version: VersionSnapshot,
pub indices_storage: StorageHandle,
pub indices_layout: Layout,
pub vocab_size: usize,
pub embed_dim: usize,
pub total_lookups: usize,
}
#[derive(Debug)]
pub struct SigmoidBackward {
pub output_storage: StorageHandle,
pub output_layout: Layout,
pub input_version: VersionSnapshot,
}
#[derive(Debug)]
pub struct TanhBackward {
pub output_storage: StorageHandle,
pub output_layout: Layout,
pub input_version: VersionSnapshot,
}
#[derive(Debug)]
pub struct GeluBackward {
pub input_storage: StorageHandle,
pub input_layout: Layout,
pub input_version: VersionSnapshot,
}
#[derive(Debug)]
pub struct LeakyReluBackward {
pub input_storage: StorageHandle,
pub input_layout: Layout,
pub input_version: VersionSnapshot,
pub alpha: f32,
}
#[derive(Debug)]
pub struct BroadcastAddBackward {
pub lhs_version: VersionSnapshot,
pub rhs_version: VersionSnapshot,
pub lhs_broadcast: Option<crate::tensor::broadcast::BroadcastInfo>,
pub rhs_broadcast: Option<crate::tensor::broadcast::BroadcastInfo>,
pub output_shape: Vec<usize>,
}
#[derive(Debug)]
pub struct BroadcastSubBackward {
pub lhs_version: VersionSnapshot,
pub rhs_version: VersionSnapshot,
pub lhs_broadcast: Option<crate::tensor::broadcast::BroadcastInfo>,
pub rhs_broadcast: Option<crate::tensor::broadcast::BroadcastInfo>,
pub output_shape: Vec<usize>,
}
#[derive(Debug)]
pub struct BroadcastMulBackward {
pub lhs_storage: StorageHandle,
pub lhs_layout: Layout,
pub lhs_version: VersionSnapshot,
pub rhs_storage: StorageHandle,
pub rhs_layout: Layout,
pub rhs_version: VersionSnapshot,
pub lhs_broadcast: Option<crate::tensor::broadcast::BroadcastInfo>,
pub rhs_broadcast: Option<crate::tensor::broadcast::BroadcastInfo>,
pub output_shape: Vec<usize>,
}
#[derive(Debug)]
pub struct BatchNorm2dBackward {
pub input_storage: StorageHandle,
pub input_layout: Layout,
pub input_version: VersionSnapshot,
pub weight_storage: StorageHandle,
pub weight_layout: Layout,
pub weight_version: VersionSnapshot,
pub save_storage: StorageHandle, pub save_layout: Layout,
pub batch: usize,
pub channels: usize,
pub height: usize,
pub width: usize,
}
#[derive(Debug)]
pub struct AdaptiveAvgPool2dBackward {
pub input_version: VersionSnapshot,
pub batch: usize,
pub channels: usize,
pub h_in: usize,
pub w_in: usize,
pub h_out: usize,
pub w_out: usize,
}
#[derive(Debug)]
pub enum BackwardOp {
Add(AddBackward),
Sub(SubBackward),
Mul(MulBackward),
Matmul(MatmulBackward),
Relu(ReluBackward),
MseLoss(MseLossBackward),
AddBias(AddBiasBackward),
Im2Col(Im2ColBackward),
Stack(StackBackward),
AddChannelBias(AddChannelBiasBackward),
SliceBatch(SliceBatchBackward),
MaxPool2d(MaxPool2dBackward),
Flatten(FlattenBackward),
Reshape(ReshapeBackward),
Dropout(DropoutBackward),
CrossEntropy(CrossEntropyBackward),
Sigmoid(SigmoidBackward),
Tanh(TanhBackward),
Gelu(GeluBackward),
LeakyRelu(LeakyReluBackward),
Transpose(TransposeBackward),
Bmm(BmmBackward),
Softmax(SoftmaxBackward),
LayerNorm(LayerNormBackward),
Embedding(EmbeddingBackward),
BroadcastAdd(BroadcastAddBackward),
BroadcastSub(BroadcastSubBackward),
BroadcastMul(BroadcastMulBackward),
BatchNorm2d(BatchNorm2dBackward),
AdaptiveAvgPool2d(AdaptiveAvgPool2dBackward),
}
const _: () = {
fn _assert_send<T: Send>() {}
fn _assert_sync<T: Sync>() {}
fn _assertions() {
_assert_send::<BackwardOp>();
_assert_sync::<BackwardOp>();
}
};