Skip to main content

dispatch_softmax_sample_f32

Function dispatch_softmax_sample_f32 

Source
pub fn dispatch_softmax_sample_f32(
    encoder: &mut CommandEncoder,
    registry: &mut KernelRegistry,
    device: &DeviceRef,
    logits: &MlxBuffer,
    scratch: &MlxBuffer,
    out_token: &MlxBuffer,
    out_logprob: &MlxBuffer,
    params_buf: &MlxBuffer,
    n_elements: u32,
    temperature: f32,
    random_val: f32,
) -> Result<()>
Expand description

Dispatch a temperature-scaled softmax + categorical sample on the GPU.

Computes softmax(logits / temperature) entirely on the GPU, then samples one token index using the provided uniform random value. Only 8 bytes (token_id u32 + logprob f32) are transferred back to the CPU.

§Arguments

  • encoder - Command encoder to record the dispatch into.
  • registry - Kernel registry (must have softmax_sample_f32 registered).
  • device - Metal device for pipeline compilation.
  • logits - Input logits buffer [n_elements] (f32).
  • scratch - Scratch buffer [n_elements] (f32) used for intermediate probability values. May be a transient allocation; must not alias logits.
  • out_token - Output buffer [1] (u32) — sampled token index.
  • out_logprob - Output buffer [1] (f32) — log-probability of the sampled token.
  • params_buf - Params buffer [3] (f32) containing: [n_elements as f32, temperature, random_val]
  • n_elements - Vocabulary size (number of logits).
  • temperature - Sampling temperature (must be > 0.0).
  • random_val - Uniform random value in [0, 1) for categorical sample.

§Errors

Returns MlxError::InvalidArgument if:

  • n_elements is 0.
  • temperature is not positive.
  • random_val is not in [0, 1).
  • Buffer sizes are inconsistent.