hanzo-ml 0.10.2

Minimalist ML framework.
Documentation
#version 450
// Gather rows along a dim (embeddings). Source is viewed as
// [left, dim_size, right] row-major; `ids` selects n_ids indices along the
// middle dim, producing a [left, n_ids, right] f32 output. total = left*n_ids*right.
layout(local_size_x = 64) in;

layout(set = 0, binding = 0) readonly  buffer Ids { uint ids[]; };
layout(set = 0, binding = 1) readonly  buffer In  { float inp[]; };
layout(set = 0, binding = 2) writeonly buffer Out { float outp[]; };
layout(push_constant) uniform Pc { uint left; uint dim_size; uint right; uint n_ids; };

void main() {
    uint gid = gl_GlobalInvocationID.x;
    uint total = left * n_ids * right;
    if (gid < total) {
        uint r = gid % right;
        uint i = (gid / right) % n_ids;
        uint l = gid / (right * n_ids);
        outp[gid] = inp[(l * dim_size + ids[i]) * right + r];
    }
}