Skip to main content

Module binary_param_backward

Module binary_param_backward 

Source
Expand description

Backward plan for the parameterized binary elementwise family.

Sibling of crate::BinaryParamPlan. For Lerp, the BW formula is da = (1 - weight)·dy, db = weight·dy — no saved tensors are needed because the gradient is a pure linear scaling of dy by constants derived from the scalar weight.

Today wired: Lerp × {f32, f16, bf16, f64}. The scalar weight is a constant w.r.t. both inputs — no gradient flows to it.

Trailblazer constraints: contig-only; dy.shape == da.shape == db.shape == desc.shape.

Structs§

BinaryParamBackwardArgs
Args bundle for a parameterized binary backward launch.
BinaryParamBackwardDescriptor
Descriptor for a parameterized binary backward op. Same shape as the FW descriptor.
BinaryParamBackwardPlan
Parameterized binary backward plan.