hanzo-ml 0.10.2

Minimalist ML framework.
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#version 450
// RMS norm over the last dim: y = x / sqrt(mean(x^2) + eps) * alpha. One invocation per row.
// Matches hanzo-ml CPU rms-norm: m = sqrt(sum(x^2)/dim + eps); y = x / m * alpha.
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
layout(set = 0, binding = 0) readonly  buffer X { float x[]; };
layout(set = 0, binding = 1) readonly  buffer A { float alpha[]; };
layout(set = 0, binding = 2) writeonly buffer Y { float y[]; };
layout(push_constant) uniform Pc { uint nrows; uint m; float eps; };
void main() {
    uint row = gl_GlobalInvocationID.x;
    if (row >= nrows) { return; }
    uint base = row * m;
    float ss = 0.0;
    for (uint i = 0u; i < m; i++) { float v = x[base + i]; ss += v * v; }
    float denom = sqrt(ss / float(m) + eps);
    for (uint i = 0u; i < m; i++) { y[base + i] = x[base + i] / denom * alpha[i]; }
}