#version 450
// Gather rows along a dim (embeddings). Source is viewed as
// [left, dim_size, right] row-major; `ids` selects n_ids indices along the
// middle dim, producing a [left, n_ids, right] f32 output. total = left*n_ids*right.
layout(local_size_x = 64) in;
layout(set = 0, binding = 0) readonly buffer Ids { uint ids[]; };
layout(set = 0, binding = 1) readonly buffer In { float inp[]; };
layout(set = 0, binding = 2) writeonly buffer Out { float outp[]; };
layout(push_constant) uniform Pc { uint left; uint dim_size; uint right; uint n_ids; };
void main() {
uint gid = gl_GlobalInvocationID.x;
uint total = left * n_ids * right;
if (gid < total) {
uint r = gid % right;
uint i = (gid / right) % n_ids;
uint l = gid / (right * n_ids);
outp[gid] = inp[(l * dim_size + ids[i]) * right + r];
}
}