hanzo-ml 0.10.2

Minimalist ML framework.
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#version 450
// Exact GELU: out = 0.5 * x * (1 + erf(x / sqrt(2))).  (hanzo-ml "gelu_erf")
layout(local_size_x = 64) in;
layout(set = 0, binding = 0) readonly  buffer In  { float inp[]; };
layout(set = 0, binding = 1) writeonly buffer Out { float o[]; };
layout(push_constant) uniform Pc { uint n; };
float erf_approx(float x) {
    float t = 1.0 / (1.0 + 0.3275911 * abs(x));
    float y = 1.0 - (((((1.061405429 * t - 1.453152027) * t) + 1.421413741) * t
                    - 0.284496736) * t + 0.254829592) * t * exp(-x * x);
    return sign(x) * y;
}
void main() {
    uint i = gl_GlobalInvocationID.x;
    if (i >= n) { return; }
    float x = inp[i];
    o[i] = 0.5 * x * (1.0 + erf_approx(x * 0.7071067811865476));
}