#version 450
// Pass 2: Normalize and scale by weight
layout(local_size_x = 256) in;
layout(set = 0, binding = 0) readonly buffer Input { float x[]; };
layout(set = 0, binding = 1) readonly buffer Weight { float w[]; };
layout(set = 0, binding = 2) writeonly buffer Output { float result[]; };
layout(push_constant) uniform Params {
int n;
float rms_inv;
};
void main() {
uint idx = gl_GlobalInvocationID.x;
if (idx < n) {
result[idx] = x[idx] * rms_inv * w[idx];
}
}