Expand description
Backward plan for the binary elementwise family.
Sibling of crate::BinaryPlan for gradient computation:
(da, db) = backward(dy, [saved tensors per op]).
Today wired: {Add, Sub, Mul, Div, Maximum, Minimum} × {f32, f16, bf16, f64}.
Add and Sub need no saved tensors; Mul, Div, Maximum, Minimum require
the saved forward inputs a and b:
- Add:
(da, db) = (dy, dy)— no saved - Sub:
(da, db) = (dy, -dy)— no saved - Mul:
(da, db) = (dy * b, dy * a)— needs saveda,b - Div:
(da, db) = (dy / b, -dy * a / b²)— needs saveda,b - Maximum / Minimum: saves used purely as comparison references; tie
splits
dyevenly (PyTorch parity). For Maximum:da = where(a==b, dy/2, where(a<b, 0, dy)),db = where(a==b, dy/2, where(b<a, 0, dy)). Minimum flips</>. NaN inputs propagatedyto both (all comparisons false).
The Args struct carries a and b as Option<TensorRef> so
callers omit them for ops that don’t need them. The dispatcher
validates that needed saves are present.
Trailblazer constraints (same shape limits as the forward
trailblazer): contig-only (no broadcasting); dy.shape == da.shape == db.shape. Ops with saves additionally require
a.shape == b.shape == dy.shape.
Structs§
- Binary
Backward Args - Args bundle for a binary backward launch.
- Binary
Backward Descriptor - Descriptor for a binary backward op.
- Binary
Backward Plan - Binary backward plan.