pub fn dispatch_softmax_sample_f32(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &DeviceRef,
logits: &MlxBuffer,
scratch: &MlxBuffer,
out_token: &MlxBuffer,
out_logprob: &MlxBuffer,
params_buf: &MlxBuffer,
n_elements: u32,
temperature: f32,
random_val: f32,
) -> Result<()>Expand description
Dispatch a temperature-scaled softmax + categorical sample on the GPU.
Computes softmax(logits / temperature) entirely on the GPU, then samples
one token index using the provided uniform random value. Only 8 bytes
(token_id u32 + logprob f32) are transferred back to the CPU.
§Arguments
encoder- Command encoder to record the dispatch into.registry- Kernel registry (must havesoftmax_sample_f32registered).device- Metal device for pipeline compilation.logits- Input logits buffer[n_elements](f32).scratch- Scratch buffer[n_elements](f32) used for intermediate probability values. May be a transient allocation; must not aliaslogits.out_token- Output buffer[1](u32) — sampled token index.out_logprob- Output buffer[1](f32) — log-probability of the sampled token.params_buf- Params buffer[3](f32) containing:[n_elements as f32, temperature, random_val]n_elements- Vocabulary size (number of logits).temperature- Sampling temperature (must be > 0.0).random_val- Uniform random value in[0, 1)for categorical sample.
§Errors
Returns MlxError::InvalidArgument if:
n_elementsis 0.temperatureis not positive.random_valis not in[0, 1).- Buffer sizes are inconsistent.