hanzo-ml 0.10.2

Minimalist ML framework.
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#version 450
// Numerically-stable softmax over the last dim. One invocation per row (row = all but
// the last dim flattened). Eliminates the GPU->CPU->GPU round-trip the fallback used.
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
layout(set = 0, binding = 0) readonly  buffer X { float x[]; };
layout(set = 0, binding = 1) writeonly buffer Y { float y[]; };
layout(push_constant) uniform Pc { uint nrows; uint m; };
void main() {
    uint row = gl_GlobalInvocationID.x;
    if (row >= nrows) { return; }
    uint base = row * m;
    float mx = -3.402823466e38;
    for (uint i = 0u; i < m; i++) { mx = max(mx, x[base + i]); }
    float s = 0.0;
    for (uint i = 0u; i < m; i++) { float e = exp(x[base + i] - mx); y[base + i] = e; s += e; }
    float inv = 1.0 / s;
    for (uint i = 0u; i < m; i++) { y[base + i] *= inv; }
}