pub fn dispatch_softmax_backward(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &DeviceRef,
y: &MlxBuffer,
dy: &MlxBuffer,
dx: &MlxBuffer,
params_buf: &MlxBuffer,
rows: u32,
cols: u32,
) -> Result<()>Expand description
Encode the softmax backward kernel.
§Arguments
encoder— Command encoder.registry— Kernel registry (must have softmax_backward source registered).device— Metal device.y— Forward softmax output[rows, cols], f32.dy— Upstream gradient[rows, cols], f32.dx— Output gradient[rows, cols], f32 (must be pre-allocated).params_buf— Params buffer containing[cols, 0]as f32.rows— Row count (one threadgroup per row).cols— Column count.
§Errors
Returns MlxError::InvalidArgument if shapes are inconsistent or
any buffer is too small.