#version 450
// Find max value per workgroup, write to partial_max[gl_WorkGroupID.x]
layout(local_size_x = 256) in;
layout(set = 0, binding = 0) readonly buffer Input { float x[]; };
layout(set = 0, binding = 1) writeonly buffer PartialMax { float partial_max[]; };
layout(push_constant) uniform Params {
int n;
};
shared float sdata[256];
void main() {
uint tid = gl_LocalInvocationID.x;
uint idx = gl_GlobalInvocationID.x;
sdata[tid] = (idx < n) ? x[idx] : -3.402823466e+38;
barrier();
for (uint s = gl_WorkGroupSize.x / 2; s > 0; s >>= 1) {
if (tid < s) {
sdata[tid] = max(sdata[tid], sdata[tid + s]);
}
barrier();
}
if (tid == 0) {
partial_max[gl_WorkGroupID.x] = sdata[0];
}
}