Expand description
Fast Walsh-Hadamard Transform (FWHT) GPU kernel dispatch.
Applies an in-place, normalized FWHT to a flat buffer shaped
[num_heads, head_dim]. One Metal threadgroup is dispatched per head;
each threadgroup has head_dim threads that cooperate through shared
memory using the standard butterfly pattern.
The transform is normalized so that H·H = I (applying it twice returns the original vector), which is required for the random-feature / scrambled Hadamard use-case in Gemma-4 attention.
Statics§
- HADAMARD_
SHADER_ SOURCE - MSL source for the Hadamard transform kernel (embedded at compile time).
Functions§
- dispatch_
hadamard_ transform - Dispatch an in-place normalized Fast Walsh-Hadamard Transform on the GPU.
- register
- Register the Hadamard transform shader source with the given kernel registry.