#version 450
// ScatterAdd along axis 0 (output-centric, no atomics): updates [U, T],
// indices [U], output [out_dim, T]. out[j, t] = Σ_{i : idx[i]==j} updates[i, t].
// Each output element independently scans the updates, so there is no
// write-collision — avoids the need for float atomics.
layout(local_size_x = 256) in;
layout(std430, binding = 0) buffer Arena { float data[]; };
layout(push_constant) uniform PC {
uint out_dim;
uint trailing;
uint num_updates;
uint upd_off;
uint idx_off;
uint out_off;
} pc;
void main() {
uint gid = gl_GlobalInvocationID.x;
uint total = pc.out_dim * pc.trailing;
if (gid >= total) { return; }
uint j = gid / pc.trailing;
uint t = gid % pc.trailing;
float acc = 0.0;
for (uint i = 0u; i < pc.num_updates; i++) {
uint ix = uint(data[pc.idx_off + i] + 0.5);
if (ix == j) {
acc += data[pc.upd_off + i * pc.trailing + t];
}
}
data[pc.out_off + gid] = acc;
}