#version 450
// Compute partial sum of squares per workgroup
layout(local_size_x = 256) in;
layout(set = 0, binding = 0) readonly buffer Input { float x[]; };
layout(set = 0, binding = 1) writeonly buffer PartialSum { float partial_sum[]; };
layout(push_constant) uniform Params {
int n;
};
shared float sdata[256];
void main() {
uint tid = gl_LocalInvocationID.x;
uint idx = gl_GlobalInvocationID.x;
float val = (idx < n) ? x[idx] : 0.0;
sdata[tid] = val * val;
barrier();
for (uint s = gl_WorkGroupSize.x / 2; s > 0; s >>= 1) {
if (tid < s) {
sdata[tid] += sdata[tid + s];
}
barrier();
}
if (tid == 0) {
partial_sum[gl_WorkGroupID.x] = sdata[0];
}
}