#version 450
// Sum reduction compute shader
// Computes: result[0] = sum(input[0..n-1])
// Uses parallel reduction with shared memory for efficiency
// Single workgroup version to avoid atomic operations
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout(set = 0, binding = 0) buffer InputBuffer {
float input_data[];
};
layout(set = 0, binding = 1) buffer OutputBuffer {
float result[];
};
layout(set = 0, binding = 2) uniform UniformBuffer {
uint size; // Number of elements to reduce
};
// Shared memory for reduction operations
shared float shared_data[256];
void main() {
uint index = gl_GlobalInvocationID.x;
uint local_index = gl_LocalInvocationID.x;
uint group_size = gl_WorkGroupSize.x;
// Load data into shared memory with stride to handle large arrays
float thread_sum = 0.0;
// Each thread sums multiple elements if the array is larger than workgroup size
for (uint i = index; i < size; i += gl_NumWorkGroups.x * group_size) {
thread_sum += input_data[i];
}
shared_data[local_index] = thread_sum;
// Synchronize to ensure all data is loaded
barrier();
// Parallel reduction within workgroup
for (uint stride = group_size / 2; stride > 0; stride >>= 1) {
if (local_index < stride) {
shared_data[local_index] += shared_data[local_index + stride];
}
barrier();
}
// First thread writes the final sum
if (local_index == 0) {
result[0] = shared_data[0];
}
}