Expand description
RMS Normalization GPU dispatch.
Computes: x * rsqrt(mean(x^2) + eps) * weight
The mean is computed over the last dimension. eps=1e-6 is the standard value for Gemma 4.
Statics§
- RMS_
NORM_ SHADER_ SOURCE - MSL source for the RMS norm kernels (embedded at compile time).
Functions§
- dispatch_
rms_ norm - Dispatch an RMS normalization operation on the GPU.
- dispatch_
rms_ norm_ mul - Dispatch a fused RMS normalization + elementwise multiply.
- dispatch_
rms_ norm_ no_ scale_ bf16 - Dispatch an RMS normalization without learned scale (bf16 only).
- dispatch_
rms_ norm_ no_ scale_ f32 - Dispatch an RMS normalization without learned scale (f32).
- register
- Register RMS norm shader sources with the given kernel registry.