#version 450
// Pass 2: Compute exp(x - max) per element (no reduction needed)
layout(local_size_x = 256) in;
layout(set = 0, binding = 0) readonly buffer Input { float x[]; };
layout(set = 0, binding = 1) writeonly buffer Output { float result[]; };
layout(push_constant) uniform Params {
int n;
float max_value;
};
void main() {
uint idx = gl_GlobalInvocationID.x;
if (idx < n) {
result[idx] = exp(x[idx] - max_value);
}
}