use crate::autograd::AutogradError;
use crate::tensor::{GradId, Layout, StorageHandle, WeakStorageHandle};
#[derive(Debug, Clone)]
pub struct VersionSnapshot {
pub grad_id: GradId,
pub weak_storage: WeakStorageHandle,
pub recorded_version: usize,
}
impl VersionSnapshot {
pub fn new(grad_id: GradId, storage: &StorageHandle) -> Self {
Self {
grad_id,
recorded_version: storage.version(),
weak_storage: storage.downgrade(),
}
}
pub fn check(&self) -> Result<(), AutogradError> {
match self.weak_storage.upgrade() {
Some(strong) => {
let current = strong.version();
if current != self.recorded_version {
Err(AutogradError::VersionMismatch {
grad_id: self.grad_id,
expected: self.recorded_version,
found: current,
})
} else {
Ok(())
}
}
None => Ok(()),
}
}
}
#[derive(Debug)]
pub struct AddBackward {
pub lhs_version: VersionSnapshot,
pub rhs_version: VersionSnapshot,
}
#[derive(Debug)]
pub struct SubBackward {
pub lhs_version: VersionSnapshot,
pub rhs_version: VersionSnapshot,
}
#[derive(Debug)]
pub struct MulBackward {
pub lhs_storage: StorageHandle,
pub lhs_layout: Layout,
pub lhs_version: VersionSnapshot,
pub rhs_storage: StorageHandle,
pub rhs_layout: Layout,
pub rhs_version: VersionSnapshot,
}
#[derive(Debug)]
pub struct MatmulBackward {
pub lhs_storage: StorageHandle,
pub lhs_layout: Layout,
pub lhs_version: VersionSnapshot,
pub rhs_storage: StorageHandle,
pub rhs_layout: Layout,
pub rhs_version: VersionSnapshot,
pub m: usize,
pub k: usize,
pub n: usize,
}
#[derive(Debug)]
pub struct ReluBackward {
pub input_storage: StorageHandle,
pub input_layout: Layout,
pub input_version: VersionSnapshot,
}
#[derive(Debug)]
pub struct MseLossBackward {
pub pred_storage: StorageHandle,
pub pred_layout: Layout,
pub pred_version: VersionSnapshot,
pub target_storage: StorageHandle,
pub target_layout: Layout,
pub target_version: VersionSnapshot,
pub numel: usize,
}
#[derive(Debug)]
pub struct AddBiasBackward {
pub input_version: VersionSnapshot,
pub bias_version: VersionSnapshot,
pub m: usize,
pub n: usize,
}
#[derive(Debug)]
pub struct SliceBatchBackward {
pub input_version: VersionSnapshot,
pub original_shape: Vec<usize>,
pub index: usize,
}
#[derive(Debug)]
pub struct Im2ColBackward {
pub input_version: VersionSnapshot,
pub c_in: usize,
pub h: usize,
pub w: usize,
pub kernel_size: usize,
pub stride: usize,
pub padding: usize,
pub out_h: usize,
pub out_w: usize,
}
#[derive(Debug)]
pub struct StackBackward {
pub count: usize,
pub each_shape: Vec<usize>,
pub versions: Vec<VersionSnapshot>,
}
#[derive(Debug)]
pub struct AddChannelBiasBackward {
pub input_version: VersionSnapshot,
pub bias_version: VersionSnapshot,
pub channels: usize,
pub spatial: usize,
}
#[derive(Debug)]
pub struct MaxPool2dBackward {
pub input_version: VersionSnapshot,
pub indices_storage: StorageHandle,
pub indices_layout: Layout,
pub channels: usize,
pub h: usize,
pub w: usize,
pub out_h: usize,
pub out_w: usize,
}
#[derive(Debug)]
pub struct ReshapeBackward {
pub input_version: VersionSnapshot,
pub original_shape: Vec<usize>,
}
#[derive(Debug)]
pub struct FlattenBackward {
pub input_version: VersionSnapshot,
pub original_shape: Vec<usize>,
}
#[derive(Debug)]
pub struct CrossEntropyBackward {
pub input_version: VersionSnapshot,
pub grad_storage: StorageHandle,
pub grad_layout: Layout,
}
#[derive(Debug)]
pub struct DropoutBackward {
pub input_version: VersionSnapshot,
pub mask_storage: StorageHandle,
pub mask_layout: Layout,
}
#[derive(Debug)]
pub enum BackwardOp {
Add(AddBackward),
Sub(SubBackward),
Mul(MulBackward),
Matmul(MatmulBackward),
Relu(ReluBackward),
MseLoss(MseLossBackward),
AddBias(AddBiasBackward),
Im2Col(Im2ColBackward),
Stack(StackBackward),
AddChannelBias(AddChannelBiasBackward),
SliceBatch(SliceBatchBackward),
MaxPool2d(MaxPool2dBackward),
Flatten(FlattenBackward),
Reshape(ReshapeBackward),
Dropout(DropoutBackward),
CrossEntropy(CrossEntropyBackward),
}
const _: () = {
fn _assert_send<T: Send>() {}
fn _assert_sync<T: Sync>() {}
fn _assertions() {
_assert_send::<BackwardOp>();
_assert_sync::<BackwardOp>();
}
};