Skip to main content

dispatch_softmax_backward

Function dispatch_softmax_backward 

Source
pub fn dispatch_softmax_backward(
    encoder: &mut CommandEncoder,
    registry: &mut KernelRegistry,
    device: &DeviceRef,
    y: &MlxBuffer,
    dy: &MlxBuffer,
    dx: &MlxBuffer,
    params_buf: &MlxBuffer,
    rows: u32,
    cols: u32,
) -> Result<()>
Expand description

Encode the softmax backward kernel.

§Arguments

  • encoder — Command encoder.
  • registry — Kernel registry (must have softmax_backward source registered).
  • device — Metal device.
  • y — Forward softmax output [rows, cols], f32.
  • dy — Upstream gradient [rows, cols], f32.
  • dx — Output gradient [rows, cols], f32 (must be pre-allocated).
  • params_buf — Params buffer containing [cols, 0] as f32.
  • rows — Row count (one threadgroup per row).
  • cols — Column count.

§Errors

Returns MlxError::InvalidArgument if shapes are inconsistent or any buffer is too small.