#version 450
// Softmax activation compute shader using shared memory
// Computes: result[i] = exp(input[i] - max) / sum(exp(input[j] - max))
// Uses traditional shared memory approach for maximum compatibility
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
layout(set = 0, binding = 0) buffer InputBuffer {
float input_data[];
};
layout(set = 0, binding = 1) buffer OutputBuffer {
float result[];
};
layout(set = 0, binding = 2) uniform UniformBuffer {
uint size;
};
// Shared memory for reduction operations
shared float shared_data[64];
void main() {
uint index = gl_GlobalInvocationID.x;
uint local_index = gl_LocalInvocationID.x;
uint group_size = gl_WorkGroupSize.x;
// Initialize shared memory
shared_data[local_index] = (index < size) ? input_data[index] : -1e30; // Very negative for padding
barrier();
// First pass: Find maximum value for numerical stability
// Parallel reduction to find max
for (uint stride = group_size / 2; stride > 0; stride >>= 1) {
if (local_index < stride) {
shared_data[local_index] = max(shared_data[local_index], shared_data[local_index + stride]);
}
barrier();
}
float max_val = shared_data[0];
// Broadcast max to all threads in workgroup
barrier();
// Second pass: Compute exp(x - max) and sum
float exp_val = (index < size) ? exp(input_data[index] - max_val) : 0.0;
shared_data[local_index] = exp_val;
barrier();
// Parallel reduction to compute sum
for (uint stride = group_size / 2; stride > 0; stride >>= 1) {
if (local_index < stride) {
shared_data[local_index] += shared_data[local_index + stride];
}
barrier();
}
float sum = shared_data[0];
// Final pass: Compute softmax
if (index < size && sum > 0.0) {
result[index] = exp_val / sum;
} else if (index < size) {
// Handle edge case where sum is zero
result[index] = 1.0 / float(size);
}
}