Skip to main content

Module binary_backward

Module binary_backward 

Source
Expand description

Backward plan for the binary elementwise family.

Sibling of crate::BinaryPlan for gradient computation: (da, db) = backward(dy, [saved tensors per op]).

Today wired: {Add, Sub, Mul, Div, Maximum, Minimum} × {f32, f16, bf16, f64}. Add and Sub need no saved tensors; Mul, Div, Maximum, Minimum require the saved forward inputs a and b:

  • Add: (da, db) = (dy, dy) — no saved
  • Sub: (da, db) = (dy, -dy) — no saved
  • Mul: (da, db) = (dy * b, dy * a) — needs saved a, b
  • Div: (da, db) = (dy / b, -dy * a / b²) — needs saved a, b
  • Maximum / Minimum: saves used purely as comparison references; tie splits dy evenly (PyTorch parity). For Maximum: da = where(a==b, dy/2, where(a<b, 0, dy)), db = where(a==b, dy/2, where(b<a, 0, dy)). Minimum flips </>. NaN inputs propagate dy to both (all comparisons false).

The Args struct carries a and b as Option<TensorRef> so callers omit them for ops that don’t need them. The dispatcher validates that needed saves are present.

Trailblazer constraints (same shape limits as the forward trailblazer): contig-only (no broadcasting); dy.shape == da.shape == db.shape. Ops with saves additionally require a.shape == b.shape == dy.shape.

Structs§

BinaryBackwardArgs
Args bundle for a binary backward launch.
BinaryBackwardDescriptor
Descriptor for a binary backward op.
BinaryBackwardPlan
Binary backward plan.