#version 450
// Numerically-stable softmax over the last dim. One invocation per row (row = all but
// the last dim flattened). Eliminates the GPU->CPU->GPU round-trip the fallback used.
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
layout(set = 0, binding = 0) readonly buffer X { float x[]; };
layout(set = 0, binding = 1) writeonly buffer Y { float y[]; };
layout(push_constant) uniform Pc { uint nrows; uint m; };
void main() {
uint row = gl_GlobalInvocationID.x;
if (row >= nrows) { return; }
uint base = row * m;
float mx = -3.402823466e38;
for (uint i = 0u; i < m; i++) { mx = max(mx, x[base + i]); }
float s = 0.0;
for (uint i = 0u; i < m; i++) { float e = exp(x[base + i] - mx); y[base + i] = e; s += e; }
float inv = 1.0 / s;
for (uint i = 0u; i < m; i++) { y[base + i] *= inv; }
}