Expand description
Numerically stable softmax GPU dispatch.
Computes softmax along the last dimension of a 2D tensor using the subtract-max trick for numerical stability. All accumulations use f32 even when inputs are f16 to prevent overflow.
Statics§
- SOFTMAX_
SHADER_ SOURCE - MSL source for the softmax kernels (embedded at compile time).
Functions§
- dispatch_
softmax - Dispatch a softmax operation on the GPU.
- register
- Register softmax shader sources with the given kernel registry.