Expand description
Greedy argmax GPU dispatch — finds the index of the maximum value in a float array entirely on the GPU.
For greedy (temperature=0) decoding with vocab_size=262144, this replaces a 1MB GPU→CPU logits readback with an 8-byte readback: the (index, value) pair. The kernel uses a single threadgroup with shared-memory tree reduction.
Statics§
- ARGMAX_
SHADER_ SOURCE - MSL source for the argmax kernel (embedded at compile time).
Functions§
- dispatch_
argmax_ f32 - Dispatch an argmax operation on the GPU.
- register
- Register argmax shader source with the given kernel registry.