Skip to main content

dispatch_argmax_f32

Function dispatch_argmax_f32 

Source
pub fn dispatch_argmax_f32(
    encoder: &mut CommandEncoder,
    registry: &mut KernelRegistry,
    device: &DeviceRef,
    input: &MlxBuffer,
    out_index: &MlxBuffer,
    out_value: &MlxBuffer,
    params_buf: &MlxBuffer,
    n_elements: u32,
) -> Result<()>
Expand description

Dispatch an argmax operation on the GPU.

Finds the index of the maximum element in input and writes the result to out_index and out_value. The entire reduction runs in a single Metal threadgroup, returning 8 bytes instead of the full vocab-size logits array.

§Arguments

  • encoder - Command encoder to record the dispatch into.
  • registry - Kernel registry (must have argmax_f32 registered).
  • device - Metal device for pipeline compilation.
  • input - Input buffer of shape [n_elements] (f32).
  • out_index - Output buffer [1] (u32) — index of the maximum element.
  • out_value - Output buffer [1] (f32) — value of the maximum element.
  • params_buf - Params buffer [1] (u32) — contains n_elements.
  • n_elements - Number of elements in input.

§Errors

Returns MlxError::InvalidArgument if:

  • n_elements is 0.
  • input element count does not match n_elements.
  • out_index or out_value element count is not 1.