#version 450
// Tanh activation compute shader
// Computes: result[i] = tanh(input[i])
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
layout(set = 0, binding = 0) buffer InputBuffer {
float input_data[];
};
layout(set = 0, binding = 1) buffer OutputBuffer {
float result[];
};
void main() {
uint index = gl_GlobalInvocationID.x;
// Bounds checking
if (index >= input_data.length() || index >= result.length()) {
return;
}
// Perform tanh activation entirely on GPU
// tanh(x) = (e^x - e^(-x)) / (e^x + e^(-x))
// Use numerically stable implementation to avoid overflow
float x = input_data[index];
if (abs(x) > 10.0) {
// For large values, tanh approaches ±1
result[index] = sign(x);
} else {
// Standard tanh computation for moderate values
float exp_2x = exp(2.0 * x);
result[index] = (exp_2x - 1.0) / (exp_2x + 1.0);
}
}