#version 450
// Softmax over `axis_len` along a strided axis. The tensor is viewed as
// [outer, axis_len, inner]: element (o, k, j) lives at
// base + o*axis_len*inner + k*inner + j. One invocation per (o, j).
layout(local_size_x = 256) in;
layout(std430, binding = 0) buffer Arena { float data[]; };
layout(push_constant) uniform PC {
uint outer;
uint axis_len;
uint inner;
uint in_off;
uint out_off;
} pc;
void main() {
uint gid = gl_GlobalInvocationID.x;
uint total = pc.outer * pc.inner;
if (gid >= total) { return; }
uint o = gid / pc.inner;
uint j = gid % pc.inner;
uint base = o * pc.axis_len * pc.inner + j;
float m = -3.402823466e38;
for (uint k = 0u; k < pc.axis_len; k++) {
m = max(m, data[pc.in_off + base + k * pc.inner]);
}
float sum = 0.0;
for (uint k = 0u; k < pc.axis_len; k++) {
sum += exp(data[pc.in_off + base + k * pc.inner] - m);
}
float inv = (sum > 0.0) ? (1.0 / sum) : 0.0;
for (uint k = 0u; k < pc.axis_len; k++) {
uint idx = base + k * pc.inner;
data[pc.out_off + idx] = exp(data[pc.in_off + idx] - m) * inv;
}
}