Skip to main content

Module l2_norm

Module l2_norm 

Source
Expand description

L2 Normalization GPU dispatch.

Computes: x / sqrt(sum(x^2) + eps) over the last dimension.

Used by Gated DeltaNet to normalize Q and K after the conv1d state update (ADR-013 Decision 3; spec derived from the mathematical definition of L2 norm, not from llama.cpp source).

Reduction is always performed in f32 for numerical stability regardless of input dtype.

§Invariants

  • Input and output share the same shape [rows, dim] and dtype.
  • params_buf must hold exactly [eps, dim as f32] as two contiguous f32.
  • rows > 0, dim > 0, input.elements() == rows * dim.

§Threadgroup shape

One threadgroup per row; threadgroup size = min(256, next_power_of_two(dim)). Shared memory of tg_size floats is used for the tree reduction.

Statics§

L2_NORM_SHADER_SOURCE
MSL source for the L2 norm kernels (embedded at compile time).

Functions§

dispatch_l2_norm
Dispatch an L2 normalization operation on the GPU.
dispatch_l2_norm_scale_f32
Dispatch a fused L2 normalization + scalar multiply on the GPU.
register
Register L2 norm shader sources with the given kernel registry.