Expand description
L2 Normalization GPU dispatch.
Computes: x / sqrt(sum(x^2) + eps) over the last dimension.
Used by Gated DeltaNet to normalize Q and K after the conv1d state update (ADR-013 Decision 3; spec derived from the mathematical definition of L2 norm, not from llama.cpp source).
Reduction is always performed in f32 for numerical stability regardless of input dtype.
§Invariants
- Input and output share the same shape
[rows, dim]and dtype. params_bufmust hold exactly[eps, dim as f32]as two contiguous f32.rows > 0,dim > 0,input.elements() == rows * dim.
§Threadgroup shape
One threadgroup per row; threadgroup size = min(256, next_power_of_two(dim)).
Shared memory of tg_size floats is used for the tree reduction.
Statics§
- L2_
NORM_ SHADER_ SOURCE - MSL source for the L2 norm kernels (embedded at compile time).
Functions§
- dispatch_
l2_ norm - Dispatch an L2 normalization operation on the GPU.
- dispatch_
l2_ norm_ scale_ f32 - Dispatch a fused L2 normalization + scalar multiply on the GPU.
- register
- Register L2 norm shader sources with the given kernel registry.