Skip to main content

Module softmax_sample

Module softmax_sample 

Source
Expand description

Temperature-scaled softmax + categorical sample, entirely on GPU.

For stochastic (temperature > 0) decoding this replaces the 1MB GPU→CPU logits readback with an 8-byte readback: the sampled token index and its log-probability.

The kernel runs three parallel passes (max, exp-sum, normalize) using threadgroup reductions, then a sequential CDF scan by thread 0 to draw the sample.

Statics§

SOFTMAX_SAMPLE_SHADER_SOURCE
MSL source for the softmax_sample kernel (embedded at compile time).

Functions§

dispatch_softmax_sample_f32
Dispatch a temperature-scaled softmax + categorical sample on the GPU.
register
Register the softmax_sample shader source with the given kernel registry.