#version 450
// Swish activation compute shader
// Computes: result[i] = input[i] * sigmoid(input[i])
// Swish(x) = x * (1 / (1 + exp(-x)))
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
layout(set = 0, binding = 0) buffer InputBuffer {
float input_data[];
};
layout(set = 0, binding = 1) buffer OutputBuffer {
float result[];
};
void main() {
uint index = gl_GlobalInvocationID.x;
// Bounds checking
if (index >= input_data.length() || index >= result.length()) {
return;
}
// Perform Swish activation entirely on GPU
// Swish(x) = x * sigmoid(x) = x * (1 / (1 + e^(-x)))
float x = input_data[index];
// Use numerically stable sigmoid computation
float sigmoid_val;
if (x >= 0.0) {
float exp_neg_x = exp(-x);
sigmoid_val = 1.0 / (1.0 + exp_neg_x);
} else {
float exp_x = exp(x);
sigmoid_val = exp_x / (1.0 + exp_x);
}
// Final Swish computation: x * sigmoid(x)
result[index] = x * sigmoid_val;
}