Skip to main content

Module softmax

Module softmax 

Source
Expand description

Numerically stable softmax GPU dispatch.

Computes softmax along the last dimension of a 2D tensor using the subtract-max trick for numerical stability. All accumulations use f32 even when inputs are f16 to prevent overflow.

Statics§

SOFTMAX_SHADER_SOURCE
MSL source for the softmax kernels (embedded at compile time).

Functions§

dispatch_softmax
Dispatch a softmax operation on the GPU.
register
Register softmax shader sources with the given kernel registry.