// Two entry points:
// rms_norm_grad_w: gradient wrt weight, dispatch [ceil(cols/256), 1, 1]
// rms_norm_grad_x: gradient wrt input, dispatch [rows, 1, 1]
// Params field names: m=rows, n=cols, k=eps_bits
struct Params {
m: u32,
n: u32,
k: u32,
_pad: u32,
}
var<storage> src_a: array<f32>; // dy
var<storage> src_b: array<f32>; // x
var<storage> bias: array<f32>; // w (weight)
var<storage, read_write> dst: array<f32>;
var<uniform> params: Params;
var<workgroup> wg_data: array<f32, 256>;
// grad_w[col] = sum_i( dy[i*n+col] * x[i*n+col] * rsqrt_i )
//
// Dispatch: [ceil(cols/256), 1, 1]
// Each thread handles one column. Workgroup cooperatively precomputes
// rsqrt for each row via shared-memory reduction (O(rows × cols/256) per
// thread), then accumulates grad_w in O(rows) per thread.
// Total work: O(rows × cols) — down from the naive O(rows × cols²).
@compute @workgroup_size(256)
fn rms_norm_grad_w(@builtin(workgroup_id) wgid: vec3<u32>, @builtin(local_invocation_id) lid: vec3<u32>) {
let col = wgid.x * 256u + lid.x;
let tid = lid.x;
let rows = params.m;
let cols = params.n;
let eps = bitcast<f32>(params.k);
// Process one row at a time: cooperative rsqrt then per-column accumulation
var acc = 0.0;
for (var i = 0u; i < rows; i++) {
let offset = i * cols;
// Cooperative rsqrt: each thread sums a stride of the row
var ss = 0.0;
var j = tid;
loop {
if j >= cols { break; }
let v = src_b[offset + j];
ss += v * v;
j += 256u;
}
wg_data[tid] = ss;
workgroupBarrier();
var stride = 128u;
loop {
if stride == 0u { break; }
if tid < stride {
wg_data[tid] += wg_data[tid + stride];
}
workgroupBarrier();
stride >>= 1u;
}
let rsqrt_i = inverseSqrt(wg_data[0] / f32(cols) + eps);
// Accumulate this row's contribution
if col < cols {
acc += src_a[offset + col] * src_b[offset + col] * rsqrt_i;
}
workgroupBarrier();
}
if col < cols {
dst[col] = acc;
}
}
// grad_x[i,j] = rsqrt_i * (dy[i,j]*w[j] - x[i,j] * s_i)
// where s_i = (rsqrt_i^2 / cols) * sum_j(dy[i,j]*w[j]*x[i,j])
@compute @workgroup_size(256)
fn rms_norm_grad_x(@builtin(workgroup_id) wgid: vec3<u32>, @builtin(local_invocation_id) lid: vec3<u32>) {
let row = wgid.x;
let tid = lid.x;
let rows = params.m;
let cols = params.n;
let eps = bitcast<f32>(params.k);
if row >= rows { return; }
let offset = row * cols;
// Phase 1: Compute rsqrt via shared memory reduction
var ss = 0.0;
var j = tid;
loop {
if j >= cols { break; }
let v = src_b[offset + j];
ss += v * v;
j += 256u;
}
wg_data[tid] = ss;
workgroupBarrier();
var stride = 128u;
loop {
if stride == 0u { break; }
if tid < stride {
wg_data[tid] += wg_data[tid + stride];
}
workgroupBarrier();
stride >>= 1u;
}
let rsqrt_val = inverseSqrt(wg_data[0] / f32(cols) + eps);
// Phase 2: Compute s_i = (rsqrt^2 / cols) * sum_j(dy[j]*w[j]*x[j])
var dot = 0.0;
j = tid;
loop {
if j >= cols { break; }
dot += src_a[offset + j] * bias[j] * src_b[offset + j];
j += 256u;
}
wg_data[tid] = dot;
workgroupBarrier();
stride = 128u;
loop {
if stride == 0u { break; }
if tid < stride {
wg_data[tid] += wg_data[tid + stride];
}
workgroupBarrier();
stride >>= 1u;
}
let s_i = (rsqrt_val * rsqrt_val / f32(cols)) * wg_data[0];
// Phase 3: Write output
j = tid;
loop {
if j >= cols { break; }
let dy = src_a[offset + j];
let w = bias[j];
let x = src_b[offset + j];
dst[offset + j] = rsqrt_val * (dy * w - x * s_i);
j += 256u;
}
}