hanzo-ml 0.10.2

Minimalist ML framework.
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#version 450
// scatter set along `dim`. ids/src share a shape; for src flat index g = (outer, j, inner)
// with inner < right and j < dim_src, write dst[outer*(dim_dst*right) + ids[g]*right + inner] = src[g].
// (Duplicate ids -> last writer wins, which is unspecified but matches hanzo-ml's contract.)
layout(local_size_x = 64) in;
layout(set = 0, binding = 0) buffer Dst { float dst[]; };
layout(set = 0, binding = 1) readonly buffer Src { float src[]; };
layout(set = 0, binding = 2) readonly buffer Ids { uint ids[]; };
layout(push_constant) uniform Pc { uint n; uint right; uint dim_src; uint dim_dst; };
void main() {
    uint g = gl_GlobalInvocationID.x;
    if (g >= n) { return; }
    uint inner = g % right;
    uint outer = g / (right * dim_src);
    uint id = ids[g];
    dst[outer * (dim_dst * right) + id * right + inner] = src[g];
}