Skip to main content

Module softmax_backward

Module softmax_backward 

Source
Expand description

Backward pass for row-wise softmax.

Given y = softmax(x) along the last dim and upstream gradient dy of the same shape, computes

dx[b, i] = y[b, i] · (dy[b, i] − Σ_j y[b, j] · dy[b, j])

Companion to crate::ops::softmax::dispatch_softmax. Used by reverse-mode autograd in hf2q’s calibrate module (ADR-020 Track 1).

Threadgroup-per-row layout matches softmax forward: one threadgroup processes one full row, doing a tree reduction over the columns.

Functions§

dispatch_softmax_backward
Encode the softmax backward kernel.
register