Skip to main content

Module argmax

Module argmax 

Source
Expand description

Greedy argmax GPU dispatch — finds the index of the maximum value in a float array entirely on the GPU.

For greedy (temperature=0) decoding with vocab_size=262144, this replaces a 1MB GPU→CPU logits readback with an 8-byte readback: the (index, value) pair. The kernel uses a single threadgroup with shared-memory tree reduction.

Statics§

ARGMAX_SHADER_SOURCE
MSL source for the argmax kernel (embedded at compile time).

Functions§

dispatch_argmax_f32
Dispatch an argmax operation on the GPU.
register
Register argmax shader source with the given kernel registry.