Expand description
Temperature-scaled softmax + categorical sample, entirely on GPU.
For stochastic (temperature > 0) decoding this replaces the 1MB GPU→CPU logits readback with an 8-byte readback: the sampled token index and its log-probability.
The kernel runs three parallel passes (max, exp-sum, normalize) using threadgroup reductions, then a sequential CDF scan by thread 0 to draw the sample.
Statics§
- SOFTMAX_
SAMPLE_ SHADER_ SOURCE - MSL source for the softmax_sample kernel (embedded at compile time).
Functions§
- dispatch_
softmax_ sample_ f32 - Dispatch a temperature-scaled softmax + categorical sample on the GPU.
- register
- Register the softmax_sample shader source with the given kernel registry.