#version 450
// gather along `dim`. ids has the OUTPUT shape; for output flat index g decomposed as
// (outer, j, inner) with inner < right and j < dim_out, the source element is
// src[outer*(dim_src*right) + ids[g]*right + inner].
layout(local_size_x = 64) in;
layout(set = 0, binding = 0) readonly buffer Src { float src[]; };
layout(set = 0, binding = 1) readonly buffer Ids { uint ids[]; };
layout(set = 0, binding = 2) writeonly buffer Out { float o[]; };
layout(push_constant) uniform Pc { uint n; uint right; uint dim_out; uint dim_src; };
void main() {
uint g = gl_GlobalInvocationID.x;
if (g >= n) { return; }
uint inner = g % right;
uint outer = g / (right * dim_out);
uint id = ids[g];
o[g] = src[outer * (dim_src * right) + id * right + inner];
}