hanzo-ml 0.10.2

Minimalist ML framework.
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#version 450
// Row-wise argmax over the last dim: out[row] = index of the maximum in row. Output u32, length rows.
layout(local_size_x = 64) in;
layout(set = 0, binding = 0) readonly  buffer In  { float inp[]; };
layout(set = 0, binding = 1) writeonly buffer Out { uint o[]; };
layout(push_constant) uniform Pc { uint rows; uint cols; };
void main() {
    uint row = gl_GlobalInvocationID.x;
    if (row >= rows) { return; }
    uint base = row * cols;
    float best = inp[base];
    uint bi = 0u;
    for (uint c = 1u; c < cols; c++) {
        float v = inp[base + c];
        if (v > best) { best = v; bi = c; }
    }
    o[row] = bi;
}