// Normalization operations. F32 only.
// Entry points: rms_norm_f32, layer_norm_f32, layer_norm_no_bias_f32, group_norm_f32
//
// Welford's online algorithm is used for LayerNorm and GroupNorm to compute
// mean and variance in a single pass with numerical stability. Each thread
// accumulates its own (count, mean, M2) triple, then a tree reduction merges
// accumulators across the workgroup using the parallel Welford merge formula:
// delta = mean_b - mean_a
// mean_ab = mean_a + delta * count_b / (count_a + count_b)
// M2_ab = M2_a + M2_b + delta^2 * count_a * count_b / (count_a + count_b)
//
// Shared memory is sized to WORKGROUP_SIZE (256). All workgroup_size attributes
// and shared memory array sizes MUST be kept in sync with this constant.
// ============================================================================
// Workgroup Configuration
// ============================================================================
const WORKGROUP_SIZE: u32 = 256u;
var<workgroup> norm_shared: array<f32, 256>;
// ============================================================================
// RMS Normalization
// ============================================================================
// rms_norm(x, weight, eps) = x / sqrt(mean(x^2) + eps) * weight
// Applied to last dimension
struct RmsNormParams {
batch_size: u32, // Product of all dims except the last
hidden_size: u32, // Size of the last dimension
eps: f32,
}
@group(0) @binding(0) var<storage, read_write> rms_input: array<f32>;
@group(0) @binding(1) var<storage, read_write> rms_weight: array<f32>;
@group(0) @binding(2) var<storage, read_write> rms_output: array<f32>;
@group(0) @binding(3) var<uniform> rms_params: RmsNormParams;
@compute @workgroup_size(256)
fn rms_norm_f32(@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) group_id: vec3<u32>) {
let tid = local_id.x;
let batch_idx = group_id.x;
if (batch_idx >= rms_params.batch_size) {
return;
}
let hidden_size = rms_params.hidden_size;
let eps = rms_params.eps;
let base_offset = batch_idx * hidden_size;
// Step 1: Compute sum of squares
var sum_sq: f32 = 0.0;
var i: u32 = tid;
while (i < hidden_size) {
let val = rms_input[base_offset + i];
sum_sq = sum_sq + val * val;
i = i + WORKGROUP_SIZE;
}
norm_shared[tid] = sum_sq;
workgroupBarrier();
// Reduce to get total sum of squares
for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) {
if (tid < s) {
norm_shared[tid] = norm_shared[tid] + norm_shared[tid + s];
}
workgroupBarrier();
}
// Compute RMS: sqrt(mean(x^2) + eps)
let rms = sqrt(norm_shared[0] / f32(hidden_size) + eps);
workgroupBarrier();
// Step 2: Normalize and apply weight
i = tid;
while (i < hidden_size) {
rms_output[base_offset + i] = rms_input[base_offset + i] / rms * rms_weight[i];
i = i + WORKGROUP_SIZE;
}
}
// ============================================================================
// Layer Normalization
// ============================================================================
// layer_norm(x, weight, bias, eps) = (x - mean(x)) / sqrt(var(x) + eps) * weight + bias
// Applied to last dimension
struct LayerNormParams {
batch_size: u32,
hidden_size: u32,
eps: f32,
}
@group(0) @binding(0) var<storage, read_write> ln_input: array<f32>;
@group(0) @binding(1) var<storage, read_write> ln_weight: array<f32>;
@group(0) @binding(2) var<storage, read_write> ln_bias: array<f32>;
@group(0) @binding(3) var<storage, read_write> ln_output: array<f32>;
@group(0) @binding(4) var<uniform> ln_params: LayerNormParams;
// Welford shared memory: count, mean, M2 per thread
var<workgroup> ln_shared_count: array<f32, 256>;
var<workgroup> ln_shared_mean: array<f32, 256>;
var<workgroup> ln_shared_m2: array<f32, 256>;
@compute @workgroup_size(256)
fn layer_norm_f32(@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) group_id: vec3<u32>) {
let tid = local_id.x;
let batch_idx = group_id.x;
if (batch_idx >= ln_params.batch_size) {
return;
}
let hidden_size = ln_params.hidden_size;
let eps = ln_params.eps;
let base_offset = batch_idx * hidden_size;
// Step 1: Per-thread Welford accumulation (single pass over input)
var count: f32 = 0.0;
var mean: f32 = 0.0;
var m2: f32 = 0.0;
var i: u32 = tid;
while (i < hidden_size) {
let x = ln_input[base_offset + i];
count = count + 1.0;
let delta = x - mean;
mean = mean + delta / count;
m2 = m2 + delta * (x - mean);
i = i + WORKGROUP_SIZE;
}
ln_shared_count[tid] = count;
ln_shared_mean[tid] = mean;
ln_shared_m2[tid] = m2;
workgroupBarrier();
// Step 2: Tree reduction with Welford merge
for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) {
if (tid < s) {
let count_a = ln_shared_count[tid];
let mean_a = ln_shared_mean[tid];
let m2_a = ln_shared_m2[tid];
let count_b = ln_shared_count[tid + s];
let mean_b = ln_shared_mean[tid + s];
let m2_b = ln_shared_m2[tid + s];
let merged_count = count_a + count_b;
if (merged_count > 0.0) {
let delta = mean_b - mean_a;
let merged_mean = mean_a + delta * count_b / merged_count;
let merged_m2 = m2_a + m2_b + delta * delta * count_a * count_b / merged_count;
ln_shared_count[tid] = merged_count;
ln_shared_mean[tid] = merged_mean;
ln_shared_m2[tid] = merged_m2;
}
}
workgroupBarrier();
}
let final_mean = ln_shared_mean[0];
let variance = ln_shared_m2[0] / f32(hidden_size);
let inv_std = 1.0 / sqrt(variance + eps);
workgroupBarrier();
// Step 3: Normalize and apply affine transformation (second pass over input)
i = tid;
while (i < hidden_size) {
let normalized = (ln_input[base_offset + i] - final_mean) * inv_std;
ln_output[base_offset + i] = normalized * ln_weight[i] + ln_bias[i];
i = i + WORKGROUP_SIZE;
}
}
// ============================================================================
// Layer Normalization without bias
// ============================================================================
@group(0) @binding(0) var<storage, read_write> ln_nb_input: array<f32>;
@group(0) @binding(1) var<storage, read_write> ln_nb_weight: array<f32>;
@group(0) @binding(2) var<storage, read_write> ln_nb_output: array<f32>;
@group(0) @binding(3) var<uniform> ln_nb_params: LayerNormParams;
@compute @workgroup_size(256)
fn layer_norm_no_bias_f32(@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) group_id: vec3<u32>) {
let tid = local_id.x;
let batch_idx = group_id.x;
if (batch_idx >= ln_nb_params.batch_size) {
return;
}
let hidden_size = ln_nb_params.hidden_size;
let eps = ln_nb_params.eps;
let base_offset = batch_idx * hidden_size;
// Step 1: Per-thread Welford accumulation (single pass)
var count: f32 = 0.0;
var mean: f32 = 0.0;
var m2: f32 = 0.0;
var i: u32 = tid;
while (i < hidden_size) {
let x = ln_nb_input[base_offset + i];
count = count + 1.0;
let delta = x - mean;
mean = mean + delta / count;
m2 = m2 + delta * (x - mean);
i = i + WORKGROUP_SIZE;
}
// Reuse layer_norm shared memory for reduction
ln_shared_count[tid] = count;
ln_shared_mean[tid] = mean;
ln_shared_m2[tid] = m2;
workgroupBarrier();
// Step 2: Tree reduction with Welford merge
for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) {
if (tid < s) {
let count_a = ln_shared_count[tid];
let mean_a = ln_shared_mean[tid];
let m2_a = ln_shared_m2[tid];
let count_b = ln_shared_count[tid + s];
let mean_b = ln_shared_mean[tid + s];
let m2_b = ln_shared_m2[tid + s];
let merged_count = count_a + count_b;
if (merged_count > 0.0) {
let delta = mean_b - mean_a;
ln_shared_count[tid] = merged_count;
ln_shared_mean[tid] = mean_a + delta * count_b / merged_count;
ln_shared_m2[tid] = m2_a + m2_b + delta * delta * count_a * count_b / merged_count;
}
}
workgroupBarrier();
}
let final_mean = ln_shared_mean[0];
let variance = ln_shared_m2[0] / f32(hidden_size);
let inv_std = 1.0 / sqrt(variance + eps);
workgroupBarrier();
// Step 3: Normalize and apply weight only (second pass)
i = tid;
while (i < hidden_size) {
let normalized = (ln_nb_input[base_offset + i] - final_mean) * inv_std;
ln_nb_output[base_offset + i] = normalized * ln_nb_weight[i];
i = i + WORKGROUP_SIZE;
}
}
// ============================================================================
// Group Normalization
// ============================================================================
// group_norm(x, weight, bias, num_groups) normalizes over groups of channels
struct GroupNormParams {
batch_size: u32,
channels: u32,
spatial: u32,
num_groups: u32,
channels_per_group: u32,
eps: f32,
_pad0: u32,
_pad1: u32,
}
@group(0) @binding(0) var<storage, read_write> gn_input: array<f32>;
@group(0) @binding(1) var<storage, read_write> gn_weight: array<f32>;
@group(0) @binding(2) var<storage, read_write> gn_bias: array<f32>;
@group(0) @binding(3) var<storage, read_write> gn_output: array<f32>;
@group(0) @binding(4) var<uniform> gn_params: GroupNormParams;
var<workgroup> gn_shared_count: array<f32, 256>;
var<workgroup> gn_shared_mean: array<f32, 256>;
var<workgroup> gn_shared_m2: array<f32, 256>;
@compute @workgroup_size(256)
fn group_norm_f32(@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) group_id: vec3<u32>) {
let tid = local_id.x;
let bg_id = group_id.x; // batch_id * num_groups + group_id
let batch_size = gn_params.batch_size;
let channels = gn_params.channels;
let spatial = gn_params.spatial;
let num_groups = gn_params.num_groups;
let channels_per_group = gn_params.channels_per_group;
let eps = gn_params.eps;
if (bg_id >= batch_size * num_groups) {
return;
}
let batch_id = bg_id / num_groups;
let group_id_val = bg_id % num_groups;
let c_start = group_id_val * channels_per_group;
let group_size = channels_per_group * spatial;
let batch_offset = batch_id * channels * spatial;
let group_offset = batch_offset + c_start * spatial;
// Step 1: Per-thread Welford accumulation (single pass)
var count: f32 = 0.0;
var mean: f32 = 0.0;
var m2: f32 = 0.0;
var i: u32 = tid;
while (i < group_size) {
let c_offset = i / spatial;
let s_offset = i % spatial;
let idx = group_offset + c_offset * spatial + s_offset;
let x = gn_input[idx];
count = count + 1.0;
let delta = x - mean;
mean = mean + delta / count;
m2 = m2 + delta * (x - mean);
i = i + WORKGROUP_SIZE;
}
gn_shared_count[tid] = count;
gn_shared_mean[tid] = mean;
gn_shared_m2[tid] = m2;
workgroupBarrier();
// Step 2: Tree reduction with Welford merge
for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) {
if (tid < s) {
let count_a = gn_shared_count[tid];
let mean_a = gn_shared_mean[tid];
let m2_a = gn_shared_m2[tid];
let count_b = gn_shared_count[tid + s];
let mean_b = gn_shared_mean[tid + s];
let m2_b = gn_shared_m2[tid + s];
let merged_count = count_a + count_b;
if (merged_count > 0.0) {
let delta = mean_b - mean_a;
gn_shared_count[tid] = merged_count;
gn_shared_mean[tid] = mean_a + delta * count_b / merged_count;
gn_shared_m2[tid] = m2_a + m2_b + delta * delta * count_a * count_b / merged_count;
}
}
workgroupBarrier();
}
let final_mean = gn_shared_mean[0];
let variance = gn_shared_m2[0] / f32(group_size);
let inv_std = 1.0 / sqrt(variance + eps);
workgroupBarrier();
// Step 3: Normalize and apply per-channel weight and bias (second pass)
i = tid;
while (i < group_size) {
let c_offset = i / spatial;
let s_offset = i % spatial;
let idx = group_offset + c_offset * spatial + s_offset;
let channel = c_start + c_offset;
let normalized = (gn_input[idx] - final_mean) * inv_std;
gn_output[idx] = normalized * gn_weight[channel] + gn_bias[channel];
i = i + WORKGROUP_SIZE;
}
}