hanzo-ml 0.10.2

Minimalist ML framework.
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#version 450
// gather along `dim`. ids has the OUTPUT shape; for output flat index g decomposed as
// (outer, j, inner) with inner < right and j < dim_out, the source element is
//   src[outer*(dim_src*right) + ids[g]*right + inner].
layout(local_size_x = 64) in;
layout(set = 0, binding = 0) readonly  buffer Src { float src[]; };
layout(set = 0, binding = 1) readonly  buffer Ids { uint ids[]; };
layout(set = 0, binding = 2) writeonly buffer Out { float o[]; };
layout(push_constant) uniform Pc { uint n; uint right; uint dim_out; uint dim_src; };
void main() {
    uint g = gl_GlobalInvocationID.x;
    if (g >= n) { return; }
    uint inner = g % right;
    uint outer = g / (right * dim_out);
    uint id = ids[g];
    o[g] = src[outer * (dim_src * right) + id * right + inner];
}