pub fn dispatch_argmax_f32(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &DeviceRef,
input: &MlxBuffer,
out_index: &MlxBuffer,
out_value: &MlxBuffer,
params_buf: &MlxBuffer,
n_elements: u32,
) -> Result<()>Expand description
Dispatch an argmax operation on the GPU.
Finds the index of the maximum element in input and writes the result to
out_index and out_value. The entire reduction runs in a single Metal
threadgroup, returning 8 bytes instead of the full vocab-size logits array.
§Arguments
encoder- Command encoder to record the dispatch into.registry- Kernel registry (must haveargmax_f32registered).device- Metal device for pipeline compilation.input- Input buffer of shape[n_elements](f32).out_index- Output buffer[1](u32) — index of the maximum element.out_value- Output buffer[1](f32) — value of the maximum element.params_buf- Params buffer[1](u32) — containsn_elements.n_elements- Number of elements ininput.
§Errors
Returns MlxError::InvalidArgument if:
n_elementsis 0.inputelement count does not matchn_elements.out_indexorout_valueelement count is not 1.