Skip to main content

Module rms_norm

Module rms_norm 

Source
Expand description

RMS Normalization GPU dispatch.

Computes: x * rsqrt(mean(x^2) + eps) * weight

The mean is computed over the last dimension. eps=1e-6 is the standard value for Gemma 4.

Statics§

RMS_NORM_SHADER_SOURCE
MSL source for the RMS norm kernels (embedded at compile time).

Functions§

dispatch_rms_norm
Dispatch an RMS normalization operation on the GPU.
dispatch_rms_norm_mul
Dispatch a fused RMS normalization + elementwise multiply.
dispatch_rms_norm_no_scale_bf16
Dispatch an RMS normalization without learned scale (bf16 only).
dispatch_rms_norm_no_scale_f32
Dispatch an RMS normalization without learned scale (f32).
register
Register RMS norm shader sources with the given kernel registry.